├── audiocraft ├── py.typed ├── utils │ ├── samples │ │ └── __init__.py │ ├── __init__.py │ ├── notebook.py │ ├── profiler.py │ ├── autocast.py │ ├── deadlock.py │ ├── export_legacy.py │ ├── cluster.py │ ├── export.py │ └── best_state.py ├── grids │ ├── __init__.py │ ├── audiogen │ │ ├── __init__.py │ │ ├── audiogen_base_16khz.py │ │ └── audiogen_pretrained_16khz_eval.py │ ├── compression │ │ ├── __init__.py │ │ ├── encodec_base_24khz.py │ │ ├── encodec_audiogen_16khz.py │ │ ├── debug.py │ │ ├── encodec_musicgen_32khz.py │ │ └── _explorers.py │ ├── diffusion │ │ ├── __init__.py │ │ ├── 4_bands_base_32khz.py │ │ └── _explorers.py │ ├── musicgen │ │ ├── __init__.py │ │ ├── musicgen_clapemb_32khz.py │ │ ├── musicgen_base_32khz.py │ │ ├── musicgen_melody_32khz.py │ │ ├── musicgen_base_cached_32khz.py │ │ ├── _explorers.py │ │ └── musicgen_pretrained_32khz_eval.py │ └── _base_explorers.py ├── quantization │ ├── __init__.py │ └── base.py ├── adversarial │ ├── discriminators │ │ ├── __init__.py │ │ ├── base.py │ │ └── mpd.py │ └── __init__.py ├── data │ ├── __init__.py │ ├── zip.py │ └── info_audio_dataset.py ├── metrics │ ├── __init__.py │ ├── chroma_cosinesim.py │ └── clap_consistency.py ├── solvers │ ├── __init__.py │ └── audiogen.py ├── losses │ ├── __init__.py │ └── sisnr.py ├── modules │ ├── __init__.py │ ├── lstm.py │ ├── chroma.py │ ├── activations.py │ └── streaming.py ├── models │ └── __init__.py ├── optim │ ├── __init__.py │ ├── linear_warmup_lr_scheduler.py │ ├── inverse_sqrt_lr_scheduler.py │ ├── cosine_lr_scheduler.py │ ├── polynomial_decay_lr_scheduler.py │ └── ema.py └── __init__.py ├── assets ├── bach.mp3 ├── bolero_ravel.mp3 ├── sirens_and_a_humming_engine_approach_and_pass.mp3 └── a_duck_quacking_as_birds_chirp_and_a_pigeon_cooing.mp3 ├── dataset └── example │ ├── electro_1.mp3 │ ├── electro_2.mp3 │ ├── electro_2.json │ └── electro_1.json ├── config ├── solver │ ├── audiogen │ │ ├── evaluation │ │ │ ├── none.yaml │ │ │ └── objective_eval.yaml │ │ ├── default.yaml │ │ ├── debug.yaml │ │ └── audiogen_base_16khz.yaml │ ├── musicgen │ │ ├── evaluation │ │ │ ├── none.yaml │ │ │ └── objective_eval.yaml │ │ ├── debug.yaml │ │ ├── musicgen_base_32khz.yaml │ │ ├── musicgen_melody_32khz.yaml │ │ └── default.yaml │ ├── compression │ │ ├── encodec_base_24khz.yaml │ │ ├── encodec_audiogen_16khz.yaml │ │ ├── encodec_musicgen_32khz.yaml │ │ ├── debug.yaml │ │ └── default.yaml │ ├── diffusion │ │ ├── encodec_24khz.yaml │ │ ├── debug.yaml │ │ └── default.yaml │ └── default.yaml ├── model │ ├── lm │ │ ├── model_scale │ │ │ ├── base.yaml │ │ │ ├── small.yaml │ │ │ ├── medium.yaml │ │ │ ├── large.yaml │ │ │ └── xsmall.yaml │ │ ├── audiogen_lm.yaml │ │ ├── musicgen_lm.yaml │ │ └── default.yaml │ ├── none.yaml │ ├── encodec │ │ ├── encodec_base_causal.yaml │ │ ├── encodec_large_nq4_s640.yaml │ │ ├── encodec_large_nq4_s320.yaml │ │ └── default.yaml │ └── score │ │ └── basic.yaml ├── dset │ ├── audio │ │ ├── default.yaml │ │ ├── example.yaml │ │ ├── audiocaps_16khz.yaml │ │ └── musiccaps_32khz.yaml │ ├── internal │ │ ├── music_400k_32khz.yaml │ │ ├── music_10k_32khz.yaml │ │ └── sounds_16khz.yaml │ └── default.yaml ├── conditioner │ ├── none.yaml │ ├── text2sound.yaml │ ├── text2music.yaml │ ├── chroma2music.yaml │ └── clapemb2music.yaml ├── teams │ ├── default.yaml │ └── labs.yaml └── config.yaml ├── mypy.ini ├── scripts ├── __init__.py ├── templates │ ├── base.html │ ├── login.html │ ├── results.html │ └── index.html └── static │ └── style.css ├── tests ├── __init__.py ├── data │ ├── __init__.py │ └── test_audio_utils.py ├── losses │ ├── __init__.py │ └── test_losses.py ├── modules │ ├── __init__.py │ ├── test_lstm.py │ └── test_activations.py ├── utils │ └── __init__.py ├── adversarial │ ├── __init__.py │ └── test_discriminators.py ├── common_utils │ ├── __init__.py │ ├── wav_utils.py │ └── temp_utils.py ├── quantization │ └── test_vq.py └── models │ ├── test_audiogen.py │ ├── test_musicgen.py │ ├── test_multibanddiffusion.py │ └── test_encodec_model.py ├── MANIFEST.in ├── egs └── example │ └── data.jsonl ├── setup.cfg ├── .github ├── workflows │ ├── audiocraft_linter.yml │ ├── audiocraft_tests.yml │ └── audiocraft_docs.yml └── actions │ └── audiocraft_build │ └── action.yml ├── requirements.txt ├── .gitignore ├── CHANGELOG.md ├── LICENSE ├── CONTRIBUTING.md ├── Makefile ├── setup.py ├── docs └── DATASETS.md ├── CODE_OF_CONDUCT.md └── README.md /audiocraft/py.typed: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /assets/bach.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sugarforever/audiocraft/main/assets/bach.mp3 -------------------------------------------------------------------------------- /assets/bolero_ravel.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sugarforever/audiocraft/main/assets/bolero_ravel.mp3 -------------------------------------------------------------------------------- /dataset/example/electro_1.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sugarforever/audiocraft/main/dataset/example/electro_1.mp3 -------------------------------------------------------------------------------- /dataset/example/electro_2.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sugarforever/audiocraft/main/dataset/example/electro_2.mp3 -------------------------------------------------------------------------------- /config/solver/audiogen/evaluation/none.yaml: -------------------------------------------------------------------------------- 1 | # @package __global__ 2 | 3 | dataset: 4 | evaluate: 5 | num_samples: 10000 6 | -------------------------------------------------------------------------------- /config/solver/musicgen/evaluation/none.yaml: -------------------------------------------------------------------------------- 1 | # @package __global__ 2 | 3 | dataset: 4 | evaluate: 5 | num_samples: 10000 6 | -------------------------------------------------------------------------------- /config/model/lm/model_scale/base.yaml: -------------------------------------------------------------------------------- 1 | # @package __global__ 2 | 3 | # overrides nothing because default is already transformer base (~ 60M params) 4 | -------------------------------------------------------------------------------- /config/model/lm/model_scale/small.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # 300M Param. 4 | 5 | transformer_lm: 6 | dim: 1024 7 | num_heads: 16 8 | num_layers: 24 9 | -------------------------------------------------------------------------------- /assets/sirens_and_a_humming_engine_approach_and_pass.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sugarforever/audiocraft/main/assets/sirens_and_a_humming_engine_approach_and_pass.mp3 -------------------------------------------------------------------------------- /config/model/lm/model_scale/medium.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # gpt2 like (~1.5B params) 4 | transformer_lm: 5 | dim: 1536 6 | num_heads: 24 7 | num_layers: 48 8 | -------------------------------------------------------------------------------- /assets/a_duck_quacking_as_birds_chirp_and_a_pigeon_cooing.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sugarforever/audiocraft/main/assets/a_duck_quacking_as_birds_chirp_and_a_pigeon_cooing.mp3 -------------------------------------------------------------------------------- /config/model/lm/model_scale/large.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # gpt2 inspired, even bigger (~3.3B params) 4 | transformer_lm: 5 | dim: 2048 6 | num_heads: 32 7 | num_layers: 48 8 | -------------------------------------------------------------------------------- /mypy.ini: -------------------------------------------------------------------------------- 1 | [mypy] 2 | 3 | [mypy-treetable,torchaudio.*,soundfile,einops.*,av.*,tqdm.*,num2words.*,spacy,xformers.*,scipy,huggingface_hub,transformers,dac.*] 4 | ignore_missing_imports = True 5 | -------------------------------------------------------------------------------- /config/model/none.yaml: -------------------------------------------------------------------------------- 1 | # @package __global__ 2 | 3 | # This file exist so that model is recognized as a config group 4 | # by Hydra, and Dora. A bit weird we might need a better fix someday. 5 | -------------------------------------------------------------------------------- /config/model/encodec/encodec_base_causal.yaml: -------------------------------------------------------------------------------- 1 | # @package __global__ 2 | 3 | defaults: 4 | - encodec/default 5 | 6 | encodec: 7 | causal: true 8 | 9 | rvq: 10 | n_q: 32 11 | q_dropout: true 12 | -------------------------------------------------------------------------------- /config/dset/audio/default.yaml: -------------------------------------------------------------------------------- 1 | # @package __global__ 2 | 3 | datasource: 4 | max_sample_rate: ??? 5 | max_channels: ??? 6 | 7 | train: ??? 8 | valid: ??? 9 | evaluate: ??? 10 | generate: null 11 | -------------------------------------------------------------------------------- /scripts/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /tests/data/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /tests/losses/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /tests/modules/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /tests/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include Makefile 2 | include LICENSE 3 | include LICENSE_weights 4 | include *.md 5 | include *.ini 6 | include requirements.txt 7 | include audiocraft/py.typed 8 | include assets/*.mp3 9 | recursive-include conf *.yaml 10 | -------------------------------------------------------------------------------- /tests/adversarial/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /config/dset/audio/example.yaml: -------------------------------------------------------------------------------- 1 | # @package __global__ 2 | 3 | datasource: 4 | max_sample_rate: 44100 5 | max_channels: 2 6 | 7 | train: egs/example 8 | valid: egs/example 9 | evaluate: egs/example 10 | generate: egs/example 11 | -------------------------------------------------------------------------------- /config/model/lm/model_scale/xsmall.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | # just used for debugging or when we just want to populate the cache 3 | # and do not care about training. 4 | 5 | transformer_lm: 6 | dim: 64 7 | num_heads: 2 8 | num_layers: 2 9 | -------------------------------------------------------------------------------- /audiocraft/utils/samples/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /config/model/encodec/encodec_large_nq4_s640.yaml: -------------------------------------------------------------------------------- 1 | # @package __global__ 2 | 3 | defaults: 4 | - encodec/default 5 | 6 | seanet: 7 | ratios: [8, 5, 4, 4] 8 | n_filters: 64 9 | 10 | rvq: 11 | bins: 2048 12 | n_q: 4 13 | q_dropout: false 14 | -------------------------------------------------------------------------------- /audiocraft/grids/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | """Dora Grids.""" 7 | -------------------------------------------------------------------------------- /audiocraft/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | """Utilities.""" 7 | -------------------------------------------------------------------------------- /config/model/encodec/encodec_large_nq4_s320.yaml: -------------------------------------------------------------------------------- 1 | # @package __global__ 2 | 3 | defaults: 4 | - encodec/default 5 | 6 | seanet: 7 | # default ratios are [8, 5, 4, 2] 8 | n_filters: 64 9 | 10 | rvq: 11 | bins: 2048 12 | n_q: 4 13 | q_dropout: false 14 | -------------------------------------------------------------------------------- /config/solver/compression/encodec_base_24khz.yaml: -------------------------------------------------------------------------------- 1 | # @package __global__ 2 | 3 | defaults: 4 | - compression/default 5 | - /model: encodec/encodec_base_causal 6 | - override /dset: audio/default 7 | - _self_ 8 | 9 | channels: 1 10 | sample_rate: 24000 11 | -------------------------------------------------------------------------------- /audiocraft/grids/audiogen/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | """AudioGen grids.""" 7 | -------------------------------------------------------------------------------- /audiocraft/grids/compression/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | """EnCodec grids.""" 7 | -------------------------------------------------------------------------------- /audiocraft/grids/diffusion/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | """Diffusion grids.""" 7 | -------------------------------------------------------------------------------- /audiocraft/grids/musicgen/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | """MusicGen grids.""" 7 | -------------------------------------------------------------------------------- /config/solver/compression/encodec_audiogen_16khz.yaml: -------------------------------------------------------------------------------- 1 | # @package __global__ 2 | 3 | defaults: 4 | - compression/default 5 | - /model: encodec/encodec_large_nq4_s320 6 | - override /dset: audio/default 7 | - _self_ 8 | 9 | channels: 1 10 | sample_rate: 16000 11 | -------------------------------------------------------------------------------- /config/solver/compression/encodec_musicgen_32khz.yaml: -------------------------------------------------------------------------------- 1 | # @package __global__ 2 | 3 | defaults: 4 | - compression/default 5 | - /model: encodec/encodec_large_nq4_s640 6 | - override /dset: audio/default 7 | - _self_ 8 | 9 | channels: 1 10 | sample_rate: 32000 11 | -------------------------------------------------------------------------------- /config/solver/diffusion/encodec_24khz.yaml: -------------------------------------------------------------------------------- 1 | # @package __global__ 2 | 3 | defaults: 4 | - diffusion/default 5 | - _self_ 6 | 7 | 8 | sample_rate: 24000 9 | channels: 1 10 | compression_model_checkpoint: //pretrained/facebook/encodec_24khz 11 | n_q: 4 # num quantizers, 3kbps 12 | -------------------------------------------------------------------------------- /egs/example/data.jsonl: -------------------------------------------------------------------------------- 1 | {"path": "dataset/example/electro_1.mp3", "duration": 15.024, "sample_rate": 48000, "amplitude": null, "weight": null, "info_path": null} 2 | {"path": "dataset/example/electro_2.mp3", "duration": 20.035918367346937, "sample_rate": 44100, "amplitude": null, "weight": null, "info_path": null} 3 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [pep8] 2 | max-line-length = 120 3 | 4 | [flake8] 5 | max-line-length = 120 6 | 7 | [coverage:report] 8 | include = audiocraft/* 9 | omit = 10 | audiocraft/environment.py 11 | audiocraft/solvers/* 12 | audiocraft/utils/* 13 | audiocraft/*/loaders.py 14 | audiocraft/*/builders.py 15 | -------------------------------------------------------------------------------- /dataset/example/electro_2.json: -------------------------------------------------------------------------------- 1 | {"key": "", "artist": "Voyager I", "sample_rate": 44100, "file_extension": "mp3", "description": "This is an electronic song sending positive vibes.", "keywords": "", "duration": 20.0, "bpm": "", "genre": "electronic", "title": "Untitled song", "name": "electro_2", "instrument": "Mix", "moods": []} 2 | -------------------------------------------------------------------------------- /dataset/example/electro_1.json: -------------------------------------------------------------------------------- 1 | {"key": "", "artist": "Voyager I", "sample_rate": 48000, "file_extension": "mp3", "description": "A cool song from Voyager.", "keywords": "bright, pulsing, cool", "duration": 15.0, "bpm": "", "genre": "electronic", "title": "Enracinement", "name": "electro_1", "instrument": "Mix", "moods": ["uplifting", "motivational"]} 2 | -------------------------------------------------------------------------------- /config/conditioner/none.yaml: -------------------------------------------------------------------------------- 1 | # @package __global__ 2 | 3 | # No conditioning 4 | 5 | classifier_free_guidance: 6 | training_dropout: 0 7 | inference_coef: 1 8 | 9 | attribute_dropout: 10 | text: {} 11 | wav: {} 12 | 13 | fuser: 14 | sum: [] 15 | prepend: [] 16 | cross: [] 17 | input_interpolate: [] 18 | 19 | conditioners: null 20 | -------------------------------------------------------------------------------- /config/dset/internal/music_400k_32khz.yaml: -------------------------------------------------------------------------------- 1 | # @package __global__ 2 | 3 | datasource: 4 | max_sample_rate: 32000 5 | max_channels: 1 6 | 7 | train: egs/music/music_400k_32khz/train 8 | valid: egs/music/music_400k_32khz/valid 9 | evaluate: egs/music/music_400k_32khz/test 10 | generate: egs/music/music_400k_32khz/test # identical to evaluate 11 | -------------------------------------------------------------------------------- /config/dset/audio/audiocaps_16khz.yaml: -------------------------------------------------------------------------------- 1 | # @package __global__ 2 | 3 | # AudioCaps dataset 4 | datasource: 5 | max_sample_rate: 16000 6 | max_channels: 1 7 | 8 | train: null # only evaluation set 9 | valid: null # only evaluation set 10 | evaluate: egs/audiocaps/audiocaps_16khz 11 | generate: egs/audiocaps/audiocaps_16khz # identical to evaluate 12 | -------------------------------------------------------------------------------- /config/dset/default.yaml: -------------------------------------------------------------------------------- 1 | # @package __global__ 2 | 3 | # WARNING: This is a base configuration file shared across ALL solvers in AudioCraft 4 | # Please don't update this file directly. Instead use distinct configuration files 5 | # to override the below configuration. 6 | datasource: 7 | train: ??? 8 | valid: ??? 9 | evaluate: ??? 10 | generate: ??? 11 | -------------------------------------------------------------------------------- /config/model/score/basic.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | diffusion_unet: 4 | hidden: 48 5 | depth: 4 6 | res_blocks: 1 7 | norm_groups: 4 8 | kernel: 8 9 | stride: 4 10 | growth: 4 11 | max_channels: 10_000 12 | dropout: 0. 13 | emb_all_layers: true 14 | bilstm: false 15 | codec_dim: null 16 | transformer: false 17 | cross_attention: false -------------------------------------------------------------------------------- /tests/common_utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # flake8: noqa 8 | from .temp_utils import TempDirMixin 9 | from .wav_utils import get_batch_white_noise, get_white_noise, save_wav 10 | -------------------------------------------------------------------------------- /audiocraft/quantization/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | """RVQ.""" 7 | # flake8: noqa 8 | from .vq import ResidualVectorQuantizer 9 | from .base import BaseQuantizer, DummyQuantizer, QuantizedResult 10 | -------------------------------------------------------------------------------- /config/teams/default.yaml: -------------------------------------------------------------------------------- 1 | default: 2 | dora_dir: /tmp/audiocraft_${oc.env:USER} 3 | partitions: 4 | global: debug 5 | team: debug 6 | reference_dir: /tmp 7 | darwin: # if we detect we are on a Mac, then most likely we are doing unit testing etc. 8 | dora_dir: /tmp/audiocraft_${oc.env:USER} 9 | partitions: 10 | global: debug 11 | team: debug 12 | reference_dir: /tmp 13 | -------------------------------------------------------------------------------- /config/dset/internal/music_10k_32khz.yaml: -------------------------------------------------------------------------------- 1 | # @package __global__ 2 | 3 | # high quality music dataset with no artist overlap between splits 4 | datasource: 5 | max_sample_rate: 32000 6 | max_channels: 1 7 | 8 | train: egs/music/music_10k_32khz/train 9 | valid: egs/music/music_10k_32khz/valid 10 | evaluate: egs/music/music_10k_32khz/test 11 | generate: egs/music/music_10k_32khz/test # identical to evaluate 12 | -------------------------------------------------------------------------------- /config/dset/internal/sounds_16khz.yaml: -------------------------------------------------------------------------------- 1 | # @package __global__ 2 | 3 | # environmental sounds dataset compiling all datasets 4 | # with applied filters on tags 5 | datasource: 6 | max_sample_rate: 16000 7 | max_channels: 1 8 | 9 | train: egs/sound/sounds_16khz/train 10 | valid: egs/sound/sounds_16khz/valid 11 | evaluate: egs/sound/sounds_16khz/test 12 | generate: egs/sound/sounds_16khz/test # identical to evaluate 13 | -------------------------------------------------------------------------------- /audiocraft/adversarial/discriminators/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # flake8: noqa 8 | from .mpd import MultiPeriodDiscriminator 9 | from .msd import MultiScaleDiscriminator 10 | from .msstftd import MultiScaleSTFTDiscriminator 11 | -------------------------------------------------------------------------------- /.github/workflows/audiocraft_linter.yml: -------------------------------------------------------------------------------- 1 | name: audiocraft_linter 2 | on: 3 | push: 4 | branches: [ main ] 5 | pull_request: 6 | branches: [ main ] 7 | 8 | jobs: 9 | run_linter: 10 | name: Run linter 11 | runs-on: ubuntu-latest 12 | steps: 13 | - uses: actions/checkout@v2 14 | - uses: ./.github/actions/audiocraft_build 15 | - run: | 16 | . env/bin/activate 17 | make linter 18 | -------------------------------------------------------------------------------- /config/dset/audio/musiccaps_32khz.yaml: -------------------------------------------------------------------------------- 1 | # @package __global__ 2 | 3 | # total samples obtained from MusicCaps = 5469 4 | # (out of 5521 due to AudioSet corrupted samples) 5 | datasource: 6 | max_sample_rate: 32000 7 | max_channels: 2 8 | 9 | train: null # only evaluation set 10 | valid: null # only evaluation set 11 | evaluate: egs/musiccaps/musiccaps_32khz 12 | generate: egs/musiccaps/musiccaps_32khz # identical to evaluate 13 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # please make sure you have already a pytorch install that is cuda enabled! 2 | av 3 | einops 4 | flashy>=0.0.1 5 | hydra-core>=1.1 6 | hydra_colorlog 7 | julius 8 | num2words 9 | numpy 10 | sentencepiece 11 | spacy==3.5.2 12 | torch>=2.0.0 13 | torchaudio>=2.0.0 14 | huggingface_hub 15 | tqdm 16 | transformers>=4.31.0 # need Encodec there. 17 | xformers 18 | demucs 19 | librosa 20 | gradio 21 | torchmetrics 22 | encodec 23 | -------------------------------------------------------------------------------- /audiocraft/data/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | """Audio loading and writing support. Datasets for raw audio 7 | or also including some metadata.""" 8 | 9 | # flake8: noqa 10 | from . import audio, audio_dataset, info_audio_dataset, music_dataset, sound_dataset 11 | -------------------------------------------------------------------------------- /scripts/templates/base.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | {% block head %} 5 | 6 | 7 | AudioCraft — MOS 8 | {% endblock %} 9 | 10 | 11 |
12 |

AudioCraft — MOS

13 | {% block content %}{% endblock %} 14 |
15 | 16 | 17 | -------------------------------------------------------------------------------- /config/conditioner/text2sound.yaml: -------------------------------------------------------------------------------- 1 | # @package __global__ 2 | 3 | classifier_free_guidance: 4 | training_dropout: 0.1 5 | inference_coef: 3.0 6 | 7 | attribute_dropout: {} 8 | 9 | fuser: 10 | cross_attention_pos_emb: false 11 | cross_attention_pos_emb_scale: 1 12 | sum: [] 13 | prepend: [] 14 | cross: [description] 15 | input_interpolate: [] 16 | 17 | conditioners: 18 | description: 19 | model: t5 20 | t5: 21 | name: t5-large 22 | finetune: false 23 | word_dropout: 0. 24 | normalize_text: false 25 | -------------------------------------------------------------------------------- /scripts/templates/login.html: -------------------------------------------------------------------------------- 1 | {% extends "base.html" %} 2 | {% block content %} 3 | 4 |

5 | You must identify yourself first! We use a highly secured protocol 6 | where you just decide your username, and that's it. No password, no encryption, 7 | just pure trust. 8 |

9 | 10 | {% if error %} 11 |

{{error}}

12 | {% endif %} 13 |
14 | 17 | 18 | 19 | 20 | {% endblock %} 21 | -------------------------------------------------------------------------------- /.github/workflows/audiocraft_tests.yml: -------------------------------------------------------------------------------- 1 | name: audiocraft_tests 2 | on: 3 | push: 4 | branches: [ main ] 5 | pull_request: 6 | branches: [ main ] 7 | 8 | jobs: 9 | run_tests: 10 | name: Run tests 11 | runs-on: ubuntu-latest 12 | steps: 13 | - uses: actions/checkout@v2 14 | - uses: ./.github/actions/audiocraft_build 15 | - name: Run unit tests 16 | run: | 17 | . env/bin/activate 18 | make tests 19 | - name: Run integration tests 20 | run: | 21 | . env/bin/activate 22 | make tests_integ 23 | -------------------------------------------------------------------------------- /scripts/templates/results.html: -------------------------------------------------------------------------------- 1 | {% extends "base.html" %} 2 | {% block content %} 3 | 4 |

Results for survey #{{signature}}

5 |

Checkout the survey page for details on the models.

6 |

The following users voted: 7 | {% for user in users %} 8 | {{user}} 9 | {% endfor %} 10 | 11 | {% for model in models %} 12 |

{{model['sig']}} ({{model['samples']}} samples)

13 |

Ratings: {{model['mean_rating']}} ± {{model['std_rating']}}

14 | 15 | {% endfor %} 16 | 17 | {% endblock %} 18 | -------------------------------------------------------------------------------- /tests/quantization/test_vq.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | 9 | from audiocraft.quantization.vq import ResidualVectorQuantizer 10 | 11 | 12 | class TestResidualVectorQuantizer: 13 | 14 | def test_rvq(self): 15 | x = torch.randn(1, 16, 2048) 16 | vq = ResidualVectorQuantizer(n_q=8, dimension=16, bins=8) 17 | res = vq(x, 1.) 18 | assert res.x.shape == torch.Size([1, 16, 2048]) 19 | -------------------------------------------------------------------------------- /config/conditioner/text2music.yaml: -------------------------------------------------------------------------------- 1 | # @package __global__ 2 | 3 | classifier_free_guidance: 4 | training_dropout: 0.3 5 | inference_coef: 3.0 6 | 7 | attribute_dropout: {} 8 | 9 | fuser: 10 | cross_attention_pos_emb: false 11 | cross_attention_pos_emb_scale: 1 12 | sum: [] 13 | prepend: [] 14 | cross: [description] 15 | input_interpolate: [] 16 | 17 | conditioners: 18 | description: 19 | model: t5 20 | t5: 21 | name: t5-base 22 | finetune: false 23 | word_dropout: 0.3 24 | normalize_text: false 25 | 26 | dataset: 27 | train: 28 | merge_text_p: 0.25 29 | drop_desc_p: 0.5 30 | drop_other_p: 0.5 31 | -------------------------------------------------------------------------------- /audiocraft/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | """Metrics like CLAP score, FAD, KLD, Visqol, Chroma similarity, etc. 7 | """ 8 | # flake8: noqa 9 | from .clap_consistency import CLAPTextConsistencyMetric, TextConsistencyMetric 10 | from .chroma_cosinesim import ChromaCosineSimilarityMetric 11 | from .fad import FrechetAudioDistanceMetric 12 | from .kld import KLDivergenceMetric, PasstKLDivergenceMetric 13 | from .rvm import RelativeVolumeMel 14 | from .visqol import ViSQOL 15 | -------------------------------------------------------------------------------- /audiocraft/solvers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | """ 7 | Solvers. A Solver is a training recipe, combining the dataloaders, models, 8 | optimizer, losses etc into a single convenient object. 9 | """ 10 | 11 | # flake8: noqa 12 | from .audiogen import AudioGenSolver 13 | from .builders import get_solver 14 | from .base import StandardSolver 15 | from .compression import CompressionSolver 16 | from .musicgen import MusicGenSolver 17 | from .diffusion import DiffusionSolver 18 | -------------------------------------------------------------------------------- /audiocraft/adversarial/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | """Adversarial losses and discriminator architectures.""" 7 | 8 | # flake8: noqa 9 | from .discriminators import ( 10 | MultiPeriodDiscriminator, 11 | MultiScaleDiscriminator, 12 | MultiScaleSTFTDiscriminator 13 | ) 14 | from .losses import ( 15 | AdversarialLoss, 16 | AdvLossType, 17 | get_adv_criterion, 18 | get_fake_criterion, 19 | get_real_criterion, 20 | FeatLossType, 21 | FeatureMatchingLoss 22 | ) 23 | -------------------------------------------------------------------------------- /audiocraft/losses/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | """Loss related classes and functions. In particular the loss balancer from 7 | EnCodec, and the usual spectral losses.""" 8 | 9 | # flake8: noqa 10 | from .balancer import Balancer 11 | from .sisnr import SISNR 12 | from .stftloss import ( 13 | LogSTFTMagnitudeLoss, 14 | MRSTFTLoss, 15 | SpectralConvergenceLoss, 16 | STFTLoss 17 | ) 18 | from .specloss import ( 19 | MelSpectrogramL1Loss, 20 | MultiScaleMelSpectrogramLoss, 21 | ) 22 | -------------------------------------------------------------------------------- /audiocraft/modules/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | """Modules used for building the models.""" 7 | 8 | # flake8: noqa 9 | from .conv import ( 10 | NormConv1d, 11 | NormConv2d, 12 | NormConvTranspose1d, 13 | NormConvTranspose2d, 14 | StreamableConv1d, 15 | StreamableConvTranspose1d, 16 | pad_for_conv1d, 17 | pad1d, 18 | unpad1d, 19 | ) 20 | from .lstm import StreamableLSTM 21 | from .seanet import SEANetEncoder, SEANetDecoder 22 | from .transformer import StreamingTransformer -------------------------------------------------------------------------------- /audiocraft/models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | """ 7 | Models for EnCodec, AudioGen, MusicGen, as well as the generic LMModel. 8 | """ 9 | # flake8: noqa 10 | from . import builders, loaders 11 | from .encodec import ( 12 | CompressionModel, EncodecModel, DAC, 13 | HFEncodecModel, HFEncodecCompressionModel) 14 | from .audiogen import AudioGen 15 | from .lm import LMModel 16 | from .multibanddiffusion import MultiBandDiffusion 17 | from .musicgen import MusicGen 18 | from .unet import DiffusionUnet 19 | -------------------------------------------------------------------------------- /audiocraft/optim/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | """Optimization stuff. In particular, optimizers (DAdaptAdam), schedulers 7 | and Exponential Moving Average. 8 | """ 9 | 10 | # flake8: noqa 11 | from .cosine_lr_scheduler import CosineLRScheduler 12 | from .dadam import DAdaptAdam 13 | from .inverse_sqrt_lr_scheduler import InverseSquareRootLRScheduler 14 | from .linear_warmup_lr_scheduler import LinearWarmupLRScheduler 15 | from .polynomial_decay_lr_scheduler import PolynomialDecayLRScheduler 16 | from .ema import ModuleDictEMA 17 | -------------------------------------------------------------------------------- /audiocraft/solvers/audiogen.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from . import builders, musicgen 8 | 9 | 10 | class AudioGenSolver(musicgen.MusicGenSolver): 11 | """Solver for AudioGen re-implementation training task. 12 | 13 | Note that this implementation does not strictly follows 14 | the method proposed in https://arxiv.org/abs/2209.15352 15 | but is derived from MusicGen's training pipeline. 16 | 17 | More information can be found in the AudioGen model card. 18 | """ 19 | DATASET_TYPE: builders.DatasetType = builders.DatasetType.SOUND 20 | -------------------------------------------------------------------------------- /config/solver/audiogen/evaluation/objective_eval.yaml: -------------------------------------------------------------------------------- 1 | # @package __global__ 2 | 3 | # Setup for execute only on audiocaps for audio generation 4 | # evaluation with objective metrics 5 | # execute_only=evaluate 6 | 7 | dataset: 8 | max_audio_duration: null 9 | # ensure the proper values are broadcasted here for evaluate 10 | evaluate: 11 | min_audio_duration: 1. # some metrics requires a minimum audio length 12 | max_audio_duration: null # all samples from audiocaps should be ~10s 13 | num_samples: null 14 | segment_duration: null 15 | generate: 16 | min_audio_duration: 1. 17 | max_audio_duration: null 18 | num_samples: 500 19 | 20 | evaluate: 21 | metrics: 22 | fad: true 23 | kld: true 24 | text_consistency: true 25 | -------------------------------------------------------------------------------- /config/solver/musicgen/evaluation/objective_eval.yaml: -------------------------------------------------------------------------------- 1 | # @package __global__ 2 | 3 | # Setup for execute only on musiccaps for audio generation 4 | # evaluation with objective metrics 5 | # execute_only=evaluate 6 | 7 | dataset: 8 | max_audio_duration: null 9 | # ensure the proper values are broadcasted here for evaluate 10 | evaluate: 11 | min_audio_duration: 1. # some metrics requires a minimum audio length 12 | max_audio_duration: null # all samples from musiccaps should be < 20s 13 | num_samples: null 14 | segment_duration: null 15 | generate: 16 | min_audio_duration: 1. 17 | max_audio_duration: null 18 | num_samples: 500 19 | 20 | evaluate: 21 | metrics: 22 | fad: true 23 | kld: true 24 | text_consistency: true 25 | -------------------------------------------------------------------------------- /audiocraft/modules/lstm.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from torch import nn 8 | 9 | 10 | class StreamableLSTM(nn.Module): 11 | """LSTM without worrying about the hidden state, nor the layout of the data. 12 | Expects input as convolutional layout. 13 | """ 14 | def __init__(self, dimension: int, num_layers: int = 2, skip: bool = True): 15 | super().__init__() 16 | self.skip = skip 17 | self.lstm = nn.LSTM(dimension, dimension, num_layers) 18 | 19 | def forward(self, x): 20 | x = x.permute(2, 0, 1) 21 | y, _ = self.lstm(x) 22 | if self.skip: 23 | y = y + x 24 | y = y.permute(1, 2, 0) 25 | return y 26 | -------------------------------------------------------------------------------- /.github/workflows/audiocraft_docs.yml: -------------------------------------------------------------------------------- 1 | name: audiocraft_docs 2 | on: 3 | push: 4 | branches: [ main ] 5 | 6 | jobs: 7 | run_docs: 8 | name: Run docs 9 | runs-on: ubuntu-latest 10 | steps: 11 | - uses: actions/checkout@v2 12 | - uses: ./.github/actions/audiocraft_build 13 | - name: Config git 14 | run: | 15 | git config --global user.email "defossez@fb.com" 16 | git config --global user.name "Alexandre Défossez (autodoc)" 17 | 18 | - name: Reset branch 19 | run: | 20 | git branch -f gh-docs main 21 | git checkout gh-docs 22 | 23 | - name: Make docs 24 | run: | 25 | . env/bin/activate 26 | make api_docs 27 | git add -f api_docs 28 | git commit -m api_docs 29 | 30 | - name: Push branch 31 | run: | 32 | git push -f -u origin gh-docs 33 | -------------------------------------------------------------------------------- /.github/actions/audiocraft_build/action.yml: -------------------------------------------------------------------------------- 1 | name: audiocraft_build 2 | description: 'Build audiocraft env.' 3 | runs: 4 | using: "composite" 5 | steps: 6 | - uses: actions/setup-python@v2 7 | with: 8 | python-version: 3.8 9 | - uses: actions/cache@v2 10 | id: cache 11 | with: 12 | path: env 13 | key: audiocraft_env-${{ hashFiles('**/requirements.txt') }} 14 | 15 | - if: ${{ steps.cache.outputs.cache-hit != 'true' }} 16 | name: Install dependencies 17 | shell: bash 18 | run: | 19 | sudo apt-get update 20 | sudo apt-get install libsndfile1-dev ffmpeg 21 | python3 -m venv env 22 | . env/bin/activate 23 | python -m pip install --upgrade pip 24 | pip install -e '.[dev]' 25 | - name: System Dependencies 26 | shell: bash 27 | run: | 28 | sudo apt-get update 29 | sudo apt-get install libsndfile1-dev ffmpeg 30 | -------------------------------------------------------------------------------- /audiocraft/grids/audiogen/audiogen_base_16khz.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from ..musicgen._explorers import LMExplorer 8 | from ...environment import AudioCraftEnvironment 9 | 10 | 11 | @LMExplorer 12 | def explorer(launcher): 13 | partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global']) 14 | launcher.slurm_(gpus=64, partition=partitions) 15 | launcher.bind_(solver='audiogen/audiogen_base_16khz') 16 | # replace this by the desired environmental sound dataset 17 | launcher.bind_(dset='internal/sounds_16khz') 18 | 19 | fsdp = {'autocast': False, 'fsdp.use': True} 20 | medium = {'model/lm/model_scale': 'medium'} 21 | 22 | launcher.bind_(fsdp) 23 | launcher(medium) 24 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # macOS dir files 10 | .DS_Store 11 | 12 | # Distribution / packaging 13 | .Python 14 | env/ 15 | build/ 16 | develop-eggs/ 17 | dist/ 18 | downloads/ 19 | eggs/ 20 | .eggs/ 21 | lib/ 22 | lib64/ 23 | parts/ 24 | sdist/ 25 | var/ 26 | wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | .ipynb_checkpoints 31 | 32 | # Tests and linter 33 | .pytest_cache/ 34 | .mypy_cache/ 35 | .coverage 36 | 37 | # docs 38 | /api_docs 39 | 40 | # dotenv 41 | .env 42 | .envrc 43 | 44 | # virtualenv 45 | .venv 46 | venv/ 47 | ENV/ 48 | 49 | # egs with manifest files 50 | egs/* 51 | !egs/example 52 | # local datasets 53 | dataset/* 54 | !dataset/example 55 | 56 | # personal notebooks & scripts 57 | */local_scripts 58 | */notes 59 | .vscode/ 60 | /notebooks 61 | /local_scripts 62 | /notes 63 | -------------------------------------------------------------------------------- /config/model/lm/audiogen_lm.yaml: -------------------------------------------------------------------------------- 1 | # @package __global__ 2 | 3 | defaults: 4 | - lm/default 5 | - override /conditioner: text2sound 6 | - override /model/lm/model_scale: small # prefer this group to set model scale instead of transformer_lm keys directly 7 | 8 | lm_model: transformer_lm 9 | 10 | codebooks_pattern: 11 | modeling: delay 12 | delay: 13 | delays: [0, 1, 2, 3] 14 | flatten_first: 0 15 | empty_initial: 0 16 | unroll: 17 | flattening: [0, 1, 2, 3] 18 | delays: [0, 0, 0, 0] 19 | music_lm: 20 | group_by: 2 21 | valle: 22 | delays: [0, 0, 0] 23 | 24 | transformer_lm: 25 | n_q: 4 26 | card: 2048 27 | memory_efficient: true 28 | bias_proj: false 29 | bias_ff: false 30 | bias_attn: false 31 | norm_first: true 32 | layer_scale: null 33 | weight_init: gaussian 34 | depthwise_init: current 35 | zero_bias_init: true 36 | attention_as_float32: false 37 | -------------------------------------------------------------------------------- /config/model/lm/musicgen_lm.yaml: -------------------------------------------------------------------------------- 1 | # @package __global__ 2 | 3 | defaults: 4 | - lm/default 5 | - override /conditioner: text2music 6 | - override /model/lm/model_scale: small # prefer this group to set model scale instead of transformer_lm keys directly 7 | 8 | lm_model: transformer_lm 9 | 10 | codebooks_pattern: 11 | modeling: delay 12 | delay: 13 | delays: [0, 1, 2, 3] 14 | flatten_first: 0 15 | empty_initial: 0 16 | unroll: 17 | flattening: [0, 1, 2, 3] 18 | delays: [0, 0, 0, 0] 19 | music_lm: 20 | group_by: 2 21 | valle: 22 | delays: [0, 0, 0] 23 | 24 | transformer_lm: 25 | n_q: 4 26 | card: 2048 27 | memory_efficient: true 28 | bias_proj: false 29 | bias_ff: false 30 | bias_attn: false 31 | norm_first: true 32 | layer_scale: null 33 | weight_init: gaussian 34 | depthwise_init: current 35 | zero_bias_init: true 36 | attention_as_float32: false 37 | -------------------------------------------------------------------------------- /tests/modules/test_lstm.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import random 8 | import torch 9 | 10 | from audiocraft.modules.lstm import StreamableLSTM 11 | 12 | 13 | class TestStreamableLSTM: 14 | 15 | def test_lstm(self): 16 | B, C, T = 4, 2, random.randint(1, 100) 17 | 18 | lstm = StreamableLSTM(C, 3, skip=False) 19 | x = torch.randn(B, C, T) 20 | y = lstm(x) 21 | 22 | print(y.shape) 23 | assert y.shape == torch.Size([B, C, T]) 24 | 25 | def test_lstm_skip(self): 26 | B, C, T = 4, 2, random.randint(1, 100) 27 | 28 | lstm = StreamableLSTM(C, 3, skip=True) 29 | x = torch.randn(B, C, T) 30 | y = lstm(x) 31 | 32 | assert y.shape == torch.Size([B, C, T]) 33 | -------------------------------------------------------------------------------- /config/teams/labs.yaml: -------------------------------------------------------------------------------- 1 | aws: 2 | dora_dir: /fsx-audio-craft-llm/${oc.env:USER}/experiments/audiocraft/outputs 3 | partitions: 4 | global: learnlab 5 | team: learnlab 6 | reference_dir: /fsx-audio-craft-llm/shared/audiocraft/reference 7 | dataset_mappers: 8 | "^/checkpoint/[a-z]+": "/fsx-audio-craft-llm" 9 | fair: 10 | dora_dir: /checkpoint/${oc.env:USER}/experiments/audiocraft/outputs 11 | partitions: 12 | global: learnlab 13 | team: learnlab 14 | reference_dir: /large_experiments/audiocraft/reference 15 | dataset_mappers: 16 | "^/datasets01/datasets01": "/datasets01" 17 | darwin: 18 | dora_dir: /tmp/audiocraft_${oc.env:USER} 19 | partitions: 20 | global: debug 21 | team: debug 22 | reference_dir: /tmp 23 | rsc: 24 | dora_dir: /checkpoint/audiocraft/${oc.env:USER}/experiments/audiocraft/outputs 25 | partitions: 26 | global: learn 27 | team: learn 28 | reference_dir: /checkpoint/audiocraft/shared/reference 29 | -------------------------------------------------------------------------------- /scripts/templates/index.html: -------------------------------------------------------------------------------- 1 | {% extends "base.html" %} 2 | {% block content %} 3 | 4 |

5 | Welcome {{session['user']}} to the internal MOS assistant for AudioCraft. 6 | You can create custom surveys between your models, that you can 7 | evaluate yourself, or with the help of your teammates, by simply 8 | sharing a link! 9 |

10 | 11 | {% for error in errors %} 12 |

{{error}}

13 | {% endfor %} 14 | 15 |
16 |
18 | 19 |
20 |
21 | 24 |
25 | 26 | 27 | 28 | {% endblock %} 29 | -------------------------------------------------------------------------------- /config/solver/audiogen/default.yaml: -------------------------------------------------------------------------------- 1 | # @package __global__ 2 | 3 | defaults: 4 | - /solver/musicgen/default 5 | - _self_ 6 | - /solver/audiogen/evaluation: none 7 | - override /dset: audio/default 8 | 9 | # See config/solver/musicgen/default.yaml for a list of possible values. 10 | # We only keep the most important here. 11 | 12 | autocast: true 13 | autocast_dtype: float16 14 | 15 | solver: audiogen 16 | sample_rate: ??? 17 | channels: ??? 18 | compression_model_checkpoint: ??? 19 | 20 | tokens: 21 | padding_with_special_token: false 22 | 23 | dataset: 24 | batch_size: 128 25 | segment_duration: 10 26 | min_segment_ratio: 1.0 # lower values such as 0.5 result in generations with a lot of silence. 27 | 28 | optim: 29 | epochs: 100 30 | updates_per_epoch: 2000 31 | lr: 1e-4 32 | optimizer: adamw 33 | max_norm: 1.0 34 | adam: 35 | betas: [0.9, 0.95] 36 | weight_decay: 0.1 37 | eps: 1e-8 38 | 39 | schedule: 40 | lr_scheduler: null 41 | -------------------------------------------------------------------------------- /tests/modules/test_activations.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | from torch import nn 9 | 10 | from audiocraft.modules.activations import CustomGLU 11 | 12 | 13 | class TestActivations: 14 | def test_custom_glu_calculation(self): 15 | 16 | activation = CustomGLU(nn.Identity()) 17 | 18 | initial_shape = (4, 8, 8) 19 | 20 | part_a = torch.ones(initial_shape) * 2 21 | part_b = torch.ones(initial_shape) * -1 22 | input = torch.cat((part_a, part_b), dim=-1) 23 | 24 | output = activation(input) 25 | 26 | # ensure all dimensions match initial shape 27 | assert output.shape == initial_shape 28 | # ensure the gating was calculated correctly a * f(b) 29 | assert torch.all(output == -2).item() 30 | -------------------------------------------------------------------------------- /tests/common_utils/wav_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from pathlib import Path 8 | import typing as tp 9 | 10 | import torch 11 | import torchaudio 12 | 13 | 14 | def get_white_noise(chs: int = 1, num_frames: int = 1): 15 | wav = torch.randn(chs, num_frames) 16 | return wav 17 | 18 | 19 | def get_batch_white_noise(bs: int = 1, chs: int = 1, num_frames: int = 1): 20 | wav = torch.randn(bs, chs, num_frames) 21 | return wav 22 | 23 | 24 | def save_wav(path: str, wav: torch.Tensor, sample_rate: int): 25 | fp = Path(path) 26 | kwargs: tp.Dict[str, tp.Any] = {} 27 | if fp.suffix == '.wav': 28 | kwargs['encoding'] = 'PCM_S' 29 | kwargs['bits_per_sample'] = 16 30 | elif fp.suffix == '.mp3': 31 | kwargs['compression'] = 320 32 | torchaudio.save(str(fp), wav, sample_rate, **kwargs) 33 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # Changelog 2 | 3 | All notable changes to this project will be documented in this file. 4 | 5 | The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/). 6 | 7 | ## [1.0.0] - 2023-08-02 8 | 9 | Major revision, added training code for EnCodec, AudioGen, MusicGen, and MultiBandDiffusion. 10 | Added pretrained model for AudioGen and MultiBandDiffusion. 11 | 12 | ## [0.0.2] - 2023-08-01 13 | 14 | Improved demo, fixed top p (thanks @jnordberg). 15 | 16 | Compressor tanh on output to avoid clipping with some style (especially piano). 17 | Now repeating the conditioning periodically if it is too short. 18 | 19 | More options when launching Gradio app locally (thanks @ashleykleynhans). 20 | 21 | Testing out PyTorch 2.0 memory efficient attention. 22 | 23 | Added extended generation (infinite length) by slowly moving the windows. 24 | Note that other implementations exist: https://github.com/camenduru/MusicGen-colab. 25 | 26 | ## [0.0.1] - 2023-06-09 27 | 28 | Initial release, with model evaluation only. 29 | -------------------------------------------------------------------------------- /audiocraft/utils/notebook.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | try: 8 | import IPython.display as ipd # type: ignore 9 | except ImportError: 10 | # Note in a notebook... 11 | pass 12 | 13 | 14 | import torch 15 | 16 | 17 | def display_audio(samples: torch.Tensor, sample_rate: int): 18 | """Renders an audio player for the given audio samples. 19 | 20 | Args: 21 | samples (torch.Tensor): a Tensor of decoded audio samples 22 | with shapes [B, C, T] or [C, T] 23 | sample_rate (int): sample rate audio should be displayed with. 24 | """ 25 | assert samples.dim() == 2 or samples.dim() == 3 26 | 27 | samples = samples.detach().cpu() 28 | if samples.dim() == 2: 29 | samples = samples[None, ...] 30 | 31 | for audio in samples: 32 | ipd.display(ipd.Audio(audio, rate=sample_rate)) 33 | -------------------------------------------------------------------------------- /config/conditioner/chroma2music.yaml: -------------------------------------------------------------------------------- 1 | # @package __global__ 2 | 3 | classifier_free_guidance: 4 | training_dropout: 0.2 5 | inference_coef: 3.0 6 | 7 | attribute_dropout: 8 | args: 9 | active_on_eval: false 10 | text: {} 11 | wav: 12 | self_wav: 0.5 13 | 14 | fuser: 15 | cross_attention_pos_emb: false 16 | cross_attention_pos_emb_scale: 1 17 | sum: [] 18 | prepend: [self_wav, description] 19 | cross: [] 20 | input_interpolate: [] 21 | 22 | conditioners: 23 | self_wav: 24 | model: chroma_stem 25 | chroma_stem: 26 | sample_rate: ${sample_rate} 27 | n_chroma: 12 28 | radix2_exp: 14 29 | argmax: true 30 | match_len_on_eval: false 31 | eval_wavs: null 32 | n_eval_wavs: 100 33 | cache_path: null 34 | description: 35 | model: t5 36 | t5: 37 | name: t5-base 38 | finetune: false 39 | word_dropout: 0.2 40 | normalize_text: false 41 | 42 | dataset: 43 | train: 44 | merge_text_p: 0.25 45 | drop_desc_p: 0.5 46 | drop_other_p: 0.5 47 | -------------------------------------------------------------------------------- /audiocraft/adversarial/discriminators/base.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from abc import ABC, abstractmethod 8 | import typing as tp 9 | 10 | import torch 11 | import torch.nn as nn 12 | 13 | 14 | FeatureMapType = tp.List[torch.Tensor] 15 | LogitsType = torch.Tensor 16 | MultiDiscriminatorOutputType = tp.Tuple[tp.List[LogitsType], tp.List[FeatureMapType]] 17 | 18 | 19 | class MultiDiscriminator(ABC, nn.Module): 20 | """Base implementation for discriminators composed of sub-discriminators acting at different scales. 21 | """ 22 | def __init__(self): 23 | super().__init__() 24 | 25 | @abstractmethod 26 | def forward(self, x: torch.Tensor) -> MultiDiscriminatorOutputType: 27 | ... 28 | 29 | @property 30 | @abstractmethod 31 | def num_discriminators(self) -> int: 32 | """Number of discriminators. 33 | """ 34 | ... 35 | -------------------------------------------------------------------------------- /config/solver/compression/debug.yaml: -------------------------------------------------------------------------------- 1 | # @package __global__ 2 | 3 | defaults: 4 | - compression/default 5 | - /model: encodec/encodec_base_causal 6 | - override /dset: audio/example 7 | - _self_ 8 | 9 | channels: 1 10 | sample_rate: 16000 11 | 12 | # debug config uses just L1 13 | losses: 14 | adv: 0. 15 | feat: 0. 16 | l1: 1. 17 | mel: 0. 18 | msspec: 0. 19 | # no balancer 20 | balancer: 21 | balance_grads: false 22 | ema_decay: 1. 23 | total_norm: 1. 24 | per_batch_item: false 25 | # no adversaries 26 | adversarial: 27 | adversaries: [] 28 | adv_loss: hinge 29 | feat_loss: l1 30 | 31 | # faster model for local dev 32 | seanet: 33 | dimension: 16 34 | n_filters: 4 35 | 36 | # very small dataset 37 | dataset: 38 | batch_size: 8 39 | num_workers: 10 40 | num_samples: 100 41 | segment_duration: 1 42 | evaluate: 43 | batch_size: 32 44 | generate: 45 | batch_size: 1 46 | num_samples: 5 47 | segment_duration: 10 48 | 49 | # limited training 50 | evaluate: 51 | every: 5 52 | generate: 53 | every: 5 54 | optim: 55 | epochs: 50 56 | -------------------------------------------------------------------------------- /audiocraft/grids/compression/encodec_base_24khz.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | """ 8 | Grid search file, simply list all the exp you want in `explorer`. 9 | Any new exp added there will be scheduled. 10 | You can cancel and experiment by commenting its line. 11 | 12 | This grid shows how to train a base causal EnCodec model at 24 kHz. 13 | """ 14 | 15 | from ._explorers import CompressionExplorer 16 | from ...environment import AudioCraftEnvironment 17 | 18 | 19 | @CompressionExplorer 20 | def explorer(launcher): 21 | partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global']) 22 | launcher.slurm_(gpus=8, partition=partitions) 23 | # base causal EnCodec trained on monophonic audio sampled at 24 kHz 24 | launcher.bind_(solver='compression/encodec_base_24khz') 25 | # replace this by the desired dataset 26 | launcher.bind_(dset='audio/example') 27 | # launch xp 28 | launcher() 29 | -------------------------------------------------------------------------------- /config/solver/audiogen/debug.yaml: -------------------------------------------------------------------------------- 1 | # @package __global__ 2 | 3 | # This is a minimal debugging configuration 4 | # for MusicGen training solver 5 | defaults: 6 | - audiogen/default 7 | - /model: lm/audiogen_lm 8 | - override /model/lm/model_scale: xsmall 9 | - override /dset: audio/example 10 | - _self_ 11 | 12 | autocast: false 13 | compression_model_checkpoint: null 14 | 15 | codebooks_pattern: 16 | modeling: parallel 17 | 18 | channels: 1 19 | sample_rate: 16000 20 | 21 | deadlock: 22 | use: false # deadlock detection 23 | 24 | dataset: 25 | batch_size: 4 26 | segment_duration: 5 27 | sample_on_weight: false # Uniform sampling all the way 28 | sample_on_duration: false # Uniform sampling all the way 29 | 30 | generate: 31 | audio: 32 | strategy: peak 33 | lm: 34 | use_sampling: false 35 | top_k: 0 36 | top_p: 0.0 37 | 38 | checkpoint: 39 | save_every: 0 40 | keep_last: 0 41 | 42 | optim: 43 | epochs: 2 44 | updates_per_epoch: 10 45 | optimizer: adamw 46 | lr: 1e-4 47 | 48 | logging: 49 | log_tensorboard: true 50 | 51 | schedule: 52 | lr_scheduler: null 53 | -------------------------------------------------------------------------------- /config/conditioner/clapemb2music.yaml: -------------------------------------------------------------------------------- 1 | # @package __global__ 2 | 3 | classifier_free_guidance: 4 | training_dropout: 0.3 5 | inference_coef: 3.0 6 | 7 | attribute_dropout: 8 | text: {} 9 | wav: {} 10 | 11 | fuser: 12 | cross_attention_pos_emb: false 13 | cross_attention_pos_emb_scale: 1 14 | sum: [] 15 | prepend: [] 16 | cross: [description] 17 | input_interpolate: [] 18 | 19 | conditioners: 20 | description: 21 | model: clap 22 | clap: 23 | checkpoint: //reference/clap/music_audioset_epoch_15_esc_90.14.pt 24 | model_arch: 'HTSAT-base' 25 | enable_fusion: false 26 | sample_rate: 44100 27 | max_audio_length: 10 28 | audio_stride: 1 29 | dim: 512 30 | attribute: description 31 | normalize: true 32 | quantize: true # use RVQ quantization 33 | n_q: 12 34 | bins: 1024 35 | kmeans_iters: 50 36 | text_p: 0. # probability of using text embed at train time 37 | cache_path: null 38 | 39 | dataset: 40 | joint_embed_attributes: [description] 41 | train: 42 | merge_text_p: 0.25 43 | drop_desc_p: 0.5 44 | drop_other_p: 0.5 45 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Meta Platforms, Inc. and affiliates. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /config/solver/musicgen/debug.yaml: -------------------------------------------------------------------------------- 1 | # @package __global__ 2 | 3 | # This is a minimal debugging configuration 4 | # for MusicGen training solver 5 | defaults: 6 | - musicgen/default 7 | - /model: lm/musicgen_lm 8 | - override /model/lm/model_scale: xsmall 9 | - override /dset: audio/example 10 | - _self_ 11 | 12 | autocast: false 13 | compression_model_checkpoint: //pretrained/debug_compression_model 14 | transformer_lm: 15 | n_q: 4 16 | card: 400 17 | 18 | codebooks_pattern: 19 | modeling: parallel 20 | 21 | channels: 1 22 | sample_rate: 32000 23 | 24 | deadlock: 25 | use: false # deadlock detection 26 | 27 | dataset: 28 | batch_size: 4 29 | segment_duration: 5 30 | sample_on_weight: false # Uniform sampling all the way 31 | sample_on_duration: false # Uniform sampling all the way 32 | 33 | generate: 34 | audio: 35 | strategy: peak 36 | lm: 37 | use_sampling: false 38 | top_k: 0 39 | top_p: 0.0 40 | 41 | checkpoint: 42 | save_every: 0 43 | keep_last: 0 44 | 45 | optim: 46 | epochs: 2 47 | updates_per_epoch: 10 48 | optimizer: adamw 49 | lr: 1e-4 50 | 51 | logging: 52 | log_tensorboard: true 53 | 54 | schedule: 55 | lr_scheduler: null 56 | -------------------------------------------------------------------------------- /audiocraft/grids/diffusion/4_bands_base_32khz.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | """ 8 | Training of the 4 diffusion models described in 9 | "From Discrete Tokens to High-Fidelity Audio Using Multi-Band Diffusion" 10 | (paper link). 11 | """ 12 | 13 | from ._explorers import DiffusionExplorer 14 | 15 | 16 | @DiffusionExplorer 17 | def explorer(launcher): 18 | launcher.slurm_(gpus=4, partition='learnfair') 19 | 20 | launcher.bind_({'solver': 'diffusion/default', 21 | 'dset': 'internal/music_10k_32khz'}) 22 | 23 | with launcher.job_array(): 24 | launcher({'filter.use': True, 'filter.idx_band': 0, "processor.use": False, 'processor.power_std': 0.4}) 25 | launcher({'filter.use': True, 'filter.idx_band': 1, "processor.use": False, 'processor.power_std': 0.4}) 26 | launcher({'filter.use': True, 'filter.idx_band': 2, "processor.use": True, 'processor.power_std': 0.4}) 27 | launcher({'filter.use': True, 'filter.idx_band': 3, "processor.use": True, 'processor.power_std': 0.75}) 28 | -------------------------------------------------------------------------------- /audiocraft/grids/compression/encodec_audiogen_16khz.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | """ 8 | Grid search file, simply list all the exp you want in `explorer`. 9 | Any new exp added there will be scheduled. 10 | You can cancel and experiment by commenting its line. 11 | 12 | This grid shows how to train the new AudioGen EnCodec model at 16 kHz. 13 | """ 14 | 15 | from ._explorers import CompressionExplorer 16 | from ...environment import AudioCraftEnvironment 17 | 18 | 19 | @CompressionExplorer 20 | def explorer(launcher): 21 | partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global']) 22 | launcher.slurm_(gpus=8, partition=partitions) 23 | # use configuration for AudioGen's EnCodec model trained on monophonic audio sampled at 16 kHz 24 | # AudioGen's EnCodec is trained with a total stride of 320 leading to a frame rate of 50 hz 25 | launcher.bind_(solver='compression/encodec_audiogen_16khz') 26 | # replace this by the desired sound dataset 27 | launcher.bind_(dset='internal/sounds_16khz') 28 | # launch xp 29 | launcher() 30 | -------------------------------------------------------------------------------- /audiocraft/grids/compression/debug.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | """ 8 | Grid search file, simply list all the exp you want in `explorer`. 9 | Any new exp added there will be scheduled. 10 | You can cancel and experiment by commenting its line. 11 | 12 | This grid is a minimal example for debugging compression task 13 | and how to override parameters directly in a grid. 14 | Learn more about dora grids: https://github.com/facebookresearch/dora 15 | """ 16 | 17 | from ._explorers import CompressionExplorer 18 | from ...environment import AudioCraftEnvironment 19 | 20 | 21 | @CompressionExplorer 22 | def explorer(launcher): 23 | partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global']) 24 | launcher.slurm_(gpus=2, partition=partitions) 25 | launcher.bind_(solver='compression/debug') 26 | 27 | with launcher.job_array(): 28 | # base debug task using config from solver=compression/debug 29 | launcher() 30 | # we can override parameters in the grid to launch additional xps 31 | launcher({'rvq.bins': 2048, 'rvq.n_q': 4}) 32 | -------------------------------------------------------------------------------- /audiocraft/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | """ 7 | AudioCraft is a general framework for training audio generative models. 8 | At the moment we provide the training code for: 9 | 10 | - [MusicGen](https://arxiv.org/abs/2306.05284), a state-of-the-art 11 | text-to-music and melody+text autoregressive generative model. 12 | For the solver, see `audiocraft.solvers.musicgen.MusicGenSolver`, and for the model, 13 | `audiocraft.models.musicgen.MusicGen`. 14 | - [AudioGen](https://arxiv.org/abs/2209.15352), a state-of-the-art 15 | text-to-general-audio generative model. 16 | - [EnCodec](https://arxiv.org/abs/2210.13438), efficient and high fidelity 17 | neural audio codec which provides an excellent tokenizer for autoregressive language models. 18 | See `audiocraft.solvers.compression.CompressionSolver`, and `audiocraft.models.encodec.EncodecModel`. 19 | - [MultiBandDiffusion](TODO), alternative diffusion-based decoder compatible with EnCodec that 20 | improves the perceived quality and reduces the artifacts coming from adversarial decoders. 21 | """ 22 | 23 | # flake8: noqa 24 | from . import data, modules, models 25 | 26 | __version__ = '1.0.0' 27 | -------------------------------------------------------------------------------- /audiocraft/grids/musicgen/musicgen_clapemb_32khz.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from ._explorers import LMExplorer 8 | from ...environment import AudioCraftEnvironment 9 | 10 | 11 | @LMExplorer 12 | def explorer(launcher): 13 | partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global']) 14 | launcher.slurm_(gpus=32, partition=partitions) 15 | launcher.bind_(solver='musicgen/musicgen_base_32khz') 16 | # replace this by the desired music dataset 17 | launcher.bind_(dset='internal/music_400k_32khz') 18 | launcher.bind_(conditioner='clapemb2music') 19 | 20 | fsdp = {'autocast': False, 'fsdp.use': True} 21 | cache_path = {'conditioners.description.clap.cache_path': 22 | '/fsx-audio-craft-llm/jadecopet/experiments/audiocraft/caches/clap_embed_music'} 23 | text_wav_training_opt = {'conditioners.description.clap.text_p': 0.5} 24 | 25 | launcher.bind_(fsdp) 26 | 27 | launcher.slurm_(gpus=32).bind_(label='32gpus') 28 | with launcher.job_array(): 29 | launcher() 30 | launcher(text_wav_training_opt) 31 | launcher(cache_path) 32 | launcher(cache_path, text_wav_training_opt) 33 | -------------------------------------------------------------------------------- /config/model/encodec/default.yaml: -------------------------------------------------------------------------------- 1 | # @package __global__ 2 | 3 | compression_model: encodec 4 | 5 | encodec: 6 | autoencoder: seanet 7 | quantizer: rvq 8 | sample_rate: ${sample_rate} 9 | channels: ${channels} 10 | causal: false 11 | renormalize: false 12 | 13 | seanet: 14 | dimension: 128 15 | channels: ${channels} 16 | causal: ${encodec.causal} 17 | n_filters: 32 18 | n_residual_layers: 1 19 | ratios: [8, 5, 4, 2] 20 | activation: ELU 21 | activation_params: {"alpha": 1.} 22 | norm: weight_norm 23 | norm_params: {} 24 | kernel_size: 7 25 | residual_kernel_size: 3 26 | last_kernel_size: 7 27 | dilation_base: 2 28 | pad_mode: constant 29 | true_skip: true 30 | compress: 2 31 | lstm: 2 32 | disable_norm_outer_blocks: 0 33 | # Specific encoder or decoder params. 34 | # You can also override any param for the encoder or decoder only 35 | # by using Hydra `+param=` syntax, i.e.` 36 | # `+seanet.decoder.n_filters=64`. 37 | decoder: 38 | trim_right_ratio: 1.0 39 | final_activation: null 40 | final_activation_params: null 41 | encoder: {} 42 | 43 | rvq: 44 | n_q: 8 45 | q_dropout: false 46 | bins: 1024 47 | decay: 0.99 48 | kmeans_init: true 49 | kmeans_iters: 50 50 | threshold_ema_dead_code: 2 51 | orthogonal_reg_weight: 0.0 52 | orthogonal_reg_active_codes_only: false 53 | 54 | no_quant: {} 55 | -------------------------------------------------------------------------------- /audiocraft/utils/profiler.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import logging 8 | import typing as tp 9 | 10 | import dora 11 | import torch 12 | 13 | 14 | logger = logging.getLogger(__name__) 15 | 16 | 17 | class Profiler: 18 | """Context manager wrapper for xformers profiler. 19 | """ 20 | def __init__(self, module: torch.nn.Module, enabled: bool = False): 21 | self.profiler: tp.Optional[tp.Any] = None 22 | if enabled: 23 | from xformers.profiler import profile 24 | output_dir = dora.get_xp().folder / 'profiler_data' 25 | logger.info("Profiling activated, results with be saved to %s", output_dir) 26 | self.profiler = profile(output_dir=output_dir, module=module) 27 | 28 | def step(self): 29 | if self.profiler is not None: 30 | self.profiler.step() # type: ignore 31 | 32 | def __enter__(self): 33 | if self.profiler is not None: 34 | return self.profiler.__enter__() # type: ignore 35 | 36 | def __exit__(self, exc_type, exc_value, exc_tb): 37 | if self.profiler is not None: 38 | return self.profiler.__exit__(exc_type, exc_value, exc_tb) # type: ignore 39 | -------------------------------------------------------------------------------- /config/solver/musicgen/musicgen_base_32khz.yaml: -------------------------------------------------------------------------------- 1 | # @package __global__ 2 | 3 | # This is the training loop solver 4 | # for the base MusicGen model (text-to-music) 5 | # on monophonic audio sampled at 32 kHz 6 | defaults: 7 | - musicgen/default 8 | - /model: lm/musicgen_lm 9 | - override /dset: audio/default 10 | - _self_ 11 | 12 | autocast: true 13 | autocast_dtype: float16 14 | 15 | # EnCodec large trained on mono-channel music audio sampled at 32khz 16 | # with a total stride of 640 leading to 50 frames/s. 17 | # rvq.n_q=4, rvq.bins=2048, no quantization dropout 18 | # (transformer_lm card and n_q must be compatible) 19 | compression_model_checkpoint: //pretrained/facebook/encodec_32khz 20 | 21 | channels: 1 22 | sample_rate: 32000 23 | 24 | deadlock: 25 | use: true # deadlock detection 26 | 27 | dataset: 28 | batch_size: 192 # 32 GPUs 29 | sample_on_weight: false # Uniform sampling all the way 30 | sample_on_duration: false # Uniform sampling all the way 31 | 32 | generate: 33 | lm: 34 | use_sampling: true 35 | top_k: 250 36 | top_p: 0.0 37 | 38 | optim: 39 | epochs: 500 40 | optimizer: dadam 41 | lr: 1 42 | ema: 43 | use: true 44 | updates: 10 45 | device: cuda 46 | 47 | logging: 48 | log_tensorboard: true 49 | 50 | schedule: 51 | lr_scheduler: cosine 52 | cosine: 53 | warmup: 4000 54 | lr_min_ratio: 0.0 55 | cycle_length: 1.0 56 | -------------------------------------------------------------------------------- /audiocraft/grids/compression/encodec_musicgen_32khz.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | """ 8 | Grid search file, simply list all the exp you want in `explorer`. 9 | Any new exp added there will be scheduled. 10 | You can cancel and experiment by commenting its line. 11 | 12 | This grid shows how to train a MusicGen EnCodec model at 32 kHz. 13 | """ 14 | 15 | from ._explorers import CompressionExplorer 16 | from ...environment import AudioCraftEnvironment 17 | 18 | 19 | @CompressionExplorer 20 | def explorer(launcher): 21 | partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global']) 22 | launcher.slurm_(gpus=8, partition=partitions) 23 | # use configuration for MusicGen's EnCodec model trained on monophonic audio sampled at 32 kHz 24 | # MusicGen's EnCodec is trained with a total stride of 640 leading to a frame rate of 50 hz 25 | launcher.bind_(solver='compression/encodec_musicgen_32khz') 26 | # replace this by the desired music dataset 27 | launcher.bind_(dset='internal/music_400k_32khz') 28 | # launch xp 29 | launcher() 30 | launcher({ 31 | 'metrics.visqol.bin': '/data/home/jadecopet/local/usr/opt/visqol', 32 | 'label': 'visqol', 33 | 'evaluate.metrics.visqol': True 34 | }) 35 | -------------------------------------------------------------------------------- /audiocraft/optim/linear_warmup_lr_scheduler.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import typing as tp 8 | 9 | from torch.optim import Optimizer 10 | from torch.optim.lr_scheduler import _LRScheduler 11 | 12 | 13 | class LinearWarmupLRScheduler(_LRScheduler): 14 | """Inverse square root LR scheduler. 15 | 16 | Args: 17 | optimizer (Optimizer): Torch optimizer. 18 | warmup_steps (int): Number of warmup steps. 19 | warmup_init_lr (tp.Optional[float]): Initial learning rate 20 | during warmup phase. When not set, use the provided learning rate. 21 | """ 22 | def __init__(self, optimizer: Optimizer, warmup_steps: int, warmup_init_lr: tp.Optional[float] = 0): 23 | self.warmup_steps = warmup_steps 24 | self.warmup_init_lr = warmup_init_lr 25 | super().__init__(optimizer) 26 | 27 | def _get_sched_lr(self, lr: float, step: int): 28 | if step < self.warmup_steps: 29 | warmup_init_lr = self.warmup_init_lr or 0 30 | lr_step = (lr - warmup_init_lr) / self.warmup_steps 31 | lr = warmup_init_lr + step * lr_step 32 | return lr 33 | 34 | def get_lr(self): 35 | return [self._get_sched_lr(base_lr, self.last_epoch) for base_lr in self.base_lrs] 36 | -------------------------------------------------------------------------------- /config/solver/musicgen/musicgen_melody_32khz.yaml: -------------------------------------------------------------------------------- 1 | # @package __global__ 2 | 3 | # This is the training loop solver 4 | # for the melody MusicGen model (text+chroma to music) 5 | # on monophonic audio sampled at 32 kHz 6 | defaults: 7 | - musicgen/default 8 | - /model: lm/musicgen_lm 9 | - override /conditioner: chroma2music 10 | - override /dset: audio/default 11 | - _self_ 12 | 13 | autocast: true 14 | autocast_dtype: float16 15 | 16 | # EnCodec large trained on mono-channel music audio sampled at 32khz 17 | # with a total stride of 640 leading to 50 frames/s. 18 | # rvq.n_q=4, rvq.bins=2048, no quantization dropout 19 | # (transformer_lm card and n_q must be compatible) 20 | compression_model_checkpoint: //pretrained/facebook/encodec_32khz 21 | 22 | channels: 1 23 | sample_rate: 32000 24 | 25 | deadlock: 26 | use: true # deadlock detection 27 | 28 | dataset: 29 | batch_size: 192 # 32 GPUs 30 | sample_on_weight: false # Uniform sampling all the way 31 | sample_on_duration: false # Uniform sampling all the way 32 | 33 | generate: 34 | lm: 35 | use_sampling: true 36 | top_k: 250 37 | top_p: 0.0 38 | 39 | optim: 40 | epochs: 500 41 | optimizer: dadam 42 | lr: 1 43 | ema: 44 | use: true 45 | updates: 10 46 | device: cuda 47 | 48 | logging: 49 | log_tensorboard: true 50 | 51 | schedule: 52 | lr_scheduler: cosine 53 | cosine: 54 | warmup: 4000 55 | lr_min_ratio: 0.0 56 | cycle_length: 1.0 57 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to AudioCraft 2 | 3 | We want to make contributing to this project as easy and transparent as 4 | possible. 5 | 6 | ## Pull Requests 7 | 8 | AudioCraft is the implementation of a research paper. 9 | Therefore, we do not plan on accepting many pull requests for new features. 10 | We certainly welcome them for bug fixes. 11 | 12 | 1. Fork the repo and create your branch from `main`. 13 | 2. If you've added code that should be tested, add tests. 14 | 3. If you've changed APIs, update the documentation. 15 | 4. Ensure the test suite passes. 16 | 5. Make sure your code lints. 17 | 6. If you haven't already, complete the Contributor License Agreement ("CLA"). 18 | 19 | ## Contributor License Agreement ("CLA") 20 | In order to accept your pull request, we need you to submit a CLA. You only need 21 | to do this once to work on any of Meta's open source projects. 22 | 23 | Complete your CLA here: 24 | 25 | ## Issues 26 | We use GitHub issues to track public bugs. Please ensure your description is 27 | clear and has sufficient instructions to be able to reproduce the issue. 28 | 29 | Meta has a [bounty program](https://www.facebook.com/whitehat/) for the safe 30 | disclosure of security bugs. In those cases, please go through the process 31 | outlined on that page and do not file a public issue. 32 | 33 | ## License 34 | By contributing to encodec, you agree that your contributions will be licensed 35 | under the LICENSE file in the root directory of this source tree. 36 | -------------------------------------------------------------------------------- /audiocraft/utils/autocast.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | 9 | 10 | class TorchAutocast: 11 | """TorchAutocast utility class. 12 | Allows you to enable and disable autocast. This is specially useful 13 | when dealing with different architectures and clusters with different 14 | levels of support. 15 | 16 | Args: 17 | enabled (bool): Whether to enable torch.autocast or not. 18 | args: Additional args for torch.autocast. 19 | kwargs: Additional kwargs for torch.autocast 20 | """ 21 | def __init__(self, enabled: bool, *args, **kwargs): 22 | self.autocast = torch.autocast(*args, **kwargs) if enabled else None 23 | 24 | def __enter__(self): 25 | if self.autocast is None: 26 | return 27 | try: 28 | self.autocast.__enter__() 29 | except RuntimeError: 30 | device = self.autocast.device 31 | dtype = self.autocast.fast_dtype 32 | raise RuntimeError( 33 | f"There was an error autocasting with dtype={dtype} device={device}\n" 34 | "If you are on the FAIR Cluster, you might need to use autocast_dtype=float16" 35 | ) 36 | 37 | def __exit__(self, *args, **kwargs): 38 | if self.autocast is None: 39 | return 40 | self.autocast.__exit__(*args, **kwargs) 41 | -------------------------------------------------------------------------------- /audiocraft/optim/inverse_sqrt_lr_scheduler.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import typing as tp 8 | 9 | from torch.optim import Optimizer 10 | from torch.optim.lr_scheduler import _LRScheduler 11 | 12 | 13 | class InverseSquareRootLRScheduler(_LRScheduler): 14 | """Inverse square root LR scheduler. 15 | 16 | Args: 17 | optimizer (Optimizer): Torch optimizer. 18 | warmup_steps (int): Number of warmup steps. 19 | warmup_init_lr (tp.Optional[float]): Initial learning rate 20 | during warmup phase. When not set, use the provided learning rate. 21 | """ 22 | def __init__(self, optimizer: Optimizer, warmup_steps: int, warmup_init_lr: tp.Optional[float] = 0): 23 | self.warmup_steps = warmup_steps 24 | self.warmup_init_lr = warmup_init_lr 25 | super().__init__(optimizer) 26 | 27 | def _get_sched_lr(self, lr: float, step: int): 28 | if step < self.warmup_steps: 29 | warmup_init_lr = self.warmup_init_lr or 0 30 | lr_step = (lr - warmup_init_lr) / self.warmup_steps 31 | lr = warmup_init_lr + step * lr_step 32 | else: 33 | decay_factor = lr * self.warmup_steps**0.5 34 | lr = decay_factor * step**-0.5 35 | return lr 36 | 37 | def get_lr(self): 38 | return [self._get_sched_lr(base_lr, self._step_count) for base_lr in self.base_lrs] 39 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | INTEG=AUDIOCRAFT_DORA_DIR="/tmp/magma_$(USER)" python3 -m dora -v run --clear device=cpu dataset.num_workers=0 optim.epochs=1 \ 2 | dataset.train.num_samples=10 dataset.valid.num_samples=10 \ 3 | dataset.evaluate.num_samples=10 dataset.generate.num_samples=2 sample_rate=16000 \ 4 | logging.level=DEBUG 5 | INTEG_COMPRESSION = $(INTEG) solver=compression/debug rvq.n_q=2 rvq.bins=48 checkpoint.save_last=true # SIG is 616d7b3c 6 | INTEG_MUSICGEN = $(INTEG) solver=musicgen/debug dset=audio/example compression_model_checkpoint=//sig/5091833e \ 7 | transformer_lm.n_q=2 transformer_lm.card=48 transformer_lm.dim=16 checkpoint.save_last=false # Using compression model from 616d7b3c 8 | INTEG_AUDIOGEN = $(INTEG) solver=audiogen/debug dset=audio/example compression_model_checkpoint=//sig/5091833e \ 9 | transformer_lm.n_q=2 transformer_lm.card=48 transformer_lm.dim=16 checkpoint.save_last=false # Using compression model from 616d7b3c 10 | INTEG_MBD = $(INTEG) solver=diffusion/debug dset=audio/example \ 11 | checkpoint.save_last=false # Using compression model from 616d7b3c 12 | 13 | default: linter tests 14 | 15 | install: 16 | pip install -U pip 17 | pip install -U -e '.[dev]' 18 | 19 | linter: 20 | flake8 audiocraft && mypy audiocraft 21 | flake8 tests && mypy tests 22 | 23 | tests: 24 | coverage run -m pytest tests 25 | coverage report 26 | 27 | tests_integ: 28 | $(INTEG_COMPRESSION) 29 | $(INTEG_MBD) 30 | $(INTEG_MUSICGEN) 31 | $(INTEG_AUDIOGEN) 32 | 33 | 34 | api_docs: 35 | pdoc3 --html -o api_docs -f audiocraft 36 | 37 | dist: 38 | python setup.py sdist 39 | 40 | .PHONY: linter tests api_docs dist 41 | -------------------------------------------------------------------------------- /audiocraft/grids/musicgen/musicgen_base_32khz.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from ._explorers import LMExplorer 8 | from ...environment import AudioCraftEnvironment 9 | 10 | 11 | @LMExplorer 12 | def explorer(launcher): 13 | partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global']) 14 | launcher.slurm_(gpus=32, partition=partitions) 15 | launcher.bind_(solver='musicgen/musicgen_base_32khz') 16 | # replace this by the desired music dataset 17 | launcher.bind_(dset='internal/music_400k_32khz') 18 | 19 | fsdp = {'autocast': False, 'fsdp.use': True} 20 | medium = {'model/lm/model_scale': 'medium'} 21 | large = {'model/lm/model_scale': 'large'} 22 | 23 | cfg_low = {'classifier_free_guidance.training_dropout': 0.2} 24 | wd_low = {'conditioners.description.t5.word_dropout': 0.2} 25 | 26 | adam = {'optim.optimizer': 'adamw', 'optim.lr': 1e-4} 27 | 28 | launcher.bind_(fsdp) 29 | 30 | launcher.slurm_(gpus=32).bind_(label='32gpus') 31 | with launcher.job_array(): 32 | sub = launcher.bind() 33 | sub() 34 | 35 | launcher.slurm_(gpus=64).bind_(label='64gpus') 36 | with launcher.job_array(): 37 | sub = launcher.bind() 38 | sub(medium, adam) 39 | 40 | launcher.slurm_(gpus=96).bind_(label='96gpus') 41 | with launcher.job_array(): 42 | sub = launcher.bind() 43 | sub(large, cfg_low, wd_low, adam, {'optim.max_norm': 3}) 44 | -------------------------------------------------------------------------------- /audiocraft/grids/compression/_explorers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import treetable as tt 8 | 9 | from .._base_explorers import BaseExplorer 10 | 11 | 12 | class CompressionExplorer(BaseExplorer): 13 | eval_metrics = ["sisnr", "visqol"] 14 | 15 | def stages(self): 16 | return ["train", "valid", "evaluate"] 17 | 18 | def get_grid_meta(self): 19 | """Returns the list of Meta information to display for each XP/job. 20 | """ 21 | return [ 22 | tt.leaf("index", align=">"), 23 | tt.leaf("name", wrap=140), 24 | tt.leaf("state"), 25 | tt.leaf("sig", align=">"), 26 | ] 27 | 28 | def get_grid_metrics(self): 29 | """Return the metrics that should be displayed in the tracking table. 30 | """ 31 | return [ 32 | tt.group( 33 | "train", 34 | [ 35 | tt.leaf("epoch"), 36 | tt.leaf("bandwidth", ".2f"), 37 | tt.leaf("adv", ".4f"), 38 | tt.leaf("d_loss", ".4f"), 39 | ], 40 | align=">", 41 | ), 42 | tt.group( 43 | "valid", 44 | [ 45 | tt.leaf("bandwidth", ".2f"), 46 | tt.leaf("adv", ".4f"), 47 | tt.leaf("msspec", ".4f"), 48 | tt.leaf("sisnr", ".2f"), 49 | ], 50 | align=">", 51 | ), 52 | tt.group( 53 | "evaluate", [tt.leaf(name, ".3f") for name in self.eval_metrics], align=">" 54 | ), 55 | ] 56 | -------------------------------------------------------------------------------- /tests/models/test_audiogen.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import pytest 8 | import torch 9 | 10 | from audiocraft.models import AudioGen 11 | 12 | 13 | class TestAudioGenModel: 14 | def get_audiogen(self): 15 | ag = AudioGen.get_pretrained(name='debug', device='cpu') 16 | ag.set_generation_params(duration=2.0, extend_stride=2.) 17 | return ag 18 | 19 | def test_base(self): 20 | ag = self.get_audiogen() 21 | assert ag.frame_rate == 25 22 | assert ag.sample_rate == 16000 23 | assert ag.audio_channels == 1 24 | 25 | def test_generate_continuation(self): 26 | ag = self.get_audiogen() 27 | prompt = torch.randn(3, 1, 16000) 28 | wav = ag.generate_continuation(prompt, 16000) 29 | assert list(wav.shape) == [3, 1, 32000] 30 | 31 | prompt = torch.randn(2, 1, 16000) 32 | wav = ag.generate_continuation( 33 | prompt, 16000, ['youpi', 'lapin dort']) 34 | assert list(wav.shape) == [2, 1, 32000] 35 | 36 | prompt = torch.randn(2, 1, 16000) 37 | with pytest.raises(AssertionError): 38 | wav = ag.generate_continuation( 39 | prompt, 16000, ['youpi', 'lapin dort', 'one too many']) 40 | 41 | def test_generate(self): 42 | ag = self.get_audiogen() 43 | wav = ag.generate( 44 | ['youpi', 'lapin dort']) 45 | assert list(wav.shape) == [2, 1, 32000] 46 | 47 | def test_generate_long(self): 48 | ag = self.get_audiogen() 49 | ag.max_duration = 3. 50 | ag.set_generation_params(duration=4., extend_stride=2.) 51 | wav = ag.generate( 52 | ['youpi', 'lapin dort']) 53 | assert list(wav.shape) == [2, 1, 16000 * 4] 54 | -------------------------------------------------------------------------------- /audiocraft/optim/cosine_lr_scheduler.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import math 8 | 9 | from torch.optim import Optimizer 10 | from torch.optim.lr_scheduler import _LRScheduler 11 | 12 | 13 | class CosineLRScheduler(_LRScheduler): 14 | """Cosine LR scheduler. 15 | 16 | Args: 17 | optimizer (Optimizer): Torch optimizer. 18 | warmup_steps (int): Number of warmup steps. 19 | total_steps (int): Total number of steps. 20 | lr_min_ratio (float): Minimum learning rate. 21 | cycle_length (float): Cycle length. 22 | """ 23 | def __init__(self, optimizer: Optimizer, total_steps: int, warmup_steps: int, 24 | lr_min_ratio: float = 0.0, cycle_length: float = 1.0): 25 | self.warmup_steps = warmup_steps 26 | assert self.warmup_steps >= 0 27 | self.total_steps = total_steps 28 | assert self.total_steps >= 0 29 | self.lr_min_ratio = lr_min_ratio 30 | self.cycle_length = cycle_length 31 | super().__init__(optimizer) 32 | 33 | def _get_sched_lr(self, lr: float, step: int): 34 | if step < self.warmup_steps: 35 | lr_ratio = step / self.warmup_steps 36 | lr = lr_ratio * lr 37 | elif step <= self.total_steps: 38 | s = (step - self.warmup_steps) / (self.total_steps - self.warmup_steps) 39 | lr_ratio = self.lr_min_ratio + 0.5 * (1 - self.lr_min_ratio) * \ 40 | (1. + math.cos(math.pi * s / self.cycle_length)) 41 | lr = lr_ratio * lr 42 | else: 43 | lr_ratio = self.lr_min_ratio 44 | lr = lr_ratio * lr 45 | return lr 46 | 47 | def get_lr(self): 48 | return [self._get_sched_lr(lr, self.last_epoch) for lr in self.base_lrs] 49 | -------------------------------------------------------------------------------- /audiocraft/utils/deadlock.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import logging 8 | import os 9 | from queue import Queue, Empty 10 | import signal 11 | import sys 12 | import threading 13 | import traceback 14 | 15 | logger = logging.getLogger(__name__) 16 | 17 | 18 | class DeadlockDetect: 19 | def __init__(self, use: bool = False, timeout: float = 120.): 20 | self.use = use 21 | self.timeout = timeout 22 | self._queue: Queue = Queue() 23 | 24 | def update(self, stage: str): 25 | if self.use: 26 | self._queue.put(stage) 27 | 28 | def __enter__(self): 29 | if self.use: 30 | self._thread = threading.Thread(target=self._detector_thread) 31 | self._thread.start() 32 | 33 | def __exit__(self, exc_type, exc_val, exc_tb): 34 | if self.use: 35 | self._queue.put(None) 36 | self._thread.join() 37 | 38 | def _detector_thread(self): 39 | logger.debug("Deadlock detector started") 40 | last_stage = "init" 41 | while True: 42 | try: 43 | stage = self._queue.get(timeout=self.timeout) 44 | except Empty: 45 | break 46 | if stage is None: 47 | logger.debug("Exiting deadlock detector thread") 48 | return 49 | else: 50 | last_stage = stage 51 | logger.error("Deadlock detector timed out, last stage was %s", last_stage) 52 | for th in threading.enumerate(): 53 | print(th, file=sys.stderr) 54 | traceback.print_stack(sys._current_frames()[th.ident]) 55 | print(file=sys.stderr) 56 | sys.stdout.flush() 57 | sys.stderr.flush() 58 | os.kill(os.getpid(), signal.SIGKILL) 59 | -------------------------------------------------------------------------------- /tests/common_utils/temp_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import os 8 | import tempfile 9 | 10 | 11 | class TempDirMixin: 12 | """Mixin to provide easy access to temp dir. 13 | """ 14 | 15 | temp_dir_ = None 16 | 17 | @classmethod 18 | def get_base_temp_dir(cls): 19 | # If AUDIOCRAFT_TEST_DIR is set, use it instead of temporary directory. 20 | # this is handy for debugging. 21 | key = "AUDIOCRAFT_TEST_DIR" 22 | if key in os.environ: 23 | return os.environ[key] 24 | if cls.temp_dir_ is None: 25 | cls.temp_dir_ = tempfile.TemporaryDirectory() 26 | return cls.temp_dir_.name 27 | 28 | @classmethod 29 | def tearDownClass(cls): 30 | if cls.temp_dir_ is not None: 31 | try: 32 | cls.temp_dir_.cleanup() 33 | cls.temp_dir_ = None 34 | except PermissionError: 35 | # On Windows there is a know issue with `shutil.rmtree`, 36 | # which fails intermittently. 37 | # https://github.com/python/cpython/issues/74168 38 | # Following the above thread, we ignore it. 39 | pass 40 | super().tearDownClass() 41 | 42 | @property 43 | def id(self): 44 | return self.__class__.__name__ 45 | 46 | def get_temp_path(self, *paths): 47 | temp_dir = os.path.join(self.get_base_temp_dir(), self.id) 48 | path = os.path.join(temp_dir, *paths) 49 | os.makedirs(os.path.dirname(path), exist_ok=True) 50 | return path 51 | 52 | def get_temp_dir(self, *paths): 53 | temp_dir = os.path.join(self.get_base_temp_dir(), self.id) 54 | path = os.path.join(temp_dir, *paths) 55 | os.makedirs(path, exist_ok=True) 56 | return path 57 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from pathlib import Path 8 | 9 | from setuptools import setup, find_packages 10 | 11 | 12 | NAME = 'audiocraft' 13 | DESCRIPTION = 'Audio generation research library for PyTorch' 14 | 15 | URL = 'https://github.com/facebookresearch/audiocraft' 16 | AUTHOR = 'FAIR Speech & Audio' 17 | EMAIL = 'defossez@meta.com, jadecopet@meta.com' 18 | REQUIRES_PYTHON = '>=3.8.0' 19 | 20 | for line in open('audiocraft/__init__.py'): 21 | line = line.strip() 22 | if '__version__' in line: 23 | context = {} 24 | exec(line, context) 25 | VERSION = context['__version__'] 26 | 27 | HERE = Path(__file__).parent 28 | 29 | try: 30 | with open(HERE / "README.md", encoding='utf-8') as f: 31 | long_description = '\n' + f.read() 32 | except FileNotFoundError: 33 | long_description = DESCRIPTION 34 | 35 | REQUIRED = [i.strip() for i in open(HERE / 'requirements.txt') if not i.startswith('#')] 36 | 37 | setup( 38 | name=NAME, 39 | version=VERSION, 40 | description=DESCRIPTION, 41 | author_email=EMAIL, 42 | long_description=long_description, 43 | long_description_content_type='text/markdown', 44 | author=AUTHOR, 45 | url=URL, 46 | python_requires=REQUIRES_PYTHON, 47 | install_requires=REQUIRED, 48 | extras_require={ 49 | 'dev': ['coverage', 'flake8', 'mypy', 'pdoc3', 'pytest'], 50 | }, 51 | packages=find_packages(), 52 | package_data={'audiocraft': ['py.typed']}, 53 | include_package_data=True, 54 | license='MIT License', 55 | classifiers=[ 56 | # Trove classifiers 57 | # Full list: https://pypi.python.org/pypi?%3Aaction=list_classifiers 58 | 'License :: OSI Approved :: MIT License', 59 | 'Topic :: Multimedia :: Sound/Audio', 60 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 61 | ], 62 | ) 63 | -------------------------------------------------------------------------------- /config/solver/audiogen/audiogen_base_16khz.yaml: -------------------------------------------------------------------------------- 1 | # @package __global__ 2 | 3 | # This is the training loop solver 4 | # for the base AudioGen model (text-to-sound) 5 | # on monophonic audio sampled at 16 kHz 6 | # using a similar EnCodec+LM setup to MusicGen 7 | defaults: 8 | - audiogen/default 9 | - /model: lm/audiogen_lm 10 | - override /dset: audio/default 11 | - _self_ 12 | 13 | autocast: true 14 | autocast_dtype: float16 15 | 16 | # EnCodec large trained on mono-channel music audio sampled at 16khz 17 | # with a total stride of 320 leading to 50 frames/s. 18 | # rvq.n_q=4, rvq.bins=2048, no quantization dropout 19 | # (transformer_lm card and n_q must be compatible) 20 | compression_model_checkpoint: //reference/bd44a852/checkpoint.th 21 | 22 | channels: 1 23 | sample_rate: 16000 24 | 25 | deadlock: 26 | use: true # deadlock detection 27 | 28 | dataset: 29 | batch_size: 128 # matching AudioGen paper setup (256 * mix_p=0.5 = 128) 30 | num_workers: 10 31 | segment_duration: 10 32 | min_segment_ratio: 1.0 33 | sample_on_weight: false # Uniform sampling all the way 34 | sample_on_duration: false # Uniform sampling all the way 35 | external_metadata_source: null 36 | # sample mixing augmentation at train time 37 | train: 38 | batch_size: 256 # matching AudioGen paper setup 39 | aug_p: 0.5 # perform audio mixing 50% of the time 40 | mix_p: 0.5 # proportion of batch items mixed together 41 | # important: note that this will reduce the 42 | # actual batch size used at train time 43 | # which will be equal to mix_p * batch_size 44 | mix_snr_low: -5 45 | mix_snr_high: 5 46 | mix_min_overlap: 0.5 47 | 48 | generate: 49 | lm: 50 | use_sampling: true 51 | top_k: 250 52 | top_p: 0.0 53 | 54 | optim: 55 | epochs: 100 56 | optimizer: adamw 57 | lr: 5e-4 58 | ema: 59 | use: true 60 | updates: 10 61 | device: cuda 62 | 63 | logging: 64 | log_tensorboard: true 65 | 66 | schedule: 67 | lr_scheduler: inverse_sqrt 68 | inverse_sqrt: 69 | warmup: 3000 70 | warmup_init_lr: 0.0 71 | -------------------------------------------------------------------------------- /scripts/static/style.css: -------------------------------------------------------------------------------- 1 | body { 2 | background-color: #fbfbfb; 3 | margin: 0; 4 | } 5 | 6 | select, input { 7 | font-size: 1em; 8 | max-width: 100%; 9 | } 10 | 11 | .xp_name { 12 | font-family: monospace; 13 | } 14 | 15 | .simple_form { 16 | background-color: #dddddd; 17 | padding: 1em; 18 | margin: 0.5em; 19 | } 20 | 21 | textarea { 22 | margin-top: 0.5em; 23 | margin-bottom: 0.5em; 24 | } 25 | 26 | .rating { 27 | background-color: grey; 28 | padding-top: 5px; 29 | padding-bottom: 5px; 30 | padding-left: 8px; 31 | padding-right: 8px; 32 | margin-right: 2px; 33 | cursor:pointer; 34 | } 35 | 36 | .rating_selected { 37 | background-color: purple; 38 | } 39 | 40 | .content { 41 | font-family: sans-serif; 42 | background-color: #f6f6f6; 43 | padding: 40px; 44 | margin: 0 auto; 45 | max-width: 1000px; 46 | } 47 | 48 | .track label { 49 | padding-top: 10px; 50 | padding-bottom: 10px; 51 | } 52 | .track { 53 | padding: 15px; 54 | margin: 5px; 55 | background-color: #c8c8c8; 56 | } 57 | 58 | .submit-big { 59 | width:400px; 60 | height:30px; 61 | font-size: 20px; 62 | } 63 | 64 | .error { 65 | color: red; 66 | } 67 | 68 | .ratings { 69 | margin-left: 10px; 70 | } 71 | 72 | .important { 73 | font-weight: bold; 74 | } 75 | 76 | .survey { 77 | margin-bottom: 100px; 78 | } 79 | 80 | .success { 81 | color: #25901b; 82 | font-weight: bold; 83 | } 84 | .warning { 85 | color: #8a1f19; 86 | font-weight: bold; 87 | } 88 | .track>section { 89 | display: flex; 90 | align-items: center; 91 | } 92 | 93 | .prompt { 94 | display: flex; 95 | align-items: center; 96 | } 97 | 98 | .track>section>div { 99 | padding-left: 10px; 100 | } 101 | 102 | audio { 103 | max-width: 280px; 104 | max-height: 40px; 105 | margin-left: 10px; 106 | margin-right: 10px; 107 | } 108 | 109 | .special { 110 | font-weight: bold; 111 | color: #2c2c2c; 112 | } 113 | 114 | -------------------------------------------------------------------------------- /tests/models/test_musicgen.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import pytest 8 | import torch 9 | 10 | from audiocraft.models import MusicGen 11 | 12 | 13 | class TestMusicGenModel: 14 | def get_musicgen(self): 15 | mg = MusicGen.get_pretrained(name='debug', device='cpu') 16 | mg.set_generation_params(duration=2.0, extend_stride=2.) 17 | return mg 18 | 19 | def test_base(self): 20 | mg = self.get_musicgen() 21 | assert mg.frame_rate == 25 22 | assert mg.sample_rate == 32000 23 | assert mg.audio_channels == 1 24 | 25 | def test_generate_unconditional(self): 26 | mg = self.get_musicgen() 27 | wav = mg.generate_unconditional(3) 28 | assert list(wav.shape) == [3, 1, 64000] 29 | 30 | def test_generate_continuation(self): 31 | mg = self.get_musicgen() 32 | prompt = torch.randn(3, 1, 32000) 33 | wav = mg.generate_continuation(prompt, 32000) 34 | assert list(wav.shape) == [3, 1, 64000] 35 | 36 | prompt = torch.randn(2, 1, 32000) 37 | wav = mg.generate_continuation( 38 | prompt, 32000, ['youpi', 'lapin dort']) 39 | assert list(wav.shape) == [2, 1, 64000] 40 | 41 | prompt = torch.randn(2, 1, 32000) 42 | with pytest.raises(AssertionError): 43 | wav = mg.generate_continuation( 44 | prompt, 32000, ['youpi', 'lapin dort', 'one too many']) 45 | 46 | def test_generate(self): 47 | mg = self.get_musicgen() 48 | wav = mg.generate( 49 | ['youpi', 'lapin dort']) 50 | assert list(wav.shape) == [2, 1, 64000] 51 | 52 | def test_generate_long(self): 53 | mg = self.get_musicgen() 54 | mg.max_duration = 3. 55 | mg.set_generation_params(duration=4., extend_stride=2.) 56 | wav = mg.generate( 57 | ['youpi', 'lapin dort']) 58 | assert list(wav.shape) == [2, 1, 32000 * 4] 59 | -------------------------------------------------------------------------------- /audiocraft/utils/export_legacy.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | """ 8 | Legacy functions used at the time of the first release, kept for referencd. 9 | """ 10 | 11 | from pathlib import Path 12 | import typing as tp 13 | 14 | from omegaconf import OmegaConf, DictConfig 15 | import torch 16 | 17 | 18 | def _clean_lm_cfg(cfg: DictConfig): 19 | OmegaConf.set_struct(cfg, False) 20 | # This used to be set automatically in the LM solver, need a more robust solution 21 | # for the future. 22 | cfg['transformer_lm']['card'] = 2048 23 | cfg['transformer_lm']['n_q'] = 4 24 | # Experimental params no longer supported. 25 | bad_params = ['spectral_norm_attn_iters', 'spectral_norm_ff_iters', 26 | 'residual_balancer_attn', 'residual_balancer_ff', 'layer_drop'] 27 | for name in bad_params: 28 | del cfg['transformer_lm'][name] 29 | OmegaConf.set_struct(cfg, True) 30 | return cfg 31 | 32 | 33 | def export_encodec(checkpoint_path: tp.Union[Path, str], out_folder: tp.Union[Path, str]): 34 | sig = Path(checkpoint_path).parent.name 35 | assert len(sig) == 8, "Not a valid Dora signature" 36 | pkg = torch.load(checkpoint_path, 'cpu') 37 | new_pkg = { 38 | 'best_state': pkg['ema']['state']['model'], 39 | 'xp.cfg': OmegaConf.to_yaml(pkg['xp.cfg']), 40 | } 41 | out_file = Path(out_folder) / f'{sig}.th' 42 | torch.save(new_pkg, out_file) 43 | return out_file 44 | 45 | 46 | def export_lm(checkpoint_path: tp.Union[Path, str], out_folder: tp.Union[Path, str]): 47 | sig = Path(checkpoint_path).parent.name 48 | assert len(sig) == 8, "Not a valid Dora signature" 49 | pkg = torch.load(checkpoint_path, 'cpu') 50 | new_pkg = { 51 | 'best_state': pkg['fsdp_best_state']['model'], 52 | 'xp.cfg': OmegaConf.to_yaml(_clean_lm_cfg(pkg['xp.cfg'])) 53 | } 54 | out_file = Path(out_folder) / f'{sig}.th' 55 | torch.save(new_pkg, out_file) 56 | return out_file 57 | -------------------------------------------------------------------------------- /tests/losses/test_losses.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import random 8 | 9 | import torch 10 | 11 | from audiocraft.losses import ( 12 | MelSpectrogramL1Loss, 13 | MultiScaleMelSpectrogramLoss, 14 | MRSTFTLoss, 15 | SISNR, 16 | STFTLoss, 17 | ) 18 | 19 | 20 | def test_mel_l1_loss(): 21 | N, C, T = 2, 2, random.randrange(1000, 100_000) 22 | t1 = torch.randn(N, C, T) 23 | t2 = torch.randn(N, C, T) 24 | 25 | mel_l1 = MelSpectrogramL1Loss(sample_rate=22_050) 26 | loss = mel_l1(t1, t2) 27 | loss_same = mel_l1(t1, t1) 28 | 29 | assert isinstance(loss, torch.Tensor) 30 | assert isinstance(loss_same, torch.Tensor) 31 | assert loss_same.item() == 0.0 32 | 33 | 34 | def test_msspec_loss(): 35 | N, C, T = 2, 2, random.randrange(1000, 100_000) 36 | t1 = torch.randn(N, C, T) 37 | t2 = torch.randn(N, C, T) 38 | 39 | msspec = MultiScaleMelSpectrogramLoss(sample_rate=22_050) 40 | loss = msspec(t1, t2) 41 | loss_same = msspec(t1, t1) 42 | 43 | assert isinstance(loss, torch.Tensor) 44 | assert isinstance(loss_same, torch.Tensor) 45 | assert loss_same.item() == 0.0 46 | 47 | 48 | def test_mrstft_loss(): 49 | N, C, T = 2, 2, random.randrange(1000, 100_000) 50 | t1 = torch.randn(N, C, T) 51 | t2 = torch.randn(N, C, T) 52 | 53 | mrstft = MRSTFTLoss() 54 | loss = mrstft(t1, t2) 55 | 56 | assert isinstance(loss, torch.Tensor) 57 | 58 | 59 | def test_sisnr_loss(): 60 | N, C, T = 2, 2, random.randrange(1000, 100_000) 61 | t1 = torch.randn(N, C, T) 62 | t2 = torch.randn(N, C, T) 63 | 64 | sisnr = SISNR() 65 | loss = sisnr(t1, t2) 66 | 67 | assert isinstance(loss, torch.Tensor) 68 | 69 | 70 | def test_stft_loss(): 71 | N, C, T = 2, 2, random.randrange(1000, 100_000) 72 | t1 = torch.randn(N, C, T) 73 | t2 = torch.randn(N, C, T) 74 | 75 | mrstft = STFTLoss() 76 | loss = mrstft(t1, t2) 77 | 78 | assert isinstance(loss, torch.Tensor) 79 | -------------------------------------------------------------------------------- /config/model/lm/default.yaml: -------------------------------------------------------------------------------- 1 | # @package __global__ 2 | defaults: 3 | - _self_ 4 | - /model/lm/model_scale: base # prefer this group to set model scale instead of transformer_lm keys directly 5 | 6 | lm_model: transformer_lm 7 | 8 | codebooks_pattern: 9 | modeling: parallel 10 | 11 | transformer_lm: 12 | dim: 512 13 | num_heads: 8 14 | num_layers: 8 15 | hidden_scale: 4 16 | n_q: 8 # number of streams to model 17 | card: 1024 18 | dropout: 0. 19 | emb_lr: null 20 | activation: gelu 21 | norm_first: false # use pre-norm instead of post-norm 22 | bias_ff: true # use bias for the feedforward 23 | bias_attn: true # use bias for the attention 24 | bias_proj: true # use bias for the output projections 25 | past_context: null 26 | causal: true 27 | custom: false # use custom MHA implementation 28 | memory_efficient: false # use flash attention 29 | attention_as_float32: false # use float32 for the attention part, 30 | # recommended at the moment when memory_efficient is True. 31 | layer_scale: null 32 | positional_embedding: sin # positional embedding strategy (sin, rope, or sin_rope). 33 | xpos: false # apply xpos decay (rope only). 34 | checkpointing: none # layer checkpointing method, can be none, torch, xformers_default. 35 | # torch is the slowest but uses the least memory, 36 | # xformers_default is somewhere in between. 37 | weight_init: null # weight initialization (null, gaussian or uniform) 38 | depthwise_init: null # perform depthwise initialization (null, current, global) 39 | zero_bias_init: false # initialize bias to zero if bias in linears and 40 | # if a weight_init method is used. 41 | norm: layer_norm # normalization method to use in transformer. 42 | cross_attention: false 43 | qk_layer_norm: false 44 | qk_layer_norm_cross: false 45 | attention_dropout: null 46 | kv_repeat: 1 47 | two_step_cfg: false # whether to do true 2 steps CFG, potentially resolving some padding issues or not... 48 | -------------------------------------------------------------------------------- /audiocraft/optim/polynomial_decay_lr_scheduler.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from torch.optim import Optimizer 8 | from torch.optim.lr_scheduler import _LRScheduler 9 | 10 | 11 | class PolynomialDecayLRScheduler(_LRScheduler): 12 | """Polynomial decay LR scheduler. 13 | 14 | Args: 15 | optimizer (Optimizer): Torch optimizer. 16 | warmup_steps (int): Number of warmup steps. 17 | total_steps (int): Total number of steps. 18 | end_lr (float): Final learning rate to achieve over total number of steps. 19 | zero_lr_warmup_steps (int): Number of steps with a learning rate of value 0. 20 | power (float): Decay exponent. 21 | """ 22 | def __init__(self, optimizer: Optimizer, warmup_steps: int, total_steps: int, 23 | end_lr: float = 0., zero_lr_warmup_steps: int = 0, power: float = 1.): 24 | self.warmup_steps = warmup_steps 25 | self.total_steps = total_steps 26 | self.end_lr = end_lr 27 | self.zero_lr_warmup_steps = zero_lr_warmup_steps 28 | self.power = power 29 | super().__init__(optimizer) 30 | 31 | def _get_sched_lr(self, lr: float, step: int): 32 | if self.zero_lr_warmup_steps > 0 and step <= self.zero_lr_warmup_steps: 33 | lr = 0 34 | elif self.warmup_steps > 0 and step <= self.warmup_steps + self.zero_lr_warmup_steps: 35 | lr_ratio = (step - self.zero_lr_warmup_steps) / float(self.warmup_steps) 36 | lr = lr_ratio * lr 37 | elif step >= self.total_steps: 38 | lr = self.end_lr 39 | else: 40 | total_warmup_steps = self.warmup_steps + self.zero_lr_warmup_steps 41 | lr_range = lr - self.end_lr 42 | pct_remaining = 1 - (step - total_warmup_steps) / (self.total_steps - total_warmup_steps) 43 | lr = lr_range * pct_remaining ** self.power + self.end_lr 44 | return lr 45 | 46 | def get_lr(self): 47 | return [self._get_sched_lr(base_lr, self.last_epoch) for base_lr in self.base_lrs] 48 | -------------------------------------------------------------------------------- /config/solver/diffusion/debug.yaml: -------------------------------------------------------------------------------- 1 | # @package __global__ 2 | 3 | defaults: 4 | - /solver/default 5 | - /model: score/basic 6 | - override /dset: audio/default 7 | - _self_ 8 | 9 | solver: diffusion 10 | 11 | sample_rate: 16000 12 | channels: 1 13 | compression_model_checkpoint: //sig/5091833e 14 | n_q: 2 # number of codebooks to keep 15 | 16 | dataset: 17 | batch_size: 8 18 | num_workers: 10 19 | segment_duration: 1 20 | train: 21 | num_samples: 100 22 | valid: 23 | num_samples: 100 24 | evaluate: 25 | batch_size: 8 26 | num_samples: 10 27 | generate: 28 | batch_size: 8 29 | num_samples: 10 30 | segment_duration: 10 31 | 32 | loss: 33 | kind: mse 34 | norm_power: 0. 35 | 36 | valid: 37 | every: 1 38 | 39 | evaluate: 40 | every: 5 41 | num_workers: 5 42 | metrics: 43 | visqol: false 44 | sisnr: false 45 | rvm: true 46 | 47 | generate: 48 | every: 5 49 | num_workers: 5 50 | audio: 51 | sample_rate: ${sample_rate} 52 | 53 | checkpoint: 54 | save_last: true 55 | save_every: 25 56 | keep_last: 10 57 | keep_every_states: null 58 | 59 | 60 | optim: 61 | epochs: 50 62 | updates_per_epoch: 2000 63 | lr: 2e-4 64 | max_norm: 0 65 | optimizer: adam 66 | adam: 67 | betas: [0.9, 0.999] 68 | weight_decay: 0. 69 | ema: 70 | use: true # whether to use EMA or not 71 | updates: 1 # update at every step 72 | device: ${device} # device for EMA, can be put on GPU if more frequent updates 73 | decay: 0.99 # EMA decay value, if null, no EMA is used 74 | 75 | processor: 76 | name: multi_band_processor 77 | use: false 78 | n_bands: 8 79 | num_samples: 10_000 80 | power_std: 1. 81 | 82 | resampling: 83 | use: false 84 | target_sr: 16000 85 | 86 | filter: 87 | use: false 88 | n_bands: 4 89 | idx_band: 0 90 | cutoffs: null 91 | 92 | schedule: 93 | repartition: "power" 94 | variable_step_batch: true 95 | beta_t0: 1.0e-5 96 | beta_t1: 2.9e-2 97 | beta_exp: 7.5 98 | num_steps: 1000 99 | variance: 'beta' 100 | clip: 5. 101 | rescale: 1. 102 | n_bands: null 103 | noise_scale: 1.0 104 | 105 | metrics: 106 | num_stage: 4 107 | -------------------------------------------------------------------------------- /config/solver/diffusion/default.yaml: -------------------------------------------------------------------------------- 1 | # @package __global__ 2 | 3 | defaults: 4 | - /solver/default 5 | - /model: score/basic 6 | - override /dset: audio/default 7 | - _self_ 8 | 9 | solver: diffusion 10 | 11 | sample_rate: ??? 12 | channels: ??? 13 | compression_model_checkpoint: ??? 14 | n_q: ??? # number of codebooks to keep 15 | 16 | 17 | dataset: 18 | batch_size: 128 19 | num_workers: 10 20 | segment_duration: 1 21 | train: 22 | num_samples: 500000 23 | valid: 24 | num_samples: 10000 25 | evaluate: 26 | batch_size: 16 27 | num_samples: 10000 28 | generate: 29 | batch_size: 32 30 | num_samples: 50 31 | segment_duration: 10 32 | audio: 33 | sample_rate: ${sample_rate} 34 | 35 | loss: 36 | kind: mse 37 | norm_power: 0. 38 | 39 | valid: 40 | every: 1 41 | 42 | evaluate: 43 | every: 20 44 | num_workers: 5 45 | metrics: 46 | visqol: false 47 | sisnr: false 48 | rvm: true 49 | 50 | generate: 51 | every: 25 52 | num_workers: 5 53 | 54 | checkpoint: 55 | save_last: true 56 | save_every: 25 57 | keep_last: 10 58 | keep_every_states: null 59 | 60 | 61 | optim: 62 | epochs: 20000 63 | updates_per_epoch: 2000 64 | lr: 2e-4 65 | max_norm: 0 66 | optimizer: adam 67 | adam: 68 | betas: [0.9, 0.999] 69 | weight_decay: 0. 70 | ema: 71 | use: true # whether to use EMA or not 72 | updates: 1 # update at every step 73 | device: ${device} # device for EMA, can be put on GPU if more frequent updates 74 | decay: 0.99 # EMA decay value, if null, no EMA is used 75 | 76 | processor: 77 | name: multi_band_processor 78 | use: false 79 | n_bands: 8 80 | num_samples: 10_000 81 | power_std: 1. 82 | 83 | resampling: 84 | use: false 85 | target_sr: 16000 86 | 87 | filter: 88 | use: false 89 | n_bands: 4 90 | idx_band: 0 91 | cutoffs: null 92 | 93 | schedule: 94 | repartition: "power" 95 | variable_step_batch: true 96 | beta_t0: 1.0e-5 97 | beta_t1: 2.9e-2 98 | beta_exp: 7.5 99 | num_steps: 1000 100 | variance: 'beta' 101 | clip: 5. 102 | rescale: 1. 103 | n_bands: null 104 | noise_scale: 1.0 105 | 106 | metrics: 107 | num_stage: 4 108 | -------------------------------------------------------------------------------- /audiocraft/grids/diffusion/_explorers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import treetable as tt 8 | 9 | from .._base_explorers import BaseExplorer 10 | 11 | 12 | class DiffusionExplorer(BaseExplorer): 13 | eval_metrics = ["sisnr", "visqol"] 14 | 15 | def stages(self): 16 | return ["train", "valid", "valid_ema", "evaluate", "evaluate_ema"] 17 | 18 | def get_grid_meta(self): 19 | """Returns the list of Meta information to display for each XP/job. 20 | """ 21 | return [ 22 | tt.leaf("index", align=">"), 23 | tt.leaf("name", wrap=140), 24 | tt.leaf("state"), 25 | tt.leaf("sig", align=">"), 26 | ] 27 | 28 | def get_grid_metrics(self): 29 | """Return the metrics that should be displayed in the tracking table. 30 | """ 31 | return [ 32 | tt.group( 33 | "train", 34 | [ 35 | tt.leaf("epoch"), 36 | tt.leaf("loss", ".3%"), 37 | ], 38 | align=">", 39 | ), 40 | tt.group( 41 | "valid", 42 | [ 43 | tt.leaf("loss", ".3%"), 44 | # tt.leaf("loss_0", ".3%"), 45 | ], 46 | align=">", 47 | ), 48 | tt.group( 49 | "valid_ema", 50 | [ 51 | tt.leaf("loss", ".3%"), 52 | # tt.leaf("loss_0", ".3%"), 53 | ], 54 | align=">", 55 | ), 56 | tt.group( 57 | "evaluate", [tt.leaf("rvm", ".4f"), tt.leaf("rvm_0", ".4f"), 58 | tt.leaf("rvm_1", ".4f"), tt.leaf("rvm_2", ".4f"), 59 | tt.leaf("rvm_3", ".4f"), ], align=">" 60 | ), 61 | tt.group( 62 | "evaluate_ema", [tt.leaf("rvm", ".4f"), tt.leaf("rvm_0", ".4f"), 63 | tt.leaf("rvm_1", ".4f"), tt.leaf("rvm_2", ".4f"), 64 | tt.leaf("rvm_3", ".4f")], align=">" 65 | ), 66 | ] 67 | -------------------------------------------------------------------------------- /audiocraft/utils/cluster.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | """ 8 | Utility functions for SLURM configuration and cluster settings. 9 | """ 10 | 11 | from enum import Enum 12 | import os 13 | import socket 14 | import typing as tp 15 | 16 | import omegaconf 17 | 18 | 19 | class ClusterType(Enum): 20 | AWS = "aws" 21 | FAIR = "fair" 22 | RSC = "rsc" 23 | LOCAL_DARWIN = "darwin" 24 | DEFAULT = "default" # used for any other cluster. 25 | 26 | 27 | def _guess_cluster_type() -> ClusterType: 28 | uname = os.uname() 29 | fqdn = socket.getfqdn() 30 | if uname.sysname == "Linux" and (uname.release.endswith("-aws") or ".ec2" in fqdn): 31 | return ClusterType.AWS 32 | 33 | if fqdn.endswith(".fair"): 34 | return ClusterType.FAIR 35 | 36 | if fqdn.endswith(".facebook.com"): 37 | return ClusterType.RSC 38 | 39 | if uname.sysname == "Darwin": 40 | return ClusterType.LOCAL_DARWIN 41 | 42 | return ClusterType.DEFAULT 43 | 44 | 45 | def get_cluster_type( 46 | cluster_type: tp.Optional[ClusterType] = None, 47 | ) -> tp.Optional[ClusterType]: 48 | if cluster_type is None: 49 | return _guess_cluster_type() 50 | 51 | return cluster_type 52 | 53 | 54 | def get_slurm_parameters( 55 | cfg: omegaconf.DictConfig, cluster_type: tp.Optional[ClusterType] = None 56 | ) -> omegaconf.DictConfig: 57 | """Update SLURM parameters in configuration based on cluster type. 58 | If the cluster type is not specify, it infers it automatically. 59 | """ 60 | from ..environment import AudioCraftEnvironment 61 | cluster_type = get_cluster_type(cluster_type) 62 | # apply cluster-specific adjustments 63 | if cluster_type == ClusterType.AWS: 64 | cfg["mem_per_gpu"] = None 65 | cfg["constraint"] = None 66 | cfg["setup"] = [] 67 | elif cluster_type == ClusterType.RSC: 68 | cfg["mem_per_gpu"] = None 69 | cfg["setup"] = [] 70 | cfg["constraint"] = None 71 | cfg["partition"] = "learn" 72 | slurm_exclude = AudioCraftEnvironment.get_slurm_exclude() 73 | if slurm_exclude is not None: 74 | cfg["exclude"] = slurm_exclude 75 | return cfg 76 | -------------------------------------------------------------------------------- /audiocraft/data/zip.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | """Utility for reading some info from inside a zip file. 7 | """ 8 | 9 | import typing 10 | import zipfile 11 | 12 | from dataclasses import dataclass 13 | from functools import lru_cache 14 | from typing_extensions import Literal 15 | 16 | 17 | DEFAULT_SIZE = 32 18 | MODE = Literal['r', 'w', 'x', 'a'] 19 | 20 | 21 | @dataclass(order=True) 22 | class PathInZip: 23 | """Hold a path of file within a zip file. 24 | 25 | Args: 26 | path (str): The convention is :. 27 | Let's assume there is a zip file /some/location/foo.zip 28 | and inside of it is a json file located at /data/file1.json, 29 | Then we expect path = "/some/location/foo.zip:/data/file1.json". 30 | """ 31 | 32 | INFO_PATH_SEP = ':' 33 | zip_path: str 34 | file_path: str 35 | 36 | def __init__(self, path: str) -> None: 37 | split_path = path.split(self.INFO_PATH_SEP) 38 | assert len(split_path) == 2 39 | self.zip_path, self.file_path = split_path 40 | 41 | @classmethod 42 | def from_paths(cls, zip_path: str, file_path: str): 43 | return cls(zip_path + cls.INFO_PATH_SEP + file_path) 44 | 45 | def __str__(self) -> str: 46 | return self.zip_path + self.INFO_PATH_SEP + self.file_path 47 | 48 | 49 | def _open_zip(path: str, mode: MODE = 'r'): 50 | return zipfile.ZipFile(path, mode) 51 | 52 | 53 | _cached_open_zip = lru_cache(DEFAULT_SIZE)(_open_zip) 54 | 55 | 56 | def set_zip_cache_size(max_size: int): 57 | """Sets the maximal LRU caching for zip file opening. 58 | 59 | Args: 60 | max_size (int): the maximal LRU cache. 61 | """ 62 | global _cached_open_zip 63 | _cached_open_zip = lru_cache(max_size)(_open_zip) 64 | 65 | 66 | def open_file_in_zip(path_in_zip: PathInZip, mode: str = 'r') -> typing.IO: 67 | """Opens a file stored inside a zip and returns a file-like object. 68 | 69 | Args: 70 | path_in_zip (PathInZip): A PathInZip object representing the file to return a file-like object of. 71 | mode (str): The mode in which to open the file with. 72 | Returns: 73 | A file-like object for PathInZip. 74 | """ 75 | zf = _cached_open_zip(path_in_zip.zip_path) 76 | return zf.open(path_in_zip.file_path) 77 | -------------------------------------------------------------------------------- /audiocraft/grids/musicgen/musicgen_melody_32khz.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from ._explorers import LMExplorer 8 | from ...environment import AudioCraftEnvironment 9 | 10 | 11 | @LMExplorer 12 | def explorer(launcher): 13 | partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global']) 14 | launcher.slurm_(gpus=32, partition=partitions) 15 | launcher.bind_(solver='musicgen/musicgen_melody_32khz') 16 | # replace this by the desired music dataset 17 | launcher.bind_(dset='internal/music_400k_32khz') 18 | 19 | fsdp = {'autocast': False, 'fsdp.use': True} 20 | medium = {'model/lm/model_scale': 'medium'} 21 | large = {'model/lm/model_scale': 'large'} 22 | 23 | cfg_low = {'classifier_free_guidance.training_dropout': 0.2} 24 | wd_low = {'conditioners.description.t5.word_dropout': 0.2} 25 | 26 | adam = {'optim.optimizer': 'adamw', 'optim.lr': 1e-4} 27 | 28 | cache_path = {'conditioners.self_wav.chroma_stem.cache_path': 29 | '/fsx-audio-craft-llm/jadecopet/experiments/audiocraft/caches/chroma_stem'} 30 | 31 | # CACHE GENERATION JOBS 32 | n_cache_gen_jobs = 4 33 | gen_sub = launcher.slurm(gpus=1) 34 | gen_sub.bind_( 35 | cache_path, { 36 | # the cache is always computed over the whole file, so duration doesn't matter here. 37 | 'dataset.segment_duration': 2., 38 | 'dataset.batch_size': 8, 39 | 'dataset.train.permutation_on_files': True, # try to not repeat files. 40 | 'optim.epochs': 10, 41 | 'model/lm/model_scale': 'xsmall', 42 | 43 | }) 44 | with gen_sub.job_array(): 45 | for gen_job in range(n_cache_gen_jobs): 46 | gen_sub({'dataset.train.shuffle_seed': gen_job}) 47 | 48 | # ACTUAL TRAINING JOBS. 49 | launcher.bind_(fsdp) 50 | 51 | launcher.slurm_(gpus=32).bind_(label='32gpus') 52 | with launcher.job_array(): 53 | sub = launcher.bind() 54 | sub() 55 | sub(cache_path) 56 | 57 | launcher.slurm_(gpus=64).bind_(label='64gpus') 58 | with launcher.job_array(): 59 | sub = launcher.bind() 60 | sub(medium, adam) 61 | 62 | launcher.slurm_(gpus=96).bind_(label='96gpus') 63 | with launcher.job_array(): 64 | sub = launcher.bind() 65 | sub(large, cfg_low, wd_low, adam, {'optim.max_norm': 3}) 66 | -------------------------------------------------------------------------------- /tests/models/test_multibanddiffusion.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import random 8 | 9 | import numpy as np 10 | import torch 11 | from audiocraft.models.multibanddiffusion import MultiBandDiffusion, DiffusionProcess 12 | from audiocraft.models import EncodecModel, DiffusionUnet 13 | from audiocraft.modules import SEANetEncoder, SEANetDecoder 14 | from audiocraft.modules.diffusion_schedule import NoiseSchedule 15 | from audiocraft.quantization import DummyQuantizer 16 | 17 | 18 | class TestMBD: 19 | 20 | def _create_mbd(self, 21 | sample_rate: int, 22 | channels: int, 23 | n_filters: int = 3, 24 | n_residual_layers: int = 1, 25 | ratios: list = [5, 4, 3, 2], 26 | num_steps: int = 1000, 27 | codec_dim: int = 128, 28 | **kwargs): 29 | frame_rate = np.prod(ratios) 30 | encoder = SEANetEncoder(channels=channels, dimension=codec_dim, n_filters=n_filters, 31 | n_residual_layers=n_residual_layers, ratios=ratios) 32 | decoder = SEANetDecoder(channels=channels, dimension=codec_dim, n_filters=n_filters, 33 | n_residual_layers=n_residual_layers, ratios=ratios) 34 | quantizer = DummyQuantizer() 35 | compression_model = EncodecModel(encoder, decoder, quantizer, frame_rate=frame_rate, 36 | sample_rate=sample_rate, channels=channels, **kwargs) 37 | diffusion_model = DiffusionUnet(chin=channels, num_steps=num_steps, codec_dim=codec_dim) 38 | schedule = NoiseSchedule(device='cpu', num_steps=num_steps) 39 | DP = DiffusionProcess(model=diffusion_model, noise_schedule=schedule) 40 | mbd = MultiBandDiffusion(DPs=[DP], codec_model=compression_model) 41 | return mbd 42 | 43 | def test_model(self): 44 | random.seed(1234) 45 | sample_rate = 24_000 46 | channels = 1 47 | codec_dim = 128 48 | mbd = self._create_mbd(sample_rate=sample_rate, channels=channels, codec_dim=codec_dim) 49 | for _ in range(10): 50 | length = random.randrange(1, 10_000) 51 | x = torch.randn(2, channels, length) 52 | res = mbd.regenerate(x, sample_rate) 53 | assert res.shape == x.shape 54 | -------------------------------------------------------------------------------- /tests/adversarial/test_discriminators.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import random 8 | 9 | import torch 10 | 11 | from audiocraft.adversarial.discriminators import ( 12 | MultiPeriodDiscriminator, 13 | MultiScaleDiscriminator, 14 | MultiScaleSTFTDiscriminator 15 | ) 16 | 17 | 18 | class TestMultiPeriodDiscriminator: 19 | 20 | def test_mpd_discriminator(self): 21 | N, C, T = 2, 2, random.randrange(1, 100_000) 22 | t0 = torch.randn(N, C, T) 23 | periods = [1, 2, 3] 24 | mpd = MultiPeriodDiscriminator(periods=periods, in_channels=C) 25 | logits, fmaps = mpd(t0) 26 | 27 | assert len(logits) == len(periods) 28 | assert len(fmaps) == len(periods) 29 | assert all([logit.shape[0] == N and len(logit.shape) == 4 for logit in logits]) 30 | assert all([feature.shape[0] == N for fmap in fmaps for feature in fmap]) 31 | 32 | 33 | class TestMultiScaleDiscriminator: 34 | 35 | def test_msd_discriminator(self): 36 | N, C, T = 2, 2, random.randrange(1, 100_000) 37 | t0 = torch.randn(N, C, T) 38 | 39 | scale_norms = ['weight_norm', 'weight_norm'] 40 | msd = MultiScaleDiscriminator(scale_norms=scale_norms, in_channels=C) 41 | logits, fmaps = msd(t0) 42 | 43 | assert len(logits) == len(scale_norms) 44 | assert len(fmaps) == len(scale_norms) 45 | assert all([logit.shape[0] == N and len(logit.shape) == 3 for logit in logits]) 46 | assert all([feature.shape[0] == N for fmap in fmaps for feature in fmap]) 47 | 48 | 49 | class TestMultiScaleStftDiscriminator: 50 | 51 | def test_msstftd_discriminator(self): 52 | N, C, T = 2, 2, random.randrange(1, 100_000) 53 | t0 = torch.randn(N, C, T) 54 | 55 | n_filters = 4 56 | n_ffts = [128, 256, 64] 57 | hop_lengths = [32, 64, 16] 58 | win_lengths = [128, 256, 64] 59 | 60 | msstftd = MultiScaleSTFTDiscriminator(filters=n_filters, n_ffts=n_ffts, hop_lengths=hop_lengths, 61 | win_lengths=win_lengths, in_channels=C) 62 | logits, fmaps = msstftd(t0) 63 | 64 | assert len(logits) == len(n_ffts) 65 | assert len(fmaps) == len(n_ffts) 66 | assert all([logit.shape[0] == N and len(logit.shape) == 4 for logit in logits]) 67 | assert all([feature.shape[0] == N for fmap in fmaps for feature in fmap]) 68 | -------------------------------------------------------------------------------- /tests/models/test_encodec_model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import random 8 | 9 | import numpy as np 10 | import torch 11 | 12 | from audiocraft.models import EncodecModel 13 | from audiocraft.modules import SEANetEncoder, SEANetDecoder 14 | from audiocraft.quantization import DummyQuantizer 15 | 16 | 17 | class TestEncodecModel: 18 | 19 | def _create_encodec_model(self, 20 | sample_rate: int, 21 | channels: int, 22 | dim: int = 5, 23 | n_filters: int = 3, 24 | n_residual_layers: int = 1, 25 | ratios: list = [5, 4, 3, 2], 26 | **kwargs): 27 | frame_rate = np.prod(ratios) 28 | encoder = SEANetEncoder(channels=channels, dimension=dim, n_filters=n_filters, 29 | n_residual_layers=n_residual_layers, ratios=ratios) 30 | decoder = SEANetDecoder(channels=channels, dimension=dim, n_filters=n_filters, 31 | n_residual_layers=n_residual_layers, ratios=ratios) 32 | quantizer = DummyQuantizer() 33 | model = EncodecModel(encoder, decoder, quantizer, frame_rate=frame_rate, 34 | sample_rate=sample_rate, channels=channels, **kwargs) 35 | return model 36 | 37 | def test_model(self): 38 | random.seed(1234) 39 | sample_rate = 24_000 40 | channels = 1 41 | model = self._create_encodec_model(sample_rate, channels) 42 | for _ in range(10): 43 | length = random.randrange(1, 10_000) 44 | x = torch.randn(2, channels, length) 45 | res = model(x) 46 | assert res.x.shape == x.shape 47 | 48 | def test_model_renorm(self): 49 | random.seed(1234) 50 | sample_rate = 24_000 51 | channels = 1 52 | model_nonorm = self._create_encodec_model(sample_rate, channels, renormalize=False) 53 | model_renorm = self._create_encodec_model(sample_rate, channels, renormalize=True) 54 | 55 | for _ in range(10): 56 | length = random.randrange(1, 10_000) 57 | x = torch.randn(2, channels, length) 58 | codes, scales = model_nonorm.encode(x) 59 | codes, scales = model_renorm.encode(x) 60 | assert scales is not None 61 | -------------------------------------------------------------------------------- /audiocraft/grids/musicgen/musicgen_base_cached_32khz.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from ._explorers import LMExplorer 8 | from ...environment import AudioCraftEnvironment 9 | 10 | 11 | @LMExplorer 12 | def explorer(launcher): 13 | partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global']) 14 | launcher.slurm_(gpus=32, partition=partitions) 15 | launcher.bind_(solver='musicgen/musicgen_base_32khz') 16 | # replace this by the desired music dataset 17 | launcher.bind_(dset='internal/music_400k_32khz') 18 | 19 | fsdp = {'autocast': False, 'fsdp.use': True} 20 | medium = {'model/lm/model_scale': 'medium'} 21 | large = {'model/lm/model_scale': 'large'} 22 | 23 | cfg_low = {'classifier_free_guidance.training_dropout': 0.2} 24 | wd_low = {'conditioners.description.t5.word_dropout': 0.2} 25 | 26 | adam = {'optim.optimizer': 'adamw', 'optim.lr': 1e-4} 27 | 28 | # BEGINNING OF CACHE WRITING JOBS. 29 | cache_write = { 30 | 'cache.path': '/fsx-codegen/defossez/cache/interleave_stereo_nv_32k', 31 | 'cache.write': True, 32 | 'generate.every': 500, 33 | 'evaluate.every': 500, 34 | 'logging.log_updates': 50, 35 | } 36 | 37 | cache_sub = launcher.bind({'model/lm/model_scale': 'xsmall', 'conditioner': 'none'}) 38 | cache_sub.bind_({'deadlock.use': True}) 39 | cache_sub.slurm_(gpus=8) 40 | with launcher.job_array(): 41 | num_shards = 10 # total number of jobs running in parallel. 42 | for shard in range(0, num_shards): 43 | launcher(cache_write, {'cache.write_num_shards': num_shards, 'cache.write_shard': shard}) 44 | 45 | # REMOVE THE FOLLOWING RETURN STATEMENT ONCE THE ABOVE JOBS ARE DONE, 46 | # OR SUFFICIENTLY AHEAD. 47 | return 48 | 49 | cache = { 50 | 'cache.path': '/fsx-codegen/defossez/cache/interleave_stereo_nv_32k', 51 | } 52 | launcher.bind_(fsdp, cache) 53 | 54 | launcher.slurm_(gpus=32).bind_(label='32gpus') 55 | with launcher.job_array(): 56 | sub = launcher.bind() 57 | sub() 58 | 59 | launcher.slurm_(gpus=64).bind_(label='64gpus') 60 | with launcher.job_array(): 61 | sub = launcher.bind() 62 | sub(medium, adam) 63 | 64 | launcher.slurm_(gpus=96).bind_(label='96gpus') 65 | with launcher.job_array(): 66 | sub = launcher.bind() 67 | sub(large, cfg_low, wd_low, adam, {'optim.max_norm': 3}) 68 | -------------------------------------------------------------------------------- /audiocraft/grids/audiogen/audiogen_pretrained_16khz_eval.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | """ 8 | Evaluation with objective metrics for the pretrained AudioGen models. 9 | This grid takes signature from the training grid and runs evaluation-only stage. 10 | 11 | When running the grid for the first time, please use: 12 | REGEN=1 dora grid audiogen.audiogen_pretrained_16khz_eval 13 | and re-use the REGEN=1 option when the grid is changed to force regenerating it. 14 | 15 | Note that you need the proper metrics external libraries setup to use all 16 | the objective metrics activated in this grid. Refer to the README for more information. 17 | """ 18 | 19 | import os 20 | 21 | from ..musicgen._explorers import GenerationEvalExplorer 22 | from ...environment import AudioCraftEnvironment 23 | from ... import train 24 | 25 | 26 | def eval(launcher, batch_size: int = 32): 27 | opts = { 28 | 'dset': 'audio/audiocaps_16khz', 29 | 'solver/audiogen/evaluation': 'objective_eval', 30 | 'execute_only': 'evaluate', 31 | '+dataset.evaluate.batch_size': batch_size, 32 | '+metrics.fad.tf.batch_size': 32, 33 | } 34 | # binary for FAD computation: replace this path with your own path 35 | metrics_opts = { 36 | 'metrics.fad.tf.bin': '/data/home/jadecopet/local/usr/opt/google-research' 37 | } 38 | opt1 = {'generate.lm.use_sampling': True, 'generate.lm.top_k': 250, 'generate.lm.top_p': 0.} 39 | opt2 = {'transformer_lm.two_step_cfg': True} 40 | 41 | sub = launcher.bind(opts) 42 | sub.bind_(metrics_opts) 43 | 44 | # base objective metrics 45 | sub(opt1, opt2) 46 | 47 | 48 | @GenerationEvalExplorer 49 | def explorer(launcher): 50 | partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global']) 51 | launcher.slurm_(gpus=4, partition=partitions) 52 | 53 | if 'REGEN' not in os.environ: 54 | folder = train.main.dora.dir / 'grids' / __name__.split('.', 2)[-1] 55 | with launcher.job_array(): 56 | for sig in folder.iterdir(): 57 | if not sig.is_symlink(): 58 | continue 59 | xp = train.main.get_xp_from_sig(sig.name) 60 | launcher(xp.argv) 61 | return 62 | 63 | audiogen_base = launcher.bind(solver="audiogen/audiogen_base_16khz") 64 | audiogen_base.bind_({'autocast': False, 'fsdp.use': True}) 65 | 66 | audiogen_base_medium = audiogen_base.bind({'continue_from': '//pretrained/facebook/audiogen-medium'}) 67 | audiogen_base_medium.bind_({'model/lm/model_scale': 'medium'}) 68 | eval(audiogen_base_medium, batch_size=128) 69 | -------------------------------------------------------------------------------- /audiocraft/grids/_base_explorers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from abc import ABC, abstractmethod 8 | import time 9 | import typing as tp 10 | from dora import Explorer 11 | import treetable as tt 12 | 13 | 14 | def get_sheep_ping(sheep) -> tp.Optional[str]: 15 | """Return the amount of time since the Sheep made some update 16 | to its log. Returns a str using the relevant time unit.""" 17 | ping = None 18 | if sheep.log is not None and sheep.log.exists(): 19 | delta = time.time() - sheep.log.stat().st_mtime 20 | if delta > 3600 * 24: 21 | ping = f'{delta / (3600 * 24):.1f}d' 22 | elif delta > 3600: 23 | ping = f'{delta / (3600):.1f}h' 24 | elif delta > 60: 25 | ping = f'{delta / 60:.1f}m' 26 | else: 27 | ping = f'{delta:.1f}s' 28 | return ping 29 | 30 | 31 | class BaseExplorer(ABC, Explorer): 32 | """Base explorer for AudioCraft grids. 33 | 34 | All task specific solvers are expected to implement the `get_grid_metrics` 35 | method to specify logic about metrics to display for a given task. 36 | 37 | If additional stages are used, the child explorer must define how to handle 38 | these new stages in the `process_history` and `process_sheep` methods. 39 | """ 40 | def stages(self): 41 | return ["train", "valid", "evaluate"] 42 | 43 | def get_grid_meta(self): 44 | """Returns the list of Meta information to display for each XP/job. 45 | """ 46 | return [ 47 | tt.leaf("index", align=">"), 48 | tt.leaf("name", wrap=140), 49 | tt.leaf("state"), 50 | tt.leaf("sig", align=">"), 51 | tt.leaf("sid", align="<"), 52 | ] 53 | 54 | @abstractmethod 55 | def get_grid_metrics(self): 56 | """Return the metrics that should be displayed in the tracking table. 57 | """ 58 | ... 59 | 60 | def process_sheep(self, sheep, history): 61 | train = { 62 | "epoch": len(history), 63 | } 64 | parts = {"train": train} 65 | for metrics in history: 66 | for key, sub in metrics.items(): 67 | part = parts.get(key, {}) 68 | if 'duration' in sub: 69 | # Convert to minutes for readability. 70 | sub['duration'] = sub['duration'] / 60. 71 | part.update(sub) 72 | parts[key] = part 73 | ping = get_sheep_ping(sheep) 74 | if ping is not None: 75 | for name in self.stages(): 76 | if name not in parts: 77 | parts[name] = {} 78 | # Add the ping to each part for convenience. 79 | parts[name]['ping'] = ping 80 | return parts 81 | -------------------------------------------------------------------------------- /config/solver/musicgen/default.yaml: -------------------------------------------------------------------------------- 1 | # @package __global__ 2 | 3 | defaults: 4 | - /solver/default 5 | - /conditioner: none 6 | - _self_ 7 | - /solver/musicgen/evaluation: none 8 | - override /dset: audio/default 9 | 10 | autocast: true 11 | autocast_dtype: float16 12 | 13 | solver: musicgen 14 | sample_rate: ??? 15 | channels: ??? 16 | compression_model_checkpoint: ??? 17 | 18 | tokens: 19 | padding_with_special_token: false 20 | 21 | cache: 22 | path: 23 | write: false 24 | write_shard: 0 25 | write_num_shards: 1 26 | 27 | 28 | dataset: 29 | batch_size: 128 30 | num_workers: 10 31 | segment_duration: 30 32 | min_segment_ratio: 0.8 # lower values such as 0.5 result in generations with a lot of silence. 33 | return_info: true 34 | train: 35 | num_samples: 1000000 # need a randomly large number here for AudioDataset 36 | valid: 37 | num_samples: 10000 38 | generate: 39 | num_samples: 50 40 | 41 | metrics: 42 | fad: 43 | use_gt: false 44 | model: tf 45 | tf: 46 | bin: null # path to local frechet_audio_distance code 47 | model_path: //reference/fad/vggish_model.ckpt 48 | kld: 49 | use_gt: false 50 | model: passt 51 | passt: 52 | pretrained_length: 20 53 | text_consistency: 54 | use_gt: false 55 | model: clap 56 | clap: 57 | model_path: //reference/clap/music_audioset_epoch_15_esc_90.14.pt 58 | model_arch: 'HTSAT-base' 59 | enable_fusion: false 60 | chroma_cosine: 61 | use_gt: false 62 | model: chroma_base 63 | chroma_base: 64 | sample_rate: ${sample_rate} 65 | n_chroma: 12 66 | radix2_exp: 14 67 | argmax: true 68 | 69 | generate: 70 | every: 25 71 | num_workers: 5 72 | path: samples 73 | audio: 74 | format: wav 75 | strategy: loudness 76 | sample_rate: ${sample_rate} 77 | loudness_headroom_db: 14 78 | lm: 79 | prompted_samples: true 80 | unprompted_samples: true 81 | gen_gt_samples: false 82 | prompt_duration: null # if not set, will use dataset.generate.segment_duration / 4 83 | gen_duration: null # if not set, will use dataset.generate.segment_duration 84 | remove_prompts: false 85 | # generation params 86 | use_sampling: false 87 | temp: 1.0 88 | top_k: 0 89 | top_p: 0.0 90 | evaluate: 91 | every: 25 92 | num_workers: 5 93 | metrics: 94 | base: false 95 | fad: false 96 | kld: false 97 | text_consistency: false 98 | chroma_cosine: false 99 | 100 | checkpoint: 101 | save_last: true 102 | save_every: 50 103 | keep_last: 10 104 | keep_every_states: null 105 | 106 | optim: 107 | epochs: 200 108 | updates_per_epoch: 2000 109 | lr: 1e-4 110 | optimizer: adamw 111 | max_norm: 1.0 112 | eager_sync: true 113 | adam: 114 | betas: [0.9, 0.95] 115 | weight_decay: 0.1 116 | eps: 1e-8 117 | 118 | schedule: 119 | lr_scheduler: null 120 | -------------------------------------------------------------------------------- /audiocraft/utils/export.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | """ 8 | Utility to export a training checkpoint to a lightweight release checkpoint. 9 | """ 10 | 11 | from pathlib import Path 12 | import typing as tp 13 | 14 | from omegaconf import OmegaConf 15 | import torch 16 | 17 | from audiocraft import __version__ 18 | 19 | 20 | def export_encodec(checkpoint_path: tp.Union[Path, str], out_file: tp.Union[Path, str]): 21 | """Export only the best state from the given EnCodec checkpoint. This 22 | should be used if you trained your own EnCodec model. 23 | """ 24 | pkg = torch.load(checkpoint_path, 'cpu') 25 | new_pkg = { 26 | 'best_state': pkg['best_state']['model'], 27 | 'xp.cfg': OmegaConf.to_yaml(pkg['xp.cfg']), 28 | 'version': __version__, 29 | 'exported': True, 30 | } 31 | Path(out_file).parent.mkdir(exist_ok=True, parents=True) 32 | torch.save(new_pkg, out_file) 33 | return out_file 34 | 35 | 36 | def export_pretrained_compression_model(pretrained_encodec: str, out_file: tp.Union[Path, str]): 37 | """Export a compression model (potentially EnCodec) from a pretrained model. 38 | This is required for packaging the audio tokenizer along a MusicGen or AudioGen model. 39 | Do not include the //pretrained/ prefix. For instance if you trained a model 40 | with `facebook/encodec_32khz`, just put that as a name. Same for `dac_44khz`. 41 | 42 | In that case, this will not actually include a copy of the model, simply the reference 43 | to the model used. 44 | """ 45 | if Path(pretrained_encodec).exists(): 46 | pkg = torch.load(pretrained_encodec) 47 | assert 'best_state' in pkg 48 | assert 'xp.cfg' in pkg 49 | assert 'version' in pkg 50 | assert 'exported' in pkg 51 | else: 52 | pkg = { 53 | 'pretrained': pretrained_encodec, 54 | 'exported': True, 55 | 'version': __version__, 56 | } 57 | Path(out_file).parent.mkdir(exist_ok=True, parents=True) 58 | torch.save(pkg, out_file) 59 | 60 | 61 | def export_lm(checkpoint_path: tp.Union[Path, str], out_file: tp.Union[Path, str]): 62 | """Export only the best state from the given MusicGen or AudioGen checkpoint. 63 | """ 64 | pkg = torch.load(checkpoint_path, 'cpu') 65 | if pkg['fsdp_best_state']: 66 | best_state = pkg['fsdp_best_state']['model'] 67 | else: 68 | assert pkg['best_state'] 69 | best_state = pkg['best_state']['model'] 70 | new_pkg = { 71 | 'best_state': best_state, 72 | 'xp.cfg': OmegaConf.to_yaml(pkg['xp.cfg']), 73 | 'version': __version__, 74 | 'exported': True, 75 | } 76 | 77 | Path(out_file).parent.mkdir(exist_ok=True, parents=True) 78 | torch.save(new_pkg, out_file) 79 | return out_file 80 | -------------------------------------------------------------------------------- /config/config.yaml: -------------------------------------------------------------------------------- 1 | # WARNING: This is the base configuration file shared across ALL solvers in AudioCraft 2 | # Please don't update this file directly. Instead use distinct configuration files 3 | # to override the below configuration. 4 | defaults: 5 | - _self_ 6 | - dset: default 7 | - solver: default 8 | 9 | device: cuda 10 | dtype: float32 11 | autocast: false 12 | autocast_dtype: bfloat16 13 | seed: 2036 14 | show: false # just show the model and its size and exit 15 | continue_from: # continue from a given sig or path 16 | execute_only: # can be set to generate/evaluate/valid to run that stage 17 | execute_inplace: false # don't enforce continue_from to be set 18 | # to enable inplace execution of the stage. This assume 19 | # that you know what you are doing and execute stage 20 | # preserving the original xp sig. 21 | benchmark_no_load: false # if set to true, will repeat the same batch instead of loading them 22 | 23 | efficient_attention_backend: torch # can be torch or xformers. 24 | num_threads: 1 # called with torch.set_num_thread. 25 | mp_start_method: forkserver # multiprocessing method (spawn, fork or fork_server). 26 | 27 | 28 | label: # use this if you want twice the same exp, with a name. 29 | 30 | # logging parameters 31 | logging: 32 | level: INFO 33 | log_updates: 10 34 | log_tensorboard: false 35 | log_wandb: false 36 | tensorboard: 37 | with_media_logging: false 38 | name: # optional name for the experiment 39 | sub_dir: # optional sub directory to store tensorboard data 40 | wandb: 41 | with_media_logging: true 42 | project: # project name 43 | name: # optional name for the experiment 44 | group: # optional group 45 | 46 | # SLURM launcher configuration. 47 | slurm: 48 | gpus: 4 # convenience parameter, number of GPUs to use. 49 | mem_per_gpu: 40 # in GB, total mem is automatically scaled with `gpus`. 50 | time: 3600 51 | constraint: 52 | partition: 53 | comment: 54 | setup: [] 55 | exclude: '' 56 | 57 | # dora parameters 58 | dora: 59 | # Output folder for all artifacts of an experiment. 60 | dir: /checkpoint/${oc.env:USER}/experiments/audiocraft/outputs 61 | # The following entries will be ignored by dora when computing the unique XP signature. 62 | # Note that slurm.* and dora.* are automatically ignored. 63 | exclude: [ 64 | 'device', 'wandb.*', 'tensorboard.*', 'logging.*', 65 | 'dataset.num_workers', 'eval.num_workers', 'special.*', 66 | 'metrics.visqol.bin', 'metrics.fad.bin', 67 | 'execute_only', 'execute_best', 'generate.every', 68 | 'optim.eager_sync', 'profiler.*', 'deadlock.*', 69 | 'efficient_attention_backend', 'num_threads', 'mp_start_method', 70 | ] 71 | use_rendezvous: false 72 | # for grids, always run from a clean repo, allowing reliable runs and storing 73 | # the exact commit. Your repo must be absolutely pristine clean. 74 | # Local `dora run` are not impacted for easier debugging. 75 | git_save: true 76 | -------------------------------------------------------------------------------- /config/solver/default.yaml: -------------------------------------------------------------------------------- 1 | # @package __global__ 2 | 3 | # WARNING: This is a base configuration file shared across ALL solvers in AudioCraft 4 | # Please don't update this file directly. Instead use distinct configuration files 5 | # to override the below configuration. 6 | solver: ??? 7 | 8 | fsdp: 9 | use: false # should we use FSDP. 10 | param_dtype: float16 # equivalent to autocast_dtype for FSDP. 11 | reduce_dtype: float32 # gradient averaging dtype, float32 will give max stability. 12 | buffer_dtype: float32 # dtype used for buffers, we don't have much buffers, so let's leave it. 13 | sharding_strategy: shard_grad_op # can be shard_grad_op or full_shard. 14 | # full_shard will use less memory but slower ?? 15 | per_block: true # If True, uses nested FSDP. 16 | 17 | profiler: 18 | enabled: false 19 | 20 | deadlock: 21 | use: false 22 | timeout: 600 23 | 24 | dataset: 25 | batch_size: ??? 26 | num_workers: 10 27 | segment_duration: null 28 | num_samples: null 29 | return_info: false 30 | shuffle: false 31 | sample_on_duration: true 32 | sample_on_weight: true 33 | min_segment_ratio: 0.5 34 | train: 35 | num_samples: null 36 | shuffle: true 37 | shuffle_seed: 0 # if you want to sample the data differently. 38 | permutation_on_files: false 39 | valid: 40 | num_samples: null 41 | evaluate: 42 | num_samples: null 43 | generate: 44 | num_samples: null 45 | return_info: true 46 | 47 | checkpoint: 48 | save_last: true 49 | save_every: null 50 | keep_last: null 51 | keep_every_states: null 52 | 53 | generate: 54 | every: null 55 | path: 'samples' 56 | audio: 57 | format: 'mp3' 58 | strategy: 'clip' 59 | sample_rate: null 60 | lm: 61 | use_sampling: false 62 | temp: 1.0 63 | top_k: 0 64 | top_p: 0.0 65 | evaluate: 66 | every: null 67 | num_workers: 5 68 | truncate_audio: null 69 | fixed_generation_duration: null # in secs 70 | metrics: 71 | base: true # run default evaluation (e.g. like train/valid stage) 72 | 73 | optim: 74 | epochs: ??? 75 | updates_per_epoch: null 76 | lr: ??? 77 | optimizer: ??? 78 | adam: 79 | betas: [0.9, 0.999] 80 | weight_decay: 0. 81 | ema: 82 | use: false # whether to use EMA or not 83 | updates: ${optim.updates_per_epoch} # frequency of updates of the EMA 84 | device: cpu # device for EMA, can be put on GPU if more frequent updates 85 | decay: 0.99 # EMA decay value, if null, no EMA is used 86 | 87 | schedule: 88 | lr_scheduler: null 89 | step: 90 | step_size: null 91 | gamma: null 92 | exponential: 93 | lr_decay: null 94 | cosine: 95 | warmup: null 96 | lr_min_ratio: 0.0 97 | cycle_length: 1.0 98 | polynomial_decay: 99 | warmup: null 100 | zero_lr_warmup_steps: 0 101 | end_lr: 0.0 102 | power: 1 103 | inverse_sqrt: 104 | warmup: null 105 | warmup_init_lr: 0.0 106 | linear_warmup: 107 | warmup: null 108 | warmup_init_lr: 0.0 109 | -------------------------------------------------------------------------------- /audiocraft/modules/chroma.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import typing as tp 7 | 8 | from einops import rearrange 9 | from librosa import filters 10 | import torch 11 | from torch import nn 12 | import torch.nn.functional as F 13 | import torchaudio 14 | 15 | 16 | class ChromaExtractor(nn.Module): 17 | """Chroma extraction and quantization. 18 | 19 | Args: 20 | sample_rate (int): Sample rate for the chroma extraction. 21 | n_chroma (int): Number of chroma bins for the chroma extraction. 22 | radix2_exp (int): Size of stft window for the chroma extraction (power of 2, e.g. 12 -> 2^12). 23 | nfft (int, optional): Number of FFT. 24 | winlen (int, optional): Window length. 25 | winhop (int, optional): Window hop size. 26 | argmax (bool, optional): Whether to use argmax. Defaults to False. 27 | norm (float, optional): Norm for chroma normalization. Defaults to inf. 28 | """ 29 | def __init__(self, sample_rate: int, n_chroma: int = 12, radix2_exp: int = 12, nfft: tp.Optional[int] = None, 30 | winlen: tp.Optional[int] = None, winhop: tp.Optional[int] = None, argmax: bool = False, 31 | norm: float = torch.inf): 32 | super().__init__() 33 | self.winlen = winlen or 2 ** radix2_exp 34 | self.nfft = nfft or self.winlen 35 | self.winhop = winhop or (self.winlen // 4) 36 | self.sample_rate = sample_rate 37 | self.n_chroma = n_chroma 38 | self.norm = norm 39 | self.argmax = argmax 40 | self.register_buffer('fbanks', torch.from_numpy(filters.chroma(sr=sample_rate, n_fft=self.nfft, tuning=0, 41 | n_chroma=self.n_chroma)), persistent=False) 42 | self.spec = torchaudio.transforms.Spectrogram(n_fft=self.nfft, win_length=self.winlen, 43 | hop_length=self.winhop, power=2, center=True, 44 | pad=0, normalized=True) 45 | 46 | def forward(self, wav: torch.Tensor) -> torch.Tensor: 47 | T = wav.shape[-1] 48 | # in case we are getting a wav that was dropped out (nullified) 49 | # from the conditioner, make sure wav length is no less that nfft 50 | if T < self.nfft: 51 | pad = self.nfft - T 52 | r = 0 if pad % 2 == 0 else 1 53 | wav = F.pad(wav, (pad // 2, pad // 2 + r), 'constant', 0) 54 | assert wav.shape[-1] == self.nfft, f"expected len {self.nfft} but got {wav.shape[-1]}" 55 | 56 | spec = self.spec(wav).squeeze(1) 57 | raw_chroma = torch.einsum('cf,...ft->...ct', self.fbanks, spec) 58 | norm_chroma = torch.nn.functional.normalize(raw_chroma, p=self.norm, dim=-2, eps=1e-6) 59 | norm_chroma = rearrange(norm_chroma, 'b d t -> b t d') 60 | 61 | if self.argmax: 62 | idx = norm_chroma.argmax(-1, keepdim=True) 63 | norm_chroma[:] = 0 64 | norm_chroma.scatter_(dim=-1, index=idx, value=1) 65 | 66 | return norm_chroma 67 | -------------------------------------------------------------------------------- /audiocraft/losses/sisnr.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import math 8 | import typing as tp 9 | 10 | import torch 11 | from torch import nn 12 | from torch.nn import functional as F 13 | 14 | 15 | def _unfold(a: torch.Tensor, kernel_size: int, stride: int) -> torch.Tensor: 16 | """Given input of size [*OT, T], output Tensor of size [*OT, F, K] 17 | with K the kernel size, by extracting frames with the given stride. 18 | This will pad the input so that `F = ceil(T / K)`. 19 | see https://github.com/pytorch/pytorch/issues/60466 20 | """ 21 | *shape, length = a.shape 22 | n_frames = math.ceil(length / stride) 23 | tgt_length = (n_frames - 1) * stride + kernel_size 24 | a = F.pad(a, (0, tgt_length - length)) 25 | strides = list(a.stride()) 26 | assert strides[-1] == 1, "data should be contiguous" 27 | strides = strides[:-1] + [stride, 1] 28 | return a.as_strided([*shape, n_frames, kernel_size], strides) 29 | 30 | 31 | def _center(x: torch.Tensor) -> torch.Tensor: 32 | return x - x.mean(-1, True) 33 | 34 | 35 | def _norm2(x: torch.Tensor) -> torch.Tensor: 36 | return x.pow(2).sum(-1, True) 37 | 38 | 39 | class SISNR(nn.Module): 40 | """SISNR loss. 41 | 42 | Input should be [B, C, T], output is scalar. 43 | 44 | Args: 45 | sample_rate (int): Sample rate. 46 | segment (float or None): Evaluate on chunks of that many seconds. If None, evaluate on 47 | entire audio only. 48 | overlap (float): Overlap between chunks, i.e. 0.5 = 50 % overlap. 49 | epsilon (float): Epsilon value for numerical stability. 50 | """ 51 | def __init__( 52 | self, 53 | sample_rate: int = 16000, 54 | segment: tp.Optional[float] = 20, 55 | overlap: float = 0.5, 56 | epsilon: float = torch.finfo(torch.float32).eps, 57 | ): 58 | super().__init__() 59 | self.sample_rate = sample_rate 60 | self.segment = segment 61 | self.overlap = overlap 62 | self.epsilon = epsilon 63 | 64 | def forward(self, out_sig: torch.Tensor, ref_sig: torch.Tensor) -> torch.Tensor: 65 | B, C, T = ref_sig.shape 66 | assert ref_sig.shape == out_sig.shape 67 | 68 | if self.segment is None: 69 | frame = T 70 | stride = T 71 | else: 72 | frame = int(self.segment * self.sample_rate) 73 | stride = int(frame * (1 - self.overlap)) 74 | 75 | epsilon = self.epsilon * frame # make epsilon prop to frame size. 76 | 77 | gt = _unfold(ref_sig, frame, stride) 78 | est = _unfold(out_sig, frame, stride) 79 | if self.segment is None: 80 | assert gt.shape[-1] == 1 81 | 82 | gt = _center(gt) 83 | est = _center(est) 84 | dot = torch.einsum("bcft,bcft->bcf", gt, est) 85 | 86 | proj = dot[:, :, :, None] * gt / (epsilon + _norm2(gt)) 87 | noise = est - proj 88 | 89 | sisnr = 10 * ( 90 | torch.log10(epsilon + _norm2(proj)) - torch.log10(epsilon + _norm2(noise)) 91 | ) 92 | return -1 * sisnr[..., 0].mean() 93 | -------------------------------------------------------------------------------- /docs/DATASETS.md: -------------------------------------------------------------------------------- 1 | # AudioCraft datasets 2 | 3 | Our dataset manifest files consist in 1-json-per-line files, potentially gzipped, 4 | as `data.jsons` or `data.jsons.gz` files. This JSON contains the path to the audio 5 | file and associated metadata. The manifest files are then provided in the configuration, 6 | as `datasource` sub-configuration. A datasource contains the pointers to the paths of 7 | the manifest files for each AudioCraft stage (or split) along with additional information 8 | (eg. maximum sample rate to use against this dataset). All the datasources are under the 9 | `dset` group config, with a dedicated configuration file for each dataset. 10 | 11 | ## Getting started 12 | 13 | ### Example 14 | 15 | See the provided example in the directory that provides a manifest to use the example dataset 16 | provided under the [dataset folder](../dataset/example). 17 | 18 | The manifest files are stored in the [egs folder](../egs/example). 19 | 20 | ```shell 21 | egs/ 22 | example/data.json.gz 23 | ``` 24 | 25 | A datasource is defined in the configuration folder, in the dset group config for this dataset 26 | at [config/dset/audio/example](../config/dset/audio/example.yaml): 27 | 28 | ```shell 29 | # @package __global__ 30 | 31 | datasource: 32 | max_sample_rate: 44100 33 | max_channels: 2 34 | 35 | train: egs/example 36 | valid: egs/example 37 | evaluate: egs/example 38 | generate: egs/example 39 | ``` 40 | 41 | For proper dataset, one should create manifest for each of the splits and specify the correct path 42 | to the given manifest in the datasource for each split. 43 | 44 | Then, using a dataset through the configuration can be done pointing to the 45 | corresponding dataset configuration: 46 | ```shell 47 | dset= # should match the yaml file name 48 | 49 | # for example 50 | dset=audio/example 51 | ``` 52 | 53 | ### Creating manifest files 54 | 55 | Assuming you want to create manifest files to load with AudioCraft's AudioDataset, you can use 56 | the following command to create new manifest files from a given folder containing audio files: 57 | 58 | ```shell 59 | python -m audiocraft.data.audio_dataset egs/my_dataset/my_dataset_split/data.jsonl.gz 60 | 61 | # For example to generate the manifest for dset=audio/example 62 | # note: we don't use any split and we don't compress the jsonl file for this dummy example 63 | python -m audiocraft.data.audio_dataset dataset/example egs/example/data.jsonl 64 | 65 | # More info with: python -m audiocraft.data.audio_dataset --help 66 | ``` 67 | 68 | ## Additional information 69 | 70 | ### MusicDataset and metadata 71 | 72 | The MusicDataset is an AudioDataset with additional metadata. The MusicDataset expects 73 | the additional metadata to be stored in a JSON file that has the same path as the corresponding 74 | audio file, but with a `.json` extension. 75 | 76 | ### SoundDataset and metadata 77 | 78 | The SoundDataset is an AudioDataset with descriptions metadata. Similarly to the MusicDataset, 79 | the SoundDataset expects the additional metadata to be stored in a JSON file that has the same 80 | path as the corresponding audio file, but with a `.json` extension. Additionally, the SoundDataset 81 | supports an additional parameter pointing to an extra folder `external_metadata_source` containing 82 | all the JSON metadata files given they have the same filename as the audio file. 83 | -------------------------------------------------------------------------------- /audiocraft/grids/musicgen/_explorers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import typing as tp 8 | 9 | import treetable as tt 10 | 11 | from .._base_explorers import BaseExplorer 12 | 13 | 14 | class LMExplorer(BaseExplorer): 15 | eval_metrics: tp.List[str] = [] 16 | 17 | def stages(self) -> tp.List[str]: 18 | return ['train', 'valid'] 19 | 20 | def get_grid_metrics(self): 21 | """Return the metrics that should be displayed in the tracking table.""" 22 | return [ 23 | tt.group( 24 | 'train', 25 | [ 26 | tt.leaf('epoch'), 27 | tt.leaf('duration', '.1f'), # duration in minutes 28 | tt.leaf('ping'), 29 | tt.leaf('ce', '.4f'), # cross entropy 30 | tt.leaf("ppl", '.3f'), # perplexity 31 | ], 32 | align='>', 33 | ), 34 | tt.group( 35 | 'valid', 36 | [ 37 | tt.leaf('ce', '.4f'), 38 | tt.leaf('ppl', '.3f'), 39 | tt.leaf('best_ppl', '.3f'), 40 | ], 41 | align='>', 42 | ), 43 | ] 44 | 45 | def process_sheep(self, sheep, history): 46 | parts = super().process_sheep(sheep, history) 47 | 48 | track_by = {'ppl': 'lower'} # values should be in ['lower', 'higher'] 49 | best_metrics = {k: (1 if v == 'lower' else -1) * float('inf') for k, v in track_by.items()} 50 | 51 | def comparator(mode, a, b): 52 | return a < b if mode == 'lower' else a > b 53 | 54 | for metrics in history: 55 | for key, sub in metrics.items(): 56 | for metric in track_by: 57 | # for the validation set, keep track of best metrics (ppl in this example) 58 | # this is so we can conveniently compare metrics between runs in the grid 59 | if key == 'valid' and metric in sub and comparator( 60 | track_by[metric], sub[metric], best_metrics[metric] 61 | ): 62 | best_metrics[metric] = sub[metric] 63 | 64 | if 'valid' in parts: 65 | parts['valid'].update({f'best_{k}': v for k, v in best_metrics.items()}) 66 | return parts 67 | 68 | 69 | class GenerationEvalExplorer(BaseExplorer): 70 | eval_metrics: tp.List[str] = [] 71 | 72 | def stages(self) -> tp.List[str]: 73 | return ['evaluate'] 74 | 75 | def get_grid_metrics(self): 76 | """Return the metrics that should be displayed in the tracking table.""" 77 | return [ 78 | tt.group( 79 | 'evaluate', 80 | [ 81 | tt.leaf('epoch', '.3f'), 82 | tt.leaf('duration', '.1f'), 83 | tt.leaf('ping'), 84 | tt.leaf('ce', '.4f'), 85 | tt.leaf('ppl', '.3f'), 86 | tt.leaf('fad', '.3f'), 87 | tt.leaf('kld', '.3f'), 88 | tt.leaf('text_consistency', '.3f'), 89 | tt.leaf('chroma_cosine', '.3f'), 90 | ], 91 | align='>', 92 | ), 93 | ] 94 | -------------------------------------------------------------------------------- /audiocraft/optim/ema.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # ModelEMA implementation is taken from 8 | # https://github.com/facebookresearch/demucs 9 | 10 | from collections import defaultdict 11 | import typing as tp 12 | 13 | import torch 14 | import torch.nn as nn 15 | 16 | 17 | def _get_all_non_persistent_buffers_set(module: nn.Module, root: str = "") -> set: 18 | names: set = set() 19 | for (name, sub_module) in module.named_modules(): 20 | if name == '': 21 | buffer_names = module._non_persistent_buffers_set 22 | buffer_names = {f"{root}.{buff_name}" if len(root) > 0 else buff_name 23 | for buff_name in buffer_names} 24 | names.update(buffer_names) 25 | else: 26 | sub_name = f"{root}.{name}" if len(root) > 0 else name 27 | sub_buffer_names = _get_all_non_persistent_buffers_set(sub_module, sub_name) 28 | names.update(sub_buffer_names) 29 | return names 30 | 31 | 32 | def _get_named_tensors(module: nn.Module): 33 | non_persistent_buffers_set = _get_all_non_persistent_buffers_set(module) 34 | named_buffers = [(name, buffer) for (name, buffer) in module.named_buffers() 35 | if name not in non_persistent_buffers_set] 36 | named_parameters = list(module.named_parameters()) 37 | return named_parameters + named_buffers 38 | 39 | 40 | class ModuleDictEMA: 41 | """Exponential Moving Average over a nn.ModuleDict. 42 | 43 | You can switch to the EMA weights temporarily. 44 | """ 45 | def __init__(self, module_dict: nn.ModuleDict, decay: float = 0.999, 46 | unbias: bool = True, device: tp.Union[torch.device, str] = 'cpu'): 47 | self.decay = decay 48 | self.module_dict = module_dict 49 | self.state: dict = defaultdict(dict) 50 | self.count = 0 51 | self.device = device 52 | self.unbias = unbias 53 | self._init() 54 | 55 | def _init(self): 56 | for module_name, module in self.module_dict.items(): 57 | for key, val in _get_named_tensors(module): 58 | if not val.is_floating_point(): 59 | continue 60 | device = self.device or val.device 61 | if key not in self.state[module_name]: 62 | self.state[module_name][key] = val.detach().to(device, copy=True) 63 | 64 | def step(self): 65 | if self.unbias: 66 | self.count = self.count * self.decay + 1 67 | w = 1 / self.count 68 | else: 69 | w = 1 - self.decay 70 | for module_name, module in self.module_dict.items(): 71 | for key, val in _get_named_tensors(module): 72 | if not val.is_floating_point(): 73 | continue 74 | device = self.device or val.device 75 | self.state[module_name][key].mul_(1 - w) 76 | self.state[module_name][key].add_(val.detach().to(device), alpha=w) 77 | 78 | def state_dict(self): 79 | return {'state': self.state, 'count': self.count} 80 | 81 | def load_state_dict(self, state): 82 | self.count = state['count'] 83 | for module_name, module in state['state'].items(): 84 | for key, val in module.items(): 85 | self.state[module_name][key].copy_(val) 86 | -------------------------------------------------------------------------------- /audiocraft/modules/activations.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | import torch.nn as nn 9 | from torch import Tensor 10 | from typing import Union, Callable 11 | 12 | 13 | class CustomGLU(nn.Module): 14 | """Custom Gated Linear Unit activation. 15 | Applies a modified gated linear unit :math:`a * f(b)` where :math:`a` is the first half 16 | of the input matrices, :math:`b` is the second half, and :math:`f` is a provided activation 17 | function (i.e. sigmoid, swish, etc.). 18 | 19 | Args: 20 | activation (nn.Module): The custom activation to apply in the Gated Linear Unit 21 | dim (int): the dimension on which to split the input. Default: -1 22 | 23 | Shape: 24 | - Input: :math:`(\ast_1, N, \ast_2)` where `*` means, any number of additional 25 | dimensions 26 | - Output: :math:`(\ast_1, M, \ast_2)` where :math:`M=N/2` 27 | 28 | Examples:: 29 | >>> m = CustomGLU(nn.Sigmoid()) 30 | >>> input = torch.randn(4, 2) 31 | >>> output = m(input) 32 | """ 33 | def __init__(self, activation: nn.Module, dim: int = -1): 34 | super(CustomGLU, self).__init__() 35 | self.dim = dim 36 | self.activation = activation 37 | 38 | def forward(self, x: Tensor): 39 | assert x.shape[self.dim] % 2 == 0 # M = N / 2 40 | a, b = torch.chunk(x, 2, dim=self.dim) 41 | return a * self.activation(b) 42 | 43 | 44 | class SwiGLU(CustomGLU): 45 | """SiLU Gated Linear Unit activation. 46 | Applies SiLU Gated Linear Unit :math:`a * SiLU(b)` where :math:`a` is 47 | the first half of the input matrices, :math:`b` is the second half. 48 | 49 | Args: 50 | dim (int): the dimension on which to split the input. Default: -1 51 | """ 52 | def __init__(self, dim: int = -1): 53 | super(SwiGLU, self).__init__(nn.SiLU(), dim) 54 | 55 | 56 | class GeGLU(CustomGLU): 57 | """GeLU Gated Linear Unit activation. 58 | Applies GeLU Gated Linear Unit :math:`a * GELU(b)` where :math:`a` is 59 | the first half of the input matrices, :math:`b` is the second half. 60 | 61 | Args: 62 | dim (int): the dimension on which to split the input. Default: -1 63 | """ 64 | def __init__(self, dim: int = -1): 65 | super(GeGLU, self).__init__(nn.GELU(), dim) 66 | 67 | 68 | class ReGLU(CustomGLU): 69 | """ReLU Gated Linear Unit activation. 70 | Applies ReLU Gated Linear Unit :math:`a * ReLU(b)` where :math:`a` is 71 | the first half of the input matrices, :math:`b` is the second half. 72 | 73 | Args: 74 | dim (int): the dimension on which to split the input. Default: -1 75 | """ 76 | def __init__(self, dim: int = -1): 77 | super(ReGLU, self).__init__(nn.ReLU(), dim) 78 | 79 | 80 | def get_activation_fn( 81 | activation: Union[str, Callable[[Tensor], Tensor]] 82 | ) -> Union[str, Callable[[Tensor], Tensor]]: 83 | """Helper function to map an activation string to the activation class. 84 | If the supplied activation is not a string that is recognized, the activation is passed back. 85 | 86 | Args: 87 | activation (str, or Callable[[Tensor], Tensor]): Activation to check 88 | """ 89 | if isinstance(activation, str): 90 | if activation == "reglu": 91 | return ReGLU() 92 | elif activation == "geglu": 93 | return GeGLU() 94 | elif activation == "swiglu": 95 | return SwiGLU() 96 | return activation 97 | -------------------------------------------------------------------------------- /audiocraft/quantization/base.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | """ 8 | Base class for all quantizers. 9 | """ 10 | 11 | from dataclasses import dataclass, field 12 | import typing as tp 13 | 14 | import torch 15 | from torch import nn 16 | 17 | 18 | @dataclass 19 | class QuantizedResult: 20 | x: torch.Tensor 21 | codes: torch.Tensor 22 | bandwidth: torch.Tensor # bandwidth in kb/s used, per batch item. 23 | penalty: tp.Optional[torch.Tensor] = None 24 | metrics: dict = field(default_factory=dict) 25 | 26 | 27 | class BaseQuantizer(nn.Module): 28 | """Base class for quantizers. 29 | """ 30 | 31 | def forward(self, x: torch.Tensor, frame_rate: int) -> QuantizedResult: 32 | """ 33 | Given input tensor x, returns first the quantized (or approximately quantized) 34 | representation along with quantized codes, bandwidth, and any penalty term for the loss. 35 | Finally, this returns a dict of metrics to update logging etc. 36 | Frame rate must be passed so that the bandwidth is properly computed. 37 | """ 38 | raise NotImplementedError() 39 | 40 | def encode(self, x: torch.Tensor) -> torch.Tensor: 41 | """Encode a given input tensor with the specified sample rate at the given bandwidth.""" 42 | raise NotImplementedError() 43 | 44 | def decode(self, codes: torch.Tensor) -> torch.Tensor: 45 | """Decode the given codes to the quantized representation.""" 46 | raise NotImplementedError() 47 | 48 | @property 49 | def total_codebooks(self): 50 | """Total number of codebooks.""" 51 | raise NotImplementedError() 52 | 53 | @property 54 | def num_codebooks(self): 55 | """Number of active codebooks.""" 56 | raise NotImplementedError() 57 | 58 | def set_num_codebooks(self, n: int): 59 | """Set the number of active codebooks.""" 60 | raise NotImplementedError() 61 | 62 | 63 | class DummyQuantizer(BaseQuantizer): 64 | """Fake quantizer that actually does not perform any quantization. 65 | """ 66 | def __init__(self): 67 | super().__init__() 68 | 69 | def forward(self, x: torch.Tensor, frame_rate: int): 70 | q = x.unsqueeze(1) 71 | return QuantizedResult(x, q, torch.tensor(q.numel() * 32 * frame_rate / 1000 / len(x)).to(x)) 72 | 73 | def encode(self, x: torch.Tensor) -> torch.Tensor: 74 | """Encode a given input tensor with the specified sample rate at the given bandwidth. 75 | In the case of the DummyQuantizer, the codes are actually identical 76 | to the input and resulting quantized representation as no quantization is done. 77 | """ 78 | return x.unsqueeze(1) 79 | 80 | def decode(self, codes: torch.Tensor) -> torch.Tensor: 81 | """Decode the given codes to the quantized representation. 82 | In the case of the DummyQuantizer, the codes are actually identical 83 | to the input and resulting quantized representation as no quantization is done. 84 | """ 85 | return codes.squeeze(1) 86 | 87 | @property 88 | def total_codebooks(self): 89 | """Total number of codebooks.""" 90 | return 1 91 | 92 | @property 93 | def num_codebooks(self): 94 | """Total number of codebooks.""" 95 | return self.total_codebooks 96 | 97 | def set_num_codebooks(self, n: int): 98 | """Set the number of active codebooks.""" 99 | raise AttributeError("Cannot override the number of codebooks for the dummy quantizer") 100 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to make participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies within all project spaces, and it also applies when 49 | an individual is representing the project or its community in public spaces. 50 | Examples of representing a project or community include using an official 51 | project e-mail address, posting via an official social media account, or acting 52 | as an appointed representative at an online or offline event. Representation of 53 | a project may be further defined and clarified by project maintainers. 54 | 55 | This Code of Conduct also applies outside the project spaces when there is a 56 | reasonable belief that an individual's behavior may have a negative impact on 57 | the project or its community. 58 | 59 | ## Enforcement 60 | 61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 62 | reported by contacting the project team at . All 63 | complaints will be reviewed and investigated and will result in a response that 64 | is deemed necessary and appropriate to the circumstances. The project team is 65 | obligated to maintain confidentiality with regard to the reporter of an incident. 66 | Further details of specific enforcement policies may be posted separately. 67 | 68 | Project maintainers who do not follow or enforce the Code of Conduct in good 69 | faith may face temporary or permanent repercussions as determined by other 70 | members of the project's leadership. 71 | 72 | ## Attribution 73 | 74 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 75 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 76 | 77 | [homepage]: https://www.contributor-covenant.org 78 | 79 | For answers to common questions about this code of conduct, see 80 | https://www.contributor-covenant.org/faq 81 | -------------------------------------------------------------------------------- /config/solver/compression/default.yaml: -------------------------------------------------------------------------------- 1 | # @package __global__ 2 | 3 | defaults: 4 | - ../default 5 | - override /dset: audio/default 6 | - _self_ 7 | 8 | solver: compression 9 | sample_rate: ??? 10 | channels: ??? 11 | 12 | # loss balancing 13 | losses: 14 | adv: 4. 15 | feat: 4. 16 | l1: 0.1 17 | mel: 0. 18 | msspec: 2. 19 | sisnr: 0. 20 | balancer: 21 | balance_grads: true 22 | ema_decay: 0.999 23 | per_batch_item: true 24 | total_norm: 1. 25 | 26 | adversarial: 27 | every: 1 28 | adversaries: [msstftd] 29 | adv_loss: hinge 30 | feat_loss: l1 31 | 32 | # losses hyperparameters 33 | l1: {} 34 | l2: {} 35 | mrstft: 36 | factor_sc: .5 37 | factor_mag: .5 38 | normalized: false 39 | mel: 40 | sample_rate: ${sample_rate} 41 | n_fft: 1024 42 | hop_length: 256 43 | win_length: 1024 44 | n_mels: 64 45 | f_min: 64 46 | f_max: null 47 | normalized: false 48 | floor_level: 1e-5 49 | sisnr: 50 | sample_rate: ${sample_rate} 51 | segment: 5. 52 | msspec: 53 | sample_rate: ${sample_rate} 54 | range_start: 6 55 | range_end: 11 56 | n_mels: 64 57 | f_min: 64 58 | f_max: null 59 | normalized: true 60 | alphas: false 61 | floor_level: 1e-5 62 | 63 | # metrics 64 | metrics: 65 | visqol: 66 | mode: audio 67 | bin: null # path to visqol install 68 | model: tcdaudio14_aacvopus_coresv_svrnsim_n.68_g.01_c1.model # visqol v3 69 | 70 | # adversaries hyperparameters 71 | msstftd: 72 | in_channels: 1 73 | out_channels: 1 74 | filters: 32 75 | norm: weight_norm 76 | n_ffts: [1024, 2048, 512, 256, 128] 77 | hop_lengths: [256, 512, 128, 64, 32] 78 | win_lengths: [1024, 2048, 512, 256, 128] 79 | activation: LeakyReLU 80 | activation_params: {negative_slope: 0.3} 81 | msd: 82 | in_channels: 1 83 | out_channels: 1 84 | scale_norms: [spectral_norm, weight_norm, weight_norm] 85 | kernel_sizes: [5, 3] 86 | filters: 16 87 | max_filters: 1024 88 | downsample_scales: [4, 4, 4, 4] 89 | inner_kernel_sizes: null 90 | groups: [4, 4, 4, 4] 91 | strides: null 92 | paddings: null 93 | activation: LeakyReLU 94 | activation_params: {negative_slope: 0.3} 95 | mpd: 96 | in_channels: 1 97 | out_channels: 1 98 | periods: [2, 3, 5, 7, 11] 99 | n_layers: 5 100 | kernel_size: 5 101 | stride: 3 102 | filters: 8 103 | filter_scales: 4 104 | max_filters: 1024 105 | activation: LeakyReLU 106 | activation_params: {negative_slope: 0.3} 107 | norm: weight_norm 108 | 109 | # data hyperparameters 110 | dataset: 111 | batch_size: 64 112 | num_workers: 10 113 | segment_duration: 1 114 | train: 115 | num_samples: 500000 116 | valid: 117 | num_samples: 10000 118 | evaluate: 119 | batch_size: 32 120 | num_samples: 10000 121 | generate: 122 | batch_size: 32 123 | num_samples: 50 124 | segment_duration: 10 125 | 126 | # solver hyperparameters 127 | evaluate: 128 | every: 25 129 | num_workers: 5 130 | metrics: 131 | visqol: false 132 | sisnr: true 133 | generate: 134 | every: 25 135 | num_workers: 5 136 | audio: 137 | sample_rate: ${sample_rate} 138 | 139 | # checkpointing schedule 140 | checkpoint: 141 | save_last: true 142 | save_every: 25 143 | keep_last: 10 144 | keep_every_states: null 145 | 146 | # optimization hyperparameters 147 | optim: 148 | epochs: 200 149 | updates_per_epoch: 2000 150 | lr: 3e-4 151 | max_norm: 0. 152 | optimizer: adam 153 | adam: 154 | betas: [0.5, 0.9] 155 | weight_decay: 0. 156 | ema: 157 | use: true # whether to use EMA or not 158 | updates: 1 # update at every step 159 | device: ${device} # device for EMA, can be put on GPU if more frequent updates 160 | decay: 0.99 # EMA decay value, if null, no EMA is used 161 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # AudioCraft 2 | ![docs badge](https://github.com/facebookresearch/audiocraft/workflows/audiocraft_docs/badge.svg) 3 | ![linter badge](https://github.com/facebookresearch/audiocraft/workflows/audiocraft_linter/badge.svg) 4 | ![tests badge](https://github.com/facebookresearch/audiocraft/workflows/audiocraft_tests/badge.svg) 5 | 6 | AudioCraft is a PyTorch library for deep learning research on audio generation. AudioCraft contains inference and training code 7 | for two state-of-the-art AI generative models producing high-quality audio: AudioGen and MusicGen. 8 | 9 | 10 | ## Installation 11 | AudioCraft requires Python 3.9, PyTorch 2.0.0. To install AudioCraft, you can run the following: 12 | 13 | ```shell 14 | # Best to make sure you have torch installed first, in particular before installing xformers. 15 | # Don't run this if you already have PyTorch installed. 16 | pip install 'torch>=2.0' 17 | # Then proceed to one of the following 18 | pip install -U audiocraft # stable release 19 | pip install -U git+https://git@github.com/facebookresearch/audiocraft#egg=audiocraft # bleeding edge 20 | pip install -e . # or if you cloned the repo locally (mandatory if you want to train). 21 | ``` 22 | 23 | We also recommend having `ffmpeg` installed, either through your system or Anaconda: 24 | ```bash 25 | sudo apt-get install ffmpeg 26 | # Or if you are using Anaconda or Miniconda 27 | conda install 'ffmpeg<5' -c conda-forge 28 | ``` 29 | 30 | ## Models 31 | 32 | At the moment, AudioCraft contains the training code and inference code for: 33 | * [MusicGen](./docs/MUSICGEN.md): A state-of-the-art controllable text-to-music model. 34 | * [AudioGen](./docs/AUDIOGEN.md): A state-of-the-art text-to-sound model. 35 | * [EnCodec](./docs/ENCODEC.md): A state-of-the-art high fidelity neural audio codec. 36 | * [Multi Band Diffusion](./docs/MBD.md): An EnCodec compatible decoder using diffusion. 37 | 38 | ## Training code 39 | 40 | AudioCraft contains PyTorch components for deep learning research in audio and training pipelines for the developed models. 41 | For a general introduction of AudioCraft design principles and instructions to develop your own training pipeline, refer to 42 | the [AudioCraft training documentation](./docs/TRAINING.md). 43 | 44 | For reproducing existing work and using the developed training pipelines, refer to the instructions for each specific model 45 | that provides pointers to configuration, example grids and model/task-specific information and FAQ. 46 | 47 | 48 | ## API documentation 49 | 50 | We provide some [API documentation](https://facebookresearch.github.io/audiocraft/api_docs/audiocraft/index.html) for AudioCraft. 51 | 52 | 53 | ## FAQ 54 | 55 | #### Is the training code available? 56 | 57 | Yes! We provide the training code for [EnCodec](./docs/ENCODEC.md), [MusicGen](./docs/MUSICGEN.md) and [Multi Band Diffusion](./docs/MBD.md). 58 | 59 | #### Where are the models stored? 60 | 61 | Hugging Face stored the model in a specific location, which can be overriden by setting the `AUDIOCRAFT_CACHE_DIR` environment variable. 62 | 63 | 64 | ## License 65 | * The code in this repository is released under the MIT license as found in the [LICENSE file](LICENSE). 66 | * The models weights in this repository are released under the CC-BY-NC 4.0 license as found in the [LICENSE_weights file](LICENSE_weights). 67 | 68 | 69 | ## Citation 70 | 71 | For the general framework of AudioCraft, please cite the following. 72 | ``` 73 | @article{copet2023simple, 74 | title={Simple and Controllable Music Generation}, 75 | author={Jade Copet and Felix Kreuk and Itai Gat and Tal Remez and David Kant and Gabriel Synnaeve and Yossi Adi and Alexandre Défossez}, 76 | year={2023}, 77 | journal={arXiv preprint arXiv:2306.05284}, 78 | } 79 | ``` 80 | 81 | When referring to a specific model, please cite as mentioned in the model specific README, e.g 82 | [./docs/MUSICGEN.md](./docs/MUSICGEN.md), [./docs/AUDIOGEN.md](./docs/AUDIOGEN.md), etc. 83 | -------------------------------------------------------------------------------- /audiocraft/metrics/chroma_cosinesim.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | import torchmetrics 9 | 10 | from ..data.audio_utils import convert_audio 11 | from ..modules.chroma import ChromaExtractor 12 | 13 | 14 | class ChromaCosineSimilarityMetric(torchmetrics.Metric): 15 | """Chroma cosine similarity metric. 16 | 17 | This metric extracts a chromagram for a reference waveform and 18 | a generated waveform and compares each frame using the cosine similarity 19 | function. The output is the mean cosine similarity. 20 | 21 | Args: 22 | sample_rate (int): Sample rate used by the chroma extractor. 23 | n_chroma (int): Number of chroma used by the chroma extractor. 24 | radix2_exp (int): Exponent for the chroma extractor. 25 | argmax (bool): Whether the chroma extractor uses argmax. 26 | eps (float): Epsilon for cosine similarity computation. 27 | """ 28 | def __init__(self, sample_rate: int, n_chroma: int, radix2_exp: int, argmax: bool, eps: float = 1e-8): 29 | super().__init__() 30 | self.chroma_sample_rate = sample_rate 31 | self.n_chroma = n_chroma 32 | self.eps = eps 33 | self.chroma_extractor = ChromaExtractor(sample_rate=self.chroma_sample_rate, n_chroma=self.n_chroma, 34 | radix2_exp=radix2_exp, argmax=argmax) 35 | self.add_state("cosine_sum", default=torch.tensor(0.), dist_reduce_fx="sum") 36 | self.add_state("weight", default=torch.tensor(0.), dist_reduce_fx="sum") 37 | 38 | def update(self, preds: torch.Tensor, targets: torch.Tensor, 39 | sizes: torch.Tensor, sample_rates: torch.Tensor) -> None: 40 | """Compute cosine similarity between chromagrams and accumulate scores over the dataset.""" 41 | if preds.size(0) == 0: 42 | return 43 | 44 | assert preds.shape == targets.shape, ( 45 | f"Preds and target shapes mismatch: preds={preds.shape}, targets={targets.shape}") 46 | assert preds.size(0) == sizes.size(0), ( 47 | f"Number of items in preds ({preds.shape}) mismatch ", 48 | f"with sizes ({sizes.shape})") 49 | assert preds.size(0) == sample_rates.size(0), ( 50 | f"Number of items in preds ({preds.shape}) mismatch ", 51 | f"with sample_rates ({sample_rates.shape})") 52 | assert torch.all(sample_rates == sample_rates[0].item()), "All sample rates are not the same in the batch" 53 | 54 | device = self.weight.device 55 | preds, targets = preds.to(device), targets.to(device) # type: ignore 56 | sample_rate = sample_rates[0].item() 57 | preds = convert_audio(preds, from_rate=sample_rate, to_rate=self.chroma_sample_rate, to_channels=1) 58 | targets = convert_audio(targets, from_rate=sample_rate, to_rate=self.chroma_sample_rate, to_channels=1) 59 | gt_chroma = self.chroma_extractor(targets) 60 | gen_chroma = self.chroma_extractor(preds) 61 | chroma_lens = (sizes / self.chroma_extractor.winhop).ceil().int() 62 | for i in range(len(gt_chroma)): 63 | t = int(chroma_lens[i].item()) 64 | cosine_sim = torch.nn.functional.cosine_similarity( 65 | gt_chroma[i, :t], gen_chroma[i, :t], dim=1, eps=self.eps) 66 | self.cosine_sum += cosine_sim.sum(dim=0) # type: ignore 67 | self.weight += torch.tensor(t) # type: ignore 68 | 69 | def compute(self) -> float: 70 | """Computes the average cosine similarty across all generated/target chromagrams pairs.""" 71 | assert self.weight.item() > 0, "Unable to compute with total number of comparisons <= 0" # type: ignore 72 | return (self.cosine_sum / self.weight).item() # type: ignore 73 | -------------------------------------------------------------------------------- /audiocraft/utils/best_state.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from collections import defaultdict 8 | import logging 9 | import typing as tp 10 | 11 | import flashy 12 | import torch 13 | 14 | from ..optim import ModuleDictEMA 15 | from .utils import copy_state 16 | 17 | 18 | logger = logging.getLogger(__name__) 19 | 20 | 21 | class BestStateDictManager(flashy.state.StateDictSource): 22 | """BestStateDictManager maintains a copy of best state_dict() for registered sources. 23 | 24 | BestStateDictManager has two main attributes: 25 | states (dict): State dict of the registered StateDictSource. 26 | param_ids (dict): Dict of parameter ids for registered states from ModuleDictEMA and other sources. 27 | 28 | When registering new sources, the BestStateDictManager will ensure two conflicting sources between 29 | ModuleDictEMA and original modules are not both registered as it would otherwise create ambiguity about 30 | what to consider for best state. 31 | 32 | Args: 33 | device (torch.device or str): Device on which we keep the copy. 34 | dtype (torch.dtype): Data type for the state parameters. 35 | """ 36 | def __init__(self, device: tp.Union[torch.device, str] = 'cpu', 37 | dtype: tp.Optional[torch.dtype] = None): 38 | self.device = device 39 | self.states: dict = {} 40 | self.param_ids: dict = defaultdict(dict) 41 | self.dtype = dtype 42 | 43 | def _get_parameter_ids(self, state_dict): 44 | return {id(p): name for name, p in state_dict.items() if isinstance(p, torch.Tensor)} 45 | 46 | def _validate_no_parameter_ids_overlap(self, name: str, param_ids: dict): 47 | for registered_name, registered_param_ids in self.param_ids.items(): 48 | if registered_name != name: 49 | overlap = set.intersection(registered_param_ids.keys(), param_ids.keys()) 50 | assert len(overlap) == 0, f"Found {len(overlap)} / {len(param_ids.keys())} overlapping parameters" 51 | f" in {name} and already registered {registered_name}: {' '.join(overlap)}" 52 | 53 | def update(self, name: str, source: flashy.state.StateDictSource): 54 | if name not in self.states: 55 | raise ValueError(f"{name} missing from registered states.") 56 | self.states[name] = copy_state(source.state_dict(), device=self.device, dtype=self.dtype) 57 | 58 | def register(self, name: str, source: flashy.state.StateDictSource): 59 | if name in self.states: 60 | raise ValueError(f"{name} already present in states.") 61 | # Registering parameter ids for EMA and non-EMA states allows us to check that 62 | # there is no overlap that would create ambiguity about how to handle the best state 63 | param_ids = self._get_parameter_ids(source.state_dict()) 64 | if isinstance(source, ModuleDictEMA): 65 | logger.debug(f"Registering to best state: ModuleDictEMA '{name}' with {len(param_ids)} params") 66 | self._validate_no_parameter_ids_overlap(name, param_ids) 67 | self.param_ids[name] = param_ids 68 | else: 69 | logger.debug(f"Registering to best state: StateDictSource '{name}' with {len(param_ids)} params") 70 | self._validate_no_parameter_ids_overlap('base', param_ids) 71 | self.param_ids['base'].update(param_ids) 72 | # Register state 73 | self.states[name] = copy_state(source.state_dict(), device=self.device, dtype=self.dtype) 74 | 75 | def state_dict(self) -> flashy.state.StateDict: 76 | return self.states 77 | 78 | def load_state_dict(self, state: flashy.state.StateDict): 79 | for name, sub_state in state.items(): 80 | for k, v in sub_state.items(): 81 | self.states[name][k].copy_(v) 82 | -------------------------------------------------------------------------------- /tests/data/test_audio_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import julius 8 | import torch 9 | import pytest 10 | 11 | from audiocraft.data.audio_utils import ( 12 | _clip_wav, 13 | convert_audio_channels, 14 | convert_audio, 15 | normalize_audio 16 | ) 17 | from ..common_utils import get_batch_white_noise 18 | 19 | 20 | class TestConvertAudioChannels: 21 | 22 | def test_convert_audio_channels_downmix(self): 23 | b, c, t = 2, 3, 100 24 | audio = get_batch_white_noise(b, c, t) 25 | mixed = convert_audio_channels(audio, channels=2) 26 | assert list(mixed.shape) == [b, 2, t] 27 | 28 | def test_convert_audio_channels_nochange(self): 29 | b, c, t = 2, 3, 100 30 | audio = get_batch_white_noise(b, c, t) 31 | mixed = convert_audio_channels(audio, channels=c) 32 | assert list(mixed.shape) == list(audio.shape) 33 | 34 | def test_convert_audio_channels_upmix(self): 35 | b, c, t = 2, 1, 100 36 | audio = get_batch_white_noise(b, c, t) 37 | mixed = convert_audio_channels(audio, channels=3) 38 | assert list(mixed.shape) == [b, 3, t] 39 | 40 | def test_convert_audio_channels_upmix_error(self): 41 | b, c, t = 2, 2, 100 42 | audio = get_batch_white_noise(b, c, t) 43 | with pytest.raises(ValueError): 44 | convert_audio_channels(audio, channels=3) 45 | 46 | 47 | class TestConvertAudio: 48 | 49 | def test_convert_audio_channels_downmix(self): 50 | b, c, dur = 2, 3, 4. 51 | sr = 128 52 | audio = get_batch_white_noise(b, c, int(sr * dur)) 53 | out = convert_audio(audio, from_rate=sr, to_rate=sr, to_channels=2) 54 | assert list(out.shape) == [audio.shape[0], 2, audio.shape[-1]] 55 | 56 | def test_convert_audio_channels_upmix(self): 57 | b, c, dur = 2, 1, 4. 58 | sr = 128 59 | audio = get_batch_white_noise(b, c, int(sr * dur)) 60 | out = convert_audio(audio, from_rate=sr, to_rate=sr, to_channels=3) 61 | assert list(out.shape) == [audio.shape[0], 3, audio.shape[-1]] 62 | 63 | def test_convert_audio_upsample(self): 64 | b, c, dur = 2, 1, 4. 65 | sr = 2 66 | new_sr = 3 67 | audio = get_batch_white_noise(b, c, int(sr * dur)) 68 | out = convert_audio(audio, from_rate=sr, to_rate=new_sr, to_channels=c) 69 | out_j = julius.resample.resample_frac(audio, old_sr=sr, new_sr=new_sr) 70 | assert torch.allclose(out, out_j) 71 | 72 | def test_convert_audio_resample(self): 73 | b, c, dur = 2, 1, 4. 74 | sr = 3 75 | new_sr = 2 76 | audio = get_batch_white_noise(b, c, int(sr * dur)) 77 | out = convert_audio(audio, from_rate=sr, to_rate=new_sr, to_channels=c) 78 | out_j = julius.resample.resample_frac(audio, old_sr=sr, new_sr=new_sr) 79 | assert torch.allclose(out, out_j) 80 | 81 | 82 | class TestNormalizeAudio: 83 | 84 | def test_clip_wav(self): 85 | b, c, dur = 2, 1, 4. 86 | sr = 3 87 | audio = 10.0 * get_batch_white_noise(b, c, int(sr * dur)) 88 | _clip_wav(audio) 89 | assert audio.abs().max() <= 1 90 | 91 | def test_normalize_audio_clip(self): 92 | b, c, dur = 2, 1, 4. 93 | sr = 3 94 | audio = 10.0 * get_batch_white_noise(b, c, int(sr * dur)) 95 | norm_audio = normalize_audio(audio, strategy='clip') 96 | assert norm_audio.abs().max() <= 1 97 | 98 | def test_normalize_audio_rms(self): 99 | b, c, dur = 2, 1, 4. 100 | sr = 3 101 | audio = 10.0 * get_batch_white_noise(b, c, int(sr * dur)) 102 | norm_audio = normalize_audio(audio, strategy='rms') 103 | assert norm_audio.abs().max() <= 1 104 | 105 | def test_normalize_audio_peak(self): 106 | b, c, dur = 2, 1, 4. 107 | sr = 3 108 | audio = 10.0 * get_batch_white_noise(b, c, int(sr * dur)) 109 | norm_audio = normalize_audio(audio, strategy='peak') 110 | assert norm_audio.abs().max() <= 1 111 | -------------------------------------------------------------------------------- /audiocraft/grids/musicgen/musicgen_pretrained_32khz_eval.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | """ 8 | Evaluation with objective metrics for the pretrained MusicGen models. 9 | This grid takes signature from the training grid and runs evaluation-only stage. 10 | 11 | When running the grid for the first time, please use: 12 | REGEN=1 dora grid musicgen.musicgen_pretrained_32khz_eval 13 | and re-use the REGEN=1 option when the grid is changed to force regenerating it. 14 | 15 | Note that you need the proper metrics external libraries setup to use all 16 | the objective metrics activated in this grid. Refer to the README for more information. 17 | """ 18 | 19 | import os 20 | 21 | from ._explorers import GenerationEvalExplorer 22 | from ...environment import AudioCraftEnvironment 23 | from ... import train 24 | 25 | 26 | def eval(launcher, batch_size: int = 32, eval_melody: bool = False): 27 | opts = { 28 | 'dset': 'audio/musiccaps_32khz', 29 | 'solver/musicgen/evaluation': 'objective_eval', 30 | 'execute_only': 'evaluate', 31 | '+dataset.evaluate.batch_size': batch_size, 32 | '+metrics.fad.tf.batch_size': 16, 33 | } 34 | # chroma-specific evaluation 35 | chroma_opts = { 36 | 'dset': 'internal/music_400k_32khz', 37 | 'dataset.evaluate.segment_duration': 30, 38 | 'dataset.evaluate.num_samples': 1000, 39 | 'evaluate.metrics.chroma_cosine': True, 40 | 'evaluate.metrics.fad': False, 41 | 'evaluate.metrics.kld': False, 42 | 'evaluate.metrics.text_consistency': False, 43 | } 44 | # binary for FAD computation: replace this path with your own path 45 | metrics_opts = { 46 | 'metrics.fad.tf.bin': '/data/home/jadecopet/local/usr/opt/google-research' 47 | } 48 | opt1 = {'generate.lm.use_sampling': True, 'generate.lm.top_k': 250, 'generate.lm.top_p': 0.} 49 | opt2 = {'transformer_lm.two_step_cfg': True} 50 | 51 | sub = launcher.bind(opts) 52 | sub.bind_(metrics_opts) 53 | 54 | # base objective metrics 55 | sub(opt1, opt2) 56 | 57 | if eval_melody: 58 | # chroma-specific metrics 59 | sub(opt1, opt2, chroma_opts) 60 | 61 | 62 | @GenerationEvalExplorer 63 | def explorer(launcher): 64 | partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global']) 65 | launcher.slurm_(gpus=4, partition=partitions) 66 | 67 | if 'REGEN' not in os.environ: 68 | folder = train.main.dora.dir / 'grids' / __name__.split('.', 2)[-1] 69 | with launcher.job_array(): 70 | for sig in folder.iterdir(): 71 | if not sig.is_symlink(): 72 | continue 73 | xp = train.main.get_xp_from_sig(sig.name) 74 | launcher(xp.argv) 75 | return 76 | 77 | with launcher.job_array(): 78 | musicgen_base = launcher.bind(solver="musicgen/musicgen_base_32khz") 79 | musicgen_base.bind_({'autocast': False, 'fsdp.use': True}) 80 | 81 | # base musicgen models 82 | musicgen_base_small = musicgen_base.bind({'continue_from': '//pretrained/facebook/musicgen-small'}) 83 | eval(musicgen_base_small, batch_size=128) 84 | 85 | musicgen_base_medium = musicgen_base.bind({'continue_from': '//pretrained/facebook/musicgen-medium'}) 86 | musicgen_base_medium.bind_({'model/lm/model_scale': 'medium'}) 87 | eval(musicgen_base_medium, batch_size=128) 88 | 89 | musicgen_base_large = musicgen_base.bind({'continue_from': '//pretrained/facebook/musicgen-large'}) 90 | musicgen_base_large.bind_({'model/lm/model_scale': 'large'}) 91 | eval(musicgen_base_large, batch_size=128) 92 | 93 | # melody musicgen model 94 | musicgen_melody = launcher.bind(solver="musicgen/musicgen_melody_32khz") 95 | musicgen_melody.bind_({'autocast': False, 'fsdp.use': True}) 96 | 97 | musicgen_melody_medium = musicgen_melody.bind({'continue_from': '//pretrained/facebook/musicgen-melody'}) 98 | musicgen_melody_medium.bind_({'model/lm/model_scale': 'medium'}) 99 | eval(musicgen_melody_medium, batch_size=128, eval_melody=True) 100 | -------------------------------------------------------------------------------- /audiocraft/data/info_audio_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | """Base classes for the datasets that also provide non-audio metadata, 7 | e.g. description, text transcription etc. 8 | """ 9 | from dataclasses import dataclass 10 | import logging 11 | import math 12 | import re 13 | import typing as tp 14 | 15 | import torch 16 | 17 | from .audio_dataset import AudioDataset, AudioMeta 18 | from ..environment import AudioCraftEnvironment 19 | from ..modules.conditioners import SegmentWithAttributes, ConditioningAttributes 20 | 21 | 22 | logger = logging.getLogger(__name__) 23 | 24 | 25 | def _clusterify_meta(meta: AudioMeta) -> AudioMeta: 26 | """Monkey-patch meta to match cluster specificities.""" 27 | meta.path = AudioCraftEnvironment.apply_dataset_mappers(meta.path) 28 | if meta.info_path is not None: 29 | meta.info_path.zip_path = AudioCraftEnvironment.apply_dataset_mappers(meta.info_path.zip_path) 30 | return meta 31 | 32 | 33 | def clusterify_all_meta(meta: tp.List[AudioMeta]) -> tp.List[AudioMeta]: 34 | """Monkey-patch all meta to match cluster specificities.""" 35 | return [_clusterify_meta(m) for m in meta] 36 | 37 | 38 | @dataclass 39 | class AudioInfo(SegmentWithAttributes): 40 | """Dummy SegmentInfo with empty attributes. 41 | 42 | The InfoAudioDataset is expected to return metadata that inherits 43 | from SegmentWithAttributes class and can return conditioning attributes. 44 | 45 | This basically guarantees all datasets will be compatible with current 46 | solver that contain conditioners requiring this. 47 | """ 48 | audio_tokens: tp.Optional[torch.Tensor] = None # populated when using cached batch for training a LM. 49 | 50 | def to_condition_attributes(self) -> ConditioningAttributes: 51 | return ConditioningAttributes() 52 | 53 | 54 | class InfoAudioDataset(AudioDataset): 55 | """AudioDataset that always returns metadata as SegmentWithAttributes along with the audio waveform. 56 | 57 | See `audiocraft.data.audio_dataset.AudioDataset` for initialization arguments. 58 | """ 59 | def __init__(self, meta: tp.List[AudioMeta], **kwargs): 60 | super().__init__(clusterify_all_meta(meta), **kwargs) 61 | 62 | def __getitem__(self, index: int) -> tp.Union[torch.Tensor, tp.Tuple[torch.Tensor, SegmentWithAttributes]]: 63 | if not self.return_info: 64 | wav = super().__getitem__(index) 65 | assert isinstance(wav, torch.Tensor) 66 | return wav 67 | wav, meta = super().__getitem__(index) 68 | return wav, AudioInfo(**meta.to_dict()) 69 | 70 | 71 | def get_keyword_or_keyword_list(value: tp.Optional[str]) -> tp.Union[tp.Optional[str], tp.Optional[tp.List[str]]]: 72 | """Preprocess a single keyword or possible a list of keywords.""" 73 | if isinstance(value, list): 74 | return get_keyword_list(value) 75 | else: 76 | return get_keyword(value) 77 | 78 | 79 | def get_string(value: tp.Optional[str]) -> tp.Optional[str]: 80 | """Preprocess a single keyword.""" 81 | if value is None or (not isinstance(value, str)) or len(value) == 0 or value == 'None': 82 | return None 83 | else: 84 | return value.strip() 85 | 86 | 87 | def get_keyword(value: tp.Optional[str]) -> tp.Optional[str]: 88 | """Preprocess a single keyword.""" 89 | if value is None or (not isinstance(value, str)) or len(value) == 0 or value == 'None': 90 | return None 91 | else: 92 | return value.strip().lower() 93 | 94 | 95 | def get_keyword_list(values: tp.Union[str, tp.List[str]]) -> tp.Optional[tp.List[str]]: 96 | """Preprocess a list of keywords.""" 97 | if isinstance(values, str): 98 | values = [v.strip() for v in re.split(r'[,\s]', values)] 99 | elif isinstance(values, float) and math.isnan(values): 100 | values = [] 101 | if not isinstance(values, list): 102 | logger.debug(f"Unexpected keyword list {values}") 103 | values = [str(values)] 104 | 105 | kws = [get_keyword(v) for v in values] 106 | kw_list = [k for k in kws if k is not None] 107 | if len(kw_list) == 0: 108 | return None 109 | else: 110 | return kw_list 111 | -------------------------------------------------------------------------------- /audiocraft/adversarial/discriminators/mpd.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import typing as tp 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | 13 | from ...modules import NormConv2d 14 | from .base import MultiDiscriminator, MultiDiscriminatorOutputType 15 | 16 | 17 | def get_padding(kernel_size: int, dilation: int = 1) -> int: 18 | return int((kernel_size * dilation - dilation) / 2) 19 | 20 | 21 | class PeriodDiscriminator(nn.Module): 22 | """Period sub-discriminator. 23 | 24 | Args: 25 | period (int): Period between samples of audio. 26 | in_channels (int): Number of input channels. 27 | out_channels (int): Number of output channels. 28 | n_layers (int): Number of convolutional layers. 29 | kernel_sizes (list of int): Kernel sizes for convolutions. 30 | stride (int): Stride for convolutions. 31 | filters (int): Initial number of filters in convolutions. 32 | filters_scale (int): Multiplier of number of filters as we increase depth. 33 | max_filters (int): Maximum number of filters. 34 | norm (str): Normalization method. 35 | activation (str): Activation function. 36 | activation_params (dict): Parameters to provide to the activation function. 37 | """ 38 | def __init__(self, period: int, in_channels: int = 1, out_channels: int = 1, 39 | n_layers: int = 5, kernel_sizes: tp.List[int] = [5, 3], stride: int = 3, 40 | filters: int = 8, filters_scale: int = 4, max_filters: int = 1024, 41 | norm: str = 'weight_norm', activation: str = 'LeakyReLU', 42 | activation_params: dict = {'negative_slope': 0.2}): 43 | super().__init__() 44 | self.period = period 45 | self.n_layers = n_layers 46 | self.activation = getattr(torch.nn, activation)(**activation_params) 47 | self.convs = nn.ModuleList() 48 | in_chs = in_channels 49 | for i in range(self.n_layers): 50 | out_chs = min(filters * (filters_scale ** (i + 1)), max_filters) 51 | eff_stride = 1 if i == self.n_layers - 1 else stride 52 | self.convs.append(NormConv2d(in_chs, out_chs, kernel_size=(kernel_sizes[0], 1), stride=(eff_stride, 1), 53 | padding=((kernel_sizes[0] - 1) // 2, 0), norm=norm)) 54 | in_chs = out_chs 55 | self.conv_post = NormConv2d(in_chs, out_channels, kernel_size=(kernel_sizes[1], 1), stride=1, 56 | padding=((kernel_sizes[1] - 1) // 2, 0), norm=norm) 57 | 58 | def forward(self, x: torch.Tensor): 59 | fmap = [] 60 | # 1d to 2d 61 | b, c, t = x.shape 62 | if t % self.period != 0: # pad first 63 | n_pad = self.period - (t % self.period) 64 | x = F.pad(x, (0, n_pad), 'reflect') 65 | t = t + n_pad 66 | x = x.view(b, c, t // self.period, self.period) 67 | 68 | for conv in self.convs: 69 | x = conv(x) 70 | x = self.activation(x) 71 | fmap.append(x) 72 | x = self.conv_post(x) 73 | fmap.append(x) 74 | # x = torch.flatten(x, 1, -1) 75 | 76 | return x, fmap 77 | 78 | 79 | class MultiPeriodDiscriminator(MultiDiscriminator): 80 | """Multi-Period (MPD) Discriminator. 81 | 82 | Args: 83 | in_channels (int): Number of input channels. 84 | out_channels (int): Number of output channels. 85 | periods (Sequence[int]): Periods between samples of audio for the sub-discriminators. 86 | **kwargs: Additional args for `PeriodDiscriminator` 87 | """ 88 | def __init__(self, in_channels: int = 1, out_channels: int = 1, 89 | periods: tp.Sequence[int] = [2, 3, 5, 7, 11], **kwargs): 90 | super().__init__() 91 | self.discriminators = nn.ModuleList([ 92 | PeriodDiscriminator(p, in_channels, out_channels, **kwargs) for p in periods 93 | ]) 94 | 95 | @property 96 | def num_discriminators(self): 97 | return len(self.discriminators) 98 | 99 | def forward(self, x: torch.Tensor) -> MultiDiscriminatorOutputType: 100 | logits = [] 101 | fmaps = [] 102 | for disc in self.discriminators: 103 | logit, fmap = disc(x) 104 | logits.append(logit) 105 | fmaps.append(fmap) 106 | return logits, fmaps 107 | -------------------------------------------------------------------------------- /audiocraft/metrics/clap_consistency.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from pathlib import Path 8 | import typing as tp 9 | 10 | import torch 11 | import torchmetrics 12 | from transformers import RobertaTokenizer # type: ignore 13 | 14 | from ..data.audio_utils import convert_audio 15 | from ..environment import AudioCraftEnvironment 16 | from ..utils.utils import load_clap_state_dict 17 | 18 | try: 19 | import laion_clap # type: ignore 20 | except ImportError: 21 | laion_clap = None 22 | 23 | 24 | class TextConsistencyMetric(torchmetrics.Metric): 25 | """Text consistency metric measuring consistency between audio and text pairs.""" 26 | 27 | def update(self, audio: torch.Tensor, text: tp.List[str], sizes: torch.Tensor, sample_rates: torch.Tensor) -> None: 28 | raise NotImplementedError("implement how to update the metric from the audio and text pairs.") 29 | 30 | def compute(self): 31 | raise NotImplementedError("implement how to compute the final metric score.") 32 | 33 | 34 | class CLAPTextConsistencyMetric(TextConsistencyMetric): 35 | """Text consistency metric relying on Contrastive Language-Audio Pretraining (CLAP). 36 | 37 | This metric is similar to the MuLan Cycle Consistency from MusicLM (https://arxiv.org/pdf/2301.11325.pdf) 38 | or the CLAP score used in Make-An-Audio (https://arxiv.org/pdf/2301.12661v1.pdf). 39 | 40 | As a joint audio-text embedding model, a pretrained CLAP model can be used to quantify the 41 | similarity between audio-text pairs. We compute the CLAP embeddings from the text descriptions as 42 | well as the generated audio based on them, and define the MCC metric as the average cosine similarity 43 | between these embeddings. 44 | 45 | Model implementation & pre-trained checkpoints: https://github.com/LAION-AI/CLAP 46 | """ 47 | def __init__(self, model_path: tp.Union[str, Path], model_arch: str = 'HTSAT-tiny', enable_fusion: bool = False): 48 | super().__init__() 49 | if laion_clap is None: 50 | raise ImportError("Please install CLAP to compute text consistency: 'pip install laion_clap'") 51 | self.add_state("cosine_sum", default=torch.tensor(0.), dist_reduce_fx="sum") 52 | self.add_state("weight", default=torch.tensor(0.), dist_reduce_fx="sum") 53 | self._initialize_model(model_path, model_arch, enable_fusion) 54 | 55 | def _initialize_model(self, model_path: tp.Union[str, Path], model_arch: str, enable_fusion: bool): 56 | model_path = AudioCraftEnvironment.resolve_reference_path(model_path) 57 | self.tokenize = RobertaTokenizer.from_pretrained('roberta-base') 58 | self.model = laion_clap.CLAP_Module(enable_fusion=enable_fusion, amodel=model_arch) 59 | self.model_sample_rate = 48_000 60 | load_clap_state_dict(self.model, model_path) 61 | self.model.eval() 62 | 63 | def _tokenizer(self, texts: tp.Union[str, tp.List[str]]) -> dict: 64 | # we use the default params from CLAP module here as well 65 | return self.tokenize(texts, padding="max_length", truncation=True, max_length=77, return_tensors="pt") 66 | 67 | def update(self, audio: torch.Tensor, text: tp.List[str], sizes: torch.Tensor, sample_rates: torch.Tensor) -> None: 68 | """Compute cosine similarity between audio and text pairs and accumulate scores over the dataset.""" 69 | assert audio.size(0) == len(text), "Number of audio and text samples should match" 70 | assert torch.all(sample_rates == sample_rates[0].item()), "All items in batch should have the same sample rate" 71 | sample_rate = int(sample_rates[0].item()) 72 | # convert audio batch to 48kHz monophonic audio with no channel dimension: [B, C, T] -> [B, T] 73 | audio = convert_audio(audio, from_rate=sample_rate, to_rate=self.model_sample_rate, to_channels=1).mean(dim=1) 74 | audio_embeddings = self.model.get_audio_embedding_from_data(audio, use_tensor=True) 75 | text_embeddings = self.model.get_text_embedding(text, tokenizer=self._tokenizer, use_tensor=True) 76 | # cosine similarity between the text and the audio embedding 77 | cosine_sim = torch.nn.functional.cosine_similarity(audio_embeddings, text_embeddings, dim=1, eps=1e-8) 78 | self.cosine_sum += cosine_sim.sum(dim=0) 79 | self.weight += torch.tensor(cosine_sim.size(0)) 80 | 81 | def compute(self): 82 | """Computes the average cosine similarty across all audio/text pairs.""" 83 | assert self.weight.item() > 0, "Unable to compute with total number of comparisons <= 0" # type: ignore 84 | return (self.cosine_sum / self.weight).item() # type: ignore 85 | -------------------------------------------------------------------------------- /audiocraft/modules/streaming.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | """ 8 | Streaming module API that should be implemented by all Streaming components, 9 | """ 10 | 11 | from contextlib import contextmanager 12 | import typing as tp 13 | from torch import nn 14 | import torch 15 | 16 | 17 | State = tp.Dict[str, torch.Tensor] 18 | 19 | 20 | class StreamingModule(nn.Module): 21 | """Common API for streaming components. 22 | 23 | Each streaming component has a streaming state, which is just a dict[str, Tensor]. 24 | By convention, the first dim of each tensor must be the batch size. 25 | Don't use dots in the key names, as this would clash with submodules 26 | (like in state_dict). 27 | 28 | If `self._is_streaming` is True, the component should use and remember 29 | the proper state inside `self._streaming_state`. 30 | 31 | To set a streaming component in streaming state, use 32 | 33 | with module.streaming(): 34 | ... 35 | 36 | This will automatically reset the streaming state when exiting the context manager. 37 | This also automatically propagates to all streaming children module. 38 | 39 | Some module might also implement the `StreamingModule.flush` method, although 40 | this one is trickier, as all parents module must be StreamingModule and implement 41 | it as well for it to work properly. See `StreamingSequential` after. 42 | """ 43 | def __init__(self) -> None: 44 | super().__init__() 45 | self._streaming_state: State = {} 46 | self._is_streaming = False 47 | 48 | def _apply_named_streaming(self, fn: tp.Any): 49 | for name, module in self.named_modules(): 50 | if isinstance(module, StreamingModule): 51 | fn(name, module) 52 | 53 | def _set_streaming(self, streaming: bool): 54 | def _set_streaming(name, module): 55 | module._is_streaming = streaming 56 | self._apply_named_streaming(_set_streaming) 57 | 58 | @contextmanager 59 | def streaming(self): 60 | """Context manager to enter streaming mode. Reset streaming state on exit.""" 61 | self._set_streaming(True) 62 | try: 63 | yield 64 | finally: 65 | self._set_streaming(False) 66 | self.reset_streaming() 67 | 68 | def reset_streaming(self): 69 | """Reset the streaming state.""" 70 | def _reset(name: str, module: StreamingModule): 71 | module._streaming_state.clear() 72 | 73 | self._apply_named_streaming(_reset) 74 | 75 | def get_streaming_state(self) -> State: 76 | """Return the streaming state, including that of sub-modules.""" 77 | state: State = {} 78 | 79 | def _add(name: str, module: StreamingModule): 80 | if name: 81 | name += "." 82 | for key, value in module._streaming_state.items(): 83 | state[name + key] = value 84 | 85 | self._apply_named_streaming(_add) 86 | return state 87 | 88 | def set_streaming_state(self, state: State): 89 | """Set the streaming state, including that of sub-modules.""" 90 | state = dict(state) 91 | 92 | def _set(name: str, module: StreamingModule): 93 | if name: 94 | name += "." 95 | module._streaming_state.clear() 96 | for key, value in list(state.items()): 97 | # complexity is not ideal here, but probably fine. 98 | if key.startswith(name): 99 | local_key = key[len(name):] 100 | if '.' not in local_key: 101 | module._streaming_state[local_key] = value 102 | del state[key] 103 | 104 | self._apply_named_streaming(_set) 105 | assert len(state) == 0, list(state.keys()) 106 | 107 | def flush(self, x: tp.Optional[torch.Tensor] = None): 108 | """Flush any remaining outputs that were waiting for completion. 109 | Typically, for convolutions, this will add the final padding 110 | and process the last buffer. 111 | 112 | This should take an optional argument `x`, which will be provided 113 | if a module before this one in the streaming pipeline has already 114 | spitted out a flushed out buffer. 115 | """ 116 | if x is None: 117 | return None 118 | else: 119 | return self(x) 120 | 121 | 122 | class StreamingSequential(StreamingModule, nn.Sequential): 123 | """A streaming compatible alternative of `nn.Sequential`. 124 | """ 125 | def flush(self, x: tp.Optional[torch.Tensor] = None): 126 | for module in self: 127 | if isinstance(module, StreamingModule): 128 | x = module.flush(x) 129 | elif x is not None: 130 | x = module(x) 131 | return x 132 | --------------------------------------------------------------------------------