├── .github ├── actions │ └── audiocraft_build │ │ └── action.yml └── workflows │ ├── audiocraft_docs.yml │ ├── audiocraft_linter.yml │ └── audiocraft_tests.yml ├── .gitignore ├── CHANGELOG.md ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── LICENSE_weights ├── MANIFEST.in ├── MODEL_CARD.md ├── Makefile ├── README.md ├── app.py ├── app_batched.py ├── assets ├── bach.mp3 └── bolero_ravel.mp3 ├── audiocraft ├── __init__.py ├── data │ ├── __init__.py │ ├── audio.py │ ├── audio_dataset.py │ ├── audio_utils.py │ └── zip.py ├── models │ ├── __init__.py │ ├── builders.py │ ├── encodec.py │ ├── lm.py │ ├── loaders.py │ └── musicgen.py ├── modules │ ├── __init__.py │ ├── activations.py │ ├── codebooks_patterns.py │ ├── conditioners.py │ ├── conv.py │ ├── lstm.py │ ├── rope.py │ ├── seanet.py │ ├── streaming.py │ └── transformer.py ├── py.typed ├── quantization │ ├── __init__.py │ ├── base.py │ ├── core_vq.py │ └── vq.py └── utils │ ├── __init__.py │ ├── autocast.py │ ├── export.py │ ├── notebook.py │ └── utils.py ├── demo.ipynb ├── mypy.ini ├── requirements.txt ├── setup.cfg ├── setup.py └── tests ├── __init__.py ├── common_utils ├── __init__.py ├── temp_utils.py └── wav_utils.py ├── data ├── __init__.py ├── test_audio.py ├── test_audio_dataset.py └── test_audio_utils.py ├── models ├── test_encodec_model.py └── test_musicgen.py ├── modules ├── __init__.py ├── test_codebooks_patterns.py ├── test_conv.py ├── test_lstm.py ├── test_rope.py ├── test_seanet.py └── test_transformer.py ├── quantization └── test_vq.py └── utils └── __init__.py /.github/actions/audiocraft_build/action.yml: -------------------------------------------------------------------------------- 1 | name: audiocraft_build 2 | description: 'Build audiocraft env.' 3 | runs: 4 | using: "composite" 5 | steps: 6 | - uses: actions/setup-python@v2 7 | with: 8 | python-version: 3.8 9 | - uses: actions/cache@v2 10 | id: cache 11 | with: 12 | path: env 13 | key: audiocraft_env-${{ hashFiles('**/requirements.txt') }} 14 | 15 | - if: ${{ steps.cache.outputs.cache-hit != 'true' }} 16 | name: Install dependencies 17 | shell: bash 18 | run: | 19 | sudo apt-get update 20 | sudo apt-get install libsndfile1-dev ffmpeg 21 | python3 -m venv env 22 | . env/bin/activate 23 | python -m pip install --upgrade pip 24 | pip install -e '.[dev]' 25 | - name: System Dependencies 26 | shell: bash 27 | run: | 28 | sudo apt-get update 29 | sudo apt-get install libsndfile1-dev ffmpeg 30 | -------------------------------------------------------------------------------- /.github/workflows/audiocraft_docs.yml: -------------------------------------------------------------------------------- 1 | name: audiocraft_docs 2 | on: 3 | push: 4 | branches: [ main ] 5 | 6 | jobs: 7 | run_docs: 8 | name: Run docs 9 | runs-on: ubuntu-latest 10 | steps: 11 | - uses: actions/checkout@v2 12 | - uses: ./.github/actions/audiocraft_build 13 | - name: Config git 14 | run: | 15 | git config --global user.email "defossez@fb.com" 16 | git config --global user.name "Alexandre Défossez (autodoc)" 17 | 18 | - name: Reset branch 19 | run: | 20 | git branch -f gh-docs main 21 | git checkout gh-docs 22 | 23 | - name: Make docs 24 | run: | 25 | . env/bin/activate 26 | make docs 27 | git add -f docs 28 | git commit -m docs 29 | 30 | - name: Push branch 31 | run: | 32 | git push -f -u origin gh-docs 33 | -------------------------------------------------------------------------------- /.github/workflows/audiocraft_linter.yml: -------------------------------------------------------------------------------- 1 | name: audiocraft_linter 2 | on: 3 | push: 4 | branches: [ main ] 5 | pull_request: 6 | branches: [ main ] 7 | 8 | jobs: 9 | run_linter: 10 | name: Run linter 11 | runs-on: ubuntu-latest 12 | steps: 13 | - uses: actions/checkout@v2 14 | - uses: ./.github/actions/audiocraft_build 15 | - run: | 16 | . env/bin/activate 17 | make linter 18 | -------------------------------------------------------------------------------- /.github/workflows/audiocraft_tests.yml: -------------------------------------------------------------------------------- 1 | name: audiocraft_tests 2 | on: 3 | push: 4 | branches: [ main ] 5 | pull_request: 6 | branches: [ main ] 7 | 8 | jobs: 9 | run_tests: 10 | name: Run tests 11 | runs-on: ubuntu-latest 12 | steps: 13 | - uses: actions/checkout@v2 14 | - uses: ./.github/actions/audiocraft_build 15 | - run: | 16 | . env/bin/activate 17 | make tests 18 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # macOS dir files 10 | .DS_Store 11 | 12 | # Distribution / packaging 13 | .Python 14 | env/ 15 | build/ 16 | develop-eggs/ 17 | dist/ 18 | downloads/ 19 | eggs/ 20 | .eggs/ 21 | lib/ 22 | lib64/ 23 | parts/ 24 | sdist/ 25 | var/ 26 | wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | .ipynb_checkpoints 31 | 32 | # Tests and linter 33 | .pytest_cache/ 34 | .mypy_cache/ 35 | .coverage 36 | 37 | # docs 38 | /docs 39 | 40 | # dotenv 41 | .env 42 | .envrc 43 | 44 | # virtualenv 45 | .venv 46 | venv/ 47 | ENV/ 48 | 49 | # personal notebooks & scripts 50 | */local_scripts 51 | */notes 52 | .vscode/ 53 | /notebooks 54 | /local_scripts 55 | /notes 56 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # Changelog 2 | 3 | All notable changes to this project will be documented in this file. 4 | 5 | The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/). 6 | 7 | ## [0.0.2a] - TBD 8 | 9 | Improved demo, fixed top p (thanks @jnordberg). 10 | 11 | Compressor tanh on output to avoid clipping with some style (especially piano). 12 | Now repeating the conditioning periodically if it is too short. 13 | 14 | More options when launching Gradio app locally (thanks @ashleykleynhans). 15 | 16 | ## [0.0.1] - 2023-06-09 17 | 18 | Initial release, with model evaluation only. 19 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to make participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies within all project spaces, and it also applies when 49 | an individual is representing the project or its community in public spaces. 50 | Examples of representing a project or community include using an official 51 | project e-mail address, posting via an official social media account, or acting 52 | as an appointed representative at an online or offline event. Representation of 53 | a project may be further defined and clarified by project maintainers. 54 | 55 | This Code of Conduct also applies outside the project spaces when there is a 56 | reasonable belief that an individual's behavior may have a negative impact on 57 | the project or its community. 58 | 59 | ## Enforcement 60 | 61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 62 | reported by contacting the project team at . All 63 | complaints will be reviewed and investigated and will result in a response that 64 | is deemed necessary and appropriate to the circumstances. The project team is 65 | obligated to maintain confidentiality with regard to the reporter of an incident. 66 | Further details of specific enforcement policies may be posted separately. 67 | 68 | Project maintainers who do not follow or enforce the Code of Conduct in good 69 | faith may face temporary or permanent repercussions as determined by other 70 | members of the project's leadership. 71 | 72 | ## Attribution 73 | 74 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 75 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 76 | 77 | [homepage]: https://www.contributor-covenant.org 78 | 79 | For answers to common questions about this code of conduct, see 80 | https://www.contributor-covenant.org/faq 81 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to Audiocraft 2 | 3 | We want to make contributing to this project as easy and transparent as 4 | possible. 5 | 6 | ## Pull Requests 7 | 8 | Audiocraft is the implementation of a research paper. 9 | Therefore, we do not plan on accepting many pull requests for new features. 10 | We certainly welcome them for bug fixes. 11 | 12 | 1. Fork the repo and create your branch from `main`. 13 | 2. If you've added code that should be tested, add tests. 14 | 3. If you've changed APIs, update the documentation. 15 | 4. Ensure the test suite passes. 16 | 5. Make sure your code lints. 17 | 6. If you haven't already, complete the Contributor License Agreement ("CLA"). 18 | 19 | ## Contributor License Agreement ("CLA") 20 | In order to accept your pull request, we need you to submit a CLA. You only need 21 | to do this once to work on any of Meta's open source projects. 22 | 23 | Complete your CLA here: 24 | 25 | ## Issues 26 | We use GitHub issues to track public bugs. Please ensure your description is 27 | clear and has sufficient instructions to be able to reproduce the issue. 28 | 29 | Meta has a [bounty program](https://www.facebook.com/whitehat/) for the safe 30 | disclosure of security bugs. In those cases, please go through the process 31 | outlined on that page and do not file a public issue. 32 | 33 | ## License 34 | By contributing to encodec, you agree that your contributions will be licensed 35 | under the LICENSE file in the root directory of this source tree. 36 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Meta Platforms, Inc. and affiliates. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include Makefile 2 | include LICENSE 3 | include LICENSE_weights 4 | include *.md 5 | include *.ini 6 | include requirements.txt 7 | include audiocraft/py.typed 8 | include assets/*.mp3 9 | -------------------------------------------------------------------------------- /MODEL_CARD.md: -------------------------------------------------------------------------------- 1 | # MusicGen Model Card 2 | 3 | ## Model details 4 | 5 | **Organization developing the model:** The FAIR team of Meta AI. 6 | 7 | **Model date:** MusicGen was trained between April 2023 and May 2023. 8 | 9 | **Model version:** This is the version 1 of the model. 10 | 11 | **Model type:** MusicGen consists of an EnCodec model for audio tokenization, an auto-regressive language model based on the transformer architecture for music modeling. The model comes in different sizes: 300M, 1.5B and 3.3B parameters ; and two variants: a model trained for text-to-music generation task and a model trained for melody-guided music generation. 12 | 13 | **Paper or resources for more information:** More information can be found in the paper [Simple and Controllable Music Generation][arxiv]. 14 | 15 | **Citation details** See [our paper][arxiv] 16 | 17 | **License** Code is released under MIT, model weights are released under CC-BY-NC 4.0. 18 | 19 | **Where to send questions or comments about the model:** Questions and comments about MusicGen can be sent via the [Github repository](https://github.com/facebookresearch/audiocraft) of the project, or by opening an issue. 20 | 21 | ## Intended use 22 | **Primary intended use:** The primary use of MusicGen is research on AI-based music generation, including: 23 | 24 | - Research efforts, such as probing and better understanding the limitations of generative models to further improve the state of science 25 | - Generation of music guided by text or melody to understand current abilities of generative AI models by machine learning amateurs 26 | 27 | **Primary intended users:** The primary intended users of the model are researchers in audio, machine learning and artificial intelligence, as well as amateur seeking to better understand those models. 28 | 29 | **Out-of-scope use cases** The model should not be used on downstream applications without further risk evaluation and mitigation. The model should not be used to intentionally create or disseminate music pieces that create hostile or alienating environments for people. This includes generating music that people would foreseeably find disturbing, distressing, or offensive; or content that propagates historical or current stereotypes. 30 | 31 | ## Metrics 32 | 33 | **Models performance measures:** We used the following objective measure to evaluate the model on a standard music benchmark: 34 | 35 | - Frechet Audio Distance computed on features extracted from a pre-trained audio classifier (VGGish) 36 | - Kullback-Leibler Divergence on label distributions extracted from a pre-trained audio classifier (PaSST) 37 | - CLAP Score between audio embedding and text embedding extracted from a pre-trained CLAP model 38 | 39 | Additionally, we run qualitative studies with human participants, evaluating the performance of the model with the following axes: 40 | 41 | - Overall quality of the music samples; 42 | - Text relevance to the provided text input; 43 | - Adherence to the melody for melody-guided music generation. 44 | 45 | More details on performance measures and human studies can be found in the paper. 46 | 47 | **Decision thresholds:** Not applicable. 48 | 49 | ## Evaluation datasets 50 | 51 | The model was evaluated on the [MusicCaps benchmark](https://www.kaggle.com/datasets/googleai/musiccaps) and on an in-domain held-out evaluation set, with no artist overlap with the training set. 52 | 53 | ## Training datasets 54 | 55 | The model was trained on licensed data using the following sources: the [Meta Music Initiative Sound Collection](https://www.fb.com/sound), [Shutterstock music collection](https://www.shutterstock.com/music) and the [Pond5 music collection](https://www.pond5.com/). See the paper for more details about the training set and corresponding preprocessing. 56 | 57 | ## Quantitative analysis 58 | 59 | More information can be found in the paper [Simple and Controllable Music Generation][arxiv], in the Experimental Setup section. 60 | 61 | ## Limitations and biases 62 | 63 | **Data:** The data sources used to train the model are created by music professionals and covered by legal agreements with the right holders. The model is trained on 20K hours of data, we believe that scaling the model on larger datasets can further improve the performance of the model. 64 | 65 | **Mitigations:** Vocals have been removed from the data source using corresponding tags, and then using using a state-of-the-art music source separation method, namely using the open source [Hybrid Transformer for Music Source Separation](https://github.com/facebookresearch/demucs) (HT-Demucs). 66 | 67 | **Limitations:** 68 | 69 | - The model is not able to generate realistic vocals. 70 | - The model has been trained with English descriptions and will not perform as well in other languages. 71 | - The model does not perform equally well for all music styles and cultures. 72 | - The model sometimes generates end of songs, collapsing to silence. 73 | - It is sometimes difficult to assess what types of text descriptions provide the best generations. Prompt engineering may be required to obtain satisfying results. 74 | 75 | **Biases:** The source of data is potentially lacking diversity and all music cultures are not equally represented in the dataset. The model may not perform equally well on the wide variety of music genres that exists. The generated samples from the model will reflect the biases from the training data. Further work on this model should include methods for balanced and just representations of cultures, for example, by scaling the training data to be both diverse and inclusive. 76 | 77 | **Risks and harms:** Biases and limitations of the model may lead to generation of samples that may be considered as biased, inappropriate or offensive. We believe that providing the code to reproduce the research and train new models will allow to broaden the application to new and more representative data. 78 | 79 | **Use cases:** Users must be aware of the biases, limitations and risks of the model. MusicGen is a model developed for artificial intelligence research on controllable music generation. As such, it should not be used for downstream applications without further investigation and mitigation of risks. 80 | 81 | [arxiv]: https://arxiv.org/abs/2306.05284 82 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | default: linter tests 2 | 3 | install: 4 | pip install -U pip 5 | pip install -U -e '.[dev]' 6 | 7 | linter: 8 | flake8 audiocraft && mypy audiocraft 9 | flake8 tests && mypy tests 10 | 11 | tests: 12 | coverage run -m pytest tests 13 | coverage report --include 'audiocraft/*' 14 | 15 | docs: 16 | pdoc3 --html -o docs -f audiocraft 17 | 18 | dist: 19 | python setup.py sdist 20 | 21 | .PHONY: linter tests docs dist 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Audiocraft 2 | ![docs badge](https://github.com/facebookresearch/audiocraft/workflows/audiocraft_docs/badge.svg) 3 | ![linter badge](https://github.com/facebookresearch/audiocraft/workflows/audiocraft_linter/badge.svg) 4 | ![tests badge](https://github.com/facebookresearch/audiocraft/workflows/audiocraft_tests/badge.svg) 5 | 6 | Audiocraft is a PyTorch library for deep learning research on audio generation. At the moment, it contains the code for MusicGen, a state-of-the-art controllable text-to-music model. 7 | 8 | ## MusicGen 9 | 10 | Audiocraft provides the code and models for MusicGen, [a simple and controllable model for music generation][arxiv]. MusicGen is a single stage auto-regressive 11 | Transformer model trained over a 32kHz EnCodec tokenizer with 4 codebooks sampled at 50 Hz. Unlike existing methods like [MusicLM](https://arxiv.org/abs/2301.11325), MusicGen doesn't require a self-supervised semantic representation, and it generates 12 | all 4 codebooks in one pass. By introducing a small delay between the codebooks, we show we can predict 13 | them in parallel, thus having only 50 auto-regressive steps per second of audio. 14 | Check out our [sample page][musicgen_samples] or test the available demo! 15 | 16 | 17 | Open In Colab 18 | 19 | 20 | Open in HugginFace 21 | 22 |
23 | 24 | We use 20K hours of licensed music to train MusicGen. Specifically, we rely on an internal dataset of 10K high-quality music tracks, and on the ShutterStock and Pond5 music data. 25 | 26 | ## Installation 27 | Audiocraft requires Python 3.9, PyTorch 2.0.0, and a GPU with at least 16 GB of memory (for the medium-sized model). To install Audiocraft, you can run the following: 28 | 29 | ```shell 30 | # Best to make sure you have torch installed first, in particular before installing xformers. 31 | # Don't run this if you already have PyTorch installed. 32 | pip install 'torch>=2.0' 33 | # Then proceed to one of the following 34 | pip install -U audiocraft # stable release 35 | pip install -U git+https://git@github.com/facebookresearch/audiocraft#egg=audiocraft # bleeding edge 36 | pip install -e . # or if you cloned the repo locally 37 | ``` 38 | 39 | ## Usage 40 | We offer a number of way to interact with MusicGen: 41 | 1. You can play with MusicGen by running the jupyter notebook at [`demo.ipynb`](./demo.ipynb) locally, or use the provided [colab notebook](https://colab.research.google.com/drive/1fxGqfg96RBUvGxZ1XXN07s3DthrKUl4-?usp=sharing). 42 | 2. You can use the gradio demo locally by running `python app.py`. 43 | 3. A demo is also available on the [`facebook/MusicGen` HuggingFace Space](https://huggingface.co/spaces/facebook/MusicGen) (huge thanks to all the HF team for their support). 44 | 4. Finally, you can run the [Gradio demo with a Colab GPU](https://colab.research.google.com/drive/1-Xe9NCdIs2sCUbiSmwHXozK6AAhMm7_i?usp=sharing), 45 | as adapted from [@camenduru Colab](https://github.com/camenduru/MusicGen-colab). 46 | 47 | ## API 48 | 49 | We provide a simple API and 4 pre-trained models. The pre trained models are: 50 | - `small`: 300M model, text to music only - [🤗 Hub](https://huggingface.co/facebook/musicgen-small) 51 | - `medium`: 1.5B model, text to music only - [🤗 Hub](https://huggingface.co/facebook/musicgen-medium) 52 | - `melody`: 1.5B model, text to music and text+melody to music - [🤗 Hub](https://huggingface.co/facebook/musicgen-melody) 53 | - `large`: 3.3B model, text to music only - [🤗 Hub](https://huggingface.co/facebook/musicgen-large) 54 | 55 | We observe the best trade-off between quality and compute with the `medium` or `melody` model. 56 | In order to use MusicGen locally **you must have a GPU**. We recommend 16GB of memory, but smaller 57 | GPUs will be able to generate short sequences, or longer sequences with the `small` model. 58 | 59 | **Note**: Please make sure to have [ffmpeg](https://ffmpeg.org/download.html) installed when using newer version of `torchaudio`. 60 | You can install it with: 61 | ``` 62 | apt-get install ffmpeg 63 | ``` 64 | 65 | See after a quick example for using the API. 66 | 67 | ```python 68 | import torchaudio 69 | from audiocraft.models import MusicGen 70 | from audiocraft.data.audio import audio_write 71 | 72 | model = MusicGen.get_pretrained('melody') 73 | model.set_generation_params(duration=8) # generate 8 seconds. 74 | wav = model.generate_unconditional(4) # generates 4 unconditional audio samples 75 | descriptions = ['happy rock', 'energetic EDM', 'sad jazz'] 76 | wav = model.generate(descriptions) # generates 3 samples. 77 | 78 | melody, sr = torchaudio.load('./assets/bach.mp3') 79 | # generates using the melody from the given audio and the provided descriptions. 80 | wav = model.generate_with_chroma(descriptions, melody[None].expand(3, -1, -1), sr) 81 | 82 | for idx, one_wav in enumerate(wav): 83 | # Will save under {idx}.wav, with loudness normalization at -14 db LUFS. 84 | audio_write(f'{idx}', one_wav.cpu(), model.sample_rate, strategy="loudness", loudness_compressor=True) 85 | ``` 86 | 87 | 88 | ## Model Card 89 | 90 | See [the model card page](./MODEL_CARD.md). 91 | 92 | ## FAQ 93 | 94 | #### Will the training code be released? 95 | 96 | Yes. We will soon release the training code for MusicGen and EnCodec. 97 | 98 | 99 | #### I need help on Windows 100 | 101 | @FurkanGozukara made a complete tutorial for [Audiocraft/MusicGen on Windows](https://youtu.be/v-YpvPkhdO4) 102 | 103 | 104 | ## Citation 105 | ``` 106 | @article{copet2023simple, 107 | title={Simple and Controllable Music Generation}, 108 | author={Jade Copet and Felix Kreuk and Itai Gat and Tal Remez and David Kant and Gabriel Synnaeve and Yossi Adi and Alexandre Défossez}, 109 | year={2023}, 110 | journal={arXiv preprint arXiv:2306.05284}, 111 | } 112 | ``` 113 | 114 | ## License 115 | * The code in this repository is released under the MIT license as found in the [LICENSE file](LICENSE). 116 | * The weights in this repository are released under the CC-BY-NC 4.0 license as found in the [LICENSE_weights file](LICENSE_weights). 117 | 118 | [arxiv]: https://arxiv.org/abs/2306.05284 119 | [musicgen_samples]: https://ai.honu.io/papers/musicgen/ 120 | -------------------------------------------------------------------------------- /app.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Meta Platforms, Inc. and affiliates. 3 | All rights reserved. 4 | 5 | This source code is licensed under the license found in the 6 | LICENSE file in the root directory of this source tree. 7 | """ 8 | 9 | from tempfile import NamedTemporaryFile 10 | import argparse 11 | import torch 12 | import gradio as gr 13 | import os 14 | from audiocraft.models import MusicGen 15 | from audiocraft.data.audio import audio_write 16 | 17 | MODEL = None 18 | IS_SHARED_SPACE = "musicgen/MusicGen" in os.environ.get('SPACE_ID', '') 19 | 20 | 21 | def load_model(version): 22 | print("Loading model", version) 23 | return MusicGen.get_pretrained(version) 24 | 25 | 26 | def predict(model, text, melody, duration, topk, topp, temperature, cfg_coef): 27 | global MODEL 28 | topk = int(topk) 29 | if MODEL is None or MODEL.name != model: 30 | MODEL = load_model(model) 31 | 32 | if duration > MODEL.lm.cfg.dataset.segment_duration: 33 | raise gr.Error("MusicGen currently supports durations of up to 30 seconds!") 34 | MODEL.set_generation_params( 35 | use_sampling=True, 36 | top_k=topk, 37 | top_p=topp, 38 | temperature=temperature, 39 | cfg_coef=cfg_coef, 40 | duration=duration, 41 | ) 42 | 43 | if melody: 44 | sr, melody = melody[0], torch.from_numpy(melody[1]).to(MODEL.device).float().t().unsqueeze(0) 45 | print(melody.shape) 46 | if melody.dim() == 2: 47 | melody = melody[None] 48 | melody = melody[..., :int(sr * MODEL.lm.cfg.dataset.segment_duration)] 49 | output = MODEL.generate_with_chroma( 50 | descriptions=[text], 51 | melody_wavs=melody, 52 | melody_sample_rate=sr, 53 | progress=False 54 | ) 55 | else: 56 | output = MODEL.generate(descriptions=[text], progress=False) 57 | 58 | output = output.detach().cpu().float()[0] 59 | with NamedTemporaryFile("wb", suffix=".wav", delete=False) as file: 60 | audio_write( 61 | file.name, output, MODEL.sample_rate, strategy="loudness", 62 | loudness_headroom_db=16, loudness_compressor=True, add_suffix=False) 63 | waveform_video = gr.make_waveform(file.name) 64 | return waveform_video 65 | 66 | 67 | def ui(**kwargs): 68 | with gr.Blocks() as interface: 69 | gr.Markdown( 70 | """ 71 | # MusicGen 72 | This is your private demo for [MusicGen](https://github.com/facebookresearch/audiocraft), a simple and controllable model for music generation 73 | presented at: ["Simple and Controllable Music Generation"](https://huggingface.co/papers/2306.05284) 74 | """ 75 | ) 76 | if IS_SHARED_SPACE: 77 | gr.Markdown(""" 78 | ⚠ This Space doesn't work in this shared UI ⚠ 79 | 80 | 81 | Duplicate Space 82 | to use it privately, or use the public demo 83 | """) 84 | with gr.Row(): 85 | with gr.Column(): 86 | with gr.Row(): 87 | text = gr.Text(label="Input Text", interactive=True) 88 | melody = gr.Audio(source="upload", type="numpy", label="Melody Condition (optional)", interactive=True) 89 | with gr.Row(): 90 | submit = gr.Button("Submit") 91 | with gr.Row(): 92 | model = gr.Radio(["melody", "medium", "small", "large"], label="Model", value="melody", interactive=True) 93 | with gr.Row(): 94 | duration = gr.Slider(minimum=1, maximum=30, value=10, label="Duration", interactive=True) 95 | with gr.Row(): 96 | topk = gr.Number(label="Top-k", value=250, interactive=True) 97 | topp = gr.Number(label="Top-p", value=0, interactive=True) 98 | temperature = gr.Number(label="Temperature", value=1.0, interactive=True) 99 | cfg_coef = gr.Number(label="Classifier Free Guidance", value=3.0, interactive=True) 100 | with gr.Column(): 101 | output = gr.Video(label="Generated Music") 102 | submit.click(predict, inputs=[model, text, melody, duration, topk, topp, temperature, cfg_coef], outputs=[output]) 103 | gr.Examples( 104 | fn=predict, 105 | examples=[ 106 | [ 107 | "An 80s driving pop song with heavy drums and synth pads in the background", 108 | "./assets/bach.mp3", 109 | "melody" 110 | ], 111 | [ 112 | "A cheerful country song with acoustic guitars", 113 | "./assets/bolero_ravel.mp3", 114 | "melody" 115 | ], 116 | [ 117 | "90s rock song with electric guitar and heavy drums", 118 | None, 119 | "medium" 120 | ], 121 | [ 122 | "a light and cheerly EDM track, with syncopated drums, aery pads, and strong emotions", 123 | "./assets/bach.mp3", 124 | "melody" 125 | ], 126 | [ 127 | "lofi slow bpm electro chill with organic samples", 128 | None, 129 | "medium", 130 | ], 131 | ], 132 | inputs=[text, melody, model], 133 | outputs=[output] 134 | ) 135 | gr.Markdown( 136 | """ 137 | ### More details 138 | 139 | The model will generate a short music extract based on the description you provided. 140 | You can generate up to 30 seconds of audio. 141 | 142 | We present 4 model variations: 143 | 1. Melody -- a music generation model capable of generating music condition on text and melody inputs. **Note**, you can also use text only. 144 | 2. Small -- a 300M transformer decoder conditioned on text only. 145 | 3. Medium -- a 1.5B transformer decoder conditioned on text only. 146 | 4. Large -- a 3.3B transformer decoder conditioned on text only (might OOM for the longest sequences.) 147 | 148 | When using `melody`, ou can optionaly provide a reference audio from 149 | which a broad melody will be extracted. The model will then try to follow both the description and melody provided. 150 | 151 | You can also use your own GPU or a Google Colab by following the instructions on our repo. 152 | See [github.com/facebookresearch/audiocraft](https://github.com/facebookresearch/audiocraft) 153 | for more details. 154 | """ 155 | ) 156 | 157 | # Show the interface 158 | launch_kwargs = {} 159 | username = kwargs.get('username') 160 | password = kwargs.get('password') 161 | server_port = kwargs.get('server_port', 0) 162 | inbrowser = kwargs.get('inbrowser', False) 163 | share = kwargs.get('share', False) 164 | server_name = kwargs.get('listen') 165 | 166 | launch_kwargs['server_name'] = server_name 167 | 168 | if username and password: 169 | launch_kwargs['auth'] = (username, password) 170 | if server_port > 0: 171 | launch_kwargs['server_port'] = server_port 172 | if inbrowser: 173 | launch_kwargs['inbrowser'] = inbrowser 174 | if share: 175 | launch_kwargs['share'] = share 176 | 177 | interface.queue().launch(**launch_kwargs, max_threads=1) 178 | 179 | 180 | if __name__ == "__main__": 181 | parser = argparse.ArgumentParser() 182 | parser.add_argument( 183 | '--listen', 184 | type=str, 185 | default='127.0.0.1', 186 | help='IP to listen on for connections to Gradio', 187 | ) 188 | parser.add_argument( 189 | '--username', type=str, default='', help='Username for authentication' 190 | ) 191 | parser.add_argument( 192 | '--password', type=str, default='', help='Password for authentication' 193 | ) 194 | parser.add_argument( 195 | '--server_port', 196 | type=int, 197 | default=0, 198 | help='Port to run the server listener on', 199 | ) 200 | parser.add_argument( 201 | '--inbrowser', action='store_true', help='Open in browser' 202 | ) 203 | parser.add_argument( 204 | '--share', action='store_true', help='Share the gradio UI' 205 | ) 206 | 207 | args = parser.parse_args() 208 | 209 | ui( 210 | username=args.username, 211 | password=args.password, 212 | inbrowser=args.inbrowser, 213 | server_port=args.server_port, 214 | share=args.share, 215 | listen=args.listen 216 | ) 217 | -------------------------------------------------------------------------------- /app_batched.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Meta Platforms, Inc. and affiliates. 3 | All rights reserved. 4 | 5 | This source code is licensed under the license found in the 6 | LICENSE file in the root directory of this source tree. 7 | """ 8 | 9 | from tempfile import NamedTemporaryFile 10 | import torch 11 | import gradio as gr 12 | from audiocraft.data.audio_utils import convert_audio 13 | from audiocraft.data.audio import audio_write 14 | from audiocraft.models import MusicGen 15 | 16 | 17 | MODEL = None 18 | 19 | 20 | def load_model(): 21 | print("Loading model") 22 | return MusicGen.get_pretrained("melody") 23 | 24 | 25 | def predict(texts, melodies): 26 | global MODEL 27 | if MODEL is None: 28 | MODEL = load_model() 29 | 30 | duration = 12 31 | MODEL.set_generation_params(duration=duration) 32 | 33 | print(texts, melodies) 34 | processed_melodies = [] 35 | 36 | target_sr = 32000 37 | target_ac = 1 38 | for melody in melodies: 39 | if melody is None: 40 | processed_melodies.append(None) 41 | else: 42 | sr, melody = melody[0], torch.from_numpy(melody[1]).to(MODEL.device).float().t() 43 | if melody.dim() == 1: 44 | melody = melody[None] 45 | melody = melody[..., :int(sr * duration)] 46 | melody = convert_audio(melody, sr, target_sr, target_ac) 47 | processed_melodies.append(melody) 48 | 49 | outputs = MODEL.generate_with_chroma( 50 | descriptions=texts, 51 | melody_wavs=processed_melodies, 52 | melody_sample_rate=target_sr, 53 | progress=False 54 | ) 55 | 56 | outputs = outputs.detach().cpu().float() 57 | out_files = [] 58 | for output in outputs: 59 | with NamedTemporaryFile("wb", suffix=".wav", delete=False) as file: 60 | audio_write( 61 | file.name, output, MODEL.sample_rate, strategy="loudness", 62 | loudness_headroom_db=16, loudness_compressor=True, add_suffix=False) 63 | waveform_video = gr.make_waveform(file.name) 64 | out_files.append(waveform_video) 65 | return [out_files] 66 | 67 | 68 | with gr.Blocks() as demo: 69 | gr.Markdown( 70 | """ 71 | # MusicGen 72 | 73 | This is the demo for [MusicGen](https://github.com/facebookresearch/audiocraft), a simple and controllable model for music generation 74 | presented at: ["Simple and Controllable Music Generation"](https://huggingface.co/papers/2306.05284). 75 |
76 | 77 | Duplicate Space 78 | for longer sequences, more control and no queue.

79 | """ 80 | ) 81 | with gr.Row(): 82 | with gr.Column(): 83 | with gr.Row(): 84 | text = gr.Text(label="Describe your music", lines=2, interactive=True) 85 | melody = gr.Audio(source="upload", type="numpy", label="Condition on a melody (optional)", interactive=True) 86 | with gr.Row(): 87 | submit = gr.Button("Generate") 88 | with gr.Column(): 89 | output = gr.Video(label="Generated Music") 90 | submit.click(predict, inputs=[text, melody], outputs=[output], batch=True, max_batch_size=12) 91 | gr.Examples( 92 | fn=predict, 93 | examples=[ 94 | [ 95 | "An 80s driving pop song with heavy drums and synth pads in the background", 96 | "./assets/bach.mp3", 97 | ], 98 | [ 99 | "A cheerful country song with acoustic guitars", 100 | "./assets/bolero_ravel.mp3", 101 | ], 102 | [ 103 | "90s rock song with electric guitar and heavy drums", 104 | None, 105 | ], 106 | [ 107 | "a light and cheerly EDM track, with syncopated drums, aery pads, and strong emotions bpm: 130", 108 | "./assets/bach.mp3", 109 | ], 110 | [ 111 | "lofi slow bpm electro chill with organic samples", 112 | None, 113 | ], 114 | ], 115 | inputs=[text, melody], 116 | outputs=[output] 117 | ) 118 | gr.Markdown(""" 119 | ### More details 120 | 121 | The model will generate 12 seconds of audio based on the description you provided. 122 | You can optionaly provide a reference audio from which a broad melody will be extracted. 123 | The model will then try to follow both the description and melody provided. 124 | All samples are generated with the `melody` model. 125 | 126 | You can also use your own GPU or a Google Colab by following the instructions on our repo. 127 | 128 | See [github.com/facebookresearch/audiocraft](https://github.com/facebookresearch/audiocraft) 129 | for more details. 130 | """) 131 | 132 | demo.queue(max_size=15).launch() 133 | -------------------------------------------------------------------------------- /assets/bach.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rkfg/audiocraft/6d70065e31c2fb422a76237e03740dd3b627de8d/assets/bach.mp3 -------------------------------------------------------------------------------- /assets/bolero_ravel.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rkfg/audiocraft/6d70065e31c2fb422a76237e03740dd3b627de8d/assets/bolero_ravel.mp3 -------------------------------------------------------------------------------- /audiocraft/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # flake8: noqa 8 | from . import data, modules, models 9 | 10 | __version__ = '0.0.2a1' 11 | -------------------------------------------------------------------------------- /audiocraft/data/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # flake8: noqa 8 | from . import audio, audio_dataset 9 | -------------------------------------------------------------------------------- /audiocraft/data/audio.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | """ 8 | Audio IO methods are defined in this module (info, read, write), 9 | We rely on av library for faster read when possible, otherwise on torchaudio. 10 | """ 11 | 12 | from dataclasses import dataclass 13 | from pathlib import Path 14 | import logging 15 | import typing as tp 16 | 17 | import numpy as np 18 | import soundfile 19 | import torch 20 | from torch.nn import functional as F 21 | import torchaudio as ta 22 | 23 | import av 24 | 25 | from .audio_utils import f32_pcm, i16_pcm, normalize_audio 26 | 27 | 28 | _av_initialized = False 29 | 30 | 31 | def _init_av(): 32 | global _av_initialized 33 | if _av_initialized: 34 | return 35 | logger = logging.getLogger('libav.mp3') 36 | logger.setLevel(logging.ERROR) 37 | _av_initialized = True 38 | 39 | 40 | @dataclass(frozen=True) 41 | class AudioFileInfo: 42 | sample_rate: int 43 | duration: float 44 | channels: int 45 | 46 | 47 | def _av_info(filepath: tp.Union[str, Path]) -> AudioFileInfo: 48 | _init_av() 49 | with av.open(str(filepath)) as af: 50 | stream = af.streams.audio[0] 51 | sample_rate = stream.codec_context.sample_rate 52 | duration = float(stream.duration * stream.time_base) 53 | channels = stream.channels 54 | return AudioFileInfo(sample_rate, duration, channels) 55 | 56 | 57 | def _soundfile_info(filepath: tp.Union[str, Path]) -> AudioFileInfo: 58 | info = soundfile.info(filepath) 59 | return AudioFileInfo(info.samplerate, info.duration, info.channels) 60 | 61 | 62 | def audio_info(filepath: tp.Union[str, Path]) -> AudioFileInfo: 63 | # torchaudio no longer returns useful duration informations for some formats like mp3s. 64 | filepath = Path(filepath) 65 | if filepath.suffix in ['.flac', '.ogg']: # TODO: Validate .ogg can be safely read with av_info 66 | # ffmpeg has some weird issue with flac. 67 | return _soundfile_info(filepath) 68 | else: 69 | return _av_info(filepath) 70 | 71 | 72 | def _av_read(filepath: tp.Union[str, Path], seek_time: float = 0, duration: float = -1.) -> tp.Tuple[torch.Tensor, int]: 73 | """FFMPEG-based audio file reading using PyAV bindings. 74 | Soundfile cannot read mp3 and av_read is more efficient than torchaudio. 75 | 76 | Args: 77 | filepath (str or Path): Path to audio file to read. 78 | seek_time (float): Time at which to start reading in the file. 79 | duration (float): Duration to read from the file. If set to -1, the whole file is read. 80 | Returns: 81 | Tuple[torch.Tensor, int]: Tuple containing audio data and sample rate 82 | """ 83 | _init_av() 84 | with av.open(str(filepath)) as af: 85 | stream = af.streams.audio[0] 86 | sr = stream.codec_context.sample_rate 87 | num_frames = int(sr * duration) if duration >= 0 else -1 88 | frame_offset = int(sr * seek_time) 89 | # we need a small negative offset otherwise we get some edge artifact 90 | # from the mp3 decoder. 91 | af.seek(int(max(0, (seek_time - 0.1)) / stream.time_base), stream=stream) 92 | frames = [] 93 | length = 0 94 | for frame in af.decode(streams=stream.index): 95 | current_offset = int(frame.rate * frame.pts * frame.time_base) 96 | strip = max(0, frame_offset - current_offset) 97 | buf = torch.from_numpy(frame.to_ndarray()) 98 | if buf.shape[0] != stream.channels: 99 | buf = buf.view(-1, stream.channels).t() 100 | buf = buf[:, strip:] 101 | frames.append(buf) 102 | length += buf.shape[1] 103 | if num_frames > 0 and length >= num_frames: 104 | break 105 | assert frames 106 | # If the above assert fails, it is likely because we seeked past the end of file point, 107 | # in which case ffmpeg returns a single frame with only zeros, and a weird timestamp. 108 | # This will need proper debugging, in due time. 109 | wav = torch.cat(frames, dim=1) 110 | assert wav.shape[0] == stream.channels 111 | if num_frames > 0: 112 | wav = wav[:, :num_frames] 113 | return f32_pcm(wav), sr 114 | 115 | 116 | def audio_read(filepath: tp.Union[str, Path], seek_time: float = 0., 117 | duration: float = -1., pad: bool = False) -> tp.Tuple[torch.Tensor, int]: 118 | """Read audio by picking the most appropriate backend tool based on the audio format. 119 | 120 | Args: 121 | filepath (str or Path): Path to audio file to read. 122 | seek_time (float): Time at which to start reading in the file. 123 | duration (float): Duration to read from the file. If set to -1, the whole file is read. 124 | pad (bool): Pad output audio if not reaching expected duration. 125 | Returns: 126 | Tuple[torch.Tensor, int]: Tuple containing audio data and sample rate. 127 | """ 128 | fp = Path(filepath) 129 | if fp.suffix in ['.flac', '.ogg']: # TODO: check if we can safely use av_read for .ogg 130 | # There is some bug with ffmpeg and reading flac 131 | info = _soundfile_info(filepath) 132 | frames = -1 if duration <= 0 else int(duration * info.sample_rate) 133 | frame_offset = int(seek_time * info.sample_rate) 134 | wav, sr = soundfile.read(filepath, start=frame_offset, frames=frames, dtype=np.float32) 135 | assert info.sample_rate == sr, f"Mismatch of sample rates {info.sample_rate} {sr}" 136 | wav = torch.from_numpy(wav).t().contiguous() 137 | if len(wav.shape) == 1: 138 | wav = torch.unsqueeze(wav, 0) 139 | elif ( 140 | fp.suffix in ['.wav', '.mp3'] and fp.suffix[1:] in ta.utils.sox_utils.list_read_formats() 141 | and duration <= 0 and seek_time == 0 142 | ): 143 | # Torchaudio is faster if we load an entire file at once. 144 | wav, sr = ta.load(fp) 145 | else: 146 | wav, sr = _av_read(filepath, seek_time, duration) 147 | if pad and duration > 0: 148 | expected_frames = int(duration * sr) 149 | wav = F.pad(wav, (0, expected_frames - wav.shape[-1])) 150 | return wav, sr 151 | 152 | 153 | def audio_write(stem_name: tp.Union[str, Path], 154 | wav: torch.Tensor, sample_rate: int, 155 | format: str = 'wav', mp3_rate: int = 320, normalize: bool = True, 156 | strategy: str = 'peak', peak_clip_headroom_db: float = 1, 157 | rms_headroom_db: float = 18, loudness_headroom_db: float = 14, 158 | loudness_compressor: bool = False, 159 | log_clipping: bool = True, make_parent_dir: bool = True, 160 | add_suffix: bool = True) -> Path: 161 | """Convenience function for saving audio to disk. Returns the filename the audio was written to. 162 | 163 | Args: 164 | stem_name (str or Path): Filename without extension which will be added automatically. 165 | format (str): Either "wav" or "mp3". 166 | mp3_rate (int): kbps when using mp3s. 167 | normalize (bool): if `True` (default), normalizes according to the prescribed 168 | strategy (see after). If `False`, the strategy is only used in case clipping 169 | would happen. 170 | strategy (str): Can be either 'clip', 'peak', or 'rms'. Default is 'peak', 171 | i.e. audio is normalized by its largest value. RMS normalizes by root-mean-square 172 | with extra headroom to avoid clipping. 'clip' just clips. 173 | peak_clip_headroom_db (float): Headroom in dB when doing 'peak' or 'clip' strategy. 174 | rms_headroom_db (float): Headroom in dB when doing 'rms' strategy. This must be much larger 175 | than the `peak_clip` one to avoid further clipping. 176 | loudness_headroom_db (float): Target loudness for loudness normalization. 177 | loudness_compressor (bool): Uses tanh for soft clipping when strategy is 'loudness'. 178 | when strategy is 'loudness'log_clipping (bool): If True, basic logging on stderr when clipping still 179 | occurs despite strategy (only for 'rms'). 180 | make_parent_dir (bool): Make parent directory if it doesn't exist. 181 | Returns: 182 | Path: Path of the saved audio. 183 | """ 184 | assert wav.dtype.is_floating_point, "wav is not floating point" 185 | if wav.dim() == 1: 186 | wav = wav[None] 187 | elif wav.dim() > 2: 188 | raise ValueError("Input wav should be at most 2 dimension.") 189 | assert wav.isfinite().all() 190 | wav = normalize_audio(wav, normalize, strategy, peak_clip_headroom_db, 191 | rms_headroom_db, loudness_headroom_db, log_clipping=log_clipping, 192 | sample_rate=sample_rate, stem_name=str(stem_name)) 193 | kwargs: dict = {} 194 | if format == 'mp3': 195 | suffix = '.mp3' 196 | kwargs.update({"compression": mp3_rate}) 197 | elif format == 'wav': 198 | wav = i16_pcm(wav) 199 | suffix = '.wav' 200 | kwargs.update({"encoding": "PCM_S", "bits_per_sample": 16}) 201 | else: 202 | raise RuntimeError(f"Invalid format {format}. Only wav or mp3 are supported.") 203 | if not add_suffix: 204 | suffix = '' 205 | path = Path(str(stem_name) + suffix) 206 | if make_parent_dir: 207 | path.parent.mkdir(exist_ok=True, parents=True) 208 | try: 209 | ta.save(path, wav, sample_rate, **kwargs) 210 | except Exception: 211 | if path.exists(): 212 | # we do not want to leave half written files around. 213 | path.unlink() 214 | raise 215 | return path 216 | -------------------------------------------------------------------------------- /audiocraft/data/audio_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import sys 8 | import typing as tp 9 | 10 | import julius 11 | import torch 12 | import torchaudio 13 | 14 | 15 | def convert_audio_channels(wav: torch.Tensor, channels: int = 2) -> torch.Tensor: 16 | """Convert audio to the given number of channels. 17 | 18 | Args: 19 | wav (torch.Tensor): Audio wave of shape [B, C, T]. 20 | channels (int): Expected number of channels as output. 21 | Returns: 22 | torch.Tensor: Downmixed or unchanged audio wave [B, C, T]. 23 | """ 24 | *shape, src_channels, length = wav.shape 25 | if src_channels == channels: 26 | pass 27 | elif channels == 1: 28 | # Case 1: 29 | # The caller asked 1-channel audio, and the stream has multiple 30 | # channels, downmix all channels. 31 | wav = wav.mean(dim=-2, keepdim=True) 32 | elif src_channels == 1: 33 | # Case 2: 34 | # The caller asked for multiple channels, but the input file has 35 | # a single channel, replicate the audio over all channels. 36 | wav = wav.expand(*shape, channels, length) 37 | elif src_channels >= channels: 38 | # Case 3: 39 | # The caller asked for multiple channels, and the input file has 40 | # more channels than requested. In that case return the first channels. 41 | wav = wav[..., :channels, :] 42 | else: 43 | # Case 4: What is a reasonable choice here? 44 | raise ValueError('The audio file has less channels than requested but is not mono.') 45 | return wav 46 | 47 | 48 | def convert_audio(wav: torch.Tensor, from_rate: float, 49 | to_rate: float, to_channels: int) -> torch.Tensor: 50 | """Convert audio to new sample rate and number of audio channels. 51 | """ 52 | wav = julius.resample_frac(wav, int(from_rate), int(to_rate)) 53 | wav = convert_audio_channels(wav, to_channels) 54 | return wav 55 | 56 | 57 | def normalize_loudness(wav: torch.Tensor, sample_rate: int, loudness_headroom_db: float = 14, 58 | loudness_compressor: bool = False, energy_floor: float = 2e-3): 59 | """Normalize an input signal to a user loudness in dB LKFS. 60 | Audio loudness is defined according to the ITU-R BS.1770-4 recommendation. 61 | 62 | Args: 63 | wav (torch.Tensor): Input multichannel audio data. 64 | sample_rate (int): Sample rate. 65 | loudness_headroom_db (float): Target loudness of the output in dB LUFS. 66 | loudness_compressor (bool): Uses tanh for soft clipping. 67 | energy_floor (float): anything below that RMS level will not be rescaled. 68 | Returns: 69 | output (torch.Tensor): Loudness normalized output data. 70 | """ 71 | energy = wav.pow(2).mean().sqrt().item() 72 | if energy < energy_floor: 73 | return wav 74 | transform = torchaudio.transforms.Loudness(sample_rate) 75 | input_loudness_db = transform(wav).item() 76 | # calculate the gain needed to scale to the desired loudness level 77 | delta_loudness = -loudness_headroom_db - input_loudness_db 78 | gain = 10.0 ** (delta_loudness / 20.0) 79 | output = gain * wav 80 | if loudness_compressor: 81 | output = torch.tanh(output) 82 | assert output.isfinite().all(), (input_loudness_db, wav.pow(2).mean().sqrt()) 83 | return output 84 | 85 | 86 | def _clip_wav(wav: torch.Tensor, log_clipping: bool = False, stem_name: tp.Optional[str] = None) -> None: 87 | """Utility function to clip the audio with logging if specified.""" 88 | max_scale = wav.abs().max() 89 | if log_clipping and max_scale > 1: 90 | clamp_prob = (wav.abs() > 1).float().mean().item() 91 | print(f"CLIPPING {stem_name or ''} happening with proba (a bit of clipping is okay):", 92 | clamp_prob, "maximum scale: ", max_scale.item(), file=sys.stderr) 93 | wav.clamp_(-1, 1) 94 | 95 | 96 | def normalize_audio(wav: torch.Tensor, normalize: bool = True, 97 | strategy: str = 'peak', peak_clip_headroom_db: float = 1, 98 | rms_headroom_db: float = 18, loudness_headroom_db: float = 14, 99 | loudness_compressor: bool = False, log_clipping: bool = False, 100 | sample_rate: tp.Optional[int] = None, 101 | stem_name: tp.Optional[str] = None) -> torch.Tensor: 102 | """Normalize the audio according to the prescribed strategy (see after). 103 | 104 | Args: 105 | wav (torch.Tensor): Audio data. 106 | normalize (bool): if `True` (default), normalizes according to the prescribed 107 | strategy (see after). If `False`, the strategy is only used in case clipping 108 | would happen. 109 | strategy (str): Can be either 'clip', 'peak', or 'rms'. Default is 'peak', 110 | i.e. audio is normalized by its largest value. RMS normalizes by root-mean-square 111 | with extra headroom to avoid clipping. 'clip' just clips. 112 | peak_clip_headroom_db (float): Headroom in dB when doing 'peak' or 'clip' strategy. 113 | rms_headroom_db (float): Headroom in dB when doing 'rms' strategy. This must be much larger 114 | than the `peak_clip` one to avoid further clipping. 115 | loudness_headroom_db (float): Target loudness for loudness normalization. 116 | loudness_compressor (bool): If True, uses tanh based soft clipping. 117 | log_clipping (bool): If True, basic logging on stderr when clipping still 118 | occurs despite strategy (only for 'rms'). 119 | sample_rate (int): Sample rate for the audio data (required for loudness). 120 | stem_name (Optional[str]): Stem name for clipping logging. 121 | Returns: 122 | torch.Tensor: Normalized audio. 123 | """ 124 | scale_peak = 10 ** (-peak_clip_headroom_db / 20) 125 | scale_rms = 10 ** (-rms_headroom_db / 20) 126 | if strategy == 'peak': 127 | rescaling = (scale_peak / wav.abs().max()) 128 | if normalize or rescaling < 1: 129 | wav = wav * rescaling 130 | elif strategy == 'clip': 131 | wav = wav.clamp(-scale_peak, scale_peak) 132 | elif strategy == 'rms': 133 | mono = wav.mean(dim=0) 134 | rescaling = scale_rms / mono.pow(2).mean().sqrt() 135 | if normalize or rescaling < 1: 136 | wav = wav * rescaling 137 | _clip_wav(wav, log_clipping=log_clipping, stem_name=stem_name) 138 | elif strategy == 'loudness': 139 | assert sample_rate is not None, "Loudness normalization requires sample rate." 140 | wav = normalize_loudness(wav, sample_rate, loudness_headroom_db, loudness_compressor) 141 | _clip_wav(wav, log_clipping=log_clipping, stem_name=stem_name) 142 | else: 143 | assert wav.abs().max() < 1 144 | assert strategy == '' or strategy == 'none', f"Unexpected strategy: '{strategy}'" 145 | return wav 146 | 147 | 148 | def f32_pcm(wav: torch.Tensor) -> torch.Tensor: 149 | """Convert audio to float 32 bits PCM format. 150 | """ 151 | if wav.dtype.is_floating_point: 152 | return wav 153 | else: 154 | assert wav.dtype == torch.int16 155 | return wav.float() / 2**15 156 | 157 | 158 | def i16_pcm(wav: torch.Tensor) -> torch.Tensor: 159 | """Convert audio to int 16 bits PCM format. 160 | 161 | ..Warning:: There exist many formula for doing this convertion. None are perfect 162 | due to the asymetry of the int16 range. One either have possible clipping, DC offset, 163 | or inconsistancies with f32_pcm. If the given wav doesn't have enough headroom, 164 | it is possible that `i16_pcm(f32_pcm)) != Identity`. 165 | """ 166 | if wav.dtype.is_floating_point: 167 | assert wav.abs().max() <= 1 168 | candidate = (wav * 2 ** 15).round() 169 | if candidate.max() >= 2 ** 15: # clipping would occur 170 | candidate = (wav * (2 ** 15 - 1)).round() 171 | return candidate.short() 172 | else: 173 | assert wav.dtype == torch.int16 174 | return wav 175 | -------------------------------------------------------------------------------- /audiocraft/data/zip.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import typing 8 | import zipfile 9 | 10 | from dataclasses import dataclass 11 | from functools import lru_cache 12 | from typing_extensions import Literal 13 | 14 | 15 | DEFAULT_SIZE = 32 16 | MODE = Literal['r', 'w', 'x', 'a'] 17 | 18 | 19 | @dataclass(order=True) 20 | class PathInZip: 21 | """Class for holding a path of file within a zip file. 22 | 23 | Args: 24 | path: The convention is : 25 | Let's assume there is a zip file /some/location/foo.zip 26 | and inside of it is a json file located at /data/file1.json, 27 | Then we expect path = "/some/location/foo.zip:/data/file1.json" 28 | """ 29 | 30 | INFO_PATH_SEP = ':' 31 | zip_path: str 32 | file_path: str 33 | 34 | def __init__(self, path: str) -> None: 35 | split_path = path.split(self.INFO_PATH_SEP) 36 | assert len(split_path) == 2 37 | self.zip_path, self.file_path = split_path 38 | 39 | @classmethod 40 | def from_paths(cls, zip_path: str, file_path: str): 41 | return cls(zip_path + cls.INFO_PATH_SEP + file_path) 42 | 43 | def __str__(self) -> str: 44 | return self.zip_path + self.INFO_PATH_SEP + self.file_path 45 | 46 | 47 | def _open_zip(path: str, mode: MODE = 'r'): 48 | return zipfile.ZipFile(path, mode) 49 | 50 | 51 | _cached_open_zip = lru_cache(DEFAULT_SIZE)(_open_zip) 52 | 53 | 54 | def set_zip_cache_size(max_size: int): 55 | """Sets the maximal LRU caching for zip file opening. 56 | 57 | Args: 58 | max_size: the maximal LRU cache. 59 | """ 60 | global _cached_open_zip 61 | _cached_open_zip = lru_cache(max_size)(_open_zip) 62 | 63 | 64 | def open_file_in_zip(path_in_zip: PathInZip, mode: str = 'r') -> typing.IO: 65 | """Opens a file stored inside a zip and returns a file-like object. 66 | 67 | Args: 68 | path_in_zip: A PathInZip object representing the file to return a file-like object of. 69 | mode: The mode in which to open the file with. 70 | Returns: 71 | A file-like object for PathInZip. 72 | """ 73 | zf = _cached_open_zip(path_in_zip.zip_path) 74 | return zf.open(path_in_zip.file_path) 75 | -------------------------------------------------------------------------------- /audiocraft/models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # flake8: noqa 8 | from .musicgen import MusicGen 9 | from .lm import LMModel 10 | from .encodec import CompressionModel, EncodecModel 11 | -------------------------------------------------------------------------------- /audiocraft/models/builders.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | """ 8 | All the functions to build the relevant models and modules 9 | from the Hydra config. 10 | """ 11 | 12 | import typing as tp 13 | import warnings 14 | 15 | import audiocraft 16 | import omegaconf 17 | import torch 18 | 19 | from .encodec import CompressionModel, EncodecModel, FlattenedCompressionModel # noqa 20 | from .lm import LMModel 21 | from ..modules.codebooks_patterns import ( 22 | CodebooksPatternProvider, 23 | DelayedPatternProvider, 24 | ParallelPatternProvider, 25 | UnrolledPatternProvider, 26 | VALLEPattern, 27 | MusicLMPattern, 28 | ) 29 | from ..modules.conditioners import ( 30 | BaseConditioner, 31 | ConditioningProvider, 32 | LUTConditioner, 33 | T5Conditioner, 34 | ConditionFuser, 35 | ChromaStemConditioner, 36 | ) 37 | from .. import quantization as qt 38 | from ..utils.utils import dict_from_config 39 | 40 | 41 | def get_quantizer(quantizer: str, cfg: omegaconf.DictConfig, dimension: int) -> qt.BaseQuantizer: 42 | klass = { 43 | 'no_quant': qt.DummyQuantizer, 44 | 'rvq': qt.ResidualVectorQuantizer 45 | }[quantizer] 46 | kwargs = dict_from_config(getattr(cfg, quantizer)) 47 | if quantizer != 'no_quant': 48 | kwargs['dimension'] = dimension 49 | return klass(**kwargs) 50 | 51 | 52 | def get_encodec_autoencoder(encoder_name: str, cfg: omegaconf.DictConfig): 53 | if encoder_name == 'seanet': 54 | kwargs = dict_from_config(getattr(cfg, 'seanet')) 55 | encoder_override_kwargs = kwargs.pop('encoder') 56 | decoder_override_kwargs = kwargs.pop('decoder') 57 | encoder_kwargs = {**kwargs, **encoder_override_kwargs} 58 | decoder_kwargs = {**kwargs, **decoder_override_kwargs} 59 | encoder = audiocraft.modules.SEANetEncoder(**encoder_kwargs) 60 | decoder = audiocraft.modules.SEANetDecoder(**decoder_kwargs) 61 | return encoder, decoder 62 | else: 63 | raise KeyError(f'Unexpected compression model {cfg.compression_model}') 64 | 65 | 66 | def get_compression_model(cfg: omegaconf.DictConfig) -> CompressionModel: 67 | """Instantiate a compression model. 68 | """ 69 | if cfg.compression_model == 'encodec': 70 | kwargs = dict_from_config(getattr(cfg, 'encodec')) 71 | encoder_name = kwargs.pop('autoencoder') 72 | quantizer_name = kwargs.pop('quantizer') 73 | encoder, decoder = get_encodec_autoencoder(encoder_name, cfg) 74 | quantizer = get_quantizer(quantizer_name, cfg, encoder.dimension) 75 | frame_rate = kwargs['sample_rate'] // encoder.hop_length 76 | renormalize = kwargs.pop('renormalize', None) 77 | renorm = kwargs.pop('renorm') 78 | if renormalize is None: 79 | renormalize = renorm is not None 80 | warnings.warn("You are using a deprecated EnCodec model. Please migrate to new renormalization.") 81 | return EncodecModel(encoder, decoder, quantizer, 82 | frame_rate=frame_rate, renormalize=renormalize, **kwargs).to(cfg.device) 83 | else: 84 | raise KeyError(f'Unexpected compression model {cfg.compression_model}') 85 | 86 | 87 | def get_lm_model(cfg: omegaconf.DictConfig) -> LMModel: 88 | """Instantiate a transformer LM. 89 | """ 90 | if cfg.lm_model == 'transformer_lm': 91 | kwargs = dict_from_config(getattr(cfg, 'transformer_lm')) 92 | n_q = kwargs['n_q'] 93 | q_modeling = kwargs.pop('q_modeling', None) 94 | codebooks_pattern_cfg = getattr(cfg, 'codebooks_pattern') 95 | attribute_dropout = dict_from_config(getattr(cfg, 'attribute_dropout')) 96 | cls_free_guidance = dict_from_config(getattr(cfg, 'classifier_free_guidance')) 97 | cfg_prob, cfg_coef = cls_free_guidance["training_dropout"], cls_free_guidance["inference_coef"] 98 | fuser = get_condition_fuser(cfg) 99 | condition_provider = get_conditioner_provider(kwargs["dim"], cfg).to(cfg.device) 100 | if len(fuser.fuse2cond['cross']) > 0: # enforce cross-att programatically 101 | kwargs['cross_attention'] = True 102 | if codebooks_pattern_cfg.modeling is None: 103 | assert q_modeling is not None, \ 104 | 'LM model should either have a codebook pattern defined or transformer_lm.q_modeling' 105 | codebooks_pattern_cfg = omegaconf.OmegaConf.create( 106 | {'modeling': q_modeling, 'delay': {'delays': list(range(n_q))}} 107 | ) 108 | pattern_provider = get_codebooks_pattern_provider(n_q, codebooks_pattern_cfg) 109 | return LMModel( 110 | pattern_provider=pattern_provider, 111 | condition_provider=condition_provider, 112 | fuser=fuser, 113 | cfg_dropout=cfg_prob, 114 | cfg_coef=cfg_coef, 115 | attribute_dropout=attribute_dropout, 116 | dtype=getattr(torch, cfg.dtype), 117 | device=cfg.device, 118 | **kwargs 119 | ).to(cfg.device) 120 | else: 121 | raise KeyError(f'Unexpected LM model {cfg.lm_model}') 122 | 123 | 124 | def get_conditioner_provider(output_dim: int, cfg: omegaconf.DictConfig) -> ConditioningProvider: 125 | """Instantiate a conditioning model. 126 | """ 127 | device = cfg.device 128 | duration = cfg.dataset.segment_duration 129 | cfg = getattr(cfg, "conditioners") 130 | cfg = omegaconf.OmegaConf.create({}) if cfg is None else cfg 131 | conditioners: tp.Dict[str, BaseConditioner] = {} 132 | with omegaconf.open_dict(cfg): 133 | condition_provider_args = cfg.pop('args', {}) 134 | for cond, cond_cfg in cfg.items(): 135 | model_type = cond_cfg["model"] 136 | model_args = cond_cfg[model_type] 137 | if model_type == "t5": 138 | conditioners[str(cond)] = T5Conditioner(output_dim=output_dim, device=device, **model_args) 139 | elif model_type == "lut": 140 | conditioners[str(cond)] = LUTConditioner(output_dim=output_dim, **model_args) 141 | elif model_type == "chroma_stem": 142 | model_args.pop('cache_path', None) 143 | conditioners[str(cond)] = ChromaStemConditioner( 144 | output_dim=output_dim, 145 | duration=duration, 146 | device=device, 147 | **model_args 148 | ) 149 | else: 150 | raise ValueError(f"unrecognized conditioning model: {model_type}") 151 | conditioner = ConditioningProvider(conditioners, device=device, **condition_provider_args) 152 | return conditioner 153 | 154 | 155 | def get_condition_fuser(cfg: omegaconf.DictConfig) -> ConditionFuser: 156 | """Instantiate a condition fuser object. 157 | """ 158 | fuser_cfg = getattr(cfg, "fuser") 159 | fuser_methods = ["sum", "cross", "prepend", "input_interpolate"] 160 | fuse2cond = {k: fuser_cfg[k] for k in fuser_methods} 161 | kwargs = {k: v for k, v in fuser_cfg.items() if k not in fuser_methods} 162 | fuser = ConditionFuser(fuse2cond=fuse2cond, **kwargs) 163 | return fuser 164 | 165 | 166 | def get_codebooks_pattern_provider(n_q: int, cfg: omegaconf.DictConfig) -> CodebooksPatternProvider: 167 | """Instantiate a codebooks pattern provider object. 168 | """ 169 | pattern_providers = { 170 | 'parallel': ParallelPatternProvider, 171 | 'delay': DelayedPatternProvider, 172 | 'unroll': UnrolledPatternProvider, 173 | 'valle': VALLEPattern, 174 | 'musiclm': MusicLMPattern, 175 | } 176 | name = cfg.modeling 177 | kwargs = dict_from_config(cfg.get(name)) if hasattr(cfg, name) else {} 178 | klass = pattern_providers[name] 179 | return klass(n_q, **kwargs) 180 | 181 | 182 | def get_debug_compression_model(device='cpu'): 183 | """Instantiate a debug compression model to be used for unit tests. 184 | """ 185 | seanet_kwargs = { 186 | 'n_filters': 4, 187 | 'n_residual_layers': 1, 188 | 'dimension': 32, 189 | 'ratios': [10, 8, 16] # 25 Hz at 32kHz 190 | } 191 | encoder = audiocraft.modules.SEANetEncoder(**seanet_kwargs) 192 | decoder = audiocraft.modules.SEANetDecoder(**seanet_kwargs) 193 | quantizer = qt.ResidualVectorQuantizer(dimension=32, bins=400, n_q=4) 194 | init_x = torch.randn(8, 32, 128) 195 | quantizer(init_x, 1) # initialize kmeans etc. 196 | compression_model = EncodecModel( 197 | encoder, decoder, quantizer, 198 | frame_rate=25, sample_rate=32000, channels=1).to(device) 199 | return compression_model.eval() 200 | 201 | 202 | def get_debug_lm_model(device='cpu'): 203 | """Instantiate a debug LM to be used for unit tests. 204 | """ 205 | pattern = DelayedPatternProvider(n_q=4) 206 | dim = 16 207 | providers = { 208 | 'description': LUTConditioner(n_bins=128, dim=dim, output_dim=dim, tokenizer="whitespace"), 209 | } 210 | condition_provider = ConditioningProvider(providers) 211 | fuser = ConditionFuser( 212 | {'cross': ['description'], 'prepend': [], 213 | 'sum': [], 'input_interpolate': []}) 214 | lm = LMModel( 215 | pattern, condition_provider, fuser, 216 | n_q=4, card=400, dim=dim, num_heads=4, custom=True, num_layers=2, 217 | cross_attention=True, causal=True) 218 | return lm.to(device).eval() 219 | -------------------------------------------------------------------------------- /audiocraft/models/encodec.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from abc import ABC, abstractmethod 8 | import typing as tp 9 | 10 | from einops import rearrange 11 | import torch 12 | from torch import nn 13 | 14 | from .. import quantization as qt 15 | 16 | 17 | class CompressionModel(ABC, nn.Module): 18 | 19 | @abstractmethod 20 | def forward(self, x: torch.Tensor) -> qt.QuantizedResult: 21 | ... 22 | 23 | @abstractmethod 24 | def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]: 25 | """See `EncodecModel.encode`""" 26 | ... 27 | 28 | @abstractmethod 29 | def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None): 30 | """See `EncodecModel.decode`""" 31 | ... 32 | 33 | @property 34 | @abstractmethod 35 | def channels(self) -> int: 36 | ... 37 | 38 | @property 39 | @abstractmethod 40 | def frame_rate(self) -> int: 41 | ... 42 | 43 | @property 44 | @abstractmethod 45 | def sample_rate(self) -> int: 46 | ... 47 | 48 | @property 49 | @abstractmethod 50 | def cardinality(self) -> int: 51 | ... 52 | 53 | @property 54 | @abstractmethod 55 | def num_codebooks(self) -> int: 56 | ... 57 | 58 | @property 59 | @abstractmethod 60 | def total_codebooks(self) -> int: 61 | ... 62 | 63 | @abstractmethod 64 | def set_num_codebooks(self, n: int): 65 | """Set the active number of codebooks used by the quantizer. 66 | """ 67 | ... 68 | 69 | 70 | class EncodecModel(CompressionModel): 71 | """Encodec model operating on the raw waveform. 72 | 73 | Args: 74 | encoder (nn.Module): Encoder network. 75 | decoder (nn.Module): Decoder network. 76 | quantizer (qt.BaseQuantizer): Quantizer network. 77 | frame_rate (int): Frame rate for the latent representation. 78 | sample_rate (int): Audio sample rate. 79 | channels (int): Number of audio channels. 80 | causal (bool): Whether to use a causal version of the model. 81 | renormalize (bool): Whether to renormalize the audio before running the model. 82 | """ 83 | # we need assignement to override the property in the abstract class, 84 | # I couldn't find a better way... 85 | frame_rate: int = 0 86 | sample_rate: int = 0 87 | channels: int = 0 88 | 89 | def __init__(self, 90 | encoder: nn.Module, 91 | decoder: nn.Module, 92 | quantizer: qt.BaseQuantizer, 93 | frame_rate: int, 94 | sample_rate: int, 95 | channels: int, 96 | causal: bool = False, 97 | renormalize: bool = False): 98 | super().__init__() 99 | self.encoder = encoder 100 | self.decoder = decoder 101 | self.quantizer = quantizer 102 | self.frame_rate = frame_rate 103 | self.sample_rate = sample_rate 104 | self.channels = channels 105 | self.renormalize = renormalize 106 | self.causal = causal 107 | if self.causal: 108 | # we force disabling here to avoid handling linear overlap of segments 109 | # as supported in original EnCodec codebase. 110 | assert not self.renormalize, 'Causal model does not support renormalize' 111 | 112 | @property 113 | def total_codebooks(self): 114 | """Total number of quantizer codebooks available. 115 | """ 116 | return self.quantizer.total_codebooks 117 | 118 | @property 119 | def num_codebooks(self): 120 | """Active number of codebooks used by the quantizer. 121 | """ 122 | return self.quantizer.num_codebooks 123 | 124 | def set_num_codebooks(self, n: int): 125 | """Set the active number of codebooks used by the quantizer. 126 | """ 127 | self.quantizer.set_num_codebooks(n) 128 | 129 | @property 130 | def cardinality(self): 131 | """Cardinality of each codebook. 132 | """ 133 | return self.quantizer.bins 134 | 135 | def preprocess(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]: 136 | scale: tp.Optional[torch.Tensor] 137 | if self.renormalize: 138 | mono = x.mean(dim=1, keepdim=True) 139 | volume = mono.pow(2).mean(dim=2, keepdim=True).sqrt() 140 | scale = 1e-8 + volume 141 | x = x / scale 142 | scale = scale.view(-1, 1) 143 | else: 144 | scale = None 145 | return x, scale 146 | 147 | def postprocess(self, 148 | x: torch.Tensor, 149 | scale: tp.Optional[torch.Tensor] = None) -> torch.Tensor: 150 | if scale is not None: 151 | assert self.renormalize 152 | x = x * scale.view(-1, 1, 1) 153 | return x 154 | 155 | def forward(self, x: torch.Tensor) -> qt.QuantizedResult: 156 | assert x.dim() == 3 157 | length = x.shape[-1] 158 | x, scale = self.preprocess(x) 159 | 160 | emb = self.encoder(x) 161 | q_res = self.quantizer(emb, self.frame_rate) 162 | out = self.decoder(q_res.x) 163 | 164 | # remove extra padding added by the encoder and decoder 165 | assert out.shape[-1] >= length, (out.shape[-1], length) 166 | out = out[..., :length] 167 | 168 | q_res.x = self.postprocess(out, scale) 169 | 170 | return q_res 171 | 172 | def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]: 173 | """Encode the given input tensor to quantized representation along with scale parameter. 174 | 175 | Args: 176 | x (torch.Tensor): Float tensor of shape [B, C, T] 177 | 178 | Returns: 179 | codes, scale (tp.Tuple[torch.Tensor, torch.Tensor]): Tuple composed of: 180 | codes a float tensor of shape [B, K, T] with K the number of codebooks used and T the timestep. 181 | scale a float tensor containing the scale for audio renormalizealization. 182 | """ 183 | assert x.dim() == 3 184 | x, scale = self.preprocess(x) 185 | emb = self.encoder(x) 186 | codes = self.quantizer.encode(emb) 187 | return codes, scale 188 | 189 | def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None): 190 | """Decode the given codes to a reconstructed representation, using the scale to perform 191 | audio denormalization if needed. 192 | 193 | Args: 194 | codes (torch.Tensor): Int tensor of shape [B, K, T] 195 | scale (tp.Optional[torch.Tensor]): Float tensor containing the scale value. 196 | 197 | Returns: 198 | out (torch.Tensor): Float tensor of shape [B, C, T], the reconstructed audio. 199 | """ 200 | emb = self.quantizer.decode(codes) 201 | out = self.decoder(emb) 202 | out = self.postprocess(out, scale) 203 | # out contains extra padding added by the encoder and decoder 204 | return out 205 | 206 | 207 | class FlattenedCompressionModel(CompressionModel): 208 | """Wraps a CompressionModel and flatten its codebooks, e.g. 209 | instead of returning [B, K, T], return [B, S, T * (K // S)] with 210 | S the number of codebooks per step, and `K // S` the number of 'virtual steps' 211 | for each real time step. 212 | 213 | Args: 214 | model (CompressionModel): compression model to wrap. 215 | codebooks_per_step (int): number of codebooks to keep per step, 216 | this must divide the number of codebooks provided by the wrapped model. 217 | extend_cardinality (bool): if True, and for instance if codebooks_per_step = 1, 218 | if each codebook has a cardinality N, then the first codebook will 219 | use the range [0, N - 1], and the second [N, 2 N - 1] etc. 220 | On decoding, this can lead to potentially invalid sequences. 221 | Any invalid entry will be silently remapped to the proper range 222 | with a modulo. 223 | """ 224 | def __init__(self, model: CompressionModel, codebooks_per_step: int = 1, 225 | extend_cardinality: bool = True): 226 | super().__init__() 227 | self.model = model 228 | self.codebooks_per_step = codebooks_per_step 229 | self.extend_cardinality = extend_cardinality 230 | 231 | @property 232 | def total_codebooks(self): 233 | return self.model.total_codebooks 234 | 235 | @property 236 | def num_codebooks(self): 237 | """Active number of codebooks used by the quantizer. 238 | 239 | ..Warning:: this reports the number of codebooks after the flattening 240 | of the codebooks! 241 | """ 242 | assert self.model.num_codebooks % self.codebooks_per_step == 0 243 | return self.codebooks_per_step 244 | 245 | def set_num_codebooks(self, n: int): 246 | """Set the active number of codebooks used by the quantizer. 247 | 248 | ..Warning:: this sets the number of codebooks **before** the flattening 249 | of the codebooks. 250 | """ 251 | assert n % self.codebooks_per_step == 0 252 | self.model.set_num_codebooks(n) 253 | 254 | @property 255 | def num_virtual_steps(self) -> int: 256 | """Return the number of virtual steps, e.g. one real step 257 | will be split into that many steps. 258 | """ 259 | return self.model.num_codebooks // self.codebooks_per_step 260 | 261 | @property 262 | def frame_rate(self) -> int: 263 | return self.model.frame_rate * self.num_virtual_steps 264 | 265 | @property 266 | def sample_rate(self) -> int: 267 | return self.model.sample_rate 268 | 269 | @property 270 | def channels(self) -> int: 271 | return self.model.channels 272 | 273 | @property 274 | def cardinality(self): 275 | """Cardinality of each codebook. 276 | """ 277 | if self.extend_cardinality: 278 | return self.model.cardinality * self.num_virtual_steps 279 | else: 280 | return self.model.cardinality 281 | 282 | def forward(self, x: torch.Tensor) -> qt.QuantizedResult: 283 | raise NotImplementedError("Not supported, use encode and decode.") 284 | 285 | def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]: 286 | indices, scales = self.model.encode(x) 287 | B, K, T = indices.shape 288 | indices = rearrange(indices, 'b (k v) t -> b k t v', k=self.codebooks_per_step) 289 | if self.extend_cardinality: 290 | for virtual_step in range(1, self.num_virtual_steps): 291 | indices[..., virtual_step] += self.model.cardinality * virtual_step 292 | indices = rearrange(indices, 'b k t v -> b k (t v)') 293 | return (indices, scales) 294 | 295 | def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None): 296 | B, K, T = codes.shape 297 | assert T % self.num_virtual_steps == 0 298 | codes = rearrange(codes, 'b k (t v) -> b (k v) t', v=self.num_virtual_steps) 299 | # We silently ignore potential errors from the LM when 300 | # using extend_cardinality. 301 | codes = codes % self.model.cardinality 302 | return self.model.decode(codes, scale) 303 | -------------------------------------------------------------------------------- /audiocraft/models/loaders.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | """ 8 | Utility functions to load from the checkpoints. 9 | Each checkpoint is a torch.saved dict with the following keys: 10 | - 'xp.cfg': the hydra config as dumped during training. This should be used 11 | to rebuild the object using the audiocraft.models.builders functions, 12 | - 'model_best_state': a readily loadable best state for the model, including 13 | the conditioner. The model obtained from `xp.cfg` should be compatible 14 | with this state dict. In the case of a LM, the encodec model would not be 15 | bundled along but instead provided separately. 16 | 17 | Those functions also support loading from a remote location with the Torch Hub API. 18 | They also support overriding some parameters, in particular the device and dtype 19 | of the returned model. 20 | """ 21 | 22 | from pathlib import Path 23 | from huggingface_hub import hf_hub_download 24 | import typing as tp 25 | import os 26 | 27 | from omegaconf import OmegaConf 28 | import torch 29 | 30 | from . import builders 31 | 32 | 33 | HF_MODEL_CHECKPOINTS_MAP = { 34 | "small": "facebook/musicgen-small", 35 | "medium": "facebook/musicgen-medium", 36 | "large": "facebook/musicgen-large", 37 | "melody": "facebook/musicgen-melody", 38 | } 39 | 40 | 41 | def _get_state_dict( 42 | file_or_url_or_id: tp.Union[Path, str], 43 | filename: tp.Optional[str] = None, 44 | device='cpu', 45 | cache_dir: tp.Optional[str] = None, 46 | ): 47 | # Return the state dict either from a file or url 48 | file_or_url_or_id = str(file_or_url_or_id) 49 | assert isinstance(file_or_url_or_id, str) 50 | 51 | if os.path.isfile(file_or_url_or_id): 52 | return torch.load(file_or_url_or_id, map_location=device) 53 | 54 | elif file_or_url_or_id.startswith('https://'): 55 | return torch.hub.load_state_dict_from_url(file_or_url_or_id, map_location=device, check_hash=True) 56 | 57 | elif file_or_url_or_id in HF_MODEL_CHECKPOINTS_MAP: 58 | assert filename is not None, "filename needs to be defined if using HF checkpoints" 59 | 60 | repo_id = HF_MODEL_CHECKPOINTS_MAP[file_or_url_or_id] 61 | file = hf_hub_download(repo_id=repo_id, filename=filename, cache_dir=cache_dir) 62 | return torch.load(file, map_location=device) 63 | 64 | else: 65 | raise ValueError(f"{file_or_url_or_id} is not a valid name, path or link that can be loaded.") 66 | 67 | 68 | def load_compression_model(file_or_url_or_id: tp.Union[Path, str], device='cpu', cache_dir: tp.Optional[str] = None): 69 | pkg = _get_state_dict(file_or_url_or_id, filename="compression_state_dict.bin", cache_dir=cache_dir) 70 | cfg = OmegaConf.create(pkg['xp.cfg']) 71 | cfg.device = str(device) 72 | model = builders.get_compression_model(cfg) 73 | model.load_state_dict(pkg['best_state']) 74 | model.eval() 75 | return model 76 | 77 | 78 | def load_lm_model(file_or_url_or_id: tp.Union[Path, str], device='cpu', cache_dir: tp.Optional[str] = None): 79 | pkg = _get_state_dict(file_or_url_or_id, filename="state_dict.bin", cache_dir=cache_dir) 80 | cfg = OmegaConf.create(pkg['xp.cfg']) 81 | cfg.device = str(device) 82 | if cfg.device == 'cpu': 83 | cfg.transformer_lm.memory_efficient = False 84 | cfg.transformer_lm.custom = True 85 | cfg.dtype = 'float32' 86 | else: 87 | cfg.dtype = 'float16' 88 | model = builders.get_lm_model(cfg) 89 | model.load_state_dict(pkg['best_state']) 90 | model.eval() 91 | model.cfg = cfg 92 | return model 93 | -------------------------------------------------------------------------------- /audiocraft/modules/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # flake8: noqa 8 | from .conv import ( 9 | NormConv1d, 10 | NormConv2d, 11 | NormConvTranspose1d, 12 | NormConvTranspose2d, 13 | StreamableConv1d, 14 | StreamableConvTranspose1d, 15 | pad_for_conv1d, 16 | pad1d, 17 | unpad1d, 18 | ) 19 | from .lstm import StreamableLSTM 20 | from .seanet import SEANetEncoder, SEANetDecoder 21 | -------------------------------------------------------------------------------- /audiocraft/modules/activations.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | import torch.nn as nn 9 | from torch import Tensor 10 | from typing import Union, Callable 11 | 12 | 13 | class CustomGLU(nn.Module): 14 | """Custom Gated Linear Unit activation. 15 | Applies a modified gated linear unit :math:`a * f(b)` where :math:`a` is the first half 16 | of the input matrices, :math:`b` is the second half, and :math:`f` is a provided activation 17 | function (i.e. sigmoid, swish, etc.). 18 | 19 | Args: 20 | activation (nn.Module): The custom activation to apply in the Gated Linear Unit 21 | dim (int): the dimension on which to split the input. Default: -1 22 | 23 | Shape: 24 | - Input: :math:`(\ast_1, N, \ast_2)` where `*` means, any number of additional 25 | dimensions 26 | - Output: :math:`(\ast_1, M, \ast_2)` where :math:`M=N/2` 27 | 28 | Examples:: 29 | >>> m = CustomGLU(nn.Sigmoid()) 30 | >>> input = torch.randn(4, 2) 31 | >>> output = m(input) 32 | """ 33 | def __init__(self, activation: nn.Module, dim: int = -1): 34 | super(CustomGLU, self).__init__() 35 | self.dim = dim 36 | self.activation = activation 37 | 38 | def forward(self, x: Tensor): 39 | assert x.shape[self.dim] % 2 == 0 # M = N / 2 40 | a, b = torch.chunk(x, 2, dim=self.dim) 41 | return a * self.activation(b) 42 | 43 | 44 | class SwiGLU(CustomGLU): 45 | """SiLU Gated Linear Unit activation. 46 | Applies SiLU Gated Linear Unit :math:`a * SiLU(b)` where :math:`a` is 47 | the first half of the input matrices, :math:`b` is the second half. 48 | 49 | Args: 50 | dim (int): the dimension on which to split the input. Default: -1 51 | """ 52 | def __init__(self, dim: int = -1): 53 | super(SwiGLU, self).__init__(nn.SiLU(), dim) 54 | 55 | 56 | class GeGLU(CustomGLU): 57 | """GeLU Gated Linear Unit activation. 58 | Applies GeLU Gated Linear Unit :math:`a * GELU(b)` where :math:`a` is 59 | the first half of the input matrices, :math:`b` is the second half. 60 | 61 | Args: 62 | dim (int): the dimension on which to split the input. Default: -1 63 | """ 64 | def __init__(self, dim: int = -1): 65 | super(GeGLU, self).__init__(nn.GELU(), dim) 66 | 67 | 68 | class ReGLU(CustomGLU): 69 | """ReLU Gated Linear Unit activation. 70 | Applies ReLU Gated Linear Unit :math:`a * ReLU(b)` where :math:`a` is 71 | the first half of the input matrices, :math:`b` is the second half. 72 | 73 | Args: 74 | dim (int): the dimension on which to split the input. Default: -1 75 | """ 76 | def __init__(self, dim: int = -1): 77 | super(ReGLU, self).__init__(nn.ReLU(), dim) 78 | 79 | 80 | def get_activation_fn( 81 | activation: Union[str, Callable[[Tensor], Tensor]] 82 | ) -> Union[str, Callable[[Tensor], Tensor]]: 83 | """Helper function to map an activation string to the activation class. 84 | If the supplied activation is not a string that is recognized, the activation is passed back. 85 | 86 | Args: 87 | activation (Union[str, Callable[[Tensor], Tensor]]): Activation to check 88 | """ 89 | if isinstance(activation, str): 90 | if activation == "reglu": 91 | return ReGLU() 92 | elif activation == "geglu": 93 | return GeGLU() 94 | elif activation == "swiglu": 95 | return SwiGLU() 96 | return activation 97 | -------------------------------------------------------------------------------- /audiocraft/modules/conv.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import math 8 | import typing as tp 9 | import warnings 10 | 11 | import torch 12 | from torch import nn 13 | from torch.nn import functional as F 14 | from torch.nn.utils import spectral_norm, weight_norm 15 | 16 | 17 | CONV_NORMALIZATIONS = frozenset(['none', 'weight_norm', 'spectral_norm', 18 | 'time_group_norm']) 19 | 20 | 21 | def apply_parametrization_norm(module: nn.Module, norm: str = 'none'): 22 | assert norm in CONV_NORMALIZATIONS 23 | if norm == 'weight_norm': 24 | return weight_norm(module) 25 | elif norm == 'spectral_norm': 26 | return spectral_norm(module) 27 | else: 28 | # We already check was in CONV_NORMALIZATION, so any other choice 29 | # doesn't need reparametrization. 30 | return module 31 | 32 | 33 | def get_norm_module(module: nn.Module, causal: bool = False, norm: str = 'none', **norm_kwargs): 34 | """Return the proper normalization module. If causal is True, this will ensure the returned 35 | module is causal, or return an error if the normalization doesn't support causal evaluation. 36 | """ 37 | assert norm in CONV_NORMALIZATIONS 38 | if norm == 'time_group_norm': 39 | if causal: 40 | raise ValueError("GroupNorm doesn't support causal evaluation.") 41 | assert isinstance(module, nn.modules.conv._ConvNd) 42 | return nn.GroupNorm(1, module.out_channels, **norm_kwargs) 43 | else: 44 | return nn.Identity() 45 | 46 | 47 | def get_extra_padding_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int, 48 | padding_total: int = 0) -> int: 49 | """See `pad_for_conv1d`. 50 | """ 51 | length = x.shape[-1] 52 | n_frames = (length - kernel_size + padding_total) / stride + 1 53 | ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total) 54 | return ideal_length - length 55 | 56 | 57 | def pad_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0): 58 | """Pad for a convolution to make sure that the last window is full. 59 | Extra padding is added at the end. This is required to ensure that we can rebuild 60 | an output of the same length, as otherwise, even with padding, some time steps 61 | might get removed. 62 | For instance, with total padding = 4, kernel size = 4, stride = 2: 63 | 0 0 1 2 3 4 5 0 0 # (0s are padding) 64 | 1 2 3 # (output frames of a convolution, last 0 is never used) 65 | 0 0 1 2 3 4 5 0 # (output of tr. conv., but pos. 5 is going to get removed as padding) 66 | 1 2 3 4 # once you removed padding, we are missing one time step ! 67 | """ 68 | extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total) 69 | return F.pad(x, (0, extra_padding)) 70 | 71 | 72 | def pad1d(x: torch.Tensor, paddings: tp.Tuple[int, int], mode: str = 'constant', value: float = 0.): 73 | """Tiny wrapper around F.pad, just to allow for reflect padding on small input. 74 | If this is the case, we insert extra 0 padding to the right before the reflection happen. 75 | """ 76 | length = x.shape[-1] 77 | padding_left, padding_right = paddings 78 | assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right) 79 | if mode == 'reflect': 80 | max_pad = max(padding_left, padding_right) 81 | extra_pad = 0 82 | if length <= max_pad: 83 | extra_pad = max_pad - length + 1 84 | x = F.pad(x, (0, extra_pad)) 85 | padded = F.pad(x, paddings, mode, value) 86 | end = padded.shape[-1] - extra_pad 87 | return padded[..., :end] 88 | else: 89 | return F.pad(x, paddings, mode, value) 90 | 91 | 92 | def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]): 93 | """Remove padding from x, handling properly zero padding. Only for 1d! 94 | """ 95 | padding_left, padding_right = paddings 96 | assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right) 97 | assert (padding_left + padding_right) <= x.shape[-1] 98 | end = x.shape[-1] - padding_right 99 | return x[..., padding_left: end] 100 | 101 | 102 | class NormConv1d(nn.Module): 103 | """Wrapper around Conv1d and normalization applied to this conv 104 | to provide a uniform interface across normalization approaches. 105 | """ 106 | def __init__(self, *args, causal: bool = False, norm: str = 'none', 107 | norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs): 108 | super().__init__() 109 | self.conv = apply_parametrization_norm(nn.Conv1d(*args, **kwargs), norm) 110 | self.norm = get_norm_module(self.conv, causal, norm, **norm_kwargs) 111 | self.norm_type = norm 112 | 113 | def forward(self, x): 114 | x = self.conv(x) 115 | x = self.norm(x) 116 | return x 117 | 118 | 119 | class NormConv2d(nn.Module): 120 | """Wrapper around Conv2d and normalization applied to this conv 121 | to provide a uniform interface across normalization approaches. 122 | """ 123 | def __init__(self, *args, norm: str = 'none', norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs): 124 | super().__init__() 125 | self.conv = apply_parametrization_norm(nn.Conv2d(*args, **kwargs), norm) 126 | self.norm = get_norm_module(self.conv, causal=False, norm=norm, **norm_kwargs) 127 | self.norm_type = norm 128 | 129 | def forward(self, x): 130 | x = self.conv(x) 131 | x = self.norm(x) 132 | return x 133 | 134 | 135 | class NormConvTranspose1d(nn.Module): 136 | """Wrapper around ConvTranspose1d and normalization applied to this conv 137 | to provide a uniform interface across normalization approaches. 138 | """ 139 | def __init__(self, *args, causal: bool = False, norm: str = 'none', 140 | norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs): 141 | super().__init__() 142 | self.convtr = apply_parametrization_norm(nn.ConvTranspose1d(*args, **kwargs), norm) 143 | self.norm = get_norm_module(self.convtr, causal, norm, **norm_kwargs) 144 | self.norm_type = norm 145 | 146 | def forward(self, x): 147 | x = self.convtr(x) 148 | x = self.norm(x) 149 | return x 150 | 151 | 152 | class NormConvTranspose2d(nn.Module): 153 | """Wrapper around ConvTranspose2d and normalization applied to this conv 154 | to provide a uniform interface across normalization approaches. 155 | """ 156 | def __init__(self, *args, norm: str = 'none', norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs): 157 | super().__init__() 158 | self.convtr = apply_parametrization_norm(nn.ConvTranspose2d(*args, **kwargs), norm) 159 | self.norm = get_norm_module(self.convtr, causal=False, norm=norm, **norm_kwargs) 160 | 161 | def forward(self, x): 162 | x = self.convtr(x) 163 | x = self.norm(x) 164 | return x 165 | 166 | 167 | class StreamableConv1d(nn.Module): 168 | """Conv1d with some builtin handling of asymmetric or causal padding 169 | and normalization. 170 | """ 171 | def __init__(self, in_channels: int, out_channels: int, 172 | kernel_size: int, stride: int = 1, dilation: int = 1, 173 | groups: int = 1, bias: bool = True, causal: bool = False, 174 | norm: str = 'none', norm_kwargs: tp.Dict[str, tp.Any] = {}, 175 | pad_mode: str = 'reflect'): 176 | super().__init__() 177 | # warn user on unusual setup between dilation and stride 178 | if stride > 1 and dilation > 1: 179 | warnings.warn('StreamableConv1d has been initialized with stride > 1 and dilation > 1' 180 | f' (kernel_size={kernel_size} stride={stride}, dilation={dilation}).') 181 | self.conv = NormConv1d(in_channels, out_channels, kernel_size, stride, 182 | dilation=dilation, groups=groups, bias=bias, causal=causal, 183 | norm=norm, norm_kwargs=norm_kwargs) 184 | self.causal = causal 185 | self.pad_mode = pad_mode 186 | 187 | def forward(self, x): 188 | B, C, T = x.shape 189 | kernel_size = self.conv.conv.kernel_size[0] 190 | stride = self.conv.conv.stride[0] 191 | dilation = self.conv.conv.dilation[0] 192 | kernel_size = (kernel_size - 1) * dilation + 1 # effective kernel size with dilations 193 | padding_total = kernel_size - stride 194 | extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total) 195 | if self.causal: 196 | # Left padding for causal 197 | x = pad1d(x, (padding_total, extra_padding), mode=self.pad_mode) 198 | else: 199 | # Asymmetric padding required for odd strides 200 | padding_right = padding_total // 2 201 | padding_left = padding_total - padding_right 202 | x = pad1d(x, (padding_left, padding_right + extra_padding), mode=self.pad_mode) 203 | return self.conv(x) 204 | 205 | 206 | class StreamableConvTranspose1d(nn.Module): 207 | """ConvTranspose1d with some builtin handling of asymmetric or causal padding 208 | and normalization. 209 | """ 210 | def __init__(self, in_channels: int, out_channels: int, 211 | kernel_size: int, stride: int = 1, causal: bool = False, 212 | norm: str = 'none', trim_right_ratio: float = 1., 213 | norm_kwargs: tp.Dict[str, tp.Any] = {}): 214 | super().__init__() 215 | self.convtr = NormConvTranspose1d(in_channels, out_channels, kernel_size, stride, 216 | causal=causal, norm=norm, norm_kwargs=norm_kwargs) 217 | self.causal = causal 218 | self.trim_right_ratio = trim_right_ratio 219 | assert self.causal or self.trim_right_ratio == 1., \ 220 | "`trim_right_ratio` != 1.0 only makes sense for causal convolutions" 221 | assert self.trim_right_ratio >= 0. and self.trim_right_ratio <= 1. 222 | 223 | def forward(self, x): 224 | kernel_size = self.convtr.convtr.kernel_size[0] 225 | stride = self.convtr.convtr.stride[0] 226 | padding_total = kernel_size - stride 227 | 228 | y = self.convtr(x) 229 | 230 | # We will only trim fixed padding. Extra padding from `pad_for_conv1d` would be 231 | # removed at the very end, when keeping only the right length for the output, 232 | # as removing it here would require also passing the length at the matching layer 233 | # in the encoder. 234 | if self.causal: 235 | # Trim the padding on the right according to the specified ratio 236 | # if trim_right_ratio = 1.0, trim everything from right 237 | padding_right = math.ceil(padding_total * self.trim_right_ratio) 238 | padding_left = padding_total - padding_right 239 | y = unpad1d(y, (padding_left, padding_right)) 240 | else: 241 | # Asymmetric padding required for odd strides 242 | padding_right = padding_total // 2 243 | padding_left = padding_total - padding_right 244 | y = unpad1d(y, (padding_left, padding_right)) 245 | return y 246 | -------------------------------------------------------------------------------- /audiocraft/modules/lstm.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from torch import nn 8 | 9 | 10 | class StreamableLSTM(nn.Module): 11 | """LSTM without worrying about the hidden state, nor the layout of the data. 12 | Expects input as convolutional layout. 13 | """ 14 | def __init__(self, dimension: int, num_layers: int = 2, skip: bool = True): 15 | super().__init__() 16 | self.skip = skip 17 | self.lstm = nn.LSTM(dimension, dimension, num_layers) 18 | 19 | def forward(self, x): 20 | x = x.permute(2, 0, 1) 21 | y, _ = self.lstm(x) 22 | if self.skip: 23 | y = y + x 24 | y = y.permute(1, 2, 0) 25 | return y 26 | -------------------------------------------------------------------------------- /audiocraft/modules/rope.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import typing as tp 8 | 9 | from torch import nn 10 | import torch 11 | 12 | 13 | class XPos(nn.Module): 14 | """Length-extrapolatable positional embedding (xPos) from [Sun et al 2022](https://arxiv.org/abs/2212.10554v1). 15 | This applies an exponential decay to the RoPE rotation matrix. 16 | 17 | Args: 18 | dim (int): Embedding dimension. 19 | smoothing (float): Smoothing factor applied to the decay rates. 20 | base_scale (int): Base decay rate, given in terms of scaling time. 21 | device (torch.device or None): Device on which to initialize the module. 22 | dtype (torch.dtype): dtype to use to generate the embedding. 23 | """ 24 | def __init__(self, dim: int, smoothing: float = 0.4, base_scale: int = 512, 25 | device=None, dtype: torch.dtype = torch.float32): 26 | super().__init__() 27 | assert dim % 2 == 0 28 | assert dtype in [torch.float64, torch.float32] 29 | self.dtype = dtype 30 | self.base_scale = base_scale 31 | 32 | half_dim = dim // 2 33 | adim = torch.arange(half_dim, device=device, dtype=dtype) 34 | decay_rates = (adim / half_dim + smoothing) / (1.0 + smoothing) 35 | self.register_buffer("decay_rates", decay_rates) 36 | self.decay: tp.Optional[torch.Tensor] = None 37 | 38 | def get_decay(self, start: int, end: int): 39 | """Create complex decay tensor, cache values for fast computation. 40 | """ 41 | if self.decay is None or end > self.decay.shape[0]: 42 | assert isinstance(self.decay_rates, torch.Tensor) # Satisfy type checker. 43 | idx = torch.arange(end, device=self.decay_rates.device, dtype=self.dtype) 44 | power = idx / self.base_scale 45 | scale = self.decay_rates ** power.unsqueeze(-1) 46 | self.decay = torch.polar(scale, torch.zeros_like(scale)) 47 | return self.decay[start:end] # [T, C/2] 48 | 49 | 50 | class RotaryEmbedding(nn.Module): 51 | """Rotary positional embedding (RoPE) from [Su et al 2022](https://arxiv.org/abs/2104.09864). 52 | 53 | Args: 54 | dim (int): Embedding dimension (twice the number of frequencies). 55 | max_period (float): Maximum period of the rotation frequencies. 56 | xpos (bool): Use xPos, applies an exponential decay to rotation matrix. 57 | scale (float): Scale of positional embedding, set to 0 to deactivate. 58 | device (torch.device or None): Device on which to initialize the module. 59 | dtype (torch.dtype): dtype to use to generate the embedding. 60 | """ 61 | def __init__(self, dim: int, max_period: float = 10000.0, xpos: bool = False, 62 | scale: float = 1.0, device=None, dtype: torch.dtype = torch.float32): 63 | super().__init__() 64 | assert dim % 2 == 0 65 | self.scale = scale 66 | assert dtype in [torch.float64, torch.float32] 67 | self.dtype = dtype 68 | 69 | adim = torch.arange(0, dim, 2, device=device, dtype=dtype)[: (dim // 2)] 70 | frequencies = 1.0 / (max_period ** (adim / dim)) 71 | self.register_buffer("frequencies", frequencies) 72 | self.rotation: tp.Optional[torch.Tensor] = None 73 | 74 | self.xpos = XPos(dim, device=device, dtype=dtype) if xpos else None 75 | 76 | def get_rotation(self, start: int, end: int): 77 | """Create complex rotation tensor, cache values for fast computation. 78 | """ 79 | if self.rotation is None or end > self.rotation.shape[0]: 80 | assert isinstance(self.frequencies, torch.Tensor) # Satisfy type checker. 81 | idx = torch.arange(end, device=self.frequencies.device, dtype=self.dtype) 82 | angles = torch.outer(idx, self.frequencies) 83 | self.rotation = torch.polar(torch.ones_like(angles), angles) 84 | return self.rotation[start:end] 85 | 86 | def rotate(self, x: torch.Tensor, start: int = 0, invert_decay: bool = False): 87 | """Apply rope rotation to query or key tensor. 88 | """ 89 | T = x.shape[1] 90 | rotation = self.get_rotation(start, start + T).unsqueeze(0).unsqueeze(2) 91 | 92 | if self.xpos: 93 | decay = self.xpos.get_decay(start, start + T).unsqueeze(0).unsqueeze(2) 94 | else: 95 | decay = 1.0 96 | 97 | if invert_decay: 98 | decay = decay ** -1 99 | 100 | x_complex = torch.view_as_complex(x.to(self.dtype).reshape(*x.shape[:-1], -1, 2)) 101 | scaled_rotation = (rotation * decay) * self.scale + (1.0 - self.scale) 102 | x_out = torch.view_as_real(x_complex * scaled_rotation).flatten(-2) 103 | 104 | return x_out.type_as(x) 105 | 106 | def rotate_qk(self, query: torch.Tensor, key: torch.Tensor, start: int = 0): 107 | """ Apply rope rotation to both query and key tensors. 108 | Supports streaming mode, in which query and key are not expected to have the same shape. 109 | In streaming mode, key will be of legnth [P + C] with P the cached past timesteps, but 110 | query will be [C] (typically C == 1). 111 | 112 | Args: 113 | query (torch.Tensor): Query to rotate. 114 | key (torch.Tensor): Key to rotate. 115 | start (int): Start index of the sequence for time offset. 116 | """ 117 | query_timesteps = query.shape[1] 118 | key_timesteps = key.shape[1] 119 | streaming_offset = key_timesteps - query_timesteps 120 | 121 | query_out = self.rotate(query, start + streaming_offset) 122 | key_out = self.rotate(key, start, invert_decay=True) 123 | 124 | return query_out, key_out 125 | -------------------------------------------------------------------------------- /audiocraft/modules/streaming.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | """ 8 | Streaming module API that should be implemented by all Streaming components, 9 | """ 10 | 11 | from contextlib import contextmanager 12 | import typing as tp 13 | from torch import nn 14 | import torch 15 | 16 | 17 | State = tp.Dict[str, torch.Tensor] 18 | 19 | 20 | class StreamingModule(nn.Module): 21 | """Common API for streaming components. 22 | 23 | Each streaming component has a streaming state, which is just a dict[str, Tensor]. 24 | By convention, the first dim of each tensor must be the batch size. 25 | Don't use dots in the key names, as this would clash with submodules 26 | (like in state_dict). 27 | 28 | If `self._is_streaming` is True, the component should use and remember 29 | the proper state inside `self._streaming_state`. 30 | 31 | To set a streaming component in streaming state, use 32 | 33 | with module.streaming(): 34 | ... 35 | 36 | This will automatically reset the streaming state when exiting the context manager. 37 | This also automatically propagates to all streaming children module. 38 | 39 | Some module might also implement the `StreamingModule.flush` method, although 40 | this one is trickier, as all parents module must be StreamingModule and implement 41 | it as well for it to work properly. See `StreamingSequential` after. 42 | """ 43 | def __init__(self) -> None: 44 | super().__init__() 45 | self._streaming_state: State = {} 46 | self._is_streaming = False 47 | 48 | def _apply_named_streaming(self, fn: tp.Any): 49 | for name, module in self.named_modules(): 50 | if isinstance(module, StreamingModule): 51 | fn(name, module) 52 | 53 | def _set_streaming(self, streaming: bool): 54 | def _set_streaming(name, module): 55 | module._is_streaming = streaming 56 | self._apply_named_streaming(_set_streaming) 57 | 58 | @contextmanager 59 | def streaming(self): 60 | """Context manager to enter streaming mode. Reset streaming state on exit. 61 | """ 62 | self._set_streaming(True) 63 | try: 64 | yield 65 | finally: 66 | self._set_streaming(False) 67 | self.reset_streaming() 68 | 69 | def reset_streaming(self): 70 | """Reset the streaming state. 71 | """ 72 | def _reset(name: str, module: StreamingModule): 73 | module._streaming_state.clear() 74 | 75 | self._apply_named_streaming(_reset) 76 | 77 | def get_streaming_state(self) -> State: 78 | """Return the streaming state, including that of sub-modules. 79 | """ 80 | state: State = {} 81 | 82 | def _add(name: str, module: StreamingModule): 83 | if name: 84 | name += "." 85 | for key, value in module._streaming_state.items(): 86 | state[name + key] = value 87 | 88 | self._apply_named_streaming(_add) 89 | return state 90 | 91 | def set_streaming_state(self, state: State): 92 | """Set the streaming state, including that of sub-modules. 93 | """ 94 | state = dict(state) 95 | 96 | def _set(name: str, module: StreamingModule): 97 | if name: 98 | name += "." 99 | module._streaming_state.clear() 100 | for key, value in list(state.items()): 101 | # complexity is not ideal here, but probably fine. 102 | if key.startswith(name): 103 | local_key = key[len(name):] 104 | if '.' not in local_key: 105 | module._streaming_state[local_key] = value 106 | del state[key] 107 | 108 | self._apply_named_streaming(_set) 109 | assert len(state) == 0, list(state.keys()) 110 | 111 | def flush(self, x: tp.Optional[torch.Tensor] = None): 112 | """Flush any remaining outputs that were waiting for completion. 113 | Typically, for convolutions, this will add the final padding 114 | and process the last buffer. 115 | 116 | This should take an optional argument `x`, which will be provided 117 | if a module before this one in the streaming pipeline has already 118 | spitted out a flushed out buffer. 119 | """ 120 | if x is None: 121 | return None 122 | else: 123 | return self(x) 124 | 125 | 126 | class StreamingSequential(StreamingModule, nn.Sequential): 127 | """A streaming compatible alternative of `nn.Sequential`. 128 | """ 129 | def flush(self, x: tp.Optional[torch.Tensor] = None): 130 | for module in self: 131 | if isinstance(module, StreamingModule): 132 | x = module.flush(x) 133 | elif x is not None: 134 | x = module(x) 135 | return x 136 | -------------------------------------------------------------------------------- /audiocraft/py.typed: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rkfg/audiocraft/6d70065e31c2fb422a76237e03740dd3b627de8d/audiocraft/py.typed -------------------------------------------------------------------------------- /audiocraft/quantization/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # flake8: noqa 8 | from .vq import ResidualVectorQuantizer 9 | from .base import BaseQuantizer, DummyQuantizer, QuantizedResult 10 | -------------------------------------------------------------------------------- /audiocraft/quantization/base.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | """ 8 | Base class for all quantizers. 9 | """ 10 | 11 | from dataclasses import dataclass, field 12 | import typing as tp 13 | 14 | import torch 15 | from torch import nn 16 | 17 | 18 | @dataclass 19 | class QuantizedResult: 20 | x: torch.Tensor 21 | codes: torch.Tensor 22 | bandwidth: torch.Tensor # bandwidth in kb/s used, per batch item. 23 | penalty: tp.Optional[torch.Tensor] = None 24 | metrics: dict = field(default_factory=dict) 25 | 26 | 27 | class BaseQuantizer(nn.Module): 28 | """Base class for quantizers. 29 | """ 30 | 31 | def forward(self, x: torch.Tensor, frame_rate: int) -> QuantizedResult: 32 | """ 33 | Given input tensor x, returns first the quantized (or approximately quantized) 34 | representation along with quantized codes, bandwidth, and any penalty term for the loss. 35 | Finally, this returns a dict of metrics to update logging etc. 36 | Frame rate must be passed so that the bandwidth is properly computed. 37 | """ 38 | raise NotImplementedError() 39 | 40 | def encode(self, x: torch.Tensor) -> torch.Tensor: 41 | """Encode a given input tensor with the specified sample rate at the given bandwidth. 42 | """ 43 | raise NotImplementedError() 44 | 45 | def decode(self, codes: torch.Tensor) -> torch.Tensor: 46 | """Decode the given codes to the quantized representation. 47 | """ 48 | raise NotImplementedError() 49 | 50 | @property 51 | def total_codebooks(self): 52 | """Total number of codebooks. 53 | """ 54 | raise NotImplementedError() 55 | 56 | @property 57 | def num_codebooks(self): 58 | """Number of active codebooks. 59 | """ 60 | raise NotImplementedError() 61 | 62 | def set_num_codebooks(self, n: int): 63 | """Set the number of active codebooks. 64 | """ 65 | raise NotImplementedError() 66 | 67 | 68 | class DummyQuantizer(BaseQuantizer): 69 | """Fake quantizer that actually does not perform any quantization. 70 | """ 71 | def __init__(self): 72 | super().__init__() 73 | 74 | def forward(self, x: torch.Tensor, frame_rate: int): 75 | q = x.unsqueeze(1) 76 | return QuantizedResult(x, q, torch.tensor(q.numel() * 32 * frame_rate / 1000 / len(x)).to(x)) 77 | 78 | def encode(self, x: torch.Tensor) -> torch.Tensor: 79 | """Encode a given input tensor with the specified sample rate at the given bandwidth. 80 | In the case of the DummyQuantizer, the codes are actually identical 81 | to the input and resulting quantized representation as no quantization is done. 82 | """ 83 | return x.unsqueeze(1) 84 | 85 | def decode(self, codes: torch.Tensor) -> torch.Tensor: 86 | """Decode the given codes to the quantized representation. 87 | In the case of the DummyQuantizer, the codes are actually identical 88 | to the input and resulting quantized representation as no quantization is done. 89 | """ 90 | return codes.squeeze(1) 91 | 92 | @property 93 | def total_codebooks(self): 94 | """Total number of codebooks. 95 | """ 96 | return 1 97 | 98 | @property 99 | def num_codebooks(self): 100 | """Total number of codebooks. 101 | """ 102 | return self.total_codebooks 103 | 104 | def set_num_codebooks(self, n: int): 105 | """Set the number of active codebooks. 106 | """ 107 | raise AttributeError("Cannot override the number of codebooks for the dummy quantizer") 108 | -------------------------------------------------------------------------------- /audiocraft/quantization/vq.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import math 8 | import typing as tp 9 | 10 | import torch 11 | 12 | from .base import BaseQuantizer, QuantizedResult 13 | from .core_vq import ResidualVectorQuantization 14 | 15 | 16 | class ResidualVectorQuantizer(BaseQuantizer): 17 | """Residual Vector Quantizer. 18 | 19 | Args: 20 | dimension (int): Dimension of the codebooks. 21 | n_q (int): Number of residual vector quantizers used. 22 | q_dropout (bool): Random quantizer drop out at train time. 23 | bins (int): Codebook size. 24 | decay (float): Decay for exponential moving average over the codebooks. 25 | kmeans_init (bool): Whether to use kmeans to initialize the codebooks. 26 | kmeans_iters (int): Number of iterations used for kmeans initialization. 27 | threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes 28 | that have an exponential moving average cluster size less than the specified threshold with 29 | randomly selected vector from the current batch. 30 | orthogonal_reg_weight (float): Orthogonal regularization weights. 31 | orthogonal_reg_active_codes_only (bool): Apply orthogonal regularization only on active codes. 32 | orthogonal_reg_max_codes (optional int): Maximum number of codes to consider. 33 | for orthogonal regulariation. 34 | """ 35 | def __init__( 36 | self, 37 | dimension: int = 256, 38 | n_q: int = 8, 39 | q_dropout: bool = False, 40 | bins: int = 1024, 41 | decay: float = 0.99, 42 | kmeans_init: bool = True, 43 | kmeans_iters: int = 10, 44 | threshold_ema_dead_code: int = 2, 45 | orthogonal_reg_weight: float = 0.0, 46 | orthogonal_reg_active_codes_only: bool = False, 47 | orthogonal_reg_max_codes: tp.Optional[int] = None, 48 | ): 49 | super().__init__() 50 | self.max_n_q = n_q 51 | self.n_q = n_q 52 | self.q_dropout = q_dropout 53 | self.dimension = dimension 54 | self.bins = bins 55 | self.decay = decay 56 | self.kmeans_init = kmeans_init 57 | self.kmeans_iters = kmeans_iters 58 | self.threshold_ema_dead_code = threshold_ema_dead_code 59 | self.orthogonal_reg_weight = orthogonal_reg_weight 60 | self.orthogonal_reg_active_codes_only = orthogonal_reg_active_codes_only 61 | self.orthogonal_reg_max_codes = orthogonal_reg_max_codes 62 | self.vq = ResidualVectorQuantization( 63 | dim=self.dimension, 64 | codebook_size=self.bins, 65 | num_quantizers=self.n_q, 66 | decay=self.decay, 67 | kmeans_init=self.kmeans_init, 68 | kmeans_iters=self.kmeans_iters, 69 | threshold_ema_dead_code=self.threshold_ema_dead_code, 70 | orthogonal_reg_weight=self.orthogonal_reg_weight, 71 | orthogonal_reg_active_codes_only=self.orthogonal_reg_active_codes_only, 72 | orthogonal_reg_max_codes=self.orthogonal_reg_max_codes, 73 | channels_last=False 74 | ) 75 | 76 | def forward(self, x: torch.Tensor, frame_rate: int): 77 | n_q = self.n_q 78 | if self.training and self.q_dropout: 79 | n_q = int(torch.randint(1, self.n_q + 1, (1,)).item()) 80 | bw_per_q = math.log2(self.bins) * frame_rate / 1000 81 | quantized, codes, commit_loss = self.vq(x, n_q=n_q) 82 | codes = codes.transpose(0, 1) 83 | # codes is [B, K, T], with T frames, K nb of codebooks. 84 | bw = torch.tensor(n_q * bw_per_q).to(x) 85 | return QuantizedResult(quantized, codes, bw, penalty=torch.mean(commit_loss)) 86 | 87 | def encode(self, x: torch.Tensor) -> torch.Tensor: 88 | """Encode a given input tensor with the specified frame rate at the given bandwidth. 89 | The RVQ encode method sets the appropriate number of quantizer to use 90 | and returns indices for each quantizer. 91 | """ 92 | n_q = self.n_q 93 | codes = self.vq.encode(x, n_q=n_q) 94 | codes = codes.transpose(0, 1) 95 | # codes is [B, K, T], with T frames, K nb of codebooks. 96 | return codes 97 | 98 | def decode(self, codes: torch.Tensor) -> torch.Tensor: 99 | """Decode the given codes to the quantized representation. 100 | """ 101 | # codes is [B, K, T], with T frames, K nb of codebooks, vq.decode expects [K, B, T]. 102 | codes = codes.transpose(0, 1) 103 | quantized = self.vq.decode(codes) 104 | return quantized 105 | 106 | @property 107 | def total_codebooks(self): 108 | return self.max_n_q 109 | 110 | @property 111 | def num_codebooks(self): 112 | return self.n_q 113 | 114 | def set_num_codebooks(self, n: int): 115 | assert n > 0 and n <= self.max_n_q 116 | self.n_q = n 117 | -------------------------------------------------------------------------------- /audiocraft/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /audiocraft/utils/autocast.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | 9 | 10 | class TorchAutocast: 11 | """TorchAutocast utility class. 12 | Allows you to enable and disable autocast. This is specially useful 13 | when dealing with different architectures and clusters with different 14 | levels of support. 15 | 16 | Args: 17 | enabled (bool): Whether to enable torch.autocast or not. 18 | args: Additional args for torch.autocast. 19 | kwargs: Additional kwargs for torch.autocast 20 | """ 21 | def __init__(self, enabled: bool, *args, **kwargs): 22 | self.autocast = torch.autocast(*args, **kwargs) if enabled else None 23 | 24 | def __enter__(self): 25 | if self.autocast is None: 26 | return 27 | try: 28 | self.autocast.__enter__() 29 | except RuntimeError: 30 | device = self.autocast.device 31 | dtype = self.autocast.fast_dtype 32 | raise RuntimeError( 33 | f"There was an error autocasting with dtype={dtype} device={device}\n" 34 | "If you are on the FAIR Cluster, you might need to use autocast_dtype=float16" 35 | ) 36 | 37 | def __exit__(self, *args, **kwargs): 38 | if self.autocast is None: 39 | return 40 | self.autocast.__exit__(*args, **kwargs) 41 | -------------------------------------------------------------------------------- /audiocraft/utils/export.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | """ 8 | Utility to export a training checkpoint to a lightweight release checkpoint. 9 | """ 10 | 11 | from pathlib import Path 12 | import typing as tp 13 | 14 | from omegaconf import OmegaConf, DictConfig 15 | import torch 16 | 17 | 18 | def _clean_lm_cfg(cfg: DictConfig): 19 | OmegaConf.set_struct(cfg, False) 20 | # This used to be set automatically in the LM solver, need a more robust solution 21 | # for the future. 22 | cfg['transformer_lm']['card'] = 2048 23 | cfg['transformer_lm']['n_q'] = 4 24 | # Experimental params no longer supported. 25 | bad_params = ['spectral_norm_attn_iters', 'spectral_norm_ff_iters', 26 | 'residual_balancer_attn', 'residual_balancer_ff', 'layer_drop'] 27 | for name in bad_params: 28 | del cfg['transformer_lm'][name] 29 | OmegaConf.set_struct(cfg, True) 30 | return cfg 31 | 32 | 33 | def export_encodec(checkpoint_path: tp.Union[Path, str], out_folder: tp.Union[Path, str]): 34 | sig = Path(checkpoint_path).parent.name 35 | assert len(sig) == 8, "Not a valid Dora signature" 36 | pkg = torch.load(checkpoint_path, 'cpu') 37 | new_pkg = { 38 | 'best_state': pkg['ema']['state']['model'], 39 | 'xp.cfg': OmegaConf.to_yaml(pkg['xp.cfg']), 40 | } 41 | out_file = Path(out_folder) / f'{sig}.th' 42 | torch.save(new_pkg, out_file) 43 | return out_file 44 | 45 | 46 | def export_lm(checkpoint_path: tp.Union[Path, str], out_folder: tp.Union[Path, str]): 47 | sig = Path(checkpoint_path).parent.name 48 | assert len(sig) == 8, "Not a valid Dora signature" 49 | pkg = torch.load(checkpoint_path, 'cpu') 50 | new_pkg = { 51 | 'best_state': pkg['fsdp_best_state']['model'], 52 | 'xp.cfg': OmegaConf.to_yaml(_clean_lm_cfg(pkg['xp.cfg'])) 53 | } 54 | out_file = Path(out_folder) / f'{sig}.th' 55 | torch.save(new_pkg, out_file) 56 | return out_file 57 | -------------------------------------------------------------------------------- /audiocraft/utils/notebook.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | try: 8 | import IPython.display as ipd # type: ignore 9 | except ImportError: 10 | # Note in a notebook... 11 | pass 12 | 13 | 14 | import torch 15 | 16 | 17 | def display_audio(samples: torch.Tensor, sample_rate: int): 18 | """Renders an audio player for the given audio samples. 19 | 20 | Args: 21 | samples (torch.Tensor): a Tensor of decoded audio samples 22 | with shapes [B, C, T] or [C, T] 23 | sample_rate (int): sample rate audio should be displayed with. 24 | """ 25 | assert samples.dim() == 2 or samples.dim() == 3 26 | 27 | samples = samples.detach().cpu() 28 | if samples.dim() == 2: 29 | samples = samples[None, ...] 30 | 31 | for audio in samples: 32 | ipd.display(ipd.Audio(audio, rate=sample_rate)) 33 | -------------------------------------------------------------------------------- /audiocraft/utils/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from concurrent.futures import ProcessPoolExecutor 8 | from functools import wraps 9 | import hashlib 10 | import logging 11 | import typing as tp 12 | 13 | import flashy 14 | import flashy.distrib 15 | import omegaconf 16 | import torch 17 | from torch.nn.utils.rnn import pad_sequence 18 | 19 | 20 | logger = logging.getLogger(__name__) 21 | 22 | 23 | def dict_from_config(cfg: omegaconf.DictConfig) -> dict: 24 | """Convenience function to map an omegaconf configuration to a dictionary. 25 | 26 | Args: 27 | cfg (omegaconf.DictConfig): Original configuration to map to dict. 28 | Returns: 29 | dict: Config as dictionary object. 30 | """ 31 | dct = omegaconf.OmegaConf.to_container(cfg, resolve=True) 32 | assert isinstance(dct, dict) 33 | return dct 34 | 35 | 36 | def random_subset(dataset, max_samples: int, seed: int = 42) -> torch.utils.data.Subset: 37 | if max_samples >= len(dataset): 38 | return dataset 39 | 40 | generator = torch.Generator().manual_seed(seed) 41 | perm = torch.randperm(len(dataset), generator=generator) 42 | return torch.utils.data.Subset(dataset, perm[:max_samples].tolist()) 43 | 44 | 45 | def get_loader(dataset, num_samples: tp.Optional[int], batch_size: int, 46 | num_workers: int, seed: int, **kwargs) -> torch.utils.data.DataLoader: 47 | """Convenience function to load dataset into a dataloader with optional subset sampling. 48 | 49 | Args: 50 | dataset: Dataset to load. 51 | num_samples (Optional[int]): Number of samples to limit subset size. 52 | batch_size (int): Batch size. 53 | num_workers (int): Number of workers for data loading. 54 | seed (int): Random seed. 55 | """ 56 | if num_samples is not None: 57 | dataset = random_subset(dataset, num_samples, seed) 58 | 59 | dataloader = flashy.distrib.loader( 60 | dataset, 61 | batch_size=batch_size, 62 | num_workers=num_workers, 63 | **kwargs 64 | ) 65 | return dataloader 66 | 67 | 68 | def get_dataset_from_loader(dataloader): 69 | dataset = dataloader.dataset 70 | if isinstance(dataset, torch.utils.data.Subset): 71 | return dataset.dataset 72 | else: 73 | return dataset 74 | 75 | 76 | def multinomial(input: torch.Tensor, num_samples: int, replacement=False, *, generator=None): 77 | """torch.multinomial with arbitrary number of dimensions, and number of candidates on the last dimension. 78 | 79 | Args: 80 | input (torch.Tensor): The input tensor containing probabilities. 81 | num_samples (int): Number of samples to draw. 82 | replacement (bool): Whether to draw with replacement or not. 83 | Keywords args: 84 | generator (torch.Generator): A pseudorandom number generator for sampling. 85 | Returns: 86 | torch.Tensor: Last dimension contains num_samples indices 87 | sampled from the multinomial probability distribution 88 | located in the last dimension of tensor input. 89 | """ 90 | input_ = input.reshape(-1, input.shape[-1]) 91 | output_ = torch.multinomial(input_, num_samples=num_samples, replacement=replacement, generator=generator) 92 | output = output_.reshape(*list(input.shape[:-1]), -1) 93 | return output 94 | 95 | 96 | def sample_top_k(probs: torch.Tensor, k: int) -> torch.Tensor: 97 | """Sample next token from top K values along the last dimension of the input probs tensor. 98 | 99 | Args: 100 | probs (torch.Tensor): Input probabilities with token candidates on the last dimension. 101 | k (int): The k in “top-k”. 102 | Returns: 103 | torch.Tensor: Sampled tokens. 104 | """ 105 | top_k_value, _ = torch.topk(probs, k, dim=-1) 106 | min_value_top_k = top_k_value[..., [-1]] 107 | probs *= (probs >= min_value_top_k).float() 108 | probs.div_(probs.sum(dim=-1, keepdim=True)) 109 | next_token = multinomial(probs, num_samples=1) 110 | return next_token 111 | 112 | 113 | def sample_top_p(probs: torch.Tensor, p: float) -> torch.Tensor: 114 | """Sample next token from top P probabilities along the last dimension of the input probs tensor. 115 | 116 | Args: 117 | probs (torch.Tensor): Input probabilities with token candidates on the last dimension. 118 | p (int): The p in “top-p”. 119 | Returns: 120 | torch.Tensor: Sampled tokens. 121 | """ 122 | probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True) 123 | probs_sum = torch.cumsum(probs_sort, dim=-1) 124 | mask = probs_sum - probs_sort > p 125 | probs_sort *= (~mask).float() 126 | probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) 127 | next_token = multinomial(probs_sort, num_samples=1) 128 | next_token = torch.gather(probs_idx, -1, next_token) 129 | return next_token 130 | 131 | 132 | class DummyPoolExecutor: 133 | """Dummy pool executor to use when we actually have only 1 worker. 134 | (e.g. instead of ProcessPoolExecutor). 135 | """ 136 | class DummyResult: 137 | def __init__(self, func, *args, **kwargs): 138 | self.func = func 139 | self.args = args 140 | self.kwargs = kwargs 141 | 142 | def result(self): 143 | return self.func(*self.args, **self.kwargs) 144 | 145 | def __init__(self, workers, mp_context=None): 146 | pass 147 | 148 | def submit(self, func, *args, **kwargs): 149 | return DummyPoolExecutor.DummyResult(func, *args, **kwargs) 150 | 151 | def __enter__(self): 152 | return self 153 | 154 | def __exit__(self, exc_type, exc_value, exc_tb): 155 | return 156 | 157 | 158 | def get_pool_executor(num_workers: int, mp_context=None): 159 | return ProcessPoolExecutor(num_workers, mp_context) if num_workers > 1 else DummyPoolExecutor(1) 160 | 161 | 162 | def length_to_mask(lengths: torch.Tensor, max_len: tp.Optional[int] = None) -> torch.Tensor: 163 | """Utility function to convert a tensor of sequence lengths to a mask (useful when working on padded sequences). 164 | For example: [3, 5] => [[1, 1, 1, 0, 0], [1, 1, 1, 1, 1]] 165 | 166 | Args: 167 | lengths (torch.Tensor): tensor with lengths 168 | max_len (int): can set the max length manually. Defaults to None. 169 | Returns: 170 | torch.Tensor: mask with 0s where there is pad tokens else 1s 171 | """ 172 | assert len(lengths.shape) == 1, "Length shape should be 1 dimensional." 173 | final_length = lengths.max().item() if not max_len else max_len 174 | final_length = max(final_length, 1) # if all seqs are of len zero we don't want a zero-size tensor 175 | return torch.arange(final_length)[None, :].to(lengths.device) < lengths[:, None] 176 | 177 | 178 | def hash_trick(word: str, vocab_size: int) -> int: 179 | """Hash trick to pair each word with an index 180 | 181 | Args: 182 | word (str): word we wish to convert to an index 183 | vocab_size (int): size of the vocabulary 184 | Returns: 185 | int: index of the word in the embedding LUT 186 | """ 187 | hash = int(hashlib.sha256(word.encode("utf-8")).hexdigest(), 16) 188 | return hash % vocab_size 189 | 190 | 191 | def with_rank_rng(base_seed: int = 1234): 192 | """Decorator for a function so that the function will use a Random Number Generator 193 | whose state depend on the GPU rank. The original RNG state is restored upon returning. 194 | 195 | Args: 196 | base_seed (int): Random seed. 197 | """ 198 | def _decorator(fun: tp.Callable): 199 | @wraps(fun) 200 | def _decorated(*args, **kwargs): 201 | state = torch.get_rng_state() 202 | seed = base_seed ^ flashy.distrib.rank() 203 | torch.manual_seed(seed) 204 | logger.debug('Rank dependent seed set to %d', seed) 205 | try: 206 | return fun(*args, **kwargs) 207 | finally: 208 | torch.set_rng_state(state) 209 | logger.debug('RNG state restored.') 210 | return _decorated 211 | return _decorator 212 | 213 | 214 | def collate(tensors: tp.List[torch.Tensor], dim: int = 0) -> tp.Tuple[torch.Tensor, torch.Tensor]: 215 | """Get a list of tensors and collate them to a single tensor. according to the following logic: 216 | - `dim` specifies the time dimension which will be stacked and padded. 217 | - The output will contain 1 new dimension (dimension index 0) which will be the size of 218 | of the original list. 219 | 220 | Args: 221 | tensors (tp.List[torch.Tensor]): List of tensors to collate. 222 | dim (int): Dimension which will be stacked and padded. 223 | Returns: 224 | tp.Tuple[torch.Tensor, torch.Tensor]: 225 | torch.Tensor: Stacked and padded tensor. The output will contain 1 new dimension 226 | (dimension index 0) which will be the size of the original list. 227 | torch.Tensor: Tensor containing length of original tensor sizes (without padding). 228 | """ 229 | tensors = [x.transpose(0, dim) for x in tensors] 230 | lens = torch.LongTensor([len(x) for x in tensors]) 231 | padded_tensors = pad_sequence(tensors) 232 | padded_tensors = padded_tensors.transpose(0, 1) 233 | padded_tensors = padded_tensors.transpose(1, dim + 1) 234 | return padded_tensors, lens 235 | -------------------------------------------------------------------------------- /demo.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# MusicGen\n", 8 | "Welcome to MusicGen's demo jupyter notebook. Here you will find a series of self-contained examples of how to use MusicGen in different settings.\n", 9 | "\n", 10 | "First, we start by initializing MusicGen, you can choose a model from the following selection:\n", 11 | "1. `small` - 300M transformer decoder.\n", 12 | "2. `medium` - 1.5B transformer decoder.\n", 13 | "3. `melody` - 1.5B transformer decoder also supporting melody conditioning.\n", 14 | "4. `large` - 3.3B transformer decoder.\n", 15 | "\n", 16 | "We will use the `small` variant for the purpose of this demonstration." 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": null, 22 | "metadata": {}, 23 | "outputs": [], 24 | "source": [ 25 | "from audiocraft.models import MusicGen\n", 26 | "\n", 27 | "# Using small model, better results would be obtained with `medium` or `large`.\n", 28 | "model = MusicGen.get_pretrained('small')" 29 | ] 30 | }, 31 | { 32 | "cell_type": "markdown", 33 | "metadata": {}, 34 | "source": [ 35 | "Next, let us configure the generation parameters. Specifically, you can control the following:\n", 36 | "* `use_sampling` (bool, optional): use sampling if True, else do argmax decoding. Defaults to True.\n", 37 | "* `top_k` (int, optional): top_k used for sampling. Defaults to 250.\n", 38 | "* `top_p` (float, optional): top_p used for sampling, when set to 0 top_k is used. Defaults to 0.0.\n", 39 | "* `temperature` (float, optional): softmax temperature parameter. Defaults to 1.0.\n", 40 | "* `duration` (float, optional): duration of the generated waveform. Defaults to 30.0.\n", 41 | "* `cfg_coef` (float, optional): coefficient used for classifier free guidance. Defaults to 3.0.\n", 42 | "\n", 43 | "When left unchanged, MusicGen will revert to its default parameters." 44 | ] 45 | }, 46 | { 47 | "cell_type": "code", 48 | "execution_count": null, 49 | "metadata": {}, 50 | "outputs": [], 51 | "source": [ 52 | "model.set_generation_params(\n", 53 | " use_sampling=True,\n", 54 | " top_k=250,\n", 55 | " duration=5\n", 56 | ")" 57 | ] 58 | }, 59 | { 60 | "cell_type": "markdown", 61 | "metadata": {}, 62 | "source": [ 63 | "Next, we can go ahead and start generating music using one of the following modes:\n", 64 | "* Unconditional samples using `model.generate_unconditional`\n", 65 | "* Music continuation using `model.generate_continuation`\n", 66 | "* Text-conditional samples using `model.generate`\n", 67 | "* Melody-conditional samples using `model.generate_with_chroma`" 68 | ] 69 | }, 70 | { 71 | "cell_type": "markdown", 72 | "metadata": {}, 73 | "source": [ 74 | "### Unconditional Generation" 75 | ] 76 | }, 77 | { 78 | "cell_type": "code", 79 | "execution_count": null, 80 | "metadata": {}, 81 | "outputs": [], 82 | "source": [ 83 | "from audiocraft.utils.notebook import display_audio\n", 84 | "\n", 85 | "output = model.generate_unconditional(num_samples=2, progress=True)\n", 86 | "display_audio(output, sample_rate=32000)" 87 | ] 88 | }, 89 | { 90 | "cell_type": "markdown", 91 | "metadata": {}, 92 | "source": [ 93 | "### Music Continuation" 94 | ] 95 | }, 96 | { 97 | "cell_type": "code", 98 | "execution_count": null, 99 | "metadata": {}, 100 | "outputs": [], 101 | "source": [ 102 | "import math\n", 103 | "import torchaudio\n", 104 | "import torch\n", 105 | "from audiocraft.utils.notebook import display_audio\n", 106 | "\n", 107 | "def get_bip_bip(bip_duration=0.125, frequency=440,\n", 108 | " duration=0.5, sample_rate=32000, device=\"cuda\"):\n", 109 | " \"\"\"Generates a series of bip bip at the given frequency.\"\"\"\n", 110 | " t = torch.arange(\n", 111 | " int(duration * sample_rate), device=\"cuda\", dtype=torch.float) / sample_rate\n", 112 | " wav = torch.cos(2 * math.pi * 440 * t)[None]\n", 113 | " tp = (t % (2 * bip_duration)) / (2 * bip_duration)\n", 114 | " envelope = (tp >= 0.5).float()\n", 115 | " return wav * envelope\n" 116 | ] 117 | }, 118 | { 119 | "cell_type": "code", 120 | "execution_count": null, 121 | "metadata": {}, 122 | "outputs": [], 123 | "source": [ 124 | "# Here we use a synthetic signal to prompt both the tonality and the BPM\n", 125 | "# of the generated audio.\n", 126 | "res = model.generate_continuation(\n", 127 | " get_bip_bip(0.125).expand(2, -1, -1), \n", 128 | " 32000, ['Jazz jazz and only jazz', \n", 129 | " 'Heartful EDM with beautiful synths and chords'], \n", 130 | " progress=True)\n", 131 | "display_audio(res, 32000)" 132 | ] 133 | }, 134 | { 135 | "cell_type": "code", 136 | "execution_count": null, 137 | "metadata": {}, 138 | "outputs": [], 139 | "source": [ 140 | "# You can also use any audio from a file. Make sure to trim the file if it is too long!\n", 141 | "prompt_waveform, prompt_sr = torchaudio.load(\"./assets/bach.mp3\")\n", 142 | "prompt_duration = 2\n", 143 | "prompt_waveform = prompt_waveform[..., :int(prompt_duration * prompt_sr)]\n", 144 | "output = model.generate_continuation(prompt_waveform, prompt_sample_rate=prompt_sr, progress=True)\n", 145 | "display_audio(output, sample_rate=32000)" 146 | ] 147 | }, 148 | { 149 | "cell_type": "markdown", 150 | "metadata": {}, 151 | "source": [ 152 | "### Text-conditional Generation" 153 | ] 154 | }, 155 | { 156 | "cell_type": "code", 157 | "execution_count": null, 158 | "metadata": {}, 159 | "outputs": [], 160 | "source": [ 161 | "from audiocraft.utils.notebook import display_audio\n", 162 | "\n", 163 | "output = model.generate(\n", 164 | " descriptions=[\n", 165 | " '80s pop track with bassy drums and synth',\n", 166 | " '90s rock song with loud guitars and heavy drums',\n", 167 | " ],\n", 168 | " progress=True\n", 169 | ")\n", 170 | "display_audio(output, sample_rate=32000)" 171 | ] 172 | }, 173 | { 174 | "cell_type": "markdown", 175 | "metadata": {}, 176 | "source": [ 177 | "### Melody-conditional Generation" 178 | ] 179 | }, 180 | { 181 | "cell_type": "code", 182 | "execution_count": null, 183 | "metadata": {}, 184 | "outputs": [], 185 | "source": [ 186 | "import torchaudio\n", 187 | "from audiocraft.utils.notebook import display_audio\n", 188 | "\n", 189 | "model = MusicGen.get_pretrained('melody')\n", 190 | "model.set_generation_params(duration=8)\n", 191 | "\n", 192 | "melody_waveform, sr = torchaudio.load(\"assets/bach.mp3\")\n", 193 | "melody_waveform = melody_waveform.unsqueeze(0).repeat(2, 1, 1)\n", 194 | "output = model.generate_with_chroma(\n", 195 | " descriptions=[\n", 196 | " '80s pop track with bassy drums and synth',\n", 197 | " '90s rock song with loud guitars and heavy drums',\n", 198 | " ],\n", 199 | " melody_wavs=melody_waveform,\n", 200 | " melody_sample_rate=sr,\n", 201 | " progress=True\n", 202 | ")\n", 203 | "display_audio(output, sample_rate=32000)" 204 | ] 205 | }, 206 | { 207 | "cell_type": "code", 208 | "execution_count": null, 209 | "metadata": {}, 210 | "outputs": [], 211 | "source": [] 212 | } 213 | ], 214 | "metadata": { 215 | "kernelspec": { 216 | "display_name": "Python 3 (ipykernel)", 217 | "language": "python", 218 | "name": "python3" 219 | }, 220 | "language_info": { 221 | "codemirror_mode": { 222 | "name": "ipython", 223 | "version": 3 224 | }, 225 | "file_extension": ".py", 226 | "mimetype": "text/x-python", 227 | "name": "python", 228 | "nbconvert_exporter": "python", 229 | "pygments_lexer": "ipython3", 230 | "version": "3.9.7" 231 | } 232 | }, 233 | "nbformat": 4, 234 | "nbformat_minor": 2 235 | } 236 | -------------------------------------------------------------------------------- /mypy.ini: -------------------------------------------------------------------------------- 1 | [mypy] 2 | 3 | [mypy-treetable,torchaudio.*,soundfile,einops.*,av.*,tqdm.*,num2words.*,spacy,xformers.*,scipy,huggingface_hub] 4 | ignore_missing_imports = True 5 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # please make sure you have already a pytorch install that is cuda enabled! 2 | av 3 | einops 4 | flashy>=0.0.1 5 | hydra-core>=1.1 6 | hydra_colorlog 7 | julius 8 | num2words 9 | numpy 10 | sentencepiece 11 | spacy==3.5.2 12 | torch>=2.0.0 13 | torchaudio>=2.0.0 14 | huggingface_hub 15 | tqdm 16 | transformers 17 | xformers 18 | demucs 19 | librosa 20 | gradio 21 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [pep8] 2 | max-line-length = 120 3 | 4 | [flake8] 5 | max-line-length = 120 6 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Meta Platforms, Inc. and affiliates. 3 | All rights reserved. 4 | 5 | This source code is licensed under the license found in the 6 | LICENSE file in the root directory of this source tree. 7 | 8 | """ 9 | 10 | from pathlib import Path 11 | 12 | from setuptools import setup, find_packages 13 | 14 | 15 | NAME = 'audiocraft' 16 | DESCRIPTION = 'Audio research library for PyTorch' 17 | 18 | URL = 'https://github.com/fairinternal/audiocraft' 19 | AUTHOR = 'FAIR Speech & Audio' 20 | EMAIL = 'defossez@meta.com' 21 | REQUIRES_PYTHON = '>=3.8.0' 22 | 23 | for line in open('audiocraft/__init__.py'): 24 | line = line.strip() 25 | if '__version__' in line: 26 | context = {} 27 | exec(line, context) 28 | VERSION = context['__version__'] 29 | 30 | HERE = Path(__file__).parent 31 | 32 | try: 33 | with open(HERE / "README.md", encoding='utf-8') as f: 34 | long_description = '\n' + f.read() 35 | except FileNotFoundError: 36 | long_description = DESCRIPTION 37 | 38 | REQUIRED = [i.strip() for i in open(HERE / 'requirements.txt') if not i.startswith('#')] 39 | 40 | setup( 41 | name=NAME, 42 | version=VERSION, 43 | description=DESCRIPTION, 44 | author_email=EMAIL, 45 | long_description=long_description, 46 | long_description_content_type='text/markdown', 47 | author=AUTHOR, 48 | url=URL, 49 | python_requires=REQUIRES_PYTHON, 50 | install_requires=REQUIRED, 51 | extras_require={ 52 | 'dev': ['coverage', 'flake8', 'mypy', 'pdoc3', 'pytest'], 53 | }, 54 | packages=find_packages(), 55 | package_data={'audiocraft': ['py.typed']}, 56 | include_package_data=True, 57 | license='MIT License', 58 | classifiers=[ 59 | # Trove classifiers 60 | # Full list: https://pypi.python.org/pypi?%3Aaction=list_classifiers 61 | 'License :: OSI Approved :: MIT License', 62 | 'Topic :: Multimedia :: Sound/Audio', 63 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 64 | ], 65 | ) 66 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /tests/common_utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # flake8: noqa 8 | from .temp_utils import TempDirMixin 9 | from .wav_utils import get_batch_white_noise, get_white_noise, save_wav 10 | -------------------------------------------------------------------------------- /tests/common_utils/temp_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import os 8 | import tempfile 9 | 10 | 11 | class TempDirMixin: 12 | """Mixin to provide easy access to temp dir. 13 | """ 14 | 15 | temp_dir_ = None 16 | 17 | @classmethod 18 | def get_base_temp_dir(cls): 19 | # If AUDIOCRAFT_TEST_DIR is set, use it instead of temporary directory. 20 | # this is handy for debugging. 21 | key = "AUDIOCRAFT_TEST_DIR" 22 | if key in os.environ: 23 | return os.environ[key] 24 | if cls.temp_dir_ is None: 25 | cls.temp_dir_ = tempfile.TemporaryDirectory() 26 | return cls.temp_dir_.name 27 | 28 | @classmethod 29 | def tearDownClass(cls): 30 | if cls.temp_dir_ is not None: 31 | try: 32 | cls.temp_dir_.cleanup() 33 | cls.temp_dir_ = None 34 | except PermissionError: 35 | # On Windows there is a know issue with `shutil.rmtree`, 36 | # which fails intermittenly. 37 | # https://github.com/python/cpython/issues/74168 38 | # Following the above thread, we ignore it. 39 | pass 40 | super().tearDownClass() 41 | 42 | @property 43 | def id(self): 44 | return self.__class__.__name__ 45 | 46 | def get_temp_path(self, *paths): 47 | temp_dir = os.path.join(self.get_base_temp_dir(), self.id) 48 | path = os.path.join(temp_dir, *paths) 49 | os.makedirs(os.path.dirname(path), exist_ok=True) 50 | return path 51 | 52 | def get_temp_dir(self, *paths): 53 | temp_dir = os.path.join(self.get_base_temp_dir(), self.id) 54 | path = os.path.join(temp_dir, *paths) 55 | os.makedirs(path, exist_ok=True) 56 | return path 57 | -------------------------------------------------------------------------------- /tests/common_utils/wav_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from pathlib import Path 8 | import typing as tp 9 | 10 | import torch 11 | import torchaudio 12 | 13 | 14 | def get_white_noise(chs: int = 1, num_frames: int = 1): 15 | wav = torch.randn(chs, num_frames) 16 | return wav 17 | 18 | 19 | def get_batch_white_noise(bs: int = 1, chs: int = 1, num_frames: int = 1): 20 | wav = torch.randn(bs, chs, num_frames) 21 | return wav 22 | 23 | 24 | def save_wav(path: str, wav: torch.Tensor, sample_rate: int): 25 | fp = Path(path) 26 | kwargs: tp.Dict[str, tp.Any] = {} 27 | if fp.suffix == '.wav': 28 | kwargs['encoding'] = 'PCM_S' 29 | kwargs['bits_per_sample'] = 16 30 | elif fp.suffix == '.mp3': 31 | kwargs['compression'] = 320 32 | torchaudio.save(str(fp), wav, sample_rate, **kwargs) 33 | -------------------------------------------------------------------------------- /tests/data/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /tests/data/test_audio.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from itertools import product 8 | import random 9 | 10 | import numpy as np 11 | import torch 12 | import torchaudio 13 | 14 | from audiocraft.data.audio import audio_info, audio_read, audio_write, _av_read 15 | 16 | from ..common_utils import TempDirMixin, get_white_noise, save_wav 17 | 18 | 19 | class TestInfo(TempDirMixin): 20 | 21 | def test_info_mp3(self): 22 | sample_rates = [8000, 16_000] 23 | channels = [1, 2] 24 | duration = 1. 25 | for sample_rate, ch in product(sample_rates, channels): 26 | wav = get_white_noise(ch, int(sample_rate * duration)) 27 | path = self.get_temp_path('sample_wav.mp3') 28 | save_wav(path, wav, sample_rate) 29 | info = audio_info(path) 30 | assert info.sample_rate == sample_rate 31 | assert info.channels == ch 32 | # we cannot trust torchaudio for num_frames, so we don't check 33 | 34 | def _test_info_format(self, ext: str): 35 | sample_rates = [8000, 16_000] 36 | channels = [1, 2] 37 | duration = 1. 38 | for sample_rate, ch in product(sample_rates, channels): 39 | n_frames = int(sample_rate * duration) 40 | wav = get_white_noise(ch, n_frames) 41 | path = self.get_temp_path(f'sample_wav{ext}') 42 | save_wav(path, wav, sample_rate) 43 | info = audio_info(path) 44 | assert info.sample_rate == sample_rate 45 | assert info.channels == ch 46 | assert np.isclose(info.duration, duration, atol=1e-5) 47 | 48 | def test_info_wav(self): 49 | self._test_info_format('.wav') 50 | 51 | def test_info_flac(self): 52 | self._test_info_format('.flac') 53 | 54 | def test_info_ogg(self): 55 | self._test_info_format('.ogg') 56 | 57 | def test_info_m4a(self): 58 | # TODO: generate m4a file programmatically 59 | # self._test_info_format('.m4a') 60 | pass 61 | 62 | 63 | class TestRead(TempDirMixin): 64 | 65 | def test_read_full_wav(self): 66 | sample_rates = [8000, 16_000] 67 | channels = [1, 2] 68 | duration = 1. 69 | for sample_rate, ch in product(sample_rates, channels): 70 | n_frames = int(sample_rate * duration) 71 | wav = get_white_noise(ch, n_frames).clamp(-0.99, 0.99) 72 | path = self.get_temp_path('sample_wav.wav') 73 | save_wav(path, wav, sample_rate) 74 | read_wav, read_sr = audio_read(path) 75 | assert read_sr == sample_rate 76 | assert read_wav.shape[0] == wav.shape[0] 77 | assert read_wav.shape[1] == wav.shape[1] 78 | assert torch.allclose(read_wav, wav, rtol=1e-03, atol=1e-04) 79 | 80 | def test_read_partial_wav(self): 81 | sample_rates = [8000, 16_000] 82 | channels = [1, 2] 83 | duration = 1. 84 | read_duration = torch.rand(1).item() 85 | for sample_rate, ch in product(sample_rates, channels): 86 | n_frames = int(sample_rate * duration) 87 | read_frames = int(sample_rate * read_duration) 88 | wav = get_white_noise(ch, n_frames).clamp(-0.99, 0.99) 89 | path = self.get_temp_path('sample_wav.wav') 90 | save_wav(path, wav, sample_rate) 91 | read_wav, read_sr = audio_read(path, 0, read_duration) 92 | assert read_sr == sample_rate 93 | assert read_wav.shape[0] == wav.shape[0] 94 | assert read_wav.shape[1] == read_frames 95 | assert torch.allclose(read_wav[..., 0:read_frames], wav[..., 0:read_frames], rtol=1e-03, atol=1e-04) 96 | 97 | def test_read_seek_time_wav(self): 98 | sample_rates = [8000, 16_000] 99 | channels = [1, 2] 100 | duration = 1. 101 | read_duration = 1. 102 | for sample_rate, ch in product(sample_rates, channels): 103 | n_frames = int(sample_rate * duration) 104 | wav = get_white_noise(ch, n_frames).clamp(-0.99, 0.99) 105 | path = self.get_temp_path('sample_wav.wav') 106 | save_wav(path, wav, sample_rate) 107 | seek_time = torch.rand(1).item() 108 | read_wav, read_sr = audio_read(path, seek_time, read_duration) 109 | seek_frames = int(sample_rate * seek_time) 110 | expected_frames = n_frames - seek_frames 111 | assert read_sr == sample_rate 112 | assert read_wav.shape[0] == wav.shape[0] 113 | assert read_wav.shape[1] == expected_frames 114 | assert torch.allclose(read_wav, wav[..., seek_frames:], rtol=1e-03, atol=1e-04) 115 | 116 | def test_read_seek_time_wav_padded(self): 117 | sample_rates = [8000, 16_000] 118 | channels = [1, 2] 119 | duration = 1. 120 | read_duration = 1. 121 | for sample_rate, ch in product(sample_rates, channels): 122 | n_frames = int(sample_rate * duration) 123 | read_frames = int(sample_rate * read_duration) 124 | wav = get_white_noise(ch, n_frames).clamp(-0.99, 0.99) 125 | path = self.get_temp_path('sample_wav.wav') 126 | save_wav(path, wav, sample_rate) 127 | seek_time = torch.rand(1).item() 128 | seek_frames = int(sample_rate * seek_time) 129 | expected_frames = n_frames - seek_frames 130 | read_wav, read_sr = audio_read(path, seek_time, read_duration, pad=True) 131 | expected_pad_wav = torch.zeros(wav.shape[0], read_frames - expected_frames) 132 | assert read_sr == sample_rate 133 | assert read_wav.shape[0] == wav.shape[0] 134 | assert read_wav.shape[1] == read_frames 135 | assert torch.allclose(read_wav[..., :expected_frames], wav[..., seek_frames:], rtol=1e-03, atol=1e-04) 136 | assert torch.allclose(read_wav[..., expected_frames:], expected_pad_wav) 137 | 138 | 139 | class TestAvRead(TempDirMixin): 140 | 141 | def test_avread_seek_base(self): 142 | sample_rates = [8000, 16_000] 143 | channels = [1, 2] 144 | duration = 2. 145 | for sample_rate, ch in product(sample_rates, channels): 146 | n_frames = int(sample_rate * duration) 147 | wav = get_white_noise(ch, n_frames) 148 | path = self.get_temp_path(f'reference_a_{sample_rate}_{ch}.wav') 149 | save_wav(path, wav, sample_rate) 150 | for _ in range(100): 151 | # seek will always load a full duration segment in the file 152 | seek_time = random.uniform(0.0, 1.0) 153 | seek_duration = random.uniform(0.001, 1.0) 154 | read_wav, read_sr = _av_read(path, seek_time, seek_duration) 155 | assert read_sr == sample_rate 156 | assert read_wav.shape[0] == wav.shape[0] 157 | assert read_wav.shape[-1] == int(seek_duration * sample_rate) 158 | 159 | def test_avread_seek_partial(self): 160 | sample_rates = [8000, 16_000] 161 | channels = [1, 2] 162 | duration = 1. 163 | for sample_rate, ch in product(sample_rates, channels): 164 | n_frames = int(sample_rate * duration) 165 | wav = get_white_noise(ch, n_frames) 166 | path = self.get_temp_path(f'reference_b_{sample_rate}_{ch}.wav') 167 | save_wav(path, wav, sample_rate) 168 | for _ in range(100): 169 | # seek will always load a partial segment 170 | seek_time = random.uniform(0.5, 1.) 171 | seek_duration = 1. 172 | expected_num_frames = n_frames - int(seek_time * sample_rate) 173 | read_wav, read_sr = _av_read(path, seek_time, seek_duration) 174 | assert read_sr == sample_rate 175 | assert read_wav.shape[0] == wav.shape[0] 176 | assert read_wav.shape[-1] == expected_num_frames 177 | 178 | def test_avread_seek_outofbound(self): 179 | sample_rates = [8000, 16_000] 180 | channels = [1, 2] 181 | duration = 1. 182 | for sample_rate, ch in product(sample_rates, channels): 183 | n_frames = int(sample_rate * duration) 184 | wav = get_white_noise(ch, n_frames) 185 | path = self.get_temp_path(f'reference_c_{sample_rate}_{ch}.wav') 186 | save_wav(path, wav, sample_rate) 187 | seek_time = 1.5 188 | read_wav, read_sr = _av_read(path, seek_time, 1.) 189 | assert read_sr == sample_rate 190 | assert read_wav.shape[0] == wav.shape[0] 191 | assert read_wav.shape[-1] == 0 192 | 193 | def test_avread_seek_edge(self): 194 | sample_rates = [8000, 16_000] 195 | # some of these values will have 196 | # int(((frames - 1) / sample_rate) * sample_rate) != (frames - 1) 197 | n_frames = [1000, 1001, 1002] 198 | channels = [1, 2] 199 | for sample_rate, ch, frames in product(sample_rates, channels, n_frames): 200 | duration = frames / sample_rate 201 | wav = get_white_noise(ch, frames) 202 | path = self.get_temp_path(f'reference_d_{sample_rate}_{ch}.wav') 203 | save_wav(path, wav, sample_rate) 204 | seek_time = (frames - 1) / sample_rate 205 | seek_frames = int(seek_time * sample_rate) 206 | read_wav, read_sr = _av_read(path, seek_time, duration) 207 | assert read_sr == sample_rate 208 | assert read_wav.shape[0] == wav.shape[0] 209 | assert read_wav.shape[-1] == (frames - seek_frames) 210 | 211 | 212 | class TestAudioWrite(TempDirMixin): 213 | 214 | def test_audio_write_wav(self): 215 | torch.manual_seed(1234) 216 | sample_rates = [8000, 16_000] 217 | n_frames = [1000, 1001, 1002] 218 | channels = [1, 2] 219 | strategies = ["peak", "clip", "rms"] 220 | formats = ["wav", "mp3"] 221 | for sample_rate, ch, frames in product(sample_rates, channels, n_frames): 222 | for format_, strategy in product(formats, strategies): 223 | wav = get_white_noise(ch, frames) 224 | path = self.get_temp_path(f'pred_{sample_rate}_{ch}') 225 | audio_write(path, wav, sample_rate, format_, strategy=strategy) 226 | read_wav, read_sr = torchaudio.load(f'{path}.{format_}') 227 | if format_ == "wav": 228 | assert read_wav.shape == wav.shape 229 | 230 | if format_ == "wav" and strategy in ["peak", "rms"]: 231 | rescaled_read_wav = read_wav / read_wav.abs().max() * wav.abs().max() 232 | # for a Gaussian, the typical max scale will be less than ~5x the std. 233 | # The error when writing to disk will ~ 1/2**15, and when rescaling, 5x that. 234 | # For RMS target, rescaling leaves more headroom by default, leading 235 | # to a 20x rescaling typically 236 | atol = (5 if strategy == "peak" else 20) / 2**15 237 | delta = (rescaled_read_wav - wav).abs().max() 238 | assert torch.allclose(wav, rescaled_read_wav, rtol=0, atol=atol), (delta, atol) 239 | formats = ["wav"] # faster unit tests 240 | -------------------------------------------------------------------------------- /tests/data/test_audio_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import julius 8 | import torch 9 | import pytest 10 | 11 | from audiocraft.data.audio_utils import ( 12 | _clip_wav, 13 | convert_audio_channels, 14 | convert_audio, 15 | normalize_audio 16 | ) 17 | from ..common_utils import get_batch_white_noise 18 | 19 | 20 | class TestConvertAudioChannels: 21 | 22 | def test_convert_audio_channels_downmix(self): 23 | b, c, t = 2, 3, 100 24 | audio = get_batch_white_noise(b, c, t) 25 | mixed = convert_audio_channels(audio, channels=2) 26 | assert list(mixed.shape) == [b, 2, t] 27 | 28 | def test_convert_audio_channels_nochange(self): 29 | b, c, t = 2, 3, 100 30 | audio = get_batch_white_noise(b, c, t) 31 | mixed = convert_audio_channels(audio, channels=c) 32 | assert list(mixed.shape) == list(audio.shape) 33 | 34 | def test_convert_audio_channels_upmix(self): 35 | b, c, t = 2, 1, 100 36 | audio = get_batch_white_noise(b, c, t) 37 | mixed = convert_audio_channels(audio, channels=3) 38 | assert list(mixed.shape) == [b, 3, t] 39 | 40 | def test_convert_audio_channels_upmix_error(self): 41 | b, c, t = 2, 2, 100 42 | audio = get_batch_white_noise(b, c, t) 43 | with pytest.raises(ValueError): 44 | convert_audio_channels(audio, channels=3) 45 | 46 | 47 | class TestConvertAudio: 48 | 49 | def test_convert_audio_channels_downmix(self): 50 | b, c, dur = 2, 3, 4. 51 | sr = 128 52 | audio = get_batch_white_noise(b, c, int(sr * dur)) 53 | out = convert_audio(audio, from_rate=sr, to_rate=sr, to_channels=2) 54 | assert list(out.shape) == [audio.shape[0], 2, audio.shape[-1]] 55 | 56 | def test_convert_audio_channels_upmix(self): 57 | b, c, dur = 2, 1, 4. 58 | sr = 128 59 | audio = get_batch_white_noise(b, c, int(sr * dur)) 60 | out = convert_audio(audio, from_rate=sr, to_rate=sr, to_channels=3) 61 | assert list(out.shape) == [audio.shape[0], 3, audio.shape[-1]] 62 | 63 | def test_convert_audio_upsample(self): 64 | b, c, dur = 2, 1, 4. 65 | sr = 2 66 | new_sr = 3 67 | audio = get_batch_white_noise(b, c, int(sr * dur)) 68 | out = convert_audio(audio, from_rate=sr, to_rate=new_sr, to_channels=c) 69 | out_j = julius.resample.resample_frac(audio, old_sr=sr, new_sr=new_sr) 70 | assert torch.allclose(out, out_j) 71 | 72 | def test_convert_audio_resample(self): 73 | b, c, dur = 2, 1, 4. 74 | sr = 3 75 | new_sr = 2 76 | audio = get_batch_white_noise(b, c, int(sr * dur)) 77 | out = convert_audio(audio, from_rate=sr, to_rate=new_sr, to_channels=c) 78 | out_j = julius.resample.resample_frac(audio, old_sr=sr, new_sr=new_sr) 79 | assert torch.allclose(out, out_j) 80 | 81 | 82 | class TestNormalizeAudio: 83 | 84 | def test_clip_wav(self): 85 | b, c, dur = 2, 1, 4. 86 | sr = 3 87 | audio = 10.0 * get_batch_white_noise(b, c, int(sr * dur)) 88 | _clip_wav(audio) 89 | assert audio.abs().max() <= 1 90 | 91 | def test_normalize_audio_clip(self): 92 | b, c, dur = 2, 1, 4. 93 | sr = 3 94 | audio = 10.0 * get_batch_white_noise(b, c, int(sr * dur)) 95 | norm_audio = normalize_audio(audio, strategy='clip') 96 | assert norm_audio.abs().max() <= 1 97 | 98 | def test_normalize_audio_rms(self): 99 | b, c, dur = 2, 1, 4. 100 | sr = 3 101 | audio = 10.0 * get_batch_white_noise(b, c, int(sr * dur)) 102 | norm_audio = normalize_audio(audio, strategy='rms') 103 | assert norm_audio.abs().max() <= 1 104 | 105 | def test_normalize_audio_peak(self): 106 | b, c, dur = 2, 1, 4. 107 | sr = 3 108 | audio = 10.0 * get_batch_white_noise(b, c, int(sr * dur)) 109 | norm_audio = normalize_audio(audio, strategy='peak') 110 | assert norm_audio.abs().max() <= 1 111 | -------------------------------------------------------------------------------- /tests/models/test_encodec_model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import random 8 | 9 | import numpy as np 10 | import torch 11 | 12 | from audiocraft.models import EncodecModel 13 | from audiocraft.modules import SEANetEncoder, SEANetDecoder 14 | from audiocraft.quantization import DummyQuantizer 15 | 16 | 17 | class TestEncodecModel: 18 | 19 | def _create_encodec_model(self, 20 | sample_rate: int, 21 | channels: int, 22 | dim: int = 5, 23 | n_filters: int = 3, 24 | n_residual_layers: int = 1, 25 | ratios: list = [5, 4, 3, 2], 26 | **kwargs): 27 | frame_rate = np.prod(ratios) 28 | encoder = SEANetEncoder(channels=channels, dimension=dim, n_filters=n_filters, 29 | n_residual_layers=n_residual_layers, ratios=ratios) 30 | decoder = SEANetDecoder(channels=channels, dimension=dim, n_filters=n_filters, 31 | n_residual_layers=n_residual_layers, ratios=ratios) 32 | quantizer = DummyQuantizer() 33 | model = EncodecModel(encoder, decoder, quantizer, frame_rate=frame_rate, 34 | sample_rate=sample_rate, channels=channels, **kwargs) 35 | return model 36 | 37 | def test_model(self): 38 | random.seed(1234) 39 | sample_rate = 24_000 40 | channels = 1 41 | model = self._create_encodec_model(sample_rate, channels) 42 | for _ in range(10): 43 | length = random.randrange(1, 10_000) 44 | x = torch.randn(2, channels, length) 45 | res = model(x) 46 | assert res.x.shape == x.shape 47 | 48 | def test_model_renorm(self): 49 | random.seed(1234) 50 | sample_rate = 24_000 51 | channels = 1 52 | model_nonorm = self._create_encodec_model(sample_rate, channels, renormalize=False) 53 | model_renorm = self._create_encodec_model(sample_rate, channels, renormalize=True) 54 | 55 | for _ in range(10): 56 | length = random.randrange(1, 10_000) 57 | x = torch.randn(2, channels, length) 58 | codes, scales = model_nonorm.encode(x) 59 | codes, scales = model_renorm.encode(x) 60 | assert scales is not None 61 | -------------------------------------------------------------------------------- /tests/models/test_musicgen.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import pytest 8 | import torch 9 | 10 | from audiocraft.models import MusicGen 11 | 12 | 13 | class TestSEANetModel: 14 | def get_musicgen(self): 15 | mg = MusicGen.get_pretrained(name='debug', device='cpu') 16 | mg.set_generation_params(duration=2.0) 17 | return mg 18 | 19 | def test_base(self): 20 | mg = self.get_musicgen() 21 | assert mg.frame_rate == 25 22 | assert mg.sample_rate == 32000 23 | assert mg.audio_channels == 1 24 | 25 | def test_generate_unconditional(self): 26 | mg = self.get_musicgen() 27 | wav = mg.generate_unconditional(3) 28 | assert list(wav.shape) == [3, 1, 64000] 29 | 30 | def test_generate_continuation(self): 31 | mg = self.get_musicgen() 32 | prompt = torch.randn(3, 1, 32000) 33 | wav = mg.generate_continuation(prompt, 32000) 34 | assert list(wav.shape) == [3, 1, 64000] 35 | 36 | prompt = torch.randn(2, 1, 32000) 37 | wav = mg.generate_continuation( 38 | prompt, 32000, ['youpi', 'lapin dort']) 39 | assert list(wav.shape) == [2, 1, 64000] 40 | 41 | prompt = torch.randn(2, 1, 32000) 42 | with pytest.raises(AssertionError): 43 | wav = mg.generate_continuation( 44 | prompt, 32000, ['youpi', 'lapin dort', 'one too many']) 45 | 46 | def test_generate(self): 47 | mg = self.get_musicgen() 48 | wav = mg.generate( 49 | ['youpi', 'lapin dort']) 50 | assert list(wav.shape) == [2, 1, 64000] 51 | -------------------------------------------------------------------------------- /tests/modules/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /tests/modules/test_codebooks_patterns.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import pytest 8 | import torch 9 | 10 | from audiocraft.modules.codebooks_patterns import ( 11 | DelayedPatternProvider, 12 | ParallelPatternProvider, 13 | Pattern, 14 | UnrolledPatternProvider, 15 | ) 16 | 17 | 18 | class TestParallelPatternProvider: 19 | 20 | @pytest.mark.parametrize("n_q", [1, 4, 32]) 21 | @pytest.mark.parametrize("timesteps", [0, 1, 16, 100]) 22 | def test_get_pattern(self, n_q: int, timesteps: int): 23 | provider = ParallelPatternProvider(n_q) 24 | pattern = provider.get_pattern(timesteps) 25 | # + 1 to account for 1st step 26 | assert len(pattern.layout) == timesteps + 1 27 | 28 | @pytest.mark.parametrize("n_q", [1, 4, 32]) 29 | @pytest.mark.parametrize("timesteps", [8, 16, 100]) 30 | def test_pattern_content(self, n_q: int, timesteps: int): 31 | provider = ParallelPatternProvider(n_q) 32 | pattern = provider.get_pattern(timesteps) 33 | for s, v in enumerate(pattern.layout): 34 | for i, code in enumerate(v): 35 | assert i == code.q 36 | assert code.t == s - 1 # account for the 1st empty step 37 | 38 | @pytest.mark.parametrize("n_q", [1, 4, 32]) 39 | @pytest.mark.parametrize("timesteps", [8, 16, 100]) 40 | def test_pattern_max_delay(self, n_q: int, timesteps: int): 41 | provider = ParallelPatternProvider(n_q) 42 | pattern = provider.get_pattern(timesteps) 43 | assert pattern.max_delay == 0 44 | assert len(pattern.valid_layout) == len(pattern.layout) - pattern.max_delay 45 | 46 | 47 | class TestDelayedPatternProvider: 48 | 49 | @pytest.mark.parametrize("n_q", [1, 4, 32]) 50 | @pytest.mark.parametrize("timesteps", [0, 1, 16, 100]) 51 | def test_get_pattern(self, n_q: int, timesteps: int): 52 | delays = [ 53 | list(range(n_q)), 54 | [0] + [1] * (n_q - 1), 55 | [0] + [4] * (n_q - 1), 56 | ] 57 | for delay in delays: 58 | provider = DelayedPatternProvider(n_q, delay) 59 | pattern = provider.get_pattern(timesteps) 60 | # + 1 to account for 1st step 61 | assert len(pattern.layout) == timesteps + max(delay) + 1 62 | 63 | @pytest.mark.parametrize("n_q", [1, 4, 32]) 64 | @pytest.mark.parametrize("timesteps", [8, 16, 100]) 65 | def test_pattern_content(self, n_q: int, timesteps: int): 66 | provider = DelayedPatternProvider(n_q) 67 | pattern = provider.get_pattern(timesteps) 68 | for s, v in enumerate(pattern.layout): 69 | for i, code in enumerate(v): 70 | assert i == code.q 71 | assert code.t == max(0, s - code.q - 1) 72 | 73 | @pytest.mark.parametrize("timesteps", [8, 16, 100]) 74 | @pytest.mark.parametrize("delay", [[0, 1, 2, 3], [0, 1, 1, 1], [0, 3, 3, 3], [0, 3]]) 75 | def test_pattern_max_delay(self, timesteps: int, delay: list): 76 | provider = DelayedPatternProvider(len(delay), delay) 77 | pattern = provider.get_pattern(timesteps) 78 | assert pattern.max_delay == max(delay) 79 | assert len(pattern.valid_layout) == len(pattern.layout) - pattern.max_delay 80 | 81 | 82 | class TestUnrolledPatternProvider: 83 | 84 | @pytest.mark.parametrize("timesteps", [0, 1, 16]) 85 | @pytest.mark.parametrize("flattening", [[0, 1, 2], [0, 1, 1]]) 86 | @pytest.mark.parametrize("delays", [[0, 0, 0], [0, 5, 5]]) 87 | def test_get_pattern(self, timesteps: int, flattening: list, delays: list): 88 | n_q = len(flattening) 89 | max_delay = max(delays) 90 | provider = UnrolledPatternProvider(n_q, flattening, delays) 91 | pattern = provider.get_pattern(timesteps) 92 | assert len(pattern.layout) == provider.num_virtual_steps(timesteps) + max_delay 93 | 94 | @pytest.mark.parametrize("timesteps", [0, 1, 16]) 95 | @pytest.mark.parametrize("flattening", [[0, 1, 2], [0, 1, 1]]) 96 | @pytest.mark.parametrize("delays", [[0, 0, 0], [0, 5, 5]]) 97 | def test_pattern_max_delay(self, timesteps: int, flattening: list, delays: list): 98 | n_q = len(flattening) 99 | max_delay = max(delays) 100 | provider = UnrolledPatternProvider(n_q, flattening, delays) 101 | pattern = provider.get_pattern(timesteps) 102 | assert pattern.max_delay == max_delay 103 | 104 | 105 | class TestPattern: 106 | 107 | def ref_build_pattern_sequence(self, z: torch.Tensor, pattern: Pattern, special_token: int): 108 | """Reference method to build the sequence from the pattern without using fancy scatter.""" 109 | bs, n_q, T = z.shape 110 | z = z.cpu().numpy() 111 | assert n_q == pattern.n_q 112 | assert T <= pattern.timesteps 113 | inp = torch.full((bs, n_q, len(pattern.layout)), special_token, dtype=torch.long).numpy() 114 | inp[:] = special_token 115 | for s, v in enumerate(pattern.layout): 116 | for (t, q) in v: 117 | if t < T: 118 | inp[:, q, s] = z[:, q, t] 119 | return torch.from_numpy(inp) 120 | 121 | def ref_revert_pattern_sequence(self, z: torch.Tensor, pattern: Pattern, special_token: int): 122 | """Reference method to revert the sequence from the pattern without using fancy scatter.""" 123 | z = z.cpu().numpy() 124 | bs, n_q, S = z.shape 125 | assert pattern.n_q == n_q 126 | inp = torch.full((bs, pattern.n_q, pattern.timesteps), special_token, dtype=torch.long).numpy() 127 | inp[:] = special_token 128 | for s, v in enumerate(pattern.layout): 129 | for (t, q) in v: 130 | if t < pattern.timesteps: 131 | inp[:, q, t] = z[:, q, s] 132 | return torch.from_numpy(inp) 133 | 134 | def ref_revert_pattern_logits(self, z: torch.Tensor, pattern: Pattern, special_token: float): 135 | """Reference method to revert the logits from the pattern without using fancy scatter.""" 136 | z = z.cpu().numpy() 137 | bs, card, n_q, S = z.shape 138 | assert pattern.n_q == n_q 139 | ref_layout = pattern.layout 140 | inp = torch.full((bs, card, pattern.n_q, pattern.timesteps), special_token, dtype=torch.float).numpy() 141 | inp[:] = special_token 142 | for s, v in enumerate(ref_layout[1:]): 143 | if s < S: 144 | for (t, q) in v: 145 | if t < pattern.timesteps: 146 | inp[:, :, q, t] = z[:, :, q, s] 147 | return torch.from_numpy(inp) 148 | 149 | def _get_pattern_providers(self, n_q: int): 150 | pattern_provider_1 = ParallelPatternProvider(n_q) 151 | pattern_provider_2 = DelayedPatternProvider(n_q, list(range(n_q))) 152 | pattern_provider_3 = DelayedPatternProvider(n_q, [0] + [1] * (n_q - 1)) 153 | pattern_provider_4 = UnrolledPatternProvider( 154 | n_q, flattening=list(range(n_q)), delays=[0] * n_q 155 | ) 156 | pattern_provider_5 = UnrolledPatternProvider( 157 | n_q, flattening=[0] + [1] * (n_q - 1), delays=[0] * n_q 158 | ) 159 | pattern_provider_6 = UnrolledPatternProvider( 160 | n_q, flattening=[0] + [1] * (n_q - 1), delays=[0] + [5] * (n_q - 1) 161 | ) 162 | return [ 163 | pattern_provider_1, 164 | pattern_provider_2, 165 | pattern_provider_3, 166 | pattern_provider_4, 167 | pattern_provider_5, 168 | pattern_provider_6, 169 | ] 170 | 171 | @pytest.mark.parametrize("n_q", [1, 4, 32]) 172 | @pytest.mark.parametrize("timesteps", [16, 72]) 173 | def test_build_pattern_sequence(self, n_q: int, timesteps: int): 174 | bs = 2 175 | card = 256 176 | special_token = card 177 | 178 | pattern_providers = self._get_pattern_providers(n_q) 179 | for pattern_provider in pattern_providers: 180 | pattern = pattern_provider.get_pattern(timesteps) 181 | # we can correctly build the sequence from the pattern 182 | z = torch.randint(0, card, (bs, n_q, timesteps)) 183 | ref_res = self.ref_build_pattern_sequence(z, pattern, special_token) 184 | res, indexes, mask = pattern.build_pattern_sequence(z, special_token) 185 | assert (res == ref_res).float().mean() == 1.0 186 | 187 | # expected assertion fails on the number of timesteps 188 | invalid_timesteps = [timesteps + 1] 189 | if pattern.num_sequence_steps != pattern.timesteps: 190 | invalid_timesteps.append(pattern.num_sequence_steps) 191 | for i_timesteps in invalid_timesteps: 192 | z2 = torch.randint(0, card, (bs, n_q, i_timesteps)) 193 | with pytest.raises(AssertionError): 194 | pattern.build_pattern_sequence(z2, special_token) 195 | 196 | # expected assertion fails on the number of codebooks 197 | invalid_qs = [0, n_q - 1, n_q + 1] 198 | for i_q in invalid_qs: 199 | z3 = torch.randint(0, card, (bs, i_q, timesteps)) 200 | with pytest.raises(AssertionError): 201 | pattern.build_pattern_sequence(z3, special_token) 202 | 203 | @pytest.mark.parametrize("n_q", [1, 4, 32]) 204 | @pytest.mark.parametrize("timesteps", [16, 72]) 205 | def test_revert_pattern_sequence(self, n_q: int, timesteps: int): 206 | bs = 2 207 | card = 256 208 | special_token = card 209 | 210 | pattern_providers = self._get_pattern_providers(n_q) 211 | for pattern_provider in pattern_providers: 212 | pattern = pattern_provider.get_pattern(timesteps) 213 | # this works assuming previous tests are successful 214 | z = torch.randint(0, card, (bs, n_q, timesteps)) 215 | s = self.ref_build_pattern_sequence(z, pattern, special_token) 216 | ref_out = self.ref_revert_pattern_sequence(s, pattern, special_token) 217 | # ensure our reference script retrieve the original sequence 218 | assert z.shape == ref_out.shape 219 | assert (z == ref_out).float().mean() == 1.0 220 | # now we can test the scatter version 221 | out, indexes, mask = pattern.revert_pattern_sequence(s, special_token) 222 | assert out.shape == ref_out.shape 223 | assert (out == ref_out).float().mean() == 1.0 224 | 225 | @pytest.mark.parametrize("n_q", [1, 4, 32]) 226 | @pytest.mark.parametrize("timesteps", [16, 72]) 227 | @pytest.mark.parametrize("card", [1, 2, 256, 1024]) 228 | def test_revert_pattern_logits(self, n_q: int, timesteps: int, card: int): 229 | bs = 2 230 | special_token = card 231 | logits_special_token = float('nan') 232 | 233 | pattern_providers = self._get_pattern_providers(n_q) 234 | for pattern_provider in pattern_providers: 235 | pattern = pattern_provider.get_pattern(timesteps) 236 | # this works assuming previous tests are successful 237 | z = torch.randint(0, card, (bs, n_q, timesteps)) 238 | s = self.ref_build_pattern_sequence(z, pattern, special_token) 239 | logits = torch.randn((bs, card, n_q, s.shape[-1])) 240 | ref_out = self.ref_revert_pattern_logits(logits, pattern, logits_special_token) 241 | # ensure our reference script retrieve the original sequence 242 | assert ref_out.shape == torch.Size([bs, card, n_q, timesteps]) 243 | # now we can test the scatter version 244 | out, indexes, mask = pattern.revert_pattern_logits(logits, logits_special_token) 245 | assert out.shape == ref_out.shape 246 | assert (out == ref_out).float().mean() == 1.0 247 | -------------------------------------------------------------------------------- /tests/modules/test_conv.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from itertools import product 8 | import math 9 | import random 10 | 11 | import pytest 12 | import torch 13 | from torch import nn 14 | 15 | from audiocraft.modules import ( 16 | NormConv1d, 17 | NormConvTranspose1d, 18 | StreamableConv1d, 19 | StreamableConvTranspose1d, 20 | pad1d, 21 | unpad1d, 22 | ) 23 | 24 | 25 | def test_get_extra_padding_for_conv1d(): 26 | # TODO: Implement me! 27 | pass 28 | 29 | 30 | def test_pad1d_zeros(): 31 | x = torch.randn(1, 1, 20) 32 | 33 | xp1 = pad1d(x, (0, 5), mode='constant', value=0.) 34 | assert xp1.shape[-1] == 25 35 | xp2 = pad1d(x, (5, 5), mode='constant', value=0.) 36 | assert xp2.shape[-1] == 30 37 | xp3 = pad1d(x, (0, 0), mode='constant', value=0.) 38 | assert xp3.shape[-1] == 20 39 | xp4 = pad1d(x, (10, 30), mode='constant', value=0.) 40 | assert xp4.shape[-1] == 60 41 | 42 | with pytest.raises(AssertionError): 43 | pad1d(x, (-1, 0), mode='constant', value=0.) 44 | 45 | with pytest.raises(AssertionError): 46 | pad1d(x, (0, -1), mode='constant', value=0.) 47 | 48 | with pytest.raises(AssertionError): 49 | pad1d(x, (-1, -1), mode='constant', value=0.) 50 | 51 | 52 | def test_pad1d_reflect(): 53 | x = torch.randn(1, 1, 20) 54 | 55 | xp1 = pad1d(x, (0, 5), mode='reflect', value=0.) 56 | assert xp1.shape[-1] == 25 57 | xp2 = pad1d(x, (5, 5), mode='reflect', value=0.) 58 | assert xp2.shape[-1] == 30 59 | xp3 = pad1d(x, (0, 0), mode='reflect', value=0.) 60 | assert xp3.shape[-1] == 20 61 | xp4 = pad1d(x, (10, 30), mode='reflect', value=0.) 62 | assert xp4.shape[-1] == 60 63 | 64 | with pytest.raises(AssertionError): 65 | pad1d(x, (-1, 0), mode='reflect', value=0.) 66 | 67 | with pytest.raises(AssertionError): 68 | pad1d(x, (0, -1), mode='reflect', value=0.) 69 | 70 | with pytest.raises(AssertionError): 71 | pad1d(x, (-1, -1), mode='reflect', value=0.) 72 | 73 | 74 | def test_unpad1d(): 75 | x = torch.randn(1, 1, 20) 76 | 77 | u1 = unpad1d(x, (5, 5)) 78 | assert u1.shape[-1] == 10 79 | u2 = unpad1d(x, (0, 5)) 80 | assert u2.shape[-1] == 15 81 | u3 = unpad1d(x, (5, 0)) 82 | assert u3.shape[-1] == 15 83 | u4 = unpad1d(x, (0, 0)) 84 | assert u4.shape[-1] == x.shape[-1] 85 | 86 | with pytest.raises(AssertionError): 87 | unpad1d(x, (-1, 0)) 88 | 89 | with pytest.raises(AssertionError): 90 | unpad1d(x, (0, -1)) 91 | 92 | with pytest.raises(AssertionError): 93 | unpad1d(x, (-1, -1)) 94 | 95 | 96 | class TestNormConv1d: 97 | 98 | def test_norm_conv1d_modules(self): 99 | N, C, T = 2, 2, random.randrange(1, 100_000) 100 | t0 = torch.randn(N, C, T) 101 | 102 | C_out, kernel_size, stride = 1, 4, 1 103 | expected_out_length = int((T - kernel_size) / stride + 1) 104 | wn_conv = NormConv1d(C, 1, kernel_size=4, norm='weight_norm') 105 | gn_conv = NormConv1d(C, 1, kernel_size=4, norm='time_group_norm') 106 | nn_conv = NormConv1d(C, 1, kernel_size=4, norm='none') 107 | 108 | assert isinstance(wn_conv.norm, nn.Identity) 109 | assert isinstance(wn_conv.conv, nn.Conv1d) 110 | 111 | assert isinstance(gn_conv.norm, nn.GroupNorm) 112 | assert isinstance(gn_conv.conv, nn.Conv1d) 113 | 114 | assert isinstance(nn_conv.norm, nn.Identity) 115 | assert isinstance(nn_conv.conv, nn.Conv1d) 116 | 117 | for conv_layer in [wn_conv, gn_conv, nn_conv]: 118 | out = conv_layer(t0) 119 | assert isinstance(out, torch.Tensor) 120 | assert list(out.shape) == [N, C_out, expected_out_length] 121 | 122 | 123 | class TestNormConvTranspose1d: 124 | 125 | def test_normalizations(self): 126 | N, C, T = 2, 2, random.randrange(1, 100_000) 127 | t0 = torch.randn(N, C, T) 128 | 129 | C_out, kernel_size, stride = 1, 4, 1 130 | expected_out_length = (T - 1) * stride + (kernel_size - 1) + 1 131 | 132 | wn_convtr = NormConvTranspose1d(C, C_out, kernel_size=kernel_size, stride=stride, norm='weight_norm') 133 | gn_convtr = NormConvTranspose1d(C, C_out, kernel_size=kernel_size, stride=stride, norm='time_group_norm') 134 | nn_convtr = NormConvTranspose1d(C, C_out, kernel_size=kernel_size, stride=stride, norm='none') 135 | 136 | assert isinstance(wn_convtr.norm, nn.Identity) 137 | assert isinstance(wn_convtr.convtr, nn.ConvTranspose1d) 138 | 139 | assert isinstance(gn_convtr.norm, nn.GroupNorm) 140 | assert isinstance(gn_convtr.convtr, nn.ConvTranspose1d) 141 | 142 | assert isinstance(nn_convtr.norm, nn.Identity) 143 | assert isinstance(nn_convtr.convtr, nn.ConvTranspose1d) 144 | 145 | for convtr_layer in [wn_convtr, gn_convtr, nn_convtr]: 146 | out = convtr_layer(t0) 147 | assert isinstance(out, torch.Tensor) 148 | assert list(out.shape) == [N, C_out, expected_out_length] 149 | 150 | 151 | class TestStreamableConv1d: 152 | 153 | def get_streamable_conv1d_output_length(self, length, kernel_size, stride, dilation): 154 | # StreamableConv1d internally pads to make sure that the last window is full 155 | padding_total = (kernel_size - 1) * dilation - (stride - 1) 156 | n_frames = (length - kernel_size + padding_total) / stride + 1 157 | ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total) 158 | return ideal_length // stride 159 | 160 | def test_streamable_conv1d(self): 161 | N, C, T = 2, 2, random.randrange(1, 100_000) 162 | t0 = torch.randn(N, C, T) 163 | C_out = 1 164 | 165 | # conv params are [(kernel_size, stride, dilation)] 166 | conv_params = [(4, 1, 1), (4, 2, 1), (3, 1, 3), (10, 5, 1), (3, 2, 3)] 167 | for causal, (kernel_size, stride, dilation) in product([False, True], conv_params): 168 | expected_out_length = self.get_streamable_conv1d_output_length(T, kernel_size, stride, dilation) 169 | sconv = StreamableConv1d(C, C_out, kernel_size=kernel_size, stride=stride, dilation=dilation, causal=causal) 170 | out = sconv(t0) 171 | assert isinstance(out, torch.Tensor) 172 | print(list(out.shape), [N, C_out, expected_out_length]) 173 | assert list(out.shape) == [N, C_out, expected_out_length] 174 | 175 | 176 | class TestStreamableConvTranspose1d: 177 | 178 | def get_streamable_convtr1d_output_length(self, length, kernel_size, stride): 179 | padding_total = (kernel_size - stride) 180 | return (length - 1) * stride - padding_total + (kernel_size - 1) + 1 181 | 182 | def test_streamable_convtr1d(self): 183 | N, C, T = 2, 2, random.randrange(1, 100_000) 184 | t0 = torch.randn(N, C, T) 185 | 186 | C_out = 1 187 | 188 | with pytest.raises(AssertionError): 189 | StreamableConvTranspose1d(C, C_out, kernel_size=4, causal=False, trim_right_ratio=0.5) 190 | StreamableConvTranspose1d(C, C_out, kernel_size=4, causal=True, trim_right_ratio=-1.) 191 | StreamableConvTranspose1d(C, C_out, kernel_size=4, causal=True, trim_right_ratio=2) 192 | 193 | # causal params are [(causal, trim_right)] 194 | causal_params = [(False, 1.0), (True, 1.0), (True, 0.5), (True, 0.0)] 195 | # conv params are [(kernel_size, stride)] 196 | conv_params = [(4, 1), (4, 2), (3, 1), (10, 5)] 197 | for ((causal, trim_right_ratio), (kernel_size, stride)) in product(causal_params, conv_params): 198 | expected_out_length = self.get_streamable_convtr1d_output_length(T, kernel_size, stride) 199 | sconvtr = StreamableConvTranspose1d(C, C_out, kernel_size=kernel_size, stride=stride, 200 | causal=causal, trim_right_ratio=trim_right_ratio) 201 | out = sconvtr(t0) 202 | assert isinstance(out, torch.Tensor) 203 | assert list(out.shape) == [N, C_out, expected_out_length] 204 | -------------------------------------------------------------------------------- /tests/modules/test_lstm.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import random 8 | import torch 9 | 10 | from audiocraft.modules.lstm import StreamableLSTM 11 | 12 | 13 | class TestStreamableLSTM: 14 | 15 | def test_lstm(self): 16 | B, C, T = 4, 2, random.randint(1, 100) 17 | 18 | lstm = StreamableLSTM(C, 3, skip=False) 19 | x = torch.randn(B, C, T) 20 | y = lstm(x) 21 | 22 | print(y.shape) 23 | assert y.shape == torch.Size([B, C, T]) 24 | 25 | def test_lstm_skip(self): 26 | B, C, T = 4, 2, random.randint(1, 100) 27 | 28 | lstm = StreamableLSTM(C, 3, skip=True) 29 | x = torch.randn(B, C, T) 30 | y = lstm(x) 31 | 32 | assert y.shape == torch.Size([B, C, T]) 33 | -------------------------------------------------------------------------------- /tests/modules/test_rope.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | 9 | from audiocraft.modules.rope import RotaryEmbedding 10 | from audiocraft.modules.transformer import StreamingTransformer 11 | 12 | 13 | def test_rope(): 14 | B, T, H, C = 8, 75, 16, 128 15 | 16 | rope = RotaryEmbedding(dim=C) 17 | xq = torch.rand((B, T, H, C)) 18 | xk = torch.rand((B, T, H, C)) 19 | xq_out, xk_out = rope.rotate_qk(xq, xk, start=7) 20 | 21 | assert list(xq_out.shape) == [B, T, H, C] 22 | assert list(xk_out.shape) == [B, T, H, C] 23 | 24 | 25 | def test_rope_io_dtypes(): 26 | B, T, H, C = 8, 75, 16, 128 27 | 28 | rope_32 = RotaryEmbedding(dim=C, dtype=torch.float32) 29 | rope_64 = RotaryEmbedding(dim=C, dtype=torch.float64) 30 | 31 | # Test bfloat16 inputs w/ both 32 and 64 precision rope. 32 | xq_16 = torch.rand((B, T, H, C)).to(torch.bfloat16) 33 | xk_16 = torch.rand((B, T, H, C)).to(torch.bfloat16) 34 | xq_out, xk_out = rope_32.rotate_qk(xq_16, xk_16) 35 | assert xq_out.dtype == torch.bfloat16 36 | xq_out, xk_out = rope_64.rotate_qk(xq_16, xk_16) 37 | assert xq_out.dtype == torch.bfloat16 38 | 39 | # Test float32 inputs w/ both 32 and 64 precision rope. 40 | xq_32 = torch.rand((B, T, H, C)).to(torch.float32) 41 | xk_32 = torch.rand((B, T, H, C)).to(torch.float32) 42 | xq_out, xk_out = rope_32.rotate_qk(xq_32, xk_32) 43 | assert xq_out.dtype == torch.float32 44 | xq_out, xk_out = rope_64.rotate_qk(xq_32, xk_32) 45 | assert xq_out.dtype == torch.float32 46 | 47 | 48 | def test_transformer_with_rope(): 49 | torch.manual_seed(1234) 50 | for pos in ['rope', 'sin_rope']: 51 | tr = StreamingTransformer( 52 | 16, 4, 2, custom=True, dropout=0., layer_scale=0.1, 53 | positional_embedding=pos) 54 | tr.eval() 55 | steps = 12 56 | x = torch.randn(3, steps, 16) 57 | 58 | out = tr(x) 59 | assert list(out.shape) == list(x.shape) 60 | 61 | 62 | @torch.no_grad() 63 | def test_rope_streaming(): 64 | torch.manual_seed(1234) 65 | tr = StreamingTransformer( 66 | 16, 4, 2, causal=True, dropout=0., 67 | custom=True, positional_embedding='rope') 68 | tr.eval() 69 | steps = 12 70 | x = torch.randn(3, steps, 16) 71 | 72 | ref = tr(x) 73 | 74 | with tr.streaming(): 75 | outs = [] 76 | frame_sizes = [1] * steps 77 | 78 | for frame_size in frame_sizes: 79 | frame = x[:, :frame_size] 80 | x = x[:, frame_size:] 81 | outs.append(tr(frame)) 82 | 83 | out = torch.cat(outs, dim=1) 84 | assert list(out.shape) == [3, steps, 16] 85 | delta = torch.norm(out - ref) / torch.norm(out) 86 | assert delta < 1e-6, delta 87 | 88 | 89 | @torch.no_grad() 90 | def test_rope_streaming_past_context(): 91 | torch.manual_seed(1234) 92 | 93 | for context in [None, 10]: 94 | tr = StreamingTransformer( 95 | 16, 4, 1 if context else 2, 96 | causal=True, past_context=context, custom=True, 97 | dropout=0., positional_embedding='rope') 98 | tr.eval() 99 | 100 | steps = 20 101 | x = torch.randn(3, steps, 16) 102 | ref = tr(x) 103 | 104 | with tr.streaming(): 105 | outs = [] 106 | frame_sizes = [1] * steps 107 | 108 | for frame_size in frame_sizes: 109 | frame = x[:, :frame_size] 110 | x = x[:, frame_size:] 111 | outs.append(tr(frame)) 112 | 113 | out = torch.cat(outs, dim=1) 114 | assert list(out.shape) == [3, steps, 16] 115 | delta = torch.norm(out - ref) / torch.norm(out) 116 | assert delta < 1e-6, delta 117 | 118 | 119 | def test_rope_memory_efficient(): 120 | torch.manual_seed(1234) 121 | tr = StreamingTransformer( 122 | 16, 4, 2, custom=True, dropout=0., layer_scale=0.1, 123 | positional_embedding='rope') 124 | tr_mem_efficient = StreamingTransformer( 125 | 16, 4, 2, dropout=0., memory_efficient=True, layer_scale=0.1, 126 | positional_embedding='rope') 127 | tr_mem_efficient.load_state_dict(tr.state_dict()) 128 | tr.eval() 129 | steps = 12 130 | x = torch.randn(3, steps, 16) 131 | 132 | with torch.no_grad(): 133 | y = tr(x) 134 | y2 = tr_mem_efficient(x) 135 | # Check at float precision b/c this is the rope default. 136 | assert torch.allclose(y, y2, atol=1e-7), (y - y2).norm() 137 | 138 | 139 | def test_rope_with_xpos(): 140 | B, T, H, C = 8, 75, 16, 128 141 | 142 | rope = RotaryEmbedding(dim=C, xpos=True) 143 | xq = torch.rand((B, T, H, C)) 144 | xk = torch.rand((B, T, H, C)) 145 | xq_out, xk_out = rope.rotate_qk(xq, xk, start=7) 146 | 147 | assert list(xq_out.shape) == [B, T, H, C] 148 | assert list(xk_out.shape) == [B, T, H, C] 149 | 150 | 151 | def test_positional_scale(): 152 | B, T, H, C = 8, 75, 16, 128 153 | 154 | rope = RotaryEmbedding(dim=C, xpos=True, scale=0.0) 155 | xq = torch.rand((B, T, H, C)) 156 | xk = torch.rand((B, T, H, C)) 157 | xq_out, xk_out = rope.rotate_qk(xq, xk, start=7) 158 | 159 | assert torch.allclose(xq, xq_out) 160 | assert torch.allclose(xk, xk_out) 161 | -------------------------------------------------------------------------------- /tests/modules/test_seanet.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from itertools import product 8 | 9 | import pytest 10 | import torch 11 | 12 | from audiocraft.modules.seanet import SEANetEncoder, SEANetDecoder, SEANetResnetBlock 13 | from audiocraft.modules import StreamableConv1d, StreamableConvTranspose1d 14 | 15 | 16 | class TestSEANetModel: 17 | 18 | def test_base(self): 19 | encoder = SEANetEncoder() 20 | decoder = SEANetDecoder() 21 | 22 | x = torch.randn(1, 1, 24000) 23 | z = encoder(x) 24 | assert list(z.shape) == [1, 128, 75], z.shape 25 | y = decoder(z) 26 | assert y.shape == x.shape, (x.shape, y.shape) 27 | 28 | def test_causal(self): 29 | encoder = SEANetEncoder(causal=True) 30 | decoder = SEANetDecoder(causal=True) 31 | x = torch.randn(1, 1, 24000) 32 | 33 | z = encoder(x) 34 | assert list(z.shape) == [1, 128, 75], z.shape 35 | y = decoder(z) 36 | assert y.shape == x.shape, (x.shape, y.shape) 37 | 38 | def test_conv_skip_connection(self): 39 | encoder = SEANetEncoder(true_skip=False) 40 | decoder = SEANetDecoder(true_skip=False) 41 | 42 | x = torch.randn(1, 1, 24000) 43 | z = encoder(x) 44 | assert list(z.shape) == [1, 128, 75], z.shape 45 | y = decoder(z) 46 | assert y.shape == x.shape, (x.shape, y.shape) 47 | 48 | def test_seanet_encoder_decoder_final_act(self): 49 | encoder = SEANetEncoder(true_skip=False) 50 | decoder = SEANetDecoder(true_skip=False, final_activation='Tanh') 51 | 52 | x = torch.randn(1, 1, 24000) 53 | z = encoder(x) 54 | assert list(z.shape) == [1, 128, 75], z.shape 55 | y = decoder(z) 56 | assert y.shape == x.shape, (x.shape, y.shape) 57 | 58 | def _check_encoder_blocks_norm(self, encoder: SEANetEncoder, n_disable_blocks: int, norm: str): 59 | n_blocks = 0 60 | for layer in encoder.model: 61 | if isinstance(layer, StreamableConv1d): 62 | n_blocks += 1 63 | assert layer.conv.norm_type == 'none' if n_blocks <= n_disable_blocks else norm 64 | elif isinstance(layer, SEANetResnetBlock): 65 | for resnet_layer in layer.block: 66 | if isinstance(resnet_layer, StreamableConv1d): 67 | # here we add + 1 to n_blocks as we increment n_blocks just after the block 68 | assert resnet_layer.conv.norm_type == 'none' if (n_blocks + 1) <= n_disable_blocks else norm 69 | 70 | def test_encoder_disable_norm(self): 71 | n_residuals = [0, 1, 3] 72 | disable_blocks = [0, 1, 2, 3, 4, 5, 6] 73 | norms = ['weight_norm', 'none'] 74 | for n_res, disable_blocks, norm in product(n_residuals, disable_blocks, norms): 75 | encoder = SEANetEncoder(n_residual_layers=n_res, norm=norm, 76 | disable_norm_outer_blocks=disable_blocks) 77 | self._check_encoder_blocks_norm(encoder, disable_blocks, norm) 78 | 79 | def _check_decoder_blocks_norm(self, decoder: SEANetDecoder, n_disable_blocks: int, norm: str): 80 | n_blocks = 0 81 | for layer in decoder.model: 82 | if isinstance(layer, StreamableConv1d): 83 | n_blocks += 1 84 | assert layer.conv.norm_type == 'none' if (decoder.n_blocks - n_blocks) < n_disable_blocks else norm 85 | elif isinstance(layer, StreamableConvTranspose1d): 86 | n_blocks += 1 87 | assert layer.convtr.norm_type == 'none' if (decoder.n_blocks - n_blocks) < n_disable_blocks else norm 88 | elif isinstance(layer, SEANetResnetBlock): 89 | for resnet_layer in layer.block: 90 | if isinstance(resnet_layer, StreamableConv1d): 91 | assert resnet_layer.conv.norm_type == 'none' \ 92 | if (decoder.n_blocks - n_blocks) < n_disable_blocks else norm 93 | 94 | def test_decoder_disable_norm(self): 95 | n_residuals = [0, 1, 3] 96 | disable_blocks = [0, 1, 2, 3, 4, 5, 6] 97 | norms = ['weight_norm', 'none'] 98 | for n_res, disable_blocks, norm in product(n_residuals, disable_blocks, norms): 99 | decoder = SEANetDecoder(n_residual_layers=n_res, norm=norm, 100 | disable_norm_outer_blocks=disable_blocks) 101 | self._check_decoder_blocks_norm(decoder, disable_blocks, norm) 102 | 103 | def test_disable_norm_raises_exception(self): 104 | # Invalid disable_norm_outer_blocks values raise exceptions 105 | with pytest.raises(AssertionError): 106 | SEANetEncoder(disable_norm_outer_blocks=-1) 107 | 108 | with pytest.raises(AssertionError): 109 | SEANetEncoder(ratios=[1, 1, 2, 2], disable_norm_outer_blocks=7) 110 | 111 | with pytest.raises(AssertionError): 112 | SEANetDecoder(disable_norm_outer_blocks=-1) 113 | 114 | with pytest.raises(AssertionError): 115 | SEANetDecoder(ratios=[1, 1, 2, 2], disable_norm_outer_blocks=7) 116 | -------------------------------------------------------------------------------- /tests/modules/test_transformer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from itertools import product 8 | 9 | import pytest 10 | import torch 11 | 12 | from audiocraft.modules.transformer import StreamingMultiheadAttention, StreamingTransformer 13 | 14 | 15 | def test_transformer_causal_streaming(): 16 | torch.manual_seed(1234) 17 | 18 | for context, custom in product([None, 10], [False, True]): 19 | # Test that causality and receptive fields are properly handled. 20 | # looking at the gradients 21 | tr = StreamingTransformer( 22 | 16, 4, 1 if context else 2, 23 | causal=True, past_context=context, custom=custom, 24 | dropout=0.) 25 | steps = 20 26 | for k in [0, 10, 15, 19]: 27 | x = torch.randn(4, steps, 16, requires_grad=True) 28 | y = tr(x) 29 | y[:, k].abs().sum().backward() 30 | if k + 1 < steps: 31 | assert torch.allclose(x.grad[:, k + 1:], torch.tensor(0.)), x.grad[:, k + 1:].norm() 32 | assert not torch.allclose(x.grad[:, :k + 1], torch.tensor(0.)), x.grad[:, :k + 1].norm() 33 | if context is not None and k > context: 34 | limit = k - context - 1 35 | assert torch.allclose(x.grad[:, :limit], 36 | torch.tensor(0.)), x.grad[:, :limit].norm() 37 | 38 | # Now check that streaming gives the same result at batch eval. 39 | x = torch.randn(4, steps, 16) 40 | y = tr(x) 41 | ys = [] 42 | with tr.streaming(): 43 | for k in range(steps): 44 | chunk = x[:, k:k + 1, :] 45 | ys.append(tr(chunk)) 46 | y_stream = torch.cat(ys, dim=1) 47 | delta = torch.norm(y_stream - y) / torch.norm(y) 48 | assert delta < 1e-6, delta 49 | 50 | 51 | def test_transformer_vs_pytorch(): 52 | torch.manual_seed(1234) 53 | # Check that in the non causal setting, we get the same result as 54 | # PyTorch Transformer encoder. 55 | for custom in [False, True]: 56 | tr = StreamingTransformer( 57 | 16, 4, 2, 58 | causal=False, custom=custom, dropout=0., positional_scale=0.) 59 | layer = torch.nn.TransformerEncoderLayer(16, 4, dropout=0., batch_first=True) 60 | tr_ref = torch.nn.TransformerEncoder(layer, 2) 61 | tr.load_state_dict(tr_ref.state_dict()) 62 | 63 | x = torch.randn(4, 20, 16) 64 | y = tr(x) 65 | y2 = tr_ref(x) 66 | delta = torch.norm(y2 - y) / torch.norm(y) 67 | assert delta < 1e-6, delta 68 | 69 | 70 | def test_streaming_api(): 71 | tr = StreamingTransformer(16, 4, 2, causal=True, dropout=0.) 72 | tr.eval() 73 | steps = 12 74 | x = torch.randn(1, steps, 16) 75 | 76 | with torch.no_grad(): 77 | with tr.streaming(): 78 | _ = tr(x[:, :1]) 79 | state = {k: v.clone() for k, v in tr.get_streaming_state().items()} 80 | y = tr(x[:, 1:2]) 81 | tr.set_streaming_state(state) 82 | y2 = tr(x[:, 1:2]) 83 | assert torch.allclose(y, y2), (y - y2).norm() 84 | assert tr.flush() is None 85 | 86 | 87 | def test_memory_efficient(): 88 | torch.manual_seed(1234) 89 | tr = StreamingTransformer( 90 | 16, 4, 2, custom=True, dropout=0., layer_scale=0.1) 91 | tr_mem_efficient = StreamingTransformer( 92 | 16, 4, 2, dropout=0., memory_efficient=True, layer_scale=0.1) 93 | tr_mem_efficient.load_state_dict(tr.state_dict()) 94 | tr.eval() 95 | steps = 12 96 | x = torch.randn(3, steps, 16) 97 | 98 | with torch.no_grad(): 99 | y = tr(x) 100 | y2 = tr_mem_efficient(x) 101 | assert torch.allclose(y, y2), (y - y2).norm() 102 | 103 | 104 | def test_attention_as_float32(): 105 | torch.manual_seed(1234) 106 | cases = [ 107 | {'custom': True}, 108 | {'custom': False}, 109 | ] 110 | for case in cases: 111 | tr = StreamingTransformer(16, 4, 2, dropout=0., dtype=torch.bfloat16, **case) 112 | tr_float32 = StreamingTransformer( 113 | 16, 4, 2, dropout=0., attention_as_float32=True, dtype=torch.bfloat16, **case) 114 | if not case['custom']: 115 | # we are not using autocast here because it doesn't really 116 | # work as expected on CPU, so we have to manually cast the weights of the MHA. 117 | for layer in tr_float32.layers: 118 | layer.self_attn.mha.to(torch.float32) 119 | tr_float32.load_state_dict(tr.state_dict()) 120 | steps = 12 121 | x = torch.randn(3, steps, 16, dtype=torch.bfloat16) 122 | 123 | with torch.no_grad(): 124 | y = tr(x) 125 | y2 = tr_float32(x) 126 | assert not torch.allclose(y, y2), (y - y2).norm() 127 | 128 | 129 | @torch.no_grad() 130 | def test_streaming_memory_efficient(): 131 | torch.manual_seed(1234) 132 | tr = StreamingTransformer(16, 4, 2, causal=True, dropout=0., custom=True) 133 | tr_mem_efficient = StreamingTransformer( 134 | 16, 4, 2, dropout=0., memory_efficient=True, causal=True) 135 | tr.load_state_dict(tr_mem_efficient.state_dict()) 136 | tr.eval() 137 | tr_mem_efficient.eval() 138 | steps = 12 139 | x = torch.randn(3, steps, 16) 140 | 141 | ref = tr(x) 142 | 143 | with tr_mem_efficient.streaming(): 144 | outs = [] 145 | # frame_sizes = [2] + [1] * (steps - 2) 146 | frame_sizes = [1] * steps 147 | 148 | for frame_size in frame_sizes: 149 | frame = x[:, :frame_size] 150 | x = x[:, frame_size:] 151 | outs.append(tr_mem_efficient(frame)) 152 | 153 | out = torch.cat(outs, dim=1) 154 | delta = torch.norm(out - ref) / torch.norm(out) 155 | assert delta < 1e-6, delta 156 | 157 | 158 | def test_cross_attention(): 159 | torch.manual_seed(1234) 160 | for norm_first in [True, False]: 161 | m = StreamingTransformer( 162 | 16, 4, 2, cross_attention=False, norm_first=norm_first, dropout=0., custom=True) 163 | m_cross = StreamingTransformer( 164 | 16, 4, 2, cross_attention=True, norm_first=norm_first, dropout=0., custom=True) 165 | m_cross.load_state_dict(m.state_dict(), strict=False) 166 | x = torch.randn(2, 5, 16) 167 | cross_x = torch.randn(2, 3, 16) 168 | y_ref = m(x) 169 | y_cross_zero = m_cross(x, cross_attention_src=0 * cross_x) 170 | # With norm_first, the two should be exactly yhe same, 171 | # but with norm_first=False, we get 2 normalization in a row 172 | # and the epsilon value leads to a tiny change. 173 | atol = 0. if norm_first else 1e-6 174 | print((y_ref - y_cross_zero).norm() / y_ref.norm()) 175 | assert torch.allclose(y_ref, y_cross_zero, atol=atol) 176 | 177 | # We now expect a difference even with a generous atol of 1e-2. 178 | y_cross = m_cross(x, cross_attention_src=cross_x) 179 | assert not torch.allclose(y_cross, y_cross_zero, atol=1e-2) 180 | 181 | with pytest.raises(AssertionError): 182 | _ = m_cross(x) 183 | _ = m(x, cross_attention_src=cross_x) 184 | 185 | 186 | def test_cross_attention_compat(): 187 | torch.manual_seed(1234) 188 | num_heads = 2 189 | dim = num_heads * 64 190 | with pytest.raises(AssertionError): 191 | StreamingMultiheadAttention(dim, num_heads, causal=True, cross_attention=True) 192 | 193 | cross_attn = StreamingMultiheadAttention( 194 | dim, num_heads, dropout=0, cross_attention=True, custom=True) 195 | ref_attn = torch.nn.MultiheadAttention(dim, num_heads, dropout=0, batch_first=True) 196 | 197 | # We can load the regular attention state dict 198 | # so we have compat when loading old checkpoints. 199 | cross_attn.load_state_dict(ref_attn.state_dict()) 200 | 201 | queries = torch.randn(3, 7, dim) 202 | keys = torch.randn(3, 9, dim) 203 | values = torch.randn(3, 9, dim) 204 | 205 | y = cross_attn(queries, keys, values)[0] 206 | y_ref = ref_attn(queries, keys, values)[0] 207 | assert torch.allclose(y, y_ref, atol=1e-7) 208 | 209 | # Now let's check that streaming is working properly. 210 | with cross_attn.streaming(): 211 | ys = [] 212 | for step in range(queries.shape[1]): 213 | ys.append(cross_attn(queries[:, step: step + 1], keys, values)[0]) 214 | y_streaming = torch.cat(ys, dim=1) 215 | assert torch.allclose(y_streaming, y, atol=1e-7) 216 | 217 | 218 | def test_repeat_kv(): 219 | torch.manual_seed(1234) 220 | num_heads = 8 221 | kv_repeat = 4 222 | dim = num_heads * 64 223 | with pytest.raises(AssertionError): 224 | mha = StreamingMultiheadAttention( 225 | dim, num_heads, causal=True, kv_repeat=kv_repeat, cross_attention=True) 226 | mha = StreamingMultiheadAttention( 227 | dim, num_heads, causal=True, kv_repeat=kv_repeat) 228 | mha = StreamingMultiheadAttention( 229 | dim, num_heads, causal=True, kv_repeat=kv_repeat, custom=True) 230 | x = torch.randn(4, 18, dim) 231 | y = mha(x, x, x)[0] 232 | assert x.shape == y.shape 233 | 234 | 235 | def test_qk_layer_norm(): 236 | torch.manual_seed(1234) 237 | tr = StreamingTransformer( 238 | 16, 4, 2, custom=True, dropout=0., qk_layer_norm=True, bias_attn=False) 239 | steps = 12 240 | x = torch.randn(3, steps, 16) 241 | y = tr(x) 242 | 243 | tr = StreamingTransformer( 244 | 16, 4, 2, custom=True, dropout=0., qk_layer_norm=True, cross_attention=True) 245 | z = torch.randn(3, 21, 16) 246 | y = tr(x, cross_attention_src=z) 247 | assert y.shape == x.shape 248 | -------------------------------------------------------------------------------- /tests/quantization/test_vq.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | 9 | from audiocraft.quantization.vq import ResidualVectorQuantizer 10 | 11 | 12 | class TestResidualVectorQuantizer: 13 | 14 | def test_rvq(self): 15 | x = torch.randn(1, 16, 2048) 16 | vq = ResidualVectorQuantizer(n_q=8, dimension=16, bins=8) 17 | res = vq(x, 1.) 18 | assert res.x.shape == torch.Size([1, 16, 2048]) 19 | -------------------------------------------------------------------------------- /tests/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | --------------------------------------------------------------------------------