├── demucs ├── py.typed ├── grids │ ├── __init__.py │ ├── mdx.py │ ├── mdx_extra.py │ ├── mdx_refine.py │ └── _explorers.py ├── remote │ ├── mdx_extra.yaml │ ├── mdx_extra_q.yaml │ ├── mdx.yaml │ ├── mdx_q.yaml │ └── files.txt ├── __init__.py ├── wdemucs.py ├── __main__.py ├── spec.py ├── ema.py ├── pretrained.py ├── repitch.py ├── svd.py ├── distrib.py ├── augment.py ├── utils.py ├── states.py ├── repo.py ├── train.py ├── evaluate.py ├── separate.py ├── apply.py ├── audio.py ├── wav.py ├── solver.py └── demucs.py ├── conf ├── svd │ ├── default.yaml │ ├── base2.yaml │ └── base.yaml ├── variant │ ├── default.yaml │ ├── example.yaml │ └── finetune.yaml ├── dset │ ├── musdb44.yaml │ ├── extra44.yaml │ ├── extra_test.yaml │ ├── auto_mus.yaml │ ├── auto_extra_test.yaml │ └── aetl.yaml ├── infer_config.yaml └── train_test_config.yaml ├── .gitignore ├── requirements.txt ├── .pre-commit-config.yaml ├── .github └── workflows │ ├── snyk.yml │ ├── pylint.yml │ └── codeql-analysis.yml ├── inference.py ├── test.py ├── train.py └── README.md /demucs/py.typed: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /demucs/grids/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /conf/svd/default.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | -------------------------------------------------------------------------------- /conf/variant/default.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.ipynb 2 | .ipynb_checkpoints/ 3 | outputs/ 4 | __pycache__/ 5 | 6 | -------------------------------------------------------------------------------- /conf/dset/musdb44.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | dset: 4 | samplerate: 44100 5 | channels: 2 -------------------------------------------------------------------------------- /conf/variant/example.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | model: hdemucs 4 | hdemucs: 5 | channels: 32 -------------------------------------------------------------------------------- /demucs/remote/mdx_extra.yaml: -------------------------------------------------------------------------------- 1 | models: ['e51eebcc', 'a1d90b5c', '5d2d6c55', 'cfa93e08'] 2 | segment: 44 -------------------------------------------------------------------------------- /demucs/remote/mdx_extra_q.yaml: -------------------------------------------------------------------------------- 1 | models: ['83fc094f', '464b36d7', '14fc6a69', '7fd6ef75'] 2 | segment: 44 3 | -------------------------------------------------------------------------------- /conf/dset/extra44.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # Musdb + extra tracks 4 | dset: 5 | wav: /checkpoint/defossez/datasets/allstems_44/ 6 | samplerate: 44100 7 | channels: 2 8 | epochs: 320 9 | -------------------------------------------------------------------------------- /demucs/remote/mdx.yaml: -------------------------------------------------------------------------------- 1 | models: ['0d19c1c6', '7ecf8ec1', 'c511e2ab', '7d865c68'] 2 | weights: [ 3 | [1., 1., 0., 0.], 4 | [0., 1., 0., 0.], 5 | [1., 0., 1., 1.], 6 | [1., 0., 1., 1.], 7 | ] 8 | segment: 44 9 | -------------------------------------------------------------------------------- /demucs/remote/mdx_q.yaml: -------------------------------------------------------------------------------- 1 | models: ['6b9c2ca1', 'b72baf4e', '42e558d4', '305bc58f'] 2 | weights: [ 3 | [1., 1., 0., 0.], 4 | [0., 1., 0., 0.], 5 | [1., 0., 1., 1.], 6 | [1., 0., 1., 1.], 7 | ] 8 | segment: 44 9 | -------------------------------------------------------------------------------- /conf/svd/base2.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | svd: 4 | penalty: 0 5 | min_size: 1 6 | dim: 100 7 | niters: 4 8 | powm: false 9 | proba: 1 10 | conv_only: false 11 | convtr: true 12 | 13 | optim: 14 | beta2: 0.9998 -------------------------------------------------------------------------------- /demucs/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | __version__ = "3.0.5a1" 8 | -------------------------------------------------------------------------------- /conf/dset/extra_test.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # Musdb + extra tracks + test set from musdb. 4 | dset: 5 | wav: /checkpoint/defossez/datasets/allstems_test_44/ 6 | samplerate: 44100 7 | channels: 2 8 | epochs: 320 9 | max_batches: 700 10 | test: 11 | sdr: false 12 | every: 500 13 | -------------------------------------------------------------------------------- /conf/svd/base.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | svd: 4 | penalty: 0 5 | min_size: 1 6 | dim: 50 7 | niters: 4 8 | powm: false 9 | proba: 1 10 | conv_only: false 11 | convtr: false # ideally this should be true, but some models were trained with this to false. 12 | 13 | optim: 14 | beta2: 0.9998 -------------------------------------------------------------------------------- /conf/variant/finetune.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | epochs: 4 4 | batch_size: 16 5 | optim: 6 | lr: 0.0006 7 | test: 8 | every: 1 9 | sdr: false 10 | dset: 11 | segment: 28 12 | shift: 2 13 | 14 | augment: 15 | scale: 16 | proba: 0 17 | shift_same: true 18 | remix: 19 | proba: 0 20 | -------------------------------------------------------------------------------- /demucs/wdemucs.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # For compat 7 | from .hdemucs import HDemucs 8 | 9 | WDemucs = HDemucs 10 | -------------------------------------------------------------------------------- /demucs/__main__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .separate import main 8 | 9 | if __name__ == "__main__": 10 | main() 11 | -------------------------------------------------------------------------------- /conf/dset/auto_mus.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # Automix dataset based on musdb train set. 4 | dset: 5 | wav: /checkpoint/defossez/datasets/automix_musdb 6 | samplerate: 44100 7 | channels: 2 8 | epochs: 360 9 | max_batches: 300 10 | test: 11 | every: 4 12 | 13 | augment: 14 | shift_same: true 15 | scale: 16 | proba: 0.5 17 | remix: 18 | proba: 0 19 | repitch: 20 | proba: 0 21 | -------------------------------------------------------------------------------- /conf/dset/auto_extra_test.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # automix dataset with Musdb, extra training data and the test set of Musdb. 4 | dset: 5 | wav: /checkpoint/defossez/datasets/automix_extra_test2 6 | samplerate: 44100 7 | channels: 2 8 | epochs: 320 9 | max_batches: 500 10 | 11 | augment: 12 | shift_same: true 13 | scale: 14 | proba: 0. 15 | remix: 16 | proba: 0 17 | repitch: 18 | proba: 0 19 | -------------------------------------------------------------------------------- /demucs/remote/files.txt: -------------------------------------------------------------------------------- 1 | 0d19c1c6-0f06f20e.th 2 | 5d2d6c55-db83574e.th 3 | 7d865c68-3d5dd56b.th 4 | 7ecf8ec1-70f50cc9.th 5 | a1d90b5c-ae9d2452.th 6 | c511e2ab-fe698775.th 7 | cfa93e08-61801ae1.th 8 | e51eebcc-c1b80bdd.th 9 | 6b9c2ca1-3fd82607.th 10 | b72baf4e-8778635e.th 11 | 42e558d4-196e0e1b.th 12 | 305bc58f-18378783.th 13 | 14fc6a69-a89dd0ee.th 14 | 464b36d7-e5a9386e.th 15 | 7fd6ef75-a905dd85.th 16 | 83fc094f-4a16d450.th 17 | -------------------------------------------------------------------------------- /conf/dset/aetl.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # automix dataset with Musdb, extra training data and the test set of Musdb. 4 | # This used even more remixes than auto_extra_test. 5 | dset: 6 | wav: /checkpoint/defossez/datasets/aetl 7 | samplerate: 44100 8 | channels: 2 9 | epochs: 320 10 | max_batches: 500 11 | 12 | augment: 13 | shift_same: true 14 | scale: 15 | proba: 0. 16 | remix: 17 | proba: 0 18 | repitch: 19 | proba: 0 20 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | git+https://github.com/KinWaiCheuk/AudioLoader.git 2 | omegaconf>=2.3.0 3 | tqdm>=4.64.1 4 | torch==1.11.0+cu113 5 | torchaudio==0.11.0+cu113 6 | torchmetrics==0.11.0 7 | torchvision==0.12.0+cu113 8 | pytorch-lightning==1.5.8 9 | musdb==0.4.0 10 | dora-search==0.1.11 11 | diffq>=0.2.1 12 | hydra-colorlog>=1.1 13 | hydra-core>=1.1 14 | julius>=0.2.3 15 | lameenc>=1.2 16 | museval>=0.4.0 17 | mypy>=0.991 18 | openunmix>=1.2.1 19 | pyyaml 20 | wandb 21 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | # Add these back in if we end up using notebooks in here at all? 3 | #- repo: https://github.com/kynan/nbstripout 4 | # rev: 0.3.9 5 | # hooks: 6 | # - id: nbstripout 7 | #- repo: https://github.com/mwouts/jupytext 8 | # rev: v1.11.2 9 | # hooks: 10 | # - id: jupytext 11 | # args: [--sync, --pipe, black] 12 | # additional_dependencies: 13 | # - black==21.5b0 # Matches hook 14 | - repo: https://github.com/psf/black 15 | rev: 22.12.0 16 | hooks: 17 | - id: black 18 | language_version: python3 19 | -------------------------------------------------------------------------------- /.github/workflows/snyk.yml: -------------------------------------------------------------------------------- 1 | name: Snyk 2 | on: push 3 | jobs: 4 | security: 5 | runs-on: ubuntu-latest 6 | steps: 7 | - uses: actions/checkout@master 8 | - name: Run Snyk to check for vulnerabilities 9 | uses: snyk/actions/python-3.8@master 10 | continue-on-error: true # To make sure that SARIF upload gets called 11 | env: 12 | SNYK_TOKEN: ${{ secrets.SNYK_TOKEN }} 13 | with: 14 | args: --sarif-file-output=snyk.sarif 15 | - name: Upload result to GitHub Code Scanning 16 | uses: github/codeql-action/upload-sarif@v2 17 | with: 18 | sarif_file: snyk.sarif 19 | -------------------------------------------------------------------------------- /.github/workflows/pylint.yml: -------------------------------------------------------------------------------- 1 | name: Python linting (black, isort, flake8, etc.) 2 | 3 | on: [pull_request] 4 | 5 | jobs: 6 | pr-lint: 7 | runs-on: ubuntu-latest 8 | steps: 9 | - uses: actions/checkout@v1 10 | name: Checkout 11 | - uses: ricardochaves/python-lint@v1.4.0 12 | with: 13 | python-root-list: 'heareval' 14 | # python-root-list: 'heareval examples' 15 | use-pylint: false 16 | use-pycodestyle: false 17 | use-flake8: true 18 | use-black: true 19 | use-mypy: true 20 | use-isort: false 21 | extra-pylint-options: "" 22 | extra-pycodestyle-options: "" 23 | extra-flake8-options: "--max-line-length=88 --extend-ignore=E203 --per-file-ignores=__init__.py:F401" 24 | extra-black-options: "" 25 | extra-mypy-options: "" 26 | extra-isort-options: "" 27 | -------------------------------------------------------------------------------- /demucs/grids/mdx.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | """ 7 | Main training for the Track A MDX models. 8 | """ 9 | 10 | from ..train import main 11 | from ._explorers import MyExplorer 12 | 13 | TRACK_A = ["0d19c1c6", "7ecf8ec1", "c511e2ab", "7d865c68"] 14 | 15 | 16 | @MyExplorer 17 | def explorer(launcher): 18 | launcher.slurm_(gpus=8, time=3 * 24 * 60, partition="learnlab") 19 | 20 | # Reproduce results from MDX competition Track A 21 | # This trains the first round of models. Once this is trained, 22 | # you will need to schedule `mdx_refine`. 23 | for sig in TRACK_A: 24 | xp = main.get_xp_from_sig(sig) 25 | parent = xp.cfg.continue_from 26 | xp = main.get_xp_from_sig(parent) 27 | launcher(xp.argv) 28 | launcher(xp.argv, {"quant.diffq": 1e-4}) 29 | launcher(xp.argv, {"quant.diffq": 3e-4}) 30 | -------------------------------------------------------------------------------- /demucs/grids/mdx_extra.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | """ 7 | Main training for the Track A MDX models. 8 | """ 9 | 10 | from ..train import main 11 | from ._explorers import MyExplorer 12 | 13 | TRACK_B = ["e51eebcc", "a1d90b5c", "5d2d6c55", "cfa93e08"] 14 | 15 | 16 | @MyExplorer 17 | def explorer(launcher): 18 | launcher.slurm_(gpus=8, time=3 * 24 * 60, partition="learnlab") 19 | 20 | # Reproduce results from MDX competition Track A 21 | # This trains the first round of models. Once this is trained, 22 | # you will need to schedule `mdx_refine`. 23 | for sig in TRACK_B: 24 | while sig is not None: 25 | xp = main.get_xp_from_sig(sig) 26 | sig = xp.cfg.continue_from 27 | 28 | for dset in ["extra44", "extra_test"]: 29 | sub = launcher.bind(xp.argv, dset=dset) 30 | sub() 31 | if dset == "extra_test": 32 | sub({"quant.diffq": 1e-4}) 33 | sub({"quant.diffq": 3e-4}) 34 | -------------------------------------------------------------------------------- /demucs/grids/mdx_refine.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | """ 7 | Main training for the Track A MDX models. 8 | """ 9 | 10 | from ..train import main 11 | from ._explorers import MyExplorer 12 | from .mdx import TRACK_A 13 | 14 | 15 | @MyExplorer 16 | def explorer(launcher): 17 | launcher.slurm_(gpus=8, time=3 * 24 * 60, partition="learnlab") 18 | 19 | # Reproduce results from MDX competition Track A 20 | # WARNING: all the experiments in the `mdx` grid must have completed. 21 | for sig in TRACK_A: 22 | xp = main.get_xp_from_sig(sig) 23 | launcher(xp.argv) 24 | for diffq in [1e-4, 3e-4]: 25 | xp_src = main.get_xp_from_sig(xp.cfg.continue_from) 26 | q_argv = [f"quant.diffq={diffq}"] 27 | actual_src = main.get_xp(xp_src.argv + q_argv) 28 | actual_src.link.load() 29 | assert len(actual_src.link.history) == actual_src.cfg.epochs 30 | argv = xp.argv + q_argv + [f'continue_from="{actual_src.sig}"'] 31 | launcher(argv) 32 | -------------------------------------------------------------------------------- /demucs/spec.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | """Conveniance wrapper to perform STFT and iSTFT""" 7 | 8 | import torch as th 9 | 10 | 11 | def spectro(x, n_fft=512, hop_length=None, pad=0): 12 | *other, length = x.shape 13 | x = x.reshape(-1, length) 14 | z = th.stft( 15 | x, 16 | n_fft * (1 + pad), 17 | hop_length or n_fft // 4, 18 | window=th.hann_window(n_fft).to(x), 19 | win_length=n_fft, 20 | normalized=True, 21 | center=True, 22 | return_complex=True, 23 | pad_mode="reflect", 24 | ) 25 | _, freqs, frame = z.shape 26 | return z.view(*other, freqs, frame) 27 | 28 | 29 | def ispectro(z, hop_length=None, length=None, pad=0): 30 | *other, freqs, frames = z.shape 31 | n_fft = 2 * freqs - 2 32 | z = z.view(-1, freqs, frames) 33 | win_length = n_fft // (1 + pad) 34 | x = th.istft( 35 | z, 36 | n_fft, 37 | hop_length, 38 | window=th.hann_window(win_length).to(z.real), 39 | win_length=win_length, 40 | normalized=True, 41 | length=length, 42 | center=True, 43 | ) 44 | _, length = x.shape 45 | return x.view(*other, length) 46 | -------------------------------------------------------------------------------- /demucs/ema.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # Inspired from https://github.com/rwightman/pytorch-image-models 8 | from contextlib import contextmanager 9 | 10 | import torch 11 | 12 | from .states import swap_state 13 | 14 | 15 | class ModelEMA: 16 | """ 17 | Perform EMA on a model. You can switch to the EMA weights temporarily 18 | with the `swap` method. 19 | 20 | ema = ModelEMA(model) 21 | with ema.swap(): 22 | # compute valid metrics with averaged model. 23 | """ 24 | 25 | def __init__(self, model, decay=0.9999, unbias=True, device="cpu"): 26 | self.decay = decay 27 | self.model = model 28 | self.state = {} 29 | self.count = 0 30 | self.device = device 31 | self.unbias = unbias 32 | 33 | self._init() 34 | 35 | def _init(self): 36 | for key, val in self.model.state_dict().items(): 37 | if val.dtype != torch.float32: 38 | continue 39 | device = self.device or val.device 40 | if key not in self.state: 41 | self.state[key] = val.detach().to(device, copy=True) 42 | 43 | def update(self): 44 | if self.unbias: 45 | self.count = self.count * self.decay + 1 46 | w = 1 / self.count 47 | else: 48 | w = 1 - self.decay 49 | for key, val in self.model.state_dict().items(): 50 | if val.dtype != torch.float32: 51 | continue 52 | device = self.device or val.device 53 | self.state[key].mul_(1 - w) 54 | self.state[key].add_(val.detach().to(device), alpha=w) 55 | 56 | @contextmanager 57 | def swap(self): 58 | with swap_state(self.model, self.state): 59 | yield 60 | 61 | def state_dict(self): 62 | return {"state": self.state, "count": self.count} 63 | 64 | def load_state_dict(self, state): 65 | self.count = state["count"] 66 | for k, v in state["state"].items(): 67 | self.state[k].copy_(v) 68 | -------------------------------------------------------------------------------- /.github/workflows/codeql-analysis.yml: -------------------------------------------------------------------------------- 1 | # For most projects, this workflow file will not need changing; you simply need 2 | # to commit it to your repository. 3 | # 4 | # You may wish to alter this file to override the set of languages analyzed, 5 | # or to provide custom queries or build logic. 6 | # 7 | # ******** NOTE ******** 8 | # We have attempted to detect the languages in your repository. Please check 9 | # the `language` matrix defined below to confirm you have the correct set of 10 | # supported CodeQL languages. 11 | # 12 | name: "CodeQL" 13 | 14 | on: 15 | push: 16 | branches: [ main ] 17 | pull_request: 18 | # The branches below must be a subset of the branches above 19 | branches: [ main ] 20 | schedule: 21 | - cron: '25 3 * * 5' 22 | 23 | jobs: 24 | analyze: 25 | name: Analyze 26 | runs-on: ubuntu-latest 27 | 28 | strategy: 29 | fail-fast: false 30 | matrix: 31 | language: [ 'python' ] 32 | # CodeQL supports [ 'cpp', 'csharp', 'go', 'java', 'javascript', 'python' ] 33 | # Learn more: 34 | # https://docs.github.com/en/free-pro-team@latest/github/finding-security-vulnerabilities-and-errors-in-your-code/configuring-code-scanning#changing-the-languages-that-are-analyzed 35 | 36 | steps: 37 | - name: Checkout repository 38 | uses: actions/checkout@v2 39 | 40 | # Initializes the CodeQL tools for scanning. 41 | - name: Initialize CodeQL 42 | uses: github/codeql-action/init@v1 43 | with: 44 | languages: ${{ matrix.language }} 45 | # If you wish to specify custom queries, you can do so here or in a config file. 46 | # By default, queries listed here will override any specified in a config file. 47 | # Prefix the list here with "+" to use these queries and those in the config file. 48 | # queries: ./path/to/local/query, your-org/your-repo/queries@main 49 | 50 | # Autobuild attempts to build any compiled languages (C/C++, C#, or Java). 51 | # If this step fails, then you should remove it and run the build manually (see below) 52 | - name: Autobuild 53 | uses: github/codeql-action/autobuild@v1 54 | 55 | # ℹ️ Command-line programs to run using the OS shell. 56 | # 📚 https://git.io/JvXDl 57 | 58 | # ✏️ If the Autobuild fails above, remove it and uncomment the following three lines 59 | # and modify them (or add more) to build your code if your project 60 | # uses a compiled language 61 | 62 | #- run: | 63 | # make bootstrap 64 | # make release 65 | 66 | - name: Perform CodeQL Analysis 67 | uses: github/codeql-action/analyze@v1 68 | -------------------------------------------------------------------------------- /demucs/pretrained.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | """Loading pretrained models. 7 | """ 8 | 9 | import logging 10 | import typing as tp 11 | from pathlib import Path 12 | 13 | from dora.log import fatal 14 | 15 | from .hdemucs import HDemucs 16 | from .repo import ( 17 | AnyModelRepo, 18 | BagOnlyRepo, 19 | LocalRepo, # noqa 20 | ModelLoadingError, 21 | ModelOnlyRepo, 22 | RemoteRepo, 23 | ) 24 | 25 | logger = logging.getLogger(__name__) 26 | ROOT_URL = "https://dl.fbaipublicfiles.com/demucs/mdx_final/" 27 | REMOTE_ROOT = Path(__file__).parent / "remote" 28 | 29 | SOURCES = ["drums", "bass", "other", "vocals"] 30 | 31 | 32 | def demucs_unittest(): 33 | model = HDemucs(channels=4, sources=SOURCES) 34 | return model 35 | 36 | 37 | def add_model_flags(parser): 38 | group = parser.add_mutually_exclusive_group(required=False) 39 | group.add_argument("-s", "--sig", help="Locally trained XP signature.") 40 | group.add_argument( 41 | "-n", 42 | "--name", 43 | default="mdx_extra_q", 44 | help="Pretrained model name or signature. Default is mdx_extra_q.", 45 | ) 46 | parser.add_argument( 47 | "--repo", 48 | type=Path, 49 | help="Folder containing all pre-trained models for use with -n.", 50 | ) 51 | 52 | 53 | def get_model(name: str, repo: tp.Optional[Path] = None): 54 | """`name` must be a bag of models name or a pretrained signature 55 | from the remote AWS model repo or the specified local repo if `repo` is not None. 56 | """ 57 | if name == "demucs_unittest": 58 | return demucs_unittest() 59 | model_repo: ModelOnlyRepo 60 | if repo is None: 61 | remote_files = [ 62 | line.strip() 63 | for line in (REMOTE_ROOT / "files.txt").read_text().split("\n") 64 | if line.strip() 65 | ] 66 | model_repo = RemoteRepo(ROOT_URL, remote_files) 67 | bag_repo = BagOnlyRepo(REMOTE_ROOT, model_repo) 68 | else: 69 | if not repo.is_dir(): 70 | fatal(f"{repo} must exist and be a directory.") 71 | model_repo = LocalRepo(repo) 72 | bag_repo = BagOnlyRepo(repo, model_repo) 73 | any_repo = AnyModelRepo(model_repo, bag_repo) 74 | return any_repo.get_model(name) 75 | 76 | 77 | def get_model_from_args(args): 78 | """ 79 | Load local model package or pre-trained model. 80 | """ 81 | return get_model(name=args.name, repo=args.repo) 82 | -------------------------------------------------------------------------------- /conf/infer_config.yaml: -------------------------------------------------------------------------------- 1 | checkpoint: null 2 | samplerate: 44100 3 | model: HDemucs 4 | devices: 1 5 | 6 | infer_audio_folder_path: null 7 | infer_audio_ext: 'wav' 8 | infer_segment: 11 9 | infer_samplerate: ${samplerate} 10 | 11 | dataloader: 12 | inference: 13 | num_workers: 1 14 | 15 | quant: # quantization hyper params 16 | diffq: # diffq penalty, typically 1e-4 or 3e-4 17 | qat: # use QAT with a fixed number of bits (not as good as diffq) 18 | min_size: 0.2 19 | group_size: 8 20 | 21 | trainer: 22 | devices: ${devices} 23 | # Pick only available GPUs 24 | auto_select_gpus: True 25 | accelerator: auto 26 | precision: 32 27 | check_val_every_n_epoch: 1 28 | resume_from_checkpoint: ${checkpoint} 29 | 30 | #logger: wandb 31 | logger: tensorboard 32 | 33 | wandb: 34 | # Optional 35 | #entity: yourusername 36 | 37 | dset: 38 | train: 39 | shift: 1 40 | 41 | data_augmentation: false 42 | demucs: # see demucs/demucs.py for a detailed description 43 | # Channels 44 | audio_channels: 2 45 | channels: 64 46 | growth: 2 47 | # Main structure 48 | depth: 6 49 | rewrite: true 50 | lstm_layers: 0 51 | # Convolutions 52 | kernel_size: 8 53 | stride: 4 54 | context: 1 55 | # Activations 56 | gelu: true 57 | glu: true 58 | # Normalization 59 | norm_groups: 4 60 | norm_starts: 4 61 | # DConv residual branch 62 | dconv_depth: 2 63 | dconv_mode: 1 # 1 = branch in encoder, 2 = in decoder, 3 = in both. 64 | dconv_comp: 4 65 | dconv_attn: 4 66 | dconv_lstm: 4 67 | dconv_init: 1e-4 68 | # Pre/post treatment 69 | resample: true 70 | normalize: false 71 | # Weight init 72 | rescale: 0.1 73 | 74 | hdemucs: # see demucs/hdemucs.py for a detailed description 75 | # Channels 76 | audio_channels: 2 77 | channels: 48 78 | channels_time: 79 | growth: 2 80 | # STFT 81 | nfft: 4096 82 | wiener_iters: 0 83 | end_iters: 0 84 | wiener_residual: false 85 | cac: true 86 | # Main structure 87 | depth: 6 88 | rewrite: true 89 | hybrid: true 90 | hybrid_old: false 91 | # Frequency Branch 92 | multi_freqs: [] 93 | multi_freqs_depth: 3 94 | freq_emb: 0.2 95 | emb_scale: 10 96 | emb_smooth: true 97 | # Convolutions 98 | kernel_size: 8 99 | stride: 4 100 | time_stride: 2 101 | context: 1 102 | context_enc: 0 103 | # normalization 104 | norm_starts: 4 105 | norm_groups: 4 106 | # DConv residual branch 107 | dconv_mode: 1 108 | dconv_depth: 2 109 | dconv_comp: 4 110 | dconv_attn: 4 111 | dconv_lstm: 4 112 | dconv_init: 1e-3 113 | # Weight init 114 | rescale: 0.1 115 | 116 | -------------------------------------------------------------------------------- /demucs/grids/_explorers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import treetable as tt 7 | from dora import Explorer 8 | 9 | 10 | class MyExplorer(Explorer): 11 | test_metrics = ["nsdr"] 12 | 13 | def get_grid_metrics(self): 14 | """Return the metrics that should be displayed in the tracking table.""" 15 | return [ 16 | tt.group( 17 | "train", 18 | [ 19 | tt.leaf("epoch"), 20 | tt.leaf("reco", ".3f"), 21 | ], 22 | align=">", 23 | ), 24 | tt.group( 25 | "valid", 26 | [ 27 | tt.leaf("penalty", ".1f"), 28 | tt.leaf("ms", ".1f"), 29 | tt.leaf("reco", ".2%"), 30 | tt.leaf("breco", ".2%"), 31 | tt.leaf("b_nsdr", ".2f"), 32 | tt.leaf("b_nsdr_drums", ".2f"), 33 | tt.leaf("b_nsdr_bass", ".2f"), 34 | tt.leaf("b_nsdr_other", ".2f"), 35 | tt.leaf("b_nsdr_vocals", ".2f"), 36 | ], 37 | align=">", 38 | ), 39 | tt.group( 40 | "test", [tt.leaf(name, ".2f") for name in self.test_metrics], align=">" 41 | ), 42 | ] 43 | 44 | def process_history(self, history): 45 | train = { 46 | "epoch": len(history), 47 | } 48 | valid = {} 49 | test = {} 50 | best_v_main = float("inf") 51 | breco = float("inf") 52 | for metrics in history: 53 | train.update(metrics["train"]) 54 | valid.update(metrics["valid"]) 55 | if "main" in metrics["valid"]: 56 | best_v_main = min(best_v_main, metrics["valid"]["main"]["loss"]) 57 | valid["bmain"] = best_v_main 58 | valid["breco"] = min(breco, metrics["valid"]["reco"]) 59 | breco = valid["breco"] 60 | if ( 61 | metrics["valid"]["loss"] == metrics["valid"]["best"] 62 | or metrics["valid"].get("nsdr") == metrics["valid"]["best"] 63 | ): 64 | for k, v in metrics["valid"].items(): 65 | if k.startswith("reco_"): 66 | valid["b_" + k[len("reco_") :]] = v 67 | if k.startswith("nsdr"): 68 | valid[f"b_{k}"] = v 69 | if "test" in metrics: 70 | test.update(metrics["test"]) 71 | metrics = history[-1] 72 | return {"train": train, "valid": valid, "test": test} 73 | -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | from pathlib import Path 3 | 4 | import hydra 5 | import pytorch_lightning as pl 6 | import torch 7 | import torchaudio 8 | from hydra import compose, initialize 9 | from hydra.utils import to_absolute_path 10 | from torch.utils.data import DataLoader, Dataset 11 | from torchaudio.functional import resample 12 | 13 | from demucs.demucs import Demucs 14 | from demucs.hdemucs import HDemucs 15 | from demucs.states import get_quantizer 16 | 17 | 18 | @hydra.main(config_path="conf", config_name="infer_config") 19 | def main(args): 20 | class InferDataset(Dataset): 21 | def __init__(self, audio_folder_path, audio_ext, sampling_rate): 22 | 23 | audiofolder = Path(audio_folder_path) # path of folder 24 | self.audio_path_list = list( 25 | audiofolder.glob(f"*.{audio_ext}") 26 | ) # path of audio 27 | self.sample_rate = sampling_rate 28 | 29 | self.audio_name = [] 30 | for i in self.audio_path_list: 31 | path = pathlib.PurePath(i) 32 | self.audio_name.append(path.name) 33 | 34 | def __len__(self): 35 | return len(self.audio_path_list) 36 | 37 | def __getitem__(self, idx): 38 | try: 39 | waveform, rate = torchaudio.load(self.audio_path_list[idx]) 40 | # return (torch.Tensor, int) 41 | 42 | if rate != self.sample_rate: 43 | waveform = resample(waveform, rate, self.sample_rate) 44 | # resample(waveform: torch.Tensor, orig_freq: int, new_freq: int) 45 | # return waveform tensor at the new frequency of dimension 46 | except: 47 | waveform = torch.tensor([[]]) 48 | rate = 0 49 | print(f"{self.audio_path_list[idx].name} is corrupted") 50 | audio_name = self.audio_name[idx] 51 | return waveform, audio_name 52 | 53 | if args.checkpoint == None: 54 | raise ValueError("Please enter the path for your model checkpoint") 55 | 56 | if args.infer_audio_folder_path == None: 57 | raise ValueError("Please enter the path for your inference audio folder") 58 | 59 | inference_set = InferDataset( 60 | to_absolute_path(args.infer_audio_folder_path), 61 | args.infer_audio_ext, 62 | args.infer_samplerate, 63 | ) 64 | inference_loader = DataLoader(inference_set, args.dataloader.inference.num_workers) 65 | 66 | if args.model == "Demucs": 67 | model = Demucs.load_from_checkpoint(to_absolute_path(args.checkpoint)) 68 | # call with pretrained model 69 | 70 | elif args.model == "HDemucs": 71 | model = HDemucs.load_from_checkpoint(to_absolute_path(args.checkpoint)) 72 | 73 | else: 74 | print("Invalid model, please choose Demucs or HDemucs") 75 | 76 | quantizer = get_quantizer(model, args.quant, model.optimizers) 77 | model.quantizer = quantizer # can use as self.quantizer in class Demucs 78 | 79 | trainer = pl.Trainer(**args.trainer) 80 | 81 | trainer.predict(model, dataloaders=inference_loader) 82 | 83 | 84 | if __name__ == "__main__": 85 | main() 86 | -------------------------------------------------------------------------------- /demucs/repitch.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | """Utility for on the fly pitch/tempo change for data augmentation.""" 7 | 8 | import random 9 | import subprocess as sp 10 | import tempfile 11 | 12 | import torch 13 | import torchaudio as ta 14 | 15 | from .audio import save_audio 16 | 17 | 18 | class RepitchedWrapper: 19 | """ 20 | Wrap a dataset to apply online change of pitch / tempo. 21 | """ 22 | 23 | def __init__( 24 | self, 25 | dataset, 26 | proba=0.2, 27 | max_pitch=2, 28 | max_tempo=12, 29 | tempo_std=5, 30 | vocals=[3], 31 | same=True, 32 | ): 33 | self.dataset = dataset 34 | self.proba = proba 35 | self.max_pitch = max_pitch 36 | self.max_tempo = max_tempo 37 | self.tempo_std = tempo_std 38 | self.same = same 39 | self.vocals = vocals 40 | 41 | def __len__(self): 42 | return len(self.dataset) 43 | 44 | def __getitem__(self, index): 45 | streams = self.dataset[index] 46 | in_length = streams.shape[-1] 47 | out_length = int((1 - 0.01 * self.max_tempo) * in_length) 48 | 49 | if random.random() < self.proba: 50 | outs = [] 51 | for idx, stream in enumerate(streams): 52 | if idx == 0 or not self.same: 53 | delta_pitch = random.randint(-self.max_pitch, self.max_pitch) 54 | delta_tempo = random.gauss(0, self.tempo_std) 55 | delta_tempo = min(max(-self.max_tempo, delta_tempo), self.max_tempo) 56 | stream = repitch( 57 | stream, delta_pitch, delta_tempo, voice=idx in self.vocals 58 | ) 59 | outs.append(stream[:, :out_length]) 60 | streams = torch.stack(outs) 61 | else: 62 | streams = streams[..., :out_length] 63 | return streams 64 | 65 | 66 | def repitch(wav, pitch, tempo, voice=False, quick=False, samplerate=44100): 67 | """ 68 | tempo is a relative delta in percentage, so tempo=10 means tempo at 110%! 69 | pitch is in semi tones. 70 | Requires `soundstretch` to be installed, see 71 | https://www.surina.net/soundtouch/soundstretch.html 72 | """ 73 | infile = tempfile.NamedTemporaryFile(suffix=".wav") 74 | outfile = tempfile.NamedTemporaryFile(suffix=".wav") 75 | save_audio(wav, infile.name, samplerate, clip="clamp") 76 | command = [ 77 | "soundstretch", 78 | infile.name, 79 | outfile.name, 80 | f"-pitch={pitch}", 81 | f"-tempo={tempo:.6f}", 82 | ] 83 | if quick: 84 | command += ["-quick"] 85 | if voice: 86 | command += ["-speech"] 87 | try: 88 | sp.run(command, capture_output=True, check=True) 89 | except sp.CalledProcessError as error: 90 | raise RuntimeError( 91 | f"Could not change bpm because {error.stderr.decode('utf-8')}" 92 | ) 93 | wav, sr = ta.load(outfile.name) 94 | assert sr == samplerate 95 | return wav 96 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | # for dataset 2 | import hashlib 3 | import json 4 | import os 5 | import sys 6 | from pathlib import Path 7 | 8 | import hydra 9 | import pytorch_lightning as pl 10 | import torch 11 | import torchaudio as ta 12 | import tqdm 13 | 14 | # library for Musdb dataset 15 | from AudioLoader.music.mss import MusdbHQ 16 | from hydra import compose, initialize 17 | from hydra.utils import to_absolute_path 18 | from omegaconf import OmegaConf 19 | from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint 20 | from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger 21 | from torch.utils.data import DataLoader, Subset 22 | 23 | # library for loader() 24 | from torch.utils.data.distributed import DistributedSampler 25 | 26 | from demucs.demucs import Demucs 27 | from demucs.hdemucs import HDemucs 28 | from demucs.states import get_quantizer 29 | from demucs.svd import svd_penalty 30 | 31 | 32 | @hydra.main(config_path="conf", config_name="train_test_config") 33 | def main(args): 34 | args.data_root = to_absolute_path(args.data_root) 35 | 36 | test_set = MusdbHQ( 37 | root=args.dset.test.root, 38 | subset="test", 39 | sources=args.dset.test.sources, 40 | download=args.dset.test.download, 41 | segment=args.dset.test.segment, 42 | shift=args.dset.test.shift, 43 | normalize=args.dset.test.normalize, 44 | samplerate=args.dset.test.samplerate, 45 | channels=args.dset.test.channels, 46 | ext=args.dset.test.ext, 47 | ) 48 | 49 | test_loader = DataLoader( 50 | test_set, 51 | batch_size=args.dataloader.test.batch_size, 52 | shuffle=args.dataloader.test.shuffle, 53 | num_workers=args.dataloader.test.num_workers, 54 | drop_last=False, 55 | ) 56 | 57 | if args.model == "Demucs": 58 | model = Demucs( 59 | sources=args.sources, 60 | samplerate=args.samplerate, 61 | segment=4 * args.dset.train.segment, 62 | **args.demucs, 63 | args=args, 64 | ) 65 | model = model.load_from_checkpoint(to_absolute_path(args.resume_checkpoint)) 66 | 67 | elif args.model == "HDemucs": 68 | model = HDemucs( 69 | sources=args.sources, 70 | samplerate=args.samplerate, 71 | segment=4 * args.dset.train.segment, 72 | **args.hdemucs, 73 | args=args, 74 | ) 75 | model = model.load_from_checkpoint(to_absolute_path(args.resume_checkpoint)) 76 | 77 | else: 78 | print("Invalid model, please choose Demucs or HDemucs") 79 | 80 | quantizer = get_quantizer(model, args.quant, model.optimizers) 81 | model.quantizer = quantizer # can use as self.quantizer in class Demucs 82 | 83 | name = f"Testing_{args.checkpoint.filename}" 84 | # file name shown in tensorboard logger 85 | 86 | lr_monitor = LearningRateMonitor(logging_interval="step") 87 | 88 | if args.logger == "tensorboard": 89 | logger = TensorBoardLogger(save_dir=".", version=1, name=name) 90 | elif args.logger == "wandb": 91 | logger = WandbLogger(project="demucs_lightning_test", **args.wandb) 92 | else: 93 | raise Exception(f"Logger {args.logger} not implemented") 94 | 95 | trainer = pl.Trainer(**args.trainer, logger=logger) 96 | 97 | trainer.test(model, test_loader) 98 | 99 | 100 | if __name__ == "__main__": 101 | main() 102 | -------------------------------------------------------------------------------- /demucs/svd.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | """Ways to make the model stronger.""" 7 | import random 8 | 9 | import torch 10 | 11 | 12 | def power_iteration(m, niters=1, bs=1): 13 | """This is the power method. batch size is used to try multiple starting point in parallel.""" 14 | assert m.dim() == 2 15 | assert m.shape[0] == m.shape[1] 16 | dim = m.shape[0] 17 | b = torch.randn(dim, bs, device=m.device, dtype=m.dtype) 18 | 19 | for _ in range(niters): 20 | n = m.mm(b) 21 | norm = n.norm(dim=0, keepdim=True) 22 | b = n / (1e-10 + norm) 23 | 24 | return norm.mean() 25 | 26 | 27 | # We need a shared RNG to make sure all the distributed worker will skip the penalty together, 28 | # as otherwise we wouldn't get any speed up. 29 | penalty_rng = random.Random(1234) 30 | 31 | 32 | def svd_penalty( 33 | model, 34 | min_size=0.1, 35 | dim=1, 36 | niters=2, 37 | powm=False, 38 | convtr=True, 39 | proba=1, 40 | conv_only=False, 41 | exact=False, 42 | bs=1, 43 | ): 44 | """ 45 | Penalty on the largest singular value for a layer. 46 | Args: 47 | - model: model to penalize 48 | - min_size: minimum size in MB of a layer to penalize. 49 | - dim: projection dimension for the svd_lowrank. Higher is better but slower. 50 | - niters: number of iterations in the algorithm used by svd_lowrank. 51 | - powm: use power method instead of lowrank SVD, my own experience 52 | is that it is both slower and less stable. 53 | - convtr: when True, differentiate between Conv and Transposed Conv. 54 | this is kept for compatibility with older experiments. 55 | - proba: probability to apply the penalty. 56 | - conv_only: only apply to conv and conv transposed, not LSTM 57 | (might not be reliable for other models than Demucs). 58 | - exact: use exact SVD (slow but useful at validation). 59 | - bs: batch_size for power method. 60 | """ 61 | total = 0 62 | if penalty_rng.random() > proba: 63 | return 0.0 64 | 65 | for m in model.modules(): 66 | for name, p in m.named_parameters(recurse=False): 67 | if p.numel() / 2**18 < min_size: 68 | continue 69 | if convtr: 70 | if isinstance(m, (torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d)): 71 | if p.dim() in [3, 4]: 72 | p = p.transpose(0, 1).contiguous() 73 | if p.dim() == 3: 74 | p = p.view(len(p), -1) 75 | elif p.dim() == 4: 76 | p = p.view(len(p), -1) 77 | elif p.dim() == 1: 78 | continue 79 | elif conv_only: 80 | continue 81 | assert p.dim() == 2, (name, p.shape) 82 | if exact: 83 | estimate = torch.svd(p, compute_uv=False)[1].pow(2).max() 84 | elif powm: 85 | a, b = p.shape 86 | if a < b: 87 | n = p.mm(p.t()) 88 | else: 89 | n = p.t().mm(p) 90 | estimate = power_iteration(n, niters, bs) 91 | else: 92 | estimate = torch.svd_lowrank(p, dim, niters)[1][0].pow(2) 93 | total += estimate 94 | return total / proba 95 | -------------------------------------------------------------------------------- /demucs/distrib.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | """Distributed training utilities. 7 | """ 8 | import logging 9 | import pickle 10 | 11 | import numpy as np 12 | import torch 13 | from dora import distrib as dora_distrib 14 | from torch.nn.parallel.distributed import DistributedDataParallel 15 | from torch.utils.data import DataLoader, Subset 16 | from torch.utils.data.distributed import DistributedSampler 17 | 18 | logger = logging.getLogger(__name__) 19 | rank = 0 20 | world_size = 1 21 | 22 | 23 | def init(): 24 | global rank, world_size 25 | if not torch.distributed.is_initialized(): 26 | dora_distrib.init() 27 | rank = dora_distrib.rank() 28 | world_size = dora_distrib.world_size() 29 | 30 | 31 | def average(metrics, count=1.0): 32 | if isinstance(metrics, dict): 33 | keys, values = zip(*sorted(metrics.items())) 34 | values = average(values, count) 35 | return dict(zip(keys, values)) 36 | if world_size == 1: 37 | return metrics 38 | tensor = torch.tensor(list(metrics) + [1], device="cuda", dtype=torch.float32) 39 | tensor *= count 40 | torch.distributed.all_reduce(tensor, op=torch.distributed.ReduceOp.SUM) 41 | return (tensor[:-1] / tensor[-1]).cpu().numpy().tolist() 42 | 43 | 44 | def wrap(model): 45 | if world_size == 1: 46 | return model 47 | else: 48 | return DistributedDataParallel( 49 | model, 50 | # find_unused_parameters=True, 51 | device_ids=[torch.cuda.current_device()], 52 | output_device=torch.cuda.current_device(), 53 | ) 54 | 55 | 56 | def barrier(): 57 | if world_size > 1: 58 | torch.distributed.barrier() 59 | 60 | 61 | def share(obj=None, src=0): 62 | if world_size == 1: 63 | return obj 64 | size = torch.empty(1, device="cuda", dtype=torch.long) 65 | if rank == src: 66 | dump = pickle.dumps(obj) 67 | size[0] = len(dump) 68 | torch.distributed.broadcast(size, src=src) 69 | # size variable is now set to the length of pickled obj in all processes 70 | 71 | if rank == src: 72 | buffer = torch.from_numpy(np.frombuffer(dump, dtype=np.uint8).copy()).cuda() 73 | else: 74 | buffer = torch.empty(size[0].item(), device="cuda", dtype=torch.uint8) 75 | torch.distributed.broadcast(buffer, src=src) 76 | # buffer variable is now set to pickled obj in all processes 77 | 78 | if rank != src: 79 | obj = pickle.loads(buffer.cpu().numpy().tobytes()) 80 | logger.debug(f"Shared object of size {len(buffer)}") 81 | return obj 82 | 83 | 84 | def loader(dataset, *args, shuffle=False, klass=DataLoader, **kwargs): 85 | """ 86 | Create a dataloader properly in case of distributed training. 87 | If a gradient is going to be computed you must set `shuffle=True`. 88 | """ 89 | if world_size == 1: 90 | return klass(dataset, *args, shuffle=shuffle, **kwargs) 91 | 92 | if shuffle: 93 | # train means we will compute backward, we use DistributedSampler 94 | sampler = DistributedSampler(dataset) 95 | # We ignore shuffle, DistributedSampler already shuffles 96 | return klass(dataset, *args, **kwargs, sampler=sampler) 97 | else: 98 | # We make a manual shard, as DistributedSampler otherwise replicate some examples 99 | dataset = Subset(dataset, list(range(rank, len(dataset), world_size))) 100 | return klass(dataset, *args, shuffle=shuffle, **kwargs) 101 | -------------------------------------------------------------------------------- /demucs/augment.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | """Data augmentations. 7 | """ 8 | 9 | import random 10 | 11 | import torch as th 12 | from torch import nn 13 | 14 | 15 | class Shift(nn.Module): 16 | """ 17 | Randomly shift audio in time by up to `shift` samples. 18 | """ 19 | 20 | def __init__(self, shift=8192, same=False): 21 | super().__init__() 22 | self.shift = shift 23 | self.same = same 24 | 25 | def forward(self, wav): 26 | batch, sources, channels, time = wav.size() 27 | length = time - self.shift 28 | if self.shift > 0: 29 | if not self.training: 30 | wav = wav[..., :length] 31 | else: 32 | srcs = 1 if self.same else sources 33 | offsets = th.randint(self.shift, [batch, srcs, 1, 1], device=wav.device) 34 | offsets = offsets.expand(-1, sources, channels, -1) 35 | indexes = th.arange(length, device=wav.device) 36 | wav = wav.gather(3, indexes + offsets) 37 | return wav 38 | 39 | 40 | class FlipChannels(nn.Module): 41 | """ 42 | Flip left-right channels. 43 | """ 44 | 45 | def forward(self, wav): 46 | batch, sources, channels, time = wav.size() 47 | if self.training and wav.size(2) == 2: 48 | left = th.randint(2, (batch, sources, 1, 1), device=wav.device) 49 | left = left.expand(-1, -1, -1, time) 50 | right = 1 - left 51 | wav = th.cat([wav.gather(2, left), wav.gather(2, right)], dim=2) 52 | return wav 53 | 54 | 55 | class FlipSign(nn.Module): 56 | """ 57 | Random sign flip. 58 | """ 59 | 60 | def forward(self, wav): 61 | batch, sources, channels, time = wav.size() 62 | if self.training: 63 | signs = th.randint( 64 | 2, (batch, sources, 1, 1), device=wav.device, dtype=th.float32 65 | ) 66 | wav = wav * (2 * signs - 1) 67 | return wav 68 | 69 | 70 | class Remix(nn.Module): 71 | """ 72 | Shuffle sources to make new mixes. 73 | """ 74 | 75 | def __init__(self, proba=1, group_size=4): 76 | """ 77 | Shuffle sources within one batch. 78 | Each batch is divided into groups of size `group_size` and shuffling is done within 79 | each group separatly. This allow to keep the same probability distribution no matter 80 | the number of GPUs. Without this grouping, using more GPUs would lead to a higher 81 | probability of keeping two sources from the same track together which can impact 82 | performance. 83 | """ 84 | super().__init__() 85 | self.proba = proba 86 | self.group_size = group_size 87 | 88 | def forward(self, wav): 89 | batch, streams, channels, time = wav.size() 90 | device = wav.device 91 | 92 | if self.training and random.random() < self.proba: 93 | group_size = self.group_size or batch 94 | if batch % group_size != 0: 95 | raise ValueError( 96 | f"Batch size {batch} must be divisible by group size {group_size}" 97 | ) 98 | groups = batch // group_size 99 | wav = wav.view(groups, group_size, streams, channels, time) 100 | permutations = th.argsort( 101 | th.rand(groups, group_size, streams, 1, 1, device=device), dim=1 102 | ) 103 | wav = wav.gather(1, permutations.expand(-1, -1, -1, channels, time)) 104 | wav = wav.view(batch, streams, channels, time) 105 | return wav 106 | 107 | 108 | class Scale(nn.Module): 109 | def __init__(self, proba=1.0, min=0.25, max=1.25): 110 | super().__init__() 111 | self.proba = proba 112 | self.min = min 113 | self.max = max 114 | 115 | def forward(self, wav): 116 | batch, streams, channels, time = wav.size() 117 | device = wav.device 118 | if self.training and random.random() < self.proba: 119 | scales = th.empty(batch, streams, 1, 1, device=device).uniform_( 120 | self.min, self.max 121 | ) 122 | wav *= scales 123 | return wav 124 | -------------------------------------------------------------------------------- /demucs/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import math 8 | import os 9 | import tempfile 10 | import typing as tp 11 | from collections import defaultdict 12 | from contextlib import contextmanager 13 | 14 | import torch 15 | from torch.nn import functional as F 16 | 17 | 18 | def unfold(a, kernel_size, stride): 19 | """Given input of size [*OT, T], output Tensor of size [*OT, F, K] 20 | with K the kernel size, by extracting frames with the given stride. 21 | 22 | This will pad the input so that `F = ceil(T / K)`. 23 | 24 | see https://github.com/pytorch/pytorch/issues/60466 25 | """ 26 | *shape, length = a.shape 27 | n_frames = math.ceil(length / stride) 28 | tgt_length = (n_frames - 1) * stride + kernel_size 29 | a = F.pad(a, (0, tgt_length - length)) 30 | strides = list(a.stride()) 31 | assert strides[-1] == 1, "data should be contiguous" 32 | strides = strides[:-1] + [stride, 1] 33 | return a.as_strided([*shape, n_frames, kernel_size], strides) 34 | 35 | 36 | def center_trim(tensor: torch.Tensor, reference: tp.Union[torch.Tensor, int]): 37 | """ 38 | Center trim `tensor` with respect to `reference`, along the last dimension. 39 | `reference` can also be a number, representing the length to trim to. 40 | If the size difference != 0 mod 2, the extra sample is removed on the right side. 41 | """ 42 | ref_size: int 43 | if isinstance(reference, torch.Tensor): 44 | ref_size = reference.size(-1) 45 | else: 46 | ref_size = reference 47 | delta = tensor.size(-1) - ref_size 48 | if delta < 0: 49 | raise ValueError("tensor must be larger than reference. " f"Delta is {delta}.") 50 | if delta: 51 | tensor = tensor[..., delta // 2 : -(delta - delta // 2)] 52 | return tensor 53 | 54 | 55 | def pull_metric(history: tp.List[dict], name: str): 56 | out = [] 57 | for metrics in history: 58 | metric = metrics 59 | for part in name.split("."): 60 | metric = metric[part] 61 | out.append(metric) 62 | return out 63 | 64 | 65 | def EMA(beta: float = 1): 66 | """ 67 | Exponential Moving Average callback. 68 | Returns a single function that can be called to repeatidly update the EMA 69 | with a dict of metrics. The callback will return 70 | the new averaged dict of metrics. 71 | 72 | Note that for `beta=1`, this is just plain averaging. 73 | """ 74 | fix: tp.Dict[str, float] = defaultdict(float) 75 | total: tp.Dict[str, float] = defaultdict(float) 76 | 77 | def _update(metrics: dict, weight: float = 1) -> dict: 78 | nonlocal total, fix 79 | for key, value in metrics.items(): 80 | total[key] = total[key] * beta + weight * float(value) 81 | fix[key] = fix[key] * beta + weight 82 | return {key: tot / fix[key] for key, tot in total.items()} 83 | 84 | return _update 85 | 86 | 87 | def sizeof_fmt(num: float, suffix: str = "B"): 88 | """ 89 | Given `num` bytes, return human readable size. 90 | Taken from https://stackoverflow.com/a/1094933 91 | """ 92 | for unit in ["", "Ki", "Mi", "Gi", "Ti", "Pi", "Ei", "Zi"]: 93 | if abs(num) < 1024.0: 94 | return "%3.1f%s%s" % (num, unit, suffix) 95 | num /= 1024.0 96 | return "%.1f%s%s" % (num, "Yi", suffix) 97 | 98 | 99 | @contextmanager 100 | def temp_filenames(count: int, delete=True): 101 | names = [] 102 | try: 103 | for _ in range(count): 104 | names.append(tempfile.NamedTemporaryFile(delete=False).name) 105 | yield names 106 | finally: 107 | if delete: 108 | for name in names: 109 | os.unlink(name) 110 | 111 | 112 | class DummyPoolExecutor: 113 | class DummyResult: 114 | def __init__(self, func, *args, **kwargs): 115 | self.func = func 116 | self.args = args 117 | self.kwargs = kwargs 118 | 119 | def result(self): 120 | return self.func(*self.args, **self.kwargs) 121 | 122 | def __init__(self, workers=0): 123 | pass 124 | 125 | def submit(self, func, *args, **kwargs): 126 | return DummyPoolExecutor.DummyResult(func, *args, **kwargs) 127 | 128 | def __enter__(self): 129 | return self 130 | 131 | def __exit__(self, exc_type, exc_value, exc_tb): 132 | return 133 | -------------------------------------------------------------------------------- /demucs/states.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | """ 7 | Utilities to save and load models. 8 | """ 9 | import functools 10 | import hashlib 11 | import inspect 12 | import io 13 | import warnings 14 | from contextlib import contextmanager 15 | from pathlib import Path 16 | 17 | import torch 18 | from diffq import DiffQuantizer, UniformQuantizer, restore_quantized_state 19 | from omegaconf import OmegaConf 20 | 21 | 22 | def get_quantizer(model, args, optimizer=None): 23 | """Return the quantizer given the XP quantization args.""" 24 | quantizer = None 25 | if args.diffq: 26 | quantizer = DiffQuantizer( 27 | model, min_size=args.min_size, group_size=args.group_size 28 | ) 29 | if optimizer is not None: 30 | quantizer.setup_optimizer(optimizer) 31 | elif args.qat: 32 | quantizer = UniformQuantizer(model, bits=args.qat, min_size=args.min_size) 33 | return quantizer 34 | 35 | 36 | def load_model(path_or_package, strict=False): 37 | """Load a model from the given serialized model, either given as a dict (already loaded) 38 | or a path to a file on disk.""" 39 | if isinstance(path_or_package, dict): 40 | package = path_or_package 41 | elif isinstance(path_or_package, (str, Path)): 42 | with warnings.catch_warnings(): 43 | warnings.simplefilter("ignore") 44 | path = path_or_package 45 | package = torch.load(path, "cpu") 46 | else: 47 | raise ValueError(f"Invalid type for {path_or_package}.") 48 | 49 | klass = package["klass"] 50 | args = package["args"] 51 | kwargs = package["kwargs"] 52 | 53 | if strict: 54 | model = klass(*args, **kwargs) 55 | else: 56 | sig = inspect.signature(klass) 57 | for key in list(kwargs): 58 | if key not in sig.parameters: 59 | warnings.warn("Dropping inexistant parameter " + key) 60 | del kwargs[key] 61 | model = klass(*args, **kwargs) 62 | 63 | state = package["state"] 64 | 65 | set_state(model, state) 66 | return model 67 | 68 | 69 | def get_state(model, quantizer, half=False): 70 | """Get the state from a model, potentially with quantization applied. 71 | If `half` is True, model are stored as half precision, which shouldn't impact performance 72 | but half the state size.""" 73 | if quantizer is None: 74 | dtype = torch.half if half else None 75 | state = { 76 | k: p.data.to(device="cpu", dtype=dtype) 77 | for k, p in model.state_dict().items() 78 | } 79 | else: 80 | state = quantizer.get_quantized_state() 81 | state["__quantized"] = True 82 | return state 83 | 84 | 85 | def set_state(model, state, quantizer=None): 86 | """Set the state on a given model.""" 87 | if state.get("__quantized"): 88 | if quantizer is not None: 89 | quantizer.restore_quantized_state(model, state["quantized"]) 90 | else: 91 | restore_quantized_state(model, state) 92 | else: 93 | model.load_state_dict(state) 94 | return state 95 | 96 | 97 | def save_with_checksum(content, path): 98 | """Save the given value on disk, along with a sha256 hash. 99 | Should be used with the output of either `serialize_model` or `get_state`.""" 100 | buf = io.BytesIO() 101 | torch.save(content, buf) 102 | sig = hashlib.sha256(buf.getvalue()).hexdigest()[:8] 103 | 104 | path = path.parent / (path.stem + "-" + sig + path.suffix) 105 | path.write_bytes(buf.getvalue()) 106 | 107 | 108 | def serialize_model(model, training_args, quantizer=None, half=True): 109 | args, kwargs = model._init_args_kwargs 110 | klass = model.__class__ 111 | 112 | state = get_state(model, quantizer, half) 113 | return { 114 | "klass": klass, 115 | "args": args, 116 | "kwargs": kwargs, 117 | "state": state, 118 | "training_args": OmegaConf.to_container(training_args, resolve=True), 119 | } 120 | 121 | 122 | def copy_state(state): 123 | return {k: v.cpu().clone() for k, v in state.items()} 124 | 125 | 126 | @contextmanager 127 | def swap_state(model, state): 128 | """ 129 | Context manager that swaps the state of a model, e.g: 130 | 131 | # model is in old state 132 | with swap_state(model, new_state): 133 | # model in new state 134 | # model back to old state 135 | """ 136 | old_state = copy_state(model.state_dict()) 137 | model.load_state_dict(state, strict=False) 138 | try: 139 | yield 140 | finally: 141 | model.load_state_dict(old_state) 142 | 143 | 144 | def capture_init(init): 145 | @functools.wraps(init) 146 | def __init__(self, *args, **kwargs): 147 | self._init_args_kwargs = (args, kwargs) 148 | init(self, *args, **kwargs) 149 | 150 | return __init__ 151 | -------------------------------------------------------------------------------- /demucs/repo.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | """Represents a model repository, including pre-trained models and bags of models. 7 | A repo can either be the main remote repository stored in AWS, or a local repository 8 | with your own models. 9 | """ 10 | 11 | import typing as tp 12 | from hashlib import sha256 13 | from pathlib import Path 14 | 15 | import torch 16 | import yaml 17 | 18 | from .apply import BagOfModels, Model 19 | from .states import load_model 20 | 21 | AnyModel = tp.Union[Model, BagOfModels] 22 | 23 | 24 | class ModelLoadingError(RuntimeError): 25 | pass 26 | 27 | 28 | def check_checksum(path: Path, checksum: str): 29 | sha = sha256() 30 | with open(path, "rb") as file: 31 | while True: 32 | buf = file.read(2**20) 33 | if not buf: 34 | break 35 | sha.update(buf) 36 | actual_checksum = sha.hexdigest()[: len(checksum)] 37 | if actual_checksum != checksum: 38 | raise ModelLoadingError( 39 | f"Invalid checksum for file {path}, " 40 | f"expected {checksum} but got {actual_checksum}" 41 | ) 42 | 43 | 44 | class ModelOnlyRepo: 45 | """Base class for all model only repos.""" 46 | 47 | def has_model(self, sig: str) -> bool: 48 | raise NotImplementedError() 49 | 50 | def get_model(self, sig: str) -> Model: 51 | raise NotImplementedError() 52 | 53 | 54 | class RemoteRepo(ModelOnlyRepo): 55 | def __init__(self, root_url: str, remote_files: tp.List[str]): 56 | if not root_url.endswith("/"): 57 | root_url += "/" 58 | self._models: tp.Dict[str, str] = {} 59 | for file in remote_files: 60 | sig, checksum = file.split(".")[0].split("-") 61 | assert sig not in self._models 62 | self._models[sig] = root_url + file 63 | 64 | def has_model(self, sig: str) -> bool: 65 | return sig in self._models 66 | 67 | def get_model(self, sig: str) -> Model: 68 | try: 69 | url = self._models[sig] 70 | except KeyError: 71 | raise ModelLoadingError( 72 | f"Could not find a pre-trained model with signature {sig}." 73 | ) 74 | pkg = torch.hub.load_state_dict_from_url( 75 | url, map_location="cpu", check_hash=True 76 | ) 77 | return load_model(pkg) 78 | 79 | 80 | class LocalRepo(ModelOnlyRepo): 81 | def __init__(self, root: Path): 82 | self.root = root 83 | self.scan() 84 | 85 | def scan(self): 86 | self._models = {} 87 | self._checksums = {} 88 | for file in self.root.iterdir(): 89 | if file.suffix == ".th": 90 | if "-" in file.stem: 91 | xp_sig, checksum = file.stem.split("-") 92 | self._checksums[xp_sig] = checksum 93 | else: 94 | xp_sig = file.stem 95 | if xp_sig in self._models: 96 | raise ModelLoadingError( 97 | f"Duplicate pre-trained model exist for signature {xp_sig}. " 98 | "Please delete all but one." 99 | ) 100 | self._models[xp_sig] = file 101 | 102 | def has_model(self, sig: str) -> bool: 103 | return sig in self._models 104 | 105 | def get_model(self, sig: str) -> Model: 106 | try: 107 | file = self._models[sig] 108 | except KeyError: 109 | raise ModelLoadingError( 110 | f"Could not find pre-trained model with signature {sig}." 111 | ) 112 | if sig in self._checksums: 113 | check_checksum(file, self._checksums[sig]) 114 | return load_model(file) 115 | 116 | 117 | class BagOnlyRepo: 118 | """Handles only YAML files containing bag of models, leaving the actual 119 | model loading to some Repo. 120 | """ 121 | 122 | def __init__(self, root: Path, model_repo: ModelOnlyRepo): 123 | self.root = root 124 | self.model_repo = model_repo 125 | self.scan() 126 | 127 | def scan(self): 128 | self._bags = {} 129 | for file in self.root.iterdir(): 130 | if file.suffix == ".yaml": 131 | self._bags[file.stem] = file 132 | 133 | def has_model(self, name: str) -> bool: 134 | return name in self._bags 135 | 136 | def get_model(self, name: str) -> BagOfModels: 137 | try: 138 | yaml_file = self._bags[name] 139 | except KeyError: 140 | raise ModelLoadingError( 141 | f"{name} is neither a single pre-trained model or " "a bag of models." 142 | ) 143 | bag = yaml.safe_load(open(yaml_file)) 144 | signatures = bag["models"] 145 | models = [self.model_repo.get_model(sig) for sig in signatures] 146 | weights = bag.get("weights") 147 | segment = bag.get("segment") 148 | return BagOfModels(models, weights, segment) 149 | 150 | 151 | class AnyModelRepo: 152 | def __init__(self, model_repo: ModelOnlyRepo, bag_repo: BagOnlyRepo): 153 | self.model_repo = model_repo 154 | self.bag_repo = bag_repo 155 | 156 | def has_model(self, name_or_sig: str) -> bool: 157 | return self.model_repo.has_model(name_or_sig) or self.bag_repo.has_model( 158 | name_or_sig 159 | ) 160 | 161 | def get_model(self, name_or_sig: str) -> AnyModel: 162 | if self.model_repo.has_model(name_or_sig): 163 | return self.model_repo.get_model(name_or_sig) 164 | else: 165 | return self.bag_repo.get_model(name_or_sig) 166 | -------------------------------------------------------------------------------- /demucs/train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | """Main training script entry point""" 8 | 9 | import logging 10 | import os 11 | import sys 12 | from pathlib import Path 13 | 14 | import hydra 15 | import torch 16 | from dora import hydra_main 17 | from hydra.core.global_hydra import GlobalHydra 18 | from omegaconf import OmegaConf 19 | from torch.utils.data import ConcatDataset 20 | 21 | from . import distrib 22 | from .demucs import Demucs 23 | from .hdemucs import HDemucs 24 | from .repitch import RepitchedWrapper 25 | from .solver import Solver 26 | from .wav import get_musdb_wav_datasets, get_wav_datasets 27 | 28 | logger = logging.getLogger(__name__) 29 | 30 | 31 | def get_model(args): 32 | extra = { 33 | "sources": list(args.dset.sources), 34 | "audio_channels": args.dset.channels, 35 | "samplerate": args.dset.samplerate, 36 | "segment": args.model_segment or 4 * args.dset.segment, 37 | } 38 | klass = {"demucs": Demucs, "hdemucs": HDemucs}[args.model] 39 | kw = OmegaConf.to_container(getattr(args, args.model), resolve=True) 40 | model = klass(**extra, **kw) 41 | return model 42 | 43 | 44 | def get_solver(args, model_only=False): 45 | distrib.init() 46 | 47 | torch.manual_seed(args.seed) 48 | model = get_model(args) 49 | if args.misc.show: 50 | logger.info(model) 51 | mb = sum(p.numel() for p in model.parameters()) * 4 / 2**20 52 | logger.info("Size: %.1f MB", mb) 53 | if hasattr(model, "valid_length"): 54 | field = model.valid_length(1) 55 | logger.info("Field: %.1f ms", field / args.dset.samplerate * 1000) 56 | sys.exit(0) 57 | 58 | # torch also initialize cuda seed if available 59 | if torch.cuda.is_available(): 60 | model.cuda() 61 | 62 | # optimizer 63 | if args.optim.optim == "adam": 64 | optimizer = torch.optim.Adam( 65 | model.parameters(), 66 | lr=args.optim.lr, 67 | betas=(args.optim.momentum, args.optim.beta2), 68 | weight_decay=args.optim.weight_decay, 69 | ) 70 | elif args.optim.optim == "adamw": 71 | optimizer = torch.optim.AdamW( 72 | model.parameters(), 73 | lr=args.optim.lr, 74 | betas=(args.optim.momentum, args.optim.beta2), 75 | weight_decay=args.optim.weight_decay, 76 | ) 77 | 78 | assert args.batch_size % distrib.world_size == 0 79 | args.batch_size //= distrib.world_size 80 | 81 | if model_only: 82 | return Solver(None, model, optimizer, args) 83 | 84 | train_set, valid_set = get_musdb_wav_datasets(args.dset) 85 | if args.dset.wav: 86 | extra_train_set, extra_valid_set = get_wav_datasets(args.dset) 87 | train_set = ConcatDataset([train_set, extra_train_set]) 88 | valid_set = ConcatDataset([valid_set, extra_valid_set]) 89 | 90 | if args.augment.repitch.proba: 91 | vocals = [] 92 | if "vocals" in args.dset.sources: 93 | vocals.append(args.dset.sources.index("vocals")) 94 | else: 95 | logger.warning("No vocal source found") 96 | if args.augment.repitch.proba: 97 | train_set = RepitchedWrapper( 98 | train_set, vocals=vocals, **args.augment.repitch 99 | ) 100 | 101 | logger.info("train/valid set size: %d %d", len(train_set), len(valid_set)) 102 | train_loader = distrib.loader( 103 | train_set, 104 | batch_size=args.batch_size, 105 | shuffle=True, 106 | num_workers=args.misc.num_workers, 107 | drop_last=True, 108 | ) 109 | if args.dset.full_cv: 110 | valid_loader = distrib.loader( 111 | valid_set, batch_size=1, shuffle=False, num_workers=args.misc.num_workers 112 | ) 113 | else: 114 | valid_loader = distrib.loader( 115 | valid_set, 116 | batch_size=args.batch_size, 117 | shuffle=False, 118 | num_workers=args.misc.num_workers, 119 | drop_last=True, 120 | ) 121 | loaders = {"train": train_loader, "valid": valid_loader} 122 | 123 | # Construct Solver 124 | torch.save(loaders, "loaders.pt") 125 | return Solver(loaders, model, optimizer, args) 126 | 127 | 128 | def get_solver_from_sig(sig, model_only=False): 129 | inst = GlobalHydra.instance() 130 | hyd = None 131 | if inst.is_initialized(): 132 | hyd = inst.hydra 133 | inst.clear() 134 | xp = main.get_xp_from_sig(sig) 135 | if hyd is not None: 136 | inst.clear() 137 | inst.initialize(hyd) 138 | 139 | with xp.enter(stack=True): 140 | return get_solver(xp.cfg, model_only) 141 | 142 | 143 | @hydra_main(config_path="../conf", config_name="config") 144 | def main(args): 145 | os.environ["CUDA_VISIBLE_DEVICES"] = "2" 146 | global __file__ 147 | __file__ = hydra.utils.to_absolute_path(__file__) 148 | for attr in ["musdb", "wav", "metadata"]: 149 | val = getattr(args.dset, attr) 150 | if val is not None: 151 | setattr(args.dset, attr, hydra.utils.to_absolute_path(val)) 152 | 153 | os.environ["OMP_NUM_THREADS"] = "1" 154 | os.environ["MKL_NUM_THREADS"] = "1" 155 | 156 | if args.misc.verbose: 157 | logger.setLevel(logging.DEBUG) 158 | 159 | logger.info("For logs, checkpoints and samples check %s", os.getcwd()) 160 | logger.debug(args) 161 | from dora import get_xp 162 | 163 | logger.debug(get_xp().cfg) 164 | 165 | solver = get_solver(args) 166 | solver.train() 167 | 168 | 169 | if "_DORA_TEST_PATH" in os.environ: 170 | main.dora.dir = Path(os.environ["_DORA_TEST_PATH"]) 171 | 172 | 173 | if __name__ == "__main__": 174 | main() 175 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # for dataset 2 | import hashlib 3 | import json 4 | import os 5 | 6 | # library for Musdb dataset 7 | import sys 8 | from pathlib import Path 9 | 10 | import hydra 11 | import pytorch_lightning as pl 12 | import torch 13 | import torchaudio as ta 14 | import tqdm 15 | from hydra import compose, initialize 16 | from hydra.utils import to_absolute_path 17 | from omegaconf import OmegaConf 18 | from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint 19 | from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger 20 | from pytorch_lightning.plugins import DDPPlugin 21 | from torch.utils.data import DataLoader, Subset 22 | 23 | # library for loader() 24 | from torch.utils.data.distributed import DistributedSampler 25 | 26 | from demucs.demucs import Demucs 27 | from demucs.hdemucs import HDemucs 28 | from demucs.states import get_quantizer 29 | from demucs.svd import svd_penalty 30 | 31 | sys.path.insert(0, "/workspace/helen/AudioLoader") 32 | from AudioLoader.music.mss import MusdbHQ 33 | 34 | 35 | @hydra.main(config_path="conf", config_name="train_test_config") 36 | def main(args): 37 | args.data_root = to_absolute_path(args.data_root) 38 | train_set = MusdbHQ( 39 | root=args.dset.train.root, 40 | subset="training", 41 | sources=["drums", "bass", "other", "vocals"], 42 | # have to be 4 sourcse, to make mix in training_step #mix = sources.sum(dim=1) 43 | download=args.dset.train.download, 44 | segment=args.dset.train.segment, 45 | shift=args.dset.train.shift, 46 | normalize=args.dset.train.normalize, 47 | samplerate=args.dset.train.samplerate, 48 | channels=args.dset.train.channels, 49 | ext=args.dset.train.ext, 50 | ) 51 | 52 | valid_set = MusdbHQ( 53 | root=args.dset.valid.root, 54 | subset="validation", 55 | sources=args.dset.valid.sources, 56 | download=args.dset.valid.download, 57 | segment=args.dset.valid.segment, 58 | shift=args.dset.valid.shift, 59 | normalize=args.dset.valid.normalize, 60 | samplerate=args.dset.valid.samplerate, 61 | channels=args.dset.valid.channels, 62 | ext=args.dset.valid.ext, 63 | ) 64 | 65 | test_set = MusdbHQ( 66 | root=args.dset.test.root, 67 | subset="test", 68 | sources=args.dset.test.sources, 69 | download=args.dset.test.download, 70 | segment=args.dset.test.segment, 71 | shift=args.dset.test.shift, 72 | normalize=args.dset.test.normalize, 73 | samplerate=args.dset.test.samplerate, 74 | channels=args.dset.test.channels, 75 | ext=args.dset.test.ext, 76 | ) 77 | 78 | train_loader = DataLoader( 79 | train_set, 80 | batch_size=args.dataloader.train.batch_size, 81 | shuffle=args.dataloader.train.shuffle, 82 | num_workers=args.dataloader.train.num_workers, 83 | drop_last=True, 84 | ) 85 | 86 | valid_loader = DataLoader( 87 | valid_set, 88 | batch_size=args.dataloader.valid.batch_size, 89 | shuffle=args.dataloader.valid.shuffle, 90 | num_workers=args.dataloader.valid.num_workers, 91 | drop_last=False, 92 | ) 93 | 94 | test_loader = DataLoader( 95 | test_set, 96 | batch_size=args.dataloader.test.batch_size, 97 | shuffle=args.dataloader.test.shuffle, 98 | num_workers=args.dataloader.test.num_workers, 99 | drop_last=False, 100 | ) 101 | 102 | if args.model == "Demucs": 103 | model = Demucs( 104 | sources=args.sources, 105 | samplerate=args.samplerate, 106 | segment=4 * args.dset.train.segment, 107 | **args.demucs, 108 | args=args, 109 | ) 110 | 111 | elif args.model == "HDemucs": 112 | model = HDemucs( 113 | sources=args.sources, 114 | samplerate=args.samplerate, 115 | segment=4 * args.dset.train.segment, 116 | **args.hdemucs, 117 | args=args, 118 | ) 119 | 120 | else: 121 | print("Invalid model, please choose Demucs or HDemucs") 122 | 123 | quantizer = get_quantizer(model, args.quant, model.optimizers) 124 | model.quantizer = quantizer # can use as self.quantizer in class Demucs 125 | 126 | # print(f'optimizer = {model.optimizers}') 127 | 128 | # print(f'len train_set= {len(train_set)}') #len train_set= 18368 129 | # print(f'len valid_set= {len(valid_set)}') #len valid_set= 14 130 | 131 | checkpoint_callback = ModelCheckpoint( 132 | **args.checkpoint, auto_insert_metric_name=False 133 | ) 134 | # auto_insert_metric_name = False: won't refer the '/' in filename as path 135 | 136 | name = f"{args.model}_experiment_epoch={args.epochs}_augmentation={args.data_augmentation}" 137 | # file name shown in tensorboard logger 138 | 139 | lr_monitor = LearningRateMonitor(logging_interval="step") 140 | 141 | if args.logger == "tensorboard": 142 | logger = TensorBoardLogger(save_dir=".", version=1, name=name) 143 | elif args.logger == "wandb": 144 | logger = WandbLogger(project="demucs_lightning", **args.wandb) 145 | else: 146 | raise Exception(f"Logger {args.logger} not implemented") 147 | 148 | if ( 149 | args.trainer.resume_from_checkpoint 150 | ): # resume previous training when this is given 151 | args.trainer.resume_from_checkpoint = to_absolute_path( 152 | args.trainer.resume_from_checkpoint 153 | ) 154 | print(f"Resume training from {args.trainer.resume_from_checkpoint}") 155 | trainer = pl.Trainer( 156 | **args.trainer, 157 | callbacks=[checkpoint_callback, lr_monitor], 158 | strategy=DDPPlugin(find_unused_parameters=False), 159 | logger=logger, 160 | ) 161 | 162 | trainer.fit(model, train_loader, valid_loader) 163 | trainer.test(model, test_loader) 164 | 165 | 166 | if __name__ == "__main__": 167 | main() 168 | -------------------------------------------------------------------------------- /conf/train_test_config.yaml: -------------------------------------------------------------------------------- 1 | data_root: '../../MusicDataset/musdb18hq' 2 | resume_checkpoint: null 3 | download: false 4 | samplerate: 44100 5 | model: HDemucs 6 | data_augmentation: true 7 | segment: 11 8 | batch_size: 4 9 | devices: -1 10 | epochs: 360 11 | seed: 42 12 | debug: false 13 | valid_apply: true 14 | flag: 15 | save_every: 16 | weights: [1., 1., 1., 1.] # weights over each source for the training/valid loss. 17 | continue_from: # continue from other XP, give the XP Dora signature. 18 | continue_pretrained: # signature of a pretrained XP, this cannot be a bag of models. 19 | pretrained_repo: # repo for pretrained model (default is official AWS) 20 | continue_best: true 21 | continue_opt: false 22 | sources: ['drums', 'bass', 'other', 'vocals'] 23 | 24 | max_batches: # limit the number of batches per epoch, useful for debugging 25 | # or if your dataset is gigantic. 26 | 27 | defaults: 28 | - _self_ 29 | - dset: musdb44 30 | - svd: default 31 | - variant: default 32 | - override hydra/hydra_logging: colorlog 33 | - override hydra/job_logging: colorlog 34 | 35 | dummy: 36 | dset: 37 | train: 38 | root: ${data_root} 39 | download: ${download} 40 | segment: ${segment} 41 | shift: 1 42 | normalize: True 43 | samplerate: ${samplerate} 44 | channels: 2 45 | ext: '.wav' 46 | valid: 47 | root: ${data_root} 48 | sources: ${sources} 49 | download: False 50 | segment: Null 51 | shift: Null 52 | normalize: True 53 | samplerate: ${samplerate} 54 | channels: 2 55 | ext: '.wav' 56 | test: 57 | root: ${data_root} 58 | sources: ${sources} 59 | download: False 60 | segment: Null 61 | shift: Null 62 | normalize: True 63 | samplerate: ${samplerate} 64 | channels: 2 65 | ext: '.wav' 66 | wav: # path to custom wav dataset 67 | 68 | 69 | test: 70 | save: False 71 | best: True 72 | every: 20 73 | split: true 74 | shifts: 1 75 | overlap: 0.25 76 | sdr: true 77 | metric: 'loss' # metric used for best model selection on the valid set, can also be nsdr 78 | nonhq: # path to non hq MusDB for evaluation 79 | 80 | dataloader: 81 | train: 82 | batch_size: ${batch_size} 83 | shuffle: True 84 | num_workers: 10 85 | valid: 86 | batch_size: 1 #ref for valid batch size: https://github.com/facebookresearch/demucs/blob/cb1d773a35ff889d25a5177b86c86c0ce8ba9ef3/demucs/train.py#L101 87 | shuffle: False 88 | num_workers: 10 89 | 90 | test: 91 | batch_size: 1 92 | shuffle: False 93 | num_workers: 2 94 | 95 | optim: 96 | lr: 3e-4 97 | momentum: 0.9 98 | beta2: 0.999 99 | loss: l1 # l1 or mse 100 | optim: adam 101 | weight_decay: 0 102 | clip_grad: 0 103 | 104 | 105 | augment: 106 | shift_same: false 107 | repitch: 108 | proba: 0.0 109 | max_tempo: 12 110 | remix: 111 | proba: 1 112 | group_size: 4 113 | scale: 114 | proba: 1 115 | min: 0.25 116 | max: 1.25 117 | flip: true 118 | 119 | misc: 120 | num_prints: 4 121 | show: false 122 | verbose: false 123 | 124 | # List of decay for EMA at batch or epoch level, e.g. 0.999. 125 | # Batch level EMA are kept on GPU for speed. 126 | ema: 127 | epoch: [] 128 | batch: [] 129 | 130 | model_segment: # override the segment parameter for the model, usually 4 times the training segment. 131 | 132 | demucs: # see demucs/demucs.py for a detailed description 133 | # Channels 134 | audio_channels: 2 135 | channels: 64 136 | growth: 2 137 | # Main structure 138 | depth: 6 139 | rewrite: true 140 | lstm_layers: 0 141 | # Convolutions 142 | kernel_size: 8 143 | stride: 4 144 | context: 1 145 | # Activations 146 | gelu: true 147 | glu: true 148 | # Normalization 149 | norm_groups: 4 150 | norm_starts: 4 151 | # DConv residual branch 152 | dconv_depth: 2 153 | dconv_mode: 1 # 1 = branch in encoder, 2 = in decoder, 3 = in both. 154 | dconv_comp: 4 155 | dconv_attn: 4 156 | dconv_lstm: 4 157 | dconv_init: 1e-4 158 | # Pre/post treatment 159 | resample: true 160 | normalize: false 161 | # Weight init 162 | rescale: 0.1 163 | 164 | hdemucs: # see demucs/hdemucs.py for a detailed description 165 | # Channels 166 | audio_channels: 2 167 | channels: 48 168 | channels_time: 169 | growth: 2 170 | # STFT 171 | nfft: 4096 172 | wiener_iters: 0 173 | end_iters: 0 174 | wiener_residual: false 175 | cac: true 176 | # Main structure 177 | depth: 6 178 | rewrite: true 179 | hybrid: true 180 | hybrid_old: false 181 | # Frequency Branch 182 | multi_freqs: [] 183 | multi_freqs_depth: 3 184 | freq_emb: 0.2 185 | emb_scale: 10 186 | emb_smooth: true 187 | # Convolutions 188 | kernel_size: 8 189 | stride: 4 190 | time_stride: 2 191 | context: 1 192 | context_enc: 0 193 | # normalization 194 | norm_starts: 4 195 | norm_groups: 4 196 | # DConv residual branch 197 | dconv_mode: 1 198 | dconv_depth: 2 199 | dconv_comp: 4 200 | dconv_attn: 4 201 | dconv_lstm: 4 202 | dconv_init: 1e-3 203 | # Weight init 204 | rescale: 0.1 205 | 206 | svd: # see svd.py for documentation 207 | penalty: 0 208 | min_size: 0.1 209 | dim: 1 210 | niters: 2 211 | powm: false 212 | proba: 1 213 | conv_only: false 214 | convtr: false 215 | bs: 1 216 | 217 | quant: # quantization hyper params 218 | diffq: # diffq penalty, typically 1e-4 or 3e-4 219 | qat: # use QAT with a fixed number of bits (not as good as diffq) 220 | min_size: 0.2 221 | group_size: 8 222 | 223 | dora: 224 | dir: outputs 225 | exclude: ["misc.*", "slurm.*", 'test.reval', 'flag'] 226 | 227 | slurm: 228 | time: 4320 229 | constraint: volta32gb 230 | setup: ['module load cuda/11.0 cudnn/v8.0.3.33-cuda.11.0 NCCL/2.8.3-1-cuda.11.0'] 231 | 232 | 233 | checkpoint: 234 | monitor: 'TRAIN/loss' #'Train/Loss' 235 | filename: "e={epoch:02d}-TRAIN_loss={TRAIN/loss:.2f}" 236 | save_top_k: 1 #only save the one whatever the minimum 237 | mode: 'min' #if validation/acc, then will monitor 'max' 238 | save_last: True #save the last point 239 | every_n_epochs: 1 240 | 241 | trainer: 242 | devices: ${devices} 243 | # Pick only available GPUs 244 | auto_select_gpus: True 245 | accelerator: auto 246 | precision: 32 247 | max_epochs: ${epochs} 248 | check_val_every_n_epoch: 1 249 | resume_from_checkpoint: ${resume_checkpoint} 250 | 251 | #logger: wandb 252 | logger: tensorboard 253 | 254 | wandb: 255 | # Optional 256 | #entity: yourusername 257 | 258 | # Hydra config 259 | hydra: 260 | job_logging: 261 | formatters: 262 | colorlog: 263 | datefmt: "%m-%d %H:%M:%S" 264 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Demucs lightning 2 | 3 | 1. [Introduction](#Introduction) 4 | 1. [Requirement](#Requirement) 5 | 1. [Training](#Training) 6 | 1. [Demucs](#Demucs) 7 | 1. [HDemucs](#HDemucs) 8 | 1. [Resume from Checkpoints](#Resume-from-Checkpoints) 9 | 1. [Training with a less powerful GPU](#Training-with-a-less-powerful-GPU) 10 | 1. [Testing pretrained model](#Testing-pretrained-model) 11 | 1. [Half-Precision Training](#Half-Precision-Training) 12 | 1. [Default settings](#Default-settings) 13 | 1. [Inferencing](#Inferencing) 14 | 1. [Development](#Development) 15 | 16 | 17 | ## Introduction 18 | ``` 19 | demucs_lightning 20 | ├──conf 21 | │ ├─train_test_config.yaml 22 | │ ├─infer_config.yaml 23 | │ │ 24 | │ 25 | ├──demucs 26 | │ ├─demucs.py 27 | │ ├─hdemucs.py 28 | │ ├─other custome modules 29 | │ 30 | ├──requirements.txt 31 | ├──train.py 32 | ├──test.py 33 | ├──inference.py 34 | │ 35 | ``` 36 | 37 | There are 2 major released version of Demucs. 38 | * Demucs (v2) used waveform as domain. 39 | * Hybrid Demucs (v3) is featuring hybrid source separation. 40 | 41 | You can find their model structure in 42 | `demucs.py` and `hdemucs.py` from demucs folder.\ 43 | For the official information of Demucs, you can visit [facebookresearch/demucs](https://github.com/facebookresearch/demucs) 44 | 45 | Demucs is trained by [MusdbHQ](https://sigsep.github.io/datasets/musdb.html). This repo uses `AudioLoader` to get MusdbHQ dataset .\ 46 | For more information of Audioloader, you can visit [KinWaiCheuk/AudioLoader](https://github.com/KinWaiCheuk/AudioLoader). 47 | 48 | Or else you can download MusdbHQ dataset manually from [zenodo](https://zenodo.org/record/3338373#.YoEmSC8RpQI). 49 | 50 | ## Requirement 51 | `Python==3.8.10` and `ffmpeg` is required to run this repo. 52 | 53 | If `ffmpeg` is not installed on your machine, you can install it via `apt install ffmpeg` 54 | 55 | You can install all required libraries at once via 56 | ``` bash 57 | pip install -r requirements.txt 58 | ``` 59 | 60 | ## Logging 61 | 62 | TensorBoard logging is used by default. If you want to use the 63 | WandbLogger instead (recommended!), either edit `logger` in 64 | `conf/train_test_config.yaml` or postpend `logger=wandb` to all 65 | your commands. 66 | 67 | ## Training 68 | If it is your first time running the repo, you can use the argument `download=True` to automatically download and setup the `musdb18hq` dataset. Otherwise, you can omit this argument. 69 | 70 | ### Demucs 71 | It requires `16,885 MB` of GPU memory. If you do not have enough GPU memory, please read [this section](#Training-with-a-less-powerful-GPU). 72 | 73 | ```bash 74 | python train.py devices=[0] model=Demucs download=True 75 | ``` 76 | 77 | ### HDemucs 78 | It requires `19,199 MB` of GPU memory. 79 | ```bash 80 | python train.py devices=[0] model=HDemucs download=True 81 | ``` 82 | 83 | ## Resume from Checkpoints 84 | It is possible to continue training from an existing checkpoint by passing the `resume_checkpoint` argument. By default, hydra saves all the checkpoints at `'outputs/YYYY-MM-DD/HH-MM-SS/XXX_experiment_epoch=XXX_augmentation=XXX/version_1/checkpoints/XXX.ckpt'`. For example, if you have a checkpoint trained with 32-bit precision for 100 epochs already via the following command: 85 | 86 | ```bash 87 | python train.py devices=[0] trainer.precision=32 epochs=100 88 | ``` 89 | 90 | And now you want to train for 50 epochs more, then you can use the following CLI command: 91 | 92 | ```bash 93 | python train.py devices=[0] trainer.precision=16 epochs=150 resume_checkpoint='outputs/2022-05-24/21-20-17/Demucs_experiment_epoch=360_augmentation=True/version_1/checkpoints/e=123-TRAIN_loss=0.08.ckpt' 94 | ``` 95 | 96 | You can always move you checkpoints to a better place to shorten the path name. 97 | 98 | ## Training with a less powerful GPU 99 | It is possible to reduce the GPU memory required to train the models by using the following tricks. But it might affect the model performance. 100 | ### Reduce Batch Size 101 | You can reduce the batch size to `2`. By doing so, it only requires `10,851 MB` of GPU memory. 102 | ```bash 103 | python train.py batch_size=2 augment.remix.group_size=2 model=Demucs 104 | ``` 105 | 106 | ### Disable Augmentation 107 | You can futher reduce the batch size to `1` if data augmentation is disabled. By doing so, it only requires `7,703 MB` of GPU memory. 108 | ```bash 109 | python train.py batch_size=1 data_augmentation=False model=Demucs 110 | ``` 111 | 112 | 113 | ### Reduce Audio Segment Length 114 | You can reduce the audio segment length to only `6`. By doing so, it only requires `6,175 MB` of GPU memory. 115 | ```bash 116 | python train.py batch_size=1 data_augmentation=False segment=6 model=Demucs 117 | ``` 118 | 119 | ## Testing pretrained model 120 | You can use `test.py` to evaluate the pretrained model directly by using an existing checkpoint. You can give the checkpoint path via `resume_checkpoint` argument. 121 | 122 | ```bash 123 | python test.py resume_checkpoint='outputs/2022-05-24/21-20-17/Demucs_experiment_epoch=360_augmentation=True/version_1/checkpoints/e=123-TRAIN_loss=0.08.ckpt' 124 | ``` 125 | 126 | ## Half-Precision Training 127 | By default, pytorch lightning uses 32-bit precision for training. To use 16-bit precision (half-precision), you can specify `trainer.precision`: 128 | 129 | ```bash 130 | python train.py trainer.precision=16 131 | ``` 132 | 133 | Double-precision is also supported by specifying `trainer.precision=64`. 134 | 135 | 136 | 137 | ## Default settings 138 | The full list of arguments and their default values can be found in `conf/config.yaml`. 139 | 140 | __devices__: Select which GPU to use. If you have multiple GPUs on your machine and you want to use GPU:2, you can set `devices=[2]`. If you want to use DDP (multi-GPU training), you can set `devices=2`, it will automatically use the first two GPUs avaliable in your machine. If you want to use GPU:0, GPU:2, and GPU:3 for training, you can set `devices=[0,2,3]`. 141 | 142 | __download__: When set to `True`, it will automatically download and setup the dataset. Default as `False` 143 | 144 | __data_root__: Select the location of your dataset. If `download=True`, it will become the directory that the dataset is going to be downloaded to. Default as `'./musdb18hq'` 145 | 146 | __model__: Select which version of demucs to use. Default model of this repo is Hybrid Demucs (v3). You can switch to Demucs (v2) by setting the `model=Demucs`. 147 | 148 | __samplerate__: The sampling rate for the audio. Default as `44100`. 149 | 150 | __epochs__: The number of epochs to train the model. Default as `360`. 151 | 152 | __optim.lr__: Learning rate of the optimizer. Default as `3e-4`. 153 | 154 | 155 | ## Inferencing 156 | You are able to apply your trained model weight on your own audio file by using `inference.py`. Some nesscesary argument are the following: 157 | 158 | * `checkpoint` refers to the path of trained model weight checkpoint file 159 | * `infer_audio_folder_path` refers to the path of your audio folder where has all the audios inside 160 | * `infer_audio_ext` refer to the type of your audio. Default value is `'wav'` 161 | 162 | ```bash 163 | python inference.py infer_audio_folder_path='../../infer_audio' checkpoint='outputs/2022-05-24/21-20-17/Demucs_experiment_epoch=360_augmentation=True/version_1/checkpoints/e=123-TRAIN_loss=0.08.ckpt' 164 | ``` 165 | 166 | By default, hydra saves all the seperated audio in the `outputs` folder. 167 | ## Development 168 | 169 | If you are a developer on this repo, please run: 170 | 171 | ``` 172 | pre-commit install 173 | ``` 174 | 175 | -------------------------------------------------------------------------------- /demucs/evaluate.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | """Test time evaluation, either using the original SDR from [Vincent et al. 2006] 8 | or the newest SDR definition from the MDX 2021 competition (this one will 9 | be reported as `nsdr` for `new sdr`). 10 | """ 11 | 12 | import logging 13 | from concurrent import futures 14 | 15 | import musdb 16 | import museval 17 | import numpy as np 18 | import torch as th 19 | from dora.log import LogProgress 20 | 21 | from . import distrib 22 | 23 | # from .apply import apply_model 24 | from .audio import convert_audio, save_audio 25 | from .utils import DummyPoolExecutor 26 | 27 | logger = logging.getLogger(__name__) 28 | 29 | 30 | def new_sdr(references, estimates): 31 | """ 32 | Compute the SDR according to the MDX challenge definition. 33 | Adapted from AIcrowd/music-demixing-challenge-starter-kit (MIT license) 34 | """ 35 | assert references.dim() == 4 36 | assert estimates.dim() == 4 37 | delta = 1e-7 # avoid numerical errors 38 | num = th.sum(th.square(references), dim=(2, 3)) 39 | den = th.sum(th.square(references - estimates), dim=(2, 3)) 40 | num += delta 41 | den += delta 42 | scores = 10 * th.log10(num / den) 43 | return scores 44 | 45 | 46 | def eval_track(references, estimates, win, hop, compute_sdr=True): 47 | references = references.transpose(1, 2).double() 48 | estimates = estimates.transpose(1, 2).double() 49 | 50 | new_scores = new_sdr(references.cpu()[None], estimates.cpu()[None])[0] 51 | 52 | if not compute_sdr: 53 | return None, new_scores 54 | else: 55 | references = references.numpy() 56 | estimates = estimates.numpy() 57 | scores = museval.metrics.bss_eval( 58 | references, 59 | estimates, 60 | compute_permutation=False, 61 | window=win, 62 | hop=hop, 63 | framewise_filters=False, 64 | bsseval_sources_version=False, 65 | )[:-1] 66 | return scores, new_scores 67 | 68 | 69 | def evaluate(solver, compute_sdr=False): 70 | """ 71 | Evaluate model using museval. 72 | `new_only` means using only the MDX definition of the SDR, which is much faster to evaluate. 73 | """ 74 | from .apply import apply_model 75 | 76 | args = solver.args 77 | 78 | output_dir = solver.folder / "results" 79 | output_dir.mkdir(exist_ok=True, parents=True) 80 | json_folder = solver.folder / "results/test" 81 | json_folder.mkdir(exist_ok=True, parents=True) 82 | 83 | # we load tracks from the original musdb set 84 | if args.test.nonhq is None: 85 | test_set = musdb.DB(args.dset.musdb, subsets=["test"], is_wav=True) 86 | else: 87 | test_set = musdb.DB(args.test.nonhq, subsets=["test"], is_wav=False) 88 | src_rate = args.dset.musdb_samplerate 89 | 90 | eval_device = "cpu" 91 | 92 | model = solver.model 93 | win = int(1.0 * model.samplerate) 94 | hop = int(1.0 * model.samplerate) 95 | 96 | indexes = range(distrib.rank, len(test_set), distrib.world_size) 97 | indexes = LogProgress(logger, indexes, updates=args.misc.num_prints, name="Eval") 98 | pendings = [] 99 | 100 | pool = futures.ProcessPoolExecutor if args.test.workers else DummyPoolExecutor 101 | with pool(args.test.workers) as pool: 102 | for index in indexes: 103 | track = test_set.tracks[index] 104 | 105 | mix = th.from_numpy(track.audio).t().float() 106 | if mix.dim() == 1: 107 | mix = mix[None] 108 | mix = mix.to(solver.device) 109 | ref = mix.mean(dim=0) # mono mixture 110 | mix = (mix - ref.mean()) / ref.std() 111 | mix = convert_audio(mix, src_rate, model.samplerate, model.audio_channels) 112 | estimates = apply_model( 113 | model, 114 | mix[None], 115 | shifts=args.test.shifts, 116 | split=args.test.split, 117 | overlap=args.test.overlap, 118 | )[0] 119 | estimates = estimates * ref.std() + ref.mean() 120 | estimates = estimates.to(eval_device) 121 | 122 | references = th.stack( 123 | [th.from_numpy(track.targets[name].audio).t() for name in model.sources] 124 | ) 125 | if references.dim() == 2: 126 | references = references[:, None] 127 | references = references.to(eval_device) 128 | references = convert_audio( 129 | references, src_rate, model.samplerate, model.audio_channels 130 | ) 131 | if args.test.save: 132 | folder = solver.folder / "wav" / track.name 133 | folder.mkdir(exist_ok=True, parents=True) 134 | for name, estimate in zip(model.sources, estimates): 135 | save_audio( 136 | estimate.cpu(), folder / (name + ".mp3"), model.samplerate 137 | ) 138 | 139 | pendings.append( 140 | ( 141 | track.name, 142 | pool.submit( 143 | eval_track, 144 | references, 145 | estimates, 146 | win=win, 147 | hop=hop, 148 | compute_sdr=compute_sdr, 149 | ), 150 | ) 151 | ) 152 | 153 | pendings = LogProgress( 154 | logger, pendings, updates=args.misc.num_prints, name="Eval (BSS)" 155 | ) 156 | tracks = {} 157 | for track_name, pending in pendings: 158 | pending = pending.result() 159 | scores, nsdrs = pending 160 | tracks[track_name] = {} 161 | for idx, target in enumerate(model.sources): 162 | tracks[track_name][target] = {"nsdr": [float(nsdrs[idx])]} 163 | if scores is not None: 164 | (sdr, isr, sir, sar) = scores 165 | for idx, target in enumerate(model.sources): 166 | values = { 167 | "SDR": sdr[idx].tolist(), 168 | "SIR": sir[idx].tolist(), 169 | "ISR": isr[idx].tolist(), 170 | "SAR": sar[idx].tolist(), 171 | } 172 | tracks[track_name][target].update(values) 173 | 174 | all_tracks = {} 175 | for src in range(distrib.world_size): 176 | all_tracks.update(distrib.share(tracks, src)) 177 | 178 | result = {} 179 | metric_names = next(iter(all_tracks.values()))[model.sources[0]] 180 | for metric_name in metric_names: 181 | avg = 0 182 | avg_of_medians = 0 183 | for source in model.sources: 184 | medians = [ 185 | np.nanmedian(all_tracks[track][source][metric_name]) 186 | for track in all_tracks.keys() 187 | ] 188 | mean = np.mean(medians) 189 | median = np.median(medians) 190 | result[metric_name.lower() + "_" + source] = mean 191 | result[metric_name.lower() + "_med" + "_" + source] = median 192 | avg += mean / len(model.sources) 193 | avg_of_medians += median / len(model.sources) 194 | result[metric_name.lower()] = avg 195 | result[metric_name.lower() + "_med"] = avg_of_medians 196 | return result 197 | -------------------------------------------------------------------------------- /demucs/separate.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import argparse 8 | import subprocess 9 | import sys 10 | from pathlib import Path 11 | 12 | import torch as th 13 | import torchaudio as ta 14 | from dora.log import fatal 15 | 16 | from .apply import BagOfModels, apply_model 17 | from .audio import AudioFile, convert_audio, save_audio 18 | from .pretrained import ModelLoadingError, add_model_flags, get_model_from_args 19 | 20 | 21 | def load_track(track, audio_channels, samplerate): 22 | errors = {} 23 | wav = None 24 | 25 | try: 26 | wav = AudioFile(track).read( 27 | streams=0, samplerate=samplerate, channels=audio_channels 28 | ) 29 | except FileNotFoundError: 30 | errors["ffmpeg"] = "Ffmpeg is not installed." 31 | except subprocess.CalledProcessError: 32 | errors["ffmpeg"] = "FFmpeg could not read the file." 33 | 34 | if wav is None: 35 | try: 36 | wav, sr = ta.load(str(track)) 37 | except RuntimeError as err: 38 | errors["torchaudio"] = err.args[0] 39 | else: 40 | wav = convert_audio(wav, sr, samplerate, audio_channels) 41 | 42 | if wav is None: 43 | print( 44 | f"Could not load file {track}. " "Maybe it is not a supported file format? " 45 | ) 46 | for backend, error in errors.items(): 47 | print( 48 | f"When trying to load using {backend}, got the following error: {error}" 49 | ) 50 | sys.exit(1) 51 | return wav 52 | 53 | 54 | def main(): 55 | parser = argparse.ArgumentParser( 56 | "demucs.separate", description="Separate the sources for the given tracks" 57 | ) 58 | parser.add_argument( 59 | "tracks", nargs="+", type=Path, default=[], help="Path to tracks" 60 | ) 61 | add_model_flags(parser) 62 | parser.add_argument("-v", "--verbose", action="store_true") 63 | parser.add_argument( 64 | "-o", 65 | "--out", 66 | type=Path, 67 | default=Path("separated"), 68 | help="Folder where to put extracted tracks. A subfolder " 69 | "with the model name will be created.", 70 | ) 71 | parser.add_argument( 72 | "-d", 73 | "--device", 74 | default="cuda" if th.cuda.is_available() else "cpu", 75 | help="Device to use, default is cuda if available else cpu", 76 | ) 77 | parser.add_argument( 78 | "--shifts", 79 | default=1, 80 | type=int, 81 | help="Number of random shifts for equivariant stabilization." 82 | "Increase separation time but improves quality for Demucs. 10 was used " 83 | "in the original paper.", 84 | ) 85 | parser.add_argument( 86 | "--overlap", default=0.25, type=float, help="Overlap between the splits." 87 | ) 88 | split_group = parser.add_mutually_exclusive_group() 89 | split_group.add_argument( 90 | "--no-split", 91 | action="store_false", 92 | dest="split", 93 | default=True, 94 | help="Doesn't split audio in chunks. " "This can use large amounts of memory.", 95 | ) 96 | split_group.add_argument( 97 | "--segment", 98 | type=int, 99 | help="Set split size of each chunk. " 100 | "This can help save memory of graphic card. ", 101 | ) 102 | parser.add_argument( 103 | "--two-stems", 104 | dest="stem", 105 | metavar="STEM", 106 | help="Only separate audio into {STEM} and no_{STEM}. ", 107 | ) 108 | group = parser.add_mutually_exclusive_group() 109 | group.add_argument( 110 | "--int24", action="store_true", help="Save wav output as 24 bits wav." 111 | ) 112 | group.add_argument( 113 | "--float32", action="store_true", help="Save wav output as float32 (2x bigger)." 114 | ) 115 | parser.add_argument( 116 | "--clip-mode", 117 | default="rescale", 118 | choices=["rescale", "clamp"], 119 | help="Strategy for avoiding clipping: rescaling entire signal " 120 | "if necessary (rescale) or hard clipping (clamp).", 121 | ) 122 | parser.add_argument( 123 | "--mp3", action="store_true", help="Convert the output wavs to mp3." 124 | ) 125 | parser.add_argument( 126 | "--mp3-bitrate", default=320, type=int, help="Bitrate of converted mp3." 127 | ) 128 | parser.add_argument( 129 | "-j", 130 | "--jobs", 131 | default=0, 132 | type=int, 133 | help="Number of jobs. This can increase memory usage but will " 134 | "be much faster when multiple cores are available.", 135 | ) 136 | 137 | args = parser.parse_args() 138 | 139 | try: 140 | model = get_model_from_args(args) 141 | except ModelLoadingError as error: 142 | fatal(error.args[0]) 143 | 144 | if args.segment is not None and args.segment < 8: 145 | fatal("Segment must greater than 8. ") 146 | 147 | if isinstance(model, BagOfModels): 148 | print( 149 | f"Selected model is a bag of {len(model.models)} models. " 150 | "You will see that many progress bars per track." 151 | ) 152 | if args.segment is not None: 153 | for sub in model.models: 154 | sub.segment = args.segment 155 | else: 156 | if args.segment is not None: 157 | sub.segment = args.segment 158 | 159 | model.cpu() 160 | model.eval() 161 | 162 | if args.stem is not None and args.stem not in model.sources: 163 | fatal( 164 | 'error: stem "{stem}" is not in selected model. STEM must be one of {sources}.'.format( 165 | stem=args.stem, sources=", ".join(model.sources) 166 | ) 167 | ) 168 | out = args.out / args.name 169 | out.mkdir(parents=True, exist_ok=True) 170 | print(f"Separated tracks will be stored in {out.resolve()}") 171 | for track in args.tracks: 172 | if not track.exists(): 173 | print( 174 | f"File {track} does not exist. If the path contains spaces, " 175 | 'please try again after surrounding the entire path with quotes "".', 176 | file=sys.stderr, 177 | ) 178 | continue 179 | print(f"Separating track {track}") 180 | wav = load_track(track, model.audio_channels, model.samplerate) 181 | 182 | ref = wav.mean(0) 183 | wav = (wav - ref.mean()) / ref.std() 184 | sources = apply_model( 185 | model, 186 | wav[None], 187 | device=args.device, 188 | shifts=args.shifts, 189 | split=args.split, 190 | overlap=args.overlap, 191 | progress=True, 192 | num_workers=args.jobs, 193 | )[0] 194 | sources = sources * ref.std() + ref.mean() 195 | 196 | track_folder = out / track.name.rsplit(".", 1)[0] 197 | track_folder.mkdir(exist_ok=True) 198 | if args.mp3: 199 | ext = ".mp3" 200 | else: 201 | ext = ".wav" 202 | kwargs = { 203 | "samplerate": model.samplerate, 204 | "bitrate": args.mp3_bitrate, 205 | "clip": args.clip_mode, 206 | "as_float": args.float32, 207 | "bits_per_sample": 24 if args.int24 else 16, 208 | } 209 | if args.stem is None: 210 | for source, name in zip(sources, model.sources): 211 | stem = str(track_folder / (name + ext)) 212 | save_audio(source, stem, **kwargs) 213 | else: 214 | sources = list(sources) 215 | stem = str(track_folder / (args.stem + ext)) 216 | save_audio(sources.pop(model.sources.index(args.stem)), stem, **kwargs) 217 | # Warning : after poping the stem, selected stem is no longer in the list 'sources' 218 | other_stem = th.zeros_like(sources[0]) 219 | for i in sources: 220 | other_stem += i 221 | stem = str(track_folder / ("no_" + args.stem + ext)) 222 | save_audio(other_stem, stem, **kwargs) 223 | 224 | 225 | if __name__ == "__main__": 226 | main() 227 | -------------------------------------------------------------------------------- /demucs/apply.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | """ 7 | Code to apply a model to a mix. It will handle chunking with overlaps and 8 | inteprolation between chunks, as well as the "shift trick". 9 | """ 10 | import random 11 | import typing as tp 12 | from concurrent.futures import ThreadPoolExecutor 13 | 14 | import torch as th 15 | import tqdm 16 | from torch import nn 17 | from torch.nn import functional as F 18 | 19 | from .demucs import Demucs 20 | from .hdemucs import HDemucs 21 | from .utils import DummyPoolExecutor, center_trim 22 | 23 | Model = tp.Union[Demucs, HDemucs] 24 | 25 | 26 | class BagOfModels(nn.Module): 27 | def __init__( 28 | self, 29 | models: tp.List[Model], 30 | weights: tp.Optional[tp.List[tp.List[float]]] = None, 31 | segment: tp.Optional[float] = None, 32 | ): 33 | """ 34 | Represents a bag of models with specific weights. 35 | You should call `apply_model` rather than calling directly the forward here for 36 | optimal performance. 37 | 38 | Args: 39 | models (list[nn.Module]): list of Demucs/HDemucs models. 40 | weights (list[list[float]]): list of weights. If None, assumed to 41 | be all ones, otherwise it should be a list of N list (N number of models), 42 | each containing S floats (S number of sources). 43 | segment (None or float): overrides the `segment` attribute of each model 44 | (this is performed inplace, be careful is you reuse the models passed). 45 | """ 46 | super().__init__() 47 | assert len(models) > 0 48 | first = models[0] 49 | for other in models: 50 | assert other.sources == first.sources 51 | assert other.samplerate == first.samplerate 52 | assert other.audio_channels == first.audio_channels 53 | if segment is not None: 54 | other.segment = segment 55 | 56 | self.audio_channels = first.audio_channels 57 | self.samplerate = first.samplerate 58 | self.sources = first.sources 59 | self.models = nn.ModuleList(models) 60 | 61 | if weights is None: 62 | weights = [[1.0 for _ in first.sources] for _ in models] 63 | else: 64 | assert len(weights) == len(models) 65 | for weight in weights: 66 | assert len(weight) == len(first.sources) 67 | self.weights = weights 68 | 69 | def forward(self, x): 70 | raise NotImplementedError("Call `apply_model` on this.") 71 | 72 | 73 | class TensorChunk: 74 | def __init__(self, tensor, offset=0, length=None): 75 | total_length = tensor.shape[-1] 76 | assert offset >= 0 77 | assert offset < total_length 78 | 79 | if length is None: 80 | length = total_length - offset 81 | else: 82 | length = min(total_length - offset, length) 83 | 84 | self.tensor = tensor 85 | self.offset = offset 86 | self.length = length 87 | self.device = tensor.device 88 | 89 | @property 90 | def shape(self): 91 | shape = list(self.tensor.shape) 92 | shape[-1] = self.length 93 | return shape 94 | 95 | def padded(self, target_length): 96 | delta = target_length - self.length 97 | total_length = self.tensor.shape[-1] 98 | assert delta >= 0 99 | 100 | start = self.offset - delta // 2 101 | end = start + target_length 102 | 103 | correct_start = max(0, start) 104 | correct_end = min(total_length, end) 105 | 106 | pad_left = correct_start - start 107 | pad_right = end - correct_end 108 | 109 | out = F.pad(self.tensor[..., correct_start:correct_end], (pad_left, pad_right)) 110 | assert out.shape[-1] == target_length 111 | return out 112 | 113 | 114 | def tensor_chunk(tensor_or_chunk): 115 | if isinstance(tensor_or_chunk, TensorChunk): 116 | return tensor_or_chunk 117 | else: 118 | assert isinstance(tensor_or_chunk, th.Tensor) 119 | return TensorChunk(tensor_or_chunk) 120 | 121 | 122 | def apply_model( 123 | model, 124 | mix, 125 | shifts=1, 126 | split=True, 127 | overlap=0.25, 128 | transition_power=1.0, 129 | progress=False, 130 | device=None, 131 | num_workers=0, 132 | pool=None, 133 | ): 134 | """ 135 | Apply model to a given mixture. 136 | 137 | Args: 138 | shifts (int): if > 0, will shift in time `mix` by a random amount between 0 and 0.5 sec 139 | and apply the oppositve shift to the output. This is repeated `shifts` time and 140 | all predictions are averaged. This effectively makes the model time equivariant 141 | and improves SDR by up to 0.2 points. 142 | split (bool): if True, the input will be broken down in 8 seconds extracts 143 | and predictions will be performed individually on each and concatenated. 144 | Useful for model with large memory footprint like Tasnet. 145 | progress (bool): if True, show a progress bar (requires split=True) 146 | device (torch.device, str, or None): if provided, device on which to 147 | execute the computation, otherwise `mix.device` is assumed. 148 | When `device` is different from `mix.device`, only local computations will 149 | be on `device`, while the entire tracks will be stored on `mix.device`. 150 | """ 151 | if device is None: 152 | device = mix.device 153 | else: 154 | device = th.device(device) 155 | if pool is None: 156 | if num_workers > 0 and device.type == "cpu": 157 | pool = ThreadPoolExecutor(num_workers) 158 | else: 159 | pool = DummyPoolExecutor() 160 | kwargs = { 161 | "shifts": shifts, 162 | "split": split, 163 | "overlap": overlap, 164 | "transition_power": transition_power, 165 | "progress": progress, 166 | "device": device, 167 | "pool": pool, 168 | } 169 | if isinstance(model, BagOfModels): 170 | # Special treatment for bag of model. 171 | # We explicitely apply multiple times `apply_model` so that the random shifts 172 | # are different for each model. 173 | estimates = 0 174 | totals = [0] * len(model.sources) 175 | for sub_model, weight in zip(model.models, model.weights): 176 | original_model_device = next(iter(sub_model.parameters())).device 177 | sub_model.to(device) 178 | 179 | out = apply_model(sub_model, mix, **kwargs) 180 | sub_model.to(original_model_device) 181 | for k, inst_weight in enumerate(weight): 182 | out[:, k, :, :] *= inst_weight 183 | totals[k] += inst_weight 184 | estimates += out 185 | del out 186 | 187 | for k in range(estimates.shape[1]): 188 | estimates[:, k, :, :] /= totals[k] 189 | return estimates 190 | 191 | model.to(device) 192 | assert transition_power >= 1, "transition_power < 1 leads to weird behavior." 193 | batch, channels, length = mix.shape 194 | if split: 195 | kwargs["split"] = False 196 | out = th.zeros(batch, len(model.sources), channels, length, device=mix.device) 197 | sum_weight = th.zeros(length, device=mix.device) 198 | segment = int(model.samplerate * model.segment) 199 | stride = int((1 - overlap) * segment) 200 | offsets = range(0, length, stride) 201 | scale = stride / model.samplerate 202 | # We start from a triangle shaped weight, with maximal weight in the middle 203 | # of the segment. Then we normalize and take to the power `transition_power`. 204 | # Large values of transition power will lead to sharper transitions. 205 | weight = th.cat( 206 | [ 207 | th.arange(1, segment // 2 + 1, device=device), 208 | th.arange(segment - segment // 2, 0, -1, device=device), 209 | ] 210 | ) 211 | assert len(weight) == segment 212 | # If the overlap < 50%, this will translate to linear transition when 213 | # transition_power is 1. 214 | weight = (weight / weight.max()) ** transition_power 215 | futures = [] 216 | for offset in offsets: 217 | chunk = TensorChunk(mix, offset, segment) 218 | future = pool.submit(apply_model, model, chunk, **kwargs) 219 | futures.append((future, offset)) 220 | offset += segment 221 | if progress: 222 | futures = tqdm.tqdm(futures, unit_scale=scale, ncols=120, unit="seconds") 223 | for future, offset in futures: 224 | chunk_out = future.result() 225 | chunk_length = chunk_out.shape[-1] 226 | out[..., offset : offset + segment] += ( 227 | weight[:chunk_length] * chunk_out 228 | ).to(mix.device) 229 | sum_weight[offset : offset + segment] += weight[:chunk_length].to( 230 | mix.device 231 | ) 232 | assert sum_weight.min() > 0 233 | out /= sum_weight 234 | return out 235 | elif shifts: 236 | kwargs["shifts"] = 0 237 | max_shift = int(0.5 * model.samplerate) 238 | mix = tensor_chunk(mix) 239 | padded_mix = mix.padded(length + 2 * max_shift) 240 | out = 0 241 | for _ in range(shifts): 242 | offset = random.randint(0, max_shift) 243 | shifted = TensorChunk(padded_mix, offset, length + max_shift - offset) 244 | shifted_out = apply_model(model, shifted, **kwargs) 245 | out += shifted_out[..., max_shift - offset :] 246 | out /= shifts 247 | return out 248 | else: 249 | if hasattr(model, "valid_length"): 250 | valid_length = model.valid_length(length) 251 | else: 252 | valid_length = length 253 | mix = tensor_chunk(mix) 254 | padded_mix = mix.padded(valid_length).to(device) 255 | with th.no_grad(): 256 | out = model(padded_mix) 257 | return center_trim(out, length) 258 | -------------------------------------------------------------------------------- /demucs/audio.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import json 7 | import subprocess as sp 8 | from pathlib import Path 9 | 10 | import julius 11 | import lameenc 12 | import numpy as np 13 | import torch 14 | import torchaudio as ta 15 | 16 | from .utils import temp_filenames 17 | 18 | 19 | def _read_info(path): 20 | stdout_data = sp.check_output( 21 | [ 22 | "ffprobe", 23 | "-loglevel", 24 | "panic", 25 | str(path), 26 | "-print_format", 27 | "json", 28 | "-show_format", 29 | "-show_streams", 30 | ] 31 | ) 32 | return json.loads(stdout_data.decode("utf-8")) 33 | 34 | 35 | class AudioFile: 36 | """ 37 | Allows to read audio from any format supported by ffmpeg, as well as resampling or 38 | converting to mono on the fly. See :method:`read` for more details. 39 | """ 40 | 41 | def __init__(self, path: Path): 42 | self.path = Path(path) 43 | self._info = None 44 | 45 | def __repr__(self): 46 | features = [("path", self.path)] 47 | features.append(("samplerate", self.samplerate())) 48 | features.append(("channels", self.channels())) 49 | features.append(("streams", len(self))) 50 | features_str = ", ".join(f"{name}={value}" for name, value in features) 51 | return f"AudioFile({features_str})" 52 | 53 | @property 54 | def info(self): 55 | if self._info is None: 56 | self._info = _read_info(self.path) 57 | return self._info 58 | 59 | @property 60 | def duration(self): 61 | return float(self.info["format"]["duration"]) 62 | 63 | @property 64 | def _audio_streams(self): 65 | return [ 66 | index 67 | for index, stream in enumerate(self.info["streams"]) 68 | if stream["codec_type"] == "audio" 69 | ] 70 | 71 | def __len__(self): 72 | return len(self._audio_streams) 73 | 74 | def channels(self, stream=0): 75 | return int(self.info["streams"][self._audio_streams[stream]]["channels"]) 76 | 77 | def samplerate(self, stream=0): 78 | return int(self.info["streams"][self._audio_streams[stream]]["sample_rate"]) 79 | 80 | def read( 81 | self, 82 | seek_time=None, 83 | duration=None, 84 | streams=slice(None), 85 | samplerate=None, 86 | channels=None, 87 | temp_folder=None, 88 | ): 89 | """ 90 | Slightly more efficient implementation than stempeg, 91 | in particular, this will extract all stems at once 92 | rather than having to loop over one file multiple times 93 | for each stream. 94 | 95 | Args: 96 | seek_time (float): seek time in seconds or None if no seeking is needed. 97 | duration (float): duration in seconds to extract or None to extract until the end. 98 | streams (slice, int or list): streams to extract, can be a single int, a list or 99 | a slice. If it is a slice or list, the output will be of size [S, C, T] 100 | with S the number of streams, C the number of channels and T the number of samples. 101 | If it is an int, the output will be [C, T]. 102 | samplerate (int): if provided, will resample on the fly. If None, no resampling will 103 | be done. Original sampling rate can be obtained with :method:`samplerate`. 104 | channels (int): if 1, will convert to mono. We do not rely on ffmpeg for that 105 | as ffmpeg automatically scale by +3dB to conserve volume when playing on speakers. 106 | See https://sound.stackexchange.com/a/42710. 107 | Our definition of mono is simply the average of the two channels. Any other 108 | value will be ignored. 109 | temp_folder (str or Path or None): temporary folder to use for decoding. 110 | 111 | 112 | """ 113 | streams = np.array(range(len(self)))[streams] 114 | single = not isinstance(streams, np.ndarray) 115 | if single: 116 | streams = [streams] 117 | 118 | if duration is None: 119 | target_size = None 120 | query_duration = None 121 | else: 122 | target_size = int((samplerate or self.samplerate()) * duration) 123 | query_duration = float( 124 | (target_size + 1) / (samplerate or self.samplerate()) 125 | ) 126 | 127 | with temp_filenames(len(streams)) as filenames: 128 | command = ["ffmpeg", "-y"] 129 | command += ["-loglevel", "panic"] 130 | if seek_time: 131 | command += ["-ss", str(seek_time)] 132 | command += ["-i", str(self.path)] 133 | for stream, filename in zip(streams, filenames): 134 | command += ["-map", f"0:{self._audio_streams[stream]}"] 135 | if query_duration is not None: 136 | command += ["-t", str(query_duration)] 137 | command += ["-threads", "1"] 138 | command += ["-f", "f32le"] 139 | if samplerate is not None: 140 | command += ["-ar", str(samplerate)] 141 | command += [filename] 142 | 143 | sp.run(command, check=True) 144 | wavs = [] 145 | for filename in filenames: 146 | wav = np.fromfile(filename, dtype=np.float32) 147 | wav = torch.from_numpy(wav) 148 | wav = wav.view(-1, self.channels()).t() 149 | if channels is not None: 150 | wav = convert_audio_channels(wav, channels) 151 | if target_size is not None: 152 | wav = wav[..., :target_size] 153 | wavs.append(wav) 154 | wav = torch.stack(wavs, dim=0) 155 | if single: 156 | wav = wav[0] 157 | return wav 158 | 159 | 160 | def convert_audio_channels(wav, channels=2): 161 | """Convert audio to the given number of channels.""" 162 | *shape, src_channels, length = wav.shape 163 | if src_channels == channels: 164 | pass 165 | elif channels == 1: 166 | # Case 1: 167 | # The caller asked 1-channel audio, but the stream have multiple 168 | # channels, downmix all channels. 169 | wav = wav.mean(dim=-2, keepdim=True) 170 | elif src_channels == 1: 171 | # Case 2: 172 | # The caller asked for multiple channels, but the input file have 173 | # one single channel, replicate the audio over all channels. 174 | wav = wav.expand(*shape, channels, length) 175 | elif src_channels >= channels: 176 | # Case 3: 177 | # The caller asked for multiple channels, and the input file have 178 | # more channels than requested. In that case return the first channels. 179 | wav = wav[..., :channels, :] 180 | else: 181 | # Case 4: What is a reasonable choice here? 182 | raise ValueError( 183 | "The audio file has less channels than requested but is not mono." 184 | ) 185 | return wav 186 | 187 | 188 | def convert_audio(wav, from_samplerate, to_samplerate, channels): 189 | """Convert audio from a given samplerate to a target one and target number of channels.""" 190 | wav = convert_audio_channels(wav, channels) 191 | return julius.resample_frac(wav, from_samplerate, to_samplerate) 192 | 193 | 194 | def i16_pcm(wav): 195 | """Convert audio to 16 bits integer PCM format.""" 196 | if wav.dtype.is_floating_point: 197 | return (wav.clamp_(-1, 1) * (2**15 - 1)).short() 198 | else: 199 | return wav 200 | 201 | 202 | def f32_pcm(wav): 203 | """Convert audio to float 32 bits PCM format.""" 204 | if wav.dtype.is_floating_point: 205 | return wav 206 | else: 207 | return wav.float() / (2**15 - 1) 208 | 209 | 210 | def as_dtype_pcm(wav, dtype): 211 | """Convert audio to either f32 pcm or i16 pcm depending on the given dtype.""" 212 | if wav.dtype.is_floating_point: 213 | return f32_pcm(wav) 214 | else: 215 | return i16_pcm(wav) 216 | 217 | 218 | def encode_mp3(wav, path, samplerate=44100, bitrate=320, verbose=False): 219 | """Save given audio as mp3. This should work on all OSes.""" 220 | C, T = wav.shape 221 | wav = i16_pcm(wav) 222 | encoder = lameenc.Encoder() 223 | encoder.set_bit_rate(bitrate) 224 | encoder.set_in_sample_rate(samplerate) 225 | encoder.set_channels(C) 226 | encoder.set_quality(2) # 2-highest, 7-fastest 227 | if not verbose: 228 | encoder.silence() 229 | wav = wav.transpose(0, 1).numpy() 230 | mp3_data = encoder.encode(wav.tobytes()) 231 | mp3_data += encoder.flush() 232 | with open(path, "wb") as f: 233 | f.write(mp3_data) 234 | 235 | 236 | def prevent_clip(wav, mode="rescale"): 237 | """ 238 | different strategies for avoiding raw clipping. 239 | """ 240 | assert wav.dtype.is_floating_point, "too late for clipping" 241 | if mode == "rescale": 242 | wav = wav / max(1.01 * wav.abs().max(), 1) 243 | elif mode == "clamp": 244 | wav = wav.clamp(-0.99, 0.99) 245 | elif mode == "tanh": 246 | wav = torch.tanh(wav) 247 | else: 248 | raise ValueError(f"Invalid mode {mode}") 249 | return wav 250 | 251 | 252 | def save_audio( 253 | wav, 254 | path, 255 | samplerate, 256 | bitrate=320, 257 | clip="rescale", 258 | bits_per_sample=16, 259 | as_float=False, 260 | ): 261 | """Save audio file, automatically preventing clipping if necessary 262 | based on the given `clip` strategy. If the path ends in `.mp3`, this 263 | will save as mp3 with the given `bitrate`. 264 | """ 265 | wav = prevent_clip(wav, mode=clip) 266 | path = Path(path) 267 | suffix = path.suffix.lower() 268 | if suffix == ".mp3": 269 | encode_mp3(wav, path, samplerate, bitrate) 270 | elif suffix == ".wav": 271 | if as_float: 272 | bits_per_sample = 32 273 | encoding = "PCM_F" 274 | else: 275 | encoding = "PCM_S" 276 | ta.save( 277 | str(path), 278 | wav, 279 | sample_rate=samplerate, 280 | encoding=encoding, 281 | bits_per_sample=bits_per_sample, 282 | ) 283 | else: 284 | raise ValueError(f"Invalid suffix for path: {suffix}") 285 | -------------------------------------------------------------------------------- /demucs/wav.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | """Loading wav based datasets, including MusdbHQ.""" 7 | 8 | import hashlib 9 | import json 10 | import math 11 | import os 12 | from collections import OrderedDict 13 | from pathlib import Path 14 | 15 | import julius 16 | import musdb 17 | import torch as th 18 | import torchaudio as ta 19 | import tqdm 20 | from torch import distributed 21 | from torch.nn import functional as F 22 | 23 | from . import distrib 24 | from .audio import convert_audio_channels 25 | 26 | MIXTURE = "mixture" 27 | EXT = ".wav" 28 | 29 | 30 | def _track_metadata(track, sources, normalize=True, ext=EXT): 31 | track_length = None 32 | track_samplerate = None 33 | mean = 0 34 | std = 1 35 | for source in sources + [MIXTURE]: 36 | file = track / f"{source}{ext}" 37 | try: 38 | info = ta.info(str(file)) 39 | except RuntimeError: 40 | print(file) 41 | raise 42 | length = info.num_frames 43 | if track_length is None: 44 | track_length = length 45 | track_samplerate = info.sample_rate 46 | elif track_length != length: 47 | raise ValueError( 48 | f"Invalid length for file {file}: " 49 | f"expecting {track_length} but got {length}." 50 | ) 51 | elif info.sample_rate != track_samplerate: 52 | raise ValueError( 53 | f"Invalid sample rate for file {file}: " 54 | f"expecting {track_samplerate} but got {info.sample_rate}." 55 | ) 56 | if source == MIXTURE and normalize: 57 | try: 58 | wav, _ = ta.load(str(file)) 59 | except RuntimeError: 60 | print(file) 61 | raise 62 | wav = wav.mean(0) 63 | mean = wav.mean().item() 64 | std = wav.std().item() 65 | 66 | return {"length": length, "mean": mean, "std": std, "samplerate": track_samplerate} 67 | 68 | 69 | def build_metadata(path, sources, normalize=True, ext=EXT): 70 | """ 71 | Build the metadata for `Wavset`. 72 | 73 | Args: 74 | path (str or Path): path to dataset. 75 | sources (list[str]): list of sources to look for. 76 | normalize (bool): if True, loads full track and store normalization 77 | values based on the mixture file. 78 | ext (str): extension of audio files (default is .wav). 79 | """ 80 | 81 | meta = {} 82 | path = Path(path) 83 | pendings = [] 84 | from concurrent.futures import ThreadPoolExecutor 85 | 86 | with ThreadPoolExecutor(8) as pool: 87 | for root, folders, files in os.walk(path, followlinks=True): 88 | root = Path(root) 89 | if root.name.startswith(".") or folders or root == path: 90 | continue 91 | name = str(root.relative_to(path)) 92 | pendings.append( 93 | (name, pool.submit(_track_metadata, root, sources, normalize, ext)) 94 | ) 95 | # meta[name] = _track_metadata(root, sources, normalize, ext) 96 | for name, pending in tqdm.tqdm(pendings, ncols=120): 97 | meta[name] = pending.result() 98 | return meta 99 | 100 | 101 | class Wavset: 102 | def __init__( 103 | self, 104 | root, 105 | metadata, 106 | sources, 107 | segment=None, 108 | shift=None, 109 | normalize=True, 110 | samplerate=44100, 111 | channels=2, 112 | ext=EXT, 113 | ): 114 | """ 115 | Waveset (or mp3 set for that matter). Can be used to train 116 | with arbitrary sources. Each track should be one folder inside of `path`. 117 | The folder should contain files named `{source}.{ext}`. 118 | 119 | Args: 120 | root (Path or str): root folder for the dataset. 121 | metadata (dict): output from `build_metadata`. 122 | sources (list[str]): list of source names. 123 | segment (None or float): segment length in seconds. If `None`, returns entire tracks. 124 | shift (None or float): stride in seconds bewteen samples. 125 | normalize (bool): normalizes input audio, **based on the metadata content**, 126 | i.e. the entire track is normalized, not individual extracts. 127 | samplerate (int): target sample rate. if the file sample rate 128 | is different, it will be resampled on the fly. 129 | channels (int): target nb of channels. if different, will be 130 | changed onthe fly. 131 | ext (str): extension for audio files (default is .wav). 132 | 133 | samplerate and channels are converted on the fly. 134 | """ 135 | self.root = Path(root) 136 | self.metadata = OrderedDict(metadata) 137 | self.segment = segment 138 | self.shift = shift or segment 139 | self.normalize = normalize 140 | self.sources = sources 141 | self.channels = channels 142 | self.samplerate = samplerate 143 | self.ext = ext 144 | self.num_examples = [] 145 | for name, meta in self.metadata.items(): 146 | track_duration = meta["length"] / meta["samplerate"] 147 | if segment is None or track_duration < segment: 148 | examples = 1 149 | else: 150 | examples = int( 151 | math.ceil((track_duration - self.segment) / self.shift) + 1 152 | ) 153 | self.num_examples.append(examples) 154 | 155 | def __len__(self): 156 | return sum(self.num_examples) 157 | 158 | def get_file(self, name, source): 159 | return self.root / name / f"{source}{self.ext}" 160 | 161 | def __getitem__(self, index): 162 | for name, examples in zip(self.metadata, self.num_examples): 163 | if index >= examples: 164 | index -= examples 165 | continue 166 | meta = self.metadata[name] 167 | num_frames = -1 168 | offset = 0 169 | if self.segment is not None: 170 | offset = int(meta["samplerate"] * self.shift * index) 171 | num_frames = int(math.ceil(meta["samplerate"] * self.segment)) 172 | wavs = [] 173 | for source in self.sources: 174 | file = self.get_file(name, source) 175 | wav, _ = ta.load(str(file), frame_offset=offset, num_frames=num_frames) 176 | wav = convert_audio_channels(wav, self.channels) 177 | wavs.append(wav) 178 | 179 | example = th.stack(wavs) 180 | example = julius.resample_frac(example, meta["samplerate"], self.samplerate) 181 | if self.normalize: 182 | example = (example - meta["mean"]) / meta["std"] 183 | if self.segment: 184 | length = int(self.segment * self.samplerate) 185 | example = example[..., :length] 186 | example = F.pad(example, (0, length - example.shape[-1])) 187 | return example 188 | 189 | 190 | def get_wav_datasets(args): 191 | """Extract the wav datasets from the XP arguments.""" 192 | sig = hashlib.sha1(str(args.wav).encode()).hexdigest()[:8] 193 | metadata_file = Path(args.metadata) / ("wav_" + sig + ".json") 194 | train_path = Path(args.wav) / "train" 195 | valid_path = Path(args.wav) / "valid" 196 | if not metadata_file.is_file() and distrib.rank == 0: 197 | metadata_file.parent.mkdir(exist_ok=True, parents=True) 198 | train = build_metadata(train_path, args.sources) 199 | valid = build_metadata(valid_path, args.sources) 200 | json.dump([train, valid], open(metadata_file, "w")) 201 | if distrib.world_size > 1: 202 | distributed.barrier() 203 | train, valid = json.load(open(metadata_file)) 204 | if args.full_cv: 205 | kw_cv = {} 206 | else: 207 | kw_cv = {"segment": args.segment, "shift": args.shift} 208 | train_set = Wavset( 209 | train_path, 210 | train, 211 | args.sources, 212 | segment=args.segment, 213 | shift=args.shift, 214 | samplerate=args.samplerate, 215 | channels=args.channels, 216 | normalize=args.normalize, 217 | ) 218 | valid_set = Wavset( 219 | valid_path, 220 | valid, 221 | [MIXTURE] + list(args.sources), 222 | samplerate=args.samplerate, 223 | channels=args.channels, 224 | normalize=args.normalize, 225 | **kw_cv, 226 | ) 227 | return train_set, valid_set 228 | 229 | 230 | def _get_musdb_valid(): 231 | # Return musdb valid set. 232 | import yaml 233 | 234 | setup_path = Path(musdb.__path__[0]) / "configs" / "mus.yaml" 235 | setup = yaml.safe_load(open(setup_path, "r")) 236 | return setup["validation_tracks"] 237 | 238 | 239 | def get_musdb_wav_datasets(args): 240 | """Extract the musdb dataset from the XP arguments.""" 241 | sig = hashlib.sha1(str(args.musdb).encode()).hexdigest()[:8] 242 | metadata_file = Path(args.metadata) / ("musdb_" + sig + ".json") 243 | root = Path(args.musdb) / "train" 244 | if not metadata_file.is_file() and distrib.rank == 0: 245 | metadata_file.parent.mkdir(exist_ok=True, parents=True) 246 | metadata = build_metadata(root, args.sources) 247 | json.dump(metadata, open(metadata_file, "w")) 248 | if distrib.world_size > 1: 249 | distributed.barrier() 250 | metadata = json.load(open(metadata_file)) 251 | 252 | valid_tracks = _get_musdb_valid() 253 | if args.train_valid: 254 | metadata_train = metadata 255 | else: 256 | metadata_train = { 257 | name: meta for name, meta in metadata.items() if name not in valid_tracks 258 | } 259 | metadata_valid = { 260 | name: meta for name, meta in metadata.items() if name in valid_tracks 261 | } 262 | if args.full_cv: 263 | kw_cv = {} 264 | else: 265 | kw_cv = {"segment": args.segment, "shift": args.shift} 266 | train_set = Wavset( 267 | root, 268 | metadata_train, 269 | args.sources, 270 | segment=args.segment, 271 | shift=args.shift, 272 | samplerate=args.samplerate, 273 | channels=args.channels, 274 | normalize=args.normalize, 275 | ) 276 | valid_set = Wavset( 277 | root, 278 | metadata_valid, 279 | [MIXTURE] + list(args.sources), 280 | samplerate=args.samplerate, 281 | channels=args.channels, 282 | normalize=args.normalize, 283 | **kw_cv, 284 | ) 285 | return train_set, valid_set 286 | -------------------------------------------------------------------------------- /demucs/solver.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | """Main training loop.""" 7 | 8 | import logging 9 | 10 | import torch 11 | import torch.nn.functional as F 12 | from dora import get_xp 13 | from dora.log import LogProgress, bold 14 | from dora.utils import write_and_rename 15 | 16 | from . import augment, distrib, pretrained, states 17 | from .apply import apply_model 18 | from .ema import ModelEMA 19 | from .evaluate import evaluate, new_sdr 20 | from .svd import svd_penalty 21 | from .utils import EMA, pull_metric 22 | 23 | logger = logging.getLogger(__name__) 24 | 25 | 26 | def _summary(metrics): 27 | return " | ".join(f"{key.capitalize()}={val}" for key, val in metrics.items()) 28 | 29 | 30 | class Solver(object): 31 | def __init__(self, loaders, model, optimizer, args): 32 | self.args = args 33 | self.loaders = loaders 34 | 35 | self.model = model 36 | self.optimizer = optimizer 37 | self.quantizer = states.get_quantizer(self.model, args.quant, self.optimizer) 38 | self.dmodel = distrib.wrap(model) 39 | self.device = next(iter(self.model.parameters())).device 40 | 41 | # Exponential moving average of the model, either updated every batch or epoch. 42 | # The best model from all the EMAs and the original one is kept based on the valid 43 | # loss for the final best model. 44 | self.emas = {"batch": [], "epoch": []} 45 | for kind in self.emas.keys(): 46 | decays = getattr(args.ema, kind) 47 | device = self.device if kind == "batch" else "cpu" 48 | if decays: 49 | for decay in decays: 50 | self.emas[kind].append(ModelEMA(self.model, decay, device=device)) 51 | 52 | # data augment 53 | augments = [ 54 | augment.Shift( 55 | shift=int(args.dset.samplerate * args.dset.shift), 56 | same=args.augment.shift_same, 57 | ) 58 | ] 59 | if args.augment.flip: 60 | augments += [augment.FlipChannels(), augment.FlipSign()] 61 | for aug in ["scale", "remix"]: 62 | kw = getattr(args.augment, aug) 63 | if kw.proba: 64 | augments.append(getattr(augment, aug.capitalize())(**kw)) 65 | self.augment = torch.nn.Sequential(*augments) 66 | 67 | xp = get_xp() 68 | self.folder = xp.folder 69 | # Checkpoints 70 | self.checkpoint_file = xp.folder / "checkpoint.th" 71 | self.best_file = xp.folder / "best.th" 72 | logger.debug("Checkpoint will be saved to %s", self.checkpoint_file.resolve()) 73 | self.best_state = None 74 | self.best_changed = False 75 | 76 | self.link = xp.link 77 | self.history = self.link.history 78 | 79 | self._reset() 80 | 81 | def _serialize(self, epoch): 82 | package = {} 83 | package["state"] = self.model.state_dict() 84 | package["optimizer"] = self.optimizer.state_dict() 85 | package["history"] = self.history 86 | package["best_state"] = self.best_state 87 | package["args"] = self.args 88 | for kind, emas in self.emas.items(): 89 | for k, ema in enumerate(emas): 90 | package[f"ema_{kind}_{k}"] = ema.state_dict() 91 | with write_and_rename(self.checkpoint_file) as tmp: 92 | torch.save(package, tmp) 93 | 94 | save_every = self.args.save_every 95 | if ( 96 | save_every 97 | and (epoch + 1) % save_every == 0 98 | and epoch + 1 != self.args.epochs 99 | ): 100 | with write_and_rename(self.folder / f"checkpoint_{epoch + 1}.th") as tmp: 101 | torch.save(package, tmp) 102 | 103 | if self.best_changed: 104 | # Saving only the latest best model. 105 | with write_and_rename(self.best_file) as tmp: 106 | package = states.serialize_model(self.model, self.args) 107 | package["state"] = self.best_state 108 | torch.save(package, tmp) 109 | self.best_changed = False 110 | 111 | def _reset(self): 112 | """Reset state of the solver, potentially using checkpoint.""" 113 | if self.checkpoint_file.exists(): 114 | logger.info(f"Loading checkpoint model: {self.checkpoint_file}") 115 | package = torch.load(self.checkpoint_file, "cpu") 116 | self.model.load_state_dict(package["state"]) 117 | self.optimizer.load_state_dict(package["optimizer"]) 118 | self.history[:] = package["history"] 119 | self.best_state = package["best_state"] 120 | for kind, emas in self.emas.items(): 121 | for k, ema in enumerate(emas): 122 | ema.load_state_dict(package[f"ema_{kind}_{k}"]) 123 | elif self.args.continue_pretrained: 124 | model = pretrained.get_model( 125 | name=self.args.continue_pretrained, repo=self.args.pretrained_repo 126 | ) 127 | self.model.load_state_dict(model.state_dict()) 128 | elif self.args.continue_from: 129 | name = "checkpoint.th" 130 | root = self.folder.parent 131 | cf = root / str(self.args.continue_from) / name 132 | logger.info("Loading from %s", cf) 133 | package = torch.load(cf, "cpu") 134 | self.best_state = package["best_state"] 135 | if self.args.continue_best: 136 | self.model.load_state_dict(package["best_state"], strict=False) 137 | else: 138 | self.model.load_state_dict(package["state"], strict=False) 139 | if self.args.continue_opt: 140 | self.optimizer.load_state_dict(package["optimizer"]) 141 | 142 | def _format_train(self, metrics: dict) -> dict: 143 | """Formatting for train/valid metrics.""" 144 | losses = { 145 | "loss": format(metrics["loss"], ".4f"), 146 | "reco": format(metrics["reco"], ".4f"), 147 | } 148 | if "nsdr" in metrics: 149 | losses["nsdr"] = format(metrics["nsdr"], ".3f") 150 | if self.quantizer is not None: 151 | losses["ms"] = format(metrics["ms"], ".2f") 152 | if "grad" in metrics: 153 | losses["grad"] = format(metrics["grad"], ".4f") 154 | if "best" in metrics: 155 | losses["best"] = format(metrics["best"], ".4f") 156 | if "bname" in metrics: 157 | losses["bname"] = metrics["bname"] 158 | if "penalty" in metrics: 159 | losses["penalty"] = format(metrics["penalty"], ".4f") 160 | if "hloss" in metrics: 161 | losses["hloss"] = format(metrics["hloss"], ".4f") 162 | return losses 163 | 164 | def _format_test(self, metrics: dict) -> dict: 165 | """Formatting for test metrics.""" 166 | losses = {} 167 | if "sdr" in metrics: 168 | losses["sdr"] = format(metrics["sdr"], ".3f") 169 | if "nsdr" in metrics: 170 | losses["nsdr"] = format(metrics["nsdr"], ".3f") 171 | for source in self.model.sources: 172 | key = f"sdr_{source}" 173 | if key in metrics: 174 | losses[key] = format(metrics[key], ".3f") 175 | key = f"nsdr_{source}" 176 | if key in metrics: 177 | losses[key] = format(metrics[key], ".3f") 178 | return losses 179 | 180 | def train(self): 181 | # Optimizing the model 182 | if self.history: 183 | logger.info("Replaying metrics from previous run") 184 | for epoch, metrics in enumerate(self.history): 185 | formatted = self._format_train(metrics["train"]) 186 | logger.info( 187 | bold(f"Train Summary | Epoch {epoch + 1} | {_summary(formatted)}") 188 | ) 189 | formatted = self._format_train(metrics["valid"]) 190 | logger.info( 191 | bold(f"Valid Summary | Epoch {epoch + 1} | {_summary(formatted)}") 192 | ) 193 | if "test" in metrics: 194 | formatted = self._format_test(metrics["test"]) 195 | if formatted: 196 | logger.info( 197 | bold( 198 | f"Test Summary | Epoch {epoch + 1} | {_summary(formatted)}" 199 | ) 200 | ) 201 | 202 | epoch = 0 203 | for epoch in range(len(self.history), self.args.epochs): 204 | # Train one epoch 205 | self.model.train() # Turn on BatchNorm & Dropout 206 | metrics = {} 207 | logger.info("-" * 70) 208 | logger.info("Training...") 209 | metrics["train"] = self._run_one_epoch(epoch) 210 | formatted = self._format_train(metrics["train"]) 211 | logger.info( 212 | bold(f"Train Summary | Epoch {epoch + 1} | {_summary(formatted)}") 213 | ) 214 | 215 | # Cross validation 216 | logger.info("-" * 70) 217 | logger.info("Cross validation...") 218 | self.model.eval() # Turn off Batchnorm & Dropout 219 | with torch.no_grad(): 220 | valid = self._run_one_epoch(epoch, train=False) 221 | bvalid = valid 222 | bname = "main" 223 | state = states.copy_state(self.model.state_dict()) 224 | metrics["valid"] = {} 225 | metrics["valid"]["main"] = valid 226 | key = self.args.test.metric 227 | for kind, emas in self.emas.items(): 228 | for k, ema in enumerate(emas): 229 | with ema.swap(): 230 | valid = self._run_one_epoch(epoch, train=False) 231 | name = f"ema_{kind}_{k}" 232 | metrics["valid"][name] = valid 233 | a = valid[key] 234 | b = bvalid[key] 235 | if key.startswith("nsdr"): 236 | a = -a 237 | b = -b 238 | if a < b: 239 | bvalid = valid 240 | state = ema.state 241 | bname = name 242 | metrics["valid"].update(bvalid) 243 | metrics["valid"]["bname"] = bname 244 | 245 | valid_loss = metrics["valid"][key] 246 | mets = pull_metric(self.link.history, f"valid.{key}") + [valid_loss] 247 | if key.startswith("nsdr"): 248 | best_loss = max(mets) 249 | else: 250 | best_loss = min(mets) 251 | metrics["valid"]["best"] = best_loss 252 | if self.args.svd.penalty > 0: 253 | kw = dict(self.args.svd) 254 | kw.pop("penalty") 255 | with torch.no_grad(): 256 | penalty = svd_penalty(self.model, exact=True, **kw) 257 | metrics["valid"]["penalty"] = penalty 258 | 259 | formatted = self._format_train(metrics["valid"]) 260 | logger.info( 261 | bold(f"Valid Summary | Epoch {epoch + 1} | {_summary(formatted)}") 262 | ) 263 | 264 | # Save the best model 265 | if valid_loss == best_loss or self.args.dset.train_valid: 266 | logger.info(bold("New best valid loss %.4f"), valid_loss) 267 | self.best_state = states.copy_state(state) 268 | self.best_changed = True 269 | 270 | # Eval model every `test.every` epoch or on last epoch 271 | should_eval = (epoch + 1) % self.args.test.every == 0 272 | is_last = epoch == self.args.epochs - 1 273 | reco = metrics["valid"]["main"]["reco"] 274 | # Tries to detect divergence in a reliable way and finish job 275 | # not to waste compute. 276 | div = epoch >= 180 and reco > 0.18 277 | div = div or epoch >= 100 and reco > 0.25 278 | div = div and self.args.optim.loss == "l1" 279 | if div: 280 | logger.warning( 281 | "Finishing training early because valid loss is too high." 282 | ) 283 | is_last = True 284 | if should_eval or is_last: 285 | # Evaluate on the testset 286 | logger.info("-" * 70) 287 | logger.info("Evaluating on the test set...") 288 | # We switch to the best known model for testing 289 | if self.args.test.best: 290 | state = self.best_state 291 | else: 292 | state = states.copy_state(self.model.state_dict()) 293 | compute_sdr = self.args.test.sdr and is_last 294 | with states.swap_state(self.model, state): 295 | with torch.no_grad(): 296 | metrics["test"] = evaluate(self, compute_sdr=compute_sdr) 297 | formatted = self._format_test(metrics["test"]) 298 | logger.info( 299 | bold(f"Test Summary | Epoch {epoch + 1} | {_summary(formatted)}") 300 | ) 301 | self.link.push_metrics(metrics) 302 | 303 | if distrib.rank == 0: 304 | # Save model each epoch 305 | self._serialize(epoch) 306 | logger.debug("Checkpoint saved to %s", self.checkpoint_file.resolve()) 307 | if is_last: 308 | break 309 | 310 | def _run_one_epoch(self, epoch, train=True): 311 | args = self.args 312 | data_loader = self.loaders["train"] if train else self.loaders["valid"] 313 | # get a different order for distributed training, otherwise this will get ignored 314 | data_loader.sampler.epoch = epoch 315 | 316 | label = ["Valid", "Train"][train] 317 | name = label + f" | Epoch {epoch + 1}" 318 | total = len(data_loader) 319 | if args.max_batches: 320 | total = min(total, args.max_batches) 321 | logprog = LogProgress( 322 | logger, 323 | data_loader, 324 | total=total, 325 | updates=self.args.misc.num_prints, 326 | name=name, 327 | ) 328 | averager = EMA() 329 | 330 | for idx, sources in enumerate(logprog): 331 | sources = sources.to(self.device) 332 | if train: 333 | sources = self.augment(sources) 334 | mix = sources.sum(dim=1) 335 | else: 336 | mix = sources[:, 0] 337 | sources = sources[:, 1:] 338 | 339 | if not train and self.args.valid_apply: 340 | estimate = apply_model( 341 | self.model, mix, split=self.args.test.split, overlap=0 342 | ) 343 | else: 344 | estimate = self.dmodel(mix) 345 | if train and hasattr(self.model, "transform_target"): 346 | sources = self.model.transform_target(mix, sources) 347 | assert estimate.shape == sources.shape, (estimate.shape, sources.shape) 348 | dims = tuple(range(2, sources.dim())) 349 | 350 | if args.optim.loss == "l1": 351 | loss = F.l1_loss(estimate, sources, reduction="none") 352 | loss = loss.mean(dims).mean(0) 353 | reco = loss 354 | elif args.optim.loss == "mse": 355 | loss = F.mse_loss(estimate, sources, reduction="none") 356 | loss = loss.mean(dims) 357 | reco = loss**0.5 358 | reco = reco.mean(0) 359 | else: 360 | raise ValueError(f"Invalid loss {self.args.loss}") 361 | weights = torch.tensor(args.weights).to(sources) 362 | loss = (loss * weights).sum() / weights.sum() 363 | 364 | ms = 0 365 | if self.quantizer is not None: 366 | ms = self.quantizer.model_size() 367 | if args.quant.diffq: 368 | loss += args.quant.diffq * ms 369 | 370 | losses = {} 371 | losses["reco"] = (reco * weights).sum() / weights.sum() 372 | losses["ms"] = ms 373 | 374 | if not train: 375 | nsdrs = new_sdr(sources, estimate.detach()).mean(0) 376 | total = 0 377 | for source, nsdr, w in zip(self.model.sources, nsdrs, weights): 378 | losses[f"nsdr_{source}"] = nsdr 379 | total += w * nsdr 380 | losses["nsdr"] = total / weights.sum() 381 | 382 | if train and args.svd.penalty > 0: 383 | kw = dict(args.svd) 384 | kw.pop("penalty") 385 | penalty = svd_penalty(self.model, **kw) 386 | losses["penalty"] = penalty 387 | loss += args.svd.penalty * penalty 388 | 389 | losses["loss"] = loss 390 | 391 | for k, source in enumerate(self.model.sources): 392 | losses[f"reco_{source}"] = reco[k] 393 | 394 | # optimize model in training mode 395 | if train: 396 | loss.backward() 397 | grad_norm = 0 398 | grads = [] 399 | for p in self.model.parameters(): 400 | if p.grad is not None: 401 | grad_norm += p.grad.data.norm() ** 2 402 | grads.append(p.grad.data) 403 | losses["grad"] = grad_norm**0.5 404 | if args.optim.clip_grad: 405 | torch.nn.utils.clip_grad_norm_( 406 | self.model.parameters(), args.optim.clip_grad 407 | ) 408 | 409 | if self.args.flag == "uns": 410 | for n, p in self.model.named_parameters(): 411 | if p.grad is None: 412 | print("no grad", n) 413 | self.optimizer.step() 414 | self.optimizer.zero_grad() 415 | for ema in self.emas["batch"]: 416 | ema.update() 417 | losses = averager(losses) 418 | logs = self._format_train(losses) 419 | logprog.update(**logs) 420 | # Just in case, clear some memory 421 | del loss, estimate, reco, ms 422 | if args.max_batches == idx: 423 | break 424 | if self.args.debug and train: 425 | break 426 | if self.args.flag == "debug": 427 | break 428 | if train: 429 | for ema in self.emas["epoch"]: 430 | ema.update() 431 | return distrib.average(losses, idx + 1) 432 | -------------------------------------------------------------------------------- /demucs/demucs.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import math 8 | import os 9 | import sys 10 | import typing as tp 11 | 12 | import torch 13 | import torchaudio 14 | from pytorch_lightning.core.lightning import LightningModule 15 | from torch import nn 16 | from torch.nn import functional as F 17 | from torchaudio.transforms import Resample 18 | 19 | from . import augment 20 | from .evaluate import new_sdr 21 | from .states import capture_init, get_quantizer 22 | from .svd import svd_penalty 23 | from .utils import center_trim, unfold 24 | 25 | 26 | class BLSTM(nn.Module): 27 | """ 28 | BiLSTM with same hidden units as input dim. 29 | If `max_steps` is not None, input will be splitting in overlapping 30 | chunks and the LSTM applied separately on each chunk. 31 | """ 32 | 33 | def __init__(self, dim, layers=1, max_steps=None, skip=False): 34 | super().__init__() 35 | assert max_steps is None or max_steps % 4 == 0 36 | self.max_steps = max_steps 37 | self.lstm = nn.LSTM( 38 | bidirectional=True, num_layers=layers, hidden_size=dim, input_size=dim 39 | ) 40 | self.linear = nn.Linear(2 * dim, dim) 41 | self.skip = skip 42 | 43 | def forward(self, x): 44 | B, C, T = x.shape 45 | y = x 46 | framed = False 47 | if self.max_steps is not None and T > self.max_steps: 48 | width = self.max_steps 49 | stride = width // 2 50 | frames = unfold(x, width, stride) 51 | nframes = frames.shape[2] 52 | framed = True 53 | x = frames.permute(0, 2, 1, 3).reshape(-1, C, width) 54 | 55 | x = x.permute(2, 0, 1) 56 | 57 | x = self.lstm(x)[0] 58 | x = self.linear(x) 59 | x = x.permute(1, 2, 0) 60 | if framed: 61 | out = [] 62 | frames = x.reshape(B, -1, C, width) 63 | limit = stride // 2 64 | for k in range(nframes): 65 | if k == 0: 66 | out.append(frames[:, k, :, :-limit]) 67 | elif k == nframes - 1: 68 | out.append(frames[:, k, :, limit:]) 69 | else: 70 | out.append(frames[:, k, :, limit:-limit]) 71 | out = torch.cat(out, -1) 72 | out = out[..., :T] 73 | x = out 74 | if self.skip: 75 | x = x + y 76 | return x 77 | 78 | 79 | def rescale_conv(conv, reference): 80 | """Rescale initial weight scale. It is unclear why it helps but it certainly does.""" 81 | std = conv.weight.std().detach() 82 | scale = (std / reference) ** 0.5 83 | conv.weight.data /= scale 84 | if conv.bias is not None: 85 | conv.bias.data /= scale 86 | 87 | 88 | def rescale_module(module, reference): 89 | for sub in module.modules(): 90 | if isinstance( 91 | sub, (nn.Conv1d, nn.ConvTranspose1d, nn.Conv2d, nn.ConvTranspose2d) 92 | ): 93 | rescale_conv(sub, reference) 94 | 95 | 96 | class LayerScale(nn.Module): 97 | """Layer scale from [Touvron et al 2021] (https://arxiv.org/pdf/2103.17239.pdf). 98 | This rescales diagonaly residual outputs close to 0 initially, then learnt. 99 | """ 100 | 101 | def __init__(self, channels: int, init: float = 0): 102 | super().__init__() 103 | self.scale = nn.Parameter(torch.zeros(channels, requires_grad=True)) 104 | self.scale.data[:] = init 105 | 106 | def forward(self, x): 107 | return self.scale[:, None] * x 108 | 109 | 110 | class DConv(nn.Module): 111 | """ 112 | New residual branches in each encoder layer. 113 | This alternates dilated convolutions, potentially with LSTMs and attention. 114 | Also before entering each residual branch, dimension is projected on a smaller subspace, 115 | e.g. of dim `channels // compress`. 116 | """ 117 | 118 | def __init__( 119 | self, 120 | channels: int, 121 | compress: float = 4, 122 | depth: int = 2, 123 | init: float = 1e-4, 124 | norm=True, 125 | attn=False, 126 | heads=4, 127 | ndecay=4, 128 | lstm=False, 129 | gelu=True, 130 | kernel=3, 131 | dilate=True, 132 | ): 133 | """ 134 | Args: 135 | channels: input/output channels for residual branch. 136 | compress: amount of channel compression inside the branch. 137 | depth: number of layers in the residual branch. Each layer has its own 138 | projection, and potentially LSTM and attention. 139 | init: initial scale for LayerNorm. 140 | norm: use GroupNorm. 141 | attn: use LocalAttention. 142 | heads: number of heads for the LocalAttention. 143 | ndecay: number of decay controls in the LocalAttention. 144 | lstm: use LSTM. 145 | gelu: Use GELU activation. 146 | kernel: kernel size for the (dilated) convolutions. 147 | dilate: if true, use dilation, increasing with the depth. 148 | """ 149 | 150 | super().__init__() 151 | assert kernel % 2 == 1 152 | self.channels = channels 153 | self.compress = compress 154 | self.depth = abs(depth) 155 | dilate = depth > 0 156 | 157 | norm_fn: tp.Callable[[int], nn.Module] 158 | norm_fn = lambda d: nn.Identity() # noqa 159 | if norm: 160 | norm_fn = lambda d: nn.GroupNorm(1, d) # noqa 161 | 162 | hidden = int(channels / compress) 163 | 164 | act: tp.Type[nn.Module] 165 | if gelu: 166 | act = nn.GELU 167 | else: 168 | act = nn.ReLU 169 | 170 | self.layers = nn.ModuleList([]) 171 | for d in range(self.depth): 172 | dilation = 2**d if dilate else 1 173 | padding = dilation * (kernel // 2) 174 | mods = [ 175 | nn.Conv1d(channels, hidden, kernel, dilation=dilation, padding=padding), 176 | norm_fn(hidden), 177 | act(), 178 | nn.Conv1d(hidden, 2 * channels, 1), 179 | norm_fn(2 * channels), 180 | nn.GLU(1), 181 | LayerScale(channels, init), 182 | ] 183 | if attn: 184 | mods.insert(3, LocalState(hidden, heads=heads, ndecay=ndecay)) 185 | if lstm: 186 | mods.insert(3, BLSTM(hidden, layers=2, max_steps=200, skip=True)) 187 | layer = nn.Sequential(*mods) 188 | self.layers.append(layer) 189 | 190 | def forward(self, x): 191 | for layer in self.layers: 192 | x = x + layer(x) 193 | return x 194 | 195 | 196 | class LocalState(nn.Module): 197 | """Local state allows to have attention based only on data (no positional embedding), 198 | but while setting a constraint on the time window (e.g. decaying penalty term). 199 | 200 | Also a failed experiments with trying to provide some frequency based attention. 201 | """ 202 | 203 | def __init__(self, channels: int, heads: int = 4, nfreqs: int = 0, ndecay: int = 4): 204 | super().__init__() 205 | assert channels % heads == 0, (channels, heads) 206 | self.heads = heads 207 | self.nfreqs = nfreqs 208 | self.ndecay = ndecay 209 | self.content = nn.Conv1d(channels, channels, 1) 210 | self.query = nn.Conv1d(channels, channels, 1) 211 | self.key = nn.Conv1d(channels, channels, 1) 212 | if nfreqs: 213 | self.query_freqs = nn.Conv1d(channels, heads * nfreqs, 1) 214 | if ndecay: 215 | self.query_decay = nn.Conv1d(channels, heads * ndecay, 1) 216 | # Initialize decay close to zero (there is a sigmoid), for maximum initial window. 217 | self.query_decay.weight.data *= 0.01 218 | assert self.query_decay.bias is not None # stupid type checker 219 | self.query_decay.bias.data[:] = -2 220 | self.proj = nn.Conv1d(channels + heads * nfreqs, channels, 1) 221 | 222 | def forward(self, x): 223 | B, C, T = x.shape 224 | heads = self.heads 225 | indexes = torch.arange(T, device=x.device, dtype=x.dtype) 226 | # left index are keys, right index are queries 227 | delta = indexes[:, None] - indexes[None, :] 228 | 229 | queries = self.query(x).view(B, heads, -1, T) 230 | keys = self.key(x).view(B, heads, -1, T) 231 | # t are keys, s are queries 232 | dots = torch.einsum("bhct,bhcs->bhts", keys, queries) 233 | dots /= keys.shape[2] ** 0.5 234 | if self.nfreqs: 235 | periods = torch.arange(1, self.nfreqs + 1, device=x.device, dtype=x.dtype) 236 | freq_kernel = torch.cos(2 * math.pi * delta / periods.view(-1, 1, 1)) 237 | freq_q = self.query_freqs(x).view(B, heads, -1, T) / self.nfreqs**0.5 238 | dots += torch.einsum("fts,bhfs->bhts", freq_kernel, freq_q) 239 | if self.ndecay: 240 | decays = torch.arange(1, self.ndecay + 1, device=x.device, dtype=x.dtype) 241 | decay_q = self.query_decay(x).view(B, heads, -1, T) 242 | decay_q = torch.sigmoid(decay_q) / 2 243 | decay_kernel = -decays.view(-1, 1, 1) * delta.abs() / self.ndecay**0.5 244 | dots += torch.einsum("fts,bhfs->bhts", decay_kernel, decay_q) 245 | 246 | # Kill self reference. 247 | dots.masked_fill_(torch.eye(T, device=dots.device, dtype=torch.bool), -100) 248 | weights = torch.softmax(dots, dim=2) 249 | 250 | content = self.content(x).view(B, heads, -1, T) 251 | result = torch.einsum("bhts,bhct->bhcs", weights, content) 252 | if self.nfreqs: 253 | time_sig = torch.einsum("bhts,fts->bhfs", weights, freq_kernel) 254 | result = torch.cat([result, time_sig], 2) 255 | result = result.reshape(B, -1, T) 256 | return x + self.proj(result) 257 | 258 | 259 | class Demucs(LightningModule): 260 | @capture_init 261 | def __init__( 262 | self, 263 | sources, 264 | # Channels 265 | audio_channels=2, 266 | channels=64, 267 | growth=2.0, 268 | # Main structure 269 | depth=6, 270 | rewrite=True, 271 | lstm_layers=0, 272 | # Convolutions 273 | kernel_size=8, 274 | stride=4, 275 | context=1, 276 | # Activations 277 | gelu=True, 278 | glu=True, 279 | # Normalization 280 | norm_starts=4, 281 | norm_groups=4, 282 | # DConv residual branch 283 | dconv_mode=1, 284 | dconv_depth=2, 285 | dconv_comp=4, 286 | dconv_attn=4, 287 | dconv_lstm=4, 288 | dconv_init=1e-4, 289 | # Pre/post processing 290 | normalize=True, 291 | resample=True, 292 | # Weight init 293 | rescale=0.1, 294 | # Metadata 295 | samplerate=44100, 296 | segment=4 * 10, 297 | args=None, 298 | ): 299 | """ 300 | Args: 301 | sources (list[str]): list of source names 302 | audio_channels (int): stereo or mono 303 | channels (int): first convolution channels 304 | depth (int): number of encoder/decoder layers 305 | growth (float): multiply (resp divide) number of channels by that 306 | for each layer of the encoder (resp decoder) 307 | depth (int): number of layers in the encoder and in the decoder. 308 | rewrite (bool): add 1x1 convolution to each layer. 309 | lstm_layers (int): number of lstm layers, 0 = no lstm. Deactivated 310 | by default, as this is now replaced by the smaller and faster small LSTMs 311 | in the DConv branches. 312 | kernel_size (int): kernel size for convolutions 313 | stride (int): stride for convolutions 314 | context (int): kernel size of the convolution in the 315 | decoder before the transposed convolution. If > 1, 316 | will provide some context from neighboring time steps. 317 | gelu: use GELU activation function. 318 | glu (bool): use glu instead of ReLU for the 1x1 rewrite conv. 319 | norm_starts: layer at which group norm starts being used. 320 | decoder layers are numbered in reverse order. 321 | norm_groups: number of groups for group norm. 322 | dconv_mode: if 1: dconv in encoder only, 2: decoder only, 3: both. 323 | dconv_depth: depth of residual DConv branch. 324 | dconv_comp: compression of DConv branch. 325 | dconv_attn: adds attention layers in DConv branch starting at this layer. 326 | dconv_lstm: adds a LSTM layer in DConv branch starting at this layer. 327 | dconv_init: initial scale for the DConv branch LayerScale. 328 | normalize (bool): normalizes the input audio on the fly, and scales back 329 | the output by the same amount. 330 | resample (bool): upsample x2 the input and downsample /2 the output. 331 | rescale (int): rescale initial weights of convolutions 332 | to get their standard deviation closer to `rescale`. 333 | samplerate (int): stored as meta information for easing 334 | future evaluations of the model. 335 | segment (float): duration of the chunks of audio to ideally evaluate the model on. 336 | This is used by `demucs.apply.apply_model`. 337 | """ 338 | 339 | super().__init__() 340 | self.audio_channels = audio_channels 341 | self.sources = sources 342 | self.kernel_size = kernel_size 343 | self.context = context 344 | self.stride = stride 345 | self.depth = depth 346 | self.resample = resample 347 | self.channels = channels 348 | self.normalize = normalize 349 | self.samplerate = samplerate 350 | self.segment = segment 351 | self.encoder = nn.ModuleList() 352 | self.decoder = nn.ModuleList() 353 | self.skip_scales = nn.ModuleList() 354 | self.upsampler = Resample(1, 2) 355 | self.downsampler = Resample(2, 1) 356 | self.args = args 357 | self.save_hyperparameters() 358 | 359 | sources_to_idx = { 360 | "drums": 0, 361 | "bass": 1, 362 | "other": 2, 363 | "vocals": 3, 364 | } # dict to map source type to index 365 | 366 | self.idx_list = ( 367 | [] 368 | ) # users provide customer sources:list, convert source type to idx:list 369 | for source in self.sources: 370 | idx = sources_to_idx[source] 371 | self.idx_list.append(idx) 372 | 373 | if args.data_augmentation: 374 | augments = [ 375 | augment.Shift( 376 | shift=int(args.samplerate * args.dset.train.shift), 377 | same=args.augment.shift_same, 378 | ) 379 | ] 380 | if args.augment.flip: 381 | augments += [augment.FlipChannels(), augment.FlipSign()] 382 | for aug in ["scale", "remix"]: 383 | kw = getattr(args.augment, aug) 384 | if kw.proba: 385 | augments.append(getattr(augment, aug.capitalize())(**kw)) 386 | self.augment = torch.nn.Sequential(*augments) 387 | 388 | if glu: 389 | activation = nn.GLU(dim=1) 390 | ch_scale = 2 391 | else: 392 | activation = nn.ReLU() 393 | ch_scale = 1 394 | if gelu: 395 | act2 = nn.GELU 396 | else: 397 | act2 = nn.ReLU 398 | 399 | in_channels = audio_channels 400 | padding = 0 401 | for index in range(depth): 402 | norm_fn = lambda d: nn.Identity() # noqa 403 | if index >= norm_starts: 404 | norm_fn = lambda d: nn.GroupNorm(norm_groups, d) # noqa 405 | 406 | encode = [] 407 | encode += [ 408 | nn.Conv1d(in_channels, channels, kernel_size, stride), 409 | norm_fn(channels), 410 | act2(), 411 | ] 412 | attn = index >= dconv_attn 413 | lstm = index >= dconv_lstm 414 | if dconv_mode & 1: 415 | encode += [ 416 | DConv( 417 | channels, 418 | depth=dconv_depth, 419 | init=dconv_init, 420 | compress=dconv_comp, 421 | attn=attn, 422 | lstm=lstm, 423 | ) 424 | ] 425 | if rewrite: 426 | encode += [ 427 | nn.Conv1d(channels, ch_scale * channels, 1), 428 | norm_fn(ch_scale * channels), 429 | activation, 430 | ] 431 | self.encoder.append(nn.Sequential(*encode)) 432 | 433 | decode = [] 434 | if index > 0: 435 | out_channels = in_channels 436 | else: 437 | out_channels = len(self.sources) * audio_channels 438 | if rewrite: 439 | decode += [ 440 | nn.Conv1d( 441 | channels, ch_scale * channels, 2 * context + 1, padding=context 442 | ), 443 | norm_fn(ch_scale * channels), 444 | activation, 445 | ] 446 | if dconv_mode & 2: 447 | decode += [ 448 | DConv( 449 | channels, 450 | depth=dconv_depth, 451 | init=dconv_init, 452 | compress=dconv_comp, 453 | attn=attn, 454 | lstm=lstm, 455 | ) 456 | ] 457 | decode += [ 458 | nn.ConvTranspose1d( 459 | channels, out_channels, kernel_size, stride, padding=padding 460 | ) 461 | ] 462 | if index > 0: 463 | decode += [norm_fn(out_channels), act2()] 464 | self.decoder.insert(0, nn.Sequential(*decode)) 465 | in_channels = channels 466 | channels = int(growth * channels) 467 | 468 | channels = in_channels 469 | if lstm_layers: 470 | self.lstm = BLSTM(channels, lstm_layers) 471 | else: 472 | self.lstm = None 473 | 474 | if rescale: 475 | rescale_module(self, reference=rescale) 476 | 477 | def valid_length(self, length): 478 | """ 479 | Return the nearest valid length to use with the model so that 480 | there is no time steps left over in a convolution, e.g. for all 481 | layers, size of the input - kernel_size % stride = 0. 482 | 483 | Note that input are automatically padded if necessary to ensure that the output 484 | has the same length as the input. 485 | """ 486 | if self.resample: 487 | length *= 2 488 | 489 | for _ in range(self.depth): 490 | length = math.ceil((length - self.kernel_size) / self.stride) + 1 491 | length = max(1, length) 492 | 493 | for idx in range(self.depth): 494 | length = (length - 1) * self.stride + self.kernel_size 495 | 496 | if self.resample: 497 | length = math.ceil(length / 2) 498 | return int(length) 499 | 500 | def forward(self, mix): 501 | x = mix 502 | length = x.shape[-1] 503 | 504 | if self.normalize: 505 | mono = mix.mean(dim=1, keepdim=True) 506 | mean = mono.mean(dim=-1, keepdim=True) 507 | std = mono.std(dim=-1, keepdim=True) 508 | x = (x - mean) / (1e-5 + std) 509 | else: 510 | mean = 0 511 | std = 1 512 | 513 | delta = self.valid_length(length) - length 514 | x = F.pad(x, (delta // 2, delta - delta // 2)) 515 | 516 | if self.resample: 517 | x = self.upsampler(x) 518 | 519 | saved = [] 520 | for encode in self.encoder: 521 | x = encode(x) 522 | saved.append(x) 523 | 524 | if self.lstm: 525 | x = self.lstm(x) 526 | 527 | for decode in self.decoder: 528 | skip = saved.pop(-1) 529 | skip = center_trim(skip, x) 530 | x = decode(x + skip) 531 | 532 | if self.resample: 533 | x = self.downsampler(x) 534 | x = x * std + mean 535 | x = center_trim(x, length) 536 | x = x.view(x.size(0), len(self.sources), self.audio_channels, x.size(-1)) 537 | return x 538 | 539 | def training_step(self, sources, batch_idx): 540 | """ 541 | The original code can be found here 542 | https://github.com/facebookresearch/demucs/blob/cb1d773a35ff889d25a5177b86c86c0ce8ba9ef3/demucs/solver.py#L290 543 | """ 544 | # self.sources (list[str]): list of source names 545 | # train_loader provide a batch of tracks with 4 sources 546 | if self.args.data_augmentation: 547 | sources = self.augment( 548 | sources 549 | ) # sources: [B, 4sources, 2, 44100*segment_length] 550 | 551 | mix = sources.sum(dim=1) # mix: [B, 2channel, 441000] 552 | estimate = self( 553 | mix 554 | ) # estimate [B, num_sources from user, 2, 44100*segment_length] 555 | 556 | # custom sources list 557 | # only calculate the loss of user sources list 558 | selected_sources = sources.index_select( 559 | 1, torch.tensor(self.idx_list).to(sources.device) 560 | ) 561 | 562 | # checking if the estimate has the correct shape 563 | assert estimate.shape == sources.shape, (estimate.shape, sources.shape) 564 | dims = tuple(range(2, sources.dim())) 565 | 566 | if self.args.optim.loss == "l1": 567 | loss = F.l1_loss(estimate, selected_sources, reduction="none") 568 | loss = loss.mean(dims).mean(0) 569 | reco = loss 570 | elif self.args.optim.loss == "mse": 571 | loss = F.mse_loss(estimate, selected_sources, reduction="none") 572 | loss = loss.mean(dims) 573 | reco = loss**0.5 574 | reco = reco.mean(0) 575 | else: 576 | raise ValueError(f"Invalid loss {self.args.loss}") 577 | 578 | weights = torch.tensor(self.args.weights).to(sources) 579 | loss = (loss * weights).sum() / weights.sum() 580 | 581 | # self.quantizer = get_quantizer(self, args.quant, self.optimizer) 582 | # quantization: get self.quantizer from train.py 583 | ms = 0 584 | # ms: model size 585 | if self.quantizer is not None: 586 | ms = self.quantizer.model_size() 587 | if self.args.quant.diffq: 588 | loss += args.quant.diffq * ms 589 | # use ms to calculate loss 590 | 591 | losses = {} 592 | losses["TRAIN/reco"] = (reco * weights).sum() / weights.sum() 593 | losses["TRAIN/ms"] = ms 594 | 595 | # penality 596 | if self.args.svd.penalty > 0: 597 | kw = dict(args.svd) 598 | kw.pop("penalty") 599 | penalty = svd_penalty(self, **kw) 600 | losses["penalty"] = penalty 601 | loss += args.svd.penalty * penalty 602 | losses["TRAIN/loss"] = loss 603 | 604 | self.log_dict(losses, on_step=False, on_epoch=True, sync_dist=True) 605 | # log(graph title, take acc as data, on_step: plot every step, on_epch: plot every epoch) 606 | 607 | return loss 608 | 609 | def validation_step(self, sources, batch_idx): 610 | from .apply import apply_model 611 | 612 | # source 1, mixture+ number of sources, 2, 7736477) 613 | mix = sources[:, 0] 614 | sources = sources[:, 1:] 615 | 616 | if self.args.valid_apply: 617 | estimate = apply_model(self, mix, split=self.args.test.split, overlap=0) 618 | 619 | # checking if the estimate has the correct shape 620 | assert estimate.shape == sources.shape, (estimate.shape, sources.shape) 621 | dims = tuple(range(2, sources.dim())) 622 | 623 | if self.args.optim.loss == "l1": 624 | loss = F.l1_loss(estimate, sources, reduction="none") 625 | loss = loss.mean(dims).mean(0) 626 | reco = loss 627 | elif self.args.optim.loss == "mse": 628 | loss = F.mse_loss(estimate, sources, reduction="none") 629 | loss = loss.mean(dims) 630 | reco = loss**0.5 631 | reco = reco.mean(0) 632 | else: 633 | raise ValueError(f"Invalid loss {self.args.loss}") 634 | 635 | weights = torch.tensor(self.args.weights).to(sources) 636 | loss = (loss * weights).sum() / weights.sum() 637 | 638 | losses = {} 639 | losses["VAL/loss"] = loss 640 | 641 | nsdrs = new_sdr(sources, estimate.detach()).mean(0) 642 | # sources is each batch of daatset [tensor] 643 | total = 0 644 | for source, nsdr, w in zip(self.sources, nsdrs, weights): 645 | # self.sources is [str] 646 | losses[f"VAL/nsdr_{source}"] = nsdr 647 | total += w * nsdr 648 | losses["VAL/nsdr"] = total / weights.sum() 649 | 650 | self.log_dict(losses, on_step=False, on_epoch=True, sync_dist=True) 651 | 652 | return loss, nsdr 653 | 654 | # loss, reco, alid loss, nsdr 655 | def test_step(self, sources, batch_idx): 656 | from .apply import apply_model 657 | 658 | # source 1, 5, 2, 7736477) 659 | mix = sources[:, 0] 660 | sources = sources[:, 1:] 661 | 662 | if self.args.valid_apply: 663 | estimate = apply_model(self, mix, split=self.args.test.split, overlap=0) 664 | 665 | # checking if the estimate has the correct shape 666 | assert estimate.shape == sources.shape, (estimate.shape, sources.shape) 667 | dims = tuple(range(2, sources.dim())) 668 | 669 | if self.args.optim.loss == "l1": 670 | loss = F.l1_loss(estimate, sources, reduction="none") 671 | loss = loss.mean(dims).mean(0) 672 | reco = loss 673 | elif self.args.optim.loss == "mse": 674 | loss = F.mse_loss(estimate, sources, reduction="none") 675 | loss = loss.mean(dims) 676 | reco = loss**0.5 677 | reco = reco.mean(0) 678 | else: 679 | raise ValueError(f"Invalid loss {self.args.loss}") 680 | 681 | weights = torch.tensor(self.args.weights).to(sources) 682 | loss = (loss * weights).sum() / weights.sum() 683 | 684 | losses = {} 685 | losses["Test/loss"] = loss 686 | 687 | nsdrs = new_sdr(sources, estimate.detach()).mean(0) 688 | # sources is each batch of daatset [tensor] 689 | total = 0 690 | for source, nsdr, w in zip(self.sources, nsdrs, weights): 691 | # self.sources is [str] 692 | losses[f"Test/nsdr_{source}"] = nsdr 693 | total += w * nsdr 694 | losses["Test/nsdr"] = total / weights.sum() 695 | 696 | self.log_dict(losses, on_step=False, on_epoch=True, sync_dist=True) 697 | 698 | # show the audio output in tensorboard 699 | if batch_idx == 0: 700 | mixture_audio_stereo = mix 701 | # [1, 2, 9675225] 702 | mixture_audio_mono = torch.mean(mixture_audio_stereo, 1) 703 | # from stereo [1,2, 9675225] to mono [1, 9675225] 704 | 705 | self.logger.experiment.add_audio( 706 | "test/mixture", 707 | snd_tensor=mixture_audio_mono.detach().cpu().numpy(), 708 | sample_rate=44100, 709 | ) 710 | # because snd_tensor need to be mono (1,L) 711 | 712 | # data visualising for audio 713 | for i, audio in enumerate(self.sources): 714 | label_stereo = sources[ 715 | :, i 716 | ] # from [1, 4, 2, 9675225] to [1, 2, 9675225] 717 | label_mono = torch.mean( 718 | label_stereo, 1 719 | ) # from stereo[1, 2, 9675225] to mono [1, 9675225] 720 | self.logger.experiment.add_audio( 721 | f"test/label/{audio}", 722 | snd_tensor=label_mono.detach().cpu().numpy(), 723 | sample_rate=44100, 724 | ) 725 | 726 | pred_stereo = estimate[ 727 | :, i 728 | ] # estimate [1, 4, 2, 9675225] to [1, 2, 9675225] 729 | pred_mono = torch.mean( 730 | pred_stereo, 1 731 | ) # from stereo[1, 2, 9675225] to mono [1, 9675225] 732 | self.logger.experiment.add_audio( 733 | f"test/pred/{audio}", 734 | snd_tensor=pred_mono.detach().cpu().numpy(), 735 | sample_rate=44100, 736 | ) 737 | 738 | return loss, nsdr 739 | 740 | def configure_optimizers(self): 741 | optimizer = torch.optim.Adam( 742 | self.parameters(), 743 | lr=self.args.optim.lr, 744 | betas=(self.args.optim.momentum, self.args.optim.beta2), 745 | weight_decay=self.args.optim.weight_decay, 746 | ) 747 | return optimizer 748 | 749 | def load_state_dict(self, state, strict=True): 750 | # fix a mismatch with previous generation Demucs models. 751 | for idx in range(self.depth): 752 | for a in ["encoder", "decoder"]: 753 | for b in ["bias", "weight"]: 754 | new = f"{a}.{idx}.3.{b}" 755 | old = f"{a}.{idx}.2.{b}" 756 | if old in state and new not in state: 757 | state[new] = state.pop(old) 758 | super().load_state_dict(state, strict=strict) 759 | 760 | def predict_step( 761 | self, batch, batch_idx 762 | ): # dataloader return each batch: waveform:tensor, (audio_name:str) 763 | from .apply import apply_model 764 | 765 | estimate = apply_model( 766 | self, batch[0], split=True, overlap=0 767 | ) # estimate: [1, num_sources, 2, 9675225] 768 | 769 | # os.makedirs() 770 | # make dir to store separated tracks inside 771 | # './' refer to pytorch lightning_outputs folder location 772 | os.makedirs(os.path.join("./", (batch[1][0]))) 773 | 774 | for i, audio in enumerate(self.sources): 775 | pred_stereo = estimate[:, i] # pred_stereo: [1, 2, 9675225] 776 | pred_mono = torch.mean( 777 | pred_stereo, 1 778 | ) # from stereo[1, 2, 9675225] to mono [1, 9675225] 779 | 780 | # export the seperated audio by torchaudio.save(path, waveform, sample_rate) 781 | # Input tensor has to be 2D 782 | torchaudio.save( 783 | os.path.join("./", (batch[1][0]), audio + ".wav"), 784 | pred_mono.detach().cpu(), 785 | self.args.samplerate, 786 | ) 787 | --------------------------------------------------------------------------------