├── .nojekyll ├── tests ├── __init__.py ├── data │ ├── test.wav │ ├── Al James - Schoolboy Facination.spectrogram.pt │ └── Al James - Schoolboy Facination.json ├── cli_test.sh ├── test_utils.py ├── create_dummy_datasets.sh ├── test_augmentations.py ├── create_vectors.py ├── test_io.py ├── test_jit.py ├── test_transforms.py ├── test_model.py ├── test_datasets.py ├── test_wiener.py └── test_regression.py ├── codecov.yml ├── .github ├── workflows │ ├── test_black.yml │ ├── test_cli.yml │ └── test_unittests.yml ├── PULL_REQUEST_TEMPLATE.md └── ISSUE_TEMPLATE │ ├── improved-model.md │ └── bug-report.md ├── Dockerfile ├── scripts ├── environment-cpu-osx.yml ├── environment-cpu-linux.yml ├── environment-gpu-linux-cuda10.yml ├── requirements.txt ├── train.py └── README.md ├── .flake8 ├── hubconf.py ├── LICENSE ├── .gitignore ├── pyproject.toml ├── docs ├── faq.md ├── inference.md ├── extensions.md ├── evaluate.html ├── training.md └── predict.html ├── pdoc └── config.mako ├── openunmix ├── predict.py ├── evaluate.py ├── cli.py ├── transforms.py ├── utils.py ├── model.py └── __init__.py ├── CONTRIBUTING.md └── README.md /.nojekyll: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /codecov.yml: -------------------------------------------------------------------------------- 1 | codecov: 2 | require_ci_to_pass: no 3 | -------------------------------------------------------------------------------- /tests/data/test.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sigsep/open-unmix-pytorch/HEAD/tests/data/test.wav -------------------------------------------------------------------------------- /tests/data/Al James - Schoolboy Facination.spectrogram.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sigsep/open-unmix-pytorch/HEAD/tests/data/Al James - Schoolboy Facination.spectrogram.pt -------------------------------------------------------------------------------- /tests/cli_test.sh: -------------------------------------------------------------------------------- 1 | # run umx on url 2 | umx https://samples.ffmpeg.org/A-codecs/wavpcm/test-96.wav --audio-backend stempeg 3 | umx https://samples.ffmpeg.org/A-codecs/wavpcm/test-96.wav --audio-backend stempeg --outdir out --niter 0 4 | -------------------------------------------------------------------------------- /.github/workflows/test_black.yml: -------------------------------------------------------------------------------- 1 | name: Lint 2 | on: [pull_request] # yamllint disable-line rule:truthy 3 | jobs: 4 | lint: 5 | runs-on: ubuntu-latest 6 | steps: 7 | - uses: actions/checkout@v4 8 | - uses: psf/black@stable 9 | with: 10 | options: "--check --verbose" 11 | version: "~= 24.4.0" 12 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM pytorch/pytorch:1.7.1-cuda11.0-cudnn8-runtime 2 | 3 | RUN apt-get update && apt-get install -y --no-install-recommends \ 4 | libsox-fmt-all \ 5 | sox \ 6 | libsox-dev 7 | 8 | WORKDIR /workspace 9 | 10 | RUN conda install ffmpeg -c conda-forge 11 | RUN pip install musdb>=0.4.0 12 | 13 | RUN pip install openunmix['stempeg'] 14 | 15 | ENTRYPOINT ["umx"] -------------------------------------------------------------------------------- /.github/PULL_REQUEST_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | 8 | 9 | #### What does this implement/fix? Explain your changes. 10 | 11 | 12 | #### Any other comments? 13 | -------------------------------------------------------------------------------- /tests/test_utils.py: -------------------------------------------------------------------------------- 1 | from openunmix.utils import AverageMeter, EarlyStopping 2 | 3 | 4 | def test_average_meter(): 5 | losses = AverageMeter() 6 | losses.update(1.0) 7 | losses.update(3.0) 8 | assert losses.avg == 2.0 9 | 10 | 11 | def test_early_stopping(): 12 | es = EarlyStopping(patience=2) 13 | es.step(1.0) 14 | 15 | assert not es.step(0.5) 16 | assert not es.step(0.6) 17 | assert es.step(0.7) 18 | -------------------------------------------------------------------------------- /scripts/environment-cpu-osx.yml: -------------------------------------------------------------------------------- 1 | name: umx-osx 2 | 3 | channels: 4 | - conda-forge 5 | - pytorch 6 | 7 | dependencies: 8 | - python=3.7 9 | - numpy=1.18 10 | - scipy=1.4 11 | - pytorch==1.9.0 12 | - torchaudio==0.9.0 13 | - tqdm 14 | - scikit-learn=0.22 15 | - ffmpeg 16 | - libsndfile 17 | - pip 18 | - pip: 19 | - musdb>=0.4.0 20 | - museval>=0.4.0 21 | - asteroid-filterbanks>=0.3.2 22 | - gitpython 23 | 24 | -------------------------------------------------------------------------------- /scripts/environment-cpu-linux.yml: -------------------------------------------------------------------------------- 1 | name: umx-cpu 2 | 3 | channels: 4 | - conda-forge 5 | - pytorch 6 | 7 | dependencies: 8 | - python=3.7 9 | - numpy=1.18 10 | - scipy=1.4 11 | - pytorch==1.9.0 12 | - torchaudio==0.9.0 13 | - cpuonly 14 | - tqdm 15 | - scikit-learn=0.22 16 | - ffmpeg 17 | - libsndfile 18 | - pip 19 | - pip: 20 | - musdb>=0.4.0 21 | - museval>=0.4.0 22 | - asteroid-filterbanks>=0.3.2 23 | - gitpython 24 | -------------------------------------------------------------------------------- /scripts/environment-gpu-linux-cuda10.yml: -------------------------------------------------------------------------------- 1 | name: umx-gpu 2 | 3 | channels: 4 | - conda-forge 5 | - pytorch 6 | - nvidia 7 | 8 | dependencies: 9 | - python=3.7 10 | - numpy=1.18 11 | - scipy=1.4 12 | - pytorch==1.9.0 13 | - torchaudio==0.9.0 14 | - cudatoolkit=11.1 15 | - scikit-learn=0.22 16 | - tqdm 17 | - libsndfile 18 | - ffmpeg 19 | - pip 20 | - pip: 21 | - musdb>=0.4.0 22 | - museval>=0.4.0 23 | - asteroid-filterbanks>=0.3.2 24 | - gitpython 25 | 26 | -------------------------------------------------------------------------------- /tests/create_dummy_datasets.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # Set the number of files to generate 4 | NBFILES=4 5 | BASEDIR=TrackfolderDataset 6 | subsets=( 7 | train 8 | valid 9 | ) 10 | for subset in "${subsets[@]}"; do 11 | for k in $(seq 1 4); do 12 | path=$BASEDIR/$subset/$k 13 | mkdir -p $path 14 | for i in $(seq 1 $NBFILES); do 15 | sox -n -r 8000 -b 16 $path/$i.wav synth "0:3" whitenoise vol 0.5 fade q 1 "0:3" 1 16 | done 17 | done 18 | done 19 | -------------------------------------------------------------------------------- /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | max-line-length = 119 3 | exclude = docs/source,*.egg,build 4 | select = E,W,F 5 | verbose = 2 6 | # https://pep8.readthedocs.io/en/latest/intro.html#error-codes 7 | format = pylint 8 | ignore = 9 | E731 # E731 - Do not assign a lambda expression, use a def 10 | W605 # W605 - invalid escape sequence '\_'. Needed for docs 11 | W504 # W504 - line break after binary operator 12 | W503 # W503 - line break before binary operator, need for black 13 | E203 # E203 - whitespace before ':'. Opposite convention enforced by black -------------------------------------------------------------------------------- /scripts/requirements.txt: -------------------------------------------------------------------------------- 1 | attrs==21.2.0 2 | cffi==1.14.6 3 | ffmpeg-python==0.2.0 4 | future==0.18.3 5 | gitdb==4.0.7 6 | GitPython==3.1.41 7 | joblib==1.2.0 8 | jsonschema==3.2.0 9 | musdb==0.3.1 10 | museval==0.3.1 11 | numpy==1.22.0 12 | pandas==1.3.3 13 | pyaml==21.8.3 14 | pycparser==2.20 15 | pyrsistent==0.18.0 16 | python-dateutil==2.8.2 17 | pytz==2021.1 18 | PyYAML==5.4.1 19 | scikit-learn==0.22 20 | scipy==1.10.0 21 | simplejson==3.17.5 22 | six==1.16.0 23 | smmap==4.0.0 24 | SoundFile==0.10.3.post1 25 | stempeg==0.2.3 26 | tqdm==4.62.2 27 | typing-extensions==3.10.0.2 28 | -------------------------------------------------------------------------------- /hubconf.py: -------------------------------------------------------------------------------- 1 | # This file is to be parsed by torch.hub mechanics 2 | # 3 | # `xxx_spec` take spectrogram inputs and output separated spectrograms 4 | # `xxx` take waveform inputs and output separated waveforms 5 | 6 | # Optional list of dependencies required by the package 7 | dependencies = ["torch", "numpy"] 8 | 9 | from openunmix import umxse_spec 10 | from openunmix import umxse 11 | 12 | from openunmix import umxhq_spec 13 | from openunmix import umxhq 14 | 15 | from openunmix import umx_spec 16 | from openunmix import umx 17 | 18 | from openunmix import umxl_spec 19 | from openunmix import umxl 20 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/improved-model.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: "\U0001F680Improved Model" 3 | about: Submit a proposal for an improved separation model 4 | 5 | --- 6 | 7 | ## 🚀 Model Improvement 8 | 12 | 13 | ## Motivation 14 | 15 | 16 | 17 | ## Objective Evaluation 18 | 19 | -------------------------------------------------------------------------------- /tests/test_augmentations.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | 4 | from openunmix import data 5 | 6 | 7 | @pytest.fixture(params=[4096]) 8 | def nb_timesteps(request): 9 | return int(request.param) 10 | 11 | 12 | @pytest.fixture(params=[1, 2]) 13 | def nb_channels(request): 14 | return request.param 15 | 16 | 17 | @pytest.fixture 18 | def audio(request, nb_channels, nb_timesteps): 19 | return torch.rand((nb_channels, nb_timesteps)) 20 | 21 | 22 | def test_gain(audio): 23 | out = data._augment_gain(audio) 24 | assert out.shape == audio.shape 25 | 26 | 27 | def test_channelswap(audio): 28 | out = data._augment_channelswap(audio) 29 | assert out.shape == audio.shape 30 | 31 | 32 | def test_forcestereo(audio, nb_channels): 33 | out = data._augment_force_stereo(audio) 34 | assert out.shape[-1] == audio.shape[-1] 35 | assert out.shape[0] == 2 36 | -------------------------------------------------------------------------------- /tests/create_vectors.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import musdb 4 | import numpy as np 5 | from openunmix import model, utils 6 | 7 | """script to create spectrogram test vectors for STFT regression tests 8 | 9 | Test vectors have been created using the `v1.0.0` release tag as this 10 | was the commit that umx was trained with 11 | """ 12 | 13 | 14 | def main(): 15 | test_track = "Al James - Schoolboy Facination" 16 | mus = musdb.DB(download=True) 17 | 18 | # load audio track 19 | track = [track for track in mus.tracks if track.name == test_track][0] 20 | 21 | # convert to torch tensor 22 | audio = torch.tensor(track.audio.T, dtype=torch.float32) 23 | 24 | stft = model.STFT(n_fft=4096, n_hop=1024) 25 | spec = model.Spectrogram(power=1, mono=False) 26 | magnitude_spectrogram = spec(stft(audio[None, ...])) 27 | torch.save(magnitude_spectrogram, "Al James - Schoolboy Facination.spectrogram.pt") 28 | 29 | 30 | if __name__ == "__main__": 31 | main() 32 | -------------------------------------------------------------------------------- /tests/test_io.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import numpy as np 3 | import os 4 | import torchaudio 5 | 6 | from openunmix import data 7 | 8 | 9 | audio_path = os.path.join( 10 | os.path.dirname(os.path.realpath(__file__)), 11 | "data/test.wav", 12 | ) 13 | 14 | 15 | @pytest.fixture(params=["soundfile", "sox_io"]) 16 | def torch_backend(request): 17 | return request.param 18 | 19 | 20 | @pytest.fixture(params=[1.0, 2.0, None]) 21 | def dur(request): 22 | return request.param 23 | 24 | 25 | @pytest.fixture(params=[True, False]) 26 | def info(request, torch_backend): 27 | torchaudio.set_audio_backend(torch_backend) 28 | 29 | if request.param: 30 | return data.load_info(audio_path) 31 | else: 32 | return None 33 | 34 | 35 | def test_loadwav(dur, info, torch_backend): 36 | torchaudio.set_audio_backend(torch_backend) 37 | audio, _ = data.load_audio(audio_path, dur=dur, info=info) 38 | rate = 8000.0 39 | if dur: 40 | assert audio.shape[-1] == int(dur * rate) 41 | else: 42 | assert audio.shape[-1] == rate * 3 43 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Inria (Fabian-Robert Stöter, Antoine Liutkus) 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /tests/test_jit.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.onnx 3 | import pytest 4 | 5 | from openunmix import model 6 | 7 | 8 | @pytest.mark.skip(reason="Currently not supported") 9 | def test_onnx(): 10 | """Test ONNX export of the separator 11 | 12 | currently results in erros, blocked by 13 | https://github.com/pytorch/pytorch/issues/49958 14 | """ 15 | nb_samples = 1 16 | nb_channels = 2 17 | nb_timesteps = 11111 18 | 19 | example = torch.rand((nb_samples, nb_channels, nb_timesteps), device="cpu") 20 | # set model to eval due to non-deterministic behaviour of dropout 21 | umx = model.OpenUnmix(nb_bins=2049, nb_channels=2).eval().to("cpu") 22 | 23 | # creatr separator 24 | separator = ( 25 | model.Separator(target_models={"source_1": umx, "source_2": umx}, niter=1, filterbank="asteroid") 26 | .eval() 27 | .to("cpu") 28 | ) 29 | 30 | torch_out = separator(example) 31 | 32 | # Export the model 33 | torch.onnx.export( 34 | separator, 35 | example, 36 | "umx.onnx", 37 | export_params=True, 38 | opset_version=10, 39 | do_constant_folding=True, 40 | input_names=["input"], 41 | output_names=["output"], 42 | dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}}, 43 | ) 44 | -------------------------------------------------------------------------------- /tests/test_transforms.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import numpy as np 3 | import torch 4 | from openunmix import transforms 5 | 6 | 7 | @pytest.fixture(params=[4096]) 8 | def nb_timesteps(request): 9 | return int(request.param) 10 | 11 | 12 | @pytest.fixture(params=[2]) 13 | def nb_channels(request): 14 | return request.param 15 | 16 | 17 | @pytest.fixture(params=[2]) 18 | def nb_samples(request): 19 | return request.param 20 | 21 | 22 | @pytest.fixture(params=[2048]) 23 | def nfft(request): 24 | return int(request.param) 25 | 26 | 27 | @pytest.fixture(params=[2]) 28 | def hop(request, nfft): 29 | return nfft // request.param 30 | 31 | 32 | @pytest.fixture(params=["torch", "asteroid"]) 33 | def method(request): 34 | return request.param 35 | 36 | 37 | @pytest.fixture 38 | def audio(request, nb_samples, nb_channels, nb_timesteps): 39 | return torch.rand((nb_samples, nb_channels, nb_timesteps)) 40 | 41 | 42 | def test_stft(audio, nfft, hop, method): 43 | # we should only test for center=True as 44 | # False doesn't pass COLA 45 | # https://github.com/pytorch/audio/issues/500 46 | stft, istft = transforms.make_filterbanks(n_fft=nfft, n_hop=hop, center=True, method=method) 47 | 48 | X = stft(audio) 49 | X = X.detach() 50 | out = istft(X, length=audio.shape[-1]) 51 | assert np.sqrt(np.mean((audio.detach().numpy() - out.detach().numpy()) ** 2)) < 1e-6 52 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug-report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: "\U0001F41B Bug Report" 3 | about: Submit a bug report to help to improve Open-Unmix 4 | 5 | --- 6 | 7 | ## 🐛 Bug 8 | 9 | 10 | 11 | ## To Reproduce 12 | 13 | Steps to reproduce the behavior: 14 | 15 | 1. 16 | 1. 17 | 1. 18 | 19 | 20 | 21 | ## Expected behavior 22 | 23 | 24 | 25 | ## Environment 26 | 27 | Please add some information about your environment 28 | 29 | - PyTorch Version (e.g., 1.2): 30 | - OS (e.g., Linux): 31 | - torchaudio loader (y/n): 32 | - Python version: 33 | - CUDA/cuDNN version: 34 | - Any other relevant information: 35 | 36 | If unsure you can paste the output from the [pytorch environment collection script](https://raw.githubusercontent.com/pytorch/pytorch/master/torch/utils/collect_env.py) 37 | (or fill out the checklist below manually). 38 | 39 | You can get that script and run it with: 40 | ``` 41 | wget https://raw.githubusercontent.com/pytorch/pytorch/master/torch/utils/collect_env.py 42 | # For security purposes, please check the contents of collect_env.py before running it. 43 | python collect_env.py 44 | ``` 45 | 46 | ## Additional context 47 | 48 | 49 | -------------------------------------------------------------------------------- /.github/workflows/test_cli.yml: -------------------------------------------------------------------------------- 1 | name: UMX 2 | # thanks for @mpariente for copying this workflow 3 | # see: https://help.github.com/en/actions/reference/events-that-trigger-workflows 4 | # Trigger the workflow on push or pull request 5 | on: # yamllint disable-line rule:truthy 6 | pull_request: 7 | types: [opened, synchronize, reopened, ready_for_review] 8 | 9 | jobs: 10 | src-test: 11 | name: separation test 12 | runs-on: ubuntu-latest 13 | strategy: 14 | matrix: 15 | python-version: [3.9] 16 | 17 | # Timeout: https://stackoverflow.com/a/59076067/4521646 18 | timeout-minutes: 20 19 | steps: 20 | - uses: actions/checkout@v2 21 | - name: Set up Python ${{ matrix.python-version }} 22 | uses: actions/setup-python@v2 23 | with: 24 | python-version: ${{ matrix.python-version }} 25 | - name: Install libnsdfile, ffmpeg and sox 26 | run: | 27 | sudo apt update 28 | sudo apt install libsndfile1-dev libsndfile1 ffmpeg sox 29 | - name: Install package dependencies 30 | run: | 31 | python -m pip install --upgrade --user pip --quiet 32 | python -m pip install .["stempeg"] 33 | python --version 34 | pip --version 35 | python -m pip list 36 | 37 | - name: CLI tests 38 | run: | 39 | umx https://samples.ffmpeg.org/A-codecs/wavpcm/test-96.wav --audio-backend stempeg 40 | -------------------------------------------------------------------------------- /tests/test_model.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | 4 | from openunmix import model 5 | from openunmix import umxse 6 | from openunmix import umxhq 7 | from openunmix import umx 8 | from openunmix import umxl 9 | 10 | 11 | @pytest.fixture(params=[100]) 12 | def nb_frames(request): 13 | return int(request.param) 14 | 15 | 16 | @pytest.fixture(params=[1, 2]) 17 | def nb_channels(request): 18 | return request.param 19 | 20 | 21 | @pytest.fixture(params=[2]) 22 | def nb_samples(request): 23 | return request.param 24 | 25 | 26 | @pytest.fixture(params=[1024]) 27 | def nb_bins(request): 28 | return request.param 29 | 30 | 31 | @pytest.fixture 32 | def spectrogram(request, nb_samples, nb_channels, nb_bins, nb_frames): 33 | return torch.rand((nb_samples, nb_channels, nb_bins, nb_frames)) 34 | 35 | 36 | @pytest.fixture(params=[False]) 37 | def unidirectional(request): 38 | return request.param 39 | 40 | 41 | @pytest.fixture(params=[32]) 42 | def hidden_size(request): 43 | return request.param 44 | 45 | 46 | def test_shape(spectrogram, nb_bins, nb_channels, unidirectional, hidden_size): 47 | unmix = model.OpenUnmix( 48 | nb_bins=nb_bins, 49 | nb_channels=nb_channels, 50 | unidirectional=unidirectional, 51 | nb_layers=1, # speed up training 52 | hidden_size=hidden_size, 53 | ) 54 | unmix.eval() 55 | Y = unmix(spectrogram) 56 | assert spectrogram.shape == Y.shape 57 | 58 | 59 | @pytest.mark.parametrize("model_fn", [umx, umxhq, umxse, umxl]) 60 | def test_model_loading(model_fn): 61 | X = torch.rand((1, 2, 4096)) 62 | model = model_fn(niter=0, pretrained=True) 63 | Y = model(X) 64 | assert Y[:, 0, ...].shape == X.shape 65 | -------------------------------------------------------------------------------- /.github/workflows/test_unittests.yml: -------------------------------------------------------------------------------- 1 | name: CI 2 | # thanks for @mpariente for copying this workflow 3 | # see: https://help.github.com/en/actions/reference/events-that-trigger-workflows 4 | # Trigger the workflow on push or pull request 5 | on: # yamllint disable-line rule:truthy 6 | pull_request: 7 | types: [opened, synchronize, reopened, ready_for_review] 8 | 9 | jobs: 10 | src-test: 11 | name: unit-tests 12 | runs-on: ubuntu-latest 13 | strategy: 14 | matrix: 15 | python-version: [3.9] 16 | 17 | # Timeout: https://stackoverflow.com/a/59076067/4521646 18 | timeout-minutes: 20 19 | steps: 20 | - uses: actions/checkout@v2 21 | - name: Set up Python ${{ matrix.python-version }} 22 | uses: actions/setup-python@v2 23 | with: 24 | python-version: ${{ matrix.python-version }} 25 | - name: Install libnsdfile, ffmpeg and sox 26 | run: | 27 | sudo apt update 28 | sudo apt install libsndfile1-dev libsndfile1 ffmpeg sox 29 | - name: Install python dependencies 30 | run: | 31 | python -m pip install --upgrade --user pip --quiet 32 | python -m pip install numpy Cython --upgrade-strategy only-if-needed --quiet 33 | python -m pip install torch torchaudio --index-url https://download.pytorch.org/whl/cpu 34 | python -m pip install -e .['tests'] 35 | python --version 36 | pip --version 37 | python -m pip list 38 | - name: Create dummy dataset 39 | run: | 40 | chmod +x tests/create_dummy_datasets.sh 41 | ./tests/create_dummy_datasets.sh 42 | shell: bash 43 | 44 | - name: Run tests 45 | env: 46 | PY_COLORS: "1" 47 | run: | 48 | pytest tests -vv -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | #### joe made this: http://goel.io/joe 2 | 3 | data/ 4 | OSU/ 5 | .mypy_cache 6 | .vscode 7 | *.json 8 | *.wav 9 | *.mp3 10 | *.pth.tar 11 | env*/ 12 | 13 | #####=== Python ===##### 14 | 15 | # Byte-compiled / optimized / DLL files 16 | __pycache__/ 17 | *.py[cod] 18 | *$py.class 19 | 20 | # C extensions 21 | *.so 22 | 23 | # Distribution / packaging 24 | .Python 25 | env/ 26 | build/ 27 | develop-eggs/ 28 | dist/ 29 | downloads/ 30 | eggs/ 31 | .eggs/ 32 | lib/ 33 | lib64/ 34 | parts/ 35 | sdist/ 36 | var/ 37 | *.egg-info/ 38 | .installed.cfg 39 | *.egg 40 | 41 | # PyInstaller 42 | # Usually these files are written by a python script from a template 43 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 44 | *.manifest 45 | *.spec 46 | 47 | # Installer logs 48 | pip-log.txt 49 | pip-delete-this-directory.txt 50 | 51 | # Unit test / coverage reports 52 | htmlcov/ 53 | .tox/ 54 | .coverage 55 | .coverage.* 56 | .cache 57 | nosetests.xml 58 | coverage.xml 59 | *,cover 60 | 61 | # Translations 62 | *.mo 63 | *.pot 64 | 65 | # Django stuff: 66 | *.log 67 | 68 | # Sphinx documentation 69 | docs/_build/ 70 | 71 | # PyBuilder 72 | target/ 73 | 74 | #####=== OSX ===##### 75 | .DS_Store 76 | .AppleDouble 77 | .LSOverride 78 | 79 | # Icon must end with two \r 80 | Icon 81 | 82 | 83 | # Thumbnails 84 | ._* 85 | 86 | # Files that might appear in the root of a volume 87 | .DocumentRevisions-V100 88 | .fseventsd 89 | .Spotlight-V100 90 | .TemporaryItems 91 | .Trashes 92 | .VolumeIcon.icns 93 | 94 | # Directories potentially created on remote AFP share 95 | .AppleDB 96 | .AppleDesktop 97 | Network Trash Folder 98 | Temporary Items 99 | .apdisk 100 | -------------------------------------------------------------------------------- /tests/test_datasets.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import numpy as np 3 | import torchaudio 4 | 5 | from openunmix import data 6 | 7 | 8 | @pytest.fixture(params=["soundfile", "sox_io"]) 9 | def torch_backend(request): 10 | return request.param 11 | 12 | 13 | def test_musdb(): 14 | musdb = data.MUSDBDataset(download=True, samples_per_track=1, seq_duration=1.0) 15 | for x, y in musdb: 16 | assert x.shape[-1] == 44100 17 | 18 | 19 | def test_trackfolder_fix(torch_backend): 20 | torchaudio.set_audio_backend(torch_backend) 21 | 22 | train_dataset = data.FixedSourcesTrackFolderDataset( 23 | split="train", 24 | seq_duration=1.0, 25 | root="TrackfolderDataset", 26 | sample_rate=8000.0, 27 | target_file="1.wav", 28 | interferer_files=["2.wav", "3.wav", "4.wav"], 29 | ) 30 | for x, y in train_dataset: 31 | assert x.shape[-1] == 8000 32 | 33 | 34 | def test_trackfolder_var(torch_backend): 35 | torchaudio.set_audio_backend(torch_backend) 36 | 37 | train_dataset = data.VariableSourcesTrackFolderDataset( 38 | split="train", 39 | seq_duration=1.0, 40 | root="TrackfolderDataset", 41 | sample_rate=8000.0, 42 | target_file="1.wav", 43 | ) 44 | for x, y in train_dataset: 45 | assert x.shape[-1] == 8000 46 | 47 | 48 | def test_sourcefolder(torch_backend): 49 | torchaudio.set_audio_backend(torch_backend) 50 | 51 | train_dataset = data.SourceFolderDataset( 52 | split="train", 53 | seq_duration=1.0, 54 | root="TrackfolderDataset", 55 | sample_rate=8000.0, 56 | target_dir="1", 57 | interferer_dirs=["2", "3"], 58 | ext=".wav", 59 | nb_samples=20, 60 | ) 61 | for k in range(len(train_dataset)): 62 | x, y = train_dataset[k] 63 | assert x.shape[-1] == 8000 64 | -------------------------------------------------------------------------------- /tests/test_wiener.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | 4 | 5 | from openunmix import model 6 | from openunmix.filtering import wiener 7 | 8 | 9 | @pytest.fixture(params=[10, 100]) 10 | def nb_frames(request): 11 | return int(request.param) 12 | 13 | 14 | @pytest.fixture(params=[1, 2]) 15 | def nb_channels(request): 16 | return request.param 17 | 18 | 19 | @pytest.fixture(params=[10, 127]) 20 | def nb_bins(request): 21 | return request.param 22 | 23 | 24 | @pytest.fixture(params=[1, 2, 3]) 25 | def nb_sources(request): 26 | return request.param 27 | 28 | 29 | @pytest.fixture(params=[0, 1, 2]) 30 | def iterations(request): 31 | return request.param 32 | 33 | 34 | @pytest.fixture(params=[True, False]) 35 | def softmask(request): 36 | return request.param 37 | 38 | 39 | @pytest.fixture(params=[True, False]) 40 | def residual(request): 41 | return request.param 42 | 43 | 44 | @pytest.fixture 45 | def target(request, nb_frames, nb_channels, nb_bins, nb_sources): 46 | return torch.rand((nb_frames, nb_bins, nb_channels, nb_sources)) 47 | 48 | 49 | @pytest.fixture 50 | def mix(request, nb_frames, nb_channels, nb_bins): 51 | return torch.rand((nb_frames, nb_bins, nb_channels, 2)) 52 | 53 | 54 | @pytest.fixture(params=[torch.float32, torch.float64]) 55 | def dtype(request): 56 | return request.param 57 | 58 | 59 | def test_wiener(target, mix, iterations, softmask, residual): 60 | output = wiener(target, mix, iterations=iterations, softmask=softmask, residual=residual) 61 | # nb_frames, nb_bins, nb_channels, 2, nb_sources 62 | assert output.shape[:3] == mix.shape[:3] 63 | assert output.shape[3] == 2 64 | if residual: 65 | assert output.shape[4] == target.shape[3] + 1 66 | else: 67 | assert output.shape[4] == target.shape[3] 68 | 69 | 70 | def test_dtype(target, mix, dtype): 71 | output = wiener(target.to(dtype=dtype), mix.to(dtype=dtype), iterations=1) 72 | assert output.dtype == dtype 73 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=61.0"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "openunmix" 7 | authors = [ 8 | {name = "Fabian-Robert Stöter", email = "mail@faroit.com"}, 9 | {name = "Antoine Liutkus", email = "antoine.liutkus@inria.fr"}, 10 | ] 11 | version = "1.3.0" 12 | description = "PyTorch-based music source separation toolkit" 13 | readme = "README.md" 14 | license = { text = "MIT" } 15 | requires-python = ">=3.9" 16 | classifiers = [ 17 | "Development Status :: 5 - Production/Stable", 18 | "Environment :: Console", 19 | "Intended Audience :: Developers", 20 | "License :: OSI Approved :: MIT License", 21 | "Operating System :: OS Independent", 22 | "Programming Language :: Python", 23 | "Programming Language :: Python :: 3 :: Only", 24 | "Programming Language :: Python :: 3.8", 25 | "Programming Language :: Python :: 3.9", 26 | "Programming Language :: Python :: 3.10", 27 | "Programming Language :: Python :: 3.11", 28 | "Programming Language :: Python :: 3.12", 29 | "Topic :: Software Development :: Libraries :: Python Modules", 30 | "Topic :: Software Development :: Quality Assurance", 31 | ] 32 | dependencies = [ 33 | "numpy", 34 | "torchaudio>=0.9.0", 35 | "torch>=1.9.0", 36 | "tqdm", 37 | ] 38 | 39 | [project.optional-dependencies] 40 | asteroid = ["asteroid-filterbanks>=0.3.2"] 41 | stempeg = ["stempeg"] 42 | evaluation = ["musdb>=0.4.0", "museval>=0.4.0"] 43 | tests = [ 44 | "pytest", 45 | "musdb>=0.4.0", 46 | "museval>=0.4.0", 47 | "stempeg", 48 | "asteroid-filterbanks>=0.3.2", 49 | "onnx", 50 | "tqdm", 51 | ] 52 | 53 | [project.scripts] 54 | umx = "openunmix.cli:separate" 55 | 56 | [project.urls] 57 | Homepage = "https://github.com/sigsep/open-unmix-pytorch" 58 | 59 | [tool.black] 60 | line-length = 120 61 | target-version = ['py39'] 62 | include = '\.pyi?$' 63 | exclude = ''' 64 | ( 65 | /( 66 | \.git 67 | | \.hg 68 | | \.mypy_cache 69 | | \.tox 70 | | \.venv 71 | | _build 72 | | buck-out 73 | | build 74 | | dist 75 | | \.idea 76 | | \.vscode 77 | | scripts 78 | | notebooks 79 | | \.eggs 80 | )/ 81 | ) 82 | ''' 83 | 84 | [tool.setuptools.packages.find] 85 | include = ["openunmix"] 86 | 87 | [tool.setuptools.package-data] 88 | openunmix = ["*.txt", "*.rst", "*.json", "*.wav", "*.pt"] -------------------------------------------------------------------------------- /tests/test_regression.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pytest 3 | import musdb 4 | import simplejson as json 5 | import numpy as np 6 | import torch 7 | 8 | 9 | from openunmix import model 10 | from openunmix import evaluate 11 | from openunmix import utils 12 | from openunmix import transforms 13 | 14 | 15 | test_track = "Al James - Schoolboy Facination" 16 | 17 | json_path = os.path.join( 18 | os.path.dirname(os.path.realpath(__file__)), 19 | "data/%s.json" % test_track, 20 | ) 21 | 22 | spec_path = os.path.join( 23 | os.path.dirname(os.path.realpath(__file__)), 24 | "data/%s.spectrogram.pt" % test_track, 25 | ) 26 | 27 | 28 | @pytest.fixture(params=["torch", "asteroid"]) 29 | def method(request): 30 | return request.param 31 | 32 | 33 | @pytest.fixture() 34 | def mus(): 35 | return musdb.DB(download=True) 36 | 37 | 38 | def test_estimate_and_evaluate(mus): 39 | # return any number of targets 40 | with open(json_path) as json_file: 41 | ref = json.loads(json_file.read()) 42 | 43 | track = [track for track in mus.tracks if track.name == test_track][0] 44 | 45 | scores = evaluate.separate_and_evaluate( 46 | track, 47 | targets=["vocals", "drums", "bass", "other"], 48 | model_str_or_path="umx", 49 | niter=1, 50 | residual=None, 51 | mus=mus, 52 | aggregate_dict=None, 53 | output_dir=None, 54 | eval_dir=None, 55 | device="cpu", 56 | wiener_win_len=None, 57 | ) 58 | 59 | assert scores.validate() is None 60 | 61 | with open(os.path.join(".", track.name) + ".json", "w+") as f: 62 | f.write(scores.json) 63 | 64 | scores = json.loads(scores.json) 65 | 66 | for target in ref["targets"]: 67 | for metric in ["SDR", "SIR", "SAR", "ISR"]: 68 | 69 | ref = np.array([d["metrics"][metric] for d in target["frames"]]) 70 | idx = [t["name"] for t in scores["targets"]].index(target["name"]) 71 | est = np.array([d["metrics"][metric] for d in scores["targets"][idx]["frames"]]) 72 | 73 | assert np.allclose(ref, est, atol=1e-01) 74 | 75 | 76 | def test_spectrogram(mus, method): 77 | """Regression test for spectrogram transform 78 | 79 | Loads pre-computed transform and compare to current spectrogram 80 | e.g. this makes sure that the training is reproducible if parameters 81 | such as STFT centering would be subject to change. 82 | """ 83 | track = [track for track in mus.tracks if track.name == test_track][0] 84 | 85 | stft, _ = transforms.make_filterbanks(n_fft=4096, n_hop=1024, sample_rate=track.rate, method=method) 86 | encoder = torch.nn.Sequential(stft, model.ComplexNorm(mono=False)) 87 | audio = torch.as_tensor(track.audio, dtype=torch.float32, device="cpu") 88 | audio = utils.preprocess(audio, track.rate, track.rate) 89 | ref = torch.load(spec_path) 90 | dut = encoder(audio).permute(3, 0, 1, 2) 91 | 92 | assert torch.allclose(ref, dut, atol=1e-1) 93 | -------------------------------------------------------------------------------- /docs/faq.md: -------------------------------------------------------------------------------- 1 | # Frequently Asked Questions 2 | 3 | ## Separating tracks crashes because it used too much memory 4 | 5 | First, separating an audio track into four separation models `vocals`, `drums`, `bass` and `other` is requires a significant amount of RAM to load all four separate models. 6 | Furthermore, another computationally important step in the separation is the post-processing, controlled by the parameter `niter`. 7 | For faster and less memory intensive inference (at the expense of separation quality) it is advised to use `niter 0`. 8 | Another way to improve performance is to apply separation on smaller excerpts using the `start` and `duration`, arguments. We suggest to only perform separation of ~30s stereo excerpts on machines with less 8GB of memory. 9 | 10 | ## Why is the training so slow? 11 | 12 | In the default configuration using the stems dataset, yielding a single batch from the dataset is very slow. This is a known issue of decoding mp4 stems since native decoders for pytorch or numpy are not available. 13 | 14 | There are two ways to speed up the training: 15 | 16 | ### 1. Increase the number of workers 17 | 18 | The default configuration does not use multiprocessing to yield the batches. You can increase the number of workers using the `--nb-workers k` configuration. E.g. `k=8` workers batch loading can get down to 1 batch per second. 19 | 20 | ### 2. Use WAV instead of MP4 21 | 22 | Convert the MUSDB18 dataset to wav using the builtin `musdb` cli tool 23 | 24 | ``` 25 | musdbconvert path/to/musdb-stems-root path/to/new/musdb-wav-root 26 | ``` 27 | 28 | or alternatively use the [MUSDB18-HQ](https://zenodo.org/record/3338373) dataset that is already stored and distributed as WAV files. Note that __if you want to compare to SiSEC 2018 participants, you should use the standard (Stems) MUSDB18 dataset and decode it to WAV, instead.__ 29 | 30 | Training on wav files can be launched using the `--is-wav` flag: 31 | 32 | ``` 33 | python scripts/train.py --root path/to/musdb18-wav --is-wav --target vocals 34 | ``` 35 | 36 | This will get you down to 0.6s per batch on 4 workers, likely hitting the bandwidth of standard hard-drives. It can be further improved using an SSD, which brings it down to 0.4s per batch on a GTX1080Ti which this leads to 95% GPU utilization. thus data-loading will not be the bottleneck anymore. 37 | 38 | ## Can I use the pre-trained models without torchhub? 39 | 40 | for some reason the torchub automatic download might not work and you want to download the files offline and use them. For that you can download [umx](https://zenodo.org/record/3340804) or [umxhq](https://zenodo.org/record/3267291) from Zenodo and create a local folder of your choice (e.g. `umx-weights`) where the model is stored in a flat file hierarchy: 41 | 42 | ``` 43 | umx-weights/vocals-*.pth 44 | umx-weights/drums-*.pth 45 | umx-weights/bass-*.pth 46 | umx-weights/other-*.pth 47 | umx-weights/vocals.json 48 | umx-weights/drums.json 49 | umx-weights/bass.json 50 | umx-weights/other.json 51 | umx-weights/separator.json 52 | ``` 53 | 54 | Test and eval can then be started using: 55 | 56 | ```bash 57 | umx --model umx-weights --input test.wav 58 | ``` 59 | -------------------------------------------------------------------------------- /pdoc/config.mako: -------------------------------------------------------------------------------- 1 | <%! 2 | # Template configuration. Copy over in your template directory 3 | # (used with `--template-dir`) and adapt as necessary. 4 | # Note, defaults are loaded from this distribution file, so your 5 | # config.mako only needs to contain values you want overridden. 6 | # You can also run pdoc with `--config KEY=VALUE` to override 7 | # individual values. 8 | html_lang = 'en' 9 | show_inherited_members = False 10 | extract_module_toc_into_sidebar = True 11 | list_class_variables_in_index = True 12 | sort_identifiers = True 13 | show_type_annotations = True 14 | # Show collapsed source code block next to each item. 15 | # Disabling this can improve rendering speed of large modules. 16 | show_source_code = True 17 | # If set, format links to objects in online source code repository 18 | # according to this template. Supported keywords for interpolation 19 | # are: commit, path, start_line, end_line. 20 | git_link_template = 'https://github.com/sigsep/open-unmix-pytorch/blob/{commit}/{path}#L{start_line}-L{end_line}' 21 | #git_link_template = 'https://gitlab.com/USER/PROJECT/blob/{commit}/{path}#L{start_line}-L{end_line}' 22 | #git_link_template = 'https://bitbucket.org/USER/PROJECT/src/{commit}/{path}#lines-{start_line}:{end_line}' 23 | #git_link_template = 'https://CGIT_HOSTNAME/PROJECT/tree/{path}?id={commit}#n{start-line}' 24 | # A prefix to use for every HTML hyperlink in the generated documentation. 25 | # No prefix results in all links being relative. 26 | link_prefix = '' 27 | # Enable syntax highlighting for code/source blocks by including Highlight.js 28 | syntax_highlighting = True 29 | # Set the style keyword such as 'atom-one-light' or 'github-gist' 30 | # Options: https://github.com/highlightjs/highlight.js/tree/master/src/styles 31 | # Demo: https://highlightjs.org/static/demo/ 32 | hljs_style = 'github' 33 | # If set, insert Google Analytics tracking code. Value is GA 34 | # tracking id (UA-XXXXXX-Y). 35 | google_analytics = '' 36 | # If set, insert Google Custom Search search bar widget above the sidebar index. 37 | # The whitespace-separated tokens represent arbitrary extra queries (at least one 38 | # must match) passed to regular Google search. Example: 39 | #google_search_query = 'inurl:github.com/USER/PROJECT site:PROJECT.github.io site:PROJECT.website' 40 | google_search_query = '' 41 | # Enable offline search using Lunr.js. For explanation of 'fuzziness' parameter, which is 42 | # added to every query word, see: https://lunrjs.com/guides/searching.html#fuzzy-matches 43 | # If 'index_docstrings' is False, a shorter index is built, indexing only 44 | # the full object reference names. 45 | #lunr_search = {'fuzziness': 1, 'index_docstrings': True} 46 | lunr_search = None 47 | # If set, render LaTeX math syntax within \(...\) (inline equations), 48 | # or within \[...\] or $$...$$ or `.. math::` (block equations) 49 | # as nicely-formatted math formulas using MathJax. 50 | # Note: in Python docstrings, either all backslashes need to be escaped (\\) 51 | # or you need to use raw r-strings. 52 | latex_math = True 53 | %> -------------------------------------------------------------------------------- /openunmix/predict.py: -------------------------------------------------------------------------------- 1 | from openunmix import utils 2 | 3 | 4 | def separate( 5 | audio, 6 | rate=None, 7 | model_str_or_path="umxl", 8 | targets=None, 9 | niter=1, 10 | residual=False, 11 | wiener_win_len=300, 12 | aggregate_dict=None, 13 | separator=None, 14 | device=None, 15 | filterbank="torch", 16 | ): 17 | """ 18 | Open Unmix functional interface 19 | 20 | Separates a torch.Tensor or the content of an audio file. 21 | 22 | If a separator is provided, use it for inference. If not, create one 23 | and use it afterwards. 24 | 25 | Args: 26 | audio: audio to process 27 | torch Tensor: shape (channels, length), and 28 | `rate` must also be provided. 29 | rate: int or None: only used if audio is a Tensor. Otherwise, 30 | inferred from the file. 31 | model_str_or_path: the pretrained model to use, defaults to UMX-L 32 | targets (str): select the targets for the source to be separated. 33 | a list including: ['vocals', 'drums', 'bass', 'other']. 34 | If you don't pick them all, you probably want to 35 | activate the `residual=True` option. 36 | Defaults to all available targets per model. 37 | niter (int): the number of post-processingiterations, defaults to 1 38 | residual (bool): if True, a "garbage" target is created 39 | wiener_win_len (int): the number of frames to use when batching 40 | the post-processing step 41 | aggregate_dict (str): if provided, must be a string containing a ' 42 | 'valid expression for a dictionary, with keys as output ' 43 | 'target names, and values a list of targets that are used to ' 44 | 'build it. For instance: \'{\"vocals\":[\"vocals\"], ' 45 | '\"accompaniment\":[\"drums\",\"bass\",\"other\"]}\' 46 | separator: if provided, the model.Separator object that will be used 47 | to perform separation 48 | device (str): selects device to be used for inference 49 | filterbank (str): filterbank implementation method. 50 | Supported are `['torch', 'asteroid']`. `torch` is about 30% faster 51 | compared to `asteroid` on large FFT sizes such as 4096. However, 52 | asteroids stft can be exported to onnx, which makes is practical 53 | for deployment. 54 | """ 55 | if separator is None: 56 | separator = utils.load_separator( 57 | model_str_or_path=model_str_or_path, 58 | targets=targets, 59 | niter=niter, 60 | residual=residual, 61 | wiener_win_len=wiener_win_len, 62 | device=device, 63 | pretrained=True, 64 | filterbank=filterbank, 65 | ) 66 | separator.freeze() 67 | if device: 68 | separator.to(device) 69 | 70 | if rate is None: 71 | raise Exception("rate` must be provided.") 72 | 73 | if device: 74 | audio = audio.to(device) 75 | audio = utils.preprocess(audio, rate, separator.sample_rate) 76 | 77 | # getting the separated signals 78 | estimates = separator(audio) 79 | estimates = separator.to_dict(estimates, aggregate_dict=aggregate_dict) 80 | return estimates 81 | -------------------------------------------------------------------------------- /docs/inference.md: -------------------------------------------------------------------------------- 1 | # Performing separation 2 | 3 | ## Interfacing using the command line 4 | 5 | The primary interface to separate files is the command line. To separate a mixture file into the four stems you can just run 6 | 7 | ```bash 8 | umx input_file.wav 9 | ``` 10 | 11 | Note that we support all files that can be read by torchaudio, depending on the set backend (either `soundfile` (libsndfile) or `sox`). 12 | For training, we set the default to `soundfile` as it is faster than `sox`. However for inference users might prefer `mp3` decoding capabilities. 13 | The separation can be controlled with additional parameters that influence the performance of the separation. 14 | 15 | | Command line Argument | Description | Default | 16 | |----------------------------|---------------------------------------------------------------------------------|-----------------| 17 | |`--start ` | set start in seconds to reduce the duration of the audio being loaded | `0.0` | 18 | |`--duration ` | set duration in seconds to reduce length of the audio being loaded. Negative values will make the full audio being loaded | `-1.0` | 19 | |`--model ` | path or string of model name to select either a self pre-trained model or a model loaded from `torchhub`. | | 20 | | `--targets list(str)` | Targets to be used for separation. For each target a model file with with same name is required. | `['vocals', 'drums', 'bass', 'other']` | 21 | | `--niter ` | Number of EM steps for refining initial estimates in a post-processing stage. `--niter 0` skips this step altogether (and thus makes separation significantly faster) More iterations can get better interference reduction at the price of more artifacts. | `1` | 22 | | `--residual` | computes a residual target, for custom separation scenarios when not all targets are available (at the expense of slightly less performance). E.g vocal/accompaniment can be performed with `--targets vocals --residual`. | not set | 23 | | `--softmask` | if activated, then the initial estimates for the sources will be obtained through a ratio mask of the mixture STFT, and not by using the default behavior of reconstructing waveforms by using the mixture phase. | not set | 24 | | `--wiener-win-len ` | Number of frames on which to apply filtering independently | `300` | 25 | | `--audio-backend ` | choose audio loading backend, either `sox_io`, `soundfile` or `stempeg` (which needs additional installation requirements) | [torchaudio default](https://pytorch.org/audio/stable/backend.html) | 26 | | `--aggregate ` | if provided, must be a string containing a valid expression for a dictionary, with keys as output target names, and values a list of targets that are used to build it. For instance: `{ "vocals": ["vocals"], "accompaniment": ["drums", "bass", "other"]}` | `None` | 27 | | `--filterbank ` | filterbank implementation method. Supported: `['torch', 'asteroid']`. While `torch` is ~30% faster compared to `asteroid` on large FFT sizes such as 4096, asteroids STFT maybe be easier to be exported for deployment. | `torch` | 28 | 29 | ## Interfacing from python 30 | 31 | At the core of the process of separating audio is the `Separator` Module which 32 | takes a numpy audio array or a `torch.Tensor` as input (the mixture) and separates into `targets` stems. 33 | Note, that for each target a separate model will be loaded. E.g. for `umx` and `umxhq` the supported targets are 34 | `['vocals', 'drums', 'bass', 'other']`. The models have to be passed to the separators `target_models` parameter. 35 | 36 | Both models `umx`, `umxhq`, `umxl` and `umxse` are downloaded automatically. 37 | 38 | Here is an example for constructor for the `Separator` takes the following arguments, with suggested default values: 39 | 40 | ```python 41 | seperator = openunmix.Separator( 42 | target_models: dict, 43 | niter: int = 0, 44 | softmask: bool = False, 45 | residual: bool = False, 46 | sample_rate: float = 44100.0, 47 | n_fft: int = 4096, 48 | n_hop: int = 1024, 49 | nb_channels: int = 2, 50 | wiener_win_len: Optional[int] = 300, 51 | filterbank: str = 'torch' 52 | ): 53 | ``` 54 | 55 | When passing 56 | 57 | > __Caution__ `training` using the EM algorithm (`niter>0`) is not supported. Only plain post-processing is supported right now for gradient computation. This is because the performance overhead of avoiding all the in-places operations is too large. 58 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing 2 | 3 | Open-Unmix is designed as scientific software. Therefore, we encourage the community to submit bug-fixes and comments and improve the __computational performance__, __reproducibility__ and the __readability__ of the code where possible. When contributing to this repository, please first discuss the change you wish to make in the issue tracker with the owners of this repository before making a change. 4 | 5 | We are not looking for contributions that only focus on improving the __separation performance__. However, if this is case, we, instead, encourage researchers to 6 | 7 | 1. Use Open-Unmix for their own research, e.g. by modification of the model. 8 | 2. Publish and present the results in a scientific paper / conference and __cite open-unmix__. 9 | 3. Contact us via mail or open a [performance issue]() if you are interested to contribute the new model. 10 | In this case we will rerun the training on our internal cluster and update the pre-trained weights accordingly. 11 | 12 | Please note we have a code of conduct, please follow it in all your interactions with the project. 13 | 14 | ## Pull Request Process 15 | 16 | The preferred way to contribute to open-unmix is to fork the 17 | [main repository](http://github.com/sigsep/open-unmix-pytorch/) on 18 | GitHub: 19 | 20 | 1. Fork the [project repository](http://github.com/sigsep/open-unmix-pytorch): 21 | click on the 'Fork' button near the top of the page. This creates 22 | a copy of the code under your account on the GitHub server. 23 | 24 | 2. Clone this copy to your local disk: 25 | 26 | ``` 27 | $ git clone git@github.com:YourLogin/open-unmix-pytorch.git 28 | $ cd open-unmix-pytorch 29 | ``` 30 | 31 | 3. Create a branch to hold your changes: 32 | 33 | ``` 34 | $ git checkout -b my-feature 35 | ``` 36 | 37 | and start making changes. Never work in the ``master`` branch! 38 | 39 | 4. Ensure any install or build artifacts are removed before making the pull request. 40 | 41 | 5. Update the README.md and/or the appropriate document in the `/docs` folder with details of changes to the interface, this includes new command line arguments, dataset description or command line examples. 42 | 43 | 6. Work on this copy on your computer using Git to do the version 44 | control. When you're done editing, do: 45 | 46 | ``` 47 | $ git add modified_files 48 | $ git commit 49 | ``` 50 | 51 | to record your changes in Git, then push them to GitHub with: 52 | 53 | ``` 54 | $ git push -u origin my-feature 55 | ``` 56 | 57 | Finally, go to the web page of your fork of the open-unmix repo, 58 | and click 'Pull request' to send your changes to the maintainers for 59 | review. This will send an email to the committers. 60 | 61 | (If any of the above seems like magic to you, then look up the 62 | [Git documentation](http://git-scm.com/documentation) on the web.) 63 | 64 | ## Code of Conduct 65 | 66 | ### Our Pledge 67 | 68 | In the interest of fostering an open and welcoming environment, we as 69 | contributors and maintainers pledge to making participation in our project and 70 | our community a harassment-free experience for everyone, regardless of age, body 71 | size, disability, ethnicity, gender identity and expression, level of experience, 72 | nationality, personal appearance, race, religion, or sexual identity and 73 | orientation. 74 | 75 | ### Our Standards 76 | 77 | Examples of behavior that contributes to creating a positive environment 78 | include: 79 | 80 | * Using welcoming and inclusive language 81 | * Being respectful of differing viewpoints and experiences 82 | * Gracefully accepting constructive criticism 83 | * Focusing on what is best for the community 84 | * Showing empathy towards other community members 85 | 86 | Examples of unacceptable behavior by participants include: 87 | 88 | * The use of sexualized language or imagery and unwelcome sexual attention or 89 | advances 90 | * Trolling, insulting/derogatory comments, and personal or political attacks 91 | * Public or private harassment 92 | * Publishing others' private information, such as a physical or electronic 93 | address, without explicit permission 94 | * Other conduct which could reasonably be considered inappropriate in a 95 | professional setting 96 | 97 | ### Our Responsibilities 98 | 99 | Project maintainers are responsible for clarifying the standards of acceptable 100 | behavior and are expected to take appropriate and fair corrective action in 101 | response to any instances of unacceptable behavior. 102 | 103 | Project maintainers have the right and responsibility to remove, edit, or 104 | reject comments, commits, code, wiki edits, issues, and other contributions 105 | that are not aligned to this Code of Conduct, or to ban temporarily or 106 | permanently any contributor for other behaviors that they deem inappropriate, 107 | threatening, offensive, or harmful. 108 | 109 | ### Scope 110 | 111 | This Code of Conduct applies both within project spaces and in public spaces 112 | when an individual is representing the project or its community. Examples of 113 | representing a project or community include using an official project e-mail 114 | address, posting via an official social media account, or acting as an appointed 115 | representative at an online or offline event. Representation of a project may be 116 | further defined and clarified by project maintainers. 117 | 118 | ### Enforcement 119 | 120 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 121 | reported by contacting the project team @aliutkus, @faroit. All 122 | complaints will be reviewed and investigated and will result in a response that 123 | is deemed necessary and appropriate to the circumstances. The project team is 124 | obligated to maintain confidentiality with regard to the reporter of an incident. 125 | Further details of specific enforcement policies may be posted separately. 126 | 127 | Project maintainers who do not follow or enforce the Code of Conduct in good 128 | faith may face temporary or permanent repercussions as determined by other 129 | members of the project's leadership. 130 | 131 | ### Attribution 132 | 133 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 134 | available at [http://contributor-covenant.org/version/1/4][version] 135 | 136 | [homepage]: http://contributor-covenant.org 137 | [version]: http://contributor-covenant.org/version/1/4/ -------------------------------------------------------------------------------- /openunmix/evaluate.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import functools 3 | import json 4 | import multiprocessing 5 | from typing import Optional, Union 6 | 7 | import musdb 8 | import museval 9 | import torch 10 | import tqdm 11 | 12 | from openunmix import utils 13 | 14 | 15 | def separate_and_evaluate( 16 | track: musdb.MultiTrack, 17 | targets: list, 18 | model_str_or_path: str, 19 | niter: int, 20 | output_dir: str, 21 | eval_dir: str, 22 | residual: bool, 23 | mus, 24 | aggregate_dict: dict = None, 25 | device: Union[str, torch.device] = "cpu", 26 | wiener_win_len: Optional[int] = None, 27 | filterbank="torch", 28 | ) -> str: 29 | 30 | separator = utils.load_separator( 31 | model_str_or_path=model_str_or_path, 32 | targets=targets, 33 | niter=niter, 34 | residual=residual, 35 | wiener_win_len=wiener_win_len, 36 | device=device, 37 | pretrained=True, 38 | filterbank=filterbank, 39 | ) 40 | 41 | separator.freeze() 42 | separator.to(device) 43 | 44 | audio = torch.as_tensor(track.audio, dtype=torch.float32, device=device) 45 | audio = utils.preprocess(audio, track.rate, separator.sample_rate) 46 | 47 | estimates = separator(audio) 48 | estimates = separator.to_dict(estimates, aggregate_dict=aggregate_dict) 49 | 50 | for key in estimates: 51 | estimates[key] = estimates[key][0].cpu().detach().numpy().T 52 | if output_dir: 53 | mus.save_estimates(estimates, track, output_dir) 54 | 55 | scores = museval.eval_mus_track(track, estimates, output_dir=eval_dir) 56 | return scores 57 | 58 | 59 | if __name__ == "__main__": 60 | # Training settings 61 | parser = argparse.ArgumentParser(description="MUSDB18 Evaluation", add_help=False) 62 | 63 | parser.add_argument( 64 | "--targets", 65 | nargs="+", 66 | default=["vocals", "drums", "bass", "other"], 67 | type=str, 68 | help="provide targets to be processed. \ 69 | If none, all available targets will be computed", 70 | ) 71 | 72 | parser.add_argument( 73 | "--model", 74 | default="umxl", 75 | type=str, 76 | help="path to mode base directory of pretrained models", 77 | ) 78 | 79 | parser.add_argument( 80 | "--outdir", 81 | type=str, 82 | help="Results path where audio evaluation results are stored", 83 | ) 84 | 85 | parser.add_argument("--evaldir", type=str, help="Results path for museval estimates") 86 | 87 | parser.add_argument("--root", type=str, help="Path to MUSDB18") 88 | 89 | parser.add_argument("--subset", type=str, default="test", help="MUSDB subset (`train`/`test`)") 90 | 91 | parser.add_argument("--cores", type=int, default=1) 92 | 93 | parser.add_argument("--no-cuda", action="store_true", default=False, help="disables CUDA inference") 94 | 95 | parser.add_argument( 96 | "--is-wav", 97 | action="store_true", 98 | default=False, 99 | help="flags wav version of the dataset", 100 | ) 101 | 102 | parser.add_argument( 103 | "--niter", 104 | type=int, 105 | default=1, 106 | help="number of iterations for refining results.", 107 | ) 108 | 109 | parser.add_argument( 110 | "--wiener-win-len", 111 | type=int, 112 | default=300, 113 | help="Number of frames on which to apply filtering independently", 114 | ) 115 | 116 | parser.add_argument( 117 | "--residual", 118 | type=str, 119 | default=None, 120 | help="if provided, build a source with given name" "for the mix minus all estimated targets", 121 | ) 122 | 123 | parser.add_argument( 124 | "--aggregate", 125 | type=str, 126 | default=None, 127 | help="if provided, must be a string containing a valid expression for " 128 | "a dictionary, with keys as output target names, and values " 129 | "a list of targets that are used to build it. For instance: " 130 | '\'{"vocals":["vocals"], "accompaniment":["drums",' 131 | '"bass","other"]}\'', 132 | ) 133 | 134 | args = parser.parse_args() 135 | 136 | use_cuda = not args.no_cuda and torch.cuda.is_available() 137 | device = torch.device("cuda" if use_cuda else "cpu") 138 | 139 | mus = musdb.DB( 140 | root=args.root, 141 | download=args.root is None, 142 | subsets=args.subset, 143 | is_wav=args.is_wav, 144 | ) 145 | aggregate_dict = None if args.aggregate is None else json.loads(args.aggregate) 146 | 147 | if args.cores > 1: 148 | pool = multiprocessing.Pool(args.cores) 149 | results = museval.EvalStore() 150 | scores_list = list( 151 | pool.imap_unordered( 152 | func=functools.partial( 153 | separate_and_evaluate, 154 | targets=args.targets, 155 | model_str_or_path=args.model, 156 | niter=args.niter, 157 | residual=args.residual, 158 | mus=mus, 159 | aggregate_dict=aggregate_dict, 160 | output_dir=args.outdir, 161 | eval_dir=args.evaldir, 162 | device=device, 163 | ), 164 | iterable=mus.tracks, 165 | chunksize=1, 166 | ) 167 | ) 168 | pool.close() 169 | pool.join() 170 | for scores in scores_list: 171 | results.add_track(scores) 172 | 173 | else: 174 | results = museval.EvalStore() 175 | for track in tqdm.tqdm(mus.tracks): 176 | scores = separate_and_evaluate( 177 | track, 178 | targets=args.targets, 179 | model_str_or_path=args.model, 180 | niter=args.niter, 181 | residual=args.residual, 182 | mus=mus, 183 | aggregate_dict=aggregate_dict, 184 | output_dir=args.outdir, 185 | eval_dir=args.evaldir, 186 | device=device, 187 | ) 188 | print(track, "\n", scores) 189 | results.add_track(scores) 190 | 191 | print(results) 192 | method = museval.MethodStore() 193 | method.add_evalstore(results, args.model) 194 | method.save(args.model + ".pandas") 195 | -------------------------------------------------------------------------------- /tests/data/Al James - Schoolboy Facination.json: -------------------------------------------------------------------------------- 1 | { 2 | "targets": [ 3 | { 4 | "name": "vocals", 5 | "frames": [ 6 | { 7 | "time": 0.0, 8 | "duration": 1.0, 9 | "metrics": { 10 | "SDR": 4.32297, 11 | "SIR": 10.12637, 12 | "SAR": 7.59260, 13 | "ISR": 7.33622 14 | } 15 | }, 16 | { 17 | "time": 1.0, 18 | "duration": 1.0, 19 | "metrics": { 20 | "SDR": 7.22083, 21 | "SIR": 10.20295, 22 | "SAR": 8.57502, 23 | "ISR": 10.96141 24 | } 25 | }, 26 | { 27 | "time": 2.0, 28 | "duration": 1.0, 29 | "metrics": { 30 | "SDR": 7.80233, 31 | "SIR": 13.23437, 32 | "SAR": 10.15870, 33 | "ISR": 10.97836 34 | } 35 | }, 36 | { 37 | "time": 3.0, 38 | "duration": 1.0, 39 | "metrics": { 40 | "SDR": 8.56464, 41 | "SIR": 15.05897, 42 | "SAR": 11.07534, 43 | "ISR": 11.29683 44 | } 45 | }, 46 | { 47 | "time": 4.0, 48 | "duration": 1.0, 49 | "metrics": { 50 | "SDR": 7.42655, 51 | "SIR": 14.16498, 52 | "SAR": 10.40192, 53 | "ISR": 11.35024 54 | } 55 | }, 56 | { 57 | "time": 5.0, 58 | "duration": 1.0, 59 | "metrics": { 60 | "SDR": 6.66938, 61 | "SIR": 11.71322, 62 | "SAR": 8.97930, 63 | "ISR": 10.42365 64 | } 65 | } 66 | ] 67 | }, 68 | { 69 | "name": "drums", 70 | "frames": [ 71 | { 72 | "time": 0.0, 73 | "duration": 1.0, 74 | "metrics": { 75 | "SDR": 3.98718, 76 | "SIR": 5.23581, 77 | "SAR": 3.66383, 78 | "ISR": 8.04258 79 | } 80 | }, 81 | { 82 | "time": 1.0, 83 | "duration": 1.0, 84 | "metrics": { 85 | "SDR": 3.10503, 86 | "SIR": 2.47233, 87 | "SAR": 2.91512, 88 | "ISR": 7.29689 89 | } 90 | }, 91 | { 92 | "time": 2.0, 93 | "duration": 1.0, 94 | "metrics": { 95 | "SDR": 4.06987, 96 | "SIR": 6.64704, 97 | "SAR": 5.02986, 98 | "ISR": 6.72215 99 | } 100 | }, 101 | { 102 | "time": 3.0, 103 | "duration": 1.0, 104 | "metrics": { 105 | "SDR": 3.67143, 106 | "SIR": 4.86323, 107 | "SAR": 3.75070, 108 | "ISR": 7.21293 109 | } 110 | }, 111 | { 112 | "time": 4.0, 113 | "duration": 1.0, 114 | "metrics": { 115 | "SDR": 3.88528, 116 | "SIR": 2.51644, 117 | "SAR": 3.34250, 118 | "ISR": 8.09842 119 | } 120 | }, 121 | { 122 | "time": 5.0, 123 | "duration": 1.0, 124 | "metrics": { 125 | "SDR": 5.39026, 126 | "SIR": 3.22021, 127 | "SAR": 3.37839, 128 | "ISR": 8.22456 129 | } 130 | } 131 | ] 132 | }, 133 | { 134 | "name": "bass", 135 | "frames": [ 136 | { 137 | "time": 0.0, 138 | "duration": 1.0, 139 | "metrics": { 140 | "SDR": 5.60414, 141 | "SIR": 12.81026, 142 | "SAR": 8.76971, 143 | "ISR": 7.72415 144 | } 145 | }, 146 | { 147 | "time": 1.0, 148 | "duration": 1.0, 149 | "metrics": { 150 | "SDR": 6.79312, 151 | "SIR": 15.50595, 152 | "SAR": 7.36306, 153 | "ISR": 6.43593 154 | } 155 | }, 156 | { 157 | "time": 2.0, 158 | "duration": 1.0, 159 | "metrics": { 160 | "SDR": 5.42226, 161 | "SIR": 14.25947, 162 | "SAR": 7.97667, 163 | "ISR": 5.99902 164 | } 165 | }, 166 | { 167 | "time": 3.0, 168 | "duration": 1.0, 169 | "metrics": { 170 | "SDR": 5.88173, 171 | "SIR": 13.25402, 172 | "SAR": 7.96686, 173 | "ISR": 7.43682 174 | } 175 | }, 176 | { 177 | "time": 4.0, 178 | "duration": 1.0, 179 | "metrics": { 180 | "SDR": 6.81244, 181 | "SIR": 13.69640, 182 | "SAR": 9.50696, 183 | "ISR": 7.39392 184 | } 185 | }, 186 | { 187 | "time": 5.0, 188 | "duration": 1.0, 189 | "metrics": { 190 | "SDR": 4.63153, 191 | "SIR": 10.98876, 192 | "SAR": 6.74840, 193 | "ISR": 6.41703 194 | } 195 | } 196 | ] 197 | }, 198 | { 199 | "name": "other", 200 | "frames": [ 201 | { 202 | "time": 0.0, 203 | "duration": 1.0, 204 | "metrics": { 205 | "SDR": 2.02272, 206 | "SIR": 3.09448, 207 | "SAR": 5.76186, 208 | "ISR": 7.85117 209 | } 210 | }, 211 | { 212 | "time": 1.0, 213 | "duration": 1.0, 214 | "metrics": { 215 | "SDR": 3.25884, 216 | "SIR": 1.92849, 217 | "SAR": 4.78471, 218 | "ISR": 7.66760 219 | } 220 | }, 221 | { 222 | "time": 2.0, 223 | "duration": 1.0, 224 | "metrics": { 225 | "SDR": 1.58550, 226 | "SIR": 1.34803, 227 | "SAR": 5.57206, 228 | "ISR": 7.48901 229 | } 230 | }, 231 | { 232 | "time": 3.0, 233 | "duration": 1.0, 234 | "metrics": { 235 | "SDR": 2.04527, 236 | "SIR": 2.84563, 237 | "SAR": 5.45523, 238 | "ISR": 6.97621 239 | } 240 | }, 241 | { 242 | "time": 4.0, 243 | "duration": 1.0, 244 | "metrics": { 245 | "SDR": 2.03671, 246 | "SIR": 2.73438, 247 | "SAR": 5.29621, 248 | "ISR": 6.24533 249 | } 250 | }, 251 | { 252 | "time": 5.0, 253 | "duration": 1.0, 254 | "metrics": { 255 | "SDR": 3.31991, 256 | "SIR": 3.90981, 257 | "SAR": 4.51333, 258 | "ISR": 5.17112 259 | } 260 | } 261 | ] 262 | } 263 | ], 264 | "museval_version": "0.3.0" 265 | } -------------------------------------------------------------------------------- /docs/extensions.md: -------------------------------------------------------------------------------- 1 | # Extending Open-Unmix 2 | 3 | ![](https://docs.google.com/drawings/d/e/2PACX-1vQ1WgVU4PGeEqTQ26j-2RbwaN9ZPlxabBI5N7mYqOK66VjT96UmT9wAaX1s6u6jDHe0ARfAo9E--lQM/pub?w=1918&h=703) 4 | One of the key aspects of _Open-Unmix_ is that it was made to be easily extensible and thus is a good starting point for new research on music source separation. In fact, the open-unmix training code is based on the [pytorch MNIST example](https://github.com/pytorch/examples/blob/master/mnist/main.py). In this document we provide a short overview of ways to extend open-unmix. 5 | 6 | ## Code Structure 7 | 8 | * `data.py` includes several torch datasets that can all be used to train _open-unmix_. 9 | * `train.py` includes all code that is necessary to start a training. 10 | * `model.py` includes the open-unmix torch modules. 11 | * `test.py` includes code to predict/unmix from audio files. 12 | * `eval.py` includes all code to run the objective evaluation using museval on the MUSDB18 dataset. 13 | * `utils.py` includes additional tools like audio loading and metadata loading. 14 | 15 | ## Provide a custom dataset 16 | 17 | Users of open-unmix that have their own datasets and could not fit one of our predefined datasets might want to implement or use their own `torch.utils.data.Dataset` to be used for the training. Such a modification is very simple since our dataset. 18 | 19 | ### Template Dataset 20 | 21 | In case you want to create your own dataset we provide a template for the open-unmix API. You can use our efficient torchaudio or libsndfile based `load_audio` audio loaders or just use your own files. Since currently (pytorch<=1.1) is using index based datasets (instead of iterable based datasets), the best way to load audio is to assign the index to one audio track. However, there are possible applications where the index is ignored and the `__len__()` method just returns arbitrary number of samples. 22 | 23 | ```python 24 | from utils import load_audio, load_info 25 | class TemplateDataset(UnmixDataset): 26 | """A template dataset class for you to implement custom datasets.""" 27 | 28 | def __init__(self, root, split='train', sample_rate=44100, seq_dur=None): 29 | """Initialize the dataset 30 | """ 31 | self.root = root 32 | self.tracks = get_tracks(root, split) 33 | 34 | def __getitem__(self, index): 35 | """Returns a time domain audio example 36 | of shape=(channel, sample) 37 | """ 38 | path = self.tracks[index] 39 | x = load_audio(path) 40 | y = load_audio(path) 41 | return x, y 42 | 43 | def __len__(self): 44 | """Return the number of audio samples""" 45 | return len(self.tracks) 46 | ``` 47 | 48 | ## Provide a custom model 49 | 50 | We think that recurrent models provide the best trade-off between good results, fast training and flexibility of training due to its ability to learn from arbitrary durations of audio and different audio representations. If you want to try different models you can easily build upon our model template below: 51 | 52 | ### Template Spectrogram Model 53 | 54 | ```python 55 | from model import Spectrogram, STFT 56 | class Model(nn.Module): 57 | def __init__( 58 | self, 59 | n_fft=4096, 60 | n_hop=1024, 61 | nb_channels=2, 62 | input_is_spectrogram=False, 63 | sample_rate=44100.0, 64 | ): 65 | """ 66 | Input: (batch, channel, sample) 67 | or (frame, batch, channels, frequency) 68 | Output: (frame, batch, channels, frequency) 69 | """ 70 | 71 | super(OpenUnmix, self).__init__() 72 | 73 | def forward(self, mix): 74 | # transform to spectrogram on the fly 75 | X = self.transform(mix) 76 | nb_frames, nb_samples, nb_channels, nb_bins = x.data.shape 77 | 78 | # transform X to estimate 79 | # .... 80 | 81 | return X 82 | ``` 83 | 84 | ## Jointly train targets 85 | 86 | We designed _open-unmix_ so that the training of multiple targets is handled in separate models. We think that this has several benefits such as: 87 | 88 | * single source models can leverage unbalanced data where for each source different size of training data is available/ 89 | * training can easily distributed by training multiple models on different nodes in parallel. 90 | * at test time the selection of different models can be adjusted for specific applications. 91 | 92 | However, we acknowledge the fact that there might be reasons to train a model jointly for all sources to improve the separation performance. These changes can easily be made in _open-unmix_ with the following modifications based the way how pytorch handles single-input-multiple-outputs models. 93 | 94 | ### 1. Extend `data.py` 95 | 96 | The dataset should be able to yield a list of tensors (one for each target): E.g. the `musdb` dataset can be extended with: 97 | 98 | ```python 99 | y = [stems[ind] for ind, _ in enumerate(self.targets)] 100 | ``` 101 | 102 | ### 2. Extend `model.py` 103 | 104 | The _open-unmix_ model can be left unchanged but instead a "supermodel" can be added that joins the forward paths of all targets: 105 | 106 | ```python 107 | class OpenUnmixJoint(nn.Module): 108 | def __init__( 109 | self, 110 | targets, 111 | *args, **kwargs 112 | ): 113 | super(OpenUnmixJoint, self).__init__() 114 | self.models = nn.ModuleList( 115 | [OpenUnmix(*args, **kwargs) for target in targets] 116 | ) 117 | 118 | def forward(self, x): 119 | return [model(x) for model in self.models] 120 | ``` 121 | 122 | ### 3. Extend `train.py` 123 | 124 | The training should be updated so that the total loss is an aggregation of the individual target losses. For the mean squared error, the following modifications should be sufficient: 125 | 126 | ```python 127 | criteria = [torch.nn.MSELoss() for t in args.targets] 128 | # ... 129 | for x, y in tqdm.tqdm(train_sampler, disable=args.quiet): 130 | x = x.to(device) 131 | y = [i.to(device) for i in y] 132 | optimizer.zero_grad() 133 | Y_hats = unmix(x) 134 | loss = 0 135 | for Y_hat, target, criterion in zip(Y_hats, y, criteria): 136 | loss = loss + criterion(Y_hat, unmix.models[0].transform(target)) 137 | ``` 138 | 139 | ## End-to-End time-domain models 140 | 141 | If you want to evaluate models that work in the time domain such as WaveNet or WaveRNN, the training code would have to modified. Instead of spectrogram output `Y` the output is simply a time domain signal `y` that can directly be compared with `x`. E.g. going from: 142 | 143 | ```python 144 | Y_hat = unmix(x) 145 | Y = unmix.transform(y) 146 | loss = criterion(Y_hat, Y) 147 | ``` 148 | 149 | to: 150 | 151 | ```python 152 | y_hat = unmix(x) 153 | loss = criterion(y_hat, y) 154 | ``` 155 | 156 | Inference, in that case, would then have to drop the spectral wiener filter and instead directly save the time domain signal (and maybe its residual): 157 | 158 | ```python 159 | est = unmix(audio_torch).cpu().detach().numpy() 160 | estimates[target] = est[0].T 161 | estimates['residual'] = audio - est[0].T 162 | ``` 163 | -------------------------------------------------------------------------------- /openunmix/cli.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import torch 3 | import torchaudio 4 | import json 5 | import numpy as np 6 | import tqdm 7 | 8 | from openunmix import utils 9 | from openunmix import predict 10 | from openunmix import data 11 | 12 | import argparse 13 | 14 | 15 | def separate(): 16 | parser = argparse.ArgumentParser( 17 | description="UMX Inference", 18 | add_help=True, 19 | formatter_class=argparse.RawDescriptionHelpFormatter, 20 | ) 21 | 22 | parser.add_argument("input", type=str, nargs="+", help="List of paths to wav/flac files.") 23 | 24 | parser.add_argument( 25 | "--model", 26 | default="umxl", 27 | type=str, 28 | help="path to mode base directory of pretrained models, defaults to UMX-L", 29 | ) 30 | 31 | parser.add_argument( 32 | "--targets", 33 | nargs="+", 34 | type=str, 35 | help="provide targets to be processed. \ 36 | If none, all available targets will be computed", 37 | ) 38 | 39 | parser.add_argument( 40 | "--outdir", 41 | type=str, 42 | help="Results path where audio evaluation results are stored", 43 | ) 44 | 45 | parser.add_argument( 46 | "--ext", 47 | type=str, 48 | default=".wav", 49 | help="Output extension which sets the audio format", 50 | ) 51 | 52 | parser.add_argument("--start", type=float, default=0.0, help="Audio chunk start in seconds") 53 | 54 | parser.add_argument( 55 | "--duration", 56 | type=float, 57 | help="Audio chunk duration in seconds, negative values load full track", 58 | ) 59 | 60 | parser.add_argument("--no-cuda", action="store_true", default=False, help="disables CUDA inference") 61 | 62 | parser.add_argument( 63 | "--audio-backend", 64 | type=str, 65 | help="Sets audio backend. Default to torchaudio's default backend: See https://pytorch.org/audio/stable/backend.html" 66 | "(`sox_io`, `sox`, `soundfile` or `stempeg`)", 67 | ) 68 | 69 | parser.add_argument( 70 | "--niter", 71 | type=int, 72 | default=1, 73 | help="number of iterations for refining results.", 74 | ) 75 | 76 | parser.add_argument( 77 | "--wiener-win-len", 78 | type=int, 79 | default=300, 80 | help="Number of frames on which to apply filtering independently", 81 | ) 82 | 83 | parser.add_argument( 84 | "--residual", 85 | type=str, 86 | default=None, 87 | help="if provided, build a source with given name " "for the mix minus all estimated targets", 88 | ) 89 | 90 | parser.add_argument( 91 | "--aggregate", 92 | type=str, 93 | default=None, 94 | help="if provided, must be a string containing a valid expression for " 95 | "a dictionary, with keys as output target names, and values " 96 | "a list of targets that are used to build it. For instance: " 97 | '\'{"vocals":["vocals"], "accompaniment":["drums",' 98 | '"bass","other"]}\'', 99 | ) 100 | 101 | parser.add_argument( 102 | "--filterbank", 103 | type=str, 104 | default="torch", 105 | help="filterbank implementation method. " 106 | "Supported: `['torch', 'asteroid']`. `torch` is ~30%% faster " 107 | "compared to `asteroid` on large FFT sizes such as 4096. However " 108 | "asteroids stft can be exported to onnx, which makes is practical " 109 | "for deployment.", 110 | ) 111 | parser.add_argument( 112 | "--verbose", 113 | action="store_true", 114 | default=False, 115 | help="Enable log messages", 116 | ) 117 | args = parser.parse_args() 118 | 119 | if args.audio_backend != "stempeg" and args.audio_backend is not None: 120 | torchaudio.set_audio_backend(args.audio_backend) 121 | 122 | use_cuda = not args.no_cuda and torch.cuda.is_available() 123 | device = torch.device("cuda" if use_cuda else "cpu") 124 | if args.verbose: 125 | print("Using ", device) 126 | # parsing the output dict 127 | aggregate_dict = None if args.aggregate is None else json.loads(args.aggregate) 128 | 129 | # create separator only once to reduce model loading 130 | # when using multiple files 131 | separator = utils.load_separator( 132 | model_str_or_path=args.model, 133 | targets=args.targets, 134 | niter=args.niter, 135 | residual=args.residual, 136 | wiener_win_len=args.wiener_win_len, 137 | device=device, 138 | pretrained=True, 139 | filterbank=args.filterbank, 140 | ) 141 | 142 | separator.freeze() 143 | separator.to(device) 144 | 145 | if args.audio_backend == "stempeg": 146 | try: 147 | import stempeg 148 | except ImportError: 149 | raise RuntimeError("Please install pip package `stempeg`") 150 | 151 | # loop over the files 152 | for input_file in tqdm.tqdm(args.input): 153 | if args.audio_backend == "stempeg": 154 | audio, rate = stempeg.read_stems( 155 | input_file, 156 | start=args.start, 157 | duration=args.duration, 158 | sample_rate=separator.sample_rate, 159 | dtype=np.float32, 160 | ) 161 | audio = torch.tensor(audio) 162 | else: 163 | audio, rate = data.load_audio(input_file, start=args.start, dur=args.duration) 164 | estimates = predict.separate( 165 | audio=audio, 166 | rate=rate, 167 | aggregate_dict=aggregate_dict, 168 | separator=separator, 169 | device=device, 170 | ) 171 | if not args.outdir: 172 | model_path = Path(args.model) 173 | if not model_path.exists(): 174 | outdir = Path(Path(input_file).stem + "_" + args.model) 175 | else: 176 | outdir = Path(Path(input_file).stem + "_" + model_path.stem) 177 | else: 178 | outdir = Path(args.outdir) / Path(input_file).stem 179 | outdir.mkdir(exist_ok=True, parents=True) 180 | 181 | # write out estimates 182 | if args.audio_backend == "stempeg": 183 | target_path = str(outdir / Path("target").with_suffix(args.ext)) 184 | # convert torch dict to numpy dict 185 | estimates_numpy = {} 186 | for target, estimate in estimates.items(): 187 | estimates_numpy[target] = torch.squeeze(estimate).detach().cpu().numpy().T 188 | 189 | stempeg.write_stems( 190 | target_path, 191 | estimates_numpy, 192 | sample_rate=separator.sample_rate, 193 | writer=stempeg.FilesWriter(multiprocess=True, output_sample_rate=rate), 194 | ) 195 | else: 196 | for target, estimate in estimates.items(): 197 | target_path = str(outdir / Path(target).with_suffix(args.ext)) 198 | torchaudio.save( 199 | target_path, 200 | torch.squeeze(estimate).to("cpu"), 201 | sample_rate=separator.sample_rate, 202 | ) 203 | -------------------------------------------------------------------------------- /openunmix/transforms.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torch 4 | import torchaudio 5 | from torch import Tensor 6 | import torch.nn as nn 7 | 8 | try: 9 | from asteroid_filterbanks.enc_dec import Encoder, Decoder 10 | from asteroid_filterbanks.transforms import to_torchaudio, from_torchaudio 11 | from asteroid_filterbanks import torch_stft_fb 12 | except ImportError: 13 | pass 14 | 15 | 16 | def make_filterbanks(n_fft=4096, n_hop=1024, center=False, sample_rate=44100.0, method="torch"): 17 | window = nn.Parameter(torch.hann_window(n_fft), requires_grad=False) 18 | 19 | if method == "torch": 20 | encoder = TorchSTFT(n_fft=n_fft, n_hop=n_hop, window=window, center=center) 21 | decoder = TorchISTFT(n_fft=n_fft, n_hop=n_hop, window=window, center=center) 22 | elif method == "asteroid": 23 | fb = torch_stft_fb.TorchSTFTFB.from_torch_args( 24 | n_fft=n_fft, 25 | hop_length=n_hop, 26 | win_length=n_fft, 27 | window=window, 28 | center=center, 29 | sample_rate=sample_rate, 30 | ) 31 | encoder = AsteroidSTFT(fb) 32 | decoder = AsteroidISTFT(fb) 33 | else: 34 | raise NotImplementedError 35 | return encoder, decoder 36 | 37 | 38 | class AsteroidSTFT(nn.Module): 39 | def __init__(self, fb): 40 | super(AsteroidSTFT, self).__init__() 41 | self.enc = Encoder(fb) 42 | 43 | def forward(self, x): 44 | aux = self.enc(x) 45 | return to_torchaudio(aux) 46 | 47 | 48 | class AsteroidISTFT(nn.Module): 49 | def __init__(self, fb): 50 | super(AsteroidISTFT, self).__init__() 51 | self.dec = Decoder(fb) 52 | 53 | def forward(self, X: Tensor, length: Optional[int] = None) -> Tensor: 54 | aux = from_torchaudio(X) 55 | return self.dec(aux, length=length) 56 | 57 | 58 | class TorchSTFT(nn.Module): 59 | """Multichannel Short-Time-Fourier Forward transform 60 | uses hard coded hann_window. 61 | Args: 62 | n_fft (int, optional): transform FFT size. Defaults to 4096. 63 | n_hop (int, optional): transform hop size. Defaults to 1024. 64 | center (bool, optional): If True, the signals first window is 65 | zero padded. Centering is required for a perfect 66 | reconstruction of the signal. However, during training 67 | of spectrogram models, it can safely turned off. 68 | Defaults to `true` 69 | window (nn.Parameter, optional): window function 70 | """ 71 | 72 | def __init__( 73 | self, 74 | n_fft: int = 4096, 75 | n_hop: int = 1024, 76 | center: bool = False, 77 | window: Optional[nn.Parameter] = None, 78 | ): 79 | super(TorchSTFT, self).__init__() 80 | if window is None: 81 | self.window = nn.Parameter(torch.hann_window(n_fft), requires_grad=False) 82 | else: 83 | self.window = window 84 | 85 | self.n_fft = n_fft 86 | self.n_hop = n_hop 87 | self.center = center 88 | 89 | def forward(self, x: Tensor) -> Tensor: 90 | """STFT forward path 91 | Args: 92 | x (Tensor): audio waveform of 93 | shape (nb_samples, nb_channels, nb_timesteps) 94 | Returns: 95 | STFT (Tensor): complex stft of 96 | shape (nb_samples, nb_channels, nb_bins, nb_frames, complex=2) 97 | last axis is stacked real and imaginary 98 | """ 99 | 100 | shape = x.size() 101 | nb_samples, nb_channels, nb_timesteps = shape 102 | 103 | # pack batch 104 | x = x.view(-1, shape[-1]) 105 | 106 | complex_stft = torch.stft( 107 | x, 108 | n_fft=self.n_fft, 109 | hop_length=self.n_hop, 110 | window=self.window, 111 | center=self.center, 112 | normalized=False, 113 | onesided=True, 114 | pad_mode="reflect", 115 | return_complex=True, 116 | ) 117 | stft_f = torch.view_as_real(complex_stft) 118 | # unpack batch 119 | stft_f = stft_f.view(shape[:-1] + stft_f.shape[-3:]) 120 | return stft_f 121 | 122 | 123 | class TorchISTFT(nn.Module): 124 | """Multichannel Inverse-Short-Time-Fourier functional 125 | wrapper for torch.istft to support batches 126 | Args: 127 | STFT (Tensor): complex stft of 128 | shape (nb_samples, nb_channels, nb_bins, nb_frames, complex=2) 129 | last axis is stacked real and imaginary 130 | n_fft (int, optional): transform FFT size. Defaults to 4096. 131 | n_hop (int, optional): transform hop size. Defaults to 1024. 132 | window (callable, optional): window function 133 | center (bool, optional): If True, the signals first window is 134 | zero padded. Centering is required for a perfect 135 | reconstruction of the signal. However, during training 136 | of spectrogram models, it can safely turned off. 137 | Defaults to `true` 138 | length (int, optional): audio signal length to crop the signal 139 | Returns: 140 | x (Tensor): audio waveform of 141 | shape (nb_samples, nb_channels, nb_timesteps) 142 | """ 143 | 144 | def __init__( 145 | self, 146 | n_fft: int = 4096, 147 | n_hop: int = 1024, 148 | center: bool = False, 149 | sample_rate: float = 44100.0, 150 | window: Optional[nn.Parameter] = None, 151 | ) -> None: 152 | super(TorchISTFT, self).__init__() 153 | 154 | self.n_fft = n_fft 155 | self.n_hop = n_hop 156 | self.center = center 157 | self.sample_rate = sample_rate 158 | 159 | if window is None: 160 | self.window = nn.Parameter(torch.hann_window(n_fft), requires_grad=False) 161 | else: 162 | self.window = window 163 | 164 | def forward(self, X: Tensor, length: Optional[int] = None) -> Tensor: 165 | shape = X.size() 166 | X = X.reshape(-1, shape[-3], shape[-2], shape[-1]) 167 | 168 | y = torch.istft( 169 | torch.view_as_complex(X), 170 | n_fft=self.n_fft, 171 | hop_length=self.n_hop, 172 | window=self.window, 173 | center=self.center, 174 | normalized=False, 175 | onesided=True, 176 | length=length, 177 | ) 178 | 179 | y = y.reshape(shape[:-3] + y.shape[-1:]) 180 | 181 | return y 182 | 183 | 184 | class ComplexNorm(nn.Module): 185 | r"""Compute the norm of complex tensor input. 186 | 187 | Extension of `torchaudio.functional.complex_norm` with mono 188 | 189 | Args: 190 | mono (bool): Downmix to single channel after applying power norm 191 | to maximize 192 | """ 193 | 194 | def __init__(self, mono: bool = False): 195 | super(ComplexNorm, self).__init__() 196 | self.mono = mono 197 | 198 | def forward(self, spec: Tensor) -> Tensor: 199 | """ 200 | Args: 201 | spec: complex_tensor (Tensor): Tensor shape of 202 | `(..., complex=2)` 203 | 204 | Returns: 205 | Tensor: Power/Mag of input 206 | `(...,)` 207 | """ 208 | # take the magnitude 209 | 210 | spec = torch.abs(torch.view_as_complex(spec)) 211 | 212 | # downmix in the mag domain to preserve energy 213 | if self.mono: 214 | spec = torch.mean(spec, 1, keepdim=True) 215 | 216 | return spec 217 | -------------------------------------------------------------------------------- /openunmix/utils.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Union 2 | 3 | import torch 4 | import os 5 | import numpy as np 6 | import torchaudio 7 | import warnings 8 | from pathlib import Path 9 | from contextlib import redirect_stderr 10 | import io 11 | import json 12 | 13 | import openunmix 14 | from openunmix import model 15 | 16 | 17 | def bandwidth_to_max_bin(rate: float, n_fft: int, bandwidth: float) -> np.ndarray: 18 | """Convert bandwidth to maximum bin count 19 | 20 | Assuming lapped transforms such as STFT 21 | 22 | Args: 23 | rate (int): Sample rate 24 | n_fft (int): FFT length 25 | bandwidth (float): Target bandwidth in Hz 26 | 27 | Returns: 28 | np.ndarray: maximum frequency bin 29 | """ 30 | freqs = np.linspace(0, rate / 2, n_fft // 2 + 1, endpoint=True) 31 | 32 | return np.max(np.where(freqs <= bandwidth)[0]) + 1 33 | 34 | 35 | def save_checkpoint(state: dict, is_best: bool, path: str, target: str): 36 | """Convert bandwidth to maximum bin count 37 | 38 | Assuming lapped transforms such as STFT 39 | 40 | Args: 41 | state (dict): torch model state dict 42 | is_best (bool): if current model is about to be saved as best model 43 | path (str): model path 44 | target (str): target name 45 | """ 46 | # save full checkpoint including optimizer 47 | torch.save(state, os.path.join(path, target + ".chkpnt")) 48 | if is_best: 49 | # save just the weights 50 | torch.save(state["state_dict"], os.path.join(path, target + ".pth")) 51 | 52 | 53 | class AverageMeter(object): 54 | """Computes and stores the average and current value""" 55 | 56 | def __init__(self): 57 | self.reset() 58 | 59 | def reset(self): 60 | self.val = 0 61 | self.avg = 0 62 | self.sum = 0 63 | self.count = 0 64 | 65 | def update(self, val, n=1): 66 | self.val = val 67 | self.sum += val * n 68 | self.count += n 69 | self.avg = self.sum / self.count 70 | 71 | 72 | class EarlyStopping(object): 73 | """Early Stopping Monitor""" 74 | 75 | def __init__(self, mode="min", min_delta=0, patience=10): 76 | self.mode = mode 77 | self.min_delta = min_delta 78 | self.patience = patience 79 | self.best = None 80 | self.num_bad_epochs = 0 81 | self.is_better = None 82 | self._init_is_better(mode, min_delta) 83 | 84 | if patience == 0: 85 | self.is_better = lambda a, b: True 86 | 87 | def step(self, metrics): 88 | if self.best is None: 89 | self.best = metrics 90 | return False 91 | 92 | if np.isnan(metrics): 93 | return True 94 | 95 | if self.is_better(metrics, self.best): 96 | self.num_bad_epochs = 0 97 | self.best = metrics 98 | else: 99 | self.num_bad_epochs += 1 100 | 101 | if self.num_bad_epochs >= self.patience: 102 | return True 103 | 104 | return False 105 | 106 | def _init_is_better(self, mode, min_delta): 107 | if mode not in {"min", "max"}: 108 | raise ValueError("mode " + mode + " is unknown!") 109 | if mode == "min": 110 | self.is_better = lambda a, best: a < best - min_delta 111 | if mode == "max": 112 | self.is_better = lambda a, best: a > best + min_delta 113 | 114 | 115 | def load_target_models(targets, model_str_or_path="umxl", device="cpu", pretrained=True): 116 | """Core model loader 117 | 118 | target model path can be either .pth, or -sha256.pth 119 | (as used on torchub) 120 | 121 | The loader either loads the models from a known model string 122 | as registered in the __init__.py or loads from custom configs. 123 | """ 124 | if isinstance(targets, str): 125 | targets = [targets] 126 | 127 | model_path = Path(model_str_or_path).expanduser() 128 | if not model_path.exists(): 129 | # model path does not exist, use pretrained models 130 | try: 131 | # disable progress bar 132 | hub_loader = getattr(openunmix, model_str_or_path + "_spec") 133 | err = io.StringIO() 134 | with redirect_stderr(err): 135 | return hub_loader(targets=targets, device=device, pretrained=pretrained) 136 | print(err.getvalue()) 137 | except AttributeError: 138 | raise NameError("Model does not exist on torchhub") 139 | # assume model is a path to a local model_str_or_path directory 140 | else: 141 | models = {} 142 | for target in targets: 143 | # load model from disk 144 | with open(Path(model_path, target + ".json"), "r") as stream: 145 | results = json.load(stream) 146 | 147 | target_model_path = next(Path(model_path).glob("%s*.pth" % target)) 148 | state = torch.load(target_model_path, map_location=device) 149 | 150 | models[target] = model.OpenUnmix( 151 | nb_bins=results["args"]["nfft"] // 2 + 1, 152 | nb_channels=results["args"]["nb_channels"], 153 | hidden_size=results["args"]["hidden_size"], 154 | max_bin=state["input_mean"].shape[0], 155 | ) 156 | 157 | if pretrained: 158 | models[target].load_state_dict(state, strict=False) 159 | 160 | models[target].to(device) 161 | return models 162 | 163 | 164 | def load_separator( 165 | model_str_or_path: str = "umxl", 166 | targets: Optional[list] = None, 167 | niter: int = 1, 168 | residual: bool = False, 169 | wiener_win_len: Optional[int] = 300, 170 | device: Union[str, torch.device] = "cpu", 171 | pretrained: bool = True, 172 | filterbank: str = "torch", 173 | ): 174 | """Separator loader 175 | 176 | Args: 177 | model_str_or_path (str): Model name or path to model _parent_ directory 178 | E.g. The following files are assumed to present when 179 | loading `model_str_or_path='mymodel', targets=['vocals']` 180 | 'mymodel/separator.json', mymodel/vocals.pth', 'mymodel/vocals.json'. 181 | Defaults to `umxl`. 182 | targets (list of str or None): list of target names. When loading a 183 | pre-trained model, all `targets` can be None as all targets 184 | will be loaded 185 | niter (int): Number of EM steps for refining initial estimates 186 | in a post-processing stage. `--niter 0` skips this step altogether 187 | (and thus makes separation significantly faster) More iterations 188 | can get better interference reduction at the price of artifacts. 189 | Defaults to `1`. 190 | residual (bool): Computes a residual target, for custom separation 191 | scenarios when not all targets are available (at the expense 192 | of slightly less performance). E.g vocal/accompaniment 193 | Defaults to `False`. 194 | wiener_win_len (int): The size of the excerpts (number of frames) on 195 | which to apply filtering independently. This means assuming 196 | time varying stereo models and localization of sources. 197 | None means not batching but using the whole signal. It comes at the 198 | price of a much larger memory usage. 199 | Defaults to `300` 200 | device (str): torch device, defaults to `cpu` 201 | pretrained (bool): determines if loading pre-trained weights 202 | filterbank (str): filterbank implementation method. 203 | Supported are `['torch', 'asteroid']`. `torch` is about 30% faster 204 | compared to `asteroid` on large FFT sizes such as 4096. However, 205 | asteroids stft can be exported to onnx, which makes is practical 206 | for deployment. 207 | """ 208 | model_path = Path(model_str_or_path).expanduser() 209 | 210 | # when path exists, we assume its a custom model saved locally 211 | if model_path.exists(): 212 | if targets is None: 213 | raise UserWarning("For custom models, please specify the targets") 214 | 215 | target_models = load_target_models(targets=targets, model_str_or_path=model_path, pretrained=pretrained) 216 | 217 | with open(Path(model_path, "separator.json"), "r") as stream: 218 | enc_conf = json.load(stream) 219 | 220 | separator = model.Separator( 221 | target_models=target_models, 222 | niter=niter, 223 | residual=residual, 224 | wiener_win_len=wiener_win_len, 225 | sample_rate=enc_conf["sample_rate"], 226 | n_fft=enc_conf["nfft"], 227 | n_hop=enc_conf["nhop"], 228 | nb_channels=enc_conf["nb_channels"], 229 | filterbank=filterbank, 230 | ).to(device) 231 | 232 | # otherwise we load the separator from torchhub 233 | else: 234 | hub_loader = getattr(openunmix, model_str_or_path) 235 | separator = hub_loader( 236 | targets=targets, 237 | device=device, 238 | pretrained=True, 239 | niter=niter, 240 | residual=residual, 241 | wiener_win_len=wiener_win_len, 242 | filterbank=filterbank, 243 | ) 244 | 245 | return separator 246 | 247 | 248 | def preprocess( 249 | audio: torch.Tensor, 250 | rate: Optional[float] = None, 251 | model_rate: Optional[float] = None, 252 | ) -> torch.Tensor: 253 | """ 254 | From an input tensor, convert it to a tensor of shape 255 | shape=(nb_samples, nb_channels, nb_timesteps). This includes: 256 | - if input is 1D, adding the samples and channels dimensions. 257 | - if input is 2D 258 | o and the smallest dimension is 1 or 2, adding the samples one. 259 | o and all dimensions are > 2, assuming the smallest is the samples 260 | one, and adding the channel one 261 | - at the end, if the number of channels is greater than the number 262 | of time steps, swap those two. 263 | - resampling to target rate if necessary 264 | 265 | Args: 266 | audio (Tensor): input waveform 267 | rate (float): sample rate for the audio 268 | model_rate (float): sample rate for the model 269 | 270 | Returns: 271 | Tensor: [shape=(nb_samples, nb_channels=2, nb_timesteps)] 272 | """ 273 | shape = torch.as_tensor(audio.shape, device=audio.device) 274 | 275 | if len(shape) == 1: 276 | # assuming only time dimension is provided. 277 | audio = audio[None, None, ...] 278 | elif len(shape) == 2: 279 | if shape.min() <= 2: 280 | # assuming sample dimension is missing 281 | audio = audio[None, ...] 282 | else: 283 | # assuming channel dimension is missing 284 | audio = audio[:, None, ...] 285 | if audio.shape[1] > audio.shape[2]: 286 | # swapping channel and time 287 | audio = audio.transpose(1, 2) 288 | if audio.shape[1] > 2: 289 | warnings.warn("Channel count > 2!. Only the first two channels " "will be processed!") 290 | audio = audio[..., :2] 291 | 292 | if audio.shape[1] == 1: 293 | # if we have mono, we duplicate it to get stereo 294 | audio = torch.repeat_interleave(audio, 2, dim=1) 295 | 296 | if rate != model_rate: 297 | warnings.warn("resample to model sample rate") 298 | # we have to resample to model samplerate if needed 299 | # this makes sure we resample input only once 300 | resampler = torchaudio.transforms.Resample( 301 | orig_freq=rate, new_freq=model_rate, resampling_method="sinc_interpolation" 302 | ).to(audio.device) 303 | audio = resampler(audio) 304 | return audio 305 | -------------------------------------------------------------------------------- /scripts/train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import time 4 | from pathlib import Path 5 | import tqdm 6 | import json 7 | import sklearn.preprocessing 8 | import numpy as np 9 | import random 10 | from git import Repo 11 | import os 12 | import copy 13 | import torchaudio 14 | 15 | from openunmix import data 16 | from openunmix import model 17 | from openunmix import utils 18 | from openunmix import transforms 19 | 20 | tqdm.monitor_interval = 0 21 | 22 | 23 | def train(args, unmix, encoder, device, train_sampler, optimizer): 24 | losses = utils.AverageMeter() 25 | unmix.train() 26 | pbar = tqdm.tqdm(train_sampler, disable=args.quiet) 27 | for x, y in pbar: 28 | pbar.set_description("Training batch") 29 | x, y = x.to(device), y.to(device) 30 | optimizer.zero_grad() 31 | X = encoder(x) 32 | Y_hat = unmix(X) 33 | Y = encoder(y) 34 | loss = torch.nn.functional.mse_loss(Y_hat, Y) 35 | loss.backward() 36 | optimizer.step() 37 | losses.update(loss.item(), Y.size(1)) 38 | pbar.set_postfix(loss="{:.3f}".format(losses.avg)) 39 | return losses.avg 40 | 41 | 42 | def valid(args, unmix, encoder, device, valid_sampler): 43 | losses = utils.AverageMeter() 44 | unmix.eval() 45 | with torch.no_grad(): 46 | for x, y in valid_sampler: 47 | x, y = x.to(device), y.to(device) 48 | X = encoder(x) 49 | Y_hat = unmix(X) 50 | Y = encoder(y) 51 | loss = torch.nn.functional.mse_loss(Y_hat, Y) 52 | losses.update(loss.item(), Y.size(1)) 53 | return losses.avg 54 | 55 | 56 | def get_statistics(args, encoder, dataset): 57 | encoder = copy.deepcopy(encoder).to("cpu") 58 | scaler = sklearn.preprocessing.StandardScaler() 59 | 60 | dataset_scaler = copy.deepcopy(dataset) 61 | if isinstance(dataset_scaler, data.SourceFolderDataset): 62 | dataset_scaler.random_chunks = False 63 | else: 64 | dataset_scaler.random_chunks = False 65 | dataset_scaler.seq_duration = None 66 | 67 | dataset_scaler.samples_per_track = 1 68 | dataset_scaler.augmentations = None 69 | dataset_scaler.random_track_mix = False 70 | dataset_scaler.random_interferer_mix = False 71 | 72 | pbar = tqdm.tqdm(range(len(dataset_scaler)), disable=args.quiet) 73 | for ind in pbar: 74 | x, y = dataset_scaler[ind] 75 | pbar.set_description("Compute dataset statistics") 76 | # downmix to mono channel 77 | X = encoder(x[None, ...]).mean(1, keepdim=False).permute(0, 2, 1) 78 | 79 | scaler.partial_fit(np.squeeze(X)) 80 | 81 | # set inital input scaler values 82 | std = np.maximum(scaler.scale_, 1e-4 * np.max(scaler.scale_)) 83 | return scaler.mean_, std 84 | 85 | 86 | def main(): 87 | parser = argparse.ArgumentParser(description="Open Unmix Trainer") 88 | 89 | # which target do we want to train? 90 | parser.add_argument( 91 | "--target", 92 | type=str, 93 | default="vocals", 94 | help="target source (will be passed to the dataset)", 95 | ) 96 | 97 | # Dataset paramaters 98 | parser.add_argument( 99 | "--dataset", 100 | type=str, 101 | default="musdb", 102 | choices=[ 103 | "musdb", 104 | "aligned", 105 | "sourcefolder", 106 | "trackfolder_var", 107 | "trackfolder_fix", 108 | ], 109 | help="Name of the dataset.", 110 | ) 111 | parser.add_argument("--root", type=str, help="root path of dataset") 112 | parser.add_argument( 113 | "--output", 114 | type=str, 115 | default="open-unmix", 116 | help="provide output path base folder name", 117 | ) 118 | parser.add_argument("--model", type=str, help="Name or path of pretrained model to fine-tune") 119 | parser.add_argument("--checkpoint", type=str, help="Path of checkpoint to resume training") 120 | parser.add_argument( 121 | "--audio-backend", 122 | type=str, 123 | default="soundfile", 124 | help="Set torchaudio backend (`sox_io` or `soundfile`", 125 | ) 126 | 127 | # Training Parameters 128 | parser.add_argument("--epochs", type=int, default=1000) 129 | parser.add_argument("--batch-size", type=int, default=16) 130 | parser.add_argument("--lr", type=float, default=0.001, help="learning rate, defaults to 1e-3") 131 | parser.add_argument( 132 | "--patience", 133 | type=int, 134 | default=140, 135 | help="maximum number of train epochs (default: 140)", 136 | ) 137 | parser.add_argument( 138 | "--lr-decay-patience", 139 | type=int, 140 | default=80, 141 | help="lr decay patience for plateau scheduler", 142 | ) 143 | parser.add_argument( 144 | "--lr-decay-gamma", 145 | type=float, 146 | default=0.3, 147 | help="gamma of learning rate scheduler decay", 148 | ) 149 | parser.add_argument("--weight-decay", type=float, default=0.00001, help="weight decay") 150 | parser.add_argument( 151 | "--seed", type=int, default=42, metavar="S", help="random seed (default: 42)" 152 | ) 153 | 154 | # Model Parameters 155 | parser.add_argument( 156 | "--seq-dur", 157 | type=float, 158 | default=6.0, 159 | help="Sequence duration in seconds" "value of <=0.0 will use full/variable length", 160 | ) 161 | parser.add_argument( 162 | "--unidirectional", 163 | action="store_true", 164 | default=False, 165 | help="Use unidirectional LSTM", 166 | ) 167 | parser.add_argument("--nfft", type=int, default=4096, help="STFT fft size and window size") 168 | parser.add_argument("--nhop", type=int, default=1024, help="STFT hop size") 169 | parser.add_argument( 170 | "--hidden-size", 171 | type=int, 172 | default=512, 173 | help="hidden size parameter of bottleneck layers", 174 | ) 175 | parser.add_argument( 176 | "--bandwidth", type=int, default=16000, help="maximum model bandwidth in herz" 177 | ) 178 | parser.add_argument( 179 | "--nb-channels", 180 | type=int, 181 | default=2, 182 | help="set number of channels for model (1, 2)", 183 | ) 184 | parser.add_argument( 185 | "--nb-workers", type=int, default=0, help="Number of workers for dataloader." 186 | ) 187 | parser.add_argument( 188 | "--debug", 189 | action="store_true", 190 | default=False, 191 | help="Speed up training init for dev purposes", 192 | ) 193 | 194 | # Misc Parameters 195 | parser.add_argument( 196 | "--quiet", 197 | action="store_true", 198 | default=False, 199 | help="less verbose during training", 200 | ) 201 | parser.add_argument( 202 | "--no-cuda", action="store_true", default=False, help="disables CUDA training" 203 | ) 204 | 205 | args, _ = parser.parse_known_args() 206 | 207 | torchaudio.set_audio_backend(args.audio_backend) 208 | use_cuda = not args.no_cuda and torch.cuda.is_available() 209 | print("Using GPU:", use_cuda) 210 | dataloader_kwargs = {"num_workers": args.nb_workers, "pin_memory": True} if use_cuda else {} 211 | 212 | repo_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) 213 | repo = Repo(repo_dir) 214 | commit = repo.head.commit.hexsha[:7] 215 | 216 | # use jpg or npy 217 | torch.manual_seed(args.seed) 218 | random.seed(args.seed) 219 | 220 | device = torch.device("cuda" if use_cuda else "cpu") 221 | 222 | train_dataset, valid_dataset, args = data.load_datasets(parser, args) 223 | 224 | # create output dir if not exist 225 | target_path = Path(args.output) 226 | target_path.mkdir(parents=True, exist_ok=True) 227 | 228 | train_sampler = torch.utils.data.DataLoader( 229 | train_dataset, batch_size=args.batch_size, shuffle=True, **dataloader_kwargs 230 | ) 231 | valid_sampler = torch.utils.data.DataLoader(valid_dataset, batch_size=1, **dataloader_kwargs) 232 | 233 | stft, _ = transforms.make_filterbanks( 234 | n_fft=args.nfft, n_hop=args.nhop, sample_rate=train_dataset.sample_rate 235 | ) 236 | encoder = torch.nn.Sequential(stft, model.ComplexNorm(mono=args.nb_channels == 1)).to(device) 237 | 238 | separator_conf = { 239 | "nfft": args.nfft, 240 | "nhop": args.nhop, 241 | "sample_rate": train_dataset.sample_rate, 242 | "nb_channels": args.nb_channels, 243 | } 244 | 245 | with open(Path(target_path, "separator.json"), "w") as outfile: 246 | outfile.write(json.dumps(separator_conf, indent=4, sort_keys=True)) 247 | 248 | if args.checkpoint or args.model or args.debug: 249 | scaler_mean = None 250 | scaler_std = None 251 | else: 252 | scaler_mean, scaler_std = get_statistics(args, encoder, train_dataset) 253 | 254 | max_bin = utils.bandwidth_to_max_bin(train_dataset.sample_rate, args.nfft, args.bandwidth) 255 | 256 | if args.model: 257 | # fine tune model 258 | print(f"Fine-tuning model from {args.model}") 259 | unmix = utils.load_target_models( 260 | args.target, model_str_or_path=args.model, device=device, pretrained=True 261 | )[args.target] 262 | unmix = unmix.to(device) 263 | else: 264 | unmix = model.OpenUnmix( 265 | input_mean=scaler_mean, 266 | input_scale=scaler_std, 267 | nb_bins=args.nfft // 2 + 1, 268 | nb_channels=args.nb_channels, 269 | hidden_size=args.hidden_size, 270 | max_bin=max_bin, 271 | unidirectional=args.unidirectional 272 | ).to(device) 273 | 274 | optimizer = torch.optim.Adam(unmix.parameters(), lr=args.lr, weight_decay=args.weight_decay) 275 | 276 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( 277 | optimizer, 278 | factor=args.lr_decay_gamma, 279 | patience=args.lr_decay_patience, 280 | cooldown=10, 281 | ) 282 | 283 | es = utils.EarlyStopping(patience=args.patience) 284 | 285 | # if a checkpoint is specified: resume training 286 | if args.checkpoint: 287 | model_path = Path(args.checkpoint).expanduser() 288 | with open(Path(model_path, args.target + ".json"), "r") as stream: 289 | results = json.load(stream) 290 | 291 | target_model_path = Path(model_path, args.target + ".chkpnt") 292 | checkpoint = torch.load(target_model_path, map_location=device) 293 | unmix.load_state_dict(checkpoint["state_dict"], strict=False) 294 | optimizer.load_state_dict(checkpoint["optimizer"]) 295 | scheduler.load_state_dict(checkpoint["scheduler"]) 296 | # train for another epochs_trained 297 | t = tqdm.trange( 298 | results["epochs_trained"], 299 | results["epochs_trained"] + args.epochs + 1, 300 | disable=args.quiet, 301 | ) 302 | train_losses = results["train_loss_history"] 303 | valid_losses = results["valid_loss_history"] 304 | train_times = results["train_time_history"] 305 | best_epoch = results["best_epoch"] 306 | es.best = results["best_loss"] 307 | es.num_bad_epochs = results["num_bad_epochs"] 308 | # else start optimizer from scratch 309 | else: 310 | t = tqdm.trange(1, args.epochs + 1, disable=args.quiet) 311 | train_losses = [] 312 | valid_losses = [] 313 | train_times = [] 314 | best_epoch = 0 315 | 316 | for epoch in t: 317 | t.set_description("Training epoch") 318 | end = time.time() 319 | train_loss = train(args, unmix, encoder, device, train_sampler, optimizer) 320 | valid_loss = valid(args, unmix, encoder, device, valid_sampler) 321 | scheduler.step(valid_loss) 322 | train_losses.append(train_loss) 323 | valid_losses.append(valid_loss) 324 | 325 | t.set_postfix(train_loss=train_loss, val_loss=valid_loss) 326 | 327 | stop = es.step(valid_loss) 328 | 329 | if valid_loss == es.best: 330 | best_epoch = epoch 331 | 332 | utils.save_checkpoint( 333 | { 334 | "epoch": epoch + 1, 335 | "state_dict": unmix.state_dict(), 336 | "best_loss": es.best, 337 | "optimizer": optimizer.state_dict(), 338 | "scheduler": scheduler.state_dict(), 339 | }, 340 | is_best=valid_loss == es.best, 341 | path=target_path, 342 | target=args.target, 343 | ) 344 | 345 | # save params 346 | params = { 347 | "epochs_trained": epoch, 348 | "args": vars(args), 349 | "best_loss": es.best, 350 | "best_epoch": best_epoch, 351 | "train_loss_history": train_losses, 352 | "valid_loss_history": valid_losses, 353 | "train_time_history": train_times, 354 | "num_bad_epochs": es.num_bad_epochs, 355 | "commit": commit, 356 | } 357 | 358 | with open(Path(target_path, args.target + ".json"), "w") as outfile: 359 | outfile.write(json.dumps(params, indent=4, sort_keys=True)) 360 | 361 | train_times.append(time.time() - end) 362 | 363 | if stop: 364 | print("Apply Early Stopping") 365 | break 366 | 367 | 368 | if __name__ == "__main__": 369 | main() 370 | -------------------------------------------------------------------------------- /openunmix/model.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Mapping 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from torch import Tensor 8 | from torch.nn import LSTM, BatchNorm1d, Linear, Parameter 9 | from .filtering import wiener 10 | from .transforms import make_filterbanks, ComplexNorm 11 | 12 | 13 | class OpenUnmix(nn.Module): 14 | """OpenUnmix Core spectrogram based separation module. 15 | 16 | Args: 17 | nb_bins (int): Number of input time-frequency bins (Default: `4096`). 18 | nb_channels (int): Number of input audio channels (Default: `2`). 19 | hidden_size (int): Size for bottleneck layers (Default: `512`). 20 | nb_layers (int): Number of Bi-LSTM layers (Default: `3`). 21 | unidirectional (bool): Use causal model useful for realtime purpose. 22 | (Default `False`) 23 | input_mean (ndarray or None): global data mean of shape `(nb_bins, )`. 24 | Defaults to zeros(nb_bins) 25 | input_scale (ndarray or None): global data mean of shape `(nb_bins, )`. 26 | Defaults to ones(nb_bins) 27 | max_bin (int or None): Internal frequency bin threshold to 28 | reduce high frequency content. Defaults to `None` which results 29 | in `nb_bins` 30 | """ 31 | 32 | def __init__( 33 | self, 34 | nb_bins: int = 4096, 35 | nb_channels: int = 2, 36 | hidden_size: int = 512, 37 | nb_layers: int = 3, 38 | unidirectional: bool = False, 39 | input_mean: Optional[np.ndarray] = None, 40 | input_scale: Optional[np.ndarray] = None, 41 | max_bin: Optional[int] = None, 42 | ): 43 | super(OpenUnmix, self).__init__() 44 | 45 | self.nb_output_bins = nb_bins 46 | if max_bin: 47 | self.nb_bins = max_bin 48 | else: 49 | self.nb_bins = self.nb_output_bins 50 | 51 | self.hidden_size = hidden_size 52 | 53 | self.fc1 = Linear(self.nb_bins * nb_channels, hidden_size, bias=False) 54 | 55 | self.bn1 = BatchNorm1d(hidden_size) 56 | 57 | if unidirectional: 58 | lstm_hidden_size = hidden_size 59 | else: 60 | lstm_hidden_size = hidden_size // 2 61 | 62 | self.lstm = LSTM( 63 | input_size=hidden_size, 64 | hidden_size=lstm_hidden_size, 65 | num_layers=nb_layers, 66 | bidirectional=not unidirectional, 67 | batch_first=False, 68 | dropout=0.4 if nb_layers > 1 else 0, 69 | ) 70 | 71 | fc2_hiddensize = hidden_size * 2 72 | self.fc2 = Linear(in_features=fc2_hiddensize, out_features=hidden_size, bias=False) 73 | 74 | self.bn2 = BatchNorm1d(hidden_size) 75 | 76 | self.fc3 = Linear( 77 | in_features=hidden_size, 78 | out_features=self.nb_output_bins * nb_channels, 79 | bias=False, 80 | ) 81 | 82 | self.bn3 = BatchNorm1d(self.nb_output_bins * nb_channels) 83 | 84 | if input_mean is not None: 85 | input_mean = torch.from_numpy(-input_mean[: self.nb_bins]).float() 86 | else: 87 | input_mean = torch.zeros(self.nb_bins) 88 | 89 | if input_scale is not None: 90 | input_scale = torch.from_numpy(1.0 / input_scale[: self.nb_bins]).float() 91 | else: 92 | input_scale = torch.ones(self.nb_bins) 93 | 94 | self.input_mean = Parameter(input_mean) 95 | self.input_scale = Parameter(input_scale) 96 | 97 | self.output_scale = Parameter(torch.ones(self.nb_output_bins).float()) 98 | self.output_mean = Parameter(torch.ones(self.nb_output_bins).float()) 99 | 100 | def freeze(self): 101 | # set all parameters as not requiring gradient, more RAM-efficient 102 | # at test time 103 | for p in self.parameters(): 104 | p.requires_grad = False 105 | self.eval() 106 | 107 | def forward(self, x: Tensor) -> Tensor: 108 | """ 109 | Args: 110 | x: input spectrogram of shape 111 | `(nb_samples, nb_channels, nb_bins, nb_frames)` 112 | 113 | Returns: 114 | Tensor: filtered spectrogram of shape 115 | `(nb_samples, nb_channels, nb_bins, nb_frames)` 116 | """ 117 | 118 | # permute so that batch is last for lstm 119 | x = x.permute(3, 0, 1, 2) 120 | # get current spectrogram shape 121 | nb_frames, nb_samples, nb_channels, nb_bins = x.data.shape 122 | 123 | mix = x.detach().clone() 124 | 125 | # crop 126 | x = x[..., : self.nb_bins] 127 | # shift and scale input to mean=0 std=1 (across all bins) 128 | x = x + self.input_mean 129 | x = x * self.input_scale 130 | 131 | # to (nb_frames*nb_samples, nb_channels*nb_bins) 132 | # and encode to (nb_frames*nb_samples, hidden_size) 133 | x = self.fc1(x.reshape(-1, nb_channels * self.nb_bins)) 134 | # normalize every instance in a batch 135 | x = self.bn1(x) 136 | x = x.reshape(nb_frames, nb_samples, self.hidden_size) 137 | # squash range ot [-1, 1] 138 | x = torch.tanh(x) 139 | 140 | # apply 3-layers of stacked LSTM 141 | lstm_out = self.lstm(x) 142 | 143 | # lstm skip connection 144 | x = torch.cat([x, lstm_out[0]], -1) 145 | 146 | # first dense stage + batch norm 147 | x = self.fc2(x.reshape(-1, x.shape[-1])) 148 | x = self.bn2(x) 149 | 150 | x = F.relu(x) 151 | 152 | # second dense stage + layer norm 153 | x = self.fc3(x) 154 | x = self.bn3(x) 155 | 156 | # reshape back to original dim 157 | x = x.reshape(nb_frames, nb_samples, nb_channels, self.nb_output_bins) 158 | 159 | # apply output scaling 160 | x *= self.output_scale 161 | x += self.output_mean 162 | 163 | # since our output is non-negative, we can apply RELU 164 | x = F.relu(x) * mix 165 | # permute back to (nb_samples, nb_channels, nb_bins, nb_frames) 166 | return x.permute(1, 2, 3, 0) 167 | 168 | 169 | class Separator(nn.Module): 170 | """ 171 | Separator class to encapsulate all the stereo filtering 172 | as a torch Module, to enable end-to-end learning. 173 | 174 | Args: 175 | targets (dict of str: nn.Module): dictionary of target models 176 | the spectrogram models to be used by the Separator. 177 | niter (int): Number of EM steps for refining initial estimates in a 178 | post-processing stage. Zeroed if only one target is estimated. 179 | defaults to `1`. 180 | residual (bool): adds an additional residual target, obtained by 181 | subtracting the other estimated targets from the mixture, 182 | before any potential EM post-processing. 183 | Defaults to `False`. 184 | wiener_win_len (int or None): The size of the excerpts 185 | (number of frames) on which to apply filtering 186 | independently. This means assuming time varying stereo models and 187 | localization of sources. 188 | None means not batching but using the whole signal. It comes at the 189 | price of a much larger memory usage. 190 | filterbank (str): filterbank implementation method. 191 | Supported are `['torch', 'asteroid']`. `torch` is about 30% faster 192 | compared to `asteroid` on large FFT sizes such as 4096. However, 193 | asteroids stft can be exported to onnx, which makes is practical 194 | for deployment. 195 | """ 196 | 197 | def __init__( 198 | self, 199 | target_models: Mapping[str, nn.Module], 200 | niter: int = 0, 201 | softmask: bool = False, 202 | residual: bool = False, 203 | sample_rate: float = 44100.0, 204 | n_fft: int = 4096, 205 | n_hop: int = 1024, 206 | nb_channels: int = 2, 207 | wiener_win_len: Optional[int] = 300, 208 | filterbank: str = "torch", 209 | ): 210 | super(Separator, self).__init__() 211 | 212 | # saving parameters 213 | self.niter = niter 214 | self.residual = residual 215 | self.softmask = softmask 216 | self.wiener_win_len = wiener_win_len 217 | 218 | self.stft, self.istft = make_filterbanks( 219 | n_fft=n_fft, 220 | n_hop=n_hop, 221 | center=True, 222 | method=filterbank, 223 | sample_rate=sample_rate, 224 | ) 225 | self.complexnorm = ComplexNorm(mono=nb_channels == 1) 226 | 227 | # registering the targets models 228 | self.target_models = nn.ModuleDict(target_models) 229 | # adding till https://github.com/pytorch/pytorch/issues/38963 230 | self.nb_targets = len(self.target_models) 231 | # get the sample_rate as the sample_rate of the first model 232 | # (tacitly assume it's the same for all targets) 233 | self.register_buffer("sample_rate", torch.as_tensor(sample_rate)) 234 | 235 | def freeze(self): 236 | # set all parameters as not requiring gradient, more RAM-efficient 237 | # at test time 238 | for p in self.parameters(): 239 | p.requires_grad = False 240 | self.eval() 241 | 242 | def forward(self, audio: Tensor) -> Tensor: 243 | """Performing the separation on audio input 244 | 245 | Args: 246 | audio (Tensor): [shape=(nb_samples, nb_channels, nb_timesteps)] 247 | mixture audio waveform 248 | 249 | Returns: 250 | Tensor: stacked tensor of separated waveforms 251 | shape `(nb_samples, nb_targets, nb_channels, nb_timesteps)` 252 | """ 253 | 254 | nb_sources = self.nb_targets 255 | nb_samples = audio.shape[0] 256 | 257 | # getting the STFT of mix: 258 | # (nb_samples, nb_channels, nb_bins, nb_frames, 2) 259 | mix_stft = self.stft(audio) 260 | X = self.complexnorm(mix_stft) 261 | 262 | # initializing spectrograms variable 263 | spectrograms = torch.zeros(X.shape + (nb_sources,), dtype=audio.dtype, device=X.device) 264 | 265 | for j, (target_name, target_module) in enumerate(self.target_models.items()): 266 | # apply current model to get the source spectrogram 267 | target_spectrogram = target_module(X.detach().clone()) 268 | spectrograms[..., j] = target_spectrogram 269 | 270 | # transposing it as 271 | # (nb_samples, nb_frames, nb_bins,{1,nb_channels}, nb_sources) 272 | spectrograms = spectrograms.permute(0, 3, 2, 1, 4) 273 | 274 | # rearranging it into: 275 | # (nb_samples, nb_frames, nb_bins, nb_channels, 2) to feed 276 | # into filtering methods 277 | mix_stft = mix_stft.permute(0, 3, 2, 1, 4) 278 | 279 | # create an additional target if we need to build a residual 280 | if self.residual: 281 | # we add an additional target 282 | nb_sources += 1 283 | 284 | if nb_sources == 1 and self.niter > 0: 285 | raise Exception( 286 | "Cannot use EM if only one target is estimated." 287 | "Provide two targets or create an additional " 288 | "one with `--residual`" 289 | ) 290 | 291 | nb_frames = spectrograms.shape[1] 292 | targets_stft = torch.zeros(mix_stft.shape + (nb_sources,), dtype=audio.dtype, device=mix_stft.device) 293 | for sample in range(nb_samples): 294 | pos = 0 295 | if self.wiener_win_len: 296 | wiener_win_len = self.wiener_win_len 297 | else: 298 | wiener_win_len = nb_frames 299 | while pos < nb_frames: 300 | cur_frame = torch.arange(pos, min(nb_frames, pos + wiener_win_len)) 301 | pos = int(cur_frame[-1]) + 1 302 | 303 | targets_stft[sample, cur_frame] = wiener( 304 | spectrograms[sample, cur_frame], 305 | mix_stft[sample, cur_frame], 306 | self.niter, 307 | softmask=self.softmask, 308 | residual=self.residual, 309 | ) 310 | 311 | # getting to (nb_samples, nb_targets, channel, fft_size, n_frames, 2) 312 | targets_stft = targets_stft.permute(0, 5, 3, 2, 1, 4).contiguous() 313 | 314 | # inverse STFT 315 | estimates = self.istft(targets_stft, length=audio.shape[2]) 316 | 317 | return estimates 318 | 319 | def to_dict(self, estimates: Tensor, aggregate_dict: Optional[dict] = None) -> dict: 320 | """Convert estimates as stacked tensor to dictionary 321 | 322 | Args: 323 | estimates (Tensor): separated targets of shape 324 | (nb_samples, nb_targets, nb_channels, nb_timesteps) 325 | aggregate_dict (dict or None) 326 | 327 | Returns: 328 | (dict of str: Tensor): 329 | """ 330 | estimates_dict = {} 331 | for k, target in enumerate(self.target_models): 332 | estimates_dict[target] = estimates[:, k, ...] 333 | 334 | # in the case of residual, we added another source 335 | if self.residual: 336 | estimates_dict["residual"] = estimates[:, -1, ...] 337 | 338 | if aggregate_dict is not None: 339 | new_estimates = {} 340 | for key in aggregate_dict: 341 | new_estimates[key] = torch.tensor(0.0) 342 | for target in aggregate_dict[key]: 343 | new_estimates[key] = new_estimates[key] + estimates_dict[target] 344 | estimates_dict = new_estimates 345 | return estimates_dict 346 | -------------------------------------------------------------------------------- /scripts/README.md: -------------------------------------------------------------------------------- 1 | # _Open-Unmix_ for PyTorch: end-to-end torch branch 2 | 3 | [![status](https://joss.theoj.org/papers/571753bc54c5d6dd36382c3d801de41d/status.svg)](https://joss.theoj.org/papers/571753bc54c5d6dd36382c3d801de41d) [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/open-unmix-a-reference-implementation-for/music-source-separation-on-musdb18)](https://paperswithcode.com/sota/music-source-separation-on-musdb18?p=open-unmix-a-reference-implementation-for) 4 | 5 | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1mijF0zGWxN-KaxTnd0q6hayAlrID5fEQ) [![Gitter](https://badges.gitter.im/sigsep/open-unmix.svg)](https://gitter.im/sigsep/open-unmix?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge) [![Google group : Open-Unmix](https://img.shields.io/badge/discuss-on%20google%20groups-orange.svg)](https://groups.google.com/forum/#!forum/open-unmix) 6 | 7 | [![Build Status](https://travis-ci.com/sigsep/open-unmix-pytorch.svg?branch=master)](https://travis-ci.com/sigsep/open-unmix-pytorch) [![Docker hub](https://img.shields.io/docker/cloud/build/faroit/open-unmix-pytorch)](https://cloud.docker.com/u/faroit/repository/docker/faroit/open-unmix-pytorch) 8 | 9 | This repository contains the PyTorch (1.0+) implementation of __Open-Unmix__, a deep neural network reference implementation for music source separation, applicable for researchers, audio engineers and artists. __Open-Unmix__ provides ready-to-use models that allow users to separate pop music into four stems: __vocals__, __drums__, __bass__ and the remaining __other__ instruments. The models were pre-trained on the [MUSDB18](https://sigsep.github.io/datasets/musdb.html) dataset. See details at [apply pre-trained model](#getting-started). 10 | 11 | ## News: 12 | 13 | - 06/05/2020: We also added a pre-trained speech enhancement model (`umxse`) provided by Sony. For more information we refer [to this site](https://sigsep.github.io/open-unmix/se) 14 | 15 | __Related Projects:__ open-unmix-pytorch | [open-unmix-nnabla](https://github.com/sigsep/open-unmix-nnabla) | [musdb](https://github.com/sigsep/sigsep-mus-db) | [museval](https://github.com/sigsep/sigsep-mus-eval) | [norbert](https://github.com/sigsep/norbert) 16 | 17 | ## The Model for one source 18 | 19 | ![](https://docs.google.com/drawings/d/e/2PACX-1vTPoQiPwmdfET4pZhue1RvG7oEUJz7eUeQvCu6vzYeKRwHl6by4RRTnphImSKM0k5KXw9rZ1iIFnpGW/pub?w=959&h=308) 20 | 21 | To perform separation into multiple sources, _Open-unmix_ comprises multiple models that are trained for each particular target. While this makes the training less comfortable, it allows great flexibility to customize the training data for each target source. 22 | 23 | Each _Open-Unmix_ source model is based on a three-layer bidirectional deep LSTM. The model learns to predict the magnitude spectrogram of a target source, like _vocals_, from the magnitude spectrogram of a mixture input. Internally, the prediction is obtained by applying a mask on the input. The model is optimized in the magnitude domain using mean squared error. 24 | 25 | ### Input Stage 26 | 27 | __Open-Unmix__ operates in the time-frequency domain to perform its prediction. The input of the model is either: 28 | 29 | * __A time domain__ signal tensor of shape `(nb_samples, nb_channels, nb_timesteps)`, where `nb_samples` are the samples in a batch, `nb_channels` is 1 or 2 for mono or stereo audio, respectively, and `nb_timesteps` is the number of audio samples in the recording. 30 | 31 | In that case, the model computes spectrograms with `torch.STFT` on the fly. 32 | 33 | * Alternatively _open-unmix_ also takes **magnitude spectrograms** directly (e.g. when pre-computed and loaded from disk). 34 | 35 | In that case, the input is of shape `(nb_frames, nb_samples, nb_channels, nb_bins)`, where `nb_frames` and `nb_bins` are the time and frequency-dimensions of a Short-Time-Fourier-Transform. 36 | 37 | The input spectrogram is _standardized_ using the global mean and standard deviation for every frequency bin across all frames. Furthermore, we apply batch normalization in multiple stages of the model to make the training more robust against gain variation. 38 | 39 | ### Dimensionality reduction 40 | 41 | The LSTM is not operating on the original input spectrogram resolution. Instead, in the first step after the normalization, the network learns to compresses the frequency and channel axis of the model to reduce redundancy and make the model converge faster. 42 | 43 | ### Bidirectional-LSTM 44 | 45 | The core of __open-unmix__ is a three layer bidirectional [LSTM network](https://dl.acm.org/citation.cfm?id=1246450). Due to its recurrent nature, the model can be trained and evaluated on arbitrary length of audio signals. Since the model takes information from past and future simultaneously, the model cannot be used in an online/real-time manner. 46 | An uni-directional model can easily be trained as described [here](docs/training.md). 47 | 48 | ### Output Stage 49 | 50 | After applying the LSTM, the signal is decoded back to its original input dimensionality. In the last steps the output is multiplied with the input magnitude spectrogram, so that the models is asked to learn a mask. 51 | 52 | ## Putting source models together: the `Separator` 53 | 54 | For inference, this branch enables a `Separator` pytorch Module, that puts together one _Open-unmix_ model for each desired target, and combines their output through a multichannel generalized Wiener filter, before application of inverse STFTs using `torchaudio`. 55 | The filtering is a rewriting in torch of the [numpy implementation](https://github.com/sigsep/norbert) used in the main branch. 56 | 57 | 58 | ## Getting started 59 | 60 | ### Installation 61 | 62 | For installation we recommend to use the [Anaconda](https://anaconda.org/) python distribution. To create a conda environment for _open-unmix_, simply run: 63 | 64 | `conda env create -f environment-X.yml` where `X` is either [`cpu-linux`, `gpu-linux-cuda10`, `cpu-osx`], depending on your system. For now, we haven't tested windows support. 65 | 66 | ### Using Docker 67 | 68 | We also provide a docker container as an alternative to anaconda. That way performing separation of a local track in `~/Music/track1.wav` can be performed in a single line: 69 | 70 | ``` 71 | docker run -v ~/Music/:/data -it faroit/open-unmix-pytorch python test.py "/data/track1.wav" --outdir /data/track1 72 | ``` 73 | 74 | ### Applying pre-trained models on audio files 75 | 76 | We provide two pre-trained music separation models: 77 | 78 | * __`umxhq` (default)__ trained on [MUSDB18-HQ](https://sigsep.github.io/datasets/musdb.html#uncompressed-wav) which comprises the same tracks as in MUSDB18 but un-compressed which yield in a full bandwidth of 22050 Hz. 79 | 80 | [![DOI](https://zenodo.org/badge/DOI/10.5281/zenodo.3370489.svg)](https://doi.org/10.5281/zenodo.3370489) 81 | 82 | * __`umx`__ is trained on the regular [MUSDB18](https://sigsep.github.io/datasets/musdb.html#compressed-stems) which is bandwidth limited to 16 kHz do to AAC compression. This model should be used for comparison with other (older) methods for evaluation in [SiSEC18](sisec18.unmix.app). 83 | 84 | [![DOI](https://zenodo.org/badge/DOI/10.5281/zenodo.3370486.svg)](https://doi.org/10.5281/zenodo.3370486) 85 | 86 | Furthermore, we provide a model for speech enhancement trained by [Sony Corporation](link) 87 | 88 | * __`umxse`__ speech enhancement model is trained on the 28-speaker version of the [Voicebank+DEMAND corpus](https://datashare.is.ed.ac.uk/handle/10283/1942?show=full). 89 | 90 | [![DOI](https://zenodo.org/badge/DOI/10.5281/zenodo.3786908.svg)](https://doi.org/10.5281/zenodo.3786908) 91 | 92 | To separate audio files (`wav`, `flac`, `ogg` - but not `mp3`) files just run: 93 | 94 | ```bash 95 | umx input_file.wav --model umxhq 96 | ``` 97 | 98 | A more detailed list of the parameters used for the separation is given in the [inference.md](/docs/inference.md) document. 99 | We provide a [jupyter notebook on google colab](https://colab.research.google.com/drive/1mijF0zGWxN-KaxTnd0q6hayAlrID5fEQ) to 100 | experiment with open-unmix and to separate files online without any installation setup. 101 | 102 | ### Interface with separator fron python via torch.hub 103 | 104 | A pre-trained `Separator` can be loaded from pytorch based code using torch.hub.load: 105 | 106 | ```python 107 | separator = torch.hub.load('sigsep/open-unmix-pytorch', 'umxhq') 108 | ``` 109 | 110 | This object may then simply be used for separation of some `audio` (`torch.Tensor` of shape ), sampled at a sampling rate `rate`, through: 111 | 112 | ```python 113 | audio_stimates = separator(audio) 114 | ``` 115 | 116 | ### Load user-trained models (only music separation models) 117 | 118 | When a path instead of a model-name is provided to `--model` the pre-trained model will be loaded from disk. 119 | 120 | ```bash 121 | umx --model /path/to/model/root/directory input_file.wav 122 | ``` 123 | 124 | Note that `model` usually contains individual models for each target and performs separation using all models. E.g. if `model_path` contains `vocals` and `drums` models, two output files are generated, unless the `--residual-model` option is selected, in which case an additional source will be produced, containing an estimate of all that is not the targets in the mixtures. 125 | 126 | ### Evaluation using `museval` 127 | 128 | To perform evaluation in comparison to other SISEC systems, you would need to install the `museval` package using 129 | 130 | ``` 131 | pip install museval 132 | ``` 133 | 134 | and then run the evaluation using 135 | 136 | `python -m openunmix.evaluate --outdir /path/to/musdb/estimates --evaldir /path/to/museval/results` 137 | 138 | ### Results compared to SiSEC 2018 (SDR/Vocals) 139 | 140 | Open-Unmix yields state-of-the-art results compared to participants from [SiSEC 2018](https://sisec18.unmix.app/#/methods). The performance of `UMXHQ` and `UMX` is almost identical since it was evaluated on compressed STEMS. 141 | 142 | ![boxplot_updated](https://user-images.githubusercontent.com/72940/63944652-3f624c80-ca72-11e9-8d33-bed701679fe6.png) 143 | 144 | Note that 145 | 146 | 1. [`STL1`, `TAK2`, `TAK3`, `TAU1`, `UHL3`, `UMXHQ`] were omitted as they were _not_ trained on only _MUSDB18_. 147 | 2. [`HEL1`, `TAK1`, `UHL1`, `UHL2`] are not open-source. 148 | 149 | #### Scores (Median of frames, Median of tracks) 150 | 151 | |target|SDR |SIR | SAR | ISR | SDR | SIR | SAR | ISR | 152 | |------|-----|-----|-----|-----|-----|-----|-----|-----| 153 | |`model`|UMX |UMX |UMX |UMX |UMXHQ|UMXHQ|UMXHQ|UMXHQ| 154 | |vocals|6.32 |13.33| 6.52|11.93| 6.25|12.95| 6.50|12.70| 155 | |bass |5.23 |10.93| 6.34| 9.23| 5.07|10.35| 6.02| 9.71| 156 | |drums |5.73 |11.12| 6.02|10.51| 6.04|11.65| 5.93|11.17| 157 | |other |4.02 |6.59 | 4.74| 9.31| 4.28| 7.10| 4.62| 8.78| 158 | 159 | ## Training 160 | 161 | Details on the training is provided in a separate document [here](docs/training.md). 162 | 163 | ## Extensions 164 | 165 | Details on how _open-unmix_ can be extended or improved for future research on music separation is described in a separate document [here](docs/extensions.md). 166 | 167 | 168 | ## Design Choices 169 | 170 | we favored simplicity over performance to promote clearness of the code. The rationale is to have __open-unmix__ serve as a __baseline__ for future research while performance still meets current state-of-the-art (See [Evaluation](#Evaluation)). The results are comparable/better to those of `UHL1`/`UHL2` which obtained the best performance over all systems trained on MUSDB18 in the [SiSEC 2018 Evaluation campaign](https://sisec18.unmix.app). 171 | We designed the code to allow researchers to reproduce existing results, quickly develop new architectures and add own user data for training and testing. We favored framework specifics implementations instead of having a monolithic repository with common code for all frameworks. 172 | 173 | ## How to contribute 174 | 175 | _open-unmix_ is a community focused project, we therefore encourage the community to submit bug-fixes and requests for technical support through [github issues](https://github.com/sigsep/open-unmix-pytorch/issues/new/choose). For more details of how to contribute, please follow our [`CONTRIBUTING.md`](CONTRIBUTING.md). For help and support, please use the gitter chat or the google groups forums. 176 | 177 | ### Authors 178 | 179 | [Fabian-Robert Stöter](https://www.faroit.com/), [Antoine Liutkus](https://github.com/aliutkus), Inria and LIRMM, Montpellier, France 180 | 181 | ## References 182 | 183 |
If you use open-unmix for your research – Cite Open-Unmix 184 | 185 | ```latex 186 | @article{stoter19, 187 | author={F.-R. St\\"oter and S. Uhlich and A. Liutkus and Y. Mitsufuji}, 188 | title={Open-Unmix - A Reference Implementation for Music Source Separation}, 189 | journal={Journal of Open Source Software}, 190 | year=2019, 191 | doi = {10.21105/joss.01667}, 192 | url = {https://doi.org/10.21105/joss.01667} 193 | } 194 | ``` 195 | 196 |

197 |
198 | 199 |
If you use the MUSDB dataset for your research - Cite the MUSDB18 Dataset 200 |

201 | 202 | ```latex 203 | @misc{MUSDB18, 204 | author = {Rafii, Zafar and 205 | Liutkus, Antoine and 206 | Fabian-Robert St{\"o}ter and 207 | Mimilakis, Stylianos Ioannis and 208 | Bittner, Rachel}, 209 | title = {The {MUSDB18} corpus for music separation}, 210 | month = dec, 211 | year = 2017, 212 | doi = {10.5281/zenodo.1117372}, 213 | url = {https://doi.org/10.5281/zenodo.1117372} 214 | } 215 | ``` 216 | 217 |

218 |
219 | 220 | 221 |
If compare your results with SiSEC 2018 Participants - Cite the SiSEC 2018 LVA/ICA Paper 222 |

223 | 224 | ```latex 225 | @inproceedings{SiSEC18, 226 | author="St{\"o}ter, Fabian-Robert and Liutkus, Antoine and Ito, Nobutaka", 227 | title="The 2018 Signal Separation Evaluation Campaign", 228 | booktitle="Latent Variable Analysis and Signal Separation: 229 | 14th International Conference, LVA/ICA 2018, Surrey, UK", 230 | year="2018", 231 | pages="293--305" 232 | } 233 | ``` 234 | 235 |

236 |
237 | 238 | ⚠️ Please note that the official acronym for _open-unmix_ is **UMX**. 239 | 240 | ### License 241 | 242 | MIT 243 | 244 | ### Acknowledgements 245 | 246 |

247 | 248 | anr 249 |

250 | -------------------------------------------------------------------------------- /openunmix/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | ![sigsep logo](https://sigsep.github.io/hero.png) 3 | Open-Unmix is a deep neural network reference implementation for music source separation, applicable for researchers, audio engineers and artists. Open-Unmix provides ready-to-use models that allow users to separate pop music into four stems: vocals, drums, bass and the remaining other instruments. The models were pre-trained on the MUSDB18 dataset. See details at apply pre-trained model. 4 | 5 | This is the python package API documentation. 6 | Please checkout [the open-unmix website](https://sigsep.github.io/open-unmix) for more information. 7 | """ 8 | 9 | from openunmix import utils 10 | import torch.hub 11 | 12 | 13 | def umxse_spec(targets=None, device="cpu", pretrained=True): 14 | target_urls = { 15 | "speech": "https://zenodo.org/records/3786908/files/speech_f5e0d9f9.pth", 16 | "noise": "https://zenodo.org/records/3786908/files/noise_04a6fc2d.pth", 17 | } 18 | 19 | from .model import OpenUnmix 20 | 21 | if targets is None: 22 | targets = ["speech", "noise"] 23 | 24 | # determine the maximum bin count for a 16khz bandwidth model 25 | max_bin = utils.bandwidth_to_max_bin(rate=16000.0, n_fft=1024, bandwidth=16000) 26 | 27 | # load open unmix models speech enhancement models 28 | target_models = {} 29 | for target in targets: 30 | target_unmix = OpenUnmix(nb_bins=1024 // 2 + 1, nb_channels=1, hidden_size=256, max_bin=max_bin) 31 | 32 | # enable centering of stft to minimize reconstruction error 33 | if pretrained: 34 | state_dict = torch.hub.load_state_dict_from_url(target_urls[target], map_location=device) 35 | target_unmix.load_state_dict(state_dict, strict=False) 36 | target_unmix.eval() 37 | 38 | target_unmix.to(device) 39 | target_models[target] = target_unmix 40 | return target_models 41 | 42 | 43 | def umxse(targets=None, residual=False, niter=1, device="cpu", pretrained=True, filterbank="torch", wiener_win_len=300): 44 | """ 45 | Open Unmix Speech Enhancemennt 1-channel BiLSTM Model 46 | trained on the 28-speaker version of Voicebank+Demand 47 | (Sampling rate: 16kHz) 48 | 49 | Args: 50 | targets (str): select the targets for the source to be separated. 51 | a list including: ['speech', 'noise']. 52 | If you don't pick them all, you probably want to 53 | activate the `residual=True` option. 54 | Defaults to all available targets per model. 55 | pretrained (bool): If True, returns a model pre-trained on MUSDB18-HQ 56 | residual (bool): if True, a "garbage" target is created 57 | niter (int): the number of post-processingiterations, defaults to 0 58 | device (str): selects device to be used for inference 59 | wiener_win_len (int or None): The size of the excerpts 60 | (number of frames) on which to apply filtering 61 | independently. This means assuming time varying stereo models and 62 | localization of sources. 63 | None means not batching but using the whole signal. It comes at the 64 | price of a much larger memory usage. 65 | filterbank (str): filterbank implementation method. 66 | Supported are `['torch', 'asteroid']`. `torch` is about 30% faster 67 | compared to `asteroid` on large FFT sizes such as 4096. However, 68 | asteroids stft can be exported to onnx, which makes is practical 69 | for deployment. 70 | 71 | Reference: 72 | Uhlich, Stefan, & Mitsufuji, Yuki. (2020). 73 | Open-Unmix for Speech Enhancement (UMX SE). 74 | Zenodo. http://doi.org/10.5281/zenodo.3786908 75 | """ 76 | from .model import Separator 77 | 78 | target_models = umxse_spec(targets=targets, device=device, pretrained=pretrained) 79 | 80 | separator = Separator( 81 | target_models=target_models, 82 | niter=niter, 83 | residual=residual, 84 | n_fft=1024, 85 | n_hop=512, 86 | nb_channels=1, 87 | sample_rate=16000.0, 88 | wiener_win_len=wiener_win_len, 89 | filterbank=filterbank, 90 | ).to(device) 91 | 92 | return separator 93 | 94 | 95 | def umxhq_spec(targets=None, device="cpu", pretrained=True): 96 | from .model import OpenUnmix 97 | 98 | # set urls for weights 99 | target_urls = { 100 | "bass": "https://zenodo.org/records/3370489/files/bass-8d85a5bd.pth", 101 | "drums": "https://zenodo.org/records/3370489/files/drums-9619578f.pth", 102 | "other": "https://zenodo.org/records/3370489/files/other-b52fbbf7.pth", 103 | "vocals": "https://zenodo.org/records/3370489/files/vocals-b62c91ce.pth", 104 | } 105 | 106 | if targets is None: 107 | targets = ["vocals", "drums", "bass", "other"] 108 | 109 | # determine the maximum bin count for a 16khz bandwidth model 110 | max_bin = utils.bandwidth_to_max_bin(rate=44100.0, n_fft=4096, bandwidth=16000) 111 | 112 | target_models = {} 113 | for target in targets: 114 | # load open unmix model 115 | target_unmix = OpenUnmix(nb_bins=4096 // 2 + 1, nb_channels=2, hidden_size=512, max_bin=max_bin) 116 | 117 | # enable centering of stft to minimize reconstruction error 118 | if pretrained: 119 | state_dict = torch.hub.load_state_dict_from_url(target_urls[target], map_location=device) 120 | target_unmix.load_state_dict(state_dict, strict=False) 121 | target_unmix.eval() 122 | 123 | target_unmix.to(device) 124 | target_models[target] = target_unmix 125 | return target_models 126 | 127 | 128 | def umxhq( 129 | targets=None, 130 | residual=False, 131 | niter=1, 132 | device="cpu", 133 | pretrained=True, 134 | wiener_win_len=300, 135 | filterbank="torch", 136 | ): 137 | """ 138 | Open Unmix 2-channel/stereo BiLSTM Model trained on MUSDB18-HQ 139 | 140 | Args: 141 | targets (str): select the targets for the source to be separated. 142 | a list including: ['vocals', 'drums', 'bass', 'other']. 143 | If you don't pick them all, you probably want to 144 | activate the `residual=True` option. 145 | Defaults to all available targets per model. 146 | pretrained (bool): If True, returns a model pre-trained on MUSDB18-HQ 147 | residual (bool): if True, a "garbage" target is created 148 | niter (int): the number of post-processingiterations, defaults to 0 149 | device (str): selects device to be used for inference 150 | wiener_win_len (int or None): The size of the excerpts 151 | (number of frames) on which to apply filtering 152 | independently. This means assuming time varying stereo models and 153 | localization of sources. 154 | None means not batching but using the whole signal. It comes at the 155 | price of a much larger memory usage. 156 | filterbank (str): filterbank implementation method. 157 | Supported are `['torch', 'asteroid']`. `torch` is about 30% faster 158 | compared to `asteroid` on large FFT sizes such as 4096. However, 159 | asteroids stft can be exported to onnx, which makes is practical 160 | for deployment. 161 | """ 162 | 163 | from .model import Separator 164 | 165 | target_models = umxhq_spec(targets=targets, device=device, pretrained=pretrained) 166 | 167 | separator = Separator( 168 | target_models=target_models, 169 | niter=niter, 170 | residual=residual, 171 | n_fft=4096, 172 | n_hop=1024, 173 | nb_channels=2, 174 | sample_rate=44100.0, 175 | wiener_win_len=wiener_win_len, 176 | filterbank=filterbank, 177 | ).to(device) 178 | 179 | return separator 180 | 181 | 182 | def umx_spec(targets=None, device="cpu", pretrained=True): 183 | from .model import OpenUnmix 184 | 185 | # set urls for weights 186 | target_urls = { 187 | "bass": "https://zenodo.org/records/3370486/files/bass-646024d3.pth", 188 | "drums": "https://zenodo.org/records/3370486/files/drums-5a48008b.pth", 189 | "other": "https://zenodo.org/records/3370486/files/other-f8e132cc.pth", 190 | "vocals": "https://zenodo.org/records/3370486/files/vocals-c8df74a5.pth", 191 | } 192 | 193 | if targets is None: 194 | targets = ["vocals", "drums", "bass", "other"] 195 | 196 | # determine the maximum bin count for a 16khz bandwidth model 197 | max_bin = utils.bandwidth_to_max_bin(rate=44100.0, n_fft=4096, bandwidth=16000) 198 | 199 | target_models = {} 200 | for target in targets: 201 | # load open unmix model 202 | target_unmix = OpenUnmix(nb_bins=4096 // 2 + 1, nb_channels=2, hidden_size=512, max_bin=max_bin) 203 | 204 | # enable centering of stft to minimize reconstruction error 205 | if pretrained: 206 | state_dict = torch.hub.load_state_dict_from_url(target_urls[target], map_location=device) 207 | target_unmix.load_state_dict(state_dict, strict=False) 208 | target_unmix.eval() 209 | 210 | target_unmix.to(device) 211 | target_models[target] = target_unmix 212 | return target_models 213 | 214 | 215 | def umx( 216 | targets=None, 217 | residual=False, 218 | niter=1, 219 | device="cpu", 220 | pretrained=True, 221 | wiener_win_len=300, 222 | filterbank="torch", 223 | ): 224 | """ 225 | Open Unmix 2-channel/stereo BiLSTM Model trained on MUSDB18 226 | 227 | Args: 228 | targets (str): select the targets for the source to be separated. 229 | a list including: ['vocals', 'drums', 'bass', 'other']. 230 | If you don't pick them all, you probably want to 231 | activate the `residual=True` option. 232 | Defaults to all available targets per model. 233 | pretrained (bool): If True, returns a model pre-trained on MUSDB18-HQ 234 | residual (bool): if True, a "garbage" target is created 235 | niter (int): the number of post-processingiterations, defaults to 0 236 | device (str): selects device to be used for inference 237 | wiener_win_len (int or None): The size of the excerpts 238 | (number of frames) on which to apply filtering 239 | independently. This means assuming time varying stereo models and 240 | localization of sources. 241 | None means not batching but using the whole signal. It comes at the 242 | price of a much larger memory usage. 243 | filterbank (str): filterbank implementation method. 244 | Supported are `['torch', 'asteroid']`. `torch` is about 30% faster 245 | compared to `asteroid` on large FFT sizes such as 4096. However, 246 | asteroids stft can be exported to onnx, which makes is practical 247 | for deployment. 248 | 249 | """ 250 | 251 | from .model import Separator 252 | 253 | target_models = umx_spec(targets=targets, device=device, pretrained=pretrained) 254 | separator = Separator( 255 | target_models=target_models, 256 | niter=niter, 257 | residual=residual, 258 | n_fft=4096, 259 | n_hop=1024, 260 | nb_channels=2, 261 | sample_rate=44100.0, 262 | wiener_win_len=wiener_win_len, 263 | filterbank=filterbank, 264 | ).to(device) 265 | 266 | return separator 267 | 268 | 269 | def umxl_spec(targets=None, device="cpu", pretrained=True): 270 | from .model import OpenUnmix 271 | 272 | # set urls for weights 273 | target_urls = { 274 | "bass": "https://zenodo.org/records/5069601/files/bass-2ca1ce51.pth", 275 | "drums": "https://zenodo.org/records/5069601/files/drums-69e0ebd4.pth", 276 | "other": "https://zenodo.org/records/5069601/files/other-c8c5b3e6.pth", 277 | "vocals": "https://zenodo.org/records/5069601/files/vocals-bccbd9aa.pth", 278 | } 279 | 280 | if targets is None: 281 | targets = ["vocals", "drums", "bass", "other"] 282 | 283 | # determine the maximum bin count for a 16khz bandwidth model 284 | max_bin = utils.bandwidth_to_max_bin(rate=44100.0, n_fft=4096, bandwidth=16000) 285 | 286 | target_models = {} 287 | for target in targets: 288 | # load open unmix model 289 | target_unmix = OpenUnmix(nb_bins=4096 // 2 + 1, nb_channels=2, hidden_size=1024, max_bin=max_bin) 290 | 291 | # enable centering of stft to minimize reconstruction error 292 | if pretrained: 293 | state_dict = torch.hub.load_state_dict_from_url(target_urls[target], map_location=device) 294 | target_unmix.load_state_dict(state_dict, strict=False) 295 | target_unmix.eval() 296 | 297 | target_unmix.to(device) 298 | target_models[target] = target_unmix 299 | return target_models 300 | 301 | 302 | def umxl( 303 | targets=None, 304 | residual=False, 305 | niter=1, 306 | device="cpu", 307 | pretrained=True, 308 | wiener_win_len=300, 309 | filterbank="torch", 310 | ): 311 | """ 312 | Open Unmix Extra (UMX-L), 2-channel/stereo BLSTM Model trained on a private dataset 313 | of ~400h of multi-track audio. 314 | 315 | 316 | Args: 317 | targets (str): select the targets for the source to be separated. 318 | a list including: ['vocals', 'drums', 'bass', 'other']. 319 | If you don't pick them all, you probably want to 320 | activate the `residual=True` option. 321 | Defaults to all available targets per model. 322 | pretrained (bool): If True, returns a model pre-trained on MUSDB18-HQ 323 | residual (bool): if True, a "garbage" target is created 324 | niter (int): the number of post-processingiterations, defaults to 0 325 | device (str): selects device to be used for inference 326 | wiener_win_len (int or None): The size of the excerpts 327 | (number of frames) on which to apply filtering 328 | independently. This means assuming time varying stereo models and 329 | localization of sources. 330 | None means not batching but using the whole signal. It comes at the 331 | price of a much larger memory usage. 332 | filterbank (str): filterbank implementation method. 333 | Supported are `['torch', 'asteroid']`. `torch` is about 30% faster 334 | compared to `asteroid` on large FFT sizes such as 4096. However, 335 | asteroids stft can be exported to onnx, which makes is practical 336 | for deployment. 337 | 338 | """ 339 | 340 | from .model import Separator 341 | 342 | target_models = umxl_spec(targets=targets, device=device, pretrained=pretrained) 343 | separator = Separator( 344 | target_models=target_models, 345 | niter=niter, 346 | residual=residual, 347 | n_fft=4096, 348 | n_hop=1024, 349 | nb_channels=2, 350 | sample_rate=44100.0, 351 | wiener_win_len=wiener_win_len, 352 | filterbank=filterbank, 353 | ).to(device) 354 | 355 | return separator 356 | -------------------------------------------------------------------------------- /docs/evaluate.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | openunmix.evaluate API documentation 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 |
21 |
22 |
23 |

Module openunmix.evaluate

24 |
25 |
26 |
27 | 28 | Expand source code 29 | Browse git 30 | 31 |
import argparse
 32 | import functools
 33 | import json
 34 | import multiprocessing
 35 | from typing import Optional, Union
 36 | 
 37 | import musdb
 38 | import museval
 39 | import torch
 40 | import tqdm
 41 | 
 42 | from openunmix import utils
 43 | 
 44 | 
 45 | def separate_and_evaluate(
 46 |     track: musdb.MultiTrack,
 47 |     targets: list,
 48 |     model_str_or_path: str,
 49 |     niter: int,
 50 |     output_dir: str,
 51 |     eval_dir: str,
 52 |     residual: bool,
 53 |     mus,
 54 |     aggregate_dict: dict = None,
 55 |     device: Union[str, torch.device] = "cpu",
 56 |     wiener_win_len: Optional[int] = None,
 57 |     filterbank="torch",
 58 | ) -> str:
 59 | 
 60 |     separator = utils.load_separator(
 61 |         model_str_or_path=model_str_or_path,
 62 |         targets=targets,
 63 |         niter=niter,
 64 |         residual=residual,
 65 |         wiener_win_len=wiener_win_len,
 66 |         device=device,
 67 |         pretrained=True,
 68 |         filterbank=filterbank,
 69 |     )
 70 | 
 71 |     separator.freeze()
 72 |     separator.to(device)
 73 | 
 74 |     audio = torch.as_tensor(track.audio, dtype=torch.float32, device=device)
 75 |     audio = utils.preprocess(audio, track.rate, separator.sample_rate)
 76 | 
 77 |     estimates = separator(audio)
 78 |     estimates = separator.to_dict(estimates, aggregate_dict=aggregate_dict)
 79 | 
 80 |     for key in estimates:
 81 |         estimates[key] = estimates[key][0].cpu().detach().numpy().T
 82 |     if output_dir:
 83 |         mus.save_estimates(estimates, track, output_dir)
 84 | 
 85 |     scores = museval.eval_mus_track(track, estimates, output_dir=eval_dir)
 86 |     return scores
 87 | 
 88 | 
 89 | if __name__ == "__main__":
 90 |     # Training settings
 91 |     parser = argparse.ArgumentParser(description="MUSDB18 Evaluation", add_help=False)
 92 | 
 93 |     parser.add_argument(
 94 |         "--targets",
 95 |         nargs="+",
 96 |         default=["vocals", "drums", "bass", "other"],
 97 |         type=str,
 98 |         help="provide targets to be processed. \
 99 |               If none, all available targets will be computed",
100 |     )
101 | 
102 |     parser.add_argument(
103 |         "--model",
104 |         default="umxhq",
105 |         type=str,
106 |         help="path to mode base directory of pretrained models",
107 |     )
108 | 
109 |     parser.add_argument(
110 |         "--outdir",
111 |         type=str,
112 |         help="Results path where audio evaluation results are stored",
113 |     )
114 | 
115 |     parser.add_argument("--evaldir", type=str, help="Results path for museval estimates")
116 | 
117 |     parser.add_argument("--root", type=str, help="Path to MUSDB18")
118 | 
119 |     parser.add_argument("--subset", type=str, default="test", help="MUSDB subset (`train`/`test`)")
120 | 
121 |     parser.add_argument("--cores", type=int, default=1)
122 | 
123 |     parser.add_argument(
124 |         "--no-cuda", action="store_true", default=False, help="disables CUDA inference"
125 |     )
126 | 
127 |     parser.add_argument(
128 |         "--is-wav",
129 |         action="store_true",
130 |         default=False,
131 |         help="flags wav version of the dataset",
132 |     )
133 | 
134 |     parser.add_argument(
135 |         "--niter",
136 |         type=int,
137 |         default=1,
138 |         help="number of iterations for refining results.",
139 |     )
140 | 
141 |     parser.add_argument(
142 |         "--wiener-win-len",
143 |         type=int,
144 |         default=300,
145 |         help="Number of frames on which to apply filtering independently",
146 |     )
147 | 
148 |     parser.add_argument(
149 |         "--residual",
150 |         type=str,
151 |         default=None,
152 |         help="if provided, build a source with given name"
153 |         "for the mix minus all estimated targets",
154 |     )
155 | 
156 |     parser.add_argument(
157 |         "--aggregate",
158 |         type=str,
159 |         default=None,
160 |         help="if provided, must be a string containing a valid expression for "
161 |         "a dictionary, with keys as output target names, and values "
162 |         "a list of targets that are used to build it. For instance: "
163 |         '\'{"vocals":["vocals"], "accompaniment":["drums",'
164 |         '"bass","other"]}\'',
165 |     )
166 | 
167 |     args = parser.parse_args()
168 | 
169 |     use_cuda = not args.no_cuda and torch.cuda.is_available()
170 |     device = torch.device("cuda" if use_cuda else "cpu")
171 | 
172 |     mus = musdb.DB(
173 |         root=args.root,
174 |         download=args.root is None,
175 |         subsets=args.subset,
176 |         is_wav=args.is_wav,
177 |     )
178 |     aggregate_dict = None if args.aggregate is None else json.loads(args.aggregate)
179 | 
180 |     if args.cores > 1:
181 |         pool = multiprocessing.Pool(args.cores)
182 |         results = museval.EvalStore()
183 |         scores_list = list(
184 |             pool.imap_unordered(
185 |                 func=functools.partial(
186 |                     separate_and_evaluate,
187 |                     targets=args.targets,
188 |                     model_str_or_path=args.model,
189 |                     niter=args.niter,
190 |                     residual=args.residual,
191 |                     mus=mus,
192 |                     aggregate_dict=aggregate_dict,
193 |                     output_dir=args.outdir,
194 |                     eval_dir=args.evaldir,
195 |                     device=device,
196 |                 ),
197 |                 iterable=mus.tracks,
198 |                 chunksize=1,
199 |             )
200 |         )
201 |         pool.close()
202 |         pool.join()
203 |         for scores in scores_list:
204 |             results.add_track(scores)
205 | 
206 |     else:
207 |         results = museval.EvalStore()
208 |         for track in tqdm.tqdm(mus.tracks):
209 |             scores = separate_and_evaluate(
210 |                 track,
211 |                 targets=args.targets,
212 |                 model_str_or_path=args.model,
213 |                 niter=args.niter,
214 |                 residual=args.residual,
215 |                 mus=mus,
216 |                 aggregate_dict=aggregate_dict,
217 |                 output_dir=args.outdir,
218 |                 eval_dir=args.evaldir,
219 |                 device=device,
220 |             )
221 |             print(track, "\n", scores)
222 |             results.add_track(scores)
223 | 
224 |     print(results)
225 |     method = museval.MethodStore()
226 |     method.add_evalstore(results, args.model)
227 |     method.save(args.model + ".pandas")
228 |
229 |
230 |
231 |
232 |
233 |
234 |
235 |

Functions

236 |
237 |
238 | def separate_and_evaluate(track: musdb.audio_classes.MultiTrack, targets: list, model_str_or_path: str, niter: int, output_dir: str, eval_dir: str, residual: bool, mus, aggregate_dict: dict = None, device: Union[str, torch.device] = 'cpu', wiener_win_len: Union[int, NoneType] = None, filterbank='torch') ‑> str 239 |
240 |
241 |
242 |
243 | 244 | Expand source code 245 | Browse git 246 | 247 |
def separate_and_evaluate(
248 |     track: musdb.MultiTrack,
249 |     targets: list,
250 |     model_str_or_path: str,
251 |     niter: int,
252 |     output_dir: str,
253 |     eval_dir: str,
254 |     residual: bool,
255 |     mus,
256 |     aggregate_dict: dict = None,
257 |     device: Union[str, torch.device] = "cpu",
258 |     wiener_win_len: Optional[int] = None,
259 |     filterbank="torch",
260 | ) -> str:
261 | 
262 |     separator = utils.load_separator(
263 |         model_str_or_path=model_str_or_path,
264 |         targets=targets,
265 |         niter=niter,
266 |         residual=residual,
267 |         wiener_win_len=wiener_win_len,
268 |         device=device,
269 |         pretrained=True,
270 |         filterbank=filterbank,
271 |     )
272 | 
273 |     separator.freeze()
274 |     separator.to(device)
275 | 
276 |     audio = torch.as_tensor(track.audio, dtype=torch.float32, device=device)
277 |     audio = utils.preprocess(audio, track.rate, separator.sample_rate)
278 | 
279 |     estimates = separator(audio)
280 |     estimates = separator.to_dict(estimates, aggregate_dict=aggregate_dict)
281 | 
282 |     for key in estimates:
283 |         estimates[key] = estimates[key][0].cpu().detach().numpy().T
284 |     if output_dir:
285 |         mus.save_estimates(estimates, track, output_dir)
286 | 
287 |     scores = museval.eval_mus_track(track, estimates, output_dir=eval_dir)
288 |     return scores
289 |
290 |
291 |
292 |
293 |
294 |
295 |
296 | 314 |
315 | 318 | 319 | -------------------------------------------------------------------------------- /docs/training.md: -------------------------------------------------------------------------------- 1 | # Training Open-Unmix 2 | 3 | > This documentation refers to the standard training procedure for _Open-unmix_, where each target is trained independently. It has not been updated for the end-to-end training capabilities that the `Separator` module allows. Please contribute if you try this. 4 | 5 | Both models, `umxhq` and `umx` that are provided with pre-trained weights, can be trained using the default parameters of the `scripts/train.py` function. 6 | 7 | ## Installation 8 | 9 | The train function is not part of the python package, thus we suggest to use [Anaconda](https://anaconda.org/) to install the training requirments, also because the environment would allow reproducible results. 10 | 11 | To create a conda environment for _open-unmix_, simply run: 12 | 13 | `conda env create -f scripts/environment-X.yml` where `X` is either [`cpu-linux`, `gpu-linux-cuda10`, `cpu-osx`], depending on your system. For now, we haven't tested windows support. 14 | 15 | ## Training API 16 | 17 | The [MUSDB18](https://sigsep.github.io/datasets/musdb.html) and [MUSDB18-HQ](https://sigsep.github.io/datasets/musdb.html) are the largest freely available datasets for professionally produced music tracks (~10h duration) of different styles. They come with isolated `drums`, `bass`, `vocals` and `others` stems. _MUSDB18_ contains two subsets: "train", composed of 100 songs, and "test", composed of 50 songs. 18 | 19 | To directly train a vocal model with _open-unmix_, we first would need to download one of the datasets and place in _unzipped_ in a directory of your choice (called `root`). 20 | 21 | | Argument | Description | Default | 22 | |----------|-------------|---------| 23 | | `--root ` | path to root of dataset on disk. | `None` | 24 | 25 | Also note that, if `--root` is not specified, we automatically download a 7 second preview version of the MUSDB18 dataset. While this is comfortable for testing purposes, we wouldn't recommend to actually train your model on this. 26 | 27 | Training can be started using 28 | 29 | ```bash 30 | python train.py --root path/to/musdb18 --target vocals 31 | ``` 32 | 33 | Training `MUSDB18` using _open-unmix_ comes with several design decisions that we made as part of our defaults to improve efficiency and performance: 34 | 35 | * __chunking__: we do not feed full audio tracks into _open-unmix_ but instead chunk the audio into 6s excerpts (`--seq-dur 6.0`). 36 | * __balanced track sampling__: to not create a bias for longer audio tracks we randomly yield one track from MUSDB18 and select a random chunk subsequently. In one epoch we select (on average) 64 samples from each track. 37 | * __source augmentation__: we apply random gains between `0.25` and `1.25` to all sources before mixing. Furthermore, we randomly swap the channels the input mixture. 38 | * __random track mixing__: for a given target we select a _random track_ with replacement. To yield a mixture we draw the interfering sources from different tracks (again with replacement) to increase generalization of the model. 39 | * __fixed validation split__: we provide a fixed validation split of [14 tracks](https://github.com/sigsep/sigsep-mus-db/blob/b283da5b8f24e84172a60a06bb8f3dacd57aa6cd/musdb/configs/mus.yaml#L41). We evaluate on these tracks in full length instead of using chunking to have evaluation as close as possible to the actual test data. 40 | 41 | Some of the parameters for the MUSDB sampling can be controlled using the following arguments: 42 | 43 | | Argument | Description | Default | 44 | |---------------------|-----------------------------------------------|--------------| 45 | | `--is-wav` | loads the decoded WAVs instead of STEMS for faster data loading. See [more details here](https://github.com/sigsep/sigsep-mus-db#using-wav-files-optional). | `False` | 46 | | `--samples-per-track ` | sets the number of samples that are randomly drawn from each track | `64` | 47 | | `--source-augmentations ` | applies augmentations to each audio source before mixing, available augmentations: `[gain, channelswap]`| [gain, channelswap] | 48 | 49 | ## Training and Model Parameters 50 | 51 | An extensive list of additional training parameters allows researchers to quickly try out different parameterizations such as a different FFT size. The table below, we list the additional training parameters and their default values (used for `umxhq` and `umx`L: 52 | 53 | | Argument | Description | Default | 54 | |----------------------------|---------------------------------------------------------------------------------|-----------------| 55 | | `--target ` | name of target source (will be passed to the dataset) | `vocals` | 56 | | `--output ` | path where to save the trained output model as well as checkpoints. | `./open-unmix` | 57 | | `--checkpoint ` | path to checkpoint of target model to resume training. | not set | 58 | | `--model ` | path or str to pretrained target to fine-tune model | not set | 59 | | `--no_cuda` | disable cuda even if available | not set | 60 | | `--epochs ` | Number of epochs to train | `1000` | 61 | | `--batch-size ` | Batch size has influence on memory usage and performance of the LSTM layer | `16` | 62 | | `--patience ` | early stopping patience | `140` | 63 | | `--seq-dur ` | Sequence duration in seconds of chunks taken from the dataset. A value of `<=0.0` results in full/variable length | `6.0` | 64 | | `--unidirectional` | changes the bidirectional LSTM to unidirectional (for real-time applications) | not set | 65 | | `--hidden-size ` | Hidden size parameter of dense bottleneck layers | `512` | 66 | | `--nfft ` | STFT FFT window length in samples | `4096` | 67 | | `--nhop ` | STFT hop length in samples | `1024` | 68 | | `--lr ` | learning rate | `0.001` | 69 | | `--lr-decay-patience ` | learning rate decay patience for plateau scheduler | `80` | 70 | | `--lr-decay-gamma ` | gamma of learning rate plateau scheduler. | `0.3` | 71 | | `--weight-decay ` | weight decay for regularization | `0.00001` | 72 | | `--bandwidth ` | maximum bandwidth in Hertz processed by the LSTM. Input and Output is always full bandwidth! | `16000` | 73 | | `--nb-channels ` | set number of channels for model (1 for mono (spectral downmix is applied,) 2 for stereo) | `2` | 74 | | `--nb-workers ` | Number of (parallel) workers for data-loader, can be safely increased for wav files | `0` | 75 | | `--quiet` | disable print and progress bar during training | not set | 76 | | `--seed ` | Initial seed to set the random initialization | `42` | 77 | | `--audio-backend ` | choose audio loading backend, either `sox` or `soundfile` | `soundfile` for training, `sox` for inference | 78 | 79 | ### Training details of `umxhq` 80 | 81 | The training of `umxhq` took place on Nvidia RTX2080 cards. Equipped with fast SSDs and `--nb-workers 4`, we could utilize around 90% of the GPU, thus training time was around 80 seconds per epoch. We ran four different seeds for each target and selected the model with the lowest validation loss. 82 | 83 | The training and validation loss curves are plotted below: 84 | 85 | ![umx-hq](https://user-images.githubusercontent.com/72940/61230598-9e6e3b00-a72a-11e9-8a89-aca1862341eb.png) 86 | 87 | ## Other Datasets 88 | 89 | _open-unmix_ uses standard PyTorch [`torch.utils.data.Dataset`](https://pytorch.org/docs/stable/data.html#torch.utils.data.Dataset) classes. The repository comes with __five__ different datasets which cover a wide range of tasks and applications around source separation. Furthermore we also provide a template Dataset if you want to start using your own dataset. The dataset can be selected through a command line argument: 90 | 91 | | Argument | Description | Default | 92 | |----------------------------|------------------------------------------------------------------------|--------------| 93 | | `--dataset ` | Name of the dataset (select from `musdb`, `aligned`, `sourcefolder`, `trackfolder_var`, `trackfolder_fix`) | `musdb` | 94 | 95 | ### `AlignedDataset` (aligned) 96 | 97 | This dataset assumes multiple track folders, where each track includes an input and one output file, directly corresponding to the input and the output of the model. 98 | 99 | This dataset is the most basic of all datasets provided here, due to the least amount of 100 | preprocessing, it is also the fastest option, however, it lacks any kind of source augmentations or custom mixing. Instead, it directly uses the target files that are within the folder. The filenames would have to be identical for each track. E.g, for the first sample of the training, input could be `1/mixture.wav` and output could be `1/vocals.wav`. 101 | 102 | Typical use cases: 103 | 104 | * Source Separation (Mixture -> Target) 105 | * Denoising (Noisy -> Clean) 106 | * Bandwidth Extension (Low Bandwidth -> High Bandwidth) 107 | 108 | #### File Structure 109 | 110 | ``` 111 | data/train/1/mixture.wav --> input 112 | data/train/1/vocals.wav ---> output 113 | ... 114 | data/valid/1/mixture.wav --> input 115 | data/valid/1/vocals.wav ---> output 116 | 117 | ``` 118 | 119 | #### Parameters 120 | 121 | | Argument | Description | Default | 122 | |----------|-------------|---------| 123 | |`--input-file ` | input file name | `None` | 124 | |`--output-file ` | output file name | `None` | 125 | 126 | #### Example 127 | 128 | ```bash 129 | python train.py --dataset aligned --root /dataset --input_file mixture.wav --output_file vocals.wav 130 | ``` 131 | 132 | ### `SourceFolderDataset` (sourcefolder) 133 | 134 | A dataset of that assumes folders of sources, 135 | instead of track folders. This is a common 136 | format for speech and environmental sound datasets 137 | such das DCASE. For each source a variable number of 138 | tracks/sounds is available, therefore the dataset is unaligned by design. 139 | 140 | In this scenario one could easily train a network to separate a target sounds from interfering sounds. For each sample, the data loader loads a random combination of target+interferer as the input and performs a linear mixture of these. The output of the model is the target. 141 | 142 | #### File structure 143 | 144 | ``` 145 | train/vocals/track11.wav -----------------\ 146 | train/drums/track202.wav (interferer1) ---+--> input 147 | train/bass/track007a.wav (interferer2) --/ 148 | 149 | train/vocals/track11.wav ---------------------> output 150 | ``` 151 | 152 | #### Parameters 153 | 154 | | Argument | Description | Default | 155 | |----------|-------------|---------| 156 | |`--interferer-dirs list[]` | list of directories used as interferers | `None` | 157 | |`--target-dir ` | directory that contains the target source | `None` | 158 | |`--ext ` | File extension | `.wav` | 159 | |`--ext ` | File extension | `.wav` | 160 | |`--nb-train-samples ` | Number of samples drawn for training | `1000` | 161 | |`--nb-valid-samples ` | Number of samples drawn for validation | `100` | 162 | |`--source-augmentations list[]` | List of augmentation functions that are processed in the order of the list | | 163 | 164 | #### Example 165 | 166 | ```bash 167 | python train.py --dataset sourcefolder --root /data --target-dir vocals --interferer-dirs car_noise wind_noise --ext .ogg --nb-train-samples 1000 168 | ``` 169 | 170 | ### `FixedSourcesTrackFolderDataset` (trackfolder_fix) 171 | 172 | A dataset of that assumes audio sources to be stored 173 | in track folder where each track has a fixed number of sources. For each track the users specifies the target file-name (`target_file`) and a list of interferences files (`interferer_files`). 174 | A linear mix is performed on the fly by summing the target and the interferers up. 175 | 176 | Due to the fact that all tracks comprise the exact same set of sources, the random track mixing augmentation technique can be used, where sources from different tracks are mixed together. Setting `random_track_mix=True` results in an unaligned dataset. 177 | When random track mixing is enabled, we define an epoch as when the the target source from all tracks has been seen and only once with whatever interfering sources has randomly been drawn. 178 | 179 | This dataset is recommended to be used for small/medium size for example like the MUSDB18 or other custom source separation datasets. 180 | 181 | #### File structure 182 | 183 | ```sh 184 | train/1/vocals.wav ---------------\ 185 | train/1/drums.wav (interferer1) ---+--> input 186 | train/1/bass.wav -(interferer2) --/ 187 | 188 | train/1/vocals.wav -------------------> output 189 | ``` 190 | 191 | #### Parameters 192 | 193 | | Argument | Description | Default | 194 | |----------|-------------|---------| 195 | |`--target-file ` | Target file (includes extension) | `None` | 196 | |`--interferer-files list[]` | list of interfering sources | `None` | 197 | |`--random-track-mix` | Applies random track mixing | `False` | 198 | |`--source-augmentations list[]` | List of augmentation functions that are processed in the order of the list | | 199 | 200 | #### Example 201 | 202 | ``` 203 | python train.py --root /data --dataset trackfolder_fix --target-file vocals.flac --interferer-files bass.flac drums.flac other.flac 204 | ``` 205 | 206 | ### `VariableSourcesTrackFolderDataset` (trackfolder_var) 207 | 208 | A dataset of that assumes audio sources to be stored in track folder where each track has a _variable_ number of sources. The users specifies the target file-name (`target_file`) and the extension of sources to used for mixing. A linear mix is performed on the fly by summing all sources in a track folder. 209 | 210 | Since the number of sources differ per track, while target is fixed, a random track mix augmentation cannot be used. 211 | Also make sure, that you do not provide the mixture file among the sources! This dataset maximizes the number of tracks that can be used since it doesn't require the presence of a fixed number of sources per track. However, it is required to 212 | have the target file to be present. To increase the dataset utilization even further users can enable the `--silence-missing-targets` option that outputs silence to missing targets. 213 | 214 | #### File structure 215 | 216 | ```sh 217 | train/1/vocals.wav --> input target \ 218 | train/1/drums.wav --> input target | 219 | train/1/bass.wav --> input target --+--> input 220 | train/1/accordion.wav --> input target | 221 | train/1/marimba.wav --> input target / 222 | 223 | train/1/vocals.wav -----------------------> output 224 | ``` 225 | 226 | #### Parameters 227 | 228 | | Argument | Description | Default | 229 | |----------|-------------|---------| 230 | |`--target-file ` | file name of target file | `None` | 231 | |`--silence-missing-targets` | if a target is not among the list of sources it will be filled with zero | not set | 232 | |`random interferer mixing` | use _random track_ for the inference track to increase generalization of the model. | not set | 233 | |`--ext ` | File extension that is used to find the interfering files | `.wav` | 234 | |`--source-augmentations list[]` | List of augmentation functions that are processed in the order of the list | | 235 | 236 | #### Example 237 | 238 | ``` 239 | python train.py --root /data --dataset trackfolder_var --target-file vocals.flac --ext .wav 240 | ``` 241 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # _Open-Unmix_ for PyTorch 2 | 3 | [![status](https://joss.theoj.org/papers/571753bc54c5d6dd36382c3d801de41d/status.svg)](https://joss.theoj.org/papers/571753bc54c5d6dd36382c3d801de41d) 4 | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1mijF0zGWxN-KaxTnd0q6hayAlrID5fEQ) 5 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/open-unmix-a-reference-implementation-for/music-source-separation-on-musdb18)](https://paperswithcode.com/sota/music-source-separation-on-musdb18?p=open-unmix-a-reference-implementation-for) 6 | 7 | [![CI](https://github.com/sigsep/open-unmix-pytorch/actions/workflows/test_unittests.yml/badge.svg?event=pull_request)](https://github.com/sigsep/open-unmix-pytorch/actions/workflows/test_unittests.yml)[![Latest Version](https://img.shields.io/pypi/v/openunmix.svg)](https://pypi.python.org/pypi/openunmix) 8 | [![Supported Python versions](https://img.shields.io/pypi/pyversions/openunmix.svg)](https://pypi.python.org/pypi/openunmix) 9 | 10 | This repository contains the PyTorch (1.8+) implementation of __Open-Unmix__, a deep neural network reference implementation for music source separation, applicable for researchers, audio engineers and artists. __Open-Unmix__ provides ready-to-use models that allow users to separate pop music into four stems: __vocals__, __drums__, __bass__ and the remaining __other__ instruments. The models were pre-trained on the freely available [MUSDB18](https://sigsep.github.io/datasets/musdb.html) dataset. See details at [apply pre-trained model](#getting-started). 11 | 12 | ## ⭐️ News 13 | 14 | - 16/04/2024: We brought the repo to torch 2.0 level. Everything seems to work fine again, but we needed to relax the regression tests. With most recent version results a slightly different, so be warned when running unit tests 15 | - 03/07/2021: We added `umxl`, a model that was trained on extra data which significantly improves the performance, especially generalization. 16 | - 14/02/2021: We released the new version of open-unmix as a python package. This comes with: a fully differentiable version of [norbert](https://github.com/sigsep/norbert), improved audio loading pipeline and large number of bug fixes. See [release notes](https://github.com/sigsep/open-unmix-pytorch/releases/) for further info. 17 | 18 | - 06/05/2020: We added a pre-trained speech enhancement model `umxse` provided by Sony. 19 | 20 | - 13/03/2020: Open-unmix was awarded 2nd place in the [PyTorch Global Summer Hackathon 2020](https://devpost.com/software/open-unmix). 21 | 22 | __Related Projects:__ open-unmix-pytorch | [open-unmix-nnabla](https://github.com/sigsep/open-unmix-nnabla) | [musdb](https://github.com/sigsep/sigsep-mus-db) | [museval](https://github.com/sigsep/sigsep-mus-eval) | [norbert](https://github.com/sigsep/norbert) 23 | 24 | ## 🧠 The Model (for one source) 25 | 26 | ![](https://docs.google.com/drawings/d/e/2PACX-1vTPoQiPwmdfET4pZhue1RvG7oEUJz7eUeQvCu6vzYeKRwHl6by4RRTnphImSKM0k5KXw9rZ1iIFnpGW/pub?w=959&h=308) 27 | 28 | To perform separation into multiple sources, _Open-unmix_ comprises multiple models that are trained for each particular target. While this makes the training less comfortable, it allows great flexibility to customize the training data for each target source. 29 | 30 | Each _Open-Unmix_ source model is based on a three-layer bidirectional deep LSTM. The model learns to predict the magnitude spectrogram of a target source, like _vocals_, from the magnitude spectrogram of a mixture input. Internally, the prediction is obtained by applying a mask on the input. The model is optimized in the magnitude domain using mean squared error. 31 | 32 | ### Input Stage 33 | 34 | __Open-Unmix__ operates in the time-frequency domain to perform its prediction. The input of the model is either: 35 | 36 | * __`models.Separator`:__ A time domain signal tensor of shape `(nb_samples, nb_channels, nb_timesteps)`, where `nb_samples` are the samples in a batch, `nb_channels` is 1 or 2 for mono or stereo audio, respectively, and `nb_timesteps` is the number of audio samples in the recording. In this case, the model computes STFTs with either `torch` or `asteroid_filteranks` on the fly. 37 | 38 | * __`models.OpenUnmix`:__ The core open-unmix takes **magnitude spectrograms** directly (e.g. when pre-computed and loaded from disk). In that case, the input is of shape `(nb_frames, nb_samples, nb_channels, nb_bins)`, where `nb_frames` and `nb_bins` are the time and frequency-dimensions of a Short-Time-Fourier-Transform. 39 | 40 | The input spectrogram is _standardized_ using the global mean and standard deviation for every frequency bin across all frames. Furthermore, we apply batch normalization in multiple stages of the model to make the training more robust against gain variation. 41 | 42 | ### Dimensionality reduction 43 | 44 | The LSTM is not operating on the original input spectrogram resolution. Instead, in the first step after the normalization, the network learns to compresses the frequency and channel axis of the model to reduce redundancy and make the model converge faster. 45 | 46 | ### Bidirectional-LSTM 47 | 48 | The core of __open-unmix__ is a three layer bidirectional [LSTM network](https://dl.acm.org/citation.cfm?id=1246450). Due to its recurrent nature, the model can be trained and evaluated on arbitrary length of audio signals. Since the model takes information from past and future simultaneously, the model cannot be used in an online/real-time manner. 49 | An uni-directional model can easily be trained as described [here](docs/training.md). 50 | 51 | ### Output Stage 52 | 53 | After applying the LSTM, the signal is decoded back to its original input dimensionality. In the last steps the output is multiplied with the input magnitude spectrogram, so that the models is asked to learn a mask. 54 | 55 | ## 🤹‍♀️ Putting source models together: the `Separator` 56 | 57 | `models.Separator` puts together _Open-unmix_ spectrogram model for each desired target, and combines their output through a multichannel generalized Wiener filter, before application of inverse STFTs using `torchaudio`. 58 | The filtering is differentiable (but parameter-free) version of [norbert](https://github.com/sigsep/norbert). The separator is currently currently only used during inference. 59 | 60 | ## 🏁 Getting started 61 | 62 | ### Installation 63 | 64 | `openunmix` can be installed from pypi using: 65 | 66 | ``` 67 | pip install openunmix 68 | ``` 69 | 70 | Note, that the pypi version of openunmix uses [torchaudio] to load and save audio files. To increase the number of supported input and output file formats (such as STEMS export), please additionally install [stempeg](https://github.com/faroit/stempeg). 71 | 72 | Training is not part of the open-unmix package, please follow [docs/train.md] for more information. 73 | 74 | #### Using Docker 75 | 76 | We also provide a docker container. Performing separation of a local track in `~/Music/track1.wav` can be performed in a single line: 77 | 78 | ``` 79 | docker run -v ~/Music/:/data -it faroit/open-unmix-pytorch "/data/track1.wav" --outdir /data/track1 80 | ``` 81 | 82 | ### Pre-trained models 83 | 84 | We provide three core pre-trained music separation models. All three models are end-to-end models that take waveform inputs and output the separated waveforms. 85 | 86 | * __`umxl` (default)__ trained on private stems dataset of compressed stems. __Note, that the weights are only licensed for non-commercial use (CC BY-NC-SA 4.0).__ 87 | 88 | [![DOI](https://zenodo.org/badge/DOI/10.5281/zenodo.5069601.svg)](https://doi.org/10.5281/zenodo.5069601) 89 | 90 | * __`umxhq`__ trained on [MUSDB18-HQ](https://sigsep.github.io/datasets/musdb.html#uncompressed-wav) which comprises the same tracks as in MUSDB18 but un-compressed which yield in a full bandwidth of 22050 Hz. 91 | 92 | [![DOI](https://zenodo.org/badge/DOI/10.5281/zenodo.3370489.svg)](https://doi.org/10.5281/zenodo.3370489) 93 | 94 | * __`umx`__ is trained on the regular [MUSDB18](https://sigsep.github.io/datasets/musdb.html#compressed-stems) which is bandwidth limited to 16 kHz do to AAC compression. This model should be used for comparison with other (older) methods for evaluation in [SiSEC18](sisec18.unmix.app). 95 | 96 | [![DOI](https://zenodo.org/badge/DOI/10.5281/zenodo.3370486.svg)](https://doi.org/10.5281/zenodo.3370486) 97 | 98 | Furthermore, we provide a model for speech enhancement trained by [Sony Corporation](link) 99 | 100 | * __`umxse`__ speech enhancement model is trained on the 28-speaker version of the [Voicebank+DEMAND corpus](https://datashare.is.ed.ac.uk/handle/10283/1942?show=full). 101 | 102 | [![DOI](https://zenodo.org/badge/DOI/10.5281/zenodo.3786908.svg)](https://doi.org/10.5281/zenodo.3786908) 103 | 104 | All four models are also available as spectrogram (core) models, which take magnitude spectrogram inputs and ouput separated spectrograms. 105 | These models can be loaded using `umxl_spec`, `umxhq_spec`, `umx_spec` and `umxse_spec`. 106 | 107 | To separate audio files (`wav`, `flac`, `ogg` - but not `mp3`) files just run: 108 | 109 | ```bash 110 | umx input_file.wav 111 | ``` 112 | 113 | A more detailed list of the parameters used for the separation is given in the [inference.md](/docs/inference.md) document. 114 | 115 | We provide a [jupyter notebook on google colab](https://colab.research.google.com/drive/1mijF0zGWxN-KaxTnd0q6hayAlrID5fEQ) to experiment with open-unmix and to separate files online without any installation setup. 116 | 117 | ### Using pre-trained models from within python 118 | 119 | We implementes several ways to load pre-trained models and use them from within your python projects: 120 | #### When the package is installed 121 | 122 | Loading a pre-trained models is as simple as loading 123 | 124 | ```python 125 | separator = openunmix.umxl(...) 126 | ``` 127 | #### torch.hub 128 | 129 | We also provide a torch.hub compatible modules that can be loaded. Note that this does _not_ even require to install the open-unmix packagen and should generally work when the pytorch version is the same. 130 | 131 | ```python 132 | separator = torch.hub.load('sigsep/open-unmix-pytorch', 'umxl, device=device) 133 | ``` 134 | 135 | Where, `umxl` specifies the pre-trained model. 136 | #### Performing separation 137 | 138 | With a created separator object, one can perform separation of some `audio` (torch.Tensor of shape `(channels, length)`, provided as at a sampling rate `separator.sample_rate`) through: 139 | 140 | ```python 141 | estimates = separator(audio, ...) 142 | # returns estimates as tensor 143 | ``` 144 | 145 | Note that this requires the audio to be in the right shape and sampling rate. For convenience we provide a pre-processing in `openunmix.utils.preprocess(..`)` that takes numpy audio and converts it to be used for open-unmix. 146 | 147 | #### One-liner 148 | 149 | To perform model loading, preprocessing and separation in one step, just use: 150 | 151 | ```python 152 | from openunmix.predict import separate 153 | estimates = separate(audio, ...) 154 | ``` 155 | 156 | ### Load user-trained models 157 | 158 | When a path instead of a model-name is provided to `--model`, pre-trained `Separator` will be loaded from disk. 159 | E.g. The following files are assumed to present when loading `--model mymodel --targets vocals` 160 | 161 | * `mymodel/separator.json` 162 | * `mymodel/vocals.pth` 163 | * `mymodel/vocals.json` 164 | 165 | 166 | Note that the separator usually joins multiple models for each target and performs separation using all models. E.g. if the separator contains `vocals` and `drums` models, two output files are generated, unless the `--residual` option is selected, in which case an additional source will be produced, containing an estimate of all that is not the targets in the mixtures. 167 | 168 | ### Evaluation using `museval` 169 | 170 | To perform evaluation in comparison to other SISEC systems, you would need to install the `museval` package using 171 | 172 | ``` 173 | pip install museval 174 | ``` 175 | 176 | and then run the evaluation using 177 | 178 | `python -m openunmix.evaluate --outdir /path/to/musdb/estimates --evaldir /path/to/museval/results` 179 | 180 | ### Results compared to SiSEC 2018 (SDR/Vocals) 181 | 182 | Open-Unmix yields state-of-the-art results compared to participants from [SiSEC 2018](https://sisec18.unmix.app/#/methods). The performance of `UMXHQ` and `UMX` is almost identical since it was evaluated on compressed STEMS. 183 | 184 | ![boxplot_updated](https://user-images.githubusercontent.com/72940/63944652-3f624c80-ca72-11e9-8d33-bed701679fe6.png) 185 | 186 | Note that 187 | 188 | 1. [`STL1`, `TAK2`, `TAK3`, `TAU1`, `UHL3`, `UMXHQ`] were omitted as they were _not_ trained on only _MUSDB18_. 189 | 2. [`HEL1`, `TAK1`, `UHL1`, `UHL2`] are not open-source. 190 | 191 | #### Scores (Median of frames, Median of tracks) 192 | 193 | |target|SDR | SDR | SDR | 194 | |------|-----|-----|-----| 195 | |`model`|UMX |UMXHQ|UMXL | 196 | |vocals|6.32 | 6.25|__7.21__ | 197 | |bass |5.23 | 5.07|__6.02__ | 198 | |drums |5.73 | 6.04|__7.15__ | 199 | |other |4.02 | 4.28|__4.89__ | 200 | 201 | ## Training 202 | 203 | Details on the training is provided in a separate document [here](docs/training.md). 204 | 205 | ## Extensions 206 | 207 | Details on how _open-unmix_ can be extended or improved for future research on music separation is described in a separate document [here](docs/extensions.md). 208 | 209 | 210 | ## Design Choices 211 | 212 | we favored simplicity over performance to promote clearness of the code. The rationale is to have __open-unmix__ serve as a __baseline__ for future research while performance still meets current state-of-the-art (See [Evaluation](#Evaluation)). The results are comparable/better to those of `UHL1`/`UHL2` which obtained the best performance over all systems trained on MUSDB18 in the [SiSEC 2018 Evaluation campaign](https://sisec18.unmix.app). 213 | We designed the code to allow researchers to reproduce existing results, quickly develop new architectures and add own user data for training and testing. We favored framework specifics implementations instead of having a monolithic repository with common code for all frameworks. 214 | 215 | ## How to contribute 216 | 217 | _open-unmix_ is a community focused project, we therefore encourage the community to submit bug-fixes and requests for technical support through [github issues](https://github.com/sigsep/open-unmix-pytorch/issues/new/choose). For more details of how to contribute, please follow our [`CONTRIBUTING.md`](CONTRIBUTING.md). For help and support, please use the gitter chat or the google groups forums. 218 | 219 | ### Authors 220 | 221 | [Fabian-Robert Stöter](https://www.faroit.com/), [Antoine Liutkus](https://github.com/aliutkus), Inria and LIRMM, Montpellier, France 222 | 223 | ## References 224 | 225 |
If you use open-unmix for your research – Cite Open-Unmix 226 | 227 | ```latex 228 | @article{stoter19, 229 | author={F.-R. St\\"oter and S. Uhlich and A. Liutkus and Y. Mitsufuji}, 230 | title={Open-Unmix - A Reference Implementation for Music Source Separation}, 231 | journal={Journal of Open Source Software}, 232 | year=2019, 233 | doi = {10.21105/joss.01667}, 234 | url = {https://doi.org/10.21105/joss.01667} 235 | } 236 | ``` 237 | 238 |

239 |
240 | 241 |
If you use the MUSDB dataset for your research - Cite the MUSDB18 Dataset 242 |

243 | 244 | ```latex 245 | @misc{MUSDB18, 246 | author = {Rafii, Zafar and 247 | Liutkus, Antoine and 248 | Fabian-Robert St{\"o}ter and 249 | Mimilakis, Stylianos Ioannis and 250 | Bittner, Rachel}, 251 | title = {The {MUSDB18} corpus for music separation}, 252 | month = dec, 253 | year = 2017, 254 | doi = {10.5281/zenodo.1117372}, 255 | url = {https://doi.org/10.5281/zenodo.1117372} 256 | } 257 | ``` 258 | 259 |

260 |
261 | 262 | 263 |
If compare your results with SiSEC 2018 Participants - Cite the SiSEC 2018 LVA/ICA Paper 264 |

265 | 266 | ```latex 267 | @inproceedings{SiSEC18, 268 | author="St{\"o}ter, Fabian-Robert and Liutkus, Antoine and Ito, Nobutaka", 269 | title="The 2018 Signal Separation Evaluation Campaign", 270 | booktitle="Latent Variable Analysis and Signal Separation: 271 | 14th International Conference, LVA/ICA 2018, Surrey, UK", 272 | year="2018", 273 | pages="293--305" 274 | } 275 | ``` 276 | 277 |

278 |
279 | 280 | ⚠️ Please note that the official acronym for _open-unmix_ is **UMX**. 281 | 282 | ### License 283 | 284 | MIT 285 | 286 | ### Acknowledgements 287 | 288 |

289 | 290 | anr 291 |

292 | -------------------------------------------------------------------------------- /docs/predict.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | openunmix.predict API documentation 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 |
21 |
22 |
23 |

Module openunmix.predict

24 |
25 |
26 |
27 | 28 | Expand source code 29 | Browse git 30 | 31 |
from openunmix import utils
 32 | 
 33 | 
 34 | def separate(
 35 |     audio,
 36 |     rate=None,
 37 |     model_str_or_path="umxhq",
 38 |     targets=None,
 39 |     niter=1,
 40 |     residual=False,
 41 |     wiener_win_len=300,
 42 |     aggregate_dict=None,
 43 |     separator=None,
 44 |     device=None,
 45 |     filterbank="torch",
 46 | ):
 47 |     """
 48 |     Open Unmix functional interface
 49 | 
 50 |     Separates a torch.Tensor or the content of an audio file.
 51 | 
 52 |     If a separator is provided, use it for inference. If not, create one
 53 |     and use it afterwards.
 54 | 
 55 |     Args:
 56 |         audio: audio to process
 57 |             torch Tensor: shape (channels, length), and
 58 |             `rate` must also be provided.
 59 |         rate: int or None: only used if audio is a Tensor. Otherwise,
 60 |             inferred from the file.
 61 |         model_str_or_path: the pretrained model to use
 62 |         targets (str): select the targets for the source to be separated.
 63 |             a list including: ['vocals', 'drums', 'bass', 'other'].
 64 |             If you don't pick them all, you probably want to
 65 |             activate the `residual=True` option.
 66 |             Defaults to all available targets per model.
 67 |         niter (int): the number of post-processingiterations, defaults to 1
 68 |         residual (bool): if True, a "garbage" target is created
 69 |         wiener_win_len (int): the number of frames to use when batching
 70 |             the post-processing step
 71 |         aggregate_dict (str): if provided, must be a string containing a '
 72 |             'valid expression for a dictionary, with keys as output '
 73 |             'target names, and values a list of targets that are used to '
 74 |             'build it. For instance: \'{\"vocals\":[\"vocals\"], '
 75 |             '\"accompaniment\":[\"drums\",\"bass\",\"other\"]}\'
 76 |         separator: if provided, the model.Separator object that will be used
 77 |              to perform separation
 78 |         device (str): selects device to be used for inference
 79 |         filterbank (str): filterbank implementation method.
 80 |             Supported are `['torch', 'asteroid']`. `torch` is about 30% faster
 81 |             compared to `asteroid` on large FFT sizes such as 4096. However,
 82 |             asteroids stft can be exported to onnx, which makes is practical
 83 |             for deployment.
 84 |     """
 85 |     if separator is None:
 86 |         separator = utils.load_separator(
 87 |             model_str_or_path=model_str_or_path,
 88 |             targets=targets,
 89 |             niter=niter,
 90 |             residual=residual,
 91 |             wiener_win_len=wiener_win_len,
 92 |             device=device,
 93 |             pretrained=True,
 94 |             filterbank=filterbank,
 95 |         )
 96 |         separator.freeze()
 97 |         if device:
 98 |             separator.to(device)
 99 | 
100 |     if rate is None:
101 |         raise Exception("rate` must be provided.")
102 | 
103 |     if device:
104 |         audio = audio.to(device)
105 |     audio = utils.preprocess(audio, rate, separator.sample_rate)
106 | 
107 |     # getting the separated signals
108 |     estimates = separator(audio)
109 |     estimates = separator.to_dict(estimates, aggregate_dict=aggregate_dict)
110 |     return estimates
111 |
112 |
113 |
114 |
115 |
116 |
117 |
118 |

Functions

119 |
120 |
121 | def separate(audio, rate=None, model_str_or_path='umxhq', targets=None, niter=1, residual=False, wiener_win_len=300, aggregate_dict=None, separator=None, device=None, filterbank='torch') 122 |
123 |
124 |

Open Unmix functional interface

125 |

Separates a torch.Tensor or the content of an audio file.

126 |

If a separator is provided, use it for inference. If not, create one 127 | and use it afterwards.

128 |

Args

129 |
130 |
audio
131 |
audio to process 132 | torch Tensor: shape (channels, length), and 133 | rate must also be provided.
134 |
rate
135 |
int or None: only used if audio is a Tensor. Otherwise, 136 | inferred from the file.
137 |
model_str_or_path
138 |
the pretrained model to use
139 |
targets : str
140 |
select the targets for the source to be separated. 141 | a list including: ['vocals', 'drums', 'bass', 'other']. 142 | If you don't pick them all, you probably want to 143 | activate the residual=True option. 144 | Defaults to all available targets per model.
145 |
niter : int
146 |
the number of post-processingiterations, defaults to 1
147 |
residual : bool
148 |
if True, a "garbage" target is created
149 |
wiener_win_len : int
150 |
the number of frames to use when batching 151 | the post-processing step
152 |
aggregate_dict : str
153 |
if provided, must be a string containing a ' 154 | 'valid expression for a dictionary, with keys as output ' 155 | 'target names, and values a list of targets that are used to ' 156 | 'build it. For instance: '{"vocals":["vocals"], ' 157 | '"accompaniment":["drums","bass","other"]}'
158 |
separator
159 |
if provided, the model.Separator object that will be used 160 | to perform separation
161 |
device : str
162 |
selects device to be used for inference
163 |
filterbank : str
164 |
filterbank implementation method. 165 | Supported are ['torch', 'asteroid']. torch is about 30% faster 166 | compared to asteroid on large FFT sizes such as 4096. However, 167 | asteroids stft can be exported to onnx, which makes is practical 168 | for deployment.
169 |
170 |
171 | 172 | Expand source code 173 | Browse git 174 | 175 |
def separate(
176 |     audio,
177 |     rate=None,
178 |     model_str_or_path="umxhq",
179 |     targets=None,
180 |     niter=1,
181 |     residual=False,
182 |     wiener_win_len=300,
183 |     aggregate_dict=None,
184 |     separator=None,
185 |     device=None,
186 |     filterbank="torch",
187 | ):
188 |     """
189 |     Open Unmix functional interface
190 | 
191 |     Separates a torch.Tensor or the content of an audio file.
192 | 
193 |     If a separator is provided, use it for inference. If not, create one
194 |     and use it afterwards.
195 | 
196 |     Args:
197 |         audio: audio to process
198 |             torch Tensor: shape (channels, length), and
199 |             `rate` must also be provided.
200 |         rate: int or None: only used if audio is a Tensor. Otherwise,
201 |             inferred from the file.
202 |         model_str_or_path: the pretrained model to use
203 |         targets (str): select the targets for the source to be separated.
204 |             a list including: ['vocals', 'drums', 'bass', 'other'].
205 |             If you don't pick them all, you probably want to
206 |             activate the `residual=True` option.
207 |             Defaults to all available targets per model.
208 |         niter (int): the number of post-processingiterations, defaults to 1
209 |         residual (bool): if True, a "garbage" target is created
210 |         wiener_win_len (int): the number of frames to use when batching
211 |             the post-processing step
212 |         aggregate_dict (str): if provided, must be a string containing a '
213 |             'valid expression for a dictionary, with keys as output '
214 |             'target names, and values a list of targets that are used to '
215 |             'build it. For instance: \'{\"vocals\":[\"vocals\"], '
216 |             '\"accompaniment\":[\"drums\",\"bass\",\"other\"]}\'
217 |         separator: if provided, the model.Separator object that will be used
218 |              to perform separation
219 |         device (str): selects device to be used for inference
220 |         filterbank (str): filterbank implementation method.
221 |             Supported are `['torch', 'asteroid']`. `torch` is about 30% faster
222 |             compared to `asteroid` on large FFT sizes such as 4096. However,
223 |             asteroids stft can be exported to onnx, which makes is practical
224 |             for deployment.
225 |     """
226 |     if separator is None:
227 |         separator = utils.load_separator(
228 |             model_str_or_path=model_str_or_path,
229 |             targets=targets,
230 |             niter=niter,
231 |             residual=residual,
232 |             wiener_win_len=wiener_win_len,
233 |             device=device,
234 |             pretrained=True,
235 |             filterbank=filterbank,
236 |         )
237 |         separator.freeze()
238 |         if device:
239 |             separator.to(device)
240 | 
241 |     if rate is None:
242 |         raise Exception("rate` must be provided.")
243 | 
244 |     if device:
245 |         audio = audio.to(device)
246 |     audio = utils.preprocess(audio, rate, separator.sample_rate)
247 | 
248 |     # getting the separated signals
249 |     estimates = separator(audio)
250 |     estimates = separator.to_dict(estimates, aggregate_dict=aggregate_dict)
251 |     return estimates
252 |
253 |
254 |
255 |
256 |
257 |
258 |
259 | 277 |
278 | 281 | 282 | --------------------------------------------------------------------------------