├── audiocraft
├── py.typed
├── utils
│ ├── samples
│ │ └── __init__.py
│ ├── __init__.py
│ ├── notebook.py
│ ├── profiler.py
│ ├── autocast.py
│ ├── deadlock.py
│ ├── export_legacy.py
│ ├── cluster.py
│ ├── export.py
│ └── best_state.py
├── grids
│ ├── __init__.py
│ ├── audiogen
│ │ ├── __init__.py
│ │ ├── audiogen_base_16khz.py
│ │ └── audiogen_pretrained_16khz_eval.py
│ ├── compression
│ │ ├── __init__.py
│ │ ├── encodec_base_24khz.py
│ │ ├── encodec_audiogen_16khz.py
│ │ ├── debug.py
│ │ ├── encodec_musicgen_32khz.py
│ │ └── _explorers.py
│ ├── diffusion
│ │ ├── __init__.py
│ │ ├── 4_bands_base_32khz.py
│ │ └── _explorers.py
│ ├── musicgen
│ │ ├── __init__.py
│ │ ├── musicgen_clapemb_32khz.py
│ │ ├── musicgen_base_32khz.py
│ │ ├── musicgen_melody_32khz.py
│ │ ├── musicgen_base_cached_32khz.py
│ │ ├── _explorers.py
│ │ └── musicgen_pretrained_32khz_eval.py
│ └── _base_explorers.py
├── quantization
│ ├── __init__.py
│ └── base.py
├── adversarial
│ ├── discriminators
│ │ ├── __init__.py
│ │ ├── base.py
│ │ └── mpd.py
│ └── __init__.py
├── data
│ ├── __init__.py
│ ├── zip.py
│ └── info_audio_dataset.py
├── metrics
│ ├── __init__.py
│ ├── chroma_cosinesim.py
│ └── clap_consistency.py
├── solvers
│ ├── __init__.py
│ └── audiogen.py
├── losses
│ ├── __init__.py
│ └── sisnr.py
├── modules
│ ├── __init__.py
│ ├── lstm.py
│ ├── chroma.py
│ └── activations.py
├── models
│ └── __init__.py
├── optim
│ ├── __init__.py
│ ├── linear_warmup_lr_scheduler.py
│ ├── inverse_sqrt_lr_scheduler.py
│ ├── cosine_lr_scheduler.py
│ ├── polynomial_decay_lr_scheduler.py
│ └── ema.py
└── __init__.py
├── assets
├── bach.mp3
├── bolero_ravel.mp3
├── sirens_and_a_humming_engine_approach_and_pass.mp3
└── a_duck_quacking_as_birds_chirp_and_a_pigeon_cooing.mp3
├── dataset
└── example
│ ├── electro_1.mp3
│ ├── electro_2.mp3
│ ├── electro_2.json
│ └── electro_1.json
├── config
├── solver
│ ├── audiogen
│ │ ├── evaluation
│ │ │ ├── none.yaml
│ │ │ └── objective_eval.yaml
│ │ ├── default.yaml
│ │ ├── debug.yaml
│ │ └── audiogen_base_16khz.yaml
│ ├── musicgen
│ │ ├── evaluation
│ │ │ ├── none.yaml
│ │ │ └── objective_eval.yaml
│ │ ├── debug.yaml
│ │ ├── musicgen_base_32khz.yaml
│ │ ├── musicgen_melody_32khz.yaml
│ │ └── default.yaml
│ ├── compression
│ │ ├── encodec_base_24khz.yaml
│ │ ├── encodec_audiogen_16khz.yaml
│ │ ├── encodec_musicgen_32khz.yaml
│ │ ├── debug.yaml
│ │ └── default.yaml
│ ├── diffusion
│ │ ├── encodec_24khz.yaml
│ │ ├── debug.yaml
│ │ └── default.yaml
│ └── default.yaml
├── model
│ ├── lm
│ │ ├── model_scale
│ │ │ ├── base.yaml
│ │ │ ├── small.yaml
│ │ │ ├── medium.yaml
│ │ │ ├── large.yaml
│ │ │ └── xsmall.yaml
│ │ ├── audiogen_lm.yaml
│ │ ├── musicgen_lm.yaml
│ │ └── default.yaml
│ ├── none.yaml
│ ├── encodec
│ │ ├── encodec_base_causal.yaml
│ │ ├── encodec_large_nq4_s640.yaml
│ │ ├── encodec_large_nq4_s320.yaml
│ │ └── default.yaml
│ └── score
│ │ └── basic.yaml
├── dset
│ ├── audio
│ │ ├── default.yaml
│ │ ├── example.yaml
│ │ ├── audiocaps_16khz.yaml
│ │ └── musiccaps_32khz.yaml
│ ├── internal
│ │ ├── music_400k_32khz.yaml
│ │ ├── music_10k_32khz.yaml
│ │ └── sounds_16khz.yaml
│ └── default.yaml
├── conditioner
│ ├── none.yaml
│ ├── text2sound.yaml
│ ├── text2music.yaml
│ ├── chroma2music.yaml
│ └── clapemb2music.yaml
├── teams
│ ├── default.yaml
│ └── labs.yaml
└── config.yaml
├── mypy.ini
├── scripts
├── __init__.py
├── templates
│ ├── base.html
│ ├── login.html
│ ├── results.html
│ └── index.html
└── static
│ └── style.css
├── tests
├── __init__.py
├── data
│ ├── __init__.py
│ └── test_audio_utils.py
├── losses
│ ├── __init__.py
│ └── test_losses.py
├── modules
│ ├── __init__.py
│ ├── test_lstm.py
│ └── test_activations.py
├── utils
│ └── __init__.py
├── adversarial
│ ├── __init__.py
│ └── test_discriminators.py
├── common_utils
│ ├── __init__.py
│ ├── wav_utils.py
│ └── temp_utils.py
├── quantization
│ └── test_vq.py
└── models
│ ├── test_audiogen.py
│ ├── test_musicgen.py
│ ├── test_multibanddiffusion.py
│ └── test_encodec_model.py
├── egs
└── example
│ └── data.jsonl
├── setup.cfg
├── .github
├── workflows
│ ├── audiocraft_linter.yml
│ ├── audiocraft_tests.yml
│ └── audiocraft_docs.yml
└── actions
│ └── audiocraft_build
│ └── action.yml
├── MANIFEST.in
├── requirements.txt
├── .gitignore
├── LICENSE
├── CHANGELOG.md
├── CONTRIBUTING.md
├── Makefile
├── setup.py
├── docs
└── DATASETS.md
├── CODE_OF_CONDUCT.md
└── README.md
/audiocraft/py.typed:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/assets/bach.mp3:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Vaibhavs10/audiocraft/main/assets/bach.mp3
--------------------------------------------------------------------------------
/assets/bolero_ravel.mp3:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Vaibhavs10/audiocraft/main/assets/bolero_ravel.mp3
--------------------------------------------------------------------------------
/dataset/example/electro_1.mp3:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Vaibhavs10/audiocraft/main/dataset/example/electro_1.mp3
--------------------------------------------------------------------------------
/dataset/example/electro_2.mp3:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Vaibhavs10/audiocraft/main/dataset/example/electro_2.mp3
--------------------------------------------------------------------------------
/config/solver/audiogen/evaluation/none.yaml:
--------------------------------------------------------------------------------
1 | # @package __global__
2 |
3 | dataset:
4 | evaluate:
5 | num_samples: 10000
6 |
--------------------------------------------------------------------------------
/config/solver/musicgen/evaluation/none.yaml:
--------------------------------------------------------------------------------
1 | # @package __global__
2 |
3 | dataset:
4 | evaluate:
5 | num_samples: 10000
6 |
--------------------------------------------------------------------------------
/config/model/lm/model_scale/base.yaml:
--------------------------------------------------------------------------------
1 | # @package __global__
2 |
3 | # overrides nothing because default is already transformer base (~ 60M params)
4 |
--------------------------------------------------------------------------------
/config/model/lm/model_scale/small.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | # 300M Param.
4 |
5 | transformer_lm:
6 | dim: 1024
7 | num_heads: 16
8 | num_layers: 24
9 |
--------------------------------------------------------------------------------
/assets/sirens_and_a_humming_engine_approach_and_pass.mp3:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Vaibhavs10/audiocraft/main/assets/sirens_and_a_humming_engine_approach_and_pass.mp3
--------------------------------------------------------------------------------
/config/model/lm/model_scale/medium.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | # gpt2 like (~1.5B params)
4 | transformer_lm:
5 | dim: 1536
6 | num_heads: 24
7 | num_layers: 48
8 |
--------------------------------------------------------------------------------
/assets/a_duck_quacking_as_birds_chirp_and_a_pigeon_cooing.mp3:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Vaibhavs10/audiocraft/main/assets/a_duck_quacking_as_birds_chirp_and_a_pigeon_cooing.mp3
--------------------------------------------------------------------------------
/config/model/lm/model_scale/large.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | # gpt2 inspired, even bigger (~3.3B params)
4 | transformer_lm:
5 | dim: 2048
6 | num_heads: 32
7 | num_layers: 48
8 |
--------------------------------------------------------------------------------
/mypy.ini:
--------------------------------------------------------------------------------
1 | [mypy]
2 |
3 | [mypy-treetable,torchaudio.*,soundfile,einops.*,av.*,tqdm.*,num2words.*,spacy,xformers.*,scipy,huggingface_hub,transformers,dac.*]
4 | ignore_missing_imports = True
5 |
--------------------------------------------------------------------------------
/config/model/none.yaml:
--------------------------------------------------------------------------------
1 | # @package __global__
2 |
3 | # This file exist so that model is recognized as a config group
4 | # by Hydra, and Dora. A bit weird we might need a better fix someday.
5 |
--------------------------------------------------------------------------------
/config/model/encodec/encodec_base_causal.yaml:
--------------------------------------------------------------------------------
1 | # @package __global__
2 |
3 | defaults:
4 | - encodec/default
5 |
6 | encodec:
7 | causal: true
8 |
9 | rvq:
10 | n_q: 32
11 | q_dropout: true
12 |
--------------------------------------------------------------------------------
/config/dset/audio/default.yaml:
--------------------------------------------------------------------------------
1 | # @package __global__
2 |
3 | datasource:
4 | max_sample_rate: ???
5 | max_channels: ???
6 |
7 | train: ???
8 | valid: ???
9 | evaluate: ???
10 | generate: null
11 |
--------------------------------------------------------------------------------
/scripts/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/tests/data/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/tests/losses/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/tests/modules/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/tests/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/tests/adversarial/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/config/dset/audio/example.yaml:
--------------------------------------------------------------------------------
1 | # @package __global__
2 |
3 | datasource:
4 | max_sample_rate: 44100
5 | max_channels: 2
6 |
7 | train: egs/example
8 | valid: egs/example
9 | evaluate: egs/example
10 | generate: egs/example
11 |
--------------------------------------------------------------------------------
/config/model/lm/model_scale/xsmall.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | # just used for debugging or when we just want to populate the cache
3 | # and do not care about training.
4 |
5 | transformer_lm:
6 | dim: 64
7 | num_heads: 2
8 | num_layers: 2
9 |
--------------------------------------------------------------------------------
/audiocraft/utils/samples/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/config/model/encodec/encodec_large_nq4_s640.yaml:
--------------------------------------------------------------------------------
1 | # @package __global__
2 |
3 | defaults:
4 | - encodec/default
5 |
6 | seanet:
7 | ratios: [8, 5, 4, 4]
8 | n_filters: 64
9 |
10 | rvq:
11 | bins: 2048
12 | n_q: 4
13 | q_dropout: false
14 |
--------------------------------------------------------------------------------
/audiocraft/grids/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | """Dora Grids."""
7 |
--------------------------------------------------------------------------------
/audiocraft/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | """Utilities."""
7 |
--------------------------------------------------------------------------------
/config/model/encodec/encodec_large_nq4_s320.yaml:
--------------------------------------------------------------------------------
1 | # @package __global__
2 |
3 | defaults:
4 | - encodec/default
5 |
6 | seanet:
7 | # default ratios are [8, 5, 4, 2]
8 | n_filters: 64
9 |
10 | rvq:
11 | bins: 2048
12 | n_q: 4
13 | q_dropout: false
14 |
--------------------------------------------------------------------------------
/config/solver/compression/encodec_base_24khz.yaml:
--------------------------------------------------------------------------------
1 | # @package __global__
2 |
3 | defaults:
4 | - compression/default
5 | - /model: encodec/encodec_base_causal
6 | - override /dset: audio/default
7 | - _self_
8 |
9 | channels: 1
10 | sample_rate: 24000
11 |
--------------------------------------------------------------------------------
/audiocraft/grids/audiogen/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | """AudioGen grids."""
7 |
--------------------------------------------------------------------------------
/audiocraft/grids/compression/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | """EnCodec grids."""
7 |
--------------------------------------------------------------------------------
/audiocraft/grids/diffusion/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | """Diffusion grids."""
7 |
--------------------------------------------------------------------------------
/audiocraft/grids/musicgen/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | """MusicGen grids."""
7 |
--------------------------------------------------------------------------------
/config/solver/compression/encodec_audiogen_16khz.yaml:
--------------------------------------------------------------------------------
1 | # @package __global__
2 |
3 | defaults:
4 | - compression/default
5 | - /model: encodec/encodec_large_nq4_s320
6 | - override /dset: audio/default
7 | - _self_
8 |
9 | channels: 1
10 | sample_rate: 16000
11 |
--------------------------------------------------------------------------------
/config/solver/compression/encodec_musicgen_32khz.yaml:
--------------------------------------------------------------------------------
1 | # @package __global__
2 |
3 | defaults:
4 | - compression/default
5 | - /model: encodec/encodec_large_nq4_s640
6 | - override /dset: audio/default
7 | - _self_
8 |
9 | channels: 1
10 | sample_rate: 32000
11 |
--------------------------------------------------------------------------------
/config/solver/diffusion/encodec_24khz.yaml:
--------------------------------------------------------------------------------
1 | # @package __global__
2 |
3 | defaults:
4 | - diffusion/default
5 | - _self_
6 |
7 |
8 | sample_rate: 24000
9 | channels: 1
10 | compression_model_checkpoint: //pretrained/facebook/encodec_24khz
11 | n_q: 4 # num quantizers, 3kbps
12 |
--------------------------------------------------------------------------------
/egs/example/data.jsonl:
--------------------------------------------------------------------------------
1 | {"path": "dataset/example/electro_1.mp3", "duration": 15.024, "sample_rate": 48000, "amplitude": null, "weight": null, "info_path": null}
2 | {"path": "dataset/example/electro_2.mp3", "duration": 20.035918367346937, "sample_rate": 44100, "amplitude": null, "weight": null, "info_path": null}
3 |
--------------------------------------------------------------------------------
/setup.cfg:
--------------------------------------------------------------------------------
1 | [pep8]
2 | max-line-length = 120
3 |
4 | [flake8]
5 | max-line-length = 120
6 |
7 | [coverage:report]
8 | include = audiocraft/*
9 | omit =
10 | audiocraft/environment.py
11 | audiocraft/solvers/*
12 | audiocraft/utils/*
13 | audiocraft/*/loaders.py
14 | audiocraft/*/builders.py
15 |
--------------------------------------------------------------------------------
/dataset/example/electro_2.json:
--------------------------------------------------------------------------------
1 | {"key": "", "artist": "Voyager I", "sample_rate": 44100, "file_extension": "mp3", "description": "This is an electronic song sending positive vibes.", "keywords": "", "duration": 20.0, "bpm": "", "genre": "electronic", "title": "Untitled song", "name": "electro_2", "instrument": "Mix", "moods": []}
2 |
--------------------------------------------------------------------------------
/dataset/example/electro_1.json:
--------------------------------------------------------------------------------
1 | {"key": "", "artist": "Voyager I", "sample_rate": 48000, "file_extension": "mp3", "description": "A cool song from Voyager.", "keywords": "bright, pulsing, cool", "duration": 15.0, "bpm": "", "genre": "electronic", "title": "Enracinement", "name": "electro_1", "instrument": "Mix", "moods": ["uplifting", "motivational"]}
2 |
--------------------------------------------------------------------------------
/config/conditioner/none.yaml:
--------------------------------------------------------------------------------
1 | # @package __global__
2 |
3 | # No conditioning
4 |
5 | classifier_free_guidance:
6 | training_dropout: 0
7 | inference_coef: 1
8 |
9 | attribute_dropout:
10 | text: {}
11 | wav: {}
12 |
13 | fuser:
14 | sum: []
15 | prepend: []
16 | cross: []
17 | input_interpolate: []
18 |
19 | conditioners: null
20 |
--------------------------------------------------------------------------------
/config/dset/internal/music_400k_32khz.yaml:
--------------------------------------------------------------------------------
1 | # @package __global__
2 |
3 | datasource:
4 | max_sample_rate: 32000
5 | max_channels: 1
6 |
7 | train: egs/music/music_400k_32khz/train
8 | valid: egs/music/music_400k_32khz/valid
9 | evaluate: egs/music/music_400k_32khz/test
10 | generate: egs/music/music_400k_32khz/test # identical to evaluate
11 |
--------------------------------------------------------------------------------
/config/dset/audio/audiocaps_16khz.yaml:
--------------------------------------------------------------------------------
1 | # @package __global__
2 |
3 | # AudioCaps dataset
4 | datasource:
5 | max_sample_rate: 16000
6 | max_channels: 1
7 |
8 | train: null # only evaluation set
9 | valid: null # only evaluation set
10 | evaluate: egs/audiocaps/audiocaps_16khz
11 | generate: egs/audiocaps/audiocaps_16khz # identical to evaluate
12 |
--------------------------------------------------------------------------------
/config/dset/default.yaml:
--------------------------------------------------------------------------------
1 | # @package __global__
2 |
3 | # WARNING: This is a base configuration file shared across ALL solvers in AudioCraft
4 | # Please don't update this file directly. Instead use distinct configuration files
5 | # to override the below configuration.
6 | datasource:
7 | train: ???
8 | valid: ???
9 | evaluate: ???
10 | generate: ???
11 |
--------------------------------------------------------------------------------
/config/model/score/basic.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | diffusion_unet:
4 | hidden: 48
5 | depth: 4
6 | res_blocks: 1
7 | norm_groups: 4
8 | kernel: 8
9 | stride: 4
10 | growth: 4
11 | max_channels: 10_000
12 | dropout: 0.
13 | emb_all_layers: true
14 | bilstm: false
15 | codec_dim: null
16 | transformer: false
17 | cross_attention: false
--------------------------------------------------------------------------------
/tests/common_utils/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # flake8: noqa
8 | from .temp_utils import TempDirMixin
9 | from .wav_utils import get_batch_white_noise, get_white_noise, save_wav
10 |
--------------------------------------------------------------------------------
/audiocraft/quantization/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | """RVQ."""
7 | # flake8: noqa
8 | from .vq import ResidualVectorQuantizer
9 | from .base import BaseQuantizer, DummyQuantizer, QuantizedResult
10 |
--------------------------------------------------------------------------------
/config/teams/default.yaml:
--------------------------------------------------------------------------------
1 | default:
2 | dora_dir: /tmp/audiocraft_${oc.env:USER}
3 | partitions:
4 | global: debug
5 | team: debug
6 | reference_dir: /tmp
7 | darwin: # if we detect we are on a Mac, then most likely we are doing unit testing etc.
8 | dora_dir: /tmp/audiocraft_${oc.env:USER}
9 | partitions:
10 | global: debug
11 | team: debug
12 | reference_dir: /tmp
13 |
--------------------------------------------------------------------------------
/config/dset/internal/music_10k_32khz.yaml:
--------------------------------------------------------------------------------
1 | # @package __global__
2 |
3 | # high quality music dataset with no artist overlap between splits
4 | datasource:
5 | max_sample_rate: 32000
6 | max_channels: 1
7 |
8 | train: egs/music/music_10k_32khz/train
9 | valid: egs/music/music_10k_32khz/valid
10 | evaluate: egs/music/music_10k_32khz/test
11 | generate: egs/music/music_10k_32khz/test # identical to evaluate
12 |
--------------------------------------------------------------------------------
/config/dset/internal/sounds_16khz.yaml:
--------------------------------------------------------------------------------
1 | # @package __global__
2 |
3 | # environmental sounds dataset compiling all datasets
4 | # with applied filters on tags
5 | datasource:
6 | max_sample_rate: 16000
7 | max_channels: 1
8 |
9 | train: egs/sound/sounds_16khz/train
10 | valid: egs/sound/sounds_16khz/valid
11 | evaluate: egs/sound/sounds_16khz/test
12 | generate: egs/sound/sounds_16khz/test # identical to evaluate
13 |
--------------------------------------------------------------------------------
/audiocraft/adversarial/discriminators/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # flake8: noqa
8 | from .mpd import MultiPeriodDiscriminator
9 | from .msd import MultiScaleDiscriminator
10 | from .msstftd import MultiScaleSTFTDiscriminator
11 |
--------------------------------------------------------------------------------
/.github/workflows/audiocraft_linter.yml:
--------------------------------------------------------------------------------
1 | name: audiocraft_linter
2 | on:
3 | push:
4 | branches: [ main ]
5 | pull_request:
6 | branches: [ main ]
7 |
8 | jobs:
9 | run_linter:
10 | name: Run linter
11 | runs-on: ubuntu-latest
12 | steps:
13 | - uses: actions/checkout@v2
14 | - uses: ./.github/actions/audiocraft_build
15 | - run: |
16 | . env/bin/activate
17 | make linter
18 |
--------------------------------------------------------------------------------
/MANIFEST.in:
--------------------------------------------------------------------------------
1 | include Makefile
2 | include LICENSE
3 | include LICENSE_weights
4 | include *.md
5 | include *.ini
6 | include requirements.txt
7 | include audiocraft/py.typed
8 | include assets/*.mp3
9 | include datasets/*.mp3
10 | recursive-include config *.yaml
11 | recursive-include demos *.py
12 | recursive-include demos *.ipynb
13 | recursive-include scripts *.py
14 | recursive-include model_cards *.md
15 | recursive-include docs *.md
16 |
--------------------------------------------------------------------------------
/config/dset/audio/musiccaps_32khz.yaml:
--------------------------------------------------------------------------------
1 | # @package __global__
2 |
3 | # total samples obtained from MusicCaps = 5469
4 | # (out of 5521 due to AudioSet corrupted samples)
5 | datasource:
6 | max_sample_rate: 32000
7 | max_channels: 2
8 |
9 | train: null # only evaluation set
10 | valid: null # only evaluation set
11 | evaluate: egs/musiccaps/musiccaps_32khz
12 | generate: egs/musiccaps/musiccaps_32khz # identical to evaluate
13 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | # please make sure you have already a pytorch install that is cuda enabled!
2 | av
3 | einops
4 | flashy>=0.0.1
5 | hydra-core>=1.1
6 | hydra_colorlog
7 | julius
8 | num2words
9 | numpy
10 | sentencepiece
11 | spacy==3.5.2
12 | torch>=2.0.0
13 | torchaudio>=2.0.0
14 | huggingface_hub
15 | tqdm
16 | transformers>=4.31.0 # need Encodec there.
17 | xformers
18 | demucs
19 | librosa
20 | gradio
21 | torchmetrics
22 | encodec
23 | protobuf
--------------------------------------------------------------------------------
/audiocraft/data/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | """Audio loading and writing support. Datasets for raw audio
7 | or also including some metadata."""
8 |
9 | # flake8: noqa
10 | from . import audio, audio_dataset, info_audio_dataset, music_dataset, sound_dataset
11 |
--------------------------------------------------------------------------------
/scripts/templates/base.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 | {% block head %}
5 |
6 |
7 | AudioCraft — MOS
8 | {% endblock %}
9 |
10 |
11 |
12 |
AudioCraft — MOS
13 | {% block content %}{% endblock %}
14 |
15 |
16 |
17 |
--------------------------------------------------------------------------------
/config/conditioner/text2sound.yaml:
--------------------------------------------------------------------------------
1 | # @package __global__
2 |
3 | classifier_free_guidance:
4 | training_dropout: 0.1
5 | inference_coef: 3.0
6 |
7 | attribute_dropout: {}
8 |
9 | fuser:
10 | cross_attention_pos_emb: false
11 | cross_attention_pos_emb_scale: 1
12 | sum: []
13 | prepend: []
14 | cross: [description]
15 | input_interpolate: []
16 |
17 | conditioners:
18 | description:
19 | model: t5
20 | t5:
21 | name: t5-large
22 | finetune: false
23 | word_dropout: 0.
24 | normalize_text: false
25 |
--------------------------------------------------------------------------------
/scripts/templates/login.html:
--------------------------------------------------------------------------------
1 | {% extends "base.html" %}
2 | {% block content %}
3 |
4 |
5 | You must identify yourself first! We use a highly secured protocol
6 | where you just decide your username, and that's it. No password, no encryption,
7 | just pure trust.
8 |
9 |
10 | {% if error %}
11 | {{error}}
12 | {% endif %}
13 |
27 |
28 | {% endblock %}
29 |
--------------------------------------------------------------------------------
/config/solver/audiogen/default.yaml:
--------------------------------------------------------------------------------
1 | # @package __global__
2 |
3 | defaults:
4 | - /solver/musicgen/default
5 | - _self_
6 | - /solver/audiogen/evaluation: none
7 | - override /dset: audio/default
8 |
9 | # See config/solver/musicgen/default.yaml for a list of possible values.
10 | # We only keep the most important here.
11 |
12 | autocast: true
13 | autocast_dtype: float16
14 |
15 | solver: audiogen
16 | sample_rate: ???
17 | channels: ???
18 | compression_model_checkpoint: ???
19 |
20 | tokens:
21 | padding_with_special_token: false
22 |
23 | dataset:
24 | batch_size: 128
25 | segment_duration: 10
26 | min_segment_ratio: 1.0 # lower values such as 0.5 result in generations with a lot of silence.
27 |
28 | optim:
29 | epochs: 100
30 | updates_per_epoch: 2000
31 | lr: 1e-4
32 | optimizer: adamw
33 | max_norm: 1.0
34 | adam:
35 | betas: [0.9, 0.95]
36 | weight_decay: 0.1
37 | eps: 1e-8
38 |
39 | schedule:
40 | lr_scheduler: null
41 |
--------------------------------------------------------------------------------
/tests/modules/test_activations.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import torch
8 | from torch import nn
9 |
10 | from audiocraft.modules.activations import CustomGLU
11 |
12 |
13 | class TestActivations:
14 | def test_custom_glu_calculation(self):
15 |
16 | activation = CustomGLU(nn.Identity())
17 |
18 | initial_shape = (4, 8, 8)
19 |
20 | part_a = torch.ones(initial_shape) * 2
21 | part_b = torch.ones(initial_shape) * -1
22 | input = torch.cat((part_a, part_b), dim=-1)
23 |
24 | output = activation(input)
25 |
26 | # ensure all dimensions match initial shape
27 | assert output.shape == initial_shape
28 | # ensure the gating was calculated correctly a * f(b)
29 | assert torch.all(output == -2).item()
30 |
--------------------------------------------------------------------------------
/tests/common_utils/wav_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from pathlib import Path
8 |
9 | import torch
10 |
11 | from audiocraft.data.audio import audio_write
12 |
13 |
14 | def get_white_noise(chs: int = 1, num_frames: int = 1):
15 | wav = torch.randn(chs, num_frames)
16 | return wav
17 |
18 |
19 | def get_batch_white_noise(bs: int = 1, chs: int = 1, num_frames: int = 1):
20 | wav = torch.randn(bs, chs, num_frames)
21 | return wav
22 |
23 |
24 | def save_wav(path: str, wav: torch.Tensor, sample_rate: int):
25 | assert wav.dim() == 2, wav.shape
26 | fp = Path(path)
27 | assert fp.suffix in ['.mp3', '.ogg', '.wav', '.flac'], fp
28 | audio_write(fp.parent / fp.stem, wav, sample_rate, fp.suffix[1:],
29 | normalize=False, strategy='clip', peak_clip_headroom_db=0)
30 |
--------------------------------------------------------------------------------
/.github/actions/audiocraft_build/action.yml:
--------------------------------------------------------------------------------
1 | name: audiocraft_build
2 | description: 'Build audiocraft env.'
3 | runs:
4 | using: "composite"
5 | steps:
6 | - uses: actions/setup-python@v2
7 | with:
8 | python-version: 3.8
9 | - uses: actions/cache@v2
10 | id: cache
11 | with:
12 | path: env
13 | key: audiocraft_env-${{ hashFiles('**/requirements.txt') }}
14 |
15 | - if: ${{ steps.cache.outputs.cache-hit != 'true' }}
16 | name: Install dependencies
17 | shell: bash
18 | run: |
19 | sudo apt-get update
20 | sudo apt-get install libsndfile1-dev ffmpeg
21 | python3 -m venv env
22 | . env/bin/activate
23 | python -m pip install --upgrade pip
24 | pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
25 | pip install --pre xformers
26 | pip install -e '.[dev]'
27 | - name: System Dependencies
28 | shell: bash
29 | run: |
30 | sudo apt-get update
31 | sudo apt-get install libsndfile1-dev ffmpeg
32 |
--------------------------------------------------------------------------------
/audiocraft/utils/notebook.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | try:
8 | import IPython.display as ipd # type: ignore
9 | except ImportError:
10 | # Note in a notebook...
11 | pass
12 |
13 |
14 | import torch
15 |
16 |
17 | def display_audio(samples: torch.Tensor, sample_rate: int):
18 | """Renders an audio player for the given audio samples.
19 |
20 | Args:
21 | samples (torch.Tensor): a Tensor of decoded audio samples
22 | with shapes [B, C, T] or [C, T]
23 | sample_rate (int): sample rate audio should be displayed with.
24 | """
25 | assert samples.dim() == 2 or samples.dim() == 3
26 |
27 | samples = samples.detach().cpu()
28 | if samples.dim() == 2:
29 | samples = samples[None, ...]
30 |
31 | for audio in samples:
32 | ipd.display(ipd.Audio(audio, rate=sample_rate))
33 |
--------------------------------------------------------------------------------
/config/conditioner/chroma2music.yaml:
--------------------------------------------------------------------------------
1 | # @package __global__
2 |
3 | classifier_free_guidance:
4 | training_dropout: 0.2
5 | inference_coef: 3.0
6 |
7 | attribute_dropout:
8 | args:
9 | active_on_eval: false
10 | text: {}
11 | wav:
12 | self_wav: 0.5
13 |
14 | fuser:
15 | cross_attention_pos_emb: false
16 | cross_attention_pos_emb_scale: 1
17 | sum: []
18 | prepend: [self_wav, description]
19 | cross: []
20 | input_interpolate: []
21 |
22 | conditioners:
23 | self_wav:
24 | model: chroma_stem
25 | chroma_stem:
26 | sample_rate: ${sample_rate}
27 | n_chroma: 12
28 | radix2_exp: 14
29 | argmax: true
30 | match_len_on_eval: false
31 | eval_wavs: null
32 | n_eval_wavs: 100
33 | cache_path: null
34 | description:
35 | model: t5
36 | t5:
37 | name: t5-base
38 | finetune: false
39 | word_dropout: 0.2
40 | normalize_text: false
41 |
42 | dataset:
43 | train:
44 | merge_text_p: 0.25
45 | drop_desc_p: 0.5
46 | drop_other_p: 0.5
47 |
--------------------------------------------------------------------------------
/audiocraft/adversarial/discriminators/base.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from abc import ABC, abstractmethod
8 | import typing as tp
9 |
10 | import torch
11 | import torch.nn as nn
12 |
13 |
14 | FeatureMapType = tp.List[torch.Tensor]
15 | LogitsType = torch.Tensor
16 | MultiDiscriminatorOutputType = tp.Tuple[tp.List[LogitsType], tp.List[FeatureMapType]]
17 |
18 |
19 | class MultiDiscriminator(ABC, nn.Module):
20 | """Base implementation for discriminators composed of sub-discriminators acting at different scales.
21 | """
22 | def __init__(self):
23 | super().__init__()
24 |
25 | @abstractmethod
26 | def forward(self, x: torch.Tensor) -> MultiDiscriminatorOutputType:
27 | ...
28 |
29 | @property
30 | @abstractmethod
31 | def num_discriminators(self) -> int:
32 | """Number of discriminators.
33 | """
34 | ...
35 |
--------------------------------------------------------------------------------
/config/solver/compression/debug.yaml:
--------------------------------------------------------------------------------
1 | # @package __global__
2 |
3 | defaults:
4 | - compression/default
5 | - /model: encodec/encodec_base_causal
6 | - override /dset: audio/example
7 | - _self_
8 |
9 | channels: 1
10 | sample_rate: 16000
11 |
12 | # debug config uses just L1
13 | losses:
14 | adv: 0.
15 | feat: 0.
16 | l1: 1.
17 | mel: 0.
18 | msspec: 0.
19 | # no balancer
20 | balancer:
21 | balance_grads: false
22 | ema_decay: 1.
23 | total_norm: 1.
24 | per_batch_item: false
25 | # no adversaries
26 | adversarial:
27 | adversaries: []
28 | adv_loss: hinge
29 | feat_loss: l1
30 |
31 | # faster model for local dev
32 | seanet:
33 | dimension: 16
34 | n_filters: 4
35 |
36 | # very small dataset
37 | dataset:
38 | batch_size: 8
39 | num_workers: 10
40 | num_samples: 100
41 | segment_duration: 1
42 | evaluate:
43 | batch_size: 32
44 | generate:
45 | batch_size: 1
46 | num_samples: 5
47 | segment_duration: 10
48 |
49 | # limited training
50 | evaluate:
51 | every: 5
52 | generate:
53 | every: 5
54 | optim:
55 | epochs: 50
56 |
--------------------------------------------------------------------------------
/audiocraft/grids/compression/encodec_base_24khz.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | """
8 | Grid search file, simply list all the exp you want in `explorer`.
9 | Any new exp added there will be scheduled.
10 | You can cancel and experiment by commenting its line.
11 |
12 | This grid shows how to train a base causal EnCodec model at 24 kHz.
13 | """
14 |
15 | from ._explorers import CompressionExplorer
16 | from ...environment import AudioCraftEnvironment
17 |
18 |
19 | @CompressionExplorer
20 | def explorer(launcher):
21 | partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global'])
22 | launcher.slurm_(gpus=8, partition=partitions)
23 | # base causal EnCodec trained on monophonic audio sampled at 24 kHz
24 | launcher.bind_(solver='compression/encodec_base_24khz')
25 | # replace this by the desired dataset
26 | launcher.bind_(dset='audio/example')
27 | # launch xp
28 | launcher()
29 |
--------------------------------------------------------------------------------
/config/conditioner/clapemb2music.yaml:
--------------------------------------------------------------------------------
1 | # @package __global__
2 |
3 | classifier_free_guidance:
4 | training_dropout: 0.3
5 | inference_coef: 3.0
6 |
7 | attribute_dropout:
8 | text: {}
9 | wav: {}
10 |
11 | fuser:
12 | cross_attention_pos_emb: false
13 | cross_attention_pos_emb_scale: 1
14 | sum: []
15 | prepend: []
16 | cross: [description]
17 | input_interpolate: []
18 |
19 | conditioners:
20 | description:
21 | model: clap
22 | clap:
23 | checkpoint: //reference/clap/music_audioset_epoch_15_esc_90.14.pt
24 | model_arch: 'HTSAT-base'
25 | enable_fusion: false
26 | sample_rate: 44100
27 | max_audio_length: 10
28 | audio_stride: 1
29 | dim: 512
30 | attribute: description
31 | normalize: true
32 | quantize: true # use RVQ quantization
33 | n_q: 12
34 | bins: 1024
35 | kmeans_iters: 50
36 | text_p: 0. # probability of using text embed at train time
37 | cache_path: null
38 |
39 | dataset:
40 | joint_embed_attributes: [description]
41 | train:
42 | merge_text_p: 0.25
43 | drop_desc_p: 0.5
44 | drop_other_p: 0.5
45 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) Meta Platforms, Inc. and affiliates.
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/audiocraft/grids/diffusion/4_bands_base_32khz.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | """
8 | Training of the 4 diffusion models described in
9 | "From Discrete Tokens to High-Fidelity Audio Using Multi-Band Diffusion"
10 | (paper link).
11 | """
12 |
13 | from ._explorers import DiffusionExplorer
14 |
15 |
16 | @DiffusionExplorer
17 | def explorer(launcher):
18 | launcher.slurm_(gpus=4, partition='learnfair')
19 |
20 | launcher.bind_({'solver': 'diffusion/default',
21 | 'dset': 'internal/music_10k_32khz'})
22 |
23 | with launcher.job_array():
24 | launcher({'filter.use': True, 'filter.idx_band': 0, "processor.use": False, 'processor.power_std': 0.4})
25 | launcher({'filter.use': True, 'filter.idx_band': 1, "processor.use": False, 'processor.power_std': 0.4})
26 | launcher({'filter.use': True, 'filter.idx_band': 2, "processor.use": True, 'processor.power_std': 0.4})
27 | launcher({'filter.use': True, 'filter.idx_band': 3, "processor.use": True, 'processor.power_std': 0.75})
28 |
--------------------------------------------------------------------------------
/CHANGELOG.md:
--------------------------------------------------------------------------------
1 | # Changelog
2 |
3 | All notable changes to this project will be documented in this file.
4 |
5 | The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).
6 |
7 |
8 | ## [1.0.1] - TBD
9 |
10 | Not using torchaudio anymore when writing audio files, relying instead directly on the commandline ffmpeg. Also not using it anymore for reading audio files, for similar reasons.
11 |
12 | ## [1.0.0] - 2023-09-07
13 |
14 | Major revision, added training code for EnCodec, AudioGen, MusicGen, and MultiBandDiffusion.
15 | Added pretrained model for AudioGen and MultiBandDiffusion.
16 |
17 | ## [0.0.2] - 2023-08-01
18 |
19 | Improved demo, fixed top p (thanks @jnordberg).
20 |
21 | Compressor tanh on output to avoid clipping with some style (especially piano).
22 | Now repeating the conditioning periodically if it is too short.
23 |
24 | More options when launching Gradio app locally (thanks @ashleykleynhans).
25 |
26 | Testing out PyTorch 2.0 memory efficient attention.
27 |
28 | Added extended generation (infinite length) by slowly moving the windows.
29 | Note that other implementations exist: https://github.com/camenduru/MusicGen-colab.
30 |
31 | ## [0.0.1] - 2023-06-09
32 |
33 | Initial release, with model evaluation only.
34 |
--------------------------------------------------------------------------------
/audiocraft/grids/compression/encodec_audiogen_16khz.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | """
8 | Grid search file, simply list all the exp you want in `explorer`.
9 | Any new exp added there will be scheduled.
10 | You can cancel and experiment by commenting its line.
11 |
12 | This grid shows how to train the new AudioGen EnCodec model at 16 kHz.
13 | """
14 |
15 | from ._explorers import CompressionExplorer
16 | from ...environment import AudioCraftEnvironment
17 |
18 |
19 | @CompressionExplorer
20 | def explorer(launcher):
21 | partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global'])
22 | launcher.slurm_(gpus=8, partition=partitions)
23 | # use configuration for AudioGen's EnCodec model trained on monophonic audio sampled at 16 kHz
24 | # AudioGen's EnCodec is trained with a total stride of 320 leading to a frame rate of 50 hz
25 | launcher.bind_(solver='compression/encodec_audiogen_16khz')
26 | # replace this by the desired sound dataset
27 | launcher.bind_(dset='internal/sounds_16khz')
28 | # launch xp
29 | launcher()
30 |
--------------------------------------------------------------------------------
/audiocraft/grids/compression/debug.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | """
8 | Grid search file, simply list all the exp you want in `explorer`.
9 | Any new exp added there will be scheduled.
10 | You can cancel and experiment by commenting its line.
11 |
12 | This grid is a minimal example for debugging compression task
13 | and how to override parameters directly in a grid.
14 | Learn more about dora grids: https://github.com/facebookresearch/dora
15 | """
16 |
17 | from ._explorers import CompressionExplorer
18 | from ...environment import AudioCraftEnvironment
19 |
20 |
21 | @CompressionExplorer
22 | def explorer(launcher):
23 | partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global'])
24 | launcher.slurm_(gpus=2, partition=partitions)
25 | launcher.bind_(solver='compression/debug')
26 |
27 | with launcher.job_array():
28 | # base debug task using config from solver=compression/debug
29 | launcher()
30 | # we can override parameters in the grid to launch additional xps
31 | launcher({'rvq.bins': 2048, 'rvq.n_q': 4})
32 |
--------------------------------------------------------------------------------
/config/solver/audiogen/debug.yaml:
--------------------------------------------------------------------------------
1 | # @package __global__
2 |
3 | # This is a minimal debugging configuration
4 | # for MusicGen training solver
5 | defaults:
6 | - audiogen/default
7 | - /model: lm/audiogen_lm
8 | - override /model/lm/model_scale: xsmall
9 | - override /dset: audio/example
10 | - _self_
11 |
12 | autocast: false
13 | compression_model_checkpoint: null
14 | transformer_lm:
15 | n_q: 4
16 | card: 400
17 |
18 | conditioners:
19 | description:
20 | model: t5
21 | t5:
22 | name: t5-small
23 |
24 | codebooks_pattern:
25 | modeling: parallel
26 |
27 | channels: 1
28 | sample_rate: 16000
29 |
30 | deadlock:
31 | use: false # deadlock detection
32 |
33 | dataset:
34 | batch_size: 4
35 | segment_duration: 5
36 | sample_on_weight: false # Uniform sampling all the way
37 | sample_on_duration: false # Uniform sampling all the way
38 |
39 | generate:
40 | audio:
41 | strategy: peak
42 | lm:
43 | use_sampling: false
44 | top_k: 0
45 | top_p: 0.0
46 |
47 | checkpoint:
48 | save_every: 0
49 | keep_last: 0
50 |
51 | optim:
52 | epochs: 2
53 | updates_per_epoch: 10
54 | optimizer: adamw
55 | lr: 1e-4
56 |
57 | logging:
58 | log_tensorboard: true
59 |
60 | schedule:
61 | lr_scheduler: null
62 |
--------------------------------------------------------------------------------
/config/solver/musicgen/debug.yaml:
--------------------------------------------------------------------------------
1 | # @package __global__
2 |
3 | # This is a minimal debugging configuration
4 | # for MusicGen training solver
5 | defaults:
6 | - musicgen/default
7 | - /model: lm/musicgen_lm
8 | - override /model/lm/model_scale: xsmall
9 | - override /dset: audio/example
10 | - _self_
11 |
12 | autocast: false
13 | compression_model_checkpoint: //pretrained/debug_compression_model
14 | transformer_lm:
15 | n_q: 4
16 | card: 400
17 |
18 | conditioners:
19 | description:
20 | model: t5
21 | t5:
22 | name: t5-small
23 |
24 | codebooks_pattern:
25 | modeling: parallel
26 |
27 | channels: 1
28 | sample_rate: 32000
29 |
30 | deadlock:
31 | use: false # deadlock detection
32 |
33 | dataset:
34 | batch_size: 4
35 | segment_duration: 5
36 | sample_on_weight: false # Uniform sampling all the way
37 | sample_on_duration: false # Uniform sampling all the way
38 |
39 | generate:
40 | audio:
41 | strategy: peak
42 | lm:
43 | use_sampling: false
44 | top_k: 0
45 | top_p: 0.0
46 |
47 | checkpoint:
48 | save_every: 0
49 | keep_last: 0
50 |
51 | optim:
52 | epochs: 2
53 | updates_per_epoch: 10
54 | optimizer: adamw
55 | lr: 1e-4
56 |
57 | logging:
58 | log_tensorboard: true
59 |
60 | schedule:
61 | lr_scheduler: null
62 |
--------------------------------------------------------------------------------
/audiocraft/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | """
7 | AudioCraft is a general framework for training audio generative models.
8 | At the moment we provide the training code for:
9 |
10 | - [MusicGen](https://arxiv.org/abs/2306.05284), a state-of-the-art
11 | text-to-music and melody+text autoregressive generative model.
12 | For the solver, see `audiocraft.solvers.musicgen.MusicGenSolver`, and for the model,
13 | `audiocraft.models.musicgen.MusicGen`.
14 | - [AudioGen](https://arxiv.org/abs/2209.15352), a state-of-the-art
15 | text-to-general-audio generative model.
16 | - [EnCodec](https://arxiv.org/abs/2210.13438), efficient and high fidelity
17 | neural audio codec which provides an excellent tokenizer for autoregressive language models.
18 | See `audiocraft.solvers.compression.CompressionSolver`, and `audiocraft.models.encodec.EncodecModel`.
19 | - [MultiBandDiffusion](TODO), alternative diffusion-based decoder compatible with EnCodec that
20 | improves the perceived quality and reduces the artifacts coming from adversarial decoders.
21 | """
22 |
23 | # flake8: noqa
24 | from . import data, modules, models
25 |
26 | __version__ = '1.0.0'
27 |
--------------------------------------------------------------------------------
/audiocraft/grids/musicgen/musicgen_clapemb_32khz.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from ._explorers import LMExplorer
8 | from ...environment import AudioCraftEnvironment
9 |
10 |
11 | @LMExplorer
12 | def explorer(launcher):
13 | partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global'])
14 | launcher.slurm_(gpus=32, partition=partitions)
15 | launcher.bind_(solver='musicgen/musicgen_base_32khz')
16 | # replace this by the desired music dataset
17 | launcher.bind_(dset='internal/music_400k_32khz')
18 | launcher.bind_(conditioner='clapemb2music')
19 |
20 | fsdp = {'autocast': False, 'fsdp.use': True}
21 | cache_path = {'conditioners.description.clap.cache_path':
22 | '/fsx-audio-craft-llm/jadecopet/experiments/audiocraft/caches/clap_embed_music'}
23 | text_wav_training_opt = {'conditioners.description.clap.text_p': 0.5}
24 |
25 | launcher.bind_(fsdp)
26 |
27 | launcher.slurm_(gpus=32).bind_(label='32gpus')
28 | with launcher.job_array():
29 | launcher()
30 | launcher(text_wav_training_opt)
31 | launcher(cache_path)
32 | launcher(cache_path, text_wav_training_opt)
33 |
--------------------------------------------------------------------------------
/config/model/encodec/default.yaml:
--------------------------------------------------------------------------------
1 | # @package __global__
2 |
3 | compression_model: encodec
4 |
5 | encodec:
6 | autoencoder: seanet
7 | quantizer: rvq
8 | sample_rate: ${sample_rate}
9 | channels: ${channels}
10 | causal: false
11 | renormalize: false
12 |
13 | seanet:
14 | dimension: 128
15 | channels: ${channels}
16 | causal: ${encodec.causal}
17 | n_filters: 32
18 | n_residual_layers: 1
19 | ratios: [8, 5, 4, 2]
20 | activation: ELU
21 | activation_params: {"alpha": 1.}
22 | norm: weight_norm
23 | norm_params: {}
24 | kernel_size: 7
25 | residual_kernel_size: 3
26 | last_kernel_size: 7
27 | dilation_base: 2
28 | pad_mode: constant
29 | true_skip: true
30 | compress: 2
31 | lstm: 2
32 | disable_norm_outer_blocks: 0
33 | # Specific encoder or decoder params.
34 | # You can also override any param for the encoder or decoder only
35 | # by using Hydra `+param=` syntax, i.e.`
36 | # `+seanet.decoder.n_filters=64`.
37 | decoder:
38 | trim_right_ratio: 1.0
39 | final_activation: null
40 | final_activation_params: null
41 | encoder: {}
42 |
43 | rvq:
44 | n_q: 8
45 | q_dropout: false
46 | bins: 1024
47 | decay: 0.99
48 | kmeans_init: true
49 | kmeans_iters: 50
50 | threshold_ema_dead_code: 2
51 | orthogonal_reg_weight: 0.0
52 | orthogonal_reg_active_codes_only: false
53 |
54 | no_quant: {}
55 |
--------------------------------------------------------------------------------
/audiocraft/utils/profiler.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import logging
8 | import typing as tp
9 |
10 | import dora
11 | import torch
12 |
13 |
14 | logger = logging.getLogger(__name__)
15 |
16 |
17 | class Profiler:
18 | """Context manager wrapper for xformers profiler.
19 | """
20 | def __init__(self, module: torch.nn.Module, enabled: bool = False):
21 | self.profiler: tp.Optional[tp.Any] = None
22 | if enabled:
23 | from xformers.profiler import profile
24 | output_dir = dora.get_xp().folder / 'profiler_data'
25 | logger.info("Profiling activated, results with be saved to %s", output_dir)
26 | self.profiler = profile(output_dir=output_dir, module=module)
27 |
28 | def step(self):
29 | if self.profiler is not None:
30 | self.profiler.step() # type: ignore
31 |
32 | def __enter__(self):
33 | if self.profiler is not None:
34 | return self.profiler.__enter__() # type: ignore
35 |
36 | def __exit__(self, exc_type, exc_value, exc_tb):
37 | if self.profiler is not None:
38 | return self.profiler.__exit__(exc_type, exc_value, exc_tb) # type: ignore
39 |
--------------------------------------------------------------------------------
/config/solver/musicgen/musicgen_base_32khz.yaml:
--------------------------------------------------------------------------------
1 | # @package __global__
2 |
3 | # This is the training loop solver
4 | # for the base MusicGen model (text-to-music)
5 | # on monophonic audio sampled at 32 kHz
6 | defaults:
7 | - musicgen/default
8 | - /model: lm/musicgen_lm
9 | - override /dset: audio/default
10 | - _self_
11 |
12 | autocast: true
13 | autocast_dtype: float16
14 |
15 | # EnCodec large trained on mono-channel music audio sampled at 32khz
16 | # with a total stride of 640 leading to 50 frames/s.
17 | # rvq.n_q=4, rvq.bins=2048, no quantization dropout
18 | # (transformer_lm card and n_q must be compatible)
19 | compression_model_checkpoint: //pretrained/facebook/encodec_32khz
20 |
21 | channels: 1
22 | sample_rate: 32000
23 |
24 | deadlock:
25 | use: true # deadlock detection
26 |
27 | dataset:
28 | batch_size: 192 # 32 GPUs
29 | sample_on_weight: false # Uniform sampling all the way
30 | sample_on_duration: false # Uniform sampling all the way
31 |
32 | generate:
33 | lm:
34 | use_sampling: true
35 | top_k: 250
36 | top_p: 0.0
37 |
38 | optim:
39 | epochs: 500
40 | optimizer: dadam
41 | lr: 1
42 | ema:
43 | use: true
44 | updates: 10
45 | device: cuda
46 |
47 | logging:
48 | log_tensorboard: true
49 |
50 | schedule:
51 | lr_scheduler: cosine
52 | cosine:
53 | warmup: 4000
54 | lr_min_ratio: 0.0
55 | cycle_length: 1.0
56 |
--------------------------------------------------------------------------------
/audiocraft/grids/compression/encodec_musicgen_32khz.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | """
8 | Grid search file, simply list all the exp you want in `explorer`.
9 | Any new exp added there will be scheduled.
10 | You can cancel and experiment by commenting its line.
11 |
12 | This grid shows how to train a MusicGen EnCodec model at 32 kHz.
13 | """
14 |
15 | from ._explorers import CompressionExplorer
16 | from ...environment import AudioCraftEnvironment
17 |
18 |
19 | @CompressionExplorer
20 | def explorer(launcher):
21 | partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global'])
22 | launcher.slurm_(gpus=8, partition=partitions)
23 | # use configuration for MusicGen's EnCodec model trained on monophonic audio sampled at 32 kHz
24 | # MusicGen's EnCodec is trained with a total stride of 640 leading to a frame rate of 50 hz
25 | launcher.bind_(solver='compression/encodec_musicgen_32khz')
26 | # replace this by the desired music dataset
27 | launcher.bind_(dset='internal/music_400k_32khz')
28 | # launch xp
29 | launcher()
30 | launcher({
31 | 'metrics.visqol.bin': '/data/home/jadecopet/local/usr/opt/visqol',
32 | 'label': 'visqol',
33 | 'evaluate.metrics.visqol': True
34 | })
35 |
--------------------------------------------------------------------------------
/audiocraft/optim/linear_warmup_lr_scheduler.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import typing as tp
8 |
9 | from torch.optim import Optimizer
10 | from torch.optim.lr_scheduler import _LRScheduler
11 |
12 |
13 | class LinearWarmupLRScheduler(_LRScheduler):
14 | """Inverse square root LR scheduler.
15 |
16 | Args:
17 | optimizer (Optimizer): Torch optimizer.
18 | warmup_steps (int): Number of warmup steps.
19 | warmup_init_lr (tp.Optional[float]): Initial learning rate
20 | during warmup phase. When not set, use the provided learning rate.
21 | """
22 | def __init__(self, optimizer: Optimizer, warmup_steps: int, warmup_init_lr: tp.Optional[float] = 0):
23 | self.warmup_steps = warmup_steps
24 | self.warmup_init_lr = warmup_init_lr
25 | super().__init__(optimizer)
26 |
27 | def _get_sched_lr(self, lr: float, step: int):
28 | if step < self.warmup_steps:
29 | warmup_init_lr = self.warmup_init_lr or 0
30 | lr_step = (lr - warmup_init_lr) / self.warmup_steps
31 | lr = warmup_init_lr + step * lr_step
32 | return lr
33 |
34 | def get_lr(self):
35 | return [self._get_sched_lr(base_lr, self.last_epoch) for base_lr in self.base_lrs]
36 |
--------------------------------------------------------------------------------
/config/solver/musicgen/musicgen_melody_32khz.yaml:
--------------------------------------------------------------------------------
1 | # @package __global__
2 |
3 | # This is the training loop solver
4 | # for the melody MusicGen model (text+chroma to music)
5 | # on monophonic audio sampled at 32 kHz
6 | defaults:
7 | - musicgen/default
8 | - /model: lm/musicgen_lm
9 | - override /conditioner: chroma2music
10 | - override /dset: audio/default
11 | - _self_
12 |
13 | autocast: true
14 | autocast_dtype: float16
15 |
16 | # EnCodec large trained on mono-channel music audio sampled at 32khz
17 | # with a total stride of 640 leading to 50 frames/s.
18 | # rvq.n_q=4, rvq.bins=2048, no quantization dropout
19 | # (transformer_lm card and n_q must be compatible)
20 | compression_model_checkpoint: //pretrained/facebook/encodec_32khz
21 |
22 | channels: 1
23 | sample_rate: 32000
24 |
25 | deadlock:
26 | use: true # deadlock detection
27 |
28 | dataset:
29 | batch_size: 192 # 32 GPUs
30 | sample_on_weight: false # Uniform sampling all the way
31 | sample_on_duration: false # Uniform sampling all the way
32 |
33 | generate:
34 | lm:
35 | use_sampling: true
36 | top_k: 250
37 | top_p: 0.0
38 |
39 | optim:
40 | epochs: 500
41 | optimizer: dadam
42 | lr: 1
43 | ema:
44 | use: true
45 | updates: 10
46 | device: cuda
47 |
48 | logging:
49 | log_tensorboard: true
50 |
51 | schedule:
52 | lr_scheduler: cosine
53 | cosine:
54 | warmup: 4000
55 | lr_min_ratio: 0.0
56 | cycle_length: 1.0
57 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # Contributing to AudioCraft
2 |
3 | We want to make contributing to this project as easy and transparent as
4 | possible.
5 |
6 | ## Pull Requests
7 |
8 | AudioCraft is the implementation of a research paper.
9 | Therefore, we do not plan on accepting many pull requests for new features.
10 | We certainly welcome them for bug fixes.
11 |
12 | 1. Fork the repo and create your branch from `main`.
13 | 2. If you've added code that should be tested, add tests.
14 | 3. If you've changed APIs, update the documentation.
15 | 4. Ensure the test suite passes.
16 | 5. Make sure your code lints.
17 | 6. If you haven't already, complete the Contributor License Agreement ("CLA").
18 |
19 | ## Contributor License Agreement ("CLA")
20 | In order to accept your pull request, we need you to submit a CLA. You only need
21 | to do this once to work on any of Meta's open source projects.
22 |
23 | Complete your CLA here:
24 |
25 | ## Issues
26 | We use GitHub issues to track public bugs. Please ensure your description is
27 | clear and has sufficient instructions to be able to reproduce the issue.
28 |
29 | Meta has a [bounty program](https://www.facebook.com/whitehat/) for the safe
30 | disclosure of security bugs. In those cases, please go through the process
31 | outlined on that page and do not file a public issue.
32 |
33 | ## License
34 | By contributing to encodec, you agree that your contributions will be licensed
35 | under the LICENSE file in the root directory of this source tree.
36 |
--------------------------------------------------------------------------------
/audiocraft/utils/autocast.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import torch
8 |
9 |
10 | class TorchAutocast:
11 | """TorchAutocast utility class.
12 | Allows you to enable and disable autocast. This is specially useful
13 | when dealing with different architectures and clusters with different
14 | levels of support.
15 |
16 | Args:
17 | enabled (bool): Whether to enable torch.autocast or not.
18 | args: Additional args for torch.autocast.
19 | kwargs: Additional kwargs for torch.autocast
20 | """
21 | def __init__(self, enabled: bool, *args, **kwargs):
22 | self.autocast = torch.autocast(*args, **kwargs) if enabled else None
23 |
24 | def __enter__(self):
25 | if self.autocast is None:
26 | return
27 | try:
28 | self.autocast.__enter__()
29 | except RuntimeError:
30 | device = self.autocast.device
31 | dtype = self.autocast.fast_dtype
32 | raise RuntimeError(
33 | f"There was an error autocasting with dtype={dtype} device={device}\n"
34 | "If you are on the FAIR Cluster, you might need to use autocast_dtype=float16"
35 | )
36 |
37 | def __exit__(self, *args, **kwargs):
38 | if self.autocast is None:
39 | return
40 | self.autocast.__exit__(*args, **kwargs)
41 |
--------------------------------------------------------------------------------
/audiocraft/optim/inverse_sqrt_lr_scheduler.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import typing as tp
8 |
9 | from torch.optim import Optimizer
10 | from torch.optim.lr_scheduler import _LRScheduler
11 |
12 |
13 | class InverseSquareRootLRScheduler(_LRScheduler):
14 | """Inverse square root LR scheduler.
15 |
16 | Args:
17 | optimizer (Optimizer): Torch optimizer.
18 | warmup_steps (int): Number of warmup steps.
19 | warmup_init_lr (tp.Optional[float]): Initial learning rate
20 | during warmup phase. When not set, use the provided learning rate.
21 | """
22 | def __init__(self, optimizer: Optimizer, warmup_steps: int, warmup_init_lr: tp.Optional[float] = 0):
23 | self.warmup_steps = warmup_steps
24 | self.warmup_init_lr = warmup_init_lr
25 | super().__init__(optimizer)
26 |
27 | def _get_sched_lr(self, lr: float, step: int):
28 | if step < self.warmup_steps:
29 | warmup_init_lr = self.warmup_init_lr or 0
30 | lr_step = (lr - warmup_init_lr) / self.warmup_steps
31 | lr = warmup_init_lr + step * lr_step
32 | else:
33 | decay_factor = lr * self.warmup_steps**0.5
34 | lr = decay_factor * step**-0.5
35 | return lr
36 |
37 | def get_lr(self):
38 | return [self._get_sched_lr(base_lr, self._step_count) for base_lr in self.base_lrs]
39 |
--------------------------------------------------------------------------------
/Makefile:
--------------------------------------------------------------------------------
1 | INTEG=AUDIOCRAFT_DORA_DIR="/tmp/magma_$(USER)" python3 -m dora -v run --clear device=cpu dataset.num_workers=0 optim.epochs=1 \
2 | dataset.train.num_samples=10 dataset.valid.num_samples=10 \
3 | dataset.evaluate.num_samples=10 dataset.generate.num_samples=2 sample_rate=16000 \
4 | logging.level=DEBUG
5 | INTEG_COMPRESSION = $(INTEG) solver=compression/debug rvq.n_q=2 rvq.bins=48 checkpoint.save_last=true # SIG is 5091833e
6 | INTEG_MUSICGEN = $(INTEG) solver=musicgen/debug dset=audio/example compression_model_checkpoint=//sig/5091833e \
7 | transformer_lm.n_q=2 transformer_lm.card=48 transformer_lm.dim=16 checkpoint.save_last=false # Using compression model from 5091833e
8 | INTEG_AUDIOGEN = $(INTEG) solver=audiogen/debug dset=audio/example compression_model_checkpoint=//sig/5091833e \
9 | transformer_lm.n_q=2 transformer_lm.card=48 transformer_lm.dim=16 checkpoint.save_last=false # Using compression model from 5091833e
10 | INTEG_MBD = $(INTEG) solver=diffusion/debug dset=audio/example \
11 | checkpoint.save_last=false # Using compression model from 616d7b3c
12 |
13 | default: linter tests
14 |
15 | install:
16 | pip install -U pip
17 | pip install -U -e '.[dev]'
18 |
19 | linter:
20 | flake8 audiocraft && mypy audiocraft
21 | flake8 tests && mypy tests
22 |
23 | tests:
24 | coverage run -m pytest tests
25 | coverage report
26 |
27 | tests_integ:
28 | $(INTEG_COMPRESSION)
29 | $(INTEG_MBD)
30 | $(INTEG_MUSICGEN)
31 | $(INTEG_AUDIOGEN)
32 |
33 |
34 | api_docs:
35 | pdoc3 --html -o api_docs -f audiocraft
36 |
37 | dist:
38 | python setup.py sdist
39 |
40 | .PHONY: linter tests api_docs dist
41 |
--------------------------------------------------------------------------------
/audiocraft/grids/musicgen/musicgen_base_32khz.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from ._explorers import LMExplorer
8 | from ...environment import AudioCraftEnvironment
9 |
10 |
11 | @LMExplorer
12 | def explorer(launcher):
13 | partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global'])
14 | launcher.slurm_(gpus=32, partition=partitions)
15 | launcher.bind_(solver='musicgen/musicgen_base_32khz')
16 | # replace this by the desired music dataset
17 | launcher.bind_(dset='internal/music_400k_32khz')
18 |
19 | fsdp = {'autocast': False, 'fsdp.use': True}
20 | medium = {'model/lm/model_scale': 'medium'}
21 | large = {'model/lm/model_scale': 'large'}
22 |
23 | cfg_low = {'classifier_free_guidance.training_dropout': 0.2}
24 | wd_low = {'conditioners.description.t5.word_dropout': 0.2}
25 |
26 | adam = {'optim.optimizer': 'adamw', 'optim.lr': 1e-4}
27 |
28 | launcher.bind_(fsdp)
29 |
30 | launcher.slurm_(gpus=32).bind_(label='32gpus')
31 | with launcher.job_array():
32 | sub = launcher.bind()
33 | sub()
34 |
35 | launcher.slurm_(gpus=64).bind_(label='64gpus')
36 | with launcher.job_array():
37 | sub = launcher.bind()
38 | sub(medium, adam)
39 |
40 | launcher.slurm_(gpus=96).bind_(label='96gpus')
41 | with launcher.job_array():
42 | sub = launcher.bind()
43 | sub(large, cfg_low, wd_low, adam, {'optim.max_norm': 3})
44 |
--------------------------------------------------------------------------------
/audiocraft/grids/compression/_explorers.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import treetable as tt
8 |
9 | from .._base_explorers import BaseExplorer
10 |
11 |
12 | class CompressionExplorer(BaseExplorer):
13 | eval_metrics = ["sisnr", "visqol"]
14 |
15 | def stages(self):
16 | return ["train", "valid", "evaluate"]
17 |
18 | def get_grid_meta(self):
19 | """Returns the list of Meta information to display for each XP/job.
20 | """
21 | return [
22 | tt.leaf("index", align=">"),
23 | tt.leaf("name", wrap=140),
24 | tt.leaf("state"),
25 | tt.leaf("sig", align=">"),
26 | ]
27 |
28 | def get_grid_metrics(self):
29 | """Return the metrics that should be displayed in the tracking table.
30 | """
31 | return [
32 | tt.group(
33 | "train",
34 | [
35 | tt.leaf("epoch"),
36 | tt.leaf("bandwidth", ".2f"),
37 | tt.leaf("adv", ".4f"),
38 | tt.leaf("d_loss", ".4f"),
39 | ],
40 | align=">",
41 | ),
42 | tt.group(
43 | "valid",
44 | [
45 | tt.leaf("bandwidth", ".2f"),
46 | tt.leaf("adv", ".4f"),
47 | tt.leaf("msspec", ".4f"),
48 | tt.leaf("sisnr", ".2f"),
49 | ],
50 | align=">",
51 | ),
52 | tt.group(
53 | "evaluate", [tt.leaf(name, ".3f") for name in self.eval_metrics], align=">"
54 | ),
55 | ]
56 |
--------------------------------------------------------------------------------
/tests/models/test_audiogen.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import pytest
8 | import torch
9 |
10 | from audiocraft.models import AudioGen
11 |
12 |
13 | class TestAudioGenModel:
14 | def get_audiogen(self):
15 | ag = AudioGen.get_pretrained(name='debug', device='cpu')
16 | ag.set_generation_params(duration=2.0, extend_stride=2.)
17 | return ag
18 |
19 | def test_base(self):
20 | ag = self.get_audiogen()
21 | assert ag.frame_rate == 25
22 | assert ag.sample_rate == 16000
23 | assert ag.audio_channels == 1
24 |
25 | def test_generate_continuation(self):
26 | ag = self.get_audiogen()
27 | prompt = torch.randn(3, 1, 16000)
28 | wav = ag.generate_continuation(prompt, 16000)
29 | assert list(wav.shape) == [3, 1, 32000]
30 |
31 | prompt = torch.randn(2, 1, 16000)
32 | wav = ag.generate_continuation(
33 | prompt, 16000, ['youpi', 'lapin dort'])
34 | assert list(wav.shape) == [2, 1, 32000]
35 |
36 | prompt = torch.randn(2, 1, 16000)
37 | with pytest.raises(AssertionError):
38 | wav = ag.generate_continuation(
39 | prompt, 16000, ['youpi', 'lapin dort', 'one too many'])
40 |
41 | def test_generate(self):
42 | ag = self.get_audiogen()
43 | wav = ag.generate(
44 | ['youpi', 'lapin dort'])
45 | assert list(wav.shape) == [2, 1, 32000]
46 |
47 | def test_generate_long(self):
48 | ag = self.get_audiogen()
49 | ag.max_duration = 3.
50 | ag.set_generation_params(duration=4., extend_stride=2.)
51 | wav = ag.generate(
52 | ['youpi', 'lapin dort'])
53 | assert list(wav.shape) == [2, 1, 16000 * 4]
54 |
--------------------------------------------------------------------------------
/audiocraft/optim/cosine_lr_scheduler.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import math
8 |
9 | from torch.optim import Optimizer
10 | from torch.optim.lr_scheduler import _LRScheduler
11 |
12 |
13 | class CosineLRScheduler(_LRScheduler):
14 | """Cosine LR scheduler.
15 |
16 | Args:
17 | optimizer (Optimizer): Torch optimizer.
18 | warmup_steps (int): Number of warmup steps.
19 | total_steps (int): Total number of steps.
20 | lr_min_ratio (float): Minimum learning rate.
21 | cycle_length (float): Cycle length.
22 | """
23 | def __init__(self, optimizer: Optimizer, total_steps: int, warmup_steps: int,
24 | lr_min_ratio: float = 0.0, cycle_length: float = 1.0):
25 | self.warmup_steps = warmup_steps
26 | assert self.warmup_steps >= 0
27 | self.total_steps = total_steps
28 | assert self.total_steps >= 0
29 | self.lr_min_ratio = lr_min_ratio
30 | self.cycle_length = cycle_length
31 | super().__init__(optimizer)
32 |
33 | def _get_sched_lr(self, lr: float, step: int):
34 | if step < self.warmup_steps:
35 | lr_ratio = step / self.warmup_steps
36 | lr = lr_ratio * lr
37 | elif step <= self.total_steps:
38 | s = (step - self.warmup_steps) / (self.total_steps - self.warmup_steps)
39 | lr_ratio = self.lr_min_ratio + 0.5 * (1 - self.lr_min_ratio) * \
40 | (1. + math.cos(math.pi * s / self.cycle_length))
41 | lr = lr_ratio * lr
42 | else:
43 | lr_ratio = self.lr_min_ratio
44 | lr = lr_ratio * lr
45 | return lr
46 |
47 | def get_lr(self):
48 | return [self._get_sched_lr(lr, self.last_epoch) for lr in self.base_lrs]
49 |
--------------------------------------------------------------------------------
/audiocraft/utils/deadlock.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import logging
8 | import os
9 | from queue import Queue, Empty
10 | import signal
11 | import sys
12 | import threading
13 | import traceback
14 |
15 | logger = logging.getLogger(__name__)
16 |
17 |
18 | class DeadlockDetect:
19 | def __init__(self, use: bool = False, timeout: float = 120.):
20 | self.use = use
21 | self.timeout = timeout
22 | self._queue: Queue = Queue()
23 |
24 | def update(self, stage: str):
25 | if self.use:
26 | self._queue.put(stage)
27 |
28 | def __enter__(self):
29 | if self.use:
30 | self._thread = threading.Thread(target=self._detector_thread)
31 | self._thread.start()
32 |
33 | def __exit__(self, exc_type, exc_val, exc_tb):
34 | if self.use:
35 | self._queue.put(None)
36 | self._thread.join()
37 |
38 | def _detector_thread(self):
39 | logger.debug("Deadlock detector started")
40 | last_stage = "init"
41 | while True:
42 | try:
43 | stage = self._queue.get(timeout=self.timeout)
44 | except Empty:
45 | break
46 | if stage is None:
47 | logger.debug("Exiting deadlock detector thread")
48 | return
49 | else:
50 | last_stage = stage
51 | logger.error("Deadlock detector timed out, last stage was %s", last_stage)
52 | for th in threading.enumerate():
53 | print(th, file=sys.stderr)
54 | traceback.print_stack(sys._current_frames()[th.ident])
55 | print(file=sys.stderr)
56 | sys.stdout.flush()
57 | sys.stderr.flush()
58 | os.kill(os.getpid(), signal.SIGKILL)
59 |
--------------------------------------------------------------------------------
/tests/common_utils/temp_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import os
8 | import tempfile
9 |
10 |
11 | class TempDirMixin:
12 | """Mixin to provide easy access to temp dir.
13 | """
14 |
15 | temp_dir_ = None
16 |
17 | @classmethod
18 | def get_base_temp_dir(cls):
19 | # If AUDIOCRAFT_TEST_DIR is set, use it instead of temporary directory.
20 | # this is handy for debugging.
21 | key = "AUDIOCRAFT_TEST_DIR"
22 | if key in os.environ:
23 | return os.environ[key]
24 | if cls.temp_dir_ is None:
25 | cls.temp_dir_ = tempfile.TemporaryDirectory()
26 | return cls.temp_dir_.name
27 |
28 | @classmethod
29 | def tearDownClass(cls):
30 | if cls.temp_dir_ is not None:
31 | try:
32 | cls.temp_dir_.cleanup()
33 | cls.temp_dir_ = None
34 | except PermissionError:
35 | # On Windows there is a know issue with `shutil.rmtree`,
36 | # which fails intermittently.
37 | # https://github.com/python/cpython/issues/74168
38 | # Following the above thread, we ignore it.
39 | pass
40 | super().tearDownClass()
41 |
42 | @property
43 | def id(self):
44 | return self.__class__.__name__
45 |
46 | def get_temp_path(self, *paths):
47 | temp_dir = os.path.join(self.get_base_temp_dir(), self.id)
48 | path = os.path.join(temp_dir, *paths)
49 | os.makedirs(os.path.dirname(path), exist_ok=True)
50 | return path
51 |
52 | def get_temp_dir(self, *paths):
53 | temp_dir = os.path.join(self.get_base_temp_dir(), self.id)
54 | path = os.path.join(temp_dir, *paths)
55 | os.makedirs(path, exist_ok=True)
56 | return path
57 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from pathlib import Path
8 |
9 | from setuptools import setup, find_packages
10 |
11 |
12 | NAME = 'audiocraft'
13 | DESCRIPTION = 'Audio generation research library for PyTorch'
14 |
15 | URL = 'https://github.com/facebookresearch/audiocraft'
16 | AUTHOR = 'FAIR Speech & Audio'
17 | EMAIL = 'defossez@meta.com, jadecopet@meta.com'
18 | REQUIRES_PYTHON = '>=3.8.0'
19 |
20 | for line in open('audiocraft/__init__.py'):
21 | line = line.strip()
22 | if '__version__' in line:
23 | context = {}
24 | exec(line, context)
25 | VERSION = context['__version__']
26 |
27 | HERE = Path(__file__).parent
28 |
29 | try:
30 | with open(HERE / "README.md", encoding='utf-8') as f:
31 | long_description = '\n' + f.read()
32 | except FileNotFoundError:
33 | long_description = DESCRIPTION
34 |
35 | REQUIRED = [i.strip() for i in open(HERE / 'requirements.txt') if not i.startswith('#')]
36 |
37 | setup(
38 | name=NAME,
39 | version=VERSION,
40 | description=DESCRIPTION,
41 | author_email=EMAIL,
42 | long_description=long_description,
43 | long_description_content_type='text/markdown',
44 | author=AUTHOR,
45 | url=URL,
46 | python_requires=REQUIRES_PYTHON,
47 | install_requires=REQUIRED,
48 | extras_require={
49 | 'dev': ['coverage', 'flake8', 'mypy', 'pdoc3', 'pytest'],
50 | },
51 | packages=find_packages(),
52 | package_data={'audiocraft': ['py.typed']},
53 | include_package_data=True,
54 | license='MIT License',
55 | classifiers=[
56 | # Trove classifiers
57 | # Full list: https://pypi.python.org/pypi?%3Aaction=list_classifiers
58 | 'License :: OSI Approved :: MIT License',
59 | 'Topic :: Multimedia :: Sound/Audio',
60 | 'Topic :: Scientific/Engineering :: Artificial Intelligence',
61 | ],
62 | )
63 |
--------------------------------------------------------------------------------
/config/solver/audiogen/audiogen_base_16khz.yaml:
--------------------------------------------------------------------------------
1 | # @package __global__
2 |
3 | # This is the training loop solver
4 | # for the base AudioGen model (text-to-sound)
5 | # on monophonic audio sampled at 16 kHz
6 | # using a similar EnCodec+LM setup to MusicGen
7 | defaults:
8 | - audiogen/default
9 | - /model: lm/audiogen_lm
10 | - override /dset: audio/default
11 | - _self_
12 |
13 | autocast: true
14 | autocast_dtype: float16
15 |
16 | # EnCodec large trained on mono-channel music audio sampled at 16khz
17 | # with a total stride of 320 leading to 50 frames/s.
18 | # rvq.n_q=4, rvq.bins=2048, no quantization dropout
19 | # (transformer_lm card and n_q must be compatible)
20 | compression_model_checkpoint: //reference/bd44a852/checkpoint.th
21 |
22 | channels: 1
23 | sample_rate: 16000
24 |
25 | deadlock:
26 | use: true # deadlock detection
27 |
28 | dataset:
29 | batch_size: 128 # matching AudioGen paper setup (256 * mix_p=0.5 = 128)
30 | num_workers: 10
31 | segment_duration: 10
32 | min_segment_ratio: 1.0
33 | sample_on_weight: false # Uniform sampling all the way
34 | sample_on_duration: false # Uniform sampling all the way
35 | external_metadata_source: null
36 | # sample mixing augmentation at train time
37 | train:
38 | batch_size: 256 # matching AudioGen paper setup
39 | aug_p: 0.5 # perform audio mixing 50% of the time
40 | mix_p: 0.5 # proportion of batch items mixed together
41 | # important: note that this will reduce the
42 | # actual batch size used at train time
43 | # which will be equal to mix_p * batch_size
44 | mix_snr_low: -5
45 | mix_snr_high: 5
46 | mix_min_overlap: 0.5
47 |
48 | generate:
49 | lm:
50 | use_sampling: true
51 | top_k: 250
52 | top_p: 0.0
53 |
54 | optim:
55 | epochs: 100
56 | optimizer: adamw
57 | lr: 5e-4
58 | ema:
59 | use: true
60 | updates: 10
61 | device: cuda
62 |
63 | logging:
64 | log_tensorboard: true
65 |
66 | schedule:
67 | lr_scheduler: inverse_sqrt
68 | inverse_sqrt:
69 | warmup: 3000
70 | warmup_init_lr: 0.0
71 |
--------------------------------------------------------------------------------
/scripts/static/style.css:
--------------------------------------------------------------------------------
1 | body {
2 | background-color: #fbfbfb;
3 | margin: 0;
4 | }
5 |
6 | select, input {
7 | font-size: 1em;
8 | max-width: 100%;
9 | }
10 |
11 | .xp_name {
12 | font-family: monospace;
13 | }
14 |
15 | .simple_form {
16 | background-color: #dddddd;
17 | padding: 1em;
18 | margin: 0.5em;
19 | }
20 |
21 | textarea {
22 | margin-top: 0.5em;
23 | margin-bottom: 0.5em;
24 | }
25 |
26 | .rating {
27 | background-color: grey;
28 | padding-top: 5px;
29 | padding-bottom: 5px;
30 | padding-left: 8px;
31 | padding-right: 8px;
32 | margin-right: 2px;
33 | cursor:pointer;
34 | }
35 |
36 | .rating_selected {
37 | background-color: purple;
38 | }
39 |
40 | .content {
41 | font-family: sans-serif;
42 | background-color: #f6f6f6;
43 | padding: 40px;
44 | margin: 0 auto;
45 | max-width: 1000px;
46 | }
47 |
48 | .track label {
49 | padding-top: 10px;
50 | padding-bottom: 10px;
51 | }
52 | .track {
53 | padding: 15px;
54 | margin: 5px;
55 | background-color: #c8c8c8;
56 | }
57 |
58 | .submit-big {
59 | width:400px;
60 | height:30px;
61 | font-size: 20px;
62 | }
63 |
64 | .error {
65 | color: red;
66 | }
67 |
68 | .ratings {
69 | margin-left: 10px;
70 | }
71 |
72 | .important {
73 | font-weight: bold;
74 | }
75 |
76 | .survey {
77 | margin-bottom: 100px;
78 | }
79 |
80 | .success {
81 | color: #25901b;
82 | font-weight: bold;
83 | }
84 | .warning {
85 | color: #8a1f19;
86 | font-weight: bold;
87 | }
88 | .track>section {
89 | display: flex;
90 | align-items: center;
91 | }
92 |
93 | .prompt {
94 | display: flex;
95 | align-items: center;
96 | }
97 |
98 | .track>section>div {
99 | padding-left: 10px;
100 | }
101 |
102 | audio {
103 | max-width: 280px;
104 | max-height: 40px;
105 | margin-left: 10px;
106 | margin-right: 10px;
107 | }
108 |
109 | .special {
110 | font-weight: bold;
111 | color: #2c2c2c;
112 | }
113 |
114 |
--------------------------------------------------------------------------------
/tests/models/test_musicgen.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import pytest
8 | import torch
9 |
10 | from audiocraft.models import MusicGen
11 |
12 |
13 | class TestMusicGenModel:
14 | def get_musicgen(self):
15 | mg = MusicGen.get_pretrained(name='debug', device='cpu')
16 | mg.set_generation_params(duration=2.0, extend_stride=2.)
17 | return mg
18 |
19 | def test_base(self):
20 | mg = self.get_musicgen()
21 | assert mg.frame_rate == 25
22 | assert mg.sample_rate == 32000
23 | assert mg.audio_channels == 1
24 |
25 | def test_generate_unconditional(self):
26 | mg = self.get_musicgen()
27 | wav = mg.generate_unconditional(3)
28 | assert list(wav.shape) == [3, 1, 64000]
29 |
30 | def test_generate_continuation(self):
31 | mg = self.get_musicgen()
32 | prompt = torch.randn(3, 1, 32000)
33 | wav = mg.generate_continuation(prompt, 32000)
34 | assert list(wav.shape) == [3, 1, 64000]
35 |
36 | prompt = torch.randn(2, 1, 32000)
37 | wav = mg.generate_continuation(
38 | prompt, 32000, ['youpi', 'lapin dort'])
39 | assert list(wav.shape) == [2, 1, 64000]
40 |
41 | prompt = torch.randn(2, 1, 32000)
42 | with pytest.raises(AssertionError):
43 | wav = mg.generate_continuation(
44 | prompt, 32000, ['youpi', 'lapin dort', 'one too many'])
45 |
46 | def test_generate(self):
47 | mg = self.get_musicgen()
48 | wav = mg.generate(
49 | ['youpi', 'lapin dort'])
50 | assert list(wav.shape) == [2, 1, 64000]
51 |
52 | def test_generate_long(self):
53 | mg = self.get_musicgen()
54 | mg.max_duration = 3.
55 | mg.set_generation_params(duration=4., extend_stride=2.)
56 | wav = mg.generate(
57 | ['youpi', 'lapin dort'])
58 | assert list(wav.shape) == [2, 1, 32000 * 4]
59 |
--------------------------------------------------------------------------------
/audiocraft/utils/export_legacy.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | """
8 | Legacy functions used at the time of the first release, kept for referencd.
9 | """
10 |
11 | from pathlib import Path
12 | import typing as tp
13 |
14 | from omegaconf import OmegaConf, DictConfig
15 | import torch
16 |
17 |
18 | def _clean_lm_cfg(cfg: DictConfig):
19 | OmegaConf.set_struct(cfg, False)
20 | # This used to be set automatically in the LM solver, need a more robust solution
21 | # for the future.
22 | cfg['transformer_lm']['card'] = 2048
23 | cfg['transformer_lm']['n_q'] = 4
24 | # Experimental params no longer supported.
25 | bad_params = ['spectral_norm_attn_iters', 'spectral_norm_ff_iters',
26 | 'residual_balancer_attn', 'residual_balancer_ff', 'layer_drop']
27 | for name in bad_params:
28 | del cfg['transformer_lm'][name]
29 | OmegaConf.set_struct(cfg, True)
30 | return cfg
31 |
32 |
33 | def export_encodec(checkpoint_path: tp.Union[Path, str], out_folder: tp.Union[Path, str]):
34 | sig = Path(checkpoint_path).parent.name
35 | assert len(sig) == 8, "Not a valid Dora signature"
36 | pkg = torch.load(checkpoint_path, 'cpu')
37 | new_pkg = {
38 | 'best_state': pkg['ema']['state']['model'],
39 | 'xp.cfg': OmegaConf.to_yaml(pkg['xp.cfg']),
40 | }
41 | out_file = Path(out_folder) / f'{sig}.th'
42 | torch.save(new_pkg, out_file)
43 | return out_file
44 |
45 |
46 | def export_lm(checkpoint_path: tp.Union[Path, str], out_folder: tp.Union[Path, str]):
47 | sig = Path(checkpoint_path).parent.name
48 | assert len(sig) == 8, "Not a valid Dora signature"
49 | pkg = torch.load(checkpoint_path, 'cpu')
50 | new_pkg = {
51 | 'best_state': pkg['fsdp_best_state']['model'],
52 | 'xp.cfg': OmegaConf.to_yaml(_clean_lm_cfg(pkg['xp.cfg']))
53 | }
54 | out_file = Path(out_folder) / f'{sig}.th'
55 | torch.save(new_pkg, out_file)
56 | return out_file
57 |
--------------------------------------------------------------------------------
/tests/losses/test_losses.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import random
8 |
9 | import torch
10 |
11 | from audiocraft.losses import (
12 | MelSpectrogramL1Loss,
13 | MultiScaleMelSpectrogramLoss,
14 | MRSTFTLoss,
15 | SISNR,
16 | STFTLoss,
17 | )
18 |
19 |
20 | def test_mel_l1_loss():
21 | N, C, T = 2, 2, random.randrange(1000, 100_000)
22 | t1 = torch.randn(N, C, T)
23 | t2 = torch.randn(N, C, T)
24 |
25 | mel_l1 = MelSpectrogramL1Loss(sample_rate=22_050)
26 | loss = mel_l1(t1, t2)
27 | loss_same = mel_l1(t1, t1)
28 |
29 | assert isinstance(loss, torch.Tensor)
30 | assert isinstance(loss_same, torch.Tensor)
31 | assert loss_same.item() == 0.0
32 |
33 |
34 | def test_msspec_loss():
35 | N, C, T = 2, 2, random.randrange(1000, 100_000)
36 | t1 = torch.randn(N, C, T)
37 | t2 = torch.randn(N, C, T)
38 |
39 | msspec = MultiScaleMelSpectrogramLoss(sample_rate=22_050)
40 | loss = msspec(t1, t2)
41 | loss_same = msspec(t1, t1)
42 |
43 | assert isinstance(loss, torch.Tensor)
44 | assert isinstance(loss_same, torch.Tensor)
45 | assert loss_same.item() == 0.0
46 |
47 |
48 | def test_mrstft_loss():
49 | N, C, T = 2, 2, random.randrange(1000, 100_000)
50 | t1 = torch.randn(N, C, T)
51 | t2 = torch.randn(N, C, T)
52 |
53 | mrstft = MRSTFTLoss()
54 | loss = mrstft(t1, t2)
55 |
56 | assert isinstance(loss, torch.Tensor)
57 |
58 |
59 | def test_sisnr_loss():
60 | N, C, T = 2, 2, random.randrange(1000, 100_000)
61 | t1 = torch.randn(N, C, T)
62 | t2 = torch.randn(N, C, T)
63 |
64 | sisnr = SISNR()
65 | loss = sisnr(t1, t2)
66 |
67 | assert isinstance(loss, torch.Tensor)
68 |
69 |
70 | def test_stft_loss():
71 | N, C, T = 2, 2, random.randrange(1000, 100_000)
72 | t1 = torch.randn(N, C, T)
73 | t2 = torch.randn(N, C, T)
74 |
75 | mrstft = STFTLoss()
76 | loss = mrstft(t1, t2)
77 |
78 | assert isinstance(loss, torch.Tensor)
79 |
--------------------------------------------------------------------------------
/config/model/lm/default.yaml:
--------------------------------------------------------------------------------
1 | # @package __global__
2 | defaults:
3 | - _self_
4 | - /model/lm/model_scale: base # prefer this group to set model scale instead of transformer_lm keys directly
5 |
6 | lm_model: transformer_lm
7 |
8 | codebooks_pattern:
9 | modeling: parallel
10 |
11 | transformer_lm:
12 | dim: 512
13 | num_heads: 8
14 | num_layers: 8
15 | hidden_scale: 4
16 | n_q: 8 # number of streams to model
17 | card: 1024
18 | dropout: 0.
19 | emb_lr: null
20 | activation: gelu
21 | norm_first: false # use pre-norm instead of post-norm
22 | bias_ff: true # use bias for the feedforward
23 | bias_attn: true # use bias for the attention
24 | bias_proj: true # use bias for the output projections
25 | past_context: null
26 | causal: true
27 | custom: false # use custom MHA implementation
28 | memory_efficient: false # use flash attention
29 | attention_as_float32: false # use float32 for the attention part,
30 | # recommended at the moment when memory_efficient is True.
31 | layer_scale: null
32 | positional_embedding: sin # positional embedding strategy (sin, rope, or sin_rope).
33 | xpos: false # apply xpos decay (rope only).
34 | checkpointing: none # layer checkpointing method, can be none, torch, xformers_default.
35 | # torch is the slowest but uses the least memory,
36 | # xformers_default is somewhere in between.
37 | weight_init: null # weight initialization (null, gaussian or uniform)
38 | depthwise_init: null # perform depthwise initialization (null, current, global)
39 | zero_bias_init: false # initialize bias to zero if bias in linears and
40 | # if a weight_init method is used.
41 | norm: layer_norm # normalization method to use in transformer.
42 | cross_attention: false
43 | qk_layer_norm: false
44 | qk_layer_norm_cross: false
45 | attention_dropout: null
46 | kv_repeat: 1
47 | two_step_cfg: false # whether to do true 2 steps CFG, potentially resolving some padding issues or not...
48 |
--------------------------------------------------------------------------------
/audiocraft/optim/polynomial_decay_lr_scheduler.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from torch.optim import Optimizer
8 | from torch.optim.lr_scheduler import _LRScheduler
9 |
10 |
11 | class PolynomialDecayLRScheduler(_LRScheduler):
12 | """Polynomial decay LR scheduler.
13 |
14 | Args:
15 | optimizer (Optimizer): Torch optimizer.
16 | warmup_steps (int): Number of warmup steps.
17 | total_steps (int): Total number of steps.
18 | end_lr (float): Final learning rate to achieve over total number of steps.
19 | zero_lr_warmup_steps (int): Number of steps with a learning rate of value 0.
20 | power (float): Decay exponent.
21 | """
22 | def __init__(self, optimizer: Optimizer, warmup_steps: int, total_steps: int,
23 | end_lr: float = 0., zero_lr_warmup_steps: int = 0, power: float = 1.):
24 | self.warmup_steps = warmup_steps
25 | self.total_steps = total_steps
26 | self.end_lr = end_lr
27 | self.zero_lr_warmup_steps = zero_lr_warmup_steps
28 | self.power = power
29 | super().__init__(optimizer)
30 |
31 | def _get_sched_lr(self, lr: float, step: int):
32 | if self.zero_lr_warmup_steps > 0 and step <= self.zero_lr_warmup_steps:
33 | lr = 0
34 | elif self.warmup_steps > 0 and step <= self.warmup_steps + self.zero_lr_warmup_steps:
35 | lr_ratio = (step - self.zero_lr_warmup_steps) / float(self.warmup_steps)
36 | lr = lr_ratio * lr
37 | elif step >= self.total_steps:
38 | lr = self.end_lr
39 | else:
40 | total_warmup_steps = self.warmup_steps + self.zero_lr_warmup_steps
41 | lr_range = lr - self.end_lr
42 | pct_remaining = 1 - (step - total_warmup_steps) / (self.total_steps - total_warmup_steps)
43 | lr = lr_range * pct_remaining ** self.power + self.end_lr
44 | return lr
45 |
46 | def get_lr(self):
47 | return [self._get_sched_lr(base_lr, self.last_epoch) for base_lr in self.base_lrs]
48 |
--------------------------------------------------------------------------------
/config/solver/diffusion/debug.yaml:
--------------------------------------------------------------------------------
1 | # @package __global__
2 |
3 | defaults:
4 | - /solver/default
5 | - /model: score/basic
6 | - override /dset: audio/default
7 | - _self_
8 |
9 | solver: diffusion
10 |
11 | sample_rate: 16000
12 | channels: 1
13 | compression_model_checkpoint: //sig/5091833e
14 | n_q: 2 # number of codebooks to keep
15 |
16 | dataset:
17 | batch_size: 8
18 | num_workers: 10
19 | segment_duration: 1
20 | train:
21 | num_samples: 100
22 | valid:
23 | num_samples: 100
24 | evaluate:
25 | batch_size: 8
26 | num_samples: 10
27 | generate:
28 | batch_size: 8
29 | num_samples: 10
30 | segment_duration: 10
31 |
32 | loss:
33 | kind: mse
34 | norm_power: 0.
35 |
36 | valid:
37 | every: 1
38 |
39 | evaluate:
40 | every: 5
41 | num_workers: 5
42 | metrics:
43 | visqol: false
44 | sisnr: false
45 | rvm: true
46 |
47 | generate:
48 | every: 5
49 | num_workers: 5
50 | audio:
51 | sample_rate: ${sample_rate}
52 |
53 | checkpoint:
54 | save_last: true
55 | save_every: 25
56 | keep_last: 10
57 | keep_every_states: null
58 |
59 |
60 | optim:
61 | epochs: 50
62 | updates_per_epoch: 2000
63 | lr: 2e-4
64 | max_norm: 0
65 | optimizer: adam
66 | adam:
67 | betas: [0.9, 0.999]
68 | weight_decay: 0.
69 | ema:
70 | use: true # whether to use EMA or not
71 | updates: 1 # update at every step
72 | device: ${device} # device for EMA, can be put on GPU if more frequent updates
73 | decay: 0.99 # EMA decay value, if null, no EMA is used
74 |
75 | processor:
76 | name: multi_band_processor
77 | use: false
78 | n_bands: 8
79 | num_samples: 10_000
80 | power_std: 1.
81 |
82 | resampling:
83 | use: false
84 | target_sr: 16000
85 |
86 | filter:
87 | use: false
88 | n_bands: 4
89 | idx_band: 0
90 | cutoffs: null
91 |
92 | schedule:
93 | repartition: "power"
94 | variable_step_batch: true
95 | beta_t0: 1.0e-5
96 | beta_t1: 2.9e-2
97 | beta_exp: 7.5
98 | num_steps: 1000
99 | variance: 'beta'
100 | clip: 5.
101 | rescale: 1.
102 | n_bands: null
103 | noise_scale: 1.0
104 |
105 | metrics:
106 | num_stage: 4
107 |
--------------------------------------------------------------------------------
/config/solver/diffusion/default.yaml:
--------------------------------------------------------------------------------
1 | # @package __global__
2 |
3 | defaults:
4 | - /solver/default
5 | - /model: score/basic
6 | - override /dset: audio/default
7 | - _self_
8 |
9 | solver: diffusion
10 |
11 | sample_rate: ???
12 | channels: ???
13 | compression_model_checkpoint: ???
14 | n_q: ??? # number of codebooks to keep
15 |
16 |
17 | dataset:
18 | batch_size: 128
19 | num_workers: 10
20 | segment_duration: 1
21 | train:
22 | num_samples: 500000
23 | valid:
24 | num_samples: 10000
25 | evaluate:
26 | batch_size: 16
27 | num_samples: 10000
28 | generate:
29 | batch_size: 32
30 | num_samples: 50
31 | segment_duration: 10
32 | audio:
33 | sample_rate: ${sample_rate}
34 |
35 | loss:
36 | kind: mse
37 | norm_power: 0.
38 |
39 | valid:
40 | every: 1
41 |
42 | evaluate:
43 | every: 20
44 | num_workers: 5
45 | metrics:
46 | visqol: false
47 | sisnr: false
48 | rvm: true
49 |
50 | generate:
51 | every: 25
52 | num_workers: 5
53 |
54 | checkpoint:
55 | save_last: true
56 | save_every: 25
57 | keep_last: 10
58 | keep_every_states: null
59 |
60 |
61 | optim:
62 | epochs: 20000
63 | updates_per_epoch: 2000
64 | lr: 2e-4
65 | max_norm: 0
66 | optimizer: adam
67 | adam:
68 | betas: [0.9, 0.999]
69 | weight_decay: 0.
70 | ema:
71 | use: true # whether to use EMA or not
72 | updates: 1 # update at every step
73 | device: ${device} # device for EMA, can be put on GPU if more frequent updates
74 | decay: 0.99 # EMA decay value, if null, no EMA is used
75 |
76 | processor:
77 | name: multi_band_processor
78 | use: false
79 | n_bands: 8
80 | num_samples: 10_000
81 | power_std: 1.
82 |
83 | resampling:
84 | use: false
85 | target_sr: 16000
86 |
87 | filter:
88 | use: false
89 | n_bands: 4
90 | idx_band: 0
91 | cutoffs: null
92 |
93 | schedule:
94 | repartition: "power"
95 | variable_step_batch: true
96 | beta_t0: 1.0e-5
97 | beta_t1: 2.9e-2
98 | beta_exp: 7.5
99 | num_steps: 1000
100 | variance: 'beta'
101 | clip: 5.
102 | rescale: 1.
103 | n_bands: null
104 | noise_scale: 1.0
105 |
106 | metrics:
107 | num_stage: 4
108 |
--------------------------------------------------------------------------------
/audiocraft/grids/diffusion/_explorers.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import treetable as tt
8 |
9 | from .._base_explorers import BaseExplorer
10 |
11 |
12 | class DiffusionExplorer(BaseExplorer):
13 | eval_metrics = ["sisnr", "visqol"]
14 |
15 | def stages(self):
16 | return ["train", "valid", "valid_ema", "evaluate", "evaluate_ema"]
17 |
18 | def get_grid_meta(self):
19 | """Returns the list of Meta information to display for each XP/job.
20 | """
21 | return [
22 | tt.leaf("index", align=">"),
23 | tt.leaf("name", wrap=140),
24 | tt.leaf("state"),
25 | tt.leaf("sig", align=">"),
26 | ]
27 |
28 | def get_grid_metrics(self):
29 | """Return the metrics that should be displayed in the tracking table.
30 | """
31 | return [
32 | tt.group(
33 | "train",
34 | [
35 | tt.leaf("epoch"),
36 | tt.leaf("loss", ".3%"),
37 | ],
38 | align=">",
39 | ),
40 | tt.group(
41 | "valid",
42 | [
43 | tt.leaf("loss", ".3%"),
44 | # tt.leaf("loss_0", ".3%"),
45 | ],
46 | align=">",
47 | ),
48 | tt.group(
49 | "valid_ema",
50 | [
51 | tt.leaf("loss", ".3%"),
52 | # tt.leaf("loss_0", ".3%"),
53 | ],
54 | align=">",
55 | ),
56 | tt.group(
57 | "evaluate", [tt.leaf("rvm", ".4f"), tt.leaf("rvm_0", ".4f"),
58 | tt.leaf("rvm_1", ".4f"), tt.leaf("rvm_2", ".4f"),
59 | tt.leaf("rvm_3", ".4f"), ], align=">"
60 | ),
61 | tt.group(
62 | "evaluate_ema", [tt.leaf("rvm", ".4f"), tt.leaf("rvm_0", ".4f"),
63 | tt.leaf("rvm_1", ".4f"), tt.leaf("rvm_2", ".4f"),
64 | tt.leaf("rvm_3", ".4f")], align=">"
65 | ),
66 | ]
67 |
--------------------------------------------------------------------------------
/audiocraft/utils/cluster.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | """
8 | Utility functions for SLURM configuration and cluster settings.
9 | """
10 |
11 | from enum import Enum
12 | import os
13 | import socket
14 | import typing as tp
15 |
16 | import omegaconf
17 |
18 |
19 | class ClusterType(Enum):
20 | AWS = "aws"
21 | FAIR = "fair"
22 | RSC = "rsc"
23 | LOCAL_DARWIN = "darwin"
24 | DEFAULT = "default" # used for any other cluster.
25 |
26 |
27 | def _guess_cluster_type() -> ClusterType:
28 | uname = os.uname()
29 | fqdn = socket.getfqdn()
30 | if uname.sysname == "Linux" and (uname.release.endswith("-aws") or ".ec2" in fqdn):
31 | return ClusterType.AWS
32 |
33 | if fqdn.endswith(".fair"):
34 | return ClusterType.FAIR
35 |
36 | if fqdn.endswith(".facebook.com"):
37 | return ClusterType.RSC
38 |
39 | if uname.sysname == "Darwin":
40 | return ClusterType.LOCAL_DARWIN
41 |
42 | return ClusterType.DEFAULT
43 |
44 |
45 | def get_cluster_type(
46 | cluster_type: tp.Optional[ClusterType] = None,
47 | ) -> tp.Optional[ClusterType]:
48 | if cluster_type is None:
49 | return _guess_cluster_type()
50 |
51 | return cluster_type
52 |
53 |
54 | def get_slurm_parameters(
55 | cfg: omegaconf.DictConfig, cluster_type: tp.Optional[ClusterType] = None
56 | ) -> omegaconf.DictConfig:
57 | """Update SLURM parameters in configuration based on cluster type.
58 | If the cluster type is not specify, it infers it automatically.
59 | """
60 | from ..environment import AudioCraftEnvironment
61 | cluster_type = get_cluster_type(cluster_type)
62 | # apply cluster-specific adjustments
63 | if cluster_type == ClusterType.AWS:
64 | cfg["mem_per_gpu"] = None
65 | cfg["constraint"] = None
66 | cfg["setup"] = []
67 | elif cluster_type == ClusterType.RSC:
68 | cfg["mem_per_gpu"] = None
69 | cfg["setup"] = []
70 | cfg["constraint"] = None
71 | cfg["partition"] = "learn"
72 | slurm_exclude = AudioCraftEnvironment.get_slurm_exclude()
73 | if slurm_exclude is not None:
74 | cfg["exclude"] = slurm_exclude
75 | return cfg
76 |
--------------------------------------------------------------------------------
/audiocraft/data/zip.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | """Utility for reading some info from inside a zip file.
7 | """
8 |
9 | import typing
10 | import zipfile
11 |
12 | from dataclasses import dataclass
13 | from functools import lru_cache
14 | from typing_extensions import Literal
15 |
16 |
17 | DEFAULT_SIZE = 32
18 | MODE = Literal['r', 'w', 'x', 'a']
19 |
20 |
21 | @dataclass(order=True)
22 | class PathInZip:
23 | """Hold a path of file within a zip file.
24 |
25 | Args:
26 | path (str): The convention is :.
27 | Let's assume there is a zip file /some/location/foo.zip
28 | and inside of it is a json file located at /data/file1.json,
29 | Then we expect path = "/some/location/foo.zip:/data/file1.json".
30 | """
31 |
32 | INFO_PATH_SEP = ':'
33 | zip_path: str
34 | file_path: str
35 |
36 | def __init__(self, path: str) -> None:
37 | split_path = path.split(self.INFO_PATH_SEP)
38 | assert len(split_path) == 2
39 | self.zip_path, self.file_path = split_path
40 |
41 | @classmethod
42 | def from_paths(cls, zip_path: str, file_path: str):
43 | return cls(zip_path + cls.INFO_PATH_SEP + file_path)
44 |
45 | def __str__(self) -> str:
46 | return self.zip_path + self.INFO_PATH_SEP + self.file_path
47 |
48 |
49 | def _open_zip(path: str, mode: MODE = 'r'):
50 | return zipfile.ZipFile(path, mode)
51 |
52 |
53 | _cached_open_zip = lru_cache(DEFAULT_SIZE)(_open_zip)
54 |
55 |
56 | def set_zip_cache_size(max_size: int):
57 | """Sets the maximal LRU caching for zip file opening.
58 |
59 | Args:
60 | max_size (int): the maximal LRU cache.
61 | """
62 | global _cached_open_zip
63 | _cached_open_zip = lru_cache(max_size)(_open_zip)
64 |
65 |
66 | def open_file_in_zip(path_in_zip: PathInZip, mode: str = 'r') -> typing.IO:
67 | """Opens a file stored inside a zip and returns a file-like object.
68 |
69 | Args:
70 | path_in_zip (PathInZip): A PathInZip object representing the file to return a file-like object of.
71 | mode (str): The mode in which to open the file with.
72 | Returns:
73 | A file-like object for PathInZip.
74 | """
75 | zf = _cached_open_zip(path_in_zip.zip_path)
76 | return zf.open(path_in_zip.file_path)
77 |
--------------------------------------------------------------------------------
/audiocraft/grids/musicgen/musicgen_melody_32khz.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from ._explorers import LMExplorer
8 | from ...environment import AudioCraftEnvironment
9 |
10 |
11 | @LMExplorer
12 | def explorer(launcher):
13 | partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global'])
14 | launcher.slurm_(gpus=32, partition=partitions)
15 | launcher.bind_(solver='musicgen/musicgen_melody_32khz')
16 | # replace this by the desired music dataset
17 | launcher.bind_(dset='internal/music_400k_32khz')
18 |
19 | fsdp = {'autocast': False, 'fsdp.use': True}
20 | medium = {'model/lm/model_scale': 'medium'}
21 | large = {'model/lm/model_scale': 'large'}
22 |
23 | cfg_low = {'classifier_free_guidance.training_dropout': 0.2}
24 | wd_low = {'conditioners.description.t5.word_dropout': 0.2}
25 |
26 | adam = {'optim.optimizer': 'adamw', 'optim.lr': 1e-4}
27 |
28 | cache_path = {'conditioners.self_wav.chroma_stem.cache_path':
29 | '/fsx-audio-craft-llm/jadecopet/experiments/audiocraft/caches/chroma_stem'}
30 |
31 | # CACHE GENERATION JOBS
32 | n_cache_gen_jobs = 4
33 | gen_sub = launcher.slurm(gpus=1)
34 | gen_sub.bind_(
35 | cache_path, {
36 | # the cache is always computed over the whole file, so duration doesn't matter here.
37 | 'dataset.segment_duration': 2.,
38 | 'dataset.batch_size': 8,
39 | 'dataset.train.permutation_on_files': True, # try to not repeat files.
40 | 'optim.epochs': 10,
41 | 'model/lm/model_scale': 'xsmall',
42 |
43 | })
44 | with gen_sub.job_array():
45 | for gen_job in range(n_cache_gen_jobs):
46 | gen_sub({'dataset.train.shuffle_seed': gen_job})
47 |
48 | # ACTUAL TRAINING JOBS.
49 | launcher.bind_(fsdp)
50 |
51 | launcher.slurm_(gpus=32).bind_(label='32gpus')
52 | with launcher.job_array():
53 | sub = launcher.bind()
54 | sub()
55 | sub(cache_path)
56 |
57 | launcher.slurm_(gpus=64).bind_(label='64gpus')
58 | with launcher.job_array():
59 | sub = launcher.bind()
60 | sub(medium, adam)
61 |
62 | launcher.slurm_(gpus=96).bind_(label='96gpus')
63 | with launcher.job_array():
64 | sub = launcher.bind()
65 | sub(large, cfg_low, wd_low, adam, {'optim.max_norm': 3})
66 |
--------------------------------------------------------------------------------
/tests/models/test_multibanddiffusion.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import random
8 |
9 | import numpy as np
10 | import torch
11 | from audiocraft.models.multibanddiffusion import MultiBandDiffusion, DiffusionProcess
12 | from audiocraft.models import EncodecModel, DiffusionUnet
13 | from audiocraft.modules import SEANetEncoder, SEANetDecoder
14 | from audiocraft.modules.diffusion_schedule import NoiseSchedule
15 | from audiocraft.quantization import DummyQuantizer
16 |
17 |
18 | class TestMBD:
19 |
20 | def _create_mbd(self,
21 | sample_rate: int,
22 | channels: int,
23 | n_filters: int = 3,
24 | n_residual_layers: int = 1,
25 | ratios: list = [5, 4, 3, 2],
26 | num_steps: int = 1000,
27 | codec_dim: int = 128,
28 | **kwargs):
29 | frame_rate = np.prod(ratios)
30 | encoder = SEANetEncoder(channels=channels, dimension=codec_dim, n_filters=n_filters,
31 | n_residual_layers=n_residual_layers, ratios=ratios)
32 | decoder = SEANetDecoder(channels=channels, dimension=codec_dim, n_filters=n_filters,
33 | n_residual_layers=n_residual_layers, ratios=ratios)
34 | quantizer = DummyQuantizer()
35 | compression_model = EncodecModel(encoder, decoder, quantizer, frame_rate=frame_rate,
36 | sample_rate=sample_rate, channels=channels, **kwargs)
37 | diffusion_model = DiffusionUnet(chin=channels, num_steps=num_steps, codec_dim=codec_dim)
38 | schedule = NoiseSchedule(device='cpu', num_steps=num_steps)
39 | DP = DiffusionProcess(model=diffusion_model, noise_schedule=schedule)
40 | mbd = MultiBandDiffusion(DPs=[DP], codec_model=compression_model)
41 | return mbd
42 |
43 | def test_model(self):
44 | random.seed(1234)
45 | sample_rate = 24_000
46 | channels = 1
47 | codec_dim = 128
48 | mbd = self._create_mbd(sample_rate=sample_rate, channels=channels, codec_dim=codec_dim)
49 | for _ in range(10):
50 | length = random.randrange(1, 10_000)
51 | x = torch.randn(2, channels, length)
52 | res = mbd.regenerate(x, sample_rate)
53 | assert res.shape == x.shape
54 |
--------------------------------------------------------------------------------
/tests/adversarial/test_discriminators.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import random
8 |
9 | import torch
10 |
11 | from audiocraft.adversarial.discriminators import (
12 | MultiPeriodDiscriminator,
13 | MultiScaleDiscriminator,
14 | MultiScaleSTFTDiscriminator
15 | )
16 |
17 |
18 | class TestMultiPeriodDiscriminator:
19 |
20 | def test_mpd_discriminator(self):
21 | N, C, T = 2, 2, random.randrange(1, 100_000)
22 | t0 = torch.randn(N, C, T)
23 | periods = [1, 2, 3]
24 | mpd = MultiPeriodDiscriminator(periods=periods, in_channels=C)
25 | logits, fmaps = mpd(t0)
26 |
27 | assert len(logits) == len(periods)
28 | assert len(fmaps) == len(periods)
29 | assert all([logit.shape[0] == N and len(logit.shape) == 4 for logit in logits])
30 | assert all([feature.shape[0] == N for fmap in fmaps for feature in fmap])
31 |
32 |
33 | class TestMultiScaleDiscriminator:
34 |
35 | def test_msd_discriminator(self):
36 | N, C, T = 2, 2, random.randrange(1, 100_000)
37 | t0 = torch.randn(N, C, T)
38 |
39 | scale_norms = ['weight_norm', 'weight_norm']
40 | msd = MultiScaleDiscriminator(scale_norms=scale_norms, in_channels=C)
41 | logits, fmaps = msd(t0)
42 |
43 | assert len(logits) == len(scale_norms)
44 | assert len(fmaps) == len(scale_norms)
45 | assert all([logit.shape[0] == N and len(logit.shape) == 3 for logit in logits])
46 | assert all([feature.shape[0] == N for fmap in fmaps for feature in fmap])
47 |
48 |
49 | class TestMultiScaleStftDiscriminator:
50 |
51 | def test_msstftd_discriminator(self):
52 | N, C, T = 2, 2, random.randrange(1, 100_000)
53 | t0 = torch.randn(N, C, T)
54 |
55 | n_filters = 4
56 | n_ffts = [128, 256, 64]
57 | hop_lengths = [32, 64, 16]
58 | win_lengths = [128, 256, 64]
59 |
60 | msstftd = MultiScaleSTFTDiscriminator(filters=n_filters, n_ffts=n_ffts, hop_lengths=hop_lengths,
61 | win_lengths=win_lengths, in_channels=C)
62 | logits, fmaps = msstftd(t0)
63 |
64 | assert len(logits) == len(n_ffts)
65 | assert len(fmaps) == len(n_ffts)
66 | assert all([logit.shape[0] == N and len(logit.shape) == 4 for logit in logits])
67 | assert all([feature.shape[0] == N for fmap in fmaps for feature in fmap])
68 |
--------------------------------------------------------------------------------
/tests/models/test_encodec_model.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import random
8 |
9 | import numpy as np
10 | import torch
11 |
12 | from audiocraft.models import EncodecModel
13 | from audiocraft.modules import SEANetEncoder, SEANetDecoder
14 | from audiocraft.quantization import DummyQuantizer
15 |
16 |
17 | class TestEncodecModel:
18 |
19 | def _create_encodec_model(self,
20 | sample_rate: int,
21 | channels: int,
22 | dim: int = 5,
23 | n_filters: int = 3,
24 | n_residual_layers: int = 1,
25 | ratios: list = [5, 4, 3, 2],
26 | **kwargs):
27 | frame_rate = np.prod(ratios)
28 | encoder = SEANetEncoder(channels=channels, dimension=dim, n_filters=n_filters,
29 | n_residual_layers=n_residual_layers, ratios=ratios)
30 | decoder = SEANetDecoder(channels=channels, dimension=dim, n_filters=n_filters,
31 | n_residual_layers=n_residual_layers, ratios=ratios)
32 | quantizer = DummyQuantizer()
33 | model = EncodecModel(encoder, decoder, quantizer, frame_rate=frame_rate,
34 | sample_rate=sample_rate, channels=channels, **kwargs)
35 | return model
36 |
37 | def test_model(self):
38 | random.seed(1234)
39 | sample_rate = 24_000
40 | channels = 1
41 | model = self._create_encodec_model(sample_rate, channels)
42 | for _ in range(10):
43 | length = random.randrange(1, 10_000)
44 | x = torch.randn(2, channels, length)
45 | res = model(x)
46 | assert res.x.shape == x.shape
47 |
48 | def test_model_renorm(self):
49 | random.seed(1234)
50 | sample_rate = 24_000
51 | channels = 1
52 | model_nonorm = self._create_encodec_model(sample_rate, channels, renormalize=False)
53 | model_renorm = self._create_encodec_model(sample_rate, channels, renormalize=True)
54 |
55 | for _ in range(10):
56 | length = random.randrange(1, 10_000)
57 | x = torch.randn(2, channels, length)
58 | codes, scales = model_nonorm.encode(x)
59 | codes, scales = model_renorm.encode(x)
60 | assert scales is not None
61 |
--------------------------------------------------------------------------------
/audiocraft/grids/musicgen/musicgen_base_cached_32khz.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from ._explorers import LMExplorer
8 | from ...environment import AudioCraftEnvironment
9 |
10 |
11 | @LMExplorer
12 | def explorer(launcher):
13 | partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global'])
14 | launcher.slurm_(gpus=32, partition=partitions)
15 | launcher.bind_(solver='musicgen/musicgen_base_32khz')
16 | # replace this by the desired music dataset
17 | launcher.bind_(dset='internal/music_400k_32khz')
18 |
19 | fsdp = {'autocast': False, 'fsdp.use': True}
20 | medium = {'model/lm/model_scale': 'medium'}
21 | large = {'model/lm/model_scale': 'large'}
22 |
23 | cfg_low = {'classifier_free_guidance.training_dropout': 0.2}
24 | wd_low = {'conditioners.description.t5.word_dropout': 0.2}
25 |
26 | adam = {'optim.optimizer': 'adamw', 'optim.lr': 1e-4}
27 |
28 | # BEGINNING OF CACHE WRITING JOBS.
29 | cache_write = {
30 | 'cache.path': '/fsx-codegen/defossez/cache/interleave_stereo_nv_32k',
31 | 'cache.write': True,
32 | 'generate.every': 500,
33 | 'evaluate.every': 500,
34 | 'logging.log_updates': 50,
35 | }
36 |
37 | cache_sub = launcher.bind({'model/lm/model_scale': 'xsmall', 'conditioner': 'none'})
38 | cache_sub.bind_({'deadlock.use': True})
39 | cache_sub.slurm_(gpus=8)
40 | with launcher.job_array():
41 | num_shards = 10 # total number of jobs running in parallel.
42 | for shard in range(0, num_shards):
43 | launcher(cache_write, {'cache.write_num_shards': num_shards, 'cache.write_shard': shard})
44 |
45 | # REMOVE THE FOLLOWING RETURN STATEMENT ONCE THE ABOVE JOBS ARE DONE,
46 | # OR SUFFICIENTLY AHEAD.
47 | return
48 |
49 | cache = {
50 | 'cache.path': '/fsx-codegen/defossez/cache/interleave_stereo_nv_32k',
51 | }
52 | launcher.bind_(fsdp, cache)
53 |
54 | launcher.slurm_(gpus=32).bind_(label='32gpus')
55 | with launcher.job_array():
56 | sub = launcher.bind()
57 | sub()
58 |
59 | launcher.slurm_(gpus=64).bind_(label='64gpus')
60 | with launcher.job_array():
61 | sub = launcher.bind()
62 | sub(medium, adam)
63 |
64 | launcher.slurm_(gpus=96).bind_(label='96gpus')
65 | with launcher.job_array():
66 | sub = launcher.bind()
67 | sub(large, cfg_low, wd_low, adam, {'optim.max_norm': 3})
68 |
--------------------------------------------------------------------------------
/audiocraft/grids/audiogen/audiogen_pretrained_16khz_eval.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | """
8 | Evaluation with objective metrics for the pretrained AudioGen models.
9 | This grid takes signature from the training grid and runs evaluation-only stage.
10 |
11 | When running the grid for the first time, please use:
12 | REGEN=1 dora grid audiogen.audiogen_pretrained_16khz_eval
13 | and re-use the REGEN=1 option when the grid is changed to force regenerating it.
14 |
15 | Note that you need the proper metrics external libraries setup to use all
16 | the objective metrics activated in this grid. Refer to the README for more information.
17 | """
18 |
19 | import os
20 |
21 | from ..musicgen._explorers import GenerationEvalExplorer
22 | from ...environment import AudioCraftEnvironment
23 | from ... import train
24 |
25 |
26 | def eval(launcher, batch_size: int = 32):
27 | opts = {
28 | 'dset': 'audio/audiocaps_16khz',
29 | 'solver/audiogen/evaluation': 'objective_eval',
30 | 'execute_only': 'evaluate',
31 | '+dataset.evaluate.batch_size': batch_size,
32 | '+metrics.fad.tf.batch_size': 32,
33 | }
34 | # binary for FAD computation: replace this path with your own path
35 | metrics_opts = {
36 | 'metrics.fad.tf.bin': '/data/home/jadecopet/local/usr/opt/google-research'
37 | }
38 | opt1 = {'generate.lm.use_sampling': True, 'generate.lm.top_k': 250, 'generate.lm.top_p': 0.}
39 | opt2 = {'transformer_lm.two_step_cfg': True}
40 |
41 | sub = launcher.bind(opts)
42 | sub.bind_(metrics_opts)
43 |
44 | # base objective metrics
45 | sub(opt1, opt2)
46 |
47 |
48 | @GenerationEvalExplorer
49 | def explorer(launcher):
50 | partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global'])
51 | launcher.slurm_(gpus=4, partition=partitions)
52 |
53 | if 'REGEN' not in os.environ:
54 | folder = train.main.dora.dir / 'grids' / __name__.split('.', 2)[-1]
55 | with launcher.job_array():
56 | for sig in folder.iterdir():
57 | if not sig.is_symlink():
58 | continue
59 | xp = train.main.get_xp_from_sig(sig.name)
60 | launcher(xp.argv)
61 | return
62 |
63 | audiogen_base = launcher.bind(solver="audiogen/audiogen_base_16khz")
64 | audiogen_base.bind_({'autocast': False, 'fsdp.use': True})
65 |
66 | audiogen_base_medium = audiogen_base.bind({'continue_from': '//pretrained/facebook/audiogen-medium'})
67 | audiogen_base_medium.bind_({'model/lm/model_scale': 'medium'})
68 | eval(audiogen_base_medium, batch_size=128)
69 |
--------------------------------------------------------------------------------
/audiocraft/grids/_base_explorers.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from abc import ABC, abstractmethod
8 | import time
9 | import typing as tp
10 | from dora import Explorer
11 | import treetable as tt
12 |
13 |
14 | def get_sheep_ping(sheep) -> tp.Optional[str]:
15 | """Return the amount of time since the Sheep made some update
16 | to its log. Returns a str using the relevant time unit."""
17 | ping = None
18 | if sheep.log is not None and sheep.log.exists():
19 | delta = time.time() - sheep.log.stat().st_mtime
20 | if delta > 3600 * 24:
21 | ping = f'{delta / (3600 * 24):.1f}d'
22 | elif delta > 3600:
23 | ping = f'{delta / (3600):.1f}h'
24 | elif delta > 60:
25 | ping = f'{delta / 60:.1f}m'
26 | else:
27 | ping = f'{delta:.1f}s'
28 | return ping
29 |
30 |
31 | class BaseExplorer(ABC, Explorer):
32 | """Base explorer for AudioCraft grids.
33 |
34 | All task specific solvers are expected to implement the `get_grid_metrics`
35 | method to specify logic about metrics to display for a given task.
36 |
37 | If additional stages are used, the child explorer must define how to handle
38 | these new stages in the `process_history` and `process_sheep` methods.
39 | """
40 | def stages(self):
41 | return ["train", "valid", "evaluate"]
42 |
43 | def get_grid_meta(self):
44 | """Returns the list of Meta information to display for each XP/job.
45 | """
46 | return [
47 | tt.leaf("index", align=">"),
48 | tt.leaf("name", wrap=140),
49 | tt.leaf("state"),
50 | tt.leaf("sig", align=">"),
51 | tt.leaf("sid", align="<"),
52 | ]
53 |
54 | @abstractmethod
55 | def get_grid_metrics(self):
56 | """Return the metrics that should be displayed in the tracking table.
57 | """
58 | ...
59 |
60 | def process_sheep(self, sheep, history):
61 | train = {
62 | "epoch": len(history),
63 | }
64 | parts = {"train": train}
65 | for metrics in history:
66 | for key, sub in metrics.items():
67 | part = parts.get(key, {})
68 | if 'duration' in sub:
69 | # Convert to minutes for readability.
70 | sub['duration'] = sub['duration'] / 60.
71 | part.update(sub)
72 | parts[key] = part
73 | ping = get_sheep_ping(sheep)
74 | if ping is not None:
75 | for name in self.stages():
76 | if name not in parts:
77 | parts[name] = {}
78 | # Add the ping to each part for convenience.
79 | parts[name]['ping'] = ping
80 | return parts
81 |
--------------------------------------------------------------------------------
/config/solver/musicgen/default.yaml:
--------------------------------------------------------------------------------
1 | # @package __global__
2 |
3 | defaults:
4 | - /solver/default
5 | - /conditioner: none
6 | - _self_
7 | - /solver/musicgen/evaluation: none
8 | - override /dset: audio/default
9 |
10 | autocast: true
11 | autocast_dtype: float16
12 |
13 | solver: musicgen
14 | sample_rate: ???
15 | channels: ???
16 | compression_model_checkpoint: ???
17 |
18 | tokens:
19 | padding_with_special_token: false
20 |
21 | cache:
22 | path:
23 | write: false
24 | write_shard: 0
25 | write_num_shards: 1
26 |
27 |
28 | dataset:
29 | batch_size: 128
30 | num_workers: 10
31 | segment_duration: 30
32 | min_segment_ratio: 0.8 # lower values such as 0.5 result in generations with a lot of silence.
33 | return_info: true
34 | train:
35 | num_samples: 1000000 # need a randomly large number here for AudioDataset
36 | valid:
37 | num_samples: 10000
38 | generate:
39 | num_samples: 50
40 |
41 | metrics:
42 | fad:
43 | use_gt: false
44 | model: tf
45 | tf:
46 | bin: null # path to local frechet_audio_distance code
47 | model_path: //reference/fad/vggish_model.ckpt
48 | kld:
49 | use_gt: false
50 | model: passt
51 | passt:
52 | pretrained_length: 20
53 | text_consistency:
54 | use_gt: false
55 | model: clap
56 | clap:
57 | model_path: //reference/clap/music_audioset_epoch_15_esc_90.14.pt
58 | model_arch: 'HTSAT-base'
59 | enable_fusion: false
60 | chroma_cosine:
61 | use_gt: false
62 | model: chroma_base
63 | chroma_base:
64 | sample_rate: ${sample_rate}
65 | n_chroma: 12
66 | radix2_exp: 14
67 | argmax: true
68 |
69 | generate:
70 | every: 25
71 | num_workers: 5
72 | path: samples
73 | audio:
74 | format: wav
75 | strategy: loudness
76 | sample_rate: ${sample_rate}
77 | loudness_headroom_db: 14
78 | lm:
79 | prompted_samples: true
80 | unprompted_samples: true
81 | gen_gt_samples: false
82 | prompt_duration: null # if not set, will use dataset.generate.segment_duration / 4
83 | gen_duration: null # if not set, will use dataset.generate.segment_duration
84 | remove_prompts: false
85 | # generation params
86 | use_sampling: false
87 | temp: 1.0
88 | top_k: 0
89 | top_p: 0.0
90 | evaluate:
91 | every: 25
92 | num_workers: 5
93 | metrics:
94 | base: false
95 | fad: false
96 | kld: false
97 | text_consistency: false
98 | chroma_cosine: false
99 |
100 | checkpoint:
101 | save_last: true
102 | save_every: 50
103 | keep_last: 10
104 | keep_every_states: null
105 |
106 | optim:
107 | epochs: 200
108 | updates_per_epoch: 2000
109 | lr: 1e-4
110 | optimizer: adamw
111 | max_norm: 1.0
112 | eager_sync: true
113 | adam:
114 | betas: [0.9, 0.95]
115 | weight_decay: 0.1
116 | eps: 1e-8
117 |
118 | schedule:
119 | lr_scheduler: null
120 |
--------------------------------------------------------------------------------
/audiocraft/utils/export.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | """
8 | Utility to export a training checkpoint to a lightweight release checkpoint.
9 | """
10 |
11 | from pathlib import Path
12 | import typing as tp
13 |
14 | from omegaconf import OmegaConf
15 | import torch
16 |
17 | from audiocraft import __version__
18 |
19 |
20 | def export_encodec(checkpoint_path: tp.Union[Path, str], out_file: tp.Union[Path, str]):
21 | """Export only the best state from the given EnCodec checkpoint. This
22 | should be used if you trained your own EnCodec model.
23 | """
24 | pkg = torch.load(checkpoint_path, 'cpu')
25 | new_pkg = {
26 | 'best_state': pkg['best_state']['model'],
27 | 'xp.cfg': OmegaConf.to_yaml(pkg['xp.cfg']),
28 | 'version': __version__,
29 | 'exported': True,
30 | }
31 | Path(out_file).parent.mkdir(exist_ok=True, parents=True)
32 | torch.save(new_pkg, out_file)
33 | return out_file
34 |
35 |
36 | def export_pretrained_compression_model(pretrained_encodec: str, out_file: tp.Union[Path, str]):
37 | """Export a compression model (potentially EnCodec) from a pretrained model.
38 | This is required for packaging the audio tokenizer along a MusicGen or AudioGen model.
39 | Do not include the //pretrained/ prefix. For instance if you trained a model
40 | with `facebook/encodec_32khz`, just put that as a name. Same for `dac_44khz`.
41 |
42 | In that case, this will not actually include a copy of the model, simply the reference
43 | to the model used.
44 | """
45 | if Path(pretrained_encodec).exists():
46 | pkg = torch.load(pretrained_encodec)
47 | assert 'best_state' in pkg
48 | assert 'xp.cfg' in pkg
49 | assert 'version' in pkg
50 | assert 'exported' in pkg
51 | else:
52 | pkg = {
53 | 'pretrained': pretrained_encodec,
54 | 'exported': True,
55 | 'version': __version__,
56 | }
57 | Path(out_file).parent.mkdir(exist_ok=True, parents=True)
58 | torch.save(pkg, out_file)
59 |
60 |
61 | def export_lm(checkpoint_path: tp.Union[Path, str], out_file: tp.Union[Path, str]):
62 | """Export only the best state from the given MusicGen or AudioGen checkpoint.
63 | """
64 | pkg = torch.load(checkpoint_path, 'cpu')
65 | if pkg['fsdp_best_state']:
66 | best_state = pkg['fsdp_best_state']['model']
67 | else:
68 | assert pkg['best_state']
69 | best_state = pkg['best_state']['model']
70 | new_pkg = {
71 | 'best_state': best_state,
72 | 'xp.cfg': OmegaConf.to_yaml(pkg['xp.cfg']),
73 | 'version': __version__,
74 | 'exported': True,
75 | }
76 |
77 | Path(out_file).parent.mkdir(exist_ok=True, parents=True)
78 | torch.save(new_pkg, out_file)
79 | return out_file
80 |
--------------------------------------------------------------------------------
/config/config.yaml:
--------------------------------------------------------------------------------
1 | # WARNING: This is the base configuration file shared across ALL solvers in AudioCraft
2 | # Please don't update this file directly. Instead use distinct configuration files
3 | # to override the below configuration.
4 | defaults:
5 | - _self_
6 | - dset: default
7 | - solver: default
8 |
9 | device: cuda
10 | dtype: float32
11 | autocast: false
12 | autocast_dtype: bfloat16
13 | seed: 2036
14 | show: false # just show the model and its size and exit
15 | continue_from: # continue from a given sig or path
16 | execute_only: # can be set to generate/evaluate/valid to run that stage
17 | execute_inplace: false # don't enforce continue_from to be set
18 | # to enable inplace execution of the stage. This assume
19 | # that you know what you are doing and execute stage
20 | # preserving the original xp sig.
21 | benchmark_no_load: false # if set to true, will repeat the same batch instead of loading them
22 |
23 | efficient_attention_backend: torch # can be torch or xformers.
24 | num_threads: 1 # called with torch.set_num_thread.
25 | mp_start_method: forkserver # multiprocessing method (spawn, fork or fork_server).
26 |
27 |
28 | label: # use this if you want twice the same exp, with a name.
29 |
30 | # logging parameters
31 | logging:
32 | level: INFO
33 | log_updates: 10
34 | log_tensorboard: false
35 | log_wandb: false
36 | tensorboard:
37 | with_media_logging: false
38 | name: # optional name for the experiment
39 | sub_dir: # optional sub directory to store tensorboard data
40 | wandb:
41 | with_media_logging: true
42 | project: # project name
43 | name: # optional name for the experiment
44 | group: # optional group
45 |
46 | # SLURM launcher configuration.
47 | slurm:
48 | gpus: 4 # convenience parameter, number of GPUs to use.
49 | mem_per_gpu: 40 # in GB, total mem is automatically scaled with `gpus`.
50 | time: 3600
51 | constraint:
52 | partition:
53 | comment:
54 | setup: []
55 | exclude: ''
56 |
57 | # dora parameters
58 | dora:
59 | # Output folder for all artifacts of an experiment.
60 | dir: /checkpoint/${oc.env:USER}/experiments/audiocraft/outputs
61 | # The following entries will be ignored by dora when computing the unique XP signature.
62 | # Note that slurm.* and dora.* are automatically ignored.
63 | exclude: [
64 | 'device', 'wandb.*', 'tensorboard.*', 'logging.*',
65 | 'dataset.num_workers', 'eval.num_workers', 'special.*',
66 | 'metrics.visqol.bin', 'metrics.fad.bin',
67 | 'execute_only', 'execute_best', 'generate.every',
68 | 'optim.eager_sync', 'profiler.*', 'deadlock.*',
69 | 'efficient_attention_backend', 'num_threads', 'mp_start_method',
70 | ]
71 | use_rendezvous: false
72 | # for grids, always run from a clean repo, allowing reliable runs and storing
73 | # the exact commit. Your repo must be absolutely pristine clean.
74 | # Local `dora run` are not impacted for easier debugging.
75 | git_save: true
76 |
--------------------------------------------------------------------------------
/config/solver/default.yaml:
--------------------------------------------------------------------------------
1 | # @package __global__
2 |
3 | # WARNING: This is a base configuration file shared across ALL solvers in AudioCraft
4 | # Please don't update this file directly. Instead use distinct configuration files
5 | # to override the below configuration.
6 | solver: ???
7 |
8 | fsdp:
9 | use: false # should we use FSDP.
10 | param_dtype: float16 # equivalent to autocast_dtype for FSDP.
11 | reduce_dtype: float32 # gradient averaging dtype, float32 will give max stability.
12 | buffer_dtype: float32 # dtype used for buffers, we don't have much buffers, so let's leave it.
13 | sharding_strategy: shard_grad_op # can be shard_grad_op or full_shard.
14 | # full_shard will use less memory but slower ??
15 | per_block: true # If True, uses nested FSDP.
16 |
17 | profiler:
18 | enabled: false
19 |
20 | deadlock:
21 | use: false
22 | timeout: 600
23 |
24 | dataset:
25 | batch_size: ???
26 | num_workers: 10
27 | segment_duration: null
28 | num_samples: null
29 | return_info: false
30 | shuffle: false
31 | sample_on_duration: true
32 | sample_on_weight: true
33 | min_segment_ratio: 0.5
34 | train:
35 | num_samples: null
36 | shuffle: true
37 | shuffle_seed: 0 # if you want to sample the data differently.
38 | permutation_on_files: false
39 | valid:
40 | num_samples: null
41 | evaluate:
42 | num_samples: null
43 | generate:
44 | num_samples: null
45 | return_info: true
46 |
47 | checkpoint:
48 | save_last: true
49 | save_every: null
50 | keep_last: null
51 | keep_every_states: null
52 |
53 | generate:
54 | every: null
55 | path: 'samples'
56 | audio:
57 | format: 'mp3'
58 | strategy: 'clip'
59 | sample_rate: null
60 | lm:
61 | use_sampling: false
62 | temp: 1.0
63 | top_k: 0
64 | top_p: 0.0
65 | evaluate:
66 | every: null
67 | num_workers: 5
68 | truncate_audio: null
69 | fixed_generation_duration: null # in secs
70 | metrics:
71 | base: true # run default evaluation (e.g. like train/valid stage)
72 |
73 | optim:
74 | epochs: ???
75 | updates_per_epoch: null
76 | lr: ???
77 | optimizer: ???
78 | adam:
79 | betas: [0.9, 0.999]
80 | weight_decay: 0.
81 | ema:
82 | use: false # whether to use EMA or not
83 | updates: ${optim.updates_per_epoch} # frequency of updates of the EMA
84 | device: cpu # device for EMA, can be put on GPU if more frequent updates
85 | decay: 0.99 # EMA decay value, if null, no EMA is used
86 |
87 | schedule:
88 | lr_scheduler: null
89 | step:
90 | step_size: null
91 | gamma: null
92 | exponential:
93 | lr_decay: null
94 | cosine:
95 | warmup: null
96 | lr_min_ratio: 0.0
97 | cycle_length: 1.0
98 | polynomial_decay:
99 | warmup: null
100 | zero_lr_warmup_steps: 0
101 | end_lr: 0.0
102 | power: 1
103 | inverse_sqrt:
104 | warmup: null
105 | warmup_init_lr: 0.0
106 | linear_warmup:
107 | warmup: null
108 | warmup_init_lr: 0.0
109 |
--------------------------------------------------------------------------------
/audiocraft/modules/chroma.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | import typing as tp
7 |
8 | from einops import rearrange
9 | from librosa import filters
10 | import torch
11 | from torch import nn
12 | import torch.nn.functional as F
13 | import torchaudio
14 |
15 |
16 | class ChromaExtractor(nn.Module):
17 | """Chroma extraction and quantization.
18 |
19 | Args:
20 | sample_rate (int): Sample rate for the chroma extraction.
21 | n_chroma (int): Number of chroma bins for the chroma extraction.
22 | radix2_exp (int): Size of stft window for the chroma extraction (power of 2, e.g. 12 -> 2^12).
23 | nfft (int, optional): Number of FFT.
24 | winlen (int, optional): Window length.
25 | winhop (int, optional): Window hop size.
26 | argmax (bool, optional): Whether to use argmax. Defaults to False.
27 | norm (float, optional): Norm for chroma normalization. Defaults to inf.
28 | """
29 | def __init__(self, sample_rate: int, n_chroma: int = 12, radix2_exp: int = 12, nfft: tp.Optional[int] = None,
30 | winlen: tp.Optional[int] = None, winhop: tp.Optional[int] = None, argmax: bool = False,
31 | norm: float = torch.inf):
32 | super().__init__()
33 | self.winlen = winlen or 2 ** radix2_exp
34 | self.nfft = nfft or self.winlen
35 | self.winhop = winhop or (self.winlen // 4)
36 | self.sample_rate = sample_rate
37 | self.n_chroma = n_chroma
38 | self.norm = norm
39 | self.argmax = argmax
40 | self.register_buffer('fbanks', torch.from_numpy(filters.chroma(sr=sample_rate, n_fft=self.nfft, tuning=0,
41 | n_chroma=self.n_chroma)), persistent=False)
42 | self.spec = torchaudio.transforms.Spectrogram(n_fft=self.nfft, win_length=self.winlen,
43 | hop_length=self.winhop, power=2, center=True,
44 | pad=0, normalized=True)
45 |
46 | def forward(self, wav: torch.Tensor) -> torch.Tensor:
47 | T = wav.shape[-1]
48 | # in case we are getting a wav that was dropped out (nullified)
49 | # from the conditioner, make sure wav length is no less that nfft
50 | if T < self.nfft:
51 | pad = self.nfft - T
52 | r = 0 if pad % 2 == 0 else 1
53 | wav = F.pad(wav, (pad // 2, pad // 2 + r), 'constant', 0)
54 | assert wav.shape[-1] == self.nfft, f"expected len {self.nfft} but got {wav.shape[-1]}"
55 |
56 | spec = self.spec(wav).squeeze(1)
57 | raw_chroma = torch.einsum('cf,...ft->...ct', self.fbanks, spec)
58 | norm_chroma = torch.nn.functional.normalize(raw_chroma, p=self.norm, dim=-2, eps=1e-6)
59 | norm_chroma = rearrange(norm_chroma, 'b d t -> b t d')
60 |
61 | if self.argmax:
62 | idx = norm_chroma.argmax(-1, keepdim=True)
63 | norm_chroma[:] = 0
64 | norm_chroma.scatter_(dim=-1, index=idx, value=1)
65 |
66 | return norm_chroma
67 |
--------------------------------------------------------------------------------
/docs/DATASETS.md:
--------------------------------------------------------------------------------
1 | # AudioCraft datasets
2 |
3 | Our dataset manifest files consist in 1-json-per-line files, potentially gzipped,
4 | as `data.jsons` or `data.jsons.gz` files. This JSON contains the path to the audio
5 | file and associated metadata. The manifest files are then provided in the configuration,
6 | as `datasource` sub-configuration. A datasource contains the pointers to the paths of
7 | the manifest files for each AudioCraft stage (or split) along with additional information
8 | (eg. maximum sample rate to use against this dataset). All the datasources are under the
9 | `dset` group config, with a dedicated configuration file for each dataset.
10 |
11 | ## Getting started
12 |
13 | ### Example
14 |
15 | See the provided example in the directory that provides a manifest to use the example dataset
16 | provided under the [dataset folder](../dataset/example).
17 |
18 | The manifest files are stored in the [egs folder](../egs/example).
19 |
20 | ```shell
21 | egs/
22 | example/data.json.gz
23 | ```
24 |
25 | A datasource is defined in the configuration folder, in the dset group config for this dataset
26 | at [config/dset/audio/example](../config/dset/audio/example.yaml):
27 |
28 | ```shell
29 | # @package __global__
30 |
31 | datasource:
32 | max_sample_rate: 44100
33 | max_channels: 2
34 |
35 | train: egs/example
36 | valid: egs/example
37 | evaluate: egs/example
38 | generate: egs/example
39 | ```
40 |
41 | For proper dataset, one should create manifest for each of the splits and specify the correct path
42 | to the given manifest in the datasource for each split.
43 |
44 | Then, using a dataset through the configuration can be done pointing to the
45 | corresponding dataset configuration:
46 | ```shell
47 | dset= # should match the yaml file name
48 |
49 | # for example
50 | dset=audio/example
51 | ```
52 |
53 | ### Creating manifest files
54 |
55 | Assuming you want to create manifest files to load with AudioCraft's AudioDataset, you can use
56 | the following command to create new manifest files from a given folder containing audio files:
57 |
58 | ```shell
59 | python -m audiocraft.data.audio_dataset egs/my_dataset/my_dataset_split/data.jsonl.gz
60 |
61 | # For example to generate the manifest for dset=audio/example
62 | # note: we don't use any split and we don't compress the jsonl file for this dummy example
63 | python -m audiocraft.data.audio_dataset dataset/example egs/example/data.jsonl
64 |
65 | # More info with: python -m audiocraft.data.audio_dataset --help
66 | ```
67 |
68 | ## Additional information
69 |
70 | ### MusicDataset and metadata
71 |
72 | The MusicDataset is an AudioDataset with additional metadata. The MusicDataset expects
73 | the additional metadata to be stored in a JSON file that has the same path as the corresponding
74 | audio file, but with a `.json` extension.
75 |
76 | ### SoundDataset and metadata
77 |
78 | The SoundDataset is an AudioDataset with descriptions metadata. Similarly to the MusicDataset,
79 | the SoundDataset expects the additional metadata to be stored in a JSON file that has the same
80 | path as the corresponding audio file, but with a `.json` extension. Additionally, the SoundDataset
81 | supports an additional parameter pointing to an extra folder `external_metadata_source` containing
82 | all the JSON metadata files given they have the same filename as the audio file.
83 |
--------------------------------------------------------------------------------
/audiocraft/grids/musicgen/_explorers.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import typing as tp
8 |
9 | import treetable as tt
10 |
11 | from .._base_explorers import BaseExplorer
12 |
13 |
14 | class LMExplorer(BaseExplorer):
15 | eval_metrics: tp.List[str] = []
16 |
17 | def stages(self) -> tp.List[str]:
18 | return ['train', 'valid']
19 |
20 | def get_grid_metrics(self):
21 | """Return the metrics that should be displayed in the tracking table."""
22 | return [
23 | tt.group(
24 | 'train',
25 | [
26 | tt.leaf('epoch'),
27 | tt.leaf('duration', '.1f'), # duration in minutes
28 | tt.leaf('ping'),
29 | tt.leaf('ce', '.4f'), # cross entropy
30 | tt.leaf("ppl", '.3f'), # perplexity
31 | ],
32 | align='>',
33 | ),
34 | tt.group(
35 | 'valid',
36 | [
37 | tt.leaf('ce', '.4f'),
38 | tt.leaf('ppl', '.3f'),
39 | tt.leaf('best_ppl', '.3f'),
40 | ],
41 | align='>',
42 | ),
43 | ]
44 |
45 | def process_sheep(self, sheep, history):
46 | parts = super().process_sheep(sheep, history)
47 |
48 | track_by = {'ppl': 'lower'} # values should be in ['lower', 'higher']
49 | best_metrics = {k: (1 if v == 'lower' else -1) * float('inf') for k, v in track_by.items()}
50 |
51 | def comparator(mode, a, b):
52 | return a < b if mode == 'lower' else a > b
53 |
54 | for metrics in history:
55 | for key, sub in metrics.items():
56 | for metric in track_by:
57 | # for the validation set, keep track of best metrics (ppl in this example)
58 | # this is so we can conveniently compare metrics between runs in the grid
59 | if key == 'valid' and metric in sub and comparator(
60 | track_by[metric], sub[metric], best_metrics[metric]
61 | ):
62 | best_metrics[metric] = sub[metric]
63 |
64 | if 'valid' in parts:
65 | parts['valid'].update({f'best_{k}': v for k, v in best_metrics.items()})
66 | return parts
67 |
68 |
69 | class GenerationEvalExplorer(BaseExplorer):
70 | eval_metrics: tp.List[str] = []
71 |
72 | def stages(self) -> tp.List[str]:
73 | return ['evaluate']
74 |
75 | def get_grid_metrics(self):
76 | """Return the metrics that should be displayed in the tracking table."""
77 | return [
78 | tt.group(
79 | 'evaluate',
80 | [
81 | tt.leaf('epoch', '.3f'),
82 | tt.leaf('duration', '.1f'),
83 | tt.leaf('ping'),
84 | tt.leaf('ce', '.4f'),
85 | tt.leaf('ppl', '.3f'),
86 | tt.leaf('fad', '.3f'),
87 | tt.leaf('kld', '.3f'),
88 | tt.leaf('text_consistency', '.3f'),
89 | tt.leaf('chroma_cosine', '.3f'),
90 | ],
91 | align='>',
92 | ),
93 | ]
94 |
--------------------------------------------------------------------------------
/audiocraft/optim/ema.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # ModelEMA implementation is taken from
8 | # https://github.com/facebookresearch/demucs
9 |
10 | from collections import defaultdict
11 | import typing as tp
12 |
13 | import torch
14 | import torch.nn as nn
15 |
16 |
17 | def _get_all_non_persistent_buffers_set(module: nn.Module, root: str = "") -> set:
18 | names: set = set()
19 | for (name, sub_module) in module.named_modules():
20 | if name == '':
21 | buffer_names = module._non_persistent_buffers_set
22 | buffer_names = {f"{root}.{buff_name}" if len(root) > 0 else buff_name
23 | for buff_name in buffer_names}
24 | names.update(buffer_names)
25 | else:
26 | sub_name = f"{root}.{name}" if len(root) > 0 else name
27 | sub_buffer_names = _get_all_non_persistent_buffers_set(sub_module, sub_name)
28 | names.update(sub_buffer_names)
29 | return names
30 |
31 |
32 | def _get_named_tensors(module: nn.Module):
33 | non_persistent_buffers_set = _get_all_non_persistent_buffers_set(module)
34 | named_buffers = [(name, buffer) for (name, buffer) in module.named_buffers()
35 | if name not in non_persistent_buffers_set]
36 | named_parameters = list(module.named_parameters())
37 | return named_parameters + named_buffers
38 |
39 |
40 | class ModuleDictEMA:
41 | """Exponential Moving Average over a nn.ModuleDict.
42 |
43 | You can switch to the EMA weights temporarily.
44 | """
45 | def __init__(self, module_dict: nn.ModuleDict, decay: float = 0.999,
46 | unbias: bool = True, device: tp.Union[torch.device, str] = 'cpu'):
47 | self.decay = decay
48 | self.module_dict = module_dict
49 | self.state: dict = defaultdict(dict)
50 | self.count = 0
51 | self.device = device
52 | self.unbias = unbias
53 | self._init()
54 |
55 | def _init(self):
56 | for module_name, module in self.module_dict.items():
57 | for key, val in _get_named_tensors(module):
58 | if not val.is_floating_point():
59 | continue
60 | device = self.device or val.device
61 | if key not in self.state[module_name]:
62 | self.state[module_name][key] = val.detach().to(device, copy=True)
63 |
64 | def step(self):
65 | if self.unbias:
66 | self.count = self.count * self.decay + 1
67 | w = 1 / self.count
68 | else:
69 | w = 1 - self.decay
70 | for module_name, module in self.module_dict.items():
71 | for key, val in _get_named_tensors(module):
72 | if not val.is_floating_point():
73 | continue
74 | device = self.device or val.device
75 | self.state[module_name][key].mul_(1 - w)
76 | self.state[module_name][key].add_(val.detach().to(device), alpha=w)
77 |
78 | def state_dict(self):
79 | return {'state': self.state, 'count': self.count}
80 |
81 | def load_state_dict(self, state):
82 | self.count = state['count']
83 | for module_name, module in state['state'].items():
84 | for key, val in module.items():
85 | self.state[module_name][key].copy_(val)
86 |
--------------------------------------------------------------------------------
/audiocraft/losses/sisnr.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import math
8 | import typing as tp
9 |
10 | import torch
11 | from torch import nn
12 | from torch.nn import functional as F
13 |
14 |
15 | def _unfold(a: torch.Tensor, kernel_size: int, stride: int) -> torch.Tensor:
16 | """Given input of size [*OT, T], output Tensor of size [*OT, F, K]
17 | with K the kernel size, by extracting frames with the given stride.
18 | This will pad the input so that `F = ceil(T / K)`.
19 | see https://github.com/pytorch/pytorch/issues/60466
20 | """
21 | *shape, length = a.shape
22 | n_frames = math.ceil(length / stride)
23 | tgt_length = (n_frames - 1) * stride + kernel_size
24 | a = F.pad(a, (0, tgt_length - length))
25 | strides = list(a.stride())
26 | assert strides[-1] == 1, "data should be contiguous"
27 | strides = strides[:-1] + [stride, 1]
28 | return a.as_strided([*shape, n_frames, kernel_size], strides)
29 |
30 |
31 | def _center(x: torch.Tensor) -> torch.Tensor:
32 | return x - x.mean(-1, True)
33 |
34 |
35 | def _norm2(x: torch.Tensor) -> torch.Tensor:
36 | return x.pow(2).sum(-1, True)
37 |
38 |
39 | class SISNR(nn.Module):
40 | """SISNR loss.
41 |
42 | Input should be [B, C, T], output is scalar.
43 |
44 | ..Warning:: This function returns the opposite of the SI-SNR (e.g. `-1 * regular_SI_SNR`).
45 | Consequently, lower scores are better in terms of reconstruction quality,
46 | in particular, it should be negative if training goes well. This done this way so
47 | that this module can also be used as a loss function for training model.
48 |
49 | Args:
50 | sample_rate (int): Sample rate.
51 | segment (float or None): Evaluate on chunks of that many seconds. If None, evaluate on
52 | entire audio only.
53 | overlap (float): Overlap between chunks, i.e. 0.5 = 50 % overlap.
54 | epsilon (float): Epsilon value for numerical stability.
55 | """
56 | def __init__(
57 | self,
58 | sample_rate: int = 16000,
59 | segment: tp.Optional[float] = 20,
60 | overlap: float = 0.5,
61 | epsilon: float = torch.finfo(torch.float32).eps,
62 | ):
63 | super().__init__()
64 | self.sample_rate = sample_rate
65 | self.segment = segment
66 | self.overlap = overlap
67 | self.epsilon = epsilon
68 |
69 | def forward(self, out_sig: torch.Tensor, ref_sig: torch.Tensor) -> torch.Tensor:
70 | B, C, T = ref_sig.shape
71 | assert ref_sig.shape == out_sig.shape
72 |
73 | if self.segment is None:
74 | frame = T
75 | stride = T
76 | else:
77 | frame = int(self.segment * self.sample_rate)
78 | stride = int(frame * (1 - self.overlap))
79 |
80 | epsilon = self.epsilon * frame # make epsilon prop to frame size.
81 |
82 | gt = _unfold(ref_sig, frame, stride)
83 | est = _unfold(out_sig, frame, stride)
84 | if self.segment is None:
85 | assert gt.shape[-1] == 1
86 |
87 | gt = _center(gt)
88 | est = _center(est)
89 | dot = torch.einsum("bcft,bcft->bcf", gt, est)
90 |
91 | proj = dot[:, :, :, None] * gt / (epsilon + _norm2(gt))
92 | noise = est - proj
93 |
94 | sisnr = 10 * (
95 | torch.log10(epsilon + _norm2(proj)) - torch.log10(epsilon + _norm2(noise))
96 | )
97 | return -1 * sisnr[..., 0].mean()
98 |
--------------------------------------------------------------------------------
/audiocraft/modules/activations.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import torch
8 | import torch.nn as nn
9 | from torch import Tensor
10 | from typing import Union, Callable
11 |
12 |
13 | class CustomGLU(nn.Module):
14 | """Custom Gated Linear Unit activation.
15 | Applies a modified gated linear unit :math:`a * f(b)` where :math:`a` is the first half
16 | of the input matrices, :math:`b` is the second half, and :math:`f` is a provided activation
17 | function (i.e. sigmoid, swish, etc.).
18 |
19 | Args:
20 | activation (nn.Module): The custom activation to apply in the Gated Linear Unit
21 | dim (int): the dimension on which to split the input. Default: -1
22 |
23 | Shape:
24 | - Input: :math:`(\ast_1, N, \ast_2)` where `*` means, any number of additional
25 | dimensions
26 | - Output: :math:`(\ast_1, M, \ast_2)` where :math:`M=N/2`
27 |
28 | Examples::
29 | >>> m = CustomGLU(nn.Sigmoid())
30 | >>> input = torch.randn(4, 2)
31 | >>> output = m(input)
32 | """
33 | def __init__(self, activation: nn.Module, dim: int = -1):
34 | super(CustomGLU, self).__init__()
35 | self.dim = dim
36 | self.activation = activation
37 |
38 | def forward(self, x: Tensor):
39 | assert x.shape[self.dim] % 2 == 0 # M = N / 2
40 | a, b = torch.chunk(x, 2, dim=self.dim)
41 | return a * self.activation(b)
42 |
43 |
44 | class SwiGLU(CustomGLU):
45 | """SiLU Gated Linear Unit activation.
46 | Applies SiLU Gated Linear Unit :math:`a * SiLU(b)` where :math:`a` is
47 | the first half of the input matrices, :math:`b` is the second half.
48 |
49 | Args:
50 | dim (int): the dimension on which to split the input. Default: -1
51 | """
52 | def __init__(self, dim: int = -1):
53 | super(SwiGLU, self).__init__(nn.SiLU(), dim)
54 |
55 |
56 | class GeGLU(CustomGLU):
57 | """GeLU Gated Linear Unit activation.
58 | Applies GeLU Gated Linear Unit :math:`a * GELU(b)` where :math:`a` is
59 | the first half of the input matrices, :math:`b` is the second half.
60 |
61 | Args:
62 | dim (int): the dimension on which to split the input. Default: -1
63 | """
64 | def __init__(self, dim: int = -1):
65 | super(GeGLU, self).__init__(nn.GELU(), dim)
66 |
67 |
68 | class ReGLU(CustomGLU):
69 | """ReLU Gated Linear Unit activation.
70 | Applies ReLU Gated Linear Unit :math:`a * ReLU(b)` where :math:`a` is
71 | the first half of the input matrices, :math:`b` is the second half.
72 |
73 | Args:
74 | dim (int): the dimension on which to split the input. Default: -1
75 | """
76 | def __init__(self, dim: int = -1):
77 | super(ReGLU, self).__init__(nn.ReLU(), dim)
78 |
79 |
80 | def get_activation_fn(
81 | activation: Union[str, Callable[[Tensor], Tensor]]
82 | ) -> Union[str, Callable[[Tensor], Tensor]]:
83 | """Helper function to map an activation string to the activation class.
84 | If the supplied activation is not a string that is recognized, the activation is passed back.
85 |
86 | Args:
87 | activation (str, or Callable[[Tensor], Tensor]): Activation to check
88 | """
89 | if isinstance(activation, str):
90 | if activation == "reglu":
91 | return ReGLU()
92 | elif activation == "geglu":
93 | return GeGLU()
94 | elif activation == "swiglu":
95 | return SwiGLU()
96 | return activation
97 |
--------------------------------------------------------------------------------
/audiocraft/quantization/base.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | """
8 | Base class for all quantizers.
9 | """
10 |
11 | from dataclasses import dataclass, field
12 | import typing as tp
13 |
14 | import torch
15 | from torch import nn
16 |
17 |
18 | @dataclass
19 | class QuantizedResult:
20 | x: torch.Tensor
21 | codes: torch.Tensor
22 | bandwidth: torch.Tensor # bandwidth in kb/s used, per batch item.
23 | penalty: tp.Optional[torch.Tensor] = None
24 | metrics: dict = field(default_factory=dict)
25 |
26 |
27 | class BaseQuantizer(nn.Module):
28 | """Base class for quantizers.
29 | """
30 |
31 | def forward(self, x: torch.Tensor, frame_rate: int) -> QuantizedResult:
32 | """
33 | Given input tensor x, returns first the quantized (or approximately quantized)
34 | representation along with quantized codes, bandwidth, and any penalty term for the loss.
35 | Finally, this returns a dict of metrics to update logging etc.
36 | Frame rate must be passed so that the bandwidth is properly computed.
37 | """
38 | raise NotImplementedError()
39 |
40 | def encode(self, x: torch.Tensor) -> torch.Tensor:
41 | """Encode a given input tensor with the specified sample rate at the given bandwidth."""
42 | raise NotImplementedError()
43 |
44 | def decode(self, codes: torch.Tensor) -> torch.Tensor:
45 | """Decode the given codes to the quantized representation."""
46 | raise NotImplementedError()
47 |
48 | @property
49 | def total_codebooks(self):
50 | """Total number of codebooks."""
51 | raise NotImplementedError()
52 |
53 | @property
54 | def num_codebooks(self):
55 | """Number of active codebooks."""
56 | raise NotImplementedError()
57 |
58 | def set_num_codebooks(self, n: int):
59 | """Set the number of active codebooks."""
60 | raise NotImplementedError()
61 |
62 |
63 | class DummyQuantizer(BaseQuantizer):
64 | """Fake quantizer that actually does not perform any quantization.
65 | """
66 | def __init__(self):
67 | super().__init__()
68 |
69 | def forward(self, x: torch.Tensor, frame_rate: int):
70 | q = x.unsqueeze(1)
71 | return QuantizedResult(x, q, torch.tensor(q.numel() * 32 * frame_rate / 1000 / len(x)).to(x))
72 |
73 | def encode(self, x: torch.Tensor) -> torch.Tensor:
74 | """Encode a given input tensor with the specified sample rate at the given bandwidth.
75 | In the case of the DummyQuantizer, the codes are actually identical
76 | to the input and resulting quantized representation as no quantization is done.
77 | """
78 | return x.unsqueeze(1)
79 |
80 | def decode(self, codes: torch.Tensor) -> torch.Tensor:
81 | """Decode the given codes to the quantized representation.
82 | In the case of the DummyQuantizer, the codes are actually identical
83 | to the input and resulting quantized representation as no quantization is done.
84 | """
85 | return codes.squeeze(1)
86 |
87 | @property
88 | def total_codebooks(self):
89 | """Total number of codebooks."""
90 | return 1
91 |
92 | @property
93 | def num_codebooks(self):
94 | """Total number of codebooks."""
95 | return self.total_codebooks
96 |
97 | def set_num_codebooks(self, n: int):
98 | """Set the number of active codebooks."""
99 | raise AttributeError("Cannot override the number of codebooks for the dummy quantizer")
100 |
--------------------------------------------------------------------------------
/CODE_OF_CONDUCT.md:
--------------------------------------------------------------------------------
1 | # Code of Conduct
2 |
3 | ## Our Pledge
4 |
5 | In the interest of fostering an open and welcoming environment, we as
6 | contributors and maintainers pledge to make participation in our project and
7 | our community a harassment-free experience for everyone, regardless of age, body
8 | size, disability, ethnicity, sex characteristics, gender identity and expression,
9 | level of experience, education, socio-economic status, nationality, personal
10 | appearance, race, religion, or sexual identity and orientation.
11 |
12 | ## Our Standards
13 |
14 | Examples of behavior that contributes to creating a positive environment
15 | include:
16 |
17 | * Using welcoming and inclusive language
18 | * Being respectful of differing viewpoints and experiences
19 | * Gracefully accepting constructive criticism
20 | * Focusing on what is best for the community
21 | * Showing empathy towards other community members
22 |
23 | Examples of unacceptable behavior by participants include:
24 |
25 | * The use of sexualized language or imagery and unwelcome sexual attention or
26 | advances
27 | * Trolling, insulting/derogatory comments, and personal or political attacks
28 | * Public or private harassment
29 | * Publishing others' private information, such as a physical or electronic
30 | address, without explicit permission
31 | * Other conduct which could reasonably be considered inappropriate in a
32 | professional setting
33 |
34 | ## Our Responsibilities
35 |
36 | Project maintainers are responsible for clarifying the standards of acceptable
37 | behavior and are expected to take appropriate and fair corrective action in
38 | response to any instances of unacceptable behavior.
39 |
40 | Project maintainers have the right and responsibility to remove, edit, or
41 | reject comments, commits, code, wiki edits, issues, and other contributions
42 | that are not aligned to this Code of Conduct, or to ban temporarily or
43 | permanently any contributor for other behaviors that they deem inappropriate,
44 | threatening, offensive, or harmful.
45 |
46 | ## Scope
47 |
48 | This Code of Conduct applies within all project spaces, and it also applies when
49 | an individual is representing the project or its community in public spaces.
50 | Examples of representing a project or community include using an official
51 | project e-mail address, posting via an official social media account, or acting
52 | as an appointed representative at an online or offline event. Representation of
53 | a project may be further defined and clarified by project maintainers.
54 |
55 | This Code of Conduct also applies outside the project spaces when there is a
56 | reasonable belief that an individual's behavior may have a negative impact on
57 | the project or its community.
58 |
59 | ## Enforcement
60 |
61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be
62 | reported by contacting the project team at . All
63 | complaints will be reviewed and investigated and will result in a response that
64 | is deemed necessary and appropriate to the circumstances. The project team is
65 | obligated to maintain confidentiality with regard to the reporter of an incident.
66 | Further details of specific enforcement policies may be posted separately.
67 |
68 | Project maintainers who do not follow or enforce the Code of Conduct in good
69 | faith may face temporary or permanent repercussions as determined by other
70 | members of the project's leadership.
71 |
72 | ## Attribution
73 |
74 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
75 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html
76 |
77 | [homepage]: https://www.contributor-covenant.org
78 |
79 | For answers to common questions about this code of conduct, see
80 | https://www.contributor-covenant.org/faq
81 |
--------------------------------------------------------------------------------
/config/solver/compression/default.yaml:
--------------------------------------------------------------------------------
1 | # @package __global__
2 |
3 | defaults:
4 | - ../default
5 | - override /dset: audio/default
6 | - _self_
7 |
8 | solver: compression
9 | sample_rate: ???
10 | channels: ???
11 |
12 | # loss balancing
13 | losses:
14 | adv: 4.
15 | feat: 4.
16 | l1: 0.1
17 | mel: 0.
18 | msspec: 2.
19 | sisnr: 0.
20 | balancer:
21 | balance_grads: true
22 | ema_decay: 0.999
23 | per_batch_item: true
24 | total_norm: 1.
25 |
26 | adversarial:
27 | every: 1
28 | adversaries: [msstftd]
29 | adv_loss: hinge
30 | feat_loss: l1
31 |
32 | # losses hyperparameters
33 | l1: {}
34 | l2: {}
35 | mrstft:
36 | factor_sc: .5
37 | factor_mag: .5
38 | normalized: false
39 | mel:
40 | sample_rate: ${sample_rate}
41 | n_fft: 1024
42 | hop_length: 256
43 | win_length: 1024
44 | n_mels: 64
45 | f_min: 64
46 | f_max: null
47 | normalized: false
48 | floor_level: 1e-5
49 | sisnr:
50 | sample_rate: ${sample_rate}
51 | segment: 5.
52 | msspec:
53 | sample_rate: ${sample_rate}
54 | range_start: 6
55 | range_end: 11
56 | n_mels: 64
57 | f_min: 64
58 | f_max: null
59 | normalized: true
60 | alphas: false
61 | floor_level: 1e-5
62 |
63 | # metrics
64 | metrics:
65 | visqol:
66 | mode: audio
67 | bin: null # path to visqol install
68 | model: tcdaudio14_aacvopus_coresv_svrnsim_n.68_g.01_c1.model # visqol v3
69 |
70 | # adversaries hyperparameters
71 | msstftd:
72 | in_channels: 1
73 | out_channels: 1
74 | filters: 32
75 | norm: weight_norm
76 | n_ffts: [1024, 2048, 512, 256, 128]
77 | hop_lengths: [256, 512, 128, 64, 32]
78 | win_lengths: [1024, 2048, 512, 256, 128]
79 | activation: LeakyReLU
80 | activation_params: {negative_slope: 0.3}
81 | msd:
82 | in_channels: 1
83 | out_channels: 1
84 | scale_norms: [spectral_norm, weight_norm, weight_norm]
85 | kernel_sizes: [5, 3]
86 | filters: 16
87 | max_filters: 1024
88 | downsample_scales: [4, 4, 4, 4]
89 | inner_kernel_sizes: null
90 | groups: [4, 4, 4, 4]
91 | strides: null
92 | paddings: null
93 | activation: LeakyReLU
94 | activation_params: {negative_slope: 0.3}
95 | mpd:
96 | in_channels: 1
97 | out_channels: 1
98 | periods: [2, 3, 5, 7, 11]
99 | n_layers: 5
100 | kernel_size: 5
101 | stride: 3
102 | filters: 8
103 | filter_scales: 4
104 | max_filters: 1024
105 | activation: LeakyReLU
106 | activation_params: {negative_slope: 0.3}
107 | norm: weight_norm
108 |
109 | # data hyperparameters
110 | dataset:
111 | batch_size: 64
112 | num_workers: 10
113 | segment_duration: 1
114 | train:
115 | num_samples: 500000
116 | valid:
117 | num_samples: 10000
118 | evaluate:
119 | batch_size: 32
120 | num_samples: 10000
121 | generate:
122 | batch_size: 32
123 | num_samples: 50
124 | segment_duration: 10
125 |
126 | # solver hyperparameters
127 | evaluate:
128 | every: 25
129 | num_workers: 5
130 | metrics:
131 | visqol: false
132 | sisnr: true
133 | generate:
134 | every: 25
135 | num_workers: 5
136 | audio:
137 | sample_rate: ${sample_rate}
138 |
139 | # checkpointing schedule
140 | checkpoint:
141 | save_last: true
142 | save_every: 25
143 | keep_last: 10
144 | keep_every_states: null
145 |
146 | # optimization hyperparameters
147 | optim:
148 | epochs: 200
149 | updates_per_epoch: 2000
150 | lr: 3e-4
151 | max_norm: 0.
152 | optimizer: adam
153 | adam:
154 | betas: [0.5, 0.9]
155 | weight_decay: 0.
156 | ema:
157 | use: true # whether to use EMA or not
158 | updates: 1 # update at every step
159 | device: ${device} # device for EMA, can be put on GPU if more frequent updates
160 | decay: 0.99 # EMA decay value, if null, no EMA is used
161 |
--------------------------------------------------------------------------------
/audiocraft/metrics/chroma_cosinesim.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import torch
8 | import torchmetrics
9 |
10 | from ..data.audio_utils import convert_audio
11 | from ..modules.chroma import ChromaExtractor
12 |
13 |
14 | class ChromaCosineSimilarityMetric(torchmetrics.Metric):
15 | """Chroma cosine similarity metric.
16 |
17 | This metric extracts a chromagram for a reference waveform and
18 | a generated waveform and compares each frame using the cosine similarity
19 | function. The output is the mean cosine similarity.
20 |
21 | Args:
22 | sample_rate (int): Sample rate used by the chroma extractor.
23 | n_chroma (int): Number of chroma used by the chroma extractor.
24 | radix2_exp (int): Exponent for the chroma extractor.
25 | argmax (bool): Whether the chroma extractor uses argmax.
26 | eps (float): Epsilon for cosine similarity computation.
27 | """
28 | def __init__(self, sample_rate: int, n_chroma: int, radix2_exp: int, argmax: bool, eps: float = 1e-8):
29 | super().__init__()
30 | self.chroma_sample_rate = sample_rate
31 | self.n_chroma = n_chroma
32 | self.eps = eps
33 | self.chroma_extractor = ChromaExtractor(sample_rate=self.chroma_sample_rate, n_chroma=self.n_chroma,
34 | radix2_exp=radix2_exp, argmax=argmax)
35 | self.add_state("cosine_sum", default=torch.tensor(0.), dist_reduce_fx="sum")
36 | self.add_state("weight", default=torch.tensor(0.), dist_reduce_fx="sum")
37 |
38 | def update(self, preds: torch.Tensor, targets: torch.Tensor,
39 | sizes: torch.Tensor, sample_rates: torch.Tensor) -> None:
40 | """Compute cosine similarity between chromagrams and accumulate scores over the dataset."""
41 | if preds.size(0) == 0:
42 | return
43 |
44 | assert preds.shape == targets.shape, (
45 | f"Preds and target shapes mismatch: preds={preds.shape}, targets={targets.shape}")
46 | assert preds.size(0) == sizes.size(0), (
47 | f"Number of items in preds ({preds.shape}) mismatch ",
48 | f"with sizes ({sizes.shape})")
49 | assert preds.size(0) == sample_rates.size(0), (
50 | f"Number of items in preds ({preds.shape}) mismatch ",
51 | f"with sample_rates ({sample_rates.shape})")
52 | assert torch.all(sample_rates == sample_rates[0].item()), "All sample rates are not the same in the batch"
53 |
54 | device = self.weight.device
55 | preds, targets = preds.to(device), targets.to(device) # type: ignore
56 | sample_rate = sample_rates[0].item()
57 | preds = convert_audio(preds, from_rate=sample_rate, to_rate=self.chroma_sample_rate, to_channels=1)
58 | targets = convert_audio(targets, from_rate=sample_rate, to_rate=self.chroma_sample_rate, to_channels=1)
59 | gt_chroma = self.chroma_extractor(targets)
60 | gen_chroma = self.chroma_extractor(preds)
61 | chroma_lens = (sizes / self.chroma_extractor.winhop).ceil().int()
62 | for i in range(len(gt_chroma)):
63 | t = int(chroma_lens[i].item())
64 | cosine_sim = torch.nn.functional.cosine_similarity(
65 | gt_chroma[i, :t], gen_chroma[i, :t], dim=1, eps=self.eps)
66 | self.cosine_sum += cosine_sim.sum(dim=0) # type: ignore
67 | self.weight += torch.tensor(t) # type: ignore
68 |
69 | def compute(self) -> float:
70 | """Computes the average cosine similarty across all generated/target chromagrams pairs."""
71 | assert self.weight.item() > 0, "Unable to compute with total number of comparisons <= 0" # type: ignore
72 | return (self.cosine_sum / self.weight).item() # type: ignore
73 |
--------------------------------------------------------------------------------
/audiocraft/utils/best_state.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from collections import defaultdict
8 | import logging
9 | import typing as tp
10 |
11 | import flashy
12 | import torch
13 |
14 | from ..optim import ModuleDictEMA
15 | from .utils import copy_state
16 |
17 |
18 | logger = logging.getLogger(__name__)
19 |
20 |
21 | class BestStateDictManager(flashy.state.StateDictSource):
22 | """BestStateDictManager maintains a copy of best state_dict() for registered sources.
23 |
24 | BestStateDictManager has two main attributes:
25 | states (dict): State dict of the registered StateDictSource.
26 | param_ids (dict): Dict of parameter ids for registered states from ModuleDictEMA and other sources.
27 |
28 | When registering new sources, the BestStateDictManager will ensure two conflicting sources between
29 | ModuleDictEMA and original modules are not both registered as it would otherwise create ambiguity about
30 | what to consider for best state.
31 |
32 | Args:
33 | device (torch.device or str): Device on which we keep the copy.
34 | dtype (torch.dtype): Data type for the state parameters.
35 | """
36 | def __init__(self, device: tp.Union[torch.device, str] = 'cpu',
37 | dtype: tp.Optional[torch.dtype] = None):
38 | self.device = device
39 | self.states: dict = {}
40 | self.param_ids: dict = defaultdict(dict)
41 | self.dtype = dtype
42 |
43 | def _get_parameter_ids(self, state_dict):
44 | return {id(p): name for name, p in state_dict.items() if isinstance(p, torch.Tensor)}
45 |
46 | def _validate_no_parameter_ids_overlap(self, name: str, param_ids: dict):
47 | for registered_name, registered_param_ids in self.param_ids.items():
48 | if registered_name != name:
49 | overlap = set.intersection(registered_param_ids.keys(), param_ids.keys())
50 | assert len(overlap) == 0, f"Found {len(overlap)} / {len(param_ids.keys())} overlapping parameters"
51 | f" in {name} and already registered {registered_name}: {' '.join(overlap)}"
52 |
53 | def update(self, name: str, source: flashy.state.StateDictSource):
54 | if name not in self.states:
55 | raise ValueError(f"{name} missing from registered states.")
56 | self.states[name] = copy_state(source.state_dict(), device=self.device, dtype=self.dtype)
57 |
58 | def register(self, name: str, source: flashy.state.StateDictSource):
59 | if name in self.states:
60 | raise ValueError(f"{name} already present in states.")
61 | # Registering parameter ids for EMA and non-EMA states allows us to check that
62 | # there is no overlap that would create ambiguity about how to handle the best state
63 | param_ids = self._get_parameter_ids(source.state_dict())
64 | if isinstance(source, ModuleDictEMA):
65 | logger.debug(f"Registering to best state: ModuleDictEMA '{name}' with {len(param_ids)} params")
66 | self._validate_no_parameter_ids_overlap(name, param_ids)
67 | self.param_ids[name] = param_ids
68 | else:
69 | logger.debug(f"Registering to best state: StateDictSource '{name}' with {len(param_ids)} params")
70 | self._validate_no_parameter_ids_overlap('base', param_ids)
71 | self.param_ids['base'].update(param_ids)
72 | # Register state
73 | self.states[name] = copy_state(source.state_dict(), device=self.device, dtype=self.dtype)
74 |
75 | def state_dict(self) -> flashy.state.StateDict:
76 | return self.states
77 |
78 | def load_state_dict(self, state: flashy.state.StateDict):
79 | for name, sub_state in state.items():
80 | for k, v in sub_state.items():
81 | self.states[name][k].copy_(v)
82 |
--------------------------------------------------------------------------------
/tests/data/test_audio_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import julius
8 | import torch
9 | import pytest
10 |
11 | from audiocraft.data.audio_utils import (
12 | _clip_wav,
13 | convert_audio_channels,
14 | convert_audio,
15 | normalize_audio
16 | )
17 | from ..common_utils import get_batch_white_noise
18 |
19 |
20 | class TestConvertAudioChannels:
21 |
22 | def test_convert_audio_channels_downmix(self):
23 | b, c, t = 2, 3, 100
24 | audio = get_batch_white_noise(b, c, t)
25 | mixed = convert_audio_channels(audio, channels=2)
26 | assert list(mixed.shape) == [b, 2, t]
27 |
28 | def test_convert_audio_channels_nochange(self):
29 | b, c, t = 2, 3, 100
30 | audio = get_batch_white_noise(b, c, t)
31 | mixed = convert_audio_channels(audio, channels=c)
32 | assert list(mixed.shape) == list(audio.shape)
33 |
34 | def test_convert_audio_channels_upmix(self):
35 | b, c, t = 2, 1, 100
36 | audio = get_batch_white_noise(b, c, t)
37 | mixed = convert_audio_channels(audio, channels=3)
38 | assert list(mixed.shape) == [b, 3, t]
39 |
40 | def test_convert_audio_channels_upmix_error(self):
41 | b, c, t = 2, 2, 100
42 | audio = get_batch_white_noise(b, c, t)
43 | with pytest.raises(ValueError):
44 | convert_audio_channels(audio, channels=3)
45 |
46 |
47 | class TestConvertAudio:
48 |
49 | def test_convert_audio_channels_downmix(self):
50 | b, c, dur = 2, 3, 4.
51 | sr = 128
52 | audio = get_batch_white_noise(b, c, int(sr * dur))
53 | out = convert_audio(audio, from_rate=sr, to_rate=sr, to_channels=2)
54 | assert list(out.shape) == [audio.shape[0], 2, audio.shape[-1]]
55 |
56 | def test_convert_audio_channels_upmix(self):
57 | b, c, dur = 2, 1, 4.
58 | sr = 128
59 | audio = get_batch_white_noise(b, c, int(sr * dur))
60 | out = convert_audio(audio, from_rate=sr, to_rate=sr, to_channels=3)
61 | assert list(out.shape) == [audio.shape[0], 3, audio.shape[-1]]
62 |
63 | def test_convert_audio_upsample(self):
64 | b, c, dur = 2, 1, 4.
65 | sr = 2
66 | new_sr = 3
67 | audio = get_batch_white_noise(b, c, int(sr * dur))
68 | out = convert_audio(audio, from_rate=sr, to_rate=new_sr, to_channels=c)
69 | out_j = julius.resample.resample_frac(audio, old_sr=sr, new_sr=new_sr)
70 | assert torch.allclose(out, out_j)
71 |
72 | def test_convert_audio_resample(self):
73 | b, c, dur = 2, 1, 4.
74 | sr = 3
75 | new_sr = 2
76 | audio = get_batch_white_noise(b, c, int(sr * dur))
77 | out = convert_audio(audio, from_rate=sr, to_rate=new_sr, to_channels=c)
78 | out_j = julius.resample.resample_frac(audio, old_sr=sr, new_sr=new_sr)
79 | assert torch.allclose(out, out_j)
80 |
81 |
82 | class TestNormalizeAudio:
83 |
84 | def test_clip_wav(self):
85 | b, c, dur = 2, 1, 4.
86 | sr = 3
87 | audio = 10.0 * get_batch_white_noise(b, c, int(sr * dur))
88 | _clip_wav(audio)
89 | assert audio.abs().max() <= 1
90 |
91 | def test_normalize_audio_clip(self):
92 | b, c, dur = 2, 1, 4.
93 | sr = 3
94 | audio = 10.0 * get_batch_white_noise(b, c, int(sr * dur))
95 | norm_audio = normalize_audio(audio, strategy='clip')
96 | assert norm_audio.abs().max() <= 1
97 |
98 | def test_normalize_audio_rms(self):
99 | b, c, dur = 2, 1, 4.
100 | sr = 3
101 | audio = 10.0 * get_batch_white_noise(b, c, int(sr * dur))
102 | norm_audio = normalize_audio(audio, strategy='rms')
103 | assert norm_audio.abs().max() <= 1
104 |
105 | def test_normalize_audio_peak(self):
106 | b, c, dur = 2, 1, 4.
107 | sr = 3
108 | audio = 10.0 * get_batch_white_noise(b, c, int(sr * dur))
109 | norm_audio = normalize_audio(audio, strategy='peak')
110 | assert norm_audio.abs().max() <= 1
111 |
--------------------------------------------------------------------------------
/audiocraft/grids/musicgen/musicgen_pretrained_32khz_eval.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | """
8 | Evaluation with objective metrics for the pretrained MusicGen models.
9 | This grid takes signature from the training grid and runs evaluation-only stage.
10 |
11 | When running the grid for the first time, please use:
12 | REGEN=1 dora grid musicgen.musicgen_pretrained_32khz_eval
13 | and re-use the REGEN=1 option when the grid is changed to force regenerating it.
14 |
15 | Note that you need the proper metrics external libraries setup to use all
16 | the objective metrics activated in this grid. Refer to the README for more information.
17 | """
18 |
19 | import os
20 |
21 | from ._explorers import GenerationEvalExplorer
22 | from ...environment import AudioCraftEnvironment
23 | from ... import train
24 |
25 |
26 | def eval(launcher, batch_size: int = 32, eval_melody: bool = False):
27 | opts = {
28 | 'dset': 'audio/musiccaps_32khz',
29 | 'solver/musicgen/evaluation': 'objective_eval',
30 | 'execute_only': 'evaluate',
31 | '+dataset.evaluate.batch_size': batch_size,
32 | '+metrics.fad.tf.batch_size': 16,
33 | }
34 | # chroma-specific evaluation
35 | chroma_opts = {
36 | 'dset': 'internal/music_400k_32khz',
37 | 'dataset.evaluate.segment_duration': 30,
38 | 'dataset.evaluate.num_samples': 1000,
39 | 'evaluate.metrics.chroma_cosine': True,
40 | 'evaluate.metrics.fad': False,
41 | 'evaluate.metrics.kld': False,
42 | 'evaluate.metrics.text_consistency': False,
43 | }
44 | # binary for FAD computation: replace this path with your own path
45 | metrics_opts = {
46 | 'metrics.fad.tf.bin': '/data/home/jadecopet/local/usr/opt/google-research'
47 | }
48 | opt1 = {'generate.lm.use_sampling': True, 'generate.lm.top_k': 250, 'generate.lm.top_p': 0.}
49 | opt2 = {'transformer_lm.two_step_cfg': True}
50 |
51 | sub = launcher.bind(opts)
52 | sub.bind_(metrics_opts)
53 |
54 | # base objective metrics
55 | sub(opt1, opt2)
56 |
57 | if eval_melody:
58 | # chroma-specific metrics
59 | sub(opt1, opt2, chroma_opts)
60 |
61 |
62 | @GenerationEvalExplorer
63 | def explorer(launcher):
64 | partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global'])
65 | launcher.slurm_(gpus=4, partition=partitions)
66 |
67 | if 'REGEN' not in os.environ:
68 | folder = train.main.dora.dir / 'grids' / __name__.split('.', 2)[-1]
69 | with launcher.job_array():
70 | for sig in folder.iterdir():
71 | if not sig.is_symlink():
72 | continue
73 | xp = train.main.get_xp_from_sig(sig.name)
74 | launcher(xp.argv)
75 | return
76 |
77 | with launcher.job_array():
78 | musicgen_base = launcher.bind(solver="musicgen/musicgen_base_32khz")
79 | musicgen_base.bind_({'autocast': False, 'fsdp.use': True})
80 |
81 | # base musicgen models
82 | musicgen_base_small = musicgen_base.bind({'continue_from': '//pretrained/facebook/musicgen-small'})
83 | eval(musicgen_base_small, batch_size=128)
84 |
85 | musicgen_base_medium = musicgen_base.bind({'continue_from': '//pretrained/facebook/musicgen-medium'})
86 | musicgen_base_medium.bind_({'model/lm/model_scale': 'medium'})
87 | eval(musicgen_base_medium, batch_size=128)
88 |
89 | musicgen_base_large = musicgen_base.bind({'continue_from': '//pretrained/facebook/musicgen-large'})
90 | musicgen_base_large.bind_({'model/lm/model_scale': 'large'})
91 | eval(musicgen_base_large, batch_size=128)
92 |
93 | # melody musicgen model
94 | musicgen_melody = launcher.bind(solver="musicgen/musicgen_melody_32khz")
95 | musicgen_melody.bind_({'autocast': False, 'fsdp.use': True})
96 |
97 | musicgen_melody_medium = musicgen_melody.bind({'continue_from': '//pretrained/facebook/musicgen-melody'})
98 | musicgen_melody_medium.bind_({'model/lm/model_scale': 'medium'})
99 | eval(musicgen_melody_medium, batch_size=128, eval_melody=True)
100 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # AudioCraft
2 | 
3 | 
4 | 
5 |
6 | AudioCraft is a PyTorch library for deep learning research on audio generation. AudioCraft contains inference and training code
7 | for two state-of-the-art AI generative models producing high-quality audio: AudioGen and MusicGen.
8 |
9 |
10 | ## Installation
11 | AudioCraft requires Python 3.9, PyTorch 2.0.0. To install AudioCraft, you can run the following:
12 |
13 | ```shell
14 | # Best to make sure you have torch installed first, in particular before installing xformers.
15 | # Don't run this if you already have PyTorch installed.
16 | pip install 'torch>=2.0'
17 | # Then proceed to one of the following
18 | pip install -U audiocraft # stable release
19 | pip install -U git+https://git@github.com/facebookresearch/audiocraft#egg=audiocraft # bleeding edge
20 | pip install -e . # or if you cloned the repo locally (mandatory if you want to train).
21 | ```
22 |
23 | We also recommend having `ffmpeg` installed, either through your system or Anaconda:
24 | ```bash
25 | sudo apt-get install ffmpeg
26 | # Or if you are using Anaconda or Miniconda
27 | conda install "ffmpeg<5" -c conda-forge
28 | ```
29 |
30 | ## Models
31 |
32 | At the moment, AudioCraft contains the training code and inference code for:
33 | * [MusicGen](./docs/MUSICGEN.md): A state-of-the-art controllable text-to-music model.
34 | * [AudioGen](./docs/AUDIOGEN.md): A state-of-the-art text-to-sound model.
35 | * [EnCodec](./docs/ENCODEC.md): A state-of-the-art high fidelity neural audio codec.
36 | * [Multi Band Diffusion](./docs/MBD.md): An EnCodec compatible decoder using diffusion.
37 |
38 | ## Training code
39 |
40 | AudioCraft contains PyTorch components for deep learning research in audio and training pipelines for the developed models.
41 | For a general introduction of AudioCraft design principles and instructions to develop your own training pipeline, refer to
42 | the [AudioCraft training documentation](./docs/TRAINING.md).
43 |
44 | For reproducing existing work and using the developed training pipelines, refer to the instructions for each specific model
45 | that provides pointers to configuration, example grids and model/task-specific information and FAQ.
46 |
47 |
48 | ## API documentation
49 |
50 | We provide some [API documentation](https://facebookresearch.github.io/audiocraft/api_docs/audiocraft/index.html) for AudioCraft.
51 |
52 |
53 | ## FAQ
54 |
55 | #### Is the training code available?
56 |
57 | Yes! We provide the training code for [EnCodec](./docs/ENCODEC.md), [MusicGen](./docs/MUSICGEN.md) and [Multi Band Diffusion](./docs/MBD.md).
58 |
59 | #### Where are the models stored?
60 |
61 | Hugging Face stored the model in a specific location, which can be overriden by setting the `AUDIOCRAFT_CACHE_DIR` environment variable for the AudioCraft models.
62 | In order to change the cache location of the other Hugging Face models, please check out the [Hugging Face Transformers documentation for the cache setup](https://huggingface.co/docs/transformers/installation#cache-setup).
63 | Finally, if you use a model that relies on Demucs (e.g. `musicgen-melody`) and want to change the download location for Demucs, refer to the [Torch Hub documentation](https://pytorch.org/docs/stable/hub.html#where-are-my-downloaded-models-saved).
64 |
65 |
66 | ## License
67 | * The code in this repository is released under the MIT license as found in the [LICENSE file](LICENSE).
68 | * The models weights in this repository are released under the CC-BY-NC 4.0 license as found in the [LICENSE_weights file](LICENSE_weights).
69 |
70 |
71 | ## Citation
72 |
73 | For the general framework of AudioCraft, please cite the following.
74 | ```
75 | @article{copet2023simple,
76 | title={Simple and Controllable Music Generation},
77 | author={Jade Copet and Felix Kreuk and Itai Gat and Tal Remez and David Kant and Gabriel Synnaeve and Yossi Adi and Alexandre Défossez},
78 | year={2023},
79 | journal={arXiv preprint arXiv:2306.05284},
80 | }
81 | ```
82 |
83 | When referring to a specific model, please cite as mentioned in the model specific README, e.g
84 | [./docs/MUSICGEN.md](./docs/MUSICGEN.md), [./docs/AUDIOGEN.md](./docs/AUDIOGEN.md), etc.
85 |
--------------------------------------------------------------------------------
/audiocraft/data/info_audio_dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | """Base classes for the datasets that also provide non-audio metadata,
7 | e.g. description, text transcription etc.
8 | """
9 | from dataclasses import dataclass
10 | import logging
11 | import math
12 | import re
13 | import typing as tp
14 |
15 | import torch
16 |
17 | from .audio_dataset import AudioDataset, AudioMeta
18 | from ..environment import AudioCraftEnvironment
19 | from ..modules.conditioners import SegmentWithAttributes, ConditioningAttributes
20 |
21 |
22 | logger = logging.getLogger(__name__)
23 |
24 |
25 | def _clusterify_meta(meta: AudioMeta) -> AudioMeta:
26 | """Monkey-patch meta to match cluster specificities."""
27 | meta.path = AudioCraftEnvironment.apply_dataset_mappers(meta.path)
28 | if meta.info_path is not None:
29 | meta.info_path.zip_path = AudioCraftEnvironment.apply_dataset_mappers(meta.info_path.zip_path)
30 | return meta
31 |
32 |
33 | def clusterify_all_meta(meta: tp.List[AudioMeta]) -> tp.List[AudioMeta]:
34 | """Monkey-patch all meta to match cluster specificities."""
35 | return [_clusterify_meta(m) for m in meta]
36 |
37 |
38 | @dataclass
39 | class AudioInfo(SegmentWithAttributes):
40 | """Dummy SegmentInfo with empty attributes.
41 |
42 | The InfoAudioDataset is expected to return metadata that inherits
43 | from SegmentWithAttributes class and can return conditioning attributes.
44 |
45 | This basically guarantees all datasets will be compatible with current
46 | solver that contain conditioners requiring this.
47 | """
48 | audio_tokens: tp.Optional[torch.Tensor] = None # populated when using cached batch for training a LM.
49 |
50 | def to_condition_attributes(self) -> ConditioningAttributes:
51 | return ConditioningAttributes()
52 |
53 |
54 | class InfoAudioDataset(AudioDataset):
55 | """AudioDataset that always returns metadata as SegmentWithAttributes along with the audio waveform.
56 |
57 | See `audiocraft.data.audio_dataset.AudioDataset` for initialization arguments.
58 | """
59 | def __init__(self, meta: tp.List[AudioMeta], **kwargs):
60 | super().__init__(clusterify_all_meta(meta), **kwargs)
61 |
62 | def __getitem__(self, index: int) -> tp.Union[torch.Tensor, tp.Tuple[torch.Tensor, SegmentWithAttributes]]:
63 | if not self.return_info:
64 | wav = super().__getitem__(index)
65 | assert isinstance(wav, torch.Tensor)
66 | return wav
67 | wav, meta = super().__getitem__(index)
68 | return wav, AudioInfo(**meta.to_dict())
69 |
70 |
71 | def get_keyword_or_keyword_list(value: tp.Optional[str]) -> tp.Union[tp.Optional[str], tp.Optional[tp.List[str]]]:
72 | """Preprocess a single keyword or possible a list of keywords."""
73 | if isinstance(value, list):
74 | return get_keyword_list(value)
75 | else:
76 | return get_keyword(value)
77 |
78 |
79 | def get_string(value: tp.Optional[str]) -> tp.Optional[str]:
80 | """Preprocess a single keyword."""
81 | if value is None or (not isinstance(value, str)) or len(value) == 0 or value == 'None':
82 | return None
83 | else:
84 | return value.strip()
85 |
86 |
87 | def get_keyword(value: tp.Optional[str]) -> tp.Optional[str]:
88 | """Preprocess a single keyword."""
89 | if value is None or (not isinstance(value, str)) or len(value) == 0 or value == 'None':
90 | return None
91 | else:
92 | return value.strip().lower()
93 |
94 |
95 | def get_keyword_list(values: tp.Union[str, tp.List[str]]) -> tp.Optional[tp.List[str]]:
96 | """Preprocess a list of keywords."""
97 | if isinstance(values, str):
98 | values = [v.strip() for v in re.split(r'[,\s]', values)]
99 | elif isinstance(values, float) and math.isnan(values):
100 | values = []
101 | if not isinstance(values, list):
102 | logger.debug(f"Unexpected keyword list {values}")
103 | values = [str(values)]
104 |
105 | kws = [get_keyword(v) for v in values]
106 | kw_list = [k for k in kws if k is not None]
107 | if len(kw_list) == 0:
108 | return None
109 | else:
110 | return kw_list
111 |
--------------------------------------------------------------------------------
/audiocraft/adversarial/discriminators/mpd.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import typing as tp
8 |
9 | import torch
10 | import torch.nn as nn
11 | import torch.nn.functional as F
12 |
13 | from ...modules import NormConv2d
14 | from .base import MultiDiscriminator, MultiDiscriminatorOutputType
15 |
16 |
17 | def get_padding(kernel_size: int, dilation: int = 1) -> int:
18 | return int((kernel_size * dilation - dilation) / 2)
19 |
20 |
21 | class PeriodDiscriminator(nn.Module):
22 | """Period sub-discriminator.
23 |
24 | Args:
25 | period (int): Period between samples of audio.
26 | in_channels (int): Number of input channels.
27 | out_channels (int): Number of output channels.
28 | n_layers (int): Number of convolutional layers.
29 | kernel_sizes (list of int): Kernel sizes for convolutions.
30 | stride (int): Stride for convolutions.
31 | filters (int): Initial number of filters in convolutions.
32 | filters_scale (int): Multiplier of number of filters as we increase depth.
33 | max_filters (int): Maximum number of filters.
34 | norm (str): Normalization method.
35 | activation (str): Activation function.
36 | activation_params (dict): Parameters to provide to the activation function.
37 | """
38 | def __init__(self, period: int, in_channels: int = 1, out_channels: int = 1,
39 | n_layers: int = 5, kernel_sizes: tp.List[int] = [5, 3], stride: int = 3,
40 | filters: int = 8, filters_scale: int = 4, max_filters: int = 1024,
41 | norm: str = 'weight_norm', activation: str = 'LeakyReLU',
42 | activation_params: dict = {'negative_slope': 0.2}):
43 | super().__init__()
44 | self.period = period
45 | self.n_layers = n_layers
46 | self.activation = getattr(torch.nn, activation)(**activation_params)
47 | self.convs = nn.ModuleList()
48 | in_chs = in_channels
49 | for i in range(self.n_layers):
50 | out_chs = min(filters * (filters_scale ** (i + 1)), max_filters)
51 | eff_stride = 1 if i == self.n_layers - 1 else stride
52 | self.convs.append(NormConv2d(in_chs, out_chs, kernel_size=(kernel_sizes[0], 1), stride=(eff_stride, 1),
53 | padding=((kernel_sizes[0] - 1) // 2, 0), norm=norm))
54 | in_chs = out_chs
55 | self.conv_post = NormConv2d(in_chs, out_channels, kernel_size=(kernel_sizes[1], 1), stride=1,
56 | padding=((kernel_sizes[1] - 1) // 2, 0), norm=norm)
57 |
58 | def forward(self, x: torch.Tensor):
59 | fmap = []
60 | # 1d to 2d
61 | b, c, t = x.shape
62 | if t % self.period != 0: # pad first
63 | n_pad = self.period - (t % self.period)
64 | x = F.pad(x, (0, n_pad), 'reflect')
65 | t = t + n_pad
66 | x = x.view(b, c, t // self.period, self.period)
67 |
68 | for conv in self.convs:
69 | x = conv(x)
70 | x = self.activation(x)
71 | fmap.append(x)
72 | x = self.conv_post(x)
73 | fmap.append(x)
74 | # x = torch.flatten(x, 1, -1)
75 |
76 | return x, fmap
77 |
78 |
79 | class MultiPeriodDiscriminator(MultiDiscriminator):
80 | """Multi-Period (MPD) Discriminator.
81 |
82 | Args:
83 | in_channels (int): Number of input channels.
84 | out_channels (int): Number of output channels.
85 | periods (Sequence[int]): Periods between samples of audio for the sub-discriminators.
86 | **kwargs: Additional args for `PeriodDiscriminator`
87 | """
88 | def __init__(self, in_channels: int = 1, out_channels: int = 1,
89 | periods: tp.Sequence[int] = [2, 3, 5, 7, 11], **kwargs):
90 | super().__init__()
91 | self.discriminators = nn.ModuleList([
92 | PeriodDiscriminator(p, in_channels, out_channels, **kwargs) for p in periods
93 | ])
94 |
95 | @property
96 | def num_discriminators(self):
97 | return len(self.discriminators)
98 |
99 | def forward(self, x: torch.Tensor) -> MultiDiscriminatorOutputType:
100 | logits = []
101 | fmaps = []
102 | for disc in self.discriminators:
103 | logit, fmap = disc(x)
104 | logits.append(logit)
105 | fmaps.append(fmap)
106 | return logits, fmaps
107 |
--------------------------------------------------------------------------------
/audiocraft/metrics/clap_consistency.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from pathlib import Path
8 | import typing as tp
9 |
10 | import torch
11 | import torchmetrics
12 | from transformers import RobertaTokenizer # type: ignore
13 |
14 | from ..data.audio_utils import convert_audio
15 | from ..environment import AudioCraftEnvironment
16 | from ..utils.utils import load_clap_state_dict
17 |
18 | try:
19 | import laion_clap # type: ignore
20 | except ImportError:
21 | laion_clap = None
22 |
23 |
24 | class TextConsistencyMetric(torchmetrics.Metric):
25 | """Text consistency metric measuring consistency between audio and text pairs."""
26 |
27 | def update(self, audio: torch.Tensor, text: tp.List[str], sizes: torch.Tensor, sample_rates: torch.Tensor) -> None:
28 | raise NotImplementedError("implement how to update the metric from the audio and text pairs.")
29 |
30 | def compute(self):
31 | raise NotImplementedError("implement how to compute the final metric score.")
32 |
33 |
34 | class CLAPTextConsistencyMetric(TextConsistencyMetric):
35 | """Text consistency metric relying on Contrastive Language-Audio Pretraining (CLAP).
36 |
37 | This metric is similar to the MuLan Cycle Consistency from MusicLM (https://arxiv.org/pdf/2301.11325.pdf)
38 | or the CLAP score used in Make-An-Audio (https://arxiv.org/pdf/2301.12661v1.pdf).
39 |
40 | As a joint audio-text embedding model, a pretrained CLAP model can be used to quantify the
41 | similarity between audio-text pairs. We compute the CLAP embeddings from the text descriptions as
42 | well as the generated audio based on them, and define the MCC metric as the average cosine similarity
43 | between these embeddings.
44 |
45 | Model implementation & pre-trained checkpoints: https://github.com/LAION-AI/CLAP
46 | """
47 | def __init__(self, model_path: tp.Union[str, Path], model_arch: str = 'HTSAT-tiny', enable_fusion: bool = False):
48 | super().__init__()
49 | if laion_clap is None:
50 | raise ImportError("Please install CLAP to compute text consistency: 'pip install laion_clap'")
51 | self.add_state("cosine_sum", default=torch.tensor(0.), dist_reduce_fx="sum")
52 | self.add_state("weight", default=torch.tensor(0.), dist_reduce_fx="sum")
53 | self._initialize_model(model_path, model_arch, enable_fusion)
54 |
55 | def _initialize_model(self, model_path: tp.Union[str, Path], model_arch: str, enable_fusion: bool):
56 | model_path = AudioCraftEnvironment.resolve_reference_path(model_path)
57 | self.tokenize = RobertaTokenizer.from_pretrained('roberta-base')
58 | self.model = laion_clap.CLAP_Module(enable_fusion=enable_fusion, amodel=model_arch)
59 | self.model_sample_rate = 48_000
60 | load_clap_state_dict(self.model, model_path)
61 | self.model.eval()
62 |
63 | def _tokenizer(self, texts: tp.Union[str, tp.List[str]]) -> dict:
64 | # we use the default params from CLAP module here as well
65 | return self.tokenize(texts, padding="max_length", truncation=True, max_length=77, return_tensors="pt")
66 |
67 | def update(self, audio: torch.Tensor, text: tp.List[str], sizes: torch.Tensor, sample_rates: torch.Tensor) -> None:
68 | """Compute cosine similarity between audio and text pairs and accumulate scores over the dataset."""
69 | assert audio.size(0) == len(text), "Number of audio and text samples should match"
70 | assert torch.all(sample_rates == sample_rates[0].item()), "All items in batch should have the same sample rate"
71 | sample_rate = int(sample_rates[0].item())
72 | # convert audio batch to 48kHz monophonic audio with no channel dimension: [B, C, T] -> [B, T]
73 | audio = convert_audio(audio, from_rate=sample_rate, to_rate=self.model_sample_rate, to_channels=1).mean(dim=1)
74 | audio_embeddings = self.model.get_audio_embedding_from_data(audio, use_tensor=True)
75 | text_embeddings = self.model.get_text_embedding(text, tokenizer=self._tokenizer, use_tensor=True)
76 | # cosine similarity between the text and the audio embedding
77 | cosine_sim = torch.nn.functional.cosine_similarity(audio_embeddings, text_embeddings, dim=1, eps=1e-8)
78 | self.cosine_sum += cosine_sim.sum(dim=0)
79 | self.weight += torch.tensor(cosine_sim.size(0))
80 |
81 | def compute(self):
82 | """Computes the average cosine similarty across all audio/text pairs."""
83 | assert self.weight.item() > 0, "Unable to compute with total number of comparisons <= 0" # type: ignore
84 | return (self.cosine_sum / self.weight).item() # type: ignore
85 |
--------------------------------------------------------------------------------