├── examples
├── test1.mp3
├── test_autonomous1.mp3
└── character_ref_emb_demo.pkl
├── requirements.txt
├── LICENSE
├── spkr.py
├── voila_tokenizer.py
├── .gitignore
├── README.md
├── infer.py
├── gradio_demo.py
├── audio_transformer.py
├── tokenize_func.py
└── model.py
/examples/test1.mp3:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maitrix-org/Voila/HEAD/examples/test1.mp3
--------------------------------------------------------------------------------
/examples/test_autonomous1.mp3:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maitrix-org/Voila/HEAD/examples/test_autonomous1.mp3
--------------------------------------------------------------------------------
/examples/character_ref_emb_demo.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/maitrix-org/Voila/HEAD/examples/character_ref_emb_demo.pkl
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch
2 | torchvision
3 | torchaudio
4 | transformers
5 | flash-attn
6 | soundfile
7 | librosa
8 | jsonlines
9 | gradio
10 | pyannote.audio
11 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2025 Maitrix.org
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/spkr.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torchaudio
3 | from torchaudio.functional import resample
4 |
5 | from pyannote.audio import Model
6 | from pyannote.audio.pipelines.speaker_verification import PretrainedSpeakerEmbedding
7 |
8 |
9 | class SpeakerEmbedding:
10 | def __init__(self, model_path="pyannote/wespeaker-voxceleb-resnet34-LM", device="cuda"):
11 | model = Model.from_pretrained(model_path).eval()
12 |
13 | self.device = torch.device(device)
14 | self.sample_rate = 16000
15 | self.model = model.to(self.device)
16 |
17 | @torch.no_grad()
18 | def __call__(self, wav, sr):
19 | wav = torch.tensor(wav, device=self.device)
20 | if sr != self.sample_rate:
21 | wav = resample(wav, sr, self.sample_rate)
22 | sr = self.sample_rate
23 |
24 | assert len(wav.shape) <= 3
25 | is_batch = False
26 | if len(wav.shape) == 3:
27 | is_batch = True
28 | elif len(wav.shape) == 2:
29 | wav = wav[None, :, :]
30 | else:
31 | wav = wav[None, None, :]
32 |
33 | with torch.inference_mode():
34 | embeddings = self.model(wav)
35 |
36 | if is_batch:
37 | return embeddings
38 | else:
39 | return embeddings[0]
40 |
41 | if __name__ == '__main__':
42 | import argparse
43 | parser = argparse.ArgumentParser()
44 | parser.add_argument("--wav", type=str, required=True)
45 | args = parser.parse_args()
46 |
47 | model = SpeakerEmbedding(device="cuda")
48 |
49 | wav, sr = torchaudio.load(args.wav)
50 | print(model(wav, sr))
51 |
--------------------------------------------------------------------------------
/voila_tokenizer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torchaudio
3 | from torchaudio.functional import resample
4 |
5 | from transformers import AutoProcessor, EncodecModel
6 |
7 |
8 | ALL_BANDWIDTHS = [1.1]
9 |
10 | class VoilaTokenizer:
11 | def __init__(
12 | self,
13 | model_path="maitrix-org/Voila-Tokenizer",
14 | bandwidth_id=0,
15 | device="cpu",
16 | ):
17 | self.device = torch.device(device)
18 | self.bandwidth = ALL_BANDWIDTHS[bandwidth_id]
19 | self.bandwidth_id = torch.tensor([bandwidth_id], device=device)
20 |
21 | self.processor = AutoProcessor.from_pretrained(model_path)
22 | self.model = EncodecModel.from_pretrained(model_path).to(device)
23 |
24 | self.sampling_rate = self.processor.sampling_rate
25 | self.model_version = self.model.config.model_version
26 |
27 |
28 | @torch.no_grad()
29 | def encode(self, wav, sr):
30 | wav = torch.tensor(wav, dtype=torch.float32, device=self.device)
31 | if sr != self.processor.sampling_rate:
32 | wav = resample(wav, sr, self.processor.sampling_rate)
33 | sr = self.processor.sampling_rate
34 | if len(wav.shape) == 1:
35 | wav = wav[None, None, :]
36 | elif len(wav.shape) == 2:
37 | assert wav.shape[0] == 1
38 | wav = wav[None, :]
39 | elif len(wav.shape) == 3:
40 | assert wav.shape[0] == 1 and wav.shape[1] == 1
41 |
42 | # inputs = self.processor(raw_audio=wav, sampling_rate=sr, return_tensors="pt")
43 | encoder_outputs = self.model.encode(wav, bandwidth=self.bandwidth)
44 | return encoder_outputs.audio_codes[0, 0]
45 |
46 | @torch.no_grad()
47 | def decode(self, audio_codes):
48 | assert len(audio_codes.shape) == 2
49 | audio_values = self.model.decode(audio_codes[None, None, :, :], [None])[0]
50 | return audio_values[0, 0]
51 |
52 | if __name__ == '__main__':
53 | import argparse
54 | import soundfile as sf
55 |
56 | parser = argparse.ArgumentParser()
57 | parser.add_argument("--wav", type=str)
58 | args = parser.parse_args()
59 |
60 | wav, sr = torchaudio.load(args.wav)
61 | if len(wav.shape) > 1:
62 | wav = wav[0]
63 |
64 | model = VoilaTokenizer(device="cuda")
65 |
66 | audio_codes = model.encode(wav, sr)
67 | audio_values = model.decode(audio_codes).cpu().numpy()
68 |
69 | tps = audio_codes.shape[-1] / (audio_values.shape[-1] / model.processor.sampling_rate)
70 | print(audio_codes.shape, audio_values.shape, tps)
71 | sf.write("audio_mt.wav", audio_values, model.processor.sampling_rate)
72 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | output
2 | .gradio
3 | run_*.sh
4 |
5 | # Byte-compiled / optimized / DLL files
6 | __pycache__/
7 | *.py[cod]
8 | *$py.class
9 |
10 | # C extensions
11 | *.so
12 |
13 | # Distribution / packaging
14 | .Python
15 | build/
16 | develop-eggs/
17 | dist/
18 | downloads/
19 | eggs/
20 | .eggs/
21 | lib/
22 | lib64/
23 | parts/
24 | sdist/
25 | var/
26 | wheels/
27 | share/python-wheels/
28 | *.egg-info/
29 | .installed.cfg
30 | *.egg
31 | MANIFEST
32 |
33 | # PyInstaller
34 | # Usually these files are written by a python script from a template
35 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
36 | *.manifest
37 | *.spec
38 |
39 | # Installer logs
40 | pip-log.txt
41 | pip-delete-this-directory.txt
42 |
43 | # Unit test / coverage reports
44 | htmlcov/
45 | .tox/
46 | .nox/
47 | .coverage
48 | .coverage.*
49 | .cache
50 | nosetests.xml
51 | coverage.xml
52 | *.cover
53 | *.py,cover
54 | .hypothesis/
55 | .pytest_cache/
56 | cover/
57 |
58 | # Translations
59 | *.mo
60 | *.pot
61 |
62 | # Django stuff:
63 | *.log
64 | local_settings.py
65 | db.sqlite3
66 | db.sqlite3-journal
67 |
68 | # Flask stuff:
69 | instance/
70 | .webassets-cache
71 |
72 | # Scrapy stuff:
73 | .scrapy
74 |
75 | # Sphinx documentation
76 | docs/_build/
77 |
78 | # PyBuilder
79 | .pybuilder/
80 | target/
81 |
82 | # Jupyter Notebook
83 | .ipynb_checkpoints
84 |
85 | # IPython
86 | profile_default/
87 | ipython_config.py
88 |
89 | # pyenv
90 | # For a library or package, you might want to ignore these files since the code is
91 | # intended to run in multiple environments; otherwise, check them in:
92 | # .python-version
93 |
94 | # pipenv
95 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
96 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
97 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
98 | # install all needed dependencies.
99 | #Pipfile.lock
100 |
101 | # UV
102 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
103 | # This is especially recommended for binary packages to ensure reproducibility, and is more
104 | # commonly ignored for libraries.
105 | #uv.lock
106 |
107 | # poetry
108 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
109 | # This is especially recommended for binary packages to ensure reproducibility, and is more
110 | # commonly ignored for libraries.
111 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
112 | #poetry.lock
113 |
114 | # pdm
115 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
116 | #pdm.lock
117 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
118 | # in version control.
119 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
120 | .pdm.toml
121 | .pdm-python
122 | .pdm-build/
123 |
124 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
125 | __pypackages__/
126 |
127 | # Celery stuff
128 | celerybeat-schedule
129 | celerybeat.pid
130 |
131 | # SageMath parsed files
132 | *.sage.py
133 |
134 | # Environments
135 | .env
136 | .venv
137 | env/
138 | venv/
139 | ENV/
140 | env.bak/
141 | venv.bak/
142 |
143 | # Spyder project settings
144 | .spyderproject
145 | .spyproject
146 |
147 | # Rope project settings
148 | .ropeproject
149 |
150 | # mkdocs documentation
151 | /site
152 |
153 | # mypy
154 | .mypy_cache/
155 | .dmypy.json
156 | dmypy.json
157 |
158 | # Pyre type checker
159 | .pyre/
160 |
161 | # pytype static type analyzer
162 | .pytype/
163 |
164 | # Cython debug symbols
165 | cython_debug/
166 |
167 | # PyCharm
168 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
169 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
170 | # and can be added to the global gitignore or merged into this file. For a more nuclear
171 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
172 | #.idea/
173 |
174 | # PyPI configuration file
175 | .pypirc
176 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 | 
3 | Voila: Voice-Language Foundation Models
4 | 💜 Project Page    |    🖥️ GitHub    |   🤗 Hugging Face   |    📑 Paper    |    🌐 Online Demo   |    🏠Maitrix.org
5 |
6 |
7 | Voila is a new family of large voice-language foundation models aiming to lift human-AI interaction experiences to the next level. Breaking away from the constraints of traditional voice AI systems—high latency, loss of vocal nuances, and mechanical responses—Voila employs an innovative end-to-end model design and a novel hierarchical Transformer architecture. This approach enables real-time, autonomous, and rich voice interactions, with latency as low as 195 ms, surpassing average human response times. Combining advanced voice and language modeling, Voila offers customizable, persona-driven engagements and excels in a range of audio tasks from ASR and TTS to speech translation across six languages. With the online [web demo](https://huggingface.co/spaces/maitrix-org/Voila-demo), Voila invites you to explore a transformative, natural dialogue experience between human and AI.
8 |
9 | # ✨ Highlights
10 | - ⭐ High-fidelity, low-latency, real-time streaming audio processing
11 | - ⭐ Effective integration of voice and language modeling capabilities
12 | - ⭐ Millions of pre-built and custom voices, fast voice switching during conversation
13 | - ⭐ Unified model for various audio tasks
14 |
15 | # 🎥 Video Demo
16 | [](https://www.youtube.com/watch?v=J27M9-g5KL0)
17 |
18 | # 🔥 Latest News!!
19 |
20 | * April 28, 2025: 👋 We've released the inference code and model weights of Voila.
21 |
22 | # ⚙️ Foundation Models
23 |
24 | | Model | Description | Download Link |
25 | |--------|-----------|-----------------|
26 | |Voila-base|Voila base model|https://huggingface.co/maitrix-org/Voila-base|
27 | |Voila-Chat|End-to-end audio chat model|https://huggingface.co/maitrix-org/Voila-chat|
28 | |Voila-Autonomous (preview)|Full-duplex audio chat model|https://huggingface.co/maitrix-org/Voila-autonomous-preview|
29 | |Voila-Audio-alpha|Empowering LLM with raw audio input|https://huggingface.co/maitrix-org/Voila-audio-alpha|
30 | |Voila-Tokenizer|Audio tokenizer|https://huggingface.co/maitrix-org/Voila-Tokenizer|
31 |
32 | ## Usage
33 | ### CLI demo
34 | ```shell
35 | for model_name in "maitrix-org/Voila-audio-alpha" "maitrix-org/Voila-base" "maitrix-org/Voila-chat"; do
36 | # Text chat
37 | python infer.py \
38 | --model-name ${model_name} \
39 | --instruction "" \
40 | --input-text "Hello" \
41 | --task-type chat_tito
42 | # Voice chat
43 | python infer.py \
44 | --model-name ${model_name} \
45 | --instruction "" \
46 | --input-audio "examples/test1.mp3" \
47 | --task-type chat_aiao
48 | done
49 |
50 | # Autonomous mode
51 | python infer.py \
52 | --model-name "maitrix-org/Voila-autonomous-preview" \
53 | --instruction "" \
54 | --input-audio "examples/test_autonomous1.mp3" \
55 | --task-type chat_aiao_auto
56 | ```
57 |
58 | ### Gradio demo
59 | ```shell
60 | python gradio_demo.py
61 | ```
62 |
63 | For more information, please refer to the [code repository](https://github.com/maitrix-org/Voila).
64 |
65 | # 📁 Datasets
66 | We publish the following two datasets: Voila Benchmark and Voila Voice Library. Voila-Benchmark is a novel speech evaluation benchmark, while Voila Voice Library provides millions of pre-built and customizable voices.
67 |
68 | | Dataset | Description | Download Link |
69 | |--------|-----------|-----------------|
70 | |Voila Benchmark| Evaluation of Voila Benchmark | https://huggingface.co/datasets/maitrix-org/Voila-Benchmark |
71 | |Voila Voice Library| Millons of pre-build voices | https://huggingface.co/datasets/maitrix-org/Voila-million-voice
72 |
73 | # 📊 Benchmark
74 | ## 1. Voila Benchmark
75 | We introduce a novel speech evaluation benchmark called the VoilaBenchmark. The Voila Benchmark is constructed by sampling from five widely used language model evaluation datasets: MMLU, MATH, OpenAI HumanEval, NQ-Open, and GSM8k. We compare our results with SpeechGPT and Moshi.
76 | | Model | Voila Benchmark |
77 | |-------|----------------|
78 | |SpeechGPT| 13.29|
79 | |Moshi | 11.45 |
80 | |**Voila** | **30.56** |
81 |
82 | _(higher is better)_
83 |
84 | For detailed scores of Voila Benchmark on each specific domain, please refer to our paper (Section 5.1 "Evaluation of Voila Benchmark").
85 | ## 2. Evaluation of ASR
86 | As Voila supports multiple tasks, including Automatic Speech Recognition (ASR), Text-to-Speech(TTS), and spoken question answering, we also evaluate the performance of ASR and TTS.
87 | For ASR, we assess performance on the LibriSpeech test-clean dataset, using Word Error Rate (WER) as our metric. Voila attains a word error rate (WER) of 4.8%, outperforming the 5.7% reported by Moshi. In scenarios where both models utilize LibriSpeech training data, Voila achieves an impressive WER of 2.7%.
88 | | Model | LibriSpeech test-clean (WER) |
89 | |-------|-----------------------|
90 | |Whisper large v2|2.7|
91 | |Whisper large v3|2.2|
92 | |FastConformer|3.6|
93 | |VoxtLM |2.7|
94 | |Moshi |5.7|
95 | |**Voila (w/o LibriSpeech train split)** |**4.8**|
96 | |**Voila (with LibriSpeech train split)**|**2.7**|
97 |
98 | _(lower is better)_
99 |
100 | ## 3. Evaluation of TTS
101 | For TTS, we follow the evaluation metrics proposed in Vall-E, which involves transcribing the generated audio using HuBERT-Large.
102 | Voila once again leads with a WER of 3.2% (and 2.8% when using LibriSpeech training data).
103 |
104 | | Model | LibriSpeech test-clean (WER) |
105 | |-------|-----------------------|
106 | |YourTTS |7.7|
107 | |Vall-E|5.9|
108 | |Moshi|4.7|
109 | |**Voila (w/o LibriSpeech train split)** |**3.2**|
110 | |**Voila (with LibriSpeech train split)** |**2.8**|
111 |
112 | _(lower is better)_
113 |
114 | # 📝 Citation
115 | If you find our work helpful, please cite us.
116 |
117 | ```
118 | @article{voila2025,
119 | author = {Yemin Shi, Yu Shu, Siwei Dong, Guangyi Liu, Jaward Sesay, Jingwen Li, Zhiting Hu},
120 | title = {Voila: Voice-Language Foundation Models for Real-Time Autonomous Interaction and Voice Roleplay},
121 | eprint={2505.02707},
122 | archivePrefix={arXiv},
123 | primaryClass={cs.CL},
124 | year = {2025}
125 | }
126 | ```
127 |
--------------------------------------------------------------------------------
/infer.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import random
4 | import jsonlines
5 | import soundfile as sf
6 | import json
7 | import copy
8 | import torch
9 | from pathlib import Path
10 | from threading import Thread
11 |
12 | import torchaudio
13 | from transformers import AutoTokenizer
14 |
15 | from model import VoilaAudioAlphaModel, VoilaModel, VoilaAutonomousModel
16 | from spkr import SpeakerEmbedding
17 | from voila_tokenizer import VoilaTokenizer
18 | from tokenize_func import (
19 | voila_input_format,
20 | AUDIO_TOKEN_FORMAT,
21 | DEFAULT_AUDIO_TOKEN,
22 | DEFAULT_ASSISTANT_TOKEN,
23 | )
24 |
25 |
26 | def disable_torch_init():
27 | """
28 | Disable the redundant torch default initialization to accelerate model creation.
29 | """
30 | import torch
31 | setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
32 | setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
33 |
34 | def load_model(model_name, audio_tokenizer_path):
35 | disable_torch_init()
36 |
37 | if "Voila-audio" in model_name:
38 | model_type = "audio"
39 | cls = VoilaAudioAlphaModel
40 | elif "Voila-auto" in model_name:
41 | model_type = "autonomous"
42 | cls = VoilaAutonomousModel
43 | else:
44 | model_type = "token"
45 | cls = VoilaModel
46 |
47 | model = cls.from_pretrained(
48 | model_name,
49 | torch_dtype=torch.bfloat16,
50 | use_flash_attention_2=True,
51 | use_cache=True,
52 | )
53 | model = model.cuda()
54 | tokenizer = AutoTokenizer.from_pretrained(model_name)
55 | tokenizer_voila = VoilaTokenizer(model_path=audio_tokenizer_path, device="cuda")
56 | return model, tokenizer, tokenizer_voila, model_type
57 |
58 | def is_audio_output_task(task_type):
59 | return task_type.endswith("ao") or "aiao" in task_type or "tts" in task_type
60 |
61 | def eval_model(model, tokenizer, tokenizer_voila, model_type, task_type, history, ref_embs, ref_embs_mask, max_new_tokens=512):
62 | # step1: initializing
63 | num_codebooks = model.config.num_codebooks
64 | codebook_size = model.config.codebook_size
65 |
66 | AUDIO_MIN_TOKEN_ID = tokenizer.convert_tokens_to_ids(AUDIO_TOKEN_FORMAT.format(0))
67 | assert isinstance(AUDIO_MIN_TOKEN_ID, int)
68 | AUDIO_MAX_TOKEN_ID = tokenizer.convert_tokens_to_ids(AUDIO_TOKEN_FORMAT.format(codebook_size*num_codebooks-1))
69 | assert isinstance(AUDIO_MAX_TOKEN_ID, int)
70 | AUDIO_TOKEN_ID = tokenizer.convert_tokens_to_ids(DEFAULT_AUDIO_TOKEN)
71 | assert isinstance(AUDIO_TOKEN_ID, int)
72 | ASSISTANT_TOKEN_ID = tokenizer.convert_tokens_to_ids(DEFAULT_ASSISTANT_TOKEN)
73 | assert isinstance(ASSISTANT_TOKEN_ID, int)
74 |
75 | # step2: set infer config
76 | data_cfg = {
77 | "input_type": model_type,
78 | "task_type": task_type,
79 | "num_codebooks": num_codebooks,
80 | "codebook_size": codebook_size,
81 | }
82 |
83 | # step3: infer
84 | input_ids, audio_datas, audio_data_masks, streaming_user_input_audio_tokens = voila_input_format(history, tokenizer, tokenizer_voila, data_cfg)
85 |
86 | # prepare user_streaming_generator to simulate streaming user input
87 | def get_input_generator(all_tokens):
88 | assert all_tokens is not None
89 | for i in range(len(all_tokens[0])):
90 | yield all_tokens[:,i]
91 |
92 | if model_type == "autonomous":
93 | input_generator = get_input_generator(torch.as_tensor(streaming_user_input_audio_tokens).cuda())
94 | input_ids = [torch.as_tensor([input]).transpose(1,2).cuda() for input in input_ids] # transpose to [bs, seq, num_codebooks]
95 | input_ids = torch.cat(input_ids, dim=2) # concat to [bs, seq, num_codebooks*2]
96 | else:
97 | input_ids = torch.as_tensor([input_ids]).transpose(1,2).cuda() # transpose to [bs, seq, num_codebooks]
98 | gen_params = {
99 | "input_ids": input_ids,
100 | "ref_embs": ref_embs,
101 | "ref_embs_mask": ref_embs_mask,
102 | "max_new_tokens": max_new_tokens,
103 | "pad_token_id": tokenizer.pad_token_id,
104 | "eos_token_id": tokenizer.eos_token_id,
105 | "llm_audio_token_id": AUDIO_TOKEN_ID,
106 | "min_audio_token_id": AUDIO_MIN_TOKEN_ID,
107 | "temperature": 0.2,
108 | "top_k": 50,
109 | "audio_temperature": 0.8,
110 | "audio_top_k": 50,
111 | }
112 | if model_type == "audio":
113 | audio_datas = torch.tensor([audio_datas], dtype=torch.bfloat16).cuda()
114 | audio_data_masks = torch.tensor([audio_data_masks]).cuda()
115 | gen_params["audio_datas"] = audio_datas
116 | gen_params["audio_data_masks"] = audio_data_masks
117 | elif model_type == "autonomous":
118 | gen_params["input_generator"] = input_generator
119 | gen_params["llm_assistant_token_id"] = ASSISTANT_TOKEN_ID
120 | print(f"Input str: {tokenizer.decode(input_ids[0, :, 0])}")
121 | with torch.inference_mode():
122 | outputs = model.run_generate(**gen_params)
123 |
124 | if model_type == "autonomous":
125 | outputs = outputs.chunk(2, dim=2)[1]
126 | outputs = outputs[0].cpu().tolist()
127 |
128 | predict_outputs = outputs[input_ids.shape[1]:]
129 | text_outputs = []
130 | audio_outputs = []
131 | for _ in range(num_codebooks):
132 | audio_outputs.append([])
133 | for item in predict_outputs:
134 | if item[0] >= AUDIO_MIN_TOKEN_ID and item[0] <= AUDIO_MAX_TOKEN_ID:
135 | for n, at in enumerate(item):
136 | audio_outputs[n].append((at - AUDIO_MIN_TOKEN_ID)%codebook_size)
137 | elif item[0] != tokenizer.eos_token_id:
138 | text_outputs.append(item[0])
139 |
140 | out ={
141 | 'text': tokenizer.decode(text_outputs),
142 | }
143 | if is_audio_output_task(task_type):
144 | audio_values = tokenizer_voila.decode(torch.tensor(audio_outputs).cuda())
145 | out['audio'] = (audio_values.detach().cpu().numpy(), 16000)
146 | return out
147 |
148 |
149 | if __name__ == "__main__":
150 | parser = argparse.ArgumentParser()
151 | parser.add_argument("--instruction", type=str, default="")
152 | parser.add_argument("--input-text", type=str, default=None)
153 | parser.add_argument("--input-audio", type=str, default=None)
154 | parser.add_argument("--result-path", type=str, default="output")
155 | parser.add_argument("--ref-audio", type=str, default="examples/test1.mp3")
156 | parser.add_argument("--model-name", type=str, default="maitrix-org/Voila-chat")
157 | parser.add_argument("--audio-tokenizer-path", type=str, default="maitrix-org/Voila-Tokenizer")
158 | parser.add_argument("--task-type", type=str, default="chat_aiao")
159 | args = parser.parse_args()
160 |
161 | assert args.model_name in [
162 | "maitrix-org/Voila-audio-alpha",
163 | "maitrix-org/Voila-base",
164 | "maitrix-org/Voila-chat",
165 | "maitrix-org/Voila-autonomous-preview",
166 | ]
167 |
168 | # step0: Model loading
169 | model, tokenizer, tokenizer_voila, model_type = load_model(args.model_name, args.audio_tokenizer_path)
170 |
171 | # step1: prepare inputs
172 | Path(args.result_path).mkdir(exist_ok=True, parents=True)
173 | history = {
174 | "instruction": args.instruction,
175 | "conversations": [],
176 | }
177 | if args.input_text is not None:
178 | history["conversations"].append({"from": "user", "text": args.input_text})
179 | elif args.input_audio is not None:
180 | history["conversations"].append({"from": "user", "audio": {"file": args.input_audio}})
181 | else:
182 | raise Exception("Please provide atleast one of --input-text and --input-audio")
183 | history["conversations"].append({"from": "assistant"})
184 |
185 | # step2: encode ref
186 | ref_embs, ref_embs_mask = None, None
187 | if is_audio_output_task(args.task_type):
188 | spkr_model = SpeakerEmbedding(device="cuda")
189 | wav, sr = torchaudio.load(args.ref_audio)
190 | ref_embs = spkr_model(wav, sr)
191 | ref_embs_mask = torch.tensor([1]).cuda()
192 |
193 | out = eval_model(model, tokenizer, tokenizer_voila, model_type, args.task_type, history, ref_embs, ref_embs_mask)
194 | print(f"Output str: {out['text']}")
195 | if 'audio' in out:
196 | wav, sr = out['audio']
197 | save_name = f"{args.result_path}/out.wav"
198 | sf.write(save_name, wav, sr)
199 |
--------------------------------------------------------------------------------
/gradio_demo.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 | import shutil
4 | import pickle
5 | import gradio as gr
6 | import soundfile as sf
7 | from pathlib import Path
8 |
9 | import torch
10 | import torchaudio
11 |
12 | from huggingface_hub import hf_hub_download
13 |
14 | from infer import load_model, eval_model
15 | from spkr import SpeakerEmbedding
16 |
17 |
18 | spkr_model = SpeakerEmbedding(device="cuda")
19 | model, tokenizer, tokenizer_voila, model_type = load_model("maitrix-org/Voila-chat", "maitrix-org/Voila-Tokenizer")
20 | default_ref_file = "examples/character_ref_emb_demo.pkl"
21 | default_ref_name = "Homer Simpson"
22 | million_voice_ref_file = hf_hub_download(repo_id="maitrix-org/Voila-million-voice", filename="character_ref_emb_chunk0.pkl", repo_type="dataset")
23 |
24 | instruction = "You are a smart AI agent created by Maitrix.org."
25 | save_path = "output"
26 |
27 | intro = """**Voila**
28 |
29 | For more demos, please goto [https://voila.maitrix.org](https://voila.maitrix.org)."""
30 |
31 | default_ref_emb_mask_list = pickle.load(open(default_ref_file, "rb"))
32 | million_voice_ref_emb_mask_list = pickle.load(open(million_voice_ref_file, "rb"))
33 |
34 | def get_ref_embs(ref_audio):
35 | wav, sr = torchaudio.load(ref_audio)
36 | ref_embs = spkr_model(wav, sr).cpu()
37 | return ref_embs
38 |
39 | def delete_directory(request: gr.Request):
40 | if not request.session_hash:
41 | return
42 | user_dir = Path(f"{save_path}/{str(request.session_hash)}")
43 | if user_dir.exists():
44 | shutil.rmtree(str(user_dir))
45 |
46 | def add_message(history, message):
47 | history.append({"role": "user", "content": {"path": message}})
48 | return history, gr.Audio(value=None), gr.Button(interactive=False)
49 |
50 | def call_bot(history, ref_embs, request: gr.Request):
51 | formated_history = {
52 | "instruction": instruction,
53 | "conversations": [{'from': item["role"], 'audio': {"file": item["content"][0]}} for item in history],
54 | }
55 | formated_history["conversations"].append({"from": "assistant"})
56 | print(formated_history)
57 | ref_embs = torch.tensor(ref_embs, dtype=torch.float32, device="cuda")
58 | ref_embs_mask = torch.tensor([1], device="cuda")
59 | out = eval_model(model, tokenizer, tokenizer_voila, model_type, "chat_aiao", formated_history, ref_embs, ref_embs_mask, max_new_tokens=512)
60 | if 'audio' in out:
61 | wav, sr = out['audio']
62 |
63 | user_dir = Path(f"{save_path}/{str(request.session_hash)}")
64 | user_dir.mkdir(exist_ok=True)
65 | save_name = f"{user_dir}/{len(history)}.wav"
66 | sf.write(save_name, wav, sr)
67 |
68 | history.append({"role": "assistant", "content": {"path": save_name}})
69 | else:
70 | history.append({"role": "assistant", "content": {"text": out['text']}})
71 |
72 | return history
73 |
74 | def run_tts(text, ref_embs):
75 | formated_history = {
76 | "instruction": "",
77 | "conversations": [{'from': "user", 'text': text}],
78 | }
79 | formated_history["conversations"].append({"from": "assistant"})
80 | ref_embs = torch.tensor(ref_embs, dtype=torch.float32, device="cuda")
81 | ref_embs_mask = torch.tensor([1], device="cuda")
82 | out = eval_model(model, tokenizer, tokenizer_voila, model_type, "chat_tts", formated_history, ref_embs, ref_embs_mask, max_new_tokens=512)
83 | if 'audio' in out:
84 | wav, sr = out['audio']
85 | return (sr, wav)
86 | else:
87 | raise Exception("No audio output")
88 |
89 | def run_asr(audio):
90 | formated_history = {
91 | "instruction": "",
92 | "conversations": [{'from': "user", 'audio': {"file": audio}}],
93 | }
94 | formated_history["conversations"].append({"from": "assistant"})
95 | out = eval_model(model, tokenizer, tokenizer_voila, model_type, "chat_asr", formated_history, None, None, max_new_tokens=512)
96 | if 'text' in out:
97 | return out['text']
98 | else:
99 | raise Exception("No text output")
100 |
101 |
102 | def markdown_ref_name(ref_name):
103 | return f"### Current voice id: {ref_name}"
104 |
105 | def random_million_voice():
106 | voice_id = random.choice(list(million_voice_ref_emb_mask_list.keys()))
107 | return markdown_ref_name(voice_id), million_voice_ref_emb_mask_list[voice_id]
108 |
109 | def get_ref_modules(cur_ref_embs):
110 | with gr.Row() as ref_row:
111 | with gr.Row():
112 | current_ref_name = gr.Markdown(markdown_ref_name(default_ref_name))
113 | with gr.Row() as ref_name_row:
114 | with gr.Column(scale=2, min_width=160):
115 | ref_name_dropdown = gr.Dropdown(
116 | choices=list(default_ref_emb_mask_list.keys()),
117 | value=default_ref_name,
118 | label="Reference voice",
119 | min_width=160,
120 | )
121 | with gr.Column(scale=1, min_width=80):
122 | random_ref_button = gr.Button(
123 | "Random from Million Voice", size="md",
124 | )
125 | with gr.Row(visible=False) as ref_audio_row:
126 | with gr.Column(scale=2, min_width=80):
127 | ref_audio = gr.Audio(
128 | sources=["microphone", "upload"],
129 | type="filepath",
130 | show_label=False,
131 | min_width=80,
132 | )
133 | with gr.Column(scale=1, min_width=80):
134 | change_ref_button = gr.Button(
135 | "Change voice",
136 | interactive=False,
137 | min_width=80,
138 | )
139 | ref_name_dropdown.change(
140 | lambda x: (markdown_ref_name(x), default_ref_emb_mask_list[x]),
141 | ref_name_dropdown,
142 | [current_ref_name, cur_ref_embs]
143 | )
144 | random_ref_button.click(
145 | random_million_voice,
146 | None,
147 | [current_ref_name, cur_ref_embs],
148 | )
149 | ref_audio.input(lambda: gr.Button(interactive=True), None, change_ref_button)
150 | # If custom ref voice checkbox is checked, show the Audio component to record or upload a reference voice
151 | custom_ref_voice = gr.Checkbox(label="Use custom voice", value=False)
152 | # Checked: enable audio and button
153 | # Unchecked: disable audio and button
154 | def custom_ref_voice_change(x, cur_ref_embs, cur_ref_embs_mask):
155 | if not x:
156 | cur_ref_embs = default_ref_emb_mask_list[default_ref_name]
157 | return [gr.Row(visible=not x), gr.Audio(value=None), gr.Row(visible=x), markdown_ref_name("Custom voice"), cur_ref_embs]
158 | custom_ref_voice.change(
159 | custom_ref_voice_change,
160 | [custom_ref_voice, cur_ref_embs],
161 | [ref_name_row, ref_audio, ref_audio_row, current_ref_name, cur_ref_embs]
162 | )
163 | # When change ref button is clicked, get the reference voice and update the reference voice state
164 | change_ref_button.click(
165 | lambda: gr.Button(interactive=False), None, [change_ref_button]
166 | ).then(
167 | get_ref_embs, ref_audio, cur_ref_embs
168 | )
169 | return ref_row
170 |
171 | def get_chat_tab():
172 | cur_ref_embs = gr.State(default_ref_emb_mask_list[default_ref_name])
173 | with gr.Row() as chat_tab:
174 | with gr.Column(scale=1):
175 | ref_row = get_ref_modules(cur_ref_embs)
176 | # Voice chat input
177 | chat_input = gr.Audio(
178 | sources=["microphone", "upload"],
179 | type="filepath",
180 | show_label=False,
181 | )
182 | submit = gr.Button("Submit", interactive=False)
183 | gr.Markdown(intro)
184 | with gr.Column(scale=9):
185 | chatbot = gr.Chatbot(
186 | elem_id="chatbot",
187 | type="messages",
188 | bubble_full_width=False,
189 | scale=1,
190 | show_copy_button=False,
191 | avatar_images=(
192 | None, # os.path.join("files", "avatar.png"),
193 | None, # os.path.join("files", "avatar.png"),
194 | ),
195 | )
196 |
197 | chat_input.input(lambda: gr.Button(interactive=True), None, submit)
198 | submit.click(
199 | add_message, [chatbot, chat_input], [chatbot, chat_input, submit]
200 | ).then(
201 | call_bot, [chatbot, cur_ref_embs], chatbot, api_name="bot_response"
202 | )
203 | return chat_tab
204 |
205 | def get_tts_tab():
206 | cur_ref_embs = gr.State(default_ref_emb_mask_list[default_ref_name])
207 | with gr.Row() as tts_tab:
208 | with gr.Column(scale=1):
209 | ref_row = get_ref_modules(cur_ref_embs)
210 | gr.Markdown(intro)
211 | with gr.Column(scale=9):
212 | tts_output = gr.Audio(label="TTS output", interactive=False)
213 | with gr.Row():
214 | text_input = gr.Textbox(label="Text", placeholder="Text to TTS")
215 | submit = gr.Button("Submit")
216 | submit.click(
217 | run_tts, [text_input, cur_ref_embs], tts_output
218 | )
219 | return tts_tab
220 |
221 | def get_asr_tab():
222 | with gr.Row() as asr_tab:
223 | with gr.Column():
224 | asr_input = gr.Audio(
225 | label="ASR input",
226 | sources=["microphone", "upload"],
227 | type="filepath",
228 | )
229 | submit = gr.Button("Submit")
230 | gr.Markdown(intro)
231 | with gr.Column():
232 | asr_output = gr.Textbox(label="ASR output", interactive=False)
233 | submit.click(
234 | run_asr, [asr_input], asr_output
235 | )
236 | return asr_tab
237 |
238 | with gr.Blocks(fill_height=True) as demo:
239 | with gr.Tab("Chat"):
240 | chat_tab = get_chat_tab()
241 | with gr.Tab("TTS"):
242 | tts_tab = get_tts_tab()
243 | with gr.Tab("ASR"):
244 | asr_tab = get_asr_tab()
245 | demo.unload(delete_directory)
246 |
247 | if __name__ == "__main__":
248 | demo.launch(share=True)
249 |
--------------------------------------------------------------------------------
/audio_transformer.py:
--------------------------------------------------------------------------------
1 | import math
2 | from typing import Optional
3 | from dataclasses import dataclass
4 |
5 | import torch
6 | import torch.nn as nn
7 | from torch import Tensor
8 | from torch.nn import functional as F
9 |
10 | from einops import rearrange
11 |
12 |
13 | @dataclass
14 | class LocalArgs:
15 | codebook_size: int = 2048
16 | num_codebooks: int = 4
17 |
18 | # Modified from https://github.com/fishaudio/fish-speech/blob/main/fish_speech/models/text2semantic/llama.py#L105
19 | class KVCache(nn.Module):
20 | def __init__(
21 | self, n_layer, batch_size, max_seq_len, n_heads, head_dim, dtype, device
22 | ):
23 | super().__init__()
24 | cache_shape = (n_layer, batch_size, n_heads, max_seq_len, head_dim)
25 | self.register_buffer("k_cache", torch.zeros(cache_shape, dtype=dtype, device=device))
26 | self.register_buffer("v_cache", torch.zeros(cache_shape, dtype=dtype, device=device))
27 |
28 | def update(self, layer_idx, input_pos, k_val, v_val):
29 | # k_val: [B, H, S, D]
30 |
31 | k_out = self.k_cache
32 | v_out = self.v_cache
33 | k_out[layer_idx, :, :, input_pos:input_pos+1] = k_val
34 | v_out[layer_idx, :, :, input_pos:input_pos+1] = v_val
35 |
36 | return k_out[layer_idx], v_out[layer_idx]
37 |
38 | # Modified from https://github.com/fishaudio/fish-speech/blob/main/fish_speech/models/text2semantic/llama.py#L756
39 | def precompute_freqs_cis(seq_len: int, n_elem: int, base: int = 10000) -> Tensor:
40 | freqs = 1.0 / (
41 | base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem)
42 | )
43 | t = torch.arange(seq_len, device=freqs.device)
44 | freqs = torch.outer(t, freqs)
45 | freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
46 | cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1)
47 | return cache
48 |
49 | # Copied from https://github.com/fishaudio/fish-speech/blob/main/fish_speech/models/text2semantic/llama.py#L767
50 | def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor:
51 | xshaped = x.float().reshape(*x.shape[:-1], -1, 2)
52 | freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2)
53 | x_out2 = torch.stack(
54 | [
55 | xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1],
56 | xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1],
57 | ],
58 | -1,
59 | )
60 |
61 | x_out2 = x_out2.flatten(3)
62 | return x_out2.type_as(x)
63 |
64 | # Copied from https://github.com/fishaudio/fish-speech/blob/main/fish_speech/models/text2semantic/llama.py#L742
65 | class RMSNorm(nn.Module):
66 | def __init__(self, dim: int, eps: float = 1e-5):
67 | super().__init__()
68 | self.eps = eps
69 | self.weight = nn.Parameter(torch.ones(dim))
70 |
71 | def _norm(self, x):
72 | return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps)
73 |
74 | def forward(self, x: Tensor) -> Tensor:
75 | output = self._norm(x.float()).type_as(x)
76 | return output * self.weight
77 |
78 | # Copied from https://github.com/fishaudio/fish-speech/blob/main/fish_speech/models/text2semantic/llama.py#L731
79 | class FeedForward(nn.Module):
80 | def __init__(self, config: LocalArgs) -> None:
81 | super().__init__()
82 | self.w1 = nn.Linear(config.dim, config.intermediate_size, bias=False)
83 | self.w3 = nn.Linear(config.dim, config.intermediate_size, bias=False)
84 | self.w2 = nn.Linear(config.intermediate_size, config.dim, bias=False)
85 |
86 | def forward(self, x: Tensor) -> Tensor:
87 | return self.w2(F.silu(self.w1(x)) * self.w3(x))
88 |
89 | # Modified from https://github.com/fishaudio/fish-speech/blob/main/fish_speech/models/text2semantic/llama.py#L615
90 | class Attention(nn.Module):
91 | def __init__(self, config: LocalArgs, layer_idx: int, use_sdpa: bool = True):
92 | super().__init__()
93 | assert config.dim % config.n_head == 0
94 | self.layer_idx = layer_idx
95 |
96 | total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim
97 | # key, query, value projections for all heads, but in a batch
98 | self.wqkv = nn.Linear(
99 | config.dim, total_head_dim, bias=config.attention_qkv_bias
100 | )
101 | self.wo = nn.Linear(config.dim, config.dim, bias=False)
102 |
103 | self.dropout = config.dropout
104 | self.n_head = config.n_head
105 | self.head_dim = config.head_dim
106 | self.n_local_heads = config.n_local_heads
107 | self.dim = config.dim
108 | self.use_sdpa = use_sdpa
109 | self._register_load_state_dict_pre_hook(self.load_hook)
110 |
111 | def load_hook(self, state_dict, prefix, *args):
112 | if prefix + "wq.weight" in state_dict:
113 | wq = state_dict.pop(prefix + "wq.weight")
114 | wk = state_dict.pop(prefix + "wk.weight")
115 | wv = state_dict.pop(prefix + "wv.weight")
116 | state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv])
117 |
118 | def forward(
119 | self,
120 | x: Tensor,
121 | freqs_cis: Tensor,
122 | mask: Tensor,
123 | input_pos: Optional[int] = None,
124 | kv_cache: Optional[KVCache] = None,
125 | ) -> Tensor:
126 | bsz, seqlen, _ = x.shape
127 |
128 | kv_size = self.n_local_heads * self.head_dim
129 | q, k, v = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1)
130 |
131 | q = q.view(bsz, seqlen, self.n_head, self.head_dim)
132 | k = k.view(bsz, seqlen, self.n_local_heads, self.head_dim)
133 | v = v.view(bsz, seqlen, self.n_local_heads, self.head_dim)
134 |
135 | q = apply_rotary_emb(q, freqs_cis)
136 | k = apply_rotary_emb(k, freqs_cis)
137 |
138 | q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
139 |
140 | if kv_cache is not None:
141 | k, v = kv_cache.update(self.layer_idx, input_pos, k, v)
142 |
143 | k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
144 | v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
145 |
146 | if self.use_sdpa:
147 | if mask is None:
148 | with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
149 | y = F.scaled_dot_product_attention(
150 | q,
151 | k,
152 | v,
153 | dropout_p=self.dropout if self.training else 0.0,
154 | is_causal=True,
155 | # No third party attn_mask here to use flash_attention
156 | )
157 | else:
158 | y = F.scaled_dot_product_attention(
159 | q,
160 | k,
161 | v,
162 | attn_mask=mask,
163 | dropout_p=self.dropout if self.training else 0.0,
164 | )
165 | else:
166 | y = self.eq_scaled_dot_product_attention(
167 | q,
168 | k,
169 | v,
170 | attn_mask=mask,
171 | dropout_p=self.dropout if self.training else 0.0,
172 | )
173 |
174 | y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)
175 |
176 | return self.wo(y)
177 |
178 | def eq_scaled_dot_product_attention(
179 | self,
180 | query,
181 | key,
182 | value,
183 | attn_mask=None,
184 | dropout_p=0.0,
185 | ) -> torch.Tensor:
186 | # This is a standard scaled dot product attention
187 | # It's low efficient, but it doesn't raise cuda error
188 |
189 | L, S = query.size(-2), key.size(-2)
190 | scale_factor = 1 / math.sqrt(query.size(-1))
191 | attn_bias = torch.zeros(1, 1, L, S, dtype=query.dtype, device=query.device)
192 |
193 | if attn_mask is not None:
194 | if attn_mask.dtype == torch.bool:
195 | attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
196 | else:
197 | attn_bias += attn_mask
198 |
199 | attn_weight = query @ key.transpose(-2, -1) * scale_factor
200 | attn_weight += attn_bias
201 | attn_weight = torch.softmax(attn_weight, dim=-1)
202 | attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
203 |
204 | return attn_weight @ value
205 |
206 | # Copied from https://github.com/fishaudio/fish-speech/blob/main/fish_speech/models/text2semantic/llama.py#L599
207 | class TransformerBlock(nn.Module):
208 | def __init__(self, config: LocalArgs, layer_idx: int, use_sdpa: bool = True) -> None:
209 | super().__init__()
210 | self.attention = Attention(config, layer_idx, use_sdpa=use_sdpa)
211 | self.feed_forward = FeedForward(config)
212 | self.ffn_norm = RMSNorm(config.dim, config.norm_eps)
213 | self.attention_norm = RMSNorm(config.dim, config.norm_eps)
214 |
215 | def forward(
216 | self, x: Tensor, freqs_cis: Tensor, mask: Tensor, input_pos: int = None, kv_cache: KVCache = None
217 | ) -> Tensor:
218 | h = x + self.attention(self.attention_norm(x), freqs_cis, mask, input_pos, kv_cache)
219 | out = h + self.feed_forward(self.ffn_norm(h))
220 | return out
221 |
222 | # Modified from https://github.com/fishaudio/fish-speech/blob/main/fish_speech/models/text2semantic/llama.py#L470
223 | class AudioTransformer(nn.Module):
224 | def __init__(self, config, use_sdpa: bool = False):
225 | super().__init__()
226 | self.config = LocalArgs()
227 | self.config.codebook_size = config.codebook_size
228 | self.config.num_codebooks = config.num_codebooks
229 | if hasattr(config, "min_audio_token_id"):
230 | self.config.min_audio_token_id = config.min_audio_token_id
231 | self.config.max_audio_token_id = config.max_audio_token_id
232 | self.config.n_layer = 4
233 | self.config.dim = 1024
234 | self.config.n_head = 32
235 | self.config.n_local_heads = 32
236 | self.config.intermediate_size = 2816
237 | self.config.head_dim = self.config.dim // self.config.n_head
238 | self.config.norm_eps = 1e-5
239 | self.config.attention_qkv_bias = False
240 | self.config.dropout = 0.0
241 |
242 | self.embeddings = nn.Embedding(self.config.codebook_size, self.config.dim)
243 | if self.config.dim != config.hidden_size:
244 | self.input_proj = nn.Linear(config.hidden_size, self.config.dim, bias=False)
245 | else:
246 | self.input_proj = nn.Identity()
247 | self.layers = nn.ModuleList(
248 | TransformerBlock(self.config, layer_idx, use_sdpa=use_sdpa) for layer_idx in range(self.config.n_layer)
249 | )
250 | self.norm = RMSNorm(self.config.dim, eps=self.config.norm_eps)
251 | self.token_head = nn.Linear(self.config.dim, self.config.codebook_size, bias=False)
252 | self.gradient_checkpointing = False
253 |
254 | self.register_buffer(
255 | "freqs_cis",
256 | precompute_freqs_cis(self.config.num_codebooks, self.config.dim // self.config.n_head, 10000),
257 | persistent=False,
258 | )
259 | self.register_buffer(
260 | "attention_mask",
261 | torch.tril(torch.ones(self.config.num_codebooks, self.config.num_codebooks, dtype=torch.bool)),
262 | persistent=False,
263 | )
264 |
265 | def run_model(self, hidden_states, freqs_cis, attention_mask, input_pos: int = None, kv_cache: KVCache = None):
266 | for layer in self.layers:
267 | # TODO: gradient_checkpointing is disabled because of bug
268 | if False: # self.gradient_checkpointing and self.training:
269 | hidden_states = self._gradient_checkpointing_func(
270 | layer.__call__,
271 | hidden_states,
272 | freqs_cis,
273 | attention_mask,
274 | use_reentrant=True,
275 | )
276 | else:
277 | hidden_states = layer(hidden_states, freqs_cis, attention_mask, input_pos, kv_cache)
278 | hidden_states = self.norm(hidden_states)
279 | logits = self.token_head(hidden_states)
280 | return logits.float()
281 |
282 | # inp: [bs, hidden_size]
283 | # labels: [bs, num_codebooks]
284 | # logits: [bs, num_codebooks, codebook_size]
285 | def forward(self, inp, labels):
286 | bs = inp.shape[0]
287 |
288 | hidden_states = self.input_proj(inp)
289 | if self.freqs_cis.dtype != hidden_states.dtype:
290 | self.freqs_cis = self.freqs_cis.to(dtype=hidden_states.dtype)
291 | if labels is not None:
292 | # Training mode
293 | # Get embedding
294 | assert bs == labels.shape[0] and labels.shape[1] == self.config.num_codebooks, f"Labels shape error: {labels.shape}"
295 | hidden_states = [hidden_states[:, None, :], self.embeddings(labels[..., :-1]).to(hidden_states.dtype)]
296 | hidden_states = torch.cat(hidden_states, dim=1) # [bs, num_codebooks, hidden_size]
297 | # Run attention layers
298 | logits = self.run_model(hidden_states, self.freqs_cis, self.attention_mask)
299 | else:
300 | # Inference mode
301 | raise RuntimeError(f"Please call function \"inference\" in inference mode")
302 | return logits
303 |
304 | # inp: [bs, seq_len, hidden_size]
305 | # out_tokens: [bs, 1, num_codebooks]
306 | @torch.inference_mode()
307 | def inference(self, inp, temperature=0, top_k=0):
308 | # Only use the last hidden states for token computation
309 | inp = inp[:, -1:, :]
310 |
311 | bs = inp.shape[0]
312 | if self.freqs_cis.dtype != inp.dtype:
313 | self.freqs_cis = self.freqs_cis.to(dtype=inp.dtype)
314 |
315 | inp = self.input_proj(inp)
316 |
317 | # Inference mode
318 | kv_cache = KVCache(
319 | self.config.n_layer,
320 | bs,
321 | self.config.num_codebooks,
322 | self.config.n_head,
323 | self.config.head_dim,
324 | dtype=inp.dtype,
325 | device=inp.device,
326 | )
327 | # Generate one token per step
328 | out_tokens = []
329 | for input_pos in range(self.config.num_codebooks):
330 | inp = inp.reshape(bs, 1, self.config.dim)
331 | local_freqs_cis = self.freqs_cis[input_pos]
332 | local_mask = self.attention_mask[None, None, input_pos, :self.config.num_codebooks]
333 |
334 | logits = self.run_model(inp, local_freqs_cis, local_mask, input_pos, kv_cache)
335 | logits = logits.squeeze(dim=1)
336 |
337 | # Apply temperature and top-k
338 | if temperature > 0:
339 | logits = logits / temperature
340 | if top_k > 0:
341 | top_k = min(top_k, logits.size(-1)) # Safety check
342 | # Remove all tokens with a probability less than the last token of the top-k
343 | indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
344 | logits = logits.masked_fill(indices_to_remove, -float("Inf"))
345 |
346 | # Do sample
347 | probs = nn.functional.softmax(logits, dim=-1)
348 | next_tokens = torch.multinomial(probs, num_samples=1)
349 |
350 | next_tokens = next_tokens.reshape(bs, 1, 1)
351 | inp = self.embeddings(next_tokens)
352 | out_tokens.append(next_tokens)
353 |
354 | return torch.cat(out_tokens, dim=-1)
355 |
--------------------------------------------------------------------------------
/tokenize_func.py:
--------------------------------------------------------------------------------
1 | import io
2 | import copy
3 | import librosa
4 | import numpy as np
5 |
6 |
7 | AUDIO_TOKEN_FORMAT = "<|{}|>"
8 |
9 | DEFAULT_SYSTEM_START_TOKEN = ""
10 | DEFAULT_SYSTEM_END_TOKEN = ""
11 |
12 | DEFAULT_TTS_REF_START_TOKEN = ""
13 | DEFAULT_TTS_REF_END_TOKEN = ""
14 | DEFAULT_TTS_REF_TOKEN = ""
15 |
16 | DEFAULT_CHAT_REF_START_TOKEN = ""
17 | DEFAULT_CHAT_REF_END_TOKEN = ""
18 | DEFAULT_CHAT_REF_TOKEN = ""
19 |
20 | DEFAULT_HUMAN_TOKEN = "<|HUMAN|>"
21 | DEFAULT_ASSISTANT_TOKEN = "<|VOILA|>"
22 |
23 | DEFAULT_AUDIO_TOKEN = ""
24 |
25 | # ===================================
26 | # task special token
27 | # -----------------------------------
28 | TASK_ASR_TOKEN = ""
29 | TASK_TTS_TOKEN = ""
30 | TASK_CHAT_TOKEN = ""
31 | TASK_STREAM_CHAT_TOKEN = ""
32 |
33 | TASK_ASR_TEXT_OUTPUT = ""
34 | TASK_TTS_AUDIO_OUTPUT = ""
35 | TASK_CHAT_TEXT_OUTPUT = ""
36 | TASK_CHAT_AUDIO_OUTPUT = ""
37 |
38 | CHAT_AUDIO_TEXT_SPLIT_TOKEN = ""
39 | # ===================================
40 |
41 | PREPEND_LEN = 80
42 | SEG_LEN = 640
43 | AUDIO_SR = 16000
44 |
45 | TASK_TYPE_CONF = {
46 | "chat_asr": TASK_ASR_TOKEN + TASK_ASR_TEXT_OUTPUT,
47 | "chat_tts": TASK_TTS_TOKEN + TASK_TTS_AUDIO_OUTPUT,
48 | "chat_tito": TASK_CHAT_TOKEN + TASK_CHAT_TEXT_OUTPUT,
49 | "chat_tiao": TASK_CHAT_TOKEN + TASK_CHAT_AUDIO_OUTPUT,
50 | "chat_aiao": TASK_CHAT_TOKEN + TASK_CHAT_AUDIO_OUTPUT,
51 | "chat_atiao": TASK_CHAT_TOKEN + TASK_CHAT_AUDIO_OUTPUT,
52 | "chat_aiao_auto": TASK_STREAM_CHAT_TOKEN + TASK_CHAT_AUDIO_OUTPUT,
53 | }
54 |
55 |
56 | def _get_zero_audio_pad(token_num):
57 | return np.zeros(SEG_LEN*token_num)
58 |
59 | def _wrapper_audio_tokens(audio_tokens, num_codebooks, codebook_size):
60 | ret_audio_tokens = []
61 | for n in range(num_codebooks):
62 | audio_token = audio_tokens[n]
63 | ret_audio_tokens.append(''.join([AUDIO_TOKEN_FORMAT.format(au + n*codebook_size) if isinstance(au, int) else au for au in audio_token]))
64 | return ret_audio_tokens
65 |
66 | def _wrapper_audio_tokens_autonomous(audio_tokens, num_codebooks, codebook_size, audio_token_min_id):
67 | ret_audio_tokens = []
68 | for n in range(num_codebooks):
69 | audio_token = audio_tokens[n]
70 | ret_audio_tokens.append([(au + n*codebook_size + audio_token_min_id) for au in audio_token])
71 | return ret_audio_tokens
72 |
73 | # Item format
74 | # {
75 | # "instruction": "",
76 | # "conversations": [
77 | # {
78 | # "from": "user" or "assistant",
79 | # "text": "",
80 | # "audio": {
81 | # "array": [],
82 | # "sr": 16000,
83 | # "bytes": "",
84 | # "file": "",
85 | # },
86 | # }
87 | # ],
88 | # }
89 | def _token_input_format(item, tokenizer, tokenizer_voila, dataset_cfg):
90 | task_type = dataset_cfg["task_type"]
91 | num_codebooks = dataset_cfg["num_codebooks"]
92 | codebook_size = dataset_cfg["codebook_size"]
93 |
94 | task_token = TASK_TYPE_CONF[task_type]
95 |
96 | # Construct system message
97 | system = item["instruction"]
98 | if task_type in ["chat_aiao", "chat_atiao", "chat_tiao"]:
99 | system = DEFAULT_CHAT_REF_START_TOKEN + DEFAULT_CHAT_REF_TOKEN + DEFAULT_CHAT_REF_END_TOKEN + system
100 | elif task_type == "chat_tts":
101 | system = DEFAULT_TTS_REF_START_TOKEN + DEFAULT_TTS_REF_TOKEN + DEFAULT_TTS_REF_END_TOKEN + system
102 | else:
103 | print (f"task type {task_type} do not use ref.")
104 | system = task_token + system
105 | system = DEFAULT_SYSTEM_START_TOKEN + system + DEFAULT_SYSTEM_END_TOKEN
106 |
107 | # Get ids for system
108 | system_ids = tokenizer.encode(system, add_special_tokens=False)
109 |
110 | # Copy into num_codebooks input ids
111 | input_ids_list = []
112 | for _ in range(num_codebooks):
113 | input_ids_list.append(copy.deepcopy(system_ids))
114 |
115 | # Assemble conversations
116 | for i, turn in enumerate(item["conversations"]):
117 | if turn['from'] == 'assistant':
118 | # task with audio token as input, prepare audio token
119 | if task_type in ["chat_aiao", "chat_tts"]:
120 | if "audio" not in turn:
121 | content = DEFAULT_ASSISTANT_TOKEN
122 | content_ids = tokenizer.encode(content, add_special_tokens=False)
123 | for n in range(num_codebooks):
124 | input_ids_list[n] += copy.deepcopy(content_ids)
125 | else:
126 | # Load audio
127 | if 'array' in turn['audio']:
128 | assert "sr" in turn["audio"]
129 | if len(turn["audio"]['array'].shape) > 1:
130 | assert turn["audio"]['array'].shape[0] <= 2
131 | turn["audio"]['array'] = librosa.to_mono(turn["audio"]['array'])
132 | audio = librosa.resample(turn["audio"]['array'], orig_sr=turn["audio"]["sr"], target_sr=AUDIO_SR)
133 | elif "bytes" in turn['audio']:
134 | audio, _ = librosa.load(io.BytesIO(turn["audio"]['bytes']), sr=AUDIO_SR)
135 | elif "file" in turn['audio']:
136 | audio, _ = librosa.load(turn["audio"]['file'], sr=AUDIO_SR)
137 | else:
138 | raise Exception(f"No audio input for task {task_type}")
139 |
140 | # get audio token
141 | audio_tokens = tokenizer_voila.encode(audio, sr=AUDIO_SR)
142 | audio_tokens = audio_tokens.cpu().numpy().tolist()
143 | audio_tokens = _wrapper_audio_tokens(audio_tokens, num_codebooks, codebook_size)
144 |
145 | for n in range(num_codebooks):
146 | content = DEFAULT_ASSISTANT_TOKEN + audio_tokens[n] + tokenizer.eos_token
147 | content_ids = tokenizer.encode(content, add_special_tokens=False, truncation=True,
148 | max_length=tokenizer.model_max_length)
149 | input_ids_list[n] += content_ids
150 |
151 | elif task_type in ["chat_tito", "chat_asr"]:
152 | if "text" not in turn:
153 | content = DEFAULT_ASSISTANT_TOKEN
154 | content_ids = tokenizer.encode(content, add_special_tokens=False)
155 | for n in range(num_codebooks):
156 | input_ids_list[n] += copy.deepcopy(content_ids)
157 | else:
158 | text = turn['text'].strip()
159 | content = DEFAULT_ASSISTANT_TOKEN + text + tokenizer.eos_token
160 | content_ids = tokenizer.encode(content, add_special_tokens=False, truncation=True,
161 | max_length=tokenizer.model_max_length)
162 | for n in range(num_codebooks):
163 | input_ids_list[n] += copy.deepcopy(content_ids)
164 | else:
165 | raise ValueError (f"[Error] Invalid data type of {task_type}.")
166 | else:
167 | # task with audio token as input, prepare audio token
168 | if task_type in ["chat_aiao", "chat_asr"]:
169 | # Load audio
170 | assert "audio" in turn
171 | if 'array' in turn['audio']:
172 | assert "sr" in turn["audio"]
173 | if len(turn["audio"]['array'].shape) > 1:
174 | assert turn["audio"]['array'].shape[0] <= 2
175 | turn["audio"]['array'] = librosa.to_mono(turn["audio"]['array'])
176 | audio = librosa.resample(turn["audio"]['array'], orig_sr=turn["audio"]["sr"], target_sr=AUDIO_SR)
177 | elif "bytes" in turn['audio']:
178 | audio, _ = librosa.load(io.BytesIO(turn["audio"]['bytes']), sr=AUDIO_SR)
179 | elif "file" in turn['audio']:
180 | audio, _ = librosa.load(turn["audio"]['file'], sr=AUDIO_SR)
181 | else:
182 | raise Exception(f"No audio input for task {task_type}")
183 |
184 | # get audio token
185 | audio_tokens = tokenizer_voila.encode(audio, sr=AUDIO_SR)
186 | audio_tokens = audio_tokens.cpu().numpy().tolist()
187 | audio_tokens = _wrapper_audio_tokens(audio_tokens, num_codebooks, codebook_size)
188 |
189 | for n in range(num_codebooks):
190 | content = DEFAULT_HUMAN_TOKEN + audio_tokens[n]
191 | content_ids = tokenizer.encode(content, add_special_tokens=False, truncation=True,
192 | max_length=tokenizer.model_max_length)
193 | input_ids_list[n] += copy.deepcopy(content_ids)
194 | elif task_type in ["chat_tito", "chat_tts"]:
195 | text = turn['text'].strip()
196 | content = DEFAULT_HUMAN_TOKEN + text
197 | content_ids = tokenizer.encode(content, add_special_tokens=False, truncation=True,
198 | max_length=tokenizer.model_max_length)
199 | for n in range(num_codebooks):
200 | input_ids_list[n] += copy.deepcopy(content_ids)
201 | else:
202 | raise ValueError (f"[Error] Invalid data type of {task_type}.")
203 |
204 | for n in range(num_codebooks):
205 | input_ids_list[n] = input_ids_list[n][:tokenizer.model_max_length]
206 |
207 | return input_ids_list, None, None, None
208 |
209 | def _token_input_format_autonomous(item, tokenizer, tokenizer_voila, dataset_cfg):
210 | task_type = dataset_cfg["task_type"]
211 | num_codebooks = dataset_cfg["num_codebooks"]
212 | codebook_size = dataset_cfg["codebook_size"]
213 | assert task_type == "chat_aiao_auto", f"only support chat_aiao_auto, {task_type} is invalid"
214 |
215 | DEFAULT_HUMAN_TOKEN_ID = tokenizer.convert_tokens_to_ids(DEFAULT_HUMAN_TOKEN)
216 | assert isinstance(DEFAULT_HUMAN_TOKEN_ID, int), "DEFAULT_HUMAN_TOKEN_ID should be an integer"
217 | AUDIO_MIN_TOKEN_ID = tokenizer.convert_tokens_to_ids(AUDIO_TOKEN_FORMAT.format(0))
218 | assert isinstance(AUDIO_MIN_TOKEN_ID, int), "AUDIO_MIN_TOKEN_ID should be an integer"
219 |
220 | task_token = TASK_TYPE_CONF[task_type]
221 |
222 | # Construct system message
223 | system = DEFAULT_CHAT_REF_START_TOKEN + DEFAULT_CHAT_REF_TOKEN + DEFAULT_CHAT_REF_END_TOKEN
224 | system = task_token + system
225 | system = DEFAULT_SYSTEM_START_TOKEN + system + DEFAULT_SYSTEM_END_TOKEN
226 |
227 | # Get ids for system
228 | system_ids_list = [[], []]
229 | system_ids = tokenizer.encode(system, add_special_tokens=False)
230 |
231 | # Insert instruction tokens into system prompt tokens
232 | instruction = item["instruction"]
233 | if instruction != "":
234 | instruction_ids = tokenizer.encode(instruction, add_special_tokens=False)
235 | else:
236 | instruction_ids = []
237 |
238 | system_ids_list[0] = system_ids[:-1] + instruction_ids + system_ids[-1:]
239 | system_ids_list[1] = system_ids[:-1] + instruction_ids + system_ids[-1:]
240 |
241 | # Copy into num_codebooks input ids
242 | channel1_input_ids_list = [[] for _ in range(num_codebooks)]
243 | channel2_input_ids_list = [[] for _ in range(num_codebooks)]
244 | for n in range(num_codebooks):
245 | channel1_input_ids_list[n] += copy.deepcopy(system_ids_list[0]) + [DEFAULT_HUMAN_TOKEN_ID]
246 | channel2_input_ids_list[n] += copy.deepcopy(system_ids_list[1]) + [DEFAULT_HUMAN_TOKEN_ID]
247 |
248 | # prepare audio token to simulate streaming input
249 | audio_meta = item['conversations'][0]['audio']
250 | if 'array' in audio_meta:
251 | assert "sr" in audio_meta
252 | if len(audio_meta['array'].shape) > 1:
253 | assert audio_meta['array'].shape[0] <= 2
254 | audio_meta['array'] = librosa.to_mono(audio_meta['array'])
255 | audio = librosa.resample(audio_meta['array'], orig_sr=audio_meta["sr"], target_sr=AUDIO_SR)
256 | elif "bytes" in audio_meta:
257 | audio, _ = librosa.load(io.BytesIO(audio_meta['bytes']), sr=AUDIO_SR)
258 | elif "file" in audio_meta:
259 | audio, _ = librosa.load(audio_meta['file'], sr=AUDIO_SR)
260 | else:
261 | raise Exception(f"No audio input for task {task_type}")
262 |
263 | # get audio token
264 | streaming_user_input_audio_tokens = tokenizer_voila.encode(audio, sr=AUDIO_SR)
265 | streaming_user_input_audio_tokens = streaming_user_input_audio_tokens.cpu().numpy().tolist()
266 | streaming_user_input_audio_tokens = _wrapper_audio_tokens_autonomous(streaming_user_input_audio_tokens, num_codebooks, codebook_size, AUDIO_MIN_TOKEN_ID)
267 |
268 | return [channel1_input_ids_list, channel2_input_ids_list], None, None, streaming_user_input_audio_tokens
269 |
270 | def _alpha_audio_input_format(item, tokenizer, dataset_cfg):
271 | task_type = dataset_cfg["task_type"]
272 | num_codebooks = dataset_cfg["num_codebooks"]
273 | codebook_size = dataset_cfg["codebook_size"]
274 |
275 | task_token = TASK_TYPE_CONF[task_type]
276 |
277 | # Construct system message
278 | system = item["instruction"]
279 | if task_type in ["chat_aiao", "chat_atiao", "chat_tiao"]:
280 | system = DEFAULT_CHAT_REF_START_TOKEN + DEFAULT_CHAT_REF_TOKEN + DEFAULT_CHAT_REF_END_TOKEN + system
281 | elif task_type == "chat_tts":
282 | system = DEFAULT_TTS_REF_START_TOKEN + DEFAULT_TTS_REF_TOKEN + DEFAULT_TTS_REF_END_TOKEN + system
283 | else:
284 | print (f"task type {task_type} do not use ref.")
285 | system = task_token + system
286 | system = DEFAULT_SYSTEM_START_TOKEN + system + DEFAULT_SYSTEM_END_TOKEN
287 |
288 | # Get ids for system
289 | system_ids = tokenizer.encode(system, add_special_tokens=False)
290 |
291 | # Copy into num_codebooks input ids
292 | input_ids_list = []
293 | for _ in range(num_codebooks):
294 | input_ids_list.append(copy.deepcopy(system_ids))
295 |
296 | # Construct audio data and mask
297 | audio_data = [np.array([0]*PREPEND_LEN)]
298 | audio_data.append(_get_zero_audio_pad(len(system_ids)))
299 | audio_data_mask = [0] * len(system_ids)
300 |
301 | # Assemble conversations
302 | for i, turn in enumerate(item["conversations"]):
303 | if turn['from'] == 'assistant':
304 | # task with audio token as input, prepare audio token
305 | if task_type in ["chat_aiao"]:
306 | if "audio" not in turn:
307 | content = DEFAULT_ASSISTANT_TOKEN
308 | content_ids = tokenizer.encode(content, add_special_tokens=False)
309 | for n in range(num_codebooks):
310 | input_ids_list[n] += copy.deepcopy(content_ids)
311 | # preprocess audio_data & audio_data_mask
312 | audio_data.append(_get_zero_audio_pad(len(content_ids)))
313 | audio_data_mask += [0] * len(content_ids)
314 | else:
315 | # Load audio
316 | if 'array' in turn['audio']:
317 | assert "sr" in turn["audio"]
318 | if len(turn["audio"]['array'].shape) > 1:
319 | assert turn["audio"]['array'].shape[0] <= 2
320 | turn["audio"]['array'] = librosa.to_mono(turn["audio"]['array'])
321 | audio = librosa.resample(turn["audio"]['array'], orig_sr=turn["audio"]["sr"], target_sr=AUDIO_SR)
322 | elif "bytes" in turn['audio']:
323 | audio, _ = librosa.load(io.BytesIO(turn["audio"]['bytes']), sr=AUDIO_SR)
324 | elif "file" in turn['audio']:
325 | audio, _ = librosa.load(turn["audio"]['file'], sr=AUDIO_SR)
326 | else:
327 | raise Exception(f"No audio input for task {task_type}")
328 |
329 | # get audio token
330 | audio_token_num = int(len(audio) / SEG_LEN)
331 | audio_token = [DEFAULT_AUDIO_TOKEN] * audio_token_num
332 | audio_token = ''.join(audio_token)
333 | audio = audio[:SEG_LEN*audio_token_num] # trim audio
334 |
335 | content = DEFAULT_ASSISTANT_TOKEN + audio_token + tokenizer.eos_token
336 | content_ids = tokenizer.encode(content, add_special_tokens=False, truncation=True,
337 | max_length=tokenizer.model_max_length)
338 | for n in range(num_codebooks):
339 | input_ids_list[n] += copy.deepcopy(content_ids)
340 |
341 | audio_data.append(_get_zero_audio_pad(1))
342 | audio_data_mask += [0]
343 | audio_data.append(audio)
344 | audio_data_mask += [1] * audio_token_num
345 | audio_data.append(_get_zero_audio_pad(1))
346 | audio_data_mask += [0]
347 | elif task_type in ["chat_tito"]:
348 | if "text" not in turn:
349 | content = DEFAULT_ASSISTANT_TOKEN
350 | content_ids = tokenizer.encode(content, add_special_tokens=False)
351 | for n in range(num_codebooks):
352 | input_ids_list[n] += copy.deepcopy(content_ids)
353 | # preprocess audio_data & audio_data_mask
354 | audio_data.append(_get_zero_audio_pad(len(content_ids)))
355 | audio_data_mask += [0] * len(content_ids)
356 | else:
357 | text = turn['text'].strip()
358 | content = DEFAULT_ASSISTANT_TOKEN + text + tokenizer.eos_token
359 | content_ids = tokenizer.encode(content, add_special_tokens=False, truncation=True,
360 | max_length=tokenizer.model_max_length)
361 | for n in range(num_codebooks):
362 | input_ids_list[n] += copy.deepcopy(content_ids)
363 | audio_data.append(_get_zero_audio_pad(len(content_ids)))
364 | audio_data_mask += [0] * len(content_ids)
365 | else:
366 | raise ValueError (f"[Error] Invalid data type of {task_type}.")
367 | else:
368 | # task with audio token as input, prepare audio token
369 | if task_type in ["chat_aiao"]:
370 | # Load audio
371 | assert "audio" in turn
372 | if 'array' in turn['audio']:
373 | assert "sr" in turn["audio"]
374 | if len(turn["audio"]['array'].shape) > 1:
375 | assert turn["audio"]['array'].shape[0] <= 2
376 | turn["audio"]['array'] = librosa.to_mono(turn["audio"]['array'])
377 | audio = librosa.resample(turn["audio"]['array'], orig_sr=turn["audio"]["sr"], target_sr=AUDIO_SR)
378 | elif "bytes" in turn['audio']:
379 | audio, _ = librosa.load(io.BytesIO(turn["audio"]['bytes']), sr=AUDIO_SR)
380 | elif "file" in turn['audio']:
381 | audio, _ = librosa.load(turn["audio"]['file'], sr=AUDIO_SR)
382 | else:
383 | raise Exception(f"No audio input for task {task_type}")
384 |
385 | # get audio token
386 | audio_token_num = int(len(audio) / SEG_LEN)
387 | audio_token = [DEFAULT_AUDIO_TOKEN] * audio_token_num
388 | audio_token = ''.join(audio_token)
389 | audio = audio[:SEG_LEN*audio_token_num] # trim audio
390 |
391 | content = DEFAULT_HUMAN_TOKEN + audio_token
392 | content_ids = tokenizer.encode(content, add_special_tokens=False, truncation=True,
393 | max_length=tokenizer.model_max_length)
394 | for n in range(num_codebooks):
395 | input_ids_list[n] += copy.deepcopy(content_ids)
396 |
397 | audio_data.append(_get_zero_audio_pad(1))
398 | audio_data_mask += [0]
399 | audio_data.append(audio)
400 | audio_data_mask += [1] * audio_token_num
401 | elif task_type in ["chat_tito"]:
402 | text = turn['text'].strip()
403 | content = DEFAULT_HUMAN_TOKEN + text
404 | content_ids = tokenizer.encode(content, add_special_tokens=False, truncation=True,
405 | max_length=tokenizer.model_max_length)
406 | for n in range(num_codebooks):
407 | input_ids_list[n] += copy.deepcopy(content_ids)
408 | audio_data.append(_get_zero_audio_pad(len(content_ids)))
409 | audio_data_mask += [0] * len(content_ids)
410 | else:
411 | raise ValueError (f"[Error] Invalid data type of {task_type}.")
412 |
413 | for n in range(num_codebooks):
414 | input_ids_list[n] = input_ids_list[n][:tokenizer.model_max_length]
415 | audio_data_mask = audio_data_mask[:tokenizer.model_max_length]
416 | audio_data = np.concatenate(audio_data)
417 | audio_data = audio_data[:PREPEND_LEN + tokenizer.model_max_length*SEG_LEN]
418 |
419 | return input_ids_list, audio_data, audio_data_mask, None
420 |
421 | # Item format
422 | # {
423 | # "instruction": "",
424 | # "conversations": [
425 | # {
426 | # "from": "user" or "assistant",
427 | # "text": "",
428 | # "audio": {
429 | # "array": [],
430 | # "sr": 16000,
431 | # "bytes": "",
432 | # "file": "",
433 | # },
434 | # }
435 | # ],
436 | # }
437 | def voila_input_format(item, tokenizer, tokenizer_voila, dataset_cfg):
438 | if dataset_cfg["input_type"] == "audio":
439 | return _alpha_audio_input_format(item, tokenizer, dataset_cfg)
440 | elif dataset_cfg["input_type"] == "autonomous":
441 | return _token_input_format_autonomous(item, tokenizer, tokenizer_voila, dataset_cfg)
442 | else:
443 | return _token_input_format(item, tokenizer, tokenizer_voila, dataset_cfg)
444 |
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | import math
2 | from dataclasses import dataclass
3 | from typing import List, Optional, Tuple, Union, Dict, Any
4 |
5 | import torch
6 | from torch import nn
7 | import torch.nn.functional as F
8 | from torch.nn import CrossEntropyLoss
9 |
10 | from transformers.cache_utils import Cache, DynamicCache
11 | from transformers.utils import ModelOutput, logging
12 | from transformers.models.llama.modeling_llama import LlamaModel, LlamaPreTrainedModel
13 |
14 | from audio_transformer import AudioTransformer
15 |
16 | logger = logging.get_logger(__name__)
17 |
18 |
19 | # Copied from https://github.com/pytorch/audio/blob/main/src/torchaudio/models/wav2vec2/components.py#L43
20 | class LayerNorm(torch.nn.LayerNorm):
21 | """Layer norm with transpose"""
22 |
23 | def forward(self, input: torch.Tensor) -> torch.Tensor:
24 | x = input.transpose(-2, -1)
25 | x = torch.nn.functional.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
26 | x = x.transpose(-2, -1)
27 | return x
28 |
29 | # Copied from https://github.com/pytorch/audio/blob/main/src/torchaudio/models/wav2vec2/components.py#L53
30 | class ConvLayerBlock(torch.nn.Module):
31 | """Convolution unit of FeatureExtractor"""
32 |
33 | def __init__(
34 | self,
35 | in_channels: int,
36 | out_channels: int,
37 | kernel_size: int,
38 | stride: int,
39 | bias: bool,
40 | layer_norm: Optional[torch.nn.Module],
41 | ):
42 | super().__init__()
43 | self.kernel_size = kernel_size
44 | self.stride = stride
45 | self.layer_norm = layer_norm
46 | self.conv = torch.nn.Conv1d(
47 | in_channels=in_channels,
48 | out_channels=out_channels,
49 | kernel_size=kernel_size,
50 | stride=stride,
51 | bias=bias,
52 | )
53 |
54 | def forward(
55 | self,
56 | x: torch.Tensor,
57 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
58 | """
59 | Args:
60 | x (Tensor): Shape: ``[batch, in_channels, in_frame]``.
61 | Returns:
62 | Tensor: Shape ``[batch, out_channels, out_frames]``.
63 | Optional[Tensor]: Shape ``[batch, ]``.
64 | """
65 | x = self.conv(x)
66 | if self.layer_norm is not None:
67 | x = self.layer_norm(x)
68 | x = torch.nn.functional.gelu(x)
69 |
70 | return x
71 |
72 | # Copied from https://github.com/pytorch/audio/blob/main/src/torchaudio/models/wav2vec2/components.py#L146
73 | class FeatureProjection(torch.nn.Module):
74 | """Layer that connects FeatureExtractor and Encoder
75 |
76 | Projects features to encoder dimension.
77 |
78 | Args:
79 | in_features (int): Input feature dim.
80 | out_features (int): Output feature dim.
81 | dropout (float): Dropout probability.
82 | """
83 |
84 | def __init__(
85 | self,
86 | in_features: int,
87 | out_features: int,
88 | dropout=0.1,
89 | ):
90 | super().__init__()
91 | self.layer_norm = torch.nn.LayerNorm(in_features)
92 | self.projection = torch.nn.Linear(
93 | in_features,
94 | out_features,
95 | )
96 | self.dropout = torch.nn.Dropout(dropout)
97 |
98 | def forward(self, x):
99 | """
100 | Args:
101 | x (Tensor):
102 | Feature Tensor. shape: ``[batch, frame, in_feature]``
103 | Returns:
104 | Tensor: Projected features. ``[batch, frame, out_feature]``.
105 | """
106 | x = self.layer_norm(x)
107 | x = self.projection(x)
108 | x = self.dropout(x)
109 | return x
110 |
111 | # Modified from https://github.com/pytorch/audio/blob/main/src/torchaudio/models/wav2vec2/components.py#L102
112 | class FeatureExtractor(torch.nn.Module):
113 | """Extract features from audio
114 |
115 | Args:
116 | conv_layers (nn.ModuleList):
117 | convolution layers
118 | """
119 |
120 | def __init__(
121 | self,
122 | shapes=[(512, 10, 5), (512, 3, 2), (512, 3, 2), (512, 3, 2), (512, 3, 2), (512, 2, 2), (512, 2, 2)],
123 | bias=False,
124 | norm_mode="group_norm",
125 | ):
126 | super().__init__()
127 | if norm_mode not in ["group_norm", "layer_norm"]:
128 | raise ValueError("Invalid norm mode")
129 | blocks = []
130 | in_channels = 1
131 | for i, (out_channels, kernel_size, stride) in enumerate(shapes):
132 | normalization = None
133 | if norm_mode == "group_norm" and i == 0:
134 | normalization = torch.nn.GroupNorm(
135 | num_groups=out_channels,
136 | num_channels=out_channels,
137 | affine=True,
138 | )
139 | elif norm_mode == "layer_norm":
140 | normalization = LayerNorm(
141 | normalized_shape=out_channels,
142 | elementwise_affine=True,
143 | )
144 | blocks.append(
145 | ConvLayerBlock(
146 | in_channels=in_channels,
147 | out_channels=out_channels,
148 | kernel_size=kernel_size,
149 | stride=stride,
150 | bias=bias,
151 | layer_norm=normalization,
152 | )
153 | )
154 | in_channels = out_channels
155 | self.conv_layers = torch.nn.ModuleList(blocks)
156 |
157 | def forward(
158 | self,
159 | x: torch.Tensor,
160 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
161 | """
162 | Args:
163 | x (Tensor):
164 | Input Tensor representing a batch of audio,
165 | shape: ``[batch, time]``.
166 |
167 | Returns:
168 | Tensor:
169 | The resulting feature, shape: ``[batch, frame, feature]``
170 | Optional[Tensor]:
171 | Valid length of each output sample. shape: ``[batch, ]``.
172 | """
173 | if x.ndim != 2:
174 | raise ValueError(f"Expected the input Tensor to be 2D (batch, time). Found: {list(x.shape)}")
175 |
176 | x = x.unsqueeze(1) # (batch, channel==1, frame)
177 | for layer in self.conv_layers:
178 | x = layer(x) # (batch, feature, frame)
179 | x = x.transpose(1, 2) # (batch, frame, feature)
180 | return x
181 |
182 | # Modified from https://github.com/pytorch/audio/blob/main/src/torchaudio/models/wav2vec2/components.py#L102
183 | class FeatureExtractorAdapter(torch.nn.Module):
184 | """Extract features from audio
185 |
186 | Args:
187 | conv_layers (nn.ModuleList):
188 | convolution layers
189 | """
190 |
191 | def __init__(
192 | self,
193 | shapes=(512, 512, 2, 2),
194 | hidden_size=2048,
195 | bias=False,
196 | norm_mode="group_norm",
197 | ):
198 | super().__init__()
199 | if norm_mode not in ["group_norm", "layer_norm"]:
200 | raise ValueError("Invalid norm mode")
201 | in_channels, out_channels, kernel_size, stride = shapes
202 | normalization = LayerNorm(
203 | normalized_shape=out_channels,
204 | elementwise_affine=True,
205 | )
206 | self.conv_layers = ConvLayerBlock(
207 | in_channels=in_channels,
208 | out_channels=out_channels,
209 | kernel_size=kernel_size,
210 | stride=stride,
211 | bias=False,
212 | layer_norm=normalization,
213 | )
214 | self.feat_proj = FeatureProjection(out_channels, hidden_size)
215 |
216 | def forward(
217 | self,
218 | x: torch.Tensor,
219 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
220 | """
221 | Args:
222 | x (Tensor):
223 | Input Tensor representing a batch of audio,
224 | shape: ``[batch, time]``.
225 |
226 | Returns:
227 | Tensor:
228 | The resulting feature, shape: ``[batch, frame, feature]``
229 | Optional[Tensor]:
230 | Valid length of each output sample. shape: ``[batch, ]``.
231 | """
232 | x = x.transpose(1, 2) # (batch, feature, frame)
233 | x = self.conv_layers(x) # (batch, feature, frame)
234 | x = x.transpose(1, 2) # (batch, frame, feature)
235 | x = self.feat_proj(x)
236 | return x
237 |
238 | @dataclass
239 | class VoilaOutput(ModelOutput):
240 | """
241 | Modified from https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_outputs.py#L678
242 |
243 | Base class for Voila outputs.
244 |
245 | Args:
246 | loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
247 | Language modeling loss (for next-token prediction).
248 | logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
249 | Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
250 | last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
251 | The hidden state of the last attention layer.
252 | past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
253 | Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
254 | `(batch_size, num_heads, sequence_length, embed_size_per_head)`)
255 |
256 | Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
257 | `past_key_values` input) to speed up sequential decoding.
258 | hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
259 | Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
260 | one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
261 |
262 | Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
263 | attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
264 | Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
265 | sequence_length)`.
266 |
267 | Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
268 | heads.
269 | """
270 |
271 | loss: Optional[torch.FloatTensor] = None
272 | logits: torch.FloatTensor = None
273 | last_hidden_state: torch.FloatTensor = None
274 | past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
275 | hidden_states: Optional[Tuple[torch.FloatTensor]] = None
276 | attentions: Optional[Tuple[torch.FloatTensor]] = None
277 | voila_pred: Optional[torch.FloatTensor] = None
278 |
279 |
280 | # Modified from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L1103
281 | class VoilaModel(LlamaPreTrainedModel):
282 | _tied_weights_keys = ["lm_head.weight"]
283 |
284 | def __init__(self, config):
285 | super().__init__(config)
286 | self.model = LlamaModel(config)
287 | self.vocab_size = config.vocab_size
288 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
289 | self.pad_vocab_size_multiple = 64
290 |
291 | self.ref_emb_linear = nn.Linear(256, config.hidden_size, bias=True)
292 | self.audio_transformer = AudioTransformer(config, use_sdpa=False)
293 |
294 | # Initialize weights and apply final processing
295 | self.post_init()
296 |
297 | def get_input_embeddings(self):
298 | return self.model.embed_tokens
299 |
300 | def set_input_embeddings(self, value):
301 | self.model.embed_tokens = value
302 |
303 | def get_output_embeddings(self):
304 | return self.lm_head
305 |
306 | def set_output_embeddings(self, new_embeddings):
307 | self.lm_head = new_embeddings
308 |
309 | def set_decoder(self, decoder):
310 | self.model = decoder
311 |
312 | def get_decoder(self):
313 | return self.model
314 |
315 | def forward(
316 | self,
317 | input_ids: torch.LongTensor = None,
318 | attention_mask: Optional[torch.Tensor] = None,
319 | position_ids: Optional[torch.LongTensor] = None,
320 | past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
321 | inputs_embeds: Optional[torch.FloatTensor] = None,
322 | labels: Optional[torch.LongTensor] = None,
323 | audio_labels: Optional[torch.LongTensor] = None,
324 | ref_embs: Optional[List[torch.Tensor]] = None,
325 | ref_embs_mask: Optional[torch.LongTensor] = None,
326 | use_cache: Optional[bool] = None,
327 | output_attentions: Optional[bool] = None,
328 | output_hidden_states: Optional[bool] = None,
329 | return_dict: Optional[bool] = None,
330 | cache_position: Optional[torch.LongTensor] = None,
331 | num_logits_to_keep: int = 0,
332 | ) -> Union[Tuple, VoilaOutput]:
333 | r"""
334 | Args:
335 | input_ids: [bs, seq_len, num_codebooks]
336 | labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
337 | Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
338 | config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
339 | (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
340 | """
341 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
342 | output_hidden_states = (
343 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
344 | )
345 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
346 |
347 | if input_ids is not None and inputs_embeds is not None:
348 | raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
349 | if inputs_embeds is None:
350 | inputs_embeds = self.model.embed_tokens(input_ids)
351 | assert len(inputs_embeds.shape) == 4
352 | if len(inputs_embeds.shape) == 4:
353 | inputs_embeds = inputs_embeds.mean(dim=2)
354 |
355 | if self.training or \
356 | (past_key_values is None and ref_embs is not None) or \
357 | (past_key_values is not None and past_key_values.get_seq_length() < 4 and ref_embs is not None):
358 | ref_embs = self.ref_emb_linear(ref_embs.to(self.ref_emb_linear.weight.dtype))
359 | ref_embs = ref_embs * ref_embs_mask.unsqueeze(-1).unsqueeze(-1)
360 | # (padding_left,padding_right,padding_top,padding_bottom,padding_front,padding_back)
361 | padding = (0, 0, 4, inputs_embeds.shape[1] - 5, 0, 0)
362 | ref_embs = torch.nn.functional.pad(ref_embs, padding, mode='constant', value=0.0)
363 | inputs_embeds = inputs_embeds + ref_embs
364 |
365 | # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
366 | outputs = self.model(
367 | attention_mask=attention_mask,
368 | position_ids=position_ids,
369 | past_key_values=past_key_values,
370 | inputs_embeds=inputs_embeds,
371 | use_cache=use_cache,
372 | output_attentions=output_attentions,
373 | output_hidden_states=output_hidden_states,
374 | return_dict=return_dict,
375 | cache_position=cache_position,
376 | )
377 |
378 | hidden_states = outputs[0]
379 | if self.config.pretraining_tp > 1:
380 | lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
381 | logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
382 | logits = torch.cat(logits, dim=-1)
383 | else:
384 | # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
385 | logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
386 |
387 | loss = None
388 |
389 | if not return_dict:
390 | output = (logits,) + outputs[1:]
391 | return (loss,) + output if loss is not None else output
392 |
393 | return VoilaOutput(
394 | loss=loss,
395 | logits=logits,
396 | last_hidden_state=hidden_states,
397 | past_key_values=outputs.past_key_values,
398 | hidden_states=outputs.hidden_states,
399 | attentions=outputs.attentions,
400 | )
401 |
402 | def _prepare_inputs_for_generation(
403 | self, input_ids, ref_embs=None, ref_embs_mask=None, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
404 | ):
405 | if past_key_values is not None and past_key_values.get_seq_length() > 0:
406 | if isinstance(past_key_values, Cache):
407 | cache_length = past_key_values.get_seq_length()
408 | past_length = past_key_values.seen_tokens
409 | max_cache_length = past_key_values.get_max_cache_shape()
410 | else:
411 | cache_length = past_length = past_key_values[0][0].shape[2]
412 | max_cache_length = None
413 |
414 | # Keep only the unprocessed tokens:
415 | # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
416 | # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
417 | # input)
418 | if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
419 | input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
420 | # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
421 | # input_ids based on the past_length.
422 | elif past_length < input_ids.shape[1]:
423 | input_ids = input_ids[:, past_length:]
424 | # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
425 |
426 | # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
427 | if (
428 | max_cache_length is not None
429 | and attention_mask is not None
430 | and cache_length + input_ids.shape[1] > max_cache_length
431 | ):
432 | attention_mask = attention_mask[:, -max_cache_length:]
433 |
434 | position_ids = kwargs.get("position_ids", None)
435 | if attention_mask is not None and position_ids is None:
436 | # create position_ids on the fly for batch generation
437 | position_ids = attention_mask.long().cumsum(-1) - 1
438 | position_ids.masked_fill_(attention_mask == 0, 1)
439 | if past_key_values:
440 | position_ids = position_ids[:, -input_ids.shape[1] :]
441 |
442 | # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
443 | if inputs_embeds is None and \
444 | (past_key_values is None or past_key_values.get_seq_length() <= 0):
445 | inputs_embeds = self.model.embed_tokens(input_ids)
446 | if inputs_embeds is not None and \
447 | (past_key_values is None or past_key_values.get_seq_length() <= 0):
448 | model_inputs = {"inputs_embeds": inputs_embeds, "ref_embs": ref_embs, "ref_embs_mask": ref_embs_mask}
449 | else:
450 | model_inputs = {"input_ids": input_ids, "ref_embs": None}
451 |
452 | model_inputs.update(
453 | {
454 | "position_ids": position_ids,
455 | "past_key_values": past_key_values,
456 | "use_cache": kwargs.get("use_cache"),
457 | "attention_mask": attention_mask,
458 | }
459 | )
460 | return model_inputs
461 |
462 | def _update_model_kwargs_for_generation(
463 | self,
464 | outputs,
465 | model_kwargs: Dict[str, Any],
466 | num_new_token: int = 1,
467 | ) -> Dict[str, Any]:
468 | # update past_key_values
469 | model_kwargs["past_key_values"] = outputs.past_key_values
470 |
471 | # update attention mask
472 | if "attention_mask" in model_kwargs:
473 | attention_mask = model_kwargs["attention_mask"]
474 | model_kwargs["attention_mask"] = torch.cat(
475 | [attention_mask, attention_mask.new_ones((attention_mask.shape[0], num_new_token))], dim=-1
476 | )
477 |
478 | return model_kwargs
479 |
480 | def _prepare_attention_mask_for_generation(
481 | self,
482 | inputs: torch.Tensor,
483 | pad_token_id: Optional[int],
484 | eos_token_id: Optional[Union[int, List[int]]],
485 | ) -> torch.LongTensor:
486 | is_input_ids = len(inputs.shape) == 2 and inputs.dtype in [torch.int, torch.long]
487 | is_pad_token_in_inputs = (pad_token_id is not None) and (pad_token_id in inputs)
488 | if isinstance(eos_token_id, int):
489 | eos_token_id = [eos_token_id]
490 | is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or (pad_token_id not in eos_token_id)
491 |
492 | # Check if input is input_ids and padded -> only then is attention_mask defined
493 | if is_input_ids and is_pad_token_in_inputs and is_pad_token_not_equal_to_eos_token_id:
494 | return inputs.ne(pad_token_id).long()
495 | else:
496 | return torch.ones(inputs.shape[:2], dtype=torch.long, device=inputs.device)
497 |
498 | @torch.inference_mode()
499 | def run_generate(
500 | self,
501 | input_ids: torch.LongTensor,
502 | ref_embs: Optional[List[torch.Tensor]] = None,
503 | ref_embs_mask: Optional[torch.LongTensor] = None,
504 | max_new_tokens: Optional[int] = 128,
505 | pad_token_id: Optional[int] = None,
506 | eos_token_id: Optional[Union[int, List[int]]] = None,
507 | streamer: Optional["BaseStreamer"] = None,
508 | llm_audio_token_id: Optional[int] = None,
509 | min_audio_token_id: Optional[int] = None,
510 | temperature=0.2,
511 | top_k=50,
512 | audio_temperature=0.2,
513 | audio_top_k=50,
514 | ):
515 | assert eos_token_id is not None and pad_token_id is not None, "eos_token_id and pad_token_id are required for inference"
516 | assert llm_audio_token_id is not None and min_audio_token_id is not None, "llm_audio_token_id and min_audio_token_id are required for inference"
517 | assert len(input_ids.shape) == 2 or len(input_ids.shape) == 3, f"input_ids is supposed to be [batch, seq_len] or [batch, seq_len, num_codebooks], and got {input_ids.shape}"
518 |
519 | eos_token_id_tensor = torch.tensor([eos_token_id]).to(input_ids.device)
520 |
521 | # keep track of which sequences are already finished
522 | unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device)
523 |
524 | # Extend input_ids with additional num_codebooks dim
525 | if len(input_ids.shape) == 2:
526 | input_ids = input_ids[:, :, None].expand(1, 1, self.config.num_codebooks)
527 |
528 | this_peer_finished = False # used by synced_gpus only
529 | max_length = input_ids.shape[1] + max_new_tokens
530 |
531 | model_kwargs = {
532 | "use_cache": True,
533 | "past_key_values": DynamicCache(),
534 | "attention_mask": self._prepare_attention_mask_for_generation(
535 | input_ids, pad_token_id, eos_token_id
536 | ),
537 | }
538 | # auto-regressive generation
539 | while True:
540 | # prepare model inputs
541 | model_inputs = self._prepare_inputs_for_generation(
542 | input_ids,
543 | ref_embs=ref_embs,
544 | ref_embs_mask=ref_embs_mask,
545 | **model_kwargs
546 | )
547 |
548 | # forward pass to get next token
549 | outputs = self(
550 | **model_inputs,
551 | return_dict=True,
552 | )
553 | audio_tokens = self.audio_transformer.inference(
554 | outputs.last_hidden_state,
555 | temperature=audio_temperature,
556 | top_k=audio_top_k,
557 | )
558 | audio_tokens = torch.stack(
559 | [
560 | audio_tokens[:, :, ci] + min_audio_token_id + ci*self.config.codebook_size
561 | for ci in range(self.config.num_codebooks)
562 | ],
563 | dim=2,
564 | )
565 |
566 | next_token_logits = outputs.logits[:, -1, :]
567 |
568 | # pre-process distribution
569 | # Apply temperature and top-k
570 | if temperature > 0:
571 | next_token_logits = next_token_logits / temperature
572 | if top_k > 0:
573 | top_k = min(top_k, next_token_logits.size(-1)) # Safety check
574 | # Remove all tokens with a probability less than the last token of the top-k
575 | indices_to_remove = next_token_logits < torch.topk(next_token_logits, top_k)[0][..., -1, None]
576 | next_token_logits = next_token_logits.masked_fill(indices_to_remove, -float("Inf"))
577 |
578 | # sample
579 | probs = nn.functional.softmax(next_token_logits, dim=-1)
580 | next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
581 |
582 | # finished sentences should have their next token be a padding token
583 | if eos_token_id is not None:
584 | if pad_token_id is None:
585 | raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
586 | next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
587 |
588 | # Append NUM_CODEBOOK text tokens or audio_tokens
589 | if len(next_tokens.shape) == 1:
590 | next_tokens = next_tokens[:, None, None].expand(-1, 1, self.config.num_codebooks)
591 | next_tokens = torch.where(next_tokens==llm_audio_token_id, audio_tokens, next_tokens)
592 |
593 | input_ids = torch.cat([input_ids, next_tokens], dim=1)
594 | if streamer is not None:
595 | streamer.put(next_tokens.cpu())
596 | model_kwargs = self._update_model_kwargs_for_generation(
597 | outputs, model_kwargs
598 | )
599 |
600 | # if eos_token was found in one sentence, set sentence to finished
601 | if eos_token_id_tensor is not None:
602 | unfinished_sequences = unfinished_sequences.mul(
603 | next_tokens[:, :, 0].ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=1)
604 | )
605 |
606 | # stop when each sentence is finished
607 | if unfinished_sequences.max() == 0:
608 | this_peer_finished = True
609 |
610 | # stop if we exceed the maximum length
611 | if input_ids.shape[1] >= max_length:
612 | this_peer_finished = True
613 |
614 | if this_peer_finished:
615 | break
616 |
617 | if streamer is not None:
618 | streamer.end()
619 |
620 | return input_ids
621 |
622 |
623 | # Modified from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L1103
624 | class VoilaAudioAlphaModel(LlamaPreTrainedModel):
625 | _tied_weights_keys = ["lm_head.weight"]
626 |
627 | def __init__(self, config):
628 | super().__init__(config)
629 | self.model = LlamaModel(config)
630 | self.vocab_size = config.vocab_size
631 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
632 | self.pad_vocab_size_multiple = 64
633 |
634 |
635 | self.ref_emb_linear = nn.Linear(256, config.hidden_size, bias=True)
636 | self.audio_transformer = AudioTransformer(config, use_sdpa=False)
637 |
638 | self.feature_extractor = FeatureExtractor()
639 | self.audio_feature_extractor_adapter = FeatureExtractorAdapter(hidden_size=config.hidden_size)
640 |
641 | # Initialize weights and apply final processing
642 | self.post_init()
643 |
644 | def get_input_embeddings(self):
645 | return self.model.embed_tokens
646 |
647 | def set_input_embeddings(self, value):
648 | self.model.embed_tokens = value
649 |
650 | def get_output_embeddings(self):
651 | return self.lm_head
652 |
653 | def set_output_embeddings(self, new_embeddings):
654 | self.lm_head = new_embeddings
655 |
656 | def set_decoder(self, decoder):
657 | self.model = decoder
658 |
659 | def get_decoder(self):
660 | return self.model
661 |
662 | def forward(
663 | self,
664 | input_ids: torch.LongTensor = None,
665 | attention_mask: Optional[torch.Tensor] = None,
666 | position_ids: Optional[torch.LongTensor] = None,
667 | past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
668 | inputs_embeds: Optional[torch.FloatTensor] = None,
669 | labels: Optional[torch.LongTensor] = None,
670 | audio_labels: Optional[torch.LongTensor] = None,
671 | ref_embs: Optional[List[torch.Tensor]] = None,
672 | ref_embs_mask: Optional[torch.LongTensor] = None,
673 | audio_datas: Optional[torch.FloatTensor] = None,
674 | audio_data_masks: Optional[torch.LongTensor] = None,
675 | use_cache: Optional[bool] = None,
676 | output_attentions: Optional[bool] = None,
677 | output_hidden_states: Optional[bool] = None,
678 | return_dict: Optional[bool] = None,
679 | cache_position: Optional[torch.LongTensor] = None,
680 | num_logits_to_keep: int = 0,
681 | ) -> Union[Tuple, VoilaOutput]:
682 | r"""
683 | Args:
684 | input_ids: [bs, seq_len, num_codebooks]
685 | labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
686 | Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
687 | config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
688 | (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
689 | """
690 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
691 | output_hidden_states = (
692 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
693 | )
694 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
695 |
696 | if input_ids is not None and inputs_embeds is not None:
697 | raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
698 | if inputs_embeds is None:
699 | inputs_embeds = self.model.embed_tokens(input_ids)
700 | assert len(inputs_embeds.shape) == 4
701 | if len(inputs_embeds.shape) == 4:
702 | inputs_embeds = inputs_embeds.mean(dim=2)
703 |
704 | if self.training or \
705 | (past_key_values is None and ref_embs is not None) or \
706 | (past_key_values is not None and past_key_values.get_seq_length() < 4 and ref_embs is not None):
707 | ref_embs = self.ref_emb_linear(ref_embs.to(self.ref_emb_linear.weight.dtype))
708 | ref_embs = ref_embs * ref_embs_mask.unsqueeze(-1).unsqueeze(-1)
709 | # (padding_left,padding_right,padding_top,padding_bottom,padding_front,padding_back)
710 | padding = (0, 0, 4, inputs_embeds.shape[1] - 5, 0, 0)
711 | ref_embs = torch.nn.functional.pad(ref_embs, padding, mode='constant', value=0.0)
712 | inputs_embeds = inputs_embeds + ref_embs
713 |
714 | if self.training or audio_datas is not None:
715 | audio_embeds = self.feature_extractor(audio_datas)
716 | audio_embeds = self.audio_feature_extractor_adapter(audio_embeds)
717 | audio_embeds = audio_embeds * audio_data_masks[..., None]
718 | inputs_embeds = inputs_embeds + audio_embeds
719 |
720 | # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
721 | outputs = self.model(
722 | attention_mask=attention_mask,
723 | position_ids=position_ids,
724 | past_key_values=past_key_values,
725 | inputs_embeds=inputs_embeds,
726 | use_cache=use_cache,
727 | output_attentions=output_attentions,
728 | output_hidden_states=output_hidden_states,
729 | return_dict=return_dict,
730 | cache_position=cache_position,
731 | )
732 |
733 | hidden_states = outputs[0]
734 | if self.config.pretraining_tp > 1:
735 | lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
736 | logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
737 | logits = torch.cat(logits, dim=-1)
738 | else:
739 | # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
740 | logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
741 |
742 | loss = None
743 | if labels is not None:
744 | # Upcast to float if we need to compute the loss to avoid potential precision issues
745 | logits = logits.float()
746 | # We shift tokens and labels in dataloader
747 | shift_logits = logits.contiguous()
748 | shift_labels = labels.contiguous()
749 | # Flatten the tokens
750 | loss_fct = CrossEntropyLoss()
751 | shift_logits = shift_logits.view(-1, self.config.vocab_size)
752 | shift_labels = shift_labels.view(-1)
753 | # Enable model parallelism
754 | shift_labels = shift_labels.to(shift_logits.device)
755 | loss = loss_fct(shift_logits, shift_labels)
756 |
757 | if audio_labels is not None:
758 | au_mask = (audio_labels >= 0).all(dim=-1)
759 | au_hidden_states = hidden_states[au_mask]
760 | au_audio_labels = audio_labels[au_mask]
761 | if len(au_hidden_states) <= 0:
762 | au_hidden_states = hidden_states.reshape(-1, hidden_states.shape[-1])
763 | au_audio_labels = torch.zeros_like(audio_labels).reshape(-1, self.config.num_codebooks)
764 | loss_weight = 0.0
765 | else:
766 | loss_weight = 1.0
767 | au_logits = self.audio_transformer(au_hidden_states, au_audio_labels)
768 | # We shift tokens and labels in dataloader
769 | shift_au_logits = au_logits.contiguous()
770 | shift_audio_labels = au_audio_labels.contiguous()
771 | # Flatten the tokens
772 | loss_fct = CrossEntropyLoss()
773 | shift_au_logits = shift_au_logits.view(-1, self.config.codebook_size)
774 | shift_audio_labels = shift_audio_labels.view(-1)
775 | # Enable model parallelism
776 | shift_audio_labels = shift_audio_labels.to(shift_au_logits.device)
777 | au_loss = loss_fct(shift_au_logits, shift_audio_labels)
778 |
779 | loss += au_loss * loss_weight
780 | else:
781 | # au_tokens = self.audio_transformer.inference(hidden_states)
782 | pass
783 |
784 | if not return_dict:
785 | output = (logits,) + outputs[1:]
786 | return (loss,) + output if loss is not None else output
787 |
788 | return VoilaOutput(
789 | loss=loss,
790 | logits=logits,
791 | last_hidden_state=hidden_states,
792 | past_key_values=outputs.past_key_values,
793 | hidden_states=outputs.hidden_states,
794 | attentions=outputs.attentions,
795 | )
796 |
797 | def _prepare_inputs_for_generation(
798 | self, input_ids, ref_embs=None, ref_embs_mask=None, audio_datas=None, audio_data_masks=None, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
799 | ):
800 | if past_key_values is not None and past_key_values.get_seq_length() > 0:
801 | if isinstance(past_key_values, Cache):
802 | cache_length = past_key_values.get_seq_length()
803 | past_length = past_key_values.seen_tokens
804 | max_cache_length = past_key_values.get_max_cache_shape()
805 | else:
806 | cache_length = past_length = past_key_values[0][0].shape[2]
807 | max_cache_length = None
808 |
809 | # Keep only the unprocessed tokens:
810 | # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
811 | # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
812 | # input)
813 | if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
814 | input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
815 | # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
816 | # input_ids based on the past_length.
817 | elif past_length < input_ids.shape[1]:
818 | input_ids = input_ids[:, past_length:]
819 | # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
820 |
821 | # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
822 | if (
823 | max_cache_length is not None
824 | and attention_mask is not None
825 | and cache_length + input_ids.shape[1] > max_cache_length
826 | ):
827 | attention_mask = attention_mask[:, -max_cache_length:]
828 |
829 | position_ids = kwargs.get("position_ids", None)
830 | if attention_mask is not None and position_ids is None:
831 | # create position_ids on the fly for batch generation
832 | position_ids = attention_mask.long().cumsum(-1) - 1
833 | position_ids.masked_fill_(attention_mask == 0, 1)
834 | if past_key_values:
835 | position_ids = position_ids[:, -input_ids.shape[1] :]
836 |
837 | # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
838 | if inputs_embeds is None and \
839 | (past_key_values is None or past_key_values.get_seq_length() <= 0):
840 | inputs_embeds = self.model.embed_tokens(input_ids)
841 | if inputs_embeds is not None and \
842 | (past_key_values is None or past_key_values.get_seq_length() <= 0):
843 | model_inputs = {"inputs_embeds": inputs_embeds, "ref_embs": ref_embs, "ref_embs_mask": ref_embs_mask, "audio_datas": audio_datas, "audio_data_masks": audio_data_masks}
844 | else:
845 | model_inputs = {"input_ids": input_ids, "ref_embs": None, "audio_datas": None, "audio_data_masks": None}
846 |
847 | model_inputs.update(
848 | {
849 | "position_ids": position_ids,
850 | "past_key_values": past_key_values,
851 | "use_cache": kwargs.get("use_cache"),
852 | "attention_mask": attention_mask,
853 | }
854 | )
855 | return model_inputs
856 |
857 | def _update_model_kwargs_for_generation(
858 | self,
859 | outputs,
860 | model_kwargs: Dict[str, Any],
861 | num_new_token: int = 1,
862 | ) -> Dict[str, Any]:
863 | # update past_key_values
864 | model_kwargs["past_key_values"] = outputs.past_key_values
865 |
866 | # update attention mask
867 | if "attention_mask" in model_kwargs:
868 | attention_mask = model_kwargs["attention_mask"]
869 | model_kwargs["attention_mask"] = torch.cat(
870 | [attention_mask, attention_mask.new_ones((attention_mask.shape[0], num_new_token))], dim=-1
871 | )
872 |
873 | return model_kwargs
874 |
875 | def _prepare_attention_mask_for_generation(
876 | self,
877 | inputs: torch.Tensor,
878 | pad_token_id: Optional[int],
879 | eos_token_id: Optional[Union[int, List[int]]],
880 | ) -> torch.LongTensor:
881 | is_input_ids = len(inputs.shape) == 2 and inputs.dtype in [torch.int, torch.long]
882 | is_pad_token_in_inputs = (pad_token_id is not None) and (pad_token_id in inputs)
883 | if isinstance(eos_token_id, int):
884 | eos_token_id = [eos_token_id]
885 | is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or (pad_token_id not in eos_token_id)
886 |
887 | # Check if input is input_ids and padded -> only then is attention_mask defined
888 | if is_input_ids and is_pad_token_in_inputs and is_pad_token_not_equal_to_eos_token_id:
889 | return inputs.ne(pad_token_id).long()
890 | else:
891 | return torch.ones(inputs.shape[:2], dtype=torch.long, device=inputs.device)
892 |
893 | @torch.inference_mode()
894 | def run_generate(
895 | self,
896 | input_ids: torch.LongTensor,
897 | ref_embs: Optional[List[torch.Tensor]] = None,
898 | ref_embs_mask: Optional[torch.LongTensor] = None,
899 | audio_datas: Optional[torch.FloatTensor] = None,
900 | audio_data_masks: Optional[torch.LongTensor] = None,
901 | max_new_tokens: Optional[int] = 128,
902 | pad_token_id: Optional[int] = None,
903 | eos_token_id: Optional[Union[int, List[int]]] = None,
904 | streamer: Optional["BaseStreamer"] = None,
905 | llm_audio_token_id: Optional[int] = None,
906 | min_audio_token_id: Optional[int] = None,
907 | temperature=0.2,
908 | top_k=50,
909 | audio_temperature=0.2,
910 | audio_top_k=50,
911 | ):
912 | assert eos_token_id is not None and pad_token_id is not None, "eos_token_id and pad_token_id are required for inference"
913 | assert llm_audio_token_id is not None and min_audio_token_id is not None, "llm_audio_token_id and min_audio_token_id are required for inference"
914 | assert len(input_ids.shape) == 2 or len(input_ids.shape) == 3, f"input_ids is supposed to be [batch, seq_len] or [batch, seq_len, num_codebooks], and got {input_ids.shape}"
915 |
916 | eos_token_id_tensor = torch.tensor([eos_token_id]).to(input_ids.device)
917 |
918 | # keep track of which sequences are already finished
919 | unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device)
920 |
921 | # Extend input_ids with additional num_codebooks dim
922 | if len(input_ids.shape) == 2:
923 | input_ids = input_ids[:, :, None].expand(1, 1, self.config.num_codebooks)
924 |
925 | this_peer_finished = False # used by synced_gpus only
926 | max_length = input_ids.shape[1] + max_new_tokens
927 |
928 | model_kwargs = {
929 | "use_cache": True,
930 | "past_key_values": DynamicCache(),
931 | "attention_mask": self._prepare_attention_mask_for_generation(
932 | input_ids, pad_token_id, eos_token_id
933 | ),
934 | }
935 | # auto-regressive generation
936 | while True:
937 | # prepare model inputs
938 | model_inputs = self._prepare_inputs_for_generation(
939 | input_ids,
940 | ref_embs=ref_embs,
941 | ref_embs_mask=ref_embs_mask,
942 | audio_datas=audio_datas,
943 | audio_data_masks=audio_data_masks,
944 | **model_kwargs
945 | )
946 |
947 | # forward pass to get next token
948 | outputs = self(
949 | **model_inputs,
950 | return_dict=True,
951 | )
952 | audio_tokens = self.audio_transformer.inference(
953 | outputs.last_hidden_state,
954 | temperature=audio_temperature,
955 | top_k=audio_top_k,
956 | )
957 | audio_tokens = torch.stack(
958 | [
959 | audio_tokens[:, :, ci] + min_audio_token_id + ci*self.config.codebook_size
960 | for ci in range(self.config.num_codebooks)
961 | ],
962 | dim=2,
963 | )
964 |
965 | next_token_logits = outputs.logits[:, -1, :]
966 |
967 | # pre-process distribution
968 | # Apply temperature and top-k
969 | if temperature > 0:
970 | next_token_logits = next_token_logits / temperature
971 | if top_k > 0:
972 | top_k = min(top_k, next_token_logits.size(-1)) # Safety check
973 | # Remove all tokens with a probability less than the last token of the top-k
974 | indices_to_remove = next_token_logits < torch.topk(next_token_logits, top_k)[0][..., -1, None]
975 | next_token_logits = next_token_logits.masked_fill(indices_to_remove, -float("Inf"))
976 |
977 | # sample
978 | probs = nn.functional.softmax(next_token_logits, dim=-1)
979 | next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
980 |
981 | # finished sentences should have their next token be a padding token
982 | if eos_token_id is not None:
983 | if pad_token_id is None:
984 | raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
985 | next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
986 |
987 | # Append NUM_CODEBOOK text tokens or audio_tokens
988 | if len(next_tokens.shape) == 1:
989 | next_tokens = next_tokens[:, None, None].expand(-1, 1, self.config.num_codebooks)
990 | next_tokens = torch.where(next_tokens==llm_audio_token_id, audio_tokens, next_tokens)
991 |
992 | input_ids = torch.cat([input_ids, next_tokens], dim=1)
993 | if streamer is not None:
994 | streamer.put(next_tokens.cpu())
995 | model_kwargs = self._update_model_kwargs_for_generation(
996 | outputs, model_kwargs
997 | )
998 |
999 | # if eos_token was found in one sentence, set sentence to finished
1000 | if eos_token_id_tensor is not None:
1001 | unfinished_sequences = unfinished_sequences.mul(
1002 | next_tokens[:, :, 0].ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=1)
1003 | )
1004 |
1005 | # stop when each sentence is finished
1006 | if unfinished_sequences.max() == 0:
1007 | this_peer_finished = True
1008 |
1009 | # stop if we exceed the maximum length
1010 | if input_ids.shape[1] >= max_length:
1011 | this_peer_finished = True
1012 |
1013 | if this_peer_finished:
1014 | break
1015 |
1016 | if streamer is not None:
1017 | streamer.end()
1018 |
1019 | return input_ids
1020 |
1021 |
1022 | # Modified from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L1103
1023 | class VoilaAutonomousModel(LlamaPreTrainedModel):
1024 | _tied_weights_keys = ["lm_head.weight"]
1025 |
1026 | def __init__(self, config):
1027 | super().__init__(config)
1028 | self.model = LlamaModel(config)
1029 | self.vocab_size = config.vocab_size
1030 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1031 | self.pad_vocab_size_multiple = 64
1032 |
1033 | self.ref_emb_linear = nn.Linear(256, config.hidden_size, bias=True)
1034 | self.audio_transformer = AudioTransformer(config, use_sdpa=False)
1035 | self.voila_predictor = nn.Sequential(nn.Linear(config.hidden_size, 2, bias=True),)
1036 |
1037 | # Initialize weights and apply final processing
1038 | self.post_init()
1039 |
1040 | def get_input_embeddings(self):
1041 | return self.model.embed_tokens
1042 |
1043 | def set_input_embeddings(self, value):
1044 | self.model.embed_tokens = value
1045 |
1046 | def get_output_embeddings(self):
1047 | return self.lm_head
1048 |
1049 | def set_output_embeddings(self, new_embeddings):
1050 | self.lm_head = new_embeddings
1051 |
1052 | def set_decoder(self, decoder):
1053 | self.model = decoder
1054 |
1055 | def get_decoder(self):
1056 | return self.model
1057 |
1058 | def forward(
1059 | self,
1060 | input_ids: torch.LongTensor = None,
1061 | attention_mask: Optional[torch.Tensor] = None,
1062 | position_ids: Optional[torch.LongTensor] = None,
1063 | past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
1064 | inputs_embeds: Optional[torch.FloatTensor] = None,
1065 | labels: Optional[torch.LongTensor] = None,
1066 | audio_labels: Optional[torch.LongTensor] = None,
1067 | voila_labels: Optional[torch.LongTensor] = None,
1068 | ref_embs: Optional[List[torch.Tensor]] = None,
1069 | ref_embs_mask: Optional[torch.LongTensor] = None,
1070 | use_cache: Optional[bool] = None,
1071 | output_attentions: Optional[bool] = None,
1072 | output_hidden_states: Optional[bool] = None,
1073 | return_dict: Optional[bool] = None,
1074 | cache_position: Optional[torch.LongTensor] = None,
1075 | num_logits_to_keep: int = 0,
1076 | ) -> Union[Tuple, VoilaOutput]:
1077 | r"""
1078 | Args:
1079 | input_ids: [bs, seq_len, num_codebooks]
1080 | labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1081 | Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1082 | config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1083 | (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1084 | """
1085 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1086 | output_hidden_states = (
1087 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1088 | )
1089 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1090 |
1091 | if input_ids is not None and inputs_embeds is not None:
1092 | raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
1093 | if inputs_embeds is None:
1094 | inputs_embeds = self.model.embed_tokens(input_ids)
1095 | assert len(inputs_embeds.shape) == 4
1096 | if len(inputs_embeds.shape) == 4:
1097 | inputs_embeds = inputs_embeds.mean(dim=2)
1098 |
1099 | if self.training or \
1100 | (past_key_values is None and ref_embs is not None) or \
1101 | (past_key_values is not None and past_key_values.get_seq_length() < 4 and ref_embs is not None):
1102 | ref_embs = self.ref_emb_linear(ref_embs.to(self.ref_emb_linear.weight.dtype))
1103 | ref_embs = ref_embs * ref_embs_mask.unsqueeze(-1).unsqueeze(-1)
1104 | # (padding_left,padding_right,padding_top,padding_bottom,padding_front,padding_back)
1105 | padding = (0, 0, 4, inputs_embeds.shape[1] - 5, 0, 0)
1106 | ref_embs = torch.nn.functional.pad(ref_embs, padding, mode='constant', value=0.0)
1107 | inputs_embeds = inputs_embeds + ref_embs
1108 |
1109 | # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1110 | outputs = self.model(
1111 | attention_mask=attention_mask,
1112 | position_ids=position_ids,
1113 | past_key_values=past_key_values,
1114 | inputs_embeds=inputs_embeds,
1115 | use_cache=use_cache,
1116 | output_attentions=output_attentions,
1117 | output_hidden_states=output_hidden_states,
1118 | return_dict=return_dict,
1119 | cache_position=cache_position,
1120 | )
1121 |
1122 | hidden_states = outputs[0]
1123 | if self.config.pretraining_tp > 1:
1124 | lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
1125 | logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
1126 | logits = torch.cat(logits, dim=-1)
1127 | else:
1128 | # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
1129 | logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
1130 |
1131 | # calc voila_predict_loss
1132 | voila_pred = self.voila_predictor(hidden_states)
1133 | voila_pred = voila_pred.float()
1134 |
1135 | loss = None
1136 |
1137 | if not return_dict:
1138 | output = (logits,) + outputs[1:]
1139 | return (loss,) + output if loss is not None else output
1140 |
1141 | return VoilaOutput(
1142 | loss=loss,
1143 | logits=logits,
1144 | last_hidden_state=hidden_states,
1145 | past_key_values=outputs.past_key_values,
1146 | hidden_states=outputs.hidden_states,
1147 | attentions=outputs.attentions,
1148 | voila_pred=voila_pred,
1149 | )
1150 |
1151 | def _prepare_inputs_for_generation(
1152 | self, input_ids, ref_embs=None, ref_embs_mask=None, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
1153 | ):
1154 | if past_key_values is not None and past_key_values.get_seq_length() > 0:
1155 | if isinstance(past_key_values, Cache):
1156 | cache_length = past_key_values.get_seq_length()
1157 | past_length = past_key_values.seen_tokens
1158 | max_cache_length = past_key_values.get_max_cache_shape()
1159 | else:
1160 | cache_length = past_length = past_key_values[0][0].shape[2]
1161 | max_cache_length = None
1162 |
1163 | # Keep only the unprocessed tokens:
1164 | # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
1165 | # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
1166 | # input)
1167 | if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
1168 | input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
1169 | # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
1170 | # input_ids based on the past_length.
1171 | elif past_length < input_ids.shape[1]:
1172 | input_ids = input_ids[:, past_length:]
1173 | # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
1174 |
1175 | # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
1176 | if (
1177 | max_cache_length is not None
1178 | and attention_mask is not None
1179 | and cache_length + input_ids.shape[1] > max_cache_length
1180 | ):
1181 | attention_mask = attention_mask[:, -max_cache_length:]
1182 |
1183 | position_ids = kwargs.get("position_ids", None)
1184 | if attention_mask is not None and position_ids is None:
1185 | # create position_ids on the fly for batch generation
1186 | position_ids = attention_mask.long().cumsum(-1) - 1
1187 | position_ids.masked_fill_(attention_mask == 0, 1)
1188 | if past_key_values:
1189 | position_ids = position_ids[:, -input_ids.shape[1] :]
1190 |
1191 | # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1192 | if inputs_embeds is None and \
1193 | (past_key_values is None or past_key_values.get_seq_length() <= 0):
1194 | inputs_embeds = self.model.embed_tokens(input_ids)
1195 | if inputs_embeds is not None and \
1196 | (past_key_values is None or past_key_values.get_seq_length() <= 0):
1197 | model_inputs = {"inputs_embeds": inputs_embeds, "ref_embs": ref_embs, "ref_embs_mask": ref_embs_mask}
1198 | else:
1199 | model_inputs = {"input_ids": input_ids, "ref_embs": None}
1200 |
1201 | model_inputs.update(
1202 | {
1203 | "position_ids": position_ids,
1204 | "past_key_values": past_key_values,
1205 | "use_cache": kwargs.get("use_cache"),
1206 | "attention_mask": attention_mask,
1207 | }
1208 | )
1209 | return model_inputs
1210 |
1211 | def _update_model_kwargs_for_generation(
1212 | self,
1213 | outputs,
1214 | model_kwargs: Dict[str, Any],
1215 | num_new_token: int = 1,
1216 | ) -> Dict[str, Any]:
1217 | # update past_key_values
1218 | model_kwargs["past_key_values"] = outputs.past_key_values
1219 |
1220 | # update attention mask
1221 | if "attention_mask" in model_kwargs:
1222 | attention_mask = model_kwargs["attention_mask"]
1223 | model_kwargs["attention_mask"] = torch.cat(
1224 | [attention_mask, attention_mask.new_ones((attention_mask.shape[0], num_new_token))], dim=-1
1225 | )
1226 |
1227 | return model_kwargs
1228 |
1229 | def _prepare_attention_mask_for_generation(
1230 | self,
1231 | inputs: torch.Tensor,
1232 | pad_token_id: Optional[int],
1233 | eos_token_id: Optional[Union[int, List[int]]],
1234 | ) -> torch.LongTensor:
1235 | is_input_ids = len(inputs.shape) == 2 and inputs.dtype in [torch.int, torch.long]
1236 | is_pad_token_in_inputs = (pad_token_id is not None) and (pad_token_id in inputs)
1237 | if isinstance(eos_token_id, int):
1238 | eos_token_id = [eos_token_id]
1239 | is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or (pad_token_id not in eos_token_id)
1240 |
1241 | # Check if input is input_ids and padded -> only then is attention_mask defined
1242 | if is_input_ids and is_pad_token_in_inputs and is_pad_token_not_equal_to_eos_token_id:
1243 | return inputs.ne(pad_token_id).long()
1244 | else:
1245 | return torch.ones(inputs.shape[:2], dtype=torch.long, device=inputs.device)
1246 |
1247 | @torch.inference_mode()
1248 | def run_generate(
1249 | self,
1250 | input_ids: torch.LongTensor,
1251 | input_generator,
1252 | ref_embs: Optional[List[torch.Tensor]] = None,
1253 | ref_embs_mask: Optional[torch.LongTensor] = None,
1254 | max_new_tokens: Optional[int] = 128,
1255 | pad_token_id: Optional[int] = None,
1256 | eos_token_id: Optional[Union[int, List[int]]] = None,
1257 | streamer: Optional["BaseStreamer"] = None,
1258 | llm_audio_token_id: Optional[int] = None,
1259 | min_audio_token_id: Optional[int] = None,
1260 | llm_assistant_token_id: Optional[int] = None,
1261 | temperature=0.2,
1262 | top_k=50,
1263 | audio_temperature=0.8,
1264 | audio_top_k=50,
1265 | ):
1266 | assert eos_token_id is not None and pad_token_id is not None, "eos_token_id and pad_token_id are required for inference"
1267 | assert llm_audio_token_id is not None and min_audio_token_id is not None, "llm_audio_token_id and min_audio_token_id are required for inference"
1268 | assert len(input_ids.shape) == 2 or len(input_ids.shape) == 3, f"input_ids is supposed to be [batch, seq_len] or [batch, seq_len, num_codebooks], and got {input_ids.shape}"
1269 |
1270 | eos_token_id_tensor = torch.tensor([eos_token_id]).to(input_ids.device)
1271 |
1272 | # keep track of which sequences are already finished
1273 | unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device)
1274 |
1275 | # Extend input_ids with additional num_codebooks dim
1276 | input_ids = input_ids.clone()
1277 | if len(input_ids.shape) == 2:
1278 | input_ids = input_ids[:, :, None].expand(1, 1, self.config.num_codebooks)
1279 |
1280 | this_peer_finished = False # used by synced_gpus only
1281 | max_length = input_ids.shape[1] + max_new_tokens
1282 |
1283 | model_kwargs = {
1284 | "use_cache": True,
1285 | "past_key_values": DynamicCache(),
1286 | "attention_mask": self._prepare_attention_mask_for_generation(
1287 | input_ids, pad_token_id, eos_token_id
1288 | ),
1289 | }
1290 | speaking = False
1291 | # auto-regressive generation
1292 | while True:
1293 | # prepare model inputs
1294 | model_inputs = self._prepare_inputs_for_generation(
1295 | input_ids,
1296 | ref_embs=ref_embs,
1297 | ref_embs_mask=ref_embs_mask,
1298 | **model_kwargs
1299 | )
1300 |
1301 | # forward pass to get next token
1302 | outputs = self(
1303 | **model_inputs,
1304 | return_dict=True,
1305 | )
1306 | audio_tokens = self.audio_transformer.inference(
1307 | outputs.last_hidden_state,
1308 | temperature=audio_temperature,
1309 | top_k=audio_top_k,
1310 | )
1311 | audio_tokens = torch.stack(
1312 | [
1313 | audio_tokens[:, :, ci] + min_audio_token_id + ci*self.config.codebook_size
1314 | for ci in range(self.config.num_codebooks)
1315 | ],
1316 | dim=2,
1317 | )
1318 |
1319 | next_token_logits = outputs.logits[:, -1, :]
1320 |
1321 | # voila head output
1322 | voila_head_pred = outputs.voila_pred[:, -1, :]
1323 | voila_head_pred = torch.argmax(voila_head_pred, dim=-1)
1324 | voila_head_pred = voila_head_pred.cpu()[0].item()
1325 |
1326 | # pre-process distribution
1327 | # Apply temperature and top-k
1328 | if temperature > 0:
1329 | next_token_logits = next_token_logits / temperature
1330 | if top_k > 0:
1331 | top_k = min(top_k, next_token_logits.size(-1)) # Safety check
1332 | # Remove all tokens with a probability less than the last token of the top-k
1333 | indices_to_remove = next_token_logits < torch.topk(next_token_logits, top_k)[0][..., -1, None]
1334 | next_token_logits = next_token_logits.masked_fill(indices_to_remove, -float("Inf"))
1335 |
1336 | # sample
1337 | probs = nn.functional.softmax(next_token_logits, dim=-1)
1338 | next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
1339 |
1340 | # voila head pred == 1, use assistant token
1341 | if voila_head_pred == 1 and not speaking:
1342 | next_tokens[0] = llm_assistant_token_id
1343 | speaking = True
1344 | elif next_tokens[0] == eos_token_id:
1345 | speaking = False
1346 |
1347 | # finished sentences should have their next token be a padding token
1348 | if eos_token_id is not None:
1349 | if pad_token_id is None:
1350 | raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
1351 | next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
1352 |
1353 | # Append NUM_CODEBOOK text tokens or audio_tokens
1354 | if len(next_tokens.shape) == 1:
1355 | next_tokens = next_tokens[:, None, None].expand(-1, 1, self.config.num_codebooks)
1356 | audio_token_mask = next_tokens == llm_audio_token_id
1357 | next_tokens = next_tokens * torch.logical_not(audio_token_mask) + audio_tokens * audio_token_mask
1358 |
1359 | if audio_token_mask[0, 0, 0].item():
1360 | try:
1361 | new_input_tokens = next(input_generator)
1362 | except:
1363 | this_peer_finished = True
1364 | break
1365 | new_input_tokens = new_input_tokens[None,None,:]
1366 | else:
1367 | new_input_tokens = next_tokens
1368 | new_input_tokens = torch.cat([new_input_tokens, next_tokens], dim=2)
1369 |
1370 | input_ids = torch.cat([input_ids, new_input_tokens], dim=1)
1371 | if streamer is not None:
1372 | streamer.put(next_tokens.cpu())
1373 | model_kwargs = self._update_model_kwargs_for_generation(
1374 | outputs, model_kwargs
1375 | )
1376 |
1377 | # # if eos_token was found in one sentence, set sentence to finished
1378 | # if eos_token_id_tensor is not None:
1379 | # unfinished_sequences = unfinished_sequences.mul(
1380 | # next_tokens[:, :, 0].ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=1)
1381 | # )
1382 |
1383 | # # stop when each sentence is finished
1384 | # if unfinished_sequences.max() == 0:
1385 | # this_peer_finished = True
1386 |
1387 | # stop if we exceed the maximum length
1388 | if input_ids.shape[1] >= max_length:
1389 | this_peer_finished = True
1390 |
1391 | if this_peer_finished:
1392 | break
1393 |
1394 | if streamer is not None:
1395 | streamer.end()
1396 |
1397 | return input_ids
1398 |
--------------------------------------------------------------------------------