├── .github
├── actions
│ └── audiocraft_build
│ │ └── action.yml
└── workflows
│ ├── audiocraft_docs.yml
│ ├── audiocraft_linter.yml
│ └── audiocraft_tests.yml
├── .gitignore
├── CHANGELOG.md
├── CODE_OF_CONDUCT.md
├── CONTRIBUTING.md
├── LICENSE
├── LICENSE_weights
├── MANIFEST.in
├── MODEL_CARD.md
├── Makefile
├── README.md
├── app.py
├── app_batched.py
├── assets
├── bach.mp3
└── bolero_ravel.mp3
├── audiocraft
├── __init__.py
├── data
│ ├── __init__.py
│ ├── audio.py
│ ├── audio_dataset.py
│ ├── audio_utils.py
│ └── zip.py
├── models
│ ├── __init__.py
│ ├── builders.py
│ ├── encodec.py
│ ├── lm.py
│ ├── loaders.py
│ └── musicgen.py
├── modules
│ ├── __init__.py
│ ├── activations.py
│ ├── codebooks_patterns.py
│ ├── conditioners.py
│ ├── conv.py
│ ├── lstm.py
│ ├── rope.py
│ ├── seanet.py
│ ├── streaming.py
│ └── transformer.py
├── py.typed
├── quantization
│ ├── __init__.py
│ ├── base.py
│ ├── core_vq.py
│ └── vq.py
└── utils
│ ├── __init__.py
│ ├── autocast.py
│ ├── export.py
│ ├── notebook.py
│ └── utils.py
├── demo.ipynb
├── mypy.ini
├── requirements.txt
├── setup.cfg
├── setup.py
└── tests
├── __init__.py
├── common_utils
├── __init__.py
├── temp_utils.py
└── wav_utils.py
├── data
├── __init__.py
├── test_audio.py
├── test_audio_dataset.py
└── test_audio_utils.py
├── models
├── test_encodec_model.py
└── test_musicgen.py
├── modules
├── __init__.py
├── test_codebooks_patterns.py
├── test_conv.py
├── test_lstm.py
├── test_rope.py
├── test_seanet.py
└── test_transformer.py
├── quantization
└── test_vq.py
└── utils
└── __init__.py
/.github/actions/audiocraft_build/action.yml:
--------------------------------------------------------------------------------
1 | name: audiocraft_build
2 | description: 'Build audiocraft env.'
3 | runs:
4 | using: "composite"
5 | steps:
6 | - uses: actions/setup-python@v2
7 | with:
8 | python-version: 3.8
9 | - uses: actions/cache@v2
10 | id: cache
11 | with:
12 | path: env
13 | key: audiocraft_env-${{ hashFiles('**/requirements.txt') }}
14 |
15 | - if: ${{ steps.cache.outputs.cache-hit != 'true' }}
16 | name: Install dependencies
17 | shell: bash
18 | run: |
19 | sudo apt-get update
20 | sudo apt-get install libsndfile1-dev ffmpeg
21 | python3 -m venv env
22 | . env/bin/activate
23 | python -m pip install --upgrade pip
24 | pip install -e '.[dev]'
25 | - name: System Dependencies
26 | shell: bash
27 | run: |
28 | sudo apt-get update
29 | sudo apt-get install libsndfile1-dev ffmpeg
30 |
--------------------------------------------------------------------------------
/.github/workflows/audiocraft_docs.yml:
--------------------------------------------------------------------------------
1 | name: audiocraft_docs
2 | on:
3 | push:
4 | branches: [ main ]
5 |
6 | jobs:
7 | run_docs:
8 | name: Run docs
9 | runs-on: ubuntu-latest
10 | steps:
11 | - uses: actions/checkout@v2
12 | - uses: ./.github/actions/audiocraft_build
13 | - name: Config git
14 | run: |
15 | git config --global user.email "defossez@fb.com"
16 | git config --global user.name "Alexandre Défossez (autodoc)"
17 |
18 | - name: Reset branch
19 | run: |
20 | git branch -f gh-docs main
21 | git checkout gh-docs
22 |
23 | - name: Make docs
24 | run: |
25 | . env/bin/activate
26 | make docs
27 | git add -f docs
28 | git commit -m docs
29 |
30 | - name: Push branch
31 | run: |
32 | git push -f -u origin gh-docs
33 |
--------------------------------------------------------------------------------
/.github/workflows/audiocraft_linter.yml:
--------------------------------------------------------------------------------
1 | name: audiocraft_linter
2 | on:
3 | push:
4 | branches: [ main ]
5 | pull_request:
6 | branches: [ main ]
7 |
8 | jobs:
9 | run_linter:
10 | name: Run linter
11 | runs-on: ubuntu-latest
12 | steps:
13 | - uses: actions/checkout@v2
14 | - uses: ./.github/actions/audiocraft_build
15 | - run: |
16 | . env/bin/activate
17 | make linter
18 |
--------------------------------------------------------------------------------
/.github/workflows/audiocraft_tests.yml:
--------------------------------------------------------------------------------
1 | name: audiocraft_tests
2 | on:
3 | push:
4 | branches: [ main ]
5 | pull_request:
6 | branches: [ main ]
7 |
8 | jobs:
9 | run_tests:
10 | name: Run tests
11 | runs-on: ubuntu-latest
12 | steps:
13 | - uses: actions/checkout@v2
14 | - uses: ./.github/actions/audiocraft_build
15 | - run: |
16 | . env/bin/activate
17 | make tests
18 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # macOS dir files
10 | .DS_Store
11 |
12 | # Distribution / packaging
13 | .Python
14 | env/
15 | build/
16 | develop-eggs/
17 | dist/
18 | downloads/
19 | eggs/
20 | .eggs/
21 | lib/
22 | lib64/
23 | parts/
24 | sdist/
25 | var/
26 | wheels/
27 | *.egg-info/
28 | .installed.cfg
29 | *.egg
30 | .ipynb_checkpoints
31 |
32 | # Tests and linter
33 | .pytest_cache/
34 | .mypy_cache/
35 | .coverage
36 |
37 | # docs
38 | /docs
39 |
40 | # dotenv
41 | .env
42 | .envrc
43 |
44 | # virtualenv
45 | .venv
46 | venv/
47 | ENV/
48 |
49 | # personal notebooks & scripts
50 | */local_scripts
51 | */notes
52 | .vscode/
53 | /notebooks
54 | /local_scripts
55 | /notes
56 |
--------------------------------------------------------------------------------
/CHANGELOG.md:
--------------------------------------------------------------------------------
1 | # Changelog
2 |
3 | All notable changes to this project will be documented in this file.
4 |
5 | The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).
6 |
7 | ## [0.0.2a] - TBD
8 |
9 | Improved demo, fixed top p (thanks @jnordberg).
10 |
11 | Compressor tanh on output to avoid clipping with some style (especially piano).
12 | Now repeating the conditioning periodically if it is too short.
13 |
14 | More options when launching Gradio app locally (thanks @ashleykleynhans).
15 |
16 | ## [0.0.1] - 2023-06-09
17 |
18 | Initial release, with model evaluation only.
19 |
--------------------------------------------------------------------------------
/CODE_OF_CONDUCT.md:
--------------------------------------------------------------------------------
1 | # Code of Conduct
2 |
3 | ## Our Pledge
4 |
5 | In the interest of fostering an open and welcoming environment, we as
6 | contributors and maintainers pledge to make participation in our project and
7 | our community a harassment-free experience for everyone, regardless of age, body
8 | size, disability, ethnicity, sex characteristics, gender identity and expression,
9 | level of experience, education, socio-economic status, nationality, personal
10 | appearance, race, religion, or sexual identity and orientation.
11 |
12 | ## Our Standards
13 |
14 | Examples of behavior that contributes to creating a positive environment
15 | include:
16 |
17 | * Using welcoming and inclusive language
18 | * Being respectful of differing viewpoints and experiences
19 | * Gracefully accepting constructive criticism
20 | * Focusing on what is best for the community
21 | * Showing empathy towards other community members
22 |
23 | Examples of unacceptable behavior by participants include:
24 |
25 | * The use of sexualized language or imagery and unwelcome sexual attention or
26 | advances
27 | * Trolling, insulting/derogatory comments, and personal or political attacks
28 | * Public or private harassment
29 | * Publishing others' private information, such as a physical or electronic
30 | address, without explicit permission
31 | * Other conduct which could reasonably be considered inappropriate in a
32 | professional setting
33 |
34 | ## Our Responsibilities
35 |
36 | Project maintainers are responsible for clarifying the standards of acceptable
37 | behavior and are expected to take appropriate and fair corrective action in
38 | response to any instances of unacceptable behavior.
39 |
40 | Project maintainers have the right and responsibility to remove, edit, or
41 | reject comments, commits, code, wiki edits, issues, and other contributions
42 | that are not aligned to this Code of Conduct, or to ban temporarily or
43 | permanently any contributor for other behaviors that they deem inappropriate,
44 | threatening, offensive, or harmful.
45 |
46 | ## Scope
47 |
48 | This Code of Conduct applies within all project spaces, and it also applies when
49 | an individual is representing the project or its community in public spaces.
50 | Examples of representing a project or community include using an official
51 | project e-mail address, posting via an official social media account, or acting
52 | as an appointed representative at an online or offline event. Representation of
53 | a project may be further defined and clarified by project maintainers.
54 |
55 | This Code of Conduct also applies outside the project spaces when there is a
56 | reasonable belief that an individual's behavior may have a negative impact on
57 | the project or its community.
58 |
59 | ## Enforcement
60 |
61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be
62 | reported by contacting the project team at . All
63 | complaints will be reviewed and investigated and will result in a response that
64 | is deemed necessary and appropriate to the circumstances. The project team is
65 | obligated to maintain confidentiality with regard to the reporter of an incident.
66 | Further details of specific enforcement policies may be posted separately.
67 |
68 | Project maintainers who do not follow or enforce the Code of Conduct in good
69 | faith may face temporary or permanent repercussions as determined by other
70 | members of the project's leadership.
71 |
72 | ## Attribution
73 |
74 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
75 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html
76 |
77 | [homepage]: https://www.contributor-covenant.org
78 |
79 | For answers to common questions about this code of conduct, see
80 | https://www.contributor-covenant.org/faq
81 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # Contributing to Audiocraft
2 |
3 | We want to make contributing to this project as easy and transparent as
4 | possible.
5 |
6 | ## Pull Requests
7 |
8 | Audiocraft is the implementation of a research paper.
9 | Therefore, we do not plan on accepting many pull requests for new features.
10 | We certainly welcome them for bug fixes.
11 |
12 | 1. Fork the repo and create your branch from `main`.
13 | 2. If you've added code that should be tested, add tests.
14 | 3. If you've changed APIs, update the documentation.
15 | 4. Ensure the test suite passes.
16 | 5. Make sure your code lints.
17 | 6. If you haven't already, complete the Contributor License Agreement ("CLA").
18 |
19 | ## Contributor License Agreement ("CLA")
20 | In order to accept your pull request, we need you to submit a CLA. You only need
21 | to do this once to work on any of Meta's open source projects.
22 |
23 | Complete your CLA here:
24 |
25 | ## Issues
26 | We use GitHub issues to track public bugs. Please ensure your description is
27 | clear and has sufficient instructions to be able to reproduce the issue.
28 |
29 | Meta has a [bounty program](https://www.facebook.com/whitehat/) for the safe
30 | disclosure of security bugs. In those cases, please go through the process
31 | outlined on that page and do not file a public issue.
32 |
33 | ## License
34 | By contributing to encodec, you agree that your contributions will be licensed
35 | under the LICENSE file in the root directory of this source tree.
36 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) Meta Platforms, Inc. and affiliates.
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/MANIFEST.in:
--------------------------------------------------------------------------------
1 | include Makefile
2 | include LICENSE
3 | include LICENSE_weights
4 | include *.md
5 | include *.ini
6 | include requirements.txt
7 | include audiocraft/py.typed
8 | include assets/*.mp3
9 |
--------------------------------------------------------------------------------
/MODEL_CARD.md:
--------------------------------------------------------------------------------
1 | # MusicGen Model Card
2 |
3 | ## Model details
4 |
5 | **Organization developing the model:** The FAIR team of Meta AI.
6 |
7 | **Model date:** MusicGen was trained between April 2023 and May 2023.
8 |
9 | **Model version:** This is the version 1 of the model.
10 |
11 | **Model type:** MusicGen consists of an EnCodec model for audio tokenization, an auto-regressive language model based on the transformer architecture for music modeling. The model comes in different sizes: 300M, 1.5B and 3.3B parameters ; and two variants: a model trained for text-to-music generation task and a model trained for melody-guided music generation.
12 |
13 | **Paper or resources for more information:** More information can be found in the paper [Simple and Controllable Music Generation][arxiv].
14 |
15 | **Citation details** See [our paper][arxiv]
16 |
17 | **License** Code is released under MIT, model weights are released under CC-BY-NC 4.0.
18 |
19 | **Where to send questions or comments about the model:** Questions and comments about MusicGen can be sent via the [Github repository](https://github.com/facebookresearch/audiocraft) of the project, or by opening an issue.
20 |
21 | ## Intended use
22 | **Primary intended use:** The primary use of MusicGen is research on AI-based music generation, including:
23 |
24 | - Research efforts, such as probing and better understanding the limitations of generative models to further improve the state of science
25 | - Generation of music guided by text or melody to understand current abilities of generative AI models by machine learning amateurs
26 |
27 | **Primary intended users:** The primary intended users of the model are researchers in audio, machine learning and artificial intelligence, as well as amateur seeking to better understand those models.
28 |
29 | **Out-of-scope use cases** The model should not be used on downstream applications without further risk evaluation and mitigation. The model should not be used to intentionally create or disseminate music pieces that create hostile or alienating environments for people. This includes generating music that people would foreseeably find disturbing, distressing, or offensive; or content that propagates historical or current stereotypes.
30 |
31 | ## Metrics
32 |
33 | **Models performance measures:** We used the following objective measure to evaluate the model on a standard music benchmark:
34 |
35 | - Frechet Audio Distance computed on features extracted from a pre-trained audio classifier (VGGish)
36 | - Kullback-Leibler Divergence on label distributions extracted from a pre-trained audio classifier (PaSST)
37 | - CLAP Score between audio embedding and text embedding extracted from a pre-trained CLAP model
38 |
39 | Additionally, we run qualitative studies with human participants, evaluating the performance of the model with the following axes:
40 |
41 | - Overall quality of the music samples;
42 | - Text relevance to the provided text input;
43 | - Adherence to the melody for melody-guided music generation.
44 |
45 | More details on performance measures and human studies can be found in the paper.
46 |
47 | **Decision thresholds:** Not applicable.
48 |
49 | ## Evaluation datasets
50 |
51 | The model was evaluated on the [MusicCaps benchmark](https://www.kaggle.com/datasets/googleai/musiccaps) and on an in-domain held-out evaluation set, with no artist overlap with the training set.
52 |
53 | ## Training datasets
54 |
55 | The model was trained on licensed data using the following sources: the [Meta Music Initiative Sound Collection](https://www.fb.com/sound), [Shutterstock music collection](https://www.shutterstock.com/music) and the [Pond5 music collection](https://www.pond5.com/). See the paper for more details about the training set and corresponding preprocessing.
56 |
57 | ## Quantitative analysis
58 |
59 | More information can be found in the paper [Simple and Controllable Music Generation][arxiv], in the Experimental Setup section.
60 |
61 | ## Limitations and biases
62 |
63 | **Data:** The data sources used to train the model are created by music professionals and covered by legal agreements with the right holders. The model is trained on 20K hours of data, we believe that scaling the model on larger datasets can further improve the performance of the model.
64 |
65 | **Mitigations:** Vocals have been removed from the data source using corresponding tags, and then using using a state-of-the-art music source separation method, namely using the open source [Hybrid Transformer for Music Source Separation](https://github.com/facebookresearch/demucs) (HT-Demucs).
66 |
67 | **Limitations:**
68 |
69 | - The model is not able to generate realistic vocals.
70 | - The model has been trained with English descriptions and will not perform as well in other languages.
71 | - The model does not perform equally well for all music styles and cultures.
72 | - The model sometimes generates end of songs, collapsing to silence.
73 | - It is sometimes difficult to assess what types of text descriptions provide the best generations. Prompt engineering may be required to obtain satisfying results.
74 |
75 | **Biases:** The source of data is potentially lacking diversity and all music cultures are not equally represented in the dataset. The model may not perform equally well on the wide variety of music genres that exists. The generated samples from the model will reflect the biases from the training data. Further work on this model should include methods for balanced and just representations of cultures, for example, by scaling the training data to be both diverse and inclusive.
76 |
77 | **Risks and harms:** Biases and limitations of the model may lead to generation of samples that may be considered as biased, inappropriate or offensive. We believe that providing the code to reproduce the research and train new models will allow to broaden the application to new and more representative data.
78 |
79 | **Use cases:** Users must be aware of the biases, limitations and risks of the model. MusicGen is a model developed for artificial intelligence research on controllable music generation. As such, it should not be used for downstream applications without further investigation and mitigation of risks.
80 |
81 | [arxiv]: https://arxiv.org/abs/2306.05284
82 |
--------------------------------------------------------------------------------
/Makefile:
--------------------------------------------------------------------------------
1 | default: linter tests
2 |
3 | install:
4 | pip install -U pip
5 | pip install -U -e '.[dev]'
6 |
7 | linter:
8 | flake8 audiocraft && mypy audiocraft
9 | flake8 tests && mypy tests
10 |
11 | tests:
12 | coverage run -m pytest tests
13 | coverage report --include 'audiocraft/*'
14 |
15 | docs:
16 | pdoc3 --html -o docs -f audiocraft
17 |
18 | dist:
19 | python setup.py sdist
20 |
21 | .PHONY: linter tests docs dist
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Audiocraft
2 | 
3 | 
4 | 
5 |
6 | Audiocraft is a PyTorch library for deep learning research on audio generation. At the moment, it contains the code for MusicGen, a state-of-the-art controllable text-to-music model.
7 |
8 | ## MusicGen
9 |
10 | Audiocraft provides the code and models for MusicGen, [a simple and controllable model for music generation][arxiv]. MusicGen is a single stage auto-regressive
11 | Transformer model trained over a 32kHz EnCodec tokenizer with 4 codebooks sampled at 50 Hz. Unlike existing methods like [MusicLM](https://arxiv.org/abs/2301.11325), MusicGen doesn't require a self-supervised semantic representation, and it generates
12 | all 4 codebooks in one pass. By introducing a small delay between the codebooks, we show we can predict
13 | them in parallel, thus having only 50 auto-regressive steps per second of audio.
14 | Check out our [sample page][musicgen_samples] or test the available demo!
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 | We use 20K hours of licensed music to train MusicGen. Specifically, we rely on an internal dataset of 10K high-quality music tracks, and on the ShutterStock and Pond5 music data.
25 |
26 | ## Installation
27 | Audiocraft requires Python 3.9, PyTorch 2.0.0, and a GPU with at least 16 GB of memory (for the medium-sized model). To install Audiocraft, you can run the following:
28 |
29 | ```shell
30 | # Best to make sure you have torch installed first, in particular before installing xformers.
31 | # Don't run this if you already have PyTorch installed.
32 | pip install 'torch>=2.0'
33 | # Then proceed to one of the following
34 | pip install -U audiocraft # stable release
35 | pip install -U git+https://git@github.com/facebookresearch/audiocraft#egg=audiocraft # bleeding edge
36 | pip install -e . # or if you cloned the repo locally
37 | ```
38 |
39 | ## Usage
40 | We offer a number of way to interact with MusicGen:
41 | 1. You can play with MusicGen by running the jupyter notebook at [`demo.ipynb`](./demo.ipynb) locally, or use the provided [colab notebook](https://colab.research.google.com/drive/1fxGqfg96RBUvGxZ1XXN07s3DthrKUl4-?usp=sharing).
42 | 2. You can use the gradio demo locally by running `python app.py`.
43 | 3. A demo is also available on the [`facebook/MusicGen` HuggingFace Space](https://huggingface.co/spaces/facebook/MusicGen) (huge thanks to all the HF team for their support).
44 | 4. Finally, you can run the [Gradio demo with a Colab GPU](https://colab.research.google.com/drive/1-Xe9NCdIs2sCUbiSmwHXozK6AAhMm7_i?usp=sharing),
45 | as adapted from [@camenduru Colab](https://github.com/camenduru/MusicGen-colab).
46 |
47 | ## API
48 |
49 | We provide a simple API and 4 pre-trained models. The pre trained models are:
50 | - `small`: 300M model, text to music only - [🤗 Hub](https://huggingface.co/facebook/musicgen-small)
51 | - `medium`: 1.5B model, text to music only - [🤗 Hub](https://huggingface.co/facebook/musicgen-medium)
52 | - `melody`: 1.5B model, text to music and text+melody to music - [🤗 Hub](https://huggingface.co/facebook/musicgen-melody)
53 | - `large`: 3.3B model, text to music only - [🤗 Hub](https://huggingface.co/facebook/musicgen-large)
54 |
55 | We observe the best trade-off between quality and compute with the `medium` or `melody` model.
56 | In order to use MusicGen locally **you must have a GPU**. We recommend 16GB of memory, but smaller
57 | GPUs will be able to generate short sequences, or longer sequences with the `small` model.
58 |
59 | **Note**: Please make sure to have [ffmpeg](https://ffmpeg.org/download.html) installed when using newer version of `torchaudio`.
60 | You can install it with:
61 | ```
62 | apt-get install ffmpeg
63 | ```
64 |
65 | See after a quick example for using the API.
66 |
67 | ```python
68 | import torchaudio
69 | from audiocraft.models import MusicGen
70 | from audiocraft.data.audio import audio_write
71 |
72 | model = MusicGen.get_pretrained('melody')
73 | model.set_generation_params(duration=8) # generate 8 seconds.
74 | wav = model.generate_unconditional(4) # generates 4 unconditional audio samples
75 | descriptions = ['happy rock', 'energetic EDM', 'sad jazz']
76 | wav = model.generate(descriptions) # generates 3 samples.
77 |
78 | melody, sr = torchaudio.load('./assets/bach.mp3')
79 | # generates using the melody from the given audio and the provided descriptions.
80 | wav = model.generate_with_chroma(descriptions, melody[None].expand(3, -1, -1), sr)
81 |
82 | for idx, one_wav in enumerate(wav):
83 | # Will save under {idx}.wav, with loudness normalization at -14 db LUFS.
84 | audio_write(f'{idx}', one_wav.cpu(), model.sample_rate, strategy="loudness", loudness_compressor=True)
85 | ```
86 |
87 |
88 | ## Model Card
89 |
90 | See [the model card page](./MODEL_CARD.md).
91 |
92 | ## FAQ
93 |
94 | #### Will the training code be released?
95 |
96 | Yes. We will soon release the training code for MusicGen and EnCodec.
97 |
98 |
99 | #### I need help on Windows
100 |
101 | @FurkanGozukara made a complete tutorial for [Audiocraft/MusicGen on Windows](https://youtu.be/v-YpvPkhdO4)
102 |
103 |
104 | ## Citation
105 | ```
106 | @article{copet2023simple,
107 | title={Simple and Controllable Music Generation},
108 | author={Jade Copet and Felix Kreuk and Itai Gat and Tal Remez and David Kant and Gabriel Synnaeve and Yossi Adi and Alexandre Défossez},
109 | year={2023},
110 | journal={arXiv preprint arXiv:2306.05284},
111 | }
112 | ```
113 |
114 | ## License
115 | * The code in this repository is released under the MIT license as found in the [LICENSE file](LICENSE).
116 | * The weights in this repository are released under the CC-BY-NC 4.0 license as found in the [LICENSE_weights file](LICENSE_weights).
117 |
118 | [arxiv]: https://arxiv.org/abs/2306.05284
119 | [musicgen_samples]: https://ai.honu.io/papers/musicgen/
120 |
--------------------------------------------------------------------------------
/app.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) Meta Platforms, Inc. and affiliates.
3 | All rights reserved.
4 |
5 | This source code is licensed under the license found in the
6 | LICENSE file in the root directory of this source tree.
7 | """
8 |
9 | from tempfile import NamedTemporaryFile
10 | import argparse
11 | import torch
12 | import gradio as gr
13 | import os
14 | from audiocraft.models import MusicGen
15 | from audiocraft.data.audio import audio_write
16 |
17 | MODEL = None
18 | IS_SHARED_SPACE = "musicgen/MusicGen" in os.environ.get('SPACE_ID', '')
19 |
20 |
21 | def load_model(version):
22 | print("Loading model", version)
23 | return MusicGen.get_pretrained(version)
24 |
25 |
26 | def predict(model, text, melody, duration, topk, topp, temperature, cfg_coef):
27 | global MODEL
28 | topk = int(topk)
29 | if MODEL is None or MODEL.name != model:
30 | MODEL = load_model(model)
31 |
32 | if duration > MODEL.lm.cfg.dataset.segment_duration:
33 | raise gr.Error("MusicGen currently supports durations of up to 30 seconds!")
34 | MODEL.set_generation_params(
35 | use_sampling=True,
36 | top_k=topk,
37 | top_p=topp,
38 | temperature=temperature,
39 | cfg_coef=cfg_coef,
40 | duration=duration,
41 | )
42 |
43 | if melody:
44 | sr, melody = melody[0], torch.from_numpy(melody[1]).to(MODEL.device).float().t().unsqueeze(0)
45 | print(melody.shape)
46 | if melody.dim() == 2:
47 | melody = melody[None]
48 | melody = melody[..., :int(sr * MODEL.lm.cfg.dataset.segment_duration)]
49 | output = MODEL.generate_with_chroma(
50 | descriptions=[text],
51 | melody_wavs=melody,
52 | melody_sample_rate=sr,
53 | progress=False
54 | )
55 | else:
56 | output = MODEL.generate(descriptions=[text], progress=False)
57 |
58 | output = output.detach().cpu().float()[0]
59 | with NamedTemporaryFile("wb", suffix=".wav", delete=False) as file:
60 | audio_write(
61 | file.name, output, MODEL.sample_rate, strategy="loudness",
62 | loudness_headroom_db=16, loudness_compressor=True, add_suffix=False)
63 | waveform_video = gr.make_waveform(file.name)
64 | return waveform_video
65 |
66 |
67 | def ui(**kwargs):
68 | with gr.Blocks() as interface:
69 | gr.Markdown(
70 | """
71 | # MusicGen
72 | This is your private demo for [MusicGen](https://github.com/facebookresearch/audiocraft), a simple and controllable model for music generation
73 | presented at: ["Simple and Controllable Music Generation"](https://huggingface.co/papers/2306.05284)
74 | """
75 | )
76 | if IS_SHARED_SPACE:
77 | gr.Markdown("""
78 | ⚠ This Space doesn't work in this shared UI ⚠
79 |
80 |
81 |
82 | to use it privately, or use the public demo
83 | """)
84 | with gr.Row():
85 | with gr.Column():
86 | with gr.Row():
87 | text = gr.Text(label="Input Text", interactive=True)
88 | melody = gr.Audio(source="upload", type="numpy", label="Melody Condition (optional)", interactive=True)
89 | with gr.Row():
90 | submit = gr.Button("Submit")
91 | with gr.Row():
92 | model = gr.Radio(["melody", "medium", "small", "large"], label="Model", value="melody", interactive=True)
93 | with gr.Row():
94 | duration = gr.Slider(minimum=1, maximum=30, value=10, label="Duration", interactive=True)
95 | with gr.Row():
96 | topk = gr.Number(label="Top-k", value=250, interactive=True)
97 | topp = gr.Number(label="Top-p", value=0, interactive=True)
98 | temperature = gr.Number(label="Temperature", value=1.0, interactive=True)
99 | cfg_coef = gr.Number(label="Classifier Free Guidance", value=3.0, interactive=True)
100 | with gr.Column():
101 | output = gr.Video(label="Generated Music")
102 | submit.click(predict, inputs=[model, text, melody, duration, topk, topp, temperature, cfg_coef], outputs=[output])
103 | gr.Examples(
104 | fn=predict,
105 | examples=[
106 | [
107 | "An 80s driving pop song with heavy drums and synth pads in the background",
108 | "./assets/bach.mp3",
109 | "melody"
110 | ],
111 | [
112 | "A cheerful country song with acoustic guitars",
113 | "./assets/bolero_ravel.mp3",
114 | "melody"
115 | ],
116 | [
117 | "90s rock song with electric guitar and heavy drums",
118 | None,
119 | "medium"
120 | ],
121 | [
122 | "a light and cheerly EDM track, with syncopated drums, aery pads, and strong emotions",
123 | "./assets/bach.mp3",
124 | "melody"
125 | ],
126 | [
127 | "lofi slow bpm electro chill with organic samples",
128 | None,
129 | "medium",
130 | ],
131 | ],
132 | inputs=[text, melody, model],
133 | outputs=[output]
134 | )
135 | gr.Markdown(
136 | """
137 | ### More details
138 |
139 | The model will generate a short music extract based on the description you provided.
140 | You can generate up to 30 seconds of audio.
141 |
142 | We present 4 model variations:
143 | 1. Melody -- a music generation model capable of generating music condition on text and melody inputs. **Note**, you can also use text only.
144 | 2. Small -- a 300M transformer decoder conditioned on text only.
145 | 3. Medium -- a 1.5B transformer decoder conditioned on text only.
146 | 4. Large -- a 3.3B transformer decoder conditioned on text only (might OOM for the longest sequences.)
147 |
148 | When using `melody`, ou can optionaly provide a reference audio from
149 | which a broad melody will be extracted. The model will then try to follow both the description and melody provided.
150 |
151 | You can also use your own GPU or a Google Colab by following the instructions on our repo.
152 | See [github.com/facebookresearch/audiocraft](https://github.com/facebookresearch/audiocraft)
153 | for more details.
154 | """
155 | )
156 |
157 | # Show the interface
158 | launch_kwargs = {}
159 | username = kwargs.get('username')
160 | password = kwargs.get('password')
161 | server_port = kwargs.get('server_port', 0)
162 | inbrowser = kwargs.get('inbrowser', False)
163 | share = kwargs.get('share', False)
164 | server_name = kwargs.get('listen')
165 |
166 | launch_kwargs['server_name'] = server_name
167 |
168 | if username and password:
169 | launch_kwargs['auth'] = (username, password)
170 | if server_port > 0:
171 | launch_kwargs['server_port'] = server_port
172 | if inbrowser:
173 | launch_kwargs['inbrowser'] = inbrowser
174 | if share:
175 | launch_kwargs['share'] = share
176 |
177 | interface.queue().launch(**launch_kwargs, max_threads=1)
178 |
179 |
180 | if __name__ == "__main__":
181 | parser = argparse.ArgumentParser()
182 | parser.add_argument(
183 | '--listen',
184 | type=str,
185 | default='127.0.0.1',
186 | help='IP to listen on for connections to Gradio',
187 | )
188 | parser.add_argument(
189 | '--username', type=str, default='', help='Username for authentication'
190 | )
191 | parser.add_argument(
192 | '--password', type=str, default='', help='Password for authentication'
193 | )
194 | parser.add_argument(
195 | '--server_port',
196 | type=int,
197 | default=0,
198 | help='Port to run the server listener on',
199 | )
200 | parser.add_argument(
201 | '--inbrowser', action='store_true', help='Open in browser'
202 | )
203 | parser.add_argument(
204 | '--share', action='store_true', help='Share the gradio UI'
205 | )
206 |
207 | args = parser.parse_args()
208 |
209 | ui(
210 | username=args.username,
211 | password=args.password,
212 | inbrowser=args.inbrowser,
213 | server_port=args.server_port,
214 | share=args.share,
215 | listen=args.listen
216 | )
217 |
--------------------------------------------------------------------------------
/app_batched.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) Meta Platforms, Inc. and affiliates.
3 | All rights reserved.
4 |
5 | This source code is licensed under the license found in the
6 | LICENSE file in the root directory of this source tree.
7 | """
8 |
9 | from tempfile import NamedTemporaryFile
10 | import torch
11 | import gradio as gr
12 | from audiocraft.data.audio_utils import convert_audio
13 | from audiocraft.data.audio import audio_write
14 | from audiocraft.models import MusicGen
15 |
16 |
17 | MODEL = None
18 |
19 |
20 | def load_model():
21 | print("Loading model")
22 | return MusicGen.get_pretrained("melody")
23 |
24 |
25 | def predict(texts, melodies):
26 | global MODEL
27 | if MODEL is None:
28 | MODEL = load_model()
29 |
30 | duration = 12
31 | MODEL.set_generation_params(duration=duration)
32 |
33 | print(texts, melodies)
34 | processed_melodies = []
35 |
36 | target_sr = 32000
37 | target_ac = 1
38 | for melody in melodies:
39 | if melody is None:
40 | processed_melodies.append(None)
41 | else:
42 | sr, melody = melody[0], torch.from_numpy(melody[1]).to(MODEL.device).float().t()
43 | if melody.dim() == 1:
44 | melody = melody[None]
45 | melody = melody[..., :int(sr * duration)]
46 | melody = convert_audio(melody, sr, target_sr, target_ac)
47 | processed_melodies.append(melody)
48 |
49 | outputs = MODEL.generate_with_chroma(
50 | descriptions=texts,
51 | melody_wavs=processed_melodies,
52 | melody_sample_rate=target_sr,
53 | progress=False
54 | )
55 |
56 | outputs = outputs.detach().cpu().float()
57 | out_files = []
58 | for output in outputs:
59 | with NamedTemporaryFile("wb", suffix=".wav", delete=False) as file:
60 | audio_write(
61 | file.name, output, MODEL.sample_rate, strategy="loudness",
62 | loudness_headroom_db=16, loudness_compressor=True, add_suffix=False)
63 | waveform_video = gr.make_waveform(file.name)
64 | out_files.append(waveform_video)
65 | return [out_files]
66 |
67 |
68 | with gr.Blocks() as demo:
69 | gr.Markdown(
70 | """
71 | # MusicGen
72 |
73 | This is the demo for [MusicGen](https://github.com/facebookresearch/audiocraft), a simple and controllable model for music generation
74 | presented at: ["Simple and Controllable Music Generation"](https://huggingface.co/papers/2306.05284).
75 |
76 |
77 |
78 | for longer sequences, more control and no queue.
79 | """
80 | )
81 | with gr.Row():
82 | with gr.Column():
83 | with gr.Row():
84 | text = gr.Text(label="Describe your music", lines=2, interactive=True)
85 | melody = gr.Audio(source="upload", type="numpy", label="Condition on a melody (optional)", interactive=True)
86 | with gr.Row():
87 | submit = gr.Button("Generate")
88 | with gr.Column():
89 | output = gr.Video(label="Generated Music")
90 | submit.click(predict, inputs=[text, melody], outputs=[output], batch=True, max_batch_size=12)
91 | gr.Examples(
92 | fn=predict,
93 | examples=[
94 | [
95 | "An 80s driving pop song with heavy drums and synth pads in the background",
96 | "./assets/bach.mp3",
97 | ],
98 | [
99 | "A cheerful country song with acoustic guitars",
100 | "./assets/bolero_ravel.mp3",
101 | ],
102 | [
103 | "90s rock song with electric guitar and heavy drums",
104 | None,
105 | ],
106 | [
107 | "a light and cheerly EDM track, with syncopated drums, aery pads, and strong emotions bpm: 130",
108 | "./assets/bach.mp3",
109 | ],
110 | [
111 | "lofi slow bpm electro chill with organic samples",
112 | None,
113 | ],
114 | ],
115 | inputs=[text, melody],
116 | outputs=[output]
117 | )
118 | gr.Markdown("""
119 | ### More details
120 |
121 | The model will generate 12 seconds of audio based on the description you provided.
122 | You can optionaly provide a reference audio from which a broad melody will be extracted.
123 | The model will then try to follow both the description and melody provided.
124 | All samples are generated with the `melody` model.
125 |
126 | You can also use your own GPU or a Google Colab by following the instructions on our repo.
127 |
128 | See [github.com/facebookresearch/audiocraft](https://github.com/facebookresearch/audiocraft)
129 | for more details.
130 | """)
131 |
132 | demo.queue(max_size=15).launch()
133 |
--------------------------------------------------------------------------------
/assets/bach.mp3:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rkfg/audiocraft/6d70065e31c2fb422a76237e03740dd3b627de8d/assets/bach.mp3
--------------------------------------------------------------------------------
/assets/bolero_ravel.mp3:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rkfg/audiocraft/6d70065e31c2fb422a76237e03740dd3b627de8d/assets/bolero_ravel.mp3
--------------------------------------------------------------------------------
/audiocraft/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # flake8: noqa
8 | from . import data, modules, models
9 |
10 | __version__ = '0.0.2a1'
11 |
--------------------------------------------------------------------------------
/audiocraft/data/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # flake8: noqa
8 | from . import audio, audio_dataset
9 |
--------------------------------------------------------------------------------
/audiocraft/data/audio.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | """
8 | Audio IO methods are defined in this module (info, read, write),
9 | We rely on av library for faster read when possible, otherwise on torchaudio.
10 | """
11 |
12 | from dataclasses import dataclass
13 | from pathlib import Path
14 | import logging
15 | import typing as tp
16 |
17 | import numpy as np
18 | import soundfile
19 | import torch
20 | from torch.nn import functional as F
21 | import torchaudio as ta
22 |
23 | import av
24 |
25 | from .audio_utils import f32_pcm, i16_pcm, normalize_audio
26 |
27 |
28 | _av_initialized = False
29 |
30 |
31 | def _init_av():
32 | global _av_initialized
33 | if _av_initialized:
34 | return
35 | logger = logging.getLogger('libav.mp3')
36 | logger.setLevel(logging.ERROR)
37 | _av_initialized = True
38 |
39 |
40 | @dataclass(frozen=True)
41 | class AudioFileInfo:
42 | sample_rate: int
43 | duration: float
44 | channels: int
45 |
46 |
47 | def _av_info(filepath: tp.Union[str, Path]) -> AudioFileInfo:
48 | _init_av()
49 | with av.open(str(filepath)) as af:
50 | stream = af.streams.audio[0]
51 | sample_rate = stream.codec_context.sample_rate
52 | duration = float(stream.duration * stream.time_base)
53 | channels = stream.channels
54 | return AudioFileInfo(sample_rate, duration, channels)
55 |
56 |
57 | def _soundfile_info(filepath: tp.Union[str, Path]) -> AudioFileInfo:
58 | info = soundfile.info(filepath)
59 | return AudioFileInfo(info.samplerate, info.duration, info.channels)
60 |
61 |
62 | def audio_info(filepath: tp.Union[str, Path]) -> AudioFileInfo:
63 | # torchaudio no longer returns useful duration informations for some formats like mp3s.
64 | filepath = Path(filepath)
65 | if filepath.suffix in ['.flac', '.ogg']: # TODO: Validate .ogg can be safely read with av_info
66 | # ffmpeg has some weird issue with flac.
67 | return _soundfile_info(filepath)
68 | else:
69 | return _av_info(filepath)
70 |
71 |
72 | def _av_read(filepath: tp.Union[str, Path], seek_time: float = 0, duration: float = -1.) -> tp.Tuple[torch.Tensor, int]:
73 | """FFMPEG-based audio file reading using PyAV bindings.
74 | Soundfile cannot read mp3 and av_read is more efficient than torchaudio.
75 |
76 | Args:
77 | filepath (str or Path): Path to audio file to read.
78 | seek_time (float): Time at which to start reading in the file.
79 | duration (float): Duration to read from the file. If set to -1, the whole file is read.
80 | Returns:
81 | Tuple[torch.Tensor, int]: Tuple containing audio data and sample rate
82 | """
83 | _init_av()
84 | with av.open(str(filepath)) as af:
85 | stream = af.streams.audio[0]
86 | sr = stream.codec_context.sample_rate
87 | num_frames = int(sr * duration) if duration >= 0 else -1
88 | frame_offset = int(sr * seek_time)
89 | # we need a small negative offset otherwise we get some edge artifact
90 | # from the mp3 decoder.
91 | af.seek(int(max(0, (seek_time - 0.1)) / stream.time_base), stream=stream)
92 | frames = []
93 | length = 0
94 | for frame in af.decode(streams=stream.index):
95 | current_offset = int(frame.rate * frame.pts * frame.time_base)
96 | strip = max(0, frame_offset - current_offset)
97 | buf = torch.from_numpy(frame.to_ndarray())
98 | if buf.shape[0] != stream.channels:
99 | buf = buf.view(-1, stream.channels).t()
100 | buf = buf[:, strip:]
101 | frames.append(buf)
102 | length += buf.shape[1]
103 | if num_frames > 0 and length >= num_frames:
104 | break
105 | assert frames
106 | # If the above assert fails, it is likely because we seeked past the end of file point,
107 | # in which case ffmpeg returns a single frame with only zeros, and a weird timestamp.
108 | # This will need proper debugging, in due time.
109 | wav = torch.cat(frames, dim=1)
110 | assert wav.shape[0] == stream.channels
111 | if num_frames > 0:
112 | wav = wav[:, :num_frames]
113 | return f32_pcm(wav), sr
114 |
115 |
116 | def audio_read(filepath: tp.Union[str, Path], seek_time: float = 0.,
117 | duration: float = -1., pad: bool = False) -> tp.Tuple[torch.Tensor, int]:
118 | """Read audio by picking the most appropriate backend tool based on the audio format.
119 |
120 | Args:
121 | filepath (str or Path): Path to audio file to read.
122 | seek_time (float): Time at which to start reading in the file.
123 | duration (float): Duration to read from the file. If set to -1, the whole file is read.
124 | pad (bool): Pad output audio if not reaching expected duration.
125 | Returns:
126 | Tuple[torch.Tensor, int]: Tuple containing audio data and sample rate.
127 | """
128 | fp = Path(filepath)
129 | if fp.suffix in ['.flac', '.ogg']: # TODO: check if we can safely use av_read for .ogg
130 | # There is some bug with ffmpeg and reading flac
131 | info = _soundfile_info(filepath)
132 | frames = -1 if duration <= 0 else int(duration * info.sample_rate)
133 | frame_offset = int(seek_time * info.sample_rate)
134 | wav, sr = soundfile.read(filepath, start=frame_offset, frames=frames, dtype=np.float32)
135 | assert info.sample_rate == sr, f"Mismatch of sample rates {info.sample_rate} {sr}"
136 | wav = torch.from_numpy(wav).t().contiguous()
137 | if len(wav.shape) == 1:
138 | wav = torch.unsqueeze(wav, 0)
139 | elif (
140 | fp.suffix in ['.wav', '.mp3'] and fp.suffix[1:] in ta.utils.sox_utils.list_read_formats()
141 | and duration <= 0 and seek_time == 0
142 | ):
143 | # Torchaudio is faster if we load an entire file at once.
144 | wav, sr = ta.load(fp)
145 | else:
146 | wav, sr = _av_read(filepath, seek_time, duration)
147 | if pad and duration > 0:
148 | expected_frames = int(duration * sr)
149 | wav = F.pad(wav, (0, expected_frames - wav.shape[-1]))
150 | return wav, sr
151 |
152 |
153 | def audio_write(stem_name: tp.Union[str, Path],
154 | wav: torch.Tensor, sample_rate: int,
155 | format: str = 'wav', mp3_rate: int = 320, normalize: bool = True,
156 | strategy: str = 'peak', peak_clip_headroom_db: float = 1,
157 | rms_headroom_db: float = 18, loudness_headroom_db: float = 14,
158 | loudness_compressor: bool = False,
159 | log_clipping: bool = True, make_parent_dir: bool = True,
160 | add_suffix: bool = True) -> Path:
161 | """Convenience function for saving audio to disk. Returns the filename the audio was written to.
162 |
163 | Args:
164 | stem_name (str or Path): Filename without extension which will be added automatically.
165 | format (str): Either "wav" or "mp3".
166 | mp3_rate (int): kbps when using mp3s.
167 | normalize (bool): if `True` (default), normalizes according to the prescribed
168 | strategy (see after). If `False`, the strategy is only used in case clipping
169 | would happen.
170 | strategy (str): Can be either 'clip', 'peak', or 'rms'. Default is 'peak',
171 | i.e. audio is normalized by its largest value. RMS normalizes by root-mean-square
172 | with extra headroom to avoid clipping. 'clip' just clips.
173 | peak_clip_headroom_db (float): Headroom in dB when doing 'peak' or 'clip' strategy.
174 | rms_headroom_db (float): Headroom in dB when doing 'rms' strategy. This must be much larger
175 | than the `peak_clip` one to avoid further clipping.
176 | loudness_headroom_db (float): Target loudness for loudness normalization.
177 | loudness_compressor (bool): Uses tanh for soft clipping when strategy is 'loudness'.
178 | when strategy is 'loudness'log_clipping (bool): If True, basic logging on stderr when clipping still
179 | occurs despite strategy (only for 'rms').
180 | make_parent_dir (bool): Make parent directory if it doesn't exist.
181 | Returns:
182 | Path: Path of the saved audio.
183 | """
184 | assert wav.dtype.is_floating_point, "wav is not floating point"
185 | if wav.dim() == 1:
186 | wav = wav[None]
187 | elif wav.dim() > 2:
188 | raise ValueError("Input wav should be at most 2 dimension.")
189 | assert wav.isfinite().all()
190 | wav = normalize_audio(wav, normalize, strategy, peak_clip_headroom_db,
191 | rms_headroom_db, loudness_headroom_db, log_clipping=log_clipping,
192 | sample_rate=sample_rate, stem_name=str(stem_name))
193 | kwargs: dict = {}
194 | if format == 'mp3':
195 | suffix = '.mp3'
196 | kwargs.update({"compression": mp3_rate})
197 | elif format == 'wav':
198 | wav = i16_pcm(wav)
199 | suffix = '.wav'
200 | kwargs.update({"encoding": "PCM_S", "bits_per_sample": 16})
201 | else:
202 | raise RuntimeError(f"Invalid format {format}. Only wav or mp3 are supported.")
203 | if not add_suffix:
204 | suffix = ''
205 | path = Path(str(stem_name) + suffix)
206 | if make_parent_dir:
207 | path.parent.mkdir(exist_ok=True, parents=True)
208 | try:
209 | ta.save(path, wav, sample_rate, **kwargs)
210 | except Exception:
211 | if path.exists():
212 | # we do not want to leave half written files around.
213 | path.unlink()
214 | raise
215 | return path
216 |
--------------------------------------------------------------------------------
/audiocraft/data/audio_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import sys
8 | import typing as tp
9 |
10 | import julius
11 | import torch
12 | import torchaudio
13 |
14 |
15 | def convert_audio_channels(wav: torch.Tensor, channels: int = 2) -> torch.Tensor:
16 | """Convert audio to the given number of channels.
17 |
18 | Args:
19 | wav (torch.Tensor): Audio wave of shape [B, C, T].
20 | channels (int): Expected number of channels as output.
21 | Returns:
22 | torch.Tensor: Downmixed or unchanged audio wave [B, C, T].
23 | """
24 | *shape, src_channels, length = wav.shape
25 | if src_channels == channels:
26 | pass
27 | elif channels == 1:
28 | # Case 1:
29 | # The caller asked 1-channel audio, and the stream has multiple
30 | # channels, downmix all channels.
31 | wav = wav.mean(dim=-2, keepdim=True)
32 | elif src_channels == 1:
33 | # Case 2:
34 | # The caller asked for multiple channels, but the input file has
35 | # a single channel, replicate the audio over all channels.
36 | wav = wav.expand(*shape, channels, length)
37 | elif src_channels >= channels:
38 | # Case 3:
39 | # The caller asked for multiple channels, and the input file has
40 | # more channels than requested. In that case return the first channels.
41 | wav = wav[..., :channels, :]
42 | else:
43 | # Case 4: What is a reasonable choice here?
44 | raise ValueError('The audio file has less channels than requested but is not mono.')
45 | return wav
46 |
47 |
48 | def convert_audio(wav: torch.Tensor, from_rate: float,
49 | to_rate: float, to_channels: int) -> torch.Tensor:
50 | """Convert audio to new sample rate and number of audio channels.
51 | """
52 | wav = julius.resample_frac(wav, int(from_rate), int(to_rate))
53 | wav = convert_audio_channels(wav, to_channels)
54 | return wav
55 |
56 |
57 | def normalize_loudness(wav: torch.Tensor, sample_rate: int, loudness_headroom_db: float = 14,
58 | loudness_compressor: bool = False, energy_floor: float = 2e-3):
59 | """Normalize an input signal to a user loudness in dB LKFS.
60 | Audio loudness is defined according to the ITU-R BS.1770-4 recommendation.
61 |
62 | Args:
63 | wav (torch.Tensor): Input multichannel audio data.
64 | sample_rate (int): Sample rate.
65 | loudness_headroom_db (float): Target loudness of the output in dB LUFS.
66 | loudness_compressor (bool): Uses tanh for soft clipping.
67 | energy_floor (float): anything below that RMS level will not be rescaled.
68 | Returns:
69 | output (torch.Tensor): Loudness normalized output data.
70 | """
71 | energy = wav.pow(2).mean().sqrt().item()
72 | if energy < energy_floor:
73 | return wav
74 | transform = torchaudio.transforms.Loudness(sample_rate)
75 | input_loudness_db = transform(wav).item()
76 | # calculate the gain needed to scale to the desired loudness level
77 | delta_loudness = -loudness_headroom_db - input_loudness_db
78 | gain = 10.0 ** (delta_loudness / 20.0)
79 | output = gain * wav
80 | if loudness_compressor:
81 | output = torch.tanh(output)
82 | assert output.isfinite().all(), (input_loudness_db, wav.pow(2).mean().sqrt())
83 | return output
84 |
85 |
86 | def _clip_wav(wav: torch.Tensor, log_clipping: bool = False, stem_name: tp.Optional[str] = None) -> None:
87 | """Utility function to clip the audio with logging if specified."""
88 | max_scale = wav.abs().max()
89 | if log_clipping and max_scale > 1:
90 | clamp_prob = (wav.abs() > 1).float().mean().item()
91 | print(f"CLIPPING {stem_name or ''} happening with proba (a bit of clipping is okay):",
92 | clamp_prob, "maximum scale: ", max_scale.item(), file=sys.stderr)
93 | wav.clamp_(-1, 1)
94 |
95 |
96 | def normalize_audio(wav: torch.Tensor, normalize: bool = True,
97 | strategy: str = 'peak', peak_clip_headroom_db: float = 1,
98 | rms_headroom_db: float = 18, loudness_headroom_db: float = 14,
99 | loudness_compressor: bool = False, log_clipping: bool = False,
100 | sample_rate: tp.Optional[int] = None,
101 | stem_name: tp.Optional[str] = None) -> torch.Tensor:
102 | """Normalize the audio according to the prescribed strategy (see after).
103 |
104 | Args:
105 | wav (torch.Tensor): Audio data.
106 | normalize (bool): if `True` (default), normalizes according to the prescribed
107 | strategy (see after). If `False`, the strategy is only used in case clipping
108 | would happen.
109 | strategy (str): Can be either 'clip', 'peak', or 'rms'. Default is 'peak',
110 | i.e. audio is normalized by its largest value. RMS normalizes by root-mean-square
111 | with extra headroom to avoid clipping. 'clip' just clips.
112 | peak_clip_headroom_db (float): Headroom in dB when doing 'peak' or 'clip' strategy.
113 | rms_headroom_db (float): Headroom in dB when doing 'rms' strategy. This must be much larger
114 | than the `peak_clip` one to avoid further clipping.
115 | loudness_headroom_db (float): Target loudness for loudness normalization.
116 | loudness_compressor (bool): If True, uses tanh based soft clipping.
117 | log_clipping (bool): If True, basic logging on stderr when clipping still
118 | occurs despite strategy (only for 'rms').
119 | sample_rate (int): Sample rate for the audio data (required for loudness).
120 | stem_name (Optional[str]): Stem name for clipping logging.
121 | Returns:
122 | torch.Tensor: Normalized audio.
123 | """
124 | scale_peak = 10 ** (-peak_clip_headroom_db / 20)
125 | scale_rms = 10 ** (-rms_headroom_db / 20)
126 | if strategy == 'peak':
127 | rescaling = (scale_peak / wav.abs().max())
128 | if normalize or rescaling < 1:
129 | wav = wav * rescaling
130 | elif strategy == 'clip':
131 | wav = wav.clamp(-scale_peak, scale_peak)
132 | elif strategy == 'rms':
133 | mono = wav.mean(dim=0)
134 | rescaling = scale_rms / mono.pow(2).mean().sqrt()
135 | if normalize or rescaling < 1:
136 | wav = wav * rescaling
137 | _clip_wav(wav, log_clipping=log_clipping, stem_name=stem_name)
138 | elif strategy == 'loudness':
139 | assert sample_rate is not None, "Loudness normalization requires sample rate."
140 | wav = normalize_loudness(wav, sample_rate, loudness_headroom_db, loudness_compressor)
141 | _clip_wav(wav, log_clipping=log_clipping, stem_name=stem_name)
142 | else:
143 | assert wav.abs().max() < 1
144 | assert strategy == '' or strategy == 'none', f"Unexpected strategy: '{strategy}'"
145 | return wav
146 |
147 |
148 | def f32_pcm(wav: torch.Tensor) -> torch.Tensor:
149 | """Convert audio to float 32 bits PCM format.
150 | """
151 | if wav.dtype.is_floating_point:
152 | return wav
153 | else:
154 | assert wav.dtype == torch.int16
155 | return wav.float() / 2**15
156 |
157 |
158 | def i16_pcm(wav: torch.Tensor) -> torch.Tensor:
159 | """Convert audio to int 16 bits PCM format.
160 |
161 | ..Warning:: There exist many formula for doing this convertion. None are perfect
162 | due to the asymetry of the int16 range. One either have possible clipping, DC offset,
163 | or inconsistancies with f32_pcm. If the given wav doesn't have enough headroom,
164 | it is possible that `i16_pcm(f32_pcm)) != Identity`.
165 | """
166 | if wav.dtype.is_floating_point:
167 | assert wav.abs().max() <= 1
168 | candidate = (wav * 2 ** 15).round()
169 | if candidate.max() >= 2 ** 15: # clipping would occur
170 | candidate = (wav * (2 ** 15 - 1)).round()
171 | return candidate.short()
172 | else:
173 | assert wav.dtype == torch.int16
174 | return wav
175 |
--------------------------------------------------------------------------------
/audiocraft/data/zip.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import typing
8 | import zipfile
9 |
10 | from dataclasses import dataclass
11 | from functools import lru_cache
12 | from typing_extensions import Literal
13 |
14 |
15 | DEFAULT_SIZE = 32
16 | MODE = Literal['r', 'w', 'x', 'a']
17 |
18 |
19 | @dataclass(order=True)
20 | class PathInZip:
21 | """Class for holding a path of file within a zip file.
22 |
23 | Args:
24 | path: The convention is :
25 | Let's assume there is a zip file /some/location/foo.zip
26 | and inside of it is a json file located at /data/file1.json,
27 | Then we expect path = "/some/location/foo.zip:/data/file1.json"
28 | """
29 |
30 | INFO_PATH_SEP = ':'
31 | zip_path: str
32 | file_path: str
33 |
34 | def __init__(self, path: str) -> None:
35 | split_path = path.split(self.INFO_PATH_SEP)
36 | assert len(split_path) == 2
37 | self.zip_path, self.file_path = split_path
38 |
39 | @classmethod
40 | def from_paths(cls, zip_path: str, file_path: str):
41 | return cls(zip_path + cls.INFO_PATH_SEP + file_path)
42 |
43 | def __str__(self) -> str:
44 | return self.zip_path + self.INFO_PATH_SEP + self.file_path
45 |
46 |
47 | def _open_zip(path: str, mode: MODE = 'r'):
48 | return zipfile.ZipFile(path, mode)
49 |
50 |
51 | _cached_open_zip = lru_cache(DEFAULT_SIZE)(_open_zip)
52 |
53 |
54 | def set_zip_cache_size(max_size: int):
55 | """Sets the maximal LRU caching for zip file opening.
56 |
57 | Args:
58 | max_size: the maximal LRU cache.
59 | """
60 | global _cached_open_zip
61 | _cached_open_zip = lru_cache(max_size)(_open_zip)
62 |
63 |
64 | def open_file_in_zip(path_in_zip: PathInZip, mode: str = 'r') -> typing.IO:
65 | """Opens a file stored inside a zip and returns a file-like object.
66 |
67 | Args:
68 | path_in_zip: A PathInZip object representing the file to return a file-like object of.
69 | mode: The mode in which to open the file with.
70 | Returns:
71 | A file-like object for PathInZip.
72 | """
73 | zf = _cached_open_zip(path_in_zip.zip_path)
74 | return zf.open(path_in_zip.file_path)
75 |
--------------------------------------------------------------------------------
/audiocraft/models/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # flake8: noqa
8 | from .musicgen import MusicGen
9 | from .lm import LMModel
10 | from .encodec import CompressionModel, EncodecModel
11 |
--------------------------------------------------------------------------------
/audiocraft/models/builders.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | """
8 | All the functions to build the relevant models and modules
9 | from the Hydra config.
10 | """
11 |
12 | import typing as tp
13 | import warnings
14 |
15 | import audiocraft
16 | import omegaconf
17 | import torch
18 |
19 | from .encodec import CompressionModel, EncodecModel, FlattenedCompressionModel # noqa
20 | from .lm import LMModel
21 | from ..modules.codebooks_patterns import (
22 | CodebooksPatternProvider,
23 | DelayedPatternProvider,
24 | ParallelPatternProvider,
25 | UnrolledPatternProvider,
26 | VALLEPattern,
27 | MusicLMPattern,
28 | )
29 | from ..modules.conditioners import (
30 | BaseConditioner,
31 | ConditioningProvider,
32 | LUTConditioner,
33 | T5Conditioner,
34 | ConditionFuser,
35 | ChromaStemConditioner,
36 | )
37 | from .. import quantization as qt
38 | from ..utils.utils import dict_from_config
39 |
40 |
41 | def get_quantizer(quantizer: str, cfg: omegaconf.DictConfig, dimension: int) -> qt.BaseQuantizer:
42 | klass = {
43 | 'no_quant': qt.DummyQuantizer,
44 | 'rvq': qt.ResidualVectorQuantizer
45 | }[quantizer]
46 | kwargs = dict_from_config(getattr(cfg, quantizer))
47 | if quantizer != 'no_quant':
48 | kwargs['dimension'] = dimension
49 | return klass(**kwargs)
50 |
51 |
52 | def get_encodec_autoencoder(encoder_name: str, cfg: omegaconf.DictConfig):
53 | if encoder_name == 'seanet':
54 | kwargs = dict_from_config(getattr(cfg, 'seanet'))
55 | encoder_override_kwargs = kwargs.pop('encoder')
56 | decoder_override_kwargs = kwargs.pop('decoder')
57 | encoder_kwargs = {**kwargs, **encoder_override_kwargs}
58 | decoder_kwargs = {**kwargs, **decoder_override_kwargs}
59 | encoder = audiocraft.modules.SEANetEncoder(**encoder_kwargs)
60 | decoder = audiocraft.modules.SEANetDecoder(**decoder_kwargs)
61 | return encoder, decoder
62 | else:
63 | raise KeyError(f'Unexpected compression model {cfg.compression_model}')
64 |
65 |
66 | def get_compression_model(cfg: omegaconf.DictConfig) -> CompressionModel:
67 | """Instantiate a compression model.
68 | """
69 | if cfg.compression_model == 'encodec':
70 | kwargs = dict_from_config(getattr(cfg, 'encodec'))
71 | encoder_name = kwargs.pop('autoencoder')
72 | quantizer_name = kwargs.pop('quantizer')
73 | encoder, decoder = get_encodec_autoencoder(encoder_name, cfg)
74 | quantizer = get_quantizer(quantizer_name, cfg, encoder.dimension)
75 | frame_rate = kwargs['sample_rate'] // encoder.hop_length
76 | renormalize = kwargs.pop('renormalize', None)
77 | renorm = kwargs.pop('renorm')
78 | if renormalize is None:
79 | renormalize = renorm is not None
80 | warnings.warn("You are using a deprecated EnCodec model. Please migrate to new renormalization.")
81 | return EncodecModel(encoder, decoder, quantizer,
82 | frame_rate=frame_rate, renormalize=renormalize, **kwargs).to(cfg.device)
83 | else:
84 | raise KeyError(f'Unexpected compression model {cfg.compression_model}')
85 |
86 |
87 | def get_lm_model(cfg: omegaconf.DictConfig) -> LMModel:
88 | """Instantiate a transformer LM.
89 | """
90 | if cfg.lm_model == 'transformer_lm':
91 | kwargs = dict_from_config(getattr(cfg, 'transformer_lm'))
92 | n_q = kwargs['n_q']
93 | q_modeling = kwargs.pop('q_modeling', None)
94 | codebooks_pattern_cfg = getattr(cfg, 'codebooks_pattern')
95 | attribute_dropout = dict_from_config(getattr(cfg, 'attribute_dropout'))
96 | cls_free_guidance = dict_from_config(getattr(cfg, 'classifier_free_guidance'))
97 | cfg_prob, cfg_coef = cls_free_guidance["training_dropout"], cls_free_guidance["inference_coef"]
98 | fuser = get_condition_fuser(cfg)
99 | condition_provider = get_conditioner_provider(kwargs["dim"], cfg).to(cfg.device)
100 | if len(fuser.fuse2cond['cross']) > 0: # enforce cross-att programatically
101 | kwargs['cross_attention'] = True
102 | if codebooks_pattern_cfg.modeling is None:
103 | assert q_modeling is not None, \
104 | 'LM model should either have a codebook pattern defined or transformer_lm.q_modeling'
105 | codebooks_pattern_cfg = omegaconf.OmegaConf.create(
106 | {'modeling': q_modeling, 'delay': {'delays': list(range(n_q))}}
107 | )
108 | pattern_provider = get_codebooks_pattern_provider(n_q, codebooks_pattern_cfg)
109 | return LMModel(
110 | pattern_provider=pattern_provider,
111 | condition_provider=condition_provider,
112 | fuser=fuser,
113 | cfg_dropout=cfg_prob,
114 | cfg_coef=cfg_coef,
115 | attribute_dropout=attribute_dropout,
116 | dtype=getattr(torch, cfg.dtype),
117 | device=cfg.device,
118 | **kwargs
119 | ).to(cfg.device)
120 | else:
121 | raise KeyError(f'Unexpected LM model {cfg.lm_model}')
122 |
123 |
124 | def get_conditioner_provider(output_dim: int, cfg: omegaconf.DictConfig) -> ConditioningProvider:
125 | """Instantiate a conditioning model.
126 | """
127 | device = cfg.device
128 | duration = cfg.dataset.segment_duration
129 | cfg = getattr(cfg, "conditioners")
130 | cfg = omegaconf.OmegaConf.create({}) if cfg is None else cfg
131 | conditioners: tp.Dict[str, BaseConditioner] = {}
132 | with omegaconf.open_dict(cfg):
133 | condition_provider_args = cfg.pop('args', {})
134 | for cond, cond_cfg in cfg.items():
135 | model_type = cond_cfg["model"]
136 | model_args = cond_cfg[model_type]
137 | if model_type == "t5":
138 | conditioners[str(cond)] = T5Conditioner(output_dim=output_dim, device=device, **model_args)
139 | elif model_type == "lut":
140 | conditioners[str(cond)] = LUTConditioner(output_dim=output_dim, **model_args)
141 | elif model_type == "chroma_stem":
142 | model_args.pop('cache_path', None)
143 | conditioners[str(cond)] = ChromaStemConditioner(
144 | output_dim=output_dim,
145 | duration=duration,
146 | device=device,
147 | **model_args
148 | )
149 | else:
150 | raise ValueError(f"unrecognized conditioning model: {model_type}")
151 | conditioner = ConditioningProvider(conditioners, device=device, **condition_provider_args)
152 | return conditioner
153 |
154 |
155 | def get_condition_fuser(cfg: omegaconf.DictConfig) -> ConditionFuser:
156 | """Instantiate a condition fuser object.
157 | """
158 | fuser_cfg = getattr(cfg, "fuser")
159 | fuser_methods = ["sum", "cross", "prepend", "input_interpolate"]
160 | fuse2cond = {k: fuser_cfg[k] for k in fuser_methods}
161 | kwargs = {k: v for k, v in fuser_cfg.items() if k not in fuser_methods}
162 | fuser = ConditionFuser(fuse2cond=fuse2cond, **kwargs)
163 | return fuser
164 |
165 |
166 | def get_codebooks_pattern_provider(n_q: int, cfg: omegaconf.DictConfig) -> CodebooksPatternProvider:
167 | """Instantiate a codebooks pattern provider object.
168 | """
169 | pattern_providers = {
170 | 'parallel': ParallelPatternProvider,
171 | 'delay': DelayedPatternProvider,
172 | 'unroll': UnrolledPatternProvider,
173 | 'valle': VALLEPattern,
174 | 'musiclm': MusicLMPattern,
175 | }
176 | name = cfg.modeling
177 | kwargs = dict_from_config(cfg.get(name)) if hasattr(cfg, name) else {}
178 | klass = pattern_providers[name]
179 | return klass(n_q, **kwargs)
180 |
181 |
182 | def get_debug_compression_model(device='cpu'):
183 | """Instantiate a debug compression model to be used for unit tests.
184 | """
185 | seanet_kwargs = {
186 | 'n_filters': 4,
187 | 'n_residual_layers': 1,
188 | 'dimension': 32,
189 | 'ratios': [10, 8, 16] # 25 Hz at 32kHz
190 | }
191 | encoder = audiocraft.modules.SEANetEncoder(**seanet_kwargs)
192 | decoder = audiocraft.modules.SEANetDecoder(**seanet_kwargs)
193 | quantizer = qt.ResidualVectorQuantizer(dimension=32, bins=400, n_q=4)
194 | init_x = torch.randn(8, 32, 128)
195 | quantizer(init_x, 1) # initialize kmeans etc.
196 | compression_model = EncodecModel(
197 | encoder, decoder, quantizer,
198 | frame_rate=25, sample_rate=32000, channels=1).to(device)
199 | return compression_model.eval()
200 |
201 |
202 | def get_debug_lm_model(device='cpu'):
203 | """Instantiate a debug LM to be used for unit tests.
204 | """
205 | pattern = DelayedPatternProvider(n_q=4)
206 | dim = 16
207 | providers = {
208 | 'description': LUTConditioner(n_bins=128, dim=dim, output_dim=dim, tokenizer="whitespace"),
209 | }
210 | condition_provider = ConditioningProvider(providers)
211 | fuser = ConditionFuser(
212 | {'cross': ['description'], 'prepend': [],
213 | 'sum': [], 'input_interpolate': []})
214 | lm = LMModel(
215 | pattern, condition_provider, fuser,
216 | n_q=4, card=400, dim=dim, num_heads=4, custom=True, num_layers=2,
217 | cross_attention=True, causal=True)
218 | return lm.to(device).eval()
219 |
--------------------------------------------------------------------------------
/audiocraft/models/encodec.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from abc import ABC, abstractmethod
8 | import typing as tp
9 |
10 | from einops import rearrange
11 | import torch
12 | from torch import nn
13 |
14 | from .. import quantization as qt
15 |
16 |
17 | class CompressionModel(ABC, nn.Module):
18 |
19 | @abstractmethod
20 | def forward(self, x: torch.Tensor) -> qt.QuantizedResult:
21 | ...
22 |
23 | @abstractmethod
24 | def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
25 | """See `EncodecModel.encode`"""
26 | ...
27 |
28 | @abstractmethod
29 | def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None):
30 | """See `EncodecModel.decode`"""
31 | ...
32 |
33 | @property
34 | @abstractmethod
35 | def channels(self) -> int:
36 | ...
37 |
38 | @property
39 | @abstractmethod
40 | def frame_rate(self) -> int:
41 | ...
42 |
43 | @property
44 | @abstractmethod
45 | def sample_rate(self) -> int:
46 | ...
47 |
48 | @property
49 | @abstractmethod
50 | def cardinality(self) -> int:
51 | ...
52 |
53 | @property
54 | @abstractmethod
55 | def num_codebooks(self) -> int:
56 | ...
57 |
58 | @property
59 | @abstractmethod
60 | def total_codebooks(self) -> int:
61 | ...
62 |
63 | @abstractmethod
64 | def set_num_codebooks(self, n: int):
65 | """Set the active number of codebooks used by the quantizer.
66 | """
67 | ...
68 |
69 |
70 | class EncodecModel(CompressionModel):
71 | """Encodec model operating on the raw waveform.
72 |
73 | Args:
74 | encoder (nn.Module): Encoder network.
75 | decoder (nn.Module): Decoder network.
76 | quantizer (qt.BaseQuantizer): Quantizer network.
77 | frame_rate (int): Frame rate for the latent representation.
78 | sample_rate (int): Audio sample rate.
79 | channels (int): Number of audio channels.
80 | causal (bool): Whether to use a causal version of the model.
81 | renormalize (bool): Whether to renormalize the audio before running the model.
82 | """
83 | # we need assignement to override the property in the abstract class,
84 | # I couldn't find a better way...
85 | frame_rate: int = 0
86 | sample_rate: int = 0
87 | channels: int = 0
88 |
89 | def __init__(self,
90 | encoder: nn.Module,
91 | decoder: nn.Module,
92 | quantizer: qt.BaseQuantizer,
93 | frame_rate: int,
94 | sample_rate: int,
95 | channels: int,
96 | causal: bool = False,
97 | renormalize: bool = False):
98 | super().__init__()
99 | self.encoder = encoder
100 | self.decoder = decoder
101 | self.quantizer = quantizer
102 | self.frame_rate = frame_rate
103 | self.sample_rate = sample_rate
104 | self.channels = channels
105 | self.renormalize = renormalize
106 | self.causal = causal
107 | if self.causal:
108 | # we force disabling here to avoid handling linear overlap of segments
109 | # as supported in original EnCodec codebase.
110 | assert not self.renormalize, 'Causal model does not support renormalize'
111 |
112 | @property
113 | def total_codebooks(self):
114 | """Total number of quantizer codebooks available.
115 | """
116 | return self.quantizer.total_codebooks
117 |
118 | @property
119 | def num_codebooks(self):
120 | """Active number of codebooks used by the quantizer.
121 | """
122 | return self.quantizer.num_codebooks
123 |
124 | def set_num_codebooks(self, n: int):
125 | """Set the active number of codebooks used by the quantizer.
126 | """
127 | self.quantizer.set_num_codebooks(n)
128 |
129 | @property
130 | def cardinality(self):
131 | """Cardinality of each codebook.
132 | """
133 | return self.quantizer.bins
134 |
135 | def preprocess(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
136 | scale: tp.Optional[torch.Tensor]
137 | if self.renormalize:
138 | mono = x.mean(dim=1, keepdim=True)
139 | volume = mono.pow(2).mean(dim=2, keepdim=True).sqrt()
140 | scale = 1e-8 + volume
141 | x = x / scale
142 | scale = scale.view(-1, 1)
143 | else:
144 | scale = None
145 | return x, scale
146 |
147 | def postprocess(self,
148 | x: torch.Tensor,
149 | scale: tp.Optional[torch.Tensor] = None) -> torch.Tensor:
150 | if scale is not None:
151 | assert self.renormalize
152 | x = x * scale.view(-1, 1, 1)
153 | return x
154 |
155 | def forward(self, x: torch.Tensor) -> qt.QuantizedResult:
156 | assert x.dim() == 3
157 | length = x.shape[-1]
158 | x, scale = self.preprocess(x)
159 |
160 | emb = self.encoder(x)
161 | q_res = self.quantizer(emb, self.frame_rate)
162 | out = self.decoder(q_res.x)
163 |
164 | # remove extra padding added by the encoder and decoder
165 | assert out.shape[-1] >= length, (out.shape[-1], length)
166 | out = out[..., :length]
167 |
168 | q_res.x = self.postprocess(out, scale)
169 |
170 | return q_res
171 |
172 | def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
173 | """Encode the given input tensor to quantized representation along with scale parameter.
174 |
175 | Args:
176 | x (torch.Tensor): Float tensor of shape [B, C, T]
177 |
178 | Returns:
179 | codes, scale (tp.Tuple[torch.Tensor, torch.Tensor]): Tuple composed of:
180 | codes a float tensor of shape [B, K, T] with K the number of codebooks used and T the timestep.
181 | scale a float tensor containing the scale for audio renormalizealization.
182 | """
183 | assert x.dim() == 3
184 | x, scale = self.preprocess(x)
185 | emb = self.encoder(x)
186 | codes = self.quantizer.encode(emb)
187 | return codes, scale
188 |
189 | def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None):
190 | """Decode the given codes to a reconstructed representation, using the scale to perform
191 | audio denormalization if needed.
192 |
193 | Args:
194 | codes (torch.Tensor): Int tensor of shape [B, K, T]
195 | scale (tp.Optional[torch.Tensor]): Float tensor containing the scale value.
196 |
197 | Returns:
198 | out (torch.Tensor): Float tensor of shape [B, C, T], the reconstructed audio.
199 | """
200 | emb = self.quantizer.decode(codes)
201 | out = self.decoder(emb)
202 | out = self.postprocess(out, scale)
203 | # out contains extra padding added by the encoder and decoder
204 | return out
205 |
206 |
207 | class FlattenedCompressionModel(CompressionModel):
208 | """Wraps a CompressionModel and flatten its codebooks, e.g.
209 | instead of returning [B, K, T], return [B, S, T * (K // S)] with
210 | S the number of codebooks per step, and `K // S` the number of 'virtual steps'
211 | for each real time step.
212 |
213 | Args:
214 | model (CompressionModel): compression model to wrap.
215 | codebooks_per_step (int): number of codebooks to keep per step,
216 | this must divide the number of codebooks provided by the wrapped model.
217 | extend_cardinality (bool): if True, and for instance if codebooks_per_step = 1,
218 | if each codebook has a cardinality N, then the first codebook will
219 | use the range [0, N - 1], and the second [N, 2 N - 1] etc.
220 | On decoding, this can lead to potentially invalid sequences.
221 | Any invalid entry will be silently remapped to the proper range
222 | with a modulo.
223 | """
224 | def __init__(self, model: CompressionModel, codebooks_per_step: int = 1,
225 | extend_cardinality: bool = True):
226 | super().__init__()
227 | self.model = model
228 | self.codebooks_per_step = codebooks_per_step
229 | self.extend_cardinality = extend_cardinality
230 |
231 | @property
232 | def total_codebooks(self):
233 | return self.model.total_codebooks
234 |
235 | @property
236 | def num_codebooks(self):
237 | """Active number of codebooks used by the quantizer.
238 |
239 | ..Warning:: this reports the number of codebooks after the flattening
240 | of the codebooks!
241 | """
242 | assert self.model.num_codebooks % self.codebooks_per_step == 0
243 | return self.codebooks_per_step
244 |
245 | def set_num_codebooks(self, n: int):
246 | """Set the active number of codebooks used by the quantizer.
247 |
248 | ..Warning:: this sets the number of codebooks **before** the flattening
249 | of the codebooks.
250 | """
251 | assert n % self.codebooks_per_step == 0
252 | self.model.set_num_codebooks(n)
253 |
254 | @property
255 | def num_virtual_steps(self) -> int:
256 | """Return the number of virtual steps, e.g. one real step
257 | will be split into that many steps.
258 | """
259 | return self.model.num_codebooks // self.codebooks_per_step
260 |
261 | @property
262 | def frame_rate(self) -> int:
263 | return self.model.frame_rate * self.num_virtual_steps
264 |
265 | @property
266 | def sample_rate(self) -> int:
267 | return self.model.sample_rate
268 |
269 | @property
270 | def channels(self) -> int:
271 | return self.model.channels
272 |
273 | @property
274 | def cardinality(self):
275 | """Cardinality of each codebook.
276 | """
277 | if self.extend_cardinality:
278 | return self.model.cardinality * self.num_virtual_steps
279 | else:
280 | return self.model.cardinality
281 |
282 | def forward(self, x: torch.Tensor) -> qt.QuantizedResult:
283 | raise NotImplementedError("Not supported, use encode and decode.")
284 |
285 | def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
286 | indices, scales = self.model.encode(x)
287 | B, K, T = indices.shape
288 | indices = rearrange(indices, 'b (k v) t -> b k t v', k=self.codebooks_per_step)
289 | if self.extend_cardinality:
290 | for virtual_step in range(1, self.num_virtual_steps):
291 | indices[..., virtual_step] += self.model.cardinality * virtual_step
292 | indices = rearrange(indices, 'b k t v -> b k (t v)')
293 | return (indices, scales)
294 |
295 | def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None):
296 | B, K, T = codes.shape
297 | assert T % self.num_virtual_steps == 0
298 | codes = rearrange(codes, 'b k (t v) -> b (k v) t', v=self.num_virtual_steps)
299 | # We silently ignore potential errors from the LM when
300 | # using extend_cardinality.
301 | codes = codes % self.model.cardinality
302 | return self.model.decode(codes, scale)
303 |
--------------------------------------------------------------------------------
/audiocraft/models/loaders.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | """
8 | Utility functions to load from the checkpoints.
9 | Each checkpoint is a torch.saved dict with the following keys:
10 | - 'xp.cfg': the hydra config as dumped during training. This should be used
11 | to rebuild the object using the audiocraft.models.builders functions,
12 | - 'model_best_state': a readily loadable best state for the model, including
13 | the conditioner. The model obtained from `xp.cfg` should be compatible
14 | with this state dict. In the case of a LM, the encodec model would not be
15 | bundled along but instead provided separately.
16 |
17 | Those functions also support loading from a remote location with the Torch Hub API.
18 | They also support overriding some parameters, in particular the device and dtype
19 | of the returned model.
20 | """
21 |
22 | from pathlib import Path
23 | from huggingface_hub import hf_hub_download
24 | import typing as tp
25 | import os
26 |
27 | from omegaconf import OmegaConf
28 | import torch
29 |
30 | from . import builders
31 |
32 |
33 | HF_MODEL_CHECKPOINTS_MAP = {
34 | "small": "facebook/musicgen-small",
35 | "medium": "facebook/musicgen-medium",
36 | "large": "facebook/musicgen-large",
37 | "melody": "facebook/musicgen-melody",
38 | }
39 |
40 |
41 | def _get_state_dict(
42 | file_or_url_or_id: tp.Union[Path, str],
43 | filename: tp.Optional[str] = None,
44 | device='cpu',
45 | cache_dir: tp.Optional[str] = None,
46 | ):
47 | # Return the state dict either from a file or url
48 | file_or_url_or_id = str(file_or_url_or_id)
49 | assert isinstance(file_or_url_or_id, str)
50 |
51 | if os.path.isfile(file_or_url_or_id):
52 | return torch.load(file_or_url_or_id, map_location=device)
53 |
54 | elif file_or_url_or_id.startswith('https://'):
55 | return torch.hub.load_state_dict_from_url(file_or_url_or_id, map_location=device, check_hash=True)
56 |
57 | elif file_or_url_or_id in HF_MODEL_CHECKPOINTS_MAP:
58 | assert filename is not None, "filename needs to be defined if using HF checkpoints"
59 |
60 | repo_id = HF_MODEL_CHECKPOINTS_MAP[file_or_url_or_id]
61 | file = hf_hub_download(repo_id=repo_id, filename=filename, cache_dir=cache_dir)
62 | return torch.load(file, map_location=device)
63 |
64 | else:
65 | raise ValueError(f"{file_or_url_or_id} is not a valid name, path or link that can be loaded.")
66 |
67 |
68 | def load_compression_model(file_or_url_or_id: tp.Union[Path, str], device='cpu', cache_dir: tp.Optional[str] = None):
69 | pkg = _get_state_dict(file_or_url_or_id, filename="compression_state_dict.bin", cache_dir=cache_dir)
70 | cfg = OmegaConf.create(pkg['xp.cfg'])
71 | cfg.device = str(device)
72 | model = builders.get_compression_model(cfg)
73 | model.load_state_dict(pkg['best_state'])
74 | model.eval()
75 | return model
76 |
77 |
78 | def load_lm_model(file_or_url_or_id: tp.Union[Path, str], device='cpu', cache_dir: tp.Optional[str] = None):
79 | pkg = _get_state_dict(file_or_url_or_id, filename="state_dict.bin", cache_dir=cache_dir)
80 | cfg = OmegaConf.create(pkg['xp.cfg'])
81 | cfg.device = str(device)
82 | if cfg.device == 'cpu':
83 | cfg.transformer_lm.memory_efficient = False
84 | cfg.transformer_lm.custom = True
85 | cfg.dtype = 'float32'
86 | else:
87 | cfg.dtype = 'float16'
88 | model = builders.get_lm_model(cfg)
89 | model.load_state_dict(pkg['best_state'])
90 | model.eval()
91 | model.cfg = cfg
92 | return model
93 |
--------------------------------------------------------------------------------
/audiocraft/modules/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # flake8: noqa
8 | from .conv import (
9 | NormConv1d,
10 | NormConv2d,
11 | NormConvTranspose1d,
12 | NormConvTranspose2d,
13 | StreamableConv1d,
14 | StreamableConvTranspose1d,
15 | pad_for_conv1d,
16 | pad1d,
17 | unpad1d,
18 | )
19 | from .lstm import StreamableLSTM
20 | from .seanet import SEANetEncoder, SEANetDecoder
21 |
--------------------------------------------------------------------------------
/audiocraft/modules/activations.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import torch
8 | import torch.nn as nn
9 | from torch import Tensor
10 | from typing import Union, Callable
11 |
12 |
13 | class CustomGLU(nn.Module):
14 | """Custom Gated Linear Unit activation.
15 | Applies a modified gated linear unit :math:`a * f(b)` where :math:`a` is the first half
16 | of the input matrices, :math:`b` is the second half, and :math:`f` is a provided activation
17 | function (i.e. sigmoid, swish, etc.).
18 |
19 | Args:
20 | activation (nn.Module): The custom activation to apply in the Gated Linear Unit
21 | dim (int): the dimension on which to split the input. Default: -1
22 |
23 | Shape:
24 | - Input: :math:`(\ast_1, N, \ast_2)` where `*` means, any number of additional
25 | dimensions
26 | - Output: :math:`(\ast_1, M, \ast_2)` where :math:`M=N/2`
27 |
28 | Examples::
29 | >>> m = CustomGLU(nn.Sigmoid())
30 | >>> input = torch.randn(4, 2)
31 | >>> output = m(input)
32 | """
33 | def __init__(self, activation: nn.Module, dim: int = -1):
34 | super(CustomGLU, self).__init__()
35 | self.dim = dim
36 | self.activation = activation
37 |
38 | def forward(self, x: Tensor):
39 | assert x.shape[self.dim] % 2 == 0 # M = N / 2
40 | a, b = torch.chunk(x, 2, dim=self.dim)
41 | return a * self.activation(b)
42 |
43 |
44 | class SwiGLU(CustomGLU):
45 | """SiLU Gated Linear Unit activation.
46 | Applies SiLU Gated Linear Unit :math:`a * SiLU(b)` where :math:`a` is
47 | the first half of the input matrices, :math:`b` is the second half.
48 |
49 | Args:
50 | dim (int): the dimension on which to split the input. Default: -1
51 | """
52 | def __init__(self, dim: int = -1):
53 | super(SwiGLU, self).__init__(nn.SiLU(), dim)
54 |
55 |
56 | class GeGLU(CustomGLU):
57 | """GeLU Gated Linear Unit activation.
58 | Applies GeLU Gated Linear Unit :math:`a * GELU(b)` where :math:`a` is
59 | the first half of the input matrices, :math:`b` is the second half.
60 |
61 | Args:
62 | dim (int): the dimension on which to split the input. Default: -1
63 | """
64 | def __init__(self, dim: int = -1):
65 | super(GeGLU, self).__init__(nn.GELU(), dim)
66 |
67 |
68 | class ReGLU(CustomGLU):
69 | """ReLU Gated Linear Unit activation.
70 | Applies ReLU Gated Linear Unit :math:`a * ReLU(b)` where :math:`a` is
71 | the first half of the input matrices, :math:`b` is the second half.
72 |
73 | Args:
74 | dim (int): the dimension on which to split the input. Default: -1
75 | """
76 | def __init__(self, dim: int = -1):
77 | super(ReGLU, self).__init__(nn.ReLU(), dim)
78 |
79 |
80 | def get_activation_fn(
81 | activation: Union[str, Callable[[Tensor], Tensor]]
82 | ) -> Union[str, Callable[[Tensor], Tensor]]:
83 | """Helper function to map an activation string to the activation class.
84 | If the supplied activation is not a string that is recognized, the activation is passed back.
85 |
86 | Args:
87 | activation (Union[str, Callable[[Tensor], Tensor]]): Activation to check
88 | """
89 | if isinstance(activation, str):
90 | if activation == "reglu":
91 | return ReGLU()
92 | elif activation == "geglu":
93 | return GeGLU()
94 | elif activation == "swiglu":
95 | return SwiGLU()
96 | return activation
97 |
--------------------------------------------------------------------------------
/audiocraft/modules/conv.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import math
8 | import typing as tp
9 | import warnings
10 |
11 | import torch
12 | from torch import nn
13 | from torch.nn import functional as F
14 | from torch.nn.utils import spectral_norm, weight_norm
15 |
16 |
17 | CONV_NORMALIZATIONS = frozenset(['none', 'weight_norm', 'spectral_norm',
18 | 'time_group_norm'])
19 |
20 |
21 | def apply_parametrization_norm(module: nn.Module, norm: str = 'none'):
22 | assert norm in CONV_NORMALIZATIONS
23 | if norm == 'weight_norm':
24 | return weight_norm(module)
25 | elif norm == 'spectral_norm':
26 | return spectral_norm(module)
27 | else:
28 | # We already check was in CONV_NORMALIZATION, so any other choice
29 | # doesn't need reparametrization.
30 | return module
31 |
32 |
33 | def get_norm_module(module: nn.Module, causal: bool = False, norm: str = 'none', **norm_kwargs):
34 | """Return the proper normalization module. If causal is True, this will ensure the returned
35 | module is causal, or return an error if the normalization doesn't support causal evaluation.
36 | """
37 | assert norm in CONV_NORMALIZATIONS
38 | if norm == 'time_group_norm':
39 | if causal:
40 | raise ValueError("GroupNorm doesn't support causal evaluation.")
41 | assert isinstance(module, nn.modules.conv._ConvNd)
42 | return nn.GroupNorm(1, module.out_channels, **norm_kwargs)
43 | else:
44 | return nn.Identity()
45 |
46 |
47 | def get_extra_padding_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int,
48 | padding_total: int = 0) -> int:
49 | """See `pad_for_conv1d`.
50 | """
51 | length = x.shape[-1]
52 | n_frames = (length - kernel_size + padding_total) / stride + 1
53 | ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total)
54 | return ideal_length - length
55 |
56 |
57 | def pad_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0):
58 | """Pad for a convolution to make sure that the last window is full.
59 | Extra padding is added at the end. This is required to ensure that we can rebuild
60 | an output of the same length, as otherwise, even with padding, some time steps
61 | might get removed.
62 | For instance, with total padding = 4, kernel size = 4, stride = 2:
63 | 0 0 1 2 3 4 5 0 0 # (0s are padding)
64 | 1 2 3 # (output frames of a convolution, last 0 is never used)
65 | 0 0 1 2 3 4 5 0 # (output of tr. conv., but pos. 5 is going to get removed as padding)
66 | 1 2 3 4 # once you removed padding, we are missing one time step !
67 | """
68 | extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total)
69 | return F.pad(x, (0, extra_padding))
70 |
71 |
72 | def pad1d(x: torch.Tensor, paddings: tp.Tuple[int, int], mode: str = 'constant', value: float = 0.):
73 | """Tiny wrapper around F.pad, just to allow for reflect padding on small input.
74 | If this is the case, we insert extra 0 padding to the right before the reflection happen.
75 | """
76 | length = x.shape[-1]
77 | padding_left, padding_right = paddings
78 | assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
79 | if mode == 'reflect':
80 | max_pad = max(padding_left, padding_right)
81 | extra_pad = 0
82 | if length <= max_pad:
83 | extra_pad = max_pad - length + 1
84 | x = F.pad(x, (0, extra_pad))
85 | padded = F.pad(x, paddings, mode, value)
86 | end = padded.shape[-1] - extra_pad
87 | return padded[..., :end]
88 | else:
89 | return F.pad(x, paddings, mode, value)
90 |
91 |
92 | def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]):
93 | """Remove padding from x, handling properly zero padding. Only for 1d!
94 | """
95 | padding_left, padding_right = paddings
96 | assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
97 | assert (padding_left + padding_right) <= x.shape[-1]
98 | end = x.shape[-1] - padding_right
99 | return x[..., padding_left: end]
100 |
101 |
102 | class NormConv1d(nn.Module):
103 | """Wrapper around Conv1d and normalization applied to this conv
104 | to provide a uniform interface across normalization approaches.
105 | """
106 | def __init__(self, *args, causal: bool = False, norm: str = 'none',
107 | norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
108 | super().__init__()
109 | self.conv = apply_parametrization_norm(nn.Conv1d(*args, **kwargs), norm)
110 | self.norm = get_norm_module(self.conv, causal, norm, **norm_kwargs)
111 | self.norm_type = norm
112 |
113 | def forward(self, x):
114 | x = self.conv(x)
115 | x = self.norm(x)
116 | return x
117 |
118 |
119 | class NormConv2d(nn.Module):
120 | """Wrapper around Conv2d and normalization applied to this conv
121 | to provide a uniform interface across normalization approaches.
122 | """
123 | def __init__(self, *args, norm: str = 'none', norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
124 | super().__init__()
125 | self.conv = apply_parametrization_norm(nn.Conv2d(*args, **kwargs), norm)
126 | self.norm = get_norm_module(self.conv, causal=False, norm=norm, **norm_kwargs)
127 | self.norm_type = norm
128 |
129 | def forward(self, x):
130 | x = self.conv(x)
131 | x = self.norm(x)
132 | return x
133 |
134 |
135 | class NormConvTranspose1d(nn.Module):
136 | """Wrapper around ConvTranspose1d and normalization applied to this conv
137 | to provide a uniform interface across normalization approaches.
138 | """
139 | def __init__(self, *args, causal: bool = False, norm: str = 'none',
140 | norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
141 | super().__init__()
142 | self.convtr = apply_parametrization_norm(nn.ConvTranspose1d(*args, **kwargs), norm)
143 | self.norm = get_norm_module(self.convtr, causal, norm, **norm_kwargs)
144 | self.norm_type = norm
145 |
146 | def forward(self, x):
147 | x = self.convtr(x)
148 | x = self.norm(x)
149 | return x
150 |
151 |
152 | class NormConvTranspose2d(nn.Module):
153 | """Wrapper around ConvTranspose2d and normalization applied to this conv
154 | to provide a uniform interface across normalization approaches.
155 | """
156 | def __init__(self, *args, norm: str = 'none', norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
157 | super().__init__()
158 | self.convtr = apply_parametrization_norm(nn.ConvTranspose2d(*args, **kwargs), norm)
159 | self.norm = get_norm_module(self.convtr, causal=False, norm=norm, **norm_kwargs)
160 |
161 | def forward(self, x):
162 | x = self.convtr(x)
163 | x = self.norm(x)
164 | return x
165 |
166 |
167 | class StreamableConv1d(nn.Module):
168 | """Conv1d with some builtin handling of asymmetric or causal padding
169 | and normalization.
170 | """
171 | def __init__(self, in_channels: int, out_channels: int,
172 | kernel_size: int, stride: int = 1, dilation: int = 1,
173 | groups: int = 1, bias: bool = True, causal: bool = False,
174 | norm: str = 'none', norm_kwargs: tp.Dict[str, tp.Any] = {},
175 | pad_mode: str = 'reflect'):
176 | super().__init__()
177 | # warn user on unusual setup between dilation and stride
178 | if stride > 1 and dilation > 1:
179 | warnings.warn('StreamableConv1d has been initialized with stride > 1 and dilation > 1'
180 | f' (kernel_size={kernel_size} stride={stride}, dilation={dilation}).')
181 | self.conv = NormConv1d(in_channels, out_channels, kernel_size, stride,
182 | dilation=dilation, groups=groups, bias=bias, causal=causal,
183 | norm=norm, norm_kwargs=norm_kwargs)
184 | self.causal = causal
185 | self.pad_mode = pad_mode
186 |
187 | def forward(self, x):
188 | B, C, T = x.shape
189 | kernel_size = self.conv.conv.kernel_size[0]
190 | stride = self.conv.conv.stride[0]
191 | dilation = self.conv.conv.dilation[0]
192 | kernel_size = (kernel_size - 1) * dilation + 1 # effective kernel size with dilations
193 | padding_total = kernel_size - stride
194 | extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total)
195 | if self.causal:
196 | # Left padding for causal
197 | x = pad1d(x, (padding_total, extra_padding), mode=self.pad_mode)
198 | else:
199 | # Asymmetric padding required for odd strides
200 | padding_right = padding_total // 2
201 | padding_left = padding_total - padding_right
202 | x = pad1d(x, (padding_left, padding_right + extra_padding), mode=self.pad_mode)
203 | return self.conv(x)
204 |
205 |
206 | class StreamableConvTranspose1d(nn.Module):
207 | """ConvTranspose1d with some builtin handling of asymmetric or causal padding
208 | and normalization.
209 | """
210 | def __init__(self, in_channels: int, out_channels: int,
211 | kernel_size: int, stride: int = 1, causal: bool = False,
212 | norm: str = 'none', trim_right_ratio: float = 1.,
213 | norm_kwargs: tp.Dict[str, tp.Any] = {}):
214 | super().__init__()
215 | self.convtr = NormConvTranspose1d(in_channels, out_channels, kernel_size, stride,
216 | causal=causal, norm=norm, norm_kwargs=norm_kwargs)
217 | self.causal = causal
218 | self.trim_right_ratio = trim_right_ratio
219 | assert self.causal or self.trim_right_ratio == 1., \
220 | "`trim_right_ratio` != 1.0 only makes sense for causal convolutions"
221 | assert self.trim_right_ratio >= 0. and self.trim_right_ratio <= 1.
222 |
223 | def forward(self, x):
224 | kernel_size = self.convtr.convtr.kernel_size[0]
225 | stride = self.convtr.convtr.stride[0]
226 | padding_total = kernel_size - stride
227 |
228 | y = self.convtr(x)
229 |
230 | # We will only trim fixed padding. Extra padding from `pad_for_conv1d` would be
231 | # removed at the very end, when keeping only the right length for the output,
232 | # as removing it here would require also passing the length at the matching layer
233 | # in the encoder.
234 | if self.causal:
235 | # Trim the padding on the right according to the specified ratio
236 | # if trim_right_ratio = 1.0, trim everything from right
237 | padding_right = math.ceil(padding_total * self.trim_right_ratio)
238 | padding_left = padding_total - padding_right
239 | y = unpad1d(y, (padding_left, padding_right))
240 | else:
241 | # Asymmetric padding required for odd strides
242 | padding_right = padding_total // 2
243 | padding_left = padding_total - padding_right
244 | y = unpad1d(y, (padding_left, padding_right))
245 | return y
246 |
--------------------------------------------------------------------------------
/audiocraft/modules/lstm.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from torch import nn
8 |
9 |
10 | class StreamableLSTM(nn.Module):
11 | """LSTM without worrying about the hidden state, nor the layout of the data.
12 | Expects input as convolutional layout.
13 | """
14 | def __init__(self, dimension: int, num_layers: int = 2, skip: bool = True):
15 | super().__init__()
16 | self.skip = skip
17 | self.lstm = nn.LSTM(dimension, dimension, num_layers)
18 |
19 | def forward(self, x):
20 | x = x.permute(2, 0, 1)
21 | y, _ = self.lstm(x)
22 | if self.skip:
23 | y = y + x
24 | y = y.permute(1, 2, 0)
25 | return y
26 |
--------------------------------------------------------------------------------
/audiocraft/modules/rope.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import typing as tp
8 |
9 | from torch import nn
10 | import torch
11 |
12 |
13 | class XPos(nn.Module):
14 | """Length-extrapolatable positional embedding (xPos) from [Sun et al 2022](https://arxiv.org/abs/2212.10554v1).
15 | This applies an exponential decay to the RoPE rotation matrix.
16 |
17 | Args:
18 | dim (int): Embedding dimension.
19 | smoothing (float): Smoothing factor applied to the decay rates.
20 | base_scale (int): Base decay rate, given in terms of scaling time.
21 | device (torch.device or None): Device on which to initialize the module.
22 | dtype (torch.dtype): dtype to use to generate the embedding.
23 | """
24 | def __init__(self, dim: int, smoothing: float = 0.4, base_scale: int = 512,
25 | device=None, dtype: torch.dtype = torch.float32):
26 | super().__init__()
27 | assert dim % 2 == 0
28 | assert dtype in [torch.float64, torch.float32]
29 | self.dtype = dtype
30 | self.base_scale = base_scale
31 |
32 | half_dim = dim // 2
33 | adim = torch.arange(half_dim, device=device, dtype=dtype)
34 | decay_rates = (adim / half_dim + smoothing) / (1.0 + smoothing)
35 | self.register_buffer("decay_rates", decay_rates)
36 | self.decay: tp.Optional[torch.Tensor] = None
37 |
38 | def get_decay(self, start: int, end: int):
39 | """Create complex decay tensor, cache values for fast computation.
40 | """
41 | if self.decay is None or end > self.decay.shape[0]:
42 | assert isinstance(self.decay_rates, torch.Tensor) # Satisfy type checker.
43 | idx = torch.arange(end, device=self.decay_rates.device, dtype=self.dtype)
44 | power = idx / self.base_scale
45 | scale = self.decay_rates ** power.unsqueeze(-1)
46 | self.decay = torch.polar(scale, torch.zeros_like(scale))
47 | return self.decay[start:end] # [T, C/2]
48 |
49 |
50 | class RotaryEmbedding(nn.Module):
51 | """Rotary positional embedding (RoPE) from [Su et al 2022](https://arxiv.org/abs/2104.09864).
52 |
53 | Args:
54 | dim (int): Embedding dimension (twice the number of frequencies).
55 | max_period (float): Maximum period of the rotation frequencies.
56 | xpos (bool): Use xPos, applies an exponential decay to rotation matrix.
57 | scale (float): Scale of positional embedding, set to 0 to deactivate.
58 | device (torch.device or None): Device on which to initialize the module.
59 | dtype (torch.dtype): dtype to use to generate the embedding.
60 | """
61 | def __init__(self, dim: int, max_period: float = 10000.0, xpos: bool = False,
62 | scale: float = 1.0, device=None, dtype: torch.dtype = torch.float32):
63 | super().__init__()
64 | assert dim % 2 == 0
65 | self.scale = scale
66 | assert dtype in [torch.float64, torch.float32]
67 | self.dtype = dtype
68 |
69 | adim = torch.arange(0, dim, 2, device=device, dtype=dtype)[: (dim // 2)]
70 | frequencies = 1.0 / (max_period ** (adim / dim))
71 | self.register_buffer("frequencies", frequencies)
72 | self.rotation: tp.Optional[torch.Tensor] = None
73 |
74 | self.xpos = XPos(dim, device=device, dtype=dtype) if xpos else None
75 |
76 | def get_rotation(self, start: int, end: int):
77 | """Create complex rotation tensor, cache values for fast computation.
78 | """
79 | if self.rotation is None or end > self.rotation.shape[0]:
80 | assert isinstance(self.frequencies, torch.Tensor) # Satisfy type checker.
81 | idx = torch.arange(end, device=self.frequencies.device, dtype=self.dtype)
82 | angles = torch.outer(idx, self.frequencies)
83 | self.rotation = torch.polar(torch.ones_like(angles), angles)
84 | return self.rotation[start:end]
85 |
86 | def rotate(self, x: torch.Tensor, start: int = 0, invert_decay: bool = False):
87 | """Apply rope rotation to query or key tensor.
88 | """
89 | T = x.shape[1]
90 | rotation = self.get_rotation(start, start + T).unsqueeze(0).unsqueeze(2)
91 |
92 | if self.xpos:
93 | decay = self.xpos.get_decay(start, start + T).unsqueeze(0).unsqueeze(2)
94 | else:
95 | decay = 1.0
96 |
97 | if invert_decay:
98 | decay = decay ** -1
99 |
100 | x_complex = torch.view_as_complex(x.to(self.dtype).reshape(*x.shape[:-1], -1, 2))
101 | scaled_rotation = (rotation * decay) * self.scale + (1.0 - self.scale)
102 | x_out = torch.view_as_real(x_complex * scaled_rotation).flatten(-2)
103 |
104 | return x_out.type_as(x)
105 |
106 | def rotate_qk(self, query: torch.Tensor, key: torch.Tensor, start: int = 0):
107 | """ Apply rope rotation to both query and key tensors.
108 | Supports streaming mode, in which query and key are not expected to have the same shape.
109 | In streaming mode, key will be of legnth [P + C] with P the cached past timesteps, but
110 | query will be [C] (typically C == 1).
111 |
112 | Args:
113 | query (torch.Tensor): Query to rotate.
114 | key (torch.Tensor): Key to rotate.
115 | start (int): Start index of the sequence for time offset.
116 | """
117 | query_timesteps = query.shape[1]
118 | key_timesteps = key.shape[1]
119 | streaming_offset = key_timesteps - query_timesteps
120 |
121 | query_out = self.rotate(query, start + streaming_offset)
122 | key_out = self.rotate(key, start, invert_decay=True)
123 |
124 | return query_out, key_out
125 |
--------------------------------------------------------------------------------
/audiocraft/modules/streaming.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | """
8 | Streaming module API that should be implemented by all Streaming components,
9 | """
10 |
11 | from contextlib import contextmanager
12 | import typing as tp
13 | from torch import nn
14 | import torch
15 |
16 |
17 | State = tp.Dict[str, torch.Tensor]
18 |
19 |
20 | class StreamingModule(nn.Module):
21 | """Common API for streaming components.
22 |
23 | Each streaming component has a streaming state, which is just a dict[str, Tensor].
24 | By convention, the first dim of each tensor must be the batch size.
25 | Don't use dots in the key names, as this would clash with submodules
26 | (like in state_dict).
27 |
28 | If `self._is_streaming` is True, the component should use and remember
29 | the proper state inside `self._streaming_state`.
30 |
31 | To set a streaming component in streaming state, use
32 |
33 | with module.streaming():
34 | ...
35 |
36 | This will automatically reset the streaming state when exiting the context manager.
37 | This also automatically propagates to all streaming children module.
38 |
39 | Some module might also implement the `StreamingModule.flush` method, although
40 | this one is trickier, as all parents module must be StreamingModule and implement
41 | it as well for it to work properly. See `StreamingSequential` after.
42 | """
43 | def __init__(self) -> None:
44 | super().__init__()
45 | self._streaming_state: State = {}
46 | self._is_streaming = False
47 |
48 | def _apply_named_streaming(self, fn: tp.Any):
49 | for name, module in self.named_modules():
50 | if isinstance(module, StreamingModule):
51 | fn(name, module)
52 |
53 | def _set_streaming(self, streaming: bool):
54 | def _set_streaming(name, module):
55 | module._is_streaming = streaming
56 | self._apply_named_streaming(_set_streaming)
57 |
58 | @contextmanager
59 | def streaming(self):
60 | """Context manager to enter streaming mode. Reset streaming state on exit.
61 | """
62 | self._set_streaming(True)
63 | try:
64 | yield
65 | finally:
66 | self._set_streaming(False)
67 | self.reset_streaming()
68 |
69 | def reset_streaming(self):
70 | """Reset the streaming state.
71 | """
72 | def _reset(name: str, module: StreamingModule):
73 | module._streaming_state.clear()
74 |
75 | self._apply_named_streaming(_reset)
76 |
77 | def get_streaming_state(self) -> State:
78 | """Return the streaming state, including that of sub-modules.
79 | """
80 | state: State = {}
81 |
82 | def _add(name: str, module: StreamingModule):
83 | if name:
84 | name += "."
85 | for key, value in module._streaming_state.items():
86 | state[name + key] = value
87 |
88 | self._apply_named_streaming(_add)
89 | return state
90 |
91 | def set_streaming_state(self, state: State):
92 | """Set the streaming state, including that of sub-modules.
93 | """
94 | state = dict(state)
95 |
96 | def _set(name: str, module: StreamingModule):
97 | if name:
98 | name += "."
99 | module._streaming_state.clear()
100 | for key, value in list(state.items()):
101 | # complexity is not ideal here, but probably fine.
102 | if key.startswith(name):
103 | local_key = key[len(name):]
104 | if '.' not in local_key:
105 | module._streaming_state[local_key] = value
106 | del state[key]
107 |
108 | self._apply_named_streaming(_set)
109 | assert len(state) == 0, list(state.keys())
110 |
111 | def flush(self, x: tp.Optional[torch.Tensor] = None):
112 | """Flush any remaining outputs that were waiting for completion.
113 | Typically, for convolutions, this will add the final padding
114 | and process the last buffer.
115 |
116 | This should take an optional argument `x`, which will be provided
117 | if a module before this one in the streaming pipeline has already
118 | spitted out a flushed out buffer.
119 | """
120 | if x is None:
121 | return None
122 | else:
123 | return self(x)
124 |
125 |
126 | class StreamingSequential(StreamingModule, nn.Sequential):
127 | """A streaming compatible alternative of `nn.Sequential`.
128 | """
129 | def flush(self, x: tp.Optional[torch.Tensor] = None):
130 | for module in self:
131 | if isinstance(module, StreamingModule):
132 | x = module.flush(x)
133 | elif x is not None:
134 | x = module(x)
135 | return x
136 |
--------------------------------------------------------------------------------
/audiocraft/py.typed:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rkfg/audiocraft/6d70065e31c2fb422a76237e03740dd3b627de8d/audiocraft/py.typed
--------------------------------------------------------------------------------
/audiocraft/quantization/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # flake8: noqa
8 | from .vq import ResidualVectorQuantizer
9 | from .base import BaseQuantizer, DummyQuantizer, QuantizedResult
10 |
--------------------------------------------------------------------------------
/audiocraft/quantization/base.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | """
8 | Base class for all quantizers.
9 | """
10 |
11 | from dataclasses import dataclass, field
12 | import typing as tp
13 |
14 | import torch
15 | from torch import nn
16 |
17 |
18 | @dataclass
19 | class QuantizedResult:
20 | x: torch.Tensor
21 | codes: torch.Tensor
22 | bandwidth: torch.Tensor # bandwidth in kb/s used, per batch item.
23 | penalty: tp.Optional[torch.Tensor] = None
24 | metrics: dict = field(default_factory=dict)
25 |
26 |
27 | class BaseQuantizer(nn.Module):
28 | """Base class for quantizers.
29 | """
30 |
31 | def forward(self, x: torch.Tensor, frame_rate: int) -> QuantizedResult:
32 | """
33 | Given input tensor x, returns first the quantized (or approximately quantized)
34 | representation along with quantized codes, bandwidth, and any penalty term for the loss.
35 | Finally, this returns a dict of metrics to update logging etc.
36 | Frame rate must be passed so that the bandwidth is properly computed.
37 | """
38 | raise NotImplementedError()
39 |
40 | def encode(self, x: torch.Tensor) -> torch.Tensor:
41 | """Encode a given input tensor with the specified sample rate at the given bandwidth.
42 | """
43 | raise NotImplementedError()
44 |
45 | def decode(self, codes: torch.Tensor) -> torch.Tensor:
46 | """Decode the given codes to the quantized representation.
47 | """
48 | raise NotImplementedError()
49 |
50 | @property
51 | def total_codebooks(self):
52 | """Total number of codebooks.
53 | """
54 | raise NotImplementedError()
55 |
56 | @property
57 | def num_codebooks(self):
58 | """Number of active codebooks.
59 | """
60 | raise NotImplementedError()
61 |
62 | def set_num_codebooks(self, n: int):
63 | """Set the number of active codebooks.
64 | """
65 | raise NotImplementedError()
66 |
67 |
68 | class DummyQuantizer(BaseQuantizer):
69 | """Fake quantizer that actually does not perform any quantization.
70 | """
71 | def __init__(self):
72 | super().__init__()
73 |
74 | def forward(self, x: torch.Tensor, frame_rate: int):
75 | q = x.unsqueeze(1)
76 | return QuantizedResult(x, q, torch.tensor(q.numel() * 32 * frame_rate / 1000 / len(x)).to(x))
77 |
78 | def encode(self, x: torch.Tensor) -> torch.Tensor:
79 | """Encode a given input tensor with the specified sample rate at the given bandwidth.
80 | In the case of the DummyQuantizer, the codes are actually identical
81 | to the input and resulting quantized representation as no quantization is done.
82 | """
83 | return x.unsqueeze(1)
84 |
85 | def decode(self, codes: torch.Tensor) -> torch.Tensor:
86 | """Decode the given codes to the quantized representation.
87 | In the case of the DummyQuantizer, the codes are actually identical
88 | to the input and resulting quantized representation as no quantization is done.
89 | """
90 | return codes.squeeze(1)
91 |
92 | @property
93 | def total_codebooks(self):
94 | """Total number of codebooks.
95 | """
96 | return 1
97 |
98 | @property
99 | def num_codebooks(self):
100 | """Total number of codebooks.
101 | """
102 | return self.total_codebooks
103 |
104 | def set_num_codebooks(self, n: int):
105 | """Set the number of active codebooks.
106 | """
107 | raise AttributeError("Cannot override the number of codebooks for the dummy quantizer")
108 |
--------------------------------------------------------------------------------
/audiocraft/quantization/vq.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import math
8 | import typing as tp
9 |
10 | import torch
11 |
12 | from .base import BaseQuantizer, QuantizedResult
13 | from .core_vq import ResidualVectorQuantization
14 |
15 |
16 | class ResidualVectorQuantizer(BaseQuantizer):
17 | """Residual Vector Quantizer.
18 |
19 | Args:
20 | dimension (int): Dimension of the codebooks.
21 | n_q (int): Number of residual vector quantizers used.
22 | q_dropout (bool): Random quantizer drop out at train time.
23 | bins (int): Codebook size.
24 | decay (float): Decay for exponential moving average over the codebooks.
25 | kmeans_init (bool): Whether to use kmeans to initialize the codebooks.
26 | kmeans_iters (int): Number of iterations used for kmeans initialization.
27 | threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
28 | that have an exponential moving average cluster size less than the specified threshold with
29 | randomly selected vector from the current batch.
30 | orthogonal_reg_weight (float): Orthogonal regularization weights.
31 | orthogonal_reg_active_codes_only (bool): Apply orthogonal regularization only on active codes.
32 | orthogonal_reg_max_codes (optional int): Maximum number of codes to consider.
33 | for orthogonal regulariation.
34 | """
35 | def __init__(
36 | self,
37 | dimension: int = 256,
38 | n_q: int = 8,
39 | q_dropout: bool = False,
40 | bins: int = 1024,
41 | decay: float = 0.99,
42 | kmeans_init: bool = True,
43 | kmeans_iters: int = 10,
44 | threshold_ema_dead_code: int = 2,
45 | orthogonal_reg_weight: float = 0.0,
46 | orthogonal_reg_active_codes_only: bool = False,
47 | orthogonal_reg_max_codes: tp.Optional[int] = None,
48 | ):
49 | super().__init__()
50 | self.max_n_q = n_q
51 | self.n_q = n_q
52 | self.q_dropout = q_dropout
53 | self.dimension = dimension
54 | self.bins = bins
55 | self.decay = decay
56 | self.kmeans_init = kmeans_init
57 | self.kmeans_iters = kmeans_iters
58 | self.threshold_ema_dead_code = threshold_ema_dead_code
59 | self.orthogonal_reg_weight = orthogonal_reg_weight
60 | self.orthogonal_reg_active_codes_only = orthogonal_reg_active_codes_only
61 | self.orthogonal_reg_max_codes = orthogonal_reg_max_codes
62 | self.vq = ResidualVectorQuantization(
63 | dim=self.dimension,
64 | codebook_size=self.bins,
65 | num_quantizers=self.n_q,
66 | decay=self.decay,
67 | kmeans_init=self.kmeans_init,
68 | kmeans_iters=self.kmeans_iters,
69 | threshold_ema_dead_code=self.threshold_ema_dead_code,
70 | orthogonal_reg_weight=self.orthogonal_reg_weight,
71 | orthogonal_reg_active_codes_only=self.orthogonal_reg_active_codes_only,
72 | orthogonal_reg_max_codes=self.orthogonal_reg_max_codes,
73 | channels_last=False
74 | )
75 |
76 | def forward(self, x: torch.Tensor, frame_rate: int):
77 | n_q = self.n_q
78 | if self.training and self.q_dropout:
79 | n_q = int(torch.randint(1, self.n_q + 1, (1,)).item())
80 | bw_per_q = math.log2(self.bins) * frame_rate / 1000
81 | quantized, codes, commit_loss = self.vq(x, n_q=n_q)
82 | codes = codes.transpose(0, 1)
83 | # codes is [B, K, T], with T frames, K nb of codebooks.
84 | bw = torch.tensor(n_q * bw_per_q).to(x)
85 | return QuantizedResult(quantized, codes, bw, penalty=torch.mean(commit_loss))
86 |
87 | def encode(self, x: torch.Tensor) -> torch.Tensor:
88 | """Encode a given input tensor with the specified frame rate at the given bandwidth.
89 | The RVQ encode method sets the appropriate number of quantizer to use
90 | and returns indices for each quantizer.
91 | """
92 | n_q = self.n_q
93 | codes = self.vq.encode(x, n_q=n_q)
94 | codes = codes.transpose(0, 1)
95 | # codes is [B, K, T], with T frames, K nb of codebooks.
96 | return codes
97 |
98 | def decode(self, codes: torch.Tensor) -> torch.Tensor:
99 | """Decode the given codes to the quantized representation.
100 | """
101 | # codes is [B, K, T], with T frames, K nb of codebooks, vq.decode expects [K, B, T].
102 | codes = codes.transpose(0, 1)
103 | quantized = self.vq.decode(codes)
104 | return quantized
105 |
106 | @property
107 | def total_codebooks(self):
108 | return self.max_n_q
109 |
110 | @property
111 | def num_codebooks(self):
112 | return self.n_q
113 |
114 | def set_num_codebooks(self, n: int):
115 | assert n > 0 and n <= self.max_n_q
116 | self.n_q = n
117 |
--------------------------------------------------------------------------------
/audiocraft/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/audiocraft/utils/autocast.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import torch
8 |
9 |
10 | class TorchAutocast:
11 | """TorchAutocast utility class.
12 | Allows you to enable and disable autocast. This is specially useful
13 | when dealing with different architectures and clusters with different
14 | levels of support.
15 |
16 | Args:
17 | enabled (bool): Whether to enable torch.autocast or not.
18 | args: Additional args for torch.autocast.
19 | kwargs: Additional kwargs for torch.autocast
20 | """
21 | def __init__(self, enabled: bool, *args, **kwargs):
22 | self.autocast = torch.autocast(*args, **kwargs) if enabled else None
23 |
24 | def __enter__(self):
25 | if self.autocast is None:
26 | return
27 | try:
28 | self.autocast.__enter__()
29 | except RuntimeError:
30 | device = self.autocast.device
31 | dtype = self.autocast.fast_dtype
32 | raise RuntimeError(
33 | f"There was an error autocasting with dtype={dtype} device={device}\n"
34 | "If you are on the FAIR Cluster, you might need to use autocast_dtype=float16"
35 | )
36 |
37 | def __exit__(self, *args, **kwargs):
38 | if self.autocast is None:
39 | return
40 | self.autocast.__exit__(*args, **kwargs)
41 |
--------------------------------------------------------------------------------
/audiocraft/utils/export.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | """
8 | Utility to export a training checkpoint to a lightweight release checkpoint.
9 | """
10 |
11 | from pathlib import Path
12 | import typing as tp
13 |
14 | from omegaconf import OmegaConf, DictConfig
15 | import torch
16 |
17 |
18 | def _clean_lm_cfg(cfg: DictConfig):
19 | OmegaConf.set_struct(cfg, False)
20 | # This used to be set automatically in the LM solver, need a more robust solution
21 | # for the future.
22 | cfg['transformer_lm']['card'] = 2048
23 | cfg['transformer_lm']['n_q'] = 4
24 | # Experimental params no longer supported.
25 | bad_params = ['spectral_norm_attn_iters', 'spectral_norm_ff_iters',
26 | 'residual_balancer_attn', 'residual_balancer_ff', 'layer_drop']
27 | for name in bad_params:
28 | del cfg['transformer_lm'][name]
29 | OmegaConf.set_struct(cfg, True)
30 | return cfg
31 |
32 |
33 | def export_encodec(checkpoint_path: tp.Union[Path, str], out_folder: tp.Union[Path, str]):
34 | sig = Path(checkpoint_path).parent.name
35 | assert len(sig) == 8, "Not a valid Dora signature"
36 | pkg = torch.load(checkpoint_path, 'cpu')
37 | new_pkg = {
38 | 'best_state': pkg['ema']['state']['model'],
39 | 'xp.cfg': OmegaConf.to_yaml(pkg['xp.cfg']),
40 | }
41 | out_file = Path(out_folder) / f'{sig}.th'
42 | torch.save(new_pkg, out_file)
43 | return out_file
44 |
45 |
46 | def export_lm(checkpoint_path: tp.Union[Path, str], out_folder: tp.Union[Path, str]):
47 | sig = Path(checkpoint_path).parent.name
48 | assert len(sig) == 8, "Not a valid Dora signature"
49 | pkg = torch.load(checkpoint_path, 'cpu')
50 | new_pkg = {
51 | 'best_state': pkg['fsdp_best_state']['model'],
52 | 'xp.cfg': OmegaConf.to_yaml(_clean_lm_cfg(pkg['xp.cfg']))
53 | }
54 | out_file = Path(out_folder) / f'{sig}.th'
55 | torch.save(new_pkg, out_file)
56 | return out_file
57 |
--------------------------------------------------------------------------------
/audiocraft/utils/notebook.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | try:
8 | import IPython.display as ipd # type: ignore
9 | except ImportError:
10 | # Note in a notebook...
11 | pass
12 |
13 |
14 | import torch
15 |
16 |
17 | def display_audio(samples: torch.Tensor, sample_rate: int):
18 | """Renders an audio player for the given audio samples.
19 |
20 | Args:
21 | samples (torch.Tensor): a Tensor of decoded audio samples
22 | with shapes [B, C, T] or [C, T]
23 | sample_rate (int): sample rate audio should be displayed with.
24 | """
25 | assert samples.dim() == 2 or samples.dim() == 3
26 |
27 | samples = samples.detach().cpu()
28 | if samples.dim() == 2:
29 | samples = samples[None, ...]
30 |
31 | for audio in samples:
32 | ipd.display(ipd.Audio(audio, rate=sample_rate))
33 |
--------------------------------------------------------------------------------
/audiocraft/utils/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from concurrent.futures import ProcessPoolExecutor
8 | from functools import wraps
9 | import hashlib
10 | import logging
11 | import typing as tp
12 |
13 | import flashy
14 | import flashy.distrib
15 | import omegaconf
16 | import torch
17 | from torch.nn.utils.rnn import pad_sequence
18 |
19 |
20 | logger = logging.getLogger(__name__)
21 |
22 |
23 | def dict_from_config(cfg: omegaconf.DictConfig) -> dict:
24 | """Convenience function to map an omegaconf configuration to a dictionary.
25 |
26 | Args:
27 | cfg (omegaconf.DictConfig): Original configuration to map to dict.
28 | Returns:
29 | dict: Config as dictionary object.
30 | """
31 | dct = omegaconf.OmegaConf.to_container(cfg, resolve=True)
32 | assert isinstance(dct, dict)
33 | return dct
34 |
35 |
36 | def random_subset(dataset, max_samples: int, seed: int = 42) -> torch.utils.data.Subset:
37 | if max_samples >= len(dataset):
38 | return dataset
39 |
40 | generator = torch.Generator().manual_seed(seed)
41 | perm = torch.randperm(len(dataset), generator=generator)
42 | return torch.utils.data.Subset(dataset, perm[:max_samples].tolist())
43 |
44 |
45 | def get_loader(dataset, num_samples: tp.Optional[int], batch_size: int,
46 | num_workers: int, seed: int, **kwargs) -> torch.utils.data.DataLoader:
47 | """Convenience function to load dataset into a dataloader with optional subset sampling.
48 |
49 | Args:
50 | dataset: Dataset to load.
51 | num_samples (Optional[int]): Number of samples to limit subset size.
52 | batch_size (int): Batch size.
53 | num_workers (int): Number of workers for data loading.
54 | seed (int): Random seed.
55 | """
56 | if num_samples is not None:
57 | dataset = random_subset(dataset, num_samples, seed)
58 |
59 | dataloader = flashy.distrib.loader(
60 | dataset,
61 | batch_size=batch_size,
62 | num_workers=num_workers,
63 | **kwargs
64 | )
65 | return dataloader
66 |
67 |
68 | def get_dataset_from_loader(dataloader):
69 | dataset = dataloader.dataset
70 | if isinstance(dataset, torch.utils.data.Subset):
71 | return dataset.dataset
72 | else:
73 | return dataset
74 |
75 |
76 | def multinomial(input: torch.Tensor, num_samples: int, replacement=False, *, generator=None):
77 | """torch.multinomial with arbitrary number of dimensions, and number of candidates on the last dimension.
78 |
79 | Args:
80 | input (torch.Tensor): The input tensor containing probabilities.
81 | num_samples (int): Number of samples to draw.
82 | replacement (bool): Whether to draw with replacement or not.
83 | Keywords args:
84 | generator (torch.Generator): A pseudorandom number generator for sampling.
85 | Returns:
86 | torch.Tensor: Last dimension contains num_samples indices
87 | sampled from the multinomial probability distribution
88 | located in the last dimension of tensor input.
89 | """
90 | input_ = input.reshape(-1, input.shape[-1])
91 | output_ = torch.multinomial(input_, num_samples=num_samples, replacement=replacement, generator=generator)
92 | output = output_.reshape(*list(input.shape[:-1]), -1)
93 | return output
94 |
95 |
96 | def sample_top_k(probs: torch.Tensor, k: int) -> torch.Tensor:
97 | """Sample next token from top K values along the last dimension of the input probs tensor.
98 |
99 | Args:
100 | probs (torch.Tensor): Input probabilities with token candidates on the last dimension.
101 | k (int): The k in “top-k”.
102 | Returns:
103 | torch.Tensor: Sampled tokens.
104 | """
105 | top_k_value, _ = torch.topk(probs, k, dim=-1)
106 | min_value_top_k = top_k_value[..., [-1]]
107 | probs *= (probs >= min_value_top_k).float()
108 | probs.div_(probs.sum(dim=-1, keepdim=True))
109 | next_token = multinomial(probs, num_samples=1)
110 | return next_token
111 |
112 |
113 | def sample_top_p(probs: torch.Tensor, p: float) -> torch.Tensor:
114 | """Sample next token from top P probabilities along the last dimension of the input probs tensor.
115 |
116 | Args:
117 | probs (torch.Tensor): Input probabilities with token candidates on the last dimension.
118 | p (int): The p in “top-p”.
119 | Returns:
120 | torch.Tensor: Sampled tokens.
121 | """
122 | probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
123 | probs_sum = torch.cumsum(probs_sort, dim=-1)
124 | mask = probs_sum - probs_sort > p
125 | probs_sort *= (~mask).float()
126 | probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
127 | next_token = multinomial(probs_sort, num_samples=1)
128 | next_token = torch.gather(probs_idx, -1, next_token)
129 | return next_token
130 |
131 |
132 | class DummyPoolExecutor:
133 | """Dummy pool executor to use when we actually have only 1 worker.
134 | (e.g. instead of ProcessPoolExecutor).
135 | """
136 | class DummyResult:
137 | def __init__(self, func, *args, **kwargs):
138 | self.func = func
139 | self.args = args
140 | self.kwargs = kwargs
141 |
142 | def result(self):
143 | return self.func(*self.args, **self.kwargs)
144 |
145 | def __init__(self, workers, mp_context=None):
146 | pass
147 |
148 | def submit(self, func, *args, **kwargs):
149 | return DummyPoolExecutor.DummyResult(func, *args, **kwargs)
150 |
151 | def __enter__(self):
152 | return self
153 |
154 | def __exit__(self, exc_type, exc_value, exc_tb):
155 | return
156 |
157 |
158 | def get_pool_executor(num_workers: int, mp_context=None):
159 | return ProcessPoolExecutor(num_workers, mp_context) if num_workers > 1 else DummyPoolExecutor(1)
160 |
161 |
162 | def length_to_mask(lengths: torch.Tensor, max_len: tp.Optional[int] = None) -> torch.Tensor:
163 | """Utility function to convert a tensor of sequence lengths to a mask (useful when working on padded sequences).
164 | For example: [3, 5] => [[1, 1, 1, 0, 0], [1, 1, 1, 1, 1]]
165 |
166 | Args:
167 | lengths (torch.Tensor): tensor with lengths
168 | max_len (int): can set the max length manually. Defaults to None.
169 | Returns:
170 | torch.Tensor: mask with 0s where there is pad tokens else 1s
171 | """
172 | assert len(lengths.shape) == 1, "Length shape should be 1 dimensional."
173 | final_length = lengths.max().item() if not max_len else max_len
174 | final_length = max(final_length, 1) # if all seqs are of len zero we don't want a zero-size tensor
175 | return torch.arange(final_length)[None, :].to(lengths.device) < lengths[:, None]
176 |
177 |
178 | def hash_trick(word: str, vocab_size: int) -> int:
179 | """Hash trick to pair each word with an index
180 |
181 | Args:
182 | word (str): word we wish to convert to an index
183 | vocab_size (int): size of the vocabulary
184 | Returns:
185 | int: index of the word in the embedding LUT
186 | """
187 | hash = int(hashlib.sha256(word.encode("utf-8")).hexdigest(), 16)
188 | return hash % vocab_size
189 |
190 |
191 | def with_rank_rng(base_seed: int = 1234):
192 | """Decorator for a function so that the function will use a Random Number Generator
193 | whose state depend on the GPU rank. The original RNG state is restored upon returning.
194 |
195 | Args:
196 | base_seed (int): Random seed.
197 | """
198 | def _decorator(fun: tp.Callable):
199 | @wraps(fun)
200 | def _decorated(*args, **kwargs):
201 | state = torch.get_rng_state()
202 | seed = base_seed ^ flashy.distrib.rank()
203 | torch.manual_seed(seed)
204 | logger.debug('Rank dependent seed set to %d', seed)
205 | try:
206 | return fun(*args, **kwargs)
207 | finally:
208 | torch.set_rng_state(state)
209 | logger.debug('RNG state restored.')
210 | return _decorated
211 | return _decorator
212 |
213 |
214 | def collate(tensors: tp.List[torch.Tensor], dim: int = 0) -> tp.Tuple[torch.Tensor, torch.Tensor]:
215 | """Get a list of tensors and collate them to a single tensor. according to the following logic:
216 | - `dim` specifies the time dimension which will be stacked and padded.
217 | - The output will contain 1 new dimension (dimension index 0) which will be the size of
218 | of the original list.
219 |
220 | Args:
221 | tensors (tp.List[torch.Tensor]): List of tensors to collate.
222 | dim (int): Dimension which will be stacked and padded.
223 | Returns:
224 | tp.Tuple[torch.Tensor, torch.Tensor]:
225 | torch.Tensor: Stacked and padded tensor. The output will contain 1 new dimension
226 | (dimension index 0) which will be the size of the original list.
227 | torch.Tensor: Tensor containing length of original tensor sizes (without padding).
228 | """
229 | tensors = [x.transpose(0, dim) for x in tensors]
230 | lens = torch.LongTensor([len(x) for x in tensors])
231 | padded_tensors = pad_sequence(tensors)
232 | padded_tensors = padded_tensors.transpose(0, 1)
233 | padded_tensors = padded_tensors.transpose(1, dim + 1)
234 | return padded_tensors, lens
235 |
--------------------------------------------------------------------------------
/demo.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# MusicGen\n",
8 | "Welcome to MusicGen's demo jupyter notebook. Here you will find a series of self-contained examples of how to use MusicGen in different settings.\n",
9 | "\n",
10 | "First, we start by initializing MusicGen, you can choose a model from the following selection:\n",
11 | "1. `small` - 300M transformer decoder.\n",
12 | "2. `medium` - 1.5B transformer decoder.\n",
13 | "3. `melody` - 1.5B transformer decoder also supporting melody conditioning.\n",
14 | "4. `large` - 3.3B transformer decoder.\n",
15 | "\n",
16 | "We will use the `small` variant for the purpose of this demonstration."
17 | ]
18 | },
19 | {
20 | "cell_type": "code",
21 | "execution_count": null,
22 | "metadata": {},
23 | "outputs": [],
24 | "source": [
25 | "from audiocraft.models import MusicGen\n",
26 | "\n",
27 | "# Using small model, better results would be obtained with `medium` or `large`.\n",
28 | "model = MusicGen.get_pretrained('small')"
29 | ]
30 | },
31 | {
32 | "cell_type": "markdown",
33 | "metadata": {},
34 | "source": [
35 | "Next, let us configure the generation parameters. Specifically, you can control the following:\n",
36 | "* `use_sampling` (bool, optional): use sampling if True, else do argmax decoding. Defaults to True.\n",
37 | "* `top_k` (int, optional): top_k used for sampling. Defaults to 250.\n",
38 | "* `top_p` (float, optional): top_p used for sampling, when set to 0 top_k is used. Defaults to 0.0.\n",
39 | "* `temperature` (float, optional): softmax temperature parameter. Defaults to 1.0.\n",
40 | "* `duration` (float, optional): duration of the generated waveform. Defaults to 30.0.\n",
41 | "* `cfg_coef` (float, optional): coefficient used for classifier free guidance. Defaults to 3.0.\n",
42 | "\n",
43 | "When left unchanged, MusicGen will revert to its default parameters."
44 | ]
45 | },
46 | {
47 | "cell_type": "code",
48 | "execution_count": null,
49 | "metadata": {},
50 | "outputs": [],
51 | "source": [
52 | "model.set_generation_params(\n",
53 | " use_sampling=True,\n",
54 | " top_k=250,\n",
55 | " duration=5\n",
56 | ")"
57 | ]
58 | },
59 | {
60 | "cell_type": "markdown",
61 | "metadata": {},
62 | "source": [
63 | "Next, we can go ahead and start generating music using one of the following modes:\n",
64 | "* Unconditional samples using `model.generate_unconditional`\n",
65 | "* Music continuation using `model.generate_continuation`\n",
66 | "* Text-conditional samples using `model.generate`\n",
67 | "* Melody-conditional samples using `model.generate_with_chroma`"
68 | ]
69 | },
70 | {
71 | "cell_type": "markdown",
72 | "metadata": {},
73 | "source": [
74 | "### Unconditional Generation"
75 | ]
76 | },
77 | {
78 | "cell_type": "code",
79 | "execution_count": null,
80 | "metadata": {},
81 | "outputs": [],
82 | "source": [
83 | "from audiocraft.utils.notebook import display_audio\n",
84 | "\n",
85 | "output = model.generate_unconditional(num_samples=2, progress=True)\n",
86 | "display_audio(output, sample_rate=32000)"
87 | ]
88 | },
89 | {
90 | "cell_type": "markdown",
91 | "metadata": {},
92 | "source": [
93 | "### Music Continuation"
94 | ]
95 | },
96 | {
97 | "cell_type": "code",
98 | "execution_count": null,
99 | "metadata": {},
100 | "outputs": [],
101 | "source": [
102 | "import math\n",
103 | "import torchaudio\n",
104 | "import torch\n",
105 | "from audiocraft.utils.notebook import display_audio\n",
106 | "\n",
107 | "def get_bip_bip(bip_duration=0.125, frequency=440,\n",
108 | " duration=0.5, sample_rate=32000, device=\"cuda\"):\n",
109 | " \"\"\"Generates a series of bip bip at the given frequency.\"\"\"\n",
110 | " t = torch.arange(\n",
111 | " int(duration * sample_rate), device=\"cuda\", dtype=torch.float) / sample_rate\n",
112 | " wav = torch.cos(2 * math.pi * 440 * t)[None]\n",
113 | " tp = (t % (2 * bip_duration)) / (2 * bip_duration)\n",
114 | " envelope = (tp >= 0.5).float()\n",
115 | " return wav * envelope\n"
116 | ]
117 | },
118 | {
119 | "cell_type": "code",
120 | "execution_count": null,
121 | "metadata": {},
122 | "outputs": [],
123 | "source": [
124 | "# Here we use a synthetic signal to prompt both the tonality and the BPM\n",
125 | "# of the generated audio.\n",
126 | "res = model.generate_continuation(\n",
127 | " get_bip_bip(0.125).expand(2, -1, -1), \n",
128 | " 32000, ['Jazz jazz and only jazz', \n",
129 | " 'Heartful EDM with beautiful synths and chords'], \n",
130 | " progress=True)\n",
131 | "display_audio(res, 32000)"
132 | ]
133 | },
134 | {
135 | "cell_type": "code",
136 | "execution_count": null,
137 | "metadata": {},
138 | "outputs": [],
139 | "source": [
140 | "# You can also use any audio from a file. Make sure to trim the file if it is too long!\n",
141 | "prompt_waveform, prompt_sr = torchaudio.load(\"./assets/bach.mp3\")\n",
142 | "prompt_duration = 2\n",
143 | "prompt_waveform = prompt_waveform[..., :int(prompt_duration * prompt_sr)]\n",
144 | "output = model.generate_continuation(prompt_waveform, prompt_sample_rate=prompt_sr, progress=True)\n",
145 | "display_audio(output, sample_rate=32000)"
146 | ]
147 | },
148 | {
149 | "cell_type": "markdown",
150 | "metadata": {},
151 | "source": [
152 | "### Text-conditional Generation"
153 | ]
154 | },
155 | {
156 | "cell_type": "code",
157 | "execution_count": null,
158 | "metadata": {},
159 | "outputs": [],
160 | "source": [
161 | "from audiocraft.utils.notebook import display_audio\n",
162 | "\n",
163 | "output = model.generate(\n",
164 | " descriptions=[\n",
165 | " '80s pop track with bassy drums and synth',\n",
166 | " '90s rock song with loud guitars and heavy drums',\n",
167 | " ],\n",
168 | " progress=True\n",
169 | ")\n",
170 | "display_audio(output, sample_rate=32000)"
171 | ]
172 | },
173 | {
174 | "cell_type": "markdown",
175 | "metadata": {},
176 | "source": [
177 | "### Melody-conditional Generation"
178 | ]
179 | },
180 | {
181 | "cell_type": "code",
182 | "execution_count": null,
183 | "metadata": {},
184 | "outputs": [],
185 | "source": [
186 | "import torchaudio\n",
187 | "from audiocraft.utils.notebook import display_audio\n",
188 | "\n",
189 | "model = MusicGen.get_pretrained('melody')\n",
190 | "model.set_generation_params(duration=8)\n",
191 | "\n",
192 | "melody_waveform, sr = torchaudio.load(\"assets/bach.mp3\")\n",
193 | "melody_waveform = melody_waveform.unsqueeze(0).repeat(2, 1, 1)\n",
194 | "output = model.generate_with_chroma(\n",
195 | " descriptions=[\n",
196 | " '80s pop track with bassy drums and synth',\n",
197 | " '90s rock song with loud guitars and heavy drums',\n",
198 | " ],\n",
199 | " melody_wavs=melody_waveform,\n",
200 | " melody_sample_rate=sr,\n",
201 | " progress=True\n",
202 | ")\n",
203 | "display_audio(output, sample_rate=32000)"
204 | ]
205 | },
206 | {
207 | "cell_type": "code",
208 | "execution_count": null,
209 | "metadata": {},
210 | "outputs": [],
211 | "source": []
212 | }
213 | ],
214 | "metadata": {
215 | "kernelspec": {
216 | "display_name": "Python 3 (ipykernel)",
217 | "language": "python",
218 | "name": "python3"
219 | },
220 | "language_info": {
221 | "codemirror_mode": {
222 | "name": "ipython",
223 | "version": 3
224 | },
225 | "file_extension": ".py",
226 | "mimetype": "text/x-python",
227 | "name": "python",
228 | "nbconvert_exporter": "python",
229 | "pygments_lexer": "ipython3",
230 | "version": "3.9.7"
231 | }
232 | },
233 | "nbformat": 4,
234 | "nbformat_minor": 2
235 | }
236 |
--------------------------------------------------------------------------------
/mypy.ini:
--------------------------------------------------------------------------------
1 | [mypy]
2 |
3 | [mypy-treetable,torchaudio.*,soundfile,einops.*,av.*,tqdm.*,num2words.*,spacy,xformers.*,scipy,huggingface_hub]
4 | ignore_missing_imports = True
5 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | # please make sure you have already a pytorch install that is cuda enabled!
2 | av
3 | einops
4 | flashy>=0.0.1
5 | hydra-core>=1.1
6 | hydra_colorlog
7 | julius
8 | num2words
9 | numpy
10 | sentencepiece
11 | spacy==3.5.2
12 | torch>=2.0.0
13 | torchaudio>=2.0.0
14 | huggingface_hub
15 | tqdm
16 | transformers
17 | xformers
18 | demucs
19 | librosa
20 | gradio
21 |
--------------------------------------------------------------------------------
/setup.cfg:
--------------------------------------------------------------------------------
1 | [pep8]
2 | max-line-length = 120
3 |
4 | [flake8]
5 | max-line-length = 120
6 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) Meta Platforms, Inc. and affiliates.
3 | All rights reserved.
4 |
5 | This source code is licensed under the license found in the
6 | LICENSE file in the root directory of this source tree.
7 |
8 | """
9 |
10 | from pathlib import Path
11 |
12 | from setuptools import setup, find_packages
13 |
14 |
15 | NAME = 'audiocraft'
16 | DESCRIPTION = 'Audio research library for PyTorch'
17 |
18 | URL = 'https://github.com/fairinternal/audiocraft'
19 | AUTHOR = 'FAIR Speech & Audio'
20 | EMAIL = 'defossez@meta.com'
21 | REQUIRES_PYTHON = '>=3.8.0'
22 |
23 | for line in open('audiocraft/__init__.py'):
24 | line = line.strip()
25 | if '__version__' in line:
26 | context = {}
27 | exec(line, context)
28 | VERSION = context['__version__']
29 |
30 | HERE = Path(__file__).parent
31 |
32 | try:
33 | with open(HERE / "README.md", encoding='utf-8') as f:
34 | long_description = '\n' + f.read()
35 | except FileNotFoundError:
36 | long_description = DESCRIPTION
37 |
38 | REQUIRED = [i.strip() for i in open(HERE / 'requirements.txt') if not i.startswith('#')]
39 |
40 | setup(
41 | name=NAME,
42 | version=VERSION,
43 | description=DESCRIPTION,
44 | author_email=EMAIL,
45 | long_description=long_description,
46 | long_description_content_type='text/markdown',
47 | author=AUTHOR,
48 | url=URL,
49 | python_requires=REQUIRES_PYTHON,
50 | install_requires=REQUIRED,
51 | extras_require={
52 | 'dev': ['coverage', 'flake8', 'mypy', 'pdoc3', 'pytest'],
53 | },
54 | packages=find_packages(),
55 | package_data={'audiocraft': ['py.typed']},
56 | include_package_data=True,
57 | license='MIT License',
58 | classifiers=[
59 | # Trove classifiers
60 | # Full list: https://pypi.python.org/pypi?%3Aaction=list_classifiers
61 | 'License :: OSI Approved :: MIT License',
62 | 'Topic :: Multimedia :: Sound/Audio',
63 | 'Topic :: Scientific/Engineering :: Artificial Intelligence',
64 | ],
65 | )
66 |
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/tests/common_utils/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # flake8: noqa
8 | from .temp_utils import TempDirMixin
9 | from .wav_utils import get_batch_white_noise, get_white_noise, save_wav
10 |
--------------------------------------------------------------------------------
/tests/common_utils/temp_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import os
8 | import tempfile
9 |
10 |
11 | class TempDirMixin:
12 | """Mixin to provide easy access to temp dir.
13 | """
14 |
15 | temp_dir_ = None
16 |
17 | @classmethod
18 | def get_base_temp_dir(cls):
19 | # If AUDIOCRAFT_TEST_DIR is set, use it instead of temporary directory.
20 | # this is handy for debugging.
21 | key = "AUDIOCRAFT_TEST_DIR"
22 | if key in os.environ:
23 | return os.environ[key]
24 | if cls.temp_dir_ is None:
25 | cls.temp_dir_ = tempfile.TemporaryDirectory()
26 | return cls.temp_dir_.name
27 |
28 | @classmethod
29 | def tearDownClass(cls):
30 | if cls.temp_dir_ is not None:
31 | try:
32 | cls.temp_dir_.cleanup()
33 | cls.temp_dir_ = None
34 | except PermissionError:
35 | # On Windows there is a know issue with `shutil.rmtree`,
36 | # which fails intermittenly.
37 | # https://github.com/python/cpython/issues/74168
38 | # Following the above thread, we ignore it.
39 | pass
40 | super().tearDownClass()
41 |
42 | @property
43 | def id(self):
44 | return self.__class__.__name__
45 |
46 | def get_temp_path(self, *paths):
47 | temp_dir = os.path.join(self.get_base_temp_dir(), self.id)
48 | path = os.path.join(temp_dir, *paths)
49 | os.makedirs(os.path.dirname(path), exist_ok=True)
50 | return path
51 |
52 | def get_temp_dir(self, *paths):
53 | temp_dir = os.path.join(self.get_base_temp_dir(), self.id)
54 | path = os.path.join(temp_dir, *paths)
55 | os.makedirs(path, exist_ok=True)
56 | return path
57 |
--------------------------------------------------------------------------------
/tests/common_utils/wav_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from pathlib import Path
8 | import typing as tp
9 |
10 | import torch
11 | import torchaudio
12 |
13 |
14 | def get_white_noise(chs: int = 1, num_frames: int = 1):
15 | wav = torch.randn(chs, num_frames)
16 | return wav
17 |
18 |
19 | def get_batch_white_noise(bs: int = 1, chs: int = 1, num_frames: int = 1):
20 | wav = torch.randn(bs, chs, num_frames)
21 | return wav
22 |
23 |
24 | def save_wav(path: str, wav: torch.Tensor, sample_rate: int):
25 | fp = Path(path)
26 | kwargs: tp.Dict[str, tp.Any] = {}
27 | if fp.suffix == '.wav':
28 | kwargs['encoding'] = 'PCM_S'
29 | kwargs['bits_per_sample'] = 16
30 | elif fp.suffix == '.mp3':
31 | kwargs['compression'] = 320
32 | torchaudio.save(str(fp), wav, sample_rate, **kwargs)
33 |
--------------------------------------------------------------------------------
/tests/data/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/tests/data/test_audio.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from itertools import product
8 | import random
9 |
10 | import numpy as np
11 | import torch
12 | import torchaudio
13 |
14 | from audiocraft.data.audio import audio_info, audio_read, audio_write, _av_read
15 |
16 | from ..common_utils import TempDirMixin, get_white_noise, save_wav
17 |
18 |
19 | class TestInfo(TempDirMixin):
20 |
21 | def test_info_mp3(self):
22 | sample_rates = [8000, 16_000]
23 | channels = [1, 2]
24 | duration = 1.
25 | for sample_rate, ch in product(sample_rates, channels):
26 | wav = get_white_noise(ch, int(sample_rate * duration))
27 | path = self.get_temp_path('sample_wav.mp3')
28 | save_wav(path, wav, sample_rate)
29 | info = audio_info(path)
30 | assert info.sample_rate == sample_rate
31 | assert info.channels == ch
32 | # we cannot trust torchaudio for num_frames, so we don't check
33 |
34 | def _test_info_format(self, ext: str):
35 | sample_rates = [8000, 16_000]
36 | channels = [1, 2]
37 | duration = 1.
38 | for sample_rate, ch in product(sample_rates, channels):
39 | n_frames = int(sample_rate * duration)
40 | wav = get_white_noise(ch, n_frames)
41 | path = self.get_temp_path(f'sample_wav{ext}')
42 | save_wav(path, wav, sample_rate)
43 | info = audio_info(path)
44 | assert info.sample_rate == sample_rate
45 | assert info.channels == ch
46 | assert np.isclose(info.duration, duration, atol=1e-5)
47 |
48 | def test_info_wav(self):
49 | self._test_info_format('.wav')
50 |
51 | def test_info_flac(self):
52 | self._test_info_format('.flac')
53 |
54 | def test_info_ogg(self):
55 | self._test_info_format('.ogg')
56 |
57 | def test_info_m4a(self):
58 | # TODO: generate m4a file programmatically
59 | # self._test_info_format('.m4a')
60 | pass
61 |
62 |
63 | class TestRead(TempDirMixin):
64 |
65 | def test_read_full_wav(self):
66 | sample_rates = [8000, 16_000]
67 | channels = [1, 2]
68 | duration = 1.
69 | for sample_rate, ch in product(sample_rates, channels):
70 | n_frames = int(sample_rate * duration)
71 | wav = get_white_noise(ch, n_frames).clamp(-0.99, 0.99)
72 | path = self.get_temp_path('sample_wav.wav')
73 | save_wav(path, wav, sample_rate)
74 | read_wav, read_sr = audio_read(path)
75 | assert read_sr == sample_rate
76 | assert read_wav.shape[0] == wav.shape[0]
77 | assert read_wav.shape[1] == wav.shape[1]
78 | assert torch.allclose(read_wav, wav, rtol=1e-03, atol=1e-04)
79 |
80 | def test_read_partial_wav(self):
81 | sample_rates = [8000, 16_000]
82 | channels = [1, 2]
83 | duration = 1.
84 | read_duration = torch.rand(1).item()
85 | for sample_rate, ch in product(sample_rates, channels):
86 | n_frames = int(sample_rate * duration)
87 | read_frames = int(sample_rate * read_duration)
88 | wav = get_white_noise(ch, n_frames).clamp(-0.99, 0.99)
89 | path = self.get_temp_path('sample_wav.wav')
90 | save_wav(path, wav, sample_rate)
91 | read_wav, read_sr = audio_read(path, 0, read_duration)
92 | assert read_sr == sample_rate
93 | assert read_wav.shape[0] == wav.shape[0]
94 | assert read_wav.shape[1] == read_frames
95 | assert torch.allclose(read_wav[..., 0:read_frames], wav[..., 0:read_frames], rtol=1e-03, atol=1e-04)
96 |
97 | def test_read_seek_time_wav(self):
98 | sample_rates = [8000, 16_000]
99 | channels = [1, 2]
100 | duration = 1.
101 | read_duration = 1.
102 | for sample_rate, ch in product(sample_rates, channels):
103 | n_frames = int(sample_rate * duration)
104 | wav = get_white_noise(ch, n_frames).clamp(-0.99, 0.99)
105 | path = self.get_temp_path('sample_wav.wav')
106 | save_wav(path, wav, sample_rate)
107 | seek_time = torch.rand(1).item()
108 | read_wav, read_sr = audio_read(path, seek_time, read_duration)
109 | seek_frames = int(sample_rate * seek_time)
110 | expected_frames = n_frames - seek_frames
111 | assert read_sr == sample_rate
112 | assert read_wav.shape[0] == wav.shape[0]
113 | assert read_wav.shape[1] == expected_frames
114 | assert torch.allclose(read_wav, wav[..., seek_frames:], rtol=1e-03, atol=1e-04)
115 |
116 | def test_read_seek_time_wav_padded(self):
117 | sample_rates = [8000, 16_000]
118 | channels = [1, 2]
119 | duration = 1.
120 | read_duration = 1.
121 | for sample_rate, ch in product(sample_rates, channels):
122 | n_frames = int(sample_rate * duration)
123 | read_frames = int(sample_rate * read_duration)
124 | wav = get_white_noise(ch, n_frames).clamp(-0.99, 0.99)
125 | path = self.get_temp_path('sample_wav.wav')
126 | save_wav(path, wav, sample_rate)
127 | seek_time = torch.rand(1).item()
128 | seek_frames = int(sample_rate * seek_time)
129 | expected_frames = n_frames - seek_frames
130 | read_wav, read_sr = audio_read(path, seek_time, read_duration, pad=True)
131 | expected_pad_wav = torch.zeros(wav.shape[0], read_frames - expected_frames)
132 | assert read_sr == sample_rate
133 | assert read_wav.shape[0] == wav.shape[0]
134 | assert read_wav.shape[1] == read_frames
135 | assert torch.allclose(read_wav[..., :expected_frames], wav[..., seek_frames:], rtol=1e-03, atol=1e-04)
136 | assert torch.allclose(read_wav[..., expected_frames:], expected_pad_wav)
137 |
138 |
139 | class TestAvRead(TempDirMixin):
140 |
141 | def test_avread_seek_base(self):
142 | sample_rates = [8000, 16_000]
143 | channels = [1, 2]
144 | duration = 2.
145 | for sample_rate, ch in product(sample_rates, channels):
146 | n_frames = int(sample_rate * duration)
147 | wav = get_white_noise(ch, n_frames)
148 | path = self.get_temp_path(f'reference_a_{sample_rate}_{ch}.wav')
149 | save_wav(path, wav, sample_rate)
150 | for _ in range(100):
151 | # seek will always load a full duration segment in the file
152 | seek_time = random.uniform(0.0, 1.0)
153 | seek_duration = random.uniform(0.001, 1.0)
154 | read_wav, read_sr = _av_read(path, seek_time, seek_duration)
155 | assert read_sr == sample_rate
156 | assert read_wav.shape[0] == wav.shape[0]
157 | assert read_wav.shape[-1] == int(seek_duration * sample_rate)
158 |
159 | def test_avread_seek_partial(self):
160 | sample_rates = [8000, 16_000]
161 | channels = [1, 2]
162 | duration = 1.
163 | for sample_rate, ch in product(sample_rates, channels):
164 | n_frames = int(sample_rate * duration)
165 | wav = get_white_noise(ch, n_frames)
166 | path = self.get_temp_path(f'reference_b_{sample_rate}_{ch}.wav')
167 | save_wav(path, wav, sample_rate)
168 | for _ in range(100):
169 | # seek will always load a partial segment
170 | seek_time = random.uniform(0.5, 1.)
171 | seek_duration = 1.
172 | expected_num_frames = n_frames - int(seek_time * sample_rate)
173 | read_wav, read_sr = _av_read(path, seek_time, seek_duration)
174 | assert read_sr == sample_rate
175 | assert read_wav.shape[0] == wav.shape[0]
176 | assert read_wav.shape[-1] == expected_num_frames
177 |
178 | def test_avread_seek_outofbound(self):
179 | sample_rates = [8000, 16_000]
180 | channels = [1, 2]
181 | duration = 1.
182 | for sample_rate, ch in product(sample_rates, channels):
183 | n_frames = int(sample_rate * duration)
184 | wav = get_white_noise(ch, n_frames)
185 | path = self.get_temp_path(f'reference_c_{sample_rate}_{ch}.wav')
186 | save_wav(path, wav, sample_rate)
187 | seek_time = 1.5
188 | read_wav, read_sr = _av_read(path, seek_time, 1.)
189 | assert read_sr == sample_rate
190 | assert read_wav.shape[0] == wav.shape[0]
191 | assert read_wav.shape[-1] == 0
192 |
193 | def test_avread_seek_edge(self):
194 | sample_rates = [8000, 16_000]
195 | # some of these values will have
196 | # int(((frames - 1) / sample_rate) * sample_rate) != (frames - 1)
197 | n_frames = [1000, 1001, 1002]
198 | channels = [1, 2]
199 | for sample_rate, ch, frames in product(sample_rates, channels, n_frames):
200 | duration = frames / sample_rate
201 | wav = get_white_noise(ch, frames)
202 | path = self.get_temp_path(f'reference_d_{sample_rate}_{ch}.wav')
203 | save_wav(path, wav, sample_rate)
204 | seek_time = (frames - 1) / sample_rate
205 | seek_frames = int(seek_time * sample_rate)
206 | read_wav, read_sr = _av_read(path, seek_time, duration)
207 | assert read_sr == sample_rate
208 | assert read_wav.shape[0] == wav.shape[0]
209 | assert read_wav.shape[-1] == (frames - seek_frames)
210 |
211 |
212 | class TestAudioWrite(TempDirMixin):
213 |
214 | def test_audio_write_wav(self):
215 | torch.manual_seed(1234)
216 | sample_rates = [8000, 16_000]
217 | n_frames = [1000, 1001, 1002]
218 | channels = [1, 2]
219 | strategies = ["peak", "clip", "rms"]
220 | formats = ["wav", "mp3"]
221 | for sample_rate, ch, frames in product(sample_rates, channels, n_frames):
222 | for format_, strategy in product(formats, strategies):
223 | wav = get_white_noise(ch, frames)
224 | path = self.get_temp_path(f'pred_{sample_rate}_{ch}')
225 | audio_write(path, wav, sample_rate, format_, strategy=strategy)
226 | read_wav, read_sr = torchaudio.load(f'{path}.{format_}')
227 | if format_ == "wav":
228 | assert read_wav.shape == wav.shape
229 |
230 | if format_ == "wav" and strategy in ["peak", "rms"]:
231 | rescaled_read_wav = read_wav / read_wav.abs().max() * wav.abs().max()
232 | # for a Gaussian, the typical max scale will be less than ~5x the std.
233 | # The error when writing to disk will ~ 1/2**15, and when rescaling, 5x that.
234 | # For RMS target, rescaling leaves more headroom by default, leading
235 | # to a 20x rescaling typically
236 | atol = (5 if strategy == "peak" else 20) / 2**15
237 | delta = (rescaled_read_wav - wav).abs().max()
238 | assert torch.allclose(wav, rescaled_read_wav, rtol=0, atol=atol), (delta, atol)
239 | formats = ["wav"] # faster unit tests
240 |
--------------------------------------------------------------------------------
/tests/data/test_audio_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import julius
8 | import torch
9 | import pytest
10 |
11 | from audiocraft.data.audio_utils import (
12 | _clip_wav,
13 | convert_audio_channels,
14 | convert_audio,
15 | normalize_audio
16 | )
17 | from ..common_utils import get_batch_white_noise
18 |
19 |
20 | class TestConvertAudioChannels:
21 |
22 | def test_convert_audio_channels_downmix(self):
23 | b, c, t = 2, 3, 100
24 | audio = get_batch_white_noise(b, c, t)
25 | mixed = convert_audio_channels(audio, channels=2)
26 | assert list(mixed.shape) == [b, 2, t]
27 |
28 | def test_convert_audio_channels_nochange(self):
29 | b, c, t = 2, 3, 100
30 | audio = get_batch_white_noise(b, c, t)
31 | mixed = convert_audio_channels(audio, channels=c)
32 | assert list(mixed.shape) == list(audio.shape)
33 |
34 | def test_convert_audio_channels_upmix(self):
35 | b, c, t = 2, 1, 100
36 | audio = get_batch_white_noise(b, c, t)
37 | mixed = convert_audio_channels(audio, channels=3)
38 | assert list(mixed.shape) == [b, 3, t]
39 |
40 | def test_convert_audio_channels_upmix_error(self):
41 | b, c, t = 2, 2, 100
42 | audio = get_batch_white_noise(b, c, t)
43 | with pytest.raises(ValueError):
44 | convert_audio_channels(audio, channels=3)
45 |
46 |
47 | class TestConvertAudio:
48 |
49 | def test_convert_audio_channels_downmix(self):
50 | b, c, dur = 2, 3, 4.
51 | sr = 128
52 | audio = get_batch_white_noise(b, c, int(sr * dur))
53 | out = convert_audio(audio, from_rate=sr, to_rate=sr, to_channels=2)
54 | assert list(out.shape) == [audio.shape[0], 2, audio.shape[-1]]
55 |
56 | def test_convert_audio_channels_upmix(self):
57 | b, c, dur = 2, 1, 4.
58 | sr = 128
59 | audio = get_batch_white_noise(b, c, int(sr * dur))
60 | out = convert_audio(audio, from_rate=sr, to_rate=sr, to_channels=3)
61 | assert list(out.shape) == [audio.shape[0], 3, audio.shape[-1]]
62 |
63 | def test_convert_audio_upsample(self):
64 | b, c, dur = 2, 1, 4.
65 | sr = 2
66 | new_sr = 3
67 | audio = get_batch_white_noise(b, c, int(sr * dur))
68 | out = convert_audio(audio, from_rate=sr, to_rate=new_sr, to_channels=c)
69 | out_j = julius.resample.resample_frac(audio, old_sr=sr, new_sr=new_sr)
70 | assert torch.allclose(out, out_j)
71 |
72 | def test_convert_audio_resample(self):
73 | b, c, dur = 2, 1, 4.
74 | sr = 3
75 | new_sr = 2
76 | audio = get_batch_white_noise(b, c, int(sr * dur))
77 | out = convert_audio(audio, from_rate=sr, to_rate=new_sr, to_channels=c)
78 | out_j = julius.resample.resample_frac(audio, old_sr=sr, new_sr=new_sr)
79 | assert torch.allclose(out, out_j)
80 |
81 |
82 | class TestNormalizeAudio:
83 |
84 | def test_clip_wav(self):
85 | b, c, dur = 2, 1, 4.
86 | sr = 3
87 | audio = 10.0 * get_batch_white_noise(b, c, int(sr * dur))
88 | _clip_wav(audio)
89 | assert audio.abs().max() <= 1
90 |
91 | def test_normalize_audio_clip(self):
92 | b, c, dur = 2, 1, 4.
93 | sr = 3
94 | audio = 10.0 * get_batch_white_noise(b, c, int(sr * dur))
95 | norm_audio = normalize_audio(audio, strategy='clip')
96 | assert norm_audio.abs().max() <= 1
97 |
98 | def test_normalize_audio_rms(self):
99 | b, c, dur = 2, 1, 4.
100 | sr = 3
101 | audio = 10.0 * get_batch_white_noise(b, c, int(sr * dur))
102 | norm_audio = normalize_audio(audio, strategy='rms')
103 | assert norm_audio.abs().max() <= 1
104 |
105 | def test_normalize_audio_peak(self):
106 | b, c, dur = 2, 1, 4.
107 | sr = 3
108 | audio = 10.0 * get_batch_white_noise(b, c, int(sr * dur))
109 | norm_audio = normalize_audio(audio, strategy='peak')
110 | assert norm_audio.abs().max() <= 1
111 |
--------------------------------------------------------------------------------
/tests/models/test_encodec_model.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import random
8 |
9 | import numpy as np
10 | import torch
11 |
12 | from audiocraft.models import EncodecModel
13 | from audiocraft.modules import SEANetEncoder, SEANetDecoder
14 | from audiocraft.quantization import DummyQuantizer
15 |
16 |
17 | class TestEncodecModel:
18 |
19 | def _create_encodec_model(self,
20 | sample_rate: int,
21 | channels: int,
22 | dim: int = 5,
23 | n_filters: int = 3,
24 | n_residual_layers: int = 1,
25 | ratios: list = [5, 4, 3, 2],
26 | **kwargs):
27 | frame_rate = np.prod(ratios)
28 | encoder = SEANetEncoder(channels=channels, dimension=dim, n_filters=n_filters,
29 | n_residual_layers=n_residual_layers, ratios=ratios)
30 | decoder = SEANetDecoder(channels=channels, dimension=dim, n_filters=n_filters,
31 | n_residual_layers=n_residual_layers, ratios=ratios)
32 | quantizer = DummyQuantizer()
33 | model = EncodecModel(encoder, decoder, quantizer, frame_rate=frame_rate,
34 | sample_rate=sample_rate, channels=channels, **kwargs)
35 | return model
36 |
37 | def test_model(self):
38 | random.seed(1234)
39 | sample_rate = 24_000
40 | channels = 1
41 | model = self._create_encodec_model(sample_rate, channels)
42 | for _ in range(10):
43 | length = random.randrange(1, 10_000)
44 | x = torch.randn(2, channels, length)
45 | res = model(x)
46 | assert res.x.shape == x.shape
47 |
48 | def test_model_renorm(self):
49 | random.seed(1234)
50 | sample_rate = 24_000
51 | channels = 1
52 | model_nonorm = self._create_encodec_model(sample_rate, channels, renormalize=False)
53 | model_renorm = self._create_encodec_model(sample_rate, channels, renormalize=True)
54 |
55 | for _ in range(10):
56 | length = random.randrange(1, 10_000)
57 | x = torch.randn(2, channels, length)
58 | codes, scales = model_nonorm.encode(x)
59 | codes, scales = model_renorm.encode(x)
60 | assert scales is not None
61 |
--------------------------------------------------------------------------------
/tests/models/test_musicgen.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import pytest
8 | import torch
9 |
10 | from audiocraft.models import MusicGen
11 |
12 |
13 | class TestSEANetModel:
14 | def get_musicgen(self):
15 | mg = MusicGen.get_pretrained(name='debug', device='cpu')
16 | mg.set_generation_params(duration=2.0)
17 | return mg
18 |
19 | def test_base(self):
20 | mg = self.get_musicgen()
21 | assert mg.frame_rate == 25
22 | assert mg.sample_rate == 32000
23 | assert mg.audio_channels == 1
24 |
25 | def test_generate_unconditional(self):
26 | mg = self.get_musicgen()
27 | wav = mg.generate_unconditional(3)
28 | assert list(wav.shape) == [3, 1, 64000]
29 |
30 | def test_generate_continuation(self):
31 | mg = self.get_musicgen()
32 | prompt = torch.randn(3, 1, 32000)
33 | wav = mg.generate_continuation(prompt, 32000)
34 | assert list(wav.shape) == [3, 1, 64000]
35 |
36 | prompt = torch.randn(2, 1, 32000)
37 | wav = mg.generate_continuation(
38 | prompt, 32000, ['youpi', 'lapin dort'])
39 | assert list(wav.shape) == [2, 1, 64000]
40 |
41 | prompt = torch.randn(2, 1, 32000)
42 | with pytest.raises(AssertionError):
43 | wav = mg.generate_continuation(
44 | prompt, 32000, ['youpi', 'lapin dort', 'one too many'])
45 |
46 | def test_generate(self):
47 | mg = self.get_musicgen()
48 | wav = mg.generate(
49 | ['youpi', 'lapin dort'])
50 | assert list(wav.shape) == [2, 1, 64000]
51 |
--------------------------------------------------------------------------------
/tests/modules/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/tests/modules/test_codebooks_patterns.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import pytest
8 | import torch
9 |
10 | from audiocraft.modules.codebooks_patterns import (
11 | DelayedPatternProvider,
12 | ParallelPatternProvider,
13 | Pattern,
14 | UnrolledPatternProvider,
15 | )
16 |
17 |
18 | class TestParallelPatternProvider:
19 |
20 | @pytest.mark.parametrize("n_q", [1, 4, 32])
21 | @pytest.mark.parametrize("timesteps", [0, 1, 16, 100])
22 | def test_get_pattern(self, n_q: int, timesteps: int):
23 | provider = ParallelPatternProvider(n_q)
24 | pattern = provider.get_pattern(timesteps)
25 | # + 1 to account for 1st step
26 | assert len(pattern.layout) == timesteps + 1
27 |
28 | @pytest.mark.parametrize("n_q", [1, 4, 32])
29 | @pytest.mark.parametrize("timesteps", [8, 16, 100])
30 | def test_pattern_content(self, n_q: int, timesteps: int):
31 | provider = ParallelPatternProvider(n_q)
32 | pattern = provider.get_pattern(timesteps)
33 | for s, v in enumerate(pattern.layout):
34 | for i, code in enumerate(v):
35 | assert i == code.q
36 | assert code.t == s - 1 # account for the 1st empty step
37 |
38 | @pytest.mark.parametrize("n_q", [1, 4, 32])
39 | @pytest.mark.parametrize("timesteps", [8, 16, 100])
40 | def test_pattern_max_delay(self, n_q: int, timesteps: int):
41 | provider = ParallelPatternProvider(n_q)
42 | pattern = provider.get_pattern(timesteps)
43 | assert pattern.max_delay == 0
44 | assert len(pattern.valid_layout) == len(pattern.layout) - pattern.max_delay
45 |
46 |
47 | class TestDelayedPatternProvider:
48 |
49 | @pytest.mark.parametrize("n_q", [1, 4, 32])
50 | @pytest.mark.parametrize("timesteps", [0, 1, 16, 100])
51 | def test_get_pattern(self, n_q: int, timesteps: int):
52 | delays = [
53 | list(range(n_q)),
54 | [0] + [1] * (n_q - 1),
55 | [0] + [4] * (n_q - 1),
56 | ]
57 | for delay in delays:
58 | provider = DelayedPatternProvider(n_q, delay)
59 | pattern = provider.get_pattern(timesteps)
60 | # + 1 to account for 1st step
61 | assert len(pattern.layout) == timesteps + max(delay) + 1
62 |
63 | @pytest.mark.parametrize("n_q", [1, 4, 32])
64 | @pytest.mark.parametrize("timesteps", [8, 16, 100])
65 | def test_pattern_content(self, n_q: int, timesteps: int):
66 | provider = DelayedPatternProvider(n_q)
67 | pattern = provider.get_pattern(timesteps)
68 | for s, v in enumerate(pattern.layout):
69 | for i, code in enumerate(v):
70 | assert i == code.q
71 | assert code.t == max(0, s - code.q - 1)
72 |
73 | @pytest.mark.parametrize("timesteps", [8, 16, 100])
74 | @pytest.mark.parametrize("delay", [[0, 1, 2, 3], [0, 1, 1, 1], [0, 3, 3, 3], [0, 3]])
75 | def test_pattern_max_delay(self, timesteps: int, delay: list):
76 | provider = DelayedPatternProvider(len(delay), delay)
77 | pattern = provider.get_pattern(timesteps)
78 | assert pattern.max_delay == max(delay)
79 | assert len(pattern.valid_layout) == len(pattern.layout) - pattern.max_delay
80 |
81 |
82 | class TestUnrolledPatternProvider:
83 |
84 | @pytest.mark.parametrize("timesteps", [0, 1, 16])
85 | @pytest.mark.parametrize("flattening", [[0, 1, 2], [0, 1, 1]])
86 | @pytest.mark.parametrize("delays", [[0, 0, 0], [0, 5, 5]])
87 | def test_get_pattern(self, timesteps: int, flattening: list, delays: list):
88 | n_q = len(flattening)
89 | max_delay = max(delays)
90 | provider = UnrolledPatternProvider(n_q, flattening, delays)
91 | pattern = provider.get_pattern(timesteps)
92 | assert len(pattern.layout) == provider.num_virtual_steps(timesteps) + max_delay
93 |
94 | @pytest.mark.parametrize("timesteps", [0, 1, 16])
95 | @pytest.mark.parametrize("flattening", [[0, 1, 2], [0, 1, 1]])
96 | @pytest.mark.parametrize("delays", [[0, 0, 0], [0, 5, 5]])
97 | def test_pattern_max_delay(self, timesteps: int, flattening: list, delays: list):
98 | n_q = len(flattening)
99 | max_delay = max(delays)
100 | provider = UnrolledPatternProvider(n_q, flattening, delays)
101 | pattern = provider.get_pattern(timesteps)
102 | assert pattern.max_delay == max_delay
103 |
104 |
105 | class TestPattern:
106 |
107 | def ref_build_pattern_sequence(self, z: torch.Tensor, pattern: Pattern, special_token: int):
108 | """Reference method to build the sequence from the pattern without using fancy scatter."""
109 | bs, n_q, T = z.shape
110 | z = z.cpu().numpy()
111 | assert n_q == pattern.n_q
112 | assert T <= pattern.timesteps
113 | inp = torch.full((bs, n_q, len(pattern.layout)), special_token, dtype=torch.long).numpy()
114 | inp[:] = special_token
115 | for s, v in enumerate(pattern.layout):
116 | for (t, q) in v:
117 | if t < T:
118 | inp[:, q, s] = z[:, q, t]
119 | return torch.from_numpy(inp)
120 |
121 | def ref_revert_pattern_sequence(self, z: torch.Tensor, pattern: Pattern, special_token: int):
122 | """Reference method to revert the sequence from the pattern without using fancy scatter."""
123 | z = z.cpu().numpy()
124 | bs, n_q, S = z.shape
125 | assert pattern.n_q == n_q
126 | inp = torch.full((bs, pattern.n_q, pattern.timesteps), special_token, dtype=torch.long).numpy()
127 | inp[:] = special_token
128 | for s, v in enumerate(pattern.layout):
129 | for (t, q) in v:
130 | if t < pattern.timesteps:
131 | inp[:, q, t] = z[:, q, s]
132 | return torch.from_numpy(inp)
133 |
134 | def ref_revert_pattern_logits(self, z: torch.Tensor, pattern: Pattern, special_token: float):
135 | """Reference method to revert the logits from the pattern without using fancy scatter."""
136 | z = z.cpu().numpy()
137 | bs, card, n_q, S = z.shape
138 | assert pattern.n_q == n_q
139 | ref_layout = pattern.layout
140 | inp = torch.full((bs, card, pattern.n_q, pattern.timesteps), special_token, dtype=torch.float).numpy()
141 | inp[:] = special_token
142 | for s, v in enumerate(ref_layout[1:]):
143 | if s < S:
144 | for (t, q) in v:
145 | if t < pattern.timesteps:
146 | inp[:, :, q, t] = z[:, :, q, s]
147 | return torch.from_numpy(inp)
148 |
149 | def _get_pattern_providers(self, n_q: int):
150 | pattern_provider_1 = ParallelPatternProvider(n_q)
151 | pattern_provider_2 = DelayedPatternProvider(n_q, list(range(n_q)))
152 | pattern_provider_3 = DelayedPatternProvider(n_q, [0] + [1] * (n_q - 1))
153 | pattern_provider_4 = UnrolledPatternProvider(
154 | n_q, flattening=list(range(n_q)), delays=[0] * n_q
155 | )
156 | pattern_provider_5 = UnrolledPatternProvider(
157 | n_q, flattening=[0] + [1] * (n_q - 1), delays=[0] * n_q
158 | )
159 | pattern_provider_6 = UnrolledPatternProvider(
160 | n_q, flattening=[0] + [1] * (n_q - 1), delays=[0] + [5] * (n_q - 1)
161 | )
162 | return [
163 | pattern_provider_1,
164 | pattern_provider_2,
165 | pattern_provider_3,
166 | pattern_provider_4,
167 | pattern_provider_5,
168 | pattern_provider_6,
169 | ]
170 |
171 | @pytest.mark.parametrize("n_q", [1, 4, 32])
172 | @pytest.mark.parametrize("timesteps", [16, 72])
173 | def test_build_pattern_sequence(self, n_q: int, timesteps: int):
174 | bs = 2
175 | card = 256
176 | special_token = card
177 |
178 | pattern_providers = self._get_pattern_providers(n_q)
179 | for pattern_provider in pattern_providers:
180 | pattern = pattern_provider.get_pattern(timesteps)
181 | # we can correctly build the sequence from the pattern
182 | z = torch.randint(0, card, (bs, n_q, timesteps))
183 | ref_res = self.ref_build_pattern_sequence(z, pattern, special_token)
184 | res, indexes, mask = pattern.build_pattern_sequence(z, special_token)
185 | assert (res == ref_res).float().mean() == 1.0
186 |
187 | # expected assertion fails on the number of timesteps
188 | invalid_timesteps = [timesteps + 1]
189 | if pattern.num_sequence_steps != pattern.timesteps:
190 | invalid_timesteps.append(pattern.num_sequence_steps)
191 | for i_timesteps in invalid_timesteps:
192 | z2 = torch.randint(0, card, (bs, n_q, i_timesteps))
193 | with pytest.raises(AssertionError):
194 | pattern.build_pattern_sequence(z2, special_token)
195 |
196 | # expected assertion fails on the number of codebooks
197 | invalid_qs = [0, n_q - 1, n_q + 1]
198 | for i_q in invalid_qs:
199 | z3 = torch.randint(0, card, (bs, i_q, timesteps))
200 | with pytest.raises(AssertionError):
201 | pattern.build_pattern_sequence(z3, special_token)
202 |
203 | @pytest.mark.parametrize("n_q", [1, 4, 32])
204 | @pytest.mark.parametrize("timesteps", [16, 72])
205 | def test_revert_pattern_sequence(self, n_q: int, timesteps: int):
206 | bs = 2
207 | card = 256
208 | special_token = card
209 |
210 | pattern_providers = self._get_pattern_providers(n_q)
211 | for pattern_provider in pattern_providers:
212 | pattern = pattern_provider.get_pattern(timesteps)
213 | # this works assuming previous tests are successful
214 | z = torch.randint(0, card, (bs, n_q, timesteps))
215 | s = self.ref_build_pattern_sequence(z, pattern, special_token)
216 | ref_out = self.ref_revert_pattern_sequence(s, pattern, special_token)
217 | # ensure our reference script retrieve the original sequence
218 | assert z.shape == ref_out.shape
219 | assert (z == ref_out).float().mean() == 1.0
220 | # now we can test the scatter version
221 | out, indexes, mask = pattern.revert_pattern_sequence(s, special_token)
222 | assert out.shape == ref_out.shape
223 | assert (out == ref_out).float().mean() == 1.0
224 |
225 | @pytest.mark.parametrize("n_q", [1, 4, 32])
226 | @pytest.mark.parametrize("timesteps", [16, 72])
227 | @pytest.mark.parametrize("card", [1, 2, 256, 1024])
228 | def test_revert_pattern_logits(self, n_q: int, timesteps: int, card: int):
229 | bs = 2
230 | special_token = card
231 | logits_special_token = float('nan')
232 |
233 | pattern_providers = self._get_pattern_providers(n_q)
234 | for pattern_provider in pattern_providers:
235 | pattern = pattern_provider.get_pattern(timesteps)
236 | # this works assuming previous tests are successful
237 | z = torch.randint(0, card, (bs, n_q, timesteps))
238 | s = self.ref_build_pattern_sequence(z, pattern, special_token)
239 | logits = torch.randn((bs, card, n_q, s.shape[-1]))
240 | ref_out = self.ref_revert_pattern_logits(logits, pattern, logits_special_token)
241 | # ensure our reference script retrieve the original sequence
242 | assert ref_out.shape == torch.Size([bs, card, n_q, timesteps])
243 | # now we can test the scatter version
244 | out, indexes, mask = pattern.revert_pattern_logits(logits, logits_special_token)
245 | assert out.shape == ref_out.shape
246 | assert (out == ref_out).float().mean() == 1.0
247 |
--------------------------------------------------------------------------------
/tests/modules/test_conv.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from itertools import product
8 | import math
9 | import random
10 |
11 | import pytest
12 | import torch
13 | from torch import nn
14 |
15 | from audiocraft.modules import (
16 | NormConv1d,
17 | NormConvTranspose1d,
18 | StreamableConv1d,
19 | StreamableConvTranspose1d,
20 | pad1d,
21 | unpad1d,
22 | )
23 |
24 |
25 | def test_get_extra_padding_for_conv1d():
26 | # TODO: Implement me!
27 | pass
28 |
29 |
30 | def test_pad1d_zeros():
31 | x = torch.randn(1, 1, 20)
32 |
33 | xp1 = pad1d(x, (0, 5), mode='constant', value=0.)
34 | assert xp1.shape[-1] == 25
35 | xp2 = pad1d(x, (5, 5), mode='constant', value=0.)
36 | assert xp2.shape[-1] == 30
37 | xp3 = pad1d(x, (0, 0), mode='constant', value=0.)
38 | assert xp3.shape[-1] == 20
39 | xp4 = pad1d(x, (10, 30), mode='constant', value=0.)
40 | assert xp4.shape[-1] == 60
41 |
42 | with pytest.raises(AssertionError):
43 | pad1d(x, (-1, 0), mode='constant', value=0.)
44 |
45 | with pytest.raises(AssertionError):
46 | pad1d(x, (0, -1), mode='constant', value=0.)
47 |
48 | with pytest.raises(AssertionError):
49 | pad1d(x, (-1, -1), mode='constant', value=0.)
50 |
51 |
52 | def test_pad1d_reflect():
53 | x = torch.randn(1, 1, 20)
54 |
55 | xp1 = pad1d(x, (0, 5), mode='reflect', value=0.)
56 | assert xp1.shape[-1] == 25
57 | xp2 = pad1d(x, (5, 5), mode='reflect', value=0.)
58 | assert xp2.shape[-1] == 30
59 | xp3 = pad1d(x, (0, 0), mode='reflect', value=0.)
60 | assert xp3.shape[-1] == 20
61 | xp4 = pad1d(x, (10, 30), mode='reflect', value=0.)
62 | assert xp4.shape[-1] == 60
63 |
64 | with pytest.raises(AssertionError):
65 | pad1d(x, (-1, 0), mode='reflect', value=0.)
66 |
67 | with pytest.raises(AssertionError):
68 | pad1d(x, (0, -1), mode='reflect', value=0.)
69 |
70 | with pytest.raises(AssertionError):
71 | pad1d(x, (-1, -1), mode='reflect', value=0.)
72 |
73 |
74 | def test_unpad1d():
75 | x = torch.randn(1, 1, 20)
76 |
77 | u1 = unpad1d(x, (5, 5))
78 | assert u1.shape[-1] == 10
79 | u2 = unpad1d(x, (0, 5))
80 | assert u2.shape[-1] == 15
81 | u3 = unpad1d(x, (5, 0))
82 | assert u3.shape[-1] == 15
83 | u4 = unpad1d(x, (0, 0))
84 | assert u4.shape[-1] == x.shape[-1]
85 |
86 | with pytest.raises(AssertionError):
87 | unpad1d(x, (-1, 0))
88 |
89 | with pytest.raises(AssertionError):
90 | unpad1d(x, (0, -1))
91 |
92 | with pytest.raises(AssertionError):
93 | unpad1d(x, (-1, -1))
94 |
95 |
96 | class TestNormConv1d:
97 |
98 | def test_norm_conv1d_modules(self):
99 | N, C, T = 2, 2, random.randrange(1, 100_000)
100 | t0 = torch.randn(N, C, T)
101 |
102 | C_out, kernel_size, stride = 1, 4, 1
103 | expected_out_length = int((T - kernel_size) / stride + 1)
104 | wn_conv = NormConv1d(C, 1, kernel_size=4, norm='weight_norm')
105 | gn_conv = NormConv1d(C, 1, kernel_size=4, norm='time_group_norm')
106 | nn_conv = NormConv1d(C, 1, kernel_size=4, norm='none')
107 |
108 | assert isinstance(wn_conv.norm, nn.Identity)
109 | assert isinstance(wn_conv.conv, nn.Conv1d)
110 |
111 | assert isinstance(gn_conv.norm, nn.GroupNorm)
112 | assert isinstance(gn_conv.conv, nn.Conv1d)
113 |
114 | assert isinstance(nn_conv.norm, nn.Identity)
115 | assert isinstance(nn_conv.conv, nn.Conv1d)
116 |
117 | for conv_layer in [wn_conv, gn_conv, nn_conv]:
118 | out = conv_layer(t0)
119 | assert isinstance(out, torch.Tensor)
120 | assert list(out.shape) == [N, C_out, expected_out_length]
121 |
122 |
123 | class TestNormConvTranspose1d:
124 |
125 | def test_normalizations(self):
126 | N, C, T = 2, 2, random.randrange(1, 100_000)
127 | t0 = torch.randn(N, C, T)
128 |
129 | C_out, kernel_size, stride = 1, 4, 1
130 | expected_out_length = (T - 1) * stride + (kernel_size - 1) + 1
131 |
132 | wn_convtr = NormConvTranspose1d(C, C_out, kernel_size=kernel_size, stride=stride, norm='weight_norm')
133 | gn_convtr = NormConvTranspose1d(C, C_out, kernel_size=kernel_size, stride=stride, norm='time_group_norm')
134 | nn_convtr = NormConvTranspose1d(C, C_out, kernel_size=kernel_size, stride=stride, norm='none')
135 |
136 | assert isinstance(wn_convtr.norm, nn.Identity)
137 | assert isinstance(wn_convtr.convtr, nn.ConvTranspose1d)
138 |
139 | assert isinstance(gn_convtr.norm, nn.GroupNorm)
140 | assert isinstance(gn_convtr.convtr, nn.ConvTranspose1d)
141 |
142 | assert isinstance(nn_convtr.norm, nn.Identity)
143 | assert isinstance(nn_convtr.convtr, nn.ConvTranspose1d)
144 |
145 | for convtr_layer in [wn_convtr, gn_convtr, nn_convtr]:
146 | out = convtr_layer(t0)
147 | assert isinstance(out, torch.Tensor)
148 | assert list(out.shape) == [N, C_out, expected_out_length]
149 |
150 |
151 | class TestStreamableConv1d:
152 |
153 | def get_streamable_conv1d_output_length(self, length, kernel_size, stride, dilation):
154 | # StreamableConv1d internally pads to make sure that the last window is full
155 | padding_total = (kernel_size - 1) * dilation - (stride - 1)
156 | n_frames = (length - kernel_size + padding_total) / stride + 1
157 | ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total)
158 | return ideal_length // stride
159 |
160 | def test_streamable_conv1d(self):
161 | N, C, T = 2, 2, random.randrange(1, 100_000)
162 | t0 = torch.randn(N, C, T)
163 | C_out = 1
164 |
165 | # conv params are [(kernel_size, stride, dilation)]
166 | conv_params = [(4, 1, 1), (4, 2, 1), (3, 1, 3), (10, 5, 1), (3, 2, 3)]
167 | for causal, (kernel_size, stride, dilation) in product([False, True], conv_params):
168 | expected_out_length = self.get_streamable_conv1d_output_length(T, kernel_size, stride, dilation)
169 | sconv = StreamableConv1d(C, C_out, kernel_size=kernel_size, stride=stride, dilation=dilation, causal=causal)
170 | out = sconv(t0)
171 | assert isinstance(out, torch.Tensor)
172 | print(list(out.shape), [N, C_out, expected_out_length])
173 | assert list(out.shape) == [N, C_out, expected_out_length]
174 |
175 |
176 | class TestStreamableConvTranspose1d:
177 |
178 | def get_streamable_convtr1d_output_length(self, length, kernel_size, stride):
179 | padding_total = (kernel_size - stride)
180 | return (length - 1) * stride - padding_total + (kernel_size - 1) + 1
181 |
182 | def test_streamable_convtr1d(self):
183 | N, C, T = 2, 2, random.randrange(1, 100_000)
184 | t0 = torch.randn(N, C, T)
185 |
186 | C_out = 1
187 |
188 | with pytest.raises(AssertionError):
189 | StreamableConvTranspose1d(C, C_out, kernel_size=4, causal=False, trim_right_ratio=0.5)
190 | StreamableConvTranspose1d(C, C_out, kernel_size=4, causal=True, trim_right_ratio=-1.)
191 | StreamableConvTranspose1d(C, C_out, kernel_size=4, causal=True, trim_right_ratio=2)
192 |
193 | # causal params are [(causal, trim_right)]
194 | causal_params = [(False, 1.0), (True, 1.0), (True, 0.5), (True, 0.0)]
195 | # conv params are [(kernel_size, stride)]
196 | conv_params = [(4, 1), (4, 2), (3, 1), (10, 5)]
197 | for ((causal, trim_right_ratio), (kernel_size, stride)) in product(causal_params, conv_params):
198 | expected_out_length = self.get_streamable_convtr1d_output_length(T, kernel_size, stride)
199 | sconvtr = StreamableConvTranspose1d(C, C_out, kernel_size=kernel_size, stride=stride,
200 | causal=causal, trim_right_ratio=trim_right_ratio)
201 | out = sconvtr(t0)
202 | assert isinstance(out, torch.Tensor)
203 | assert list(out.shape) == [N, C_out, expected_out_length]
204 |
--------------------------------------------------------------------------------
/tests/modules/test_lstm.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import random
8 | import torch
9 |
10 | from audiocraft.modules.lstm import StreamableLSTM
11 |
12 |
13 | class TestStreamableLSTM:
14 |
15 | def test_lstm(self):
16 | B, C, T = 4, 2, random.randint(1, 100)
17 |
18 | lstm = StreamableLSTM(C, 3, skip=False)
19 | x = torch.randn(B, C, T)
20 | y = lstm(x)
21 |
22 | print(y.shape)
23 | assert y.shape == torch.Size([B, C, T])
24 |
25 | def test_lstm_skip(self):
26 | B, C, T = 4, 2, random.randint(1, 100)
27 |
28 | lstm = StreamableLSTM(C, 3, skip=True)
29 | x = torch.randn(B, C, T)
30 | y = lstm(x)
31 |
32 | assert y.shape == torch.Size([B, C, T])
33 |
--------------------------------------------------------------------------------
/tests/modules/test_rope.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import torch
8 |
9 | from audiocraft.modules.rope import RotaryEmbedding
10 | from audiocraft.modules.transformer import StreamingTransformer
11 |
12 |
13 | def test_rope():
14 | B, T, H, C = 8, 75, 16, 128
15 |
16 | rope = RotaryEmbedding(dim=C)
17 | xq = torch.rand((B, T, H, C))
18 | xk = torch.rand((B, T, H, C))
19 | xq_out, xk_out = rope.rotate_qk(xq, xk, start=7)
20 |
21 | assert list(xq_out.shape) == [B, T, H, C]
22 | assert list(xk_out.shape) == [B, T, H, C]
23 |
24 |
25 | def test_rope_io_dtypes():
26 | B, T, H, C = 8, 75, 16, 128
27 |
28 | rope_32 = RotaryEmbedding(dim=C, dtype=torch.float32)
29 | rope_64 = RotaryEmbedding(dim=C, dtype=torch.float64)
30 |
31 | # Test bfloat16 inputs w/ both 32 and 64 precision rope.
32 | xq_16 = torch.rand((B, T, H, C)).to(torch.bfloat16)
33 | xk_16 = torch.rand((B, T, H, C)).to(torch.bfloat16)
34 | xq_out, xk_out = rope_32.rotate_qk(xq_16, xk_16)
35 | assert xq_out.dtype == torch.bfloat16
36 | xq_out, xk_out = rope_64.rotate_qk(xq_16, xk_16)
37 | assert xq_out.dtype == torch.bfloat16
38 |
39 | # Test float32 inputs w/ both 32 and 64 precision rope.
40 | xq_32 = torch.rand((B, T, H, C)).to(torch.float32)
41 | xk_32 = torch.rand((B, T, H, C)).to(torch.float32)
42 | xq_out, xk_out = rope_32.rotate_qk(xq_32, xk_32)
43 | assert xq_out.dtype == torch.float32
44 | xq_out, xk_out = rope_64.rotate_qk(xq_32, xk_32)
45 | assert xq_out.dtype == torch.float32
46 |
47 |
48 | def test_transformer_with_rope():
49 | torch.manual_seed(1234)
50 | for pos in ['rope', 'sin_rope']:
51 | tr = StreamingTransformer(
52 | 16, 4, 2, custom=True, dropout=0., layer_scale=0.1,
53 | positional_embedding=pos)
54 | tr.eval()
55 | steps = 12
56 | x = torch.randn(3, steps, 16)
57 |
58 | out = tr(x)
59 | assert list(out.shape) == list(x.shape)
60 |
61 |
62 | @torch.no_grad()
63 | def test_rope_streaming():
64 | torch.manual_seed(1234)
65 | tr = StreamingTransformer(
66 | 16, 4, 2, causal=True, dropout=0.,
67 | custom=True, positional_embedding='rope')
68 | tr.eval()
69 | steps = 12
70 | x = torch.randn(3, steps, 16)
71 |
72 | ref = tr(x)
73 |
74 | with tr.streaming():
75 | outs = []
76 | frame_sizes = [1] * steps
77 |
78 | for frame_size in frame_sizes:
79 | frame = x[:, :frame_size]
80 | x = x[:, frame_size:]
81 | outs.append(tr(frame))
82 |
83 | out = torch.cat(outs, dim=1)
84 | assert list(out.shape) == [3, steps, 16]
85 | delta = torch.norm(out - ref) / torch.norm(out)
86 | assert delta < 1e-6, delta
87 |
88 |
89 | @torch.no_grad()
90 | def test_rope_streaming_past_context():
91 | torch.manual_seed(1234)
92 |
93 | for context in [None, 10]:
94 | tr = StreamingTransformer(
95 | 16, 4, 1 if context else 2,
96 | causal=True, past_context=context, custom=True,
97 | dropout=0., positional_embedding='rope')
98 | tr.eval()
99 |
100 | steps = 20
101 | x = torch.randn(3, steps, 16)
102 | ref = tr(x)
103 |
104 | with tr.streaming():
105 | outs = []
106 | frame_sizes = [1] * steps
107 |
108 | for frame_size in frame_sizes:
109 | frame = x[:, :frame_size]
110 | x = x[:, frame_size:]
111 | outs.append(tr(frame))
112 |
113 | out = torch.cat(outs, dim=1)
114 | assert list(out.shape) == [3, steps, 16]
115 | delta = torch.norm(out - ref) / torch.norm(out)
116 | assert delta < 1e-6, delta
117 |
118 |
119 | def test_rope_memory_efficient():
120 | torch.manual_seed(1234)
121 | tr = StreamingTransformer(
122 | 16, 4, 2, custom=True, dropout=0., layer_scale=0.1,
123 | positional_embedding='rope')
124 | tr_mem_efficient = StreamingTransformer(
125 | 16, 4, 2, dropout=0., memory_efficient=True, layer_scale=0.1,
126 | positional_embedding='rope')
127 | tr_mem_efficient.load_state_dict(tr.state_dict())
128 | tr.eval()
129 | steps = 12
130 | x = torch.randn(3, steps, 16)
131 |
132 | with torch.no_grad():
133 | y = tr(x)
134 | y2 = tr_mem_efficient(x)
135 | # Check at float precision b/c this is the rope default.
136 | assert torch.allclose(y, y2, atol=1e-7), (y - y2).norm()
137 |
138 |
139 | def test_rope_with_xpos():
140 | B, T, H, C = 8, 75, 16, 128
141 |
142 | rope = RotaryEmbedding(dim=C, xpos=True)
143 | xq = torch.rand((B, T, H, C))
144 | xk = torch.rand((B, T, H, C))
145 | xq_out, xk_out = rope.rotate_qk(xq, xk, start=7)
146 |
147 | assert list(xq_out.shape) == [B, T, H, C]
148 | assert list(xk_out.shape) == [B, T, H, C]
149 |
150 |
151 | def test_positional_scale():
152 | B, T, H, C = 8, 75, 16, 128
153 |
154 | rope = RotaryEmbedding(dim=C, xpos=True, scale=0.0)
155 | xq = torch.rand((B, T, H, C))
156 | xk = torch.rand((B, T, H, C))
157 | xq_out, xk_out = rope.rotate_qk(xq, xk, start=7)
158 |
159 | assert torch.allclose(xq, xq_out)
160 | assert torch.allclose(xk, xk_out)
161 |
--------------------------------------------------------------------------------
/tests/modules/test_seanet.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from itertools import product
8 |
9 | import pytest
10 | import torch
11 |
12 | from audiocraft.modules.seanet import SEANetEncoder, SEANetDecoder, SEANetResnetBlock
13 | from audiocraft.modules import StreamableConv1d, StreamableConvTranspose1d
14 |
15 |
16 | class TestSEANetModel:
17 |
18 | def test_base(self):
19 | encoder = SEANetEncoder()
20 | decoder = SEANetDecoder()
21 |
22 | x = torch.randn(1, 1, 24000)
23 | z = encoder(x)
24 | assert list(z.shape) == [1, 128, 75], z.shape
25 | y = decoder(z)
26 | assert y.shape == x.shape, (x.shape, y.shape)
27 |
28 | def test_causal(self):
29 | encoder = SEANetEncoder(causal=True)
30 | decoder = SEANetDecoder(causal=True)
31 | x = torch.randn(1, 1, 24000)
32 |
33 | z = encoder(x)
34 | assert list(z.shape) == [1, 128, 75], z.shape
35 | y = decoder(z)
36 | assert y.shape == x.shape, (x.shape, y.shape)
37 |
38 | def test_conv_skip_connection(self):
39 | encoder = SEANetEncoder(true_skip=False)
40 | decoder = SEANetDecoder(true_skip=False)
41 |
42 | x = torch.randn(1, 1, 24000)
43 | z = encoder(x)
44 | assert list(z.shape) == [1, 128, 75], z.shape
45 | y = decoder(z)
46 | assert y.shape == x.shape, (x.shape, y.shape)
47 |
48 | def test_seanet_encoder_decoder_final_act(self):
49 | encoder = SEANetEncoder(true_skip=False)
50 | decoder = SEANetDecoder(true_skip=False, final_activation='Tanh')
51 |
52 | x = torch.randn(1, 1, 24000)
53 | z = encoder(x)
54 | assert list(z.shape) == [1, 128, 75], z.shape
55 | y = decoder(z)
56 | assert y.shape == x.shape, (x.shape, y.shape)
57 |
58 | def _check_encoder_blocks_norm(self, encoder: SEANetEncoder, n_disable_blocks: int, norm: str):
59 | n_blocks = 0
60 | for layer in encoder.model:
61 | if isinstance(layer, StreamableConv1d):
62 | n_blocks += 1
63 | assert layer.conv.norm_type == 'none' if n_blocks <= n_disable_blocks else norm
64 | elif isinstance(layer, SEANetResnetBlock):
65 | for resnet_layer in layer.block:
66 | if isinstance(resnet_layer, StreamableConv1d):
67 | # here we add + 1 to n_blocks as we increment n_blocks just after the block
68 | assert resnet_layer.conv.norm_type == 'none' if (n_blocks + 1) <= n_disable_blocks else norm
69 |
70 | def test_encoder_disable_norm(self):
71 | n_residuals = [0, 1, 3]
72 | disable_blocks = [0, 1, 2, 3, 4, 5, 6]
73 | norms = ['weight_norm', 'none']
74 | for n_res, disable_blocks, norm in product(n_residuals, disable_blocks, norms):
75 | encoder = SEANetEncoder(n_residual_layers=n_res, norm=norm,
76 | disable_norm_outer_blocks=disable_blocks)
77 | self._check_encoder_blocks_norm(encoder, disable_blocks, norm)
78 |
79 | def _check_decoder_blocks_norm(self, decoder: SEANetDecoder, n_disable_blocks: int, norm: str):
80 | n_blocks = 0
81 | for layer in decoder.model:
82 | if isinstance(layer, StreamableConv1d):
83 | n_blocks += 1
84 | assert layer.conv.norm_type == 'none' if (decoder.n_blocks - n_blocks) < n_disable_blocks else norm
85 | elif isinstance(layer, StreamableConvTranspose1d):
86 | n_blocks += 1
87 | assert layer.convtr.norm_type == 'none' if (decoder.n_blocks - n_blocks) < n_disable_blocks else norm
88 | elif isinstance(layer, SEANetResnetBlock):
89 | for resnet_layer in layer.block:
90 | if isinstance(resnet_layer, StreamableConv1d):
91 | assert resnet_layer.conv.norm_type == 'none' \
92 | if (decoder.n_blocks - n_blocks) < n_disable_blocks else norm
93 |
94 | def test_decoder_disable_norm(self):
95 | n_residuals = [0, 1, 3]
96 | disable_blocks = [0, 1, 2, 3, 4, 5, 6]
97 | norms = ['weight_norm', 'none']
98 | for n_res, disable_blocks, norm in product(n_residuals, disable_blocks, norms):
99 | decoder = SEANetDecoder(n_residual_layers=n_res, norm=norm,
100 | disable_norm_outer_blocks=disable_blocks)
101 | self._check_decoder_blocks_norm(decoder, disable_blocks, norm)
102 |
103 | def test_disable_norm_raises_exception(self):
104 | # Invalid disable_norm_outer_blocks values raise exceptions
105 | with pytest.raises(AssertionError):
106 | SEANetEncoder(disable_norm_outer_blocks=-1)
107 |
108 | with pytest.raises(AssertionError):
109 | SEANetEncoder(ratios=[1, 1, 2, 2], disable_norm_outer_blocks=7)
110 |
111 | with pytest.raises(AssertionError):
112 | SEANetDecoder(disable_norm_outer_blocks=-1)
113 |
114 | with pytest.raises(AssertionError):
115 | SEANetDecoder(ratios=[1, 1, 2, 2], disable_norm_outer_blocks=7)
116 |
--------------------------------------------------------------------------------
/tests/modules/test_transformer.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from itertools import product
8 |
9 | import pytest
10 | import torch
11 |
12 | from audiocraft.modules.transformer import StreamingMultiheadAttention, StreamingTransformer
13 |
14 |
15 | def test_transformer_causal_streaming():
16 | torch.manual_seed(1234)
17 |
18 | for context, custom in product([None, 10], [False, True]):
19 | # Test that causality and receptive fields are properly handled.
20 | # looking at the gradients
21 | tr = StreamingTransformer(
22 | 16, 4, 1 if context else 2,
23 | causal=True, past_context=context, custom=custom,
24 | dropout=0.)
25 | steps = 20
26 | for k in [0, 10, 15, 19]:
27 | x = torch.randn(4, steps, 16, requires_grad=True)
28 | y = tr(x)
29 | y[:, k].abs().sum().backward()
30 | if k + 1 < steps:
31 | assert torch.allclose(x.grad[:, k + 1:], torch.tensor(0.)), x.grad[:, k + 1:].norm()
32 | assert not torch.allclose(x.grad[:, :k + 1], torch.tensor(0.)), x.grad[:, :k + 1].norm()
33 | if context is not None and k > context:
34 | limit = k - context - 1
35 | assert torch.allclose(x.grad[:, :limit],
36 | torch.tensor(0.)), x.grad[:, :limit].norm()
37 |
38 | # Now check that streaming gives the same result at batch eval.
39 | x = torch.randn(4, steps, 16)
40 | y = tr(x)
41 | ys = []
42 | with tr.streaming():
43 | for k in range(steps):
44 | chunk = x[:, k:k + 1, :]
45 | ys.append(tr(chunk))
46 | y_stream = torch.cat(ys, dim=1)
47 | delta = torch.norm(y_stream - y) / torch.norm(y)
48 | assert delta < 1e-6, delta
49 |
50 |
51 | def test_transformer_vs_pytorch():
52 | torch.manual_seed(1234)
53 | # Check that in the non causal setting, we get the same result as
54 | # PyTorch Transformer encoder.
55 | for custom in [False, True]:
56 | tr = StreamingTransformer(
57 | 16, 4, 2,
58 | causal=False, custom=custom, dropout=0., positional_scale=0.)
59 | layer = torch.nn.TransformerEncoderLayer(16, 4, dropout=0., batch_first=True)
60 | tr_ref = torch.nn.TransformerEncoder(layer, 2)
61 | tr.load_state_dict(tr_ref.state_dict())
62 |
63 | x = torch.randn(4, 20, 16)
64 | y = tr(x)
65 | y2 = tr_ref(x)
66 | delta = torch.norm(y2 - y) / torch.norm(y)
67 | assert delta < 1e-6, delta
68 |
69 |
70 | def test_streaming_api():
71 | tr = StreamingTransformer(16, 4, 2, causal=True, dropout=0.)
72 | tr.eval()
73 | steps = 12
74 | x = torch.randn(1, steps, 16)
75 |
76 | with torch.no_grad():
77 | with tr.streaming():
78 | _ = tr(x[:, :1])
79 | state = {k: v.clone() for k, v in tr.get_streaming_state().items()}
80 | y = tr(x[:, 1:2])
81 | tr.set_streaming_state(state)
82 | y2 = tr(x[:, 1:2])
83 | assert torch.allclose(y, y2), (y - y2).norm()
84 | assert tr.flush() is None
85 |
86 |
87 | def test_memory_efficient():
88 | torch.manual_seed(1234)
89 | tr = StreamingTransformer(
90 | 16, 4, 2, custom=True, dropout=0., layer_scale=0.1)
91 | tr_mem_efficient = StreamingTransformer(
92 | 16, 4, 2, dropout=0., memory_efficient=True, layer_scale=0.1)
93 | tr_mem_efficient.load_state_dict(tr.state_dict())
94 | tr.eval()
95 | steps = 12
96 | x = torch.randn(3, steps, 16)
97 |
98 | with torch.no_grad():
99 | y = tr(x)
100 | y2 = tr_mem_efficient(x)
101 | assert torch.allclose(y, y2), (y - y2).norm()
102 |
103 |
104 | def test_attention_as_float32():
105 | torch.manual_seed(1234)
106 | cases = [
107 | {'custom': True},
108 | {'custom': False},
109 | ]
110 | for case in cases:
111 | tr = StreamingTransformer(16, 4, 2, dropout=0., dtype=torch.bfloat16, **case)
112 | tr_float32 = StreamingTransformer(
113 | 16, 4, 2, dropout=0., attention_as_float32=True, dtype=torch.bfloat16, **case)
114 | if not case['custom']:
115 | # we are not using autocast here because it doesn't really
116 | # work as expected on CPU, so we have to manually cast the weights of the MHA.
117 | for layer in tr_float32.layers:
118 | layer.self_attn.mha.to(torch.float32)
119 | tr_float32.load_state_dict(tr.state_dict())
120 | steps = 12
121 | x = torch.randn(3, steps, 16, dtype=torch.bfloat16)
122 |
123 | with torch.no_grad():
124 | y = tr(x)
125 | y2 = tr_float32(x)
126 | assert not torch.allclose(y, y2), (y - y2).norm()
127 |
128 |
129 | @torch.no_grad()
130 | def test_streaming_memory_efficient():
131 | torch.manual_seed(1234)
132 | tr = StreamingTransformer(16, 4, 2, causal=True, dropout=0., custom=True)
133 | tr_mem_efficient = StreamingTransformer(
134 | 16, 4, 2, dropout=0., memory_efficient=True, causal=True)
135 | tr.load_state_dict(tr_mem_efficient.state_dict())
136 | tr.eval()
137 | tr_mem_efficient.eval()
138 | steps = 12
139 | x = torch.randn(3, steps, 16)
140 |
141 | ref = tr(x)
142 |
143 | with tr_mem_efficient.streaming():
144 | outs = []
145 | # frame_sizes = [2] + [1] * (steps - 2)
146 | frame_sizes = [1] * steps
147 |
148 | for frame_size in frame_sizes:
149 | frame = x[:, :frame_size]
150 | x = x[:, frame_size:]
151 | outs.append(tr_mem_efficient(frame))
152 |
153 | out = torch.cat(outs, dim=1)
154 | delta = torch.norm(out - ref) / torch.norm(out)
155 | assert delta < 1e-6, delta
156 |
157 |
158 | def test_cross_attention():
159 | torch.manual_seed(1234)
160 | for norm_first in [True, False]:
161 | m = StreamingTransformer(
162 | 16, 4, 2, cross_attention=False, norm_first=norm_first, dropout=0., custom=True)
163 | m_cross = StreamingTransformer(
164 | 16, 4, 2, cross_attention=True, norm_first=norm_first, dropout=0., custom=True)
165 | m_cross.load_state_dict(m.state_dict(), strict=False)
166 | x = torch.randn(2, 5, 16)
167 | cross_x = torch.randn(2, 3, 16)
168 | y_ref = m(x)
169 | y_cross_zero = m_cross(x, cross_attention_src=0 * cross_x)
170 | # With norm_first, the two should be exactly yhe same,
171 | # but with norm_first=False, we get 2 normalization in a row
172 | # and the epsilon value leads to a tiny change.
173 | atol = 0. if norm_first else 1e-6
174 | print((y_ref - y_cross_zero).norm() / y_ref.norm())
175 | assert torch.allclose(y_ref, y_cross_zero, atol=atol)
176 |
177 | # We now expect a difference even with a generous atol of 1e-2.
178 | y_cross = m_cross(x, cross_attention_src=cross_x)
179 | assert not torch.allclose(y_cross, y_cross_zero, atol=1e-2)
180 |
181 | with pytest.raises(AssertionError):
182 | _ = m_cross(x)
183 | _ = m(x, cross_attention_src=cross_x)
184 |
185 |
186 | def test_cross_attention_compat():
187 | torch.manual_seed(1234)
188 | num_heads = 2
189 | dim = num_heads * 64
190 | with pytest.raises(AssertionError):
191 | StreamingMultiheadAttention(dim, num_heads, causal=True, cross_attention=True)
192 |
193 | cross_attn = StreamingMultiheadAttention(
194 | dim, num_heads, dropout=0, cross_attention=True, custom=True)
195 | ref_attn = torch.nn.MultiheadAttention(dim, num_heads, dropout=0, batch_first=True)
196 |
197 | # We can load the regular attention state dict
198 | # so we have compat when loading old checkpoints.
199 | cross_attn.load_state_dict(ref_attn.state_dict())
200 |
201 | queries = torch.randn(3, 7, dim)
202 | keys = torch.randn(3, 9, dim)
203 | values = torch.randn(3, 9, dim)
204 |
205 | y = cross_attn(queries, keys, values)[0]
206 | y_ref = ref_attn(queries, keys, values)[0]
207 | assert torch.allclose(y, y_ref, atol=1e-7)
208 |
209 | # Now let's check that streaming is working properly.
210 | with cross_attn.streaming():
211 | ys = []
212 | for step in range(queries.shape[1]):
213 | ys.append(cross_attn(queries[:, step: step + 1], keys, values)[0])
214 | y_streaming = torch.cat(ys, dim=1)
215 | assert torch.allclose(y_streaming, y, atol=1e-7)
216 |
217 |
218 | def test_repeat_kv():
219 | torch.manual_seed(1234)
220 | num_heads = 8
221 | kv_repeat = 4
222 | dim = num_heads * 64
223 | with pytest.raises(AssertionError):
224 | mha = StreamingMultiheadAttention(
225 | dim, num_heads, causal=True, kv_repeat=kv_repeat, cross_attention=True)
226 | mha = StreamingMultiheadAttention(
227 | dim, num_heads, causal=True, kv_repeat=kv_repeat)
228 | mha = StreamingMultiheadAttention(
229 | dim, num_heads, causal=True, kv_repeat=kv_repeat, custom=True)
230 | x = torch.randn(4, 18, dim)
231 | y = mha(x, x, x)[0]
232 | assert x.shape == y.shape
233 |
234 |
235 | def test_qk_layer_norm():
236 | torch.manual_seed(1234)
237 | tr = StreamingTransformer(
238 | 16, 4, 2, custom=True, dropout=0., qk_layer_norm=True, bias_attn=False)
239 | steps = 12
240 | x = torch.randn(3, steps, 16)
241 | y = tr(x)
242 |
243 | tr = StreamingTransformer(
244 | 16, 4, 2, custom=True, dropout=0., qk_layer_norm=True, cross_attention=True)
245 | z = torch.randn(3, 21, 16)
246 | y = tr(x, cross_attention_src=z)
247 | assert y.shape == x.shape
248 |
--------------------------------------------------------------------------------
/tests/quantization/test_vq.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import torch
8 |
9 | from audiocraft.quantization.vq import ResidualVectorQuantizer
10 |
11 |
12 | class TestResidualVectorQuantizer:
13 |
14 | def test_rvq(self):
15 | x = torch.randn(1, 16, 2048)
16 | vq = ResidualVectorQuantizer(n_q=8, dimension=16, bins=8)
17 | res = vq(x, 1.)
18 | assert res.x.shape == torch.Size([1, 16, 2048])
19 |
--------------------------------------------------------------------------------
/tests/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------