├── .gitignore
├── README.md
├── example_prompt.flac
├── pyproject.toml
├── scripts
├── eval_continue.py
├── eval_reference.py
├── eval_stream.py
├── run_offline.py
├── run_stream.py
├── run_voice_clone.py
├── train_libriheavy.py
└── train_libriheavy_stream.py
├── shell_scripts
├── eval.sh
├── run.sh
├── train_libriheavy.sh
└── train_libriheavy_stream.sh
├── sled.egg-info
├── PKG-INFO
├── SOURCES.txt
├── dependency_links.txt
├── requires.txt
└── top_level.txt
├── sled
├── __init__.py
├── energy_distance.py
├── modeling_llama_with_dropout.py
├── sled.py
├── sled_stream.py
├── trainer.py
└── trainer_libriheavy.py
└── tokenizer_bpe_libriheavy
├── merges.txt
└── vocab.json
/.gitignore:
--------------------------------------------------------------------------------
1 | **/__pycache__/
2 | /runs/
3 | *.sh
4 | *.wav
5 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # 🛷SLED-TTS: Efficient Speech Language Modeling via Energy Distance in Continuous Latent Space
2 | > **Authors: [Zhengrui Ma](https://scholar.google.com/citations?user=dUgq6tEAAAAJ), [Yang Feng*](https://people.ucas.edu.cn/~yangfeng?language=en), [Chenze Shao](https://scholar.google.com/citations?user=LH_rZf8AAAAJ&hl), [Fandong Meng](https://fandongmeng.github.io/), [Jie Zhou](https://scholar.google.com.hk/citations?user=OijxQCMAAAAJ&hl=en), [Min Zhang](https://scholar.google.com/citations?user=CncXH-YAAAAJ&hl=en)**
3 |
4 | [](https://arxiv.org/abs/2505.13181)
5 | [](https://github.com/ictnlp/SLED-TTS)
6 | [](https://huggingface.co/collections/ICTNLP/sled-tts-680253e19c889010a1a376ac)
7 | [](https://www.wechat.com)
8 | [](https://ict.cas.cn)
9 |
10 | ## News
11 | - **Our paper has been released on [arXiv](https://arxiv.org/abs/2505.13181).**
12 |
13 | ## Key features
14 | - **Continuous Autoregressive Modeling**: SLED models speech in a continuous latent space, eliminating the need for complex hierarchical architectures.
15 | - **Streaming Synthesis**: SLED supports streaming synthesis, enabling speech generation to start as soon as the text stream begins.
16 | - **Voice Cloning**: Capable of generating speech based on a 3-second prefix or reference utterance as prompt.
17 |
18 |
19 | ## Demo
20 | You can check SLED in action by exploring the [demo page](https://sled-demo.github.io/).
21 |
22 |

23 |

24 |
25 |
26 | ## Available Models on Hugging Face
27 |
28 | We are currently offering two English models trained on LibriHeavy on [Hugging Face](https://huggingface.co/collections/ICTNLP/sled-tts-680253e19c889010a1a376ac):
29 |
30 | 1. **[SLED-TTS-Libriheavy](https://huggingface.co/ICTNLP/SLED-TTS-Libriheavy)**: This model is trained on Libriheavy and provides high-quality text-to-speech synthesis.
31 |
32 | 2. **[SLED-TTS-Streaming-Libriheavy](https://huggingface.co/ICTNLP/SLED-TTS-Streaming-Libriheavy)**: This variant supports **streaming decoding**, which generates a 0.6-second speech chunk for every 5 text tokens received.
33 |
34 | **Alternatively, you can train SLED on your own data by following the guidelines below.**
35 |
36 | ## Usage
37 | **We provide the training and inference code for SLED-TTS.**
38 |
39 | ### Installation
40 | ``` sh
41 | git clone https://github.com/ictnlp/SLED-TTS.git
42 | cd SLED-TTS
43 | pip install -e ./
44 | ```
45 |
46 | We currently utilize the sum of the first 8 embedding vectors from [Encodec_24khz](https://huggingface.co/facebook/encodec_24khz) as the continuous latent vector. To proceed, ensure that [Encodec_24khz](https://huggingface.co/facebook/encodec_24khz) is downloaded and cached in your HuggingFace dir.
47 |
48 | ### Inference
49 | - Set the `CHECKPOINT` variable to the path of the cached **[SLED-TTS-Libriheavy](https://huggingface.co/ICTNLP/SLED-TTS-Libriheavy)** or **[SLED-TTS-Streaming-Libriheavy](https://huggingface.co/ICTNLP/SLED-TTS-Streaming-Libriheavy)** model.
50 | - Diverse generation results can be obtained by varying the `SEED` variable.
51 | - Use `-bf16` flag to enable bf16 inference.
52 | ``` sh
53 | CHECKPOINT=/path/to/checkpoint
54 | CFG=2.0
55 | SEED=0
56 | ```
57 | ***Offline Inference***
58 | ``` sh
59 | python scripts/run_offline.py \
60 | --model_name_or_path ${CHECKPOINT} \
61 | --cfg ${CFG} \
62 | --input "My remark pleases him, but I soon prove to him that it is not the right way to speak. However perfect may have been the language of that ancient writer." \
63 | --seed ${SEED}
64 | ```
65 | ***Streaming Inference***
66 | ``` sh
67 | python scripts/run_stream.py \
68 | --model_name_or_path ${CHECKPOINT} \
69 | --cfg ${CFG} \
70 | --input "My remark pleases him, but I soon prove to him that it is not the right way to speak. However perfect may have been the language of that ancient writer." \
71 | --seed ${SEED}
72 | # Please note that we have simulated the generation in a streaming environment in run_stream.py for evaluating its quality.
73 | # However, the existing code does not actually provide a streaming API.
74 | ```
75 | ***Voice Clone***
76 |
77 | You can adjust the prompt speech by setting `--prompt_text` and `--prompt_audio`.
78 | ``` sh
79 | python scripts/run_voice_clone.py \
80 | --prompt_text "Were I in the warm room with all the splendor and magnificence!" \
81 | --prompt_audio "example_prompt.flac" \
82 | --model_name_or_path ${CHECKPOINT} \
83 | --cfg ${CFG} \
84 | --input "Perhaps the other trees from the forest will come to look at me!" \
85 | --seed ${SEED}
86 | ```
87 |
88 | ### Training
89 |
90 | ***Data Processing***
91 |
92 | Process the LibriHeavy data so that each line follows the JSON format shown below.
93 | ```
94 | {"id": "large/10022/essayoncriticism_1505_librivox_64kb_mp3/essayoncriticism_01_pope_64kb_5", "start": 610.32, "duration": 19.76, "supervisions": [{"text": "Hail! bards triumphant! born in happier days; Immortal heirs of universal praise! Whose honors with increase of ages grow, As streams roll down, enlarging as they flow; Nations unborn your mighty names shall sound, [193] And worlds applaud that must not yet be found!"}], "recording": {"sources": [{"source": "download/librilight/large/10022/essayoncriticism_1505_librivox_64kb_mp3/essayoncriticism_01_pope_64kb.flac"}], "sampling_rate": 16000}, "type": "MonoCut"}
95 | ```
96 | Or you can use the manifest of LibriHeavy available at this [URL](https://huggingface.co/datasets/ICTNLP/LibriHeavy_manifest). For your own datasets, process them into a similar format.
97 |
98 | ***Training Offline Model***
99 | ``` sh
100 | OUTPUT_DIR=./runs/libriheavy
101 | mkdir -p $OUTPUT_DIR
102 | LOG_FILE=${OUTPUT_DIR}/log
103 |
104 | BATCH_SIZE=8
105 | UPDATE_FREQ=8
106 | # assume 8 proc per node, then WORLD_SIZE * 8 * BATCH_SIZE * UPDATE_FREQ == 512
107 |
108 | torchrun --nnodes ${WORLD_SIZE} --node_rank ${RANK} --nproc_per_node 8 --master_addr ${MASTER_ADDR} --master_port ${MASTER_PORT} \
109 | ./scripts/train_libriheavy.py \
110 | --training_cfg 0.1 \
111 | --num_hidden_layers 12 --diffloss_d 6 --noise_channels 128 \
112 | --dataloader_num_workers 8 \
113 | --dataloader_pin_memory True \
114 | --remove_unused_columns False \
115 | --label_names audio_inputs \
116 | --group_by_speech_length \
117 | --do_train \
118 | --do_eval \
119 | --eval_strategy steps \
120 | --eval_steps 10000 \
121 | --prediction_loss_only \
122 | --per_device_train_batch_size ${BATCH_SIZE} \
123 | --per_device_eval_batch_size 24 \
124 | --gradient_accumulation_steps ${UPDATE_FREQ} \
125 | --bf16 \
126 | --learning_rate 5e-4 \
127 | --weight_decay 0.01 \
128 | --adam_beta1 0.9 \
129 | --adam_beta2 0.999 \
130 | --adam_epsilon 1e-8 \
131 | --max_grad_norm 1.0 \
132 | --max_steps 300000 \
133 | --lr_scheduler_type "linear" \
134 | --warmup_steps 32000 \
135 | --logging_first_step \
136 | --logging_steps 100 \
137 | --save_steps 10000 \
138 | --save_total_limit 10 \
139 | --output_dir ${OUTPUT_DIR} \
140 | --report_to tensorboard \
141 | --disable_tqdm True \
142 | --ddp_timeout 3600 --overwrite_output_dir
143 |
144 | ```
145 |
146 | ***Training Streaming Model***
147 | ``` sh
148 | OUTPUT_DIR=./runs/libriheavy_stream
149 | mkdir -p $OUTPUT_DIR
150 | LOG_FILE=${OUTPUT_DIR}/log
151 |
152 | BATCH_SIZE=8
153 | UPDATE_FREQ=8
154 | # assume 8 proc per node, then WORLD_SIZE * 8 * BATCH_SIZE * UPDATE_FREQ == 512
155 |
156 | torchrun --nnodes ${WORLD_SIZE} --node_rank ${RANK} --nproc_per_node 8 --master_addr ${MASTER_ADDR} --master_port ${MASTER_PORT} \
157 | ./scripts/train_libriheavy_stream.py \
158 | --finetune_path ./runs/libriheavy/checkpoint-300000/model.safetensors \
159 | --stream_n 5 --stream_m 45 \
160 | --training_cfg 0.1 \
161 | --num_hidden_layers 12 --diffloss_d 6 --noise_channels 128 \
162 | --dataloader_num_workers 8 \
163 | --dataloader_pin_memory True \
164 | --remove_unused_columns False \
165 | --label_names audio_inputs \
166 | --group_by_speech_length \
167 | --do_train \
168 | --do_eval \
169 | --eval_strategy steps \
170 | --eval_steps 10000 \
171 | --prediction_loss_only \
172 | --per_device_train_batch_size ${BATCH_SIZE} \
173 | --per_device_eval_batch_size 24 \
174 | --gradient_accumulation_steps ${UPDATE_FREQ} \
175 | --bf16 \
176 | --learning_rate 3e-4 \
177 | --weight_decay 0.01 \
178 | --adam_beta1 0.9 \
179 | --adam_beta2 0.999 \
180 | --adam_epsilon 1e-8 \
181 | --max_grad_norm 1.0 \
182 | --max_steps 100000 \
183 | --lr_scheduler_type "linear" \
184 | --warmup_steps 10000 \
185 | --logging_first_step \
186 | --logging_steps 100 \
187 | --save_steps 10000 \
188 | --save_total_limit 10 \
189 | --output_dir ${OUTPUT_DIR} \
190 | --report_to tensorboard \
191 | --disable_tqdm True \
192 | --ddp_timeout 3600 --overwrite_output_dir
193 | ```
194 | ### BF16 Support
195 | By setting the `-bf16` flag, the model will load in bf16 during inference and in fp32 during training (for mixed precision training). To enable pure bf16 training, you can change
196 | https://github.com/ictnlp/SLED-TTS/blob/69a0a77d37180ec711a21f39f1b6bffa8b068072/scripts/train_libriheavy.py#L298
197 | to
198 | ```
199 | torch_dtype = torch.bfloat16 if training_args.bf16 else None
200 | ```
201 | However, Encodec should always execute in fp32 to maintain the precision of latents. Therefore, we load Encodec in fp32 and downcast the encoded latent to bf16.
202 |
203 |
204 | ## Citation
205 | If you have any questions, please feel free to submit an issue or contact `mazhengrui21b@ict.ac.cn`.
206 |
207 | If our work is useful for you, please cite as:
208 |
209 | ```
210 | @misc{ma2025efficientspeechlanguagemodeling,
211 | title={Efficient Speech Language Modeling via Energy Distance in Continuous Latent Space},
212 | author={Zhengrui Ma and Yang Feng and Chenze Shao and Fandong Meng and Jie Zhou and Min Zhang},
213 | year={2025},
214 | eprint={2505.13181},
215 | archivePrefix={arXiv},
216 | primaryClass={cs.CL},
217 | url={https://arxiv.org/abs/2505.13181},
218 | }
219 | ```
220 |
--------------------------------------------------------------------------------
/example_prompt.flac:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ictnlp/SLED-TTS/b8ed10d9953160efd8a0538b4ea5af80a57c9e96/example_prompt.flac
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["setuptools", "wheel"]
3 | build-backend = "setuptools.build_meta"
4 |
5 | [project]
6 | name = "sled"
7 | version = "0.1.0"
8 | description = "Implementation of SLED"
9 | authors = [
10 | { name="Zhengrui Ma", email="mazhengrui21b@ict.ac.cn" }
11 | ]
12 | dependencies = [
13 | "torch==2.5.1",
14 | "torchaudio==2.5.1",
15 | "transformers==4.47.0",
16 | "datasets==3.1.0",
17 | "accelerate==1.2.0",
18 | "numpy==1.26.4",
19 | "librosa==0.10.2.post1",
20 | "soundfile==0.12.1"
21 | ]
22 |
23 | [tool.setuptools]
24 | packages = ["sled"]
25 |
--------------------------------------------------------------------------------
/scripts/eval_continue.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import logging
4 | from typing import Tuple
5 | from pathlib import Path
6 |
7 | import torch
8 | import torchaudio
9 |
10 |
11 | from datasets import load_dataset, Audio
12 | from transformers import RobertaTokenizer, AutoProcessor
13 | from transformers.data.data_collator import pad_without_fast_tokenizer_warning
14 | from accelerate.utils import set_seed
15 |
16 | import pdb
17 |
18 |
19 | from sled.sled import SpeechLlamaForCausalLM
20 |
21 | BANDWIDTH=6
22 | SAMPLING_RATE=24000
23 | STRIDE=320
24 | FREQ=75
25 | MIN_LEN=4.0
26 | MAX_LEN=10.0
27 | PROMPT_LEN=3
28 |
29 |
30 |
31 | logging.basicConfig(
32 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
33 | datefmt="%m/%d/%Y %H:%M:%S",
34 | level=logging.INFO,
35 | )
36 | logger = logging.getLogger(__name__)
37 |
38 |
39 |
40 | def adjust_length_to_model(length, max_sequence_length):
41 | assert max_sequence_length > 0
42 | if length <= 0 and max_sequence_length > 0:
43 | length = max_sequence_length
44 | elif 0 < max_sequence_length < length:
45 | length = max_sequence_length # No generation bigger than model size
46 | return length
47 |
48 |
49 | def filter_function(example):
50 | return ((len(example["audio"]["array"]) / SAMPLING_RATE) < MAX_LEN) and ((len(example["audio"]["array"]) / SAMPLING_RATE) > MIN_LEN)
51 |
52 |
53 |
54 | def main():
55 | parser = argparse.ArgumentParser()
56 |
57 | parser.add_argument(
58 | "--model_name_or_path",
59 | default=None,
60 | type=str,
61 | required=True,
62 | )
63 | parser.add_argument("--max_length", type=int, default=0)
64 | parser.add_argument(
65 | "--cfg",
66 | type=float,
67 | default=1.0,
68 | )
69 | parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
70 | parser.add_argument("--num_return_sequences", type=int, default=1, help="The number of samples to generate.")
71 | parser.add_argument(
72 | "--fp16",
73 | action="store_true",
74 | )
75 | parser.add_argument(
76 | "--bf16",
77 | action="store_true",
78 | )
79 | parser.add_argument("--batch_size", type=int, default=32)
80 | args = parser.parse_args()
81 |
82 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
83 | torch_dtype = torch.bfloat16 if args.bf16 else (torch.float16 if args.fp16 else torch.float32)
84 | logger.warning(f"device: {device}, 16-bits inference: {args.fp16 or args.bf16}")
85 |
86 | if args.seed is not None:
87 | set_seed(args.seed)
88 |
89 | # Initialize the model and tokenizer
90 | tokenizer = RobertaTokenizer.from_pretrained(args.model_name_or_path)
91 | eos_token_id = tokenizer.eos_token_id
92 |
93 | model = SpeechLlamaForCausalLM.from_pretrained(args.model_name_or_path, torch_dtype=torch_dtype)
94 | model.infer_cfg = args.cfg
95 | model.initialize_codec("facebook/encodec_24khz")
96 | model.to(device)
97 |
98 |
99 | processor = AutoProcessor.from_pretrained("facebook/encodec_24khz")
100 |
101 | assert tokenizer.pad_token is not None
102 | logger.info(f"tokenizer pad token: {tokenizer.pad_token}")
103 |
104 |
105 | max_seq_length = getattr(model.config, "max_position_embeddings", 0)
106 | args.max_length = adjust_length_to_model(args.max_length, max_sequence_length=max_seq_length)
107 | logger.info(args)
108 |
109 |
110 | eval_dataset = load_dataset("yoom618/librispeech_pc", split="test.clean")
111 | eval_dataset = eval_dataset.cast_column("audio", Audio(sampling_rate=SAMPLING_RATE))
112 |
113 |
114 | logger.info(f"original eval dataset: {len(eval_dataset)} samples.")
115 | eval_dataset = eval_dataset.filter(filter_function)
116 | logger.info(f"filtered eval dataset: {len(eval_dataset)} samples.")
117 |
118 | tokenized_eval_dataset = eval_dataset.map(
119 | lambda example: tokenizer(example["text"]),
120 | batched=True,
121 | )
122 |
123 |
124 | batch_size = args.batch_size
125 | output_path = Path("eval_continue")
126 | output_path.mkdir(parents=True, exist_ok=True)
127 |
128 | with torch.no_grad():
129 | for i in range(0, len(tokenized_eval_dataset), batch_size):
130 | batch = tokenized_eval_dataset.select(range(i, min(i + batch_size, len(tokenized_eval_dataset))))
131 |
132 | input_ids = [{"input_ids":instance["input_ids"]} for instance in batch]
133 |
134 | encodes = pad_without_fast_tokenizer_warning(
135 | tokenizer,
136 | input_ids,
137 | padding=True,
138 | return_attention_mask=True,
139 | return_tensors="pt"
140 | )
141 |
142 | input_ids = encodes["input_ids"].to(device)
143 | attention_mask = encodes["attention_mask"].to(device)
144 | text_input_length = input_ids.shape[1]
145 |
146 | audio_arrays = [instance["audio"]["array"] for instance in batch]
147 | audio_inputs = processor(raw_audio=audio_arrays, sampling_rate=SAMPLING_RATE, return_tensors="pt") # 'padding_mask': b,t 'input_values': b,c,t
148 |
149 |
150 | encoder_outputs = model.codec.encode(audio_inputs["input_values"].to(device), audio_inputs["padding_mask"].to(device), bandwidth=BANDWIDTH) #1,b,r,t, 1 due to one chunk
151 | speech_inputs_embeds = model.codec.quantizer.decode(encoder_outputs.audio_codes[0].transpose(0, 1)) #b,d,t
152 |
153 | speech_attention_mask = audio_inputs["padding_mask"][..., ::STRIDE].to(device)
154 | assert speech_inputs_embeds.size(-1) == speech_attention_mask.size(-1)
155 | speech_inputs_embeds = speech_inputs_embeds.transpose(1,2).to(model.dtype) #b,t,d
156 |
157 |
158 | speech_inputs_embeds = speech_inputs_embeds[:,:FREQ * PROMPT_LEN,:]
159 | speech_attention_mask = speech_attention_mask[:,:FREQ * PROMPT_LEN]
160 | speech_input_length = speech_inputs_embeds.shape[1]
161 |
162 |
163 | new_attention_mask = torch.concat([attention_mask, speech_attention_mask], dim=1)
164 |
165 |
166 | output_sequences = model.generate(
167 | input_ids=input_ids,
168 | inputs_embeds=speech_inputs_embeds,
169 | attention_mask=new_attention_mask,
170 | max_length=args.max_length,
171 | do_sample=True,
172 | num_return_sequences=args.num_return_sequences,
173 | )
174 |
175 |
176 |
177 | new_embeds = output_sequences[1]
178 | generated_ids = output_sequences[0][:, text_input_length:]
179 |
180 | new_audio_values = model.codec.decoder(new_embeds.transpose(-1,-2).float())
181 |
182 | wav_len = (generated_ids.ne(eos_token_id).sum(dim=-1) + speech_input_length) * STRIDE
183 |
184 | for i in range(len(wav_len)):
185 | id = batch["id"][i]
186 | torchaudio.save(output_path / f"{id}.wav", new_audio_values[i][:,:wav_len[i]].cpu(), SAMPLING_RATE)
187 |
188 | return
189 |
190 |
191 | if __name__ == "__main__":
192 | main()
193 |
--------------------------------------------------------------------------------
/scripts/eval_reference.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import logging
4 | from typing import Tuple
5 | from pathlib import Path
6 |
7 | import torch
8 | import torchaudio
9 |
10 |
11 | from datasets import load_dataset, Audio
12 | from transformers import AutoTokenizer, RobertaTokenizer, AutoProcessor
13 | from transformers.data.data_collator import pad_without_fast_tokenizer_warning
14 | from accelerate.utils import set_seed
15 |
16 | from sled.sled import SpeechLlamaForCausalLM
17 |
18 |
19 |
20 | import pdb
21 |
22 | BANDWIDTH=6
23 | SAMPLING_RATE=24000
24 | STRIDE=320
25 | FREQ=75
26 | MIN_LEN=4.0
27 | MAX_LEN=10.0
28 |
29 |
30 |
31 |
32 | logging.basicConfig(
33 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
34 | datefmt="%m/%d/%Y %H:%M:%S",
35 | level=logging.INFO,
36 | )
37 | logger = logging.getLogger(__name__)
38 |
39 |
40 |
41 | def adjust_length_to_model(length, max_sequence_length):
42 | assert max_sequence_length > 0
43 | if length <= 0 and max_sequence_length > 0:
44 | length = max_sequence_length
45 | elif 0 < max_sequence_length < length:
46 | length = max_sequence_length # No generation bigger than model size
47 | return length
48 |
49 |
50 | def filter_function(example):
51 | return ((len(example["audio"]["array"]) / SAMPLING_RATE) < MAX_LEN) and ((len(example["audio"]["array"]) / SAMPLING_RATE) > MIN_LEN)
52 |
53 |
54 |
55 | def main():
56 | parser = argparse.ArgumentParser()
57 |
58 | parser.add_argument(
59 | "--model_name_or_path",
60 | default=None,
61 | type=str,
62 | required=True,
63 | )
64 | parser.add_argument("--max_length", type=int, default=0)
65 | parser.add_argument(
66 | "--cfg",
67 | type=float,
68 | default=1.0,
69 | )
70 | parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
71 | parser.add_argument("--num_return_sequences", type=int, default=1, help="The number of samples to generate.")
72 | parser.add_argument(
73 | "--fp16",
74 | action="store_true",
75 | )
76 | parser.add_argument(
77 | "--bf16",
78 | action="store_true",
79 | )
80 | parser.add_argument("--batch_size", type=int, default=1)
81 | args = parser.parse_args()
82 |
83 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
84 | torch_dtype = torch.bfloat16 if args.bf16 else (torch.float16 if args.fp16 else torch.float32)
85 | logger.warning(f"device: {device}, 16-bits inference: {args.fp16 or args.bf16}")
86 |
87 | if args.seed is not None:
88 | set_seed(args.seed)
89 |
90 | # Initialize the model and tokenizer
91 | tokenizer = RobertaTokenizer.from_pretrained(args.model_name_or_path)
92 | eos_token_id = tokenizer.eos_token_id
93 |
94 | model = SpeechLlamaForCausalLM.from_pretrained(args.model_name_or_path, torch_dtype=torch_dtype)
95 | model.infer_cfg = args.cfg
96 | model.initialize_codec("facebook/encodec_24khz")
97 | model.to(device)
98 |
99 |
100 | processor = AutoProcessor.from_pretrained("facebook/encodec_24khz")
101 |
102 | assert tokenizer.pad_token is not None
103 | logger.info(f"tokenizer pad token: {tokenizer.pad_token}")
104 |
105 |
106 | max_seq_length = getattr(model.config, "max_position_embeddings", 0)
107 | args.max_length = adjust_length_to_model(args.max_length, max_sequence_length=max_seq_length)
108 | logger.info(args)
109 |
110 |
111 | eval_dataset = load_dataset("yoom618/librispeech_pc", split="test.clean")
112 | eval_dataset = eval_dataset.cast_column("audio", Audio(sampling_rate=SAMPLING_RATE))
113 |
114 |
115 | logger.info(f"original eval dataset: {len(eval_dataset)} samples.")
116 | eval_dataset = eval_dataset.filter(filter_function)
117 | logger.info(f"filtered eval dataset: {len(eval_dataset)} samples.")
118 |
119 |
120 | batch_size = args.batch_size
121 | assert batch_size == 1
122 |
123 | pathname = "eval_reference"
124 | output_path = Path(pathname)
125 | output_path_concat = Path(pathname + "_concat")
126 |
127 | output_path.mkdir(parents=True, exist_ok=True)
128 | output_path_concat.mkdir(parents=True, exist_ok=True)
129 | last_index = None
130 | with torch.no_grad():
131 | #reverse to implement
132 | for i in range(len(eval_dataset) - 1, -1, -batch_size):
133 | current_sample = eval_dataset[i]
134 | current_speaker_id = current_sample["speaker_id"]
135 |
136 | if last_index is None:
137 | last_index = i
138 |
139 | if i != 0:
140 | prompt_sample = eval_dataset[i - 1]
141 | if prompt_sample["speaker_id"] != current_speaker_id:
142 | prompt_sample = eval_dataset[last_index]
143 | last_index = None
144 | else:
145 | prompt_sample = eval_dataset[last_index]
146 | last_index = None
147 |
148 |
149 | text_to_synthesize = current_sample["text"]
150 | prompt_text = prompt_sample["text"]
151 |
152 | prompt_audio = prompt_sample["audio"]["array"]
153 |
154 | input_text = [prompt_text + " " + text_to_synthesize]
155 |
156 | batch_encoded = tokenizer.batch_encode_plus(
157 | input_text,
158 | add_special_tokens=True,
159 | padding="longest",
160 | truncation=True,
161 | return_tensors="pt"
162 | )
163 |
164 | input_ids = batch_encoded["input_ids"].to(device)
165 | attention_mask = batch_encoded["attention_mask"].to(device)
166 | text_input_length = input_ids.shape[1]
167 |
168 | audio_arrays = [prompt_audio]
169 | audio_inputs = processor(raw_audio=audio_arrays, sampling_rate=SAMPLING_RATE, return_tensors="pt") # 'padding_mask': b,t 'input_values': b,c,t
170 |
171 |
172 | encoder_outputs = model.codec.encode(audio_inputs["input_values"].to(device), audio_inputs["padding_mask"].to(device), bandwidth=BANDWIDTH) #1,b,r,t, 1 due to one chunk
173 | speech_inputs_embeds = model.codec.quantizer.decode(encoder_outputs.audio_codes[0].transpose(0, 1)) #b,d,t
174 |
175 | speech_attention_mask = audio_inputs["padding_mask"][..., ::STRIDE].to(device)
176 | assert speech_inputs_embeds.size(-1) == speech_attention_mask.size(-1)
177 | speech_inputs_embeds = speech_inputs_embeds.transpose(1,2).to(model.dtype) #b,t,d
178 |
179 |
180 | speech_input_length = speech_inputs_embeds.shape[1]
181 |
182 |
183 | new_attention_mask = torch.concat([attention_mask, speech_attention_mask], dim=1)
184 |
185 |
186 | output_sequences = model.generate(
187 | input_ids=input_ids,
188 | inputs_embeds=speech_inputs_embeds,
189 | attention_mask=new_attention_mask,
190 | max_length=args.max_length,
191 | do_sample=True,
192 | num_return_sequences=args.num_return_sequences,
193 | )
194 |
195 |
196 |
197 | new_embeds = output_sequences[1]
198 | generated_ids = output_sequences[0][:, text_input_length:]
199 |
200 | new_audio_values = model.codec.decoder(new_embeds.transpose(-1,-2).float())
201 |
202 | wav_len = (generated_ids.ne(eos_token_id).sum(dim=-1) + speech_input_length) * STRIDE
203 |
204 |
205 | assert len(wav_len) == 1
206 | for i in range(len(wav_len)):
207 | id = current_sample["id"]
208 | torchaudio.save(output_path / f"{id}.wav", new_audio_values[i][:, speech_input_length* STRIDE:wav_len[i]].cpu(), SAMPLING_RATE)
209 | torchaudio.save(output_path_concat / f"{id}.wav", new_audio_values[i][:,:wav_len[i]].cpu(), SAMPLING_RATE)
210 | return
211 |
212 |
213 | if __name__ == "__main__":
214 | main()
215 |
--------------------------------------------------------------------------------
/scripts/eval_stream.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import logging
4 | from typing import Tuple
5 | from pathlib import Path
6 |
7 | import torch
8 | import torchaudio
9 |
10 |
11 | from datasets import load_dataset, Audio
12 | from transformers import RobertaTokenizer, AutoProcessor
13 | from transformers.data.data_collator import pad_without_fast_tokenizer_warning
14 | from accelerate.utils import set_seed
15 |
16 | from sled.sled_stream import SpeechLlamaForCausalLM
17 |
18 |
19 |
20 | import pdb
21 |
22 | BANDWIDTH=6
23 | SAMPLING_RATE=24000
24 | STRIDE=320
25 | FREQ=75
26 | MIN_LEN=4.0
27 | MAX_LEN=10.0
28 |
29 |
30 |
31 | logging.basicConfig(
32 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
33 | datefmt="%m/%d/%Y %H:%M:%S",
34 | level=logging.INFO,
35 | )
36 | logger = logging.getLogger(__name__)
37 |
38 |
39 |
40 | def adjust_length_to_model(length, max_sequence_length):
41 | assert max_sequence_length > 0
42 | if length <= 0 and max_sequence_length > 0:
43 | length = max_sequence_length
44 | elif 0 < max_sequence_length < length:
45 | length = max_sequence_length # No generation bigger than model size
46 | return length
47 |
48 |
49 | def filter_function(example):
50 | return ((len(example["audio"]["array"]) / SAMPLING_RATE) < MAX_LEN) and ((len(example["audio"]["array"]) / SAMPLING_RATE) > MIN_LEN)
51 |
52 |
53 |
54 |
55 |
56 | def main():
57 | parser = argparse.ArgumentParser()
58 |
59 | parser.add_argument(
60 | "--model_name_or_path",
61 | default=None,
62 | type=str,
63 | required=True,
64 | )
65 | parser.add_argument("--max_length", type=int, default=0)
66 | parser.add_argument(
67 | "--cfg",
68 | type=float,
69 | default=1.0,
70 | )
71 | parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
72 | parser.add_argument("--num_return_sequences", type=int, default=1, help="The number of samples to generate.")
73 | parser.add_argument(
74 | "--fp16",
75 | action="store_true",
76 | )
77 | parser.add_argument(
78 | "--bf16",
79 | action="store_true",
80 | )
81 | parser.add_argument("--batch_size", type=int, default=32)
82 | args = parser.parse_args()
83 |
84 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
85 | torch_dtype = torch.bfloat16 if args.bf16 else (torch.float16 if args.fp16 else torch.float32)
86 | logger.warning(f"device: {device}, 16-bits inference: {args.fp16 or args.bf16}")
87 |
88 | if args.seed is not None:
89 | set_seed(args.seed)
90 |
91 | # Initialize the model and tokenizer
92 | tokenizer = RobertaTokenizer.from_pretrained(args.model_name_or_path)
93 | eos_token_id = tokenizer.eos_token_id
94 |
95 | model = SpeechLlamaForCausalLM.from_pretrained(args.model_name_or_path, torch_dtype=torch_dtype)
96 | model.infer_cfg = args.cfg
97 | model.initialize_codec("facebook/encodec_24khz")
98 | model.to(device)
99 |
100 |
101 | processor = AutoProcessor.from_pretrained("facebook/encodec_24khz")
102 |
103 | assert tokenizer.pad_token is not None
104 | logger.info(f"tokenizer pad token: {tokenizer.pad_token}")
105 |
106 |
107 | max_seq_length = getattr(model.config, "max_position_embeddings", 0)
108 | args.max_length = adjust_length_to_model(args.max_length, max_sequence_length=max_seq_length)
109 | logger.info(args)
110 |
111 |
112 | eval_dataset = load_dataset("yoom618/librispeech_pc", split="test.clean")
113 | eval_dataset = eval_dataset.cast_column("audio", Audio(sampling_rate=SAMPLING_RATE))
114 |
115 |
116 | logger.info(f"original eval dataset: {len(eval_dataset)} samples.")
117 | eval_dataset = eval_dataset.filter(filter_function)
118 | logger.info(f"filtered eval dataset: {len(eval_dataset)} samples.")
119 |
120 | tokenized_eval_dataset = eval_dataset.map(
121 | lambda example: tokenizer(example["text"]),
122 | batched=True,
123 | )
124 |
125 |
126 | batch_size = args.batch_size
127 | output_path = Path("eval_stream")
128 | output_path.mkdir(parents=True, exist_ok=True)
129 |
130 | with torch.no_grad():
131 | for i in range(0, len(tokenized_eval_dataset), batch_size):
132 | batch = tokenized_eval_dataset.select(range(i, min(i + batch_size, len(tokenized_eval_dataset))))
133 |
134 | input_ids = [{"input_ids":instance["input_ids"]} for instance in batch]
135 |
136 | encodes = pad_without_fast_tokenizer_warning(
137 | tokenizer,
138 | input_ids,
139 | padding=True,
140 | return_attention_mask=True,
141 | return_tensors="pt"
142 | )
143 |
144 | input_ids = encodes["input_ids"].to(device)
145 | attention_mask = encodes["attention_mask"].to(device)
146 | text_input_length = input_ids.shape[1]
147 |
148 |
149 |
150 | output_sequences = model.generate(
151 | input_ids=input_ids,
152 | attention_mask=attention_mask,
153 | max_length=args.max_length,
154 | do_sample=True,
155 | num_return_sequences=args.num_return_sequences,
156 | )
157 |
158 |
159 |
160 | new_embeds = output_sequences[1]
161 | generated_ids = output_sequences[0]
162 |
163 | new_audio_values = model.codec.decoder(new_embeds.transpose(-1,-2).float())
164 |
165 | wav_len = (generated_ids.ne(eos_token_id).sum(dim=-1)) * STRIDE
166 |
167 |
168 |
169 | for i in range(len(wav_len)):
170 | id = batch["id"][i]
171 | torchaudio.save(output_path / f"{id}.wav", new_audio_values[i][:,:wav_len[i]].cpu(), SAMPLING_RATE)
172 |
173 | return
174 |
175 |
176 | if __name__ == "__main__":
177 | main()
178 |
--------------------------------------------------------------------------------
/scripts/run_offline.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import logging
3 | from typing import Tuple
4 | from pathlib import Path
5 |
6 | import pdb
7 | import torch
8 | import torchaudio
9 | from accelerate.utils import set_seed
10 |
11 | from transformers import AutoTokenizer, PreTrainedTokenizerFast, RobertaTokenizer
12 | from sled.sled import SpeechLlamaForCausalLM
13 |
14 |
15 |
16 | SAMPLING_RATE=24000
17 |
18 | logging.basicConfig(
19 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
20 | datefmt="%m/%d/%Y %H:%M:%S",
21 | level=logging.INFO,
22 | )
23 | logger = logging.getLogger(__name__)
24 |
25 |
26 |
27 | def adjust_length_to_model(length, max_sequence_length):
28 | assert max_sequence_length > 0
29 | if length <= 0 and max_sequence_length > 0:
30 | length = max_sequence_length
31 | elif 0 < max_sequence_length < length:
32 | length = max_sequence_length # No generation bigger than model size
33 | return length
34 |
35 | def main():
36 | parser = argparse.ArgumentParser()
37 |
38 | parser.add_argument(
39 | "--model_name_or_path",
40 | default=None,
41 | type=str,
42 | required=True,
43 | )
44 |
45 | parser.add_argument("--input", type=str, default="It was silent and gloomy, being tenanted solely by the captive and lighted by the dying embers of a fire which had been used for the purposes of cookery.")
46 | parser.add_argument("--max_length", type=int, default=0)
47 |
48 | parser.add_argument(
49 | "--temperature",
50 | type=float,
51 | default=1.0,
52 | help="temperature of 1.0 has no effect, lower tend toward greedy sampling",
53 | )
54 | parser.add_argument(
55 | "--cfg",
56 | type=float,
57 | default=1.0,
58 | )
59 | parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
60 | parser.add_argument("--num_return_sequences", type=int, default=1, help="The number of samples to generate.")
61 | parser.add_argument(
62 | "--fp16",
63 | action="store_true",
64 | )
65 | parser.add_argument(
66 | "--bf16",
67 | action="store_true",
68 | )
69 | args = parser.parse_args()
70 |
71 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
72 | torch_dtype = torch.bfloat16 if args.bf16 else (torch.float16 if args.fp16 else torch.float32)
73 | logger.warning(f"device: {device}, 16-bits inference: {args.fp16 or args.bf16}")
74 |
75 | if args.seed is not None:
76 | set_seed(args.seed)
77 |
78 | # Initialize the model and tokenizer
79 | tokenizer = RobertaTokenizer.from_pretrained(args.model_name_or_path)
80 | eos_token_id = tokenizer.eos_token_id
81 |
82 | model = SpeechLlamaForCausalLM.from_pretrained(args.model_name_or_path, torch_dtype=torch_dtype)
83 | model.infer_cfg = args.cfg
84 | model.initialize_codec("facebook/encodec_24khz")
85 | model.to(device)
86 |
87 | assert tokenizer.pad_token is not None
88 | logger.info(f"tokenizer pad token: {tokenizer.pad_token}")
89 |
90 |
91 | max_seq_length = getattr(model.config, "max_position_embeddings", 0)
92 | args.max_length = adjust_length_to_model(args.max_length, max_sequence_length=max_seq_length)
93 | logger.info(args)
94 |
95 |
96 | input_text = args.input if args.input else input("Model input >>> ")
97 | input_ids = tokenizer.encode(input_text, return_tensors='pt').to(device)
98 |
99 |
100 | output_sequences = model.generate(
101 | input_ids=input_ids,
102 | max_length=args.max_length,
103 | do_sample=True,
104 | num_return_sequences=args.num_return_sequences,
105 | )
106 |
107 |
108 | new_embeds = output_sequences[1]
109 | new_audio_values = model.codec.decoder(new_embeds.transpose(-1,-2).float())
110 |
111 |
112 | output_path = "output.wav"
113 | torchaudio.save(output_path, new_audio_values[0].cpu(), SAMPLING_RATE)
114 |
115 | return
116 |
117 |
118 | if __name__ == "__main__":
119 | main()
120 |
--------------------------------------------------------------------------------
/scripts/run_stream.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import logging
3 | from typing import Tuple
4 | from pathlib import Path
5 |
6 | import pdb
7 | import torch
8 | import torchaudio
9 | from accelerate.utils import set_seed
10 |
11 | from transformers import AutoTokenizer, PreTrainedTokenizerFast, RobertaTokenizer
12 | from sled.sled_stream import SpeechLlamaForCausalLM
13 |
14 |
15 |
16 | SAMPLING_RATE=24000
17 |
18 | logging.basicConfig(
19 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
20 | datefmt="%m/%d/%Y %H:%M:%S",
21 | level=logging.INFO,
22 | )
23 | logger = logging.getLogger(__name__)
24 |
25 |
26 |
27 | def adjust_length_to_model(length, max_sequence_length):
28 | assert max_sequence_length > 0
29 | if length <= 0 and max_sequence_length > 0:
30 | length = max_sequence_length
31 | elif 0 < max_sequence_length < length:
32 | length = max_sequence_length # No generation bigger than model size
33 | return length
34 |
35 | def main():
36 | parser = argparse.ArgumentParser()
37 |
38 | parser.add_argument(
39 | "--model_name_or_path",
40 | default=None,
41 | type=str,
42 | required=True,
43 | )
44 |
45 | parser.add_argument("--input", type=str, default="It was silent and gloomy, being tenanted solely by the captive and lighted by the dying embers of a fire which had been used for the purposes of cookery.")
46 | parser.add_argument("--max_length", type=int, default=0)
47 |
48 | parser.add_argument(
49 | "--temperature",
50 | type=float,
51 | default=1.0,
52 | help="temperature of 1.0 has no effect, lower tend toward greedy sampling",
53 | )
54 | parser.add_argument(
55 | "--cfg",
56 | type=float,
57 | default=1.0,
58 | )
59 | parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
60 | parser.add_argument("--num_return_sequences", type=int, default=1, help="The number of samples to generate.")
61 | parser.add_argument(
62 | "--fp16",
63 | action="store_true",
64 | )
65 | parser.add_argument(
66 | "--bf16",
67 | action="store_true",
68 | )
69 | args = parser.parse_args()
70 |
71 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
72 | torch_dtype = torch.bfloat16 if args.bf16 else (torch.float16 if args.fp16 else torch.float32)
73 | logger.warning(f"device: {device}, 16-bits inference: {args.fp16 or args.bf16}")
74 |
75 | if args.seed is not None:
76 | set_seed(args.seed)
77 |
78 | # Initialize the model and tokenizer
79 | tokenizer = RobertaTokenizer.from_pretrained(args.model_name_or_path)
80 | eos_token_id = tokenizer.eos_token_id
81 |
82 | model = SpeechLlamaForCausalLM.from_pretrained(args.model_name_or_path, torch_dtype=torch_dtype)
83 | model.infer_cfg = args.cfg
84 | model.initialize_codec("facebook/encodec_24khz")
85 | model.to(device)
86 |
87 | assert tokenizer.pad_token is not None
88 | logger.info(f"tokenizer pad token: {tokenizer.pad_token}")
89 |
90 |
91 | max_seq_length = getattr(model.config, "max_position_embeddings", 0)
92 | args.max_length = adjust_length_to_model(args.max_length, max_sequence_length=max_seq_length)
93 | logger.info(args)
94 |
95 |
96 | input_text = args.input if args.input else input("Model input >>> ")
97 | input_ids = tokenizer.encode(input_text, return_tensors='pt').to(device)
98 |
99 |
100 | output_sequences = model.generate(
101 | input_ids=input_ids,
102 | max_length=args.max_length,
103 | do_sample=True,
104 | num_return_sequences=args.num_return_sequences,
105 | )
106 |
107 |
108 | new_embeds = output_sequences[1]
109 | new_audio_values = model.codec.decoder(new_embeds.transpose(-1,-2).float())
110 |
111 |
112 | output_path = "output.wav"
113 | torchaudio.save(output_path, new_audio_values[0].cpu(), SAMPLING_RATE)
114 |
115 | return
116 |
117 |
118 | if __name__ == "__main__":
119 | main()
120 |
--------------------------------------------------------------------------------
/scripts/run_voice_clone.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import logging
3 | from typing import Tuple
4 | from pathlib import Path
5 |
6 | import pdb
7 | import torch
8 | import torchaudio
9 | from accelerate.utils import set_seed
10 |
11 | from transformers import AutoTokenizer, PreTrainedTokenizerFast, RobertaTokenizer, AutoProcessor
12 | from sled.sled import SpeechLlamaForCausalLM
13 |
14 |
15 | BANDWIDTH=6
16 | STRIDE=320
17 | SAMPLING_RATE=24000
18 |
19 | logging.basicConfig(
20 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
21 | datefmt="%m/%d/%Y %H:%M:%S",
22 | level=logging.INFO,
23 | )
24 | logger = logging.getLogger(__name__)
25 |
26 |
27 |
28 | def adjust_length_to_model(length, max_sequence_length):
29 | assert max_sequence_length > 0
30 | if length <= 0 and max_sequence_length > 0:
31 | length = max_sequence_length
32 | elif 0 < max_sequence_length < length:
33 | length = max_sequence_length # No generation bigger than model size
34 | return length
35 |
36 | def main():
37 | parser = argparse.ArgumentParser()
38 |
39 | parser.add_argument(
40 | "--model_name_or_path",
41 | default=None,
42 | type=str,
43 | required=True,
44 | )
45 | parser.add_argument("--prompt_audio", type=str)
46 | parser.add_argument("--prompt_text", type=str)
47 | parser.add_argument("--input", type=str, default="It was silent and gloomy, being tenanted solely by the captive and lighted by the dying embers of a fire which had been used for the purposes of cookery.")
48 | parser.add_argument("--max_length", type=int, default=0)
49 |
50 | parser.add_argument(
51 | "--temperature",
52 | type=float,
53 | default=1.0,
54 | help="temperature of 1.0 has no effect, lower tend toward greedy sampling",
55 | )
56 | parser.add_argument(
57 | "--cfg",
58 | type=float,
59 | default=1.0,
60 | )
61 | parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
62 | parser.add_argument("--num_return_sequences", type=int, default=1, help="The number of samples to generate.")
63 | parser.add_argument(
64 | "--fp16",
65 | action="store_true",
66 | )
67 | parser.add_argument(
68 | "--bf16",
69 | action="store_true",
70 | )
71 | args = parser.parse_args()
72 |
73 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
74 | torch_dtype = torch.bfloat16 if args.bf16 else (torch.float16 if args.fp16 else torch.float32)
75 | logger.warning(f"device: {device}, 16-bits inference: {args.fp16 or args.bf16}")
76 |
77 | if args.seed is not None:
78 | set_seed(args.seed)
79 |
80 | # Initialize the model and tokenizer
81 | tokenizer = RobertaTokenizer.from_pretrained(args.model_name_or_path)
82 | eos_token_id = tokenizer.eos_token_id
83 |
84 | model = SpeechLlamaForCausalLM.from_pretrained(args.model_name_or_path, torch_dtype=torch_dtype)
85 | model.infer_cfg = args.cfg
86 | model.initialize_codec("facebook/encodec_24khz")
87 | model.to(device)
88 |
89 | processor = AutoProcessor.from_pretrained("facebook/encodec_24khz")
90 |
91 | assert tokenizer.pad_token is not None
92 | logger.info(f"tokenizer pad token: {tokenizer.pad_token}")
93 |
94 |
95 | max_seq_length = getattr(model.config, "max_position_embeddings", 0)
96 | args.max_length = adjust_length_to_model(args.max_length, max_sequence_length=max_seq_length)
97 | logger.info(args)
98 |
99 | prompt_text = args.prompt_text if args.prompt_text else input("Prompt Text >>> ")
100 | prompt_audio = args.prompt_audio
101 | waveform, sample_rate = torchaudio.load(prompt_audio, normalize=True)
102 | if sample_rate != SAMPLING_RATE:
103 | resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=SAMPLING_RATE)
104 | waveform = resampler(waveform).squeeze().numpy()
105 |
106 |
107 | input_text = args.input if args.input else input("Model input >>> ")
108 | #input_ids = tokenizer.encode(input_text + " " + prompt_text, return_tensors='pt').to(device)
109 | input_text = [prompt_text + " " + input_text]
110 |
111 | batch_encoded = tokenizer.batch_encode_plus(
112 | input_text,
113 | add_special_tokens=True,
114 | padding="longest",
115 | truncation=True,
116 | return_tensors="pt",
117 | )
118 |
119 | input_ids = batch_encoded["input_ids"].to(device)
120 | attention_mask = batch_encoded["attention_mask"].to(device)
121 | text_input_length = input_ids.shape[1]
122 |
123 |
124 | audio_arrays = [waveform]
125 | audio_inputs = processor(raw_audio=audio_arrays, sampling_rate=SAMPLING_RATE, return_tensors="pt") # 'padding_mask': b,t 'input_values': b,c,t
126 |
127 |
128 | encoder_outputs = model.codec.encode(audio_inputs["input_values"].to(device), audio_inputs["padding_mask"].to(device), bandwidth=BANDWIDTH) #1,b,r,t, 1 due to one chunk
129 | speech_inputs_embeds = model.codec.quantizer.decode(encoder_outputs.audio_codes[0].transpose(0, 1)) #b,d,t
130 |
131 | speech_attention_mask = audio_inputs["padding_mask"][..., ::STRIDE].to(device)
132 | assert speech_inputs_embeds.size(-1) == speech_attention_mask.size(-1)
133 | speech_inputs_embeds = speech_inputs_embeds.transpose(1,2).to(model.dtype) #b,t,d
134 |
135 | speech_input_length = speech_inputs_embeds.shape[1]
136 |
137 |
138 | new_attention_mask = torch.concat([attention_mask, speech_attention_mask], dim=1)
139 |
140 |
141 | output_sequences = model.generate(
142 | input_ids=input_ids,
143 | inputs_embeds=speech_inputs_embeds,
144 | attention_mask=new_attention_mask,
145 | max_length=args.max_length,
146 | do_sample=True,
147 | num_return_sequences=args.num_return_sequences,
148 | )
149 |
150 |
151 | new_embeds = output_sequences[1]
152 | new_audio_values = model.codec.decoder(new_embeds.transpose(-1,-2).float())
153 |
154 |
155 | output_path = "output.wav"
156 | torchaudio.save(output_path, new_audio_values[0][:, speech_input_length* STRIDE:].cpu(), SAMPLING_RATE)
157 |
158 | return
159 |
160 |
161 | if __name__ == "__main__":
162 | main()
163 |
--------------------------------------------------------------------------------
/scripts/train_libriheavy.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import math
4 | import json
5 | import logging
6 | import pathlib
7 | from pathlib import Path
8 | from dataclasses import dataclass, field, asdict
9 | from typing import Dict, Optional, Sequence, List, Union
10 |
11 | import numpy as np
12 | import torch
13 | import datasets
14 | import transformers
15 |
16 | import soundfile as sf
17 | import librosa
18 |
19 |
20 |
21 | from transformers import (
22 | HfArgumentParser,
23 | set_seed,
24 | )
25 | from transformers.testing_utils import CaptureLogger
26 | from transformers.trainer_utils import get_last_checkpoint
27 | from transformers.data.data_collator import pad_without_fast_tokenizer_warning
28 |
29 | from transformers import AutoProcessor, AutoTokenizer, RobertaTokenizer
30 |
31 |
32 | from sled.sled import SpeechLlamaConfig, SpeechLlamaForCausalLM
33 | from sled.trainer_libriheavy import SpeechLlamaTrainer
34 |
35 | logger = logging.getLogger(__name__)
36 |
37 |
38 |
39 | SAMPLING_RATE=24000
40 | SAMPLING_RATE_LIBRIHEAVY=16000
41 | SAMPLING_RATE_TOKENIZER=75
42 |
43 |
44 | @dataclass
45 | class ArchArguments:
46 | # --------------------------------------------------------------------------
47 | # Llama Arguments
48 | hidden_size: int = 1024
49 | intermediate_size: int = 2752
50 | num_hidden_layers: int = 12
51 | num_attention_heads: int = 16
52 | num_key_value_heads: Optional[int] = None
53 | hidden_act: str = "silu"
54 | max_position_embeddings: int = 2048
55 | initializer_range: float = 0.02
56 | rms_norm_eps: float = 1e-6
57 | use_cache: bool = True
58 | pad_token_id: Optional[int] = None
59 | bos_token_id: int = 0
60 | eos_token_id: int = 2
61 | pretraining_tp: int = 1
62 | tie_word_embeddings: bool = False
63 | rope_theta: float = 10000.0
64 | rope_scaling: Optional[float] = None
65 | attention_bias: bool = False
66 | attention_dropout: float = 0.1
67 | mlp_bias: bool = False
68 | vocab_size: int = 32000
69 | dropout: float = 0.1
70 | activation_dropout: float = 0.1
71 |
72 | # --------------------------------------------------------------------------
73 | # Score Arguments
74 | vae_embed_dim: int = 128
75 | diffloss_d: int = 3
76 | diffloss_w: int = 1024
77 | training_cfg: float = 0.0
78 | noise_channels: int = 128
79 |
80 |
81 |
82 | @dataclass
83 | class ModelArguments:
84 | # --------------------------------------------------------------------------
85 | # Codec & Tokenizer Arguments
86 | codec: str = "facebook/encodec_24khz"
87 | tokenizer: str = "/path/tokenizer_bpe_libriheavy"
88 |
89 |
90 |
91 | @dataclass
92 | class DataArguments:
93 | data_path: str = "/path/libriheavy"
94 | train_manifest: List[str] = field(default_factory=lambda: ["/path/libriheavy/cases_and_punc/libriheavy_cuts_large.jsonl", "/path/libriheavy/cases_and_punc/libriheavy_cuts_medium.jsonl", "/path/libriheavy/cases_and_punc/libriheavy_cuts_small.jsonl"])
95 | eval_manifest: List[str] = field(default_factory=lambda: ["/path/libriheavy/cases_and_punc/filtered2/libriheavy_cuts_dev.jsonl"])
96 | pad_to_multiple_of: Optional[int] = None
97 | max_train_samples: Optional[int] = field(
98 | default=None,
99 | metadata={
100 | "help": (
101 | "For debugging purposes or quicker training, truncate the number of training examples to this "
102 | "value if set."
103 | )
104 | },
105 | )
106 | max_eval_samples: Optional[int] = field(
107 | default=None,
108 | metadata={
109 | "help": (
110 | "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
111 | "value if set."
112 | )
113 | },
114 | )
115 |
116 |
117 | @dataclass
118 | class TrainingArguments(transformers.TrainingArguments):
119 | group_by_speech_length: bool = field(default=True)
120 |
121 |
122 |
123 | @dataclass
124 | class DataCollatorForSupervisedDataset(object):
125 | """Collate examples for supervised fine-tuning."""
126 |
127 | tokenizer: transformers.PreTrainedTokenizer
128 | processor: transformers.PreTrainedTokenizer
129 | data_path: str
130 | pad_to_multiple_of: Optional[int] = None
131 |
132 | def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
133 | input_ids = [{"input_ids":instance["input_ids"]} for instance in instances]
134 |
135 | batch = pad_without_fast_tokenizer_warning(
136 | self.tokenizer,
137 | input_ids,
138 | padding=True,
139 | pad_to_multiple_of=self.pad_to_multiple_of,
140 | return_attention_mask=True,
141 | return_tensors="pt"
142 | )
143 |
144 | audio_files = [instance["recording"]["sources"][0]["source"] for instance in instances]
145 | durations = [instance["duration"] for instance in instances]
146 | start_times = [instance["start"] for instance in instances]
147 |
148 | audio_arrays = [self.load_audio(file_path, start, duration) for file_path, start, duration in zip(audio_files, start_times, durations)]
149 |
150 | audio_inputs = self.processor(raw_audio=audio_arrays, sampling_rate=SAMPLING_RATE, return_tensors="pt") # 'padding_mask': b,t 'input_values': b,c,t
151 |
152 | batch["audio_inputs"] = audio_inputs
153 |
154 | return batch
155 |
156 | def load_audio(self, file_path: str, start: float, duration: float) -> np.array:
157 | abs_path = Path(self.data_path) / file_path
158 | audio, sampling_rate = sf.read(abs_path, start=int(start * SAMPLING_RATE_LIBRIHEAVY), stop=int((start + duration) * SAMPLING_RATE_LIBRIHEAVY))
159 | resampled_audio = librosa.resample(audio, orig_sr=sampling_rate, target_sr=SAMPLING_RATE)
160 | return resampled_audio
161 |
162 |
163 | def load_manifest(file_paths):
164 | all_data = []
165 | for file_path in file_paths:
166 | with open(file_path, "r") as f:
167 | all_data.extend([json.loads(line) for line in f])
168 | return all_data
169 |
170 |
171 | def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer,
172 | arch_args, model_args, data_args, training_args) -> Dict:
173 | """Make dataset and collator for supervised fine-tuning."""
174 | train_dataset = None
175 | eval_dataset = None
176 |
177 | if training_args.do_train:
178 | train_manifest = load_manifest(data_args.train_manifest)
179 |
180 | if data_args.max_train_samples is not None:
181 | train_manifest = train_manifest[:data_args.max_train_samples]
182 |
183 | train_dataset = train_manifest
184 |
185 |
186 | if training_args.do_eval:
187 | eval_manifest = load_manifest(data_args.eval_manifest)
188 |
189 | if data_args.max_eval_samples is not None:
190 | eval_manifest = eval_manifest[:data_args.max_eval_samples]
191 |
192 | eval_dataset = eval_manifest
193 |
194 | def tokenize_example(example):
195 | text = example["supervisions"][0]["text"]
196 | return tokenizer(text)
197 |
198 | tokenized_train_dataset = None
199 | tokenized_eval_dataset = None
200 |
201 | with training_args.main_process_first(desc="dataset map tokenization"):
202 | if training_args.do_train and train_dataset:
203 | tokenized_train_dataset = [
204 | {**example, **tokenize_example(example)} for example in train_dataset
205 | ]
206 |
207 | if training_args.do_eval and eval_dataset:
208 | tokenized_eval_dataset = [
209 | {**example, **tokenize_example(example)} for example in eval_dataset
210 | ]
211 |
212 |
213 | def filter_function(example):
214 | file_path = example["recording"]["sources"][0]["source"]
215 | abs_path = Path(data_args.data_path) / file_path
216 | exists = abs_path.exists()
217 | return ((len(example['input_ids']) + int(example["duration"] * SAMPLING_RATE_TOKENIZER)) < arch_args.max_position_embeddings) and exists
218 |
219 | if tokenized_train_dataset is not None:
220 | logger.info(f"original train dataset: {len(tokenized_train_dataset)} samples.")
221 | tokenized_train_dataset = [ex for ex in tokenized_train_dataset if filter_function(ex)]
222 | logger.info(f"filtered train dataset: {len(tokenized_train_dataset)} samples.")
223 |
224 |
225 | if tokenized_eval_dataset is not None:
226 | logger.info(f"original eval dataset: {len(tokenized_eval_dataset)} samples.")
227 | tokenized_eval_dataset = [ex for ex in tokenized_eval_dataset if filter_function(ex)]
228 | logger.info(f"filtered eval dataset: {len(tokenized_eval_dataset)} samples.")
229 |
230 | processor = AutoProcessor.from_pretrained(model_args.codec)
231 | data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer, processor=processor, pad_to_multiple_of=data_args.pad_to_multiple_of, data_path=data_args.data_path)
232 |
233 | return tokenized_train_dataset, tokenized_eval_dataset, data_collator
234 |
235 |
236 | def train(attn_implementation="sdpa"):
237 |
238 | parser = HfArgumentParser(
239 | (ArchArguments, ModelArguments, DataArguments, TrainingArguments))
240 | arch_args, model_args, data_args, training_args = parser.parse_args_into_dataclasses()
241 |
242 |
243 | # Setup logging
244 | logging.basicConfig(
245 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
246 | datefmt="%m/%d/%Y %H:%M:%S",
247 | handlers=[logging.StreamHandler(sys.stdout)],
248 | )
249 |
250 | if training_args.should_log:
251 | # The default of training_args.log_level is passive, so we set log level at info here to have that default.
252 | transformers.utils.logging.set_verbosity_info()
253 |
254 | log_level = training_args.get_process_log_level()
255 | logger.setLevel(log_level)
256 | datasets.utils.logging.set_verbosity(log_level)
257 | transformers.utils.logging.set_verbosity(log_level)
258 | transformers.utils.logging.enable_default_handler()
259 | transformers.utils.logging.enable_explicit_format()
260 |
261 | # Log on each process the small summary:
262 | logger.warning(
263 | f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}, "
264 | + f"distributed training: {training_args.parallel_mode.value == 'distributed'}, 16-bits training: {training_args.fp16 or training_args.bf16}"
265 | )
266 | logger.info(f"Training/evaluation parameters {training_args}")
267 |
268 |
269 | # Detecting last checkpoint.
270 | last_checkpoint = None
271 | if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
272 | last_checkpoint = get_last_checkpoint(training_args.output_dir)
273 | if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
274 | raise ValueError(
275 | f"Output directory ({training_args.output_dir}) already exists and is not empty. "
276 | "Use --overwrite_output_dir to overcome."
277 | )
278 | elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
279 | logger.info(
280 | f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
281 | "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
282 | )
283 |
284 | # Set seed before initializing model.
285 | set_seed(training_args.seed)
286 |
287 |
288 | tokenizer = RobertaTokenizer.from_pretrained(
289 | model_args.tokenizer,
290 | padding_side="left",
291 | add_eos_token=True,
292 | )
293 | arch_args.vocab_size = tokenizer.vocab_size
294 | model_config = SpeechLlamaConfig(**asdict(arch_args))
295 | logger.info(f"config: {model_config}")
296 |
297 |
298 | torch_dtype = None #torch.bfloat16 if training_args.bf16 else None
299 |
300 | model = SpeechLlamaForCausalLM._from_config(model_config, attn_implementation=attn_implementation, torch_dtype=torch_dtype)
301 | model.initialize_codec(model_args)
302 |
303 | n_params = sum({p.data_ptr(): p.numel() for p in model.parameters()}.values())
304 | logger.info(f"Training new model from scratch - Total size={n_params/2**20:.2f}M params")
305 |
306 |
307 | train_dataset, eval_dataset, data_collator = make_supervised_data_module(tokenizer, arch_args, model_args, data_args, training_args)
308 | trainer = SpeechLlamaTrainer(
309 | model=model,
310 | args=training_args,
311 | data_collator=data_collator,
312 | train_dataset=train_dataset if training_args.do_train else None,
313 | eval_dataset=eval_dataset if training_args.do_eval else None,
314 | tokenizer=tokenizer,
315 | )
316 |
317 | # Training
318 | if training_args.do_train:
319 | checkpoint = None
320 | if training_args.resume_from_checkpoint is not None:
321 | checkpoint = training_args.resume_from_checkpoint
322 | elif last_checkpoint is not None:
323 | checkpoint = last_checkpoint
324 | train_result = trainer.train(resume_from_checkpoint=checkpoint)
325 | trainer.save_model() # Saves the tokenizer too for easy upload
326 |
327 | metrics = train_result.metrics
328 |
329 | max_train_samples = (
330 | data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset)
331 | )
332 | metrics["train_samples"] = min(max_train_samples, len(train_dataset))
333 |
334 | trainer.log_metrics("train", metrics)
335 | trainer.save_metrics("train", metrics)
336 | trainer.save_state()
337 |
338 | # Evaluation
339 | if training_args.do_eval:
340 | logger.info("*** Evaluate ***")
341 |
342 | metrics = trainer.evaluate()
343 |
344 | max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
345 | metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
346 |
347 |
348 | trainer.log_metrics("eval", metrics)
349 | trainer.save_metrics("eval", metrics)
350 |
351 | if __name__ == "__main__":
352 | train(attn_implementation="sdpa")
353 |
--------------------------------------------------------------------------------
/scripts/train_libriheavy_stream.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import math
4 | import json
5 | import logging
6 | import pathlib
7 | from pathlib import Path
8 | from dataclasses import dataclass, field, asdict
9 | from typing import Dict, Optional, Sequence, List, Union
10 |
11 | import numpy as np
12 | import torch
13 | import datasets
14 | import transformers
15 | from safetensors.torch import load_file
16 |
17 | import soundfile as sf
18 | import librosa
19 |
20 |
21 |
22 | from transformers import (
23 | HfArgumentParser,
24 | set_seed,
25 | )
26 | from transformers.testing_utils import CaptureLogger
27 | from transformers.trainer_utils import get_last_checkpoint
28 | from transformers.data.data_collator import pad_without_fast_tokenizer_warning
29 |
30 | from transformers import AutoProcessor, AutoTokenizer, RobertaTokenizer
31 |
32 |
33 | from sled.sled_stream import SpeechLlamaConfig, SpeechLlamaForCausalLM
34 | from sled.trainer_libriheavy import SpeechLlamaTrainer
35 |
36 | logger = logging.getLogger(__name__)
37 |
38 |
39 |
40 | SAMPLING_RATE=24000
41 | SAMPLING_RATE_LIBRIHEAVY=16000
42 | SAMPLING_RATE_TOKENIZER=75
43 |
44 |
45 | @dataclass
46 | class ArchArguments:
47 | # --------------------------------------------------------------------------
48 | # Llama Arguments
49 | hidden_size: int = 1024
50 | intermediate_size: int = 2752
51 | num_hidden_layers: int = 12
52 | num_attention_heads: int = 16
53 | num_key_value_heads: Optional[int] = None
54 | hidden_act: str = "silu"
55 | max_position_embeddings: int = 2048
56 | initializer_range: float = 0.02
57 | rms_norm_eps: float = 1e-6
58 | use_cache: bool = True
59 | pad_token_id: Optional[int] = None
60 | bos_token_id: int = 0
61 | eos_token_id: int = 2
62 | pretraining_tp: int = 1
63 | tie_word_embeddings: bool = False
64 | rope_theta: float = 10000.0
65 | rope_scaling: Optional[float] = None
66 | attention_bias: bool = False
67 | attention_dropout: float = 0.1
68 | mlp_bias: bool = False
69 | vocab_size: int = 32000
70 | dropout: float = 0.1
71 | activation_dropout: float = 0.1
72 |
73 | # --------------------------------------------------------------------------
74 | # Score Arguments
75 | vae_embed_dim: int = 128
76 | diffloss_d: int = 3
77 | diffloss_w: int = 1024
78 | training_cfg: float = 0.0
79 | noise_channels: int = 128
80 |
81 | # --------------------------------------------------------------------------
82 | # Stream Arguments
83 | stream_n: int = 5
84 | stream_m: int = 45
85 |
86 |
87 |
88 | @dataclass
89 | class ModelArguments:
90 | # --------------------------------------------------------------------------
91 | # Codec & Tokenizer Arguments
92 | codec: str = "facebook/encodec_24khz"
93 | tokenizer: str = "/path/tokenizer_bpe_libriheavy"
94 |
95 |
96 |
97 | @dataclass
98 | class DataArguments:
99 | finetune_path: str = None
100 | data_path: str = "/path/libriheavy"
101 | train_manifest: List[str] = field(default_factory=lambda: ["/path/libriheavy/cases_and_punc/libriheavy_cuts_large.jsonl", "/path/libriheavy/cases_and_punc/libriheavy_cuts_medium.jsonl", "/path/libriheavy/cases_and_punc/libriheavy_cuts_small.jsonl"])
102 | eval_manifest: List[str] = field(default_factory=lambda: ["/path/libriheavy/cases_and_punc/filtered2/libriheavy_cuts_dev.jsonl"])
103 | pad_to_multiple_of: Optional[int] = None
104 | max_train_samples: Optional[int] = field(
105 | default=None,
106 | metadata={
107 | "help": (
108 | "For debugging purposes or quicker training, truncate the number of training examples to this "
109 | "value if set."
110 | )
111 | },
112 | )
113 | max_eval_samples: Optional[int] = field(
114 | default=None,
115 | metadata={
116 | "help": (
117 | "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
118 | "value if set."
119 | )
120 | },
121 | )
122 |
123 |
124 | @dataclass
125 | class TrainingArguments(transformers.TrainingArguments):
126 | group_by_speech_length: bool = field(default=True)
127 |
128 |
129 |
130 | @dataclass
131 | class DataCollatorForSupervisedDataset(object):
132 | """Collate examples for supervised fine-tuning."""
133 |
134 | tokenizer: transformers.PreTrainedTokenizer
135 | processor: transformers.PreTrainedTokenizer
136 | data_path: str
137 | pad_to_multiple_of: Optional[int] = None
138 |
139 | def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
140 | input_ids = [{"input_ids":instance["input_ids"]} for instance in instances]
141 |
142 | batch = pad_without_fast_tokenizer_warning(
143 | self.tokenizer,
144 | input_ids,
145 | padding=True,
146 | pad_to_multiple_of=self.pad_to_multiple_of,
147 | return_attention_mask=True,
148 | return_tensors="pt"
149 | )
150 |
151 | audio_files = [instance["recording"]["sources"][0]["source"] for instance in instances]
152 | durations = [instance["duration"] for instance in instances]
153 | start_times = [instance["start"] for instance in instances]
154 |
155 | audio_arrays = [self.load_audio(file_path, start, duration) for file_path, start, duration in zip(audio_files, start_times, durations)]
156 |
157 | audio_inputs = self.processor(raw_audio=audio_arrays, sampling_rate=SAMPLING_RATE, return_tensors="pt") # 'padding_mask': b,t 'input_values': b,c,t
158 |
159 | batch["audio_inputs"] = audio_inputs
160 |
161 | return batch
162 |
163 | def load_audio(self, file_path: str, start: float, duration: float) -> np.array:
164 | abs_path = Path(self.data_path) / file_path
165 | audio, sampling_rate = sf.read(abs_path, start=int(start * SAMPLING_RATE_LIBRIHEAVY), stop=int((start + duration) * SAMPLING_RATE_LIBRIHEAVY))
166 | resampled_audio = librosa.resample(audio, orig_sr=sampling_rate, target_sr=SAMPLING_RATE)
167 | return resampled_audio
168 |
169 |
170 | def load_manifest(file_paths):
171 | all_data = []
172 | for file_path in file_paths:
173 | with open(file_path, "r") as f:
174 | all_data.extend([json.loads(line) for line in f])
175 | return all_data
176 |
177 |
178 | def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer,
179 | arch_args, model_args, data_args, training_args) -> Dict:
180 | """Make dataset and collator for supervised fine-tuning."""
181 | train_dataset = None
182 | eval_dataset = None
183 |
184 | if training_args.do_train:
185 | train_manifest = load_manifest(data_args.train_manifest)
186 |
187 | if data_args.max_train_samples is not None:
188 | train_manifest = train_manifest[:data_args.max_train_samples]
189 |
190 | train_dataset = train_manifest
191 |
192 |
193 | if training_args.do_eval:
194 | eval_manifest = load_manifest(data_args.eval_manifest)
195 |
196 | if data_args.max_eval_samples is not None:
197 | eval_manifest = eval_manifest[:data_args.max_eval_samples]
198 |
199 | eval_dataset = eval_manifest
200 |
201 | def tokenize_example(example):
202 | text = example["supervisions"][0]["text"]
203 | return tokenizer(text)
204 |
205 | tokenized_train_dataset = None
206 | tokenized_eval_dataset = None
207 |
208 | with training_args.main_process_first(desc="dataset map tokenization"):
209 | if training_args.do_train and train_dataset:
210 | tokenized_train_dataset = [
211 | {**example, **tokenize_example(example)} for example in train_dataset
212 | ]
213 |
214 | if training_args.do_eval and eval_dataset:
215 | tokenized_eval_dataset = [
216 | {**example, **tokenize_example(example)} for example in eval_dataset
217 | ]
218 |
219 |
220 | def filter_function(example):
221 | file_path = example["recording"]["sources"][0]["source"]
222 | abs_path = Path(data_args.data_path) / file_path
223 | exists = abs_path.exists()
224 | return ((len(example['input_ids']) + int(example["duration"] * SAMPLING_RATE_TOKENIZER)) < arch_args.max_position_embeddings) and exists
225 |
226 | if tokenized_train_dataset is not None:
227 | logger.info(f"original train dataset: {len(tokenized_train_dataset)} samples.")
228 | tokenized_train_dataset = [ex for ex in tokenized_train_dataset if filter_function(ex)]
229 | logger.info(f"filtered train dataset: {len(tokenized_train_dataset)} samples.")
230 |
231 |
232 | if tokenized_eval_dataset is not None:
233 | logger.info(f"original eval dataset: {len(tokenized_eval_dataset)} samples.")
234 | tokenized_eval_dataset = [ex for ex in tokenized_eval_dataset if filter_function(ex)]
235 | logger.info(f"filtered eval dataset: {len(tokenized_eval_dataset)} samples.")
236 |
237 | processor = AutoProcessor.from_pretrained(model_args.codec)
238 | data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer, processor=processor, pad_to_multiple_of=data_args.pad_to_multiple_of, data_path=data_args.data_path)
239 |
240 | return tokenized_train_dataset, tokenized_eval_dataset, data_collator
241 |
242 |
243 | def train(attn_implementation="sdpa"):
244 |
245 | parser = HfArgumentParser(
246 | (ArchArguments, ModelArguments, DataArguments, TrainingArguments))
247 | arch_args, model_args, data_args, training_args = parser.parse_args_into_dataclasses()
248 |
249 |
250 | # Setup logging
251 | logging.basicConfig(
252 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
253 | datefmt="%m/%d/%Y %H:%M:%S",
254 | handlers=[logging.StreamHandler(sys.stdout)],
255 | )
256 |
257 | if training_args.should_log:
258 | # The default of training_args.log_level is passive, so we set log level at info here to have that default.
259 | transformers.utils.logging.set_verbosity_info()
260 |
261 | log_level = training_args.get_process_log_level()
262 | logger.setLevel(log_level)
263 | datasets.utils.logging.set_verbosity(log_level)
264 | transformers.utils.logging.set_verbosity(log_level)
265 | transformers.utils.logging.enable_default_handler()
266 | transformers.utils.logging.enable_explicit_format()
267 |
268 | # Log on each process the small summary:
269 | logger.warning(
270 | f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}, "
271 | + f"distributed training: {training_args.parallel_mode.value == 'distributed'}, 16-bits training: {training_args.fp16 or training_args.bf16}"
272 | )
273 | logger.info(f"Training/evaluation parameters {training_args}")
274 |
275 |
276 | # Detecting last checkpoint.
277 | last_checkpoint = None
278 | if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
279 | last_checkpoint = get_last_checkpoint(training_args.output_dir)
280 | if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
281 | raise ValueError(
282 | f"Output directory ({training_args.output_dir}) already exists and is not empty. "
283 | "Use --overwrite_output_dir to overcome."
284 | )
285 | elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
286 | logger.info(
287 | f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
288 | "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
289 | )
290 |
291 | # Set seed before initializing model.
292 | set_seed(training_args.seed)
293 |
294 |
295 | tokenizer = RobertaTokenizer.from_pretrained(
296 | model_args.tokenizer,
297 | padding_side="right",
298 | add_eos_token=True,
299 | )
300 | arch_args.vocab_size = tokenizer.vocab_size
301 | model_config = SpeechLlamaConfig(**asdict(arch_args))
302 | logger.info(f"config: {model_config}")
303 |
304 |
305 | torch_dtype = None #torch.bfloat16 if training_args.bf16 else None
306 |
307 | model = SpeechLlamaForCausalLM._from_config(model_config, attn_implementation=attn_implementation, torch_dtype=torch_dtype)
308 |
309 | state_dict = load_file(data_args.finetune_path)
310 | missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
311 | model.initialize_codec(model_args)
312 |
313 | n_params = sum({p.data_ptr(): p.numel() for p in model.parameters()}.values())
314 | logger.info(f"Training new model from scratch - Total size={n_params/2**20:.2f}M params")
315 |
316 |
317 | train_dataset, eval_dataset, data_collator = make_supervised_data_module(tokenizer, arch_args, model_args, data_args, training_args)
318 | trainer = SpeechLlamaTrainer(
319 | model=model,
320 | args=training_args,
321 | data_collator=data_collator,
322 | train_dataset=train_dataset if training_args.do_train else None,
323 | eval_dataset=eval_dataset if training_args.do_eval else None,
324 | tokenizer=tokenizer,
325 | )
326 |
327 | # Training
328 | if training_args.do_train:
329 | checkpoint = None
330 | if training_args.resume_from_checkpoint is not None:
331 | checkpoint = training_args.resume_from_checkpoint
332 | elif last_checkpoint is not None:
333 | checkpoint = last_checkpoint
334 | train_result = trainer.train(resume_from_checkpoint=checkpoint)
335 | trainer.save_model() # Saves the tokenizer too for easy upload
336 |
337 | metrics = train_result.metrics
338 |
339 | max_train_samples = (
340 | data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset)
341 | )
342 | metrics["train_samples"] = min(max_train_samples, len(train_dataset))
343 |
344 | trainer.log_metrics("train", metrics)
345 | trainer.save_metrics("train", metrics)
346 | trainer.save_state()
347 |
348 | # Evaluation
349 | if training_args.do_eval:
350 | logger.info("*** Evaluate ***")
351 |
352 | metrics = trainer.evaluate()
353 |
354 | max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
355 | metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
356 |
357 |
358 | trainer.log_metrics("eval", metrics)
359 | trainer.save_metrics("eval", metrics)
360 |
361 | if __name__ == "__main__":
362 | train(attn_implementation="sdpa")
363 |
--------------------------------------------------------------------------------
/shell_scripts/eval.sh:
--------------------------------------------------------------------------------
1 |
2 | CHECKPOINT=/path/to/checkpoint
3 | BSZ=1
4 | CFG=2.0
5 |
6 | python ./scripts/eval_stream.py \
7 | --model_name_or_path ${CHECKPOINT} \
8 | --batch_size ${BSZ} --cfg ${CFG} --seed 0
9 |
--------------------------------------------------------------------------------
/shell_scripts/run.sh:
--------------------------------------------------------------------------------
1 |
2 |
3 | CHECKPOINT=/path/to/checkpoint
4 | CFG=2.0
5 |
6 | # Offline Inference
7 | python scripts/run_offline.py \
8 | --model_name_or_path ${CHECKPOINT} \
9 | --cfg ${CFG} \
10 | --input "My remark pleases him, but I soon prove to him that it is not the right way to speak. However perfect may have been the language of that ancient writer." \
11 | --seed 42
12 |
13 | # Or Streaming Inference
14 | python scripts/run_stream.py \
15 | --model_name_or_path ${CHECKPOINT} \
16 | --cfg ${CFG} \
17 | --input "My remark pleases him, but I soon prove to him that it is not the right way to speak. However perfect may have been the language of that ancient writer." \
18 | --seed 42
19 |
20 | # Or Voice Clone
21 | python scripts/run_voice_clone.py \
22 | --prompt_text "Were I in the warm room with all the splendor and magnificence!" \
23 | --prompt_audio "example_prompt.flac" \
24 | --model_name_or_path ${CHECKPOINT} \
25 | --cfg ${CFG} \
26 | --input "Perhaps the other trees from the forest will come to look at me!" \
27 | --seed 42
28 |
--------------------------------------------------------------------------------
/shell_scripts/train_libriheavy.sh:
--------------------------------------------------------------------------------
1 | OUTPUT_DIR=./runs/libriheavy
2 | mkdir -p $OUTPUT_DIR
3 | LOG_FILE=${OUTPUT_DIR}/log
4 |
5 | BATCH_SIZE=8
6 | UPDATE_FREQ=8
7 | # assume 8 proc per node, then WORLD_SIZE * 8 * BATCH_SIZE * UPDATE_FREQ == 512
8 |
9 | torchrun --nnodes ${WORLD_SIZE} --node_rank ${RANK} --nproc_per_node 8 --master_addr ${MASTER_ADDR} --master_port ${MASTER_PORT} \
10 | ./scripts/train_libriheavy.py \
11 | --training_cfg 0.1 \
12 | --num_hidden_layers 12 --diffloss_d 6 --noise_channels 128 \
13 | --dataloader_num_workers 8 \
14 | --dataloader_pin_memory True \
15 | --remove_unused_columns False \
16 | --label_names audio_inputs \
17 | --group_by_speech_length \
18 | --do_train \
19 | --do_eval \
20 | --eval_strategy steps \
21 | --eval_steps 10000 \
22 | --prediction_loss_only \
23 | --per_device_train_batch_size ${BATCH_SIZE} \
24 | --per_device_eval_batch_size 24 \
25 | --gradient_accumulation_steps ${UPDATE_FREQ} \
26 | --bf16 \
27 | --learning_rate 5e-4 \
28 | --weight_decay 0.01 \
29 | --adam_beta1 0.9 \
30 | --adam_beta2 0.999 \
31 | --adam_epsilon 1e-8 \
32 | --max_grad_norm 1.0 \
33 | --max_steps 300000 \
34 | --lr_scheduler_type "linear" \
35 | --warmup_steps 32000 \
36 | --logging_first_step \
37 | --logging_steps 100 \
38 | --save_steps 10000 \
39 | --save_total_limit 10 \
40 | --output_dir ${OUTPUT_DIR} \
41 | --report_to tensorboard \
42 | --disable_tqdm True \
43 | --ddp_timeout 3600 --overwrite_output_dir \
44 | 2>&1 |tee -a ${LOG_FILE}
45 |
--------------------------------------------------------------------------------
/shell_scripts/train_libriheavy_stream.sh:
--------------------------------------------------------------------------------
1 | OUTPUT_DIR=./runs/libriheavy_stream
2 | mkdir -p $OUTPUT_DIR
3 | LOG_FILE=${OUTPUT_DIR}/log
4 |
5 | BATCH_SIZE=8
6 | UPDATE_FREQ=8
7 | # assume 8 proc per node, then WORLD_SIZE * 8 * BATCH_SIZE * UPDATE_FREQ == 512
8 |
9 | torchrun --nnodes ${WORLD_SIZE} --node_rank ${RANK} --nproc_per_node 8 --master_addr ${MASTER_ADDR} --master_port ${MASTER_PORT} \
10 | ./scripts/train_libriheavy_stream.py \
11 | --finetune_path ./runs/libriheavy/checkpoint-300000/model.safetensors \
12 | --stream_n 5 --stream_m 45 \
13 | --training_cfg 0.1 \
14 | --num_hidden_layers 12 --diffloss_d 6 --noise_channels 128 \
15 | --dataloader_num_workers 8 \
16 | --dataloader_pin_memory True \
17 | --remove_unused_columns False \
18 | --label_names audio_inputs \
19 | --group_by_speech_length \
20 | --do_train \
21 | --do_eval \
22 | --eval_strategy steps \
23 | --eval_steps 10000 \
24 | --prediction_loss_only \
25 | --per_device_train_batch_size ${BATCH_SIZE} \
26 | --per_device_eval_batch_size 24 \
27 | --gradient_accumulation_steps ${UPDATE_FREQ} \
28 | --bf16 \
29 | --learning_rate 3e-4 \
30 | --weight_decay 0.01 \
31 | --adam_beta1 0.9 \
32 | --adam_beta2 0.999 \
33 | --adam_epsilon 1e-8 \
34 | --max_grad_norm 1.0 \
35 | --max_steps 100000 \
36 | --lr_scheduler_type "linear" \
37 | --warmup_steps 10000 \
38 | --logging_first_step \
39 | --logging_steps 100 \
40 | --save_steps 10000 \
41 | --save_total_limit 10 \
42 | --output_dir ${OUTPUT_DIR} \
43 | --report_to tensorboard \
44 | --disable_tqdm True \
45 | --ddp_timeout 3600 --overwrite_output_dir \
46 | 2>&1 |tee -a ${LOG_FILE}
47 |
--------------------------------------------------------------------------------
/sled.egg-info/PKG-INFO:
--------------------------------------------------------------------------------
1 | Metadata-Version: 2.4
2 | Name: sled
3 | Version: 0.1.0
4 | Summary: Implementation of SLED
5 | Author-email: Zhengrui Ma
6 | Requires-Dist: torch==2.5.1
7 | Requires-Dist: torchaudio==2.5.1
8 | Requires-Dist: transformers==4.47.0
9 | Requires-Dist: datasets==3.1.0
10 | Requires-Dist: accelerate==1.2.0
11 | Requires-Dist: numpy==1.26.4
12 | Requires-Dist: librosa==0.10.2.post1
13 | Requires-Dist: soundfile==0.12.1
14 |
--------------------------------------------------------------------------------
/sled.egg-info/SOURCES.txt:
--------------------------------------------------------------------------------
1 | README.md
2 | pyproject.toml
3 | scripts/eval_continue.py
4 | scripts/eval_reference.py
5 | scripts/eval_stream.py
6 | scripts/run_offline.py
7 | scripts/run_stream.py
8 | scripts/train_libriheavy.py
9 | scripts/train_libriheavy_stream.py
10 | sled/__init__.py
11 | sled/energy_distance.py
12 | sled/modeling_llama_with_dropout.py
13 | sled/sled.py
14 | sled/sled_stream.py
15 | sled/trainer.py
16 | sled/trainer_libriheavy.py
17 | sled.egg-info/PKG-INFO
18 | sled.egg-info/SOURCES.txt
19 | sled.egg-info/dependency_links.txt
20 | sled.egg-info/requires.txt
21 | sled.egg-info/top_level.txt
--------------------------------------------------------------------------------
/sled.egg-info/dependency_links.txt:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/sled.egg-info/requires.txt:
--------------------------------------------------------------------------------
1 | torch==2.5.1
2 | torchaudio==2.5.1
3 | transformers==4.47.0
4 | datasets==3.1.0
5 | accelerate==1.2.0
6 | numpy==1.26.4
7 | librosa==0.10.2.post1
8 | soundfile==0.12.1
9 |
--------------------------------------------------------------------------------
/sled.egg-info/top_level.txt:
--------------------------------------------------------------------------------
1 | sled
2 |
--------------------------------------------------------------------------------
/sled/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ictnlp/SLED-TTS/b8ed10d9953160efd8a0538b4ea5af80a57c9e96/sled/__init__.py
--------------------------------------------------------------------------------
/sled/energy_distance.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.utils.checkpoint import checkpoint
4 | import math
5 | import pdb
6 |
7 |
8 | class ScoreLossZ(nn.Module):
9 | """Energy Distance Loss"""
10 | def __init__(self, target_channels, z_channels, depth, width, beta=1, gamma=1, noise_channels=16):
11 | super(ScoreLossZ, self).__init__()
12 | self.noise_channels = noise_channels
13 | self.net = SimpleMLPAdaLN(
14 | in_channels=self.noise_channels,
15 | model_channels=width,
16 | out_channels=target_channels,
17 | z_channels=z_channels,
18 | num_res_blocks=depth,
19 | )
20 | self.beta = beta
21 | self.gamma = gamma
22 |
23 |
24 | def forward(self, target, z, mask=None, additional_targets=None):
25 | noise_1 = torch.randn((z.shape[0], self.noise_channels), dtype=z.dtype, device=z.device)
26 | sample_1 = self.net(noise_1, z)
27 | noise_2 = torch.randn((z.shape[0], self.noise_channels), dtype=z.dtype, device=z.device)
28 | sample_2 = self.net(noise_2, z)
29 |
30 | score = self.energy_score(sample_1, sample_2, target, additional_targets)
31 | loss = - score
32 | if mask is not None:
33 | loss = (loss * mask).sum() / mask.sum()
34 | return loss.mean()
35 |
36 | def energy_distance(self, x_1, x_2):
37 | return torch.pow(torch.linalg.norm(x_1 - x_2, ord=2, dim=-1), self.beta)
38 |
39 | def energy_score(self, sample_1, sample_2, target, additional_targets = None):
40 | distance_1 = self.energy_distance(sample_1, target)
41 | distance_2 = self.energy_distance(sample_2, target)
42 | variance = self.energy_distance(sample_1, sample_2)
43 | score = variance - distance_1 - distance_2
44 | return score
45 |
46 | def kernel_distance(self, x_1, x_2):
47 | return - torch.exp(- torch.sum(torch.pow(x_1 - x_2, 2), dim = -1).div(2 * self.gamma**2))
48 |
49 | def kernel_score(self, sample_1, sample_2, target):
50 | distance_1 = self.kernel_distance(sample_1, target)
51 | distance_2 = self.kernel_distance(sample_2, target)
52 | variance = self.kernel_distance(sample_1, sample_2)
53 | score = variance - distance_1 - distance_2
54 | return score
55 |
56 | def sample(self, z, temperature=1.0, cfg=1.0):
57 | if cfg != 1.0:
58 | z_1, z_2 = z.chunk(2, dim=0)
59 | z = z_1 * cfg + (1 - cfg) * z_2
60 |
61 | noise = torch.randn((z.shape[0], self.noise_channels), dtype=z.dtype, device=z.device)
62 | return self.net(noise, z)
63 |
64 |
65 |
66 | def modulate(x, shift, scale):
67 | return x * (1 + scale) + shift
68 |
69 | class ResBlock(nn.Module):
70 | """
71 | A residual block that can optionally change the number of channels.
72 | :param channels: the number of input channels.
73 | """
74 |
75 | def __init__(
76 | self,
77 | channels
78 | ):
79 | super().__init__()
80 | self.channels = channels
81 |
82 | self.in_ln = nn.LayerNorm(channels, eps=1e-6, elementwise_affine=False)
83 | self.mlp = nn.Sequential(
84 | nn.Linear(channels, channels, bias=True),
85 | nn.SiLU(),
86 | nn.Linear(channels, channels, bias=True),
87 | )
88 | self.noise_ln = nn.LayerNorm(channels, eps=1e-6, elementwise_affine=True)
89 |
90 | self.adaLN_modulation = nn.Sequential(
91 | nn.SiLU(),
92 | nn.Linear(channels, 3 * channels, bias=True)
93 | )
94 |
95 | def forward(self, x, y):
96 | shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(self.noise_ln(y)).chunk(3, dim=-1)
97 | h = modulate(self.in_ln(x), shift_mlp, scale_mlp)
98 | h = self.mlp(h)
99 | return x + gate_mlp * h
100 |
101 |
102 | class FinalLayer(nn.Module):
103 | """
104 | The final layer of DiT.
105 | """
106 | def __init__(self, model_channels, out_channels):
107 | super().__init__()
108 | self.norm_final = nn.LayerNorm(model_channels, elementwise_affine=False, eps=1e-6)
109 | self.linear = nn.Linear(model_channels, out_channels, bias=True)
110 | self.adaLN_modulation = nn.Sequential(
111 | nn.SiLU(),
112 | nn.Linear(model_channels, 2 * model_channels, bias=True)
113 | )
114 |
115 | def forward(self, x, c):
116 | shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
117 | x = modulate(self.norm_final(x), shift, scale)
118 | x = self.linear(x)
119 | return x
120 |
121 |
122 | class SimpleMLPAdaLN(nn.Module):
123 | """
124 | The MLP for Energy Distance Loss.
125 | :param in_channels: channels in the input Tensor.
126 | :param model_channels: base channel count for the model.
127 | :param out_channels: channels in the output Tensor.
128 | :param z_channels: channels in the condition.
129 | :param num_res_blocks: number of residual blocks per downsample.
130 | """
131 |
132 | def __init__(
133 | self,
134 | in_channels,
135 | model_channels,
136 | out_channels,
137 | z_channels,
138 | num_res_blocks,
139 | grad_checkpointing=False
140 | ):
141 | super().__init__()
142 |
143 | self.in_channels = in_channels
144 | self.model_channels = model_channels
145 | self.out_channels = out_channels
146 | self.num_res_blocks = num_res_blocks
147 | self.grad_checkpointing = grad_checkpointing
148 |
149 | self.cond_embed = nn.Linear(z_channels, model_channels)
150 |
151 | self.input_proj = nn.Linear(in_channels, model_channels)
152 |
153 | res_blocks = []
154 | for i in range(num_res_blocks):
155 | res_blocks.append(ResBlock(
156 | model_channels,
157 | ))
158 |
159 | self.res_blocks = nn.ModuleList(res_blocks)
160 | self.final_layer = FinalLayer(model_channels, out_channels)
161 |
162 | self.initialize_weights()
163 |
164 | def initialize_weights(self):
165 | def _basic_init(module):
166 | if isinstance(module, nn.Linear):
167 | torch.nn.init.xavier_uniform_(module.weight)
168 | if module.bias is not None:
169 | nn.init.constant_(module.bias, 0)
170 | module._is_hf_initialized = True
171 | self.apply(_basic_init)
172 |
173 | # Zero-out adaLN modulation layers
174 | for block in self.res_blocks:
175 | nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
176 | nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
177 |
178 | # Zero-out output layers
179 | nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
180 | nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
181 | nn.init.constant_(self.final_layer.linear.weight, 0)
182 | nn.init.constant_(self.final_layer.linear.bias, 0)
183 |
184 | def forward(self, x, c):
185 | """
186 | Apply the model to an input batch.
187 | :param x: an [N x C x ...] Tensor of inputs.
188 | :param t: a 1-D batch of timesteps.
189 | :param c: conditioning from AR transformer.
190 | :return: an [N x C x ...] Tensor of outputs.
191 | """
192 | y = self.input_proj(x)
193 | x = self.cond_embed(c)
194 |
195 |
196 |
197 | if self.grad_checkpointing and not torch.jit.is_scripting():
198 | raise NotImplementedError
199 | else:
200 | for block in self.res_blocks:
201 | x = block(x, y)
202 |
203 | return self.final_layer(x, y)
204 |
205 | def forward_with_cfg(self, x, t, c, cfg_scale):
206 | half = x[: len(x) // 2]
207 | combined = torch.cat([half, half], dim=0)
208 | model_out = self.forward(combined, t, c)
209 | eps, rest = model_out[:, :self.in_channels], model_out[:, self.in_channels:]
210 | cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
211 | half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
212 | eps = torch.cat([half_eps, half_eps], dim=0)
213 | return torch.cat([eps, rest], dim=1)
214 |
--------------------------------------------------------------------------------
/sled/sled.py:
--------------------------------------------------------------------------------
1 | from typing import List, Optional, Tuple, Union, Dict, Any
2 | from pathlib import Path
3 | import torch
4 | import torch.nn as nn
5 |
6 | from transformers import EncodecModel
7 | from .modeling_llama_with_dropout import LlamaConfig, LlamaModel, LlamaForCausalLM
8 | from .modeling_llama_with_dropout import _prepare_4d_causal_attention_mask_with_cache_position
9 | from transformers.cache_utils import Cache, DynamicCache, StaticCache
10 | from transformers.modeling_outputs import CausalLMOutputWithPast
11 | from transformers.generation.utils import GenerateDecoderOnlyOutput
12 | from transformers.utils.import_utils import is_torchdynamo_compiling
13 |
14 | from .energy_distance import ScoreLossZ
15 | import copy
16 |
17 |
18 | class ModifiedGenerateDecoderOnlyOutput(GenerateDecoderOnlyOutput):
19 | features = None
20 |
21 |
22 |
23 | class SpeechLlamaConfig(LlamaConfig):
24 | model_type = "speech_llama"
25 |
26 | def __init__(
27 | self,
28 | vae_embed_dim=128,
29 | diffloss_d=3,
30 | diffloss_w=1024,
31 | training_cfg=0.0,
32 | noise_channels=128,
33 | **kwargs,
34 | ):
35 | self.vae_embed_dim = vae_embed_dim
36 | self.diffloss_d = diffloss_d
37 | self.diffloss_w = diffloss_w
38 | self.training_cfg = training_cfg
39 | self.noise_channels = noise_channels
40 |
41 |
42 | super().__init__(**kwargs)
43 |
44 |
45 |
46 |
47 |
48 | class SpeechLlamaForCausalLM(LlamaForCausalLM):
49 | config_class = SpeechLlamaConfig
50 |
51 | def __init__(self, config):
52 | super(LlamaForCausalLM, self).__init__(config)
53 | self.model = LlamaModel(config)
54 | self.codec = None
55 |
56 | # --------------------------------------------------------------------------
57 | # Speech Embedding
58 | self.token_embed_dim = config.vae_embed_dim
59 | self.hidden_size = config.hidden_size
60 | self.z_proj = nn.Linear(self.token_embed_dim, self.hidden_size, bias=True)
61 | self.z_proj_ln = nn.LayerNorm(self.hidden_size, eps=1e-6)
62 | self.eos_head = nn.Linear(self.hidden_size, 1, bias=False)
63 | self.embed_mean = None
64 | self.embed_std = None
65 | self.training_cfg = config.training_cfg
66 |
67 | # --------------------------------------------------------------------------
68 | # Score Loss
69 | self.scoreloss = ScoreLossZ(
70 | target_channels=self.token_embed_dim,
71 | z_channels=self.hidden_size,
72 | width=config.diffloss_w,
73 | depth=config.diffloss_d,
74 | noise_channels=config.noise_channels,
75 | )
76 |
77 | # --------------------------------------------------------------------------
78 | # BCE Loss
79 | pos_weight = torch.Tensor([100.]) # Weight of EOS is equal to 100
80 | self.bceloss = torch.nn.BCEWithLogitsLoss(reduction="none", pos_weight=pos_weight)
81 |
82 |
83 | # Initialize weights and apply final processing
84 | self.post_init() # Check whether affects init of LlamaModel
85 |
86 | def initialize_codec(self, model_args):
87 | if hasattr(model_args, "codec"):
88 | self.codec = EncodecModel.from_pretrained(model_args.codec, torch_dtype=torch.float32) # keep encodec model in fp32
89 | else:
90 | self.codec = EncodecModel.from_pretrained(model_args, torch_dtype=torch.float32)
91 | for param in self.codec.parameters():
92 | param.requires_grad = False
93 |
94 |
95 |
96 | def _init_weights(self, m):
97 | if isinstance(m, nn.Linear):
98 | # we use xavier_uniform following official JAX ViT:
99 | torch.nn.init.xavier_uniform_(m.weight)
100 | if isinstance(m, nn.Linear) and m.bias is not None:
101 | nn.init.constant_(m.bias, 0)
102 | elif isinstance(m, nn.LayerNorm):
103 | if m.bias is not None:
104 | nn.init.constant_(m.bias, 0)
105 | if m.weight is not None:
106 | nn.init.constant_(m.weight, 1.0)
107 |
108 |
109 | def prepare_inputs_labels_for_multimodal(
110 | self,
111 | input_ids,
112 | position_ids,
113 | attention_mask,
114 | past_key_values,
115 | audio_inputs,
116 | ):
117 | if self.training_cfg > 0.0:
118 | bsz = attention_mask.size(0)
119 | random_mask = torch.rand(bsz) < self.training_cfg
120 | cfg_mask = torch.zeros_like(attention_mask, dtype=torch.bool)
121 | cfg_mask[:, :-1] = random_mask[:, None]
122 | attention_mask[cfg_mask] = 0
123 |
124 | text_inputs_embeds = self.model.embed_tokens(input_ids)
125 |
126 | with torch.no_grad():
127 | encoder_outputs = self.codec.encode(audio_inputs["input_values"], audio_inputs["padding_mask"], bandwidth=6) #1,b,r,t, 1 due to one chunk
128 | speech_inputs_embeds = self.codec.quantizer.decode(encoder_outputs.audio_codes[0].transpose(0, 1)) #b,d,t, always fp32
129 |
130 | speech_attention_mask = audio_inputs["padding_mask"][..., ::320]
131 | assert speech_inputs_embeds.size(-1) == speech_attention_mask.size(-1)
132 | speech_inputs_embeds = speech_inputs_embeds.transpose(1,2).to(self.model.dtype) #b,t,d, support full bf16 training
133 |
134 | net_speech_inputs_embeds = self.z_proj(speech_inputs_embeds)
135 | new_inputs_embeds = torch.concat([text_inputs_embeds, net_speech_inputs_embeds], dim=1) #bsz, seq_len, hidden_size
136 | new_attention_mask = torch.concat([attention_mask, speech_attention_mask], dim=1)
137 | new_labels = speech_inputs_embeds
138 |
139 | return None, position_ids, new_attention_mask, past_key_values, new_inputs_embeds, new_labels, speech_attention_mask
140 |
141 | def forward(
142 | self,
143 | input_ids: torch.LongTensor = None,
144 | attention_mask: Optional[torch.Tensor] = None,
145 | position_ids: Optional[torch.LongTensor] = None,
146 | past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
147 | inputs_embeds: Optional[torch.FloatTensor] = None,
148 | labels: Optional[torch.LongTensor] = None,
149 | use_cache: Optional[bool] = None,
150 | output_attentions: Optional[bool] = None,
151 | output_hidden_states: Optional[bool] = None,
152 | return_dict: Optional[bool] = None,
153 | cache_position: Optional[torch.LongTensor] = None,
154 | num_logits_to_keep: int = 0,
155 | audio_inputs: Optional[Dict[str, Any]] = None,
156 | speech_inputs_embeds: Optional[torch.FloatTensor] = None,
157 | speech_attention_mask: Optional[torch.Tensor] = None,
158 | ) -> Union[Tuple, CausalLMOutputWithPast]:
159 |
160 |
161 | if audio_inputs is not None:
162 | (
163 | _,
164 | position_ids,
165 | attention_mask,
166 | past_key_values,
167 | inputs_embeds,
168 | labels,
169 | speech_attention_mask,
170 | ) = self.prepare_inputs_labels_for_multimodal(
171 | input_ids,
172 | position_ids,
173 | attention_mask,
174 | past_key_values,
175 | audio_inputs,
176 | )
177 | else:
178 | assert not ((input_ids is None) and (inputs_embeds is None))
179 | if input_ids is not None and inputs_embeds is None:
180 | inputs_embeds = self.model.embed_tokens(input_ids)
181 | elif input_ids is None and inputs_embeds is not None:
182 | inputs_embeds = self.z_proj(inputs_embeds)
183 | else:
184 | inputs_embeds = torch.cat([self.model.embed_tokens(input_ids), self.z_proj(inputs_embeds)], dim=1)
185 |
186 | inputs_embeds = self.z_proj_ln(inputs_embeds)
187 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
188 | output_hidden_states = (
189 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
190 | )
191 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
192 |
193 | # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
194 | outputs = self.model(
195 | input_ids=None,
196 | attention_mask=attention_mask,
197 | position_ids=position_ids,
198 | past_key_values=past_key_values,
199 | inputs_embeds=inputs_embeds,
200 | use_cache=use_cache,
201 | output_attentions=output_attentions,
202 | output_hidden_states=output_hidden_states,
203 | return_dict=return_dict,
204 | cache_position=cache_position,
205 | )
206 |
207 | hidden_states = outputs[0]
208 | logits = hidden_states[:, -num_logits_to_keep:, :]
209 |
210 | loss = None
211 |
212 | if labels is not None:
213 | bsz, speech_len, _ = labels.shape
214 |
215 | z = logits[:, -(speech_len+1):-1, :] #bsz, speech_len, hid_dim
216 |
217 | labels = labels.reshape(bsz * speech_len, -1)
218 | z = z.reshape(bsz * speech_len, -1)
219 | mask = speech_attention_mask.reshape(bsz * speech_len)
220 | loss = self.scoreloss(z=z, target=labels, mask=mask)
221 |
222 | eos_score = self.eos_head(logits[:, -(speech_len+1):, :]).squeeze(-1).float() #bsz, speech_len+1
223 |
224 | non_pad_counts = speech_attention_mask.sum(dim=1)
225 | eos_labels = torch.zeros(bsz, speech_len + 1)
226 | eos_labels[torch.arange(bsz), non_pad_counts] = 1 #bsz, speech_len+1
227 | eos_labels = eos_labels.to(eos_score.device)
228 |
229 | eos_loss = self.bceloss(eos_score, eos_labels) #bsz, speech_len+1
230 |
231 | #Check BCE loss weight BROADCASTING
232 | ones_column = torch.ones(bsz, 1).to(speech_attention_mask.device)
233 | loss_mask = torch.cat((ones_column, speech_attention_mask), dim=1) #bsz, speech_len+1
234 |
235 | eos_loss = (eos_loss * loss_mask).sum() / loss_mask.sum()
236 |
237 | loss = eos_loss + loss
238 |
239 | if not return_dict:
240 | output = (logits,) + outputs[1:]
241 | return (loss,) + output if loss is not None else output
242 |
243 | return CausalLMOutputWithPast(
244 | loss=loss,
245 | logits=logits,
246 | past_key_values=outputs.past_key_values,
247 | hidden_states=outputs.hidden_states,
248 | attentions=outputs.attentions,
249 | )
250 |
251 | # Check
252 | def state_dict(self, *args, **kwargs):
253 | state_dict = super().state_dict(*args, **kwargs)
254 | codec_keys = [k for k in state_dict if 'codec' in k]
255 | for key in codec_keys:
256 | del state_dict[key]
257 | return state_dict
258 |
259 | # Check
260 | def load_state_dict(self, state_dict, strict=True):
261 | codec_keys = [k for k in state_dict if 'codec' in k]
262 | for key in codec_keys:
263 | del state_dict[key]
264 | return super().load_state_dict(state_dict, strict=False)
265 |
266 |
267 | def prepare_inputs_for_generation(
268 | self,
269 | input_ids,
270 | inputs_embeds=None,
271 | past_key_values=None,
272 | attention_mask=None,
273 | cache_position=None,
274 | position_ids=None,
275 | use_cache=True,
276 | num_logits_to_keep=None,
277 | **kwargs,
278 | ):
279 | if cache_position[0] == 0:
280 | if attention_mask is not None and position_ids is None:
281 | # create position_ids on the fly for batch generation
282 | position_ids = attention_mask.long().cumsum(-1) - 1
283 | position_ids.masked_fill_(attention_mask == 0, 1)
284 |
285 | if inputs_embeds is not None:
286 | model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": inputs_embeds}
287 | else:
288 | model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}
289 |
290 | if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
291 | raise NotImplementedError
292 |
293 | if num_logits_to_keep is not None:
294 | model_inputs["num_logits_to_keep"] = num_logits_to_keep
295 |
296 | model_inputs.update(
297 | {
298 | "position_ids": position_ids,
299 | "cache_position": cache_position,
300 | "past_key_values": past_key_values,
301 | "use_cache": use_cache,
302 | "attention_mask": attention_mask,
303 | }
304 | )
305 | else:
306 | if past_key_values is not None:
307 | inputs_embeds = inputs_embeds[:, -cache_position.shape[0] :, :]
308 |
309 | if attention_mask is not None and position_ids is None:
310 | # create position_ids on the fly for batch generation
311 | position_ids = attention_mask.long().cumsum(-1) - 1
312 | position_ids.masked_fill_(attention_mask == 0, 1)
313 | if past_key_values:
314 | position_ids = position_ids[:, -inputs_embeds.shape[1] :]
315 | position_ids = position_ids.clone(memory_format=torch.contiguous_format)
316 |
317 | model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
318 |
319 | if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
320 | raise NotImplementedError
321 |
322 | if num_logits_to_keep is not None:
323 | model_inputs["num_logits_to_keep"] = num_logits_to_keep
324 |
325 | model_inputs.update(
326 | {
327 | "position_ids": position_ids,
328 | "cache_position": cache_position,
329 | "past_key_values": past_key_values,
330 | "use_cache": use_cache,
331 | "attention_mask": attention_mask,
332 | }
333 | )
334 |
335 | return model_inputs
336 |
337 |
338 | def prepare_inputs_for_generation_cfg(
339 | self,
340 | input_ids,
341 | inputs_embeds=None,
342 | past_key_values=None,
343 | attention_mask=None,
344 | cache_position=None,
345 | position_ids=None,
346 | use_cache=True,
347 | num_logits_to_keep=None,
348 | **kwargs,
349 | ):
350 | attention_mask = attention_mask.clone()
351 | attention_mask[:, :self.prompt_length-1] = 0
352 |
353 | if cache_position[0] == 0:
354 | if attention_mask is not None and position_ids is None:
355 | # create position_ids on the fly for batch generation
356 | position_ids = attention_mask.long().cumsum(-1) - 1
357 | position_ids.masked_fill_(attention_mask == 0, 1)
358 |
359 | if inputs_embeds is not None:
360 | model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": inputs_embeds}
361 | else:
362 | model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}
363 |
364 | if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
365 | raise NotImplementedError
366 |
367 | if num_logits_to_keep is not None:
368 | model_inputs["num_logits_to_keep"] = num_logits_to_keep
369 |
370 | model_inputs.update(
371 | {
372 | "position_ids": position_ids,
373 | "cache_position": cache_position,
374 | "past_key_values": copy.deepcopy(past_key_values),
375 | "use_cache": use_cache,
376 | "attention_mask": attention_mask,
377 | }
378 | )
379 | else:
380 | if past_key_values is not None:
381 | inputs_embeds = inputs_embeds[:, -cache_position.shape[0] :, :]
382 |
383 | if attention_mask is not None and position_ids is None:
384 | # create position_ids on the fly for batch generation
385 | position_ids = attention_mask.long().cumsum(-1) - 1
386 | position_ids.masked_fill_(attention_mask == 0, 1)
387 | if past_key_values:
388 | position_ids = position_ids[:, -inputs_embeds.shape[1] :]
389 | position_ids = position_ids.clone(memory_format=torch.contiguous_format)
390 |
391 | model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
392 |
393 | if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
394 | raise NotImplementedError
395 |
396 | if num_logits_to_keep is not None:
397 | model_inputs["num_logits_to_keep"] = num_logits_to_keep
398 |
399 | model_inputs.update(
400 | {
401 | "position_ids": position_ids,
402 | "cache_position": cache_position,
403 | "past_key_values":copy.deepcopy(past_key_values),
404 | "use_cache": use_cache,
405 | "attention_mask": attention_mask,
406 | }
407 | )
408 |
409 | return model_inputs
410 |
411 |
412 |
413 | def _get_initial_cache_position(self, input_ids, model_kwargs):
414 | """Calculates `cache_position` for the pre-fill stage based on `input_ids` and optionally past length"""
415 | # `torch.compile`-friendly `torch.arange` from a shape -- the lines below are equivalent to `torch.arange`
416 | if "inputs_embeds" in model_kwargs:
417 | #cache_position = torch.ones_like(model_kwargs["inputs_embeds"][0, :, 0], dtype=torch.int64).cumsum(0) - 1
418 | cache_position = (torch.ones(model_kwargs["inputs_embeds"].shape[1] + input_ids.shape[1], dtype=torch.int64).cumsum(0) - 1).to(model_kwargs["inputs_embeds"].device)
419 | else:
420 | cache_position = torch.ones_like(input_ids[0, :], dtype=torch.int64).cumsum(0) - 1
421 |
422 | past_length = 0
423 | if model_kwargs.get("past_key_values") is not None:
424 | cache = model_kwargs["past_key_values"]
425 | past_length = 0
426 | if not isinstance(cache, Cache):
427 | past_length = cache[0][0].shape[2]
428 | elif hasattr(cache, "get_seq_length") and cache.get_seq_length() is not None:
429 | past_length = cache.get_seq_length()
430 |
431 | # TODO(joao): this is not torch.compile-friendly, find a work-around. If the cache is not empty,
432 | # end-to-end compilation will yield bad results because `cache_position` will be incorrect.
433 | if not is_torchdynamo_compiling():
434 | cache_position = cache_position[past_length:]
435 |
436 | model_kwargs["cache_position"] = cache_position
437 | return model_kwargs
438 |
439 |
440 | def _sample(
441 | self,
442 | input_ids,
443 | logits_processor,
444 | stopping_criteria,
445 | generation_config,
446 | synced_gpus,
447 | streamer,
448 | **model_kwargs,
449 | ):
450 | r"""
451 | Parameters:
452 | input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
453 | The sequence used as a prompt for the generation.
454 | logits_processor (`LogitsProcessorList`):
455 | An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
456 | used to modify the prediction scores of the language modeling head applied at each generation step.
457 | stopping_criteria (`StoppingCriteriaList`):
458 | An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
459 | used to tell if the generation loop should stop.
460 | generation_config ([`~generation.GenerationConfig`]):
461 | The generation configuration to be used as parametrization of the decoding method.
462 | synced_gpus (`bool`):
463 | Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
464 | streamer (`BaseStreamer`, *optional*):
465 | Streamer object that will be used to stream the generated sequences. Generated tokens are passed
466 | through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
467 | model_kwargs:
468 | Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is
469 | an encoder-decoder model the kwargs should include `encoder_outputs`.
470 |
471 | Return:
472 | [`~generation.GenerateDecoderOnlyOutput`], [`~generation.GenerateEncoderDecoderOutput`] or `torch.LongTensor`:
473 | A `torch.LongTensor` containing the generated tokens (default behaviour) or a
474 | [`~generation.GenerateDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and
475 | `return_dict_in_generate=True` or a [`~generation.GenerateEncoderDecoderOutput`] if
476 | `model.config.is_encoder_decoder=True`.
477 | """
478 | # init values
479 | pad_token_id = generation_config._pad_token_tensor
480 | bos_token_id = generation_config._bos_token_tensor
481 | eos_token_id = generation_config._eos_token_tensor
482 | output_attentions = generation_config.output_attentions
483 | output_hidden_states = generation_config.output_hidden_states
484 | output_scores = generation_config.output_scores
485 | output_logits = generation_config.output_logits
486 | return_dict_in_generate = generation_config.return_dict_in_generate
487 | max_length = generation_config.max_length
488 | has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria)
489 | do_sample = generation_config.do_sample
490 |
491 | # init attention / hidden states / scores tuples
492 | scores = () if (return_dict_in_generate and output_scores) else None
493 | raw_logits = () if (return_dict_in_generate and output_logits) else None
494 | decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
495 | cross_attentions = () if (return_dict_in_generate and output_attentions) else None
496 | decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
497 |
498 |
499 | inputs_embeds = model_kwargs.get("inputs_embeds", None)
500 |
501 | if self.infer_cfg != 1.0:
502 | real_batch_size = input_ids.shape[0]
503 | if inputs_embeds is not None:
504 | inputs_embeds = inputs_embeds.repeat(2, 1, 1) # expand
505 | input_ids = input_ids.repeat(2, 1) # expand
506 | self.prompt_length = input_ids.shape[1]
507 | extended_attention_mask = model_kwargs["attention_mask"].clone()
508 | extended_attention_mask[:, :self.prompt_length-1] = 0
509 | model_kwargs["attention_mask"] = torch.cat([model_kwargs["attention_mask"], extended_attention_mask], dim=0) # expand
510 |
511 |
512 | # keep track of which sequences are already finished
513 | batch_size, cur_len = input_ids.shape if inputs_embeds is None else (input_ids.shape[0], input_ids.shape[1] + inputs_embeds.shape[1])
514 | this_peer_finished = False
515 | unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
516 | model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)
517 | model_kwargs.pop("inputs_embeds", None)
518 |
519 |
520 | while self._has_unfinished_sequences(
521 | this_peer_finished, synced_gpus, device=input_ids.device, cur_len=cur_len, max_length=max_length
522 | ):
523 | # prepare model inputs
524 | model_inputs = self.prepare_inputs_for_generation(input_ids, inputs_embeds, **model_kwargs)
525 |
526 | # prepare variable output controls (note: some models won't accept all output controls)
527 | model_inputs.update({"output_attentions": output_attentions} if output_attentions else {})
528 | model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {})
529 |
530 | # forward pass to get next token
531 | outputs = self(**model_inputs, return_dict=True)
532 |
533 |
534 | if synced_gpus and this_peer_finished:
535 | continue # don't waste resources running the code we don't need
536 |
537 | # Clone is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration
538 | # (the clone itself is always small)
539 | next_token_logits = outputs.logits.clone()[:, -1, :]
540 | eos_next_token_logits = next_token_logits
541 | if self.infer_cfg != 1.0:
542 | next_token_logits_normal = next_token_logits[:real_batch_size, :]
543 | next_token_logits_cfg = next_token_logits[real_batch_size:, :]
544 | next_token_logits = next_token_logits_cfg + self.infer_cfg * (next_token_logits_normal - next_token_logits_cfg)
545 | eos_next_token_logits = eos_next_token_logits[:real_batch_size, :]
546 | # Store scores, attentions and hidden_states when required
547 | if return_dict_in_generate:
548 | raise NotImplementedError
549 |
550 |
551 | # token selection
552 | if do_sample:
553 | next_embeds = self.scoreloss.sample(next_token_logits, temperature=1.0) # bsz, dim
554 |
555 | next_actions = torch.sigmoid(self.eos_head(eos_next_token_logits)) >= 0.8 # 0: continue, 1: stop
556 | next_tokens = torch.where(next_actions == 0, bos_token_id, eos_token_id)
557 |
558 | if self.infer_cfg != 1.0:
559 | # exband
560 | next_embeds = next_embeds.repeat(2, 1)
561 | next_tokens = next_tokens.repeat(2, 1)
562 |
563 |
564 | else:
565 | raise NotImplementedError
566 |
567 |
568 | # finished sentences should have their next token be a padding token
569 | if has_eos_stopping_criteria:
570 | next_tokens = next_tokens * unfinished_sequences.unsqueeze(1) + pad_token_id * (1 - unfinished_sequences.unsqueeze(1))
571 |
572 | # update generated ids, model inputs, and length for next step
573 | if inputs_embeds is not None:
574 | inputs_embeds = torch.cat([inputs_embeds, next_embeds[:, None, :]], dim=1)
575 | else:
576 | inputs_embeds = next_embeds[:, None, :]
577 | input_ids = torch.cat([input_ids, next_tokens], dim=-1)
578 |
579 |
580 | model_kwargs = self._update_model_kwargs_for_generation(
581 | outputs,
582 | model_kwargs,
583 | )
584 |
585 | unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, None)
586 | this_peer_finished = unfinished_sequences.max() == 0
587 | cur_len += 1
588 |
589 | # This is needed to properly delete outputs.logits which may be very large for first iteration
590 | # Otherwise a reference to outputs is kept which keeps the logits alive in the next iteration
591 | del outputs
592 |
593 | if self.infer_cfg != 1.0:
594 | input_ids = input_ids[:real_batch_size]
595 | inputs_embeds = inputs_embeds[:real_batch_size]
596 |
597 | if return_dict_in_generate:
598 | return ModifiedGenerateDecoderOnlyOutput(
599 | sequences=input_ids,
600 | features=inputs_embeds,
601 | scores=scores,
602 | logits=raw_logits,
603 | attentions=decoder_attentions,
604 | hidden_states=decoder_hidden_states,
605 | past_key_values=model_kwargs.get("past_key_values"),
606 | )
607 | else:
608 | return (input_ids, inputs_embeds)
609 |
610 |
611 | def _update_model_kwargs_for_generation(
612 | self,
613 | outputs,
614 | model_kwargs: Dict[str, Any],
615 | num_new_tokens: int = 1,
616 | ) -> Dict[str, Any]:
617 | # update past_key_values keeping its naming used in model code
618 | cache_name, cache = self._extract_past_from_model_output(outputs)
619 | model_kwargs[cache_name] = cache
620 |
621 | # update attention mask
622 | if "attention_mask" in model_kwargs:
623 | attention_mask = model_kwargs["attention_mask"]
624 | model_kwargs["attention_mask"] = torch.cat(
625 | [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
626 | )
627 |
628 | if model_kwargs.get("use_cache", True):
629 | model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + num_new_tokens
630 | else:
631 | raise NotImplementedError
632 |
633 | return model_kwargs
634 |
--------------------------------------------------------------------------------
/sled/sled_stream.py:
--------------------------------------------------------------------------------
1 | from typing import List, Optional, Tuple, Union, Dict, Any
2 | from pathlib import Path
3 | import torch
4 | import torch.nn as nn
5 |
6 | from transformers import EncodecModel
7 | from .modeling_llama_with_dropout import LlamaConfig, LlamaModel, LlamaForCausalLM
8 | from .modeling_llama_with_dropout import _prepare_4d_causal_attention_mask_with_cache_position
9 | from transformers.cache_utils import Cache, DynamicCache, StaticCache
10 | from transformers.modeling_outputs import CausalLMOutputWithPast
11 | from transformers.generation.utils import GenerateDecoderOnlyOutput
12 | from transformers.utils.import_utils import is_torchdynamo_compiling
13 |
14 | from .energy_distance import ScoreLossZ
15 | import copy
16 |
17 |
18 | class ModifiedGenerateDecoderOnlyOutput(GenerateDecoderOnlyOutput):
19 | features = None
20 |
21 |
22 | def interleave_embeddings_and_mask_efficient(text_embeds, speech_embeds, text_mask, speech_mask, n, m):
23 |
24 | """
25 |
26 | Parameter:
27 | - text_embeds: Tensor with shape (b, t1, d)
28 | - speech_embeds: Tensor with shape (b, t2, d)
29 | - text_mask: Tensor with shape: (b, t1)
30 | - speech_mask: Tensor with shape: (b, t2)
31 | - n: int,number of consecutive text_embeds
32 | - m: int,number of consecutive speech_embeds
33 |
34 | Return:
35 | - interleaved: Tensor with shape (b, t1 + t2, d)
36 | - interleaved_mask: Tensor with shape (b, t1 + t2)
37 | - speech_positions_mask: Tensor with shape (b, t1 + t2)
38 | """
39 |
40 | b, t1, d = text_embeds.size()
41 | _, t2, _ = speech_embeds.size()
42 |
43 |
44 | num_cycles_text = t1 // n
45 | num_cycles_speech = t2 // m
46 | total_cycles = min(num_cycles_text, num_cycles_speech)
47 |
48 | interleaved_blocks = []
49 | interleaved_mask_blocks = []
50 | speech_positions_mask_blocks = []
51 |
52 |
53 | if total_cycles > 0:
54 | text_main = text_embeds[:, :total_cycles * n, :].reshape(b, total_cycles, n, d)
55 | speech_main = speech_embeds[:, :total_cycles * m, :].reshape(b, total_cycles, m, d)
56 |
57 | text_mask_main = text_mask[:, :total_cycles * n].reshape(b, total_cycles, n)
58 | speech_mask_main = speech_mask[:, :total_cycles * m].reshape(b, total_cycles, m)
59 |
60 | interleaved_main = torch.cat([text_main, speech_main], dim=2).reshape(b, total_cycles * (n + m), d)
61 | interleaved_blocks.append(interleaved_main)
62 |
63 | interleaved_mask_main = torch.cat([text_mask_main, speech_mask_main], dim=2).reshape(b, total_cycles * (n + m))
64 | interleaved_mask_blocks.append(interleaved_mask_main)
65 |
66 |
67 | text_zero_mask = torch.zeros_like(text_mask_main)
68 | speech_one_main = torch.ones_like(speech_mask_main)
69 |
70 | speech_positions_main = torch.cat([text_zero_mask, speech_one_main], dim=2).reshape(b, total_cycles * (n + m))
71 | speech_positions_mask_blocks.append(speech_positions_main)
72 |
73 |
74 | remaining_text = text_embeds[:, total_cycles * n:, :]
75 | remaining_speech = speech_embeds[:, total_cycles * m:, :]
76 |
77 | remaining_mask_text = text_mask[:, total_cycles * n:]
78 | remaining_mask_speech = speech_mask[:, total_cycles * m:]
79 |
80 |
81 | remaining_num_text = remaining_text.size(1)
82 | remaining_num_speech = remaining_speech.size(1)
83 |
84 | assert (remaining_num_text < n) or (remaining_num_speech < m)
85 |
86 | interleaved_blocks.append(remaining_text[:, :n, :])
87 | interleaved_blocks.append(remaining_speech)
88 | interleaved_blocks.append(remaining_text[:, n:, :])
89 |
90 | interleaved_mask_blocks.append(remaining_mask_text[:, :n])
91 | interleaved_mask_blocks.append(remaining_mask_speech)
92 | interleaved_mask_blocks.append(remaining_mask_text[:, n:])
93 |
94 | speech_positions_mask_blocks.append(torch.zeros_like(remaining_mask_text[:, :n]))
95 | speech_positions_mask_blocks.append(torch.ones_like(remaining_mask_speech))
96 | speech_positions_mask_blocks.append(torch.zeros_like(remaining_mask_text[:, n:]))
97 |
98 | interleaved = torch.cat(interleaved_blocks, dim=1)
99 | interleaved_mask = torch.cat(interleaved_mask_blocks, dim=1)
100 | speech_positions_mask = torch.cat(speech_positions_mask_blocks, dim=1)
101 |
102 | assert interleaved.size(1) == (t1 + t2) == interleaved_mask.size(1) == speech_positions_mask.size(1)
103 |
104 | return interleaved, interleaved_mask, speech_positions_mask
105 |
106 |
107 |
108 | def get_previous_non_pad_indices(attention_mask):
109 | """
110 | find previous non-pad index for each position
111 |
112 | Parameter:
113 | attention_mask (torch.Tensor): shape (batch_size, seq_length), 0 for pad,1 for non-pad。
114 |
115 | Return:
116 | torch.Tensor: shape (batch_size, seq_length), each element is the index of the previous non-pad position, or -1 if there is none.
117 | """
118 |
119 | batch_size, seq_length = attention_mask.shape
120 |
121 |
122 | indices = torch.arange(seq_length, device=attention_mask.device).unsqueeze(0).expand(batch_size, -1) # (batch_size, seq_length)
123 |
124 |
125 | mask = attention_mask == 1
126 | indices = torch.where(mask, indices, torch.full_like(indices, -1))
127 |
128 |
129 | shifted_indices = torch.cat([torch.full((batch_size, 1), -1, device=attention_mask.device), indices[:, :-1]], dim=1)
130 |
131 |
132 | previous_non_pad = torch.cummax(shifted_indices, dim=1).values
133 |
134 | return previous_non_pad
135 |
136 |
137 |
138 | class SpeechLlamaConfig(LlamaConfig):
139 | model_type = "speech_llama"
140 |
141 | def __init__(
142 | self,
143 | vae_embed_dim=128,
144 | diffloss_d=3,
145 | diffloss_w=1024,
146 | training_cfg=0.0,
147 | noise_channels=128,
148 | stream_n=5,
149 | stream_m=45,
150 | **kwargs,
151 | ):
152 |
153 | self.vae_embed_dim = vae_embed_dim
154 | self.diffloss_d = diffloss_d
155 | self.diffloss_w = diffloss_w
156 | self.training_cfg = training_cfg
157 | self.noise_channels = noise_channels
158 | self.stream_n = stream_n
159 | self.stream_m = stream_m
160 |
161 |
162 | super().__init__(**kwargs)
163 |
164 |
165 |
166 |
167 |
168 | class SpeechLlamaForCausalLM(LlamaForCausalLM):
169 | config_class = SpeechLlamaConfig
170 |
171 | def __init__(self, config):
172 | super(LlamaForCausalLM, self).__init__(config)
173 | self.model = LlamaModel(config)
174 | self.codec = None
175 |
176 | # --------------------------------------------------------------------------
177 | # Speech Embedding
178 | self.token_embed_dim = config.vae_embed_dim
179 | self.hidden_size = config.hidden_size
180 | self.z_proj = nn.Linear(self.token_embed_dim, self.hidden_size, bias=True)
181 | self.z_proj_ln = nn.LayerNorm(self.hidden_size, eps=1e-6)
182 | self.eos_head = nn.Linear(self.hidden_size, 1, bias=False)
183 | self.embed_mean = None
184 | self.embed_std = None
185 | self.training_cfg = config.training_cfg
186 |
187 | # --------------------------------------------------------------------------
188 | # Score Loss
189 | self.scoreloss = ScoreLossZ(
190 | target_channels=self.token_embed_dim,
191 | z_channels=self.hidden_size,
192 | width=config.diffloss_w,
193 | depth=config.diffloss_d,
194 | noise_channels=config.noise_channels,
195 | )
196 |
197 | # --------------------------------------------------------------------------
198 | # BCE Loss
199 | pos_weight = torch.Tensor([100.]) # Weight of EOS is equal to 100
200 | self.bceloss = torch.nn.BCEWithLogitsLoss(reduction="none", pos_weight=pos_weight)
201 |
202 | self.stream_n = config.stream_n
203 | self.stream_m = config.stream_m
204 |
205 | # Initialize weights and apply final processing
206 | self.post_init() # Check whether affects init of LlamaModel
207 |
208 | def initialize_codec(self, model_args):
209 | if hasattr(model_args, "codec"):
210 | self.codec = EncodecModel.from_pretrained(model_args.codec, torch_dtype=torch.float32) # keep encodec model in fp32
211 | else:
212 | self.codec = EncodecModel.from_pretrained(model_args, torch_dtype=torch.float32)
213 | for param in self.codec.parameters():
214 | param.requires_grad = False
215 |
216 |
217 |
218 | def _init_weights(self, m):
219 | if isinstance(m, nn.Linear):
220 | # we use xavier_uniform following official JAX ViT:
221 | torch.nn.init.xavier_uniform_(m.weight)
222 | if isinstance(m, nn.Linear) and m.bias is not None:
223 | nn.init.constant_(m.bias, 0)
224 | elif isinstance(m, nn.LayerNorm):
225 | if m.bias is not None:
226 | nn.init.constant_(m.bias, 0)
227 | if m.weight is not None:
228 | nn.init.constant_(m.weight, 1.0)
229 |
230 |
231 | def prepare_inputs_labels_for_multimodal(
232 | self,
233 | input_ids,
234 | position_ids,
235 | attention_mask,
236 | past_key_values,
237 | audio_inputs,
238 | ):
239 | if self.training_cfg > 0.0:
240 | bsz = attention_mask.size(0)
241 | random_mask = torch.rand(bsz) < self.training_cfg
242 | attention_mask[random_mask] = 0
243 | attention_mask[random_mask, 0] = 1
244 | input_ids[random_mask, 0] = 2
245 |
246 |
247 | text_inputs_embeds = self.model.embed_tokens(input_ids)
248 |
249 | with torch.no_grad():
250 | encoder_outputs = self.codec.encode(audio_inputs["input_values"], audio_inputs["padding_mask"], bandwidth=6) #1,b,r,t, 1 due to one chunk
251 | speech_inputs_embeds = self.codec.quantizer.decode(encoder_outputs.audio_codes[0].transpose(0, 1)) #b,d,t, always fp32
252 |
253 | speech_attention_mask = audio_inputs["padding_mask"][..., ::320]
254 | assert speech_inputs_embeds.size(-1) == speech_attention_mask.size(-1)
255 | speech_inputs_embeds = speech_inputs_embeds.transpose(1,2).to(self.model.dtype) #b,t,d, support full bf16 training
256 |
257 |
258 | net_speech_inputs_embeds = self.z_proj(speech_inputs_embeds)
259 | new_inputs_embeds, new_attention_mask, speech_positions_mask = interleave_embeddings_and_mask_efficient(text_inputs_embeds, net_speech_inputs_embeds, attention_mask, speech_attention_mask, self.stream_n, self.stream_m)
260 |
261 | new_labels = speech_inputs_embeds
262 |
263 | return None, position_ids, new_attention_mask, past_key_values, new_inputs_embeds, new_labels, speech_attention_mask, speech_positions_mask
264 |
265 | def forward(
266 | self,
267 | input_ids: torch.LongTensor = None,
268 | attention_mask: Optional[torch.Tensor] = None,
269 | position_ids: Optional[torch.LongTensor] = None,
270 | past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
271 | inputs_embeds: Optional[torch.FloatTensor] = None,
272 | labels: Optional[torch.LongTensor] = None,
273 | use_cache: Optional[bool] = None,
274 | output_attentions: Optional[bool] = None,
275 | output_hidden_states: Optional[bool] = None,
276 | return_dict: Optional[bool] = None,
277 | cache_position: Optional[torch.LongTensor] = None,
278 | num_logits_to_keep: int = 0,
279 | audio_inputs: Optional[Dict[str, Any]] = None,
280 | speech_inputs_embeds: Optional[torch.FloatTensor] = None,
281 | speech_attention_mask: Optional[torch.Tensor] = None,
282 | ) -> Union[Tuple, CausalLMOutputWithPast]:
283 |
284 |
285 | if audio_inputs is not None:
286 | (
287 | _,
288 | position_ids,
289 | attention_mask,
290 | past_key_values,
291 | inputs_embeds,
292 | labels,
293 | speech_attention_mask,
294 | speech_positions_mask,
295 | ) = self.prepare_inputs_labels_for_multimodal(
296 | input_ids,
297 | position_ids,
298 | attention_mask,
299 | past_key_values,
300 | audio_inputs,
301 | )
302 | else:
303 | assert not ((input_ids is None) and (inputs_embeds is None))
304 | if input_ids is not None and inputs_embeds is None:
305 | inputs_embeds = self.model.embed_tokens(input_ids)
306 | elif input_ids is None and inputs_embeds is not None:
307 | inputs_embeds = self.z_proj(inputs_embeds)
308 | else:
309 | inputs_embeds = torch.cat([self.z_proj(inputs_embeds), self.model.embed_tokens(input_ids)], dim=1)
310 |
311 | inputs_embeds = self.z_proj_ln(inputs_embeds)
312 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
313 | output_hidden_states = (
314 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
315 | )
316 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
317 |
318 | # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
319 | outputs = self.model(
320 | input_ids=None,
321 | attention_mask=attention_mask,
322 | position_ids=position_ids,
323 | past_key_values=past_key_values,
324 | inputs_embeds=inputs_embeds,
325 | use_cache=use_cache,
326 | output_attentions=output_attentions,
327 | output_hidden_states=output_hidden_states,
328 | return_dict=return_dict,
329 | cache_position=cache_position,
330 | )
331 |
332 | hidden_states = outputs[0]
333 | logits = hidden_states[:, -num_logits_to_keep:, :]
334 |
335 | loss = None
336 |
337 | if labels is not None:
338 |
339 | bsz, speech_len, _ = labels.shape
340 | feat_dim = logits.size(-1)
341 | generate_index = get_previous_non_pad_indices(attention_mask)
342 | selected_indices = generate_index.masked_select((speech_positions_mask == 1)) # bsz * speech_len
343 | selected_indices = selected_indices.view(bsz, speech_len) # bsz, speech_len
344 | selected_indices_expanded = selected_indices.unsqueeze(-1).expand(-1, -1, feat_dim) # bsz * speech_len * feat_dim
345 |
346 | z = torch.gather(logits, dim=1, index=selected_indices_expanded) # bsz * speech_len * feat_dim
347 |
348 |
349 | labels = labels.reshape(bsz * speech_len, -1)
350 | mask = speech_attention_mask.reshape(bsz * speech_len)
351 | loss = self.scoreloss(z=z.reshape(bsz * speech_len, -1), target=labels, mask=mask)
352 |
353 | current_index = torch.arange(attention_mask.size(1), device=attention_mask.device).unsqueeze(0).expand(bsz, -1) # (bsz, full_seq_length)
354 | selected_current_indices = current_index.masked_select((speech_positions_mask == 1)) # bsz * speech_len
355 | selected_current_indices = selected_current_indices.view(bsz, speech_len)[:, -1:] # bsz, 1
356 | selected_current_indices_expanded = selected_current_indices.unsqueeze(-1).expand(-1, -1, feat_dim) # bsz * 1 * feat_dim
357 |
358 | z_last = torch.gather(logits, dim=1, index=selected_current_indices_expanded) # bsz * 1 * feat_dim
359 |
360 | z = torch.cat([z, z_last], dim=1)
361 |
362 | eos_score = self.eos_head(z).squeeze(-1).float() #bsz, speech_len+1
363 |
364 | non_pad_counts = speech_attention_mask.sum(dim=1)
365 | eos_labels = torch.zeros(bsz, speech_len + 1)
366 | eos_labels[torch.arange(bsz), non_pad_counts] = 1 #bsz, speech_len+1
367 | eos_labels = eos_labels.to(eos_score.device)
368 |
369 | eos_loss = self.bceloss(eos_score, eos_labels) #bsz, speech_len+1
370 |
371 | #Check BCE loss weight BROADCASTING
372 | ones_column = torch.ones(bsz, 1).to(speech_attention_mask.device)
373 | loss_mask = torch.cat((ones_column, speech_attention_mask), dim=1) #bsz, speech_len+1
374 |
375 | eos_loss = (eos_loss * loss_mask).sum() / loss_mask.sum()
376 |
377 | loss = eos_loss + loss
378 |
379 | if not return_dict:
380 | output = (logits,) + outputs[1:]
381 | return (loss,) + output if loss is not None else output
382 |
383 | return CausalLMOutputWithPast(
384 | loss=loss,
385 | logits=logits,
386 | past_key_values=outputs.past_key_values,
387 | hidden_states=outputs.hidden_states,
388 | attentions=outputs.attentions,
389 | )
390 |
391 | # Check
392 | def state_dict(self, *args, **kwargs):
393 | state_dict = super().state_dict(*args, **kwargs)
394 | codec_keys = [k for k in state_dict if 'codec' in k]
395 | for key in codec_keys:
396 | del state_dict[key]
397 | return state_dict
398 |
399 | # Check
400 | def load_state_dict(self, state_dict, strict=True):
401 | codec_keys = [k for k in state_dict if 'codec' in k]
402 | for key in codec_keys:
403 | del state_dict[key]
404 | return super().load_state_dict(state_dict, strict=False)
405 |
406 |
407 | def prepare_inputs_for_generation(
408 | self,
409 | input_ids,
410 | inputs_embeds=None,
411 | turn=0,
412 | num_write_turn=0,
413 | past_key_values=None,
414 | attention_mask=None,
415 | cache_position=None,
416 | position_ids=None,
417 | use_cache=True,
418 | num_logits_to_keep=None,
419 | **kwargs,
420 | ):
421 | if num_write_turn !=0 :
422 | assert cache_position.shape[0] == 1
423 | assert past_key_values is not None
424 | inputs_embeds = inputs_embeds[:, -cache_position.shape[0] :, :]
425 |
426 | if attention_mask is not None and position_ids is None:
427 | # create position_ids on the fly for batch generation
428 | position_ids = attention_mask.long().cumsum(-1) - 1
429 | position_ids.masked_fill_(attention_mask == 0, 1)
430 |
431 | position_ids = position_ids[:, -inputs_embeds.shape[1] :]
432 | position_ids = position_ids.clone(memory_format=torch.contiguous_format)
433 |
434 | model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
435 |
436 | if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
437 | raise NotImplementedError
438 |
439 | if num_logits_to_keep is not None:
440 | model_inputs["num_logits_to_keep"] = num_logits_to_keep
441 |
442 | model_inputs.update(
443 | {
444 | "position_ids": position_ids,
445 | "cache_position": cache_position,
446 | "past_key_values": past_key_values,
447 | "use_cache": use_cache,
448 | "attention_mask": attention_mask,
449 | }
450 | )
451 |
452 | else:
453 | if inputs_embeds is not None:
454 | inputs_embeds = inputs_embeds[:, -1 :, :]
455 |
456 | if attention_mask is not None and position_ids is None:
457 | # create position_ids on the fly for batch generation
458 | position_ids = attention_mask.long().cumsum(-1) - 1
459 | position_ids.masked_fill_(attention_mask == 0, 1)
460 |
461 | position_ids = position_ids[:, -cache_position.shape[0] :]
462 | position_ids = position_ids.clone(memory_format=torch.contiguous_format)
463 |
464 | model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": inputs_embeds}
465 |
466 | if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
467 | raise NotImplementedError
468 |
469 | if num_logits_to_keep is not None:
470 | model_inputs["num_logits_to_keep"] = num_logits_to_keep
471 |
472 | model_inputs.update(
473 | {
474 | "position_ids": position_ids,
475 | "cache_position": cache_position,
476 | "past_key_values": past_key_values,
477 | "use_cache": use_cache,
478 | "attention_mask": attention_mask,
479 | }
480 | )
481 | return model_inputs
482 |
483 |
484 | def prepare_inputs_for_generation_cfg(
485 | self,
486 | input_ids,
487 | inputs_embeds=None,
488 | past_key_values=None,
489 | attention_mask=None,
490 | cache_position=None,
491 | position_ids=None,
492 | use_cache=True,
493 | num_logits_to_keep=None,
494 | **kwargs,
495 | ):
496 | if cache_position[0] == 0:
497 | if attention_mask is not None and position_ids is None:
498 | # create position_ids on the fly for batch generation
499 | position_ids = attention_mask.long().cumsum(-1) - 1
500 | position_ids.masked_fill_(attention_mask == 0, 1)
501 |
502 | model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}
503 |
504 | if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
505 | raise NotImplementedError
506 |
507 | if num_logits_to_keep is not None:
508 | model_inputs["num_logits_to_keep"] = num_logits_to_keep
509 |
510 | model_inputs.update(
511 | {
512 | "position_ids": position_ids,
513 | "cache_position": cache_position,
514 | "past_key_values": past_key_values,
515 | "use_cache": use_cache,
516 | "attention_mask": attention_mask,
517 | }
518 | )
519 | else:
520 | if past_key_values is not None:
521 | inputs_embeds = inputs_embeds[:, -cache_position.shape[0] :, :]
522 |
523 | if attention_mask is not None and position_ids is None:
524 | # create position_ids on the fly for batch generation
525 | position_ids = attention_mask.long().cumsum(-1) - 1
526 | position_ids.masked_fill_(attention_mask == 0, 1)
527 | if past_key_values:
528 | position_ids = position_ids[:, -inputs_embeds.shape[1] :]
529 | position_ids = position_ids.clone(memory_format=torch.contiguous_format)
530 |
531 | model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
532 |
533 | if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
534 | raise NotImplementedError
535 |
536 | if num_logits_to_keep is not None:
537 | model_inputs["num_logits_to_keep"] = num_logits_to_keep
538 |
539 | model_inputs.update(
540 | {
541 | "position_ids": position_ids,
542 | "cache_position": cache_position,
543 | "past_key_values": past_key_values,
544 | "use_cache": use_cache,
545 | "attention_mask": attention_mask,
546 | }
547 | )
548 |
549 | return model_inputs
550 |
551 |
552 |
553 | def _get_initial_cache_position_after_each_read(self, input_ids, model_kwargs):
554 | """Calculates `cache_position` for the pre-fill stage based on `input_ids` and optionally past length"""
555 | # `torch.compile`-friendly `torch.arange` from a shape -- the lines below are equivalent to `torch.arange`
556 | cache_position = torch.ones_like(input_ids[0, :], dtype=torch.int64).cumsum(0) - 1
557 |
558 | past_length = 0
559 | if model_kwargs.get("past_key_values") is not None:
560 | cache = model_kwargs["past_key_values"]
561 | past_length = 0
562 | if not isinstance(cache, Cache):
563 | past_length = cache[0][0].shape[2]
564 | elif hasattr(cache, "get_seq_length") and cache.get_seq_length() is not None:
565 | past_length = cache.get_seq_length()
566 |
567 | # TODO(joao): this is not torch.compile-friendly, find a work-around. If the cache is not empty,
568 | # end-to-end compilation will yield bad results because `cache_position` will be incorrect.
569 | if not is_torchdynamo_compiling():
570 | cache_position = cache_position + past_length
571 | if model_kwargs.get("cache_position", None) is not None:
572 | model_kwargs["cache_position"] = torch.cat([model_kwargs["cache_position"], cache_position + 1], dim=0)
573 | else:
574 | model_kwargs["cache_position"] = cache_position
575 | return model_kwargs
576 |
577 |
578 | def _get_initial_cache_position_cfg(self, input_ids, model_kwargs):
579 | """Calculates `cache_position` for the pre-fill stage based on `input_ids` and optionally past length"""
580 | # `torch.compile`-friendly `torch.arange` from a shape -- the lines below are equivalent to `torch.arange`
581 | cache_position = torch.ones_like(input_ids[0, :], dtype=torch.int64).cumsum(0) - 1
582 |
583 | past_length = 0
584 | if model_kwargs.get("past_key_values") is not None:
585 | cache = model_kwargs["past_key_values"]
586 | past_length = 0
587 | if not isinstance(cache, Cache):
588 | past_length = cache[0][0].shape[2]
589 | elif hasattr(cache, "get_seq_length") and cache.get_seq_length() is not None:
590 | past_length = cache.get_seq_length()
591 |
592 | # TODO(joao): this is not torch.compile-friendly, find a work-around. If the cache is not empty,
593 | # end-to-end compilation will yield bad results because `cache_position` will be incorrect.
594 | if not is_torchdynamo_compiling():
595 | cache_position = cache_position[past_length:]
596 |
597 | model_kwargs["cache_position"] = cache_position
598 | return model_kwargs
599 |
600 |
601 | def _sample(
602 | self,
603 | input_ids,
604 | logits_processor,
605 | stopping_criteria,
606 | generation_config,
607 | synced_gpus,
608 | streamer,
609 | **model_kwargs,
610 | ):
611 | r"""
612 | Parameters:
613 | input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
614 | The sequence used as a prompt for the generation.
615 | logits_processor (`LogitsProcessorList`):
616 | An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
617 | used to modify the prediction scores of the language modeling head applied at each generation step.
618 | stopping_criteria (`StoppingCriteriaList`):
619 | An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
620 | used to tell if the generation loop should stop.
621 | generation_config ([`~generation.GenerationConfig`]):
622 | The generation configuration to be used as parametrization of the decoding method.
623 | synced_gpus (`bool`):
624 | Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
625 | streamer (`BaseStreamer`, *optional*):
626 | Streamer object that will be used to stream the generated sequences. Generated tokens are passed
627 | through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
628 | model_kwargs:
629 | Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is
630 | an encoder-decoder model the kwargs should include `encoder_outputs`.
631 |
632 | Return:
633 | [`~generation.GenerateDecoderOnlyOutput`], [`~generation.GenerateEncoderDecoderOutput`] or `torch.LongTensor`:
634 | A `torch.LongTensor` containing the generated tokens (default behaviour) or a
635 | [`~generation.GenerateDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and
636 | `return_dict_in_generate=True` or a [`~generation.GenerateEncoderDecoderOutput`] if
637 | `model.config.is_encoder_decoder=True`.
638 | """
639 | # init values
640 | pad_token_id = generation_config._pad_token_tensor
641 | bos_token_id = generation_config._bos_token_tensor
642 | eos_token_id = generation_config._eos_token_tensor
643 | output_attentions = generation_config.output_attentions
644 | output_hidden_states = generation_config.output_hidden_states
645 | output_scores = generation_config.output_scores
646 | output_logits = generation_config.output_logits
647 | return_dict_in_generate = generation_config.return_dict_in_generate
648 | max_length = generation_config.max_length
649 | has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria)
650 | do_sample = generation_config.do_sample
651 |
652 | # init attention / hidden states / scores tuples
653 | scores = () if (return_dict_in_generate and output_scores) else None
654 | raw_logits = () if (return_dict_in_generate and output_logits) else None
655 | decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
656 | cross_attentions = () if (return_dict_in_generate and output_attentions) else None
657 | decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
658 |
659 |
660 |
661 | # keep track of which sequences are already finished
662 | batch_size = input_ids.shape[0]
663 | assert batch_size == 1 # only support batch size 1
664 | turn = 0
665 | read_finished = False
666 | this_peer_finished = False
667 | unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
668 | last_src_read_len = 0
669 | cur_len = 0
670 | complete_attention_mask = model_kwargs.pop("attention_mask")
671 | model_kwargs["attention_mask"] = torch.zeros((batch_size, 0), dtype=complete_attention_mask.dtype, device=complete_attention_mask.device)
672 | inputs_embeds = None
673 | generate_ids = torch.zeros((batch_size, 0), dtype=input_ids.dtype, device=input_ids.device)
674 |
675 |
676 | if self.infer_cfg != 1.0:
677 | model_kwargs_cfg = copy.deepcopy(model_kwargs)
678 | model_kwargs_cfg["attention_mask"] = torch.ones((batch_size, 1), dtype=complete_attention_mask.dtype, device=complete_attention_mask.device)
679 | input_ids_cfg = torch.ones((batch_size, 1), dtype=input_ids.dtype, device=input_ids.device)
680 | input_ids_cfg[:, 0] = 2
681 | inputs_embeds_cfg = None
682 | model_kwargs_cfg = self._get_initial_cache_position_cfg(input_ids_cfg, model_kwargs_cfg)
683 |
684 |
685 | while read_finished == False:
686 |
687 | src_read_len = min(input_ids.shape[1], (turn + 1) * self.stream_n)
688 | if src_read_len == input_ids.shape[1]:
689 | read_finished = True
690 |
691 |
692 | this_turn_input_ids = input_ids[:, last_src_read_len:src_read_len]
693 | model_kwargs = self._get_initial_cache_position_after_each_read(this_turn_input_ids, model_kwargs)
694 | model_kwargs["attention_mask"] = torch.cat([model_kwargs["attention_mask"], complete_attention_mask[:, last_src_read_len:src_read_len]], dim=1)
695 |
696 | #self.prompt_length = input_ids.shape[1] if inputs_embeds is None else input_ids.shape[1] + inputs_embeds.shape[1]
697 |
698 | cur_len = cur_len + src_read_len - last_src_read_len
699 |
700 |
701 | last_src_read_len = src_read_len
702 | num_write_in_turn = 0
703 |
704 |
705 | while self._has_unfinished_sequences(
706 | this_peer_finished, synced_gpus, device=input_ids.device, cur_len=cur_len, max_length=max_length
707 | ):
708 | if (not read_finished) and (num_write_in_turn == self.stream_m):
709 | break
710 |
711 | # prepare model inputs
712 | model_inputs = self.prepare_inputs_for_generation(this_turn_input_ids, inputs_embeds, turn, num_write_in_turn, **model_kwargs)
713 | if self.infer_cfg != 1.0:
714 | model_inputs_cfg = self.prepare_inputs_for_generation_cfg(input_ids_cfg, inputs_embeds_cfg, **model_kwargs_cfg)
715 |
716 |
717 | # prepare variable output controls (note: some models won't accept all output controls)
718 | model_inputs.update({"output_attentions": output_attentions} if output_attentions else {})
719 | model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {})
720 |
721 | # forward pass to get next token
722 | outputs = self(**model_inputs, return_dict=True)
723 | if self.infer_cfg != 1.0:
724 | outputs_cfg = self(**model_inputs_cfg, return_dict=True)
725 |
726 |
727 | if synced_gpus and this_peer_finished:
728 | continue # don't waste resources running the code we don't need
729 |
730 | # Clone is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration
731 | # (the clone itself is always small)
732 | next_token_logits = outputs.logits.clone()[:, -1, :]
733 | eos_next_token_logits = next_token_logits.clone()
734 | if self.infer_cfg != 1.0:
735 | next_token_logits_cfg = outputs_cfg.logits.clone()[:, -1, :]
736 | next_token_logits = next_token_logits_cfg + self.infer_cfg * (next_token_logits - next_token_logits_cfg)
737 | # Store scores, attentions and hidden_states when required
738 | if return_dict_in_generate:
739 | if output_logits:
740 | raw_logits += (next_token_logits,)
741 | if output_attentions:
742 | decoder_attentions += (
743 | (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
744 | )
745 | if self.config.is_encoder_decoder:
746 | cross_attentions += (outputs.cross_attentions,)
747 |
748 | if output_hidden_states:
749 | decoder_hidden_states += (
750 | (outputs.decoder_hidden_states,)
751 | if self.config.is_encoder_decoder
752 | else (outputs.hidden_states,)
753 | )
754 |
755 | # token selection
756 | if do_sample:
757 | next_embeds = self.scoreloss.sample(next_token_logits, temperature=1.0) # bsz, dim
758 |
759 | next_actions = torch.sigmoid(self.eos_head(eos_next_token_logits)) >= 0.5 # 0: continue, 1: stop
760 | # if read_finished:
761 | # next_tokens = torch.where(next_actions == 0, bos_token_id, eos_token_id)
762 | # else:
763 | # next_tokens = torch.full((batch_size,), bos_token_id ,device=next_actions.device)
764 | next_tokens = torch.where(next_actions == 0, bos_token_id, eos_token_id)
765 | else:
766 | raise NotImplementedError
767 |
768 |
769 | # finished sentences should have their next token be a padding token
770 | if has_eos_stopping_criteria:
771 | next_tokens = next_tokens * unfinished_sequences.unsqueeze(1) + pad_token_id * (1 - unfinished_sequences.unsqueeze(1))
772 |
773 | # update generated ids, model inputs, and length for next step
774 | if inputs_embeds is not None:
775 | inputs_embeds = torch.cat([inputs_embeds, next_embeds[:, None, :]], dim=1)
776 | else:
777 | inputs_embeds = next_embeds[:, None, :]
778 |
779 | generate_ids = torch.cat([generate_ids, next_tokens], dim=-1)
780 |
781 | if self.infer_cfg != 1.0:
782 | if inputs_embeds_cfg is not None:
783 | inputs_embeds_cfg = torch.cat([inputs_embeds_cfg, next_embeds[:, None, :]], dim=1)
784 | else:
785 | inputs_embeds_cfg = next_embeds[:, None, :]
786 |
787 | input_ids_cfg = torch.cat([input_ids_cfg, next_tokens], dim=-1)
788 |
789 |
790 | model_kwargs = self._update_model_kwargs_for_generation(
791 | outputs,
792 | model_kwargs,
793 | )
794 |
795 |
796 | if self.infer_cfg != 1.0:
797 | model_kwargs_cfg = self._update_model_kwargs_for_generation(outputs_cfg, model_kwargs_cfg)
798 |
799 |
800 |
801 | unfinished_sequences = unfinished_sequences & ~stopping_criteria(generate_ids, None)
802 | this_peer_finished = unfinished_sequences.max() == 0
803 | cur_len += 1
804 | num_write_in_turn += 1
805 |
806 | # This is needed to properly delete outputs.logits which may be very large for first iteration
807 | # Otherwise a reference to outputs is kept which keeps the logits alive in the next iteration
808 | del outputs
809 |
810 | turn += 1
811 |
812 | if return_dict_in_generate:
813 | return ModifiedGenerateDecoderOnlyOutput(
814 | sequences=generate_ids,
815 | features=inputs_embeds,
816 | scores=scores,
817 | logits=raw_logits,
818 | attentions=decoder_attentions,
819 | hidden_states=decoder_hidden_states,
820 | past_key_values=model_kwargs.get("past_key_values"),
821 | )
822 | else:
823 | return (generate_ids, inputs_embeds)
824 |
825 | def _update_model_kwargs_for_generation(
826 | self,
827 | outputs,
828 | model_kwargs: Dict[str, Any],
829 | num_new_tokens: int = 1,
830 | ) -> Dict[str, Any]:
831 | # update past_key_values keeping its naming used in model code
832 | cache_name, cache = self._extract_past_from_model_output(outputs)
833 | model_kwargs[cache_name] = cache
834 |
835 | # update attention mask
836 | if "attention_mask" in model_kwargs:
837 | attention_mask = model_kwargs["attention_mask"]
838 | model_kwargs["attention_mask"] = torch.cat(
839 | [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
840 | )
841 |
842 | if model_kwargs.get("use_cache", True):
843 | model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + num_new_tokens
844 | else:
845 | raise NotImplementedError
846 |
847 | return model_kwargs
848 |
--------------------------------------------------------------------------------
/sled/trainer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from typing import List, Optional
3 |
4 | from torch.utils.data import Dataset
5 | from transformers import Trainer
6 | from transformers.trainer import has_length
7 | from transformers.trainer_pt_utils import LengthGroupedSampler
8 |
9 | class SpeechLengthGroupedSampler(LengthGroupedSampler):
10 | r"""
11 | Sampler that samples indices in a way that groups together features of the dataset of roughly the same length while
12 | keeping a bit of randomness.
13 | """
14 |
15 | def __init__(
16 | self,
17 | batch_size: int,
18 | dataset: Optional[Dataset] = None,
19 | generator=None,
20 | ):
21 |
22 | self.batch_size = batch_size
23 |
24 | lengths = [len(feature["audio"]["array"]) for feature in dataset]
25 | self.lengths = lengths
26 | self.generator = generator
27 |
28 |
29 |
30 |
31 | class SpeechLlamaTrainer(Trainer):
32 |
33 | def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
34 | if self.train_dataset is None or not has_length(self.train_dataset):
35 | return None
36 |
37 |
38 | if self.args.group_by_speech_length:
39 | return SpeechLengthGroupedSampler(
40 | self.args.train_batch_size * self.args.gradient_accumulation_steps,
41 | dataset=self.train_dataset,
42 | )
43 | else:
44 | return super()._get_train_sampler()
--------------------------------------------------------------------------------
/sled/trainer_libriheavy.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from typing import List, Optional
3 |
4 | from torch.utils.data import Dataset
5 | from transformers import Trainer
6 | from transformers.trainer import has_length
7 | from transformers.trainer_pt_utils import LengthGroupedSampler
8 |
9 | class SpeechLengthGroupedSampler(LengthGroupedSampler):
10 | r"""
11 | Sampler that samples indices in a way that groups together features of the dataset of roughly the same length while
12 | keeping a bit of randomness.
13 | """
14 |
15 | def __init__(
16 | self,
17 | batch_size: int,
18 | dataset: Optional[Dataset] = None,
19 | generator=None,
20 | ):
21 |
22 | self.batch_size = batch_size
23 |
24 | lengths = [int(feature["duration"] * 75) for feature in dataset]
25 | self.lengths = lengths
26 | self.generator = generator
27 |
28 |
29 |
30 |
31 | class SpeechLlamaTrainer(Trainer):
32 |
33 | def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
34 | if self.train_dataset is None or not has_length(self.train_dataset):
35 | return None
36 |
37 |
38 | if self.args.group_by_speech_length:
39 | return SpeechLengthGroupedSampler(
40 | self.args.train_batch_size * self.args.gradient_accumulation_steps,
41 | dataset=self.train_dataset,
42 | )
43 | else:
44 | return super()._get_train_sampler()
45 |
--------------------------------------------------------------------------------