├── asteroid ├── scripts │ ├── __init__.py │ └── asteroid_versions.py ├── losses │ ├── bark_matrix_16k.mat │ ├── bark_matrix_8k.mat │ ├── soft_f1.py │ └── __init__.py ├── engine │ └── __init__.py ├── models │ ├── README.md │ └── dccrnet.py ├── masknn │ ├── __init__.py │ └── _dccrn_architectures.py ├── dsp │ ├── __init__.py │ ├── normalization.py │ └── vad.py ├── __init__.py ├── data │ ├── utils.py │ ├── __init__.py │ └── vad_dataset.py └── utils │ ├── __init__.py │ └── test_utils.py ├── egs ├── TAC │ ├── utils │ ├── README.md │ └── local │ │ ├── conf.yml │ │ └── parse_data.py ├── LibriVAD │ ├── utils │ ├── local │ │ ├── prepare_data.sh │ │ ├── generate_librivad.sh │ │ └── conf.yml │ └── README.md ├── demask │ ├── utils │ └── local │ │ ├── parse_data.py │ │ └── conf.yml ├── wham │ ├── DPRNN │ │ ├── utils │ │ ├── README.md │ │ └── local │ │ │ ├── conf.yml │ │ │ ├── prepare_data.sh │ │ │ ├── convert_sphere2wav.sh │ │ │ └── preprocess_wham.py │ ├── TwoStep │ │ ├── utils │ │ ├── local │ │ │ ├── conf.yml │ │ │ ├── prepare_data.sh │ │ │ ├── convert_sphere2wav.sh │ │ │ └── preprocess_wham.py │ │ └── README.md │ ├── DynamicMixing │ │ ├── requirements.txt │ │ ├── README.md │ │ ├── utils │ │ │ ├── prepare_python_env.sh │ │ │ └── get_training_stats.py │ │ └── local │ │ │ ├── prepare_data.sh │ │ │ ├── conf.yml │ │ │ ├── resample_dataset.py │ │ │ ├── convert_sphere2wav.sh │ │ │ └── preprocess_wham.py │ ├── DPTNet │ │ ├── README.md │ │ ├── local │ │ │ ├── conf.yml │ │ │ ├── prepare_data.sh │ │ │ ├── convert_sphere2wav.sh │ │ │ └── preprocess_wham.py │ │ └── utils │ │ │ └── prepare_python_env.sh │ ├── ConvTasNet │ │ ├── local │ │ │ ├── conf.yml │ │ │ ├── prepare_data.sh │ │ │ ├── convert_sphere2wav.sh │ │ │ └── preprocess_wham.py │ │ ├── utils │ │ │ └── prepare_python_env.sh │ │ └── README.md │ ├── README.md │ ├── FilterbankDesign │ │ └── local │ │ │ └── conf.yml │ └── MixIT │ │ ├── local │ │ ├── conf.yml │ │ ├── prepare_data.sh │ │ ├── convert_sphere2wav.sh │ │ └── preprocess_wham.py │ │ └── README.md ├── fuss │ ├── baseline │ │ └── utils │ └── README.md ├── librimix │ ├── DCCRNet │ │ ├── utils │ │ ├── local │ │ │ ├── generate_librimix.sh │ │ │ ├── prepare_data.sh │ │ │ ├── conf.yml │ │ │ ├── get_text.py │ │ │ └── create_local_metadata.py │ │ └── README.md │ ├── DCUNet │ │ ├── utils │ │ ├── local │ │ │ ├── generate_librimix.sh │ │ │ ├── prepare_data.sh │ │ │ ├── conf.yml │ │ │ ├── get_text.py │ │ │ └── create_local_metadata.py │ │ └── README.md │ ├── DPTNet │ │ ├── utils │ │ ├── local │ │ │ ├── generate_librimix.sh │ │ │ ├── prepare_data.sh │ │ │ ├── conf.yml │ │ │ ├── get_text.py │ │ │ └── create_local_metadata.py │ │ └── README.md │ ├── ConvTasNet │ │ ├── utils │ │ ├── local │ │ │ ├── generate_librimix.sh │ │ │ ├── prepare_data.sh │ │ │ ├── conf.yml │ │ │ ├── get_text.py │ │ │ └── create_local_metadata.py │ │ └── README.md │ ├── DPRNNTasNet │ │ ├── utils │ │ ├── local │ │ │ ├── generate_librimix.sh │ │ │ ├── prepare_data.sh │ │ │ ├── conf.yml │ │ │ ├── get_text.py │ │ │ └── create_local_metadata.py │ │ └── README.md │ ├── SuDORMRFNet │ │ ├── utils │ │ ├── README.md │ │ └── local │ │ │ ├── generate_librimix.sh │ │ │ ├── prepare_data.sh │ │ │ ├── conf.yml │ │ │ ├── get_text.py │ │ │ └── create_local_metadata.py │ ├── SuDORMRFImprovedNet │ │ ├── utils │ │ ├── README.md │ │ └── local │ │ │ ├── generate_librimix.sh │ │ │ ├── prepare_data.sh │ │ │ ├── conf.yml │ │ │ ├── get_text.py │ │ │ └── create_local_metadata.py │ └── README.md ├── musdb18 │ ├── X-UMX │ │ ├── utils │ │ ├── requirements.txt │ │ └── local │ │ │ └── conf.yml │ └── README.md ├── sms_wsj │ ├── CaCGMM │ │ ├── utils │ │ ├── local │ │ │ ├── conf.yml │ │ │ └── prepare_data.sh │ │ ├── README.md │ │ ├── start_evaluation.py │ │ └── run.sh │ └── README.md ├── whamr │ ├── TasNet │ │ ├── utils │ │ └── local │ │ │ ├── conf.yml │ │ │ ├── prepare_data.sh │ │ │ └── convert_sphere2wav.sh │ └── README.md ├── dampvsep │ ├── ConvTasNet │ │ ├── utils │ │ ├── local │ │ │ ├── prepare_data.sh │ │ │ └── conf.yml │ │ └── README.md │ └── README.md ├── kinect-wsj │ ├── DeepClustering │ │ ├── utils │ │ ├── README.md │ │ ├── requirements.txt │ │ ├── model.py │ │ └── local │ │ │ ├── conf.yml │ │ │ └── convert_sphere2wav.sh │ └── README.md ├── wsj0-mix │ ├── DeepClustering │ │ ├── utils │ │ ├── requirements.txt │ │ ├── README.md │ │ └── local │ │ │ ├── conf.yml │ │ │ ├── convert_sphere2wav.sh │ │ │ └── preprocess_wsj0mix.py │ └── README.md ├── dns_challenge_INTERSPEECH2020 │ ├── baseline │ │ ├── utils │ │ └── local │ │ │ ├── download_data.sh │ │ │ ├── conf.yml │ │ │ ├── create_dns_dataset.sh │ │ │ └── install_git_lfs.sh │ └── README.md ├── wsj0-mix-var │ └── Multi-Decoder-DPRNN │ │ ├── requirements.txt │ │ ├── .vscode │ │ └── settings.json │ │ ├── local │ │ ├── conf.yml │ │ ├── convert_sphere2wav.sh │ │ └── preprocess_wsj0mix.py │ │ ├── utils │ │ └── prepare_python_env.sh │ │ └── separate.py └── avspeech │ ├── looking-to-listen │ ├── local │ │ ├── __init__.py │ │ ├── loader │ │ │ ├── __init__.py │ │ │ ├── constants │ │ │ │ └── __init__.py │ │ │ ├── remove_corrupt.py │ │ │ └── remove_empty_audio.py │ │ ├── postprocess │ │ │ ├── __init__.py │ │ │ └── postprocess_audio.py │ │ ├── conf.yml │ │ ├── requirements.txt │ │ └── data_prep.yml │ └── train │ │ ├── __init__.py │ │ ├── config.py │ │ └── metric_utils.py │ └── README.md ├── .gitattributes ├── docs ├── source │ ├── readmes │ │ ├── egs_README.md │ │ ├── CONTRIBUTING.md │ │ ├── fuss_README.md │ │ ├── wham_README.md │ │ ├── whamr_README.md │ │ ├── musdb18_README.md │ │ ├── sms_wsj_README.md │ │ ├── avspeech_README.md │ │ ├── dampvsep_README.md │ │ ├── kinect-wsj_README.md │ │ ├── librimix_README.md │ │ ├── wsj0-mix_README.md │ │ └── dns_challenge_README.md │ ├── _static │ │ └── images │ │ │ ├── favicon.ico │ │ │ ├── train_val_loss.png │ │ │ ├── asteroid_logo_dark.png │ │ │ └── code_example_croped.png │ ├── package_reference │ │ ├── system.rst │ │ ├── blocks.rst │ │ ├── utils.rst │ │ ├── models.rst │ │ ├── data.rst │ │ ├── optimizers.rst │ │ ├── dsp.rst │ │ └── filterbanks.rst │ ├── cli.rst │ ├── _templates │ │ └── theme_variables.jinja │ ├── supported_datasets.rst │ ├── installation.rst │ └── why_use_asteroid.rst ├── Makefile └── make.bat ├── tests ├── version_consistency │ └── dummy_test.py ├── engine │ ├── __init__.py │ └── system_test.py ├── losses │ └── __init__.py ├── dsp │ ├── normalization_test.py │ ├── vad_test.py │ ├── deltas_tests.py │ ├── overlap_add_test.py │ ├── spatial_test.py │ └── consistency_test.py ├── jit │ ├── __init__.py │ ├── jit_torch_utils_test.py │ ├── jit_masknn_test.py │ └── jit_filterbanks_test.py ├── README.md ├── cli_test.sh ├── binarize_test.py ├── models │ └── fasnet_test.py ├── masknn │ ├── norms_test.py │ └── convolutional_test.py ├── utils │ └── hub_utils_test.py ├── cli_setup.py └── cli_test.py ├── MANIFEST.in ├── requirements ├── dev.txt ├── install.txt ├── torchhub.txt └── docs.txt ├── environment.yml ├── .coveragerc ├── requirements.txt ├── .github ├── ISSUE_TEMPLATE │ ├── documentation.md │ ├── how-to-question.md │ ├── feature_request.md │ └── bug_report.md └── workflows │ ├── rebase.yml │ ├── lint.yml │ ├── test_formatting.yml │ └── test_torch_hub.yml ├── codecov.yml ├── .pre-commit-config.yaml ├── pyproject.toml ├── .readthedocs.yml ├── .flake8 ├── model_card_template.md └── LICENSE /asteroid/scripts/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /egs/TAC/utils: -------------------------------------------------------------------------------- 1 | ../wham/ConvTasNet/utils/ -------------------------------------------------------------------------------- /egs/LibriVAD/utils: -------------------------------------------------------------------------------- 1 | ../wham/ConvTasNet/utils -------------------------------------------------------------------------------- /egs/demask/utils: -------------------------------------------------------------------------------- 1 | ../wham/ConvTasNet/utils/ -------------------------------------------------------------------------------- /egs/wham/DPRNN/utils: -------------------------------------------------------------------------------- 1 | ../../wham/ConvTasNet/utils/ -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | notebooks/* linguist-vendored=true 2 | -------------------------------------------------------------------------------- /egs/fuss/baseline/utils: -------------------------------------------------------------------------------- 1 | ../../wham/ConvTasNet/utils/ -------------------------------------------------------------------------------- /egs/librimix/DCCRNet/utils: -------------------------------------------------------------------------------- 1 | ../../wham/ConvTasNet/utils -------------------------------------------------------------------------------- /egs/librimix/DCUNet/utils: -------------------------------------------------------------------------------- 1 | ../../wham/ConvTasNet/utils -------------------------------------------------------------------------------- /egs/librimix/DPTNet/utils: -------------------------------------------------------------------------------- 1 | ../../wham/ConvTasNet/utils -------------------------------------------------------------------------------- /egs/musdb18/X-UMX/utils: -------------------------------------------------------------------------------- 1 | ../../wham/ConvTasNet/utils -------------------------------------------------------------------------------- /egs/sms_wsj/CaCGMM/utils: -------------------------------------------------------------------------------- 1 | ../../wham/ConvTasNet/utils/ -------------------------------------------------------------------------------- /egs/wham/TwoStep/utils: -------------------------------------------------------------------------------- 1 | ../../wham/ConvTasNet/utils/ -------------------------------------------------------------------------------- /egs/whamr/TasNet/utils: -------------------------------------------------------------------------------- 1 | ../../wham/ConvTasNet/utils/ -------------------------------------------------------------------------------- /docs/source/readmes/egs_README.md: -------------------------------------------------------------------------------- 1 | ../../../egs/README.md -------------------------------------------------------------------------------- /egs/dampvsep/ConvTasNet/utils: -------------------------------------------------------------------------------- 1 | ../../wham/ConvTasNet/utils/ -------------------------------------------------------------------------------- /egs/librimix/ConvTasNet/utils: -------------------------------------------------------------------------------- 1 | ../../wham/ConvTasNet/utils/ -------------------------------------------------------------------------------- /egs/librimix/DPRNNTasNet/utils: -------------------------------------------------------------------------------- 1 | ../../wham/ConvTasNet/utils -------------------------------------------------------------------------------- /egs/librimix/SuDORMRFNet/utils: -------------------------------------------------------------------------------- 1 | ../../wham/ConvTasNet/utils -------------------------------------------------------------------------------- /docs/source/readmes/CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | ../../../CONTRIBUTING.md -------------------------------------------------------------------------------- /docs/source/readmes/fuss_README.md: -------------------------------------------------------------------------------- 1 | ../../../egs/fuss/README.md -------------------------------------------------------------------------------- /docs/source/readmes/wham_README.md: -------------------------------------------------------------------------------- 1 | ../../../egs/wham/README.md -------------------------------------------------------------------------------- /docs/source/readmes/whamr_README.md: -------------------------------------------------------------------------------- 1 | ../../../egs/whamr/README.md -------------------------------------------------------------------------------- /egs/kinect-wsj/DeepClustering/utils: -------------------------------------------------------------------------------- 1 | ../../wham/ConvTasNet/utils -------------------------------------------------------------------------------- /egs/wsj0-mix/DeepClustering/utils: -------------------------------------------------------------------------------- 1 | ../../wham/ConvTasNet/utils -------------------------------------------------------------------------------- /docs/source/readmes/musdb18_README.md: -------------------------------------------------------------------------------- 1 | ../../../egs/musdb18/README.md -------------------------------------------------------------------------------- /docs/source/readmes/sms_wsj_README.md: -------------------------------------------------------------------------------- 1 | ../../../egs/sms_wsj/README.md -------------------------------------------------------------------------------- /egs/librimix/SuDORMRFImprovedNet/utils: -------------------------------------------------------------------------------- 1 | ../../wham/ConvTasNet/utils -------------------------------------------------------------------------------- /egs/librimix/SuDORMRFNet/README.md: -------------------------------------------------------------------------------- 1 | ### Results 2 | 3 | 4 | Coming -------------------------------------------------------------------------------- /docs/source/readmes/avspeech_README.md: -------------------------------------------------------------------------------- 1 | ../../../egs/avspeech/README.md -------------------------------------------------------------------------------- /docs/source/readmes/dampvsep_README.md: -------------------------------------------------------------------------------- 1 | ../../../egs/dampvsep/README.md -------------------------------------------------------------------------------- /docs/source/readmes/kinect-wsj_README.md: -------------------------------------------------------------------------------- 1 | ../../../egs/kinect-wsj/README.md -------------------------------------------------------------------------------- /docs/source/readmes/librimix_README.md: -------------------------------------------------------------------------------- 1 | ../../../egs/librimix/README.md -------------------------------------------------------------------------------- /docs/source/readmes/wsj0-mix_README.md: -------------------------------------------------------------------------------- 1 | ../../../egs/wsj0-mix/README.md -------------------------------------------------------------------------------- /egs/kinect-wsj/DeepClustering/README.md: -------------------------------------------------------------------------------- 1 | # Results 2 | Coming soon 3 | -------------------------------------------------------------------------------- /egs/kinect-wsj/DeepClustering/requirements.txt: -------------------------------------------------------------------------------- 1 | scikit-learn>=0.20.2 2 | -------------------------------------------------------------------------------- /egs/wham/DynamicMixing/requirements.txt: -------------------------------------------------------------------------------- 1 | pysndfx>=1.16.4 2 | scipy>=1.4.1 -------------------------------------------------------------------------------- /egs/wsj0-mix/DeepClustering/requirements.txt: -------------------------------------------------------------------------------- 1 | scikit-learn>=0.20.2 2 | -------------------------------------------------------------------------------- /egs/dns_challenge_INTERSPEECH2020/baseline/utils: -------------------------------------------------------------------------------- 1 | ../../wham/ConvTasNet/utils/ -------------------------------------------------------------------------------- /egs/kinect-wsj/DeepClustering/model.py: -------------------------------------------------------------------------------- 1 | ../../wsj0-mix/DeepClustering/model.py -------------------------------------------------------------------------------- /tests/version_consistency/dummy_test.py: -------------------------------------------------------------------------------- 1 | def dummy_test(): 2 | pass 3 | -------------------------------------------------------------------------------- /docs/source/readmes/dns_challenge_README.md: -------------------------------------------------------------------------------- 1 | ../../../egs/dns_challenge/README.md -------------------------------------------------------------------------------- /egs/librimix/SuDORMRFImprovedNet/README.md: -------------------------------------------------------------------------------- 1 | ### Results 2 | 3 | 4 | 5 | Coming -------------------------------------------------------------------------------- /egs/wsj0-mix-var/Multi-Decoder-DPRNN/requirements.txt: -------------------------------------------------------------------------------- 1 | asteroid 2 | numpy 3 | librosa -------------------------------------------------------------------------------- /egs/avspeech/looking-to-listen/local/__init__.py: -------------------------------------------------------------------------------- 1 | from .postprocess import filter_audio, shelf 2 | -------------------------------------------------------------------------------- /egs/wsj0-mix-var/Multi-Decoder-DPRNN/.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "ros.distro": "noetic" 3 | } -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include asteroid/losses/bark_matrix_8k.mat 2 | include asteroid/losses/bark_matrix_16k.mat -------------------------------------------------------------------------------- /egs/avspeech/looking-to-listen/local/loader/__init__.py: -------------------------------------------------------------------------------- 1 | from .frames import input_face_embeddings 2 | -------------------------------------------------------------------------------- /egs/LibriVAD/local/prepare_data.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | mkdir "data" 3 | cp Libri_VAD/metadata/LibriSpeech/* data/ 4 | -------------------------------------------------------------------------------- /egs/avspeech/looking-to-listen/local/postprocess/__init__.py: -------------------------------------------------------------------------------- 1 | from .postprocess_audio import filter_audio, shelf 2 | -------------------------------------------------------------------------------- /egs/musdb18/X-UMX/requirements.txt: -------------------------------------------------------------------------------- 1 | scikit-learn>=0.22 2 | musdb>=0.4.0 3 | museval>=0.4.0 4 | norbert>=0.2.1 5 | -------------------------------------------------------------------------------- /asteroid/losses/bark_matrix_16k.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asteroid-team/asteroid/HEAD/asteroid/losses/bark_matrix_16k.mat -------------------------------------------------------------------------------- /asteroid/losses/bark_matrix_8k.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asteroid-team/asteroid/HEAD/asteroid/losses/bark_matrix_8k.mat -------------------------------------------------------------------------------- /docs/source/_static/images/favicon.ico: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asteroid-team/asteroid/HEAD/docs/source/_static/images/favicon.ico -------------------------------------------------------------------------------- /docs/source/_static/images/train_val_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asteroid-team/asteroid/HEAD/docs/source/_static/images/train_val_loss.png -------------------------------------------------------------------------------- /asteroid/engine/__init__.py: -------------------------------------------------------------------------------- 1 | from .system import System 2 | from .optimizers import make_optimizer 3 | 4 | __all__ = ["System", "make_optimizer"] 5 | -------------------------------------------------------------------------------- /docs/source/_static/images/asteroid_logo_dark.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asteroid-team/asteroid/HEAD/docs/source/_static/images/asteroid_logo_dark.png -------------------------------------------------------------------------------- /docs/source/_static/images/code_example_croped.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asteroid-team/asteroid/HEAD/docs/source/_static/images/code_example_croped.png -------------------------------------------------------------------------------- /egs/wsj0-mix/DeepClustering/README.md: -------------------------------------------------------------------------------- 1 | Results will be updated soon 2 | 3 | Deep clustering alone, with VAD weights 9.9 dB improvement (SDR) 4 | MI alone: 10.0 5 | Chimera: 11.5 -------------------------------------------------------------------------------- /tests/engine/__init__.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | ignored_warnings = ["ignore:Could not log computational graph since"] 4 | 5 | pytestmark = pytest.mark.filterwarnings(*ignored_warnings) 6 | -------------------------------------------------------------------------------- /requirements/dev.txt: -------------------------------------------------------------------------------- 1 | # Requirements for development on Asteroid and running tests 2 | -r ./install.txt 3 | pre-commit 4 | black==25.9.0 5 | pytest 6 | coverage 7 | codecov 8 | 9 | librosa 10 | -------------------------------------------------------------------------------- /egs/avspeech/looking-to-listen/train/__init__.py: -------------------------------------------------------------------------------- 1 | from .config import ParamConfig 2 | from .metric_utils import snr, sdr 3 | from .callbacks import SNRCallback, SDRCallback 4 | from .trainer import train 5 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: asteroid 2 | channels: 3 | - anaconda 4 | - conda-forge 5 | dependencies: 6 | - python=3.8 7 | - Cython 8 | - pip: 9 | - -r file:requirements.txt 10 | - -e . 11 | -------------------------------------------------------------------------------- /.coveragerc: -------------------------------------------------------------------------------- 1 | [report] 2 | fail_under=80 3 | # Regexes for lines to exclude from consideration 4 | exclude_lines = 5 | pragma: no cover 6 | raise NotImplementedError 7 | 8 | omit = 9 | asteroid/data/* 10 | 11 | [run] 12 | source=asteroid 13 | -------------------------------------------------------------------------------- /egs/avspeech/looking-to-listen/local/conf.yml: -------------------------------------------------------------------------------- 1 | # Training config 2 | training: 3 | epochs: 40 4 | batch_size: 2 5 | num_workers: 2 6 | device: "cuda:0" 7 | # Optimizer config 8 | optim: 9 | optimizer: adam 10 | lr: 0.0003 11 | -------------------------------------------------------------------------------- /egs/avspeech/looking-to-listen/local/requirements.txt: -------------------------------------------------------------------------------- 1 | catalyst==20.1 2 | facenet-pytorch>=2.2.9 3 | librosa>=0.7.2 4 | mir-eval>=0.6 5 | opencv-python==4.2.0.34 6 | pysndfx>=0.3.6 7 | threadpoolctl>=2.0.0 8 | torchvision>=0.6.0 9 | zipp>=3.1.0 10 | -------------------------------------------------------------------------------- /egs/musdb18/README.md: -------------------------------------------------------------------------------- 1 | ### MUSDB18 Dataset 2 | 3 | The musdb18 is a dataset of 150 full lengths music tracks (~10h duration) of different genres along with their isolated drums, bass, vocals and others stems. 4 | 5 | More info [here](https://sigsep.github.io/datasets/musdb.html). -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # Requirements for using Asteroid. Using this file is equivalent to using 2 | # requirements/install.txt. Note that we cannot make this file a symlink to 3 | # requirements/install.txt because of how pip resolves relative paths with -r. 4 | 5 | -r requirements/install.txt 6 | -------------------------------------------------------------------------------- /requirements/install.txt: -------------------------------------------------------------------------------- 1 | # Requirements for using Asteroid 2 | -r ./torchhub.txt 3 | PyYAML>=5.0 4 | pandas>=0.23.4 5 | pytorch-lightning>=2.0.0 6 | torchmetrics==1.8.0 7 | torchaudio>=0.8.0 8 | pb_bss_eval>=0.0.2 9 | torch_stoi>=0.0.1 10 | torch_optimizer>=0.0.1a12,<=0.3.0 11 | julius 12 | -------------------------------------------------------------------------------- /asteroid/models/README.md: -------------------------------------------------------------------------------- 1 | ### Publishing models 2 | 3 | - First, create a account on [Zenodo](https://zenodo.org/) 4 | (you can log in with GitHub directly) 5 | - Then [create an access token](https://zenodo.org/account/settings/applications/tokens/new/), 6 | we'll need one to upload anything. 7 | 8 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/documentation.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: "\U0001F4DD Typos and doc fixes" 3 | about: Typos and doc fixes 4 | title: '' 5 | labels: typo, documentation 6 | assignees: '' 7 | 8 | --- 9 | 10 | ## 📚 Documentation 11 | 12 | For typos and doc fixes, rather and submit a PR, thanks in advance! 13 | -------------------------------------------------------------------------------- /egs/LibriVAD/local/generate_librivad.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | storage_dir= 4 | python_path=python 5 | 6 | . ./utils/parse_options.sh 7 | 8 | # Clone Libri_VAD repo 9 | git clone https://github.com/asteroid-team/Libri_VAD 10 | 11 | # Run generation script 12 | cd Libri_VAD 13 | . run.sh $storage_dir 14 | -------------------------------------------------------------------------------- /tests/losses/__init__.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | ignored_warnings = [ 4 | "ignore:Could not log computational graph since", 5 | "ignore:The dataloader, val dataloader", 6 | "ignore:The dataloader, train dataloader", 7 | ] 8 | 9 | pytestmark = pytest.mark.filterwarnings(*ignored_warnings) 10 | -------------------------------------------------------------------------------- /egs/wham/DPTNet/README.md: -------------------------------------------------------------------------------- 1 | ### Results 2 | 3 | | | task |kernel size|chunk size|batch size|SI-SNRi(dB) | SDRi(dB)| 4 | |:----:|:---------:|:---------:|:--------:|:--------:|:----------:|:-------:| 5 | | Paper| sep_clean | 2 | | | | | 6 | | Here | sep_clean | 2 | | | | | 7 | -------------------------------------------------------------------------------- /requirements/torchhub.txt: -------------------------------------------------------------------------------- 1 | # Minimal set of requirements required to be able to run Asteroid models from Torch Hub. 2 | # Note that Asteroid itself is not required to be installed. 3 | numpy>=1.16.4 4 | scipy>=1.10.1 5 | torch>=2.0.0 6 | asteroid-filterbanks>=0.4.0 7 | requests 8 | filelock 9 | SoundFile>=0.10.2 10 | huggingface_hub>=0.0.2 11 | -------------------------------------------------------------------------------- /asteroid/masknn/__init__.py: -------------------------------------------------------------------------------- 1 | from .convolutional import TDConvNet, TDConvNetpp, SuDORMRF, SuDORMRFImproved 2 | from .recurrent import DPRNN, LSTMMasker 3 | from .attention import DPTransformer 4 | 5 | __all__ = [ 6 | "TDConvNet", 7 | "DPRNN", 8 | "DPTransformer", 9 | "LSTMMasker", 10 | "SuDORMRF", 11 | "SuDORMRFImproved", 12 | ] 13 | -------------------------------------------------------------------------------- /codecov.yml: -------------------------------------------------------------------------------- 1 | github_checks: 2 | annotations: false 3 | comment: false 4 | coverage: 5 | status: 6 | patch: 7 | default: 8 | target: 60% 9 | project: 10 | default: 11 | target: auto # target is the base commit coverage 12 | threshold: 5% # allow this little decrease on project 13 | base: auto 14 | -------------------------------------------------------------------------------- /egs/librimix/ConvTasNet/local/generate_librimix.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | storage_dir= 4 | n_src= 5 | python_path=python 6 | 7 | . ./utils/parse_options.sh 8 | 9 | current_dir=$(pwd) 10 | # Clone LibriMix repo 11 | git clone https://github.com/JorisCos/LibriMix 12 | 13 | # Run generation script 14 | cd LibriMix 15 | . generate_librimix.sh $storage_dir 16 | -------------------------------------------------------------------------------- /egs/librimix/DCCRNet/local/generate_librimix.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | storage_dir= 4 | n_src= 5 | python_path=python 6 | 7 | . ./utils/parse_options.sh 8 | 9 | current_dir=$(pwd) 10 | # Clone LibriMix repo 11 | git clone https://github.com/JorisCos/LibriMix 12 | 13 | # Run generation script 14 | cd LibriMix 15 | . generate_librimix.sh $storage_dir 16 | -------------------------------------------------------------------------------- /egs/librimix/DCUNet/local/generate_librimix.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | storage_dir= 4 | n_src= 5 | python_path=python 6 | 7 | . ./utils/parse_options.sh 8 | 9 | current_dir=$(pwd) 10 | # Clone LibriMix repo 11 | git clone https://github.com/JorisCos/LibriMix 12 | 13 | # Run generation script 14 | cd LibriMix 15 | . generate_librimix.sh $storage_dir 16 | -------------------------------------------------------------------------------- /egs/librimix/DPTNet/local/generate_librimix.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | storage_dir= 4 | n_src= 5 | python_path=python 6 | 7 | . ./utils/parse_options.sh 8 | 9 | current_dir=$(pwd) 10 | # Clone LibriMix repo 11 | git clone https://github.com/JorisCos/LibriMix 12 | 13 | # Run generation script 14 | cd LibriMix 15 | . generate_librimix.sh $storage_dir 16 | -------------------------------------------------------------------------------- /egs/librimix/DPRNNTasNet/local/generate_librimix.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | storage_dir= 4 | n_src= 5 | python_path=python 6 | 7 | . ./utils/parse_options.sh 8 | 9 | current_dir=$(pwd) 10 | # Clone LibriMix repo 11 | git clone https://github.com/JorisCos/LibriMix 12 | 13 | # Run generation script 14 | cd LibriMix 15 | . generate_librimix.sh $storage_dir 16 | -------------------------------------------------------------------------------- /egs/librimix/SuDORMRFNet/local/generate_librimix.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | storage_dir= 4 | n_src= 5 | python_path=python 6 | 7 | . ./utils/parse_options.sh 8 | 9 | current_dir=$(pwd) 10 | # Clone LibriMix repo 11 | git clone https://github.com/JorisCos/LibriMix 12 | 13 | # Run generation script 14 | cd LibriMix 15 | . generate_librimix.sh $storage_dir 16 | -------------------------------------------------------------------------------- /egs/librimix/SuDORMRFImprovedNet/local/generate_librimix.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | storage_dir= 4 | n_src= 5 | python_path=python 6 | 7 | . ./utils/parse_options.sh 8 | 9 | current_dir=$(pwd) 10 | # Clone LibriMix repo 11 | git clone https://github.com/JorisCos/LibriMix 12 | 13 | # Run generation script 14 | cd LibriMix 15 | . generate_librimix.sh $storage_dir 16 | -------------------------------------------------------------------------------- /egs/sms_wsj/CaCGMM/local/conf.yml: -------------------------------------------------------------------------------- 1 | mm_config: 2 | # stft config 3 | stft_size: 512 4 | stft_shift: 128 5 | stft_window_length: null 6 | stft_window: 'hann' 7 | # Mask config 8 | out: 'mm_mvdr_souden' 9 | Observation: 'Observation' 10 | mask_estimator: 'cacgmm' 11 | weight_constant_axis: -3 # pi_tk 12 | # Beamformer config 13 | beamformer: 'mvdr_souden' 14 | postfilter: null 15 | 16 | -------------------------------------------------------------------------------- /egs/avspeech/looking-to-listen/train/config.py: -------------------------------------------------------------------------------- 1 | # use dataclass if >=python3.7 2 | class ParamConfig: 3 | def __init__(self, batch_size, epochs, workers, cuda, use_half, learning_rate): 4 | self.batch_size = batch_size 5 | self.epochs = epochs 6 | self.workers = workers 7 | self.cuda = cuda 8 | self.use_half = use_half 9 | self.learning_rate = learning_rate 10 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v3.2.0 4 | hooks: 5 | - id: end-of-file-fixer 6 | - id: trailing-whitespace 7 | - id: check-yaml 8 | 9 | - repo: https://github.com/psf/black 10 | rev: 22.3.0 11 | hooks: 12 | - id: black 13 | args: 14 | - --config=pyproject.toml 15 | types: [python] 16 | -------------------------------------------------------------------------------- /docs/source/package_reference/system.rst: -------------------------------------------------------------------------------- 1 | .. role:: hidden 2 | :class: hidden-section 3 | 4 | Lightning Wrapper 5 | ================= 6 | 7 | As explained in :ref:`Training and Evaluation`, Asteroid provides a thin wrapper 8 | on the top of `PyTorchLightning `_ 9 | for training your models. 10 | 11 | .. automodule:: asteroid.engine.system 12 | :members: 13 | -------------------------------------------------------------------------------- /tests/dsp/normalization_test.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from asteroid.dsp.normalization import normalize_estimates 3 | 4 | 5 | def test_normalization(): 6 | 7 | mix = (np.random.rand(1600) - 0.5) * 2 # random [-1,1[ 8 | est = (np.random.rand(2, 1600) - 0.5) * 10 9 | est_normalized = normalize_estimates(est, mix) 10 | 11 | assert np.max(est_normalized) < 1 12 | assert np.min(est_normalized) >= -1 13 | -------------------------------------------------------------------------------- /asteroid/dsp/__init__.py: -------------------------------------------------------------------------------- 1 | from .consistency import mixture_consistency 2 | from .overlap_add import LambdaOverlapAdd, DualPathProcessing 3 | from .beamforming import ( 4 | SCM, 5 | Beamformer, 6 | RTFMVDRBeamformer, 7 | SoudenMVDRBeamformer, 8 | SDWMWFBeamformer, 9 | GEVBeamformer, 10 | ) 11 | 12 | __all__ = [ 13 | "mixture_consistency", 14 | "LambdaOverlapAdd", 15 | "DualPathProcessing", 16 | ] 17 | -------------------------------------------------------------------------------- /egs/dns_challenge_INTERSPEECH2020/baseline/local/download_data.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | clone_dir=$1 4 | 5 | recipe_dir=$PWD 6 | cd $clone_dir 7 | 8 | # Clone repo 9 | git clone -b interspeech2020/master https://github.com/microsoft/DNS-Challenge 10 | cd DNS-Challenge 11 | 12 | # Run lfs stuff in the repo 13 | git lfs install 14 | git lfs track "*.wav" 15 | git add .gitattributes 16 | 17 | # Go back to the recipe 18 | cd $recipe_dir -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = [ 3 | "setuptools", 4 | "wheel", 5 | ] 6 | 7 | [tool.black] 8 | # https://github.com/psf/black 9 | line-length = 100 10 | target-version = ["py38"] 11 | exclude = "(.eggs|.git|.hg|.mypy_cache|.nox|.tox|.venv|.svn|_build|buck-out|build|dist)" 12 | 13 | [tool.pytest.ini_options] 14 | filterwarnings = [ 15 | "ignore:Using or importing the ABCs.*:DeprecationWarning" 16 | ] 17 | -------------------------------------------------------------------------------- /tests/dsp/vad_test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from asteroid.dsp.vad import ebased_vad 3 | 4 | 5 | def test_ebased_vad(): 6 | mag_spec = torch.abs(torch.randn(10, 2, 65, 16)) # Need positive inputs 7 | batch_src_mask = ebased_vad(mag_spec) 8 | 9 | assert isinstance(batch_src_mask, torch.BoolTensor) 10 | batch_1_mask = ebased_vad(mag_spec[:, 0]) 11 | # Assert independence of VAD output 12 | assert (batch_src_mask[:, 0] == batch_1_mask).all() 13 | -------------------------------------------------------------------------------- /tests/jit/__init__.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | ignored_warnings = [ 4 | "ignore:torch.tensor results are registered as constants in the trace.", 5 | "ignore:Converting a tensor to a Python boolean might cause the trace to be incorrect.", 6 | "ignore:Converting a tensor to a Python float might cause the trace to be incorrect.", 7 | "ignore:Using or importing the ABCs from", 8 | ] 9 | 10 | pytestmark = pytest.mark.filterwarnings(*ignored_warnings) 11 | -------------------------------------------------------------------------------- /asteroid/dsp/normalization.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def normalize_estimates(est_np, mix_np): 5 | """Normalizes estimates according to the mixture maximum amplitude 6 | 7 | Args: 8 | est_np (np.array): Estimates with shape (n_src, time). 9 | mix_np (np.array): One mixture with shape (time, ). 10 | 11 | """ 12 | mix_max = np.max(np.abs(mix_np)) 13 | return np.stack([est * mix_max / np.max(np.abs(est)) for est in est_np]) 14 | -------------------------------------------------------------------------------- /egs/dampvsep/ConvTasNet/local/prepare_data.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | dampvsep_root= 3 | python_path=python 4 | 5 | . ./utils/parse_options.sh 6 | if [ ! -d DAMP-VSEP-Singles ]; then 7 | # Clone preprocessed DAMP-VSEP-Singles repo 8 | git clone https://github.com/groadabike/DAMP-VSEP-Singles.git 9 | fi 10 | 11 | if [ ! -d metadata ]; then 12 | # Generate the splits 13 | . DAMP-VSEP-Singles/generate_dampvsep_singles.sh $dampvsep_root ../metadata $python_path 14 | fi 15 | -------------------------------------------------------------------------------- /egs/librimix/DCCRNet/local/prepare_data.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | storage_dir= 4 | n_src= 5 | python_path=python 6 | 7 | . ./utils/parse_options.sh 8 | 9 | if [[ $n_src -le 1 ]] 10 | then 11 | changed_n_src=2 12 | else 13 | changed_n_src=n_src 14 | fi 15 | 16 | $python_path local/create_local_metadata.py --librimix_dir $storage_dir/Libri$changed_n_src"Mix" 17 | 18 | $python_path local/get_text.py \ 19 | --libridir $storage_dir/LibriSpeech \ 20 | --split test-clean \ 21 | --outfile data/test_annotations.csv 22 | -------------------------------------------------------------------------------- /egs/librimix/DCUNet/local/prepare_data.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | storage_dir= 4 | n_src= 5 | python_path=python 6 | 7 | . ./utils/parse_options.sh 8 | 9 | if [[ $n_src -le 1 ]] 10 | then 11 | changed_n_src=2 12 | else 13 | changed_n_src=n_src 14 | fi 15 | 16 | $python_path local/create_local_metadata.py --librimix_dir $storage_dir/Libri$changed_n_src"Mix" 17 | 18 | $python_path local/get_text.py \ 19 | --libridir $storage_dir/LibriSpeech \ 20 | --split test-clean \ 21 | --outfile data/test_annotations.csv 22 | -------------------------------------------------------------------------------- /egs/librimix/DPTNet/local/prepare_data.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | storage_dir= 4 | n_src= 5 | python_path=python 6 | 7 | . ./utils/parse_options.sh 8 | 9 | if [[ $n_src -le 1 ]] 10 | then 11 | changed_n_src=2 12 | else 13 | changed_n_src=n_src 14 | fi 15 | 16 | $python_path local/create_local_metadata.py --librimix_dir $storage_dir/Libri$changed_n_src"Mix" 17 | 18 | $python_path local/get_text.py \ 19 | --libridir $storage_dir/LibriSpeech \ 20 | --split test-clean \ 21 | --outfile data/test_annotations.csv 22 | -------------------------------------------------------------------------------- /egs/librimix/ConvTasNet/local/prepare_data.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | storage_dir= 4 | n_src= 5 | python_path=python 6 | 7 | . ./utils/parse_options.sh 8 | 9 | if [[ $n_src -le 1 ]] 10 | then 11 | changed_n_src=2 12 | else 13 | changed_n_src=$n_src 14 | fi 15 | 16 | $python_path local/create_local_metadata.py --librimix_dir $storage_dir/Libri$changed_n_src"Mix" 17 | 18 | $python_path local/get_text.py \ 19 | --libridir $storage_dir/LibriSpeech \ 20 | --split test-clean \ 21 | --outfile data/test_annotations.csv 22 | -------------------------------------------------------------------------------- /egs/librimix/DPRNNTasNet/local/prepare_data.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | storage_dir= 4 | n_src= 5 | python_path=python 6 | 7 | . ./utils/parse_options.sh 8 | 9 | if [[ $n_src -le 1 ]] 10 | then 11 | changed_n_src=2 12 | else 13 | changed_n_src=n_src 14 | fi 15 | 16 | $python_path local/create_local_metadata.py --librimix_dir $storage_dir/Libri$changed_n_src"Mix" 17 | 18 | $python_path local/get_text.py \ 19 | --libridir $storage_dir/LibriSpeech \ 20 | --split test-clean \ 21 | --outfile data/test_annotations.csv 22 | -------------------------------------------------------------------------------- /egs/librimix/SuDORMRFNet/local/prepare_data.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | storage_dir= 4 | n_src= 5 | python_path=python 6 | 7 | . ./utils/parse_options.sh 8 | 9 | if [[ $n_src -le 1 ]] 10 | then 11 | changed_n_src=2 12 | else 13 | changed_n_src=n_src 14 | fi 15 | 16 | $python_path local/create_local_metadata.py --librimix_dir $storage_dir/Libri$changed_n_src"Mix" 17 | 18 | $python_path local/get_text.py \ 19 | --libridir $storage_dir/LibriSpeech \ 20 | --split test-clean \ 21 | --outfile data/test_annotations.csv 22 | -------------------------------------------------------------------------------- /egs/librimix/SuDORMRFImprovedNet/local/prepare_data.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | storage_dir= 4 | n_src= 5 | python_path=python 6 | 7 | . ./utils/parse_options.sh 8 | 9 | if [[ $n_src -le 1 ]] 10 | then 11 | changed_n_src=2 12 | else 13 | changed_n_src=n_src 14 | fi 15 | 16 | $python_path local/create_local_metadata.py --librimix_dir $storage_dir/Libri$changed_n_src"Mix" 17 | 18 | $python_path local/get_text.py \ 19 | --libridir $storage_dir/LibriSpeech \ 20 | --split test-clean \ 21 | --outfile data/test_annotations.csv 22 | -------------------------------------------------------------------------------- /requirements/docs.txt: -------------------------------------------------------------------------------- 1 | -r ./dev.txt 2 | sphinx-rtd-theme==0.4.3 3 | sphinxcontrib-jsmath==1.0.1 4 | sphinxcontrib-programoutput>=0.16 5 | sphinx==8.*, <9.0 6 | jinja2>=3.1.0,<3.2.0 7 | #recommonmark # fails with badges 8 | myst-parser==4.*,<5.0 9 | nbsphinx==0.9.*, <1.0 10 | lxml[html_clean] 11 | m2r2==0.3.4 12 | -e git+https://github.com/asteroid-team/asteroid_sphinx_theme#egg=asteroid_sphinx_theme 13 | # asteroid_sphinx_theme>=0.0.3 14 | # pandoc 15 | # docutils 16 | # sphinxcontrib-fulltoc 17 | # sphinxcontrib-mockautodoc 18 | # pip_shims 19 | -------------------------------------------------------------------------------- /egs/librimix/DCUNet/README.md: -------------------------------------------------------------------------------- 1 | ### Results 2 | 3 | The model was train on the `enh_single` task on Libri1Mix `train_360`, at 16kHz. 4 | 5 | The pretrained model is available [here](https://huggingface.co/JorisCos/DCUNet_Libri1Mix_enhsingle_16k). 6 | 7 | On Libri1Mix min test set : 8 | 9 | ``` yaml 10 | si_sdr: 11.853042303532362 11 | si_sdr_imp: 8.403260997672662 12 | sdr: 12.248064453851127 13 | sdr_imp: 8.745401654638112 14 | sar: 12.248064453851127 15 | sar_imp: 8.745401654638112 16 | stoi: 0.9095207713713066 17 | stoi_imp: 0.11360094783076659 18 | ``` -------------------------------------------------------------------------------- /egs/librimix/DPTNet/README.md: -------------------------------------------------------------------------------- 1 | ### Results 2 | 3 | The model was train on the `enh_single` task on Libri1Mix `train_360`, at 16kHz. 4 | 5 | The pretrained model is available [here](https://huggingface.co/JorisCos/DPTNet_Libri1Mix_enhsingle_16k&) 6 | On Libri1Mix min test set : 7 | 8 | ``` yaml 9 | si_sdr: 14.829670037349064 10 | si_sdr_imp: 11.379888731489366 11 | sdr: 15.395712644737149 12 | sdr_imp: 11.893049845524112 13 | sar: 15.395712644737149 14 | sar_imp: 11.893049845524112 15 | stoi: 0.9301948391058859 16 | stoi_imp: 0.13427501556534832 17 | ``` -------------------------------------------------------------------------------- /egs/librimix/DCCRNet/README.md: -------------------------------------------------------------------------------- 1 | ### Results 2 | 3 | The model was train on the `enh_single` task on Libri1Mix `train_360`, at 16kHz. 4 | 5 | The pretrained model is available [here](https://huggingface.co/JorisCos/DCCRNet_Libri1Mix_enhsingle_16k). 6 | 7 | On Libri1Mix min test set : 8 | 9 | ``` yaml 10 | si_sdr: 13.329767398333798 11 | si_sdr_imp: 9.879986092474098 12 | sdr: 13.87279932997016 13 | sdr_imp: 10.370136530757103 14 | sar: 13.87279932997016 15 | sar_imp: 10.370136530757103 16 | stoi: 0.9140907015623948 17 | stoi_imp: 0.11817087802185405 18 | ``` -------------------------------------------------------------------------------- /egs/librimix/DPRNNTasNet/README.md: -------------------------------------------------------------------------------- 1 | ### Results 2 | 3 | The model was train on the `enh_single` task on Libri1Mix `train_360`, at 16kHz. 4 | 5 | The pretrained model is available [here](https://huggingface.co/JorisCos/DPRNNTasNet-ks2_Libri1Mix_enhsingle_16k) 6 | On Libri1Mix min test set : 7 | 8 | ``` yaml 9 | si_sdr: 14.7228101708889 10 | si_sdr_imp: 11.2730288650292 11 | sdr: 15.35661405197161 12 | sdr_imp: 11.853951252758595 13 | sar: 15.35661405197161 14 | sar_imp: 11.853951252758595 15 | stoi: 0.9300461826351578 16 | stoi_imp: 0.13412635909461715 17 | ``` -------------------------------------------------------------------------------- /tests/README.md: -------------------------------------------------------------------------------- 1 | # Asteroid tests 2 | ## Running the tests locally 3 | 4 | ```bash 5 | git clone https://github.com/asteroid-team/asteroid 6 | cd asteroid 7 | 8 | # install module locally 9 | pip install -e . 10 | 11 | # install dev deps 12 | pip install -r requirements/dev.txt 13 | 14 | # run tests 15 | py.test -v 16 | ``` 17 | 18 | ### Running with coverage 19 | From `asteroid` parent directory 20 | ```bash 21 | # generate coverage 22 | coverage run --source asteroid -m py.test tests -v --doctest-modules 23 | # print coverage stats 24 | coverage report -m 25 | ``` 26 | -------------------------------------------------------------------------------- /egs/dns_challenge_INTERSPEECH2020/baseline/local/conf.yml: -------------------------------------------------------------------------------- 1 | # Filterbank config 2 | filterbank: 3 | n_filters: 512 4 | kernel_size: 512 5 | stride: 256 6 | # Network config 7 | masknet: 8 | hidden_size: 500 9 | rnn_type: gru 10 | n_layers: 3 11 | dropout: 0.3 12 | # Training config 13 | training: 14 | epochs: 200 15 | batch_size: 64 16 | num_workers: 32 17 | half_lr: y 18 | early_stop: y 19 | # Optim config 20 | optim: 21 | optimizer: adam 22 | lr: 0.001 23 | weight_decay: 0. 24 | # Data config 25 | data: 26 | json_dir: data/ 27 | val_prop: 0.2 28 | -------------------------------------------------------------------------------- /egs/TAC/README.md: -------------------------------------------------------------------------------- 1 | This model was trained using the [dataset](https://github.com/yluo42/TAC/tree/master/data) provided by Yi Luo, author of 2 | ["End-to-end Microphone Permutation and Number Invariant Multi-channel Speech Separation"](https://arxiv.org/abs/1910.14104). 3 | 4 | ### Results 5 | 6 | | | task | dataset type | SI-SNRi(dB) | 7 | |:----:|:---------:|:----------:|:----------:| 8 | | Paper| sep_clean | adhoc | ? | 9 | | Here | sep_clean | adhoc | 9.9 | 10 | 11 | The pretrained model is available [here](https://huggingface.co/JorisCos/FasNet) 12 | -------------------------------------------------------------------------------- /.readthedocs.yml: -------------------------------------------------------------------------------- 1 | # Read the Docs configuration file 2 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details 3 | 4 | # Required 5 | version: 2 6 | 7 | # Build documentation in the docs/ directory with Sphinx 8 | sphinx: 9 | configuration: docs/source/conf.py 10 | 11 | # Optionally build your docs in additional formats such as PDF 12 | formats: 13 | - htmlzip 14 | - pdf 15 | 16 | # Optionally set the version of Python and requirements required to build your docs 17 | python: 18 | version: 3.7 19 | install: 20 | - requirements: requirements/docs.txt 21 | -------------------------------------------------------------------------------- /egs/librimix/ConvTasNet/README.md: -------------------------------------------------------------------------------- 1 | ### Results 2 | 3 | All the models were trained using 8 Khz min subsets with the same model 4 | parameters (see [here](./local/conf.yml) for more details). 5 | 6 | 7 | | | task |SI-SNRi(dB) | SDRi(dB)| 8 | |:---------:|:---------:|:---------:|:-------:| 9 | | train-100 | sep_clean | 13.0 | 13.4 | 10 | | train-360 | sep_clean | 14.7 | 15.1 | 11 | | train-100 | sep_noisy | 10.8 | 11.4 | 12 | | train-360 | sep_noisy | 12 | 12.5 | 13 | 14 | See available models [here](https://huggingface.co/models?filter=asteroid). -------------------------------------------------------------------------------- /tests/jit/jit_torch_utils_test.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | 4 | from asteroid.utils import torch_utils 5 | 6 | 7 | @pytest.mark.parametrize( 8 | "data", 9 | ( 10 | torch.tensor([1]), 11 | torch.tensor([1, 2]), 12 | torch.tensor([[1], [2]]), 13 | torch.tensor([[2, 5], [3, 8]]), 14 | ), 15 | ) 16 | def test_jitable_shape(data): 17 | expected = torch_utils.jitable_shape(data) 18 | scripted = torch.jit.trace(torch_utils.jitable_shape, torch.tensor([1])) 19 | output = scripted(data) 20 | assert torch.equal(output, expected) 21 | -------------------------------------------------------------------------------- /egs/dns_challenge_INTERSPEECH2020/baseline/local/create_dns_dataset.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | clone_dir=$1 4 | storage_dir=$2 5 | 6 | cd $clone_dir/DNS-Challenge 7 | # SED the cfg file to modify windows-like path to linux-like path 8 | sed -i 's+\\+\/+g' noisyspeech_synthesizer.cfg 9 | 10 | # Change default saving directories 11 | # We keep the default values for all the rest feel free to modify it. 12 | sed -i 's+./training+'"$storage_dir"'+g' noisyspeech_synthesizer.cfg 13 | 14 | # Run the dataset recipe 15 | python -m pip install librosa pandas 16 | python noisyspeech_synthesizer_singleprocess.py -------------------------------------------------------------------------------- /egs/librimix/DCCRNet/local/conf.yml: -------------------------------------------------------------------------------- 1 | # Filterbank config 2 | filterbank: 3 | stft_n_filters: 512 4 | stft_kernel_size: 400 5 | stft_stride: 100 6 | masknet: 7 | architecture: DCCRN-CL 8 | data: 9 | task: enh_single 10 | train_dir: data/wav16k/max/train-360 11 | valid_dir: data/wav16k/max/dev 12 | sample_rate: 16000 13 | n_src: 1 14 | segment: 4 15 | training: 16 | epochs: 200 17 | batch_size: 12 18 | num_workers: 4 19 | half_lr: yes 20 | early_stop: yes 21 | gradient_clipping: 5 22 | optim: 23 | optimizer: adam 24 | lr: 0.001 25 | weight_decay: !!float 1e-5 26 | -------------------------------------------------------------------------------- /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | max-line-length = 119 3 | exclude = docs/source,*.egg,build 4 | select = E,W,F 5 | verbose = 2 6 | # https://pep8.readthedocs.io/en/latest/intro.html#error-codes 7 | format = pylint 8 | ignore = 9 | # E731 - Do not assign a lambda expression, use a def 10 | E731 11 | # W605 - invalid escape sequence '\_'. Needed for docs 12 | W605 13 | # W504 - line break after binary operator 14 | W504 15 | # W503 - line break before binary operator, need for black 16 | W503 17 | # E203 - whitespace before ':'. Opposite convention enforced by black 18 | E203 19 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/how-to-question.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: "\U0001F914 ❓ Asking a question" 3 | about: Asking a question 4 | title: '' 5 | labels: question 6 | assignees: '' 7 | 8 | --- 9 | 10 | ## ❓ Questions and Help 11 | 12 | ### Before asking: 13 | 1. search the [issues](https://github.com/asteroid-team/asteroid/issues). 14 | 2. search the [discussions](https://github.com/asteroid-team/asteroid/discussions) 15 | 3. search the docs. 16 | 17 | 18 | 19 | Open a question on the [Discussions](https://github.com/asteroid-team/asteroid/discussions/new) 20 | -------------------------------------------------------------------------------- /.github/workflows/rebase.yml: -------------------------------------------------------------------------------- 1 | name: Automatic Rebase 2 | # https://github.com/marketplace/actions/automatic-rebase 3 | 4 | on: 5 | issue_comment: 6 | types: [created] 7 | 8 | jobs: 9 | rebase: 10 | name: Rebase 11 | if: github.event.issue.pull_request != '' && contains(github.event.comment.body, '/rebase') 12 | runs-on: ubuntu-latest 13 | steps: 14 | - uses: actions/checkout@v5 15 | with: 16 | fetch-depth: 0 17 | - name: Automatic Rebase 18 | uses: cirrus-actions/rebase@1.8 19 | 20 | env: 21 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 22 | -------------------------------------------------------------------------------- /egs/wham/DPRNN/README.md: -------------------------------------------------------------------------------- 1 | ### Results 2 | 3 | | | task |kernel size|chunk size|batch size|SI-SNRi(dB) | SDRi(dB)| 4 | |:----:|:---------:|:---------:|:--------:|:--------:|:----------:|:-------:| 5 | | Paper| sep_clean | 16 | 100 | - | 15.9 | 16.1 | 6 | | Here | sep_clean | 16 | 100 | 8 | 17.7 | 18.0 | 7 | | Paper| sep_clean | 2 | 250 | - | 18.8 | 19.0 | 8 | | Here | sep_clean | 2 | 250 | 3 | 19.3 | 19.5 | 9 | 10 | Both models with ks=16 and ks=2 were trained with segments of 2seconds only. -------------------------------------------------------------------------------- /egs/LibriVAD/README.md: -------------------------------------------------------------------------------- 1 | ### Results 2 | 3 | This model was on train on [LibriVAD](https://github.com/asteroid-team/Libri_VAD) a dataset for voice activity detection 4 | in noisy environment based on [LibriSpeech](https://www.openslr.org/12) and 5 | [DNS challenge](https://github.com/microsoft/DNS-Challenge/tree/master/datasets/noise) noises. 6 | 7 | 8 | 9 | | | F1 score |Accuracy| Precision| Recall| 10 | |:-------:|:--------------:|:------:|:--------:|:-----:| 11 | | Asteroid | 0.84 | 0.82 | 0.83 | 0.89 | 12 | 13 | 14 | See available models [here](https://huggingface.co/models?filter=asteroid). 15 | -------------------------------------------------------------------------------- /egs/librimix/DCUNet/local/conf.yml: -------------------------------------------------------------------------------- 1 | # Filterbank config 2 | filterbank: 3 | stft_n_filters: 1024 4 | stft_kernel_size: 1024 5 | stft_stride: 256 6 | masknet: 7 | architecture: Large-DCUNet-20 8 | fix_length_mode: pad 9 | data: 10 | task: enh_single 11 | train_dir: data/wav16k/max/train-360 12 | valid_dir: data/wav16k/max/dev 13 | sample_rate: 16000 14 | n_src: 1 15 | segment: 2 16 | training: 17 | epochs: 200 18 | batch_size: 4 19 | num_workers: 4 20 | half_lr: yes 21 | early_stop: yes 22 | gradient_clipping: 5 23 | optim: 24 | optimizer: adam 25 | lr: 0.001 26 | weight_decay: !!float 1e-5 27 | -------------------------------------------------------------------------------- /docs/source/package_reference/blocks.rst: -------------------------------------------------------------------------------- 1 | DNN building blocks 2 | =================== 3 | 4 | Convolutional blocks 5 | -------------------- 6 | .. automodule:: asteroid.masknn.convolutional 7 | :members: 8 | 9 | Recurrent blocks 10 | ---------------- 11 | .. automodule:: asteroid.masknn.recurrent 12 | :members: 13 | 14 | Attention blocks 15 | ---------------- 16 | .. automodule:: asteroid.masknn.attention 17 | :members: 18 | 19 | Norms 20 | ----- 21 | .. automodule:: asteroid.masknn.norms 22 | :members: 23 | 24 | Complex number support 25 | ---------------------- 26 | .. automodule:: asteroid.complex_nn 27 | :members: 28 | -------------------------------------------------------------------------------- /egs/TAC/local/conf.yml: -------------------------------------------------------------------------------- 1 | data: 2 | sample_rate: 16000 3 | segment: 4 | train_json: ./data/train.json 5 | dev_json: ./data/validation.json 6 | test_json: ./data/test.json 7 | net: 8 | enc_dim: 64 9 | chunk_size: 50 10 | hop_size: 25 11 | feature_dim: 64 12 | hidden_dim: 128 13 | n_layers: 4 14 | n_src: 2 15 | window_ms: 4 16 | context_ms: 16 17 | optim: 18 | lr: 0.001 19 | weight_decay: !!float 1e-5 20 | training: 21 | epochs: 200 22 | batch_size: 1 23 | gradient_clipping: 5 24 | accumulate_batches: 1 25 | save_top_k: 5 26 | num_workers: 8 27 | patience: 30 28 | half_lr: true 29 | early_stop: true 30 | -------------------------------------------------------------------------------- /egs/whamr/TasNet/local/conf.yml: -------------------------------------------------------------------------------- 1 | # Filterbank config 2 | filterbank: 3 | n_filters: 512 4 | kernel_size: 40 5 | stride: 20 6 | # Network config 7 | masknet: 8 | n_layers: 4 9 | n_units: 500 10 | dropout: 0.3 11 | # Training config 12 | training: 13 | epochs: 200 14 | half_lr: yes 15 | early_stop: yes 16 | batch_size: 16 17 | num_workers: 16 18 | # Optim config 19 | optim: 20 | optimizer: adam 21 | lr: 0.001 22 | weight_decay: 0. 23 | # Data config 24 | data: 25 | train_dir: data/wav8k/min/tr/ 26 | valid_dir: data/wav8k/min/cv/ 27 | task: sep_reverb_noisy 28 | nondefault_nsrc: 29 | sample_rate: 8000 30 | mode: min -------------------------------------------------------------------------------- /egs/LibriVAD/local/conf.yml: -------------------------------------------------------------------------------- 1 | # filterbank config 2 | filterbank: 3 | n_filters: 512 4 | kernel_size: 16 5 | stride: 8 6 | # Network config 7 | masknet: 8 | n_src: 1 9 | n_blocks: 3 10 | n_repeats: 5 11 | mask_act: relu 12 | bn_chan: 128 13 | skip_chan: 128 14 | hid_chan: 512 15 | causal: False 16 | # Training config 17 | training: 18 | epochs: 200 19 | batch_size: 8 20 | num_workers: 4 21 | half_lr: yes 22 | early_stop: yes 23 | # Optim config 24 | optim: 25 | optimizer: adam 26 | lr: 0.001 27 | weight_decay: 0. 28 | # Data config 29 | data: 30 | train_dir: data/train.json 31 | valid_dir: data/dev.json 32 | segment: 3 33 | -------------------------------------------------------------------------------- /egs/librimix/README.md: -------------------------------------------------------------------------------- 1 | ### LibriMix dataset 2 | 3 | The LibriMix dataset is an open source dataset 4 | derived from LibriSpeech dataset. It's meant as an alternative and complement 5 | to [WHAM](./../wham/). 6 | 7 | More info [here](https://github.com/JorisCos/LibriMix). 8 | 9 | **References** 10 | ```BibTeX 11 | @misc{cosentino2020librimix, 12 | title={LibriMix: An Open-Source Dataset for Generalizable Speech Separation}, 13 | author={Joris Cosentino and Manuel Pariente and Samuele Cornell and Antoine Deleforge and Emmanuel Vincent}, 14 | year={2020}, 15 | eprint={2005.11262}, 16 | archivePrefix={arXiv}, 17 | primaryClass={eess.AS} 18 | } 19 | ``` -------------------------------------------------------------------------------- /asteroid/losses/soft_f1.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn.modules.loss import _Loss 3 | 4 | 5 | class F1_loss(_Loss): 6 | """Calculate F1 score""" 7 | 8 | def __init__(self, eps=1e-10): 9 | super().__init__() 10 | self.eps = eps 11 | 12 | def forward(self, estimates, targets): 13 | tp = (targets * estimates).sum() 14 | fp = ((1 - targets) * estimates).sum() 15 | fn = (targets * (1 - estimates)).sum() 16 | 17 | precision = tp / (tp + fp + self.eps) 18 | recall = tp / (tp + fn + self.eps) 19 | 20 | f1 = 2 * (precision * recall) / (precision + recall + self.eps) 21 | return 1 - f1.mean() 22 | -------------------------------------------------------------------------------- /egs/avspeech/looking-to-listen/local/loader/constants/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | STORAGE_DIR = os.environ.get("STORAGE_DIR", "storage_dir") 4 | if not STORAGE_DIR.startswith("/"): 5 | STORAGE_DIR = os.path.join("../..", STORAGE_DIR) # We are in local/loader 6 | 7 | AUDIO_MIX_COMMAND_PREFIX = "ffmpeg -y -t 00:00:03 -ac 1 " 8 | 9 | AUDIO_DIR = f"{STORAGE_DIR}/storage/audio" 10 | VIDEO_DIR = f"{STORAGE_DIR}/storage/video" 11 | EMBED_DIR = f"{STORAGE_DIR}/storage/embed" 12 | MIXED_AUDIO_DIR = f"{STORAGE_DIR}/storage/mixed" 13 | SPEC_DIR = f"{STORAGE_DIR}/storage/spec" 14 | 15 | AUDIO_SET_DIR = f"{STORAGE_DIR}/audio_set/audio" 16 | 17 | STORAGE_LIMIT = 5_000_000_000 18 | -------------------------------------------------------------------------------- /docs/source/package_reference/utils.rst: -------------------------------------------------------------------------------- 1 | .. role:: hidden 2 | :class: hidden-section 3 | 4 | Utils 5 | ===== 6 | 7 | Parser utils 8 | ------------- 9 | Asteroid has its own argument parser (built on ``argparse``) that handles 10 | dict-like structure, created from a config YAML file. 11 | 12 | .. automodule:: asteroid.utils.parser_utils 13 | :members: 14 | 15 | 16 | Torch utils 17 | ------------ 18 | .. automodule:: asteroid.utils.torch_utils 19 | :members: 20 | 21 | 22 | Hub utils 23 | ---------- 24 | .. automodule:: asteroid.utils.hub_utils 25 | :members: 26 | 27 | 28 | Generic utils 29 | -------------- 30 | .. automodule:: asteroid.utils.generic_utils 31 | :members: 32 | 33 | 34 | 35 | -------------------------------------------------------------------------------- /egs/dampvsep/README.md: -------------------------------------------------------------------------------- 1 | ### DAMP-VSEP dataset 2 | 3 | All the information regarding the dataset can be found in 4 | [zenodo](https://zenodo.org/record/3553059#.X5xKGnX7S-o). 5 | 6 | **References** 7 | If you use this dataset, please cite as follows : 8 | 9 | ```BibTex 10 | @dataset{smule_inc_2019_3553059, 11 | author = {Smule, Inc}, 12 | title = {{DAMP-VSEP: Smule Digital Archive of Mobile 13 | Performances - Vocal Separation}}, 14 | month = oct, 15 | year = 2019, 16 | publisher = {Zenodo}, 17 | version = {1.0.1}, 18 | doi = {10.5281/zenodo.3553059}, 19 | url = {https://doi.org/10.5281/zenodo.3553059} 20 | } 21 | ``` 22 | -------------------------------------------------------------------------------- /egs/wham/ConvTasNet/local/conf.yml: -------------------------------------------------------------------------------- 1 | # Filterbank config 2 | filterbank: 3 | n_filters: 512 4 | kernel_size: 16 5 | stride: 8 6 | # Network config 7 | masknet: 8 | n_blocks: 8 9 | n_repeats: 3 10 | mask_act: relu 11 | bn_chan: 128 12 | skip_chan: 128 13 | hid_chan: 512 14 | # Training config 15 | training: 16 | epochs: 200 17 | batch_size: 8 18 | num_workers: 4 19 | half_lr: yes 20 | early_stop: yes 21 | # Optim config 22 | optim: 23 | optimizer: adam 24 | lr: 0.001 25 | weight_decay: 0. 26 | # Data config 27 | data: 28 | train_dir: data/wav8k/min/tr/ 29 | valid_dir: data/wav8k/min/cv/ 30 | task: sep_clean 31 | nondefault_nsrc: 32 | sample_rate: 8000 33 | mode: min 34 | -------------------------------------------------------------------------------- /egs/librimix/ConvTasNet/local/conf.yml: -------------------------------------------------------------------------------- 1 | # filterbank config 2 | filterbank: 3 | n_filters: 512 4 | kernel_size: 16 5 | stride: 8 6 | # Network config 7 | masknet: 8 | n_blocks: 8 9 | n_repeats: 3 10 | mask_act: relu 11 | bn_chan: 128 12 | skip_chan: 128 13 | hid_chan: 512 14 | # Training config 15 | training: 16 | epochs: 200 17 | batch_size: 6 18 | num_workers: 4 19 | half_lr: yes 20 | early_stop: yes 21 | # Optim config 22 | optim: 23 | optimizer: adam 24 | lr: 0.001 25 | weight_decay: 0. 26 | # Data config 27 | data: 28 | task: sep_clean 29 | train_dir: data/wav8k/min/train-100 30 | valid_dir: data/wav8k/min/dev 31 | sample_rate: 8000 32 | n_src: 2 33 | segment: 3 34 | -------------------------------------------------------------------------------- /egs/whamr/README.md: -------------------------------------------------------------------------------- 1 | ### WHAMR dataset 2 | WHAMR! is a noisy and reverberant single-channel speech separation dataset 3 | based on WSJ0. 4 | It is a reverberant extension of [WHAM!](./../wham). 5 | 6 | Note that WHAMR! can synthesize binaural recordings, but we only consider 7 | the single channel for now. 8 | 9 | More info [here](http://wham.whisper.ai/). 10 | **References** 11 | ```BibTex 12 | @misc{maciejewski2019whamr, 13 | title={WHAMR!: Noisy and Reverberant Single-Channel Speech Separation}, 14 | author={Matthew Maciejewski and Gordon Wichern and Emmett McQuinn and Jonathan Le Roux}, 15 | year={2019}, 16 | eprint={1910.10279}, 17 | archivePrefix={arXiv}, 18 | primaryClass={cs.SD} 19 | } 20 | ``` -------------------------------------------------------------------------------- /egs/wsj0-mix/README.md: -------------------------------------------------------------------------------- 1 | ### wsj0-2mix dataset 2 | 3 | wsj0-2mix is a single channel speech separation dataset base on WSJ0. 4 | Three speaker extension (wsj0-3mix) is also considered here. 5 | 6 | **Reference** 7 | ```BibTex 8 | @article{Hershey_2016, 9 | title={Deep clustering: Discriminative embeddings for segmentation and separation}, 10 | ISBN={9781479999880}, 11 | url={http://dx.doi.org/10.1109/ICASSP.2016.7471631}, 12 | DOI={10.1109/icassp.2016.7471631}, 13 | journal={2016 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP)}, 14 | publisher={IEEE}, 15 | author={Hershey, John R. and Chen, Zhuo and Le Roux, Jonathan and Watanabe, Shinji}, 16 | year={2016}, 17 | } 18 | ``` -------------------------------------------------------------------------------- /egs/wham/README.md: -------------------------------------------------------------------------------- 1 | ### WHAM dataset 2 | WHAM! is a noisy single-channel speech separation dataset based on WSJ0. 3 | It is a noisy extension of [wsj0-2mix](./../wsj0-mix/). 4 | 5 | More info [here](http://wham.whisper.ai/). 6 | 7 | **References** 8 | ```BibTex 9 | @inproceedings{WHAMWichern2019, 10 | author={Gordon Wichern and Joe Antognini and Michael Flynn and Licheng Richard Zhu and Emmett McQuinn and Dwight Crow and Ethan Manilow and Jonathan Le Roux}, 11 | title={{WHAM!: extending speech separation to noisy environments}}, 12 | year=2019, 13 | booktitle={Proc. Interspeech}, 14 | pages={1368--1372}, 15 | doi={10.21437/Interspeech.2019-2821}, 16 | url={http://dx.doi.org/10.21437/Interspeech.2019-2821} 17 | } 18 | ``` 19 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = 6 | SPHINXBUILD = sphinx-build 7 | SOURCEDIR = source 8 | BUILDDIR = build 9 | 10 | # Put it first so that "make" without argument is like "make help". 11 | help: 12 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 13 | 14 | .PHONY: help Makefile 15 | 16 | # Catch-all target: route all unknown targets to Sphinx using the new 17 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 18 | %: Makefile 19 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 20 | 21 | clean : 22 | rm -rf $(BUILDDIR)/* 23 | rm -rf $(SOURCEDIR)/apidoc/* 24 | -------------------------------------------------------------------------------- /docs/source/cli.rst: -------------------------------------------------------------------------------- 1 | Command-line interface 2 | ====================== 3 | 4 | 5 | Inference 6 | --------- 7 | 8 | asteroid-infer 9 | ~~~~~~~~~~~~~~ 10 | 11 | Example 12 | ....... 13 | 14 | :: 15 | 16 | asteroid-infer "mpariente/ConvTasNet_WHAM!_sepclean" --files myaudio.wav --resample --ola-window 8000 --ola-hop 4000 17 | 18 | Reference 19 | ......... 20 | 21 | .. program-output:: asteroid-infer --help 22 | 23 | 24 | Publishing models 25 | ----------------- 26 | 27 | asteroid-upload 28 | ~~~~~~~~~~~~~~~ 29 | 30 | Reference 31 | ......... 32 | 33 | .. program-output:: asteroid-upload --help 34 | 35 | asteroid-register-sr 36 | ~~~~~~~~~~~~~~~~~~~~ 37 | 38 | Reference 39 | ......... 40 | 41 | .. program-output:: asteroid-register-sr --help 42 | -------------------------------------------------------------------------------- /egs/wsj0-mix/DeepClustering/local/conf.yml: -------------------------------------------------------------------------------- 1 | # Filterbank config 2 | filterbank: 3 | n_filters: 256 4 | kernel_size: 256 5 | stride: 64 6 | # Network config 7 | masknet: 8 | rnn_type: lstm 9 | n_layers: 4 10 | hidden_size: 600 11 | dropout: 0.3 12 | embedding_dim: 40 13 | take_log: y 14 | # Training config 15 | training: 16 | epochs: 200 17 | batch_size: 32 18 | num_workers: 8 19 | half_lr: yes 20 | early_stop: yes 21 | loss_alpha: 1.0 # DC loss weight : 1.0 => DC, <1.0 => Chimera 22 | # Optim config 23 | optim: 24 | optimizer: rmsprop 25 | lr: 0.0001 26 | weight_decay: 0.00000 27 | # momentum: 0.9 28 | # Data config 29 | data: 30 | train_dir: data/wav8k/min/tr/ 31 | valid_dir: data/wav8k/min/cv/ 32 | n_src: 2 33 | sample_rate: 8000 -------------------------------------------------------------------------------- /egs/librimix/SuDORMRFImprovedNet/local/conf.yml: -------------------------------------------------------------------------------- 1 | ## Filterbank config 2 | filterbank: 3 | fb_name: "free" 4 | n_filters: 512 5 | kernel_size: 41 6 | stride: 20 7 | ## Network config 8 | masknet: 9 | bn_chan: 128 10 | num_blocks: 16 11 | upsampling_depth: 4 12 | mask_act: "relu" 13 | in_chan: 512 14 | # Training config 15 | training: 16 | epochs: 200 17 | batch_size: 2 18 | num_workers: 4 19 | half_lr: yes 20 | early_stop: yes 21 | gradient_clipping: 5 22 | # Optim config 23 | optim: 24 | optimizer: adam 25 | lr: 0.001 26 | weight_decay: !!float 1e-5 27 | # Data config 28 | data: 29 | task: enh_single 30 | train_dir: data/wav16k/max/train-360 31 | valid_dir: data/wav16k/max/dev 32 | sample_rate: 16000 33 | n_src: 1 34 | segment: 3 35 | -------------------------------------------------------------------------------- /egs/librimix/SuDORMRFNet/local/conf.yml: -------------------------------------------------------------------------------- 1 | ## Filterbank config 2 | filterbank: 3 | fb_name: "free" 4 | n_filters: 512 5 | kernel_size: 41 6 | stride: 20 7 | ## Network config 8 | masknet: 9 | bn_chan: 128 10 | num_blocks: 16 11 | upsampling_depth: 4 12 | mask_act: "softmax" 13 | in_chan: 512 14 | # Training config 15 | training: 16 | epochs: 200 17 | batch_size: 2 18 | num_workers: 4 19 | half_lr: yes 20 | early_stop: yes 21 | gradient_clipping: 5 22 | # Optim config 23 | optim: 24 | optimizer: adam 25 | lr: 0.001 26 | weight_decay: !!float 1e-5 27 | # Data config 28 | data: 29 | task: enh_single 30 | train_dir: data/wav16k/max/train-360 31 | valid_dir: data/wav16k/max/dev 32 | sample_rate: 16000 33 | n_src: 1 34 | segment: 3 35 | 36 | -------------------------------------------------------------------------------- /egs/sms_wsj/README.md: -------------------------------------------------------------------------------- 1 | ### SMS_WSJ dataset 2 | 3 | SMS_WSJ (stands for Spatialized Multi-Speaker Wall Street Journal) 4 | is a multichannel source separation dataset, based on WSJ0 and WSJ1. 5 | 6 | All the information regarding the dataset can be found in 7 | [this repo](https://github.com/fgnt/sms_wsj). 8 | 9 | **References** 10 | If you use this dataset, please cite the corresponding paper as follows : 11 | ```BibTex 12 | @Article{SmsWsj19, 13 | author = {Drude, Lukas and Heitkaemper, Jens and Boeddeker, Christoph and Haeb-Umbach, Reinhold}, 14 | title = {{SMS-WSJ}: Database, performance measures, and baseline recipe for multi-channel source separation and recognition}, 15 | journal = {arXiv preprint arXiv:1910.13934}, 16 | year = {2019}, 17 | } 18 | ``` -------------------------------------------------------------------------------- /egs/kinect-wsj/DeepClustering/local/conf.yml: -------------------------------------------------------------------------------- 1 | # Filterbank config 2 | filterbank: 3 | n_filters: 512 4 | kernel_size: 512 5 | stride: 256 6 | # Network config 7 | masknet: 8 | rnn_type: lstm 9 | n_layers: 4 10 | hidden_size: 600 11 | dropout: 0.3 12 | embedding_dim: 40 13 | take_log: y 14 | # Training config 15 | training: 16 | epochs: 200 17 | batch_size: 32 18 | num_workers: 12 19 | half_lr: yes 20 | early_stop: yes 21 | loss_alpha: 1.0 # DC loss weight : 1.0 => DC, <1.0 => Chimera 22 | # Optim config 23 | optim: 24 | optimizer: rmsprop 25 | lr: 0.00001 26 | weight_decay: 0.00000 27 | # momentum: 0.9 28 | # Data config 29 | data: 30 | train_dir: data/wav16k/max/tr/ 31 | valid_dir: data/wav16k/max/cv/ 32 | n_src: 2 33 | sample_rate: 16000 34 | -------------------------------------------------------------------------------- /tests/dsp/deltas_tests.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pytest 3 | 4 | from asteroid.dsp.deltas import concat_deltas, compute_delta 5 | 6 | 7 | @pytest.mark.parametrize("dim", [1, 2, -1, -2]) 8 | def test_delta(dim): 9 | phase = torch.randn(2, 257, 100) 10 | delta_phase = compute_delta(phase, dim=dim) 11 | assert phase.shape == delta_phase.shape 12 | 13 | 14 | @pytest.mark.parametrize("dim", [1, 2, -1, -2]) 15 | @pytest.mark.parametrize("order", [1, 2]) 16 | def test_concat_deltas(dim, order): 17 | phase_shape = [2, 257, 100] 18 | phase = torch.randn(*phase_shape) 19 | cat_deltas = concat_deltas(phase, order=order, dim=dim) 20 | out_shape = list(phase_shape) 21 | out_shape[dim] = phase_shape[dim] * (1 + order) 22 | assert out_shape == list(cat_deltas.shape) 23 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: "\U0001F680 Feature request" 3 | about: Suggest an idea for this project 4 | title: '' 5 | labels: enhancement, help wanted 6 | assignees: '' 7 | 8 | --- 9 | 10 | ## 🚀 Feature 11 | 12 | 13 | ### Motivation 14 | 15 | 17 | 18 | ### What you'd like 19 | 20 | 21 | 22 | ### Alternatives 23 | 24 | 25 | 26 | ### Additional context 27 | 28 | 29 | -------------------------------------------------------------------------------- /egs/wham/FilterbankDesign/local/conf.yml: -------------------------------------------------------------------------------- 1 | # Filterbank config 2 | filterbank: 3 | fb_name: analytic_free 4 | n_filters: 512 5 | kernel_size: 16 6 | stride: 8 7 | inp_mode: reim 8 | mask_mode: reim 9 | # Network config 10 | masknet: 11 | n_blocks: 3 12 | n_repeats: 2 13 | mask_act: relu 14 | bn_chan: 128 15 | skip_chan: 128 16 | hid_chan: 512 17 | # Training config 18 | training: 19 | epochs: 100 20 | batch_size: 16 21 | num_workers: 4 22 | half_lr: yes 23 | early_stop: yes 24 | # Optim config 25 | optim: 26 | optimizer: adam 27 | lr: 0.001 28 | weight_decay: 0. 29 | # Data config 30 | data: 31 | train_dir: data/2speakers_wham/wav8k/min/tr/ 32 | valid_dir: data/2speakers_wham/wav8k/min/cv/ 33 | task: sep_clean 34 | nondefault_nsrc: 35 | sample_rate: 8000 36 | mode: min 37 | -------------------------------------------------------------------------------- /egs/wham/MixIT/local/conf.yml: -------------------------------------------------------------------------------- 1 | # Filterbank config 2 | filterbank: 3 | n_filters: 64 4 | kernel_size: 16 5 | stride: 8 6 | # Network config 7 | masknet: 8 | in_chan: 64 9 | n_src: 4 10 | out_chan: 64 11 | bn_chan: 128 12 | hid_size: 128 13 | chunk_size: 100 14 | hop_size: 50 15 | n_repeats: 6 16 | mask_act: 'sigmoid' 17 | bidirectional: true 18 | dropout: 0 19 | # Training config 20 | training: 21 | epochs: 200 22 | batch_size: 4 23 | num_workers: 4 24 | half_lr: yes 25 | early_stop: yes 26 | gradient_clipping: 5 27 | # Optim config 28 | optim: 29 | optimizer: adam 30 | lr: 0.001 31 | weight_decay: !!float 1e-5 32 | # Data config 33 | data: 34 | train_dir: data/wav8k/min/tr/ 35 | valid_dir: data/wav8k/min/cv/ 36 | sample_rate: 8000 37 | mode: min 38 | segment: 2.0 39 | -------------------------------------------------------------------------------- /egs/wham/DynamicMixing/README.md: -------------------------------------------------------------------------------- 1 | ### WHAM/WSJ0-2MIX Dynamic Mixing 2 | 3 | - introduced in [WaveSplit]() [1] 4 | - we added speed perturbation with SoX via [python-audio-effects](https://github.com/carlthome/python-audio-effects). 5 | - this recipe comes with DPRNN as default model but the model can be swapped by modifying model.py and local/conf.yml. 6 | - Original WSJ0 data is needed to run this recipe. 7 | 8 | ### Results: 9 | 10 | | model | task |kernel size|chunk size|batch size|SI-SNRi(dB) | SDRi(dB)| 11 | |:----:|:---------:|:---------:|:--------:|:--------:|:----------:|:-------:| 12 | | DPRNN + DM | sep_clean | 16 | 100 | 8 | 18.4 | 18.64 | 13 | | DPRNN | sep_clean | 16 | 100 | 8 | 17.7 | 17.9 | 14 | 15 | --- 16 | #### References: 17 | 18 | [1] 19 | -------------------------------------------------------------------------------- /egs/dns_challenge_INTERSPEECH2020/baseline/local/install_git_lfs.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Make lfs firectory 4 | mkdir -p lfs_install 5 | cd lfs_install 6 | 7 | # Download .tar file and extract 8 | wget https://github.com/git-lfs/git-lfs/releases/download/v2.10.0/git-lfs-linux-amd64-v2.10.0.tar.gz 9 | tar -xzvf git-lfs-linux-amd64-v2.10.0.tar.gz 10 | 11 | # To install without sudo? git-lfs will be installed in ~/.local/bin instead of /usr/local/bin 12 | sed -i 's+/usr/local+$HOME/.local+g' install.sh 13 | 14 | # Run the install script 15 | . ./install.sh 16 | 17 | # Export path 18 | export PATH=$PATH:$HOME/.local/bin 19 | 20 | echo -e "Installed git-lfs and temporarly added it to your path. To make it permanent, run the following command \n" 21 | echo "echo 'export PATH=\$PATH:\$HOME/.local/bin' >> ~/.bashrc" 22 | cd ../ 23 | 24 | -------------------------------------------------------------------------------- /asteroid/__init__.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | 3 | from .models import ConvTasNet, DCCRNet, DCUNet, DPRNNTasNet, DPTNet, LSTMTasNet, DeMask 4 | from .utils import deprecation_utils, torch_utils # noqa 5 | 6 | project_root = str(pathlib.Path(__file__).expanduser().absolute().parent.parent) 7 | __version__ = "0.7.1dev" 8 | 9 | 10 | def show_available_models(): 11 | from .utils.hub_utils import MODELS_URLS_HASHTABLE 12 | 13 | print(" \n".join(list(MODELS_URLS_HASHTABLE.keys()))) 14 | 15 | 16 | def available_models(): 17 | from .utils.hub_utils import MODELS_URLS_HASHTABLE 18 | 19 | return MODELS_URLS_HASHTABLE 20 | 21 | 22 | __all__ = [ 23 | "ConvTasNet", 24 | "DPRNNTasNet", 25 | "DPTNet", 26 | "LSTMTasNet", 27 | "DeMask", 28 | "DCUNet", 29 | "DCCRNet", 30 | "show_available_models", 31 | ] 32 | -------------------------------------------------------------------------------- /egs/librimix/DPRNNTasNet/local/conf.yml: -------------------------------------------------------------------------------- 1 | # Filterbank config 2 | filterbank: 3 | n_filters: 64 4 | kernel_size: 2 5 | stride: 1 6 | # Network config 7 | masknet: 8 | in_chan: 64 9 | out_chan: 64 10 | bn_chan: 128 11 | hid_size: 128 12 | chunk_size: 250 13 | hop_size: 125 14 | n_repeats: 6 15 | mask_act: 'sigmoid' 16 | bidirectional: true 17 | dropout: 0 18 | # Training config 19 | training: 20 | epochs: 200 21 | batch_size: 2 22 | num_workers: 4 23 | half_lr: yes 24 | early_stop: yes 25 | gradient_clipping: 5 26 | # Optim config 27 | optim: 28 | optimizer: adam 29 | lr: 0.001 30 | weight_decay: !!float 1e-5 31 | # Data config 32 | data: 33 | task: enh_single 34 | train_dir: data/wav16k/max/train-360 35 | valid_dir: data/wav16k/max/dev 36 | sample_rate: 16000 37 | n_src: 1 38 | segment: 1 39 | -------------------------------------------------------------------------------- /egs/wham/MixIT/README.md: -------------------------------------------------------------------------------- 1 | ### Description 2 | This simple recipe demonstrates MixIT Unsupervised Separation [1] 3 | on WSJ0-2Mix (we use WHAM clean) with DPRNN. 4 | We use MixIT to train DPRNN on mixtures of mixtures of always two speakers. 5 | Test and validation are the plain WHAM clean with always two speakers. 6 | Results can be improved by not having always 2 speakers in each mixture in 7 | train so that the mixture of mixtures will have not always 4 speakers as shown by [1]. 8 | 9 | 10 | 11 | References: 12 | 13 | ```BibTeX 14 | @article{wisdom2020unsupervised, 15 | title={Unsupervised sound separation using mixtures of mixtures}, 16 | author={Wisdom, Scott and Tzinis, Efthymios and Erdogan, Hakan and Weiss, Ron J and Wilson, Kevin and Hershey, John R}, 17 | journal={arXiv preprint arXiv:2006.12701}, 18 | year={2020} 19 | } 20 | -------------------------------------------------------------------------------- /egs/avspeech/looking-to-listen/local/postprocess/postprocess_audio.py: -------------------------------------------------------------------------------- 1 | import scipy.signal as sg 2 | from pysndfx import AudioEffectsChain 3 | 4 | 5 | def filter_audio(y, sr=16_000, cutoff=15_000, low_cutoff=1, filter_order=5): 6 | sos = sg.butter( 7 | filter_order, 8 | [low_cutoff / sr / 2, cutoff / sr / 2], 9 | btype="band", 10 | analog=False, 11 | output="sos", 12 | ) 13 | filtered = sg.sosfilt(sos, y) 14 | 15 | return filtered 16 | 17 | 18 | def shelf(y, sr=16_000, gain=5, frequency=500, slope=0.5, high_frequency=7_000): 19 | afc = AudioEffectsChain() 20 | fx = afc.lowshelf(gain=gain, frequency=frequency, slope=slope).highshelf( 21 | gain=-gain, frequency=high_frequency, slope=slope 22 | ) 23 | 24 | y = fx(y, sample_in=sr, sample_out=sr) 25 | 26 | return y 27 | -------------------------------------------------------------------------------- /tests/dsp/overlap_add_test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.testing import assert_close 3 | import pytest 4 | 5 | from asteroid.dsp.overlap_add import LambdaOverlapAdd 6 | 7 | 8 | @pytest.mark.parametrize("length", [1390, 8372]) 9 | @pytest.mark.parametrize("batch_size", [1, 2]) 10 | @pytest.mark.parametrize("n_src", [1, 2]) 11 | @pytest.mark.parametrize("window", ["hann", None]) 12 | @pytest.mark.parametrize("window_size", [128]) 13 | @pytest.mark.parametrize("hop_size", [64]) 14 | def test_overlap_add(length, batch_size, n_src, window, window_size, hop_size): 15 | mix = torch.randn((batch_size, length)).reshape(batch_size, 1, -1) 16 | nnet = lambda x: x.unsqueeze(1).repeat(1, n_src, 1) 17 | oladd = LambdaOverlapAdd(nnet, n_src, window_size, hop_size, window) 18 | oladded = oladd(mix) 19 | assert_close(mix.repeat(1, n_src, 1), oladded) 20 | -------------------------------------------------------------------------------- /egs/wham/DPRNN/local/conf.yml: -------------------------------------------------------------------------------- 1 | # Filterbank config 2 | filterbank: 3 | n_filters: 64 4 | kernel_size: 2 5 | stride: 1 6 | # Network config 7 | masknet: 8 | in_chan: 64 9 | n_src: 2 10 | out_chan: 64 11 | bn_chan: 128 12 | hid_size: 128 13 | chunk_size: 250 14 | hop_size: 125 15 | n_repeats: 6 16 | mask_act: 'sigmoid' 17 | bidirectional: true 18 | dropout: 0 19 | # Training config 20 | training: 21 | epochs: 200 22 | batch_size: 4 23 | num_workers: 4 24 | half_lr: yes 25 | early_stop: yes 26 | gradient_clipping: 5 27 | # Optim config 28 | optim: 29 | optimizer: adam 30 | lr: 0.001 31 | weight_decay: !!float 1e-5 32 | # Data config 33 | data: 34 | train_dir: data/wav8k/min/tr/ 35 | valid_dir: data/wav8k/min/cv/ 36 | task: sep_clean 37 | nondefault_nsrc: 38 | sample_rate: 8000 39 | mode: min 40 | segment: 2.0 41 | -------------------------------------------------------------------------------- /egs/fuss/README.md: -------------------------------------------------------------------------------- 1 | ### FUSS dataset 2 | 3 | The Free Universal Sound Separation (FUSS) dataset comprises audio mixtures of arbitrary sounds with source references for use in experiments on arbitrary sound separation. 4 | 5 | All the information related to this dataset can be found in [this repo](https://github.com/google-research/sound-separation/tree/master/datasets/fuss). 6 | 7 | 8 | 9 | **References** 10 | If you use this dataset, please cite the corresponding paper as follows: 11 | ```BibTex 12 | @Article{Wisdom2020, 13 | author = {Scott Wisdom and Hakan Erdogan and Daniel P. W. Ellis and Romain Serizel and Nicolas Turpault and Eduardo Fonseca and Justin Salamon and Prem Seetharaman and John R. Hershey}, 14 | title = {What's All the FUSS About Free Universal Sound Separation Data?}, 15 | journal = {in preparation}, 16 | year = {2020}, 17 | } 18 | ``` -------------------------------------------------------------------------------- /.github/workflows/lint.yml: -------------------------------------------------------------------------------- 1 | name: Automatic lint 2 | 3 | on: 4 | issue_comment: 5 | types: [created] 6 | 7 | jobs: 8 | build: 9 | name: Lint code base 10 | if: github.event.issue.pull_request != '' && contains(github.event.comment.body, '/lint') 11 | runs-on: ubuntu-latest 12 | 13 | steps: 14 | - name: Set up Python 3.13 15 | uses: actions/setup-python@v6 16 | with: 17 | python-version: 3.13 18 | 19 | - uses: actions/checkout@v5 20 | - run: git pull 21 | - run: pip install black==25.9.0 flake8 22 | - run: python -m black --config=pyproject.toml asteroid tests egs 23 | - name: Commit changes 24 | uses: EndBug/add-and-commit@v9 25 | with: 26 | message: "[Bot] Lint PR" 27 | add: "*.py" 28 | env: 29 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 30 | -------------------------------------------------------------------------------- /tests/cli_test.sh: -------------------------------------------------------------------------------- 1 | # Save a model (tmp.th) and two wavfiles (tmp.wav, tmp2.wav) 2 | python -m pip install -e . --quiet 3 | python tests/cli_setup.py 4 | 5 | # asteroid-register-sr` 6 | coverage run -a `which asteroid-register-sr` tmp.th 8000 7 | 8 | # asteroid-infer 9 | coverage run -a `which asteroid-infer` tmp.th --files tmp.wav 10 | coverage run -a `which asteroid-infer` tmp.th --files tmp.wav tmp2.wav --force-overwrite 11 | coverage run -a `which asteroid-infer` tmp.th --files tmp.wav --ola-window 1000 --force-overwrite 12 | coverage run -a `which asteroid-infer` tmp.th --files tmp.wav --ola-window 1000 --ola-no-reorder --force-overwrite 13 | 14 | # asteroid-upload 15 | echo "n" | coverage run -a `which asteroid-upload` publish_dir --uploader "Manuel Pariente" --affiliation "Loria" --use_sandbox --token $ACCESS_TOKEN 16 | 17 | # asteroid-version 18 | coverage run -a `which asteroid-versions` 19 | -------------------------------------------------------------------------------- /tests/jit/jit_masknn_test.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | from torch.testing import assert_close 4 | from asteroid.masknn import norms 5 | 6 | 7 | @pytest.mark.parametrize("cls", (norms.GlobLN, norms.FeatsGlobLN, norms.ChanLN)) 8 | def test_lns(cls): 9 | chan_size = 10 10 | model = cls(channel_size=chan_size) 11 | x = torch.randn(1, chan_size, 12) 12 | 13 | traced = torch.jit.trace(model, x) 14 | 15 | y = torch.randn(3, chan_size, 18, 12) 16 | assert_close(traced(y), model(y)) 17 | 18 | y = torch.randn(2, chan_size, 10, 5, 4) 19 | assert_close(traced(y), model(y)) 20 | 21 | 22 | def test_cumln(): 23 | chan_size = 10 24 | model = norms.CumLN(channel_size=chan_size) 25 | x = torch.randn(1, chan_size, 12) 26 | 27 | traced = torch.jit.trace(model, x) 28 | 29 | y = torch.randn(3, chan_size, 100) 30 | assert_close(traced(y), model(y)) 31 | -------------------------------------------------------------------------------- /.github/workflows/test_formatting.yml: -------------------------------------------------------------------------------- 1 | name: Linter 2 | on: [push, pull_request] 3 | 4 | jobs: 5 | code-black: 6 | name: CI 7 | runs-on: ubuntu-latest 8 | steps: 9 | - name: Checkout 10 | uses: actions/checkout@v5 11 | - name: Set up Python 3.13 12 | uses: actions/setup-python@v6 13 | with: 14 | python-version: 3.13 15 | 16 | - name: Install Black and flake8 17 | run: pip install black==25.9.0 flake8 18 | - name: Run Black 19 | run: python -m black --config=pyproject.toml --check asteroid tests egs 20 | 21 | - name: Link with flake8 22 | # Exit on important linting errors and warn about others. 23 | run: | 24 | python -m flake8 asteroid tests --show-source --statistics --select=F6,F7,F82,F52 25 | python -m flake8 --config .flake8 --exit-zero asteroid tests --statistics 26 | -------------------------------------------------------------------------------- /asteroid/data/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data._utils.collate import default_collate 3 | 4 | 5 | def online_mixing_collate(batch): 6 | """Mix target sources to create new mixtures. 7 | Output of the default collate function is expected to return two objects: 8 | inputs and targets. 9 | """ 10 | # Inputs (batch, time) / targets (batch, n_src, time) 11 | inputs, targets = default_collate(batch) 12 | batch, n_src, _ = targets.shape 13 | 14 | energies = torch.sum(targets**2, dim=-1, keepdim=True) 15 | new_src = [] 16 | for i in range(targets.shape[1]): 17 | new_s = targets[torch.randperm(batch), i, :] 18 | new_s = new_s * torch.sqrt(energies[:, i] / (new_s**2).sum(-1, keepdims=True)) 19 | new_src.append(new_s) 20 | 21 | targets = torch.stack(new_src, dim=1) 22 | inputs = targets.sum(1) 23 | return inputs, targets 24 | -------------------------------------------------------------------------------- /asteroid/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .avspeech_dataset import AVSpeechDataset 2 | from .wham_dataset import WhamDataset 3 | from .whamr_dataset import WhamRDataset 4 | from .dns_dataset import DNSDataset 5 | from .librimix_dataset import LibriMix 6 | from .wsj0_mix import Wsj0mixDataset 7 | from .musdb18_dataset import MUSDB18Dataset 8 | from .sms_wsj_dataset import SmsWsjDataset 9 | from .kinect_wsj import KinectWsjMixDataset 10 | from .fuss_dataset import FUSSDataset 11 | from .dampvsep_dataset import DAMPVSEPSinglesDataset 12 | from .vad_dataset import LibriVADDataset 13 | 14 | __all__ = [ 15 | "AVSpeechDataset", 16 | "WhamDataset", 17 | "WhamRDataset", 18 | "DNSDataset", 19 | "LibriMix", 20 | "Wsj0mixDataset", 21 | "MUSDB18Dataset", 22 | "SmsWsjDataset", 23 | "KinectWsjMixDataset", 24 | "FUSSDataset", 25 | "DAMPVSEPSinglesDataset", 26 | "LibriVADDataset", 27 | ] 28 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=source 11 | set BUILDDIR=build 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.http://sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% 33 | 34 | :end 35 | popd -------------------------------------------------------------------------------- /egs/wsj0-mix-var/Multi-Decoder-DPRNN/local/conf.yml: -------------------------------------------------------------------------------- 1 | # Filterbank config 2 | filterbank: 3 | n_filters: 64 4 | kernel_size: 8 5 | stride: 4 6 | # Network config 7 | masknet: 8 | n_srcs: [2, 3, 4, 5] 9 | bn_chan: 128 10 | hid_size: 128 11 | chunk_size: 128 12 | hop_size: 64 13 | n_repeats: 8 14 | mask_act: 'sigmoid' 15 | bidirectional: true 16 | dropout: 0 17 | use_mulcat: false 18 | # Training config 19 | training: 20 | epochs: 200 21 | batch_size: 2 22 | num_workers: 2 23 | half_lr: yes 24 | lr_decay: yes 25 | early_stop: yes 26 | gradient_clipping: 5 27 | # Optim config 28 | optim: 29 | optimizer: adam 30 | lr: 0.001 31 | weight_decay: 0.00000 32 | # Data config 33 | data: 34 | train_dir: "data/{}speakers/wav8k/min/tr" 35 | valid_dir: "data/{}speakers/wav8k/min/cv" 36 | task: sep_count 37 | sample_rate: 8000 38 | seglen: 4.0 39 | minlen: 2.0 40 | loss: 41 | lambda: 0.05 42 | -------------------------------------------------------------------------------- /egs/librimix/DPTNet/local/conf.yml: -------------------------------------------------------------------------------- 1 | # Filterbank config 2 | filterbank: 3 | n_filters: 64 4 | kernel_size: 32 5 | stride: 16 6 | # Network config 7 | masknet: 8 | in_chan: 64 9 | out_chan: 64 10 | ff_hid: 256 11 | ff_activation: "relu" 12 | norm_type: "gLN" 13 | chunk_size: 100 14 | hop_size: 50 15 | n_repeats: 2 16 | mask_act: 'sigmoid' 17 | bidirectional: true 18 | dropout: 0 19 | # Training config 20 | training: 21 | epochs: 200 22 | batch_size: 4 23 | num_workers: 4 24 | half_lr: yes 25 | early_stop: yes 26 | gradient_clipping: 5 27 | # Optim config 28 | optim: 29 | optimizer: adam 30 | lr: 0.001 31 | weight_decay: !!float 1e-5 32 | scheduler: 33 | steps_per_epoch: 10000 34 | d_model: 64 35 | # Data config 36 | data: 37 | task: enh_single 38 | train_dir: data/wav16k/max/train-360 39 | valid_dir: data/wav16k/max/dev 40 | sample_rate: 16000 41 | n_src: 1 42 | segment: 3 43 | 44 | -------------------------------------------------------------------------------- /egs/librimix/DCCRNet/local/get_text.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import glob 3 | import os 4 | import pandas as pd 5 | 6 | 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument("--libridir", required=True, type=str) 9 | parser.add_argument("--outfile", required=True, type=str) 10 | parser.add_argument("--split", type=str, default="train-360") 11 | 12 | args = parser.parse_args() 13 | 14 | libridir = os.path.join(args.libridir, args.split) 15 | trans_txt_list = glob.glob(os.path.join(libridir, "**/*.txt"), recursive=True) 16 | row_list = [] 17 | for name in trans_txt_list: 18 | f = open(name, "r") 19 | for line in f: 20 | dict1 = {} 21 | split_line = line.split(" ", maxsplit=1) 22 | dict1["utt_id"] = split_line[0] 23 | dict1["text"] = split_line[1].replace("\n", "").replace("\r", "") 24 | row_list.append(dict1) 25 | 26 | df = pd.DataFrame(row_list) 27 | df.to_csv(args.outfile, index=False) 28 | -------------------------------------------------------------------------------- /egs/librimix/DCUNet/local/get_text.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import glob 3 | import os 4 | import pandas as pd 5 | 6 | 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument("--libridir", required=True, type=str) 9 | parser.add_argument("--outfile", required=True, type=str) 10 | parser.add_argument("--split", type=str, default="train-360") 11 | 12 | args = parser.parse_args() 13 | 14 | libridir = os.path.join(args.libridir, args.split) 15 | trans_txt_list = glob.glob(os.path.join(libridir, "**/*.txt"), recursive=True) 16 | row_list = [] 17 | for name in trans_txt_list: 18 | f = open(name, "r") 19 | for line in f: 20 | dict1 = {} 21 | split_line = line.split(" ", maxsplit=1) 22 | dict1["utt_id"] = split_line[0] 23 | dict1["text"] = split_line[1].replace("\n", "").replace("\r", "") 24 | row_list.append(dict1) 25 | 26 | df = pd.DataFrame(row_list) 27 | df.to_csv(args.outfile, index=False) 28 | -------------------------------------------------------------------------------- /egs/librimix/DPTNet/local/get_text.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import glob 3 | import os 4 | import pandas as pd 5 | 6 | 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument("--libridir", required=True, type=str) 9 | parser.add_argument("--outfile", required=True, type=str) 10 | parser.add_argument("--split", type=str, default="train-360") 11 | 12 | args = parser.parse_args() 13 | 14 | libridir = os.path.join(args.libridir, args.split) 15 | trans_txt_list = glob.glob(os.path.join(libridir, "**/*.txt"), recursive=True) 16 | row_list = [] 17 | for name in trans_txt_list: 18 | f = open(name, "r") 19 | for line in f: 20 | dict1 = {} 21 | split_line = line.split(" ", maxsplit=1) 22 | dict1["utt_id"] = split_line[0] 23 | dict1["text"] = split_line[1].replace("\n", "").replace("\r", "") 24 | row_list.append(dict1) 25 | 26 | df = pd.DataFrame(row_list) 27 | df.to_csv(args.outfile, index=False) 28 | -------------------------------------------------------------------------------- /egs/wham/TwoStep/local/conf.yml: -------------------------------------------------------------------------------- 1 | # Filterbank config 2 | filterbank: 3 | n_filters: 128 4 | kernel_size: 21 5 | stride: 10 6 | # Network config 7 | masknet: 8 | n_blocks: 8 9 | n_repeats: 4 10 | conv_kernel_size: 3 11 | bn_chan: 256 12 | hid_chan: 512 13 | # Training config for the filterbank 14 | filterbank_training: 15 | reuse_pretrained_filterbank: yes 16 | f_epochs: 50 17 | f_batch_size: 4 18 | f_num_workers: 4 19 | f_half_lr: yes 20 | f_early_stop: yes 21 | f_optimizer: adam 22 | f_lr: 0.0005 23 | # Training config for the separation module 24 | separator_training: 25 | s_epochs: 200 26 | s_batch_size: 4 27 | s_num_workers: 4 28 | s_half_lr: yes 29 | s_early_stop: yes 30 | s_optimizer: adam 31 | s_lr: 0.001 32 | # Data config 33 | data: 34 | train_dir: data/wav8k/min/tr/ 35 | valid_dir: data/wav8k/min/cv/ 36 | task: sep_clean 37 | nondefault_nsrc: 38 | sample_rate: 8000 39 | mode: min 40 | -------------------------------------------------------------------------------- /egs/librimix/ConvTasNet/local/get_text.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import glob 3 | import os 4 | import pandas as pd 5 | 6 | 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument("--libridir", required=True, type=str) 9 | parser.add_argument("--outfile", required=True, type=str) 10 | parser.add_argument("--split", type=str, default="train-360") 11 | 12 | args = parser.parse_args() 13 | 14 | libridir = os.path.join(args.libridir, args.split) 15 | trans_txt_list = glob.glob(os.path.join(libridir, "**/*.txt"), recursive=True) 16 | row_list = [] 17 | for name in trans_txt_list: 18 | f = open(name, "r") 19 | for line in f: 20 | dict1 = {} 21 | split_line = line.split(" ", maxsplit=1) 22 | dict1["utt_id"] = split_line[0] 23 | dict1["text"] = split_line[1].replace("\n", "").replace("\r", "") 24 | row_list.append(dict1) 25 | 26 | df = pd.DataFrame(row_list) 27 | df.to_csv(args.outfile, index=False) 28 | -------------------------------------------------------------------------------- /egs/librimix/DPRNNTasNet/local/get_text.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import glob 3 | import os 4 | import pandas as pd 5 | 6 | 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument("--libridir", required=True, type=str) 9 | parser.add_argument("--outfile", required=True, type=str) 10 | parser.add_argument("--split", type=str, default="train-360") 11 | 12 | args = parser.parse_args() 13 | 14 | libridir = os.path.join(args.libridir, args.split) 15 | trans_txt_list = glob.glob(os.path.join(libridir, "**/*.txt"), recursive=True) 16 | row_list = [] 17 | for name in trans_txt_list: 18 | f = open(name, "r") 19 | for line in f: 20 | dict1 = {} 21 | split_line = line.split(" ", maxsplit=1) 22 | dict1["utt_id"] = split_line[0] 23 | dict1["text"] = split_line[1].replace("\n", "").replace("\r", "") 24 | row_list.append(dict1) 25 | 26 | df = pd.DataFrame(row_list) 27 | df.to_csv(args.outfile, index=False) 28 | -------------------------------------------------------------------------------- /egs/librimix/SuDORMRFNet/local/get_text.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import glob 3 | import os 4 | import pandas as pd 5 | 6 | 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument("--libridir", required=True, type=str) 9 | parser.add_argument("--outfile", required=True, type=str) 10 | parser.add_argument("--split", type=str, default="train-360") 11 | 12 | args = parser.parse_args() 13 | 14 | libridir = os.path.join(args.libridir, args.split) 15 | trans_txt_list = glob.glob(os.path.join(libridir, "**/*.txt"), recursive=True) 16 | row_list = [] 17 | for name in trans_txt_list: 18 | f = open(name, "r") 19 | for line in f: 20 | dict1 = {} 21 | split_line = line.split(" ", maxsplit=1) 22 | dict1["utt_id"] = split_line[0] 23 | dict1["text"] = split_line[1].replace("\n", "").replace("\r", "") 24 | row_list.append(dict1) 25 | 26 | df = pd.DataFrame(row_list) 27 | df.to_csv(args.outfile, index=False) 28 | -------------------------------------------------------------------------------- /egs/librimix/SuDORMRFImprovedNet/local/get_text.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import glob 3 | import os 4 | import pandas as pd 5 | 6 | 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument("--libridir", required=True, type=str) 9 | parser.add_argument("--outfile", required=True, type=str) 10 | parser.add_argument("--split", type=str, default="train-360") 11 | 12 | args = parser.parse_args() 13 | 14 | libridir = os.path.join(args.libridir, args.split) 15 | trans_txt_list = glob.glob(os.path.join(libridir, "**/*.txt"), recursive=True) 16 | row_list = [] 17 | for name in trans_txt_list: 18 | f = open(name, "r") 19 | for line in f: 20 | dict1 = {} 21 | split_line = line.split(" ", maxsplit=1) 22 | dict1["utt_id"] = split_line[0] 23 | dict1["text"] = split_line[1].replace("\n", "").replace("\r", "") 24 | row_list.append(dict1) 25 | 26 | df = pd.DataFrame(row_list) 27 | df.to_csv(args.outfile, index=False) 28 | -------------------------------------------------------------------------------- /egs/avspeech/looking-to-listen/local/data_prep.yml: -------------------------------------------------------------------------------- 1 | # Dataset config 2 | data: 3 | n_src: 2 4 | train_input_df_path: "local/train.csv" 5 | val_input_df_path: "local/val.csv" 6 | # Download config: 7 | download: 8 | # keeping all path w.r.t. file location 9 | download_jobs: 4 10 | download_path: "../../data/audio_visual/avspeech_train.csv" 11 | download_start: 0 12 | download_end: 10000 13 | # Extract audio config 14 | extract: 15 | extract_jobs: 4 16 | extract_sampling_rate: 16000 17 | extract_input_audio_channel: 2 18 | extract_audio_extension: "wav" 19 | extract_duration: 3 20 | # Mix audio config 21 | mix: 22 | mix_remove_random_chance: 0.9 23 | mix_use_audio_set: False 24 | mix_file_limit: 100_000_000 25 | mix_validation_size: 0.3 26 | # Extract face config 27 | face: 28 | face_cuda: True 29 | face_use_half: False 30 | face_corrupt_file_path: "../../data/corrupt_frames_list.txt" 31 | -------------------------------------------------------------------------------- /egs/wham/DPTNet/local/conf.yml: -------------------------------------------------------------------------------- 1 | # Filterbank config 2 | filterbank: 3 | n_filters: 64 4 | kernel_size: 16 5 | stride: 8 6 | # Network config 7 | masknet: 8 | in_chan: 64 9 | n_src: 2 10 | out_chan: 64 11 | ff_hid: 256 12 | ff_activation: "relu" 13 | norm_type: "gLN" 14 | chunk_size: 100 15 | hop_size: 50 16 | n_repeats: 2 17 | mask_act: 'sigmoid' 18 | bidirectional: true 19 | dropout: 0 20 | # Training config 21 | training: 22 | epochs: 200 23 | batch_size: 4 24 | num_workers: 4 25 | half_lr: yes 26 | early_stop: yes 27 | gradient_clipping: 5 28 | # Optim config 29 | optim: 30 | optimizer: adam 31 | lr: 0.001 32 | weight_decay: !!float 1e-5 33 | scheduler: 34 | steps_per_epoch: 10000 35 | d_model: 64 36 | # Data config 37 | data: 38 | train_dir: data/wav8k/min/tr/ 39 | valid_dir: data/wav8k/min/cv/ 40 | task: sep_clean 41 | nondefault_nsrc: 42 | sample_rate: 8000 43 | mode: min 44 | segment: 2.0 45 | -------------------------------------------------------------------------------- /egs/avspeech/README.md: -------------------------------------------------------------------------------- 1 | ### AVSpeech dataset 2 | 3 | AVSpeech is an audio-visual speech separation dataset which was introduced by Google 4 | in this article [Looking to Listen at the Cocktail Party: 5 | A Speaker-Independent Audio-Visual Model for Speech 6 | Separation](https://arxiv.org/abs/1804.03619). 7 | 8 | More info [here](https://looking-to-listen.github.io/avspeech/download.html). 9 | 10 | **References** 11 | ```BibTex 12 | @article{Ephrat_2018, 13 | title={Looking to listen at the cocktail party}, 14 | volume={37}, 15 | url={http://dx.doi.org/10.1145/3197517.3201357}, 16 | DOI={10.1145/3197517.3201357}, 17 | journal={ACM Transactions on Graphics}, 18 | publisher={Association for Computing Machinery (ACM)}, 19 | author={Ephrat, Ariel and Mosseri, Inbar and Lang, Oran and Dekel, Tali and Wilson, Kevin and Hassidim, Avinatan and Freeman, William T. and Rubinstein, Michael}, 20 | year={2018}, 21 | pages={1–11} 22 | } 23 | ``` 24 | -------------------------------------------------------------------------------- /egs/musdb18/X-UMX/local/conf.yml: -------------------------------------------------------------------------------- 1 | # Training config 2 | training: 3 | epochs: 1000 4 | batch_size: 14 5 | loss_combine_sources: yes 6 | loss_use_multidomain: yes 7 | mix_coef: 10.0 8 | val_dur: 80.0 9 | # Optim config 10 | optim: 11 | optimizer: adam 12 | lr: 0.001 13 | patience: 1000 14 | lr_decay_patience: 80 15 | lr_decay_gamma: 0.3 16 | weight_decay: 0.00001 17 | # Data config 18 | data: 19 | train_dir: ./data 20 | output: x-umx_outputs 21 | sample_rate: 44100 22 | num_workers: 4 23 | seed: 42 24 | seq_dur: 6.0 25 | samples_per_track: 64 26 | source_augmentations: 27 | - gain 28 | - channelswap 29 | sources: 30 | - bass 31 | - drums 32 | - vocals 33 | - other 34 | # Network config 35 | model: 36 | pretrained: null 37 | bidirectional: yes 38 | window_length: 4096 39 | in_chan: 4096 40 | nhop: 1024 41 | hidden_size: 512 42 | bandwidth: 16000 43 | nb_channels: 2 44 | spec_power: 1 45 | -------------------------------------------------------------------------------- /egs/dampvsep/ConvTasNet/local/conf.yml: -------------------------------------------------------------------------------- 1 | # filterbank config 2 | filterbank: 3 | n_filters: 256 # N 4 | kernel_size: 20 # L 5 | stride: 10 # L/2 6 | # Network config 7 | masknet: 8 | n_blocks: 8 # X 9 | n_repeats: 4 # R 10 | mask_act: relu # mask_nonlinear 11 | conv_kernel_size: 3 # P 12 | bn_chan: 256 # B 13 | skip_chan: 256 # Sc 14 | hid_chan: 512 # H 15 | norm_type: gLN 16 | n_src: 2 17 | # Training config 18 | training: 19 | epochs: 200 20 | batch_size: 16 21 | num_workers: 4 22 | half_lr: yes 23 | early_stop: yes 24 | loss_alpha: 0.3 # Alpha for loss function 25 | # Optim config 26 | optim: 27 | optimizer: adam 28 | lr: 0.001 29 | weight_decay: 0. 30 | # Data config 31 | data: 32 | task: enh_both 33 | train_set: english 34 | root_path: 35 | mixture: remix 36 | sample_rate: 16000 37 | segment: 8. 38 | ex_per_track: 16 39 | channels: 1 40 | 41 | -------------------------------------------------------------------------------- /egs/avspeech/looking-to-listen/local/loader/remove_corrupt.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pandas as pd 3 | 4 | df_train = pd.read_csv("../../data/train.csv") 5 | df_val = pd.read_csv("../../data/val.csv") 6 | 7 | print(df_train.shape) 8 | print(df_val.shape) 9 | 10 | corrupt_files = [] 11 | 12 | with open("../../data/corrupt_frames_list.txt") as f: 13 | corrupt_files = f.readlines() 14 | 15 | corrupt_files = set(corrupt_files) 16 | print(len(corrupt_files)) 17 | corrupt_files = [c[:-1] for c in corrupt_files] 18 | print(corrupt_files) 19 | 20 | df_train = df_train[~df_train["video_1"].isin(corrupt_files)] 21 | df_val = df_val[~df_val["video_1"].isin(corrupt_files)] 22 | 23 | df_train = df_train[~df_train["video_2"].isin(corrupt_files)] 24 | df_val = df_val[~df_val["video_2"].isin(corrupt_files)] 25 | 26 | print(df_train.shape) 27 | print(df_val.shape) 28 | 29 | df_train.to_csv("../../data/train.csv", index=False) 30 | 31 | df_val.to_csv("../../data/val.csv", index=False) 32 | -------------------------------------------------------------------------------- /tests/binarize_test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from asteroid.binarize import Binarize 4 | 5 | 6 | def test_Binarize(): 7 | # fmt: off 8 | inputs_list = [ 9 | torch.Tensor([0.1, 0.6, 0.2, 0.6, 0.1, 0.1, 0.1, 0.7, 0.7, 0.7, 0.1, 0.7, 0.7, 0.7, 0.1, 10 | 0.8, 0.9, 0.2, 0.7, 0.1, 0.1, 0.1, 0.8, 0.1]), 11 | torch.Tensor([0.1, 0.1, 0.2, 0.1]), 12 | torch.Tensor([0.7, 0.7, 0.7, 0.7]), 13 | torch.Tensor([0.1, 0.7]), 14 | ] 15 | # fmt: on 16 | expected_result_list = [ 17 | torch.Tensor([0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0]), 18 | torch.Tensor([0.0, 0.0, 0.0, 0.0]), 19 | torch.Tensor([1, 1, 1, 1]), 20 | torch.Tensor([0.0, 0.0]), 21 | ] 22 | binarizer = Binarize(0.5, 3, 1) 23 | for i in range(len(inputs_list)): 24 | result = binarizer(inputs_list[i].unsqueeze(0).unsqueeze(0)) 25 | assert torch.allclose(result, expected_result_list[i]) 26 | -------------------------------------------------------------------------------- /egs/wham/ConvTasNet/utils/prepare_python_env.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Usage ./utils/install_env.sh --install_dir A --asteroid_root B --pip_requires C 3 | install_dir=~ 4 | asteroid_root=../../../../ 5 | pip_requires=../../../requirements.txt # Expects a requirement.txt 6 | 7 | . utils/parse_options.sh || exit 1 8 | 9 | mkdir -p $install_dir 10 | cd $install_dir 11 | echo "Download and install latest version of miniconda3 into ${install_dir}" 12 | wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh 13 | 14 | bash Miniconda3-latest-Linux-x86_64.sh -b -p miniconda3 15 | pip_path=$PWD/miniconda3/bin/pip 16 | 17 | rm Miniconda3-latest-Linux-x86_64.sh 18 | cd - 19 | 20 | if [[ ! -z ${pip_requires} ]]; then 21 | $pip_path install -r $pip_requires 22 | fi 23 | $pip_path install soundfile 24 | $pip_path install -e $asteroid_root 25 | #$pip_path install ${asteroid_root}/\[""evaluate""\] 26 | echo -e "\nAsteroid has been installed in editable mode. Feel free to apply your changes !" -------------------------------------------------------------------------------- /egs/wham/DPTNet/utils/prepare_python_env.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Usage ./utils/install_env.sh --install_dir A --asteroid_root B --pip_requires C 3 | install_dir=~ 4 | asteroid_root=../../../../ 5 | pip_requires=../../../requirements.txt # Expects a requirement.txt 6 | 7 | . utils/parse_options.sh || exit 1 8 | 9 | mkdir -p $install_dir 10 | cd $install_dir 11 | echo "Download and install latest version of miniconda3 into ${install_dir}" 12 | wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh 13 | 14 | bash Miniconda3-latest-Linux-x86_64.sh -b -p miniconda3 15 | pip_path=$PWD/miniconda3/bin/pip 16 | 17 | rm Miniconda3-latest-Linux-x86_64.sh 18 | cd - 19 | 20 | if [[ ! -z ${pip_requires} ]]; then 21 | $pip_path install -r $pip_requires 22 | fi 23 | $pip_path install soundfile 24 | $pip_path install -e $asteroid_root 25 | #$pip_path install ${asteroid_root}/\[""evaluate""\] 26 | echo -e "\nAsteroid has been installed in editable mode. Feel free to apply your changes !" 27 | -------------------------------------------------------------------------------- /egs/wham/DynamicMixing/utils/prepare_python_env.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Usage ./utils/install_env.sh --install_dir A --asteroid_root B --pip_requires C 3 | install_dir=~ 4 | asteroid_root=../../../../ 5 | pip_requires=../../../requirements.txt # Expects a requirement.txt 6 | 7 | . utils/parse_options.sh || exit 1 8 | 9 | mkdir -p $install_dir 10 | cd $install_dir 11 | echo "Download and install latest version of miniconda3 into ${install_dir}" 12 | wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh 13 | 14 | bash Miniconda3-latest-Linux-x86_64.sh -b -p miniconda3 15 | pip_path=$PWD/miniconda3/bin/pip 16 | 17 | rm Miniconda3-latest-Linux-x86_64.sh 18 | cd - 19 | 20 | if [[ ! -z ${pip_requires} ]]; then 21 | $pip_path install -r $pip_requires 22 | fi 23 | $pip_path install soundfile 24 | $pip_path install -e $asteroid_root 25 | #$pip_path install ${asteroid_root}/\[""evaluate""\] 26 | echo -e "\nAsteroid has been installed in editable mode. Feel free to apply your changes !" -------------------------------------------------------------------------------- /egs/demask/local/parse_data.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import soundfile as sf 4 | from glob import glob 5 | import argparse 6 | from pathlib import Path 7 | 8 | parser = argparse.ArgumentParser( 9 | "Script to parse data to json in order to avoid parsing each time at beginning of each experiment" 10 | ) 11 | parser.add_argument("--input_dir", type=str) 12 | parser.add_argument("--output_json", type=str) 13 | parser.add_argument("--regex", type=str) 14 | 15 | if __name__ == "__main__": 16 | args = parser.parse_args() 17 | assert os.path.exists(args.input_dir), "Input dir does not exist" 18 | files = glob(os.path.join(args.input_dir, args.regex), recursive=True) 19 | to_json = [] 20 | for f in files: 21 | meta = sf.SoundFile(f) 22 | samples = len(meta) 23 | to_json.append({"file": f, "length": samples}) 24 | 25 | os.makedirs(Path(args.output_json).parent, exist_ok=True) 26 | with open(args.output_json, "w") as f: 27 | json.dump(to_json, f) 28 | -------------------------------------------------------------------------------- /egs/wsj0-mix-var/Multi-Decoder-DPRNN/utils/prepare_python_env.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Usage ./utils/install_env.sh --install_dir A --asteroid_root B --pip_requires C 3 | install_dir=~ 4 | asteroid_root=../../../../ 5 | pip_requires=../../../requirements.txt # Expects a requirement.txt 6 | 7 | . utils/parse_options.sh || exit 1 8 | 9 | mkdir -p $install_dir 10 | cd $install_dir 11 | echo "Download and install latest version of miniconda3 into ${install_dir}" 12 | wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh 13 | 14 | bash Miniconda3-latest-Linux-x86_64.sh -b -p miniconda3 15 | pip_path=$PWD/miniconda3/bin/pip 16 | 17 | rm Miniconda3-latest-Linux-x86_64.sh 18 | cd - 19 | 20 | if [[ ! -z ${pip_requires} ]]; then 21 | $pip_path install -r $pip_requires 22 | fi 23 | $pip_path install soundfile 24 | $pip_path install -e $asteroid_root 25 | #$pip_path install ${asteroid_root}/\[""evaluate""\] 26 | echo -e "\nAsteroid has been installed in editable mode. Feel free to apply your changes !" -------------------------------------------------------------------------------- /docs/source/package_reference/models.rst: -------------------------------------------------------------------------------- 1 | .. role:: hidden 2 | :class: hidden-section 3 | 4 | Models 5 | ====== 6 | 7 | Base classes 8 | ------------ 9 | 10 | .. automodule:: asteroid.models.base_models 11 | :members: 12 | 13 | Ready-to-use models 14 | ------------------- 15 | 16 | .. automodule:: asteroid.models.conv_tasnet 17 | :members: 18 | 19 | .. automodule:: asteroid.models.dccrnet 20 | :members: 21 | 22 | .. automodule:: asteroid.models.dcunet 23 | :members: 24 | 25 | .. automodule:: asteroid.models.demask 26 | :members: 27 | 28 | .. automodule:: asteroid.models.dprnn_tasnet 29 | :members: 30 | 31 | .. automodule:: asteroid.models.dptnet 32 | :members: 33 | 34 | .. automodule:: asteroid.models.lstm_tasnet 35 | :members: 36 | 37 | .. automodule:: asteroid.models.sudormrf 38 | :members: 39 | 40 | 41 | 42 | Publishing models 43 | ----------------- 44 | .. automodule:: asteroid.models.zenodo 45 | :members: 46 | 47 | .. automodule:: asteroid.models.publisher 48 | :members: 49 | -------------------------------------------------------------------------------- /egs/demask/local/conf.yml: -------------------------------------------------------------------------------- 1 | filterbank: 2 | fb_type: stft 3 | n_filters: 512 4 | kernel_size: 512 5 | stride: 256 6 | demask_net: 7 | input_type: mag 8 | output_type: mag 9 | hidden_dims: [1024] 10 | dropout: 0 11 | activation: relu 12 | mask_act: relu 13 | norm_type: gLN 14 | data: 15 | fs: 16000 16 | length: 4 17 | clean_speech_train: ./data/clean/train-clean-360.json 18 | clean_speech_valid: ./data/clean/dev-clean.json 19 | rir_train: ./data/rirs/train.json 20 | rir_valid: ./data/rirs/validation.json 21 | optim: 22 | lr: 0.001 23 | weight_decay: !!float 1e-5 24 | training: 25 | epochs: 200 26 | batch_size: 4 27 | gradient_clipping: 5 28 | accumulate_batches: 1 29 | save_top_k: 10 30 | num_workers: 8 31 | patience: 30 32 | half_lr: true 33 | early_stop: true 34 | gaussian_mask_noise_snr_dB: np.random.randint(3, 12) 35 | white_noise_dB: np.random.randint(-3, 30) 36 | speed_augm: np.random.uniform(0.95, 1.05) 37 | gain_augm: np.random.randint(-30, -2) 38 | n_taps: 97 39 | -------------------------------------------------------------------------------- /tests/models/fasnet_test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pytest 3 | 4 | from asteroid.models.fasnet import FasNetTAC 5 | 6 | 7 | @pytest.mark.parametrize("samples", [8372]) 8 | @pytest.mark.parametrize("batch_size", [1, 2]) 9 | @pytest.mark.parametrize("n_mics", [1, 2]) 10 | @pytest.mark.parametrize("n_src", [1, 2, 3]) 11 | @pytest.mark.parametrize("use_tac", [True, False]) 12 | @pytest.mark.parametrize("enc_dim", [4]) 13 | @pytest.mark.parametrize("feature_dim", [8]) 14 | @pytest.mark.parametrize("window", [2]) 15 | @pytest.mark.parametrize("context", [3]) 16 | def test_fasnet(batch_size, n_mics, samples, n_src, use_tac, enc_dim, feature_dim, window, context): 17 | mixture = torch.rand((batch_size, n_mics, samples)) 18 | valid_mics = torch.tensor([n_mics for x in range(batch_size)]) 19 | fasnet = FasNetTAC( 20 | n_src, 21 | use_tac=use_tac, 22 | enc_dim=enc_dim, 23 | feature_dim=feature_dim, 24 | window_ms=window, 25 | context_ms=context, 26 | ) 27 | fasnet(mixture, valid_mics) 28 | -------------------------------------------------------------------------------- /asteroid/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .generic_utils import ( 2 | average_arrays_in_dic, 3 | flatten_dict, 4 | get_wav_random_start_stop, 5 | has_arg, 6 | unet_decoder_args, 7 | ) 8 | from .parser_utils import ( 9 | prepare_parser_from_dict, 10 | parse_args_as_dict, 11 | str_int_float, 12 | str2bool, 13 | str2bool_arg, 14 | isfloat, 15 | isint, 16 | ) 17 | from .torch_utils import tensors_to_device, to_cuda, get_device 18 | 19 | # The functions above were all in asteroid/utils.py before refactoring into 20 | # asteroid/utils/*_utils.py files. They are imported for backward compatibility. 21 | 22 | __all__ = [ 23 | "prepare_parser_from_dict", 24 | "parse_args_as_dict", 25 | "str_int_float", 26 | "str2bool", 27 | "str2bool_arg", 28 | "isfloat", 29 | "isint", 30 | "tensors_to_device", 31 | "to_cuda", 32 | "get_device", 33 | "has_arg", 34 | "flatten_dict", 35 | "average_arrays_in_dic", 36 | "get_wav_random_start_stop", 37 | "unet_decoder_args", 38 | ] 39 | -------------------------------------------------------------------------------- /docs/source/_templates/theme_variables.jinja: -------------------------------------------------------------------------------- 1 | {%- set external_urls = { 2 | 'github': 'https://github.com/mpariente/asteroid', 3 | 'github_issues': 'https://github.com/mpariente/asteroid/issues', 4 | 'contributing': 'https://github.com/mpariente/asteroid/blob/master/CONTRIBUTING.md', 5 | 'governance': 'https://github.com/mpariente/asteroid/blob/master/CONTRIBUTING.md', 6 | 'docs': 'https://mpariente.github.io/asteroid/', 7 | 'twitter': 'https://twitter.com/', 8 | 'discuss': 'https://discuss.pytorch.org', 9 | 'tutorials': '', 10 | 'previous_pytorch_versions': 'https://pytorch-lightning.rtfd.io/en/latest/', 11 | 'home': 'https://mpariente.github.io/asteroid/', 12 | 'get_started': 'https://colab.research.google.com/github/mpariente/asteroid/blob/master/notebooks/00_GettingStarted.ipynb', 13 | 'features': 'https://mpariente.github.io/asteroid/', 14 | 'blog': 'https://mpariente.github.io/asteroid/', 15 | 'resources': 'https://mpariente.github.io/asteroid/', 16 | 'support': 'https://mpariente.github.io/asteroid/', 17 | } 18 | -%} 19 | -------------------------------------------------------------------------------- /egs/wsj0-mix-var/Multi-Decoder-DPRNN/separate.py: -------------------------------------------------------------------------------- 1 | import torch, torchaudio 2 | import argparse 3 | import os 4 | from model import MultiDecoderDPRNN 5 | 6 | os.makedirs("outputs", exist_ok=True) 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument( 9 | "--wav_file", 10 | type=str, 11 | default="", 12 | help="Path to the wav file to run model inference on.", 13 | ) 14 | args = parser.parse_args() 15 | 16 | mixture, sample_rate = torchaudio.load(args.wav_file) 17 | 18 | model = MultiDecoderDPRNN.from_pretrained("JunzheJosephZhu/MultiDecoderDPRNN").eval() 19 | if torch.cuda.is_available(): 20 | model.cuda() 21 | mixture = mixture.cuda() 22 | sources_est = model.separate(mixture).cpu() 23 | for i, source in enumerate(sources_est): 24 | torchaudio.save(f"outputs/{i}.wav", source[None], sample_rate) 25 | 26 | print( 27 | "Thank you for using Multi-Decoder-DPRNN to separate your mixture files. \ 28 | Please support our work by citing our paper: http://www.isle.illinois.edu/speech_web_lg/pubs/2021/zhu2021multi.pdf" 29 | ) 30 | -------------------------------------------------------------------------------- /docs/source/package_reference/data.rst: -------------------------------------------------------------------------------- 1 | PyTorch Datasets 2 | ================ 3 | 4 | This page lists the supported datasets and their corresponding 5 | PyTorch's ``Dataset`` class. If you're interested in the datasets more 6 | than in the code, see `this page <../supported_datasets.rst>`__. 7 | 8 | .. currentmodule:: asteroid.data 9 | 10 | LibriMix 11 | -------- 12 | .. autoclass:: LibriMix 13 | 14 | Wsj0mix 15 | -------- 16 | .. autoclass:: Wsj0mixDataset 17 | 18 | WHAM! 19 | ------ 20 | .. autoclass:: WhamDataset 21 | 22 | WHAMR! 23 | ------- 24 | .. autoclass:: WhamRDataset 25 | 26 | SMS-WSJ 27 | --------- 28 | .. autoclass:: SmsWsjDataset 29 | 30 | KinectWSJMix 31 | ------------- 32 | .. autoclass:: KinectWsjMixDataset 33 | 34 | DNSDataset 35 | ---------- 36 | .. autoclass:: DNSDataset 37 | 38 | MUSDB18 39 | -------- 40 | .. autoclass:: MUSDB18Dataset 41 | 42 | DAMP-VSEP 43 | --------- 44 | .. autoclass:: DAMPVSEPSinglesDataset 45 | 46 | FUSS 47 | ---- 48 | .. autoclass:: FUSSDataset 49 | 50 | AVSpeech 51 | -------- 52 | .. autoclass:: AVSpeechDataset 53 | -------------------------------------------------------------------------------- /model_card_template.md: -------------------------------------------------------------------------------- 1 | --- 2 | tags: 3 | - asteroid 4 | - audio 5 | - [model_name] 6 | datasets: 7 | - [dataset_name] 8 | - [task_name] 9 | license: cc-by-sa-3.0 10 | inference: false 11 | --- 12 | 13 | Fill in all the field within brackets []. 14 | 15 | 16 | ## Description: 17 | This model was trained by [your name] using the [the recipe name] recipe in Asteroid. 18 | It was trained on the `[task name]` task of the [dataset name] dataset. 19 | 20 | 21 | ## Training config: 22 | ```yaml 23 | Paste here the content of conf.yml 24 | ``` 25 | 26 | 27 | ## Results: 28 | ```yaml 29 | Paste here the content of final_metrics.json 30 | ``` 31 | 32 | 33 | ## License notice: 34 | 35 | ** This is important, please fill it, if you need help, you can ask on Asteroid's slack.** 36 | 37 | This work "[the name of the repo]" 38 | is a derivative of [the dataset]() by 39 | [the author](), 40 | used under [the license](). 41 | "[the name of the repo]" 42 | is licensed under [Attribution-ShareAlike 3.0 Unported](https://creativecommons.org/licenses/by-sa/3.0/) 43 | by [your name]. 44 | -------------------------------------------------------------------------------- /docs/source/package_reference/optimizers.rst: -------------------------------------------------------------------------------- 1 | Optimizers & Schedulers 2 | ======================= 3 | 4 | Optimizers 5 | ---------- 6 | 7 | Asteroid relies on `torch_optimizer `_ and 8 | ``torch`` for optimizers. 9 | We provide a simple ``get`` method that retrieves optimizers from string, 10 | which makes it easy to specify optimizers from the command line. 11 | 12 | Here is a list of supported optimizers, retrievable from string: 13 | 14 | - AccSGD 15 | - AdaBound 16 | - AdaMod 17 | - DiffGrad 18 | - Lamb 19 | - NovoGrad 20 | - PID 21 | - QHAdam 22 | - QHM 23 | - RAdam 24 | - SGDW 25 | - Yogi 26 | - Ranger 27 | - RangerQH 28 | - RangerVA 29 | - Adam 30 | - RMSprop 31 | - SGD 32 | - Adadelta 33 | - Adagrad 34 | - Adamax 35 | - AdamW 36 | - ASG 37 | 38 | .. automodule:: asteroid.engine.optimizers 39 | :members: 40 | 41 | 42 | Schedulers 43 | ---------- 44 | 45 | Asteroid provides step-wise learning schedulers, integrable to 46 | ``pytorch-lightning`` via ``System``. 47 | 48 | .. automodule:: asteroid.engine.schedulers 49 | :members: 50 | -------------------------------------------------------------------------------- /egs/wham/ConvTasNet/README.md: -------------------------------------------------------------------------------- 1 | ### Results 2 | 3 | | | task | n_blocks | n_repeats | batch size |SI-SNRi(dB) | SDRi(dB)| 4 | |:----:|:---------:|:--------:|:---------:|:----------:|:----------:|:-------:| 5 | | Paper| sep_clean | 8 | 3 | - | 15.3 | 15.6 | 6 | | Here | sep_clean | 8 | 3 | 12 | 16.2 | 16.5 | 7 | 8 | 9 | ### References 10 | If you use this model, please cite the original work. 11 | ```BibTex 12 | @article{Luo_2019, 13 | title={Conv-TasNet: Surpassing Ideal Time–Frequency Magnitude Masking for Speech Separation}, 14 | volume={27}, 15 | ISSN={2329-9304}, 16 | url={http://dx.doi.org/10.1109/TASLP.2019.2915167}, 17 | DOI={10.1109/taslp.2019.2915167}, 18 | journal={IEEE/ACM Transactions on Audio, Speech, and Language Processing}, 19 | publisher={Institute of Electrical and Electronics Engineers (IEEE)}, 20 | author={Luo, Yi and Mesgarani, Nima}, 21 | year={2019}, 22 | month={Aug}, 23 | pages={1256–1266} 24 | } 25 | ``` 26 | 27 | and if you like using `asteroid` you can give us a star! :star: 28 | -------------------------------------------------------------------------------- /egs/wham/ConvTasNet/local/prepare_data.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | wav_dir=tmp 4 | out_dir=tmp 5 | python_path=python 6 | 7 | . utils/parse_options.sh 8 | 9 | ## Download WHAM noises 10 | mkdir -p $out_dir 11 | echo "Download WHAM noises into $out_dir" 12 | # If downloading stalls for more than 20s, relaunch from previous state. 13 | wget -c --tries=0 --read-timeout=20 https://storage.googleapis.com/whisper-public/wham_noise.zip -P $out_dir 14 | mkdir -p $out_dir/logs 15 | unzip $out_dir/wham_noise.zip -d $out_dir >> $out_dir/logs/unzip_wham.log 16 | 17 | echo "Download WHAM scripts into $out_dir" 18 | wget https://storage.googleapis.com/whisper-public/wham_scripts.tar.gz -P $out_dir 19 | tar -xzvf $out_dir/wham_scripts.tar.gz -C $out_dir 20 | mv $out_dir/wham_scripts.tar.gz $out_dir/wham_scripts 21 | 22 | wait 23 | 24 | echo "Run python scripts to create the WHAM mixtures" 25 | # Requires : Numpy, Scipy, Pandas, and Pysoundfile 26 | cd $out_dir/wham_scripts 27 | $python_path create_wham_from_scratch.py \ 28 | --wsj0-root $wav_dir \ 29 | --wham-noise-root $out_dir/wham_noise\ 30 | --output-dir $out_dir 31 | cd - -------------------------------------------------------------------------------- /asteroid/utils/test_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils import data 3 | 4 | 5 | class DummyDataset(data.Dataset): 6 | def __init__(self): 7 | self.inp_dim = 10 8 | self.out_dim = 10 9 | 10 | def __len__(self): 11 | return 20 12 | 13 | def __getitem__(self, idx): 14 | return torch.randn(1, self.inp_dim), torch.randn(1, self.out_dim) 15 | 16 | 17 | class DummyWaveformDataset(data.Dataset): 18 | def __init__(self, total=12, n_src=3, len_wave=16000): 19 | self.inp_len_wave = len_wave 20 | self.out_len_wave = len_wave 21 | self.total = total 22 | self.inp_n_sig = 1 23 | self.out_n_sig = n_src 24 | 25 | def __len__(self): 26 | return self.total 27 | 28 | def __getitem__(self, idx): 29 | mixed = torch.randn(self.inp_n_sig, self.inp_len_wave) 30 | srcs = torch.randn(self.out_n_sig, self.out_len_wave) 31 | return mixed, srcs 32 | 33 | 34 | def torch_version_tuple(): 35 | version, *suffix = torch.__version__.split("+") 36 | return tuple(map(int, version.split("."))) + tuple(suffix) 37 | -------------------------------------------------------------------------------- /egs/wham/DynamicMixing/local/prepare_data.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | wav_dir=tmp 4 | out_dir=tmp 5 | python_path=python 6 | 7 | . utils/parse_options.sh 8 | 9 | ## Download WHAM noises 10 | mkdir -p $out_dir 11 | echo "Download WHAM noises into $out_dir" 12 | # If downloading stalls for more than 20s, relaunch from previous state. 13 | wget -c --tries=0 --read-timeout=20 https://storage.googleapis.com/whisper-public/wham_noise.zip -P $out_dir 14 | 15 | echo "Download WHAM scripts into $out_dir" 16 | wget https://storage.googleapis.com/whisper-public/wham_scripts.tar.gz -P $out_dir 17 | mkdir -p $out_dir/wham_scripts 18 | tar -xzvf $out_dir/wham_scripts.tar.gz -C $out_dir/wham_scripts 19 | mv $out_dir/wham_scripts.tar.gz $out_dir/wham_scripts 20 | 21 | wait 22 | 23 | unzip $out_dir/wham_noise.zip $out_dir >> logs/unzip_wham.log 24 | 25 | echo "Run python scripts to create the WHAM mixtures" 26 | # Requires : Numpy, Scipy, Pandas, and Pysoundfile 27 | cd $out_dir/wham_scripts 28 | $python_path create_wham_from_scratch.py \ 29 | --wsj0-root $wav_dir \ 30 | --wham-noise-root $out_dir/wham_noise\ 31 | --output-dir $out_dir 32 | cd - -------------------------------------------------------------------------------- /egs/wham/DPRNN/local/prepare_data.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | wav_dir=tmp 4 | out_dir=tmp 5 | python_path=python 6 | 7 | . utils/parse_options.sh 8 | 9 | ## Download WHAM noises 10 | mkdir -p $out_dir 11 | echo "Download WHAM noises into $out_dir" 12 | # If downloading stalls for more than 20s, relaunch from previous state. 13 | wget -c --tries=0 --read-timeout=20 https://storage.googleapis.com/whisper-public/wham_noise.zip -P $out_dir 14 | 15 | echo "Download WHAM scripts into $out_dir" 16 | wget https://storage.googleapis.com/whisper-public/wham_scripts.tar.gz -P $out_dir 17 | mkdir -p $out_dir/wham_scripts 18 | tar -xzvf $out_dir/wham_scripts.tar.gz -C $out_dir/wham_scripts 19 | mv $out_dir/wham_scripts.tar.gz $out_dir/wham_scripts 20 | 21 | wait 22 | 23 | unzip $out_dir/wham_noise.zip $out_dir >> logs/unzip_wham.log 24 | 25 | echo "Run python scripts to create the WHAM mixtures" 26 | # Requires : Numpy, Scipy, Pandas, and Pysoundfile 27 | cd $out_dir/wham_scripts/wham_scripts 28 | $python_path create_wham_from_scratch.py \ 29 | --wsj0-root $wav_dir \ 30 | --wham-noise-root $out_dir/wham_noise\ 31 | --output-dir $out_dir 32 | cd - 33 | -------------------------------------------------------------------------------- /egs/wham/DPTNet/local/prepare_data.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | wav_dir=tmp 4 | out_dir=tmp 5 | python_path=python 6 | 7 | . utils/parse_options.sh 8 | 9 | ## Download WHAM noises 10 | mkdir -p $out_dir 11 | echo "Download WHAM noises into $out_dir" 12 | # If downloading stalls for more than 20s, relaunch from previous state. 13 | wget -c --tries=0 --read-timeout=20 https://storage.googleapis.com/whisper-public/wham_noise.zip -P $out_dir 14 | 15 | echo "Download WHAM scripts into $out_dir" 16 | wget https://storage.googleapis.com/whisper-public/wham_scripts.tar.gz -P $out_dir 17 | mkdir -p $out_dir/wham_scripts 18 | tar -xzvf $out_dir/wham_scripts.tar.gz -C $out_dir/wham_scripts 19 | mv $out_dir/wham_scripts.tar.gz $out_dir/wham_scripts 20 | 21 | wait 22 | 23 | unzip $out_dir/wham_noise.zip $out_dir >> logs/unzip_wham.log 24 | 25 | echo "Run python scripts to create the WHAM mixtures" 26 | # Requires : Numpy, Scipy, Pandas, and Pysoundfile 27 | cd $out_dir/wham_scripts/wham_scripts 28 | $python_path create_wham_from_scratch.py \ 29 | --wsj0-root $wav_dir \ 30 | --wham-noise-root $out_dir/wham_noise\ 31 | --output-dir $out_dir 32 | cd - 33 | -------------------------------------------------------------------------------- /egs/wham/TwoStep/local/prepare_data.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | wav_dir=/mnt/data/wsj0-mix/wsj0 4 | out_dir=/mnt/data/wham 5 | python_path=python 6 | 7 | . utils/parse_options.sh 8 | 9 | ## Download WHAM noises 10 | mkdir -p $out_dir 11 | echo "Download WHAM noises into $out_dir" 12 | # If downloading stalls for more than 20s, relaunch from previous state. 13 | wget -c --tries=0 --read-timeout=20 https://storage.googleapis.com/whisper-public/wham_noise.zip -P $out_dir 14 | unzip $out_dir/wham_noise.zip -d $out_dir 15 | 16 | echo "Download WHAM scripts into $out_dir" 17 | wget https://storage.googleapis.com/whisper-public/wham_scripts.tar.gz -P $out_dir 18 | mkdir -p $out_dir/wham_scripts 19 | tar -xzvf $out_dir/wham_scripts.tar.gz -C $out_dir/wham_scripts 20 | mv $out_dir/wham_scripts.tar.gz $out_dir/wham_scripts 21 | 22 | wait 23 | 24 | echo "Run python scripts to create the WHAM mixtures" 25 | # Requires : Numpy, Scipy, Pandas, and Pysoundfile 26 | cd $out_dir/wham_scripts/wham_scripts 27 | $python_path create_wham_from_scratch.py \ 28 | --wsj0-root $wav_dir \ 29 | --wham-noise-root $out_dir/wham_noise\ 30 | --output-dir $out_dir 31 | cd - -------------------------------------------------------------------------------- /egs/wham/MixIT/local/prepare_data.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | wav_dir=/mnt/data/wsj0-mix/wsj0 4 | out_dir=/mnt/data/wham 5 | python_path=python 6 | 7 | . utils/parse_options.sh 8 | 9 | ## Download WHAM noises 10 | mkdir -p $out_dir 11 | echo "Download WHAM noises into $out_dir" 12 | # If downloading stalls for more than 20s, relaunch from previous state. 13 | wget -c --tries=0 --read-timeout=20 https://storage.googleapis.com/whisper-public/wham_noise.zip -P $out_dir 14 | unzip $out_dir/wham_noise.zip -d $out_dir 15 | 16 | echo "Download WHAM scripts into $out_dir" 17 | wget https://storage.googleapis.com/whisper-public/wham_scripts.tar.gz -P $out_dir 18 | mkdir -p $out_dir/wham_scripts 19 | tar -xzvf $out_dir/wham_scripts.tar.gz -C $out_dir/wham_scripts 20 | mv $out_dir/wham_scripts.tar.gz $out_dir/wham_scripts 21 | 22 | wait 23 | 24 | echo "Run python scripts to create the WHAM mixtures" 25 | # Requires : Numpy, Scipy, Pandas, and Pysoundfile 26 | cd $out_dir/wham_scripts/wham_scripts 27 | $python_path create_wham_from_scratch.py \ 28 | --wsj0-root $wav_dir \ 29 | --wham-noise-root $out_dir/wham_noise\ 30 | --output-dir $out_dir 31 | cd - 32 | -------------------------------------------------------------------------------- /egs/wham/DynamicMixing/local/conf.yml: -------------------------------------------------------------------------------- 1 | # Filterbank config 2 | filterbank: 3 | n_filters: 64 4 | kernel_size: 16 5 | stride: 8 6 | # Network config 7 | masknet: 8 | in_chan: 64 9 | n_src: 2 10 | out_chan: 64 11 | bn_chan: 128 12 | hid_size: 128 13 | chunk_size: 250 14 | hop_size: 125 15 | n_repeats: 6 16 | mask_act: 'sigmoid' 17 | bidirectional: true 18 | dropout: 0 19 | # Training config 20 | training: 21 | epochs: 200 22 | batch_size: 4 23 | num_workers: 4 24 | half_lr: yes 25 | early_stop: yes 26 | gradient_clipping: 5 27 | # Optim config 28 | optim: 29 | optimizer: adam 30 | lr: 0.001 31 | weight_decay: 0. 32 | # Data config 33 | data: 34 | train_dir: data/wav8k/min/tr/ 35 | valid_dir: data/wav8k/min/cv/ 36 | data_augmentation: True 37 | task: sep_clean 38 | nondefault_nsrc: 39 | sample_rate: 8000 40 | mode: min 41 | segment: 4.0 42 | augmentation: 43 | wsj0train: data/wsj0_8k_train 44 | noise_dir: 45 | orig_percentage: 0. 46 | #global_db_range: [-50, 0] 47 | #abs_stats: [-16.7, 7] 48 | #rel_stats: [2.52, 4] 49 | #noise_stats: [5.1, 6.4] 50 | speed_perturb: [1, 1] 51 | 52 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Pariente Manuel 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /docs/source/package_reference/dsp.rst: -------------------------------------------------------------------------------- 1 | DSP Modules 2 | =========== 3 | 4 | .. role:: hidden 5 | :class: hidden-section 6 | 7 | 8 | 9 | :hidden:`Beamforming` 10 | ~~~~~~~~~~~~~~~~~~~~~~~~ 11 | 12 | .. autoclass:: asteroid.dsp.beamforming.Beamformer 13 | .. autoclass:: asteroid.dsp.beamforming.SDWMWFBeamformer 14 | .. autoclass:: asteroid.dsp.beamforming.GEVBeamformer 15 | .. autoclass:: asteroid.dsp.beamforming.RTFMVDRBeamformer 16 | .. autoclass:: asteroid.dsp.beamforming.SoudenMVDRBeamformer 17 | .. autoclass:: asteroid.dsp.beamforming.SCM 18 | 19 | :hidden:`LambdaOverlapAdd` 20 | ~~~~~~~~~~~~~~~~~~~~~~~~~~ 21 | .. autoclass:: asteroid.dsp.LambdaOverlapAdd 22 | :members: 23 | 24 | :hidden:`DualPath Processing` 25 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 26 | .. autoclass:: asteroid.dsp.DualPathProcessing 27 | :members: 28 | 29 | :hidden:`Mixture Consistency` 30 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 31 | .. autofunction:: asteroid.dsp.mixture_consistency 32 | 33 | :hidden:`VAD` 34 | ~~~~~~~~~~~~~~~~ 35 | .. autofunction:: asteroid.dsp.vad.ebased_vad 36 | 37 | :hidden:`Delta Features` 38 | ~~~~~~~~~~~~~~~~~~~~~~~~ 39 | .. automodule:: asteroid.dsp.deltas 40 | -------------------------------------------------------------------------------- /docs/source/supported_datasets.rst: -------------------------------------------------------------------------------- 1 | Datasets and tasks 2 | ================== 3 | The following is a list of supported datasets, sorted by task. 4 | If you're more interested in the corresponding PyTorch ``Dataset``, see 5 | `this page `__ 6 | 7 | Speech separation 8 | ----------------- 9 | 10 | .. mdinclude:: readmes/wsj0-mix_README.md 11 | .. mdinclude:: readmes/wham_README.md 12 | .. mdinclude:: readmes/whamr_README.md 13 | .. mdinclude:: readmes/librimix_README.md 14 | .. mdinclude:: readmes/kinect-wsj_README.md 15 | .. mdinclude:: readmes/sms_wsj_README.md 16 | 17 | Speech enhancement 18 | ------------------ 19 | .. mdinclude:: readmes/dns_challenge_README.md 20 | 21 | 22 | Music source separation 23 | ----------------------- 24 | .. mdinclude:: readmes/musdb18_README.md 25 | .. mdinclude:: readmes/dampvsep_README.md 26 | 27 | 28 | Environmental sound separation 29 | ------------------------------ 30 | .. mdinclude:: readmes/fuss_README.md 31 | 32 | 33 | Audio-visual source separation 34 | ------------------------------ 35 | .. mdinclude:: readmes/avspeech_README.md 36 | 37 | 38 | Speaker extraction 39 | ------------------ 40 | -------------------------------------------------------------------------------- /egs/librimix/DCCRNet/local/create_local_metadata.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import argparse 4 | from glob import glob 5 | 6 | # Command line arguments 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument( 9 | "--librimix_dir", type=str, default=None, help="Path to librispeech root directory" 10 | ) 11 | 12 | 13 | def main(args): 14 | librimix_dir = args.librimix_dir 15 | create_local_metadata(librimix_dir) 16 | 17 | 18 | def create_local_metadata(librimix_dir): 19 | 20 | md_dirs = [f for f in glob(os.path.join(librimix_dir, "*/*/*")) if f.endswith("metadata")] 21 | for md_dir in md_dirs: 22 | md_files = [f for f in os.listdir(md_dir) if f.startswith("mix")] 23 | for md_file in md_files: 24 | subset = md_file.split("_")[1] 25 | local_path = os.path.join( 26 | "data", os.path.relpath(md_dir, librimix_dir), subset 27 | ).replace("/metadata", "") 28 | os.makedirs(local_path, exist_ok=True) 29 | shutil.copy(os.path.join(md_dir, md_file), local_path) 30 | 31 | 32 | if __name__ == "__main__": 33 | args = parser.parse_args() 34 | main(args) 35 | -------------------------------------------------------------------------------- /egs/librimix/DCUNet/local/create_local_metadata.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import argparse 4 | from glob import glob 5 | 6 | # Command line arguments 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument( 9 | "--librimix_dir", type=str, default=None, help="Path to librispeech root directory" 10 | ) 11 | 12 | 13 | def main(args): 14 | librimix_dir = args.librimix_dir 15 | create_local_metadata(librimix_dir) 16 | 17 | 18 | def create_local_metadata(librimix_dir): 19 | 20 | md_dirs = [f for f in glob(os.path.join(librimix_dir, "*/*/*")) if f.endswith("metadata")] 21 | for md_dir in md_dirs: 22 | md_files = [f for f in os.listdir(md_dir) if f.startswith("mix")] 23 | for md_file in md_files: 24 | subset = md_file.split("_")[1] 25 | local_path = os.path.join( 26 | "data", os.path.relpath(md_dir, librimix_dir), subset 27 | ).replace("/metadata", "") 28 | os.makedirs(local_path, exist_ok=True) 29 | shutil.copy(os.path.join(md_dir, md_file), local_path) 30 | 31 | 32 | if __name__ == "__main__": 33 | args = parser.parse_args() 34 | main(args) 35 | -------------------------------------------------------------------------------- /egs/librimix/DPTNet/local/create_local_metadata.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import argparse 4 | from glob import glob 5 | 6 | # Command line arguments 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument( 9 | "--librimix_dir", type=str, default=None, help="Path to librispeech root directory" 10 | ) 11 | 12 | 13 | def main(args): 14 | librimix_dir = args.librimix_dir 15 | create_local_metadata(librimix_dir) 16 | 17 | 18 | def create_local_metadata(librimix_dir): 19 | 20 | md_dirs = [f for f in glob(os.path.join(librimix_dir, "*/*/*")) if f.endswith("metadata")] 21 | for md_dir in md_dirs: 22 | md_files = [f for f in os.listdir(md_dir) if f.startswith("mix")] 23 | for md_file in md_files: 24 | subset = md_file.split("_")[1] 25 | local_path = os.path.join( 26 | "data", os.path.relpath(md_dir, librimix_dir), subset 27 | ).replace("/metadata", "") 28 | os.makedirs(local_path, exist_ok=True) 29 | shutil.copy(os.path.join(md_dir, md_file), local_path) 30 | 31 | 32 | if __name__ == "__main__": 33 | args = parser.parse_args() 34 | main(args) 35 | -------------------------------------------------------------------------------- /egs/librimix/ConvTasNet/local/create_local_metadata.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import argparse 4 | from glob import glob 5 | 6 | # Command line arguments 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument( 9 | "--librimix_dir", type=str, default=None, help="Path to librispeech root directory" 10 | ) 11 | 12 | 13 | def main(args): 14 | librimix_dir = args.librimix_dir 15 | create_local_metadata(librimix_dir) 16 | 17 | 18 | def create_local_metadata(librimix_dir): 19 | 20 | md_dirs = [f for f in glob(os.path.join(librimix_dir, "*/*/*")) if f.endswith("metadata")] 21 | for md_dir in md_dirs: 22 | md_files = [f for f in os.listdir(md_dir) if f.startswith("mix")] 23 | for md_file in md_files: 24 | subset = md_file.split("_")[1] 25 | local_path = os.path.join( 26 | "data", os.path.relpath(md_dir, librimix_dir), subset 27 | ).replace("/metadata", "") 28 | os.makedirs(local_path, exist_ok=True) 29 | shutil.copy(os.path.join(md_dir, md_file), local_path) 30 | 31 | 32 | if __name__ == "__main__": 33 | args = parser.parse_args() 34 | main(args) 35 | -------------------------------------------------------------------------------- /egs/librimix/DPRNNTasNet/local/create_local_metadata.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import argparse 4 | from glob import glob 5 | 6 | # Command line arguments 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument( 9 | "--librimix_dir", type=str, default=None, help="Path to librispeech root directory" 10 | ) 11 | 12 | 13 | def main(args): 14 | librimix_dir = args.librimix_dir 15 | create_local_metadata(librimix_dir) 16 | 17 | 18 | def create_local_metadata(librimix_dir): 19 | 20 | md_dirs = [f for f in glob(os.path.join(librimix_dir, "*/*/*")) if f.endswith("metadata")] 21 | for md_dir in md_dirs: 22 | md_files = [f for f in os.listdir(md_dir) if f.startswith("mix")] 23 | for md_file in md_files: 24 | subset = md_file.split("_")[1] 25 | local_path = os.path.join( 26 | "data", os.path.relpath(md_dir, librimix_dir), subset 27 | ).replace("/metadata", "") 28 | os.makedirs(local_path, exist_ok=True) 29 | shutil.copy(os.path.join(md_dir, md_file), local_path) 30 | 31 | 32 | if __name__ == "__main__": 33 | args = parser.parse_args() 34 | main(args) 35 | -------------------------------------------------------------------------------- /egs/librimix/SuDORMRFNet/local/create_local_metadata.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import argparse 4 | from glob import glob 5 | 6 | # Command line arguments 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument( 9 | "--librimix_dir", type=str, default=None, help="Path to librispeech root directory" 10 | ) 11 | 12 | 13 | def main(args): 14 | librimix_dir = args.librimix_dir 15 | create_local_metadata(librimix_dir) 16 | 17 | 18 | def create_local_metadata(librimix_dir): 19 | 20 | md_dirs = [f for f in glob(os.path.join(librimix_dir, "*/*/*")) if f.endswith("metadata")] 21 | for md_dir in md_dirs: 22 | md_files = [f for f in os.listdir(md_dir) if f.startswith("mix")] 23 | for md_file in md_files: 24 | subset = md_file.split("_")[1] 25 | local_path = os.path.join( 26 | "data", os.path.relpath(md_dir, librimix_dir), subset 27 | ).replace("/metadata", "") 28 | os.makedirs(local_path, exist_ok=True) 29 | shutil.copy(os.path.join(md_dir, md_file), local_path) 30 | 31 | 32 | if __name__ == "__main__": 33 | args = parser.parse_args() 34 | main(args) 35 | -------------------------------------------------------------------------------- /tests/dsp/spatial_test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pytest 3 | import numpy as np 4 | from asteroid.dsp.spatial import xcorr 5 | 6 | 7 | @pytest.mark.parametrize("seq_len_input", [1390]) 8 | @pytest.mark.parametrize("seq_len_ref", [1390, 1290]) 9 | @pytest.mark.parametrize("batch_size", [1, 2]) 10 | @pytest.mark.parametrize("n_mics_input", [1]) 11 | @pytest.mark.parametrize("n_mics_ref", [1, 2]) 12 | @pytest.mark.parametrize("normalized", [False, True]) 13 | def test_xcorr(seq_len_input, seq_len_ref, batch_size, n_mics_input, n_mics_ref, normalized): 14 | target = torch.rand((batch_size, n_mics_input, seq_len_input)) 15 | ref = torch.rand((batch_size, n_mics_ref, seq_len_ref)) 16 | result = xcorr(target, ref, normalized) 17 | assert result.shape[-1] == (seq_len_input - seq_len_ref) + 1 18 | 19 | if normalized == False: 20 | for b in range(batch_size): 21 | for m in range(n_mics_input): 22 | npy_result = np.correlate(target[b, m].numpy(), ref[b, m].numpy()) 23 | np.testing.assert_array_almost_equal( 24 | result[b, m, : len(npy_result)].numpy(), npy_result, decimal=2 25 | ) 26 | -------------------------------------------------------------------------------- /egs/librimix/SuDORMRFImprovedNet/local/create_local_metadata.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import argparse 4 | from glob import glob 5 | 6 | # Command line arguments 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument( 9 | "--librimix_dir", type=str, default=None, help="Path to librispeech root directory" 10 | ) 11 | 12 | 13 | def main(args): 14 | librimix_dir = args.librimix_dir 15 | create_local_metadata(librimix_dir) 16 | 17 | 18 | def create_local_metadata(librimix_dir): 19 | 20 | md_dirs = [f for f in glob(os.path.join(librimix_dir, "*/*/*")) if f.endswith("metadata")] 21 | for md_dir in md_dirs: 22 | md_files = [f for f in os.listdir(md_dir) if f.startswith("mix")] 23 | for md_file in md_files: 24 | subset = md_file.split("_")[1] 25 | local_path = os.path.join( 26 | "data", os.path.relpath(md_dir, librimix_dir), subset 27 | ).replace("/metadata", "") 28 | os.makedirs(local_path, exist_ok=True) 29 | shutil.copy(os.path.join(md_dir, md_file), local_path) 30 | 31 | 32 | if __name__ == "__main__": 33 | args = parser.parse_args() 34 | main(args) 35 | -------------------------------------------------------------------------------- /egs/sms_wsj/CaCGMM/README.md: -------------------------------------------------------------------------------- 1 | ### Mixture model based source separation on SMS-WSJ 2 | 3 | Configuration: 4 | beamformer = 'mvdr_souden' 5 | mask_estimator = 'cacgmm' 6 | postfilter = None 7 | stft_shift = 128 8 | stft_size = 512 9 | stft_window = 'hann' 10 | 11 | evaluation metric | cv_dev93 | test_eval92 12 | :-------------------|--------------:|--------------: 13 | PESQ | 2.068 | 2.187 14 | STOI | 0.820 | 0.800 15 | mir_eval SDR | 12.34 | 12.11 16 | invasive SDR | 15.74 | 15.47 17 | 18 | ### References 19 | ```BibTex 20 | @Article{SmsWsj19, 21 | author = {Drude, Lukas and Heitkaemper, Jens and Boeddeker, Christoph and Haeb-Umbach, Reinhold}, 22 | title = {{SMS-WSJ}: Database, performance measures, and baseline recipe for multi-channel source separation and recognition}, 23 | journal = {arXiv preprint arXiv:1910.13934}, 24 | year = {2019}, 25 | } 26 | ``` -------------------------------------------------------------------------------- /egs/wham/DynamicMixing/local/resample_dataset.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from glob import glob 4 | from distutils.dir_util import copy_tree 5 | from scipy.signal import resample_poly 6 | import soundfile as sf 7 | 8 | parser = argparse.ArgumentParser("Script for resampling a dataset") 9 | parser.add_argument("source_dir", type=str) 10 | parser.add_argument("out_dir", type=str) 11 | parser.add_argument("original_sr", type=int) 12 | parser.add_argument("target_sr", type=int) 13 | parser.add_argument("--extension", type=str, default="wav") 14 | 15 | 16 | def main(out_dir, original_sr, target_sr, extension): 17 | assert original_sr >= target_sr, "Upsampling not supported" 18 | wavs = glob(os.path.join(out_dir, "**/*.{}".format(extension)), recursive=True) 19 | for wav in wavs: 20 | data, fs = sf.read(wav) 21 | assert fs == original_sr 22 | data = resample_poly(data, target_sr, fs) 23 | sf.write(wav, data, samplerate=target_sr) 24 | 25 | 26 | if __name__ == "__main__": 27 | args = parser.parse_args() 28 | copy_tree(args.source_dir, args.out_dir) # first we copy then we resample 29 | main(args.out_dir, args.original_sr, args.target_sr, args.extension) 30 | -------------------------------------------------------------------------------- /docs/source/installation.rst: -------------------------------------------------------------------------------- 1 | Installation 2 | ============ 3 | 4 | By following the instructions below, first install PyTorch and then 5 | Asteroid (using either pip/dev install). We recommend the development 6 | installation for users likely to modify the source code. 7 | 8 | CUDA and PyTorch 9 | **************** 10 | 11 | Asteroid is based on PyTorch. 12 | To run Asteroid on GPU, you will need a CUDA-enabled PyTorch installation. 13 | Visit this site for the instructions: https://pytorch.org/get-started/locally/. 14 | 15 | Pip 16 | *** 17 | 18 | Asteroid is regularly updated on PyPI, install the latest stable version with:: 19 | 20 | pip install asteroid 21 | 22 | 23 | Development installation 24 | ************************ 25 | 26 | For development installation, you can fork/clone the GitHub repo and locally install it with pip:: 27 | 28 | git clone https://github.com/asteroid-team/asteroid 29 | cd asteroid 30 | pip install -e . 31 | 32 | This is an editable install (``-e`` flag), it means that source code changes (or branch switching) are 33 | automatically taken into account when importing asteroid. 34 | 35 | You can also use ``conda env create -f environment.yml`` to create a Conda env directly. 36 | -------------------------------------------------------------------------------- /asteroid/dsp/vad.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from ..utils.torch_utils import script_if_tracing 3 | 4 | 5 | @script_if_tracing 6 | def ebased_vad(mag_spec, th_db: int = 40): 7 | """Compute energy-based VAD from a magnitude spectrogram (or equivalent). 8 | 9 | Args: 10 | mag_spec (torch.Tensor): the spectrogram to perform VAD on. 11 | Expected shape (batch, *, freq, time). 12 | The VAD mask will be computed independently for all the leading 13 | dimensions until the last two. Independent of the ordering of the 14 | last two dimensions. 15 | th_db (int): The threshold in dB from which a TF-bin is considered 16 | silent. 17 | 18 | Returns: 19 | :class:`torch.BoolTensor`, the VAD mask. 20 | 21 | 22 | Examples 23 | >>> import torch 24 | >>> mag_spec = torch.abs(torch.randn(10, 2, 65, 16)) 25 | >>> batch_src_mask = ebased_vad(mag_spec) 26 | """ 27 | log_mag = 20 * torch.log10(mag_spec) 28 | # Compute VAD for each utterance in a batch independently. 29 | to_view = list(mag_spec.shape[:-2]) + [1, -1] 30 | max_log_mag = torch.max(log_mag.view(to_view), -1, keepdim=True)[0] 31 | return log_mag > (max_log_mag - th_db) 32 | -------------------------------------------------------------------------------- /egs/whamr/TasNet/local/prepare_data.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | wav_dir=tmp 4 | out_dir=tmp 5 | python_path=python 6 | 7 | . utils/parse_options.sh 8 | 9 | ## Download WHAM noises 10 | mkdir -p $out_dir 11 | echo "Download WHAM noises into $out_dir" 12 | # If downloading stalls for more than 20s, relaunch from previous state. 13 | #wget -c --tries=0 --read-timeout=20 https://storage.googleapis.com/whisper-public/wham_noise.zip -P $out_dir 14 | # 15 | # 16 | #echo "Download WHAMR scripts into $out_dir" 17 | #wget https://storage.googleapis.com/whisper-public/whamr_scripts.tar.gz -P $out_dir 18 | #tar -xzvf $out_dir/whamr_scripts.tar.gz -C $out_dir/ 19 | #mv $out_dir/whamr_scripts.tar.gz $out_dir/whamr_scripts 20 | # 21 | #wait 22 | # 23 | #echo "Unzip WHAM noises into $out_dir" 24 | #mkdir -p logs 25 | #unzip $out_dir/wham_noise.zip -d $out_dir >> logs/unzip_whamr.log 26 | 27 | 28 | cd $out_dir/whamr_scripts 29 | echo "Run python scripts to create the WHAM mixtures" 30 | # Requires : Pyloudnorm, Numpy, Scipy, Pandas, Pysoundfile and pyroomacoustics 31 | $python_path -m pip install -r requirements.txt 32 | 33 | $python_path create_wham_from_scratch.py \ 34 | --wsj0-root $wav_dir \ 35 | --wham-noise-root $out_dir/wham_noise\ 36 | --output-dir $out_dir 37 | cd - 38 | -------------------------------------------------------------------------------- /egs/wham/DPRNN/local/convert_sphere2wav.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # MIT Copyright (c) 2018 Kaituo XU 3 | 4 | 5 | sphere_dir=tmp 6 | wav_dir=tmp 7 | 8 | . utils/parse_options.sh || exit 1; 9 | 10 | 11 | echo "Download sph2pipe_v2.5 into egs/tools" 12 | mkdir -p ../../tools 13 | wget http://www.openslr.org/resources/3/sph2pipe_v2.5.tar.gz -P ../../tools 14 | cd ../../tools && tar -xzvf sph2pipe_v2.5.tar.gz && gcc -o sph2pipe_v2.5/sph2pipe sph2pipe_v2.5/*.c -lm && cd - 15 | 16 | echo "Convert sphere format to wav format" 17 | sph2pipe=../../tools/sph2pipe_v2.5/sph2pipe 18 | 19 | if [ ! -x $sph2pipe ]; then 20 | echo "Could not find (or execute) the sph2pipe program at $sph2pipe"; 21 | exit 1; 22 | fi 23 | 24 | tmp=data/local/ 25 | mkdir -p $tmp 26 | 27 | [ ! -f $tmp/sph.list ] && find $sphere_dir -iname '*.wv*' | grep -e 'si_tr_s' -e 'si_dt_05' -e 'si_et_05' > $tmp/sph.list 28 | 29 | if [ ! -d $wav_dir ]; then 30 | while read line; do 31 | wav=`echo "$line" | sed "s:wv1:wav:g" | awk -v dir=$wav_dir -F'/' '{printf("%s/%s/%s/%s", dir, $(NF-2), $(NF-1), $NF)}'` 32 | echo $wav 33 | mkdir -p `dirname $wav` 34 | $sph2pipe -f wav $line > $wav 35 | done < $tmp/sph.list > $tmp/wav.list 36 | else 37 | echo "Do you already get wav files? if not, please remove $wav_dir" 38 | fi 39 | -------------------------------------------------------------------------------- /egs/wham/DPTNet/local/convert_sphere2wav.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # MIT Copyright (c) 2018 Kaituo XU 3 | 4 | 5 | sphere_dir=tmp 6 | wav_dir=tmp 7 | 8 | . utils/parse_options.sh || exit 1; 9 | 10 | 11 | echo "Download sph2pipe_v2.5 into egs/tools" 12 | mkdir -p ../../tools 13 | wget http://www.openslr.org/resources/3/sph2pipe_v2.5.tar.gz -P ../../tools 14 | cd ../../tools && tar -xzvf sph2pipe_v2.5.tar.gz && gcc -o sph2pipe_v2.5/sph2pipe sph2pipe_v2.5/*.c -lm && cd - 15 | 16 | echo "Convert sphere format to wav format" 17 | sph2pipe=../../tools/sph2pipe_v2.5/sph2pipe 18 | 19 | if [ ! -x $sph2pipe ]; then 20 | echo "Could not find (or execute) the sph2pipe program at $sph2pipe"; 21 | exit 1; 22 | fi 23 | 24 | tmp=data/local/ 25 | mkdir -p $tmp 26 | 27 | [ ! -f $tmp/sph.list ] && find $sphere_dir -iname '*.wv*' | grep -e 'si_tr_s' -e 'si_dt_05' -e 'si_et_05' > $tmp/sph.list 28 | 29 | if [ ! -d $wav_dir ]; then 30 | while read line; do 31 | wav=`echo "$line" | sed "s:wv1:wav:g" | awk -v dir=$wav_dir -F'/' '{printf("%s/%s/%s/%s", dir, $(NF-2), $(NF-1), $NF)}'` 32 | echo $wav 33 | mkdir -p `dirname $wav` 34 | $sph2pipe -f wav $line > $wav 35 | done < $tmp/sph.list > $tmp/wav.list 36 | else 37 | echo "Do you already get wav files? if not, please remove $wav_dir" 38 | fi 39 | -------------------------------------------------------------------------------- /egs/wham/MixIT/local/convert_sphere2wav.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # MIT Copyright (c) 2018 Kaituo XU 3 | 4 | 5 | sphere_dir=tmp 6 | wav_dir=tmp 7 | 8 | . utils/parse_options.sh || exit 1; 9 | 10 | 11 | echo "Download sph2pipe_v2.5 into egs/tools" 12 | mkdir -p ../../tools 13 | wget http://www.openslr.org/resources/3/sph2pipe_v2.5.tar.gz -P ../../tools 14 | cd ../../tools && tar -xzvf sph2pipe_v2.5.tar.gz && gcc -o sph2pipe_v2.5/sph2pipe sph2pipe_v2.5/*.c -lm && cd - 15 | 16 | echo "Convert sphere format to wav format" 17 | sph2pipe=../../tools/sph2pipe_v2.5/sph2pipe 18 | 19 | if [ ! -x $sph2pipe ]; then 20 | echo "Could not find (or execute) the sph2pipe program at $sph2pipe"; 21 | exit 1; 22 | fi 23 | 24 | tmp=data/local/ 25 | mkdir -p $tmp 26 | 27 | [ ! -f $tmp/sph.list ] && find $sphere_dir -iname '*.wv*' | grep -e 'si_tr_s' -e 'si_dt_05' -e 'si_et_05' > $tmp/sph.list 28 | 29 | if [ ! -d $wav_dir ]; then 30 | while read line; do 31 | wav=`echo "$line" | sed "s:wv1:wav:g" | awk -v dir=$wav_dir -F'/' '{printf("%s/%s/%s/%s", dir, $(NF-2), $(NF-1), $NF)}'` 32 | echo $wav 33 | mkdir -p `dirname $wav` 34 | $sph2pipe -f wav $line > $wav 35 | done < $tmp/sph.list > $tmp/wav.list 36 | else 37 | echo "Do you already get wav files? if not, please remove $wav_dir" 38 | fi 39 | -------------------------------------------------------------------------------- /egs/wham/TwoStep/local/convert_sphere2wav.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # MIT Copyright (c) 2018 Kaituo XU 3 | 4 | 5 | sphere_dir=tmp 6 | wav_dir=tmp 7 | 8 | . utils/parse_options.sh || exit 1; 9 | 10 | 11 | echo "Download sph2pipe_v2.5 into egs/tools" 12 | mkdir -p ../../tools 13 | wget http://www.openslr.org/resources/3/sph2pipe_v2.5.tar.gz -P ../../tools 14 | cd ../../tools && tar -xzvf sph2pipe_v2.5.tar.gz && gcc -o sph2pipe_v2.5/sph2pipe sph2pipe_v2.5/*.c -lm && cd - 15 | 16 | echo "Convert sphere format to wav format" 17 | sph2pipe=../../tools/sph2pipe_v2.5/sph2pipe 18 | 19 | if [ ! -x $sph2pipe ]; then 20 | echo "Could not find (or execute) the sph2pipe program at $sph2pipe"; 21 | exit 1; 22 | fi 23 | 24 | tmp=data/local/ 25 | mkdir -p $tmp 26 | 27 | [ ! -f $tmp/sph.list ] && find $sphere_dir -iname '*.wv*' | grep -e 'si_tr_s' -e 'si_dt_05' -e 'si_et_05' > $tmp/sph.list 28 | 29 | if [ ! -d $wav_dir ]; then 30 | while read line; do 31 | wav=`echo "$line" | sed "s:wv1:wav:g" | awk -v dir=$wav_dir -F'/' '{printf("%s/%s/%s/%s", dir, $(NF-2), $(NF-1), $NF)}'` 32 | echo $wav 33 | mkdir -p `dirname $wav` 34 | $sph2pipe -f wav $line > $wav 35 | done < $tmp/sph.list > $tmp/wav.list 36 | else 37 | echo "Do you already get wav files? if not, please remove $wav_dir" 38 | fi 39 | -------------------------------------------------------------------------------- /egs/whamr/TasNet/local/convert_sphere2wav.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # MIT Copyright (c) 2018 Kaituo XU 3 | 4 | 5 | sphere_dir=tmp 6 | wav_dir=tmp 7 | 8 | . utils/parse_options.sh || exit 1; 9 | 10 | 11 | echo "Download sph2pipe_v2.5 into egs/tools" 12 | mkdir -p ../../tools 13 | wget http://www.openslr.org/resources/3/sph2pipe_v2.5.tar.gz -P ../../tools 14 | cd ../../tools && tar -xzvf sph2pipe_v2.5.tar.gz && gcc -o sph2pipe_v2.5/sph2pipe sph2pipe_v2.5/*.c -lm && cd - 15 | 16 | echo "Convert sphere format to wav format" 17 | sph2pipe=../../tools/sph2pipe_v2.5/sph2pipe 18 | 19 | if [ ! -x $sph2pipe ]; then 20 | echo "Could not find (or execute) the sph2pipe program at $sph2pipe"; 21 | exit 1; 22 | fi 23 | 24 | tmp=data/local/ 25 | mkdir -p $tmp 26 | 27 | [ ! -f $tmp/sph.list ] && find $sphere_dir -iname '*.wv*' | grep -e 'si_tr_s' -e 'si_dt_05' -e 'si_et_05' > $tmp/sph.list 28 | 29 | if [ ! -d $wav_dir ]; then 30 | while read line; do 31 | wav=`echo "$line" | sed "s:wv1:wav:g" | awk -v dir=$wav_dir -F'/' '{printf("%s/%s/%s/%s", dir, $(NF-2), $(NF-1), $NF)}'` 32 | echo $wav 33 | mkdir -p `dirname $wav` 34 | $sph2pipe -f wav $line > $wav 35 | done < $tmp/sph.list > $tmp/wav.list 36 | else 37 | echo "Do you already get wav files? if not, please remove $wav_dir" 38 | fi 39 | -------------------------------------------------------------------------------- /egs/wham/ConvTasNet/local/convert_sphere2wav.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # MIT Copyright (c) 2018 Kaituo XU 3 | 4 | 5 | sphere_dir=tmp 6 | wav_dir=tmp 7 | 8 | . utils/parse_options.sh || exit 1; 9 | 10 | 11 | echo "Download sph2pipe_v2.5 into egs/tools" 12 | mkdir -p ../../tools 13 | wget http://www.openslr.org/resources/3/sph2pipe_v2.5.tar.gz -P ../../tools 14 | cd ../../tools && tar -xzvf sph2pipe_v2.5.tar.gz && gcc -o sph2pipe_v2.5/sph2pipe sph2pipe_v2.5/*.c -lm && cd - 15 | 16 | echo "Convert sphere format to wav format" 17 | sph2pipe=../../tools/sph2pipe_v2.5/sph2pipe 18 | 19 | if [ ! -x $sph2pipe ]; then 20 | echo "Could not find (or execute) the sph2pipe program at $sph2pipe"; 21 | exit 1; 22 | fi 23 | 24 | tmp=data/local/ 25 | mkdir -p $tmp 26 | 27 | [ ! -f $tmp/sph.list ] && find $sphere_dir -iname '*.wv*' | grep -e 'si_tr_s' -e 'si_dt_05' -e 'si_et_05' > $tmp/sph.list 28 | 29 | if [ ! -d $wav_dir ]; then 30 | while read line; do 31 | wav=`echo "$line" | sed "s:wv1:wav:g" | awk -v dir=$wav_dir -F'/' '{printf("%s/%s/%s/%s", dir, $(NF-2), $(NF-1), $NF)}'` 32 | echo $wav 33 | mkdir -p `dirname $wav` 34 | $sph2pipe -f wav $line > $wav 35 | done < $tmp/sph.list > $tmp/wav.list 36 | else 37 | echo "Do you already get wav files? if not, please remove $wav_dir" 38 | fi 39 | -------------------------------------------------------------------------------- /egs/wham/DynamicMixing/local/convert_sphere2wav.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # MIT Copyright (c) 2018 Kaituo XU 3 | 4 | 5 | sphere_dir=tmp 6 | wav_dir=tmp 7 | 8 | . utils/parse_options.sh || exit 1; 9 | 10 | 11 | echo "Download sph2pipe_v2.5 into egs/tools" 12 | mkdir -p ../../tools 13 | wget http://www.openslr.org/resources/3/sph2pipe_v2.5.tar.gz -P ../../tools 14 | cd ../../tools && tar -xzvf sph2pipe_v2.5.tar.gz && gcc -o sph2pipe_v2.5/sph2pipe sph2pipe_v2.5/*.c -lm && cd - 15 | 16 | echo "Convert sphere format to wav format" 17 | sph2pipe=../../tools/sph2pipe_v2.5/sph2pipe 18 | 19 | if [ ! -x $sph2pipe ]; then 20 | echo "Could not find (or execute) the sph2pipe program at $sph2pipe"; 21 | exit 1; 22 | fi 23 | 24 | tmp=data/local/ 25 | mkdir -p $tmp 26 | 27 | [ ! -f $tmp/sph.list ] && find $sphere_dir -iname '*.wv*' | grep -e 'si_tr_s' -e 'si_dt_05' -e 'si_et_05' > $tmp/sph.list 28 | 29 | if [ ! -d $wav_dir ]; then 30 | while read line; do 31 | wav=`echo "$line" | sed "s:wv1:wav:g" | awk -v dir=$wav_dir -F'/' '{printf("%s/%s/%s/%s", dir, $(NF-2), $(NF-1), $NF)}'` 32 | echo $wav 33 | mkdir -p `dirname $wav` 34 | $sph2pipe -f wav $line > $wav 35 | done < $tmp/sph.list > $tmp/wav.list 36 | else 37 | echo "Do you already get wav files? if not, please remove $wav_dir" 38 | fi 39 | -------------------------------------------------------------------------------- /egs/kinect-wsj/DeepClustering/local/convert_sphere2wav.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # MIT Copyright (c) 2018 Kaituo XU 3 | 4 | 5 | sphere_dir=tmp 6 | wav_dir=tmp 7 | 8 | . utils/parse_options.sh || exit 1; 9 | 10 | 11 | echo "Download sph2pipe_v2.5 into egs/tools" 12 | mkdir -p ../../tools 13 | wget http://www.openslr.org/resources/3/sph2pipe_v2.5.tar.gz -P ../../tools 14 | cd ../../tools && tar -xzvf sph2pipe_v2.5.tar.gz && gcc -o sph2pipe_v2.5/sph2pipe sph2pipe_v2.5/*.c -lm && cd - 15 | 16 | echo "Convert sphere format to wav format" 17 | sph2pipe=../../tools/sph2pipe_v2.5/sph2pipe 18 | 19 | if [ ! -x $sph2pipe ]; then 20 | echo "Could not find (or execute) the sph2pipe program at $sph2pipe"; 21 | exit 1; 22 | fi 23 | 24 | tmp=data/local/ 25 | mkdir -p $tmp 26 | 27 | [ ! -f $tmp/sph.list ] && find $sphere_dir -iname '*.wv*' | grep -e 'si_tr_s' -e 'si_dt_05' -e 'si_et_05' > $tmp/sph.list 28 | 29 | if [ ! -d $wav_dir ]; then 30 | while read line; do 31 | wav=`echo "$line" | sed "s:wv1:wav:g" | awk -v dir=$wav_dir -F'/' '{printf("%s/%s/%s/%s", dir, $(NF-2), $(NF-1), $NF)}'` 32 | echo $wav 33 | mkdir -p `dirname $wav` 34 | $sph2pipe -f wav $line > $wav 35 | done < $tmp/sph.list > $tmp/wav.list 36 | else 37 | echo "Do you already get wav files? if not, please remove $wav_dir" 38 | fi 39 | -------------------------------------------------------------------------------- /egs/wham/TwoStep/README.md: -------------------------------------------------------------------------------- 1 | ### Description 2 | A two-step training procedure for source 3 | separation via a deep neural network. In the first step we learn a 4 | transform (and it’s inverse) to a latent space where masking-based 5 | separation performance using oracles is optimal. For the second step, 6 | we train a separation module that operates on the previously learned 7 | space. 8 | 9 | ### Results 10 | 11 | | | Task | n_blocks | n_repeats | batch size |SI-SNRi(dB) | 12 | |:----:|:---------:|:--------:|:---------:|:----------:|:----------:| 13 | | Paper| sep_clean | 8 | 4 | - | 16.10 | 14 | | Here | sep_clean | 8 | 4 | - | 15.23 | 15 | 16 | ### References 17 | If you use this model, please cite the original work. 18 | ```BibTex 19 | @article{tzinis2019two, 20 | title={Two-Step Sound Source Separation: Training on Learned Latent Targets}, 21 | author={Tzinis, Efthymios and Venkataramani, Shrikant and Wang, Zhepei and Subakan, Cem and Smaragdis, Paris}, 22 | booktitle={ICASSP 2020-2020 IEEE International Conference on 23 | Acoustics, Speech and Signal Processing (ICASSP)}, 24 | pages={}, 25 | year={2020}, 26 | organization={IEEE} 27 | } 28 | ``` 29 | 30 | and if you like using `asteroid` you can give us a star! :star: 31 | -------------------------------------------------------------------------------- /egs/wsj0-mix/DeepClustering/local/convert_sphere2wav.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # MIT Copyright (c) 2018 Kaituo XU 3 | 4 | 5 | sphere_dir=tmp 6 | wav_dir=tmp 7 | 8 | . utils/parse_options.sh || exit 1; 9 | 10 | 11 | echo "Download sph2pipe_v2.5 into egs/tools" 12 | mkdir -p ../../tools 13 | wget http://www.openslr.org/resources/3/sph2pipe_v2.5.tar.gz -P ../../tools 14 | cd ../../tools && tar -xzvf sph2pipe_v2.5.tar.gz && gcc -o sph2pipe_v2.5/sph2pipe sph2pipe_v2.5/*.c -lm && cd - 15 | 16 | echo "Convert sphere format to wav format" 17 | sph2pipe=../../tools/sph2pipe_v2.5/sph2pipe 18 | 19 | if [ ! -x $sph2pipe ]; then 20 | echo "Could not find (or execute) the sph2pipe program at $sph2pipe"; 21 | exit 1; 22 | fi 23 | 24 | tmp=data/local/ 25 | mkdir -p $tmp 26 | 27 | [ ! -f $tmp/sph.list ] && find $sphere_dir -iname '*.wv*' | grep -e 'si_tr_s' -e 'si_dt_05' -e 'si_et_05' > $tmp/sph.list 28 | 29 | if [ ! -d $wav_dir ]; then 30 | while read line; do 31 | wav=`echo "$line" | sed "s:wv1:wav:g" | awk -v dir=$wav_dir -F'/' '{printf("%s/%s/%s/%s", dir, $(NF-2), $(NF-1), $NF)}'` 32 | echo $wav 33 | mkdir -p `dirname $wav` 34 | $sph2pipe -f wav $line > $wav 35 | done < $tmp/sph.list > $tmp/wav.list 36 | else 37 | echo "Do you already get wav files? if not, please remove $wav_dir" 38 | fi 39 | -------------------------------------------------------------------------------- /egs/dns_challenge_INTERSPEECH2020/README.md: -------------------------------------------------------------------------------- 1 | ### INTERSPEECH 2020 DNS Challenge's dataset 2 | 3 | The Deep Noise Suppression (DNS) Challenge is a single-channel speech enhancement 4 | challenge organized by Microsoft, with a focus on real-time applications. 5 | More info can be found on the [official page](https://dns-challenge.azurewebsites.net/). 6 | 7 | **References** 8 | The challenge paper, [here](https://arxiv.org/abs/2001.08662). 9 | ```BibTex 10 | @misc{DNSChallenge2020, 11 | title={The INTERSPEECH 2020 Deep Noise Suppression Challenge: Datasets, Subjective Speech Quality and Testing Framework}, 12 | author={Chandan K. A. Reddy and Ebrahim Beyrami and Harishchandra Dubey and Vishak Gopal and Roger Cheng and Ross Cutler and Sergiy Matusevych and Robert Aichner and Ashkan Aazami and Sebastian Braun and Puneet Rana and Sriram Srinivasan and Johannes Gehrke}, year={2020}, 13 | eprint={2001.08662}, 14 | } 15 | ``` 16 | The baseline paper, [here](https://arxiv.org/abs/2001.10601). 17 | ```BibTex 18 | @misc{xia2020weighted, 19 | title={Weighted Speech Distortion Losses for Neural-network-based Real-time Speech Enhancement}, 20 | author={Yangyang Xia and Sebastian Braun and Chandan K. A. Reddy and Harishchandra Dubey and Ross Cutler and Ivan Tashev}, 21 | year={2020}, 22 | eprint={2001.10601}, 23 | } 24 | ``` 25 | -------------------------------------------------------------------------------- /egs/wsj0-mix-var/Multi-Decoder-DPRNN/local/convert_sphere2wav.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # MIT Copyright (c) 2018 Kaituo XU 3 | 4 | 5 | sphere_dir=tmp 6 | wav_dir=tmp 7 | 8 | . utils/parse_options.sh || exit 1; 9 | 10 | 11 | echo "Download sph2pipe_v2.5 into egs/tools" 12 | mkdir -p ../../tools 13 | wget http://www.openslr.org/resources/3/sph2pipe_v2.5.tar.gz -P ../../tools 14 | cd ../../tools && tar -xzvf sph2pipe_v2.5.tar.gz && gcc -o sph2pipe_v2.5/sph2pipe sph2pipe_v2.5/*.c -lm && cd - 15 | 16 | echo "Convert sphere format to wav format" 17 | sph2pipe=../../tools/sph2pipe_v2.5/sph2pipe 18 | 19 | if [ ! -x $sph2pipe ]; then 20 | echo "Could not find (or execute) the sph2pipe program at $sph2pipe"; 21 | exit 1; 22 | fi 23 | 24 | tmp=data/local/ 25 | mkdir -p $tmp 26 | 27 | [ ! -f $tmp/sph.list ] && find $sphere_dir -iname '*.wv*' | grep -e 'si_tr_s' -e 'si_dt_05' -e 'si_et_05' > $tmp/sph.list 28 | 29 | if [ ! -d $wav_dir ]; then 30 | while read line; do 31 | wav=`echo "$line" | sed "s:wv1:wav:g" | awk -v dir=$wav_dir -F'/' '{printf("%s/%s/%s/%s", dir, $(NF-2), $(NF-1), $NF)}'` 32 | echo $wav 33 | mkdir -p `dirname $wav` 34 | $sph2pipe -f wav $line > $wav 35 | done < $tmp/sph.list > $tmp/wav.list 36 | else 37 | echo "Do you already get wav files? if not, please remove $wav_dir" 38 | fi 39 | -------------------------------------------------------------------------------- /egs/kinect-wsj/README.md: -------------------------------------------------------------------------------- 1 | ### Kinect-WSJ dataset 2 | Kinect-WSJ is a reverberated, noisy version of the WSJ0-2MIX dataset. 3 | Microphones are placed on a linear array with spacing between the devices 4 | resembling that of Microsoft Kinect ™, the device used to record the CHiME-5 dataset. 5 | This was done so that we could use the real ambient noise captured as part of CHiME-5 dataset. 6 | The room impulse responses (RIR) were simulated for a sampling rate of 16,000 Hz. 7 | 8 | **Requirements** 9 | * wsj_path : Path to precomputed wsj-2mix dataset. Should contain the folder 2speakers/wav16k/. 10 | If you don't have wsj_mix dataset, please create it using the scripts in egs/wsj0_mix 11 | * chime_path : Path to chime-5 dataset. Should contain the folders train, dev and eval 12 | * dihard_path : Path to dihard labels. Should contain ```*.lab``` files for the train and dev set 13 | 14 | **References** 15 | [Original repo](https://github.com/sunits/Reverberated_WSJ_2MIX/) 16 | 17 | ``` 18 | @inproceedings{sivasankaran2020, 19 | booktitle = {2020 28th {{European Signal Processing Conference}} ({{EUSIPCO}})}, 20 | title={Analyzing the impact of speaker localization errors on speech separation for automatic speech recognition}, 21 | author={Sunit Sivasankaran and Emmanuel Vincent and Dominique Fohr}, 22 | year={2021}, 23 | month = Jan, 24 | } 25 | ``` 26 | 27 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: "\U0001F41B Bug report" 3 | about: Create a report to help us improve 4 | title: '' 5 | labels: bug, help wanted 6 | assignees: '' 7 | 8 | --- 9 | 10 | ### Before reporting a bug: 11 | First, please search [previous issues](https://github.com/asteroid-team/asteroid/issues) 12 | and [the FAQ](https://asteroid-team.github.io/asteroid/faq.html) and be sure this hasn't 13 | been answered elsewhere. 14 | 15 | ## 🐛 Bug 16 | 17 | 18 | 19 | ### To Reproduce 20 | 22 | 23 | Steps to reproduce the behavior (code sample and stack trace): 24 | 25 | ### Expected behavior 26 | 27 | 28 | 29 | ### Environment 30 | 31 | #### Package versions 32 | 33 | Run `asteroid-versions` and paste the output here: 34 | 35 | ``` 36 | Paste here 37 | ``` 38 | 39 | Alternatively, if you cannot install Asteroid or have an old version that doesn't have the `asteroid-versions` script, 40 | please output the output of: 41 | 42 | ``` 43 | pip freeze | egrep -i 'pytorch|torch|asteroid' 44 | ``` 45 | 46 | #### Additional info 47 | 48 | Additional info (environment, custom script, etc...) 49 | -------------------------------------------------------------------------------- /egs/TAC/local/parse_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import soundfile as sf 4 | import argparse 5 | import glob 6 | import re 7 | from pathlib import Path 8 | 9 | parser = argparse.ArgumentParser("parsing tac dataset") 10 | parser.add_argument("--in_dir", type=str) 11 | parser.add_argument("--out_json", type=str) 12 | 13 | 14 | def parse_dataset(in_dir, out_json): 15 | 16 | examples = [] 17 | for n_mic_f in glob.glob(os.path.join(in_dir, "*")): 18 | for sample_dir in glob.glob(os.path.join(n_mic_f, "*")): 19 | c_ex = {} 20 | for wav in glob.glob(os.path.join(sample_dir, "*.wav")): 21 | 22 | source_or_mix = Path(wav).stem.split("_")[0] 23 | n_mic = int(re.findall("\d+", Path(wav).stem.split("_")[-1])[0]) 24 | length = len(sf.SoundFile(wav)) 25 | 26 | if n_mic not in c_ex.keys(): 27 | c_ex[n_mic] = {source_or_mix: wav, "length": length} 28 | else: 29 | assert c_ex[n_mic]["length"] == length 30 | c_ex[n_mic][source_or_mix] = wav 31 | examples.append(c_ex) 32 | 33 | os.makedirs(Path(out_json).parent, exist_ok=True) 34 | 35 | with open(out_json, "w") as f: 36 | json.dump(examples, f, indent=4) 37 | 38 | 39 | if __name__ == "__main__": 40 | args = parser.parse_args() 41 | parse_dataset(args.in_dir, args.out_json) 42 | -------------------------------------------------------------------------------- /egs/dampvsep/ConvTasNet/README.md: -------------------------------------------------------------------------------- 1 | ## Description 2 | 3 | ConvTasNet model trained using DAMP-VSEP dataset. 4 | The dataset is preprocessed to obtain only single ensembles performances. 5 | The preprocess return two train sets, one validation and one test sets. 6 | 7 | The preprocessing steps can be found is [this repo](https://github.com/groadabike/DAMP-VSEP-Singles) 8 | 9 | The details of the dataset: 10 | 11 | | Dataset | Perf | hrs | 12 | |:--------------|----------:|----------:| 13 | | train_english | 9243 | 77 | 14 | | train_singles | 20660 | 174 | 15 | | valid | 100 | 0.8 | 16 | | test | 100 | 0.8 | 17 | 18 | 19 | 20 | ## Results 21 | The next results were obtained by remixing the sources. 22 | Results using the original mixture are pending. 23 | 24 | | | Mixture |SI-SNRi(dB) (v)| STOI (v)|SDRi(dB) (b)| 25 | |:-------------:|:---------:|:-------------:|:-------:|:----------:| 26 | | train_english | remix | 14.3 | 0.6872 | 14.5 | 27 | | train_english | original | --- | --- | --- | 28 | | train_singles | remix | 15.0 | 0.6808 | 14.8 | 29 | | train_singles | original | --- | --- | --- | 30 | 31 | (v): vocal 32 | (b): background accompaniment 33 | 34 | ## Python requirements 35 | 36 | pip install librosa 37 | conda install -c conda-forge ffmpeg 38 | -------------------------------------------------------------------------------- /tests/masknn/norms_test.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | from torch import nn 4 | 5 | from asteroid.masknn import norms 6 | 7 | 8 | @pytest.mark.parametrize("norm_str", ["gLN", "cLN", "cgLN", "bN", "fgLN"]) 9 | @pytest.mark.parametrize("channel_size", [8, 128, 4]) 10 | def test_norms(norm_str, channel_size): 11 | norm_layer = norms.get(norm_str) 12 | # Use get on the class 13 | out_from_get = norms.get(norm_layer) 14 | assert out_from_get == norm_layer 15 | # Use get on the instance 16 | norm_layer = norm_layer(channel_size) 17 | out_from_get = norms.get(norm_layer) 18 | assert out_from_get == norm_layer 19 | 20 | # Test forward 21 | inp = torch.randn(4, channel_size, 12) 22 | out = norm_layer(inp) 23 | assert not torch.isnan(out).any() 24 | 25 | 26 | @pytest.mark.parametrize("wrong", ["wrong_string", 12, object()]) 27 | def test_get_errors(wrong): 28 | with pytest.raises(ValueError): 29 | # Should raise for anything not a Optimizer instance + unknown string 30 | norms.get(wrong) 31 | 32 | 33 | def test_get_none(): 34 | assert norms.get(None) is None 35 | 36 | 37 | def test_register(): 38 | class Custom(nn.Module): 39 | def __init__(self): 40 | super().__init__() 41 | 42 | norms.register_norm(Custom) 43 | cls = norms.get("Custom") 44 | assert cls == Custom 45 | 46 | with pytest.raises(ValueError): 47 | norms.register_norm(norms.CumLN) 48 | -------------------------------------------------------------------------------- /tests/utils/hub_utils_test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pytest 3 | from asteroid.utils import hub_utils 4 | 5 | 6 | HF_EXAMPLE_MODEL_IDENTIFER = "julien-c/DPRNNTasNet-ks16_WHAM_sepclean" 7 | HF_EXAMPLE_MODEL_IDENTIFER_URL = "https://huggingface.co/julien-c/DPRNNTasNet-ks16_WHAM_sepclean" 8 | # An actual model hosted on huggingface.co 9 | 10 | REVISION_ID_ONE_SPECIFIC_COMMIT = "8ab5ef18ef2eda141dd11a5d037a8bede7804ce4" 11 | # One particular commit (not the top of `main`) 12 | 13 | 14 | def test_download(): 15 | # We download 16 | path1 = hub_utils.cached_download("mpariente/ConvTasNet_WHAM!_sepclean") 17 | assert os.path.isfile(path1) 18 | # We use cache 19 | path2 = hub_utils.cached_download("mpariente/ConvTasNet_WHAM!_sepclean") 20 | assert path1 == path2 21 | 22 | 23 | @pytest.mark.parametrize( 24 | "model_id", 25 | [HF_EXAMPLE_MODEL_IDENTIFER, HF_EXAMPLE_MODEL_IDENTIFER_URL], 26 | ) 27 | def test_hf_download(model_id): 28 | # We download 29 | path1 = hub_utils.cached_download(model_id) 30 | assert os.path.isfile(path1) 31 | # We use cache 32 | path2 = hub_utils.cached_download(model_id) 33 | assert path1 == path2 34 | # However if specifying a particular commit, 35 | # file will be different. 36 | path3 = hub_utils.cached_download( 37 | f"{HF_EXAMPLE_MODEL_IDENTIFER}@{REVISION_ID_ONE_SPECIFIC_COMMIT}" 38 | ) 39 | assert path3 != path1 40 | 41 | 42 | def test_model_list(): 43 | hub_utils.model_list() 44 | hub_utils.model_list(name_only=True) 45 | -------------------------------------------------------------------------------- /docs/source/why_use_asteroid.rst: -------------------------------------------------------------------------------- 1 | What is Asteroid? 2 | ================= 3 | 4 | Asteroid is a PyTorch-based audio source separation toolkit. 5 | 6 | The main goals of Asteroid are: 7 | 8 | - Gather a wider **community** around audio source separation by lowering the barriers to entry. 9 | - **Promote reproducibility** by replicating important research papers. 10 | - Automatize most engineering and **make way for research**. 11 | - Simplify **model sharing** to reduce compute costs and carbon footprint. 12 | 13 | 14 | So, how do we do that? We aim to provide 15 | 16 | - PyTorch ``Dataset`` for **common datasets**. 17 | - Ready-to-use state-of-the art source separation architectures in **native PyTorch**. 18 | - **Configurable recipes** from data preparation to evaluation. 19 | - **Pretrained models** for a wide variety of tasks and architectures. 20 | 21 | Who is it for? 22 | -------------- 23 | 24 | Asteroid has several target usage: 25 | 26 | - Use asteroid in your own code, as a package. 27 | - Use available recipes to build your own separation model. 28 | - Use pretrained models to process your files. 29 | - Hit the ground running with your research ideas! 30 | 31 | 32 | Want to know more? 33 | ------------------ 34 | 35 | - `Visit our webpage `__ 36 | - `Read our paper `__ 37 | - `Watch the presentation video `__ 38 | - `Check how we won the PyTorch Hackathon 2020 ! `__ 39 | -------------------------------------------------------------------------------- /asteroid/losses/__init__.py: -------------------------------------------------------------------------------- 1 | from .pit_wrapper import PITLossWrapper 2 | from .mixit_wrapper import MixITLossWrapper 3 | from .sinkpit_wrapper import SinkPITLossWrapper 4 | from .sdr import PairwiseNegSDR 5 | from .sdr import pairwise_neg_sisdr, singlesrc_neg_sisdr, multisrc_neg_sisdr 6 | from .sdr import pairwise_neg_sdsdr, singlesrc_neg_sdsdr, multisrc_neg_sdsdr 7 | from .sdr import pairwise_neg_snr, singlesrc_neg_snr, multisrc_neg_snr 8 | from .mse import pairwise_mse, singlesrc_mse, multisrc_mse 9 | from .cluster import deep_clustering_loss 10 | from .pmsqe import SingleSrcPMSQE 11 | from .multi_scale_spectral import SingleSrcMultiScaleSpectral 12 | 13 | try: 14 | from .stoi import NegSTOILoss as SingleSrcNegSTOI 15 | except ModuleNotFoundError: 16 | # Is installed with asteroid, but remove the deps for TorchHub. 17 | def f(): 18 | raise ModuleNotFoundError("No module named 'torch_stoi'") 19 | 20 | SingleSrcNegSTOI = lambda *a, **kw: f() 21 | 22 | 23 | __all__ = [ 24 | "PITLossWrapper", 25 | "MixITLossWrapper", 26 | "SinkPITLossWrapper", 27 | "PairwiseNegSDR", 28 | "singlesrc_neg_sisdr", 29 | "pairwise_neg_sisdr", 30 | "multisrc_neg_sisdr", 31 | "pairwise_neg_sdsdr", 32 | "singlesrc_neg_sdsdr", 33 | "multisrc_neg_sdsdr", 34 | "pairwise_neg_snr", 35 | "singlesrc_neg_snr", 36 | "multisrc_neg_snr", 37 | "pairwise_mse", 38 | "singlesrc_mse", 39 | "multisrc_mse", 40 | "deep_clustering_loss", 41 | "SingleSrcPMSQE", 42 | "SingleSrcNegSTOI", 43 | "SingleSrcMultiScaleSpectral", 44 | ] 45 | -------------------------------------------------------------------------------- /egs/avspeech/looking-to-listen/train/metric_utils.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import torch 3 | import mir_eval 4 | import numpy as np 5 | from asteroid.data.avspeech_dataset import AVSpeechDataset 6 | 7 | 8 | def snr(pred_signal: torch.Tensor, true_signal: torch.Tensor) -> torch.FloatTensor: 9 | """ 10 | Calculate the Signal-to-Noise Ratio 11 | from two signals 12 | 13 | Args: 14 | pred_signal (torch.Tensor): predicted signal spectrogram. 15 | true_signal (torch.Tensor): original signal spectrogram. 16 | 17 | """ 18 | inter_signal = true_signal - pred_signal 19 | 20 | true_power = (true_signal**2).sum() 21 | inter_power = (inter_signal**2).sum() 22 | 23 | snr = 10 * torch.log10(true_power / inter_power) 24 | 25 | return snr 26 | 27 | 28 | def sdr(pred_signal: torch.Tensor, true_signal: torch.Tensor) -> torch.FloatTensor: 29 | """ 30 | Calculate the Signal-to-Distortion Ratio 31 | from two signals 32 | 33 | Args: 34 | pred_signal (torch.Tensor): predicted signal spectrogram. 35 | true_signal (torch.Tensor): original signal spectrogram. 36 | 37 | """ 38 | n_sources = pred_signal.shape[0] 39 | 40 | y_pred_wav = np.zeros((n_sources, 48_000)) 41 | y_wav = np.zeros((n_sources, 48_000)) 42 | 43 | for i in range(n_sources): 44 | y_pred_wav[i] = AVSpeechDataset.decode(pred_signal[i, ...]).numpy() 45 | y_wav[i] = AVSpeechDataset.decode(true_signal[i, ...]).numpy() 46 | sdr, sir, sar, _ = mir_eval.separation.bss_eval_sources(y_wav, y_pred_wav) 47 | 48 | return sdr 49 | -------------------------------------------------------------------------------- /asteroid/masknn/_dccrn_architectures.py: -------------------------------------------------------------------------------- 1 | # fmt: off 2 | DCCRN_ARCHITECTURES = { 3 | "DCCRN-CL": ( 4 | # Encoders: 5 | # (in_chan, out_chan, kernel_size, stride, padding) 6 | ( 7 | ( 1, 16, (5, 2), (2, 1), (2, 0)), 8 | ( 16, 32, (5, 2), (2, 1), (2, 0)), 9 | ( 32, 64, (5, 2), (2, 1), (2, 0)), 10 | ( 64, 128, (5, 2), (2, 1), (2, 0)), 11 | (128, 128, (5, 2), (2, 1), (2, 0)), 12 | (128, 128, (5, 2), (2, 1), (2, 0)), 13 | ), 14 | # Decoders: 15 | # (in_chan, out_chan, kernel_size, stride, padding, output_padding) 16 | ( 17 | (256, 128, (5, 2), (2, 1), (2, 0), (1, 0)), 18 | (256, 128, (5, 2), (2, 1), (2, 0), (1, 0)), 19 | (256, 64, (5, 2), (2, 1), (2, 0), (1, 0)), 20 | (128, 32, (5, 2), (2, 1), (2, 0), (1, 0)), 21 | ( 64, 16, (5, 2), (2, 1), (2, 0), (1, 0)), 22 | ( 32, 1, (5, 2), (2, 1), (2, 0), (1, 0)), 23 | ), 24 | ), 25 | "mini": ( 26 | # This is a dummy architecture used for Asteroid unit tests. 27 | 28 | # Encoders: 29 | # (in_chan, out_chan, kernel_size, stride, padding) 30 | ( 31 | (1, 4, (5, 2), (2, 1), (2, 0)), 32 | (4, 8, (5, 2), (2, 1), (2, 0)), 33 | ), 34 | # Decoders: 35 | # (in_chan, out_chan, kernel_size, stride, padding, output_padding) 36 | ( 37 | (16, 4, (5, 2), (2, 1), (2, 0), (1, 0)), 38 | ( 8, 1, (5, 2), (2, 1), (2, 0), (1, 0)), 39 | ), 40 | ), 41 | } 42 | -------------------------------------------------------------------------------- /tests/cli_setup.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import soundfile as sf 3 | import numpy as np 4 | import os 5 | from asteroid.models import ConvTasNet, save_publishable 6 | from asteroid.data.wham_dataset import wham_noise_license, wsj0_license 7 | 8 | 9 | def setup_register_sr(): 10 | model = ConvTasNet( 11 | n_src=2, 12 | n_repeats=2, 13 | n_blocks=3, 14 | bn_chan=16, 15 | hid_chan=4, 16 | skip_chan=8, 17 | n_filters=32, 18 | ) 19 | to_save = model.serialize() 20 | to_save["model_args"].pop("sample_rate") 21 | torch.save(to_save, "tmp.th") 22 | 23 | 24 | def setup_infer(): 25 | sf.write("tmp.wav", np.random.randn(16000), 8000) 26 | sf.write("tmp2.wav", np.random.randn(16000), 8000) 27 | 28 | 29 | def setup_upload(): 30 | train_set_infos = dict( 31 | dataset="WHAM", task="sep_noisy", licenses=[wsj0_license, wham_noise_license] 32 | ) 33 | final_results = {"si_sdr": 8.67, "si_sdr_imp": 13.16} 34 | model = ConvTasNet( 35 | n_src=2, 36 | n_repeats=2, 37 | n_blocks=3, 38 | bn_chan=16, 39 | hid_chan=4, 40 | skip_chan=8, 41 | n_filters=32, 42 | ) 43 | model_dict = model.serialize() 44 | model_dict.update(train_set_infos) 45 | 46 | os.makedirs("publish_dir", exist_ok=True) 47 | save_publishable( 48 | "publish_dir", 49 | model_dict, 50 | metrics=final_results, 51 | train_conf=dict(), 52 | ) 53 | 54 | 55 | if __name__ == "__main__": 56 | setup_register_sr() 57 | setup_infer() 58 | setup_upload() 59 | -------------------------------------------------------------------------------- /tests/cli_test.py: -------------------------------------------------------------------------------- 1 | from asteroid.scripts import asteroid_versions, asteroid_cli 2 | 3 | 4 | def test_asteroid_versions(): 5 | versions = asteroid_versions.asteroid_versions() 6 | assert "Asteroid" in versions 7 | assert "PyTorch" in versions 8 | assert "PyTorch-Lightning" in versions 9 | 10 | 11 | def test_print_versions(): 12 | asteroid_versions.print_versions() 13 | 14 | 15 | def test_asteroid_versions_without_git(monkeypatch): 16 | monkeypatch.setenv("PATH", "") 17 | asteroid_versions.asteroid_versions() 18 | 19 | 20 | def test_infer_device(monkeypatch): 21 | """Test that inference is performed on the PyTorch device given by '--device'. 22 | 23 | We can't properly test this in environments with only CPU device available. 24 | As an approximation we test that the '.to()' method of the model is called 25 | with the device given by '--device'. 26 | """ 27 | 28 | # We can't use a real model to test this because calling .to() with a fake device 29 | # on a real model will fail. 30 | class FakeModel: 31 | def to(self, device): 32 | self.device = device 33 | 34 | fake_model = FakeModel() 35 | 36 | # Monkeypatch 'from_pretrained' to load our fake model. 37 | from asteroid.models import BaseModel 38 | 39 | monkeypatch.setattr(BaseModel, "from_pretrained", lambda *args, **kwargs: fake_model) 40 | 41 | # Note that this will issue a warning about the missing file. 42 | asteroid_cli.infer( 43 | ["--device", "cuda:42", "somemodel", "--files", "file_that_does_not_exist.wav"] 44 | ) 45 | 46 | assert fake_model.device == "cuda:42" 47 | -------------------------------------------------------------------------------- /egs/sms_wsj/CaCGMM/start_evaluation.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import dlp_mpi 4 | import yaml 5 | from asteroid.utils import prepare_parser_from_dict, parse_args_as_dict 6 | from sms_wsj.examples.reference_systems import experiment 7 | 8 | parser = argparse.ArgumentParser() 9 | parser.add_argument("--json_path", default="data/sms_wsj.json", help="Full path to sms_wsj.json") 10 | 11 | 12 | def main(conf): 13 | experiment.run( 14 | config_updates=dict(json_path=conf["main_args"]["json_path"], **conf["mm_config"]) 15 | ) 16 | 17 | 18 | if __name__ == "__main__": 19 | if dlp_mpi.IS_MASTER: 20 | # We start with opening the config file conf.yml as a dictionary from 21 | # which we can create parsers. Each top level key in the dictionary defined 22 | # by the YAML file creates a group in the parser. 23 | with open("local/conf.yml") as f: 24 | def_conf = yaml.safe_load(f) 25 | parser = prepare_parser_from_dict(def_conf, parser=parser) 26 | # Arguments are then parsed into a hierarchical dictionary (instead of 27 | # flat, as returned by argparse) to falicitate calls to the different 28 | # asteroid methods (see in main). 29 | # plain_args is the direct output of parser.parse_args() and contains all 30 | # the attributes in an non-hierarchical structure. It can be useful to also 31 | # have it so we included it here but it is not used. 32 | arg_dict, plain_args = parse_args_as_dict(parser, return_plain_args=True) 33 | else: 34 | arg_dict = None 35 | arg_dict = dlp_mpi.bcast(arg_dict, root=dlp_mpi.MASTER) 36 | main(arg_dict) 37 | -------------------------------------------------------------------------------- /asteroid/scripts/asteroid_versions.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import pathlib 3 | import subprocess 4 | import torch 5 | import pytorch_lightning as pl 6 | import asteroid 7 | 8 | 9 | def print_versions(): 10 | """CLI function to get info about the Asteroid and dependency versions.""" 11 | for k, v in asteroid_versions().items(): 12 | print(f"{k:20s}{v}") 13 | 14 | 15 | def asteroid_versions(): 16 | return { 17 | "Asteroid": asteroid_version(), 18 | "PyTorch": pytorch_version(), 19 | "PyTorch-Lightning": pytorch_lightning_version(), 20 | } 21 | 22 | 23 | def pytorch_version(): 24 | return torch.__version__ 25 | 26 | 27 | def pytorch_lightning_version(): 28 | return pl.__version__ 29 | 30 | 31 | def asteroid_version(): 32 | asteroid_root = pathlib.Path(__file__).parent.parent.parent 33 | if asteroid_root.joinpath(".git").exists(): 34 | return f"{asteroid.__version__}, Git checkout {get_git_version(asteroid_root)}" 35 | else: 36 | return asteroid.__version__ 37 | 38 | 39 | def get_git_version(root): 40 | def _git(*cmd): 41 | return subprocess.check_output(["git", *cmd], cwd=root).strip().decode("ascii", "ignore") 42 | 43 | try: 44 | commit = _git("rev-parse", "HEAD") 45 | branch = _git("rev-parse", "--symbolic-full-name", "--abbrev-ref", "HEAD") 46 | dirty = _git("status", "--porcelain") 47 | except Exception as err: 48 | print(f"Failed to get Git checkout info: {err}", file=sys.stderr) 49 | return "" 50 | s = commit[:12] 51 | if branch: 52 | s += f" ({branch})" 53 | if dirty: 54 | s += f", dirty tree" 55 | return s 56 | -------------------------------------------------------------------------------- /egs/avspeech/looking-to-listen/local/loader/remove_empty_audio.py: -------------------------------------------------------------------------------- 1 | """ 2 | Extremely fast mixing (100+ audio files per second) 3 | generates a lot of empty/corrupted files 4 | """ 5 | 6 | import os 7 | import pandas as pd 8 | from pathlib import Path 9 | from argparse import ArgumentParser 10 | 11 | from constants import MIXED_AUDIO_DIR 12 | 13 | 14 | def remove_corrupt_audio(audio_dir, df, path, expected_audio_size=96_000): 15 | files = audio_dir.rglob("*wav") 16 | 17 | corrupt_audio = [] 18 | 19 | for f in files: 20 | size = f.stat().st_size 21 | if f.as_posix().startswith("../.."): 22 | # pathname should match with content of {train/val}.csv 23 | f = Path(*f.parts[2:]) 24 | 25 | if size < expected_audio_size: 26 | corrupt_audio.append(f.as_posix()) 27 | 28 | print(f"Found total corrupted files: {len(corrupt_audio)}") 29 | 30 | filtered_df = df[~df["mixed_audio"].isin(corrupt_audio)] 31 | print(df.shape, filtered_df.shape) 32 | 33 | filtered_df.to_csv(path, index=False) 34 | 35 | 36 | if __name__ == "__main__": 37 | 38 | parser = ArgumentParser() 39 | 40 | parser.add_argument("--mixed-dir", default=Path(MIXED_AUDIO_DIR), type=Path) 41 | parser.add_argument("--train-df", default=Path("../../data/train.csv"), type=Path) 42 | parser.add_argument("--val-df", default=Path("../../data/val.csv"), type=Path) 43 | 44 | args = parser.parse_args() 45 | 46 | train_df = pd.read_csv(args.train_df) 47 | val_df = pd.read_csv(args.val_df) 48 | 49 | remove_corrupt_audio(args.mixed_dir, train_df, args.train_df) 50 | remove_corrupt_audio(args.mixed_dir, val_df, args.val_df) 51 | -------------------------------------------------------------------------------- /.github/workflows/test_torch_hub.yml: -------------------------------------------------------------------------------- 1 | name: TorchHub integration 2 | 3 | on: push 4 | 5 | jobs: 6 | src-test: 7 | name: TorchHub integration 8 | runs-on: ubuntu-latest 9 | 10 | # Timeout: https://stackoverflow.com/a/59076067/4521646 11 | timeout-minutes: 10 12 | 13 | strategy: 14 | matrix: 15 | python-version: [3.13] 16 | 17 | env: 18 | ACTIONS_ALLOW_UNSECURE_COMMANDS: True 19 | steps: 20 | - name: Install libnsdfile 21 | run: | 22 | sudo apt update 23 | sudo apt install libsndfile1-dev libsndfile1 24 | 25 | - uses: actions/checkout@v5 26 | 27 | - name: Extract branch name 28 | run: echo "::set-env name=BRANCH::${GITHUB_REF#refs/heads/}" 29 | - name: Check branch name 30 | run: echo $BRANCH 31 | 32 | - name: Set up Python ${{ matrix.python-version }} 33 | uses: actions/setup-python@v6 34 | with: 35 | python-version: ${{ matrix.python-version }} 36 | 37 | - name: Install python dependencies 38 | run: | 39 | python -m pip install --upgrade pip --quiet 40 | python -m pip install numpy Cython --upgrade-strategy only-if-needed --quiet 41 | python -m pip install -r requirements/torchhub.txt --quiet 42 | python --version 43 | pip --version 44 | python -m pip list 45 | 46 | - name: TorchHub list 47 | run: | 48 | python -c "import torch; print(torch.hub.list('mpariente/asteroid:$BRANCH'))" 49 | 50 | - name: TorchHub help 51 | run: | 52 | python -c "import torch; print(torch.hub.help('mpariente/asteroid:$BRANCH', 'conv_tasnet'))" 53 | -------------------------------------------------------------------------------- /tests/jit/jit_filterbanks_test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pytest 3 | from torch.testing import assert_close 4 | from asteroid_filterbanks import make_enc_dec 5 | from asteroid.models.base_models import BaseEncoderMaskerDecoder 6 | 7 | 8 | @pytest.mark.parametrize( 9 | "filter_bank_name", 10 | ("free", "stft", "analytic_free", "param_sinc"), 11 | ) 12 | @pytest.mark.parametrize( 13 | "inference_data", 14 | ( 15 | (torch.rand(240) - 0.5) * 2, 16 | (torch.rand(1, 220) - 0.5) * 2, 17 | (torch.rand(4, 256) - 0.5) * 2, 18 | (torch.rand(1, 3, 312) - 0.5) * 2, 19 | (torch.rand(3, 2, 128) - 0.5) * 2, 20 | (torch.rand(1, 1, 3, 212) - 0.5) * 2, 21 | (torch.rand(2, 4, 3, 128) - 0.5) * 2, 22 | ), 23 | ) 24 | def test_jit_filterbanks(filter_bank_name, inference_data): 25 | model = DummyModel(fb_name=filter_bank_name) 26 | model = model.eval() 27 | 28 | inputs = ((torch.rand(1, 200) - 0.5) * 2,) 29 | traced = torch.jit.trace(model, inputs) 30 | with torch.no_grad(): 31 | res = model(inference_data) 32 | out = traced(inference_data) 33 | assert_close(res, out) 34 | 35 | 36 | class DummyModel(BaseEncoderMaskerDecoder): 37 | def __init__( 38 | self, 39 | fb_name="free", 40 | kernel_size=16, 41 | n_filters=32, 42 | stride=8, 43 | encoder_activation=None, 44 | **fb_kwargs, 45 | ): 46 | encoder, decoder = make_enc_dec( 47 | fb_name, kernel_size=kernel_size, n_filters=n_filters, stride=stride, **fb_kwargs 48 | ) 49 | masker = torch.nn.Identity() 50 | super().__init__(encoder, masker, decoder, encoder_activation=encoder_activation) 51 | -------------------------------------------------------------------------------- /egs/wham/DPRNN/local/preprocess_wham.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import soundfile as sf 5 | 6 | 7 | def preprocess_one_dir(in_dir, out_dir, out_filename): 8 | """Create .json file for one condition.""" 9 | file_infos = [] 10 | in_dir = os.path.abspath(in_dir) 11 | wav_list = os.listdir(in_dir) 12 | wav_list.sort() 13 | for wav_file in wav_list: 14 | if not wav_file.endswith(".wav"): 15 | continue 16 | wav_path = os.path.join(in_dir, wav_file) 17 | samples = sf.SoundFile(wav_path) 18 | file_infos.append((wav_path, len(samples))) 19 | if not os.path.exists(out_dir): 20 | os.makedirs(out_dir) 21 | with open(os.path.join(out_dir, out_filename + ".json"), "w") as f: 22 | json.dump(file_infos, f, indent=4) 23 | 24 | 25 | def preprocess(inp_args): 26 | """Create .json files for all conditions.""" 27 | speaker_list = ["mix_both", "mix_clean", "mix_single", "s1", "s2", "noise"] 28 | for data_type in ["tr", "cv", "tt"]: 29 | for spk in speaker_list: 30 | preprocess_one_dir( 31 | os.path.join(inp_args.in_dir, data_type, spk), 32 | os.path.join(inp_args.out_dir, data_type), 33 | spk, 34 | ) 35 | 36 | 37 | if __name__ == "__main__": 38 | parser = argparse.ArgumentParser("WHAM data preprocessing") 39 | parser.add_argument( 40 | "--in_dir", type=str, default=None, help="Directory path of wham including tr, cv and tt" 41 | ) 42 | parser.add_argument( 43 | "--out_dir", type=str, default=None, help="Directory path to put output files" 44 | ) 45 | args = parser.parse_args() 46 | print(args) 47 | preprocess(args) 48 | -------------------------------------------------------------------------------- /egs/wham/DPTNet/local/preprocess_wham.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import soundfile as sf 5 | 6 | 7 | def preprocess_one_dir(in_dir, out_dir, out_filename): 8 | """Create .json file for one condition.""" 9 | file_infos = [] 10 | in_dir = os.path.abspath(in_dir) 11 | wav_list = os.listdir(in_dir) 12 | wav_list.sort() 13 | for wav_file in wav_list: 14 | if not wav_file.endswith(".wav"): 15 | continue 16 | wav_path = os.path.join(in_dir, wav_file) 17 | samples = sf.SoundFile(wav_path) 18 | file_infos.append((wav_path, len(samples))) 19 | if not os.path.exists(out_dir): 20 | os.makedirs(out_dir) 21 | with open(os.path.join(out_dir, out_filename + ".json"), "w") as f: 22 | json.dump(file_infos, f, indent=4) 23 | 24 | 25 | def preprocess(inp_args): 26 | """Create .json files for all conditions.""" 27 | speaker_list = ["mix_both", "mix_clean", "mix_single", "s1", "s2", "noise"] 28 | for data_type in ["tr", "cv", "tt"]: 29 | for spk in speaker_list: 30 | preprocess_one_dir( 31 | os.path.join(inp_args.in_dir, data_type, spk), 32 | os.path.join(inp_args.out_dir, data_type), 33 | spk, 34 | ) 35 | 36 | 37 | if __name__ == "__main__": 38 | parser = argparse.ArgumentParser("WHAM data preprocessing") 39 | parser.add_argument( 40 | "--in_dir", type=str, default=None, help="Directory path of wham including tr, cv and tt" 41 | ) 42 | parser.add_argument( 43 | "--out_dir", type=str, default=None, help="Directory path to put output files" 44 | ) 45 | args = parser.parse_args() 46 | print(args) 47 | preprocess(args) 48 | -------------------------------------------------------------------------------- /egs/wham/MixIT/local/preprocess_wham.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import soundfile as sf 5 | 6 | 7 | def preprocess_one_dir(in_dir, out_dir, out_filename): 8 | """Create .json file for one condition.""" 9 | file_infos = [] 10 | in_dir = os.path.abspath(in_dir) 11 | wav_list = os.listdir(in_dir) 12 | wav_list.sort() 13 | for wav_file in wav_list: 14 | if not wav_file.endswith(".wav"): 15 | continue 16 | wav_path = os.path.join(in_dir, wav_file) 17 | samples = sf.SoundFile(wav_path) 18 | file_infos.append((wav_path, len(samples))) 19 | if not os.path.exists(out_dir): 20 | os.makedirs(out_dir) 21 | with open(os.path.join(out_dir, out_filename + ".json"), "w") as f: 22 | json.dump(file_infos, f, indent=4) 23 | 24 | 25 | def preprocess(inp_args): 26 | """Create .json files for all conditions.""" 27 | speaker_list = ["mix_both", "mix_clean", "mix_single", "s1", "s2", "noise"] 28 | for data_type in ["tr", "cv", "tt"]: 29 | for spk in speaker_list: 30 | preprocess_one_dir( 31 | os.path.join(inp_args.in_dir, data_type, spk), 32 | os.path.join(inp_args.out_dir, data_type), 33 | spk, 34 | ) 35 | 36 | 37 | if __name__ == "__main__": 38 | parser = argparse.ArgumentParser("WHAM data preprocessing") 39 | parser.add_argument( 40 | "--in_dir", type=str, default=None, help="Directory path of wham including tr, cv and tt" 41 | ) 42 | parser.add_argument( 43 | "--out_dir", type=str, default=None, help="Directory path to put output files" 44 | ) 45 | args = parser.parse_args() 46 | print(args) 47 | preprocess(args) 48 | -------------------------------------------------------------------------------- /egs/wham/TwoStep/local/preprocess_wham.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import soundfile as sf 5 | 6 | 7 | def preprocess_one_dir(in_dir, out_dir, out_filename): 8 | """Create .json file for one condition.""" 9 | file_infos = [] 10 | in_dir = os.path.abspath(in_dir) 11 | wav_list = os.listdir(in_dir) 12 | wav_list.sort() 13 | for wav_file in wav_list: 14 | if not wav_file.endswith(".wav"): 15 | continue 16 | wav_path = os.path.join(in_dir, wav_file) 17 | samples = sf.SoundFile(wav_path) 18 | file_infos.append((wav_path, len(samples))) 19 | if not os.path.exists(out_dir): 20 | os.makedirs(out_dir) 21 | with open(os.path.join(out_dir, out_filename + ".json"), "w") as f: 22 | json.dump(file_infos, f, indent=4) 23 | 24 | 25 | def preprocess(inp_args): 26 | """Create .json files for all conditions.""" 27 | speaker_list = ["mix_both", "mix_clean", "mix_single", "s1", "s2", "noise"] 28 | for data_type in ["tr", "cv", "tt"]: 29 | for spk in speaker_list: 30 | preprocess_one_dir( 31 | os.path.join(inp_args.in_dir, data_type, spk), 32 | os.path.join(inp_args.out_dir, data_type), 33 | spk, 34 | ) 35 | 36 | 37 | if __name__ == "__main__": 38 | parser = argparse.ArgumentParser("WHAM data preprocessing") 39 | parser.add_argument( 40 | "--in_dir", type=str, default=None, help="Directory path of wham including tr, cv and tt" 41 | ) 42 | parser.add_argument( 43 | "--out_dir", type=str, default=None, help="Directory path to put output files" 44 | ) 45 | args = parser.parse_args() 46 | print(args) 47 | preprocess(args) 48 | -------------------------------------------------------------------------------- /egs/wham/ConvTasNet/local/preprocess_wham.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import soundfile as sf 5 | 6 | 7 | def preprocess_one_dir(in_dir, out_dir, out_filename): 8 | """Create .json file for one condition.""" 9 | file_infos = [] 10 | in_dir = os.path.abspath(in_dir) 11 | wav_list = os.listdir(in_dir) 12 | wav_list.sort() 13 | for wav_file in wav_list: 14 | if not wav_file.endswith(".wav"): 15 | continue 16 | wav_path = os.path.join(in_dir, wav_file) 17 | samples = sf.SoundFile(wav_path) 18 | file_infos.append((wav_path, len(samples))) 19 | if not os.path.exists(out_dir): 20 | os.makedirs(out_dir) 21 | with open(os.path.join(out_dir, out_filename + ".json"), "w") as f: 22 | json.dump(file_infos, f, indent=4) 23 | 24 | 25 | def preprocess(inp_args): 26 | """Create .json files for all conditions.""" 27 | speaker_list = ["mix_both", "mix_clean", "mix_single", "s1", "s2", "noise"] 28 | for data_type in ["tr", "cv", "tt"]: 29 | for spk in speaker_list: 30 | preprocess_one_dir( 31 | os.path.join(inp_args.in_dir, data_type, spk), 32 | os.path.join(inp_args.out_dir, data_type), 33 | spk, 34 | ) 35 | 36 | 37 | if __name__ == "__main__": 38 | parser = argparse.ArgumentParser("WHAM data preprocessing") 39 | parser.add_argument( 40 | "--in_dir", type=str, default=None, help="Directory path of wham including tr, cv and tt" 41 | ) 42 | parser.add_argument( 43 | "--out_dir", type=str, default=None, help="Directory path to put output files" 44 | ) 45 | args = parser.parse_args() 46 | print(args) 47 | preprocess(args) 48 | -------------------------------------------------------------------------------- /egs/wham/DynamicMixing/local/preprocess_wham.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import soundfile as sf 5 | 6 | 7 | def preprocess_one_dir(in_dir, out_dir, out_filename): 8 | """Create .json file for one condition.""" 9 | file_infos = [] 10 | in_dir = os.path.abspath(in_dir) 11 | wav_list = os.listdir(in_dir) 12 | wav_list.sort() 13 | for wav_file in wav_list: 14 | if not wav_file.endswith(".wav"): 15 | continue 16 | wav_path = os.path.join(in_dir, wav_file) 17 | samples = sf.SoundFile(wav_path) 18 | file_infos.append((wav_path, len(samples))) 19 | if not os.path.exists(out_dir): 20 | os.makedirs(out_dir) 21 | with open(os.path.join(out_dir, out_filename + ".json"), "w") as f: 22 | json.dump(file_infos, f, indent=4) 23 | 24 | 25 | def preprocess(inp_args): 26 | """Create .json files for all conditions.""" 27 | speaker_list = ["mix_both", "mix_clean", "mix_single", "s1", "s2", "noise"] 28 | for data_type in ["tr", "cv", "tt"]: 29 | for spk in speaker_list: 30 | preprocess_one_dir( 31 | os.path.join(inp_args.in_dir, data_type, spk), 32 | os.path.join(inp_args.out_dir, data_type), 33 | spk, 34 | ) 35 | 36 | 37 | if __name__ == "__main__": 38 | parser = argparse.ArgumentParser("WHAM data preprocessing") 39 | parser.add_argument( 40 | "--in_dir", type=str, default=None, help="Directory path of wham including tr, cv and tt" 41 | ) 42 | parser.add_argument( 43 | "--out_dir", type=str, default=None, help="Directory path to put output files" 44 | ) 45 | args = parser.parse_args() 46 | print(args) 47 | preprocess(args) 48 | -------------------------------------------------------------------------------- /asteroid/models/dccrnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from asteroid_filterbanks.transforms import from_torch_complex, to_torch_complex 3 | from ..masknn.recurrent import DCCRMaskNet 4 | from .dcunet import BaseDCUNet 5 | 6 | 7 | class DCCRNet(BaseDCUNet): 8 | """DCCRNet as proposed in [1]. 9 | 10 | Args: 11 | architecture (str): The architecture to use, must be "DCCRN-CL". 12 | stft_kernel_size (int): STFT frame length to use 13 | stft_stride (int, optional): STFT hop length to use. 14 | sample_rate (float): Sampling rate of the model. 15 | masknet_kwargs (optional): Passed to :class:`DCCRMaskNet` 16 | 17 | References 18 | - [1] : "DCCRN: Deep Complex Convolution Recurrent Network for Phase-Aware Speech Enhancement", 19 | Yanxin Hu et al. https://arxiv.org/abs/2008.00264 20 | """ 21 | 22 | masknet_class = DCCRMaskNet 23 | 24 | def __init__( 25 | self, *args, stft_n_filters=512, stft_kernel_size=400, stft_stride=100, **masknet_kwargs 26 | ): 27 | masknet_kwargs.setdefault("n_freqs", stft_n_filters // 2) 28 | super().__init__( 29 | *args, 30 | stft_n_filters=stft_n_filters, 31 | stft_kernel_size=stft_kernel_size, 32 | stft_stride=stft_stride, 33 | **masknet_kwargs, 34 | ) 35 | 36 | def forward_encoder(self, wav): 37 | tf_rep = self.encoder(wav) 38 | # Remove Nyquist frequency bin 39 | return to_torch_complex(tf_rep)[..., :-1, :] 40 | 41 | def apply_masks(self, tf_rep, est_masks): 42 | masked_tf_rep = est_masks * tf_rep.unsqueeze(1) 43 | # Pad Nyquist frequency bin 44 | return from_torch_complex(torch.nn.functional.pad(masked_tf_rep, [0, 0, 0, 1])) 45 | -------------------------------------------------------------------------------- /tests/engine/system_test.py: -------------------------------------------------------------------------------- 1 | from torch import nn, optim 2 | from torch.utils import data 3 | from pytorch_lightning import Trainer 4 | 5 | from asteroid.engine.system import System 6 | from asteroid.utils.test_utils import DummyDataset 7 | 8 | 9 | def test_system(): 10 | model = nn.Sequential(nn.Linear(10, 10), nn.ReLU()) 11 | optimizer = optim.Adam(model.parameters(), lr=1e-3) 12 | scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer) 13 | dataset = DummyDataset() 14 | loader = data.DataLoader(dataset, batch_size=2, num_workers=4) 15 | system = System( 16 | model, 17 | optimizer, 18 | loss_func=nn.MSELoss(), 19 | train_loader=loader, 20 | val_loader=loader, 21 | scheduler=scheduler, 22 | ) 23 | trainer = Trainer( 24 | max_epochs=1, fast_dev_run=True, accelerator="cpu", strategy="ddp", devices="auto" 25 | ) 26 | trainer.fit(system) 27 | 28 | 29 | def test_system_no_scheduler(): 30 | model = nn.Sequential(nn.Linear(10, 10), nn.ReLU()) 31 | optimizer = optim.Adam(model.parameters(), lr=1e-3) 32 | scheduler = None 33 | dataset = DummyDataset() 34 | loader = data.DataLoader(dataset, batch_size=2, num_workers=4) 35 | system = System( 36 | model, 37 | optimizer, 38 | loss_func=nn.MSELoss(), 39 | train_loader=loader, 40 | val_loader=loader, 41 | scheduler=scheduler, 42 | ) 43 | trainer = Trainer( 44 | max_epochs=1, fast_dev_run=True, accelerator="cpu", strategy="ddp", devices="auto" 45 | ) 46 | trainer.fit(system) 47 | 48 | 49 | def test_config_to_hparams(): 50 | conf = {"data": {"a": 1, "b": 2}, "nnet": {"c": 3}, "optim": {"d": None, "e": [1, 2, 3]}} 51 | System.config_to_hparams(conf) 52 | -------------------------------------------------------------------------------- /asteroid/data/vad_dataset.py: -------------------------------------------------------------------------------- 1 | import json 2 | import soundfile as sf 3 | import torch 4 | from torch.utils.data import Dataset, DataLoader 5 | import random 6 | 7 | 8 | class LibriVADDataset(Dataset): 9 | """Dataset class for Voice Activity Detection. 10 | 11 | Args: 12 | md_file_path (str): The path to the metadata file. 13 | """ 14 | 15 | def __init__(self, md_file_path, sample_rate=8000, segment=3): 16 | 17 | self.md_filepath = md_file_path 18 | with open(self.md_filepath) as json_file: 19 | self.md = json.load(json_file) 20 | self.segment = segment 21 | self.sample_rate = sample_rate 22 | 23 | def __len__(self): 24 | return len(self.md) 25 | 26 | def __getitem__(self, idx): 27 | # Get the row in dataframe 28 | row = self.md[idx] 29 | # Get mixture path 30 | self.source_path = row[f"mixture_path"] 31 | length = len(sf.read(self.source_path)[0]) 32 | if self.segment is not None: 33 | start = random.randint(0, length - int(self.segment * self.sample_rate)) 34 | stop = start + int(self.segment * self.sample_rate) 35 | else: 36 | start = 0 37 | stop = None 38 | 39 | s, sr = sf.read(self.source_path, start=start, stop=stop, dtype="float32") 40 | # Convert sources to tensor 41 | source = torch.from_numpy(s) 42 | label = from_vad_to_label(length, row["VAD"], start, stop).unsqueeze(0) 43 | return source, label 44 | 45 | 46 | def from_vad_to_label(length, vad, begin, end): 47 | label = torch.zeros(length, dtype=torch.float) 48 | for start, stop in zip(vad["start"], vad["stop"]): 49 | label[..., start:stop] = 1 50 | return label[..., begin:end] 51 | -------------------------------------------------------------------------------- /egs/wham/DynamicMixing/utils/get_training_stats.py: -------------------------------------------------------------------------------- 1 | """ 2 | This script can be used to get dB lvl stats for WHAM sources and noise. 3 | """ 4 | 5 | import soundfile as sf 6 | from glob import glob 7 | import os 8 | import matplotlib.pyplot as plt 9 | import numpy as np 10 | 11 | WHAM_ROOT = "/media/sam/Data/WSJ/wham_scripts/2speakers_wham/wav8k" 12 | 13 | for mode in ["min", "max"]: 14 | for split in ["tr"]: 15 | 16 | noises = glob(os.path.join(WHAM_ROOT, mode, split, "noise", "*.wav")) 17 | s1 = glob(os.path.join(WHAM_ROOT, mode, split, "s1", "*.wav")) 18 | s2 = glob(os.path.join(WHAM_ROOT, mode, split, "s2", "*.wav")) 19 | 20 | # stat joint 21 | joint_src_stats = [] 22 | 23 | for i in range(len(s1)): 24 | c_s1 = s1[i] 25 | c_s2 = os.path.join(WHAM_ROOT, mode, split, "s2", c_s1.split("/")[-1]) 26 | noise = os.path.join(WHAM_ROOT, mode, split, "noise", c_s1.split("/")[-1]) 27 | 28 | c_s1_audio, _ = sf.read(c_s1) 29 | c_s2_audio, _ = sf.read(c_s2) 30 | noise, _ = sf.read(noise) 31 | 32 | c_s1_lvl = 20 * np.log10(np.max(np.abs(c_s1_audio))) 33 | c_s2_lvl = 20 * np.log10(np.max(np.abs(c_s2_audio))) 34 | noises_lvl = 20 * np.log10(np.max(np.abs(noise))) 35 | 36 | joint_src_stats.append([c_s1_lvl, c_s2_lvl, noises_lvl]) 37 | 38 | plt.hist2d( 39 | [x[0] for x in joint_src_stats], 40 | [x[1] for x in joint_src_stats], 41 | 100, 42 | [[-50, 0], [-50, 0]], 43 | ) 44 | plt.hist2d( 45 | [x[0] for x in joint_src_stats], 46 | [x[2] for x in joint_src_stats], 47 | 100, 48 | [[-50, 0], [-50, 0]], 49 | ) 50 | -------------------------------------------------------------------------------- /egs/sms_wsj/CaCGMM/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Exit on error 4 | set -e 5 | set -o pipefail 6 | 7 | 8 | storage_dir= # Main storage directory (Need disk space) 9 | wsj_dir= # WSJ DIR (if wsj0 and wsj1 have been merged) 10 | 11 | # Following wsj_dirs have to be defined, if wsj0 and wsj1 have not been merged 12 | wsj0_dir=$wsj_dir 13 | wsj1_dir=$wsj_dir 14 | 15 | num_jobs=$(nproc --all) 16 | 17 | stage=0 18 | 19 | python_path=python 20 | . utils/parse_options.sh || exit 1; 21 | 22 | data_dir=data # Local data directory (Not much disk space required) 23 | sms_wsj=${storage_dir}/DATA/sms_wsj/ # Directory where to save SMS-WSJ wav files 24 | 25 | export OMP_NUM_THREADS=1 26 | export MKL_NUM_THREADS=1 27 | 28 | 29 | if [[ $stage -le 0 ]]; then 30 | echo "Stage 0: Cloning and installing SMS-WSJ repository" 31 | if [[ ! -d local/sms_wsj ]]; then 32 | git clone https://github.com/fgnt/sms_wsj.git local/sms_wsj 33 | fi 34 | ${python_path} -m pip install -e local/sms_wsj 35 | fi 36 | 37 | 38 | if [[ $stage -le 1 ]]; then 39 | echo "Stage 1: Generating SMS_WSJ data" 40 | . local/prepare_data.sh --wsj0_dir $wsj0_dir --wsj1_dir $wsj1_dir --num_jobs $num_jobs \ 41 | --sms_wsj_dir $sms_wsj --json_dir $data_dir --python_path $python_path 42 | fi 43 | 44 | 45 | if [[ $stage -le 2 ]]; then 46 | if [[ ! -d local/pb_bss ]]; then 47 | echo "Downloading and installing pb_bss (a model based source separation toolbox)" 48 | git clone https://github.com/fgnt/pb_bss.git local/pb_bss 49 | ${python_path} -m pip install einops 50 | ${python_path} -m pip install nara_wpe 51 | ${python_path} -m pip install cython 52 | ${python_path} -m pip install -e local/pb_bss[all] 53 | fi 54 | mpiexec -n $num_jobs ${python_path} start_evaluation.py --json_path $data_dir/sms_wsj.json 55 | fi 56 | -------------------------------------------------------------------------------- /docs/source/package_reference/filterbanks.rst: -------------------------------------------------------------------------------- 1 | .. role:: hidden 2 | :class: hidden-section 3 | 4 | Filterbank API 5 | ============== 6 | 7 | Filterbank, Encoder and Decoder 8 | ------------------------------- 9 | .. autoclass:: asteroid_filterbanks.Filterbank 10 | :members: 11 | .. autoclass:: asteroid_filterbanks.Encoder 12 | :members: 13 | :show-inheritance: 14 | .. autoclass:: asteroid_filterbanks.Decoder 15 | :members: 16 | :show-inheritance: 17 | .. autoclass:: asteroid_filterbanks.make_enc_dec 18 | :members: 19 | .. autoclass:: asteroid_filterbanks.get 20 | 21 | Learnable filterbanks 22 | --------------------- 23 | 24 | :hidden:`Free` 25 | ~~~~~~~~~~~~~~~~ 26 | .. automodule:: asteroid_filterbanks.free_fb 27 | :members: 28 | 29 | :hidden:`Analytic Free` 30 | ~~~~~~~~~~~~~~~~~~~~~~~ 31 | .. automodule:: asteroid_filterbanks.analytic_free_fb 32 | :members: 33 | 34 | :hidden:`Parameterized Sinc` 35 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 36 | .. automodule:: asteroid_filterbanks.param_sinc_fb 37 | :members: 38 | 39 | Fixed filterbanks 40 | ----------------- 41 | 42 | :hidden:`STFT` 43 | ~~~~~~~~~~~~~~~~ 44 | .. automodule:: asteroid_filterbanks.stft_fb 45 | :members: 46 | 47 | :hidden:`MelGram` 48 | ~~~~~~~~~~~~~~~~~ 49 | .. automodule:: asteroid_filterbanks.melgram_fb 50 | :members: 51 | 52 | :hidden:`MPGT` 53 | ~~~~~~~~~~~~~~~~ 54 | .. autoclass:: asteroid_filterbanks.multiphase_gammatone_fb.MultiphaseGammatoneFB 55 | :members: 56 | 57 | Transforms 58 | ---------- 59 | 60 | :hidden:`Griffin-Lim and MISI` 61 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 62 | 63 | .. automodule:: asteroid_filterbanks.griffin_lim 64 | :members: 65 | 66 | :hidden:`Complex transforms` 67 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 68 | 69 | .. automodule:: asteroid_filterbanks.transforms 70 | :members: 71 | -------------------------------------------------------------------------------- /tests/dsp/consistency_test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.testing import assert_close 3 | import pytest 4 | 5 | from asteroid.dsp.consistency import mixture_consistency 6 | 7 | 8 | @pytest.mark.parametrize("mix_shape", [[2, 1600], [2, 130, 10]]) 9 | @pytest.mark.parametrize("dim", [1, 2]) 10 | @pytest.mark.parametrize("n_src", [1, 2, 3]) 11 | def test_consistency_noweight(mix_shape, dim, n_src): 12 | mix = torch.randn(mix_shape) 13 | est_shape = mix_shape[:dim] + [n_src] + mix_shape[dim:] 14 | est_sources = torch.randn(est_shape) 15 | consistent_est_sources = mixture_consistency(mix, est_sources, dim=dim) 16 | assert_close(mix, consistent_est_sources.sum(dim)) 17 | 18 | 19 | @pytest.mark.parametrize("mix_shape", [[2, 1600], [2, 130, 10]]) 20 | @pytest.mark.parametrize("dim", [1, 2]) 21 | @pytest.mark.parametrize("n_src", [1, 2, 3]) 22 | def test_consistency_withweight(mix_shape, dim, n_src): 23 | mix = torch.randn(mix_shape) 24 | est_shape = mix_shape[:dim] + [n_src] + mix_shape[dim:] 25 | est_sources = torch.randn(est_shape) 26 | # Create source weights : should have the same number of dims as 27 | # est_sources with ones out of batch and n_src dims. 28 | ones = [1 for _ in range(len(mix_shape) - 1)] 29 | src_weights_shape = mix_shape[:1] + ones[: dim - 1] + [n_src] + ones[dim - 1 :] 30 | src_weights = torch.softmax(torch.randn(src_weights_shape), dim=dim) 31 | # Apply mixture consitency 32 | consistent_est_sources = mixture_consistency(mix, est_sources, src_weights=src_weights, dim=dim) 33 | assert_close(mix, consistent_est_sources.sum(dim)) 34 | 35 | 36 | def test_consistency_raise(): 37 | mix = torch.randn(10, 1, 1, 160) 38 | est = torch.randn(10, 2, 160) 39 | with pytest.raises(RuntimeError): 40 | mixture_consistency(mix, est, dim=1) 41 | -------------------------------------------------------------------------------- /egs/wsj0-mix-var/Multi-Decoder-DPRNN/local/preprocess_wsj0mix.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import soundfile as sf 5 | 6 | 7 | def preprocess_one_dir(in_dir, out_dir, out_filename): 8 | """Create .json file for one condition.""" 9 | file_infos = [] 10 | in_dir = os.path.abspath(in_dir) 11 | wav_list = os.listdir(in_dir) 12 | wav_list.sort() 13 | for wav_file in wav_list: 14 | if not wav_file.endswith(".wav"): 15 | continue 16 | wav_path = os.path.join(in_dir, wav_file) 17 | samples = sf.SoundFile(wav_path) 18 | file_infos.append((wav_path, len(samples))) 19 | if not os.path.exists(out_dir): 20 | os.makedirs(out_dir) 21 | with open(os.path.join(out_dir, out_filename + ".json"), "w") as f: 22 | json.dump(file_infos, f, indent=4) 23 | 24 | 25 | def preprocess(inp_args): 26 | """Create .json files for all conditions.""" 27 | speaker_list = ["mix"] + [f"s{n+1}" for n in range(inp_args.n_src)] 28 | for data_type in ["tr", "cv", "tt"]: 29 | for spk in speaker_list: 30 | preprocess_one_dir( 31 | os.path.join(inp_args.in_dir, data_type, spk), 32 | os.path.join(inp_args.out_dir, data_type), 33 | spk, 34 | ) 35 | 36 | 37 | if __name__ == "__main__": 38 | parser = argparse.ArgumentParser("WSJ0-MIX data preprocessing") 39 | parser.add_argument( 40 | "--in_dir", type=str, default=None, help="Directory path of wham including tr, cv and tt" 41 | ) 42 | parser.add_argument("--n_src", type=int, default=2, help="Number of sources in wsj0-mix") 43 | parser.add_argument( 44 | "--out_dir", type=str, default=None, help="Directory path to put output files" 45 | ) 46 | args = parser.parse_args() 47 | preprocess(args) 48 | -------------------------------------------------------------------------------- /tests/masknn/convolutional_test.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | from asteroid.masknn import TDConvNet, TDConvNetpp 4 | 5 | 6 | @pytest.mark.parametrize("mask_act", ["relu", "softmax"]) 7 | @pytest.mark.parametrize("out_chan", [None, 10]) 8 | @pytest.mark.parametrize("skip_chan", [0, 12]) 9 | @pytest.mark.parametrize("causal", [True, False]) 10 | def test_tdconvnet(mask_act, out_chan, skip_chan, causal): 11 | in_chan, n_src = 20, 2 12 | model = TDConvNet( 13 | in_chan=in_chan, 14 | n_src=n_src, 15 | mask_act=mask_act, 16 | n_blocks=2, 17 | n_repeats=2, 18 | bn_chan=10, 19 | hid_chan=11, 20 | skip_chan=skip_chan, 21 | out_chan=out_chan, 22 | causal=causal, 23 | ) 24 | batch, n_frames = 2, 24 25 | inp = torch.randn(batch, in_chan, n_frames) 26 | out = model(inp) 27 | _ = model.get_config() 28 | out_chan = out_chan if out_chan else in_chan 29 | assert out.shape == (batch, n_src, out_chan, n_frames) 30 | 31 | 32 | @pytest.mark.parametrize("mask_act", ["relu", "softmax"]) 33 | @pytest.mark.parametrize("out_chan", [None, 10]) 34 | @pytest.mark.parametrize("skip_chan", [0, 12]) 35 | def test_tdconvnetpp(mask_act, out_chan, skip_chan): 36 | in_chan, n_src = 20, 2 37 | model = TDConvNetpp( 38 | in_chan=in_chan, 39 | n_src=n_src, 40 | mask_act=mask_act, 41 | n_blocks=2, 42 | n_repeats=2, 43 | bn_chan=10, 44 | hid_chan=11, 45 | skip_chan=skip_chan, 46 | out_chan=out_chan, 47 | ) 48 | batch, n_frames = 2, 24 49 | inp = torch.randn(batch, in_chan, n_frames) 50 | out, consistency_weights = model(inp) 51 | _ = model.get_config() 52 | out_chan = out_chan if out_chan else in_chan 53 | assert out.shape == (batch, n_src, out_chan, n_frames) 54 | -------------------------------------------------------------------------------- /egs/sms_wsj/CaCGMM/local/prepare_data.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | wsj0_dir= 3 | wsj1_dir= 4 | num_jobs=$(nproc --all) 5 | sms_wsj_dir= 6 | python_path=python 7 | json_dir=${sms_wsj_dir} 8 | 9 | . utils/parse_options.sh || exit 1; 10 | 11 | wsj_8k_zeromean=${sms_wsj_dir}/wsj_8k_zeromean 12 | rir_dir=${sms_wsj_dir}/rirs 13 | 14 | echo using ${num_jobs} parallel jobs 15 | 16 | if [[ ! -d $wsj_8k_zeromean ]]; then 17 | echo creating ${wsj_8k_zeromean} 18 | mpiexec -np ${num_jobs} $python_path -m sms_wsj.database.wsj.write_wav \ 19 | with dst_dir=${wsj_8k_zeromean} wsj0_root=${wsj0_dir} wsj1_root=${wsj1_dir} sample_rate=8000 20 | fi 21 | 22 | if [[ ! -d $json_dir ]]; then 23 | mkdir -p $json_dir 24 | fi 25 | 26 | if [[ ! -f $json_dir/wsj_8k_zeromean.json ]]; then 27 | echo creating $json_dir/wsj_8k_zeromean.json 28 | $python_path -m sms_wsj.database.wsj.create_json \ 29 | with json_path=$json_dir/wsj_8k_zeromean.json database_dir=$wsj_8k_zeromean as_wav=True 30 | fi 31 | 32 | if [[ ! -d $rir_dir ]]; then 33 | echo "RIR directory does not exist, starting download." 34 | mkdir -p ${rir_dir} 35 | wget -qO- https://zenodo.org/record/3517889/files/sms_wsj.tar.gz.parta{a,b,c,d,e} \ 36 | | tar -C ${rir_dir}/ -zx --checkpoint=10000 --checkpoint-action=echo="%u/5530000 %c" 37 | fi 38 | 39 | if [[ ! -f $json_dir/sms_wsj.json ]]; then 40 | echo creating $json_dir/sms_wsj.json 41 | $python_path -m sms_wsj.database.create_json \ 42 | with json_path=$json_dir/sms_wsj.json rir_dir=$rir_dir \ 43 | wsj_json_path=$json_dir/wsj_8k_zeromean.json 44 | 45 | fi 46 | 47 | 48 | echo creating $sms_wsj_dir files 49 | echo This amends the sms_wsj.json with the new paths. 50 | mpiexec -np ${num_jobs} $python_path -m sms_wsj.database.write_files \ 51 | with dst_dir=${sms_wsj_dir} json_path=$json_dir/sms_wsj.json \ 52 | write_all=True new_json_path=$json_dir/sms_wsj.json -------------------------------------------------------------------------------- /egs/wsj0-mix/DeepClustering/local/preprocess_wsj0mix.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import soundfile as sf 5 | 6 | 7 | def preprocess_one_dir(in_dir, out_dir, out_filename): 8 | """Create .json file for one condition.""" 9 | file_infos = [] 10 | in_dir = os.path.abspath(in_dir) 11 | wav_list = os.listdir(in_dir) 12 | wav_list.sort() 13 | for wav_file in wav_list: 14 | if not wav_file.endswith(".wav"): 15 | continue 16 | wav_path = os.path.join(in_dir, wav_file) 17 | samples = sf.SoundFile(wav_path) 18 | file_infos.append((wav_path, len(samples))) 19 | if not os.path.exists(out_dir): 20 | os.makedirs(out_dir) 21 | with open(os.path.join(out_dir, out_filename + ".json"), "w") as f: 22 | json.dump(file_infos, f, indent=4) 23 | 24 | 25 | def preprocess(inp_args): 26 | """Create .json files for all conditions.""" 27 | speaker_list = ["mix"] + [f"s{n+1}" for n in range(inp_args.n_src)] 28 | for data_type in ["tr", "cv", "tt"]: 29 | for spk in speaker_list: 30 | preprocess_one_dir( 31 | os.path.join(inp_args.in_dir, data_type, spk), 32 | os.path.join(inp_args.out_dir, data_type), 33 | spk, 34 | ) 35 | 36 | 37 | if __name__ == "__main__": 38 | parser = argparse.ArgumentParser("WSJ0-MIX data preprocessing") 39 | parser.add_argument( 40 | "--in_dir", type=str, default=None, help="Directory path of wham including tr, cv and tt" 41 | ) 42 | parser.add_argument("--n_src", type=int, default=2, help="Number of sources in wsj0-mix") 43 | parser.add_argument( 44 | "--out_dir", type=str, default=None, help="Directory path to put output files" 45 | ) 46 | args = parser.parse_args() 47 | print(args) 48 | preprocess(args) 49 | --------------------------------------------------------------------------------