├── .nojekyll
├── tests
├── __init__.py
├── data
│ ├── test.wav
│ ├── Al James - Schoolboy Facination.spectrogram.pt
│ └── Al James - Schoolboy Facination.json
├── cli_test.sh
├── test_utils.py
├── create_dummy_datasets.sh
├── test_augmentations.py
├── create_vectors.py
├── test_io.py
├── test_jit.py
├── test_transforms.py
├── test_model.py
├── test_datasets.py
├── test_wiener.py
└── test_regression.py
├── codecov.yml
├── .github
├── workflows
│ ├── test_black.yml
│ ├── test_cli.yml
│ └── test_unittests.yml
├── PULL_REQUEST_TEMPLATE.md
└── ISSUE_TEMPLATE
│ ├── improved-model.md
│ └── bug-report.md
├── Dockerfile
├── scripts
├── environment-cpu-osx.yml
├── environment-cpu-linux.yml
├── environment-gpu-linux-cuda10.yml
├── requirements.txt
├── train.py
└── README.md
├── .flake8
├── hubconf.py
├── LICENSE
├── .gitignore
├── pyproject.toml
├── docs
├── faq.md
├── inference.md
├── extensions.md
├── evaluate.html
├── training.md
└── predict.html
├── pdoc
└── config.mako
├── openunmix
├── predict.py
├── evaluate.py
├── cli.py
├── transforms.py
├── utils.py
├── model.py
└── __init__.py
├── CONTRIBUTING.md
└── README.md
/.nojekyll:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/codecov.yml:
--------------------------------------------------------------------------------
1 | codecov:
2 | require_ci_to_pass: no
3 |
--------------------------------------------------------------------------------
/tests/data/test.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sigsep/open-unmix-pytorch/HEAD/tests/data/test.wav
--------------------------------------------------------------------------------
/tests/data/Al James - Schoolboy Facination.spectrogram.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sigsep/open-unmix-pytorch/HEAD/tests/data/Al James - Schoolboy Facination.spectrogram.pt
--------------------------------------------------------------------------------
/tests/cli_test.sh:
--------------------------------------------------------------------------------
1 | # run umx on url
2 | umx https://samples.ffmpeg.org/A-codecs/wavpcm/test-96.wav --audio-backend stempeg
3 | umx https://samples.ffmpeg.org/A-codecs/wavpcm/test-96.wav --audio-backend stempeg --outdir out --niter 0
4 |
--------------------------------------------------------------------------------
/.github/workflows/test_black.yml:
--------------------------------------------------------------------------------
1 | name: Lint
2 | on: [pull_request] # yamllint disable-line rule:truthy
3 | jobs:
4 | lint:
5 | runs-on: ubuntu-latest
6 | steps:
7 | - uses: actions/checkout@v4
8 | - uses: psf/black@stable
9 | with:
10 | options: "--check --verbose"
11 | version: "~= 24.4.0"
12 |
--------------------------------------------------------------------------------
/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM pytorch/pytorch:1.7.1-cuda11.0-cudnn8-runtime
2 |
3 | RUN apt-get update && apt-get install -y --no-install-recommends \
4 | libsox-fmt-all \
5 | sox \
6 | libsox-dev
7 |
8 | WORKDIR /workspace
9 |
10 | RUN conda install ffmpeg -c conda-forge
11 | RUN pip install musdb>=0.4.0
12 |
13 | RUN pip install openunmix['stempeg']
14 |
15 | ENTRYPOINT ["umx"]
--------------------------------------------------------------------------------
/.github/PULL_REQUEST_TEMPLATE.md:
--------------------------------------------------------------------------------
1 |
8 |
9 | #### What does this implement/fix? Explain your changes.
10 |
11 |
12 | #### Any other comments?
13 |
--------------------------------------------------------------------------------
/tests/test_utils.py:
--------------------------------------------------------------------------------
1 | from openunmix.utils import AverageMeter, EarlyStopping
2 |
3 |
4 | def test_average_meter():
5 | losses = AverageMeter()
6 | losses.update(1.0)
7 | losses.update(3.0)
8 | assert losses.avg == 2.0
9 |
10 |
11 | def test_early_stopping():
12 | es = EarlyStopping(patience=2)
13 | es.step(1.0)
14 |
15 | assert not es.step(0.5)
16 | assert not es.step(0.6)
17 | assert es.step(0.7)
18 |
--------------------------------------------------------------------------------
/scripts/environment-cpu-osx.yml:
--------------------------------------------------------------------------------
1 | name: umx-osx
2 |
3 | channels:
4 | - conda-forge
5 | - pytorch
6 |
7 | dependencies:
8 | - python=3.7
9 | - numpy=1.18
10 | - scipy=1.4
11 | - pytorch==1.9.0
12 | - torchaudio==0.9.0
13 | - tqdm
14 | - scikit-learn=0.22
15 | - ffmpeg
16 | - libsndfile
17 | - pip
18 | - pip:
19 | - musdb>=0.4.0
20 | - museval>=0.4.0
21 | - asteroid-filterbanks>=0.3.2
22 | - gitpython
23 |
24 |
--------------------------------------------------------------------------------
/scripts/environment-cpu-linux.yml:
--------------------------------------------------------------------------------
1 | name: umx-cpu
2 |
3 | channels:
4 | - conda-forge
5 | - pytorch
6 |
7 | dependencies:
8 | - python=3.7
9 | - numpy=1.18
10 | - scipy=1.4
11 | - pytorch==1.9.0
12 | - torchaudio==0.9.0
13 | - cpuonly
14 | - tqdm
15 | - scikit-learn=0.22
16 | - ffmpeg
17 | - libsndfile
18 | - pip
19 | - pip:
20 | - musdb>=0.4.0
21 | - museval>=0.4.0
22 | - asteroid-filterbanks>=0.3.2
23 | - gitpython
24 |
--------------------------------------------------------------------------------
/scripts/environment-gpu-linux-cuda10.yml:
--------------------------------------------------------------------------------
1 | name: umx-gpu
2 |
3 | channels:
4 | - conda-forge
5 | - pytorch
6 | - nvidia
7 |
8 | dependencies:
9 | - python=3.7
10 | - numpy=1.18
11 | - scipy=1.4
12 | - pytorch==1.9.0
13 | - torchaudio==0.9.0
14 | - cudatoolkit=11.1
15 | - scikit-learn=0.22
16 | - tqdm
17 | - libsndfile
18 | - ffmpeg
19 | - pip
20 | - pip:
21 | - musdb>=0.4.0
22 | - museval>=0.4.0
23 | - asteroid-filterbanks>=0.3.2
24 | - gitpython
25 |
26 |
--------------------------------------------------------------------------------
/tests/create_dummy_datasets.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | # Set the number of files to generate
4 | NBFILES=4
5 | BASEDIR=TrackfolderDataset
6 | subsets=(
7 | train
8 | valid
9 | )
10 | for subset in "${subsets[@]}"; do
11 | for k in $(seq 1 4); do
12 | path=$BASEDIR/$subset/$k
13 | mkdir -p $path
14 | for i in $(seq 1 $NBFILES); do
15 | sox -n -r 8000 -b 16 $path/$i.wav synth "0:3" whitenoise vol 0.5 fade q 1 "0:3" 1
16 | done
17 | done
18 | done
19 |
--------------------------------------------------------------------------------
/.flake8:
--------------------------------------------------------------------------------
1 | [flake8]
2 | max-line-length = 119
3 | exclude = docs/source,*.egg,build
4 | select = E,W,F
5 | verbose = 2
6 | # https://pep8.readthedocs.io/en/latest/intro.html#error-codes
7 | format = pylint
8 | ignore =
9 | E731 # E731 - Do not assign a lambda expression, use a def
10 | W605 # W605 - invalid escape sequence '\_'. Needed for docs
11 | W504 # W504 - line break after binary operator
12 | W503 # W503 - line break before binary operator, need for black
13 | E203 # E203 - whitespace before ':'. Opposite convention enforced by black
--------------------------------------------------------------------------------
/scripts/requirements.txt:
--------------------------------------------------------------------------------
1 | attrs==21.2.0
2 | cffi==1.14.6
3 | ffmpeg-python==0.2.0
4 | future==0.18.3
5 | gitdb==4.0.7
6 | GitPython==3.1.41
7 | joblib==1.2.0
8 | jsonschema==3.2.0
9 | musdb==0.3.1
10 | museval==0.3.1
11 | numpy==1.22.0
12 | pandas==1.3.3
13 | pyaml==21.8.3
14 | pycparser==2.20
15 | pyrsistent==0.18.0
16 | python-dateutil==2.8.2
17 | pytz==2021.1
18 | PyYAML==5.4.1
19 | scikit-learn==0.22
20 | scipy==1.10.0
21 | simplejson==3.17.5
22 | six==1.16.0
23 | smmap==4.0.0
24 | SoundFile==0.10.3.post1
25 | stempeg==0.2.3
26 | tqdm==4.62.2
27 | typing-extensions==3.10.0.2
28 |
--------------------------------------------------------------------------------
/hubconf.py:
--------------------------------------------------------------------------------
1 | # This file is to be parsed by torch.hub mechanics
2 | #
3 | # `xxx_spec` take spectrogram inputs and output separated spectrograms
4 | # `xxx` take waveform inputs and output separated waveforms
5 |
6 | # Optional list of dependencies required by the package
7 | dependencies = ["torch", "numpy"]
8 |
9 | from openunmix import umxse_spec
10 | from openunmix import umxse
11 |
12 | from openunmix import umxhq_spec
13 | from openunmix import umxhq
14 |
15 | from openunmix import umx_spec
16 | from openunmix import umx
17 |
18 | from openunmix import umxl_spec
19 | from openunmix import umxl
20 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/improved-model.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: "\U0001F680Improved Model"
3 | about: Submit a proposal for an improved separation model
4 |
5 | ---
6 |
7 | ## 🚀 Model Improvement
8 |
12 |
13 | ## Motivation
14 |
15 |
16 |
17 | ## Objective Evaluation
18 |
19 |
--------------------------------------------------------------------------------
/tests/test_augmentations.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import torch
3 |
4 | from openunmix import data
5 |
6 |
7 | @pytest.fixture(params=[4096])
8 | def nb_timesteps(request):
9 | return int(request.param)
10 |
11 |
12 | @pytest.fixture(params=[1, 2])
13 | def nb_channels(request):
14 | return request.param
15 |
16 |
17 | @pytest.fixture
18 | def audio(request, nb_channels, nb_timesteps):
19 | return torch.rand((nb_channels, nb_timesteps))
20 |
21 |
22 | def test_gain(audio):
23 | out = data._augment_gain(audio)
24 | assert out.shape == audio.shape
25 |
26 |
27 | def test_channelswap(audio):
28 | out = data._augment_channelswap(audio)
29 | assert out.shape == audio.shape
30 |
31 |
32 | def test_forcestereo(audio, nb_channels):
33 | out = data._augment_force_stereo(audio)
34 | assert out.shape[-1] == audio.shape[-1]
35 | assert out.shape[0] == 2
36 |
--------------------------------------------------------------------------------
/tests/create_vectors.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import musdb
4 | import numpy as np
5 | from openunmix import model, utils
6 |
7 | """script to create spectrogram test vectors for STFT regression tests
8 |
9 | Test vectors have been created using the `v1.0.0` release tag as this
10 | was the commit that umx was trained with
11 | """
12 |
13 |
14 | def main():
15 | test_track = "Al James - Schoolboy Facination"
16 | mus = musdb.DB(download=True)
17 |
18 | # load audio track
19 | track = [track for track in mus.tracks if track.name == test_track][0]
20 |
21 | # convert to torch tensor
22 | audio = torch.tensor(track.audio.T, dtype=torch.float32)
23 |
24 | stft = model.STFT(n_fft=4096, n_hop=1024)
25 | spec = model.Spectrogram(power=1, mono=False)
26 | magnitude_spectrogram = spec(stft(audio[None, ...]))
27 | torch.save(magnitude_spectrogram, "Al James - Schoolboy Facination.spectrogram.pt")
28 |
29 |
30 | if __name__ == "__main__":
31 | main()
32 |
--------------------------------------------------------------------------------
/tests/test_io.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import numpy as np
3 | import os
4 | import torchaudio
5 |
6 | from openunmix import data
7 |
8 |
9 | audio_path = os.path.join(
10 | os.path.dirname(os.path.realpath(__file__)),
11 | "data/test.wav",
12 | )
13 |
14 |
15 | @pytest.fixture(params=["soundfile", "sox_io"])
16 | def torch_backend(request):
17 | return request.param
18 |
19 |
20 | @pytest.fixture(params=[1.0, 2.0, None])
21 | def dur(request):
22 | return request.param
23 |
24 |
25 | @pytest.fixture(params=[True, False])
26 | def info(request, torch_backend):
27 | torchaudio.set_audio_backend(torch_backend)
28 |
29 | if request.param:
30 | return data.load_info(audio_path)
31 | else:
32 | return None
33 |
34 |
35 | def test_loadwav(dur, info, torch_backend):
36 | torchaudio.set_audio_backend(torch_backend)
37 | audio, _ = data.load_audio(audio_path, dur=dur, info=info)
38 | rate = 8000.0
39 | if dur:
40 | assert audio.shape[-1] == int(dur * rate)
41 | else:
42 | assert audio.shape[-1] == rate * 3
43 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2019 Inria (Fabian-Robert Stöter, Antoine Liutkus)
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/tests/test_jit.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.onnx
3 | import pytest
4 |
5 | from openunmix import model
6 |
7 |
8 | @pytest.mark.skip(reason="Currently not supported")
9 | def test_onnx():
10 | """Test ONNX export of the separator
11 |
12 | currently results in erros, blocked by
13 | https://github.com/pytorch/pytorch/issues/49958
14 | """
15 | nb_samples = 1
16 | nb_channels = 2
17 | nb_timesteps = 11111
18 |
19 | example = torch.rand((nb_samples, nb_channels, nb_timesteps), device="cpu")
20 | # set model to eval due to non-deterministic behaviour of dropout
21 | umx = model.OpenUnmix(nb_bins=2049, nb_channels=2).eval().to("cpu")
22 |
23 | # creatr separator
24 | separator = (
25 | model.Separator(target_models={"source_1": umx, "source_2": umx}, niter=1, filterbank="asteroid")
26 | .eval()
27 | .to("cpu")
28 | )
29 |
30 | torch_out = separator(example)
31 |
32 | # Export the model
33 | torch.onnx.export(
34 | separator,
35 | example,
36 | "umx.onnx",
37 | export_params=True,
38 | opset_version=10,
39 | do_constant_folding=True,
40 | input_names=["input"],
41 | output_names=["output"],
42 | dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}},
43 | )
44 |
--------------------------------------------------------------------------------
/tests/test_transforms.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import numpy as np
3 | import torch
4 | from openunmix import transforms
5 |
6 |
7 | @pytest.fixture(params=[4096])
8 | def nb_timesteps(request):
9 | return int(request.param)
10 |
11 |
12 | @pytest.fixture(params=[2])
13 | def nb_channels(request):
14 | return request.param
15 |
16 |
17 | @pytest.fixture(params=[2])
18 | def nb_samples(request):
19 | return request.param
20 |
21 |
22 | @pytest.fixture(params=[2048])
23 | def nfft(request):
24 | return int(request.param)
25 |
26 |
27 | @pytest.fixture(params=[2])
28 | def hop(request, nfft):
29 | return nfft // request.param
30 |
31 |
32 | @pytest.fixture(params=["torch", "asteroid"])
33 | def method(request):
34 | return request.param
35 |
36 |
37 | @pytest.fixture
38 | def audio(request, nb_samples, nb_channels, nb_timesteps):
39 | return torch.rand((nb_samples, nb_channels, nb_timesteps))
40 |
41 |
42 | def test_stft(audio, nfft, hop, method):
43 | # we should only test for center=True as
44 | # False doesn't pass COLA
45 | # https://github.com/pytorch/audio/issues/500
46 | stft, istft = transforms.make_filterbanks(n_fft=nfft, n_hop=hop, center=True, method=method)
47 |
48 | X = stft(audio)
49 | X = X.detach()
50 | out = istft(X, length=audio.shape[-1])
51 | assert np.sqrt(np.mean((audio.detach().numpy() - out.detach().numpy()) ** 2)) < 1e-6
52 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/bug-report.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: "\U0001F41B Bug Report"
3 | about: Submit a bug report to help to improve Open-Unmix
4 |
5 | ---
6 |
7 | ## 🐛 Bug
8 |
9 |
10 |
11 | ## To Reproduce
12 |
13 | Steps to reproduce the behavior:
14 |
15 | 1.
16 | 1.
17 | 1.
18 |
19 |
20 |
21 | ## Expected behavior
22 |
23 |
24 |
25 | ## Environment
26 |
27 | Please add some information about your environment
28 |
29 | - PyTorch Version (e.g., 1.2):
30 | - OS (e.g., Linux):
31 | - torchaudio loader (y/n):
32 | - Python version:
33 | - CUDA/cuDNN version:
34 | - Any other relevant information:
35 |
36 | If unsure you can paste the output from the [pytorch environment collection script](https://raw.githubusercontent.com/pytorch/pytorch/master/torch/utils/collect_env.py)
37 | (or fill out the checklist below manually).
38 |
39 | You can get that script and run it with:
40 | ```
41 | wget https://raw.githubusercontent.com/pytorch/pytorch/master/torch/utils/collect_env.py
42 | # For security purposes, please check the contents of collect_env.py before running it.
43 | python collect_env.py
44 | ```
45 |
46 | ## Additional context
47 |
48 |
49 |
--------------------------------------------------------------------------------
/.github/workflows/test_cli.yml:
--------------------------------------------------------------------------------
1 | name: UMX
2 | # thanks for @mpariente for copying this workflow
3 | # see: https://help.github.com/en/actions/reference/events-that-trigger-workflows
4 | # Trigger the workflow on push or pull request
5 | on: # yamllint disable-line rule:truthy
6 | pull_request:
7 | types: [opened, synchronize, reopened, ready_for_review]
8 |
9 | jobs:
10 | src-test:
11 | name: separation test
12 | runs-on: ubuntu-latest
13 | strategy:
14 | matrix:
15 | python-version: [3.9]
16 |
17 | # Timeout: https://stackoverflow.com/a/59076067/4521646
18 | timeout-minutes: 20
19 | steps:
20 | - uses: actions/checkout@v2
21 | - name: Set up Python ${{ matrix.python-version }}
22 | uses: actions/setup-python@v2
23 | with:
24 | python-version: ${{ matrix.python-version }}
25 | - name: Install libnsdfile, ffmpeg and sox
26 | run: |
27 | sudo apt update
28 | sudo apt install libsndfile1-dev libsndfile1 ffmpeg sox
29 | - name: Install package dependencies
30 | run: |
31 | python -m pip install --upgrade --user pip --quiet
32 | python -m pip install .["stempeg"]
33 | python --version
34 | pip --version
35 | python -m pip list
36 |
37 | - name: CLI tests
38 | run: |
39 | umx https://samples.ffmpeg.org/A-codecs/wavpcm/test-96.wav --audio-backend stempeg
40 |
--------------------------------------------------------------------------------
/tests/test_model.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import torch
3 |
4 | from openunmix import model
5 | from openunmix import umxse
6 | from openunmix import umxhq
7 | from openunmix import umx
8 | from openunmix import umxl
9 |
10 |
11 | @pytest.fixture(params=[100])
12 | def nb_frames(request):
13 | return int(request.param)
14 |
15 |
16 | @pytest.fixture(params=[1, 2])
17 | def nb_channels(request):
18 | return request.param
19 |
20 |
21 | @pytest.fixture(params=[2])
22 | def nb_samples(request):
23 | return request.param
24 |
25 |
26 | @pytest.fixture(params=[1024])
27 | def nb_bins(request):
28 | return request.param
29 |
30 |
31 | @pytest.fixture
32 | def spectrogram(request, nb_samples, nb_channels, nb_bins, nb_frames):
33 | return torch.rand((nb_samples, nb_channels, nb_bins, nb_frames))
34 |
35 |
36 | @pytest.fixture(params=[False])
37 | def unidirectional(request):
38 | return request.param
39 |
40 |
41 | @pytest.fixture(params=[32])
42 | def hidden_size(request):
43 | return request.param
44 |
45 |
46 | def test_shape(spectrogram, nb_bins, nb_channels, unidirectional, hidden_size):
47 | unmix = model.OpenUnmix(
48 | nb_bins=nb_bins,
49 | nb_channels=nb_channels,
50 | unidirectional=unidirectional,
51 | nb_layers=1, # speed up training
52 | hidden_size=hidden_size,
53 | )
54 | unmix.eval()
55 | Y = unmix(spectrogram)
56 | assert spectrogram.shape == Y.shape
57 |
58 |
59 | @pytest.mark.parametrize("model_fn", [umx, umxhq, umxse, umxl])
60 | def test_model_loading(model_fn):
61 | X = torch.rand((1, 2, 4096))
62 | model = model_fn(niter=0, pretrained=True)
63 | Y = model(X)
64 | assert Y[:, 0, ...].shape == X.shape
65 |
--------------------------------------------------------------------------------
/.github/workflows/test_unittests.yml:
--------------------------------------------------------------------------------
1 | name: CI
2 | # thanks for @mpariente for copying this workflow
3 | # see: https://help.github.com/en/actions/reference/events-that-trigger-workflows
4 | # Trigger the workflow on push or pull request
5 | on: # yamllint disable-line rule:truthy
6 | pull_request:
7 | types: [opened, synchronize, reopened, ready_for_review]
8 |
9 | jobs:
10 | src-test:
11 | name: unit-tests
12 | runs-on: ubuntu-latest
13 | strategy:
14 | matrix:
15 | python-version: [3.9]
16 |
17 | # Timeout: https://stackoverflow.com/a/59076067/4521646
18 | timeout-minutes: 20
19 | steps:
20 | - uses: actions/checkout@v2
21 | - name: Set up Python ${{ matrix.python-version }}
22 | uses: actions/setup-python@v2
23 | with:
24 | python-version: ${{ matrix.python-version }}
25 | - name: Install libnsdfile, ffmpeg and sox
26 | run: |
27 | sudo apt update
28 | sudo apt install libsndfile1-dev libsndfile1 ffmpeg sox
29 | - name: Install python dependencies
30 | run: |
31 | python -m pip install --upgrade --user pip --quiet
32 | python -m pip install numpy Cython --upgrade-strategy only-if-needed --quiet
33 | python -m pip install torch torchaudio --index-url https://download.pytorch.org/whl/cpu
34 | python -m pip install -e .['tests']
35 | python --version
36 | pip --version
37 | python -m pip list
38 | - name: Create dummy dataset
39 | run: |
40 | chmod +x tests/create_dummy_datasets.sh
41 | ./tests/create_dummy_datasets.sh
42 | shell: bash
43 |
44 | - name: Run tests
45 | env:
46 | PY_COLORS: "1"
47 | run: |
48 | pytest tests -vv
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | #### joe made this: http://goel.io/joe
2 |
3 | data/
4 | OSU/
5 | .mypy_cache
6 | .vscode
7 | *.json
8 | *.wav
9 | *.mp3
10 | *.pth.tar
11 | env*/
12 |
13 | #####=== Python ===#####
14 |
15 | # Byte-compiled / optimized / DLL files
16 | __pycache__/
17 | *.py[cod]
18 | *$py.class
19 |
20 | # C extensions
21 | *.so
22 |
23 | # Distribution / packaging
24 | .Python
25 | env/
26 | build/
27 | develop-eggs/
28 | dist/
29 | downloads/
30 | eggs/
31 | .eggs/
32 | lib/
33 | lib64/
34 | parts/
35 | sdist/
36 | var/
37 | *.egg-info/
38 | .installed.cfg
39 | *.egg
40 |
41 | # PyInstaller
42 | # Usually these files are written by a python script from a template
43 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
44 | *.manifest
45 | *.spec
46 |
47 | # Installer logs
48 | pip-log.txt
49 | pip-delete-this-directory.txt
50 |
51 | # Unit test / coverage reports
52 | htmlcov/
53 | .tox/
54 | .coverage
55 | .coverage.*
56 | .cache
57 | nosetests.xml
58 | coverage.xml
59 | *,cover
60 |
61 | # Translations
62 | *.mo
63 | *.pot
64 |
65 | # Django stuff:
66 | *.log
67 |
68 | # Sphinx documentation
69 | docs/_build/
70 |
71 | # PyBuilder
72 | target/
73 |
74 | #####=== OSX ===#####
75 | .DS_Store
76 | .AppleDouble
77 | .LSOverride
78 |
79 | # Icon must end with two \r
80 | Icon
81 |
82 |
83 | # Thumbnails
84 | ._*
85 |
86 | # Files that might appear in the root of a volume
87 | .DocumentRevisions-V100
88 | .fseventsd
89 | .Spotlight-V100
90 | .TemporaryItems
91 | .Trashes
92 | .VolumeIcon.icns
93 |
94 | # Directories potentially created on remote AFP share
95 | .AppleDB
96 | .AppleDesktop
97 | Network Trash Folder
98 | Temporary Items
99 | .apdisk
100 |
--------------------------------------------------------------------------------
/tests/test_datasets.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import numpy as np
3 | import torchaudio
4 |
5 | from openunmix import data
6 |
7 |
8 | @pytest.fixture(params=["soundfile", "sox_io"])
9 | def torch_backend(request):
10 | return request.param
11 |
12 |
13 | def test_musdb():
14 | musdb = data.MUSDBDataset(download=True, samples_per_track=1, seq_duration=1.0)
15 | for x, y in musdb:
16 | assert x.shape[-1] == 44100
17 |
18 |
19 | def test_trackfolder_fix(torch_backend):
20 | torchaudio.set_audio_backend(torch_backend)
21 |
22 | train_dataset = data.FixedSourcesTrackFolderDataset(
23 | split="train",
24 | seq_duration=1.0,
25 | root="TrackfolderDataset",
26 | sample_rate=8000.0,
27 | target_file="1.wav",
28 | interferer_files=["2.wav", "3.wav", "4.wav"],
29 | )
30 | for x, y in train_dataset:
31 | assert x.shape[-1] == 8000
32 |
33 |
34 | def test_trackfolder_var(torch_backend):
35 | torchaudio.set_audio_backend(torch_backend)
36 |
37 | train_dataset = data.VariableSourcesTrackFolderDataset(
38 | split="train",
39 | seq_duration=1.0,
40 | root="TrackfolderDataset",
41 | sample_rate=8000.0,
42 | target_file="1.wav",
43 | )
44 | for x, y in train_dataset:
45 | assert x.shape[-1] == 8000
46 |
47 |
48 | def test_sourcefolder(torch_backend):
49 | torchaudio.set_audio_backend(torch_backend)
50 |
51 | train_dataset = data.SourceFolderDataset(
52 | split="train",
53 | seq_duration=1.0,
54 | root="TrackfolderDataset",
55 | sample_rate=8000.0,
56 | target_dir="1",
57 | interferer_dirs=["2", "3"],
58 | ext=".wav",
59 | nb_samples=20,
60 | )
61 | for k in range(len(train_dataset)):
62 | x, y = train_dataset[k]
63 | assert x.shape[-1] == 8000
64 |
--------------------------------------------------------------------------------
/tests/test_wiener.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import torch
3 |
4 |
5 | from openunmix import model
6 | from openunmix.filtering import wiener
7 |
8 |
9 | @pytest.fixture(params=[10, 100])
10 | def nb_frames(request):
11 | return int(request.param)
12 |
13 |
14 | @pytest.fixture(params=[1, 2])
15 | def nb_channels(request):
16 | return request.param
17 |
18 |
19 | @pytest.fixture(params=[10, 127])
20 | def nb_bins(request):
21 | return request.param
22 |
23 |
24 | @pytest.fixture(params=[1, 2, 3])
25 | def nb_sources(request):
26 | return request.param
27 |
28 |
29 | @pytest.fixture(params=[0, 1, 2])
30 | def iterations(request):
31 | return request.param
32 |
33 |
34 | @pytest.fixture(params=[True, False])
35 | def softmask(request):
36 | return request.param
37 |
38 |
39 | @pytest.fixture(params=[True, False])
40 | def residual(request):
41 | return request.param
42 |
43 |
44 | @pytest.fixture
45 | def target(request, nb_frames, nb_channels, nb_bins, nb_sources):
46 | return torch.rand((nb_frames, nb_bins, nb_channels, nb_sources))
47 |
48 |
49 | @pytest.fixture
50 | def mix(request, nb_frames, nb_channels, nb_bins):
51 | return torch.rand((nb_frames, nb_bins, nb_channels, 2))
52 |
53 |
54 | @pytest.fixture(params=[torch.float32, torch.float64])
55 | def dtype(request):
56 | return request.param
57 |
58 |
59 | def test_wiener(target, mix, iterations, softmask, residual):
60 | output = wiener(target, mix, iterations=iterations, softmask=softmask, residual=residual)
61 | # nb_frames, nb_bins, nb_channels, 2, nb_sources
62 | assert output.shape[:3] == mix.shape[:3]
63 | assert output.shape[3] == 2
64 | if residual:
65 | assert output.shape[4] == target.shape[3] + 1
66 | else:
67 | assert output.shape[4] == target.shape[3]
68 |
69 |
70 | def test_dtype(target, mix, dtype):
71 | output = wiener(target.to(dtype=dtype), mix.to(dtype=dtype), iterations=1)
72 | assert output.dtype == dtype
73 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["setuptools>=61.0"]
3 | build-backend = "setuptools.build_meta"
4 |
5 | [project]
6 | name = "openunmix"
7 | authors = [
8 | {name = "Fabian-Robert Stöter", email = "mail@faroit.com"},
9 | {name = "Antoine Liutkus", email = "antoine.liutkus@inria.fr"},
10 | ]
11 | version = "1.3.0"
12 | description = "PyTorch-based music source separation toolkit"
13 | readme = "README.md"
14 | license = { text = "MIT" }
15 | requires-python = ">=3.9"
16 | classifiers = [
17 | "Development Status :: 5 - Production/Stable",
18 | "Environment :: Console",
19 | "Intended Audience :: Developers",
20 | "License :: OSI Approved :: MIT License",
21 | "Operating System :: OS Independent",
22 | "Programming Language :: Python",
23 | "Programming Language :: Python :: 3 :: Only",
24 | "Programming Language :: Python :: 3.8",
25 | "Programming Language :: Python :: 3.9",
26 | "Programming Language :: Python :: 3.10",
27 | "Programming Language :: Python :: 3.11",
28 | "Programming Language :: Python :: 3.12",
29 | "Topic :: Software Development :: Libraries :: Python Modules",
30 | "Topic :: Software Development :: Quality Assurance",
31 | ]
32 | dependencies = [
33 | "numpy",
34 | "torchaudio>=0.9.0",
35 | "torch>=1.9.0",
36 | "tqdm",
37 | ]
38 |
39 | [project.optional-dependencies]
40 | asteroid = ["asteroid-filterbanks>=0.3.2"]
41 | stempeg = ["stempeg"]
42 | evaluation = ["musdb>=0.4.0", "museval>=0.4.0"]
43 | tests = [
44 | "pytest",
45 | "musdb>=0.4.0",
46 | "museval>=0.4.0",
47 | "stempeg",
48 | "asteroid-filterbanks>=0.3.2",
49 | "onnx",
50 | "tqdm",
51 | ]
52 |
53 | [project.scripts]
54 | umx = "openunmix.cli:separate"
55 |
56 | [project.urls]
57 | Homepage = "https://github.com/sigsep/open-unmix-pytorch"
58 |
59 | [tool.black]
60 | line-length = 120
61 | target-version = ['py39']
62 | include = '\.pyi?$'
63 | exclude = '''
64 | (
65 | /(
66 | \.git
67 | | \.hg
68 | | \.mypy_cache
69 | | \.tox
70 | | \.venv
71 | | _build
72 | | buck-out
73 | | build
74 | | dist
75 | | \.idea
76 | | \.vscode
77 | | scripts
78 | | notebooks
79 | | \.eggs
80 | )/
81 | )
82 | '''
83 |
84 | [tool.setuptools.packages.find]
85 | include = ["openunmix"]
86 |
87 | [tool.setuptools.package-data]
88 | openunmix = ["*.txt", "*.rst", "*.json", "*.wav", "*.pt"]
--------------------------------------------------------------------------------
/tests/test_regression.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pytest
3 | import musdb
4 | import simplejson as json
5 | import numpy as np
6 | import torch
7 |
8 |
9 | from openunmix import model
10 | from openunmix import evaluate
11 | from openunmix import utils
12 | from openunmix import transforms
13 |
14 |
15 | test_track = "Al James - Schoolboy Facination"
16 |
17 | json_path = os.path.join(
18 | os.path.dirname(os.path.realpath(__file__)),
19 | "data/%s.json" % test_track,
20 | )
21 |
22 | spec_path = os.path.join(
23 | os.path.dirname(os.path.realpath(__file__)),
24 | "data/%s.spectrogram.pt" % test_track,
25 | )
26 |
27 |
28 | @pytest.fixture(params=["torch", "asteroid"])
29 | def method(request):
30 | return request.param
31 |
32 |
33 | @pytest.fixture()
34 | def mus():
35 | return musdb.DB(download=True)
36 |
37 |
38 | def test_estimate_and_evaluate(mus):
39 | # return any number of targets
40 | with open(json_path) as json_file:
41 | ref = json.loads(json_file.read())
42 |
43 | track = [track for track in mus.tracks if track.name == test_track][0]
44 |
45 | scores = evaluate.separate_and_evaluate(
46 | track,
47 | targets=["vocals", "drums", "bass", "other"],
48 | model_str_or_path="umx",
49 | niter=1,
50 | residual=None,
51 | mus=mus,
52 | aggregate_dict=None,
53 | output_dir=None,
54 | eval_dir=None,
55 | device="cpu",
56 | wiener_win_len=None,
57 | )
58 |
59 | assert scores.validate() is None
60 |
61 | with open(os.path.join(".", track.name) + ".json", "w+") as f:
62 | f.write(scores.json)
63 |
64 | scores = json.loads(scores.json)
65 |
66 | for target in ref["targets"]:
67 | for metric in ["SDR", "SIR", "SAR", "ISR"]:
68 |
69 | ref = np.array([d["metrics"][metric] for d in target["frames"]])
70 | idx = [t["name"] for t in scores["targets"]].index(target["name"])
71 | est = np.array([d["metrics"][metric] for d in scores["targets"][idx]["frames"]])
72 |
73 | assert np.allclose(ref, est, atol=1e-01)
74 |
75 |
76 | def test_spectrogram(mus, method):
77 | """Regression test for spectrogram transform
78 |
79 | Loads pre-computed transform and compare to current spectrogram
80 | e.g. this makes sure that the training is reproducible if parameters
81 | such as STFT centering would be subject to change.
82 | """
83 | track = [track for track in mus.tracks if track.name == test_track][0]
84 |
85 | stft, _ = transforms.make_filterbanks(n_fft=4096, n_hop=1024, sample_rate=track.rate, method=method)
86 | encoder = torch.nn.Sequential(stft, model.ComplexNorm(mono=False))
87 | audio = torch.as_tensor(track.audio, dtype=torch.float32, device="cpu")
88 | audio = utils.preprocess(audio, track.rate, track.rate)
89 | ref = torch.load(spec_path)
90 | dut = encoder(audio).permute(3, 0, 1, 2)
91 |
92 | assert torch.allclose(ref, dut, atol=1e-1)
93 |
--------------------------------------------------------------------------------
/docs/faq.md:
--------------------------------------------------------------------------------
1 | # Frequently Asked Questions
2 |
3 | ## Separating tracks crashes because it used too much memory
4 |
5 | First, separating an audio track into four separation models `vocals`, `drums`, `bass` and `other` is requires a significant amount of RAM to load all four separate models.
6 | Furthermore, another computationally important step in the separation is the post-processing, controlled by the parameter `niter`.
7 | For faster and less memory intensive inference (at the expense of separation quality) it is advised to use `niter 0`.
8 | Another way to improve performance is to apply separation on smaller excerpts using the `start` and `duration`, arguments. We suggest to only perform separation of ~30s stereo excerpts on machines with less 8GB of memory.
9 |
10 | ## Why is the training so slow?
11 |
12 | In the default configuration using the stems dataset, yielding a single batch from the dataset is very slow. This is a known issue of decoding mp4 stems since native decoders for pytorch or numpy are not available.
13 |
14 | There are two ways to speed up the training:
15 |
16 | ### 1. Increase the number of workers
17 |
18 | The default configuration does not use multiprocessing to yield the batches. You can increase the number of workers using the `--nb-workers k` configuration. E.g. `k=8` workers batch loading can get down to 1 batch per second.
19 |
20 | ### 2. Use WAV instead of MP4
21 |
22 | Convert the MUSDB18 dataset to wav using the builtin `musdb` cli tool
23 |
24 | ```
25 | musdbconvert path/to/musdb-stems-root path/to/new/musdb-wav-root
26 | ```
27 |
28 | or alternatively use the [MUSDB18-HQ](https://zenodo.org/record/3338373) dataset that is already stored and distributed as WAV files. Note that __if you want to compare to SiSEC 2018 participants, you should use the standard (Stems) MUSDB18 dataset and decode it to WAV, instead.__
29 |
30 | Training on wav files can be launched using the `--is-wav` flag:
31 |
32 | ```
33 | python scripts/train.py --root path/to/musdb18-wav --is-wav --target vocals
34 | ```
35 |
36 | This will get you down to 0.6s per batch on 4 workers, likely hitting the bandwidth of standard hard-drives. It can be further improved using an SSD, which brings it down to 0.4s per batch on a GTX1080Ti which this leads to 95% GPU utilization. thus data-loading will not be the bottleneck anymore.
37 |
38 | ## Can I use the pre-trained models without torchhub?
39 |
40 | for some reason the torchub automatic download might not work and you want to download the files offline and use them. For that you can download [umx](https://zenodo.org/record/3340804) or [umxhq](https://zenodo.org/record/3267291) from Zenodo and create a local folder of your choice (e.g. `umx-weights`) where the model is stored in a flat file hierarchy:
41 |
42 | ```
43 | umx-weights/vocals-*.pth
44 | umx-weights/drums-*.pth
45 | umx-weights/bass-*.pth
46 | umx-weights/other-*.pth
47 | umx-weights/vocals.json
48 | umx-weights/drums.json
49 | umx-weights/bass.json
50 | umx-weights/other.json
51 | umx-weights/separator.json
52 | ```
53 |
54 | Test and eval can then be started using:
55 |
56 | ```bash
57 | umx --model umx-weights --input test.wav
58 | ```
59 |
--------------------------------------------------------------------------------
/pdoc/config.mako:
--------------------------------------------------------------------------------
1 | <%!
2 | # Template configuration. Copy over in your template directory
3 | # (used with `--template-dir`) and adapt as necessary.
4 | # Note, defaults are loaded from this distribution file, so your
5 | # config.mako only needs to contain values you want overridden.
6 | # You can also run pdoc with `--config KEY=VALUE` to override
7 | # individual values.
8 | html_lang = 'en'
9 | show_inherited_members = False
10 | extract_module_toc_into_sidebar = True
11 | list_class_variables_in_index = True
12 | sort_identifiers = True
13 | show_type_annotations = True
14 | # Show collapsed source code block next to each item.
15 | # Disabling this can improve rendering speed of large modules.
16 | show_source_code = True
17 | # If set, format links to objects in online source code repository
18 | # according to this template. Supported keywords for interpolation
19 | # are: commit, path, start_line, end_line.
20 | git_link_template = 'https://github.com/sigsep/open-unmix-pytorch/blob/{commit}/{path}#L{start_line}-L{end_line}'
21 | #git_link_template = 'https://gitlab.com/USER/PROJECT/blob/{commit}/{path}#L{start_line}-L{end_line}'
22 | #git_link_template = 'https://bitbucket.org/USER/PROJECT/src/{commit}/{path}#lines-{start_line}:{end_line}'
23 | #git_link_template = 'https://CGIT_HOSTNAME/PROJECT/tree/{path}?id={commit}#n{start-line}'
24 | # A prefix to use for every HTML hyperlink in the generated documentation.
25 | # No prefix results in all links being relative.
26 | link_prefix = ''
27 | # Enable syntax highlighting for code/source blocks by including Highlight.js
28 | syntax_highlighting = True
29 | # Set the style keyword such as 'atom-one-light' or 'github-gist'
30 | # Options: https://github.com/highlightjs/highlight.js/tree/master/src/styles
31 | # Demo: https://highlightjs.org/static/demo/
32 | hljs_style = 'github'
33 | # If set, insert Google Analytics tracking code. Value is GA
34 | # tracking id (UA-XXXXXX-Y).
35 | google_analytics = ''
36 | # If set, insert Google Custom Search search bar widget above the sidebar index.
37 | # The whitespace-separated tokens represent arbitrary extra queries (at least one
38 | # must match) passed to regular Google search. Example:
39 | #google_search_query = 'inurl:github.com/USER/PROJECT site:PROJECT.github.io site:PROJECT.website'
40 | google_search_query = ''
41 | # Enable offline search using Lunr.js. For explanation of 'fuzziness' parameter, which is
42 | # added to every query word, see: https://lunrjs.com/guides/searching.html#fuzzy-matches
43 | # If 'index_docstrings' is False, a shorter index is built, indexing only
44 | # the full object reference names.
45 | #lunr_search = {'fuzziness': 1, 'index_docstrings': True}
46 | lunr_search = None
47 | # If set, render LaTeX math syntax within \(...\) (inline equations),
48 | # or within \[...\] or $$...$$ or `.. math::` (block equations)
49 | # as nicely-formatted math formulas using MathJax.
50 | # Note: in Python docstrings, either all backslashes need to be escaped (\\)
51 | # or you need to use raw r-strings.
52 | latex_math = True
53 | %>
--------------------------------------------------------------------------------
/openunmix/predict.py:
--------------------------------------------------------------------------------
1 | from openunmix import utils
2 |
3 |
4 | def separate(
5 | audio,
6 | rate=None,
7 | model_str_or_path="umxl",
8 | targets=None,
9 | niter=1,
10 | residual=False,
11 | wiener_win_len=300,
12 | aggregate_dict=None,
13 | separator=None,
14 | device=None,
15 | filterbank="torch",
16 | ):
17 | """
18 | Open Unmix functional interface
19 |
20 | Separates a torch.Tensor or the content of an audio file.
21 |
22 | If a separator is provided, use it for inference. If not, create one
23 | and use it afterwards.
24 |
25 | Args:
26 | audio: audio to process
27 | torch Tensor: shape (channels, length), and
28 | `rate` must also be provided.
29 | rate: int or None: only used if audio is a Tensor. Otherwise,
30 | inferred from the file.
31 | model_str_or_path: the pretrained model to use, defaults to UMX-L
32 | targets (str): select the targets for the source to be separated.
33 | a list including: ['vocals', 'drums', 'bass', 'other'].
34 | If you don't pick them all, you probably want to
35 | activate the `residual=True` option.
36 | Defaults to all available targets per model.
37 | niter (int): the number of post-processingiterations, defaults to 1
38 | residual (bool): if True, a "garbage" target is created
39 | wiener_win_len (int): the number of frames to use when batching
40 | the post-processing step
41 | aggregate_dict (str): if provided, must be a string containing a '
42 | 'valid expression for a dictionary, with keys as output '
43 | 'target names, and values a list of targets that are used to '
44 | 'build it. For instance: \'{\"vocals\":[\"vocals\"], '
45 | '\"accompaniment\":[\"drums\",\"bass\",\"other\"]}\'
46 | separator: if provided, the model.Separator object that will be used
47 | to perform separation
48 | device (str): selects device to be used for inference
49 | filterbank (str): filterbank implementation method.
50 | Supported are `['torch', 'asteroid']`. `torch` is about 30% faster
51 | compared to `asteroid` on large FFT sizes such as 4096. However,
52 | asteroids stft can be exported to onnx, which makes is practical
53 | for deployment.
54 | """
55 | if separator is None:
56 | separator = utils.load_separator(
57 | model_str_or_path=model_str_or_path,
58 | targets=targets,
59 | niter=niter,
60 | residual=residual,
61 | wiener_win_len=wiener_win_len,
62 | device=device,
63 | pretrained=True,
64 | filterbank=filterbank,
65 | )
66 | separator.freeze()
67 | if device:
68 | separator.to(device)
69 |
70 | if rate is None:
71 | raise Exception("rate` must be provided.")
72 |
73 | if device:
74 | audio = audio.to(device)
75 | audio = utils.preprocess(audio, rate, separator.sample_rate)
76 |
77 | # getting the separated signals
78 | estimates = separator(audio)
79 | estimates = separator.to_dict(estimates, aggregate_dict=aggregate_dict)
80 | return estimates
81 |
--------------------------------------------------------------------------------
/docs/inference.md:
--------------------------------------------------------------------------------
1 | # Performing separation
2 |
3 | ## Interfacing using the command line
4 |
5 | The primary interface to separate files is the command line. To separate a mixture file into the four stems you can just run
6 |
7 | ```bash
8 | umx input_file.wav
9 | ```
10 |
11 | Note that we support all files that can be read by torchaudio, depending on the set backend (either `soundfile` (libsndfile) or `sox`).
12 | For training, we set the default to `soundfile` as it is faster than `sox`. However for inference users might prefer `mp3` decoding capabilities.
13 | The separation can be controlled with additional parameters that influence the performance of the separation.
14 |
15 | | Command line Argument | Description | Default |
16 | |----------------------------|---------------------------------------------------------------------------------|-----------------|
17 | |`--start ` | set start in seconds to reduce the duration of the audio being loaded | `0.0` |
18 | |`--duration ` | set duration in seconds to reduce length of the audio being loaded. Negative values will make the full audio being loaded | `-1.0` |
19 | |`--model ` | path or string of model name to select either a self pre-trained model or a model loaded from `torchhub`. | |
20 | | `--targets list(str)` | Targets to be used for separation. For each target a model file with with same name is required. | `['vocals', 'drums', 'bass', 'other']` |
21 | | `--niter ` | Number of EM steps for refining initial estimates in a post-processing stage. `--niter 0` skips this step altogether (and thus makes separation significantly faster) More iterations can get better interference reduction at the price of more artifacts. | `1` |
22 | | `--residual` | computes a residual target, for custom separation scenarios when not all targets are available (at the expense of slightly less performance). E.g vocal/accompaniment can be performed with `--targets vocals --residual`. | not set |
23 | | `--softmask` | if activated, then the initial estimates for the sources will be obtained through a ratio mask of the mixture STFT, and not by using the default behavior of reconstructing waveforms by using the mixture phase. | not set |
24 | | `--wiener-win-len ` | Number of frames on which to apply filtering independently | `300` |
25 | | `--audio-backend ` | choose audio loading backend, either `sox_io`, `soundfile` or `stempeg` (which needs additional installation requirements) | [torchaudio default](https://pytorch.org/audio/stable/backend.html) |
26 | | `--aggregate ` | if provided, must be a string containing a valid expression for a dictionary, with keys as output target names, and values a list of targets that are used to build it. For instance: `{ "vocals": ["vocals"], "accompaniment": ["drums", "bass", "other"]}` | `None` |
27 | | `--filterbank ` | filterbank implementation method. Supported: `['torch', 'asteroid']`. While `torch` is ~30% faster compared to `asteroid` on large FFT sizes such as 4096, asteroids STFT maybe be easier to be exported for deployment. | `torch` |
28 |
29 | ## Interfacing from python
30 |
31 | At the core of the process of separating audio is the `Separator` Module which
32 | takes a numpy audio array or a `torch.Tensor` as input (the mixture) and separates into `targets` stems.
33 | Note, that for each target a separate model will be loaded. E.g. for `umx` and `umxhq` the supported targets are
34 | `['vocals', 'drums', 'bass', 'other']`. The models have to be passed to the separators `target_models` parameter.
35 |
36 | Both models `umx`, `umxhq`, `umxl` and `umxse` are downloaded automatically.
37 |
38 | Here is an example for constructor for the `Separator` takes the following arguments, with suggested default values:
39 |
40 | ```python
41 | seperator = openunmix.Separator(
42 | target_models: dict,
43 | niter: int = 0,
44 | softmask: bool = False,
45 | residual: bool = False,
46 | sample_rate: float = 44100.0,
47 | n_fft: int = 4096,
48 | n_hop: int = 1024,
49 | nb_channels: int = 2,
50 | wiener_win_len: Optional[int] = 300,
51 | filterbank: str = 'torch'
52 | ):
53 | ```
54 |
55 | When passing
56 |
57 | > __Caution__ `training` using the EM algorithm (`niter>0`) is not supported. Only plain post-processing is supported right now for gradient computation. This is because the performance overhead of avoiding all the in-places operations is too large.
58 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # Contributing
2 |
3 | Open-Unmix is designed as scientific software. Therefore, we encourage the community to submit bug-fixes and comments and improve the __computational performance__, __reproducibility__ and the __readability__ of the code where possible. When contributing to this repository, please first discuss the change you wish to make in the issue tracker with the owners of this repository before making a change.
4 |
5 | We are not looking for contributions that only focus on improving the __separation performance__. However, if this is case, we, instead, encourage researchers to
6 |
7 | 1. Use Open-Unmix for their own research, e.g. by modification of the model.
8 | 2. Publish and present the results in a scientific paper / conference and __cite open-unmix__.
9 | 3. Contact us via mail or open a [performance issue]() if you are interested to contribute the new model.
10 | In this case we will rerun the training on our internal cluster and update the pre-trained weights accordingly.
11 |
12 | Please note we have a code of conduct, please follow it in all your interactions with the project.
13 |
14 | ## Pull Request Process
15 |
16 | The preferred way to contribute to open-unmix is to fork the
17 | [main repository](http://github.com/sigsep/open-unmix-pytorch/) on
18 | GitHub:
19 |
20 | 1. Fork the [project repository](http://github.com/sigsep/open-unmix-pytorch):
21 | click on the 'Fork' button near the top of the page. This creates
22 | a copy of the code under your account on the GitHub server.
23 |
24 | 2. Clone this copy to your local disk:
25 |
26 | ```
27 | $ git clone git@github.com:YourLogin/open-unmix-pytorch.git
28 | $ cd open-unmix-pytorch
29 | ```
30 |
31 | 3. Create a branch to hold your changes:
32 |
33 | ```
34 | $ git checkout -b my-feature
35 | ```
36 |
37 | and start making changes. Never work in the ``master`` branch!
38 |
39 | 4. Ensure any install or build artifacts are removed before making the pull request.
40 |
41 | 5. Update the README.md and/or the appropriate document in the `/docs` folder with details of changes to the interface, this includes new command line arguments, dataset description or command line examples.
42 |
43 | 6. Work on this copy on your computer using Git to do the version
44 | control. When you're done editing, do:
45 |
46 | ```
47 | $ git add modified_files
48 | $ git commit
49 | ```
50 |
51 | to record your changes in Git, then push them to GitHub with:
52 |
53 | ```
54 | $ git push -u origin my-feature
55 | ```
56 |
57 | Finally, go to the web page of your fork of the open-unmix repo,
58 | and click 'Pull request' to send your changes to the maintainers for
59 | review. This will send an email to the committers.
60 |
61 | (If any of the above seems like magic to you, then look up the
62 | [Git documentation](http://git-scm.com/documentation) on the web.)
63 |
64 | ## Code of Conduct
65 |
66 | ### Our Pledge
67 |
68 | In the interest of fostering an open and welcoming environment, we as
69 | contributors and maintainers pledge to making participation in our project and
70 | our community a harassment-free experience for everyone, regardless of age, body
71 | size, disability, ethnicity, gender identity and expression, level of experience,
72 | nationality, personal appearance, race, religion, or sexual identity and
73 | orientation.
74 |
75 | ### Our Standards
76 |
77 | Examples of behavior that contributes to creating a positive environment
78 | include:
79 |
80 | * Using welcoming and inclusive language
81 | * Being respectful of differing viewpoints and experiences
82 | * Gracefully accepting constructive criticism
83 | * Focusing on what is best for the community
84 | * Showing empathy towards other community members
85 |
86 | Examples of unacceptable behavior by participants include:
87 |
88 | * The use of sexualized language or imagery and unwelcome sexual attention or
89 | advances
90 | * Trolling, insulting/derogatory comments, and personal or political attacks
91 | * Public or private harassment
92 | * Publishing others' private information, such as a physical or electronic
93 | address, without explicit permission
94 | * Other conduct which could reasonably be considered inappropriate in a
95 | professional setting
96 |
97 | ### Our Responsibilities
98 |
99 | Project maintainers are responsible for clarifying the standards of acceptable
100 | behavior and are expected to take appropriate and fair corrective action in
101 | response to any instances of unacceptable behavior.
102 |
103 | Project maintainers have the right and responsibility to remove, edit, or
104 | reject comments, commits, code, wiki edits, issues, and other contributions
105 | that are not aligned to this Code of Conduct, or to ban temporarily or
106 | permanently any contributor for other behaviors that they deem inappropriate,
107 | threatening, offensive, or harmful.
108 |
109 | ### Scope
110 |
111 | This Code of Conduct applies both within project spaces and in public spaces
112 | when an individual is representing the project or its community. Examples of
113 | representing a project or community include using an official project e-mail
114 | address, posting via an official social media account, or acting as an appointed
115 | representative at an online or offline event. Representation of a project may be
116 | further defined and clarified by project maintainers.
117 |
118 | ### Enforcement
119 |
120 | Instances of abusive, harassing, or otherwise unacceptable behavior may be
121 | reported by contacting the project team @aliutkus, @faroit. All
122 | complaints will be reviewed and investigated and will result in a response that
123 | is deemed necessary and appropriate to the circumstances. The project team is
124 | obligated to maintain confidentiality with regard to the reporter of an incident.
125 | Further details of specific enforcement policies may be posted separately.
126 |
127 | Project maintainers who do not follow or enforce the Code of Conduct in good
128 | faith may face temporary or permanent repercussions as determined by other
129 | members of the project's leadership.
130 |
131 | ### Attribution
132 |
133 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
134 | available at [http://contributor-covenant.org/version/1/4][version]
135 |
136 | [homepage]: http://contributor-covenant.org
137 | [version]: http://contributor-covenant.org/version/1/4/
--------------------------------------------------------------------------------
/openunmix/evaluate.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import functools
3 | import json
4 | import multiprocessing
5 | from typing import Optional, Union
6 |
7 | import musdb
8 | import museval
9 | import torch
10 | import tqdm
11 |
12 | from openunmix import utils
13 |
14 |
15 | def separate_and_evaluate(
16 | track: musdb.MultiTrack,
17 | targets: list,
18 | model_str_or_path: str,
19 | niter: int,
20 | output_dir: str,
21 | eval_dir: str,
22 | residual: bool,
23 | mus,
24 | aggregate_dict: dict = None,
25 | device: Union[str, torch.device] = "cpu",
26 | wiener_win_len: Optional[int] = None,
27 | filterbank="torch",
28 | ) -> str:
29 |
30 | separator = utils.load_separator(
31 | model_str_or_path=model_str_or_path,
32 | targets=targets,
33 | niter=niter,
34 | residual=residual,
35 | wiener_win_len=wiener_win_len,
36 | device=device,
37 | pretrained=True,
38 | filterbank=filterbank,
39 | )
40 |
41 | separator.freeze()
42 | separator.to(device)
43 |
44 | audio = torch.as_tensor(track.audio, dtype=torch.float32, device=device)
45 | audio = utils.preprocess(audio, track.rate, separator.sample_rate)
46 |
47 | estimates = separator(audio)
48 | estimates = separator.to_dict(estimates, aggregate_dict=aggregate_dict)
49 |
50 | for key in estimates:
51 | estimates[key] = estimates[key][0].cpu().detach().numpy().T
52 | if output_dir:
53 | mus.save_estimates(estimates, track, output_dir)
54 |
55 | scores = museval.eval_mus_track(track, estimates, output_dir=eval_dir)
56 | return scores
57 |
58 |
59 | if __name__ == "__main__":
60 | # Training settings
61 | parser = argparse.ArgumentParser(description="MUSDB18 Evaluation", add_help=False)
62 |
63 | parser.add_argument(
64 | "--targets",
65 | nargs="+",
66 | default=["vocals", "drums", "bass", "other"],
67 | type=str,
68 | help="provide targets to be processed. \
69 | If none, all available targets will be computed",
70 | )
71 |
72 | parser.add_argument(
73 | "--model",
74 | default="umxl",
75 | type=str,
76 | help="path to mode base directory of pretrained models",
77 | )
78 |
79 | parser.add_argument(
80 | "--outdir",
81 | type=str,
82 | help="Results path where audio evaluation results are stored",
83 | )
84 |
85 | parser.add_argument("--evaldir", type=str, help="Results path for museval estimates")
86 |
87 | parser.add_argument("--root", type=str, help="Path to MUSDB18")
88 |
89 | parser.add_argument("--subset", type=str, default="test", help="MUSDB subset (`train`/`test`)")
90 |
91 | parser.add_argument("--cores", type=int, default=1)
92 |
93 | parser.add_argument("--no-cuda", action="store_true", default=False, help="disables CUDA inference")
94 |
95 | parser.add_argument(
96 | "--is-wav",
97 | action="store_true",
98 | default=False,
99 | help="flags wav version of the dataset",
100 | )
101 |
102 | parser.add_argument(
103 | "--niter",
104 | type=int,
105 | default=1,
106 | help="number of iterations for refining results.",
107 | )
108 |
109 | parser.add_argument(
110 | "--wiener-win-len",
111 | type=int,
112 | default=300,
113 | help="Number of frames on which to apply filtering independently",
114 | )
115 |
116 | parser.add_argument(
117 | "--residual",
118 | type=str,
119 | default=None,
120 | help="if provided, build a source with given name" "for the mix minus all estimated targets",
121 | )
122 |
123 | parser.add_argument(
124 | "--aggregate",
125 | type=str,
126 | default=None,
127 | help="if provided, must be a string containing a valid expression for "
128 | "a dictionary, with keys as output target names, and values "
129 | "a list of targets that are used to build it. For instance: "
130 | '\'{"vocals":["vocals"], "accompaniment":["drums",'
131 | '"bass","other"]}\'',
132 | )
133 |
134 | args = parser.parse_args()
135 |
136 | use_cuda = not args.no_cuda and torch.cuda.is_available()
137 | device = torch.device("cuda" if use_cuda else "cpu")
138 |
139 | mus = musdb.DB(
140 | root=args.root,
141 | download=args.root is None,
142 | subsets=args.subset,
143 | is_wav=args.is_wav,
144 | )
145 | aggregate_dict = None if args.aggregate is None else json.loads(args.aggregate)
146 |
147 | if args.cores > 1:
148 | pool = multiprocessing.Pool(args.cores)
149 | results = museval.EvalStore()
150 | scores_list = list(
151 | pool.imap_unordered(
152 | func=functools.partial(
153 | separate_and_evaluate,
154 | targets=args.targets,
155 | model_str_or_path=args.model,
156 | niter=args.niter,
157 | residual=args.residual,
158 | mus=mus,
159 | aggregate_dict=aggregate_dict,
160 | output_dir=args.outdir,
161 | eval_dir=args.evaldir,
162 | device=device,
163 | ),
164 | iterable=mus.tracks,
165 | chunksize=1,
166 | )
167 | )
168 | pool.close()
169 | pool.join()
170 | for scores in scores_list:
171 | results.add_track(scores)
172 |
173 | else:
174 | results = museval.EvalStore()
175 | for track in tqdm.tqdm(mus.tracks):
176 | scores = separate_and_evaluate(
177 | track,
178 | targets=args.targets,
179 | model_str_or_path=args.model,
180 | niter=args.niter,
181 | residual=args.residual,
182 | mus=mus,
183 | aggregate_dict=aggregate_dict,
184 | output_dir=args.outdir,
185 | eval_dir=args.evaldir,
186 | device=device,
187 | )
188 | print(track, "\n", scores)
189 | results.add_track(scores)
190 |
191 | print(results)
192 | method = museval.MethodStore()
193 | method.add_evalstore(results, args.model)
194 | method.save(args.model + ".pandas")
195 |
--------------------------------------------------------------------------------
/tests/data/Al James - Schoolboy Facination.json:
--------------------------------------------------------------------------------
1 | {
2 | "targets": [
3 | {
4 | "name": "vocals",
5 | "frames": [
6 | {
7 | "time": 0.0,
8 | "duration": 1.0,
9 | "metrics": {
10 | "SDR": 4.32297,
11 | "SIR": 10.12637,
12 | "SAR": 7.59260,
13 | "ISR": 7.33622
14 | }
15 | },
16 | {
17 | "time": 1.0,
18 | "duration": 1.0,
19 | "metrics": {
20 | "SDR": 7.22083,
21 | "SIR": 10.20295,
22 | "SAR": 8.57502,
23 | "ISR": 10.96141
24 | }
25 | },
26 | {
27 | "time": 2.0,
28 | "duration": 1.0,
29 | "metrics": {
30 | "SDR": 7.80233,
31 | "SIR": 13.23437,
32 | "SAR": 10.15870,
33 | "ISR": 10.97836
34 | }
35 | },
36 | {
37 | "time": 3.0,
38 | "duration": 1.0,
39 | "metrics": {
40 | "SDR": 8.56464,
41 | "SIR": 15.05897,
42 | "SAR": 11.07534,
43 | "ISR": 11.29683
44 | }
45 | },
46 | {
47 | "time": 4.0,
48 | "duration": 1.0,
49 | "metrics": {
50 | "SDR": 7.42655,
51 | "SIR": 14.16498,
52 | "SAR": 10.40192,
53 | "ISR": 11.35024
54 | }
55 | },
56 | {
57 | "time": 5.0,
58 | "duration": 1.0,
59 | "metrics": {
60 | "SDR": 6.66938,
61 | "SIR": 11.71322,
62 | "SAR": 8.97930,
63 | "ISR": 10.42365
64 | }
65 | }
66 | ]
67 | },
68 | {
69 | "name": "drums",
70 | "frames": [
71 | {
72 | "time": 0.0,
73 | "duration": 1.0,
74 | "metrics": {
75 | "SDR": 3.98718,
76 | "SIR": 5.23581,
77 | "SAR": 3.66383,
78 | "ISR": 8.04258
79 | }
80 | },
81 | {
82 | "time": 1.0,
83 | "duration": 1.0,
84 | "metrics": {
85 | "SDR": 3.10503,
86 | "SIR": 2.47233,
87 | "SAR": 2.91512,
88 | "ISR": 7.29689
89 | }
90 | },
91 | {
92 | "time": 2.0,
93 | "duration": 1.0,
94 | "metrics": {
95 | "SDR": 4.06987,
96 | "SIR": 6.64704,
97 | "SAR": 5.02986,
98 | "ISR": 6.72215
99 | }
100 | },
101 | {
102 | "time": 3.0,
103 | "duration": 1.0,
104 | "metrics": {
105 | "SDR": 3.67143,
106 | "SIR": 4.86323,
107 | "SAR": 3.75070,
108 | "ISR": 7.21293
109 | }
110 | },
111 | {
112 | "time": 4.0,
113 | "duration": 1.0,
114 | "metrics": {
115 | "SDR": 3.88528,
116 | "SIR": 2.51644,
117 | "SAR": 3.34250,
118 | "ISR": 8.09842
119 | }
120 | },
121 | {
122 | "time": 5.0,
123 | "duration": 1.0,
124 | "metrics": {
125 | "SDR": 5.39026,
126 | "SIR": 3.22021,
127 | "SAR": 3.37839,
128 | "ISR": 8.22456
129 | }
130 | }
131 | ]
132 | },
133 | {
134 | "name": "bass",
135 | "frames": [
136 | {
137 | "time": 0.0,
138 | "duration": 1.0,
139 | "metrics": {
140 | "SDR": 5.60414,
141 | "SIR": 12.81026,
142 | "SAR": 8.76971,
143 | "ISR": 7.72415
144 | }
145 | },
146 | {
147 | "time": 1.0,
148 | "duration": 1.0,
149 | "metrics": {
150 | "SDR": 6.79312,
151 | "SIR": 15.50595,
152 | "SAR": 7.36306,
153 | "ISR": 6.43593
154 | }
155 | },
156 | {
157 | "time": 2.0,
158 | "duration": 1.0,
159 | "metrics": {
160 | "SDR": 5.42226,
161 | "SIR": 14.25947,
162 | "SAR": 7.97667,
163 | "ISR": 5.99902
164 | }
165 | },
166 | {
167 | "time": 3.0,
168 | "duration": 1.0,
169 | "metrics": {
170 | "SDR": 5.88173,
171 | "SIR": 13.25402,
172 | "SAR": 7.96686,
173 | "ISR": 7.43682
174 | }
175 | },
176 | {
177 | "time": 4.0,
178 | "duration": 1.0,
179 | "metrics": {
180 | "SDR": 6.81244,
181 | "SIR": 13.69640,
182 | "SAR": 9.50696,
183 | "ISR": 7.39392
184 | }
185 | },
186 | {
187 | "time": 5.0,
188 | "duration": 1.0,
189 | "metrics": {
190 | "SDR": 4.63153,
191 | "SIR": 10.98876,
192 | "SAR": 6.74840,
193 | "ISR": 6.41703
194 | }
195 | }
196 | ]
197 | },
198 | {
199 | "name": "other",
200 | "frames": [
201 | {
202 | "time": 0.0,
203 | "duration": 1.0,
204 | "metrics": {
205 | "SDR": 2.02272,
206 | "SIR": 3.09448,
207 | "SAR": 5.76186,
208 | "ISR": 7.85117
209 | }
210 | },
211 | {
212 | "time": 1.0,
213 | "duration": 1.0,
214 | "metrics": {
215 | "SDR": 3.25884,
216 | "SIR": 1.92849,
217 | "SAR": 4.78471,
218 | "ISR": 7.66760
219 | }
220 | },
221 | {
222 | "time": 2.0,
223 | "duration": 1.0,
224 | "metrics": {
225 | "SDR": 1.58550,
226 | "SIR": 1.34803,
227 | "SAR": 5.57206,
228 | "ISR": 7.48901
229 | }
230 | },
231 | {
232 | "time": 3.0,
233 | "duration": 1.0,
234 | "metrics": {
235 | "SDR": 2.04527,
236 | "SIR": 2.84563,
237 | "SAR": 5.45523,
238 | "ISR": 6.97621
239 | }
240 | },
241 | {
242 | "time": 4.0,
243 | "duration": 1.0,
244 | "metrics": {
245 | "SDR": 2.03671,
246 | "SIR": 2.73438,
247 | "SAR": 5.29621,
248 | "ISR": 6.24533
249 | }
250 | },
251 | {
252 | "time": 5.0,
253 | "duration": 1.0,
254 | "metrics": {
255 | "SDR": 3.31991,
256 | "SIR": 3.90981,
257 | "SAR": 4.51333,
258 | "ISR": 5.17112
259 | }
260 | }
261 | ]
262 | }
263 | ],
264 | "museval_version": "0.3.0"
265 | }
--------------------------------------------------------------------------------
/docs/extensions.md:
--------------------------------------------------------------------------------
1 | # Extending Open-Unmix
2 |
3 | 
4 | One of the key aspects of _Open-Unmix_ is that it was made to be easily extensible and thus is a good starting point for new research on music source separation. In fact, the open-unmix training code is based on the [pytorch MNIST example](https://github.com/pytorch/examples/blob/master/mnist/main.py). In this document we provide a short overview of ways to extend open-unmix.
5 |
6 | ## Code Structure
7 |
8 | * `data.py` includes several torch datasets that can all be used to train _open-unmix_.
9 | * `train.py` includes all code that is necessary to start a training.
10 | * `model.py` includes the open-unmix torch modules.
11 | * `test.py` includes code to predict/unmix from audio files.
12 | * `eval.py` includes all code to run the objective evaluation using museval on the MUSDB18 dataset.
13 | * `utils.py` includes additional tools like audio loading and metadata loading.
14 |
15 | ## Provide a custom dataset
16 |
17 | Users of open-unmix that have their own datasets and could not fit one of our predefined datasets might want to implement or use their own `torch.utils.data.Dataset` to be used for the training. Such a modification is very simple since our dataset.
18 |
19 | ### Template Dataset
20 |
21 | In case you want to create your own dataset we provide a template for the open-unmix API. You can use our efficient torchaudio or libsndfile based `load_audio` audio loaders or just use your own files. Since currently (pytorch<=1.1) is using index based datasets (instead of iterable based datasets), the best way to load audio is to assign the index to one audio track. However, there are possible applications where the index is ignored and the `__len__()` method just returns arbitrary number of samples.
22 |
23 | ```python
24 | from utils import load_audio, load_info
25 | class TemplateDataset(UnmixDataset):
26 | """A template dataset class for you to implement custom datasets."""
27 |
28 | def __init__(self, root, split='train', sample_rate=44100, seq_dur=None):
29 | """Initialize the dataset
30 | """
31 | self.root = root
32 | self.tracks = get_tracks(root, split)
33 |
34 | def __getitem__(self, index):
35 | """Returns a time domain audio example
36 | of shape=(channel, sample)
37 | """
38 | path = self.tracks[index]
39 | x = load_audio(path)
40 | y = load_audio(path)
41 | return x, y
42 |
43 | def __len__(self):
44 | """Return the number of audio samples"""
45 | return len(self.tracks)
46 | ```
47 |
48 | ## Provide a custom model
49 |
50 | We think that recurrent models provide the best trade-off between good results, fast training and flexibility of training due to its ability to learn from arbitrary durations of audio and different audio representations. If you want to try different models you can easily build upon our model template below:
51 |
52 | ### Template Spectrogram Model
53 |
54 | ```python
55 | from model import Spectrogram, STFT
56 | class Model(nn.Module):
57 | def __init__(
58 | self,
59 | n_fft=4096,
60 | n_hop=1024,
61 | nb_channels=2,
62 | input_is_spectrogram=False,
63 | sample_rate=44100.0,
64 | ):
65 | """
66 | Input: (batch, channel, sample)
67 | or (frame, batch, channels, frequency)
68 | Output: (frame, batch, channels, frequency)
69 | """
70 |
71 | super(OpenUnmix, self).__init__()
72 |
73 | def forward(self, mix):
74 | # transform to spectrogram on the fly
75 | X = self.transform(mix)
76 | nb_frames, nb_samples, nb_channels, nb_bins = x.data.shape
77 |
78 | # transform X to estimate
79 | # ....
80 |
81 | return X
82 | ```
83 |
84 | ## Jointly train targets
85 |
86 | We designed _open-unmix_ so that the training of multiple targets is handled in separate models. We think that this has several benefits such as:
87 |
88 | * single source models can leverage unbalanced data where for each source different size of training data is available/
89 | * training can easily distributed by training multiple models on different nodes in parallel.
90 | * at test time the selection of different models can be adjusted for specific applications.
91 |
92 | However, we acknowledge the fact that there might be reasons to train a model jointly for all sources to improve the separation performance. These changes can easily be made in _open-unmix_ with the following modifications based the way how pytorch handles single-input-multiple-outputs models.
93 |
94 | ### 1. Extend `data.py`
95 |
96 | The dataset should be able to yield a list of tensors (one for each target): E.g. the `musdb` dataset can be extended with:
97 |
98 | ```python
99 | y = [stems[ind] for ind, _ in enumerate(self.targets)]
100 | ```
101 |
102 | ### 2. Extend `model.py`
103 |
104 | The _open-unmix_ model can be left unchanged but instead a "supermodel" can be added that joins the forward paths of all targets:
105 |
106 | ```python
107 | class OpenUnmixJoint(nn.Module):
108 | def __init__(
109 | self,
110 | targets,
111 | *args, **kwargs
112 | ):
113 | super(OpenUnmixJoint, self).__init__()
114 | self.models = nn.ModuleList(
115 | [OpenUnmix(*args, **kwargs) for target in targets]
116 | )
117 |
118 | def forward(self, x):
119 | return [model(x) for model in self.models]
120 | ```
121 |
122 | ### 3. Extend `train.py`
123 |
124 | The training should be updated so that the total loss is an aggregation of the individual target losses. For the mean squared error, the following modifications should be sufficient:
125 |
126 | ```python
127 | criteria = [torch.nn.MSELoss() for t in args.targets]
128 | # ...
129 | for x, y in tqdm.tqdm(train_sampler, disable=args.quiet):
130 | x = x.to(device)
131 | y = [i.to(device) for i in y]
132 | optimizer.zero_grad()
133 | Y_hats = unmix(x)
134 | loss = 0
135 | for Y_hat, target, criterion in zip(Y_hats, y, criteria):
136 | loss = loss + criterion(Y_hat, unmix.models[0].transform(target))
137 | ```
138 |
139 | ## End-to-End time-domain models
140 |
141 | If you want to evaluate models that work in the time domain such as WaveNet or WaveRNN, the training code would have to modified. Instead of spectrogram output `Y` the output is simply a time domain signal `y` that can directly be compared with `x`. E.g. going from:
142 |
143 | ```python
144 | Y_hat = unmix(x)
145 | Y = unmix.transform(y)
146 | loss = criterion(Y_hat, Y)
147 | ```
148 |
149 | to:
150 |
151 | ```python
152 | y_hat = unmix(x)
153 | loss = criterion(y_hat, y)
154 | ```
155 |
156 | Inference, in that case, would then have to drop the spectral wiener filter and instead directly save the time domain signal (and maybe its residual):
157 |
158 | ```python
159 | est = unmix(audio_torch).cpu().detach().numpy()
160 | estimates[target] = est[0].T
161 | estimates['residual'] = audio - est[0].T
162 | ```
163 |
--------------------------------------------------------------------------------
/openunmix/cli.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | import torch
3 | import torchaudio
4 | import json
5 | import numpy as np
6 | import tqdm
7 |
8 | from openunmix import utils
9 | from openunmix import predict
10 | from openunmix import data
11 |
12 | import argparse
13 |
14 |
15 | def separate():
16 | parser = argparse.ArgumentParser(
17 | description="UMX Inference",
18 | add_help=True,
19 | formatter_class=argparse.RawDescriptionHelpFormatter,
20 | )
21 |
22 | parser.add_argument("input", type=str, nargs="+", help="List of paths to wav/flac files.")
23 |
24 | parser.add_argument(
25 | "--model",
26 | default="umxl",
27 | type=str,
28 | help="path to mode base directory of pretrained models, defaults to UMX-L",
29 | )
30 |
31 | parser.add_argument(
32 | "--targets",
33 | nargs="+",
34 | type=str,
35 | help="provide targets to be processed. \
36 | If none, all available targets will be computed",
37 | )
38 |
39 | parser.add_argument(
40 | "--outdir",
41 | type=str,
42 | help="Results path where audio evaluation results are stored",
43 | )
44 |
45 | parser.add_argument(
46 | "--ext",
47 | type=str,
48 | default=".wav",
49 | help="Output extension which sets the audio format",
50 | )
51 |
52 | parser.add_argument("--start", type=float, default=0.0, help="Audio chunk start in seconds")
53 |
54 | parser.add_argument(
55 | "--duration",
56 | type=float,
57 | help="Audio chunk duration in seconds, negative values load full track",
58 | )
59 |
60 | parser.add_argument("--no-cuda", action="store_true", default=False, help="disables CUDA inference")
61 |
62 | parser.add_argument(
63 | "--audio-backend",
64 | type=str,
65 | help="Sets audio backend. Default to torchaudio's default backend: See https://pytorch.org/audio/stable/backend.html"
66 | "(`sox_io`, `sox`, `soundfile` or `stempeg`)",
67 | )
68 |
69 | parser.add_argument(
70 | "--niter",
71 | type=int,
72 | default=1,
73 | help="number of iterations for refining results.",
74 | )
75 |
76 | parser.add_argument(
77 | "--wiener-win-len",
78 | type=int,
79 | default=300,
80 | help="Number of frames on which to apply filtering independently",
81 | )
82 |
83 | parser.add_argument(
84 | "--residual",
85 | type=str,
86 | default=None,
87 | help="if provided, build a source with given name " "for the mix minus all estimated targets",
88 | )
89 |
90 | parser.add_argument(
91 | "--aggregate",
92 | type=str,
93 | default=None,
94 | help="if provided, must be a string containing a valid expression for "
95 | "a dictionary, with keys as output target names, and values "
96 | "a list of targets that are used to build it. For instance: "
97 | '\'{"vocals":["vocals"], "accompaniment":["drums",'
98 | '"bass","other"]}\'',
99 | )
100 |
101 | parser.add_argument(
102 | "--filterbank",
103 | type=str,
104 | default="torch",
105 | help="filterbank implementation method. "
106 | "Supported: `['torch', 'asteroid']`. `torch` is ~30%% faster "
107 | "compared to `asteroid` on large FFT sizes such as 4096. However "
108 | "asteroids stft can be exported to onnx, which makes is practical "
109 | "for deployment.",
110 | )
111 | parser.add_argument(
112 | "--verbose",
113 | action="store_true",
114 | default=False,
115 | help="Enable log messages",
116 | )
117 | args = parser.parse_args()
118 |
119 | if args.audio_backend != "stempeg" and args.audio_backend is not None:
120 | torchaudio.set_audio_backend(args.audio_backend)
121 |
122 | use_cuda = not args.no_cuda and torch.cuda.is_available()
123 | device = torch.device("cuda" if use_cuda else "cpu")
124 | if args.verbose:
125 | print("Using ", device)
126 | # parsing the output dict
127 | aggregate_dict = None if args.aggregate is None else json.loads(args.aggregate)
128 |
129 | # create separator only once to reduce model loading
130 | # when using multiple files
131 | separator = utils.load_separator(
132 | model_str_or_path=args.model,
133 | targets=args.targets,
134 | niter=args.niter,
135 | residual=args.residual,
136 | wiener_win_len=args.wiener_win_len,
137 | device=device,
138 | pretrained=True,
139 | filterbank=args.filterbank,
140 | )
141 |
142 | separator.freeze()
143 | separator.to(device)
144 |
145 | if args.audio_backend == "stempeg":
146 | try:
147 | import stempeg
148 | except ImportError:
149 | raise RuntimeError("Please install pip package `stempeg`")
150 |
151 | # loop over the files
152 | for input_file in tqdm.tqdm(args.input):
153 | if args.audio_backend == "stempeg":
154 | audio, rate = stempeg.read_stems(
155 | input_file,
156 | start=args.start,
157 | duration=args.duration,
158 | sample_rate=separator.sample_rate,
159 | dtype=np.float32,
160 | )
161 | audio = torch.tensor(audio)
162 | else:
163 | audio, rate = data.load_audio(input_file, start=args.start, dur=args.duration)
164 | estimates = predict.separate(
165 | audio=audio,
166 | rate=rate,
167 | aggregate_dict=aggregate_dict,
168 | separator=separator,
169 | device=device,
170 | )
171 | if not args.outdir:
172 | model_path = Path(args.model)
173 | if not model_path.exists():
174 | outdir = Path(Path(input_file).stem + "_" + args.model)
175 | else:
176 | outdir = Path(Path(input_file).stem + "_" + model_path.stem)
177 | else:
178 | outdir = Path(args.outdir) / Path(input_file).stem
179 | outdir.mkdir(exist_ok=True, parents=True)
180 |
181 | # write out estimates
182 | if args.audio_backend == "stempeg":
183 | target_path = str(outdir / Path("target").with_suffix(args.ext))
184 | # convert torch dict to numpy dict
185 | estimates_numpy = {}
186 | for target, estimate in estimates.items():
187 | estimates_numpy[target] = torch.squeeze(estimate).detach().cpu().numpy().T
188 |
189 | stempeg.write_stems(
190 | target_path,
191 | estimates_numpy,
192 | sample_rate=separator.sample_rate,
193 | writer=stempeg.FilesWriter(multiprocess=True, output_sample_rate=rate),
194 | )
195 | else:
196 | for target, estimate in estimates.items():
197 | target_path = str(outdir / Path(target).with_suffix(args.ext))
198 | torchaudio.save(
199 | target_path,
200 | torch.squeeze(estimate).to("cpu"),
201 | sample_rate=separator.sample_rate,
202 | )
203 |
--------------------------------------------------------------------------------
/openunmix/transforms.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 | import torch
4 | import torchaudio
5 | from torch import Tensor
6 | import torch.nn as nn
7 |
8 | try:
9 | from asteroid_filterbanks.enc_dec import Encoder, Decoder
10 | from asteroid_filterbanks.transforms import to_torchaudio, from_torchaudio
11 | from asteroid_filterbanks import torch_stft_fb
12 | except ImportError:
13 | pass
14 |
15 |
16 | def make_filterbanks(n_fft=4096, n_hop=1024, center=False, sample_rate=44100.0, method="torch"):
17 | window = nn.Parameter(torch.hann_window(n_fft), requires_grad=False)
18 |
19 | if method == "torch":
20 | encoder = TorchSTFT(n_fft=n_fft, n_hop=n_hop, window=window, center=center)
21 | decoder = TorchISTFT(n_fft=n_fft, n_hop=n_hop, window=window, center=center)
22 | elif method == "asteroid":
23 | fb = torch_stft_fb.TorchSTFTFB.from_torch_args(
24 | n_fft=n_fft,
25 | hop_length=n_hop,
26 | win_length=n_fft,
27 | window=window,
28 | center=center,
29 | sample_rate=sample_rate,
30 | )
31 | encoder = AsteroidSTFT(fb)
32 | decoder = AsteroidISTFT(fb)
33 | else:
34 | raise NotImplementedError
35 | return encoder, decoder
36 |
37 |
38 | class AsteroidSTFT(nn.Module):
39 | def __init__(self, fb):
40 | super(AsteroidSTFT, self).__init__()
41 | self.enc = Encoder(fb)
42 |
43 | def forward(self, x):
44 | aux = self.enc(x)
45 | return to_torchaudio(aux)
46 |
47 |
48 | class AsteroidISTFT(nn.Module):
49 | def __init__(self, fb):
50 | super(AsteroidISTFT, self).__init__()
51 | self.dec = Decoder(fb)
52 |
53 | def forward(self, X: Tensor, length: Optional[int] = None) -> Tensor:
54 | aux = from_torchaudio(X)
55 | return self.dec(aux, length=length)
56 |
57 |
58 | class TorchSTFT(nn.Module):
59 | """Multichannel Short-Time-Fourier Forward transform
60 | uses hard coded hann_window.
61 | Args:
62 | n_fft (int, optional): transform FFT size. Defaults to 4096.
63 | n_hop (int, optional): transform hop size. Defaults to 1024.
64 | center (bool, optional): If True, the signals first window is
65 | zero padded. Centering is required for a perfect
66 | reconstruction of the signal. However, during training
67 | of spectrogram models, it can safely turned off.
68 | Defaults to `true`
69 | window (nn.Parameter, optional): window function
70 | """
71 |
72 | def __init__(
73 | self,
74 | n_fft: int = 4096,
75 | n_hop: int = 1024,
76 | center: bool = False,
77 | window: Optional[nn.Parameter] = None,
78 | ):
79 | super(TorchSTFT, self).__init__()
80 | if window is None:
81 | self.window = nn.Parameter(torch.hann_window(n_fft), requires_grad=False)
82 | else:
83 | self.window = window
84 |
85 | self.n_fft = n_fft
86 | self.n_hop = n_hop
87 | self.center = center
88 |
89 | def forward(self, x: Tensor) -> Tensor:
90 | """STFT forward path
91 | Args:
92 | x (Tensor): audio waveform of
93 | shape (nb_samples, nb_channels, nb_timesteps)
94 | Returns:
95 | STFT (Tensor): complex stft of
96 | shape (nb_samples, nb_channels, nb_bins, nb_frames, complex=2)
97 | last axis is stacked real and imaginary
98 | """
99 |
100 | shape = x.size()
101 | nb_samples, nb_channels, nb_timesteps = shape
102 |
103 | # pack batch
104 | x = x.view(-1, shape[-1])
105 |
106 | complex_stft = torch.stft(
107 | x,
108 | n_fft=self.n_fft,
109 | hop_length=self.n_hop,
110 | window=self.window,
111 | center=self.center,
112 | normalized=False,
113 | onesided=True,
114 | pad_mode="reflect",
115 | return_complex=True,
116 | )
117 | stft_f = torch.view_as_real(complex_stft)
118 | # unpack batch
119 | stft_f = stft_f.view(shape[:-1] + stft_f.shape[-3:])
120 | return stft_f
121 |
122 |
123 | class TorchISTFT(nn.Module):
124 | """Multichannel Inverse-Short-Time-Fourier functional
125 | wrapper for torch.istft to support batches
126 | Args:
127 | STFT (Tensor): complex stft of
128 | shape (nb_samples, nb_channels, nb_bins, nb_frames, complex=2)
129 | last axis is stacked real and imaginary
130 | n_fft (int, optional): transform FFT size. Defaults to 4096.
131 | n_hop (int, optional): transform hop size. Defaults to 1024.
132 | window (callable, optional): window function
133 | center (bool, optional): If True, the signals first window is
134 | zero padded. Centering is required for a perfect
135 | reconstruction of the signal. However, during training
136 | of spectrogram models, it can safely turned off.
137 | Defaults to `true`
138 | length (int, optional): audio signal length to crop the signal
139 | Returns:
140 | x (Tensor): audio waveform of
141 | shape (nb_samples, nb_channels, nb_timesteps)
142 | """
143 |
144 | def __init__(
145 | self,
146 | n_fft: int = 4096,
147 | n_hop: int = 1024,
148 | center: bool = False,
149 | sample_rate: float = 44100.0,
150 | window: Optional[nn.Parameter] = None,
151 | ) -> None:
152 | super(TorchISTFT, self).__init__()
153 |
154 | self.n_fft = n_fft
155 | self.n_hop = n_hop
156 | self.center = center
157 | self.sample_rate = sample_rate
158 |
159 | if window is None:
160 | self.window = nn.Parameter(torch.hann_window(n_fft), requires_grad=False)
161 | else:
162 | self.window = window
163 |
164 | def forward(self, X: Tensor, length: Optional[int] = None) -> Tensor:
165 | shape = X.size()
166 | X = X.reshape(-1, shape[-3], shape[-2], shape[-1])
167 |
168 | y = torch.istft(
169 | torch.view_as_complex(X),
170 | n_fft=self.n_fft,
171 | hop_length=self.n_hop,
172 | window=self.window,
173 | center=self.center,
174 | normalized=False,
175 | onesided=True,
176 | length=length,
177 | )
178 |
179 | y = y.reshape(shape[:-3] + y.shape[-1:])
180 |
181 | return y
182 |
183 |
184 | class ComplexNorm(nn.Module):
185 | r"""Compute the norm of complex tensor input.
186 |
187 | Extension of `torchaudio.functional.complex_norm` with mono
188 |
189 | Args:
190 | mono (bool): Downmix to single channel after applying power norm
191 | to maximize
192 | """
193 |
194 | def __init__(self, mono: bool = False):
195 | super(ComplexNorm, self).__init__()
196 | self.mono = mono
197 |
198 | def forward(self, spec: Tensor) -> Tensor:
199 | """
200 | Args:
201 | spec: complex_tensor (Tensor): Tensor shape of
202 | `(..., complex=2)`
203 |
204 | Returns:
205 | Tensor: Power/Mag of input
206 | `(...,)`
207 | """
208 | # take the magnitude
209 |
210 | spec = torch.abs(torch.view_as_complex(spec))
211 |
212 | # downmix in the mag domain to preserve energy
213 | if self.mono:
214 | spec = torch.mean(spec, 1, keepdim=True)
215 |
216 | return spec
217 |
--------------------------------------------------------------------------------
/openunmix/utils.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, Union
2 |
3 | import torch
4 | import os
5 | import numpy as np
6 | import torchaudio
7 | import warnings
8 | from pathlib import Path
9 | from contextlib import redirect_stderr
10 | import io
11 | import json
12 |
13 | import openunmix
14 | from openunmix import model
15 |
16 |
17 | def bandwidth_to_max_bin(rate: float, n_fft: int, bandwidth: float) -> np.ndarray:
18 | """Convert bandwidth to maximum bin count
19 |
20 | Assuming lapped transforms such as STFT
21 |
22 | Args:
23 | rate (int): Sample rate
24 | n_fft (int): FFT length
25 | bandwidth (float): Target bandwidth in Hz
26 |
27 | Returns:
28 | np.ndarray: maximum frequency bin
29 | """
30 | freqs = np.linspace(0, rate / 2, n_fft // 2 + 1, endpoint=True)
31 |
32 | return np.max(np.where(freqs <= bandwidth)[0]) + 1
33 |
34 |
35 | def save_checkpoint(state: dict, is_best: bool, path: str, target: str):
36 | """Convert bandwidth to maximum bin count
37 |
38 | Assuming lapped transforms such as STFT
39 |
40 | Args:
41 | state (dict): torch model state dict
42 | is_best (bool): if current model is about to be saved as best model
43 | path (str): model path
44 | target (str): target name
45 | """
46 | # save full checkpoint including optimizer
47 | torch.save(state, os.path.join(path, target + ".chkpnt"))
48 | if is_best:
49 | # save just the weights
50 | torch.save(state["state_dict"], os.path.join(path, target + ".pth"))
51 |
52 |
53 | class AverageMeter(object):
54 | """Computes and stores the average and current value"""
55 |
56 | def __init__(self):
57 | self.reset()
58 |
59 | def reset(self):
60 | self.val = 0
61 | self.avg = 0
62 | self.sum = 0
63 | self.count = 0
64 |
65 | def update(self, val, n=1):
66 | self.val = val
67 | self.sum += val * n
68 | self.count += n
69 | self.avg = self.sum / self.count
70 |
71 |
72 | class EarlyStopping(object):
73 | """Early Stopping Monitor"""
74 |
75 | def __init__(self, mode="min", min_delta=0, patience=10):
76 | self.mode = mode
77 | self.min_delta = min_delta
78 | self.patience = patience
79 | self.best = None
80 | self.num_bad_epochs = 0
81 | self.is_better = None
82 | self._init_is_better(mode, min_delta)
83 |
84 | if patience == 0:
85 | self.is_better = lambda a, b: True
86 |
87 | def step(self, metrics):
88 | if self.best is None:
89 | self.best = metrics
90 | return False
91 |
92 | if np.isnan(metrics):
93 | return True
94 |
95 | if self.is_better(metrics, self.best):
96 | self.num_bad_epochs = 0
97 | self.best = metrics
98 | else:
99 | self.num_bad_epochs += 1
100 |
101 | if self.num_bad_epochs >= self.patience:
102 | return True
103 |
104 | return False
105 |
106 | def _init_is_better(self, mode, min_delta):
107 | if mode not in {"min", "max"}:
108 | raise ValueError("mode " + mode + " is unknown!")
109 | if mode == "min":
110 | self.is_better = lambda a, best: a < best - min_delta
111 | if mode == "max":
112 | self.is_better = lambda a, best: a > best + min_delta
113 |
114 |
115 | def load_target_models(targets, model_str_or_path="umxl", device="cpu", pretrained=True):
116 | """Core model loader
117 |
118 | target model path can be either .pth, or -sha256.pth
119 | (as used on torchub)
120 |
121 | The loader either loads the models from a known model string
122 | as registered in the __init__.py or loads from custom configs.
123 | """
124 | if isinstance(targets, str):
125 | targets = [targets]
126 |
127 | model_path = Path(model_str_or_path).expanduser()
128 | if not model_path.exists():
129 | # model path does not exist, use pretrained models
130 | try:
131 | # disable progress bar
132 | hub_loader = getattr(openunmix, model_str_or_path + "_spec")
133 | err = io.StringIO()
134 | with redirect_stderr(err):
135 | return hub_loader(targets=targets, device=device, pretrained=pretrained)
136 | print(err.getvalue())
137 | except AttributeError:
138 | raise NameError("Model does not exist on torchhub")
139 | # assume model is a path to a local model_str_or_path directory
140 | else:
141 | models = {}
142 | for target in targets:
143 | # load model from disk
144 | with open(Path(model_path, target + ".json"), "r") as stream:
145 | results = json.load(stream)
146 |
147 | target_model_path = next(Path(model_path).glob("%s*.pth" % target))
148 | state = torch.load(target_model_path, map_location=device)
149 |
150 | models[target] = model.OpenUnmix(
151 | nb_bins=results["args"]["nfft"] // 2 + 1,
152 | nb_channels=results["args"]["nb_channels"],
153 | hidden_size=results["args"]["hidden_size"],
154 | max_bin=state["input_mean"].shape[0],
155 | )
156 |
157 | if pretrained:
158 | models[target].load_state_dict(state, strict=False)
159 |
160 | models[target].to(device)
161 | return models
162 |
163 |
164 | def load_separator(
165 | model_str_or_path: str = "umxl",
166 | targets: Optional[list] = None,
167 | niter: int = 1,
168 | residual: bool = False,
169 | wiener_win_len: Optional[int] = 300,
170 | device: Union[str, torch.device] = "cpu",
171 | pretrained: bool = True,
172 | filterbank: str = "torch",
173 | ):
174 | """Separator loader
175 |
176 | Args:
177 | model_str_or_path (str): Model name or path to model _parent_ directory
178 | E.g. The following files are assumed to present when
179 | loading `model_str_or_path='mymodel', targets=['vocals']`
180 | 'mymodel/separator.json', mymodel/vocals.pth', 'mymodel/vocals.json'.
181 | Defaults to `umxl`.
182 | targets (list of str or None): list of target names. When loading a
183 | pre-trained model, all `targets` can be None as all targets
184 | will be loaded
185 | niter (int): Number of EM steps for refining initial estimates
186 | in a post-processing stage. `--niter 0` skips this step altogether
187 | (and thus makes separation significantly faster) More iterations
188 | can get better interference reduction at the price of artifacts.
189 | Defaults to `1`.
190 | residual (bool): Computes a residual target, for custom separation
191 | scenarios when not all targets are available (at the expense
192 | of slightly less performance). E.g vocal/accompaniment
193 | Defaults to `False`.
194 | wiener_win_len (int): The size of the excerpts (number of frames) on
195 | which to apply filtering independently. This means assuming
196 | time varying stereo models and localization of sources.
197 | None means not batching but using the whole signal. It comes at the
198 | price of a much larger memory usage.
199 | Defaults to `300`
200 | device (str): torch device, defaults to `cpu`
201 | pretrained (bool): determines if loading pre-trained weights
202 | filterbank (str): filterbank implementation method.
203 | Supported are `['torch', 'asteroid']`. `torch` is about 30% faster
204 | compared to `asteroid` on large FFT sizes such as 4096. However,
205 | asteroids stft can be exported to onnx, which makes is practical
206 | for deployment.
207 | """
208 | model_path = Path(model_str_or_path).expanduser()
209 |
210 | # when path exists, we assume its a custom model saved locally
211 | if model_path.exists():
212 | if targets is None:
213 | raise UserWarning("For custom models, please specify the targets")
214 |
215 | target_models = load_target_models(targets=targets, model_str_or_path=model_path, pretrained=pretrained)
216 |
217 | with open(Path(model_path, "separator.json"), "r") as stream:
218 | enc_conf = json.load(stream)
219 |
220 | separator = model.Separator(
221 | target_models=target_models,
222 | niter=niter,
223 | residual=residual,
224 | wiener_win_len=wiener_win_len,
225 | sample_rate=enc_conf["sample_rate"],
226 | n_fft=enc_conf["nfft"],
227 | n_hop=enc_conf["nhop"],
228 | nb_channels=enc_conf["nb_channels"],
229 | filterbank=filterbank,
230 | ).to(device)
231 |
232 | # otherwise we load the separator from torchhub
233 | else:
234 | hub_loader = getattr(openunmix, model_str_or_path)
235 | separator = hub_loader(
236 | targets=targets,
237 | device=device,
238 | pretrained=True,
239 | niter=niter,
240 | residual=residual,
241 | wiener_win_len=wiener_win_len,
242 | filterbank=filterbank,
243 | )
244 |
245 | return separator
246 |
247 |
248 | def preprocess(
249 | audio: torch.Tensor,
250 | rate: Optional[float] = None,
251 | model_rate: Optional[float] = None,
252 | ) -> torch.Tensor:
253 | """
254 | From an input tensor, convert it to a tensor of shape
255 | shape=(nb_samples, nb_channels, nb_timesteps). This includes:
256 | - if input is 1D, adding the samples and channels dimensions.
257 | - if input is 2D
258 | o and the smallest dimension is 1 or 2, adding the samples one.
259 | o and all dimensions are > 2, assuming the smallest is the samples
260 | one, and adding the channel one
261 | - at the end, if the number of channels is greater than the number
262 | of time steps, swap those two.
263 | - resampling to target rate if necessary
264 |
265 | Args:
266 | audio (Tensor): input waveform
267 | rate (float): sample rate for the audio
268 | model_rate (float): sample rate for the model
269 |
270 | Returns:
271 | Tensor: [shape=(nb_samples, nb_channels=2, nb_timesteps)]
272 | """
273 | shape = torch.as_tensor(audio.shape, device=audio.device)
274 |
275 | if len(shape) == 1:
276 | # assuming only time dimension is provided.
277 | audio = audio[None, None, ...]
278 | elif len(shape) == 2:
279 | if shape.min() <= 2:
280 | # assuming sample dimension is missing
281 | audio = audio[None, ...]
282 | else:
283 | # assuming channel dimension is missing
284 | audio = audio[:, None, ...]
285 | if audio.shape[1] > audio.shape[2]:
286 | # swapping channel and time
287 | audio = audio.transpose(1, 2)
288 | if audio.shape[1] > 2:
289 | warnings.warn("Channel count > 2!. Only the first two channels " "will be processed!")
290 | audio = audio[..., :2]
291 |
292 | if audio.shape[1] == 1:
293 | # if we have mono, we duplicate it to get stereo
294 | audio = torch.repeat_interleave(audio, 2, dim=1)
295 |
296 | if rate != model_rate:
297 | warnings.warn("resample to model sample rate")
298 | # we have to resample to model samplerate if needed
299 | # this makes sure we resample input only once
300 | resampler = torchaudio.transforms.Resample(
301 | orig_freq=rate, new_freq=model_rate, resampling_method="sinc_interpolation"
302 | ).to(audio.device)
303 | audio = resampler(audio)
304 | return audio
305 |
--------------------------------------------------------------------------------
/scripts/train.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import torch
3 | import time
4 | from pathlib import Path
5 | import tqdm
6 | import json
7 | import sklearn.preprocessing
8 | import numpy as np
9 | import random
10 | from git import Repo
11 | import os
12 | import copy
13 | import torchaudio
14 |
15 | from openunmix import data
16 | from openunmix import model
17 | from openunmix import utils
18 | from openunmix import transforms
19 |
20 | tqdm.monitor_interval = 0
21 |
22 |
23 | def train(args, unmix, encoder, device, train_sampler, optimizer):
24 | losses = utils.AverageMeter()
25 | unmix.train()
26 | pbar = tqdm.tqdm(train_sampler, disable=args.quiet)
27 | for x, y in pbar:
28 | pbar.set_description("Training batch")
29 | x, y = x.to(device), y.to(device)
30 | optimizer.zero_grad()
31 | X = encoder(x)
32 | Y_hat = unmix(X)
33 | Y = encoder(y)
34 | loss = torch.nn.functional.mse_loss(Y_hat, Y)
35 | loss.backward()
36 | optimizer.step()
37 | losses.update(loss.item(), Y.size(1))
38 | pbar.set_postfix(loss="{:.3f}".format(losses.avg))
39 | return losses.avg
40 |
41 |
42 | def valid(args, unmix, encoder, device, valid_sampler):
43 | losses = utils.AverageMeter()
44 | unmix.eval()
45 | with torch.no_grad():
46 | for x, y in valid_sampler:
47 | x, y = x.to(device), y.to(device)
48 | X = encoder(x)
49 | Y_hat = unmix(X)
50 | Y = encoder(y)
51 | loss = torch.nn.functional.mse_loss(Y_hat, Y)
52 | losses.update(loss.item(), Y.size(1))
53 | return losses.avg
54 |
55 |
56 | def get_statistics(args, encoder, dataset):
57 | encoder = copy.deepcopy(encoder).to("cpu")
58 | scaler = sklearn.preprocessing.StandardScaler()
59 |
60 | dataset_scaler = copy.deepcopy(dataset)
61 | if isinstance(dataset_scaler, data.SourceFolderDataset):
62 | dataset_scaler.random_chunks = False
63 | else:
64 | dataset_scaler.random_chunks = False
65 | dataset_scaler.seq_duration = None
66 |
67 | dataset_scaler.samples_per_track = 1
68 | dataset_scaler.augmentations = None
69 | dataset_scaler.random_track_mix = False
70 | dataset_scaler.random_interferer_mix = False
71 |
72 | pbar = tqdm.tqdm(range(len(dataset_scaler)), disable=args.quiet)
73 | for ind in pbar:
74 | x, y = dataset_scaler[ind]
75 | pbar.set_description("Compute dataset statistics")
76 | # downmix to mono channel
77 | X = encoder(x[None, ...]).mean(1, keepdim=False).permute(0, 2, 1)
78 |
79 | scaler.partial_fit(np.squeeze(X))
80 |
81 | # set inital input scaler values
82 | std = np.maximum(scaler.scale_, 1e-4 * np.max(scaler.scale_))
83 | return scaler.mean_, std
84 |
85 |
86 | def main():
87 | parser = argparse.ArgumentParser(description="Open Unmix Trainer")
88 |
89 | # which target do we want to train?
90 | parser.add_argument(
91 | "--target",
92 | type=str,
93 | default="vocals",
94 | help="target source (will be passed to the dataset)",
95 | )
96 |
97 | # Dataset paramaters
98 | parser.add_argument(
99 | "--dataset",
100 | type=str,
101 | default="musdb",
102 | choices=[
103 | "musdb",
104 | "aligned",
105 | "sourcefolder",
106 | "trackfolder_var",
107 | "trackfolder_fix",
108 | ],
109 | help="Name of the dataset.",
110 | )
111 | parser.add_argument("--root", type=str, help="root path of dataset")
112 | parser.add_argument(
113 | "--output",
114 | type=str,
115 | default="open-unmix",
116 | help="provide output path base folder name",
117 | )
118 | parser.add_argument("--model", type=str, help="Name or path of pretrained model to fine-tune")
119 | parser.add_argument("--checkpoint", type=str, help="Path of checkpoint to resume training")
120 | parser.add_argument(
121 | "--audio-backend",
122 | type=str,
123 | default="soundfile",
124 | help="Set torchaudio backend (`sox_io` or `soundfile`",
125 | )
126 |
127 | # Training Parameters
128 | parser.add_argument("--epochs", type=int, default=1000)
129 | parser.add_argument("--batch-size", type=int, default=16)
130 | parser.add_argument("--lr", type=float, default=0.001, help="learning rate, defaults to 1e-3")
131 | parser.add_argument(
132 | "--patience",
133 | type=int,
134 | default=140,
135 | help="maximum number of train epochs (default: 140)",
136 | )
137 | parser.add_argument(
138 | "--lr-decay-patience",
139 | type=int,
140 | default=80,
141 | help="lr decay patience for plateau scheduler",
142 | )
143 | parser.add_argument(
144 | "--lr-decay-gamma",
145 | type=float,
146 | default=0.3,
147 | help="gamma of learning rate scheduler decay",
148 | )
149 | parser.add_argument("--weight-decay", type=float, default=0.00001, help="weight decay")
150 | parser.add_argument(
151 | "--seed", type=int, default=42, metavar="S", help="random seed (default: 42)"
152 | )
153 |
154 | # Model Parameters
155 | parser.add_argument(
156 | "--seq-dur",
157 | type=float,
158 | default=6.0,
159 | help="Sequence duration in seconds" "value of <=0.0 will use full/variable length",
160 | )
161 | parser.add_argument(
162 | "--unidirectional",
163 | action="store_true",
164 | default=False,
165 | help="Use unidirectional LSTM",
166 | )
167 | parser.add_argument("--nfft", type=int, default=4096, help="STFT fft size and window size")
168 | parser.add_argument("--nhop", type=int, default=1024, help="STFT hop size")
169 | parser.add_argument(
170 | "--hidden-size",
171 | type=int,
172 | default=512,
173 | help="hidden size parameter of bottleneck layers",
174 | )
175 | parser.add_argument(
176 | "--bandwidth", type=int, default=16000, help="maximum model bandwidth in herz"
177 | )
178 | parser.add_argument(
179 | "--nb-channels",
180 | type=int,
181 | default=2,
182 | help="set number of channels for model (1, 2)",
183 | )
184 | parser.add_argument(
185 | "--nb-workers", type=int, default=0, help="Number of workers for dataloader."
186 | )
187 | parser.add_argument(
188 | "--debug",
189 | action="store_true",
190 | default=False,
191 | help="Speed up training init for dev purposes",
192 | )
193 |
194 | # Misc Parameters
195 | parser.add_argument(
196 | "--quiet",
197 | action="store_true",
198 | default=False,
199 | help="less verbose during training",
200 | )
201 | parser.add_argument(
202 | "--no-cuda", action="store_true", default=False, help="disables CUDA training"
203 | )
204 |
205 | args, _ = parser.parse_known_args()
206 |
207 | torchaudio.set_audio_backend(args.audio_backend)
208 | use_cuda = not args.no_cuda and torch.cuda.is_available()
209 | print("Using GPU:", use_cuda)
210 | dataloader_kwargs = {"num_workers": args.nb_workers, "pin_memory": True} if use_cuda else {}
211 |
212 | repo_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
213 | repo = Repo(repo_dir)
214 | commit = repo.head.commit.hexsha[:7]
215 |
216 | # use jpg or npy
217 | torch.manual_seed(args.seed)
218 | random.seed(args.seed)
219 |
220 | device = torch.device("cuda" if use_cuda else "cpu")
221 |
222 | train_dataset, valid_dataset, args = data.load_datasets(parser, args)
223 |
224 | # create output dir if not exist
225 | target_path = Path(args.output)
226 | target_path.mkdir(parents=True, exist_ok=True)
227 |
228 | train_sampler = torch.utils.data.DataLoader(
229 | train_dataset, batch_size=args.batch_size, shuffle=True, **dataloader_kwargs
230 | )
231 | valid_sampler = torch.utils.data.DataLoader(valid_dataset, batch_size=1, **dataloader_kwargs)
232 |
233 | stft, _ = transforms.make_filterbanks(
234 | n_fft=args.nfft, n_hop=args.nhop, sample_rate=train_dataset.sample_rate
235 | )
236 | encoder = torch.nn.Sequential(stft, model.ComplexNorm(mono=args.nb_channels == 1)).to(device)
237 |
238 | separator_conf = {
239 | "nfft": args.nfft,
240 | "nhop": args.nhop,
241 | "sample_rate": train_dataset.sample_rate,
242 | "nb_channels": args.nb_channels,
243 | }
244 |
245 | with open(Path(target_path, "separator.json"), "w") as outfile:
246 | outfile.write(json.dumps(separator_conf, indent=4, sort_keys=True))
247 |
248 | if args.checkpoint or args.model or args.debug:
249 | scaler_mean = None
250 | scaler_std = None
251 | else:
252 | scaler_mean, scaler_std = get_statistics(args, encoder, train_dataset)
253 |
254 | max_bin = utils.bandwidth_to_max_bin(train_dataset.sample_rate, args.nfft, args.bandwidth)
255 |
256 | if args.model:
257 | # fine tune model
258 | print(f"Fine-tuning model from {args.model}")
259 | unmix = utils.load_target_models(
260 | args.target, model_str_or_path=args.model, device=device, pretrained=True
261 | )[args.target]
262 | unmix = unmix.to(device)
263 | else:
264 | unmix = model.OpenUnmix(
265 | input_mean=scaler_mean,
266 | input_scale=scaler_std,
267 | nb_bins=args.nfft // 2 + 1,
268 | nb_channels=args.nb_channels,
269 | hidden_size=args.hidden_size,
270 | max_bin=max_bin,
271 | unidirectional=args.unidirectional
272 | ).to(device)
273 |
274 | optimizer = torch.optim.Adam(unmix.parameters(), lr=args.lr, weight_decay=args.weight_decay)
275 |
276 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
277 | optimizer,
278 | factor=args.lr_decay_gamma,
279 | patience=args.lr_decay_patience,
280 | cooldown=10,
281 | )
282 |
283 | es = utils.EarlyStopping(patience=args.patience)
284 |
285 | # if a checkpoint is specified: resume training
286 | if args.checkpoint:
287 | model_path = Path(args.checkpoint).expanduser()
288 | with open(Path(model_path, args.target + ".json"), "r") as stream:
289 | results = json.load(stream)
290 |
291 | target_model_path = Path(model_path, args.target + ".chkpnt")
292 | checkpoint = torch.load(target_model_path, map_location=device)
293 | unmix.load_state_dict(checkpoint["state_dict"], strict=False)
294 | optimizer.load_state_dict(checkpoint["optimizer"])
295 | scheduler.load_state_dict(checkpoint["scheduler"])
296 | # train for another epochs_trained
297 | t = tqdm.trange(
298 | results["epochs_trained"],
299 | results["epochs_trained"] + args.epochs + 1,
300 | disable=args.quiet,
301 | )
302 | train_losses = results["train_loss_history"]
303 | valid_losses = results["valid_loss_history"]
304 | train_times = results["train_time_history"]
305 | best_epoch = results["best_epoch"]
306 | es.best = results["best_loss"]
307 | es.num_bad_epochs = results["num_bad_epochs"]
308 | # else start optimizer from scratch
309 | else:
310 | t = tqdm.trange(1, args.epochs + 1, disable=args.quiet)
311 | train_losses = []
312 | valid_losses = []
313 | train_times = []
314 | best_epoch = 0
315 |
316 | for epoch in t:
317 | t.set_description("Training epoch")
318 | end = time.time()
319 | train_loss = train(args, unmix, encoder, device, train_sampler, optimizer)
320 | valid_loss = valid(args, unmix, encoder, device, valid_sampler)
321 | scheduler.step(valid_loss)
322 | train_losses.append(train_loss)
323 | valid_losses.append(valid_loss)
324 |
325 | t.set_postfix(train_loss=train_loss, val_loss=valid_loss)
326 |
327 | stop = es.step(valid_loss)
328 |
329 | if valid_loss == es.best:
330 | best_epoch = epoch
331 |
332 | utils.save_checkpoint(
333 | {
334 | "epoch": epoch + 1,
335 | "state_dict": unmix.state_dict(),
336 | "best_loss": es.best,
337 | "optimizer": optimizer.state_dict(),
338 | "scheduler": scheduler.state_dict(),
339 | },
340 | is_best=valid_loss == es.best,
341 | path=target_path,
342 | target=args.target,
343 | )
344 |
345 | # save params
346 | params = {
347 | "epochs_trained": epoch,
348 | "args": vars(args),
349 | "best_loss": es.best,
350 | "best_epoch": best_epoch,
351 | "train_loss_history": train_losses,
352 | "valid_loss_history": valid_losses,
353 | "train_time_history": train_times,
354 | "num_bad_epochs": es.num_bad_epochs,
355 | "commit": commit,
356 | }
357 |
358 | with open(Path(target_path, args.target + ".json"), "w") as outfile:
359 | outfile.write(json.dumps(params, indent=4, sort_keys=True))
360 |
361 | train_times.append(time.time() - end)
362 |
363 | if stop:
364 | print("Apply Early Stopping")
365 | break
366 |
367 |
368 | if __name__ == "__main__":
369 | main()
370 |
--------------------------------------------------------------------------------
/openunmix/model.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, Mapping
2 |
3 | import numpy as np
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 | from torch import Tensor
8 | from torch.nn import LSTM, BatchNorm1d, Linear, Parameter
9 | from .filtering import wiener
10 | from .transforms import make_filterbanks, ComplexNorm
11 |
12 |
13 | class OpenUnmix(nn.Module):
14 | """OpenUnmix Core spectrogram based separation module.
15 |
16 | Args:
17 | nb_bins (int): Number of input time-frequency bins (Default: `4096`).
18 | nb_channels (int): Number of input audio channels (Default: `2`).
19 | hidden_size (int): Size for bottleneck layers (Default: `512`).
20 | nb_layers (int): Number of Bi-LSTM layers (Default: `3`).
21 | unidirectional (bool): Use causal model useful for realtime purpose.
22 | (Default `False`)
23 | input_mean (ndarray or None): global data mean of shape `(nb_bins, )`.
24 | Defaults to zeros(nb_bins)
25 | input_scale (ndarray or None): global data mean of shape `(nb_bins, )`.
26 | Defaults to ones(nb_bins)
27 | max_bin (int or None): Internal frequency bin threshold to
28 | reduce high frequency content. Defaults to `None` which results
29 | in `nb_bins`
30 | """
31 |
32 | def __init__(
33 | self,
34 | nb_bins: int = 4096,
35 | nb_channels: int = 2,
36 | hidden_size: int = 512,
37 | nb_layers: int = 3,
38 | unidirectional: bool = False,
39 | input_mean: Optional[np.ndarray] = None,
40 | input_scale: Optional[np.ndarray] = None,
41 | max_bin: Optional[int] = None,
42 | ):
43 | super(OpenUnmix, self).__init__()
44 |
45 | self.nb_output_bins = nb_bins
46 | if max_bin:
47 | self.nb_bins = max_bin
48 | else:
49 | self.nb_bins = self.nb_output_bins
50 |
51 | self.hidden_size = hidden_size
52 |
53 | self.fc1 = Linear(self.nb_bins * nb_channels, hidden_size, bias=False)
54 |
55 | self.bn1 = BatchNorm1d(hidden_size)
56 |
57 | if unidirectional:
58 | lstm_hidden_size = hidden_size
59 | else:
60 | lstm_hidden_size = hidden_size // 2
61 |
62 | self.lstm = LSTM(
63 | input_size=hidden_size,
64 | hidden_size=lstm_hidden_size,
65 | num_layers=nb_layers,
66 | bidirectional=not unidirectional,
67 | batch_first=False,
68 | dropout=0.4 if nb_layers > 1 else 0,
69 | )
70 |
71 | fc2_hiddensize = hidden_size * 2
72 | self.fc2 = Linear(in_features=fc2_hiddensize, out_features=hidden_size, bias=False)
73 |
74 | self.bn2 = BatchNorm1d(hidden_size)
75 |
76 | self.fc3 = Linear(
77 | in_features=hidden_size,
78 | out_features=self.nb_output_bins * nb_channels,
79 | bias=False,
80 | )
81 |
82 | self.bn3 = BatchNorm1d(self.nb_output_bins * nb_channels)
83 |
84 | if input_mean is not None:
85 | input_mean = torch.from_numpy(-input_mean[: self.nb_bins]).float()
86 | else:
87 | input_mean = torch.zeros(self.nb_bins)
88 |
89 | if input_scale is not None:
90 | input_scale = torch.from_numpy(1.0 / input_scale[: self.nb_bins]).float()
91 | else:
92 | input_scale = torch.ones(self.nb_bins)
93 |
94 | self.input_mean = Parameter(input_mean)
95 | self.input_scale = Parameter(input_scale)
96 |
97 | self.output_scale = Parameter(torch.ones(self.nb_output_bins).float())
98 | self.output_mean = Parameter(torch.ones(self.nb_output_bins).float())
99 |
100 | def freeze(self):
101 | # set all parameters as not requiring gradient, more RAM-efficient
102 | # at test time
103 | for p in self.parameters():
104 | p.requires_grad = False
105 | self.eval()
106 |
107 | def forward(self, x: Tensor) -> Tensor:
108 | """
109 | Args:
110 | x: input spectrogram of shape
111 | `(nb_samples, nb_channels, nb_bins, nb_frames)`
112 |
113 | Returns:
114 | Tensor: filtered spectrogram of shape
115 | `(nb_samples, nb_channels, nb_bins, nb_frames)`
116 | """
117 |
118 | # permute so that batch is last for lstm
119 | x = x.permute(3, 0, 1, 2)
120 | # get current spectrogram shape
121 | nb_frames, nb_samples, nb_channels, nb_bins = x.data.shape
122 |
123 | mix = x.detach().clone()
124 |
125 | # crop
126 | x = x[..., : self.nb_bins]
127 | # shift and scale input to mean=0 std=1 (across all bins)
128 | x = x + self.input_mean
129 | x = x * self.input_scale
130 |
131 | # to (nb_frames*nb_samples, nb_channels*nb_bins)
132 | # and encode to (nb_frames*nb_samples, hidden_size)
133 | x = self.fc1(x.reshape(-1, nb_channels * self.nb_bins))
134 | # normalize every instance in a batch
135 | x = self.bn1(x)
136 | x = x.reshape(nb_frames, nb_samples, self.hidden_size)
137 | # squash range ot [-1, 1]
138 | x = torch.tanh(x)
139 |
140 | # apply 3-layers of stacked LSTM
141 | lstm_out = self.lstm(x)
142 |
143 | # lstm skip connection
144 | x = torch.cat([x, lstm_out[0]], -1)
145 |
146 | # first dense stage + batch norm
147 | x = self.fc2(x.reshape(-1, x.shape[-1]))
148 | x = self.bn2(x)
149 |
150 | x = F.relu(x)
151 |
152 | # second dense stage + layer norm
153 | x = self.fc3(x)
154 | x = self.bn3(x)
155 |
156 | # reshape back to original dim
157 | x = x.reshape(nb_frames, nb_samples, nb_channels, self.nb_output_bins)
158 |
159 | # apply output scaling
160 | x *= self.output_scale
161 | x += self.output_mean
162 |
163 | # since our output is non-negative, we can apply RELU
164 | x = F.relu(x) * mix
165 | # permute back to (nb_samples, nb_channels, nb_bins, nb_frames)
166 | return x.permute(1, 2, 3, 0)
167 |
168 |
169 | class Separator(nn.Module):
170 | """
171 | Separator class to encapsulate all the stereo filtering
172 | as a torch Module, to enable end-to-end learning.
173 |
174 | Args:
175 | targets (dict of str: nn.Module): dictionary of target models
176 | the spectrogram models to be used by the Separator.
177 | niter (int): Number of EM steps for refining initial estimates in a
178 | post-processing stage. Zeroed if only one target is estimated.
179 | defaults to `1`.
180 | residual (bool): adds an additional residual target, obtained by
181 | subtracting the other estimated targets from the mixture,
182 | before any potential EM post-processing.
183 | Defaults to `False`.
184 | wiener_win_len (int or None): The size of the excerpts
185 | (number of frames) on which to apply filtering
186 | independently. This means assuming time varying stereo models and
187 | localization of sources.
188 | None means not batching but using the whole signal. It comes at the
189 | price of a much larger memory usage.
190 | filterbank (str): filterbank implementation method.
191 | Supported are `['torch', 'asteroid']`. `torch` is about 30% faster
192 | compared to `asteroid` on large FFT sizes such as 4096. However,
193 | asteroids stft can be exported to onnx, which makes is practical
194 | for deployment.
195 | """
196 |
197 | def __init__(
198 | self,
199 | target_models: Mapping[str, nn.Module],
200 | niter: int = 0,
201 | softmask: bool = False,
202 | residual: bool = False,
203 | sample_rate: float = 44100.0,
204 | n_fft: int = 4096,
205 | n_hop: int = 1024,
206 | nb_channels: int = 2,
207 | wiener_win_len: Optional[int] = 300,
208 | filterbank: str = "torch",
209 | ):
210 | super(Separator, self).__init__()
211 |
212 | # saving parameters
213 | self.niter = niter
214 | self.residual = residual
215 | self.softmask = softmask
216 | self.wiener_win_len = wiener_win_len
217 |
218 | self.stft, self.istft = make_filterbanks(
219 | n_fft=n_fft,
220 | n_hop=n_hop,
221 | center=True,
222 | method=filterbank,
223 | sample_rate=sample_rate,
224 | )
225 | self.complexnorm = ComplexNorm(mono=nb_channels == 1)
226 |
227 | # registering the targets models
228 | self.target_models = nn.ModuleDict(target_models)
229 | # adding till https://github.com/pytorch/pytorch/issues/38963
230 | self.nb_targets = len(self.target_models)
231 | # get the sample_rate as the sample_rate of the first model
232 | # (tacitly assume it's the same for all targets)
233 | self.register_buffer("sample_rate", torch.as_tensor(sample_rate))
234 |
235 | def freeze(self):
236 | # set all parameters as not requiring gradient, more RAM-efficient
237 | # at test time
238 | for p in self.parameters():
239 | p.requires_grad = False
240 | self.eval()
241 |
242 | def forward(self, audio: Tensor) -> Tensor:
243 | """Performing the separation on audio input
244 |
245 | Args:
246 | audio (Tensor): [shape=(nb_samples, nb_channels, nb_timesteps)]
247 | mixture audio waveform
248 |
249 | Returns:
250 | Tensor: stacked tensor of separated waveforms
251 | shape `(nb_samples, nb_targets, nb_channels, nb_timesteps)`
252 | """
253 |
254 | nb_sources = self.nb_targets
255 | nb_samples = audio.shape[0]
256 |
257 | # getting the STFT of mix:
258 | # (nb_samples, nb_channels, nb_bins, nb_frames, 2)
259 | mix_stft = self.stft(audio)
260 | X = self.complexnorm(mix_stft)
261 |
262 | # initializing spectrograms variable
263 | spectrograms = torch.zeros(X.shape + (nb_sources,), dtype=audio.dtype, device=X.device)
264 |
265 | for j, (target_name, target_module) in enumerate(self.target_models.items()):
266 | # apply current model to get the source spectrogram
267 | target_spectrogram = target_module(X.detach().clone())
268 | spectrograms[..., j] = target_spectrogram
269 |
270 | # transposing it as
271 | # (nb_samples, nb_frames, nb_bins,{1,nb_channels}, nb_sources)
272 | spectrograms = spectrograms.permute(0, 3, 2, 1, 4)
273 |
274 | # rearranging it into:
275 | # (nb_samples, nb_frames, nb_bins, nb_channels, 2) to feed
276 | # into filtering methods
277 | mix_stft = mix_stft.permute(0, 3, 2, 1, 4)
278 |
279 | # create an additional target if we need to build a residual
280 | if self.residual:
281 | # we add an additional target
282 | nb_sources += 1
283 |
284 | if nb_sources == 1 and self.niter > 0:
285 | raise Exception(
286 | "Cannot use EM if only one target is estimated."
287 | "Provide two targets or create an additional "
288 | "one with `--residual`"
289 | )
290 |
291 | nb_frames = spectrograms.shape[1]
292 | targets_stft = torch.zeros(mix_stft.shape + (nb_sources,), dtype=audio.dtype, device=mix_stft.device)
293 | for sample in range(nb_samples):
294 | pos = 0
295 | if self.wiener_win_len:
296 | wiener_win_len = self.wiener_win_len
297 | else:
298 | wiener_win_len = nb_frames
299 | while pos < nb_frames:
300 | cur_frame = torch.arange(pos, min(nb_frames, pos + wiener_win_len))
301 | pos = int(cur_frame[-1]) + 1
302 |
303 | targets_stft[sample, cur_frame] = wiener(
304 | spectrograms[sample, cur_frame],
305 | mix_stft[sample, cur_frame],
306 | self.niter,
307 | softmask=self.softmask,
308 | residual=self.residual,
309 | )
310 |
311 | # getting to (nb_samples, nb_targets, channel, fft_size, n_frames, 2)
312 | targets_stft = targets_stft.permute(0, 5, 3, 2, 1, 4).contiguous()
313 |
314 | # inverse STFT
315 | estimates = self.istft(targets_stft, length=audio.shape[2])
316 |
317 | return estimates
318 |
319 | def to_dict(self, estimates: Tensor, aggregate_dict: Optional[dict] = None) -> dict:
320 | """Convert estimates as stacked tensor to dictionary
321 |
322 | Args:
323 | estimates (Tensor): separated targets of shape
324 | (nb_samples, nb_targets, nb_channels, nb_timesteps)
325 | aggregate_dict (dict or None)
326 |
327 | Returns:
328 | (dict of str: Tensor):
329 | """
330 | estimates_dict = {}
331 | for k, target in enumerate(self.target_models):
332 | estimates_dict[target] = estimates[:, k, ...]
333 |
334 | # in the case of residual, we added another source
335 | if self.residual:
336 | estimates_dict["residual"] = estimates[:, -1, ...]
337 |
338 | if aggregate_dict is not None:
339 | new_estimates = {}
340 | for key in aggregate_dict:
341 | new_estimates[key] = torch.tensor(0.0)
342 | for target in aggregate_dict[key]:
343 | new_estimates[key] = new_estimates[key] + estimates_dict[target]
344 | estimates_dict = new_estimates
345 | return estimates_dict
346 |
--------------------------------------------------------------------------------
/scripts/README.md:
--------------------------------------------------------------------------------
1 | # _Open-Unmix_ for PyTorch: end-to-end torch branch
2 |
3 | [](https://joss.theoj.org/papers/571753bc54c5d6dd36382c3d801de41d) [](https://paperswithcode.com/sota/music-source-separation-on-musdb18?p=open-unmix-a-reference-implementation-for)
4 |
5 | [](https://colab.research.google.com/drive/1mijF0zGWxN-KaxTnd0q6hayAlrID5fEQ) [](https://gitter.im/sigsep/open-unmix?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge) [](https://groups.google.com/forum/#!forum/open-unmix)
6 |
7 | [](https://travis-ci.com/sigsep/open-unmix-pytorch) [](https://cloud.docker.com/u/faroit/repository/docker/faroit/open-unmix-pytorch)
8 |
9 | This repository contains the PyTorch (1.0+) implementation of __Open-Unmix__, a deep neural network reference implementation for music source separation, applicable for researchers, audio engineers and artists. __Open-Unmix__ provides ready-to-use models that allow users to separate pop music into four stems: __vocals__, __drums__, __bass__ and the remaining __other__ instruments. The models were pre-trained on the [MUSDB18](https://sigsep.github.io/datasets/musdb.html) dataset. See details at [apply pre-trained model](#getting-started).
10 |
11 | ## News:
12 |
13 | - 06/05/2020: We also added a pre-trained speech enhancement model (`umxse`) provided by Sony. For more information we refer [to this site](https://sigsep.github.io/open-unmix/se)
14 |
15 | __Related Projects:__ open-unmix-pytorch | [open-unmix-nnabla](https://github.com/sigsep/open-unmix-nnabla) | [musdb](https://github.com/sigsep/sigsep-mus-db) | [museval](https://github.com/sigsep/sigsep-mus-eval) | [norbert](https://github.com/sigsep/norbert)
16 |
17 | ## The Model for one source
18 |
19 | 
20 |
21 | To perform separation into multiple sources, _Open-unmix_ comprises multiple models that are trained for each particular target. While this makes the training less comfortable, it allows great flexibility to customize the training data for each target source.
22 |
23 | Each _Open-Unmix_ source model is based on a three-layer bidirectional deep LSTM. The model learns to predict the magnitude spectrogram of a target source, like _vocals_, from the magnitude spectrogram of a mixture input. Internally, the prediction is obtained by applying a mask on the input. The model is optimized in the magnitude domain using mean squared error.
24 |
25 | ### Input Stage
26 |
27 | __Open-Unmix__ operates in the time-frequency domain to perform its prediction. The input of the model is either:
28 |
29 | * __A time domain__ signal tensor of shape `(nb_samples, nb_channels, nb_timesteps)`, where `nb_samples` are the samples in a batch, `nb_channels` is 1 or 2 for mono or stereo audio, respectively, and `nb_timesteps` is the number of audio samples in the recording.
30 |
31 | In that case, the model computes spectrograms with `torch.STFT` on the fly.
32 |
33 | * Alternatively _open-unmix_ also takes **magnitude spectrograms** directly (e.g. when pre-computed and loaded from disk).
34 |
35 | In that case, the input is of shape `(nb_frames, nb_samples, nb_channels, nb_bins)`, where `nb_frames` and `nb_bins` are the time and frequency-dimensions of a Short-Time-Fourier-Transform.
36 |
37 | The input spectrogram is _standardized_ using the global mean and standard deviation for every frequency bin across all frames. Furthermore, we apply batch normalization in multiple stages of the model to make the training more robust against gain variation.
38 |
39 | ### Dimensionality reduction
40 |
41 | The LSTM is not operating on the original input spectrogram resolution. Instead, in the first step after the normalization, the network learns to compresses the frequency and channel axis of the model to reduce redundancy and make the model converge faster.
42 |
43 | ### Bidirectional-LSTM
44 |
45 | The core of __open-unmix__ is a three layer bidirectional [LSTM network](https://dl.acm.org/citation.cfm?id=1246450). Due to its recurrent nature, the model can be trained and evaluated on arbitrary length of audio signals. Since the model takes information from past and future simultaneously, the model cannot be used in an online/real-time manner.
46 | An uni-directional model can easily be trained as described [here](docs/training.md).
47 |
48 | ### Output Stage
49 |
50 | After applying the LSTM, the signal is decoded back to its original input dimensionality. In the last steps the output is multiplied with the input magnitude spectrogram, so that the models is asked to learn a mask.
51 |
52 | ## Putting source models together: the `Separator`
53 |
54 | For inference, this branch enables a `Separator` pytorch Module, that puts together one _Open-unmix_ model for each desired target, and combines their output through a multichannel generalized Wiener filter, before application of inverse STFTs using `torchaudio`.
55 | The filtering is a rewriting in torch of the [numpy implementation](https://github.com/sigsep/norbert) used in the main branch.
56 |
57 |
58 | ## Getting started
59 |
60 | ### Installation
61 |
62 | For installation we recommend to use the [Anaconda](https://anaconda.org/) python distribution. To create a conda environment for _open-unmix_, simply run:
63 |
64 | `conda env create -f environment-X.yml` where `X` is either [`cpu-linux`, `gpu-linux-cuda10`, `cpu-osx`], depending on your system. For now, we haven't tested windows support.
65 |
66 | ### Using Docker
67 |
68 | We also provide a docker container as an alternative to anaconda. That way performing separation of a local track in `~/Music/track1.wav` can be performed in a single line:
69 |
70 | ```
71 | docker run -v ~/Music/:/data -it faroit/open-unmix-pytorch python test.py "/data/track1.wav" --outdir /data/track1
72 | ```
73 |
74 | ### Applying pre-trained models on audio files
75 |
76 | We provide two pre-trained music separation models:
77 |
78 | * __`umxhq` (default)__ trained on [MUSDB18-HQ](https://sigsep.github.io/datasets/musdb.html#uncompressed-wav) which comprises the same tracks as in MUSDB18 but un-compressed which yield in a full bandwidth of 22050 Hz.
79 |
80 | [](https://doi.org/10.5281/zenodo.3370489)
81 |
82 | * __`umx`__ is trained on the regular [MUSDB18](https://sigsep.github.io/datasets/musdb.html#compressed-stems) which is bandwidth limited to 16 kHz do to AAC compression. This model should be used for comparison with other (older) methods for evaluation in [SiSEC18](sisec18.unmix.app).
83 |
84 | [](https://doi.org/10.5281/zenodo.3370486)
85 |
86 | Furthermore, we provide a model for speech enhancement trained by [Sony Corporation](link)
87 |
88 | * __`umxse`__ speech enhancement model is trained on the 28-speaker version of the [Voicebank+DEMAND corpus](https://datashare.is.ed.ac.uk/handle/10283/1942?show=full).
89 |
90 | [](https://doi.org/10.5281/zenodo.3786908)
91 |
92 | To separate audio files (`wav`, `flac`, `ogg` - but not `mp3`) files just run:
93 |
94 | ```bash
95 | umx input_file.wav --model umxhq
96 | ```
97 |
98 | A more detailed list of the parameters used for the separation is given in the [inference.md](/docs/inference.md) document.
99 | We provide a [jupyter notebook on google colab](https://colab.research.google.com/drive/1mijF0zGWxN-KaxTnd0q6hayAlrID5fEQ) to
100 | experiment with open-unmix and to separate files online without any installation setup.
101 |
102 | ### Interface with separator fron python via torch.hub
103 |
104 | A pre-trained `Separator` can be loaded from pytorch based code using torch.hub.load:
105 |
106 | ```python
107 | separator = torch.hub.load('sigsep/open-unmix-pytorch', 'umxhq')
108 | ```
109 |
110 | This object may then simply be used for separation of some `audio` (`torch.Tensor` of shape ), sampled at a sampling rate `rate`, through:
111 |
112 | ```python
113 | audio_stimates = separator(audio)
114 | ```
115 |
116 | ### Load user-trained models (only music separation models)
117 |
118 | When a path instead of a model-name is provided to `--model` the pre-trained model will be loaded from disk.
119 |
120 | ```bash
121 | umx --model /path/to/model/root/directory input_file.wav
122 | ```
123 |
124 | Note that `model` usually contains individual models for each target and performs separation using all models. E.g. if `model_path` contains `vocals` and `drums` models, two output files are generated, unless the `--residual-model` option is selected, in which case an additional source will be produced, containing an estimate of all that is not the targets in the mixtures.
125 |
126 | ### Evaluation using `museval`
127 |
128 | To perform evaluation in comparison to other SISEC systems, you would need to install the `museval` package using
129 |
130 | ```
131 | pip install museval
132 | ```
133 |
134 | and then run the evaluation using
135 |
136 | `python -m openunmix.evaluate --outdir /path/to/musdb/estimates --evaldir /path/to/museval/results`
137 |
138 | ### Results compared to SiSEC 2018 (SDR/Vocals)
139 |
140 | Open-Unmix yields state-of-the-art results compared to participants from [SiSEC 2018](https://sisec18.unmix.app/#/methods). The performance of `UMXHQ` and `UMX` is almost identical since it was evaluated on compressed STEMS.
141 |
142 | 
143 |
144 | Note that
145 |
146 | 1. [`STL1`, `TAK2`, `TAK3`, `TAU1`, `UHL3`, `UMXHQ`] were omitted as they were _not_ trained on only _MUSDB18_.
147 | 2. [`HEL1`, `TAK1`, `UHL1`, `UHL2`] are not open-source.
148 |
149 | #### Scores (Median of frames, Median of tracks)
150 |
151 | |target|SDR |SIR | SAR | ISR | SDR | SIR | SAR | ISR |
152 | |------|-----|-----|-----|-----|-----|-----|-----|-----|
153 | |`model`|UMX |UMX |UMX |UMX |UMXHQ|UMXHQ|UMXHQ|UMXHQ|
154 | |vocals|6.32 |13.33| 6.52|11.93| 6.25|12.95| 6.50|12.70|
155 | |bass |5.23 |10.93| 6.34| 9.23| 5.07|10.35| 6.02| 9.71|
156 | |drums |5.73 |11.12| 6.02|10.51| 6.04|11.65| 5.93|11.17|
157 | |other |4.02 |6.59 | 4.74| 9.31| 4.28| 7.10| 4.62| 8.78|
158 |
159 | ## Training
160 |
161 | Details on the training is provided in a separate document [here](docs/training.md).
162 |
163 | ## Extensions
164 |
165 | Details on how _open-unmix_ can be extended or improved for future research on music separation is described in a separate document [here](docs/extensions.md).
166 |
167 |
168 | ## Design Choices
169 |
170 | we favored simplicity over performance to promote clearness of the code. The rationale is to have __open-unmix__ serve as a __baseline__ for future research while performance still meets current state-of-the-art (See [Evaluation](#Evaluation)). The results are comparable/better to those of `UHL1`/`UHL2` which obtained the best performance over all systems trained on MUSDB18 in the [SiSEC 2018 Evaluation campaign](https://sisec18.unmix.app).
171 | We designed the code to allow researchers to reproduce existing results, quickly develop new architectures and add own user data for training and testing. We favored framework specifics implementations instead of having a monolithic repository with common code for all frameworks.
172 |
173 | ## How to contribute
174 |
175 | _open-unmix_ is a community focused project, we therefore encourage the community to submit bug-fixes and requests for technical support through [github issues](https://github.com/sigsep/open-unmix-pytorch/issues/new/choose). For more details of how to contribute, please follow our [`CONTRIBUTING.md`](CONTRIBUTING.md). For help and support, please use the gitter chat or the google groups forums.
176 |
177 | ### Authors
178 |
179 | [Fabian-Robert Stöter](https://www.faroit.com/), [Antoine Liutkus](https://github.com/aliutkus), Inria and LIRMM, Montpellier, France
180 |
181 | ## References
182 |
183 | If you use open-unmix for your research – Cite Open-Unmix
184 |
185 | ```latex
186 | @article{stoter19,
187 | author={F.-R. St\\"oter and S. Uhlich and A. Liutkus and Y. Mitsufuji},
188 | title={Open-Unmix - A Reference Implementation for Music Source Separation},
189 | journal={Journal of Open Source Software},
190 | year=2019,
191 | doi = {10.21105/joss.01667},
192 | url = {https://doi.org/10.21105/joss.01667}
193 | }
194 | ```
195 |
196 |
197 |
198 |
199 | If you use the MUSDB dataset for your research - Cite the MUSDB18 Dataset
200 |
201 |
202 | ```latex
203 | @misc{MUSDB18,
204 | author = {Rafii, Zafar and
205 | Liutkus, Antoine and
206 | Fabian-Robert St{\"o}ter and
207 | Mimilakis, Stylianos Ioannis and
208 | Bittner, Rachel},
209 | title = {The {MUSDB18} corpus for music separation},
210 | month = dec,
211 | year = 2017,
212 | doi = {10.5281/zenodo.1117372},
213 | url = {https://doi.org/10.5281/zenodo.1117372}
214 | }
215 | ```
216 |
217 |
218 |
219 |
220 |
221 | If compare your results with SiSEC 2018 Participants - Cite the SiSEC 2018 LVA/ICA Paper
222 |
223 |
224 | ```latex
225 | @inproceedings{SiSEC18,
226 | author="St{\"o}ter, Fabian-Robert and Liutkus, Antoine and Ito, Nobutaka",
227 | title="The 2018 Signal Separation Evaluation Campaign",
228 | booktitle="Latent Variable Analysis and Signal Separation:
229 | 14th International Conference, LVA/ICA 2018, Surrey, UK",
230 | year="2018",
231 | pages="293--305"
232 | }
233 | ```
234 |
235 |
236 |
237 |
238 | ⚠️ Please note that the official acronym for _open-unmix_ is **UMX**.
239 |
240 | ### License
241 |
242 | MIT
243 |
244 | ### Acknowledgements
245 |
246 |
247 |
248 |
249 |
250 |
--------------------------------------------------------------------------------
/openunmix/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | 
3 | Open-Unmix is a deep neural network reference implementation for music source separation, applicable for researchers, audio engineers and artists. Open-Unmix provides ready-to-use models that allow users to separate pop music into four stems: vocals, drums, bass and the remaining other instruments. The models were pre-trained on the MUSDB18 dataset. See details at apply pre-trained model.
4 |
5 | This is the python package API documentation.
6 | Please checkout [the open-unmix website](https://sigsep.github.io/open-unmix) for more information.
7 | """
8 |
9 | from openunmix import utils
10 | import torch.hub
11 |
12 |
13 | def umxse_spec(targets=None, device="cpu", pretrained=True):
14 | target_urls = {
15 | "speech": "https://zenodo.org/records/3786908/files/speech_f5e0d9f9.pth",
16 | "noise": "https://zenodo.org/records/3786908/files/noise_04a6fc2d.pth",
17 | }
18 |
19 | from .model import OpenUnmix
20 |
21 | if targets is None:
22 | targets = ["speech", "noise"]
23 |
24 | # determine the maximum bin count for a 16khz bandwidth model
25 | max_bin = utils.bandwidth_to_max_bin(rate=16000.0, n_fft=1024, bandwidth=16000)
26 |
27 | # load open unmix models speech enhancement models
28 | target_models = {}
29 | for target in targets:
30 | target_unmix = OpenUnmix(nb_bins=1024 // 2 + 1, nb_channels=1, hidden_size=256, max_bin=max_bin)
31 |
32 | # enable centering of stft to minimize reconstruction error
33 | if pretrained:
34 | state_dict = torch.hub.load_state_dict_from_url(target_urls[target], map_location=device)
35 | target_unmix.load_state_dict(state_dict, strict=False)
36 | target_unmix.eval()
37 |
38 | target_unmix.to(device)
39 | target_models[target] = target_unmix
40 | return target_models
41 |
42 |
43 | def umxse(targets=None, residual=False, niter=1, device="cpu", pretrained=True, filterbank="torch", wiener_win_len=300):
44 | """
45 | Open Unmix Speech Enhancemennt 1-channel BiLSTM Model
46 | trained on the 28-speaker version of Voicebank+Demand
47 | (Sampling rate: 16kHz)
48 |
49 | Args:
50 | targets (str): select the targets for the source to be separated.
51 | a list including: ['speech', 'noise'].
52 | If you don't pick them all, you probably want to
53 | activate the `residual=True` option.
54 | Defaults to all available targets per model.
55 | pretrained (bool): If True, returns a model pre-trained on MUSDB18-HQ
56 | residual (bool): if True, a "garbage" target is created
57 | niter (int): the number of post-processingiterations, defaults to 0
58 | device (str): selects device to be used for inference
59 | wiener_win_len (int or None): The size of the excerpts
60 | (number of frames) on which to apply filtering
61 | independently. This means assuming time varying stereo models and
62 | localization of sources.
63 | None means not batching but using the whole signal. It comes at the
64 | price of a much larger memory usage.
65 | filterbank (str): filterbank implementation method.
66 | Supported are `['torch', 'asteroid']`. `torch` is about 30% faster
67 | compared to `asteroid` on large FFT sizes such as 4096. However,
68 | asteroids stft can be exported to onnx, which makes is practical
69 | for deployment.
70 |
71 | Reference:
72 | Uhlich, Stefan, & Mitsufuji, Yuki. (2020).
73 | Open-Unmix for Speech Enhancement (UMX SE).
74 | Zenodo. http://doi.org/10.5281/zenodo.3786908
75 | """
76 | from .model import Separator
77 |
78 | target_models = umxse_spec(targets=targets, device=device, pretrained=pretrained)
79 |
80 | separator = Separator(
81 | target_models=target_models,
82 | niter=niter,
83 | residual=residual,
84 | n_fft=1024,
85 | n_hop=512,
86 | nb_channels=1,
87 | sample_rate=16000.0,
88 | wiener_win_len=wiener_win_len,
89 | filterbank=filterbank,
90 | ).to(device)
91 |
92 | return separator
93 |
94 |
95 | def umxhq_spec(targets=None, device="cpu", pretrained=True):
96 | from .model import OpenUnmix
97 |
98 | # set urls for weights
99 | target_urls = {
100 | "bass": "https://zenodo.org/records/3370489/files/bass-8d85a5bd.pth",
101 | "drums": "https://zenodo.org/records/3370489/files/drums-9619578f.pth",
102 | "other": "https://zenodo.org/records/3370489/files/other-b52fbbf7.pth",
103 | "vocals": "https://zenodo.org/records/3370489/files/vocals-b62c91ce.pth",
104 | }
105 |
106 | if targets is None:
107 | targets = ["vocals", "drums", "bass", "other"]
108 |
109 | # determine the maximum bin count for a 16khz bandwidth model
110 | max_bin = utils.bandwidth_to_max_bin(rate=44100.0, n_fft=4096, bandwidth=16000)
111 |
112 | target_models = {}
113 | for target in targets:
114 | # load open unmix model
115 | target_unmix = OpenUnmix(nb_bins=4096 // 2 + 1, nb_channels=2, hidden_size=512, max_bin=max_bin)
116 |
117 | # enable centering of stft to minimize reconstruction error
118 | if pretrained:
119 | state_dict = torch.hub.load_state_dict_from_url(target_urls[target], map_location=device)
120 | target_unmix.load_state_dict(state_dict, strict=False)
121 | target_unmix.eval()
122 |
123 | target_unmix.to(device)
124 | target_models[target] = target_unmix
125 | return target_models
126 |
127 |
128 | def umxhq(
129 | targets=None,
130 | residual=False,
131 | niter=1,
132 | device="cpu",
133 | pretrained=True,
134 | wiener_win_len=300,
135 | filterbank="torch",
136 | ):
137 | """
138 | Open Unmix 2-channel/stereo BiLSTM Model trained on MUSDB18-HQ
139 |
140 | Args:
141 | targets (str): select the targets for the source to be separated.
142 | a list including: ['vocals', 'drums', 'bass', 'other'].
143 | If you don't pick them all, you probably want to
144 | activate the `residual=True` option.
145 | Defaults to all available targets per model.
146 | pretrained (bool): If True, returns a model pre-trained on MUSDB18-HQ
147 | residual (bool): if True, a "garbage" target is created
148 | niter (int): the number of post-processingiterations, defaults to 0
149 | device (str): selects device to be used for inference
150 | wiener_win_len (int or None): The size of the excerpts
151 | (number of frames) on which to apply filtering
152 | independently. This means assuming time varying stereo models and
153 | localization of sources.
154 | None means not batching but using the whole signal. It comes at the
155 | price of a much larger memory usage.
156 | filterbank (str): filterbank implementation method.
157 | Supported are `['torch', 'asteroid']`. `torch` is about 30% faster
158 | compared to `asteroid` on large FFT sizes such as 4096. However,
159 | asteroids stft can be exported to onnx, which makes is practical
160 | for deployment.
161 | """
162 |
163 | from .model import Separator
164 |
165 | target_models = umxhq_spec(targets=targets, device=device, pretrained=pretrained)
166 |
167 | separator = Separator(
168 | target_models=target_models,
169 | niter=niter,
170 | residual=residual,
171 | n_fft=4096,
172 | n_hop=1024,
173 | nb_channels=2,
174 | sample_rate=44100.0,
175 | wiener_win_len=wiener_win_len,
176 | filterbank=filterbank,
177 | ).to(device)
178 |
179 | return separator
180 |
181 |
182 | def umx_spec(targets=None, device="cpu", pretrained=True):
183 | from .model import OpenUnmix
184 |
185 | # set urls for weights
186 | target_urls = {
187 | "bass": "https://zenodo.org/records/3370486/files/bass-646024d3.pth",
188 | "drums": "https://zenodo.org/records/3370486/files/drums-5a48008b.pth",
189 | "other": "https://zenodo.org/records/3370486/files/other-f8e132cc.pth",
190 | "vocals": "https://zenodo.org/records/3370486/files/vocals-c8df74a5.pth",
191 | }
192 |
193 | if targets is None:
194 | targets = ["vocals", "drums", "bass", "other"]
195 |
196 | # determine the maximum bin count for a 16khz bandwidth model
197 | max_bin = utils.bandwidth_to_max_bin(rate=44100.0, n_fft=4096, bandwidth=16000)
198 |
199 | target_models = {}
200 | for target in targets:
201 | # load open unmix model
202 | target_unmix = OpenUnmix(nb_bins=4096 // 2 + 1, nb_channels=2, hidden_size=512, max_bin=max_bin)
203 |
204 | # enable centering of stft to minimize reconstruction error
205 | if pretrained:
206 | state_dict = torch.hub.load_state_dict_from_url(target_urls[target], map_location=device)
207 | target_unmix.load_state_dict(state_dict, strict=False)
208 | target_unmix.eval()
209 |
210 | target_unmix.to(device)
211 | target_models[target] = target_unmix
212 | return target_models
213 |
214 |
215 | def umx(
216 | targets=None,
217 | residual=False,
218 | niter=1,
219 | device="cpu",
220 | pretrained=True,
221 | wiener_win_len=300,
222 | filterbank="torch",
223 | ):
224 | """
225 | Open Unmix 2-channel/stereo BiLSTM Model trained on MUSDB18
226 |
227 | Args:
228 | targets (str): select the targets for the source to be separated.
229 | a list including: ['vocals', 'drums', 'bass', 'other'].
230 | If you don't pick them all, you probably want to
231 | activate the `residual=True` option.
232 | Defaults to all available targets per model.
233 | pretrained (bool): If True, returns a model pre-trained on MUSDB18-HQ
234 | residual (bool): if True, a "garbage" target is created
235 | niter (int): the number of post-processingiterations, defaults to 0
236 | device (str): selects device to be used for inference
237 | wiener_win_len (int or None): The size of the excerpts
238 | (number of frames) on which to apply filtering
239 | independently. This means assuming time varying stereo models and
240 | localization of sources.
241 | None means not batching but using the whole signal. It comes at the
242 | price of a much larger memory usage.
243 | filterbank (str): filterbank implementation method.
244 | Supported are `['torch', 'asteroid']`. `torch` is about 30% faster
245 | compared to `asteroid` on large FFT sizes such as 4096. However,
246 | asteroids stft can be exported to onnx, which makes is practical
247 | for deployment.
248 |
249 | """
250 |
251 | from .model import Separator
252 |
253 | target_models = umx_spec(targets=targets, device=device, pretrained=pretrained)
254 | separator = Separator(
255 | target_models=target_models,
256 | niter=niter,
257 | residual=residual,
258 | n_fft=4096,
259 | n_hop=1024,
260 | nb_channels=2,
261 | sample_rate=44100.0,
262 | wiener_win_len=wiener_win_len,
263 | filterbank=filterbank,
264 | ).to(device)
265 |
266 | return separator
267 |
268 |
269 | def umxl_spec(targets=None, device="cpu", pretrained=True):
270 | from .model import OpenUnmix
271 |
272 | # set urls for weights
273 | target_urls = {
274 | "bass": "https://zenodo.org/records/5069601/files/bass-2ca1ce51.pth",
275 | "drums": "https://zenodo.org/records/5069601/files/drums-69e0ebd4.pth",
276 | "other": "https://zenodo.org/records/5069601/files/other-c8c5b3e6.pth",
277 | "vocals": "https://zenodo.org/records/5069601/files/vocals-bccbd9aa.pth",
278 | }
279 |
280 | if targets is None:
281 | targets = ["vocals", "drums", "bass", "other"]
282 |
283 | # determine the maximum bin count for a 16khz bandwidth model
284 | max_bin = utils.bandwidth_to_max_bin(rate=44100.0, n_fft=4096, bandwidth=16000)
285 |
286 | target_models = {}
287 | for target in targets:
288 | # load open unmix model
289 | target_unmix = OpenUnmix(nb_bins=4096 // 2 + 1, nb_channels=2, hidden_size=1024, max_bin=max_bin)
290 |
291 | # enable centering of stft to minimize reconstruction error
292 | if pretrained:
293 | state_dict = torch.hub.load_state_dict_from_url(target_urls[target], map_location=device)
294 | target_unmix.load_state_dict(state_dict, strict=False)
295 | target_unmix.eval()
296 |
297 | target_unmix.to(device)
298 | target_models[target] = target_unmix
299 | return target_models
300 |
301 |
302 | def umxl(
303 | targets=None,
304 | residual=False,
305 | niter=1,
306 | device="cpu",
307 | pretrained=True,
308 | wiener_win_len=300,
309 | filterbank="torch",
310 | ):
311 | """
312 | Open Unmix Extra (UMX-L), 2-channel/stereo BLSTM Model trained on a private dataset
313 | of ~400h of multi-track audio.
314 |
315 |
316 | Args:
317 | targets (str): select the targets for the source to be separated.
318 | a list including: ['vocals', 'drums', 'bass', 'other'].
319 | If you don't pick them all, you probably want to
320 | activate the `residual=True` option.
321 | Defaults to all available targets per model.
322 | pretrained (bool): If True, returns a model pre-trained on MUSDB18-HQ
323 | residual (bool): if True, a "garbage" target is created
324 | niter (int): the number of post-processingiterations, defaults to 0
325 | device (str): selects device to be used for inference
326 | wiener_win_len (int or None): The size of the excerpts
327 | (number of frames) on which to apply filtering
328 | independently. This means assuming time varying stereo models and
329 | localization of sources.
330 | None means not batching but using the whole signal. It comes at the
331 | price of a much larger memory usage.
332 | filterbank (str): filterbank implementation method.
333 | Supported are `['torch', 'asteroid']`. `torch` is about 30% faster
334 | compared to `asteroid` on large FFT sizes such as 4096. However,
335 | asteroids stft can be exported to onnx, which makes is practical
336 | for deployment.
337 |
338 | """
339 |
340 | from .model import Separator
341 |
342 | target_models = umxl_spec(targets=targets, device=device, pretrained=pretrained)
343 | separator = Separator(
344 | target_models=target_models,
345 | niter=niter,
346 | residual=residual,
347 | n_fft=4096,
348 | n_hop=1024,
349 | nb_channels=2,
350 | sample_rate=44100.0,
351 | wiener_win_len=wiener_win_len,
352 | filterbank=filterbank,
353 | ).to(device)
354 |
355 | return separator
356 |
--------------------------------------------------------------------------------
/docs/evaluate.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 | openunmix.evaluate API documentation
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
292 |
293 |
294 |
295 |
296 |
314 |
315 |
318 |
319 |
--------------------------------------------------------------------------------
/docs/training.md:
--------------------------------------------------------------------------------
1 | # Training Open-Unmix
2 |
3 | > This documentation refers to the standard training procedure for _Open-unmix_, where each target is trained independently. It has not been updated for the end-to-end training capabilities that the `Separator` module allows. Please contribute if you try this.
4 |
5 | Both models, `umxhq` and `umx` that are provided with pre-trained weights, can be trained using the default parameters of the `scripts/train.py` function.
6 |
7 | ## Installation
8 |
9 | The train function is not part of the python package, thus we suggest to use [Anaconda](https://anaconda.org/) to install the training requirments, also because the environment would allow reproducible results.
10 |
11 | To create a conda environment for _open-unmix_, simply run:
12 |
13 | `conda env create -f scripts/environment-X.yml` where `X` is either [`cpu-linux`, `gpu-linux-cuda10`, `cpu-osx`], depending on your system. For now, we haven't tested windows support.
14 |
15 | ## Training API
16 |
17 | The [MUSDB18](https://sigsep.github.io/datasets/musdb.html) and [MUSDB18-HQ](https://sigsep.github.io/datasets/musdb.html) are the largest freely available datasets for professionally produced music tracks (~10h duration) of different styles. They come with isolated `drums`, `bass`, `vocals` and `others` stems. _MUSDB18_ contains two subsets: "train", composed of 100 songs, and "test", composed of 50 songs.
18 |
19 | To directly train a vocal model with _open-unmix_, we first would need to download one of the datasets and place in _unzipped_ in a directory of your choice (called `root`).
20 |
21 | | Argument | Description | Default |
22 | |----------|-------------|---------|
23 | | `--root ` | path to root of dataset on disk. | `None` |
24 |
25 | Also note that, if `--root` is not specified, we automatically download a 7 second preview version of the MUSDB18 dataset. While this is comfortable for testing purposes, we wouldn't recommend to actually train your model on this.
26 |
27 | Training can be started using
28 |
29 | ```bash
30 | python train.py --root path/to/musdb18 --target vocals
31 | ```
32 |
33 | Training `MUSDB18` using _open-unmix_ comes with several design decisions that we made as part of our defaults to improve efficiency and performance:
34 |
35 | * __chunking__: we do not feed full audio tracks into _open-unmix_ but instead chunk the audio into 6s excerpts (`--seq-dur 6.0`).
36 | * __balanced track sampling__: to not create a bias for longer audio tracks we randomly yield one track from MUSDB18 and select a random chunk subsequently. In one epoch we select (on average) 64 samples from each track.
37 | * __source augmentation__: we apply random gains between `0.25` and `1.25` to all sources before mixing. Furthermore, we randomly swap the channels the input mixture.
38 | * __random track mixing__: for a given target we select a _random track_ with replacement. To yield a mixture we draw the interfering sources from different tracks (again with replacement) to increase generalization of the model.
39 | * __fixed validation split__: we provide a fixed validation split of [14 tracks](https://github.com/sigsep/sigsep-mus-db/blob/b283da5b8f24e84172a60a06bb8f3dacd57aa6cd/musdb/configs/mus.yaml#L41). We evaluate on these tracks in full length instead of using chunking to have evaluation as close as possible to the actual test data.
40 |
41 | Some of the parameters for the MUSDB sampling can be controlled using the following arguments:
42 |
43 | | Argument | Description | Default |
44 | |---------------------|-----------------------------------------------|--------------|
45 | | `--is-wav` | loads the decoded WAVs instead of STEMS for faster data loading. See [more details here](https://github.com/sigsep/sigsep-mus-db#using-wav-files-optional). | `False` |
46 | | `--samples-per-track ` | sets the number of samples that are randomly drawn from each track | `64` |
47 | | `--source-augmentations ` | applies augmentations to each audio source before mixing, available augmentations: `[gain, channelswap]`| [gain, channelswap] |
48 |
49 | ## Training and Model Parameters
50 |
51 | An extensive list of additional training parameters allows researchers to quickly try out different parameterizations such as a different FFT size. The table below, we list the additional training parameters and their default values (used for `umxhq` and `umx`L:
52 |
53 | | Argument | Description | Default |
54 | |----------------------------|---------------------------------------------------------------------------------|-----------------|
55 | | `--target ` | name of target source (will be passed to the dataset) | `vocals` |
56 | | `--output ` | path where to save the trained output model as well as checkpoints. | `./open-unmix` |
57 | | `--checkpoint ` | path to checkpoint of target model to resume training. | not set |
58 | | `--model ` | path or str to pretrained target to fine-tune model | not set |
59 | | `--no_cuda` | disable cuda even if available | not set |
60 | | `--epochs ` | Number of epochs to train | `1000` |
61 | | `--batch-size ` | Batch size has influence on memory usage and performance of the LSTM layer | `16` |
62 | | `--patience ` | early stopping patience | `140` |
63 | | `--seq-dur ` | Sequence duration in seconds of chunks taken from the dataset. A value of `<=0.0` results in full/variable length | `6.0` |
64 | | `--unidirectional` | changes the bidirectional LSTM to unidirectional (for real-time applications) | not set |
65 | | `--hidden-size ` | Hidden size parameter of dense bottleneck layers | `512` |
66 | | `--nfft ` | STFT FFT window length in samples | `4096` |
67 | | `--nhop ` | STFT hop length in samples | `1024` |
68 | | `--lr ` | learning rate | `0.001` |
69 | | `--lr-decay-patience ` | learning rate decay patience for plateau scheduler | `80` |
70 | | `--lr-decay-gamma ` | gamma of learning rate plateau scheduler. | `0.3` |
71 | | `--weight-decay ` | weight decay for regularization | `0.00001` |
72 | | `--bandwidth ` | maximum bandwidth in Hertz processed by the LSTM. Input and Output is always full bandwidth! | `16000` |
73 | | `--nb-channels ` | set number of channels for model (1 for mono (spectral downmix is applied,) 2 for stereo) | `2` |
74 | | `--nb-workers ` | Number of (parallel) workers for data-loader, can be safely increased for wav files | `0` |
75 | | `--quiet` | disable print and progress bar during training | not set |
76 | | `--seed ` | Initial seed to set the random initialization | `42` |
77 | | `--audio-backend ` | choose audio loading backend, either `sox` or `soundfile` | `soundfile` for training, `sox` for inference |
78 |
79 | ### Training details of `umxhq`
80 |
81 | The training of `umxhq` took place on Nvidia RTX2080 cards. Equipped with fast SSDs and `--nb-workers 4`, we could utilize around 90% of the GPU, thus training time was around 80 seconds per epoch. We ran four different seeds for each target and selected the model with the lowest validation loss.
82 |
83 | The training and validation loss curves are plotted below:
84 |
85 | 
86 |
87 | ## Other Datasets
88 |
89 | _open-unmix_ uses standard PyTorch [`torch.utils.data.Dataset`](https://pytorch.org/docs/stable/data.html#torch.utils.data.Dataset) classes. The repository comes with __five__ different datasets which cover a wide range of tasks and applications around source separation. Furthermore we also provide a template Dataset if you want to start using your own dataset. The dataset can be selected through a command line argument:
90 |
91 | | Argument | Description | Default |
92 | |----------------------------|------------------------------------------------------------------------|--------------|
93 | | `--dataset ` | Name of the dataset (select from `musdb`, `aligned`, `sourcefolder`, `trackfolder_var`, `trackfolder_fix`) | `musdb` |
94 |
95 | ### `AlignedDataset` (aligned)
96 |
97 | This dataset assumes multiple track folders, where each track includes an input and one output file, directly corresponding to the input and the output of the model.
98 |
99 | This dataset is the most basic of all datasets provided here, due to the least amount of
100 | preprocessing, it is also the fastest option, however, it lacks any kind of source augmentations or custom mixing. Instead, it directly uses the target files that are within the folder. The filenames would have to be identical for each track. E.g, for the first sample of the training, input could be `1/mixture.wav` and output could be `1/vocals.wav`.
101 |
102 | Typical use cases:
103 |
104 | * Source Separation (Mixture -> Target)
105 | * Denoising (Noisy -> Clean)
106 | * Bandwidth Extension (Low Bandwidth -> High Bandwidth)
107 |
108 | #### File Structure
109 |
110 | ```
111 | data/train/1/mixture.wav --> input
112 | data/train/1/vocals.wav ---> output
113 | ...
114 | data/valid/1/mixture.wav --> input
115 | data/valid/1/vocals.wav ---> output
116 |
117 | ```
118 |
119 | #### Parameters
120 |
121 | | Argument | Description | Default |
122 | |----------|-------------|---------|
123 | |`--input-file ` | input file name | `None` |
124 | |`--output-file ` | output file name | `None` |
125 |
126 | #### Example
127 |
128 | ```bash
129 | python train.py --dataset aligned --root /dataset --input_file mixture.wav --output_file vocals.wav
130 | ```
131 |
132 | ### `SourceFolderDataset` (sourcefolder)
133 |
134 | A dataset of that assumes folders of sources,
135 | instead of track folders. This is a common
136 | format for speech and environmental sound datasets
137 | such das DCASE. For each source a variable number of
138 | tracks/sounds is available, therefore the dataset is unaligned by design.
139 |
140 | In this scenario one could easily train a network to separate a target sounds from interfering sounds. For each sample, the data loader loads a random combination of target+interferer as the input and performs a linear mixture of these. The output of the model is the target.
141 |
142 | #### File structure
143 |
144 | ```
145 | train/vocals/track11.wav -----------------\
146 | train/drums/track202.wav (interferer1) ---+--> input
147 | train/bass/track007a.wav (interferer2) --/
148 |
149 | train/vocals/track11.wav ---------------------> output
150 | ```
151 |
152 | #### Parameters
153 |
154 | | Argument | Description | Default |
155 | |----------|-------------|---------|
156 | |`--interferer-dirs list[]` | list of directories used as interferers | `None` |
157 | |`--target-dir ` | directory that contains the target source | `None` |
158 | |`--ext ` | File extension | `.wav` |
159 | |`--ext ` | File extension | `.wav` |
160 | |`--nb-train-samples ` | Number of samples drawn for training | `1000` |
161 | |`--nb-valid-samples ` | Number of samples drawn for validation | `100` |
162 | |`--source-augmentations list[]` | List of augmentation functions that are processed in the order of the list | |
163 |
164 | #### Example
165 |
166 | ```bash
167 | python train.py --dataset sourcefolder --root /data --target-dir vocals --interferer-dirs car_noise wind_noise --ext .ogg --nb-train-samples 1000
168 | ```
169 |
170 | ### `FixedSourcesTrackFolderDataset` (trackfolder_fix)
171 |
172 | A dataset of that assumes audio sources to be stored
173 | in track folder where each track has a fixed number of sources. For each track the users specifies the target file-name (`target_file`) and a list of interferences files (`interferer_files`).
174 | A linear mix is performed on the fly by summing the target and the interferers up.
175 |
176 | Due to the fact that all tracks comprise the exact same set of sources, the random track mixing augmentation technique can be used, where sources from different tracks are mixed together. Setting `random_track_mix=True` results in an unaligned dataset.
177 | When random track mixing is enabled, we define an epoch as when the the target source from all tracks has been seen and only once with whatever interfering sources has randomly been drawn.
178 |
179 | This dataset is recommended to be used for small/medium size for example like the MUSDB18 or other custom source separation datasets.
180 |
181 | #### File structure
182 |
183 | ```sh
184 | train/1/vocals.wav ---------------\
185 | train/1/drums.wav (interferer1) ---+--> input
186 | train/1/bass.wav -(interferer2) --/
187 |
188 | train/1/vocals.wav -------------------> output
189 | ```
190 |
191 | #### Parameters
192 |
193 | | Argument | Description | Default |
194 | |----------|-------------|---------|
195 | |`--target-file ` | Target file (includes extension) | `None` |
196 | |`--interferer-files list[]` | list of interfering sources | `None` |
197 | |`--random-track-mix` | Applies random track mixing | `False` |
198 | |`--source-augmentations list[]` | List of augmentation functions that are processed in the order of the list | |
199 |
200 | #### Example
201 |
202 | ```
203 | python train.py --root /data --dataset trackfolder_fix --target-file vocals.flac --interferer-files bass.flac drums.flac other.flac
204 | ```
205 |
206 | ### `VariableSourcesTrackFolderDataset` (trackfolder_var)
207 |
208 | A dataset of that assumes audio sources to be stored in track folder where each track has a _variable_ number of sources. The users specifies the target file-name (`target_file`) and the extension of sources to used for mixing. A linear mix is performed on the fly by summing all sources in a track folder.
209 |
210 | Since the number of sources differ per track, while target is fixed, a random track mix augmentation cannot be used.
211 | Also make sure, that you do not provide the mixture file among the sources! This dataset maximizes the number of tracks that can be used since it doesn't require the presence of a fixed number of sources per track. However, it is required to
212 | have the target file to be present. To increase the dataset utilization even further users can enable the `--silence-missing-targets` option that outputs silence to missing targets.
213 |
214 | #### File structure
215 |
216 | ```sh
217 | train/1/vocals.wav --> input target \
218 | train/1/drums.wav --> input target |
219 | train/1/bass.wav --> input target --+--> input
220 | train/1/accordion.wav --> input target |
221 | train/1/marimba.wav --> input target /
222 |
223 | train/1/vocals.wav -----------------------> output
224 | ```
225 |
226 | #### Parameters
227 |
228 | | Argument | Description | Default |
229 | |----------|-------------|---------|
230 | |`--target-file ` | file name of target file | `None` |
231 | |`--silence-missing-targets` | if a target is not among the list of sources it will be filled with zero | not set |
232 | |`random interferer mixing` | use _random track_ for the inference track to increase generalization of the model. | not set |
233 | |`--ext ` | File extension that is used to find the interfering files | `.wav` |
234 | |`--source-augmentations list[]` | List of augmentation functions that are processed in the order of the list | |
235 |
236 | #### Example
237 |
238 | ```
239 | python train.py --root /data --dataset trackfolder_var --target-file vocals.flac --ext .wav
240 | ```
241 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # _Open-Unmix_ for PyTorch
2 |
3 | [](https://joss.theoj.org/papers/571753bc54c5d6dd36382c3d801de41d)
4 | [](https://colab.research.google.com/drive/1mijF0zGWxN-KaxTnd0q6hayAlrID5fEQ)
5 | [](https://paperswithcode.com/sota/music-source-separation-on-musdb18?p=open-unmix-a-reference-implementation-for)
6 |
7 | [](https://github.com/sigsep/open-unmix-pytorch/actions/workflows/test_unittests.yml)[](https://pypi.python.org/pypi/openunmix)
8 | [](https://pypi.python.org/pypi/openunmix)
9 |
10 | This repository contains the PyTorch (1.8+) implementation of __Open-Unmix__, a deep neural network reference implementation for music source separation, applicable for researchers, audio engineers and artists. __Open-Unmix__ provides ready-to-use models that allow users to separate pop music into four stems: __vocals__, __drums__, __bass__ and the remaining __other__ instruments. The models were pre-trained on the freely available [MUSDB18](https://sigsep.github.io/datasets/musdb.html) dataset. See details at [apply pre-trained model](#getting-started).
11 |
12 | ## ⭐️ News
13 |
14 | - 16/04/2024: We brought the repo to torch 2.0 level. Everything seems to work fine again, but we needed to relax the regression tests. With most recent version results a slightly different, so be warned when running unit tests
15 | - 03/07/2021: We added `umxl`, a model that was trained on extra data which significantly improves the performance, especially generalization.
16 | - 14/02/2021: We released the new version of open-unmix as a python package. This comes with: a fully differentiable version of [norbert](https://github.com/sigsep/norbert), improved audio loading pipeline and large number of bug fixes. See [release notes](https://github.com/sigsep/open-unmix-pytorch/releases/) for further info.
17 |
18 | - 06/05/2020: We added a pre-trained speech enhancement model `umxse` provided by Sony.
19 |
20 | - 13/03/2020: Open-unmix was awarded 2nd place in the [PyTorch Global Summer Hackathon 2020](https://devpost.com/software/open-unmix).
21 |
22 | __Related Projects:__ open-unmix-pytorch | [open-unmix-nnabla](https://github.com/sigsep/open-unmix-nnabla) | [musdb](https://github.com/sigsep/sigsep-mus-db) | [museval](https://github.com/sigsep/sigsep-mus-eval) | [norbert](https://github.com/sigsep/norbert)
23 |
24 | ## 🧠 The Model (for one source)
25 |
26 | 
27 |
28 | To perform separation into multiple sources, _Open-unmix_ comprises multiple models that are trained for each particular target. While this makes the training less comfortable, it allows great flexibility to customize the training data for each target source.
29 |
30 | Each _Open-Unmix_ source model is based on a three-layer bidirectional deep LSTM. The model learns to predict the magnitude spectrogram of a target source, like _vocals_, from the magnitude spectrogram of a mixture input. Internally, the prediction is obtained by applying a mask on the input. The model is optimized in the magnitude domain using mean squared error.
31 |
32 | ### Input Stage
33 |
34 | __Open-Unmix__ operates in the time-frequency domain to perform its prediction. The input of the model is either:
35 |
36 | * __`models.Separator`:__ A time domain signal tensor of shape `(nb_samples, nb_channels, nb_timesteps)`, where `nb_samples` are the samples in a batch, `nb_channels` is 1 or 2 for mono or stereo audio, respectively, and `nb_timesteps` is the number of audio samples in the recording. In this case, the model computes STFTs with either `torch` or `asteroid_filteranks` on the fly.
37 |
38 | * __`models.OpenUnmix`:__ The core open-unmix takes **magnitude spectrograms** directly (e.g. when pre-computed and loaded from disk). In that case, the input is of shape `(nb_frames, nb_samples, nb_channels, nb_bins)`, where `nb_frames` and `nb_bins` are the time and frequency-dimensions of a Short-Time-Fourier-Transform.
39 |
40 | The input spectrogram is _standardized_ using the global mean and standard deviation for every frequency bin across all frames. Furthermore, we apply batch normalization in multiple stages of the model to make the training more robust against gain variation.
41 |
42 | ### Dimensionality reduction
43 |
44 | The LSTM is not operating on the original input spectrogram resolution. Instead, in the first step after the normalization, the network learns to compresses the frequency and channel axis of the model to reduce redundancy and make the model converge faster.
45 |
46 | ### Bidirectional-LSTM
47 |
48 | The core of __open-unmix__ is a three layer bidirectional [LSTM network](https://dl.acm.org/citation.cfm?id=1246450). Due to its recurrent nature, the model can be trained and evaluated on arbitrary length of audio signals. Since the model takes information from past and future simultaneously, the model cannot be used in an online/real-time manner.
49 | An uni-directional model can easily be trained as described [here](docs/training.md).
50 |
51 | ### Output Stage
52 |
53 | After applying the LSTM, the signal is decoded back to its original input dimensionality. In the last steps the output is multiplied with the input magnitude spectrogram, so that the models is asked to learn a mask.
54 |
55 | ## 🤹♀️ Putting source models together: the `Separator`
56 |
57 | `models.Separator` puts together _Open-unmix_ spectrogram model for each desired target, and combines their output through a multichannel generalized Wiener filter, before application of inverse STFTs using `torchaudio`.
58 | The filtering is differentiable (but parameter-free) version of [norbert](https://github.com/sigsep/norbert). The separator is currently currently only used during inference.
59 |
60 | ## 🏁 Getting started
61 |
62 | ### Installation
63 |
64 | `openunmix` can be installed from pypi using:
65 |
66 | ```
67 | pip install openunmix
68 | ```
69 |
70 | Note, that the pypi version of openunmix uses [torchaudio] to load and save audio files. To increase the number of supported input and output file formats (such as STEMS export), please additionally install [stempeg](https://github.com/faroit/stempeg).
71 |
72 | Training is not part of the open-unmix package, please follow [docs/train.md] for more information.
73 |
74 | #### Using Docker
75 |
76 | We also provide a docker container. Performing separation of a local track in `~/Music/track1.wav` can be performed in a single line:
77 |
78 | ```
79 | docker run -v ~/Music/:/data -it faroit/open-unmix-pytorch "/data/track1.wav" --outdir /data/track1
80 | ```
81 |
82 | ### Pre-trained models
83 |
84 | We provide three core pre-trained music separation models. All three models are end-to-end models that take waveform inputs and output the separated waveforms.
85 |
86 | * __`umxl` (default)__ trained on private stems dataset of compressed stems. __Note, that the weights are only licensed for non-commercial use (CC BY-NC-SA 4.0).__
87 |
88 | [](https://doi.org/10.5281/zenodo.5069601)
89 |
90 | * __`umxhq`__ trained on [MUSDB18-HQ](https://sigsep.github.io/datasets/musdb.html#uncompressed-wav) which comprises the same tracks as in MUSDB18 but un-compressed which yield in a full bandwidth of 22050 Hz.
91 |
92 | [](https://doi.org/10.5281/zenodo.3370489)
93 |
94 | * __`umx`__ is trained on the regular [MUSDB18](https://sigsep.github.io/datasets/musdb.html#compressed-stems) which is bandwidth limited to 16 kHz do to AAC compression. This model should be used for comparison with other (older) methods for evaluation in [SiSEC18](sisec18.unmix.app).
95 |
96 | [](https://doi.org/10.5281/zenodo.3370486)
97 |
98 | Furthermore, we provide a model for speech enhancement trained by [Sony Corporation](link)
99 |
100 | * __`umxse`__ speech enhancement model is trained on the 28-speaker version of the [Voicebank+DEMAND corpus](https://datashare.is.ed.ac.uk/handle/10283/1942?show=full).
101 |
102 | [](https://doi.org/10.5281/zenodo.3786908)
103 |
104 | All four models are also available as spectrogram (core) models, which take magnitude spectrogram inputs and ouput separated spectrograms.
105 | These models can be loaded using `umxl_spec`, `umxhq_spec`, `umx_spec` and `umxse_spec`.
106 |
107 | To separate audio files (`wav`, `flac`, `ogg` - but not `mp3`) files just run:
108 |
109 | ```bash
110 | umx input_file.wav
111 | ```
112 |
113 | A more detailed list of the parameters used for the separation is given in the [inference.md](/docs/inference.md) document.
114 |
115 | We provide a [jupyter notebook on google colab](https://colab.research.google.com/drive/1mijF0zGWxN-KaxTnd0q6hayAlrID5fEQ) to experiment with open-unmix and to separate files online without any installation setup.
116 |
117 | ### Using pre-trained models from within python
118 |
119 | We implementes several ways to load pre-trained models and use them from within your python projects:
120 | #### When the package is installed
121 |
122 | Loading a pre-trained models is as simple as loading
123 |
124 | ```python
125 | separator = openunmix.umxl(...)
126 | ```
127 | #### torch.hub
128 |
129 | We also provide a torch.hub compatible modules that can be loaded. Note that this does _not_ even require to install the open-unmix packagen and should generally work when the pytorch version is the same.
130 |
131 | ```python
132 | separator = torch.hub.load('sigsep/open-unmix-pytorch', 'umxl, device=device)
133 | ```
134 |
135 | Where, `umxl` specifies the pre-trained model.
136 | #### Performing separation
137 |
138 | With a created separator object, one can perform separation of some `audio` (torch.Tensor of shape `(channels, length)`, provided as at a sampling rate `separator.sample_rate`) through:
139 |
140 | ```python
141 | estimates = separator(audio, ...)
142 | # returns estimates as tensor
143 | ```
144 |
145 | Note that this requires the audio to be in the right shape and sampling rate. For convenience we provide a pre-processing in `openunmix.utils.preprocess(..`)` that takes numpy audio and converts it to be used for open-unmix.
146 |
147 | #### One-liner
148 |
149 | To perform model loading, preprocessing and separation in one step, just use:
150 |
151 | ```python
152 | from openunmix.predict import separate
153 | estimates = separate(audio, ...)
154 | ```
155 |
156 | ### Load user-trained models
157 |
158 | When a path instead of a model-name is provided to `--model`, pre-trained `Separator` will be loaded from disk.
159 | E.g. The following files are assumed to present when loading `--model mymodel --targets vocals`
160 |
161 | * `mymodel/separator.json`
162 | * `mymodel/vocals.pth`
163 | * `mymodel/vocals.json`
164 |
165 |
166 | Note that the separator usually joins multiple models for each target and performs separation using all models. E.g. if the separator contains `vocals` and `drums` models, two output files are generated, unless the `--residual` option is selected, in which case an additional source will be produced, containing an estimate of all that is not the targets in the mixtures.
167 |
168 | ### Evaluation using `museval`
169 |
170 | To perform evaluation in comparison to other SISEC systems, you would need to install the `museval` package using
171 |
172 | ```
173 | pip install museval
174 | ```
175 |
176 | and then run the evaluation using
177 |
178 | `python -m openunmix.evaluate --outdir /path/to/musdb/estimates --evaldir /path/to/museval/results`
179 |
180 | ### Results compared to SiSEC 2018 (SDR/Vocals)
181 |
182 | Open-Unmix yields state-of-the-art results compared to participants from [SiSEC 2018](https://sisec18.unmix.app/#/methods). The performance of `UMXHQ` and `UMX` is almost identical since it was evaluated on compressed STEMS.
183 |
184 | 
185 |
186 | Note that
187 |
188 | 1. [`STL1`, `TAK2`, `TAK3`, `TAU1`, `UHL3`, `UMXHQ`] were omitted as they were _not_ trained on only _MUSDB18_.
189 | 2. [`HEL1`, `TAK1`, `UHL1`, `UHL2`] are not open-source.
190 |
191 | #### Scores (Median of frames, Median of tracks)
192 |
193 | |target|SDR | SDR | SDR |
194 | |------|-----|-----|-----|
195 | |`model`|UMX |UMXHQ|UMXL |
196 | |vocals|6.32 | 6.25|__7.21__ |
197 | |bass |5.23 | 5.07|__6.02__ |
198 | |drums |5.73 | 6.04|__7.15__ |
199 | |other |4.02 | 4.28|__4.89__ |
200 |
201 | ## Training
202 |
203 | Details on the training is provided in a separate document [here](docs/training.md).
204 |
205 | ## Extensions
206 |
207 | Details on how _open-unmix_ can be extended or improved for future research on music separation is described in a separate document [here](docs/extensions.md).
208 |
209 |
210 | ## Design Choices
211 |
212 | we favored simplicity over performance to promote clearness of the code. The rationale is to have __open-unmix__ serve as a __baseline__ for future research while performance still meets current state-of-the-art (See [Evaluation](#Evaluation)). The results are comparable/better to those of `UHL1`/`UHL2` which obtained the best performance over all systems trained on MUSDB18 in the [SiSEC 2018 Evaluation campaign](https://sisec18.unmix.app).
213 | We designed the code to allow researchers to reproduce existing results, quickly develop new architectures and add own user data for training and testing. We favored framework specifics implementations instead of having a monolithic repository with common code for all frameworks.
214 |
215 | ## How to contribute
216 |
217 | _open-unmix_ is a community focused project, we therefore encourage the community to submit bug-fixes and requests for technical support through [github issues](https://github.com/sigsep/open-unmix-pytorch/issues/new/choose). For more details of how to contribute, please follow our [`CONTRIBUTING.md`](CONTRIBUTING.md). For help and support, please use the gitter chat or the google groups forums.
218 |
219 | ### Authors
220 |
221 | [Fabian-Robert Stöter](https://www.faroit.com/), [Antoine Liutkus](https://github.com/aliutkus), Inria and LIRMM, Montpellier, France
222 |
223 | ## References
224 |
225 | If you use open-unmix for your research – Cite Open-Unmix
226 |
227 | ```latex
228 | @article{stoter19,
229 | author={F.-R. St\\"oter and S. Uhlich and A. Liutkus and Y. Mitsufuji},
230 | title={Open-Unmix - A Reference Implementation for Music Source Separation},
231 | journal={Journal of Open Source Software},
232 | year=2019,
233 | doi = {10.21105/joss.01667},
234 | url = {https://doi.org/10.21105/joss.01667}
235 | }
236 | ```
237 |
238 |
239 |
240 |
241 | If you use the MUSDB dataset for your research - Cite the MUSDB18 Dataset
242 |
243 |
244 | ```latex
245 | @misc{MUSDB18,
246 | author = {Rafii, Zafar and
247 | Liutkus, Antoine and
248 | Fabian-Robert St{\"o}ter and
249 | Mimilakis, Stylianos Ioannis and
250 | Bittner, Rachel},
251 | title = {The {MUSDB18} corpus for music separation},
252 | month = dec,
253 | year = 2017,
254 | doi = {10.5281/zenodo.1117372},
255 | url = {https://doi.org/10.5281/zenodo.1117372}
256 | }
257 | ```
258 |
259 |
260 |
261 |
262 |
263 | If compare your results with SiSEC 2018 Participants - Cite the SiSEC 2018 LVA/ICA Paper
264 |
265 |
266 | ```latex
267 | @inproceedings{SiSEC18,
268 | author="St{\"o}ter, Fabian-Robert and Liutkus, Antoine and Ito, Nobutaka",
269 | title="The 2018 Signal Separation Evaluation Campaign",
270 | booktitle="Latent Variable Analysis and Signal Separation:
271 | 14th International Conference, LVA/ICA 2018, Surrey, UK",
272 | year="2018",
273 | pages="293--305"
274 | }
275 | ```
276 |
277 |
278 |
279 |
280 | ⚠️ Please note that the official acronym for _open-unmix_ is **UMX**.
281 |
282 | ### License
283 |
284 | MIT
285 |
286 | ### Acknowledgements
287 |
288 |
from openunmix import utils
32 |
33 |
34 | def separate(
35 | audio,
36 | rate=None,
37 | model_str_or_path="umxhq",
38 | targets=None,
39 | niter=1,
40 | residual=False,
41 | wiener_win_len=300,
42 | aggregate_dict=None,
43 | separator=None,
44 | device=None,
45 | filterbank="torch",
46 | ):
47 | """
48 | Open Unmix functional interface
49 |
50 | Separates a torch.Tensor or the content of an audio file.
51 |
52 | If a separator is provided, use it for inference. If not, create one
53 | and use it afterwards.
54 |
55 | Args:
56 | audio: audio to process
57 | torch Tensor: shape (channels, length), and
58 | `rate` must also be provided.
59 | rate: int or None: only used if audio is a Tensor. Otherwise,
60 | inferred from the file.
61 | model_str_or_path: the pretrained model to use
62 | targets (str): select the targets for the source to be separated.
63 | a list including: ['vocals', 'drums', 'bass', 'other'].
64 | If you don't pick them all, you probably want to
65 | activate the `residual=True` option.
66 | Defaults to all available targets per model.
67 | niter (int): the number of post-processingiterations, defaults to 1
68 | residual (bool): if True, a "garbage" target is created
69 | wiener_win_len (int): the number of frames to use when batching
70 | the post-processing step
71 | aggregate_dict (str): if provided, must be a string containing a '
72 | 'valid expression for a dictionary, with keys as output '
73 | 'target names, and values a list of targets that are used to '
74 | 'build it. For instance: \'{\"vocals\":[\"vocals\"], '
75 | '\"accompaniment\":[\"drums\",\"bass\",\"other\"]}\'
76 | separator: if provided, the model.Separator object that will be used
77 | to perform separation
78 | device (str): selects device to be used for inference
79 | filterbank (str): filterbank implementation method.
80 | Supported are `['torch', 'asteroid']`. `torch` is about 30% faster
81 | compared to `asteroid` on large FFT sizes such as 4096. However,
82 | asteroids stft can be exported to onnx, which makes is practical
83 | for deployment.
84 | """
85 | if separator is None:
86 | separator = utils.load_separator(
87 | model_str_or_path=model_str_or_path,
88 | targets=targets,
89 | niter=niter,
90 | residual=residual,
91 | wiener_win_len=wiener_win_len,
92 | device=device,
93 | pretrained=True,
94 | filterbank=filterbank,
95 | )
96 | separator.freeze()
97 | if device:
98 | separator.to(device)
99 |
100 | if rate is None:
101 | raise Exception("rate` must be provided.")
102 |
103 | if device:
104 | audio = audio.to(device)
105 | audio = utils.preprocess(audio, rate, separator.sample_rate)
106 |
107 | # getting the separated signals
108 | estimates = separator(audio)
109 | estimates = separator.to_dict(estimates, aggregate_dict=aggregate_dict)
110 | return estimates
Separates a torch.Tensor or the content of an audio file.
126 |
If a separator is provided, use it for inference. If not, create one
127 | and use it afterwards.
128 |
Args
129 |
130 |
audio
131 |
audio to process
132 | torch Tensor: shape (channels, length), and
133 | rate must also be provided.
134 |
rate
135 |
int or None: only used if audio is a Tensor. Otherwise,
136 | inferred from the file.
137 |
model_str_or_path
138 |
the pretrained model to use
139 |
targets : str
140 |
select the targets for the source to be separated.
141 | a list including: ['vocals', 'drums', 'bass', 'other'].
142 | If you don't pick them all, you probably want to
143 | activate the residual=True option.
144 | Defaults to all available targets per model.
145 |
niter : int
146 |
the number of post-processingiterations, defaults to 1
147 |
residual : bool
148 |
if True, a "garbage" target is created
149 |
wiener_win_len : int
150 |
the number of frames to use when batching
151 | the post-processing step
152 |
aggregate_dict : str
153 |
if provided, must be a string containing a '
154 | 'valid expression for a dictionary, with keys as output '
155 | 'target names, and values a list of targets that are used to '
156 | 'build it. For instance: '{"vocals":["vocals"], '
157 | '"accompaniment":["drums","bass","other"]}'
158 |
separator
159 |
if provided, the model.Separator object that will be used
160 | to perform separation
161 |
device : str
162 |
selects device to be used for inference
163 |
filterbank : str
164 |
filterbank implementation method.
165 | Supported are ['torch', 'asteroid']. torch is about 30% faster
166 | compared to asteroid on large FFT sizes such as 4096. However,
167 | asteroids stft can be exported to onnx, which makes is practical
168 | for deployment.
def separate(
176 | audio,
177 | rate=None,
178 | model_str_or_path="umxhq",
179 | targets=None,
180 | niter=1,
181 | residual=False,
182 | wiener_win_len=300,
183 | aggregate_dict=None,
184 | separator=None,
185 | device=None,
186 | filterbank="torch",
187 | ):
188 | """
189 | Open Unmix functional interface
190 |
191 | Separates a torch.Tensor or the content of an audio file.
192 |
193 | If a separator is provided, use it for inference. If not, create one
194 | and use it afterwards.
195 |
196 | Args:
197 | audio: audio to process
198 | torch Tensor: shape (channels, length), and
199 | `rate` must also be provided.
200 | rate: int or None: only used if audio is a Tensor. Otherwise,
201 | inferred from the file.
202 | model_str_or_path: the pretrained model to use
203 | targets (str): select the targets for the source to be separated.
204 | a list including: ['vocals', 'drums', 'bass', 'other'].
205 | If you don't pick them all, you probably want to
206 | activate the `residual=True` option.
207 | Defaults to all available targets per model.
208 | niter (int): the number of post-processingiterations, defaults to 1
209 | residual (bool): if True, a "garbage" target is created
210 | wiener_win_len (int): the number of frames to use when batching
211 | the post-processing step
212 | aggregate_dict (str): if provided, must be a string containing a '
213 | 'valid expression for a dictionary, with keys as output '
214 | 'target names, and values a list of targets that are used to '
215 | 'build it. For instance: \'{\"vocals\":[\"vocals\"], '
216 | '\"accompaniment\":[\"drums\",\"bass\",\"other\"]}\'
217 | separator: if provided, the model.Separator object that will be used
218 | to perform separation
219 | device (str): selects device to be used for inference
220 | filterbank (str): filterbank implementation method.
221 | Supported are `['torch', 'asteroid']`. `torch` is about 30% faster
222 | compared to `asteroid` on large FFT sizes such as 4096. However,
223 | asteroids stft can be exported to onnx, which makes is practical
224 | for deployment.
225 | """
226 | if separator is None:
227 | separator = utils.load_separator(
228 | model_str_or_path=model_str_or_path,
229 | targets=targets,
230 | niter=niter,
231 | residual=residual,
232 | wiener_win_len=wiener_win_len,
233 | device=device,
234 | pretrained=True,
235 | filterbank=filterbank,
236 | )
237 | separator.freeze()
238 | if device:
239 | separator.to(device)
240 |
241 | if rate is None:
242 | raise Exception("rate` must be provided.")
243 |
244 | if device:
245 | audio = audio.to(device)
246 | audio = utils.preprocess(audio, rate, separator.sample_rate)
247 |
248 | # getting the separated signals
249 | estimates = separator(audio)
250 | estimates = separator.to_dict(estimates, aggregate_dict=aggregate_dict)
251 | return estimates