├── .gitignore ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── MODEL_CARD.md ├── README.md ├── assets └── spiritlm_overview.png ├── checkpoints └── README.md ├── data └── examples │ ├── pred.jsonl │ └── ref.jsonl ├── env.yml ├── examples ├── audio │ └── 7143-88743-0029.flac ├── distributed_inference_recipe │ ├── multi_nodes.slurm │ └── run_dist.py ├── speech_generation │ └── spirit_model.ipynb └── speech_tokenizer │ └── spiritlm_speech_tokenizer.ipynb ├── requirements.dev.txt ├── requirements.txt ├── setup.py ├── spiritlm ├── __init__.py ├── eval │ ├── README.md │ ├── eval_stsp.py │ ├── load_data.py │ ├── stsp │ │ ├── few_shot_prompt.py │ │ ├── predict_stsp.py │ │ ├── sanity_check_download.py │ │ ├── sentiment_classifiers.py │ │ ├── stsp_constants.py │ │ └── utils.py │ └── utils.py ├── model │ ├── README.md │ ├── __init__.py │ ├── spiritlm_model.py │ └── utils.py └── speech_tokenizer │ ├── README.md │ ├── __init__.py │ ├── f0 │ ├── __init__.py │ ├── f0_extractor.py │ ├── f0_tokenizer.py │ └── vqvae.py │ ├── hifigan │ ├── __init__.py │ └── hifigan_vocoder.py │ ├── hubert │ ├── __init__.py │ ├── hubert_model │ │ ├── __init__.py │ │ ├── hubert_model.py │ │ └── wav2vec2_model.py │ ├── hubert_tokenizer.py │ └── quantizer_model.py │ ├── spiritlm_tokenizer.py │ └── style_encoder │ ├── __init__.py │ └── w2v2_encoder.py └── tests ├── __init__.py ├── test_spirit_model.py └── test_tokenizer.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 110 | .pdm.toml 111 | .pdm-python 112 | .pdm-build/ 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # pytype static type analyzer 152 | .pytype/ 153 | 154 | # Cython debug symbols 155 | cython_debug/ 156 | 157 | # PyCharm 158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 160 | # and can be added to the global gitignore or merged into this file. For a more nuclear 161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 162 | #.idea/ 163 | 164 | # local checkpoints folder 165 | checkpoints/ 166 | 167 | # local data folder 168 | data/ -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to make participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies within all project spaces, and it also applies when 49 | an individual is representing the project or its community in public spaces. 50 | Examples of representing a project or community include using an official 51 | project e-mail address, posting via an official social media account, or acting 52 | as an appointed representative at an online or offline event. Representation of 53 | a project may be further defined and clarified by project maintainers. 54 | 55 | This Code of Conduct also applies outside the project spaces when there is a 56 | reasonable belief that an individual's behavior may have a negative impact on 57 | the project or its community. 58 | 59 | ## Enforcement 60 | 61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 62 | reported by contacting the project team at . All 63 | complaints will be reviewed and investigated and will result in a response that 64 | is deemed necessary and appropriate to the circumstances. The project team is 65 | obligated to maintain confidentiality with regard to the reporter of an incident. 66 | Further details of specific enforcement policies may be posted separately. 67 | 68 | Project maintainers who do not follow or enforce the Code of Conduct in good 69 | faith may face temporary or permanent repercussions as determined by other 70 | members of the project's leadership. 71 | 72 | ## Attribution 73 | 74 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 75 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 76 | 77 | [homepage]: https://www.contributor-covenant.org 78 | 79 | For answers to common questions about this code of conduct, see 80 | https://www.contributor-covenant.org/faq 81 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to spiritlm 2 | We want to make contributing to this project as easy and transparent as 3 | possible. 4 | 5 | ## Pull Requests 6 | We actively welcome your pull requests. 7 | 8 | 1. Fork the repo and create your branch from `main`. 9 | 2. If you've added code that should be tested, add tests. 10 | 3. If you've changed APIs, update the documentation. 11 | 4. Ensure the test suite passes. 12 | 5. Make sure your code lints. 13 | 6. If you haven't already, complete the Contributor License Agreement ("CLA"). 14 | 15 | ## Contributor License Agreement ("CLA") 16 | In order to accept your pull request, we need you to submit a CLA. You only need 17 | to do this once to work on any of Facebook's open source projects. 18 | 19 | Complete your CLA here: 20 | 21 | ## Issues 22 | We use GitHub issues to track public bugs. Please ensure your description is 23 | clear and has sufficient instructions to be able to reproduce the issue. 24 | 25 | Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe 26 | disclosure of security bugs. In those cases, please go through the process 27 | outlined on that page and do not file a public issue. 28 | 29 | ## License 30 | By contributing to spiritlm, you agree that your contributions will be licensed 31 | under the LICENSE file in the root directory of this source tree. -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | FAIR Noncommercial Research License 2 | Last Updated: October 18, 2024 3 | 4 | “Acceptable Use Policy” means the FAIR Acceptable Use Policy, applicable to Research Materials, that is incorporated into this Agreement. 5 | 6 | “Agreement” means the terms and conditions for use, reproduction, distribution and modification of the Research Materials set forth herein. 7 | 8 | “Documentation” means the specifications, manuals and documentation accompanying 9 | Research Materials distributed by Meta. 10 | 11 | “Licensee” or “you” means you, or your employer or any other person or entity (if you are entering into this Agreement on such person or entity’s behalf), of the age required under applicable laws, rules or regulations to provide legal consent and that has legal authority to bind your employer or such other person or entity if you are entering in this Agreement on their behalf. 12 | 13 | “Meta” or “we” means Meta Platforms Ireland Limited (if you are located in or, if you are an entity, your principal place of business is in the EEA or Switzerland) and Meta Platforms, Inc. (if you are located outside of the EEA or Switzerland). 14 | 15 | “Noncommercial Research Uses” means noncommercial research use cases related to research, development, education, processing, or analysis and in each case, is not primarily intended for commercial advantage or monetary compensation to you or others. 16 | 17 | “Research Materials” means, collectively, Documentation and the models, software and algorithms, including machine-learning model code, trained model weights, inference-enabling code, training-enabling code, fine-tuning enabling code, demonstration materials and other elements of the foregoing distributed by Meta and made available under this Agreement. 18 | 19 | By clicking “I Accept” below or by using or distributing any portion or element of the Research Materials, you agree to be bound by this Agreement. 20 | 21 | 22 | 1. License Rights and Redistribution. 23 | 24 | a. Grant of Rights. You are granted a non-exclusive, worldwide, non-transferable and royalty-free limited license under Meta’s intellectual property or other rights owned by Meta embodied in the Research Materials to use, reproduce, distribute, copy, create derivative works of, and make modifications to the Research Materials. 25 | 26 | b. Redistribution and Use. 27 | i. You will not use the Research Materials or any outputs or results of the Research Materials in connection with any commercial uses or for any uses other than Noncommercial Research Uses; 28 | 29 | ii. Distribution of Research Materials, and any derivative works thereof, are subject to the terms of this Agreement. If you distribute or make the Research Materials, or any derivative works thereof, available to a third party, you may only do so under the terms of this Agreement. You shall also provide a copy of this Agreement to such third party. 30 | 31 | iii. If you submit for publication the results of research you perform on, using, or otherwise in connection with Research Materials, you must acknowledge the use of Research Materials in your publication. 32 | 33 | iv. Your use of the Research Materials must comply with applicable laws and regulations (including Trade Control Laws) and adhere to the FAIR Acceptable Use Policy, which is hereby incorporated by reference into this Agreement. 34 | 35 | 2. User Support. Your Noncommercial Research Use of the Research Materials is done at your own discretion; Meta does not process any information nor provide any service in relation to such use. Meta is under no obligation to provide any support services for the Research Materials. Any support provided is “as is”, “with all faults”, and without warranty of any kind. 36 | 37 | 3. Disclaimer of Warranty. UNLESS REQUIRED BY APPLICABLE LAW, THE RESEARCH MATERIALS AND ANY OUTPUT AND RESULTS THEREFROM ARE PROVIDED ON AN “AS IS” BASIS, WITHOUT WARRANTIES OF ANY KIND, AND META DISCLAIMS ALL WARRANTIES OF ANY KIND, BOTH EXPRESS AND IMPLIED, INCLUDING, WITHOUT LIMITATION, ANY WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE FOR DETERMINING THE APPROPRIATENESS OF USING OR REDISTRIBUTING THE RESEARCH MATERIALS AND ASSUME ANY RISKS ASSOCIATED WITH YOUR USE OF THE RESEARCH MATERIALS AND ANY OUTPUT AND RESULTS. 38 | 39 | 4. Limitation of Liability. IN NO EVENT WILL META OR ITS AFFILIATES BE LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, TORT, NEGLIGENCE, PRODUCTS LIABILITY, OR OTHERWISE, ARISING OUT OF THIS AGREEMENT, FOR ANY LOST PROFITS OR ANY DIRECT OR INDIRECT, SPECIAL, CONSEQUENTIAL, INCIDENTAL, EXEMPLARY OR PUNITIVE DAMAGES, EVEN IF META OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF ANY OF THE FOREGOING. 40 | 41 | 5. Intellectual Property. 42 | 43 | a. Subject to Meta’s ownership of Research Materials and derivatives made by or for Meta, with respect to any derivative works and modifications of the Research Materials that are made by you, as between you and Meta, you are and will be the owner of such derivative works and modifications. 44 | 45 | b. If you institute litigation or other proceedings against Meta or any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Research Materials, outputs or results, or any portion of any of the foregoing, constitutes infringement of intellectual property or other rights owned or licensable by you, then any licenses granted to you under this Agreement shall terminate as of the date such litigation or claim is filed or instituted. You will indemnify and hold harmless Meta from and against any claim by any third party arising out of or related to your use or distribution of the Research Materials. 46 | 47 | 6. Term and Termination. The term of this Agreement will commence upon your acceptance of this Agreement or access to the Research Materials and will continue in full force and effect until terminated in accordance with the terms and conditions herein. Meta may terminate this Agreement if you are in breach of any term or condition of this Agreement. Upon termination of this Agreement, you shall delete and cease use of the Research Materials. Sections 5, 6 and 9 shall survive the termination of this Agreement. 48 | 49 | 7. Governing Law and Jurisdiction. This Agreement will be governed and construed under the laws of the State of California without regard to choice of law principles, and the UN Convention on Contracts for the International Sale of Goods does not apply to this Agreement. The courts of California shall have exclusive jurisdiction of any dispute arising out of this Agreement. 50 | 51 | 8. Modifications and Amendments. Meta may modify this Agreement from time to time by posting a revised version at [https://github.com/facebookresearch/spiritlm/blob/main/LICENSE]; provided that they are similar in spirit to the current version of the Agreement, but may differ in detail to address new problems or concerns. All such changes will be effective immediately. Your continued use of the Research Materials after any modification to this Agreement constitutes your agreement to such modification. Except as provided in this Agreement, no modification or addition to any provision of this Agreement will be binding unless it is in writing and signed by an authorized representative of both you and Meta. 52 | 53 | 54 | FAIR Acceptable Use Policy 55 | 56 | The Fundamental AI Research (FAIR) team at Meta seeks to further understanding of new and existing research domains with the mission of advancing the state-of-the-art in artificial intelligence through open research for the benefit of all. 57 | 58 | As part of this mission, Meta makes certain research materials available for noncommercial research use. Meta is committed to promoting the safe and responsible use of such research materials. 59 | 60 | 61 | Prohibited Uses 62 | 63 | You agree you will not use, or allow others to use, Research Materials to: 64 | 65 | 1. Violate the law or others’ rights, including to: 66 | a. Engage in, promote, generate, contribute to, encourage, plan, incite, or further illegal or unlawful activity or content, such as: 67 | i. Violence or terrorism 68 | ii. Exploitation or harm to children, including the solicitation, creation, acquisition, or dissemination of child exploitative content or failure to report Child Sexual Abuse Material 69 | iii. Human trafficking, exploitation, and sexual violence 70 | iv. The illegal distribution of information or materials to minors, including obscene materials, or failure to employ legally required age-gating in connection with such information or materials. 71 | v. Sexual solicitation 72 | iv. Any other criminal activity 73 | 74 | b. Engage in, promote, incite, or facilitate the harassment, abuse, threatening, or bullying of individuals or groups of individuals 75 | 76 | c. Engage in, promote, incite, or facilitate discrimination or other unlawful or harmful conduct in the provision of employment, employment benefits, credit, housing, other economic benefits, or other essential goods and services 77 | 78 | d. Engage in the unauthorized or unlicensed practice of any profession including, but not limited to, financial, legal, medical/health, or related professional practices 79 | 80 | e. Collect, process, disclose, generate, or infer health, demographic, or other sensitive personal or private information about individuals without rights and consents required by applicable laws 81 | 82 | f. Engage in or facilitate any action or generate any content that infringes, misappropriates, or otherwise violates any third-party rights, including the outputs or results of any technology using FAIR research materials 83 | 84 | g. Create, generate, or facilitate the creation of malicious code, malware, computer viruses or do anything else that could disable, overburden, interfere with or impair the proper working, integrity, operation or appearance of a website or computer system 85 | 86 | 2. Engage in, promote, incite, facilitate, or assist in the planning or development of activities that present a risk of death or bodily harm to individuals, including use of research artifacts related to the following: 87 | 88 | a. Military, warfare, nuclear industries or applications, espionage, use for materials or activities that are subject to the International Traffic Arms Regulations (ITAR) maintained by the United States Department of State 89 | 90 | b. Guns and illegal weapons (including weapon development) 91 | 92 | c. Illegal drugs and regulated/controlled substances 93 | 94 | d. Operation of critical infrastructure, transportation technologies, or heavy machinery 95 | 96 | e. Self-harm or harm to others, including suicide, cutting, and eating disorders 97 | 98 | f. Any content intended to incite or promote violence, abuse, or any infliction of bodily harm to an individual 99 | 100 | 3. Intentionally deceive or mislead others, including use of FAIR Research Materials related to the following: 101 | 102 | a. Generating, promoting, or furthering fraud or the creation or promotion of disinformation 103 | 104 | b. Generating, promoting, or furthering defamatory content, including the creation of defamatory statements, images, or other content 105 | 106 | c. Generating, promoting, or further distributing spam 107 | 108 | d. Impersonating another individual without consent, authorization, or legal right 109 | 110 | e. Representing that outputs of FAIR research materials or outputs from technology using FAIR research materials o are human-generated 111 | 112 | f. Generating or facilitating false online engagement, including fake reviews and other means of fake online engagement 113 | 114 | 4. Fail to appropriately disclose to end users any known dangers of your Research Materials. 115 | 116 | Please report any violation of this Policy or other problems that could lead to a violation of this Policy by submitting a report here [https://docs.google.com/forms/d/e/1FAIpQLSeb11cryAopJ7LNrC4nxEUXrHY26hfkXQMf_uH-oFgA3WlYZQ/viewform]. 117 | -------------------------------------------------------------------------------- /MODEL_CARD.md: -------------------------------------------------------------------------------- 1 | # Meta Spirit LM Model Card 2 | 3 | ## Model Details 4 | 5 | *Note: Use of this model is governed by the FAIR Noncommercial Research License.* 6 | 7 | Spirit LM is a multimodal language model that freely mixes text and speech. The model can be prompted with either text or speech and is capable of generating outputs in either modality, while preserving the expressivity of the input prompt. The model is also able to learn new tasks across modalities such as automatic speech recognition, text-to-speech, and speech classification in a few-shot manner. 8 | 9 | ## Model Developers 10 | Meta 11 | 12 | ## Variations 13 | Spirit LM comes in two versions: Spirit LM Base that uses speech phonetic tokens and Spirit LM Expressive that models expressivity using pitch and style tokens in addition to the phonetic tokens. 14 | 15 | ## Input 16 | Models input text or speech or a mixed sequence of the two. 17 | 18 | ## Output 19 | Models generate text or speech or a mixed sequence of the two. 20 | 21 | ## Model Architecture 22 | ### Speech Tokenizer 23 | Spirit LM uses 3 types of speech tokenizers: Phonetic Tokenizer (HuBERT), Pitch Tokenizer (VQ-VAE) and Style Tokenizer (Speechprop or Wav2vec2). We use Hifi-GAN to convert the speech tokens back to audio. 24 | 25 | It is worth noting that in the associated paper, for Spirit LM Expressive, we used Speechprop to extract style tokens, while we use a Wav2vec2 model to extract style tokens in this release. 26 | 27 | | | Model | Parameters | Input | Output | 28 | |------------------------|--------------------------|------------|---------------------|--------------------| 29 | | Phonetic Tokenizer | HuBERT+LinearQuantizer | 96M | Waveform | Phonetic Tokens | 30 | | Pitch Tokenizer | VQ-VAE | 0.2M | Extracted F0 | Pitch Tokens | 31 | | Style Tokenizer | Wav2vec2+LinearProjection| 95M | Waveform | Style Tokens | 32 | | Base Speech Decoder | Hifi-GAN | 14M | Phonetic Tokens | Waveform | 33 | | Expressive Speech Decoder | Hifi-GAN | 15M | Phonetic, Pitch, Style Tokens | Waveform 34 | 35 | ### Language Model 36 | Spirit LM is initialized from the Llama-2 7B model. 37 | 38 | | | Architecture | Parameters | Input/Output Tokens | Vocab Size | 39 | |----------------------|----------------|------------|----------------------------------------------------------|------------| 40 | | Spirit LM Base | Llama-2 7B | 7B | Text Tokens, Phonetic Tokens | 32512 | 41 | | Spirit LM Expressive | Llama-2 7B | 7B | Text Tokens, Phonetic Tokens, Pitch Tokens, Style Tokens | 32768 | 42 | 43 | ### Release Date 44 | The models were trained between October and December 2023. The research paper was released on February 8th 2024. We released the model on October 18th 2024. 45 | 46 | ### Status 47 | This is a static model trained on an offline dataset. 48 | 49 | ### License 50 | We release the model under the FAIR Noncommercial Research License found in the [LICENSE](LICENSE) file in the root directory of this repo. 51 | 52 | ### Research Paper 53 | More information can be found in the paper ["SpiRit-LM: Interleaved Spoken and Written Language Model"](https://arxiv.org/pdf/2402.05755.pdf). 54 | 55 | ## Hardware and Software 56 | ### Training Factors 57 | We used custom training libraries. The training of the released models has been performed on Meta’s Research Clusters. 58 | 59 | The training of each model (Spirit LM Base and Spirit LM Expressive) takes 21K GPU hours of computation on hardware of type A100-80GB (TDP of 350-400W), not including the training of Llama-2. 60 | 61 | ## Training Data 62 | We trained the models on a combination of text-only datasets, speech-only datasets and aligned speech-text datasets. All the speech datasets are publicly available. Here are the statistics of the datasets we used: 63 | 64 | | | Hours | Speech Tokens | Text Tokens | 65 | |--------------|-------|---------------|-------------| 66 | | Speech-only | 458K | 28.2B | - | 67 | | Speech+Text | 111K | 7.0B | 1.4B | 68 | | Text-only | - | - | 307B | 69 | 70 | ## Evaluation Results 71 | See evaluations for our models and detailed ablations in Section 4 and 5, and safety evaluations in Section 6 of the [research paper](https://arxiv.org/pdf/2402.05755.pdf). 72 | 73 | ## Intended Use 74 | ### Intended Use Cases 75 | Spirit LM is intended for noncommercial research use in English. 76 | 77 | ### Out-of-Scope Uses 78 | Use in any manner that violates applicable laws or regulations (including trade compliance laws). Use in languages other than English. Use in any other way that is prohibited by the FAIR Noncommercial Research License and Acceptable Use Policy. 79 | 80 | ## Ethical Considerations and Limitations 81 | This model is built on Llama 2 which carries risks with use. Testing conducted to date has been in English, and has not covered, nor could it cover all scenarios. For these reasons, as with all LLMs, Llama 2’s potential outputs cannot be predicted in advance, and the model may in some instances produce inaccurate, biased or other objectionable responses to user prompts. The model’s speech capabilities are designed to analyze speaker agnostic qualities of any input speech and output speech in one of four pre-set voices. The model is meant for use for noncommercial research purposes only and should not be deployed in any consumer-facing applications. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Meta Spirit LM: Interleaved Spoken and Written Language Model 2 | 3 | This repository contains the model weights, inference code and evaluation scripts for the Spirit LM [paper](https://arxiv.org/pdf/2402.05755.pdf). You can find more generation samples on our [demo page](https://speechbot.github.io/spiritlm/). 4 | 5 | ## Spirit LM Model Overview 6 | 7 | 8 | ## Installation Setup 9 | ### Conda 10 | ``` 11 | conda env create -f env.yml 12 | pip install -e '.[eval]' 13 | 14 | ``` 15 | ### Pip 16 | ``` 17 | pip install -e '.[eval]' 18 | ``` 19 | 20 | ### Dev 21 | (Optionally, use only if you want to run the tests.) 22 | ``` 23 | pip install -e '.[dev]' 24 | ``` 25 | 26 | ## Checkpoints Setup 27 | See [checkpoints/README.md](checkpoints/README.md) 28 | 29 | ## Quick Start 30 | ### Speech Tokenization 31 | See [spiritlm/speech_tokenizer/README.md](spiritlm/speech_tokenizer/README.md) 32 | ### Spirit LM Generation 33 | See [spiritlm/model/README.md](spiritlm/model/README.md) 34 | ### Speech-Text Sentiment Preservation benchmark (STSP) 35 | See [spiritlm/eval/README.md](spiritlm/eval/README.md) 36 | 37 | ## Model Card 38 | More details of the model can be found in [MODEL_CARD.md](MODEL_CARD.md). 39 | 40 | ## License 41 | The present code is provided under the **FAIR Noncommercial Research License** found in [LICENSE](LICENSE). 42 | 43 | ## Citation 44 | ``` 45 | @misc{nguyen2024spiritlminterleavedspokenwritten, 46 | title={SpiRit-LM: Interleaved Spoken and Written Language Model}, 47 | author={Tu Anh Nguyen and Benjamin Muller and Bokai Yu and Marta R. Costa-jussa and Maha Elbayad and Sravya Popuri and Paul-Ambroise Duquenne and Robin Algayres and Ruslan Mavlyutov and Itai Gat and Gabriel Synnaeve and Juan Pino and Benoit Sagot and Emmanuel Dupoux}, 48 | year={2024}, 49 | eprint={2402.05755}, 50 | archivePrefix={arXiv}, 51 | primaryClass={cs.CL}, 52 | url={https://arxiv.org/abs/2402.05755}, 53 | } 54 | ``` 55 | 56 | -------------------------------------------------------------------------------- /assets/spiritlm_overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/spiritlm/52fb2f4d585811450f192732a1c81760208b9fd0/assets/spiritlm_overview.png -------------------------------------------------------------------------------- /checkpoints/README.md: -------------------------------------------------------------------------------- 1 | # Spirit LM Checkpoints 2 | 3 | ## Download Checkpoints 4 | To access and download Spirit LM Checkpoints, please request the model artifacts in this link: 5 | 6 | [https://ai.meta.com/resources/models-and-libraries/spirit-lm-downloads/](https://ai.meta.com/resources/models-and-libraries/spirit-lm-downloads/) 7 | 8 | Upon approval, you will then receive an email with download links to each model artifact. 9 | 10 | Please note that Spirit LM is made available under the **FAIR Noncommercial Research License** 11 | found in the [LICENSE](../LICENSE) file in the root directory of this source tree and Acceptable Use Policy. 12 | 13 | ## Structure 14 | The checkpoints directory should look like this: 15 | ``` 16 | checkpoints/ 17 | ├── README.md 18 | ├── speech_tokenizer 19 | │   ├── hifigan_spiritlm_base 20 | │   │   ├── config.json 21 | │   │   ├── generator.pt 22 | │   │   ├── speakers.txt 23 | │   │   └── styles.txt 24 | │   ├── hifigan_spiritlm_expressive_w2v2 25 | │   │   ├── config.json 26 | │   │   ├── generator.pt 27 | │   │   └── speakers.txt 28 | │   ├── hubert_25hz 29 | │   │   ├── L11_quantizer_500.pt 30 | │   │   └── mhubert_base_25hz.pt 31 | │   ├── style_encoder_w2v2 32 | │   │   ├── config.json 33 | │   │   └── pytorch_model.bin 34 | │   └── vqvae_f0_quantizer 35 | │   ├── config.yaml 36 | │   └── model.pt 37 | └── spiritlm_model 38 | ├── spirit-lm-base-7b 39 | │   ├── config.json 40 | │   ├── generation_config.json 41 | │   ├── pytorch_model.bin 42 | │   ├── special_tokens_map.json 43 | │   ├── tokenizer_config.json 44 | │   └── tokenizer.model 45 | └── spirit-lm-expressive-7b 46 | ├── config.json 47 | ├── generation_config.json 48 | ├── pytorch_model.bin 49 | ├── special_tokens_map.json 50 | ├── tokenizer_config.json 51 | └── tokenizer.model 52 | ``` 53 | You can export `SPIRITLM_CHECKPOINTS_DIR` to point to a differnt directory where you downloaded checkpoints. -------------------------------------------------------------------------------- /data/examples/pred.jsonl: -------------------------------------------------------------------------------- 1 | {"pred": "angry", "id": 4792320029370491913} 2 | {"pred": "neutral", "id": -5682350483296949563} 3 | {"pred": "amused", "id": -8754508989367964614} 4 | {"pred": "angry", "id": -9018665079841831624} 5 | {"pred": "neutral", "id": 1159246029716120600} 6 | -------------------------------------------------------------------------------- /data/examples/ref.jsonl: -------------------------------------------------------------------------------- 1 | {"emotion": "angry", "sentiment": "negative", "wav_path": "emov/sam/Angry/anger_281-308_0286.wav", "split": "test", "speaker": "sam", "id": 4792320029370491913} 2 | {"emotion": "neutral", "sentiment": "neutral", "wav_path": "emov/sam/Neutral/neutral_281-308_0286.wav", "split": "test", "speaker": "sam", "id": -5682350483296949563} 3 | {"emotion": "amused", "sentiment": "positive", "wav_path": "emov/sam/Amused/amused_281-308_0286.wav", "split": "test", "speaker": "sam", "id": -8754508989367964614} 4 | {"emotion": "angry", "sentiment": "negative", "wav_path": "emov/jenie/Angry/anger_57-84_0084.wav", "split": "test", "speaker": "jenie", "id": -9018665079841831624} 5 | {"emotion": "neutral", "sentiment": "neutral", "wav_path": "emov/jenie/Neutral/neutral_57-84_0084.wav", "split": "test", "speaker": "jenie", "id": 1159246029716120600} 6 | -------------------------------------------------------------------------------- /env.yml: -------------------------------------------------------------------------------- 1 | name: spiritlm_test 2 | channels: 3 | - pytorch 4 | - nvidia 5 | dependencies: 6 | - python=3.9 7 | - pip 8 | - pytorch-cuda=11.8 9 | - pytorch 10 | - torchaudio 11 | - pip: 12 | - omegaconf==2.2.0 13 | - librosa~=0.10 14 | - local-attention~=1.9 15 | - encodec~=0.1 16 | - transformers 17 | - fairscale~=0.4 18 | - sentencepiece 19 | - torchfcpe~=0.0.4 -------------------------------------------------------------------------------- /examples/audio/7143-88743-0029.flac: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/spiritlm/52fb2f4d585811450f192732a1c81760208b9fd0/examples/audio/7143-88743-0029.flac -------------------------------------------------------------------------------- /examples/distributed_inference_recipe/multi_nodes.slurm: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Copyright (c) Meta Platforms, Inc. and affiliates. 4 | # All rights reserved. 5 | # 6 | # This source code is licensed under the FAIR Noncommercial Research License 7 | # found in the LICENSE file in the root directory of this source tree. 8 | 9 | #SBATCH --job-name=spiritlm 10 | #SBATCH --ntasks-per-node=1 11 | #SBATCH --gpus-per-node=8 12 | #SBATCH --nodes=2 13 | #SBATCH --cpus-per-task=12 14 | #SBATCH --output=./logs/%j.stdout 15 | #SBATCH --error=./logs/%j.stderr 16 | #SBATCH --time=01:00:00 17 | 18 | set -e 19 | 20 | srun bash -c 'torchrun --nnodes $SLURM_JOB_NUM_NODES --nproc-per-node $SLURM_GPUS_ON_NODE \ 21 | --node-rank $SLURM_PROCID \ 22 | --master-addr $(scontrol show hostnames $SLURM_NODELIST | head -n1) \ 23 | --master-port 12345 \ 24 | examples/distributed_inference_recipe/run_dist.py' 25 | -------------------------------------------------------------------------------- /examples/distributed_inference_recipe/run_dist.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the FAIR Noncommercial Research License 5 | # found in the LICENSE file in the root directory of this source tree. 6 | 7 | """ 8 | Usage example: 9 | 10 | cd {SPIRITLM ROOT FOLDER} 11 | export PYTHONPATH=. 12 | 13 | Single node, multi-gpus: 14 | (Assume that your machine has 8 GPUs) 15 | torchrun --nnodes 1 --nproc-per-node 8 examples/distributed_inference_recipe/run_dist.py 16 | 17 | Multi-nodes, multi-gpus: 18 | (2 nodes, 8 GPUs for eahc node, via sbatch) 19 | mkdir -p logs 20 | sbatch examples/distributed_inference_recipe/multi_nodes.slurm 21 | """ 22 | 23 | import os 24 | 25 | import torch 26 | import torch.distributed as dist 27 | import torchaudio 28 | from spiritlm.model.spiritlm_model import ( 29 | ContentType, 30 | GenerationInput, 31 | OutputModality, 32 | Spiritlm, 33 | ) 34 | from torch.utils.data import TensorDataset 35 | from torch.utils.data.distributed import DistributedSampler 36 | from transformers import GenerationConfig, set_seed 37 | 38 | 39 | def run(seed: int = 0): 40 | world_size = int(os.environ["WORLD_SIZE"]) 41 | world_rank = int(os.environ["RANK"]) 42 | print( 43 | f"Running distributed inference with world_size: {world_size}, world_rank: {world_rank}" 44 | ) 45 | dist.init_process_group("nccl", rank=world_rank, world_size=world_size) 46 | 47 | set_seed(seed) 48 | 49 | wav = torchaudio.load("examples/audio/7143-88743-0029.flac")[0].squeeze() 50 | 51 | # fake repeated dataset 52 | dataset = TensorDataset(wav.repeat(32, 1)) 53 | 54 | sampler = DistributedSampler(dataset=dataset) 55 | loader = torch.utils.data.DataLoader( 56 | dataset=dataset, 57 | batch_size=1, # don't change 58 | sampler=sampler, 59 | num_workers=4, 60 | ) 61 | 62 | spirit_lm = Spiritlm("spirit-lm-expressive-7b") 63 | 64 | for _, data in enumerate(loader): 65 | outs = spirit_lm.generate( 66 | output_modality=OutputModality.ARBITRARY, 67 | interleaved_inputs=[ 68 | GenerationInput( 69 | content=data[0], # 0 because of batch size 1 70 | content_type=ContentType.SPEECH, 71 | ) 72 | ], 73 | generation_config=GenerationConfig( 74 | temperature=0.9, 75 | top_p=0.95, 76 | max_new_tokens=200, 77 | do_sample=True, 78 | ), 79 | ) 80 | print(f"outs: {outs}") 81 | 82 | 83 | def setup_env(): 84 | os.environ["OMP_NUM_THREADS"] = "1" 85 | 86 | 87 | if __name__ == "__main__": 88 | setup_env() 89 | run() 90 | -------------------------------------------------------------------------------- /requirements.dev.txt: -------------------------------------------------------------------------------- 1 | pytest -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | omegaconf>=2.2.0 2 | librosa>=0.10 3 | local-attention>=1.9 4 | encodec>=0.1 5 | transformers 6 | fairscale>=0.4 7 | sentencepiece 8 | pyarrow>=14.0 9 | torchfcpe>=0.0.4 -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the FAIR Noncommercial Research License 5 | # found in the LICENSE file in the root directory of this source tree. 6 | 7 | import os 8 | from pathlib import Path 9 | 10 | from setuptools import find_packages, setup 11 | 12 | NAME = "spiritlm" 13 | VERSION = "0.1.0" 14 | DESCRIPTION = "Interleaved Spoken and Written Language Model" 15 | URL = "https://github.com/facebookresearch/spiritlm" 16 | KEYWORDS = [ 17 | "Language Model, Speech Language Model, Multimodal, Crossmodal, Expressivity Modeling" 18 | ] 19 | LICENSE = "FAIR Noncommercial Research License" 20 | 21 | 22 | def _get_long_description(): 23 | with (Path(__file__).parent / "README.md").open(encoding="utf-8") as file: 24 | long_description = file.read() 25 | return long_description 26 | 27 | 28 | def _read_reqs(relpath): 29 | fullpath = os.path.join(os.path.dirname(__file__), relpath) 30 | with open(fullpath) as f: 31 | return [ 32 | s.strip() for s in f.readlines() if (s.strip() and not s.startswith("#")) 33 | ] 34 | 35 | 36 | setup( 37 | name=NAME, 38 | version=VERSION, 39 | description=DESCRIPTION, 40 | long_description=_get_long_description(), 41 | long_description_content_type="text/plain", 42 | url=URL, 43 | license=LICENSE, 44 | author="Meta", 45 | keywords=KEYWORDS, 46 | classifiers=[ 47 | "Intended Audience :: Science/Research", 48 | "License :: FAIR Noncommercial Research License", 49 | "Topic :: Multimedia :: Sound/Audio", 50 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 51 | ], 52 | packages=find_packages(), 53 | zip_safe=False, 54 | python_requires=">=3.9", 55 | install_requires=_read_reqs("requirements.txt"), 56 | extras_require={ 57 | "dev": ["pytest"], 58 | "eval": ["pandas"], 59 | }, 60 | ) 61 | -------------------------------------------------------------------------------- /spiritlm/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the FAIR Noncommercial Research License 5 | # found in the LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /spiritlm/eval/README.md: -------------------------------------------------------------------------------- 1 | # STSP Evaluation 2 | The Speech-Text Sentiment Preservation (STSP) benchmark is made of a collection of speech and text prompts in the positive, negative or neutral sentiment. 3 | Given a spoken or written prompt , the task consists in generating a text or speech sequence of tokens that preserves the sentiment of the prompt. 4 | 5 | The sentiment of the prompt is evaluated automatically with a sentiment/emotion classifier in speech or text depending of the output modality. 6 | Based on these, we derive a STSP accuracy score. 7 | 8 | ## Data Download 9 | Download the data as well as the speech/text classifier checkpoints via this [link](https://dl.fbaipublicfiles.com/textless_nlp/spiritlm/stsp.tar.gz) 10 | then extract the data into the folder `{spiritlm ROOT FOLDER}/data/stsp_data` 11 | ``` 12 | cd {spiritlm ROOT FOLDER} 13 | mkdir data/stsp_data 14 | tar -xvzf stsp.tar.gz -C data/stsp_data --strip-components=1 15 | ``` 16 | Run the following script to check the dataset is all correctly present: 17 | ``` 18 | python spiritlm/eval/stsp/sanity_check_download.py 19 | ``` 20 | ## Data structure 21 | The dataset contains 3 folders: 22 | - `data`: raw audio files 23 | - `manifest`: data splits 24 | - `model`: speech/text classifier checkpoints 25 | ### Data 26 | The raw audio files for 27 | - `emov`: EMOV 28 | - `expresso/conversational`: EXPRESSO-ASR 29 | - `expresso/read`: EXPRESSO-READ 30 | 31 | ### Manifest 32 | The train/validation/test splits, concretely we have: 33 | 34 | #### EMOV 35 | - 1053 records for emov train split at `manifest/emov/emov.train.jsonl` 36 | - 351 records for emov dev split at `manifest/emov/emov.dev.jsonl` 37 | - 351 records for emov test split at `manifest/emov/emov.test.jsonl` 38 | 39 | #### EXPRESSO-ASR 40 | - 1373 records for EXPRESSO-ASR train split at `manifest/expresso/expresso_asr.train` 41 | - 479 records for EXPRESSO-ASR dev at `manifest/expresso/expresso_asr.dev.jsonl` 42 | - 462 records for EXPRESSO-ASR test split at `manifest/expresso/expresso_asr.test.jsonl` 43 | 44 | #### EXPRESSO-READ 45 | - 1024 records for EXPRESSO-READ train split at `manifest/expresso/expresso_read.train` 46 | - 60 records for EXPRESSO-READ dev at `manifest/expresso/expresso_read.dev.jsonl` 47 | - 54 records for EXPRESSO-READ test split at `manifest/expresso/expresso_read.test.jsonl` 48 | 49 | #### Few-shot Samples 50 | The subset from EXPRESSO-ASR training set, used for the few-shot experiments: 51 | - `s2s.jsonl`: S -> S direction 52 | - `s2t.jsonl`: S -> T direction 53 | - `t2t.jsonl`: T -> T direction 54 | - `t2s.jsonl`: T -> S direction 55 | 56 | ### Auto-Eval Speech And Text Classifiers 57 | 58 | The sentiment of the generated sequence is estimated in an auto-eval fashion with Speech and Text classifiers. We point to the [paper](https://arxiv.org/abs/2402.05755) for details on these classifiers. 59 | 60 | 61 | ## Prediction & Evaluation of Spirit LM on STSP (Speech/Text) 62 | 63 | ```export PYTHONPATH=.``` 64 | 65 | Set `spiritlm` to the model you want to evaluate: e.g. ```spiritlm=spirit-lm-base-7b``` or ```spiritlm=spirit-lm-expressive-7b``` 66 | 67 | #### Speech to Text 68 | torchrun --nnodes 1 --nproc-per-node 1 spiritlm/eval/stsp/predict_stsp.py --model $spiritlm --eval_manifest_path data/stsp_data/manifest/emov/emov.test.jsonl --eval --write_pred ./pred_s_t.jsonl --input_output speech_text 69 | #### Text to Text 70 | torchrun --nnodes 1 --nproc-per-node 1 spiritlm/eval/stsp/predict_stsp.py --model $spiritlm --eval_manifest_path data/stsp_data/manifest/emov/emov.test.jsonl --eval --write_pred ./pred_t_t.jsonl --input_output text_text 71 | #### Text to Speech 72 | torchrun --nnodes 1 --nproc-per-node 1 spiritlm/eval/stsp/predict_stsp.py --model $spiritlm --eval_manifest_path data/stsp_data/manifest/emov/emov.test.jsonl --eval --write_pred ./pred_t_s.jsonl --input_output text_speech 73 | #### Speech to Speech 74 | torchrun --nnodes 1 --nproc-per-node 1 spiritlm/eval/stsp/predict_stsp.py --model $spiritlm --eval_manifest_path data/stsp_data/manifest/emov/emov.test.jsonl --eval --write_pred ./pred_s_s.jsonl --input_output speech_speech 75 | 76 | 77 | ### Post-hoc Evaluation 78 | 79 | To evaluate the performance of a model different from SpiritLM, you can use the following evaluation script that takes as input a prediction.jsonl file. 80 | 81 | ``` 82 | python spiritlm/eval/eval_stsp.py --ref_file $REF_FILE --pred_file $pred_file 83 | ``` 84 | 85 | e.g. 86 | 87 | ``` 88 | python spiritlm/eval/eval_stsp.py \ 89 | --ref_file ./data/examples/demo.jsonl \ 90 | --pred_file ./data/examples/pred.jsonl 91 | > Accuracy: 100.00% for predictions ./data/examples/pred.jsonl 92 | ``` 93 | -------------------------------------------------------------------------------- /spiritlm/eval/eval_stsp.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the FAIR Noncommercial Research License 5 | # found in the LICENSE file in the root directory of this source tree. 6 | 7 | import argparse 8 | import json 9 | from typing import Dict, Union 10 | 11 | import pandas as pd 12 | from spiritlm.eval.stsp.utils import EMOTION_2_SENTIMENT 13 | 14 | 15 | def load_pred(predictions): 16 | ret = {} 17 | with open(predictions) as f: 18 | for line in f: 19 | pred = json.loads(line) 20 | ret[str(pred["id"])] = pred["pred"] 21 | 22 | assert sum(1 for _ in open(predictions)) == len(ret) 23 | 24 | return ret 25 | 26 | 27 | def eval( 28 | gold_records: str, predictions: Union[str, Dict], info_data="", label="sentiment" 29 | ): 30 | n_gold_records = sum(1 for _ in open(gold_records)) 31 | n_lines_pred = ( 32 | sum(1 for _ in open(predictions)) 33 | if isinstance(predictions, str) 34 | else len(predictions) 35 | ) 36 | assert ( 37 | n_gold_records == n_lines_pred 38 | ), f"Mismatch between prediction ({n_lines_pred} samples in {predictions}) and reference ({n_gold_records} in {gold_records})" 39 | 40 | pred_dic = load_pred(predictions) if isinstance(predictions, str) else predictions 41 | scores = [] 42 | 43 | with open(gold_records) as gold: 44 | for line in gold: 45 | ref = json.loads(line) 46 | try: 47 | if label in ref: 48 | scores.append(pred_dic[str(ref["id"])] == ref[label]) 49 | else: 50 | assert label == "sentiment" and "emotion" in ref, ref 51 | sentiment = EMOTION_2_SENTIMENT[ref["emotion"]] 52 | scores.append(pred_dic[str(ref["id"])] == sentiment) 53 | except Exception as e: 54 | print( 55 | f"ERROR in matching the predicted labels with the gold ones: {e}: ref['id'] do not match any key in {pred_dic}', {ref['id']}: " 56 | ) 57 | # TODO: add other metrics if needed : F1 per class, etc. 58 | report = pd.DataFrame({"Correct": scores}) 59 | if isinstance(predictions, str): 60 | info_data += f"from {predictions}" 61 | print( 62 | f"Accuracy: {(report['Correct']==1).sum()/len(report)*100:0.2f}% for predictions {info_data}" 63 | ) 64 | 65 | 66 | if __name__ == "__main__": 67 | parser = argparse.ArgumentParser() 68 | 69 | parser.add_argument( 70 | "--ref_file", 71 | type=str, 72 | help="Path to reference record", 73 | ) 74 | parser.add_argument( 75 | "--pred_file", 76 | type=str, 77 | help="Path to prediction: should be jsonl with each entry {'pred': , 'id': }", 78 | ) 79 | parser.add_argument( 80 | "--label", 81 | type=str, 82 | default="sentiment", 83 | help="sentiment or emotion", 84 | ) 85 | args = parser.parse_args() 86 | 87 | eval(args.ref_file, args.pred_file, label=args.label) 88 | -------------------------------------------------------------------------------- /spiritlm/eval/load_data.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the FAIR Noncommercial Research License 5 | # found in the LICENSE file in the root directory of this source tree. 6 | 7 | import json 8 | from pathlib import Path 9 | 10 | import torch 11 | import torchaudio 12 | 13 | 14 | class SpeechData(torch.utils.data.Dataset): 15 | def __init__(self, manifest_dir, root_dir=None): 16 | if root_dir is None: 17 | root_dir = "." 18 | self.root_dir = Path(root_dir) 19 | self.manifest_dir = self.root_dir / manifest_dir 20 | self.wav_field = "wav_path" 21 | self.manifest = [json.loads(line.strip()) for line in open(manifest_dir)] 22 | 23 | def __getitem__(self, idx): 24 | wav_path = self.root_dir / self.manifest[idx][self.wav_field] 25 | return { 26 | "wav": torchaudio.load(wav_path)[0].squeeze(0), 27 | "id": str(self.manifest[idx]["id"]), 28 | } 29 | 30 | def __len__(self): 31 | return len(self.manifest) 32 | 33 | 34 | class TextData(torch.utils.data.Dataset): 35 | def __init__(self, manifest_dir, root_dir=None): 36 | if root_dir is None: 37 | root_dir = "." 38 | self.root_dir = Path(root_dir) 39 | self.manifest_dir = self.root_dir / manifest_dir 40 | self.text_field = "asr" 41 | self.manifest = [json.loads(line.strip()) for line in open(manifest_dir)] 42 | 43 | def __getitem__(self, idx): 44 | return { 45 | "text": self.manifest[idx][self.text_field], 46 | "id": str(self.manifest[idx]["id"]), 47 | } 48 | 49 | def __len__(self): 50 | return len(self.manifest) 51 | -------------------------------------------------------------------------------- /spiritlm/eval/stsp/few_shot_prompt.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the FAIR Noncommercial Research License 5 | # found in the LICENSE file in the root directory of this source tree. 6 | 7 | import math 8 | from typing import Union 9 | 10 | import pandas as pd 11 | import torch 12 | import torchaudio 13 | from spiritlm.eval.stsp.stsp_constants import STSP_DATA_ROOT, STSP_MANIFEST_ROOT 14 | from spiritlm.model.spiritlm_model import Spiritlm 15 | 16 | FEW_SHOT_MANIFEST_DIR = STSP_MANIFEST_ROOT / "few_shot" 17 | FEW_SHOT_TEMPLATE = "{prompt}{generation}" 18 | 19 | 20 | def wav_prompt(spiritlm_model: Spiritlm, wav: Union[str, torch.Tensor]) -> str: 21 | return spiritlm_model.SPEECH_PROMPT_PREFIX + spiritlm_model.speech_tokenizer(wav) 22 | 23 | 24 | def text_prompt(spiritlm_model: Spiritlm, text: str) -> str: 25 | return spiritlm_model.TEXT_PROMPT_PREFIX + text 26 | 27 | 28 | def _load_half_wav(wav_path: str, load_first_half: bool) -> torch.Tensor: 29 | wav_path = STSP_DATA_ROOT / wav_path 30 | wav = torchaudio.load(wav_path)[0].squeeze(0) 31 | size = wav.size()[0] 32 | half_size = size // 2 33 | if load_first_half: 34 | wav = wav[:half_size] 35 | else: 36 | wav = wav[half_size:] 37 | return wav 38 | 39 | 40 | def build_few_shot_prompt( 41 | spiritlm_model: Spiritlm, 42 | input_output: str, 43 | n_shots: int = 3, 44 | ) -> str: 45 | """ 46 | Build the few-shot prompt by simply concatenating a set of examples. 47 | 48 | E.g., a 3-shots T->S prompt would like this: 49 | "[Text]text1[Speech]speech_tokens1\n[Text]text2[Speech]speech_tokens2\n[Text]text3[Speech]speech_tokens3\n" 50 | """ 51 | manifset_file_mapping = { 52 | "text_text": "t2t", 53 | "speech_text": "s2t", 54 | "text_speech": "t2s", 55 | "speech_speech": "s2s", 56 | } 57 | manifest_path = ( 58 | FEW_SHOT_MANIFEST_DIR / f"{manifset_file_mapping[input_output]}.jsonl" 59 | ) 60 | df = pd.read_json(manifest_path, lines=True) 61 | assert n_shots <= len(df) 62 | 63 | # ensure a balanced sampels for each sentiment 64 | nb_samples_per_sentiment = math.ceil(n_shots / 3) 65 | df = df.groupby("sentiment").sample(n=nb_samples_per_sentiment) 66 | 67 | prompts = [] 68 | for _, row in df.iterrows(): 69 | prompt = row["prompt"] 70 | generation = row["generation"] 71 | if input_output == "text_text": 72 | prompt = FEW_SHOT_TEMPLATE.format( 73 | prompt=text_prompt(spiritlm_model, prompt), 74 | generation=text_prompt(spiritlm_model, generation), 75 | ) 76 | elif input_output == "text_speech": 77 | prompt = FEW_SHOT_TEMPLATE.format( 78 | prompt=text_prompt(spiritlm_model, prompt), 79 | generation=wav_prompt( 80 | spiritlm_model, _load_half_wav(generation, load_first_half=False) 81 | ), 82 | ) 83 | elif input_output == "speech_text": 84 | prompt = FEW_SHOT_TEMPLATE.format( 85 | prompt=wav_prompt( 86 | spiritlm_model, _load_half_wav(prompt, load_first_half=True) 87 | ), 88 | generation=text_prompt(spiritlm_model, generation), 89 | ) 90 | elif input_output == "speech_speech": 91 | prompt = FEW_SHOT_TEMPLATE.format( 92 | prompt=wav_prompt( 93 | spiritlm_model, _load_half_wav(prompt, load_first_half=True) 94 | ), 95 | generation=wav_prompt( 96 | spiritlm_model, _load_half_wav(generation, load_first_half=False) 97 | ), 98 | ) 99 | prompts.append(prompt) 100 | print(f"prompts: {prompts}") 101 | return "\n".join(prompts) + "\n" 102 | -------------------------------------------------------------------------------- /spiritlm/eval/stsp/predict_stsp.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the FAIR Noncommercial Research License 5 | # found in the LICENSE file in the root directory of this source tree. 6 | 7 | """ 8 | Usage example: 9 | 10 | cd {SPIRITLM ROOT FOLDER} 11 | export PYTHONPATH=. 12 | 13 | # Speech to Text 14 | torchrun --nnodes 1 --nproc-per-node 1 spiritlm/eval/stsp/predict_stsp.py --eval_manifest_path data/examples/ref.jsonl --eval --write_pred ./pred_s_t.jsonl --input_output speech_text 15 | # Text to Text 16 | torchrun --nnodes 1 --nproc-per-node 1 spiritlm/eval/stsp/predict_stsp.py --eval_manifest_path data/examples/ref.jsonl --eval --write_pred ./pred_t_t.jsonl --input_output text_text 17 | # Text to Speech# 18 | torchrun --nnodes 1 --nproc-per-node 1 spiritlm/eval/stsp/predict_stsp.py --eval_manifest_path data/examples/ref.jsonl --eval --write_pred ./pred._t_s.jsonl --input_output text_speech 19 | # Speech to Speech 20 | torchrun --nnodes 1 --nproc-per-node 1 spiritlm/eval/stsp/predict_stsp.py --eval_manifest_path data/examples/ref.jsonl --eval --write_pred ./pred_s_s.jsonl --input_output speech_speech 21 | 22 | """ 23 | 24 | import argparse 25 | import json 26 | import os 27 | import uuid 28 | from pathlib import Path 29 | from typing import Union 30 | 31 | import torch 32 | import torch.distributed as dist 33 | import torchaudio 34 | from spiritlm.eval.eval_stsp import eval 35 | from spiritlm.eval.load_data import SpeechData, TextData 36 | from spiritlm.eval.stsp.few_shot_prompt import build_few_shot_prompt 37 | from spiritlm.eval.stsp.sentiment_classifiers import ( 38 | get_text_sentiment_prediction, 39 | load_sentiment_classifier, 40 | ) 41 | from spiritlm.eval.stsp.stsp_constants import STSP_DATA_ROOT, STSP_MODEL_ROOT 42 | from spiritlm.eval.stsp.utils import ( 43 | ExpressoEmotionClassifier, 44 | load_emotion_classifier, 45 | wav2emotion_and_sentiment, 46 | ) 47 | from spiritlm.model.spiritlm_model import ( 48 | ContentType, 49 | GenerationInput, 50 | InterleavedOutputs, 51 | OutputModality, 52 | Spiritlm, 53 | ) 54 | from torch.utils.data.distributed import DistributedSampler 55 | from tqdm import tqdm 56 | from transformers import AutoModelForSequenceClassification, GenerationConfig, set_seed 57 | 58 | SPEECH_CLASSIFIER = STSP_MODEL_ROOT / "speech_classifier" 59 | TEXT_CLASSIFIER = STSP_MODEL_ROOT / "text_classifier" 60 | 61 | NB_RETRIES = 3 62 | 63 | 64 | def get_eval_classifier(args): 65 | if args.input_output.endswith("speech"): 66 | return load_emotion_classifier(str(SPEECH_CLASSIFIER)) 67 | elif args.input_output.endswith("text"): 68 | return load_sentiment_classifier(str(TEXT_CLASSIFIER)) 69 | else: 70 | raise (Exception(f"{args.input_output} not supported")) 71 | 72 | 73 | def get_sentiment( 74 | input_output, 75 | generation, 76 | classifer: Union[AutoModelForSequenceClassification, ExpressoEmotionClassifier], 77 | ): 78 | if input_output.endswith("speech"): 79 | _, pred_sentiment = wav2emotion_and_sentiment(generation, classifer) 80 | elif input_output.endswith("text"): 81 | _, pred_sentiment = get_text_sentiment_prediction(generation, classifer) 82 | return pred_sentiment 83 | 84 | 85 | def write_jsonl(dir: str, predictions: dict): 86 | Path(dir).parent.mkdir(exist_ok=True, parents=True) 87 | with open(dir, "w") as f: 88 | for id, result_dict in predictions.items(): 89 | record = {"id": id, **result_dict} 90 | json_string = json.dumps(record) 91 | f.write(json_string + "\n") # Add a newline to separate JSON objects 92 | print(f"{dir} written") 93 | 94 | 95 | def write_wav( 96 | wav, 97 | save_dir: Path, 98 | sample_rate: int = 16_000, 99 | ) -> str: 100 | """Save wav under `save_dir` with a random name and return the full path.""" 101 | save_dir.mkdir(exist_ok=True, parents=True) 102 | random_path = save_dir / (str(uuid.uuid4()) + ".wav") 103 | torchaudio.save( 104 | random_path, torch.from_numpy(wav).unsqueeze(0), sample_rate=sample_rate 105 | ) 106 | return str(random_path) 107 | 108 | 109 | def run(args): 110 | world_size = int(os.environ["WORLD_SIZE"]) 111 | world_rank = int(os.environ["RANK"]) 112 | print( 113 | f"Running distributed inference with world_size: {world_size}, world_rank: {world_rank}" 114 | ) 115 | dist.init_process_group("nccl", rank=world_rank, world_size=world_size) 116 | set_seed(args.seed) 117 | spiritlm_model = Spiritlm(args.model) 118 | evaluation_classifier = get_eval_classifier(args) 119 | input_output = args.input_output 120 | eval_manifest_path = args.eval_manifest_path 121 | write_wav_output = args.write_wav_output 122 | 123 | if args.few_shot > 0: 124 | prompt = build_few_shot_prompt( 125 | spiritlm_model=spiritlm_model, 126 | input_output=args.input_output, 127 | n_shots=args.few_shot, 128 | ) 129 | else: 130 | prompt = None 131 | 132 | # load 133 | if input_output.startswith("speech"): 134 | eval_dataset = SpeechData(eval_manifest_path, root_dir=STSP_DATA_ROOT) 135 | elif input_output.startswith("text"): 136 | eval_dataset = TextData(eval_manifest_path, root_dir=STSP_DATA_ROOT) 137 | 138 | sampler = DistributedSampler(dataset=eval_dataset) 139 | loader = torch.utils.data.DataLoader( 140 | dataset=eval_dataset, 141 | batch_size=1, # large batch size is not supported yet 142 | sampler=sampler, 143 | num_workers=4, 144 | ) 145 | predictions = {} 146 | if input_output.endswith("speech"): 147 | output_modality = OutputModality.SPEECH 148 | max_new_tokens = 300 149 | else: 150 | output_modality = OutputModality.TEXT 151 | max_new_tokens = 50 152 | for _, data in tqdm( 153 | enumerate(loader), 154 | desc=f"Predict {eval_manifest_path}", 155 | total=eval_dataset.__len__() // world_size, 156 | ): 157 | # retry the generation multiple times because sometime it does not generate hubert tokens 158 | for i in range(NB_RETRIES): 159 | try: 160 | out: InterleavedOutputs = spiritlm_model.generate( 161 | output_modality=output_modality, 162 | interleaved_inputs=[ 163 | GenerationInput( 164 | content=( 165 | data["wav"][0] 166 | if input_output.startswith("speech") 167 | else data["text"][0] 168 | ), # 0 because of batch size 1 169 | content_type=( 170 | ContentType.SPEECH 171 | if input_output.startswith("speech") 172 | else ContentType.TEXT 173 | ), 174 | ) 175 | ], 176 | generation_config=GenerationConfig( 177 | temperature=0.8, 178 | top_p=0.95, 179 | max_new_tokens=max_new_tokens, 180 | do_sample=True, 181 | ), 182 | prompt=prompt, 183 | ) 184 | except Exception as e: 185 | print(f"Got an exception when generating: {e}") 186 | if i == NB_RETRIES - 1: 187 | raise Exception(f"Failed to generate after {NB_RETRIES}") 188 | else: 189 | break 190 | assert len(out) == 1 191 | generated_output = out[0].content 192 | detected_sentiment = get_sentiment( 193 | input_output, generated_output, evaluation_classifier 194 | ) 195 | if output_modality == OutputModality.TEXT: 196 | generation = generated_output 197 | elif write_wav_output and output_modality == OutputModality.SPEECH: 198 | generation = write_wav(generated_output, Path(write_wav_output)) 199 | else: 200 | generation = None 201 | result_dict = {"pred": detected_sentiment} 202 | if generation is not None: 203 | result_dict["generation"] = generation 204 | predictions[str(data["id"][0])] = result_dict 205 | 206 | if args.eval: 207 | gathered_predictions = [None for _ in range(world_size)] 208 | dist.gather_object( 209 | predictions, gathered_predictions if world_rank == 0 else None, dst=0 210 | ) 211 | if world_rank == 0: 212 | all_predictions = {k: v for d in gathered_predictions for k, v in d.items()} 213 | eval( 214 | eval_manifest_path, 215 | {k: v["pred"] for k, v in all_predictions.items()}, 216 | info_data=f"{eval_manifest_path}, input-output {input_output}", 217 | label="sentiment", 218 | ) 219 | 220 | if args.write_pred is not None and world_rank == 0: 221 | write_jsonl(args.write_pred, all_predictions) 222 | 223 | 224 | def setup_env(): 225 | os.environ["OMP_NUM_THREADS"] = "1" 226 | 227 | 228 | if __name__ == "__main__": 229 | parser = argparse.ArgumentParser() 230 | parser.add_argument( 231 | "--eval_manifest_path", # data/examples/ref.jsonl 232 | type=str, 233 | help="Path to reference record", 234 | required=True, 235 | ) 236 | 237 | parser.add_argument( 238 | "--data_root_dir", # data/stsp_data 239 | type=str, 240 | help=f"Path to root data folder, default to {str(STSP_DATA_ROOT)}", 241 | default=str(STSP_DATA_ROOT), 242 | required=False, 243 | ) 244 | 245 | parser.add_argument( 246 | "--model", 247 | type=str, 248 | default="spirit-lm-expressive-7b", 249 | help="Model name (spirit-lm-base-7b or spirit-lm-expressive-7b) or path to model", 250 | required=False, 251 | ) 252 | parser.add_argument( 253 | "--few_shot", 254 | type=int, 255 | default=0, 256 | help="Number of few shot examples, 3/6/9", 257 | required=False, 258 | ) 259 | parser.add_argument( 260 | "--input_output", 261 | type=str, 262 | default="speech_speech", 263 | help="speech_speech speech_text text_speech text_text", 264 | required=False, 265 | ) 266 | parser.add_argument( 267 | "--eval_type", 268 | type=str, 269 | default="emotion", 270 | required=False, 271 | ) 272 | parser.add_argument( 273 | "--write_pred", 274 | type=str, 275 | default=None, 276 | help="Path to save the predictions output", 277 | required=False, 278 | ) 279 | parser.add_argument( 280 | "--write_wav_output", 281 | type=str, 282 | default=None, 283 | help="Path to save the generated audio if the output is speech", 284 | required=False, 285 | ) 286 | parser.add_argument( 287 | "--eval", 288 | default=False, 289 | action="store_true", 290 | ) 291 | parser.add_argument( 292 | "--seed", 293 | default=0, 294 | type=int, 295 | ) 296 | 297 | args = parser.parse_args() 298 | setup_env() 299 | run(args) 300 | -------------------------------------------------------------------------------- /spiritlm/eval/stsp/sanity_check_download.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the FAIR Noncommercial Research License 5 | # found in the LICENSE file in the root directory of this source tree. 6 | 7 | import json 8 | 9 | from spiritlm.eval.stsp.stsp_constants import STSP_DATA_ROOT, STSP_MANIFEST_ROOT 10 | 11 | 12 | def check_all_datasets(): 13 | for dataset_manifset in STSP_MANIFEST_ROOT.glob("**/*jsonl"): 14 | records_checked = 0 15 | print(f"dataset_manifset: {dataset_manifset}") 16 | with dataset_manifset.open() as f: 17 | for record in f: 18 | record = json.loads(record) 19 | for wav_key in ["wav_path", "prompt", "generation"]: 20 | if wav_key in record and record[wav_key].endswith(".wav"): 21 | wav_path = STSP_DATA_ROOT / record[wav_key] 22 | assert ( 23 | wav_path.is_file() 24 | ), f"Record {record[wav_key]} not found in {str(wav_path)} and listed in {dataset_manifset}" 25 | records_checked += 1 26 | print(f"{records_checked} records checked for {dataset_manifset.stem} split") 27 | 28 | 29 | if __name__ == "__main__": 30 | check_all_datasets() 31 | -------------------------------------------------------------------------------- /spiritlm/eval/stsp/sentiment_classifiers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the FAIR Noncommercial Research License 5 | # found in the LICENSE file in the root directory of this source tree. 6 | 7 | from typing import Any, Dict, List, Tuple 8 | 9 | from transformers import AutoModelForSequenceClassification, AutoTokenizer, pipeline 10 | 11 | 12 | def pred_to_label( 13 | sentiment_prediction_scores: List[List[Dict[str, Any]]], 14 | ) -> Tuple[str, float]: 15 | if isinstance(sentiment_prediction_scores[0], list): 16 | sentiment_prediction_scores = sentiment_prediction_scores[0] 17 | item_with_max_score = max( 18 | sentiment_prediction_scores, key=lambda _dict: _dict["score"] 19 | ) 20 | score = item_with_max_score["score"] 21 | return score, item_with_max_score["label"].lower() 22 | 23 | 24 | def get_text_sentiment_prediction(text: str, sentiment_classifier) -> Tuple[str, float]: 25 | return pred_to_label(sentiment_classifier(text)) 26 | 27 | 28 | def load_sentiment_classifier(model_dir: str): 29 | classifier = pipeline( 30 | task="text-classification", 31 | model=AutoModelForSequenceClassification.from_pretrained(model_dir), 32 | tokenizer=AutoTokenizer.from_pretrained( 33 | "j-hartmann/sentiment-roberta-large-english-3-classes" 34 | ), 35 | top_k=None, 36 | ) 37 | return classifier 38 | -------------------------------------------------------------------------------- /spiritlm/eval/stsp/stsp_constants.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the FAIR Noncommercial Research License 5 | # found in the LICENSE file in the root directory of this source tree. 6 | 7 | from pathlib import Path 8 | 9 | STSP_ROOT = Path(__file__).parents[3] / "data" / "stsp_data" 10 | STSP_DATA_ROOT = STSP_ROOT / "data" 11 | STSP_MODEL_ROOT = STSP_ROOT / "model" 12 | STSP_MANIFEST_ROOT = STSP_ROOT / "manifest" 13 | -------------------------------------------------------------------------------- /spiritlm/eval/stsp/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the FAIR Noncommercial Research License 5 | # found in the LICENSE file in the root directory of this source tree. 6 | 7 | from dataclasses import dataclass 8 | from functools import cache 9 | from typing import List, Optional, Tuple 10 | 11 | import torch 12 | import torchaudio 13 | from transformers import AutoFeatureExtractor, AutoModelForAudioClassification 14 | 15 | EXPRESSO_EMOTION_2_SENTIMENT = { 16 | "happy": "positive", 17 | "angry": "negative", 18 | "sad": "negative", 19 | "default": "neutral", 20 | } 21 | 22 | EMOTION_2_SENTIMENT = { 23 | "happy": "positive", 24 | "angry": "negative", 25 | "sad": "negative", 26 | "default": "neutral", 27 | "neutral": "neutral", 28 | "amused": "positive", 29 | } 30 | 31 | 32 | @cache 33 | def emotions2new_label_names_and_indices( 34 | emotions_to_select: Tuple[str], 35 | label_names: Tuple[str], 36 | ) -> Tuple[List[str], List[int]]: 37 | emotion2index = {e: i for i, e in enumerate(label_names)} 38 | sorted_indices_emotions = sorted( 39 | [(emotion2index[emotion], emotion) for emotion in emotions_to_select] 40 | ) 41 | zipped = list(zip(*sorted_indices_emotions)) 42 | return zipped 43 | 44 | 45 | def expresso_emotion2_sentiment(emotion: str): 46 | return EXPRESSO_EMOTION_2_SENTIMENT[emotion] 47 | 48 | 49 | @dataclass 50 | class ExpressoEmotionClassifier: 51 | feature_extractor: AutoFeatureExtractor 52 | model: AutoModelForAudioClassification 53 | label_names: List[str] 54 | 55 | 56 | def load_emotion_classifier(checkpoint_path: str) -> ExpressoEmotionClassifier: 57 | feature_extractor = AutoFeatureExtractor.from_pretrained(checkpoint_path) 58 | model = ( 59 | AutoModelForAudioClassification.from_pretrained(checkpoint_path).cuda().eval() 60 | ) 61 | label_names = [model.config.id2label[i] for i in range(model.config.num_labels)] 62 | print(f"Classification model loaded from {checkpoint_path} !") 63 | return ExpressoEmotionClassifier(feature_extractor, model, label_names) 64 | 65 | 66 | @torch.inference_mode() 67 | def predict_audio( 68 | audio, 69 | expresso_emotion_classifier: ExpressoEmotionClassifier, 70 | emotions_to_predict: Optional[List[str]] = None, 71 | ): 72 | if isinstance(audio, str): 73 | speech, _ = torchaudio.load(audio) 74 | resampler = torchaudio.transforms.Resample( 75 | expresso_emotion_classifier.feature_extractor.sampling_rate 76 | ) 77 | speech = resampler(speech).squeeze().numpy() 78 | else: 79 | speech = audio 80 | 81 | features = expresso_emotion_classifier.feature_extractor( 82 | speech, 83 | sampling_rate=expresso_emotion_classifier.feature_extractor.sampling_rate, 84 | return_tensors="pt", 85 | ) 86 | features["input_values"] = features["input_values"].cuda() 87 | 88 | logits = expresso_emotion_classifier.model(**features).logits 89 | if emotions_to_predict is not None: 90 | (indices, label_names) = emotions2new_label_names_and_indices( 91 | tuple(emotions_to_predict), tuple(expresso_emotion_classifier.label_names) 92 | ) 93 | logits = logits[:, indices] 94 | else: 95 | label_names = expresso_emotion_classifier.label_names 96 | pred_id = torch.argmax(logits, dim=-1)[0].item() 97 | 98 | return label_names[pred_id], logits.detach().cpu().numpy() 99 | 100 | 101 | def wav2emotion( 102 | wav, 103 | expresso_emotion_classifier: ExpressoEmotionClassifier, 104 | emotions_to_predict: Optional[List[str]] = None, 105 | ) -> str: 106 | label_logits = predict_audio( 107 | audio=wav, 108 | expresso_emotion_classifier=expresso_emotion_classifier, 109 | emotions_to_predict=emotions_to_predict, 110 | ) 111 | pred_emotion = label_logits[0] 112 | return pred_emotion 113 | 114 | 115 | def wav2emotion_and_sentiment( 116 | wav, 117 | expresso_emotion_classifier: ExpressoEmotionClassifier, 118 | emotions_to_predict: Optional[List[str]] = None, 119 | ) -> Tuple[str, str]: 120 | pred_emotion = wav2emotion(wav, expresso_emotion_classifier, emotions_to_predict) 121 | mapped_sentiment = expresso_emotion2_sentiment(pred_emotion) 122 | return pred_emotion, mapped_sentiment 123 | -------------------------------------------------------------------------------- /spiritlm/eval/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the FAIR Noncommercial Research License 5 | # found in the LICENSE file in the root directory of this source tree. 6 | 7 | import torchaudio 8 | from spiritlm.model.spiritlm_model import Spiritlm 9 | 10 | 11 | def wav_prompt(spiritlm_model: Spiritlm, wav_path: str) -> str: 12 | wav = torchaudio.load(wav_path)[0].squeeze(0) 13 | return spiritlm_model.SPEECH_PROMPT_PREFIX + spiritlm_model.speech_tokenizer(wav) 14 | 15 | 16 | def text_prompt(spiritlm_model: Spiritlm, text: str) -> str: 17 | return spiritlm_model.TEXT_PROMPT_PREFIX + text 18 | -------------------------------------------------------------------------------- /spiritlm/model/README.md: -------------------------------------------------------------------------------- 1 | # Model for Spirit LM 2 | This repo includes the Spirit LM model wrapper. 3 | 4 | ## Usage examples 5 | 6 | ### Model Loading 7 | ```python 8 | from spiritlm.model.spiritlm_model import Spiritlm 9 | 10 | # Spirit LM Base 7B 11 | spirit_lm = Spiritlm("spirit-lm-base-7b") 12 | 13 | # Spirit LM Expressive 7B 14 | spirit_lm = Spiritlm("spirit-lm-expressive-7b") 15 | ``` 16 | 17 | ### Generation examples 18 | ```python 19 | from spiritlm.model.spiritlm_model import OutputModality, GenerationInput, ContentType 20 | from transformers import GenerationConfig 21 | 22 | # Generate only text 23 | spirit_lm.generate( 24 | output_modality=OutputModality.TEXT, 25 | interleaved_inputs=[ 26 | GenerationInput( 27 | content="The largest country in the world is", 28 | content_type=ContentType.TEXT, 29 | ) 30 | ], 31 | generation_config=GenerationConfig( 32 | temperature=0.9, 33 | top_p=0.95, 34 | max_new_tokens=50, 35 | do_sample=True, 36 | ), 37 | ) 38 | 39 | # Expected output format: 40 | # [GenerationOuput(content='Russia, with an area of ...', content_type=)] 41 | 42 | # Generate only speech 43 | spirit_lm.generate( 44 | output_modality=OutputModality.SPEECH, 45 | interleaved_inputs=[ 46 | GenerationInput( 47 | content="examples/audio/7143-88743-0029.flac", 48 | content_type=ContentType.SPEECH, 49 | ) 50 | ], 51 | generation_config=GenerationConfig( 52 | temperature=0.9, 53 | top_p=0.95, 54 | max_new_tokens=200, 55 | do_sample=True, 56 | ), 57 | ) 58 | 59 | # Expected output format: 60 | # [GenerationOuput(content=array([ 3.6673620e-05, 2.6468514e-04, 1.0735081e-03, ...,], dtype=float32), content_type=)] 61 | 62 | 63 | # Arbitrary generation 64 | spirit_lm.generate( 65 | output_modality=OutputModality.ARBITRARY, 66 | interleaved_inputs=[ 67 | GenerationInput( 68 | content="examples/audio/7143-88743-0029.flac", 69 | content_type=ContentType.SPEECH, 70 | ) 71 | ], 72 | generation_config=GenerationConfig( 73 | temperature=0.9, 74 | top_p=0.95, 75 | max_new_tokens=200, 76 | do_sample=True, 77 | ), 78 | ) 79 | # Expected output format is a list of GenerationOuput where content type could be `ContentType.TEXT' or `ContentType.SPEECH`: 80 | # [GenerationOuput(content='xxx', content_type=), GenerationOuput(content=array([ 0.00553902, -0.03210586, ... ], dtype=float32), content_type=), GenerationOuput(content='yyy', content_type=), GenerationOuput(content=array([0.04051103, 0.03596291, 0.03381396, ..., 0.05103811, 0.05429034, ..,,], dtype=float32), content_type=)] 81 | ``` 82 | See more examples with other modalites in [examples/speech_generation/spirit_model.ipynb](../../examples/speech_generation/spirit_model.ipynb). -------------------------------------------------------------------------------- /spiritlm/model/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the FAIR Noncommercial Research License 5 | # found in the LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /spiritlm/model/spiritlm_model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the FAIR Noncommercial Research License 5 | # found in the LICENSE file in the root directory of this source tree. 6 | 7 | import logging 8 | import math 9 | import os 10 | from dataclasses import dataclass 11 | from enum import Enum, auto 12 | from functools import cache 13 | from pathlib import Path 14 | from typing import List, Optional, Tuple, Union 15 | 16 | import numpy as np 17 | import torch 18 | import torchaudio 19 | from spiritlm.model.utils import ( 20 | convert_to_wav_tensor, 21 | does_end_with_speech_token, 22 | does_start_with_speech_token, 23 | find_prompt_last_speech_start_position, 24 | get_forbidden_tokens, 25 | ) 26 | from spiritlm.speech_tokenizer import spiritlm_base, spiritlm_expressive 27 | from transformers import GenerationConfig, LlamaForCausalLM, LlamaTokenizer, set_seed 28 | 29 | _logger = logging.getLogger(__name__) 30 | 31 | 32 | # Get the base checkpoints directory from environment variable or use the default base path 33 | base_checkpoints_dir = Path(os.getenv("SPIRITLM_CHECKPOINTS_DIR", Path(__file__).parent.parent.parent / "checkpoints")) 34 | 35 | # Append 'spiritlm_model' to the base path 36 | CHECKPOINT_DIR = base_checkpoints_dir / "spiritlm_model" 37 | 38 | class ContentType(Enum): 39 | TEXT = "TEXT" 40 | SPEECH = "SPEECH" 41 | 42 | 43 | class OutputModality(Enum): 44 | TEXT = auto() 45 | SPEECH = auto() 46 | ARBITRARY = auto() 47 | 48 | 49 | @dataclass 50 | class GenerationInput: 51 | content: Union[str, os.PathLike, torch.Tensor, np.ndarray] 52 | content_type: ContentType 53 | 54 | @classmethod 55 | def from_tuple(cls, tup): 56 | content_type, content = tup 57 | content_type = content_type.upper() 58 | assert content_type in [ 59 | "SPEECH", 60 | "TEXT", 61 | ], f"expects content_type to be one of ['SPEECH', 'TEXT'], found '{content_type}'" 62 | if content_type == "TEXT": 63 | content_type = ContentType.TEXT 64 | elif content_type == "SPEECH": 65 | content_type = ContentType.SPEECH 66 | return cls(content=content, content_type=content_type) 67 | 68 | 69 | @dataclass 70 | class GenerationOuput: 71 | content: Union[str, np.ndarray] 72 | content_type: ContentType 73 | 74 | 75 | InterleavedInputs = List[GenerationInput] 76 | InterleavedOutputs = List[GenerationOuput] 77 | 78 | 79 | class SpiritlmVariants(Enum): 80 | BASE_7B = "spirit-lm-base-7b" 81 | EXPRESSIVIE_7B = "spirit-lm-expressive-7b" 82 | 83 | @classmethod 84 | def values_as_list(cls): 85 | return [e.value for e in cls] 86 | 87 | 88 | def _ensure_model_name(name: str): 89 | if Path(name).exists(): 90 | name = Path(name).stem 91 | expected_names = SpiritlmVariants.values_as_list() 92 | assert ( 93 | name in SpiritlmVariants.values_as_list() 94 | ), f"Unknown model name, expected one of {expected_names}" 95 | 96 | 97 | def _set_device_and_return(): 98 | if not torch.cuda.is_available(): 99 | return "cpu" 100 | local_rank = int(os.environ.get("LOCAL_RANK", "0")) 101 | torch.cuda.set_device(local_rank) 102 | return torch.device(local_rank) 103 | 104 | 105 | def _convert_str_output_modality(output_modality): 106 | """Convert from string to an instance of OutputModality""" 107 | output_modality_str_map = { 108 | "TEXT": OutputModality.TEXT, 109 | "SPEECH": OutputModality.SPEECH, 110 | "ARBITRARY": OutputModality.ARBITRARY, 111 | } 112 | if isinstance(output_modality, str): 113 | output_modality = output_modality.upper() 114 | assert ( 115 | output_modality in output_modality_str_map 116 | ), f"invalid string output_modality (found '{output_modality}', but expects one of {list(output_modality_str_map)})" 117 | output_modality = output_modality_str_map[output_modality] 118 | assert isinstance(output_modality, OutputModality) 119 | return output_modality 120 | 121 | 122 | def _get_generation_inputs(interleaved_inputs): 123 | """Convert from a list of tuple (content_type, content) to a list of GenrationInput""" 124 | for i, item in enumerate(interleaved_inputs): 125 | assert isinstance(item, tuple) or isinstance(item, GenerationInput), ( 126 | "Each element of interleaved_inputs is expected to be either an instance of GenerationInput " 127 | "or a tuple of (content_modality, content)" 128 | ) 129 | if isinstance(item, tuple): 130 | interleaved_inputs[i] = GenerationInput.from_tuple(interleaved_inputs[i]) 131 | return interleaved_inputs 132 | 133 | 134 | def _overwrite_generation_config(generation_config, kwargs): 135 | """Overwrite generation_config from the kwargs""" 136 | if generation_config is None: 137 | generation_config = GenerationConfig() 138 | assert isinstance(generation_config, GenerationConfig) 139 | gen_diff_dict = generation_config.to_diff_dict() 140 | for attr_name, attr_value in kwargs.items(): 141 | assert hasattr( 142 | generation_config, attr_name 143 | ), f"attribute '{attr_name}' not found in transformers.GenerationConfig" 144 | if attr_name in gen_diff_dict and attr_value != gen_diff_dict[attr_name]: 145 | _logger.warning( 146 | f"Overwrite generation_config's {attr_name} to {attr_value}" 147 | ) 148 | setattr(generation_config, attr_name, attr_value) 149 | return generation_config 150 | 151 | 152 | class Spiritlm: 153 | TEXT_PROMPT_PREFIX = "[Text]" 154 | SPEECH_PROMPT_PREFIX = "[Speech]" 155 | 156 | def __init__(self, name: str, **speech_tokenizer_kwargs): 157 | if Path(name).exists(): 158 | path = name 159 | else: 160 | path = CHECKPOINT_DIR / name 161 | _ensure_model_name(name) 162 | self.device = _set_device_and_return() 163 | _logger.info(f"Loading SPIRIT-LM model from the path {path}...") 164 | self.model = LlamaForCausalLM.from_pretrained( 165 | path, torch_dtype=torch.bfloat16 166 | ).to(self.device) 167 | _logger.info(f"SPIRIT-LM model is loaded.") 168 | self.tokenizer = LlamaTokenizer.from_pretrained( 169 | pretrained_model_name_or_path=path, 170 | add_bos_token=True, 171 | add_eos_token=False, 172 | ) 173 | _logger.info("Loading SPIRIT-LM speech tokenizers ...") 174 | if name == SpiritlmVariants.BASE_7B.value: 175 | self.speech_tokenizer = spiritlm_base(**speech_tokenizer_kwargs) 176 | self.is_expressive_model = False 177 | elif name == SpiritlmVariants.EXPRESSIVIE_7B.value: 178 | self.speech_tokenizer = spiritlm_expressive(**speech_tokenizer_kwargs) 179 | self.is_expressive_model = True 180 | _logger.info("SPIRIT-LM speech tokenizers are loaded.") 181 | 182 | def _build_prompt( 183 | self, 184 | generation_inputs: List[GenerationInput], 185 | output_modality: OutputModality, 186 | ) -> str: 187 | """ 188 | Build the prompt according the input content and the output modality. 189 | """ 190 | if not isinstance(output_modality, OutputModality): 191 | raise ValueError(f"Unknown output_modality: {output_modality}") 192 | prompts = [] 193 | prev_modality = None 194 | for gen_input in generation_inputs: 195 | if gen_input.content_type.value == ContentType.SPEECH.value: 196 | gen_input.content = convert_to_wav_tensor(gen_input.content) 197 | if prev_modality != "s": 198 | prompts.append(Spiritlm.SPEECH_PROMPT_PREFIX) 199 | prompts.append(self.speech_tokenizer(gen_input.content)) 200 | prev_modality = "s" # speech 201 | elif gen_input.content_type.value == ContentType.TEXT.value: 202 | if prev_modality != "t": 203 | prompts.append(Spiritlm.TEXT_PROMPT_PREFIX) 204 | prompts.append(gen_input.content) 205 | prev_modality = "t" # text 206 | else: 207 | raise ValueError( 208 | f"Unknown content type: {gen_input.content_type.value}" 209 | ) 210 | if output_modality == OutputModality.TEXT: 211 | if prev_modality != "t": 212 | prompts.append(Spiritlm.TEXT_PROMPT_PREFIX) 213 | elif output_modality == OutputModality.SPEECH: 214 | if prev_modality != "s": 215 | prompts.append(Spiritlm.SPEECH_PROMPT_PREFIX) 216 | return "".join(prompts) 217 | 218 | @cache 219 | def _build_forbidden_tokens( 220 | self, 221 | output_modality: OutputModality, 222 | ) -> List[int]: 223 | """ 224 | Build a set of token ids that we don't want to generate according the modality direction. 225 | 226 | For instance, when the modality direction is speech to text (S2T), i.e., we continue 227 | generating text given a speech prompt, we want that the output contains only the text tokens. 228 | """ 229 | if output_modality == OutputModality.TEXT: 230 | forbidden_tokens = get_forbidden_tokens( 231 | ban_special_tokens=True, 232 | generate_only_text=True, 233 | ban_expressivity_tokens=True if self.is_expressive_model else False, 234 | ) 235 | elif output_modality == OutputModality.SPEECH: 236 | forbidden_tokens = get_forbidden_tokens( 237 | ban_special_tokens=True, 238 | generate_only_speech=True, 239 | ) 240 | elif output_modality == OutputModality.ARBITRARY: 241 | forbidden_tokens = [] 242 | else: 243 | raise ValueError(f"Unknown output_modality: {output_modality}") 244 | return forbidden_tokens 245 | 246 | def _parse_speech_and_text( 247 | self, 248 | generated_content: str, 249 | ): 250 | # TODO: clean this function, it is too long! 251 | splits = [] 252 | i = 0 253 | last_pos = len(generated_content) 254 | char_and_types = [] 255 | is_speech_token = False 256 | is_text_token = False 257 | text_prefix_length = len(Spiritlm.TEXT_PROMPT_PREFIX) 258 | speech_prefix_length = len(Spiritlm.SPEECH_PROMPT_PREFIX) 259 | while i < last_pos: 260 | ch = generated_content[i] 261 | j = i 262 | if ch == "[": 263 | if ( 264 | j + text_prefix_length - 1 < last_pos 265 | and generated_content[j : j + text_prefix_length] 266 | == Spiritlm.TEXT_PROMPT_PREFIX 267 | ): # text prefix token 268 | j += text_prefix_length # skip "[Text] 269 | elif ( 270 | j + speech_prefix_length - 1 < last_pos 271 | and generated_content[j : j + speech_prefix_length] 272 | == Spiritlm.SPEECH_PROMPT_PREFIX 273 | ): # speech prefix token 274 | j += speech_prefix_length # skip "[Speech]" 275 | elif j + 2 < last_pos and generated_content[j + 1 : j + 3] in ( 276 | "Hu", 277 | "Pi", 278 | "St", 279 | ): 280 | j += 3 # skip "["" and Hu/Pi/St 281 | while j < last_pos and generated_content[j] != "]": 282 | j += 1 283 | j += 1 # skip "]" 284 | is_speech_token = True 285 | else: # other texts starting with "[" e.g., "[abc" 286 | is_text_token = True 287 | j += 1 288 | else: 289 | is_text_token = True 290 | while j < last_pos and generated_content[j] != "[": 291 | j += 1 292 | 293 | cur_content = generated_content[i:j] 294 | if is_speech_token: 295 | if len(char_and_types) and char_and_types[-1][1] == "t": 296 | splits.append( 297 | ( 298 | "".join( 299 | ( 300 | content_and_type[0] 301 | for content_and_type in char_and_types 302 | ) 303 | ), 304 | "t", 305 | ) 306 | ) 307 | char_and_types = [] 308 | char_and_types.append((cur_content, "s")) # speech 309 | elif is_text_token: 310 | if len(char_and_types) and char_and_types[-1][1] == "s": 311 | splits.append( 312 | ( 313 | "".join( 314 | ( 315 | content_and_type[0] 316 | for content_and_type in char_and_types 317 | ) 318 | ), 319 | "s", 320 | ) 321 | ) 322 | char_and_types = [] 323 | char_and_types.append((cur_content, "t")) # text 324 | is_speech_token, is_text_token = False, False 325 | i = j 326 | if len(char_and_types): 327 | if char_and_types[-1][1] == "t": 328 | splits.append( 329 | ( 330 | "".join( 331 | (content_and_type[0] for content_and_type in char_and_types) 332 | ), 333 | "t", 334 | ) 335 | ) 336 | else: 337 | splits.append( 338 | ( 339 | "".join( 340 | (content_and_type[0] for content_and_type in char_and_types) 341 | ), 342 | "s", 343 | ) 344 | ) 345 | return splits 346 | 347 | def _decode_from_generated_output( 348 | self, 349 | output_modality: OutputModality, 350 | generated_content: str, 351 | prompt: str, 352 | speaker_id: int = 2, 353 | ) -> InterleavedOutputs: 354 | """ 355 | Decode the generated tokens according the modality direction. 356 | 357 | If the output is text, we return what it is. 358 | If the output is speech, we decode speech tokens by the speech tokenizer. 359 | If the output is arbitrary, we decode the generated content according to the its modality. 360 | """ 361 | 362 | def _decode( 363 | modality: OutputModality, 364 | gen: str, 365 | ) -> InterleavedOutputs: 366 | if modality == OutputModality.TEXT: 367 | return [ 368 | GenerationOuput( 369 | content=gen, 370 | content_type=ContentType.TEXT, 371 | ) 372 | ] 373 | elif modality == OutputModality.SPEECH: 374 | return [ 375 | GenerationOuput( 376 | content=self.speech_tokenizer.decode( 377 | gen, speaker_id=speaker_id 378 | ), 379 | content_type=ContentType.SPEECH, 380 | ) 381 | ] 382 | elif modality == OutputModality.ARBITRARY: 383 | decoded_chunks = [] 384 | for i, (chunk_content, chunk_modality) in enumerate( 385 | self._parse_speech_and_text(gen) 386 | ): 387 | if chunk_modality == "s": 388 | # TODO: the way of finding Hubert token could be false positive 389 | nb_content_hubert_tokens = len(chunk_content.split("[Hu")) 390 | decoded = _decode( 391 | modality=OutputModality.SPEECH, 392 | gen=chunk_content, 393 | )[0] 394 | if i == 0 and is_last_content_speech: 395 | # edge case when the prompt ends with speech and the generation starts with speech 396 | nb_prompt_hubert_tokens = ( 397 | len(prompt[last_speech_start_pos:].split("[Hu")) - 1 398 | ) # minus the one in prefix 399 | if nb_content_hubert_tokens - nb_prompt_hubert_tokens < 25: 400 | # continued speech from the prompt is too short 401 | continue 402 | # we drop the prompt part from the generation 403 | prompt_ratio = ( 404 | nb_prompt_hubert_tokens / nb_content_hubert_tokens 405 | ) 406 | decoded.content = decoded.content[ 407 | math.ceil(decoded.content.size * prompt_ratio) : 408 | ] 409 | elif i > 0 and nb_content_hubert_tokens < 25: 410 | # new speech in generation is too short 411 | continue 412 | else: 413 | decoded = _decode( 414 | modality=OutputModality.TEXT, 415 | gen=chunk_content, 416 | )[0] 417 | decoded_chunks.append(decoded) 418 | return decoded_chunks 419 | else: 420 | raise ValueError(f"Unknown output_modality: {output_modality}") 421 | 422 | generated_new_content = generated_content[len(prompt) :].strip() 423 | is_last_content_speech, last_speech_start_pos = False, 0 424 | if ( 425 | output_modality == OutputModality.ARBITRARY 426 | and does_end_with_speech_token(prompt) 427 | and does_start_with_speech_token(generated_new_content) 428 | ): 429 | is_last_content_speech = True 430 | last_speech_start_pos = find_prompt_last_speech_start_position(prompt) 431 | # If the prompt ends with speech, we decode both the prompt and the generation 432 | # because we probably don't have pitch and style tokens in the generation. 433 | generated_new_content = generated_content[last_speech_start_pos:] 434 | return _decode(output_modality, generated_new_content) 435 | 436 | def generate( 437 | self, 438 | interleaved_inputs: Optional[List[Union[GenerationInput, tuple]]] = None, 439 | prompt: Optional[str] = None, 440 | output_modality: Union[OutputModality, str] = OutputModality.ARBITRARY, 441 | generation_config: Optional[GenerationConfig] = None, 442 | force_tokens_to_output_modality: bool = True, 443 | speaker_id: int = 2, 444 | return_prompt: bool = False, 445 | seed: Optional[int] = None, 446 | **kwargs, # GenerationConfig args can be passing here 447 | ) -> Union[InterleavedOutputs, Tuple[InterleavedOutputs, str]]: 448 | """ 449 | Speech/text generation given speech/text prompt. 450 | 451 | Parameters: 452 | interleaved_inputs (List of `GenerationInput` or list of tuples): 453 | List of speech/text inputs. 454 | Each element can be an instance of `GenerationInput` or a tuple of (content_type, content) 455 | Text content is string; Speech content is either audio path, audio tensor, or nummpy array. 456 | The prompt will be built by interleaving them in order. 457 | prompt (str): 458 | The prompt in encoded tokens string, 459 | e.g., "[Speech][Hu99][Hu38]...", "[Text]whatever text" or mix of speech & text. 460 | output_modality (str or `OutputModality`): 461 | 'TEXT' or OutputModality.TEXT: generate text 462 | 'SPEECH' or OutputModality.SPEECH: generate speech 463 | 'ARBITRARY' or OutputModality.ARBITRARY: generate arbitrary modality output (default) 464 | generation_config (`GenerationConfig`): 465 | Generation configuration used by Huggingface `generate` function. 466 | force_tokens_to_output_modality (bool): 467 | Whether to force generating tokens to the output modality that you specify in `output_modality`. 468 | For instance, if the `output_modality` is TEXT and force_tokens_to_output_modality is True, 469 | we force the model to generate only the text tokens. 470 | speaker_id (int): 471 | Speaker id, 0, 1, 2 or 3. 472 | return_prompt (bool): 473 | Whether to return the constructed prompt (could be used for debug). 474 | **kwargs: 475 | Directly passing arguments from transformers.GenerationConfig (e.g. temperature, max_new_tokens, do_sample). 476 | See: https://huggingface.co/docs/transformers/main_classes/text_generation#transformers.GenerationConfig 477 | """ 478 | 479 | if seed is not None: 480 | _logger.info(f"Set seed to {seed}") 481 | set_seed(seed) 482 | 483 | # Set the output modality 484 | output_modality = _convert_str_output_modality(output_modality) 485 | 486 | # Get the input prompt 487 | assert not ( 488 | interleaved_inputs is None and prompt is None 489 | ), "interleaved_inputs and prompt can not both be None" 490 | if ( 491 | prompt is not None 492 | and interleaved_inputs is not None 493 | and len(interleaved_inputs) > 0 494 | ): 495 | _logger.warning( 496 | "When prompt is specified, interleaved_inputs will not be used." 497 | ) 498 | if prompt is None: 499 | if not isinstance(interleaved_inputs, list): 500 | interleaved_inputs = [interleaved_inputs] 501 | interleaved_inputs = _get_generation_inputs(interleaved_inputs) 502 | prompt = self._build_prompt( 503 | interleaved_inputs, 504 | output_modality, 505 | ) 506 | 507 | # Get input tensor 508 | inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device) 509 | 510 | # Get generation config from kwargs 511 | generation_config = _overwrite_generation_config(generation_config, kwargs) 512 | 513 | # Get forbidden token ids 514 | if ( 515 | force_tokens_to_output_modality 516 | and output_modality != OutputModality.ARBITRARY 517 | ): 518 | forbidden_token_ids = [ 519 | [tok_id] for tok_id in self._build_forbidden_tokens(output_modality) 520 | ] 521 | else: 522 | forbidden_token_ids = None 523 | 524 | # Perform the generation 525 | generate_ids = self.model.generate( 526 | **inputs, 527 | generation_config=generation_config, 528 | bad_words_ids=forbidden_token_ids, 529 | pad_token_id=-1, 530 | ) 531 | 532 | # Decode the output 533 | gen = self.tokenizer.batch_decode( 534 | generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False 535 | )[0] 536 | try: 537 | decoded_output = self._decode_from_generated_output( 538 | output_modality=output_modality, 539 | generated_content=gen, 540 | prompt=prompt, 541 | speaker_id=speaker_id, 542 | ) 543 | except Exception as e: 544 | _logger.error(f"Fail to decode the content: {gen[len(prompt) :].strip()}") 545 | raise e 546 | 547 | if return_prompt: 548 | return decoded_output, prompt 549 | else: 550 | return decoded_output 551 | 552 | 553 | if __name__ == "__main__": 554 | spirit_lm = Spiritlm("spirit-lm-expressive-7b") 555 | # run several time to test speech text interleaved outputs 556 | wav = torchaudio.load("examples/audio/7143-88743-0029.flac")[0].squeeze() 557 | for i in range(5): 558 | outs = spirit_lm.generate( 559 | output_modality=OutputModality.ARBITRARY, 560 | interleaved_inputs=[ 561 | GenerationInput( 562 | content=wav, 563 | content_type=ContentType.SPEECH, 564 | ) 565 | ], 566 | generation_config=GenerationConfig( 567 | temperature=0.9, 568 | top_p=0.95, 569 | max_new_tokens=200, 570 | do_sample=True, 571 | ), 572 | ) 573 | print("-" * 100) 574 | print(i) 575 | print("-" * 100) 576 | print(outs) 577 | -------------------------------------------------------------------------------- /spiritlm/model/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the FAIR Noncommercial Research License 5 | # found in the LICENSE file in the root directory of this source tree. 6 | 7 | import os 8 | import re 9 | from io import BytesIO 10 | from typing import List, Optional, Union 11 | 12 | import numpy as np 13 | import torch 14 | import torchaudio 15 | 16 | EXPECTED_SAMPLING_RATE = 16_000 17 | 18 | 19 | def find_prompt_last_speech_start_position(prompt: str) -> Optional[int]: 20 | prev_end = None 21 | # revert the prompt so we can search from right to left, the speech token patterns are also reverted. 22 | for match in re.finditer("(\]\d+uH\[)|(\]\d+iP\[)|(\]\d+tS\[)", prompt[::-1]): 23 | start, end = match.start(), match.end() 24 | if prev_end is not None and start != prev_end: 25 | return len(prompt) - prev_end 26 | prev_end = end 27 | if prev_end is None: 28 | # speech token is not found in the prompt 29 | return None 30 | return len(prompt) - prev_end 31 | 32 | 33 | def convert_to_wav_tensor( 34 | content: Union[str, os.PathLike, torch.Tensor, np.ndarray] 35 | ) -> torch.Tensor: 36 | if isinstance(content, os.PathLike) or isinstance(content, str): 37 | audio_path = str(content) 38 | wav, sr = torchaudio.load(audio_path) 39 | if sr != EXPECTED_SAMPLING_RATE: 40 | wav = torchaudio.functional.resample( 41 | wav, orig_freq=sr, new_freq=EXPECTED_SAMPLING_RATE 42 | ) 43 | elif isinstance(content, np.ndarray): 44 | wav = torch.from_numpy(content) 45 | elif isinstance(content, bytes): 46 | wav, sr = torchaudio.load(BytesIO(content)) 47 | if sr != EXPECTED_SAMPLING_RATE: 48 | wav = torchaudio.functional.resample( 49 | wav, orig_freq=sr, new_freq=EXPECTED_SAMPLING_RATE 50 | ) 51 | else: 52 | wav = content 53 | 54 | # TODO: what about stereo ? 55 | 56 | return wav.squeeze() 57 | 58 | 59 | def does_start_with_speech_token(encoded_string) -> bool: 60 | if ( 61 | encoded_string is None or len(encoded_string) <= 4 62 | ): # shortest speech token is "[Hu1]" 63 | return False 64 | if encoded_string[0] != "[": 65 | return False 66 | end_pos = 1 67 | while end_pos < len(encoded_string): 68 | if encoded_string[end_pos] == "]" and end_pos >= 4: 69 | if any(encoded_string[1:3].startswith(tok) for tok in ["Hu", "Pi", "St"]): 70 | return True 71 | return False 72 | # longest speech token is "[Huxxxxx]" 73 | if end_pos >= 10: 74 | return False 75 | end_pos += 1 76 | return False 77 | 78 | 79 | def does_end_with_speech_token(encoded_string: str) -> bool: 80 | if ( 81 | encoded_string is None or len(encoded_string) <= 4 82 | ): # shortest speech token is "[Hu1]" 83 | return False 84 | if encoded_string[-1] != "]": 85 | return False 86 | start_pos = len(encoded_string) - 2 87 | while start_pos >= 0: 88 | if encoded_string[start_pos] == "[" and start_pos + 3 < len(encoded_string): 89 | if any( 90 | encoded_string[start_pos + 1 : start_pos + 3].startswith(tok) 91 | for tok in ["Hu", "Pi", "St"] 92 | ): 93 | return True 94 | return False 95 | # longest speech token is "[Huxxxxx]" 96 | if start_pos < len(encoded_string) - 10: 97 | return False 98 | start_pos -= 1 99 | return False 100 | 101 | 102 | def get_forbidden_tokens( 103 | ban_special_tokens: bool = True, 104 | generate_only_speech: bool = False, 105 | generate_only_text: bool = False, 106 | ban_expressivity_tokens: bool = False, 107 | ) -> List[int]: 108 | assert not ( 109 | generate_only_speech and generate_only_text 110 | ), "Nothing will be generated when generate_only_speech and generate_only_text is all True." 111 | forbidden_tokens = [] 112 | if ban_special_tokens: 113 | forbidden_tokens += [ 114 | 32000, 115 | 32001, 116 | ] # [Text], [Speech] 117 | if generate_only_speech: 118 | forbidden_tokens += list(range(32000)) 119 | elif generate_only_text: 120 | forbidden_tokens += list(range(32002, 32002 + 501)) # hubert tokens 121 | if ban_expressivity_tokens: 122 | forbidden_tokens += list(range(32503, 32503 + 64)) # pitch tokens 123 | forbidden_tokens += list( 124 | range(32567, 32567 + 100) 125 | ) # forbidden style tokens 126 | return forbidden_tokens 127 | -------------------------------------------------------------------------------- /spiritlm/speech_tokenizer/README.md: -------------------------------------------------------------------------------- 1 | # Speech Tokenization for Spirit LM 2 | 3 | This repo contains the speech encoder/decoder used for the Spirit LM. 4 | 5 | Here is an example of how to use spiritlm_tokenizer 6 | 7 | ```python 8 | import IPython.display as ipd 9 | from spiritlm.speech_tokenizer import spiritlm_base, spiritlm_expressive 10 | 11 | tokenizer = spiritlm_base() # base version, only has hubert units 12 | # tokenizer = spiritlm_expressive() # expressive version, with pitch & style units 13 | 14 | # Input audio 15 | audio = "examples/audio/7143-88743-0029.flac" 16 | print('Original audio:') 17 | ipd.display(ipd.Audio(audio)) 18 | 19 | ## encode_units 20 | print('\nEncode audio into units (not deduplicated) \n', '-'*20) 21 | units = tokenizer.encode_units(audio) 22 | print(units) 23 | # > {'audio': '.../audio/7143-88743-0029.flac', 'hubert': '99 49 38 149 149 71...'} 24 | 25 | ## encode_string 26 | print('\nEncode audio into string (deduplicated and sorted units) \n', '-'*20) 27 | string_tokens = tokenizer.encode_string(audio) 28 | print(string_tokens) 29 | # > '[Hu99][Hu49][Hu38][Hu149][Hu71]...' 30 | 31 | ## decode from units 32 | print('\nDecode back to audio from units (not deduplicated) \n', '-'*20) 33 | resyn_wav = tokenizer.decode(units, speaker_id=2, dur_pred=False) 34 | ipd.display(ipd.Audio(resyn_wav, rate=16000)) 35 | 36 | ## decode from string 37 | print('\nDecode back to audio from string (deduplicated and sorted units) \n', '-'*20) 38 | resyn_dedup_wav = tokenizer.decode(string_tokens, speaker_id=2) 39 | ipd.display(ipd.Audio(resyn_dedup_wav, rate=16000)) 40 | ``` 41 | 42 | An example notebook can be found in [examples/speech_tokenizer/spiritlm_speech_tokenizer.ipynb](../../examples/speech_tokenizer/spiritlm_speech_tokenizer.ipynb). -------------------------------------------------------------------------------- /spiritlm/speech_tokenizer/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the FAIR Noncommercial Research License 5 | # found in the LICENSE file in the root directory of this source tree. 6 | 7 | from .f0 import spiritlm_expressive_f0 8 | from .hifigan import spiritlm_base_hifigan, spiritlm_expressive_hifigan_w2v2 9 | from .hubert import spiritlm_hubert 10 | from .spiritlm_tokenizer import SpiritLMTokenizer 11 | from .style_encoder import spiritlm_expressive_style_encoder_w2v2 12 | 13 | # Trick to avoid reloading the same model twice when calling multiple times 14 | HUBERT = None 15 | HIFIGAN_BASE = None 16 | F0 = None 17 | STYLE_W2V2 = None 18 | HIFIGAN_EXPRESSIVE_W2V2 = None 19 | 20 | 21 | def spiritlm_base( 22 | default_speaker=2, 23 | default_style=8, # conv-default 24 | ): 25 | # Hubert 26 | global HUBERT 27 | if HUBERT is None: 28 | HUBERT = spiritlm_hubert() 29 | 30 | # Hifigan 31 | global HIFIGAN_BASE 32 | if HIFIGAN_BASE is None: 33 | HIFIGAN_BASE = spiritlm_base_hifigan( 34 | default_speaker=default_speaker, default_style=default_style 35 | ) 36 | 37 | return SpiritLMTokenizer( 38 | hubert_model=HUBERT, 39 | hifigan_model=HIFIGAN_BASE, 40 | ) 41 | 42 | 43 | def spiritlm_expressive(f0_backbone="fcpe", default_speaker=2): 44 | # Hubert 45 | global HUBERT 46 | if HUBERT is None: 47 | HUBERT = spiritlm_hubert() 48 | 49 | # F0 50 | global F0 51 | if F0 is None: 52 | F0 = spiritlm_expressive_f0(f0_backbone=f0_backbone) 53 | 54 | # Style 55 | global STYLE_W2V2 56 | if STYLE_W2V2 is None: 57 | STYLE_W2V2 = spiritlm_expressive_style_encoder_w2v2() 58 | 59 | # Hifigan 60 | global HIFIGAN_EXPRESSIVE_W2V2 61 | if HIFIGAN_EXPRESSIVE_W2V2 is None: 62 | HIFIGAN_EXPRESSIVE_W2V2 = spiritlm_expressive_hifigan_w2v2( 63 | default_speaker=default_speaker 64 | ) 65 | 66 | return SpiritLMTokenizer( 67 | hubert_model=HUBERT, 68 | pitch_model=F0, 69 | style_model=STYLE_W2V2, 70 | hifigan_model=HIFIGAN_EXPRESSIVE_W2V2, 71 | hubert_key="hubert", 72 | pitch_key="pitch", 73 | style_key="style", 74 | ) 75 | -------------------------------------------------------------------------------- /spiritlm/speech_tokenizer/f0/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the FAIR Noncommercial Research License 5 | # found in the LICENSE file in the root directory of this source tree. 6 | 7 | from pathlib import Path 8 | import os 9 | 10 | import torch 11 | 12 | from .f0_tokenizer import F0Tokenizer 13 | 14 | # Get the base checkpoints directory from environment variable or use the default base path 15 | base_checkpoints_dir = Path(os.getenv("SPIRITLM_CHECKPOINTS_DIR", Path(__file__).parents[3] / "checkpoints")) 16 | 17 | # Append 'speech_tokenizer' to the base path 18 | CHECKPOINT_DIR = base_checkpoints_dir / "speech_tokenizer" 19 | 20 | CURRENT_DEVICE = ( 21 | torch.device(torch.cuda.current_device()) 22 | if torch.cuda.is_available() 23 | else "mps" if torch.backends.mps.is_available() else "cpu" 24 | ) 25 | 26 | 27 | def spiritlm_expressive_f0(f0_backbone="fcpe"): 28 | return F0Tokenizer( 29 | f0_extractor_method=f0_backbone, 30 | quantizer_path=CHECKPOINT_DIR / "vqvae_f0_quantizer/model.pt", 31 | hop_length=80, 32 | sampling_rate=16000, 33 | interpolate=True, 34 | device=CURRENT_DEVICE, 35 | ) 36 | -------------------------------------------------------------------------------- /spiritlm/speech_tokenizer/f0/f0_extractor.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the FAIR Noncommercial Research License 5 | # found in the LICENSE file in the root directory of this source tree. 6 | 7 | 8 | import logging 9 | 10 | import numpy as np 11 | import torch 12 | import torch.nn as nn 13 | import torchaudio 14 | 15 | _logger = logging.getLogger(__name__) 16 | 17 | 18 | class F0Extractor(nn.Module): 19 | 20 | def __init__( 21 | self, 22 | hop_length=80, 23 | sampling_rate=16000, 24 | interpolate=True, 25 | ): 26 | """Each second will have sampling_rate/hop_length frames.""" 27 | super().__init__() 28 | 29 | self.hop_length = hop_length 30 | self.sampling_rate = sampling_rate 31 | self.interpolate = interpolate 32 | 33 | def load_audio(self, path, mono=True): 34 | wav, sr = torchaudio.load(path) 35 | if sr != self.sampling_rate: 36 | wav = torchaudio.functional.resample( 37 | wav, orig_freq=sr, new_freq=self.sampling_rate 38 | ) 39 | if mono and wav.ndim == 2: 40 | wav = wav.mean(dim=0) 41 | wav = wav.numpy() 42 | return wav 43 | 44 | def compute_f0_uv(self, wav, interpolate=True): 45 | raise NotImplementedError("Not implemented!") 46 | 47 | @torch.inference_mode() 48 | def forward(self, audio, vuv=False): 49 | if isinstance(audio, str): 50 | audio = self.load_audio(audio) 51 | 52 | f0, uv = self.compute_f0_uv(audio, interpolate=self.interpolate) 53 | 54 | if not vuv: 55 | return f0 56 | else: 57 | return f0, uv 58 | 59 | 60 | class pYAAPTF0Extractor(F0Extractor): 61 | 62 | def compute_f0_uv(self, wav, interpolate=True): 63 | pitch = self.get_pitch(wav) 64 | # take interpolate, otherwise pitch.samp_values 65 | # pyaapt has some problems with pitch.samp_values, so do it manually (from pgslm) 66 | f0 = pitch.samp_values 67 | if interpolate: 68 | f0 = self.interpolate_f0(f0) 69 | vuv = pitch.vuv 70 | return f0, vuv 71 | 72 | def get_pitch(self, wav): 73 | try: 74 | import amfm_decompy.basic_tools as basic 75 | import amfm_decompy.pYAAPT as pYAAPT 76 | from librosa.util import normalize 77 | except ImportError as error: 78 | raise ImportError( 79 | "To use pYAAPTF0Extractor, please install AMFM-decompy and librosa" 80 | ) from error 81 | 82 | wav = wav.squeeze() 83 | assert wav.ndim == 1 84 | if not isinstance(wav, np.ndarray): 85 | wav = np.array(wav) 86 | frame_length = 20.0 # ms 87 | to_pad = int(frame_length / 1000 * self.sampling_rate) // 2 88 | 89 | # remove remainders for large hop length 90 | n_frames = len(wav) // self.hop_length * self.hop_length 91 | wav = wav[:n_frames] 92 | 93 | audio = normalize(wav) * 0.95 94 | if self.hop_length == 80: 95 | audio = np.pad(audio, (to_pad, to_pad), "constant", constant_values=0) 96 | audio = basic.SignalObj(audio, self.sampling_rate) 97 | pitch = pYAAPT.yaapt( 98 | audio, 99 | frame_length=frame_length, 100 | frame_space=self.hop_length / self.sampling_rate * 1000, 101 | nccf_thresh1=0.25, 102 | tda_frame_length=25.0, 103 | ) 104 | 105 | return pitch 106 | 107 | def interpolate_f0(self, f0, fill_extremities=True): 108 | try: 109 | from scipy.interpolate import interp1d 110 | except ImportError as error: 111 | raise ImportError( 112 | "To use pYAAPTF0Extractor, please install scipy (`pip install scipy`)" 113 | ) from error 114 | 115 | orig_t = np.arange(f0.shape[0]) 116 | f0_interp = f0[:] 117 | ii = f0_interp != 0 118 | if ii.sum() > 1: 119 | f0_interp = interp1d( 120 | orig_t[ii], 121 | f0_interp[ii], 122 | bounds_error=False, 123 | kind="linear", 124 | fill_value=0, 125 | )(orig_t) 126 | 127 | # Fill extreme values with border values 128 | if fill_extremities: 129 | f0_interp[: orig_t[ii][0]] = f0_interp[ii][0] 130 | f0_interp[orig_t[ii][-1] + 1 :] = f0_interp[ii][-1] 131 | 132 | return f0_interp 133 | 134 | 135 | class FCPEF0Extractor(F0Extractor): 136 | 137 | def __init__( 138 | self, 139 | hop_length=80, 140 | sampling_rate=16000, 141 | interpolate=True, 142 | device=None, 143 | ): 144 | try: 145 | from torchfcpe import spawn_bundled_infer_model 146 | except ImportError as error: 147 | raise ImportError( 148 | "To use FCPEF0Extractor, please install torchfcpe (`pip install torchfcpe`)" 149 | ) from error 150 | 151 | super().__init__( 152 | hop_length=hop_length, sampling_rate=sampling_rate, interpolate=interpolate 153 | ) 154 | 155 | self.model = spawn_bundled_infer_model(device=device) 156 | 157 | def compute_f0_uv(self, wav, interpolate=True): 158 | wav = wav.squeeze() 159 | assert wav.ndim == 1 160 | f0_target_length = (len(wav) // self.hop_length) + 1 161 | if not isinstance(wav, torch.Tensor): 162 | wav = torch.from_numpy(wav) 163 | wav = wav.float().unsqueeze(0).unsqueeze(-1) 164 | f0, uv = self.model.infer( 165 | wav, 166 | sr=self.sampling_rate, 167 | decoder_mode="local_argmax", # Recommended mode 168 | threshold=0.05, # Threshold for V/UV decision 169 | f0_min=50, # Minimum pitch 170 | f0_max=1100, # Maximum pitch 171 | interp_uv=interpolate, # Interpolate unvoiced frames 172 | output_interp_target_length=f0_target_length, # Interpolate to target length 173 | retur_uv=True, 174 | ) 175 | vuv = 1 - uv 176 | return f0.squeeze().cpu().numpy(), vuv.squeeze().cpu().numpy() 177 | 178 | 179 | def load_f0_extractor( 180 | f0_extractor_method, hop_length, sampling_rate, interpolate, device=None 181 | ): 182 | expected_methods = ["pyaapt", "fcpe"] 183 | assert ( 184 | f0_extractor_method in expected_methods 185 | ), f"Unexpected f0 extractor method: {f0_extractor_method} (choices are: {expected_methods})" 186 | if f0_extractor_method == "pyaapt": 187 | f0_extractor = pYAAPTF0Extractor( 188 | hop_length=hop_length, sampling_rate=sampling_rate, interpolate=interpolate 189 | ) 190 | elif f0_extractor_method == "fcpe": 191 | f0_extractor = FCPEF0Extractor( 192 | hop_length=hop_length, 193 | sampling_rate=sampling_rate, 194 | interpolate=interpolate, 195 | device=device, 196 | ) 197 | _logger.info( 198 | f"Using '{f0_extractor_method}' f0 extractor method (choices are: {expected_methods})" 199 | ) 200 | return f0_extractor 201 | -------------------------------------------------------------------------------- /spiritlm/speech_tokenizer/f0/f0_tokenizer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the FAIR Noncommercial Research License 5 | # found in the LICENSE file in the root directory of this source tree. 6 | 7 | 8 | import logging 9 | import os 10 | 11 | import torch 12 | 13 | from .f0_extractor import load_f0_extractor 14 | from .vqvae import load_vqvae 15 | 16 | _logger = logging.getLogger(__name__) 17 | 18 | 19 | class F0Tokenizer(torch.nn.Module): 20 | 21 | def __init__( 22 | self, 23 | f0_extractor_method, 24 | quantizer_path, 25 | f0_speaker_stats=None, 26 | hop_length=80, 27 | sampling_rate=16000, 28 | interpolate=False, 29 | device="cuda", 30 | ): 31 | super().__init__() 32 | 33 | self.f0_extractor = load_f0_extractor( 34 | f0_extractor_method=f0_extractor_method, 35 | hop_length=hop_length, 36 | sampling_rate=sampling_rate, 37 | interpolate=interpolate, 38 | device=device, 39 | ) 40 | 41 | self.quantizer, self.quantizer_cfg = load_vqvae(quantizer_path) 42 | self.quantizer.eval() 43 | self.quantizer.to(device) 44 | # Load speaker stats 45 | self.speaker_f0_stats = f0_speaker_stats 46 | if self.speaker_f0_stats is None and ( 47 | self.quantizer_cfg.get("speaker_norm", False) 48 | or "norm_" in self.quantizer_cfg.features 49 | ): 50 | speaker_stats_path = self.quantizer_cfg.get("speaker_stats", None) 51 | if speaker_stats_path is not None and os.path.exists(speaker_stats_path): 52 | self.speaker_f0_stats = torch.load( 53 | speaker_stats_path, weights_only=True 54 | ) 55 | _logger.info(f"Speaker f0 stats loaded from '{speaker_stats_path}'") 56 | else: 57 | _logger.info( 58 | "It seems that model is using normalized f0 but no speaker stats is given, will infer mean f0 from input utterance." 59 | ) 60 | 61 | # this is useful for determining the device 62 | self.register_buffer( 63 | "_float_tensor", torch.tensor([0], dtype=torch.float, device=device) 64 | ) 65 | 66 | @property 67 | def device(self): 68 | return self._float_tensor.device 69 | 70 | def quantize_vqvae(self, f0, vuv, speaker=None, compute_vqvae_pred=False): 71 | assert self.quantizer_cfg.features in [ 72 | "f0_interp,vuv", 73 | "f0,vuv", 74 | "norm_f0_interp,vuv", 75 | "norm_f0,vuv", 76 | ], self.quantizer_cfg.features 77 | 78 | if not isinstance(f0, torch.Tensor): 79 | f0 = torch.tensor(f0) 80 | if not isinstance(vuv, torch.Tensor): 81 | vuv = torch.tensor(vuv) 82 | 83 | # normalize f0 84 | if ( 85 | self.quantizer_cfg.get("speaker_norm", False) 86 | or "norm_" in self.quantizer_cfg.features 87 | ): 88 | mask = f0 != 0 89 | if speaker is not None and speaker in self.speaker_f0_stats: 90 | mean = self.speaker_f0_stats[speaker]["f0_mean"] 91 | else: 92 | # Get statistics from utterance (maybe it is more accurate to get mean from voiced segments) 93 | vuv_mask = vuv != 0 94 | mean = torch.mean(f0[vuv_mask]) 95 | f0[mask] = f0[mask] - mean 96 | 97 | x = torch.stack([f0, vuv]) # (2, T) 98 | x = x.float().unsqueeze(0).to(self.device) # (1, 2, T) 99 | if not compute_vqvae_pred: 100 | quant_f0 = self.quantizer(x, compute_pred=False) 101 | quant_f0 = quant_f0[0].squeeze(0) 102 | return quant_f0 103 | else: 104 | quant_f0, pred = self.quantizer(x, compute_pred=True) 105 | quant_f0 = quant_f0[0].squeeze(0) 106 | pred = pred[0] 107 | return quant_f0, pred 108 | 109 | def forward(self, x, speaker=None, dense=False, compute_vqvae_pred=False): 110 | f0, vuv = self.f0_extractor(x, vuv=True) 111 | if dense: 112 | return f0 113 | return self.quantize_vqvae( 114 | f0, vuv, speaker=speaker, compute_vqvae_pred=compute_vqvae_pred 115 | ) 116 | -------------------------------------------------------------------------------- /spiritlm/speech_tokenizer/f0/vqvae.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the FAIR Noncommercial Research License 5 | # found in the LICENSE file in the root directory of this source tree. 6 | 7 | # VQ-VAE model, adapted from: 8 | # - https://github.com/openai/jukebox/blob/master/jukebox/vqvae/ 9 | # - https://github.com/facebookresearch/speech-resynthesis/blob/main/modules/vq.py 10 | 11 | 12 | import logging 13 | import math 14 | from pathlib import Path 15 | 16 | import numpy as np 17 | import torch 18 | import torch.distributed as dist 19 | import torch.nn as nn 20 | import torch.nn.functional as F 21 | from omegaconf import OmegaConf 22 | 23 | _logger = logging.getLogger(__name__) 24 | 25 | 26 | def load_vqvae(checkpoint): 27 | config = Path(checkpoint).parent / "config.yaml" 28 | cfg = OmegaConf.load(config) 29 | model = VQVAE(cfg) 30 | state_dict = torch.load(checkpoint, map_location="cpu", weights_only=False) 31 | model.load_state_dict(state_dict["model"]) 32 | model.eval() 33 | _logger.info(f"VQVAE model loaded from '{checkpoint}'!") 34 | return model, cfg 35 | 36 | 37 | class VQVAE(nn.Module): 38 | def __init__(self, h): 39 | super().__init__() 40 | 41 | self.encoder = Encoder(**h.encoder_params) 42 | self.vq = Bottleneck(**h.vq_params) 43 | self.decoder = Decoder(**h.decoder_params) 44 | 45 | def forward(self, x, compute_pred=False): 46 | with torch.no_grad(): 47 | z = self.encoder(x) 48 | codes, z_q, commit_losses, metrics = self.vq(z) 49 | 50 | if not compute_pred: 51 | return codes 52 | x_hat = self.decoder(z_q) 53 | 54 | return codes, x_hat 55 | 56 | 57 | class BottleneckBlock(nn.Module): 58 | def __init__(self, k_bins, emb_width, mu): 59 | super().__init__() 60 | self.k_bins = k_bins 61 | self.emb_width = emb_width 62 | self.mu = mu 63 | self.reset_k() 64 | self.threshold = 1.0 65 | 66 | def reset_k(self): 67 | self.init = False 68 | self.k_sum = None 69 | self.k_elem = None 70 | self.register_buffer("k", torch.zeros(self.k_bins, self.emb_width)) 71 | 72 | def _tile(self, x): 73 | d, ew = x.shape 74 | if d < self.k_bins: 75 | n_repeats = (self.k_bins + d - 1) // d 76 | std = 0.01 / np.sqrt(ew) 77 | x = x.repeat(n_repeats, 1) 78 | x = x + torch.randn_like(x) * std 79 | return x 80 | 81 | def init_k(self, x): 82 | mu, emb_width, k_bins = self.mu, self.emb_width, self.k_bins 83 | self.init = True 84 | # init k_w using random vectors from x 85 | y = self._tile(x) 86 | _k_rand = y[torch.randperm(y.shape[0])][:k_bins] 87 | dist.broadcast(_k_rand, 0) 88 | self.k = _k_rand 89 | assert self.k.shape == (k_bins, emb_width) 90 | self.k_sum = self.k 91 | self.k_elem = torch.ones(k_bins, device=self.k.device) 92 | 93 | def restore_k(self, num_tokens=None, threshold=1.0): 94 | mu, emb_width, k_bins = self.mu, self.emb_width, self.k_bins 95 | self.init = True 96 | assert self.k.shape == (k_bins, emb_width) 97 | self.k_sum = self.k.clone() 98 | self.k_elem = torch.ones(k_bins, device=self.k.device) 99 | if num_tokens is not None: 100 | expected_usage = num_tokens / k_bins 101 | self.k_elem.data.mul_(expected_usage) 102 | self.k_sum.data.mul_(expected_usage) 103 | self.threshold = threshold 104 | 105 | def update_k(self, x, x_l): 106 | mu, emb_width, k_bins = self.mu, self.emb_width, self.k_bins 107 | with torch.no_grad(): 108 | # Calculate new centres 109 | x_l_onehot = torch.zeros( 110 | k_bins, x.shape[0], device=x.device 111 | ) # k_bins, N * L 112 | x_l_onehot.scatter_(0, x_l.view(1, x.shape[0]), 1) 113 | 114 | _k_sum = torch.matmul(x_l_onehot, x) # k_bins, w 115 | _k_elem = x_l_onehot.sum(dim=-1) # k_bins 116 | y = self._tile(x) 117 | _k_rand = y[torch.randperm(y.shape[0])][:k_bins] 118 | 119 | dist.broadcast(_k_rand, 0) 120 | dist.all_reduce(_k_sum) 121 | dist.all_reduce(_k_elem) 122 | 123 | # Update centres 124 | old_k = self.k 125 | self.k_sum = mu * self.k_sum + (1.0 - mu) * _k_sum # w, k_bins 126 | self.k_elem = mu * self.k_elem + (1.0 - mu) * _k_elem # k_bins 127 | usage = (self.k_elem.view(k_bins, 1) >= self.threshold).float() 128 | self.k = ( 129 | usage 130 | * (self.k_sum.view(k_bins, emb_width) / self.k_elem.view(k_bins, 1)) 131 | + (1 - usage) * _k_rand 132 | ) 133 | _k_prob = _k_elem / torch.sum( 134 | _k_elem 135 | ) # x_l_onehot.mean(dim=-1) # prob of each bin 136 | entropy = -torch.sum( 137 | _k_prob * torch.log(_k_prob + 1e-8) 138 | ) # entropy ie how diverse 139 | used_curr = (_k_elem >= self.threshold).sum() 140 | usage = torch.sum(usage) 141 | dk = torch.norm(self.k - old_k) / np.sqrt(np.prod(old_k.shape)) 142 | return dict(entropy=entropy, used_curr=used_curr, usage=usage, dk=dk) 143 | 144 | def preprocess(self, x): 145 | # NCT -> NTC -> [NT, C] 146 | x = x.permute(0, 2, 1).contiguous() 147 | x = x.view(-1, x.shape[-1]) # x_en = (N * L, w), k_j = (w, k_bins) 148 | 149 | if x.shape[-1] == self.emb_width: 150 | prenorm = torch.norm(x - torch.mean(x)) / np.sqrt(np.prod(x.shape)) 151 | elif x.shape[-1] == 2 * self.emb_width: 152 | x1, x2 = x[..., : self.emb_width], x[..., self.emb_width :] 153 | prenorm = (torch.norm(x1 - torch.mean(x1)) / np.sqrt(np.prod(x1.shape))) + ( 154 | torch.norm(x2 - torch.mean(x2)) / np.sqrt(np.prod(x2.shape)) 155 | ) 156 | 157 | # Normalise 158 | x = x1 + x2 159 | else: 160 | assert False, f"Expected {x.shape[-1]} to be (1 or 2) * {self.emb_width}" 161 | return x, prenorm 162 | 163 | def postprocess(self, x_l, x_d, x_shape): 164 | # [NT, C] -> NTC -> NCT 165 | N, T = x_shape 166 | x_d = x_d.view(N, T, -1).permute(0, 2, 1).contiguous() 167 | x_l = x_l.view(N, T) 168 | return x_l, x_d 169 | 170 | def quantise(self, x): 171 | # Calculate latent code x_l 172 | k_w = self.k.t() 173 | distance = ( 174 | torch.sum(x**2, dim=-1, keepdim=True) 175 | - 2 * torch.matmul(x, k_w) 176 | + torch.sum(k_w**2, dim=0, keepdim=True) 177 | ) # (N * L, b) 178 | min_distance, x_l = torch.min(distance, dim=-1) 179 | fit = torch.mean(min_distance) 180 | return x_l, fit 181 | 182 | def dequantise(self, x_l): 183 | x = F.embedding(x_l, self.k) 184 | return x 185 | 186 | def encode(self, x): 187 | N, width, T = x.shape 188 | 189 | # Preprocess. 190 | x, prenorm = self.preprocess(x) 191 | 192 | # Quantise 193 | x_l, fit = self.quantise(x) 194 | 195 | # Postprocess. 196 | x_l = x_l.view(N, T) 197 | return x_l 198 | 199 | def decode(self, x_l): 200 | N, T = x_l.shape 201 | width = self.emb_width 202 | 203 | # Dequantise 204 | x_d = self.dequantise(x_l) 205 | 206 | # Postprocess 207 | x_d = x_d.view(N, T, width).permute(0, 2, 1).contiguous() 208 | return x_d 209 | 210 | def forward(self, x, update_k=True): 211 | N, width, T = x.shape 212 | 213 | # Preprocess 214 | x, prenorm = self.preprocess(x) 215 | 216 | # Init k if not inited 217 | if update_k and not self.init: 218 | self.init_k(x) 219 | 220 | # Quantise and dequantise through bottleneck 221 | x_l, fit = self.quantise(x) 222 | x_d = self.dequantise(x_l) 223 | 224 | # Update embeddings 225 | if update_k and self.training: 226 | update_metrics = self.update_k(x, x_l) 227 | else: 228 | update_metrics = {} 229 | 230 | # Loss 231 | commit_loss = torch.norm(x_d.detach() - x) ** 2 / np.prod(x.shape) 232 | 233 | # Passthrough 234 | x_d = x + (x_d - x).detach() 235 | 236 | # Postprocess 237 | x_l, x_d = self.postprocess(x_l, x_d, (N, T)) 238 | return x_l, x_d, commit_loss, dict(fit=fit, pn=prenorm, **update_metrics) 239 | 240 | 241 | class Bottleneck(nn.Module): 242 | def __init__(self, l_bins, emb_width, mu, levels): 243 | super().__init__() 244 | self.levels = levels 245 | level_block = lambda level: BottleneckBlock(l_bins, emb_width, mu) 246 | self.level_blocks = nn.ModuleList() 247 | for level in range(self.levels): 248 | self.level_blocks.append(level_block(level)) 249 | 250 | def encode(self, xs): 251 | zs = [level_block.encode(x) for (level_block, x) in zip(self.level_blocks, xs)] 252 | return zs 253 | 254 | def decode(self, zs, start_level=0, end_level=None): 255 | if end_level is None: 256 | end_level = self.levels 257 | xs_quantised = [ 258 | level_block.decode(z) 259 | for (level_block, z) in zip(self.level_blocks[start_level:end_level], zs) 260 | ] 261 | return xs_quantised 262 | 263 | def forward(self, xs): 264 | zs, xs_quantised, commit_losses, metrics = [], [], [], [] 265 | for level in range(self.levels): 266 | level_block = self.level_blocks[level] 267 | x = xs[level] 268 | z, x_quantised, commit_loss, metric = level_block(x, update_k=self.training) 269 | zs.append(z) 270 | if not self.training: 271 | # Be extra paranoid and make sure the encoder weights can't 272 | # change from straight-through estimator 273 | x_quantised = x_quantised.detach() 274 | xs_quantised.append(x_quantised) 275 | commit_losses.append(commit_loss) 276 | if self.training: 277 | metrics.append(metric) 278 | return zs, xs_quantised, commit_losses, metrics 279 | 280 | 281 | class ResConvBlock(nn.Module): 282 | def __init__(self, n_in, n_state): 283 | super().__init__() 284 | self.model = nn.Sequential( 285 | nn.ReLU(), 286 | nn.Conv2d(n_in, n_state, 3, 1, 1), 287 | nn.ReLU(), 288 | nn.Conv2d(n_state, n_in, 1, 1, 0), 289 | ) 290 | 291 | def forward(self, x): 292 | return x + self.model(x) 293 | 294 | 295 | class Resnet(nn.Module): 296 | def __init__(self, n_in, n_depth, m_conv=1.0): 297 | super().__init__() 298 | self.model = nn.Sequential( 299 | *[ResConvBlock(n_in, int(m_conv * n_in)) for _ in range(n_depth)] 300 | ) 301 | 302 | def forward(self, x): 303 | return self.model(x) 304 | 305 | 306 | class ResConv1DBlock(nn.Module): 307 | def __init__(self, n_in, n_state, dilation=1, zero_out=False, res_scale=1.0): 308 | super().__init__() 309 | padding = dilation 310 | self.model = nn.Sequential( 311 | nn.ReLU(), 312 | nn.Conv1d(n_in, n_state, 3, 1, padding, dilation), 313 | nn.ReLU(), 314 | nn.Conv1d(n_state, n_in, 1, 1, 0), 315 | ) 316 | if zero_out: 317 | out = self.model[-1] 318 | nn.init.zeros_(out.weight) 319 | nn.init.zeros_(out.bias) 320 | self.res_scale = res_scale 321 | 322 | def forward(self, x): 323 | return x + self.res_scale * self.model(x) 324 | 325 | 326 | class Resnet1D(nn.Module): 327 | def __init__( 328 | self, 329 | n_in, 330 | n_depth, 331 | m_conv=1.0, 332 | dilation_growth_rate=1, 333 | dilation_cycle=None, 334 | zero_out=False, 335 | res_scale=False, 336 | reverse_dilation=False, 337 | checkpoint_res=False, 338 | ): 339 | super().__init__() 340 | 341 | def _get_depth(depth): 342 | if dilation_cycle is None: 343 | return depth 344 | else: 345 | return depth % dilation_cycle 346 | 347 | blocks = [ 348 | ResConv1DBlock( 349 | n_in, 350 | int(m_conv * n_in), 351 | dilation=dilation_growth_rate ** _get_depth(depth), 352 | zero_out=zero_out, 353 | res_scale=1.0 if not res_scale else 1.0 / math.sqrt(n_depth), 354 | ) 355 | for depth in range(n_depth) 356 | ] 357 | if reverse_dilation: 358 | blocks = blocks[::-1] 359 | self.checkpoint_res = checkpoint_res 360 | if self.checkpoint_res == 1: 361 | if dist.get_rank() == 0: 362 | _logger.warning("Checkpointing convs") 363 | self.blocks = nn.ModuleList(blocks) 364 | else: 365 | self.model = nn.Sequential(*blocks) 366 | 367 | def forward(self, x): 368 | if self.checkpoint_res == 1: 369 | raise NotImplementedError("Checkpoint not implemented") 370 | else: 371 | return self.model(x) 372 | 373 | 374 | def assert_shape(x, exp_shape): 375 | assert x.shape == exp_shape, f"Expected {exp_shape} got {x.shape}" 376 | 377 | 378 | class EncoderConvBlock(nn.Module): 379 | def __init__( 380 | self, 381 | input_emb_width, 382 | output_emb_width, 383 | down_t, 384 | stride_t, 385 | width, 386 | depth, 387 | m_conv, 388 | dilation_growth_rate=1, 389 | dilation_cycle=None, 390 | zero_out=False, 391 | res_scale=False, 392 | ): 393 | super().__init__() 394 | blocks = [] 395 | if type(stride_t) is tuple or type(stride_t) is list: 396 | start = True 397 | for s_t, d_t in zip(stride_t, down_t): 398 | if s_t % 2 == 0: 399 | filter_t, pad_t = s_t * 2, s_t // 2 400 | else: 401 | filter_t, pad_t = s_t * 2 + 1, s_t // 2 + 1 402 | if d_t > 0: 403 | for i in range(d_t): 404 | block = nn.Sequential( 405 | nn.Conv1d( 406 | input_emb_width if i == 0 and start else width, 407 | width, 408 | filter_t, 409 | s_t, 410 | pad_t, 411 | ), 412 | Resnet1D( 413 | width, 414 | depth, 415 | m_conv, 416 | dilation_growth_rate, 417 | dilation_cycle, 418 | zero_out, 419 | res_scale, 420 | ), 421 | ) 422 | blocks.append(block) 423 | start = False 424 | block = nn.Conv1d(width, output_emb_width, 3, 1, 1) 425 | blocks.append(block) 426 | else: 427 | filter_t, pad_t = stride_t * 2, stride_t // 2 428 | if down_t > 0: 429 | for i in range(down_t): 430 | block = nn.Sequential( 431 | nn.Conv1d( 432 | input_emb_width if i == 0 else width, 433 | width, 434 | filter_t, 435 | stride_t, 436 | pad_t, 437 | ), 438 | Resnet1D( 439 | width, 440 | depth, 441 | m_conv, 442 | dilation_growth_rate, 443 | dilation_cycle, 444 | zero_out, 445 | res_scale, 446 | ), 447 | ) 448 | blocks.append(block) 449 | block = nn.Conv1d(width, output_emb_width, 3, 1, 1) 450 | blocks.append(block) 451 | self.model = nn.Sequential(*blocks) 452 | 453 | def forward(self, x): 454 | return self.model(x) 455 | 456 | 457 | class DecoderConvBock(nn.Module): 458 | def __init__( 459 | self, 460 | input_emb_width, 461 | output_emb_width, 462 | down_t, 463 | stride_t, 464 | width, 465 | depth, 466 | m_conv, 467 | dilation_growth_rate=1, 468 | dilation_cycle=None, 469 | zero_out=False, 470 | res_scale=False, 471 | reverse_decoder_dilation=False, 472 | checkpoint_res=False, 473 | ): 474 | super().__init__() 475 | blocks = [] 476 | 477 | if type(stride_t) is tuple or type(stride_t) is list: 478 | block = nn.Conv1d(output_emb_width, width, 3, 1, 1) 479 | blocks.append(block) 480 | for k, (s_t, d_t) in enumerate(zip(stride_t, down_t)): 481 | if d_t > 0: 482 | if s_t % 2 == 0: 483 | filter_t, pad_t = s_t * 2, s_t // 2 484 | else: 485 | filter_t, pad_t = s_t * 2 + 1, s_t // 2 + 1 486 | end = k == len(stride_t) - 1 487 | for i in range(d_t): 488 | block = nn.Sequential( 489 | Resnet1D( 490 | width, 491 | depth, 492 | m_conv, 493 | dilation_growth_rate, 494 | dilation_cycle, 495 | zero_out=zero_out, 496 | res_scale=res_scale, 497 | reverse_dilation=reverse_decoder_dilation, 498 | checkpoint_res=checkpoint_res, 499 | ), 500 | nn.ConvTranspose1d( 501 | width, 502 | input_emb_width if i == (d_t - 1) and end else width, 503 | filter_t, 504 | s_t, 505 | pad_t, 506 | ), 507 | ) 508 | blocks.append(block) 509 | else: 510 | if down_t > 0: 511 | filter_t, pad_t = stride_t * 2, stride_t // 2 512 | block = nn.Conv1d(output_emb_width, width, 3, 1, 1) 513 | blocks.append(block) 514 | for i in range(down_t): 515 | block = nn.Sequential( 516 | Resnet1D( 517 | width, 518 | depth, 519 | m_conv, 520 | dilation_growth_rate, 521 | dilation_cycle, 522 | zero_out=zero_out, 523 | res_scale=res_scale, 524 | reverse_dilation=reverse_decoder_dilation, 525 | checkpoint_res=checkpoint_res, 526 | ), 527 | nn.ConvTranspose1d( 528 | width, 529 | input_emb_width if i == (down_t - 1) else width, 530 | filter_t, 531 | stride_t, 532 | pad_t, 533 | ), 534 | ) 535 | blocks.append(block) 536 | self.model = nn.Sequential(*blocks) 537 | 538 | def forward(self, x): 539 | return self.model(x) 540 | 541 | 542 | class Encoder(nn.Module): 543 | def __init__( 544 | self, 545 | input_emb_width, 546 | output_emb_width, 547 | levels, 548 | downs_t, 549 | strides_t, 550 | **block_kwargs, 551 | ): 552 | super().__init__() 553 | self.input_emb_width = input_emb_width 554 | self.output_emb_width = output_emb_width 555 | self.levels = levels 556 | self.downs_t = downs_t 557 | self.strides_t = strides_t 558 | 559 | block_kwargs_copy = dict(**block_kwargs) 560 | if "reverse_decoder_dilation" in block_kwargs_copy: 561 | del block_kwargs_copy["reverse_decoder_dilation"] 562 | level_block = lambda level, down_t, stride_t: EncoderConvBlock( 563 | input_emb_width if level == 0 else output_emb_width, 564 | output_emb_width, 565 | down_t, 566 | stride_t, 567 | **block_kwargs_copy, 568 | ) 569 | self.level_blocks = nn.ModuleList() 570 | iterator = zip(list(range(self.levels)), downs_t, strides_t) 571 | for level, down_t, stride_t in iterator: 572 | self.level_blocks.append(level_block(level, down_t, stride_t)) 573 | 574 | def forward(self, x): 575 | N, T = x.shape[0], x.shape[-1] 576 | emb = self.input_emb_width 577 | assert_shape(x, (N, emb, T)) 578 | xs = [] 579 | 580 | # 64, 32, ... 581 | iterator = zip(list(range(self.levels)), self.downs_t, self.strides_t) 582 | for level, down_t, stride_t in iterator: 583 | level_block = self.level_blocks[level] 584 | x = level_block(x) 585 | if type(stride_t) is tuple or type(stride_t) is list: 586 | emb, T = self.output_emb_width, T // np.prod( 587 | [s**d for s, d in zip(stride_t, down_t)] 588 | ) 589 | else: 590 | emb, T = self.output_emb_width, T // (stride_t**down_t) 591 | assert_shape(x, (N, emb, T)) 592 | xs.append(x) 593 | 594 | return xs 595 | 596 | 597 | class Decoder(nn.Module): 598 | def __init__( 599 | self, 600 | input_emb_width, 601 | output_emb_width, 602 | levels, 603 | downs_t, 604 | strides_t, 605 | **block_kwargs, 606 | ): 607 | super().__init__() 608 | self.input_emb_width = input_emb_width 609 | self.output_emb_width = output_emb_width 610 | self.levels = levels 611 | 612 | self.downs_t = downs_t 613 | 614 | self.strides_t = strides_t 615 | 616 | level_block = lambda level, down_t, stride_t: DecoderConvBock( 617 | output_emb_width, output_emb_width, down_t, stride_t, **block_kwargs 618 | ) 619 | self.level_blocks = nn.ModuleList() 620 | iterator = zip(list(range(self.levels)), downs_t, strides_t) 621 | for level, down_t, stride_t in iterator: 622 | self.level_blocks.append(level_block(level, down_t, stride_t)) 623 | 624 | self.out = nn.Conv1d(output_emb_width, input_emb_width, 3, 1, 1) 625 | 626 | def forward(self, xs, all_levels=True): 627 | if all_levels: 628 | assert len(xs) == self.levels 629 | else: 630 | assert len(xs) == 1 631 | x = xs[-1] 632 | N, T = x.shape[0], x.shape[-1] 633 | emb = self.output_emb_width 634 | assert_shape(x, (N, emb, T)) 635 | 636 | # 32, 64 ... 637 | iterator = reversed( 638 | list(zip(list(range(self.levels)), self.downs_t, self.strides_t)) 639 | ) 640 | for level, down_t, stride_t in iterator: 641 | level_block = self.level_blocks[level] 642 | x = level_block(x) 643 | if type(stride_t) is tuple or type(stride_t) is list: 644 | emb, T = self.output_emb_width, T * np.prod( 645 | [s**d for s, d in zip(stride_t, down_t)] 646 | ) 647 | else: 648 | emb, T = self.output_emb_width, T * (stride_t**down_t) 649 | assert_shape(x, (N, emb, T)) 650 | if level != 0 and all_levels: 651 | x = x + xs[level - 1] 652 | 653 | x = self.out(x) 654 | return x 655 | -------------------------------------------------------------------------------- /spiritlm/speech_tokenizer/hifigan/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the FAIR Noncommercial Research License 5 | # found in the LICENSE file in the root directory of this source tree. 6 | 7 | from pathlib import Path 8 | import os 9 | 10 | import torch 11 | 12 | from .hifigan_vocoder import HifiGANVocoder 13 | 14 | # Get the base checkpoints directory from environment variable or use the default base path 15 | base_checkpoints_dir = Path(os.getenv("SPIRITLM_CHECKPOINTS_DIR", Path(__file__).parents[3] / "checkpoints")) 16 | 17 | # Append 'speech_tokenizer' to the base path 18 | CHECKPOINT_DIR = base_checkpoints_dir / "speech_tokenizer" 19 | 20 | CURRENT_DEVICE = ( 21 | torch.device(torch.cuda.current_device()) 22 | if torch.cuda.is_available() 23 | else "mps" if torch.backends.mps.is_available() else "cpu" 24 | ) 25 | 26 | 27 | def spiritlm_base_hifigan( 28 | default_speaker=2, 29 | default_style=8, # conv-default 30 | ): 31 | return HifiGANVocoder( 32 | CHECKPOINT_DIR / "hifigan_spiritlm_base/generator.pt", 33 | default_speaker=default_speaker, 34 | default_style=default_style, 35 | ).to(CURRENT_DEVICE) 36 | 37 | 38 | def spiritlm_expressive_hifigan_w2v2(default_speaker=2): 39 | return HifiGANVocoder( 40 | CHECKPOINT_DIR / "hifigan_spiritlm_expressive_w2v2/generator.pt", 41 | default_speaker=default_speaker, 42 | ).to(CURRENT_DEVICE) 43 | -------------------------------------------------------------------------------- /spiritlm/speech_tokenizer/hifigan/hifigan_vocoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the FAIR Noncommercial Research License 5 | # found in the LICENSE file in the root directory of this source tree. 6 | 7 | # Standalone Hifigan vocoder 8 | # Adapted from: 9 | # - https://github.com/jik876/hifi-gan 10 | # - https://github.com/facebookresearch/fairseq/tree/main/fairseq/models/text_to_speech 11 | # - https://github.com/facebookresearch/speech-resynthesis/blob/main/examples/speech_to_speech_translation/models.py 12 | # - https://github.com/facebookresearch/speech-resynthesis/blob/main/examples/expresso/models.py 13 | 14 | import json 15 | import logging 16 | from pathlib import Path 17 | from typing import Dict 18 | 19 | import numpy as np 20 | import torch 21 | import torch.nn as nn 22 | import torch.nn.functional as F 23 | from torch.nn import Conv1d, ConvTranspose1d 24 | from torch.nn.utils import remove_weight_norm, weight_norm 25 | 26 | _logger = logging.getLogger(__name__) 27 | 28 | 29 | class HifiGANVocoder(nn.Module): 30 | def __init__( 31 | self, 32 | checkpoint_path, 33 | config_path=None, 34 | default_speaker=0, 35 | default_style=0, 36 | fp16=False, 37 | ): 38 | super().__init__() 39 | 40 | if config_path is None: 41 | config_path = Path(checkpoint_path).parent / "config.json" 42 | with open(config_path) as f: 43 | cfg = json.load(f) 44 | self.vocoder = CodeHiFiGANVocoderModel(checkpoint_path, cfg, fp16) 45 | self.vocoder.eval() 46 | 47 | self.multispkr = self.vocoder.model.multispkr 48 | if self.multispkr: 49 | self.default_speaker = default_speaker 50 | speakers_path = Path(checkpoint_path).parent / "speakers.txt" 51 | if speakers_path.exists(): 52 | with open(speakers_path) as f: 53 | self.speakers = [line.strip() for line in f] 54 | _logger.info( 55 | f"Loaded {len(self.speakers)} speakers. First few speakers: {self.speakers[:10]}" 56 | ) 57 | 58 | self.multistyle = self.vocoder.model.multistyle 59 | if self.multistyle: 60 | self.default_style = default_style 61 | styles_path = Path(checkpoint_path).parent / "styles.txt" 62 | if styles_path.exists(): 63 | with open(styles_path) as f: 64 | self.styles = [line.strip() for line in f] 65 | _logger.info( 66 | f"Loaded {len(self.styles)} styles. First few styles: {self.styles[:10]}" 67 | ) 68 | 69 | self.dur_pred = self.vocoder.model.dur_predictor is not None 70 | self.cfg = cfg 71 | 72 | _logger.info( 73 | f"HifiGAN: Duration Prediction = {self.dur_pred} - " 74 | f"Multiple Speaker = {bool(self.multispkr)} - " 75 | f"Multiple Style = {bool(self.multistyle)}" 76 | ) 77 | 78 | # this is useful for determining the device 79 | self.register_buffer("_float_tensor", torch.tensor([0], dtype=torch.float)) 80 | 81 | @property 82 | def device(self): 83 | return self._float_tensor.device 84 | 85 | def preprocess_code(self, code, deduplicate_code=False): 86 | if isinstance(code, str): 87 | code = code.split() 88 | if isinstance(code, list): 89 | code = list(map(int, code)) 90 | code = torch.tensor(code) 91 | elif isinstance(code, np.ndarray): 92 | code = torch.from_numpy(code) 93 | code = code.long() 94 | if deduplicate_code: 95 | code = torch.unique_consecutive(code) 96 | return code.view(1, -1) 97 | 98 | def forward( 99 | self, 100 | code, 101 | speaker_id=None, 102 | style_id=None, 103 | dur_pred=True, 104 | f0_code=None, 105 | style_code=None, 106 | not_dedup_code=False, 107 | ): 108 | assert not ( 109 | dur_pred and not self.dur_pred 110 | ), "Model doesnt't support duration prediction" 111 | inp = dict() 112 | inp["code"] = self.preprocess_code(code, dur_pred and not not_dedup_code) 113 | if f0_code is not None: 114 | inp["f0_code"] = self.preprocess_code(f0_code, deduplicate_code=False) 115 | if style_code is not None: 116 | inp["style_code"] = self.preprocess_code(style_code, deduplicate_code=False) 117 | if self.multispkr: 118 | if speaker_id is None: 119 | speaker_id = self.default_speaker 120 | inp["spkr"] = torch.LongTensor([speaker_id]).view(1, 1) 121 | if self.multistyle: 122 | if style_id is None: 123 | style_id = self.default_style 124 | inp["style"] = torch.LongTensor([style_id]).view(1, 1) 125 | inp = {k: v.to(self.device) for k, v in inp.items()} 126 | return self.vocoder(inp, dur_pred) 127 | 128 | 129 | class CodeHiFiGANVocoderModel(nn.Module): 130 | def __init__( 131 | self, checkpoint_path: str, model_cfg: Dict[str, str], fp16: bool = False 132 | ) -> None: 133 | super().__init__() 134 | self.model = CodeGenerator(model_cfg) 135 | state_dict = torch.load(checkpoint_path, map_location="cpu", weights_only=True) 136 | self.model.load_state_dict(state_dict["generator"]) 137 | self.model.eval() 138 | if fp16: 139 | self.model.half() 140 | self.model.remove_weight_norm() 141 | _logger.info(f"Loaded CodeHiFiGAN checkpoint from '{checkpoint_path}'") 142 | 143 | def upsample(self, code, downsampled_code, uprate): 144 | N = code.size(1) 145 | K = downsampled_code.size(1) 146 | assert abs(K * uprate - N) / uprate <= 1, (N, K, uprate) 147 | upsampled_code = torch.repeat_interleave(downsampled_code, uprate, dim=1) 148 | if upsampled_code.size(1) < N: 149 | z = torch.zeros_like(code) 150 | z[:, : upsampled_code.size(1)] = upsampled_code 151 | z[:, upsampled_code.size(1) :] = upsampled_code[:, -1].view(-1, 1) 152 | upsampled_code = z 153 | upsampled_code = upsampled_code[:, :N] 154 | return upsampled_code 155 | 156 | def forward(self, x: Dict[str, torch.Tensor], dur_prediction=False) -> torch.Tensor: 157 | assert "code" in x 158 | x["dur_prediction"] = dur_prediction 159 | 160 | # remove invalid code 161 | mask = x["code"] >= 0 162 | x["code"] = x["code"][mask].unsqueeze(dim=0) 163 | if "f0" in x: 164 | f0_up_ratio = x["f0"].size(1) // x["code"].size(1) 165 | mask = mask.unsqueeze(2).repeat(1, 1, f0_up_ratio).view(-1, x["f0"].size(1)) 166 | x["f0"] = x["f0"][mask].unsqueeze(dim=0) 167 | 168 | # preprocess f0 & style codes 169 | if "f0_code" in x: 170 | if dur_prediction: # f0 must be upsampled first if dedup 171 | assert len(x["f0_code"][0]) == len( 172 | x["code"][0] 173 | ), f"f0 must be upsampled first if dedup (f0_code length: {len(x['f0_code'][0])}, code length: {len(x['code'][0])})" 174 | else: 175 | x["f0_code"] = self.upsample( 176 | x["code"], x["f0_code"], self.model.hubert_to_f0 177 | ) 178 | 179 | if "style_code" in x: 180 | if dur_prediction: # style must be upsampled first if dedup 181 | f"style must be upsampled first if dedup (style_code length: {len(x['style_code'][0])}, code length: {len(x['code'][0])})" 182 | else: 183 | x["style_code"] = self.upsample( 184 | x["code"], x["style_code"], self.model.hubert_to_style 185 | ) 186 | 187 | return self.model(**x).detach().squeeze() 188 | 189 | 190 | # Higigan Generator 191 | LRELU_SLOPE = 0.1 192 | 193 | 194 | def init_weights(m, mean=0.0, std=0.01): 195 | classname = m.__class__.__name__ 196 | if classname.find("Conv") != -1: 197 | m.weight.data.normal_(mean, std) 198 | 199 | 200 | def get_padding(kernel_size, dilation=1): 201 | return (kernel_size * dilation - dilation) // 2 202 | 203 | 204 | class ResBlock(nn.Module): 205 | def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)): 206 | super(ResBlock, self).__init__() 207 | self.convs1 = nn.ModuleList( 208 | [ 209 | weight_norm( 210 | Conv1d( 211 | channels, 212 | channels, 213 | kernel_size, 214 | 1, 215 | dilation=dilation[0], 216 | padding=get_padding(kernel_size, dilation[0]), 217 | ) 218 | ), 219 | weight_norm( 220 | Conv1d( 221 | channels, 222 | channels, 223 | kernel_size, 224 | 1, 225 | dilation=dilation[1], 226 | padding=get_padding(kernel_size, dilation[1]), 227 | ) 228 | ), 229 | weight_norm( 230 | Conv1d( 231 | channels, 232 | channels, 233 | kernel_size, 234 | 1, 235 | dilation=dilation[2], 236 | padding=get_padding(kernel_size, dilation[2]), 237 | ) 238 | ), 239 | ] 240 | ) 241 | self.convs1.apply(init_weights) 242 | 243 | self.convs2 = nn.ModuleList( 244 | [ 245 | weight_norm( 246 | Conv1d( 247 | channels, 248 | channels, 249 | kernel_size, 250 | 1, 251 | dilation=1, 252 | padding=get_padding(kernel_size, 1), 253 | ) 254 | ), 255 | weight_norm( 256 | Conv1d( 257 | channels, 258 | channels, 259 | kernel_size, 260 | 1, 261 | dilation=1, 262 | padding=get_padding(kernel_size, 1), 263 | ) 264 | ), 265 | weight_norm( 266 | Conv1d( 267 | channels, 268 | channels, 269 | kernel_size, 270 | 1, 271 | dilation=1, 272 | padding=get_padding(kernel_size, 1), 273 | ) 274 | ), 275 | ] 276 | ) 277 | self.convs2.apply(init_weights) 278 | 279 | def forward(self, x): 280 | for c1, c2 in zip(self.convs1, self.convs2): 281 | xt = F.leaky_relu(x, LRELU_SLOPE) 282 | xt = c1(xt) 283 | xt = F.leaky_relu(xt, LRELU_SLOPE) 284 | xt = c2(xt) 285 | x = xt + x 286 | return x 287 | 288 | def remove_weight_norm(self): 289 | for layer in self.convs1: 290 | remove_weight_norm(layer) 291 | for layer in self.convs2: 292 | remove_weight_norm(layer) 293 | 294 | 295 | class Generator(nn.Module): 296 | def __init__(self, cfg): 297 | super(Generator, self).__init__() 298 | self.num_kernels = len(cfg["resblock_kernel_sizes"]) 299 | self.num_upsamples = len(cfg["upsample_rates"]) 300 | self.conv_pre = weight_norm( 301 | Conv1d( 302 | cfg.get("model_in_dim", 80), 303 | cfg["upsample_initial_channel"], 304 | 7, 305 | 1, 306 | padding=3, 307 | ) 308 | ) 309 | 310 | self.ups = nn.ModuleList() 311 | for i, (u, k) in enumerate( 312 | zip(cfg["upsample_rates"], cfg["upsample_kernel_sizes"]) 313 | ): 314 | self.ups.append( 315 | weight_norm( 316 | ConvTranspose1d( 317 | cfg["upsample_initial_channel"] // (2**i), 318 | cfg["upsample_initial_channel"] // (2 ** (i + 1)), 319 | k, 320 | u, 321 | padding=(k - u) // 2, 322 | ) 323 | ) 324 | ) 325 | 326 | self.resblocks = nn.ModuleList() 327 | for i in range(len(self.ups)): 328 | ch = cfg["upsample_initial_channel"] // (2 ** (i + 1)) 329 | for k, d in zip( 330 | cfg["resblock_kernel_sizes"], cfg["resblock_dilation_sizes"] 331 | ): 332 | self.resblocks.append(ResBlock(ch, k, d)) 333 | 334 | self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3)) 335 | self.ups.apply(init_weights) 336 | self.conv_post.apply(init_weights) 337 | 338 | def forward(self, x): 339 | x = self.conv_pre(x) 340 | for i in range(self.num_upsamples): 341 | x = F.leaky_relu(x, LRELU_SLOPE) 342 | x = self.ups[i](x) 343 | xs = None 344 | for j in range(self.num_kernels): 345 | if xs is None: 346 | xs = self.resblocks[i * self.num_kernels + j](x) 347 | else: 348 | xs += self.resblocks[i * self.num_kernels + j](x) 349 | x = xs / self.num_kernels 350 | x = F.leaky_relu(x) 351 | x = self.conv_post(x) 352 | x = torch.tanh(x) 353 | 354 | return x 355 | 356 | def remove_weight_norm(self): 357 | _logger.info("Removing weight norm...") 358 | for layer in self.ups: 359 | remove_weight_norm(layer) 360 | for layer in self.resblocks: 361 | layer.remove_weight_norm() 362 | remove_weight_norm(self.conv_pre) 363 | remove_weight_norm(self.conv_post) 364 | 365 | 366 | class VariancePredictor(nn.Module): 367 | def __init__( 368 | self, 369 | encoder_embed_dim, 370 | var_pred_hidden_dim, 371 | var_pred_kernel_size, 372 | var_pred_dropout, 373 | ): 374 | super().__init__() 375 | self.conv1 = nn.Sequential( 376 | nn.Conv1d( 377 | encoder_embed_dim, 378 | var_pred_hidden_dim, 379 | kernel_size=var_pred_kernel_size, 380 | padding=(var_pred_kernel_size - 1) // 2, 381 | ), 382 | nn.ReLU(), 383 | ) 384 | self.ln1 = nn.LayerNorm(var_pred_hidden_dim) 385 | self.dropout = var_pred_dropout 386 | self.conv2 = nn.Sequential( 387 | nn.Conv1d( 388 | var_pred_hidden_dim, 389 | var_pred_hidden_dim, 390 | kernel_size=var_pred_kernel_size, 391 | padding=1, 392 | ), 393 | nn.ReLU(), 394 | ) 395 | self.ln2 = nn.LayerNorm(var_pred_hidden_dim) 396 | self.proj = nn.Linear(var_pred_hidden_dim, 1) 397 | 398 | def forward(self, x): 399 | # Input: B x T x C; Output: B x T 400 | x = self.conv1(x.transpose(1, 2)).transpose(1, 2) 401 | x = F.dropout(self.ln1(x), p=self.dropout, training=self.training) 402 | x = self.conv2(x.transpose(1, 2)).transpose(1, 2) 403 | x = F.dropout(self.ln2(x), p=self.dropout, training=self.training) 404 | return self.proj(x).squeeze(dim=2) 405 | 406 | 407 | class CodeGenerator(Generator): 408 | def __init__(self, cfg): 409 | super().__init__(cfg) 410 | self.dict = nn.Embedding(cfg["num_embeddings"], cfg["embedding_dim"]) 411 | self.multispkr = cfg.get("multispkr", None) 412 | self.embedder = cfg.get("embedder_params", None) 413 | 414 | self.f0_dict = None 415 | if cfg.get("num_f0_tokens", None): 416 | self.f0_dict = nn.Embedding(cfg["num_f0_tokens"], cfg["embedding_dim"]) 417 | self.hubert_to_f0 = round( 418 | cfg["f0_hop_size"] / cfg["code_hop_size"] 419 | ) # 4 for 25hz hubert and 6.25hz f0 420 | 421 | self.style_dict = None 422 | if cfg.get("num_style_tokens", None): 423 | self.style_dict = nn.Embedding( 424 | cfg["num_style_tokens"], cfg["embedding_dim"] 425 | ) 426 | self.hubert_to_style = round( 427 | cfg["style_hop_size"] / cfg["code_hop_size"] 428 | ) # 25 for 25hz hubert and 1hz style 429 | 430 | self.multistyle = cfg.get("multistyle", None) 431 | 432 | if self.multispkr and not self.embedder: 433 | self.spkr = nn.Embedding(cfg.get("num_speakers", 200), cfg["embedding_dim"]) 434 | elif self.embedder: 435 | self.spkr = nn.Linear(cfg.get("embedder_dim", 256), cfg["embedding_dim"]) 436 | 437 | if self.multistyle: 438 | self.style = nn.Embedding(cfg.get("num_styles", 100), cfg["embedding_dim"]) 439 | 440 | self.dur_predictor = None 441 | if cfg.get("dur_predictor_params", None): 442 | self.dur_predictor = VariancePredictor( 443 | cfg["dur_predictor_params"]["encoder_embed_dim"], 444 | cfg["dur_predictor_params"]["var_pred_hidden_dim"], 445 | cfg["dur_predictor_params"]["var_pred_kernel_size"], 446 | cfg["dur_predictor_params"]["var_pred_dropout"], 447 | ) 448 | 449 | self.f0 = cfg.get("f0", None) 450 | n_f0_bin = cfg.get("f0_quant_num_bin", 0) 451 | self.f0_quant_embed = ( 452 | None if n_f0_bin <= 0 else nn.Embedding(n_f0_bin, cfg["embedding_dim"]) 453 | ) 454 | 455 | @staticmethod 456 | def _upsample(signal, max_frames): 457 | if signal.dim() == 3: 458 | bsz, channels, cond_length = signal.size() 459 | elif signal.dim() == 2: 460 | signal = signal.unsqueeze(2) 461 | bsz, channels, cond_length = signal.size() 462 | else: 463 | signal = signal.view(-1, 1, 1) 464 | bsz, channels, cond_length = signal.size() 465 | 466 | signal = signal.unsqueeze(3).repeat(1, 1, 1, max_frames // cond_length) 467 | 468 | # pad zeros as needed (if signal's shape does not divide completely with max_frames) 469 | reminder = (max_frames - signal.shape[2] * signal.shape[3]) // signal.shape[3] 470 | if reminder > 0: 471 | raise NotImplementedError( 472 | "Padding condition signal - misalignment between condition features." 473 | ) 474 | 475 | signal = signal.view(bsz, channels, max_frames) 476 | return signal 477 | 478 | def forward(self, **kwargs): 479 | x = self.dict(kwargs["code"]).transpose(1, 2) 480 | 481 | dur_out = None 482 | if self.dur_predictor and kwargs.get("dur_prediction", False): 483 | assert x.size(0) == 1, "only support single sample" 484 | log_dur_pred = self.dur_predictor(x.transpose(1, 2)) 485 | dur_out = torch.clamp( 486 | torch.round((torch.exp(log_dur_pred) - 1)).long(), min=1 487 | ) 488 | # B x C x T 489 | x = torch.repeat_interleave(x, dur_out.view(-1), dim=2) 490 | 491 | if self.f0: 492 | if self.f0_quant_embed: 493 | kwargs["f0"] = self.f0_quant_embed(kwargs["f0"].long()).transpose(1, 2) 494 | else: 495 | kwargs["f0"] = kwargs["f0"].unsqueeze(1) 496 | 497 | if x.shape[-1] < kwargs["f0"].shape[-1]: 498 | x = self._upsample(x, kwargs["f0"].shape[-1]) 499 | elif x.shape[-1] > kwargs["f0"].shape[-1]: 500 | kwargs["f0"] = self._upsample(kwargs["f0"], x.shape[-1]) 501 | x = torch.cat([x, kwargs["f0"]], dim=1) 502 | 503 | if self.f0_dict is not None: 504 | f0 = self.f0_dict(kwargs["f0_code"]).transpose(1, 2) # B, C, T 505 | if dur_out is not None: 506 | f0 = torch.repeat_interleave(f0, dur_out.view(-1), dim=2) 507 | x = torch.cat([x, f0], dim=1) # B, 2C, T 508 | 509 | if self.style_dict is not None: 510 | style = self.style_dict(kwargs["style_code"]).transpose(1, 2) # B, C, T 511 | if dur_out is not None: 512 | style = torch.repeat_interleave(style, dur_out.view(-1), dim=2) 513 | x = torch.cat([x, style], dim=1) # B, 2C, T 514 | 515 | if self.multispkr: 516 | assert ( 517 | "spkr" in kwargs 518 | ), 'require "spkr" input for multispeaker CodeHiFiGAN vocoder' 519 | spkr = self.spkr(kwargs["spkr"]).transpose(1, 2) 520 | spkr = self._upsample(spkr, x.shape[-1]) 521 | x = torch.cat([x, spkr], dim=1) 522 | 523 | if self.multistyle: 524 | assert ( 525 | "style" in kwargs 526 | ), 'require "style" input for multispeaker CodeHiFiGAN vocoder' 527 | style = self.style(kwargs["style"]).transpose(1, 2) 528 | style = self._upsample(style, x.shape[-1]) 529 | x = torch.cat([x, style], dim=1) 530 | 531 | for k, feat in kwargs.items(): 532 | if k in [ 533 | "spkr", 534 | "code", 535 | "f0", 536 | "dur_prediction", 537 | "style", 538 | "f0_code", 539 | "style_code", 540 | ]: 541 | continue 542 | 543 | feat = self._upsample(feat, x.shape[-1]) 544 | x = torch.cat([x, feat], dim=1) 545 | 546 | return super().forward(x) 547 | -------------------------------------------------------------------------------- /spiritlm/speech_tokenizer/hubert/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the FAIR Noncommercial Research License 5 | # found in the LICENSE file in the root directory of this source tree. 6 | 7 | from pathlib import Path 8 | import os 9 | 10 | import torch 11 | 12 | from .hubert_tokenizer import HubertTokenizer 13 | 14 | # Get the base checkpoints directory from environment variable or use the default base path 15 | base_checkpoints_dir = Path(os.getenv("SPIRITLM_CHECKPOINTS_DIR", Path(__file__).parents[3] / "checkpoints")) 16 | 17 | # Append 'speech_tokenizer' to the base path 18 | CHECKPOINT_DIR = base_checkpoints_dir / "speech_tokenizer" 19 | 20 | CURRENT_DEVICE = ( 21 | torch.device(torch.cuda.current_device()) 22 | if torch.cuda.is_available() 23 | else "mps" if torch.backends.mps.is_available() else "cpu" 24 | ) 25 | 26 | 27 | def spiritlm_hubert(): 28 | return HubertTokenizer( 29 | hubert_ckpt=CHECKPOINT_DIR / "hubert_25hz/mhubert_base_25hz.pt", 30 | hubert_layer=11, 31 | quantizer_ckpt=CHECKPOINT_DIR / "hubert_25hz/L11_quantizer_500.pt", 32 | is_linear_quantizer=True, 33 | ).to(CURRENT_DEVICE) 34 | -------------------------------------------------------------------------------- /spiritlm/speech_tokenizer/hubert/hubert_model/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the FAIR Noncommercial Research License 5 | # found in the LICENSE file in the root directory of this source tree. 6 | 7 | from .hubert_model import * 8 | -------------------------------------------------------------------------------- /spiritlm/speech_tokenizer/hubert/hubert_tokenizer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the FAIR Noncommercial Research License 5 | # found in the LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | import torchaudio 9 | from torch import nn 10 | 11 | from .hubert_model import load_hubert_model 12 | from .quantizer_model import load_quantizer_model 13 | 14 | 15 | class HubertTokenizer(nn.Module): 16 | def __init__( 17 | self, 18 | hubert_ckpt, 19 | hubert_layer, 20 | quantizer_ckpt, 21 | is_linear_quantizer=True, 22 | min_chunk=400, 23 | max_chunk=100 * 16_000, 24 | ): 25 | super().__init__() 26 | 27 | # hubert model 28 | self.hubert_ckpt = str(hubert_ckpt) 29 | self.hubert_layer = hubert_layer 30 | self.hubert_model = None 31 | self.should_normalize = False 32 | self.min_chunk = min_chunk 33 | self.max_chunk = max_chunk 34 | 35 | # quantizer model 36 | self.quantizer_ckpt = str(quantizer_ckpt) 37 | self.is_linear_quantizer = is_linear_quantizer 38 | self.quantizer_model = None 39 | 40 | # this is useful for determining the device 41 | self.register_buffer("_float_tensor", torch.tensor([0], dtype=torch.float)) 42 | self.load_models() 43 | 44 | @torch.no_grad() # otherwise some non-leaf nodes appear which breaks serialization 45 | def load_models(self): 46 | # Load hubert model 47 | hubert_model, model_cfg, task_cfg = load_hubert_model(self.hubert_ckpt) 48 | self.hubert_task_cfg = task_cfg 49 | self.hubert_model_cfg = model_cfg 50 | self.hubert_model = hubert_model 51 | self.hubert_model.to(self.device) 52 | self.hubert_model.eval() 53 | for parameter in self.hubert_model.parameters(): 54 | parameter.requires_grad_(False) 55 | self.should_normalize = task_cfg.normalize 56 | 57 | # Load quantizer model 58 | self.quantizer_model = load_quantizer_model( 59 | self.quantizer_ckpt, is_linear_quantizer=self.is_linear_quantizer 60 | ) 61 | self.quantizer_model.to(self.device) 62 | self.quantizer_model.eval() 63 | 64 | @property 65 | def device(self): 66 | return self._float_tensor.device 67 | 68 | @property 69 | def code_hop_size(self) -> int: 70 | hop_size = 1 71 | for dim, kernel, stride in eval(self.hubert_model_cfg.conv_feature_layers): 72 | hop_size *= stride 73 | return hop_size # 320 for 50hz model and 640 for 25hz model 74 | 75 | @property 76 | def frame_rate(self) -> int: 77 | return self.expected_sample_rate / self.code_hop_size # 50 or 25 78 | 79 | @property 80 | def n_units(self) -> int: 81 | return self.kmeans_model.K 82 | 83 | @property 84 | def expected_sample_rate(self) -> int: 85 | return self.hubert_task_cfg.sample_rate # 16_000 86 | 87 | def load_audio(self, path): 88 | wav, sr = torchaudio.load(path) 89 | if sr != self.expected_sample_rate: 90 | wav = torchaudio.functional.resample( 91 | wav, orig_freq=sr, new_freq=self.expected_sample_rate 92 | ) 93 | return wav 94 | 95 | @torch.inference_mode() 96 | def forward(self, x, separate_channels=False, dense=False): 97 | if isinstance(x, str): 98 | x = self.load_audio(x) 99 | i_ndim = x.dim() 100 | if i_ndim == 2: 101 | x = x.unsqueeze(0) 102 | elif i_ndim == 1: 103 | x = x.view(1, 1, -1) 104 | 105 | # x should expect a shape [B, C, T], where C is number of channels 106 | assert len(x.shape) == 3 107 | feats = self.get_dense_features(x) # [B, T_enc] 108 | 109 | if dense: 110 | return feats 111 | 112 | tokens = self.quantizer_model(feats) # [B, T_enc] 113 | 114 | if i_ndim == 3: 115 | tokens = tokens.view(x.shape[0], 1, -1) 116 | else: 117 | tokens = tokens.squeeze(0) 118 | 119 | if not separate_channels: 120 | return tokens 121 | 122 | @torch.inference_mode() 123 | def get_dense_features(self, x, separate_channels=False): 124 | x = x.to(self.device) 125 | 126 | assert separate_channels == False, "Not supported yet" # TODO: Fix this 127 | 128 | if not separate_channels: 129 | x = x.mean(1) # [B, T] 130 | 131 | if self.should_normalize: 132 | x = torch.cat([nn.functional.layer_norm(item, item.shape) for item in x]) 133 | 134 | feat = [] 135 | for start in range(0, x.size(1), self.max_chunk): 136 | x_chunk = x[:, start : start + self.max_chunk] 137 | if x_chunk.size(1) < self.min_chunk: 138 | continue 139 | feat_chunk, _ = self.hubert_model.extract_features( 140 | source=x_chunk, 141 | padding_mask=None, 142 | mask=False, 143 | output_layer=self.hubert_layer, 144 | ) 145 | feat.append(feat_chunk) 146 | 147 | return torch.cat(feat, 1) 148 | -------------------------------------------------------------------------------- /spiritlm/speech_tokenizer/hubert/quantizer_model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the FAIR Noncommercial Research License 5 | # found in the LICENSE file in the root directory of this source tree. 6 | 7 | import logging 8 | 9 | import torch 10 | from torch import nn 11 | 12 | _logger = logging.getLogger(__name__) 13 | 14 | 15 | class LinearQuantizerModel(nn.Module): 16 | def __init__(self, ckpt_path): 17 | super().__init__() 18 | state_dict = torch.load(ckpt_path, map_location="cpu", weights_only=True) 19 | 20 | self.vocab_size = state_dict["model_cfg"]["vocab_size"] 21 | dim = state_dict["model_cfg"]["dim"] 22 | upstream_dim = state_dict["model_cfg"]["upstream_dim"] 23 | 24 | out_dim = self.vocab_size + 1 # vocab_size + 1 for blank in CTC 25 | mid_dim = upstream_dim - out_dim 26 | 27 | self.encoder = nn.Sequential( 28 | *[ 29 | nn.Linear(dim, dim - mid_dim // 4), 30 | nn.LeakyReLU(), 31 | nn.Linear(dim - mid_dim // 4, dim - mid_dim // 2), 32 | nn.LeakyReLU(), 33 | nn.Linear(dim - mid_dim // 2, self.vocab_size + 1), 34 | ] 35 | ) 36 | 37 | self.encoder.load_state_dict(state_dict["model_weight"]) 38 | 39 | def forward(self, x): 40 | logits = self.encoder(x) 41 | logits = torch.nn.functional.log_softmax(logits, dim=-1) 42 | code = logits.argmax(dim=-1) 43 | 44 | # post-process units: replace BLANK with most-left non-BLANK units 45 | non_stop_counter = 0 46 | while (code == self.vocab_size).any(): 47 | non_stop_counter += 1 48 | code[code == self.vocab_size] = torch.roll(code, 1)[code == self.vocab_size] 49 | if non_stop_counter == 10000: 50 | break 51 | 52 | return code 53 | 54 | 55 | class KmeansModel(nn.Module): 56 | def __init__(self, km_path): 57 | super().__init__() 58 | states = torch.load(km_path, map_location="cpu", weights_only=True) 59 | assert ( 60 | "cluster_centers" in states and "n_clusters" in states 61 | ), "Not a valid kmeans checkpoint." 62 | C_np = states["cluster_centers"].transpose() # [d_feats, K] 63 | Cnorm_np = (C_np**2).sum(0, keepdims=True) # [K,] 64 | self.K = states["n_clusters"] 65 | assert self.K == C_np.shape[-1] 66 | 67 | self.C = nn.Parameter(torch.from_numpy(C_np), requires_grad=False) 68 | self.Cnorm = nn.Parameter(torch.from_numpy(Cnorm_np), requires_grad=False) 69 | 70 | def forward(self, x): 71 | batched = False 72 | if len(x.shape) == 3: # [B, T, d] 73 | batched = True 74 | B, T, d = x.shape 75 | x = x.view(-1, d) 76 | 77 | # x: [T, d]; C: [d, K]; Cnorm: [K,] 78 | dist = x.pow(2).sum(1, keepdim=True) - 2 * torch.matmul(x, self.C) + self.Cnorm 79 | assigned_clusters = dist.argmin(dim=1) # [T,] 80 | 81 | if batched: 82 | assigned_clusters = assigned_clusters.view(B, T) 83 | 84 | return assigned_clusters 85 | 86 | 87 | def load_quantizer_model(ckpt_path, is_linear_quantizer): 88 | if is_linear_quantizer: 89 | model = LinearQuantizerModel(ckpt_path) 90 | _logger.info(f"Loaded LinearQuantizer from '{ckpt_path}'") 91 | else: 92 | model = KmeansModel(ckpt_path) 93 | _logger.info(f"Loaded KmeansModel from '{ckpt_path}'") 94 | return model 95 | -------------------------------------------------------------------------------- /spiritlm/speech_tokenizer/spiritlm_tokenizer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the FAIR Noncommercial Research License 5 | # found in the LICENSE file in the root directory of this source tree. 6 | 7 | import logging 8 | import os 9 | import random 10 | from typing import Dict, List 11 | 12 | import torchaudio 13 | 14 | MOST_COMMON_STYLES = [71, 68, 98] 15 | 16 | 17 | _logger = logging.getLogger(__name__) 18 | 19 | 20 | def _toks_positions(toks: List[str], rate: float, dedup: bool): 21 | prev_tok = None 22 | res = [] 23 | for i, tok in enumerate(toks): 24 | if (not dedup) or (prev_tok is None or tok != prev_tok): 25 | res += [(tok, i / rate)] 26 | prev_tok = tok 27 | return res 28 | 29 | 30 | def units_to_string( 31 | units: Dict[str, str], 32 | has_pitch=False, 33 | has_style=False, 34 | hubert_rate=24.99, 35 | hubert_dedup=True, 36 | hubert_key="hubert", 37 | pitch_rate=12.5, 38 | pitch_dedup=True, 39 | pitch_key="pitch", 40 | style_rate=1, 41 | style_dedup=False, 42 | style_key="style", 43 | ) -> str: 44 | """ 45 | Example: 46 | - input (units): 47 | { 48 | 'hubert': '78 42 81 159 316 259', 49 | 'pitch': '13 13 13 13 13 3', 50 | 'style': '81 81 81 81 81 81', 51 | } 52 | - output: 53 | '[St81][Hu78][Pi13][Hu42][Hu81][Hu159][Hu316][Pi3][Hu259]' 54 | """ 55 | 56 | combine_toks = [] 57 | 58 | if has_style: 59 | combine_toks += _toks_positions( 60 | [f"[St{i}]" for i in units[style_key].split()], style_rate, style_dedup 61 | ) 62 | if has_pitch: 63 | combine_toks += _toks_positions( 64 | [f"[Pi{i}]" for i in units[pitch_key].split()], pitch_rate, pitch_dedup 65 | ) 66 | combine_toks += _toks_positions( 67 | [f"[Hu{i}]" for i in units[hubert_key].split()], hubert_rate, hubert_dedup 68 | ) 69 | combine_toks = [tok_pos[0] for tok_pos in sorted(combine_toks, key=lambda x: x[1])] 70 | return "".join(combine_toks) 71 | 72 | 73 | def get_random_most_common_style() -> int: 74 | return random.choice(MOST_COMMON_STYLES) 75 | 76 | 77 | def string_to_units( 78 | gen, 79 | hubert_key="hubert", 80 | pitch_key="pitch", 81 | style_key="style", 82 | duplicate_hubert_for_multiple_pitch=False, 83 | ): 84 | """ 85 | Convert from tokenized string to dictionary of units. 86 | The units are 'pre-duplicated' to match the number of hubert units. 87 | Examples 88 | - input: 89 | '[St81][Hu78][Pi13][Hu42][Hu81][Hu159][Hu316][Pi3][Hu259]' 90 | - output: 91 | { 92 | 'hubert': '78 42 81 159 316 259', 93 | 'pitch': '13 13 13 13 13 3', 94 | 'style': '81 81 81 81 81 81', 95 | } 96 | """ 97 | prev_hubert = None 98 | first_hubert = None 99 | prev_pitch = None 100 | first_pitch = None 101 | prev_style = None 102 | first_style = None 103 | prev_is_pitch = False # If this is True, add prev_hubert to the codes 104 | hubert = [] 105 | pitch = [] 106 | style = [] 107 | for item in gen.split("["): 108 | if item and len(item) > 2: 109 | if item.startswith("Hu") and item[2].isdigit(): 110 | hubert += [item[2:-1]] 111 | pitch += [prev_pitch] 112 | style += [prev_style] 113 | prev_is_pitch = False 114 | prev_hubert = item[2:-1] 115 | if first_hubert is None: 116 | first_hubert = item[2:-1] 117 | elif item.startswith("St") and item[2].isdigit(): 118 | if prev_style is None: 119 | first_style = item[2:-1] 120 | prev_style = item[2:-1] 121 | elif item.startswith("Pi") and item[2].isdigit(): 122 | if duplicate_hubert_for_multiple_pitch and prev_is_pitch: 123 | hubert += [prev_hubert] 124 | pitch += [item[2:-1]] 125 | style += [prev_style] 126 | if prev_pitch is None: 127 | first_pitch = item[2:-1] 128 | prev_pitch = item[2:-1] 129 | prev_is_pitch = True 130 | if first_pitch is not None and first_style is None: 131 | # in rare case, style is not present, we select randomly a common style token to make decoding work 132 | first_style = str(get_random_most_common_style()) 133 | for i in range(len(hubert)): 134 | if hubert[i] is None: 135 | hubert[i] = first_hubert 136 | if style[i] is None: 137 | style[i] = first_style 138 | if pitch[i] is None: 139 | pitch[i] = first_pitch 140 | units = {hubert_key: " ".join(hubert)} 141 | if first_pitch is not None: 142 | units[pitch_key] = " ".join(pitch) 143 | if first_style is not None: 144 | units[style_key] = " ".join(style) 145 | return units 146 | 147 | 148 | class SpiritLMTokenizer: 149 | def __init__( 150 | self, 151 | hubert_model, 152 | pitch_model=None, 153 | style_model=None, 154 | hifigan_model=None, 155 | hubert_rate=24.99, 156 | hubert_dedup=True, 157 | hubert_key="hubert", 158 | pitch_rate=12.5, 159 | pitch_dedup=True, 160 | pitch_key="pitch", 161 | style_rate=1, 162 | style_dedup=False, 163 | style_key="style", 164 | expected_sample_rate=16_000, 165 | max_wav_chunk=100 * 16_000, 166 | min_wav_chunk=1280, # 400 is minimum for hubert, 1280 (80ms) is minimum for pitch, so let's take 1280 167 | ): 168 | super().__init__() 169 | 170 | self.hubert_model = hubert_model 171 | self.pitch_model = pitch_model 172 | self.style_model = style_model 173 | self.hifigan_model = hifigan_model 174 | 175 | self.hubert_rate = hubert_rate 176 | self.hubert_dedup = hubert_dedup 177 | self.hubert_key = hubert_key 178 | 179 | self.speech_token = "[Speech]" 180 | self.pitch_key = None 181 | self.style_key = None 182 | if pitch_model is not None: 183 | self.pitch_rate = pitch_rate 184 | self.pitch_dedup = pitch_dedup 185 | self.pitch_key = pitch_key 186 | if style_model is not None: 187 | self.style_rate = style_rate 188 | self.style_dedup = style_dedup 189 | self.style_key = style_key 190 | 191 | self.expected_sample_rate = expected_sample_rate 192 | self.max_wav_chunk = max_wav_chunk 193 | self.min_wav_chunk = min_wav_chunk 194 | 195 | def load_audio(self, path): 196 | wav, sr = torchaudio.load(path) 197 | if sr != self.expected_sample_rate: 198 | wav = torchaudio.functional.resample( 199 | wav, orig_freq=sr, new_freq=self.expected_sample_rate 200 | ) 201 | return wav 202 | 203 | def encode_units(self, audio, channel_id=None): 204 | """ 205 | Get the speech units in dictionary format, e.g. 206 | { 207 | 'audio': 'path/to/audio.wav', 208 | 'hubert': '1 1 2 2 3', 209 | 'pitch': '15 15 20', 210 | 'style': '7', 211 | } 212 | The audio can be the path to audio file or an array. 213 | For stereo audio file, channel_id can be set (0 or 1). 214 | """ 215 | units = {} 216 | 217 | if isinstance(audio, str): 218 | units["audio"] = os.path.abspath(audio) 219 | audio = self.load_audio(audio) 220 | audio = audio.squeeze() 221 | if len(audio.shape) == 2: 222 | assert ( 223 | audio.shape[0] == 2 224 | ), f"expected a stereo wav of shape (2,x), found {audio.shape}" 225 | if channel_id is None: 226 | _logger.warning( 227 | "Found stereo audio input, averaging audio from 2 channels. If you want to extract" 228 | "only one channel, set channel_id to 0 or 1" 229 | ) 230 | audio = audio.mean(0) 231 | else: 232 | audio = audio[channel_id] 233 | assert len(audio.shape) == 1, audio.shape 234 | 235 | hubert_units = [] 236 | pitch_units = [] 237 | style_units = [] 238 | for start in range(0, len(audio), self.max_wav_chunk): 239 | audio_chunk = audio[start : start + self.max_wav_chunk] 240 | if len(audio_chunk) < self.min_wav_chunk: 241 | continue 242 | hubert_units.extend([str(i.item()) for i in self.hubert_model(audio_chunk)]) 243 | if self.pitch_model is not None: 244 | pitch_units.extend( 245 | [str(i.item()) for i in self.pitch_model(audio_chunk)] 246 | ) 247 | if self.style_model is not None: 248 | style_units.extend( 249 | [str(i.item()) for i in self.style_model(audio_chunk)] 250 | ) 251 | 252 | units[self.hubert_key] = " ".join(hubert_units) 253 | if self.pitch_model is not None: 254 | units[self.pitch_key] = " ".join(pitch_units) 255 | if self.style_model is not None: 256 | units[self.style_key] = " ".join(style_units) 257 | return units 258 | 259 | def units2string(self, units): 260 | """ 261 | Convert from dictionary of units to tokenized string. 262 | The units are (optionally deduped) sorted by time steps and interleaved 263 | """ 264 | has_pitch = self.pitch_model is not None 265 | has_style = self.style_model is not None 266 | return units_to_string( 267 | units=units, 268 | has_pitch=has_pitch, 269 | has_style=has_style, 270 | hubert_rate=self.hubert_rate, 271 | hubert_dedup=self.hubert_dedup, 272 | hubert_key=self.hubert_key, 273 | pitch_rate=self.pitch_rate if has_pitch else None, 274 | pitch_dedup=self.pitch_dedup if has_pitch else None, 275 | pitch_key=self.pitch_key if has_pitch else None, 276 | style_rate=self.style_rate if has_style else None, 277 | style_dedup=self.style_dedup if has_style else None, 278 | style_key=self.style_key if has_style else None, 279 | ) 280 | 281 | def encode_string(self, audio): 282 | """ 283 | Tokenize the audio into string format, e.g. 284 | '[St7][Pi15][Hu1][Hu2][Pi20][Hu3]' 285 | """ 286 | units = self.encode_units(audio) 287 | return self.units2string(units) 288 | 289 | def __call__(self, audio): 290 | """ 291 | Default call method 292 | """ 293 | return self.encode_string(audio) 294 | 295 | def string2units(self, gen, duplicate_hubert_for_multiple_pitch=False): 296 | """ 297 | Convert from tokenized string to dictionary of units. 298 | The units are 'pre-duplicated' to match the number of hubert units. 299 | Examples 300 | - input: 301 | '[St81][Hu78][Pi13][Hu42][Hu81][Hu159][Hu316][Pi3][Hu259]' 302 | - output: 303 | { 304 | 'hubert': '78 42 81 159 316 259', 305 | 'pitch': '13 13 13 13 13 3', 306 | 'style': '81 81 81 81 81 81', 307 | } 308 | """ 309 | return string_to_units( 310 | gen, 311 | hubert_key=self.hubert_key, 312 | pitch_key=self.pitch_key if self.pitch_key else "pitch", 313 | style_key=self.style_key if self.style_key else "style", 314 | duplicate_hubert_for_multiple_pitch=duplicate_hubert_for_multiple_pitch, 315 | ) 316 | 317 | def decode(self, code, speaker_id=2, dur_pred=True): 318 | """ 319 | code can be under text form ([Hu1][Hu2]) or units form ({'hubert': '1 2'}) 320 | """ 321 | 322 | assert self.hifigan_model is not None 323 | 324 | if isinstance(code, str): 325 | units = self.string2units(code) 326 | else: 327 | units = code 328 | 329 | # if units['hubert'] doesn't have the same number as units['f0'] 330 | # then likely this is resynthesis task, and we'll set dur_pred=False 331 | if ( 332 | self.pitch_key 333 | and self.pitch_key in units 334 | and len(units[self.pitch_key].split()) 335 | != len(units[self.hubert_key].split()) 336 | ): 337 | dur_pred = False 338 | 339 | wav = ( 340 | self.hifigan_model( 341 | code=units[self.hubert_key], 342 | f0_code=( 343 | units[self.pitch_key] 344 | if self.pitch_key and self.pitch_key in units 345 | else None 346 | ), 347 | style_code=( 348 | units[self.style_key] 349 | if self.style_key and self.style_key in units 350 | else None 351 | ), 352 | dur_pred=dur_pred, 353 | speaker_id=speaker_id, 354 | not_dedup_code=True, 355 | ) 356 | .detach() 357 | .cpu() 358 | .numpy() 359 | ) 360 | 361 | return wav 362 | -------------------------------------------------------------------------------- /spiritlm/speech_tokenizer/style_encoder/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the FAIR Noncommercial Research License 5 | # found in the LICENSE file in the root directory of this source tree. 6 | 7 | 8 | import logging 9 | import os 10 | from pathlib import Path 11 | 12 | import torch 13 | 14 | from .w2v2_encoder import Wav2Vec2StyleEncoder 15 | 16 | _logger = logging.getLogger(__name__) 17 | 18 | # Get the base checkpoints directory from environment variable or use the default base path 19 | base_checkpoints_dir = Path(os.getenv("SPIRITLM_CHECKPOINTS_DIR", Path(__file__).parents[3] / "checkpoints")) 20 | 21 | # Append 'speech_tokenizer' to the base path 22 | CHECKPOINT_DIR = base_checkpoints_dir / "speech_tokenizer" 23 | 24 | CURRENT_DEVICE = ( 25 | torch.device(torch.cuda.current_device()) 26 | if torch.cuda.is_available() 27 | else "mps" if torch.backends.mps.is_available() else "cpu" 28 | ) 29 | 30 | 31 | def spiritlm_expressive_style_encoder_w2v2() -> Wav2Vec2StyleEncoder: 32 | STYLE_ENCODER_CKPT_PATH = CHECKPOINT_DIR / "style_encoder_w2v2" 33 | model = Wav2Vec2StyleEncoder.from_pretrained( 34 | pretrained_model_name_or_path=STYLE_ENCODER_CKPT_PATH 35 | ).to(CURRENT_DEVICE) 36 | _logger.info(f"Style encoder loaded from {str(STYLE_ENCODER_CKPT_PATH)}") 37 | return model 38 | -------------------------------------------------------------------------------- /spiritlm/speech_tokenizer/style_encoder/w2v2_encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the FAIR Noncommercial Research License 5 | # found in the LICENSE file in the root directory of this source tree. 6 | 7 | from typing import Union 8 | 9 | import torch 10 | import torchaudio 11 | from transformers import Wav2Vec2FeatureExtractor, Wav2Vec2ForSequenceClassification 12 | 13 | 14 | class Wav2Vec2StyleEncoder(Wav2Vec2ForSequenceClassification): 15 | def __init__(self, config, pool_size: int = 50): 16 | super().__init__(config) 17 | self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained( 18 | "facebook/wav2vec2-base" 19 | ) 20 | self.pool_size = pool_size 21 | 22 | # this is useful for determining the device 23 | self.register_buffer("_float_tensor", torch.tensor([0], dtype=torch.float)) 24 | 25 | @property 26 | def device(self): 27 | return self._float_tensor.device 28 | 29 | @torch.no_grad() 30 | def forward(self, wavs: Union[torch.Tensor, str]) -> torch.Tensor: 31 | if isinstance(wavs, str): 32 | # TODO: resampling if applicable 33 | wavs = torchaudio.load(wavs)[0].squeeze(0) 34 | # TODO: handle list of strs 35 | inputs = self.feature_extractor( 36 | wavs, sampling_rate=16_000, return_tensors="pt" 37 | ).input_values 38 | outputs = self.wav2vec2(inputs.to(self.device)) 39 | hidden_states = outputs[0] 40 | hidden_states = self.projector(hidden_states) 41 | chunk_size = self.pool_size 42 | batch_size, sequence_length, hidden_size = hidden_states.shape 43 | pooled_output = [] 44 | for i in range(0, sequence_length, chunk_size): 45 | chunk = hidden_states[:, i : i + chunk_size, :] 46 | pooled_output.append(chunk.mean(dim=1)) 47 | pooled_output = torch.cat( 48 | pooled_output, dim=1 49 | ) # Concatenate the chunks along the desired dimension 50 | pooled_output = pooled_output.view( 51 | batch_size, -1, hidden_size 52 | ) # Reshape back to the original shape 53 | logits = self.classifier(pooled_output) 54 | lprobs = torch.nn.functional.log_softmax(logits, dim=-1) 55 | pred = torch.argmax(lprobs, dim=-1).squeeze(0) 56 | return pred 57 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the FAIR Noncommercial Research License 5 | # found in the LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /tests/test_spirit_model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the FAIR Noncommercial Research License 5 | # found in the LICENSE file in the root directory of this source tree. 6 | 7 | from unittest.mock import Mock, patch 8 | 9 | import pytest 10 | from spiritlm.model.spiritlm_model import Spiritlm 11 | from spiritlm.model.utils import ( 12 | does_end_with_speech_token, 13 | does_start_with_speech_token, 14 | find_prompt_last_speech_start_position, 15 | ) 16 | 17 | 18 | @pytest.mark.parametrize( 19 | "content,expected", 20 | [ 21 | ( 22 | "abc[Speech][St1][Pi234][Hu123][Hu45][Text]hello world[", 23 | [("abc", "t"), ("[St1][Pi234][Hu123][Hu45]", "s"), ("hello world[", "t")], 24 | ), 25 | ( 26 | "[St1][Pi234][Hu123][Hu45]", 27 | [("[St1][Pi234][Hu123][Hu45]", "s")], 28 | ), 29 | ( 30 | "abc", 31 | [("abc", "t")], 32 | ), 33 | ( 34 | "abc[]", 35 | [("abc[]", "t")], 36 | ), 37 | ( 38 | "[St1][Pi234][Hu123][Hu45][Text][abc", 39 | [("[St1][Pi234][Hu123][Hu45]", "s"), ("[abc", "t")], 40 | ), 41 | ( 42 | "abc[Text]def", 43 | [("abcdef", "t")], 44 | ), 45 | ], 46 | ) 47 | def test_parse_speech_and_text(content, expected): 48 | with patch( 49 | "spiritlm.model.spiritlm_model.Spiritlm.__init__", Mock(return_value=None) 50 | ): 51 | mock_spiritlm_model = Spiritlm("spirit-lm-base-7b") 52 | mock_spiritlm_model.speech_prompt_prefix = "[Speech]" 53 | assert mock_spiritlm_model._parse_speech_and_text(content) == expected 54 | 55 | 56 | @pytest.mark.parametrize( 57 | "content,expected", 58 | [ 59 | ( 60 | "[Hu338][Text] and they went out together[Speech][Hu431][Pi0][Hu457][Hu79][Pi11][Hu258][Hu85][Hu28][Hu50][Text] and mrs johnson shoes except in mourning[Speech][Pi59][Hu32][Pi20][Hu453][Pi35][Pi26][Hu166]", 61 | [ 62 | ("[Hu338]", "s"), 63 | (" and they went out together", "t"), 64 | ("[Hu431][Pi0][Hu457][Hu79][Pi11][Hu258][Hu85][Hu28][Hu50]", "s"), 65 | (" and mrs johnson shoes except in mourning", "t"), 66 | ("[Pi59][Hu32][Pi20][Hu453][Pi35][Pi26][Hu166]", "s"), 67 | ], 68 | ) 69 | ], 70 | ) 71 | def test_parse_speech_and_text_with_expressive_tokens(content, expected): 72 | with patch( 73 | "spiritlm.model.spiritlm_model.Spiritlm.__init__", Mock(return_value=None) 74 | ): 75 | mock_spiritlm_model = Spiritlm("spirit-lm-base-7b") 76 | mock_spiritlm_model.speech_prompt_prefix = "[Speech]" 77 | print(f"content: {content}") 78 | print(f"expected: {expected}") 79 | assert mock_spiritlm_model._parse_speech_and_text(content) == expected 80 | 81 | 82 | @pytest.mark.parametrize( 83 | "encoded_string,expected", 84 | [ 85 | ( 86 | "]]", 87 | False, 88 | ), 89 | ( 90 | "[]", 91 | False, 92 | ), 93 | ( 94 | "[Hu100]", 95 | True, 96 | ), 97 | ("abc[]", False), 98 | ( 99 | "[St1][Pi234][Hu123][Hu45][Text][abc]", 100 | False, 101 | ), 102 | ( 103 | "abc[Text]def", 104 | False, 105 | ), 106 | ( 107 | "[Pi9]", 108 | True, 109 | ), 110 | ( 111 | "[St0]", 112 | True, 113 | ), 114 | ], 115 | ) 116 | def test_does_prompt_end_by_speech(encoded_string, expected): 117 | assert does_end_with_speech_token(encoded_string) == expected 118 | 119 | 120 | @pytest.mark.parametrize( 121 | "encoded_string,expected", 122 | [ 123 | ( 124 | "abc[Hu123][Hu456][Pi23][St2]", 125 | 3, 126 | ), 127 | ( 128 | "[Hu123]abc[Hu123][Hu456][Pi23][St2]", 129 | 10, 130 | ), 131 | ( 132 | "[Hu123][Hu456][Pi23][St2]", 133 | 0, 134 | ), 135 | ( 136 | "abc", 137 | None, 138 | ), 139 | ( 140 | "[Speech][St71][Pi39][Hu99][Hu49][Pi57][Hu38][Hu149][Pi48][Hu71][Hu423][Hu427][Pi56][Hu492][Hu288][Pi44][Hu315][Hu153][Pi42][Hu389][Pi59][Hu497][Hu412][Pi51][Hu247][Hu354][Pi44][Hu7][Hu96][Pi43][Hu452][Pi0][Hu176][Hu266][Pi54][St71][Hu77][Pi13][Hu248][Hu336][Pi39][Hu211][Pi25][Hu166][Hu65][Pi58][Hu94][Hu224][Pi26][Hu148][Pi44][Hu492][Hu191][Pi26][Hu440][Pi13][Hu41][Pi20][Hu457][Hu79][Pi46][Hu382][Hu451][Pi26][Hu332][Hu216][Hu114][Hu340][St71][Pi40][Hu478][Hu74][Pi26][Hu79][Hu370][Pi56][Hu272][Hu370][Pi51][Hu53][Pi14][Hu477][Hu65][Pi46][Hu171][Hu60][Pi41][Hu258][Hu111][Pi40][Hu338][Hu23][Pi39][Hu338][Hu23][Hu338][St71][Pi57][Hu7][Hu338][Hu149][Pi59][Hu406][Hu7][Hu361][Hu99][Pi20][Hu209][Hu479][Pi35][Hu50][St71][Hu7][Hu149][Pi55][Hu35][Pi13][Hu130][Pi3][Hu169][Pi52][Hu72][Pi9][Hu434][Hu119][Hu272][Hu4][Pi20][Hu249][Hu245][Pi57][Hu433][Pi56][Hu159][Hu294][Hu139][Hu359][Hu343][Hu269][Hu302][St71][Hu226][Pi32][Hu370][Hu216][Pi39][Hu459][Hu424][Pi57][Hu226][Pi46][Hu382][Hu7][Pi27][Hu58][Hu138][Pi20][Hu428][Hu397][Pi44][Hu350][Pi32][Hu306][Pi59][Hu84][Hu11][Hu171][Pi42][Hu60][Pi48][Hu314][Hu227][St71][Hu355][Pi56][Hu9][Hu58][Pi44][Hu138][Hu226][Pi25][Hu370][Hu272][Pi56][Hu382][Hu334][Pi26][Hu330][Hu176][Pi56][Hu307][Pi46][Hu145][Hu248][Pi56][Hu493][Hu64][Pi40][Hu44][Hu388][Pi39][Hu7][Hu111][Pi59][St71][Hu23][Hu481][Pi13][Hu149][Pi15][Hu80][Hu70][Pi47][Hu431][Hu457][Pi13][Hu79][Pi27][Hu249][Pi55][Hu245][Pi54][Hu433][Pi36][Hu316][Pi53][Hu180][Pi3][Hu458][Pi26][Hu86][St71][Pi43][Hu225][Pi49][Hu103][Hu60][Pi3][Hu96][Hu119][Pi39][Hu129][Pi41][Hu356][Hu218][Pi14][Hu4][Hu259][Pi56][Hu392][Pi46][Hu490][Hu75][Pi14][Hu488][Hu166][Pi46][Hu65][Hu171][Pi40][Hu60][Hu7][Hu54][Pi39][Hu85][St83][Pi40][Hu361]", 141 | 8, 142 | ), 143 | ], 144 | ) 145 | def test_find_prompt_last_speech_start_position(encoded_string, expected): 146 | assert find_prompt_last_speech_start_position(encoded_string) == expected 147 | 148 | 149 | @pytest.mark.parametrize( 150 | "encoded_string,expected", 151 | [ 152 | ( 153 | "[[", 154 | False, 155 | ), 156 | ( 157 | "[]", 158 | False, 159 | ), 160 | ( 161 | "[Hu100]", 162 | True, 163 | ), 164 | ("abc[]", False), 165 | ( 166 | "[St1][Pi234][Hu123][Hu45][Text][abc]", 167 | True, 168 | ), 169 | ( 170 | "abc[Text]def", 171 | False, 172 | ), 173 | ( 174 | "[Pi9]", 175 | True, 176 | ), 177 | ( 178 | "[St0]", 179 | True, 180 | ), 181 | ], 182 | ) 183 | def test_does_start_with_speech_token(encoded_string, expected): 184 | assert does_start_with_speech_token(encoded_string) == expected 185 | -------------------------------------------------------------------------------- /tests/test_tokenizer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the FAIR Noncommercial Research License 5 | # found in the LICENSE file in the root directory of this source tree. 6 | 7 | import pytest 8 | import torchaudio 9 | from spiritlm.speech_tokenizer import spiritlm_base, spiritlm_expressive 10 | 11 | 12 | @pytest.fixture 13 | def spiritlm_expressive_tokenizer(): 14 | return spiritlm_expressive() 15 | 16 | 17 | @pytest.fixture 18 | def spiritlm_base_tokenizer(): 19 | return spiritlm_base() 20 | 21 | 22 | def test_expressive_tokenizer_encode_units(spiritlm_expressive_tokenizer): 23 | audio = "examples/audio/7143-88743-0029.flac" 24 | units = spiritlm_expressive_tokenizer.encode_units(audio) 25 | expected = { 26 | "hubert": "99 49 38 149 149 71 423 427 492 288 315 153 153 389 497 412 247 354 7 96 452 452 176 266 266 77 248 336 336 211 166 65 94 224 224 148 492 191 440 440 41 41 457 79 382 451 332 216 114 340 478 74 79 370 272 370 370 53 477 65 171 60 258 111 111 111 111 338 338 23 23 338 23 338 338 338 7 338 338 149 406 7 361 361 361 99 99 99 99 99 99 99 209 209 209 209 209 479 50 50 7 149 149 35 35 130 130 169 169 72 434 119 272 4 249 245 245 433 159 294 139 359 343 269 302 226 370 216 459 424 424 226 382 7 58 138 428 397 350 350 306 306 306 84 11 171 171 60 314 227 227 355 9 58 138 226 370 272 382 334 330 176 176 307 145 248 493 64 44 388 7 111 111 111 111 23 23 481 149 149 80 70 431 457 79 79 249 249 245 245 245 433 433 316 316 180 458 458 458 86 86 225 103 60 96 119 119 129 356 218 4 259 259 392 490 75 488 166 65 171 60 7 54 54 85 85 361 361", 27 | "pitch": "39 39 39 48 56 40 42 39 51 40 43 54 3 35 39 25 58 26 44 40 13 20 46 41 26 40 26 56 41 46 46 41 41 40 40 40 39 39 57 59 59 59 59 59 59 59 59 20 20 20 35 35 13 3 9 6 0 20 57 56 56 56 56 59 44 57 41 59 42 51 59 57 59 59 39 39 46 56 58 41 41 40 39 39 39 59 59 59 15 27 13 55 13 27 35 36 3 53 3 26 43 53 54 39 25 14 41 46 46 46 46 41 41 41", 28 | "style": "71 71 71 71 71 71 71 71 71 83", 29 | } 30 | for token_key in ["hubert", "pitch", "style"]: 31 | assert ( 32 | expected[token_key] == units[token_key] 33 | ), f"{token_key} expected {expected[token_key]}, got {units[token_key]}" 34 | 35 | 36 | def test_expressive_tokenizer_encode_units_with_tensor_input( 37 | spiritlm_expressive_tokenizer, 38 | ): 39 | wav = torchaudio.load("examples/audio/7143-88743-0029.flac")[0].squeeze(0) 40 | units = spiritlm_expressive_tokenizer.encode_units(wav) 41 | expected = { 42 | "hubert": "99 49 38 149 149 71 423 427 492 288 315 153 153 389 497 412 247 354 7 96 452 452 176 266 266 77 248 336 336 211 166 65 94 224 224 148 492 191 440 440 41 41 457 79 382 451 332 216 114 340 478 74 79 370 272 370 370 53 477 65 171 60 258 111 111 111 111 338 338 23 23 338 23 338 338 338 7 338 338 149 406 7 361 361 361 99 99 99 99 99 99 99 209 209 209 209 209 479 50 50 7 149 149 35 35 130 130 169 169 72 434 119 272 4 249 245 245 433 159 294 139 359 343 269 302 226 370 216 459 424 424 226 382 7 58 138 428 397 350 350 306 306 306 84 11 171 171 60 314 227 227 355 9 58 138 226 370 272 382 334 330 176 176 307 145 248 493 64 44 388 7 111 111 111 111 23 23 481 149 149 80 70 431 457 79 79 249 249 245 245 245 433 433 316 316 180 458 458 458 86 86 225 103 60 96 119 119 129 356 218 4 259 259 392 490 75 488 166 65 171 60 7 54 54 85 85 361 361", 43 | "pitch": "39 39 39 48 56 40 42 39 51 40 43 54 3 35 39 25 58 26 44 40 13 20 46 41 26 40 26 56 41 46 46 41 41 40 40 40 39 39 57 59 59 59 59 59 59 59 59 20 20 20 35 35 13 3 9 6 0 20 57 56 56 56 56 59 44 57 41 59 42 51 59 57 59 59 39 39 46 56 58 41 41 40 39 39 39 59 59 59 15 27 13 55 13 27 35 36 3 53 3 26 43 53 54 39 25 14 41 46 46 46 46 41 41 41", 44 | "style": "71 71 71 71 71 71 71 71 71 83", 45 | } 46 | for token_key in ["hubert", "pitch", "style"]: 47 | assert ( 48 | expected[token_key] == units[token_key] 49 | ), f"{token_key} expected {expected[token_key]}, got {units[token_key]}" 50 | 51 | 52 | def test_base_tokenizer_encode_units(spiritlm_base_tokenizer): 53 | audio = "examples/audio/7143-88743-0029.flac" 54 | units = spiritlm_base_tokenizer.encode_units(audio) 55 | expected_hubert = "99 49 38 149 149 71 423 427 492 288 315 153 153 389 497 412 247 354 7 96 452 452 176 266 266 77 248 336 336 211 166 65 94 224 224 148 492 191 440 440 41 41 457 79 382 451 332 216 114 340 478 74 79 370 272 370 370 53 477 65 171 60 258 111 111 111 111 338 338 23 23 338 23 338 338 338 7 338 338 149 406 7 361 361 361 99 99 99 99 99 99 99 209 209 209 209 209 479 50 50 7 149 149 35 35 130 130 169 169 72 434 119 272 4 249 245 245 433 159 294 139 359 343 269 302 226 370 216 459 424 424 226 382 7 58 138 428 397 350 350 306 306 306 84 11 171 171 60 314 227 227 355 9 58 138 226 370 272 382 334 330 176 176 307 145 248 493 64 44 388 7 111 111 111 111 23 23 481 149 149 80 70 431 457 79 79 249 249 245 245 245 433 433 316 316 180 458 458 458 86 86 225 103 60 96 119 119 129 356 218 4 259 259 392 490 75 488 166 65 171 60 7 54 54 85 85 361 361" 56 | assert expected_hubert == units["hubert"] 57 | 58 | 59 | def test_expressive_tokenizer_encode_string(spiritlm_expressive_tokenizer): 60 | audio = "examples/audio/7143-88743-0029.flac" 61 | encoded_string = spiritlm_expressive_tokenizer.encode_string(audio) 62 | expected = "[St71][Pi39][Hu99][Hu49][Hu38][Hu149][Hu71][Pi48][Hu423][Hu427][Pi56][Hu492][Hu288][Pi40][Hu315][Hu153][Pi42][Hu389][Pi39][Hu497][Hu412][Pi51][Hu247][Hu354][Pi40][Hu7][Hu96][Pi43][Hu452][Pi54][Hu176][Hu266][Pi3][St71][Hu77][Pi35][Hu248][Hu336][Pi39][Hu211][Pi25][Hu166][Hu65][Pi58][Hu94][Hu224][Pi26][Hu148][Pi44][Hu492][Hu191][Pi40][Hu440][Pi13][Hu41][Pi20][Hu457][Hu79][Pi46][Hu382][Hu451][Pi41][Hu332][Hu216][Pi26][Hu114][Hu340][St71][Pi40][Hu478][Hu74][Pi26][Hu79][Hu370][Pi56][Hu272][Hu370][Pi41][Hu53][Pi46][Hu477][Hu65][Hu171][Hu60][Pi41][Hu258][Hu111][Pi40][Hu338][Hu23][Hu338][Pi39][Hu23][Hu338][St71][Pi57][Hu7][Hu338][Pi59][Hu149][Hu406][Hu7][Hu361][Hu99][Hu209][Pi20][Hu479][Hu50][St71][Pi35][Hu7][Hu149][Hu35][Pi13][Hu130][Pi3][Hu169][Pi9][Hu72][Pi6][Hu434][Hu119][Pi0][Hu272][Hu4][Pi20][Hu249][Hu245][Pi57][Hu433][Pi56][Hu159][Hu294][Hu139][Hu359][Hu343][Hu269][Hu302][St71][Hu226][Pi59][Hu370][Hu216][Pi44][Hu459][Hu424][Pi57][Hu226][Pi41][Hu382][Hu7][Pi59][Hu58][Hu138][Pi42][Hu428][Hu397][Pi51][Hu350][Pi59][Hu306][Pi57][Hu84][Pi59][Hu11][Hu171][Hu60][Pi39][Hu314][Hu227][St71][Hu355][Pi46][Hu9][Hu58][Pi56][Hu138][Hu226][Pi58][Hu370][Hu272][Pi41][Hu382][Hu334][Hu330][Hu176][Pi40][Hu307][Pi39][Hu145][Hu248][Hu493][Hu64][Hu44][Hu388][Pi59][Hu7][Hu111][St71][Hu23][Pi15][Hu481][Pi27][Hu149][Pi13][Hu80][Hu70][Pi55][Hu431][Hu457][Pi13][Hu79][Pi27][Hu249][Pi35][Hu245][Pi36][Hu433][Pi3][Hu316][Pi53][Hu180][Pi3][Hu458][Pi26][Hu86][St71][Pi43][Hu225][Pi53][Hu103][Hu60][Pi54][Hu96][Hu119][Pi39][Hu129][Pi25][Hu356][Hu218][Pi14][Hu4][Hu259][Pi41][Hu392][Pi46][Hu490][Hu75][Hu488][Hu166][Hu65][Hu171][Hu60][Hu7][Pi41][Hu54][Hu85][St83][Hu361]" 63 | assert encoded_string == expected 64 | 65 | 66 | def test_base_tokenizer_encode_string(spiritlm_base_tokenizer): 67 | audio = "examples/audio/7143-88743-0029.flac" 68 | encoded_string = spiritlm_base_tokenizer.encode_string(audio) 69 | expected = "[Hu99][Hu49][Hu38][Hu149][Hu71][Hu423][Hu427][Hu492][Hu288][Hu315][Hu153][Hu389][Hu497][Hu412][Hu247][Hu354][Hu7][Hu96][Hu452][Hu176][Hu266][Hu77][Hu248][Hu336][Hu211][Hu166][Hu65][Hu94][Hu224][Hu148][Hu492][Hu191][Hu440][Hu41][Hu457][Hu79][Hu382][Hu451][Hu332][Hu216][Hu114][Hu340][Hu478][Hu74][Hu79][Hu370][Hu272][Hu370][Hu53][Hu477][Hu65][Hu171][Hu60][Hu258][Hu111][Hu338][Hu23][Hu338][Hu23][Hu338][Hu7][Hu338][Hu149][Hu406][Hu7][Hu361][Hu99][Hu209][Hu479][Hu50][Hu7][Hu149][Hu35][Hu130][Hu169][Hu72][Hu434][Hu119][Hu272][Hu4][Hu249][Hu245][Hu433][Hu159][Hu294][Hu139][Hu359][Hu343][Hu269][Hu302][Hu226][Hu370][Hu216][Hu459][Hu424][Hu226][Hu382][Hu7][Hu58][Hu138][Hu428][Hu397][Hu350][Hu306][Hu84][Hu11][Hu171][Hu60][Hu314][Hu227][Hu355][Hu9][Hu58][Hu138][Hu226][Hu370][Hu272][Hu382][Hu334][Hu330][Hu176][Hu307][Hu145][Hu248][Hu493][Hu64][Hu44][Hu388][Hu7][Hu111][Hu23][Hu481][Hu149][Hu80][Hu70][Hu431][Hu457][Hu79][Hu249][Hu245][Hu433][Hu316][Hu180][Hu458][Hu86][Hu225][Hu103][Hu60][Hu96][Hu119][Hu129][Hu356][Hu218][Hu4][Hu259][Hu392][Hu490][Hu75][Hu488][Hu166][Hu65][Hu171][Hu60][Hu7][Hu54][Hu85][Hu361]" 70 | assert encoded_string == expected 71 | --------------------------------------------------------------------------------