├── src └── slam_llm │ ├── __init__.py │ ├── utils │ ├── compute_utils.py │ ├── __init__.py │ ├── num2word.py │ ├── metric.py │ ├── whisper_tn.py │ ├── preprocess_text.py │ ├── llm_tn.py │ ├── model_utils.py │ ├── compute_aac_metrics.py │ ├── fsdp_utils.py │ ├── dataset_utils.py │ ├── memory_utils.py │ ├── config_utils.py │ ├── custom_utils.py │ └── checkpoint_handler.py │ ├── data │ ├── __init__.py │ ├── concatenator.py │ └── sampler.py │ ├── datasets │ ├── __init__.py │ └── audio_dataset.py │ ├── inference │ ├── __init__.py │ ├── model_utils.py │ ├── chat_utils.py │ ├── checkpoint_converter_fsdp_hf.py │ └── safety_utils.py │ ├── policies │ ├── __init__.py │ ├── activation_checkpointing_functions.py │ ├── mixed_precision.py │ ├── wrapping.py │ └── anyprecision_optimizer.py │ ├── pipeline │ ├── inference.py │ ├── inference_batch.py │ └── finetune.py │ └── models │ ├── projector.py │ └── encoder.py ├── evaluation ├── qwen2.5-3B.xlsx └── test_metric.py ├── examples └── st_covost2 │ ├── image │ ├── prompt.png │ └── framework.jpg │ ├── conf │ ├── prompt.yaml │ └── ds_config.json │ ├── finetune_asr.py │ ├── scripts │ ├── infer_hf.sh │ ├── all.sh │ └── infer_all.sh │ ├── asr_config.py │ ├── inference_asr_batch.py │ ├── dataset │ ├── fleurs_dataset.py │ └── srt_dataset.py │ └── model │ └── slam_model_st.py ├── setup.py ├── models └── README.md ├── requirements.txt ├── .gitignore ├── .github ├── ISSUE_TEMPLATE │ ├── feature-request.yml │ └── bug.yml ├── PULL_REQUEST_TEMPLATE.md └── workflows │ └── spellcheck.yml ├── LICENSE ├── pyproject.toml └── README.md /src/slam_llm/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /evaluation/qwen2.5-3B.xlsx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yxduir/LLM-SRT/HEAD/evaluation/qwen2.5-3B.xlsx -------------------------------------------------------------------------------- /examples/st_covost2/image/prompt.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yxduir/LLM-SRT/HEAD/examples/st_covost2/image/prompt.png -------------------------------------------------------------------------------- /examples/st_covost2/image/framework.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yxduir/LLM-SRT/HEAD/examples/st_covost2/image/framework.jpg -------------------------------------------------------------------------------- /src/slam_llm/utils/compute_utils.py: -------------------------------------------------------------------------------- 1 | 2 | def calculate_output_length_1d(L_in, kernel_size, stride, padding=0): 3 | return (L_in + 2 * padding - kernel_size) // stride + 1 -------------------------------------------------------------------------------- /src/slam_llm/data/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. -------------------------------------------------------------------------------- /src/slam_llm/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. -------------------------------------------------------------------------------- /src/slam_llm/inference/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | setup( 4 | name="slam-llm", 5 | version="0.0.1", 6 | packages=["src/slam_llm"], # 与 Hatch 配置一致 7 | install_requires=open("requirements.txt").read().splitlines(), 8 | ) -------------------------------------------------------------------------------- /examples/st_covost2/conf/prompt.yaml: -------------------------------------------------------------------------------- 1 | dataset_config: 2 | # we put prompt here, because the hydra override in shell script only support a small subset of chars 3 | # prompt: "Transcribe speech to text. Output the transcription directly without redundant content. Ensure that the output is not duplicated. " 4 | # prompt: "" 5 | -------------------------------------------------------------------------------- /src/slam_llm/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. 3 | 4 | from slam_llm.utils.memory_utils import MemoryTrace 5 | from slam_llm.utils.dataset_utils import * 6 | from slam_llm.utils.fsdp_utils import fsdp_auto_wrap_policy 7 | from slam_llm.utils.train_utils import * -------------------------------------------------------------------------------- /src/slam_llm/policies/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. 3 | 4 | from slam_llm.policies.mixed_precision import * 5 | from slam_llm.policies.wrapping import * 6 | from slam_llm.policies.activation_checkpointing_functions import apply_fsdp_checkpointing 7 | from slam_llm.policies.anyprecision_optimizer import AnyPrecisionAdamW 8 | -------------------------------------------------------------------------------- /examples/st_covost2/conf/ds_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_micro_batch_size_per_gpu": 4, 3 | "gradient_accumulation_steps": 1, 4 | "optimizer": { 5 | "type": "Adam", 6 | "params": { 7 | "lr": 1e-4 8 | } 9 | }, 10 | "fp16": { 11 | "enabled": true 12 | }, 13 | "zero_optimization": { 14 | "stage": 3, 15 | "offload_optimizer": { 16 | "device": "cpu" 17 | } 18 | } 19 | } -------------------------------------------------------------------------------- /models/README.md: -------------------------------------------------------------------------------- 1 | 2 | ## Download Model 3 | Encoder | Adapter | LLM 4 | |---|---|--- 5 | [whisper-large-v3](https://huggingface.co/openai/whisper-large-v3) | [q-former+mlp](https://huggingface.co/yxdu/llm-srt) | [Qwen2.5-3B](https://huggingface.co/Qwen/Qwen2.5-3B) 6 | ``` 7 | cd models/ 8 | 9 | git lfs clone https://huggingface.co/yxdu/llm-srt 10 | git lfs clone https://huggingface.co/openai/whisper-large-v3 11 | # for 3B model (support 15 languages) 12 | git lfs clone https://huggingface.co/Qwen/Qwen2.5-3B 13 | cd .. 14 | ``` 15 | 16 | 17 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate 2 | appdirs 3 | loralib 4 | bitsandbytes 5 | black 6 | black[jupyter] 7 | datasets 8 | fire 9 | peft 10 | sentencepiece 11 | py7zr 12 | scipy 13 | optimum 14 | wandb 15 | hydra-core>=1.3.2 16 | openai-whisper 17 | wandb 18 | soundfile 19 | evaluate 20 | transformers 21 | datasets==3.6.0 22 | sacrebleu 23 | jiwer 24 | librosa 25 | unbabel-comet 26 | nltk 27 | openpyxl 28 | torch==2.4.0 29 | torchaudio==2.4.0 30 | torchvision==0.19.0 31 | debugpy 32 | importlib_metadata==4.13.0 33 | setuptools>=60.0.0 34 | evaluate 35 | torchcodec -------------------------------------------------------------------------------- /src/slam_llm/utils/num2word.py: -------------------------------------------------------------------------------- 1 | 2 | import sys 3 | from num2words import num2words 4 | 5 | file = sys.argv[1] 6 | out_file = sys.argv[2] 7 | 8 | with open(file) as f: 9 | lines = f.readlines() 10 | 11 | with open(out_file, "w") as fw: 12 | for line in lines: 13 | key, content = line.strip().split(maxsplit=1) 14 | new_content = "" 15 | for ct in content.split(): 16 | if ct.isdigit(): 17 | ct = num2words(ct) 18 | new_content += ct + " " 19 | fw.write(key + " " + new_content + "\n") -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | __pycache__ 3 | .ipynb_checkpoints 4 | .vscode 5 | debug.py 6 | .idea/* 7 | transformers 8 | wandb/ 9 | log/ 10 | *.log 11 | outputs/ 12 | data/ 13 | examples/vsr_LRS3/scripts/decode_avhubert_vo_vicuna_7b_noself.sh 14 | examples/asr_librispeech/scripts/decode_hubert_xtralarge_linear_vicuna_7b_copy.sh 15 | examples/vsr_LRS3/scripts/decode_avhubert_vo_vicuna_7b_copy.sh 16 | scripts_all 17 | examples/hotwords_librispeech 18 | examples/asr_librispeech/scripts/decode_hubert_xtralarge_linear_vicuna_7b_debug.sh 19 | *.pt 20 | *.jsonl 21 | models/llm-srt 22 | models/Qwen2.5-3B 23 | models/whisper-large-v3 -------------------------------------------------------------------------------- /src/slam_llm/utils/metric.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def compute_accuracy(pad_outputs, pad_targets, ignore_label): 4 | """Calculate accuracy. 5 | 6 | Args: 7 | pad_outputs (LongTensor): Prediction tensors (B, Lmax). 8 | pad_targets (LongTensor): Target label tensors (B, Lmax). 9 | ignore_label (int): Ignore label id. 10 | 11 | Returns: 12 | float: Accuracy value (0.0 - 1.0). 13 | 14 | """ 15 | mask = pad_targets != ignore_label 16 | numerator = torch.sum( 17 | pad_outputs.masked_select(mask) == pad_targets.masked_select(mask) 18 | ) 19 | denominator = torch.sum(mask) 20 | return numerator.float() / denominator.float() #(FIX:MZY):return torch.Tensor type -------------------------------------------------------------------------------- /src/slam_llm/utils/whisper_tn.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import re 4 | import string 5 | from whisper_normalizer.english import EnglishTextNormalizer 6 | 7 | english_normalizer = EnglishTextNormalizer() 8 | 9 | def normalize_text(srcfn, dstfn): 10 | with open(srcfn, "r") as f_read, open(dstfn, "w") as f_write: 11 | all_lines = f_read.readlines() 12 | for line in all_lines: 13 | line = line.strip() 14 | line_arr = line.split() 15 | key = line_arr[0] 16 | conts = " ".join(line_arr[1:]) 17 | normalized_conts = english_normalizer(conts) 18 | f_write.write("{0}\t{1}\n".format(key, normalized_conts)) 19 | 20 | if __name__ == "__main__": 21 | srcfn = sys.argv[1] 22 | dstfn = sys.argv[2] 23 | normalize_text(srcfn, dstfn) -------------------------------------------------------------------------------- /src/slam_llm/policies/activation_checkpointing_functions.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. 3 | 4 | from functools import partial 5 | 6 | from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( 7 | checkpoint_wrapper, 8 | CheckpointImpl, 9 | apply_activation_checkpointing, 10 | ) 11 | from transformers.models.llama.modeling_llama import LlamaDecoderLayer 12 | 13 | non_reentrant_wrapper = partial( 14 | checkpoint_wrapper, 15 | checkpoint_impl=CheckpointImpl.NO_REENTRANT, 16 | ) 17 | 18 | check_fn = lambda submodule: isinstance(submodule, LlamaDecoderLayer) 19 | 20 | 21 | def apply_fsdp_checkpointing(model): 22 | """apply activation checkpointing to model 23 | returns None as model is updated directly 24 | """ 25 | print(f"--> applying fsdp activation checkpointing...") 26 | 27 | apply_activation_checkpointing( 28 | model, checkpoint_wrapper_fn=non_reentrant_wrapper, check_fn=check_fn 29 | ) 30 | -------------------------------------------------------------------------------- /src/slam_llm/utils/preprocess_text.py: -------------------------------------------------------------------------------- 1 | 2 | import sys 3 | import re 4 | import string 5 | 6 | in_f = sys.argv[1] 7 | out_f = sys.argv[2] 8 | 9 | 10 | with open(in_f, "r", encoding="utf-8") as f: 11 | lines = f.readlines() 12 | 13 | with open(out_f, "w", encoding="utf-8") as f: 14 | for line in lines: 15 | outs = line.strip().split("\t", 1) 16 | if len(outs) == 2: 17 | idx, text = outs 18 | text = re.sub("<|", "", text) 19 | text = re.sub("|>", "", text) 20 | text = re.sub("—", "", text) 21 | # text = re.sub("", "", text) 22 | # text = re.sub("@@", "", text) 23 | # text = re.sub("@", "", text) 24 | # text = re.sub("", "", text) 25 | # text = re.sub(" ", "", text) 26 | # text = text.lower() 27 | translator = str.maketrans('', '', string.punctuation.replace("'", "")) 28 | result = text.translate(translator) 29 | text = result.upper() 30 | else: 31 | idx = outs[0] 32 | text = " " 33 | 34 | # text = [x for x in text] 35 | # text = " ".join(text) 36 | out = "{} {}\n".format(idx, text) 37 | f.write(out) 38 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature-request.yml: -------------------------------------------------------------------------------- 1 | name: 🚀 Feature request 2 | description: Submit a proposal/request for a new slam-llm feature 3 | 4 | body: 5 | - type: textarea 6 | id: feature-pitch 7 | attributes: 8 | label: 🚀 The feature, motivation and pitch 9 | description: > 10 | A clear and concise description of the feature proposal. Please outline the motivation for the proposal. Is your feature request related to a specific problem? e.g., *"I'm working on X and would like Y to be possible"*. If this is related to another GitHub issue, please link here too. 11 | validations: 12 | required: true 13 | 14 | - type: textarea 15 | id: alternatives 16 | attributes: 17 | label: Alternatives 18 | description: > 19 | A description of any alternative solutions or features you've considered, if any. 20 | 21 | - type: textarea 22 | id: additional-context 23 | attributes: 24 | label: Additional context 25 | description: > 26 | Add any other context or screenshots about the feature request. 27 | 28 | - type: markdown 29 | attributes: 30 | value: > 31 | Thanks for contributing 🎉! -------------------------------------------------------------------------------- /src/slam_llm/inference/model_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # This software may be used and distributed according to the terms of the GNU General Public License version 3. 3 | 4 | from peft import PeftModel 5 | from transformers import LlamaForCausalLM, LlamaConfig 6 | 7 | # Function to load the main model for text generation 8 | def load_model(model_name, quantization): 9 | model = LlamaForCausalLM.from_pretrained( 10 | model_name, 11 | return_dict=True, 12 | load_in_8bit=quantization, 13 | device_map="auto", 14 | low_cpu_mem_usage=True, 15 | ) 16 | return model 17 | 18 | 19 | # Function to load the PeftModel for performance optimization 20 | def load_peft_model(model, peft_model): 21 | peft_model = PeftModel.from_pretrained(model, peft_model) 22 | return peft_model 23 | 24 | # Loading the model from config to load FSDP checkpoints into that 25 | def load_llama_from_config(config_path): 26 | model_config = LlamaConfig.from_pretrained(config_path) 27 | model = LlamaForCausalLM(config=model_config) 28 | return model 29 | 30 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Ziyang Ma 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /src/slam_llm/policies/mixed_precision.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. 3 | 4 | import torch 5 | 6 | from torch.distributed.fsdp import ( 7 | MixedPrecision, 8 | ) 9 | 10 | # requires grad scaler in main loop 11 | fpSixteen = MixedPrecision( 12 | param_dtype=torch.float16, 13 | # Gradient communication precision. 14 | reduce_dtype=torch.float16, 15 | # Buffer precision. 16 | buffer_dtype=torch.float16, 17 | ) 18 | 19 | bfSixteen = MixedPrecision( 20 | param_dtype=torch.bfloat16, 21 | # Gradient communication precision. 22 | reduce_dtype=torch.bfloat16, 23 | # Buffer precision. 24 | buffer_dtype=torch.bfloat16, 25 | cast_forward_inputs=True, 26 | ) 27 | 28 | bfSixteen_mixed = MixedPrecision( 29 | param_dtype=torch.float32, 30 | reduce_dtype=torch.bfloat16, 31 | buffer_dtype=torch.bfloat16, 32 | ) 33 | 34 | fp32_policy = MixedPrecision( 35 | param_dtype=torch.float32, 36 | reduce_dtype=torch.float32, 37 | buffer_dtype=torch.float32, 38 | ) 39 | -------------------------------------------------------------------------------- /src/slam_llm/policies/wrapping.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. 3 | 4 | import functools 5 | 6 | from transformers.models.llama.modeling_llama import LlamaDecoderLayer 7 | from torch.distributed.fsdp.wrap import ( 8 | transformer_auto_wrap_policy, 9 | size_based_auto_wrap_policy, 10 | ) 11 | 12 | 13 | def get_size_policy(min_params=1e8): 14 | num_wrap_policy = functools.partial( 15 | size_based_auto_wrap_policy, min_num_params=min_params 16 | ) 17 | return num_wrap_policy 18 | 19 | 20 | def get_llama_wrapper(): 21 | """we register our main layer class and use the fsdp transformer wrapping policy 22 | ensures embedding layers are in the root fsdp unit for shared access and that fsdp units map to transformer layers 23 | """ 24 | # ==== use new transformer wrapper 25 | 26 | llama_auto_wrap_policy = functools.partial( 27 | transformer_auto_wrap_policy, 28 | transformer_layer_cls={ 29 | LlamaDecoderLayer, 30 | }, 31 | ) 32 | 33 | return llama_auto_wrap_policy 34 | -------------------------------------------------------------------------------- /src/slam_llm/utils/llm_tn.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import re 4 | import string 5 | from whisper_normalizer.english import EnglishTextNormalizer 6 | 7 | english_normalizer = EnglishTextNormalizer() 8 | 9 | def reduce_repeated_words(text): 10 | pattern ="." 11 | for i in range(1, 50): 12 | p = pattern * i 13 | text = re.sub(f'({p})' + r'\1{4,200}', r'\1', text) 14 | for i in range (50, 100): 15 | p = pattern * i 16 | text = re.sub(f'({p})' + r'\1{3,200}', r'\1', text) 17 | return text 18 | 19 | def normalize_text(srcfn, dstfn): 20 | with open(srcfn, "r") as f_read, open(dstfn, "w") as f_write: 21 | all_lines = f_read.readlines() 22 | for line in all_lines: 23 | line = line.strip() 24 | line_arr = line.split() 25 | key = line_arr[0] 26 | conts = " ".join(line_arr[1:]) 27 | normalized_conts = english_normalizer(conts) 28 | reduced_conts = reduce_repeated_words(normalized_conts) 29 | f_write.write("{0}\t{1}\n".format(key, reduced_conts)) 30 | 31 | if __name__ == "__main__": 32 | srcfn = sys.argv[1] 33 | dstfn = sys.argv[2] 34 | normalize_text(srcfn, dstfn) -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["hatchling", "hatch-requirements-txt"] 3 | build-backend = "hatchling.build" 4 | 5 | [project] 6 | name = "slam-llm" 7 | version = "0.0.1" 8 | description = "SLAM-LLM is a deep learning toolkit that allows researchers and developers to train custom multimodal large language model (MLLM), focusing on Speech, Language, Audio, Music processing. We provide detailed recipes for training and high-performance checkpoints for inference." 9 | readme = "README.md" 10 | requires-python = ">=3.8" 11 | classifiers = [ 12 | "Programming Language :: Python :: 3", 13 | "License :: Other/Proprietary License", 14 | "Operating System :: OS Independent", 15 | ] 16 | dynamic = ["dependencies"] 17 | 18 | [project.optional-dependencies] 19 | vllm = ["vllm"] 20 | tests = ["pytest-mock"] 21 | auditnlg = ["auditnlg"] 22 | 23 | [project.urls] 24 | "Homepage" = "https://github.com/ddlBoJack/SLAM-LLM" 25 | "Bug Tracker" = "https://github.com/ddlBoJack/SLAM-LLM/issues" 26 | 27 | [tool.hatch.build] 28 | exclude = [ 29 | "dist/*", 30 | ] 31 | 32 | [tool.hatch.build.targets.wheel] 33 | packages = ["src/slam_llm"] 34 | 35 | [tool.hatch.metadata.hooks.requirements_txt] 36 | files = ["requirements.txt"] -------------------------------------------------------------------------------- /src/slam_llm/data/concatenator.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. 3 | 4 | from tqdm import tqdm 5 | from itertools import chain 6 | 7 | from torch.utils.data import Dataset 8 | 9 | 10 | class ConcatDataset(Dataset): 11 | def __init__(self, dataset, chunk_size=4096): 12 | self.dataset = dataset 13 | self.chunk_size = chunk_size 14 | 15 | self.samples = [] 16 | 17 | buffer = { 18 | "input_ids": [], 19 | "attention_mask": [], 20 | "labels": [], 21 | } 22 | 23 | for sample in tqdm(self.dataset, desc="Preprocessing dataset", dynamic_ncols=True): 24 | buffer = {k: v + sample[k] for k,v in buffer.items()} 25 | 26 | while len(next(iter(buffer.values()))) > self.chunk_size: 27 | self.samples.append({k: v[:self.chunk_size] for k,v in buffer.items()}) 28 | buffer = {k: v[self.chunk_size:] for k,v in buffer.items()} 29 | 30 | def __getitem__(self, idx): 31 | return self.samples[idx] 32 | 33 | def __len__(self): 34 | return len(self.samples) 35 | -------------------------------------------------------------------------------- /src/slam_llm/utils/model_utils.py: -------------------------------------------------------------------------------- 1 | from slam_llm.utils.dataset_utils import load_module_from_py_file 2 | from pathlib import Path 3 | 4 | def get_custom_model_factory(model_config, logger): 5 | costom_model_path = model_config.get( 6 | "file", None 7 | ) 8 | if costom_model_path is None: 9 | from slam_llm.models.slam_model import model_factory 10 | return model_factory 11 | 12 | if ":" in costom_model_path: 13 | module_path, func_name = costom_model_path.split(":") 14 | else: 15 | module_path, func_name = costom_model_path, "model_factory" 16 | 17 | if not module_path.endswith(".py"): 18 | raise ValueError(f"Dataset file {module_path} is not a .py file.") 19 | 20 | module_path = Path(module_path) 21 | if not module_path.is_file(): 22 | raise FileNotFoundError(f"Dataset py file {module_path.as_posix()} does not exist or is not a file.") 23 | 24 | module = load_module_from_py_file(module_path.as_posix()) 25 | try: 26 | return getattr(module, func_name) 27 | except AttributeError as e: 28 | logger.info(f"It seems like the given method name ({func_name}) is not present in the model .py file ({module_path.as_posix()}).") 29 | raise e 30 | 31 | -------------------------------------------------------------------------------- /src/slam_llm/utils/compute_aac_metrics.py: -------------------------------------------------------------------------------- 1 | from aac_metrics import evaluate 2 | import sys 3 | 4 | def compute_wer(ref_file, 5 | hyp_file): 6 | pred_captions = [] 7 | gt_captions = [] 8 | 9 | with open(hyp_file, 'r') as hyp_reader: 10 | for line in hyp_reader: 11 | key = line.strip().split()[0] 12 | value = line.strip().split()[1:] 13 | pred_captions.append(value) 14 | with open(ref_file, 'r') as ref_reader: 15 | for line in ref_reader: 16 | key = line.strip().split()[0] 17 | value = line.strip().split()[1:] 18 | gt_captions.append(value) 19 | 20 | print('Used lines:', len(pred_captions)) 21 | candidates: list[str] = pred_captions 22 | mult_references: list[list[str]] = [[gt] for gt in gt_captions] 23 | 24 | corpus_scores, _ = evaluate(candidates, mult_references) 25 | print(corpus_scores) 26 | # dict containing the score of each metric: "bleu_1", "bleu_2", "bleu_3", "bleu_4", "rouge_l", "meteor", "cider_d", "spice", "spider" 27 | # {"bleu_1": tensor(0.4278), "bleu_2": ..., ...} 28 | 29 | 30 | if __name__ == '__main__': 31 | if len(sys.argv) != 3: 32 | print("usage : python compute_aac_metrics.py test.ref test.hyp") 33 | sys.exit(0) 34 | 35 | ref_file = sys.argv[1] 36 | hyp_file = sys.argv[2] 37 | cer_detail_file = sys.argv[3] 38 | compute_wer(ref_file, hyp_file, cer_detail_file) -------------------------------------------------------------------------------- /src/slam_llm/utils/fsdp_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. 3 | 4 | def fsdp_auto_wrap_policy(model, transformer_layer_name): 5 | import functools 6 | 7 | from torch.distributed.fsdp.wrap import _or_policy, lambda_auto_wrap_policy, transformer_auto_wrap_policy 8 | 9 | from peft.tuners import PrefixEncoder, PromptEmbedding, PromptEncoder 10 | 11 | def lambda_policy_fn(module): 12 | if ( 13 | len(list(module.named_children())) == 0 14 | and getattr(module, "weight", None) is not None 15 | and module.weight.requires_grad 16 | ): 17 | return True 18 | return False 19 | 20 | lambda_policy = functools.partial(lambda_auto_wrap_policy, lambda_fn=lambda_policy_fn) 21 | transformer_wrap_policy = functools.partial( 22 | transformer_auto_wrap_policy, 23 | transformer_layer_cls=( 24 | PrefixEncoder, 25 | PromptEncoder, 26 | PromptEmbedding, 27 | transformer_layer_name, 28 | # FullyShardedDataParallelPlugin.get_module_class_from_name( 29 | # model, transformer_layer_name 30 | # ), 31 | ), 32 | ) 33 | 34 | auto_wrap_policy = functools.partial(_or_policy, policies=[lambda_policy, transformer_wrap_policy]) 35 | return auto_wrap_policy -------------------------------------------------------------------------------- /.github/PULL_REQUEST_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | # What does this PR do? 2 | 3 | 11 | 12 | 13 | 14 | Fixes # (issue) 15 | 16 | 17 | ## Feature/Issue validation/testing 18 | 19 | Please describe the tests that you ran to verify your changes and relevant result summary. Provide instructions so it can be reproduced. 20 | Please also list any relevant details for your test configuration. 21 | 22 | - [ ] Test A 23 | Logs for Test A 24 | 25 | - [ ] Test B 26 | Logs for Test B 27 | 28 | 29 | ## Before submitting 30 | - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). 31 | - [ ] Did you read the [contributor guideline](https://github.com/facebookresearch/llama-recipes/blob/main/CONTRIBUTING.md#pull-requests), 32 | Pull Request section? 33 | - [ ] Was this discussed/approved via a Github issue? Please add a link 34 | to it if that's the case. 35 | - [ ] Did you make sure to update the documentation with your changes? 36 | - [ ] Did you write any new necessary tests? 37 | 38 | Thanks for contributing 🎉! 39 | -------------------------------------------------------------------------------- /examples/st_covost2/finetune_asr.py: -------------------------------------------------------------------------------- 1 | from slam_llm.pipeline.finetune import main as train 2 | 3 | import hydra 4 | import logging 5 | from typing import Optional 6 | from dataclasses import dataclass, field 7 | from omegaconf import DictConfig, ListConfig, OmegaConf 8 | from asr_config import ModelConfig, TrainConfig, DataConfig, LogConfig, FSDPConfig 9 | 10 | @dataclass 11 | class RunConfig: 12 | dataset_config: DataConfig = field(default_factory=DataConfig) 13 | model_config: ModelConfig = field(default_factory=ModelConfig) 14 | train_config: TrainConfig = field(default_factory=TrainConfig) 15 | log_config: LogConfig = field(default_factory=LogConfig) 16 | fsdp_config: FSDPConfig = field(default_factory=FSDPConfig) 17 | debug: bool = field(default=False, metadata={"help": "Use pdb when true"}) 18 | metric: str = field(default="acc", metadata={"help": "The metric for evaluation"}) 19 | ckpt_path: Optional[str] = field( 20 | default=None, metadata={"help": "The path to projector checkpoint"} 21 | ) 22 | 23 | @hydra.main(config_name=None, version_base=None) 24 | def main_hydra(cfg: DictConfig): 25 | run_config = RunConfig() 26 | cfg = OmegaConf.merge(run_config, cfg) 27 | def to_plain_list(cfg_item): 28 | if isinstance(cfg_item, ListConfig): 29 | return OmegaConf.to_container(cfg_item, resolve=True) 30 | elif isinstance(cfg_item, DictConfig): 31 | return {k: to_plain_list(v) for k, v in cfg_item.items()} 32 | else: 33 | return cfg_item 34 | 35 | # kwargs = to_plain_list(cfg) 36 | kwargs = cfg 37 | log_level = getattr(logging, kwargs.get("log_level", "INFO").upper()) 38 | 39 | logging.basicConfig(level=log_level) 40 | 41 | if kwargs.get("debug", False): 42 | import pdb; 43 | pdb.set_trace() 44 | 45 | train(kwargs) 46 | 47 | 48 | if __name__ == "__main__": 49 | main_hydra() 50 | -------------------------------------------------------------------------------- /.github/workflows/spellcheck.yml: -------------------------------------------------------------------------------- 1 | name: SpellCheck 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | pull_request: 8 | branches: 9 | - main 10 | jobs: 11 | build: 12 | runs-on: ubuntu-20.04 13 | name: Lint changed files 14 | steps: 15 | - uses: actions/checkout@v3 16 | with: 17 | fetch-depth: 0 # OR "2" -> To retrieve the preceding commit. 18 | 19 | - name: Check links in all markdown files 20 | uses: gaurav-nelson/github-action-markdown-link-check@1.0.13 21 | with: 22 | use-verbose-mode: 'yes' 23 | config-file: "scripts/markdown_link_check_config.json" 24 | 25 | - name: Get changed files 26 | id: changed-files 27 | uses: tj-actions/changed-files@v29.0.4 28 | with: 29 | 30 | files: | 31 | **/*.py 32 | 33 | spellcheck: 34 | runs-on: ubuntu-20.04 35 | steps: 36 | - uses: actions/checkout@v3 37 | 38 | - name: Install dependencies 39 | run: | 40 | sudo apt-get install aspell aspell-en 41 | pip install pyspelling 42 | 43 | - name: Get changed files 44 | id: changed-files 45 | uses: tj-actions/changed-files@v29.0.4 46 | with: 47 | files: | 48 | **/*.md 49 | 50 | - name: Check spellings 51 | run: | 52 | sources="" 53 | for file in ${{ steps.changed-files.outputs.all_changed_files }}; do 54 | sources="${sources} -S $file" 55 | done 56 | if [ ! "$sources" ]; then 57 | echo "No files to spellcheck" 58 | else 59 | pyspelling -c $GITHUB_WORKSPACE/scripts/spellcheck_conf/spellcheck.yaml --name Markdown $sources 60 | fi 61 | 62 | - name: In the case of misspellings 63 | if: ${{ failure() }} 64 | run: | 65 | echo "Please fix the misspellings. If you are sure about some of them, " 66 | echo "so append those to scripts/spellcheck_conf/wordlist.txt" 67 | -------------------------------------------------------------------------------- /src/slam_llm/utils/dataset_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. 3 | 4 | import importlib 5 | from functools import partial 6 | from pathlib import Path 7 | 8 | import torch 9 | 10 | import logging 11 | logger = logging.getLogger(__name__) 12 | 13 | 14 | def load_module_from_py_file(py_file: str) -> object: 15 | """ 16 | This method loads a module from a py file which is not in the Python path 17 | """ 18 | module_name = Path(py_file).name 19 | loader = importlib.machinery.SourceFileLoader(module_name, py_file) 20 | spec = importlib.util.spec_from_loader(module_name, loader) 21 | module = importlib.util.module_from_spec(spec) 22 | 23 | loader.exec_module(module) 24 | 25 | return module 26 | 27 | 28 | def get_custom_dataset(dataset_config, tokenizer, split: str): 29 | if ":" in dataset_config.file: 30 | module_path, func_name = dataset_config.file.split(":") 31 | else: 32 | module_path, func_name = dataset_config.file, "get_custom_dataset" 33 | 34 | if not module_path.endswith(".py"): 35 | raise ValueError(f"Dataset file {module_path} is not a .py file.") 36 | 37 | module_path = Path(module_path) 38 | if not module_path.is_file(): 39 | raise FileNotFoundError(f"Dataset py file {module_path.as_posix()} does not exist or is not a file.") 40 | 41 | module = load_module_from_py_file(module_path.as_posix()) 42 | try: 43 | return getattr(module, func_name)(dataset_config, tokenizer, split) 44 | except AttributeError as e: 45 | logger.info(f"It seems like the given method name ({func_name}) is not present in the dataset .py file ({module_path.as_posix()}).") 46 | raise e 47 | 48 | 49 | def get_preprocessed_dataset( 50 | tokenizer, dataset_config, split: str = "train" 51 | ) -> torch.utils.data.Dataset: 52 | 53 | return get_custom_dataset( 54 | dataset_config, 55 | tokenizer, 56 | split, 57 | ) 58 | -------------------------------------------------------------------------------- /src/slam_llm/data/sampler.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. 3 | 4 | import random 5 | from itertools import islice 6 | 7 | import numpy as np 8 | import torch 9 | 10 | 11 | class LengthBasedBatchSampler(torch.utils.data.BatchSampler): 12 | def __init__(self, data_source, batch_size: int, drop_last: bool, shuffle: bool=True) -> None: 13 | if isinstance(next(iter(data_source)), dict): 14 | first_key = next(iter(next(iter(data_source)).keys())) 15 | self.lengths = [len(d[first_key]) for d in data_source] 16 | else: 17 | self.lengths = [len(d) for d in data_source] 18 | self.batch_size = batch_size 19 | self.drop_last = drop_last 20 | self.shuffle = shuffle 21 | 22 | def __iter__(self): 23 | ids = np.argsort(self.lengths) 24 | if self.drop_last: 25 | ids = ids[:len(ids) // self.batch_size * self.batch_size] 26 | 27 | batches = [ids[i:i+self.batch_size] for i in range(0, len(ids), self.batch_size)] 28 | 29 | if self.shuffle: 30 | random.shuffle(batches) 31 | 32 | for b in batches: 33 | yield b 34 | 35 | def __len__(self): 36 | if self.drop_last: 37 | return len(self.lengths) // self.batch_size 38 | else: 39 | return len(self.lengths) // self.batch_size + (len(self.lengths) % self.batch_size > 0) 40 | 41 | 42 | class DistributedLengthBasedBatchSampler(torch.utils.data.BatchSampler): 43 | def __init__(self, data_source, batch_size: int, num_replicas: int, rank: int, shuffle: bool = True, seed: int = 0) -> None: 44 | random.seed(seed) 45 | self.batch_sampler = LengthBasedBatchSampler( 46 | data_source, batch_size=batch_size, drop_last=True, shuffle=shuffle 47 | ) 48 | self.num_replicas = num_replicas 49 | self.rank = rank 50 | 51 | def __iter__(self): 52 | max_length = len(self.batch_sampler) // self.num_replicas * self.num_replicas 53 | return islice(self.batch_sampler, self.rank, max_length, self.num_replicas) 54 | 55 | def __len__(self): 56 | return len(self.batch_sampler) // self.num_replicas 57 | -------------------------------------------------------------------------------- /src/slam_llm/inference/chat_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. 3 | 4 | import json 5 | from typing import List, Literal, TypedDict 6 | 7 | 8 | Role = Literal["user", "assistant"] 9 | 10 | 11 | class Message(TypedDict): 12 | role: Role 13 | content: str 14 | 15 | 16 | Dialog = List[Message] 17 | 18 | B_INST, E_INST = "[INST]", "[/INST]" 19 | B_SYS, E_SYS = "<>\n", "\n<>\n\n" 20 | def format_tokens(dialogs, tokenizer): 21 | prompt_tokens = [] 22 | for dialog in dialogs: 23 | if dialog[0]["role"] == "system": 24 | dialog = [ 25 | { 26 | "role": dialog[1]["role"], 27 | "content": B_SYS 28 | + dialog[0]["content"] 29 | + E_SYS 30 | + dialog[1]["content"], 31 | } 32 | ] + dialog[2:] 33 | assert all([msg["role"] == "user" for msg in dialog[::2]]) and all( 34 | [msg["role"] == "assistant" for msg in dialog[1::2]] 35 | ), ( 36 | "model only supports 'system','user' and 'assistant' roles, " 37 | "starting with user and alternating (u/a/u/a/u...)" 38 | ) 39 | """ 40 | Please verify that your tokenizer support adding "[INST]", "[/INST]" to your inputs. 41 | Here, we are adding it manually. 42 | """ 43 | dialog_tokens: List[int] = sum( 44 | [ 45 | tokenizer.encode( 46 | f"{B_INST} {(prompt['content']).strip()} {E_INST} {(answer['content']).strip()} ", 47 | ) + [tokenizer.eos_token_id] 48 | for prompt, answer in zip(dialog[::2], dialog[1::2]) 49 | ], 50 | [], 51 | ) 52 | assert ( 53 | dialog[-1]["role"] == "user" 54 | ), f"Last message must be from user, got {dialog[-1]['role']}" 55 | dialog_tokens += tokenizer.encode( 56 | f"{B_INST} {(dialog[-1]['content']).strip()} {E_INST}", 57 | ) 58 | prompt_tokens.append(dialog_tokens) 59 | return prompt_tokens 60 | 61 | 62 | def read_dialogs_from_file(file_path): 63 | with open(file_path, 'r') as file: 64 | dialogs = json.load(file) 65 | return dialogs 66 | -------------------------------------------------------------------------------- /src/slam_llm/utils/memory_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. 3 | 4 | import gc 5 | import psutil 6 | import threading 7 | 8 | import torch 9 | 10 | def byte2gb(x): 11 | return int(x / 2**30) 12 | # This context manager is used to track the peak memory usage of the process 13 | class MemoryTrace: 14 | def __enter__(self): 15 | gc.collect() 16 | torch.cuda.empty_cache() 17 | torch.cuda.reset_max_memory_allocated() # reset the peak gauge to zero 18 | self.begin = byte2gb(torch.cuda.memory_allocated()) 19 | self.process = psutil.Process() 20 | self.cpu_begin = byte2gb(self.cpu_mem_used()) 21 | self.peak_monitoring = True 22 | peak_monitor_thread = threading.Thread(target=self.peak_monitor_func) 23 | peak_monitor_thread.daemon = True 24 | peak_monitor_thread.start() 25 | return self 26 | 27 | def cpu_mem_used(self): 28 | """get resident set size memory for the current process""" 29 | return self.process.memory_info().rss 30 | 31 | def peak_monitor_func(self): 32 | self.cpu_peak = -1 33 | 34 | while True: 35 | self.cpu_peak = max(self.cpu_mem_used(), self.cpu_peak) 36 | 37 | # can't sleep or will not catch the peak right (this comment is here on purpose) 38 | # time.sleep(0.001) # 1msec 39 | 40 | if not self.peak_monitoring: 41 | break 42 | 43 | def __exit__(self, *exc): 44 | self.peak_monitoring = False 45 | 46 | gc.collect() 47 | torch.cuda.empty_cache() 48 | self.end = byte2gb(torch.cuda.memory_allocated()) 49 | self.peak = byte2gb(torch.cuda.max_memory_allocated()) 50 | cuda_info = torch.cuda.memory_stats() 51 | self.peak_active_gb = byte2gb(cuda_info["active_bytes.all.peak"]) 52 | self.cuda_malloc_retires = cuda_info.get("num_alloc_retries", 0) 53 | self.peak_active_gb = byte2gb(cuda_info["active_bytes.all.peak"]) 54 | self.m_cuda_ooms = cuda_info.get("num_ooms", 0) 55 | self.used = byte2gb(self.end - self.begin) 56 | self.peaked = byte2gb(self.peak - self.begin) 57 | self.max_reserved = byte2gb(torch.cuda.max_memory_reserved()) 58 | 59 | self.cpu_end = self.cpu_mem_used() 60 | self.cpu_used = byte2gb(self.cpu_end - self.cpu_begin) 61 | self.cpu_peaked = byte2gb(self.cpu_peak - self.cpu_begin) 62 | # print(f"delta used/peak {self.used:4d}/{self.peaked:4d}") -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug.yml: -------------------------------------------------------------------------------- 1 | name: 🐛 Bug Report 2 | description: Create a report to help us reproduce and fix the bug 3 | 4 | body: 5 | - type: markdown 6 | attributes: 7 | value: > 8 | #### Before submitting a bug, please make sure the issue hasn't been already addressed by searching through [the 9 | existing and past issues](https://github.com/ddlBoJack/SLAM-LLM/issues). 10 | 11 | - type: textarea 12 | id: system-info 13 | attributes: 14 | label: System Info 15 | description: | 16 | Please share your system info with us. You can use the following command to capture your environment information 17 | python -m "torch.utils.collect_env" 18 | 19 | placeholder: | 20 | PyTorch version, CUDA version, GPU type, #num of GPUs... 21 | validations: 22 | required: true 23 | 24 | - type: checkboxes 25 | id: information-scripts-examples 26 | attributes: 27 | label: Information 28 | description: 'The problem arises when using:' 29 | options: 30 | - label: "The official example scripts" 31 | - label: "My own modified scripts" 32 | 33 | - type: textarea 34 | id: bug-description 35 | attributes: 36 | label: 🐛 Describe the bug 37 | description: | 38 | Please provide a clear and concise description of what the bug is. 39 | 40 | Provide the exact command(s) that you ran with the settings eg using FSDP and PEFT or pure FSDP. 41 | 42 | Please also paste or describe the results you observe instead of the expected results. 43 | placeholder: | 44 | A clear and concise description of what the bug is. 45 | 46 | ```python 47 | # Command that you used for running the examples 48 | ``` 49 | Description of the results 50 | validations: 51 | required: true 52 | 53 | - type: textarea 54 | attributes: 55 | label: Error logs 56 | description: | 57 | If you observe an error, please paste the error message including the **full** traceback of the exception. It may be relevant to wrap error messages in ```` ```triple quotes blocks``` ````. 58 | 59 | placeholder: | 60 | ``` 61 | The error message you got, with the full traceback. 62 | ``` 63 | 64 | validations: 65 | required: true 66 | 67 | 68 | - type: textarea 69 | id: expected-behavior 70 | validations: 71 | required: true 72 | attributes: 73 | label: Expected behavior 74 | description: "A clear and concise description of what you would expect to happen." 75 | 76 | - type: markdown 77 | attributes: 78 | value: > 79 | Thanks for contributing 🎉! 80 | -------------------------------------------------------------------------------- /src/slam_llm/inference/checkpoint_converter_fsdp_hf.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. 3 | 4 | # from accelerate import init_empty_weights, load_checkpoint_and_dispatch 5 | 6 | import fire 7 | import os 8 | import sys 9 | import yaml 10 | 11 | from transformers import LlamaTokenizer 12 | 13 | from slam_llm.inference.model_utils import load_llama_from_config 14 | 15 | # Get the current file's directory 16 | current_directory = os.path.dirname(os.path.abspath(__file__)) 17 | 18 | # Get the parent directory 19 | parent_directory = os.path.dirname(current_directory) 20 | 21 | # Append the parent directory to sys.path 22 | sys.path.append(parent_directory) 23 | from model_checkpointing import load_sharded_model_single_gpu 24 | 25 | def main( 26 | fsdp_checkpoint_path="", # Path to FSDP Sharded model checkpoints 27 | consolidated_model_path="", # Path to save the HF converted model checkpoints 28 | HF_model_path_or_name="" # Path/ name of the HF model that include config.json and tokenizer_config.json (e.g. meta-llama/Llama-2-7b-chat-hf) 29 | ): 30 | 31 | try: 32 | file_name = 'train_params.yaml' 33 | # Combine the directory and file name to create the full path 34 | train_params_path = os.path.join(fsdp_checkpoint_path, file_name) 35 | # Open the file 36 | with open(train_params_path, 'r') as file: 37 | # Load the YAML data 38 | data = yaml.safe_load(file) 39 | 40 | # Access the 'model_name' field 41 | HF_model_path_or_name = data.get('model_name') 42 | 43 | print(f"Model name: {HF_model_path_or_name}") 44 | except FileNotFoundError: 45 | print(f"The file {train_params_path} does not exist.") 46 | HF_model_path_or_name = input("Please enter the model name: ") 47 | print(f"Model name: {HF_model_path_or_name}") 48 | except Exception as e: 49 | print(f"An error occurred: {e}") 50 | 51 | 52 | #load the HF model definition from config 53 | model_def = load_llama_from_config(HF_model_path_or_name) 54 | print("model is loaded from config") 55 | #load the FSDP sharded checkpoints into the model 56 | model = load_sharded_model_single_gpu(model_def, fsdp_checkpoint_path) 57 | print("model is loaded from FSDP checkpoints") 58 | #loading the tokenizer form the model_path 59 | tokenizer = LlamaTokenizer.from_pretrained(HF_model_path_or_name) 60 | tokenizer.save_pretrained(consolidated_model_path) 61 | #save the FSDP sharded checkpoints in HF format 62 | model.save_pretrained(consolidated_model_path) 63 | print(f"HuggingFace model checkpoints has been saved in {consolidated_model_path}") 64 | if __name__ == "__main__": 65 | fire.Fire(main) 66 | -------------------------------------------------------------------------------- /src/slam_llm/pipeline/inference.py: -------------------------------------------------------------------------------- 1 | # import fire 2 | import logging 3 | import random 4 | import torch 5 | # import argparse 6 | from slam_llm.models.slam_model import slam_model 7 | # config 8 | # from llama_recipes.configs import fsdp_config as FSDP_CONFIG 9 | # from llama_recipes.configs import train_config as TRAIN_CONFIG 10 | # from llama_recipes.configs import model_config as MODEL_CONFIG 11 | 12 | from slam_llm.utils.model_utils import get_custom_model_factory 13 | 14 | import hydra 15 | from omegaconf import DictConfig, ListConfig, OmegaConf 16 | 17 | 18 | @hydra.main(config_name=None, version_base=None) 19 | def main_hydra(cfg: DictConfig): 20 | def to_plain_list(cfg_item): 21 | if isinstance(cfg_item, ListConfig): 22 | return OmegaConf.to_container(cfg_item, resolve=True) 23 | elif isinstance(cfg_item, DictConfig): 24 | return {k: to_plain_list(v) for k, v in cfg_item.items()} 25 | else: 26 | return cfg_item 27 | 28 | # kwargs = to_plain_list(cfg) 29 | kwargs = cfg 30 | log_level = getattr(logging, kwargs.get("log_level", "INFO").upper()) 31 | 32 | logging.basicConfig(level=log_level) 33 | 34 | if kwargs.get("debug", False): 35 | import pdb; 36 | pdb.set_trace() 37 | 38 | main(kwargs) 39 | 40 | def main(kwargs: DictConfig): 41 | 42 | # Update the configuration for the training and sharding process 43 | # train_config, fsdp_config, model_config = TRAIN_CONFIG(), FSDP_CONFIG(), MODEL_CONFIG() 44 | # update_config((train_config, fsdp_config, model_config), **kwargs) 45 | train_config, fsdp_config, model_config, log_config, dataset_config = kwargs.train_config, \ 46 | kwargs.fsdp_config, \ 47 | kwargs.model_config, \ 48 | kwargs.log_config, \ 49 | kwargs.dataset_config 50 | 51 | del kwargs.train_config 52 | del kwargs.fsdp_config 53 | del kwargs.model_config 54 | del kwargs.log_config 55 | del kwargs.dataset_config 56 | 57 | # Set the seeds for reproducibility 58 | torch.cuda.manual_seed(train_config.seed) 59 | torch.manual_seed(train_config.seed) 60 | random.seed(train_config.seed) 61 | 62 | model_factory = get_custom_model_factory(model_config, logger) 63 | model, tokenizer = model_factory(train_config, model_config, **kwargs) 64 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # FIX(MZY): put the whole model to device. 65 | model.to(device) 66 | model.eval() 67 | 68 | while True: 69 | print("=====================================") 70 | wav_path = input("Your Wav Path:\n") 71 | prompt = input("Your Prompt:\n") 72 | # wav_path = kwargs.get('wav_path') 73 | # prompt = kwargs.get('prompt') 74 | try: 75 | model_outputs = model.inference(wav_path, prompt) 76 | output_text = model.tokenizer.batch_decode(model_outputs, add_special_tokens=False, skip_special_tokens=True) 77 | print(output_text) 78 | except: 79 | continue 80 | 81 | 82 | 83 | if __name__ == "__main__": 84 | main_hydra() -------------------------------------------------------------------------------- /examples/st_covost2/scripts/infer_hf.sh: -------------------------------------------------------------------------------- 1 | export MASTER_ADDR=localhost 2 | # export TOKENIZERS_PARALLELISM=false 3 | export MASTER_PORT=12345 4 | export WANDB_MODE=offline 5 | export CUDA_VISIBLE_DEVICES=0 6 | if [ -n "$CUDA_VISIBLE_DEVICES" ]; then 7 | gpu_count=$(echo "$CUDA_VISIBLE_DEVICES" | awk -F',' '{print NF}') 8 | elif command -v nvidia-smi &> /dev/null; then 9 | gpu_count=$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l) 10 | else 11 | gpu_count=1 # 12 | fi 13 | echo "GPU number: $gpu_count" 14 | 15 | # get the current script's directory 16 | current_script=$(readlink -f "$0") 17 | current_dir=$(dirname "$current_script") 18 | code=$(realpath "$current_dir/../../../../LLM-SRT") 19 | echo "Code path: ${code}" 20 | cd ${code} 21 | source=fleurs_enzh 22 | beam=5 23 | mode=srt 24 | validnum=-1 25 | 26 | peft=true 27 | if [ "$peft" = "true" ]; then 28 | freeze_llm="false" 29 | else 30 | freeze_llm="true" 31 | fi 32 | encoder_path_hf=${code}/models/whisper-large-v3 33 | llm_path=${code}/models/Qwen2.5-3B 34 | 35 | 36 | llm_name=$(basename "$llm_path") 37 | checkpoint_dir=${code}/models/llm-srt/qwen2.5-3b.pt 38 | 39 | 40 | ckpt_name=$checkpoint_dir 41 | echo "find .pt file: $ckpt_name" 42 | 43 | decode_log="${code}/${source}_${mode}_${llm_name}.jsonl" 44 | 45 | echo "Decode log saved to: ${decode_log}" 46 | 47 | if [ "$gpu_count" -gt 1 ]; then 48 | enable_ddp=true 49 | else 50 | enable_ddp=false 51 | fi 52 | 53 | # Inference 54 | torchrun \ 55 | --nnodes 1 \ 56 | --nproc_per_node ${gpu_count} \ 57 | --master_port=29503 \ 58 | ${code}/examples/st_covost2/inference_asr_batch.py \ 59 | --config-path "conf" \ 60 | --config-name "prompt.yaml" \ 61 | ++train_config.enable_fsdp=false \ 62 | ++train_config.enable_ddp=true \ 63 | ++fsdp_config.pure_bf16=true \ 64 | ++model_config.llm_name=$llm_name \ 65 | ++model_config.llm_path=$llm_path \ 66 | ++model_config.llm_dim=2048 \ 67 | ++model_config.query_len=80 \ 68 | ++model_config.encoder_name=whisper \ 69 | ++model_config.encoder_projector_ds_rate=5 \ 70 | ++model_config.encoder_path=$speech_encoder_path \ 71 | ++model_config.encoder_path_hf=$encoder_path_hf \ 72 | ++model_config.encoder_dim=1280 \ 73 | ++model_config.encoder_projector=q-former \ 74 | ++dataset_config.dataset=st_dataset \ 75 | ++dataset_config.file=examples/st_covost2/dataset/fleurs_dataset.py:get_speech_dataset \ 76 | ++dataset_config.val_data_path=$val_data_path \ 77 | ++dataset_config.input_type=mel \ 78 | ++dataset_config.fix_length_audio=80 \ 79 | ++dataset_config.mel_size=128 \ 80 | ++dataset_config.inference_mode=true \ 81 | ++dataset_config.source=$source \ 82 | ++train_config.model_name=asr \ 83 | ++train_config.freeze_encoder=true \ 84 | ++train_config.freeze_llm=true \ 85 | ++train_config.batching_strategy=custom \ 86 | ++train_config.num_epochs=1 \ 87 | ++train_config.enable_ddp=$enable_ddp \ 88 | ++train_config.val_batch_size=32 \ 89 | ++train_config.num_workers_dataloader=16 \ 90 | ++log_config.decode_log=$decode_log \ 91 | ++ckpt_path=$ckpt_name \ 92 | ++train_config.use_peft=true 93 | -------------------------------------------------------------------------------- /examples/st_covost2/scripts/all.sh: -------------------------------------------------------------------------------- 1 | export TOKENIZERS_PARALLELISM=false 2 | export WANDB_MODE=offline 3 | # export HYDRA_FULL_ERROR=1 4 | export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5 5 | if command -v nvidia-smi &> /dev/null; then 6 | gpu_count=$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l) 7 | fi 8 | if [ -n "$CUDA_VISIBLE_DEVICES" ]; then 9 | gpu_count=$(echo "$CUDA_VISIBLE_DEVICES" | awk -F',' '{print NF}') 10 | fi 11 | 12 | 13 | mode=srt 14 | 15 | 16 | 17 | 18 | echo "GPU number: $gpu_count" 19 | current_script=$(readlink -f "$0") 20 | current_dir=$(dirname "$current_script") 21 | code=$(realpath "$current_dir/../../../../LLM-SRT") 22 | cd ${code} 23 | source=fleurs 24 | validnum=-1 25 | peft=true 26 | 27 | 28 | if [ "$peft" = "true" ]; then 29 | freeze_llm="false" 30 | else 31 | freeze_llm="true" 32 | fi 33 | 34 | checkpoint_dir=${code}/models/output/qwen2.5-3B-mlp-15-srt-4 35 | output_dir=${code}/models/output/qwen2.5-3B-mlp-15-srt-5 36 | 37 | 38 | encoder_path_hf=${code}/models/whisper-large-v3 39 | llm_path=${code}/models/Qwen2.5-3B 40 | llm_name=$(basename "$llm_path") 41 | train_data_path=${code}/data/fleurs/wavs/train.jsonl 42 | val_data_path=${code}/data/fleurs/wavs/validation.jsonl 43 | 44 | 45 | 46 | max_epoch=$(ls -d ${checkpoint_dir}/asr_epoch_*_step_* | sed -n 's/.*asr_epoch_\([0-9]*\)_step_\([0-9]*\).*/\1/p' | sort -n | tail -1) 47 | max_step=$(ls -d ${checkpoint_dir}/asr_epoch_${max_epoch}_step_* | sed -n 's/.*asr_epoch_[0-9]*_step_\([0-9]*\).*/\1/p' | sort -n | tail -1) 48 | 49 | final_path="${checkpoint_dir}/asr_epoch_${max_epoch}_step_${max_step}" 50 | ckpt_name=$final_path/model.pt 51 | echo "find .pt file: $ckpt_name" 52 | 53 | 54 | 55 | 56 | 57 | 58 | hydra_args=" 59 | hydra.run.dir=$output_dir \ 60 | ++model_config.llm_name=Qwen2.5-7B \ 61 | ++model_config.llm_path=$llm_path \ 62 | ++model_config.llm_dim=2048 \ 63 | ++model_config.encoder_name=whisper \ 64 | ++model_config.encoder_projector_ds_rate=5 \ 65 | ++model_config.encoder_path=$speech_encoder_path \ 66 | ++model_config.encoder_path_hf=$encoder_path_hf \ 67 | ++model_config.encoder_dim=1280 \ 68 | ++model_config.encoder_projector=q-former \ 69 | ++model_config.query_len=80 \ 70 | ++dataset_config.dataset=srt_dataset \ 71 | ++dataset_config.file=examples/st_covost2/dataset/srt_dataset.py:get_speech_dataset \ 72 | ++dataset_config.train_data_path=$train_data_path \ 73 | ++dataset_config.val_data_path=$val_data_path \ 74 | ++dataset_config.input_type=mel \ 75 | ++dataset_config.mel_size=128 \ 76 | ++dataset_config.fix_length_audio=80 \ 77 | ++dataset_config.source=$source \ 78 | ++dataset_config.mode=$mode \ 79 | ++train_config.model_name=asr \ 80 | ++train_config.num_epochs=10 \ 81 | ++train_config.freeze_encoder=true \ 82 | ++train_config.freeze_llm=$freeze_llm \ 83 | ++train_config.batching_strategy=custom \ 84 | ++train_config.gradient_accumulation_steps=1 \ 85 | ++train_config.warmup_steps=200 \ 86 | ++train_config.total_steps=200000 \ 87 | ++train_config.lr=1e-4 \ 88 | ++train_config.batch_size_training=8 \ 89 | ++train_config.val_batch_size=16 \ 90 | ++train_config.num_workers_dataloader=8 \ 91 | ++train_config.output_dir=$output_dir \ 92 | ++metric=acc \ 93 | ++train_config.use_fp16=false \ 94 | ++dataset_config.validnum=$validnum \ 95 | ++train_config.use_fast_kernels=false \ 96 | ++ckpt_path=$ckpt_name \ 97 | " 98 | 99 | 100 | 101 | torchrun \ 102 | --nnodes 1 \ 103 | --nproc_per_node ${gpu_count} \ 104 | --master_port=29504 \ 105 | ${code}/examples/st_covost2/finetune_asr.py \ 106 | --config-path "conf" \ 107 | --config-name "prompt.yaml" \ 108 | ++train_config.enable_fsdp=false \ 109 | ++train_config.enable_ddp=true \ 110 | ++fsdp_config.pure_bf16=true \ 111 | ++log_config.use_wandb=true \ 112 | ++log_config.wandb_project_name=fleur-tts \ 113 | ++log_config.wandb_exp_name=yxduir \ 114 | ++train_config.validation_interval=1000 \ 115 | ++train_config.use_peft=${peft} \ 116 | $hydra_args 117 | -------------------------------------------------------------------------------- /examples/st_covost2/scripts/infer_all.sh: -------------------------------------------------------------------------------- 1 | export MASTER_ADDR=localhost 2 | # export TOKENIZERS_PARALLELISM=false 3 | export MASTER_PORT=12348 4 | export WANDB_MODE=offline 5 | export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5 6 | 7 | 8 | 9 | # GPU num 10 | if [ -n "$CUDA_VISIBLE_DEVICES" ]; then 11 | gpu_count=$(echo "$CUDA_VISIBLE_DEVICES" | awk -F',' '{print NF}') 12 | elif command -v nvidia-smi &> /dev/null; then 13 | gpu_count=$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l) 14 | else 15 | gpu_count=1 # 默认值 16 | fi 17 | 18 | echo "GPU number: $gpu_count" 19 | 20 | # get the current script's directory 21 | current_script=$(readlink -f "$0") 22 | current_dir=$(dirname "$current_script") 23 | code=$(realpath "$current_dir/../../../../LLM-SRT") 24 | echo "Code path: ${code}" 25 | cd ${code} 26 | source=wmt24_all 27 | mode=srt 28 | validnum=-2 29 | 30 | peft=true 31 | if [ "$peft" = "true" ]; then 32 | freeze_llm="false" 33 | else 34 | freeze_llm="true" 35 | fi 36 | 37 | 38 | checkpoint_dir=${code}/models/llm-srt/qwen2.5-3b.pt 39 | encoder_path_hf=/data_a100/models/whisper-large-v3 40 | llm_path=/data_a100/models/Qwen2.5-3B 41 | llm_name=$(basename "$llm_path") 42 | 43 | val_data_path=${code}/data/fleurs/wavs/test.jsonl 44 | 45 | 46 | # decode_log file path 47 | val_data_dir=$(dirname "${val_data_path}") 48 | val_data_basename=$(basename "${val_data_path}" .jsonl) 49 | decode_log="${code}/${val_data_basename}_${mode}_${llm_name}.jsonl" 50 | 51 | 52 | # max_epoch=$(ls -d ${checkpoint_dir}/asr_epoch_*_step_* | sed -n 's/.*asr_epoch_\([0-9]*\)_step_\([0-9]*\).*/\1/p' | sort -n | tail -1) 53 | # max_step=$(ls -d ${checkpoint_dir}/asr_epoch_${max_epoch}_step_* | sed -n 's/.*asr_epoch_[0-9]*_step_\([0-9]*\).*/\1/p' | sort -n | tail -1) 54 | 55 | # final_path="${checkpoint_dir}/asr_epoch_${max_epoch}_step_${max_step}" 56 | # ckpt_name=$final_path/model.pt 57 | 58 | ckpt_name=$checkpoint_dir 59 | 60 | echo "find .pt file: $ckpt_name" 61 | 62 | 63 | 64 | echo "Decode log saved to: ${decode_log}" 65 | if [ "$gpu_count" -gt 1 ]; then 66 | enable_ddp=true 67 | else 68 | enable_ddp=false 69 | fi 70 | 71 | # Inference 72 | torchrun \ 73 | --nnodes 1 \ 74 | --nproc_per_node ${gpu_count} \ 75 | --master_port=29508 \ 76 | ${code}/examples/st_covost2/inference_asr_batch.py \ 77 | --config-path "conf" \ 78 | --config-name "prompt.yaml" \ 79 | ++train_config.enable_fsdp=false \ 80 | ++train_config.enable_ddp=true \ 81 | ++fsdp_config.pure_bf16=true \ 82 | ++model_config.llm_name=$llm_name \ 83 | ++model_config.llm_path=$llm_path \ 84 | ++model_config.llm_dim=2048 \ 85 | ++model_config.query_len=80 \ 86 | ++model_config.encoder_name=whisper \ 87 | ++model_config.encoder_projector_ds_rate=5 \ 88 | ++model_config.encoder_path=$speech_encoder_path \ 89 | ++model_config.encoder_path_hf=$encoder_path_hf \ 90 | ++model_config.encoder_dim=1280 \ 91 | ++model_config.encoder_projector=q-former \ 92 | ++dataset_config.dataset=st_dataset \ 93 | ++dataset_config.file=examples/st_covost2/dataset/srt_dataset.py:get_speech_dataset \ 94 | ++dataset_config.val_data_path=$val_data_path \ 95 | ++dataset_config.input_type=mel \ 96 | ++dataset_config.fix_length_audio=80 \ 97 | ++dataset_config.mel_size=128 \ 98 | ++dataset_config.inference_mode=true \ 99 | ++dataset_config.source=$source \ 100 | ++dataset_config.mode=$mode \ 101 | ++dataset_config.validnum=$validnum \ 102 | ++train_config.model_name=asr \ 103 | ++train_config.freeze_encoder=true \ 104 | ++train_config.freeze_llm=$freeze_llm \ 105 | ++train_config.batching_strategy=custom \ 106 | ++train_config.num_epochs=1 \ 107 | ++train_config.val_batch_size=8 \ 108 | ++train_config.enable_ddp=$enable_ddp \ 109 | ++train_config.num_workers_dataloader=16 \ 110 | ++train_config.enable_ddp=$enable_ddp \ 111 | ++log_config.decode_log=$decode_log \ 112 | ++ckpt_path=$ckpt_name \ 113 | ++train_config.use_peft=${peft} \ 114 | -------------------------------------------------------------------------------- /src/slam_llm/models/projector.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class EncoderProjectorConcat(nn.Module): 6 | def __init__(self, config): 7 | super().__init__() 8 | self.k = config.encoder_projector_ds_rate 9 | self.encoder_dim = config.encoder_dim 10 | self.llm_dim = config.llm_dim 11 | self.linear1 = nn.Linear(self.encoder_dim * self.k, 2048) 12 | self.relu = nn.ReLU() 13 | self.linear2 = nn.Linear(2048, config.llm_dim) 14 | 15 | def forward(self, x): 16 | batch_size, seq_len, dim = x.size() 17 | num_frames_to_discard = seq_len % self.k 18 | if num_frames_to_discard > 0: 19 | x = x[:, :-num_frames_to_discard, :] 20 | seq_len = x.size(1) 21 | 22 | x = x.contiguous() 23 | x = x.view(batch_size, seq_len // self.k, dim * self.k) 24 | x = self.linear1(x) 25 | x = self.relu(x) 26 | x = self.linear2(x) 27 | return x 28 | 29 | class EncoderProjectorCov1d(nn.Module): 30 | def __init__(self, config): 31 | super().__init__() 32 | self.k = config.encoder_projector_ds_rate 33 | self.encoder_dim = config.encoder_dim 34 | self.llm_dim = config.llm_dim 35 | self.conv1d = nn.Conv1d(in_channels=self.encoder_dim, out_channels=self.encoder_dim, kernel_size=self.k, stride=self.k, padding=0) 36 | self.linear1 = nn.Linear(self.encoder_dim, 2048) 37 | self.relu1 = nn.ReLU() 38 | self.linear2 = nn.Linear(2048, self.llm_dim) 39 | self.relu2 = nn.ReLU() 40 | 41 | def forward(self, x): 42 | x = x.transpose(1, 2) 43 | x = self.conv1d(x) 44 | x = x.transpose(1, 2) 45 | x = self.relu1(x) 46 | x = self.linear1(x) 47 | x = self.relu2(x) 48 | x = self.linear2(x) 49 | return x 50 | 51 | class EncoderProjectorQFormer(nn.Module): 52 | def __init__(self, config): 53 | super().__init__() 54 | self.encoder_dim = config.encoder_dim 55 | self.llm_dim = config.llm_dim 56 | from transformers import Blip2QFormerConfig, Blip2QFormerModel 57 | configuration = Blip2QFormerConfig() 58 | configuration.encoder_hidden_size = self.encoder_dim 59 | configuration.num_hidden_layers = config.qformer_layers 60 | 61 | 62 | self.query_len = int(config.get("query_len", 80)) 63 | self.query = nn.Parameter(torch.zeros(1, self.query_len, configuration.hidden_size)) 64 | self.query.data.normal_(mean=0.0, std=1.0) 65 | self.qformer = Blip2QFormerModel(configuration) 66 | 67 | # (encoder维度)1280->2048(llm维度) 3B 68 | # 1280->5120 32B 69 | if self.llm_dim <= 1536: 70 | self.linear = nn.Linear(configuration.hidden_size, self.llm_dim) 71 | self.norm = nn.LayerNorm(self.llm_dim, eps=1e-5) 72 | elif self.llm_dim <= 2560: 73 | self.linear1 = nn.Linear(configuration.hidden_size, 1536) # 从 768 -> 2560 74 | self.relu = nn.ReLU() # 激活函数 75 | self.linear2 = nn.Linear(1536, self.llm_dim) # 从 2560 -> 5120 76 | self.norm = nn.LayerNorm(self.llm_dim, eps=1e-5) # 最终归一化 77 | else: 78 | self.linear1 = nn.Linear(configuration.hidden_size, 2560) # 从 768 -> 2560 79 | self.relu = nn.ReLU() # 激活函数 80 | self.linear2 = nn.Linear(2560, self.llm_dim) # 从 2560 -> 5120 81 | self.norm = nn.LayerNorm(self.llm_dim, eps=1e-5) # 最终归一化 82 | 83 | 84 | 85 | def forward(self, x, atts): 86 | query = self.query.expand(x.shape[0], -1, -1) 87 | 88 | query_output = self.qformer( 89 | query_embeds=query, 90 | encoder_hidden_states=x, 91 | encoder_attention_mask=atts, 92 | return_dict=True, 93 | ) 94 | 95 | if self.llm_dim <= 1536: 96 | query_proj = self.norm(self.linear(query_output.last_hidden_state)) 97 | else: 98 | x = self.linear1(query_output.last_hidden_state) # 从 1280 -> 2560 99 | x = self.relu(x) # 激活 100 | x = self.linear2(x) # 从 2560 -> 5120 101 | query_proj = self.norm(x) # LayerNorm 归一化 102 | 103 | 104 | return query_proj -------------------------------------------------------------------------------- /src/slam_llm/utils/config_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. 3 | 4 | import inspect 5 | # from dataclasses import asdict 6 | 7 | import torch.distributed as dist 8 | from torch.utils.data import DistributedSampler 9 | from peft import ( 10 | LoraConfig, 11 | AdaptionPromptConfig, 12 | PrefixTuningConfig, 13 | ) 14 | from transformers import default_data_collator 15 | from transformers.data import DataCollatorForSeq2Seq 16 | 17 | # from llama_recipes.configs import datasets, lora_config, llama_adapter_config, prefix_config, train_config 18 | from slam_llm.data.sampler import LengthBasedBatchSampler, DistributedLengthBasedBatchSampler 19 | 20 | from omegaconf import OmegaConf 21 | 22 | import logging 23 | logger = logging.getLogger(__name__) 24 | 25 | # def update_config(config, **kwargs): 26 | # if isinstance(config, (tuple, list)): 27 | # for c in config: 28 | # update_config(c, **kwargs) 29 | # else: 30 | # for k, v in kwargs.items(): 31 | # if hasattr(config, k): 32 | # setattr(config, k, v) 33 | # elif "." in k: 34 | # # allow --some_config.some_param=True 35 | # config_name, param_name = k.split(".") 36 | # if type(config).__name__ == config_name: 37 | # if hasattr(config, param_name): 38 | # setattr(config, param_name, v) 39 | # else: 40 | # # In case of specialized config we can warm user 41 | # logger.warning(f"Warning: {config_name} does not accept parameter: {k}") 42 | # elif isinstance(config, train_config): 43 | # logger.warning(f"Warning: unknown parameter {k}") 44 | 45 | 46 | def generate_peft_config(train_config): 47 | # configs = (lora_config, llama_adapter_config, prefix_config) 48 | # peft_configs = (LoraConfig, AdaptionPromptConfig, PrefixTuningConfig) 49 | peft_configs = {"lora": LoraConfig, 50 | "llama_adapter": AdaptionPromptConfig, 51 | "prefix": PrefixTuningConfig 52 | } 53 | # names = tuple(c.__name__.rstrip("_config") for c in configs) 54 | # 55 | # assert train_config.peft_method in names, f"Peft config not found: {train_config.peft_method}" 56 | # 57 | # config = configs[names.index(train_config.peft_method)]() 58 | config = train_config.peft_config 59 | 60 | params = OmegaConf.to_container(config, resolve=True) 61 | # peft_config = peft_configs[names.index(train_config.peft_method)](**params) 62 | params.pop("peft_method", None) #(FIX:MZY): remove peft_method from params to avoid error 63 | peft_config = peft_configs[config.get("peft_method", "lora")](**params) 64 | 65 | return peft_config 66 | 67 | 68 | def get_dataloader_kwargs(train_config, dataset, tokenizer, mode): 69 | kwargs = {} 70 | batch_size = train_config.batch_size_training if mode=="train" else train_config.val_batch_size 71 | if train_config.batching_strategy == "padding": 72 | if train_config.enable_fsdp or train_config.enable_ddp or train_config.enable_deepspeed: 73 | kwargs["batch_sampler"] = DistributedLengthBasedBatchSampler( 74 | dataset, 75 | batch_size=batch_size, 76 | rank=dist.get_rank(), 77 | num_replicas=dist.get_world_size(), 78 | shuffle=mode=="train", 79 | ) 80 | else: 81 | kwargs["batch_sampler"] = LengthBasedBatchSampler(dataset, batch_size, drop_last=True, shuffle=mode=="train") 82 | kwargs["collate_fn"] = DataCollatorForSeq2Seq(tokenizer) 83 | elif train_config.batching_strategy == "packing": 84 | if train_config.enable_fsdp or train_config.enable_ddp or train_config.enable_deepspeed: 85 | kwargs["sampler"] = DistributedSampler( 86 | dataset, 87 | rank=dist.get_rank(), 88 | num_replicas=dist.get_world_size(), 89 | shuffle=mode=="train", 90 | ) 91 | kwargs["batch_size"] = batch_size 92 | kwargs["drop_last"] = True 93 | kwargs["collate_fn"] = default_data_collator 94 | else: 95 | # raise ValueError(f"Unknown batching strategy: {train_config.batching_strategy}") 96 | if train_config.enable_fsdp or train_config.enable_ddp or train_config.enable_deepspeed: 97 | kwargs["sampler"] = DistributedSampler( 98 | dataset, 99 | rank=dist.get_rank(), 100 | num_replicas=dist.get_world_size(), 101 | shuffle=mode=="train", 102 | ) 103 | kwargs["batch_size"] = batch_size 104 | kwargs["drop_last"] = True 105 | kwargs["collate_fn"] = dataset.collator 106 | logger.info(f"Using batching strategy: {train_config.batching_strategy}") 107 | 108 | return kwargs 109 | -------------------------------------------------------------------------------- /src/slam_llm/pipeline/inference_batch.py: -------------------------------------------------------------------------------- 1 | # import fire 2 | import random 3 | import torch 4 | import logging 5 | # import argparse 6 | from slam_llm.models.slam_model import slam_model 7 | # config 8 | # from llama_recipes.configs import fsdp_config as FSDP_CONFIG 9 | # from llama_recipes.configs import train_config as TRAIN_CONFIG 10 | # from llama_recipes.configs import model_config as MODEL_CONFIG 11 | # from llama_recipes.configs import log_config as LOG_CONFIG 12 | 13 | from slam_llm.utils.model_utils import get_custom_model_factory 14 | from slam_llm.utils.dataset_utils import get_preprocessed_dataset 15 | import os 16 | import logging 17 | from tqdm import tqdm 18 | 19 | import hydra 20 | from omegaconf import DictConfig, ListConfig, OmegaConf 21 | 22 | 23 | @hydra.main(config_name=None, version_base=None) 24 | def main_hydra(cfg: DictConfig): 25 | def to_plain_list(cfg_item): 26 | if isinstance(cfg_item, ListConfig): 27 | return OmegaConf.to_container(cfg_item, resolve=True) 28 | elif isinstance(cfg_item, DictConfig): 29 | return {k: to_plain_list(v) for k, v in cfg_item.items()} 30 | else: 31 | return cfg_item 32 | 33 | # kwargs = to_plain_list(cfg) 34 | kwargs = cfg 35 | log_level = getattr(logging, kwargs.get("log_level", "INFO").upper()) 36 | 37 | logging.basicConfig(level=log_level) 38 | 39 | if kwargs.get("debug", False): 40 | import pdb; 41 | pdb.set_trace() 42 | 43 | main(kwargs) 44 | 45 | 46 | def main(kwargs: DictConfig): 47 | 48 | # Update the configuration for the training and sharding process 49 | # train_config, fsdp_config, model_config, log_config = TRAIN_CONFIG(), FSDP_CONFIG(), MODEL_CONFIG(), LOG_CONFIG() 50 | # update_config((train_config, fsdp_config, model_config, log_config), **kwargs) 51 | train_config, fsdp_config, model_config, log_config, dataset_config = kwargs.train_config, \ 52 | kwargs.fsdp_config, \ 53 | kwargs.model_config, \ 54 | kwargs.log_config, \ 55 | kwargs.dataset_config 56 | 57 | OmegaConf.set_struct(kwargs,False) 58 | del kwargs["train_config"] 59 | del kwargs["fsdp_config"] 60 | del kwargs["model_config"] 61 | del kwargs["log_config"] 62 | del kwargs["dataset_config"] 63 | OmegaConf.set_struct(kwargs,True) 64 | 65 | # Set log 66 | if not os.path.exists(os.path.dirname(log_config.log_file)): 67 | os.makedirs(os.path.dirname(log_config.log_file), exist_ok=True) 68 | logging.basicConfig( 69 | level=logging.INFO, 70 | format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", 71 | datefmt="%Y-%m-%d %H:%M:%S", 72 | filemode='w' 73 | ) 74 | 75 | logger = logging.getLogger() 76 | logger.setLevel(logging.INFO) 77 | 78 | file_handler = logging.FileHandler(filename=log_config.log_file, mode='w') 79 | file_handler.setLevel(logging.INFO) 80 | file_formatter = logging.Formatter('[%(asctime)s][%(name)s][%(levelname)s] - %(message)s', datefmt='%Y-%m-%d %H:%M:%S') 81 | file_handler.setFormatter(file_formatter) 82 | 83 | logger.handlers[0].setLevel(logging.INFO) 84 | console_formatter = logging.Formatter('[%(asctime)s][%(name)s][%(levelname)s] - %(message)s', datefmt='%Y-%m-%d %H:%M:%S') 85 | logger.handlers[0].setFormatter(console_formatter) 86 | 87 | logger.addHandler(file_handler) 88 | 89 | logger.info("train_config: {}".format(train_config)) 90 | logger.info("fsdp_config: {}".format(fsdp_config)) 91 | logger.info("model_config: {}".format(model_config)) 92 | 93 | 94 | # Set the seeds for reproducibility 95 | torch.cuda.manual_seed(train_config.seed) 96 | torch.manual_seed(train_config.seed) 97 | random.seed(train_config.seed) 98 | 99 | model_factory = get_custom_model_factory(model_config, logger) 100 | model, tokenizer = model_factory(train_config, model_config, **kwargs) 101 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # FIX(MZY): put the whole model to device. 102 | model.to(device) 103 | model.eval() 104 | 105 | # dataset_config = generate_dataset_config(train_config, kwargs) 106 | logger.info("dataset_config: {}".format(dataset_config)) 107 | dataset_test = get_preprocessed_dataset( 108 | tokenizer, 109 | dataset_config, 110 | split="test", 111 | ) 112 | if not (train_config.enable_fsdp or train_config.enable_ddp) or rank == 0: 113 | logger.info(f"--> Training Set Length = {len(dataset_test)}") 114 | 115 | test_dataloader = torch.utils.data.DataLoader( 116 | dataset_test, 117 | num_workers=train_config.num_workers_dataloader, 118 | pin_memory=True, 119 | shuffle=False, 120 | batch_size=train_config.val_batch_size, 121 | drop_last=False, 122 | collate_fn=dataset_test.collator 123 | ) 124 | 125 | 126 | logger.info("=====================================") 127 | pred_path = kwargs.get('decode_log') + "_pred" 128 | gt_path = kwargs.get('decode_log') + "_gt" 129 | with open(pred_path, "w") as pred, open(gt_path, "w") as gt: 130 | for step, batch in tqdm(enumerate(test_dataloader), total=len(test_dataloader)): 131 | for key in batch.keys(): 132 | batch[key] = batch[key].to(device) if isinstance(batch[key], torch.Tensor) else batch[key] 133 | model_outputs = model.generate(**batch) 134 | output_text = model.tokenizer.batch_decode(model_outputs, add_special_tokens=False, skip_special_tokens=True) 135 | for key, text, target in zip(batch["keys"], output_text, batch["targets"]): 136 | pred.write(key + "\t" + text.replace("\n", " ") + "\n") 137 | gt.write(key + "\t" + target + "\n") 138 | 139 | 140 | if __name__ == "__main__": 141 | main_hydra() -------------------------------------------------------------------------------- /examples/st_covost2/asr_config.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from typing import Optional, List 3 | from torch.distributed.fsdp import ShardingStrategy 4 | from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType 5 | 6 | @dataclass 7 | class ModelConfig: 8 | file: str = "examples/st_covost2/model/slam_model_st.py" 9 | llm_name: str = "vicuna-13b-v1.5" 10 | llm_path: str = "PATH/to/LLAMA/7B" 11 | llm_type: str = "decoder_only" 12 | llm_dim: int = 3584 13 | encoder_path_hf: Optional[str] = None 14 | encoder_name: Optional[str] = None 15 | encoder_ds_rate: int = 2 16 | encoder_path: Optional[str] = None 17 | encoder_dim: int = 1280 18 | encoder_projector: str = "linear" 19 | encoder_projector_ds_rate: int = 5 20 | modal: str = "audio" 21 | normalize: Optional[bool] = field(default=False, metadata={ 22 | "help": "whether input is normalized, used for models such as wavlm" 23 | }) 24 | encoder_type: str = field(default="finetune", metadata={ 25 | "help": "whether model is only pretrained or finetuned, used for models such as hubert" 26 | }) 27 | query_len: Optional[str] = None 28 | qformer_layers: int = 8 29 | beam: int = 1 30 | 31 | 32 | 33 | 34 | @dataclass 35 | class PeftConfig: 36 | peft_method: str = "lora" # None , llama_adapter, prefix 37 | r: int = 8 38 | lora_alpha: int = 32 39 | target_modules: List = field(default_factory=lambda: [ "q_proj", "v_proj" ]) 40 | bias: str = "none" 41 | task_type: str = "CAUSAL_LM" 42 | lora_dropout: float = 0.05 43 | inference_mode: bool = False 44 | 45 | @dataclass 46 | class TrainConfig: 47 | model_name:str = "PATH/to/LLAMA/7B" 48 | enable_ddp:bool = False 49 | enable_deepspeed:bool = False 50 | enable_fsdp:bool = False 51 | low_cpu_fsdp:bool = False 52 | run_validation:bool = True 53 | batch_size_training:int = 4 54 | batching_strategy:str = field(default="packing", metadata={ 55 | "help":"alternative: padding" 56 | }) # 57 | context_length:int = 4096 58 | gradient_accumulation_steps:int = 1 59 | num_epochs:int = 3 60 | num_workers_dataloader:int = 1 61 | warmup_steps:int = 1000 62 | total_steps:int = 100000 63 | validation_interval:int = 1000 64 | lr:float = 1e-4 65 | weight_decay:float = 0.01 66 | gamma:float = 0.85 67 | seed:int = 42 68 | use_fp16:bool = False 69 | mixed_precision:bool = True 70 | val_batch_size:int = 1 71 | 72 | use_peft:bool = False 73 | peft_config:PeftConfig = field(default_factory=PeftConfig) 74 | output_dir:str = "PATH/to/save/PEFT/model" 75 | freeze_layers:bool = False 76 | num_freeze_layers:int = 1 77 | quantization:bool = False 78 | one_gpu:bool = False 79 | save_model:bool = True 80 | dist_checkpoint_root_folder:str = "PATH/to/save/FSDP/model" # will be used if using FSDP 81 | dist_checkpoint_folder:str = "fine-tuned" # will be used if using FSDP 82 | save_optimizer:bool = False # will be used if using FSDP 83 | use_fast_kernels:bool = False # Enable using SDPA from PyTroch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels 84 | run_test_during_validation:bool = False 85 | run_test_during_validation_file:str = "test.wav" 86 | run_test_during_validation_prompt:str = "<|ASR|>" 87 | freeze_llm:bool = field(default=False, metadata={ 88 | "help": "whether to freeze llm when finetuning, should be true when use peft finetuning" 89 | }) 90 | freeze_encoder:bool = False 91 | 92 | @dataclass 93 | class DataConfig: 94 | dataset: str = "st_dataset" 95 | file: str = "examples/st_covost2/dataset/st_dataset.py:get_speech_dataset" 96 | train_data_path: Optional[str] = None 97 | val_data_path: Optional[str] = None 98 | train_split: str = "train" 99 | test_split:str = "test" 100 | prompt: Optional[str] = None 101 | data_path: Optional[str] = None 102 | max_words: Optional[int] = None 103 | max_mel: Optional[float] = None 104 | fix_length_audio: int = -1 105 | inference_mode:bool = False 106 | input_type: str = field(default="raw", metadata={ 107 | "help":"Use raw when input is wav, mel when for whisper" 108 | }) 109 | mel_size: int = field(default=80, metadata={ 110 | "help": "80 for whisper large v1 and v2, 128 for v3" 111 | }) 112 | normalize: Optional[bool] = field(default=False, metadata={ 113 | "help": "whether input is normalized, used for models such as wavlm" 114 | }) 115 | bf16:bool = True 116 | fp16:bool = True 117 | source: Optional[str] = None 118 | mode: Optional[str] = None 119 | validnum: int = -1 120 | 121 | @dataclass 122 | class FSDPConfig: 123 | mixed_precision: bool = True 124 | use_fp16: bool = False 125 | sharding_strategy: ShardingStrategy = ShardingStrategy.FULL_SHARD # HYBRID_SHARD "Full Shard within a node DDP cross Nodes", SHARD_GRAD_OP "Shard only Gradients and Optimizer States", NO_SHARD "Similar to DDP". 126 | checkpoint_type: StateDictType = StateDictType.SHARDED_STATE_DICT # alternatively can use SHARDED_STATE_DICT save one file per rank, and can resize the world-size. 127 | fsdp_activation_checkpointing: bool = True 128 | fsdp_cpu_offload: bool = False 129 | pure_bf16: bool = False 130 | optimizer: str = "AdamW" 131 | 132 | @dataclass 133 | class LogConfig: 134 | use_wandb: bool = False 135 | wandb_dir: str = "test_wandb" 136 | wandb_entity_name: str = "yxduir" 137 | wandb_project_name: str = "project_name" 138 | wandb_exp_name: str = "exp_name" 139 | log_file: str = "./test.log" 140 | log_interval: int = 50 141 | decode_log: str = "./test.log" 142 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## [ACL 2025 Main] Making LLMs Better Many-to-Many Speech-to-Text Translators with Curriculum Learning 4 | 5 | **LLM-SRT paper**: [https://arxiv.org/abs/2409.19510](https://arxiv.org/abs/2409.19510); 6 | 7 | **MCAT paper**: [https://arxiv.org/abs/2512.01512v1](https://arxiv.org/abs/2512.01512v1); 8 | 9 | This project is a subproject of [**SLAM-LLM**](https://github.com/X-LANCE/SLAM-LLM). 10 | 11 | ✅ **Current Version LLM-SRT (v1.0)** 12 | - **Supported 15 Languages**: Chinese (zho), English (eng), Japanese (jpn), Korean (kor), German (deu), French (fra), Indonesian (ind), Italian (ita), Dutch (nld), Portuguese (por), Russian (rus), Spanish (spa), Thai (tha), Vietnamese (vie), Cantonese (yue) 13 | - **210 Translation Directions** - Supports all 210 possible translation directions (15×14 language pairs) 14 | 15 | 🚀 **MCAT (v2.0)**: Code and Model: [https://github.com/yxduir/m2m-70](https://github.com/yxduir/m2m-70) 16 | - **Supported 70 Languages**: Afrikaans (afr), Amharic (amh), Arabic (ara), Assamese (asm), Azerbaijani (azj), Belarusian (bel), Bengali (ben), Bosnian (bos), Bulgarian (bul), Catalan (cat), Czech (ces), Chinese (cmn), Welsh (cym), Danish (dan), German (deu), Greek (ell), English (eng), Estonian (est), Persian (fas), Finnish (fin), French (fra), Galician (glg), Gujarati (guj), Hebrew (heb), Hindi (hin), Croatian (hrv), Hungarian (hun), Armenian (hye), Indonesian (ind), Icelandic (isl), Italian (ita), Javanese (jav), Japanese (jpn), Kannada (kan), Georgian (kat), Kazakh (kaz), Khmer (khm), Kyrgyz (kir), Korean (kor), Lao (lao), Latvian (lav), Lithuanian (lit), Malayalam (mal), Macedonian (mkd), Malay (msa), Burmese (mya), Dutch (nld), Norwegian (nob), Nepali (npi), Punjabi (pan), Polish (pol), Portuguese (por), Romanian (ron), Russian (rus), Slovak (slk), Slovenian (slv), Spanish (spa), Serbian (srp), Swedish (swe), Swahili (swh), Tamil (tam), Telugu (tel), Tagalog (tgl), Thai (tha), Turkish (tur), Ukrainian (ukr), Urdu (urd), Uzbek (uzb), Vietnamese (vie), Cantonese (yue) 17 | - **4830 Translation Directions** - Supports all 4830 possible translation directions (70×69 language pairs) 18 | 19 | 20 | ## Installation 21 | ``` 22 | sudo apt-get install python3-setuptools 23 | 24 | conda create -n llm-srt python=3.10 -y 25 | conda activate llm-srt 26 | 27 | git clone https://github.com/yxduir/LLM-SRT 28 | cd LLM-SRT 29 | 30 | pip install -r requirements.txt 31 | pip install -e . 32 | sudo apt install ffmpeg 33 | ``` 34 | 35 | ## Download Model 36 | Encoder | Adapter | LLM 37 | |---|---|--- 38 | [whisper-large-v3](https://huggingface.co/openai/whisper-large-v3) | [q-former+mlp](https://huggingface.co/yxdu/llm-srt) | [Qwen2.5-3B](https://huggingface.co/Qwen/Qwen2.5-3B) 39 | ``` 40 | cd models/ 41 | 42 | # Total 29G, for 3B model (support 15 languages) 43 | hf download yxdu/llm-srt --local-dir llm-srt 44 | hf download openai/whisper-large-v3 --local-dir whisper-large-v3 45 | hf download Qwen/Qwen2.5-3B --local-dir Qwen2.5-3B 46 | 47 | cd .. 48 | ``` 49 | 50 | ## Infer Demo 51 | This is an automatic inference script for the fleurs dataset from English (eng) to Chinese (zho). 52 | ``` 53 | bash examples/st_covost2/scripts/infer_hf.sh 54 | ``` 55 | 56 | ## Train Dataset 57 | If you want to train your own model, you can download the following datasets. 58 | ``` 59 | [Common Voice](https://commonvoice.mozilla.org/en/datasets) 60 | 61 | [Fleurs](https://huggingface.co/datasets/google/fleurs) 62 | ``` 63 | 64 | 65 | 66 | ## Data preparation 67 | You need to prepare the data jsonl in this format. 68 | | audio | source | prompt | gt | 69 | |------------|------------------|----------------------------|---------------| 70 | | audio_path | `{name}_{src}_{tgt}` | `<\|{src}\|><\|{tgt}\|>`| `transcription{prompt}translation` | 71 | ``` 72 | {"audio": "eng/test/139.wav", "source": "fleurs_eng_zho", "prompt": "<|eng|><|zho|>", "gt": "They have feet with scales and claws, they lay eggs, and they walk on their two back legs like a T-Rex.<|eng|><|zho|>它们脚上有鳞片和爪子,会产卵,还像霸王龙一样用两条后腿走路。"} 73 | {"audio": "deu/test/0.wav", "source": "fleurs_deu_ara", "prompt": "<|deu|><|ara|>", "gt": "Für die besten Aussichten auf Hongkong sollten Sie die Insel verlassen und zum gegenüberliegenden Ufer von Kowloon fahren.<|deu|><|ara|>لكي تحظى بأفضل المشاهد لهونج كونج، غادر الجزيرة واتجه إلى واجهة كولون البحرية في الجهة المقابلة."} 74 | {"audio": "jpn/test/485.wav", "source": "fleurs_jpn_ita", "prompt": "<|jpn|><|ita|>", "gt": "これらの結晶の組成は、赤外分光法(FTIR)で比較すると、患部のペットの尿中に見られるものと一致します。<|jpn|><|ita|>Al confronto mediante spettroscopia infrarossa (FT-IR), la composizione di questi cristalli corrisponde a quella individuata nell'urina degli animali da compagnia che ne sono colpiti."} 75 | ``` 76 | ## Training and Inference 77 | You can use the following scripts to perform training and inference separately. 78 | For all.sh, you can modify the training task based on the 'mode' keyword: asr, smt, srt. 79 | ``` 80 | #train 81 | bash examples/st_covost2/scripts/all.sh 82 | 83 | 84 | #infer 85 | bash examples/st_covost2/scripts/infer_all.sh 86 | bash examples/st_covost2/scripts/infer_hf.sh 87 | ``` 88 | 89 | 90 | ## Citation 91 | ``` 92 | @article{du2025speech2text, 93 | title = {Making LLMs Better Many-to-Many Speech-to-Text Translators with Curriculum Learning}, 94 | author = {Du, Yexing and Pan, Youcheng and Ma, Ziyang and Yang, Bo and Yang, Yifang and Deng, Keqi and Chen, Xie and Xiang, Yang and Liu, Ming and Qin, Bing}, 95 | booktitle = {Proceedings of the 63rd Annual Meeting of the Association for Computational Linguistics (ACL 2025)}, 96 | year = {2025}, 97 | } 98 | @misc{du2025mcatscalingmanytomanyspeechtotext, 99 | title={MCAT: Scaling Many-to-Many Speech-to-Text Translation with MLLMs to 70 Languages}, 100 | author={Yexing Du and Kaiyuan Liu and Youcheng Pan and Bo Yang and Keqi Deng and Xie Chen and Yang Xiang and Ming Liu and Bin Qin and YaoWei Wang}, 101 | year={2025}, 102 | eprint={2512.01512}, 103 | archivePrefix={arXiv}, 104 | primaryClass={cs.CL}, 105 | url={https://arxiv.org/abs/2512.01512}, 106 | } 107 | ``` -------------------------------------------------------------------------------- /src/slam_llm/inference/safety_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. 3 | 4 | import os 5 | import torch 6 | import warnings 7 | 8 | 9 | # Class for performing safety checks using AuditNLG library 10 | class AuditNLGSensitiveTopics(object): 11 | def __init__(self): 12 | pass 13 | 14 | def __call__(self, output_text): 15 | try: 16 | from auditnlg.safety.exam import safety_scores 17 | except ImportError as e: 18 | print("Could not import optional dependency: auditnlg\nPlease install manually with:\n pip install auditnlg\nFollowed by:\npip install -r requirements.txt") 19 | raise e 20 | 21 | 22 | data = [{"output": output_text}] 23 | 24 | result = safety_scores(data=data, method="sensitive_topics") 25 | scores = result[1]["all_scores"][0] 26 | is_safe = scores["pred_class"] == "none" 27 | report = "" 28 | if not is_safe: 29 | report += f"Predicted class: {scores['pred_class']}\n" 30 | report += "|" + "|".join(f"{n:^10}" for n in [list(k.keys())[0] for k in scores["class_scores"]]) + "|\n" 31 | report += "|" + "|".join(f"{n:^10.5}" for n in [list(k.values())[0] for k in scores["class_scores"]]) + "|\n" 32 | return "Sensitive Topics", is_safe, report 33 | 34 | 35 | class SalesforceSafetyChecker(object): 36 | def __init__(self): 37 | pass 38 | 39 | def __call__(self, output_text): 40 | from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, AutoConfig 41 | 42 | config = AutoConfig.from_pretrained("Salesforce/safety-flan-t5-base") 43 | tokenizer = AutoTokenizer.from_pretrained("Salesforce/safety-flan-t5-base") 44 | model = AutoModelForSeq2SeqLM.from_pretrained("Salesforce/safety-flan-t5-base", config=config) 45 | model.eval() 46 | 47 | prefix = "Is the field safe or unsafe?" 48 | input_ids = tokenizer(prefix + " " + output_text + " ", return_tensors="pt").input_ids 49 | 50 | if len(input_ids[0]) > 512: 51 | warnings.warn( 52 | "Input length is > 512 token. Safety check result could be incorrect." 53 | ) 54 | 55 | with torch.no_grad(): 56 | outputs = model.generate( 57 | input_ids, 58 | output_scores=True, 59 | return_dict_in_generate=True, 60 | max_new_tokens=20, 61 | ) 62 | 63 | is_safe = tokenizer.decode(outputs.sequences[0], skip_special_tokens=True).split(" ")[0] == "safe" 64 | 65 | report = "" 66 | if not is_safe: 67 | true_false_ids = tokenizer("true false").input_ids[:2] 68 | keys = ["toxicity", "hate", "identity", "violence", "physical", "sexual", "profanity", "biased"] 69 | scores = {} 70 | for k, i in zip(keys, range(3,20,2)): 71 | scores[k] = round(outputs.scores[i][0,true_false_ids].softmax(dim=0)[0].item(), 5) 72 | 73 | report += "|" + "|".join(f"{n:^10}" for n in scores.keys()) + "|\n" 74 | report += "|" + "|".join(f"{n:^10}" for n in scores.values()) + "|\n" 75 | return "Salesforce Content Safety Flan T5 Base", is_safe, report 76 | 77 | 78 | def get_total_length(self, data): 79 | prefix = "Is the field safe or unsafe " 80 | input_sample = " {output} ".format(**data[0]) 81 | 82 | return len(self.tokenizer(prefix + input_sample)["input_ids"]) 83 | 84 | 85 | # Class for performing safety checks using Azure Content Safety service 86 | class AzureSaftyChecker(object): 87 | def __init__(self): 88 | try: 89 | from azure.ai.contentsafety import ContentSafetyClient 90 | from azure.core.credentials import AzureKeyCredential 91 | 92 | key = os.environ["CONTENT_SAFETY_KEY"] 93 | endpoint = os.environ["CONTENT_SAFETY_ENDPOINT"] 94 | except ImportError: 95 | raise Exception( 96 | "Could not import required package azure-ai-contentsafety. Install with: pip install azure-ai-contentsafety" 97 | ) 98 | except KeyError: 99 | raise Exception( 100 | "Environment variables not set. Please set CONTENT_SAFETY_KEY and CONTENT_SAFETY_ENDPOINT." 101 | ) 102 | 103 | self.client = ContentSafetyClient(endpoint, AzureKeyCredential(key)) 104 | 105 | def __call__(self, output_text): 106 | from azure.core.exceptions import HttpResponseError 107 | from azure.ai.contentsafety.models import AnalyzeTextOptions, TextCategory 108 | 109 | print(len(output_text)) 110 | if len(output_text) > 1000: 111 | raise Exception("Input length to safety check is too long (>1000).") 112 | 113 | categories = [ 114 | TextCategory.VIOLENCE, 115 | TextCategory.SELF_HARM, 116 | TextCategory.SEXUAL, 117 | TextCategory.HATE, 118 | ] 119 | 120 | request = AnalyzeTextOptions(text=output_text, categories=categories) 121 | 122 | try: 123 | response = self.client.analyze_text(request) 124 | except HttpResponseError as e: 125 | print("Analyze text failed.") 126 | if e.error: 127 | print(f"Error code: {e.error.code}") 128 | print(f"Error message: {e.error.message}") 129 | raise 130 | print(e) 131 | raise e 132 | 133 | levels = {0: "Safe", 2: "Low", 4: "Medium", 6: "High"} 134 | 135 | severities = [ 136 | getattr(response, c.name.lower() + "_result").severity for c in categories 137 | ] 138 | 139 | DEFAULT_LEVELS = [0, 0, 0, 0] 140 | 141 | is_safe = all([s <= l for s, l in zip(severities, DEFAULT_LEVELS)]) 142 | 143 | report = "" 144 | if not is_safe: 145 | report = "|" + "|".join(f"{c.name:^10}" for c in categories) + "|\n" 146 | report += "|" + "|".join(f"{levels[s]:^10}" for s in severities) + "|\n" 147 | 148 | return "Azure Content Saftey API", is_safe, report 149 | 150 | 151 | # Function to load the PeftModel for performance optimization 152 | # Function to determine which safety checker to use based on the options selected 153 | def get_safety_checker(enable_azure_content_safety, 154 | enable_sensitive_topics, 155 | enable_salesforce_content_safety, 156 | ): 157 | safety_checker = [] 158 | if enable_azure_content_safety: 159 | safety_checker.append(AzureSaftyChecker()) 160 | if enable_sensitive_topics: 161 | safety_checker.append(AuditNLGSensitiveTopics()) 162 | if enable_salesforce_content_safety: 163 | safety_checker.append(SalesforceSafetyChecker()) 164 | return safety_checker 165 | 166 | 167 | 168 | 169 | 170 | -------------------------------------------------------------------------------- /src/slam_llm/models/encoder.py: -------------------------------------------------------------------------------- 1 | import types 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from dataclasses import dataclass 6 | 7 | 8 | class WhisperWrappedEncoder: 9 | 10 | @classmethod 11 | def load(cls, model_config): 12 | 13 | def extract_variable_length_features(self, x: torch.Tensor): 14 | """ 15 | x : torch.Tensor, shape = (batch_size, n_mels, n_ctx) 16 | the mel spectrogram of the audio 17 | """ 18 | x = F.gelu(self.conv1(x)) 19 | x = F.gelu(self.conv2(x)) 20 | x = x.permute(0, 2, 1) 21 | 22 | # assert x.shape[1:] == self.positional_embedding.shape, "incorrect audio shape" 23 | # x = (x + self.positional_embedding).to(x.dtype) 24 | x = (x + self.positional_embedding[: x.shape[1]]).to(x.dtype) 25 | 26 | # 遍历每个 block 27 | for block in self.blocks: 28 | # 在 block 中调用 layer_norm 之前,将输入转换为 Float32 29 | y = x.float() # 转换为 Float32 30 | x = block(y) # 执行 block 的计算 31 | x = x.to(torch.bfloat16) # 转换回 BFloat16 32 | 33 | # 最后的 layer_norm 34 | x = self.ln_post(x.float()).to(torch.bfloat16) # 转换为 Float32 并返回 BFloat16 35 | return x 36 | 37 | if model_config.encoder_path_hf is not None: 38 | from transformers import WhisperModel 39 | encoder = WhisperModel.from_pretrained(model_config.encoder_path_hf,torch_dtype=torch.bfloat16).encoder 40 | else: 41 | import whisper 42 | encoder = whisper.load_model(name=model_config.encoder_path, device='cpu').encoder 43 | encoder.extract_variable_length_features = types.MethodType(extract_variable_length_features, encoder) 44 | return encoder 45 | 46 | 47 | class BEATsEncoder: 48 | 49 | @classmethod 50 | def load(cls, model_config): 51 | from .BEATs.BEATs import BEATs, BEATsConfig 52 | checkpoint = torch.load(model_config.encoder_path) 53 | cfg = BEATsConfig(checkpoint['cfg']) 54 | BEATs_model = BEATs(cfg) 55 | BEATs_model.load_state_dict(checkpoint['model']) 56 | 57 | return BEATs_model 58 | 59 | 60 | @dataclass 61 | class UserDirModule: 62 | user_dir: str 63 | 64 | class EATEncoder: 65 | 66 | @classmethod 67 | def load(cls, model_config): 68 | import fairseq 69 | model_path = UserDirModule(model_config.encoder_fairseq_dir) 70 | fairseq.utils.import_user_module(model_path) 71 | EATEncoder, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([model_config.encoder_path]) 72 | EATEncoder = EATEncoder[0] 73 | 74 | return EATEncoder 75 | 76 | def extract_features(self, source, padding_mask): 77 | return self.model.extract_features(source, padding_mask = padding_mask, mask=False, remove_extra_tokens = False)['x'] 78 | 79 | class CLAPEncoder: 80 | 81 | @classmethod 82 | def load(cls, model_config): 83 | from .CLAP.ase_model import ASE 84 | import ruamel.yaml as yaml 85 | with open(model_config.clap_config, 'r') as f: 86 | clap_config = yaml.safe_load(f) 87 | clap_config['pd_text_support'] = model_config.get("pd_text_support", None) 88 | model = ASE(clap_config) 89 | checkpoint = torch.load(model_config.encoder_path)['model'] 90 | model.load_state_dict(checkpoint) 91 | return model 92 | 93 | class SpatialASTEncoder: 94 | @classmethod 95 | def load(cls, model_config): 96 | from functools import partial 97 | from .SpatialAST import SpatialAST 98 | binaural_encoder = SpatialAST.BinauralEncoder( 99 | num_classes=355, drop_path_rate=0.1, num_cls_tokens=3, 100 | patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, 101 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6) 102 | ) 103 | 104 | checkpoint = torch.load(model_config.encoder_ckpt, map_location='cpu') 105 | binaural_encoder.load_state_dict(checkpoint['model'], strict=False) 106 | return binaural_encoder 107 | 108 | class WavLMEncoder(nn.Module): 109 | def __init__(self, config, model): 110 | super().__init__() 111 | self.config = config 112 | self.model = model 113 | 114 | @classmethod 115 | def load(cls, model_config): 116 | from .wavlm.WavLM import WavLM, WavLMConfig 117 | checkpoint = torch.load(model_config.encoder_path) 118 | cfg = WavLMConfig(checkpoint['cfg']) 119 | WavLM_model = WavLM(cfg) 120 | WavLM_model.load_state_dict(checkpoint['model']) 121 | assert model_config.normalize == cfg.normalize, "normalize flag in config and model checkpoint do not match" 122 | 123 | return cls(cfg, WavLM_model) 124 | 125 | def extract_features(self, source, padding_mask): 126 | return self.model.extract_features(source, padding_mask)[0] 127 | 128 | class AVHubertEncoder: 129 | 130 | @classmethod 131 | def load(cls, model_config): 132 | import fairseq 133 | from .avhubert import hubert_pretraining, hubert, hubert_asr 134 | models, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([model_config.encoder_path]) 135 | model = models[0] 136 | return model 137 | 138 | class HubertEncoder: 139 | 140 | @classmethod 141 | def load(cls, model_config): 142 | import fairseq 143 | models, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([model_config.encoder_path]) 144 | model = models[0] 145 | if model_config.encoder_type == "pretrain": 146 | pass 147 | elif model_config.encoder_type == "finetune": 148 | model.w2v_encoder.proj = None 149 | model.w2v_encoder.apply_mask = False 150 | else: 151 | assert model_config.encoder_type in ["pretrain", "finetune"], "input_type must be one of [pretrain, finetune]" 152 | return model 153 | 154 | 155 | class HfTextEncoder: 156 | 157 | @classmethod 158 | def load(cls, model_config): 159 | from transformers import AutoModel 160 | model = AutoModel.from_pretrained(model_config.encoder_path) 161 | return model 162 | 163 | class MusicFMEncoder(nn.Module): 164 | def __init__(self, config, model): 165 | super().__init__() 166 | self.config = config 167 | self.model = model 168 | 169 | @classmethod 170 | def load(cls, model_config): 171 | from .musicfm.model.musicfm_25hz import MusicFM25Hz 172 | model = MusicFM25Hz( 173 | stat_path = model_config.encoder_stat_path, 174 | model_path = model_config.encoder_path, 175 | w2v2_config_path = model_config.get('encoder_config_path', "facebook/wav2vec2-conformer-rope-large-960h-ft") 176 | ) 177 | return cls(model_config, model) 178 | 179 | def extract_features(self, source, padding_mask=None): 180 | _, hidden_states = self.model.get_predictions(source) 181 | out = hidden_states[self.config.encoder_layer_idx] 182 | return out 183 | 184 | class Emotion2vecEncoder: 185 | 186 | @classmethod 187 | def load(cls, model_config): 188 | import fairseq 189 | model_path = UserDirModule(model_config.encoder_fairseq_dir) 190 | fairseq.utils.import_user_module(model_path) 191 | model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([model_config.encoder_path]) 192 | model = model[0] 193 | 194 | return model -------------------------------------------------------------------------------- /src/slam_llm/policies/anyprecision_optimizer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. 3 | 4 | # AnyPrecisionAdamW: a flexible precision AdamW optimizer 5 | # with optional Kahan summation for high precision weight updates. 6 | # Allows direct control over momentum, variance and auxiliary compensation 7 | # buffer dtypes. 8 | # Optional Kahan summation is used to offset precision reduction for 9 | # the weight updates. This allows full training in BFloat16 (equal or 10 | # better than FP32 results in many cases) due to high precision weight upates. 11 | 12 | import torch 13 | from torch.optim.optimizer import Optimizer 14 | 15 | 16 | class AnyPrecisionAdamW(Optimizer): 17 | def __init__( 18 | self, 19 | params, 20 | lr=1e-3, 21 | betas=(0.9, 0.999), 22 | eps=1e-8, 23 | weight_decay=0.0, 24 | use_kahan_summation=False, 25 | momentum_dtype=torch.bfloat16, 26 | variance_dtype=torch.bfloat16, 27 | compensation_buffer_dtype=torch.bfloat16, 28 | ): 29 | """ 30 | Args: 31 | params (iterable): iterable of parameters to optimize or dicts defining 32 | parameter groups 33 | lr (float, optional): learning rate (default: 1e-3) 34 | betas (Tuple[float, float], optional): coefficients used for computing 35 | running averages of gradient and its square (default: (0.9, 0.999)) 36 | eps (float, optional): term added to the denominator to improve 37 | numerical stability (default: 1e-8) 38 | weight_decay (float, optional): weight decay coefficient (default: 1e-2) 39 | 40 | # Any Precision specific 41 | use_kahan_summation = creates auxiliary buffer to ensure high precision 42 | model param updates (default: False) 43 | momentum_dtype = dtype for momentum (default: BFloat32) 44 | variance_dtype = dtype for uncentered variance (default: BFloat16) 45 | compensation_buffer_dtype = dtype for Kahan summation 46 | buffer (default: BFloat16) 47 | 48 | # Usage 49 | This optimizer implements optimizer states, and Kahan summation 50 | for high precision updates, all in user controlled dtypes. 51 | Defaults are variance in BF16, Momentum in FP32. 52 | This can be run in FSDP mixed precision, amp, or full precision, 53 | depending on what training pipeline you wish to work with. 54 | 55 | Setting to use_kahan_summation = False, and changing momentum and 56 | variance dtypes to FP32, reverts this to a standard AdamW optimizer. 57 | 58 | """ 59 | defaults = dict( 60 | lr=lr, 61 | betas=betas, 62 | eps=eps, 63 | weight_decay=weight_decay, 64 | use_kahan_summation=use_kahan_summation, 65 | momentum_dtype=momentum_dtype, 66 | variance_dtype=variance_dtype, 67 | compensation_buffer_dtype=compensation_buffer_dtype, 68 | ) 69 | 70 | super().__init__(params, defaults) 71 | 72 | @torch.no_grad() 73 | def step(self, closure=None): 74 | """Performs a single optimization step. 75 | Args: 76 | closure (callable, optional): A closure that reevaluates the model 77 | and returns the loss. 78 | """ 79 | 80 | if closure is not None: 81 | with torch.enable_grad(): 82 | # to fix linter, we do not keep the returned loss for use atm. 83 | closure() 84 | 85 | for group in self.param_groups: 86 | 87 | beta1, beta2 = group["betas"] 88 | lr = group["lr"] 89 | weight_decay = group["weight_decay"] 90 | eps = group["eps"] 91 | use_kahan_summation = group["use_kahan_summation"] 92 | 93 | momentum_dtype = group["momentum_dtype"] 94 | variance_dtype = group["variance_dtype"] 95 | compensation_buffer_dtype = group["compensation_buffer_dtype"] 96 | 97 | for p in group["params"]: 98 | if p.grad is None: 99 | continue 100 | 101 | if p.grad.is_sparse: 102 | raise RuntimeError( 103 | "AnyPrecisionAdamW does not support sparse gradients" 104 | ) 105 | 106 | state = self.state[p] 107 | 108 | # State initialization 109 | if len(state) == 0: 110 | 111 | state["step"] = torch.tensor(0.0) 112 | 113 | # momentum - EMA of gradient values 114 | state["exp_avg"] = torch.zeros_like( 115 | p, 116 | dtype=momentum_dtype, 117 | ) 118 | 119 | # variance uncentered - EMA of squared gradient values 120 | state["exp_avg_sq"] = torch.zeros_like( 121 | p, 122 | dtype=variance_dtype, 123 | ) 124 | 125 | # optional Kahan summation - accumulated error tracker 126 | if use_kahan_summation: 127 | state["compensation"] = torch.zeros_like( 128 | p, 129 | dtype=compensation_buffer_dtype, 130 | ) 131 | 132 | # main processing ------------------------- 133 | 134 | # update the steps for each param group update 135 | state["step"] += 1 136 | step = state["step"] 137 | 138 | exp_avg = state["exp_avg"] 139 | exp_avg_sq = state["exp_avg_sq"] 140 | 141 | grad = p.grad 142 | 143 | # weight decay, AdamW style 144 | if weight_decay: 145 | p.data.mul_(1 - lr * weight_decay) 146 | 147 | # update momentum 148 | exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) 149 | 150 | # update uncentered variance 151 | exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) 152 | 153 | # adjust using bias1 154 | bias_correction1 = 1 - beta1**step 155 | 156 | step_size = lr / bias_correction1 157 | 158 | # adjust using bias2 159 | denom_correction = (1 - beta2**step) ** 0.5 # avoids math import 160 | 161 | centered_variance = (exp_avg_sq.sqrt() / denom_correction).add_( 162 | eps, alpha=1 163 | ) 164 | 165 | # lr update to compensation 166 | if use_kahan_summation: 167 | compensation = state["compensation"] 168 | 169 | compensation.addcdiv_(exp_avg, centered_variance, value=-step_size) 170 | 171 | # update weights with compensation (Kahan summation) 172 | # save error back to compensation for next iteration 173 | temp_buffer = p.detach().clone() 174 | p.data.add_(compensation) 175 | compensation.add_(temp_buffer.sub_(p.data)) 176 | 177 | else: 178 | # usual AdamW updates 179 | p.data.addcdiv_(exp_avg, centered_variance, value=-step_size) -------------------------------------------------------------------------------- /src/slam_llm/datasets/audio_dataset.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import random 3 | from torchaudio.transforms import Resample 4 | import json, yaml 5 | import copy 6 | 7 | import numpy as np 8 | from scipy import signal 9 | import soundfile as sf 10 | 11 | import torch 12 | import torchaudio 13 | from torch.utils.data import Dataset 14 | from slam_llm.utils.compute_utils import calculate_output_length_1d 15 | from slam_llm.models.BEATs.BEATs import BEATs 16 | from slam_llm.models.EAT.EAT import EAT_preprocess 17 | 18 | 19 | class AudioDatasetJsonl(torch.utils.data.Dataset): 20 | 21 | def __init__(self, 22 | dataset_config, 23 | tokenizer=None, 24 | split='train', 25 | ): 26 | super().__init__() 27 | self.dataset_config = dataset_config 28 | self.tokenizer = tokenizer 29 | # data_parallel_size = dist.get_world_size() 30 | data_parallel_size = 1 31 | 32 | # self.data_list = contents 33 | self.IGNORE_INDEX = -100 # The default setting in CrossEntropyLoss 34 | self.prompt_template = "USER: {}\n ASSISTANT:" 35 | self.answer_template = "{}" 36 | self.fix_length_audio = dataset_config.fix_length_audio 37 | self.inference_mode = dataset_config.get("inference_mode", False) 38 | self.input_type = dataset_config.get("input_type", None) 39 | self.split = split 40 | self.model_name = dataset_config.get("model_name", "beats") 41 | 42 | self.data_list = [] 43 | if split == "train": 44 | with open(dataset_config.train_data_path, encoding='utf-8') as fin: 45 | for line in fin: 46 | data_dict = json.loads(line.strip()) 47 | self.data_list.append(data_dict) 48 | else: 49 | with open(dataset_config.val_data_path, encoding='utf-8') as fin: 50 | for line in fin: 51 | data_dict = json.loads(line.strip()) 52 | self.data_list.append(data_dict) 53 | 54 | # # debug 55 | # with open(dataset_config.train_data_path, encoding='utf-8') as fin: 56 | # for line in fin: 57 | # data_dict = json.loads(line.strip()) 58 | # self.data_list.append(data_dict) 59 | # if split == "train": 60 | # self.data_list = self.data_list[:80] 61 | # else: 62 | # self.data_list = self.data_list[80:100] 63 | 64 | def get_source_len(self, data_dict): 65 | return data_dict["source_len"] 66 | 67 | def get_target_len(self, data_dict): 68 | 69 | return data_dict["target_len"] if "target_len" in data_dict else 0 70 | 71 | def __len__(self): 72 | return len(self.data_list) 73 | 74 | def __getitem__(self, index): 75 | data_dict = self.data_list[index] 76 | audio_path = data_dict.get("source") 77 | target = data_dict.get("target", None) 78 | task = data_dict.get("prompt", "AAC") 79 | key = data_dict.get("key", None) 80 | 81 | # audio_raw, sample_rate = torchaudio.load(audio_path) 82 | try: 83 | audio_raw, sample_rate = torchaudio.load(audio_path) 84 | if audio_raw.shape[1] == 0: 85 | raise ValueError("Empty audio file") 86 | resampler = Resample(orig_freq=sample_rate, new_freq=16000) 87 | audio_raw = resampler(audio_raw) 88 | 89 | except (FileNotFoundError, ValueError, RuntimeError): 90 | audio_raw = torch.zeros(1, 16000) 91 | 92 | # assert sample_rate == 16e3, "Sample rate should be 16kHz, but got {} in file {}".format(sample_rate,audio_path) 93 | if self.model_name == "beats": 94 | audio_mel = BEATs.preprocess(audio_raw[0], fbank_mean=self.dataset_config.fbank_mean, fbank_std=self.dataset_config.fbank_std) 95 | elif self.model_name == "eat": 96 | audio_mel = EAT_preprocess(source=audio_raw[0],norm_mean=self.dataset_config.fbank_mean,norm_std=self.dataset_config.fbank_std, 97 | target_length=self.dataset_config.target_length,fixed_length=self.dataset_config.fixed_length,random_crop=self.dataset_config.random_crop) 98 | else: 99 | pass 100 | 101 | # prompt = "Describe the audio you hear. Output the audio caption directly without redundant content. Ensure that the output is not duplicated. " 102 | # prompt = "Describe the audio you hear. " 103 | prompt = self.dataset_config.prompt + ' ' 104 | 105 | 106 | prompt = self.prompt_template.format(prompt) 107 | answer = self.answer_template.format(target) 108 | 109 | prompt_ids = self.tokenizer.encode(prompt) 110 | 111 | prompt_length = len(prompt_ids) 112 | if self.model_name == "beats": 113 | audio_length = (audio_mel.shape[0] + 1) // 2 # ad-hoc for beats for 2x downsample from mel to feats 114 | 115 | elif self.model_name == "eat": 116 | audio_length = audio_mel.shape[0] // 2 + 1 # ad-hoc for eat for 2x downsample from mel to feats 117 | audio_length = audio_length // self.dataset_config.encoder_projector_ds_rate # ad-hoc for 5x fc downsample 118 | # audio_length = calculate_output_length_1d(audio_length, 5, 5, 0) # ad-hoc for 5x cov1d downsample 119 | if self.fix_length_audio > 0: 120 | audio_length = self.fix_length_audio 121 | audio_pseudo = torch.full((audio_length,), -1) # placeholder 122 | 123 | if self.inference_mode: 124 | prompt_ids = torch.tensor(prompt_ids, dtype=torch.int64) 125 | example_ids = torch.cat((audio_pseudo, prompt_ids)) # [audio,prompt] 126 | example_mask = example_ids.ge(-1) # [True,True] 127 | 128 | return { 129 | "input_ids": example_ids, 130 | "attention_mask": example_mask, 131 | "audio": audio_raw if self.input_type == "raw" else None, 132 | "audio_mel": audio_mel if self.input_type == "mel" else None, 133 | "audio_length": audio_length, 134 | "key": key, 135 | "target": target, 136 | } 137 | 138 | example = prompt + answer # FIX(MZY): avoid putting a bos token before answer. 139 | example_ids = self.tokenizer.encode(example) # [prompt,answer] 140 | example_ids.append(self.tokenizer.eos_token_id) # [prompt,answer,eos] 141 | example_ids = torch.tensor( 142 | example_ids, dtype=torch.int64 143 | ) 144 | example_ids = torch.cat((audio_pseudo, example_ids)) # [audio,prompt,answer,eos] 145 | 146 | labels_ids = copy.deepcopy(example_ids) # [audio,prompt,answer,eos] 147 | labels_ids[:audio_length + prompt_length] = -1 # [-1,-1,answer,eos]; 148 | example_mask = example_ids.ge(-1) # FIX(GZF): [True,True,True,True] 149 | 150 | label_mask = labels_ids.ge(0) # [False,False,True,True] 151 | example_ids[~example_mask] = 0 # [audio,prompt,answer,eos] 152 | labels_ids[~label_mask] = self.IGNORE_INDEX # [-100,-100,answer,eos] 153 | 154 | return { 155 | "input_ids": example_ids, 156 | "labels": labels_ids, 157 | "attention_mask": example_mask, 158 | 'audio_mel': audio_mel, 159 | 'audio_length': audio_length, 160 | "target": target, 161 | } 162 | 163 | def pad(self, sequence, max_length, padding_idx=0): 164 | if isinstance(sequence, (int, list, tuple)): 165 | if len(sequence) < max_length: 166 | sequence = sequence + [padding_idx] * (max_length - len(sequence)) 167 | else: 168 | sequence = sequence[:max_length] 169 | elif isinstance(sequence, torch.Tensor): 170 | if len(sequence) < max_length: 171 | sequence = torch.cat( 172 | (sequence, torch.full(([max_length - len(sequence)] + list(sequence.size())[1:]), padding_idx))) 173 | else: 174 | sequence = sequence[:max_length] 175 | else: 176 | raise Exception("Type mismatch during padding!") 177 | return sequence 178 | 179 | def collator(self, samples): 180 | assert samples is not None 181 | input_ids_max_length = max([s['input_ids'].shape[0] for s in samples]) 182 | input_ids = torch.stack([self.pad(s['input_ids'], input_ids_max_length, self.tokenizer.pad_token_id) 183 | for s in samples]) 184 | attention_mask = torch.stack([self.pad(s['attention_mask'], input_ids_max_length, False) 185 | for s in samples]) 186 | 187 | audio_mel_max_length = max([s['audio_mel'].shape[0] for s in samples]) 188 | audio_mel = torch.stack([self.pad(s['audio_mel'], audio_mel_max_length, 0) 189 | for s in samples]) 190 | audio_mel_mask = torch.zeros(len(samples), audio_mel_max_length) 191 | for line, sample in enumerate(samples): 192 | audio_mel_mask[line, :sample['audio_mel'].shape[0]] = 1 193 | modality_mask = torch.zeros_like(attention_mask) 194 | for line, sample in enumerate(samples): 195 | modality_mask[line, :sample['audio_length']] = 1 196 | 197 | targets = [s['target'] for s in samples] 198 | if self.inference_mode: 199 | keys = [s['key'] for s in samples] 200 | 201 | return { 202 | "input_ids": input_ids, 203 | "attention_mask": attention_mask, 204 | "audio_mel": audio_mel if self.input_type == "mel" else None, 205 | "modality_mask": modality_mask, 206 | "keys": keys, 207 | "targets": targets 208 | } 209 | 210 | labels = torch.stack([self.pad(s['labels'], input_ids_max_length, self.IGNORE_INDEX) 211 | for s in samples]) 212 | return { 213 | 'input_ids': input_ids, 214 | 'labels': labels, 215 | 'attention_mask': attention_mask, 216 | 'audio_mel': audio_mel, 217 | 'audio_mel_mask': audio_mel_mask, 218 | 'modality_mask': modality_mask 219 | } 220 | 221 | 222 | 223 | def get_audio_dataset(dataset_config, tokenizer, split): 224 | dataset = AudioDatasetJsonl(dataset_config, tokenizer, split) 225 | 226 | return dataset -------------------------------------------------------------------------------- /src/slam_llm/utils/custom_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import cv2 8 | import torch 9 | import random 10 | import numpy as np 11 | from typing import Dict, List, Optional, Tuple 12 | 13 | def load_video(path): 14 | for i in range(3): 15 | try: 16 | cap = cv2.VideoCapture(path) 17 | frames = [] 18 | while True: 19 | ret, frame = cap.read() 20 | if ret: 21 | frame = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) 22 | frames.append(frame) 23 | else: 24 | break 25 | frames = np.stack(frames) 26 | return frames 27 | except Exception: 28 | print(f"failed loading {path} ({i} / 3)") 29 | if i == 2: 30 | raise ValueError(f"Unable to load {path}") 31 | 32 | 33 | class Compose(object): 34 | """Compose several preprocess together. 35 | Args: 36 | preprocess (list of ``Preprocess`` objects): list of preprocess to compose. 37 | """ 38 | 39 | def __init__(self, preprocess): 40 | self.preprocess = preprocess 41 | 42 | def __call__(self, sample): 43 | for t in self.preprocess: 44 | sample = t(sample) 45 | return sample 46 | 47 | def __repr__(self): 48 | format_string = self.__class__.__name__ + '(' 49 | for t in self.preprocess: 50 | format_string += '\n' 51 | format_string += ' {0}'.format(t) 52 | format_string += '\n)' 53 | return format_string 54 | 55 | 56 | class Normalize(object): 57 | """Normalize a ndarray image with mean and standard deviation. 58 | """ 59 | 60 | def __init__(self, mean, std): 61 | self.mean = mean 62 | self.std = std 63 | 64 | def __call__(self, frames): 65 | """ 66 | Args: 67 | tensor (Tensor): Tensor image of size (C, H, W) to be normalized. 68 | Returns: 69 | Tensor: Normalized Tensor image. 70 | """ 71 | frames = (frames - self.mean) / self.std 72 | return frames 73 | 74 | def __repr__(self): 75 | return self.__class__.__name__+'(mean={0}, std={1})'.format(self.mean, self.std) 76 | 77 | class CenterCrop(object): 78 | """Crop the given image at the center 79 | """ 80 | def __init__(self, size): 81 | self.size = size 82 | 83 | def __call__(self, frames): 84 | """ 85 | Args: 86 | img (numpy.ndarray): Images to be cropped. 87 | Returns: 88 | numpy.ndarray: Cropped image. 89 | """ 90 | t, h, w = frames.shape 91 | th, tw = self.size 92 | delta_w = int(round((w - tw))/2.) 93 | delta_h = int(round((h - th))/2.) 94 | frames = frames[:, delta_h:delta_h+th, delta_w:delta_w+tw] 95 | return frames 96 | 97 | 98 | class RandomCrop(object): 99 | """Crop the given image at the center 100 | """ 101 | 102 | def __init__(self, size): 103 | self.size = size 104 | 105 | def __call__(self, frames): 106 | """ 107 | Args: 108 | img (numpy.ndarray): Images to be cropped. 109 | Returns: 110 | numpy.ndarray: Cropped image. 111 | """ 112 | t, h, w = frames.shape 113 | th, tw = self.size 114 | delta_w = random.randint(0, w-tw) 115 | delta_h = random.randint(0, h-th) 116 | frames = frames[:, delta_h:delta_h+th, delta_w:delta_w+tw] 117 | return frames 118 | 119 | def __repr__(self): 120 | return self.__class__.__name__ + '(size={0})'.format(self.size) 121 | 122 | class HorizontalFlip(object): 123 | """Flip image horizontally. 124 | """ 125 | 126 | def __init__(self, flip_ratio): 127 | self.flip_ratio = flip_ratio 128 | 129 | def __call__(self, frames): 130 | """ 131 | Args: 132 | img (numpy.ndarray): Images to be flipped with a probability flip_ratio 133 | Returns: 134 | numpy.ndarray: Cropped image. 135 | """ 136 | t, h, w = frames.shape 137 | if random.random() < self.flip_ratio: 138 | for index in range(t): 139 | frames[index] = cv2.flip(frames[index], 1) 140 | return frames 141 | 142 | def compute_mask_indices( 143 | shape: Tuple[int, int], 144 | padding_mask: Optional[torch.Tensor], 145 | mask_prob: float, 146 | mask_length: int, 147 | mask_type: str = "static", 148 | mask_other: float = 0.0, 149 | min_masks: int = 0, 150 | no_overlap: bool = False, 151 | min_space: int = 0, 152 | ) -> np.ndarray: 153 | """ 154 | Computes random mask spans for a given shape 155 | Args: 156 | shape: the the shape for which to compute masks. 157 | should be of size 2 where first element is batch size and 2nd is timesteps 158 | padding_mask: optional padding mask of the same size as shape, which will prevent masking padded elements 159 | mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by 160 | number of timesteps divided by length of mask span to mask approximately this percentage of all elements. 161 | however due to overlaps, the actual number will be smaller (unless no_overlap is True) 162 | mask_type: how to compute mask lengths 163 | static = fixed size 164 | uniform = sample from uniform distribution [mask_other, mask_length*2] 165 | normal = sample from normal distribution with mean mask_length and stdev mask_other. mask is min 1 element 166 | poisson = sample from possion distribution with lambda = mask length 167 | min_masks: minimum number of masked spans 168 | no_overlap: if false, will switch to an alternative recursive algorithm that prevents spans from overlapping 169 | min_space: only used if no_overlap is True, this is how many elements to keep unmasked between spans 170 | """ 171 | 172 | bsz, all_sz = shape 173 | mask = np.full((bsz, all_sz), False) 174 | 175 | all_num_mask = int( 176 | # add a random number for probabilistic rounding 177 | mask_prob * all_sz / float(mask_length) 178 | + np.random.rand() 179 | ) 180 | 181 | all_num_mask = max(min_masks, all_num_mask) 182 | 183 | mask_idcs = [] 184 | for i in range(bsz): 185 | if padding_mask is not None: 186 | sz = all_sz - padding_mask[i].long().sum().item() 187 | num_mask = int( 188 | # add a random number for probabilistic rounding 189 | mask_prob * sz / float(mask_length) 190 | + np.random.rand() 191 | ) 192 | num_mask = max(min_masks, num_mask) 193 | else: 194 | sz = all_sz 195 | num_mask = all_num_mask 196 | 197 | if mask_type == "static": 198 | lengths = np.full(num_mask, mask_length) 199 | elif mask_type == "uniform": 200 | lengths = np.random.randint(mask_other, mask_length * 2 + 1, size=num_mask) 201 | elif mask_type == "normal": 202 | lengths = np.random.normal(mask_length, mask_other, size=num_mask) 203 | lengths = [max(1, int(round(x))) for x in lengths] 204 | elif mask_type == "poisson": 205 | lengths = np.random.poisson(mask_length, size=num_mask) 206 | lengths = [int(round(x)) for x in lengths] 207 | else: 208 | raise Exception("unknown mask selection " + mask_type) 209 | 210 | if sum(lengths) == 0: 211 | lengths[0] = min(mask_length, sz - 1) 212 | 213 | if no_overlap: 214 | mask_idc = [] 215 | 216 | def arrange(s, e, length, keep_length): 217 | span_start = np.random.randint(s, e - length) 218 | mask_idc.extend(span_start + i for i in range(length)) 219 | 220 | new_parts = [] 221 | if span_start - s - min_space >= keep_length: 222 | new_parts.append((s, span_start - min_space + 1)) 223 | if e - span_start - keep_length - min_space > keep_length: 224 | new_parts.append((span_start + length + min_space, e)) 225 | return new_parts 226 | 227 | parts = [(0, sz)] 228 | min_length = min(lengths) 229 | for length in sorted(lengths, reverse=True): 230 | lens = np.fromiter( 231 | (e - s if e - s >= length + min_space else 0 for s, e in parts), 232 | np.int, 233 | ) 234 | l_sum = np.sum(lens) 235 | if l_sum == 0: 236 | break 237 | probs = lens / np.sum(lens) 238 | c = np.random.choice(len(parts), p=probs) 239 | s, e = parts.pop(c) 240 | parts.extend(arrange(s, e, length, min_length)) 241 | mask_idc = np.asarray(mask_idc) 242 | else: 243 | min_len = min(lengths) 244 | if sz - min_len <= num_mask: 245 | min_len = sz - num_mask - 1 246 | 247 | mask_idc = np.random.choice(sz - min_len, num_mask, replace=False) 248 | 249 | mask_idc = np.asarray( 250 | [ 251 | mask_idc[j] + offset 252 | for j in range(len(mask_idc)) 253 | for offset in range(lengths[j]) 254 | ] 255 | ) 256 | 257 | mask_idcs.append(np.unique(mask_idc[mask_idc < sz])) 258 | 259 | min_len = min([len(m) for m in mask_idcs]) 260 | batch_indexes, starts, ends = [], [], [] 261 | for i, mask_idc in enumerate(mask_idcs): 262 | if len(mask_idc) > min_len: 263 | mask_idc = np.random.choice(mask_idc, min_len, replace=False) 264 | mask[i, mask_idc] = True 265 | vals, run_starts, run_lengths = find_runs(mask[i]) 266 | start_indices, lengths = run_starts[vals == True], run_lengths[vals == True] 267 | starts.append(start_indices) 268 | ends.append(start_indices+lengths) 269 | batch_indexes.append(np.zeros([len(start_indices)])+i) 270 | return mask, np.concatenate(starts).astype(np.int64), np.concatenate(ends).astype(np.int64), np.concatenate(batch_indexes).astype(np.int64) 271 | 272 | def find_runs(x): 273 | """Find runs of consecutive items in an array.""" 274 | 275 | # ensure array 276 | x = np.asanyarray(x) 277 | if x.ndim != 1: 278 | raise ValueError('only 1D array supported') 279 | n = x.shape[0] 280 | 281 | # handle empty array 282 | if n == 0: 283 | return np.array([]), np.array([]), np.array([]) 284 | 285 | else: 286 | # find run starts 287 | loc_run_start = np.empty(n, dtype=bool) 288 | loc_run_start[0] = True 289 | np.not_equal(x[:-1], x[1:], out=loc_run_start[1:]) 290 | run_starts = np.nonzero(loc_run_start)[0] 291 | 292 | # find run values 293 | run_values = x[loc_run_start] 294 | 295 | # find run lengths 296 | run_lengths = np.diff(np.append(run_starts, n)) 297 | 298 | return run_values, run_starts, run_lengths 299 | -------------------------------------------------------------------------------- /evaluation/test_metric.py: -------------------------------------------------------------------------------- 1 | import json 2 | from collections import defaultdict 3 | from evaluate import load 4 | from transformers.models.whisper.english_normalizer import BasicTextNormalizer 5 | import sacrebleu 6 | import pandas as pd 7 | from sacrebleu.metrics import BLEU, CHRF, TER 8 | import os 9 | import csv 10 | from comet import download_model, load_from_checkpoint 11 | import os 12 | from evaluate import load 13 | from transformers.models.whisper.english_normalizer import BasicTextNormalizer 14 | normalizer = BasicTextNormalizer() 15 | import torch 16 | import openpyxl 17 | from comet import download_model, load_from_checkpoint 18 | import torch 19 | import argparse 20 | import nltk 21 | from nltk.translate import meteor_score 22 | from nltk import word_tokenize 23 | 24 | device_num = 0 25 | 26 | src_langs = ['ara','arz','cmn','zsm', 'ben', 'ces', 'deu', 'eng', 'fas', 'fra', 'heb', 'hin', 'ind', 'ita', 'jpn', 'khm', 'kor', 'lao', 'msa', 'mya', 'nld', 'pol', 'por', 'rus', 'spa', 'tha', 'tgl', 'tur', 'urd', 'vie', 'zho','yue','oci','mon','khk','yue'] 27 | tgt_langs = ['ara','zsm', 'ben', 'ces', 'deu', 'eng', 'fas', 'fra', 'heb', 'hin', 'ind', 'ita', 'jpn', 'khm', 'kor', 'lao', 'msa', 'mya', 'nld', 'pol', 'por', 'rus', 'spa', 'tha', 'tgl', 'tur', 'urd', 'vie', 'zho','yue'] 28 | 29 | file_path = "../fleurs_enzh_srt_Qwen2.5-3B.jsonl" 30 | # file_path = "../test_srt_Qwen2.5-3B.jsonl" 31 | 32 | 33 | 34 | test_metrics_all = ["idx","iso3","iso2","resource","bleu","spbleu","comet","meteor","xcomet","cometwiki","wer","cer"] 35 | 36 | test_metrics = ["idx","iso3","iso2","bleu","spbleu","comet","xcomet","cometwiki"] 37 | 38 | test_metrics = ["idx","iso3","iso2","bleu","spbleu","comet","meteor"] 39 | test_metrics = ["idx","iso3","iso2","bleu","spbleu","comet"] 40 | 41 | 42 | 43 | ISO3_TO_ISO2_MAPPING = { 44 | 'ara': 'ar', # Arabic 45 | 'arz': 'ar', # Arabic (Egypt) 46 | 'ben': 'bn', # Bengali 47 | 'ces': 'cs', # Czech 48 | 'deu': 'de', # German 49 | 'eng': 'en', # English 50 | 'spa': 'es', # Spanish 51 | 'fas': 'fa', # Persian 52 | 'pes': 'fa', # Persian 53 | 'fra': 'fr', # French 54 | 'heb': 'he', # Hebrew 55 | 'hin': 'hi', # Hindi 56 | 'ind': 'id', # Indonesian 57 | 'ita': 'it', # Italian 58 | 'jpn': 'ja', # Japanese 59 | 'khm': 'km', # Khmer 60 | 'kor': 'ko', # Korean 61 | 'lao': 'lo', # Lao 62 | 'msa': 'ms', # Malay 63 | 'zsm': 'ms', # Malay 64 | 'mya': 'my', # Burmese 65 | 'nld': 'nl', # Dutch 66 | 'pol': 'pl', # Polish 67 | 'por': 'pt', # Portuguese 68 | 'rus': 'ru', # Russian 69 | 'tha': 'th', # Thai 70 | 'tgl': 'tl', # Tagalog 71 | 'tur': 'tr', # Turkish 72 | 'urd': 'ur', # Urdu 73 | 'vie': 'vi', # Vietnamese 74 | 'zho': 'zh', # Chinese 75 | 'cmn': 'zh', # Mandarin Chinese 76 | 'yue': 'ye', # Cantonese 77 | 'ceb': 'ce', # Cebuan 78 | 'oci': 'oc', # Occitan 79 | 'mon': 'mn', # Mongolian 80 | 'khk': 'mn', # Mongolian (Khalkha) 81 | } 82 | 83 | 84 | ISO3_TO_LOW = { 85 | 'ara': '1', # Arabic 86 | 'arz': '1', # Arabic (Egypt) 87 | 'ben': '2', # Bengali 88 | 'ces': '1', # Czech 89 | 'deu': '1', # German 90 | 'eng': '1', # English 91 | 'spa': '1', # Spanish 92 | 'fas': '1', # Persian 93 | 'pes': '1', # Persian 94 | 'fra': '1', # French 95 | 'heb': '2', # Hebrew 96 | 'hin': '1', # Hindi 97 | 'ind': '2', # Indonesian 98 | 'ita': '1', # Italian 99 | 'jpn': '1', # Japanese 100 | 'khm': '3', # Khmer 101 | 'kor': '1', # Korean 102 | 'lao': '3', # Lao 103 | 'msa': '2', # Malay 104 | 'zsm': '2', # Malay 105 | 'mya': '3', # Burmese 106 | 'nld': '1', # Dutch 107 | 'pol': '1', # Polish 108 | 'por': '1', # Portuguese 109 | 'rus': '1', # Russian 110 | 'tha': '2', # Thai 111 | 'tgl': '2', # Tagalog 112 | 'tur': '1', # Turkish 113 | 'urd': '2', # Urdu 114 | 'vie': '1', # Vietnamese 115 | 'zho': '1', # Chinese 116 | 'cmn': '1', # Mandarin Chinese 117 | 'yue': '1', # Cantonese 118 | 'ceb': '1', # Cebuan 119 | 'oci': '3', # Occitan 120 | 'mon': '3', # Mongolian 121 | 'khk': '3', # Mongolian (Khalkha) 122 | } 123 | 124 | 125 | if "wer" in test_metrics: 126 | wer = load("wer") 127 | 128 | if "cer" in test_metrics: 129 | cer = load("cer") 130 | 131 | if "comet" in test_metrics: 132 | comet_model_path = download_model("Unbabel/wmt22-comet-da") 133 | comet_model = load_from_checkpoint(comet_model_path).half() 134 | 135 | if "xcomet" in test_metrics: 136 | xmodel_path = download_model("Unbabel/XCOMET-XXL") 137 | xcomet_model = load_from_checkpoint(xmodel_path).half() 138 | 139 | if "cometwiki" in test_metrics: 140 | cometwikiz_model_path = download_model("Unbabel/wmt23-cometkiwi-da-xxl") 141 | cometwiki_model = load_from_checkpoint(cometwikiz_model_path).half() 142 | 143 | 144 | lang_groups = defaultdict(lambda: defaultdict(lambda: {"asr_gt": [], "asr_re": [], "st_gt": [], "st_re": []})) 145 | 146 | count = 0 147 | with open(file_path, 'r', encoding='utf-8') as file: 148 | for line in file: 149 | data = json.loads(line.strip()) 150 | gt = data.get("gt", "") 151 | prompt = data.get("prompt","") 152 | response = data.get("response", "") 153 | 154 | 155 | src_lang = gt.split("|>")[0].split("<|")[-1] 156 | 157 | tgt_lang = gt.split("<|")[-1].split("|>")[0] 158 | 159 | if src_lang == tgt_lang or src_lang not in src_langs or tgt_lang not in tgt_langs: 160 | continue 161 | 162 | prompt = f"<|{src_lang}|><|{tgt_lang}|>" 163 | split_responses = response.split(prompt) 164 | if len(split_responses) == 2: 165 | asr_re, st_re = split_responses 166 | else: 167 | # continue 168 | print(count,response) 169 | print(count,gt) 170 | count +=1 171 | 172 | asr_re = response 173 | st_re = response.split("|>")[-1] if "|>" in response else response 174 | if len(st_re)==0: 175 | st_re = response 176 | 177 | lang_groups[src_lang][tgt_lang]["asr_gt"].append(gt.split(prompt)[0]) 178 | lang_groups[src_lang][tgt_lang]["asr_re"].append(asr_re) 179 | lang_groups[src_lang][tgt_lang]["st_gt"].append(gt.split(prompt)[1]) 180 | lang_groups[src_lang][tgt_lang]["st_re"].append(st_re) 181 | 182 | results = {} 183 | idx = 1 184 | for src_lang in sorted(lang_groups.keys()): 185 | tgt_lang_data = lang_groups[src_lang] 186 | for tgt_lang in sorted(tgt_lang_data.keys()): 187 | data = tgt_lang_data[tgt_lang] 188 | 189 | iso3 = f"{src_lang}_{tgt_lang}" 190 | iso2 = f"{ISO3_TO_ISO2_MAPPING[src_lang]}_{ISO3_TO_ISO2_MAPPING[tgt_lang]}" 191 | sources = [s.strip() for s in data["asr_gt"]] 192 | asr_predictions = [a.strip() for a in data["asr_re"]] 193 | predictions = [p.strip() for p in data["st_re"]] 194 | references = [r.strip() for r in data["st_gt"]] 195 | 196 | cer_lang = ["tha", "jpn", "kor", "zho", "yue", "cmn","lao","mya"] 197 | 198 | if src_lang in cer_lang and "cer" in test_metrics: 199 | normalized_predictions = [normalizer(pred) for pred in asr_predictions] 200 | normalized_references = [normalizer(ref) for ref in sources] 201 | cer_score = cer.compute(predictions=normalized_predictions, references=normalized_references)*100 202 | else: 203 | cer_score = 0 204 | 205 | if src_lang not in cer_lang and "wer" in test_metrics: 206 | normalized_predictions = [normalizer(pred) for pred in asr_predictions] 207 | normalized_references = [normalizer(ref) for ref in sources] 208 | wer_score = wer.compute(predictions=normalized_predictions, references=normalized_references)*100 209 | else: 210 | wer_score = 0 211 | 212 | if "bleu" in test_metrics: 213 | tokenize_method = "char" if tgt_lang in ["tha", "jpn", "kor", "zho", "yue", "cmn","lao","mya"] else "13a" 214 | bleu = BLEU(tokenize=tokenize_method) 215 | bleu_score = bleu.corpus_score(predictions, [references]).score 216 | else: 217 | bleu_score = 0 218 | 219 | if "spbleu" in test_metrics: 220 | spbleu = BLEU(tokenize='flores200') 221 | spbleu_score = spbleu.corpus_score(predictions, [references]).score 222 | else: 223 | spbleu_score = 0 224 | 225 | if "comet" in test_metrics: 226 | comet_data = [{'src': s, 'mt': p, 'ref': r} for s, p, r in zip(sources, predictions, references)] 227 | comet_score = comet_model.predict(comet_data, batch_size=512,devices=[device_num])['system_score']*100 228 | else: 229 | comet_score = 0 230 | 231 | if "xcomet" in test_metrics: 232 | comet_data = [{'src': s, 'mt': p} for s, p in zip(sources, predictions)] 233 | xcomet_score = xcomet_model.predict(comet_data, batch_size=16,devices=[device_num])['system_score']*100 234 | else: 235 | xcomet_score = 0 236 | 237 | if "cometwiki" in test_metrics: 238 | comet_data = [{'src': s, 'mt': p} for s, p in zip(sources, predictions)] 239 | cometwiki_score = cometwiki_model.predict(comet_data, batch_size=16,devices=[device_num])['system_score']*100 240 | else: 241 | cometwiki_score = 0 242 | 243 | if "meteor" in test_metrics: 244 | meteor_lang = { 245 | "tha": "th", 246 | "jpn": "ja", 247 | "kor": "ko", 248 | "zho": "zh", 249 | "yue": "zh", 250 | "cmn": "zh", 251 | "lao": "lo", 252 | "mya": "my", 253 | }.get(tgt_lang, "english") 254 | 255 | 256 | 257 | meteor_scores = [] 258 | for pred, ref in zip(predictions, references): 259 | pred_tokens = word_tokenize(pred, language=meteor_lang) 260 | ref_tokens = word_tokenize(ref, language=meteor_lang) 261 | 262 | score = meteor_score.meteor_score([ref_tokens], pred_tokens) * 100 263 | meteor_scores.append(score) 264 | 265 | meteor_value = sum(meteor_scores) / len(meteor_scores) 266 | else: 267 | meteor_value = 0 268 | results[idx] = { 269 | "idx": idx, 270 | "iso3":iso3, 271 | "iso2":iso2, 272 | "resource":ISO3_TO_LOW[tgt_lang], 273 | "bleu": round(bleu_score, 2), 274 | "spbleu": round(spbleu_score, 2), 275 | "comet": round(comet_score, 2), 276 | "meteor": round(meteor_value , 2), 277 | "xcomet": round(xcomet_score, 2), 278 | "cometwiki": round(cometwiki_score, 2), 279 | "wer": round(wer_score, 2), 280 | "cer": round(cer_score, 2), 281 | } 282 | print(results[idx]) 283 | idx +=1 284 | 285 | 286 | output_xlsx = file_path.split("/")[-1].split(".")[0] + ".xlsx" 287 | wb = openpyxl.Workbook() 288 | ws = wb.active 289 | ws.append(test_metrics_all) 290 | 291 | for key, scores in results.items(): 292 | ws.append([scores["idx"],scores["iso3"],scores["iso2"],scores["resource"], scores["bleu"], scores["spbleu"], scores["comet"], scores["meteor"],scores["xcomet"],scores["cometwiki"],scores["wer"],scores["cer"]]) 293 | 294 | wb.save(output_xlsx) 295 | print(f"result saved in {output_xlsx}") -------------------------------------------------------------------------------- /examples/st_covost2/inference_asr_batch.py: -------------------------------------------------------------------------------- 1 | 2 | import hydra 3 | import logging 4 | from dataclasses import dataclass, field 5 | from omegaconf import DictConfig, ListConfig, OmegaConf 6 | from typing import Optional 7 | from asr_config import ModelConfig, TrainConfig, DataConfig, LogConfig, FSDPConfig 8 | # import fire 9 | import random 10 | import torch 11 | import logging 12 | import sacrebleu 13 | # import argparse 14 | import itertools 15 | import json 16 | import time 17 | from slam_llm.models.slam_model import slam_model 18 | 19 | 20 | 21 | 22 | 23 | # config 24 | # from llama_recipes.configs import fsdp_config as FSDP_CONFIG 25 | # from llama_recipes.configs import train_config as TRAIN_CONFIG 26 | # from llama_recipes.configs import model_config as MODEL_CONFIG 27 | # from llama_recipes.configs import log_config as LOG_CONFIG 28 | from slam_llm.utils.train_utils import ( 29 | train, 30 | freeze_transformer_layers, 31 | setup, 32 | setup_environ_flags, 33 | clear_gpu_cache, 34 | get_policies 35 | ) 36 | from slam_llm.utils.model_utils import get_custom_model_factory 37 | from slam_llm.utils.dataset_utils import get_preprocessed_dataset 38 | import os 39 | import logging 40 | from tqdm import tqdm 41 | from model.slam_model_st import model_factory 42 | from transformers import AutoTokenizer,AutoConfig,AutoModel 43 | 44 | import hydra 45 | from omegaconf import DictConfig, ListConfig, OmegaConf 46 | 47 | from slam_llm.utils.model_utils import get_custom_model_factory 48 | 49 | class InferenceSampler(torch.utils.data.sampler.Sampler): 50 | 51 | def __init__(self, size): 52 | self._size = int(size) 53 | assert size > 0 54 | self._rank = torch.distributed.get_rank() 55 | self._world_size = torch.distributed.get_world_size() 56 | self._local_indices = self._get_local_indices(size, self._world_size, 57 | self._rank) 58 | 59 | @staticmethod 60 | def _get_local_indices(total_size, world_size, rank): 61 | shard_size = total_size // world_size 62 | left = total_size % world_size 63 | shard_sizes = [shard_size + int(r < left) for r in range(world_size)] 64 | 65 | begin = sum(shard_sizes[:rank]) 66 | end = min(sum(shard_sizes[:rank + 1]), total_size) 67 | return range(begin, end) 68 | 69 | def __iter__(self): 70 | yield from self._local_indices 71 | 72 | def __len__(self): 73 | return len(self._local_indices) 74 | 75 | def Inference(kwargs: DictConfig): 76 | 77 | # Update the configuration for the training and sharding process 78 | train_config, fsdp_config, model_config, log_config, dataset_config,ckpt_path = kwargs.train_config, \ 79 | kwargs.fsdp_config, \ 80 | kwargs.model_config, \ 81 | kwargs.log_config, \ 82 | kwargs.dataset_config, \ 83 | kwargs.ckpt_path 84 | 85 | OmegaConf.set_struct(kwargs,False) 86 | del kwargs["train_config"] 87 | del kwargs["fsdp_config"] 88 | del kwargs["model_config"] 89 | del kwargs["log_config"] 90 | del kwargs["dataset_config"] 91 | OmegaConf.set_struct(kwargs,True) 92 | 93 | 94 | # Set log 95 | if not os.path.exists(os.path.dirname(log_config.log_file)): 96 | os.makedirs(os.path.dirname(log_config.log_file), exist_ok=True) 97 | logging.basicConfig( 98 | level=logging.INFO, 99 | format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", 100 | datefmt="%Y-%m-%d %H:%M:%S", 101 | filemode='w' 102 | ) 103 | 104 | logger = logging.getLogger() 105 | logger.setLevel(logging.INFO) 106 | 107 | file_handler = logging.FileHandler(filename=log_config.log_file, mode='w') 108 | file_handler.setLevel(logging.INFO) 109 | file_formatter = logging.Formatter('[%(asctime)s][%(name)s][%(levelname)s] - %(message)s', datefmt='%Y-%m-%d %H:%M:%S') 110 | file_handler.setFormatter(file_formatter) 111 | 112 | logger.handlers[0].setLevel(logging.INFO) 113 | console_formatter = logging.Formatter('[%(asctime)s][%(name)s][%(levelname)s] - %(message)s', datefmt='%Y-%m-%d %H:%M:%S') 114 | logger.handlers[0].setFormatter(console_formatter) 115 | 116 | logger.addHandler(file_handler) 117 | 118 | 119 | 120 | # Set the seeds for reproducibility 121 | torch.cuda.manual_seed(train_config.seed) 122 | torch.manual_seed(train_config.seed) 123 | random.seed(train_config.seed) 124 | 125 | 126 | 127 | 128 | if train_config.enable_fsdp or train_config.enable_ddp: 129 | setup() 130 | local_rank = int(os.environ["LOCAL_RANK"]) 131 | rank = int(os.environ["RANK"]) 132 | world_size = int(os.environ["WORLD_SIZE"]) 133 | else: 134 | local_rank = 0 135 | rank = 0 136 | world_size = 1 137 | print("local_rank: ",local_rank) 138 | print("rank: ",rank) 139 | print("world_size: ",world_size) 140 | 141 | 142 | if torch.distributed.is_initialized(): 143 | torch.cuda.set_device(local_rank) 144 | clear_gpu_cache(local_rank) 145 | setup_environ_flags(rank) 146 | 147 | if not (train_config.enable_fsdp or train_config.enable_ddp) or rank == 0: 148 | logger.info("train_config: {}".format(train_config)) 149 | logger.info("fsdp_config: {}".format(fsdp_config)) 150 | logger.info("model_config: {}".format(model_config)) 151 | logger.info("log_config: {}".format(log_config)) 152 | 153 | model_factory = get_custom_model_factory(model_config, logger) 154 | model, tokenizer = model_factory(train_config, model_config, **kwargs) 155 | 156 | 157 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # FIX(MZY): put the whole model to device. 158 | # model.to(torch.bfloat16) 159 | model.to(torch.float16) 160 | 161 | dataset_config["fp16"]=True 162 | model.to(device) 163 | model.eval() 164 | tokenizer.padding_side = 'left' 165 | 166 | 167 | 168 | 169 | dataset_test = get_preprocessed_dataset( 170 | tokenizer, 171 | dataset_config, 172 | split="test", 173 | ) 174 | if world_size > 1: 175 | test_sampler = InferenceSampler(len(dataset_test)) 176 | else: 177 | from torch.utils.data import SequentialSampler 178 | test_sampler = SequentialSampler(dataset_test) 179 | 180 | test_dataloader = torch.utils.data.DataLoader( 181 | dataset_test, 182 | sampler=test_sampler, 183 | num_workers=train_config.num_workers_dataloader, 184 | pin_memory=True, 185 | shuffle=False, 186 | batch_size=train_config.val_batch_size, 187 | drop_last=False, 188 | prefetch_factor=10, 189 | persistent_workers=False, 190 | collate_fn=dataset_test.collator 191 | ) 192 | 193 | gts = [] 194 | sources = [] 195 | rets = [] 196 | audio_paths = [] 197 | prompts = [] 198 | 199 | for step, batch in tqdm(enumerate(test_dataloader), total=len(test_dataloader)): 200 | 201 | for key in batch.keys(): 202 | batch[key] = batch[key].to(device) if isinstance(batch[key], torch.Tensor) else batch[key] 203 | 204 | model_outputs = model.generate(**batch) 205 | 206 | # print(model_outputs) 207 | output_text = model.tokenizer.batch_decode(model_outputs, add_special_tokens=False, skip_special_tokens=True) 208 | 209 | for key, audio_path ,prompt,text, target in zip(batch["keys"],batch["audio_paths"],batch["prompts"], output_text, batch["targets"]): 210 | # print("Prediction: ",key,text) 211 | # print("Ground Truth:",key,target) 212 | print(key,"pred: ",text) 213 | print(key,"gold: ",target) 214 | 215 | source = "eng" 216 | 217 | audio_paths.append(audio_path) 218 | rets.append(text) 219 | gts.append(target) 220 | sources.append(source) 221 | prompts.append(prompt) 222 | 223 | if world_size > 1: 224 | torch.distributed.barrier() 225 | merged_gts = [None for _ in range(world_size)] 226 | torch.distributed.all_gather_object(merged_gts, gts) 227 | merged_gts = [None for _ in range(world_size)] 228 | merged_sources = [None for _ in range(world_size)] 229 | merged_responses = [None for _ in range(world_size)] 230 | merged_audio_paths = [None for _ in range(world_size)] 231 | merged_prompts = [None for _ in range(world_size)] 232 | torch.distributed.all_gather_object(merged_gts, gts) 233 | torch.distributed.all_gather_object(merged_sources, sources) 234 | torch.distributed.all_gather_object(merged_responses, rets) 235 | torch.distributed.all_gather_object(merged_audio_paths, audio_paths) 236 | torch.distributed.all_gather_object(merged_prompts, prompts) 237 | 238 | merged_gts = [_ for _ in itertools.chain.from_iterable(merged_gts)] 239 | merged_sources = [_ for _ in itertools.chain.from_iterable(merged_sources)] 240 | merged_responses = [_ for _ in itertools.chain.from_iterable(merged_responses)] 241 | merged_audio_paths = [_ for _ in itertools.chain.from_iterable(merged_audio_paths)] 242 | merged_prompts = [_ for _ in itertools.chain.from_iterable(merged_prompts)] 243 | else: 244 | merged_gts = gts 245 | merged_responses = rets 246 | merged_sources = sources 247 | merged_audio_paths = audio_paths 248 | merged_prompts = prompts 249 | 250 | 251 | 252 | 253 | if world_size > 1: 254 | if torch.distributed.get_rank() == 0: 255 | results_file = log_config.decode_log 256 | with open(results_file, 'w') as f: 257 | for gt, response, source, audio_path, prompt in zip( 258 | merged_gts, merged_responses, merged_sources, merged_audio_paths, merged_prompts 259 | ): 260 | result = { 261 | 'gt': gt, 262 | 'response': response, 263 | 'source': source, 264 | "audio_path": audio_path, 265 | "prompt": prompt, 266 | } 267 | f.write(json.dumps(result, ensure_ascii=False) + '\n') 268 | print(f"Results saved to: {results_file}") 269 | torch.distributed.barrier() 270 | else: 271 | results_file = log_config.decode_log 272 | with open(results_file, 'w') as f: 273 | for gt, response, source, audio_path, prompt in zip( 274 | merged_gts, merged_responses, merged_sources, merged_audio_paths, merged_prompts 275 | ): 276 | result = { 277 | 'gt': gt, 278 | 'response': response, 279 | 'source': source, 280 | "audio_path": audio_path, 281 | "prompt": prompt, 282 | } 283 | f.write(json.dumps(result, ensure_ascii=False) + '\n') 284 | print(f"Results saved to: {results_file}") 285 | 286 | 287 | @dataclass 288 | class RunConfig: 289 | dataset_config: DataConfig = field(default_factory=DataConfig) 290 | model_config: ModelConfig = field(default_factory=ModelConfig) 291 | train_config: TrainConfig = field(default_factory=TrainConfig) 292 | log_config: LogConfig = field(default_factory=LogConfig) 293 | fsdp_config: FSDPConfig = field(default_factory=FSDPConfig) 294 | debug: bool = field(default=False, metadata={"help": "Use pdb when true"}) 295 | metric: str = field(default="acc", metadata={"help": "The metric for evaluation"}) 296 | decode_log: str = field( 297 | default="output/decode_log", 298 | metadata={"help": "The prefix for the decode output"}, 299 | ) 300 | ckpt_path: str = field( 301 | default="output/model.pt", metadata={"help": "The path to projector checkpoint"} 302 | ) 303 | peft_ckpt: Optional[str] = field( 304 | default=None, 305 | metadata={ 306 | "help": "The path to peft checkpoint, should be a directory including adapter_config.json" 307 | }, 308 | ) 309 | 310 | 311 | @hydra.main(config_name=None, version_base=None) 312 | def main_hydra(cfg: DictConfig): 313 | run_config = RunConfig() 314 | cfg = OmegaConf.merge(run_config, cfg) 315 | # kwargs = to_plain_list(cfg) 316 | log_level = getattr(logging, cfg.get("log_level", "INFO").upper()) 317 | 318 | logging.basicConfig(level=log_level) 319 | 320 | if cfg.get("debug", False): 321 | import pdb 322 | 323 | pdb.set_trace() 324 | 325 | Inference(cfg) 326 | 327 | 328 | if __name__ == "__main__": 329 | main_hydra() 330 | -------------------------------------------------------------------------------- /src/slam_llm/utils/checkpoint_handler.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. 3 | import os 4 | from pathlib import Path 5 | from datetime import datetime 6 | import torch 7 | import time 8 | from collections import OrderedDict 9 | 10 | from torch.distributed.fsdp import ( 11 | FullyShardedDataParallel as FSDP, 12 | StateDictType, 13 | FullStateDictConfig, # general model non-sharded, non-flattened params 14 | LocalStateDictConfig, # flattened params, usable only by FSDP 15 | # ShardedStateDictConfig, # un-flattened param but shards, usable by other parallel schemes. 16 | ) 17 | 18 | from torch.distributed._shard.checkpoint import ( 19 | FileSystemReader, 20 | FileSystemWriter, 21 | save_state_dict, 22 | load_state_dict, 23 | ) 24 | from torch.distributed.checkpoint.default_planner import ( 25 | DefaultSavePlanner, 26 | DefaultLoadPlanner, 27 | ) 28 | 29 | 30 | from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType 31 | import torch.distributed._shard.checkpoint as dist_cp 32 | import torch.distributed as dist 33 | 34 | 35 | import logging 36 | logger = logging.getLogger(__name__) 37 | 38 | 39 | def get_date_of_run(): 40 | """create date and time for file save uniqueness 41 | example: 2022-05-07-08:31:12_PM' 42 | """ 43 | date_of_run = datetime.now().strftime("%Y-%m-%d-%I:%M:%S_%p") 44 | logger.info(f"--> current date and time of run = {date_of_run}") 45 | return date_of_run 46 | 47 | 48 | # create singleton saving policies to avoid making over and over 49 | fullstate_save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) 50 | 51 | 52 | def load_model_sharded(model, rank, cfg): 53 | # torch.manual_seed(103) 54 | folder_name = ( 55 | cfg.dist_checkpoint_root_folder 56 | + "/" 57 | + cfg.dist_checkpoint_folder 58 | + "-" 59 | + cfg.model_name 60 | ) 61 | 62 | load_dir = Path.cwd() / folder_name 63 | 64 | if not load_dir.exists(): 65 | if rank == 0: 66 | logger.info(f"No sharded_state_dict checkpoint directory found...skipping") 67 | return 68 | if rank == 0: 69 | logger.info(f"loading model from model path: {load_dir} ") 70 | reader = FileSystemReader(load_dir) 71 | 72 | with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT): 73 | checkpoint = {"model": model.state_dict()} 74 | if rank == 0: 75 | ck = checkpoint.keys() 76 | logger.info(f" checkpoint key len = {len(ck)} and \n keys = {ck}") 77 | 78 | dist_cp.load_state_dict( 79 | state_dict=checkpoint, 80 | storage_reader=reader, 81 | ) 82 | if rank == 0: 83 | logger.info(f"checkpoint after load_state_dict()") 84 | ck = checkpoint.keys() 85 | logger.info(f" checkpoint key len = {len(ck)} and \n keys = {ck}") 86 | model.load_state_dict(checkpoint["model"]) 87 | if rank == 0: 88 | logger.info(f"Sharded state checkpoint loaded from {load_dir}") 89 | 90 | 91 | def save_model_and_optimizer_sharded(model, rank, cfg,optim=None): 92 | """save model and optimizer via sharded_state_dict to save_dir""" 93 | 94 | folder_name = ( 95 | cfg.dist_checkpoint_root_folder 96 | + "/" 97 | + cfg.dist_checkpoint_folder 98 | + "-" 99 | + cfg.model_name 100 | ) 101 | 102 | save_dir = Path.cwd() / folder_name 103 | if rank == 0: 104 | logger.info(f"Saving model to {save_dir}") 105 | 106 | distributed_writer = dist_cp.FileSystemWriter( 107 | save_dir, 108 | ) 109 | t0 = time.perf_counter() 110 | 111 | with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT): 112 | 113 | state_dict = {"model": model.state_dict()} 114 | if optim is not None: 115 | state_dict["optim"] = FSDP.optim_state_dict(model, optim) 116 | 117 | dist_cp.save_state_dict( 118 | state_dict=state_dict, 119 | storage_writer=distributed_writer, 120 | planner=DefaultSavePlanner(), 121 | 122 | ) 123 | dist.barrier() 124 | t1 = time.perf_counter() 125 | if rank == 0: 126 | logger.info(f"Sharded state checkpoint saved to {save_dir}") 127 | logger.info( 128 | f"Checkpoint Time = {t1-t0:.4f}\n" 129 | ) 130 | def save_model_checkpoint( 131 | model, 132 | optimizer, 133 | rank, 134 | cfg, 135 | epoch=1, 136 | ): 137 | """saving model via rank0 cpu streaming and full_state_dict""" 138 | 139 | with FSDP.state_dict_type( 140 | model, StateDictType.FULL_STATE_DICT, fullstate_save_policy 141 | ): 142 | cpu_state = model.state_dict() 143 | 144 | logger.info(f"saving process: rank {rank} done w model state_dict\n") 145 | 146 | 147 | if rank == 0: 148 | logger.info(f"--> saving model ...") 149 | # create save path 150 | folder_name = ( 151 | cfg.dist_checkpoint_root_folder 152 | + "/" 153 | + cfg.dist_checkpoint_folder 154 | + "-" 155 | + cfg.model_name 156 | ) 157 | save_dir = Path.cwd() / folder_name 158 | save_dir.mkdir(parents=True, exist_ok=True) 159 | save_name = cfg.model_name + "-" + str(epoch) + ".pt" 160 | save_full_path = str(save_dir) + "/" + save_name 161 | 162 | # save model 163 | torch.save(cpu_state, save_full_path) 164 | 165 | 166 | logger.info(f"model checkpoint saved for epoch {epoch} at {save_full_path}\n") 167 | 168 | def save_model_checkpoint_deepspeed(model, cfg, checkpoint_name="checkpoint"): 169 | logger.info(f"--> saving model ...") 170 | save_dir = os.path.join(cfg.output_dir, checkpoint_name) 171 | os.makedirs(save_dir, exist_ok=True) 172 | # save_full_path = os.path.join(save_dir, "model.pt") 173 | save_full_path = save_dir 174 | model.save_checkpoint(save_dir=save_full_path, exclude_frozen_parameters=True) 175 | logger.info(f"encoder saved at {save_full_path}") 176 | 177 | def save_model_checkpoint_peft(model, optimizer, rank, cfg, checkpoint_name="checkpoint", save_trainable_only=True): 178 | logger.info(f"--> saving model ...") 179 | save_dir = os.path.join(cfg.output_dir, checkpoint_name) 180 | os.makedirs(save_dir, exist_ok=True) 181 | save_full_path = os.path.join(save_dir, "model.pt") 182 | if cfg.enable_ddp: 183 | model = model.module 184 | cpu_state = model.state_dict() 185 | if save_trainable_only: 186 | state_dict = OrderedDict() 187 | for name, para in model.named_parameters(): 188 | if para.requires_grad: 189 | state_dict[name] = cpu_state[name] 190 | else: 191 | state_dict = cpu_state 192 | torch.save(state_dict, save_full_path) 193 | logger.info(f"encoder saved at {save_full_path}") 194 | 195 | def save_model_checkpoint_peft_full_shard(model, optimizer, rank, cfg, epoch=0): 196 | with FSDP.state_dict_type( 197 | model, StateDictType.FULL_STATE_DICT, fullstate_save_policy 198 | ): 199 | cpu_state = model.state_dict() 200 | logger.info(f"saving process: rank {rank} done w model state_dict\n") 201 | 202 | if rank == 0: 203 | logger.info(f"--> saving model ...") 204 | save_dir = os.path.join(cfg.output_dir, cfg.model_name, str(epoch+1)) 205 | os.makedirs(save_dir, exist_ok=True) 206 | 207 | if not cfg.freeze_llm: 208 | llm_dict = {} 209 | for key in cpu_state.keys(): 210 | if key.startswith("llm."): 211 | llm_dict[key] = cpu_state[key] 212 | model.llm.save_pretrained(save_directory=save_dir, state_dict=llm_dict) 213 | logger.info(f"llm saved at {save_dir}") 214 | 215 | save_full_path = os.path.join(save_dir, "model.pt") 216 | encoder_dict = {} 217 | if not cfg.freeze_encoder: 218 | for key in cpu_state.keys(): 219 | if key.startswith("encoder."): 220 | encoder_dict[key] = cpu_state[key] 221 | for key in cpu_state.keys(): 222 | if key.startswith("encoder_projector."): 223 | encoder_dict[key] = cpu_state[key] 224 | torch.save(encoder_dict, save_full_path) 225 | logger.info(f"encoder saved at {save_full_path}") 226 | 227 | logger.info(f"model checkpoint saved for epoch {epoch+1}\n") 228 | 229 | dist.barrier() 230 | 231 | def load_model_checkpoint(model, rank, cfg): 232 | """load local checkpoint to rank0 cpu 233 | must be called * before * passing to FSDP""" 234 | 235 | if rank != 0: 236 | return 237 | 238 | # where is the checkpoint at... 239 | full_state_dict_model_path = ( 240 | Path.cwd() / cfg.checkpoint_folder / cfg.checkpoint_model_filename 241 | ) 242 | # is it present... 243 | if not full_state_dict_model_path.is_file(): 244 | logger.info( 245 | f"model checkpoint {full_state_dict_model_path} not present. Returning..." 246 | ) 247 | return 248 | 249 | 250 | model_checkpoint = torch.load(full_state_dict_model_path) 251 | # integrate into loaded model 252 | model.load_state_dict(model_checkpoint) 253 | 254 | 255 | logger.info(f"model checkpoint loaded to rank0 cpu") 256 | 257 | 258 | def save_optimizer_checkpoint(model, optimizer, rank, cfg, epoch=1): 259 | """save optimizer state via full state dict""" 260 | 261 | 262 | logger.info(f"--> optim state call on rank {rank}\n") 263 | 264 | # pull all sharded optimizer states to rank0 cpu... 265 | 266 | optim_state = FSDP.full_optim_state_dict(model, optimizer) 267 | 268 | 269 | logger.info(f"optim state dict ready on {rank} and len of {len(optim_state)}\n") 270 | 271 | if rank == 0: 272 | folder_name = ( 273 | cfg.dist_checkpoint_root_folder 274 | + "/" 275 | + cfg.dist_checkpoint_folder 276 | + "-" 277 | + cfg.model_name 278 | ) 279 | save_dir = Path.cwd() / folder_name 280 | save_dir.mkdir(parents=True, exist_ok=True) 281 | 282 | opt_save_name = ( 283 | "optimizer" + "-" + cfg.model_name + "-" + str(epoch) + ".pt" 284 | ) 285 | opt_save_full_path = save_dir / opt_save_name 286 | 287 | logger.info(f"--> saving optimizer state...") 288 | 289 | torch.save(optim_state, opt_save_full_path) 290 | 291 | logger.info(f"--> saved {opt_save_full_path} to disk") 292 | 293 | 294 | def load_optimizer_checkpoint(model, optimizer_checkpoint_path, rank): 295 | """load an fsdp optimizer full_state checkpoint using scatter method 296 | this ensures only rank 0 loads the optimizer state dict and scatters to other ranks 297 | """ 298 | 299 | 300 | if not optimizer_checkpoint_path.is_file(): 301 | logger.info( 302 | f"warning - optimizer checkpoint not present {optimizer_checkpoint_path}. Returning. " 303 | ) 304 | return 305 | 306 | full_osd = None 307 | 308 | if rank == 0: 309 | full_osd = torch.load(optimizer_checkpoint_path) 310 | 311 | # called from all ranks, though only rank0 has a valid param for full_osd 312 | sharded_osd = FSDP.scatter_full_optim_state_dict(full_osd, model) 313 | 314 | logger.info(f"optimizer shard loaded on rank {rank}") 315 | 316 | def load_sharded_model_single_gpu(model,model_path): 317 | 318 | reader = FileSystemReader(model_path) 319 | 320 | state_dict = { 321 | "model": model.state_dict() 322 | } 323 | 324 | dist_cp.load_state_dict( 325 | state_dict=state_dict, 326 | storage_reader= FileSystemReader(model_path), 327 | no_dist=True, 328 | ) 329 | 330 | model.load_state_dict(state_dict["model"]) 331 | 332 | logger.info(f"Sharded state checkpoint loaded from {model_path}") 333 | return model 334 | -------------------------------------------------------------------------------- /src/slam_llm/pipeline/finetune.py: -------------------------------------------------------------------------------- 1 | # os 2 | import os 3 | import fire 4 | import random 5 | import importlib 6 | 7 | # nn 8 | import torch 9 | from transformers.models.llama.modeling_llama import LlamaDecoderLayer 10 | 11 | # opt 12 | import torch.optim as optim 13 | from torch.optim.lr_scheduler import StepLR 14 | from torch.distributed.fsdp import ( 15 | FullyShardedDataParallel as FSDP, 16 | ) 17 | from torch.nn.parallel import DistributedDataParallel as DDP 18 | 19 | from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload 20 | from slam_llm.policies import AnyPrecisionAdamW, apply_fsdp_checkpointing 21 | 22 | # util 23 | from slam_llm.utils import fsdp_auto_wrap_policy 24 | from slam_llm.utils.config_utils import get_dataloader_kwargs 25 | 26 | from slam_llm.utils.dataset_utils import get_preprocessed_dataset 27 | from slam_llm.data.concatenator import ConcatDataset 28 | 29 | from slam_llm.utils.model_utils import get_custom_model_factory 30 | from slam_llm.utils.train_utils import ( 31 | train, 32 | freeze_transformer_layers, 33 | setup, 34 | setup_environ_flags, 35 | clear_gpu_cache, 36 | get_policies 37 | ) 38 | 39 | import sys 40 | import logging 41 | import wandb 42 | 43 | import hydra 44 | from omegaconf import DictConfig, ListConfig, OmegaConf 45 | from pathlib import Path 46 | 47 | @hydra.main(config_name=None, version_base=None) 48 | def main_hydra(cfg: DictConfig): 49 | def to_plain_list(cfg_item): 50 | if isinstance(cfg_item, ListConfig): 51 | return OmegaConf.to_container(cfg_item, resolve=True) 52 | elif isinstance(cfg_item, DictConfig): 53 | return {k: to_plain_list(v) for k, v in cfg_item.items()} 54 | else: 55 | return cfg_item 56 | 57 | # kwargs = to_plain_list(cfg) 58 | kwargs = cfg 59 | log_level = getattr(logging, kwargs.get("log_level", "INFO").upper()) 60 | 61 | logging.basicConfig(level=log_level) 62 | 63 | if kwargs.get("debug", False): 64 | import pdb; 65 | pdb.set_trace() 66 | 67 | main(kwargs) 68 | 69 | 70 | def main(kwargs: DictConfig): 71 | # Update the configuration for the training and sharding process 72 | # train_config, fsdp_config, model_config, log_config = TRAIN_CONFIG(), FSDP_CONFIG(), MODEL_CONFIG(), LOG_CONFIG() 73 | # update_config((train_config, fsdp_config, model_config, log_config), **kwargs) 74 | 75 | train_config, fsdp_config, model_config, log_config, dataset_config = kwargs.train_config, \ 76 | kwargs.fsdp_config, \ 77 | kwargs.model_config, \ 78 | kwargs.log_config, \ 79 | kwargs.dataset_config 80 | 81 | fsdp_config.use_fp16 = train_config.use_fp16 82 | OmegaConf.set_struct(kwargs,False) 83 | del kwargs["train_config"] 84 | del kwargs["fsdp_config"] 85 | del kwargs["model_config"] 86 | del kwargs["log_config"] 87 | del kwargs["dataset_config"] 88 | OmegaConf.set_struct(kwargs,True) 89 | 90 | # Set log 91 | if not os.path.exists(os.path.dirname(log_config.log_file)): 92 | os.makedirs(os.path.dirname(log_config.log_file), exist_ok=True) 93 | logging.basicConfig( 94 | level=logging.INFO, 95 | format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", 96 | datefmt="%Y-%m-%d %H:%M:%S", 97 | filemode='w' 98 | ) 99 | 100 | logger = logging.getLogger() 101 | logger.setLevel(logging.INFO) 102 | 103 | file_handler = logging.FileHandler(filename=log_config.log_file, mode='w') 104 | file_handler.setLevel(logging.INFO) 105 | file_formatter = logging.Formatter('[%(asctime)s][%(name)s][%(levelname)s] - %(message)s', datefmt='%Y-%m-%d %H:%M:%S') 106 | file_handler.setFormatter(file_formatter) 107 | 108 | logger.handlers[0].setLevel(logging.INFO) 109 | console_formatter = logging.Formatter('[%(asctime)s][%(name)s][%(levelname)s] - %(message)s', datefmt='%Y-%m-%d %H:%M:%S') 110 | logger.handlers[0].setFormatter(console_formatter) 111 | 112 | logger.addHandler(file_handler) 113 | 114 | 115 | # Set the seeds for reproducibility 116 | torch.cuda.manual_seed(train_config.seed) 117 | torch.manual_seed(train_config.seed) 118 | random.seed(train_config.seed) 119 | 120 | if train_config.enable_fsdp or train_config.enable_ddp: 121 | setup() 122 | # torchrun specific 123 | local_rank = int(os.environ["LOCAL_RANK"]) 124 | rank = int(os.environ["RANK"]) 125 | world_size = int(os.environ["WORLD_SIZE"]) 126 | logger.info(f"local_rank: {local_rank}, rank: {rank}, world_size: {world_size}") 127 | 128 | if torch.distributed.is_initialized(): 129 | torch.cuda.set_device(local_rank) 130 | clear_gpu_cache(local_rank) 131 | setup_environ_flags(rank) 132 | 133 | if not (train_config.enable_fsdp or train_config.enable_ddp) or rank == 0: 134 | logger.info("train_config: {}".format(train_config)) 135 | logger.info("fsdp_config: {}".format(fsdp_config)) 136 | logger.info("model_config: {}".format(model_config)) 137 | logger.info("log_config: {}".format(log_config)) 138 | 139 | # Set wandb 140 | if not (train_config.enable_fsdp or train_config.enable_ddp) or rank == 0: 141 | if log_config.use_wandb: 142 | if not os.path.exists(log_config.wandb_dir): 143 | os.makedirs(log_config.wandb_dir, exist_ok=True) 144 | wandb_config={"train_config": train_config, "fsdp_config": fsdp_config, "model_config": model_config, "log_config": log_config} 145 | wandb.init(dir=log_config.wandb_dir, entity=log_config.wandb_entity_name, project=log_config.wandb_project_name,name=log_config.wandb_exp_name ,config=wandb_config) 146 | 147 | 148 | model_factory = get_custom_model_factory(model_config, logger) 149 | model, tokenizer = model_factory(train_config, model_config, **kwargs) 150 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 151 | 152 | 153 | # Convert the model to bfloat16 if fsdp and pure_bf16 is enabled 154 | if (train_config.enable_fsdp or train_config.enable_ddp) and fsdp_config.pure_bf16: 155 | model.to(torch.bfloat16) 156 | 157 | #setting up FSDP if enable_fsdp is enabled 158 | if train_config.enable_fsdp: 159 | if not train_config.use_peft and train_config.freeze_layers: 160 | 161 | freeze_transformer_layers(train_config.num_freeze_layers) 162 | # from torch.distributed.fsdp import ShardingStrategy 163 | # fsdp_config.sharding_strategy = getattr(ShardingStrategy, fsdp_config.sharding_strategy) 164 | mixed_precision_policy, wrapping_policy = get_policies(fsdp_config, rank) 165 | my_auto_wrapping_policy = fsdp_auto_wrap_policy(model, LlamaDecoderLayer) 166 | 167 | model = FSDP( 168 | model, 169 | auto_wrap_policy= my_auto_wrapping_policy, #(FIX:MZY): Using my_auto_wrapping_policy whether peft or not. This will avoid model shard type check error of requires_grad mismatching. 170 | cpu_offload=CPUOffload(offload_params=True) if fsdp_config.fsdp_cpu_offload else None, 171 | mixed_precision=mixed_precision_policy if not fsdp_config.pure_bf16 else None, 172 | sharding_strategy=fsdp_config.sharding_strategy, 173 | device_id=torch.cuda.current_device(), 174 | limit_all_gathers=True, 175 | sync_module_states=train_config.low_cpu_fsdp, 176 | param_init_fn=lambda module: module.to_empty(device=torch.device("cuda"), recurse=False) 177 | if train_config.low_cpu_fsdp and rank != 0 else None, 178 | ) 179 | if fsdp_config.fsdp_activation_checkpointing: 180 | apply_fsdp_checkpointing(model) 181 | elif train_config.enable_ddp: 182 | model = model.cuda(local_rank) 183 | model = DDP(model, device_ids=[local_rank], 184 | find_unused_parameters=kwargs.get("train_conf", {}).get("find_unused_parameters", False)) 185 | elif not train_config.quantization: 186 | model.to(device) 187 | 188 | # dataset_config = generate_dataset_config(train_config, kwargs) 189 | logger.info("dataset_config: {}".format(dataset_config)) 190 | if not (train_config.enable_fsdp or train_config.enable_ddp) or rank == 0: 191 | if log_config.use_wandb: 192 | wandb.config.update({"dataset_config": dataset_config}) 193 | 194 | # Load and preprocess the dataset for training and validation 195 | dataset_train = get_preprocessed_dataset( 196 | tokenizer, 197 | dataset_config, 198 | split="train", 199 | ) 200 | if not (train_config.enable_fsdp or train_config.enable_ddp) or rank == 0: 201 | logger.info(f"--> Training Set Length = {len(dataset_train)}") 202 | dataset_val = get_preprocessed_dataset( 203 | tokenizer, 204 | dataset_config, 205 | split="val", 206 | ) 207 | if not (train_config.enable_fsdp or train_config.enable_ddp) or rank == 0: 208 | logger.info(f"--> Validation Set Length = {len(dataset_val)}") 209 | if train_config.batching_strategy == "packing": 210 | dataset_train = ConcatDataset(dataset_train, chunk_size=train_config.context_length) 211 | 212 | train_dl_kwargs = get_dataloader_kwargs(train_config, dataset_train, tokenizer, "train") 213 | 214 | # Create DataLoaders for the training and validation dataset 215 | train_dataloader = torch.utils.data.DataLoader( 216 | dataset_train, 217 | num_workers=train_config.num_workers_dataloader, 218 | prefetch_factor=10, 219 | pin_memory=True, 220 | **train_dl_kwargs, 221 | ) 222 | 223 | eval_dataloader = None 224 | if train_config.run_validation: 225 | if train_config.batching_strategy == "packing": 226 | dataset_val = ConcatDataset(dataset_val, chunk_size=train_config.context_length) 227 | 228 | val_dl_kwargs = get_dataloader_kwargs(train_config, dataset_val, tokenizer, "val") 229 | 230 | eval_dataloader = torch.utils.data.DataLoader( 231 | dataset_val, 232 | num_workers=train_config.num_workers_dataloader, 233 | pin_memory=True, 234 | prefetch_factor=10, 235 | **val_dl_kwargs, 236 | ) 237 | 238 | # Initialize the optimizer and learning rate scheduler 239 | if fsdp_config.pure_bf16 and fsdp_config.optimizer == "anyprecision": 240 | optimizer = AnyPrecisionAdamW( 241 | model.parameters(), 242 | lr=train_config.lr, 243 | momentum_dtype=torch.bfloat16, 244 | variance_dtype=torch.bfloat16, 245 | use_kahan_summation=False, 246 | weight_decay=train_config.weight_decay, 247 | ) 248 | else: 249 | optimizer = optim.AdamW( 250 | model.parameters(), 251 | lr=train_config.lr, 252 | weight_decay=train_config.weight_decay, 253 | ) 254 | # scheduler = StepLR(optimizer, step_size=1, gamma=train_config.gamma) 255 | scheduler = torch.optim.lr_scheduler.LambdaLR( 256 | optimizer, 257 | lr_lambda=lambda step: ( 258 | min(step / train_config.warmup_steps, 1) if step < train_config.warmup_steps 259 | else max(0.0, 1 - (step - train_config.warmup_steps) / (train_config.total_steps - train_config.warmup_steps)) 260 | # else 1 261 | ) 262 | ) 263 | 264 | # Start the training process 265 | results = train( 266 | model, 267 | train_dataloader, 268 | eval_dataloader, 269 | tokenizer, 270 | optimizer, 271 | scheduler, 272 | train_config.gradient_accumulation_steps, 273 | train_config, 274 | log_config, 275 | fsdp_config if train_config.enable_fsdp else None, 276 | local_rank if train_config.enable_fsdp or train_config.enable_ddp else None, 277 | rank if train_config.enable_fsdp or train_config.enable_ddp else None, 278 | ) 279 | if not (train_config.enable_fsdp or train_config.enable_ddp) or rank==0: 280 | [logger.info(f'Key: {k}, Value: {v}') for k, v in results.items()] 281 | 282 | if not (train_config.enable_fsdp or train_config.enable_ddp) or rank == 0: 283 | if log_config.use_wandb: 284 | wandb.finish() 285 | 286 | if __name__ == "__main__": 287 | main_hydra() -------------------------------------------------------------------------------- /examples/st_covost2/dataset/fleurs_dataset.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import random 3 | import json, yaml 4 | import copy 5 | from transformers import AutoFeatureExtractor, WhisperModel,WhisperFeatureExtractor 6 | import numpy as np 7 | from scipy import signal 8 | import soundfile as sf 9 | import torch.distributed as dist 10 | import os 11 | import torch 12 | import torchaudio 13 | from torch.utils.data import Dataset 14 | import whisper 15 | from slam_llm.utils.compute_utils import calculate_output_length_1d 16 | from datasets import load_dataset,load_from_disk 17 | from datasets import Audio 18 | 19 | 20 | class SpeechDatasetJsonl(torch.utils.data.Dataset): 21 | 22 | def __init__(self, 23 | dataset_config, 24 | tokenizer=None, 25 | split='train', 26 | ): 27 | super().__init__() 28 | self.mel_size = dataset_config.get("mel_size", 80) # 80 for whisper large v1 and v2, 128 for large v3 29 | 30 | 31 | ds = load_dataset("yxdu/fleurs_en_test")["test"] 32 | ds = ds.cast_column("audio", Audio(sampling_rate=16000)) 33 | print(ds) 34 | 35 | 36 | self.ds = ds 37 | self.tokenizer = tokenizer 38 | self.dataset_config = dataset_config 39 | 40 | self.IGNORE_INDEX = -100 # The default setting in CrossEntropyLoss 41 | self.prompt = dataset_config.get("prompt", None) 42 | self.bf16 = dataset_config.get("bf16", True) 43 | self.source = dataset_config.get("source", None) 44 | 45 | self.answer_template = "{}" 46 | self.fix_length_audio = dataset_config.get("fix_length_audio", -1) 47 | self.inference_mode = dataset_config.get("inference_mode", False) 48 | self.normalize = dataset_config.get("normalize", False) 49 | self.input_type = dataset_config.get("input_type", None) 50 | assert self.input_type in ["raw", "mel"], "input_type must be one of [raw, mel]" 51 | 52 | self.data_list = [] 53 | self.count = 0 54 | 55 | 56 | self.printed = False 57 | 58 | 59 | def __len__(self): 60 | print(len(self.ds)) 61 | return len(self.ds) 62 | 63 | def __getitem__(self, index): 64 | 65 | data_dict = self.ds[index] 66 | 67 | 68 | prompt = "<|eng|><|zho|>" 69 | source = "fleurs_eng_zho" 70 | target = data_dict["raw_transcription"]+prompt+data_dict["sentence_zho_Hans"] 71 | 72 | 73 | 74 | if not self.printed: 75 | print(prompt) 76 | print(target) 77 | self.printed = True 78 | 79 | key = data_dict.get("key", str(index)) 80 | 81 | audio_raw = whisper.pad_or_trim(data_dict["audio"]["array"]) 82 | audio_raw = torch.tensor(audio_raw, dtype=torch.float32) 83 | audio_mel = whisper.log_mel_spectrogram(audio_raw, n_mels=self.mel_size).permute(1, 0) 84 | 85 | if self.bf16: 86 | audio_mel = audio_mel.to(torch.bfloat16) 87 | 88 | 89 | if self.fix_length_audio > 0: 90 | audio_length = self.fix_length_audio 91 | 92 | audio_pseudo = torch.full((audio_length,), -1) # placeholder 93 | 94 | prompt_ids = self.tokenizer.encode(prompt) 95 | prompt_length = len(prompt_ids) 96 | 97 | 98 | if self.inference_mode: 99 | audio_mel = audio_mel.to(torch.float16) 100 | audio_path = "audio" 101 | 102 | prompt_ids = self.tokenizer.encode(prompt) 103 | 104 | prompt_ids = torch.tensor(prompt_ids, dtype=torch.int64) 105 | example_ids = torch.cat((audio_pseudo, prompt_ids)) # [audio,prompt] 106 | example_mask = example_ids.ge(-1) # [True,True] 107 | 108 | return { 109 | "input_ids": example_ids, 110 | "attention_mask": example_mask, 111 | "audio": audio_raw if self.input_type == "raw" else None, 112 | "audio_mel": audio_mel if self.input_type == "mel" else None, 113 | "audio_length": audio_length, 114 | "audio_path":audio_path, 115 | "key": key, 116 | "target": target, 117 | "audio_path":audio_path, 118 | "prompt_id":prompt_ids, 119 | "prompt":prompt, 120 | "source":source, 121 | "prompt_length": prompt_length, 122 | } 123 | 124 | 125 | answer = self.answer_template.format(target) 126 | example = prompt + answer # FIX(MZY): avoid putting a bos token before answer. 127 | 128 | example_ids = self.tokenizer.encode(example) # [prompt,answer] 129 | example_ids.append(self.tokenizer.eos_token_id) # [prompt,answer,eos] 130 | example_ids = torch.tensor( 131 | example_ids, dtype=torch.int64) 132 | 133 | example_ids = torch.cat((audio_pseudo, example_ids)) # [audio,prompt,answer,eos] 134 | 135 | labels_ids = copy.deepcopy(example_ids) # [audio,prompt,answer,eos] 136 | labels_ids[:audio_length + prompt_length] = -1 # [-1,-1,answer,eos]; 137 | example_mask = example_ids.ge(-1) # FIX(GZF): [True,True,True,True] 138 | 139 | label_mask = labels_ids.ge(0) # [False,False,True,True] 140 | example_ids[~example_mask] = 0 # [audio,prompt,answer,eos] 141 | labels_ids[~label_mask] = self.IGNORE_INDEX # [-100,-100,answer,eos] 142 | 143 | 144 | 145 | return { 146 | "input_ids": example_ids, 147 | "labels": labels_ids, 148 | "attention_mask": example_mask, 149 | "audio_mel": audio_mel if self.input_type == "mel" else None, 150 | "audio_length": audio_length, 151 | "prompt_length": prompt_length, 152 | } 153 | 154 | def pad(self, sequence, max_length, padding_idx=0): 155 | if isinstance(sequence, (int, list, tuple)): 156 | if len(sequence) < max_length: 157 | sequence = sequence + [padding_idx] * (max_length - len(sequence)) 158 | else: 159 | sequence = sequence[:max_length] 160 | elif isinstance(sequence, torch.Tensor): 161 | if len(sequence) < max_length: 162 | sequence = torch.cat( 163 | (sequence, torch.full(([max_length - len(sequence)] + list(sequence.size())[1:]), padding_idx))) 164 | else: 165 | sequence = sequence[:max_length] 166 | elif isinstance(sequence, np.ndarray): 167 | if len(sequence) < max_length: 168 | sequence = np.concatenate( 169 | (sequence, np.full((max_length - len(sequence),) + sequence.shape[1:], padding_idx))) 170 | else: 171 | sequence = sequence[:max_length] 172 | else: 173 | raise Exception("Type mismatch during padding!") 174 | return sequence 175 | 176 | @classmethod 177 | def padding(cls, sequence, padding_length, padding_idx=0, padding_side="right"): 178 | if isinstance(sequence, (int, list, tuple)): 179 | if padding_length >= 0: 180 | sequence = sequence + [padding_idx] * padding_length 181 | else: 182 | sequence = sequence[:padding_length] 183 | elif isinstance(sequence, torch.Tensor): 184 | if sequence.ndimension() == 2: 185 | if padding_length >= 0: 186 | sequence = torch.nn.functional.pad(sequence, (0, padding_length)) 187 | else: 188 | sequence = sequence[:, :padding_length] 189 | else: 190 | if padding_length >= 0: 191 | if padding_side == "left": 192 | sequence = torch.cat((torch.full(([padding_length] + list(sequence.size())[1:]), padding_idx), sequence)) 193 | else: 194 | sequence = torch.cat((sequence, torch.full(([padding_length] + list(sequence.size())[1:]), padding_idx))) 195 | else: 196 | sequence = sequence[:padding_length] 197 | elif isinstance(sequence, np.ndarray): 198 | if padding_length >= 0: 199 | sequence = np.concatenate( 200 | (sequence, np.full((padding_length,) + sequence.shape[1:], padding_idx))) 201 | else: 202 | sequence = sequence[:padding_length] 203 | else: 204 | raise Exception("Type mismatch during padding!") 205 | return sequence 206 | 207 | def collator(self, samples): 208 | assert samples is not None 209 | input_prompt_lengths = [s["audio_length"] + s['prompt_length'] for s in samples] #[120, 48, 82, 42] 210 | input_answer_lengths = [len(s["input_ids"]) - s["audio_length"] - s['prompt_length'] for s in samples] #[0, 0, 0, 0] 211 | 212 | input_prompt_max_length = max(input_prompt_lengths) 213 | input_answer_max_length = max(input_answer_lengths) 214 | 215 | input_ids = torch.stack([ 216 | self.padding( 217 | self.padding(samples[index]["input_ids"], input_prompt_max_length - input_prompt_lengths[index], self.tokenizer.pad_token_id, padding_side="left"), 218 | input_answer_max_length - input_answer_lengths[index], self.tokenizer.pad_token_id 219 | ) for index in range(len(samples)) 220 | ]) 221 | 222 | attention_mask = torch.stack([ 223 | self.padding( 224 | self.padding(samples[index]["attention_mask"], input_prompt_max_length - input_prompt_lengths[index], False, padding_side="left"), 225 | input_answer_max_length - input_answer_lengths[index], False 226 | ) for index in range(len(samples)) 227 | ]) 228 | 229 | 230 | if self.input_type == "raw": 231 | audio_raw_max_length = max([s['audio'].shape[0] for s in samples]) 232 | audio_raw = torch.stack([self.pad(s['audio'], audio_raw_max_length, 0) 233 | for s in samples]) 234 | audio_mask = torch.zeros(len(samples), audio_raw_max_length) 235 | for line, sample in enumerate(samples): 236 | audio_mask[line, :sample['audio'].shape[0]] = 1 237 | elif self.input_type == "mel": 238 | audio_mel_max_length = max([s['audio_mel'].shape[0] for s in samples]) 239 | audio_mel = torch.stack([self.pad(s['audio_mel'], audio_mel_max_length, 0) 240 | for s in samples]) 241 | audio_mel_post_mask = torch.zeros(len(samples), (audio_mel_max_length + 1) // 2) # ad-hoc for whisper for 2x downsample from mel to feats 242 | for line, sample in enumerate(samples): 243 | audio_mel_post_mask[line, :(sample['audio_mel'].shape[0] + 1) // 2] = 1 244 | 245 | modality_mask = torch.zeros_like(attention_mask) 246 | for index in range(len(samples)): 247 | padding_left = input_prompt_max_length - input_prompt_lengths[index] 248 | modality_mask[index, padding_left:padding_left+samples[index]["audio_length"]] = True 249 | 250 | if self.inference_mode: 251 | keys = [s['key'] for s in samples] 252 | targets = [s['target'] for s in samples] 253 | prompts = [s['prompt'] for s in samples] 254 | audio_paths = [s['audio_path'] for s in samples] 255 | 256 | return { 257 | "input_ids": input_ids, 258 | "attention_mask": attention_mask, 259 | "audio": audio_raw if self.input_type == "raw" else None, 260 | "audio_mask": audio_mask if self.input_type == "raw" else None, 261 | "audio_mel": audio_mel if self.input_type == "mel" else None, 262 | "audio_mel_post_mask": audio_mel_post_mask if self.input_type == "mel" else None, 263 | "modality_mask": modality_mask, 264 | "keys": keys, 265 | "targets": targets, 266 | "audio_paths": audio_paths, 267 | "prompts": prompts, 268 | } 269 | 270 | labels = torch.stack([ 271 | self.padding( 272 | self.padding(samples[index]['labels'], input_prompt_max_length - input_prompt_lengths[index], self.IGNORE_INDEX, padding_side="left"), 273 | input_answer_max_length - input_answer_lengths[index], self.IGNORE_INDEX) 274 | for index in range(len(samples)) 275 | ]) 276 | 277 | return { 278 | "input_ids": input_ids, 279 | "labels": labels, 280 | "attention_mask": attention_mask, 281 | "audio": audio_raw if self.input_type == "raw" else None, 282 | "audio_mask": audio_mask if self.input_type == "raw" else None, 283 | "audio_mel": audio_mel if self.input_type == "mel" else None, 284 | "audio_mel_post_mask": audio_mel_post_mask if self.input_type == "mel" else None, 285 | "modality_mask": modality_mask 286 | } 287 | 288 | 289 | 290 | 291 | def get_speech_dataset(dataset_config, tokenizer, split): 292 | dataset = SpeechDatasetJsonl(dataset_config, tokenizer, split) 293 | 294 | return dataset -------------------------------------------------------------------------------- /examples/st_covost2/dataset/srt_dataset.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import random 3 | import json, yaml 4 | import copy 5 | import os 6 | import numpy as np 7 | from scipy import signal 8 | import soundfile as sf 9 | import librosa 10 | import torch 11 | import torchaudio 12 | from torch.utils.data import Dataset 13 | import whisper 14 | from slam_llm.utils.compute_utils import calculate_output_length_1d 15 | 16 | 17 | class SpeechDatasetJsonl(torch.utils.data.Dataset): 18 | 19 | def __init__(self, 20 | dataset_config, 21 | tokenizer=None, 22 | split='train', 23 | ): 24 | super().__init__() 25 | self.dataset_config = dataset_config 26 | self.tokenizer = tokenizer 27 | self.mode = dataset_config.get("mode", "srt") 28 | 29 | self.IGNORE_INDEX = -100 # The default setting in CrossEntropyLoss 30 | self.prompt = dataset_config.get("prompt", "") 31 | self.bf16 = dataset_config.get("bf16", True) 32 | self.fp16 = dataset_config.get("fp16", False) 33 | self.mel_size = dataset_config.get("mel_size", 128) # 80 for whisper large v1 and v2, 128 for large v3 34 | self.source = dataset_config.get("source", "eng") 35 | 36 | self.answer_template = "{}" 37 | self.fix_length_audio = dataset_config.get("fix_length_audio", 80) 38 | self.inference_mode = dataset_config.get("inference_mode", False) 39 | self.normalize = dataset_config.get("normalize", False) 40 | self.validnum = dataset_config.get("validnum", -2) 41 | self.input_type = dataset_config.get("input_type", "mel") 42 | assert self.input_type in ["raw", "mel"], "input_type must be one of [raw, mel]" 43 | self.data_dir = os.path.dirname(dataset_config.get("val_data_path"))+"/" 44 | print(self.data_dir) 45 | 46 | src_lang = ['ara', 'ben', 'ces', 'deu', 'eng', 'fas', 'fra', 'heb', 'hin', 'ind', 'ita', 'jpn', 'khm', 'kor', 'lao', 'msa', 'mya', 'nld', 'pol', 'por', 'rus', 'spa', 'tha', 'tgl', 'tur', 'urd', 'vie', 'zho'] 47 | # src = self.source.split("_")[-1] 48 | # src_lang = [src] 49 | # src_lang = ['eng', 'deu', 'fra', 'spa', 'por', 'ita', 'nld', 'rus', 'jpn', 'kor', 'vie', 'ind','tha',"zho","yue"] 50 | # src_lang = ['eng', 'deu', 'fra', 'spa', 'por', 'ita', 'nld', 'rus', 'jpn', 'kor', 'vie', 'ind','tha',"zho"] 51 | # src_lang = ['zho'] 52 | # src_lang = ['eng',"zho","jpn","kor"] 53 | # src_lang = ['spa'] 54 | # src_lang = ['zho'] 55 | src_lang = ['eng'] 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | # src_lang = ['eng', 'deu', 'fra', 'spa', 'por', 'ita', 'nld', 'rus', 'jpn', 'kor', 'vie', 'ind','tha',"zho","yue"] 64 | 65 | 66 | 67 | 68 | # eng_Lant 69 | tgt_lang = ['ara', 'ben', 'ces', 'deu', 'eng', 'fas', 'fra', 'heb', 'hin', 'ind', 'ita', 'jpn', 'khm', 'kor', 'lao', 'msa', 'mya', 'nld', 'pol', 'por', 'rus', 'spa', 'tha', 'tgl', 'tur', 'urd', 'vie', 'zho'] 70 | 71 | tgt_lang = ['ara','zsm', 'ben', 'ces', 'deu', 'eng', 'fas', 'fra', 'heb', 'hin', 'ind', 'ita', 'jpn', 'khm', 'kor', 'lao', 'msa', 'mya', 'nld', 'pol', 'por', 'rus', 'spa', 'tha', 'tgl', 'tur', 'urd', 'vie', 'zho'] 72 | 73 | 74 | 75 | # tgt_lang = ["pes","tur","hin","tgl","arb","zsm","ces"] 76 | # tgt_lang = ['eng', 'deu', 'fra', 'spa', 'por', 'ita', 'nld', 'rus', 'jpn', 'kor', 'vie', 'ind','tha',"zho","yue"] 77 | 78 | # tgt_lang = ['eng', 'deu', 'fra', 'spa', 'por', 'ita', 'nld', 'rus', 'jpn', 'kor', 'vie', 'ind','tha',"zho","yue"] 79 | 80 | # tgt_lang = ['deu', 'fra', 'rus', 'jpn', "zho", "eng"] 81 | 82 | tgt_lang = ['zho'] 83 | # tgt_lang = ['eng'] 84 | 85 | 86 | 87 | # tgt_lang = ['jpn'] 88 | 89 | # tgt_lang = ['jpn', "zho","yue"] 90 | # tgt_lang = ["eng"] 91 | 92 | 93 | # 设置随机种子,确保结果可复现 94 | random_seed = 42 # 可以替换为任意整数 95 | random.seed(random_seed) 96 | 97 | 98 | self.data_list = [] 99 | self.count = 0 100 | 101 | if split == "train": 102 | with open(dataset_config.get("train_data_path"), encoding='utf-8') as fin: 103 | for line in fin: 104 | data_dict = json.loads(line.strip()) 105 | data_source = data_dict["source"] 106 | if self.source==data_source: 107 | self.data_list.append(data_dict) 108 | elif self.source == "all": 109 | self.data_list.append(data_dict) 110 | elif data_source.split("_")[-2] in src_lang and data_source.split("_")[-1] in tgt_lang: 111 | self.data_list.append(data_dict) 112 | # 打乱数据顺序 113 | random.shuffle(self.data_list) 114 | else: 115 | with open(dataset_config.get("val_data_path"), encoding='utf-8') as fin: 116 | for line in fin: 117 | data_dict = json.loads(line.strip()) 118 | data_source = data_dict["source"] 119 | if self.source == data_source: 120 | self.data_list.append(data_dict) 121 | elif self.source == "all": 122 | self.data_list.append(data_dict) 123 | elif data_source.split("_")[-2] in src_lang and data_source.split("_")[-1] in tgt_lang: 124 | self.data_list.append(data_dict) 125 | if self.validnum == -1: 126 | random.shuffle(self.data_list) 127 | # if len(self.data_list)>50000: 128 | # self.data_list=self.data_list[:50000] 129 | elif self.validnum == -2: 130 | pass 131 | else: 132 | self.data_list = random.sample(self.data_list, self.validnum) 133 | 134 | 135 | 136 | # 截取前 1000 条数据 137 | self.printed = False # 标志位,控制print只执行一次 138 | print(split,len(self.data_list)) 139 | 140 | 141 | def __len__(self): 142 | return len(self.data_list) 143 | 144 | def __getitem__(self, index): 145 | data_dict = self.data_list[index] 146 | 147 | audio_path = data_dict.get("audio","") 148 | if not audio_path.startswith('/'): 149 | audio_path = self.data_dir + audio_path 150 | 151 | 152 | 153 | 154 | prompt = data_dict.get("prompt") 155 | target = data_dict.get("gt") 156 | source = data_dict.get("source") 157 | 158 | if self.mode == "smt": 159 | prompt = target.split(prompt)[0]+prompt 160 | if self.validnum ==-1: 161 | target = target.split(prompt)[1] 162 | elif self.mode == "asr": 163 | prompt = prompt[:7] 164 | target = target.split(prompt)[0] 165 | elif self.mode == "asrmmt": 166 | prompt = data_dict.get("asr").split(prompt)[0]+prompt 167 | 168 | if not self.printed: 169 | print(prompt) 170 | print(target) 171 | self.printed = True 172 | 173 | 174 | key = data_dict.get("key", str(index)) 175 | 176 | audio_raw = whisper.load_audio(audio_path) 177 | # audio_raw, sr = librosa.load(audio_path, sr=None) # sr=None ensures we get the original sample rate 178 | # Resample audio to 16000 Hz if the sample rate is different 179 | # if sr != 16000: 180 | # audio_raw = librosa.resample(audio_raw, orig_sr=sr, target_sr=16000) 181 | # sr = 16000 # Update the sample rate to 16000 182 | 183 | if self.input_type == "raw": 184 | audio_raw = torch.from_numpy(audio_raw) 185 | if self.normalize: 186 | audio_raw = torch.nn.functional.layer_norm(audio_raw, audio_raw.shape) 187 | audio_length = len(audio_raw) // 320 # ad-hoc for fairseq 320x downsample 188 | audio_length = audio_length // 5 # ad-hoc for 5x fc downsample 189 | elif self.input_type == "mel": 190 | audio_raw = whisper.pad_or_trim(audio_raw) 191 | audio_mel = whisper.log_mel_spectrogram(audio_raw, n_mels=self.mel_size).permute(1, 0) 192 | 193 | 194 | if self.fix_length_audio > 0: 195 | audio_length = self.fix_length_audio 196 | audio_pseudo = torch.full((audio_length,), -1) # placeholder 197 | prompt_id = self.tokenizer.encode(prompt) 198 | prompt_length = len(prompt_id) 199 | 200 | 201 | if self.inference_mode: 202 | audio_mel = audio_mel.to(torch.float16) 203 | 204 | 205 | prompt_id = torch.tensor(prompt_id, dtype=torch.int64) 206 | example_ids = torch.cat((audio_pseudo, prompt_id)) # [audio,prompt] 207 | example_mask = example_ids.ge(-1) # [True,True] 208 | 209 | return { 210 | "input_ids": example_ids, 211 | "attention_mask": example_mask, 212 | "audio": audio_raw if self.input_type == "raw" else None, 213 | "audio_mel": audio_mel if self.input_type == "mel" else None, 214 | "audio_length": audio_length, 215 | "audio_path":audio_path, 216 | "key": key, 217 | "target": target, 218 | "audio_path":audio_path, 219 | "prompt_id":prompt_id, 220 | "prompt":prompt, 221 | "source":source, 222 | "prompt_length": prompt_length, 223 | } 224 | 225 | if self.bf16: 226 | audio_mel = audio_mel.to(torch.bfloat16) 227 | answer = self.answer_template.format(target) 228 | example = prompt + answer # FIX(MZY): avoid putting a bos token before answer. 229 | 230 | example_ids = self.tokenizer.encode(example) # [prompt,answer] 231 | example_ids.append(self.tokenizer.eos_token_id) # [prompt,answer,eos] 232 | example_ids = torch.tensor( 233 | example_ids, dtype=torch.int64) 234 | 235 | example_ids = torch.cat((audio_pseudo, example_ids)) # [audio,prompt,answer,eos] 236 | 237 | labels_ids = copy.deepcopy(example_ids) # [audio,prompt,answer,eos] 238 | labels_ids[:audio_length + prompt_length] = -1 # [-1,-1,answer,eos]; 239 | example_mask = example_ids.ge(-1) # FIX(GZF): [True,True,True,True] 240 | 241 | label_mask = labels_ids.ge(0) # [False,False,True,True] 242 | example_ids[~example_mask] = 0 # [audio,prompt,answer,eos] 243 | labels_ids[~label_mask] = self.IGNORE_INDEX # [-100,-100,answer,eos] 244 | 245 | 246 | 247 | return { 248 | "input_ids": example_ids, 249 | "labels": labels_ids, 250 | "attention_mask": example_mask, 251 | "audio": audio_raw if self.input_type == "raw" else None, 252 | "audio_mel": audio_mel if self.input_type == "mel" else None, 253 | "audio_length": audio_length, 254 | "prompt_length": prompt_length, 255 | } 256 | 257 | def pad(self, sequence, max_length, padding_idx=0): 258 | if isinstance(sequence, (int, list, tuple)): 259 | if len(sequence) < max_length: 260 | sequence = sequence + [padding_idx] * (max_length - len(sequence)) 261 | else: 262 | sequence = sequence[:max_length] 263 | elif isinstance(sequence, torch.Tensor): 264 | if len(sequence) < max_length: 265 | sequence = torch.cat( 266 | (sequence, torch.full(([max_length - len(sequence)] + list(sequence.size())[1:]), padding_idx))) 267 | else: 268 | sequence = sequence[:max_length] 269 | elif isinstance(sequence, np.ndarray): 270 | if len(sequence) < max_length: 271 | sequence = np.concatenate( 272 | (sequence, np.full((max_length - len(sequence),) + sequence.shape[1:], padding_idx))) 273 | else: 274 | sequence = sequence[:max_length] 275 | else: 276 | raise Exception("Type mismatch during padding!") 277 | return sequence 278 | 279 | @classmethod 280 | def padding(cls, sequence, padding_length, padding_idx=0, padding_side="right"): 281 | if isinstance(sequence, (int, list, tuple)): 282 | if padding_length >= 0: 283 | sequence = sequence + [padding_idx] * padding_length 284 | else: 285 | sequence = sequence[:padding_length] 286 | elif isinstance(sequence, torch.Tensor): 287 | if sequence.ndimension() == 2: 288 | if padding_length >= 0: 289 | sequence = torch.nn.functional.pad(sequence, (0, padding_length)) 290 | else: 291 | sequence = sequence[:, :padding_length] 292 | else: 293 | if padding_length >= 0: 294 | if padding_side == "left": 295 | sequence = torch.cat((torch.full(([padding_length] + list(sequence.size())[1:]), padding_idx), sequence)) 296 | else: 297 | sequence = torch.cat((sequence, torch.full(([padding_length] + list(sequence.size())[1:]), padding_idx))) 298 | else: 299 | sequence = sequence[:padding_length] 300 | elif isinstance(sequence, np.ndarray): 301 | if padding_length >= 0: 302 | sequence = np.concatenate( 303 | (sequence, np.full((padding_length,) + sequence.shape[1:], padding_idx))) 304 | else: 305 | sequence = sequence[:padding_length] 306 | else: 307 | raise Exception("Type mismatch during padding!") 308 | return sequence 309 | 310 | def collator(self, samples): 311 | assert samples is not None 312 | input_prompt_lengths = [s["audio_length"] + s['prompt_length'] for s in samples] #[120, 48, 82, 42] 313 | input_answer_lengths = [len(s["input_ids"]) - s["audio_length"] - s['prompt_length'] for s in samples] #[0, 0, 0, 0] 314 | 315 | input_prompt_max_length = max(input_prompt_lengths) 316 | input_answer_max_length = max(input_answer_lengths) 317 | 318 | input_ids = torch.stack([ 319 | self.padding( 320 | self.padding(samples[index]["input_ids"], input_prompt_max_length - input_prompt_lengths[index], self.tokenizer.pad_token_id, padding_side="left"), 321 | input_answer_max_length - input_answer_lengths[index], self.tokenizer.pad_token_id 322 | ) for index in range(len(samples)) 323 | ]) 324 | 325 | attention_mask = torch.stack([ 326 | self.padding( 327 | self.padding(samples[index]["attention_mask"], input_prompt_max_length - input_prompt_lengths[index], False, padding_side="left"), 328 | input_answer_max_length - input_answer_lengths[index], False 329 | ) for index in range(len(samples)) 330 | ]) 331 | 332 | 333 | if self.input_type == "raw": 334 | audio_raw_max_length = max([s['audio'].shape[0] for s in samples]) 335 | audio_raw = torch.stack([self.pad(s['audio'], audio_raw_max_length, 0) 336 | for s in samples]) 337 | audio_mask = torch.zeros(len(samples), audio_raw_max_length) 338 | for line, sample in enumerate(samples): 339 | audio_mask[line, :sample['audio'].shape[0]] = 1 340 | elif self.input_type == "mel": 341 | audio_mel_max_length = max([s['audio_mel'].shape[0] for s in samples]) 342 | audio_mel = torch.stack([self.pad(s['audio_mel'], audio_mel_max_length, 0) 343 | for s in samples]) 344 | audio_mel_post_mask = torch.zeros(len(samples), (audio_mel_max_length + 1) // 2) # ad-hoc for whisper for 2x downsample from mel to feats 345 | for line, sample in enumerate(samples): 346 | audio_mel_post_mask[line, :(sample['audio_mel'].shape[0] + 1) // 2] = 1 347 | 348 | modality_mask = torch.zeros_like(attention_mask) 349 | for index in range(len(samples)): 350 | padding_left = input_prompt_max_length - input_prompt_lengths[index] 351 | modality_mask[index, padding_left:padding_left+samples[index]["audio_length"]] = True 352 | 353 | if self.inference_mode: 354 | keys = [s['key'] for s in samples] 355 | targets = [s['target'] for s in samples] 356 | audio_paths = [s['audio_path'] for s in samples] 357 | prompts = [s['prompt'] for s in samples] 358 | prompt_ids = [s['prompt_id'] for s in samples] 359 | 360 | return { 361 | "input_ids": input_ids, 362 | "attention_mask": attention_mask, 363 | "audio": audio_raw if self.input_type == "raw" else None, 364 | "audio_mask": audio_mask if self.input_type == "raw" else None, 365 | "audio_mel": audio_mel if self.input_type == "mel" else None, 366 | "audio_mel_post_mask": audio_mel_post_mask if self.input_type == "mel" else None, 367 | "modality_mask": modality_mask, 368 | "keys": keys, 369 | "targets": targets, 370 | "audio_paths": audio_paths, 371 | "prompts": prompts, 372 | "prompt_ids": prompt_ids, 373 | } 374 | 375 | labels = torch.stack([ 376 | self.padding( 377 | self.padding(samples[index]['labels'], input_prompt_max_length - input_prompt_lengths[index], self.IGNORE_INDEX, padding_side="left"), 378 | input_answer_max_length - input_answer_lengths[index], self.IGNORE_INDEX) 379 | for index in range(len(samples)) 380 | ]) 381 | 382 | return { 383 | "input_ids": input_ids, 384 | "labels": labels, 385 | "attention_mask": attention_mask, 386 | "audio": audio_raw if self.input_type == "raw" else None, 387 | "audio_mask": audio_mask if self.input_type == "raw" else None, 388 | "audio_mel": audio_mel if self.input_type == "mel" else None, 389 | "audio_mel_post_mask": audio_mel_post_mask if self.input_type == "mel" else None, 390 | "modality_mask": modality_mask 391 | } 392 | 393 | 394 | 395 | 396 | def get_speech_dataset(dataset_config, tokenizer, split): 397 | dataset = SpeechDatasetJsonl(dataset_config, tokenizer, split) 398 | 399 | return dataset 400 | -------------------------------------------------------------------------------- /examples/st_covost2/model/slam_model_st.py: -------------------------------------------------------------------------------- 1 | import os 2 | import types 3 | import torch 4 | import soundfile as sf 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import torch.distributed as dist 8 | from typing import List, Optional, Tuple, Union 9 | from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig, AutoModel, AutoModelForSeq2SeqLM, T5ForConditionalGeneration 10 | from peft import LoraConfig, TaskType, get_peft_model, prepare_model_for_kbit_training 11 | 12 | from slam_llm.utils.config_utils import generate_peft_config 13 | from slam_llm.utils.train_utils import print_module_size, print_model_size 14 | from peft import PeftModel, PeftConfig 15 | from torch.nn import CrossEntropyLoss 16 | from slam_llm.utils.metric import compute_accuracy 17 | from transformers import SeamlessM4Tv2ForSpeechToText,SeamlessM4Tv2ForTextToText 18 | import logging 19 | 20 | from transformers import StoppingCriteria, StoppingCriteriaList 21 | 22 | logger = logging.getLogger(__name__) 23 | 24 | 25 | 26 | def model_factory(train_config, model_config, **kwargs): 27 | tokenizer = setup_tokenizer(train_config, model_config, **kwargs) 28 | 29 | encoder = setup_encoder(train_config, model_config, **kwargs) 30 | 31 | # llm 32 | llm = setup_llm(train_config, model_config, **kwargs) 33 | 34 | 35 | 36 | # projector 37 | encoder_projector = setup_encoder_projector( 38 | train_config, model_config, **kwargs 39 | ) 40 | model = slam_model( 41 | encoder, 42 | llm, 43 | encoder_projector, 44 | tokenizer, 45 | train_config, 46 | model_config, 47 | **kwargs, 48 | ) 49 | 50 | 51 | ckpt_path = kwargs.get("ckpt_path", None) #FIX(MZY): load model ckpt(mainly projector, related to model_checkpointing/checkpoint_handler.py: save_model_checkpoint_peft) 52 | print(ckpt_path) 53 | if ckpt_path is not None: 54 | logger.info("loading other parts from: {}".format(ckpt_path)) 55 | ckpt_dict = torch.load(ckpt_path, map_location="cpu") 56 | model.load_state_dict(ckpt_dict, strict=False) 57 | 58 | print_model_size(model, train_config, int(os.environ["RANK"]) if train_config.enable_fsdp or train_config.enable_ddp else 0) 59 | # print(model) 60 | return model, tokenizer 61 | 62 | 63 | def setup_tokenizer(train_config, model_config, **kwargs): 64 | # Load the tokenizer and add special tokens 65 | if "vallex" in model_config.llm_name.lower(): 66 | return None 67 | elif "mupt" in model_config.llm_name.lower(): 68 | tokenizer = AutoTokenizer.from_pretrained(model_config.llm_path, 69 | trust_remote_code=True, 70 | use_fast=False) 71 | else: 72 | tokenizer = AutoTokenizer.from_pretrained(model_config.llm_path,trust_remote_code=True) 73 | 74 | 75 | tokenizer.pad_token_id = tokenizer.eos_token_id 76 | print("tokenizer.pad_token_id:",tokenizer.pad_token_id) 77 | print("tokenizer.eos_token_id:",tokenizer.eos_token_id) 78 | 79 | print("eos_token:", tokenizer.eos_token) # 输出如 "" 或 "<|endoftext|>" 80 | print("pad_token:", tokenizer.pad_token) # 输出如 "" 或 " 81 | return tokenizer 82 | 83 | 84 | def setup_encoder(train_config, model_config, **kwargs): 85 | encoder_list = model_config.encoder_name.split(",") if model_config.encoder_name else [] 86 | if len(encoder_list) == 0: 87 | return None 88 | if len(encoder_list) == 1: 89 | encoder_name = encoder_list[0] 90 | if encoder_name == "whisper" or encoder_name == "qwen-audio": 91 | from slam_llm.models.encoder import WhisperWrappedEncoder 92 | encoder = WhisperWrappedEncoder.load(model_config) 93 | if encoder_name == "beats": 94 | from slam_llm.models.encoder import BEATsEncoder 95 | encoder = BEATsEncoder.load(model_config) 96 | if encoder_name == "eat": 97 | from slam_llm.models.encoder import EATEncoder 98 | encoder = EATEncoder.load(model_config) 99 | if encoder_name == "SpatialAST": 100 | from slam_llm.models.encoder import SpatialASTEncoder 101 | encoder = SpatialASTEncoder.load(model_config) 102 | if encoder_name == "wavlm": 103 | from slam_llm.models.encoder import WavLMEncoder 104 | encoder = WavLMEncoder.load(model_config) 105 | if encoder_name == "av_hubert": 106 | from slam_llm.models.encoder import AVHubertEncoder 107 | encoder = AVHubertEncoder.load(model_config) 108 | if encoder_name == "hubert": 109 | from slam_llm.models.encoder import HubertEncoder 110 | encoder = HubertEncoder.load(model_config) 111 | if encoder_name == "musicfm": 112 | from slam_llm.models.encoder import MusicFMEncoder 113 | encoder = MusicFMEncoder.load(model_config) 114 | 115 | if "llama" in encoder_name.lower(): 116 | from slam_llm.models.encoder import HfTextEncoder 117 | encoder = HfTextEncoder.load(model_config) 118 | print_module_size(encoder, encoder_name, int(os.environ["RANK"]) if train_config.enable_fsdp or train_config.enable_ddp else 0) 119 | 120 | if train_config.freeze_encoder: 121 | for name, param in encoder.named_parameters(): 122 | param.requires_grad = False 123 | encoder.eval() 124 | print_module_size(encoder, encoder_name, int(os.environ["RANK"]) if train_config.enable_fsdp or train_config.enable_ddp else 0) 125 | 126 | return encoder 127 | 128 | def setup_llm(train_config, model_config, **kwargs): 129 | from pkg_resources import packaging 130 | use_cache = False if train_config.enable_fsdp or train_config.enable_ddp else None 131 | 132 | model = AutoModelForCausalLM.from_pretrained( 133 | model_config.llm_path, 134 | load_in_8bit=True if train_config.quantization else None, 135 | device_map="auto" if train_config.quantization else None, 136 | use_cache=use_cache, 137 | # attn_implementation="flash_attention_2" if train_config.use_fast_kernels else None, 138 | attn_implementation="eager", 139 | torch_dtype=torch.bfloat16, 140 | trust_remote_code=True 141 | ) 142 | # print(model) 143 | 144 | 145 | print_module_size(model, model_config.llm_name, int(os.environ["RANK"]) if train_config.enable_fsdp or train_config.enable_ddp else 0) 146 | 147 | # Prepare the model for int8 training if quantization is enabled 148 | if train_config.quantization: 149 | model = prepare_model_for_kbit_training(model) 150 | 151 | if train_config.freeze_llm: # TODO:to test offical `freeze_layers` and `num_freeze_layers` 152 | for name, param in model.named_parameters(): 153 | param.requires_grad = False 154 | model.eval() 155 | 156 | if kwargs.get("peft_ckpt", None): # (FIX:MZY):reload will get wrong results when decoding 157 | logger.info("loading peft_ckpt from: {}".format(kwargs.get("peft_ckpt"))) 158 | model = PeftModel.from_pretrained(model=model, model_id=kwargs.get("peft_ckpt"), is_trainable=True) 159 | model.print_trainable_parameters() 160 | elif train_config.use_peft: 161 | logger.info("setup peft...") 162 | peft_config = generate_peft_config(train_config) 163 | print(peft_config) 164 | model = get_peft_model(model, peft_config) 165 | model.print_trainable_parameters() 166 | 167 | # print(model) 168 | print_module_size(model, model_config.llm_name, int(os.environ["RANK"]) if train_config.enable_fsdp or train_config.enable_ddp else 0) 169 | return model 170 | 171 | def setup_encoder_projector(train_config, model_config, **kwargs): 172 | if model_config.encoder_projector == "linear": 173 | from slam_llm.models.projector import EncoderProjectorConcat 174 | encoder_projector = EncoderProjectorConcat(model_config) 175 | elif model_config.encoder_projector == "cov1d-linear": 176 | from slam_llm.models.projector import EncoderProjectorCov1d 177 | encoder_projector = EncoderProjectorCov1d(model_config) 178 | elif model_config.encoder_projector == "q-former": 179 | from slam_llm.models.projector import EncoderProjectorQFormer 180 | encoder_projector = EncoderProjectorQFormer(model_config) 181 | else: 182 | return None 183 | print_module_size(encoder_projector, model_config.encoder_projector, int(os.environ["RANK"]) if train_config.enable_fsdp or train_config.enable_ddp else 0) 184 | return encoder_projector 185 | 186 | 187 | class slam_model(nn.Module): 188 | def __init__( 189 | self, 190 | encoder: nn.Module, 191 | llm: nn.Module, 192 | encoder_projector: nn.Module, 193 | tokenizer, 194 | train_config, 195 | model_config, 196 | **kwargs 197 | ): 198 | super().__init__() 199 | # modality encoder 200 | self.encoder = encoder 201 | 202 | 203 | # llm 204 | self.llm = llm 205 | 206 | # projector 207 | self.encoder_projector = encoder_projector 208 | 209 | # tokenizer 210 | self.tokenizer = tokenizer 211 | 212 | self.metric = kwargs.get("metric", "acc") 213 | 214 | self.train_config = train_config 215 | self.model_config = model_config 216 | 217 | if train_config.get("enable_deepspeed", False): 218 | def new_forward(self, input): 219 | output = F.layer_norm( 220 | input.float(), 221 | self.normalized_shape, 222 | self.weight.float() if self.weight is not None else None, 223 | self.bias.float() if self.bias is not None else None, 224 | self.eps, 225 | ) 226 | return output.type_as(input) 227 | for item in self.modules(): 228 | if isinstance(item, nn.LayerNorm): 229 | item.forward = types.MethodType(new_forward, item) 230 | 231 | 232 | 233 | def forward(self, 234 | input_ids: torch.LongTensor = None, 235 | attention_mask: Optional[torch.Tensor] = None, 236 | position_ids: Optional[torch.LongTensor] = None, 237 | past_key_values: Optional[List[torch.FloatTensor]] = None, 238 | inputs_embeds: Optional[torch.FloatTensor] = None, 239 | labels: Optional[torch.LongTensor] = None, 240 | use_cache: Optional[bool] = None, 241 | output_attentions: Optional[bool] = None, 242 | output_hidden_states: Optional[bool] = None, 243 | return_dict: Optional[bool] = None, 244 | **kwargs, 245 | ): 246 | audio_mel = kwargs.get("audio_mel", None) 247 | audio_mel_mask = kwargs.get("audio_mel_mask", None) 248 | audio_mel_post_mask = kwargs.get("audio_mel_post_mask", None) # 2x downsample for whisper 249 | 250 | audio = kwargs.get("audio", None) 251 | audio_mask = kwargs.get("audio_mask", None) 252 | visual = kwargs.get("visual", None) 253 | visual_mask = kwargs.get("visual_mask", None) 254 | 255 | 256 | # for text encoder 257 | instruct_ids = kwargs.get("instruct_ids", None) 258 | instruct_mask = kwargs.get("instruct_mask", None) 259 | 260 | modality_mask = kwargs.get("modality_mask", None) 261 | 262 | 263 | 264 | encoder_outs = None 265 | if audio_mel is not None or audio is not None or visual is not None: 266 | if self.train_config.freeze_encoder: # freeze encoder 267 | self.encoder.eval() 268 | if self.model_config.encoder_path_hf is not None: 269 | encoder_outs = self.encoder(audio_mel.permute(0, 2, 1)).last_hidden_state # bs*seq*dim 270 | elif self.model_config.encoder_name == "whisper": 271 | encoder_outs = self.encoder.extract_variable_length_features(audio_mel.permute(0, 2, 1)) # bs*seq*dim 272 | if self.model_config.encoder_name == "beats": 273 | encoder_outs, audio_mel_post_mask = self.encoder.extract_features(audio_mel, audio_mel_mask) # bs*seq*dim 274 | if self.model_config.encoder_name == "eat": 275 | encoder_outs = self.encoder.model.extract_features(audio_mel.unsqueeze(dim=1), padding_mask = None, mask=False, remove_extra_tokens = False)['x'] 276 | if self.model_config.encoder_name == "SpatialAST": 277 | encoder_outs = self.encoder(audio) # output: [bs, seq_len=3+512, dim=768] 278 | if self.model_config.encoder_name == "wavlm": 279 | encoder_outs = self.encoder.extract_features(audio, 1 - audio_mask) #(FIX:MZY): 1-audio_mask is needed for wavlm as the padding mask 280 | if self.model_config.encoder_name == "hubert": 281 | results = self.encoder(source = audio, padding_mask = 1-audio_mask) 282 | if self.model_config.encoder_type == "pretrain": 283 | encoder_outs, audio_mel_post_mask = results["x"], results["padding_mask"] 284 | if self.model_config.encoder_type == "finetune": 285 | encoder_outs, audio_mel_post_mask = results["encoder_out"], results["padding_mask"] 286 | encoder_outs = encoder_outs.transpose(0, 1) 287 | if self.model_config.encoder_name == "av_hubert": 288 | results = self.encoder(source={'video':visual, 'audio':audio}, padding_mask=visual_mask) # bs*seq*dim 289 | encoder_outs, audio_mel_post_mask = results["encoder_out"], results["padding_mask"] 290 | encoder_outs = encoder_outs.transpose(0, 1) 291 | audio_mel_post_mask = (~audio_mel_post_mask).float() 292 | if self.model_config.encoder_name == 'musicfm': 293 | encoder_outs = self.encoder.extract_features(audio, padding_mask = None) # MusicFM doesn't support padding mask 294 | if self.encoder is None: 295 | encoder_outs = audio_mel if audio_mel is not None else audio 296 | 297 | if self.model_config.encoder_projector == "q-former": 298 | encoder_outs = self.encoder_projector(encoder_outs, audio_mel_post_mask) 299 | if self.model_config.encoder_projector == "linear": 300 | encoder_outs = self.encoder_projector(encoder_outs) 301 | if self.model_config.encoder_projector == "cov1d-linear": 302 | encoder_outs = self.encoder_projector(encoder_outs) 303 | 304 | if instruct_ids is not None: 305 | if self.encoder is not None: 306 | encoder_outs = self.encoder(input_ids=instruct_ids, attention_mask=instruct_mask).last_hidden_state 307 | 308 | if self.model_config.encoder_projector == "q-former": 309 | encoder_outs = self.encoder_projector(encoder_outs, instruct_mask) 310 | if self.model_config.encoder_projector == "linear": 311 | encoder_outs = self.encoder_projector(encoder_outs) 312 | 313 | 314 | if input_ids is not None: 315 | input_ids[input_ids == -1] = 0 316 | if isinstance(self.llm, T5ForConditionalGeneration): 317 | inputs_embeds = self.llm.shared(input_ids) 318 | else: 319 | if hasattr(self.llm.model, "embed_tokens"): 320 | inputs_embeds = self.llm.model.embed_tokens(input_ids) 321 | elif hasattr(self.llm.model.model, "embed_tokens"): 322 | inputs_embeds = self.llm.model.model.embed_tokens(input_ids) 323 | else: 324 | inputs_embeds = self.llm.model.model.model.embed_tokens(input_ids) 325 | 326 | if modality_mask is not None: 327 | modality_mask_start_indices = (modality_mask == True).float().argmax(dim=1) 328 | modality_lengths = torch.clamp(modality_mask.sum(dim=1), max=encoder_outs.shape[1]).tolist() 329 | 330 | encoder_outs_pad = torch.zeros_like(inputs_embeds) 331 | for i in range(encoder_outs.shape[0]): 332 | encoder_outs_pad[ 333 | i, modality_mask_start_indices[i]:modality_mask_start_indices[i]+modality_lengths[i] 334 | ] = encoder_outs[i][:modality_lengths[i]] 335 | 336 | inputs_embeds = encoder_outs_pad + inputs_embeds * (~modality_mask[:, :, None]) 337 | 338 | if kwargs.get("inference_mode", False): 339 | return inputs_embeds, attention_mask 340 | 341 | # print(inputs_embeds.shape) 342 | model_outputs = self.llm(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels,) 343 | acc = -1 344 | if self.metric: 345 | with torch.no_grad(): 346 | preds = torch.argmax(input=model_outputs.logits, dim=-1) 347 | acc = compute_accuracy(preds.detach()[:, :-1], labels.detach()[:, 1:], ignore_label=-100) 348 | 349 | 350 | return model_outputs, acc 351 | 352 | @torch.no_grad() 353 | def generate(self, 354 | input_ids: torch.LongTensor = None, 355 | attention_mask: Optional[torch.Tensor] = None, 356 | position_ids: Optional[torch.LongTensor] = None, 357 | past_key_values: Optional[List[torch.FloatTensor]] = None, 358 | inputs_embeds: Optional[torch.FloatTensor] = None, 359 | labels: Optional[torch.LongTensor] = None, 360 | use_cache: Optional[bool] = None, 361 | output_attentions: Optional[bool] = None, 362 | output_hidden_states: Optional[bool] = None, 363 | return_dict: Optional[bool] = None, 364 | beam: Optional[int] = 1, 365 | **kwargs, 366 | ): 367 | kwargs["inference_mode"] = True 368 | 369 | inputs_embeds, attention_mask = self.forward( 370 | input_ids=input_ids, 371 | attention_mask=attention_mask, 372 | position_ids=position_ids, 373 | past_key_values=past_key_values, 374 | inputs_embeds=inputs_embeds, 375 | labels=labels, 376 | use_cache=use_cache, 377 | output_attentions=output_attentions, 378 | output_hidden_states=output_hidden_states, 379 | return_dict=return_dict, 380 | **kwargs, 381 | ) 382 | model_outputs = self.llm.generate( 383 | inputs_embeds=inputs_embeds, 384 | max_new_tokens=kwargs.get("max_new_tokens",400), 385 | num_beams=kwargs.get("num_beams", 5), 386 | do_sample=kwargs.get("do_sample", False), 387 | min_length=kwargs.get("min_new_tokens", 10), 388 | top_p=kwargs.get("top_p", 1.0), 389 | repetition_penalty=kwargs.get("repetition_penalty", 1.0), 390 | length_penalty=kwargs.get("length_penalty", 1.0), 391 | temperature=kwargs.get("temperature", 1.0), 392 | no_repeat_ngram_size=5, 393 | early_stopping=True, 394 | attention_mask=attention_mask, 395 | eos_token_id=self.tokenizer.eos_token_id, 396 | bos_token_id=self.tokenizer.eos_token_id, 397 | pad_token_id=self.tokenizer.pad_token_id, 398 | ) 399 | 400 | 401 | return model_outputs --------------------------------------------------------------------------------