├── .gitignore ├── README.md ├── example_prompt.flac ├── pyproject.toml ├── scripts ├── eval_continue.py ├── eval_reference.py ├── eval_stream.py ├── run_offline.py ├── run_stream.py ├── run_voice_clone.py ├── train_libriheavy.py └── train_libriheavy_stream.py ├── shell_scripts ├── eval.sh ├── run.sh ├── train_libriheavy.sh └── train_libriheavy_stream.sh ├── sled.egg-info ├── PKG-INFO ├── SOURCES.txt ├── dependency_links.txt ├── requires.txt └── top_level.txt ├── sled ├── __init__.py ├── energy_distance.py ├── modeling_llama_with_dropout.py ├── sled.py ├── sled_stream.py ├── trainer.py └── trainer_libriheavy.py └── tokenizer_bpe_libriheavy ├── merges.txt └── vocab.json /.gitignore: -------------------------------------------------------------------------------- 1 | **/__pycache__/ 2 | /runs/ 3 | *.sh 4 | *.wav 5 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 🛷SLED-TTS: Efficient Speech Language Modeling via Energy Distance in Continuous Latent Space 2 | > **Authors: [Zhengrui Ma](https://scholar.google.com/citations?user=dUgq6tEAAAAJ), [Yang Feng*](https://people.ucas.edu.cn/~yangfeng?language=en), [Chenze Shao](https://scholar.google.com/citations?user=LH_rZf8AAAAJ&hl), [Fandong Meng](https://fandongmeng.github.io/), [Jie Zhou](https://scholar.google.com.hk/citations?user=OijxQCMAAAAJ&hl=en), [Min Zhang](https://scholar.google.com/citations?user=CncXH-YAAAAJ&hl=en)** 3 | 4 | [![arXiv](https://img.shields.io/badge/arXiv-2505.13181-b31b1b.svg?logo=arXiv)](https://arxiv.org/abs/2505.13181) 5 | [![code](https://img.shields.io/badge/Github-Code-keygen.svg?logo=github)](https://github.com/ictnlp/SLED-TTS) 6 | [![HuggingFace](https://img.shields.io/badge/HuggingFace-FEC200?style=flat&logo=Hugging%20Face)](https://huggingface.co/collections/ICTNLP/sled-tts-680253e19c889010a1a376ac) 7 | [![WeChat AI](https://img.shields.io/badge/WeChat%20AI-4CAF50?style=flat&logo=wechat)](https://www.wechat.com) 8 | [![ICT/CAS](https://img.shields.io/badge/ICT%2FCAS-0066cc?style=flat&logo=school)](https://ict.cas.cn) 9 | 10 | ## News 11 | - **Our paper has been released on [arXiv](https://arxiv.org/abs/2505.13181).** 12 | 13 | ## Key features 14 | - **Continuous Autoregressive Modeling**: SLED models speech in a continuous latent space, eliminating the need for complex hierarchical architectures. 15 | - **Streaming Synthesis**: SLED supports streaming synthesis, enabling speech generation to start as soon as the text stream begins. 16 | - **Voice Cloning**: Capable of generating speech based on a 3-second prefix or reference utterance as prompt. 17 | 18 | 19 | ## Demo 20 | You can check SLED in action by exploring the [demo page](https://sled-demo.github.io/). 21 |
22 | 23 | 24 |
25 | 26 | ## Available Models on Hugging Face 27 | 28 | We are currently offering two English models trained on LibriHeavy on [Hugging Face](https://huggingface.co/collections/ICTNLP/sled-tts-680253e19c889010a1a376ac): 29 | 30 | 1. **[SLED-TTS-Libriheavy](https://huggingface.co/ICTNLP/SLED-TTS-Libriheavy)**: This model is trained on Libriheavy and provides high-quality text-to-speech synthesis. 31 | 32 | 2. **[SLED-TTS-Streaming-Libriheavy](https://huggingface.co/ICTNLP/SLED-TTS-Streaming-Libriheavy)**: This variant supports **streaming decoding**, which generates a 0.6-second speech chunk for every 5 text tokens received. 33 | 34 | **Alternatively, you can train SLED on your own data by following the guidelines below.** 35 | 36 | ## Usage 37 | **We provide the training and inference code for SLED-TTS.** 38 | 39 | ### Installation 40 | ``` sh 41 | git clone https://github.com/ictnlp/SLED-TTS.git 42 | cd SLED-TTS 43 | pip install -e ./ 44 | ``` 45 | 46 | We currently utilize the sum of the first 8 embedding vectors from [Encodec_24khz](https://huggingface.co/facebook/encodec_24khz) as the continuous latent vector. To proceed, ensure that [Encodec_24khz](https://huggingface.co/facebook/encodec_24khz) is downloaded and cached in your HuggingFace dir. 47 | 48 | ### Inference 49 | - Set the `CHECKPOINT` variable to the path of the cached **[SLED-TTS-Libriheavy](https://huggingface.co/ICTNLP/SLED-TTS-Libriheavy)** or **[SLED-TTS-Streaming-Libriheavy](https://huggingface.co/ICTNLP/SLED-TTS-Streaming-Libriheavy)** model. 50 | - Diverse generation results can be obtained by varying the `SEED` variable. 51 | - Use `-bf16` flag to enable bf16 inference. 52 | ``` sh 53 | CHECKPOINT=/path/to/checkpoint 54 | CFG=2.0 55 | SEED=0 56 | ``` 57 | ***Offline Inference*** 58 | ``` sh 59 | python scripts/run_offline.py \ 60 | --model_name_or_path ${CHECKPOINT} \ 61 | --cfg ${CFG} \ 62 | --input "My remark pleases him, but I soon prove to him that it is not the right way to speak. However perfect may have been the language of that ancient writer." \ 63 | --seed ${SEED} 64 | ``` 65 | ***Streaming Inference*** 66 | ``` sh 67 | python scripts/run_stream.py \ 68 | --model_name_or_path ${CHECKPOINT} \ 69 | --cfg ${CFG} \ 70 | --input "My remark pleases him, but I soon prove to him that it is not the right way to speak. However perfect may have been the language of that ancient writer." \ 71 | --seed ${SEED} 72 | # Please note that we have simulated the generation in a streaming environment in run_stream.py for evaluating its quality. 73 | # However, the existing code does not actually provide a streaming API. 74 | ``` 75 | ***Voice Clone*** 76 | 77 | You can adjust the prompt speech by setting `--prompt_text` and `--prompt_audio`. 78 | ``` sh 79 | python scripts/run_voice_clone.py \ 80 | --prompt_text "Were I in the warm room with all the splendor and magnificence!" \ 81 | --prompt_audio "example_prompt.flac" \ 82 | --model_name_or_path ${CHECKPOINT} \ 83 | --cfg ${CFG} \ 84 | --input "Perhaps the other trees from the forest will come to look at me!" \ 85 | --seed ${SEED} 86 | ``` 87 | 88 | ### Training 89 | 90 | ***Data Processing*** 91 | 92 | Process the LibriHeavy data so that each line follows the JSON format shown below. 93 | ``` 94 | {"id": "large/10022/essayoncriticism_1505_librivox_64kb_mp3/essayoncriticism_01_pope_64kb_5", "start": 610.32, "duration": 19.76, "supervisions": [{"text": "Hail! bards triumphant! born in happier days; Immortal heirs of universal praise! Whose honors with increase of ages grow, As streams roll down, enlarging as they flow; Nations unborn your mighty names shall sound, [193] And worlds applaud that must not yet be found!"}], "recording": {"sources": [{"source": "download/librilight/large/10022/essayoncriticism_1505_librivox_64kb_mp3/essayoncriticism_01_pope_64kb.flac"}], "sampling_rate": 16000}, "type": "MonoCut"} 95 | ``` 96 | Or you can use the manifest of LibriHeavy available at this [URL](https://huggingface.co/datasets/ICTNLP/LibriHeavy_manifest). For your own datasets, process them into a similar format. 97 | 98 | ***Training Offline Model*** 99 | ``` sh 100 | OUTPUT_DIR=./runs/libriheavy 101 | mkdir -p $OUTPUT_DIR 102 | LOG_FILE=${OUTPUT_DIR}/log 103 | 104 | BATCH_SIZE=8 105 | UPDATE_FREQ=8 106 | # assume 8 proc per node, then WORLD_SIZE * 8 * BATCH_SIZE * UPDATE_FREQ == 512 107 | 108 | torchrun --nnodes ${WORLD_SIZE} --node_rank ${RANK} --nproc_per_node 8 --master_addr ${MASTER_ADDR} --master_port ${MASTER_PORT} \ 109 | ./scripts/train_libriheavy.py \ 110 | --training_cfg 0.1 \ 111 | --num_hidden_layers 12 --diffloss_d 6 --noise_channels 128 \ 112 | --dataloader_num_workers 8 \ 113 | --dataloader_pin_memory True \ 114 | --remove_unused_columns False \ 115 | --label_names audio_inputs \ 116 | --group_by_speech_length \ 117 | --do_train \ 118 | --do_eval \ 119 | --eval_strategy steps \ 120 | --eval_steps 10000 \ 121 | --prediction_loss_only \ 122 | --per_device_train_batch_size ${BATCH_SIZE} \ 123 | --per_device_eval_batch_size 24 \ 124 | --gradient_accumulation_steps ${UPDATE_FREQ} \ 125 | --bf16 \ 126 | --learning_rate 5e-4 \ 127 | --weight_decay 0.01 \ 128 | --adam_beta1 0.9 \ 129 | --adam_beta2 0.999 \ 130 | --adam_epsilon 1e-8 \ 131 | --max_grad_norm 1.0 \ 132 | --max_steps 300000 \ 133 | --lr_scheduler_type "linear" \ 134 | --warmup_steps 32000 \ 135 | --logging_first_step \ 136 | --logging_steps 100 \ 137 | --save_steps 10000 \ 138 | --save_total_limit 10 \ 139 | --output_dir ${OUTPUT_DIR} \ 140 | --report_to tensorboard \ 141 | --disable_tqdm True \ 142 | --ddp_timeout 3600 --overwrite_output_dir 143 | 144 | ``` 145 | 146 | ***Training Streaming Model*** 147 | ``` sh 148 | OUTPUT_DIR=./runs/libriheavy_stream 149 | mkdir -p $OUTPUT_DIR 150 | LOG_FILE=${OUTPUT_DIR}/log 151 | 152 | BATCH_SIZE=8 153 | UPDATE_FREQ=8 154 | # assume 8 proc per node, then WORLD_SIZE * 8 * BATCH_SIZE * UPDATE_FREQ == 512 155 | 156 | torchrun --nnodes ${WORLD_SIZE} --node_rank ${RANK} --nproc_per_node 8 --master_addr ${MASTER_ADDR} --master_port ${MASTER_PORT} \ 157 | ./scripts/train_libriheavy_stream.py \ 158 | --finetune_path ./runs/libriheavy/checkpoint-300000/model.safetensors \ 159 | --stream_n 5 --stream_m 45 \ 160 | --training_cfg 0.1 \ 161 | --num_hidden_layers 12 --diffloss_d 6 --noise_channels 128 \ 162 | --dataloader_num_workers 8 \ 163 | --dataloader_pin_memory True \ 164 | --remove_unused_columns False \ 165 | --label_names audio_inputs \ 166 | --group_by_speech_length \ 167 | --do_train \ 168 | --do_eval \ 169 | --eval_strategy steps \ 170 | --eval_steps 10000 \ 171 | --prediction_loss_only \ 172 | --per_device_train_batch_size ${BATCH_SIZE} \ 173 | --per_device_eval_batch_size 24 \ 174 | --gradient_accumulation_steps ${UPDATE_FREQ} \ 175 | --bf16 \ 176 | --learning_rate 3e-4 \ 177 | --weight_decay 0.01 \ 178 | --adam_beta1 0.9 \ 179 | --adam_beta2 0.999 \ 180 | --adam_epsilon 1e-8 \ 181 | --max_grad_norm 1.0 \ 182 | --max_steps 100000 \ 183 | --lr_scheduler_type "linear" \ 184 | --warmup_steps 10000 \ 185 | --logging_first_step \ 186 | --logging_steps 100 \ 187 | --save_steps 10000 \ 188 | --save_total_limit 10 \ 189 | --output_dir ${OUTPUT_DIR} \ 190 | --report_to tensorboard \ 191 | --disable_tqdm True \ 192 | --ddp_timeout 3600 --overwrite_output_dir 193 | ``` 194 | ### BF16 Support 195 | By setting the `-bf16` flag, the model will load in bf16 during inference and in fp32 during training (for mixed precision training). To enable pure bf16 training, you can change 196 | https://github.com/ictnlp/SLED-TTS/blob/69a0a77d37180ec711a21f39f1b6bffa8b068072/scripts/train_libriheavy.py#L298 197 | to 198 | ``` 199 | torch_dtype = torch.bfloat16 if training_args.bf16 else None 200 | ``` 201 | However, Encodec should always execute in fp32 to maintain the precision of latents. Therefore, we load Encodec in fp32 and downcast the encoded latent to bf16. 202 | 203 | 204 | ## Citation 205 | If you have any questions, please feel free to submit an issue or contact `mazhengrui21b@ict.ac.cn`. 206 | 207 | If our work is useful for you, please cite as: 208 | 209 | ``` 210 | @misc{ma2025efficientspeechlanguagemodeling, 211 | title={Efficient Speech Language Modeling via Energy Distance in Continuous Latent Space}, 212 | author={Zhengrui Ma and Yang Feng and Chenze Shao and Fandong Meng and Jie Zhou and Min Zhang}, 213 | year={2025}, 214 | eprint={2505.13181}, 215 | archivePrefix={arXiv}, 216 | primaryClass={cs.CL}, 217 | url={https://arxiv.org/abs/2505.13181}, 218 | } 219 | ``` 220 | -------------------------------------------------------------------------------- /example_prompt.flac: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ictnlp/SLED-TTS/b8ed10d9953160efd8a0538b4ea5af80a57c9e96/example_prompt.flac -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools", "wheel"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "sled" 7 | version = "0.1.0" 8 | description = "Implementation of SLED" 9 | authors = [ 10 | { name="Zhengrui Ma", email="mazhengrui21b@ict.ac.cn" } 11 | ] 12 | dependencies = [ 13 | "torch==2.5.1", 14 | "torchaudio==2.5.1", 15 | "transformers==4.47.0", 16 | "datasets==3.1.0", 17 | "accelerate==1.2.0", 18 | "numpy==1.26.4", 19 | "librosa==0.10.2.post1", 20 | "soundfile==0.12.1" 21 | ] 22 | 23 | [tool.setuptools] 24 | packages = ["sled"] 25 | -------------------------------------------------------------------------------- /scripts/eval_continue.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import logging 4 | from typing import Tuple 5 | from pathlib import Path 6 | 7 | import torch 8 | import torchaudio 9 | 10 | 11 | from datasets import load_dataset, Audio 12 | from transformers import RobertaTokenizer, AutoProcessor 13 | from transformers.data.data_collator import pad_without_fast_tokenizer_warning 14 | from accelerate.utils import set_seed 15 | 16 | import pdb 17 | 18 | 19 | from sled.sled import SpeechLlamaForCausalLM 20 | 21 | BANDWIDTH=6 22 | SAMPLING_RATE=24000 23 | STRIDE=320 24 | FREQ=75 25 | MIN_LEN=4.0 26 | MAX_LEN=10.0 27 | PROMPT_LEN=3 28 | 29 | 30 | 31 | logging.basicConfig( 32 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 33 | datefmt="%m/%d/%Y %H:%M:%S", 34 | level=logging.INFO, 35 | ) 36 | logger = logging.getLogger(__name__) 37 | 38 | 39 | 40 | def adjust_length_to_model(length, max_sequence_length): 41 | assert max_sequence_length > 0 42 | if length <= 0 and max_sequence_length > 0: 43 | length = max_sequence_length 44 | elif 0 < max_sequence_length < length: 45 | length = max_sequence_length # No generation bigger than model size 46 | return length 47 | 48 | 49 | def filter_function(example): 50 | return ((len(example["audio"]["array"]) / SAMPLING_RATE) < MAX_LEN) and ((len(example["audio"]["array"]) / SAMPLING_RATE) > MIN_LEN) 51 | 52 | 53 | 54 | def main(): 55 | parser = argparse.ArgumentParser() 56 | 57 | parser.add_argument( 58 | "--model_name_or_path", 59 | default=None, 60 | type=str, 61 | required=True, 62 | ) 63 | parser.add_argument("--max_length", type=int, default=0) 64 | parser.add_argument( 65 | "--cfg", 66 | type=float, 67 | default=1.0, 68 | ) 69 | parser.add_argument("--seed", type=int, default=42, help="random seed for initialization") 70 | parser.add_argument("--num_return_sequences", type=int, default=1, help="The number of samples to generate.") 71 | parser.add_argument( 72 | "--fp16", 73 | action="store_true", 74 | ) 75 | parser.add_argument( 76 | "--bf16", 77 | action="store_true", 78 | ) 79 | parser.add_argument("--batch_size", type=int, default=32) 80 | args = parser.parse_args() 81 | 82 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 83 | torch_dtype = torch.bfloat16 if args.bf16 else (torch.float16 if args.fp16 else torch.float32) 84 | logger.warning(f"device: {device}, 16-bits inference: {args.fp16 or args.bf16}") 85 | 86 | if args.seed is not None: 87 | set_seed(args.seed) 88 | 89 | # Initialize the model and tokenizer 90 | tokenizer = RobertaTokenizer.from_pretrained(args.model_name_or_path) 91 | eos_token_id = tokenizer.eos_token_id 92 | 93 | model = SpeechLlamaForCausalLM.from_pretrained(args.model_name_or_path, torch_dtype=torch_dtype) 94 | model.infer_cfg = args.cfg 95 | model.initialize_codec("facebook/encodec_24khz") 96 | model.to(device) 97 | 98 | 99 | processor = AutoProcessor.from_pretrained("facebook/encodec_24khz") 100 | 101 | assert tokenizer.pad_token is not None 102 | logger.info(f"tokenizer pad token: {tokenizer.pad_token}") 103 | 104 | 105 | max_seq_length = getattr(model.config, "max_position_embeddings", 0) 106 | args.max_length = adjust_length_to_model(args.max_length, max_sequence_length=max_seq_length) 107 | logger.info(args) 108 | 109 | 110 | eval_dataset = load_dataset("yoom618/librispeech_pc", split="test.clean") 111 | eval_dataset = eval_dataset.cast_column("audio", Audio(sampling_rate=SAMPLING_RATE)) 112 | 113 | 114 | logger.info(f"original eval dataset: {len(eval_dataset)} samples.") 115 | eval_dataset = eval_dataset.filter(filter_function) 116 | logger.info(f"filtered eval dataset: {len(eval_dataset)} samples.") 117 | 118 | tokenized_eval_dataset = eval_dataset.map( 119 | lambda example: tokenizer(example["text"]), 120 | batched=True, 121 | ) 122 | 123 | 124 | batch_size = args.batch_size 125 | output_path = Path("eval_continue") 126 | output_path.mkdir(parents=True, exist_ok=True) 127 | 128 | with torch.no_grad(): 129 | for i in range(0, len(tokenized_eval_dataset), batch_size): 130 | batch = tokenized_eval_dataset.select(range(i, min(i + batch_size, len(tokenized_eval_dataset)))) 131 | 132 | input_ids = [{"input_ids":instance["input_ids"]} for instance in batch] 133 | 134 | encodes = pad_without_fast_tokenizer_warning( 135 | tokenizer, 136 | input_ids, 137 | padding=True, 138 | return_attention_mask=True, 139 | return_tensors="pt" 140 | ) 141 | 142 | input_ids = encodes["input_ids"].to(device) 143 | attention_mask = encodes["attention_mask"].to(device) 144 | text_input_length = input_ids.shape[1] 145 | 146 | audio_arrays = [instance["audio"]["array"] for instance in batch] 147 | audio_inputs = processor(raw_audio=audio_arrays, sampling_rate=SAMPLING_RATE, return_tensors="pt") # 'padding_mask': b,t 'input_values': b,c,t 148 | 149 | 150 | encoder_outputs = model.codec.encode(audio_inputs["input_values"].to(device), audio_inputs["padding_mask"].to(device), bandwidth=BANDWIDTH) #1,b,r,t, 1 due to one chunk 151 | speech_inputs_embeds = model.codec.quantizer.decode(encoder_outputs.audio_codes[0].transpose(0, 1)) #b,d,t 152 | 153 | speech_attention_mask = audio_inputs["padding_mask"][..., ::STRIDE].to(device) 154 | assert speech_inputs_embeds.size(-1) == speech_attention_mask.size(-1) 155 | speech_inputs_embeds = speech_inputs_embeds.transpose(1,2).to(model.dtype) #b,t,d 156 | 157 | 158 | speech_inputs_embeds = speech_inputs_embeds[:,:FREQ * PROMPT_LEN,:] 159 | speech_attention_mask = speech_attention_mask[:,:FREQ * PROMPT_LEN] 160 | speech_input_length = speech_inputs_embeds.shape[1] 161 | 162 | 163 | new_attention_mask = torch.concat([attention_mask, speech_attention_mask], dim=1) 164 | 165 | 166 | output_sequences = model.generate( 167 | input_ids=input_ids, 168 | inputs_embeds=speech_inputs_embeds, 169 | attention_mask=new_attention_mask, 170 | max_length=args.max_length, 171 | do_sample=True, 172 | num_return_sequences=args.num_return_sequences, 173 | ) 174 | 175 | 176 | 177 | new_embeds = output_sequences[1] 178 | generated_ids = output_sequences[0][:, text_input_length:] 179 | 180 | new_audio_values = model.codec.decoder(new_embeds.transpose(-1,-2).float()) 181 | 182 | wav_len = (generated_ids.ne(eos_token_id).sum(dim=-1) + speech_input_length) * STRIDE 183 | 184 | for i in range(len(wav_len)): 185 | id = batch["id"][i] 186 | torchaudio.save(output_path / f"{id}.wav", new_audio_values[i][:,:wav_len[i]].cpu(), SAMPLING_RATE) 187 | 188 | return 189 | 190 | 191 | if __name__ == "__main__": 192 | main() 193 | -------------------------------------------------------------------------------- /scripts/eval_reference.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import logging 4 | from typing import Tuple 5 | from pathlib import Path 6 | 7 | import torch 8 | import torchaudio 9 | 10 | 11 | from datasets import load_dataset, Audio 12 | from transformers import AutoTokenizer, RobertaTokenizer, AutoProcessor 13 | from transformers.data.data_collator import pad_without_fast_tokenizer_warning 14 | from accelerate.utils import set_seed 15 | 16 | from sled.sled import SpeechLlamaForCausalLM 17 | 18 | 19 | 20 | import pdb 21 | 22 | BANDWIDTH=6 23 | SAMPLING_RATE=24000 24 | STRIDE=320 25 | FREQ=75 26 | MIN_LEN=4.0 27 | MAX_LEN=10.0 28 | 29 | 30 | 31 | 32 | logging.basicConfig( 33 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 34 | datefmt="%m/%d/%Y %H:%M:%S", 35 | level=logging.INFO, 36 | ) 37 | logger = logging.getLogger(__name__) 38 | 39 | 40 | 41 | def adjust_length_to_model(length, max_sequence_length): 42 | assert max_sequence_length > 0 43 | if length <= 0 and max_sequence_length > 0: 44 | length = max_sequence_length 45 | elif 0 < max_sequence_length < length: 46 | length = max_sequence_length # No generation bigger than model size 47 | return length 48 | 49 | 50 | def filter_function(example): 51 | return ((len(example["audio"]["array"]) / SAMPLING_RATE) < MAX_LEN) and ((len(example["audio"]["array"]) / SAMPLING_RATE) > MIN_LEN) 52 | 53 | 54 | 55 | def main(): 56 | parser = argparse.ArgumentParser() 57 | 58 | parser.add_argument( 59 | "--model_name_or_path", 60 | default=None, 61 | type=str, 62 | required=True, 63 | ) 64 | parser.add_argument("--max_length", type=int, default=0) 65 | parser.add_argument( 66 | "--cfg", 67 | type=float, 68 | default=1.0, 69 | ) 70 | parser.add_argument("--seed", type=int, default=42, help="random seed for initialization") 71 | parser.add_argument("--num_return_sequences", type=int, default=1, help="The number of samples to generate.") 72 | parser.add_argument( 73 | "--fp16", 74 | action="store_true", 75 | ) 76 | parser.add_argument( 77 | "--bf16", 78 | action="store_true", 79 | ) 80 | parser.add_argument("--batch_size", type=int, default=1) 81 | args = parser.parse_args() 82 | 83 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 84 | torch_dtype = torch.bfloat16 if args.bf16 else (torch.float16 if args.fp16 else torch.float32) 85 | logger.warning(f"device: {device}, 16-bits inference: {args.fp16 or args.bf16}") 86 | 87 | if args.seed is not None: 88 | set_seed(args.seed) 89 | 90 | # Initialize the model and tokenizer 91 | tokenizer = RobertaTokenizer.from_pretrained(args.model_name_or_path) 92 | eos_token_id = tokenizer.eos_token_id 93 | 94 | model = SpeechLlamaForCausalLM.from_pretrained(args.model_name_or_path, torch_dtype=torch_dtype) 95 | model.infer_cfg = args.cfg 96 | model.initialize_codec("facebook/encodec_24khz") 97 | model.to(device) 98 | 99 | 100 | processor = AutoProcessor.from_pretrained("facebook/encodec_24khz") 101 | 102 | assert tokenizer.pad_token is not None 103 | logger.info(f"tokenizer pad token: {tokenizer.pad_token}") 104 | 105 | 106 | max_seq_length = getattr(model.config, "max_position_embeddings", 0) 107 | args.max_length = adjust_length_to_model(args.max_length, max_sequence_length=max_seq_length) 108 | logger.info(args) 109 | 110 | 111 | eval_dataset = load_dataset("yoom618/librispeech_pc", split="test.clean") 112 | eval_dataset = eval_dataset.cast_column("audio", Audio(sampling_rate=SAMPLING_RATE)) 113 | 114 | 115 | logger.info(f"original eval dataset: {len(eval_dataset)} samples.") 116 | eval_dataset = eval_dataset.filter(filter_function) 117 | logger.info(f"filtered eval dataset: {len(eval_dataset)} samples.") 118 | 119 | 120 | batch_size = args.batch_size 121 | assert batch_size == 1 122 | 123 | pathname = "eval_reference" 124 | output_path = Path(pathname) 125 | output_path_concat = Path(pathname + "_concat") 126 | 127 | output_path.mkdir(parents=True, exist_ok=True) 128 | output_path_concat.mkdir(parents=True, exist_ok=True) 129 | last_index = None 130 | with torch.no_grad(): 131 | #reverse to implement 132 | for i in range(len(eval_dataset) - 1, -1, -batch_size): 133 | current_sample = eval_dataset[i] 134 | current_speaker_id = current_sample["speaker_id"] 135 | 136 | if last_index is None: 137 | last_index = i 138 | 139 | if i != 0: 140 | prompt_sample = eval_dataset[i - 1] 141 | if prompt_sample["speaker_id"] != current_speaker_id: 142 | prompt_sample = eval_dataset[last_index] 143 | last_index = None 144 | else: 145 | prompt_sample = eval_dataset[last_index] 146 | last_index = None 147 | 148 | 149 | text_to_synthesize = current_sample["text"] 150 | prompt_text = prompt_sample["text"] 151 | 152 | prompt_audio = prompt_sample["audio"]["array"] 153 | 154 | input_text = [prompt_text + " " + text_to_synthesize] 155 | 156 | batch_encoded = tokenizer.batch_encode_plus( 157 | input_text, 158 | add_special_tokens=True, 159 | padding="longest", 160 | truncation=True, 161 | return_tensors="pt" 162 | ) 163 | 164 | input_ids = batch_encoded["input_ids"].to(device) 165 | attention_mask = batch_encoded["attention_mask"].to(device) 166 | text_input_length = input_ids.shape[1] 167 | 168 | audio_arrays = [prompt_audio] 169 | audio_inputs = processor(raw_audio=audio_arrays, sampling_rate=SAMPLING_RATE, return_tensors="pt") # 'padding_mask': b,t 'input_values': b,c,t 170 | 171 | 172 | encoder_outputs = model.codec.encode(audio_inputs["input_values"].to(device), audio_inputs["padding_mask"].to(device), bandwidth=BANDWIDTH) #1,b,r,t, 1 due to one chunk 173 | speech_inputs_embeds = model.codec.quantizer.decode(encoder_outputs.audio_codes[0].transpose(0, 1)) #b,d,t 174 | 175 | speech_attention_mask = audio_inputs["padding_mask"][..., ::STRIDE].to(device) 176 | assert speech_inputs_embeds.size(-1) == speech_attention_mask.size(-1) 177 | speech_inputs_embeds = speech_inputs_embeds.transpose(1,2).to(model.dtype) #b,t,d 178 | 179 | 180 | speech_input_length = speech_inputs_embeds.shape[1] 181 | 182 | 183 | new_attention_mask = torch.concat([attention_mask, speech_attention_mask], dim=1) 184 | 185 | 186 | output_sequences = model.generate( 187 | input_ids=input_ids, 188 | inputs_embeds=speech_inputs_embeds, 189 | attention_mask=new_attention_mask, 190 | max_length=args.max_length, 191 | do_sample=True, 192 | num_return_sequences=args.num_return_sequences, 193 | ) 194 | 195 | 196 | 197 | new_embeds = output_sequences[1] 198 | generated_ids = output_sequences[0][:, text_input_length:] 199 | 200 | new_audio_values = model.codec.decoder(new_embeds.transpose(-1,-2).float()) 201 | 202 | wav_len = (generated_ids.ne(eos_token_id).sum(dim=-1) + speech_input_length) * STRIDE 203 | 204 | 205 | assert len(wav_len) == 1 206 | for i in range(len(wav_len)): 207 | id = current_sample["id"] 208 | torchaudio.save(output_path / f"{id}.wav", new_audio_values[i][:, speech_input_length* STRIDE:wav_len[i]].cpu(), SAMPLING_RATE) 209 | torchaudio.save(output_path_concat / f"{id}.wav", new_audio_values[i][:,:wav_len[i]].cpu(), SAMPLING_RATE) 210 | return 211 | 212 | 213 | if __name__ == "__main__": 214 | main() 215 | -------------------------------------------------------------------------------- /scripts/eval_stream.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import logging 4 | from typing import Tuple 5 | from pathlib import Path 6 | 7 | import torch 8 | import torchaudio 9 | 10 | 11 | from datasets import load_dataset, Audio 12 | from transformers import RobertaTokenizer, AutoProcessor 13 | from transformers.data.data_collator import pad_without_fast_tokenizer_warning 14 | from accelerate.utils import set_seed 15 | 16 | from sled.sled_stream import SpeechLlamaForCausalLM 17 | 18 | 19 | 20 | import pdb 21 | 22 | BANDWIDTH=6 23 | SAMPLING_RATE=24000 24 | STRIDE=320 25 | FREQ=75 26 | MIN_LEN=4.0 27 | MAX_LEN=10.0 28 | 29 | 30 | 31 | logging.basicConfig( 32 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 33 | datefmt="%m/%d/%Y %H:%M:%S", 34 | level=logging.INFO, 35 | ) 36 | logger = logging.getLogger(__name__) 37 | 38 | 39 | 40 | def adjust_length_to_model(length, max_sequence_length): 41 | assert max_sequence_length > 0 42 | if length <= 0 and max_sequence_length > 0: 43 | length = max_sequence_length 44 | elif 0 < max_sequence_length < length: 45 | length = max_sequence_length # No generation bigger than model size 46 | return length 47 | 48 | 49 | def filter_function(example): 50 | return ((len(example["audio"]["array"]) / SAMPLING_RATE) < MAX_LEN) and ((len(example["audio"]["array"]) / SAMPLING_RATE) > MIN_LEN) 51 | 52 | 53 | 54 | 55 | 56 | def main(): 57 | parser = argparse.ArgumentParser() 58 | 59 | parser.add_argument( 60 | "--model_name_or_path", 61 | default=None, 62 | type=str, 63 | required=True, 64 | ) 65 | parser.add_argument("--max_length", type=int, default=0) 66 | parser.add_argument( 67 | "--cfg", 68 | type=float, 69 | default=1.0, 70 | ) 71 | parser.add_argument("--seed", type=int, default=42, help="random seed for initialization") 72 | parser.add_argument("--num_return_sequences", type=int, default=1, help="The number of samples to generate.") 73 | parser.add_argument( 74 | "--fp16", 75 | action="store_true", 76 | ) 77 | parser.add_argument( 78 | "--bf16", 79 | action="store_true", 80 | ) 81 | parser.add_argument("--batch_size", type=int, default=32) 82 | args = parser.parse_args() 83 | 84 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 85 | torch_dtype = torch.bfloat16 if args.bf16 else (torch.float16 if args.fp16 else torch.float32) 86 | logger.warning(f"device: {device}, 16-bits inference: {args.fp16 or args.bf16}") 87 | 88 | if args.seed is not None: 89 | set_seed(args.seed) 90 | 91 | # Initialize the model and tokenizer 92 | tokenizer = RobertaTokenizer.from_pretrained(args.model_name_or_path) 93 | eos_token_id = tokenizer.eos_token_id 94 | 95 | model = SpeechLlamaForCausalLM.from_pretrained(args.model_name_or_path, torch_dtype=torch_dtype) 96 | model.infer_cfg = args.cfg 97 | model.initialize_codec("facebook/encodec_24khz") 98 | model.to(device) 99 | 100 | 101 | processor = AutoProcessor.from_pretrained("facebook/encodec_24khz") 102 | 103 | assert tokenizer.pad_token is not None 104 | logger.info(f"tokenizer pad token: {tokenizer.pad_token}") 105 | 106 | 107 | max_seq_length = getattr(model.config, "max_position_embeddings", 0) 108 | args.max_length = adjust_length_to_model(args.max_length, max_sequence_length=max_seq_length) 109 | logger.info(args) 110 | 111 | 112 | eval_dataset = load_dataset("yoom618/librispeech_pc", split="test.clean") 113 | eval_dataset = eval_dataset.cast_column("audio", Audio(sampling_rate=SAMPLING_RATE)) 114 | 115 | 116 | logger.info(f"original eval dataset: {len(eval_dataset)} samples.") 117 | eval_dataset = eval_dataset.filter(filter_function) 118 | logger.info(f"filtered eval dataset: {len(eval_dataset)} samples.") 119 | 120 | tokenized_eval_dataset = eval_dataset.map( 121 | lambda example: tokenizer(example["text"]), 122 | batched=True, 123 | ) 124 | 125 | 126 | batch_size = args.batch_size 127 | output_path = Path("eval_stream") 128 | output_path.mkdir(parents=True, exist_ok=True) 129 | 130 | with torch.no_grad(): 131 | for i in range(0, len(tokenized_eval_dataset), batch_size): 132 | batch = tokenized_eval_dataset.select(range(i, min(i + batch_size, len(tokenized_eval_dataset)))) 133 | 134 | input_ids = [{"input_ids":instance["input_ids"]} for instance in batch] 135 | 136 | encodes = pad_without_fast_tokenizer_warning( 137 | tokenizer, 138 | input_ids, 139 | padding=True, 140 | return_attention_mask=True, 141 | return_tensors="pt" 142 | ) 143 | 144 | input_ids = encodes["input_ids"].to(device) 145 | attention_mask = encodes["attention_mask"].to(device) 146 | text_input_length = input_ids.shape[1] 147 | 148 | 149 | 150 | output_sequences = model.generate( 151 | input_ids=input_ids, 152 | attention_mask=attention_mask, 153 | max_length=args.max_length, 154 | do_sample=True, 155 | num_return_sequences=args.num_return_sequences, 156 | ) 157 | 158 | 159 | 160 | new_embeds = output_sequences[1] 161 | generated_ids = output_sequences[0] 162 | 163 | new_audio_values = model.codec.decoder(new_embeds.transpose(-1,-2).float()) 164 | 165 | wav_len = (generated_ids.ne(eos_token_id).sum(dim=-1)) * STRIDE 166 | 167 | 168 | 169 | for i in range(len(wav_len)): 170 | id = batch["id"][i] 171 | torchaudio.save(output_path / f"{id}.wav", new_audio_values[i][:,:wav_len[i]].cpu(), SAMPLING_RATE) 172 | 173 | return 174 | 175 | 176 | if __name__ == "__main__": 177 | main() 178 | -------------------------------------------------------------------------------- /scripts/run_offline.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | from typing import Tuple 4 | from pathlib import Path 5 | 6 | import pdb 7 | import torch 8 | import torchaudio 9 | from accelerate.utils import set_seed 10 | 11 | from transformers import AutoTokenizer, PreTrainedTokenizerFast, RobertaTokenizer 12 | from sled.sled import SpeechLlamaForCausalLM 13 | 14 | 15 | 16 | SAMPLING_RATE=24000 17 | 18 | logging.basicConfig( 19 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 20 | datefmt="%m/%d/%Y %H:%M:%S", 21 | level=logging.INFO, 22 | ) 23 | logger = logging.getLogger(__name__) 24 | 25 | 26 | 27 | def adjust_length_to_model(length, max_sequence_length): 28 | assert max_sequence_length > 0 29 | if length <= 0 and max_sequence_length > 0: 30 | length = max_sequence_length 31 | elif 0 < max_sequence_length < length: 32 | length = max_sequence_length # No generation bigger than model size 33 | return length 34 | 35 | def main(): 36 | parser = argparse.ArgumentParser() 37 | 38 | parser.add_argument( 39 | "--model_name_or_path", 40 | default=None, 41 | type=str, 42 | required=True, 43 | ) 44 | 45 | parser.add_argument("--input", type=str, default="It was silent and gloomy, being tenanted solely by the captive and lighted by the dying embers of a fire which had been used for the purposes of cookery.") 46 | parser.add_argument("--max_length", type=int, default=0) 47 | 48 | parser.add_argument( 49 | "--temperature", 50 | type=float, 51 | default=1.0, 52 | help="temperature of 1.0 has no effect, lower tend toward greedy sampling", 53 | ) 54 | parser.add_argument( 55 | "--cfg", 56 | type=float, 57 | default=1.0, 58 | ) 59 | parser.add_argument("--seed", type=int, default=42, help="random seed for initialization") 60 | parser.add_argument("--num_return_sequences", type=int, default=1, help="The number of samples to generate.") 61 | parser.add_argument( 62 | "--fp16", 63 | action="store_true", 64 | ) 65 | parser.add_argument( 66 | "--bf16", 67 | action="store_true", 68 | ) 69 | args = parser.parse_args() 70 | 71 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 72 | torch_dtype = torch.bfloat16 if args.bf16 else (torch.float16 if args.fp16 else torch.float32) 73 | logger.warning(f"device: {device}, 16-bits inference: {args.fp16 or args.bf16}") 74 | 75 | if args.seed is not None: 76 | set_seed(args.seed) 77 | 78 | # Initialize the model and tokenizer 79 | tokenizer = RobertaTokenizer.from_pretrained(args.model_name_or_path) 80 | eos_token_id = tokenizer.eos_token_id 81 | 82 | model = SpeechLlamaForCausalLM.from_pretrained(args.model_name_or_path, torch_dtype=torch_dtype) 83 | model.infer_cfg = args.cfg 84 | model.initialize_codec("facebook/encodec_24khz") 85 | model.to(device) 86 | 87 | assert tokenizer.pad_token is not None 88 | logger.info(f"tokenizer pad token: {tokenizer.pad_token}") 89 | 90 | 91 | max_seq_length = getattr(model.config, "max_position_embeddings", 0) 92 | args.max_length = adjust_length_to_model(args.max_length, max_sequence_length=max_seq_length) 93 | logger.info(args) 94 | 95 | 96 | input_text = args.input if args.input else input("Model input >>> ") 97 | input_ids = tokenizer.encode(input_text, return_tensors='pt').to(device) 98 | 99 | 100 | output_sequences = model.generate( 101 | input_ids=input_ids, 102 | max_length=args.max_length, 103 | do_sample=True, 104 | num_return_sequences=args.num_return_sequences, 105 | ) 106 | 107 | 108 | new_embeds = output_sequences[1] 109 | new_audio_values = model.codec.decoder(new_embeds.transpose(-1,-2).float()) 110 | 111 | 112 | output_path = "output.wav" 113 | torchaudio.save(output_path, new_audio_values[0].cpu(), SAMPLING_RATE) 114 | 115 | return 116 | 117 | 118 | if __name__ == "__main__": 119 | main() 120 | -------------------------------------------------------------------------------- /scripts/run_stream.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | from typing import Tuple 4 | from pathlib import Path 5 | 6 | import pdb 7 | import torch 8 | import torchaudio 9 | from accelerate.utils import set_seed 10 | 11 | from transformers import AutoTokenizer, PreTrainedTokenizerFast, RobertaTokenizer 12 | from sled.sled_stream import SpeechLlamaForCausalLM 13 | 14 | 15 | 16 | SAMPLING_RATE=24000 17 | 18 | logging.basicConfig( 19 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 20 | datefmt="%m/%d/%Y %H:%M:%S", 21 | level=logging.INFO, 22 | ) 23 | logger = logging.getLogger(__name__) 24 | 25 | 26 | 27 | def adjust_length_to_model(length, max_sequence_length): 28 | assert max_sequence_length > 0 29 | if length <= 0 and max_sequence_length > 0: 30 | length = max_sequence_length 31 | elif 0 < max_sequence_length < length: 32 | length = max_sequence_length # No generation bigger than model size 33 | return length 34 | 35 | def main(): 36 | parser = argparse.ArgumentParser() 37 | 38 | parser.add_argument( 39 | "--model_name_or_path", 40 | default=None, 41 | type=str, 42 | required=True, 43 | ) 44 | 45 | parser.add_argument("--input", type=str, default="It was silent and gloomy, being tenanted solely by the captive and lighted by the dying embers of a fire which had been used for the purposes of cookery.") 46 | parser.add_argument("--max_length", type=int, default=0) 47 | 48 | parser.add_argument( 49 | "--temperature", 50 | type=float, 51 | default=1.0, 52 | help="temperature of 1.0 has no effect, lower tend toward greedy sampling", 53 | ) 54 | parser.add_argument( 55 | "--cfg", 56 | type=float, 57 | default=1.0, 58 | ) 59 | parser.add_argument("--seed", type=int, default=42, help="random seed for initialization") 60 | parser.add_argument("--num_return_sequences", type=int, default=1, help="The number of samples to generate.") 61 | parser.add_argument( 62 | "--fp16", 63 | action="store_true", 64 | ) 65 | parser.add_argument( 66 | "--bf16", 67 | action="store_true", 68 | ) 69 | args = parser.parse_args() 70 | 71 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 72 | torch_dtype = torch.bfloat16 if args.bf16 else (torch.float16 if args.fp16 else torch.float32) 73 | logger.warning(f"device: {device}, 16-bits inference: {args.fp16 or args.bf16}") 74 | 75 | if args.seed is not None: 76 | set_seed(args.seed) 77 | 78 | # Initialize the model and tokenizer 79 | tokenizer = RobertaTokenizer.from_pretrained(args.model_name_or_path) 80 | eos_token_id = tokenizer.eos_token_id 81 | 82 | model = SpeechLlamaForCausalLM.from_pretrained(args.model_name_or_path, torch_dtype=torch_dtype) 83 | model.infer_cfg = args.cfg 84 | model.initialize_codec("facebook/encodec_24khz") 85 | model.to(device) 86 | 87 | assert tokenizer.pad_token is not None 88 | logger.info(f"tokenizer pad token: {tokenizer.pad_token}") 89 | 90 | 91 | max_seq_length = getattr(model.config, "max_position_embeddings", 0) 92 | args.max_length = adjust_length_to_model(args.max_length, max_sequence_length=max_seq_length) 93 | logger.info(args) 94 | 95 | 96 | input_text = args.input if args.input else input("Model input >>> ") 97 | input_ids = tokenizer.encode(input_text, return_tensors='pt').to(device) 98 | 99 | 100 | output_sequences = model.generate( 101 | input_ids=input_ids, 102 | max_length=args.max_length, 103 | do_sample=True, 104 | num_return_sequences=args.num_return_sequences, 105 | ) 106 | 107 | 108 | new_embeds = output_sequences[1] 109 | new_audio_values = model.codec.decoder(new_embeds.transpose(-1,-2).float()) 110 | 111 | 112 | output_path = "output.wav" 113 | torchaudio.save(output_path, new_audio_values[0].cpu(), SAMPLING_RATE) 114 | 115 | return 116 | 117 | 118 | if __name__ == "__main__": 119 | main() 120 | -------------------------------------------------------------------------------- /scripts/run_voice_clone.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | from typing import Tuple 4 | from pathlib import Path 5 | 6 | import pdb 7 | import torch 8 | import torchaudio 9 | from accelerate.utils import set_seed 10 | 11 | from transformers import AutoTokenizer, PreTrainedTokenizerFast, RobertaTokenizer, AutoProcessor 12 | from sled.sled import SpeechLlamaForCausalLM 13 | 14 | 15 | BANDWIDTH=6 16 | STRIDE=320 17 | SAMPLING_RATE=24000 18 | 19 | logging.basicConfig( 20 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 21 | datefmt="%m/%d/%Y %H:%M:%S", 22 | level=logging.INFO, 23 | ) 24 | logger = logging.getLogger(__name__) 25 | 26 | 27 | 28 | def adjust_length_to_model(length, max_sequence_length): 29 | assert max_sequence_length > 0 30 | if length <= 0 and max_sequence_length > 0: 31 | length = max_sequence_length 32 | elif 0 < max_sequence_length < length: 33 | length = max_sequence_length # No generation bigger than model size 34 | return length 35 | 36 | def main(): 37 | parser = argparse.ArgumentParser() 38 | 39 | parser.add_argument( 40 | "--model_name_or_path", 41 | default=None, 42 | type=str, 43 | required=True, 44 | ) 45 | parser.add_argument("--prompt_audio", type=str) 46 | parser.add_argument("--prompt_text", type=str) 47 | parser.add_argument("--input", type=str, default="It was silent and gloomy, being tenanted solely by the captive and lighted by the dying embers of a fire which had been used for the purposes of cookery.") 48 | parser.add_argument("--max_length", type=int, default=0) 49 | 50 | parser.add_argument( 51 | "--temperature", 52 | type=float, 53 | default=1.0, 54 | help="temperature of 1.0 has no effect, lower tend toward greedy sampling", 55 | ) 56 | parser.add_argument( 57 | "--cfg", 58 | type=float, 59 | default=1.0, 60 | ) 61 | parser.add_argument("--seed", type=int, default=42, help="random seed for initialization") 62 | parser.add_argument("--num_return_sequences", type=int, default=1, help="The number of samples to generate.") 63 | parser.add_argument( 64 | "--fp16", 65 | action="store_true", 66 | ) 67 | parser.add_argument( 68 | "--bf16", 69 | action="store_true", 70 | ) 71 | args = parser.parse_args() 72 | 73 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 74 | torch_dtype = torch.bfloat16 if args.bf16 else (torch.float16 if args.fp16 else torch.float32) 75 | logger.warning(f"device: {device}, 16-bits inference: {args.fp16 or args.bf16}") 76 | 77 | if args.seed is not None: 78 | set_seed(args.seed) 79 | 80 | # Initialize the model and tokenizer 81 | tokenizer = RobertaTokenizer.from_pretrained(args.model_name_or_path) 82 | eos_token_id = tokenizer.eos_token_id 83 | 84 | model = SpeechLlamaForCausalLM.from_pretrained(args.model_name_or_path, torch_dtype=torch_dtype) 85 | model.infer_cfg = args.cfg 86 | model.initialize_codec("facebook/encodec_24khz") 87 | model.to(device) 88 | 89 | processor = AutoProcessor.from_pretrained("facebook/encodec_24khz") 90 | 91 | assert tokenizer.pad_token is not None 92 | logger.info(f"tokenizer pad token: {tokenizer.pad_token}") 93 | 94 | 95 | max_seq_length = getattr(model.config, "max_position_embeddings", 0) 96 | args.max_length = adjust_length_to_model(args.max_length, max_sequence_length=max_seq_length) 97 | logger.info(args) 98 | 99 | prompt_text = args.prompt_text if args.prompt_text else input("Prompt Text >>> ") 100 | prompt_audio = args.prompt_audio 101 | waveform, sample_rate = torchaudio.load(prompt_audio, normalize=True) 102 | if sample_rate != SAMPLING_RATE: 103 | resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=SAMPLING_RATE) 104 | waveform = resampler(waveform).squeeze().numpy() 105 | 106 | 107 | input_text = args.input if args.input else input("Model input >>> ") 108 | #input_ids = tokenizer.encode(input_text + " " + prompt_text, return_tensors='pt').to(device) 109 | input_text = [prompt_text + " " + input_text] 110 | 111 | batch_encoded = tokenizer.batch_encode_plus( 112 | input_text, 113 | add_special_tokens=True, 114 | padding="longest", 115 | truncation=True, 116 | return_tensors="pt", 117 | ) 118 | 119 | input_ids = batch_encoded["input_ids"].to(device) 120 | attention_mask = batch_encoded["attention_mask"].to(device) 121 | text_input_length = input_ids.shape[1] 122 | 123 | 124 | audio_arrays = [waveform] 125 | audio_inputs = processor(raw_audio=audio_arrays, sampling_rate=SAMPLING_RATE, return_tensors="pt") # 'padding_mask': b,t 'input_values': b,c,t 126 | 127 | 128 | encoder_outputs = model.codec.encode(audio_inputs["input_values"].to(device), audio_inputs["padding_mask"].to(device), bandwidth=BANDWIDTH) #1,b,r,t, 1 due to one chunk 129 | speech_inputs_embeds = model.codec.quantizer.decode(encoder_outputs.audio_codes[0].transpose(0, 1)) #b,d,t 130 | 131 | speech_attention_mask = audio_inputs["padding_mask"][..., ::STRIDE].to(device) 132 | assert speech_inputs_embeds.size(-1) == speech_attention_mask.size(-1) 133 | speech_inputs_embeds = speech_inputs_embeds.transpose(1,2).to(model.dtype) #b,t,d 134 | 135 | speech_input_length = speech_inputs_embeds.shape[1] 136 | 137 | 138 | new_attention_mask = torch.concat([attention_mask, speech_attention_mask], dim=1) 139 | 140 | 141 | output_sequences = model.generate( 142 | input_ids=input_ids, 143 | inputs_embeds=speech_inputs_embeds, 144 | attention_mask=new_attention_mask, 145 | max_length=args.max_length, 146 | do_sample=True, 147 | num_return_sequences=args.num_return_sequences, 148 | ) 149 | 150 | 151 | new_embeds = output_sequences[1] 152 | new_audio_values = model.codec.decoder(new_embeds.transpose(-1,-2).float()) 153 | 154 | 155 | output_path = "output.wav" 156 | torchaudio.save(output_path, new_audio_values[0][:, speech_input_length* STRIDE:].cpu(), SAMPLING_RATE) 157 | 158 | return 159 | 160 | 161 | if __name__ == "__main__": 162 | main() 163 | -------------------------------------------------------------------------------- /scripts/train_libriheavy.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import math 4 | import json 5 | import logging 6 | import pathlib 7 | from pathlib import Path 8 | from dataclasses import dataclass, field, asdict 9 | from typing import Dict, Optional, Sequence, List, Union 10 | 11 | import numpy as np 12 | import torch 13 | import datasets 14 | import transformers 15 | 16 | import soundfile as sf 17 | import librosa 18 | 19 | 20 | 21 | from transformers import ( 22 | HfArgumentParser, 23 | set_seed, 24 | ) 25 | from transformers.testing_utils import CaptureLogger 26 | from transformers.trainer_utils import get_last_checkpoint 27 | from transformers.data.data_collator import pad_without_fast_tokenizer_warning 28 | 29 | from transformers import AutoProcessor, AutoTokenizer, RobertaTokenizer 30 | 31 | 32 | from sled.sled import SpeechLlamaConfig, SpeechLlamaForCausalLM 33 | from sled.trainer_libriheavy import SpeechLlamaTrainer 34 | 35 | logger = logging.getLogger(__name__) 36 | 37 | 38 | 39 | SAMPLING_RATE=24000 40 | SAMPLING_RATE_LIBRIHEAVY=16000 41 | SAMPLING_RATE_TOKENIZER=75 42 | 43 | 44 | @dataclass 45 | class ArchArguments: 46 | # -------------------------------------------------------------------------- 47 | # Llama Arguments 48 | hidden_size: int = 1024 49 | intermediate_size: int = 2752 50 | num_hidden_layers: int = 12 51 | num_attention_heads: int = 16 52 | num_key_value_heads: Optional[int] = None 53 | hidden_act: str = "silu" 54 | max_position_embeddings: int = 2048 55 | initializer_range: float = 0.02 56 | rms_norm_eps: float = 1e-6 57 | use_cache: bool = True 58 | pad_token_id: Optional[int] = None 59 | bos_token_id: int = 0 60 | eos_token_id: int = 2 61 | pretraining_tp: int = 1 62 | tie_word_embeddings: bool = False 63 | rope_theta: float = 10000.0 64 | rope_scaling: Optional[float] = None 65 | attention_bias: bool = False 66 | attention_dropout: float = 0.1 67 | mlp_bias: bool = False 68 | vocab_size: int = 32000 69 | dropout: float = 0.1 70 | activation_dropout: float = 0.1 71 | 72 | # -------------------------------------------------------------------------- 73 | # Score Arguments 74 | vae_embed_dim: int = 128 75 | diffloss_d: int = 3 76 | diffloss_w: int = 1024 77 | training_cfg: float = 0.0 78 | noise_channels: int = 128 79 | 80 | 81 | 82 | @dataclass 83 | class ModelArguments: 84 | # -------------------------------------------------------------------------- 85 | # Codec & Tokenizer Arguments 86 | codec: str = "facebook/encodec_24khz" 87 | tokenizer: str = "/path/tokenizer_bpe_libriheavy" 88 | 89 | 90 | 91 | @dataclass 92 | class DataArguments: 93 | data_path: str = "/path/libriheavy" 94 | train_manifest: List[str] = field(default_factory=lambda: ["/path/libriheavy/cases_and_punc/libriheavy_cuts_large.jsonl", "/path/libriheavy/cases_and_punc/libriheavy_cuts_medium.jsonl", "/path/libriheavy/cases_and_punc/libriheavy_cuts_small.jsonl"]) 95 | eval_manifest: List[str] = field(default_factory=lambda: ["/path/libriheavy/cases_and_punc/filtered2/libriheavy_cuts_dev.jsonl"]) 96 | pad_to_multiple_of: Optional[int] = None 97 | max_train_samples: Optional[int] = field( 98 | default=None, 99 | metadata={ 100 | "help": ( 101 | "For debugging purposes or quicker training, truncate the number of training examples to this " 102 | "value if set." 103 | ) 104 | }, 105 | ) 106 | max_eval_samples: Optional[int] = field( 107 | default=None, 108 | metadata={ 109 | "help": ( 110 | "For debugging purposes or quicker training, truncate the number of evaluation examples to this " 111 | "value if set." 112 | ) 113 | }, 114 | ) 115 | 116 | 117 | @dataclass 118 | class TrainingArguments(transformers.TrainingArguments): 119 | group_by_speech_length: bool = field(default=True) 120 | 121 | 122 | 123 | @dataclass 124 | class DataCollatorForSupervisedDataset(object): 125 | """Collate examples for supervised fine-tuning.""" 126 | 127 | tokenizer: transformers.PreTrainedTokenizer 128 | processor: transformers.PreTrainedTokenizer 129 | data_path: str 130 | pad_to_multiple_of: Optional[int] = None 131 | 132 | def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: 133 | input_ids = [{"input_ids":instance["input_ids"]} for instance in instances] 134 | 135 | batch = pad_without_fast_tokenizer_warning( 136 | self.tokenizer, 137 | input_ids, 138 | padding=True, 139 | pad_to_multiple_of=self.pad_to_multiple_of, 140 | return_attention_mask=True, 141 | return_tensors="pt" 142 | ) 143 | 144 | audio_files = [instance["recording"]["sources"][0]["source"] for instance in instances] 145 | durations = [instance["duration"] for instance in instances] 146 | start_times = [instance["start"] for instance in instances] 147 | 148 | audio_arrays = [self.load_audio(file_path, start, duration) for file_path, start, duration in zip(audio_files, start_times, durations)] 149 | 150 | audio_inputs = self.processor(raw_audio=audio_arrays, sampling_rate=SAMPLING_RATE, return_tensors="pt") # 'padding_mask': b,t 'input_values': b,c,t 151 | 152 | batch["audio_inputs"] = audio_inputs 153 | 154 | return batch 155 | 156 | def load_audio(self, file_path: str, start: float, duration: float) -> np.array: 157 | abs_path = Path(self.data_path) / file_path 158 | audio, sampling_rate = sf.read(abs_path, start=int(start * SAMPLING_RATE_LIBRIHEAVY), stop=int((start + duration) * SAMPLING_RATE_LIBRIHEAVY)) 159 | resampled_audio = librosa.resample(audio, orig_sr=sampling_rate, target_sr=SAMPLING_RATE) 160 | return resampled_audio 161 | 162 | 163 | def load_manifest(file_paths): 164 | all_data = [] 165 | for file_path in file_paths: 166 | with open(file_path, "r") as f: 167 | all_data.extend([json.loads(line) for line in f]) 168 | return all_data 169 | 170 | 171 | def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer, 172 | arch_args, model_args, data_args, training_args) -> Dict: 173 | """Make dataset and collator for supervised fine-tuning.""" 174 | train_dataset = None 175 | eval_dataset = None 176 | 177 | if training_args.do_train: 178 | train_manifest = load_manifest(data_args.train_manifest) 179 | 180 | if data_args.max_train_samples is not None: 181 | train_manifest = train_manifest[:data_args.max_train_samples] 182 | 183 | train_dataset = train_manifest 184 | 185 | 186 | if training_args.do_eval: 187 | eval_manifest = load_manifest(data_args.eval_manifest) 188 | 189 | if data_args.max_eval_samples is not None: 190 | eval_manifest = eval_manifest[:data_args.max_eval_samples] 191 | 192 | eval_dataset = eval_manifest 193 | 194 | def tokenize_example(example): 195 | text = example["supervisions"][0]["text"] 196 | return tokenizer(text) 197 | 198 | tokenized_train_dataset = None 199 | tokenized_eval_dataset = None 200 | 201 | with training_args.main_process_first(desc="dataset map tokenization"): 202 | if training_args.do_train and train_dataset: 203 | tokenized_train_dataset = [ 204 | {**example, **tokenize_example(example)} for example in train_dataset 205 | ] 206 | 207 | if training_args.do_eval and eval_dataset: 208 | tokenized_eval_dataset = [ 209 | {**example, **tokenize_example(example)} for example in eval_dataset 210 | ] 211 | 212 | 213 | def filter_function(example): 214 | file_path = example["recording"]["sources"][0]["source"] 215 | abs_path = Path(data_args.data_path) / file_path 216 | exists = abs_path.exists() 217 | return ((len(example['input_ids']) + int(example["duration"] * SAMPLING_RATE_TOKENIZER)) < arch_args.max_position_embeddings) and exists 218 | 219 | if tokenized_train_dataset is not None: 220 | logger.info(f"original train dataset: {len(tokenized_train_dataset)} samples.") 221 | tokenized_train_dataset = [ex for ex in tokenized_train_dataset if filter_function(ex)] 222 | logger.info(f"filtered train dataset: {len(tokenized_train_dataset)} samples.") 223 | 224 | 225 | if tokenized_eval_dataset is not None: 226 | logger.info(f"original eval dataset: {len(tokenized_eval_dataset)} samples.") 227 | tokenized_eval_dataset = [ex for ex in tokenized_eval_dataset if filter_function(ex)] 228 | logger.info(f"filtered eval dataset: {len(tokenized_eval_dataset)} samples.") 229 | 230 | processor = AutoProcessor.from_pretrained(model_args.codec) 231 | data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer, processor=processor, pad_to_multiple_of=data_args.pad_to_multiple_of, data_path=data_args.data_path) 232 | 233 | return tokenized_train_dataset, tokenized_eval_dataset, data_collator 234 | 235 | 236 | def train(attn_implementation="sdpa"): 237 | 238 | parser = HfArgumentParser( 239 | (ArchArguments, ModelArguments, DataArguments, TrainingArguments)) 240 | arch_args, model_args, data_args, training_args = parser.parse_args_into_dataclasses() 241 | 242 | 243 | # Setup logging 244 | logging.basicConfig( 245 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 246 | datefmt="%m/%d/%Y %H:%M:%S", 247 | handlers=[logging.StreamHandler(sys.stdout)], 248 | ) 249 | 250 | if training_args.should_log: 251 | # The default of training_args.log_level is passive, so we set log level at info here to have that default. 252 | transformers.utils.logging.set_verbosity_info() 253 | 254 | log_level = training_args.get_process_log_level() 255 | logger.setLevel(log_level) 256 | datasets.utils.logging.set_verbosity(log_level) 257 | transformers.utils.logging.set_verbosity(log_level) 258 | transformers.utils.logging.enable_default_handler() 259 | transformers.utils.logging.enable_explicit_format() 260 | 261 | # Log on each process the small summary: 262 | logger.warning( 263 | f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}, " 264 | + f"distributed training: {training_args.parallel_mode.value == 'distributed'}, 16-bits training: {training_args.fp16 or training_args.bf16}" 265 | ) 266 | logger.info(f"Training/evaluation parameters {training_args}") 267 | 268 | 269 | # Detecting last checkpoint. 270 | last_checkpoint = None 271 | if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir: 272 | last_checkpoint = get_last_checkpoint(training_args.output_dir) 273 | if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0: 274 | raise ValueError( 275 | f"Output directory ({training_args.output_dir}) already exists and is not empty. " 276 | "Use --overwrite_output_dir to overcome." 277 | ) 278 | elif last_checkpoint is not None and training_args.resume_from_checkpoint is None: 279 | logger.info( 280 | f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " 281 | "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." 282 | ) 283 | 284 | # Set seed before initializing model. 285 | set_seed(training_args.seed) 286 | 287 | 288 | tokenizer = RobertaTokenizer.from_pretrained( 289 | model_args.tokenizer, 290 | padding_side="left", 291 | add_eos_token=True, 292 | ) 293 | arch_args.vocab_size = tokenizer.vocab_size 294 | model_config = SpeechLlamaConfig(**asdict(arch_args)) 295 | logger.info(f"config: {model_config}") 296 | 297 | 298 | torch_dtype = None #torch.bfloat16 if training_args.bf16 else None 299 | 300 | model = SpeechLlamaForCausalLM._from_config(model_config, attn_implementation=attn_implementation, torch_dtype=torch_dtype) 301 | model.initialize_codec(model_args) 302 | 303 | n_params = sum({p.data_ptr(): p.numel() for p in model.parameters()}.values()) 304 | logger.info(f"Training new model from scratch - Total size={n_params/2**20:.2f}M params") 305 | 306 | 307 | train_dataset, eval_dataset, data_collator = make_supervised_data_module(tokenizer, arch_args, model_args, data_args, training_args) 308 | trainer = SpeechLlamaTrainer( 309 | model=model, 310 | args=training_args, 311 | data_collator=data_collator, 312 | train_dataset=train_dataset if training_args.do_train else None, 313 | eval_dataset=eval_dataset if training_args.do_eval else None, 314 | tokenizer=tokenizer, 315 | ) 316 | 317 | # Training 318 | if training_args.do_train: 319 | checkpoint = None 320 | if training_args.resume_from_checkpoint is not None: 321 | checkpoint = training_args.resume_from_checkpoint 322 | elif last_checkpoint is not None: 323 | checkpoint = last_checkpoint 324 | train_result = trainer.train(resume_from_checkpoint=checkpoint) 325 | trainer.save_model() # Saves the tokenizer too for easy upload 326 | 327 | metrics = train_result.metrics 328 | 329 | max_train_samples = ( 330 | data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset) 331 | ) 332 | metrics["train_samples"] = min(max_train_samples, len(train_dataset)) 333 | 334 | trainer.log_metrics("train", metrics) 335 | trainer.save_metrics("train", metrics) 336 | trainer.save_state() 337 | 338 | # Evaluation 339 | if training_args.do_eval: 340 | logger.info("*** Evaluate ***") 341 | 342 | metrics = trainer.evaluate() 343 | 344 | max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset) 345 | metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset)) 346 | 347 | 348 | trainer.log_metrics("eval", metrics) 349 | trainer.save_metrics("eval", metrics) 350 | 351 | if __name__ == "__main__": 352 | train(attn_implementation="sdpa") 353 | -------------------------------------------------------------------------------- /scripts/train_libriheavy_stream.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import math 4 | import json 5 | import logging 6 | import pathlib 7 | from pathlib import Path 8 | from dataclasses import dataclass, field, asdict 9 | from typing import Dict, Optional, Sequence, List, Union 10 | 11 | import numpy as np 12 | import torch 13 | import datasets 14 | import transformers 15 | from safetensors.torch import load_file 16 | 17 | import soundfile as sf 18 | import librosa 19 | 20 | 21 | 22 | from transformers import ( 23 | HfArgumentParser, 24 | set_seed, 25 | ) 26 | from transformers.testing_utils import CaptureLogger 27 | from transformers.trainer_utils import get_last_checkpoint 28 | from transformers.data.data_collator import pad_without_fast_tokenizer_warning 29 | 30 | from transformers import AutoProcessor, AutoTokenizer, RobertaTokenizer 31 | 32 | 33 | from sled.sled_stream import SpeechLlamaConfig, SpeechLlamaForCausalLM 34 | from sled.trainer_libriheavy import SpeechLlamaTrainer 35 | 36 | logger = logging.getLogger(__name__) 37 | 38 | 39 | 40 | SAMPLING_RATE=24000 41 | SAMPLING_RATE_LIBRIHEAVY=16000 42 | SAMPLING_RATE_TOKENIZER=75 43 | 44 | 45 | @dataclass 46 | class ArchArguments: 47 | # -------------------------------------------------------------------------- 48 | # Llama Arguments 49 | hidden_size: int = 1024 50 | intermediate_size: int = 2752 51 | num_hidden_layers: int = 12 52 | num_attention_heads: int = 16 53 | num_key_value_heads: Optional[int] = None 54 | hidden_act: str = "silu" 55 | max_position_embeddings: int = 2048 56 | initializer_range: float = 0.02 57 | rms_norm_eps: float = 1e-6 58 | use_cache: bool = True 59 | pad_token_id: Optional[int] = None 60 | bos_token_id: int = 0 61 | eos_token_id: int = 2 62 | pretraining_tp: int = 1 63 | tie_word_embeddings: bool = False 64 | rope_theta: float = 10000.0 65 | rope_scaling: Optional[float] = None 66 | attention_bias: bool = False 67 | attention_dropout: float = 0.1 68 | mlp_bias: bool = False 69 | vocab_size: int = 32000 70 | dropout: float = 0.1 71 | activation_dropout: float = 0.1 72 | 73 | # -------------------------------------------------------------------------- 74 | # Score Arguments 75 | vae_embed_dim: int = 128 76 | diffloss_d: int = 3 77 | diffloss_w: int = 1024 78 | training_cfg: float = 0.0 79 | noise_channels: int = 128 80 | 81 | # -------------------------------------------------------------------------- 82 | # Stream Arguments 83 | stream_n: int = 5 84 | stream_m: int = 45 85 | 86 | 87 | 88 | @dataclass 89 | class ModelArguments: 90 | # -------------------------------------------------------------------------- 91 | # Codec & Tokenizer Arguments 92 | codec: str = "facebook/encodec_24khz" 93 | tokenizer: str = "/path/tokenizer_bpe_libriheavy" 94 | 95 | 96 | 97 | @dataclass 98 | class DataArguments: 99 | finetune_path: str = None 100 | data_path: str = "/path/libriheavy" 101 | train_manifest: List[str] = field(default_factory=lambda: ["/path/libriheavy/cases_and_punc/libriheavy_cuts_large.jsonl", "/path/libriheavy/cases_and_punc/libriheavy_cuts_medium.jsonl", "/path/libriheavy/cases_and_punc/libriheavy_cuts_small.jsonl"]) 102 | eval_manifest: List[str] = field(default_factory=lambda: ["/path/libriheavy/cases_and_punc/filtered2/libriheavy_cuts_dev.jsonl"]) 103 | pad_to_multiple_of: Optional[int] = None 104 | max_train_samples: Optional[int] = field( 105 | default=None, 106 | metadata={ 107 | "help": ( 108 | "For debugging purposes or quicker training, truncate the number of training examples to this " 109 | "value if set." 110 | ) 111 | }, 112 | ) 113 | max_eval_samples: Optional[int] = field( 114 | default=None, 115 | metadata={ 116 | "help": ( 117 | "For debugging purposes or quicker training, truncate the number of evaluation examples to this " 118 | "value if set." 119 | ) 120 | }, 121 | ) 122 | 123 | 124 | @dataclass 125 | class TrainingArguments(transformers.TrainingArguments): 126 | group_by_speech_length: bool = field(default=True) 127 | 128 | 129 | 130 | @dataclass 131 | class DataCollatorForSupervisedDataset(object): 132 | """Collate examples for supervised fine-tuning.""" 133 | 134 | tokenizer: transformers.PreTrainedTokenizer 135 | processor: transformers.PreTrainedTokenizer 136 | data_path: str 137 | pad_to_multiple_of: Optional[int] = None 138 | 139 | def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: 140 | input_ids = [{"input_ids":instance["input_ids"]} for instance in instances] 141 | 142 | batch = pad_without_fast_tokenizer_warning( 143 | self.tokenizer, 144 | input_ids, 145 | padding=True, 146 | pad_to_multiple_of=self.pad_to_multiple_of, 147 | return_attention_mask=True, 148 | return_tensors="pt" 149 | ) 150 | 151 | audio_files = [instance["recording"]["sources"][0]["source"] for instance in instances] 152 | durations = [instance["duration"] for instance in instances] 153 | start_times = [instance["start"] for instance in instances] 154 | 155 | audio_arrays = [self.load_audio(file_path, start, duration) for file_path, start, duration in zip(audio_files, start_times, durations)] 156 | 157 | audio_inputs = self.processor(raw_audio=audio_arrays, sampling_rate=SAMPLING_RATE, return_tensors="pt") # 'padding_mask': b,t 'input_values': b,c,t 158 | 159 | batch["audio_inputs"] = audio_inputs 160 | 161 | return batch 162 | 163 | def load_audio(self, file_path: str, start: float, duration: float) -> np.array: 164 | abs_path = Path(self.data_path) / file_path 165 | audio, sampling_rate = sf.read(abs_path, start=int(start * SAMPLING_RATE_LIBRIHEAVY), stop=int((start + duration) * SAMPLING_RATE_LIBRIHEAVY)) 166 | resampled_audio = librosa.resample(audio, orig_sr=sampling_rate, target_sr=SAMPLING_RATE) 167 | return resampled_audio 168 | 169 | 170 | def load_manifest(file_paths): 171 | all_data = [] 172 | for file_path in file_paths: 173 | with open(file_path, "r") as f: 174 | all_data.extend([json.loads(line) for line in f]) 175 | return all_data 176 | 177 | 178 | def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer, 179 | arch_args, model_args, data_args, training_args) -> Dict: 180 | """Make dataset and collator for supervised fine-tuning.""" 181 | train_dataset = None 182 | eval_dataset = None 183 | 184 | if training_args.do_train: 185 | train_manifest = load_manifest(data_args.train_manifest) 186 | 187 | if data_args.max_train_samples is not None: 188 | train_manifest = train_manifest[:data_args.max_train_samples] 189 | 190 | train_dataset = train_manifest 191 | 192 | 193 | if training_args.do_eval: 194 | eval_manifest = load_manifest(data_args.eval_manifest) 195 | 196 | if data_args.max_eval_samples is not None: 197 | eval_manifest = eval_manifest[:data_args.max_eval_samples] 198 | 199 | eval_dataset = eval_manifest 200 | 201 | def tokenize_example(example): 202 | text = example["supervisions"][0]["text"] 203 | return tokenizer(text) 204 | 205 | tokenized_train_dataset = None 206 | tokenized_eval_dataset = None 207 | 208 | with training_args.main_process_first(desc="dataset map tokenization"): 209 | if training_args.do_train and train_dataset: 210 | tokenized_train_dataset = [ 211 | {**example, **tokenize_example(example)} for example in train_dataset 212 | ] 213 | 214 | if training_args.do_eval and eval_dataset: 215 | tokenized_eval_dataset = [ 216 | {**example, **tokenize_example(example)} for example in eval_dataset 217 | ] 218 | 219 | 220 | def filter_function(example): 221 | file_path = example["recording"]["sources"][0]["source"] 222 | abs_path = Path(data_args.data_path) / file_path 223 | exists = abs_path.exists() 224 | return ((len(example['input_ids']) + int(example["duration"] * SAMPLING_RATE_TOKENIZER)) < arch_args.max_position_embeddings) and exists 225 | 226 | if tokenized_train_dataset is not None: 227 | logger.info(f"original train dataset: {len(tokenized_train_dataset)} samples.") 228 | tokenized_train_dataset = [ex for ex in tokenized_train_dataset if filter_function(ex)] 229 | logger.info(f"filtered train dataset: {len(tokenized_train_dataset)} samples.") 230 | 231 | 232 | if tokenized_eval_dataset is not None: 233 | logger.info(f"original eval dataset: {len(tokenized_eval_dataset)} samples.") 234 | tokenized_eval_dataset = [ex for ex in tokenized_eval_dataset if filter_function(ex)] 235 | logger.info(f"filtered eval dataset: {len(tokenized_eval_dataset)} samples.") 236 | 237 | processor = AutoProcessor.from_pretrained(model_args.codec) 238 | data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer, processor=processor, pad_to_multiple_of=data_args.pad_to_multiple_of, data_path=data_args.data_path) 239 | 240 | return tokenized_train_dataset, tokenized_eval_dataset, data_collator 241 | 242 | 243 | def train(attn_implementation="sdpa"): 244 | 245 | parser = HfArgumentParser( 246 | (ArchArguments, ModelArguments, DataArguments, TrainingArguments)) 247 | arch_args, model_args, data_args, training_args = parser.parse_args_into_dataclasses() 248 | 249 | 250 | # Setup logging 251 | logging.basicConfig( 252 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 253 | datefmt="%m/%d/%Y %H:%M:%S", 254 | handlers=[logging.StreamHandler(sys.stdout)], 255 | ) 256 | 257 | if training_args.should_log: 258 | # The default of training_args.log_level is passive, so we set log level at info here to have that default. 259 | transformers.utils.logging.set_verbosity_info() 260 | 261 | log_level = training_args.get_process_log_level() 262 | logger.setLevel(log_level) 263 | datasets.utils.logging.set_verbosity(log_level) 264 | transformers.utils.logging.set_verbosity(log_level) 265 | transformers.utils.logging.enable_default_handler() 266 | transformers.utils.logging.enable_explicit_format() 267 | 268 | # Log on each process the small summary: 269 | logger.warning( 270 | f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}, " 271 | + f"distributed training: {training_args.parallel_mode.value == 'distributed'}, 16-bits training: {training_args.fp16 or training_args.bf16}" 272 | ) 273 | logger.info(f"Training/evaluation parameters {training_args}") 274 | 275 | 276 | # Detecting last checkpoint. 277 | last_checkpoint = None 278 | if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir: 279 | last_checkpoint = get_last_checkpoint(training_args.output_dir) 280 | if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0: 281 | raise ValueError( 282 | f"Output directory ({training_args.output_dir}) already exists and is not empty. " 283 | "Use --overwrite_output_dir to overcome." 284 | ) 285 | elif last_checkpoint is not None and training_args.resume_from_checkpoint is None: 286 | logger.info( 287 | f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " 288 | "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." 289 | ) 290 | 291 | # Set seed before initializing model. 292 | set_seed(training_args.seed) 293 | 294 | 295 | tokenizer = RobertaTokenizer.from_pretrained( 296 | model_args.tokenizer, 297 | padding_side="right", 298 | add_eos_token=True, 299 | ) 300 | arch_args.vocab_size = tokenizer.vocab_size 301 | model_config = SpeechLlamaConfig(**asdict(arch_args)) 302 | logger.info(f"config: {model_config}") 303 | 304 | 305 | torch_dtype = None #torch.bfloat16 if training_args.bf16 else None 306 | 307 | model = SpeechLlamaForCausalLM._from_config(model_config, attn_implementation=attn_implementation, torch_dtype=torch_dtype) 308 | 309 | state_dict = load_file(data_args.finetune_path) 310 | missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) 311 | model.initialize_codec(model_args) 312 | 313 | n_params = sum({p.data_ptr(): p.numel() for p in model.parameters()}.values()) 314 | logger.info(f"Training new model from scratch - Total size={n_params/2**20:.2f}M params") 315 | 316 | 317 | train_dataset, eval_dataset, data_collator = make_supervised_data_module(tokenizer, arch_args, model_args, data_args, training_args) 318 | trainer = SpeechLlamaTrainer( 319 | model=model, 320 | args=training_args, 321 | data_collator=data_collator, 322 | train_dataset=train_dataset if training_args.do_train else None, 323 | eval_dataset=eval_dataset if training_args.do_eval else None, 324 | tokenizer=tokenizer, 325 | ) 326 | 327 | # Training 328 | if training_args.do_train: 329 | checkpoint = None 330 | if training_args.resume_from_checkpoint is not None: 331 | checkpoint = training_args.resume_from_checkpoint 332 | elif last_checkpoint is not None: 333 | checkpoint = last_checkpoint 334 | train_result = trainer.train(resume_from_checkpoint=checkpoint) 335 | trainer.save_model() # Saves the tokenizer too for easy upload 336 | 337 | metrics = train_result.metrics 338 | 339 | max_train_samples = ( 340 | data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset) 341 | ) 342 | metrics["train_samples"] = min(max_train_samples, len(train_dataset)) 343 | 344 | trainer.log_metrics("train", metrics) 345 | trainer.save_metrics("train", metrics) 346 | trainer.save_state() 347 | 348 | # Evaluation 349 | if training_args.do_eval: 350 | logger.info("*** Evaluate ***") 351 | 352 | metrics = trainer.evaluate() 353 | 354 | max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset) 355 | metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset)) 356 | 357 | 358 | trainer.log_metrics("eval", metrics) 359 | trainer.save_metrics("eval", metrics) 360 | 361 | if __name__ == "__main__": 362 | train(attn_implementation="sdpa") 363 | -------------------------------------------------------------------------------- /shell_scripts/eval.sh: -------------------------------------------------------------------------------- 1 | 2 | CHECKPOINT=/path/to/checkpoint 3 | BSZ=1 4 | CFG=2.0 5 | 6 | python ./scripts/eval_stream.py \ 7 | --model_name_or_path ${CHECKPOINT} \ 8 | --batch_size ${BSZ} --cfg ${CFG} --seed 0 9 | -------------------------------------------------------------------------------- /shell_scripts/run.sh: -------------------------------------------------------------------------------- 1 | 2 | 3 | CHECKPOINT=/path/to/checkpoint 4 | CFG=2.0 5 | 6 | # Offline Inference 7 | python scripts/run_offline.py \ 8 | --model_name_or_path ${CHECKPOINT} \ 9 | --cfg ${CFG} \ 10 | --input "My remark pleases him, but I soon prove to him that it is not the right way to speak. However perfect may have been the language of that ancient writer." \ 11 | --seed 42 12 | 13 | # Or Streaming Inference 14 | python scripts/run_stream.py \ 15 | --model_name_or_path ${CHECKPOINT} \ 16 | --cfg ${CFG} \ 17 | --input "My remark pleases him, but I soon prove to him that it is not the right way to speak. However perfect may have been the language of that ancient writer." \ 18 | --seed 42 19 | 20 | # Or Voice Clone 21 | python scripts/run_voice_clone.py \ 22 | --prompt_text "Were I in the warm room with all the splendor and magnificence!" \ 23 | --prompt_audio "example_prompt.flac" \ 24 | --model_name_or_path ${CHECKPOINT} \ 25 | --cfg ${CFG} \ 26 | --input "Perhaps the other trees from the forest will come to look at me!" \ 27 | --seed 42 28 | -------------------------------------------------------------------------------- /shell_scripts/train_libriheavy.sh: -------------------------------------------------------------------------------- 1 | OUTPUT_DIR=./runs/libriheavy 2 | mkdir -p $OUTPUT_DIR 3 | LOG_FILE=${OUTPUT_DIR}/log 4 | 5 | BATCH_SIZE=8 6 | UPDATE_FREQ=8 7 | # assume 8 proc per node, then WORLD_SIZE * 8 * BATCH_SIZE * UPDATE_FREQ == 512 8 | 9 | torchrun --nnodes ${WORLD_SIZE} --node_rank ${RANK} --nproc_per_node 8 --master_addr ${MASTER_ADDR} --master_port ${MASTER_PORT} \ 10 | ./scripts/train_libriheavy.py \ 11 | --training_cfg 0.1 \ 12 | --num_hidden_layers 12 --diffloss_d 6 --noise_channels 128 \ 13 | --dataloader_num_workers 8 \ 14 | --dataloader_pin_memory True \ 15 | --remove_unused_columns False \ 16 | --label_names audio_inputs \ 17 | --group_by_speech_length \ 18 | --do_train \ 19 | --do_eval \ 20 | --eval_strategy steps \ 21 | --eval_steps 10000 \ 22 | --prediction_loss_only \ 23 | --per_device_train_batch_size ${BATCH_SIZE} \ 24 | --per_device_eval_batch_size 24 \ 25 | --gradient_accumulation_steps ${UPDATE_FREQ} \ 26 | --bf16 \ 27 | --learning_rate 5e-4 \ 28 | --weight_decay 0.01 \ 29 | --adam_beta1 0.9 \ 30 | --adam_beta2 0.999 \ 31 | --adam_epsilon 1e-8 \ 32 | --max_grad_norm 1.0 \ 33 | --max_steps 300000 \ 34 | --lr_scheduler_type "linear" \ 35 | --warmup_steps 32000 \ 36 | --logging_first_step \ 37 | --logging_steps 100 \ 38 | --save_steps 10000 \ 39 | --save_total_limit 10 \ 40 | --output_dir ${OUTPUT_DIR} \ 41 | --report_to tensorboard \ 42 | --disable_tqdm True \ 43 | --ddp_timeout 3600 --overwrite_output_dir \ 44 | 2>&1 |tee -a ${LOG_FILE} 45 | -------------------------------------------------------------------------------- /shell_scripts/train_libriheavy_stream.sh: -------------------------------------------------------------------------------- 1 | OUTPUT_DIR=./runs/libriheavy_stream 2 | mkdir -p $OUTPUT_DIR 3 | LOG_FILE=${OUTPUT_DIR}/log 4 | 5 | BATCH_SIZE=8 6 | UPDATE_FREQ=8 7 | # assume 8 proc per node, then WORLD_SIZE * 8 * BATCH_SIZE * UPDATE_FREQ == 512 8 | 9 | torchrun --nnodes ${WORLD_SIZE} --node_rank ${RANK} --nproc_per_node 8 --master_addr ${MASTER_ADDR} --master_port ${MASTER_PORT} \ 10 | ./scripts/train_libriheavy_stream.py \ 11 | --finetune_path ./runs/libriheavy/checkpoint-300000/model.safetensors \ 12 | --stream_n 5 --stream_m 45 \ 13 | --training_cfg 0.1 \ 14 | --num_hidden_layers 12 --diffloss_d 6 --noise_channels 128 \ 15 | --dataloader_num_workers 8 \ 16 | --dataloader_pin_memory True \ 17 | --remove_unused_columns False \ 18 | --label_names audio_inputs \ 19 | --group_by_speech_length \ 20 | --do_train \ 21 | --do_eval \ 22 | --eval_strategy steps \ 23 | --eval_steps 10000 \ 24 | --prediction_loss_only \ 25 | --per_device_train_batch_size ${BATCH_SIZE} \ 26 | --per_device_eval_batch_size 24 \ 27 | --gradient_accumulation_steps ${UPDATE_FREQ} \ 28 | --bf16 \ 29 | --learning_rate 3e-4 \ 30 | --weight_decay 0.01 \ 31 | --adam_beta1 0.9 \ 32 | --adam_beta2 0.999 \ 33 | --adam_epsilon 1e-8 \ 34 | --max_grad_norm 1.0 \ 35 | --max_steps 100000 \ 36 | --lr_scheduler_type "linear" \ 37 | --warmup_steps 10000 \ 38 | --logging_first_step \ 39 | --logging_steps 100 \ 40 | --save_steps 10000 \ 41 | --save_total_limit 10 \ 42 | --output_dir ${OUTPUT_DIR} \ 43 | --report_to tensorboard \ 44 | --disable_tqdm True \ 45 | --ddp_timeout 3600 --overwrite_output_dir \ 46 | 2>&1 |tee -a ${LOG_FILE} 47 | -------------------------------------------------------------------------------- /sled.egg-info/PKG-INFO: -------------------------------------------------------------------------------- 1 | Metadata-Version: 2.4 2 | Name: sled 3 | Version: 0.1.0 4 | Summary: Implementation of SLED 5 | Author-email: Zhengrui Ma 6 | Requires-Dist: torch==2.5.1 7 | Requires-Dist: torchaudio==2.5.1 8 | Requires-Dist: transformers==4.47.0 9 | Requires-Dist: datasets==3.1.0 10 | Requires-Dist: accelerate==1.2.0 11 | Requires-Dist: numpy==1.26.4 12 | Requires-Dist: librosa==0.10.2.post1 13 | Requires-Dist: soundfile==0.12.1 14 | -------------------------------------------------------------------------------- /sled.egg-info/SOURCES.txt: -------------------------------------------------------------------------------- 1 | README.md 2 | pyproject.toml 3 | scripts/eval_continue.py 4 | scripts/eval_reference.py 5 | scripts/eval_stream.py 6 | scripts/run_offline.py 7 | scripts/run_stream.py 8 | scripts/train_libriheavy.py 9 | scripts/train_libriheavy_stream.py 10 | sled/__init__.py 11 | sled/energy_distance.py 12 | sled/modeling_llama_with_dropout.py 13 | sled/sled.py 14 | sled/sled_stream.py 15 | sled/trainer.py 16 | sled/trainer_libriheavy.py 17 | sled.egg-info/PKG-INFO 18 | sled.egg-info/SOURCES.txt 19 | sled.egg-info/dependency_links.txt 20 | sled.egg-info/requires.txt 21 | sled.egg-info/top_level.txt -------------------------------------------------------------------------------- /sled.egg-info/dependency_links.txt: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /sled.egg-info/requires.txt: -------------------------------------------------------------------------------- 1 | torch==2.5.1 2 | torchaudio==2.5.1 3 | transformers==4.47.0 4 | datasets==3.1.0 5 | accelerate==1.2.0 6 | numpy==1.26.4 7 | librosa==0.10.2.post1 8 | soundfile==0.12.1 9 | -------------------------------------------------------------------------------- /sled.egg-info/top_level.txt: -------------------------------------------------------------------------------- 1 | sled 2 | -------------------------------------------------------------------------------- /sled/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ictnlp/SLED-TTS/b8ed10d9953160efd8a0538b4ea5af80a57c9e96/sled/__init__.py -------------------------------------------------------------------------------- /sled/energy_distance.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.utils.checkpoint import checkpoint 4 | import math 5 | import pdb 6 | 7 | 8 | class ScoreLossZ(nn.Module): 9 | """Energy Distance Loss""" 10 | def __init__(self, target_channels, z_channels, depth, width, beta=1, gamma=1, noise_channels=16): 11 | super(ScoreLossZ, self).__init__() 12 | self.noise_channels = noise_channels 13 | self.net = SimpleMLPAdaLN( 14 | in_channels=self.noise_channels, 15 | model_channels=width, 16 | out_channels=target_channels, 17 | z_channels=z_channels, 18 | num_res_blocks=depth, 19 | ) 20 | self.beta = beta 21 | self.gamma = gamma 22 | 23 | 24 | def forward(self, target, z, mask=None, additional_targets=None): 25 | noise_1 = torch.randn((z.shape[0], self.noise_channels), dtype=z.dtype, device=z.device) 26 | sample_1 = self.net(noise_1, z) 27 | noise_2 = torch.randn((z.shape[0], self.noise_channels), dtype=z.dtype, device=z.device) 28 | sample_2 = self.net(noise_2, z) 29 | 30 | score = self.energy_score(sample_1, sample_2, target, additional_targets) 31 | loss = - score 32 | if mask is not None: 33 | loss = (loss * mask).sum() / mask.sum() 34 | return loss.mean() 35 | 36 | def energy_distance(self, x_1, x_2): 37 | return torch.pow(torch.linalg.norm(x_1 - x_2, ord=2, dim=-1), self.beta) 38 | 39 | def energy_score(self, sample_1, sample_2, target, additional_targets = None): 40 | distance_1 = self.energy_distance(sample_1, target) 41 | distance_2 = self.energy_distance(sample_2, target) 42 | variance = self.energy_distance(sample_1, sample_2) 43 | score = variance - distance_1 - distance_2 44 | return score 45 | 46 | def kernel_distance(self, x_1, x_2): 47 | return - torch.exp(- torch.sum(torch.pow(x_1 - x_2, 2), dim = -1).div(2 * self.gamma**2)) 48 | 49 | def kernel_score(self, sample_1, sample_2, target): 50 | distance_1 = self.kernel_distance(sample_1, target) 51 | distance_2 = self.kernel_distance(sample_2, target) 52 | variance = self.kernel_distance(sample_1, sample_2) 53 | score = variance - distance_1 - distance_2 54 | return score 55 | 56 | def sample(self, z, temperature=1.0, cfg=1.0): 57 | if cfg != 1.0: 58 | z_1, z_2 = z.chunk(2, dim=0) 59 | z = z_1 * cfg + (1 - cfg) * z_2 60 | 61 | noise = torch.randn((z.shape[0], self.noise_channels), dtype=z.dtype, device=z.device) 62 | return self.net(noise, z) 63 | 64 | 65 | 66 | def modulate(x, shift, scale): 67 | return x * (1 + scale) + shift 68 | 69 | class ResBlock(nn.Module): 70 | """ 71 | A residual block that can optionally change the number of channels. 72 | :param channels: the number of input channels. 73 | """ 74 | 75 | def __init__( 76 | self, 77 | channels 78 | ): 79 | super().__init__() 80 | self.channels = channels 81 | 82 | self.in_ln = nn.LayerNorm(channels, eps=1e-6, elementwise_affine=False) 83 | self.mlp = nn.Sequential( 84 | nn.Linear(channels, channels, bias=True), 85 | nn.SiLU(), 86 | nn.Linear(channels, channels, bias=True), 87 | ) 88 | self.noise_ln = nn.LayerNorm(channels, eps=1e-6, elementwise_affine=True) 89 | 90 | self.adaLN_modulation = nn.Sequential( 91 | nn.SiLU(), 92 | nn.Linear(channels, 3 * channels, bias=True) 93 | ) 94 | 95 | def forward(self, x, y): 96 | shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(self.noise_ln(y)).chunk(3, dim=-1) 97 | h = modulate(self.in_ln(x), shift_mlp, scale_mlp) 98 | h = self.mlp(h) 99 | return x + gate_mlp * h 100 | 101 | 102 | class FinalLayer(nn.Module): 103 | """ 104 | The final layer of DiT. 105 | """ 106 | def __init__(self, model_channels, out_channels): 107 | super().__init__() 108 | self.norm_final = nn.LayerNorm(model_channels, elementwise_affine=False, eps=1e-6) 109 | self.linear = nn.Linear(model_channels, out_channels, bias=True) 110 | self.adaLN_modulation = nn.Sequential( 111 | nn.SiLU(), 112 | nn.Linear(model_channels, 2 * model_channels, bias=True) 113 | ) 114 | 115 | def forward(self, x, c): 116 | shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1) 117 | x = modulate(self.norm_final(x), shift, scale) 118 | x = self.linear(x) 119 | return x 120 | 121 | 122 | class SimpleMLPAdaLN(nn.Module): 123 | """ 124 | The MLP for Energy Distance Loss. 125 | :param in_channels: channels in the input Tensor. 126 | :param model_channels: base channel count for the model. 127 | :param out_channels: channels in the output Tensor. 128 | :param z_channels: channels in the condition. 129 | :param num_res_blocks: number of residual blocks per downsample. 130 | """ 131 | 132 | def __init__( 133 | self, 134 | in_channels, 135 | model_channels, 136 | out_channels, 137 | z_channels, 138 | num_res_blocks, 139 | grad_checkpointing=False 140 | ): 141 | super().__init__() 142 | 143 | self.in_channels = in_channels 144 | self.model_channels = model_channels 145 | self.out_channels = out_channels 146 | self.num_res_blocks = num_res_blocks 147 | self.grad_checkpointing = grad_checkpointing 148 | 149 | self.cond_embed = nn.Linear(z_channels, model_channels) 150 | 151 | self.input_proj = nn.Linear(in_channels, model_channels) 152 | 153 | res_blocks = [] 154 | for i in range(num_res_blocks): 155 | res_blocks.append(ResBlock( 156 | model_channels, 157 | )) 158 | 159 | self.res_blocks = nn.ModuleList(res_blocks) 160 | self.final_layer = FinalLayer(model_channels, out_channels) 161 | 162 | self.initialize_weights() 163 | 164 | def initialize_weights(self): 165 | def _basic_init(module): 166 | if isinstance(module, nn.Linear): 167 | torch.nn.init.xavier_uniform_(module.weight) 168 | if module.bias is not None: 169 | nn.init.constant_(module.bias, 0) 170 | module._is_hf_initialized = True 171 | self.apply(_basic_init) 172 | 173 | # Zero-out adaLN modulation layers 174 | for block in self.res_blocks: 175 | nn.init.constant_(block.adaLN_modulation[-1].weight, 0) 176 | nn.init.constant_(block.adaLN_modulation[-1].bias, 0) 177 | 178 | # Zero-out output layers 179 | nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) 180 | nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) 181 | nn.init.constant_(self.final_layer.linear.weight, 0) 182 | nn.init.constant_(self.final_layer.linear.bias, 0) 183 | 184 | def forward(self, x, c): 185 | """ 186 | Apply the model to an input batch. 187 | :param x: an [N x C x ...] Tensor of inputs. 188 | :param t: a 1-D batch of timesteps. 189 | :param c: conditioning from AR transformer. 190 | :return: an [N x C x ...] Tensor of outputs. 191 | """ 192 | y = self.input_proj(x) 193 | x = self.cond_embed(c) 194 | 195 | 196 | 197 | if self.grad_checkpointing and not torch.jit.is_scripting(): 198 | raise NotImplementedError 199 | else: 200 | for block in self.res_blocks: 201 | x = block(x, y) 202 | 203 | return self.final_layer(x, y) 204 | 205 | def forward_with_cfg(self, x, t, c, cfg_scale): 206 | half = x[: len(x) // 2] 207 | combined = torch.cat([half, half], dim=0) 208 | model_out = self.forward(combined, t, c) 209 | eps, rest = model_out[:, :self.in_channels], model_out[:, self.in_channels:] 210 | cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0) 211 | half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps) 212 | eps = torch.cat([half_eps, half_eps], dim=0) 213 | return torch.cat([eps, rest], dim=1) 214 | -------------------------------------------------------------------------------- /sled/sled.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, Tuple, Union, Dict, Any 2 | from pathlib import Path 3 | import torch 4 | import torch.nn as nn 5 | 6 | from transformers import EncodecModel 7 | from .modeling_llama_with_dropout import LlamaConfig, LlamaModel, LlamaForCausalLM 8 | from .modeling_llama_with_dropout import _prepare_4d_causal_attention_mask_with_cache_position 9 | from transformers.cache_utils import Cache, DynamicCache, StaticCache 10 | from transformers.modeling_outputs import CausalLMOutputWithPast 11 | from transformers.generation.utils import GenerateDecoderOnlyOutput 12 | from transformers.utils.import_utils import is_torchdynamo_compiling 13 | 14 | from .energy_distance import ScoreLossZ 15 | import copy 16 | 17 | 18 | class ModifiedGenerateDecoderOnlyOutput(GenerateDecoderOnlyOutput): 19 | features = None 20 | 21 | 22 | 23 | class SpeechLlamaConfig(LlamaConfig): 24 | model_type = "speech_llama" 25 | 26 | def __init__( 27 | self, 28 | vae_embed_dim=128, 29 | diffloss_d=3, 30 | diffloss_w=1024, 31 | training_cfg=0.0, 32 | noise_channels=128, 33 | **kwargs, 34 | ): 35 | self.vae_embed_dim = vae_embed_dim 36 | self.diffloss_d = diffloss_d 37 | self.diffloss_w = diffloss_w 38 | self.training_cfg = training_cfg 39 | self.noise_channels = noise_channels 40 | 41 | 42 | super().__init__(**kwargs) 43 | 44 | 45 | 46 | 47 | 48 | class SpeechLlamaForCausalLM(LlamaForCausalLM): 49 | config_class = SpeechLlamaConfig 50 | 51 | def __init__(self, config): 52 | super(LlamaForCausalLM, self).__init__(config) 53 | self.model = LlamaModel(config) 54 | self.codec = None 55 | 56 | # -------------------------------------------------------------------------- 57 | # Speech Embedding 58 | self.token_embed_dim = config.vae_embed_dim 59 | self.hidden_size = config.hidden_size 60 | self.z_proj = nn.Linear(self.token_embed_dim, self.hidden_size, bias=True) 61 | self.z_proj_ln = nn.LayerNorm(self.hidden_size, eps=1e-6) 62 | self.eos_head = nn.Linear(self.hidden_size, 1, bias=False) 63 | self.embed_mean = None 64 | self.embed_std = None 65 | self.training_cfg = config.training_cfg 66 | 67 | # -------------------------------------------------------------------------- 68 | # Score Loss 69 | self.scoreloss = ScoreLossZ( 70 | target_channels=self.token_embed_dim, 71 | z_channels=self.hidden_size, 72 | width=config.diffloss_w, 73 | depth=config.diffloss_d, 74 | noise_channels=config.noise_channels, 75 | ) 76 | 77 | # -------------------------------------------------------------------------- 78 | # BCE Loss 79 | pos_weight = torch.Tensor([100.]) # Weight of EOS is equal to 100 80 | self.bceloss = torch.nn.BCEWithLogitsLoss(reduction="none", pos_weight=pos_weight) 81 | 82 | 83 | # Initialize weights and apply final processing 84 | self.post_init() # Check whether affects init of LlamaModel 85 | 86 | def initialize_codec(self, model_args): 87 | if hasattr(model_args, "codec"): 88 | self.codec = EncodecModel.from_pretrained(model_args.codec, torch_dtype=torch.float32) # keep encodec model in fp32 89 | else: 90 | self.codec = EncodecModel.from_pretrained(model_args, torch_dtype=torch.float32) 91 | for param in self.codec.parameters(): 92 | param.requires_grad = False 93 | 94 | 95 | 96 | def _init_weights(self, m): 97 | if isinstance(m, nn.Linear): 98 | # we use xavier_uniform following official JAX ViT: 99 | torch.nn.init.xavier_uniform_(m.weight) 100 | if isinstance(m, nn.Linear) and m.bias is not None: 101 | nn.init.constant_(m.bias, 0) 102 | elif isinstance(m, nn.LayerNorm): 103 | if m.bias is not None: 104 | nn.init.constant_(m.bias, 0) 105 | if m.weight is not None: 106 | nn.init.constant_(m.weight, 1.0) 107 | 108 | 109 | def prepare_inputs_labels_for_multimodal( 110 | self, 111 | input_ids, 112 | position_ids, 113 | attention_mask, 114 | past_key_values, 115 | audio_inputs, 116 | ): 117 | if self.training_cfg > 0.0: 118 | bsz = attention_mask.size(0) 119 | random_mask = torch.rand(bsz) < self.training_cfg 120 | cfg_mask = torch.zeros_like(attention_mask, dtype=torch.bool) 121 | cfg_mask[:, :-1] = random_mask[:, None] 122 | attention_mask[cfg_mask] = 0 123 | 124 | text_inputs_embeds = self.model.embed_tokens(input_ids) 125 | 126 | with torch.no_grad(): 127 | encoder_outputs = self.codec.encode(audio_inputs["input_values"], audio_inputs["padding_mask"], bandwidth=6) #1,b,r,t, 1 due to one chunk 128 | speech_inputs_embeds = self.codec.quantizer.decode(encoder_outputs.audio_codes[0].transpose(0, 1)) #b,d,t, always fp32 129 | 130 | speech_attention_mask = audio_inputs["padding_mask"][..., ::320] 131 | assert speech_inputs_embeds.size(-1) == speech_attention_mask.size(-1) 132 | speech_inputs_embeds = speech_inputs_embeds.transpose(1,2).to(self.model.dtype) #b,t,d, support full bf16 training 133 | 134 | net_speech_inputs_embeds = self.z_proj(speech_inputs_embeds) 135 | new_inputs_embeds = torch.concat([text_inputs_embeds, net_speech_inputs_embeds], dim=1) #bsz, seq_len, hidden_size 136 | new_attention_mask = torch.concat([attention_mask, speech_attention_mask], dim=1) 137 | new_labels = speech_inputs_embeds 138 | 139 | return None, position_ids, new_attention_mask, past_key_values, new_inputs_embeds, new_labels, speech_attention_mask 140 | 141 | def forward( 142 | self, 143 | input_ids: torch.LongTensor = None, 144 | attention_mask: Optional[torch.Tensor] = None, 145 | position_ids: Optional[torch.LongTensor] = None, 146 | past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, 147 | inputs_embeds: Optional[torch.FloatTensor] = None, 148 | labels: Optional[torch.LongTensor] = None, 149 | use_cache: Optional[bool] = None, 150 | output_attentions: Optional[bool] = None, 151 | output_hidden_states: Optional[bool] = None, 152 | return_dict: Optional[bool] = None, 153 | cache_position: Optional[torch.LongTensor] = None, 154 | num_logits_to_keep: int = 0, 155 | audio_inputs: Optional[Dict[str, Any]] = None, 156 | speech_inputs_embeds: Optional[torch.FloatTensor] = None, 157 | speech_attention_mask: Optional[torch.Tensor] = None, 158 | ) -> Union[Tuple, CausalLMOutputWithPast]: 159 | 160 | 161 | if audio_inputs is not None: 162 | ( 163 | _, 164 | position_ids, 165 | attention_mask, 166 | past_key_values, 167 | inputs_embeds, 168 | labels, 169 | speech_attention_mask, 170 | ) = self.prepare_inputs_labels_for_multimodal( 171 | input_ids, 172 | position_ids, 173 | attention_mask, 174 | past_key_values, 175 | audio_inputs, 176 | ) 177 | else: 178 | assert not ((input_ids is None) and (inputs_embeds is None)) 179 | if input_ids is not None and inputs_embeds is None: 180 | inputs_embeds = self.model.embed_tokens(input_ids) 181 | elif input_ids is None and inputs_embeds is not None: 182 | inputs_embeds = self.z_proj(inputs_embeds) 183 | else: 184 | inputs_embeds = torch.cat([self.model.embed_tokens(input_ids), self.z_proj(inputs_embeds)], dim=1) 185 | 186 | inputs_embeds = self.z_proj_ln(inputs_embeds) 187 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 188 | output_hidden_states = ( 189 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 190 | ) 191 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 192 | 193 | # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) 194 | outputs = self.model( 195 | input_ids=None, 196 | attention_mask=attention_mask, 197 | position_ids=position_ids, 198 | past_key_values=past_key_values, 199 | inputs_embeds=inputs_embeds, 200 | use_cache=use_cache, 201 | output_attentions=output_attentions, 202 | output_hidden_states=output_hidden_states, 203 | return_dict=return_dict, 204 | cache_position=cache_position, 205 | ) 206 | 207 | hidden_states = outputs[0] 208 | logits = hidden_states[:, -num_logits_to_keep:, :] 209 | 210 | loss = None 211 | 212 | if labels is not None: 213 | bsz, speech_len, _ = labels.shape 214 | 215 | z = logits[:, -(speech_len+1):-1, :] #bsz, speech_len, hid_dim 216 | 217 | labels = labels.reshape(bsz * speech_len, -1) 218 | z = z.reshape(bsz * speech_len, -1) 219 | mask = speech_attention_mask.reshape(bsz * speech_len) 220 | loss = self.scoreloss(z=z, target=labels, mask=mask) 221 | 222 | eos_score = self.eos_head(logits[:, -(speech_len+1):, :]).squeeze(-1).float() #bsz, speech_len+1 223 | 224 | non_pad_counts = speech_attention_mask.sum(dim=1) 225 | eos_labels = torch.zeros(bsz, speech_len + 1) 226 | eos_labels[torch.arange(bsz), non_pad_counts] = 1 #bsz, speech_len+1 227 | eos_labels = eos_labels.to(eos_score.device) 228 | 229 | eos_loss = self.bceloss(eos_score, eos_labels) #bsz, speech_len+1 230 | 231 | #Check BCE loss weight BROADCASTING 232 | ones_column = torch.ones(bsz, 1).to(speech_attention_mask.device) 233 | loss_mask = torch.cat((ones_column, speech_attention_mask), dim=1) #bsz, speech_len+1 234 | 235 | eos_loss = (eos_loss * loss_mask).sum() / loss_mask.sum() 236 | 237 | loss = eos_loss + loss 238 | 239 | if not return_dict: 240 | output = (logits,) + outputs[1:] 241 | return (loss,) + output if loss is not None else output 242 | 243 | return CausalLMOutputWithPast( 244 | loss=loss, 245 | logits=logits, 246 | past_key_values=outputs.past_key_values, 247 | hidden_states=outputs.hidden_states, 248 | attentions=outputs.attentions, 249 | ) 250 | 251 | # Check 252 | def state_dict(self, *args, **kwargs): 253 | state_dict = super().state_dict(*args, **kwargs) 254 | codec_keys = [k for k in state_dict if 'codec' in k] 255 | for key in codec_keys: 256 | del state_dict[key] 257 | return state_dict 258 | 259 | # Check 260 | def load_state_dict(self, state_dict, strict=True): 261 | codec_keys = [k for k in state_dict if 'codec' in k] 262 | for key in codec_keys: 263 | del state_dict[key] 264 | return super().load_state_dict(state_dict, strict=False) 265 | 266 | 267 | def prepare_inputs_for_generation( 268 | self, 269 | input_ids, 270 | inputs_embeds=None, 271 | past_key_values=None, 272 | attention_mask=None, 273 | cache_position=None, 274 | position_ids=None, 275 | use_cache=True, 276 | num_logits_to_keep=None, 277 | **kwargs, 278 | ): 279 | if cache_position[0] == 0: 280 | if attention_mask is not None and position_ids is None: 281 | # create position_ids on the fly for batch generation 282 | position_ids = attention_mask.long().cumsum(-1) - 1 283 | position_ids.masked_fill_(attention_mask == 0, 1) 284 | 285 | if inputs_embeds is not None: 286 | model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": inputs_embeds} 287 | else: 288 | model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None} 289 | 290 | if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2: 291 | raise NotImplementedError 292 | 293 | if num_logits_to_keep is not None: 294 | model_inputs["num_logits_to_keep"] = num_logits_to_keep 295 | 296 | model_inputs.update( 297 | { 298 | "position_ids": position_ids, 299 | "cache_position": cache_position, 300 | "past_key_values": past_key_values, 301 | "use_cache": use_cache, 302 | "attention_mask": attention_mask, 303 | } 304 | ) 305 | else: 306 | if past_key_values is not None: 307 | inputs_embeds = inputs_embeds[:, -cache_position.shape[0] :, :] 308 | 309 | if attention_mask is not None and position_ids is None: 310 | # create position_ids on the fly for batch generation 311 | position_ids = attention_mask.long().cumsum(-1) - 1 312 | position_ids.masked_fill_(attention_mask == 0, 1) 313 | if past_key_values: 314 | position_ids = position_ids[:, -inputs_embeds.shape[1] :] 315 | position_ids = position_ids.clone(memory_format=torch.contiguous_format) 316 | 317 | model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None} 318 | 319 | if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2: 320 | raise NotImplementedError 321 | 322 | if num_logits_to_keep is not None: 323 | model_inputs["num_logits_to_keep"] = num_logits_to_keep 324 | 325 | model_inputs.update( 326 | { 327 | "position_ids": position_ids, 328 | "cache_position": cache_position, 329 | "past_key_values": past_key_values, 330 | "use_cache": use_cache, 331 | "attention_mask": attention_mask, 332 | } 333 | ) 334 | 335 | return model_inputs 336 | 337 | 338 | def prepare_inputs_for_generation_cfg( 339 | self, 340 | input_ids, 341 | inputs_embeds=None, 342 | past_key_values=None, 343 | attention_mask=None, 344 | cache_position=None, 345 | position_ids=None, 346 | use_cache=True, 347 | num_logits_to_keep=None, 348 | **kwargs, 349 | ): 350 | attention_mask = attention_mask.clone() 351 | attention_mask[:, :self.prompt_length-1] = 0 352 | 353 | if cache_position[0] == 0: 354 | if attention_mask is not None and position_ids is None: 355 | # create position_ids on the fly for batch generation 356 | position_ids = attention_mask.long().cumsum(-1) - 1 357 | position_ids.masked_fill_(attention_mask == 0, 1) 358 | 359 | if inputs_embeds is not None: 360 | model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": inputs_embeds} 361 | else: 362 | model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None} 363 | 364 | if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2: 365 | raise NotImplementedError 366 | 367 | if num_logits_to_keep is not None: 368 | model_inputs["num_logits_to_keep"] = num_logits_to_keep 369 | 370 | model_inputs.update( 371 | { 372 | "position_ids": position_ids, 373 | "cache_position": cache_position, 374 | "past_key_values": copy.deepcopy(past_key_values), 375 | "use_cache": use_cache, 376 | "attention_mask": attention_mask, 377 | } 378 | ) 379 | else: 380 | if past_key_values is not None: 381 | inputs_embeds = inputs_embeds[:, -cache_position.shape[0] :, :] 382 | 383 | if attention_mask is not None and position_ids is None: 384 | # create position_ids on the fly for batch generation 385 | position_ids = attention_mask.long().cumsum(-1) - 1 386 | position_ids.masked_fill_(attention_mask == 0, 1) 387 | if past_key_values: 388 | position_ids = position_ids[:, -inputs_embeds.shape[1] :] 389 | position_ids = position_ids.clone(memory_format=torch.contiguous_format) 390 | 391 | model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None} 392 | 393 | if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2: 394 | raise NotImplementedError 395 | 396 | if num_logits_to_keep is not None: 397 | model_inputs["num_logits_to_keep"] = num_logits_to_keep 398 | 399 | model_inputs.update( 400 | { 401 | "position_ids": position_ids, 402 | "cache_position": cache_position, 403 | "past_key_values":copy.deepcopy(past_key_values), 404 | "use_cache": use_cache, 405 | "attention_mask": attention_mask, 406 | } 407 | ) 408 | 409 | return model_inputs 410 | 411 | 412 | 413 | def _get_initial_cache_position(self, input_ids, model_kwargs): 414 | """Calculates `cache_position` for the pre-fill stage based on `input_ids` and optionally past length""" 415 | # `torch.compile`-friendly `torch.arange` from a shape -- the lines below are equivalent to `torch.arange` 416 | if "inputs_embeds" in model_kwargs: 417 | #cache_position = torch.ones_like(model_kwargs["inputs_embeds"][0, :, 0], dtype=torch.int64).cumsum(0) - 1 418 | cache_position = (torch.ones(model_kwargs["inputs_embeds"].shape[1] + input_ids.shape[1], dtype=torch.int64).cumsum(0) - 1).to(model_kwargs["inputs_embeds"].device) 419 | else: 420 | cache_position = torch.ones_like(input_ids[0, :], dtype=torch.int64).cumsum(0) - 1 421 | 422 | past_length = 0 423 | if model_kwargs.get("past_key_values") is not None: 424 | cache = model_kwargs["past_key_values"] 425 | past_length = 0 426 | if not isinstance(cache, Cache): 427 | past_length = cache[0][0].shape[2] 428 | elif hasattr(cache, "get_seq_length") and cache.get_seq_length() is not None: 429 | past_length = cache.get_seq_length() 430 | 431 | # TODO(joao): this is not torch.compile-friendly, find a work-around. If the cache is not empty, 432 | # end-to-end compilation will yield bad results because `cache_position` will be incorrect. 433 | if not is_torchdynamo_compiling(): 434 | cache_position = cache_position[past_length:] 435 | 436 | model_kwargs["cache_position"] = cache_position 437 | return model_kwargs 438 | 439 | 440 | def _sample( 441 | self, 442 | input_ids, 443 | logits_processor, 444 | stopping_criteria, 445 | generation_config, 446 | synced_gpus, 447 | streamer, 448 | **model_kwargs, 449 | ): 450 | r""" 451 | Parameters: 452 | input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): 453 | The sequence used as a prompt for the generation. 454 | logits_processor (`LogitsProcessorList`): 455 | An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] 456 | used to modify the prediction scores of the language modeling head applied at each generation step. 457 | stopping_criteria (`StoppingCriteriaList`): 458 | An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`] 459 | used to tell if the generation loop should stop. 460 | generation_config ([`~generation.GenerationConfig`]): 461 | The generation configuration to be used as parametrization of the decoding method. 462 | synced_gpus (`bool`): 463 | Whether to continue running the while loop until max_length (needed for ZeRO stage 3) 464 | streamer (`BaseStreamer`, *optional*): 465 | Streamer object that will be used to stream the generated sequences. Generated tokens are passed 466 | through `streamer.put(token_ids)` and the streamer is responsible for any further processing. 467 | model_kwargs: 468 | Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is 469 | an encoder-decoder model the kwargs should include `encoder_outputs`. 470 | 471 | Return: 472 | [`~generation.GenerateDecoderOnlyOutput`], [`~generation.GenerateEncoderDecoderOutput`] or `torch.LongTensor`: 473 | A `torch.LongTensor` containing the generated tokens (default behaviour) or a 474 | [`~generation.GenerateDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and 475 | `return_dict_in_generate=True` or a [`~generation.GenerateEncoderDecoderOutput`] if 476 | `model.config.is_encoder_decoder=True`. 477 | """ 478 | # init values 479 | pad_token_id = generation_config._pad_token_tensor 480 | bos_token_id = generation_config._bos_token_tensor 481 | eos_token_id = generation_config._eos_token_tensor 482 | output_attentions = generation_config.output_attentions 483 | output_hidden_states = generation_config.output_hidden_states 484 | output_scores = generation_config.output_scores 485 | output_logits = generation_config.output_logits 486 | return_dict_in_generate = generation_config.return_dict_in_generate 487 | max_length = generation_config.max_length 488 | has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria) 489 | do_sample = generation_config.do_sample 490 | 491 | # init attention / hidden states / scores tuples 492 | scores = () if (return_dict_in_generate and output_scores) else None 493 | raw_logits = () if (return_dict_in_generate and output_logits) else None 494 | decoder_attentions = () if (return_dict_in_generate and output_attentions) else None 495 | cross_attentions = () if (return_dict_in_generate and output_attentions) else None 496 | decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None 497 | 498 | 499 | inputs_embeds = model_kwargs.get("inputs_embeds", None) 500 | 501 | if self.infer_cfg != 1.0: 502 | real_batch_size = input_ids.shape[0] 503 | if inputs_embeds is not None: 504 | inputs_embeds = inputs_embeds.repeat(2, 1, 1) # expand 505 | input_ids = input_ids.repeat(2, 1) # expand 506 | self.prompt_length = input_ids.shape[1] 507 | extended_attention_mask = model_kwargs["attention_mask"].clone() 508 | extended_attention_mask[:, :self.prompt_length-1] = 0 509 | model_kwargs["attention_mask"] = torch.cat([model_kwargs["attention_mask"], extended_attention_mask], dim=0) # expand 510 | 511 | 512 | # keep track of which sequences are already finished 513 | batch_size, cur_len = input_ids.shape if inputs_embeds is None else (input_ids.shape[0], input_ids.shape[1] + inputs_embeds.shape[1]) 514 | this_peer_finished = False 515 | unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device) 516 | model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs) 517 | model_kwargs.pop("inputs_embeds", None) 518 | 519 | 520 | while self._has_unfinished_sequences( 521 | this_peer_finished, synced_gpus, device=input_ids.device, cur_len=cur_len, max_length=max_length 522 | ): 523 | # prepare model inputs 524 | model_inputs = self.prepare_inputs_for_generation(input_ids, inputs_embeds, **model_kwargs) 525 | 526 | # prepare variable output controls (note: some models won't accept all output controls) 527 | model_inputs.update({"output_attentions": output_attentions} if output_attentions else {}) 528 | model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {}) 529 | 530 | # forward pass to get next token 531 | outputs = self(**model_inputs, return_dict=True) 532 | 533 | 534 | if synced_gpus and this_peer_finished: 535 | continue # don't waste resources running the code we don't need 536 | 537 | # Clone is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration 538 | # (the clone itself is always small) 539 | next_token_logits = outputs.logits.clone()[:, -1, :] 540 | eos_next_token_logits = next_token_logits 541 | if self.infer_cfg != 1.0: 542 | next_token_logits_normal = next_token_logits[:real_batch_size, :] 543 | next_token_logits_cfg = next_token_logits[real_batch_size:, :] 544 | next_token_logits = next_token_logits_cfg + self.infer_cfg * (next_token_logits_normal - next_token_logits_cfg) 545 | eos_next_token_logits = eos_next_token_logits[:real_batch_size, :] 546 | # Store scores, attentions and hidden_states when required 547 | if return_dict_in_generate: 548 | raise NotImplementedError 549 | 550 | 551 | # token selection 552 | if do_sample: 553 | next_embeds = self.scoreloss.sample(next_token_logits, temperature=1.0) # bsz, dim 554 | 555 | next_actions = torch.sigmoid(self.eos_head(eos_next_token_logits)) >= 0.8 # 0: continue, 1: stop 556 | next_tokens = torch.where(next_actions == 0, bos_token_id, eos_token_id) 557 | 558 | if self.infer_cfg != 1.0: 559 | # exband 560 | next_embeds = next_embeds.repeat(2, 1) 561 | next_tokens = next_tokens.repeat(2, 1) 562 | 563 | 564 | else: 565 | raise NotImplementedError 566 | 567 | 568 | # finished sentences should have their next token be a padding token 569 | if has_eos_stopping_criteria: 570 | next_tokens = next_tokens * unfinished_sequences.unsqueeze(1) + pad_token_id * (1 - unfinished_sequences.unsqueeze(1)) 571 | 572 | # update generated ids, model inputs, and length for next step 573 | if inputs_embeds is not None: 574 | inputs_embeds = torch.cat([inputs_embeds, next_embeds[:, None, :]], dim=1) 575 | else: 576 | inputs_embeds = next_embeds[:, None, :] 577 | input_ids = torch.cat([input_ids, next_tokens], dim=-1) 578 | 579 | 580 | model_kwargs = self._update_model_kwargs_for_generation( 581 | outputs, 582 | model_kwargs, 583 | ) 584 | 585 | unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, None) 586 | this_peer_finished = unfinished_sequences.max() == 0 587 | cur_len += 1 588 | 589 | # This is needed to properly delete outputs.logits which may be very large for first iteration 590 | # Otherwise a reference to outputs is kept which keeps the logits alive in the next iteration 591 | del outputs 592 | 593 | if self.infer_cfg != 1.0: 594 | input_ids = input_ids[:real_batch_size] 595 | inputs_embeds = inputs_embeds[:real_batch_size] 596 | 597 | if return_dict_in_generate: 598 | return ModifiedGenerateDecoderOnlyOutput( 599 | sequences=input_ids, 600 | features=inputs_embeds, 601 | scores=scores, 602 | logits=raw_logits, 603 | attentions=decoder_attentions, 604 | hidden_states=decoder_hidden_states, 605 | past_key_values=model_kwargs.get("past_key_values"), 606 | ) 607 | else: 608 | return (input_ids, inputs_embeds) 609 | 610 | 611 | def _update_model_kwargs_for_generation( 612 | self, 613 | outputs, 614 | model_kwargs: Dict[str, Any], 615 | num_new_tokens: int = 1, 616 | ) -> Dict[str, Any]: 617 | # update past_key_values keeping its naming used in model code 618 | cache_name, cache = self._extract_past_from_model_output(outputs) 619 | model_kwargs[cache_name] = cache 620 | 621 | # update attention mask 622 | if "attention_mask" in model_kwargs: 623 | attention_mask = model_kwargs["attention_mask"] 624 | model_kwargs["attention_mask"] = torch.cat( 625 | [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1 626 | ) 627 | 628 | if model_kwargs.get("use_cache", True): 629 | model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + num_new_tokens 630 | else: 631 | raise NotImplementedError 632 | 633 | return model_kwargs 634 | -------------------------------------------------------------------------------- /sled/sled_stream.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, Tuple, Union, Dict, Any 2 | from pathlib import Path 3 | import torch 4 | import torch.nn as nn 5 | 6 | from transformers import EncodecModel 7 | from .modeling_llama_with_dropout import LlamaConfig, LlamaModel, LlamaForCausalLM 8 | from .modeling_llama_with_dropout import _prepare_4d_causal_attention_mask_with_cache_position 9 | from transformers.cache_utils import Cache, DynamicCache, StaticCache 10 | from transformers.modeling_outputs import CausalLMOutputWithPast 11 | from transformers.generation.utils import GenerateDecoderOnlyOutput 12 | from transformers.utils.import_utils import is_torchdynamo_compiling 13 | 14 | from .energy_distance import ScoreLossZ 15 | import copy 16 | 17 | 18 | class ModifiedGenerateDecoderOnlyOutput(GenerateDecoderOnlyOutput): 19 | features = None 20 | 21 | 22 | def interleave_embeddings_and_mask_efficient(text_embeds, speech_embeds, text_mask, speech_mask, n, m): 23 | 24 | """ 25 | 26 | Parameter: 27 | - text_embeds: Tensor with shape (b, t1, d) 28 | - speech_embeds: Tensor with shape (b, t2, d) 29 | - text_mask: Tensor with shape: (b, t1) 30 | - speech_mask: Tensor with shape: (b, t2) 31 | - n: int,number of consecutive text_embeds 32 | - m: int,number of consecutive speech_embeds 33 | 34 | Return: 35 | - interleaved: Tensor with shape (b, t1 + t2, d) 36 | - interleaved_mask: Tensor with shape (b, t1 + t2) 37 | - speech_positions_mask: Tensor with shape (b, t1 + t2) 38 | """ 39 | 40 | b, t1, d = text_embeds.size() 41 | _, t2, _ = speech_embeds.size() 42 | 43 | 44 | num_cycles_text = t1 // n 45 | num_cycles_speech = t2 // m 46 | total_cycles = min(num_cycles_text, num_cycles_speech) 47 | 48 | interleaved_blocks = [] 49 | interleaved_mask_blocks = [] 50 | speech_positions_mask_blocks = [] 51 | 52 | 53 | if total_cycles > 0: 54 | text_main = text_embeds[:, :total_cycles * n, :].reshape(b, total_cycles, n, d) 55 | speech_main = speech_embeds[:, :total_cycles * m, :].reshape(b, total_cycles, m, d) 56 | 57 | text_mask_main = text_mask[:, :total_cycles * n].reshape(b, total_cycles, n) 58 | speech_mask_main = speech_mask[:, :total_cycles * m].reshape(b, total_cycles, m) 59 | 60 | interleaved_main = torch.cat([text_main, speech_main], dim=2).reshape(b, total_cycles * (n + m), d) 61 | interleaved_blocks.append(interleaved_main) 62 | 63 | interleaved_mask_main = torch.cat([text_mask_main, speech_mask_main], dim=2).reshape(b, total_cycles * (n + m)) 64 | interleaved_mask_blocks.append(interleaved_mask_main) 65 | 66 | 67 | text_zero_mask = torch.zeros_like(text_mask_main) 68 | speech_one_main = torch.ones_like(speech_mask_main) 69 | 70 | speech_positions_main = torch.cat([text_zero_mask, speech_one_main], dim=2).reshape(b, total_cycles * (n + m)) 71 | speech_positions_mask_blocks.append(speech_positions_main) 72 | 73 | 74 | remaining_text = text_embeds[:, total_cycles * n:, :] 75 | remaining_speech = speech_embeds[:, total_cycles * m:, :] 76 | 77 | remaining_mask_text = text_mask[:, total_cycles * n:] 78 | remaining_mask_speech = speech_mask[:, total_cycles * m:] 79 | 80 | 81 | remaining_num_text = remaining_text.size(1) 82 | remaining_num_speech = remaining_speech.size(1) 83 | 84 | assert (remaining_num_text < n) or (remaining_num_speech < m) 85 | 86 | interleaved_blocks.append(remaining_text[:, :n, :]) 87 | interleaved_blocks.append(remaining_speech) 88 | interleaved_blocks.append(remaining_text[:, n:, :]) 89 | 90 | interleaved_mask_blocks.append(remaining_mask_text[:, :n]) 91 | interleaved_mask_blocks.append(remaining_mask_speech) 92 | interleaved_mask_blocks.append(remaining_mask_text[:, n:]) 93 | 94 | speech_positions_mask_blocks.append(torch.zeros_like(remaining_mask_text[:, :n])) 95 | speech_positions_mask_blocks.append(torch.ones_like(remaining_mask_speech)) 96 | speech_positions_mask_blocks.append(torch.zeros_like(remaining_mask_text[:, n:])) 97 | 98 | interleaved = torch.cat(interleaved_blocks, dim=1) 99 | interleaved_mask = torch.cat(interleaved_mask_blocks, dim=1) 100 | speech_positions_mask = torch.cat(speech_positions_mask_blocks, dim=1) 101 | 102 | assert interleaved.size(1) == (t1 + t2) == interleaved_mask.size(1) == speech_positions_mask.size(1) 103 | 104 | return interleaved, interleaved_mask, speech_positions_mask 105 | 106 | 107 | 108 | def get_previous_non_pad_indices(attention_mask): 109 | """ 110 | find previous non-pad index for each position 111 | 112 | Parameter: 113 | attention_mask (torch.Tensor): shape (batch_size, seq_length), 0 for pad,1 for non-pad。 114 | 115 | Return: 116 | torch.Tensor: shape (batch_size, seq_length), each element is the index of the previous non-pad position, or -1 if there is none. 117 | """ 118 | 119 | batch_size, seq_length = attention_mask.shape 120 | 121 | 122 | indices = torch.arange(seq_length, device=attention_mask.device).unsqueeze(0).expand(batch_size, -1) # (batch_size, seq_length) 123 | 124 | 125 | mask = attention_mask == 1 126 | indices = torch.where(mask, indices, torch.full_like(indices, -1)) 127 | 128 | 129 | shifted_indices = torch.cat([torch.full((batch_size, 1), -1, device=attention_mask.device), indices[:, :-1]], dim=1) 130 | 131 | 132 | previous_non_pad = torch.cummax(shifted_indices, dim=1).values 133 | 134 | return previous_non_pad 135 | 136 | 137 | 138 | class SpeechLlamaConfig(LlamaConfig): 139 | model_type = "speech_llama" 140 | 141 | def __init__( 142 | self, 143 | vae_embed_dim=128, 144 | diffloss_d=3, 145 | diffloss_w=1024, 146 | training_cfg=0.0, 147 | noise_channels=128, 148 | stream_n=5, 149 | stream_m=45, 150 | **kwargs, 151 | ): 152 | 153 | self.vae_embed_dim = vae_embed_dim 154 | self.diffloss_d = diffloss_d 155 | self.diffloss_w = diffloss_w 156 | self.training_cfg = training_cfg 157 | self.noise_channels = noise_channels 158 | self.stream_n = stream_n 159 | self.stream_m = stream_m 160 | 161 | 162 | super().__init__(**kwargs) 163 | 164 | 165 | 166 | 167 | 168 | class SpeechLlamaForCausalLM(LlamaForCausalLM): 169 | config_class = SpeechLlamaConfig 170 | 171 | def __init__(self, config): 172 | super(LlamaForCausalLM, self).__init__(config) 173 | self.model = LlamaModel(config) 174 | self.codec = None 175 | 176 | # -------------------------------------------------------------------------- 177 | # Speech Embedding 178 | self.token_embed_dim = config.vae_embed_dim 179 | self.hidden_size = config.hidden_size 180 | self.z_proj = nn.Linear(self.token_embed_dim, self.hidden_size, bias=True) 181 | self.z_proj_ln = nn.LayerNorm(self.hidden_size, eps=1e-6) 182 | self.eos_head = nn.Linear(self.hidden_size, 1, bias=False) 183 | self.embed_mean = None 184 | self.embed_std = None 185 | self.training_cfg = config.training_cfg 186 | 187 | # -------------------------------------------------------------------------- 188 | # Score Loss 189 | self.scoreloss = ScoreLossZ( 190 | target_channels=self.token_embed_dim, 191 | z_channels=self.hidden_size, 192 | width=config.diffloss_w, 193 | depth=config.diffloss_d, 194 | noise_channels=config.noise_channels, 195 | ) 196 | 197 | # -------------------------------------------------------------------------- 198 | # BCE Loss 199 | pos_weight = torch.Tensor([100.]) # Weight of EOS is equal to 100 200 | self.bceloss = torch.nn.BCEWithLogitsLoss(reduction="none", pos_weight=pos_weight) 201 | 202 | self.stream_n = config.stream_n 203 | self.stream_m = config.stream_m 204 | 205 | # Initialize weights and apply final processing 206 | self.post_init() # Check whether affects init of LlamaModel 207 | 208 | def initialize_codec(self, model_args): 209 | if hasattr(model_args, "codec"): 210 | self.codec = EncodecModel.from_pretrained(model_args.codec, torch_dtype=torch.float32) # keep encodec model in fp32 211 | else: 212 | self.codec = EncodecModel.from_pretrained(model_args, torch_dtype=torch.float32) 213 | for param in self.codec.parameters(): 214 | param.requires_grad = False 215 | 216 | 217 | 218 | def _init_weights(self, m): 219 | if isinstance(m, nn.Linear): 220 | # we use xavier_uniform following official JAX ViT: 221 | torch.nn.init.xavier_uniform_(m.weight) 222 | if isinstance(m, nn.Linear) and m.bias is not None: 223 | nn.init.constant_(m.bias, 0) 224 | elif isinstance(m, nn.LayerNorm): 225 | if m.bias is not None: 226 | nn.init.constant_(m.bias, 0) 227 | if m.weight is not None: 228 | nn.init.constant_(m.weight, 1.0) 229 | 230 | 231 | def prepare_inputs_labels_for_multimodal( 232 | self, 233 | input_ids, 234 | position_ids, 235 | attention_mask, 236 | past_key_values, 237 | audio_inputs, 238 | ): 239 | if self.training_cfg > 0.0: 240 | bsz = attention_mask.size(0) 241 | random_mask = torch.rand(bsz) < self.training_cfg 242 | attention_mask[random_mask] = 0 243 | attention_mask[random_mask, 0] = 1 244 | input_ids[random_mask, 0] = 2 245 | 246 | 247 | text_inputs_embeds = self.model.embed_tokens(input_ids) 248 | 249 | with torch.no_grad(): 250 | encoder_outputs = self.codec.encode(audio_inputs["input_values"], audio_inputs["padding_mask"], bandwidth=6) #1,b,r,t, 1 due to one chunk 251 | speech_inputs_embeds = self.codec.quantizer.decode(encoder_outputs.audio_codes[0].transpose(0, 1)) #b,d,t, always fp32 252 | 253 | speech_attention_mask = audio_inputs["padding_mask"][..., ::320] 254 | assert speech_inputs_embeds.size(-1) == speech_attention_mask.size(-1) 255 | speech_inputs_embeds = speech_inputs_embeds.transpose(1,2).to(self.model.dtype) #b,t,d, support full bf16 training 256 | 257 | 258 | net_speech_inputs_embeds = self.z_proj(speech_inputs_embeds) 259 | new_inputs_embeds, new_attention_mask, speech_positions_mask = interleave_embeddings_and_mask_efficient(text_inputs_embeds, net_speech_inputs_embeds, attention_mask, speech_attention_mask, self.stream_n, self.stream_m) 260 | 261 | new_labels = speech_inputs_embeds 262 | 263 | return None, position_ids, new_attention_mask, past_key_values, new_inputs_embeds, new_labels, speech_attention_mask, speech_positions_mask 264 | 265 | def forward( 266 | self, 267 | input_ids: torch.LongTensor = None, 268 | attention_mask: Optional[torch.Tensor] = None, 269 | position_ids: Optional[torch.LongTensor] = None, 270 | past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, 271 | inputs_embeds: Optional[torch.FloatTensor] = None, 272 | labels: Optional[torch.LongTensor] = None, 273 | use_cache: Optional[bool] = None, 274 | output_attentions: Optional[bool] = None, 275 | output_hidden_states: Optional[bool] = None, 276 | return_dict: Optional[bool] = None, 277 | cache_position: Optional[torch.LongTensor] = None, 278 | num_logits_to_keep: int = 0, 279 | audio_inputs: Optional[Dict[str, Any]] = None, 280 | speech_inputs_embeds: Optional[torch.FloatTensor] = None, 281 | speech_attention_mask: Optional[torch.Tensor] = None, 282 | ) -> Union[Tuple, CausalLMOutputWithPast]: 283 | 284 | 285 | if audio_inputs is not None: 286 | ( 287 | _, 288 | position_ids, 289 | attention_mask, 290 | past_key_values, 291 | inputs_embeds, 292 | labels, 293 | speech_attention_mask, 294 | speech_positions_mask, 295 | ) = self.prepare_inputs_labels_for_multimodal( 296 | input_ids, 297 | position_ids, 298 | attention_mask, 299 | past_key_values, 300 | audio_inputs, 301 | ) 302 | else: 303 | assert not ((input_ids is None) and (inputs_embeds is None)) 304 | if input_ids is not None and inputs_embeds is None: 305 | inputs_embeds = self.model.embed_tokens(input_ids) 306 | elif input_ids is None and inputs_embeds is not None: 307 | inputs_embeds = self.z_proj(inputs_embeds) 308 | else: 309 | inputs_embeds = torch.cat([self.z_proj(inputs_embeds), self.model.embed_tokens(input_ids)], dim=1) 310 | 311 | inputs_embeds = self.z_proj_ln(inputs_embeds) 312 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 313 | output_hidden_states = ( 314 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 315 | ) 316 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 317 | 318 | # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) 319 | outputs = self.model( 320 | input_ids=None, 321 | attention_mask=attention_mask, 322 | position_ids=position_ids, 323 | past_key_values=past_key_values, 324 | inputs_embeds=inputs_embeds, 325 | use_cache=use_cache, 326 | output_attentions=output_attentions, 327 | output_hidden_states=output_hidden_states, 328 | return_dict=return_dict, 329 | cache_position=cache_position, 330 | ) 331 | 332 | hidden_states = outputs[0] 333 | logits = hidden_states[:, -num_logits_to_keep:, :] 334 | 335 | loss = None 336 | 337 | if labels is not None: 338 | 339 | bsz, speech_len, _ = labels.shape 340 | feat_dim = logits.size(-1) 341 | generate_index = get_previous_non_pad_indices(attention_mask) 342 | selected_indices = generate_index.masked_select((speech_positions_mask == 1)) # bsz * speech_len 343 | selected_indices = selected_indices.view(bsz, speech_len) # bsz, speech_len 344 | selected_indices_expanded = selected_indices.unsqueeze(-1).expand(-1, -1, feat_dim) # bsz * speech_len * feat_dim 345 | 346 | z = torch.gather(logits, dim=1, index=selected_indices_expanded) # bsz * speech_len * feat_dim 347 | 348 | 349 | labels = labels.reshape(bsz * speech_len, -1) 350 | mask = speech_attention_mask.reshape(bsz * speech_len) 351 | loss = self.scoreloss(z=z.reshape(bsz * speech_len, -1), target=labels, mask=mask) 352 | 353 | current_index = torch.arange(attention_mask.size(1), device=attention_mask.device).unsqueeze(0).expand(bsz, -1) # (bsz, full_seq_length) 354 | selected_current_indices = current_index.masked_select((speech_positions_mask == 1)) # bsz * speech_len 355 | selected_current_indices = selected_current_indices.view(bsz, speech_len)[:, -1:] # bsz, 1 356 | selected_current_indices_expanded = selected_current_indices.unsqueeze(-1).expand(-1, -1, feat_dim) # bsz * 1 * feat_dim 357 | 358 | z_last = torch.gather(logits, dim=1, index=selected_current_indices_expanded) # bsz * 1 * feat_dim 359 | 360 | z = torch.cat([z, z_last], dim=1) 361 | 362 | eos_score = self.eos_head(z).squeeze(-1).float() #bsz, speech_len+1 363 | 364 | non_pad_counts = speech_attention_mask.sum(dim=1) 365 | eos_labels = torch.zeros(bsz, speech_len + 1) 366 | eos_labels[torch.arange(bsz), non_pad_counts] = 1 #bsz, speech_len+1 367 | eos_labels = eos_labels.to(eos_score.device) 368 | 369 | eos_loss = self.bceloss(eos_score, eos_labels) #bsz, speech_len+1 370 | 371 | #Check BCE loss weight BROADCASTING 372 | ones_column = torch.ones(bsz, 1).to(speech_attention_mask.device) 373 | loss_mask = torch.cat((ones_column, speech_attention_mask), dim=1) #bsz, speech_len+1 374 | 375 | eos_loss = (eos_loss * loss_mask).sum() / loss_mask.sum() 376 | 377 | loss = eos_loss + loss 378 | 379 | if not return_dict: 380 | output = (logits,) + outputs[1:] 381 | return (loss,) + output if loss is not None else output 382 | 383 | return CausalLMOutputWithPast( 384 | loss=loss, 385 | logits=logits, 386 | past_key_values=outputs.past_key_values, 387 | hidden_states=outputs.hidden_states, 388 | attentions=outputs.attentions, 389 | ) 390 | 391 | # Check 392 | def state_dict(self, *args, **kwargs): 393 | state_dict = super().state_dict(*args, **kwargs) 394 | codec_keys = [k for k in state_dict if 'codec' in k] 395 | for key in codec_keys: 396 | del state_dict[key] 397 | return state_dict 398 | 399 | # Check 400 | def load_state_dict(self, state_dict, strict=True): 401 | codec_keys = [k for k in state_dict if 'codec' in k] 402 | for key in codec_keys: 403 | del state_dict[key] 404 | return super().load_state_dict(state_dict, strict=False) 405 | 406 | 407 | def prepare_inputs_for_generation( 408 | self, 409 | input_ids, 410 | inputs_embeds=None, 411 | turn=0, 412 | num_write_turn=0, 413 | past_key_values=None, 414 | attention_mask=None, 415 | cache_position=None, 416 | position_ids=None, 417 | use_cache=True, 418 | num_logits_to_keep=None, 419 | **kwargs, 420 | ): 421 | if num_write_turn !=0 : 422 | assert cache_position.shape[0] == 1 423 | assert past_key_values is not None 424 | inputs_embeds = inputs_embeds[:, -cache_position.shape[0] :, :] 425 | 426 | if attention_mask is not None and position_ids is None: 427 | # create position_ids on the fly for batch generation 428 | position_ids = attention_mask.long().cumsum(-1) - 1 429 | position_ids.masked_fill_(attention_mask == 0, 1) 430 | 431 | position_ids = position_ids[:, -inputs_embeds.shape[1] :] 432 | position_ids = position_ids.clone(memory_format=torch.contiguous_format) 433 | 434 | model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None} 435 | 436 | if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2: 437 | raise NotImplementedError 438 | 439 | if num_logits_to_keep is not None: 440 | model_inputs["num_logits_to_keep"] = num_logits_to_keep 441 | 442 | model_inputs.update( 443 | { 444 | "position_ids": position_ids, 445 | "cache_position": cache_position, 446 | "past_key_values": past_key_values, 447 | "use_cache": use_cache, 448 | "attention_mask": attention_mask, 449 | } 450 | ) 451 | 452 | else: 453 | if inputs_embeds is not None: 454 | inputs_embeds = inputs_embeds[:, -1 :, :] 455 | 456 | if attention_mask is not None and position_ids is None: 457 | # create position_ids on the fly for batch generation 458 | position_ids = attention_mask.long().cumsum(-1) - 1 459 | position_ids.masked_fill_(attention_mask == 0, 1) 460 | 461 | position_ids = position_ids[:, -cache_position.shape[0] :] 462 | position_ids = position_ids.clone(memory_format=torch.contiguous_format) 463 | 464 | model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": inputs_embeds} 465 | 466 | if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2: 467 | raise NotImplementedError 468 | 469 | if num_logits_to_keep is not None: 470 | model_inputs["num_logits_to_keep"] = num_logits_to_keep 471 | 472 | model_inputs.update( 473 | { 474 | "position_ids": position_ids, 475 | "cache_position": cache_position, 476 | "past_key_values": past_key_values, 477 | "use_cache": use_cache, 478 | "attention_mask": attention_mask, 479 | } 480 | ) 481 | return model_inputs 482 | 483 | 484 | def prepare_inputs_for_generation_cfg( 485 | self, 486 | input_ids, 487 | inputs_embeds=None, 488 | past_key_values=None, 489 | attention_mask=None, 490 | cache_position=None, 491 | position_ids=None, 492 | use_cache=True, 493 | num_logits_to_keep=None, 494 | **kwargs, 495 | ): 496 | if cache_position[0] == 0: 497 | if attention_mask is not None and position_ids is None: 498 | # create position_ids on the fly for batch generation 499 | position_ids = attention_mask.long().cumsum(-1) - 1 500 | position_ids.masked_fill_(attention_mask == 0, 1) 501 | 502 | model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None} 503 | 504 | if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2: 505 | raise NotImplementedError 506 | 507 | if num_logits_to_keep is not None: 508 | model_inputs["num_logits_to_keep"] = num_logits_to_keep 509 | 510 | model_inputs.update( 511 | { 512 | "position_ids": position_ids, 513 | "cache_position": cache_position, 514 | "past_key_values": past_key_values, 515 | "use_cache": use_cache, 516 | "attention_mask": attention_mask, 517 | } 518 | ) 519 | else: 520 | if past_key_values is not None: 521 | inputs_embeds = inputs_embeds[:, -cache_position.shape[0] :, :] 522 | 523 | if attention_mask is not None and position_ids is None: 524 | # create position_ids on the fly for batch generation 525 | position_ids = attention_mask.long().cumsum(-1) - 1 526 | position_ids.masked_fill_(attention_mask == 0, 1) 527 | if past_key_values: 528 | position_ids = position_ids[:, -inputs_embeds.shape[1] :] 529 | position_ids = position_ids.clone(memory_format=torch.contiguous_format) 530 | 531 | model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None} 532 | 533 | if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2: 534 | raise NotImplementedError 535 | 536 | if num_logits_to_keep is not None: 537 | model_inputs["num_logits_to_keep"] = num_logits_to_keep 538 | 539 | model_inputs.update( 540 | { 541 | "position_ids": position_ids, 542 | "cache_position": cache_position, 543 | "past_key_values": past_key_values, 544 | "use_cache": use_cache, 545 | "attention_mask": attention_mask, 546 | } 547 | ) 548 | 549 | return model_inputs 550 | 551 | 552 | 553 | def _get_initial_cache_position_after_each_read(self, input_ids, model_kwargs): 554 | """Calculates `cache_position` for the pre-fill stage based on `input_ids` and optionally past length""" 555 | # `torch.compile`-friendly `torch.arange` from a shape -- the lines below are equivalent to `torch.arange` 556 | cache_position = torch.ones_like(input_ids[0, :], dtype=torch.int64).cumsum(0) - 1 557 | 558 | past_length = 0 559 | if model_kwargs.get("past_key_values") is not None: 560 | cache = model_kwargs["past_key_values"] 561 | past_length = 0 562 | if not isinstance(cache, Cache): 563 | past_length = cache[0][0].shape[2] 564 | elif hasattr(cache, "get_seq_length") and cache.get_seq_length() is not None: 565 | past_length = cache.get_seq_length() 566 | 567 | # TODO(joao): this is not torch.compile-friendly, find a work-around. If the cache is not empty, 568 | # end-to-end compilation will yield bad results because `cache_position` will be incorrect. 569 | if not is_torchdynamo_compiling(): 570 | cache_position = cache_position + past_length 571 | if model_kwargs.get("cache_position", None) is not None: 572 | model_kwargs["cache_position"] = torch.cat([model_kwargs["cache_position"], cache_position + 1], dim=0) 573 | else: 574 | model_kwargs["cache_position"] = cache_position 575 | return model_kwargs 576 | 577 | 578 | def _get_initial_cache_position_cfg(self, input_ids, model_kwargs): 579 | """Calculates `cache_position` for the pre-fill stage based on `input_ids` and optionally past length""" 580 | # `torch.compile`-friendly `torch.arange` from a shape -- the lines below are equivalent to `torch.arange` 581 | cache_position = torch.ones_like(input_ids[0, :], dtype=torch.int64).cumsum(0) - 1 582 | 583 | past_length = 0 584 | if model_kwargs.get("past_key_values") is not None: 585 | cache = model_kwargs["past_key_values"] 586 | past_length = 0 587 | if not isinstance(cache, Cache): 588 | past_length = cache[0][0].shape[2] 589 | elif hasattr(cache, "get_seq_length") and cache.get_seq_length() is not None: 590 | past_length = cache.get_seq_length() 591 | 592 | # TODO(joao): this is not torch.compile-friendly, find a work-around. If the cache is not empty, 593 | # end-to-end compilation will yield bad results because `cache_position` will be incorrect. 594 | if not is_torchdynamo_compiling(): 595 | cache_position = cache_position[past_length:] 596 | 597 | model_kwargs["cache_position"] = cache_position 598 | return model_kwargs 599 | 600 | 601 | def _sample( 602 | self, 603 | input_ids, 604 | logits_processor, 605 | stopping_criteria, 606 | generation_config, 607 | synced_gpus, 608 | streamer, 609 | **model_kwargs, 610 | ): 611 | r""" 612 | Parameters: 613 | input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): 614 | The sequence used as a prompt for the generation. 615 | logits_processor (`LogitsProcessorList`): 616 | An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] 617 | used to modify the prediction scores of the language modeling head applied at each generation step. 618 | stopping_criteria (`StoppingCriteriaList`): 619 | An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`] 620 | used to tell if the generation loop should stop. 621 | generation_config ([`~generation.GenerationConfig`]): 622 | The generation configuration to be used as parametrization of the decoding method. 623 | synced_gpus (`bool`): 624 | Whether to continue running the while loop until max_length (needed for ZeRO stage 3) 625 | streamer (`BaseStreamer`, *optional*): 626 | Streamer object that will be used to stream the generated sequences. Generated tokens are passed 627 | through `streamer.put(token_ids)` and the streamer is responsible for any further processing. 628 | model_kwargs: 629 | Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is 630 | an encoder-decoder model the kwargs should include `encoder_outputs`. 631 | 632 | Return: 633 | [`~generation.GenerateDecoderOnlyOutput`], [`~generation.GenerateEncoderDecoderOutput`] or `torch.LongTensor`: 634 | A `torch.LongTensor` containing the generated tokens (default behaviour) or a 635 | [`~generation.GenerateDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and 636 | `return_dict_in_generate=True` or a [`~generation.GenerateEncoderDecoderOutput`] if 637 | `model.config.is_encoder_decoder=True`. 638 | """ 639 | # init values 640 | pad_token_id = generation_config._pad_token_tensor 641 | bos_token_id = generation_config._bos_token_tensor 642 | eos_token_id = generation_config._eos_token_tensor 643 | output_attentions = generation_config.output_attentions 644 | output_hidden_states = generation_config.output_hidden_states 645 | output_scores = generation_config.output_scores 646 | output_logits = generation_config.output_logits 647 | return_dict_in_generate = generation_config.return_dict_in_generate 648 | max_length = generation_config.max_length 649 | has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria) 650 | do_sample = generation_config.do_sample 651 | 652 | # init attention / hidden states / scores tuples 653 | scores = () if (return_dict_in_generate and output_scores) else None 654 | raw_logits = () if (return_dict_in_generate and output_logits) else None 655 | decoder_attentions = () if (return_dict_in_generate and output_attentions) else None 656 | cross_attentions = () if (return_dict_in_generate and output_attentions) else None 657 | decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None 658 | 659 | 660 | 661 | # keep track of which sequences are already finished 662 | batch_size = input_ids.shape[0] 663 | assert batch_size == 1 # only support batch size 1 664 | turn = 0 665 | read_finished = False 666 | this_peer_finished = False 667 | unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device) 668 | last_src_read_len = 0 669 | cur_len = 0 670 | complete_attention_mask = model_kwargs.pop("attention_mask") 671 | model_kwargs["attention_mask"] = torch.zeros((batch_size, 0), dtype=complete_attention_mask.dtype, device=complete_attention_mask.device) 672 | inputs_embeds = None 673 | generate_ids = torch.zeros((batch_size, 0), dtype=input_ids.dtype, device=input_ids.device) 674 | 675 | 676 | if self.infer_cfg != 1.0: 677 | model_kwargs_cfg = copy.deepcopy(model_kwargs) 678 | model_kwargs_cfg["attention_mask"] = torch.ones((batch_size, 1), dtype=complete_attention_mask.dtype, device=complete_attention_mask.device) 679 | input_ids_cfg = torch.ones((batch_size, 1), dtype=input_ids.dtype, device=input_ids.device) 680 | input_ids_cfg[:, 0] = 2 681 | inputs_embeds_cfg = None 682 | model_kwargs_cfg = self._get_initial_cache_position_cfg(input_ids_cfg, model_kwargs_cfg) 683 | 684 | 685 | while read_finished == False: 686 | 687 | src_read_len = min(input_ids.shape[1], (turn + 1) * self.stream_n) 688 | if src_read_len == input_ids.shape[1]: 689 | read_finished = True 690 | 691 | 692 | this_turn_input_ids = input_ids[:, last_src_read_len:src_read_len] 693 | model_kwargs = self._get_initial_cache_position_after_each_read(this_turn_input_ids, model_kwargs) 694 | model_kwargs["attention_mask"] = torch.cat([model_kwargs["attention_mask"], complete_attention_mask[:, last_src_read_len:src_read_len]], dim=1) 695 | 696 | #self.prompt_length = input_ids.shape[1] if inputs_embeds is None else input_ids.shape[1] + inputs_embeds.shape[1] 697 | 698 | cur_len = cur_len + src_read_len - last_src_read_len 699 | 700 | 701 | last_src_read_len = src_read_len 702 | num_write_in_turn = 0 703 | 704 | 705 | while self._has_unfinished_sequences( 706 | this_peer_finished, synced_gpus, device=input_ids.device, cur_len=cur_len, max_length=max_length 707 | ): 708 | if (not read_finished) and (num_write_in_turn == self.stream_m): 709 | break 710 | 711 | # prepare model inputs 712 | model_inputs = self.prepare_inputs_for_generation(this_turn_input_ids, inputs_embeds, turn, num_write_in_turn, **model_kwargs) 713 | if self.infer_cfg != 1.0: 714 | model_inputs_cfg = self.prepare_inputs_for_generation_cfg(input_ids_cfg, inputs_embeds_cfg, **model_kwargs_cfg) 715 | 716 | 717 | # prepare variable output controls (note: some models won't accept all output controls) 718 | model_inputs.update({"output_attentions": output_attentions} if output_attentions else {}) 719 | model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {}) 720 | 721 | # forward pass to get next token 722 | outputs = self(**model_inputs, return_dict=True) 723 | if self.infer_cfg != 1.0: 724 | outputs_cfg = self(**model_inputs_cfg, return_dict=True) 725 | 726 | 727 | if synced_gpus and this_peer_finished: 728 | continue # don't waste resources running the code we don't need 729 | 730 | # Clone is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration 731 | # (the clone itself is always small) 732 | next_token_logits = outputs.logits.clone()[:, -1, :] 733 | eos_next_token_logits = next_token_logits.clone() 734 | if self.infer_cfg != 1.0: 735 | next_token_logits_cfg = outputs_cfg.logits.clone()[:, -1, :] 736 | next_token_logits = next_token_logits_cfg + self.infer_cfg * (next_token_logits - next_token_logits_cfg) 737 | # Store scores, attentions and hidden_states when required 738 | if return_dict_in_generate: 739 | if output_logits: 740 | raw_logits += (next_token_logits,) 741 | if output_attentions: 742 | decoder_attentions += ( 743 | (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) 744 | ) 745 | if self.config.is_encoder_decoder: 746 | cross_attentions += (outputs.cross_attentions,) 747 | 748 | if output_hidden_states: 749 | decoder_hidden_states += ( 750 | (outputs.decoder_hidden_states,) 751 | if self.config.is_encoder_decoder 752 | else (outputs.hidden_states,) 753 | ) 754 | 755 | # token selection 756 | if do_sample: 757 | next_embeds = self.scoreloss.sample(next_token_logits, temperature=1.0) # bsz, dim 758 | 759 | next_actions = torch.sigmoid(self.eos_head(eos_next_token_logits)) >= 0.5 # 0: continue, 1: stop 760 | # if read_finished: 761 | # next_tokens = torch.where(next_actions == 0, bos_token_id, eos_token_id) 762 | # else: 763 | # next_tokens = torch.full((batch_size,), bos_token_id ,device=next_actions.device) 764 | next_tokens = torch.where(next_actions == 0, bos_token_id, eos_token_id) 765 | else: 766 | raise NotImplementedError 767 | 768 | 769 | # finished sentences should have their next token be a padding token 770 | if has_eos_stopping_criteria: 771 | next_tokens = next_tokens * unfinished_sequences.unsqueeze(1) + pad_token_id * (1 - unfinished_sequences.unsqueeze(1)) 772 | 773 | # update generated ids, model inputs, and length for next step 774 | if inputs_embeds is not None: 775 | inputs_embeds = torch.cat([inputs_embeds, next_embeds[:, None, :]], dim=1) 776 | else: 777 | inputs_embeds = next_embeds[:, None, :] 778 | 779 | generate_ids = torch.cat([generate_ids, next_tokens], dim=-1) 780 | 781 | if self.infer_cfg != 1.0: 782 | if inputs_embeds_cfg is not None: 783 | inputs_embeds_cfg = torch.cat([inputs_embeds_cfg, next_embeds[:, None, :]], dim=1) 784 | else: 785 | inputs_embeds_cfg = next_embeds[:, None, :] 786 | 787 | input_ids_cfg = torch.cat([input_ids_cfg, next_tokens], dim=-1) 788 | 789 | 790 | model_kwargs = self._update_model_kwargs_for_generation( 791 | outputs, 792 | model_kwargs, 793 | ) 794 | 795 | 796 | if self.infer_cfg != 1.0: 797 | model_kwargs_cfg = self._update_model_kwargs_for_generation(outputs_cfg, model_kwargs_cfg) 798 | 799 | 800 | 801 | unfinished_sequences = unfinished_sequences & ~stopping_criteria(generate_ids, None) 802 | this_peer_finished = unfinished_sequences.max() == 0 803 | cur_len += 1 804 | num_write_in_turn += 1 805 | 806 | # This is needed to properly delete outputs.logits which may be very large for first iteration 807 | # Otherwise a reference to outputs is kept which keeps the logits alive in the next iteration 808 | del outputs 809 | 810 | turn += 1 811 | 812 | if return_dict_in_generate: 813 | return ModifiedGenerateDecoderOnlyOutput( 814 | sequences=generate_ids, 815 | features=inputs_embeds, 816 | scores=scores, 817 | logits=raw_logits, 818 | attentions=decoder_attentions, 819 | hidden_states=decoder_hidden_states, 820 | past_key_values=model_kwargs.get("past_key_values"), 821 | ) 822 | else: 823 | return (generate_ids, inputs_embeds) 824 | 825 | def _update_model_kwargs_for_generation( 826 | self, 827 | outputs, 828 | model_kwargs: Dict[str, Any], 829 | num_new_tokens: int = 1, 830 | ) -> Dict[str, Any]: 831 | # update past_key_values keeping its naming used in model code 832 | cache_name, cache = self._extract_past_from_model_output(outputs) 833 | model_kwargs[cache_name] = cache 834 | 835 | # update attention mask 836 | if "attention_mask" in model_kwargs: 837 | attention_mask = model_kwargs["attention_mask"] 838 | model_kwargs["attention_mask"] = torch.cat( 839 | [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1 840 | ) 841 | 842 | if model_kwargs.get("use_cache", True): 843 | model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + num_new_tokens 844 | else: 845 | raise NotImplementedError 846 | 847 | return model_kwargs 848 | -------------------------------------------------------------------------------- /sled/trainer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import List, Optional 3 | 4 | from torch.utils.data import Dataset 5 | from transformers import Trainer 6 | from transformers.trainer import has_length 7 | from transformers.trainer_pt_utils import LengthGroupedSampler 8 | 9 | class SpeechLengthGroupedSampler(LengthGroupedSampler): 10 | r""" 11 | Sampler that samples indices in a way that groups together features of the dataset of roughly the same length while 12 | keeping a bit of randomness. 13 | """ 14 | 15 | def __init__( 16 | self, 17 | batch_size: int, 18 | dataset: Optional[Dataset] = None, 19 | generator=None, 20 | ): 21 | 22 | self.batch_size = batch_size 23 | 24 | lengths = [len(feature["audio"]["array"]) for feature in dataset] 25 | self.lengths = lengths 26 | self.generator = generator 27 | 28 | 29 | 30 | 31 | class SpeechLlamaTrainer(Trainer): 32 | 33 | def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]: 34 | if self.train_dataset is None or not has_length(self.train_dataset): 35 | return None 36 | 37 | 38 | if self.args.group_by_speech_length: 39 | return SpeechLengthGroupedSampler( 40 | self.args.train_batch_size * self.args.gradient_accumulation_steps, 41 | dataset=self.train_dataset, 42 | ) 43 | else: 44 | return super()._get_train_sampler() -------------------------------------------------------------------------------- /sled/trainer_libriheavy.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import List, Optional 3 | 4 | from torch.utils.data import Dataset 5 | from transformers import Trainer 6 | from transformers.trainer import has_length 7 | from transformers.trainer_pt_utils import LengthGroupedSampler 8 | 9 | class SpeechLengthGroupedSampler(LengthGroupedSampler): 10 | r""" 11 | Sampler that samples indices in a way that groups together features of the dataset of roughly the same length while 12 | keeping a bit of randomness. 13 | """ 14 | 15 | def __init__( 16 | self, 17 | batch_size: int, 18 | dataset: Optional[Dataset] = None, 19 | generator=None, 20 | ): 21 | 22 | self.batch_size = batch_size 23 | 24 | lengths = [int(feature["duration"] * 75) for feature in dataset] 25 | self.lengths = lengths 26 | self.generator = generator 27 | 28 | 29 | 30 | 31 | class SpeechLlamaTrainer(Trainer): 32 | 33 | def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]: 34 | if self.train_dataset is None or not has_length(self.train_dataset): 35 | return None 36 | 37 | 38 | if self.args.group_by_speech_length: 39 | return SpeechLengthGroupedSampler( 40 | self.args.train_batch_size * self.args.gradient_accumulation_steps, 41 | dataset=self.train_dataset, 42 | ) 43 | else: 44 | return super()._get_train_sampler() 45 | --------------------------------------------------------------------------------