├── .github ├── ISSUE_TEMPLATE │ ├── bug-report.md │ ├── documentation.md │ ├── feature-request.md │ └── questions-help-support.md ├── PULL_REQUEST_TEMPLATE.md └── workflows │ ├── publish.yml │ └── test.yml ├── .gitignore ├── CHANGELOG.rst ├── CONTRIBUTING.rst ├── LICENSE ├── README.rst ├── audtorch ├── __init__.py ├── collate.py ├── datasets │ ├── __init__.py │ ├── audio_set.py │ ├── base.py │ ├── emodb.py │ ├── libri_speech.py │ ├── mixture.py │ ├── mozilla_common_voice.py │ ├── speech_commands.py │ ├── utils.py │ ├── voxceleb1.py │ └── white_noise.py ├── metrics │ ├── __init__.py │ ├── functional.py │ ├── losses.py │ └── metrics.py ├── samplers.py ├── transforms │ ├── __init__.py │ ├── functional.py │ └── transforms.py └── utils.py ├── docs ├── api-collate.rst ├── api-datasets.rst ├── api-metrics-functional.rst ├── api-metrics.rst ├── api-samplers.rst ├── api-transforms-functional.rst ├── api-transforms.rst ├── api-utils.rst ├── changelog.rst ├── conf.py ├── develop.rst ├── genindex.rst ├── index.rst ├── install.rst ├── refs.bib ├── refs.rst ├── requirements.txt ├── tutorials │ ├── data │ │ └── emodb │ │ │ ├── 03a01Eb.wav │ │ │ └── 03a01Fa.wav │ └── introduction.ipynb └── usage.rst ├── requirements.txt ├── setup.cfg ├── setup.py └── tests ├── requirements.txt ├── test_collate.py ├── test_datasets.py ├── test_metrics.py ├── test_samplers.py ├── test_transforms.py ├── test_transforms_functional.py └── test_utils.py /.github/ISSUE_TEMPLATE/bug-report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: "Bug Report" 3 | about: Submit a bug report to help us improve audtorch 4 | 5 | --- 6 | 7 | ### Bug 8 | 9 | Describe the bug. 10 | 11 | 12 | ### Steps to reproduce 13 | 14 | List of steps/code to reproduce the bug: 15 | 16 | 1. 17 | 1. 18 | 19 | 20 | ### Expected behaviour 21 | 22 | What is the expected outcome. 23 | 24 | 25 | ### Environment 26 | 27 | * OS: 28 | * Python version: 29 | * audtorch version: 30 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/documentation.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: "Documentation" 3 | about: Report an issue related to https://audtorch.readthedocs.io/ 4 | 5 | --- 6 | 7 | ### Documentation 8 | 9 | Issues or proposal for the documentation at https://audtorch.readthedocs.io/. 10 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature-request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: "Feature Request" 3 | about: Submit a proposal/request for a new audtorch feature 4 | 5 | --- 6 | 7 | ### Feature 8 | 9 | Describe the desired feature. 10 | 11 | 12 | ### Motivation 13 | 14 | Why this feature is a great idea. 15 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/questions-help-support.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: "Questions/Help/Support" 3 | about: Do you need support? We have resources. 4 | 5 | --- 6 | 7 | ### Questions and Help 8 | 9 | Ask a question or ask for help. 10 | -------------------------------------------------------------------------------- /.github/PULL_REQUEST_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | ### Summary 2 | 3 | Goal of this pull request. 4 | 5 | 6 | ### Proposed Changes 7 | 8 | List of actual changes. 9 | 10 | 11 | ### Discussion 12 | 13 | Only if needed: 14 | 15 | 1. Point one for discussion 16 | 2. Point two for discussion 17 | 18 | 19 | Please review the [guidelines for contributing](../CONTRIBUTING.md). 20 | -------------------------------------------------------------------------------- /.github/workflows/publish.yml: -------------------------------------------------------------------------------- 1 | name: Publish 2 | 3 | on: 4 | push: 5 | tags: 6 | - 'v*' # Push events to matching v*, i.e. v1.0, v20.15.10 7 | 8 | jobs: 9 | deploy: 10 | runs-on: ubuntu-latest 11 | steps: 12 | - uses: actions/checkout@v2 13 | with: 14 | fetch-depth: 2 15 | - name: Set up Python 16 | uses: actions/setup-python@v2 17 | with: 18 | python-version: '3.x' 19 | - name: Install dependencies 20 | run: | 21 | python -m pip install --upgrade pip 22 | pip install setuptools wheel twine 23 | # PyPI package 24 | - name: Build and publish 25 | env: 26 | TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }} 27 | TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }} 28 | run: | 29 | python setup.py sdist bdist_wheel 30 | python -m twine upload dist/* 31 | # Docuemntation 32 | - name: Install doc dependencies 33 | run: | 34 | sudo apt-get update 35 | sudo apt-get install pandoc 36 | pip install -r docs/requirements.txt 37 | - name: Build documentation 38 | run: python -m sphinx docs/ docs/_build/ -b html 39 | - name: Deploy documentation to Github pages 40 | uses: peaceiris/actions-gh-pages@v3 41 | with: 42 | github_token: ${{ secrets.GITHUB_TOKEN }} 43 | publish_dir: ./docs/_build 44 | # Github release 45 | - name: Read CHANGELOG 46 | id: changelog 47 | run: | 48 | # Get bullet points from last CHANGELOG entry 49 | CHANGELOG=$(git diff -U0 HEAD^ HEAD | grep '^[+][\* ]' | sed 's/\+//') 50 | # Support for multiline, see 51 | # https://github.com/actions/create-release/pull/11#issuecomment-640071918 52 | CHANGELOG="${CHANGELOG//'%'/'%25'}" 53 | CHANGELOG="${CHANGELOG//$'\n'/'%0A'}" 54 | CHANGELOG="${CHANGELOG//$'\r'/'%0D'}" 55 | echo "Got changelog: $CHANGELOG" 56 | echo "::set-output name=body::$CHANGELOG" 57 | - name: Create release on Github 58 | id: create_release 59 | uses: actions/create-release@v1 60 | env: 61 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 62 | with: 63 | tag_name: ${{ github.ref }} 64 | release_name: Release ${{ github.ref }} 65 | body: ${{ steps.changelog.outputs.body }} 66 | -------------------------------------------------------------------------------- /.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | name: Test 2 | 3 | on: 4 | push: 5 | branches: [ master ] 6 | pull_request: 7 | branches: [ master ] 8 | 9 | jobs: 10 | build: 11 | 12 | runs-on: ${{ matrix.os }} 13 | strategy: 14 | matrix: 15 | os: [ ubuntu-latest ] 16 | python-version: [3.6, 3.7, 3.8] 17 | 18 | steps: 19 | - uses: actions/checkout@v2 20 | - name: Set up Python ${{ matrix.python-version }} 21 | uses: actions/setup-python@v2 22 | with: 23 | python-version: ${{ matrix.python-version }} 24 | - name: Install dependencies 25 | run: | 26 | sudo apt-get update 27 | sudo apt-get install -y sox libsndfile1 28 | python -m pip install --upgrade pip 29 | pip install -r requirements.txt 30 | - name: Test with pytest 31 | run: | 32 | pip install -r tests/requirements.txt 33 | python -m pytest 34 | - name: Upload coverage to Codecov 35 | uses: codecov/codecov-action@v1 36 | with: 37 | token: ${{ secrets.CODECOV_TOKEN }} 38 | file: ./coverage.xml 39 | if: matrix.os == 'ubuntu-latest' 40 | - name: Test building documentation 41 | run: | 42 | sudo apt-get install -y pandoc 43 | pip install -r docs/requirements.txt 44 | python -m sphinx docs/ docs/_build/ -b html -W -D nbsphinx_execute='always' 45 | python -m sphinx docs/ docs/_build/ -b linkcheck -W 46 | if: matrix.python-version == '3.6' 47 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | .eggs 3 | *.egg-info 4 | build/ 5 | dist/ 6 | *.pyc 7 | coverage.xml 8 | .coverage 9 | -------------------------------------------------------------------------------- /CHANGELOG.rst: -------------------------------------------------------------------------------- 1 | Changelog 2 | ========= 3 | 4 | All notable changes to this project will be documented in this file. 5 | 6 | The format is based on `Keep a Changelog`_, 7 | and this project adheres to `Semantic Versioning`_. 8 | 9 | 10 | Version 0.6.4 (2020-11-02) 11 | -------------------------- 12 | 13 | * Fixed: link to documentation on Github pages in Python package 14 | 15 | 16 | Version 0.6.3 (2020-10-30) 17 | -------------------------- 18 | 19 | * Added: use copy-button Sphinx plugin 20 | * Added: links to usage and installation to README 21 | * Changed: use sphinx-audeering-theme 22 | * Changed: update all documentation links to Github pages 23 | 24 | 25 | Version 0.6.2 (2020-10-30) 26 | -------------------------- 27 | 28 | * Fixed: install missing pandoc for publishing documentation 29 | 30 | 31 | Version 0.6.1 (2020-10-30) 32 | -------------------------- 33 | 34 | * Fixed: only install doc dependency for automatic release 35 | 36 | 37 | Version 0.6.0 (2020-10-30) 38 | -------------------------- 39 | 40 | * Added: code coverage 41 | * Added: automatic publishing using Github Actions 42 | * Changed: use Github Actions for testing 43 | * Changed: host documentation as Github pages 44 | * Fixed: use newest librosa version 45 | 46 | Version 0.5.2 (2020-03-03) 47 | -------------------------- 48 | 49 | * Fixed: disable automatic execution of notebook 50 | 51 | 52 | Version 0.5.1 (2020-03-03) 53 | -------------------------- 54 | 55 | * Fixed: execute jupyter notebook on readthedocs 56 | * Fixed: release date of 5.0.0 in CHANGELOG 57 | 58 | 59 | Version 0.5.0 (2020-03-03) 60 | -------------------------- 61 | 62 | * Added: `RandomConvolutionalMix` transform 63 | * Added: `EmoDB` data set 64 | * Added: introduction tutorial 65 | * Added: Python 3.8 support 66 | * Added: ``column_end`` + ``column_start`` to ``CsvDataset`` and 67 | ``PandasDataset`` 68 | * Added: random convolutional mix transform 69 | * Changed: default filename column in data sets is now ``file`` 70 | * Changed: force keyword only arguments 71 | * Fixed: ``stft`` functional example 72 | * Fixed: import of ``librosa`` 73 | * Removed: Python 3.5 support 74 | 75 | 76 | Version 0.4.2 (2019-11-04) 77 | -------------------------- 78 | 79 | * Fixed: critical bug of missing files in wheel package (#60) 80 | 81 | 82 | Version 0.4.1 (2019-10-25) 83 | -------------------------- 84 | 85 | * Fixed: default axis values for Masking transforms (#59) 86 | 87 | 88 | Version 0.4.0 (2019-10-21) 89 | -------------------------- 90 | 91 | * Added: masking transforms in time and frequency domain 92 | 93 | 94 | Version 0.3.2 (2019-10-04) 95 | -------------------------- 96 | 97 | * Fixed: long description in ``setup.cfg`` 98 | 99 | 100 | Version 0.3.1 (2019-10-04) 101 | -------------------------- 102 | 103 | * Changed: define package in ``setup.cfg`` 104 | 105 | 106 | Version 0.3.0 (2019-09-13) 107 | -------------------------- 108 | 109 | * Added: ``datasets.SpeechCommands`` (#49) 110 | * Removed: ``LogSpectrogram`` (#52) 111 | 112 | 113 | Version 0.2.1 (2019-08-01) 114 | -------------------------- 115 | 116 | * Changed: Remove os.system call for moving files (#43) 117 | * Fixed: Remove broken logos from issue templates (#31) 118 | * Fixed: Wrong ``Spectrogram`` output shape in documentation (#40) 119 | * Fixed: Broken data set loading for relative paths (#33) 120 | 121 | 122 | Version 0.2.0 (2019-06-28) 123 | -------------------------- 124 | 125 | * Added: ``Standardize``, ``Log`` (#29) 126 | * Changed: Switch to `Keep a Changelog`_ format (#34) 127 | * Deprecated: ``LogSpectrogram`` (#29) 128 | * Fixed: ``normalize`` axis (#28) 129 | 130 | 131 | Version 0.1.1 (2019-05-23) 132 | -------------------------- 133 | 134 | * Fixed: Broken API documentation on readthedocs 135 | 136 | 137 | Version 0.1.0 (2019-05-22) 138 | -------------------------- 139 | 140 | * Added: Public release 141 | 142 | 143 | .. _Keep a Changelog: https://keepachangelog.com/en/1.0.0/ 144 | .. _Semantic Versioning: https://semver.org/spec/v2.0.0.html 145 | -------------------------------------------------------------------------------- /CONTRIBUTING.rst: -------------------------------------------------------------------------------- 1 | Contributing 2 | ============ 3 | 4 | Everyone is invited to contribute to this project. Feel free to create a 5 | `pull request`_. 6 | If you find errors, omissions, inconsistencies or other things that need 7 | improvement, please create an issue_. 8 | 9 | .. _issue: https://github.com/audeering/audtorch/issues/new/ 10 | .. _pull request: https://github.com/audeering/audtorch/compare/ 11 | 12 | 13 | Development Installation 14 | ------------------------ 15 | 16 | Instead of pip-installing the latest release from PyPI_, you should get the 17 | newest development version from Github_:: 18 | 19 | git clone https://github.com/audeering/audtorch/ 20 | cd audtorch 21 | # Create virtual environment, e.g. 22 | # virtualenv --python=python3 _env 23 | # source _env/bin/activate 24 | python setup.py develop 25 | 26 | .. _PyPI: https://pypi.org/project/audtorch/ 27 | .. _Github: https://github.com/audeering/audtorch/ 28 | 29 | This way, your installation always stays up-to-date, even if you pull new 30 | changes from the Github_ repository. 31 | 32 | If you prefer, you can also replace the last command with:: 33 | 34 | pip install -r requirements.txt 35 | 36 | 37 | Pull requests 38 | ------------- 39 | 40 | When creating a new pull request, please conside the following points: 41 | 42 | * Focus on a single topic as it is easier to review short pull requests 43 | * Ensure your code is readable and `PEP 8`_ compatible 44 | * Provide a test for proposed new functionality 45 | * Add a docstring, see the `Writing Documentation` remarks below 46 | * Choose a `meaningful commit messages`_ 47 | 48 | .. _PEP 8: https://www.python.org/dev/peps/pep-0008/ 49 | .. _meaningful commit messages: https://chris.beams.io/posts/git-commit/ 50 | 51 | 52 | Writing Documentation 53 | --------------------- 54 | 55 | The API documentation of :mod:`audtorch` is build automatically from the 56 | docstrings_ of its classes and functions. 57 | 58 | docstrings_ are written in reStructuredText_ as indicated by the ``r`` at 59 | its beginning and they are written using the `Google docstring convention`_ 60 | with the following additions: 61 | 62 | * Start argument description in lower case and end the last sentence without a 63 | punctation. 64 | * If the argument is optional, its default value has to be indicated. 65 | * Description of attributes start as well in lower case and stop without 66 | punctuation. 67 | * Attributes that can influence the behavior of the class should be described by 68 | the word ``controls``. 69 | * Attributes that are supposed to be read only and provide only information 70 | should be described by the word ``holds``. 71 | * Have a special section for class attributes. 72 | * Python variables should be set in single back tics in the description of the 73 | docstring, e.g. ```True```. Only for some explicit statements like a list 74 | of variables it might be look better to write them as code, e.g. 75 | ```'mean'```. 76 | 77 | The important part of the docstrings_ is the first line which holds a short 78 | summary of the functionality, that should not be longer than one line, written 79 | in imperative, and stops with a point. It is also considered good practice to 80 | include an usage example. 81 | 82 | reStructuredText_ allows for easy inclusion of math in LaTeX syntax that will 83 | be dynamically rendered in the browser. 84 | 85 | After you are happy with your docstring, you have to include it into the main 86 | documentation under the ``docs/`` folder in the appropriate api file. E.g. 87 | ``energy()`` is part of the ``utils`` module and the corresponding file in the 88 | documentation would be ``docs/api-utils.rst``, where it is included. 89 | 90 | .. _docstrings: https://www.python.org/dev/peps/pep-0257/ 91 | .. _reStructuredText: 92 | http://www.sphinx-doc.org/en/master/usage/restructuredtext/basics.html 93 | .. _Google docstring convention: 94 | https://sphinxcontrib-napoleon.readthedocs.io/en/latest/example_google.html 95 | 96 | 97 | Building Documentation 98 | ---------------------- 99 | 100 | If you make changes to the documentation, you can re-create the HTML pages 101 | using Sphinx_. 102 | You can install it and a few other necessary packages with:: 103 | 104 | pip install -r doc/requirements.txt 105 | 106 | To create the HTML pages, use:: 107 | 108 | sphinx-build docs/ build/sphinx/html/ -b html 109 | 110 | The generated files will be available in the directory ``build/sphinx/html/``. 111 | 112 | It is also possible to automatically check if all links are still valid:: 113 | 114 | sphinx-build docs/ build/sphinx/html/ -b linkcheck 115 | 116 | .. _Sphinx: http://sphinx-doc.org/ 117 | 118 | 119 | Running Tests 120 | ------------- 121 | 122 | You'll need pytest_ and a few dependencies for that. 123 | It can be installed with:: 124 | 125 | pip install -r tests/requirements.txt 126 | 127 | To execute the tests, simply run:: 128 | 129 | pytest 130 | 131 | .. _pytest: https://pytest.org/ 132 | 133 | 134 | Creating a New Release 135 | ---------------------- 136 | 137 | New releases are made using the following steps: 138 | 139 | #. Update ``CHANGELOG.rst`` 140 | #. Commit those changes as "Release X.Y.Z" 141 | #. Create an (annotated) tag with ``git tag -a X.Y.Z`` 142 | #. Push the commit and the tag to Github 143 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 audEERING GmbH and Contributors 4 | 5 | Authors: 6 | Andreas Triantafyllopoulos 7 | Stephan Huber 8 | Johannes Wagner 9 | Hagen Wierstorf 10 | 11 | Contributors: 12 | Harri Taylor 13 | 14 | Permission is hereby granted, free of charge, to any person obtaining a copy 15 | of this software and associated documentation files (the "Software"), to deal 16 | in the Software without restriction, including without limitation the rights 17 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 18 | copies of the Software, and to permit persons to whom the Software is 19 | furnished to do so, subject to the following conditions: 20 | 21 | The above copyright notice and this permission notice shall be included in all 22 | copies or substantial portions of the Software. 23 | 24 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 25 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 26 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 27 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 28 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 29 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 30 | SOFTWARE. 31 | -------------------------------------------------------------------------------- /README.rst: -------------------------------------------------------------------------------- 1 | ======== 2 | audtorch 3 | ======== 4 | 5 | |tests| |coverage| |docs| |python-versions| |license| 6 | 7 | Deep learning with PyTorch_ and audio. 8 | 9 | Have a look at the installation_ and usage_ instructions as a starting point. 10 | 11 | If you are interested in PyTorch_ and audio you should also check out the 12 | efforts to integrate more audio directly into PyTorch_: 13 | 14 | * `pytorch/audio`_ 15 | * `keunwoochoi/torchaudio-contrib`_ 16 | 17 | .. _installation: https://audeering.github.io/audtorch/install.html 18 | .. _keunwoochoi/torchaudio-contrib: https://github.com/keunwoochoi/torchaudio-contrib 19 | .. _PyTorch: https://pytorch.org 20 | .. _pytorch/audio: https://github.com/pytorch/audio 21 | .. _usage: https://audeering.github.io/audtorch/usage.html 22 | 23 | .. |tests| image:: https://github.com/audeering/audtorch/workflows/Test/badge.svg 24 | :target: https://github.com/audeering/audtorch/actions?query=workflow%3ATest 25 | :alt: Test status 26 | .. |coverage| image:: https://codecov.io/gh/audeering/audtorch/branch/master/graph/badge.svg?token=PUA9P2UJW1 27 | :target: https://codecov.io/gh/audeering/audtorch/ 28 | :alt: code coverage 29 | .. |docs| image:: https://img.shields.io/pypi/v/audtorch?label=docs 30 |    :target: https://audeering.github.io/audtorch/ 31 |    :alt: audtorch's documentation 32 | .. |python-versions| image:: https://img.shields.io/pypi/pyversions/audtorch.svg 33 | :target: https://pypi.org/project/audtorch/ 34 | :alt: audtorch's supported Python versions 35 | .. |license| image:: https://img.shields.io/badge/license-MIT-green.svg 36 | :target: https://github.com/audeering/audtorch/blob/master/LICENSE 37 | :alt: audtorch's MIT license 38 | -------------------------------------------------------------------------------- /audtorch/__init__.py: -------------------------------------------------------------------------------- 1 | from . import collate 2 | from . import datasets 3 | from . import metrics 4 | from . import samplers 5 | from . import transforms 6 | from . import utils 7 | -------------------------------------------------------------------------------- /audtorch/collate.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn.utils.rnn import pad_sequence 3 | 4 | 5 | class Collation(object): 6 | r"""Abstract interface for collation classes. 7 | 8 | All other collation classes should subclass it. All subclasses should 9 | override ``__call__``, that executes the actual collate function. 10 | 11 | """ 12 | 13 | def __init__(self): 14 | super().__init__() 15 | 16 | def __call__(self, batch): 17 | raise NotImplementedError("This is an abstract interface for " 18 | "modularizing collate functions." 19 | "Please use one of its subclasses.") 20 | 21 | 22 | class Seq2Seq(Collation): 23 | r"""Pads mini-batches to longest contained sequence for seq2seq-purposes. 24 | 25 | This class pads features and targets to the largest sequence in the batch. 26 | Before padding, length information are extracted from them. 27 | 28 | Note: 29 | The tensors can be sorted in descending order of features' lengths 30 | by enabling :attr:`sort_sequences`. Thereby the requirements of 31 | :py:func:`torch.nn.utils.rnn.pack_padded_sequence` 32 | are anticipated, which is used by recurrent layers. 33 | 34 | * :attr:`sequence_dimensions` holds dimension of sequence in features 35 | and targets 36 | * :attr:`batch_first` controls output shape of features and targets 37 | * :attr:`pad_values` controls values to pad features (targets) with 38 | * :attr:`sort_sequences` controls if sequences are sorted in 39 | descending order of `features`' lengths 40 | 41 | Args: 42 | sequence_dimensions (list of ints): indices representing dimension of 43 | sequence in feature and target tensors. 44 | Position `0` represents sequence dimension of `features`, 45 | position `1` represents sequence dimension of `targets`. 46 | Negative indexing is permitted 47 | batch_first (bool or None, optional): determines output shape of 48 | collate function. If `None`, original shape of 49 | `features` and `targets` is kept with dimension of `batch size` 50 | prepended. See Shape for more information. 51 | Default: `None` 52 | pad_values (list, optional): values to pad shorter sequences with. 53 | Position `0` represents value of `features`, 54 | position `1` represents value of `targets`. Default: `[0, 0]` 55 | sort_sequences (bool, optional): option whether to sort sequences 56 | in descending order of `features`' lengths. Default: `True` 57 | 58 | Shape: 59 | - Input: :math:`(*, S, *)`, where :math:`*` can be any number 60 | of further dimensions except :math:`N` which is the batch size, 61 | and where :math:`S` is the sequence dimension. 62 | - Output: 63 | 64 | - `features`: 65 | 66 | - :math:`(N, *, S, *)` if :attr:`batch_first` is `None`, 67 | i.e. the original input shape with :math:`N` prepended 68 | which is the batch size 69 | - :math:`(N, S, *, *)` if :attr:`batch_first` is `True` 70 | - :math:`(S, N, *, *)` if :attr:`batch_first` is `False` 71 | 72 | - `feats_lengths`: :math:`(N,)` 73 | 74 | - `targets`: analogous to `features` 75 | 76 | - `tgt_lengths`: analogous to `feats_lengths` 77 | 78 | Example: 79 | >>> # data format: FS = (feature dimension, sequence dimension) 80 | >>> batch = [[torch.zeros(161, 108), torch.zeros(10)], 81 | ... [torch.zeros(161, 223), torch.zeros(12)]] 82 | >>> collate_fn = Seq2Seq([-1, -1], batch_first=None) 83 | >>> features = collate_fn(batch)[0] 84 | >>> list(features.shape) 85 | [2, 161, 223] 86 | 87 | """ 88 | 89 | def __init__( 90 | self, 91 | sequence_dimensions, 92 | *, 93 | batch_first=None, 94 | pad_values=[0, 0], 95 | sort_sequences=True, 96 | ): 97 | 98 | self.sequence_dimensions = sequence_dimensions 99 | self.batch_first = batch_first 100 | self.pad_values = pad_values 101 | self.sort_sequences = sort_sequences 102 | 103 | def __call__(self, batch): 104 | r"""Collate and pad sequences of mini-batch. 105 | 106 | The output tensor is augmented by the dimension of `batch_size`. 107 | 108 | Args: 109 | batch (list of tuples): contains all samples of a batch. 110 | Each sample is represented by a tuple (`features`, `targets`) 111 | which is returned by data set's __getitem__ method 112 | 113 | Returns: 114 | torch.tensors: `features`, `feature lengths`, `targets` 115 | and `target lengths` in data format according to 116 | :attr:`batch_first`. 117 | 118 | """ 119 | features = [torch.as_tensor(sample[0]) for sample in batch] 120 | features, feats_lengths, sorted_indices = _collate_sequences( 121 | features, self.sequence_dimensions[0], self.pad_values[0], 122 | self.batch_first, self.sort_sequences, []) 123 | 124 | targets = [torch.as_tensor(sample[1]) for sample in batch] 125 | targets, tgt_lengths, _ = _collate_sequences( 126 | targets, self.sequence_dimensions[1], self.pad_values[1], 127 | self.batch_first, self.sort_sequences, sorted_indices) 128 | 129 | return features, feats_lengths, targets, tgt_lengths 130 | 131 | 132 | def _collate_sequences( 133 | sequences, 134 | sequence_dimension, 135 | pad_value, 136 | batch_first, 137 | sort_sequences=True, 138 | sorted_indices=[], 139 | ): 140 | r"""Collate and pad sequences. 141 | 142 | Args: 143 | sequences (list of torch.tensors): contains all samples of a batch 144 | sequence_dimension (int): index representing dimension of sequence 145 | in tensors 146 | batch_first (bool or None): determines output shape of tensors 147 | pad_value (float): value to pad shorter sequences with. 148 | sort_sequences (bool, optional): option whether to sort `sequences`. 149 | Default: `True` 150 | sorted_indices (list of ints, optional): indices to sort sequences 151 | and their lengths in descending order with. Default: `[]` 152 | 153 | Returns: 154 | tuple: 155 | 156 | * torch.Tensor: data of sequences in format 157 | according to :attr:`batch_first`, list of sorted indices 158 | * torch.IntTensor: lengths of sequences 159 | * list of int: indices of sequences sorted in descending order 160 | of their lengths 161 | 162 | """ 163 | # handle negative indexing 164 | sequence_dimension = len(sequences[0].shape) + sequence_dimension \ 165 | if sequence_dimension < 0 else sequence_dimension 166 | 167 | # input to `pad_sequence` requires shape `(S, *)` 168 | # swap sequence-dimension to the front 169 | sequences = [t.transpose(0, sequence_dimension) for t in sequences] 170 | 171 | # extract lengths (first dimension of permuted tensor) 172 | lengths = [t.shape[0] for t in sequences] 173 | 174 | if sort_sequences: 175 | # sort sequences and lengths in descending order of `lengths` 176 | if not sorted_indices: 177 | sorted_indices = sorted( 178 | range(len(lengths)), key=lambda i: lengths[i], reverse=True) 179 | 180 | sequences = [sequences[idx] for idx in sorted_indices] 181 | lengths = [lengths[idx] for idx in sorted_indices] 182 | 183 | # pad sequences 184 | sequences = pad_sequence( 185 | sequences=sequences, 186 | batch_first=batch_first if batch_first is not None else True, 187 | padding_value=pad_value) 188 | 189 | if batch_first is None: # recover input data format 190 | sequences = sequences.transpose(1, sequence_dimension + 1) 191 | 192 | else: # recover order of "*"-dimensions 193 | permuted = list(range(len(sequences.shape))) 194 | if sequence_dimension >= 2: 195 | permuted.insert( 196 | 2, permuted.pop(sequence_dimension + 1)) 197 | sequences = sequences.permute(permuted) 198 | 199 | lengths = torch.IntTensor(lengths) 200 | 201 | return sequences, lengths, sorted_indices 202 | -------------------------------------------------------------------------------- /audtorch/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .audio_set import * 2 | from .base import * 3 | from .emodb import * 4 | from .libri_speech import * 5 | from .mixture import * 6 | from .mozilla_common_voice import * 7 | from .speech_commands import * 8 | from .voxceleb1 import * 9 | from .white_noise import * 10 | from .utils import * 11 | -------------------------------------------------------------------------------- /audtorch/datasets/audio_set.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | 4 | import pandas as pd 5 | 6 | from ..utils import flatten_list 7 | from .base import AudioDataset 8 | from .utils import safe_path 9 | 10 | 11 | __doctest_skip__ = ['*'] 12 | 13 | 14 | class AudioSet(AudioDataset): 15 | r"""A large-scale dataset of manually annotated audio events. 16 | 17 | Open and publicly available data set of audio events from Google: 18 | https://research.google.com/audioset/ 19 | 20 | License: CC BY 4.0 21 | 22 | The categories corresponding to an audio signal are returned as a list, 23 | starting with those included in the top hierarchy of the 24 | `AudioSet ontology`_, followed by those from the second hierarchy and then 25 | all other categories in a random order. 26 | 27 | The signals to be returned can be limited by excluding or including only 28 | certain categories. This is achieved by first including only the desired 29 | categories, estimating all its parent categories and then applying the 30 | exclusion. 31 | 32 | .. _AudioSet ontology: https://research.google.com/audioset/ontology/ 33 | 34 | * :attr:`transform` controls the input transform 35 | * :attr:`target_transform` controls the target transform 36 | * :attr:`files` controls the audio files of the data set 37 | * :attr:`targets` controls the corresponding targets 38 | * :attr:`sampling_rate` holds the sampling rate of the returned data 39 | * :attr:`original_sampling_rate` holds the sampling rate of the audio files 40 | of the data set 41 | 42 | Args: 43 | root (str): root directory of dataset 44 | csv_file (str, optional): name of a CSV file located in `root`. Can be 45 | one of `balanced_train_segments.csv`, 46 | `unbalanced_train_segments.csv`, `eval_segments.csv`. 47 | Default: `balanced_train_segments.csv` 48 | include (list of str, optional): list of categories to include. 49 | If `None` all categories are included. Default: `None` 50 | exclude (list of str, optional): list of categories to exclude. 51 | If `None` no category is excluded. Default: `None` 52 | transform (callable, optional): function/transform applied on the 53 | signal. Default: `None` 54 | target_transform (callable, optional): function/transform applied on 55 | the target. Default: `None` 56 | 57 | `AudioSet ontology`_ categories of the two top hierarchies: 58 | 59 | .. code-block:: none 60 | 61 | Human sounds Animal Music 62 | |-Human voice |-Domestic animals, pets |-Musical instrument 63 | |-Whistling |-Livestock, farm |-Music genre 64 | |-Respiratory sounds | animals, working |-Musical concepts 65 | |-Human locomotion | animals |-Music role 66 | |-Digestive \-Wild animals \-Music mood 67 | |-Hands 68 | |-Heart sounds, Sounds of things Natural sounds 69 | | heartbeat |-Vehicle |-Wind 70 | |-Otoacoustic emission |-Engine |-Thunderstorm 71 | \-Human group actions |-Domestic sounds, |-Water 72 | | home sounds \-Fire 73 | Source-ambiguous sounds |-Bell 74 | |-Generic impact sounds |-Alarm Channel, environment 75 | |-Surface contact |-Mechanisms and background 76 | |-Deformable shell |-Tools |-Acoustic environment 77 | |-Onomatopoeia |-Explosion |-Noise 78 | |-Silence |-Wood \-Sound reproduction 79 | \-Other sourceless |-Glass 80 | |-Liquid 81 | |-Miscellaneous sources 82 | \-Specific impact sounds 83 | 84 | Warning: 85 | Some of the recordings in `AudioSet` were captured with `mono` and 86 | others with `stereo` input. The user must be careful to handle this, 87 | e.g. using a transform to adjust number of channels. 88 | 89 | Example: 90 | >>> import sounddevice as sd 91 | >>> data = AudioSet(root='/data/AudioSet', include=['Thunderstorm']) 92 | >>> print(data) 93 | Dataset AudioSet 94 | Number of data points: 73 95 | Root Location: /data/AudioSet 96 | Sampling Rate: 16000Hz 97 | CSV file: balanced_train_segments.csv 98 | Included categories: ['Thunderstorm'] 99 | >>> signal, target = data[4] 100 | >>> target 101 | ['Natural sounds', 'Thunderstorm', 'Water', 'Rain', 'Thunder'] 102 | >>> sd.play(signal.transpose(), data.sampling_rate) 103 | 104 | """ 105 | 106 | # Categories of the two top hieararchies of AudioSet 107 | # https://research.google.com/audioset/ontology/ 108 | categories = { 109 | 'Human sounds': [ 110 | 'Human voice', 'Whistling', 'Respiratory sounds', 111 | 'Human locomotion', 'Digestive', 'Hands', 112 | 'Heart sounds, heartbeat', 'Otoacoustic emission', 113 | 'Human group actions'], 114 | 'Source-ambiguous sounds': [ 115 | 'Generic impact sounds', 'Surface contact', 'Deformable shell', 116 | 'Onomatopoeia', 'Silence', 'Other sourceless'], 117 | 'Animal': [ 118 | 'Domestic animals, pets', 119 | 'Livestock, farm animals, working animals', 'Wild animals'], 120 | 'Sounds of things': [ 121 | 'Vehicle', 'Engine', 'Domestic sounds, home sounds', 'Bell', 122 | 'Alarm', 'Mechanisms', 'Tools', 'Explosion', 'Wood', 'Glass', 123 | 'Liquid', 'Miscellaneous sources', 'Specific impact sounds'], 124 | 'Music': [ 125 | 'Musical instrument', 'Music genre', 'Musical concepts', 126 | 'Music role', 'Music mood'], 127 | 'Natural sounds': [ 128 | 'Wind', 'Thunderstorm', 'Water', 'Fire'], 129 | 'Channel, environment and background': [ 130 | 'Acoustic environment', 'Noise', 'Sound reproduction'] 131 | } 132 | 133 | def __init__( 134 | self, 135 | root, 136 | *, 137 | csv_file='balanced_train_segments.csv', 138 | sampling_rate=16000, 139 | include=None, 140 | exclude=None, 141 | transform=None, 142 | target_transform=None, 143 | ): 144 | root = safe_path(root) 145 | # Allow only official CSV files as no audio paths are defined otherwise 146 | assert csv_file in ['eval_segments.csv', 'balanced_train_segments.csv', 147 | 'unbalanced_train_segments.csv'] 148 | csv_file = os.path.join(root, csv_file) 149 | 150 | # Load complete ontology 151 | with open(os.path.join(root, 'ontology.json')) as fp: 152 | self.ontology = json.load(fp) 153 | 154 | # Get the desired filenames and categories 155 | df = pd.read_csv(csv_file, skiprows=2, sep=', ', engine='python') 156 | df = self._filename_and_ids(df) 157 | if include is not None: 158 | df = self._filter_by_categories(df, include) 159 | df['ids'] = df['ids'].map(self._add_parent_ids) 160 | if exclude is not None: 161 | df = self._filter_by_categories(df, exclude, exclude_mode=True) 162 | categories = df['ids'].map(self._convert_ids_to_categories) 163 | 164 | audio_folder = os.path.splitext(os.path.basename(csv_file))[0] 165 | files = [os.path.join(audio_folder, f) for f in df['filename']] 166 | 167 | super().__init__( 168 | files=files, 169 | targets=categories, 170 | sampling_rate=16000, 171 | root=root, 172 | transform=transform, 173 | target_transform=target_transform, 174 | ) 175 | self.csv_file = csv_file 176 | self.include = include 177 | self.exclude = exclude 178 | 179 | def _filename_and_ids(self, df): 180 | r"""Return data frame with filenames and IDs. 181 | 182 | Args: 183 | df (pandas.DataFrame): data frame as read in from the CSV file 184 | 185 | Results: 186 | pandas.DataFrame: data frame with columns `filename` and `ids` 187 | 188 | """ 189 | df.rename(columns={'positive_labels': 'ids'}, inplace=True) 190 | # Translate labels from "label1,label2" to [label1, label2] 191 | df['ids'] = [label.strip('\"').split(',') for label in df['ids']] 192 | # Insert filename 193 | df['filename'] = (df['# YTID'] 194 | + '_' 195 | + [f'{x:.3f}' for x in df['start_seconds']] 196 | + '.wav') 197 | return df[['filename', 'ids']] 198 | 199 | def _add_parent_ids(self, child_ids): 200 | r"""Add all parent IDs to the list of given child IDs. 201 | 202 | Args: 203 | child_ids (list of str): child IDs 204 | 205 | Return: 206 | list of str: list of child and parent IDs 207 | 208 | """ 209 | ids = child_ids 210 | for id in child_ids: 211 | ids += [x['id'] for x in self.ontology if id in x['child_ids']] 212 | # Remove duplicates 213 | return list(set(ids)) 214 | 215 | def _convert_ids_to_categories(self, ids): 216 | r"""Convert list of ids to sorted list of categories. 217 | 218 | Args: 219 | ids (list of str): list of IDs 220 | 221 | Returns: 222 | list of str: list of sorted categories 223 | 224 | """ 225 | # Convert IDs to categories 226 | categories = [] 227 | for id in ids: 228 | categories += [x['name'] for x in self.ontology 229 | if x['id'] == id] 230 | # Order categories after the first two top ontologies 231 | order = [] 232 | first_hierarchy = self.categories.keys() 233 | second_hierarchy = flatten_list(list(self.categories.values())) 234 | for cat in categories: 235 | if cat in first_hierarchy: 236 | order += [0] 237 | elif cat in second_hierarchy: 238 | order += [1] 239 | else: 240 | order += [2] 241 | # Sort list `categories` by the list `order` 242 | categories = [cat for _, cat in sorted(zip(order, categories))] 243 | return categories 244 | 245 | def _filter_by_categories( 246 | self, 247 | df, 248 | categories, 249 | exclude_mode=False, 250 | ): 251 | r"""Return data frame containing only specified categories. 252 | 253 | Args: 254 | df (pandas.DataFrame): data frame containing the columns `ids` 255 | categories (list of str): list of categories to include or exclude 256 | exclude_mode (bool, optional): if `False` the specified categories 257 | should be included in the data frame, otherwise excluded. 258 | Default: `False` 259 | 260 | Returns: 261 | pandas.DataFrame: data frame containing only the desired categories 262 | 263 | """ 264 | ids = self._ids_for_categories(categories) 265 | if exclude_mode: 266 | # Remove rows that have an intersection of actual and desired IDs 267 | df = df[[False if set(row['ids']) & set(ids) else True 268 | for _, row in df.iterrows()]] 269 | else: 270 | # Include rows that have an intersection of actual and desired IDs 271 | df = df[[True if set(row['ids']) & set(ids) else False 272 | for _, row in df.iterrows()]] 273 | df = df.reset_index(drop=True) 274 | return df 275 | 276 | def _ids_for_categories(self, categories): 277 | r"""All IDs and child IDs for a given set of categories. 278 | 279 | Args: 280 | categories (list of str): list of categories 281 | 282 | Returns: 283 | list: list of IDs 284 | 285 | """ 286 | ids = [] 287 | category_ids = \ 288 | [x['id'] for x in self.ontology if x['name'] in categories] 289 | for category_id in category_ids: 290 | ids += self._subcategory_ids(category_id) 291 | # Remove duplicates 292 | return list(set(ids)) 293 | 294 | def _subcategory_ids(self, parent_id): 295 | r"""Recursively identify all IDs of a given category. 296 | 297 | Args: 298 | parent_id (unicode str): ID of parent category 299 | 300 | Returns: 301 | list: list of all children IDs and the parent ID 302 | 303 | """ 304 | id_list = [parent_id] 305 | child_ids = \ 306 | [x['child_ids'] for x in self.ontology if x['id'] == parent_id] 307 | child_ids = flatten_list(child_ids) 308 | # Add all subcategories 309 | for child_id in child_ids: 310 | id_list += self._subcategory_ids(child_id) 311 | return id_list 312 | 313 | def extra_repr(self): 314 | fmt_str = f' CSV file: {os.path.basename(self.csv_file)}\n' 315 | if self.include: 316 | fmt_str += f' Included categories: {self.include}\n' 317 | if self.exclude: 318 | fmt_str += f' Excluded categories: {self.exclude}\n' 319 | return fmt_str 320 | -------------------------------------------------------------------------------- /audtorch/datasets/emodb.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | 4 | from typing import Callable 5 | 6 | from audtorch.datasets.base import AudioDataset 7 | from audtorch.datasets.utils import download_url 8 | 9 | 10 | __doctest_skip__ = ['*'] 11 | 12 | 13 | class EmoDB(AudioDataset): 14 | r"""EmoDB data set. 15 | 16 | Open and publicly available data set of acted emotions: 17 | http://www.emodb.bilderbar.info/navi.html 18 | 19 | EmoDB is a small audio data set collected in an anechoic chamber in the 20 | Berlin Institute of Technology, it contains 5 male and 5 female speakers, 21 | consists of 10 unique sentences, and is annotated for 6 emotions plus a 22 | neutral state. The spoken language is German. 23 | 24 | Args: 25 | root: root directory of dataset 26 | transform: function/transform applied on the signal 27 | target_transform: function/transform applied on the target 28 | 29 | Note: 30 | * When using the EmoDB data set in your research, please cite 31 | the following publication: :cite:`burkhardt2005database`. 32 | 33 | Example: 34 | >>> import sounddevice as sd 35 | >>> data = EmoDB('/data/emodb') 36 | >>> print(data) 37 | Dataset EmoDB 38 | Number of data points: 465 39 | Root Location: /data/emodb 40 | Sampling Rate: 16000Hz 41 | Labels: emotion 42 | >>> signal, target = data[0] 43 | >>> target 44 | 'A' 45 | >>> sd.play(signal.transpose(), data.sampling_rate) 46 | 47 | """ 48 | url = ('http://www.emodb.bilderbar.info/navi.html') 49 | 50 | def __init__(self, root: str, *, transform: Callable = None, 51 | target_transform: Callable = None): 52 | files = glob.glob(root + '/*.wav') 53 | files = [os.path.basename(x) for x in files] 54 | targets = [x.split('.')[0][-2] for x in files] 55 | super().__init__(root=root, files=files, targets=targets, 56 | transform=transform, 57 | sampling_rate=16000, 58 | target_transform=target_transform) 59 | 60 | def extra_repr(self): 61 | fmt_str = ' Labels: emotion\n' 62 | return fmt_str 63 | -------------------------------------------------------------------------------- /audtorch/datasets/libri_speech.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import shutil 4 | 5 | import pandas as pd 6 | 7 | from .utils import (download_url_list, extract_archive, safe_path) 8 | from ..utils import run_worker_threads 9 | from .base import PandasDataset 10 | 11 | 12 | __doctest_skip__ = ['*'] 13 | 14 | 15 | class LibriSpeech(PandasDataset): 16 | r"""`LibriSpeech` speech data set. 17 | 18 | Open and publicly available data set of voices from OpenSLR: 19 | http://www.openslr.org/12/ 20 | 21 | License: CC BY 4.0. 22 | 23 | `LibriSpeech` contains several hundred hours of English speech 24 | with corresponding transcriptions in capital letters without punctuation. 25 | 26 | It is split into different subsets according to WER-level achieved when 27 | performing speech recognition on the speakers. The subsets are: 28 | `train-clean-100`, `train-clean-360`, `train-other-500` `dev-clean`, 29 | `dev-other`, `test-clean`, `test-other` 30 | 31 | * :attr:`root` holds the data set's location 32 | * :attr:`transform` controls the input transform 33 | * :attr:`target_transform` controls the target transform 34 | * :attr:`files` controls the audio files of the data set 35 | * :attr:`labels` controls the corresponding labels 36 | * :attr:`sampling_rate` holds the sampling rate of data set 37 | 38 | In addition, the following class attributes are available 39 | 40 | * :attr:`all_sets` holds the names of the different pre-defined sets 41 | * :attr:`urls` holds the download links of the different sets 42 | 43 | Args: 44 | root (str): root directory of data set 45 | sets (str or list, optional): desired sets of `LibriSpeech`. 46 | Mutually exclusive with :attr:`dataframe`. 47 | Default: `None` 48 | dataframe (pandas.DataFrame, optional): pandas data frame containing 49 | columns `audio_path` (relative to root) and `transcription`. 50 | It can be used to pre-select files based on meta information, 51 | e.g. sequence length. Mutually exclusive with :attr:`sets`. 52 | Default: `None` 53 | transform (callable, optional): function/transform applied on 54 | the signal. Default: `None` 55 | target_transform (callable, optional): function/transform applied on 56 | the target. Default: `None` 57 | download (bool, optional): download data set to root directory 58 | if not present. Default: `False` 59 | 60 | Example: 61 | >>> import sounddevice as sd 62 | >>> data = LibriSpeech(root='/data/LibriSpeech', sets='dev-clean') 63 | >>> print(data) 64 | Dataset LibriSpeech 65 | Number of data points: 2703 66 | Root Location: /data/LibriSpeech 67 | Sampling Rate: 16000Hz 68 | Sets: dev-clean 69 | >>> signal, label = data[8] 70 | >>> label 71 | AS FOR ETCHINGS THEY ARE OF TWO KINDS BRITISH AND FOREIGN 72 | >>> sd.play(signal.transpose(), data.sampling_rate) 73 | 74 | """ 75 | 76 | all_sets = ['train-clean-100', 'train-clean-360', 'train-other-500', 77 | 'dev-clean', 'dev-other', 'test-clean', 'test-other'] 78 | urls = { 79 | "train-clean-100": 80 | 'https://openslr.org/resources/12/train-clean-100.tar.gz', 81 | "train-clean-360": 82 | 'https://openslr.org/resources/12/train-clean-360.tar.gz', 83 | "train-other-500": 84 | 'https://openslr.org/resources/12/train-other-500.tar.gz', 85 | "dev-clean": 86 | 'https://openslr.org/resources/12/dev-clean.tar.gz', 87 | "dev-other": 88 | 'https://openslr.org/resources/12/dev-other.tar.gz', 89 | "test-clean": 90 | 'https://openslr.org/resources/12/test-clean.tar.gz', 91 | "test-other": 92 | 'https://openslr.org/resources/12/test-other.tar.gz'} 93 | _transcription = 'transcription' 94 | _audio_path = 'audio_path' 95 | 96 | def __init__( 97 | self, 98 | root, 99 | *, 100 | sets=None, 101 | dataframe=None, 102 | transform=None, 103 | target_transform=None, 104 | download=False, 105 | ): 106 | 107 | self.root = safe_path(root) 108 | 109 | if isinstance(sets, str): 110 | sets = [sets] 111 | if dataframe is None and sets is None: 112 | self.sets = self.all_sets 113 | elif dataframe is None: 114 | assert set(sets) <= set(self.all_sets) 115 | self.sets = sets 116 | elif dataframe is not None: 117 | self.sets = None 118 | else: 119 | raise ValueError('Either `sets` or `dataframe` can be specified.') 120 | 121 | if download: # data not available 122 | self._download() 123 | 124 | if not self._check_exists(): 125 | raise RuntimeError('Requested sets of data set not found.') 126 | 127 | if dataframe is None: 128 | files = self._get_files() 129 | dataframe = self._create_dataframe(files) 130 | 131 | super().__init__( 132 | root=self.root, 133 | sampling_rate=16000, 134 | df=dataframe, 135 | column_filename=self._audio_path, 136 | column_labels=self._transcription, 137 | transform=transform, 138 | target_transform=target_transform) 139 | 140 | def _check_exists(self): 141 | return all([os.path.exists(os.path.join(self.root, s)) 142 | for s in self.sets]) 143 | 144 | def _download(self): 145 | absent_sets = [s for s in self.sets 146 | if not os.path.exists(os.path.join(self.root, s))] 147 | if not absent_sets: 148 | return 149 | 150 | out_path = os.path.join(self.root, "tmp") 151 | if not os.path.exists(out_path): 152 | os.makedirs(out_path) 153 | 154 | urls = [self.urls[s] for s in absent_sets] 155 | filenames = download_url_list(urls, out_path, num_workers=0) 156 | for filename in filenames: 157 | extract_archive(os.path.join(out_path, filename), 158 | out_path=out_path, 159 | remove_finished=True) 160 | contents = glob.glob(os.path.join(out_path, 'LibriSpeech/*')) 161 | for f in contents: 162 | shutil.move(f, self.root) 163 | os.rmdir(os.path.join(out_path, "LibriSpeech")) 164 | os.rmdir(out_path) 165 | 166 | def _get_files(self): 167 | files = [] 168 | for set in self.sets: 169 | path_to_files = os.path.join(self.root, set, '**/**/*.flac') 170 | files += glob.glob(path_to_files) 171 | return files 172 | 173 | @classmethod 174 | def _create_dataframe(cls, files): 175 | 176 | def _create_df_per_txt(txt_file): 177 | _set = txt_file.rsplit('/', 4)[-4] 178 | df_per_txt = pd.read_csv(txt_file, names=[cls._audio_path]) 179 | 180 | # split content once with delimiter 181 | df_per_txt[[cls._audio_path, cls._transcription]] = \ 182 | df_per_txt[cls._audio_path].str.split(" ", 1, expand=True) 183 | 184 | # compose audio paths relative to root 185 | df_per_txt[cls._audio_path] = df_per_txt[cls._audio_path].apply( 186 | _compose_relative_audio_paths(_set)) 187 | return df_per_txt 188 | 189 | def _compose_relative_audio_paths(_set): 190 | return lambda row: os.path.join( 191 | _set, *row.split('-')[:-1], row + '.flac') 192 | 193 | # get absolute paths to txt files from audio files 194 | txt_files = sorted(list(set( 195 | [f.rsplit('-', 1)[0] + '.trans.txt' for f in files]))) 196 | args = [(f, ) for f in txt_files] 197 | dataframes = run_worker_threads(12, _create_df_per_txt, args) 198 | return pd.concat(dataframes, ignore_index=True) 199 | 200 | def extra_repr(self): 201 | if self.sets is not None: 202 | fmt_str = f' Sets: {", ".join(self.sets)}\n' 203 | return fmt_str 204 | -------------------------------------------------------------------------------- /audtorch/datasets/mixture.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import numpy as np 4 | from torch.utils.data import Dataset 5 | 6 | from .base import _include_repr 7 | from .utils import ensure_same_sampling_rate 8 | 9 | 10 | __doctest_skip__ = ['*'] 11 | 12 | 13 | class SpeechNoiseMix(Dataset): 14 | r"""Mix speech and noise with speech as target. 15 | 16 | Add noise to each speech sample from the provided data by a 17 | mix transform. Return the mix as input and the speech signal as 18 | corresponding target. In addition, allow to replace randomly some of the 19 | mixes by noise as input and silence as output. This helps to train a speech 20 | enhancement algorithm to deal with background noise only as input signal 21 | :cite:`Rethage2018`. 22 | 23 | * :attr:`speech_dataset` controls the speech data set 24 | * :attr:`mix_transform` controls the transform that adds noise 25 | * :attr:`transform` controls the transform applied on the mix 26 | * :attr:`target_transform` controls the transform applied on the target 27 | clean speech 28 | * :attr:`joint_transform` controls the transform applied jointly on the 29 | mixture an the target clean speech 30 | * :attr:`percentage_silence` controls the amount of noise-silent data 31 | augmentation 32 | 33 | Args: 34 | speech_dataset (Dataset): speech data set 35 | mix_transform (callable): function/transform that can augment a signal 36 | with noise 37 | transform (callable, optional): function/transform applied on the 38 | speech-noise-mixture (input) only. Default; `None` 39 | target_transform (callable, optional): function/transform applied 40 | on the speech (target) only. Default: `None` 41 | joint_transform (callable, optional): function/transform applied 42 | on the mixtue (input) and speech (target) simultaneously. If the 43 | transform includes randomization it is applied with the same random 44 | parameter during both calls 45 | percentage_silence (float, optional): value between `0` and `1`, which 46 | controls the percentage of randomly inserted noise input, silent 47 | target pairs. Default: `0` 48 | 49 | Examples: 50 | >>> import sounddevice as sd 51 | >>> from audtorch import datasets, transforms 52 | >>> noise = datasets.WhiteNoise(duration=10, sampling_rate=48000) 53 | >>> mix = transforms.RandomAdditiveMix(noise) 54 | >>> normalize = transforms.Normalize() 55 | >>> speech = datasets.MozillaCommonVoice(root='/data/MozillaCommonVoice/cv_corpus_v1') 56 | >>> data = SpeechNoiseMix(speech, mix, transform=normalize) 57 | >>> noisy, clean = data[0] 58 | >>> sd.play(noisy.transpose(), data.sampling_rate) 59 | 60 | """ # noqa: E501 61 | def __init__( 62 | self, 63 | speech_dataset, 64 | mix_transform, 65 | *, 66 | transform=None, 67 | target_transform=None, 68 | joint_transform=None, 69 | percentage_silence=0, 70 | ): 71 | super().__init__() 72 | self.speech_dataset = speech_dataset 73 | self.mix_transform = mix_transform 74 | self.transform = transform 75 | self.target_transform = target_transform 76 | self.joint_transform = joint_transform 77 | self.percentage_silence = percentage_silence 78 | 79 | if not (0 <= self.percentage_silence <= 1): 80 | raise ValueError('`percentage_silence` needs to be in [0, 1]`') 81 | 82 | if hasattr(mix_transform, 'dataset'): 83 | ensure_same_sampling_rate([speech_dataset, mix_transform.dataset]) 84 | 85 | def __len__(self): 86 | return len(self.speech_dataset) 87 | 88 | def __getitem__(self, item): 89 | # [0] ensures that we get only data, no targets 90 | speech = self.speech_dataset[item][0] 91 | # Randomly add (noise, silence) as (input, target) 92 | if random.random() < self.percentage_silence: 93 | speech = np.zeros(speech.shape) 94 | mixture = self.mix_transform(speech) 95 | 96 | if self.joint_transform is not None: 97 | randomness = getattr(self.joint_transform, 'fix_randomization', 98 | None) 99 | mixture = self.joint_transform(mixture) 100 | if randomness is not None: 101 | self.joint_transform.fix_randomization = True 102 | speech = self.joint_transform(speech) 103 | if randomness is not None: 104 | self.joint_transform.fix_randomization = randomness 105 | 106 | if self.transform is not None: 107 | mixture = self.transform(mixture) 108 | 109 | if self.target_transform is not None: 110 | speech = self.target_transform(speech) 111 | 112 | # input, target 113 | return mixture, speech 114 | 115 | @property 116 | def sampling_rate(self): 117 | return self.speech_dataset.sampling_rate 118 | 119 | def __repr__(self): 120 | speech_dataset_name = self.speech_dataset.__class__.__name__ 121 | fmt_str = f'Dataset {self.__class__.__name__}\n' 122 | fmt_str += f' Number of data points: {self.__len__()}\n' 123 | fmt_str += f' Speech dataset: {speech_dataset_name}\n' 124 | fmt_str += f' Sampling rate: {self.sampling_rate}Hz\n' 125 | if self.percentage_silence > 0: 126 | fmt_str += ( 127 | f' Silence augmentation: ' 128 | f'{100 * self.percentage_silence:.0f}%\n' 129 | ) 130 | fmt_str += ' Labels: speech signal\n' 131 | fmt_str += _include_repr('Mixing Transform', self.mix_transform) 132 | if self.transform: 133 | fmt_str += _include_repr('Transform', self.transform) 134 | if self.target_transform: 135 | fmt_str += _include_repr('Target Transform', self.target_transform) 136 | if self.joint_transform: 137 | fmt_str += _include_repr('Joint Transform', self.joint_transform) 138 | return fmt_str 139 | -------------------------------------------------------------------------------- /audtorch/datasets/mozilla_common_voice.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from .utils import (download_url, extract_archive, safe_path) 4 | from .base import CsvDataset 5 | 6 | 7 | __doctest_skip__ = ['*'] 8 | 9 | 10 | class MozillaCommonVoice(CsvDataset): 11 | """Mozilla Common Voice speech data set. 12 | 13 | Open and publicly available data set of voices from Mozilla: 14 | https://voice.mozilla.org/en/datasets 15 | 16 | License: CC-0 (public domain) 17 | 18 | Mozilla Common Voice includes the labels `text`, `up_votes`, 19 | `down_votes`, `age`, `gender`, `accent`, `duration`. You can select one of 20 | those labels which is returned as a string by the data set as target or you 21 | can specify a list of the labels and the data set will return a dictionary 22 | containing those labels. The default label that is returned is `text`. 23 | 24 | * :attr:`root` holds the data set's location 25 | * :attr:`transform` controls the input transform 26 | * :attr:`target_transform` controls the target transform 27 | * :attr:`files` controls the audio files of the data set 28 | * :attr:`targets` controls the corresponding targets 29 | * :attr:`sampling_rate` holds the sampling rate of the returned data 30 | * :attr:`original_sampling_rate` holds the sampling rate of the audio files 31 | of the data set 32 | 33 | In addition, the following class attribute is available 34 | 35 | * :attr:`url` holds the download link of the data set 36 | 37 | Args: 38 | root (str): root directory of data set, where the CSV files are 39 | located, e.g. `/data/MozillaCommonVoice/cv_corpus_v1` 40 | csv_file (str, optional): name of a CSV file from the `root` 41 | folder. No absolute path is possible. You are most probably 42 | interested in `cv-valid-train.csv`, `cv-valid-dev.csv`, and 43 | `cv-valid-test.csv`. Default: `cv-valid-train.csv`. 44 | label_type (str or list of str, optional): one of `text`, `up_votes`, 45 | `down_votes`, `age`, `gender`, `accent`, `duration`. Or a list of 46 | any combination of those. Default: `text` 47 | transform (callable, optional): function/transform applied on the 48 | signal. Default: `None` 49 | target_transform (callable, optional): function/transform applied on 50 | the target. Default: `None` 51 | download (bool, optional): download data set if not present. 52 | Default: `False` 53 | 54 | Note: 55 | The Mozilla Common Voice data set is constantly growing. If you 56 | choose to download it, it will always grep the latest version. If 57 | you require reproducibility of your results, make sure to store a 58 | safe snapshot of the version you used. 59 | 60 | Example: 61 | >>> import sounddevice as sd 62 | >>> data = MozillaCommonVoice('/data/MozillaCommonVoice/cv_corpus_v1') 63 | >>> print(data) 64 | Dataset MozillaCommonVoice 65 | Number of data points: 195776 66 | Root Location: /data/MozillaCommonVoice/cv_corpus_v1 67 | Sampling Rate: 48000Hz 68 | Labels: text 69 | CSV file: cv-valid-train.csv 70 | >>> signal, target = data[0] 71 | >>> target 72 | 'learn to recognize omens and follow them the old king had said' 73 | >>> sd.play(signal.transpose(), data.sampling_rate) 74 | 75 | """ # noqa: E501 76 | 77 | url = ('https://common-voice-data-download.s3.amazonaws.com/' 78 | 'cv_corpus_v1.tar.gz') 79 | 80 | def __init__( 81 | self, 82 | root, 83 | *, 84 | csv_file='cv-valid-train.csv', 85 | label_type='text', 86 | transform=None, 87 | target_transform=None, 88 | download=False, 89 | ): 90 | 91 | self.root = safe_path(root) 92 | csv_file = os.path.join(root, csv_file) 93 | 94 | if download: 95 | self._download() 96 | 97 | super().__init__( 98 | csv_file=csv_file, 99 | sampling_rate=48000, 100 | root=root, 101 | sep=',', 102 | column_labels=label_type, 103 | column_filename='filename', 104 | transform=transform, 105 | target_transform=target_transform, 106 | ) 107 | 108 | def _download(self): 109 | if self._check_exists(): 110 | return 111 | download_dir = self.root 112 | corpus = 'cv_corpus_v1' 113 | if download_dir.endswith(corpus): 114 | download_dir = download_dir[:-len(corpus)] 115 | filename = download_url(self.url, download_dir) 116 | extract_archive(filename, remove_finished=True) 117 | -------------------------------------------------------------------------------- /audtorch/datasets/speech_commands.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | from warnings import warn 4 | 5 | import resampy 6 | from .utils import (download_url, extract_archive, safe_path, load) 7 | from .base import AudioDataset 8 | from ..transforms import RandomCrop 9 | from os.path import join 10 | 11 | __doctest_skip__ = ['*'] 12 | 13 | 14 | class SpeechCommands(AudioDataset): 15 | r"""Data set of spoken words designed for keyword spotting tasks. 16 | 17 | Speech Commands V2 publicly available from Google: 18 | http://download.tensorflow.org/data/speech_commands_v0.02.tar.gz 19 | 20 | License: CC BY 4.0 21 | 22 | Args: 23 | root (str): root directory of data set, 24 | where the CSV files are located, 25 | e.g. `/data/speech_commands_v0.02` 26 | train (bool, optional): Partition the dataset into the training set. 27 | `False` returns the test split. 28 | Default: `False` 29 | download (bool, optional): Download the dataset to `root` 30 | if it's not already available. 31 | Default: `False` 32 | include (str, or list of str, optional): commands to include 33 | as 'recognised' words. 34 | Options: `"10cmd"`, `"full"`. 35 | A custom dataset can be defined using a list of command words. 36 | For example, `["stop","go"]`. 37 | Words that are not in the "include" list 38 | are treated as unknown words. 39 | Default: `'10cmd'` 40 | silence (bool, optional): include a 'silence' class composed of 41 | background noise (Note: use randomcrop when training). 42 | Default: `True` 43 | transform (callable, optional): function/transform applied on the 44 | signal. 45 | Default: `None` 46 | target_transform (callable, optional): function/transform applied on 47 | the target. 48 | Default: `None` 49 | 50 | Example: 51 | >>> import sounddevice as sd 52 | >>> data = SpeechCommands(root='/data/speech_commands_v0.02') 53 | >>> print(data) 54 | Dataset SpeechCommands 55 | Number of data points: 97524 56 | Root Location: /data/speech_commands_v0.02 57 | Sampling Rate: 16000Hz 58 | >>> signal, target = data[4] 59 | >>> target 60 | 'right' 61 | >>> sd.play(signal.transpose(), data.sampling_rate) 62 | """ 63 | 64 | url = ('http://download.tensorflow.org/' 65 | 'data/speech_commands_v0.02.tar.gz') 66 | 67 | # Available target commands 68 | classes = [ 69 | 'right', 'eight', 'cat', 'tree', 'backward', 70 | 'learn', 'bed', 'happy', 'go', 'dog', 'no', 71 | 'wow', 'follow', 'nine', 'left', 'stop', 'three', 72 | 'sheila', 'one', 'bird', 'zero', 'seven', 'up', 73 | 'visual', 'marvin', 'two', 'house', 'down', 'six', 74 | 'yes', 'on', 'five', 'forward', 'off', 'four'] 75 | 76 | partitions = { 77 | # https://arxiv.org/pdf/1710.06554.pdf 78 | '10cmd': ['yes', 'no', 'up', 'down', 'left', 79 | 'right', 'on', 'off', 'stop', 'go'], 80 | 'full': classes 81 | } 82 | 83 | def __init__( 84 | self, 85 | root, 86 | train=True, 87 | download=False, 88 | *, 89 | sampling_rate=16000, 90 | include='10cmd', 91 | transform=None, 92 | target_transform=None, 93 | ): 94 | self.root = safe_path(root) 95 | self.same_length = False 96 | self.silence_label = -1 97 | self.trim = RandomCrop(sampling_rate) 98 | 99 | if download: 100 | self._download() 101 | 102 | if type(include) is not list: 103 | include = self.partitions[include] 104 | 105 | if not set(include) == set(self.classes): 106 | include.append("_unknown_") 107 | 108 | with open(safe_path(join(self.root, 'testing_list.txt'))) as f: 109 | test_files = f.read().splitlines() 110 | 111 | files, targets = [], [] 112 | for speech_cmd in self.classes: 113 | d = os.listdir(join(self.root, speech_cmd)) 114 | d = [join(speech_cmd, x) for x in d] 115 | 116 | # Filter out test / train files using `testing_list.txt` 117 | d_f = list(set(d) - set(test_files)) \ 118 | if train else list(set(d) & set(test_files)) 119 | 120 | files.extend([join(self.root, p) for p in d_f]) 121 | target = speech_cmd if speech_cmd in include else '_unknown_' 122 | # speech commands is a classification dataset, so return logits 123 | targets.extend([include.index(target) for _ in range(len(d_f))]) 124 | 125 | self.silence_label = len(include) 126 | 127 | # Match occurrences of silence with `unknown` 128 | # if silence: 129 | # n_samples = max(targets.count(len(include) - 1), 3000) 130 | # n_samples = int(n_samples * 0.9) \ 131 | # if train else int(n_samples * 0.1) 132 | # 133 | # sf = [] 134 | # for file in os.listdir(join(self.root, '_background_noise_')): 135 | # if file.endswith('.wav'): 136 | # sf.append(join(self.root, '_background_noise_', file)) 137 | # 138 | # targets.extend([len(include) for _ in range(n_samples)]) 139 | # files.extend(random.choices(sf, k=n_samples)) 140 | 141 | super().__init__( 142 | files=files, 143 | targets=targets, 144 | sampling_rate=sampling_rate, 145 | root=root, 146 | transform=transform, 147 | target_transform=target_transform, 148 | ) 149 | 150 | def add_silence( 151 | self, 152 | n_samples=3000, 153 | same_length=True, 154 | ): 155 | # https://github.com/audeering/audtorch/pull/49#discussion_r317489141 156 | self.same_length = same_length 157 | self.targets.extend([self.silence_label for _ in range(n_samples)]) 158 | 159 | bg_noises = [] 160 | for file in os.listdir(join(self.root, '_background_noise_')): 161 | if file.endswith('.wav'): 162 | bg_noises.append(join(self.root, '_background_noise_', file)) 163 | 164 | self.files.extend(random.choices(bg_noises, k=n_samples)) 165 | 166 | def __getitem__(self, index): 167 | signal, signal_sampling_rate = load(self.files[index]) 168 | # Handle empty signals 169 | if signal.shape[1] == 0: 170 | warn('Returning previous file.', UserWarning) 171 | return self.__getitem__(index - 1) 172 | # Handle different sampling rate 173 | if signal_sampling_rate != self.original_sampling_rate: 174 | warn( 175 | (f'Resample from {signal_sampling_rate} ' 176 | f'to {self.original_sampling_rate}'), 177 | UserWarning, 178 | ) 179 | signal = resampy.resample(signal, signal_sampling_rate, 180 | self.original_sampling_rate, axis=-1) 181 | 182 | target = self.targets[index] 183 | 184 | # https://github.com/audeering/audtorch/pull/49#discussion_r319044362 185 | if target == self.silence_label and self.same_length: 186 | signal = self.trim(signal) 187 | 188 | if self.transform is not None: 189 | signal = self.transform(signal) 190 | 191 | if self.target_transform is not None: 192 | target = self.target_transform(target) 193 | 194 | return signal, target 195 | 196 | def _download(self): 197 | if self._check_exists(): 198 | return 199 | download_dir = self.root 200 | corpus = 'speech_commands_v0.02' 201 | if download_dir.endswith(corpus): 202 | download_dir = download_dir[:-len(corpus)] 203 | filename = download_url(self.url, download_dir) 204 | extract_archive(filename, out_path=self.root, remove_finished=True) 205 | -------------------------------------------------------------------------------- /audtorch/datasets/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import subprocess 3 | from warnings import warn 4 | import urllib 5 | import tarfile 6 | 7 | from tqdm import tqdm 8 | import numpy as np 9 | import audiofile as af 10 | from torch.utils.data import Subset 11 | 12 | from ..utils import run_worker_threads 13 | 14 | 15 | __doctest_skip__ = ['load'] 16 | 17 | 18 | def load( 19 | filename, 20 | *, 21 | duration=None, 22 | offset=0, 23 | ): 24 | r"""Load audio file. 25 | 26 | If an error occurrs during loading as the file could not be found, 27 | is empty, or has the wrong format an empty signal is returned and a warning 28 | shown. 29 | 30 | Args: 31 | file (str or int or file-like object): file name of input audio file 32 | duration (float, optional): return only a specified duration in 33 | seconds. Default: `None` 34 | offset (float, optional): start reading at offset in seconds. 35 | Default: `0` 36 | 37 | Returns: 38 | tuple: 39 | 40 | * **numpy.ndarray**: two-dimensional array with shape 41 | `(channels, samples)` 42 | * **int**: sample rate of the audio file 43 | 44 | Example: 45 | >>> signal, sampling_rate = load('speech.wav') 46 | 47 | """ 48 | signal = np.array([[]]) # empty signal of shape (1, 0) 49 | sampling_rate = None 50 | try: 51 | signal, sampling_rate = af.read(filename, 52 | duration=duration, 53 | offset=offset, 54 | always_2d=True) 55 | except ValueError: 56 | warn(f'File opening error for: {filename}', UserWarning) 57 | except (IOError, FileNotFoundError): 58 | warn(f'File does not exist: {filename}', UserWarning) 59 | except RuntimeError: 60 | warn(f'Runtime error for file: {filename}', UserWarning) 61 | except subprocess.CalledProcessError: 62 | warn(f'ffmpeg conversion failed for: {filename}', UserWarning) 63 | return signal, sampling_rate 64 | 65 | 66 | def download_url( 67 | url, 68 | root, 69 | *, 70 | filename=None, 71 | md5=None, 72 | ): 73 | r"""Download a file from an url to a specified directory. 74 | 75 | Args: 76 | url (str): URL to download file from 77 | root (str): directory to place downloaded file in 78 | filename (str, optional): name to save the file under. 79 | If `None`, use basename of URL. Default: `None` 80 | md5 (str, optional): MD5 checksum of the download. 81 | If None, do not check. Default: `None` 82 | 83 | Returns: 84 | str: path to downloaded file 85 | 86 | """ 87 | root = safe_path(root) 88 | if not filename: 89 | filename = os.path.basename(url) 90 | filename = os.path.join(root, filename) 91 | 92 | os.makedirs(root, exist_ok=True) 93 | 94 | # downloads file 95 | if not os.path.isfile(filename): 96 | bar_updater = _gen_bar_updater(tqdm(unit='B', unit_scale=True)) 97 | try: 98 | print('Downloading ' + url + ' to ' + filename) 99 | urllib.request.urlretrieve(url, filename, reporthook=bar_updater) 100 | except OSError: 101 | if url[:5] == 'https': 102 | url = url.replace('https:', 'http:') 103 | print('Failed download. Trying https -> http instead.' 104 | ' Downloading ' + url + ' to ' + filename) 105 | urllib.request.urlretrieve(url, filename, 106 | reporthook=bar_updater) 107 | return safe_path(filename) 108 | 109 | 110 | def download_url_list( 111 | urls, 112 | root, 113 | *, 114 | num_workers=0, 115 | ): 116 | r"""Download files from a list of URLs to a specified directory. 117 | 118 | Args: 119 | urls (list of str or dict): either list of URLs or dictionary 120 | with URLs as keys and with either filenames or tuples of 121 | filename and MD5 checksum as values. Uses basename of URL if 122 | filename is `None`. Performs no check if MD5 checksum is `None` 123 | root (str): directory to place downloaded files in 124 | num_workers (int, optional): number of worker threads 125 | (0 = len(urls)). Default: `0` 126 | 127 | """ 128 | # always convert to dict 129 | if type(urls) is list: 130 | urls = {x: None for x in urls} 131 | 132 | # download file and extract 133 | def _task(url, filename): 134 | md5 = None 135 | if type(filename) is tuple: 136 | filename, md5 = filename 137 | return download_url(url, root, filename=filename, md5=md5) 138 | 139 | # start workers 140 | params = [(url, filename) for url, filename in urls.items()] 141 | return run_worker_threads(num_workers, _task, params) 142 | 143 | 144 | def extract_archive( 145 | filename, 146 | *, 147 | out_path=None, 148 | remove_finished=False, 149 | ): 150 | r"""Extract archive. 151 | 152 | Currently `tar.gz` and `tar` archives are supported. 153 | 154 | Args: 155 | filename (str): path to archive 156 | out_path (str, optional): extract archive in this folder. 157 | Default: folder where archive is located in 158 | remove_finished (bool, optional): if `True` remove archive after 159 | extraction. Default: `False` 160 | 161 | """ 162 | print(f'Extracting {filename}') 163 | if out_path is None: 164 | out_path = os.path.dirname(filename) 165 | if filename.endswith('tar.gz'): 166 | tar = tarfile.open(filename, 'r:gz') 167 | elif filename.endswith('tar'): 168 | tar = tarfile.open(filename, 'r:') 169 | else: 170 | raise RuntimeError('Archive format not supported.') 171 | tar.extractall(path=out_path) 172 | tar.close() 173 | if remove_finished: 174 | os.unlink(filename) 175 | 176 | 177 | def sampling_rate_after_transform( 178 | dataset, 179 | ): 180 | r"""Sampling rate of data set after all transforms are applied. 181 | 182 | A change of sampling rate by a transform is only recognized, if that 183 | transform has the attribute :attr:`output_sampling_rate`. 184 | 185 | Args: 186 | dataset (torch.utils.data.Dataset): data set with `sampling_rate` 187 | attribute or property 188 | 189 | Returns: 190 | int: sampling rate in Hz after all transforms are applied 191 | 192 | Example: 193 | >>> from audtorch import datasets, transforms 194 | >>> t = transforms.Resample(input_sampling_rate=16000, 195 | ... output_sampling_rate=8000) 196 | >>> data = datasets.WhiteNoise(sampling_rate=16000, transform=t) 197 | >>> sampling_rate_after_transform(data) 198 | 8000 199 | 200 | """ 201 | sampling_rate = dataset.original_sampling_rate 202 | try: 203 | # List of composed transforms 204 | transforms = dataset.transform.transforms 205 | except AttributeError: 206 | # Single transform 207 | transforms = [dataset.transform] 208 | for transform in transforms: 209 | if hasattr(transform, 'output_sampling_rate'): 210 | sampling_rate = transform.output_sampling_rate 211 | return sampling_rate 212 | 213 | 214 | def ensure_same_sampling_rate( 215 | datasets, 216 | ): 217 | r"""Raise error if provided data set differ in sampling rate. 218 | 219 | All data sets that are checked need to have a `sampling_rate` attribute or 220 | property. 221 | 222 | Args: 223 | datasets (list of torch.utils.data.Dataset): list of at least two audio 224 | data sets. 225 | 226 | """ 227 | for dataset in datasets: 228 | if not hasattr(dataset, 'sampling_rate'): 229 | raise RuntimeError( 230 | f"{dataset} doesn't have a `sampling_rate` attribute." 231 | ) 232 | for n in range(1, len(datasets)): 233 | if datasets[0].sampling_rate != datasets[n].sampling_rate: 234 | error_msg = 'Sampling rates do not match:\n' 235 | for dataset in datasets: 236 | info = dataset.__repr__() 237 | error_msg += f'{dataset.sampling_rate}Hz from {info}' 238 | raise ValueError(error_msg) 239 | 240 | 241 | def ensure_df_columns_contain( 242 | df, 243 | labels, 244 | ): 245 | r"""Raise error if list of labels are not in dataframe columns. 246 | 247 | Args: 248 | df (pandas.dataframe): data frame 249 | labels (list of str): labels to be expected in `df.columns` 250 | 251 | Example: 252 | >>> import pandas as pd 253 | >>> df = pd.DataFrame(data=[(1, 2)], columns=['a', 'b']) 254 | >>> ensure_df_columns_contain(df, ['a', 'c']) 255 | Traceback (most recent call last): 256 | RuntimeError: Dataframe contains only these columns: 'a, b' 257 | 258 | """ 259 | ensure_df_not_empty(df) 260 | if labels is not None and not set(labels) <= set(df.columns): 261 | raise RuntimeError( 262 | f"Dataframe contains only these columns: '{', '.join(df.columns)}'" 263 | ) 264 | 265 | 266 | def ensure_df_not_empty( 267 | df, 268 | labels=None, 269 | ): 270 | r"""Raise error if dataframe is empty. 271 | 272 | Args: 273 | df (pandas.dataframe): data frame 274 | labels (list of str, optional): list of labels used to shrink data 275 | set. Default: `None` 276 | 277 | Example: 278 | >>> import pandas as pd 279 | >>> df = pd.DataFrame() 280 | >>> ensure_df_not_empty(df) 281 | Traceback (most recent call last): 282 | RuntimeError: No valid data points found in data set 283 | 284 | """ 285 | error_message = 'No valid data points found in data set' 286 | if labels is not None: 287 | error_message += f" for the selected labels: {', '.join(labels)}" 288 | if len(df) == 0: 289 | raise RuntimeError(error_message) 290 | 291 | 292 | def files_and_labels_from_df( 293 | df, 294 | *, 295 | column_labels=None, 296 | column_filename='filename', 297 | ): 298 | r"""Extract list of files and labels from dataframe columns. 299 | 300 | Args: 301 | df (pandas.DataFrame): data frame with filenames and labels 302 | column_labels (str or list of str, optional): name of data frame 303 | column(s) containing the desired labels. Default: `None` 304 | column_filename (str, optional): name of column holding the file 305 | names. Default: `filename` 306 | 307 | Returns: 308 | tuple: 309 | * list of str: list of files 310 | * list of str or list of dicts: list of labels 311 | 312 | Example: 313 | >>> import pandas as pd 314 | >>> df = pd.DataFrame(data=[('speech.wav', 'speech')], 315 | ... columns=['filename', 'label']) 316 | >>> files, labels = files_and_labels_from_df(df, column_labels='label') 317 | >>> os.path.relpath(files[0]), labels[0] 318 | ('speech.wav', 'speech') 319 | 320 | """ 321 | if df is None: 322 | return [], [] 323 | 324 | ensure_df_columns_contain(df, [column_filename]) 325 | df = df.copy() 326 | files = df.pop(column_filename).tolist() 327 | 328 | if column_labels is None: 329 | return files, [''] * len(files) 330 | 331 | if isinstance(column_labels, str): 332 | column_labels = [column_labels] 333 | ensure_df_columns_contain(df, column_labels) 334 | df = df[column_labels] 335 | # Drop empty entries 336 | df = df.dropna().reset_index(drop=True) 337 | ensure_df_not_empty(df, column_labels) 338 | if len(column_labels) == 1: 339 | # list of strings 340 | labels = df.values.T[0].tolist() 341 | else: 342 | # list of dicts 343 | labels = df.to_dict('records') 344 | return files, labels 345 | 346 | 347 | def _gen_bar_updater(pbar): 348 | def bar_update(count, block_size, total_size): 349 | if pbar.total is None and total_size: 350 | pbar.total = total_size 351 | progress_bytes = count * block_size 352 | pbar.update(progress_bytes - pbar.n) 353 | 354 | return bar_update 355 | 356 | 357 | def defined_split( 358 | dataset, 359 | split_func, 360 | ): 361 | r"""Split data set into desired non-overlapping subsets. 362 | 363 | Args: 364 | dataset (torch.utils.data.Dataset): data set to be split 365 | split_func (func): function mapping from data set index to subset id, 366 | :math:`f(\text{index}) = \text{subset\_id}`. 367 | The target domain of subset ids does not need to cover the 368 | complete range `[0, 1, ..., (num_subsets - 1)]` 369 | 370 | Returns: 371 | (list of Subsets): desired subsets according to :attr:`split_func` 372 | 373 | Example: 374 | >>> import torch 375 | >>> from torch.utils.data import TensorDataset 376 | >>> from audtorch.samplers import buckets_of_even_size 377 | >>> data = TensorDataset(torch.randn(100)) 378 | >>> lengths = np.random.randint(0, 1000, (100,)) 379 | >>> split_func = buckets_of_even_size(lengths, 5) 380 | >>> subsets = defined_split(data, split_func) 381 | >>> [len(subset) for subset in subsets] 382 | [20, 20, 20, 20, 20] 383 | 384 | """ 385 | subset_ids = [split_func(i) for i in range(len(dataset))] 386 | unique_subset_ids = sorted(set(subset_ids)) 387 | num_subsets = len(unique_subset_ids) 388 | 389 | split_indices = [[] for _ in range(num_subsets)] 390 | 391 | for i, subset_id in enumerate(subset_ids): 392 | # handle non-coherent target domain 393 | subset_id = unique_subset_ids.index(subset_id) 394 | split_indices[subset_id] += [i] 395 | 396 | return [Subset(dataset, indices) 397 | for indices in split_indices] 398 | 399 | 400 | def safe_path( 401 | path, 402 | ): 403 | """Ensure the path is absolute and doesn't include `..` or `~`. 404 | 405 | Args: 406 | path (str): absolute or relative path 407 | 408 | Returns: 409 | str: absolute path 410 | 411 | """ 412 | return os.path.abspath(os.path.expanduser(path)) 413 | -------------------------------------------------------------------------------- /audtorch/datasets/voxceleb1.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pandas as pd 4 | 5 | from audtorch.datasets.base import AudioDataset 6 | from audtorch.datasets.utils import download_url, safe_path 7 | 8 | 9 | __doctest_skip__ = ['*'] 10 | 11 | 12 | class VoxCeleb1(AudioDataset): 13 | r"""VoxCeleb1 data set. 14 | 15 | Open and publicly available data set of voices from University of Oxford: 16 | http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1.html 17 | 18 | VoxCeleb1 is a large audio-visual data set consisting of short clips of 19 | human speech extracted from YouTube interviews with celebrities. It is 20 | free for commercial and research purposes. 21 | 22 | Licence: CC BY-SA 4.0 23 | 24 | * :attr:`transform` controls the input transform 25 | * :attr:`target_transform` controls the target transform 26 | * :attr:`files` controls the audio files of the data set 27 | * :attr:`targets` controls the corresponding targets 28 | * :attr:`sampling_rate` holds the sampling rate of data set 29 | 30 | In addition, the following class attributes are available: 31 | 32 | * :attr:`url` holds its URL 33 | 34 | Args: 35 | root (str): root directory of dataset 36 | partition (str, optional): name of the data partition to use. 37 | Choose one of `train`, `dev`, `test` or `None`. If `None` is given, 38 | then the whole data set will be returned. Default: `train` 39 | transform (callable, optional): function/transform applied on the 40 | signal. Default: `None` 41 | target_transform (callable, optional): function/transform applied on 42 | the target. Default: `None` 43 | 44 | Note: 45 | * This data set will work only if the identification file is downloaded 46 | as is from the official homepage. Please open it in your browser and 47 | copy paste its contents in a file in your computer. 48 | * To download the data set go to 49 | http://www.robots.ox.ac.uk/~vgg/data/voxceleb/ and fill in the form 50 | to request a password. Get the Audio Files that the owners provide. 51 | 52 | * When using the VoxCeleb1 data set in your research, please cite 53 | the following publication: :cite:`nagrani2017voxceleb`. 54 | 55 | Example: 56 | >>> import sounddevice as sd 57 | >>> data = VoxCeleb1('/data/voxceleb1') 58 | >>> print(data) 59 | Dataset VoxCeleb1 60 | Number of data points: 138361 61 | Root Location: /data/voxceleb1 62 | Sampling Rate: 16000Hz 63 | Labels: speaker ID 64 | >>> signal, target = data[0] 65 | >>> target 66 | 'id10003' 67 | >>> sd.play(signal.transpose(), data.sampling_rate) 68 | 69 | """ 70 | url = ('http://www.robots.ox.ac.uk/~vgg/data/voxceleb/') 71 | _iden_file_url = ( 72 | 'http://www.robots.ox.ac.uk/~vgg/data/voxceleb/meta/iden_split.txt') 73 | _partitions = {'train': 1, 'dev': 2, 'test': 3} 74 | 75 | def __init__( 76 | self, 77 | root, 78 | *, 79 | partition='train', 80 | transform=None, 81 | target_transform=None, 82 | ): 83 | self.root = safe_path(root) 84 | 85 | filelist = pd.read_csv( 86 | os.path.join( 87 | self.root, 88 | download_url(self._iden_file_url, self.root)), 89 | sep=' ', 90 | header=None, 91 | ) 92 | files, targets = self._get_files_speaker_lists(filelist) 93 | 94 | if partition is not None: 95 | # filter indices based on identification split 96 | indices = [index for index, x in enumerate(filelist[0]) 97 | if x == self._partitions[partition]] 98 | files = [files[index] for index in indices] 99 | targets = [targets[index] for index in indices] 100 | 101 | super().__init__( 102 | files=files, 103 | targets=targets, 104 | sampling_rate=16000, 105 | root=root, 106 | transform=transform, 107 | target_transform=target_transform, 108 | ) 109 | 110 | def _get_files_speaker_lists(self, filelist): 111 | r"""Extract file names and speaker IDs. 112 | 113 | Args: 114 | filelist (pandas.DataFrame): data frame containing file list 115 | and speakers 116 | 117 | Returns: 118 | list: files belonging to data set 119 | list: speaker IDs per file 120 | 121 | """ 122 | files = [os.path.join(self.root, 'wav', x) 123 | for x in filelist[1]] 124 | speakers = [x.split('/')[0] for x in filelist[1]] 125 | return files, speakers 126 | 127 | def extra_repr(self): 128 | fmt_str = ' Labels: speaker ID\n' 129 | return fmt_str 130 | -------------------------------------------------------------------------------- /audtorch/datasets/white_noise.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from torch.utils.data import Dataset 3 | 4 | from ..transforms import functional as F 5 | from .utils import sampling_rate_after_transform 6 | 7 | 8 | __doctest_skip__ = ['*'] 9 | 10 | 11 | class WhiteNoise(Dataset): 12 | r"""White noise data set. 13 | 14 | The white noise is generated by numpy.random.standard_normal. 15 | 16 | * :attr:`duration` controls the duration of the noise signal 17 | * :attr:`sampling_rate` holds the sampling rate of the returned data 18 | * :attr:`mean` controls the mean of the underlying distribution 19 | * :attr:`stdev` controls the standard deviation of the underlying 20 | distribution 21 | * :attr:`transform` controls the input transform 22 | * :attr:`target_transform` controls the target transform 23 | 24 | As white noise has not really a sampling rate you can use the following 25 | attribute to change it instead of resampling: 26 | 27 | * :attr:`original_sampling_rate` controls the sampling rate of the data set 28 | 29 | Args: 30 | duration (float): duration of the noise signal in seconds 31 | sampling_rate (int, optional): sampling rate in Hz. Default: `44100` 32 | mean (float, optional): mean of underlying distribution. Default: `0` 33 | stdev (float, optional): standard deviation of underlying distribution. 34 | Default: `1` 35 | transform (callable, optional): function/transform applied on the 36 | signal. Default: `None` 37 | target_transform (callable, optional): function/transform applied on 38 | the target. Default: `None` 39 | 40 | Note: 41 | Even `WhiteNoise` has an infintely number of entries, its length is 42 | `1` as repeated calls of the same index return different signals. 43 | 44 | Example: 45 | >>> import sounddevice as sd 46 | >>> data = WhiteNoise(duration=1, sampling_rate=44100) 47 | >>> print(data) 48 | Dataset WhiteNoise 49 | Number of data points: Inf 50 | Signal length: 1s 51 | Sampling Rate: 44100Hz 52 | Label (str): noise type 53 | >>> signal, target = data[0] 54 | >>> target 55 | 'white noise' 56 | >>> sd.play(signal.transpose(), data.sampling_rate) 57 | 58 | """ 59 | 60 | def __init__( 61 | self, 62 | *, 63 | duration=1, 64 | sampling_rate=44100, 65 | mean=0, 66 | stdev=1, 67 | transform=None, 68 | target_transform=None, 69 | ): 70 | super().__init__() 71 | self.duration = duration 72 | self.mean = mean 73 | self.stdev = stdev 74 | self.transform = transform 75 | self.target_transform = target_transform 76 | self.original_sampling_rate = sampling_rate 77 | 78 | @property 79 | def sampling_rate(self): 80 | return sampling_rate_after_transform(self) 81 | 82 | def __len__(self): 83 | # This has no meaningful __len__ as its actual length is `Inf`. 84 | # Return `1` to make it work with `random.choice` like operations. 85 | return 1 86 | 87 | def __getitem__(self, item): 88 | samples = int(np.ceil(self.duration * self.sampling_rate)) 89 | signal = np.random.normal(loc=self.mean, scale=self.stdev, 90 | size=samples) 91 | signal = F.normalize(np.expand_dims(signal, axis=0)) 92 | target = 'white noise' 93 | 94 | if self.transform is not None: 95 | signal = self.transform(signal) 96 | 97 | if self.target_transform is not None: 98 | target = self.target_transform(target) 99 | 100 | return signal, target 101 | 102 | def __repr__(self): 103 | fmt_str = f'Dataset {self.__class__.__name__}\n' 104 | fmt_str += ' Number of data points: Inf\n' 105 | fmt_str += f' Signal length: {self.duration}s\n' 106 | if self.sampling_rate == self.original_sampling_rate: 107 | fmt_str += f' Sampling Rate: {self.sampling_rate}Hz\n' 108 | else: 109 | fmt_str += ( 110 | f' Sampling Rate: {self.sampling_rate}Hz ' 111 | f'(original: {self.original_sampling_rate}Hz)\n' 112 | ) 113 | fmt_str += ' Label (str): noise type\n' 114 | tmp1 = ' Transform: ' 115 | tmp2 = self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp1)) 116 | if self.transform: 117 | fmt_str += f'{tmp1}{tmp2}\n' 118 | tmp1 = ' Target Transform: ' 119 | tmp2 = self.target_transform.__repr__().replace('\n', 120 | '\n' + ' ' * len(tmp1)) 121 | if self.target_transform: 122 | fmt_str += f'{tmp1}{tmp2}' 123 | return fmt_str 124 | -------------------------------------------------------------------------------- /audtorch/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from .losses import * 2 | from .metrics import * 3 | from . import functional 4 | -------------------------------------------------------------------------------- /audtorch/metrics/functional.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | def pearsonr( 4 | x, 5 | y, 6 | batch_first=True, 7 | ): 8 | r"""Computes Pearson Correlation Coefficient across rows. 9 | 10 | Pearson Correlation Coefficient (also known as Linear Correlation 11 | Coefficient or Pearson's :math:`\rho`) is computed as: 12 | 13 | .. math:: 14 | 15 | \rho = \frac {E[(X-\mu_X)(Y-\mu_Y)]} {\sigma_X\sigma_Y} 16 | 17 | If inputs are matrices, then then we assume that we are given a 18 | mini-batch of sequences, and the correlation coefficient is 19 | computed for each sequence independently and returned as a vector. If 20 | `batch_fist` is `True`, then we assume that every row represents a 21 | sequence in the mini-batch, otherwise we assume that batch information 22 | is in the columns. 23 | 24 | Warning: 25 | We do not account for the multi-dimensional case. This function has 26 | been tested only for the 2D case, either in `batch_first==True` or in 27 | `batch_first==False` mode. In the multi-dimensional case, 28 | it is possible that the values returned will be meaningless. 29 | 30 | Args: 31 | x (torch.Tensor): input tensor 32 | y (torch.Tensor): target tensor 33 | batch_first (bool, optional): controls if batch dimension is first. 34 | Default: `True` 35 | 36 | Returns: 37 | torch.Tensor: correlation coefficient between `x` and `y` 38 | 39 | Note: 40 | :math:`\sigma_X` is computed using **PyTorch** builtin 41 | **Tensor.std()**, which by default uses Bessel correction: 42 | 43 | .. math:: 44 | 45 | \sigma_X=\displaystyle\frac{1}{N-1}\sum_{i=1}^N({x_i}-\bar{x})^2 46 | 47 | We therefore account for this correction in the computation of the 48 | covariance by multiplying it with :math:`\frac{1}{N-1}`. 49 | 50 | Shape: 51 | - Input: :math:`(N, M)` for correlation between matrices, 52 | or :math:`(M)` for correlation between vectors 53 | - Target: :math:`(N, M)` or :math:`(M)`. Must be identical to input 54 | - Output: :math:`(N, 1)` for correlation between matrices, 55 | or :math:`(1)` for correlation between vectors 56 | 57 | Examples: 58 | >>> import torch 59 | >>> _ = torch.manual_seed(0) 60 | >>> input = torch.rand(3, 5) 61 | >>> target = torch.rand(3, 5) 62 | >>> output = pearsonr(input, target) 63 | >>> print('Pearson Correlation between input and target is {0}'.format(output[:, 0])) 64 | Pearson Correlation between input and target is tensor([ 0.2991, -0.8471, 0.9138]) 65 | 66 | """ # noqa: E501 67 | assert x.shape == y.shape 68 | 69 | if batch_first: 70 | dim = -1 71 | else: 72 | dim = 0 73 | 74 | centered_x = x - x.mean(dim=dim, keepdim=True) 75 | centered_y = y - y.mean(dim=dim, keepdim=True) 76 | 77 | covariance = (centered_x * centered_y).sum(dim=dim, keepdim=True) 78 | 79 | bessel_corrected_covariance = covariance / (x.shape[dim] - 1) 80 | 81 | x_std = x.std(dim=dim, keepdim=True) 82 | y_std = y.std(dim=dim, keepdim=True) 83 | 84 | corr = bessel_corrected_covariance / (x_std * y_std) 85 | 86 | return corr 87 | 88 | 89 | def concordance_cc( 90 | x, 91 | y, 92 | batch_first=True, 93 | ): 94 | r"""Computes Concordance Correlation Coefficient across rows. 95 | 96 | Concordance Correlation Coefficient is computed as: 97 | 98 | .. math:: 99 | 100 | \rho_c = \frac {2\rho\sigma_X\sigma_Y} {\sigma_X\sigma_X + 101 | \sigma_Y\sigma_Y + (\mu_X - \mu_Y)^2} 102 | 103 | where :math:`\rho` is Pearson Correlation Coefficient, :math:`\sigma_X`, 104 | :math:`\sigma_Y` are the standard deviation, and :math:`\mu_X`, 105 | :math:`\mu_Y` the mean values of :math:`X` and :math:`Y` accordingly. 106 | 107 | If inputs are matrices, then then we assume that we are given a 108 | mini-batch of sequences, and the concordance correlation coefficient is 109 | computed for each sequence independently and returned as a vector. If 110 | `batch_fist` is `True`, then we assume that every row represents a 111 | sequence in the mini-batch, otherwise we assume that batch information 112 | is in the columns. 113 | 114 | Warning: 115 | We do not account for the multi-dimensional case. This function has 116 | been tested only for the 2D case, either in `batch_first==True` or in 117 | `batch_first==False` mode. In the multi-dimensional case, 118 | it is possible that the values returned will be meaningless. 119 | 120 | Note: 121 | :math:`\sigma_X` is computed using **PyTorch** builtin 122 | **Tensor.std()**, which by default uses Bessel correction: 123 | 124 | .. math:: 125 | 126 | \sigma_X=\displaystyle\frac{1}{N-1}\sum_{i=1}^N({x_i}-\bar{x})^2 127 | 128 | We therefore account for this correction in the computation of the 129 | concordance correlation coefficient by multiplying all standard 130 | deviations with :math:`\frac{N-1}{N}`. This is equivalent to 131 | multiplying only :math:`(\mu_X - \mu_Y)^2` with :math:`\frac{N}{ 132 | N-1}`. We choose that option for numerical stability. 133 | 134 | Args: 135 | x (torch.Tensor): input tensor 136 | y (torch.Tensor): target tensor 137 | batch_first (bool, optional): controls if batch dimension is first. 138 | Default: `True` 139 | 140 | Returns: 141 | torch.Tensor: concordance correlation coefficient between `x` and `y` 142 | 143 | Shape: 144 | - Input: :math:`(N, M)` for correlation between matrices, 145 | or :math:`(M)` for correlation between vectors 146 | - Target: :math:`(N, M)` or :math:`(M)`. Must be identical to input 147 | - Output: :math:`(N, 1)` for correlation between matrices, 148 | or :math:`(1)` for correlation between vectors 149 | 150 | Examples: 151 | >>> import torch 152 | >>> _ = torch.manual_seed(0) 153 | >>> input = torch.rand(3, 5) 154 | >>> target = torch.rand(3, 5) 155 | >>> output = concordance_cc(input, target) 156 | >>> print('Concordance Correlation between input and target is {0}'.format(output[:, 0])) 157 | Concordance Correlation between input and target is tensor([ 0.2605, -0.7862, 0.5298]) 158 | 159 | """ # noqa: E501 160 | assert x.shape == y.shape 161 | 162 | if batch_first: 163 | dim = -1 164 | else: 165 | dim = 0 166 | 167 | bessel_correction_term = (x.shape[dim] - 1) / x.shape[dim] 168 | 169 | r = pearsonr(x, y, batch_first) 170 | x_mean = x.mean(dim=dim, keepdim=True) 171 | y_mean = y.mean(dim=dim, keepdim=True) 172 | x_std = x.std(dim=dim, keepdim=True) 173 | y_std = y.std(dim=dim, keepdim=True) 174 | ccc = 2 * r * x_std * y_std / (x_std * x_std 175 | + y_std * y_std 176 | + (x_mean - y_mean) 177 | * (x_mean - y_mean) 178 | / bessel_correction_term) 179 | return ccc 180 | -------------------------------------------------------------------------------- /audtorch/metrics/losses.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | 5 | class EnergyConservingLoss(nn.L1Loss): 6 | r"""Energy conserving loss. 7 | 8 | A two term loss that enforces energy conservation after 9 | :cite:`Rethage2018`. 10 | 11 | The loss can be described as: 12 | 13 | .. math:: 14 | \ell(x, y, m) = L = \{l_1,\dots,l_N\}^\top, \quad 15 | l_n = |x_n - y_n| + |b_n - \hat{b_n}|, 16 | 17 | where :math:`N` is the batch size. If reduction is not ``'none'``, then: 18 | 19 | .. math:: 20 | \ell(x, y, m) = 21 | \begin{cases} 22 | \operatorname{mean}(L), & \text{if reduction} = \text{'mean';}\\ 23 | \operatorname{sum}(L), & \text{if reduction} = \text{'sum'.} 24 | \end{cases} 25 | 26 | :math:`x` is the input signal (estimated target), :math:`y` the target 27 | signal, :math:`m` the mixture signal, :math:`b` the background signal given 28 | by :math:`b = m - y`, and :math:`\hat{b}` the estimated background signal 29 | given by :math:`\hat{b} = m - x`. 30 | 31 | Args: 32 | reduction (string, optional): specifies the reduction to apply to the 33 | output: 'none' | 'mean' | 'sum'. 34 | 'none': no reduction will be applied, 'mean': the sum of the output 35 | will be divided by the number of elements in the output, 'sum': the 36 | output will be summed. 37 | 38 | Shape: 39 | - Input: :math:`(N, *)` where `*` means, any number of additional 40 | dimensions 41 | - Target: :math:`(N, *)`, same shape as the input 42 | - Mixture: :math:`(N, *)`, same shape as the input 43 | - Output: scalar. If reduction is ``'none'``, then :math:`(N, *)`, same 44 | shape as the input 45 | 46 | Examples: 47 | >>> import torch 48 | >>> _ = torch.manual_seed(0) 49 | >>> loss = EnergyConservingLoss() 50 | >>> input = torch.randn(3, 5, requires_grad=True) 51 | >>> target = torch.randn(3, 5) 52 | >>> mixture = torch.randn(3, 5) 53 | >>> loss(input, target, mixture) 54 | tensor(2.1352, grad_fn=) 55 | 56 | """ 57 | def __init__( 58 | self, 59 | *, 60 | reduction='mean', 61 | ): 62 | super().__init__(None, None, reduction) 63 | 64 | def forward(self, y_predicted, y, x): 65 | noise = x - y 66 | noise_predicted = x - y_predicted 67 | return (F.l1_loss(y_predicted, y, reduction=self.reduction) 68 | + F.l1_loss(noise_predicted, noise, reduction=self.reduction)) 69 | -------------------------------------------------------------------------------- /audtorch/metrics/metrics.py: -------------------------------------------------------------------------------- 1 | from . import functional as F 2 | 3 | 4 | class PearsonR(object): 5 | r"""Computes Pearson Correlation Coefficient. 6 | 7 | Pearson Correlation Coefficient (also known as Linear Correlation 8 | Coefficient or Pearson's :math:`\rho`) is computed as: 9 | 10 | .. math:: 11 | 12 | \rho = \frac {E[(X-\mu_X)(Y-\mu_Y)]} {\sigma_X\sigma_Y} 13 | 14 | If inputs are vectors, computes Pearson's :math:`\rho` between the two 15 | of them. If inputs are multi-dimensional arrays, computes Pearson's 16 | :math:`\rho` along the first or last input dimension according to the 17 | `batch_first` argument, returns a **torch.Tensor** as output, 18 | and optionally reduces it according to the `reduction` argument. 19 | 20 | Args: 21 | reduction (string, optional): specifies the reduction to apply to the 22 | output: 'none' | 'mean' | 'sum'. 23 | 'none': no reduction will be applied, 'mean': the sum of the output 24 | will be divided by the number of elements in the output, 'sum': the 25 | output will be summed. Default: 'mean' 26 | batch_first (bool, optional): controls if batch dimension is first. 27 | Default: `True` 28 | 29 | Shape: 30 | - Input: :math:`(N, *)` where `*` means, any number of additional 31 | dimensions 32 | - Target: :math:`(N, *)`, same shape as the input 33 | - Output: scalar. If reduction is ``'none'``, then :math:`(N, 1)` 34 | 35 | Example: 36 | >>> import torch 37 | >>> _ = torch.manual_seed(0) 38 | >>> metric = PearsonR() 39 | >>> input = torch.rand(3, 5) 40 | >>> target = torch.rand(3, 5) 41 | >>> metric(input, target) 42 | tensor(0.1220) 43 | 44 | """ 45 | def __init__( 46 | self, 47 | *, 48 | reduction='mean', 49 | batch_first=True, 50 | ): 51 | self.reduction = reduction 52 | self.batch_first = batch_first 53 | 54 | def __call__(self, x, y): 55 | r = F.pearsonr(x, y, self.batch_first) 56 | if self.reduction == 'mean': 57 | r = r.mean() 58 | elif self.reduction == 'sum': 59 | r = r.sum() 60 | return r 61 | 62 | 63 | class ConcordanceCC(object): 64 | r"""Computes Concordance Correlation Coefficient (CCC). 65 | 66 | CCC is computed as: 67 | 68 | .. math:: 69 | 70 | \rho_c = \frac {2\rho\sigma_X\sigma_Y} {\sigma_X\sigma_X + 71 | \sigma_Y\sigma_Y + (\mu_X - \mu_Y)^2} 72 | 73 | where :math:`\rho` is Pearson Correlation Coefficient, :math:`\sigma_X`, 74 | :math:`\sigma_Y` are the standard deviation, and :math:`\mu_X`, 75 | :math:`\mu_Y` the mean values of :math:`X` and :math:`Y` accordingly. 76 | 77 | If inputs are vectors, computes CCC between the two of them. If inputs 78 | are multi-dimensional arrays, computes CCC along the first or last input 79 | dimension according to the `batch_first` argument, returns a 80 | **torch.Tensor** as output, and optionally reduces it according to the 81 | `reduction` argument. 82 | 83 | Args: 84 | reduction (string, optional): specifies the reduction to apply to the 85 | output: 'none' | 'mean' | 'sum'. 86 | 'none': no reduction will be applied, 'mean': the sum of the output 87 | will be divided by the number of elements in the output, 'sum': the 88 | output will be summed. Default: 'mean' 89 | batch_first (bool, optional): controls if batch dimension is first. 90 | Default: `True` 91 | 92 | Shape: 93 | - Input: :math:`(N, *)` where `*` means, any number of additional 94 | dimensions 95 | - Target: :math:`(N, *)`, same shape as the input 96 | - Output: scalar. If reduction is ``'none'``, then :math:`(N, 1)` 97 | 98 | Example: 99 | >>> import torch 100 | >>> _ = torch.manual_seed(0) 101 | >>> metric = ConcordanceCC() 102 | >>> input = torch.rand(3, 5) 103 | >>> target = torch.rand(3, 5) 104 | >>> metric(input, target) 105 | tensor(0.0014) 106 | 107 | """ 108 | def __init__( 109 | self, 110 | *, 111 | reduction='mean', 112 | batch_first=True, 113 | ): 114 | self.reduction = reduction 115 | self.batch_first = batch_first 116 | 117 | def __call__(self, x, y): 118 | r = F.concordance_cc(x, y, self.batch_first) 119 | if self.reduction == 'mean': 120 | r = r.mean() 121 | elif self.reduction == 'sum': 122 | r = r.sum() 123 | return r 124 | -------------------------------------------------------------------------------- /audtorch/samplers.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from torch.utils.data import BatchSampler 3 | 4 | 5 | class BucketSampler(BatchSampler): 6 | r"""Creates batches from ordered data sets. 7 | 8 | This sampler iterates over the data sets of `concat_dataset` 9 | and samples sequentially from them. 10 | Samples of each batch deliberately originate solely from the same data set. 11 | Only when the current data set is exhausted, the next data set 12 | is sampled from. In other words, 13 | samples from different buckets are never mixed. 14 | 15 | In each epoch `num_batches` batches of size `batch_sizes` 16 | are extracted from each data set. 17 | If the requested number of batches cannot be extracted from a data set, 18 | only its available batches are queued. 19 | By default, the data sets (and thus their batches) are iterated over 20 | in increasing order of their data set id. 21 | 22 | Note: 23 | The information in :attr:`batch_sizes` and :attr:`num_batches` 24 | refer to :attr:`datasets` at the same index 25 | independently of :attr:`permuted_order`. 26 | 27 | Simple Use Case: "Train on data with increasing sequence length" 28 | 29 | ======================= ================================== 30 | bucket_id: [0, 1, 2, ... end ] 31 | batch_sizes: [32, 16, 8, ... 2 ] 32 | num_batches: [None, None, None, ... None ] 33 | ======================= ================================== 34 | 35 | Result: 36 | "Extract all batches (`None`) from all data sets, 37 | all of different batch size, and queue them 38 | in increasing order of their data set id" 39 | 40 | * :attr:`batch_sizes` controls batch size for each data set 41 | * :attr:`num_batches` controls number of batches to extract 42 | from each data set 43 | * :attr:`permuted_order` controls if order in which data sets are iterated 44 | over is permuted or in which specific order iteration is permuted 45 | * :attr:`shuffle_each_bucket` controls if each data set is shuffled 46 | * :attr:`drop_last` controls whether to drop last samples of a bucket 47 | which cannot form a mini-batch 48 | 49 | Args: 50 | concat_dataset (torch.utils.data.ConcatDataset): ordered concatenated 51 | data set 52 | batch_sizes (list): batch sizes per data set. Permissible values are 53 | unsigned integers 54 | num_batches (list or None, optional): number of batches per data set. 55 | Permissible values are non-negative integers and None. 56 | If None, then as many batches are extracted as data set provides. 57 | Default: `None` 58 | permuted_order (bool or list, optional): option whether to permute the 59 | order of data set ids in which the respective data set's batches 60 | are queued. If True (False), data set ids are (not) shuffled. 61 | Besides, a customized list of permuted data set ids can be 62 | specified. Default: `False` 63 | shuffle_each_bucket (bool, optional): option whether to shuffle samples 64 | in each data set. Recommended to set to True. Default: `True` 65 | drop_last (bool, optional): controls whether the last samples of a 66 | bucket which cannot form a mini-batch should be dropped. 67 | Default: `False` 68 | 69 | Example: 70 | >>> import torch 71 | >>> from torch.utils.data import (TensorDataset, ConcatDataset) 72 | >>> from audtorch.datasets.utils import defined_split 73 | >>> data = TensorDataset(torch.randn(100)) 74 | >>> lengths = np.random.randint(0, 890, (100,)) 75 | >>> split_func = buckets_of_even_size(lengths, num_buckets=3) 76 | >>> subsets = defined_split(data, split_func) 77 | >>> concat_dataset = ConcatDataset(subsets) 78 | >>> batch_sampler = BucketSampler(concat_dataset, 3 * [16]) 79 | 80 | """ 81 | 82 | def __init__( 83 | self, 84 | concat_dataset, 85 | batch_sizes, 86 | *, 87 | num_batches=None, 88 | permuted_order=False, 89 | shuffle_each_bucket=True, 90 | drop_last=False, 91 | ): 92 | self.datasets = _stack_concatenated_datasets(concat_dataset) 93 | self.batch_sizes = batch_sizes 94 | self.num_batches = num_batches 95 | self.permuted_order = permuted_order 96 | self.shuffle_each_bucket = shuffle_each_bucket 97 | self.drop_last = drop_last 98 | self.dataset_ids = list(range(len(self.datasets))) 99 | 100 | if isinstance(permuted_order, list): 101 | assert sorted(self.permuted_order) == self.dataset_ids, \ 102 | '`permuted_order` not consistent with number of data sets.' 103 | self.dataset_ids = permuted_order 104 | 105 | assert all( 106 | [self.batch_sizes[i] > 0 and isinstance(self.batch_sizes[i], int) 107 | for i in self.dataset_ids] 108 | ), 'Only positive integers permitted for `num_batches`.' 109 | 110 | if self.num_batches is not None: 111 | assert all( 112 | [self.num_batches[i] >= 0 and isinstance( 113 | self.num_batches[i], int) for i in self.dataset_ids 114 | if self.num_batches[i] is not None]), \ 115 | "Only non-negative integers or " \ 116 | "`None` permitted for `num_batches`." 117 | 118 | if not isinstance(drop_last, bool): 119 | raise ValueError( 120 | f'drop_last should be a boolean value, but got {drop_last}' 121 | ) 122 | 123 | def __iter__(self): 124 | r"""Iterates sequentially over data sets and forms batches 125 | 126 | """ 127 | all_batches = [] 128 | batch = [] 129 | 130 | if self.permuted_order is True: 131 | self.dataset_ids = list(np.random.permutation(self.dataset_ids)) 132 | 133 | # iterate over data sets ordered by data set id 134 | for dset_id in self.dataset_ids: 135 | 136 | dataset = self.datasets[dset_id] 137 | num_batch = 0 138 | 139 | if self.shuffle_each_bucket: # random samples from data set 140 | dataset = list(np.random.permutation(dataset)) 141 | 142 | for sample in dataset: # iterate over samples in data set 143 | if self.num_batches is not None and \ 144 | self.num_batches[dset_id] is not None: 145 | if num_batch == self.num_batches[dset_id]: 146 | break 147 | batch.append(sample) 148 | if len(batch) == self.batch_sizes[dset_id]: 149 | all_batches.append(batch) 150 | num_batch += 1 151 | batch = [] 152 | 153 | # yield full batch and also \ 154 | # handle last samples of bucket which cannot form entire batch 155 | if len(batch) > 0 and not self.drop_last: 156 | all_batches.append(batch) 157 | batch = [] 158 | 159 | return iter(all_batches) 160 | 161 | def __len__(self): 162 | 163 | sampler_size = 0 164 | for dset_id in self.dataset_ids: 165 | 166 | dataset = self.datasets[dset_id] 167 | bs = self.batch_sizes[dset_id] 168 | requested_batches = None if self.num_batches is None \ 169 | else self.num_batches[dset_id] 170 | 171 | if self.drop_last: 172 | fitted_batches = len(dataset) // bs 173 | else: 174 | fitted_batches = (len(dataset) + bs - 1) // bs 175 | 176 | if requested_batches is None: 177 | sampler_size += fitted_batches 178 | else: 179 | sampler_size += fitted_batches if \ 180 | fitted_batches < requested_batches \ 181 | else requested_batches 182 | 183 | return sampler_size 184 | 185 | 186 | def buckets_by_boundaries( 187 | key_values, 188 | bucket_boundaries, 189 | ): 190 | r"""Split samples into buckets based on key values using bucket boundaries. 191 | 192 | Note: 193 | A sample is sorted into bucket :math:`i` if for their key value 194 | holds: 195 | 196 | :math:`b_{i-1} <= \text{key value} < b_i`, 197 | where :math:`b_i` is `bucket boundary` at index :math:`i` 198 | 199 | Args: 200 | key_values (list): contains key values, e.g. sequence lengths 201 | bucket_boundaries (list): contains boundaries of buckets in 202 | ascending order. The list should neither contain a lower or 203 | upper boundary, e.g. not numpy.iinfo.min or numpy.iinfo.max. 204 | 205 | Returns: 206 | func: Key function to use for splitting: \ 207 | :math:`f(\text{item}) = \text{bucket\_id}` 208 | 209 | Example: 210 | >>> lengths = [288, 258, 156, 99, 47, 13] 211 | >>> boundaries = [80, 150] 212 | >>> split_func = buckets_by_boundaries(lengths, boundaries) 213 | >>> [split_func(i) for i in range(len(lengths))] 214 | [2, 2, 2, 1, 0, 0] 215 | 216 | """ 217 | assert bucket_boundaries == sorted(bucket_boundaries), \ 218 | "Iterable `bucket_boundaries` not given in ascending order." 219 | 220 | assert len(bucket_boundaries) == \ 221 | len(np.unique(bucket_boundaries)), \ 222 | "Iterable `bucket_boundaries` contains duplicate(s)." 223 | 224 | num_buckets = len(bucket_boundaries) + 1 225 | 226 | def key_func(item): 227 | key_val = key_values[item] 228 | for bucket_id in range(num_buckets - 1): 229 | if key_val < bucket_boundaries[bucket_id]: 230 | return bucket_id 231 | return num_buckets - 1 232 | 233 | return key_func 234 | 235 | 236 | def buckets_of_even_size( 237 | key_values, 238 | num_buckets, 239 | *, 240 | reverse=False, 241 | ): 242 | r"""Split samples into buckets of even size based on key values. 243 | 244 | The samples are sorted with either increasing (or decreasing) key value. 245 | If number of samples cannot be distributed evenly to buckets, 246 | the first buckets are filled up with one remainder each. 247 | 248 | Args: 249 | key_values (list): contains key values, e.g. sequence lengths 250 | num_buckets (int): number of buckets to form. Permitted are 251 | positive integers 252 | reverse (bool, optional): if True, then sort in descending order. 253 | Default: `False` 254 | 255 | Returns: 256 | func: Key function to use for splitting: \ 257 | :math:`f(\text{item}) = \text{bucket\_id}` 258 | 259 | Example: 260 | >>> lengths = [288, 258, 156, 47, 112, 99, 13] 261 | >>> num_buckets = 4 262 | >>> split_func = buckets_of_even_size(lengths, num_buckets) 263 | >>> [split_func(i) for i in range(len(lengths))] 264 | [3, 2, 2, 0, 1, 1, 0] 265 | 266 | """ 267 | # make sure that bucket size is larger than 0 268 | assert (len(key_values) >= num_buckets), \ 269 | "Not enough `key_values` for `num_buckets` in order to" \ 270 | " form even buckets." 271 | 272 | assert (num_buckets > 0) and isinstance(num_buckets, int), \ 273 | "Specified value for `num_buckets` not a positive integer." 274 | 275 | bucket_size, remainder = divmod(len(key_values), num_buckets) 276 | sorted_indices = sorted(range(len(key_values)), 277 | key=lambda i: key_values[i], 278 | reverse=reverse) 279 | 280 | bucket_membership = len(key_values) * [0] 281 | sample_count, bucket_id, r = 0, 0, remainder 282 | 283 | for i in sorted_indices: 284 | bucket_membership[i] = bucket_id 285 | sample_count += 1 286 | if sample_count == bucket_size: # bucket full 287 | # distribute one remainder per bucket 288 | if r > 0 and bucket_id + r == remainder: 289 | r -= 1 290 | sample_count -= 1 291 | else: 292 | sample_count = 0 293 | bucket_id += 1 294 | 295 | def key_func(item): 296 | return bucket_membership[item] 297 | 298 | return key_func 299 | 300 | 301 | def _stack_concatenated_datasets(concat_dataset): 302 | r"""Extract and stack indices of different data sets from `concat_dataset`. 303 | 304 | Each data set is represented by a complete range of indices 305 | starting from 0...len(data set) for the first data set. 306 | The ranges of the following data sets build on 307 | the range of the corresponding previous data set so that 308 | the indices of the lists are cumulative:: 309 | 310 | datasets = [[0 ... len(data_1) - 1], 311 | [len(data_1) ... len(data_1) - 1], ... ] 312 | 313 | Args: 314 | concat_dataset (ConcatDataset): data set to sample from 315 | 316 | Returns: 317 | list: list of lists of data set indices 318 | 319 | Example: 320 | >>> import torch 321 | >>> from torch.utils.data import (TensorDataset, ConcatDataset) 322 | >>> data_1 = TensorDataset(torch.Tensor([0, 1, 2, 3])) 323 | >>> data_2 = TensorDataset(torch.Tensor([0, 1, 2])) 324 | >>> concat_dataset = ConcatDataset((data_1, data_2)) 325 | >>> concat_dataset.cumulative_sizes 326 | [4, 7] 327 | >>> _stack_concatenated_datasets(concat_dataset) 328 | [[0, 1, 2, 3], [4, 5, 6]] 329 | 330 | """ 331 | datasets = [] 332 | start_idx = 0 333 | for upper_edge in concat_dataset.cumulative_sizes: 334 | datasets += [list(range(start_idx, upper_edge))] 335 | start_idx = upper_edge 336 | 337 | return datasets 338 | -------------------------------------------------------------------------------- /audtorch/transforms/__init__.py: -------------------------------------------------------------------------------- 1 | from .transforms import * 2 | from . import functional 3 | -------------------------------------------------------------------------------- /audtorch/transforms/functional.py: -------------------------------------------------------------------------------- 1 | import librosa 2 | import numpy as np 3 | import torch 4 | 5 | from .. import utils 6 | 7 | 8 | def crop( 9 | signal, 10 | idx, 11 | *, 12 | axis=-1, 13 | ): 14 | r"""Crop signal along an axis. 15 | 16 | Args: 17 | signal (numpy.ndarray): audio signal 18 | idx (int or tuple): first (and last) index to return 19 | axis (int, optional): axis along to crop. Default: `-1` 20 | 21 | Note: 22 | Indexing from the end with `-1`, `-2`, ... is allowed. But you cannot 23 | use `-1` in the second part of the tuple to specify the last entry. 24 | Instead you have to write `(-2, signal.shape[axis])` to get the last 25 | two entries of `axis`, or simply `-1` if you only want to get the last 26 | entry. 27 | 28 | Returns: 29 | numpy.ndarray: cropped signal 30 | 31 | Example: 32 | >>> a = np.array([[1, 2], [3, 4]]) 33 | >>> crop(a, 1) 34 | array([[2], 35 | [4]]) 36 | 37 | """ 38 | # Ensure idx is iterate able 39 | if isinstance(idx, int): 40 | idx = [idx] 41 | # Allow for -1 like syntax for index 42 | length = signal.shape[axis] 43 | idx = [length + i if i < 0 else i for i in idx] 44 | # Add stop index for single values 45 | if len(idx) == 1: 46 | idx = [idx[0], idx[0] + 1] 47 | 48 | # Split into three parts and return middle one 49 | return np.split(signal, idx, axis=axis)[1] 50 | 51 | 52 | def pad( 53 | signal, 54 | padding, 55 | *, 56 | value=0, 57 | axis=-1, 58 | ): 59 | r"""Pad signal along an axis. 60 | 61 | If padding is an integer it pads equally on the left and right of the 62 | signal. If padding is a tuple with two entries it uses the first for the 63 | left side and the second for the right side. 64 | 65 | Args: 66 | signal (numpy.ndarray): audio signal 67 | padding (int or tuple): padding to apply on the left and right 68 | value (float, optional): value to pad with. Default: `0` 69 | axis (int, optional): axis along which to pad. Default: `-1` 70 | 71 | Returns: 72 | numpy.ndarray: padded signal 73 | 74 | Example: 75 | >>> a = np.array([[1, 2], [3, 4]]) 76 | >>> pad(a, (0, 1)) 77 | array([[1, 2, 0], 78 | [3, 4, 0]]) 79 | 80 | """ 81 | padding = utils.to_tuple(padding) 82 | dimensions = np.ndim(signal) 83 | pad = [(0, 0) for _ in range(dimensions)] # no padding for all axes 84 | pad[axis] = padding # padding along selected axis 85 | 86 | return np.pad(signal, pad, 'constant', constant_values=value) 87 | 88 | 89 | def replicate( 90 | signal, 91 | repetitions, 92 | *, 93 | axis=-1, 94 | ): 95 | r"""Replicate signal along an axis. 96 | 97 | Args: 98 | signal (numpy.ndarray): audio signal 99 | repetitions (int): number of times to replicate signal 100 | axis (int, optional): axis along which to replicate. Default: `-1` 101 | 102 | Returns: 103 | numpy.ndarray: replicated signal 104 | 105 | Example: 106 | >>> a = np.array([1, 2, 3]) 107 | >>> replicate(a, 3) 108 | array([1, 2, 3, 1, 2, 3, 1, 2, 3]) 109 | 110 | """ 111 | reps = [1 for _ in range(len(signal.shape))] 112 | reps[axis] = repetitions 113 | return np.tile(signal, reps) 114 | 115 | 116 | def mask( 117 | signal, 118 | num_blocks, 119 | max_width, 120 | *, 121 | value=0., 122 | axis=-1, 123 | ): 124 | r"""Randomly mask signal along axis. 125 | 126 | Args: 127 | signal (torch.Tensor): audio signal 128 | num_blocks (int): number of mask blocks 129 | max_width (int): maximum size of block 130 | value (float, optional): mask value. Default: `0.` 131 | axis (int, optional): axis along which to mask. Default: `-1` 132 | 133 | Returns: 134 | torch.Tensor: masked signal 135 | 136 | """ 137 | signal_size = signal.shape[axis] 138 | # add 1 to `max_width` to include value `max_width` in sampling 139 | widths = torch.randint(low=1, high=max_width + 1, size=(num_blocks,)) 140 | start = torch.LongTensor( 141 | [torch.randint(0, signal_size - widths[i].item(), (1,)) 142 | for i in range(num_blocks)]) 143 | 144 | for i, s in enumerate(start): 145 | signal.narrow(start=s.item(), 146 | length=widths[i].item(), 147 | dim=axis).fill_(value) 148 | return signal 149 | 150 | 151 | def downmix( 152 | signal, 153 | channels, 154 | *, 155 | method='mean', 156 | axis=-2, 157 | ): 158 | r"""Downmix signal to the provided number of channels. 159 | 160 | The downmix is done by one of these methods: 161 | 162 | * ``'mean'`` replace last desired channel by mean across itself and 163 | all remaining channels 164 | * ``'crop'`` drop all remaining channels 165 | 166 | Args: 167 | signal (numpy.ndarray): audio signal 168 | channels (int): number of desired channels 169 | method (str, optional): downmix method. Default: `'mean'` 170 | axis (int, optional): axis to downmix. Default: `-2` 171 | 172 | Returns: 173 | numpy.ndarray: reshaped signal 174 | 175 | Example: 176 | >>> a = np.array([[1, 2], [3, 4]]) 177 | >>> downmix(a, 1) 178 | array([[2, 3]]) 179 | 180 | """ 181 | input_channels = np.atleast_2d(signal).shape[axis] 182 | 183 | if input_channels <= channels: 184 | return signal 185 | 186 | if method == 'mean': 187 | downmix = crop(signal, (channels - 1, input_channels), axis=axis) 188 | downmix = np.mean(downmix, axis=axis) 189 | signal = np.insert(signal, channels - 1, downmix, axis=axis) 190 | elif method == 'crop': 191 | pass 192 | else: 193 | raise TypeError(f'Method {method} not supported.') 194 | 195 | signal = crop(signal, (0, channels), axis=axis) 196 | return signal 197 | 198 | 199 | def upmix( 200 | signal, 201 | channels, 202 | *, 203 | method='mean', 204 | axis=-2, 205 | ): 206 | r"""Upmix signal to the provided number of channels. 207 | 208 | The upmix is achieved by adding the same signal in the additional channels. 209 | The fixed signal is calculated by one of the following methods: 210 | 211 | * ``'mean'`` mean across all input channels 212 | * ``'zero'`` zeros 213 | * ``'repeat'`` last input channel 214 | 215 | Args: 216 | signal (numpy.ndarray): audio signal 217 | channels (int): number of desired channels 218 | method (str, optional): upmix method. Default: `'mean'` 219 | axis (int, optional): axis to upmix. Default: `-2` 220 | 221 | Returns: 222 | numpy.ndarray: reshaped signal 223 | 224 | Example: 225 | >>> a = np.array([[1, 2], [3, 4]]) 226 | >>> upmix(a, 3) 227 | array([[1., 2.], 228 | [3., 4.], 229 | [2., 3.]]) 230 | 231 | """ 232 | input_channels = np.atleast_2d(signal).shape[axis] 233 | 234 | if input_channels >= channels: 235 | return signal 236 | 237 | signal = np.atleast_2d(signal) 238 | if method == 'mean': 239 | upmix = np.mean(signal, axis=axis) 240 | upmix = np.expand_dims(upmix, axis=axis) 241 | elif method == 'zero': 242 | upmix = np.zeros(signal.shape) 243 | upmix = crop(upmix, -1, axis=axis) 244 | elif method == 'repeat': 245 | upmix = crop(signal, -1, axis=axis) 246 | else: 247 | raise TypeError(f'Method {method} not supported.') 248 | 249 | upmix = np.repeat(upmix, channels - input_channels, axis=axis) 250 | return np.concatenate((signal, upmix), axis=axis) 251 | 252 | 253 | def additive_mix( 254 | signal1, 255 | signal2, 256 | ratio, 257 | ): 258 | r"""Mix two signals additively by given ratio. 259 | 260 | If the power of one of the signals is below 1e-7, the signals are added 261 | without adjusting the signal-to-noise ratio. 262 | 263 | Args: 264 | signal1 (numpy.ndarray): audio signal 265 | signal2 (numpy.ndarray): audio signal 266 | ratio (int): ratio in dB of the second signal compared to the first one 267 | 268 | Returns: 269 | numpy.ndarray: mixture 270 | 271 | Example: 272 | >>> a = np.array([[1, 2], [3, 4]]) 273 | >>> additive_mix(a, a, -10 * np.log10(0.5 ** 2)) 274 | array([[1.5, 3. ], 275 | [4.5, 6. ]]) 276 | 277 | """ 278 | if signal1.shape != signal2.shape: 279 | raise ValueError( 280 | f'Shape of signal1 ({signal1.shape}) ' 281 | f'and signal2 ({signal2.shape}) do not match' 282 | ) 283 | # If one of the signals includes only silence, don't apply SNR 284 | tol = 1e-7 285 | if utils.power(signal1) < tol or utils.power(signal2) < tol: 286 | scaling_factor = 1 287 | else: 288 | scaling_factor = (utils.power(signal1) 289 | / utils.power(signal2) 290 | * 10 ** (-ratio / 10)) 291 | return signal1 + np.sqrt(scaling_factor) * signal2 292 | 293 | 294 | def normalize( 295 | signal, 296 | *, 297 | axis=None, 298 | ): 299 | r"""Normalize signal. 300 | 301 | Ensure the maximum of the absolute value of the signal is 1. 302 | 303 | Note: 304 | The signal will never be divided by a number smaller than 1e-7. 305 | Meaning signals which are nearly silent are only slightly 306 | amplified. 307 | 308 | Args: 309 | signal (numpy.ndarray): audio signal 310 | axis (int, optional): normalize only along the given axis. 311 | Default: `None` 312 | 313 | Returns: 314 | numpy.ndarray: normalized signal 315 | 316 | Example: 317 | >>> a = np.array([[1, 2], [3, 4]]) 318 | >>> normalize(a) 319 | array([[0.25, 0.5 ], 320 | [0.75, 1. ]]) 321 | 322 | """ 323 | if axis is not None: 324 | peak = np.expand_dims(np.amax(np.abs(signal), axis=axis), axis=axis) 325 | else: 326 | peak = np.amax(np.abs(signal)) 327 | return signal / np.maximum(peak, 1e-7) 328 | 329 | 330 | def standardize( 331 | signal, 332 | *, 333 | mean=True, 334 | std=True, 335 | axis=None, 336 | ): 337 | r"""Standardize signal. 338 | 339 | Ensure the signal has a mean value of 0 and a variance of 1. 340 | 341 | Note: 342 | The signal will never be divided by a variance smaller than 1e-7. 343 | 344 | Args: 345 | signal (numpy.ndarray): audio signal 346 | mean (bool, optional): apply mean centering. Default: `True` 347 | std (bool, optional): normalize by standard deviation. Default: `True` 348 | axis (int, optional): standardize only along the given axis. 349 | Default: `None` 350 | 351 | Returns: 352 | numpy.ndarray: standardized signal 353 | 354 | Example: 355 | >>> a = np.array([[1, 2], [3, 4]]) 356 | >>> standardize(a) 357 | array([[-1.34164079, -0.4472136 ], 358 | [ 0.4472136 , 1.34164079]]) 359 | 360 | """ 361 | if mean: 362 | signal_mean = np.mean(signal, axis=axis) 363 | if axis is not None: 364 | signal_mean = np.expand_dims(signal_mean, axis=axis) 365 | signal = signal - signal_mean 366 | if std: 367 | signal_std = np.std(signal, axis=axis) 368 | if axis is not None: 369 | signal_std = np.expand_dims(signal_std, axis=axis) 370 | signal = signal / np.maximum(signal_std, 1e-7) 371 | return signal 372 | 373 | 374 | def stft( 375 | signal, 376 | window_size, 377 | hop_size, 378 | *, 379 | fft_size=None, 380 | window='hann', 381 | axis=-1, 382 | ): 383 | r"""Short-time Fourier transform. 384 | 385 | The Short-time Fourier transform (STFT) is calculated by using librosa. 386 | It returns an array with the same shape as the input array, besides the 387 | axis chosen for STFT calculation is replaced by the two new ones of the 388 | spectrogram. 389 | 390 | The chosen FFT size is set identical to `window_size`. 391 | 392 | Args: 393 | signal (numpy.ndarray): audio signal 394 | window_size (int): size of STFT window in samples 395 | hop_size (int): size of STFT window hop in samples 396 | window (str, tuple, number, function, or numpy.ndarray, optional): type 397 | of STFT window. Default: `hann` 398 | axis (int, optional): axis of STFT calculation. Default: `-1` 399 | 400 | Returns: 401 | numpy.ndarray: complex spectrogram with the shape of its last two 402 | dimensions as `(window_size/2 + 1, 403 | np.ceil((len(signal) + window_size/2) / hop_size))` 404 | 405 | Example: 406 | >>> a = np.array([1., 2., 3., 4.]) 407 | >>> stft(a, 2, 1) 408 | array([[ 1.+0.j, 2.+0.j, 3.+0.j, 4.+0.j, 3.+0.j], 409 | [-1.+0.j, -2.+0.j, -3.+0.j, -4.+0.j, -3.+0.j]]) 410 | 411 | """ 412 | samples = signal.shape[axis] 413 | if samples < window_size: 414 | raise ValueError( 415 | f'`signal` of length {samples} needs to be at least ' 416 | f'as long as the `window_size` of {window_size}' 417 | ) 418 | 419 | # Pad to ensure same signal length after reconstruction 420 | # See discussion at https://github.com/librosa/librosa/issues/328 421 | signal = pad(signal, (0, np.mod(samples, hop_size)), value=0, axis=axis) 422 | if fft_size is None: 423 | fft_size = window_size 424 | fft_config = dict(n_fft=fft_size, hop_length=hop_size, 425 | win_length=window_size, window=window) 426 | spectrogram = np.apply_along_axis(librosa.stft, axis, signal, **fft_config) 427 | return spectrogram 428 | 429 | 430 | def istft( 431 | spectrogram, 432 | window_size, 433 | hop_size, 434 | *, 435 | window='hann', 436 | axis=-2, 437 | ): 438 | r"""Inverse Short-time Fourier transform. 439 | 440 | The inverse Short-time Fourier transform (iSTFT) is calculated by using 441 | librosa. 442 | It handles multi-dimensional inputs, but assumes that the two spectrogram 443 | axis are beside each other, starting with the axis corresponding to 444 | frequency bins. 445 | The returned audio signal has one dimension less than the spectrogram. 446 | 447 | Args: 448 | spectrogram (numpy.ndarray): complex spectrogram 449 | window_size (int): size of STFT window in samples 450 | hop_size (int): size of STFT window hop in samples 451 | window (str, tuple, number, function, or numpy.ndarray, optional): type 452 | of STFT window. Default: `hann` 453 | axis (int, optional): axis of frequency bins of the spectrogram. Time 454 | bins are expected at `axis + 1`. Default: `-2` 455 | 456 | Returns: 457 | numpy.ndarray: signal with shape `(number_of_time_bins * hop_size - 458 | window_size/2)` 459 | 460 | Example: 461 | >>> a = np.array([1., 2., 3., 4.]) 462 | >>> D = stft(a, 4, 1) 463 | >>> istft(D, 4, 1) 464 | array([1., 2., 3., 4.]) 465 | 466 | """ 467 | if axis == -1: 468 | raise ValueError('`axis` of spectrogram frequency bins cannot be -1') 469 | 470 | ifft_config = dict(hop_length=hop_size, win_length=window_size, 471 | window=window) 472 | # Size of frequency and time axis 473 | f = spectrogram.shape[axis] 474 | t = spectrogram.shape[axis + 1] 475 | # Reshape the two axes of spectrogram into one axis 476 | shape_before = spectrogram.shape[:axis] 477 | shape_after = spectrogram.shape[axis:][2:] 478 | D = np.reshape(spectrogram, [*shape_before, f * t, *shape_after]) 479 | # Adjust negative axis values as the second spectrogram axis was removed 480 | if axis < -1: 481 | axis += 1 482 | # iSTFT along the axis 483 | signal = np.apply_along_axis(_istft, axis, D, f, t, **ifft_config) 484 | # Remove padding that was added for STFT 485 | samples = signal.shape[axis] 486 | signal = crop(signal, (0, samples - np.mod(samples, hop_size)), axis=axis) 487 | return signal 488 | 489 | 490 | def _istft(spectrogram, frequency_bins, time_bins, **config): 491 | """Inverse Short-time Fourier transform from a single axis. 492 | 493 | Time and frequency bins have to be provided in a single vector. This allows 494 | effective computation using `numpy.apply_along_axis`. 495 | 496 | Args: 497 | spectrogram (numpy.array): one dimensional vector 498 | frequency_bins (int): number of frequency bins 499 | time_bins (int): number of time bins 500 | **config (dict, optional): keyword arguments for librosa.istft 501 | 502 | Returns: 503 | numpy.array: time-series 504 | 505 | """ 506 | # Reshape to [frequency_bins, time_bins] as expected by librosa 507 | spectrogram = np.reshape(spectrogram, [frequency_bins, time_bins]) 508 | return librosa.istft(spectrogram, **config) 509 | -------------------------------------------------------------------------------- /audtorch/utils.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | import threading 3 | import queue 4 | 5 | from tqdm import tqdm 6 | import numpy as np 7 | 8 | 9 | def flatten_list( 10 | nested_list, 11 | ): 12 | """Flatten an arbitrarily nested list. 13 | 14 | Implemented without recursion to avoid stack overflows. 15 | Returns a new list, the original list is unchanged. 16 | 17 | Args: 18 | nested_list (list): nested list 19 | 20 | Returns: 21 | list: flattened list 22 | 23 | Example: 24 | >>> flatten_list([1, 2, 3, [4], [], [[[[[[[[[5]]]]]]]]]]) 25 | [1, 2, 3, 4, 5] 26 | >>> flatten_list([[1, 2], 3]) 27 | [1, 2, 3] 28 | 29 | """ 30 | def _flat_generator(nested_list): 31 | while nested_list: 32 | sublist = nested_list.pop(0) 33 | if isinstance(sublist, list): 34 | nested_list = sublist + nested_list 35 | else: 36 | yield sublist 37 | nested_list = deepcopy(nested_list) 38 | return list(_flat_generator(nested_list)) 39 | 40 | 41 | def to_tuple( 42 | input, 43 | *, 44 | tuple_len=2, 45 | ): 46 | r"""Convert to tuple of given length. 47 | 48 | This utility function is used to convert single-value arguments to tuples 49 | of appropriate length, e.g. for multi-dimensional inputs where each 50 | dimension requires the same value. If the argument is already an iterable 51 | it is returned as a tuple if its length matches the desired tuple length. 52 | Otherwise a `ValueError` is raised. 53 | 54 | Args: 55 | input (non-iterable or iterable): argument to be converted to tuple 56 | tuple_len (int): required length of argument tuple. Default: `2` 57 | 58 | Returns: 59 | tuple: tuple of desired length 60 | 61 | Example: 62 | >>> to_tuple(2) 63 | (2, 2) 64 | 65 | """ 66 | try: 67 | iter(input) 68 | if len(input) != tuple_len: 69 | raise ValueError( 70 | f'Input length expected to be {tuple_len} but was {len(input)}' 71 | ) 72 | else: 73 | input = tuple(input) 74 | except TypeError: 75 | input = tuple([input] * tuple_len) 76 | return input 77 | 78 | 79 | def energy( 80 | signal, 81 | ): 82 | r"""Energy of input signal. 83 | 84 | .. math:: 85 | E = \sum_n |x_n|^2 86 | 87 | Args: 88 | signal (numpy.ndarray): signal 89 | 90 | Returns: 91 | float: energy of signal 92 | 93 | Example: 94 | >>> a = np.array([[2, 2]]) 95 | >>> energy(a) 96 | 8 97 | 98 | """ 99 | return np.sum(np.abs(signal) ** 2) 100 | 101 | 102 | def power( 103 | signal, 104 | ): 105 | r"""Power of input signal. 106 | 107 | .. math:: 108 | P = {1 \over N} \sum_n |x_n|^2 109 | 110 | Args: 111 | signal (numpy.ndarray): signal 112 | 113 | Returns: 114 | float: power of signal 115 | 116 | Example: 117 | >>> a = np.array([[2, 2]]) 118 | >>> power(a) 119 | 4.0 120 | 121 | """ 122 | return np.sum(np.abs(signal) ** 2) / signal.size 123 | 124 | 125 | def run_worker_threads( 126 | num_workers, 127 | task_fun, params, 128 | *, 129 | progress_bar=False, 130 | ): 131 | r"""Run parallel tasks using worker threads. 132 | 133 | Args: 134 | num_workers (int): number of worker threads 135 | task_fun (Callable): task function with one or more 136 | parameters, e.g. x, y, z, and optionally returning a value 137 | params (list of tuples): list of tuples holding parameters 138 | for each task, e.g. [(x1, y1, z1), (x2, y2, z2), ...] 139 | progress_bar (bool): show a progress bar. Default: False 140 | 141 | Returns: 142 | list: result values in order of `params` 143 | 144 | Example: 145 | >>> power = lambda x, n: x ** n 146 | >>> params = [(2, n) for n in range(10)] 147 | >>> run_worker_threads(3, power, params) 148 | [1, 2, 4, 8, 16, 32, 64, 128, 256, 512] 149 | 150 | """ 151 | num_workers = max(0, num_workers) 152 | n_tasks = len(params) 153 | results = [None] * n_tasks 154 | 155 | # do not use more workers as needed 156 | num_workers = n_tasks if num_workers == 0 else min(num_workers, n_tasks) 157 | 158 | # num_workers == 1 -> run sequentially 159 | if num_workers == 1: 160 | for index, param in enumerate(params): 161 | results[index] = task_fun(*param) 162 | 163 | # number_workers > 1 -> parallalize work 164 | else: 165 | 166 | # define worker thread 167 | def _worker(): 168 | while True: 169 | item = q.get() 170 | if item is None: 171 | break 172 | index, param = item 173 | results[index] = task_fun(*param) 174 | q.task_done() 175 | 176 | # create queue, possibly with a progress bar 177 | if progress_bar: 178 | class QueueWithProgbar(queue.Queue): 179 | def __init__(self, n_tasks, maxsize=0): 180 | super().__init__(maxsize) 181 | self.pbar = tqdm(total=n_tasks) 182 | 183 | def task_done(self): 184 | super().task_done() 185 | self.pbar.update(1) 186 | q = QueueWithProgbar(n_tasks) 187 | else: 188 | q = queue.Queue() 189 | 190 | # fill queue 191 | for index, param in enumerate(params): 192 | q.put((index, param)) 193 | 194 | # start workers 195 | threads = [] 196 | for i in range(num_workers): 197 | t = threading.Thread(target=_worker) 198 | t.start() 199 | threads.append(t) 200 | 201 | # block until all tasks are done 202 | q.join() 203 | 204 | # stop workers 205 | for _ in range(num_workers): 206 | q.put(None) 207 | for t in threads: 208 | t.join() 209 | 210 | return results 211 | -------------------------------------------------------------------------------- /docs/api-collate.rst: -------------------------------------------------------------------------------- 1 | audtorch.collate 2 | ================ 3 | 4 | Collate functions manipulate and merge a list 5 | of samples to form a mini-batch, see :py:class:`torch.utils.data.DataLoader`. 6 | An example use case is batching sequences of variable-length, 7 | which requires padding each sample to the maximum length in the batch. 8 | 9 | .. automodule:: audtorch.collate 10 | 11 | Collation 12 | --------- 13 | 14 | .. autoclass:: Collation 15 | :members: 16 | 17 | Seq2Seq 18 | ------- 19 | 20 | .. autoclass:: Seq2Seq 21 | :members: 22 | -------------------------------------------------------------------------------- /docs/api-datasets.rst: -------------------------------------------------------------------------------- 1 | audtorch.datasets 2 | ================= 3 | 4 | Audio data sets. 5 | 6 | .. automodule:: audtorch.datasets 7 | 8 | 9 | AudioSet 10 | -------- 11 | 12 | .. autoclass:: AudioSet 13 | :members: 14 | 15 | EmoDB 16 | ----- 17 | 18 | .. autoclass:: EmoDB 19 | :members: 20 | 21 | LibriSpeech 22 | ----------- 23 | 24 | .. autoclass:: LibriSpeech 25 | :members: 26 | 27 | MozillaCommonVoice 28 | ------------------ 29 | 30 | .. autoclass:: MozillaCommonVoice 31 | :members: 32 | 33 | SpeechCommands 34 | -------------- 35 | 36 | .. autoclass:: SpeechCommands 37 | :members: 38 | 39 | VoxCeleb1 40 | --------- 41 | 42 | .. autoclass:: VoxCeleb1 43 | :members: 44 | 45 | WhiteNoise 46 | ---------- 47 | 48 | .. autoclass:: WhiteNoise 49 | :members: 50 | 51 | Base 52 | ---- 53 | 54 | This section contains a mix of generic data sets that are useful for a wide 55 | variety of cases and can be used as base classes for other data sets. 56 | 57 | AudioDataset 58 | ~~~~~~~~~~~~ 59 | 60 | .. autoclass:: AudioDataset 61 | :members: 62 | 63 | PandasDataset 64 | ~~~~~~~~~~~~~ 65 | 66 | .. autoclass:: PandasDataset 67 | :members: 68 | 69 | CsvDataset 70 | ~~~~~~~~~~ 71 | 72 | .. autoclass:: CsvDataset 73 | :members: 74 | 75 | AudioConcatDataset 76 | ~~~~~~~~~~~~~~~~~~ 77 | 78 | .. autoclass:: AudioConcatDataset 79 | :members: 80 | 81 | Mixture 82 | ------- 83 | 84 | This section contains data sets that are primarily used for mixing different 85 | data sets. 86 | 87 | SpeechNoiseMix 88 | ~~~~~~~~~~~~~~ 89 | 90 | .. autoclass:: SpeechNoiseMix 91 | :members: 92 | 93 | Utils 94 | ----- 95 | 96 | Utility functions for handling audio data sets. 97 | 98 | load 99 | ~~~~ 100 | 101 | .. autofunction:: load 102 | 103 | download_url 104 | ~~~~~~~~~~~~ 105 | 106 | .. autofunction:: download_url 107 | 108 | download_url_list 109 | ~~~~~~~~~~~~~~~~~ 110 | 111 | .. autofunction:: download_url_list 112 | 113 | extract_archive 114 | ~~~~~~~~~~~~~~~ 115 | 116 | .. autofunction:: extract_archive 117 | 118 | sampling_rate_after_transform 119 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 120 | 121 | .. autofunction:: sampling_rate_after_transform 122 | 123 | ensure_same_sampling_rate 124 | ~~~~~~~~~~~~~~~~~~~~~~~~~ 125 | 126 | .. autofunction:: ensure_same_sampling_rate 127 | 128 | ensure_df_columns_contain 129 | ~~~~~~~~~~~~~~~~~~~~~~~~~ 130 | 131 | .. autofunction:: ensure_df_columns_contain 132 | 133 | ensure_df_not_empty 134 | ~~~~~~~~~~~~~~~~~~~ 135 | 136 | .. autofunction:: ensure_df_not_empty 137 | 138 | files_and_labels_from_df 139 | ~~~~~~~~~~~~~~~~~~~~~~~~ 140 | 141 | .. autofunction:: files_and_labels_from_df 142 | 143 | defined_split 144 | ~~~~~~~~~~~~~ 145 | 146 | .. autofunction:: defined_split 147 | -------------------------------------------------------------------------------- /docs/api-metrics-functional.rst: -------------------------------------------------------------------------------- 1 | audtorch.metrics.functional 2 | =========================== 3 | 4 | The goal of the metrics functionals is to provide functions that work 5 | independent on the dimensions of the input signal and can be used easily to 6 | create additional metrics and losses. 7 | 8 | .. automodule:: audtorch.metrics.functional 9 | 10 | pearsonr 11 | -------- 12 | 13 | .. autofunction:: pearsonr 14 | 15 | concordance_cc 16 | -------------- 17 | 18 | .. autofunction:: concordance_cc 19 | 20 | -------------------------------------------------------------------------------- /docs/api-metrics.rst: -------------------------------------------------------------------------------- 1 | audtorch.metrics 2 | ================ 3 | 4 | .. automodule:: audtorch.metrics 5 | 6 | EnergyConservingLoss 7 | -------------------- 8 | 9 | .. autoclass:: EnergyConservingLoss 10 | :members: 11 | 12 | PearsonR 13 | -------- 14 | 15 | .. autoclass:: PearsonR 16 | :members: 17 | 18 | ConcordanceCC 19 | ------------- 20 | 21 | .. autoclass:: ConcordanceCC 22 | :members: 23 | -------------------------------------------------------------------------------- /docs/api-samplers.rst: -------------------------------------------------------------------------------- 1 | audtorch.samplers 2 | ================= 3 | 4 | .. automodule:: audtorch.samplers 5 | 6 | BucketSampler 7 | ----------------------------- 8 | 9 | .. autoclass:: BucketSampler 10 | :members: 11 | 12 | buckets_by_boundaries 13 | --------------------------- 14 | 15 | .. autofunction:: buckets_by_boundaries 16 | 17 | buckets_of_even_size 18 | ------------------------------ 19 | 20 | .. autofunction:: buckets_of_even_size 21 | -------------------------------------------------------------------------------- /docs/api-transforms-functional.rst: -------------------------------------------------------------------------------- 1 | audtorch.transforms.functional 2 | ============================== 3 | 4 | The goal of the transform functionals is to provide functions that work 5 | independent on the dimensions of the input signal and can be used easily to 6 | create the actual transforms. 7 | 8 | .. Note:: 9 | 10 | All of the transforms work currently only with :py:obj:`numpy.array` as 11 | inputs, not :py:obj:`torch.Tensor`. 12 | 13 | .. automodule:: audtorch.transforms.functional 14 | 15 | crop 16 | ---- 17 | 18 | .. autofunction:: crop 19 | 20 | pad 21 | --- 22 | 23 | .. autofunction:: pad 24 | 25 | replicate 26 | --------- 27 | 28 | .. autofunction:: replicate 29 | 30 | downmix 31 | ------- 32 | 33 | .. autofunction:: downmix 34 | 35 | upmix 36 | ----- 37 | 38 | .. autofunction:: upmix 39 | 40 | additive_mix 41 | ------------ 42 | 43 | .. autofunction:: additive_mix 44 | 45 | mask 46 | ---- 47 | 48 | .. autofunction:: mask 49 | 50 | normalize 51 | --------- 52 | 53 | .. autofunction:: normalize 54 | 55 | standardize 56 | ----------- 57 | 58 | .. autofunction:: standardize 59 | 60 | stft 61 | ---- 62 | 63 | .. autofunction:: stft 64 | 65 | istft 66 | ----- 67 | 68 | .. autofunction:: istft 69 | -------------------------------------------------------------------------------- /docs/api-transforms.rst: -------------------------------------------------------------------------------- 1 | audtorch.transforms 2 | =================== 3 | 4 | The transforms can be provided to :py:class:`audtorch.datasets` as an argument 5 | and work on the data before it will be returned. 6 | 7 | .. Note:: 8 | 9 | All of the transforms work currently only with :py:obj:`numpy.array` as 10 | inputs, not :py:obj:`torch.Tensor`. 11 | 12 | .. automodule:: audtorch.transforms 13 | 14 | Compose 15 | ------- 16 | 17 | .. autoclass:: Compose 18 | :members: 19 | 20 | Crop 21 | ---- 22 | 23 | .. autoclass:: Crop 24 | :members: 25 | 26 | RandomCrop 27 | ---------- 28 | 29 | .. autoclass:: RandomCrop 30 | :members: 31 | 32 | Pad 33 | --- 34 | 35 | .. autoclass:: Pad 36 | :members: 37 | 38 | RandomPad 39 | --------- 40 | 41 | .. autoclass:: RandomPad 42 | :members: 43 | 44 | Replicate 45 | --------- 46 | 47 | .. autoclass:: Replicate 48 | :members: 49 | 50 | RandomReplicate 51 | --------------- 52 | 53 | .. autoclass:: RandomReplicate 54 | :members: 55 | 56 | Expand 57 | ------ 58 | 59 | .. autoclass:: Expand 60 | :members: 61 | 62 | RandomMask 63 | ---------- 64 | 65 | .. autoclass:: RandomMask 66 | :members: 67 | 68 | MaskSpectrogramTime 69 | ------------------- 70 | 71 | .. autoclass:: MaskSpectrogramTime 72 | :members: 73 | 74 | MaskSpectrogramFrequency 75 | ------------------------ 76 | 77 | .. autoclass:: MaskSpectrogramFrequency 78 | :members: 79 | 80 | Downmix 81 | ------- 82 | 83 | .. autoclass:: Downmix 84 | :members: 85 | 86 | Upmix 87 | ----- 88 | 89 | .. autoclass:: Upmix 90 | :members: 91 | 92 | Remix 93 | ----- 94 | 95 | .. autoclass:: Remix 96 | :members: 97 | 98 | Normalize 99 | --------- 100 | 101 | .. autoclass:: Normalize 102 | :members: 103 | 104 | Standardize 105 | ----------- 106 | 107 | .. autoclass:: Standardize 108 | :members: 109 | 110 | Resample 111 | -------- 112 | 113 | .. autoclass:: Resample 114 | :members: 115 | 116 | Spectrogram 117 | ----------- 118 | 119 | .. autoclass:: Spectrogram 120 | :members: 121 | 122 | Log 123 | --- 124 | 125 | .. autoclass:: Log 126 | :members: 127 | 128 | RandomAdditiveMix 129 | ----------------- 130 | 131 | .. autoclass:: RandomAdditiveMix 132 | :members: 133 | 134 | RandomConvolutionalMix 135 | ---------------------- 136 | 137 | .. autoclass:: RandomConvolutionalMix 138 | :members: 139 | -------------------------------------------------------------------------------- /docs/api-utils.rst: -------------------------------------------------------------------------------- 1 | audtorch.utils 2 | ============== 3 | 4 | Utility functions. 5 | 6 | .. automodule:: audtorch.utils 7 | 8 | flatten_list 9 | ------------ 10 | 11 | .. autofunction:: flatten_list 12 | 13 | to_tuple 14 | -------- 15 | 16 | .. autofunction:: to_tuple 17 | 18 | energy 19 | ------ 20 | 21 | .. autofunction:: energy 22 | 23 | power 24 | ----- 25 | 26 | .. autofunction:: power 27 | 28 | run_worker_threads 29 | ------------------ 30 | 31 | .. autofunction:: run_worker_threads 32 | -------------------------------------------------------------------------------- /docs/changelog.rst: -------------------------------------------------------------------------------- 1 | .. include:: ../CHANGELOG.rst 2 | -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from subprocess import check_output 4 | 5 | 6 | # Import ------------------------------------------------------------------ 7 | 8 | # Relative source code path. Avoids `import audtorch`, which need package to be 9 | # installed first. 10 | sys.path.insert(0, os.path.abspath('..')) 11 | 12 | # Ignore package dependencies during building the docs 13 | autodoc_mock_imports = [ 14 | 'audiofile', 15 | 'librosa', 16 | 'numpy', 17 | 'pandas', 18 | 'resampy', 19 | 'torch', 20 | 'scipy', 21 | 'tqdm', 22 | 'tabulate', 23 | ] 24 | 25 | 26 | # Project ----------------------------------------------------------------- 27 | 28 | project = 'audtorch' 29 | copyright = '2019 audEERING GmbH' 30 | author = ('Andreas Triantafyllopoulos, ' 31 | 'Stephan Huber, ' 32 | 'Johannes Wagner, ' 33 | 'Hagen Wierstorf') 34 | # The x.y.z version read from tags 35 | try: 36 | version = check_output(['git', 'describe', '--tags', '--always']) 37 | version = version.decode().strip() 38 | except Exception: 39 | version = '' 40 | title = '{} Documentation'.format(project) 41 | 42 | 43 | # General ----------------------------------------------------------------- 44 | 45 | master_doc = 'index' 46 | extensions = [] 47 | source_suffix = '.rst' 48 | exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store', '**.ipynb_checkpoints'] 49 | pygments_style = None 50 | extensions = [ 51 | 'sphinx.ext.autodoc', 52 | 'sphinx.ext.doctest', 53 | 'sphinx.ext.intersphinx', 54 | 'sphinx.ext.napoleon', # support for Google-style docstrings 55 | 'sphinx_copybutton', 56 | 'sphinxcontrib.bibtex', 57 | 'sphinxcontrib.katex', 58 | 'nbsphinx', 59 | ] 60 | 61 | napoleon_use_ivar = True # List of class attributes 62 | autodoc_inherit_docstrings = False # disable docstring inheritance 63 | 64 | intersphinx_mapping = { 65 | 'python': ('https://docs.python.org/3/', None), 66 | 'numpy': ('https://docs.scipy.org/doc/numpy/', None), 67 | 'scipy': ('https://docs.scipy.org/doc/scipy/reference/', None), 68 | 'torch': ('https://pytorch.org/docs/stable/', None), 69 | } 70 | 71 | copybutton_prompt_text = r'>>> |\.\.\. |$ ' 72 | copybutton_prompt_is_regexp = True 73 | 74 | nbsphinx_execute = 'never' 75 | 76 | linkcheck_ignore = [ 77 | 'https://doi.org/', # has timeouts from time to time 78 | ] 79 | bibtex_bibfiles = ['refs.bib'] 80 | 81 | # HTML -------------------------------------------------------------------- 82 | 83 | html_theme = 'sphinx_audeering_theme' 84 | html_theme_options = { 85 | 'display_version': True, 86 | 'footer_links': False, 87 | 'logo_only': False, 88 | } 89 | html_context = { 90 | 'display_github': True, 91 | } 92 | 93 | html_title = title 94 | -------------------------------------------------------------------------------- /docs/develop.rst: -------------------------------------------------------------------------------- 1 | .. include:: ../CONTRIBUTING.rst 2 | -------------------------------------------------------------------------------- /docs/genindex.rst: -------------------------------------------------------------------------------- 1 | Index 2 | ----- 3 | -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | .. include:: ../README.rst 2 | 3 | .. toctree:: 4 | :caption: Getting started 5 | :hidden: 6 | 7 | install 8 | usage 9 | develop 10 | changelog 11 | 12 | .. Warning: then usage of genindex is a hack to get a TOC entry, see 13 | .. https://stackoverflow.com/a/42310803. This might break the usage of sphinx if 14 | .. you want to create something different than HTML output. 15 | 16 | .. toctree:: 17 | :caption: Tutorials 18 | :hidden: 19 | 20 | tutorials/introduction 21 | 22 | .. toctree:: 23 | :caption: API Documentation 24 | :hidden: 25 | 26 | api-collate 27 | api-datasets 28 | api-metrics 29 | api-metrics-functional 30 | api-samplers 31 | api-transforms 32 | api-transforms-functional 33 | api-utils 34 | refs 35 | genindex 36 | -------------------------------------------------------------------------------- /docs/install.rst: -------------------------------------------------------------------------------- 1 | Installation 2 | ============ 3 | 4 | :mod:`audtorch` is supported by Python 3.6 or higher. To install it run 5 | (preferable in a `virtual environment`_): 6 | 7 | .. code-block:: bash 8 | 9 | pip install audtorch 10 | 11 | .. _virtual environment: https://docs.python-guide.org/dev/virtualenvs 12 | -------------------------------------------------------------------------------- /docs/refs.bib: -------------------------------------------------------------------------------- 1 | @inproceedings{Rethage2018, 2 | author = {Rethage, Dario and Pons, Jordi and Serra, Xavier}, 3 | booktitle = {IEEE International Conference on Acoustics, Speech and 4 | Signal Processing (ICASSP)}, 5 | title = {A Wavenet for Speech Denoising}, 6 | pages = {5069-5073}, 7 | doi = {10.1109/ICASSP.2018.8462417}, 8 | url = {https://arxiv.org/abs/1706.07162}, 9 | year = {2018} 10 | } 11 | @inproceedings{nagrani2017voxceleb, 12 | author={Arsha Nagrani and Joon Son Chung and Andrew Zisserman}, 13 | title={VoxCeleb: A Large-Scale Speaker Identification Dataset}, 14 | year=2017, 15 | booktitle={Proc. Interspeech 2017}, 16 | pages={2616--2620}, 17 | doi={10.21437/Interspeech.2017-950}, 18 | url={https://arxiv.org/abs/1706.08612} 19 | } 20 | @inproceedings{burkhardt2005database, 21 | title={A database of German emotional speech}, 22 | author={Burkhardt, Felix and Paeschke, Astrid and 23 | Rolfes, Miriam and Sendlmeier, Walter F and Weiss, Benjamin}, 24 | booktitle={Ninth European Conference on Speech Communication and Technology}, 25 | year={2005} 26 | } 27 | -------------------------------------------------------------------------------- /docs/refs.rst: -------------------------------------------------------------------------------- 1 | .. _sec-references: 2 | 3 | .. only:: html 4 | 5 | References 6 | ---------- 7 | 8 | .. bibliography:: refs.bib 9 | :style: alpha 10 | :all: 11 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | jupyter 2 | matplotlib 3 | nbsphinx 4 | sphinx 5 | sphinx-audeering-theme >=0.9.0 6 | sphinxcontrib-katex 7 | sphinxcontrib-bibtex >=2.1.0 8 | sphinx-copybutton 9 | -------------------------------------------------------------------------------- /docs/tutorials/data/emodb/03a01Eb.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/audeering/audtorch/d82ae7f7f8c7edb7b7180b83442224e9a68483bd/docs/tutorials/data/emodb/03a01Eb.wav -------------------------------------------------------------------------------- /docs/tutorials/data/emodb/03a01Fa.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/audeering/audtorch/d82ae7f7f8c7edb7b7180b83442224e9a68483bd/docs/tutorials/data/emodb/03a01Fa.wav -------------------------------------------------------------------------------- /docs/tutorials/introduction.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Introduction\n", 8 | "In this tutorial, we will see how one can use `audtorch` to rapidly speed up the development of audio-based deep learning applications." 9 | ] 10 | }, 11 | { 12 | "cell_type": "markdown", 13 | "metadata": {}, 14 | "source": [ 15 | "## Preliminaries\n", 16 | "\n", 17 | "* [PyTorch](https://pytorch.org/) already has an inteface for data sets, aptly called [Dataset](https://pytorch.org/docs/stable/data.html#torch.utils.data.Dataset) \n", 18 | "* It then wraps this interface with a [DataLoader](https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader) that efficiently allows us to loop through the data in parallel, and takes care of the random order as well\n", 19 | "* All we need to do is implement the [Dataset](https://pytorch.org/docs/stable/data.html#torch.utils.data.Dataset) interface to get the input for the model and the labels\n", 20 | "* **However**, it is not easy for beginners to see how one can go from a bunch of files in their hard drive, to the features that will be used as input in a machine learning model\n", 21 | "* **Thankfully**, `audtorch` is there to take of all that for you :-)\n", 22 | "\n", 23 | "Before you start, you might want to familiarize yourselves with [PyTorch's data pipeline](https://pytorch.org/docs/stable/data.html)" 24 | ] 25 | }, 26 | { 27 | "cell_type": "markdown", 28 | "metadata": {}, 29 | "source": [ 30 | "## Data loading\n", 31 | "We are going to start with loading the necessary data. \n", 32 | "\n", 33 | "`audtorch` offers a growing [collection of data sets](https://audeering.github.io/audtorch/api-datasets.html). Normally, using this interface requires one to have that particular data set on their hard drive. Some of them even support downloading from their original source. \n", 34 | "\n", 35 | "We will be using the Berlin Database of Emotional Speech (EmoDB) for this tutorial. For convenience, we have included two of its files in a sub-directory. We recommend you to get the whole data base from its [original website ](http://www.emodb.bilderbar.info/navi.html). " 36 | ] 37 | }, 38 | { 39 | "cell_type": "code", 40 | "execution_count": null, 41 | "metadata": {}, 42 | "outputs": [], 43 | "source": [ 44 | "%matplotlib inline\n", 45 | "import matplotlib\n", 46 | "import matplotlib.pyplot as plt\n", 47 | "import numpy as np\n", 48 | "import audtorch\n", 49 | "import IPython.display as ipd" 50 | ] 51 | }, 52 | { 53 | "cell_type": "code", 54 | "execution_count": null, 55 | "metadata": {}, 56 | "outputs": [], 57 | "source": [ 58 | "dataset = audtorch.datasets.EmoDB(\n", 59 | " root='data/emodb'\n", 60 | ")\n", 61 | "print(dataset)" 62 | ] 63 | }, 64 | { 65 | "cell_type": "code", 66 | "execution_count": null, 67 | "metadata": {}, 68 | "outputs": [], 69 | "source": [ 70 | "x, y = dataset[0]\n", 71 | "print(x.shape)\n", 72 | "print(y)" 73 | ] 74 | }, 75 | { 76 | "cell_type": "code", 77 | "execution_count": null, 78 | "metadata": {}, 79 | "outputs": [], 80 | "source": [ 81 | "ipd.Audio(x, rate=dataset.sampling_rate)" 82 | ] 83 | }, 84 | { 85 | "cell_type": "markdown", 86 | "metadata": {}, 87 | "source": [ 88 | "That's it really. Up to this point, `audtorch` does not add much to the PyTorch's data API, which is already quite advanced anyway." 89 | ] 90 | }, 91 | { 92 | "cell_type": "markdown", 93 | "metadata": {}, 94 | "source": [ 95 | "## Feature extraction\n", 96 | "Feature extraction is the first important benefit of using `audtorch`. \n", 97 | "\n", 98 | "`audtorch` collects an ever growing set of [feature transformation and data pre-processing utilities](https://audeering.github.io/audtorch/api-transforms.html#). That way you don't need to worry too much about getting your data pipeline ready, but you can quickly start with the cool modelling part. \n", 99 | "\n", 100 | "A typical kind of features used in the audio domain, are spectral features. Audio signals are analyzed with respect to their frequency content using something called a [Fourier transform](https://en.wikipedia.org/wiki/Fourier_transform). \n", 101 | "\n", 102 | "Moreover, since that content changes over time, we normally use a [short-time Fourier Transform](https://en.wikipedia.org/wiki/Short-time_Fourier_transform). This leads then to the generation of a so-called [spectrogram](https://en.wikipedia.org/wiki/Spectrogram), which is nothing more than an image representation of the frequency content of a signal over time. \n", 103 | "\n", 104 | "We assume that the reader is already familiar with this terminology. What's important to point out, is that `audtorch` is designed to allow for easy usage of those features in a typical `PyTorch` workflow. Below, we see an example of how a feature extraction transform is defined:" 105 | ] 106 | }, 107 | { 108 | "cell_type": "code", 109 | "execution_count": null, 110 | "metadata": {}, 111 | "outputs": [], 112 | "source": [ 113 | "spec = audtorch.transforms.Spectrogram(\n", 114 | " window_size=int(0.025 * dataset.sampling_rate),\n", 115 | " hop_size=int(0.010 * dataset.sampling_rate)\n", 116 | ")\n", 117 | "print(spec)" 118 | ] 119 | }, 120 | { 121 | "cell_type": "markdown", 122 | "metadata": {}, 123 | "source": [ 124 | "By plotting the spectrogram, we see what frequency content our signal has over time." 125 | ] 126 | }, 127 | { 128 | "cell_type": "code", 129 | "execution_count": null, 130 | "metadata": {}, 131 | "outputs": [], 132 | "source": [ 133 | "spectrogram = spec(x)\n", 134 | "plt.imshow(spectrogram.squeeze())\n", 135 | "plt.gca().invert_yaxis()" 136 | ] 137 | }, 138 | { 139 | "cell_type": "markdown", 140 | "metadata": {}, 141 | "source": [ 142 | "The above image looks mostly empty. That's why we have a lot of content with very low power that is dominated by the presence of a few frequencies where most of the signal's power is concentrated. \n", 143 | "\n", 144 | "It is typical to compute the logarithm of the spectrogram to reveal more information. That squashes the input and reveals previously hidden structure in other frequency bands. Incidentally, this squashing reduces the dynamic range of the resulting image, which makes our input more suitable for deep neural network training. \n", 145 | "\n", 146 | "`audtorch` provides a nice wrapper function for [numpy's log](https://docs.scipy.org/doc/numpy/reference/generated/numpy.log.html) to simplify things." 147 | ] 148 | }, 149 | { 150 | "cell_type": "code", 151 | "execution_count": null, 152 | "metadata": {}, 153 | "outputs": [], 154 | "source": [ 155 | "lg = audtorch.transforms.Log()\n", 156 | "print(lg)" 157 | ] 158 | }, 159 | { 160 | "cell_type": "code", 161 | "execution_count": null, 162 | "metadata": {}, 163 | "outputs": [], 164 | "source": [ 165 | "log_spectrogram = lg(spectrogram)" 166 | ] 167 | }, 168 | { 169 | "cell_type": "code", 170 | "execution_count": null, 171 | "metadata": {}, 172 | "outputs": [], 173 | "source": [ 174 | "plt.imshow(log_spectrogram.squeeze())\n", 175 | "plt.gca().invert_yaxis()" 176 | ] 177 | }, 178 | { 179 | "cell_type": "markdown", 180 | "metadata": {}, 181 | "source": [ 182 | "This image shows that there is a lot more going on in our signal than we previously thought. \n", 183 | "\n", 184 | "In general, we recommend to always start with a preliminary data analysis before you jump into modelling to ensure you have the proper understanding of your problem. \n", 185 | "\n", 186 | "`audtorch` is here to help you with that, and another useful feature is that it allows you to stack multiple transforms in a [Compose transform](https://audeering.github.io/audtorch/api-transforms.html#audtorch.transforms.Compose). Below, we stack together the spectrogram and the log transforms to form a single object." 187 | ] 188 | }, 189 | { 190 | "cell_type": "code", 191 | "execution_count": null, 192 | "metadata": {}, 193 | "outputs": [], 194 | "source": [ 195 | "t = audtorch.transforms.Compose(\n", 196 | " [\n", 197 | " audtorch.transforms.Spectrogram(\n", 198 | " window_size=int(0.025 * 16000),\n", 199 | " hop_size=int(0.010 * 16000)\n", 200 | " ),\n", 201 | " audtorch.transforms.Log()\n", 202 | " ]\n", 203 | ")\n", 204 | "print(t)" 205 | ] 206 | }, 207 | { 208 | "cell_type": "code", 209 | "execution_count": null, 210 | "metadata": {}, 211 | "outputs": [], 212 | "source": [ 213 | "plt.imshow(t(x).squeeze())\n", 214 | "plt.gca().invert_yaxis()" 215 | ] 216 | }, 217 | { 218 | "cell_type": "markdown", 219 | "metadata": {}, 220 | "source": [ 221 | "This stacking can continue *ad infinum*, as seen below with the [Standardize transform](https://audeering.github.io/audtorch/api-transforms.html#standardize). \n", 222 | "\n", 223 | "Make sure to always stay up to date with [all the transforms offered by audtorch](https://audeering.github.io/audtorch/api-transforms.html)!" 224 | ] 225 | }, 226 | { 227 | "cell_type": "code", 228 | "execution_count": null, 229 | "metadata": {}, 230 | "outputs": [], 231 | "source": [ 232 | "t = audtorch.transforms.Compose(\n", 233 | " [\n", 234 | " audtorch.transforms.Spectrogram(\n", 235 | " window_size=int(0.025 * 16000),\n", 236 | " hop_size=int(0.010 * 16000)\n", 237 | " ),\n", 238 | " audtorch.transforms.Log(),\n", 239 | " audtorch.transforms.Standardize()\n", 240 | " ]\n", 241 | ")\n", 242 | "print(t)" 243 | ] 244 | }, 245 | { 246 | "cell_type": "code", 247 | "execution_count": null, 248 | "metadata": {}, 249 | "outputs": [], 250 | "source": [ 251 | "plt.imshow(t(x).squeeze())\n", 252 | "plt.gca().invert_yaxis()" 253 | ] 254 | }, 255 | { 256 | "cell_type": "markdown", 257 | "metadata": {}, 258 | "source": [ 259 | "## Data augmentation\n", 260 | "\n", 261 | "One of the most crucial aspects of recent deep learning successes is arguably data augmentation. Roughly, this means increasing the sampling of your input space by creating slightly different copies of the original input without changing the label. \n", 262 | "\n", 263 | "In the image domain, people use a variety of transforms, such as:\n", 264 | "\n", 265 | "* Adding noise\n", 266 | "* Cropping\n", 267 | "* Rotating\n", 268 | "* Etc.\n", 269 | "\n", 270 | "Things are not so easy in the audio domain. Rotation, for example, does not make any sense for spectrogram features, since the two axes are not interchangeable. In general, the community seems to use the following transforms:\n", 271 | "\n", 272 | "* Noise\n", 273 | "* Time/frequency masking\n", 274 | "* Pitch shifting\n", 275 | "* Etc.\n", 276 | "\n", 277 | "An important feature of `audtorch` is making these transformations very easy to use in practice. In the following example, we will be using [RandomAdditiveMix](https://audeering.github.io/audtorch/api-transforms.html#randomadditivemix). This transforms allows you to randomly mix audio samples with a noise data set of your choice (e.g. a large audio data set like [AudioSet](https://audeering.github.io/audtorch/api-datasets.html#audioset)). \n", 278 | "\n", 279 | "In this example, we will use a built-in data set, [WhiteNoise](https://audeering.github.io/audtorch/api-datasets.html#whitenoise), which simply creates a random white noise signal every time it is called." 280 | ] 281 | }, 282 | { 283 | "cell_type": "code", 284 | "execution_count": null, 285 | "metadata": {}, 286 | "outputs": [], 287 | "source": [ 288 | "random_mix = audtorch.transforms.RandomAdditiveMix(\n", 289 | " dataset=audtorch.datasets.WhiteNoise(sampling_rate=dataset.sampling_rate)\n", 290 | ")\n", 291 | "print(random_mix)" 292 | ] 293 | }, 294 | { 295 | "cell_type": "markdown", 296 | "metadata": {}, 297 | "source": [ 298 | "You can see that this transforms modifies the audio signal itself, by adding this \"static\" TV noise to our original signal. Obviously though, the emotion of the speaker remains the same. This is a very practical way to augment your training set without changing the labels." 299 | ] 300 | }, 301 | { 302 | "cell_type": "code", 303 | "execution_count": null, 304 | "metadata": {}, 305 | "outputs": [], 306 | "source": [ 307 | "import IPython.display as ipd\n", 308 | "ipd.Audio(random_mix(x), rate=dataset.sampling_rate)" 309 | ] 310 | }, 311 | { 312 | "cell_type": "markdown", 313 | "metadata": {}, 314 | "source": [ 315 | "### Stacking data augmentation and feature extraction\n", 316 | "What is really important, is that `audtorch` allows us to do simultaneous data augmentation and feature extraction **on-the-fly**. \n", 317 | "\n", 318 | "This is very useful in the typical case where we run the same training samples multiple times through the network (i.e. when we train for multiple epochs), and would like to slightly change the input every time. All we have to do is stack our data augmentation transforms on top of our feature extraction ones." 319 | ] 320 | }, 321 | { 322 | "cell_type": "code", 323 | "execution_count": null, 324 | "metadata": {}, 325 | "outputs": [], 326 | "source": [ 327 | "t = audtorch.transforms.Compose(\n", 328 | " [\n", 329 | " audtorch.transforms.RandomAdditiveMix(\n", 330 | " dataset=audtorch.datasets.WhiteNoise(sampling_rate=dataset.sampling_rate),\n", 331 | " expand_method='multiple'\n", 332 | " ),\n", 333 | " audtorch.transforms.Spectrogram(\n", 334 | " window_size=int(0.025 * dataset.sampling_rate),\n", 335 | " hop_size=int(0.010 * dataset.sampling_rate)\n", 336 | " ),\n", 337 | " audtorch.transforms.Log(),\n", 338 | " audtorch.transforms.Standardize()\n", 339 | " ]\n", 340 | ")\n", 341 | "print(t)" 342 | ] 343 | }, 344 | { 345 | "cell_type": "markdown", 346 | "metadata": {}, 347 | "source": [ 348 | "We can clearly see how this spectrogram seems noisier than the one we had before. Hopefully, this will be enough to make our classifier generalize better!" 349 | ] 350 | }, 351 | { 352 | "cell_type": "code", 353 | "execution_count": null, 354 | "metadata": {}, 355 | "outputs": [], 356 | "source": [ 357 | "plt.imshow(t(x).squeeze())\n", 358 | "plt.gca().invert_yaxis()" 359 | ] 360 | } 361 | ], 362 | "metadata": { 363 | "celltoolbar": "Slideshow", 364 | "language_info": { 365 | "codemirror_mode": { 366 | "name": "ipython", 367 | "version": 3 368 | }, 369 | "file_extension": ".py", 370 | "mimetype": "text/x-python", 371 | "name": "python", 372 | "nbconvert_exporter": "python", 373 | "pygments_lexer": "ipython3", 374 | "version": "3.7.4" 375 | } 376 | }, 377 | "nbformat": 4, 378 | "nbformat_minor": 2 379 | } 380 | -------------------------------------------------------------------------------- /docs/usage.rst: -------------------------------------------------------------------------------- 1 | Usage 2 | ===== 3 | 4 | :mod:`audtorch` automates the data iteration process for deep neural 5 | network training using PyTorch_. It provides a set of feature extraction 6 | transforms that can be implemented on-the-fly on the CPU. 7 | 8 | The following example creates a data set of speech samples that are cut to a 9 | fixed length of 10240 samples. In addition they are augmented on the fly during 10 | data loading by a transform that adds samples from another data set: 11 | 12 | .. code-block:: python 13 | 14 | >>> import sounddevice as sd 15 | >>> from audtorch import datasets, transforms 16 | >>> noise = datasets.WhiteNoise(duration=10240, sampling_rate=16000) 17 | >>> augment = transforms.Compose([transforms.RandomCrop(10240), 18 | ... transforms.RandomAdditiveMix(noise)]) 19 | >>> data = datasets.LibriSpeech(root='~/LibriSpeech', sets='dev-clean', 20 | ... download=True, transform=augment) 21 | >>> signal, label = data[8] 22 | >>> sd.play(signal.transpose(), data.sampling_rate) 23 | 24 | Besides data sets and transforms the package provides standard evaluation 25 | metrics, samplers, and necessary collate functions for training deep neural 26 | networks for audio tasks. 27 | 28 | .. _PyTorch: https://pytorch.org 29 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | -e . 2 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | name = audtorch 3 | author = Andreas Triantafyllopoulos, Stephan Huber, Johannes Wagner, Hagen Wierstorf 4 | author_email = atriant@audeering.com 5 | description = Deep learning with PyTorch and audio 6 | long_description = file: README.rst, CHANGELOG.rst 7 | license = MIT License 8 | license_file = LICENSE 9 | keywords = audio, torch 10 | url = https://github.com/audeering/audtorch 11 | project_urls = 12 | Documentation = https://audeering.github.io/audtorch/ 13 | Tracker = https://github.com/audeering/audtorch/issues/ 14 | platforms = any 15 | classifiers = 16 | Development Status :: 5 - Production/Stable 17 | Intended Audience :: Developers 18 | Intended Audience :: Science/Research 19 | License :: OSI Approved :: MIT License 20 | Operating System :: OS Independent 21 | Programming Language :: Python 22 | Programming Language :: Python :: 3 23 | Programming Language :: Python :: 3.6 24 | Programming Language :: Python :: 3.7 25 | Programming Language :: Python :: 3.8 26 | Topic :: Multimedia :: Sound/Audio 27 | 28 | [options] 29 | packages = find: 30 | setup_requires = 31 | setuptools_scm 32 | install_requires = 33 | numpy 34 | audiofile 35 | librosa >=0.8.0 36 | resampy 37 | torch 38 | pandas 39 | tqdm 40 | tabulate 41 | tests_require = 42 | pytest 43 | 44 | [tool:pytest] 45 | addopts = 46 | --flake8 47 | --doctest-plus 48 | --cov=audtorch 49 | --cov-report term-missing 50 | --cov-report xml 51 | --cov-fail-under=60 52 | xfail_strict = true 53 | 54 | [flake8] 55 | ignore = 56 | W503 # math, https://github.com/PyCQA/pycodestyle/issues/513 57 | __init__.py F401 F403 # ignore unused and * imports 58 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | setup(use_scm_version=True) 4 | -------------------------------------------------------------------------------- /tests/requirements.txt: -------------------------------------------------------------------------------- 1 | pytest 2 | pytest-cov 3 | pytest-flake8 4 | pytest-doctestplus 5 | librosa 6 | scipy 7 | -------------------------------------------------------------------------------- /tests/test_collate.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | from audtorch import collate 4 | 5 | 6 | # data format FS = (feature dim, sequence dim) 7 | batch_1 = [[torch.zeros(4, 5), torch.zeros(10)], 8 | [torch.zeros(4, 6), torch.zeros(12)]] 9 | sequence_dimensions_1 = [-1, -1] 10 | expected_1 = [[2, 4, 6], [2, 6, 4], [6, 2, 4]] 11 | 12 | # data format CFS = (channel dim, feature dim, sequence dim) 13 | batch_2 = [[torch.zeros(3, 4, 5), torch.zeros(10)], 14 | [torch.zeros(3, 4, 6), torch.zeros(12)]] 15 | sequence_dimensions_2 = [2, 0] 16 | expected_2 = [[2, 3, 4, 6], [2, 6, 3, 4], [6, 2, 3, 4]] 17 | 18 | 19 | @pytest.mark.parametrize("batch,sequence_dimensions,expected", 20 | [(batch_1, sequence_dimensions_1, expected_1), 21 | (batch_2, sequence_dimensions_2, expected_2)]) 22 | @pytest.mark.parametrize("batch_first", [None, True, False]) 23 | def test_seq2seq(batch, sequence_dimensions, expected, batch_first): 24 | 25 | collation = collate.Seq2Seq( 26 | sequence_dimensions=sequence_dimensions, 27 | batch_first=batch_first, 28 | pad_values=[-1, -1]) 29 | output = collation(batch) 30 | 31 | if batch_first is None: 32 | assert list(output[0].shape) == expected[0] 33 | 34 | if batch_first: 35 | assert list(output[0].shape) == expected[1] 36 | 37 | if batch_first is False: 38 | assert list(output[0].shape) == expected[2] 39 | -------------------------------------------------------------------------------- /tests/test_datasets.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import numpy as np 3 | import pandas as pd 4 | import torch 5 | from torch.utils.data import TensorDataset 6 | 7 | from audtorch import (datasets, samplers, transforms) 8 | 9 | 10 | xfail = pytest.mark.xfail 11 | filterwarnings = pytest.mark.filterwarnings 12 | 13 | 14 | # --- datasets/noise.py --- 15 | @pytest.mark.parametrize('duration', [0.01, 0.1, 1]) 16 | @pytest.mark.parametrize('sampling_rate', [8000]) 17 | @pytest.mark.parametrize('mean', [0, 1]) 18 | @pytest.mark.parametrize('stdev', [1, 0.5]) 19 | def test_whitenoise(duration, sampling_rate, mean, stdev): 20 | dataset = datasets.WhiteNoise(duration=duration, 21 | sampling_rate=sampling_rate, 22 | mean=mean, 23 | stdev=stdev) 24 | noise, label = next(iter(dataset)) 25 | samples = int(np.ceil(duration * sampling_rate)) 26 | assert noise.shape == (1, samples) 27 | assert label == 'white noise' 28 | assert -1 <= np.max(np.abs(noise)) <= 1 29 | assert len(dataset) == 1 30 | 31 | 32 | # --- datasets/utils.py --- 33 | crop = transforms.RandomCrop(8192) 34 | resamp1 = transforms.Resample(48000, 44100) 35 | resamp2 = transforms.Resample(44100, 16000) 36 | t1 = transforms.Compose([crop, resamp1]) 37 | t2 = transforms.Compose([crop, resamp1, resamp2]) 38 | t3 = transforms.Compose([resamp1, crop, resamp2]) 39 | d0 = datasets.WhiteNoise(duration=0.5, sampling_rate=48000, transform=crop) 40 | d1 = datasets.WhiteNoise(duration=0.5, sampling_rate=48000, transform=t1) 41 | d2 = datasets.WhiteNoise(duration=0.5, sampling_rate=48000, transform=t2) 42 | d3 = datasets.WhiteNoise(duration=0.5, sampling_rate=48000, transform=t3) 43 | df_empty = pd.DataFrame() 44 | df_a = pd.DataFrame(data=[0], columns=['a']) 45 | df_ab = pd.DataFrame(data=[('0', 1)], columns=['a', 'b']) 46 | 47 | 48 | @pytest.mark.parametrize('list_of_datasets', [ 49 | (d2, d3), 50 | pytest.param([d0, d1], marks=xfail(raises=ValueError)) 51 | ]) 52 | def test_audioconcatdataset(list_of_datasets): 53 | datasets.AudioConcatDataset(list_of_datasets) 54 | 55 | 56 | @pytest.mark.parametrize(('input,expected_output,expected_sampling_rate,' 57 | 'expected_warning'), [ 58 | ('', np.array([[]]), None, 'File does not exist: '), 59 | ]) 60 | def test_load(input, expected_output, expected_sampling_rate, 61 | expected_warning): 62 | with pytest.warns(UserWarning, match=expected_warning): 63 | output, sampling_rate = datasets.load(input) 64 | assert np.array_equal(output, expected_output) 65 | assert sampling_rate == expected_sampling_rate 66 | 67 | 68 | @pytest.mark.parametrize('transform', [t1, t2, t3]) 69 | def test_sampling_rate_after_transform(transform): 70 | expected_sampling_rate = transform.transforms[-1].output_sampling_rate 71 | dataset = datasets.WhiteNoise(duration=0.5, 72 | sampling_rate=48000, 73 | transform=transform) 74 | sampling_rate = datasets.sampling_rate_after_transform(dataset) 75 | assert sampling_rate == expected_sampling_rate 76 | assert dataset.sampling_rate == expected_sampling_rate 77 | 78 | 79 | @pytest.mark.parametrize('list_of_datasets', [ 80 | [d0, d0, d0], 81 | [d1, d1], 82 | [d2, d2], 83 | pytest.param([1], marks=xfail(raises=RuntimeError)), 84 | pytest.param([d0, d1], marks=xfail(raises=ValueError)), 85 | pytest.param([d1, d2], marks=xfail(raises=ValueError)), 86 | ]) 87 | def test_ensure_same_sampling_rate(list_of_datasets): 88 | datasets.ensure_same_sampling_rate(list_of_datasets) 89 | 90 | 91 | @pytest.mark.parametrize('df,labels', [ 92 | pytest.param(df_empty, 'a', marks=xfail(raises=RuntimeError)), 93 | (df_a, 'a'), 94 | pytest.param(df_a, 'b', marks=xfail(raises=RuntimeError)), 95 | (df_ab, ['a', 'b']), 96 | (df_ab, 'a'), 97 | pytest.param(df_ab, 'c', marks=xfail(raises=RuntimeError)), 98 | ]) 99 | def test_ensure_df_columns_contain(df, labels): 100 | datasets.ensure_df_columns_contain(df, labels) 101 | 102 | 103 | @pytest.mark.parametrize('df', [ 104 | pytest.param(df_empty, marks=xfail(raises=RuntimeError)), 105 | df_a, 106 | ]) 107 | def test_ensure_df_not_empty(df): 108 | datasets.ensure_df_not_empty(df) 109 | 110 | 111 | @pytest.mark.parametrize('df,expected_files,expected_labels', [ 112 | (df_ab, ['0'], [1]), 113 | (None, [], []), 114 | pytest.param(df_empty, [], [], marks=xfail(raises=RuntimeError)), 115 | ]) 116 | def test_files_and_labels_from_df(df, expected_files, expected_labels): 117 | files, labels = datasets.files_and_labels_from_df( 118 | df, 119 | column_filename='a', 120 | column_labels='b', 121 | ) 122 | assert files == expected_files 123 | assert labels == expected_labels 124 | _, labels = datasets.files_and_labels_from_df( 125 | df, 126 | column_filename='a', 127 | column_labels=None, 128 | ) 129 | if df is None: 130 | assert labels == [] 131 | else: 132 | assert labels == [''] * len(df) 133 | 134 | 135 | @pytest.mark.parametrize('key_values,split_func,kwargs,expected', [ 136 | (list(range(5)), samplers.buckets_of_even_size, 2, [3, 2]), 137 | (list(range(10)), samplers.buckets_by_boundaries, [0, 3, 8], [3, 5, 2]) 138 | ]) 139 | def test_defined_split(key_values, split_func, kwargs, expected): 140 | data = TensorDataset(torch.arange(len(key_values))) 141 | split_func = split_func(torch.Tensor(key_values), kwargs) 142 | subsets = datasets.defined_split(data, split_func) 143 | 144 | assert expected == [len(subset) for subset in subsets] 145 | -------------------------------------------------------------------------------- /tests/test_metrics.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import scipy 3 | import torch 4 | 5 | import numpy as np 6 | 7 | import audtorch.metrics as metrics 8 | import audtorch.metrics.functional as F 9 | 10 | 11 | @pytest.mark.parametrize('reduction', ['none', 'sum', 'mean']) 12 | def test_energypreservingloss(reduction): 13 | loss = metrics.EnergyConservingLoss(reduction=reduction) 14 | # Random integers as tensors to avoid precision problems with torch.equal 15 | input = torch.rand((3, 5), requires_grad=True) 16 | target = torch.rand((3, 5)) 17 | mixture = torch.rand((3, 5)) 18 | noise = mixture - target 19 | noise_predicted = mixture - input 20 | expected_output = (torch.abs(input - target) 21 | + torch.abs(noise - noise_predicted)) 22 | output = loss(input, target, mixture) 23 | if reduction == 'none': 24 | assert torch.equal(output, expected_output) 25 | elif reduction == 'sum': 26 | assert torch.equal(output, torch.sum(expected_output)) 27 | elif reduction == 'mean': 28 | assert torch.equal(output, torch.mean(expected_output)) 29 | 30 | 31 | @pytest.mark.parametrize('shape', [(5,), (5, 3)]) 32 | def test_pearsonr(shape): 33 | input = torch.rand(shape) 34 | target = torch.rand(shape) 35 | 36 | if len(shape) == 1: 37 | r = F.pearsonr(input, target) 38 | assert r.shape[0] == 1 39 | np.testing.assert_almost_equal( 40 | r.numpy()[0], scipy.stats.pearsonr( 41 | input.numpy(), target.numpy())[0], decimal=6) 42 | else: 43 | r = F.pearsonr(input, target) 44 | assert r.shape[0] == shape[0] 45 | for index, (input_row, target_row) in enumerate(zip(input, target)): 46 | np.testing.assert_almost_equal( 47 | r[index].numpy()[0], scipy.stats.pearsonr( 48 | input_row.numpy(), target_row.numpy())[0], decimal=6) 49 | 50 | r = F.pearsonr(input, target, batch_first=False) 51 | assert r.shape[1] == shape[1] 52 | for index, (input_col, target_col) in enumerate( 53 | zip(input.transpose(0, 1), target.transpose(0, 1))): 54 | np.testing.assert_almost_equal( 55 | r[:, index].numpy()[0], scipy.stats.pearsonr( 56 | input_col.numpy(), target_col.numpy())[0], decimal=6) 57 | 58 | 59 | @pytest.mark.parametrize('shape', [(5,), (5, 3)]) 60 | def test_concordance_cc(shape): 61 | input = torch.rand(shape) 62 | target = torch.rand(shape) 63 | 64 | def concordance_cc(x, y): 65 | r = scipy.stats.pearsonr(x, y)[0] 66 | ccc = 2 * r * x.std() * y.std() / (x.std() * x.std() 67 | + y.std() * y.std() 68 | + (x.mean() - y.mean()) 69 | * (x.mean() - y.mean())) 70 | return ccc 71 | 72 | if len(shape) == 1: 73 | ccc = F.concordance_cc(input, target) 74 | assert ccc.shape[0] == 1 75 | np.testing.assert_almost_equal( 76 | ccc.numpy()[0], concordance_cc(input.numpy(), target.numpy()), 77 | decimal=6) 78 | else: 79 | ccc = F.concordance_cc(input, target) 80 | assert ccc.shape[0] == shape[0] 81 | for index, (input_row, target_row) in enumerate(zip(input, target)): 82 | np.testing.assert_almost_equal( 83 | ccc[index].numpy()[0], concordance_cc( 84 | input_row.numpy(), target_row.numpy()), decimal=6) 85 | 86 | ccc = F.concordance_cc(input, target, batch_first=False) 87 | assert ccc.shape[1] == shape[1] 88 | for index, (input_col, target_col) in enumerate( 89 | zip(input.transpose(0, 1), target.transpose(0, 1))): 90 | np.testing.assert_almost_equal( 91 | ccc[:, index].numpy()[0], concordance_cc( 92 | input_col.numpy(), target_col.numpy()), decimal=6) 93 | -------------------------------------------------------------------------------- /tests/test_samplers.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | import numpy as np 4 | from numpy.random import randint 5 | import random 6 | from bisect import bisect_right 7 | from torch.utils.data import (TensorDataset, ConcatDataset) 8 | 9 | from audtorch.samplers import ( 10 | buckets_of_even_size, buckets_by_boundaries, BucketSampler) 11 | from audtorch.datasets.utils import defined_split 12 | 13 | 14 | # data sets 15 | data_size = 1000 16 | num_feats = 160 17 | max_length = 300 18 | max_feature = 100 19 | lengths = torch.randint(0, max_length, (data_size,)) 20 | inputs = torch.randint(0, max_feature, (data_size, num_feats, max_length)) 21 | data = TensorDataset(inputs) 22 | 23 | # function params 24 | num_buckets = randint(1, 10) 25 | batch_sizes = num_buckets * [randint(1, np.iinfo(np.int8).max)] 26 | num_batches = num_buckets * [randint(0, np.iinfo(np.int8).max)] 27 | bucket_boundaries = [b + num for num, b in enumerate( 28 | sorted(list(random.sample(range(max_length - num_buckets), 29 | (num_buckets - 1)))))] 30 | 31 | 32 | @pytest.mark.parametrize("key_values", [lengths]) 33 | @pytest.mark.parametrize("num_buckets", [num_buckets]) 34 | @pytest.mark.parametrize("reverse", [True, False]) 35 | def test_buckets_of_even_size(key_values, num_buckets, reverse): 36 | 37 | expected_bucket_size, remainders = divmod(data_size, num_buckets) 38 | expected_bucket_dist = remainders * [expected_bucket_size + 1] 39 | expected_bucket_dist += (num_buckets - remainders) * [expected_bucket_size] 40 | 41 | buckets = buckets_of_even_size(key_values=key_values, 42 | num_buckets=num_buckets, 43 | reverse=reverse) 44 | 45 | expected_bucket_ids = set(range(num_buckets)) 46 | bucket_ids = [buckets(i) for i in range(data_size)] 47 | key_values = [key_values[i] for i in range(data_size)] 48 | bucket_dist = [bucket_ids.count(bucket_id) 49 | for bucket_id in range(num_buckets)] 50 | 51 | # do bucket ids only range from 0 to num_buckets-1? 52 | assert expected_bucket_ids == set(bucket_ids) 53 | 54 | sort_indices = sorted(range(data_size), key=lambda idx: key_values[idx], 55 | reverse=reverse) 56 | sorted_buckets = [bucket_ids[idx] for idx in sort_indices] 57 | diff_buckets = np.diff(sorted_buckets) 58 | 59 | # sorted with monotonously increasing/decreasing length of key values? 60 | assert all(diff >= 0 for diff in diff_buckets) 61 | 62 | # are buckets evenly distributed (except for remainders)? 63 | assert expected_bucket_dist == bucket_dist 64 | 65 | 66 | @pytest.mark.parametrize("key_values", [lengths]) 67 | @pytest.mark.parametrize("bucket_boundaries", [bucket_boundaries]) 68 | def test_buckets_by_boundaries(key_values, bucket_boundaries): 69 | 70 | buckets = buckets_by_boundaries(key_values=lengths, 71 | bucket_boundaries=bucket_boundaries) 72 | 73 | num_buckets = len(bucket_boundaries) + 1 74 | expected_bucket_ids = list(range(num_buckets)) 75 | 76 | data_size = key_values.shape[0] 77 | bucket_ids = [buckets(idx) for idx in range(data_size)] 78 | key_values = [key_values[idx] for idx in range(data_size)] 79 | 80 | # do bucket ids only range from 0 to num_buckets-1? 81 | # missing ids only allowed if corresponding bucket empty 82 | missing_buckets = list(set(expected_bucket_ids) - set(bucket_ids)) 83 | for bucket_id in missing_buckets: 84 | if bucket_id == 0: 85 | assert not any([v < bucket_boundaries[bucket_id] 86 | for v in key_values]) 87 | elif bucket_id == expected_bucket_ids[-1]: 88 | assert not any([v > bucket_boundaries[(bucket_id - 1)] 89 | for v in key_values]) 90 | else: 91 | assert not any([v in range(bucket_boundaries[(bucket_id - 1)], 92 | bucket_boundaries[bucket_id]) 93 | for v in key_values]) 94 | 95 | sort_indices = sorted(range(data_size), key=lambda i: key_values[i]) 96 | sorted_buckets = [bucket_ids[i] for i in sort_indices] 97 | diff_buckets = np.diff(sorted_buckets) 98 | 99 | # sorted with monotonously increasing/decreasing length of key values? 100 | assert all(diff >= 0 for diff in diff_buckets) 101 | 102 | 103 | @pytest.mark.parametrize("data", [data]) 104 | @pytest.mark.parametrize( 105 | "key_func", 106 | [buckets_of_even_size(lengths, num_buckets, reverse=False), 107 | buckets_by_boundaries(lengths, bucket_boundaries)]) 108 | @pytest.mark.parametrize("expected_num_datasets", [num_buckets]) 109 | @pytest.mark.parametrize("batch_sizes", [batch_sizes]) 110 | @pytest.mark.parametrize("expected_num_batches", [num_batches]) 111 | @pytest.mark.parametrize("permuted_order", [False, random.shuffle( 112 | list(range(num_buckets)))]) 113 | @pytest.mark.parametrize("shuffle_each_bucket", [True, False]) 114 | @pytest.mark.parametrize("drop_last", [True, False]) 115 | def test_bucket_sampler(data, key_func, expected_num_datasets, 116 | batch_sizes, expected_num_batches, permuted_order, 117 | shuffle_each_bucket, drop_last): 118 | 119 | subsets = defined_split(data, key_func) 120 | concat_dataset = ConcatDataset(subsets) 121 | 122 | batch_sampler = BucketSampler( 123 | concat_dataset=concat_dataset, 124 | batch_sizes=batch_sizes, 125 | num_batches=expected_num_batches, 126 | permuted_order=permuted_order, 127 | shuffle_each_bucket=shuffle_each_bucket, 128 | drop_last=drop_last) 129 | 130 | num_datasets = len(batch_sampler.datasets) 131 | expected_dataset_ids = list(range(num_datasets)) 132 | if isinstance(permuted_order, list) and permuted_order: 133 | expected_dataset_ids = permuted_order 134 | 135 | # assert data sets via batch sampler 136 | batch_indices = list(iter(batch_sampler)) 137 | epoch_batch_sizes = [len(batch) for batch in batch_indices] 138 | dataset_sizes = [len(d) for d in batch_sampler.datasets] 139 | 140 | expected_epoch_batch_sizes = [] 141 | expected_dset_ids = [] 142 | 143 | for i in expected_dataset_ids: 144 | 145 | skip = False 146 | 147 | if batch_sizes[i] == 0 or expected_num_batches[i] == 0: 148 | continue 149 | 150 | fitted_batches = dataset_sizes[i] // batch_sizes[i] 151 | num_batches = fitted_batches 152 | if expected_num_batches[i] <= fitted_batches: 153 | num_batches = expected_num_batches[i] 154 | skip = True 155 | 156 | add_batch_sizes = [batch_sizes[i]] * num_batches 157 | 158 | if not drop_last and not skip: 159 | remainder = dataset_sizes[i] % batch_sizes[i] 160 | if remainder != 0: 161 | add_batch_sizes += [remainder] 162 | 163 | expected_epoch_batch_sizes += add_batch_sizes 164 | 165 | if len(add_batch_sizes) > 0: 166 | expected_dset_ids += [i] 167 | 168 | # are batch sizes as expected? 169 | assert expected_epoch_batch_sizes == epoch_batch_sizes 170 | 171 | unique_batch_ids = [] 172 | for batch in batch_indices: 173 | ids = list(set([bisect_right( 174 | concat_dataset.cumulative_sizes, idx) for idx in batch])) 175 | unique_batch_ids += [ids] 176 | 177 | # are all samples of each batch from identical data set? 178 | assert all(list(map(lambda l: len(l) == 1, unique_batch_ids))) 179 | 180 | # flatten list 181 | unique_batch_ids = [i for ids in unique_batch_ids for i in ids] 182 | 183 | dataset_ids = [] 184 | prev_id = None 185 | 186 | for i in unique_batch_ids: 187 | current_id = i 188 | if prev_id is None: 189 | dataset_ids += [current_id] 190 | else: 191 | if not current_id == prev_id: 192 | dataset_ids += [current_id] 193 | prev_id = current_id 194 | 195 | # are batches drawn from data sets in desired order? 196 | assert expected_dset_ids == dataset_ids 197 | -------------------------------------------------------------------------------- /tests/test_transforms.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | import numpy as np 4 | import scipy 5 | import resampy 6 | import librosa 7 | 8 | import audtorch.transforms as transforms 9 | import audtorch.transforms.functional as F 10 | 11 | xfail = pytest.mark.xfail 12 | 13 | a11 = np.array([1, 2, 3, 4], dtype=float) 14 | a12 = np.array([5, 6, 7, 8], dtype=float) 15 | a21 = np.array([9, 10, 11, 12], dtype=float) 16 | a22 = np.array([13, 14, 15, 16], dtype=float) 17 | ones = np.ones(4) 18 | zeros = np.zeros(4) 19 | A = np.array([[a11, a12], [a21, a22]]) # Tensor of shape (2, 2, 4) 20 | 21 | # Ratio in dB to add two signals to yield 1.5 magnitude 22 | _half_ratio = -10 * np.log10(0.5 ** 2) 23 | 24 | 25 | def _mean(signal, axis): 26 | """Return mean along axis and preserve number of dimensions.""" 27 | return np.expand_dims(np.mean(signal, axis=axis), axis=axis) 28 | 29 | 30 | def _pad(vector, padding, value): 31 | """Add padding to a vector using np.pad.""" 32 | return np.pad(vector, padding, 'constant', constant_values=value) 33 | 34 | 35 | @pytest.mark.parametrize('input,idx,axis,expected_output', [ 36 | (A, (0, 2), -1, np.array([[a11[:2], a12[:2]], [a21[:2], a22[:2]]])), 37 | (A, (1, 2), -2, np.array([[a12], [a22]])), 38 | (A, (0, 1), 0, np.array([[a11, a12]])), 39 | (A, -1, -2, np.array([[a12], [a22]])), 40 | (A, 0, -2, np.array([[a11], [a21]])), 41 | (a11, (1, 2), -1, a11[1:2]), 42 | ]) 43 | def test_crop(input, idx, axis, expected_output): 44 | t = transforms.Crop(idx, axis=axis) 45 | assert np.array_equal(t(input), expected_output) 46 | 47 | 48 | @pytest.mark.parametrize('input,padding,value,axis,expected_output', [ 49 | (A, 1, 0, -1, np.array([[_pad(a11, 1, 0), _pad(a12, 1, 0)], 50 | [_pad(a21, 1, 0), _pad(a22, 1, 0)]])), 51 | (A, (0, 1), 1, 1, np.array([[a11, a12, ones], [a21, a22, ones]])), 52 | (A, (1, 0), 0, 0, np.array([[zeros, zeros], [a11, a12], [a21, a22]])), 53 | (a11, 1, 1, -1, _pad(a11, 1, 1)), 54 | ]) 55 | def test_pad(input, padding, value, axis, expected_output): 56 | t = transforms.Pad(padding, value=value, axis=axis) 57 | assert np.array_equal(t(input), expected_output) 58 | 59 | 60 | @pytest.mark.parametrize('input,repetitions,axis', [ 61 | (A, 2, -1), 62 | (A, 3, 0), 63 | (A, 4, 1), 64 | (a11, 1, 0)]) 65 | def test_replicate(input, repetitions, axis): 66 | expected_output = np.concatenate(tuple([input] * repetitions), axis) 67 | t = transforms.Replicate(repetitions, axis=axis) 68 | assert np.array_equal(t(input), expected_output) 69 | 70 | 71 | @pytest.mark.parametrize('input,size,axis,method', [ 72 | (A, 12, 0, 'pad'), 73 | (A, 12, 1, 'pad'), 74 | (A, 12, 2, 'pad'), 75 | (A, 12, 0, 'replicate'), 76 | (A, 12, 1, 'replicate'), 77 | (A, 12, 2, 'replicate')]) 78 | def test_expand(input, size, axis, method): 79 | t = transforms.Expand(size=size, axis=axis, method=method) 80 | if method == 'pad': 81 | assert np.array_equal(t(input), F.pad( 82 | input, (0, size - input.shape[axis]), axis=axis)) 83 | else: 84 | assert np.array_equal(t(input), F.crop(F.replicate( 85 | input, 86 | repetitions=size // input.shape[axis] + 1, 87 | axis=axis), (0, size), axis=axis)) 88 | 89 | 90 | @pytest.mark.parametrize('input,coverage,max_width,value,axis,' 91 | 'expected_masked_items', [ 92 | (A, 0., 3, 0, -1, 0), 93 | (A, 0.1, 1, 0, -1, 4), 94 | (A, 0.25, 1, 0, -1, 4), 95 | (A, 0.3, 1, 0, -1, 4), 96 | (A, 0.1, 1, 0, -2, 8) 97 | ]) 98 | def test_random_mask(input, coverage, max_width, value, axis, 99 | expected_masked_items): 100 | input = torch.from_numpy(input).clone() 101 | t = transforms.RandomMask(coverage, max_width, value, axis) 102 | masked_items = len((t(input) == value).nonzero()) 103 | assert masked_items == expected_masked_items 104 | 105 | 106 | @pytest.mark.parametrize('input,channels,method,axis,expected_output', [ 107 | (A, 2, 'mean', -2, A), 108 | (A, 1, 'mean', -2, _mean(A, axis=-2)), 109 | (A, 1, 'crop', -2, np.array([[a11], [a21]])), 110 | (A, 0, 'mean', -2, np.empty((2, 0, 4))), # empty array with correct shape 111 | (A, 0, 'crop', -2, np.empty((2, 0, 4))), 112 | (A, 1, 'mean', 0, _mean(A, axis=0)), 113 | (a11, 1, 'crop', -1, np.array([a11[0]])), 114 | (a11, 1, 'crop', -2, np.array(a11)), 115 | ]) 116 | def test_downmix(input, channels, method, axis, expected_output): 117 | t = transforms.Downmix(channels, method=method, axis=axis) 118 | assert np.array_equal(t(input), expected_output) 119 | 120 | 121 | @pytest.mark.parametrize('input,channels,method,axis,expected_output', [ 122 | (A, 2, 'mean', -2, A), 123 | (A, 3, 'mean', -2, np.hstack((A, _mean(A, axis=-2)))), 124 | (A, 3, 'zero', -2, np.hstack((A, [[zeros], [zeros]]))), 125 | (A, 3, 'repeat', -2, np.hstack((A, A[:, -1, None, :]))), 126 | (A, 3, 'mean', 0, np.vstack((A, _mean(A, axis=0)))), 127 | (a11, 2, 'repeat', -2, np.array([a11, a11])), 128 | ]) 129 | def test_upmix(input, channels, method, axis, expected_output): 130 | t = transforms.Upmix(channels, method=method, axis=axis) 131 | assert np.array_equal(t(input), expected_output) 132 | 133 | 134 | @pytest.mark.parametrize('input,channels,axis,expected_output', [ 135 | (A, 2, -2, A), 136 | (A, 3, -2, np.hstack((A, _mean(A, axis=-2)))), 137 | (A, 3, 0, np.vstack((A, _mean(A, axis=0)))), 138 | (A, 1, -2, _mean(A, axis=-2)), 139 | (A, 0, -2, np.empty((2, 0, 4))), # empty array with correct shape 140 | (A, 1, 0, _mean(A, axis=0)), 141 | ]) 142 | def test_remix(input, channels, axis, expected_output): 143 | t = transforms.Remix(channels, axis=axis) 144 | assert np.array_equal(t(input), expected_output) 145 | 146 | 147 | @pytest.mark.parametrize('input,axis,expected_output', [ 148 | (A, None, A / np.max(A)), 149 | (A, -1, np.array([[a11 / max(a11), a12 / max(a12)], 150 | [a21 / max(a21), a22 / max(a22)]])), 151 | (a11, None, a11 / np.max(a11)), 152 | (a11, -1, a11 / np.max(a11)), 153 | ]) 154 | def test_normalize(input, axis, expected_output): 155 | t = transforms.Normalize(axis=axis) 156 | assert np.array_equal(t(input), expected_output) 157 | 158 | 159 | @pytest.mark.parametrize('input,axis,mean,std', [ 160 | (A, None, True, True), 161 | (A, -1, True, True), 162 | (a11, None, True, True), 163 | (a11, -1, True, True), 164 | (A, -1, False, True), 165 | (A, -1, True, False), 166 | (A, -1, False, False), 167 | ]) 168 | def test_standardize(input, axis, mean, std): 169 | t = transforms.Standardize(axis=axis) 170 | output = t(input) 171 | if mean: 172 | np.testing.assert_almost_equal(output.mean(axis=axis).mean(), 0) 173 | if std: 174 | np.testing.assert_almost_equal(output.std(axis=axis).mean(), 1) 175 | 176 | 177 | @pytest.mark.parametrize('input,idx,axis', [ 178 | (A, (0, 2), -1), 179 | (a11, (1, 2), -1), 180 | ]) 181 | def test_compose(input, idx, axis): 182 | t = transforms.Compose([transforms.Crop(idx, axis=axis), 183 | transforms.Normalize(axis=axis)]) 184 | expected_output = F.crop(input, idx, axis=axis) 185 | expected_output = F.normalize(expected_output, axis=axis) 186 | assert np.array_equal(t(input), expected_output) 187 | 188 | 189 | @pytest.mark.parametrize('input,size,axis', [ 190 | (A, 2, -1), 191 | (A, 1, -2), 192 | (A, 1, 0), 193 | (A, 0, -2), 194 | (a11, 3, -1), 195 | ]) 196 | def test_randomcrop(input, size, axis): 197 | t = transforms.RandomCrop(size, axis=axis) 198 | t.fix_randomization = True 199 | assert np.array_equal(t(input), t(input)) 200 | assert np.array_equal(t(input), F.crop(input, t.idx, axis=t.axis)) 201 | 202 | 203 | @pytest.mark.parametrize('input,padding,value,axis', [ 204 | (A, 1, 0, -1), 205 | (A, 2, 1, 1), 206 | (A, 0, 0, 0), 207 | (a11, 1, 1, -1), 208 | ]) 209 | def test_randompad(input, padding, value, axis): 210 | t = transforms.RandomPad(padding, value=value, axis=axis) 211 | t.fix_randomization = True 212 | assert np.array_equal(t(input), t(input)) 213 | expected_output = F.pad(input, t.pad, value=t.value, axis=t.axis) 214 | assert np.array_equal(t(input), expected_output) 215 | 216 | 217 | @pytest.mark.parametrize('input,input_sample_rate,output_sample_rate,axis', [ 218 | (A, 4, 2, -1), 219 | (np.ones([4, 4, 2]), 4, 2, -2), 220 | (a11, 3, 2, -1), 221 | ]) 222 | @pytest.mark.parametrize('method', ['kaiser_best', 'kaiser_fast', 'scipy']) 223 | def test_resample(input, input_sample_rate, output_sample_rate, method, axis): 224 | t = transforms.Resample(input_sample_rate, output_sample_rate, 225 | method=method, axis=axis) 226 | output_length = int(input.shape[axis] * output_sample_rate 227 | / float(input_sample_rate)) 228 | print(input.shape) 229 | if method == 'scipy': 230 | expected_output = scipy.signal.resample(input, output_length, 231 | axis=axis) 232 | else: 233 | expected_output = resampy.resample(input, input_sample_rate, 234 | output_sample_rate, method=method, 235 | axis=axis) 236 | transformed_input = t(input) 237 | assert transformed_input.shape[axis] == output_length 238 | assert np.array_equal(transformed_input, expected_output) 239 | 240 | 241 | @pytest.mark.parametrize('input,window_size,hop_size,axis', [ 242 | (A, 4, 1, 2), 243 | pytest.param(A, 2048, 1024, 2, marks=xfail(raises=ValueError)), 244 | (np.random.normal(size=[2, 3, 16000]), 2048, 1024, 2), 245 | (np.random.normal(size=[2, 16000, 3]), 2048, 1024, 1), 246 | (np.random.normal(size=[16000, 2, 3]), 2048, 1024, 0), 247 | (np.random.normal(size=[16000, 2, 3]), 2048, 1024, 0), 248 | (np.random.normal(size=16000), 2048, 1024, 0), 249 | ]) 250 | def test_stft(input, window_size, hop_size, axis): 251 | t = transforms.Spectrogram(window_size, hop_size, axis=axis) 252 | spectrogram = F.stft(input, window_size, hop_size, axis=axis) 253 | magnitude, phase = librosa.magphase(spectrogram) 254 | assert np.array_equal(t(input), magnitude) 255 | assert np.array_equal(t.phase, phase) 256 | 257 | 258 | @pytest.mark.parametrize('input,magnitude_boost', [ 259 | (A, 1e-07), 260 | pytest.param(A, -1e-07, marks=xfail(raises=ValueError)), 261 | pytest.param(A, -1, marks=xfail(raises=ValueError)), 262 | (np.random.normal(size=[2, 3, 16000]), 1e-03), 263 | (np.random.normal(size=[2, 16000, 3]), 1e-07), 264 | (np.random.normal(size=[16000, 2, 3]), 1e-07), 265 | (np.random.normal(size=16000), 1e-07), 266 | ]) 267 | def test_log(input, magnitude_boost): 268 | input = input + abs(input.min()) 269 | t_log = transforms.Log(magnitude_boost=magnitude_boost) 270 | assert np.array_equal(t_log(input), np.log(input + magnitude_boost)) 271 | 272 | 273 | @pytest.mark.parametrize('signal1,signal2,ratio,' 274 | 'percentage_silence,expected_signal', [ 275 | (A, A, 0, 0, 2 * A), 276 | (A, A, _half_ratio, 0, 1.5 * A), 277 | (a11, a11, 0, 0, 2 * a11), 278 | (a11, a11, _half_ratio, 0, 1.5 * a11), 279 | (A, np.zeros_like(A), 0, 0, A), 280 | (np.zeros_like(A), A, 0, 0, A), 281 | (A, np.zeros_like(A), 0, 1, A), 282 | ]) 283 | def test_randomadditivemix(signal1, signal2, ratio, percentage_silence, 284 | expected_signal): 285 | mixed_signal = F.additive_mix(signal1, signal2, ratio) 286 | assert np.array_equal(mixed_signal, expected_signal) 287 | augmentation_data_set = [[signal2]] 288 | t = transforms.RandomAdditiveMix(dataset=augmentation_data_set, 289 | ratios=[ratio], 290 | percentage_silence=percentage_silence) 291 | assert np.array_equal(t(signal1), expected_signal) 292 | t = transforms.RandomAdditiveMix(dataset=augmentation_data_set) 293 | transform = t(signal1) 294 | functional = F.additive_mix(signal1, signal2, t.ratio) 295 | assert np.array_equal(transform, functional) 296 | 297 | 298 | @pytest.mark.parametrize('signal,impulse_response,axis,expected_signal', [ 299 | (A, [1], -1, A), 300 | (A, [0], -1, 0 * A), 301 | (a11, a11, -1, np.convolve(a11, a11)), 302 | (A, a11, -2, np.apply_along_axis(np.convolve, -2, A, a11)) 303 | ]) 304 | def test_random_convolutional_mix(signal, impulse_response, 305 | axis, expected_signal): 306 | augmentation_dataset = [[impulse_response]] 307 | t = transforms.RandomConvolutionalMix(augmentation_dataset, axis=axis) 308 | convolved_signal = t(signal) 309 | assert np.array_equal(convolved_signal, expected_signal) 310 | expected_length = signal.shape[axis] + len(np.array(impulse_response)) - 1 311 | assert convolved_signal.shape[axis] == expected_length 312 | -------------------------------------------------------------------------------- /tests/test_transforms_functional.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | import numpy as np 4 | import librosa 5 | 6 | import audtorch.transforms.functional as F 7 | 8 | xfail = pytest.mark.xfail 9 | 10 | a11 = np.array([1, 2, 3, 4], dtype=float) 11 | a12 = np.array([5, 6, 7, 8], dtype=float) 12 | a21 = np.array([9, 10, 11, 12], dtype=float) 13 | a22 = np.array([13, 14, 15, 16], dtype=float) 14 | ones = np.ones(4) 15 | zeros = np.zeros(4) 16 | A = np.array([[a11, a12], [a21, a22]]) # Tensor of shape (2, 2, 4) 17 | 18 | # Ratio in dB to add two inputs to yield 1.5 magnitude 19 | _half_ratio = -10 * np.log10(0.5 ** 2) 20 | 21 | 22 | def _mean(input, axis): 23 | """Return mean along axis and preserve number of dimensions.""" 24 | return np.expand_dims(np.mean(input, axis=axis), axis=axis) 25 | 26 | 27 | def _pad(vector, padding, value): 28 | """Add padding to a vector using np.pad.""" 29 | return np.pad(vector, padding, 'constant', constant_values=value) 30 | 31 | 32 | @pytest.mark.parametrize('input,idx,axis,expected_output', [ 33 | (A, (0, 2), -1, np.array([[a11[:2], a12[:2]], [a21[:2], a22[:2]]])), 34 | (A, (1, 2), -2, np.array([[a12], [a22]])), 35 | (A, (0, 1), 0, np.array([[a11, a12]])), 36 | (A, -1, -2, np.array([[a12], [a22]])), 37 | (A, 0, -2, np.array([[a11], [a21]])), 38 | (a11, (1, 2), -1, a11[1:2]), 39 | ]) 40 | def test_crop(input, idx, axis, expected_output): 41 | output = F.crop(input, idx, axis=axis) 42 | assert np.array_equal(output, expected_output) 43 | 44 | 45 | @pytest.mark.parametrize('input,padding,value,axis,expected_output', [ 46 | (A, 1, 0, -1, np.array([[_pad(a11, 1, 0), _pad(a12, 1, 0)], 47 | [_pad(a21, 1, 0), _pad(a22, 1, 0)]])), 48 | (A, (0, 1), 1, 1, np.array([[a11, a12, ones], [a21, a22, ones]])), 49 | (A, (1, 0), 0, 0, np.array([[zeros, zeros], [a11, a12], [a21, a22]])), 50 | (a11, 1, 1, -1, _pad(a11, 1, 1)), 51 | ]) 52 | def test_pad(input, padding, value, axis, expected_output): 53 | output = F.pad(input, padding, value=value, axis=axis) 54 | assert np.array_equal(output, expected_output) 55 | 56 | 57 | @pytest.mark.parametrize('input,repetitions,axis', [ 58 | (A, 2, -1), 59 | (A, 3, 0), 60 | (A, 4, 1), 61 | (a11, 1, 0)]) 62 | def test_replicate(input, repetitions, axis): 63 | expected_output = np.concatenate(tuple([input] * repetitions), axis) 64 | output = F.replicate(input, repetitions, axis=axis) 65 | assert np.array_equal(output, expected_output) 66 | 67 | 68 | @pytest.mark.parametrize('input,num_blocks,max_width,value,axis,' 69 | 'min_expected_items', [ 70 | (A, 0, 1, 0, -1, 0), 71 | (A, 1, 1, 0, -1, 4), 72 | (A, 2, 1, 0, -1, 4), 73 | (A, 1, 1, -1, -1, 4), 74 | (A, 1, 1, 0, -2, 8) 75 | ]) 76 | def test_mask(input, num_blocks, max_width, value, axis, 77 | min_expected_items): 78 | input = torch.from_numpy(input).clone() 79 | masked_signal = F.mask( 80 | input, num_blocks, max_width, value=value, axis=axis) 81 | masked_items = len((masked_signal == value).nonzero()) 82 | if num_blocks <= 1 and max_width == 1: 83 | assert masked_items == min_expected_items 84 | else: 85 | # widths can vary, blocks can overlap 86 | assert masked_items % min_expected_items == 0 87 | 88 | 89 | @pytest.mark.parametrize('input,channels,method,axis,expected_output', [ 90 | (A, 2, 'mean', -2, A), 91 | (A, 1, 'mean', -2, _mean(A, axis=-2)), 92 | (A, 1, 'crop', -2, np.array([[a11], [a21]])), 93 | (A, 0, 'mean', -2, np.empty((2, 0, 4))), # empty array with correct shape 94 | (A, 0, 'crop', -2, np.empty((2, 0, 4))), 95 | (A, 1, 'mean', 0, _mean(A, axis=0)), 96 | (a11, 1, 'crop', -1, np.array([a11[0]])), 97 | (a11, 1, 'crop', -2, np.array(a11)), 98 | ]) 99 | def test_downmix(input, channels, method, axis, expected_output): 100 | output = F.downmix(input, channels, method=method, axis=axis) 101 | assert np.array_equal(output, expected_output) 102 | 103 | 104 | @pytest.mark.parametrize('input,channels,method,axis,expected_output', [ 105 | (A, 2, 'mean', -2, A), 106 | (A, 3, 'mean', -2, np.hstack((A, _mean(A, axis=-2)))), 107 | (A, 3, 'zero', -2, np.hstack((A, [[zeros], [zeros]]))), 108 | (A, 3, 'repeat', -2, np.hstack((A, A[:, -1, None, :]))), 109 | (A, 3, 'mean', 0, np.vstack((A, _mean(A, axis=0)))), 110 | (a11, 2, 'repeat', -2, np.array([a11, a11])), 111 | ]) 112 | def test_upmix(input, channels, method, axis, expected_output): 113 | output = F.upmix(input, channels, method=method, axis=axis) 114 | assert np.array_equal(output, expected_output) 115 | 116 | 117 | @pytest.mark.parametrize(('input1,input2,ratio,percentage_silence,' 118 | 'expected_output'), [ 119 | (A, A, 0, 0, 2 * A), 120 | (A, A, _half_ratio, 0, 1.5 * A), 121 | (a11, a11, 0, 0, 2 * a11), 122 | (a11, a11, _half_ratio, 0, 1.5 * a11), 123 | (A, np.zeros_like(A), 0, 0, A), 124 | (np.zeros_like(A), A, 0, 0, A), 125 | (A, np.zeros_like(A), 0, 1, A), 126 | ]) 127 | def test_additivemix(input1, input2, ratio, percentage_silence, 128 | expected_output): 129 | output = F.additive_mix(input1, input2, ratio) 130 | assert np.array_equal(output, expected_output) 131 | 132 | 133 | @pytest.mark.parametrize('input,axis,expected_output', [ 134 | (A, None, A / np.max(A)), 135 | (A, -1, np.array([[a11 / max(a11), a12 / max(a12)], 136 | [a21 / max(a21), a22 / max(a22)]])), 137 | (a11, None, a11 / np.max(a11)), 138 | (a11, -1, a11 / np.max(a11)), 139 | ([a11, a12], 1, np.array([a11 / max(a11), a12 / max(a12)])), 140 | ([[1, 4], [3, 2]], 0, np.array([[1 / 3, 4 / 4], [3 / 3, 2 / 4]])), 141 | ]) 142 | def test_normalize(input, axis, expected_output): 143 | output = F.normalize(input, axis=axis) 144 | assert np.array_equal(output, expected_output) 145 | 146 | 147 | @pytest.mark.parametrize('input,axis,mean,std', [ 148 | (A, None, True, True), 149 | (A, -1, True, True), 150 | (a11, None, True, True), 151 | (a11, -1, True, True), 152 | (A, -1, False, True), 153 | (A, -1, True, False), 154 | (A, -1, False, False), 155 | ]) 156 | def test_standardize(input, axis, mean, std): 157 | output = F.standardize(input, axis=axis, mean=mean, std=std) 158 | if mean: 159 | np.testing.assert_almost_equal(output.mean(axis=axis).mean(), 0) 160 | if std: 161 | np.testing.assert_almost_equal(output.std(axis=axis).mean(), 1) 162 | 163 | 164 | @pytest.mark.parametrize('input,window_size,hop_size,axis', [ 165 | (A, 4, 1, 2), 166 | pytest.param(A, 2048, 1024, 2, marks=xfail(raises=ValueError)), 167 | pytest.param(A, 3, 1, 2, marks=xfail), 168 | (np.random.normal(size=[2, 3, 16000]), 2048, 1024, 2), 169 | (np.random.normal(size=[2, 16000, 3]), 2048, 1024, 1), 170 | (np.random.normal(size=[16000, 2, 3]), 2048, 1024, 0), 171 | (np.random.normal(size=16000), 2048, 1024, 0), 172 | ]) 173 | def test_stft(input, window_size, hop_size, axis): 174 | expected_output = input 175 | samples = input.shape[axis] 176 | spectrogram = F.stft(input, window_size, hop_size, axis=axis) 177 | magnitude, phase = librosa.magphase(spectrogram) 178 | output = F.istft(spectrogram, window_size, hop_size, axis=axis) 179 | output = F.crop(output, (0, samples), axis=axis) 180 | np.testing.assert_almost_equal(output, expected_output, decimal=6) 181 | -------------------------------------------------------------------------------- /tests/test_utils.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import numpy as np 3 | 4 | import audtorch as at 5 | 6 | 7 | xfail = pytest.mark.xfail 8 | 9 | 10 | @pytest.mark.parametrize('nested_list,expected_list', [ 11 | ([1, 2, 3, [4], [], [[[[[[[[[5]]]]]]]]]], [1, 2, 3, 4, 5]), 12 | ([[1, 2], 3], [1, 2, 3]), 13 | ([1, 2, 3], [1, 2, 3]), 14 | ]) 15 | def test_flatten_list(nested_list, expected_list): 16 | flattened_list = at.utils.flatten_list(nested_list) 17 | assert flattened_list == expected_list 18 | 19 | 20 | @pytest.mark.parametrize('input,tuple_len,expected_output', [ 21 | ('aa', 2, ('a', 'a')), 22 | (2, 1, (2,)), 23 | (1, 3, (1, 1, 1)), 24 | ((1, (1, 2)), 2, (1, (1, 2))), 25 | ([1, 2], 2, (1, 2)), 26 | pytest.param([1], 2, [], marks=xfail(raises=ValueError)), 27 | pytest.param([], 2, [], marks=xfail(raises=ValueError)), 28 | ]) 29 | def test_to_tuple(input, tuple_len, expected_output): 30 | output = at.utils.to_tuple(input, tuple_len=tuple_len) 31 | assert output == expected_output 32 | 33 | 34 | @pytest.mark.parametrize('input,expected_output', [ 35 | (np.array([[2, 2]]), 8), 36 | ]) 37 | def test_energy(input, expected_output): 38 | output = at.utils.energy(input) 39 | assert output == expected_output 40 | 41 | 42 | @pytest.mark.parametrize('input,expected_output', [ 43 | (np.array([[2, 2]]), 4), 44 | ]) 45 | def test_power(input, expected_output): 46 | output = at.utils.power(input) 47 | assert output == expected_output 48 | 49 | 50 | @pytest.mark.parametrize('n_workers,task_fun,params', [ 51 | (3, lambda x, n: x ** n, [(2, n) for n in range(10)]), 52 | ]) 53 | def test_run_worker_threads(n_workers, task_fun, params): 54 | list1 = at.utils.run_worker_threads(n_workers, task_fun, params) 55 | list2 = [task_fun(*p) for p in params] 56 | assert len(list1) == len(list2) and list1 == list2 57 | --------------------------------------------------------------------------------