├── .gitignore ├── .vscode └── launch.json ├── 22k_to_16k.sh ├── LICENSE ├── Llama-3.2-1B-Instruct └── config.json ├── README.md ├── data.json ├── images └── model.png ├── make_data.py ├── omni_speech ├── arguments.py ├── constants.py ├── conversation.py ├── datasets │ ├── __init__.py │ └── preprocess.py ├── infer │ ├── convert_jsonl_to_txt.py │ ├── gen_answer_data │ │ ├── answer.json │ │ ├── question.json │ │ └── wavs │ ├── infer.py │ ├── run.sh │ └── unit2wav.sh ├── model │ ├── .ipynb_checkpoints │ │ └── omni_speech_arch-checkpoint.py │ ├── __init__.py │ ├── builder.py │ ├── language_model │ │ ├── .ipynb_checkpoints │ │ │ └── omni_speech_llama-checkpoint.py │ │ ├── omni_speech2s_llama.py │ │ └── omni_speech_llama.py │ ├── omni_speech_arch.py │ ├── speech_encoder │ │ ├── builder.py │ │ └── speech_encoder.py │ ├── speech_generator │ │ ├── builder.py │ │ ├── generation.py │ │ └── speech_generator.py │ └── speech_projector │ │ ├── builder.py │ │ └── speech_projector.py ├── serve │ ├── __init__.py │ ├── controller.py │ ├── examples │ │ ├── helpful_base_1.wav │ │ ├── helpful_base_2.wav │ │ ├── helpful_base_3.wav │ │ ├── helpful_base_4.wav │ │ ├── helpful_base_5.wav │ │ ├── vicuna_1.wav │ │ ├── vicuna_2.wav │ │ ├── vicuna_3.wav │ │ ├── vicuna_4.wav │ │ └── vicuna_5.wav │ ├── gradio_web_server.py │ └── model_worker.py ├── train │ ├── run.sh │ ├── stage1.py │ └── stage2.py └── utils.py ├── pyproject.toml ├── r_100.txt ├── requirements.txt └── wavs ├── sft_1.wav ├── sft_10.wav ├── sft_100.wav ├── sft_11.wav ├── sft_12.wav ├── sft_13.wav ├── sft_14.wav ├── sft_15.wav ├── sft_16.wav ├── sft_17.wav ├── sft_18.wav ├── sft_19.wav ├── sft_2.wav ├── sft_20.wav ├── sft_21.wav ├── sft_22.wav ├── sft_23.wav ├── sft_24.wav ├── sft_25.wav ├── sft_26.wav ├── sft_27.wav ├── sft_28.wav ├── sft_29.wav ├── sft_3.wav ├── sft_30.wav ├── sft_31.wav ├── sft_32.wav ├── sft_33.wav ├── sft_34.wav ├── sft_35.wav ├── sft_36.wav ├── sft_37.wav ├── sft_38.wav ├── sft_39.wav ├── sft_4.wav ├── sft_40.wav ├── sft_41.wav ├── sft_42.wav ├── sft_43.wav ├── sft_44.wav ├── sft_45.wav ├── sft_46.wav ├── sft_47.wav ├── sft_48.wav ├── sft_49.wav ├── sft_5.wav ├── sft_50.wav ├── sft_51.wav ├── sft_52.wav ├── sft_53.wav ├── sft_54.wav ├── sft_55.wav ├── sft_56.wav ├── sft_57.wav ├── sft_58.wav ├── sft_59.wav ├── sft_6.wav ├── sft_60.wav ├── sft_61.wav ├── sft_62.wav ├── sft_63.wav ├── sft_64.wav ├── sft_65.wav ├── sft_66.wav ├── sft_67.wav ├── sft_68.wav ├── sft_69.wav ├── sft_7.wav ├── sft_70.wav ├── sft_71.wav ├── sft_72.wav ├── sft_73.wav ├── sft_74.wav ├── sft_75.wav ├── sft_76.wav ├── sft_77.wav ├── sft_78.wav ├── sft_79.wav ├── sft_8.wav ├── sft_80.wav ├── sft_81.wav ├── sft_82.wav ├── sft_83.wav ├── sft_84.wav ├── sft_85.wav ├── sft_86.wav ├── sft_87.wav ├── sft_88.wav ├── sft_89.wav ├── sft_9.wav ├── sft_90.wav ├── sft_91.wav ├── sft_92.wav ├── sft_93.wav ├── sft_94.wav ├── sft_95.wav ├── sft_96.wav ├── sft_97.wav ├── sft_98.wav └── sft_99.wav /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | llama_omni.egg-info/ 3 | models/ 4 | vocoder/ 5 | -------------------------------------------------------------------------------- /.vscode/launch.json: -------------------------------------------------------------------------------- 1 | { 2 | // 使用 IntelliSense 了解相关属性。 3 | // 悬停以查看现有属性的描述。 4 | // 欲了解更多信息,请访问: https://go.microsoft.com/fwlink/?linkid=830387 5 | "version": "0.2.0", 6 | "configurations": [ 7 | { 8 | "name": "Python: llama", 9 | "type": "python", 10 | "request": "launch", 11 | "program": "${file}", 12 | "console": "integratedTerminal", 13 | "justMyCode": true, 14 | "args": [ 15 | "--model-path", "Llama-3.2-1B-Instruct", 16 | "--question-file", "data.json", 17 | "--answer-file", "answer.json", 18 | "--num-chunks", "1", 19 | "--chunk-idx", "0", 20 | "--temperature", "0", 21 | "--conv-mode", "llama_3", 22 | "--input_type", "mel", 23 | "--mel_size", "128" 24 | ], 25 | "python":"/opt/conda/bin/python" 26 | } 27 | ] 28 | } -------------------------------------------------------------------------------- /22k_to_16k.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # 定义输入和输出目录 4 | input_dir="wavs_22k" 5 | output_dir="wavs" 6 | 7 | # 创建输出目录(如果不存在) 8 | mkdir -p "$output_dir" 9 | 10 | # 遍历输入目录中的所有文件 11 | for input_file in "$input_dir"/*; do 12 | # 获取文件的基本名称 13 | base_name=$(basename "$input_file") 14 | # 定义输出文件的路径 15 | output_file="$output_dir/$base_name" 16 | 17 | # 使用 ffmpeg 转换音频采样率为 16kHz 18 | ffmpeg -i "$input_file" -ar 16000 "$output_file" 19 | done 20 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /Llama-3.2-1B-Instruct/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "_name_or_path": "Llama-3.2-1B-Omni", 3 | "architectures": [ 4 | "OmniSpeechLlamaForCausalLM" 5 | ], 6 | "attention_bias": false, 7 | "attention_dropout": 0.0, 8 | "bos_token_id": 128000, 9 | "ctc_decoder_config": "(2,4096,32,11008)", 10 | "ctc_loss_weight": 1.0, 11 | "ctc_upsample_factor": 25, 12 | "eos_token_id": [ 13 | 128001, 14 | 128008, 15 | 128009 16 | ], 17 | "freeze_speech_projector": false, 18 | "head_dim": 64, 19 | "hidden_act": "silu", 20 | "hidden_size": 2048, 21 | "initializer_range": 0.02, 22 | "intermediate_size": 8192, 23 | "max_position_embeddings": 131072, 24 | "mlp_bias": false, 25 | "model_type": "omni_speech_llama", 26 | "num_attention_heads": 32, 27 | "num_hidden_layers": 16, 28 | "num_key_value_heads": 8, 29 | "pretraining_tp": 1, 30 | "rms_norm_eps": 1e-05, 31 | "rope_scaling": { 32 | "factor": 32.0, 33 | "high_freq_factor": 4.0, 34 | "low_freq_factor": 1.0, 35 | "original_max_position_embeddings": 8192, 36 | "rope_type": "llama3" 37 | }, 38 | "rope_theta": 500000.0, 39 | "speech_encoder": "models/speech_encoder/large-v3.pt", 40 | "speech_encoder_ds_rate": 5, 41 | "speech_encoder_hidden_size": 1280, 42 | "speech_encoder_type": "whisper", 43 | "speech_generator_type": "ctc", 44 | "speech_normalize": false, 45 | "speech_projector_lr": 0.001, 46 | "speech_projector_type": "linear", 47 | "tie_word_embeddings": false, 48 | "tokenizer_model_max_length": 2048, 49 | "tokenizer_padding_side": "right", 50 | "torch_dtype": "bfloat16", 51 | "transformers_version": "4.43.4", 52 | "tune_speech_projector": true, 53 | "unit_vocab_size": 1000, 54 | "use_cache": true, 55 | "vocab_size": 128256 56 | } 57 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | **LLama-Omni Training Code Recurrence(llama-omni训练代码复现)** 2 | 3 | 1.根据LLama-omni给的方法进行环境安装(https://github.com/ictnlp/LLaMA-Omni) 4 | 5 | 2.wavs下是根据论文方法生成的100条数据(指令相同,模型用的Qwen,其他保持一致),用于阶段一和阶段二的训练 6 | 7 | 3.whisper模型放到models下(需要自己下载);一阶段我用的是1B模型(Llama-3.2-1B-Instruct),下载后放到当前目录下,记得修改config.json的关于whisper配置,否则会报错;如果跑的起8B,可直接用原论文的权重,或原生8B进行修改。 8 | 9 | 4.vocoder是音频生成模块的权重(已包含)。 10 | 11 | 5.二阶段的数据在omni_speech/infer/gen_answer_data/answer.json,用wavs下的question音频生成的回复生成的token。 12 | 13 | 6.两个阶段的精度用的都是bf16,loss可以正常下降,单卡3090,一阶段由于是1B可以多卡跑起来,二阶段由于设备有限,只在单卡上跑了跑,可以收敛到2。fp16会有loss nan的问题。 14 | 15 | 7.启动方法: 16 | 17 | 一阶段:bash omni_speech/train/run.sh 18 | 19 | 二阶段:python omni_speech/train/stage2.py 20 | 21 | PS: 22 | 23 | 1.感谢LLama-Omni的工作!也感谢另一位小伙伴@EDGSCOUT-li,我们一起复现这篇论文的训练过程,由于我们打算去复现freeze_omni(https://github.com/VITA-MLLM/Freeze-Omni)。 24 | 25 | 如果有issue看到后会及时回复,并对该项目进行更新。 26 | 27 | 2.由于资源限制还没有用更多的数据训练好一个能work的模型并进行发布,后续可能会结合freeze的方法训练好一个中文端到端对话的模型。 28 | 29 | 3.希望能对做端到端语音对话的朋友有一点点帮助 -------------------------------------------------------------------------------- /images/model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wntg/LLaMA-Omni/cf4c32fb4ee5cb620c58520db562e6e13384ecb3/images/model.png -------------------------------------------------------------------------------- /make_data.py: -------------------------------------------------------------------------------- 1 | # temple="<|begin_of_text|><|start_header_id|>system<|end_header_id|>\ 2 | # You are a helpful language and speech assistant.You are able to understand the speech content that the user provides,\ 3 | # and assist the user with a variety of tasks using natural language.<|eot_id|>\ 4 | # <|start_header_id|>user<|end_header_id|>Please answer the questions in the user’s input speech.<|eot_id|>\ 5 | # <|start_header_id|>assistant<|end_header_id|><|end_of_text|>" 6 | 7 | 8 | 9 | #读取目录下所有wav文件,并打印出来路径 10 | import os 11 | import json 12 | wav_dir = 'wavs' 13 | wav_files = os.listdir(wav_dir) 14 | 15 | #打开文件并读取内容,和wav_files同时遍历 16 | with open('r_100.txt', 'r', encoding='utf-8') as f: 17 | responses = f.readlines() 18 | with open("data.json", "w", encoding="utf-8") as file: 19 | saved_array = [] 20 | for wav,response in zip(wav_files,responses): 21 | #以json格式保存os.path.join(wav_dir,wav),response.strip() 22 | data={"id":wav.split('.')[0], 23 | "speech":os.path.join(wav_dir,wav), 24 | "conversations":[ 25 | { 26 | "from": "human", 27 | "value": "\nPlease directly answer the questions in the user's speech." 28 | }, 29 | { "from": "assistant", 30 | "value": response.strip() 31 | } 32 | ] 33 | } 34 | saved_array.append(data) 35 | json.dump(saved_array, file, indent=4, ensure_ascii=False) 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | -------------------------------------------------------------------------------- /omni_speech/arguments.py: -------------------------------------------------------------------------------- 1 | import transformers 2 | 3 | from dataclasses import dataclass, field 4 | from typing import Optional 5 | 6 | 7 | @dataclass 8 | class ModelArguments: 9 | model_name_or_path: Optional[str] = field(default="facebook/opt-125m") 10 | version: Optional[str] = field(default="v0") 11 | freeze_backbone: bool = field(default=False) 12 | tune_speech_projector: bool = field(default=False) 13 | tune_speech_encoder: bool = field(default=False) 14 | tune_speech_generator_only: bool = field(default=False) 15 | speech_encoder_type: Optional[str] = field(default=None) 16 | speech_encoder: Optional[str] = field(default=None) 17 | pretrain_speech_projector: Optional[str] = field(default=None) 18 | speech_projector_type: Optional[str] = field(default='linear') 19 | speech_generator_type: Optional[str] = field(default='ctc') 20 | ctc_decoder_config: str = "(2,4096,32,11008)" 21 | ctc_upsample_factor: int = 1 22 | ctc_loss_weight: float = 1.0 23 | unit_vocab_size: int = 1000 24 | speech_encoder_ds_rate: int = 5 25 | speech_encoder_hidden_size: int = 1280 26 | 27 | 28 | @dataclass 29 | class DataArguments: 30 | data_path: str = field(default=None, 31 | metadata={"help": "Path to the training data."}) 32 | is_multimodal: bool = False 33 | input_type: str = field(default="mel") 34 | speech_normalize: bool = False 35 | mel_size: int = 128 36 | has_tgt_units: bool = False 37 | 38 | 39 | @dataclass 40 | class TrainingArguments(transformers.TrainingArguments): 41 | cache_dir: Optional[str] = field(default=None) 42 | optim: str = field(default="adamw_torch") 43 | freeze_speech_projector: bool = field(default=False) 44 | model_max_length: int = field( 45 | default=512, 46 | metadata={ 47 | "help": 48 | "Maximum sequence length. Sequences will be right padded (and possibly truncated)." 49 | }, 50 | ) 51 | double_quant: bool = field( 52 | default=True, 53 | metadata={"help": "Compress the quantization statistics through double quantization."} 54 | ) 55 | quant_type: str = field( 56 | default="nf4", 57 | metadata={"help": "Quantization data type to use. Should be one of `fp4` or `nf4`."} 58 | ) 59 | bits: int = field( 60 | default=16, 61 | metadata={"help": "How many bits to use."} 62 | ) 63 | lora_enable: bool = False 64 | lora_r: int = 64 65 | lora_alpha: int = 16 66 | lora_dropout: float = 0.05 67 | lora_weight_path: str = "" 68 | lora_bias: str = "none" 69 | speech_projector_lr: Optional[float] = None 70 | group_by_modality_length: bool = field(default=False) -------------------------------------------------------------------------------- /omni_speech/constants.py: -------------------------------------------------------------------------------- 1 | CONTROLLER_HEART_BEAT_EXPIRATION = 30 2 | WORKER_HEART_BEAT_INTERVAL = 15 3 | 4 | LOGDIR = "." 5 | 6 | # Model Constants 7 | IGNORE_INDEX = -100 8 | SPEECH_TOKEN_INDEX = -200 9 | DEFAULT_SPEECH_TOKEN = "" -------------------------------------------------------------------------------- /omni_speech/conversation.py: -------------------------------------------------------------------------------- 1 | # Adopted from https://github.com/haotian-liu/LLaVA. Below is the original copyright: 2 | # Copyright 2023 Haotian Liu 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import dataclasses 17 | from enum import auto, Enum 18 | from typing import List, Any, Union, Tuple 19 | import base64 20 | from io import BytesIO 21 | from PIL import Image 22 | 23 | 24 | class SeparatorStyle(Enum): 25 | """Different separator style.""" 26 | TWO = auto() 27 | PLAIN = auto() 28 | LLAMA_2 = auto() 29 | LLAMA_3 = auto() 30 | 31 | 32 | @dataclasses.dataclass 33 | class Conversation: 34 | """A class that keeps all conversation history.""" 35 | system: str 36 | roles: List[str] 37 | messages: List[List[str]] 38 | offset: int 39 | sep_style: SeparatorStyle = SeparatorStyle.PLAIN 40 | sep: str = "###" 41 | sep2: str = None 42 | version: str = "Unknown" 43 | 44 | tokenizer_id: str = "" 45 | tokenizer: Any = None 46 | # Stop criteria (the default one is EOS token) 47 | stop_str: Union[str, List[str]] = None 48 | # Stops generation if meeting any token in this list 49 | stop_token_ids: List[int] = None 50 | 51 | skip_next: bool = False 52 | 53 | def get_prompt(self): 54 | messages = self.messages 55 | 56 | if self.sep_style == SeparatorStyle.TWO: 57 | seps = [self.sep, self.sep2] 58 | ret = self.system + seps[0] 59 | for i, (role, message) in enumerate(messages): 60 | if message: 61 | if type(message) is tuple: 62 | message = message[0] 63 | ret += role + ": " + message + seps[i % 2] 64 | else: 65 | ret += role + ":" 66 | elif self.sep_style == SeparatorStyle.LLAMA_3: 67 | wrap_sys = lambda msg: f"<|start_header_id|>system<|end_header_id|>\n\n{msg}<|eot_id|>" if len(msg) > 0 else msg 68 | ret = "<|begin_of_text|>" + wrap_sys(self.system) 69 | for i, (role, message) in enumerate(messages): 70 | if message: 71 | if type(message) is tuple: 72 | message = message[0] 73 | ret += f"<|start_header_id|>{role}<|end_header_id|>\n\n" 74 | ret += message.strip() + self.sep2 75 | else: 76 | ret += f"<|start_header_id|>{role}<|end_header_id|>\n\n" 77 | return ret 78 | elif self.sep_style == SeparatorStyle.LLAMA_2: 79 | wrap_sys = lambda msg: f"<>\n{msg}\n<>\n\n" if len(msg) > 0 else msg 80 | wrap_inst = lambda msg: f"[INST] {msg} [/INST]" 81 | ret = "" 82 | 83 | for i, (role, message) in enumerate(messages): 84 | if i == 0: 85 | assert message, "first message should not be none" 86 | assert role == self.roles[0], "first message should come from user" 87 | if message: 88 | if type(message) is tuple: 89 | message, _, _ = message 90 | if i == 0: 91 | message = wrap_sys(self.system) + message 92 | if i % 2 == 0: 93 | message = wrap_inst(message) 94 | ret += self.sep + message 95 | else: 96 | ret += " " + message + " " + self.sep2 97 | else: 98 | ret += "" 99 | ret = ret.lstrip(self.sep) 100 | elif self.sep_style == SeparatorStyle.PLAIN: 101 | seps = [self.sep, self.sep2] 102 | ret = self.system 103 | for i, (role, message) in enumerate(messages): 104 | if message: 105 | if type(message) is tuple: 106 | message, _, _ = message 107 | ret += message + seps[i % 2] 108 | else: 109 | ret += "" 110 | else: 111 | raise ValueError(f"Invalid style: {self.sep_style}") 112 | 113 | return ret 114 | 115 | def append_message(self, role, message): 116 | self.messages.append([role, message]) 117 | 118 | def to_gradio_chatbot(self): 119 | ret = [] 120 | for i, (role, msg) in enumerate(self.messages[self.offset:]): 121 | if i % 2 == 0: 122 | if type(msg) is tuple: 123 | msg, speech = msg 124 | ret.append([msg, None]) 125 | else: 126 | ret.append([msg, None]) 127 | else: 128 | ret[-1][-1] = msg 129 | return ret 130 | 131 | def copy(self): 132 | return Conversation( 133 | system=self.system, 134 | roles=self.roles, 135 | messages=[[x, y] for x, y in self.messages], 136 | offset=self.offset, 137 | sep_style=self.sep_style, 138 | sep=self.sep, 139 | sep2=self.sep2, 140 | version=self.version) 141 | 142 | def dict(self): 143 | if len(self.get_images()) > 0: 144 | return { 145 | "system": self.system, 146 | "roles": self.roles, 147 | "messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages], 148 | "offset": self.offset, 149 | "sep": self.sep, 150 | "sep2": self.sep2, 151 | } 152 | return { 153 | "system": self.system, 154 | "roles": self.roles, 155 | "messages": self.messages, 156 | "offset": self.offset, 157 | "sep": self.sep, 158 | "sep2": self.sep2, 159 | } 160 | 161 | conv_vicuna_v1 = Conversation( 162 | system="A chat between a curious user and an artificial intelligence assistant. " "The assistant gives helpful, detailed, and polite answers to the user's questions.", 163 | roles=("USER", "ASSISTANT"), 164 | version="v1", 165 | messages=[], 166 | offset=0, 167 | sep_style=SeparatorStyle.TWO, 168 | sep=" ", 169 | sep2="", 170 | ) 171 | 172 | conv_llama_2 = Conversation( 173 | system="You are a helpful language and speech assistant. " "You are able to understand the speech content that the user provides, " "and assist the user with a variety of tasks using natural language.", 174 | roles=("USER", "ASSISTANT"), 175 | version="llama_v2", 176 | messages=[], 177 | offset=0, 178 | sep_style=SeparatorStyle.LLAMA_2, 179 | sep="", 180 | sep2="", 181 | ) 182 | 183 | conv_llama_3 = Conversation( 184 | system="You are a helpful language and speech assistant. " "You are able to understand the speech content that the user provides, " "and assist the user with a variety of tasks using natural language.", 185 | roles=("user", "assistant"), 186 | version="llama_v3", 187 | messages=[], 188 | offset=0, 189 | sep_style=SeparatorStyle.LLAMA_3, 190 | sep="", 191 | sep2="<|eot_id|>" 192 | ) 193 | 194 | conv_plain = Conversation( 195 | system="", 196 | roles=("", ""), 197 | messages=( 198 | ), 199 | offset=0, 200 | sep_style=SeparatorStyle.PLAIN, 201 | sep="", 202 | ) 203 | 204 | 205 | default_conversation = conv_llama_3 206 | conv_templates = { 207 | "v1": conv_vicuna_v1, 208 | "plain": conv_plain, 209 | "llama_2": conv_llama_2, 210 | "llama_3": conv_llama_3, 211 | } 212 | 213 | 214 | if __name__ == "__main__": 215 | print(default_conversation.get_prompt()) 216 | -------------------------------------------------------------------------------- /omni_speech/datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wntg/LLaMA-Omni/cf4c32fb4ee5cb620c58520db562e6e13384ecb3/omni_speech/datasets/__init__.py -------------------------------------------------------------------------------- /omni_speech/datasets/preprocess.py: -------------------------------------------------------------------------------- 1 | # Adopted from https://github.com/haotian-liu/LLaVA. Below is the original copyright: 2 | # Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright: 3 | # Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright: 4 | # Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | 18 | import copy 19 | import torch 20 | import transformers 21 | import tokenizers 22 | 23 | from typing import Dict, Sequence 24 | 25 | from omni_speech.constants import IGNORE_INDEX, DEFAULT_SPEECH_TOKEN 26 | from omni_speech import conversation as conversation_lib 27 | from omni_speech.model import * 28 | from omni_speech.arguments import DataArguments 29 | from omni_speech.constants import SPEECH_TOKEN_INDEX 30 | 31 | from packaging import version 32 | 33 | IS_TOKENIZER_GREATER_THAN_0_14 = version.parse(tokenizers.__version__) >= version.parse('0.14') 34 | 35 | 36 | def tokenizer_speech_token(prompt, tokenizer, speech_token_index=SPEECH_TOKEN_INDEX, return_tensors=None): 37 | prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('')] 38 | 39 | def insert_separator(X, sep): 40 | return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1] 41 | 42 | input_ids = [] 43 | offset = 0 44 | if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id: 45 | offset = 1 46 | input_ids.append(prompt_chunks[0][0]) 47 | 48 | for x in insert_separator(prompt_chunks, [speech_token_index] * (offset + 1)): 49 | input_ids.extend(x[offset:]) 50 | 51 | if return_tensors is not None: 52 | if return_tensors == 'pt': 53 | return torch.tensor(input_ids, dtype=torch.long) 54 | raise ValueError(f'Unsupported tensor type: {return_tensors}') 55 | return input_ids 56 | 57 | 58 | def preprocess_multimodal( 59 | sources: Sequence[str], 60 | data_args: DataArguments 61 | ) -> Dict: 62 | is_multimodal = data_args.is_multimodal 63 | if not is_multimodal: 64 | return sources 65 | 66 | for source in sources: 67 | for sentence in source: 68 | if DEFAULT_SPEECH_TOKEN in sentence['value']: 69 | sentence['value'] = sentence['value'].replace(DEFAULT_SPEECH_TOKEN, '').strip() 70 | sentence['value'] = DEFAULT_SPEECH_TOKEN + '\n' + sentence['value'] 71 | sentence['value'] = sentence['value'].strip() 72 | 73 | return sources 74 | 75 | 76 | def preprocess_llama_2( 77 | sources, 78 | tokenizer: transformers.PreTrainedTokenizer, 79 | has_speech: bool = False 80 | ) -> Dict: 81 | conv = conversation_lib.default_conversation.copy() 82 | roles = {"human": conv.roles[0], "gpt": conv.roles[1]} 83 | 84 | # Apply prompt templates 85 | conversations = [] 86 | for i, source in enumerate(sources): 87 | if roles[source[0]["from"]] != conv.roles[0]: 88 | # Skip the first one if it is not from human 89 | source = source[1:] 90 | 91 | conv.messages = [] 92 | for j, sentence in enumerate(source): 93 | role = roles[sentence["from"]] 94 | assert role == conv.roles[j % 2], f"{i}" 95 | conv.append_message(role, sentence["value"]) 96 | conversations.append(conv.get_prompt()) 97 | 98 | # Tokenize conversations 99 | 100 | if has_speech: 101 | input_ids = torch.stack([tokenizer_speech_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0) 102 | else: 103 | input_ids = tokenizer( 104 | conversations, 105 | return_tensors="pt", 106 | padding="longest", 107 | max_length=tokenizer.model_max_length, 108 | truncation=True, 109 | ).input_ids 110 | 111 | targets = input_ids.clone() 112 | 113 | assert conv.sep_style == conversation_lib.SeparatorStyle.LLAMA_2 114 | 115 | # Mask targets 116 | sep = "[/INST] " 117 | for conversation, target in zip(conversations, targets): 118 | total_len = int(target.ne(tokenizer.pad_token_id).sum()) 119 | 120 | rounds = conversation.split(conv.sep2) 121 | cur_len = 1 122 | target[:cur_len] = IGNORE_INDEX 123 | for i, rou in enumerate(rounds): 124 | if rou == "": 125 | break 126 | 127 | parts = rou.split(sep) 128 | if len(parts) != 2: 129 | break 130 | parts[0] += sep 131 | 132 | if has_speech: 133 | round_len = len(tokenizer_speech_token(rou, tokenizer)) 134 | instruction_len = len(tokenizer_speech_token(parts[0], tokenizer)) - 2 135 | else: 136 | round_len = len(tokenizer(rou).input_ids) 137 | instruction_len = len(tokenizer(parts[0]).input_ids) - 2 138 | 139 | target[cur_len : cur_len + instruction_len] = IGNORE_INDEX 140 | 141 | cur_len += round_len 142 | target[cur_len:] = IGNORE_INDEX 143 | 144 | if cur_len < tokenizer.model_max_length: 145 | if cur_len != total_len: 146 | target[:] = IGNORE_INDEX 147 | print( 148 | f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." 149 | f" (ignored)" 150 | ) 151 | 152 | return dict( 153 | input_ids=input_ids, 154 | labels=targets, 155 | ) 156 | 157 | 158 | def preprocess_llama_3( 159 | sources, 160 | tokenizer: transformers.PreTrainedTokenizer, 161 | has_speech: bool = False 162 | ) -> Dict: 163 | conv = conversation_lib.default_conversation.copy() 164 | roles = {"human": conv.roles[0], "gpt": conv.roles[1]} 165 | 166 | # Apply prompt templates 167 | conversations = [] 168 | for i, source in enumerate(sources): 169 | if roles[source[0]["from"]] != conv.roles[0]: 170 | # Skip the first one if it is not from human 171 | source = source[1:] 172 | 173 | assert len(source) == 2, "now only support single-turn conversation" 174 | 175 | conv.messages = [] 176 | for j, sentence in enumerate(source): 177 | role = roles[sentence["from"]] 178 | assert role == conv.roles[j % 2], f"{i}" 179 | conv.append_message(role, sentence["value"]) 180 | conversations.append(conv.get_prompt()) 181 | 182 | # Tokenize conversations 183 | 184 | if has_speech: 185 | input_ids = torch.stack([tokenizer_speech_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0) 186 | else: 187 | input_ids = tokenizer( 188 | conversations, 189 | return_tensors="pt", 190 | padding="longest", 191 | max_length=tokenizer.model_max_length, 192 | truncation=True, 193 | ).input_ids 194 | 195 | targets = input_ids.clone() 196 | 197 | assert conv.sep_style == conversation_lib.SeparatorStyle.LLAMA_3 198 | 199 | # Mask targets 200 | sep = "<|start_header_id|>" + conv.roles[1] + "<|end_header_id|>\n\n" 201 | for conversation, target in zip(conversations, targets): 202 | total_len = int(target.ne(tokenizer.pad_token_id).sum()) 203 | 204 | cur_len = 1 205 | target[:cur_len] = IGNORE_INDEX 206 | parts = conversation.split(sep) 207 | parts[0] += sep 208 | 209 | if has_speech: 210 | conversation_len = len(tokenizer_speech_token(conversation, tokenizer)) 211 | instruction_len = len(tokenizer_speech_token(parts[0], tokenizer)) - 1 212 | else: 213 | conversation_len = len(tokenizer(conversation).input_ids) 214 | instruction_len = len(tokenizer(parts[0]).input_ids) - 1 215 | 216 | target[cur_len : cur_len + instruction_len] = IGNORE_INDEX 217 | cur_len += conversation_len 218 | target[cur_len:] = IGNORE_INDEX 219 | 220 | # if cur_len < tokenizer.model_max_length: 221 | # if cur_len != total_len: 222 | # target[:] = IGNORE_INDEX 223 | # print( 224 | # f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." 225 | # f" (ignored)" 226 | # ) 227 | 228 | return dict( 229 | input_ids=input_ids, 230 | labels=targets, 231 | ) 232 | 233 | 234 | def preprocess_v1( 235 | sources, 236 | tokenizer: transformers.PreTrainedTokenizer, 237 | has_speech: bool = False 238 | ) -> Dict: 239 | conv = conversation_lib.default_conversation.copy() 240 | roles = {"human": conv.roles[0], "gpt": conv.roles[1]} 241 | 242 | # Apply prompt templates 243 | conversations = [] 244 | for i, source in enumerate(sources): 245 | if roles[source[0]["from"]] != conv.roles[0]: 246 | # Skip the first one if it is not from human 247 | source = source[1:] 248 | 249 | conv.messages = [] 250 | for j, sentence in enumerate(source): 251 | role = roles[sentence["from"]] 252 | assert role == conv.roles[j % 2], f"{i}" 253 | conv.append_message(role, sentence["value"]) 254 | conversations.append(conv.get_prompt()) 255 | 256 | # Tokenize conversations 257 | 258 | if has_speech: 259 | input_ids = torch.stack([tokenizer_speech_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0) 260 | else: 261 | input_ids = tokenizer( 262 | conversations, 263 | return_tensors="pt", 264 | padding="longest", 265 | max_length=tokenizer.model_max_length, 266 | truncation=True, 267 | ).input_ids 268 | 269 | targets = input_ids.clone() 270 | 271 | assert conv.sep_style == conversation_lib.SeparatorStyle.TWO 272 | 273 | # Mask targets 274 | sep = conv.sep + conv.roles[1] + ": " 275 | for conversation, target in zip(conversations, targets): 276 | total_len = int(target.ne(tokenizer.pad_token_id).sum()) 277 | 278 | rounds = conversation.split(conv.sep2) 279 | cur_len = 1 280 | target[:cur_len] = IGNORE_INDEX 281 | for i, rou in enumerate(rounds): 282 | if rou == "": 283 | break 284 | 285 | parts = rou.split(sep) 286 | if len(parts) != 2: 287 | break 288 | parts[0] += sep 289 | 290 | if has_speech: 291 | round_len = len(tokenizer_speech_token(rou, tokenizer)) 292 | instruction_len = len(tokenizer_speech_token(parts[0], tokenizer)) - 2 293 | else: 294 | round_len = len(tokenizer(rou).input_ids) 295 | instruction_len = len(tokenizer(parts[0]).input_ids) - 2 296 | 297 | # FIXME: tokenizer bug 298 | if i != 0 and not tokenizer.legacy and IS_TOKENIZER_GREATER_THAN_0_14: 299 | round_len -= 1 300 | instruction_len -= 1 301 | 302 | target[cur_len : cur_len + instruction_len] = IGNORE_INDEX 303 | 304 | cur_len += round_len 305 | target[cur_len:] = IGNORE_INDEX 306 | 307 | if cur_len < tokenizer.model_max_length: 308 | if cur_len != total_len: 309 | target[:] = IGNORE_INDEX 310 | print( 311 | f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." 312 | f" (ignored)" 313 | ) 314 | 315 | return dict( 316 | input_ids=input_ids, 317 | labels=targets, 318 | ) 319 | 320 | 321 | def preprocess_plain( 322 | sources: Sequence[str], 323 | tokenizer: transformers.PreTrainedTokenizer, 324 | ) -> Dict: 325 | # add end signal and concatenate together 326 | conversations = [] 327 | for source in sources: 328 | assert len(source) == 2 329 | assert DEFAULT_SPEECH_TOKEN in source[0]['value'] 330 | source[0]['value'] = DEFAULT_SPEECH_TOKEN 331 | conversation = source[0]['value'] + source[1]['value'] + conversation_lib.default_conversation.sep 332 | conversations.append(conversation) 333 | # tokenize conversations 334 | input_ids = [tokenizer_speech_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations] 335 | targets = copy.deepcopy(input_ids) 336 | for target, source in zip(targets, sources): 337 | tokenized_len = len(tokenizer_speech_token(source[0]['value'], tokenizer)) 338 | target[:tokenized_len] = IGNORE_INDEX 339 | 340 | return dict(input_ids=input_ids, labels=targets) 341 | 342 | 343 | def preprocess( 344 | sources: Sequence[str], 345 | tokenizer: transformers.PreTrainedTokenizer, 346 | has_speech: bool = False 347 | ) -> Dict: 348 | """ 349 | Given a list of sources, each is a conversation list. This transform: 350 | 1. Add signal '### ' at the beginning each sentence, with end signal '\n'; 351 | 2. Concatenate conversations together; 352 | 3. Tokenize the concatenated conversation; 353 | 4. Make a deepcopy as the target. Mask human words with IGNORE_INDEX. 354 | """ 355 | if conversation_lib.default_conversation.sep_style == conversation_lib.SeparatorStyle.PLAIN: 356 | return preprocess_plain(sources, tokenizer) 357 | if conversation_lib.default_conversation.sep_style == conversation_lib.SeparatorStyle.LLAMA_2: 358 | return preprocess_llama_2(sources, tokenizer, has_speech=has_speech) 359 | if conversation_lib.default_conversation.version.startswith("v1"): 360 | return preprocess_v1(sources, tokenizer, has_speech=has_speech) 361 | if conversation_lib.default_conversation.sep_style == conversation_lib.SeparatorStyle.LLAMA_3: 362 | return preprocess_llama_3(sources, tokenizer, has_speech=has_speech) 363 | raise NotImplementedError -------------------------------------------------------------------------------- /omni_speech/infer/convert_jsonl_to_txt.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import json 3 | 4 | input_file = sys.argv[1] 5 | output_file = sys.argv[2] 6 | 7 | with open(input_file, "r") as fin, open(output_file, "w") as fout: 8 | data = fin.readlines() 9 | for line in data: 10 | item = json.loads(line) 11 | prediction_units = item["prediction_units"] 12 | if prediction_units != "": 13 | fout.write(prediction_units + "\n") 14 | else: 15 | fout.write("0\n") 16 | -------------------------------------------------------------------------------- /omni_speech/infer/gen_answer_data/wavs: -------------------------------------------------------------------------------- 1 | ../../../wavs -------------------------------------------------------------------------------- /omni_speech/infer/infer.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import os 4 | import json 5 | from tqdm import tqdm 6 | import shortuuid 7 | import whisper 8 | 9 | from omni_speech.constants import SPEECH_TOKEN_INDEX, DEFAULT_SPEECH_TOKEN 10 | from omni_speech.conversation import conv_templates, SeparatorStyle 11 | from omni_speech.model.builder import load_pretrained_model 12 | from omni_speech.utils import disable_torch_init 13 | from omni_speech.datasets.preprocess import tokenizer_speech_token 14 | from torch.utils.data import Dataset, DataLoader 15 | 16 | import math 17 | 18 | 19 | def split_list(lst, n): 20 | """Split a list into n (roughly) equal-sized chunks""" 21 | chunk_size = math.ceil(len(lst) / n) # integer division 22 | return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)] 23 | 24 | 25 | def get_chunk(lst, n, k): 26 | chunks = split_list(lst, n) 27 | return chunks[k] 28 | 29 | 30 | # Custom dataset class 31 | class CustomDataset(Dataset): 32 | def __init__(self, questions, tokenizer, model_config, input_type, mel_size): 33 | self.questions = questions 34 | self.tokenizer = tokenizer 35 | self.model_config = model_config 36 | self.input_type = input_type 37 | self.mel_size = mel_size 38 | 39 | def __getitem__(self, index): 40 | item = self.questions[index] 41 | speech_file = item["speech"] 42 | qs = item["conversations"][0]["value"] 43 | 44 | conv = conv_templates[args.conv_mode].copy() 45 | conv.append_message(conv.roles[0], qs) 46 | conv.append_message(conv.roles[1], None) 47 | prompt = conv.get_prompt() 48 | 49 | speech = whisper.load_audio(speech_file) 50 | if self.input_type == "raw": 51 | speech = torch.from_numpy(speech) 52 | if self.model_config.speech_normalize: 53 | speech = torch.nn.functional.layer_norm(speech, speech.shape) 54 | elif self.input_type == "mel": 55 | speech = whisper.pad_or_trim(speech) 56 | speech = whisper.log_mel_spectrogram(speech, n_mels=self.mel_size).permute(1, 0) 57 | print(prompt) 58 | input_ids = tokenizer_speech_token(prompt, self.tokenizer, return_tensors='pt') 59 | 60 | return input_ids, speech, torch.LongTensor([speech.shape[0]]) 61 | 62 | def __len__(self): 63 | return len(self.questions) 64 | 65 | 66 | def collate_fn(batch): 67 | input_ids, speech_tensors, speech_lengths = zip(*batch) 68 | input_ids = torch.stack(input_ids, dim=0) 69 | speech_tensors = torch.stack(speech_tensors, dim=0) 70 | speech_lengths = torch.stack(speech_lengths, dim=0) 71 | return input_ids, speech_tensors, speech_lengths 72 | 73 | 74 | def ctc_postprocess(tokens, blank): 75 | _toks = tokens.squeeze(0).tolist() 76 | deduplicated_toks = [v for i, v in enumerate(_toks) if i == 0 or v != _toks[i - 1]] 77 | hyp = [v for v in deduplicated_toks if v != blank] #官方493 222 78 | hyp = " ".join(list(map(str, hyp))) #1918 547 79 | return hyp 80 | 81 | # DataLoader 82 | def create_data_loader(questions, tokenizer, model_config, input_type, mel_size, batch_size=1, num_workers=4): 83 | assert batch_size == 1, "batch_size must be 1" 84 | dataset = CustomDataset(questions, tokenizer, model_config, input_type, mel_size) 85 | data_loader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False, collate_fn=collate_fn) 86 | return data_loader 87 | 88 | 89 | def eval_model(args): 90 | # Model 91 | disable_torch_init() 92 | model_path = os.path.expanduser(args.model_path) 93 | tokenizer, model, context_len = load_pretrained_model(model_path, args.model_base, is_lora=args.is_lora, s2s=args.s2s) 94 | #tokenizer长度128000+256 model:OmniSpeech2SLlamaForCausalLM(OmniSpeechLlamaModel、speech_encoder、speech_projector、speech_generator) 95 | questions = json.load(open(os.path.expanduser(args.question_file), "r")) 96 | #{'id': 'helpful_base_1', 'speech': 'omni_speech/infer/examples/question_wav/helpful_base_1.wav', 'conversations': [{'from': 'human', 'value': "\nPlease directly answer the questions in the user's speech."}]} 97 | questions = get_chunk(questions, args.num_chunks, args.chunk_idx) 98 | answers_file = os.path.expanduser(args.answer_file) 99 | os.makedirs(os.path.dirname(answers_file), exist_ok=True) 100 | ans_file = open(answers_file, "w") 101 | 102 | data_loader = create_data_loader(questions, tokenizer, model.config, args.input_type, args.mel_size) 103 | 104 | for (input_ids, speech_tensor, speech_length), item in tqdm(zip(data_loader, questions), total=len(questions)): 105 | #torch.Size([1, 62]),torch.Size([1, 3000, 128]) #tensor([[3000]]) 106 | idx = item["id"] 107 | try: 108 | answer = item["conversations"][1]["value"] 109 | except: 110 | answer = None 111 | input_ids = input_ids.to(device='cuda', non_blocking=True) 112 | speech_tensor = speech_tensor.to(dtype=torch.float16, device='cuda', non_blocking=True) 113 | speech_length = speech_length.to(device='cuda', non_blocking=True) 114 | 115 | with torch.inference_mode(): 116 | if args.s2s: 117 | outputs = model.generate( 118 | input_ids, 119 | speech=speech_tensor, 120 | speech_lengths=speech_length, 121 | do_sample=True if args.temperature > 0 else False, 122 | temperature=args.temperature, 123 | top_p=args.top_p, 124 | num_beams=args.num_beams, 125 | max_new_tokens=args.max_new_tokens, 126 | use_cache=True, 127 | pad_token_id=128004, 128 | streaming_unit_gen=False, 129 | ) 130 | output_ids, output_units = outputs 131 | else: 132 | outputs = model.generate( 133 | input_ids, 134 | speech=speech_tensor, 135 | speech_lengths=speech_length, 136 | do_sample=True if args.temperature > 0 else False, 137 | temperature=args.temperature, 138 | top_p=args.top_p, 139 | num_beams=args.num_beams, 140 | max_new_tokens=args.max_new_tokens, 141 | use_cache=True, 142 | pad_token_id=128004, 143 | ) 144 | output_ids = outputs 145 | 146 | outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip() 147 | if args.s2s: 148 | output_units = ctc_postprocess(output_units, blank=model.config.unit_vocab_size) 149 | #547个token 150 | print(f"H-{idx}\t{outputs}") 151 | print(f"T-{idx}\t{answer}") 152 | if args.s2s: 153 | print(f"U-{idx}\t{output_units}") 154 | 155 | if args.s2s: 156 | ans_file.write(json.dumps({"question_id": idx, "prediction": outputs, "prediction_units": output_units, "answer": answer}) + "\n") 157 | else: 158 | ans_file.write(json.dumps({"question_id": idx, "prediction": outputs, "answer": answer}) + "\n") 159 | # ans_file.flush() 160 | ans_file.close() 161 | 162 | 163 | if __name__ == "__main__": 164 | parser = argparse.ArgumentParser() 165 | parser.add_argument("--model-path", type=str, default="facebook/opt-350m") 166 | parser.add_argument("--model-base", type=str, default=None) 167 | parser.add_argument("--question-file", type=str) 168 | parser.add_argument("--answer-file", type=str) 169 | parser.add_argument("--conv-mode", type=str, default="v1") 170 | parser.add_argument("--num-chunks", type=int, default=1) 171 | parser.add_argument("--chunk-idx", type=int, default=0) 172 | parser.add_argument("--temperature", type=float, default=0) 173 | parser.add_argument("--top_p", type=float, default=None) 174 | parser.add_argument("--num_beams", type=int, default=1) 175 | parser.add_argument("--max_new_tokens", type=int, default=256) 176 | parser.add_argument("--input_type", type=str, default="raw") 177 | parser.add_argument("--mel_size", type=int, default=128) 178 | parser.add_argument("--s2s", action="store_true", default=False) 179 | parser.add_argument("--is_lora", action="store_true", default=False) 180 | args = parser.parse_args() 181 | eval_model(args) 182 | -------------------------------------------------------------------------------- /omni_speech/infer/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | ROOT=$1 4 | 5 | VOCODER_CKPT=vocoder/g_00500000 6 | VOCODER_CFG=vocoder/config.json 7 | 8 | python omni_speech/infer/infer.py \ 9 | --model-path saves/stage2_new/checkpoint-500 \ 10 | --question-file $ROOT/question.json \ 11 | --answer-file $ROOT/answer.json \ 12 | --num-chunks 1 \ 13 | --chunk-idx 0 \ 14 | --temperature 0 \ 15 | --conv-mode llama_3 \ 16 | --input_type mel \ 17 | --mel_size 128 \ 18 | --s2s 19 | python omni_speech/infer/convert_jsonl_to_txt.py $ROOT/answer.json $ROOT/answer.unit 20 | python fairseq/examples/speech_to_speech/generate_waveform_from_code.py \ 21 | --in-code-file $ROOT/answer.unit \ 22 | --vocoder $VOCODER_CKPT --vocoder-cfg $VOCODER_CFG \ 23 | --results-path $ROOT/answer_wav/ --dur-prediction 24 | 25 | # bash omni_speech/infer/run.sh omni_speech/infer/examples 26 | -------------------------------------------------------------------------------- /omni_speech/infer/unit2wav.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | ROOT=omni_speech/infer/unit2wav_samples 4 | 5 | VOCODER_CKPT=vocoder/g_00500000 6 | VOCODER_CFG=vocoder/config.json 7 | 8 | python omni_speech/infer/convert_jsonl_to_txt.py $ROOT/answer.json $ROOT/answer.unit 9 | python fairseq/examples/speech_to_speech/generate_waveform_from_code.py \ 10 | --in-code-file $ROOT/answer.unit \ 11 | --vocoder $VOCODER_CKPT --vocoder-cfg $VOCODER_CFG \ 12 | --results-path $ROOT/answer_wav/ --dur-prediction -------------------------------------------------------------------------------- /omni_speech/model/.ipynb_checkpoints/omni_speech_arch-checkpoint.py: -------------------------------------------------------------------------------- 1 | # Adopted from https://github.com/haotian-liu/LLaVA. We modify the code to support speech input. Below is the original copyright: 2 | # Copyright 2023 Haotian Liu 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from abc import ABC, abstractmethod 17 | 18 | import torch 19 | 20 | from .speech_encoder.builder import build_speech_encoder 21 | from .speech_projector.builder import build_speech_projector 22 | from omni_speech.constants import IGNORE_INDEX, SPEECH_TOKEN_INDEX 23 | from omni_speech.utils import lengths_to_padding_mask 24 | 25 | 26 | class OmniSpeechMetaModel: 27 | 28 | def __init__(self, config): 29 | super(OmniSpeechMetaModel, self).__init__(config) 30 | 31 | if hasattr(config, "speech_encoder"): 32 | self.speech_encoder = build_speech_encoder(config) 33 | self.speech_projector = build_speech_projector(config) 34 | 35 | def get_speech_encoder(self): 36 | speech_encoder = getattr(self, 'speech_encoder', None) 37 | if type(speech_encoder) is list: 38 | speech_encoder = speech_encoder[0] 39 | return speech_encoder 40 | 41 | def initialize_speech_modules(self, model_args, fsdp=None): 42 | self.config.speech_encoder = getattr(model_args, "speech_encoder", None) 43 | self.config.speech_encoder_type = getattr(model_args, "speech_encoder_type", None) 44 | self.config.speech_projector_type = getattr(model_args, 'speech_projector_type', 'linear') 45 | self.config.speech_encoder_ds_rate = getattr(model_args, 'speech_encoder_ds_rate', 5) 46 | self.config.speech_encoder_hidden_size = getattr(model_args, 'speech_encoder_hidden_size', 1280) 47 | 48 | if self.get_speech_encoder() is None: 49 | speech_encoder = build_speech_encoder(self.config) 50 | if fsdp is not None and len(fsdp) > 0: 51 | self.speech_encoder = [speech_encoder] 52 | else: 53 | self.speech_encoder = speech_encoder 54 | 55 | if getattr(self, 'speech_projector', None) is None: 56 | self.speech_projector = build_speech_projector(self.config) 57 | else: 58 | # In case it is frozen by LoRA 59 | for p in self.speech_projector.parameters(): 60 | p.requires_grad = True 61 | 62 | if model_args.pretrain_speech_projector is not None: 63 | pretrain_speech_projector_weights = torch.load(model_args.pretrain_speech_projector, map_location='cpu') 64 | def get_w(weights, keyword): 65 | return {k.split(keyword + '.')[1]: v for k, v in weights.items() if keyword in k} 66 | 67 | self.speech_projector.load_state_dict(get_w(pretrain_speech_projector_weights, 'speech_projector')) 68 | 69 | 70 | class OmniSpeechMetaForCausalLM(ABC): 71 | 72 | @abstractmethod 73 | def get_model(self): 74 | pass 75 | 76 | def get_speech_encoder(self): 77 | return self.get_model().get_speech_encoder() 78 | 79 | def get_speech_projector(self): 80 | return self.get_model().speech_projector 81 | 82 | def encode_speech(self, speech, speech_lengths): 83 | speech_encoder_type = self.config.speech_encoder_type 84 | speech_encoder = self.get_speech_encoder() 85 | if "whisper" in speech_encoder_type.lower(): 86 | print(speech,speech_encoder) 87 | encoder_outs = speech_encoder(speech.permute(0, 2, 1)) 88 | speech_lengths = (speech_lengths + 1) // 2 89 | else: 90 | raise ValueError(f'Unknown speech encoder: {speech_encoder}') 91 | speech_projector_type = self.config.speech_projector_type 92 | speech_projector = self.get_speech_projector() 93 | if speech_projector_type == "linear": 94 | encoder_outs = speech_projector(encoder_outs) 95 | speech_lengths = speech_lengths // speech_projector.k 96 | else: 97 | raise ValueError(f'Unknown speech projector: {speech_projector_type}') 98 | speech_features = [encoder_outs[i, :speech_lengths[i]] for i in range(len(encoder_outs))] 99 | return speech_features 100 | 101 | def prepare_inputs_labels_for_speech_and_text( 102 | self, input_ids, position_ids, attention_mask, past_key_values, labels, 103 | speech, speech_lengths 104 | ): 105 | speech_encoder = self.get_speech_encoder() 106 | if speech_encoder is None or speech is None or input_ids.shape[1] == 1: 107 | return input_ids, position_ids, attention_mask, past_key_values, None, labels 108 | 109 | speech_features = self.encode_speech(speech, speech_lengths) 110 | 111 | # Let's just add dummy tensors if they do not exist, 112 | # it is a headache to deal with None all the time. 113 | # But it is not ideal, and if you have a better idea, 114 | # please open an issue / submit a PR, thanks. 115 | _labels = labels 116 | _position_ids = position_ids 117 | _attention_mask = attention_mask 118 | if attention_mask is None: 119 | attention_mask = torch.ones_like(input_ids, dtype=torch.bool) 120 | else: 121 | attention_mask = attention_mask.bool() 122 | if position_ids is None: 123 | position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=input_ids.device) 124 | if labels is None: 125 | labels = torch.full_like(input_ids, IGNORE_INDEX) 126 | 127 | # remove the padding using attention_mask -- FIXME 128 | _input_ids = input_ids 129 | input_ids = [cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)] 130 | labels = [cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask)] 131 | 132 | new_input_embeds = [] 133 | new_labels = [] 134 | cur_speech_idx = 0 135 | for batch_idx, cur_input_ids in enumerate(input_ids): 136 | num_speech = (cur_input_ids == SPEECH_TOKEN_INDEX).sum() #-200是语音的标记,出现几次有几个语音 137 | if num_speech == 0: 138 | cur_speech_features = speech_features[cur_speech_idx] 139 | cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids) 140 | cur_input_embeds = torch.cat([cur_input_embeds_1, cur_speech_features[0:0]], dim=0) 141 | new_input_embeds.append(cur_input_embeds) 142 | new_labels.append(labels[batch_idx]) 143 | cur_speech_idx += 1 144 | continue 145 | 146 | speech_token_indices = [-1] + torch.where(cur_input_ids == SPEECH_TOKEN_INDEX)[0].tolist() + [cur_input_ids.shape[0]] 147 | #[-1, 45, 62] 45是token -200的位置 148 | cur_input_ids_nospeech = [] 149 | cur_labels = labels[batch_idx] #都是-100,长度是62 150 | cur_labels_nospeech = [] 151 | for i in range(len(speech_token_indices) - 1): 152 | cur_input_ids_nospeech.append(cur_input_ids[speech_token_indices[i]+1:speech_token_indices[i+1]]) #cur_input_ids[0:45]、cur_input_ids[46:62] 153 | cur_labels_nospeech.append(cur_labels[speech_token_indices[i]+1:speech_token_indices[i+1]]) #len 45的-100列表、 154 | split_sizes = [x.shape[0] for x in cur_labels_nospeech] 155 | cur_input_embeds = self.get_model().embed_tokens(torch.cat(cur_input_ids_nospeech)) #Embedding(128256, 4096) torch.Size([61, 4096]) 156 | cur_input_embeds_no_speech = torch.split(cur_input_embeds, split_sizes, dim=0) 157 | cur_new_input_embeds = [] 158 | cur_new_labels = [] 159 | 160 | for i in range(num_speech + 1): #上面两段文本,中间加了一次音频特征 161 | cur_new_input_embeds.append(cur_input_embeds_no_speech[i]) 162 | cur_new_labels.append(cur_labels_nospeech[i]) 163 | if i < num_speech: 164 | cur_speech_features = speech_features[cur_speech_idx] #torch.Size([300, 4096]) 165 | cur_speech_idx += 1 166 | cur_new_input_embeds.append(cur_speech_features) 167 | cur_new_labels.append(torch.full((cur_speech_features.shape[0],), IGNORE_INDEX, device=cur_labels.device, dtype=cur_labels.dtype)) 168 | #填充和cur_speech_features等长的-100列表 169 | cur_new_input_embeds = [x.to(self.device) for x in cur_new_input_embeds] 170 | 171 | cur_new_input_embeds = torch.cat(cur_new_input_embeds) #torch.Size([361, 4096]) 172 | cur_new_labels = torch.cat(cur_new_labels) #torch.Size([361]) 173 | 174 | new_input_embeds.append(cur_new_input_embeds) 175 | new_labels.append(cur_new_labels) 176 | 177 | # Truncate sequences to max length as speech features can make the sequence longer 178 | tokenizer_model_max_length = getattr(self.config, 'tokenizer_model_max_length', None) 179 | if tokenizer_model_max_length is not None: 180 | new_input_embeds = [x[:tokenizer_model_max_length] for x in new_input_embeds] 181 | new_labels = [x[:tokenizer_model_max_length] for x in new_labels] 182 | 183 | # Combine them 184 | max_len = max(x.shape[0] for x in new_input_embeds) 185 | batch_size = len(new_input_embeds) 186 | 187 | new_input_embeds_padded = [] 188 | new_labels_padded = torch.full((batch_size, max_len), IGNORE_INDEX, dtype=new_labels[0].dtype, device=new_labels[0].device) 189 | attention_mask = torch.zeros((batch_size, max_len), dtype=attention_mask.dtype, device=attention_mask.device) 190 | position_ids = torch.zeros((batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device) 191 | 192 | for i, (cur_new_embed, cur_new_labels) in enumerate(zip(new_input_embeds, new_labels)): 193 | cur_len = cur_new_embed.shape[0] 194 | if getattr(self.config, 'tokenizer_padding_side', 'right') == "left": 195 | new_input_embeds_padded.append(torch.cat(( 196 | torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device), 197 | cur_new_embed 198 | ), dim=0)) 199 | if cur_len > 0: 200 | new_labels_padded[i, -cur_len:] = cur_new_labels 201 | attention_mask[i, -cur_len:] = True 202 | position_ids[i, -cur_len:] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device) 203 | else: 204 | new_input_embeds_padded.append(torch.cat(( 205 | cur_new_embed, 206 | torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device) 207 | ), dim=0)) 208 | if cur_len > 0: 209 | new_labels_padded[i, :cur_len] = cur_new_labels 210 | attention_mask[i, :cur_len] = True 211 | position_ids[i, :cur_len] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device) 212 | 213 | new_input_embeds = torch.stack(new_input_embeds_padded, dim=0) 214 | 215 | if _labels is None: 216 | new_labels = None 217 | else: 218 | new_labels = new_labels_padded 219 | 220 | if _attention_mask is None: 221 | attention_mask = None 222 | else: 223 | attention_mask = attention_mask.to(dtype=_attention_mask.dtype) 224 | 225 | if _position_ids is None: 226 | position_ids = None 227 | 228 | return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels -------------------------------------------------------------------------------- /omni_speech/model/__init__.py: -------------------------------------------------------------------------------- 1 | from .language_model.omni_speech_llama import OmniSpeechLlamaForCausalLM, OmniSpeechConfig 2 | from .language_model.omni_speech2s_llama import OmniSpeech2SLlamaForCausalLM -------------------------------------------------------------------------------- /omni_speech/model/builder.py: -------------------------------------------------------------------------------- 1 | # Adopted from https://github.com/haotian-liu/LLaVA. We modify the code to support speech input. Below is the original copyright: 2 | # Copyright 2023 Haotian Liu 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import os 17 | import warnings 18 | import shutil 19 | 20 | from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig 21 | import torch 22 | from omni_speech.model import * 23 | from omni_speech.model.speech_encoder.builder import build_speech_encoder 24 | import os 25 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 26 | 27 | def load_pretrained_model(model_path, model_base, is_lora=False, s2s=False, load_8bit=False, load_4bit=False, device="cuda", use_flash_attn=False, **kwargs): 28 | if load_8bit: 29 | kwargs['load_in_8bit'] = True 30 | elif load_4bit: 31 | kwargs['load_in_4bit'] = True 32 | kwargs['quantization_config'] = BitsAndBytesConfig( 33 | load_in_4bit=True, 34 | bnb_4bit_compute_dtype=torch.float16, 35 | bnb_4bit_use_double_quant=True, 36 | bnb_4bit_quant_type='nf4' 37 | ) 38 | else: 39 | kwargs['torch_dtype'] = torch.float16 40 | 41 | if use_flash_attn: 42 | kwargs['attn_implementation'] = 'flash_attention_2' 43 | 44 | model_cls = OmniSpeech2SLlamaForCausalLM if s2s else OmniSpeechLlamaForCausalLM 45 | 46 | # Load OmniSpeech model 47 | if is_lora: 48 | assert model_base is not None, "model_base is required for LoRA models." 49 | from omni_speech.model.language_model.omni_speech_llama import OmniSpeechConfig 50 | lora_cfg_pretrained = OmniSpeechConfig.from_pretrained(model_path) 51 | tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False) 52 | print('Loading OmniSpeech from base model...') 53 | model = model_cls.from_pretrained(model_base, low_cpu_mem_usage=False, config=lora_cfg_pretrained, **kwargs) 54 | print('Loading additional OmniSpeech weights...') 55 | if os.path.exists(os.path.join(model_path, 'non_lora_trainables.bin')): 56 | non_lora_trainables = torch.load(os.path.join(model_path, 'non_lora_trainables.bin'), map_location='cpu') 57 | non_lora_trainables = {(k[11:] if k.startswith('base_model.') else k): v for k, v in non_lora_trainables.items()} 58 | if any(k.startswith('model.model.') for k in non_lora_trainables): 59 | non_lora_trainables = {(k[6:] if k.startswith('model.') else k): v for k, v in non_lora_trainables.items()} 60 | model.load_state_dict(non_lora_trainables, strict=False) 61 | 62 | from peft import PeftModel 63 | print('Loading LoRA weights...') 64 | model = PeftModel.from_pretrained(model, model_path) 65 | print('Merging LoRA weights...') 66 | model = model.merge_and_unload() 67 | print('Model is loaded...') 68 | elif model_base is not None: 69 | print('Loading OmniSpeech from base model...') 70 | tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False) 71 | cfg_pretrained = AutoConfig.from_pretrained(model_path) 72 | model = model_cls.from_pretrained(model_base, low_cpu_mem_usage=False, config=cfg_pretrained, **kwargs) 73 | 74 | speech_projector_weights = torch.load(os.path.join(model_path, 'speech_projector.bin'), map_location='cpu') 75 | speech_projector_weights = {k: v.to(torch.float16) for k, v in speech_projector_weights.items()} 76 | model.load_state_dict(speech_projector_weights, strict=False) 77 | model = model.to(device=device) 78 | else: 79 | tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False) 80 | model = model_cls.from_pretrained( 81 | model_path, 82 | low_cpu_mem_usage=False, 83 | **kwargs 84 | ) 85 | 86 | model.get_model().speech_encoder = build_speech_encoder(model.config) 87 | model.get_model().speech_encoder.to(dtype=torch.float16) 88 | model = model.to(device=device) 89 | 90 | if hasattr(model.config, "max_sequence_length"): 91 | context_len = model.config.max_sequence_length 92 | else: 93 | context_len = 2048 94 | 95 | return tokenizer, model, context_len 96 | 97 | def create_model(model_path, model_base, is_lora=False, s2s=False, load_8bit=False, load_4bit=False, device="cuda", use_flash_attn=False, **kwargs): 98 | if load_8bit: 99 | kwargs['load_in_8bit'] = True 100 | elif load_4bit: 101 | kwargs['load_in_4bit'] = True 102 | kwargs['quantization_config'] = BitsAndBytesConfig( 103 | load_in_4bit=True, 104 | bnb_4bit_compute_dtype=torch.float16, 105 | bnb_4bit_use_double_quant=True, 106 | bnb_4bit_quant_type='nf4' 107 | ) 108 | else: 109 | kwargs['torch_dtype'] = torch.bfloat16 110 | 111 | if use_flash_attn: 112 | kwargs['attn_implementation'] = 'flash_attention_2' 113 | model_cls = OmniSpeech2SLlamaForCausalLM if s2s else OmniSpeechLlamaForCausalLM 114 | if is_lora: 115 | assert model_base is not None, "model_base is required for LoRA models." 116 | from omni_speech.model.language_model.omni_speech_llama import OmniSpeechConfig 117 | lora_cfg_pretrained = OmniSpeechConfig.from_pretrained(model_path) 118 | tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False) 119 | print('Loading OmniSpeech from base model...') 120 | model = model_cls.from_pretrained(model_base, low_cpu_mem_usage=False, config=lora_cfg_pretrained, **kwargs) 121 | print('Loading additional OmniSpeech weights...') 122 | # if os.path.exists(os.path.join(model_path, 'non_lora_trainables.bin')): 123 | # non_lora_trainables = torch.load(os.path.join(model_path, 'non_lora_trainables.bin'), map_location='cpu') 124 | # non_lora_trainables = {(k[11:] if k.startswith('base_model.') else k): v for k, v in non_lora_trainables.items()} 125 | # if any(k.startswith('model.model.') for k in non_lora_trainables): 126 | # non_lora_trainables = {(k[6:] if k.startswith('model.') else k): v for k, v in non_lora_trainables.items()} 127 | #model.load_state_dict(non_lora_trainables, strict=False) 128 | 129 | from peft import PeftModel 130 | print('Loading LoRA weights...') 131 | model = PeftModel.from_pretrained(model, model_path) 132 | print('Merging LoRA weights...') 133 | model = model.merge_and_unload() 134 | print('Model is loaded...') 135 | 136 | else: 137 | tokenizer = AutoTokenizer.from_pretrained(model_path,padding_side="right",padding=True, use_fast=False) 138 | model = model_cls.from_pretrained( 139 | model_path, 140 | low_cpu_mem_usage=False, 141 | **kwargs, 142 | 143 | ) 144 | 145 | model.initialize_speech_generator(model.config) 146 | model = model.to(device=device,dtype=torch.bfloat16) 147 | model.get_model().speech_encoder = build_speech_encoder(model.config) 148 | model.get_model().speech_encoder.to(device=device, dtype=torch.bfloat16) 149 | #冻住speech_encoder的参数 150 | for param in model.get_model().speech_encoder.parameters(): 151 | param.requires_grad = False 152 | 153 | if hasattr(model.config, "max_sequence_length"): 154 | context_len = model.config.max_sequence_length 155 | else: 156 | context_len = 2048 157 | 158 | return tokenizer, model, context_len 159 | -------------------------------------------------------------------------------- /omni_speech/model/language_model/.ipynb_checkpoints/omni_speech_llama-checkpoint.py: -------------------------------------------------------------------------------- 1 | # Adopted from https://github.com/haotian-liu/LLaVA. We modify the code to support speech input. Below is the original copyright: 2 | # Copyright 2023 Haotian Liu 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from typing import List, Optional, Tuple, Union 17 | 18 | import torch 19 | import torch.nn as nn 20 | 21 | from transformers import AutoConfig, AutoModelForCausalLM, \ 22 | LlamaConfig, LlamaModel, LlamaForCausalLM 23 | 24 | from transformers.modeling_outputs import CausalLMOutputWithPast 25 | from transformers.generation.utils import GenerateOutput 26 | 27 | from ..omni_speech_arch import OmniSpeechMetaModel, OmniSpeechMetaForCausalLM 28 | 29 | 30 | class OmniSpeechConfig(LlamaConfig): 31 | model_type = "omni_speech_llama" 32 | 33 | 34 | class OmniSpeechLlamaModel(OmniSpeechMetaModel, LlamaModel): 35 | config_class = OmniSpeechConfig 36 | 37 | def __init__(self, config: LlamaConfig): 38 | super(OmniSpeechLlamaModel, self).__init__(config) 39 | 40 | 41 | class OmniSpeechLlamaForCausalLM(LlamaForCausalLM, OmniSpeechMetaForCausalLM): 42 | config_class = OmniSpeechConfig 43 | 44 | def __init__(self, config): 45 | 46 | super(LlamaForCausalLM, self).__init__(config) 47 | self.model = OmniSpeechLlamaModel(config) 48 | self.pretraining_tp = config.pretraining_tp 49 | self.vocab_size = config.vocab_size 50 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) 51 | 52 | # Initialize weights and apply final processing 53 | self.post_init() 54 | 55 | def get_model(self): 56 | return self.model 57 | 58 | def forward( 59 | self, 60 | input_ids: torch.LongTensor = None, 61 | attention_mask: Optional[torch.Tensor] = None, 62 | position_ids: Optional[torch.LongTensor] = None, 63 | past_key_values: Optional[List[torch.FloatTensor]] = None, 64 | inputs_embeds: Optional[torch.FloatTensor] = None, 65 | labels: Optional[torch.LongTensor] = None, 66 | use_cache: Optional[bool] = None, 67 | output_attentions: Optional[bool] = None, 68 | output_hidden_states: Optional[bool] = None, 69 | speech: Optional[torch.FloatTensor] = None, 70 | speech_lengths: Optional[torch.LongTensor] = None, 71 | return_dict: Optional[bool] = None, 72 | cache_position: Optional[torch.LongTensor] = None, 73 | ) -> Union[Tuple, CausalLMOutputWithPast]: 74 | if inputs_embeds is None: #inputs_embeds none 75 | ( 76 | input_ids, 77 | position_ids, 78 | attention_mask, 79 | past_key_values, 80 | inputs_embeds, 81 | labels 82 | ) = self.prepare_inputs_labels_for_speech_and_text( 83 | input_ids, 84 | position_ids, 85 | attention_mask, 86 | past_key_values, 87 | labels, 88 | speech, 89 | speech_lengths 90 | ) 91 | 92 | return super().forward( 93 | input_ids=input_ids, #none 94 | attention_mask=attention_mask, #none 95 | position_ids=position_ids, #none 96 | past_key_values=past_key_values, #none 97 | inputs_embeds=inputs_embeds, #tesnor[1,361,2048] 98 | labels=labels, #none 99 | use_cache=use_cache, #True 100 | output_attentions=output_attentions, #none 101 | output_hidden_states=output_hidden_states, #none 102 | return_dict=return_dict #none 103 | ) 104 | 105 | @torch.no_grad() 106 | def generate( 107 | self, 108 | inputs: Optional[torch.Tensor] = None, 109 | speech: Optional[torch.Tensor] = None, 110 | speech_lengths: Optional[torch.Tensor] = None, 111 | **kwargs, 112 | ) -> Union[GenerateOutput, torch.LongTensor]: 113 | position_ids = kwargs.pop("position_ids", None) 114 | attention_mask = kwargs.pop("attention_mask", None) 115 | if "inputs_embeds" in kwargs: 116 | raise NotImplementedError("`inputs_embeds` is not supported") 117 | 118 | if speech is not None: 119 | ( 120 | inputs, 121 | position_ids, 122 | attention_mask, 123 | _, 124 | inputs_embeds, 125 | _ 126 | ) = self.prepare_inputs_labels_for_speech_and_text( 127 | inputs, 128 | position_ids, 129 | attention_mask, 130 | None, 131 | None, 132 | speech, 133 | speech_lengths 134 | ) 135 | else: 136 | inputs_embeds = self.get_model().embed_tokens(inputs) 137 | 138 | return super().generate( 139 | position_ids=position_ids, 140 | attention_mask=attention_mask, 141 | inputs_embeds=inputs_embeds, 142 | **kwargs 143 | ) 144 | 145 | def prepare_inputs_for_generation(self, input_ids, past_key_values=None, 146 | inputs_embeds=None, **kwargs): 147 | speech = kwargs.pop("speech", None) 148 | speech_lengths = kwargs.pop("speech_lengths", None) 149 | inputs = super().prepare_inputs_for_generation( 150 | input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs 151 | ) 152 | if speech is not None: 153 | inputs['speech'] = speech 154 | inputs['speech_lengths'] = speech_lengths 155 | return inputs 156 | # 157 | AutoConfig.register("omni_speech_llama", OmniSpeechConfig) 158 | AutoModelForCausalLM.register(OmniSpeechConfig, OmniSpeechLlamaForCausalLM) 159 | -------------------------------------------------------------------------------- /omni_speech/model/language_model/omni_speech2s_llama.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, Tuple, Union 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from transformers import AutoConfig, AutoModelForCausalLM, \ 7 | LlamaConfig 8 | 9 | from transformers.modeling_outputs import CausalLMOutputWithPast 10 | from transformers.generation.utils import GenerateOutput 11 | 12 | from omni_speech.model.language_model.omni_speech_llama import OmniSpeechLlamaForCausalLM 13 | from omni_speech.model.speech_generator.builder import build_speech_generator 14 | from omni_speech.model.speech_generator.generation import GenerationWithCTC 15 | 16 | 17 | class OmniSpeech2SConfig(LlamaConfig): 18 | model_type = "omni_speech2s_llama" 19 | 20 | 21 | class OmniSpeech2SLlamaForCausalLM(OmniSpeechLlamaForCausalLM, GenerationWithCTC): 22 | config_class = OmniSpeech2SConfig 23 | 24 | def __init__(self, config): 25 | super().__init__(config) 26 | 27 | # Initialize weights and apply final processing 28 | self.post_init() 29 | if hasattr(config, "speech_generator_type"): 30 | self.speech_generator = build_speech_generator(config) 31 | 32 | def initialize_speech_generator(self, model_args): 33 | self.config.speech_generator_type = getattr(model_args, 'speech_generator_type', 'ctc') 34 | self.config.ctc_decoder_config = getattr(model_args, 'ctc_decoder_config', '(4,4096,32,11008)') 35 | self.config.ctc_upsample_factor = getattr(model_args, 'ctc_upsample_factor', 1) 36 | self.config.ctc_loss_weight = getattr(model_args, 'ctc_loss_weight', 1.0) 37 | self.config.unit_vocab_size = getattr(model_args, 'unit_vocab_size', 1000) 38 | self.tune_speech_generator_only = getattr(model_args, 'tune_speech_generator_only', True) 39 | if getattr(self, "speech_generator", None) is None: 40 | self.speech_generator = build_speech_generator(self.config) 41 | 42 | def forward( 43 | self, 44 | input_ids: torch.LongTensor = None, 45 | attention_mask: Optional[torch.Tensor] = None, 46 | position_ids: Optional[torch.LongTensor] = None, 47 | past_key_values: Optional[List[torch.FloatTensor]] = None, 48 | inputs_embeds: Optional[torch.FloatTensor] = None, 49 | labels: Optional[torch.LongTensor] = None, 50 | use_cache: Optional[bool] = None, 51 | output_attentions: Optional[bool] = None, 52 | output_hidden_states: Optional[bool] = None, 53 | speech: Optional[torch.FloatTensor] = None, 54 | speech_lengths: Optional[torch.LongTensor] = None, 55 | tgt_units: Optional[torch.LongTensor] = None, 56 | return_dict: Optional[bool] = None, 57 | cache_position: Optional[torch.LongTensor] = None, 58 | ) -> Union[Tuple, CausalLMOutputWithPast]: 59 | 60 | if inputs_embeds is None: 61 | ( 62 | input_ids, 63 | position_ids, 64 | attention_mask, 65 | past_key_values, 66 | inputs_embeds, 67 | labels 68 | ) = self.prepare_inputs_labels_for_speech_and_text( 69 | input_ids, 70 | position_ids, 71 | attention_mask, 72 | past_key_values, 73 | labels, 74 | speech, 75 | speech_lengths 76 | ) 77 | 78 | if self.training: 79 | if self.tune_speech_generator_only: 80 | with torch.no_grad(): 81 | llama_output = super(OmniSpeechLlamaForCausalLM, self).forward( 82 | input_ids=input_ids, 83 | attention_mask=attention_mask, 84 | position_ids=position_ids, 85 | past_key_values=past_key_values, 86 | inputs_embeds=inputs_embeds, 87 | labels=labels, 88 | use_cache=use_cache, 89 | output_attentions=output_attentions, 90 | output_hidden_states=True, 91 | return_dict=return_dict 92 | ) 93 | loss = self.speech_generator(llama_output['hidden_states'][-1], labels, tgt_units) 94 | else: 95 | llama_output = super(OmniSpeechLlamaForCausalLM, self).forward( 96 | input_ids=input_ids, 97 | attention_mask=attention_mask, 98 | position_ids=position_ids, 99 | past_key_values=past_key_values, 100 | inputs_embeds=inputs_embeds, 101 | labels=labels, 102 | use_cache=use_cache, 103 | output_attentions=output_attentions, 104 | output_hidden_states=True, 105 | return_dict=return_dict 106 | ) 107 | lm_loss = llama_output.loss 108 | ctc_loss = self.speech_generator(llama_output['hidden_states'][-1], labels, tgt_units) 109 | loss = lm_loss + ctc_loss * self.config.ctc_loss_weight 110 | else: 111 | llama_output = super(OmniSpeechLlamaForCausalLM, self).forward( 112 | input_ids=input_ids, 113 | attention_mask=attention_mask, 114 | position_ids=position_ids, 115 | past_key_values=past_key_values, 116 | inputs_embeds=inputs_embeds, 117 | labels=labels, 118 | use_cache=use_cache, 119 | output_attentions=output_attentions, 120 | output_hidden_states=True, 121 | return_dict=return_dict 122 | ) 123 | loss = llama_output.loss 124 | 125 | return CausalLMOutputWithPast( 126 | loss=loss, 127 | logits=llama_output.logits, 128 | past_key_values=llama_output.past_key_values, 129 | hidden_states=llama_output.hidden_states, 130 | attentions=llama_output.attentions 131 | ) 132 | 133 | @torch.no_grad() 134 | def generate( 135 | self, 136 | inputs: Optional[torch.Tensor] = None, 137 | speech: Optional[torch.Tensor] = None, 138 | speech_lengths: Optional[torch.Tensor] = None, 139 | streaming_unit_gen=False, 140 | **kwargs, 141 | ) -> Union[GenerateOutput, torch.LongTensor]: 142 | position_ids = kwargs.pop("position_ids", None) 143 | attention_mask = kwargs.pop("attention_mask", None) 144 | if "inputs_embeds" in kwargs: 145 | raise NotImplementedError("`inputs_embeds` is not supported") 146 | 147 | if speech is not None: 148 | ( 149 | inputs, 150 | position_ids, 151 | attention_mask, 152 | _, 153 | inputs_embeds, 154 | _ 155 | ) = self.prepare_inputs_labels_for_speech_and_text( 156 | inputs, 157 | position_ids, 158 | attention_mask, 159 | None, 160 | None, 161 | speech, 162 | speech_lengths 163 | ) 164 | else: 165 | inputs_embeds = self.get_model().embed_tokens(inputs) 166 | outputs = GenerationWithCTC.generate( 167 | self, 168 | position_ids=position_ids, 169 | attention_mask=attention_mask, 170 | inputs_embeds=inputs_embeds, 171 | output_hidden_states=True, 172 | return_dict_in_generate=True, 173 | streaming_unit_gen=streaming_unit_gen, 174 | **kwargs 175 | ) 176 | #odict_keys(['sequences', 'hidden_states', 'past_key_values']) 177 | # sequences:torch.Size([1, 52]) hidden_states:torch.Size([1, 52, 4096]) past_key_values:torch.Size([1, 52, 4096]) 178 | hidden_states = outputs['hidden_states'] 179 | hidden_states = torch.cat([hidden_states[0][-1][:, -1:, :]] + [hidden_states[i][-1] for i in range(1, len(hidden_states))], dim=1) 180 | #torch.Size([1, 47, 4096]) 181 | ctc_pred = self.speech_generator.predict(hidden_states.squeeze(0)) #torch.Size([1, 1300]) 182 | 183 | return outputs.sequences, ctc_pred 184 | 185 | def prepare_inputs_for_generation(self, input_ids, past_key_values=None, 186 | inputs_embeds=None, **kwargs): 187 | speech = kwargs.pop("speech", None) 188 | speech_lengths = kwargs.pop("speech_lengths", None) 189 | inputs = super().prepare_inputs_for_generation( 190 | input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs 191 | ) 192 | if speech is not None: 193 | inputs['speech'] = speech 194 | inputs['speech_lengths'] = speech_lengths 195 | return inputs 196 | 197 | AutoConfig.register("omni_speech2s_llama", OmniSpeech2SConfig) 198 | AutoModelForCausalLM.register(OmniSpeech2SConfig, OmniSpeech2SLlamaForCausalLM) 199 | -------------------------------------------------------------------------------- /omni_speech/model/language_model/omni_speech_llama.py: -------------------------------------------------------------------------------- 1 | # Adopted from https://github.com/haotian-liu/LLaVA. We modify the code to support speech input. Below is the original copyright: 2 | # Copyright 2023 Haotian Liu 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from typing import List, Optional, Tuple, Union 17 | 18 | import torch 19 | import torch.nn as nn 20 | 21 | from transformers import AutoConfig, AutoModelForCausalLM, \ 22 | LlamaConfig, LlamaModel, LlamaForCausalLM 23 | 24 | from transformers.modeling_outputs import CausalLMOutputWithPast 25 | from transformers.generation.utils import GenerateOutput 26 | 27 | from ..omni_speech_arch import OmniSpeechMetaModel, OmniSpeechMetaForCausalLM 28 | 29 | 30 | class OmniSpeechConfig(LlamaConfig): 31 | model_type = "omni_speech_llama" 32 | 33 | 34 | class OmniSpeechLlamaModel(OmniSpeechMetaModel, LlamaModel): 35 | config_class = OmniSpeechConfig 36 | 37 | def __init__(self, config: LlamaConfig): 38 | super(OmniSpeechLlamaModel, self).__init__(config) 39 | 40 | 41 | class OmniSpeechLlamaForCausalLM(LlamaForCausalLM, OmniSpeechMetaForCausalLM): 42 | config_class = OmniSpeechConfig 43 | 44 | def __init__(self, config): 45 | super(LlamaForCausalLM, self).__init__(config) 46 | self.model = OmniSpeechLlamaModel(config) #llm和speech_ecoder 和projector 47 | self.pretraining_tp = config.pretraining_tp 48 | self.vocab_size = config.vocab_size 49 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) 50 | 51 | # Initialize weights and apply final processing 52 | self.post_init() 53 | 54 | def get_model(self): 55 | return self.model 56 | 57 | def forward( 58 | self, 59 | input_ids: torch.LongTensor = None, 60 | attention_mask: Optional[torch.Tensor] = None, 61 | position_ids: Optional[torch.LongTensor] = None, 62 | past_key_values: Optional[List[torch.FloatTensor]] = None, 63 | inputs_embeds: Optional[torch.FloatTensor] = None, 64 | labels: Optional[torch.LongTensor] = None, 65 | use_cache: Optional[bool] = None, 66 | output_attentions: Optional[bool] = None, 67 | output_hidden_states: Optional[bool] = None, 68 | speech: Optional[torch.FloatTensor] = None, 69 | speech_lengths: Optional[torch.LongTensor] = None, 70 | return_dict: Optional[bool] = None, 71 | cache_position: Optional[torch.LongTensor] = None, 72 | ) -> Union[Tuple, CausalLMOutputWithPast]: 73 | if inputs_embeds is None: #inputs_embeds none 74 | ( 75 | input_ids, 76 | position_ids, 77 | attention_mask, 78 | past_key_values, 79 | inputs_embeds, 80 | labels 81 | ) = self.prepare_inputs_labels_for_speech_and_text( 82 | input_ids, 83 | position_ids, 84 | attention_mask, 85 | past_key_values, 86 | labels, 87 | speech, 88 | speech_lengths 89 | ) 90 | result=super().forward( 91 | input_ids=input_ids, #none 92 | attention_mask=attention_mask, #none 93 | position_ids=position_ids, #none 94 | past_key_values=past_key_values, #none 95 | inputs_embeds=inputs_embeds, #tesnor[1,361,2048] 96 | labels=labels, #none 97 | use_cache=use_cache, #True 98 | output_attentions=output_attentions, #none 99 | output_hidden_states=output_hidden_states, #none 100 | return_dict=return_dict #none 101 | ) 102 | return result 103 | 104 | @torch.no_grad() 105 | def generate( 106 | self, 107 | inputs: Optional[torch.Tensor] = None, 108 | speech: Optional[torch.Tensor] = None, 109 | speech_lengths: Optional[torch.Tensor] = None, 110 | **kwargs, 111 | ) -> Union[GenerateOutput, torch.LongTensor]: 112 | position_ids = kwargs.pop("position_ids", None) 113 | attention_mask = kwargs.pop("attention_mask", None) 114 | if "inputs_embeds" in kwargs: 115 | raise NotImplementedError("`inputs_embeds` is not supported") 116 | 117 | if speech is not None: 118 | ( 119 | inputs, 120 | position_ids, 121 | attention_mask, 122 | _, 123 | inputs_embeds, 124 | _ 125 | ) = self.prepare_inputs_labels_for_speech_and_text( 126 | inputs, 127 | position_ids, 128 | attention_mask, 129 | None, 130 | None, 131 | speech, 132 | speech_lengths 133 | ) 134 | else: 135 | inputs_embeds = self.get_model().embed_tokens(inputs) 136 | 137 | return super().generate( 138 | position_ids=position_ids, 139 | attention_mask=attention_mask, 140 | inputs_embeds=inputs_embeds, 141 | **kwargs 142 | ) 143 | 144 | def prepare_inputs_for_generation(self, input_ids, past_key_values=None, 145 | inputs_embeds=None, **kwargs): 146 | speech = kwargs.pop("speech", None) 147 | speech_lengths = kwargs.pop("speech_lengths", None) 148 | inputs = super().prepare_inputs_for_generation( 149 | input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs 150 | ) 151 | if speech is not None: 152 | inputs['speech'] = speech 153 | inputs['speech_lengths'] = speech_lengths 154 | return inputs 155 | # 156 | AutoConfig.register("omni_speech_llama", OmniSpeechConfig) 157 | AutoModelForCausalLM.register(OmniSpeechConfig, OmniSpeechLlamaForCausalLM) 158 | -------------------------------------------------------------------------------- /omni_speech/model/omni_speech_arch.py: -------------------------------------------------------------------------------- 1 | # Adopted from https://github.com/haotian-liu/LLaVA. We modify the code to support speech input. Below is the original copyright: 2 | # Copyright 2023 Haotian Liu 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from abc import ABC, abstractmethod 17 | 18 | import torch 19 | 20 | from .speech_encoder.builder import build_speech_encoder 21 | from .speech_projector.builder import build_speech_projector 22 | from omni_speech.constants import IGNORE_INDEX, SPEECH_TOKEN_INDEX 23 | from omni_speech.utils import lengths_to_padding_mask 24 | 25 | 26 | class OmniSpeechMetaModel: 27 | 28 | def __init__(self, config): 29 | super(OmniSpeechMetaModel, self).__init__(config) 30 | 31 | if hasattr(config, "speech_encoder"): 32 | self.speech_encoder = build_speech_encoder(config) 33 | self.speech_projector = build_speech_projector(config) 34 | 35 | def get_speech_encoder(self): 36 | speech_encoder = getattr(self, 'speech_encoder', None) 37 | if type(speech_encoder) is list: 38 | speech_encoder = speech_encoder[0] 39 | return speech_encoder 40 | 41 | def initialize_speech_modules(self, model_args, fsdp=None): 42 | self.config.speech_encoder = getattr(model_args, "speech_encoder", None) 43 | self.config.speech_encoder_type = getattr(model_args, "speech_encoder_type", None) 44 | self.config.speech_projector_type = getattr(model_args, 'speech_projector_type', 'linear') 45 | self.config.speech_encoder_ds_rate = getattr(model_args, 'speech_encoder_ds_rate', 5) 46 | self.config.speech_encoder_hidden_size = getattr(model_args, 'speech_encoder_hidden_size', 1280) 47 | if self.get_speech_encoder() is None: 48 | speech_encoder = build_speech_encoder(self.config) 49 | if fsdp is not None and len(fsdp) > 0: 50 | self.speech_encoder = [speech_encoder] 51 | else: 52 | self.speech_encoder = speech_encoder 53 | 54 | if getattr(self, 'speech_projector', None) is None: 55 | self.speech_projector = build_speech_projector(self.config) 56 | else: 57 | # In case it is frozen by LoRA 58 | for p in self.speech_projector.parameters(): 59 | p.requires_grad = True 60 | 61 | if model_args.pretrain_speech_projector is not None: 62 | pretrain_speech_projector_weights = torch.load(model_args.pretrain_speech_projector, map_location='cpu') 63 | def get_w(weights, keyword): 64 | return {k.split(keyword + '.')[1]: v for k, v in weights.items() if keyword in k} 65 | 66 | self.speech_projector.load_state_dict(get_w(pretrain_speech_projector_weights, 'speech_projector')) 67 | 68 | 69 | class OmniSpeechMetaForCausalLM(ABC): 70 | 71 | @abstractmethod 72 | def get_model(self): 73 | pass 74 | 75 | def get_speech_encoder(self): 76 | return self.get_model().get_speech_encoder() 77 | 78 | def get_speech_projector(self): 79 | return self.get_model().speech_projector 80 | 81 | def encode_speech(self, speech, speech_lengths): 82 | speech_encoder_type = self.config.speech_encoder_type 83 | speech_encoder = self.get_speech_encoder() 84 | if "whisper" in speech_encoder_type.lower(): 85 | # speech=speech.half() 86 | # speech_encoder.half() 87 | encoder_outs = speech_encoder(speech.permute(0, 2, 1)) 88 | speech_lengths = (speech_lengths + 1) // 2 89 | else: 90 | raise ValueError(f'Unknown speech encoder: {speech_encoder}') 91 | speech_projector_type = self.config.speech_projector_type 92 | speech_projector = self.get_speech_projector() 93 | if speech_projector_type == "linear": 94 | encoder_outs = speech_projector(encoder_outs) 95 | speech_lengths = speech_lengths // speech_projector.k 96 | else: 97 | raise ValueError(f'Unknown speech projector: {speech_projector_type}') 98 | speech_features = [encoder_outs[i, :speech_lengths[i]] for i in range(len(encoder_outs))] 99 | return speech_features 100 | 101 | def prepare_inputs_labels_for_speech_and_text( 102 | self, input_ids, position_ids, attention_mask, past_key_values, labels, 103 | speech, speech_lengths 104 | ): 105 | speech_encoder = self.get_speech_encoder() 106 | if speech_encoder is None or speech is None or input_ids.shape[1] == 1: 107 | return input_ids, position_ids, attention_mask, past_key_values, None, labels 108 | 109 | speech_features = self.encode_speech(speech, speech_lengths) 110 | #print("speech_features:", speech_features[0]) 111 | # Let's just add dummy tensors if they do not exist, 112 | # it is a headache to deal with None all the time. 113 | # But it is not ideal, and if you have a better idea, 114 | # please open an issue / submit a PR, thanks. 115 | _labels = labels 116 | _position_ids = position_ids 117 | _attention_mask = attention_mask 118 | if attention_mask is None: 119 | attention_mask = torch.ones_like(input_ids, dtype=torch.bool) 120 | else: 121 | attention_mask = attention_mask.bool() 122 | if position_ids is None: 123 | position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=input_ids.device) 124 | if labels is None: 125 | labels = torch.full_like(input_ids, IGNORE_INDEX) 126 | 127 | # remove the padding using attention_mask -- FIXME 128 | _input_ids = input_ids 129 | input_ids = [cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)] 130 | labels = [cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask)] 131 | 132 | new_input_embeds = [] 133 | new_labels = [] 134 | cur_speech_idx = 0 135 | for batch_idx, cur_input_ids in enumerate(input_ids): 136 | num_speech = (cur_input_ids == SPEECH_TOKEN_INDEX).sum() #-200是语音的标记,出现几次有几个语音 137 | if num_speech == 0: 138 | cur_speech_features = speech_features[cur_speech_idx] 139 | cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids) 140 | cur_input_embeds = torch.cat([cur_input_embeds_1, cur_speech_features[0:0]], dim=0) 141 | new_input_embeds.append(cur_input_embeds) 142 | new_labels.append(labels[batch_idx]) 143 | cur_speech_idx += 1 144 | continue 145 | 146 | speech_token_indices = [-1] + torch.where(cur_input_ids == SPEECH_TOKEN_INDEX)[0].tolist() + [cur_input_ids.shape[0]] 147 | #[-1, 45, 62] 45是token -200的位置 148 | cur_input_ids_nospeech = [] 149 | cur_labels = labels[batch_idx] #都是-100,长度是62 150 | cur_labels_nospeech = [] 151 | for i in range(len(speech_token_indices) - 1): 152 | cur_input_ids_nospeech.append(cur_input_ids[speech_token_indices[i]+1:speech_token_indices[i+1]]) #cur_input_ids[0:45]、cur_input_ids[46:62] 153 | cur_labels_nospeech.append(cur_labels[speech_token_indices[i]+1:speech_token_indices[i+1]]) #len 45的-100列表、[46:62] -100 154 | split_sizes = [x.shape[0] for x in cur_labels_nospeech] 155 | cur_input_embeds = self.get_model().embed_tokens(torch.cat(cur_input_ids_nospeech)) #Embedding(128256, 4096) torch.Size([61, 4096]) 156 | cur_input_embeds_no_speech = torch.split(cur_input_embeds, split_sizes, dim=0) 157 | cur_new_input_embeds = [] 158 | cur_new_labels = [] 159 | 160 | for i in range(num_speech + 1): #上面两段文本,中间加了一次音频特征 161 | cur_new_input_embeds.append(cur_input_embeds_no_speech[i]) 162 | cur_new_labels.append(cur_labels_nospeech[i]) 163 | if i < num_speech: 164 | cur_speech_features = speech_features[cur_speech_idx] #torch.Size([300, 4096]) 165 | cur_speech_idx += 1 166 | cur_new_input_embeds.append(cur_speech_features) 167 | cur_new_labels.append(torch.full((cur_speech_features.shape[0],), IGNORE_INDEX, device=cur_labels.device, dtype=cur_labels.dtype)) 168 | #填充和cur_speech_features等长的-100列表 169 | cur_new_input_embeds = [x.to(self.device) for x in cur_new_input_embeds] 170 | 171 | cur_new_input_embeds = torch.cat(cur_new_input_embeds) #torch.Size([361, 4096]) 172 | cur_new_labels = torch.cat(cur_new_labels) #torch.Size([361]) 173 | 174 | new_input_embeds.append(cur_new_input_embeds) 175 | new_labels.append(cur_new_labels) 176 | # Truncate sequences to max length as speech features can make the sequence longer 177 | tokenizer_model_max_length = getattr(self.config, 'tokenizer_model_max_length', None) 178 | if tokenizer_model_max_length is not None: 179 | new_input_embeds = [x[:tokenizer_model_max_length] for x in new_input_embeds] 180 | new_labels = [x[:tokenizer_model_max_length] for x in new_labels] 181 | 182 | # Combine them 183 | max_len = max(x.shape[0] for x in new_input_embeds) 184 | batch_size = len(new_input_embeds) 185 | 186 | new_input_embeds_padded = [] 187 | new_labels_padded = torch.full((batch_size, max_len), IGNORE_INDEX, dtype=new_labels[0].dtype, device=new_labels[0].device) 188 | attention_mask = torch.zeros((batch_size, max_len), dtype=attention_mask.dtype, device=attention_mask.device) 189 | position_ids = torch.zeros((batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device) 190 | for i, (cur_new_embed, cur_new_labels) in enumerate(zip(new_input_embeds, new_labels)): 191 | cur_len = cur_new_embed.shape[0] 192 | if getattr(self.config, 'tokenizer_padding_side', 'right') == "left": 193 | new_input_embeds_padded.append(torch.cat(( 194 | torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device), 195 | cur_new_embed 196 | ), dim=0)) 197 | if cur_len > 0: 198 | new_labels_padded[i, -cur_len:] = cur_new_labels 199 | attention_mask[i, -cur_len:] = True 200 | position_ids[i, -cur_len:] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device) 201 | else: 202 | new_input_embeds_padded.append(torch.cat(( 203 | cur_new_embed, 204 | torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device) 205 | ), dim=0)) 206 | if cur_len > 0: 207 | new_labels_padded[i, :cur_len] = cur_new_labels 208 | attention_mask[i, :cur_len] = True 209 | position_ids[i, :cur_len] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device) 210 | 211 | new_input_embeds = torch.stack(new_input_embeds_padded, dim=0) 212 | 213 | if _labels is None: 214 | new_labels = None 215 | else: 216 | new_labels = new_labels_padded 217 | 218 | if _attention_mask is None: 219 | attention_mask = None 220 | else: 221 | attention_mask = attention_mask.to(dtype=_attention_mask.dtype) 222 | 223 | if _position_ids is None: 224 | position_ids = None 225 | #return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels 226 | # print(", ".join(map(str, new_labels))) 227 | return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels -------------------------------------------------------------------------------- /omni_speech/model/speech_encoder/builder.py: -------------------------------------------------------------------------------- 1 | from .speech_encoder import WhisperWrappedEncoder 2 | 3 | 4 | def build_speech_encoder(config): 5 | speech_encoder_type = getattr(config, 'speech_encoder_type', None) 6 | if "whisper" in speech_encoder_type.lower(): 7 | return WhisperWrappedEncoder.load(config) 8 | 9 | raise ValueError(f'Unknown speech encoder: {speech_encoder_type}') 10 | -------------------------------------------------------------------------------- /omni_speech/model/speech_encoder/speech_encoder.py: -------------------------------------------------------------------------------- 1 | # Adopted from https://github.com/ddlBoJack/SLAM-LLM/blob/main/src/slam_llm/models/encoder.py 2 | 3 | import types 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | 9 | class WhisperWrappedEncoder: 10 | 11 | @classmethod 12 | def load(cls, model_config): 13 | 14 | def replace_layer_norm(module): 15 | from whisper.model import LayerNorm 16 | for name, child in module.named_children(): 17 | if isinstance(child, LayerNorm): 18 | old_params = child.state_dict() 19 | new_layer_norm = nn.LayerNorm(child.normalized_shape, eps=child.eps, elementwise_affine=child.elementwise_affine) 20 | new_layer_norm.load_state_dict(old_params) 21 | setattr(module, name, new_layer_norm) 22 | else: 23 | replace_layer_norm(child) 24 | 25 | import whisper 26 | encoder = whisper.load_model(name=model_config.speech_encoder, device='cpu').encoder 27 | replace_layer_norm(encoder) 28 | 29 | return encoder -------------------------------------------------------------------------------- /omni_speech/model/speech_generator/builder.py: -------------------------------------------------------------------------------- 1 | from .speech_generator import SpeechGeneratorCTC 2 | 3 | 4 | def build_speech_generator(config): 5 | generator_type = getattr(config, 'speech_generator_type', 'ctc') 6 | if generator_type == 'ctc': 7 | return SpeechGeneratorCTC(config) 8 | 9 | raise ValueError(f'Unknown generator type: {generator_type}') 10 | -------------------------------------------------------------------------------- /omni_speech/model/speech_generator/speech_generator.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from transformers.models.llama.modeling_llama import LlamaDecoderLayer 7 | from omni_speech.constants import IGNORE_INDEX 8 | 9 | 10 | def lengths_to_padding_mask(lens): 11 | bsz, max_lens = lens.size(0), torch.max(lens).item() 12 | mask = torch.arange(max_lens).to(lens.device).view(1, max_lens) 13 | mask = mask.expand(bsz, -1) >= lens.view(bsz, 1).expand(-1, max_lens) 14 | return mask 15 | 16 | 17 | def _uniform_assignment(src_lens, tgt_lens): 18 | tgt_indices = torch.arange(torch.max(tgt_lens)).expand(len(tgt_lens), -1).to(tgt_lens.device) 19 | ratio = tgt_lens / src_lens 20 | index_t = (tgt_indices / ratio.view(-1, 1)).long() 21 | return index_t 22 | 23 | 24 | class SpeechGeneratorCTC(nn.Module): 25 | def __init__(self, config): 26 | super().__init__() 27 | n_layers, n_dims, n_heads, n_inter_dims = list(map(int, config.ctc_decoder_config[1:-1].split(","))) 28 | _config = copy.deepcopy(config) 29 | _config.hidden_size = n_dims 30 | _config.num_hidden_layers = n_layers 31 | _config.num_attention_heads = n_heads 32 | _config.num_key_value_heads = n_heads 33 | _config.intermediate_size = n_inter_dims 34 | _config._attn_implementation = "flash_attention_2" 35 | self.upsample_factor = config.ctc_upsample_factor 36 | self.input_proj = nn.Linear(config.hidden_size, n_dims) 37 | self.layers = nn.ModuleList( 38 | [LlamaDecoderLayer(_config, layer_idx) for layer_idx in range(n_layers)] 39 | ) 40 | self.unit_vocab_size = config.unit_vocab_size 41 | self.output_proj = nn.Linear(n_dims, config.unit_vocab_size + 1) 42 | 43 | def upsample(self, reps, tgt_units=None): 44 | src_lens = torch.LongTensor([len(rep) for rep in reps]).to(reps[0].device) 45 | up_lens = src_lens * self.upsample_factor 46 | if tgt_units is not None: 47 | tgt_lens = tgt_units.ne(IGNORE_INDEX).long().sum(dim=-1) 48 | up_lens = torch.max(up_lens, tgt_lens) 49 | reps = torch.nn.utils.rnn.pad_sequence(reps, batch_first=True) 50 | padding_mask = lengths_to_padding_mask(up_lens) 51 | mapped_inputs = _uniform_assignment(src_lens, up_lens).masked_fill( 52 | padding_mask, 0 53 | ) 54 | copied_reps = torch.gather( 55 | reps, 56 | 1, 57 | mapped_inputs.unsqueeze(-1).expand( 58 | *mapped_inputs.size(), reps.size(-1) 59 | ), 60 | ) 61 | copied_reps = copied_reps.masked_fill(padding_mask.unsqueeze(-1), 0) 62 | position_ids = torch.arange(0, max(up_lens)).unsqueeze(0).expand(len(reps), -1).to(device=copied_reps.device) 63 | return copied_reps, ~padding_mask, position_ids 64 | 65 | def forward(self, tgt_reps, labels, tgt_units): 66 | tgt_label_reps = [] 67 | for tgt_rep, label in zip(tgt_reps, labels): 68 | tgt_label_reps.append(tgt_rep[label != IGNORE_INDEX]) 69 | hidden_states, attention_mask, position_ids = self.upsample(tgt_label_reps, tgt_units) 70 | hidden_states = self.input_proj(hidden_states) 71 | for layer in self.layers: 72 | layer_outputs = layer( 73 | hidden_states, 74 | attention_mask=attention_mask, 75 | position_ids=position_ids, 76 | ) 77 | hidden_states = layer_outputs[0] 78 | ctc_logits = self.output_proj(hidden_states) 79 | ctc_lprobs = F.log_softmax(ctc_logits.float(), dim=-1, dtype=torch.float32) 80 | ctc_lens = attention_mask.long().sum(dim=-1) 81 | ctc_tgt_lens = tgt_units.ne(IGNORE_INDEX).long().sum(dim=-1) 82 | ctc_tgt_mask = ~lengths_to_padding_mask(ctc_tgt_lens) 83 | 84 | padding_vector = torch.full((tgt_units.shape[0], tgt_units.shape[1]-ctc_tgt_mask.shape[1]), False, dtype=torch.bool).to(ctc_tgt_mask.device) 85 | ctc_tgt_mask = torch.cat((ctc_tgt_mask, padding_vector), dim=1) 86 | 87 | ctc_tgt_flat = tgt_units.masked_select(ctc_tgt_mask) 88 | ctc_loss = F.ctc_loss( 89 | ctc_lprobs.transpose(0, 1), 90 | ctc_tgt_flat, 91 | ctc_lens, 92 | ctc_tgt_lens, 93 | reduction="sum", 94 | zero_infinity=True, 95 | blank=self.unit_vocab_size 96 | ) 97 | ctc_loss /= ctc_tgt_lens.sum().item() 98 | return ctc_loss 99 | 100 | def predict(self, tgt_reps): #torch.Size([47, 4096]) 101 | hidden_states, attention_mask, position_ids = self.upsample([tgt_reps]) 102 | #torch.Size([1, 1175, 4096]) #torch.Size([1, 1175]) 103 | hidden_states = self.input_proj(hidden_states) 104 | for layer in self.layers: 105 | layer_outputs = layer( 106 | hidden_states, 107 | attention_mask=attention_mask, 108 | position_ids=position_ids, 109 | ) 110 | hidden_states = layer_outputs[0] 111 | #torch.Size([1, 1175, 4096]) 112 | ctc_logits = self.output_proj(hidden_states) #torch.Size([1, 1175, 1001]) 113 | ctc_lprobs = F.log_softmax(ctc_logits.float(), dim=-1, dtype=torch.float32) 114 | ctc_pred = ctc_lprobs.argmax(dim=-1).masked_fill_(~attention_mask, self.unit_vocab_size) 115 | return ctc_pred -------------------------------------------------------------------------------- /omni_speech/model/speech_projector/builder.py: -------------------------------------------------------------------------------- 1 | from .speech_projector import EncoderProjectorConcat 2 | 3 | 4 | def build_speech_projector(config): 5 | projector_type = getattr(config, 'speech_projector_type', 'linear') 6 | if projector_type == 'linear': 7 | return EncoderProjectorConcat(config) 8 | 9 | raise ValueError(f'Unknown projector type: {projector_type}') 10 | -------------------------------------------------------------------------------- /omni_speech/model/speech_projector/speech_projector.py: -------------------------------------------------------------------------------- 1 | # Adopted from https://github.com/ddlBoJack/SLAM-LLM/blob/main/src/slam_llm/models/projector.py 2 | 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | 8 | class EncoderProjectorConcat(nn.Module): 9 | def __init__(self, config): 10 | super().__init__() 11 | self.k = config.speech_encoder_ds_rate 12 | self.encoder_dim = config.speech_encoder_hidden_size 13 | self.llm_dim = config.hidden_size 14 | self.linear1 = nn.Linear(self.encoder_dim * self.k, 2048) 15 | self.relu = nn.ReLU() 16 | self.linear2 = nn.Linear(2048, config.hidden_size) 17 | 18 | def forward(self, x): 19 | batch_size, seq_len, dim = x.size() 20 | num_frames_to_discard = seq_len % self.k #5的倍数 21 | if num_frames_to_discard > 0: 22 | x = x[:, :-num_frames_to_discard, :] 23 | seq_len = x.size(1) 24 | 25 | x = x.contiguous() 26 | x = x.view(batch_size, seq_len // self.k, dim * self.k) 27 | x = self.linear1(x) 28 | x = self.relu(x) 29 | x = self.linear2(x) 30 | return x 31 | -------------------------------------------------------------------------------- /omni_speech/serve/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wntg/LLaMA-Omni/cf4c32fb4ee5cb620c58520db562e6e13384ecb3/omni_speech/serve/__init__.py -------------------------------------------------------------------------------- /omni_speech/serve/controller.py: -------------------------------------------------------------------------------- 1 | """ 2 | A controller manages distributed workers. 3 | It sends worker addresses to clients. 4 | """ 5 | import argparse 6 | import asyncio 7 | import dataclasses 8 | from enum import Enum, auto 9 | import json 10 | import logging 11 | import time 12 | from typing import List, Union 13 | import threading 14 | 15 | from fastapi import FastAPI, Request 16 | from fastapi.responses import StreamingResponse 17 | import numpy as np 18 | import requests 19 | import uvicorn 20 | 21 | from omni_speech.constants import CONTROLLER_HEART_BEAT_EXPIRATION 22 | from omni_speech.utils import build_logger, server_error_msg 23 | 24 | 25 | logger = build_logger("controller", "controller.log") 26 | 27 | 28 | class DispatchMethod(Enum): 29 | LOTTERY = auto() 30 | SHORTEST_QUEUE = auto() 31 | 32 | @classmethod 33 | def from_str(cls, name): 34 | if name == "lottery": 35 | return cls.LOTTERY 36 | elif name == "shortest_queue": 37 | return cls.SHORTEST_QUEUE 38 | else: 39 | raise ValueError(f"Invalid dispatch method") 40 | 41 | 42 | @dataclasses.dataclass 43 | class WorkerInfo: 44 | model_names: List[str] 45 | speed: int 46 | queue_length: int 47 | check_heart_beat: bool 48 | last_heart_beat: str 49 | 50 | 51 | def heart_beat_controller(controller): 52 | while True: 53 | time.sleep(CONTROLLER_HEART_BEAT_EXPIRATION) 54 | controller.remove_stable_workers_by_expiration() 55 | 56 | 57 | class Controller: 58 | def __init__(self, dispatch_method: str): 59 | # Dict[str -> WorkerInfo] 60 | self.worker_info = {} 61 | self.dispatch_method = DispatchMethod.from_str(dispatch_method) 62 | 63 | self.heart_beat_thread = threading.Thread( 64 | target=heart_beat_controller, args=(self,), daemon=True) 65 | self.heart_beat_thread.start() 66 | 67 | logger.info("Init controller") 68 | 69 | def register_worker(self, worker_name: str, check_heart_beat: bool, 70 | worker_status: dict): 71 | if worker_name not in self.worker_info: 72 | logger.info(f"Register a new worker: {worker_name}") 73 | else: 74 | logger.info(f"Register an existing worker: {worker_name}") 75 | 76 | if not worker_status: 77 | worker_status = self.get_worker_status(worker_name) 78 | if not worker_status: 79 | return False 80 | 81 | self.worker_info[worker_name] = WorkerInfo( 82 | worker_status["model_names"], worker_status["speed"], worker_status["queue_length"], 83 | check_heart_beat, time.time()) 84 | 85 | logger.info(f"Register done: {worker_name}, {worker_status}") 86 | return True 87 | 88 | def get_worker_status(self, worker_name: str): 89 | try: 90 | r = requests.post(worker_name + "/worker_get_status", timeout=5) 91 | except requests.exceptions.RequestException as e: 92 | logger.error(f"Get status fails: {worker_name}, {e}") 93 | return None 94 | 95 | if r.status_code != 200: 96 | logger.error(f"Get status fails: {worker_name}, {r}") 97 | return None 98 | 99 | return r.json() 100 | 101 | def remove_worker(self, worker_name: str): 102 | del self.worker_info[worker_name] 103 | 104 | def refresh_all_workers(self): 105 | old_info = dict(self.worker_info) 106 | self.worker_info = {} 107 | 108 | for w_name, w_info in old_info.items(): 109 | if not self.register_worker(w_name, w_info.check_heart_beat, None): 110 | logger.info(f"Remove stale worker: {w_name}") 111 | 112 | def list_models(self): 113 | model_names = set() 114 | 115 | for w_name, w_info in self.worker_info.items(): 116 | model_names.update(w_info.model_names) 117 | 118 | return list(model_names) 119 | 120 | def get_worker_address(self, model_name: str): 121 | if self.dispatch_method == DispatchMethod.LOTTERY: 122 | worker_names = [] 123 | worker_speeds = [] 124 | for w_name, w_info in self.worker_info.items(): 125 | if model_name in w_info.model_names: 126 | worker_names.append(w_name) 127 | worker_speeds.append(w_info.speed) 128 | worker_speeds = np.array(worker_speeds, dtype=np.float32) 129 | norm = np.sum(worker_speeds) 130 | if norm < 1e-4: 131 | return "" 132 | worker_speeds = worker_speeds / norm 133 | if True: # Directly return address 134 | pt = np.random.choice(np.arange(len(worker_names)), 135 | p=worker_speeds) 136 | worker_name = worker_names[pt] 137 | return worker_name 138 | 139 | # Check status before returning 140 | while True: 141 | pt = np.random.choice(np.arange(len(worker_names)), 142 | p=worker_speeds) 143 | worker_name = worker_names[pt] 144 | 145 | if self.get_worker_status(worker_name): 146 | break 147 | else: 148 | self.remove_worker(worker_name) 149 | worker_speeds[pt] = 0 150 | norm = np.sum(worker_speeds) 151 | if norm < 1e-4: 152 | return "" 153 | worker_speeds = worker_speeds / norm 154 | continue 155 | return worker_name 156 | elif self.dispatch_method == DispatchMethod.SHORTEST_QUEUE: 157 | worker_names = [] 158 | worker_qlen = [] 159 | for w_name, w_info in self.worker_info.items(): 160 | if model_name in w_info.model_names: 161 | worker_names.append(w_name) 162 | worker_qlen.append(w_info.queue_length / w_info.speed) 163 | if len(worker_names) == 0: 164 | return "" 165 | min_index = np.argmin(worker_qlen) 166 | w_name = worker_names[min_index] 167 | self.worker_info[w_name].queue_length += 1 168 | logger.info(f"names: {worker_names}, queue_lens: {worker_qlen}, ret: {w_name}") 169 | return w_name 170 | else: 171 | raise ValueError(f"Invalid dispatch method: {self.dispatch_method}") 172 | 173 | def receive_heart_beat(self, worker_name: str, queue_length: int): 174 | if worker_name not in self.worker_info: 175 | logger.info(f"Receive unknown heart beat. {worker_name}") 176 | return False 177 | 178 | self.worker_info[worker_name].queue_length = queue_length 179 | self.worker_info[worker_name].last_heart_beat = time.time() 180 | logger.info(f"Receive heart beat. {worker_name}") 181 | return True 182 | 183 | def remove_stable_workers_by_expiration(self): 184 | expire = time.time() - CONTROLLER_HEART_BEAT_EXPIRATION 185 | to_delete = [] 186 | for worker_name, w_info in self.worker_info.items(): 187 | if w_info.check_heart_beat and w_info.last_heart_beat < expire: 188 | to_delete.append(worker_name) 189 | 190 | for worker_name in to_delete: 191 | self.remove_worker(worker_name) 192 | 193 | def worker_api_generate_stream(self, params): 194 | worker_addr = self.get_worker_address(params["model"]) 195 | if not worker_addr: 196 | logger.info(f"no worker: {params['model']}") 197 | ret = { 198 | "text": server_error_msg, 199 | "error_code": 2, 200 | } 201 | yield json.dumps(ret).encode() + b"\0" 202 | 203 | try: 204 | response = requests.post(worker_addr + "/worker_generate_stream", 205 | json=params, stream=True, timeout=5) 206 | for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"): 207 | if chunk: 208 | yield chunk + b"\0" 209 | except requests.exceptions.RequestException as e: 210 | logger.info(f"worker timeout: {worker_addr}") 211 | ret = { 212 | "text": server_error_msg, 213 | "error_code": 3, 214 | } 215 | yield json.dumps(ret).encode() + b"\0" 216 | 217 | 218 | # Let the controller act as a worker to achieve hierarchical 219 | # management. This can be used to connect isolated sub networks. 220 | def worker_api_get_status(self): 221 | model_names = set() 222 | speed = 0 223 | queue_length = 0 224 | 225 | for w_name in self.worker_info: 226 | worker_status = self.get_worker_status(w_name) 227 | if worker_status is not None: 228 | model_names.update(worker_status["model_names"]) 229 | speed += worker_status["speed"] 230 | queue_length += worker_status["queue_length"] 231 | 232 | return { 233 | "model_names": list(model_names), 234 | "speed": speed, 235 | "queue_length": queue_length, 236 | } 237 | 238 | 239 | app = FastAPI() 240 | 241 | 242 | @app.post("/register_worker") 243 | async def register_worker(request: Request): 244 | data = await request.json() 245 | controller.register_worker( 246 | data["worker_name"], data["check_heart_beat"], 247 | data.get("worker_status", None)) 248 | 249 | 250 | @app.post("/refresh_all_workers") 251 | async def refresh_all_workers(): 252 | models = controller.refresh_all_workers() 253 | 254 | 255 | @app.post("/list_models") 256 | async def list_models(): 257 | models = controller.list_models() 258 | return {"models": models} 259 | 260 | 261 | @app.post("/get_worker_address") 262 | async def get_worker_address(request: Request): 263 | data = await request.json() 264 | addr = controller.get_worker_address(data["model"]) 265 | return {"address": addr} 266 | 267 | 268 | @app.post("/receive_heart_beat") 269 | async def receive_heart_beat(request: Request): 270 | data = await request.json() 271 | exist = controller.receive_heart_beat( 272 | data["worker_name"], data["queue_length"]) 273 | return {"exist": exist} 274 | 275 | 276 | @app.post("/worker_generate_stream") 277 | async def worker_api_generate_stream(request: Request): 278 | params = await request.json() 279 | generator = controller.worker_api_generate_stream(params) 280 | return StreamingResponse(generator) 281 | 282 | 283 | @app.post("/worker_get_status") 284 | async def worker_api_get_status(request: Request): 285 | return controller.worker_api_get_status() 286 | 287 | 288 | if __name__ == "__main__": 289 | parser = argparse.ArgumentParser() 290 | parser.add_argument("--host", type=str, default="localhost") 291 | parser.add_argument("--port", type=int, default=21001) 292 | parser.add_argument("--dispatch-method", type=str, choices=[ 293 | "lottery", "shortest_queue"], default="shortest_queue") 294 | args = parser.parse_args() 295 | logger.info(f"args: {args}") 296 | 297 | controller = Controller(args.dispatch_method) 298 | uvicorn.run(app, host=args.host, port=args.port, log_level="info") -------------------------------------------------------------------------------- /omni_speech/serve/examples/helpful_base_1.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wntg/LLaMA-Omni/cf4c32fb4ee5cb620c58520db562e6e13384ecb3/omni_speech/serve/examples/helpful_base_1.wav -------------------------------------------------------------------------------- /omni_speech/serve/examples/helpful_base_2.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wntg/LLaMA-Omni/cf4c32fb4ee5cb620c58520db562e6e13384ecb3/omni_speech/serve/examples/helpful_base_2.wav -------------------------------------------------------------------------------- /omni_speech/serve/examples/helpful_base_3.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wntg/LLaMA-Omni/cf4c32fb4ee5cb620c58520db562e6e13384ecb3/omni_speech/serve/examples/helpful_base_3.wav -------------------------------------------------------------------------------- /omni_speech/serve/examples/helpful_base_4.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wntg/LLaMA-Omni/cf4c32fb4ee5cb620c58520db562e6e13384ecb3/omni_speech/serve/examples/helpful_base_4.wav -------------------------------------------------------------------------------- /omni_speech/serve/examples/helpful_base_5.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wntg/LLaMA-Omni/cf4c32fb4ee5cb620c58520db562e6e13384ecb3/omni_speech/serve/examples/helpful_base_5.wav -------------------------------------------------------------------------------- /omni_speech/serve/examples/vicuna_1.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wntg/LLaMA-Omni/cf4c32fb4ee5cb620c58520db562e6e13384ecb3/omni_speech/serve/examples/vicuna_1.wav -------------------------------------------------------------------------------- /omni_speech/serve/examples/vicuna_2.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wntg/LLaMA-Omni/cf4c32fb4ee5cb620c58520db562e6e13384ecb3/omni_speech/serve/examples/vicuna_2.wav -------------------------------------------------------------------------------- /omni_speech/serve/examples/vicuna_3.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wntg/LLaMA-Omni/cf4c32fb4ee5cb620c58520db562e6e13384ecb3/omni_speech/serve/examples/vicuna_3.wav -------------------------------------------------------------------------------- /omni_speech/serve/examples/vicuna_4.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wntg/LLaMA-Omni/cf4c32fb4ee5cb620c58520db562e6e13384ecb3/omni_speech/serve/examples/vicuna_4.wav -------------------------------------------------------------------------------- /omni_speech/serve/examples/vicuna_5.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wntg/LLaMA-Omni/cf4c32fb4ee5cb620c58520db562e6e13384ecb3/omni_speech/serve/examples/vicuna_5.wav -------------------------------------------------------------------------------- /omni_speech/serve/gradio_web_server.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import datetime 3 | import json 4 | import os 5 | import time 6 | import torch 7 | import torchaudio 8 | 9 | import gradio as gr 10 | import numpy as np 11 | import requests 12 | import soundfile as sf 13 | 14 | from omni_speech.conversation import default_conversation, conv_templates 15 | from omni_speech.constants import LOGDIR 16 | from omni_speech.utils import build_logger, server_error_msg 17 | from fairseq.models.text_to_speech.vocoder import CodeHiFiGANVocoder 18 | 19 | 20 | logger = build_logger("gradio_web_server", "gradio_web_server.log") 21 | 22 | vocoder = None 23 | 24 | headers = {"User-Agent": "LLaMA-Omni Client"} 25 | 26 | no_change_btn = gr.Button() 27 | enable_btn = gr.Button(interactive=True) 28 | disable_btn = gr.Button(interactive=False) 29 | 30 | 31 | def get_conv_log_filename(): 32 | t = datetime.datetime.now() 33 | name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json") 34 | return name 35 | 36 | 37 | def get_model_list(): 38 | ret = requests.post(args.controller_url + "/refresh_all_workers") 39 | assert ret.status_code == 200 40 | ret = requests.post(args.controller_url + "/list_models") 41 | models = ret.json()["models"] 42 | logger.info(f"Models: {models}") 43 | return models 44 | 45 | 46 | get_window_url_params = """ 47 | function() { 48 | const params = new URLSearchParams(window.location.search); 49 | url_params = Object.fromEntries(params); 50 | console.log(url_params); 51 | return url_params; 52 | } 53 | """ 54 | 55 | 56 | def load_demo(url_params, request: gr.Request): 57 | logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}") 58 | 59 | dropdown_update = gr.Dropdown(visible=True) 60 | if "model" in url_params: 61 | model = url_params["model"] 62 | if model in models: 63 | dropdown_update = gr.Dropdown(value=model, visible=True) 64 | 65 | state = default_conversation.copy() 66 | return state, dropdown_update 67 | 68 | 69 | def load_demo_refresh_model_list(request: gr.Request): 70 | logger.info(f"load_demo. ip: {request.client.host}") 71 | models = get_model_list() 72 | state = default_conversation.copy() 73 | dropdown_update = gr.Dropdown( 74 | choices=models, 75 | value=models[0] if len(models) > 0 else "" 76 | ) 77 | return state, dropdown_update 78 | 79 | 80 | def clear_history(request: gr.Request): 81 | logger.info(f"clear_history. ip: {request.client.host}") 82 | state = default_conversation.copy() 83 | return (state, None, "", "", None) 84 | 85 | 86 | def add_speech(state, speech, request: gr.Request): 87 | text = "Please directly answer the questions in the user's speech." 88 | text = '\n' + text 89 | text = (text, speech) 90 | state = default_conversation.copy() 91 | state.append_message(state.roles[0], text) 92 | state.append_message(state.roles[1], None) 93 | state.skip_next = False 94 | return (state) 95 | 96 | 97 | def http_bot(state, model_selector, temperature, top_p, max_new_tokens, chunk_size, request: gr.Request): 98 | logger.info(f"http_bot. ip: {request.client.host}") 99 | start_tstamp = time.time() 100 | model_name = model_selector 101 | 102 | if state.skip_next: 103 | # This generate call is skipped due to invalid inputs 104 | yield (state, "", "", None) 105 | return 106 | 107 | if len(state.messages) == state.offset + 2: 108 | # First round of conversation 109 | template_name = "llama_3" 110 | new_state = conv_templates[template_name].copy() 111 | new_state.append_message(new_state.roles[0], state.messages[-2][1]) 112 | new_state.append_message(new_state.roles[1], None) 113 | state = new_state 114 | 115 | # Query worker address 116 | controller_url = args.controller_url 117 | ret = requests.post(controller_url + "/get_worker_address", 118 | json={"model": model_name}) 119 | worker_addr = ret.json()["address"] 120 | logger.info(f"model_name: {model_name}, worker_addr: {worker_addr}") 121 | 122 | # No available worker 123 | if worker_addr == "": 124 | state.messages[-1][-1] = server_error_msg 125 | yield (state, "", "", None) 126 | return 127 | 128 | # Construct prompt 129 | prompt = state.get_prompt() 130 | 131 | sr, audio = state.messages[0][1][1] 132 | resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=16000) 133 | audio = torch.tensor(audio.astype(np.float32)).unsqueeze(0) 134 | audio = resampler(audio).squeeze(0).numpy() 135 | audio /= 32768.0 136 | audio = audio.tolist() 137 | # Make requests 138 | pload = { 139 | "model": model_name, 140 | "prompt": prompt, 141 | "temperature": float(temperature), 142 | "top_p": float(top_p), 143 | "max_new_tokens": min(int(max_new_tokens), 1500), 144 | "stop": state.sep2, 145 | "audio": audio, 146 | } 147 | 148 | yield (state, "", "", None) 149 | 150 | cur_dir = os.path.dirname(os.path.abspath(__file__)) 151 | 152 | try: 153 | # Stream output 154 | response = requests.post(worker_addr + "/worker_generate_stream", 155 | headers=headers, json=pload, stream=True, timeout=10) 156 | num_generated_units = 0 157 | wav_list = [] 158 | for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"): 159 | if chunk: 160 | data = json.loads(chunk.decode()) 161 | if data["error_code"] == 0: 162 | output = data["text"][len(prompt):].strip() 163 | output_unit = list(map(int, data["unit"].strip().split())) 164 | state.messages[-1][-1] = (output, data["unit"].strip()) 165 | 166 | # vocoder 167 | new_units = output_unit[num_generated_units:] 168 | if len(new_units) >= chunk_size: 169 | num_generated_units = len(output_unit) 170 | x = {"code": torch.LongTensor(new_units).view(1, -1).cuda()} 171 | wav = vocoder(x, True) 172 | wav_list.append(wav.detach().cpu().numpy()) 173 | 174 | if len(wav_list) > 0: 175 | wav_full = np.concatenate(wav_list) 176 | return_value = (16000, wav_full) 177 | else: 178 | return_value = None 179 | 180 | yield (state, state.messages[-1][-1][0], state.messages[-1][-1][1], return_value) 181 | else: 182 | output = data["text"] + f" (error_code: {data['error_code']})" 183 | state.messages[-1][-1] = output 184 | yield (state, "", "", None) 185 | return 186 | time.sleep(0.03) 187 | except requests.exceptions.RequestException as e: 188 | state.messages[-1][-1] = server_error_msg 189 | yield (state, "", "", None) 190 | return 191 | 192 | if num_generated_units < len(output_unit): 193 | new_units = output_unit[num_generated_units:] 194 | num_generated_units = len(output_unit) 195 | x = { 196 | "code": torch.LongTensor(new_units).view(1, -1).cuda() 197 | } 198 | wav = vocoder(x, True) 199 | wav_list.append(wav.detach().cpu().numpy()) 200 | 201 | if len(wav_list) > 0: 202 | wav_full = np.concatenate(wav_list) 203 | return_value = (16000, wav_full) 204 | else: 205 | return_value = None 206 | 207 | yield (state, state.messages[-1][-1][0], state.messages[-1][-1][1], return_value) 208 | 209 | finish_tstamp = time.time() 210 | logger.info(f"{output}") 211 | logger.info(f"{output_unit}") 212 | 213 | 214 | title_markdown = (""" 215 | # 🎧 LLaMA-Omni: Seamless Speech Interaction with Large Language Models 216 | """) 217 | 218 | block_css = """ 219 | 220 | #buttons button { 221 | min-width: min(120px,100%); 222 | } 223 | 224 | """ 225 | 226 | def build_demo(embed_mode, vocoder, cur_dir=None, concurrency_count=10): 227 | with gr.Blocks(title="LLaMA-Omni Speech Chatbot", theme=gr.themes.Default(), css=block_css) as demo: 228 | state = gr.State() 229 | 230 | if not embed_mode: 231 | gr.Markdown(title_markdown) 232 | 233 | with gr.Row(elem_id="model_selector_row"): 234 | model_selector = gr.Dropdown( 235 | choices=models, 236 | value=models[0] if len(models) > 0 else "", 237 | interactive=True, 238 | show_label=False, 239 | container=False) 240 | 241 | with gr.Row(): 242 | audio_input_box = gr.Audio(sources=["upload", "microphone"], label="Speech Input") 243 | with gr.Accordion("Parameters", open=True) as parameter_row: 244 | temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.0, step=0.1, interactive=True, label="Temperature",) 245 | top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.7, step=0.1, interactive=True, label="Top P",) 246 | max_output_tokens = gr.Slider(minimum=0, maximum=1024, value=512, step=64, interactive=True, label="Max Output Tokens",) 247 | chunk_size = gr.Slider(minimum=10, maximum=500, value=40, step=10, interactive=True, label="Chunk Size",) 248 | 249 | if cur_dir is None: 250 | cur_dir = os.path.dirname(os.path.abspath(__file__)) 251 | gr.Examples(examples=[ 252 | [f"{cur_dir}/examples/vicuna_1.wav"], 253 | [f"{cur_dir}/examples/vicuna_2.wav"], 254 | [f"{cur_dir}/examples/vicuna_3.wav"], 255 | [f"{cur_dir}/examples/vicuna_4.wav"], 256 | [f"{cur_dir}/examples/vicuna_5.wav"], 257 | [f"{cur_dir}/examples/helpful_base_1.wav"], 258 | [f"{cur_dir}/examples/helpful_base_2.wav"], 259 | [f"{cur_dir}/examples/helpful_base_3.wav"], 260 | [f"{cur_dir}/examples/helpful_base_4.wav"], 261 | [f"{cur_dir}/examples/helpful_base_5.wav"], 262 | ], inputs=[audio_input_box]) 263 | 264 | with gr.Row(): 265 | submit_btn = gr.Button(value="Send", variant="primary") 266 | clear_btn = gr.Button(value="Clear") 267 | 268 | text_output_box = gr.Textbox(label="Text Output", type="text") 269 | unit_output_box = gr.Textbox(label="Unit Output", type="text") 270 | audio_output_box = gr.Audio(label="Speech Output") 271 | 272 | url_params = gr.JSON(visible=False) 273 | 274 | submit_btn.click( 275 | add_speech, 276 | [state, audio_input_box], 277 | [state] 278 | ).then( 279 | http_bot, 280 | [state, model_selector, temperature, top_p, max_output_tokens, chunk_size], 281 | [state, text_output_box, unit_output_box, audio_output_box], 282 | concurrency_limit=concurrency_count 283 | ) 284 | 285 | clear_btn.click( 286 | clear_history, 287 | None, 288 | [state, audio_input_box, text_output_box, unit_output_box, audio_output_box], 289 | queue=False 290 | ) 291 | 292 | if args.model_list_mode == "once": 293 | demo.load( 294 | load_demo, 295 | [url_params], 296 | [state, model_selector], 297 | js=get_window_url_params 298 | ) 299 | elif args.model_list_mode == "reload": 300 | demo.load( 301 | load_demo_refresh_model_list, 302 | None, 303 | [state, model_selector], 304 | queue=False 305 | ) 306 | else: 307 | raise ValueError(f"Unknown model list mode: {args.model_list_mode}") 308 | 309 | return demo 310 | 311 | 312 | def build_vocoder(args): 313 | global vocoder 314 | if args.vocoder is None: 315 | return None 316 | with open(args.vocoder_cfg) as f: 317 | vocoder_cfg = json.load(f) 318 | vocoder = CodeHiFiGANVocoder(args.vocoder, vocoder_cfg).cuda() 319 | 320 | 321 | if __name__ == "__main__": 322 | parser = argparse.ArgumentParser() 323 | parser.add_argument("--host", type=str, default="0.0.0.0") 324 | parser.add_argument("--port", type=int) 325 | parser.add_argument("--controller-url", type=str, default="http://localhost:21001") 326 | parser.add_argument("--concurrency-count", type=int, default=16) 327 | parser.add_argument("--model-list-mode", type=str, default="once", 328 | choices=["once", "reload"]) 329 | parser.add_argument("--share", action="store_true") 330 | parser.add_argument("--moderate", action="store_true") 331 | parser.add_argument("--embed", action="store_true") 332 | parser.add_argument("--vocoder", type=str) 333 | parser.add_argument("--vocoder-cfg", type=str) 334 | args = parser.parse_args() 335 | logger.info(f"args: {args}") 336 | 337 | models = get_model_list() 338 | build_vocoder(args) 339 | 340 | logger.info(args) 341 | demo = build_demo(args.embed, vocoder, concurrency_count=args.concurrency_count) 342 | demo.queue( 343 | api_open=False 344 | ).launch( 345 | server_name=args.host, 346 | server_port=args.port, 347 | share=args.share 348 | ) -------------------------------------------------------------------------------- /omni_speech/serve/model_worker.py: -------------------------------------------------------------------------------- 1 | """ 2 | A model worker executes the model. 3 | """ 4 | import argparse 5 | import asyncio 6 | import json 7 | import time 8 | import threading 9 | import uuid 10 | 11 | from fastapi import FastAPI, Request, BackgroundTasks 12 | from fastapi.responses import StreamingResponse 13 | import requests 14 | import torch 15 | import uvicorn 16 | import whisper 17 | import numpy as np 18 | from functools import partial 19 | 20 | from transformers import PreTrainedTokenizer 21 | 22 | from omni_speech.constants import WORKER_HEART_BEAT_INTERVAL 23 | from omni_speech.utils import (build_logger, server_error_msg, 24 | pretty_print_semaphore) 25 | from omni_speech.model.builder import load_pretrained_model 26 | from omni_speech.constants import SPEECH_TOKEN_INDEX, DEFAULT_SPEECH_TOKEN 27 | from omni_speech.datasets.preprocess import tokenizer_speech_token 28 | from transformers import TextIteratorStreamer 29 | from threading import Thread 30 | 31 | 32 | GB = 1 << 30 33 | 34 | worker_id = str(uuid.uuid4())[:6] 35 | logger = build_logger("model_worker", f"model_worker_{worker_id}.log") 36 | global_counter = 0 37 | 38 | model_semaphore = None 39 | 40 | 41 | def heart_beat_worker(controller): 42 | 43 | while True: 44 | time.sleep(WORKER_HEART_BEAT_INTERVAL) 45 | controller.send_heart_beat() 46 | 47 | 48 | def load_speech(audio, input_type, mel_size, speech_normalize): 49 | speech = np.array(audio, dtype=np.float32) 50 | if input_type == "raw": 51 | speech = torch.from_numpy(speech) 52 | if speech_normalize: 53 | speech = torch.nn.functional.layer_norm(speech, speech.shape) 54 | elif input_type == "mel": 55 | speech = whisper.pad_or_trim(speech) 56 | speech = whisper.log_mel_spectrogram(speech, n_mels=mel_size).permute(1, 0) 57 | return speech 58 | 59 | 60 | def build_unit_tokenizer(vocab_size): 61 | import os 62 | from transformers import BertTokenizer 63 | with open("unit_vocab.txt", "w") as f: 64 | for i in range(vocab_size + 1): 65 | f.write(str(i) + "\n") 66 | tokenizer = BertTokenizer(vocab_file="unit_vocab.txt") 67 | os.remove("unit_vocab.txt") 68 | return tokenizer 69 | 70 | 71 | class ModelWorker: 72 | def __init__(self, controller_addr, worker_addr, 73 | worker_id, no_register, 74 | model_path, model_base, model_name, 75 | load_8bit, load_4bit, device, input_type, mel_size, s2s, is_lora, use_flash_attn=False): 76 | self.controller_addr = controller_addr 77 | self.worker_addr = worker_addr 78 | self.worker_id = worker_id 79 | self.device = device 80 | self.model_name = model_name 81 | self.input_type = input_type 82 | self.mel_size = mel_size 83 | self.tokenizer, self.model, self.context_len = load_pretrained_model( 84 | model_path, model_base, is_lora=is_lora, s2s=s2s, load_8bit=load_8bit, load_4bit=load_4bit, device=self.device, use_flash_attn=use_flash_attn) 85 | self.unit_tokenizer = build_unit_tokenizer(self.model.config.unit_vocab_size) 86 | 87 | if not no_register: 88 | self.register_to_controller() 89 | self.heart_beat_thread = threading.Thread( 90 | target=heart_beat_worker, args=(self,), daemon=True) 91 | self.heart_beat_thread.start() 92 | 93 | def register_to_controller(self): 94 | logger.info("Register to controller") 95 | 96 | url = self.controller_addr + "/register_worker" 97 | data = { 98 | "worker_name": self.worker_addr, 99 | "check_heart_beat": True, 100 | "worker_status": self.get_status() 101 | } 102 | r = requests.post(url, json=data) 103 | assert r.status_code == 200 104 | 105 | def send_heart_beat(self): 106 | logger.info(f"Send heart beat. Models: {[self.model_name]}. " 107 | f"Semaphore: {pretty_print_semaphore(model_semaphore)}. " 108 | f"global_counter: {global_counter}") 109 | 110 | url = self.controller_addr + "/receive_heart_beat" 111 | 112 | while True: 113 | try: 114 | ret = requests.post(url, json={ 115 | "worker_name": self.worker_addr, 116 | "queue_length": self.get_queue_length()}, timeout=5) 117 | exist = ret.json()["exist"] 118 | break 119 | except requests.exceptions.RequestException as e: 120 | logger.error(f"heart beat error: {e}") 121 | time.sleep(5) 122 | 123 | if not exist: 124 | self.register_to_controller() 125 | 126 | def get_queue_length(self): 127 | if model_semaphore is None: 128 | return 0 129 | else: 130 | return args.limit_model_concurrency - model_semaphore._value + (len( 131 | model_semaphore._waiters) if model_semaphore._waiters is not None else 0) 132 | 133 | def get_status(self): 134 | return { 135 | "model_names": [self.model_name], 136 | "speed": 1, 137 | "queue_length": self.get_queue_length(), 138 | } 139 | 140 | @torch.inference_mode() 141 | def generate_stream(self, params): 142 | tokenizer, model = self.tokenizer, self.model 143 | 144 | prompt = params["prompt"] 145 | ori_prompt = prompt 146 | audio = params.get("audio", None) 147 | if audio is not None and len(audio) > 0: 148 | speech = load_speech(audio, self.input_type, self.mel_size, self.model.config.speech_normalize) 149 | speech_length = torch.LongTensor([speech.shape[0]]).unsqueeze(0).to(self.device) 150 | speech_tensor = speech.unsqueeze(0).to(self.device, dtype=torch.float16) 151 | speech_args = {"speech": speech_tensor, "speech_lengths": speech_length} 152 | else: 153 | speech = None 154 | speech_args = {} 155 | 156 | temperature = float(params.get("temperature", 1.0)) 157 | top_p = float(params.get("top_p", 1.0)) 158 | max_context_length = getattr(model.config, 'max_position_embeddings', 2048) 159 | max_new_tokens = min(int(params.get("max_new_tokens", 256)), 1024) 160 | stop_str = params.get("stop", None) 161 | do_sample = True if temperature > 0.001 else False 162 | 163 | input_ids = tokenizer_speech_token(prompt, tokenizer, return_tensors='pt').unsqueeze(0).to(self.device) 164 | streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=15) 165 | streamer_unit = TextIteratorStreamer(self.unit_tokenizer, skip_prompt=False, skip_special_tokens=True, timeout=15) 166 | 167 | # max_new_tokens = min(max_new_tokens, max_context_length - input_ids.shape[-1] - num_image_tokens) 168 | 169 | if max_new_tokens < 1: 170 | yield json.dumps({"text": ori_prompt + "Exceeds max token length. Please start a new conversation, thanks.", "error_code": 0}).encode() + b"\0" 171 | return 172 | 173 | thread = Thread(target=model.generate, kwargs=dict( 174 | inputs=input_ids, 175 | do_sample=do_sample, 176 | temperature=temperature, 177 | top_p=top_p, 178 | max_new_tokens=max_new_tokens, 179 | streamer=streamer, 180 | streamer_unit=streamer_unit, 181 | streaming_unit_gen=True, 182 | use_cache=True, 183 | **speech_args 184 | )) 185 | thread.start() 186 | 187 | generated_text = ori_prompt 188 | for new_text in streamer: 189 | generated_text += new_text 190 | generated_unit = " ".join(map(str, streamer_unit.token_cache)) 191 | if generated_text.endswith(stop_str): 192 | generated_text = generated_text[:-len(stop_str)] 193 | yield json.dumps({"text": generated_text, "unit": generated_unit, "error_code": 0}).encode() + b"\0" 194 | 195 | def generate_stream_gate(self, params): 196 | try: 197 | for x in self.generate_stream(params): 198 | yield x 199 | except ValueError as e: 200 | print("Caught ValueError:", e) 201 | ret = { 202 | "text": server_error_msg, 203 | "error_code": 1, 204 | } 205 | yield json.dumps(ret).encode() + b"\0" 206 | except torch.cuda.CudaError as e: 207 | print("Caught torch.cuda.CudaError:", e) 208 | ret = { 209 | "text": server_error_msg, 210 | "error_code": 1, 211 | } 212 | yield json.dumps(ret).encode() + b"\0" 213 | except Exception as e: 214 | print("Caught Unknown Error", e) 215 | ret = { 216 | "text": server_error_msg, 217 | "error_code": 1, 218 | } 219 | yield json.dumps(ret).encode() + b"\0" 220 | 221 | 222 | app = FastAPI() 223 | 224 | 225 | def release_model_semaphore(fn=None): 226 | model_semaphore.release() 227 | if fn is not None: 228 | fn() 229 | 230 | 231 | @app.post("/worker_generate_stream") 232 | async def generate_stream(request: Request): 233 | global model_semaphore, global_counter 234 | global_counter += 1 235 | params = await request.json() 236 | 237 | if model_semaphore is None: 238 | model_semaphore = asyncio.Semaphore(args.limit_model_concurrency) 239 | await model_semaphore.acquire() 240 | worker.send_heart_beat() 241 | generator = worker.generate_stream_gate(params) 242 | background_tasks = BackgroundTasks() 243 | background_tasks.add_task(partial(release_model_semaphore, fn=worker.send_heart_beat)) 244 | return StreamingResponse(generator, background=background_tasks) 245 | 246 | 247 | @app.post("/worker_get_status") 248 | async def get_status(request: Request): 249 | return worker.get_status() 250 | 251 | 252 | if __name__ == "__main__": 253 | parser = argparse.ArgumentParser() 254 | parser.add_argument("--host", type=str, default="localhost") 255 | parser.add_argument("--port", type=int, default=21002) 256 | parser.add_argument("--worker-address", type=str, 257 | default="http://localhost:21002") 258 | parser.add_argument("--controller-address", type=str, 259 | default="http://localhost:21001") 260 | parser.add_argument("--model-path", type=str, default="facebook/opt-350m") 261 | parser.add_argument("--model-base", type=str, default=None) 262 | parser.add_argument("--model-name", type=str) 263 | parser.add_argument("--device", type=str, default="cuda") 264 | parser.add_argument("--limit-model-concurrency", type=int, default=5) 265 | parser.add_argument("--stream-interval", type=int, default=1) 266 | parser.add_argument("--no-register", action="store_true") 267 | parser.add_argument("--load-8bit", action="store_true") 268 | parser.add_argument("--load-4bit", action="store_true") 269 | parser.add_argument("--use-flash-attn", action="store_true") 270 | parser.add_argument("--input-type", type=str, default="mel") 271 | parser.add_argument("--mel-size", type=int, default=128) 272 | parser.add_argument("--s2s", action="store_true", default=False) 273 | parser.add_argument("--is-lora", action="store_true", default=False) 274 | args = parser.parse_args() 275 | logger.info(f"args: {args}") 276 | 277 | worker = ModelWorker(args.controller_address, 278 | args.worker_address, 279 | worker_id, 280 | args.no_register, 281 | args.model_path, 282 | args.model_base, 283 | args.model_name, 284 | args.load_8bit, 285 | args.load_4bit, 286 | args.device, 287 | args.input_type, 288 | args.mel_size, 289 | args.s2s, 290 | args.is_lora, 291 | use_flash_attn=args.use_flash_attn) 292 | uvicorn.run(app, host=args.host, port=args.port, log_level="info") -------------------------------------------------------------------------------- /omni_speech/train/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | ROOT=$1 4 | 5 | VOCODER_CKPT=vocoder/g_00500000 6 | VOCODER_CFG=vocoder/config.json 7 | 8 | python omni_speech/train/stage1.py \ 9 | --model-path Llama-3.2-1B-Instruct \ 10 | --question-file data.json \ 11 | --answer-file answer.json \ 12 | --num-chunks 1 \ 13 | --chunk-idx 0 \ 14 | --temperature 0 \ 15 | --conv-mode llama_3 \ 16 | --input_type mel \ 17 | --mel_size 128 \ 18 | 19 | -------------------------------------------------------------------------------- /omni_speech/train/stage1.py: -------------------------------------------------------------------------------- 1 | from omni_speech.model.builder import load_pretrained_model,create_model 2 | import argparse 3 | import os 4 | import torch 5 | from torch.utils.data import Dataset, DataLoader 6 | import whisper 7 | from omni_speech.conversation import conv_templates 8 | import ipdb 9 | import math 10 | import json 11 | from tqdm import tqdm 12 | from omni_speech.datasets.preprocess import tokenizer_speech_token 13 | from transformers import DataCollatorForLanguageModeling 14 | from transformers import TrainingArguments 15 | from transformers import Trainer 16 | from tqdm import tqdm 17 | import torch.optim as optim 18 | import torch.optim as optim 19 | from transformers import DataCollatorForSeq2Seq 20 | from torch.nn.utils.rnn import pad_sequence 21 | 22 | # Custom dataset class 23 | 24 | def collate_fn(batch): 25 | for i in range(len(batch)): 26 | batch[i]= batch[i].values() 27 | 28 | input_ids,labels,speech_tensors,speech_lengths = zip(*batch) 29 | input_ids = pad_sequence(input_ids, batch_first=True, padding_value=128009) 30 | labels = pad_sequence(labels, batch_first=True, padding_value=128009) 31 | 32 | speech_tensors = torch.stack(speech_tensors, dim=0) 33 | speech_lengths = torch.stack(speech_lengths, dim=0) 34 | return {"input_ids":input_ids,"labels":labels, "speech":speech_tensors, "speech_lengths":speech_lengths} 35 | 36 | class CustomDataset(Dataset): 37 | def __init__(self, questions, tokenizer, model_config, input_type, mel_size): 38 | self.questions = questions 39 | self.tokenizer = tokenizer 40 | self.model_config = model_config 41 | self.input_type = input_type 42 | self.mel_size = mel_size 43 | 44 | def __getitem__(self, index): 45 | item = self.questions[index] 46 | speech_file = item["speech"] 47 | qs = item["conversations"][0]["value"] 48 | re = item["conversations"][1]["value"] 49 | 50 | conv = conv_templates[args.conv_mode].copy() 51 | conv.append_message(conv.roles[0], qs) 52 | conv.append_message(conv.roles[1], re) 53 | prompt = conv.get_prompt() 54 | 55 | speech = whisper.load_audio(speech_file) 56 | if self.input_type == "raw": 57 | speech = torch.from_numpy(speech) 58 | if self.model_config.speech_normalize: 59 | speech = torch.nn.functional.layer_norm(speech, speech.shape) 60 | elif self.input_type == "mel": 61 | speech = whisper.pad_or_trim(speech) 62 | speech = whisper.log_mel_spectrogram(speech, n_mels=self.mel_size).permute(1, 0) 63 | input_ids = tokenizer_speech_token(prompt, self.tokenizer, return_tensors='pt') 64 | ret=dict(input_ids=input_ids,labels=input_ids, speech=speech.to(torch.bfloat16), speech_lengths=torch.LongTensor([speech.shape[0]])) 65 | return ret 66 | def __len__(self): 67 | return len(self.questions) 68 | 69 | # DataLoader 70 | def create_data_loader(questions, tokenizer, model_config, input_type, mel_size, batch_size=2, num_workers=1): 71 | # assert batch_size == 1, "batch_size must be 1" 72 | 73 | dataset = CustomDataset(questions, tokenizer, model_config, input_type, mel_size) 74 | #data_loader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False, collate_fn=collate_fn) 75 | return dataset 76 | 77 | 78 | def split_list(lst, n): 79 | """Split a list into n (roughly) equal-sized chunks""" 80 | chunk_size = math.ceil(len(lst) / n) # integer division 81 | return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)] 82 | 83 | 84 | def get_chunk(lst, n, k): 85 | chunks = split_list(lst, n) 86 | return chunks[k] 87 | 88 | 89 | def train_model(args): 90 | # 设置每张卡的device 91 | 92 | 93 | device = 'cuda' if torch.cuda.is_available() else 'cpu' # 设置 device,能用 cuda 就用 cuda,苹果 M 系列可以用 mps 94 | 95 | model_path = os.path.expanduser(args.model_path) 96 | tokenizer, model, context_len = create_model(model_path, args.model_base, is_lora=args.is_lora, s2s=args.s2s) 97 | 98 | 99 | questions = json.load(open(os.path.expanduser(args.question_file), "r")) 100 | questions = get_chunk(questions, args.num_chunks, args.chunk_idx) #chunk 1 chunk-idx 0 取list中的多少进行测试 101 | data_loader = create_data_loader(questions, tokenizer, model.config, args.input_type, args.mel_size) 102 | 103 | 104 | from transformers import Trainer, TrainingArguments 105 | # 初始化Trainer 106 | training_args = TrainingArguments( 107 | output_dir='saves', # 输出路径,包括模型检查点、中间文件等 108 | overwrite_output_dir=True, # 是否覆写 output_dir 109 | do_train=True, # 是否做训练 110 | do_eval=False, # 是否做评估 111 | eval_steps=1, # 评估步骤间隔 112 | per_device_train_batch_size=2, # 每设备批次 113 | gradient_accumulation_steps=6, # 梯度累计步大小,省显存,但小模型没必要,用 1 收敛比较快 114 | learning_rate=1e-4, 115 | weight_decay=0.01, 116 | adam_beta2=0.95, 117 | warmup_ratio=0.01, 118 | lr_scheduler_type='cosine', # 学习率调度策略,LLM 训练一般都用余弦 119 | logging_steps=1, # 打印步骤间隔 120 | report_to=None, # 日志输出目标,不想用 wandb 可以设置为 None 121 | num_train_epochs=50, # 训练轮数,2 ~ 3 即可 122 | save_steps=1000, # 检查点保存步骤间隔 123 | save_total_limit=2, # output_dir 内留存的检查点最大数目 124 | seed=3407, # 随机种子 125 | bf16=True # 是否开启混合精度训练 126 | 127 | ) 128 | tokenizer.pad_token = tokenizer.eos_token 129 | trainer = Trainer( 130 | model=model, 131 | tokenizer=tokenizer, 132 | args=training_args, 133 | train_dataset=data_loader, 134 | eval_dataset=data_loader, 135 | data_collator=collate_fn 136 | ) 137 | trainer.train() 138 | 139 | 140 | 141 | if __name__ == "__main__": 142 | parser = argparse.ArgumentParser() 143 | parser = argparse.ArgumentParser() 144 | parser.add_argument("--model-path", type=str, default="facebook/opt-350m") 145 | parser.add_argument("--model-base", type=str, default=None) 146 | parser.add_argument("--question-file", type=str) 147 | parser.add_argument("--answer-file", type=str) 148 | parser.add_argument("--conv-mode", type=str, default="v1") 149 | parser.add_argument("--num-chunks", type=int, default=1) 150 | parser.add_argument("--chunk-idx", type=int, default=0) 151 | parser.add_argument("--temperature", type=float, default=0) 152 | parser.add_argument("--top_p", type=float, default=None) 153 | parser.add_argument("--num_beams", type=int, default=1) 154 | parser.add_argument("--max_new_tokens", type=int, default=256) 155 | parser.add_argument("--input_type", type=str, default="raw") 156 | parser.add_argument("--mel_size", type=int, default=128) 157 | parser.add_argument("--s2s", action="store_true", default=False) 158 | parser.add_argument("--is_lora", action="store_true", default=False) 159 | args = parser.parse_args() 160 | train_model(args) -------------------------------------------------------------------------------- /omni_speech/train/stage2.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | # import os 4 | #os.environ['CUDA_VISIBLE_DEVICES'] = "0" #(代表仅使用第0,1号GPU) 5 | import torch 6 | from torch.utils.data import Dataset, DataLoader 7 | import whisper 8 | # import ipdb 9 | import math 10 | import json 11 | from tqdm import tqdm 12 | from omni_speech.conversation import conv_templates 13 | from omni_speech.model.builder import load_pretrained_model,create_model 14 | from omni_speech.datasets.preprocess import tokenizer_speech_token 15 | from transformers import DataCollatorForLanguageModeling 16 | from transformers import TrainingArguments 17 | from transformers import Trainer 18 | from tqdm import tqdm 19 | import torch.optim as optim 20 | # from memory_profiler import profile 21 | import torch.optim as optim 22 | from transformers import DataCollatorForSeq2Seq 23 | import os 24 | from torch.nn.utils.rnn import pad_sequence 25 | 26 | # Custom dataset class 27 | 28 | def collate_fn(batch_data): 29 | for i in range(len(batch_data)): 30 | batch_data[i] = batch_data[i].values() 31 | input_ids,labels,speech_tensors, tgt_units,speech_lengths = zip(*batch_data) 32 | 33 | # input_idspad为llama的<|eot_id|> 34 | input_ids = pad_sequence(input_ids, batch_first=True, padding_value=128009) 35 | labels = pad_sequence(labels, batch_first=True, padding_value=-100) 36 | tgt_units = pad_sequence(tgt_units, batch_first=True, padding_value=-100) 37 | # input_ids = torch.stack(input_ids, dim=0) 38 | # labels = torch.stack(labels, dim=0) 39 | speech_tensors = torch.stack(speech_tensors, dim=0) 40 | speech_lengths = torch.stack(speech_lengths, dim=0) 41 | #转fp16 42 | 43 | ret=dict(input_ids=input_ids,labels=labels, speech=speech_tensors.bfloat16(), tgt_units = tgt_units, speech_lengths=speech_lengths) 44 | return ret 45 | 46 | class CustomDataset(Dataset): 47 | def __init__(self, questions, responses, tokenizer, model_config, input_type, mel_size): 48 | self.questions = questions 49 | self.responses = responses 50 | self.tokenizer = tokenizer 51 | self.model_config = model_config 52 | self.input_type = input_type 53 | self.mel_size = mel_size 54 | 55 | 56 | # def get_tgt_unit(self, file_path): 57 | # unique_data_list = [] 58 | # with open(file_path, 'r', encoding='utf-8') as file: 59 | # for line in file: 60 | # line = line.strip() 61 | # parts = line.split('<') 62 | # result = [part for part in parts if part and '>' in part] 63 | # # 移除元素末尾的 '>' 64 | # result = [part.split('>')[0] for part in result] 65 | # line_list = [int(item) for item in result] 66 | # #unique_data = [line_list[i] for i in range(len(line_list)) if i == 0 or line_list[i] != line_list[i-1]] 67 | # unique_data_list.append(line_list) 68 | # # return torch.tensor(unique_data_list) 69 | # return unique_data_list 70 | 71 | def __getitem__(self, index): 72 | #tgt_unit = torch.tensor(self.tgt_unit[index]) 73 | responses = self.responses[index] 74 | prediction = responses['prediction'] 75 | tgt_unit = responses['prediction_units'] 76 | tgt_unit = torch.tensor([int(item) for item in tgt_unit.split(' ')]) 77 | item = self.questions[index] 78 | speech_file = item["speech"] 79 | qs = item["conversations"][0]["value"] 80 | ans = item["conversations"][1]["value"] 81 | # llm_gt = self.llm_gt[index] 82 | conv = conv_templates[args.conv_mode].copy() 83 | conv.append_message(conv.roles[0], qs) 84 | conv.append_message(conv.roles[1], prediction) 85 | prompt = conv.get_prompt() 86 | 87 | 88 | speech = whisper.load_audio(speech_file) 89 | if self.input_type == "raw": 90 | speech = torch.from_numpy(speech) 91 | if self.model_config.speech_normalize: 92 | speech = torch.nn.functional.layer_norm(speech, speech.shape) 93 | elif self.input_type == "mel": 94 | speech = whisper.pad_or_trim(speech) 95 | speech = whisper.log_mel_spectrogram(speech, n_mels=self.mel_size).permute(1, 0) 96 | input_ids_ = tokenizer_speech_token(prompt, self.tokenizer, return_tensors='pt') 97 | input_ids = input_ids_.tolist() 98 | # 处理 input_ids 和 labels,仅训练answer部分的loss 99 | split_markers = [128006, 78191, 128007, 271] 100 | last_marker_index = -1 101 | 102 | for i in range(len(input_ids) - len(split_markers) + 1): 103 | if input_ids[i:i + len(split_markers)] == split_markers: 104 | last_marker_index = i + len(split_markers) 105 | break 106 | if last_marker_index != -1: 107 | list1 = input_ids[:last_marker_index] 108 | list2 = input_ids[last_marker_index:] 109 | 110 | labels = len(list1) * [-100] + list2 111 | labels = torch.tensor(labels, device=input_ids_.device, dtype=input_ids_.dtype) 112 | ret=dict(input_ids=input_ids_,labels=labels, speech=speech, tgt_units=tgt_unit ,speech_lengths=torch.LongTensor([speech.shape[0]])) 113 | # ret=dict(input_ids=input_ids,labels=None, speech=speech, tgt_units=tgt_unit ,speech_lengths=torch.LongTensor([speech.shape[0]])) 114 | return ret 115 | def __len__(self): 116 | return len(self.questions) 117 | 118 | # DataLoader 119 | def create_data_loader(questions, responses,tokenizer, model_config, input_type, mel_size, batch_size=1, num_workers=1): 120 | assert batch_size == 1, "batch_size must be 1" 121 | 122 | dataset = CustomDataset(questions,responses, tokenizer, model_config, input_type, mel_size) 123 | #data_loader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False, collate_fn=collate_fn) 124 | return dataset 125 | 126 | 127 | def split_list(lst, n): 128 | """Split a list into n (roughly) equal-sized chunks""" 129 | chunk_size = math.ceil(len(lst) / n) # integer division 130 | return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)] 131 | 132 | 133 | def get_chunk(lst, n, k): 134 | chunks = split_list(lst, n) 135 | return chunks[k] 136 | 137 | 138 | def train_model(args): 139 | device = 'cuda' if torch.cuda.is_available() else 'cpu' # 设置 device,能用 cuda 就用 cuda,苹果 M 系列可以用 mps 140 | #local_rank = torch.distributed.get_rank() 141 | #torch.cuda.set_device(local_rank) 142 | #device = torch.device(f'cuda:{local_rank}') 143 | model_path = os.path.expanduser(args.model_path) 144 | tokenizer, model, context_len = create_model(model_path, args.model_base, device=device, is_lora=args.is_lora, s2s=args.s2s) 145 | 146 | questions = json.load(open(os.path.expanduser(args.question_file), "r")) 147 | questions = get_chunk(questions, args.num_chunks, args.chunk_idx) #chunk 1 chunk-idx 0 取list中的多少进行测试 148 | with open(os.path.expanduser(args.answer_file), "r") as f: 149 | responses = f.readlines() 150 | for i in range(len(responses)): 151 | responses[i] = json.loads(responses[i]) 152 | data_loader = create_data_loader(questions,responses, tokenizer, model.config, args.input_type, args.mel_size) 153 | 154 | 155 | # optimizer = optim.Adam(model.parameters(), lr=0.00001) 156 | # 学习率变大 157 | optimizer = optim.Adam(model.speech_generator.parameters() , lr=1e-4) 158 | # optimizer = optim.SGD(model.parameters(), lr=0.001) 159 | 160 | 161 | 162 | 163 | # 初始化Trainer 164 | model.train() 165 | training_args = TrainingArguments( 166 | output_dir='saves/stage2_fp16', # 输出路径,包括模型检查点、中间文件等 167 | overwrite_output_dir=True, # 是否覆写 output_dir 168 | do_train=True, # 是否做训练 169 | do_eval=False, # 是否做评估 170 | eval_steps=100, # 评估步骤间隔 171 | per_device_train_batch_size=1, # 每设备批次 172 | gradient_accumulation_steps=8, # 梯度累计步大小,省显存,但小模型没必要,用 1 收敛比较快 173 | learning_rate=3e-5, # 学习率大小 174 | lr_scheduler_type='cosine', # 学习率调度策略,LLM 训练一般都用余弦 175 | bf16=torch.cuda.is_bf16_supported(), # 尝试配置 bf16 176 | fp16=not torch.cuda.is_bf16_supported(), # bf16 不行就上 fp16 177 | half_precision_backend='cuda_amp', 178 | logging_steps=1, # 打印步骤间隔 179 | report_to=None, # 日志输出目标,不想用 wandb 可以设置为 None 180 | num_train_epochs=1000, # 训练轮数,2 ~ 3 即可 181 | save_steps=1000, # 检查点保存步骤间隔 182 | save_total_limit=100, # output_dir 内留存的检查点最大数目 183 | seed=3407, # 随机种子 184 | max_grad_norm=1.0, 185 | 186 | ) 187 | tokenizer.pad_token = tokenizer.eos_token 188 | trainer = Trainer( 189 | model=model, 190 | tokenizer=tokenizer, 191 | args=training_args, 192 | train_dataset=data_loader, 193 | eval_dataset=data_loader, 194 | data_collator=collate_fn, 195 | optimizers=(optimizer, None) 196 | ) 197 | # with torch.no_grad: 198 | trainer.train() 199 | 200 | 201 | if __name__ == "__main__": 202 | 203 | parser = argparse.ArgumentParser() 204 | parser = argparse.ArgumentParser() 205 | #parser.add_argument("--model-path", type=str, default="") 206 | parser.add_argument("--model-path", type=str, default="Llama-3.1-8B-Omni") 207 | parser.add_argument("--model-base", type=str, default='Llama-3.1-8B-Omni') 208 | # parser.add_argument("--question-file", type=str, default="./omni_speech/infer/examples/question.json") 209 | parser.add_argument("--question-file", type=str, default="data.json") 210 | parser.add_argument("--answer-file", type=str, default="omni_speech/infer/gen_answer_data/answer.json") 211 | parser.add_argument("--conv-mode", type=str, default="llama_3") 212 | parser.add_argument("--num-chunks", type=int, default=1) 213 | parser.add_argument("--chunk-idx", type=int, default=0) 214 | parser.add_argument("--temperature", type=float, default=0) 215 | parser.add_argument("--top_p", type=float, default=None) 216 | parser.add_argument("--num_beams", type=int, default=1) 217 | parser.add_argument("--max_new_tokens", type=int, default=256) 218 | parser.add_argument("--input_type", type=str, default="mel") 219 | parser.add_argument("--mel_size", type=int, default=128) 220 | parser.add_argument("--s2s", action="store_true", default=True) 221 | parser.add_argument("--is_lora",type=bool, default=False) 222 | #parser.add_argument("--local-rank") 223 | args = parser.parse_args() 224 | train_model(args) 225 | -------------------------------------------------------------------------------- /omni_speech/utils.py: -------------------------------------------------------------------------------- 1 | # Adopted from https://github.com/haotian-liu/LLaVA. Below is the original copyright: 2 | # Copyright 2023 Haotian Liu 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import os 17 | import sys 18 | import torch 19 | import logging 20 | import logging.handlers 21 | import transformers 22 | 23 | from omni_speech.constants import LOGDIR 24 | 25 | server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**" 26 | moderation_msg = "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN." 27 | 28 | handler = None 29 | 30 | 31 | def build_logger(logger_name, logger_filename): 32 | global handler 33 | 34 | formatter = logging.Formatter( 35 | fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s", 36 | datefmt="%Y-%m-%d %H:%M:%S", 37 | ) 38 | 39 | # Set the format of root handlers 40 | if not logging.getLogger().handlers: 41 | logging.basicConfig(level=logging.INFO) 42 | logging.getLogger().handlers[0].setFormatter(formatter) 43 | 44 | # Redirect stdout and stderr to loggers 45 | stdout_logger = logging.getLogger("stdout") 46 | stdout_logger.setLevel(logging.INFO) 47 | sl = StreamToLogger(stdout_logger, logging.INFO) 48 | sys.stdout = sl 49 | 50 | stderr_logger = logging.getLogger("stderr") 51 | stderr_logger.setLevel(logging.ERROR) 52 | sl = StreamToLogger(stderr_logger, logging.ERROR) 53 | sys.stderr = sl 54 | 55 | # Get logger 56 | logger = logging.getLogger(logger_name) 57 | logger.setLevel(logging.INFO) 58 | 59 | # Add a file handler for all loggers 60 | if handler is None: 61 | os.makedirs(LOGDIR, exist_ok=True) 62 | filename = os.path.join(LOGDIR, logger_filename) 63 | handler = logging.handlers.TimedRotatingFileHandler( 64 | filename, when='D', utc=True, encoding='UTF-8') 65 | handler.setFormatter(formatter) 66 | 67 | for name, item in logging.root.manager.loggerDict.items(): 68 | if isinstance(item, logging.Logger): 69 | item.addHandler(handler) 70 | 71 | return logger 72 | 73 | 74 | class StreamToLogger(object): 75 | """ 76 | Fake file-like stream object that redirects writes to a logger instance. 77 | """ 78 | def __init__(self, logger, log_level=logging.INFO): 79 | self.terminal = sys.stdout 80 | self.logger = logger 81 | self.log_level = log_level 82 | self.linebuf = '' 83 | 84 | def __getattr__(self, attr): 85 | return getattr(self.terminal, attr) 86 | 87 | def write(self, buf): 88 | temp_linebuf = self.linebuf + buf 89 | self.linebuf = '' 90 | for line in temp_linebuf.splitlines(True): 91 | # From the io.TextIOWrapper docs: 92 | # On output, if newline is None, any '\n' characters written 93 | # are translated to the system default line separator. 94 | # By default sys.stdout.write() expects '\n' newlines and then 95 | # translates them so this is still cross platform. 96 | if line[-1] == '\n': 97 | self.logger.log(self.log_level, line.rstrip()) 98 | else: 99 | self.linebuf += line 100 | 101 | def flush(self): 102 | if self.linebuf != '': 103 | self.logger.log(self.log_level, self.linebuf.rstrip()) 104 | self.linebuf = '' 105 | 106 | 107 | def maybe_zero_3(param, ignore_status=False, name=None): 108 | from deepspeed import zero 109 | from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus 110 | if hasattr(param, "ds_id"): 111 | if param.ds_status == ZeroParamStatus.NOT_AVAILABLE: 112 | if not ignore_status: 113 | logging.warning(f"{name}: param.ds_status != ZeroParamStatus.NOT_AVAILABLE: {param.ds_status}") 114 | with zero.GatheredParameters([param]): 115 | param = param.data.detach().cpu().clone() 116 | else: 117 | param = param.detach().cpu().clone() 118 | return param 119 | 120 | 121 | # Borrowed from peft.utils.get_peft_model_state_dict 122 | def get_peft_state_maybe_zero_3(named_params, bias): 123 | if bias == "none": 124 | to_return = {k: t for k, t in named_params if "lora_" in k} 125 | elif bias == "all": 126 | to_return = {k: t for k, t in named_params if "lora_" in k or "bias" in k} 127 | elif bias == "lora_only": 128 | to_return = {} 129 | maybe_lora_bias = {} 130 | lora_bias_names = set() 131 | for k, t in named_params: 132 | if "lora_" in k: 133 | to_return[k] = t 134 | bias_name = k.split("lora_")[0] + "bias" 135 | lora_bias_names.add(bias_name) 136 | elif "bias" in k: 137 | maybe_lora_bias[k] = t 138 | for k, t in maybe_lora_bias: 139 | if bias_name in lora_bias_names: 140 | to_return[bias_name] = t 141 | else: 142 | raise NotImplementedError 143 | to_return = {k: maybe_zero_3(v, ignore_status=True) for k, v in to_return.items()} 144 | return to_return 145 | 146 | 147 | def get_peft_state_non_lora_maybe_zero_3(named_params, require_grad_only=True): 148 | to_return = {k: t for k, t in named_params if "lora_" not in k} 149 | if require_grad_only: 150 | to_return = {k: t for k, t in to_return.items() if t.requires_grad} 151 | to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()} 152 | return to_return 153 | 154 | 155 | def get_speech_projector_state_maybe_zero_3(named_params, keys_to_match): 156 | to_return = {k: t for k, t in named_params if any(key_match in k for key_match in keys_to_match)} 157 | to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()} 158 | return to_return 159 | 160 | 161 | def find_all_linear_names(model): 162 | cls = torch.nn.Linear 163 | lora_module_names = set() 164 | speech_keywords = ['speech_projector', 'speech_encoder'] 165 | for name, module in model.named_modules(): 166 | if any(speech_keyword in name for speech_keyword in speech_keywords): 167 | continue 168 | if isinstance(module, cls): 169 | names = name.split('.') 170 | lora_module_names.add(names[0] if len(names) == 1 else names[-1]) 171 | 172 | if 'lm_head' in lora_module_names: # needed for 16-bit 173 | lora_module_names.remove('lm_head') 174 | return list(lora_module_names) 175 | 176 | 177 | def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, 178 | output_dir: str): 179 | """Collects the state dict and dump to disk.""" 180 | 181 | if getattr(trainer.args, "tune_speech_projector", False): 182 | # Only save projector 183 | keys_to_match = ['speech_projector'] 184 | if getattr(trainer.args, "use_im_start_end", False): 185 | keys_to_match.extend(['embed_tokens', 'embed_in']) 186 | 187 | weight_to_save = get_speech_projector_state_maybe_zero_3(trainer.model.named_parameters(), keys_to_match) 188 | trainer.model.config.save_pretrained(output_dir) 189 | 190 | current_folder = output_dir.split('/')[-1] 191 | parent_folder = os.path.dirname(output_dir) 192 | if trainer.args.local_rank == 0 or trainer.args.local_rank == -1: 193 | if current_folder.startswith('checkpoint-'): 194 | speech_projector_folder = os.path.join(parent_folder, "speech_projector") 195 | os.makedirs(speech_projector_folder, exist_ok=True) 196 | torch.save(weight_to_save, os.path.join(speech_projector_folder, f'{current_folder}.bin')) 197 | else: 198 | torch.save(weight_to_save, os.path.join(output_dir, f'speech_projector.bin')) 199 | return 200 | 201 | if trainer.deepspeed: 202 | torch.cuda.synchronize() 203 | trainer.save_model(output_dir) 204 | return 205 | 206 | state_dict = trainer.model.state_dict() 207 | if trainer.args.should_save: 208 | cpu_state_dict = { 209 | key: value.cpu() 210 | for key, value in state_dict.items() 211 | } 212 | del state_dict 213 | trainer._save(output_dir, state_dict=cpu_state_dict) # noqa 214 | 215 | 216 | def lengths_to_padding_mask(lens): 217 | bsz, max_lens = lens.size(0), torch.max(lens).item() 218 | mask = torch.arange(max_lens).to(lens.device).view(1, max_lens) 219 | mask = mask.expand(bsz, -1) >= lens.view(bsz, 1).expand(-1, max_lens) 220 | return mask 221 | 222 | 223 | def lengths_to_mask(lens): 224 | return ~lengths_to_padding_mask(lens) 225 | 226 | 227 | def disable_torch_init(): 228 | """ 229 | Disable the redundant torch default initialization to accelerate model creation. 230 | """ 231 | import torch 232 | setattr(torch.nn.Linear, "reset_parameters", lambda self: None) 233 | setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None) 234 | 235 | 236 | def get_model_name_from_path(model_path): 237 | model_path = model_path.strip("/") 238 | model_paths = model_path.split("/") 239 | if model_paths[-1].startswith('checkpoint-'): 240 | return model_paths[-2] + "_" + model_paths[-1] 241 | else: 242 | return model_paths[-1] 243 | 244 | 245 | def violates_moderation(text): 246 | """ 247 | Check whether the text violates OpenAI moderation API. 248 | """ 249 | url = "https://api.openai.com/v1/moderations" 250 | headers = {"Content-Type": "application/json", 251 | "Authorization": "Bearer " + os.environ["OPENAI_API_KEY"]} 252 | text = text.replace("\n", "") 253 | data = "{" + '"input": ' + f'"{text}"' + "}" 254 | data = data.encode("utf-8") 255 | try: 256 | ret = requests.post(url, headers=headers, data=data, timeout=5) 257 | flagged = ret.json()["results"][0]["flagged"] 258 | except requests.exceptions.RequestException as e: 259 | flagged = False 260 | except KeyError as e: 261 | flagged = False 262 | 263 | return flagged 264 | 265 | 266 | def pretty_print_semaphore(semaphore): 267 | if semaphore is None: 268 | return "None" 269 | return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})" -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=61.0"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "llama-omni" 7 | version = "1.0.0" 8 | description = "Towards GPT-4o like large speech-language model." 9 | readme = "README.md" 10 | requires-python = ">=3.10" 11 | classifiers = [ 12 | "Programming Language :: Python :: 3", 13 | "License :: OSI Approved :: Apache Software License", 14 | ] 15 | dependencies = [ 16 | "torch==2.1.2", "torchvision==0.16.2", "torchaudio==2.1.2", 17 | "transformers==4.43.4", "tokenizers==0.19.1", "sentencepiece==0.1.99", "shortuuid", 18 | "accelerate==0.33.0", "peft==0.11.1", "bitsandbytes==0.43.1", 19 | "pydantic", "markdown2[all]", "numpy", "scikit-learn==1.2.2", 20 | "gradio==4.43.0", "gradio_client==1.3.0", 21 | "requests", "httpx==0.27.2", "uvicorn", "fastapi", "soundfile", 22 | "einops==0.6.1", "einops-exts==0.0.4", "timm==0.6.13", 23 | "openai-whisper", "setuptools==59.5.0", "omegaconf==2.0.6", 24 | ] 25 | 26 | [project.optional-dependencies] 27 | train = ["deepspeed==0.12.6", "ninja", "wandb", "tensorboardX"] 28 | build = ["build", "twine"] 29 | 30 | [tool.setuptools.packages.find] 31 | exclude = ["data", "checkpoints", "logs", "models", "fairseq", "flash-attention"] 32 | 33 | [tool.wheel] 34 | exclude = ["data", "checkpoints", "logs", "models", "fairseq", "flash-attention"] 35 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==2.1.0 2 | accelerate==0.33.0 3 | aiofiles==23.2.1 4 | annotated-types==0.7.0 5 | antlr4-python3-runtime==4.8 6 | anyio==4.5.0 7 | argon2-cffi==23.1.0 8 | argon2-cffi-bindings==21.2.0 9 | arrow==1.3.0 10 | asttokens==2.4.1 11 | async-lru==2.0.4 12 | attrs==24.2.0 13 | babel==2.16.0 14 | beautifulsoup4==4.12.3 15 | bitarray==2.9.2 16 | bitsandbytes==0.43.1 17 | black==22.3.0 18 | bleach==6.1.0 19 | blessed==1.20.0 20 | certifi==2024.8.30 21 | cffi==1.17.1 22 | charset-normalizer==3.3.2 23 | click==8.1.7 24 | colorama==0.4.6 25 | coloredlogs==15.0.1 26 | comm==0.2.2 27 | contourpy==1.3.0 28 | cycler==0.12.1 29 | Cython==3.0.11 30 | debugpy==1.8.7 31 | decorator==5.1.1 32 | defusedxml==0.7.1 33 | dill==0.3.8 34 | docopt==0.6.2 35 | einops==0.6.1 36 | einops-exts==0.0.4 37 | exceptiongroup==1.2.2 38 | executing==2.1.0 39 | fastapi==0.112.4 40 | fastjsonschema==2.20.0 41 | ffmpy==0.4.0 42 | filelock==3.16.1 43 | flake8==5.0.4 44 | flash-attn==2.6.3 45 | flatbuffers==24.3.25 46 | fonttools==4.53.1 47 | fqdn==1.5.1 48 | fsspec==2024.9.0 49 | gpustat==1.1.1 50 | gradio==4.43.0 51 | gradio_client==1.3.0 52 | grpcio==1.66.1 53 | h11==0.14.0 54 | httpcore==1.0.5 55 | httpx==0.27.2 56 | huggingface-hub==0.25.0 57 | humanfriendly==10.0 58 | hydra-core==1.0.7 59 | idna==3.10 60 | importlib_resources==6.4.5 61 | ipdb==0.13.13 62 | ipykernel==6.29.5 63 | ipython==8.27.0 64 | ipywidgets==8.1.5 65 | isoduration==20.11.0 66 | isort==5.10.1 67 | jedi==0.19.1 68 | Jinja2==3.1.4 69 | joblib==1.4.2 70 | json5==0.9.25 71 | jsonpointer==3.0.0 72 | jsonschema==4.23.0 73 | jsonschema-specifications==2024.10.1 74 | jupyter==1.1.1 75 | jupyter-console==6.6.3 76 | jupyter-events==0.10.0 77 | jupyter-lsp==2.2.5 78 | jupyter_client==8.6.3 79 | jupyter_core==5.7.2 80 | jupyter_server==2.14.2 81 | jupyter_server_terminals==0.5.3 82 | jupyterlab==4.2.5 83 | jupyterlab_pygments==0.3.0 84 | jupyterlab_server==2.27.3 85 | jupyterlab_widgets==3.0.13 86 | kaldi-decoder==0.2.7 87 | kaldialign==0.9.1 88 | kaldifst==1.7.12 89 | kaldilm==1.15.1 90 | kiwisolver==1.4.7 91 | latex2mathml==3.77.0 92 | llvmlite==0.43.0 93 | lxml==5.3.0 94 | Markdown==3.7 95 | markdown-it-py==3.0.0 96 | markdown2==2.5.0 97 | MarkupSafe==2.1.5 98 | matplotlib==3.9.2 99 | matplotlib-inline==0.1.7 100 | mccabe==0.7.0 101 | mdurl==0.1.2 102 | mistune==3.0.2 103 | more-itertools==10.5.0 104 | mpmath==1.3.0 105 | mypy-extensions==1.0.0 106 | nbclient==0.10.0 107 | nbconvert==7.16.4 108 | nbformat==5.10.4 109 | nest-asyncio==1.6.0 110 | networkx==3.3 111 | notebook==7.2.2 112 | notebook_shim==0.2.4 113 | num2words==0.5.13 114 | numba==0.60.0 115 | numpy==1.26.4 116 | nvidia-cublas-cu12==12.1.3.1 117 | nvidia-cuda-cupti-cu12==12.1.105 118 | nvidia-cuda-nvrtc-cu12==12.1.105 119 | nvidia-cuda-runtime-cu12==12.1.105 120 | nvidia-cudnn-cu12==8.9.2.26 121 | nvidia-cufft-cu12==11.0.2.54 122 | nvidia-curand-cu12==10.3.2.106 123 | nvidia-cusolver-cu12==11.4.5.107 124 | nvidia-cusparse-cu12==12.1.0.106 125 | nvidia-ml-py==12.560.30 126 | nvidia-nccl-cu12==2.18.1 127 | nvidia-nvjitlink-cu12==12.6.68 128 | nvidia-nvtx-cu12==12.1.105 129 | omegaconf==2.0.6 130 | onnx==1.16.2 131 | onnxconverter-common==1.14.0 132 | onnxoptimizer==0.3.13 133 | onnxruntime==1.19.2 134 | onnxsim==0.4.36 135 | openai-whisper==20231117 136 | orjson==3.10.7 137 | overrides==7.7.0 138 | packaging==24.1 139 | pandas==2.2.2 140 | pandocfilters==1.5.1 141 | parso==0.8.4 142 | pathspec==0.12.1 143 | peft==0.11.1 144 | pexpect==4.9.0 145 | pillow==10.4.0 146 | platformdirs==4.3.6 147 | portalocker==2.10.1 148 | prometheus_client==0.21.0 149 | prompt_toolkit==3.0.47 150 | protobuf==3.20.2 151 | psutil==6.0.0 152 | ptyprocess==0.7.0 153 | pure_eval==0.2.3 154 | pycantonese==3.4.0 155 | pycodestyle==2.9.1 156 | pycparser==2.22 157 | pydantic==2.9.2 158 | pydantic_core==2.23.4 159 | pydub==0.25.1 160 | pyflakes==2.5.0 161 | Pygments==2.18.0 162 | pylangacq==0.16.2 163 | pyparsing==3.1.4 164 | pypinyin==0.50.0 165 | python-dateutil==2.9.0.post0 166 | python-json-logger==2.0.7 167 | python-multipart==0.0.9 168 | pytz==2024.2 169 | PyYAML==6.0.2 170 | pyzmq==26.2.0 171 | referencing==0.35.1 172 | regex==2024.9.11 173 | requests==2.32.3 174 | rfc3339-validator==0.1.4 175 | rfc3986-validator==0.1.1 176 | rich==13.8.1 177 | rpds-py==0.20.0 178 | ruff==0.6.5 179 | sacrebleu==2.4.3 180 | safetensors==0.4.5 181 | scikit-learn==1.2.2 182 | scipy==1.14.1 183 | semantic-version==2.10.0 184 | Send2Trash==1.8.3 185 | sentencepiece==0.1.99 186 | shellingham==1.5.4 187 | shortuuid==1.0.13 188 | six==1.16.0 189 | sniffio==1.3.1 190 | soundfile==0.12.1 191 | soupsieve==2.6 192 | stack-data==0.6.3 193 | starlette==0.38.5 194 | svgwrite==1.4.3 195 | sympy==1.13.3 196 | tabulate==0.9.0 197 | tensorboard==2.17.1 198 | tensorboard-data-server==0.7.2 199 | tensorboardX==2.6.2.2 200 | terminado==0.18.1 201 | threadpoolctl==3.5.0 202 | tiktoken==0.7.0 203 | timm==0.6.13 204 | tinycss2==1.4.0 205 | tokenizers==0.19.1 206 | tomli==2.0.1 207 | tomlkit==0.12.0 208 | torch==2.1.2 209 | torchaudio==2.1.2 210 | torchvision==0.16.2 211 | tornado==6.4.1 212 | tqdm==4.66.5 213 | traitlets==5.14.3 214 | transformers==4.43.4 215 | triton==2.1.0 216 | typeguard==4.3.0 217 | typer==0.12.5 218 | types-python-dateutil==2.9.0.20241003 219 | typing_extensions==4.12.2 220 | tzdata==2024.1 221 | uri-template==1.3.0 222 | urllib3==2.2.3 223 | uvicorn==0.30.6 224 | wavedrom==2.0.3.post3 225 | wcwidth==0.2.13 226 | webcolors==24.8.0 227 | webencodings==0.5.1 228 | websocket-client==1.8.0 229 | websockets==12.0 230 | Werkzeug==3.0.4 231 | widgetsnbextension==4.0.13 232 | wordseg==0.0.2 233 | -------------------------------------------------------------------------------- /wavs/sft_1.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wntg/LLaMA-Omni/cf4c32fb4ee5cb620c58520db562e6e13384ecb3/wavs/sft_1.wav -------------------------------------------------------------------------------- /wavs/sft_10.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wntg/LLaMA-Omni/cf4c32fb4ee5cb620c58520db562e6e13384ecb3/wavs/sft_10.wav -------------------------------------------------------------------------------- /wavs/sft_100.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wntg/LLaMA-Omni/cf4c32fb4ee5cb620c58520db562e6e13384ecb3/wavs/sft_100.wav -------------------------------------------------------------------------------- /wavs/sft_11.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wntg/LLaMA-Omni/cf4c32fb4ee5cb620c58520db562e6e13384ecb3/wavs/sft_11.wav -------------------------------------------------------------------------------- /wavs/sft_12.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wntg/LLaMA-Omni/cf4c32fb4ee5cb620c58520db562e6e13384ecb3/wavs/sft_12.wav -------------------------------------------------------------------------------- /wavs/sft_13.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wntg/LLaMA-Omni/cf4c32fb4ee5cb620c58520db562e6e13384ecb3/wavs/sft_13.wav -------------------------------------------------------------------------------- /wavs/sft_14.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wntg/LLaMA-Omni/cf4c32fb4ee5cb620c58520db562e6e13384ecb3/wavs/sft_14.wav -------------------------------------------------------------------------------- /wavs/sft_15.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wntg/LLaMA-Omni/cf4c32fb4ee5cb620c58520db562e6e13384ecb3/wavs/sft_15.wav -------------------------------------------------------------------------------- /wavs/sft_16.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wntg/LLaMA-Omni/cf4c32fb4ee5cb620c58520db562e6e13384ecb3/wavs/sft_16.wav -------------------------------------------------------------------------------- /wavs/sft_17.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wntg/LLaMA-Omni/cf4c32fb4ee5cb620c58520db562e6e13384ecb3/wavs/sft_17.wav -------------------------------------------------------------------------------- /wavs/sft_18.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wntg/LLaMA-Omni/cf4c32fb4ee5cb620c58520db562e6e13384ecb3/wavs/sft_18.wav -------------------------------------------------------------------------------- /wavs/sft_19.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wntg/LLaMA-Omni/cf4c32fb4ee5cb620c58520db562e6e13384ecb3/wavs/sft_19.wav -------------------------------------------------------------------------------- /wavs/sft_2.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wntg/LLaMA-Omni/cf4c32fb4ee5cb620c58520db562e6e13384ecb3/wavs/sft_2.wav -------------------------------------------------------------------------------- /wavs/sft_20.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wntg/LLaMA-Omni/cf4c32fb4ee5cb620c58520db562e6e13384ecb3/wavs/sft_20.wav -------------------------------------------------------------------------------- /wavs/sft_21.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wntg/LLaMA-Omni/cf4c32fb4ee5cb620c58520db562e6e13384ecb3/wavs/sft_21.wav -------------------------------------------------------------------------------- /wavs/sft_22.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wntg/LLaMA-Omni/cf4c32fb4ee5cb620c58520db562e6e13384ecb3/wavs/sft_22.wav -------------------------------------------------------------------------------- /wavs/sft_23.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wntg/LLaMA-Omni/cf4c32fb4ee5cb620c58520db562e6e13384ecb3/wavs/sft_23.wav -------------------------------------------------------------------------------- /wavs/sft_24.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wntg/LLaMA-Omni/cf4c32fb4ee5cb620c58520db562e6e13384ecb3/wavs/sft_24.wav -------------------------------------------------------------------------------- /wavs/sft_25.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wntg/LLaMA-Omni/cf4c32fb4ee5cb620c58520db562e6e13384ecb3/wavs/sft_25.wav -------------------------------------------------------------------------------- /wavs/sft_26.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wntg/LLaMA-Omni/cf4c32fb4ee5cb620c58520db562e6e13384ecb3/wavs/sft_26.wav -------------------------------------------------------------------------------- /wavs/sft_27.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wntg/LLaMA-Omni/cf4c32fb4ee5cb620c58520db562e6e13384ecb3/wavs/sft_27.wav -------------------------------------------------------------------------------- /wavs/sft_28.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wntg/LLaMA-Omni/cf4c32fb4ee5cb620c58520db562e6e13384ecb3/wavs/sft_28.wav -------------------------------------------------------------------------------- /wavs/sft_29.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wntg/LLaMA-Omni/cf4c32fb4ee5cb620c58520db562e6e13384ecb3/wavs/sft_29.wav -------------------------------------------------------------------------------- /wavs/sft_3.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wntg/LLaMA-Omni/cf4c32fb4ee5cb620c58520db562e6e13384ecb3/wavs/sft_3.wav -------------------------------------------------------------------------------- /wavs/sft_30.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wntg/LLaMA-Omni/cf4c32fb4ee5cb620c58520db562e6e13384ecb3/wavs/sft_30.wav -------------------------------------------------------------------------------- /wavs/sft_31.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wntg/LLaMA-Omni/cf4c32fb4ee5cb620c58520db562e6e13384ecb3/wavs/sft_31.wav -------------------------------------------------------------------------------- /wavs/sft_32.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wntg/LLaMA-Omni/cf4c32fb4ee5cb620c58520db562e6e13384ecb3/wavs/sft_32.wav -------------------------------------------------------------------------------- /wavs/sft_33.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wntg/LLaMA-Omni/cf4c32fb4ee5cb620c58520db562e6e13384ecb3/wavs/sft_33.wav -------------------------------------------------------------------------------- /wavs/sft_34.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wntg/LLaMA-Omni/cf4c32fb4ee5cb620c58520db562e6e13384ecb3/wavs/sft_34.wav -------------------------------------------------------------------------------- /wavs/sft_35.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wntg/LLaMA-Omni/cf4c32fb4ee5cb620c58520db562e6e13384ecb3/wavs/sft_35.wav -------------------------------------------------------------------------------- /wavs/sft_36.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wntg/LLaMA-Omni/cf4c32fb4ee5cb620c58520db562e6e13384ecb3/wavs/sft_36.wav -------------------------------------------------------------------------------- /wavs/sft_37.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wntg/LLaMA-Omni/cf4c32fb4ee5cb620c58520db562e6e13384ecb3/wavs/sft_37.wav -------------------------------------------------------------------------------- /wavs/sft_38.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wntg/LLaMA-Omni/cf4c32fb4ee5cb620c58520db562e6e13384ecb3/wavs/sft_38.wav -------------------------------------------------------------------------------- /wavs/sft_39.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wntg/LLaMA-Omni/cf4c32fb4ee5cb620c58520db562e6e13384ecb3/wavs/sft_39.wav -------------------------------------------------------------------------------- /wavs/sft_4.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wntg/LLaMA-Omni/cf4c32fb4ee5cb620c58520db562e6e13384ecb3/wavs/sft_4.wav -------------------------------------------------------------------------------- /wavs/sft_40.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wntg/LLaMA-Omni/cf4c32fb4ee5cb620c58520db562e6e13384ecb3/wavs/sft_40.wav -------------------------------------------------------------------------------- /wavs/sft_41.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wntg/LLaMA-Omni/cf4c32fb4ee5cb620c58520db562e6e13384ecb3/wavs/sft_41.wav -------------------------------------------------------------------------------- /wavs/sft_42.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wntg/LLaMA-Omni/cf4c32fb4ee5cb620c58520db562e6e13384ecb3/wavs/sft_42.wav -------------------------------------------------------------------------------- /wavs/sft_43.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wntg/LLaMA-Omni/cf4c32fb4ee5cb620c58520db562e6e13384ecb3/wavs/sft_43.wav -------------------------------------------------------------------------------- /wavs/sft_44.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wntg/LLaMA-Omni/cf4c32fb4ee5cb620c58520db562e6e13384ecb3/wavs/sft_44.wav -------------------------------------------------------------------------------- /wavs/sft_45.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wntg/LLaMA-Omni/cf4c32fb4ee5cb620c58520db562e6e13384ecb3/wavs/sft_45.wav -------------------------------------------------------------------------------- /wavs/sft_46.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wntg/LLaMA-Omni/cf4c32fb4ee5cb620c58520db562e6e13384ecb3/wavs/sft_46.wav -------------------------------------------------------------------------------- /wavs/sft_47.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wntg/LLaMA-Omni/cf4c32fb4ee5cb620c58520db562e6e13384ecb3/wavs/sft_47.wav -------------------------------------------------------------------------------- /wavs/sft_48.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wntg/LLaMA-Omni/cf4c32fb4ee5cb620c58520db562e6e13384ecb3/wavs/sft_48.wav -------------------------------------------------------------------------------- /wavs/sft_49.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wntg/LLaMA-Omni/cf4c32fb4ee5cb620c58520db562e6e13384ecb3/wavs/sft_49.wav -------------------------------------------------------------------------------- /wavs/sft_5.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wntg/LLaMA-Omni/cf4c32fb4ee5cb620c58520db562e6e13384ecb3/wavs/sft_5.wav -------------------------------------------------------------------------------- /wavs/sft_50.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wntg/LLaMA-Omni/cf4c32fb4ee5cb620c58520db562e6e13384ecb3/wavs/sft_50.wav -------------------------------------------------------------------------------- /wavs/sft_51.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wntg/LLaMA-Omni/cf4c32fb4ee5cb620c58520db562e6e13384ecb3/wavs/sft_51.wav -------------------------------------------------------------------------------- /wavs/sft_52.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wntg/LLaMA-Omni/cf4c32fb4ee5cb620c58520db562e6e13384ecb3/wavs/sft_52.wav -------------------------------------------------------------------------------- /wavs/sft_53.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wntg/LLaMA-Omni/cf4c32fb4ee5cb620c58520db562e6e13384ecb3/wavs/sft_53.wav -------------------------------------------------------------------------------- /wavs/sft_54.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wntg/LLaMA-Omni/cf4c32fb4ee5cb620c58520db562e6e13384ecb3/wavs/sft_54.wav -------------------------------------------------------------------------------- /wavs/sft_55.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wntg/LLaMA-Omni/cf4c32fb4ee5cb620c58520db562e6e13384ecb3/wavs/sft_55.wav -------------------------------------------------------------------------------- /wavs/sft_56.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wntg/LLaMA-Omni/cf4c32fb4ee5cb620c58520db562e6e13384ecb3/wavs/sft_56.wav -------------------------------------------------------------------------------- /wavs/sft_57.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wntg/LLaMA-Omni/cf4c32fb4ee5cb620c58520db562e6e13384ecb3/wavs/sft_57.wav -------------------------------------------------------------------------------- /wavs/sft_58.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wntg/LLaMA-Omni/cf4c32fb4ee5cb620c58520db562e6e13384ecb3/wavs/sft_58.wav -------------------------------------------------------------------------------- /wavs/sft_59.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wntg/LLaMA-Omni/cf4c32fb4ee5cb620c58520db562e6e13384ecb3/wavs/sft_59.wav -------------------------------------------------------------------------------- /wavs/sft_6.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wntg/LLaMA-Omni/cf4c32fb4ee5cb620c58520db562e6e13384ecb3/wavs/sft_6.wav -------------------------------------------------------------------------------- /wavs/sft_60.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wntg/LLaMA-Omni/cf4c32fb4ee5cb620c58520db562e6e13384ecb3/wavs/sft_60.wav -------------------------------------------------------------------------------- /wavs/sft_61.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wntg/LLaMA-Omni/cf4c32fb4ee5cb620c58520db562e6e13384ecb3/wavs/sft_61.wav -------------------------------------------------------------------------------- /wavs/sft_62.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wntg/LLaMA-Omni/cf4c32fb4ee5cb620c58520db562e6e13384ecb3/wavs/sft_62.wav -------------------------------------------------------------------------------- /wavs/sft_63.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wntg/LLaMA-Omni/cf4c32fb4ee5cb620c58520db562e6e13384ecb3/wavs/sft_63.wav -------------------------------------------------------------------------------- /wavs/sft_64.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wntg/LLaMA-Omni/cf4c32fb4ee5cb620c58520db562e6e13384ecb3/wavs/sft_64.wav -------------------------------------------------------------------------------- /wavs/sft_65.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wntg/LLaMA-Omni/cf4c32fb4ee5cb620c58520db562e6e13384ecb3/wavs/sft_65.wav -------------------------------------------------------------------------------- /wavs/sft_66.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wntg/LLaMA-Omni/cf4c32fb4ee5cb620c58520db562e6e13384ecb3/wavs/sft_66.wav -------------------------------------------------------------------------------- /wavs/sft_67.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wntg/LLaMA-Omni/cf4c32fb4ee5cb620c58520db562e6e13384ecb3/wavs/sft_67.wav -------------------------------------------------------------------------------- /wavs/sft_68.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wntg/LLaMA-Omni/cf4c32fb4ee5cb620c58520db562e6e13384ecb3/wavs/sft_68.wav -------------------------------------------------------------------------------- /wavs/sft_69.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wntg/LLaMA-Omni/cf4c32fb4ee5cb620c58520db562e6e13384ecb3/wavs/sft_69.wav -------------------------------------------------------------------------------- /wavs/sft_7.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wntg/LLaMA-Omni/cf4c32fb4ee5cb620c58520db562e6e13384ecb3/wavs/sft_7.wav -------------------------------------------------------------------------------- /wavs/sft_70.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wntg/LLaMA-Omni/cf4c32fb4ee5cb620c58520db562e6e13384ecb3/wavs/sft_70.wav -------------------------------------------------------------------------------- /wavs/sft_71.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wntg/LLaMA-Omni/cf4c32fb4ee5cb620c58520db562e6e13384ecb3/wavs/sft_71.wav -------------------------------------------------------------------------------- /wavs/sft_72.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wntg/LLaMA-Omni/cf4c32fb4ee5cb620c58520db562e6e13384ecb3/wavs/sft_72.wav -------------------------------------------------------------------------------- /wavs/sft_73.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wntg/LLaMA-Omni/cf4c32fb4ee5cb620c58520db562e6e13384ecb3/wavs/sft_73.wav -------------------------------------------------------------------------------- /wavs/sft_74.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wntg/LLaMA-Omni/cf4c32fb4ee5cb620c58520db562e6e13384ecb3/wavs/sft_74.wav -------------------------------------------------------------------------------- /wavs/sft_75.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wntg/LLaMA-Omni/cf4c32fb4ee5cb620c58520db562e6e13384ecb3/wavs/sft_75.wav -------------------------------------------------------------------------------- /wavs/sft_76.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wntg/LLaMA-Omni/cf4c32fb4ee5cb620c58520db562e6e13384ecb3/wavs/sft_76.wav -------------------------------------------------------------------------------- /wavs/sft_77.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wntg/LLaMA-Omni/cf4c32fb4ee5cb620c58520db562e6e13384ecb3/wavs/sft_77.wav -------------------------------------------------------------------------------- /wavs/sft_78.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wntg/LLaMA-Omni/cf4c32fb4ee5cb620c58520db562e6e13384ecb3/wavs/sft_78.wav -------------------------------------------------------------------------------- /wavs/sft_79.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wntg/LLaMA-Omni/cf4c32fb4ee5cb620c58520db562e6e13384ecb3/wavs/sft_79.wav -------------------------------------------------------------------------------- /wavs/sft_8.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wntg/LLaMA-Omni/cf4c32fb4ee5cb620c58520db562e6e13384ecb3/wavs/sft_8.wav -------------------------------------------------------------------------------- /wavs/sft_80.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wntg/LLaMA-Omni/cf4c32fb4ee5cb620c58520db562e6e13384ecb3/wavs/sft_80.wav -------------------------------------------------------------------------------- /wavs/sft_81.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wntg/LLaMA-Omni/cf4c32fb4ee5cb620c58520db562e6e13384ecb3/wavs/sft_81.wav -------------------------------------------------------------------------------- /wavs/sft_82.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wntg/LLaMA-Omni/cf4c32fb4ee5cb620c58520db562e6e13384ecb3/wavs/sft_82.wav -------------------------------------------------------------------------------- /wavs/sft_83.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wntg/LLaMA-Omni/cf4c32fb4ee5cb620c58520db562e6e13384ecb3/wavs/sft_83.wav -------------------------------------------------------------------------------- /wavs/sft_84.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wntg/LLaMA-Omni/cf4c32fb4ee5cb620c58520db562e6e13384ecb3/wavs/sft_84.wav -------------------------------------------------------------------------------- /wavs/sft_85.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wntg/LLaMA-Omni/cf4c32fb4ee5cb620c58520db562e6e13384ecb3/wavs/sft_85.wav -------------------------------------------------------------------------------- /wavs/sft_86.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wntg/LLaMA-Omni/cf4c32fb4ee5cb620c58520db562e6e13384ecb3/wavs/sft_86.wav -------------------------------------------------------------------------------- /wavs/sft_87.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wntg/LLaMA-Omni/cf4c32fb4ee5cb620c58520db562e6e13384ecb3/wavs/sft_87.wav -------------------------------------------------------------------------------- /wavs/sft_88.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wntg/LLaMA-Omni/cf4c32fb4ee5cb620c58520db562e6e13384ecb3/wavs/sft_88.wav -------------------------------------------------------------------------------- /wavs/sft_89.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wntg/LLaMA-Omni/cf4c32fb4ee5cb620c58520db562e6e13384ecb3/wavs/sft_89.wav -------------------------------------------------------------------------------- /wavs/sft_9.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wntg/LLaMA-Omni/cf4c32fb4ee5cb620c58520db562e6e13384ecb3/wavs/sft_9.wav -------------------------------------------------------------------------------- /wavs/sft_90.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wntg/LLaMA-Omni/cf4c32fb4ee5cb620c58520db562e6e13384ecb3/wavs/sft_90.wav -------------------------------------------------------------------------------- /wavs/sft_91.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wntg/LLaMA-Omni/cf4c32fb4ee5cb620c58520db562e6e13384ecb3/wavs/sft_91.wav -------------------------------------------------------------------------------- /wavs/sft_92.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wntg/LLaMA-Omni/cf4c32fb4ee5cb620c58520db562e6e13384ecb3/wavs/sft_92.wav -------------------------------------------------------------------------------- /wavs/sft_93.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wntg/LLaMA-Omni/cf4c32fb4ee5cb620c58520db562e6e13384ecb3/wavs/sft_93.wav -------------------------------------------------------------------------------- /wavs/sft_94.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wntg/LLaMA-Omni/cf4c32fb4ee5cb620c58520db562e6e13384ecb3/wavs/sft_94.wav -------------------------------------------------------------------------------- /wavs/sft_95.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wntg/LLaMA-Omni/cf4c32fb4ee5cb620c58520db562e6e13384ecb3/wavs/sft_95.wav -------------------------------------------------------------------------------- /wavs/sft_96.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wntg/LLaMA-Omni/cf4c32fb4ee5cb620c58520db562e6e13384ecb3/wavs/sft_96.wav -------------------------------------------------------------------------------- /wavs/sft_97.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wntg/LLaMA-Omni/cf4c32fb4ee5cb620c58520db562e6e13384ecb3/wavs/sft_97.wav -------------------------------------------------------------------------------- /wavs/sft_98.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wntg/LLaMA-Omni/cf4c32fb4ee5cb620c58520db562e6e13384ecb3/wavs/sft_98.wav -------------------------------------------------------------------------------- /wavs/sft_99.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wntg/LLaMA-Omni/cf4c32fb4ee5cb620c58520db562e6e13384ecb3/wavs/sft_99.wav --------------------------------------------------------------------------------