├── examples ├── test1.mp3 ├── test_autonomous1.mp3 └── character_ref_emb_demo.pkl ├── requirements.txt ├── LICENSE ├── spkr.py ├── voila_tokenizer.py ├── .gitignore ├── README.md ├── infer.py ├── gradio_demo.py ├── audio_transformer.py ├── tokenize_func.py └── model.py /examples/test1.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maitrix-org/Voila/HEAD/examples/test1.mp3 -------------------------------------------------------------------------------- /examples/test_autonomous1.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maitrix-org/Voila/HEAD/examples/test_autonomous1.mp3 -------------------------------------------------------------------------------- /examples/character_ref_emb_demo.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maitrix-org/Voila/HEAD/examples/character_ref_emb_demo.pkl -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | torchvision 3 | torchaudio 4 | transformers 5 | flash-attn 6 | soundfile 7 | librosa 8 | jsonlines 9 | gradio 10 | pyannote.audio 11 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 Maitrix.org 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /spkr.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchaudio 3 | from torchaudio.functional import resample 4 | 5 | from pyannote.audio import Model 6 | from pyannote.audio.pipelines.speaker_verification import PretrainedSpeakerEmbedding 7 | 8 | 9 | class SpeakerEmbedding: 10 | def __init__(self, model_path="pyannote/wespeaker-voxceleb-resnet34-LM", device="cuda"): 11 | model = Model.from_pretrained(model_path).eval() 12 | 13 | self.device = torch.device(device) 14 | self.sample_rate = 16000 15 | self.model = model.to(self.device) 16 | 17 | @torch.no_grad() 18 | def __call__(self, wav, sr): 19 | wav = torch.tensor(wav, device=self.device) 20 | if sr != self.sample_rate: 21 | wav = resample(wav, sr, self.sample_rate) 22 | sr = self.sample_rate 23 | 24 | assert len(wav.shape) <= 3 25 | is_batch = False 26 | if len(wav.shape) == 3: 27 | is_batch = True 28 | elif len(wav.shape) == 2: 29 | wav = wav[None, :, :] 30 | else: 31 | wav = wav[None, None, :] 32 | 33 | with torch.inference_mode(): 34 | embeddings = self.model(wav) 35 | 36 | if is_batch: 37 | return embeddings 38 | else: 39 | return embeddings[0] 40 | 41 | if __name__ == '__main__': 42 | import argparse 43 | parser = argparse.ArgumentParser() 44 | parser.add_argument("--wav", type=str, required=True) 45 | args = parser.parse_args() 46 | 47 | model = SpeakerEmbedding(device="cuda") 48 | 49 | wav, sr = torchaudio.load(args.wav) 50 | print(model(wav, sr)) 51 | -------------------------------------------------------------------------------- /voila_tokenizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchaudio 3 | from torchaudio.functional import resample 4 | 5 | from transformers import AutoProcessor, EncodecModel 6 | 7 | 8 | ALL_BANDWIDTHS = [1.1] 9 | 10 | class VoilaTokenizer: 11 | def __init__( 12 | self, 13 | model_path="maitrix-org/Voila-Tokenizer", 14 | bandwidth_id=0, 15 | device="cpu", 16 | ): 17 | self.device = torch.device(device) 18 | self.bandwidth = ALL_BANDWIDTHS[bandwidth_id] 19 | self.bandwidth_id = torch.tensor([bandwidth_id], device=device) 20 | 21 | self.processor = AutoProcessor.from_pretrained(model_path) 22 | self.model = EncodecModel.from_pretrained(model_path).to(device) 23 | 24 | self.sampling_rate = self.processor.sampling_rate 25 | self.model_version = self.model.config.model_version 26 | 27 | 28 | @torch.no_grad() 29 | def encode(self, wav, sr): 30 | wav = torch.tensor(wav, dtype=torch.float32, device=self.device) 31 | if sr != self.processor.sampling_rate: 32 | wav = resample(wav, sr, self.processor.sampling_rate) 33 | sr = self.processor.sampling_rate 34 | if len(wav.shape) == 1: 35 | wav = wav[None, None, :] 36 | elif len(wav.shape) == 2: 37 | assert wav.shape[0] == 1 38 | wav = wav[None, :] 39 | elif len(wav.shape) == 3: 40 | assert wav.shape[0] == 1 and wav.shape[1] == 1 41 | 42 | # inputs = self.processor(raw_audio=wav, sampling_rate=sr, return_tensors="pt") 43 | encoder_outputs = self.model.encode(wav, bandwidth=self.bandwidth) 44 | return encoder_outputs.audio_codes[0, 0] 45 | 46 | @torch.no_grad() 47 | def decode(self, audio_codes): 48 | assert len(audio_codes.shape) == 2 49 | audio_values = self.model.decode(audio_codes[None, None, :, :], [None])[0] 50 | return audio_values[0, 0] 51 | 52 | if __name__ == '__main__': 53 | import argparse 54 | import soundfile as sf 55 | 56 | parser = argparse.ArgumentParser() 57 | parser.add_argument("--wav", type=str) 58 | args = parser.parse_args() 59 | 60 | wav, sr = torchaudio.load(args.wav) 61 | if len(wav.shape) > 1: 62 | wav = wav[0] 63 | 64 | model = VoilaTokenizer(device="cuda") 65 | 66 | audio_codes = model.encode(wav, sr) 67 | audio_values = model.decode(audio_codes).cpu().numpy() 68 | 69 | tps = audio_codes.shape[-1] / (audio_values.shape[-1] / model.processor.sampling_rate) 70 | print(audio_codes.shape, audio_values.shape, tps) 71 | sf.write("audio_mt.wav", audio_values, model.processor.sampling_rate) 72 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | output 2 | .gradio 3 | run_*.sh 4 | 5 | # Byte-compiled / optimized / DLL files 6 | __pycache__/ 7 | *.py[cod] 8 | *$py.class 9 | 10 | # C extensions 11 | *.so 12 | 13 | # Distribution / packaging 14 | .Python 15 | build/ 16 | develop-eggs/ 17 | dist/ 18 | downloads/ 19 | eggs/ 20 | .eggs/ 21 | lib/ 22 | lib64/ 23 | parts/ 24 | sdist/ 25 | var/ 26 | wheels/ 27 | share/python-wheels/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | MANIFEST 32 | 33 | # PyInstaller 34 | # Usually these files are written by a python script from a template 35 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 36 | *.manifest 37 | *.spec 38 | 39 | # Installer logs 40 | pip-log.txt 41 | pip-delete-this-directory.txt 42 | 43 | # Unit test / coverage reports 44 | htmlcov/ 45 | .tox/ 46 | .nox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | *.cover 53 | *.py,cover 54 | .hypothesis/ 55 | .pytest_cache/ 56 | cover/ 57 | 58 | # Translations 59 | *.mo 60 | *.pot 61 | 62 | # Django stuff: 63 | *.log 64 | local_settings.py 65 | db.sqlite3 66 | db.sqlite3-journal 67 | 68 | # Flask stuff: 69 | instance/ 70 | .webassets-cache 71 | 72 | # Scrapy stuff: 73 | .scrapy 74 | 75 | # Sphinx documentation 76 | docs/_build/ 77 | 78 | # PyBuilder 79 | .pybuilder/ 80 | target/ 81 | 82 | # Jupyter Notebook 83 | .ipynb_checkpoints 84 | 85 | # IPython 86 | profile_default/ 87 | ipython_config.py 88 | 89 | # pyenv 90 | # For a library or package, you might want to ignore these files since the code is 91 | # intended to run in multiple environments; otherwise, check them in: 92 | # .python-version 93 | 94 | # pipenv 95 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 96 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 97 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 98 | # install all needed dependencies. 99 | #Pipfile.lock 100 | 101 | # UV 102 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. 103 | # This is especially recommended for binary packages to ensure reproducibility, and is more 104 | # commonly ignored for libraries. 105 | #uv.lock 106 | 107 | # poetry 108 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 109 | # This is especially recommended for binary packages to ensure reproducibility, and is more 110 | # commonly ignored for libraries. 111 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 112 | #poetry.lock 113 | 114 | # pdm 115 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 116 | #pdm.lock 117 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 118 | # in version control. 119 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 120 | .pdm.toml 121 | .pdm-python 122 | .pdm-build/ 123 | 124 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 125 | __pypackages__/ 126 | 127 | # Celery stuff 128 | celerybeat-schedule 129 | celerybeat.pid 130 | 131 | # SageMath parsed files 132 | *.sage.py 133 | 134 | # Environments 135 | .env 136 | .venv 137 | env/ 138 | venv/ 139 | ENV/ 140 | env.bak/ 141 | venv.bak/ 142 | 143 | # Spyder project settings 144 | .spyderproject 145 | .spyproject 146 | 147 | # Rope project settings 148 | .ropeproject 149 | 150 | # mkdocs documentation 151 | /site 152 | 153 | # mypy 154 | .mypy_cache/ 155 | .dmypy.json 156 | dmypy.json 157 | 158 | # Pyre type checker 159 | .pyre/ 160 | 161 | # pytype static type analyzer 162 | .pytype/ 163 | 164 | # Cython debug symbols 165 | cython_debug/ 166 | 167 | # PyCharm 168 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 169 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 170 | # and can be added to the global gitignore or merged into this file. For a more nuclear 171 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 172 | #.idea/ 173 | 174 | # PyPI configuration file 175 | .pypirc 176 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 |
3 | Voila: Voice-Language Foundation Models

4 | 💜 Project Page    |    🖥️ GitHub    |   🤗 Hugging Face   |    📑 Paper    |    🌐 Online Demo   |    🏠Maitrix.org 5 |

6 | 7 | Voila is a new family of large voice-language foundation models aiming to lift human-AI interaction experiences to the next level. Breaking away from the constraints of traditional voice AI systems—high latency, loss of vocal nuances, and mechanical responses—Voila employs an innovative end-to-end model design and a novel hierarchical Transformer architecture. This approach enables real-time, autonomous, and rich voice interactions, with latency as low as 195 ms, surpassing average human response times. Combining advanced voice and language modeling, Voila offers customizable, persona-driven engagements and excels in a range of audio tasks from ASR and TTS to speech translation across six languages. With the online [web demo](https://huggingface.co/spaces/maitrix-org/Voila-demo), Voila invites you to explore a transformative, natural dialogue experience between human and AI. 8 | 9 | # ✨ Highlights 10 | - ⭐ High-fidelity, low-latency, real-time streaming audio processing 11 | - ⭐ Effective integration of voice and language modeling capabilities 12 | - ⭐ Millions of pre-built and custom voices, fast voice switching during conversation 13 | - ⭐ Unified model for various audio tasks 14 | 15 | # 🎥 Video Demo 16 | [![Voila Demo](https://img.youtube.com/vi/J27M9-g5KL0/0.jpg)](https://www.youtube.com/watch?v=J27M9-g5KL0) 17 | 18 | # 🔥 Latest News!! 19 | 20 | * April 28, 2025: 👋 We've released the inference code and model weights of Voila. 21 | 22 | # ⚙️ Foundation Models 23 | 24 | | Model | Description | Download Link | 25 | |--------|-----------|-----------------| 26 | |Voila-base|Voila base model|https://huggingface.co/maitrix-org/Voila-base| 27 | |Voila-Chat|End-to-end audio chat model|https://huggingface.co/maitrix-org/Voila-chat| 28 | |Voila-Autonomous (preview)|Full-duplex audio chat model|https://huggingface.co/maitrix-org/Voila-autonomous-preview| 29 | |Voila-Audio-alpha|Empowering LLM with raw audio input|https://huggingface.co/maitrix-org/Voila-audio-alpha| 30 | |Voila-Tokenizer|Audio tokenizer|https://huggingface.co/maitrix-org/Voila-Tokenizer| 31 | 32 | ## Usage 33 | ### CLI demo 34 | ```shell 35 | for model_name in "maitrix-org/Voila-audio-alpha" "maitrix-org/Voila-base" "maitrix-org/Voila-chat"; do 36 | # Text chat 37 | python infer.py \ 38 | --model-name ${model_name} \ 39 | --instruction "" \ 40 | --input-text "Hello" \ 41 | --task-type chat_tito 42 | # Voice chat 43 | python infer.py \ 44 | --model-name ${model_name} \ 45 | --instruction "" \ 46 | --input-audio "examples/test1.mp3" \ 47 | --task-type chat_aiao 48 | done 49 | 50 | # Autonomous mode 51 | python infer.py \ 52 | --model-name "maitrix-org/Voila-autonomous-preview" \ 53 | --instruction "" \ 54 | --input-audio "examples/test_autonomous1.mp3" \ 55 | --task-type chat_aiao_auto 56 | ``` 57 | 58 | ### Gradio demo 59 | ```shell 60 | python gradio_demo.py 61 | ``` 62 | 63 | For more information, please refer to the [code repository](https://github.com/maitrix-org/Voila). 64 | 65 | # 📁 Datasets 66 | We publish the following two datasets: Voila Benchmark and Voila Voice Library. Voila-Benchmark is a novel speech evaluation benchmark, while Voila Voice Library provides millions of pre-built and customizable voices. 67 | 68 | | Dataset | Description | Download Link | 69 | |--------|-----------|-----------------| 70 | |Voila Benchmark| Evaluation of Voila Benchmark | https://huggingface.co/datasets/maitrix-org/Voila-Benchmark | 71 | |Voila Voice Library| Millons of pre-build voices | https://huggingface.co/datasets/maitrix-org/Voila-million-voice 72 | 73 | # 📊 Benchmark 74 | ## 1. Voila Benchmark 75 | We introduce a novel speech evaluation benchmark called the VoilaBenchmark. The Voila Benchmark is constructed by sampling from five widely used language model evaluation datasets: MMLU, MATH, OpenAI HumanEval, NQ-Open, and GSM8k. We compare our results with SpeechGPT and Moshi. 76 | | Model | Voila Benchmark | 77 | |-------|----------------| 78 | |SpeechGPT| 13.29| 79 | |Moshi | 11.45 | 80 | |**Voila** | **30.56** | 81 | 82 | _(higher is better)_ 83 | 84 | For detailed scores of Voila Benchmark on each specific domain, please refer to our paper (Section 5.1 "Evaluation of Voila Benchmark"). 85 | ## 2. Evaluation of ASR 86 | As Voila supports multiple tasks, including Automatic Speech Recognition (ASR), Text-to-Speech(TTS), and spoken question answering, we also evaluate the performance of ASR and TTS. 87 | For ASR, we assess performance on the LibriSpeech test-clean dataset, using Word Error Rate (WER) as our metric. Voila attains a word error rate (WER) of 4.8%, outperforming the 5.7% reported by Moshi. In scenarios where both models utilize LibriSpeech training data, Voila achieves an impressive WER of 2.7%. 88 | | Model | LibriSpeech test-clean (WER) | 89 | |-------|-----------------------| 90 | |Whisper large v2|2.7| 91 | |Whisper large v3|2.2| 92 | |FastConformer|3.6| 93 | |VoxtLM |2.7| 94 | |Moshi |5.7| 95 | |**Voila (w/o LibriSpeech train split)** |**4.8**| 96 | |**Voila (with LibriSpeech train split)**|**2.7**| 97 | 98 | _(lower is better)_ 99 | 100 | ## 3. Evaluation of TTS 101 | For TTS, we follow the evaluation metrics proposed in Vall-E, which involves transcribing the generated audio using HuBERT-Large. 102 | Voila once again leads with a WER of 3.2% (and 2.8% when using LibriSpeech training data). 103 | 104 | | Model | LibriSpeech test-clean (WER) | 105 | |-------|-----------------------| 106 | |YourTTS |7.7| 107 | |Vall-E|5.9| 108 | |Moshi|4.7| 109 | |**Voila (w/o LibriSpeech train split)** |**3.2**| 110 | |**Voila (with LibriSpeech train split)** |**2.8**| 111 | 112 | _(lower is better)_ 113 | 114 | # 📝 Citation 115 | If you find our work helpful, please cite us. 116 | 117 | ``` 118 | @article{voila2025, 119 | author = {Yemin Shi, Yu Shu, Siwei Dong, Guangyi Liu, Jaward Sesay, Jingwen Li, Zhiting Hu}, 120 | title = {Voila: Voice-Language Foundation Models for Real-Time Autonomous Interaction and Voice Roleplay}, 121 | eprint={2505.02707}, 122 | archivePrefix={arXiv}, 123 | primaryClass={cs.CL}, 124 | year = {2025} 125 | } 126 | ``` 127 | -------------------------------------------------------------------------------- /infer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import random 4 | import jsonlines 5 | import soundfile as sf 6 | import json 7 | import copy 8 | import torch 9 | from pathlib import Path 10 | from threading import Thread 11 | 12 | import torchaudio 13 | from transformers import AutoTokenizer 14 | 15 | from model import VoilaAudioAlphaModel, VoilaModel, VoilaAutonomousModel 16 | from spkr import SpeakerEmbedding 17 | from voila_tokenizer import VoilaTokenizer 18 | from tokenize_func import ( 19 | voila_input_format, 20 | AUDIO_TOKEN_FORMAT, 21 | DEFAULT_AUDIO_TOKEN, 22 | DEFAULT_ASSISTANT_TOKEN, 23 | ) 24 | 25 | 26 | def disable_torch_init(): 27 | """ 28 | Disable the redundant torch default initialization to accelerate model creation. 29 | """ 30 | import torch 31 | setattr(torch.nn.Linear, "reset_parameters", lambda self: None) 32 | setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None) 33 | 34 | def load_model(model_name, audio_tokenizer_path): 35 | disable_torch_init() 36 | 37 | if "Voila-audio" in model_name: 38 | model_type = "audio" 39 | cls = VoilaAudioAlphaModel 40 | elif "Voila-auto" in model_name: 41 | model_type = "autonomous" 42 | cls = VoilaAutonomousModel 43 | else: 44 | model_type = "token" 45 | cls = VoilaModel 46 | 47 | model = cls.from_pretrained( 48 | model_name, 49 | torch_dtype=torch.bfloat16, 50 | use_flash_attention_2=True, 51 | use_cache=True, 52 | ) 53 | model = model.cuda() 54 | tokenizer = AutoTokenizer.from_pretrained(model_name) 55 | tokenizer_voila = VoilaTokenizer(model_path=audio_tokenizer_path, device="cuda") 56 | return model, tokenizer, tokenizer_voila, model_type 57 | 58 | def is_audio_output_task(task_type): 59 | return task_type.endswith("ao") or "aiao" in task_type or "tts" in task_type 60 | 61 | def eval_model(model, tokenizer, tokenizer_voila, model_type, task_type, history, ref_embs, ref_embs_mask, max_new_tokens=512): 62 | # step1: initializing 63 | num_codebooks = model.config.num_codebooks 64 | codebook_size = model.config.codebook_size 65 | 66 | AUDIO_MIN_TOKEN_ID = tokenizer.convert_tokens_to_ids(AUDIO_TOKEN_FORMAT.format(0)) 67 | assert isinstance(AUDIO_MIN_TOKEN_ID, int) 68 | AUDIO_MAX_TOKEN_ID = tokenizer.convert_tokens_to_ids(AUDIO_TOKEN_FORMAT.format(codebook_size*num_codebooks-1)) 69 | assert isinstance(AUDIO_MAX_TOKEN_ID, int) 70 | AUDIO_TOKEN_ID = tokenizer.convert_tokens_to_ids(DEFAULT_AUDIO_TOKEN) 71 | assert isinstance(AUDIO_TOKEN_ID, int) 72 | ASSISTANT_TOKEN_ID = tokenizer.convert_tokens_to_ids(DEFAULT_ASSISTANT_TOKEN) 73 | assert isinstance(ASSISTANT_TOKEN_ID, int) 74 | 75 | # step2: set infer config 76 | data_cfg = { 77 | "input_type": model_type, 78 | "task_type": task_type, 79 | "num_codebooks": num_codebooks, 80 | "codebook_size": codebook_size, 81 | } 82 | 83 | # step3: infer 84 | input_ids, audio_datas, audio_data_masks, streaming_user_input_audio_tokens = voila_input_format(history, tokenizer, tokenizer_voila, data_cfg) 85 | 86 | # prepare user_streaming_generator to simulate streaming user input 87 | def get_input_generator(all_tokens): 88 | assert all_tokens is not None 89 | for i in range(len(all_tokens[0])): 90 | yield all_tokens[:,i] 91 | 92 | if model_type == "autonomous": 93 | input_generator = get_input_generator(torch.as_tensor(streaming_user_input_audio_tokens).cuda()) 94 | input_ids = [torch.as_tensor([input]).transpose(1,2).cuda() for input in input_ids] # transpose to [bs, seq, num_codebooks] 95 | input_ids = torch.cat(input_ids, dim=2) # concat to [bs, seq, num_codebooks*2] 96 | else: 97 | input_ids = torch.as_tensor([input_ids]).transpose(1,2).cuda() # transpose to [bs, seq, num_codebooks] 98 | gen_params = { 99 | "input_ids": input_ids, 100 | "ref_embs": ref_embs, 101 | "ref_embs_mask": ref_embs_mask, 102 | "max_new_tokens": max_new_tokens, 103 | "pad_token_id": tokenizer.pad_token_id, 104 | "eos_token_id": tokenizer.eos_token_id, 105 | "llm_audio_token_id": AUDIO_TOKEN_ID, 106 | "min_audio_token_id": AUDIO_MIN_TOKEN_ID, 107 | "temperature": 0.2, 108 | "top_k": 50, 109 | "audio_temperature": 0.8, 110 | "audio_top_k": 50, 111 | } 112 | if model_type == "audio": 113 | audio_datas = torch.tensor([audio_datas], dtype=torch.bfloat16).cuda() 114 | audio_data_masks = torch.tensor([audio_data_masks]).cuda() 115 | gen_params["audio_datas"] = audio_datas 116 | gen_params["audio_data_masks"] = audio_data_masks 117 | elif model_type == "autonomous": 118 | gen_params["input_generator"] = input_generator 119 | gen_params["llm_assistant_token_id"] = ASSISTANT_TOKEN_ID 120 | print(f"Input str: {tokenizer.decode(input_ids[0, :, 0])}") 121 | with torch.inference_mode(): 122 | outputs = model.run_generate(**gen_params) 123 | 124 | if model_type == "autonomous": 125 | outputs = outputs.chunk(2, dim=2)[1] 126 | outputs = outputs[0].cpu().tolist() 127 | 128 | predict_outputs = outputs[input_ids.shape[1]:] 129 | text_outputs = [] 130 | audio_outputs = [] 131 | for _ in range(num_codebooks): 132 | audio_outputs.append([]) 133 | for item in predict_outputs: 134 | if item[0] >= AUDIO_MIN_TOKEN_ID and item[0] <= AUDIO_MAX_TOKEN_ID: 135 | for n, at in enumerate(item): 136 | audio_outputs[n].append((at - AUDIO_MIN_TOKEN_ID)%codebook_size) 137 | elif item[0] != tokenizer.eos_token_id: 138 | text_outputs.append(item[0]) 139 | 140 | out ={ 141 | 'text': tokenizer.decode(text_outputs), 142 | } 143 | if is_audio_output_task(task_type): 144 | audio_values = tokenizer_voila.decode(torch.tensor(audio_outputs).cuda()) 145 | out['audio'] = (audio_values.detach().cpu().numpy(), 16000) 146 | return out 147 | 148 | 149 | if __name__ == "__main__": 150 | parser = argparse.ArgumentParser() 151 | parser.add_argument("--instruction", type=str, default="") 152 | parser.add_argument("--input-text", type=str, default=None) 153 | parser.add_argument("--input-audio", type=str, default=None) 154 | parser.add_argument("--result-path", type=str, default="output") 155 | parser.add_argument("--ref-audio", type=str, default="examples/test1.mp3") 156 | parser.add_argument("--model-name", type=str, default="maitrix-org/Voila-chat") 157 | parser.add_argument("--audio-tokenizer-path", type=str, default="maitrix-org/Voila-Tokenizer") 158 | parser.add_argument("--task-type", type=str, default="chat_aiao") 159 | args = parser.parse_args() 160 | 161 | assert args.model_name in [ 162 | "maitrix-org/Voila-audio-alpha", 163 | "maitrix-org/Voila-base", 164 | "maitrix-org/Voila-chat", 165 | "maitrix-org/Voila-autonomous-preview", 166 | ] 167 | 168 | # step0: Model loading 169 | model, tokenizer, tokenizer_voila, model_type = load_model(args.model_name, args.audio_tokenizer_path) 170 | 171 | # step1: prepare inputs 172 | Path(args.result_path).mkdir(exist_ok=True, parents=True) 173 | history = { 174 | "instruction": args.instruction, 175 | "conversations": [], 176 | } 177 | if args.input_text is not None: 178 | history["conversations"].append({"from": "user", "text": args.input_text}) 179 | elif args.input_audio is not None: 180 | history["conversations"].append({"from": "user", "audio": {"file": args.input_audio}}) 181 | else: 182 | raise Exception("Please provide atleast one of --input-text and --input-audio") 183 | history["conversations"].append({"from": "assistant"}) 184 | 185 | # step2: encode ref 186 | ref_embs, ref_embs_mask = None, None 187 | if is_audio_output_task(args.task_type): 188 | spkr_model = SpeakerEmbedding(device="cuda") 189 | wav, sr = torchaudio.load(args.ref_audio) 190 | ref_embs = spkr_model(wav, sr) 191 | ref_embs_mask = torch.tensor([1]).cuda() 192 | 193 | out = eval_model(model, tokenizer, tokenizer_voila, model_type, args.task_type, history, ref_embs, ref_embs_mask) 194 | print(f"Output str: {out['text']}") 195 | if 'audio' in out: 196 | wav, sr = out['audio'] 197 | save_name = f"{args.result_path}/out.wav" 198 | sf.write(save_name, wav, sr) 199 | -------------------------------------------------------------------------------- /gradio_demo.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import shutil 4 | import pickle 5 | import gradio as gr 6 | import soundfile as sf 7 | from pathlib import Path 8 | 9 | import torch 10 | import torchaudio 11 | 12 | from huggingface_hub import hf_hub_download 13 | 14 | from infer import load_model, eval_model 15 | from spkr import SpeakerEmbedding 16 | 17 | 18 | spkr_model = SpeakerEmbedding(device="cuda") 19 | model, tokenizer, tokenizer_voila, model_type = load_model("maitrix-org/Voila-chat", "maitrix-org/Voila-Tokenizer") 20 | default_ref_file = "examples/character_ref_emb_demo.pkl" 21 | default_ref_name = "Homer Simpson" 22 | million_voice_ref_file = hf_hub_download(repo_id="maitrix-org/Voila-million-voice", filename="character_ref_emb_chunk0.pkl", repo_type="dataset") 23 | 24 | instruction = "You are a smart AI agent created by Maitrix.org." 25 | save_path = "output" 26 | 27 | intro = """**Voila** 28 | 29 | For more demos, please goto [https://voila.maitrix.org](https://voila.maitrix.org).""" 30 | 31 | default_ref_emb_mask_list = pickle.load(open(default_ref_file, "rb")) 32 | million_voice_ref_emb_mask_list = pickle.load(open(million_voice_ref_file, "rb")) 33 | 34 | def get_ref_embs(ref_audio): 35 | wav, sr = torchaudio.load(ref_audio) 36 | ref_embs = spkr_model(wav, sr).cpu() 37 | return ref_embs 38 | 39 | def delete_directory(request: gr.Request): 40 | if not request.session_hash: 41 | return 42 | user_dir = Path(f"{save_path}/{str(request.session_hash)}") 43 | if user_dir.exists(): 44 | shutil.rmtree(str(user_dir)) 45 | 46 | def add_message(history, message): 47 | history.append({"role": "user", "content": {"path": message}}) 48 | return history, gr.Audio(value=None), gr.Button(interactive=False) 49 | 50 | def call_bot(history, ref_embs, request: gr.Request): 51 | formated_history = { 52 | "instruction": instruction, 53 | "conversations": [{'from': item["role"], 'audio': {"file": item["content"][0]}} for item in history], 54 | } 55 | formated_history["conversations"].append({"from": "assistant"}) 56 | print(formated_history) 57 | ref_embs = torch.tensor(ref_embs, dtype=torch.float32, device="cuda") 58 | ref_embs_mask = torch.tensor([1], device="cuda") 59 | out = eval_model(model, tokenizer, tokenizer_voila, model_type, "chat_aiao", formated_history, ref_embs, ref_embs_mask, max_new_tokens=512) 60 | if 'audio' in out: 61 | wav, sr = out['audio'] 62 | 63 | user_dir = Path(f"{save_path}/{str(request.session_hash)}") 64 | user_dir.mkdir(exist_ok=True) 65 | save_name = f"{user_dir}/{len(history)}.wav" 66 | sf.write(save_name, wav, sr) 67 | 68 | history.append({"role": "assistant", "content": {"path": save_name}}) 69 | else: 70 | history.append({"role": "assistant", "content": {"text": out['text']}}) 71 | 72 | return history 73 | 74 | def run_tts(text, ref_embs): 75 | formated_history = { 76 | "instruction": "", 77 | "conversations": [{'from': "user", 'text': text}], 78 | } 79 | formated_history["conversations"].append({"from": "assistant"}) 80 | ref_embs = torch.tensor(ref_embs, dtype=torch.float32, device="cuda") 81 | ref_embs_mask = torch.tensor([1], device="cuda") 82 | out = eval_model(model, tokenizer, tokenizer_voila, model_type, "chat_tts", formated_history, ref_embs, ref_embs_mask, max_new_tokens=512) 83 | if 'audio' in out: 84 | wav, sr = out['audio'] 85 | return (sr, wav) 86 | else: 87 | raise Exception("No audio output") 88 | 89 | def run_asr(audio): 90 | formated_history = { 91 | "instruction": "", 92 | "conversations": [{'from': "user", 'audio': {"file": audio}}], 93 | } 94 | formated_history["conversations"].append({"from": "assistant"}) 95 | out = eval_model(model, tokenizer, tokenizer_voila, model_type, "chat_asr", formated_history, None, None, max_new_tokens=512) 96 | if 'text' in out: 97 | return out['text'] 98 | else: 99 | raise Exception("No text output") 100 | 101 | 102 | def markdown_ref_name(ref_name): 103 | return f"### Current voice id: {ref_name}" 104 | 105 | def random_million_voice(): 106 | voice_id = random.choice(list(million_voice_ref_emb_mask_list.keys())) 107 | return markdown_ref_name(voice_id), million_voice_ref_emb_mask_list[voice_id] 108 | 109 | def get_ref_modules(cur_ref_embs): 110 | with gr.Row() as ref_row: 111 | with gr.Row(): 112 | current_ref_name = gr.Markdown(markdown_ref_name(default_ref_name)) 113 | with gr.Row() as ref_name_row: 114 | with gr.Column(scale=2, min_width=160): 115 | ref_name_dropdown = gr.Dropdown( 116 | choices=list(default_ref_emb_mask_list.keys()), 117 | value=default_ref_name, 118 | label="Reference voice", 119 | min_width=160, 120 | ) 121 | with gr.Column(scale=1, min_width=80): 122 | random_ref_button = gr.Button( 123 | "Random from Million Voice", size="md", 124 | ) 125 | with gr.Row(visible=False) as ref_audio_row: 126 | with gr.Column(scale=2, min_width=80): 127 | ref_audio = gr.Audio( 128 | sources=["microphone", "upload"], 129 | type="filepath", 130 | show_label=False, 131 | min_width=80, 132 | ) 133 | with gr.Column(scale=1, min_width=80): 134 | change_ref_button = gr.Button( 135 | "Change voice", 136 | interactive=False, 137 | min_width=80, 138 | ) 139 | ref_name_dropdown.change( 140 | lambda x: (markdown_ref_name(x), default_ref_emb_mask_list[x]), 141 | ref_name_dropdown, 142 | [current_ref_name, cur_ref_embs] 143 | ) 144 | random_ref_button.click( 145 | random_million_voice, 146 | None, 147 | [current_ref_name, cur_ref_embs], 148 | ) 149 | ref_audio.input(lambda: gr.Button(interactive=True), None, change_ref_button) 150 | # If custom ref voice checkbox is checked, show the Audio component to record or upload a reference voice 151 | custom_ref_voice = gr.Checkbox(label="Use custom voice", value=False) 152 | # Checked: enable audio and button 153 | # Unchecked: disable audio and button 154 | def custom_ref_voice_change(x, cur_ref_embs, cur_ref_embs_mask): 155 | if not x: 156 | cur_ref_embs = default_ref_emb_mask_list[default_ref_name] 157 | return [gr.Row(visible=not x), gr.Audio(value=None), gr.Row(visible=x), markdown_ref_name("Custom voice"), cur_ref_embs] 158 | custom_ref_voice.change( 159 | custom_ref_voice_change, 160 | [custom_ref_voice, cur_ref_embs], 161 | [ref_name_row, ref_audio, ref_audio_row, current_ref_name, cur_ref_embs] 162 | ) 163 | # When change ref button is clicked, get the reference voice and update the reference voice state 164 | change_ref_button.click( 165 | lambda: gr.Button(interactive=False), None, [change_ref_button] 166 | ).then( 167 | get_ref_embs, ref_audio, cur_ref_embs 168 | ) 169 | return ref_row 170 | 171 | def get_chat_tab(): 172 | cur_ref_embs = gr.State(default_ref_emb_mask_list[default_ref_name]) 173 | with gr.Row() as chat_tab: 174 | with gr.Column(scale=1): 175 | ref_row = get_ref_modules(cur_ref_embs) 176 | # Voice chat input 177 | chat_input = gr.Audio( 178 | sources=["microphone", "upload"], 179 | type="filepath", 180 | show_label=False, 181 | ) 182 | submit = gr.Button("Submit", interactive=False) 183 | gr.Markdown(intro) 184 | with gr.Column(scale=9): 185 | chatbot = gr.Chatbot( 186 | elem_id="chatbot", 187 | type="messages", 188 | bubble_full_width=False, 189 | scale=1, 190 | show_copy_button=False, 191 | avatar_images=( 192 | None, # os.path.join("files", "avatar.png"), 193 | None, # os.path.join("files", "avatar.png"), 194 | ), 195 | ) 196 | 197 | chat_input.input(lambda: gr.Button(interactive=True), None, submit) 198 | submit.click( 199 | add_message, [chatbot, chat_input], [chatbot, chat_input, submit] 200 | ).then( 201 | call_bot, [chatbot, cur_ref_embs], chatbot, api_name="bot_response" 202 | ) 203 | return chat_tab 204 | 205 | def get_tts_tab(): 206 | cur_ref_embs = gr.State(default_ref_emb_mask_list[default_ref_name]) 207 | with gr.Row() as tts_tab: 208 | with gr.Column(scale=1): 209 | ref_row = get_ref_modules(cur_ref_embs) 210 | gr.Markdown(intro) 211 | with gr.Column(scale=9): 212 | tts_output = gr.Audio(label="TTS output", interactive=False) 213 | with gr.Row(): 214 | text_input = gr.Textbox(label="Text", placeholder="Text to TTS") 215 | submit = gr.Button("Submit") 216 | submit.click( 217 | run_tts, [text_input, cur_ref_embs], tts_output 218 | ) 219 | return tts_tab 220 | 221 | def get_asr_tab(): 222 | with gr.Row() as asr_tab: 223 | with gr.Column(): 224 | asr_input = gr.Audio( 225 | label="ASR input", 226 | sources=["microphone", "upload"], 227 | type="filepath", 228 | ) 229 | submit = gr.Button("Submit") 230 | gr.Markdown(intro) 231 | with gr.Column(): 232 | asr_output = gr.Textbox(label="ASR output", interactive=False) 233 | submit.click( 234 | run_asr, [asr_input], asr_output 235 | ) 236 | return asr_tab 237 | 238 | with gr.Blocks(fill_height=True) as demo: 239 | with gr.Tab("Chat"): 240 | chat_tab = get_chat_tab() 241 | with gr.Tab("TTS"): 242 | tts_tab = get_tts_tab() 243 | with gr.Tab("ASR"): 244 | asr_tab = get_asr_tab() 245 | demo.unload(delete_directory) 246 | 247 | if __name__ == "__main__": 248 | demo.launch(share=True) 249 | -------------------------------------------------------------------------------- /audio_transformer.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Optional 3 | from dataclasses import dataclass 4 | 5 | import torch 6 | import torch.nn as nn 7 | from torch import Tensor 8 | from torch.nn import functional as F 9 | 10 | from einops import rearrange 11 | 12 | 13 | @dataclass 14 | class LocalArgs: 15 | codebook_size: int = 2048 16 | num_codebooks: int = 4 17 | 18 | # Modified from https://github.com/fishaudio/fish-speech/blob/main/fish_speech/models/text2semantic/llama.py#L105 19 | class KVCache(nn.Module): 20 | def __init__( 21 | self, n_layer, batch_size, max_seq_len, n_heads, head_dim, dtype, device 22 | ): 23 | super().__init__() 24 | cache_shape = (n_layer, batch_size, n_heads, max_seq_len, head_dim) 25 | self.register_buffer("k_cache", torch.zeros(cache_shape, dtype=dtype, device=device)) 26 | self.register_buffer("v_cache", torch.zeros(cache_shape, dtype=dtype, device=device)) 27 | 28 | def update(self, layer_idx, input_pos, k_val, v_val): 29 | # k_val: [B, H, S, D] 30 | 31 | k_out = self.k_cache 32 | v_out = self.v_cache 33 | k_out[layer_idx, :, :, input_pos:input_pos+1] = k_val 34 | v_out[layer_idx, :, :, input_pos:input_pos+1] = v_val 35 | 36 | return k_out[layer_idx], v_out[layer_idx] 37 | 38 | # Modified from https://github.com/fishaudio/fish-speech/blob/main/fish_speech/models/text2semantic/llama.py#L756 39 | def precompute_freqs_cis(seq_len: int, n_elem: int, base: int = 10000) -> Tensor: 40 | freqs = 1.0 / ( 41 | base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem) 42 | ) 43 | t = torch.arange(seq_len, device=freqs.device) 44 | freqs = torch.outer(t, freqs) 45 | freqs_cis = torch.polar(torch.ones_like(freqs), freqs) 46 | cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1) 47 | return cache 48 | 49 | # Copied from https://github.com/fishaudio/fish-speech/blob/main/fish_speech/models/text2semantic/llama.py#L767 50 | def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor: 51 | xshaped = x.float().reshape(*x.shape[:-1], -1, 2) 52 | freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2) 53 | x_out2 = torch.stack( 54 | [ 55 | xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1], 56 | xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1], 57 | ], 58 | -1, 59 | ) 60 | 61 | x_out2 = x_out2.flatten(3) 62 | return x_out2.type_as(x) 63 | 64 | # Copied from https://github.com/fishaudio/fish-speech/blob/main/fish_speech/models/text2semantic/llama.py#L742 65 | class RMSNorm(nn.Module): 66 | def __init__(self, dim: int, eps: float = 1e-5): 67 | super().__init__() 68 | self.eps = eps 69 | self.weight = nn.Parameter(torch.ones(dim)) 70 | 71 | def _norm(self, x): 72 | return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps) 73 | 74 | def forward(self, x: Tensor) -> Tensor: 75 | output = self._norm(x.float()).type_as(x) 76 | return output * self.weight 77 | 78 | # Copied from https://github.com/fishaudio/fish-speech/blob/main/fish_speech/models/text2semantic/llama.py#L731 79 | class FeedForward(nn.Module): 80 | def __init__(self, config: LocalArgs) -> None: 81 | super().__init__() 82 | self.w1 = nn.Linear(config.dim, config.intermediate_size, bias=False) 83 | self.w3 = nn.Linear(config.dim, config.intermediate_size, bias=False) 84 | self.w2 = nn.Linear(config.intermediate_size, config.dim, bias=False) 85 | 86 | def forward(self, x: Tensor) -> Tensor: 87 | return self.w2(F.silu(self.w1(x)) * self.w3(x)) 88 | 89 | # Modified from https://github.com/fishaudio/fish-speech/blob/main/fish_speech/models/text2semantic/llama.py#L615 90 | class Attention(nn.Module): 91 | def __init__(self, config: LocalArgs, layer_idx: int, use_sdpa: bool = True): 92 | super().__init__() 93 | assert config.dim % config.n_head == 0 94 | self.layer_idx = layer_idx 95 | 96 | total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim 97 | # key, query, value projections for all heads, but in a batch 98 | self.wqkv = nn.Linear( 99 | config.dim, total_head_dim, bias=config.attention_qkv_bias 100 | ) 101 | self.wo = nn.Linear(config.dim, config.dim, bias=False) 102 | 103 | self.dropout = config.dropout 104 | self.n_head = config.n_head 105 | self.head_dim = config.head_dim 106 | self.n_local_heads = config.n_local_heads 107 | self.dim = config.dim 108 | self.use_sdpa = use_sdpa 109 | self._register_load_state_dict_pre_hook(self.load_hook) 110 | 111 | def load_hook(self, state_dict, prefix, *args): 112 | if prefix + "wq.weight" in state_dict: 113 | wq = state_dict.pop(prefix + "wq.weight") 114 | wk = state_dict.pop(prefix + "wk.weight") 115 | wv = state_dict.pop(prefix + "wv.weight") 116 | state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv]) 117 | 118 | def forward( 119 | self, 120 | x: Tensor, 121 | freqs_cis: Tensor, 122 | mask: Tensor, 123 | input_pos: Optional[int] = None, 124 | kv_cache: Optional[KVCache] = None, 125 | ) -> Tensor: 126 | bsz, seqlen, _ = x.shape 127 | 128 | kv_size = self.n_local_heads * self.head_dim 129 | q, k, v = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1) 130 | 131 | q = q.view(bsz, seqlen, self.n_head, self.head_dim) 132 | k = k.view(bsz, seqlen, self.n_local_heads, self.head_dim) 133 | v = v.view(bsz, seqlen, self.n_local_heads, self.head_dim) 134 | 135 | q = apply_rotary_emb(q, freqs_cis) 136 | k = apply_rotary_emb(k, freqs_cis) 137 | 138 | q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v)) 139 | 140 | if kv_cache is not None: 141 | k, v = kv_cache.update(self.layer_idx, input_pos, k, v) 142 | 143 | k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1) 144 | v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1) 145 | 146 | if self.use_sdpa: 147 | if mask is None: 148 | with sdpa_kernel(SDPBackend.FLASH_ATTENTION): 149 | y = F.scaled_dot_product_attention( 150 | q, 151 | k, 152 | v, 153 | dropout_p=self.dropout if self.training else 0.0, 154 | is_causal=True, 155 | # No third party attn_mask here to use flash_attention 156 | ) 157 | else: 158 | y = F.scaled_dot_product_attention( 159 | q, 160 | k, 161 | v, 162 | attn_mask=mask, 163 | dropout_p=self.dropout if self.training else 0.0, 164 | ) 165 | else: 166 | y = self.eq_scaled_dot_product_attention( 167 | q, 168 | k, 169 | v, 170 | attn_mask=mask, 171 | dropout_p=self.dropout if self.training else 0.0, 172 | ) 173 | 174 | y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim) 175 | 176 | return self.wo(y) 177 | 178 | def eq_scaled_dot_product_attention( 179 | self, 180 | query, 181 | key, 182 | value, 183 | attn_mask=None, 184 | dropout_p=0.0, 185 | ) -> torch.Tensor: 186 | # This is a standard scaled dot product attention 187 | # It's low efficient, but it doesn't raise cuda error 188 | 189 | L, S = query.size(-2), key.size(-2) 190 | scale_factor = 1 / math.sqrt(query.size(-1)) 191 | attn_bias = torch.zeros(1, 1, L, S, dtype=query.dtype, device=query.device) 192 | 193 | if attn_mask is not None: 194 | if attn_mask.dtype == torch.bool: 195 | attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf")) 196 | else: 197 | attn_bias += attn_mask 198 | 199 | attn_weight = query @ key.transpose(-2, -1) * scale_factor 200 | attn_weight += attn_bias 201 | attn_weight = torch.softmax(attn_weight, dim=-1) 202 | attn_weight = torch.dropout(attn_weight, dropout_p, train=True) 203 | 204 | return attn_weight @ value 205 | 206 | # Copied from https://github.com/fishaudio/fish-speech/blob/main/fish_speech/models/text2semantic/llama.py#L599 207 | class TransformerBlock(nn.Module): 208 | def __init__(self, config: LocalArgs, layer_idx: int, use_sdpa: bool = True) -> None: 209 | super().__init__() 210 | self.attention = Attention(config, layer_idx, use_sdpa=use_sdpa) 211 | self.feed_forward = FeedForward(config) 212 | self.ffn_norm = RMSNorm(config.dim, config.norm_eps) 213 | self.attention_norm = RMSNorm(config.dim, config.norm_eps) 214 | 215 | def forward( 216 | self, x: Tensor, freqs_cis: Tensor, mask: Tensor, input_pos: int = None, kv_cache: KVCache = None 217 | ) -> Tensor: 218 | h = x + self.attention(self.attention_norm(x), freqs_cis, mask, input_pos, kv_cache) 219 | out = h + self.feed_forward(self.ffn_norm(h)) 220 | return out 221 | 222 | # Modified from https://github.com/fishaudio/fish-speech/blob/main/fish_speech/models/text2semantic/llama.py#L470 223 | class AudioTransformer(nn.Module): 224 | def __init__(self, config, use_sdpa: bool = False): 225 | super().__init__() 226 | self.config = LocalArgs() 227 | self.config.codebook_size = config.codebook_size 228 | self.config.num_codebooks = config.num_codebooks 229 | if hasattr(config, "min_audio_token_id"): 230 | self.config.min_audio_token_id = config.min_audio_token_id 231 | self.config.max_audio_token_id = config.max_audio_token_id 232 | self.config.n_layer = 4 233 | self.config.dim = 1024 234 | self.config.n_head = 32 235 | self.config.n_local_heads = 32 236 | self.config.intermediate_size = 2816 237 | self.config.head_dim = self.config.dim // self.config.n_head 238 | self.config.norm_eps = 1e-5 239 | self.config.attention_qkv_bias = False 240 | self.config.dropout = 0.0 241 | 242 | self.embeddings = nn.Embedding(self.config.codebook_size, self.config.dim) 243 | if self.config.dim != config.hidden_size: 244 | self.input_proj = nn.Linear(config.hidden_size, self.config.dim, bias=False) 245 | else: 246 | self.input_proj = nn.Identity() 247 | self.layers = nn.ModuleList( 248 | TransformerBlock(self.config, layer_idx, use_sdpa=use_sdpa) for layer_idx in range(self.config.n_layer) 249 | ) 250 | self.norm = RMSNorm(self.config.dim, eps=self.config.norm_eps) 251 | self.token_head = nn.Linear(self.config.dim, self.config.codebook_size, bias=False) 252 | self.gradient_checkpointing = False 253 | 254 | self.register_buffer( 255 | "freqs_cis", 256 | precompute_freqs_cis(self.config.num_codebooks, self.config.dim // self.config.n_head, 10000), 257 | persistent=False, 258 | ) 259 | self.register_buffer( 260 | "attention_mask", 261 | torch.tril(torch.ones(self.config.num_codebooks, self.config.num_codebooks, dtype=torch.bool)), 262 | persistent=False, 263 | ) 264 | 265 | def run_model(self, hidden_states, freqs_cis, attention_mask, input_pos: int = None, kv_cache: KVCache = None): 266 | for layer in self.layers: 267 | # TODO: gradient_checkpointing is disabled because of bug 268 | if False: # self.gradient_checkpointing and self.training: 269 | hidden_states = self._gradient_checkpointing_func( 270 | layer.__call__, 271 | hidden_states, 272 | freqs_cis, 273 | attention_mask, 274 | use_reentrant=True, 275 | ) 276 | else: 277 | hidden_states = layer(hidden_states, freqs_cis, attention_mask, input_pos, kv_cache) 278 | hidden_states = self.norm(hidden_states) 279 | logits = self.token_head(hidden_states) 280 | return logits.float() 281 | 282 | # inp: [bs, hidden_size] 283 | # labels: [bs, num_codebooks] 284 | # logits: [bs, num_codebooks, codebook_size] 285 | def forward(self, inp, labels): 286 | bs = inp.shape[0] 287 | 288 | hidden_states = self.input_proj(inp) 289 | if self.freqs_cis.dtype != hidden_states.dtype: 290 | self.freqs_cis = self.freqs_cis.to(dtype=hidden_states.dtype) 291 | if labels is not None: 292 | # Training mode 293 | # Get embedding 294 | assert bs == labels.shape[0] and labels.shape[1] == self.config.num_codebooks, f"Labels shape error: {labels.shape}" 295 | hidden_states = [hidden_states[:, None, :], self.embeddings(labels[..., :-1]).to(hidden_states.dtype)] 296 | hidden_states = torch.cat(hidden_states, dim=1) # [bs, num_codebooks, hidden_size] 297 | # Run attention layers 298 | logits = self.run_model(hidden_states, self.freqs_cis, self.attention_mask) 299 | else: 300 | # Inference mode 301 | raise RuntimeError(f"Please call function \"inference\" in inference mode") 302 | return logits 303 | 304 | # inp: [bs, seq_len, hidden_size] 305 | # out_tokens: [bs, 1, num_codebooks] 306 | @torch.inference_mode() 307 | def inference(self, inp, temperature=0, top_k=0): 308 | # Only use the last hidden states for token computation 309 | inp = inp[:, -1:, :] 310 | 311 | bs = inp.shape[0] 312 | if self.freqs_cis.dtype != inp.dtype: 313 | self.freqs_cis = self.freqs_cis.to(dtype=inp.dtype) 314 | 315 | inp = self.input_proj(inp) 316 | 317 | # Inference mode 318 | kv_cache = KVCache( 319 | self.config.n_layer, 320 | bs, 321 | self.config.num_codebooks, 322 | self.config.n_head, 323 | self.config.head_dim, 324 | dtype=inp.dtype, 325 | device=inp.device, 326 | ) 327 | # Generate one token per step 328 | out_tokens = [] 329 | for input_pos in range(self.config.num_codebooks): 330 | inp = inp.reshape(bs, 1, self.config.dim) 331 | local_freqs_cis = self.freqs_cis[input_pos] 332 | local_mask = self.attention_mask[None, None, input_pos, :self.config.num_codebooks] 333 | 334 | logits = self.run_model(inp, local_freqs_cis, local_mask, input_pos, kv_cache) 335 | logits = logits.squeeze(dim=1) 336 | 337 | # Apply temperature and top-k 338 | if temperature > 0: 339 | logits = logits / temperature 340 | if top_k > 0: 341 | top_k = min(top_k, logits.size(-1)) # Safety check 342 | # Remove all tokens with a probability less than the last token of the top-k 343 | indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] 344 | logits = logits.masked_fill(indices_to_remove, -float("Inf")) 345 | 346 | # Do sample 347 | probs = nn.functional.softmax(logits, dim=-1) 348 | next_tokens = torch.multinomial(probs, num_samples=1) 349 | 350 | next_tokens = next_tokens.reshape(bs, 1, 1) 351 | inp = self.embeddings(next_tokens) 352 | out_tokens.append(next_tokens) 353 | 354 | return torch.cat(out_tokens, dim=-1) 355 | -------------------------------------------------------------------------------- /tokenize_func.py: -------------------------------------------------------------------------------- 1 | import io 2 | import copy 3 | import librosa 4 | import numpy as np 5 | 6 | 7 | AUDIO_TOKEN_FORMAT = "<|{}|>" 8 | 9 | DEFAULT_SYSTEM_START_TOKEN = "" 10 | DEFAULT_SYSTEM_END_TOKEN = "" 11 | 12 | DEFAULT_TTS_REF_START_TOKEN = "" 13 | DEFAULT_TTS_REF_END_TOKEN = "" 14 | DEFAULT_TTS_REF_TOKEN = "" 15 | 16 | DEFAULT_CHAT_REF_START_TOKEN = "" 17 | DEFAULT_CHAT_REF_END_TOKEN = "" 18 | DEFAULT_CHAT_REF_TOKEN = "" 19 | 20 | DEFAULT_HUMAN_TOKEN = "<|HUMAN|>" 21 | DEFAULT_ASSISTANT_TOKEN = "<|VOILA|>" 22 | 23 | DEFAULT_AUDIO_TOKEN = "" 24 | 25 | # =================================== 26 | # task special token 27 | # ----------------------------------- 28 | TASK_ASR_TOKEN = "" 29 | TASK_TTS_TOKEN = "" 30 | TASK_CHAT_TOKEN = "" 31 | TASK_STREAM_CHAT_TOKEN = "" 32 | 33 | TASK_ASR_TEXT_OUTPUT = "" 34 | TASK_TTS_AUDIO_OUTPUT = "" 35 | TASK_CHAT_TEXT_OUTPUT = "" 36 | TASK_CHAT_AUDIO_OUTPUT = "" 37 | 38 | CHAT_AUDIO_TEXT_SPLIT_TOKEN = "" 39 | # =================================== 40 | 41 | PREPEND_LEN = 80 42 | SEG_LEN = 640 43 | AUDIO_SR = 16000 44 | 45 | TASK_TYPE_CONF = { 46 | "chat_asr": TASK_ASR_TOKEN + TASK_ASR_TEXT_OUTPUT, 47 | "chat_tts": TASK_TTS_TOKEN + TASK_TTS_AUDIO_OUTPUT, 48 | "chat_tito": TASK_CHAT_TOKEN + TASK_CHAT_TEXT_OUTPUT, 49 | "chat_tiao": TASK_CHAT_TOKEN + TASK_CHAT_AUDIO_OUTPUT, 50 | "chat_aiao": TASK_CHAT_TOKEN + TASK_CHAT_AUDIO_OUTPUT, 51 | "chat_atiao": TASK_CHAT_TOKEN + TASK_CHAT_AUDIO_OUTPUT, 52 | "chat_aiao_auto": TASK_STREAM_CHAT_TOKEN + TASK_CHAT_AUDIO_OUTPUT, 53 | } 54 | 55 | 56 | def _get_zero_audio_pad(token_num): 57 | return np.zeros(SEG_LEN*token_num) 58 | 59 | def _wrapper_audio_tokens(audio_tokens, num_codebooks, codebook_size): 60 | ret_audio_tokens = [] 61 | for n in range(num_codebooks): 62 | audio_token = audio_tokens[n] 63 | ret_audio_tokens.append(''.join([AUDIO_TOKEN_FORMAT.format(au + n*codebook_size) if isinstance(au, int) else au for au in audio_token])) 64 | return ret_audio_tokens 65 | 66 | def _wrapper_audio_tokens_autonomous(audio_tokens, num_codebooks, codebook_size, audio_token_min_id): 67 | ret_audio_tokens = [] 68 | for n in range(num_codebooks): 69 | audio_token = audio_tokens[n] 70 | ret_audio_tokens.append([(au + n*codebook_size + audio_token_min_id) for au in audio_token]) 71 | return ret_audio_tokens 72 | 73 | # Item format 74 | # { 75 | # "instruction": "", 76 | # "conversations": [ 77 | # { 78 | # "from": "user" or "assistant", 79 | # "text": "", 80 | # "audio": { 81 | # "array": [], 82 | # "sr": 16000, 83 | # "bytes": "", 84 | # "file": "", 85 | # }, 86 | # } 87 | # ], 88 | # } 89 | def _token_input_format(item, tokenizer, tokenizer_voila, dataset_cfg): 90 | task_type = dataset_cfg["task_type"] 91 | num_codebooks = dataset_cfg["num_codebooks"] 92 | codebook_size = dataset_cfg["codebook_size"] 93 | 94 | task_token = TASK_TYPE_CONF[task_type] 95 | 96 | # Construct system message 97 | system = item["instruction"] 98 | if task_type in ["chat_aiao", "chat_atiao", "chat_tiao"]: 99 | system = DEFAULT_CHAT_REF_START_TOKEN + DEFAULT_CHAT_REF_TOKEN + DEFAULT_CHAT_REF_END_TOKEN + system 100 | elif task_type == "chat_tts": 101 | system = DEFAULT_TTS_REF_START_TOKEN + DEFAULT_TTS_REF_TOKEN + DEFAULT_TTS_REF_END_TOKEN + system 102 | else: 103 | print (f"task type {task_type} do not use ref.") 104 | system = task_token + system 105 | system = DEFAULT_SYSTEM_START_TOKEN + system + DEFAULT_SYSTEM_END_TOKEN 106 | 107 | # Get ids for system 108 | system_ids = tokenizer.encode(system, add_special_tokens=False) 109 | 110 | # Copy into num_codebooks input ids 111 | input_ids_list = [] 112 | for _ in range(num_codebooks): 113 | input_ids_list.append(copy.deepcopy(system_ids)) 114 | 115 | # Assemble conversations 116 | for i, turn in enumerate(item["conversations"]): 117 | if turn['from'] == 'assistant': 118 | # task with audio token as input, prepare audio token 119 | if task_type in ["chat_aiao", "chat_tts"]: 120 | if "audio" not in turn: 121 | content = DEFAULT_ASSISTANT_TOKEN 122 | content_ids = tokenizer.encode(content, add_special_tokens=False) 123 | for n in range(num_codebooks): 124 | input_ids_list[n] += copy.deepcopy(content_ids) 125 | else: 126 | # Load audio 127 | if 'array' in turn['audio']: 128 | assert "sr" in turn["audio"] 129 | if len(turn["audio"]['array'].shape) > 1: 130 | assert turn["audio"]['array'].shape[0] <= 2 131 | turn["audio"]['array'] = librosa.to_mono(turn["audio"]['array']) 132 | audio = librosa.resample(turn["audio"]['array'], orig_sr=turn["audio"]["sr"], target_sr=AUDIO_SR) 133 | elif "bytes" in turn['audio']: 134 | audio, _ = librosa.load(io.BytesIO(turn["audio"]['bytes']), sr=AUDIO_SR) 135 | elif "file" in turn['audio']: 136 | audio, _ = librosa.load(turn["audio"]['file'], sr=AUDIO_SR) 137 | else: 138 | raise Exception(f"No audio input for task {task_type}") 139 | 140 | # get audio token 141 | audio_tokens = tokenizer_voila.encode(audio, sr=AUDIO_SR) 142 | audio_tokens = audio_tokens.cpu().numpy().tolist() 143 | audio_tokens = _wrapper_audio_tokens(audio_tokens, num_codebooks, codebook_size) 144 | 145 | for n in range(num_codebooks): 146 | content = DEFAULT_ASSISTANT_TOKEN + audio_tokens[n] + tokenizer.eos_token 147 | content_ids = tokenizer.encode(content, add_special_tokens=False, truncation=True, 148 | max_length=tokenizer.model_max_length) 149 | input_ids_list[n] += content_ids 150 | 151 | elif task_type in ["chat_tito", "chat_asr"]: 152 | if "text" not in turn: 153 | content = DEFAULT_ASSISTANT_TOKEN 154 | content_ids = tokenizer.encode(content, add_special_tokens=False) 155 | for n in range(num_codebooks): 156 | input_ids_list[n] += copy.deepcopy(content_ids) 157 | else: 158 | text = turn['text'].strip() 159 | content = DEFAULT_ASSISTANT_TOKEN + text + tokenizer.eos_token 160 | content_ids = tokenizer.encode(content, add_special_tokens=False, truncation=True, 161 | max_length=tokenizer.model_max_length) 162 | for n in range(num_codebooks): 163 | input_ids_list[n] += copy.deepcopy(content_ids) 164 | else: 165 | raise ValueError (f"[Error] Invalid data type of {task_type}.") 166 | else: 167 | # task with audio token as input, prepare audio token 168 | if task_type in ["chat_aiao", "chat_asr"]: 169 | # Load audio 170 | assert "audio" in turn 171 | if 'array' in turn['audio']: 172 | assert "sr" in turn["audio"] 173 | if len(turn["audio"]['array'].shape) > 1: 174 | assert turn["audio"]['array'].shape[0] <= 2 175 | turn["audio"]['array'] = librosa.to_mono(turn["audio"]['array']) 176 | audio = librosa.resample(turn["audio"]['array'], orig_sr=turn["audio"]["sr"], target_sr=AUDIO_SR) 177 | elif "bytes" in turn['audio']: 178 | audio, _ = librosa.load(io.BytesIO(turn["audio"]['bytes']), sr=AUDIO_SR) 179 | elif "file" in turn['audio']: 180 | audio, _ = librosa.load(turn["audio"]['file'], sr=AUDIO_SR) 181 | else: 182 | raise Exception(f"No audio input for task {task_type}") 183 | 184 | # get audio token 185 | audio_tokens = tokenizer_voila.encode(audio, sr=AUDIO_SR) 186 | audio_tokens = audio_tokens.cpu().numpy().tolist() 187 | audio_tokens = _wrapper_audio_tokens(audio_tokens, num_codebooks, codebook_size) 188 | 189 | for n in range(num_codebooks): 190 | content = DEFAULT_HUMAN_TOKEN + audio_tokens[n] 191 | content_ids = tokenizer.encode(content, add_special_tokens=False, truncation=True, 192 | max_length=tokenizer.model_max_length) 193 | input_ids_list[n] += copy.deepcopy(content_ids) 194 | elif task_type in ["chat_tito", "chat_tts"]: 195 | text = turn['text'].strip() 196 | content = DEFAULT_HUMAN_TOKEN + text 197 | content_ids = tokenizer.encode(content, add_special_tokens=False, truncation=True, 198 | max_length=tokenizer.model_max_length) 199 | for n in range(num_codebooks): 200 | input_ids_list[n] += copy.deepcopy(content_ids) 201 | else: 202 | raise ValueError (f"[Error] Invalid data type of {task_type}.") 203 | 204 | for n in range(num_codebooks): 205 | input_ids_list[n] = input_ids_list[n][:tokenizer.model_max_length] 206 | 207 | return input_ids_list, None, None, None 208 | 209 | def _token_input_format_autonomous(item, tokenizer, tokenizer_voila, dataset_cfg): 210 | task_type = dataset_cfg["task_type"] 211 | num_codebooks = dataset_cfg["num_codebooks"] 212 | codebook_size = dataset_cfg["codebook_size"] 213 | assert task_type == "chat_aiao_auto", f"only support chat_aiao_auto, {task_type} is invalid" 214 | 215 | DEFAULT_HUMAN_TOKEN_ID = tokenizer.convert_tokens_to_ids(DEFAULT_HUMAN_TOKEN) 216 | assert isinstance(DEFAULT_HUMAN_TOKEN_ID, int), "DEFAULT_HUMAN_TOKEN_ID should be an integer" 217 | AUDIO_MIN_TOKEN_ID = tokenizer.convert_tokens_to_ids(AUDIO_TOKEN_FORMAT.format(0)) 218 | assert isinstance(AUDIO_MIN_TOKEN_ID, int), "AUDIO_MIN_TOKEN_ID should be an integer" 219 | 220 | task_token = TASK_TYPE_CONF[task_type] 221 | 222 | # Construct system message 223 | system = DEFAULT_CHAT_REF_START_TOKEN + DEFAULT_CHAT_REF_TOKEN + DEFAULT_CHAT_REF_END_TOKEN 224 | system = task_token + system 225 | system = DEFAULT_SYSTEM_START_TOKEN + system + DEFAULT_SYSTEM_END_TOKEN 226 | 227 | # Get ids for system 228 | system_ids_list = [[], []] 229 | system_ids = tokenizer.encode(system, add_special_tokens=False) 230 | 231 | # Insert instruction tokens into system prompt tokens 232 | instruction = item["instruction"] 233 | if instruction != "": 234 | instruction_ids = tokenizer.encode(instruction, add_special_tokens=False) 235 | else: 236 | instruction_ids = [] 237 | 238 | system_ids_list[0] = system_ids[:-1] + instruction_ids + system_ids[-1:] 239 | system_ids_list[1] = system_ids[:-1] + instruction_ids + system_ids[-1:] 240 | 241 | # Copy into num_codebooks input ids 242 | channel1_input_ids_list = [[] for _ in range(num_codebooks)] 243 | channel2_input_ids_list = [[] for _ in range(num_codebooks)] 244 | for n in range(num_codebooks): 245 | channel1_input_ids_list[n] += copy.deepcopy(system_ids_list[0]) + [DEFAULT_HUMAN_TOKEN_ID] 246 | channel2_input_ids_list[n] += copy.deepcopy(system_ids_list[1]) + [DEFAULT_HUMAN_TOKEN_ID] 247 | 248 | # prepare audio token to simulate streaming input 249 | audio_meta = item['conversations'][0]['audio'] 250 | if 'array' in audio_meta: 251 | assert "sr" in audio_meta 252 | if len(audio_meta['array'].shape) > 1: 253 | assert audio_meta['array'].shape[0] <= 2 254 | audio_meta['array'] = librosa.to_mono(audio_meta['array']) 255 | audio = librosa.resample(audio_meta['array'], orig_sr=audio_meta["sr"], target_sr=AUDIO_SR) 256 | elif "bytes" in audio_meta: 257 | audio, _ = librosa.load(io.BytesIO(audio_meta['bytes']), sr=AUDIO_SR) 258 | elif "file" in audio_meta: 259 | audio, _ = librosa.load(audio_meta['file'], sr=AUDIO_SR) 260 | else: 261 | raise Exception(f"No audio input for task {task_type}") 262 | 263 | # get audio token 264 | streaming_user_input_audio_tokens = tokenizer_voila.encode(audio, sr=AUDIO_SR) 265 | streaming_user_input_audio_tokens = streaming_user_input_audio_tokens.cpu().numpy().tolist() 266 | streaming_user_input_audio_tokens = _wrapper_audio_tokens_autonomous(streaming_user_input_audio_tokens, num_codebooks, codebook_size, AUDIO_MIN_TOKEN_ID) 267 | 268 | return [channel1_input_ids_list, channel2_input_ids_list], None, None, streaming_user_input_audio_tokens 269 | 270 | def _alpha_audio_input_format(item, tokenizer, dataset_cfg): 271 | task_type = dataset_cfg["task_type"] 272 | num_codebooks = dataset_cfg["num_codebooks"] 273 | codebook_size = dataset_cfg["codebook_size"] 274 | 275 | task_token = TASK_TYPE_CONF[task_type] 276 | 277 | # Construct system message 278 | system = item["instruction"] 279 | if task_type in ["chat_aiao", "chat_atiao", "chat_tiao"]: 280 | system = DEFAULT_CHAT_REF_START_TOKEN + DEFAULT_CHAT_REF_TOKEN + DEFAULT_CHAT_REF_END_TOKEN + system 281 | elif task_type == "chat_tts": 282 | system = DEFAULT_TTS_REF_START_TOKEN + DEFAULT_TTS_REF_TOKEN + DEFAULT_TTS_REF_END_TOKEN + system 283 | else: 284 | print (f"task type {task_type} do not use ref.") 285 | system = task_token + system 286 | system = DEFAULT_SYSTEM_START_TOKEN + system + DEFAULT_SYSTEM_END_TOKEN 287 | 288 | # Get ids for system 289 | system_ids = tokenizer.encode(system, add_special_tokens=False) 290 | 291 | # Copy into num_codebooks input ids 292 | input_ids_list = [] 293 | for _ in range(num_codebooks): 294 | input_ids_list.append(copy.deepcopy(system_ids)) 295 | 296 | # Construct audio data and mask 297 | audio_data = [np.array([0]*PREPEND_LEN)] 298 | audio_data.append(_get_zero_audio_pad(len(system_ids))) 299 | audio_data_mask = [0] * len(system_ids) 300 | 301 | # Assemble conversations 302 | for i, turn in enumerate(item["conversations"]): 303 | if turn['from'] == 'assistant': 304 | # task with audio token as input, prepare audio token 305 | if task_type in ["chat_aiao"]: 306 | if "audio" not in turn: 307 | content = DEFAULT_ASSISTANT_TOKEN 308 | content_ids = tokenizer.encode(content, add_special_tokens=False) 309 | for n in range(num_codebooks): 310 | input_ids_list[n] += copy.deepcopy(content_ids) 311 | # preprocess audio_data & audio_data_mask 312 | audio_data.append(_get_zero_audio_pad(len(content_ids))) 313 | audio_data_mask += [0] * len(content_ids) 314 | else: 315 | # Load audio 316 | if 'array' in turn['audio']: 317 | assert "sr" in turn["audio"] 318 | if len(turn["audio"]['array'].shape) > 1: 319 | assert turn["audio"]['array'].shape[0] <= 2 320 | turn["audio"]['array'] = librosa.to_mono(turn["audio"]['array']) 321 | audio = librosa.resample(turn["audio"]['array'], orig_sr=turn["audio"]["sr"], target_sr=AUDIO_SR) 322 | elif "bytes" in turn['audio']: 323 | audio, _ = librosa.load(io.BytesIO(turn["audio"]['bytes']), sr=AUDIO_SR) 324 | elif "file" in turn['audio']: 325 | audio, _ = librosa.load(turn["audio"]['file'], sr=AUDIO_SR) 326 | else: 327 | raise Exception(f"No audio input for task {task_type}") 328 | 329 | # get audio token 330 | audio_token_num = int(len(audio) / SEG_LEN) 331 | audio_token = [DEFAULT_AUDIO_TOKEN] * audio_token_num 332 | audio_token = ''.join(audio_token) 333 | audio = audio[:SEG_LEN*audio_token_num] # trim audio 334 | 335 | content = DEFAULT_ASSISTANT_TOKEN + audio_token + tokenizer.eos_token 336 | content_ids = tokenizer.encode(content, add_special_tokens=False, truncation=True, 337 | max_length=tokenizer.model_max_length) 338 | for n in range(num_codebooks): 339 | input_ids_list[n] += copy.deepcopy(content_ids) 340 | 341 | audio_data.append(_get_zero_audio_pad(1)) 342 | audio_data_mask += [0] 343 | audio_data.append(audio) 344 | audio_data_mask += [1] * audio_token_num 345 | audio_data.append(_get_zero_audio_pad(1)) 346 | audio_data_mask += [0] 347 | elif task_type in ["chat_tito"]: 348 | if "text" not in turn: 349 | content = DEFAULT_ASSISTANT_TOKEN 350 | content_ids = tokenizer.encode(content, add_special_tokens=False) 351 | for n in range(num_codebooks): 352 | input_ids_list[n] += copy.deepcopy(content_ids) 353 | # preprocess audio_data & audio_data_mask 354 | audio_data.append(_get_zero_audio_pad(len(content_ids))) 355 | audio_data_mask += [0] * len(content_ids) 356 | else: 357 | text = turn['text'].strip() 358 | content = DEFAULT_ASSISTANT_TOKEN + text + tokenizer.eos_token 359 | content_ids = tokenizer.encode(content, add_special_tokens=False, truncation=True, 360 | max_length=tokenizer.model_max_length) 361 | for n in range(num_codebooks): 362 | input_ids_list[n] += copy.deepcopy(content_ids) 363 | audio_data.append(_get_zero_audio_pad(len(content_ids))) 364 | audio_data_mask += [0] * len(content_ids) 365 | else: 366 | raise ValueError (f"[Error] Invalid data type of {task_type}.") 367 | else: 368 | # task with audio token as input, prepare audio token 369 | if task_type in ["chat_aiao"]: 370 | # Load audio 371 | assert "audio" in turn 372 | if 'array' in turn['audio']: 373 | assert "sr" in turn["audio"] 374 | if len(turn["audio"]['array'].shape) > 1: 375 | assert turn["audio"]['array'].shape[0] <= 2 376 | turn["audio"]['array'] = librosa.to_mono(turn["audio"]['array']) 377 | audio = librosa.resample(turn["audio"]['array'], orig_sr=turn["audio"]["sr"], target_sr=AUDIO_SR) 378 | elif "bytes" in turn['audio']: 379 | audio, _ = librosa.load(io.BytesIO(turn["audio"]['bytes']), sr=AUDIO_SR) 380 | elif "file" in turn['audio']: 381 | audio, _ = librosa.load(turn["audio"]['file'], sr=AUDIO_SR) 382 | else: 383 | raise Exception(f"No audio input for task {task_type}") 384 | 385 | # get audio token 386 | audio_token_num = int(len(audio) / SEG_LEN) 387 | audio_token = [DEFAULT_AUDIO_TOKEN] * audio_token_num 388 | audio_token = ''.join(audio_token) 389 | audio = audio[:SEG_LEN*audio_token_num] # trim audio 390 | 391 | content = DEFAULT_HUMAN_TOKEN + audio_token 392 | content_ids = tokenizer.encode(content, add_special_tokens=False, truncation=True, 393 | max_length=tokenizer.model_max_length) 394 | for n in range(num_codebooks): 395 | input_ids_list[n] += copy.deepcopy(content_ids) 396 | 397 | audio_data.append(_get_zero_audio_pad(1)) 398 | audio_data_mask += [0] 399 | audio_data.append(audio) 400 | audio_data_mask += [1] * audio_token_num 401 | elif task_type in ["chat_tito"]: 402 | text = turn['text'].strip() 403 | content = DEFAULT_HUMAN_TOKEN + text 404 | content_ids = tokenizer.encode(content, add_special_tokens=False, truncation=True, 405 | max_length=tokenizer.model_max_length) 406 | for n in range(num_codebooks): 407 | input_ids_list[n] += copy.deepcopy(content_ids) 408 | audio_data.append(_get_zero_audio_pad(len(content_ids))) 409 | audio_data_mask += [0] * len(content_ids) 410 | else: 411 | raise ValueError (f"[Error] Invalid data type of {task_type}.") 412 | 413 | for n in range(num_codebooks): 414 | input_ids_list[n] = input_ids_list[n][:tokenizer.model_max_length] 415 | audio_data_mask = audio_data_mask[:tokenizer.model_max_length] 416 | audio_data = np.concatenate(audio_data) 417 | audio_data = audio_data[:PREPEND_LEN + tokenizer.model_max_length*SEG_LEN] 418 | 419 | return input_ids_list, audio_data, audio_data_mask, None 420 | 421 | # Item format 422 | # { 423 | # "instruction": "", 424 | # "conversations": [ 425 | # { 426 | # "from": "user" or "assistant", 427 | # "text": "", 428 | # "audio": { 429 | # "array": [], 430 | # "sr": 16000, 431 | # "bytes": "", 432 | # "file": "", 433 | # }, 434 | # } 435 | # ], 436 | # } 437 | def voila_input_format(item, tokenizer, tokenizer_voila, dataset_cfg): 438 | if dataset_cfg["input_type"] == "audio": 439 | return _alpha_audio_input_format(item, tokenizer, dataset_cfg) 440 | elif dataset_cfg["input_type"] == "autonomous": 441 | return _token_input_format_autonomous(item, tokenizer, tokenizer_voila, dataset_cfg) 442 | else: 443 | return _token_input_format(item, tokenizer, tokenizer_voila, dataset_cfg) 444 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import math 2 | from dataclasses import dataclass 3 | from typing import List, Optional, Tuple, Union, Dict, Any 4 | 5 | import torch 6 | from torch import nn 7 | import torch.nn.functional as F 8 | from torch.nn import CrossEntropyLoss 9 | 10 | from transformers.cache_utils import Cache, DynamicCache 11 | from transformers.utils import ModelOutput, logging 12 | from transformers.models.llama.modeling_llama import LlamaModel, LlamaPreTrainedModel 13 | 14 | from audio_transformer import AudioTransformer 15 | 16 | logger = logging.get_logger(__name__) 17 | 18 | 19 | # Copied from https://github.com/pytorch/audio/blob/main/src/torchaudio/models/wav2vec2/components.py#L43 20 | class LayerNorm(torch.nn.LayerNorm): 21 | """Layer norm with transpose""" 22 | 23 | def forward(self, input: torch.Tensor) -> torch.Tensor: 24 | x = input.transpose(-2, -1) 25 | x = torch.nn.functional.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) 26 | x = x.transpose(-2, -1) 27 | return x 28 | 29 | # Copied from https://github.com/pytorch/audio/blob/main/src/torchaudio/models/wav2vec2/components.py#L53 30 | class ConvLayerBlock(torch.nn.Module): 31 | """Convolution unit of FeatureExtractor""" 32 | 33 | def __init__( 34 | self, 35 | in_channels: int, 36 | out_channels: int, 37 | kernel_size: int, 38 | stride: int, 39 | bias: bool, 40 | layer_norm: Optional[torch.nn.Module], 41 | ): 42 | super().__init__() 43 | self.kernel_size = kernel_size 44 | self.stride = stride 45 | self.layer_norm = layer_norm 46 | self.conv = torch.nn.Conv1d( 47 | in_channels=in_channels, 48 | out_channels=out_channels, 49 | kernel_size=kernel_size, 50 | stride=stride, 51 | bias=bias, 52 | ) 53 | 54 | def forward( 55 | self, 56 | x: torch.Tensor, 57 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: 58 | """ 59 | Args: 60 | x (Tensor): Shape: ``[batch, in_channels, in_frame]``. 61 | Returns: 62 | Tensor: Shape ``[batch, out_channels, out_frames]``. 63 | Optional[Tensor]: Shape ``[batch, ]``. 64 | """ 65 | x = self.conv(x) 66 | if self.layer_norm is not None: 67 | x = self.layer_norm(x) 68 | x = torch.nn.functional.gelu(x) 69 | 70 | return x 71 | 72 | # Copied from https://github.com/pytorch/audio/blob/main/src/torchaudio/models/wav2vec2/components.py#L146 73 | class FeatureProjection(torch.nn.Module): 74 | """Layer that connects FeatureExtractor and Encoder 75 | 76 | Projects features to encoder dimension. 77 | 78 | Args: 79 | in_features (int): Input feature dim. 80 | out_features (int): Output feature dim. 81 | dropout (float): Dropout probability. 82 | """ 83 | 84 | def __init__( 85 | self, 86 | in_features: int, 87 | out_features: int, 88 | dropout=0.1, 89 | ): 90 | super().__init__() 91 | self.layer_norm = torch.nn.LayerNorm(in_features) 92 | self.projection = torch.nn.Linear( 93 | in_features, 94 | out_features, 95 | ) 96 | self.dropout = torch.nn.Dropout(dropout) 97 | 98 | def forward(self, x): 99 | """ 100 | Args: 101 | x (Tensor): 102 | Feature Tensor. shape: ``[batch, frame, in_feature]`` 103 | Returns: 104 | Tensor: Projected features. ``[batch, frame, out_feature]``. 105 | """ 106 | x = self.layer_norm(x) 107 | x = self.projection(x) 108 | x = self.dropout(x) 109 | return x 110 | 111 | # Modified from https://github.com/pytorch/audio/blob/main/src/torchaudio/models/wav2vec2/components.py#L102 112 | class FeatureExtractor(torch.nn.Module): 113 | """Extract features from audio 114 | 115 | Args: 116 | conv_layers (nn.ModuleList): 117 | convolution layers 118 | """ 119 | 120 | def __init__( 121 | self, 122 | shapes=[(512, 10, 5), (512, 3, 2), (512, 3, 2), (512, 3, 2), (512, 3, 2), (512, 2, 2), (512, 2, 2)], 123 | bias=False, 124 | norm_mode="group_norm", 125 | ): 126 | super().__init__() 127 | if norm_mode not in ["group_norm", "layer_norm"]: 128 | raise ValueError("Invalid norm mode") 129 | blocks = [] 130 | in_channels = 1 131 | for i, (out_channels, kernel_size, stride) in enumerate(shapes): 132 | normalization = None 133 | if norm_mode == "group_norm" and i == 0: 134 | normalization = torch.nn.GroupNorm( 135 | num_groups=out_channels, 136 | num_channels=out_channels, 137 | affine=True, 138 | ) 139 | elif norm_mode == "layer_norm": 140 | normalization = LayerNorm( 141 | normalized_shape=out_channels, 142 | elementwise_affine=True, 143 | ) 144 | blocks.append( 145 | ConvLayerBlock( 146 | in_channels=in_channels, 147 | out_channels=out_channels, 148 | kernel_size=kernel_size, 149 | stride=stride, 150 | bias=bias, 151 | layer_norm=normalization, 152 | ) 153 | ) 154 | in_channels = out_channels 155 | self.conv_layers = torch.nn.ModuleList(blocks) 156 | 157 | def forward( 158 | self, 159 | x: torch.Tensor, 160 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: 161 | """ 162 | Args: 163 | x (Tensor): 164 | Input Tensor representing a batch of audio, 165 | shape: ``[batch, time]``. 166 | 167 | Returns: 168 | Tensor: 169 | The resulting feature, shape: ``[batch, frame, feature]`` 170 | Optional[Tensor]: 171 | Valid length of each output sample. shape: ``[batch, ]``. 172 | """ 173 | if x.ndim != 2: 174 | raise ValueError(f"Expected the input Tensor to be 2D (batch, time). Found: {list(x.shape)}") 175 | 176 | x = x.unsqueeze(1) # (batch, channel==1, frame) 177 | for layer in self.conv_layers: 178 | x = layer(x) # (batch, feature, frame) 179 | x = x.transpose(1, 2) # (batch, frame, feature) 180 | return x 181 | 182 | # Modified from https://github.com/pytorch/audio/blob/main/src/torchaudio/models/wav2vec2/components.py#L102 183 | class FeatureExtractorAdapter(torch.nn.Module): 184 | """Extract features from audio 185 | 186 | Args: 187 | conv_layers (nn.ModuleList): 188 | convolution layers 189 | """ 190 | 191 | def __init__( 192 | self, 193 | shapes=(512, 512, 2, 2), 194 | hidden_size=2048, 195 | bias=False, 196 | norm_mode="group_norm", 197 | ): 198 | super().__init__() 199 | if norm_mode not in ["group_norm", "layer_norm"]: 200 | raise ValueError("Invalid norm mode") 201 | in_channels, out_channels, kernel_size, stride = shapes 202 | normalization = LayerNorm( 203 | normalized_shape=out_channels, 204 | elementwise_affine=True, 205 | ) 206 | self.conv_layers = ConvLayerBlock( 207 | in_channels=in_channels, 208 | out_channels=out_channels, 209 | kernel_size=kernel_size, 210 | stride=stride, 211 | bias=False, 212 | layer_norm=normalization, 213 | ) 214 | self.feat_proj = FeatureProjection(out_channels, hidden_size) 215 | 216 | def forward( 217 | self, 218 | x: torch.Tensor, 219 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: 220 | """ 221 | Args: 222 | x (Tensor): 223 | Input Tensor representing a batch of audio, 224 | shape: ``[batch, time]``. 225 | 226 | Returns: 227 | Tensor: 228 | The resulting feature, shape: ``[batch, frame, feature]`` 229 | Optional[Tensor]: 230 | Valid length of each output sample. shape: ``[batch, ]``. 231 | """ 232 | x = x.transpose(1, 2) # (batch, feature, frame) 233 | x = self.conv_layers(x) # (batch, feature, frame) 234 | x = x.transpose(1, 2) # (batch, frame, feature) 235 | x = self.feat_proj(x) 236 | return x 237 | 238 | @dataclass 239 | class VoilaOutput(ModelOutput): 240 | """ 241 | Modified from https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_outputs.py#L678 242 | 243 | Base class for Voila outputs. 244 | 245 | Args: 246 | loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): 247 | Language modeling loss (for next-token prediction). 248 | logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): 249 | Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). 250 | last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): 251 | The hidden state of the last attention layer. 252 | past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): 253 | Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape 254 | `(batch_size, num_heads, sequence_length, embed_size_per_head)`) 255 | 256 | Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see 257 | `past_key_values` input) to speed up sequential decoding. 258 | hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): 259 | Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + 260 | one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. 261 | 262 | Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. 263 | attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): 264 | Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, 265 | sequence_length)`. 266 | 267 | Attentions weights after the attention softmax, used to compute the weighted average in the self-attention 268 | heads. 269 | """ 270 | 271 | loss: Optional[torch.FloatTensor] = None 272 | logits: torch.FloatTensor = None 273 | last_hidden_state: torch.FloatTensor = None 274 | past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None 275 | hidden_states: Optional[Tuple[torch.FloatTensor]] = None 276 | attentions: Optional[Tuple[torch.FloatTensor]] = None 277 | voila_pred: Optional[torch.FloatTensor] = None 278 | 279 | 280 | # Modified from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L1103 281 | class VoilaModel(LlamaPreTrainedModel): 282 | _tied_weights_keys = ["lm_head.weight"] 283 | 284 | def __init__(self, config): 285 | super().__init__(config) 286 | self.model = LlamaModel(config) 287 | self.vocab_size = config.vocab_size 288 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) 289 | self.pad_vocab_size_multiple = 64 290 | 291 | self.ref_emb_linear = nn.Linear(256, config.hidden_size, bias=True) 292 | self.audio_transformer = AudioTransformer(config, use_sdpa=False) 293 | 294 | # Initialize weights and apply final processing 295 | self.post_init() 296 | 297 | def get_input_embeddings(self): 298 | return self.model.embed_tokens 299 | 300 | def set_input_embeddings(self, value): 301 | self.model.embed_tokens = value 302 | 303 | def get_output_embeddings(self): 304 | return self.lm_head 305 | 306 | def set_output_embeddings(self, new_embeddings): 307 | self.lm_head = new_embeddings 308 | 309 | def set_decoder(self, decoder): 310 | self.model = decoder 311 | 312 | def get_decoder(self): 313 | return self.model 314 | 315 | def forward( 316 | self, 317 | input_ids: torch.LongTensor = None, 318 | attention_mask: Optional[torch.Tensor] = None, 319 | position_ids: Optional[torch.LongTensor] = None, 320 | past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, 321 | inputs_embeds: Optional[torch.FloatTensor] = None, 322 | labels: Optional[torch.LongTensor] = None, 323 | audio_labels: Optional[torch.LongTensor] = None, 324 | ref_embs: Optional[List[torch.Tensor]] = None, 325 | ref_embs_mask: Optional[torch.LongTensor] = None, 326 | use_cache: Optional[bool] = None, 327 | output_attentions: Optional[bool] = None, 328 | output_hidden_states: Optional[bool] = None, 329 | return_dict: Optional[bool] = None, 330 | cache_position: Optional[torch.LongTensor] = None, 331 | num_logits_to_keep: int = 0, 332 | ) -> Union[Tuple, VoilaOutput]: 333 | r""" 334 | Args: 335 | input_ids: [bs, seq_len, num_codebooks] 336 | labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): 337 | Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., 338 | config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored 339 | (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. 340 | """ 341 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 342 | output_hidden_states = ( 343 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 344 | ) 345 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 346 | 347 | if input_ids is not None and inputs_embeds is not None: 348 | raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") 349 | if inputs_embeds is None: 350 | inputs_embeds = self.model.embed_tokens(input_ids) 351 | assert len(inputs_embeds.shape) == 4 352 | if len(inputs_embeds.shape) == 4: 353 | inputs_embeds = inputs_embeds.mean(dim=2) 354 | 355 | if self.training or \ 356 | (past_key_values is None and ref_embs is not None) or \ 357 | (past_key_values is not None and past_key_values.get_seq_length() < 4 and ref_embs is not None): 358 | ref_embs = self.ref_emb_linear(ref_embs.to(self.ref_emb_linear.weight.dtype)) 359 | ref_embs = ref_embs * ref_embs_mask.unsqueeze(-1).unsqueeze(-1) 360 | # (padding_left,padding_right,padding_top,padding_bottom,padding_front,padding_back) 361 | padding = (0, 0, 4, inputs_embeds.shape[1] - 5, 0, 0) 362 | ref_embs = torch.nn.functional.pad(ref_embs, padding, mode='constant', value=0.0) 363 | inputs_embeds = inputs_embeds + ref_embs 364 | 365 | # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) 366 | outputs = self.model( 367 | attention_mask=attention_mask, 368 | position_ids=position_ids, 369 | past_key_values=past_key_values, 370 | inputs_embeds=inputs_embeds, 371 | use_cache=use_cache, 372 | output_attentions=output_attentions, 373 | output_hidden_states=output_hidden_states, 374 | return_dict=return_dict, 375 | cache_position=cache_position, 376 | ) 377 | 378 | hidden_states = outputs[0] 379 | if self.config.pretraining_tp > 1: 380 | lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0) 381 | logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)] 382 | logits = torch.cat(logits, dim=-1) 383 | else: 384 | # Only compute necessary logits, and do not upcast them to float if we are not computing the loss 385 | logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) 386 | 387 | loss = None 388 | 389 | if not return_dict: 390 | output = (logits,) + outputs[1:] 391 | return (loss,) + output if loss is not None else output 392 | 393 | return VoilaOutput( 394 | loss=loss, 395 | logits=logits, 396 | last_hidden_state=hidden_states, 397 | past_key_values=outputs.past_key_values, 398 | hidden_states=outputs.hidden_states, 399 | attentions=outputs.attentions, 400 | ) 401 | 402 | def _prepare_inputs_for_generation( 403 | self, input_ids, ref_embs=None, ref_embs_mask=None, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs 404 | ): 405 | if past_key_values is not None and past_key_values.get_seq_length() > 0: 406 | if isinstance(past_key_values, Cache): 407 | cache_length = past_key_values.get_seq_length() 408 | past_length = past_key_values.seen_tokens 409 | max_cache_length = past_key_values.get_max_cache_shape() 410 | else: 411 | cache_length = past_length = past_key_values[0][0].shape[2] 412 | max_cache_length = None 413 | 414 | # Keep only the unprocessed tokens: 415 | # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where 416 | # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as 417 | # input) 418 | if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: 419 | input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] 420 | # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard 421 | # input_ids based on the past_length. 422 | elif past_length < input_ids.shape[1]: 423 | input_ids = input_ids[:, past_length:] 424 | # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. 425 | 426 | # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. 427 | if ( 428 | max_cache_length is not None 429 | and attention_mask is not None 430 | and cache_length + input_ids.shape[1] > max_cache_length 431 | ): 432 | attention_mask = attention_mask[:, -max_cache_length:] 433 | 434 | position_ids = kwargs.get("position_ids", None) 435 | if attention_mask is not None and position_ids is None: 436 | # create position_ids on the fly for batch generation 437 | position_ids = attention_mask.long().cumsum(-1) - 1 438 | position_ids.masked_fill_(attention_mask == 0, 1) 439 | if past_key_values: 440 | position_ids = position_ids[:, -input_ids.shape[1] :] 441 | 442 | # if `inputs_embeds` are passed, we only want to use them in the 1st generation step 443 | if inputs_embeds is None and \ 444 | (past_key_values is None or past_key_values.get_seq_length() <= 0): 445 | inputs_embeds = self.model.embed_tokens(input_ids) 446 | if inputs_embeds is not None and \ 447 | (past_key_values is None or past_key_values.get_seq_length() <= 0): 448 | model_inputs = {"inputs_embeds": inputs_embeds, "ref_embs": ref_embs, "ref_embs_mask": ref_embs_mask} 449 | else: 450 | model_inputs = {"input_ids": input_ids, "ref_embs": None} 451 | 452 | model_inputs.update( 453 | { 454 | "position_ids": position_ids, 455 | "past_key_values": past_key_values, 456 | "use_cache": kwargs.get("use_cache"), 457 | "attention_mask": attention_mask, 458 | } 459 | ) 460 | return model_inputs 461 | 462 | def _update_model_kwargs_for_generation( 463 | self, 464 | outputs, 465 | model_kwargs: Dict[str, Any], 466 | num_new_token: int = 1, 467 | ) -> Dict[str, Any]: 468 | # update past_key_values 469 | model_kwargs["past_key_values"] = outputs.past_key_values 470 | 471 | # update attention mask 472 | if "attention_mask" in model_kwargs: 473 | attention_mask = model_kwargs["attention_mask"] 474 | model_kwargs["attention_mask"] = torch.cat( 475 | [attention_mask, attention_mask.new_ones((attention_mask.shape[0], num_new_token))], dim=-1 476 | ) 477 | 478 | return model_kwargs 479 | 480 | def _prepare_attention_mask_for_generation( 481 | self, 482 | inputs: torch.Tensor, 483 | pad_token_id: Optional[int], 484 | eos_token_id: Optional[Union[int, List[int]]], 485 | ) -> torch.LongTensor: 486 | is_input_ids = len(inputs.shape) == 2 and inputs.dtype in [torch.int, torch.long] 487 | is_pad_token_in_inputs = (pad_token_id is not None) and (pad_token_id in inputs) 488 | if isinstance(eos_token_id, int): 489 | eos_token_id = [eos_token_id] 490 | is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or (pad_token_id not in eos_token_id) 491 | 492 | # Check if input is input_ids and padded -> only then is attention_mask defined 493 | if is_input_ids and is_pad_token_in_inputs and is_pad_token_not_equal_to_eos_token_id: 494 | return inputs.ne(pad_token_id).long() 495 | else: 496 | return torch.ones(inputs.shape[:2], dtype=torch.long, device=inputs.device) 497 | 498 | @torch.inference_mode() 499 | def run_generate( 500 | self, 501 | input_ids: torch.LongTensor, 502 | ref_embs: Optional[List[torch.Tensor]] = None, 503 | ref_embs_mask: Optional[torch.LongTensor] = None, 504 | max_new_tokens: Optional[int] = 128, 505 | pad_token_id: Optional[int] = None, 506 | eos_token_id: Optional[Union[int, List[int]]] = None, 507 | streamer: Optional["BaseStreamer"] = None, 508 | llm_audio_token_id: Optional[int] = None, 509 | min_audio_token_id: Optional[int] = None, 510 | temperature=0.2, 511 | top_k=50, 512 | audio_temperature=0.2, 513 | audio_top_k=50, 514 | ): 515 | assert eos_token_id is not None and pad_token_id is not None, "eos_token_id and pad_token_id are required for inference" 516 | assert llm_audio_token_id is not None and min_audio_token_id is not None, "llm_audio_token_id and min_audio_token_id are required for inference" 517 | assert len(input_ids.shape) == 2 or len(input_ids.shape) == 3, f"input_ids is supposed to be [batch, seq_len] or [batch, seq_len, num_codebooks], and got {input_ids.shape}" 518 | 519 | eos_token_id_tensor = torch.tensor([eos_token_id]).to(input_ids.device) 520 | 521 | # keep track of which sequences are already finished 522 | unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device) 523 | 524 | # Extend input_ids with additional num_codebooks dim 525 | if len(input_ids.shape) == 2: 526 | input_ids = input_ids[:, :, None].expand(1, 1, self.config.num_codebooks) 527 | 528 | this_peer_finished = False # used by synced_gpus only 529 | max_length = input_ids.shape[1] + max_new_tokens 530 | 531 | model_kwargs = { 532 | "use_cache": True, 533 | "past_key_values": DynamicCache(), 534 | "attention_mask": self._prepare_attention_mask_for_generation( 535 | input_ids, pad_token_id, eos_token_id 536 | ), 537 | } 538 | # auto-regressive generation 539 | while True: 540 | # prepare model inputs 541 | model_inputs = self._prepare_inputs_for_generation( 542 | input_ids, 543 | ref_embs=ref_embs, 544 | ref_embs_mask=ref_embs_mask, 545 | **model_kwargs 546 | ) 547 | 548 | # forward pass to get next token 549 | outputs = self( 550 | **model_inputs, 551 | return_dict=True, 552 | ) 553 | audio_tokens = self.audio_transformer.inference( 554 | outputs.last_hidden_state, 555 | temperature=audio_temperature, 556 | top_k=audio_top_k, 557 | ) 558 | audio_tokens = torch.stack( 559 | [ 560 | audio_tokens[:, :, ci] + min_audio_token_id + ci*self.config.codebook_size 561 | for ci in range(self.config.num_codebooks) 562 | ], 563 | dim=2, 564 | ) 565 | 566 | next_token_logits = outputs.logits[:, -1, :] 567 | 568 | # pre-process distribution 569 | # Apply temperature and top-k 570 | if temperature > 0: 571 | next_token_logits = next_token_logits / temperature 572 | if top_k > 0: 573 | top_k = min(top_k, next_token_logits.size(-1)) # Safety check 574 | # Remove all tokens with a probability less than the last token of the top-k 575 | indices_to_remove = next_token_logits < torch.topk(next_token_logits, top_k)[0][..., -1, None] 576 | next_token_logits = next_token_logits.masked_fill(indices_to_remove, -float("Inf")) 577 | 578 | # sample 579 | probs = nn.functional.softmax(next_token_logits, dim=-1) 580 | next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) 581 | 582 | # finished sentences should have their next token be a padding token 583 | if eos_token_id is not None: 584 | if pad_token_id is None: 585 | raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.") 586 | next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) 587 | 588 | # Append NUM_CODEBOOK text tokens or audio_tokens 589 | if len(next_tokens.shape) == 1: 590 | next_tokens = next_tokens[:, None, None].expand(-1, 1, self.config.num_codebooks) 591 | next_tokens = torch.where(next_tokens==llm_audio_token_id, audio_tokens, next_tokens) 592 | 593 | input_ids = torch.cat([input_ids, next_tokens], dim=1) 594 | if streamer is not None: 595 | streamer.put(next_tokens.cpu()) 596 | model_kwargs = self._update_model_kwargs_for_generation( 597 | outputs, model_kwargs 598 | ) 599 | 600 | # if eos_token was found in one sentence, set sentence to finished 601 | if eos_token_id_tensor is not None: 602 | unfinished_sequences = unfinished_sequences.mul( 603 | next_tokens[:, :, 0].ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=1) 604 | ) 605 | 606 | # stop when each sentence is finished 607 | if unfinished_sequences.max() == 0: 608 | this_peer_finished = True 609 | 610 | # stop if we exceed the maximum length 611 | if input_ids.shape[1] >= max_length: 612 | this_peer_finished = True 613 | 614 | if this_peer_finished: 615 | break 616 | 617 | if streamer is not None: 618 | streamer.end() 619 | 620 | return input_ids 621 | 622 | 623 | # Modified from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L1103 624 | class VoilaAudioAlphaModel(LlamaPreTrainedModel): 625 | _tied_weights_keys = ["lm_head.weight"] 626 | 627 | def __init__(self, config): 628 | super().__init__(config) 629 | self.model = LlamaModel(config) 630 | self.vocab_size = config.vocab_size 631 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) 632 | self.pad_vocab_size_multiple = 64 633 | 634 | 635 | self.ref_emb_linear = nn.Linear(256, config.hidden_size, bias=True) 636 | self.audio_transformer = AudioTransformer(config, use_sdpa=False) 637 | 638 | self.feature_extractor = FeatureExtractor() 639 | self.audio_feature_extractor_adapter = FeatureExtractorAdapter(hidden_size=config.hidden_size) 640 | 641 | # Initialize weights and apply final processing 642 | self.post_init() 643 | 644 | def get_input_embeddings(self): 645 | return self.model.embed_tokens 646 | 647 | def set_input_embeddings(self, value): 648 | self.model.embed_tokens = value 649 | 650 | def get_output_embeddings(self): 651 | return self.lm_head 652 | 653 | def set_output_embeddings(self, new_embeddings): 654 | self.lm_head = new_embeddings 655 | 656 | def set_decoder(self, decoder): 657 | self.model = decoder 658 | 659 | def get_decoder(self): 660 | return self.model 661 | 662 | def forward( 663 | self, 664 | input_ids: torch.LongTensor = None, 665 | attention_mask: Optional[torch.Tensor] = None, 666 | position_ids: Optional[torch.LongTensor] = None, 667 | past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, 668 | inputs_embeds: Optional[torch.FloatTensor] = None, 669 | labels: Optional[torch.LongTensor] = None, 670 | audio_labels: Optional[torch.LongTensor] = None, 671 | ref_embs: Optional[List[torch.Tensor]] = None, 672 | ref_embs_mask: Optional[torch.LongTensor] = None, 673 | audio_datas: Optional[torch.FloatTensor] = None, 674 | audio_data_masks: Optional[torch.LongTensor] = None, 675 | use_cache: Optional[bool] = None, 676 | output_attentions: Optional[bool] = None, 677 | output_hidden_states: Optional[bool] = None, 678 | return_dict: Optional[bool] = None, 679 | cache_position: Optional[torch.LongTensor] = None, 680 | num_logits_to_keep: int = 0, 681 | ) -> Union[Tuple, VoilaOutput]: 682 | r""" 683 | Args: 684 | input_ids: [bs, seq_len, num_codebooks] 685 | labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): 686 | Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., 687 | config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored 688 | (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. 689 | """ 690 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 691 | output_hidden_states = ( 692 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 693 | ) 694 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 695 | 696 | if input_ids is not None and inputs_embeds is not None: 697 | raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") 698 | if inputs_embeds is None: 699 | inputs_embeds = self.model.embed_tokens(input_ids) 700 | assert len(inputs_embeds.shape) == 4 701 | if len(inputs_embeds.shape) == 4: 702 | inputs_embeds = inputs_embeds.mean(dim=2) 703 | 704 | if self.training or \ 705 | (past_key_values is None and ref_embs is not None) or \ 706 | (past_key_values is not None and past_key_values.get_seq_length() < 4 and ref_embs is not None): 707 | ref_embs = self.ref_emb_linear(ref_embs.to(self.ref_emb_linear.weight.dtype)) 708 | ref_embs = ref_embs * ref_embs_mask.unsqueeze(-1).unsqueeze(-1) 709 | # (padding_left,padding_right,padding_top,padding_bottom,padding_front,padding_back) 710 | padding = (0, 0, 4, inputs_embeds.shape[1] - 5, 0, 0) 711 | ref_embs = torch.nn.functional.pad(ref_embs, padding, mode='constant', value=0.0) 712 | inputs_embeds = inputs_embeds + ref_embs 713 | 714 | if self.training or audio_datas is not None: 715 | audio_embeds = self.feature_extractor(audio_datas) 716 | audio_embeds = self.audio_feature_extractor_adapter(audio_embeds) 717 | audio_embeds = audio_embeds * audio_data_masks[..., None] 718 | inputs_embeds = inputs_embeds + audio_embeds 719 | 720 | # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) 721 | outputs = self.model( 722 | attention_mask=attention_mask, 723 | position_ids=position_ids, 724 | past_key_values=past_key_values, 725 | inputs_embeds=inputs_embeds, 726 | use_cache=use_cache, 727 | output_attentions=output_attentions, 728 | output_hidden_states=output_hidden_states, 729 | return_dict=return_dict, 730 | cache_position=cache_position, 731 | ) 732 | 733 | hidden_states = outputs[0] 734 | if self.config.pretraining_tp > 1: 735 | lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0) 736 | logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)] 737 | logits = torch.cat(logits, dim=-1) 738 | else: 739 | # Only compute necessary logits, and do not upcast them to float if we are not computing the loss 740 | logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) 741 | 742 | loss = None 743 | if labels is not None: 744 | # Upcast to float if we need to compute the loss to avoid potential precision issues 745 | logits = logits.float() 746 | # We shift tokens and labels in dataloader 747 | shift_logits = logits.contiguous() 748 | shift_labels = labels.contiguous() 749 | # Flatten the tokens 750 | loss_fct = CrossEntropyLoss() 751 | shift_logits = shift_logits.view(-1, self.config.vocab_size) 752 | shift_labels = shift_labels.view(-1) 753 | # Enable model parallelism 754 | shift_labels = shift_labels.to(shift_logits.device) 755 | loss = loss_fct(shift_logits, shift_labels) 756 | 757 | if audio_labels is not None: 758 | au_mask = (audio_labels >= 0).all(dim=-1) 759 | au_hidden_states = hidden_states[au_mask] 760 | au_audio_labels = audio_labels[au_mask] 761 | if len(au_hidden_states) <= 0: 762 | au_hidden_states = hidden_states.reshape(-1, hidden_states.shape[-1]) 763 | au_audio_labels = torch.zeros_like(audio_labels).reshape(-1, self.config.num_codebooks) 764 | loss_weight = 0.0 765 | else: 766 | loss_weight = 1.0 767 | au_logits = self.audio_transformer(au_hidden_states, au_audio_labels) 768 | # We shift tokens and labels in dataloader 769 | shift_au_logits = au_logits.contiguous() 770 | shift_audio_labels = au_audio_labels.contiguous() 771 | # Flatten the tokens 772 | loss_fct = CrossEntropyLoss() 773 | shift_au_logits = shift_au_logits.view(-1, self.config.codebook_size) 774 | shift_audio_labels = shift_audio_labels.view(-1) 775 | # Enable model parallelism 776 | shift_audio_labels = shift_audio_labels.to(shift_au_logits.device) 777 | au_loss = loss_fct(shift_au_logits, shift_audio_labels) 778 | 779 | loss += au_loss * loss_weight 780 | else: 781 | # au_tokens = self.audio_transformer.inference(hidden_states) 782 | pass 783 | 784 | if not return_dict: 785 | output = (logits,) + outputs[1:] 786 | return (loss,) + output if loss is not None else output 787 | 788 | return VoilaOutput( 789 | loss=loss, 790 | logits=logits, 791 | last_hidden_state=hidden_states, 792 | past_key_values=outputs.past_key_values, 793 | hidden_states=outputs.hidden_states, 794 | attentions=outputs.attentions, 795 | ) 796 | 797 | def _prepare_inputs_for_generation( 798 | self, input_ids, ref_embs=None, ref_embs_mask=None, audio_datas=None, audio_data_masks=None, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs 799 | ): 800 | if past_key_values is not None and past_key_values.get_seq_length() > 0: 801 | if isinstance(past_key_values, Cache): 802 | cache_length = past_key_values.get_seq_length() 803 | past_length = past_key_values.seen_tokens 804 | max_cache_length = past_key_values.get_max_cache_shape() 805 | else: 806 | cache_length = past_length = past_key_values[0][0].shape[2] 807 | max_cache_length = None 808 | 809 | # Keep only the unprocessed tokens: 810 | # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where 811 | # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as 812 | # input) 813 | if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: 814 | input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] 815 | # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard 816 | # input_ids based on the past_length. 817 | elif past_length < input_ids.shape[1]: 818 | input_ids = input_ids[:, past_length:] 819 | # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. 820 | 821 | # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. 822 | if ( 823 | max_cache_length is not None 824 | and attention_mask is not None 825 | and cache_length + input_ids.shape[1] > max_cache_length 826 | ): 827 | attention_mask = attention_mask[:, -max_cache_length:] 828 | 829 | position_ids = kwargs.get("position_ids", None) 830 | if attention_mask is not None and position_ids is None: 831 | # create position_ids on the fly for batch generation 832 | position_ids = attention_mask.long().cumsum(-1) - 1 833 | position_ids.masked_fill_(attention_mask == 0, 1) 834 | if past_key_values: 835 | position_ids = position_ids[:, -input_ids.shape[1] :] 836 | 837 | # if `inputs_embeds` are passed, we only want to use them in the 1st generation step 838 | if inputs_embeds is None and \ 839 | (past_key_values is None or past_key_values.get_seq_length() <= 0): 840 | inputs_embeds = self.model.embed_tokens(input_ids) 841 | if inputs_embeds is not None and \ 842 | (past_key_values is None or past_key_values.get_seq_length() <= 0): 843 | model_inputs = {"inputs_embeds": inputs_embeds, "ref_embs": ref_embs, "ref_embs_mask": ref_embs_mask, "audio_datas": audio_datas, "audio_data_masks": audio_data_masks} 844 | else: 845 | model_inputs = {"input_ids": input_ids, "ref_embs": None, "audio_datas": None, "audio_data_masks": None} 846 | 847 | model_inputs.update( 848 | { 849 | "position_ids": position_ids, 850 | "past_key_values": past_key_values, 851 | "use_cache": kwargs.get("use_cache"), 852 | "attention_mask": attention_mask, 853 | } 854 | ) 855 | return model_inputs 856 | 857 | def _update_model_kwargs_for_generation( 858 | self, 859 | outputs, 860 | model_kwargs: Dict[str, Any], 861 | num_new_token: int = 1, 862 | ) -> Dict[str, Any]: 863 | # update past_key_values 864 | model_kwargs["past_key_values"] = outputs.past_key_values 865 | 866 | # update attention mask 867 | if "attention_mask" in model_kwargs: 868 | attention_mask = model_kwargs["attention_mask"] 869 | model_kwargs["attention_mask"] = torch.cat( 870 | [attention_mask, attention_mask.new_ones((attention_mask.shape[0], num_new_token))], dim=-1 871 | ) 872 | 873 | return model_kwargs 874 | 875 | def _prepare_attention_mask_for_generation( 876 | self, 877 | inputs: torch.Tensor, 878 | pad_token_id: Optional[int], 879 | eos_token_id: Optional[Union[int, List[int]]], 880 | ) -> torch.LongTensor: 881 | is_input_ids = len(inputs.shape) == 2 and inputs.dtype in [torch.int, torch.long] 882 | is_pad_token_in_inputs = (pad_token_id is not None) and (pad_token_id in inputs) 883 | if isinstance(eos_token_id, int): 884 | eos_token_id = [eos_token_id] 885 | is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or (pad_token_id not in eos_token_id) 886 | 887 | # Check if input is input_ids and padded -> only then is attention_mask defined 888 | if is_input_ids and is_pad_token_in_inputs and is_pad_token_not_equal_to_eos_token_id: 889 | return inputs.ne(pad_token_id).long() 890 | else: 891 | return torch.ones(inputs.shape[:2], dtype=torch.long, device=inputs.device) 892 | 893 | @torch.inference_mode() 894 | def run_generate( 895 | self, 896 | input_ids: torch.LongTensor, 897 | ref_embs: Optional[List[torch.Tensor]] = None, 898 | ref_embs_mask: Optional[torch.LongTensor] = None, 899 | audio_datas: Optional[torch.FloatTensor] = None, 900 | audio_data_masks: Optional[torch.LongTensor] = None, 901 | max_new_tokens: Optional[int] = 128, 902 | pad_token_id: Optional[int] = None, 903 | eos_token_id: Optional[Union[int, List[int]]] = None, 904 | streamer: Optional["BaseStreamer"] = None, 905 | llm_audio_token_id: Optional[int] = None, 906 | min_audio_token_id: Optional[int] = None, 907 | temperature=0.2, 908 | top_k=50, 909 | audio_temperature=0.2, 910 | audio_top_k=50, 911 | ): 912 | assert eos_token_id is not None and pad_token_id is not None, "eos_token_id and pad_token_id are required for inference" 913 | assert llm_audio_token_id is not None and min_audio_token_id is not None, "llm_audio_token_id and min_audio_token_id are required for inference" 914 | assert len(input_ids.shape) == 2 or len(input_ids.shape) == 3, f"input_ids is supposed to be [batch, seq_len] or [batch, seq_len, num_codebooks], and got {input_ids.shape}" 915 | 916 | eos_token_id_tensor = torch.tensor([eos_token_id]).to(input_ids.device) 917 | 918 | # keep track of which sequences are already finished 919 | unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device) 920 | 921 | # Extend input_ids with additional num_codebooks dim 922 | if len(input_ids.shape) == 2: 923 | input_ids = input_ids[:, :, None].expand(1, 1, self.config.num_codebooks) 924 | 925 | this_peer_finished = False # used by synced_gpus only 926 | max_length = input_ids.shape[1] + max_new_tokens 927 | 928 | model_kwargs = { 929 | "use_cache": True, 930 | "past_key_values": DynamicCache(), 931 | "attention_mask": self._prepare_attention_mask_for_generation( 932 | input_ids, pad_token_id, eos_token_id 933 | ), 934 | } 935 | # auto-regressive generation 936 | while True: 937 | # prepare model inputs 938 | model_inputs = self._prepare_inputs_for_generation( 939 | input_ids, 940 | ref_embs=ref_embs, 941 | ref_embs_mask=ref_embs_mask, 942 | audio_datas=audio_datas, 943 | audio_data_masks=audio_data_masks, 944 | **model_kwargs 945 | ) 946 | 947 | # forward pass to get next token 948 | outputs = self( 949 | **model_inputs, 950 | return_dict=True, 951 | ) 952 | audio_tokens = self.audio_transformer.inference( 953 | outputs.last_hidden_state, 954 | temperature=audio_temperature, 955 | top_k=audio_top_k, 956 | ) 957 | audio_tokens = torch.stack( 958 | [ 959 | audio_tokens[:, :, ci] + min_audio_token_id + ci*self.config.codebook_size 960 | for ci in range(self.config.num_codebooks) 961 | ], 962 | dim=2, 963 | ) 964 | 965 | next_token_logits = outputs.logits[:, -1, :] 966 | 967 | # pre-process distribution 968 | # Apply temperature and top-k 969 | if temperature > 0: 970 | next_token_logits = next_token_logits / temperature 971 | if top_k > 0: 972 | top_k = min(top_k, next_token_logits.size(-1)) # Safety check 973 | # Remove all tokens with a probability less than the last token of the top-k 974 | indices_to_remove = next_token_logits < torch.topk(next_token_logits, top_k)[0][..., -1, None] 975 | next_token_logits = next_token_logits.masked_fill(indices_to_remove, -float("Inf")) 976 | 977 | # sample 978 | probs = nn.functional.softmax(next_token_logits, dim=-1) 979 | next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) 980 | 981 | # finished sentences should have their next token be a padding token 982 | if eos_token_id is not None: 983 | if pad_token_id is None: 984 | raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.") 985 | next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) 986 | 987 | # Append NUM_CODEBOOK text tokens or audio_tokens 988 | if len(next_tokens.shape) == 1: 989 | next_tokens = next_tokens[:, None, None].expand(-1, 1, self.config.num_codebooks) 990 | next_tokens = torch.where(next_tokens==llm_audio_token_id, audio_tokens, next_tokens) 991 | 992 | input_ids = torch.cat([input_ids, next_tokens], dim=1) 993 | if streamer is not None: 994 | streamer.put(next_tokens.cpu()) 995 | model_kwargs = self._update_model_kwargs_for_generation( 996 | outputs, model_kwargs 997 | ) 998 | 999 | # if eos_token was found in one sentence, set sentence to finished 1000 | if eos_token_id_tensor is not None: 1001 | unfinished_sequences = unfinished_sequences.mul( 1002 | next_tokens[:, :, 0].ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=1) 1003 | ) 1004 | 1005 | # stop when each sentence is finished 1006 | if unfinished_sequences.max() == 0: 1007 | this_peer_finished = True 1008 | 1009 | # stop if we exceed the maximum length 1010 | if input_ids.shape[1] >= max_length: 1011 | this_peer_finished = True 1012 | 1013 | if this_peer_finished: 1014 | break 1015 | 1016 | if streamer is not None: 1017 | streamer.end() 1018 | 1019 | return input_ids 1020 | 1021 | 1022 | # Modified from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L1103 1023 | class VoilaAutonomousModel(LlamaPreTrainedModel): 1024 | _tied_weights_keys = ["lm_head.weight"] 1025 | 1026 | def __init__(self, config): 1027 | super().__init__(config) 1028 | self.model = LlamaModel(config) 1029 | self.vocab_size = config.vocab_size 1030 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) 1031 | self.pad_vocab_size_multiple = 64 1032 | 1033 | self.ref_emb_linear = nn.Linear(256, config.hidden_size, bias=True) 1034 | self.audio_transformer = AudioTransformer(config, use_sdpa=False) 1035 | self.voila_predictor = nn.Sequential(nn.Linear(config.hidden_size, 2, bias=True),) 1036 | 1037 | # Initialize weights and apply final processing 1038 | self.post_init() 1039 | 1040 | def get_input_embeddings(self): 1041 | return self.model.embed_tokens 1042 | 1043 | def set_input_embeddings(self, value): 1044 | self.model.embed_tokens = value 1045 | 1046 | def get_output_embeddings(self): 1047 | return self.lm_head 1048 | 1049 | def set_output_embeddings(self, new_embeddings): 1050 | self.lm_head = new_embeddings 1051 | 1052 | def set_decoder(self, decoder): 1053 | self.model = decoder 1054 | 1055 | def get_decoder(self): 1056 | return self.model 1057 | 1058 | def forward( 1059 | self, 1060 | input_ids: torch.LongTensor = None, 1061 | attention_mask: Optional[torch.Tensor] = None, 1062 | position_ids: Optional[torch.LongTensor] = None, 1063 | past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, 1064 | inputs_embeds: Optional[torch.FloatTensor] = None, 1065 | labels: Optional[torch.LongTensor] = None, 1066 | audio_labels: Optional[torch.LongTensor] = None, 1067 | voila_labels: Optional[torch.LongTensor] = None, 1068 | ref_embs: Optional[List[torch.Tensor]] = None, 1069 | ref_embs_mask: Optional[torch.LongTensor] = None, 1070 | use_cache: Optional[bool] = None, 1071 | output_attentions: Optional[bool] = None, 1072 | output_hidden_states: Optional[bool] = None, 1073 | return_dict: Optional[bool] = None, 1074 | cache_position: Optional[torch.LongTensor] = None, 1075 | num_logits_to_keep: int = 0, 1076 | ) -> Union[Tuple, VoilaOutput]: 1077 | r""" 1078 | Args: 1079 | input_ids: [bs, seq_len, num_codebooks] 1080 | labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): 1081 | Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., 1082 | config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored 1083 | (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. 1084 | """ 1085 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 1086 | output_hidden_states = ( 1087 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 1088 | ) 1089 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 1090 | 1091 | if input_ids is not None and inputs_embeds is not None: 1092 | raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") 1093 | if inputs_embeds is None: 1094 | inputs_embeds = self.model.embed_tokens(input_ids) 1095 | assert len(inputs_embeds.shape) == 4 1096 | if len(inputs_embeds.shape) == 4: 1097 | inputs_embeds = inputs_embeds.mean(dim=2) 1098 | 1099 | if self.training or \ 1100 | (past_key_values is None and ref_embs is not None) or \ 1101 | (past_key_values is not None and past_key_values.get_seq_length() < 4 and ref_embs is not None): 1102 | ref_embs = self.ref_emb_linear(ref_embs.to(self.ref_emb_linear.weight.dtype)) 1103 | ref_embs = ref_embs * ref_embs_mask.unsqueeze(-1).unsqueeze(-1) 1104 | # (padding_left,padding_right,padding_top,padding_bottom,padding_front,padding_back) 1105 | padding = (0, 0, 4, inputs_embeds.shape[1] - 5, 0, 0) 1106 | ref_embs = torch.nn.functional.pad(ref_embs, padding, mode='constant', value=0.0) 1107 | inputs_embeds = inputs_embeds + ref_embs 1108 | 1109 | # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) 1110 | outputs = self.model( 1111 | attention_mask=attention_mask, 1112 | position_ids=position_ids, 1113 | past_key_values=past_key_values, 1114 | inputs_embeds=inputs_embeds, 1115 | use_cache=use_cache, 1116 | output_attentions=output_attentions, 1117 | output_hidden_states=output_hidden_states, 1118 | return_dict=return_dict, 1119 | cache_position=cache_position, 1120 | ) 1121 | 1122 | hidden_states = outputs[0] 1123 | if self.config.pretraining_tp > 1: 1124 | lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0) 1125 | logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)] 1126 | logits = torch.cat(logits, dim=-1) 1127 | else: 1128 | # Only compute necessary logits, and do not upcast them to float if we are not computing the loss 1129 | logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) 1130 | 1131 | # calc voila_predict_loss 1132 | voila_pred = self.voila_predictor(hidden_states) 1133 | voila_pred = voila_pred.float() 1134 | 1135 | loss = None 1136 | 1137 | if not return_dict: 1138 | output = (logits,) + outputs[1:] 1139 | return (loss,) + output if loss is not None else output 1140 | 1141 | return VoilaOutput( 1142 | loss=loss, 1143 | logits=logits, 1144 | last_hidden_state=hidden_states, 1145 | past_key_values=outputs.past_key_values, 1146 | hidden_states=outputs.hidden_states, 1147 | attentions=outputs.attentions, 1148 | voila_pred=voila_pred, 1149 | ) 1150 | 1151 | def _prepare_inputs_for_generation( 1152 | self, input_ids, ref_embs=None, ref_embs_mask=None, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs 1153 | ): 1154 | if past_key_values is not None and past_key_values.get_seq_length() > 0: 1155 | if isinstance(past_key_values, Cache): 1156 | cache_length = past_key_values.get_seq_length() 1157 | past_length = past_key_values.seen_tokens 1158 | max_cache_length = past_key_values.get_max_cache_shape() 1159 | else: 1160 | cache_length = past_length = past_key_values[0][0].shape[2] 1161 | max_cache_length = None 1162 | 1163 | # Keep only the unprocessed tokens: 1164 | # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where 1165 | # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as 1166 | # input) 1167 | if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: 1168 | input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] 1169 | # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard 1170 | # input_ids based on the past_length. 1171 | elif past_length < input_ids.shape[1]: 1172 | input_ids = input_ids[:, past_length:] 1173 | # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. 1174 | 1175 | # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. 1176 | if ( 1177 | max_cache_length is not None 1178 | and attention_mask is not None 1179 | and cache_length + input_ids.shape[1] > max_cache_length 1180 | ): 1181 | attention_mask = attention_mask[:, -max_cache_length:] 1182 | 1183 | position_ids = kwargs.get("position_ids", None) 1184 | if attention_mask is not None and position_ids is None: 1185 | # create position_ids on the fly for batch generation 1186 | position_ids = attention_mask.long().cumsum(-1) - 1 1187 | position_ids.masked_fill_(attention_mask == 0, 1) 1188 | if past_key_values: 1189 | position_ids = position_ids[:, -input_ids.shape[1] :] 1190 | 1191 | # if `inputs_embeds` are passed, we only want to use them in the 1st generation step 1192 | if inputs_embeds is None and \ 1193 | (past_key_values is None or past_key_values.get_seq_length() <= 0): 1194 | inputs_embeds = self.model.embed_tokens(input_ids) 1195 | if inputs_embeds is not None and \ 1196 | (past_key_values is None or past_key_values.get_seq_length() <= 0): 1197 | model_inputs = {"inputs_embeds": inputs_embeds, "ref_embs": ref_embs, "ref_embs_mask": ref_embs_mask} 1198 | else: 1199 | model_inputs = {"input_ids": input_ids, "ref_embs": None} 1200 | 1201 | model_inputs.update( 1202 | { 1203 | "position_ids": position_ids, 1204 | "past_key_values": past_key_values, 1205 | "use_cache": kwargs.get("use_cache"), 1206 | "attention_mask": attention_mask, 1207 | } 1208 | ) 1209 | return model_inputs 1210 | 1211 | def _update_model_kwargs_for_generation( 1212 | self, 1213 | outputs, 1214 | model_kwargs: Dict[str, Any], 1215 | num_new_token: int = 1, 1216 | ) -> Dict[str, Any]: 1217 | # update past_key_values 1218 | model_kwargs["past_key_values"] = outputs.past_key_values 1219 | 1220 | # update attention mask 1221 | if "attention_mask" in model_kwargs: 1222 | attention_mask = model_kwargs["attention_mask"] 1223 | model_kwargs["attention_mask"] = torch.cat( 1224 | [attention_mask, attention_mask.new_ones((attention_mask.shape[0], num_new_token))], dim=-1 1225 | ) 1226 | 1227 | return model_kwargs 1228 | 1229 | def _prepare_attention_mask_for_generation( 1230 | self, 1231 | inputs: torch.Tensor, 1232 | pad_token_id: Optional[int], 1233 | eos_token_id: Optional[Union[int, List[int]]], 1234 | ) -> torch.LongTensor: 1235 | is_input_ids = len(inputs.shape) == 2 and inputs.dtype in [torch.int, torch.long] 1236 | is_pad_token_in_inputs = (pad_token_id is not None) and (pad_token_id in inputs) 1237 | if isinstance(eos_token_id, int): 1238 | eos_token_id = [eos_token_id] 1239 | is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or (pad_token_id not in eos_token_id) 1240 | 1241 | # Check if input is input_ids and padded -> only then is attention_mask defined 1242 | if is_input_ids and is_pad_token_in_inputs and is_pad_token_not_equal_to_eos_token_id: 1243 | return inputs.ne(pad_token_id).long() 1244 | else: 1245 | return torch.ones(inputs.shape[:2], dtype=torch.long, device=inputs.device) 1246 | 1247 | @torch.inference_mode() 1248 | def run_generate( 1249 | self, 1250 | input_ids: torch.LongTensor, 1251 | input_generator, 1252 | ref_embs: Optional[List[torch.Tensor]] = None, 1253 | ref_embs_mask: Optional[torch.LongTensor] = None, 1254 | max_new_tokens: Optional[int] = 128, 1255 | pad_token_id: Optional[int] = None, 1256 | eos_token_id: Optional[Union[int, List[int]]] = None, 1257 | streamer: Optional["BaseStreamer"] = None, 1258 | llm_audio_token_id: Optional[int] = None, 1259 | min_audio_token_id: Optional[int] = None, 1260 | llm_assistant_token_id: Optional[int] = None, 1261 | temperature=0.2, 1262 | top_k=50, 1263 | audio_temperature=0.8, 1264 | audio_top_k=50, 1265 | ): 1266 | assert eos_token_id is not None and pad_token_id is not None, "eos_token_id and pad_token_id are required for inference" 1267 | assert llm_audio_token_id is not None and min_audio_token_id is not None, "llm_audio_token_id and min_audio_token_id are required for inference" 1268 | assert len(input_ids.shape) == 2 or len(input_ids.shape) == 3, f"input_ids is supposed to be [batch, seq_len] or [batch, seq_len, num_codebooks], and got {input_ids.shape}" 1269 | 1270 | eos_token_id_tensor = torch.tensor([eos_token_id]).to(input_ids.device) 1271 | 1272 | # keep track of which sequences are already finished 1273 | unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device) 1274 | 1275 | # Extend input_ids with additional num_codebooks dim 1276 | input_ids = input_ids.clone() 1277 | if len(input_ids.shape) == 2: 1278 | input_ids = input_ids[:, :, None].expand(1, 1, self.config.num_codebooks) 1279 | 1280 | this_peer_finished = False # used by synced_gpus only 1281 | max_length = input_ids.shape[1] + max_new_tokens 1282 | 1283 | model_kwargs = { 1284 | "use_cache": True, 1285 | "past_key_values": DynamicCache(), 1286 | "attention_mask": self._prepare_attention_mask_for_generation( 1287 | input_ids, pad_token_id, eos_token_id 1288 | ), 1289 | } 1290 | speaking = False 1291 | # auto-regressive generation 1292 | while True: 1293 | # prepare model inputs 1294 | model_inputs = self._prepare_inputs_for_generation( 1295 | input_ids, 1296 | ref_embs=ref_embs, 1297 | ref_embs_mask=ref_embs_mask, 1298 | **model_kwargs 1299 | ) 1300 | 1301 | # forward pass to get next token 1302 | outputs = self( 1303 | **model_inputs, 1304 | return_dict=True, 1305 | ) 1306 | audio_tokens = self.audio_transformer.inference( 1307 | outputs.last_hidden_state, 1308 | temperature=audio_temperature, 1309 | top_k=audio_top_k, 1310 | ) 1311 | audio_tokens = torch.stack( 1312 | [ 1313 | audio_tokens[:, :, ci] + min_audio_token_id + ci*self.config.codebook_size 1314 | for ci in range(self.config.num_codebooks) 1315 | ], 1316 | dim=2, 1317 | ) 1318 | 1319 | next_token_logits = outputs.logits[:, -1, :] 1320 | 1321 | # voila head output 1322 | voila_head_pred = outputs.voila_pred[:, -1, :] 1323 | voila_head_pred = torch.argmax(voila_head_pred, dim=-1) 1324 | voila_head_pred = voila_head_pred.cpu()[0].item() 1325 | 1326 | # pre-process distribution 1327 | # Apply temperature and top-k 1328 | if temperature > 0: 1329 | next_token_logits = next_token_logits / temperature 1330 | if top_k > 0: 1331 | top_k = min(top_k, next_token_logits.size(-1)) # Safety check 1332 | # Remove all tokens with a probability less than the last token of the top-k 1333 | indices_to_remove = next_token_logits < torch.topk(next_token_logits, top_k)[0][..., -1, None] 1334 | next_token_logits = next_token_logits.masked_fill(indices_to_remove, -float("Inf")) 1335 | 1336 | # sample 1337 | probs = nn.functional.softmax(next_token_logits, dim=-1) 1338 | next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) 1339 | 1340 | # voila head pred == 1, use assistant token 1341 | if voila_head_pred == 1 and not speaking: 1342 | next_tokens[0] = llm_assistant_token_id 1343 | speaking = True 1344 | elif next_tokens[0] == eos_token_id: 1345 | speaking = False 1346 | 1347 | # finished sentences should have their next token be a padding token 1348 | if eos_token_id is not None: 1349 | if pad_token_id is None: 1350 | raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.") 1351 | next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) 1352 | 1353 | # Append NUM_CODEBOOK text tokens or audio_tokens 1354 | if len(next_tokens.shape) == 1: 1355 | next_tokens = next_tokens[:, None, None].expand(-1, 1, self.config.num_codebooks) 1356 | audio_token_mask = next_tokens == llm_audio_token_id 1357 | next_tokens = next_tokens * torch.logical_not(audio_token_mask) + audio_tokens * audio_token_mask 1358 | 1359 | if audio_token_mask[0, 0, 0].item(): 1360 | try: 1361 | new_input_tokens = next(input_generator) 1362 | except: 1363 | this_peer_finished = True 1364 | break 1365 | new_input_tokens = new_input_tokens[None,None,:] 1366 | else: 1367 | new_input_tokens = next_tokens 1368 | new_input_tokens = torch.cat([new_input_tokens, next_tokens], dim=2) 1369 | 1370 | input_ids = torch.cat([input_ids, new_input_tokens], dim=1) 1371 | if streamer is not None: 1372 | streamer.put(next_tokens.cpu()) 1373 | model_kwargs = self._update_model_kwargs_for_generation( 1374 | outputs, model_kwargs 1375 | ) 1376 | 1377 | # # if eos_token was found in one sentence, set sentence to finished 1378 | # if eos_token_id_tensor is not None: 1379 | # unfinished_sequences = unfinished_sequences.mul( 1380 | # next_tokens[:, :, 0].ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=1) 1381 | # ) 1382 | 1383 | # # stop when each sentence is finished 1384 | # if unfinished_sequences.max() == 0: 1385 | # this_peer_finished = True 1386 | 1387 | # stop if we exceed the maximum length 1388 | if input_ids.shape[1] >= max_length: 1389 | this_peer_finished = True 1390 | 1391 | if this_peer_finished: 1392 | break 1393 | 1394 | if streamer is not None: 1395 | streamer.end() 1396 | 1397 | return input_ids 1398 | --------------------------------------------------------------------------------