├── .github ├── ISSUE_TEMPLATE │ ├── bug.md │ └── question.md └── workflows │ ├── linter.yml │ └── tests.yml ├── .gitignore ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── Demucs.ipynb ├── LICENSE ├── MANIFEST.in ├── Makefile ├── README.md ├── conf ├── config.yaml ├── dset │ ├── aetl.yaml │ ├── auto_extra_test.yaml │ ├── auto_mus.yaml │ ├── extra44.yaml │ ├── extra_mmi_goodclean.yaml │ ├── extra_test.yaml │ ├── musdb44.yaml │ ├── sdx23_bleeding.yaml │ └── sdx23_labelnoise.yaml ├── svd │ ├── base.yaml │ ├── base2.yaml │ └── default.yaml └── variant │ ├── default.yaml │ ├── example.yaml │ └── finetune.yaml ├── demucs.png ├── demucs ├── __init__.py ├── __main__.py ├── api.py ├── apply.py ├── audio.py ├── audio_legacy.py ├── augment.py ├── demucs.py ├── distrib.py ├── ema.py ├── evaluate.py ├── grids │ ├── __init__.py │ ├── _explorers.py │ ├── mdx.py │ ├── mdx_extra.py │ ├── mdx_refine.py │ ├── mmi.py │ ├── mmi_ft.py │ ├── repro.py │ ├── repro_ft.py │ └── sdx23.py ├── hdemucs.py ├── htdemucs.py ├── pretrained.py ├── py.typed ├── remote │ ├── files.txt │ ├── hdemucs_mmi.yaml │ ├── htdemucs.yaml │ ├── htdemucs_6s.yaml │ ├── htdemucs_ft.yaml │ ├── mdx.yaml │ ├── mdx_extra.yaml │ ├── mdx_extra_q.yaml │ ├── mdx_q.yaml │ ├── repro_mdx_a.yaml │ ├── repro_mdx_a_hybrid_only.yaml │ └── repro_mdx_a_time_only.yaml ├── repitch.py ├── repo.py ├── separate.py ├── solver.py ├── spec.py ├── states.py ├── svd.py ├── train.py ├── transformer.py ├── utils.py ├── wav.py └── wdemucs.py ├── docs ├── api.md ├── linux.md ├── mac.md ├── mdx.md ├── release.md ├── sdx23.md ├── training.md └── windows.md ├── environment-cpu.yml ├── environment-cuda.yml ├── hubconf.py ├── mypy.ini ├── outputs.tar.gz ├── requirements.txt ├── requirements_minimal.txt ├── setup.cfg ├── setup.py ├── test.mp3 └── tools ├── __init__.py ├── automix.py ├── bench.py ├── convert.py ├── export.py └── test_pretrained.py /.github/ISSUE_TEMPLATE/bug.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: 🐛 Bug Report 3 | about: Submit a bug report to help us improve 4 | labels: 'bug' 5 | --- 6 | 7 | ## 🐛 Bug Report 8 | 9 | (A clear and concise description of what the bug is) 10 | 11 | ## To Reproduce 12 | 13 | (Write your steps here:) 14 | 15 | 1. Step 1... 16 | 1. Step 2... 17 | 1. Step 3... 18 | 19 | ## Expected behavior 20 | 21 | (Write what you thought would happen.) 22 | 23 | ## Actual Behavior 24 | 25 | (Write what happened. Add screenshots, if applicable.) 26 | 27 | ## Your Environment 28 | 29 | 30 | 31 | - Python and PyTorch version: 32 | - Operating system and version (desktop or mobile): 33 | - Hardware (gpu or cpu, amount of RAM etc.): 34 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/question.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: "❓Questions/Help/Support" 3 | about: If you have a question about the paper, code or algorithm, please ask here! 4 | labels: question 5 | 6 | --- 7 | 8 | ## ❓ Questions 9 | 10 | (Please ask your question here.) 11 | -------------------------------------------------------------------------------- /.github/workflows/linter.yml: -------------------------------------------------------------------------------- 1 | name: linter 2 | on: 3 | push: 4 | branches: [ main ] 5 | pull_request: 6 | branches: [ main ] 7 | workflow_dispatch: 8 | 9 | jobs: 10 | build: 11 | runs-on: ubuntu-latest 12 | if: ${{ github.repository == 'facebookresearch/demucs' || github.event_name == 'workflow_dispatch' }} 13 | steps: 14 | - uses: actions/checkout@v2 15 | - uses: actions/setup-python@v2 16 | with: 17 | python-version: 3.8 18 | 19 | - uses: actions/cache@v2 20 | with: 21 | path: env 22 | key: env-${{ hashFiles('**/requirements.txt', '.github/workflows/*') }} 23 | 24 | - name: Install dependencies 25 | run: | 26 | python3 -m venv env 27 | . env/bin/activate 28 | python -m pip install --upgrade pip 29 | pip install -r requirements.txt 30 | pip install '.[dev]' 31 | 32 | 33 | - name: Run linter 34 | run: | 35 | . env/bin/activate 36 | make linter 37 | -------------------------------------------------------------------------------- /.github/workflows/tests.yml: -------------------------------------------------------------------------------- 1 | name: tests 2 | on: 3 | push: 4 | branches: [ main ] 5 | pull_request: 6 | branches: [ main ] 7 | workflow_dispatch: 8 | 9 | jobs: 10 | build: 11 | runs-on: ubuntu-latest 12 | if: ${{ github.repository == 'facebookresearch/demucs' || github.event_name == 'workflow_dispatch' }} 13 | steps: 14 | - uses: actions/checkout@v2 15 | - uses: actions/setup-python@v2 16 | with: 17 | python-version: 3.8 18 | 19 | - uses: actions/cache@v2 20 | with: 21 | path: env 22 | key: env-${{ hashFiles('**/requirements.txt', '.github/workflows/*') }} 23 | 24 | - name: Install dependencies 25 | run: | 26 | sudo apt-get update 27 | sudo apt-get install -y ffmpeg 28 | python3 -m venv env 29 | . env/bin/activate 30 | python -m pip install --upgrade pip 31 | pip install -r requirements.txt 32 | 33 | - name: Run separation test 34 | run: | 35 | . env/bin/activate 36 | make test_eval 37 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.egg-info 2 | __pycache__ 3 | Session.vim 4 | /build 5 | /dist 6 | /lab 7 | /metadata 8 | /notebooks 9 | /outputs 10 | /release 11 | /release_models 12 | /separated 13 | /tests 14 | /trash 15 | /misc 16 | /mdx 17 | .mypy_cache 18 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to make participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies within all project spaces, and it also applies when 49 | an individual is representing the project or its community in public spaces. 50 | Examples of representing a project or community include using an official 51 | project e-mail address, posting via an official social media account, or acting 52 | as an appointed representative at an online or offline event. Representation of 53 | a project may be further defined and clarified by project maintainers. 54 | 55 | ## Enforcement 56 | 57 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 58 | reported by contacting the project team at . All 59 | complaints will be reviewed and investigated and will result in a response that 60 | is deemed necessary and appropriate to the circumstances. The project team is 61 | obligated to maintain confidentiality with regard to the reporter of an incident. 62 | Further details of specific enforcement policies may be posted separately. 63 | 64 | Project maintainers who do not follow or enforce the Code of Conduct in good 65 | faith may face temporary or permanent repercussions as determined by other 66 | members of the project's leadership. 67 | 68 | ## Attribution 69 | 70 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 71 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 72 | 73 | [homepage]: https://www.contributor-covenant.org 74 | 75 | For answers to common questions about this code of conduct, see 76 | https://www.contributor-covenant.org/faq 77 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to Demucs 2 | 3 | ## Pull Requests 4 | 5 | In order to accept your pull request, we need you to submit a CLA. You only need 6 | to do this once to work on any of Facebook's open source projects. 7 | 8 | Complete your CLA here: 9 | 10 | Demucs is the implementation of a research paper. 11 | Therefore, we do not plan on accepting many pull requests for new features. 12 | We certainly welcome them for bug fixes. 13 | 14 | 15 | ## Issues 16 | 17 | We use GitHub issues to track public bugs. Please ensure your description is 18 | clear and has sufficient instructions to be able to reproduce the issue. 19 | 20 | 21 | ## License 22 | By contributing to this repository, you agree that your contributions will be licensed 23 | under the LICENSE file in the root directory of this source tree. 24 | -------------------------------------------------------------------------------- /Demucs.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "colab_type": "text", 7 | "id": "Be9yoh-ILfRr" 8 | }, 9 | "source": [ 10 | "# Hybrid Demucs\n", 11 | "\n", 12 | "Feel free to use the Colab version:\n", 13 | "https://colab.research.google.com/drive/1dC9nVxk3V_VPjUADsnFu8EiT-xnU1tGH?usp=sharing" 14 | ] 15 | }, 16 | { 17 | "cell_type": "code", 18 | "execution_count": null, 19 | "metadata": { 20 | "colab": { 21 | "base_uri": "https://localhost:8080/", 22 | "height": 139 23 | }, 24 | "colab_type": "code", 25 | "executionInfo": { 26 | "elapsed": 12277, 27 | "status": "ok", 28 | "timestamp": 1583778134659, 29 | "user": { 30 | "displayName": "Marllus Lustosa", 31 | "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GgLl2RbW64ZyWz3Y8IBku0zhHCMnt7fz7fEl0LTdA=s64", 32 | "userId": "14811735256675200480" 33 | }, 34 | "user_tz": 180 35 | }, 36 | "id": "kOjIPLlzhPfn", 37 | "outputId": "c75f17ec-b576-4105-bc5b-c2ac9c1018a3" 38 | }, 39 | "outputs": [], 40 | "source": [ 41 | "!pip install -U demucs\n", 42 | "# or for local development, if you have a clone of Demucs\n", 43 | "# pip install -e ." 44 | ] 45 | }, 46 | { 47 | "cell_type": "code", 48 | "execution_count": null, 49 | "metadata": { 50 | "colab": {}, 51 | "colab_type": "code", 52 | "id": "5lYOzKKCKAbJ" 53 | }, 54 | "outputs": [], 55 | "source": [ 56 | "# You can use the `demucs` command line to separate tracks\n", 57 | "!demucs test.mp3" 58 | ] 59 | }, 60 | { 61 | "cell_type": "code", 62 | "execution_count": null, 63 | "metadata": {}, 64 | "outputs": [], 65 | "source": [ 66 | "# You can also load directly the pretrained models,\n", 67 | "# for instance for the MDX 2021 winning model of Track A:\n", 68 | "from demucs import pretrained\n", 69 | "model = pretrained.get_model('mdx')" 70 | ] 71 | }, 72 | { 73 | "cell_type": "code", 74 | "execution_count": null, 75 | "metadata": {}, 76 | "outputs": [], 77 | "source": [ 78 | "# Because `model` is a bag of 4 models, you cannot directly call it on your data,\n", 79 | "# but the `apply_model` will know what to do of it.\n", 80 | "import torch\n", 81 | "from demucs.apply import apply_model\n", 82 | "x = torch.randn(1, 2, 44100 * 10) # ten seconds of white noise for the demo\n", 83 | "out = apply_model(model, x)[0] # shape is [S, C, T] with S the number of sources\n", 84 | "\n", 85 | "# So let see, where is all the white noise content is going ?\n", 86 | "for name, source in zip(model.sources, out):\n", 87 | " print(name, source.std() / x.std())\n", 88 | "# The outputs are quite weird to be fair, not what I would have expected." 89 | ] 90 | }, 91 | { 92 | "cell_type": "code", 93 | "execution_count": null, 94 | "metadata": {}, 95 | "outputs": [], 96 | "source": [ 97 | "# now let's take a single model from the bag, and let's test it on a pure cosine\n", 98 | "freq = 440 # in Hz\n", 99 | "sr = model.samplerate\n", 100 | "t = torch.arange(10 * sr).float() / sr\n", 101 | "x = torch.cos(2 * 3.1416 * freq * t).expand(1, 2, -1)\n", 102 | "sub_model = model.models[3]\n", 103 | "out = sub_model(x)[0]\n", 104 | "\n", 105 | "# Same question where does it go?\n", 106 | "for name, source in zip(model.sources, out):\n", 107 | " print(name, source.std() / x.std())\n", 108 | " \n", 109 | "# Well now it makes much more sense, all the energy is going\n", 110 | "# in the `other` source.\n", 111 | "# Feel free to try lower pitch (try 80 Hz) to see what happens !" 112 | ] 113 | }, 114 | { 115 | "cell_type": "code", 116 | "execution_count": null, 117 | "metadata": {}, 118 | "outputs": [], 119 | "source": [ 120 | "# For training or more fun, refer to the Demucs README on our repo\n", 121 | "# https://github.com/facebookresearch/demucs/tree/main/demucs" 122 | ] 123 | } 124 | ], 125 | "metadata": { 126 | "accelerator": "GPU", 127 | "colab": { 128 | "authorship_tag": "ABX9TyM9xpVr1M86NRcjtQ7g9tCx", 129 | "collapsed_sections": [], 130 | "name": "Demucs.ipynb", 131 | "provenance": [] 132 | }, 133 | "kernelspec": { 134 | "display_name": "Python 3", 135 | "language": "python", 136 | "name": "python3" 137 | }, 138 | "language_info": { 139 | "codemirror_mode": { 140 | "name": "ipython", 141 | "version": 3 142 | }, 143 | "file_extension": ".py", 144 | "mimetype": "text/x-python", 145 | "name": "python", 146 | "nbconvert_exporter": "python", 147 | "pygments_lexer": "ipython3", 148 | "version": "3.8.8" 149 | } 150 | }, 151 | "nbformat": 4, 152 | "nbformat_minor": 1 153 | } 154 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Meta Platforms, Inc. and affiliates. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | recursive-exclude env * 2 | recursive-include conf *.yaml 3 | include Makefile 4 | include LICENSE 5 | include demucs.png 6 | include outputs.tar.gz 7 | include test.mp3 8 | include requirements.txt 9 | include requirements_minimal.txt 10 | include mypy.ini 11 | include demucs/py.typed 12 | include demucs/remote/*.txt 13 | include demucs/remote/*.yaml 14 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | all: linter tests 2 | 3 | linter: 4 | flake8 demucs 5 | mypy demucs 6 | 7 | tests: test_train test_eval 8 | 9 | test_train: tests/musdb 10 | _DORA_TEST_PATH=/tmp/demucs python3 -m dora run --clear \ 11 | dset.musdb=./tests/musdb dset.segment=4 dset.shift=2 epochs=2 model=demucs \ 12 | demucs.depth=2 demucs.channels=4 test.sdr=false misc.num_workers=0 test.workers=0 \ 13 | test.shifts=0 14 | 15 | test_eval: 16 | python3 -m demucs -n demucs_unittest test.mp3 17 | python3 -m demucs -n demucs_unittest --two-stems=vocals test.mp3 18 | python3 -m demucs -n demucs_unittest --mp3 test.mp3 19 | python3 -m demucs -n demucs_unittest --flac --int24 test.mp3 20 | python3 -m demucs -n demucs_unittest --int24 --clip-mode clamp test.mp3 21 | python3 -m demucs -n demucs_unittest --segment 8 test.mp3 22 | python3 -m demucs.api -n demucs_unittest --segment 8 test.mp3 23 | python3 -m demucs --list-models 24 | 25 | tests/musdb: 26 | test -e tests || mkdir tests 27 | python3 -c 'import musdb; musdb.DB("tests/tmp", download=True)' 28 | musdbconvert tests/tmp tests/musdb 29 | 30 | dist: 31 | python3 setup.py sdist 32 | 33 | clean: 34 | rm -r dist build *.egg-info 35 | 36 | .PHONY: linter dist test_train test_eval 37 | -------------------------------------------------------------------------------- /conf/config.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _self_ 3 | - dset: musdb44 4 | - svd: default 5 | - variant: default 6 | - override hydra/hydra_logging: colorlog 7 | - override hydra/job_logging: colorlog 8 | 9 | dummy: 10 | dset: 11 | musdb: /checkpoint/defossez/datasets/musdbhq 12 | musdb_samplerate: 44100 13 | use_musdb: true # set to false to not use musdb as training data. 14 | wav: # path to custom wav dataset 15 | wav2: # second custom wav dataset 16 | segment: 11 17 | shift: 1 18 | train_valid: false 19 | full_cv: true 20 | samplerate: 44100 21 | channels: 2 22 | normalize: true 23 | metadata: ./metadata 24 | sources: ['drums', 'bass', 'other', 'vocals'] 25 | valid_samples: # valid dataset size 26 | backend: null # if provided select torchaudio backend. 27 | 28 | test: 29 | save: False 30 | best: True 31 | workers: 2 32 | every: 20 33 | split: true 34 | shifts: 1 35 | overlap: 0.25 36 | sdr: true 37 | metric: 'loss' # metric used for best model selection on the valid set, can also be nsdr 38 | nonhq: # path to non hq MusDB for evaluation 39 | 40 | epochs: 360 41 | batch_size: 64 42 | max_batches: # limit the number of batches per epoch, useful for debugging 43 | # or if your dataset is gigantic. 44 | optim: 45 | lr: 3e-4 46 | momentum: 0.9 47 | beta2: 0.999 48 | loss: l1 # l1 or mse 49 | optim: adam 50 | weight_decay: 0 51 | clip_grad: 0 52 | 53 | seed: 42 54 | debug: false 55 | valid_apply: true 56 | flag: 57 | save_every: 58 | weights: [1., 1., 1., 1.] # weights over each source for the training/valid loss. 59 | 60 | augment: 61 | shift_same: false 62 | repitch: 63 | proba: 0.2 64 | max_tempo: 12 65 | remix: 66 | proba: 1 67 | group_size: 4 68 | scale: 69 | proba: 1 70 | min: 0.25 71 | max: 1.25 72 | flip: true 73 | 74 | continue_from: # continue from other XP, give the XP Dora signature. 75 | continue_pretrained: # signature of a pretrained XP, this cannot be a bag of models. 76 | pretrained_repo: # repo for pretrained model (default is official AWS) 77 | continue_best: true 78 | continue_opt: false 79 | 80 | misc: 81 | num_workers: 10 82 | num_prints: 4 83 | show: false 84 | verbose: false 85 | 86 | # List of decay for EMA at batch or epoch level, e.g. 0.999. 87 | # Batch level EMA are kept on GPU for speed. 88 | ema: 89 | epoch: [] 90 | batch: [] 91 | 92 | use_train_segment: true # to remove 93 | model_segment: # override the segment parameter for the model, usually 4 times the training segment. 94 | model: demucs # see demucs/train.py for the possibilities, and config for each model hereafter. 95 | demucs: # see demucs/demucs.py for a detailed description 96 | # Channels 97 | channels: 64 98 | growth: 2 99 | # Main structure 100 | depth: 6 101 | rewrite: true 102 | lstm_layers: 0 103 | # Convolutions 104 | kernel_size: 8 105 | stride: 4 106 | context: 1 107 | # Activations 108 | gelu: true 109 | glu: true 110 | # Normalization 111 | norm_groups: 4 112 | norm_starts: 4 113 | # DConv residual branch 114 | dconv_depth: 2 115 | dconv_mode: 1 # 1 = branch in encoder, 2 = in decoder, 3 = in both. 116 | dconv_comp: 4 117 | dconv_attn: 4 118 | dconv_lstm: 4 119 | dconv_init: 1e-4 120 | # Pre/post treatment 121 | resample: true 122 | normalize: false 123 | # Weight init 124 | rescale: 0.1 125 | 126 | hdemucs: # see demucs/hdemucs.py for a detailed description 127 | # Channels 128 | channels: 48 129 | channels_time: 130 | growth: 2 131 | # STFT 132 | nfft: 4096 133 | wiener_iters: 0 134 | end_iters: 0 135 | wiener_residual: false 136 | cac: true 137 | # Main structure 138 | depth: 6 139 | rewrite: true 140 | hybrid: true 141 | hybrid_old: false 142 | # Frequency Branch 143 | multi_freqs: [] 144 | multi_freqs_depth: 3 145 | freq_emb: 0.2 146 | emb_scale: 10 147 | emb_smooth: true 148 | # Convolutions 149 | kernel_size: 8 150 | stride: 4 151 | time_stride: 2 152 | context: 1 153 | context_enc: 0 154 | # normalization 155 | norm_starts: 4 156 | norm_groups: 4 157 | # DConv residual branch 158 | dconv_mode: 1 159 | dconv_depth: 2 160 | dconv_comp: 4 161 | dconv_attn: 4 162 | dconv_lstm: 4 163 | dconv_init: 1e-3 164 | # Weight init 165 | rescale: 0.1 166 | 167 | # Torchaudio implementation of HDemucs 168 | torch_hdemucs: 169 | # Channels 170 | channels: 48 171 | growth: 2 172 | # STFT 173 | nfft: 4096 174 | # Main structure 175 | depth: 6 176 | freq_emb: 0.2 177 | emb_scale: 10 178 | emb_smooth: true 179 | # Convolutions 180 | kernel_size: 8 181 | stride: 4 182 | time_stride: 2 183 | context: 1 184 | context_enc: 0 185 | # normalization 186 | norm_starts: 4 187 | norm_groups: 4 188 | # DConv residual branch 189 | dconv_depth: 2 190 | dconv_comp: 4 191 | dconv_attn: 4 192 | dconv_lstm: 4 193 | dconv_init: 1e-3 194 | 195 | htdemucs: # see demucs/htdemucs.py for a detailed description 196 | # Channels 197 | channels: 48 198 | channels_time: 199 | growth: 2 200 | # STFT 201 | nfft: 4096 202 | wiener_iters: 0 203 | end_iters: 0 204 | wiener_residual: false 205 | cac: true 206 | # Main structure 207 | depth: 4 208 | rewrite: true 209 | # Frequency Branch 210 | multi_freqs: [] 211 | multi_freqs_depth: 3 212 | freq_emb: 0.2 213 | emb_scale: 10 214 | emb_smooth: true 215 | # Convolutions 216 | kernel_size: 8 217 | stride: 4 218 | time_stride: 2 219 | context: 1 220 | context_enc: 0 221 | # normalization 222 | norm_starts: 4 223 | norm_groups: 4 224 | # DConv residual branch 225 | dconv_mode: 1 226 | dconv_depth: 2 227 | dconv_comp: 8 228 | dconv_init: 1e-3 229 | # Before the Transformer 230 | bottom_channels: 0 231 | # CrossTransformer 232 | # ------ Common to all 233 | # Regular parameters 234 | t_layers: 5 235 | t_hidden_scale: 4.0 236 | t_heads: 8 237 | t_dropout: 0.0 238 | t_layer_scale: True 239 | t_gelu: True 240 | # ------------- Positional Embedding 241 | t_emb: sin 242 | t_max_positions: 10000 # for the scaled embedding 243 | t_max_period: 10000.0 244 | t_weight_pos_embed: 1.0 245 | t_cape_mean_normalize: True 246 | t_cape_augment: True 247 | t_cape_glob_loc_scale: [5000.0, 1.0, 1.4] 248 | t_sin_random_shift: 0 249 | # ------------- norm before a transformer encoder 250 | t_norm_in: True 251 | t_norm_in_group: False 252 | # ------------- norm inside the encoder 253 | t_group_norm: False 254 | t_norm_first: True 255 | t_norm_out: True 256 | # ------------- optim 257 | t_weight_decay: 0.0 258 | t_lr: 259 | # ------------- sparsity 260 | t_sparse_self_attn: False 261 | t_sparse_cross_attn: False 262 | t_mask_type: diag 263 | t_mask_random_seed: 42 264 | t_sparse_attn_window: 400 265 | t_global_window: 100 266 | t_sparsity: 0.95 267 | t_auto_sparsity: False 268 | # Cross Encoder First (False) 269 | t_cross_first: False 270 | # Weight init 271 | rescale: 0.1 272 | 273 | svd: # see svd.py for documentation 274 | penalty: 0 275 | min_size: 0.1 276 | dim: 1 277 | niters: 2 278 | powm: false 279 | proba: 1 280 | conv_only: false 281 | convtr: false 282 | bs: 1 283 | 284 | quant: # quantization hyper params 285 | diffq: # diffq penalty, typically 1e-4 or 3e-4 286 | qat: # use QAT with a fixed number of bits (not as good as diffq) 287 | min_size: 0.2 288 | group_size: 8 289 | 290 | dora: 291 | dir: outputs 292 | exclude: ["misc.*", "slurm.*", 'test.reval', 'flag', 'dset.backend'] 293 | 294 | slurm: 295 | time: 4320 296 | constraint: volta32gb 297 | setup: ['module load cudnn/v8.4.1.50-cuda.11.6 NCCL/2.11.4-6-cuda.11.6 cuda/11.6'] 298 | 299 | # Hydra config 300 | hydra: 301 | job_logging: 302 | formatters: 303 | colorlog: 304 | datefmt: "%m-%d %H:%M:%S" 305 | -------------------------------------------------------------------------------- /conf/dset/aetl.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # automix dataset with Musdb, extra training data and the test set of Musdb. 4 | # This used even more remixes than auto_extra_test. 5 | dset: 6 | wav: /checkpoint/defossez/datasets/aetl 7 | samplerate: 44100 8 | channels: 2 9 | epochs: 320 10 | max_batches: 500 11 | 12 | augment: 13 | shift_same: true 14 | scale: 15 | proba: 0. 16 | remix: 17 | proba: 0 18 | repitch: 19 | proba: 0 20 | -------------------------------------------------------------------------------- /conf/dset/auto_extra_test.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # automix dataset with Musdb, extra training data and the test set of Musdb. 4 | dset: 5 | wav: /checkpoint/defossez/datasets/automix_extra_test2 6 | samplerate: 44100 7 | channels: 2 8 | epochs: 320 9 | max_batches: 500 10 | 11 | augment: 12 | shift_same: true 13 | scale: 14 | proba: 0. 15 | remix: 16 | proba: 0 17 | repitch: 18 | proba: 0 19 | -------------------------------------------------------------------------------- /conf/dset/auto_mus.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # Automix dataset based on musdb train set. 4 | dset: 5 | wav: /checkpoint/defossez/datasets/automix_musdb 6 | samplerate: 44100 7 | channels: 2 8 | epochs: 360 9 | max_batches: 300 10 | test: 11 | every: 4 12 | 13 | augment: 14 | shift_same: true 15 | scale: 16 | proba: 0.5 17 | remix: 18 | proba: 0 19 | repitch: 20 | proba: 0 21 | -------------------------------------------------------------------------------- /conf/dset/extra44.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # Musdb + extra tracks 4 | dset: 5 | wav: /checkpoint/defossez/datasets/allstems_44/ 6 | samplerate: 44100 7 | channels: 2 8 | epochs: 320 9 | -------------------------------------------------------------------------------- /conf/dset/extra_mmi_goodclean.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # Musdb + extra tracks 4 | dset: 5 | wav: /checkpoint/defossez/datasets/allstems_44/ 6 | wav2: /checkpoint/defossez/datasets/mmi44_goodclean 7 | samplerate: 44100 8 | channels: 2 9 | wav2_weight: null 10 | wav2_valid: false 11 | valid_samples: 100 12 | epochs: 1200 13 | -------------------------------------------------------------------------------- /conf/dset/extra_test.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # Musdb + extra tracks + test set from musdb. 4 | dset: 5 | wav: /checkpoint/defossez/datasets/allstems_test_44/ 6 | samplerate: 44100 7 | channels: 2 8 | epochs: 320 9 | max_batches: 700 10 | test: 11 | sdr: false 12 | every: 500 13 | -------------------------------------------------------------------------------- /conf/dset/musdb44.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | dset: 4 | samplerate: 44100 5 | channels: 2 -------------------------------------------------------------------------------- /conf/dset/sdx23_bleeding.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # Musdb + extra tracks 4 | dset: 5 | wav: /shared/home/defossez/data/datasets/moisesdb23_bleeding_v1.0/ 6 | use_musdb: false 7 | samplerate: 44100 8 | channels: 2 9 | backend: soundfile # must use soundfile as some mixture would clip with sox. 10 | epochs: 320 11 | -------------------------------------------------------------------------------- /conf/dset/sdx23_labelnoise.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # Musdb + extra tracks 4 | dset: 5 | wav: /shared/home/defossez/data/datasets/moisesdb23_labelnoise_v1.0 6 | use_musdb: false 7 | samplerate: 44100 8 | channels: 2 9 | backend: soundfile # must use soundfile as some mixture would clip with sox. 10 | epochs: 320 11 | -------------------------------------------------------------------------------- /conf/svd/base.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | svd: 4 | penalty: 0 5 | min_size: 1 6 | dim: 50 7 | niters: 4 8 | powm: false 9 | proba: 1 10 | conv_only: false 11 | convtr: false # ideally this should be true, but some models were trained with this to false. 12 | 13 | optim: 14 | beta2: 0.9998 -------------------------------------------------------------------------------- /conf/svd/base2.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | svd: 4 | penalty: 0 5 | min_size: 1 6 | dim: 100 7 | niters: 4 8 | powm: false 9 | proba: 1 10 | conv_only: false 11 | convtr: true 12 | 13 | optim: 14 | beta2: 0.9998 -------------------------------------------------------------------------------- /conf/svd/default.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | -------------------------------------------------------------------------------- /conf/variant/default.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | -------------------------------------------------------------------------------- /conf/variant/example.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | model: hdemucs 4 | hdemucs: 5 | channels: 32 -------------------------------------------------------------------------------- /conf/variant/finetune.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | epochs: 4 4 | batch_size: 16 5 | optim: 6 | lr: 0.0006 7 | test: 8 | every: 1 9 | sdr: false 10 | dset: 11 | segment: 28 12 | shift: 2 13 | 14 | augment: 15 | scale: 16 | proba: 0 17 | shift_same: true 18 | remix: 19 | proba: 0 20 | -------------------------------------------------------------------------------- /demucs.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adefossez/demucs/b9ab48cad45976ba42b2ff17b229c071f0df9390/demucs.png -------------------------------------------------------------------------------- /demucs/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | __version__ = "4.1.0a3" 8 | -------------------------------------------------------------------------------- /demucs/__main__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .separate import main 8 | 9 | if __name__ == '__main__': 10 | main() 11 | -------------------------------------------------------------------------------- /demucs/apply.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | """ 7 | Code to apply a model to a mix. It will handle chunking with overlaps and 8 | inteprolation between chunks, as well as the "shift trick". 9 | """ 10 | from concurrent.futures import ThreadPoolExecutor 11 | import copy 12 | import random 13 | from threading import Lock 14 | import typing as tp 15 | 16 | import torch as th 17 | from torch import nn 18 | from torch.nn import functional as F 19 | import tqdm 20 | 21 | from .demucs import Demucs 22 | from .hdemucs import HDemucs 23 | from .htdemucs import HTDemucs 24 | from .utils import center_trim, DummyPoolExecutor 25 | 26 | Model = tp.Union[Demucs, HDemucs, HTDemucs] 27 | 28 | 29 | class BagOfModels(nn.Module): 30 | def __init__(self, models: tp.List[Model], 31 | weights: tp.Optional[tp.List[tp.List[float]]] = None, 32 | segment: tp.Optional[float] = None): 33 | """ 34 | Represents a bag of models with specific weights. 35 | You should call `apply_model` rather than calling directly the forward here for 36 | optimal performance. 37 | 38 | Args: 39 | models (list[nn.Module]): list of Demucs/HDemucs models. 40 | weights (list[list[float]]): list of weights. If None, assumed to 41 | be all ones, otherwise it should be a list of N list (N number of models), 42 | each containing S floats (S number of sources). 43 | segment (None or float): overrides the `segment` attribute of each model 44 | (this is performed inplace, be careful is you reuse the models passed). 45 | """ 46 | super().__init__() 47 | assert len(models) > 0 48 | first = models[0] 49 | for other in models: 50 | assert other.sources == first.sources 51 | assert other.samplerate == first.samplerate 52 | assert other.audio_channels == first.audio_channels 53 | if segment is not None: 54 | if not isinstance(other, HTDemucs) or segment <= other.segment: 55 | other.segment = segment 56 | 57 | self.audio_channels = first.audio_channels 58 | self.samplerate = first.samplerate 59 | self.sources = first.sources 60 | self.models = nn.ModuleList(models) 61 | 62 | if weights is None: 63 | weights = [[1. for _ in first.sources] for _ in models] 64 | else: 65 | assert len(weights) == len(models) 66 | for weight in weights: 67 | assert len(weight) == len(first.sources) 68 | self.weights = weights 69 | 70 | @property 71 | def max_allowed_segment(self) -> float: 72 | max_allowed_segment = float('inf') 73 | for model in self.models: 74 | if isinstance(model, HTDemucs): 75 | max_allowed_segment = min(max_allowed_segment, float(model.segment)) 76 | return max_allowed_segment 77 | 78 | def forward(self, x): 79 | raise NotImplementedError("Call `apply_model` on this.") 80 | 81 | 82 | class TensorChunk: 83 | def __init__(self, tensor, offset=0, length=None): 84 | total_length = tensor.shape[-1] 85 | assert offset >= 0 86 | assert offset < total_length 87 | 88 | if length is None: 89 | length = total_length - offset 90 | else: 91 | length = min(total_length - offset, length) 92 | 93 | if isinstance(tensor, TensorChunk): 94 | self.tensor = tensor.tensor 95 | self.offset = offset + tensor.offset 96 | else: 97 | self.tensor = tensor 98 | self.offset = offset 99 | self.length = length 100 | self.device = tensor.device 101 | 102 | @property 103 | def shape(self): 104 | shape = list(self.tensor.shape) 105 | shape[-1] = self.length 106 | return shape 107 | 108 | def padded(self, target_length): 109 | delta = target_length - self.length 110 | total_length = self.tensor.shape[-1] 111 | assert delta >= 0 112 | 113 | start = self.offset - delta // 2 114 | end = start + target_length 115 | 116 | correct_start = max(0, start) 117 | correct_end = min(total_length, end) 118 | 119 | pad_left = correct_start - start 120 | pad_right = end - correct_end 121 | 122 | out = F.pad(self.tensor[..., correct_start:correct_end], (pad_left, pad_right)) 123 | assert out.shape[-1] == target_length 124 | return out 125 | 126 | 127 | def tensor_chunk(tensor_or_chunk): 128 | if isinstance(tensor_or_chunk, TensorChunk): 129 | return tensor_or_chunk 130 | else: 131 | assert isinstance(tensor_or_chunk, th.Tensor) 132 | return TensorChunk(tensor_or_chunk) 133 | 134 | 135 | def _replace_dict(_dict: tp.Optional[dict], *subs: tp.Tuple[tp.Hashable, tp.Any]) -> dict: 136 | if _dict is None: 137 | _dict = {} 138 | else: 139 | _dict = copy.copy(_dict) 140 | for key, value in subs: 141 | _dict[key] = value 142 | return _dict 143 | 144 | 145 | def apply_model(model: tp.Union[BagOfModels, Model], 146 | mix: tp.Union[th.Tensor, TensorChunk], 147 | shifts: int = 1, split: bool = True, 148 | overlap: float = 0.25, transition_power: float = 1., 149 | progress: bool = False, device=None, 150 | num_workers: int = 0, segment: tp.Optional[float] = None, 151 | pool=None, lock=None, 152 | callback: tp.Optional[tp.Callable[[dict], None]] = None, 153 | callback_arg: tp.Optional[dict] = None) -> th.Tensor: 154 | """ 155 | Apply model to a given mixture. 156 | 157 | Args: 158 | shifts (int): if > 0, will shift in time `mix` by a random amount between 0 and 0.5 sec 159 | and apply the oppositve shift to the output. This is repeated `shifts` time and 160 | all predictions are averaged. This effectively makes the model time equivariant 161 | and improves SDR by up to 0.2 points. 162 | split (bool): if True, the input will be broken down in 8 seconds extracts 163 | and predictions will be performed individually on each and concatenated. 164 | Useful for model with large memory footprint like Tasnet. 165 | progress (bool): if True, show a progress bar (requires split=True) 166 | device (torch.device, str, or None): if provided, device on which to 167 | execute the computation, otherwise `mix.device` is assumed. 168 | When `device` is different from `mix.device`, only local computations will 169 | be on `device`, while the entire tracks will be stored on `mix.device`. 170 | num_workers (int): if non zero, device is 'cpu', how many threads to 171 | use in parallel. 172 | segment (float or None): override the model segment parameter. 173 | """ 174 | if device is None: 175 | device = mix.device 176 | else: 177 | device = th.device(device) 178 | if pool is None: 179 | if num_workers > 0 and device.type == 'cpu': 180 | pool = ThreadPoolExecutor(num_workers) 181 | else: 182 | pool = DummyPoolExecutor() 183 | if lock is None: 184 | lock = Lock() 185 | callback_arg = _replace_dict( 186 | callback_arg, *{"model_idx_in_bag": 0, "shift_idx": 0, "segment_offset": 0}.items() 187 | ) 188 | kwargs: tp.Dict[str, tp.Any] = { 189 | 'shifts': shifts, 190 | 'split': split, 191 | 'overlap': overlap, 192 | 'transition_power': transition_power, 193 | 'progress': progress, 194 | 'device': device, 195 | 'pool': pool, 196 | 'segment': segment, 197 | 'lock': lock, 198 | } 199 | out: tp.Union[float, th.Tensor] 200 | res: tp.Union[float, th.Tensor] 201 | if isinstance(model, BagOfModels): 202 | # Special treatment for bag of model. 203 | # We explicitely apply multiple times `apply_model` so that the random shifts 204 | # are different for each model. 205 | estimates: tp.Union[float, th.Tensor] = 0. 206 | totals = [0.] * len(model.sources) 207 | callback_arg["models"] = len(model.models) 208 | for sub_model, model_weights in zip(model.models, model.weights): 209 | kwargs["callback"] = (( 210 | lambda d, i=callback_arg["model_idx_in_bag"]: callback( 211 | _replace_dict(d, ("model_idx_in_bag", i))) if callback else None) 212 | ) 213 | original_model_device = next(iter(sub_model.parameters())).device 214 | sub_model.to(device) 215 | 216 | res = apply_model(sub_model, mix, **kwargs, callback_arg=callback_arg) 217 | out = res 218 | sub_model.to(original_model_device) 219 | for k, inst_weight in enumerate(model_weights): 220 | out[:, k, :, :] *= inst_weight 221 | totals[k] += inst_weight 222 | estimates += out 223 | del out 224 | callback_arg["model_idx_in_bag"] += 1 225 | 226 | assert isinstance(estimates, th.Tensor) 227 | for k in range(estimates.shape[1]): 228 | estimates[:, k, :, :] /= totals[k] 229 | return estimates 230 | 231 | if "models" not in callback_arg: 232 | callback_arg["models"] = 1 233 | model.to(device) 234 | model.eval() 235 | assert transition_power >= 1, "transition_power < 1 leads to weird behavior." 236 | batch, channels, length = mix.shape 237 | if shifts: 238 | kwargs['shifts'] = 0 239 | max_shift = int(0.5 * model.samplerate) 240 | mix = tensor_chunk(mix) 241 | assert isinstance(mix, TensorChunk) 242 | padded_mix = mix.padded(length + 2 * max_shift) 243 | out = 0. 244 | for shift_idx in range(shifts): 245 | offset = random.randint(0, max_shift) 246 | shifted = TensorChunk(padded_mix, offset, length + max_shift - offset) 247 | kwargs["callback"] = ( 248 | (lambda d, i=shift_idx: callback(_replace_dict(d, ("shift_idx", i))) 249 | if callback else None) 250 | ) 251 | res = apply_model(model, shifted, **kwargs, callback_arg=callback_arg) 252 | shifted_out = res 253 | out += shifted_out[..., max_shift - offset:] 254 | out /= shifts 255 | assert isinstance(out, th.Tensor) 256 | return out 257 | elif split: 258 | kwargs['split'] = False 259 | out = th.zeros(batch, len(model.sources), channels, length, device=mix.device) 260 | sum_weight = th.zeros(length, device=mix.device) 261 | if segment is None: 262 | segment = model.segment 263 | assert segment is not None and segment > 0. 264 | segment_length: int = int(model.samplerate * segment) 265 | stride = int((1 - overlap) * segment_length) 266 | offsets = range(0, length, stride) 267 | scale = float(format(stride / model.samplerate, ".2f")) 268 | # We start from a triangle shaped weight, with maximal weight in the middle 269 | # of the segment. Then we normalize and take to the power `transition_power`. 270 | # Large values of transition power will lead to sharper transitions. 271 | weight = th.cat([th.arange(1, segment_length // 2 + 1, device=device), 272 | th.arange(segment_length - segment_length // 2, 0, -1, device=device)]) 273 | assert len(weight) == segment_length 274 | # If the overlap < 50%, this will translate to linear transition when 275 | # transition_power is 1. 276 | weight = (weight / weight.max())**transition_power 277 | futures = [] 278 | for offset in offsets: 279 | chunk = TensorChunk(mix, offset, segment_length) 280 | future = pool.submit(apply_model, model, chunk, **kwargs, callback_arg=callback_arg, 281 | callback=(lambda d, i=offset: 282 | callback(_replace_dict(d, ("segment_offset", i))) 283 | if callback else None)) 284 | futures.append((future, offset)) 285 | offset += segment_length 286 | if progress: 287 | futures = tqdm.tqdm(futures, unit_scale=scale, ncols=120, unit='seconds') 288 | for future, offset in futures: 289 | try: 290 | chunk_out = future.result() # type: th.Tensor 291 | except Exception: 292 | pool.shutdown(wait=True, cancel_futures=True) 293 | raise 294 | chunk_length = chunk_out.shape[-1] 295 | out[..., offset:offset + segment_length] += ( 296 | weight[:chunk_length] * chunk_out).to(mix.device) 297 | sum_weight[offset:offset + segment_length] += weight[:chunk_length].to(mix.device) 298 | assert sum_weight.min() > 0 299 | out /= sum_weight 300 | assert isinstance(out, th.Tensor) 301 | return out 302 | else: 303 | valid_length: int 304 | if isinstance(model, HTDemucs) and segment is not None: 305 | valid_length = int(segment * model.samplerate) 306 | elif hasattr(model, 'valid_length'): 307 | valid_length = model.valid_length(length) # type: ignore 308 | else: 309 | valid_length = length 310 | mix = tensor_chunk(mix) 311 | assert isinstance(mix, TensorChunk) 312 | padded_mix = mix.padded(valid_length).to(device) 313 | with lock: 314 | if callback is not None: 315 | callback(_replace_dict(callback_arg, ("state", "start"))) # type: ignore 316 | with th.no_grad(): 317 | out = model(padded_mix) 318 | with lock: 319 | if callback is not None: 320 | callback(_replace_dict(callback_arg, ("state", "end"))) # type: ignore 321 | assert isinstance(out, th.Tensor) 322 | return center_trim(out, length) 323 | -------------------------------------------------------------------------------- /demucs/audio.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import json 7 | import subprocess as sp 8 | from pathlib import Path 9 | 10 | import lameenc 11 | import julius 12 | import numpy as np 13 | from . import audio_legacy 14 | import torch 15 | import torchaudio as ta 16 | import typing as tp 17 | 18 | from .utils import temp_filenames 19 | 20 | 21 | def _read_info(path): 22 | stdout_data = sp.check_output([ 23 | 'ffprobe', "-loglevel", "panic", 24 | str(path), '-print_format', 'json', '-show_format', '-show_streams' 25 | ]) 26 | return json.loads(stdout_data.decode('utf-8')) 27 | 28 | 29 | class AudioFile: 30 | """ 31 | Allows to read audio from any format supported by ffmpeg, as well as resampling or 32 | converting to mono on the fly. See :method:`read` for more details. 33 | """ 34 | def __init__(self, path: Path): 35 | self.path = Path(path) 36 | self._info = None 37 | 38 | def __repr__(self): 39 | features = [("path", self.path)] 40 | features.append(("samplerate", self.samplerate())) 41 | features.append(("channels", self.channels())) 42 | features.append(("streams", len(self))) 43 | features_str = ", ".join(f"{name}={value}" for name, value in features) 44 | return f"AudioFile({features_str})" 45 | 46 | @property 47 | def info(self): 48 | if self._info is None: 49 | self._info = _read_info(self.path) 50 | return self._info 51 | 52 | @property 53 | def duration(self): 54 | return float(self.info['format']['duration']) 55 | 56 | @property 57 | def _audio_streams(self): 58 | return [ 59 | index for index, stream in enumerate(self.info["streams"]) 60 | if stream["codec_type"] == "audio" 61 | ] 62 | 63 | def __len__(self): 64 | return len(self._audio_streams) 65 | 66 | def channels(self, stream=0): 67 | return int(self.info['streams'][self._audio_streams[stream]]['channels']) 68 | 69 | def samplerate(self, stream=0): 70 | return int(self.info['streams'][self._audio_streams[stream]]['sample_rate']) 71 | 72 | def read(self, 73 | seek_time=None, 74 | duration=None, 75 | streams=slice(None), 76 | samplerate=None, 77 | channels=None): 78 | """ 79 | Slightly more efficient implementation than stempeg, 80 | in particular, this will extract all stems at once 81 | rather than having to loop over one file multiple times 82 | for each stream. 83 | 84 | Args: 85 | seek_time (float): seek time in seconds or None if no seeking is needed. 86 | duration (float): duration in seconds to extract or None to extract until the end. 87 | streams (slice, int or list): streams to extract, can be a single int, a list or 88 | a slice. If it is a slice or list, the output will be of size [S, C, T] 89 | with S the number of streams, C the number of channels and T the number of samples. 90 | If it is an int, the output will be [C, T]. 91 | samplerate (int): if provided, will resample on the fly. If None, no resampling will 92 | be done. Original sampling rate can be obtained with :method:`samplerate`. 93 | channels (int): if 1, will convert to mono. We do not rely on ffmpeg for that 94 | as ffmpeg automatically scale by +3dB to conserve volume when playing on speakers. 95 | See https://sound.stackexchange.com/a/42710. 96 | Our definition of mono is simply the average of the two channels. Any other 97 | value will be ignored. 98 | """ 99 | streams = np.array(range(len(self)))[streams] 100 | single = not isinstance(streams, np.ndarray) 101 | if single: 102 | streams = [streams] 103 | 104 | if duration is None: 105 | target_size = None 106 | query_duration = None 107 | else: 108 | target_size = int((samplerate or self.samplerate()) * duration) 109 | query_duration = float((target_size + 1) / (samplerate or self.samplerate())) 110 | 111 | with temp_filenames(len(streams)) as filenames: 112 | command = ['ffmpeg', '-y'] 113 | command += ['-loglevel', 'panic'] 114 | if seek_time: 115 | command += ['-ss', str(seek_time)] 116 | command += ['-i', str(self.path)] 117 | for stream, filename in zip(streams, filenames): 118 | command += ['-map', f'0:{self._audio_streams[stream]}'] 119 | if query_duration is not None: 120 | command += ['-t', str(query_duration)] 121 | command += ['-threads', '1'] 122 | command += ['-f', 'f32le'] 123 | if samplerate is not None: 124 | command += ['-ar', str(samplerate)] 125 | command += [filename] 126 | 127 | sp.run(command, check=True) 128 | wavs = [] 129 | for filename in filenames: 130 | wav = np.fromfile(filename, dtype=np.float32) 131 | wav = torch.from_numpy(wav) 132 | wav = wav.view(-1, self.channels()).t() 133 | if channels is not None: 134 | wav = convert_audio_channels(wav, channels) 135 | if target_size is not None: 136 | wav = wav[..., :target_size] 137 | wavs.append(wav) 138 | wav = torch.stack(wavs, dim=0) 139 | if single: 140 | wav = wav[0] 141 | return wav 142 | 143 | 144 | def convert_audio_channels(wav, channels=2): 145 | """Convert audio to the given number of channels.""" 146 | *shape, src_channels, length = wav.shape 147 | if src_channels == channels: 148 | pass 149 | elif channels == 1: 150 | # Case 1: 151 | # The caller asked 1-channel audio, but the stream have multiple 152 | # channels, downmix all channels. 153 | wav = wav.mean(dim=-2, keepdim=True) 154 | elif src_channels == 1: 155 | # Case 2: 156 | # The caller asked for multiple channels, but the input file have 157 | # one single channel, replicate the audio over all channels. 158 | wav = wav.expand(*shape, channels, length) 159 | elif src_channels >= channels: 160 | # Case 3: 161 | # The caller asked for multiple channels, and the input file have 162 | # more channels than requested. In that case return the first channels. 163 | wav = wav[..., :channels, :] 164 | else: 165 | # Case 4: What is a reasonable choice here? 166 | raise ValueError('The audio file has less channels than requested but is not mono.') 167 | return wav 168 | 169 | 170 | def convert_audio(wav, from_samplerate, to_samplerate, channels) -> torch.Tensor: 171 | """Convert audio from a given samplerate to a target one and target number of channels.""" 172 | wav = convert_audio_channels(wav, channels) 173 | return julius.resample_frac(wav, from_samplerate, to_samplerate) 174 | 175 | 176 | def i16_pcm(wav): 177 | """Convert audio to 16 bits integer PCM format.""" 178 | if wav.dtype.is_floating_point: 179 | return (wav.clamp_(-1, 1) * (2**15 - 1)).short() 180 | else: 181 | return wav 182 | 183 | 184 | def f32_pcm(wav): 185 | """Convert audio to float 32 bits PCM format.""" 186 | if wav.dtype.is_floating_point: 187 | return wav 188 | else: 189 | return wav.float() / (2**15 - 1) 190 | 191 | 192 | def as_dtype_pcm(wav, dtype): 193 | """Convert audio to either f32 pcm or i16 pcm depending on the given dtype.""" 194 | if wav.dtype.is_floating_point: 195 | return f32_pcm(wav) 196 | else: 197 | return i16_pcm(wav) 198 | 199 | 200 | def encode_mp3(wav, path, samplerate=44100, bitrate=320, quality=2, verbose=False): 201 | """Save given audio as mp3. This should work on all OSes.""" 202 | C, T = wav.shape 203 | wav = i16_pcm(wav) 204 | encoder = lameenc.Encoder() 205 | encoder.set_bit_rate(bitrate) 206 | encoder.set_in_sample_rate(samplerate) 207 | encoder.set_channels(C) 208 | encoder.set_quality(quality) # 2-highest, 7-fastest 209 | if not verbose: 210 | encoder.silence() 211 | wav = wav.data.cpu() 212 | wav = wav.transpose(0, 1).numpy() 213 | mp3_data = encoder.encode(wav.tobytes()) 214 | mp3_data += encoder.flush() 215 | with open(path, "wb") as f: 216 | f.write(mp3_data) 217 | 218 | 219 | def prevent_clip(wav, mode='rescale'): 220 | """ 221 | different strategies for avoiding raw clipping. 222 | """ 223 | if mode is None or mode == 'none': 224 | return wav 225 | assert wav.dtype.is_floating_point, "too late for clipping" 226 | if mode == 'rescale': 227 | wav = wav / max(1.01 * wav.abs().max(), 1) 228 | elif mode == 'clamp': 229 | wav = wav.clamp(-0.99, 0.99) 230 | elif mode == 'tanh': 231 | wav = torch.tanh(wav) 232 | else: 233 | raise ValueError(f"Invalid mode {mode}") 234 | return wav 235 | 236 | 237 | def save_audio(wav: torch.Tensor, 238 | path: tp.Union[str, Path], 239 | samplerate: int, 240 | bitrate: int = 320, 241 | clip: tp.Literal["rescale", "clamp", "tanh", "none"] = 'rescale', 242 | bits_per_sample: tp.Literal[16, 24, 32] = 16, 243 | as_float: bool = False, 244 | preset: tp.Literal[2, 3, 4, 5, 6, 7] = 2): 245 | """Save audio file, automatically preventing clipping if necessary 246 | based on the given `clip` strategy. If the path ends in `.mp3`, this 247 | will save as mp3 with the given `bitrate`. Use `preset` to set mp3 quality: 248 | 2 for highest quality, 7 for fastest speed 249 | """ 250 | wav = prevent_clip(wav, mode=clip) 251 | path = Path(path) 252 | suffix = path.suffix.lower() 253 | if suffix == ".mp3": 254 | encode_mp3(wav, path, samplerate, bitrate, preset, verbose=True) 255 | elif suffix == ".wav": 256 | if as_float: 257 | bits_per_sample = 32 258 | encoding = 'PCM_F' 259 | else: 260 | encoding = 'PCM_S' 261 | ta.save(str(path), wav, sample_rate=samplerate, 262 | encoding=encoding, bits_per_sample=bits_per_sample) 263 | elif suffix == ".flac": 264 | ta.save(str(path), wav, sample_rate=samplerate, bits_per_sample=bits_per_sample) 265 | else: 266 | raise ValueError(f"Invalid suffix for path: {suffix}") 267 | -------------------------------------------------------------------------------- /demucs/audio_legacy.py: -------------------------------------------------------------------------------- 1 | # This file is to extend support for torchaudio 2.1 2 | 3 | import importlib 4 | import os 5 | import sys 6 | import warnings 7 | 8 | if not "torchaudio" in sys.modules: 9 | os.environ["TORCHAUDIO_USE_BACKEND_DISPATCHER"] = "0" 10 | elif os.getenv("TORCHAUDIO_USE_BACKEND_DISPATCHER", default="1") == "1": 11 | if sys.modules["torchaudio"].__version__ >= "2.1": 12 | os.environ["TORCHAUDIO_USE_BACKEND_DISPATCHER"] = "0" 13 | importlib.reload(sys.modules["torchaudio"]) 14 | warnings.warn( 15 | "TORCHAUDIO_USE_BACKEND_DISPATCHER is set to 0 and torchaudio is reloaded.", 16 | ImportWarning, 17 | ) 18 | -------------------------------------------------------------------------------- /demucs/augment.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | """Data augmentations. 7 | """ 8 | 9 | import random 10 | import torch as th 11 | from torch import nn 12 | 13 | 14 | class Shift(nn.Module): 15 | """ 16 | Randomly shift audio in time by up to `shift` samples. 17 | """ 18 | def __init__(self, shift=8192, same=False): 19 | super().__init__() 20 | self.shift = shift 21 | self.same = same 22 | 23 | def forward(self, wav): 24 | batch, sources, channels, time = wav.size() 25 | length = time - self.shift 26 | if self.shift > 0: 27 | if not self.training: 28 | wav = wav[..., :length] 29 | else: 30 | srcs = 1 if self.same else sources 31 | offsets = th.randint(self.shift, [batch, srcs, 1, 1], device=wav.device) 32 | offsets = offsets.expand(-1, sources, channels, -1) 33 | indexes = th.arange(length, device=wav.device) 34 | wav = wav.gather(3, indexes + offsets) 35 | return wav 36 | 37 | 38 | class FlipChannels(nn.Module): 39 | """ 40 | Flip left-right channels. 41 | """ 42 | def forward(self, wav): 43 | batch, sources, channels, time = wav.size() 44 | if self.training and wav.size(2) == 2: 45 | left = th.randint(2, (batch, sources, 1, 1), device=wav.device) 46 | left = left.expand(-1, -1, -1, time) 47 | right = 1 - left 48 | wav = th.cat([wav.gather(2, left), wav.gather(2, right)], dim=2) 49 | return wav 50 | 51 | 52 | class FlipSign(nn.Module): 53 | """ 54 | Random sign flip. 55 | """ 56 | def forward(self, wav): 57 | batch, sources, channels, time = wav.size() 58 | if self.training: 59 | signs = th.randint(2, (batch, sources, 1, 1), device=wav.device, dtype=th.float32) 60 | wav = wav * (2 * signs - 1) 61 | return wav 62 | 63 | 64 | class Remix(nn.Module): 65 | """ 66 | Shuffle sources to make new mixes. 67 | """ 68 | def __init__(self, proba=1, group_size=4): 69 | """ 70 | Shuffle sources within one batch. 71 | Each batch is divided into groups of size `group_size` and shuffling is done within 72 | each group separatly. This allow to keep the same probability distribution no matter 73 | the number of GPUs. Without this grouping, using more GPUs would lead to a higher 74 | probability of keeping two sources from the same track together which can impact 75 | performance. 76 | """ 77 | super().__init__() 78 | self.proba = proba 79 | self.group_size = group_size 80 | 81 | def forward(self, wav): 82 | batch, streams, channels, time = wav.size() 83 | device = wav.device 84 | 85 | if self.training and random.random() < self.proba: 86 | group_size = self.group_size or batch 87 | if batch % group_size != 0: 88 | raise ValueError(f"Batch size {batch} must be divisible by group size {group_size}") 89 | groups = batch // group_size 90 | wav = wav.view(groups, group_size, streams, channels, time) 91 | permutations = th.argsort(th.rand(groups, group_size, streams, 1, 1, device=device), 92 | dim=1) 93 | wav = wav.gather(1, permutations.expand(-1, -1, -1, channels, time)) 94 | wav = wav.view(batch, streams, channels, time) 95 | return wav 96 | 97 | 98 | class Scale(nn.Module): 99 | def __init__(self, proba=1., min=0.25, max=1.25): 100 | super().__init__() 101 | self.proba = proba 102 | self.min = min 103 | self.max = max 104 | 105 | def forward(self, wav): 106 | batch, streams, channels, time = wav.size() 107 | device = wav.device 108 | if self.training and random.random() < self.proba: 109 | scales = th.empty(batch, streams, 1, 1, device=device).uniform_(self.min, self.max) 110 | wav *= scales 111 | return wav 112 | -------------------------------------------------------------------------------- /demucs/distrib.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | """Distributed training utilities. 7 | """ 8 | import logging 9 | import pickle 10 | 11 | import numpy as np 12 | import torch 13 | from torch.utils.data.distributed import DistributedSampler 14 | from torch.utils.data import DataLoader, Subset 15 | from torch.nn.parallel.distributed import DistributedDataParallel 16 | 17 | from dora import distrib as dora_distrib 18 | 19 | logger = logging.getLogger(__name__) 20 | rank = 0 21 | world_size = 1 22 | 23 | 24 | def init(): 25 | global rank, world_size 26 | if not torch.distributed.is_initialized(): 27 | dora_distrib.init() 28 | rank = dora_distrib.rank() 29 | world_size = dora_distrib.world_size() 30 | 31 | 32 | def average(metrics, count=1.): 33 | if isinstance(metrics, dict): 34 | keys, values = zip(*sorted(metrics.items())) 35 | values = average(values, count) 36 | return dict(zip(keys, values)) 37 | if world_size == 1: 38 | return metrics 39 | tensor = torch.tensor(list(metrics) + [1], device='cuda', dtype=torch.float32) 40 | tensor *= count 41 | torch.distributed.all_reduce(tensor, op=torch.distributed.ReduceOp.SUM) 42 | return (tensor[:-1] / tensor[-1]).cpu().numpy().tolist() 43 | 44 | 45 | def wrap(model): 46 | if world_size == 1: 47 | return model 48 | else: 49 | return DistributedDataParallel( 50 | model, 51 | # find_unused_parameters=True, 52 | device_ids=[torch.cuda.current_device()], 53 | output_device=torch.cuda.current_device()) 54 | 55 | 56 | def barrier(): 57 | if world_size > 1: 58 | torch.distributed.barrier() 59 | 60 | 61 | def share(obj=None, src=0): 62 | if world_size == 1: 63 | return obj 64 | size = torch.empty(1, device='cuda', dtype=torch.long) 65 | if rank == src: 66 | dump = pickle.dumps(obj) 67 | size[0] = len(dump) 68 | torch.distributed.broadcast(size, src=src) 69 | # size variable is now set to the length of pickled obj in all processes 70 | 71 | if rank == src: 72 | buffer = torch.from_numpy(np.frombuffer(dump, dtype=np.uint8).copy()).cuda() 73 | else: 74 | buffer = torch.empty(size[0].item(), device='cuda', dtype=torch.uint8) 75 | torch.distributed.broadcast(buffer, src=src) 76 | # buffer variable is now set to pickled obj in all processes 77 | 78 | if rank != src: 79 | obj = pickle.loads(buffer.cpu().numpy().tobytes()) 80 | logger.debug(f"Shared object of size {len(buffer)}") 81 | return obj 82 | 83 | 84 | def loader(dataset, *args, shuffle=False, klass=DataLoader, **kwargs): 85 | """ 86 | Create a dataloader properly in case of distributed training. 87 | If a gradient is going to be computed you must set `shuffle=True`. 88 | """ 89 | if world_size == 1: 90 | return klass(dataset, *args, shuffle=shuffle, **kwargs) 91 | 92 | if shuffle: 93 | # train means we will compute backward, we use DistributedSampler 94 | sampler = DistributedSampler(dataset) 95 | # We ignore shuffle, DistributedSampler already shuffles 96 | return klass(dataset, *args, **kwargs, sampler=sampler) 97 | else: 98 | # We make a manual shard, as DistributedSampler otherwise replicate some examples 99 | dataset = Subset(dataset, list(range(rank, len(dataset), world_size))) 100 | return klass(dataset, *args, shuffle=shuffle, **kwargs) 101 | -------------------------------------------------------------------------------- /demucs/ema.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # Inspired from https://github.com/rwightman/pytorch-image-models 8 | from contextlib import contextmanager 9 | 10 | import torch 11 | 12 | from .states import swap_state 13 | 14 | 15 | class ModelEMA: 16 | """ 17 | Perform EMA on a model. You can switch to the EMA weights temporarily 18 | with the `swap` method. 19 | 20 | ema = ModelEMA(model) 21 | with ema.swap(): 22 | # compute valid metrics with averaged model. 23 | """ 24 | def __init__(self, model, decay=0.9999, unbias=True, device='cpu'): 25 | self.decay = decay 26 | self.model = model 27 | self.state = {} 28 | self.count = 0 29 | self.device = device 30 | self.unbias = unbias 31 | 32 | self._init() 33 | 34 | def _init(self): 35 | for key, val in self.model.state_dict().items(): 36 | if val.dtype != torch.float32: 37 | continue 38 | device = self.device or val.device 39 | if key not in self.state: 40 | self.state[key] = val.detach().to(device, copy=True) 41 | 42 | def update(self): 43 | if self.unbias: 44 | self.count = self.count * self.decay + 1 45 | w = 1 / self.count 46 | else: 47 | w = 1 - self.decay 48 | for key, val in self.model.state_dict().items(): 49 | if val.dtype != torch.float32: 50 | continue 51 | device = self.device or val.device 52 | self.state[key].mul_(1 - w) 53 | self.state[key].add_(val.detach().to(device), alpha=w) 54 | 55 | @contextmanager 56 | def swap(self): 57 | with swap_state(self.model, self.state): 58 | yield 59 | 60 | def state_dict(self): 61 | return {'state': self.state, 'count': self.count} 62 | 63 | def load_state_dict(self, state): 64 | self.count = state['count'] 65 | for k, v in state['state'].items(): 66 | self.state[k].copy_(v) 67 | -------------------------------------------------------------------------------- /demucs/evaluate.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | """Test time evaluation, either using the original SDR from [Vincent et al. 2006] 8 | or the newest SDR definition from the MDX 2021 competition (this one will 9 | be reported as `nsdr` for `new sdr`). 10 | """ 11 | 12 | from concurrent import futures 13 | import logging 14 | 15 | from dora.log import LogProgress 16 | import numpy as np 17 | import musdb 18 | import museval 19 | import torch as th 20 | 21 | from .apply import apply_model 22 | from .audio import convert_audio, save_audio 23 | from . import distrib 24 | from .utils import DummyPoolExecutor 25 | 26 | 27 | logger = logging.getLogger(__name__) 28 | 29 | 30 | def new_sdr(references, estimates): 31 | """ 32 | Compute the SDR according to the MDX challenge definition. 33 | Adapted from AIcrowd/music-demixing-challenge-starter-kit (MIT license) 34 | """ 35 | assert references.dim() == 4 36 | assert estimates.dim() == 4 37 | delta = 1e-7 # avoid numerical errors 38 | num = th.sum(th.square(references), dim=(2, 3)) 39 | den = th.sum(th.square(references - estimates), dim=(2, 3)) 40 | num += delta 41 | den += delta 42 | scores = 10 * th.log10(num / den) 43 | return scores 44 | 45 | 46 | def eval_track(references, estimates, win, hop, compute_sdr=True): 47 | references = references.transpose(1, 2).double() 48 | estimates = estimates.transpose(1, 2).double() 49 | 50 | new_scores = new_sdr(references.cpu()[None], estimates.cpu()[None])[0] 51 | 52 | if not compute_sdr: 53 | return None, new_scores 54 | else: 55 | references = references.numpy() 56 | estimates = estimates.numpy() 57 | scores = museval.metrics.bss_eval( 58 | references, estimates, 59 | compute_permutation=False, 60 | window=win, 61 | hop=hop, 62 | framewise_filters=False, 63 | bsseval_sources_version=False)[:-1] 64 | return scores, new_scores 65 | 66 | 67 | def evaluate(solver, compute_sdr=False): 68 | """ 69 | Evaluate model using museval. 70 | compute_sdr=False means using only the MDX definition of the SDR, which 71 | is much faster to evaluate. 72 | """ 73 | 74 | args = solver.args 75 | 76 | output_dir = solver.folder / "results" 77 | output_dir.mkdir(exist_ok=True, parents=True) 78 | json_folder = solver.folder / "results/test" 79 | json_folder.mkdir(exist_ok=True, parents=True) 80 | 81 | # we load tracks from the original musdb set 82 | if args.test.nonhq is None: 83 | test_set = musdb.DB(args.dset.musdb, subsets=["test"], is_wav=True) 84 | else: 85 | test_set = musdb.DB(args.test.nonhq, subsets=["test"], is_wav=False) 86 | src_rate = args.dset.musdb_samplerate 87 | 88 | eval_device = 'cpu' 89 | 90 | model = solver.model 91 | win = int(1. * model.samplerate) 92 | hop = int(1. * model.samplerate) 93 | 94 | indexes = range(distrib.rank, len(test_set), distrib.world_size) 95 | indexes = LogProgress(logger, indexes, updates=args.misc.num_prints, 96 | name='Eval') 97 | pendings = [] 98 | 99 | pool = futures.ProcessPoolExecutor if args.test.workers else DummyPoolExecutor 100 | with pool(args.test.workers) as pool: 101 | for index in indexes: 102 | track = test_set.tracks[index] 103 | 104 | mix = th.from_numpy(track.audio).t().float() 105 | if mix.dim() == 1: 106 | mix = mix[None] 107 | mix = mix.to(solver.device) 108 | ref = mix.mean(dim=0) # mono mixture 109 | mix = (mix - ref.mean()) / ref.std() 110 | mix = convert_audio(mix, src_rate, model.samplerate, model.audio_channels) 111 | estimates = apply_model(model, mix[None], 112 | shifts=args.test.shifts, split=args.test.split, 113 | overlap=args.test.overlap)[0] 114 | estimates = estimates * ref.std() + ref.mean() 115 | estimates = estimates.to(eval_device) 116 | 117 | references = th.stack( 118 | [th.from_numpy(track.targets[name].audio).t() for name in model.sources]) 119 | if references.dim() == 2: 120 | references = references[:, None] 121 | references = references.to(eval_device) 122 | references = convert_audio(references, src_rate, 123 | model.samplerate, model.audio_channels) 124 | if args.test.save: 125 | folder = solver.folder / "wav" / track.name 126 | folder.mkdir(exist_ok=True, parents=True) 127 | for name, estimate in zip(model.sources, estimates): 128 | save_audio(estimate.cpu(), folder / (name + ".mp3"), model.samplerate) 129 | 130 | pendings.append((track.name, pool.submit( 131 | eval_track, references, estimates, win=win, hop=hop, compute_sdr=compute_sdr))) 132 | 133 | pendings = LogProgress(logger, pendings, updates=args.misc.num_prints, 134 | name='Eval (BSS)') 135 | tracks = {} 136 | for track_name, pending in pendings: 137 | pending = pending.result() 138 | scores, nsdrs = pending 139 | tracks[track_name] = {} 140 | for idx, target in enumerate(model.sources): 141 | tracks[track_name][target] = {'nsdr': [float(nsdrs[idx])]} 142 | if scores is not None: 143 | (sdr, isr, sir, sar) = scores 144 | for idx, target in enumerate(model.sources): 145 | values = { 146 | "SDR": sdr[idx].tolist(), 147 | "SIR": sir[idx].tolist(), 148 | "ISR": isr[idx].tolist(), 149 | "SAR": sar[idx].tolist() 150 | } 151 | tracks[track_name][target].update(values) 152 | 153 | all_tracks = {} 154 | for src in range(distrib.world_size): 155 | all_tracks.update(distrib.share(tracks, src)) 156 | 157 | result = {} 158 | metric_names = next(iter(all_tracks.values()))[model.sources[0]] 159 | for metric_name in metric_names: 160 | avg = 0 161 | avg_of_medians = 0 162 | for source in model.sources: 163 | medians = [ 164 | np.nanmedian(all_tracks[track][source][metric_name]) 165 | for track in all_tracks.keys()] 166 | mean = np.mean(medians) 167 | median = np.median(medians) 168 | result[metric_name.lower() + "_" + source] = mean 169 | result[metric_name.lower() + "_med" + "_" + source] = median 170 | avg += mean / len(model.sources) 171 | avg_of_medians += median / len(model.sources) 172 | result[metric_name.lower()] = avg 173 | result[metric_name.lower() + "_med"] = avg_of_medians 174 | return result 175 | -------------------------------------------------------------------------------- /demucs/grids/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adefossez/demucs/b9ab48cad45976ba42b2ff17b229c071f0df9390/demucs/grids/__init__.py -------------------------------------------------------------------------------- /demucs/grids/_explorers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | from dora import Explorer 7 | import treetable as tt 8 | 9 | 10 | class MyExplorer(Explorer): 11 | test_metrics = ['nsdr', 'sdr_med'] 12 | 13 | def get_grid_metrics(self): 14 | """Return the metrics that should be displayed in the tracking table. 15 | """ 16 | return [ 17 | tt.group("train", [ 18 | tt.leaf("epoch"), 19 | tt.leaf("reco", ".3f"), 20 | ], align=">"), 21 | tt.group("valid", [ 22 | tt.leaf("penalty", ".1f"), 23 | tt.leaf("ms", ".1f"), 24 | tt.leaf("reco", ".2%"), 25 | tt.leaf("breco", ".2%"), 26 | tt.leaf("b_nsdr", ".2f"), 27 | # tt.leaf("b_nsdr_drums", ".2f"), 28 | # tt.leaf("b_nsdr_bass", ".2f"), 29 | # tt.leaf("b_nsdr_other", ".2f"), 30 | # tt.leaf("b_nsdr_vocals", ".2f"), 31 | ], align=">"), 32 | tt.group("test", [ 33 | tt.leaf(name, ".2f") 34 | for name in self.test_metrics 35 | ], align=">") 36 | ] 37 | 38 | def process_history(self, history): 39 | train = { 40 | 'epoch': len(history), 41 | } 42 | valid = {} 43 | test = {} 44 | best_v_main = float('inf') 45 | breco = float('inf') 46 | for metrics in history: 47 | train.update(metrics['train']) 48 | valid.update(metrics['valid']) 49 | if 'main' in metrics['valid']: 50 | best_v_main = min(best_v_main, metrics['valid']['main']['loss']) 51 | valid['bmain'] = best_v_main 52 | valid['breco'] = min(breco, metrics['valid']['reco']) 53 | breco = valid['breco'] 54 | if (metrics['valid']['loss'] == metrics['valid']['best'] or 55 | metrics['valid'].get('nsdr') == metrics['valid']['best']): 56 | for k, v in metrics['valid'].items(): 57 | if k.startswith('reco_'): 58 | valid['b_' + k[len('reco_'):]] = v 59 | if k.startswith('nsdr'): 60 | valid[f'b_{k}'] = v 61 | if 'test' in metrics: 62 | test.update(metrics['test']) 63 | metrics = history[-1] 64 | return {"train": train, "valid": valid, "test": test} 65 | -------------------------------------------------------------------------------- /demucs/grids/mdx.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | """ 7 | Main training for the Track A MDX models. 8 | """ 9 | 10 | from ._explorers import MyExplorer 11 | from ..train import main 12 | 13 | 14 | TRACK_A = ['0d19c1c6', '7ecf8ec1', 'c511e2ab', '7d865c68'] 15 | 16 | 17 | @MyExplorer 18 | def explorer(launcher): 19 | launcher.slurm_( 20 | gpus=8, 21 | time=3 * 24 * 60, 22 | partition='learnlab') 23 | 24 | # Reproduce results from MDX competition Track A 25 | # This trains the first round of models. Once this is trained, 26 | # you will need to schedule `mdx_refine`. 27 | for sig in TRACK_A: 28 | xp = main.get_xp_from_sig(sig) 29 | parent = xp.cfg.continue_from 30 | xp = main.get_xp_from_sig(parent) 31 | launcher(xp.argv) 32 | launcher(xp.argv, {'quant.diffq': 1e-4}) 33 | launcher(xp.argv, {'quant.diffq': 3e-4}) 34 | -------------------------------------------------------------------------------- /demucs/grids/mdx_extra.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | """ 7 | Main training for the Track A MDX models. 8 | """ 9 | 10 | from ._explorers import MyExplorer 11 | from ..train import main 12 | 13 | TRACK_B = ['e51eebcc', 'a1d90b5c', '5d2d6c55', 'cfa93e08'] 14 | 15 | 16 | @MyExplorer 17 | def explorer(launcher): 18 | launcher.slurm_( 19 | gpus=8, 20 | time=3 * 24 * 60, 21 | partition='learnlab') 22 | 23 | # Reproduce results from MDX competition Track A 24 | # This trains the first round of models. Once this is trained, 25 | # you will need to schedule `mdx_refine`. 26 | for sig in TRACK_B: 27 | while sig is not None: 28 | xp = main.get_xp_from_sig(sig) 29 | sig = xp.cfg.continue_from 30 | 31 | for dset in ['extra44', 'extra_test']: 32 | sub = launcher.bind(xp.argv, dset=dset) 33 | sub() 34 | if dset == 'extra_test': 35 | sub({'quant.diffq': 1e-4}) 36 | sub({'quant.diffq': 3e-4}) 37 | -------------------------------------------------------------------------------- /demucs/grids/mdx_refine.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | """ 7 | Main training for the Track A MDX models. 8 | """ 9 | 10 | from ._explorers import MyExplorer 11 | from .mdx import TRACK_A 12 | from ..train import main 13 | 14 | 15 | @MyExplorer 16 | def explorer(launcher): 17 | launcher.slurm_( 18 | gpus=8, 19 | time=3 * 24 * 60, 20 | partition='learnlab') 21 | 22 | # Reproduce results from MDX competition Track A 23 | # WARNING: all the experiments in the `mdx` grid must have completed. 24 | for sig in TRACK_A: 25 | xp = main.get_xp_from_sig(sig) 26 | launcher(xp.argv) 27 | for diffq in [1e-4, 3e-4]: 28 | xp_src = main.get_xp_from_sig(xp.cfg.continue_from) 29 | q_argv = [f'quant.diffq={diffq}'] 30 | actual_src = main.get_xp(xp_src.argv + q_argv) 31 | actual_src.link.load() 32 | assert len(actual_src.link.history) == actual_src.cfg.epochs 33 | argv = xp.argv + q_argv + [f'continue_from="{actual_src.sig}"'] 34 | launcher(argv) 35 | -------------------------------------------------------------------------------- /demucs/grids/mmi.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from ._explorers import MyExplorer 8 | from dora import Launcher 9 | 10 | 11 | @MyExplorer 12 | def explorer(launcher: Launcher): 13 | launcher.slurm_(gpus=8, time=3 * 24 * 60, partition="devlab,learnlab,learnfair") # 3 days 14 | 15 | sub = launcher.bind_( 16 | { 17 | "dset": "extra_mmi_goodclean", 18 | "test.shifts": 0, 19 | "model": "htdemucs", 20 | "htdemucs.dconv_mode": 3, 21 | "htdemucs.depth": 4, 22 | "htdemucs.t_dropout": 0.02, 23 | "htdemucs.t_layers": 5, 24 | "max_batches": 800, 25 | "ema.epoch": [0.9, 0.95], 26 | "ema.batch": [0.9995, 0.9999], 27 | "dset.segment": 10, 28 | "batch_size": 32, 29 | } 30 | ) 31 | sub({"model": "hdemucs"}) 32 | sub({"model": "hdemucs", "dset": "extra44"}) 33 | sub({"model": "hdemucs", "dset": "musdb44"}) 34 | 35 | sparse = { 36 | 'batch_size': 3 * 8, 37 | 'augment.remix.group_size': 3, 38 | 'htdemucs.t_auto_sparsity': True, 39 | 'htdemucs.t_sparse_self_attn': True, 40 | 'htdemucs.t_sparse_cross_attn': True, 41 | 'htdemucs.t_sparsity': 0.9, 42 | "htdemucs.t_layers": 7 43 | } 44 | 45 | with launcher.job_array(): 46 | for transf_layers in [5, 7]: 47 | for bottom_channels in [0, 512]: 48 | sub = launcher.bind({ 49 | "htdemucs.t_layers": transf_layers, 50 | "htdemucs.bottom_channels": bottom_channels, 51 | }) 52 | if bottom_channels == 0 and transf_layers == 5: 53 | sub({"augment.remix.proba": 0.0}) 54 | sub({ 55 | "augment.repitch.proba": 0.0, 56 | # when doing repitching, we trim the outut to align on the 57 | # highest change of BPM. When removing repitching, 58 | # we simulate it here to ensure the training context is the same. 59 | # Another second is lost for all experiments due to the random 60 | # shift augmentation. 61 | "dset.segment": 10 * 0.88}) 62 | elif bottom_channels == 512 and transf_layers == 5: 63 | sub(dset="musdb44") 64 | sub(dset="extra44") 65 | # Sparse kernel XP, currently not released as kernels are still experimental. 66 | sub(sparse, {'dset.segment': 15, "htdemucs.t_layers": 7}) 67 | 68 | for duration in [5, 10, 15]: 69 | sub({"dset.segment": duration}) 70 | -------------------------------------------------------------------------------- /demucs/grids/mmi_ft.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from ._explorers import MyExplorer 8 | from dora import Launcher 9 | from demucs import train 10 | 11 | 12 | def get_sub(launcher, sig): 13 | xp = train.main.get_xp_from_sig(sig) 14 | sub = launcher.bind(xp.argv) 15 | sub() 16 | sub.bind_({ 17 | 'continue_from': sig, 18 | 'continue_best': True}) 19 | return sub 20 | 21 | 22 | @MyExplorer 23 | def explorer(launcher: Launcher): 24 | launcher.slurm_(gpus=4, time=3 * 24 * 60, partition="devlab,learnlab,learnfair") # 3 days 25 | ft = { 26 | 'optim.lr': 1e-4, 27 | 'augment.remix.proba': 0, 28 | 'augment.scale.proba': 0, 29 | 'augment.shift_same': True, 30 | 'htdemucs.t_weight_decay': 0.05, 31 | 'batch_size': 8, 32 | 'optim.clip_grad': 5, 33 | 'optim.optim': 'adamw', 34 | 'epochs': 50, 35 | 'dset.wav2_valid': True, 36 | 'ema.epoch': [], # let's make valid a bit faster 37 | } 38 | with launcher.job_array(): 39 | for sig in ['2899e11a']: 40 | sub = get_sub(launcher, sig) 41 | sub.bind_(ft) 42 | for segment in [15, 18]: 43 | for source in range(4): 44 | w = [0] * 4 45 | w[source] = 1 46 | sub({'weights': w, 'dset.segment': segment}) 47 | 48 | for sig in ['955717e8']: 49 | sub = get_sub(launcher, sig) 50 | sub.bind_(ft) 51 | for segment in [10, 15]: 52 | for source in range(4): 53 | w = [0] * 4 54 | w[source] = 1 55 | sub({'weights': w, 'dset.segment': segment}) 56 | -------------------------------------------------------------------------------- /demucs/grids/repro.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | """ 7 | Easier training for reproducibility 8 | """ 9 | 10 | from ._explorers import MyExplorer 11 | 12 | 13 | @MyExplorer 14 | def explorer(launcher): 15 | launcher.slurm_( 16 | gpus=8, 17 | time=3 * 24 * 60, 18 | partition='devlab,learnlab') 19 | 20 | launcher.bind_({'ema.epoch': [0.9, 0.95]}) 21 | launcher.bind_({'ema.batch': [0.9995, 0.9999]}) 22 | launcher.bind_({'epochs': 600}) 23 | 24 | base = {'model': 'demucs', 'demucs.dconv_mode': 0, 'demucs.gelu': False, 25 | 'demucs.lstm_layers': 2} 26 | newt = {'model': 'demucs', 'demucs.normalize': True} 27 | hdem = {'model': 'hdemucs'} 28 | svd = {'svd.penalty': 1e-5, 'svd': 'base2'} 29 | 30 | with launcher.job_array(): 31 | for model in [base, newt, hdem]: 32 | sub = launcher.bind(model) 33 | if model is base: 34 | # Training the v2 Demucs on MusDB HQ 35 | sub(epochs=360) 36 | continue 37 | 38 | # those two will be used in the repro_mdx_a bag of models. 39 | sub(svd) 40 | sub(svd, seed=43) 41 | if model == newt: 42 | # Ablation study 43 | sub() 44 | abl = sub.bind(svd) 45 | abl({'ema.epoch': [], 'ema.batch': []}) 46 | abl({'demucs.dconv_lstm': 10}) 47 | abl({'demucs.dconv_attn': 10}) 48 | abl({'demucs.dconv_attn': 10, 'demucs.dconv_lstm': 10, 'demucs.lstm_layers': 2}) 49 | abl({'demucs.dconv_mode': 0}) 50 | abl({'demucs.gelu': False}) 51 | -------------------------------------------------------------------------------- /demucs/grids/repro_ft.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | """ 7 | Fine tuning experiments 8 | """ 9 | 10 | from ._explorers import MyExplorer 11 | from ..train import main 12 | 13 | 14 | @MyExplorer 15 | def explorer(launcher): 16 | launcher.slurm_( 17 | gpus=8, 18 | time=300, 19 | partition='devlab,learnlab') 20 | 21 | # Mus 22 | launcher.slurm_(constraint='volta32gb') 23 | 24 | grid = "repro" 25 | folder = main.dora.dir / "grids" / grid 26 | 27 | for sig in folder.iterdir(): 28 | if not sig.is_symlink(): 29 | continue 30 | xp = main.get_xp_from_sig(sig) 31 | xp.link.load() 32 | if len(xp.link.history) != xp.cfg.epochs: 33 | continue 34 | sub = launcher.bind(xp.argv, [f'continue_from="{xp.sig}"']) 35 | sub.bind_({'ema.epoch': [0.9, 0.95], 'ema.batch': [0.9995, 0.9999]}) 36 | sub.bind_({'test.every': 1, 'test.sdr': True, 'epochs': 4}) 37 | sub.bind_({'dset.segment': 28, 'dset.shift': 2}) 38 | sub.bind_({'batch_size': 32}) 39 | auto = {'dset': 'auto_mus'} 40 | auto.update({'augment.remix.proba': 0, 'augment.scale.proba': 0, 41 | 'augment.shift_same': True}) 42 | sub.bind_(auto) 43 | sub.bind_({'batch_size': 16}) 44 | sub.bind_({'optim.lr': 1e-4}) 45 | sub.bind_({'model_segment': 44}) 46 | sub() 47 | -------------------------------------------------------------------------------- /demucs/grids/sdx23.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from ._explorers import MyExplorer 8 | from dora import Launcher 9 | 10 | 11 | @MyExplorer 12 | def explorer(launcher: Launcher): 13 | launcher.slurm_(gpus=8, time=3 * 24 * 60, partition="speechgpt,learnfair", 14 | mem_per_gpu=None, constraint='') 15 | launcher.bind_({"dset.use_musdb": False}) 16 | 17 | with launcher.job_array(): 18 | launcher(dset='sdx23_bleeding') 19 | launcher(dset='sdx23_labelnoise') 20 | -------------------------------------------------------------------------------- /demucs/pretrained.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | """Loading pretrained models. 7 | """ 8 | 9 | import logging 10 | from pathlib import Path 11 | import typing as tp 12 | 13 | from dora.log import fatal, bold 14 | 15 | from .hdemucs import HDemucs 16 | from .repo import RemoteRepo, LocalRepo, ModelOnlyRepo, BagOnlyRepo, AnyModelRepo, ModelLoadingError # noqa 17 | from .states import _check_diffq 18 | 19 | logger = logging.getLogger(__name__) 20 | ROOT_URL = "https://dl.fbaipublicfiles.com/demucs/" 21 | REMOTE_ROOT = Path(__file__).parent / 'remote' 22 | 23 | SOURCES = ["drums", "bass", "other", "vocals"] 24 | DEFAULT_MODEL = 'htdemucs' 25 | 26 | 27 | def demucs_unittest(): 28 | model = HDemucs(channels=4, sources=SOURCES) 29 | return model 30 | 31 | 32 | def add_model_flags(parser): 33 | group = parser.add_mutually_exclusive_group(required=False) 34 | group.add_argument("-s", "--sig", help="Locally trained XP signature.") 35 | group.add_argument("-n", "--name", default="htdemucs", 36 | help="Pretrained model name or signature. Default is htdemucs.") 37 | parser.add_argument("--repo", type=Path, 38 | help="Folder containing all pre-trained models for use with -n.") 39 | 40 | 41 | def _parse_remote_files(remote_file_list) -> tp.Dict[str, str]: 42 | root: str = '' 43 | models: tp.Dict[str, str] = {} 44 | for line in remote_file_list.read_text().split('\n'): 45 | line = line.strip() 46 | if line.startswith('#'): 47 | continue 48 | elif len(line) == 0: 49 | continue 50 | elif line.startswith('root:'): 51 | root = line.split(':', 1)[1].strip() 52 | else: 53 | sig = line.split('-', 1)[0] 54 | assert sig not in models 55 | models[sig] = ROOT_URL + root + line 56 | return models 57 | 58 | 59 | def get_model(name: str, 60 | repo: tp.Optional[Path] = None): 61 | """`name` must be a bag of models name or a pretrained signature 62 | from the remote AWS model repo or the specified local repo if `repo` is not None. 63 | """ 64 | if name == 'demucs_unittest': 65 | return demucs_unittest() 66 | model_repo: ModelOnlyRepo 67 | if repo is None: 68 | models = _parse_remote_files(REMOTE_ROOT / 'files.txt') 69 | model_repo = RemoteRepo(models) 70 | bag_repo = BagOnlyRepo(REMOTE_ROOT, model_repo) 71 | else: 72 | if not repo.is_dir(): 73 | fatal(f"{repo} must exist and be a directory.") 74 | model_repo = LocalRepo(repo) 75 | bag_repo = BagOnlyRepo(repo, model_repo) 76 | any_repo = AnyModelRepo(model_repo, bag_repo) 77 | try: 78 | model = any_repo.get_model(name) 79 | except ImportError as exc: 80 | if 'diffq' in exc.args[0]: 81 | _check_diffq() 82 | raise 83 | 84 | model.eval() 85 | return model 86 | 87 | 88 | def get_model_from_args(args): 89 | """ 90 | Load local model package or pre-trained model. 91 | """ 92 | if args.name is None: 93 | args.name = DEFAULT_MODEL 94 | print(bold("Important: the default model was recently changed to `htdemucs`"), 95 | "the latest Hybrid Transformer Demucs model. In some cases, this model can " 96 | "actually perform worse than previous models. To get back the old default model " 97 | "use `-n mdx_extra_q`.") 98 | return get_model(name=args.name, repo=args.repo) 99 | -------------------------------------------------------------------------------- /demucs/py.typed: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adefossez/demucs/b9ab48cad45976ba42b2ff17b229c071f0df9390/demucs/py.typed -------------------------------------------------------------------------------- /demucs/remote/files.txt: -------------------------------------------------------------------------------- 1 | # MDX Models 2 | root: mdx_final/ 3 | 0d19c1c6-0f06f20e.th 4 | 5d2d6c55-db83574e.th 5 | 7d865c68-3d5dd56b.th 6 | 7ecf8ec1-70f50cc9.th 7 | a1d90b5c-ae9d2452.th 8 | c511e2ab-fe698775.th 9 | cfa93e08-61801ae1.th 10 | e51eebcc-c1b80bdd.th 11 | 6b9c2ca1-3fd82607.th 12 | b72baf4e-8778635e.th 13 | 42e558d4-196e0e1b.th 14 | 305bc58f-18378783.th 15 | 14fc6a69-a89dd0ee.th 16 | 464b36d7-e5a9386e.th 17 | 7fd6ef75-a905dd85.th 18 | 83fc094f-4a16d450.th 19 | 1ef250f1-592467ce.th 20 | 902315c2-b39ce9c9.th 21 | 9a6b4851-03af0aa6.th 22 | fa0cb7f9-100d8bf4.th 23 | # Hybrid Transformer models 24 | root: hybrid_transformer/ 25 | 955717e8-8726e21a.th 26 | f7e0c4bc-ba3fe64a.th 27 | d12395a8-e57c48e6.th 28 | 92cfc3b6-ef3bcb9c.th 29 | 04573f0d-f3cf25b2.th 30 | 75fc33f5-1941ce65.th 31 | # Experimental 6 sources model 32 | 5c90dfd2-34c22ccb.th 33 | -------------------------------------------------------------------------------- /demucs/remote/hdemucs_mmi.yaml: -------------------------------------------------------------------------------- 1 | models: ['75fc33f5'] 2 | segment: 44 3 | -------------------------------------------------------------------------------- /demucs/remote/htdemucs.yaml: -------------------------------------------------------------------------------- 1 | models: ['955717e8'] 2 | -------------------------------------------------------------------------------- /demucs/remote/htdemucs_6s.yaml: -------------------------------------------------------------------------------- 1 | models: ['5c90dfd2'] 2 | -------------------------------------------------------------------------------- /demucs/remote/htdemucs_ft.yaml: -------------------------------------------------------------------------------- 1 | models: ['f7e0c4bc', 'd12395a8', '92cfc3b6', '04573f0d'] 2 | weights: [ 3 | [1., 0., 0., 0.], 4 | [0., 1., 0., 0.], 5 | [0., 0., 1., 0.], 6 | [0., 0., 0., 1.], 7 | ] -------------------------------------------------------------------------------- /demucs/remote/mdx.yaml: -------------------------------------------------------------------------------- 1 | models: ['0d19c1c6', '7ecf8ec1', 'c511e2ab', '7d865c68'] 2 | weights: [ 3 | [1., 1., 0., 0.], 4 | [0., 1., 0., 0.], 5 | [1., 0., 1., 1.], 6 | [1., 0., 1., 1.], 7 | ] 8 | segment: 44 9 | -------------------------------------------------------------------------------- /demucs/remote/mdx_extra.yaml: -------------------------------------------------------------------------------- 1 | models: ['e51eebcc', 'a1d90b5c', '5d2d6c55', 'cfa93e08'] 2 | segment: 44 -------------------------------------------------------------------------------- /demucs/remote/mdx_extra_q.yaml: -------------------------------------------------------------------------------- 1 | models: ['83fc094f', '464b36d7', '14fc6a69', '7fd6ef75'] 2 | segment: 44 3 | -------------------------------------------------------------------------------- /demucs/remote/mdx_q.yaml: -------------------------------------------------------------------------------- 1 | models: ['6b9c2ca1', 'b72baf4e', '42e558d4', '305bc58f'] 2 | weights: [ 3 | [1., 1., 0., 0.], 4 | [0., 1., 0., 0.], 5 | [1., 0., 1., 1.], 6 | [1., 0., 1., 1.], 7 | ] 8 | segment: 44 9 | -------------------------------------------------------------------------------- /demucs/remote/repro_mdx_a.yaml: -------------------------------------------------------------------------------- 1 | models: ['9a6b4851', '1ef250f1', 'fa0cb7f9', '902315c2'] 2 | segment: 44 3 | -------------------------------------------------------------------------------- /demucs/remote/repro_mdx_a_hybrid_only.yaml: -------------------------------------------------------------------------------- 1 | models: ['fa0cb7f9', '902315c2', 'fa0cb7f9', '902315c2'] 2 | segment: 44 3 | -------------------------------------------------------------------------------- /demucs/remote/repro_mdx_a_time_only.yaml: -------------------------------------------------------------------------------- 1 | models: ['9a6b4851', '9a6b4851', '1ef250f1', '1ef250f1'] 2 | segment: 44 3 | -------------------------------------------------------------------------------- /demucs/repitch.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | """Utility for on the fly pitch/tempo change for data augmentation.""" 7 | 8 | import random 9 | import subprocess as sp 10 | import tempfile 11 | 12 | from . import audio_legacy 13 | import torch 14 | import torchaudio as ta 15 | 16 | from .audio import save_audio 17 | 18 | 19 | class RepitchedWrapper: 20 | """ 21 | Wrap a dataset to apply online change of pitch / tempo. 22 | """ 23 | def __init__(self, dataset, proba=0.2, max_pitch=2, max_tempo=12, 24 | tempo_std=5, vocals=[3], same=True): 25 | self.dataset = dataset 26 | self.proba = proba 27 | self.max_pitch = max_pitch 28 | self.max_tempo = max_tempo 29 | self.tempo_std = tempo_std 30 | self.same = same 31 | self.vocals = vocals 32 | 33 | def __len__(self): 34 | return len(self.dataset) 35 | 36 | def __getitem__(self, index): 37 | streams = self.dataset[index] 38 | in_length = streams.shape[-1] 39 | out_length = int((1 - 0.01 * self.max_tempo) * in_length) 40 | 41 | if random.random() < self.proba: 42 | outs = [] 43 | for idx, stream in enumerate(streams): 44 | if idx == 0 or not self.same: 45 | delta_pitch = random.randint(-self.max_pitch, self.max_pitch) 46 | delta_tempo = random.gauss(0, self.tempo_std) 47 | delta_tempo = min(max(-self.max_tempo, delta_tempo), self.max_tempo) 48 | stream = repitch( 49 | stream, 50 | delta_pitch, 51 | delta_tempo, 52 | voice=idx in self.vocals) 53 | outs.append(stream[:, :out_length]) 54 | streams = torch.stack(outs) 55 | else: 56 | streams = streams[..., :out_length] 57 | return streams 58 | 59 | 60 | def repitch(wav, pitch, tempo, voice=False, quick=False, samplerate=44100): 61 | """ 62 | tempo is a relative delta in percentage, so tempo=10 means tempo at 110%! 63 | pitch is in semi tones. 64 | Requires `soundstretch` to be installed, see 65 | https://www.surina.net/soundtouch/soundstretch.html 66 | """ 67 | infile = tempfile.NamedTemporaryFile(suffix=".wav") 68 | outfile = tempfile.NamedTemporaryFile(suffix=".wav") 69 | save_audio(wav, infile.name, samplerate, clip='clamp') 70 | command = [ 71 | "soundstretch", 72 | infile.name, 73 | outfile.name, 74 | f"-pitch={pitch}", 75 | f"-tempo={tempo:.6f}", 76 | ] 77 | if quick: 78 | command += ["-quick"] 79 | if voice: 80 | command += ["-speech"] 81 | try: 82 | sp.run(command, capture_output=True, check=True) 83 | except sp.CalledProcessError as error: 84 | raise RuntimeError(f"Could not change bpm because {error.stderr.decode('utf-8')}") 85 | wav, sr = ta.load(outfile.name) 86 | assert sr == samplerate 87 | return wav 88 | -------------------------------------------------------------------------------- /demucs/repo.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | """Represents a model repository, including pre-trained models and bags of models. 7 | A repo can either be the main remote repository stored in AWS, or a local repository 8 | with your own models. 9 | """ 10 | 11 | from hashlib import sha256 12 | from pathlib import Path 13 | import typing as tp 14 | 15 | import torch 16 | import yaml 17 | 18 | from .apply import BagOfModels, Model 19 | from .states import load_model 20 | 21 | 22 | AnyModel = tp.Union[Model, BagOfModels] 23 | 24 | 25 | class ModelLoadingError(RuntimeError): 26 | pass 27 | 28 | 29 | def check_checksum(path: Path, checksum: str): 30 | sha = sha256() 31 | with open(path, 'rb') as file: 32 | while True: 33 | buf = file.read(2**20) 34 | if not buf: 35 | break 36 | sha.update(buf) 37 | actual_checksum = sha.hexdigest()[:len(checksum)] 38 | if actual_checksum != checksum: 39 | raise ModelLoadingError(f'Invalid checksum for file {path}, ' 40 | f'expected {checksum} but got {actual_checksum}') 41 | 42 | 43 | class ModelOnlyRepo: 44 | """Base class for all model only repos. 45 | """ 46 | def has_model(self, sig: str) -> bool: 47 | raise NotImplementedError() 48 | 49 | def get_model(self, sig: str) -> Model: 50 | raise NotImplementedError() 51 | 52 | def list_model(self) -> tp.Dict[str, tp.Union[str, Path]]: 53 | raise NotImplementedError() 54 | 55 | 56 | class RemoteRepo(ModelOnlyRepo): 57 | def __init__(self, models: tp.Dict[str, str]): 58 | self._models = models 59 | 60 | def has_model(self, sig: str) -> bool: 61 | return sig in self._models 62 | 63 | def get_model(self, sig: str) -> Model: 64 | try: 65 | url = self._models[sig] 66 | except KeyError: 67 | raise ModelLoadingError(f'Could not find a pre-trained model with signature {sig}.') 68 | pkg = torch.hub.load_state_dict_from_url( 69 | url, map_location='cpu', check_hash=True) # type: ignore 70 | return load_model(pkg) 71 | 72 | def list_model(self) -> tp.Dict[str, tp.Union[str, Path]]: 73 | return self._models # type: ignore 74 | 75 | 76 | class LocalRepo(ModelOnlyRepo): 77 | def __init__(self, root: Path): 78 | self.root = root 79 | self.scan() 80 | 81 | def scan(self): 82 | self._models = {} 83 | self._checksums = {} 84 | for file in self.root.iterdir(): 85 | if file.suffix == '.th': 86 | if '-' in file.stem: 87 | xp_sig, checksum = file.stem.split('-') 88 | self._checksums[xp_sig] = checksum 89 | else: 90 | xp_sig = file.stem 91 | if xp_sig in self._models: 92 | raise ModelLoadingError( 93 | f'Duplicate pre-trained model exist for signature {xp_sig}. ' 94 | 'Please delete all but one.') 95 | self._models[xp_sig] = file 96 | 97 | def has_model(self, sig: str) -> bool: 98 | return sig in self._models 99 | 100 | def get_model(self, sig: str) -> Model: 101 | try: 102 | file = self._models[sig] 103 | except KeyError: 104 | raise ModelLoadingError(f'Could not find pre-trained model with signature {sig}.') 105 | if sig in self._checksums: 106 | check_checksum(file, self._checksums[sig]) 107 | return load_model(file) 108 | 109 | def list_model(self) -> tp.Dict[str, tp.Union[str, Path]]: 110 | return self._models 111 | 112 | 113 | class BagOnlyRepo: 114 | """Handles only YAML files containing bag of models, leaving the actual 115 | model loading to some Repo. 116 | """ 117 | def __init__(self, root: Path, model_repo: ModelOnlyRepo): 118 | self.root = root 119 | self.model_repo = model_repo 120 | self.scan() 121 | 122 | def scan(self): 123 | self._bags = {} 124 | for file in self.root.iterdir(): 125 | if file.suffix == '.yaml': 126 | self._bags[file.stem] = file 127 | 128 | def has_model(self, name: str) -> bool: 129 | return name in self._bags 130 | 131 | def get_model(self, name: str) -> BagOfModels: 132 | try: 133 | yaml_file = self._bags[name] 134 | except KeyError: 135 | raise ModelLoadingError(f'{name} is neither a single pre-trained model or ' 136 | 'a bag of models.') 137 | bag = yaml.safe_load(open(yaml_file)) 138 | signatures = bag['models'] 139 | models = [self.model_repo.get_model(sig) for sig in signatures] 140 | weights = bag.get('weights') 141 | segment = bag.get('segment') 142 | return BagOfModels(models, weights, segment) 143 | 144 | def list_model(self) -> tp.Dict[str, tp.Union[str, Path]]: 145 | return self._bags 146 | 147 | 148 | class AnyModelRepo: 149 | def __init__(self, model_repo: ModelOnlyRepo, bag_repo: BagOnlyRepo): 150 | self.model_repo = model_repo 151 | self.bag_repo = bag_repo 152 | 153 | def has_model(self, name_or_sig: str) -> bool: 154 | return self.model_repo.has_model(name_or_sig) or self.bag_repo.has_model(name_or_sig) 155 | 156 | def get_model(self, name_or_sig: str) -> AnyModel: 157 | if self.model_repo.has_model(name_or_sig): 158 | return self.model_repo.get_model(name_or_sig) 159 | else: 160 | return self.bag_repo.get_model(name_or_sig) 161 | 162 | def list_model(self) -> tp.Dict[str, tp.Union[str, Path]]: 163 | models = self.model_repo.list_model() 164 | for key, value in self.bag_repo.list_model().items(): 165 | models[key] = value 166 | return models 167 | -------------------------------------------------------------------------------- /demucs/separate.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import argparse 8 | import sys 9 | from pathlib import Path 10 | 11 | from dora.log import fatal 12 | import torch as th 13 | 14 | from .api import Separator, save_audio, list_models 15 | 16 | from .apply import BagOfModels 17 | from .htdemucs import HTDemucs 18 | from .pretrained import add_model_flags, ModelLoadingError 19 | 20 | 21 | def get_parser(): 22 | parser = argparse.ArgumentParser("demucs.separate", 23 | description="Separate the sources for the given tracks") 24 | parser.add_argument("tracks", nargs='*', type=Path, default=[], help='Path to tracks') 25 | add_model_flags(parser) 26 | parser.add_argument("--list-models", action="store_true", help="List available models " 27 | "from current repo and exit") 28 | parser.add_argument("-v", "--verbose", action="store_true") 29 | parser.add_argument("-o", 30 | "--out", 31 | type=Path, 32 | default=Path("separated"), 33 | help="Folder where to put extracted tracks. A subfolder " 34 | "with the model name will be created.") 35 | parser.add_argument("--filename", 36 | default="{track}/{stem}.{ext}", 37 | help="Set the name of output file. \n" 38 | 'Use "{track}", "{trackext}", "{stem}", "{ext}" to use ' 39 | "variables of track name without extension, track extension, " 40 | "stem name and default output file extension. \n" 41 | 'Default is "{track}/{stem}.{ext}".') 42 | parser.add_argument("-d", 43 | "--device", 44 | default=( 45 | "cuda" 46 | if th.cuda.is_available() 47 | else "mps" 48 | if th.backends.mps.is_available() 49 | else "cpu" 50 | ), 51 | help="Device to use, default is cuda if available else cpu") 52 | parser.add_argument("--shifts", 53 | default=1, 54 | type=int, 55 | help="Number of random shifts for equivariant stabilization." 56 | "Increase separation time but improves quality for Demucs. 10 was used " 57 | "in the original paper.") 58 | parser.add_argument("--overlap", 59 | default=0.25, 60 | type=float, 61 | help="Overlap between the splits.") 62 | split_group = parser.add_mutually_exclusive_group() 63 | split_group.add_argument("--no-split", 64 | action="store_false", 65 | dest="split", 66 | default=True, 67 | help="Doesn't split audio in chunks. " 68 | "This can use large amounts of memory.") 69 | split_group.add_argument("--segment", type=int, 70 | help="Set split size of each chunk. " 71 | "This can help save memory of graphic card. ") 72 | parser.add_argument("--two-stems", 73 | dest="stem", metavar="STEM", 74 | help="Only separate audio into {STEM} and no_{STEM}. ") 75 | parser.add_argument("--other-method", dest="other_method", choices=["none", "add", "minus"], 76 | default="add", help='Decide how to get "no_{STEM}". "none" will not save ' 77 | '"no_{STEM}". "add" will add all the other stems. "minus" will use the ' 78 | "original track minus the selected stem.") 79 | depth_group = parser.add_mutually_exclusive_group() 80 | depth_group.add_argument("--int24", action="store_true", 81 | help="Save wav output as 24 bits wav.") 82 | depth_group.add_argument("--float32", action="store_true", 83 | help="Save wav output as float32 (2x bigger).") 84 | parser.add_argument("--clip-mode", default="rescale", choices=["rescale", "clamp", "none"], 85 | help="Strategy for avoiding clipping: rescaling entire signal " 86 | "if necessary (rescale) or hard clipping (clamp).") 87 | format_group = parser.add_mutually_exclusive_group() 88 | format_group.add_argument("--flac", action="store_true", 89 | help="Convert the output wavs to flac.") 90 | format_group.add_argument("--mp3", action="store_true", 91 | help="Convert the output wavs to mp3.") 92 | parser.add_argument("--mp3-bitrate", 93 | default=320, 94 | type=int, 95 | help="Bitrate of converted mp3.") 96 | parser.add_argument("--mp3-preset", choices=range(2, 8), type=int, default=2, 97 | help="Encoder preset of MP3, 2 for highest quality, 7 for " 98 | "fastest speed. Default is 2") 99 | parser.add_argument("-j", "--jobs", 100 | default=0, 101 | type=int, 102 | help="Number of jobs. This can increase memory usage but will " 103 | "be much faster when multiple cores are available.") 104 | 105 | return parser 106 | 107 | 108 | def main(opts=None): 109 | parser = get_parser() 110 | args = parser.parse_args(opts) 111 | if args.list_models: 112 | models = list_models(args.repo) 113 | print("Bag of models:", end="\n ") 114 | print("\n ".join(models["bag"])) 115 | print("Single models:", end="\n ") 116 | print("\n ".join(models["single"])) 117 | sys.exit(0) 118 | if len(args.tracks) == 0: 119 | print("error: the following arguments are required: tracks", file=sys.stderr) 120 | sys.exit(1) 121 | 122 | try: 123 | separator = Separator(model=args.name, 124 | repo=args.repo, 125 | device=args.device, 126 | shifts=args.shifts, 127 | split=args.split, 128 | overlap=args.overlap, 129 | progress=True, 130 | jobs=args.jobs, 131 | segment=args.segment) 132 | except ModelLoadingError as error: 133 | fatal(error.args[0]) 134 | 135 | max_allowed_segment = float('inf') 136 | if isinstance(separator.model, HTDemucs): 137 | max_allowed_segment = float(separator.model.segment) 138 | elif isinstance(separator.model, BagOfModels): 139 | max_allowed_segment = separator.model.max_allowed_segment 140 | if args.segment is not None and args.segment > max_allowed_segment: 141 | fatal("Cannot use a Transformer model with a longer segment " 142 | f"than it was trained for. Maximum segment is: {max_allowed_segment}") 143 | 144 | if isinstance(separator.model, BagOfModels): 145 | print( 146 | f"Selected model is a bag of {len(separator.model.models)} models. " 147 | "You will see that many progress bars per track." 148 | ) 149 | 150 | if args.stem is not None and args.stem not in separator.model.sources: 151 | fatal( 152 | 'error: stem "{stem}" is not in selected model. ' 153 | "STEM must be one of {sources}.".format( 154 | stem=args.stem, sources=", ".join(separator.model.sources) 155 | ) 156 | ) 157 | out = args.out / args.name 158 | out.mkdir(parents=True, exist_ok=True) 159 | print(f"Separated tracks will be stored in {out.resolve()}") 160 | for track in args.tracks: 161 | if not track.exists(): 162 | print(f"File {track} does not exist. If the path contains spaces, " 163 | 'please try again after surrounding the entire path with quotes "".', 164 | file=sys.stderr) 165 | continue 166 | print(f"Separating track {track}") 167 | 168 | origin, res = separator.separate_audio_file(track) 169 | 170 | if args.mp3: 171 | ext = "mp3" 172 | elif args.flac: 173 | ext = "flac" 174 | else: 175 | ext = "wav" 176 | kwargs = { 177 | "samplerate": separator.samplerate, 178 | "bitrate": args.mp3_bitrate, 179 | "preset": args.mp3_preset, 180 | "clip": args.clip_mode, 181 | "as_float": args.float32, 182 | "bits_per_sample": 24 if args.int24 else 16, 183 | } 184 | if args.stem is None: 185 | for name, source in res.items(): 186 | stem = out / args.filename.format( 187 | track=track.name.rsplit(".", 1)[0], 188 | trackext=track.name.rsplit(".", 1)[-1], 189 | stem=name, 190 | ext=ext, 191 | ) 192 | stem.parent.mkdir(parents=True, exist_ok=True) 193 | save_audio(source, str(stem), **kwargs) 194 | else: 195 | stem = out / args.filename.format( 196 | track=track.name.rsplit(".", 1)[0], 197 | trackext=track.name.rsplit(".", 1)[-1], 198 | stem="minus_" + args.stem, 199 | ext=ext, 200 | ) 201 | if args.other_method == "minus": 202 | stem.parent.mkdir(parents=True, exist_ok=True) 203 | save_audio(origin - res[args.stem], str(stem), **kwargs) 204 | stem = out / args.filename.format( 205 | track=track.name.rsplit(".", 1)[0], 206 | trackext=track.name.rsplit(".", 1)[-1], 207 | stem=args.stem, 208 | ext=ext, 209 | ) 210 | stem.parent.mkdir(parents=True, exist_ok=True) 211 | save_audio(res.pop(args.stem), str(stem), **kwargs) 212 | # Warning : after poping the stem, selected stem is no longer in the dict 'res' 213 | if args.other_method == "add": 214 | other_stem = th.zeros_like(next(iter(res.values()))) 215 | for i in res.values(): 216 | other_stem += i 217 | stem = out / args.filename.format( 218 | track=track.name.rsplit(".", 1)[0], 219 | trackext=track.name.rsplit(".", 1)[-1], 220 | stem="no_" + args.stem, 221 | ext=ext, 222 | ) 223 | stem.parent.mkdir(parents=True, exist_ok=True) 224 | save_audio(other_stem, str(stem), **kwargs) 225 | 226 | 227 | if __name__ == "__main__": 228 | main() 229 | -------------------------------------------------------------------------------- /demucs/spec.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | """Conveniance wrapper to perform STFT and iSTFT""" 7 | 8 | import torch as th 9 | 10 | 11 | def spectro(x, n_fft=512, hop_length=None, pad=0): 12 | *other, length = x.shape 13 | x = x.reshape(-1, length) 14 | is_mps_xpu = x.device.type in ['mps', 'xpu'] 15 | if is_mps_xpu: 16 | x = x.cpu() 17 | z = th.stft(x, 18 | n_fft * (1 + pad), 19 | hop_length or n_fft // 4, 20 | window=th.hann_window(n_fft).to(x), 21 | win_length=n_fft, 22 | normalized=True, 23 | center=True, 24 | return_complex=True, 25 | pad_mode='reflect') 26 | _, freqs, frame = z.shape 27 | return z.view(*other, freqs, frame) 28 | 29 | 30 | def ispectro(z, hop_length=None, length=None, pad=0): 31 | *other, freqs, frames = z.shape 32 | n_fft = 2 * freqs - 2 33 | z = z.view(-1, freqs, frames) 34 | win_length = n_fft // (1 + pad) 35 | is_mps_xpu = z.device.type in ['mps', 'xpu'] 36 | if is_mps_xpu: 37 | z = z.cpu() 38 | x = th.istft(z, 39 | n_fft, 40 | hop_length, 41 | window=th.hann_window(win_length).to(z.real), 42 | win_length=win_length, 43 | normalized=True, 44 | length=length, 45 | center=True) 46 | _, length = x.shape 47 | return x.view(*other, length) 48 | -------------------------------------------------------------------------------- /demucs/states.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | """ 7 | Utilities to save and load models. 8 | """ 9 | from contextlib import contextmanager 10 | 11 | import functools 12 | import hashlib 13 | import inspect 14 | import io 15 | from pathlib import Path 16 | import warnings 17 | 18 | from omegaconf import OmegaConf 19 | from dora.log import fatal 20 | import torch 21 | 22 | 23 | def _check_diffq(): 24 | try: 25 | import diffq # noqa 26 | except ImportError: 27 | fatal('Trying to use DiffQ, but diffq is not installed.\n' 28 | 'On Windows run: python.exe -m pip install diffq \n' 29 | 'On Linux/Mac, run: python3 -m pip install diffq') 30 | 31 | 32 | def get_quantizer(model, args, optimizer=None): 33 | """Return the quantizer given the XP quantization args.""" 34 | quantizer = None 35 | if args.diffq: 36 | _check_diffq() 37 | from diffq import DiffQuantizer 38 | quantizer = DiffQuantizer( 39 | model, min_size=args.min_size, group_size=args.group_size) 40 | if optimizer is not None: 41 | quantizer.setup_optimizer(optimizer) 42 | elif args.qat: 43 | _check_diffq() 44 | from diffq import UniformQuantizer 45 | quantizer = UniformQuantizer( 46 | model, bits=args.qat, min_size=args.min_size) 47 | return quantizer 48 | 49 | 50 | def load_model(path_or_package, strict=False): 51 | """Load a model from the given serialized model, either given as a dict (already loaded) 52 | or a path to a file on disk.""" 53 | if isinstance(path_or_package, dict): 54 | package = path_or_package 55 | elif isinstance(path_or_package, (str, Path)): 56 | with warnings.catch_warnings(): 57 | warnings.simplefilter("ignore") 58 | path = path_or_package 59 | package = torch.load(path, 'cpu') 60 | else: 61 | raise ValueError(f"Invalid type for {path_or_package}.") 62 | 63 | klass = package["klass"] 64 | args = package["args"] 65 | kwargs = package["kwargs"] 66 | 67 | if strict: 68 | model = klass(*args, **kwargs) 69 | else: 70 | sig = inspect.signature(klass) 71 | for key in list(kwargs): 72 | if key not in sig.parameters: 73 | warnings.warn("Dropping inexistant parameter " + key) 74 | del kwargs[key] 75 | model = klass(*args, **kwargs) 76 | 77 | state = package["state"] 78 | 79 | set_state(model, state) 80 | return model 81 | 82 | 83 | def get_state(model, quantizer, half=False): 84 | """Get the state from a model, potentially with quantization applied. 85 | If `half` is True, model are stored as half precision, which shouldn't impact performance 86 | but half the state size.""" 87 | if quantizer is None: 88 | dtype = torch.half if half else None 89 | state = {k: p.data.to(device='cpu', dtype=dtype) for k, p in model.state_dict().items()} 90 | else: 91 | state = quantizer.get_quantized_state() 92 | state['__quantized'] = True 93 | return state 94 | 95 | 96 | def set_state(model, state, quantizer=None): 97 | """Set the state on a given model.""" 98 | if state.get('__quantized'): 99 | if quantizer is not None: 100 | quantizer.restore_quantized_state(model, state['quantized']) 101 | else: 102 | _check_diffq() 103 | from diffq import restore_quantized_state 104 | restore_quantized_state(model, state) 105 | else: 106 | model.load_state_dict(state) 107 | return state 108 | 109 | 110 | def save_with_checksum(content, path): 111 | """Save the given value on disk, along with a sha256 hash. 112 | Should be used with the output of either `serialize_model` or `get_state`.""" 113 | buf = io.BytesIO() 114 | torch.save(content, buf) 115 | sig = hashlib.sha256(buf.getvalue()).hexdigest()[:8] 116 | 117 | path = path.parent / (path.stem + "-" + sig + path.suffix) 118 | path.write_bytes(buf.getvalue()) 119 | 120 | 121 | def serialize_model(model, training_args, quantizer=None, half=True): 122 | args, kwargs = model._init_args_kwargs 123 | klass = model.__class__ 124 | 125 | state = get_state(model, quantizer, half) 126 | return { 127 | 'klass': klass, 128 | 'args': args, 129 | 'kwargs': kwargs, 130 | 'state': state, 131 | 'training_args': OmegaConf.to_container(training_args, resolve=True), 132 | } 133 | 134 | 135 | def copy_state(state): 136 | return {k: v.cpu().clone() for k, v in state.items()} 137 | 138 | 139 | @contextmanager 140 | def swap_state(model, state): 141 | """ 142 | Context manager that swaps the state of a model, e.g: 143 | 144 | # model is in old state 145 | with swap_state(model, new_state): 146 | # model in new state 147 | # model back to old state 148 | """ 149 | old_state = copy_state(model.state_dict()) 150 | model.load_state_dict(state, strict=False) 151 | try: 152 | yield 153 | finally: 154 | model.load_state_dict(old_state) 155 | 156 | 157 | def capture_init(init): 158 | @functools.wraps(init) 159 | def __init__(self, *args, **kwargs): 160 | self._init_args_kwargs = (args, kwargs) 161 | init(self, *args, **kwargs) 162 | 163 | return __init__ 164 | -------------------------------------------------------------------------------- /demucs/svd.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | """Ways to make the model stronger.""" 7 | import random 8 | import torch 9 | 10 | 11 | def power_iteration(m, niters=1, bs=1): 12 | """This is the power method. batch size is used to try multiple starting point in parallel.""" 13 | assert m.dim() == 2 14 | assert m.shape[0] == m.shape[1] 15 | dim = m.shape[0] 16 | b = torch.randn(dim, bs, device=m.device, dtype=m.dtype) 17 | 18 | for _ in range(niters): 19 | n = m.mm(b) 20 | norm = n.norm(dim=0, keepdim=True) 21 | b = n / (1e-10 + norm) 22 | 23 | return norm.mean() 24 | 25 | 26 | # We need a shared RNG to make sure all the distributed worker will skip the penalty together, 27 | # as otherwise we wouldn't get any speed up. 28 | penalty_rng = random.Random(1234) 29 | 30 | 31 | def svd_penalty(model, min_size=0.1, dim=1, niters=2, powm=False, convtr=True, 32 | proba=1, conv_only=False, exact=False, bs=1): 33 | """ 34 | Penalty on the largest singular value for a layer. 35 | Args: 36 | - model: model to penalize 37 | - min_size: minimum size in MB of a layer to penalize. 38 | - dim: projection dimension for the svd_lowrank. Higher is better but slower. 39 | - niters: number of iterations in the algorithm used by svd_lowrank. 40 | - powm: use power method instead of lowrank SVD, my own experience 41 | is that it is both slower and less stable. 42 | - convtr: when True, differentiate between Conv and Transposed Conv. 43 | this is kept for compatibility with older experiments. 44 | - proba: probability to apply the penalty. 45 | - conv_only: only apply to conv and conv transposed, not LSTM 46 | (might not be reliable for other models than Demucs). 47 | - exact: use exact SVD (slow but useful at validation). 48 | - bs: batch_size for power method. 49 | """ 50 | total = 0 51 | if penalty_rng.random() > proba: 52 | return 0. 53 | 54 | for m in model.modules(): 55 | for name, p in m.named_parameters(recurse=False): 56 | if p.numel() / 2**18 < min_size: 57 | continue 58 | if convtr: 59 | if isinstance(m, (torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d)): 60 | if p.dim() in [3, 4]: 61 | p = p.transpose(0, 1).contiguous() 62 | if p.dim() == 3: 63 | p = p.view(len(p), -1) 64 | elif p.dim() == 4: 65 | p = p.view(len(p), -1) 66 | elif p.dim() == 1: 67 | continue 68 | elif conv_only: 69 | continue 70 | assert p.dim() == 2, (name, p.shape) 71 | if exact: 72 | estimate = torch.svd(p, compute_uv=False)[1].pow(2).max() 73 | elif powm: 74 | a, b = p.shape 75 | if a < b: 76 | n = p.mm(p.t()) 77 | else: 78 | n = p.t().mm(p) 79 | estimate = power_iteration(n, niters, bs) 80 | else: 81 | estimate = torch.svd_lowrank(p, dim, niters)[1][0].pow(2) 82 | total += estimate 83 | return total / proba 84 | -------------------------------------------------------------------------------- /demucs/train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | """Main training script entry point""" 8 | 9 | import logging 10 | import os 11 | from pathlib import Path 12 | import sys 13 | 14 | from dora import hydra_main 15 | import hydra 16 | from hydra.core.global_hydra import GlobalHydra 17 | from omegaconf import OmegaConf 18 | from . import audio_legacy 19 | import torch 20 | from torch import nn 21 | import torchaudio 22 | from torch.utils.data import ConcatDataset 23 | 24 | from . import distrib 25 | from .wav import get_wav_datasets, get_musdb_wav_datasets 26 | from .demucs import Demucs 27 | from .hdemucs import HDemucs 28 | from .htdemucs import HTDemucs 29 | from .repitch import RepitchedWrapper 30 | from .solver import Solver 31 | from .states import capture_init 32 | from .utils import random_subset 33 | 34 | logger = logging.getLogger(__name__) 35 | 36 | 37 | class TorchHDemucsWrapper(nn.Module): 38 | """Wrapper around torchaudio HDemucs implementation to provide the proper metadata 39 | for model evaluation. 40 | See https://pytorch.org/audio/stable/tutorials/hybrid_demucs_tutorial.html""" 41 | 42 | @capture_init 43 | def __init__(self, **kwargs): 44 | super().__init__() 45 | try: 46 | from torchaudio.models import HDemucs as TorchHDemucs 47 | except ImportError: 48 | raise ImportError("Please upgrade torchaudio for using its implementation of HDemucs") 49 | self.samplerate = kwargs.pop('samplerate') 50 | self.segment = kwargs.pop('segment') 51 | self.sources = kwargs['sources'] 52 | self.torch_hdemucs = TorchHDemucs(**kwargs) 53 | 54 | def forward(self, mix): 55 | return self.torch_hdemucs.forward(mix) 56 | 57 | 58 | def get_model(args): 59 | extra = { 60 | 'sources': list(args.dset.sources), 61 | 'audio_channels': args.dset.channels, 62 | 'samplerate': args.dset.samplerate, 63 | 'segment': args.model_segment or 4 * args.dset.segment, 64 | } 65 | klass = { 66 | 'demucs': Demucs, 67 | 'hdemucs': HDemucs, 68 | 'htdemucs': HTDemucs, 69 | 'torch_hdemucs': TorchHDemucsWrapper, 70 | }[args.model] 71 | kw = OmegaConf.to_container(getattr(args, args.model), resolve=True) 72 | model = klass(**extra, **kw) 73 | return model 74 | 75 | 76 | def get_optimizer(model, args): 77 | seen_params = set() 78 | other_params = [] 79 | groups = [] 80 | for n, module in model.named_modules(): 81 | if hasattr(module, "make_optim_group"): 82 | group = module.make_optim_group() 83 | params = set(group["params"]) 84 | assert params.isdisjoint(seen_params) 85 | seen_params |= set(params) 86 | groups.append(group) 87 | for param in model.parameters(): 88 | if param not in seen_params: 89 | other_params.append(param) 90 | groups.insert(0, {"params": other_params}) 91 | parameters = groups 92 | if args.optim.optim == "adam": 93 | return torch.optim.Adam( 94 | parameters, 95 | lr=args.optim.lr, 96 | betas=(args.optim.momentum, args.optim.beta2), 97 | weight_decay=args.optim.weight_decay, 98 | ) 99 | elif args.optim.optim == "adamw": 100 | return torch.optim.AdamW( 101 | parameters, 102 | lr=args.optim.lr, 103 | betas=(args.optim.momentum, args.optim.beta2), 104 | weight_decay=args.optim.weight_decay, 105 | ) 106 | else: 107 | raise ValueError("Invalid optimizer %s", args.optim.optimizer) 108 | 109 | 110 | def get_datasets(args): 111 | if args.dset.backend: 112 | torchaudio.set_audio_backend(args.dset.backend) 113 | if args.dset.use_musdb: 114 | train_set, valid_set = get_musdb_wav_datasets(args.dset) 115 | else: 116 | train_set, valid_set = [], [] 117 | if args.dset.wav: 118 | extra_train_set, extra_valid_set = get_wav_datasets(args.dset) 119 | if len(args.dset.sources) <= 4: 120 | train_set = ConcatDataset([train_set, extra_train_set]) 121 | valid_set = ConcatDataset([valid_set, extra_valid_set]) 122 | else: 123 | train_set = extra_train_set 124 | valid_set = extra_valid_set 125 | 126 | if args.dset.wav2: 127 | extra_train_set, extra_valid_set = get_wav_datasets(args.dset, "wav2") 128 | weight = args.dset.wav2_weight 129 | if weight is not None: 130 | b = len(train_set) 131 | e = len(extra_train_set) 132 | reps = max(1, round(e / b * (1 / weight - 1))) 133 | else: 134 | reps = 1 135 | train_set = ConcatDataset([train_set] * reps + [extra_train_set]) 136 | if args.dset.wav2_valid: 137 | if weight is not None: 138 | b = len(valid_set) 139 | n_kept = int(round(weight * b / (1 - weight))) 140 | valid_set = ConcatDataset( 141 | [valid_set, random_subset(extra_valid_set, n_kept)] 142 | ) 143 | else: 144 | valid_set = ConcatDataset([valid_set, extra_valid_set]) 145 | if args.dset.valid_samples is not None: 146 | valid_set = random_subset(valid_set, args.dset.valid_samples) 147 | assert len(train_set) 148 | assert len(valid_set) 149 | return train_set, valid_set 150 | 151 | 152 | def get_solver(args, model_only=False): 153 | distrib.init() 154 | 155 | torch.manual_seed(args.seed) 156 | model = get_model(args) 157 | if args.misc.show: 158 | logger.info(model) 159 | mb = sum(p.numel() for p in model.parameters()) * 4 / 2**20 160 | logger.info('Size: %.1f MB', mb) 161 | if hasattr(model, 'valid_length'): 162 | field = model.valid_length(1) 163 | logger.info('Field: %.1f ms', field / args.dset.samplerate * 1000) 164 | sys.exit(0) 165 | 166 | # torch also initialize cuda seed if available 167 | if torch.cuda.is_available(): 168 | model.cuda() 169 | 170 | # optimizer 171 | optimizer = get_optimizer(model, args) 172 | 173 | assert args.batch_size % distrib.world_size == 0 174 | args.batch_size //= distrib.world_size 175 | 176 | if model_only: 177 | return Solver(None, model, optimizer, args) 178 | 179 | train_set, valid_set = get_datasets(args) 180 | 181 | if args.augment.repitch.proba: 182 | vocals = [] 183 | if 'vocals' in args.dset.sources: 184 | vocals.append(args.dset.sources.index('vocals')) 185 | else: 186 | logger.warning('No vocal source found') 187 | if args.augment.repitch.proba: 188 | train_set = RepitchedWrapper(train_set, vocals=vocals, **args.augment.repitch) 189 | 190 | logger.info("train/valid set size: %d %d", len(train_set), len(valid_set)) 191 | train_loader = distrib.loader( 192 | train_set, batch_size=args.batch_size, shuffle=True, 193 | num_workers=args.misc.num_workers, drop_last=True) 194 | if args.dset.full_cv: 195 | valid_loader = distrib.loader( 196 | valid_set, batch_size=1, shuffle=False, 197 | num_workers=args.misc.num_workers) 198 | else: 199 | valid_loader = distrib.loader( 200 | valid_set, batch_size=args.batch_size, shuffle=False, 201 | num_workers=args.misc.num_workers, drop_last=True) 202 | loaders = {"train": train_loader, "valid": valid_loader} 203 | 204 | # Construct Solver 205 | return Solver(loaders, model, optimizer, args) 206 | 207 | 208 | def get_solver_from_sig(sig, model_only=False): 209 | inst = GlobalHydra.instance() 210 | hyd = None 211 | if inst.is_initialized(): 212 | hyd = inst.hydra 213 | inst.clear() 214 | xp = main.get_xp_from_sig(sig) 215 | if hyd is not None: 216 | inst.clear() 217 | inst.initialize(hyd) 218 | 219 | with xp.enter(stack=True): 220 | return get_solver(xp.cfg, model_only) 221 | 222 | 223 | @hydra_main(config_path="../conf", config_name="config", version_base="1.1") 224 | def main(args): 225 | global __file__ 226 | __file__ = hydra.utils.to_absolute_path(__file__) 227 | for attr in ["musdb", "wav", "metadata"]: 228 | val = getattr(args.dset, attr) 229 | if val is not None: 230 | setattr(args.dset, attr, hydra.utils.to_absolute_path(val)) 231 | 232 | os.environ["OMP_NUM_THREADS"] = "1" 233 | os.environ["MKL_NUM_THREADS"] = "1" 234 | 235 | if args.misc.verbose: 236 | logger.setLevel(logging.DEBUG) 237 | 238 | logger.info("For logs, checkpoints and samples check %s", os.getcwd()) 239 | logger.debug(args) 240 | from dora import get_xp 241 | logger.debug(get_xp().cfg) 242 | 243 | solver = get_solver(args) 244 | solver.train() 245 | 246 | 247 | if '_DORA_TEST_PATH' in os.environ: 248 | main.dora.dir = Path(os.environ['_DORA_TEST_PATH']) 249 | 250 | 251 | if __name__ == "__main__": 252 | main() 253 | -------------------------------------------------------------------------------- /demucs/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from collections import defaultdict 8 | from concurrent.futures import CancelledError 9 | from contextlib import contextmanager 10 | import math 11 | import os 12 | import tempfile 13 | import typing as tp 14 | 15 | import torch 16 | from torch.nn import functional as F 17 | from torch.utils.data import Subset 18 | 19 | 20 | def unfold(a, kernel_size, stride): 21 | """Given input of size [*OT, T], output Tensor of size [*OT, F, K] 22 | with K the kernel size, by extracting frames with the given stride. 23 | 24 | This will pad the input so that `F = ceil(T / K)`. 25 | 26 | see https://github.com/pytorch/pytorch/issues/60466 27 | """ 28 | *shape, length = a.shape 29 | n_frames = math.ceil(length / stride) 30 | tgt_length = (n_frames - 1) * stride + kernel_size 31 | a = F.pad(a, (0, tgt_length - length)) 32 | strides = list(a.stride()) 33 | assert strides[-1] == 1, 'data should be contiguous' 34 | strides = strides[:-1] + [stride, 1] 35 | return a.as_strided([*shape, n_frames, kernel_size], strides) 36 | 37 | 38 | def center_trim(tensor: torch.Tensor, reference: tp.Union[torch.Tensor, int]): 39 | """ 40 | Center trim `tensor` with respect to `reference`, along the last dimension. 41 | `reference` can also be a number, representing the length to trim to. 42 | If the size difference != 0 mod 2, the extra sample is removed on the right side. 43 | """ 44 | ref_size: int 45 | if isinstance(reference, torch.Tensor): 46 | ref_size = reference.size(-1) 47 | else: 48 | ref_size = reference 49 | delta = tensor.size(-1) - ref_size 50 | if delta < 0: 51 | raise ValueError("tensor must be larger than reference. " f"Delta is {delta}.") 52 | if delta: 53 | tensor = tensor[..., delta // 2:-(delta - delta // 2)] 54 | return tensor 55 | 56 | 57 | def pull_metric(history: tp.List[dict], name: str): 58 | out = [] 59 | for metrics in history: 60 | metric = metrics 61 | for part in name.split("."): 62 | metric = metric[part] 63 | out.append(metric) 64 | return out 65 | 66 | 67 | def EMA(beta: float = 1): 68 | """ 69 | Exponential Moving Average callback. 70 | Returns a single function that can be called to repeatidly update the EMA 71 | with a dict of metrics. The callback will return 72 | the new averaged dict of metrics. 73 | 74 | Note that for `beta=1`, this is just plain averaging. 75 | """ 76 | fix: tp.Dict[str, float] = defaultdict(float) 77 | total: tp.Dict[str, float] = defaultdict(float) 78 | 79 | def _update(metrics: dict, weight: float = 1) -> dict: 80 | nonlocal total, fix 81 | for key, value in metrics.items(): 82 | total[key] = total[key] * beta + weight * float(value) 83 | fix[key] = fix[key] * beta + weight 84 | return {key: tot / fix[key] for key, tot in total.items()} 85 | return _update 86 | 87 | 88 | def sizeof_fmt(num: float, suffix: str = 'B'): 89 | """ 90 | Given `num` bytes, return human readable size. 91 | Taken from https://stackoverflow.com/a/1094933 92 | """ 93 | for unit in ['', 'Ki', 'Mi', 'Gi', 'Ti', 'Pi', 'Ei', 'Zi']: 94 | if abs(num) < 1024.0: 95 | return "%3.1f%s%s" % (num, unit, suffix) 96 | num /= 1024.0 97 | return "%.1f%s%s" % (num, 'Yi', suffix) 98 | 99 | 100 | @contextmanager 101 | def temp_filenames(count: int, delete=True): 102 | names = [] 103 | try: 104 | for _ in range(count): 105 | names.append(tempfile.NamedTemporaryFile(delete=False).name) 106 | yield names 107 | finally: 108 | if delete: 109 | for name in names: 110 | os.unlink(name) 111 | 112 | 113 | def random_subset(dataset, max_samples: int, seed: int = 42): 114 | if max_samples >= len(dataset): 115 | return dataset 116 | 117 | generator = torch.Generator().manual_seed(seed) 118 | perm = torch.randperm(len(dataset), generator=generator) 119 | return Subset(dataset, perm[:max_samples].tolist()) 120 | 121 | 122 | class DummyPoolExecutor: 123 | class DummyResult: 124 | def __init__(self, func, _dict, *args, **kwargs): 125 | self.func = func 126 | self._dict = _dict 127 | self.args = args 128 | self.kwargs = kwargs 129 | 130 | def result(self): 131 | if self._dict["run"]: 132 | return self.func(*self.args, **self.kwargs) 133 | else: 134 | raise CancelledError() 135 | 136 | def __init__(self, workers=0): 137 | self._dict = {"run": True} 138 | 139 | def submit(self, func, *args, **kwargs): 140 | return DummyPoolExecutor.DummyResult(func, self._dict, *args, **kwargs) 141 | 142 | def shutdown(self, *_, **__): 143 | self._dict["run"] = False 144 | 145 | def __enter__(self): 146 | return self 147 | 148 | def __exit__(self, exc_type, exc_value, exc_tb): 149 | return 150 | -------------------------------------------------------------------------------- /demucs/wav.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | """Loading wav based datasets, including MusdbHQ.""" 7 | 8 | from collections import OrderedDict 9 | import hashlib 10 | import math 11 | import json 12 | import os 13 | from pathlib import Path 14 | import tqdm 15 | 16 | import musdb 17 | import julius 18 | from . import audio_legacy 19 | import torch as th 20 | from torch import distributed 21 | import torchaudio as ta 22 | from torch.nn import functional as F 23 | 24 | from .audio import convert_audio_channels 25 | from . import distrib 26 | 27 | MIXTURE = "mixture" 28 | EXT = ".wav" 29 | 30 | 31 | def _track_metadata(track, sources, normalize=True, ext=EXT): 32 | track_length = None 33 | track_samplerate = None 34 | mean = 0 35 | std = 1 36 | for source in sources + [MIXTURE]: 37 | file = track / f"{source}{ext}" 38 | if source == MIXTURE and not file.exists(): 39 | audio = 0 40 | for sub_source in sources: 41 | sub_file = track / f"{sub_source}{ext}" 42 | sub_audio, sr = ta.load(sub_file) 43 | audio += sub_audio 44 | would_clip = audio.abs().max() >= 1 45 | if would_clip: 46 | assert ta.get_audio_backend() == 'soundfile', 'use dset.backend=soundfile' 47 | ta.save(file, audio, sr, encoding='PCM_F') 48 | 49 | try: 50 | info = ta.info(str(file)) 51 | except RuntimeError: 52 | print(file) 53 | raise 54 | length = info.num_frames 55 | if track_length is None: 56 | track_length = length 57 | track_samplerate = info.sample_rate 58 | elif track_length != length: 59 | raise ValueError( 60 | f"Invalid length for file {file}: " 61 | f"expecting {track_length} but got {length}.") 62 | elif info.sample_rate != track_samplerate: 63 | raise ValueError( 64 | f"Invalid sample rate for file {file}: " 65 | f"expecting {track_samplerate} but got {info.sample_rate}.") 66 | if source == MIXTURE and normalize: 67 | try: 68 | wav, _ = ta.load(str(file)) 69 | except RuntimeError: 70 | print(file) 71 | raise 72 | wav = wav.mean(0) 73 | mean = wav.mean().item() 74 | std = wav.std().item() 75 | 76 | return {"length": length, "mean": mean, "std": std, "samplerate": track_samplerate} 77 | 78 | 79 | def build_metadata(path, sources, normalize=True, ext=EXT): 80 | """ 81 | Build the metadata for `Wavset`. 82 | 83 | Args: 84 | path (str or Path): path to dataset. 85 | sources (list[str]): list of sources to look for. 86 | normalize (bool): if True, loads full track and store normalization 87 | values based on the mixture file. 88 | ext (str): extension of audio files (default is .wav). 89 | """ 90 | 91 | meta = {} 92 | path = Path(path) 93 | pendings = [] 94 | from concurrent.futures import ThreadPoolExecutor 95 | with ThreadPoolExecutor(8) as pool: 96 | for root, folders, files in os.walk(path, followlinks=True): 97 | root = Path(root) 98 | if root.name.startswith('.') or folders or root == path: 99 | continue 100 | name = str(root.relative_to(path)) 101 | pendings.append((name, pool.submit(_track_metadata, root, sources, normalize, ext))) 102 | # meta[name] = _track_metadata(root, sources, normalize, ext) 103 | for name, pending in tqdm.tqdm(pendings, ncols=120): 104 | meta[name] = pending.result() 105 | return meta 106 | 107 | 108 | class Wavset: 109 | def __init__( 110 | self, 111 | root, metadata, sources, 112 | segment=None, shift=None, normalize=True, 113 | samplerate=44100, channels=2, ext=EXT): 114 | """ 115 | Waveset (or mp3 set for that matter). Can be used to train 116 | with arbitrary sources. Each track should be one folder inside of `path`. 117 | The folder should contain files named `{source}.{ext}`. 118 | 119 | Args: 120 | root (Path or str): root folder for the dataset. 121 | metadata (dict): output from `build_metadata`. 122 | sources (list[str]): list of source names. 123 | segment (None or float): segment length in seconds. If `None`, returns entire tracks. 124 | shift (None or float): stride in seconds bewteen samples. 125 | normalize (bool): normalizes input audio, **based on the metadata content**, 126 | i.e. the entire track is normalized, not individual extracts. 127 | samplerate (int): target sample rate. if the file sample rate 128 | is different, it will be resampled on the fly. 129 | channels (int): target nb of channels. if different, will be 130 | changed onthe fly. 131 | ext (str): extension for audio files (default is .wav). 132 | 133 | samplerate and channels are converted on the fly. 134 | """ 135 | self.root = Path(root) 136 | self.metadata = OrderedDict(metadata) 137 | self.segment = segment 138 | self.shift = shift or segment 139 | self.normalize = normalize 140 | self.sources = sources 141 | self.channels = channels 142 | self.samplerate = samplerate 143 | self.ext = ext 144 | self.num_examples = [] 145 | for name, meta in self.metadata.items(): 146 | track_duration = meta['length'] / meta['samplerate'] 147 | if segment is None or track_duration < segment: 148 | examples = 1 149 | else: 150 | examples = int(math.ceil((track_duration - self.segment) / self.shift) + 1) 151 | self.num_examples.append(examples) 152 | 153 | def __len__(self): 154 | return sum(self.num_examples) 155 | 156 | def get_file(self, name, source): 157 | return self.root / name / f"{source}{self.ext}" 158 | 159 | def __getitem__(self, index): 160 | for name, examples in zip(self.metadata, self.num_examples): 161 | if index >= examples: 162 | index -= examples 163 | continue 164 | meta = self.metadata[name] 165 | num_frames = -1 166 | offset = 0 167 | if self.segment is not None: 168 | offset = int(meta['samplerate'] * self.shift * index) 169 | num_frames = int(math.ceil(meta['samplerate'] * self.segment)) 170 | wavs = [] 171 | for source in self.sources: 172 | file = self.get_file(name, source) 173 | wav, _ = ta.load(str(file), frame_offset=offset, num_frames=num_frames) 174 | wav = convert_audio_channels(wav, self.channels) 175 | wavs.append(wav) 176 | 177 | example = th.stack(wavs) 178 | example = julius.resample_frac(example, meta['samplerate'], self.samplerate) 179 | if self.normalize: 180 | example = (example - meta['mean']) / meta['std'] 181 | if self.segment: 182 | length = int(self.segment * self.samplerate) 183 | example = example[..., :length] 184 | example = F.pad(example, (0, length - example.shape[-1])) 185 | return example 186 | 187 | 188 | def get_wav_datasets(args, name='wav'): 189 | """Extract the wav datasets from the XP arguments.""" 190 | path = getattr(args, name) 191 | sig = hashlib.sha1(str(path).encode()).hexdigest()[:8] 192 | metadata_file = Path(args.metadata) / ('wav_' + sig + ".json") 193 | train_path = Path(path) / "train" 194 | valid_path = Path(path) / "valid" 195 | if not metadata_file.is_file() and distrib.rank == 0: 196 | metadata_file.parent.mkdir(exist_ok=True, parents=True) 197 | train = build_metadata(train_path, args.sources) 198 | valid = build_metadata(valid_path, args.sources) 199 | json.dump([train, valid], open(metadata_file, "w")) 200 | if distrib.world_size > 1: 201 | distributed.barrier() 202 | train, valid = json.load(open(metadata_file)) 203 | if args.full_cv: 204 | kw_cv = {} 205 | else: 206 | kw_cv = {'segment': args.segment, 'shift': args.shift} 207 | train_set = Wavset(train_path, train, args.sources, 208 | segment=args.segment, shift=args.shift, 209 | samplerate=args.samplerate, channels=args.channels, 210 | normalize=args.normalize) 211 | valid_set = Wavset(valid_path, valid, [MIXTURE] + list(args.sources), 212 | samplerate=args.samplerate, channels=args.channels, 213 | normalize=args.normalize, **kw_cv) 214 | return train_set, valid_set 215 | 216 | 217 | def _get_musdb_valid(): 218 | # Return musdb valid set. 219 | import yaml 220 | setup_path = Path(musdb.__path__[0]) / 'configs' / 'mus.yaml' 221 | setup = yaml.safe_load(open(setup_path, 'r')) 222 | return setup['validation_tracks'] 223 | 224 | 225 | def get_musdb_wav_datasets(args): 226 | """Extract the musdb dataset from the XP arguments.""" 227 | sig = hashlib.sha1(str(args.musdb).encode()).hexdigest()[:8] 228 | metadata_file = Path(args.metadata) / ('musdb_' + sig + ".json") 229 | root = Path(args.musdb) / "train" 230 | if not metadata_file.is_file() and distrib.rank == 0: 231 | metadata_file.parent.mkdir(exist_ok=True, parents=True) 232 | metadata = build_metadata(root, args.sources) 233 | json.dump(metadata, open(metadata_file, "w")) 234 | if distrib.world_size > 1: 235 | distributed.barrier() 236 | metadata = json.load(open(metadata_file)) 237 | 238 | valid_tracks = _get_musdb_valid() 239 | if args.train_valid: 240 | metadata_train = metadata 241 | else: 242 | metadata_train = {name: meta for name, meta in metadata.items() if name not in valid_tracks} 243 | metadata_valid = {name: meta for name, meta in metadata.items() if name in valid_tracks} 244 | if args.full_cv: 245 | kw_cv = {} 246 | else: 247 | kw_cv = {'segment': args.segment, 'shift': args.shift} 248 | train_set = Wavset(root, metadata_train, args.sources, 249 | segment=args.segment, shift=args.shift, 250 | samplerate=args.samplerate, channels=args.channels, 251 | normalize=args.normalize) 252 | valid_set = Wavset(root, metadata_valid, [MIXTURE] + list(args.sources), 253 | samplerate=args.samplerate, channels=args.channels, 254 | normalize=args.normalize, **kw_cv) 255 | return train_set, valid_set 256 | -------------------------------------------------------------------------------- /demucs/wdemucs.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # For compat 7 | from .hdemucs import HDemucs 8 | 9 | WDemucs = HDemucs 10 | -------------------------------------------------------------------------------- /docs/api.md: -------------------------------------------------------------------------------- 1 | # Demucs APIs 2 | 3 | ## Quick start 4 | 5 | Notes: Type hints have been added to all API functions. It is recommended to check them before passing parameters to a function as some arguments only support limited types (e.g. parameter `repo` of method `load_model` only support type `pathlib.Path`). 6 | 7 | 1. The first step is to import api module: 8 | 9 | ```python 10 | import demucs.api 11 | ``` 12 | 13 | 2. Then initialize the `Separator`. Parameters which will be served as default values for methods can be passed. Model should be specified. 14 | 15 | ```python 16 | # Initialize with default parameters: 17 | separator = demucs.api.Separator() 18 | 19 | # Use another model and segment: 20 | separator = demucs.api.Separator(model="mdx_extra", segment=12) 21 | 22 | # You can also use other parameters defined 23 | ``` 24 | 25 | 3. Separate it! 26 | 27 | ```python 28 | # Separating an audio file 29 | origin, separated = separator.separate_audio_file("file.mp3") 30 | 31 | # Separating a loaded audio 32 | origin, separated = separator.separate_tensor(audio) 33 | 34 | # If you encounter an error like CUDA out of memory, you can use this to change parameters like `segment`: 35 | separator.update_parameter(segment=smaller_segment) 36 | ``` 37 | 38 | 4. Save audio 39 | 40 | ```python 41 | # Remember to create the destination folder before calling `save_audio` 42 | # Or you are likely to recieve `FileNotFoundError` 43 | for file, sources in separated: 44 | for stem, source in sources.items(): 45 | demucs.api.save_audio(source, f"{stem}_{file}", samplerate=separator.samplerate) 46 | ``` 47 | 48 | ## API References 49 | 50 | The types of each parameter and return value is not listed in this document. To know the exact type of them, please read the type hints in api.py (most modern code editors support inferring types based on type hints). 51 | 52 | ### `class Separator` 53 | 54 | The base separator class 55 | 56 | ##### Parameters 57 | 58 | model: Pretrained model name or signature. Default is htdemucs. 59 | 60 | repo: Folder containing all pre-trained models for use. 61 | 62 | segment: Length (in seconds) of each segment (only available if `split` is `True`). If not specified, will use the command line option. 63 | 64 | shifts: If > 0, will shift in time `wav` by a random amount between 0 and 0.5 sec and apply the oppositve shift to the output. This is repeated `shifts` time and all predictions are averaged. This effectively makes the model time equivariant and improves SDR by up to 0.2 points. If not specified, will use the command line option. 65 | 66 | split: If True, the input will be broken down into small chunks (length set by `segment`) and predictions will be performed individually on each and concatenated. Useful for model with large memory footprint like Tasnet. If not specified, will use the command line option. 67 | 68 | overlap: The overlap between the splits. If not specified, will use the command line option. 69 | 70 | device (torch.device, str, or None): If provided, device on which to execute the computation, otherwise `wav.device` is assumed. When `device` is different from `wav.device`, only local computations will be on `device`, while the entire tracks will be stored on `wav.device`. If not specified, will use the command line option. 71 | 72 | jobs: Number of jobs. This can increase memory usage but will be much faster when multiple cores are available. If not specified, will use the command line option. 73 | 74 | callback: A function will be called when the separation of a chunk starts or finished. The argument passed to the function will be a dict. For more information, please see the Callback section. 75 | 76 | callback_arg: A dict containing private parameters to be passed to callback function. For more information, please see the Callback section. 77 | 78 | progress: If true, show a progress bar. 79 | 80 | ##### Notes for callback 81 | 82 | The function will be called with only one positional parameter whose type is `dict`. The `callback_arg` will be combined with information of current separation progress. The progress information will override the values in `callback_arg` if same key has been used. To abort the separation, raise an exception in `callback` which should be handled by yourself if you want your codes continue to function. 83 | 84 | Progress information contains several keys (These keys will always exist): 85 | - `model_idx_in_bag`: The index of the submodel in `BagOfModels`. Starts from 0. 86 | - `shift_idx`: The index of shifts. Starts from 0. 87 | - `segment_offset`: The offset of current segment. If the number is 441000, it doesn't mean that it is at the 441000 second of the audio, but the "frame" of the tensor. 88 | - `state`: Could be `"start"` or `"end"`. 89 | - `audio_length`: Length of the audio (in "frame" of the tensor). 90 | - `models`: Count of submodels in the model. 91 | 92 | #### `property samplerate` 93 | 94 | A read-only property saving sample rate of the model requires. Will raise a warning if the model is not loaded and return the default value. 95 | 96 | #### `property audio_channels` 97 | 98 | A read-only property saving audio channels of the model requires. Will raise a warning if the model is not loaded and return the default value. 99 | 100 | #### `property model` 101 | 102 | A read-only property saving the model. 103 | 104 | #### `method update_parameter()` 105 | 106 | Update the parameters of separation. 107 | 108 | ##### Parameters 109 | 110 | segment: Length (in seconds) of each segment (only available if `split` is `True`). If not specified, will use the command line option. 111 | 112 | shifts: If > 0, will shift in time `wav` by a random amount between 0 and 0.5 sec and apply the oppositve shift to the output. This is repeated `shifts` time and all predictions are averaged. This effectively makes the model time equivariant and improves SDR by up to 0.2 points. If not specified, will use the command line option. 113 | 114 | split: If True, the input will be broken down into small chunks (length set by `segment`) and predictions will be performed individually on each and concatenated. Useful for model with large memory footprint like Tasnet. If not specified, will use the command line option. 115 | 116 | overlap: The overlap between the splits. If not specified, will use the command line option. 117 | 118 | device (torch.device, str, or None): If provided, device on which to execute the computation, otherwise `wav.device` is assumed. When `device` is different from `wav.device`, only local computations will be on `device`, while the entire tracks will be stored on `wav.device`. If not specified, will use the command line option. 119 | 120 | jobs: Number of jobs. This can increase memory usage but will be much faster when multiple cores are available. If not specified, will use the command line option. 121 | 122 | callback: A function will be called when the separation of a chunk starts or finished. The argument passed to the function will be a dict. For more information, please see the Callback section. 123 | 124 | callback_arg: A dict containing private parameters to be passed to callback function. For more information, please see the Callback section. 125 | 126 | progress: If true, show a progress bar. 127 | 128 | ##### Notes for callback 129 | 130 | The function will be called with only one positional parameter whose type is `dict`. The `callback_arg` will be combined with information of current separation progress. The progress information will override the values in `callback_arg` if same key has been used. To abort the separation, raise an exception in `callback` which should be handled by yourself if you want your codes continue to function. 131 | 132 | Progress information contains several keys (These keys will always exist): 133 | - `model_idx_in_bag`: The index of the submodel in `BagOfModels`. Starts from 0. 134 | - `shift_idx`: The index of shifts. Starts from 0. 135 | - `segment_offset`: The offset of current segment. If the number is 441000, it doesn't mean that it is at the 441000 second of the audio, but the "frame" of the tensor. 136 | - `state`: Could be `"start"` or `"end"`. 137 | - `audio_length`: Length of the audio (in "frame" of the tensor). 138 | - `models`: Count of submodels in the model. 139 | 140 | #### `method separate_tensor()` 141 | 142 | Separate an audio. 143 | 144 | ##### Parameters 145 | 146 | wav: Waveform of the audio. Should have 2 dimensions, the first is each audio channel, while the second is the waveform of each channel. e.g. `tuple(wav.shape) == (2, 884000)` means the audio has 2 channels. 147 | 148 | sr: Sample rate of the original audio, the wave will be resampled if it doesn't match the model. 149 | 150 | ##### Returns 151 | 152 | A tuple, whose first element is the original wave and second element is a dict, whose keys are the name of stems and values are separated waves. The original wave will have already been resampled. 153 | 154 | ##### Notes 155 | 156 | Use this function with cautiousness. This function does not provide data verifying. 157 | 158 | #### `method separate_audio_file()` 159 | 160 | Separate an audio file. The method will automatically read the file. 161 | 162 | ##### Parameters 163 | 164 | wav: Path of the file to be separated. 165 | 166 | ##### Returns 167 | 168 | A tuple, whose first element is the original wave and second element is a dict, whose keys are the name of stems and values are separated waves. The original wave will have already been resampled. 169 | 170 | ### `function save_audio()` 171 | 172 | Save audio file. 173 | 174 | ##### Parameters 175 | 176 | wav: Audio to be saved 177 | 178 | path: The file path to be saved. Ending must be one of `.mp3` and `.wav`. 179 | 180 | samplerate: File sample rate. 181 | 182 | bitrate: If the suffix of `path` is `.mp3`, it will be used to specify the bitrate of mp3. 183 | 184 | clip: Clipping preventing strategy. 185 | 186 | bits_per_sample: If the suffix of `path` is `.wav`, it will be used to specify the bit depth of wav. 187 | 188 | as_float: If it is True and the suffix of `path` is `.wav`, then `bits_per_sample` will be set to 32 and will write the wave file with float format. 189 | 190 | ##### Returns 191 | 192 | None 193 | 194 | ### `function list_models()` 195 | 196 | List the available models. Please remember that not all the returned models can be successfully loaded. 197 | 198 | ##### Parameters 199 | 200 | repo: The repo whose models are to be listed. 201 | 202 | ##### Returns 203 | 204 | A dict with two keys ("single" for single models and "bag" for bag of models). The values are lists whose components are strs. -------------------------------------------------------------------------------- /docs/linux.md: -------------------------------------------------------------------------------- 1 | # Linux support for Demucs 2 | 3 | If your distribution has at least Python 3.8, and you just wish to separate 4 | tracks with Demucs, not train it, you can just run 5 | 6 | ```bash 7 | pip3 install --user -U demucs 8 | # Then anytime you want to use demucs, just do 9 | python3 -m demucs -d cpu PATH_TO_AUDIO_FILE_1 10 | # If you have added the user specific pip bin/ folder to your path, you can also do 11 | demucs -d cpu PATH_TO_AUDIO_FILE_1 12 | ``` 13 | 14 | If Python is too old, or you want to be able to train, I recommend [installing Miniconda][miniconda], with Python 3.8 or more. 15 | 16 | ```bash 17 | conda activate 18 | pip3 install -U demucs 19 | # Then anytime you want to use demucs, first do conda activate, then 20 | demucs -d cpu PATH_TO_AUDIO_FILE_1 21 | ``` 22 | 23 | Of course, you can also use a specific env for Demucs. 24 | 25 | **Important, torchaudio 0.12 update:** Torchaudio no longer supports decoding mp3s without ffmpeg installed. You must have ffmpeg installed, either through Anaconda (`conda install ffmpeg -c conda-forge`) or as a distribution package (e.g. `sudo apt-get install ffmpeg`). 26 | 27 | 28 | [miniconda]: https://docs.conda.io/en/latest/miniconda.html#linux-installers 29 | -------------------------------------------------------------------------------- /docs/mac.md: -------------------------------------------------------------------------------- 1 | # macOS support for Demucs 2 | 3 | If you have a sufficiently recent version of macOS, you can just run 4 | 5 | ```bash 6 | python3 -m pip install --user -U demucs 7 | # Then anytime you want to use demucs, just do 8 | python3 -m demucs -d cpu PATH_TO_AUDIO_FILE_1 9 | # If you have added the user specific pip bin/ folder to your path, you can also do 10 | demucs -d cpu PATH_TO_AUDIO_FILE_1 11 | ``` 12 | 13 | If you do not already have Anaconda installed or much experience with the terminal on macOS, here are some detailed instructions: 14 | 15 | 1. Download [Anaconda 3.8 (or more recent) 64-bit for macOS][anaconda]: 16 | 2. Open [Anaconda Prompt in macOS][prompt] 17 | 3. Follow these commands: 18 | ```bash 19 | conda activate 20 | pip3 install -U demucs 21 | # Then anytime you want to use demucs, first do conda activate, then 22 | demucs -d cpu PATH_TO_AUDIO_FILE_1 23 | ``` 24 | 25 | **Important, torchaudio 0.12 update:** Torchaudio no longer supports decoding mp3s without ffmpeg installed. You must have ffmpeg installed, either through Anaconda (`conda install ffmpeg -c conda-forge`) or with Homebrew for instance (`brew install ffmpeg`). 26 | 27 | [anaconda]: https://www.anaconda.com/download 28 | [prompt]: https://docs.anaconda.com/anaconda/user-guide/getting-started/#open-nav-mac 29 | -------------------------------------------------------------------------------- /docs/mdx.md: -------------------------------------------------------------------------------- 1 | # Music DemiXing challenge (MDX) 2 | 3 | If you want to use Demucs for the [MDX challenge](https://www.aicrowd.com/challenges/music-demixing-challenge-ismir-2021), 4 | please follow the instructions hereafter 5 | 6 | ## Installing Demucs 7 | 8 | Follow the instructions from the [main README](https://github.com/facebookresearch/demucs#requirements) 9 | in order to setup Demucs using Anaconda. You will need the full setup up for training, including soundstretch. 10 | 11 | ## Getting MusDB-HQ 12 | 13 | Download [MusDB-HQ](https://zenodo.org/record/3338373) to some folder and unzip it. 14 | 15 | ## Training Demucs 16 | 17 | Train Demucs (you might need to change the batch size depending on the number of GPUs available). 18 | It seems 48 channels is enough to get the best performance on MusDB-HQ, and training will faster 19 | and less memory demanding. In any case, the 64 channels versions is timing out on the challenge. 20 | ```bash 21 | ./run.py --channels=48 --batch_size 64 --musdb=PATH_TO_MUSDB --is_wav [EXTRA_FLAGS] 22 | ``` 23 | 24 | ### Post training 25 | 26 | Once the training is completed, a new model file will be exported in `models/`. 27 | 28 | You can look at the SDR on the MusDB dataset using `python result_table.py`. 29 | 30 | 31 | ### Evaluate and export a model before training is over 32 | 33 | If you want to export a model before training is complete, use the following command: 34 | ```bash 35 | python -m demucs [ALL EXACT TRAINING FLAGS] --save_model 36 | ``` 37 | You can also pass the `--half` flag, in order to save weights in half precision. This will divide the model size by 2 and won't impact SDR. 38 | 39 | Once this is done, you can partially evaluate a model with 40 | ```bash 41 | ./run.py --test NAME_OF_MODEL.th --musdb=PATH_TO_MUSDB --is_wav 42 | ``` 43 | 44 | **Note:** `NAME_OF_MODEL.th` is given relative to the models folder (given by `--models`, defaults to `models/`), so don't include it in the name. 45 | 46 | 47 | ### Training smaller models 48 | 49 | If you want to quickly test idea, I would recommend training a 16 kHz model, and testing if things work there or not, before training the full 44kHz model. You can train one of those with 50 | ```bash 51 | ./run.py --channels=32 --samplerate 16000 --samples 160000 --data_stride 16000 --depth=5 --batch_size 64 --repitch=0 --musdb=PATH_TO_MUSDB --is_wav [EXTRA_FLAGS] 52 | ``` 53 | (repitch must be turned off, because things will break at 16kHz). 54 | 55 | ## Submitting your model 56 | 57 | 1. Git clone [the Music Demixing Challenge - Starter Kit - Demucs Edition](https://github.com/adefossez/music-demixing-challenge-starter-kit). 58 | 2. Inside the starter kit, create a `models/` folder and copy over the trained model from the Demucs repo (renaming 59 | it for instance `my_model.th`) 60 | 3. Inside the `test_demuc.py` file, change the function `prediction_setup`: comment the loading 61 | of the pre-trained model, and uncomment the code to load your own model. 62 | 4. Edit the file `aicrowd.json` with your username. 63 | 5. Install [git-lfs](https://git-lfs.github.com/). Then run 64 | 65 | ```bash 66 | git lfs install 67 | git add models/ 68 | git add -u . 69 | git commit -m "My Demucs submission" 70 | ``` 71 | 6. Follow the [submission instructions](https://github.com/AIcrowd/music-demixing-challenge-starter-kit/blob/master/docs/SUBMISSION.md). 72 | 73 | Best of luck 🤞 74 | -------------------------------------------------------------------------------- /docs/release.md: -------------------------------------------------------------------------------- 1 | # Release notes for Demucs 2 | 3 | ## V4.1.0a, TBD 4 | 5 | Get models list 6 | 7 | Check segment of HTDemucs inside BagOfModels 8 | 9 | Added api.py to be called from another program 10 | 11 | Use api in separate.py 12 | 13 | Added `--other-method`: method to get `no_{STEM}`, add up all the other stems (add), original track substract the specific stem (minus), and discard (none) 14 | 15 | Added type `HTDemucs` to type alias `AnyModel`. 16 | 17 | Improving recent torchaudio versions support (Thanks @CarlGao4) 18 | 19 | ## V4.0.1, 8th of September 2023 20 | 21 | **From this version, Python 3.7 is no longer supported. This is not a problem since the latest PyTorch 2.0.0 no longer support it either.** 22 | 23 | Various improvements by @CarlGao4. Support for `segment` param inside of HTDemucs 24 | model. 25 | 26 | Made diffq an optional dependency, with an error message if not installed. 27 | 28 | Added output format flac (Free Lossless Audio Codec) 29 | 30 | Will use CPU for complex numbers, when using MPS device (all other computations are performed by mps). 31 | 32 | Optimize codes to save memory 33 | 34 | Allow changing preset of MP3 35 | 36 | ## V4.0.0, 7th of December 2022 37 | 38 | Adding hybrid transformer Demucs model. 39 | 40 | Added support for [Torchaudio implementation of HDemucs](https://pytorch.org/audio/main/tutorials/hybrid_demucs_tutorial.html), thanks @skim0514. 41 | 42 | Added experimental 6 sources model `htdemucs_6s` (`drums`, `bass`, `other`, `vocals`, `piano`, `guitar`). 43 | 44 | ## V3.0.6, 16th of November 2022 45 | 46 | Option to customize output path of stems (@CarlGao4) 47 | 48 | Fixed bug in pad1d leading to failure sometimes. 49 | 50 | ## V3.0.5, 17th of August 2022 51 | 52 | Added `--segment` flag to customize the segment length and use less memory (thanks @CarlGao4). 53 | 54 | Fix reflect padding bug on small inputs. 55 | 56 | Compatible with pyTorch 1.12 57 | 58 | ## V3.0.4, 24th of February 2022 59 | 60 | Added option to split into two stems (i.e. vocals, vs. non vocals), thanks to @CarlGao4. 61 | 62 | Added `--float32`, `--int24` and `--clip-mode` options to customize how output stems are saved. 63 | 64 | ## V3.0.3, 2nd of December 2021 65 | 66 | Fix bug in weights used for different sources. Thanks @keunwoochoi for the report and fix. 67 | 68 | Improving drastically memory usage on GPU for long files. Thanks a lot @famzah for providing this. 69 | 70 | Adding multithread evaluation on CPU (`-j` option). 71 | 72 | (v3.0.2 had a bug with the CPU pool and is skipped.) 73 | 74 | ## V3.0.1, 12th of November 2021 75 | 76 | Release of Demucs v3, featuring hybrid domain separation and much more. 77 | This drops support for Conv-Tasnet and training on the non HQ MusDB dataset. 78 | There is no version 3.0.0 because I messed up. 79 | 80 | ## V2.0.2, 26th of May 2021 81 | 82 | - Fix in Tasnet (PR #178) 83 | - Use ffmpeg in priority when available instead of torchaudio to avoid small shift in MP3 data. 84 | - other minor fixes 85 | 86 | ## v2.0.1, 11th of May 2021 87 | 88 | MusDB HQ support added. Custom wav dataset support added. 89 | Minor changes: issue with padding of mp3 and torchaudio reading, in order to limit that, 90 | Demucs now uses ffmpeg in priority and fallback to torchaudio. 91 | Replaced pre-trained demucs model with one trained on more recent codebase. 92 | 93 | ## v2.0.0, 28th of April 2021 94 | 95 | This is a big release, with at lof of breaking changes. You will likely 96 | need to install Demucs from scratch. 97 | 98 | 99 | 100 | - Demucs now supports on the fly resampling by a factor of 2. 101 | This improves SDR almost 0.3 points. 102 | - Random scaling of each source added (From Uhlich et al. 2017). 103 | - Random pitch and tempo augmentation addded, from [Cohen-Hadria et al. 2019]. 104 | - With extra augmentation, the best performing Demucs model now has only 64 channels 105 | instead of 100, so model size goes from 2.4GB to 1GB. Also SDR is up from 5.6 SDR to 6.3 when trained only on MusDB. 106 | - Quantized model using [DiffQ](https://github.com/facebookresearch/diffq) has been added. Model size is 150MB, no loss in quality as far as I, or the metrics, 107 | can say. 108 | - Pretrained models are now using the TorchHub interface. 109 | - Overlap mode for separation, to limit inconsitencies at 110 | frame boundaries, with linear transition over the overlap. Overlap is currently 111 | at 25%. Not that this is only done for separation, not training, because 112 | I added that quite late to the code. For Conv-TasNet this can improve 113 | SDR quite a bit (+0.3 points, to 6.0). 114 | - PyPI hosting, for separation, not training! 115 | -------------------------------------------------------------------------------- /docs/sdx23.md: -------------------------------------------------------------------------------- 1 | # SDX 23 challenge 2 | 3 | Checkout [the challenge page](https://www.aicrowd.com/challenges/sound-demixing-challenge-2023) 4 | for more information. This page is specifically on training models for the [MDX'23 sub-challenge](https://www.aicrowd.com/challenges/sound-demixing-challenge-2023/problems/music-demixing-track-mdx-23). 5 | There are two tracks: one trained on a dataset with bleeding, and the other with label mixups. 6 | 7 | This gives instructions on training an Hybrid Demucs model on those datasets. 8 | I haven't tried the HT Demucs model, as it typically requires quite a bit of training data but the same could be done with it. 9 | 10 | You will need to work from an up to date clone of this repo. See the [generic training instructions](./training.md) for more information. 11 | 12 | ## Getting the data 13 | 14 | Register on the challenge, then checkout the [Resources page](https://www.aicrowd.com/challenges/sound-demixing-challenge-2023/problems/music-demixing-track-mdx-23/dataset_files) and download the dataset you are 15 | interested in. 16 | 17 | Update the `conf/dset/sdx23_bleeding.yaml` and `conf/dset/sdx23_labelnoise.yaml` files to point to the right path. 18 | 19 | **Make sure soundfile** is installed (`conda install -c conda-forge libsndfile; pip install soundfile`). 20 | 21 | ### Create proper train / valid structure 22 | 23 | Demucs requires a valid set to work properly. Go to the folder where you extracted the tracks then do: 24 | 25 | ```shell 26 | mkdir train 27 | mv * train # there will be a warning saying cannot move train to itself but that's fine the other tracks should have. 28 | mkdir valid 29 | cd train 30 | mv 5640831d-7853-4d06-8166-988e2844b652 bc964128-da16-4e4c-af95-4d1211e78c70 \ 31 | cc7f7675-d3c8-4a49-a2d7-a8959b694004 f40ffd10-4e8b-41e6-bd8a-971929ca9138 \ 32 | bc1f2967-f834-43bd-aadc-95afc897cfe7 cc3e4991-6cce-40fe-a917-81a4fbb92ea6 \ 33 | ed90a89a-bf22-444d-af3d-d9ac3896ebd2 f4b735de-14b1-4091-a9ba-c8b30c0740a7 ../valid 34 | ``` 35 | 36 | ## Training 37 | 38 | See `dora grid sdx23` for a starting point. You can do `dora grid sdx23 --init --dry_run` then `dora run -f SIG -d` with `SIG` one of the signature 39 | to train on a machine with GPUs if you do not have a SLURM cluster. 40 | 41 | Keep in mind that the valid tracks and train tracks are corrupted in different ways for those tasks, so do not expect 42 | the valid loss to go down as smoothly as with normal training on the clean MusDB. 43 | 44 | I only trained Hybrid Demucs baselines as Hybrid Transformer typically requires more data. 45 | 46 | 47 | ## Exporting models 48 | 49 | Run 50 | ``` 51 | python -m tools.export SIG 52 | ``` 53 | 54 | This will export the trained model into the `release_models` folder. 55 | 56 | ## Submitting a model 57 | 58 | Clone the [Demucs Starter Kit for SDX23](https://github.com/adefossez/sdx23). Follow the instructions there. 59 | 60 | You will to copy the models under `release_models` in the `sdx23/models/` folder before you can use them. 61 | Make sure you have git-lfs properly installed and setup before adding those files to your fork of `sdx23`. 62 | -------------------------------------------------------------------------------- /docs/training.md: -------------------------------------------------------------------------------- 1 | # Training (Hybrid) Demucs 2 | 3 | ## Install all the dependencies 4 | 5 | You should install all the dependencies either with either Anaconda (using the env file `environment-cuda.yml` ) 6 | or `pip`, with `requirements.txt`. 7 | 8 | ## Datasets 9 | 10 | ### MusDB HQ 11 | 12 | Note that we do not support MusDB non HQ training anymore. 13 | Get the [Musdb HQ](https://zenodo.org/record/3338373) dataset, and update the path to it in two places: 14 | - The `dset.musdb` key inside `conf/config.yaml`. 15 | - The variable `MUSDB_PATH` inside `tools/automix.py`. 16 | 17 | ### Create the fine tuning datasets 18 | 19 | **This is only for the MDX 2021 competition models** 20 | 21 | I use a fine tuning on a dataset crafted by remixing songs in a musically plausible way. 22 | The automix script will make sure that BPM, first beat and pitches are aligned. 23 | In the file `tools/automix.py`, edit `OUTPATH` to suit your setup, as well as the `MUSDB_PATH` 24 | to point to your copy of MusDB HQ. Then run 25 | 26 | ```bash 27 | export NUMBA_NUM_THREADS=1; python3 -m tools.automix 28 | ``` 29 | 30 | **Important:** the script will show many errors, those are normals. They just indicate when two stems 31 | do not batch due to BPM or music scale difference. 32 | 33 | Finally, edit the file `conf/dset/auto_mus.yaml` and replace `dset.wav` to the value of `OUTPATH`. 34 | 35 | If you have a custom dataset, you can also uncomment the lines `dset2 = ...` and 36 | `dset3 = ...` to add your custom wav data and the test set of MusDB for Track B models. 37 | You can then replace the paths in `conf/dset/auto_extra.yaml`, `conf/dset/auto_extra_test.yaml` 38 | and `conf/dset/aetl.yaml` (this last one was using 10 mixes instead of 6 for each song). 39 | 40 | ### Dataset metadata cache 41 | 42 | Datasets are scanned the first time they are used to determine the files and their durations. 43 | If you change a dataset and need a rescan, just delete the `metadata` folder. 44 | 45 | ## A short intro to Dora 46 | 47 | I use [Dora][dora] for all the of experiments (XPs) management. You should have a look at the Dora README 48 | to learn about the tool. Here is a quick summary of what to know: 49 | 50 | - An XP is a unique set of hyper-parameters with a given signature. The signature is a hash of 51 | those hyper-parameters. I will always refer to an XP with its signature, e.g. `9357e12e`. 52 | We will see after that you can retrieve the hyper-params and re-rerun it in a single command. 53 | - In fact, the hash is defined as a delta between the base config and the one obtained with 54 | the config overrides you passed from the command line. 55 | **This means you must never change the `conf/**.yaml` files directly.**, 56 | except for editing things like paths. Changing the default values in the config files means 57 | the XP signature won't reflect that change, and wrong checkpoints might be reused. 58 | I know, this is annoying, but the reason is that otherwise, any change to the config file would 59 | mean that all XPs ran so far would see their signature change. 60 | 61 | ### Dora commands 62 | 63 | Run `tar xvf outputs.tar.gz`. This will initialize the Dora XP repository, so that Dora knows 64 | which hyper-params match the signature like `9357e12e`. Once you have done that, you should be able 65 | to run the following: 66 | 67 | ```bash 68 | dora info -f 81de367c # this will show the hyper-parameter used by a specific XP. 69 | # Be careful some overrides might present twice, and the right most one 70 | # will give you the right value for it. 71 | dora run -d -f 81de367c # run an XP with the hyper-parameters from XP 81de367c. 72 | # `-d` is for distributed, it will use all available GPUs. 73 | dora run -d -f 81de367c hdemucs.channels=32 # start from the config of XP 81de367c but change some hyper-params. 74 | # This will give you a new XP with a new signature (here 3fe9c332). 75 | ``` 76 | 77 | An XP runs from a specific folder based on its signature, by default under the `outputs/` folder. 78 | You can safely interrupt a training and resume it, it will reuse any existing checkpoint, as it will 79 | reuse the same folder. 80 | If you made some change to the code and need to ignore a previous checkpoint you can use `dora run --clear [RUN ARGS]`. 81 | 82 | If you have a Slurm cluster, you can also use the `dora grid` command, e.g. `dora grid mdx`. 83 | Please refer to the [Dora documentation][dora] for more information. 84 | 85 | ## Hyper parameters 86 | 87 | Have a look at [conf/config.yaml](../conf/config.yaml) for a list of all the hyper-parameters you can override. 88 | If you are not familiar with [Hydra](https://github.com/facebookresearch/hydra), go checkout their page 89 | to be familiar with how to provide overrides for your trainings. 90 | 91 | 92 | ## Model architecture 93 | 94 | A number of architectures are supported. You can select one with `model=NAME`, and have a look 95 | in [conf/config.yaml'(../conf/config.yaml) for each architecture specific hyperparams. 96 | Those specific params will be always prefixed with the architecture name when passing the override 97 | from the command line or in grid files. Here is the list of models: 98 | 99 | - demucs: original time-only Demucs. 100 | - hdemucs: Hybrid Demucs (v3). 101 | - torch_hdemucs: Same as Hybrid Demucs, but using [torchaudio official implementation](https://pytorch.org/audio/stable/tutorials/hybrid_demucs_tutorial.html). 102 | - htdemucs: Hybrid Transformer Demucs (v4). 103 | 104 | ### Storing config in files 105 | 106 | As mentioned earlier, you should never change the base config files. However, you can use Hydra config groups 107 | in order to store variants you often use. If you want to create a new variant combining multiple hyper-params, 108 | copy the file `conf/variant/example.yaml` to `conf/variant/my_variant.yaml`, and then you can use it with 109 | 110 | ```bash 111 | dora run -d variant=my_variant 112 | ``` 113 | 114 | Once you have created this file, you should not edit it once you have started training models with it. 115 | 116 | 117 | ## Fine tuning 118 | 119 | If a first model is trained, you can fine tune it with other settings (e.g. automix dataset) with 120 | 121 | ```bash 122 | dora run -d -f 81de367c continue_from=81de367c dset=auto_mus variant=finetune 123 | ```` 124 | 125 | Note that you need both `-f 81de367c` and `continue_from=81de367c`. The first one indicates 126 | that the hyper-params of `81de367c` should be used as a starting point for the config. 127 | The second indicates that the weights from `81de367c` should be used as a starting point for the solver. 128 | 129 | 130 | ## Model evaluation 131 | 132 | Your model will be evaluated automatically with the new SDR definition from MDX every 20 epochs. 133 | Old style SDR (which is quite slow) will only happen at the end of training. 134 | 135 | ## Model Export 136 | 137 | 138 | In order to use your models with other commands (such as the `demucs` command for separation) you must 139 | export it. For that run 140 | 141 | ```bash 142 | python3 -m tools.export 9357e12e [OTHER SIGS ...] # replace with the appropriate signatures. 143 | ``` 144 | 145 | The models will be stored under `release_models/`. You can use them with the `demucs` separation command with the following flags: 146 | ```bash 147 | demucs --repo ./release_models -n 9357e12e my_track.mp3 148 | ``` 149 | 150 | ### Bag of models 151 | 152 | If you want to combine multiple models, potentially with different weights for each source, you can copy 153 | `demucs/remote/mdx.yaml` to `./release_models/my_bag.yaml`. You can then edit the list of models (all models used should have been exported first) and the weights per source and model (list of list, outer list is over models, inner list is over sources). You can then use your bag of model as 154 | 155 | ```bash 156 | demucs --repo ./release_models -n my_bag my_track.mp3 157 | ``` 158 | 159 | ## Model evaluation 160 | 161 | You can evaluate any pre-trained model or bag of models using the following command: 162 | ```bash 163 | python3 -m tools.test_pretrained -n NAME_OF_MODEL [EXTRA ARGS] 164 | ``` 165 | where `NAME_OF_MODEL` is either the name of the bag (e.g. `mdx`, `repro_mdx_a`), 166 | or a single Dora signature of one of the model of the bags. You can pass `EXTRA ARGS` to customize 167 | the test options, like the number of random shifts (e.g. `test.shifts=2`). This will compute the old-style 168 | SDR and can take quite bit of time. 169 | 170 | For custom models that were trained locally, you will need to indicate that you wish 171 | to use the local model repositories, with the `--repo ./release_models` flag, e.g., 172 | ```bash 173 | python3 -m tools.test_pretrained --repo ./release_models -n my_bag 174 | ``` 175 | 176 | 177 | ## API to retrieve the model 178 | 179 | You can retrieve officially released models in Python using the following API: 180 | ```python 181 | from demucs import pretrained 182 | from demucs.apply import apply_model 183 | bag = pretrained.get_model('htdemucs') # for a bag of models or a named model 184 | # (which is just a bag with 1 model). 185 | model = pretrained.get_model('955717e8') # using the signature for single models. 186 | 187 | bag.models # list of individual models 188 | stems = apply_model(model, mix) # apply the model to the given mix. 189 | ``` 190 | 191 | ## Model Zoo 192 | 193 | ### Hybrid Transformer Demucs 194 | 195 | The configuration for the Hybrid Transformer models are available in: 196 | 197 | ```shell 198 | dora grid mmi --dry_run --init 199 | dora grid mmi_ft --dry_run --init # fined tuned on each sources. 200 | ``` 201 | 202 | We release in particular `955717e8`, Hybrid Transformer Demucs using 5 layers, 512 channels, 10 seconds training segment length. We also release its fine tuned version, with one model 203 | for each source `f7e0c4bc`, `d12395a8`, `92cfc3b6`, `04573f0d` (drums, bass, other, vocals). 204 | The model `955717e8` is also named `htdemucs`, while the bag of models is provided 205 | as `htdemucs_ft`. 206 | 207 | We also release `75fc33f5`, a regular Hybrid Demucs trained on the same dataset, 208 | available as `hdemucs_mmi`. 209 | 210 | 211 | 212 | ### Models from the MDX Competition 2021 213 | 214 | 215 | Here is a short descriptions of the models used for the MDX submission, either Track A (MusDB HQ only) 216 | or Track B (extra training data allowed). Training happen in two stage, with the second stage 217 | being the fine tunining on the automix generated dataset. 218 | All the fine tuned models are available on our AWS repository 219 | (you can retrieve it with `demucs.pretrained.get_model(SIG)`). The bag of models are available 220 | by doing `demucs.pretrained.get_model(NAME)` with `NAME` begin either `mdx` (for Track A) or `mdx_extra` 221 | (for Track B). 222 | 223 | #### Track A 224 | 225 | The 4 models are: 226 | 227 | - `0d19c1c6`: fine-tuned on automix dataset from `9357e12e` 228 | - `7ecf8ec1`: fine-tuned on automix dataset from `e312f349` 229 | - `c511e2ab`: fine-tuned on automix dataset from `81de367c` 230 | - `7d865c68`: fine-tuned on automix dataset from `80a68df8` 231 | 232 | The 4 initial models (before fine tuning are): 233 | 234 | - `9357e12e`: 64ch time domain only improved Demucs, with new residual branches, group norm, 235 | and singular value penalty. 236 | - `e312f349`: 64ch time domain only improved, with new residual branches, group norm, 237 | and singular value penalty, trained with a loss that focus only on drums and bass. 238 | - `81de367c`: 48ch hybrid model , with residual branches, group norm, 239 | singular value penalty penalty and amplitude spectrogram. 240 | - `80a68df8`: same as b5559babb but using CaC and different 241 | random seed, as well different weigths per frequency bands in outermost layers. 242 | 243 | The hybrid models are combined with equal weights for all sources except for the bass. 244 | `0d19c1c6` (time domain) is used for both drums and bass. `7ecf8ec1` is used only for the bass. 245 | 246 | You can see all the hyper parameters at once with (one common line for all common hyper params, and then only shows 247 | the hyper parameters that differs), along with the DiffQ variants that are used for the `mdx_q` models: 248 | ``` 249 | dora grid mdx --dry_run --init 250 | dora grid mdx --dry_run --init 251 | ``` 252 | 253 | #### Track B 254 | 255 | - `e51eebcc` 256 | - `a1d90b5c` 257 | - `5d2d6c55` 258 | - `cfa93e08` 259 | 260 | All the models are 48ch hybrid demucs with different random seeds. Two of them 261 | are using CaC, and two are using amplitude spectrograms with masking. 262 | All the models are combined with equal weights for all sources. 263 | 264 | Things are a bit messy for Track B, there was a lot of fine tuning 265 | over different datasets. I won't describe the entire genealogy of models here, 266 | but all the information can be accessed with the `dora info -f SIG` command. 267 | 268 | Similarly you can do (those will contain a few extra lines, for training without the MusDB test set as training, and extra DiffQ XPs): 269 | ``` 270 | dora grid mdx_extra --dry_run --init 271 | ``` 272 | 273 | ### Reproducibility and Ablation 274 | 275 | I updated the paper to report numbers with a more homogeneous setup than the one used for the competition. 276 | On MusDB HQ, I still need to use a combination of time only and hybrid models to achieve the best performance. 277 | The experiments are provided in the grids [repro.py](../demucs/grids/repro.py) and 278 | [repro_ft._py](../demucs/grids/repro_ft.py) for the fine tuning on the realistic mix datasets. 279 | 280 | The new bag of models reaches an SDR of 7.64 (vs. 7.68 for the original track A model). It uses 281 | 2 time only models trained with residual branches, local attention and the SVD penalty, 282 | along with 2 hybrid models, with the same features, and using CaC representation. 283 | We average the performance of all the models with the same weight over all sources, unlike 284 | what was done for the original track A model. We trained for 600 epochs, against 360 before. 285 | 286 | The new bag of model is available as part of the pretrained model as `repro_mdx_a`. 287 | The time only bag is named `repro_mdx_a_time_only`, and the hybrid only `repro_mdx_a_hybrid_only`. 288 | Checkout the paper for more information on the training. 289 | 290 | [dora]: https://github.com/facebookresearch/dora 291 | -------------------------------------------------------------------------------- /docs/windows.md: -------------------------------------------------------------------------------- 1 | # Windows support for Demucs 2 | 3 | ## Installation and usage 4 | 5 | If you don't have much experience with Anaconda, python or the shell, here are more detailed instructions. Note that **Demucs is not supported on 32bits systems** (as Pytorch is not available there). 6 | 7 | - First install Anaconda with **Python 3.8** or more recent, which you can find [here][install]. 8 | - Start the [Anaconda prompt][prompt]. 9 | 10 | Then, all commands that follow must be run from this prompt. 11 | 12 |
13 | I have no coding experience and these are too difficult for me 14 | 15 | > Then a GUI is suitable for you. See [Demucs GUI](https://github.com/CarlGao4/Demucs-Gui) 16 | 17 |
18 | 19 | ### If you want to use your GPU 20 | 21 | If you have graphic cards produced by NVIDIA with more than 2GiB of memory, you can separate tracks with GPU acceleration. To achieve this, you must install Pytorch with CUDA. If Pytorch was already installed (you already installed Demucs for instance), first run `python.exe -m pip uninstall torch torchaudio`. 22 | Then visit [Pytorch Home Page](https://pytorch.org/get-started/locally/) and follow the guide on it to install with CUDA support. Please make sure that the version of torchaudio should no greater than 2.1 (which is the latest version when this document is written, but 2.2.0 is sure unsupported) 23 | 24 | ### Installation 25 | 26 | Start the Anaconda prompt, and run the following 27 | 28 | ```cmd 29 | conda install -c conda-forge ffmpeg 30 | python.exe -m pip install -U demucs SoundFile 31 | ``` 32 | 33 | ### Upgrade 34 | 35 | To upgrade Demucs, simply run `python.exe -m pip install -U demucs`, from the Anaconda prompt. 36 | 37 | ### Usage 38 | 39 | Then to use Demucs, just start the **Anaconda prompt** and run: 40 | ``` 41 | demucs -d cpu "PATH_TO_AUDIO_FILE_1" ["PATH_TO_AUDIO_FILE_2" ...] 42 | ``` 43 | The `"` around the filename are required if the path contains spaces. A simple way to input these paths is draging a file from a folder into the terminal. 44 | 45 | To find out the separated files, you can run this command and open the folders: 46 | ``` 47 | explorer separated 48 | ``` 49 | 50 | ### Separating an entire folder 51 | 52 | You can use the following command to separate an entire folder of mp3s for instance (replace the extension `.mp3` if needs be for other file types) 53 | ``` 54 | cd FOLDER 55 | for %i in (*.mp3) do (demucs -d cpu "%i") 56 | ``` 57 | 58 | ## Potential errors 59 | 60 | If you have an error saying that `mkl_intel_thread.dll` cannot be found, you can try to first run 61 | `conda install -c defaults intel-openmp -f`. Then try again to run the `demucs` command. If it still doesn't work, you can try to run first `set CONDA_DLL_SEARCH_MODIFICATION_ENABLE=1`, then again the `demucs` command and hopefully it will work 🙏. 62 | 63 | **If you get a permission error**, please try starting the Anaconda Prompt as administrator. 64 | 65 | 66 | [install]: https://www.anaconda.com/download 67 | [prompt]: https://docs.anaconda.com/anaconda/user-guide/getting-started/#open-prompt-win 68 | -------------------------------------------------------------------------------- /environment-cpu.yml: -------------------------------------------------------------------------------- 1 | name: demucs 2 | 3 | channels: 4 | - pytorch 5 | - conda-forge 6 | 7 | dependencies: 8 | - python>=3.8,<3.10 9 | - ffmpeg>=4.2 10 | - pytorch>=1.8.1 11 | - torchaudio>=0.8 12 | - tqdm>=4.36 13 | - pip 14 | - pip: 15 | - diffq>=0.2 16 | - dora-search 17 | - einops 18 | - hydra-colorlog>=1.1 19 | - hydra-core>=1.1 20 | - julius>=0.2.3 21 | - lameenc>=1.2 22 | - openunmix 23 | - musdb>=0.4.0 24 | - museval>=0.4.0 25 | - soundfile 26 | - submitit 27 | - treetable>=0.2.3 28 | 29 | -------------------------------------------------------------------------------- /environment-cuda.yml: -------------------------------------------------------------------------------- 1 | name: demucs 2 | 3 | channels: 4 | - pytorch 5 | - conda-forge 6 | 7 | dependencies: 8 | - python>=3.8,<3.10 9 | - ffmpeg>=4.2 10 | - pytorch>=1.8.1 11 | - torchaudio>=0.8 12 | - cudatoolkit>=10 13 | - tqdm>=4.36 14 | - pip 15 | - pip: 16 | - diffq>=0.2 17 | - dora-search 18 | - einops 19 | - hydra-colorlog>=1.1 20 | - hydra-core>=1.1 21 | - julius>=0.2.3 22 | - lameenc>=1.2 23 | - openunmix 24 | - musdb>=0.4.0 25 | - museval>=0.4.0 26 | - soundfile 27 | - submitit 28 | - treetable>=0.2.3 29 | -------------------------------------------------------------------------------- /hubconf.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | dependencies = ['dora-search', 'julius', 'lameenc', 'openunmix', 'pyyaml', 8 | 'torch', 'torchaudio', 'tqdm'] 9 | 10 | from demucs.pretrained import get_model 11 | 12 | -------------------------------------------------------------------------------- /mypy.ini: -------------------------------------------------------------------------------- 1 | [mypy] 2 | 3 | [mypy-treetable,torchaudio.*,diffq,yaml,tqdm,lameenc,musdb,museval,openunmix.*,einops,xformers.*] 4 | ignore_missing_imports = True 5 | 6 | -------------------------------------------------------------------------------- /outputs.tar.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adefossez/demucs/b9ab48cad45976ba42b2ff17b229c071f0df9390/outputs.tar.gz -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # please make sure you have already a pytorch install that is cuda enabled! 2 | dora-search>=0.1.12 3 | diffq>=0.2.1 4 | einops 5 | flake8 6 | hydra-colorlog>=1.1 7 | hydra-core>=1.1 8 | julius>=0.2.3 9 | lameenc>=1.2 10 | museval 11 | mypy 12 | openunmix 13 | pyyaml 14 | submitit 15 | torch>=1.8.1 16 | torchaudio>=0.8,<2.2 17 | tqdm 18 | treetable 19 | soundfile>=0.10.3 20 | -------------------------------------------------------------------------------- /requirements_minimal.txt: -------------------------------------------------------------------------------- 1 | # please make sure you have already a pytorch install that is cuda enabled! 2 | dora-search 3 | einops 4 | julius>=0.2.3 5 | lameenc>=1.2 6 | openunmix 7 | pyyaml 8 | torch>=1.8.1 9 | torchaudio>=0.8,<2.2 10 | tqdm 11 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [pep8] 2 | max-line-length = 100 3 | 4 | [flake8] 5 | max-line-length = 100 6 | 7 | [yapf] 8 | column_limit = 100 9 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # author: adefossez 7 | # Inspired from https://github.com/kennethreitz/setup.py 8 | 9 | from pathlib import Path 10 | 11 | from setuptools import setup 12 | 13 | 14 | NAME = 'demucs' 15 | DESCRIPTION = 'Music source separation in the waveform domain.' 16 | 17 | URL = 'https://github.com/facebookresearch/demucs' 18 | EMAIL = 'defossez@fb.com' 19 | AUTHOR = 'Alexandre Défossez' 20 | REQUIRES_PYTHON = '>=3.8.0' 21 | 22 | HERE = Path(__file__).parent 23 | 24 | # Get version without explicitely loading the module. 25 | for line in open('demucs/__init__.py'): 26 | line = line.strip() 27 | if '__version__' in line: 28 | context = {} 29 | exec(line, context) 30 | VERSION = context['__version__'] 31 | 32 | 33 | def load_requirements(name): 34 | required = [i.strip() for i in open(HERE / name)] 35 | required = [i for i in required if not i.startswith('#')] 36 | return required 37 | 38 | 39 | REQUIRED = load_requirements('requirements_minimal.txt') 40 | ALL_REQUIRED = load_requirements('requirements.txt') 41 | 42 | try: 43 | with open(HERE / "README.md", encoding='utf-8') as f: 44 | long_description = '\n' + f.read() 45 | except FileNotFoundError: 46 | long_description = DESCRIPTION 47 | 48 | setup( 49 | name=NAME, 50 | version=VERSION, 51 | description=DESCRIPTION, 52 | long_description=long_description, 53 | long_description_content_type='text/markdown', 54 | author=AUTHOR, 55 | author_email=EMAIL, 56 | python_requires=REQUIRES_PYTHON, 57 | url=URL, 58 | packages=['demucs'], 59 | extras_require={ 60 | 'dev': ALL_REQUIRED, 61 | }, 62 | install_requires=REQUIRED, 63 | include_package_data=True, 64 | entry_points={ 65 | 'console_scripts': ['demucs=demucs.separate:main'], 66 | }, 67 | license='MIT License', 68 | classifiers=[ 69 | # Trove classifiers 70 | # Full list: https://pypi.python.org/pypi?%3Aaction=list_classifiers 71 | 'License :: OSI Approved :: MIT License', 72 | 'Topic :: Multimedia :: Sound/Audio', 73 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 74 | ], 75 | ) 76 | -------------------------------------------------------------------------------- /test.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adefossez/demucs/b9ab48cad45976ba42b2ff17b229c071f0df9390/test.mp3 -------------------------------------------------------------------------------- /tools/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /tools/automix.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | """ 8 | This script creates realistic mixes with stems from different songs. 9 | In particular, it will align BPM, sync up the first beat and perform pitch 10 | shift to maximize pitches overlap. 11 | In order to limit artifacts, only parts that can be mixed with less than 15% 12 | tempo shift, and 3 semitones of pitch shift are mixed together. 13 | """ 14 | from collections import namedtuple 15 | from concurrent.futures import ProcessPoolExecutor 16 | import hashlib 17 | from pathlib import Path 18 | import random 19 | import shutil 20 | import tqdm 21 | import pickle 22 | 23 | from librosa.beat import beat_track 24 | from librosa.feature import chroma_cqt 25 | import numpy as np 26 | import torch 27 | from torch.nn import functional as F 28 | 29 | from dora.utils import try_load 30 | from demucs.audio import save_audio 31 | from demucs.repitch import repitch 32 | from demucs.pretrained import SOURCES 33 | from demucs.wav import build_metadata, Wavset, _get_musdb_valid 34 | 35 | 36 | MUSDB_PATH = '/checkpoint/defossez/datasets/musdbhq' 37 | EXTRA_WAV_PATH = "/checkpoint/defossez/datasets/allstems_44" 38 | # WARNING: OUTPATH will be completely erased. 39 | OUTPATH = Path.home() / 'tmp/demucs_mdx/automix_musdb/' 40 | CACHE = Path.home() / 'tmp/automix_cache' # cache BPM and pitch information. 41 | CHANNELS = 2 42 | SR = 44100 43 | MAX_PITCH = 3 # maximum allowable pitch shift in semi tones 44 | MAX_TEMPO = 0.15 # maximum allowable tempo shift 45 | 46 | 47 | Spec = namedtuple("Spec", "tempo onsets kr track index") 48 | 49 | 50 | def rms(wav, window=10000): 51 | """efficient rms computed for each time step over a given window.""" 52 | half = window // 2 53 | window = 2 * half + 1 54 | wav = F.pad(wav, (half, half)) 55 | tot = wav.pow(2).cumsum(dim=-1) 56 | return ((tot[..., window - 1:] - tot[..., :-window + 1]) / window).sqrt() 57 | 58 | 59 | def analyse_track(dset, index): 60 | """analyse track, extract bpm and distribution of notes from the bass line.""" 61 | track = dset[index] 62 | mix = track.sum(0).mean(0) 63 | ref = mix.std() 64 | 65 | starts = (abs(mix) >= 1e-2 * ref).float().argmax().item() 66 | track = track[..., starts:] 67 | 68 | cache = CACHE / dset.sig 69 | cache.mkdir(exist_ok=True, parents=True) 70 | 71 | cache_file = cache / f"{index}.pkl" 72 | cached = None 73 | if cache_file.exists(): 74 | cached = try_load(cache_file) 75 | if cached is not None: 76 | tempo, events, hist_kr = cached 77 | 78 | if cached is None: 79 | drums = track[0].mean(0) 80 | if drums.std() > 1e-2 * ref: 81 | tempo, events = beat_track(y=drums.numpy(), units='time', sr=SR) 82 | else: 83 | print("failed drums", drums.std(), ref) 84 | return None, track 85 | 86 | bass = track[1].mean(0) 87 | r = rms(bass) 88 | peak = r.max() 89 | mask = r >= 0.05 * peak 90 | bass = bass[mask] 91 | if bass.std() > 1e-2 * ref: 92 | kr = torch.from_numpy(chroma_cqt(y=bass.numpy(), sr=SR)) 93 | hist_kr = (kr.max(dim=0, keepdim=True)[0] == kr).float().mean(1) 94 | else: 95 | print("failed bass", bass.std(), ref) 96 | return None, track 97 | 98 | pickle.dump([tempo, events, hist_kr], open(cache_file, 'wb')) 99 | spec = Spec(tempo, events, hist_kr, track, index) 100 | return spec, None 101 | 102 | 103 | def best_pitch_shift(kr_a, kr_b): 104 | """find the best pitch shift between two chroma distributions.""" 105 | deltas = [] 106 | for p in range(12): 107 | deltas.append((kr_a - kr_b).abs().mean()) 108 | kr_b = kr_b.roll(1, 0) 109 | 110 | ps = np.argmin(deltas) 111 | if ps > 6: 112 | ps = ps - 12 113 | return ps 114 | 115 | 116 | def align_stems(stems): 117 | """Align the first beats of the stems. 118 | This is a naive implementation. A grid with a time definition 10ms is defined and 119 | each beat onset is represented as a gaussian over this grid. 120 | Then, we try each possible time shift to make two grids align the best. 121 | We repeat for all sources. 122 | """ 123 | sources = len(stems) 124 | width = 5e-3 # grid of 10ms 125 | limit = 5 126 | std = 2 127 | x = torch.arange(-limit, limit + 1, 1).float() 128 | gauss = torch.exp(-x**2 / (2 * std**2)) 129 | 130 | grids = [] 131 | for wav, onsets in stems: 132 | le = wav.shape[-1] 133 | dur = le / SR 134 | grid = torch.zeros(int(le / width / SR)) 135 | for onset in onsets: 136 | pos = int(onset / width) 137 | if onset >= dur - 1: 138 | continue 139 | if onset < 1: 140 | continue 141 | grid[pos - limit:pos + limit + 1] += gauss 142 | grids.append(grid) 143 | 144 | shifts = [0] 145 | for s in range(1, sources): 146 | max_shift = int(4 / width) 147 | dots = [] 148 | for shift in range(-max_shift, max_shift): 149 | other = grids[s] 150 | ref = grids[0] 151 | if shift >= 0: 152 | other = other[shift:] 153 | else: 154 | ref = ref[shift:] 155 | le = min(len(other), len(ref)) 156 | dots.append((ref[:le].dot(other[:le]), int(shift * width * SR))) 157 | 158 | _, shift = max(dots) 159 | shifts.append(-shift) 160 | 161 | outs = [] 162 | new_zero = min(shifts) 163 | for (wav, _), shift in zip(stems, shifts): 164 | offset = shift - new_zero 165 | wav = F.pad(wav, (offset, 0)) 166 | outs.append(wav) 167 | 168 | le = min(x.shape[-1] for x in outs) 169 | 170 | outs = [w[..., :le] for w in outs] 171 | return torch.stack(outs) 172 | 173 | 174 | def find_candidate(spec_ref, catalog, pitch_match=True): 175 | """Given reference track, this finds a track in the catalog that 176 | is a potential match (pitch and tempo delta must be within the allowable limits). 177 | """ 178 | candidates = list(catalog) 179 | random.shuffle(candidates) 180 | 181 | for spec in candidates: 182 | ok = False 183 | for scale in [1/4, 1/2, 1, 2, 4]: 184 | tempo = spec.tempo * scale 185 | delta_tempo = spec_ref.tempo / tempo - 1 186 | if abs(delta_tempo) < MAX_TEMPO: 187 | ok = True 188 | break 189 | if not ok: 190 | print(delta_tempo, spec_ref.tempo, spec.tempo, "FAILED TEMPO") 191 | # too much of a tempo difference 192 | continue 193 | spec = spec._replace(tempo=tempo) 194 | 195 | ps = 0 196 | if pitch_match: 197 | ps = best_pitch_shift(spec_ref.kr, spec.kr) 198 | if abs(ps) > MAX_PITCH: 199 | print("Failed pitch", ps) 200 | # too much pitch difference 201 | continue 202 | return spec, delta_tempo, ps 203 | 204 | 205 | def get_part(spec, source, dt, dp): 206 | """Apply given delta of tempo and delta of pitch to a stem.""" 207 | wav = spec.track[source] 208 | if dt or dp: 209 | wav = repitch(wav, dp, dt * 100, samplerate=SR, voice=source == 3) 210 | spec = spec._replace(onsets=spec.onsets / (1 + dt)) 211 | return wav, spec 212 | 213 | 214 | def build_track(ref_index, catalog): 215 | """Given the reference track index and a catalog of track, builds 216 | a completely new track. One of the source at random from the ref track will 217 | be kept and other sources will be drawn from the catalog. 218 | """ 219 | order = list(range(len(SOURCES))) 220 | random.shuffle(order) 221 | 222 | stems = [None] * len(order) 223 | indexes = [None] * len(order) 224 | origs = [None] * len(order) 225 | dps = [None] * len(order) 226 | dts = [None] * len(order) 227 | 228 | first = order[0] 229 | spec_ref = catalog[ref_index] 230 | stems[first] = (spec_ref.track[first], spec_ref.onsets) 231 | indexes[first] = ref_index 232 | origs[first] = spec_ref.track[first] 233 | dps[first] = 0 234 | dts[first] = 0 235 | 236 | pitch_match = order != 0 237 | 238 | for src in order[1:]: 239 | spec, dt, dp = find_candidate(spec_ref, catalog, pitch_match=pitch_match) 240 | if not pitch_match: 241 | spec_ref = spec_ref._replace(kr=spec.kr) 242 | pitch_match = True 243 | dps[src] = dp 244 | dts[src] = dt 245 | wav, spec = get_part(spec, src, dt, dp) 246 | stems[src] = (wav, spec.onsets) 247 | indexes[src] = spec.index 248 | origs.append(spec.track[src]) 249 | print("FINAL CHOICES", ref_index, indexes, dps, dts) 250 | stems = align_stems(stems) 251 | return stems, origs 252 | 253 | 254 | def get_musdb_dataset(part='train'): 255 | root = Path(MUSDB_PATH) / part 256 | ext = '.wav' 257 | metadata = build_metadata(root, SOURCES, ext=ext, normalize=False) 258 | valid_tracks = _get_musdb_valid() 259 | metadata_train = {name: meta for name, meta in metadata.items() if name not in valid_tracks} 260 | train_set = Wavset( 261 | root, metadata_train, SOURCES, samplerate=SR, channels=CHANNELS, 262 | normalize=False, ext=ext) 263 | sig = hashlib.sha1(str(root).encode()).hexdigest()[:8] 264 | train_set.sig = sig 265 | return train_set 266 | 267 | 268 | def get_wav_dataset(): 269 | root = Path(EXTRA_WAV_PATH) 270 | ext = '.wav' 271 | metadata = _build_metadata(root, SOURCES, ext=ext, normalize=False) 272 | train_set = Wavset( 273 | root, metadata, SOURCES, samplerate=SR, channels=CHANNELS, 274 | normalize=False, ext=ext) 275 | sig = hashlib.sha1(str(root).encode()).hexdigest()[:8] 276 | train_set.sig = sig 277 | return train_set 278 | 279 | 280 | def main(): 281 | random.seed(4321) 282 | if OUTPATH.exists(): 283 | shutil.rmtree(OUTPATH) 284 | OUTPATH.mkdir(exist_ok=True, parents=True) 285 | (OUTPATH / 'train').mkdir(exist_ok=True, parents=True) 286 | (OUTPATH / 'valid').mkdir(exist_ok=True, parents=True) 287 | out = OUTPATH / 'train' 288 | 289 | dset = get_musdb_dataset() 290 | # dset2 = get_wav_dataset() 291 | # dset3 = get_musdb_dataset('test') 292 | dset2 = None 293 | dset3 = None 294 | pendings = [] 295 | copies = 6 296 | copies_rej = 2 297 | 298 | with ProcessPoolExecutor(20) as pool: 299 | for index in range(len(dset)): 300 | pendings.append(pool.submit(analyse_track, dset, index)) 301 | 302 | if dset2: 303 | for index in range(len(dset2)): 304 | pendings.append(pool.submit(analyse_track, dset2, index)) 305 | if dset3: 306 | for index in range(len(dset3)): 307 | pendings.append(pool.submit(analyse_track, dset3, index)) 308 | 309 | catalog = [] 310 | rej = 0 311 | for pending in tqdm.tqdm(pendings, ncols=120): 312 | spec, track = pending.result() 313 | if spec is not None: 314 | catalog.append(spec) 315 | else: 316 | mix = track.sum(0) 317 | for copy in range(copies_rej): 318 | folder = out / f'rej_{rej}_{copy}' 319 | folder.mkdir() 320 | save_audio(mix, folder / "mixture.wav", SR) 321 | for stem, source in zip(track, SOURCES): 322 | save_audio(stem, folder / f"{source}.wav", SR, clip='clamp') 323 | rej += 1 324 | 325 | for copy in range(copies): 326 | for index in range(len(catalog)): 327 | track, origs = build_track(index, catalog) 328 | mix = track.sum(0) 329 | mx = mix.abs().max() 330 | scale = max(1, 1.01 * mx) 331 | mix = mix / scale 332 | track = track / scale 333 | folder = out / f'{copy}_{index}' 334 | folder.mkdir() 335 | save_audio(mix, folder / "mixture.wav", SR) 336 | for stem, source, orig in zip(track, SOURCES, origs): 337 | save_audio(stem, folder / f"{source}.wav", SR, clip='clamp') 338 | # save_audio(stem.std() * orig / (1e-6 + orig.std()), folder / f"{source}_orig.wav", 339 | # SR, clip='clamp') 340 | 341 | 342 | if __name__ == '__main__': 343 | main() 344 | -------------------------------------------------------------------------------- /tools/bench.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | """ 8 | benchmarking script, useful to check for OOM, reasonable train time, 9 | and for the MDX competion, estimate if we will match the time limit.""" 10 | from contextlib import contextmanager 11 | import logging 12 | import sys 13 | import time 14 | import torch 15 | 16 | from demucs.train import get_solver, main 17 | from demucs.apply import apply_model 18 | 19 | logging.basicConfig(level=logging.INFO, stream=sys.stderr) 20 | 21 | 22 | class Result: 23 | pass 24 | 25 | 26 | @contextmanager 27 | def bench(): 28 | import gc 29 | gc.collect() 30 | torch.cuda.reset_max_memory_allocated() 31 | torch.cuda.empty_cache() 32 | result = Result() 33 | # before = torch.cuda.memory_allocated() 34 | before = 0 35 | begin = time.time() 36 | try: 37 | yield result 38 | finally: 39 | torch.cuda.synchronize() 40 | mem = (torch.cuda.max_memory_allocated() - before) / 2 ** 20 41 | tim = time.time() - begin 42 | result.mem = mem 43 | result.tim = tim 44 | 45 | 46 | xp = main.get_xp_from_sig(sys.argv[1]) 47 | xp = main.get_xp(xp.argv + sys.argv[2:]) 48 | with xp.enter(): 49 | solver = get_solver(xp.cfg) 50 | if getattr(solver.model, 'use_train_segment', False): 51 | batch = solver.augment(next(iter(solver.loaders['train']))) 52 | solver.model.segment = Fraction(batch.shape[-1], solver.model.samplerate) 53 | train_segment = solver.model.segment 54 | solver.model.eval() 55 | model = solver.model 56 | model.cuda() 57 | x = torch.randn(2, xp.cfg.dset.channels, int(10 * model.samplerate), device='cuda') 58 | with bench() as res: 59 | y = model(x) 60 | y.sum().backward() 61 | del y 62 | for p in model.parameters(): 63 | p.grad = None 64 | print(f"FB: {res.mem:.1f} MB, {res.tim * 1000:.1f} ms") 65 | 66 | x = torch.randn(1, xp.cfg.dset.channels, int(model.segment * model.samplerate), device='cuda') 67 | with bench() as res: 68 | with torch.no_grad(): 69 | y = model(x) 70 | del y 71 | print(f"FV: {res.mem:.1f} MB, {res.tim * 1000:.1f} ms") 72 | 73 | model.cpu() 74 | torch.set_num_threads(1) 75 | test = torch.randn(1, xp.cfg.dset.channels, model.samplerate * 40) 76 | b = time.time() 77 | apply_model(model, test, split=True, shifts=1) 78 | print("CPU 40 sec:", time.time() - b) 79 | -------------------------------------------------------------------------------- /tools/convert.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # Script to convert option names and model args from the dev branch to 8 | # the cleanup release one. There should be no reaso to use that anymore. 9 | 10 | import argparse 11 | import io 12 | import json 13 | from pathlib import Path 14 | import subprocess as sp 15 | 16 | import torch 17 | 18 | from demucs import train, pretrained, states 19 | 20 | DEV_REPO = Path.home() / 'tmp/release_demucs_mdx' 21 | 22 | 23 | TO_REMOVE = [ 24 | 'demucs.dconv_kw.gelu=True', 25 | 'demucs.dconv_kw.nfreqs=0', 26 | 'demucs.dconv_kw.nfreqs=0', 27 | 'demucs.dconv_kw.version=4', 28 | 'demucs.norm=gn', 29 | 'wdemucs.nice=True', 30 | 'wdemucs.good=True', 31 | 'wdemucs.freq_emb=-0.2', 32 | 'special=True', 33 | 'special=False', 34 | ] 35 | 36 | TO_REPLACE = [ 37 | ('power', 'svd'), 38 | ('wdemucs', 'hdemucs'), 39 | ('hdemucs.hybrid=True', 'hdemucs.hybrid_old=True'), 40 | ('hdemucs.hybrid=2', 'hdemucs.hybrid=True'), 41 | ] 42 | 43 | TO_INJECT = [ 44 | ('model=hdemucs', ['hdemucs.cac=False']), 45 | ('model=hdemucs', ['hdemucs.norm_starts=999']), 46 | ] 47 | 48 | 49 | def get_original_argv(sig): 50 | return json.load(open(Path(DEV_REPO) / f'outputs/xps/{sig}/.argv.json')) 51 | 52 | 53 | def transform(argv, mappings, verbose=False): 54 | for rm in TO_REMOVE: 55 | while rm in argv: 56 | argv.remove(rm) 57 | 58 | for old, new in TO_REPLACE: 59 | argv[:] = [a.replace(old, new) for a in argv] 60 | 61 | for condition, args in TO_INJECT: 62 | if condition in argv: 63 | argv[:] = args + argv 64 | 65 | for idx, arg in enumerate(argv): 66 | if 'continue_from=' in arg: 67 | dep_sig = arg.split('=')[1] 68 | if dep_sig.startswith('"'): 69 | dep_sig = eval(dep_sig) 70 | if verbose: 71 | print("Need to recursively convert dependency XP", dep_sig) 72 | new_sig = convert(dep_sig, mappings, verbose).sig 73 | argv[idx] = f'continue_from="{new_sig}"' 74 | 75 | 76 | def convert(sig, mappings, verbose=False): 77 | argv = get_original_argv(sig) 78 | if verbose: 79 | print("Original argv", argv) 80 | transform(argv, mappings, verbose) 81 | if verbose: 82 | print("New argv", argv) 83 | xp = train.main.get_xp(argv) 84 | train.main.init_xp(xp) 85 | if verbose: 86 | print("Mapping", sig, "->", xp.sig) 87 | mappings[sig] = xp.sig 88 | return xp 89 | 90 | 91 | def _eval_old(old_sig, x): 92 | script = ( 93 | 'from demucs import pretrained; import torch; import sys; import io; ' 94 | 'buf = io.BytesIO(sys.stdin.buffer.read()); ' 95 | 'x = torch.load(buf); m = pretrained.load_pretrained_model(' 96 | f'"{old_sig}"); torch.save(m(x), sys.stdout.buffer)') 97 | 98 | buf = io.BytesIO() 99 | torch.save(x, buf) 100 | proc = sp.run( 101 | ['python3', '-c', script], input=buf.getvalue(), capture_output=True, cwd=DEV_REPO) 102 | if proc.returncode != 0: 103 | print("Error", proc.stderr.decode()) 104 | assert False 105 | 106 | buf = io.BytesIO(proc.stdout) 107 | return torch.load(buf) 108 | 109 | 110 | def compare(old_sig, model): 111 | test = torch.randn(1, 2, 44100 * 10) 112 | old_out = _eval_old(old_sig, test) 113 | out = model(test) 114 | 115 | delta = 20 * torch.log10((out - old_out).norm() / out.norm()).item() 116 | return delta 117 | 118 | 119 | def main(): 120 | torch.manual_seed(1234) 121 | parser = argparse.ArgumentParser('convert') 122 | parser.add_argument('sigs', nargs='*') 123 | parser.add_argument('-o', '--output', type=Path, default=Path('release_models')) 124 | parser.add_argument('-d', '--dump', action='store_true') 125 | parser.add_argument('-c', '--compare', action='store_true') 126 | parser.add_argument('-v', '--verbose', action='store_true') 127 | args = parser.parse_args() 128 | 129 | args.output.mkdir(exist_ok=True, parents=True) 130 | mappings = {} 131 | for sig in args.sigs: 132 | xp = convert(sig, mappings, args.verbose) 133 | if args.dump or args.compare: 134 | old_pkg = pretrained._load_package(sig, old=True) 135 | model = train.get_model(xp.cfg) 136 | model.load_state_dict(old_pkg['state']) 137 | if args.dump: 138 | pkg = states.serialize_model(model, xp.cfg) 139 | states.save_with_checksum(pkg, args.output / f'{xp.sig}.th') 140 | if args.compare: 141 | delta = compare(sig, model) 142 | print("Delta for", sig, xp.sig, delta) 143 | 144 | mappings[sig] = xp.sig 145 | 146 | print("FINAL MAPPINGS") 147 | for old, new in mappings.items(): 148 | print(old, " ", new) 149 | 150 | 151 | if __name__ == '__main__': 152 | main() 153 | -------------------------------------------------------------------------------- /tools/export.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | """Export a trained model from the full checkpoint (with optimizer etc.) to 8 | a final checkpoint, with only the model itself. The model is always stored as 9 | half float to gain space, and because this has zero impact on the final loss. 10 | When DiffQ was used for training, the model will actually be quantized and bitpacked.""" 11 | from argparse import ArgumentParser 12 | from fractions import Fraction 13 | import logging 14 | from pathlib import Path 15 | import sys 16 | import torch 17 | 18 | from demucs import train 19 | from demucs.states import serialize_model, save_with_checksum 20 | 21 | 22 | logger = logging.getLogger(__name__) 23 | 24 | 25 | def main(): 26 | logging.basicConfig(level=logging.INFO, stream=sys.stderr) 27 | 28 | parser = ArgumentParser("tools.export", description="Export trained models from XP sigs.") 29 | parser.add_argument('signatures', nargs='*', help='XP signatures.') 30 | parser.add_argument('-o', '--out', type=Path, default=Path("release_models"), 31 | help="Path where to store release models (default release_models)") 32 | parser.add_argument('-s', '--sign', action='store_true', 33 | help='Add sha256 prefix checksum to the filename.') 34 | 35 | args = parser.parse_args() 36 | args.out.mkdir(exist_ok=True, parents=True) 37 | 38 | for sig in args.signatures: 39 | xp = train.main.get_xp_from_sig(sig) 40 | name = train.main.get_name(xp) 41 | logger.info('Handling %s/%s', sig, name) 42 | 43 | out_path = args.out / (sig + ".th") 44 | 45 | solver = train.get_solver_from_sig(sig) 46 | if len(solver.history) < solver.args.epochs: 47 | logger.warning( 48 | 'Model %s has less epoch than expected (%d / %d)', 49 | sig, len(solver.history), solver.args.epochs) 50 | 51 | solver.model.load_state_dict(solver.best_state) 52 | pkg = serialize_model(solver.model, solver.args, solver.quantizer, half=True) 53 | if getattr(solver.model, 'use_train_segment', False): 54 | batch = solver.augment(next(iter(solver.loaders['train']))) 55 | pkg['kwargs']['segment'] = Fraction(batch.shape[-1], solver.model.samplerate) 56 | print("Override", pkg['kwargs']['segment']) 57 | valid, test = None, None 58 | for m in solver.history: 59 | if 'valid' in m: 60 | valid = m['valid'] 61 | if 'test' in m: 62 | test = m['test'] 63 | pkg['metrics'] = (valid, test) 64 | if args.sign: 65 | save_with_checksum(pkg, out_path) 66 | else: 67 | torch.save(pkg, out_path) 68 | 69 | 70 | if __name__ == '__main__': 71 | main() 72 | -------------------------------------------------------------------------------- /tools/test_pretrained.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # Script to evaluate pretrained models. 8 | 9 | from argparse import ArgumentParser 10 | import logging 11 | import sys 12 | 13 | import torch 14 | 15 | from demucs import train, pretrained, evaluate 16 | 17 | 18 | def main(): 19 | torch.set_num_threads(1) 20 | logging.basicConfig(stream=sys.stderr, level=logging.INFO) 21 | parser = ArgumentParser("tools.test_pretrained", 22 | description="Evaluate pre-trained models or bags of models " 23 | "on MusDB.") 24 | pretrained.add_model_flags(parser) 25 | parser.add_argument('overrides', nargs='*', 26 | help='Extra overrides, e.g. test.shifts=2.') 27 | args = parser.parse_args() 28 | 29 | xp = train.main.get_xp(args.overrides) 30 | with xp.enter(): 31 | solver = train.get_solver(xp.cfg) 32 | 33 | model = pretrained.get_model_from_args(args) 34 | solver.model = model.to(solver.device) 35 | solver.model.eval() 36 | 37 | with torch.no_grad(): 38 | results = evaluate.evaluate(solver, xp.cfg.test.sdr) 39 | print(results) 40 | 41 | 42 | if __name__ == '__main__': 43 | main() 44 | --------------------------------------------------------------------------------