├── finetune_codes ├── utils.py ├── __init__.py ├── demo_data │ └── audio_understanding │ │ ├── audios │ │ ├── librispeech_1263-139804-0001.flac │ │ ├── librispeech_1263-139804-0002.flac │ │ ├── librispeech_1263-139804-0004.flac │ │ └── librispeech_1263-139804-0005.flac │ │ ├── data.jsonl │ │ └── prepare_librispeech_asrtask.py ├── check_sft_infer.py ├── ds_config_zero2.json ├── ds_config_zero3.json ├── extract_semantic_codes.py ├── configuration_moonshot_kimia.py ├── README.md ├── finetune_ds.sh ├── model.py └── datasets.py ├── kimia_infer ├── __init__.py ├── api │ ├── __init__.py │ ├── prompt_manager.py │ └── kimia.py ├── models │ ├── __init__.py │ ├── tokenizer │ │ ├── __init__.py │ │ ├── whisper_Lv3 │ │ │ ├── mel_filters.npz │ │ │ └── whisper.py │ │ ├── glm4_tokenizer.py │ │ └── glm4_utils.py │ └── detokenizer │ │ ├── vocoder │ │ ├── alias_free_activation │ │ │ ├── __init__.py │ │ │ ├── cuda │ │ │ │ ├── __init__.py │ │ │ │ ├── compat.h │ │ │ │ ├── anti_alias_activation.cpp │ │ │ │ ├── activation1d.py │ │ │ │ ├── load.py │ │ │ │ ├── type_shim.h │ │ │ │ └── anti_alias_activation_cuda.cu │ │ │ └── torch │ │ │ │ ├── __init__.py │ │ │ │ ├── act.py │ │ │ │ ├── resample.py │ │ │ │ └── filter.py │ │ ├── utils.py │ │ └── activations.py │ │ ├── flow_matching │ │ ├── scheduler.py │ │ ├── ode_wrapper.py │ │ ├── dit_block.py │ │ └── model.py │ │ ├── bigvgan_wrapper.py │ │ └── semantic_fm_prefix_streaming.py └── utils │ ├── __init__.py │ ├── special_tokens.py │ ├── data.py │ └── sampler.py ├── test_audios ├── multiturn │ ├── case2 │ │ ├── multiturn_a1.txt │ │ ├── multiturn_a1.wav │ │ ├── multiturn_q1.wav │ │ └── multiturn_q2.wav │ └── case1 │ │ ├── multiturn_a1.txt │ │ ├── multiturn_a1.wav │ │ ├── multiturn_q1.wav │ │ └── multiturn_q2.wav ├── asr_example.wav └── qa_example.wav ├── assets ├── kimia_logo.png ├── kimia_report.pdf ├── kimia_framework.png └── kimia_radar_chart.png ├── .gitmodules ├── requirements.txt ├── Dockerfile ├── pyproject.toml ├── infer.py ├── .gitignore └── finetune.py /finetune_codes/utils.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /kimia_infer/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /finetune_codes/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /kimia_infer/api/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /kimia_infer/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /kimia_infer/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /kimia_infer/models/tokenizer/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /test_audios/multiturn/case2/multiturn_a1.txt: -------------------------------------------------------------------------------- 1 | 当然可以,这很简单。一二三四五六七八九十。 -------------------------------------------------------------------------------- /kimia_infer/models/detokenizer/vocoder/alias_free_activation/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /kimia_infer/models/detokenizer/vocoder/alias_free_activation/cuda/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /test_audios/multiturn/case1/multiturn_a1.txt: -------------------------------------------------------------------------------- 1 | 当然可以,李白的诗很多,比如这句:“床前明月光,疑是地上霜。举头望明月,低头思故乡。 -------------------------------------------------------------------------------- /assets/kimia_logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MoonshotAI/Kimi-Audio/HEAD/assets/kimia_logo.png -------------------------------------------------------------------------------- /assets/kimia_report.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MoonshotAI/Kimi-Audio/HEAD/assets/kimia_report.pdf -------------------------------------------------------------------------------- /assets/kimia_framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MoonshotAI/Kimi-Audio/HEAD/assets/kimia_framework.png -------------------------------------------------------------------------------- /assets/kimia_radar_chart.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MoonshotAI/Kimi-Audio/HEAD/assets/kimia_radar_chart.png -------------------------------------------------------------------------------- /test_audios/asr_example.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MoonshotAI/Kimi-Audio/HEAD/test_audios/asr_example.wav -------------------------------------------------------------------------------- /test_audios/qa_example.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MoonshotAI/Kimi-Audio/HEAD/test_audios/qa_example.wav -------------------------------------------------------------------------------- /test_audios/multiturn/case1/multiturn_a1.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MoonshotAI/Kimi-Audio/HEAD/test_audios/multiturn/case1/multiturn_a1.wav -------------------------------------------------------------------------------- /test_audios/multiturn/case1/multiturn_q1.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MoonshotAI/Kimi-Audio/HEAD/test_audios/multiturn/case1/multiturn_q1.wav -------------------------------------------------------------------------------- /test_audios/multiturn/case1/multiturn_q2.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MoonshotAI/Kimi-Audio/HEAD/test_audios/multiturn/case1/multiturn_q2.wav -------------------------------------------------------------------------------- /test_audios/multiturn/case2/multiturn_a1.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MoonshotAI/Kimi-Audio/HEAD/test_audios/multiturn/case2/multiturn_a1.wav -------------------------------------------------------------------------------- /test_audios/multiturn/case2/multiturn_q1.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MoonshotAI/Kimi-Audio/HEAD/test_audios/multiturn/case2/multiturn_q1.wav -------------------------------------------------------------------------------- /test_audios/multiturn/case2/multiturn_q2.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MoonshotAI/Kimi-Audio/HEAD/test_audios/multiturn/case2/multiturn_q2.wav -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "kimia_infer/models/tokenizer/glm4"] 2 | path = kimia_infer/models/tokenizer/glm4 3 | url = https://github.com/THUDM/GLM-4-Voice.git 4 | -------------------------------------------------------------------------------- /kimia_infer/models/tokenizer/whisper_Lv3/mel_filters.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MoonshotAI/Kimi-Audio/HEAD/kimia_infer/models/tokenizer/whisper_Lv3/mel_filters.npz -------------------------------------------------------------------------------- /finetune_codes/demo_data/audio_understanding/audios/librispeech_1263-139804-0001.flac: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MoonshotAI/Kimi-Audio/HEAD/finetune_codes/demo_data/audio_understanding/audios/librispeech_1263-139804-0001.flac -------------------------------------------------------------------------------- /finetune_codes/demo_data/audio_understanding/audios/librispeech_1263-139804-0002.flac: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MoonshotAI/Kimi-Audio/HEAD/finetune_codes/demo_data/audio_understanding/audios/librispeech_1263-139804-0002.flac -------------------------------------------------------------------------------- /finetune_codes/demo_data/audio_understanding/audios/librispeech_1263-139804-0004.flac: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MoonshotAI/Kimi-Audio/HEAD/finetune_codes/demo_data/audio_understanding/audios/librispeech_1263-139804-0004.flac -------------------------------------------------------------------------------- /finetune_codes/demo_data/audio_understanding/audios/librispeech_1263-139804-0005.flac: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MoonshotAI/Kimi-Audio/HEAD/finetune_codes/demo_data/audio_understanding/audios/librispeech_1263-139804-0005.flac -------------------------------------------------------------------------------- /kimia_infer/models/detokenizer/vocoder/alias_free_activation/torch/__init__.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0 2 | # LICENSE is in incl_licenses directory. 3 | 4 | from .filter import * 5 | from .resample import * 6 | from .act import * 7 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==2.6.0 2 | torchaudio==2.6.0 3 | packaging 4 | jinja2 5 | openai-whisper 6 | jsonlines 7 | pandas 8 | validators 9 | sty 10 | transformers 11 | librosa 12 | accelerate 13 | aiohttp 14 | colorama 15 | omegaconf==2.3.0 16 | sox 17 | six==1.16.0 18 | hyperpyyaml 19 | conformer==0.3.2 20 | diffusers 21 | pillow 22 | sentencepiece 23 | easydict 24 | fire 25 | ujson 26 | cairosvg 27 | immutabledict 28 | rich 29 | wget 30 | gdown 31 | datasets 32 | torchdyn==1.0.6 33 | huggingface_hub 34 | loguru 35 | decord 36 | blobfile 37 | timm 38 | sacrebleu==1.5.1 39 | soundfile 40 | tqdm 41 | flash_attn==2.7.4.post1 42 | deepspeed==0.16.9 -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvidia/cuda:12.8.1-cudnn-devel-ubuntu22.04 2 | 3 | WORKDIR /app 4 | 5 | COPY ./requirements.txt /app/ 6 | RUN apt-get update && apt-get install -y \ 7 | python3.10 \ 8 | python3.10-dev \ 9 | curl \ 10 | sox \ 11 | openssh-server \ 12 | ffmpeg \ 13 | libgl1-mesa-glx \ 14 | git \ 15 | ninja-build \ 16 | && rm -rf /var/lib/apt/lists/* 17 | 18 | # 安装 pip 19 | RUN curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py \ 20 | && python3.10 get-pip.py \ 21 | && rm get-pip.py 22 | RUN pip install -r requirements.txt 23 | RUN pip install flash-attn --no-build-isolation 24 | 25 | # alias python3 as python 26 | RUN ln -s /usr/bin/python3 /usr/bin/python 27 | 28 | CMD ["/bin/bash"] -------------------------------------------------------------------------------- /kimia_infer/models/detokenizer/vocoder/alias_free_activation/torch/act.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0 2 | # LICENSE is in incl_licenses directory. 3 | 4 | import torch.nn as nn 5 | from .resample import UpSample1d, DownSample1d 6 | 7 | 8 | class Activation1d(nn.Module): 9 | def __init__( 10 | self, 11 | activation, 12 | up_ratio: int = 2, 13 | down_ratio: int = 2, 14 | up_kernel_size: int = 12, 15 | down_kernel_size: int = 12, 16 | ): 17 | super().__init__() 18 | self.up_ratio = up_ratio 19 | self.down_ratio = down_ratio 20 | self.act = activation 21 | self.upsample = UpSample1d(up_ratio, up_kernel_size) 22 | self.downsample = DownSample1d(down_ratio, down_kernel_size) 23 | 24 | # x: [B,C,T] 25 | def forward(self, x): 26 | x = self.upsample(x) 27 | x = self.act(x) 28 | x = self.downsample(x) 29 | 30 | return x 31 | -------------------------------------------------------------------------------- /finetune_codes/check_sft_infer.py: -------------------------------------------------------------------------------- 1 | from kimia_infer.api.kimia import KimiAudio 2 | 3 | 4 | model = KimiAudio(model_path="output/finetuned_hf_for_inference", load_detokenizer=False) 5 | 6 | 7 | sampling_params = { 8 | "audio_temperature": 0.8, 9 | "audio_top_k": 10, 10 | "text_temperature": 0.0, 11 | "text_top_k": 5, 12 | "audio_repetition_penalty": 1.0, 13 | "audio_repetition_window_size": 64, 14 | "text_repetition_penalty": 1.0, 15 | "text_repetition_window_size": 16, 16 | } 17 | 18 | messages = [ 19 | {"role": "user", "message_type": "text", "content": "Please transcribe the spoken content into written text."}, 20 | { 21 | "role": "user", 22 | "message_type": "audio", 23 | "content": "finetune_codes/demo_data/audio_understanding/audios/librispeech_1263-139804-0001.flac", 24 | }, 25 | ] 26 | 27 | wav, text = model.generate(messages, **sampling_params, output_type="text") 28 | print(">>> output text: ", text) 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | # ---------- pyproject.toml ---------- 2 | [build-system] 3 | requires = ["setuptools", "wheel", "torch"] 4 | build-backend = "setuptools.build_meta" 5 | 6 | [project] 7 | name = "kimi-audio" 8 | version = "0.1.0" 9 | description = "Inference library for the Kimi‑Audio foundation model" 10 | readme = "README.md" 11 | license = {text = "MIT"} 12 | authors = [{name = "MoonshotAI", email = "contact@moonshot.ai"}] 13 | 14 | dependencies = [ 15 | "torch", 16 | "torchaudio", 17 | "flash-attn", 18 | "soundfile", 19 | "librosa", 20 | "tqdm", 21 | "loguru", 22 | "huggingface_hub", 23 | "transformers", 24 | "conformer", 25 | "diffusers", 26 | "tiktoken", 27 | "ninja", 28 | "timm", 29 | "torchdyn" 30 | ] 31 | 32 | [tool.setuptools] 33 | include-package-data = true 34 | 35 | [tool.setuptools.packages.find] 36 | where = ["."] 37 | include = ["kimia_infer*"] 38 | 39 | [tool.setuptools.package-data] 40 | "kimia_infer" = ["**/*"] 41 | 42 | [project.urls] 43 | Repository = "https://github.com/MoonshotAI/Kimi-Audio" 44 | -------------------------------------------------------------------------------- /kimia_infer/models/detokenizer/vocoder/alias_free_activation/cuda/compat.h: -------------------------------------------------------------------------------- 1 | /* coding=utf-8 2 | * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | /*This code is copied fron NVIDIA apex: 18 | * https://github.com/NVIDIA/apex 19 | * with minor changes. */ 20 | 21 | #ifndef TORCH_CHECK 22 | #define TORCH_CHECK AT_CHECK 23 | #endif 24 | 25 | #ifdef VERSION_GE_1_3 26 | #define DATA_PTR data_ptr 27 | #else 28 | #define DATA_PTR data 29 | #endif 30 | -------------------------------------------------------------------------------- /kimia_infer/models/detokenizer/vocoder/alias_free_activation/cuda/anti_alias_activation.cpp: -------------------------------------------------------------------------------- 1 | /* coding=utf-8 2 | * Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | #include 18 | 19 | extern "C" torch::Tensor fwd_cuda(torch::Tensor const &input, torch::Tensor const &up_filter, torch::Tensor const &down_filter, torch::Tensor const &alpha, torch::Tensor const &beta); 20 | 21 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 22 | m.def("forward", &fwd_cuda, "Anti-Alias Activation forward (CUDA)"); 23 | } -------------------------------------------------------------------------------- /kimia_infer/models/tokenizer/glm4_tokenizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import librosa 3 | import os 4 | 5 | from transformers import WhisperFeatureExtractor 6 | from .glm4.speech_tokenizer.modeling_whisper import WhisperVQEncoder 7 | from .glm4_utils import extract_speech_token 8 | from torch import nn 9 | 10 | 11 | class Glm4Tokenizer(nn.Module): 12 | def __init__(self, tokenizer_path): 13 | super().__init__() 14 | self.whisper_model = WhisperVQEncoder.from_pretrained(tokenizer_path).eval() 15 | self.feature_extractor = WhisperFeatureExtractor.from_pretrained(tokenizer_path) 16 | 17 | def tokenize(self, speech=None, audio_path=None, sr=16000): 18 | if audio_path: 19 | audio, sr = librosa.load(audio_path, sr=16000) 20 | audio = torch.tensor(audio).unsqueeze(0) 21 | audio_info = (audio, sr) 22 | else: 23 | assert speech is not None 24 | assert sr 25 | if isinstance(speech, list): 26 | speech = torch.tensor(speech).unsqueeze(0) 27 | if len(speech.shape) == 1: 28 | speech = speech.unsqueeze(0) 29 | audio_info = (speech, sr) 30 | 31 | audio_tokens = extract_speech_token( 32 | self.whisper_model, self.feature_extractor, [audio_info] 33 | )[0] 34 | audio_tokens = torch.tensor(audio_tokens).unsqueeze(0) 35 | return audio_tokens 36 | -------------------------------------------------------------------------------- /finetune_codes/ds_config_zero2.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | "optimizer": { 14 | "type": "AdamW", 15 | "params": { 16 | "lr": "auto", 17 | "betas": "auto", 18 | "eps": "auto", 19 | "weight_decay": "auto" 20 | } 21 | }, 22 | 23 | "scheduler": { 24 | "type": "WarmupLR", 25 | "params": { 26 | "warmup_min_lr": "auto", 27 | "warmup_max_lr": "auto", 28 | "warmup_num_steps": "auto" 29 | } 30 | }, 31 | 32 | "zero_optimization": { 33 | "stage": 2, 34 | "offload_optimizer": { 35 | "device": "none", 36 | "pin_memory": true 37 | }, 38 | "allgather_partitions": true, 39 | "allgather_bucket_size": 2e8, 40 | "overlap_comm": true, 41 | "reduce_scatter": true, 42 | "reduce_bucket_size": 2e8, 43 | "contiguous_gradients": true 44 | }, 45 | 46 | "gradient_accumulation_steps": "auto", 47 | "gradient_clipping": "auto", 48 | "steps_per_print": 100, 49 | "train_batch_size": "auto", 50 | "train_micro_batch_size_per_gpu": "auto", 51 | "wall_clock_breakdown": false 52 | } -------------------------------------------------------------------------------- /kimia_infer/utils/special_tokens.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | 4 | @dataclass 5 | class ExtraTokens: 6 | msg_end: int 7 | user_msg_start: int 8 | assistant_msg_start: int 9 | 10 | media_begin: int 11 | media_end: int 12 | 13 | kimia_text_blank: int 14 | kimia_text_eos: int 15 | 16 | kimia_user_msg_start: int 17 | kimia_assistant_msg_start: int 18 | 19 | kimia_speech_ct_id: int 20 | kimia_speech_ctd_id: int 21 | 22 | pad: int 23 | 24 | 25 | def instantiate_extra_tokens(tokenizer): 26 | if hasattr(tokenizer, "special_tokens"): 27 | map_fn = lambda x: tokenizer.special_tokens[x] 28 | elif hasattr(tokenizer, "convert_tokens_to_ids"): 29 | map_fn = lambda x: tokenizer.convert_tokens_to_ids(x) 30 | else: 31 | raise ValueError(f"Invalid tokenizer type: {type(tokenizer)}") 32 | return ExtraTokens( 33 | msg_end=map_fn("<|im_msg_end|>"), # 0 34 | user_msg_start=map_fn("<|im_user_msg_start|>"), # 1 35 | assistant_msg_start=map_fn("<|im_assistant_msg_start|>"), # 2 36 | media_begin=map_fn("<|im_media_begin|>"), # 13 37 | media_end=map_fn("<|im_media_end|>"), # 15 38 | kimia_text_blank=map_fn("<|im_kimia_text_blank|>"), # 18 39 | kimia_text_eos=map_fn("<|im_kimia_text_eos|>"), # 19 40 | kimia_user_msg_start=map_fn("<|im_kimia_user_msg_start|>"), # 22 41 | kimia_assistant_msg_start=map_fn("<|im_kimia_assistant_msg_start|>"), # 23 42 | kimia_speech_ct_id=map_fn("<|im_kimia_speech_ct_id|>"), # 27 43 | kimia_speech_ctd_id=map_fn("<|im_kimia_speech_ctd_id|>"), # 28 44 | pad=tokenizer.pad_id, 45 | ) 46 | -------------------------------------------------------------------------------- /finetune_codes/ds_config_zero3.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | "optimizer": { 14 | "type": "AdamW", 15 | "params": { 16 | "lr": "auto", 17 | "betas": "auto", 18 | "eps": "auto", 19 | "weight_decay": "auto" 20 | } 21 | }, 22 | 23 | "scheduler": { 24 | "type": "WarmupLR", 25 | "params": { 26 | "warmup_min_lr": "auto", 27 | "warmup_max_lr": "auto", 28 | "warmup_num_steps": "auto" 29 | } 30 | }, 31 | 32 | "zero_optimization": { 33 | "stage": 3, 34 | "offload_optimizer": { 35 | "device": "none", 36 | "pin_memory": true 37 | }, 38 | "offload_param": { 39 | "device": "none", 40 | "pin_memory": true 41 | }, 42 | "overlap_comm": true, 43 | "contiguous_gradients": true, 44 | "sub_group_size": 1e9, 45 | "reduce_bucket_size": "auto", 46 | "stage3_prefetch_bucket_size": "auto", 47 | "stage3_param_persistence_threshold": "auto", 48 | "stage3_max_live_parameters": 1e9, 49 | "stage3_max_reuse_distance": 1e9, 50 | "stage3_gather_16bit_weights_on_model_save": true 51 | }, 52 | 53 | "gradient_accumulation_steps": "auto", 54 | "gradient_clipping": "auto", 55 | "steps_per_print": 100, 56 | "train_batch_size": "auto", 57 | "train_micro_batch_size_per_gpu": "auto", 58 | "wall_clock_breakdown": false 59 | } -------------------------------------------------------------------------------- /finetune_codes/extract_semantic_codes.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import json 4 | 5 | from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig 6 | from huggingface_hub import snapshot_download 7 | import tqdm 8 | 9 | from kimia_infer.api.prompt_manager import KimiAPromptManager 10 | 11 | 12 | if __name__ == "__main__": 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument("--model_name_or_path", type=str, default="moonshotai/Kimi-Audio-7B") 15 | parser.add_argument("--input_file", type=str, required=True) 16 | parser.add_argument("--output_file", type=str, required=True) 17 | args = parser.parse_args() 18 | 19 | if os.path.exists(args.model_name_or_path): 20 | # local path 21 | cache_path = args.model_name_or_path 22 | else: 23 | # cache everything if model_path is a model-id 24 | cache_path = snapshot_download(args.model_name_or_path) 25 | 26 | # load model config 27 | model_config = AutoConfig.from_pretrained(cache_path, trust_remote_code=True) 28 | 29 | prompt_manager = KimiAPromptManager( 30 | model_path=cache_path, kimia_token_offset=model_config.kimia_token_offset, kimia_text_audiodelaytokens=model_config.kimia_mimo_audiodelaytokens 31 | ) 32 | 33 | with open(args.input_file, "r") as f, open(args.output_file, "w") as f_out: 34 | lines = f.readlines() 35 | for line in tqdm.tqdm(lines): 36 | data = json.loads(line) 37 | 38 | for msg in data["conversation"]: 39 | if msg["message_type"] == "audio": 40 | audio_path = msg["content"] 41 | audio_tokens = prompt_manager._tokenize_audio(audio_path) 42 | msg["audio_tokens"] = audio_tokens 43 | 44 | f_out.write(json.dumps(data, ensure_ascii=False) + "\n") 45 | -------------------------------------------------------------------------------- /kimia_infer/models/detokenizer/vocoder/alias_free_activation/torch/resample.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0 2 | # LICENSE is in incl_licenses directory. 3 | 4 | import torch.nn as nn 5 | from torch.nn import functional as F 6 | from .filter import LowPassFilter1d 7 | from .filter import kaiser_sinc_filter1d 8 | 9 | 10 | class UpSample1d(nn.Module): 11 | def __init__(self, ratio=2, kernel_size=None): 12 | super().__init__() 13 | self.ratio = ratio 14 | self.kernel_size = ( 15 | int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size 16 | ) 17 | self.stride = ratio 18 | self.pad = self.kernel_size // ratio - 1 19 | self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2 20 | self.pad_right = ( 21 | self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2 22 | ) 23 | filter = kaiser_sinc_filter1d( 24 | cutoff=0.5 / ratio, half_width=0.6 / ratio, kernel_size=self.kernel_size 25 | ) 26 | self.register_buffer("filter", filter) 27 | 28 | # x: [B, C, T] 29 | def forward(self, x): 30 | _, C, _ = x.shape 31 | 32 | x = F.pad(x, (self.pad, self.pad), mode="replicate") 33 | x = self.ratio * F.conv_transpose1d( 34 | x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C 35 | ) 36 | x = x[..., self.pad_left : -self.pad_right] 37 | 38 | return x 39 | 40 | 41 | class DownSample1d(nn.Module): 42 | def __init__(self, ratio=2, kernel_size=None): 43 | super().__init__() 44 | self.ratio = ratio 45 | self.kernel_size = ( 46 | int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size 47 | ) 48 | self.lowpass = LowPassFilter1d( 49 | cutoff=0.5 / ratio, 50 | half_width=0.6 / ratio, 51 | stride=ratio, 52 | kernel_size=self.kernel_size, 53 | ) 54 | 55 | def forward(self, x): 56 | xx = self.lowpass(x) 57 | 58 | return xx 59 | -------------------------------------------------------------------------------- /finetune_codes/demo_data/audio_understanding/data.jsonl: -------------------------------------------------------------------------------- 1 | {"task_type": "understanding", "conversation": [{"role": "user", "message_type": "text", "content": "Please transcribe the spoken content into written text."}, {"role": "user", "message_type": "audio", "content": "finetune_codes/demo_data/audio_understanding/audios/librispeech_1263-139804-0001.flac"}, {"role": "assistant", "message_type": "text", "content": "they are nearly as large as we and relatively much lighter in weight all the periods of physical growth are correspondingly decreased children walk four or five weeks after birth"}]} 2 | {"task_type": "understanding", "conversation": [{"role": "user", "message_type": "text", "content": "Please transcribe the spoken content into written text."}, {"role": "user", "message_type": "audio", "content": "finetune_codes/demo_data/audio_understanding/audios/librispeech_1263-139804-0002.flac"}, {"role": "assistant", "message_type": "text", "content": "and are capable of receiving regular instruction at the age of five months strange as it may seem this sphere which for convenience we will call brief"}]} 3 | {"task_type": "understanding", "conversation": [{"role": "user", "message_type": "text", "content": "Please transcribe the spoken content into written text."}, {"role": "user", "message_type": "audio", "content": "finetune_codes/demo_data/audio_understanding/audios/librispeech_1263-139804-0004.flac"}, {"role": "assistant", "message_type": "text", "content": "on this world of brief all vegetables mature in periods so short that one marvels when he hears it think of cereals reaching maturity in seven or eight of our days or during one day of brief"}]} 4 | {"task_type": "understanding", "conversation": [{"role": "user", "message_type": "text", "content": "Please transcribe the spoken content into written text."}, {"role": "user", "message_type": "audio", "content": "finetune_codes/demo_data/audio_understanding/audios/librispeech_1263-139804-0005.flac"}, {"role": "assistant", "message_type": "text", "content": "early in the morning certain crops are planted and are harvested at night two or more days are required for maturing other crops actually the people of brief raise their crops with less labor than is required amongst us"}]} 5 | -------------------------------------------------------------------------------- /finetune_codes/configuration_moonshot_kimia.py: -------------------------------------------------------------------------------- 1 | from transformers.models.qwen2.configuration_qwen2 import Qwen2Config 2 | 3 | 4 | class KimiAudioConfig(Qwen2Config): 5 | def __init__( 6 | self, 7 | vocab_size=163840, 8 | hidden_size=4096, 9 | intermediate_size=11008, 10 | num_hidden_layers=32, 11 | num_attention_heads=32, 12 | num_key_value_heads=None, 13 | hidden_act="silu", 14 | initializer_range=0.02, 15 | rms_norm_eps=1e-6, 16 | use_cache=True, 17 | rope_theta=10000.0, 18 | rope_scaling=None, 19 | tie_word_embeddings=False, 20 | kimia_mimo_layers: int = 6, 21 | kimia_mimo_audiodelaytokens: int = 5, 22 | kimia_mimo_transformer_from_layer_index: int = 21, 23 | kimia_audio_output_vocab: int = 16896, 24 | kimia_text_output_vocab: int = 152064, 25 | num_audio_special_tokens: int = 512, 26 | num_base_tokens: int = 151643, 27 | kimia_token_offset: int = 152064, 28 | use_whisper_feature: bool = True, 29 | kimia_adaptor_input_dim: int = 5120, 30 | kimia_media_begin: int = 151661, 31 | kimia_media_end: int = 151663, 32 | **kwargs, 33 | ): 34 | super().__init__( 35 | vocab_size=vocab_size, 36 | hidden_size=hidden_size, 37 | intermediate_size=intermediate_size, 38 | num_hidden_layers=num_hidden_layers, 39 | num_attention_heads=num_attention_heads, 40 | num_key_value_heads=num_key_value_heads, 41 | hidden_act=hidden_act, 42 | initializer_range=initializer_range, 43 | rms_norm_eps=rms_norm_eps, 44 | use_cache=use_cache, 45 | tie_word_embeddings=tie_word_embeddings, 46 | rope_theta=rope_theta, 47 | rope_scaling=rope_scaling, 48 | **kwargs, 49 | ) 50 | 51 | self.kimia_mimo_layers = kimia_mimo_layers 52 | self.kimia_mimo_audiodelaytokens = kimia_mimo_audiodelaytokens 53 | # vocab 54 | self.kimia_mimo_transformer_from_layer_index = ( 55 | kimia_mimo_transformer_from_layer_index 56 | ) 57 | self.kimia_audio_output_vocab = kimia_audio_output_vocab 58 | self.kimia_text_output_vocab = kimia_text_output_vocab 59 | self.num_audio_special_tokens = num_audio_special_tokens 60 | self.num_base_tokens = num_base_tokens 61 | self.kimia_token_offset = kimia_token_offset 62 | self.use_whisper_feature = use_whisper_feature 63 | self.kimia_adaptor_input_dim = kimia_adaptor_input_dim 64 | # special tokens 65 | self.kimia_media_begin = kimia_media_begin 66 | self.kimia_media_end = kimia_media_end -------------------------------------------------------------------------------- /kimia_infer/models/detokenizer/vocoder/alias_free_activation/cuda/activation1d.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 NVIDIA CORPORATION. 2 | # Licensed under the MIT license. 3 | 4 | import torch 5 | import torch.nn as nn 6 | from ..torch.resample import UpSample1d, DownSample1d 7 | 8 | # load fused CUDA kernel: this enables importing anti_alias_activation_cuda 9 | from . import load 10 | 11 | anti_alias_activation_cuda = load.load() 12 | 13 | 14 | class FusedAntiAliasActivation(torch.autograd.Function): 15 | """ 16 | Assumes filter size 12, replication padding on upsampling/downsampling, and logscale alpha/beta parameters as inputs. 17 | The hyperparameters are hard-coded in the kernel to maximize speed. 18 | NOTE: The fused kenrel is incorrect for Activation1d with different hyperparameters. 19 | """ 20 | 21 | @staticmethod 22 | def forward(ctx, inputs, up_ftr, down_ftr, alpha, beta): 23 | activation_results = anti_alias_activation_cuda.forward( 24 | inputs, up_ftr, down_ftr, alpha, beta 25 | ) 26 | 27 | return activation_results 28 | 29 | @staticmethod 30 | def backward(ctx, output_grads): 31 | raise NotImplementedError 32 | return output_grads, None, None 33 | 34 | 35 | class Activation1d(nn.Module): 36 | def __init__( 37 | self, 38 | activation, 39 | up_ratio: int = 2, 40 | down_ratio: int = 2, 41 | up_kernel_size: int = 12, 42 | down_kernel_size: int = 12, 43 | fused: bool = True, 44 | ): 45 | super().__init__() 46 | self.up_ratio = up_ratio 47 | self.down_ratio = down_ratio 48 | self.act = activation 49 | self.upsample = UpSample1d(up_ratio, up_kernel_size) 50 | self.downsample = DownSample1d(down_ratio, down_kernel_size) 51 | 52 | self.fused = fused # Whether to use fused CUDA kernel or not 53 | 54 | def forward(self, x): 55 | if not self.fused: 56 | x = self.upsample(x) 57 | x = self.act(x) 58 | x = self.downsample(x) 59 | return x 60 | else: 61 | if self.act.__class__.__name__ == "Snake": 62 | beta = self.act.alpha.data # Snake uses same params for alpha and beta 63 | else: 64 | beta = ( 65 | self.act.beta.data 66 | ) # Snakebeta uses different params for alpha and beta 67 | alpha = self.act.alpha.data 68 | if ( 69 | not self.act.alpha_logscale 70 | ): # Exp baked into cuda kernel, cancel it out with a log 71 | alpha = torch.log(alpha) 72 | beta = torch.log(beta) 73 | 74 | x = FusedAntiAliasActivation.apply( 75 | x, self.upsample.filter, self.downsample.lowpass.filter, alpha, beta 76 | ) 77 | return x 78 | -------------------------------------------------------------------------------- /kimia_infer/models/detokenizer/vocoder/alias_free_activation/cuda/load.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 NVIDIA CORPORATION. 2 | # Licensed under the MIT license. 3 | 4 | import os 5 | import pathlib 6 | import subprocess 7 | 8 | from torch.utils import cpp_extension 9 | 10 | """ 11 | Setting this param to a list has a problem of generating different compilation commands (with diferent order of architectures) and leading to recompilation of fused kernels. 12 | Set it to empty stringo avoid recompilation and assign arch flags explicity in extra_cuda_cflags below 13 | """ 14 | os.environ["TORCH_CUDA_ARCH_LIST"] = "" 15 | 16 | 17 | def load(): 18 | # Check if cuda 11 is installed for compute capability 8.0 19 | cc_flag = [] 20 | _, bare_metal_major, _ = _get_cuda_bare_metal_version(cpp_extension.CUDA_HOME) 21 | if int(bare_metal_major) >= 11: 22 | cc_flag.append("-gencode") 23 | cc_flag.append("arch=compute_80,code=sm_80") 24 | 25 | # Build path 26 | srcpath = pathlib.Path(__file__).parent.absolute() 27 | buildpath = srcpath / "build" 28 | _create_build_dir(buildpath) 29 | 30 | # Helper function to build the kernels. 31 | def _cpp_extention_load_helper(name, sources, extra_cuda_flags): 32 | return cpp_extension.load( 33 | name=name, 34 | sources=sources, 35 | build_directory=buildpath, 36 | extra_cflags=[ 37 | "-O3", 38 | ], 39 | extra_cuda_cflags=[ 40 | "-O3", 41 | "-gencode", 42 | "arch=compute_70,code=sm_70", 43 | "--use_fast_math", 44 | ] 45 | + extra_cuda_flags 46 | + cc_flag, 47 | verbose=True, 48 | ) 49 | 50 | extra_cuda_flags = [ 51 | "-U__CUDA_NO_HALF_OPERATORS__", 52 | "-U__CUDA_NO_HALF_CONVERSIONS__", 53 | "--expt-relaxed-constexpr", 54 | "--expt-extended-lambda", 55 | ] 56 | 57 | sources = [ 58 | srcpath / "anti_alias_activation.cpp", 59 | srcpath / "anti_alias_activation_cuda.cu", 60 | ] 61 | anti_alias_activation_cuda = _cpp_extention_load_helper( 62 | "anti_alias_activation_cuda", sources, extra_cuda_flags 63 | ) 64 | 65 | return anti_alias_activation_cuda 66 | 67 | 68 | def _get_cuda_bare_metal_version(cuda_dir): 69 | raw_output = subprocess.check_output( 70 | [cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True 71 | ) 72 | output = raw_output.split() 73 | release_idx = output.index("release") + 1 74 | release = output[release_idx].split(".") 75 | bare_metal_major = release[0] 76 | bare_metal_minor = release[1][0] 77 | 78 | return raw_output, bare_metal_major, bare_metal_minor 79 | 80 | 81 | def _create_build_dir(buildpath): 82 | try: 83 | os.mkdir(buildpath) 84 | except OSError: 85 | if not os.path.isdir(buildpath): 86 | print(f"Creation of the build directory {buildpath} failed") 87 | -------------------------------------------------------------------------------- /kimia_infer/models/detokenizer/flow_matching/scheduler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from abc import abstractmethod, ABC 3 | 4 | try: 5 | from torchdyn.core import NeuralODE 6 | 7 | NEURALODE_INSTALLED = True 8 | except ImportError: 9 | NEURALODE_INSTALLED = False 10 | 11 | 12 | class SchedulerBase(ABC): 13 | def __init__(self) -> None: 14 | pass 15 | 16 | @abstractmethod 17 | def set_timesteps(self): 18 | pass 19 | 20 | @abstractmethod 21 | def step(self): 22 | pass 23 | 24 | @abstractmethod 25 | def add_noise(self): 26 | pass 27 | 28 | 29 | class StreamingFlowMatchingScheduler(SchedulerBase): 30 | def __init__( 31 | self, 32 | timesteps=1000, 33 | sigma_min=1e-4, 34 | ) -> None: 35 | super().__init__() 36 | 37 | self.sigma_min = sigma_min 38 | self.timesteps = timesteps 39 | self.t_min = 0 40 | self.t_max = 1 - self.sigma_min 41 | 42 | self.neural_ode = None 43 | 44 | def set_timesteps(self, timesteps=15): 45 | self.timesteps = timesteps 46 | 47 | def step(self, xt, predicted_v): 48 | 49 | h = (self.t_max - self.t_min) / self.timesteps 50 | h = h * torch.ones(xt.shape[0], dtype=xt.dtype, device=xt.device) 51 | 52 | xt = xt + h * predicted_v 53 | return xt 54 | 55 | def sample(self, ode_wrapper, time_steps, xt, verbose=False, x0=None): 56 | h = (self.t_max - self.t_min) / self.timesteps 57 | h = h * torch.ones(xt.shape[0], dtype=xt.dtype, device=xt.device) 58 | 59 | if verbose: 60 | gt_v = x0 - xt 61 | 62 | for t in time_steps: 63 | predicted_v = ode_wrapper(t, xt) 64 | if verbose: 65 | dist = torch.mean(torch.nn.functional.l1_loss(gt_v, predicted_v)) 66 | print("Time: {}, Distance: {}".format(t, dist)) 67 | xt = xt + h * predicted_v 68 | return xt 69 | 70 | def sample_by_neuralode(self, ode_wrapper, time_steps, xt, verbose=False, x0=None): 71 | if not NEURALODE_INSTALLED: 72 | raise ImportError("NeuralODE is not installed, please install it first.") 73 | 74 | if self.neural_ode is None: 75 | self.neural_ode = NeuralODE( 76 | ode_wrapper, 77 | solver="euler", 78 | sensitivity="adjoint", 79 | atol=self.sigma_min, 80 | rtol=self.sigma_min, 81 | ) 82 | 83 | eval_points, traj = self.neural_ode(xt, time_steps) 84 | return traj[-1] 85 | 86 | def add_noise( 87 | self, 88 | original_samples: torch.FloatTensor, 89 | noise: torch.FloatTensor, 90 | timesteps: torch.IntTensor, 91 | ): 92 | ut = original_samples - (1 - self.sigma_min) * noise # 和ut的梯度没关系 93 | t_unsqueeze = timesteps.unsqueeze(1).unsqueeze(1).float() / self.timesteps 94 | x_noisy = ( 95 | t_unsqueeze * original_samples 96 | + (1.0 - (1 - self.sigma_min) * t_unsqueeze) * noise 97 | ) 98 | return x_noisy, ut 99 | -------------------------------------------------------------------------------- /finetune_codes/README.md: -------------------------------------------------------------------------------- 1 | # Finetune Kimi-Audio 2 | 3 | ## 1. Data 4 | 5 | We provide the demo data for each SFT task. You can prepare your own data in the same format. 6 | 7 | The data file is a jsonl file, each line is a data in json format. 8 | 9 | 10 | ### Audio Understanding 11 | 12 | The data format is as follows (we use ASR task as an example): 13 | 14 | ```json 15 | { 16 | "task_type": "understanding", 17 | "conversation": [ 18 | { 19 | "role": "user", 20 | "message_type": "text", 21 | "content": "Please transcribe the spoken content into written text." 22 | }, 23 | { 24 | "role": "user", 25 | "message_type": "audio", 26 | "content": # Audio Path 27 | }, 28 | { 29 | "role": "assistant", 30 | "message_type": "text", 31 | "content": # Transcript 32 | } 33 | ] 34 | } 35 | ``` 36 | 37 | * Librispeech ASR task as an example: 38 | ``` bash 39 | python finetune_codes/demo_data/audio_understanding/prepare_librispeech_asrtask.py --output_dir "output/data/librispeech" 40 | ``` 41 | 42 | Note: The file `finetune_codes/demo_data/audio_understanding/data.jsonl` is the demo data for Librispeech ASR task, and it does not contain enough data for finetune. You can prepare your own data in the same format (or run the script `prepare_librispeech_asrtask.py` to generate the data) and put it in the `finetune_codes/demo_data/audio_understanding/` directory. 43 | 44 | ## 2. Finetune 45 | 46 | 1. Download the pretrained model and save it in `output/pretrained_hf`. 47 | 48 | ``` bash 49 | CUDA_VISIBLE_DEVICES=0 python -m finetune_codes.model --model_name "moonshotai/Kimi-Audio-7B" --output_dir "output/pretrained_hf" 50 | ``` 51 | 52 | 2. Preprocess the data and extract the semantic tokens. 53 | ```bash 54 | CUDA_VISIBLE_DEVICES=0 python -m finetune_codes.extract_semantic_codes --input_file "finetune_codes/demo_data/audio_understanding/data.jsonl" --output_file "finetune_codes/demo_data/audio_understanding/data_with_semantic_codes.jsonl" 55 | ``` 56 | 57 | 3. Finetune the model. 58 | 59 | You can use the following command to finetune the model. 60 | 61 | ```bash 62 | bash finetune_codes/finetune_ds.sh \ 63 | --model_path "output/pretrained_hf" \ 64 | --data "finetune_codes/demo_data/audio_understanding/data_with_semantic_codes.jsonl" 65 | ``` 66 | 67 | 4. Convert the finetuned model for inference. 68 | ```bash 69 | CUDA_VISIBLE_DEVICES=0 python -m finetune_codes.model --model_name "moonshotai/Kimi-Audio-7B" \ 70 | --action "export_model" \ 71 | --input_dir "output/kimiaudio_ckpts" \ 72 | --output_dir "output/finetuned_hf_for_inference" 73 | ``` 74 | 75 | You can infer with the finetuned model by running: 76 | ```bash 77 | CUDA_VISIBLE_DEVICES=0 python -m finetune_codes.check_sft_infer 78 | ``` 79 | 80 | # Note 81 | 82 | In this example, we support the ASR task. For other task such as speech conversation or text-to-speech, you might need to change the `tokenize_message` function in `finetune_codes/datasets.py`. 83 | 84 | The hyper-parameters in `finetune_codes/finetune_ds.sh` should be tuned in new task because of the differences in task and dataset size. -------------------------------------------------------------------------------- /finetune_codes/finetune_ds.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | export CUDA_DEVICE_MAX_CONNECTIONS=1 3 | DIR=`pwd` 4 | 5 | # Guide: 6 | # This script supports distributed training on multi-gpu workers (as well as single-worker training). 7 | # Please set the options below according to the comments. 8 | # For multi-gpu workers training, these options should be manually set for each worker. 9 | # After setting the options, please run the script on each worker. 10 | 11 | # Number of GPUs per GPU worker 12 | GPUS_PER_NODE=$(python -c 'import torch; print(torch.cuda.device_count())') 13 | 14 | # Number of GPU workers, for single-worker training, please set to 1 15 | NNODES=${NNODES:-1} 16 | 17 | # The rank of this worker, should be in {0, ..., WORKER_CNT-1}, for single-worker training, please set to 0 18 | NODE_RANK=${NODE_RANK:-0} 19 | 20 | # The ip address of the rank-0 worker, for single-worker training, please set to localhost 21 | MASTER_ADDR=${MASTER_ADDR:-localhost} 22 | 23 | # The port for communication 24 | MASTER_PORT=${MASTER_PORT:-6001} 25 | 26 | MODEL="moonshotai/Kimi-Audio-7B" # Set the path if you do not want to load from huggingface directly 27 | 28 | PRETRAINED_MODEL_PATH="" 29 | 30 | # ATTENTION: specify the path to your training data, which should be a json file consisting of a list of conversations. 31 | # See the section for finetuning in README for more information. 32 | DATA="" 33 | 34 | function usage() { 35 | echo ' 36 | Usage: bash finetune/finetune_ds.sh [-m MODEL_PATH] [-d DATA_PATH] 37 | ' 38 | } 39 | 40 | while [[ "$1" != "" ]]; do 41 | case $1 in 42 | -m | --model_path ) 43 | shift 44 | PRETRAINED_MODEL_PATH=$1 45 | ;; 46 | -d | --data ) 47 | shift 48 | DATA=$1 49 | ;; 50 | -h | --help ) 51 | usage 52 | exit 0 53 | ;; 54 | * ) 55 | echo "Unknown argument ${1}" 56 | exit 1 57 | ;; 58 | esac 59 | shift 60 | done 61 | 62 | # check if data exists 63 | if [ ! -f "$DATA" ]; then 64 | echo "Error: DATA file does not exist" 65 | exit 1 66 | fi 67 | 68 | # check if model_path exists 69 | if [ ! -d "$PRETRAINED_MODEL_PATH" ]; then 70 | echo "Error: PRETRAINED_MODEL_PATH does not exist" 71 | exit 1 72 | fi 73 | 74 | echo "PRETRAINED_MODEL_PATH: $PRETRAINED_MODEL_PATH" 75 | echo "DATA: $DATA" 76 | 77 | DISTRIBUTED_ARGS=" 78 | --nproc_per_node $GPUS_PER_NODE \ 79 | --nnodes $NNODES \ 80 | --node_rank $NODE_RANK \ 81 | --master_addr $MASTER_ADDR \ 82 | --master_port $MASTER_PORT 83 | " 84 | 85 | echo "start finetune" 86 | echo "DISTRIBUTED_ARGS: $DISTRIBUTED_ARGS" 87 | 88 | torchrun $DISTRIBUTED_ARGS finetune.py \ 89 | --model_name_or_path $MODEL \ 90 | --model_path $PRETRAINED_MODEL_PATH \ 91 | --data_path $DATA \ 92 | --eval_ratio 0.05 \ 93 | --bf16 True \ 94 | --output_dir output/kimiaudio_ckpts \ 95 | --num_train_epochs 5 \ 96 | --per_device_train_batch_size 1 \ 97 | --per_device_eval_batch_size 1 \ 98 | --gradient_accumulation_steps 1 \ 99 | --evaluation_strategy "no" \ 100 | --save_strategy "steps" \ 101 | --save_steps 1000 \ 102 | --save_total_limit 10 \ 103 | --learning_rate 1e-5 \ 104 | --weight_decay 0.1 \ 105 | --adam_beta2 0.95 \ 106 | --warmup_ratio 0.01 \ 107 | --lr_scheduler_type "cosine" \ 108 | --logging_steps 1 \ 109 | --report_to "none" \ 110 | --model_max_length 512 \ 111 | --gradient_checkpointing True \ 112 | --lazy_preprocess True \ 113 | --deepspeed finetune_codes/ds_config_zero3.json -------------------------------------------------------------------------------- /infer.py: -------------------------------------------------------------------------------- 1 | from kimia_infer.api.kimia import KimiAudio 2 | import os 3 | import soundfile as sf 4 | import argparse 5 | 6 | if __name__ == "__main__": 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument("--model_path", type=str, default="moonshotai/Kimi-Audio-7B-Instruct") 9 | args = parser.parse_args() 10 | 11 | model = KimiAudio( 12 | model_path=args.model_path, 13 | load_detokenizer=True, 14 | ) 15 | 16 | sampling_params = { 17 | "audio_temperature": 0.8, 18 | "audio_top_k": 10, 19 | "text_temperature": 0.0, 20 | "text_top_k": 5, 21 | "audio_repetition_penalty": 1.0, 22 | "audio_repetition_window_size": 64, 23 | "text_repetition_penalty": 1.0, 24 | "text_repetition_window_size": 16, 25 | } 26 | 27 | messages = [ 28 | {"role": "user", "message_type": "text", "content": "请将音频内容转换为文字。"}, 29 | { 30 | "role": "user", 31 | "message_type": "audio", 32 | "content": "test_audios/asr_example.wav", 33 | }, 34 | ] 35 | 36 | wav, text = model.generate(messages, **sampling_params, output_type="text") 37 | print(">>> output text: ", text) 38 | 39 | output_dir = "test_audios/output" 40 | os.makedirs(output_dir, exist_ok=True) 41 | # audio2audio 42 | messages = [ 43 | { 44 | "role": "user", 45 | "message_type": "audio", 46 | "content": "test_audios/qa_example.wav", 47 | } 48 | ] 49 | 50 | wav, text = model.generate(messages, **sampling_params, output_type="both") 51 | sf.write( 52 | os.path.join(output_dir, "output.wav"), 53 | wav.detach().cpu().view(-1).numpy(), 54 | 24000, 55 | ) 56 | print(">>> output text: ", text) 57 | 58 | 59 | # audio2audio multiturn 60 | messages = [ 61 | { 62 | "role": "user", 63 | "message_type": "audio", 64 | "content": "test_audios/multiturn/case1/multiturn_q1.wav", 65 | }, 66 | { 67 | "role": "assistant", 68 | "message_type": "audio-text", 69 | "content": ["test_audios/multiturn/case1/multiturn_a1.wav", "当然可以,李白的诗很多,比如这句:“床前明月光,疑是地上霜。举头望明月,低头思故乡。"] 70 | }, 71 | { 72 | "role": "user", 73 | "message_type": "audio", 74 | "content": "test_audios/multiturn/case1/multiturn_q2.wav", 75 | } 76 | ] 77 | wav, text = model.generate(messages, **sampling_params, output_type="both") 78 | sf.write( 79 | os.path.join(output_dir, "case_1_multiturn_a2.wav"), 80 | wav.detach().cpu().view(-1).numpy(), 81 | 24000, 82 | ) 83 | print(">>> output text: ", text) 84 | 85 | 86 | messages = [ 87 | { 88 | "role": "user", 89 | "message_type": "audio", 90 | "content": "test_audios/multiturn/case2/multiturn_q1.wav", 91 | }, 92 | { 93 | "role": "assistant", 94 | "message_type": "audio-text", 95 | "content": ["test_audios/multiturn/case2/multiturn_a1.wav", "当然可以,这很简单。一二三四五六七八九十。"] 96 | }, 97 | { 98 | "role": "user", 99 | "message_type": "audio", 100 | "content": "test_audios/multiturn/case2/multiturn_q2.wav", 101 | } 102 | ] 103 | wav, text = model.generate(messages, **sampling_params, output_type="both") 104 | sf.write( 105 | os.path.join(output_dir, "case_2_multiturn_a2.wav"), 106 | wav.detach().cpu().view(-1).numpy(), 107 | 24000, 108 | ) 109 | print(">>> output text: ", text) 110 | -------------------------------------------------------------------------------- /kimia_infer/models/detokenizer/bigvgan_wrapper.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import logging 4 | 5 | import librosa 6 | import torch 7 | 8 | from .vocoder.bigvgan import BigVGAN 9 | from .vocoder.utils import get_melspec, AttrDict, load_checkpoint 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | 14 | class BigVGANWrapper: 15 | def __init__( 16 | self, vocoder: BigVGAN, device: torch.device, h: AttrDict, dtype=None 17 | ) -> None: 18 | self.vocoder = vocoder.to(device) 19 | if dtype is not None: 20 | self.vocoder = self.vocoder.to(dtype) 21 | self.vocoder = self.vocoder.eval() 22 | self.device = device 23 | self.h = h 24 | 25 | def to_dtype(self, dtype): 26 | self.vocoder = self.vocoder.to(dtype) 27 | 28 | def extract_mel_from_wav(self, wav_path=None, wav_data=None): 29 | """ 30 | params: 31 | wav_path: str, path of the wav, should be 24k 32 | wav_data: torch.tensor or numpy array, shape [T], wav data, should be 24k 33 | return: 34 | mel: [T, num_mels], torch.tensor 35 | """ 36 | if wav_data is None: 37 | wav_data, _ = librosa.load(wav_path, sr=self.h["sampling_rate"]) 38 | 39 | wav_data = torch.tensor(wav_data).unsqueeze(0) 40 | 41 | mel = get_melspec( 42 | y=wav_data, 43 | n_fft=self.h["n_fft"], 44 | num_mels=self.h["num_mels"], 45 | sampling_rate=self.h["sampling_rate"], 46 | hop_size=self.h["hop_size"], 47 | win_size=self.h["win_size"], 48 | fmin=self.h["fmin"], 49 | fmax=self.h["fmax"], 50 | ) 51 | return mel.squeeze(0).transpose(0, 1) 52 | 53 | @torch.inference_mode() 54 | def extract_mel_from_wav_batch(self, wav_data): 55 | """ 56 | params: 57 | wav_data: torch.tensor or numpy array, shape [Batch, T], wav data, should be 24k 58 | return: 59 | mel: [Batch, T, num_mels], torch.tensor 60 | """ 61 | 62 | wav_data = torch.tensor(wav_data) 63 | 64 | mel = get_melspec( 65 | wav=wav_data, 66 | n_fft=self.h["n_fft"], 67 | num_mels=self.h["num_mels"], 68 | sampling_rate=self.h["sampling_rate"], 69 | hop_size=self.h["hop_size"], 70 | win_size=self.h["win_size"], 71 | fmin=self.h["fmin"], 72 | fmax=self.h["fmax"], 73 | ) 74 | return mel.transpose(1, 2) 75 | 76 | def decode_mel(self, mel): 77 | """ 78 | params: 79 | mel: [T, num_mels], torch.tensor 80 | return: 81 | wav: [1, T], torch.tensor 82 | """ 83 | mel = mel.transpose(0, 1).unsqueeze(0).to(self.device) 84 | wav = self.vocoder(mel) 85 | return wav.squeeze(0) 86 | 87 | def decode_mel_batch(self, mel): 88 | """ 89 | params: 90 | mel: [B, T, num_mels], torch.tensor 91 | return: 92 | wav: [B, 1, T], torch.tensor 93 | """ 94 | mel = mel.transpose(1, 2).to(self.device) 95 | wav = self.vocoder(mel) 96 | return wav 97 | 98 | @classmethod 99 | def from_pretrained(cls, model_config, ckpt_path, device): 100 | with open(model_config) as f: 101 | data = f.read() 102 | json_config = json.loads(data) 103 | h = AttrDict(json_config) 104 | vocoder = BigVGAN(h, True) 105 | state_dict_g = load_checkpoint(ckpt_path, "cpu") 106 | vocoder.load_state_dict(state_dict_g["generator"]) 107 | 108 | logger.info(">>> Load vocoder from {}".format(ckpt_path)) 109 | return cls(vocoder, device, h) 110 | -------------------------------------------------------------------------------- /kimia_infer/models/detokenizer/vocoder/utils.py: -------------------------------------------------------------------------------- 1 | from librosa.filters import mel as librosa_mel_fn 2 | import torch 3 | import os 4 | 5 | mel_basis_cache = {} 6 | hann_window_cache = {} 7 | 8 | 9 | def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): 10 | return torch.log(torch.clamp(x, min=clip_val) * C) 11 | 12 | 13 | def spectral_normalize_torch(magnitudes): 14 | return dynamic_range_compression_torch(magnitudes) 15 | 16 | 17 | def get_melspec( 18 | y: torch.Tensor, 19 | n_fft: int, 20 | num_mels: int, 21 | sampling_rate: int, 22 | hop_size: int, 23 | win_size: int, 24 | fmin: int, 25 | fmax: int = None, 26 | center: bool = False, 27 | ) -> torch.Tensor: 28 | """ 29 | Calculate the mel spectrogram of an input signal. 30 | This function uses slaney norm for the librosa mel filterbank (using librosa.filters.mel) and uses Hann window for STFT (using torch.stft). 31 | 32 | Args: 33 | y (torch.Tensor): Input signal. 34 | n_fft (int): FFT size. 35 | num_mels (int): Number of mel bins. 36 | sampling_rate (int): Sampling rate of the input signal. 37 | hop_size (int): Hop size for STFT. 38 | win_size (int): Window size for STFT. 39 | fmin (int): Minimum frequency for mel filterbank. 40 | fmax (int): Maximum frequency for mel filterbank. If None, defaults to half the sampling rate (fmax = sr / 2.0) inside librosa_mel_fn 41 | center (bool): Whether to pad the input to center the frames. Default is False. 42 | 43 | Returns: 44 | torch.Tensor: Mel spectrogram. 45 | """ 46 | if torch.min(y) < -1.0: 47 | print(f"[WARNING] Min value of input waveform signal is {torch.min(y)}") 48 | if torch.max(y) > 1.0: 49 | print(f"[WARNING] Max value of input waveform signal is {torch.max(y)}") 50 | 51 | device = y.device 52 | key = f"{n_fft}_{num_mels}_{sampling_rate}_{hop_size}_{win_size}_{fmin}_{fmax}_{device}" 53 | 54 | if key not in mel_basis_cache: 55 | mel = librosa_mel_fn( 56 | sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax 57 | ) 58 | mel_basis_cache[key] = torch.from_numpy(mel).float().to(device) 59 | hann_window_cache[key] = torch.hann_window(win_size).to(device) 60 | 61 | mel_basis = mel_basis_cache[key] 62 | hann_window = hann_window_cache[key] 63 | 64 | padding = (n_fft - hop_size) // 2 65 | y = torch.nn.functional.pad( 66 | y.unsqueeze(1), (padding, padding), mode="reflect" 67 | ).squeeze(1) 68 | 69 | spec = torch.stft( 70 | y, 71 | n_fft, 72 | hop_length=hop_size, 73 | win_length=win_size, 74 | window=hann_window, 75 | center=center, 76 | pad_mode="reflect", 77 | normalized=False, 78 | onesided=True, 79 | return_complex=True, 80 | ) 81 | spec = torch.sqrt(torch.view_as_real(spec).pow(2).sum(-1) + 1e-9) 82 | 83 | mel_spec = torch.matmul(mel_basis, spec) 84 | mel_spec = spectral_normalize_torch(mel_spec) 85 | 86 | return mel_spec 87 | 88 | 89 | class AttrDict(dict): 90 | def __init__(self, *args, **kwargs): 91 | super(AttrDict, self).__init__(*args, **kwargs) 92 | self.__dict__ = self 93 | 94 | 95 | def load_checkpoint(filepath, device): 96 | assert os.path.isfile(filepath) 97 | print(f"Loading '{filepath}'") 98 | checkpoint_dict = torch.load(filepath, map_location=device, weights_only=True) 99 | print("Complete.") 100 | return checkpoint_dict 101 | 102 | 103 | def init_weights(m, mean=0.0, std=0.01): 104 | classname = m.__class__.__name__ 105 | if classname.find("Conv") != -1: 106 | m.weight.data.normal_(mean, std) 107 | 108 | 109 | def get_padding(kernel_size, dilation=1): 110 | return int((kernel_size * dilation - dilation) / 2) 111 | -------------------------------------------------------------------------------- /kimia_infer/models/detokenizer/vocoder/alias_free_activation/torch/filter.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0 2 | # LICENSE is in incl_licenses directory. 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import math 8 | 9 | if "sinc" in dir(torch): 10 | sinc = torch.sinc 11 | else: 12 | # This code is adopted from adefossez's julius.core.sinc under the MIT License 13 | # https://adefossez.github.io/julius/julius/core.html 14 | # LICENSE is in incl_licenses directory. 15 | def sinc(x: torch.Tensor): 16 | """ 17 | Implementation of sinc, i.e. sin(pi * x) / (pi * x) 18 | __Warning__: Different to julius.sinc, the input is multiplied by `pi`! 19 | """ 20 | return torch.where( 21 | x == 0, 22 | torch.tensor(1.0, device=x.device, dtype=x.dtype), 23 | torch.sin(math.pi * x) / math.pi / x, 24 | ) 25 | 26 | 27 | # This code is adopted from adefossez's julius.lowpass.LowPassFilters under the MIT License 28 | # https://adefossez.github.io/julius/julius/lowpass.html 29 | # LICENSE is in incl_licenses directory. 30 | def kaiser_sinc_filter1d( 31 | cutoff, half_width, kernel_size 32 | ): # return filter [1,1,kernel_size] 33 | even = kernel_size % 2 == 0 34 | half_size = kernel_size // 2 35 | 36 | # For kaiser window 37 | delta_f = 4 * half_width 38 | A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95 39 | if A > 50.0: 40 | beta = 0.1102 * (A - 8.7) 41 | elif A >= 21.0: 42 | beta = 0.5842 * (A - 21) ** 0.4 + 0.07886 * (A - 21.0) 43 | else: 44 | beta = 0.0 45 | window = torch.kaiser_window(kernel_size, beta=beta, periodic=False) 46 | 47 | # ratio = 0.5/cutoff -> 2 * cutoff = 1 / ratio 48 | if even: 49 | time = torch.arange(-half_size, half_size) + 0.5 50 | else: 51 | time = torch.arange(kernel_size) - half_size 52 | if cutoff == 0: 53 | filter_ = torch.zeros_like(time) 54 | else: 55 | filter_ = 2 * cutoff * window * sinc(2 * cutoff * time) 56 | """ 57 | Normalize filter to have sum = 1, otherwise we will have a small leakage of the constant component in the input signal. 58 | """ 59 | filter_ /= filter_.sum() 60 | filter = filter_.view(1, 1, kernel_size) 61 | 62 | return filter 63 | 64 | 65 | class LowPassFilter1d(nn.Module): 66 | def __init__( 67 | self, 68 | cutoff=0.5, 69 | half_width=0.6, 70 | stride: int = 1, 71 | padding: bool = True, 72 | padding_mode: str = "replicate", 73 | kernel_size: int = 12, 74 | ): 75 | """ 76 | kernel_size should be even number for stylegan3 setup, in this implementation, odd number is also possible. 77 | """ 78 | super().__init__() 79 | if cutoff < -0.0: 80 | raise ValueError("Minimum cutoff must be larger than zero.") 81 | if cutoff > 0.5: 82 | raise ValueError("A cutoff above 0.5 does not make sense.") 83 | self.kernel_size = kernel_size 84 | self.even = kernel_size % 2 == 0 85 | self.pad_left = kernel_size // 2 - int(self.even) 86 | self.pad_right = kernel_size // 2 87 | self.stride = stride 88 | self.padding = padding 89 | self.padding_mode = padding_mode 90 | filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size) 91 | self.register_buffer("filter", filter) 92 | 93 | # Input [B, C, T] 94 | def forward(self, x): 95 | _, C, _ = x.shape 96 | 97 | if self.padding: 98 | x = F.pad(x, (self.pad_left, self.pad_right), mode=self.padding_mode) 99 | out = F.conv1d(x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C) 100 | 101 | return out 102 | -------------------------------------------------------------------------------- /kimia_infer/utils/data.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class KimiAContent: 5 | def __init__( 6 | self, audio_token_ids=None, text_token_ids=None, is_continuous_mask=None, audio_token_loss_mask=None, text_token_loss_mask=None 7 | ): 8 | self.audio_token_ids: list[int] = audio_token_ids or [] 9 | self.text_token_ids: list[int] = text_token_ids or [] 10 | self.is_continuous_mask: list[int] = is_continuous_mask or [] 11 | 12 | self.audio_token_loss_mask: list[int] = audio_token_loss_mask or [] 13 | self.text_token_loss_mask: list[int] = text_token_loss_mask or [] 14 | 15 | self.continuous_feature = [] 16 | 17 | def audio_append(self, index: int, is_continuous: bool = False, audio_token_loss_mask: bool = False): 18 | self.audio_token_ids.append(index) 19 | self.is_continuous_mask.append(is_continuous) 20 | self.audio_token_loss_mask.append(audio_token_loss_mask) 21 | 22 | def text_append(self, index: int, text_token_loss_mask: bool = False): 23 | self.text_token_ids.append(index) 24 | self.text_token_loss_mask.append(text_token_loss_mask) 25 | 26 | def audio_extend(self, ids: list[int], is_continuous: bool = False, audio_token_loss_mask: bool = False): 27 | self.audio_token_ids.extend(ids) 28 | self.is_continuous_mask.extend([is_continuous] * len(ids)) 29 | self.audio_token_loss_mask.extend([audio_token_loss_mask] * len(ids)) 30 | 31 | def text_extend(self, ids: list[int], text_token_loss_mask: bool = False): 32 | self.text_token_ids.extend(ids) 33 | self.text_token_loss_mask.extend([text_token_loss_mask] * len(ids)) 34 | 35 | def audio_prepend(self, index: int, is_continuous: bool = False, audio_token_loss_mask: bool = False): 36 | self.audio_token_ids = [index] + self.audio_token_ids 37 | self.is_continuous_mask = [is_continuous] + self.is_continuous_mask 38 | self.audio_token_loss_mask = [audio_token_loss_mask] + self.audio_token_loss_mask 39 | 40 | def text_prepend(self, index: int, text_token_loss_mask: bool = False): 41 | self.text_token_ids = [index] + self.text_token_ids 42 | self.text_token_loss_mask = [text_token_loss_mask] + self.text_token_loss_mask 43 | 44 | def audio_pretend(self, ids: list[int], is_continuous: bool = False, audio_token_loss_mask: bool = False): 45 | self.audio_token_ids = ids + self.audio_token_ids 46 | self.is_continuous_mask = [is_continuous] * len(ids) + self.is_continuous_mask 47 | self.audio_token_loss_mask = [audio_token_loss_mask] * len(ids) + self.audio_token_loss_mask 48 | 49 | def text_pretend(self, ids: list[int], text_token_loss_mask: bool = False ): 50 | self.text_token_ids = ids + self.text_token_ids 51 | self.text_token_loss_mask = [text_token_loss_mask] * len(ids) + self.text_token_loss_mask 52 | 53 | def merge(self, other: "KimiAContent"): 54 | self.audio_token_ids.extend(other.audio_token_ids) 55 | self.text_token_ids.extend(other.text_token_ids) 56 | self.is_continuous_mask.extend(other.is_continuous_mask) 57 | self.audio_token_loss_mask.extend(other.audio_token_loss_mask) 58 | self.text_token_loss_mask.extend(other.text_token_loss_mask) 59 | self.continuous_feature.extend(other.continuous_feature) 60 | 61 | def to_tensor(self) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 62 | return ( 63 | torch.tensor([self.audio_token_ids], dtype=torch.long), 64 | torch.tensor([self.text_token_ids], dtype=torch.long), 65 | torch.tensor([self.is_continuous_mask], dtype=torch.bool), 66 | torch.tensor([self.audio_token_loss_mask], dtype=torch.bool), 67 | torch.tensor([self.text_token_loss_mask], dtype=torch.bool), 68 | ) 69 | 70 | def is_valid(self): 71 | return ( 72 | len(self.audio_token_ids) 73 | == len(self.text_token_ids) 74 | == len(self.is_continuous_mask) 75 | == len(self.audio_token_loss_mask) 76 | == len(self.text_token_loss_mask) 77 | ) 78 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python_goose script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python_goose-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .env_* 125 | .venv 126 | env/ 127 | venv/ 128 | ENV/ 129 | env.bak/ 130 | venv.bak/ 131 | 132 | # Spyder project settings 133 | .spyderproject 134 | .spyproject 135 | 136 | # Rope project settings 137 | .ropeproject 138 | 139 | # mkdocs documentation 140 | /site 141 | 142 | # mypy 143 | .mypy_cache/ 144 | .dmypy.json 145 | dmypy.json 146 | 147 | # Pyre type checker 148 | .pyre/ 149 | 150 | # pytype static type analyzer 151 | .pytype/ 152 | 153 | # Cython debug symbols 154 | cython_debug/ 155 | 156 | # PyCharm 157 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 158 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 159 | # and can be added to the global gitignore or merged into this file. For a more nuclear 160 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 161 | .idea/ 162 | data 163 | logs 164 | *.zip 165 | conf 166 | .DS_Store 167 | .ruff_cache 168 | .log 169 | *.parquet 170 | *.progress 171 | # Vscode 172 | .vscode 173 | 174 | *.safetensors 175 | *.model 176 | *.pt 177 | *.pth 178 | test_audios/output 179 | output/ 180 | output -------------------------------------------------------------------------------- /kimia_infer/models/tokenizer/glm4_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import io 3 | import glob 4 | import math 5 | import tarfile 6 | import torch 7 | import torchaudio 8 | import safetensors 9 | from .glm4.speech_tokenizer.configuration_whisper import WhisperVQConfig 10 | from .glm4.speech_tokenizer.modeling_whisper import WhisperVQEncoder, WhisperVQForConditionalGeneration 11 | from transformers import WhisperFeatureExtractor, WhisperTokenizerFast 12 | 13 | 14 | def load_quantize_encoder(model_path): 15 | config = WhisperVQConfig.from_pretrained(model_path) 16 | config.quantize_encoder_only = True 17 | model = WhisperVQEncoder(config) 18 | state_dict = {} 19 | for path in glob.glob(os.path.join(model_path, "model*.safetensors")): 20 | with safetensors.safe_open(path, framework="pt", device="cpu") as f: 21 | for key in f.keys(): 22 | if key.startswith("model.encoder."): 23 | new_key = key[len("model.encoder."):] 24 | if new_key.startswith("layer_norm"): 25 | continue 26 | if new_key.startswith("layers"): 27 | layer_id = int(new_key.split(".")[1]) 28 | if layer_id >= config.quantize_position: 29 | continue 30 | state_dict[new_key] = f.get_tensor(key) 31 | model.load_state_dict(state_dict) 32 | model.eval() 33 | model.cuda() 34 | return model 35 | 36 | 37 | _resample_buffer: dict[int, torchaudio.transforms.Resample] = {} 38 | 39 | 40 | def extract_speech_token(model: WhisperVQEncoder, feature_extractor: WhisperFeatureExtractor, utts): 41 | dtype = model.conv1.weight.dtype 42 | with torch.no_grad(): 43 | audios, indices = [], [] 44 | for idx, utt in enumerate(utts): 45 | if isinstance(utt, tuple): 46 | audio, sample_rate = utt 47 | else: 48 | audio, sample_rate = torchaudio.load(utt) 49 | audio = audio.to(torch.cuda.current_device()) 50 | if sample_rate != 16000: 51 | if sample_rate not in _resample_buffer: 52 | _resample_buffer[sample_rate] = torchaudio.transforms.Resample( 53 | orig_freq=sample_rate, 54 | new_freq=16000 55 | ).to(torch.cuda.current_device()) 56 | audio = _resample_buffer[sample_rate](audio) 57 | # if audio.shape[0] > 1: 58 | # audio = audio[:1] 59 | audio = audio[0] 60 | audio = audio.cpu().numpy() 61 | time_step = 0 62 | while time_step * 16000 < audio.shape[0]: 63 | audio_segment = audio[time_step * 16000: (time_step + 30) * 16000] 64 | audios.append(audio_segment) 65 | indices.append(idx) 66 | time_step += 30 67 | pooling_kernel_size = model.config.pooling_kernel_size or 1 68 | stride = model.conv1.stride[0] * model.conv2.stride[0] * pooling_kernel_size * feature_extractor.hop_length 69 | all_speech_tokens = [[] for _ in range(len(utts))] 70 | batch_size = 128 71 | for start in range(0, len(audios), batch_size): 72 | features = feature_extractor(audios[start: start + batch_size], sampling_rate=16000, 73 | return_attention_mask=True, return_tensors="pt", device=torch.cuda.current_device(), 74 | padding="longest", pad_to_multiple_of=stride) 75 | features["input_features"] = features["input_features"].to(torch.cuda.current_device()).to(dtype) 76 | features["attention_mask"] = features["attention_mask"].to(torch.cuda.current_device()) 77 | # import ipdb; ipdb.set_trace() 78 | outputs = model(**features) 79 | speech_tokens = outputs.quantized_token_ids 80 | attention_mask = features.attention_mask[:, ::model.conv1.stride[0] * model.conv2.stride[0]] 81 | attention_mask = attention_mask[:, ::model.config.pooling_kernel_size] 82 | assert attention_mask.shape == speech_tokens.shape 83 | for i in range(len(speech_tokens)): 84 | idx = indices[start + i] 85 | speech_token = speech_tokens[i][attention_mask[i].bool()].tolist() 86 | all_speech_tokens[idx].extend(speech_token) 87 | return all_speech_tokens 88 | -------------------------------------------------------------------------------- /kimia_infer/models/detokenizer/vocoder/activations.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, sin, pow 3 | from torch.nn import Parameter 4 | 5 | 6 | class Snake(nn.Module): 7 | """ 8 | Implementation of a sine-based periodic activation function 9 | Shape: 10 | - Input: (B, C, T) 11 | - Output: (B, C, T), same shape as the input 12 | Parameters: 13 | - alpha - trainable parameter 14 | References: 15 | - This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda: 16 | https://arxiv.org/abs/2006.08195 17 | Examples: 18 | >>> a1 = snake(256) 19 | >>> x = torch.randn(256) 20 | >>> x = a1(x) 21 | """ 22 | 23 | def __init__( 24 | self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False 25 | ): 26 | """ 27 | Initialization. 28 | INPUT: 29 | - in_features: shape of the input 30 | - alpha: trainable parameter 31 | alpha is initialized to 1 by default, higher values = higher-frequency. 32 | alpha will be trained along with the rest of your model. 33 | """ 34 | super(Snake, self).__init__() 35 | self.in_features = in_features 36 | 37 | # Initialize alpha 38 | self.alpha_logscale = alpha_logscale 39 | if self.alpha_logscale: # Log scale alphas initialized to zeros 40 | self.alpha = Parameter(torch.zeros(in_features) * alpha) 41 | else: # Linear scale alphas initialized to ones 42 | self.alpha = Parameter(torch.ones(in_features) * alpha) 43 | 44 | self.alpha.requires_grad = alpha_trainable 45 | 46 | self.no_div_by_zero = 0.000000001 47 | 48 | def forward(self, x): 49 | """ 50 | Forward pass of the function. 51 | Applies the function to the input elementwise. 52 | Snake ∶= x + 1/a * sin^2 (xa) 53 | """ 54 | alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # Line up with x to [B, C, T] 55 | if self.alpha_logscale: 56 | alpha = torch.exp(alpha) 57 | x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2) 58 | 59 | return x 60 | 61 | 62 | class SnakeBeta(nn.Module): 63 | """ 64 | A modified Snake function which uses separate parameters for the magnitude of the periodic components 65 | Shape: 66 | - Input: (B, C, T) 67 | - Output: (B, C, T), same shape as the input 68 | Parameters: 69 | - alpha - trainable parameter that controls frequency 70 | - beta - trainable parameter that controls magnitude 71 | References: 72 | - This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda: 73 | https://arxiv.org/abs/2006.08195 74 | Examples: 75 | >>> a1 = snakebeta(256) 76 | >>> x = torch.randn(256) 77 | >>> x = a1(x) 78 | """ 79 | 80 | def __init__( 81 | self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False 82 | ): 83 | """ 84 | Initialization. 85 | INPUT: 86 | - in_features: shape of the input 87 | - alpha - trainable parameter that controls frequency 88 | - beta - trainable parameter that controls magnitude 89 | alpha is initialized to 1 by default, higher values = higher-frequency. 90 | beta is initialized to 1 by default, higher values = higher-magnitude. 91 | alpha will be trained along with the rest of your model. 92 | """ 93 | super(SnakeBeta, self).__init__() 94 | self.in_features = in_features 95 | 96 | # Initialize alpha 97 | self.alpha_logscale = alpha_logscale 98 | if self.alpha_logscale: # Log scale alphas initialized to zeros 99 | self.alpha = Parameter(torch.zeros(in_features) * alpha) 100 | self.beta = Parameter(torch.zeros(in_features) * alpha) 101 | else: # Linear scale alphas initialized to ones 102 | self.alpha = Parameter(torch.ones(in_features) * alpha) 103 | self.beta = Parameter(torch.ones(in_features) * alpha) 104 | 105 | self.alpha.requires_grad = alpha_trainable 106 | self.beta.requires_grad = alpha_trainable 107 | 108 | self.no_div_by_zero = 0.000000001 109 | 110 | def forward(self, x): 111 | """ 112 | Forward pass of the function. 113 | Applies the function to the input elementwise. 114 | SnakeBeta ∶= x + 1/b * sin^2 (xa) 115 | """ 116 | alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # Line up with x to [B, C, T] 117 | beta = self.beta.unsqueeze(0).unsqueeze(-1) 118 | if self.alpha_logscale: 119 | alpha = torch.exp(alpha) 120 | beta = torch.exp(beta) 121 | x = x + (1.0 / (beta + self.no_div_by_zero)) * pow(sin(x * alpha), 2) 122 | 123 | return x 124 | -------------------------------------------------------------------------------- /kimia_infer/models/detokenizer/vocoder/alias_free_activation/cuda/type_shim.h: -------------------------------------------------------------------------------- 1 | /* coding=utf-8 2 | * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | #include 18 | #include "compat.h" 19 | 20 | #define DISPATCH_FLOAT_HALF_AND_BFLOAT(TYPE, NAME, ...) \ 21 | switch (TYPE) \ 22 | { \ 23 | case at::ScalarType::Float: \ 24 | { \ 25 | using scalar_t = float; \ 26 | __VA_ARGS__; \ 27 | break; \ 28 | } \ 29 | case at::ScalarType::Half: \ 30 | { \ 31 | using scalar_t = at::Half; \ 32 | __VA_ARGS__; \ 33 | break; \ 34 | } \ 35 | case at::ScalarType::BFloat16: \ 36 | { \ 37 | using scalar_t = at::BFloat16; \ 38 | __VA_ARGS__; \ 39 | break; \ 40 | } \ 41 | default: \ 42 | AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ 43 | } 44 | 45 | #define DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \ 46 | switch (TYPEIN) \ 47 | { \ 48 | case at::ScalarType::Float: \ 49 | { \ 50 | using scalar_t_in = float; \ 51 | switch (TYPEOUT) \ 52 | { \ 53 | case at::ScalarType::Float: \ 54 | { \ 55 | using scalar_t_out = float; \ 56 | __VA_ARGS__; \ 57 | break; \ 58 | } \ 59 | case at::ScalarType::Half: \ 60 | { \ 61 | using scalar_t_out = at::Half; \ 62 | __VA_ARGS__; \ 63 | break; \ 64 | } \ 65 | case at::ScalarType::BFloat16: \ 66 | { \ 67 | using scalar_t_out = at::BFloat16; \ 68 | __VA_ARGS__; \ 69 | break; \ 70 | } \ 71 | default: \ 72 | AT_ERROR(#NAME, " not implemented for '", toString(TYPEOUT), "'"); \ 73 | } \ 74 | break; \ 75 | } \ 76 | case at::ScalarType::Half: \ 77 | { \ 78 | using scalar_t_in = at::Half; \ 79 | using scalar_t_out = at::Half; \ 80 | __VA_ARGS__; \ 81 | break; \ 82 | } \ 83 | case at::ScalarType::BFloat16: \ 84 | { \ 85 | using scalar_t_in = at::BFloat16; \ 86 | using scalar_t_out = at::BFloat16; \ 87 | __VA_ARGS__; \ 88 | break; \ 89 | } \ 90 | default: \ 91 | AT_ERROR(#NAME, " not implemented for '", toString(TYPEIN), "'"); \ 92 | } 93 | -------------------------------------------------------------------------------- /finetune_codes/demo_data/audio_understanding/prepare_librispeech_asrtask.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tqdm 3 | import json 4 | import argparse 5 | 6 | class LibrispeechtrainDownloader: 7 | """Downloader for Librispeech dataset.""" 8 | 9 | def __init__(self, output_dir: str): 10 | self.output_dir = output_dir 11 | self.metadata = {} 12 | 13 | def download(self) -> bool: 14 | librispeech_dir = os.path.join(self.output_dir, "librispeech") 15 | if not os.path.exists(librispeech_dir): 16 | os.makedirs(librispeech_dir, exist_ok=True) 17 | try: 18 | original_dir = os.getcwd() 19 | os.chdir(librispeech_dir) 20 | 21 | # Download train-clean-100 dataset 22 | download_success = os.system("wget https://us.openslr.org/resources/12/train-clean-100.tar.gz -O train-clean-100.tar.gz") 23 | if download_success != 0: 24 | raise RuntimeError("Failed to download train-clean-100 dataset") 25 | # Download train-clean-360 dataset 26 | download_success = os.system("wget https://us.openslr.org/resources/12/train-clean-360.tar.gz -O train-clean-360.tar.gz") 27 | if download_success != 0: 28 | raise RuntimeError("Failed to download train-clean-360 dataset") 29 | # Download train-other-500 dataset 30 | download_success = os.system("wget https://us.openslr.org/resources/12/train-other-500.tar.gz -O train-other-500.tar.gz") 31 | if download_success != 0: 32 | raise RuntimeError("Failed to download train-other-500 dataset") 33 | 34 | # Extract the tar.gz dataset 35 | extract_success = os.system("tar -xzf train-clean-100.tar.gz") 36 | if extract_success != 0: 37 | raise RuntimeError("Failed to extract train-clean-100 dataset") 38 | extract_success = os.system("tar -xzf train-clean-360.tar.gz") 39 | if extract_success != 0: 40 | raise RuntimeError("Failed to extract train-clean-360 dataset") 41 | extract_success = os.system("tar -xzf train-other-500.tar.gz") 42 | if extract_success != 0: 43 | raise RuntimeError("Failed to extract train-other-500 dataset") 44 | 45 | # Restore original directory 46 | os.chdir(original_dir) 47 | 48 | except Exception as e: 49 | print(f"Error downloading librispeech dataset: {str(e)}") 50 | return False 51 | else: 52 | print("librispeech dataset already downloaded") 53 | 54 | self.metadata["librispeech"] = [] 55 | index = 0 56 | metadata_path = os.path.join(self.output_dir, f"librispeech.jsonl") 57 | if os.path.exists(metadata_path): 58 | print(f"Skipping librispeech dataset because it already exists") 59 | return True 60 | 61 | question = "Please transcribe the spoken content into written text." 62 | 63 | subsets = ["train-clean-100", "train-clean-360", "train-other-500"] 64 | 65 | index = 0 66 | for subset in subsets: 67 | subset_dir = os.path.join(self.output_dir, "librispeech/LibriSpeech", subset) 68 | for spk_folder in tqdm.tqdm(os.listdir(subset_dir)): 69 | for chapter_folder in os.listdir(os.path.join(subset_dir, spk_folder)): 70 | # get all the flac files in the chapter_folder 71 | flac_files = [f for f in os.listdir(os.path.join(subset_dir, spk_folder, chapter_folder)) if f.endswith(".flac")] 72 | transcript_path = os.path.join(subset_dir, spk_folder, chapter_folder, f"{spk_folder}-{chapter_folder}.trans.txt") 73 | transcript_dict = {} 74 | with open(transcript_path, 'r', encoding="utf-8") as f: 75 | for line in f: 76 | parts = line.strip().split(" ", 1) 77 | assert len(parts) == 2, f"Invalid line: {line}" 78 | flac_file = parts[0] 79 | transcript = parts[1] 80 | transcript_dict[flac_file] = transcript 81 | for flac_file in flac_files: 82 | audio_path = os.path.join(subset_dir, spk_folder, chapter_folder, flac_file) 83 | transcript = transcript_dict[flac_file.split(".")[0]] 84 | assert os.path.exists(audio_path), f"Audio file {audio_path} does not exist" 85 | 86 | self.metadata["librispeech"].append({ 87 | "task_type": "understanding", 88 | 89 | "conversation": [ 90 | { 91 | "role": "user", 92 | 'message_type': 'text', 93 | "content": question 94 | }, 95 | { 96 | "role": "user", 97 | 'message_type': 'audio', 98 | 'content': audio_path 99 | }, 100 | { 101 | "role": "assistant", 102 | "message_type": "text", 103 | "content": transcript.lower() 104 | } 105 | ] 106 | }) 107 | index += 1 108 | 109 | with open(metadata_path, 'w', encoding="utf-8") as f: 110 | for metadata in self.metadata["librispeech"]: 111 | f.write(json.dumps(metadata, ensure_ascii=False) + '\n') 112 | print(f"Completed processing LibriSpeech dataset. Metadata saved to {metadata_path}") 113 | return True 114 | 115 | 116 | if __name__ == "__main__": 117 | parser = argparse.ArgumentParser() 118 | parser.add_argument("--output_dir", type=str, default="output/data/librispeech") 119 | args = parser.parse_args() 120 | downloader = LibrispeechtrainDownloader(output_dir=args.output_dir) 121 | downloader.download() 122 | 123 | 124 | -------------------------------------------------------------------------------- /finetune_codes/model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | from typing import Optional, List 4 | import shutil 5 | import torch 6 | from transformers import AutoModelForCausalLM 7 | from huggingface_hub import snapshot_download 8 | 9 | from kimia_infer.models.tokenizer.whisper_Lv3.whisper import WhisperEncoder 10 | from .modeling_kimia import MoonshotKimiaForCausalLM 11 | 12 | 13 | class KimiAudioModel(MoonshotKimiaForCausalLM): 14 | def __init__(self, config): 15 | super().__init__(config) 16 | self.whisper_model = WhisperEncoder("openai/whisper-large-v3", mel_batch_size=20, unfreeze_online_whisper_model=True) 17 | 18 | @classmethod 19 | def init_from_pretrained(cls, model_name_or_path, model_load_kwargs): 20 | if os.path.exists(model_name_or_path): 21 | # local path 22 | cache_path = model_name_or_path 23 | else: 24 | # cache everything if model_path is a model-id 25 | cache_path = snapshot_download(model_name_or_path) 26 | 27 | audio_model = AutoModelForCausalLM.from_pretrained( 28 | cache_path, 29 | device_map=None, 30 | torch_dtype=torch.bfloat16, trust_remote_code=True, **model_load_kwargs, 31 | ) 32 | 33 | whisper_model = WhisperEncoder( 34 | os.path.join(cache_path, "whisper-large-v3"), mel_batch_size=20, unfreeze_online_whisper_model=True 35 | ) 36 | kimia_model = cls(audio_model.config) 37 | 38 | # merge audio model and whisper model's state dict 39 | pretrained_state_dict = audio_model.state_dict() 40 | 41 | for n, p in whisper_model.state_dict().items(): 42 | pretrained_state_dict["whisper_model." + n] = p 43 | 44 | kimia_model.load_state_dict(pretrained_state_dict) 45 | 46 | return kimia_model 47 | 48 | @staticmethod 49 | def export_model(input_dir, output_dir): 50 | print("Loading model from {}".format(input_dir)) 51 | kimiaudio = KimiAudioModel.from_pretrained(input_dir) 52 | 53 | print("Saving Kimi-Audio LM to {}".format(output_dir)) 54 | audio_model = MoonshotKimiaForCausalLM(kimiaudio.config) 55 | audio_model_state_dict = {k: v for k, v in kimiaudio.state_dict().items() if not k.startswith("whisper_model")} 56 | audio_model.load_state_dict(audio_model_state_dict) 57 | 58 | audio_model.save_pretrained(output_dir) 59 | 60 | shutil.copyfile("finetune_codes/configuration_moonshot_kimia.py", os.path.join(output_dir, "configuration_moonshot_kimia.py")) 61 | shutil.copyfile("finetune_codes/modeling_kimia.py", os.path.join(output_dir, "modeling_moonshot_kimia.py")) 62 | 63 | from kimia_infer.models.tokenizer.whisper_Lv3.whisper import WhisperModel 64 | 65 | whisper_model = WhisperModel.from_pretrained("openai/whisper-large-v3") 66 | 67 | kimiaudio_whisper_encoder_state_dict = {k.replace("speech_encoder.", "encoder."): v for k, v in kimiaudio.whisper_model.state_dict().items() if k.startswith("speech_encoder")} 68 | 69 | missing_keys, unexpected_keys = whisper_model.load_state_dict(kimiaudio_whisper_encoder_state_dict, strict=False) 70 | assert len(unexpected_keys) == 0, f"Unexpected keys: {unexpected_keys}" 71 | 72 | for k in missing_keys: 73 | assert k.startswith("decoder"), f"Missing keys: {k}" 74 | 75 | whisper_model.save_pretrained(os.path.join(output_dir, "whisper-large-v3")) 76 | 77 | print("Exported Kimi-Audio LM and Whisper model to {}".format(output_dir)) 78 | 79 | 80 | def forward( 81 | self, 82 | input_ids: torch.LongTensor = None, 83 | text_input_ids: torch.LongTensor = None, 84 | whisper_input_feature: Optional[torch.FloatTensor] = None, 85 | is_continuous_mask: Optional[torch.Tensor] = None, 86 | attention_mask: Optional[torch.Tensor] = None, 87 | position_ids: Optional[torch.LongTensor] = None, 88 | past_key_values: Optional[List[torch.FloatTensor]] = None, 89 | inputs_embeds: Optional[torch.FloatTensor] = None, 90 | labels: Optional[torch.LongTensor] = None, 91 | use_cache: Optional[bool] = None, 92 | output_attentions: Optional[bool] = None, 93 | output_hidden_states: Optional[bool] = None, 94 | generation_mode: Optional[bool] = None, 95 | return_dict: Optional[bool] = None, 96 | ): 97 | whisper_input_feats = torch.from_numpy(whisper_input_feature[0]).unsqueeze(0)[:, :].to(torch.cuda.current_device()) 98 | whisper_feats = self.whisper_model(whisper_input_feats) 99 | whisper_feats = whisper_feats.reshape( 100 | whisper_feats.shape[0], 101 | int(whisper_feats.shape[1] // 4), 102 | whisper_feats.shape[2] * 4, 103 | ) 104 | return super().forward( 105 | input_ids=input_ids, 106 | text_input_ids=text_input_ids, 107 | whisper_input_feature=whisper_feats, 108 | is_continuous_mask=is_continuous_mask, 109 | attention_mask=attention_mask, 110 | position_ids=position_ids, 111 | past_key_values=past_key_values, 112 | inputs_embeds=inputs_embeds, 113 | labels=labels, 114 | use_cache=use_cache, 115 | output_attentions=output_attentions, 116 | output_hidden_states=output_hidden_states, 117 | generation_mode=generation_mode, 118 | return_dict=return_dict, 119 | ) 120 | 121 | 122 | if __name__ == "__main__": 123 | parser = argparse.ArgumentParser() 124 | parser.add_argument("--model_name", type=str, default="moonshotai/Kimi-Audio-7B") 125 | parser.add_argument("--action", type=str, choices=["init_from_pretrained", "export_model"], default="init_from_pretrained") 126 | parser.add_argument("--output_dir", type=str, default="output/pretrained_hf") 127 | parser.add_argument("--input_dir", type=str, default="output/finetuned_hf") 128 | args = parser.parse_args() 129 | 130 | if args.action == "init_from_pretrained": 131 | 132 | model = KimiAudioModel.init_from_pretrained(args.model_name, model_load_kwargs={}) 133 | 134 | os.makedirs(args.output_dir, exist_ok=True) 135 | # save model 136 | model.save_pretrained(args.output_dir) 137 | elif args.action == "export_model": 138 | KimiAudioModel.export_model(args.input_dir, args.output_dir) 139 | else: 140 | raise ValueError(f"Invalid action: {args.action}") -------------------------------------------------------------------------------- /kimia_infer/utils/sampler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class KimiASampler: 5 | def __init__( 6 | self, 7 | audio_top_k: int, 8 | audio_temperature: float, 9 | audio_repetition_penalty: float, 10 | audio_repetition_window_size: int, 11 | text_top_k: int, 12 | text_temperature: float, 13 | text_repetition_penalty: float, 14 | text_repetition_window_size: int, 15 | ): 16 | self.audio_top_k = audio_top_k 17 | self.audio_temperature = audio_temperature 18 | self.text_top_k = text_top_k 19 | self.text_temperature = text_temperature 20 | 21 | self.audio_repetition_penalty = audio_repetition_penalty 22 | self.audio_repetition_window_size = audio_repetition_window_size 23 | self.text_repetition_penalty = text_repetition_penalty 24 | self.text_repetition_window_size = text_repetition_window_size 25 | 26 | def sample_audio_logits( 27 | self, logits: torch.Tensor, recent_tokens=None 28 | ) -> torch.Tensor: 29 | """Sample from audio logits with top-k, temperature and repetition penalty. 30 | 31 | Args: 32 | logits: Logits tensor of shape [batch_size, seq_len, vocab_size] or [batch_size, vocab_size] 33 | recent_tokens: Optional tensor of recent tokens for repetition penalty 34 | 35 | Returns: 36 | Sampled token ids 37 | """ 38 | # Take the last token's logits if we have a sequence dimension 39 | if len(logits.shape) == 3: 40 | logits = logits[:, -1] 41 | 42 | # Apply repetition penalty if needed 43 | if ( 44 | self.audio_repetition_penalty > 1.0 45 | and recent_tokens is not None 46 | and len(recent_tokens) > self.audio_repetition_window_size 47 | ): 48 | logits = logits[0] # Assumes batch size of 1 for repetition penalty 49 | recent_window = recent_tokens[-self.audio_repetition_window_size :].long() 50 | 51 | # Gather scores of recent tokens 52 | scores = torch.gather(logits, dim=0, index=recent_window) 53 | 54 | # Apply penalty: if score < 0 multiply by penalty, otherwise divide by penalty 55 | scores = torch.where( 56 | scores < 0, 57 | scores * self.audio_repetition_penalty, 58 | scores / self.audio_repetition_penalty, 59 | ) 60 | 61 | # Put the penalized scores back 62 | logits.scatter_(dim=0, index=recent_window, src=scores) 63 | logits = logits.unsqueeze(0) # Add batch dimension back 64 | 65 | # Convert to probabilities with softmax 66 | logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float) 67 | 68 | # Apply temperature scaling if not greedy 69 | if self.audio_temperature > 1e-6: 70 | logprobs = logprobs / self.audio_temperature 71 | 72 | # Apply top-k sampling 73 | if self.audio_top_k > 0: 74 | # Get probabilities from logprobs 75 | probs = torch.exp(logprobs) 76 | 77 | # Select top-k probabilities and indices 78 | top_k_probs, top_k_indices = torch.topk(probs, self.audio_top_k, dim=-1) 79 | 80 | # Sample from the top-k distribution 81 | sampled_indices = torch.multinomial(top_k_probs, num_samples=1).squeeze( 82 | 1 83 | ) 84 | next_token = top_k_indices.gather( 85 | -1, sampled_indices.unsqueeze(-1) 86 | ).squeeze(-1) 87 | else: 88 | # Sample from the full distribution 89 | next_token = torch.multinomial( 90 | torch.exp(logprobs), num_samples=1 91 | ).squeeze(1) 92 | else: 93 | # Greedy sampling (temperature = 0) 94 | next_token = torch.argmax(logprobs, dim=-1) 95 | 96 | return next_token 97 | 98 | def sample_text_logits( 99 | self, logits: torch.Tensor, recent_tokens=None 100 | ) -> torch.Tensor: 101 | """Sample from text logits with top-k, temperature and repetition penalty. 102 | 103 | Args: 104 | logits: Logits tensor of shape [batch_size, seq_len, vocab_size] or [batch_size, vocab_size] 105 | recent_tokens: Optional tensor of recent tokens for repetition penalty 106 | 107 | Returns: 108 | Sampled token ids 109 | """ 110 | # Take the last token's logits if we have a sequence dimension 111 | if len(logits.shape) == 3: 112 | logits = logits[:, -1] 113 | 114 | # Apply repetition penalty if needed 115 | if ( 116 | self.text_repetition_penalty > 1.0 117 | and recent_tokens is not None 118 | and len(recent_tokens) > self.text_repetition_window_size 119 | ): 120 | logits = logits[0] # Assumes batch size of 1 for repetition penalty 121 | recent_window = recent_tokens[-self.text_repetition_window_size :].long() 122 | 123 | # Gather scores of recent tokens 124 | scores = torch.gather(logits, dim=0, index=recent_window) 125 | 126 | # Apply penalty: if score < 0 multiply by penalty, otherwise divide by penalty 127 | scores = torch.where( 128 | scores < 0, 129 | scores * self.text_repetition_penalty, 130 | scores / self.text_repetition_penalty, 131 | ) 132 | 133 | # Put the penalized scores back 134 | logits.scatter_(dim=0, index=recent_window, src=scores) 135 | logits = logits.unsqueeze(0) # Add batch dimension back 136 | 137 | # Convert to probabilities with softmax 138 | logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float) 139 | 140 | # Apply temperature scaling if not greedy 141 | if self.text_temperature > 1e-6: 142 | logprobs = logprobs / self.text_temperature 143 | 144 | # Apply top-k sampling 145 | if self.text_top_k > 0: 146 | # Get probabilities from logprobs 147 | probs = torch.exp(logprobs) 148 | 149 | # Select top-k probabilities and indices 150 | top_k_probs, top_k_indices = torch.topk(probs, self.text_top_k, dim=-1) 151 | 152 | # Sample from the top-k distribution 153 | sampled_indices = torch.multinomial(top_k_probs, num_samples=1).squeeze( 154 | 1 155 | ) 156 | next_token = top_k_indices.gather( 157 | -1, sampled_indices.unsqueeze(-1) 158 | ).squeeze(-1) 159 | else: 160 | # Sample from the full distribution 161 | next_token = torch.multinomial( 162 | torch.exp(logprobs), num_samples=1 163 | ).squeeze(1) 164 | else: 165 | # Greedy sampling (temperature = 0) 166 | next_token = torch.argmax(logprobs, dim=-1) 167 | 168 | return next_token 169 | -------------------------------------------------------------------------------- /finetune.py: -------------------------------------------------------------------------------- 1 | # This code is based on the revised code from fastchat based on tatsu-lab/stanford_alpaca and QwenLM/Qwen. 2 | 3 | 4 | from dataclasses import dataclass, field 5 | import json 6 | import logging 7 | import os 8 | from typing import Dict, Optional 9 | 10 | import torch 11 | from deepspeed import zero 12 | from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus 13 | import transformers 14 | from transformers import Trainer, AutoTokenizer 15 | from transformers.integrations import deepspeed 16 | from transformers.trainer_pt_utils import LabelSmoother 17 | from accelerate.utils import DistributedType 18 | from huggingface_hub import snapshot_download 19 | 20 | from finetune_codes.model import KimiAudioModel 21 | from finetune_codes.datasets import LazySupervisedDataset 22 | 23 | logging.basicConfig(level=logging.INFO) 24 | logger = logging.getLogger(__name__) 25 | 26 | IGNORE_TOKEN_ID = LabelSmoother.ignore_index 27 | 28 | 29 | @dataclass 30 | class ModelArguments: 31 | model_name_or_path: Optional[str] = field(default="moonshotai/Kimi-Audio-7B") 32 | model_path: str = field( 33 | default=None, metadata={"help": "Path to the pretrained model."} 34 | ) 35 | 36 | @dataclass 37 | class DataArguments: 38 | data_path: str = field( 39 | default=None, metadata={"help": "Path to the training data."} 40 | ) 41 | eval_ratio: float = field( 42 | default=0.05, metadata={"help": "Ratio of evaluation data."} 43 | ) 44 | lazy_preprocess: bool = False 45 | 46 | 47 | @dataclass 48 | class TrainingArguments(transformers.TrainingArguments): 49 | cache_dir: Optional[str] = field(default=None) 50 | optim: str = field(default="adamw_torch") 51 | dataloader_pin_memory: bool = field(default=False) 52 | model_max_length: int = field( 53 | default=8192, 54 | metadata={ 55 | "help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)." 56 | }, 57 | ) 58 | 59 | 60 | 61 | def maybe_zero_3(param): 62 | if hasattr(param, "ds_id"): 63 | assert param.ds_status == ZeroParamStatus.NOT_AVAILABLE 64 | with zero.GatheredParameters([param]): 65 | param = param.data.detach().cpu().clone() 66 | else: 67 | param = param.detach().cpu().clone() 68 | return param 69 | 70 | 71 | local_rank = None 72 | 73 | def rank0_print(*args): 74 | if local_rank == 0: 75 | print(*args) 76 | 77 | 78 | def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str): 79 | """Collects the state dict and dump to disk.""" 80 | # check if zero3 mode enabled 81 | if deepspeed.is_deepspeed_zero3_enabled(): 82 | state_dict = trainer.model_wrapped._zero3_consolidated_16bit_state_dict() 83 | else: 84 | state_dict = trainer.model.state_dict() 85 | if trainer.args.should_save and trainer.args.local_rank == 0: 86 | trainer._save(output_dir, state_dict=state_dict) 87 | 88 | 89 | 90 | 91 | def make_supervised_data_module( 92 | whisper_model, text_tokenizer, data_args, max_len, kimia_token_offset, 93 | ) -> Dict: 94 | """Make dataset and collator for supervised fine-tuning.""" 95 | dataset_cls = LazySupervisedDataset 96 | rank0_print("Loading data...") 97 | 98 | with open(data_args.data_path, "r") as f: 99 | lines = f.readlines() 100 | all_data = [json.loads(line) for line in lines] 101 | 102 | if data_args.eval_ratio > 0: 103 | eval_data = all_data[:int(len(all_data) * data_args.eval_ratio)] 104 | train_data = all_data[int(len(all_data) * data_args.eval_ratio):] 105 | assert len(eval_data) > 0, "No evaluation data found" 106 | assert len(train_data) > 0, "No training data found" 107 | else: 108 | eval_data = None 109 | train_data = all_data 110 | 111 | train_dataset = dataset_cls(train_data, whisper_model=whisper_model, text_tokenizer=text_tokenizer, max_len=max_len, kimia_token_offset=kimia_token_offset) 112 | 113 | if eval_data: 114 | eval_dataset = dataset_cls(eval_data, whisper_model=whisper_model, text_tokenizer=text_tokenizer, max_len=max_len, kimia_token_offset=kimia_token_offset) 115 | else: 116 | eval_dataset = None 117 | 118 | return dict(train_dataset=train_dataset, eval_dataset=eval_dataset) 119 | 120 | 121 | def compute_loss(outputs, labels, num_items_in_batch=None): 122 | 123 | audio_logits, text_logits = outputs.logits 124 | 125 | audio_labels, text_labels, audio_loss_mask, text_loss_mask = labels 126 | assert audio_labels.shape[0] == 1, print("we only support micro batch size 1 for demo purpose") 127 | 128 | audio_loss = torch.nn.functional.cross_entropy(audio_logits.view(-1, audio_logits.shape[-1]), audio_labels.view(-1), reduction="none") 129 | text_loss = torch.nn.functional.cross_entropy(text_logits.view(-1, text_logits.shape[-1]), text_labels.view(-1), reduction="none") 130 | 131 | 132 | audio_loss = (audio_loss * audio_loss_mask.view(-1)).sum() / (audio_loss_mask.view(-1).sum() + 1e-4) 133 | text_loss = (text_loss * text_loss_mask.view(-1)).sum() / (text_loss_mask.view(-1).sum() + 1e-4) 134 | loss = audio_loss + text_loss 135 | return loss 136 | 137 | 138 | def train(): 139 | global local_rank 140 | 141 | parser = transformers.HfArgumentParser( 142 | (ModelArguments, DataArguments, TrainingArguments) 143 | ) 144 | ( 145 | model_args, 146 | data_args, 147 | training_args, 148 | ) = parser.parse_args_into_dataclasses() 149 | 150 | # This serves for single-gpu qlora. 151 | if getattr(training_args, 'deepspeed', None) and int(os.environ.get("WORLD_SIZE", 1))==1: 152 | training_args.distributed_state.distributed_type = DistributedType.DEEPSPEED 153 | 154 | local_rank = training_args.local_rank 155 | 156 | model_load_kwargs = { 157 | 'low_cpu_mem_usage': not deepspeed.is_deepspeed_zero3_enabled(), 158 | } 159 | 160 | logger.info(f"Loading kimi-audio main model") 161 | 162 | if os.path.exists(model_args.model_name_or_path): 163 | # local path 164 | cache_path = model_args.model_name_or_path 165 | else: 166 | # cache everything if model_path is a model-id 167 | cache_path = snapshot_download(model_args.model_name_or_path) 168 | 169 | logger.info(f"Looking for resources in {cache_path}") 170 | # check if model_path exists 171 | if not os.path.exists(model_args.model_path): 172 | raise ValueError(f"Model path {model_args.model_path} does not exist") 173 | model = KimiAudioModel.from_pretrained(model_args.model_path, 174 | device_map=None, 175 | **model_load_kwargs) 176 | 177 | text_tokenizer = AutoTokenizer.from_pretrained( 178 | cache_path, trust_remote_code=True 179 | ) 180 | 181 | # Load data 182 | data_module = make_supervised_data_module( 183 | whisper_model=model.whisper_model, text_tokenizer=text_tokenizer, 184 | data_args=data_args, max_len=training_args.model_max_length, kimia_token_offset=model.config.kimia_token_offset 185 | ) 186 | 187 | # Start trainner 188 | trainer = Trainer( 189 | model=model, args=training_args, 190 | compute_loss_func=compute_loss, 191 | data_collator=data_module["train_dataset"].collate_fn, 192 | **data_module 193 | ) 194 | 195 | trainer.train() 196 | trainer.save_state() 197 | 198 | safe_save_model_for_hf_trainer(trainer=trainer, output_dir=training_args.output_dir) 199 | 200 | if __name__ == "__main__": 201 | train() -------------------------------------------------------------------------------- /kimia_infer/models/tokenizer/whisper_Lv3/whisper.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn.functional as F 4 | from torch import nn 5 | from subprocess import CalledProcessError, run, Popen, PIPE 6 | import os 7 | from functools import lru_cache 8 | from typing import Optional, Union 9 | from .modeling_whisper import WhisperModel 10 | 11 | # hard-coded audio hyperparameters 12 | SAMPLE_RATE = 16000 13 | N_FFT = 400 14 | N_MELS = 120 15 | HOP_LENGTH = 160 16 | CHUNK_LENGTH = 30 17 | N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000 samples in a 30-second chunk 18 | 19 | 20 | def load_bytesio_audio(content, sr: int = SAMPLE_RATE): 21 | cmd = [ 22 | "ffmpeg", 23 | "-nostdin", 24 | "-threads", 25 | "0", 26 | "-i", 27 | "pipe:", 28 | "-f", 29 | "s16le", 30 | "-ac", 31 | "1", 32 | "-acodec", 33 | "pcm_s16le", 34 | "-ar", 35 | str(sr), 36 | "pipe:", 37 | ] 38 | p = Popen(cmd, stdin=PIPE, stdout=PIPE, stderr=PIPE, bufsize=-1) 39 | out, _ = p.communicate(input=content) 40 | return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0 41 | 42 | 43 | def load_audio(file: str, sr: int = SAMPLE_RATE): 44 | """ 45 | Open an audio file and read as mono waveform, resampling as necessary 46 | 47 | Parameters 48 | ---------- 49 | file: str 50 | The audio file to open 51 | 52 | sr: int 53 | The sample rate to resample the audio if necessary 54 | 55 | Returns 56 | ------- 57 | A NumPy array containing the audio waveform, in float32 dtype. 58 | """ 59 | 60 | # This launches a subprocess to decode audio while down-mixing 61 | # and resampling as necessary. Requires the ffmpeg CLI in PATH. 62 | # fmt: off 63 | cmd = ["ffmpeg", "-nostdin", "-threads", "0", "-i", file, "-f", "s16le", "-ac", "1", "-acodec", "pcm_s16le", "-ar", str(sr), "-"] 64 | # fmt: on 65 | try: 66 | out = run(cmd, capture_output=True, check=True).stdout 67 | except CalledProcessError as e: 68 | raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e 69 | 70 | return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0 71 | 72 | 73 | def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1): 74 | """ 75 | Pad or trim the audio array to N_SAMPLES, as expected by the encoder. 76 | """ 77 | if torch.is_tensor(array): 78 | if array.shape[axis] > length: 79 | array = array.index_select( 80 | dim=axis, index=torch.arange(length, device=array.device) 81 | ) 82 | 83 | if array.shape[axis] < length: 84 | pad_widths = [(0, 0)] * array.ndim 85 | pad_widths[axis] = (0, length - array.shape[axis]) 86 | array = F.pad(array, [pad for sizes in pad_widths[::-1] for pad in sizes]) 87 | else: 88 | if array.shape[axis] > length: 89 | array = array.take(indices=range(length), axis=axis) 90 | 91 | if array.shape[axis] < length: 92 | pad_widths = [(0, 0)] * array.ndim 93 | pad_widths[axis] = (0, length - array.shape[axis]) 94 | array = np.pad(array, pad_widths) 95 | 96 | return array 97 | 98 | 99 | @lru_cache(maxsize=None) 100 | def mel_filters(device, n_mels: int = 128) -> torch.Tensor: 101 | """ 102 | load the mel filterbank matrix for projecting STFT into a Mel spectrogram. 103 | Allows decoupling librosa dependency; saved using: 104 | 105 | np.savez_compressed( 106 | "mel_filters.npz", 107 | mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80), 108 | ) 109 | """ 110 | with np.load( 111 | os.path.join(os.path.dirname(__file__), "mel_filters.npz") # todo 112 | # os.path.join("assets", "mel_filters.npz") 113 | ) as f: 114 | return torch.from_numpy(f[f"mel_{n_mels}"]).to(device) 115 | 116 | 117 | def log_mel_spectrogram( 118 | audio: Union[str, np.ndarray, torch.Tensor], 119 | n_mels: int = 128, 120 | padding: int = 0, 121 | device: Optional[Union[str, torch.device]] = None, 122 | ): 123 | """ 124 | Compute the log-Mel spectrogram of 125 | 126 | Parameters 127 | ---------- 128 | audio: Union[str, np.ndarray, torch.Tensor], shape = (*) 129 | The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz 130 | 131 | n_mels: int 132 | The number of Mel-frequency filters, only 80 is supported 133 | 134 | padding: int 135 | Number of zero samples to pad to the right 136 | 137 | device: Optional[Union[str, torch.device]] 138 | If given, the audio tensor is moved to this device before STFT 139 | 140 | Returns 141 | ------- 142 | torch.Tensor, shape = (80, n_frames) 143 | A Tensor that contains the Mel spectrogram 144 | """ 145 | if not torch.is_tensor(audio): 146 | if isinstance(audio, str): 147 | audio = load_audio(audio) 148 | audio = torch.from_numpy(audio) 149 | 150 | if device is not None: 151 | audio = audio.to(device) 152 | if padding > 0: 153 | audio = F.pad(audio, (0, padding)) 154 | window = torch.hann_window(N_FFT).to(audio.device) 155 | stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True) 156 | magnitudes = stft[..., :-1].abs() ** 2 157 | 158 | filters = mel_filters(audio.device, n_mels) 159 | mel_spec = filters @ magnitudes 160 | 161 | log_spec = torch.clamp(mel_spec, min=1e-10).log10() 162 | log_spec = torch.maximum(log_spec, log_spec.max() - 8.0) 163 | log_spec = (log_spec + 4.0) / 4.0 164 | return log_spec 165 | 166 | 167 | class WhisperEncoder(nn.Module): 168 | def __init__( 169 | self, model_path, mel_batch_size=40, unfreeze_online_whisper_model=False 170 | ): 171 | super().__init__() 172 | self.speech_encoder = WhisperModel.from_pretrained(model_path).encoder 173 | self.unfreeze_online_whisper_model = unfreeze_online_whisper_model 174 | if not self.unfreeze_online_whisper_model: 175 | self.speech_encoder.eval() 176 | self.mel_batch_size = mel_batch_size 177 | 178 | def forward(self, audio, kimia_whisper_clip_silence=False): 179 | if isinstance(audio, torch.Tensor): 180 | audio = audio[0] 181 | audio = audio.cpu().numpy() 182 | time_step = 0 183 | audios = [] 184 | while time_step * 16000 < audio.shape[0]: 185 | audio_segment = audio[time_step * 16000 : (time_step + 30) * 16000] 186 | audios.append(audio_segment) 187 | time_step += 30 188 | 189 | final_audio_embedding = [] 190 | 191 | for audio_segment in audios: 192 | # import pdb; pdb.set_trace() 193 | assert audio_segment.shape[0] <= 480000 194 | L = audio_segment.shape[0] 195 | token_len = (L - 1) // ( 196 | 160 * 8 197 | ) + 1 # to match huggingface logic, with use attention mask to control the length and the slice with mask[:, ::160], also match the glm4 12.5 logic 198 | 199 | pad_audio = pad_or_trim(audio_segment.flatten()) 200 | mel = log_mel_spectrogram(pad_audio) # torch.Size([80, 3000]) 201 | assert mel.shape[1] == 3000 202 | if kimia_whisper_clip_silence: 203 | input_seq_lens_list = [token_len * 4] 204 | input_seq_lens = torch.LongTensor(input_seq_lens_list).to( 205 | torch.cuda.current_device() 206 | ) 207 | audio_embedding = self.speech_encoder( 208 | mel.unsqueeze(0).to(torch.cuda.current_device()).to(torch.bfloat16), 209 | return_dict=True, 210 | input_seq_lens=input_seq_lens, 211 | ).last_hidden_state 212 | else: 213 | audio_embedding = self.speech_encoder( 214 | mel.unsqueeze(0).to(torch.cuda.current_device()).to(torch.bfloat16), 215 | return_dict=True, 216 | ).last_hidden_state 217 | # audio_embedding: [1, 3000, 1280] 218 | audio_embedding = audio_embedding[:, : token_len * 4, :] 219 | final_audio_embedding.append(audio_embedding) 220 | 221 | final_audio_embedding = torch.cat(final_audio_embedding, dim=1) 222 | return final_audio_embedding 223 | 224 | @torch.no_grad() 225 | def tokenize_waveform(self, audio, kimia_whisper_clip_silence=False): 226 | audio_embedding = self.forward(audio, kimia_whisper_clip_silence) 227 | return audio_embedding.cpu() 228 | -------------------------------------------------------------------------------- /kimia_infer/models/detokenizer/flow_matching/ode_wrapper.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from functools import lru_cache 4 | import copy 5 | 6 | 7 | @lru_cache(maxsize=1) 8 | def get_cached_zeros(numel, device="cpu", dtype=torch.float32): 9 | return torch.zeros(numel, device=device, dtype=dtype) 10 | 11 | 12 | class StreamingODEWrapperForPrefix(nn.Module): 13 | def __init__( 14 | self, 15 | net, 16 | x_mask, 17 | x_cond, 18 | use_cfg=False, 19 | use_cfg_rescale=True, 20 | cfg_init=1.0, 21 | cfg_scale=4.0, 22 | cfg_schedule="linear", 23 | cfg_token_id=0, 24 | ): 25 | super(StreamingODEWrapperForPrefix, self).__init__() 26 | self.net = net 27 | self.x_mask = x_mask 28 | self.x_cond = x_cond 29 | 30 | assert use_cfg == False, "cfg is not supported in streaming detokenizer" 31 | 32 | self.use_cfg = use_cfg 33 | self.use_cfg_rescale = use_cfg_rescale 34 | self.cfg_init = cfg_init 35 | self.cfg_scale = cfg_scale 36 | self.cfg_token_id = cfg_token_id 37 | self.cfg_schedule = cfg_schedule 38 | self.position_ids = None 39 | self.seq_len = None 40 | 41 | self.incremental_state = {} 42 | self.kv_cache_tokens = 0 43 | self.cu_seqlens = None 44 | self.cu_maxlen = None 45 | 46 | self.cu_seqlens_k = None 47 | self.cu_maxlen_k = None 48 | self.previous_seqlen = None 49 | 50 | def clear_all_states(self): 51 | self.incremental_state = {} 52 | self.kv_cache_tokens = 0 53 | self.cu_seqlens = None 54 | self.cu_maxlen = None 55 | 56 | self.cu_seqlens_k = None 57 | self.cu_maxlen_k = None 58 | self.previous_seqlen = None 59 | 60 | def state_dict(self): 61 | return { 62 | "incremental_state": copy.deepcopy(self.incremental_state), 63 | "kv_cache_tokens": copy.deepcopy(self.kv_cache_tokens), 64 | "cu_seqlens": copy.deepcopy(self.cu_seqlens), 65 | "cu_maxlen": copy.deepcopy(self.cu_maxlen), 66 | "cu_seqlens_k": copy.deepcopy(self.cu_seqlens_k), 67 | "cu_maxlen_k": copy.deepcopy(self.cu_maxlen_k), 68 | "previous_seqlen": copy.deepcopy(self.previous_seqlen), 69 | } 70 | 71 | def load_state_dict(self, state_dict): 72 | self.incremental_state = state_dict["incremental_state"] 73 | self.kv_cache_tokens = state_dict["kv_cache_tokens"] 74 | self.cu_seqlens = state_dict["cu_seqlens"] 75 | self.cu_maxlen = state_dict["cu_maxlen"] 76 | self.cu_seqlens_k = state_dict["cu_seqlens_k"] 77 | self.cu_maxlen_k = state_dict["cu_maxlen_k"] 78 | self.previous_seqlen = state_dict["previous_seqlen"] 79 | 80 | def set_conditions(self, x_mask, x_cond, start_position_id, cache={}): 81 | if not self.use_cfg: 82 | self.x_mask = x_mask 83 | self.x_cond = x_cond 84 | else: 85 | self.x_cond = torch.cat((x_cond, x_cond), dim=0) 86 | self.x_mask = torch.cat((x_mask, x_mask), dim=0) 87 | 88 | position_ids_cur = [ 89 | i 90 | for i in range(start_position_id, self.x_cond.shape[1] + start_position_id) 91 | ] 92 | position_ids = torch.tensor([position_ids_cur]) 93 | 94 | if not self.use_cfg: 95 | self.position_ids = position_ids.to(self.x_cond.device).long() 96 | self.seq_len = ( 97 | torch.Tensor([position_ids.shape[1]]).to(self.x_cond.device).long() 98 | ) 99 | else: 100 | self.position_ids = ( 101 | torch.cat((position_ids, position_ids), dim=0) 102 | .to(self.x_cond.device) 103 | .long() 104 | ) 105 | self.seq_len = ( 106 | torch.Tensor([position_ids.shape[1], position_ids.shape[1]]) 107 | .to(self.x_cond.device) 108 | .long() 109 | ) 110 | 111 | cu_seqlens = torch.cumsum(self.seq_len, dim=0) 112 | self.cu_seqlens = torch.cat( 113 | [torch.Tensor([0]).to(cu_seqlens.device), cu_seqlens], dim=0 114 | ).int() 115 | self.cu_maxlen = self.seq_len.cpu().max() 116 | 117 | if self.cu_seqlens_k is None: 118 | self.cu_seqlens_k = self.cu_seqlens 119 | self.cu_maxlen_k = self.cu_maxlen 120 | previous_seqlen = self.seq_len 121 | else: 122 | previous_seqlen_old = cache["previous_seqlen"] 123 | previous_seqlen = previous_seqlen_old + self.seq_len 124 | # calculate cu_seqlens_k 125 | cu_seqlens_k = torch.cumsum(previous_seqlen, dim=0) 126 | self.cu_seqlens_k = torch.cat( 127 | [torch.Tensor([0]).to(cu_seqlens_k.device), cu_seqlens_k], dim=0 128 | ).int() 129 | self.cu_maxlen_k = previous_seqlen.cpu().max() 130 | self.previous_seqlen = previous_seqlen 131 | ret_cache = {"previous_seqlen": previous_seqlen} 132 | return ret_cache 133 | 134 | def update_incremental_state( 135 | self, 136 | reserve_kv_cache_tokens=0, 137 | max_kv_cache_tokens=900, 138 | condition_cache={"previous_seqlen"}, 139 | ): 140 | 141 | assert ( 142 | reserve_kv_cache_tokens <= max_kv_cache_tokens 143 | ), "reserve_kv_cache_tokens should be less than or equal to max_kv_cache_tokens" 144 | 145 | for layer_idx, layer_cache in self.incremental_state.items(): 146 | # update attention kv cache 147 | layer_cache["attn_kvcache"]["prev_k"] = layer_cache["attn_kvcache"]["cur_k"] 148 | layer_cache["attn_kvcache"]["prev_v"] = layer_cache["attn_kvcache"]["cur_v"] 149 | 150 | self.kv_cache_tokens = layer_cache["attn_kvcache"]["prev_k"].shape[1] 151 | 152 | if self.kv_cache_tokens > max_kv_cache_tokens: 153 | # drop old tokens from reserve kv cache tokens to max_kv_cache_tokens 154 | reserve_tokens_excludeprompt = ( 155 | max_kv_cache_tokens - reserve_kv_cache_tokens 156 | ) 157 | 158 | if reserve_kv_cache_tokens == 0: 159 | layer_cache["attn_kvcache"]["prev_k"] = layer_cache["attn_kvcache"][ 160 | "prev_k" 161 | ][:, -reserve_tokens_excludeprompt:] 162 | layer_cache["attn_kvcache"]["prev_v"] = layer_cache["attn_kvcache"][ 163 | "prev_v" 164 | ][:, -reserve_tokens_excludeprompt:] 165 | elif reserve_tokens_excludeprompt == 0: 166 | layer_cache["attn_kvcache"]["prev_k"] = layer_cache["attn_kvcache"][ 167 | "prev_k" 168 | ][:, :reserve_kv_cache_tokens] 169 | layer_cache["attn_kvcache"]["prev_v"] = layer_cache["attn_kvcache"][ 170 | "prev_v" 171 | ][:, :reserve_kv_cache_tokens] 172 | else: 173 | layer_cache["attn_kvcache"]["prev_k"] = torch.cat( 174 | [ 175 | layer_cache["attn_kvcache"]["prev_k"][ 176 | :, :reserve_kv_cache_tokens 177 | ], 178 | layer_cache["attn_kvcache"]["prev_k"][ 179 | :, -reserve_tokens_excludeprompt: 180 | ], 181 | ], 182 | dim=1, 183 | ) 184 | 185 | layer_cache["attn_kvcache"]["prev_v"] = torch.cat( 186 | [ 187 | layer_cache["attn_kvcache"]["prev_v"][ 188 | :, :reserve_kv_cache_tokens 189 | ], 190 | layer_cache["attn_kvcache"]["prev_v"][ 191 | :, -reserve_tokens_excludeprompt: 192 | ], 193 | ], 194 | dim=1, 195 | ) 196 | 197 | bsz = layer_cache["attn_kvcache"]["prev_k"].shape[0] 198 | self.previous_seqlen = ( 199 | torch.Tensor( 200 | [ 201 | layer_cache["attn_kvcache"]["prev_k"].shape[1] 202 | for i in range(bsz) 203 | ] 204 | ) 205 | .to(layer_cache["attn_kvcache"]["prev_k"].device) 206 | .long() 207 | ) 208 | condition_cache["previous_seqlen"] = self.previous_seqlen 209 | self.kv_cache_tokens = layer_cache["attn_kvcache"]["prev_k"].shape[1] 210 | 211 | # clear current cache 212 | layer_cache["attn_kvcache"].pop("cur_k") 213 | layer_cache["attn_kvcache"].pop("cur_v") 214 | 215 | def forward(self, t, x, args=None): 216 | # t = torch.tensor([t * 1000] * x.shape[0], device=x.device, dtype=x.dtype).long() 217 | t = ( 218 | get_cached_zeros(x.shape[0], device=x.device, dtype=torch.long) 219 | + (t * 1000).long() 220 | ) 221 | 222 | if self.use_cfg: 223 | raise NotImplementedError("cfg is not supported in streaming detokenizer.") 224 | else: 225 | pred_noise = self.net( 226 | x=x, 227 | condition=self.x_cond, 228 | t=t, 229 | position_ids=self.position_ids, 230 | cu_seqlens=self.cu_seqlens, 231 | cu_maxlen=self.cu_maxlen, 232 | cu_seqlens_k=self.cu_seqlens_k, 233 | cu_maxlen_k=self.cu_maxlen_k, 234 | incremental_state=self.incremental_state, 235 | nopadding=True, 236 | mask=None, 237 | seq_len=None, 238 | ) 239 | return pred_noise 240 | -------------------------------------------------------------------------------- /finetune_codes/datasets.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | from functools import lru_cache 3 | import torch 4 | from typing import Dict, List 5 | from kimia_infer.utils.special_tokens import instantiate_extra_tokens 6 | from kimia_infer.utils.data import KimiAContent 7 | import librosa 8 | 9 | class LazySupervisedDataset(Dataset): 10 | """Dataset for supervised fine-tuning.""" 11 | 12 | def __init__(self, raw_data_list, whisper_model, text_tokenizer, max_len: int, kimia_token_offset: int): 13 | super(LazySupervisedDataset, self).__init__() 14 | self.whisper_model = whisper_model 15 | self.max_len = max_len 16 | 17 | print("There are {} samples in the dataset".format(len(raw_data_list))) 18 | self.whisper_model = whisper_model 19 | 20 | print(f"Loading text tokenizer") 21 | self.text_tokenizer = text_tokenizer 22 | 23 | self.extra_tokens = instantiate_extra_tokens(self.text_tokenizer) 24 | 25 | self.pad_token = self.extra_tokens.pad 26 | self.kimia_token_offset = kimia_token_offset 27 | self.raw_data = raw_data_list 28 | 29 | self.cached_data_dict = {} 30 | 31 | def __len__(self): 32 | return len(self.raw_data) 33 | 34 | def extract_whisper_feat(self, wav: str): 35 | wav = librosa.load(wav, sr=16000)[0] 36 | # if isinstance(wav, str): 37 | # wav = librosa.load(wav, sr=16000)[0] 38 | 39 | # wav_tensor = torch.tensor(wav).unsqueeze(0)[:, :] 40 | # elif isinstance(wav, torch.Tensor): 41 | # wav_tensor = wav 42 | # else: 43 | # raise ValueError(f"Invalid wav type: {type(wav)}") 44 | 45 | # wav_tensor = wav_tensor.to(torch.cuda.current_device()) 46 | # continous_feature = self.whisper_model(wav_tensor) 47 | # continous_feature = continous_feature.reshape( 48 | # continous_feature.shape[0], 49 | # int(continous_feature.shape[1] // 4), 50 | # continous_feature.shape[2] * 4, 51 | # ) 52 | return wav 53 | 54 | def _tokenize_text(self, text): 55 | if text is None: 56 | return None 57 | token_ids = self.text_tokenizer.encode(text, bos=False, eos=False) 58 | return token_ids 59 | 60 | def tokenize_message( 61 | self, 62 | message, 63 | tokenize_role=True, 64 | has_ct_token=False, 65 | has_msg_end_token=False, 66 | extract_whisper_feature=False, 67 | output_type: str = "text", 68 | ): 69 | kimia_content_msg = KimiAContent() 70 | 71 | role = message["role"] 72 | 73 | has_loss = role == "assistant" 74 | 75 | if tokenize_role: 76 | if role == "user": 77 | kimia_content_msg.audio_append(self.extra_tokens.kimia_user_msg_start) 78 | kimia_content_msg.text_append(self.extra_tokens.kimia_text_blank) 79 | elif role == "assistant": 80 | kimia_content_msg.audio_append( 81 | self.extra_tokens.kimia_assistant_msg_start 82 | ) 83 | kimia_content_msg.text_append(self.extra_tokens.kimia_text_blank) 84 | else: 85 | raise NotImplementedError(f"role: {role}") 86 | 87 | if message["message_type"] == "text": 88 | text = message["content"] 89 | text_tokens = self._tokenize_text(text) 90 | 91 | kimia_content_msg.text_extend(text_tokens, has_loss) 92 | kimia_content_msg.audio_extend( 93 | [self.extra_tokens.kimia_text_blank] * len(text_tokens) 94 | ) 95 | 96 | if role == "assistant": 97 | kimia_content_msg.text_append(self.extra_tokens.kimia_text_eos, has_loss) # eos for text stream 98 | kimia_content_msg.audio_append(self.extra_tokens.kimia_text_blank, audio_token_loss_mask=False) 99 | 100 | elif message["message_type"] == "audio": 101 | speech_tokens = message["audio_tokens"] 102 | 103 | kimia_content_msg.audio_append(self.extra_tokens.media_begin) 104 | kimia_content_msg.audio_extend(speech_tokens, is_continuous=True, audio_token_loss_mask=has_loss) 105 | kimia_content_msg.audio_append(self.extra_tokens.media_end, audio_token_loss_mask=has_loss) # EOS for audio stream 106 | kimia_content_msg.text_extend( 107 | [self.extra_tokens.kimia_text_blank] * (len(speech_tokens) + 2) 108 | ) 109 | 110 | if has_ct_token: 111 | if output_type == "text": 112 | kimia_content_msg.audio_append(self.extra_tokens.kimia_speech_ct_id) 113 | else: 114 | kimia_content_msg.audio_append( 115 | self.extra_tokens.kimia_speech_ctd_id 116 | ) 117 | kimia_content_msg.text_append(self.extra_tokens.kimia_text_blank) 118 | 119 | if extract_whisper_feature: 120 | whisper_feature = self.extract_whisper_feat(message["content"]) 121 | kimia_content_msg.continuous_feature.append(whisper_feature) 122 | elif message["message_type"] == None: 123 | pass 124 | else: 125 | raise NotImplementedError(f"message_type: {message['message_type']}") 126 | 127 | if has_msg_end_token: 128 | kimia_content_msg.audio_append(self.extra_tokens.msg_end, audio_token_loss_mask=False) 129 | kimia_content_msg.text_append(self.extra_tokens.kimia_text_blank) 130 | 131 | assert ( 132 | kimia_content_msg.is_valid() 133 | ), f"kimia_content_msg is not valid: {kimia_content_msg}" 134 | 135 | return kimia_content_msg 136 | 137 | def tokenize_conversation( 138 | self, messages: List[Dict], output_type: str = "text", add_assistant_start_msg: bool = True 139 | ) -> KimiAContent: 140 | """ 141 | messages: List[Dict] 142 | messages[i] = { 143 | "role": "user" | "assistant" | "system", 144 | "content": str 145 | } 146 | """ 147 | assert output_type in ["text", "both"] 148 | 149 | msgs: List[KimiAContent] = [] 150 | tokenize_role = True 151 | has_ct_token = False 152 | has_msg_end_token = False 153 | 154 | previous_role = None 155 | for msg_idx, message in enumerate(messages): 156 | assert message["role"] in ["user", "assistant"] 157 | 158 | if previous_role is None: 159 | tokenize_role = True 160 | else: 161 | if message["role"] == previous_role: 162 | tokenize_role = False 163 | else: 164 | tokenize_role = True 165 | 166 | if msg_idx == len(messages) - 1: 167 | has_ct_token = True 168 | has_msg_end_token = True 169 | else: 170 | if messages[msg_idx + 1]["role"] != message["role"]: 171 | has_ct_token = True 172 | has_msg_end_token = True 173 | else: 174 | has_ct_token = False 175 | has_msg_end_token = False 176 | 177 | previous_role = message["role"] 178 | 179 | msg = self.tokenize_message( 180 | message=message, 181 | tokenize_role=tokenize_role, 182 | has_ct_token=has_ct_token, 183 | has_msg_end_token=has_msg_end_token, 184 | extract_whisper_feature=True, 185 | output_type=output_type, 186 | ) 187 | msgs.append(msg) 188 | 189 | if add_assistant_start_msg: 190 | assistant_start_msg = self.tokenize_message( 191 | message={ 192 | "role": "assistant", 193 | "message_type": None, 194 | }, 195 | tokenize_role=True, 196 | has_ct_token=False, 197 | has_msg_end_token=False, 198 | ) 199 | 200 | msgs.append(assistant_start_msg) 201 | 202 | ret_msg = msgs[0] 203 | 204 | for msg in msgs[1:]: 205 | ret_msg.merge(msg) 206 | 207 | return ret_msg 208 | 209 | @lru_cache(maxsize=None) 210 | def __getitem__(self, i) -> Dict[str, torch.Tensor]: 211 | 212 | task_type = self.raw_data[i]["task_type"] 213 | conversation = self.raw_data[i]["conversation"] 214 | 215 | output_type = "text" if task_type == "understanding" else "both" 216 | 217 | tokenized_conversation = self.tokenize_conversation(conversation, output_type=output_type, add_assistant_start_msg=False) 218 | 219 | audio_input_ids, text_input_ids, is_continuous_mask, audio_token_loss_mask, text_token_loss_mask = tokenized_conversation.to_tensor() 220 | 221 | audio_features = tokenized_conversation.continuous_feature 222 | 223 | audio_labels = torch.cat((audio_input_ids[:, 1:], audio_input_ids.new_full((1, 1), self.pad_token)), dim=1) 224 | text_labels = torch.cat((text_input_ids[:, 1:], text_input_ids.new_full((1, 1), self.pad_token)), dim=1) 225 | audio_loss_mask = torch.cat((audio_token_loss_mask[:, 1:], audio_token_loss_mask.new_full((1, 1), False)), dim=1) 226 | text_loss_mask = torch.cat((text_token_loss_mask[:, 1:], text_token_loss_mask.new_full((1, 1), False)), dim=1) 227 | 228 | ret = dict( 229 | input_ids=audio_input_ids, 230 | text_input_ids=text_input_ids, 231 | whisper_input_feature=audio_features, 232 | is_continuous_mask=is_continuous_mask, 233 | labels=( 234 | audio_labels, 235 | text_labels, 236 | audio_loss_mask, 237 | text_loss_mask, 238 | ), 239 | ) 240 | 241 | return ret 242 | 243 | @staticmethod 244 | def collate_fn(batch): 245 | assert len(batch) == 1, "micro batch size is 1 for demo" 246 | 247 | return batch[0] 248 | 249 | 250 | -------------------------------------------------------------------------------- /kimia_infer/api/prompt_manager.py: -------------------------------------------------------------------------------- 1 | from typing import List, Dict 2 | import os 3 | 4 | import librosa 5 | import torch 6 | from loguru import logger 7 | from transformers import AutoTokenizer 8 | 9 | 10 | from kimia_infer.models.tokenizer.whisper_Lv3.whisper import WhisperEncoder 11 | from kimia_infer.models.tokenizer.glm4_tokenizer import Glm4Tokenizer 12 | from kimia_infer.utils.data import KimiAContent 13 | from kimia_infer.utils.special_tokens import instantiate_extra_tokens 14 | 15 | class KimiAPromptManager: 16 | def __init__(self, model_path: str, kimia_token_offset: int, kimia_text_audiodelaytokens: int): 17 | self.audio_tokenizer = Glm4Tokenizer("THUDM/glm-4-voice-tokenizer") 18 | self.audio_tokenizer = self.audio_tokenizer.to(torch.cuda.current_device()) 19 | 20 | logger.info(f"Looking for resources in {model_path}") 21 | logger.info(f"Loading whisper model") 22 | 23 | self.whisper_model = WhisperEncoder( 24 | os.path.join(model_path, "whisper-large-v3"), mel_batch_size=20 25 | ) 26 | self.whisper_model = self.whisper_model.to(torch.cuda.current_device()) 27 | self.whisper_model = self.whisper_model.bfloat16() 28 | self.whisper_model.eval() 29 | 30 | logger.info(f"Loading text tokenizer") 31 | if os.path.exists(model_path) and os.path.exists(os.path.join(model_path, "tokenizer_config.json")): 32 | self.text_tokenizer = AutoTokenizer.from_pretrained( 33 | model_path, trust_remote_code=True 34 | ) 35 | else: 36 | logger.info(f"Can not find text tokenizer in {model_path}, Loading default text tokenizer from moonshotai/Kimi-Audio-7B-Instruct") 37 | self.text_tokenizer = AutoTokenizer.from_pretrained( 38 | "moonshotai/Kimi-Audio-7B-Instruct", trust_remote_code=True 39 | ) 40 | 41 | self.extra_tokens = instantiate_extra_tokens(self.text_tokenizer) 42 | 43 | self.kimia_text_audiodelaytokens = kimia_text_audiodelaytokens 44 | 45 | self.kimia_token_offset = kimia_token_offset 46 | 47 | def _tokenize_text(self, text): 48 | if text is None: 49 | return None 50 | token_ids = self.text_tokenizer.encode(text, bos=False, eos=False) 51 | return token_ids 52 | 53 | def _tokenize_audio(self, wav_path): 54 | wav_tokens = self.audio_tokenizer.tokenize(audio_path=wav_path) 55 | wav_tokens = wav_tokens + self.kimia_token_offset 56 | wav_tokens_list = wav_tokens.squeeze(0).cpu().numpy().tolist() 57 | return wav_tokens_list 58 | 59 | def extract_whisper_feat(self, wav: torch.Tensor | str): 60 | if isinstance(wav, str): 61 | wav = librosa.load(wav, sr=16000)[0] 62 | 63 | wav_tensor = torch.tensor(wav).unsqueeze(0)[:, :] 64 | elif isinstance(wav, torch.Tensor): 65 | wav_tensor = wav 66 | else: 67 | raise ValueError(f"Invalid wav type: {type(wav)}") 68 | assert self.whisper_model is not None 69 | wav_tensor = wav_tensor.to(torch.cuda.current_device()) 70 | continous_feature = self.whisper_model.tokenize_waveform(wav_tensor) 71 | continous_feature = continous_feature.reshape( 72 | continous_feature.shape[0], 73 | int(continous_feature.shape[1] // 4), 74 | continous_feature.shape[2] * 4, 75 | ) 76 | return continous_feature 77 | 78 | def tokenize_message( 79 | self, 80 | message, 81 | tokenize_role=True, 82 | has_ct_token=False, 83 | has_msg_end_token=False, 84 | extract_whisper_feature=False, 85 | output_type: str = "text", 86 | ): 87 | kimia_content_msg = KimiAContent() 88 | 89 | role = message["role"] 90 | 91 | has_loss = role == "assistant" 92 | 93 | if tokenize_role: 94 | if role == "user": 95 | kimia_content_msg.audio_append(self.extra_tokens.kimia_user_msg_start) 96 | kimia_content_msg.text_append(self.extra_tokens.kimia_text_blank) 97 | elif role == "assistant": 98 | kimia_content_msg.audio_append( 99 | self.extra_tokens.kimia_assistant_msg_start 100 | ) 101 | kimia_content_msg.text_append(self.extra_tokens.kimia_text_blank) 102 | else: 103 | raise NotImplementedError(f"role: {role}") 104 | 105 | if message["message_type"] == "text": 106 | text = message["content"] 107 | text_tokens = self._tokenize_text(text) 108 | 109 | kimia_content_msg.text_extend(text_tokens, has_loss) 110 | kimia_content_msg.audio_extend( 111 | [self.extra_tokens.kimia_text_blank] * len(text_tokens) 112 | ) 113 | 114 | if role == "assistant": 115 | kimia_content_msg.text_append(self.extra_tokens.kimia_text_eos, has_loss) # eos for text stream 116 | kimia_content_msg.audio_append(self.extra_tokens.kimia_text_blank, audio_token_loss_mask=False) 117 | 118 | elif message["message_type"] == "audio": 119 | if "audio_tokens" in message: 120 | speech_tokens = message["audio_tokens"] 121 | else: 122 | audio_path = message["content"] 123 | speech_tokens = self._tokenize_audio(audio_path) 124 | 125 | kimia_content_msg.audio_append(self.extra_tokens.media_begin) 126 | kimia_content_msg.audio_extend(speech_tokens, is_continuous=True, audio_token_loss_mask=has_loss) 127 | kimia_content_msg.audio_append(self.extra_tokens.media_end, audio_token_loss_mask=has_loss) # EOS for audio stream 128 | kimia_content_msg.text_extend( 129 | [self.extra_tokens.kimia_text_blank] * (len(speech_tokens) + 2) 130 | ) 131 | 132 | if has_ct_token: 133 | if output_type == "text": 134 | kimia_content_msg.audio_append(self.extra_tokens.kimia_speech_ct_id) 135 | else: 136 | kimia_content_msg.audio_append( 137 | self.extra_tokens.kimia_speech_ctd_id 138 | ) 139 | kimia_content_msg.text_append(self.extra_tokens.kimia_text_blank) 140 | 141 | if extract_whisper_feature: 142 | whisper_feature = self.extract_whisper_feat(audio_path) 143 | kimia_content_msg.continuous_feature.append(whisper_feature) 144 | elif message["message_type"] == "audio-text": 145 | audio_path, text = message["content"] 146 | speech_tokens = self._tokenize_audio(audio_path) 147 | text_tokens = self._tokenize_text(text) 148 | 149 | kimia_content_msg.audio_extend([self.extra_tokens.kimia_text_blank] * self.kimia_text_audiodelaytokens) 150 | kimia_content_msg.audio_extend(speech_tokens, is_continuous=False) 151 | kimia_content_msg.text_extend(text_tokens) 152 | text_pad_tokens = (self.kimia_text_audiodelaytokens + len(speech_tokens) - len(text_tokens)) * [self.extra_tokens.kimia_text_blank] 153 | kimia_content_msg.text_extend(text_pad_tokens) 154 | 155 | elif message["message_type"] == None: 156 | pass 157 | else: 158 | raise NotImplementedError(f"message_type: {message['message_type']}") 159 | 160 | if has_msg_end_token: 161 | kimia_content_msg.audio_append(self.extra_tokens.msg_end, audio_token_loss_mask=False) 162 | kimia_content_msg.text_append(self.extra_tokens.kimia_text_blank) 163 | 164 | assert ( 165 | kimia_content_msg.is_valid() 166 | ), f"kimia_content_msg is not valid: {kimia_content_msg}" 167 | 168 | return kimia_content_msg 169 | 170 | def get_prompt( 171 | self, messages: List[Dict], output_type: str = "text", add_assistant_start_msg: bool = True 172 | ) -> KimiAContent: 173 | """ 174 | messages: List[Dict] 175 | messages[i] = { 176 | "role": "user" | "assistant" | "system", 177 | "content": str 178 | } 179 | """ 180 | assert output_type in ["text", "both"] 181 | 182 | msgs: List[KimiAContent] = [] 183 | tokenize_role = True 184 | has_ct_token = False 185 | has_msg_end_token = False 186 | 187 | previous_role = None 188 | for msg_idx, message in enumerate(messages): 189 | assert message["role"] in ["user", "assistant"] 190 | 191 | if previous_role is None: 192 | tokenize_role = True 193 | else: 194 | if message["role"] == previous_role: 195 | tokenize_role = False 196 | else: 197 | tokenize_role = True 198 | 199 | if msg_idx == len(messages) - 1: 200 | has_ct_token = True 201 | has_msg_end_token = True 202 | else: 203 | if messages[msg_idx + 1]["role"] != message["role"]: 204 | has_ct_token = True 205 | has_msg_end_token = True 206 | else: 207 | has_ct_token = False 208 | has_msg_end_token = False 209 | 210 | previous_role = message["role"] 211 | 212 | msg = self.tokenize_message( 213 | message=message, 214 | tokenize_role=tokenize_role, 215 | has_ct_token=has_ct_token, 216 | has_msg_end_token=has_msg_end_token, 217 | extract_whisper_feature=True, 218 | output_type=output_type, 219 | ) 220 | msgs.append(msg) 221 | 222 | if add_assistant_start_msg: 223 | assistant_start_msg = self.tokenize_message( 224 | message={ 225 | "role": "assistant", 226 | "message_type": None, 227 | }, 228 | tokenize_role=True, 229 | has_ct_token=False, 230 | has_msg_end_token=False, 231 | ) 232 | 233 | msgs.append(assistant_start_msg) 234 | 235 | ret_msg = msgs[0] 236 | 237 | for msg in msgs[1:]: 238 | ret_msg.merge(msg) 239 | 240 | return ret_msg 241 | -------------------------------------------------------------------------------- /kimia_infer/models/detokenizer/flow_matching/dit_block.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | from flash_attn import flash_attn_varlen_func, flash_attn_varlen_qkvpacked_func 10 | 11 | 12 | def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): 13 | # x shape: bsz, seqlen, self.n_local_heads, self.head_hidden_dim / 2 14 | # the last shape is "self.hidden_dim / 2" because we convert to complex 15 | assert x.ndim == 4 16 | assert freqs_cis.shape == ( 17 | x.shape[0], 18 | x.shape[1], 19 | x.shape[-1], 20 | ), f"x shape: {x.shape}, freqs_cis shape: {freqs_cis.shape}" 21 | 22 | # reshape freq cis to match and apply pointwise multiply 23 | # new shape: bsz, seq_len, 1, self.head_hidden_dim / 2 24 | shape = [x.shape[0], x.shape[1], 1, x.shape[-1]] 25 | return freqs_cis.view(*shape) 26 | 27 | 28 | def apply_rotary_emb( 29 | xq: torch.Tensor, 30 | xk: torch.Tensor, 31 | freqs_cis: torch.Tensor, 32 | ): 33 | xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) 34 | xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) 35 | 36 | freqs_cis = reshape_for_broadcast(freqs_cis, xq_) 37 | xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) 38 | xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) 39 | return xq_out.type_as(xq), xk_out.type_as(xk) 40 | 41 | 42 | class Attention(nn.Module): 43 | 44 | def __init__( 45 | self, 46 | dim: int, 47 | num_heads: int = 8, 48 | qkv_bias: bool = False, 49 | qk_norm: bool = False, 50 | attn_drop: float = 0.0, 51 | proj_drop: float = 0.0, 52 | norm_layer: nn.Module = nn.LayerNorm, 53 | flash_attention: bool = True, 54 | ) -> None: 55 | super().__init__() 56 | assert dim % num_heads == 0, "dim should be divisible by num_heads" 57 | self.num_heads = num_heads 58 | self.head_dim = dim // num_heads 59 | self.scale = self.head_dim**-0.5 60 | self.fused_attn = flash_attention 61 | 62 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 63 | self.qk_norm = qk_norm 64 | self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() 65 | self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() 66 | self.attn_drop = nn.Dropout(attn_drop) 67 | self.proj = nn.Linear(dim, dim) 68 | self.proj_drop = nn.Dropout(proj_drop) 69 | 70 | def forward( 71 | self, 72 | x: torch.Tensor, 73 | seq_len, 74 | cu_seqlens, 75 | max_seqlen, 76 | cu_seqlens_k, 77 | max_seqlen_k, 78 | rotary_pos_emb=None, 79 | incremental_state=None, 80 | nopadding=True, 81 | ) -> torch.Tensor: 82 | B, N, C = x.shape 83 | 84 | if self.fused_attn: 85 | if nopadding: 86 | qkv = self.qkv(x) 87 | qkv = qkv.view(B * N, self.num_heads * 3, self.head_dim) 88 | q, k, v = qkv.split([self.num_heads] * 3, dim=1) 89 | q, k = self.q_norm(q), self.k_norm(k) 90 | 91 | q = q.view(B, N, self.num_heads, self.head_dim) 92 | k = k.view(B, N, self.num_heads, self.head_dim) 93 | v = v.view(B, N, self.num_heads, self.head_dim) 94 | 95 | if rotary_pos_emb is not None: 96 | q, k = apply_rotary_emb(q, k, rotary_pos_emb) 97 | 98 | if incremental_state is not None: 99 | if "prev_k" in incremental_state: 100 | prev_k = incremental_state["prev_k"] 101 | k = torch.cat([prev_k, k], dim=1) 102 | 103 | if "cur_k" not in incremental_state: 104 | incremental_state["cur_k"] = {} 105 | incremental_state["cur_k"] = k 106 | 107 | if "prev_v" in incremental_state: 108 | prev_v = incremental_state["prev_v"] 109 | v = torch.cat([prev_v, v], dim=1) 110 | 111 | if "cur_v" not in incremental_state: 112 | incremental_state["cur_v"] = {} 113 | incremental_state["cur_v"] = v 114 | 115 | q = q.view(B * N, self.num_heads, self.head_dim) 116 | k = k.view(-1, self.num_heads, self.head_dim) 117 | v = v.view(-1, self.num_heads, self.head_dim) 118 | 119 | x = flash_attn_varlen_func( 120 | q=q, 121 | k=k, 122 | v=v, 123 | cu_seqlens_q=cu_seqlens, 124 | cu_seqlens_k=cu_seqlens_k, 125 | max_seqlen_q=max_seqlen, 126 | max_seqlen_k=max_seqlen_k, 127 | dropout_p=self.attn_drop.p if self.training else 0.0, 128 | ) 129 | else: 130 | 131 | if incremental_state is not None: 132 | raise NotImplementedError( 133 | "It is designed for batching inference. AR-chunk is not supported currently." 134 | ) 135 | 136 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim) 137 | if self.qk_norm: 138 | q, k, v = qkv.unbind(2) 139 | q, k = self.q_norm(q), self.k_norm(k) 140 | # re-bind 141 | qkv = torch.stack((q, k, v), dim=2) 142 | 143 | # pack qkv with seq_len 144 | qkv_collect = [] 145 | for i in range(qkv.shape[0]): 146 | qkv_collect.append(qkv[i, : seq_len[i], :, :, :]) 147 | 148 | qkv = torch.cat(qkv_collect, dim=0) 149 | 150 | x = flash_attn_varlen_qkvpacked_func( 151 | qkv=qkv, 152 | cu_seqlens=cu_seqlens, 153 | max_seqlen=max_seqlen, 154 | dropout_p=self.attn_drop.p if self.training else 0.0, 155 | ) 156 | 157 | # unpack and pad 0 158 | x_collect = [] 159 | for i in range(B): 160 | x_collect.append(x[cu_seqlens[i] : cu_seqlens[i + 1], :, :]) 161 | x = torch.nn.utils.rnn.pad_sequence( 162 | x_collect, batch_first=True, padding_value=0 163 | ) 164 | 165 | else: 166 | q = q * self.scale 167 | attn = q @ k.transpose(-2, -1) 168 | attn = attn.softmax(dim=-1) 169 | attn = self.attn_drop(attn) 170 | x = attn @ v 171 | x = x.transpose(1, 2) 172 | 173 | x = x.reshape(B, N, C) 174 | x = self.proj(x) 175 | x = self.proj_drop(x) 176 | return x 177 | 178 | 179 | def modulate(x, shift, scale): 180 | return x * (1 + scale) + shift 181 | 182 | 183 | class FinalLayer(nn.Module): 184 | """ 185 | The final layer of DiT. 186 | """ 187 | 188 | def __init__(self, hidden_size, out_channels): 189 | super().__init__() 190 | self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 191 | self.linear = nn.Linear(hidden_size, out_channels, bias=True) 192 | self.adaLN_modulation = nn.Sequential( 193 | nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True) 194 | ) 195 | 196 | def forward(self, x, c): 197 | shift, scale = self.adaLN_modulation(c).chunk(2, dim=2) 198 | x = modulate(self.norm_final(x), shift, scale) 199 | x = self.linear(x) 200 | return x 201 | 202 | 203 | class DiTBlock(nn.Module): 204 | """ 205 | A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning. 206 | """ 207 | 208 | def __init__( 209 | self, 210 | hidden_size, 211 | num_heads, 212 | mlp_ratio=4.0, 213 | ffn_type="conv1d_conv1d", 214 | ffn_gated_glu=True, 215 | ffn_act_layer="gelu", 216 | ffn_conv_kernel_size=5, 217 | **block_kwargs, 218 | ): 219 | super().__init__() 220 | self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 221 | self.attn = Attention( 222 | hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs 223 | ) 224 | 225 | self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 226 | 227 | if ffn_type == "vanilla_mlp": 228 | from timm.models.vision_transformer import Mlp 229 | 230 | mlp_hidden_dim = int(hidden_size * mlp_ratio) 231 | approx_gelu = lambda: nn.GELU(approximate="tanh") 232 | self.mlp = Mlp( 233 | in_features=hidden_size, 234 | hidden_features=mlp_hidden_dim, 235 | act_layer=approx_gelu, 236 | drop=0, 237 | ) 238 | else: 239 | raise NotImplementedError(f"FFN type {ffn_type} is not implemented") 240 | 241 | self.adaLN_modulation = nn.Sequential( 242 | nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True) 243 | ) 244 | 245 | def forward( 246 | self, 247 | x, 248 | c, 249 | seq_len, 250 | cu_seqlens, 251 | cu_maxlen, 252 | cu_seqlens_k, 253 | cu_maxlen_k, 254 | mask, 255 | rotary_pos_emb=None, 256 | incremental_state=None, 257 | nopadding=True, 258 | ): 259 | shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( 260 | self.adaLN_modulation(c).chunk(6, dim=2) 261 | ) 262 | 263 | x_ = modulate(self.norm1(x), shift_msa, scale_msa) 264 | 265 | if incremental_state is not None: 266 | if "attn_kvcache" not in incremental_state: 267 | incremental_state["attn_kvcache"] = {} 268 | inc_attn = incremental_state["attn_kvcache"] 269 | else: 270 | inc_attn = None 271 | 272 | x_ = self.attn( 273 | x_, 274 | seq_len=seq_len, 275 | cu_seqlens=cu_seqlens, 276 | max_seqlen=cu_maxlen, 277 | cu_seqlens_k=cu_seqlens_k, 278 | max_seqlen_k=cu_maxlen_k, 279 | rotary_pos_emb=rotary_pos_emb, 280 | incremental_state=inc_attn, 281 | nopadding=nopadding, 282 | ) 283 | 284 | if not nopadding: 285 | x_ = x_ * mask[:, :, None] 286 | 287 | x = x + gate_msa * x_ 288 | 289 | x_ = modulate(self.norm2(x), shift_mlp, scale_mlp) 290 | 291 | x_ = self.mlp(x_) 292 | 293 | if not nopadding: 294 | x_ = x_ * mask[:, :, None] 295 | 296 | x = x + gate_mlp * x_ 297 | return x 298 | -------------------------------------------------------------------------------- /kimia_infer/models/detokenizer/vocoder/alias_free_activation/cuda/anti_alias_activation_cuda.cu: -------------------------------------------------------------------------------- 1 | /* coding=utf-8 2 | * Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | #include 18 | #include 19 | #include 20 | #include 21 | #include 22 | #include 23 | #include 24 | #include "type_shim.h" 25 | #include 26 | #include 27 | #include 28 | #include 29 | #include 30 | 31 | namespace 32 | { 33 | // Hard-coded hyperparameters 34 | // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and 35 | constexpr int ELEMENTS_PER_LDG_STG = 1; //(WARP_ITERATIONS < 4) ? 1 : 4; 36 | constexpr int BUFFER_SIZE = 32; 37 | constexpr int FILTER_SIZE = 12; 38 | constexpr int HALF_FILTER_SIZE = 6; 39 | constexpr int UPSAMPLE_REPLICATION_PAD = 5; // 5 on each side, matching torch impl 40 | constexpr int DOWNSAMPLE_REPLICATION_PAD_LEFT = 5; // matching torch impl 41 | constexpr int DOWNSAMPLE_REPLICATION_PAD_RIGHT = 6; // matching torch impl 42 | 43 | template 44 | __global__ void anti_alias_activation_forward( 45 | output_t *dst, 46 | const input_t *src, 47 | const input_t *up_ftr, 48 | const input_t *down_ftr, 49 | const input_t *alpha, 50 | const input_t *beta, 51 | int batch_size, 52 | int channels, 53 | int seq_len) 54 | { 55 | // Up and downsample filters 56 | input_t up_filter[FILTER_SIZE]; 57 | input_t down_filter[FILTER_SIZE]; 58 | 59 | // Load data from global memory including extra indices reserved for replication paddings 60 | input_t elements[2 * FILTER_SIZE + 2 * BUFFER_SIZE + 2 * UPSAMPLE_REPLICATION_PAD] = {0}; 61 | input_t intermediates[2 * FILTER_SIZE + 2 * BUFFER_SIZE + DOWNSAMPLE_REPLICATION_PAD_LEFT + DOWNSAMPLE_REPLICATION_PAD_RIGHT] = {0}; 62 | 63 | // Output stores downsampled output before writing to dst 64 | output_t output[BUFFER_SIZE]; 65 | 66 | // blockDim/threadIdx = (128, 1, 1) 67 | // gridDim/blockIdx = (seq_blocks, channels, batches) 68 | int block_offset = (blockIdx.x * 128 * BUFFER_SIZE + seq_len * (blockIdx.y + gridDim.y * blockIdx.z)); 69 | int local_offset = threadIdx.x * BUFFER_SIZE; 70 | int seq_offset = blockIdx.x * 128 * BUFFER_SIZE + local_offset; 71 | 72 | // intermediate have double the seq_len 73 | int intermediate_local_offset = threadIdx.x * BUFFER_SIZE * 2; 74 | int intermediate_seq_offset = blockIdx.x * 128 * BUFFER_SIZE * 2 + intermediate_local_offset; 75 | 76 | // Get values needed for replication padding before moving pointer 77 | const input_t *right_most_pntr = src + (seq_len * (blockIdx.y + gridDim.y * blockIdx.z)); 78 | input_t seq_left_most_value = right_most_pntr[0]; 79 | input_t seq_right_most_value = right_most_pntr[seq_len - 1]; 80 | 81 | // Move src and dst pointers 82 | src += block_offset + local_offset; 83 | dst += block_offset + local_offset; 84 | 85 | // Alpha and beta values for snake activatons. Applies exp by default 86 | alpha = alpha + blockIdx.y; 87 | input_t alpha_val = expf(alpha[0]); 88 | beta = beta + blockIdx.y; 89 | input_t beta_val = expf(beta[0]); 90 | 91 | #pragma unroll 92 | for (int it = 0; it < FILTER_SIZE; it += 1) 93 | { 94 | up_filter[it] = up_ftr[it]; 95 | down_filter[it] = down_ftr[it]; 96 | } 97 | 98 | // Apply replication padding for upsampling, matching torch impl 99 | #pragma unroll 100 | for (int it = -HALF_FILTER_SIZE; it < BUFFER_SIZE + HALF_FILTER_SIZE; it += 1) 101 | { 102 | int element_index = seq_offset + it; // index for element 103 | if ((element_index < 0) && (element_index >= -UPSAMPLE_REPLICATION_PAD)) 104 | { 105 | elements[2 * (HALF_FILTER_SIZE + it)] = 2 * seq_left_most_value; 106 | } 107 | if ((element_index >= seq_len) && (element_index < seq_len + UPSAMPLE_REPLICATION_PAD)) 108 | { 109 | elements[2 * (HALF_FILTER_SIZE + it)] = 2 * seq_right_most_value; 110 | } 111 | if ((element_index >= 0) && (element_index < seq_len)) 112 | { 113 | elements[2 * (HALF_FILTER_SIZE + it)] = 2 * src[it]; 114 | } 115 | } 116 | 117 | // Apply upsampling strided convolution and write to intermediates. It reserves DOWNSAMPLE_REPLICATION_PAD_LEFT for replication padding of the downsampilng conv later 118 | #pragma unroll 119 | for (int it = 0; it < (2 * BUFFER_SIZE + 2 * FILTER_SIZE); it += 1) 120 | { 121 | input_t acc = 0.0; 122 | int element_index = intermediate_seq_offset + it; // index for intermediate 123 | #pragma unroll 124 | for (int f_idx = 0; f_idx < FILTER_SIZE; f_idx += 1) 125 | { 126 | if ((element_index + f_idx) >= 0) 127 | { 128 | acc += up_filter[f_idx] * elements[it + f_idx]; 129 | } 130 | } 131 | intermediates[it + DOWNSAMPLE_REPLICATION_PAD_LEFT] = acc; 132 | } 133 | 134 | // Apply activation function. It reserves DOWNSAMPLE_REPLICATION_PAD_LEFT and DOWNSAMPLE_REPLICATION_PAD_RIGHT for replication padding of the downsampilng conv later 135 | double no_div_by_zero = 0.000000001; 136 | #pragma unroll 137 | for (int it = 0; it < 2 * BUFFER_SIZE + 2 * FILTER_SIZE; it += 1) 138 | { 139 | intermediates[it + DOWNSAMPLE_REPLICATION_PAD_LEFT] += (1.0 / (beta_val + no_div_by_zero)) * sinf(intermediates[it + DOWNSAMPLE_REPLICATION_PAD_LEFT] * alpha_val) * sinf(intermediates[it + DOWNSAMPLE_REPLICATION_PAD_LEFT] * alpha_val); 140 | } 141 | 142 | // Apply replication padding before downsampling conv from intermediates 143 | #pragma unroll 144 | for (int it = 0; it < DOWNSAMPLE_REPLICATION_PAD_LEFT; it += 1) 145 | { 146 | intermediates[it] = intermediates[DOWNSAMPLE_REPLICATION_PAD_LEFT]; 147 | } 148 | #pragma unroll 149 | for (int it = DOWNSAMPLE_REPLICATION_PAD_LEFT + 2 * BUFFER_SIZE + 2 * FILTER_SIZE; it < DOWNSAMPLE_REPLICATION_PAD_LEFT + 2 * BUFFER_SIZE + 2 * FILTER_SIZE + DOWNSAMPLE_REPLICATION_PAD_RIGHT; it += 1) 150 | { 151 | intermediates[it] = intermediates[DOWNSAMPLE_REPLICATION_PAD_LEFT + 2 * BUFFER_SIZE + 2 * FILTER_SIZE - 1]; 152 | } 153 | 154 | // Apply downsample strided convolution (assuming stride=2) from intermediates 155 | #pragma unroll 156 | for (int it = 0; it < BUFFER_SIZE; it += 1) 157 | { 158 | input_t acc = 0.0; 159 | #pragma unroll 160 | for (int f_idx = 0; f_idx < FILTER_SIZE; f_idx += 1) 161 | { 162 | // Add constant DOWNSAMPLE_REPLICATION_PAD_RIGHT to match torch implementation 163 | acc += down_filter[f_idx] * intermediates[it * 2 + f_idx + DOWNSAMPLE_REPLICATION_PAD_RIGHT]; 164 | } 165 | output[it] = acc; 166 | } 167 | 168 | // Write output to dst 169 | #pragma unroll 170 | for (int it = 0; it < BUFFER_SIZE; it += ELEMENTS_PER_LDG_STG) 171 | { 172 | int element_index = seq_offset + it; 173 | if (element_index < seq_len) 174 | { 175 | dst[it] = output[it]; 176 | } 177 | } 178 | 179 | } 180 | 181 | template 182 | void dispatch_anti_alias_activation_forward( 183 | output_t *dst, 184 | const input_t *src, 185 | const input_t *up_ftr, 186 | const input_t *down_ftr, 187 | const input_t *alpha, 188 | const input_t *beta, 189 | int batch_size, 190 | int channels, 191 | int seq_len) 192 | { 193 | if (seq_len == 0) 194 | { 195 | return; 196 | } 197 | else 198 | { 199 | // Use 128 threads per block to maximimize gpu utilization 200 | constexpr int threads_per_block = 128; 201 | constexpr int seq_len_per_block = 4096; 202 | int blocks_per_seq_len = (seq_len + seq_len_per_block - 1) / seq_len_per_block; 203 | dim3 blocks(blocks_per_seq_len, channels, batch_size); 204 | dim3 threads(threads_per_block, 1, 1); 205 | 206 | anti_alias_activation_forward 207 | <<>>(dst, src, up_ftr, down_ftr, alpha, beta, batch_size, channels, seq_len); 208 | } 209 | } 210 | } 211 | 212 | extern "C" torch::Tensor fwd_cuda(torch::Tensor const &input, torch::Tensor const &up_filter, torch::Tensor const &down_filter, torch::Tensor const &alpha, torch::Tensor const &beta) 213 | { 214 | // Input is a 3d tensor with dimensions [batches, channels, seq_len] 215 | const int batches = input.size(0); 216 | const int channels = input.size(1); 217 | const int seq_len = input.size(2); 218 | 219 | // Output 220 | auto act_options = input.options().requires_grad(false); 221 | 222 | torch::Tensor anti_alias_activation_results = 223 | torch::empty({batches, channels, seq_len}, act_options); 224 | 225 | void *input_ptr = static_cast(input.data_ptr()); 226 | void *up_filter_ptr = static_cast(up_filter.data_ptr()); 227 | void *down_filter_ptr = static_cast(down_filter.data_ptr()); 228 | void *alpha_ptr = static_cast(alpha.data_ptr()); 229 | void *beta_ptr = static_cast(beta.data_ptr()); 230 | void *anti_alias_activation_results_ptr = static_cast(anti_alias_activation_results.data_ptr()); 231 | 232 | DISPATCH_FLOAT_HALF_AND_BFLOAT( 233 | input.scalar_type(), 234 | "dispatch anti alias activation_forward", 235 | dispatch_anti_alias_activation_forward( 236 | reinterpret_cast(anti_alias_activation_results_ptr), 237 | reinterpret_cast(input_ptr), 238 | reinterpret_cast(up_filter_ptr), 239 | reinterpret_cast(down_filter_ptr), 240 | reinterpret_cast(alpha_ptr), 241 | reinterpret_cast(beta_ptr), 242 | batches, 243 | channels, 244 | seq_len);); 245 | return anti_alias_activation_results; 246 | } -------------------------------------------------------------------------------- /kimia_infer/api/kimia.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import tqdm 4 | import torch 5 | from loguru import logger 6 | from huggingface_hub import cached_assets_path 7 | from transformers import AutoModelForCausalLM 8 | 9 | from kimia_infer.models.detokenizer import get_audio_detokenizer 10 | from .prompt_manager import KimiAPromptManager 11 | from kimia_infer.utils.sampler import KimiASampler 12 | from huggingface_hub import snapshot_download 13 | 14 | class KimiAudio(object): 15 | def __init__(self, model_path: str, load_detokenizer: bool = True): 16 | logger.info(f"Loading kimi-audio main model") 17 | 18 | if os.path.exists(model_path): 19 | # local path 20 | cache_path = model_path 21 | else: 22 | # cache everything if model_path is a model-id 23 | cache_path = snapshot_download(model_path) 24 | 25 | logger.info(f"Looking for resources in {cache_path}") 26 | logger.info(f"Loading whisper model") 27 | self.alm = AutoModelForCausalLM.from_pretrained( 28 | cache_path, torch_dtype=torch.bfloat16, trust_remote_code=True 29 | ) 30 | self.alm = self.alm.to(torch.cuda.current_device()) 31 | 32 | model_config = self.alm.config 33 | self.kimia_text_audiodelaytokens = model_config.kimia_mimo_audiodelaytokens 34 | self.kimia_token_offset = model_config.kimia_token_offset 35 | 36 | self.prompt_manager = KimiAPromptManager( 37 | model_path=cache_path, kimia_token_offset=self.kimia_token_offset, kimia_text_audiodelaytokens=self.kimia_text_audiodelaytokens 38 | ) 39 | 40 | if load_detokenizer: 41 | logger.info(f"Loading detokenizer") 42 | # need to compile extension moudules for the first time, it may take several minutes. 43 | self.detokenizer = get_audio_detokenizer(cache_path) 44 | else: 45 | # in this case, you're not allowed to generate audio(wav) 46 | self.detokenizer = None 47 | 48 | self.extra_tokens = self.prompt_manager.extra_tokens 49 | self.eod_ids = [self.extra_tokens.msg_end, self.extra_tokens.media_end] 50 | 51 | @torch.inference_mode() 52 | def _generate_loop( 53 | self, 54 | audio_input_ids: torch.Tensor, # input audio tokens 55 | text_input_ids: torch.Tensor = None, # input text tokens if use multi-input 56 | max_new_tokens: int = 50, 57 | audio_top_k: int = 5, 58 | audio_temperature: float = 0.0, 59 | audio_repetition_penalty: float = 1.0, 60 | audio_repetition_window_size: int = 64, 61 | text_top_k: int = 5, 62 | text_temperature: float = 0.0, 63 | text_repetition_penalty: float = 1.0, 64 | text_repetition_window_size: int = 16, 65 | is_continuous_mask: torch.Tensor = None, 66 | continous_feature: torch.Tensor = None, 67 | output_type: str = "text", 68 | ): 69 | 70 | sampler = KimiASampler( 71 | audio_top_k=audio_top_k, 72 | audio_temperature=audio_temperature, 73 | audio_repetition_penalty=audio_repetition_penalty, 74 | audio_repetition_window_size=audio_repetition_window_size, 75 | text_top_k=text_top_k, 76 | text_temperature=text_temperature, 77 | text_repetition_penalty=text_repetition_penalty, 78 | text_repetition_window_size=text_repetition_window_size, 79 | ) 80 | 81 | text_stream_is_finished = False 82 | previous_audio_tokens = torch.zeros( 83 | (4096,), 84 | dtype=torch.int, 85 | device=torch.cuda.current_device(), 86 | ) 87 | text_previous_tokens = torch.zeros( 88 | (4096,), 89 | dtype=torch.int, 90 | device=torch.cuda.current_device(), 91 | ) 92 | 93 | decoder_input_audio_ids = audio_input_ids.clone() 94 | decoder_input_text_ids = text_input_ids.clone() 95 | decoder_position_ids = ( 96 | torch.arange( 97 | 0, decoder_input_audio_ids.shape[1], device=torch.cuda.current_device() 98 | ) 99 | .unsqueeze(0) 100 | .long() 101 | ) 102 | decoder_input_whisper_feature = continous_feature 103 | decoder_is_continuous_mask = is_continuous_mask 104 | past_key_values = None 105 | 106 | last_position_id = decoder_input_audio_ids.shape[1] - 1 107 | 108 | valid_text_length = 0 109 | valid_audio_length = 0 110 | 111 | for i in tqdm.tqdm( 112 | range(max_new_tokens), desc="Generating tokens", disable=False 113 | ): 114 | audio_logits, text_logits, past_key_values = self.alm.forward( 115 | input_ids=decoder_input_audio_ids, 116 | text_input_ids=decoder_input_text_ids, 117 | whisper_input_feature=decoder_input_whisper_feature, 118 | is_continuous_mask=decoder_is_continuous_mask, 119 | position_ids=decoder_position_ids, 120 | past_key_values=past_key_values, 121 | return_dict=False, 122 | ) 123 | 124 | # Sample text token using the sampler 125 | next_token_text = sampler.sample_text_logits( 126 | text_logits, recent_tokens=text_previous_tokens[:i] if i > 0 else None 127 | ) 128 | 129 | # Sample audio token using the sampler 130 | next_audio_token = sampler.sample_audio_logits( 131 | audio_logits, recent_tokens=previous_audio_tokens[:i] if i > 0 else None 132 | ) 133 | 134 | if text_stream_is_finished: 135 | next_token_text.fill_(self.extra_tokens.kimia_text_blank) 136 | elif next_token_text.item() == self.extra_tokens.kimia_text_eos: 137 | text_stream_is_finished = True 138 | else: 139 | valid_text_length += 1 140 | 141 | text_previous_tokens[i : i + 1] = next_token_text 142 | 143 | if i < self.kimia_text_audiodelaytokens: 144 | next_audio_token.fill_(self.extra_tokens.kimia_text_blank) 145 | else: 146 | if output_type == "text": 147 | next_audio_token.fill_(self.extra_tokens.kimia_text_blank) 148 | else: 149 | valid_audio_length += 1 150 | 151 | previous_audio_tokens[i : i + 1] = next_audio_token 152 | 153 | audio_stream_is_finished = next_audio_token.item() in self.eod_ids 154 | 155 | if ( 156 | output_type == "text" 157 | and text_stream_is_finished 158 | or output_type == "both" 159 | and audio_stream_is_finished 160 | ): 161 | return_text_tokens = ( 162 | text_previous_tokens[:valid_text_length] 163 | .detach() 164 | .cpu() 165 | .numpy() 166 | .tolist() 167 | ) 168 | return_audio_tokens = ( 169 | previous_audio_tokens[ 170 | self.kimia_text_audiodelaytokens : valid_audio_length 171 | + self.kimia_text_audiodelaytokens 172 | ] 173 | .detach() 174 | .cpu() 175 | .numpy() 176 | .tolist() 177 | ) 178 | return return_audio_tokens, return_text_tokens 179 | else: 180 | decoder_input_audio_ids = next_audio_token.unsqueeze(1) 181 | decoder_input_text_ids = next_token_text.unsqueeze(1) 182 | 183 | decoder_position_ids = ( 184 | torch.zeros(1, 1, device=torch.cuda.current_device()) 185 | .fill_(last_position_id + 1) 186 | .long() 187 | .view(1, 1) 188 | ) 189 | last_position_id += 1 190 | 191 | decoder_input_whisper_feature = None 192 | decoder_is_continuous_mask = None 193 | 194 | return_text_tokens = ( 195 | text_previous_tokens[:valid_text_length].detach().cpu().numpy().tolist() 196 | ) 197 | return_audio_tokens = ( 198 | previous_audio_tokens[ 199 | self.kimia_text_audiodelaytokens : valid_audio_length 200 | + self.kimia_text_audiodelaytokens 201 | ] 202 | .detach() 203 | .cpu() 204 | .numpy() 205 | .tolist() 206 | ) 207 | return return_audio_tokens, return_text_tokens 208 | 209 | @torch.inference_mode() 210 | def generate( 211 | self, 212 | chats: list[dict], 213 | output_type="text", 214 | audio_temperature=0.0, 215 | audio_top_k=5, 216 | text_temperature=0.0, 217 | text_top_k=5, 218 | audio_repetition_penalty=1.0, 219 | audio_repetition_window_size=64, 220 | text_repetition_penalty=1.0, 221 | text_repetition_window_size=16, 222 | max_new_tokens=-1, 223 | ): 224 | ## TODO: 需要一个check函数,检查输入的history格式是否合法 225 | ## 比如,对于ASR任务,一定是: text-instruction/audio-instruction + audio-content, 我理解content和instruction是不能换位置的 226 | ## assistant前必须有user等等,我觉得最好做一下check 227 | 228 | assert output_type in ["text", "both"] 229 | 230 | history = self.prompt_manager.get_prompt(chats, output_type=output_type) 231 | 232 | audio_input_ids, text_input_ids, is_continuous_mask, _, _ = history.to_tensor() 233 | audio_features = history.continuous_feature 234 | 235 | generated_wav_tokens = [] 236 | generated_text_tokens = [] 237 | 238 | if output_type == "both": 239 | max_new_tokens = int(12.5 * 120) - audio_input_ids.shape[1] 240 | else: 241 | if max_new_tokens == -1: 242 | max_new_tokens = 7500 - audio_input_ids.shape[1] 243 | 244 | audio_input_ids = audio_input_ids.to(torch.cuda.current_device()) 245 | text_input_ids = text_input_ids.to(torch.cuda.current_device()) 246 | is_continuous_mask = is_continuous_mask.to(torch.cuda.current_device()) 247 | audio_features = [f.to(torch.cuda.current_device()) for f in audio_features] 248 | 249 | generated_wav_tokens, generated_text_tokens = self._generate_loop( 250 | audio_input_ids=audio_input_ids, 251 | text_input_ids=text_input_ids, 252 | max_new_tokens=max_new_tokens, 253 | audio_temperature=audio_temperature, 254 | audio_top_k=audio_top_k, 255 | audio_repetition_penalty=audio_repetition_penalty, 256 | audio_repetition_window_size=audio_repetition_window_size, 257 | text_top_k=text_top_k, 258 | text_temperature=text_temperature, 259 | text_repetition_penalty=text_repetition_penalty, 260 | text_repetition_window_size=text_repetition_window_size, 261 | is_continuous_mask=is_continuous_mask, 262 | continous_feature=audio_features, 263 | output_type=output_type, 264 | ) 265 | 266 | generated_wav_tokens = [ 267 | t for t in generated_wav_tokens if t >= self.kimia_token_offset 268 | ] # filter out the illegal tokens 269 | 270 | generated_wav_tokens = torch.tensor(generated_wav_tokens).unsqueeze(0) 271 | generated_wav_tokens = generated_wav_tokens - self.kimia_token_offset 272 | 273 | generated_text_tokens = [ 274 | t for t in generated_text_tokens if t < self.kimia_token_offset 275 | ] 276 | generated_text = self.detokenize_text(generated_text_tokens) 277 | if self.detokenizer is not None and output_type == "both": 278 | generated_wav = self.detokenize_audio(generated_wav_tokens) 279 | else: 280 | generated_wav = None 281 | 282 | return generated_wav, generated_text 283 | 284 | def detokenize_audio(self, audio_tokens): 285 | if self.detokenizer is None: 286 | raise ValueError("Detokenizer is not initialized") 287 | self.detokenizer.clear_states() 288 | chunk_size = 30 # hard-coded right now 289 | first_chunk_size = 30 290 | cache_speech_collection = [] 291 | audio_tokens = audio_tokens.to(torch.cuda.current_device()) 292 | audio_tokens = audio_tokens.long() 293 | num_audio_tokens = audio_tokens.size(1) 294 | first_chunk_semantic_tokens = audio_tokens[:, :first_chunk_size] 295 | gen_speech = self.detokenizer.detokenize_streaming( 296 | first_chunk_semantic_tokens, 297 | is_final=(num_audio_tokens <= first_chunk_size), 298 | upsample_factor=4, 299 | ) 300 | cache_speech_collection.append(gen_speech) 301 | 302 | if num_audio_tokens > first_chunk_size: 303 | res_semantic_tokens = audio_tokens[:, first_chunk_size:] 304 | for i in range(0, res_semantic_tokens.size(1), chunk_size): 305 | chunk_semantic_tokens = res_semantic_tokens[:, i : i + chunk_size] 306 | gen_speech = self.detokenizer.detokenize_streaming( 307 | chunk_semantic_tokens, 308 | upsample_factor=4, 309 | is_final=(i + chunk_size >= res_semantic_tokens.size(1)), 310 | ) 311 | cache_speech_collection.append(gen_speech) 312 | 313 | gen_speech = torch.cat(cache_speech_collection, dim=-1) 314 | return gen_speech 315 | 316 | def detokenize_text(self, text_tokens): 317 | valid_text_ids = [] 318 | for x in text_tokens: 319 | if x == self.extra_tokens.kimia_text_eos: 320 | break 321 | valid_text_ids.append(x) 322 | return self.prompt_manager.text_tokenizer.decode(valid_text_ids) 323 | -------------------------------------------------------------------------------- /kimia_infer/models/detokenizer/flow_matching/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | from .dit_block import DiTBlock, FinalLayer 5 | 6 | 7 | def precompute_freqs_cis( 8 | dim: int, 9 | end: int, 10 | theta: float = 10000.0, 11 | interpolation_factor: int = 1, 12 | max_seq_length: int = 4096, 13 | ): 14 | print( 15 | f"using rope base theta = {theta}, interpolation factor = {interpolation_factor}" 16 | ) 17 | freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) 18 | 19 | # ROPE type-A extention 20 | # we choose to use interpolation rather than extrapolation for better position encoding 21 | # for scale purposes, t should be a float tensor 22 | t = torch.arange(end, device=freqs.device).float() 23 | scale = 1.0 / float(interpolation_factor) 24 | t *= scale 25 | 26 | freqs = torch.outer(t, freqs).float() # type: ignore 27 | freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 28 | 29 | # Sometimes, we don't need so many rope emb as seq_len is smaller than max_pos_emb 30 | # e.g. rope 1M but seqlen 32k, this will cause gpu memory waste 31 | if max_seq_length < end: 32 | freqs_cis = freqs_cis[:max_seq_length,].clone() 33 | return freqs_cis 34 | 35 | 36 | class TimestepEmbedder(nn.Module): 37 | """ 38 | Embeds scalar timesteps into vector representations. 39 | """ 40 | 41 | def __init__(self, hidden_size, frequency_embedding_size=256): 42 | super().__init__() 43 | self.mlp = nn.Sequential( 44 | nn.Linear(frequency_embedding_size, hidden_size, bias=True), 45 | nn.SiLU(), 46 | nn.Linear(hidden_size, hidden_size, bias=True), 47 | ) 48 | self.frequency_embedding_size = frequency_embedding_size 49 | 50 | @staticmethod 51 | def timestep_embedding(t, dim, max_period=10000): 52 | """ 53 | Create sinusoidal timestep embeddings. 54 | :param t: a 1-D Tensor of N indices, one per batch element. 55 | These may be fractional. 56 | :param dim: the dimension of the output. 57 | :param max_period: controls the minimum frequency of the embeddings. 58 | :return: an (N, D) Tensor of positional embeddings. 59 | """ 60 | # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py 61 | half = dim // 2 62 | freqs = ( 63 | torch.exp( 64 | -math.log(max_period) 65 | * torch.arange(start=0, end=half, dtype=torch.float32) 66 | / half 67 | ) 68 | .float() 69 | .to(device=t.device) 70 | ) 71 | args = t[:, None].float() * freqs[None] 72 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 73 | if dim % 2: 74 | embedding = torch.cat( 75 | [embedding, torch.zeros_like(embedding[:, :1])], dim=-1 76 | ) 77 | return embedding 78 | 79 | def forward(self, t): 80 | t_freq = self.timestep_embedding(t, self.frequency_embedding_size) 81 | t_emb = self.mlp(t_freq.to(self.mlp[0].weight.dtype)) 82 | return t_emb 83 | 84 | 85 | class SinusoidalPositionalEmbedding(nn.Module): 86 | """This module produces sinusoidal positional embeddings of any length. 87 | 88 | Padding symbols are ignored. 89 | """ 90 | 91 | def __init__(self, embedding_dim, padding_idx, init_size=1024): 92 | super().__init__() 93 | self.embedding_dim = embedding_dim 94 | self.padding_idx = padding_idx 95 | self.weights = SinusoidalPositionalEmbedding.get_embedding( 96 | init_size, 97 | embedding_dim, 98 | padding_idx, 99 | ) 100 | self.register_buffer("_float_tensor", torch.FloatTensor(1)) 101 | 102 | @staticmethod 103 | def get_embedding(num_embeddings, embedding_dim, padding_idx=None): 104 | """Build sinusoidal embeddings. 105 | 106 | This matches the implementation in tensor2tensor, but differs slightly 107 | from the description in Section 3.5 of "Attention Is All You Need". 108 | """ 109 | half_dim = embedding_dim // 2 # d/2 110 | emb = math.log(10000) / (half_dim - 1) # 2*log(10000)/(d-2) 111 | emb = torch.exp( 112 | torch.arange(half_dim, dtype=torch.float) * -emb 113 | ) # -2i/(d-2)*log(10000); i from 0 to (d-2)/2; shape: (d/2, ) 114 | emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze( 115 | 1 116 | ) * emb.unsqueeze( 117 | 0 118 | ) # pos/[1000 ** (2i/(d-2))]; shape: (num_embeddings, d/2) 119 | emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view( 120 | num_embeddings, -1 121 | ) # shape: (num_embeddings, d) 122 | if embedding_dim % 2 == 1: 123 | # zero pad 124 | emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1) 125 | if padding_idx is not None: 126 | emb[padding_idx, :] = 0 127 | return emb 128 | 129 | def forward(self, input, incremental_state=None, timestep=None, **kwargs): 130 | """Input is expected to be of size [bsz x seqlen].""" 131 | bsz, seq_len = input.shape[:2] 132 | max_pos = self.padding_idx + 1 + seq_len 133 | if self.weights is None or max_pos > self.weights.size(0): 134 | # recompute/expand embeddings if needed 135 | self.weights = SinusoidalPositionalEmbedding.get_embedding( 136 | max_pos, 137 | self.embedding_dim, 138 | self.padding_idx, 139 | ) 140 | self.weights = self.weights.to(self._float_tensor) 141 | 142 | if incremental_state is not None: 143 | # positions is the same for every token when decoding a single step 144 | pos = timestep.view(-1)[0] + 1 if timestep is not None else seq_len 145 | return self.weights[self.padding_idx + pos, :].expand(bsz, 1, -1) 146 | 147 | positions = self.make_positions(input, self.padding_idx) 148 | return ( 149 | self.weights.index_select(0, positions.view(-1)) 150 | .view(bsz, seq_len, -1) 151 | .detach() 152 | ) # (B, T, dim) 153 | 154 | def max_positions(self): 155 | """Maximum number of supported positions.""" 156 | return int(1e5) # an arbitrary large number 157 | 158 | def make_positions(self, tensor, padding_idx): 159 | """Replace non-padding symbols with their position numbers. 160 | 161 | Position numbers begin at padding_idx+1. Padding symbols are ignored. 162 | """ 163 | # The series of casts and type-conversions here are carefully 164 | # balanced to both work with ONNX export and XLA. In particular XLA 165 | # prefers ints, cumsum defaults to output longs, and ONNX doesn't know 166 | # how to handle the dtype kwarg in cumsum. 167 | mask = tensor.ne(padding_idx).int() 168 | return (torch.cumsum(mask, dim=1).type_as(mask) * mask).long() + padding_idx 169 | 170 | 171 | class DiTPrefix(nn.Module): 172 | """ 173 | Diffusion model with a Transformer backbone. 174 | """ 175 | 176 | def __init__( 177 | self, 178 | input_size, 179 | output_size, 180 | semantic_vocab_size, 181 | hidden_size=1024, 182 | depth=12, 183 | num_heads=4, 184 | # mlp related 185 | mlp_ratio=4.0, 186 | ffn_type="conv1d_conv1d", 187 | ffn_gated_glu=True, 188 | ffn_act_layer="gelu", 189 | ffn_conv_kernel_size=5, 190 | # rope 191 | use_rope=False, 192 | rope_params={ 193 | "max_position_embeddings": 4096, 194 | "rope_base": 10000.0, 195 | "rope_interpolation_factor": 1.0, 196 | }, 197 | position_embedding_type="sincos", 198 | max_seq_len=4096, 199 | prompt_cfg_dropout=0.0, 200 | ): 201 | super().__init__() 202 | self.num_heads = num_heads 203 | 204 | self.prompt_cfg_dropout = prompt_cfg_dropout 205 | 206 | self.t_embedder = TimestepEmbedder(hidden_size) 207 | 208 | self.semantic_token_embedding = nn.Embedding(semantic_vocab_size, hidden_size) 209 | 210 | self.input_linear = nn.Linear(input_size, hidden_size) 211 | 212 | # position embedding 213 | if position_embedding_type == "learnable": 214 | self.position_embedding = nn.Embedding(max_seq_len + 1, hidden_size) 215 | elif position_embedding_type == "sincos": 216 | self.position_embedding = SinusoidalPositionalEmbedding( 217 | hidden_size, 0, max_seq_len + 1 218 | ) 219 | elif position_embedding_type == "skip": 220 | self.position_embedding = None 221 | else: 222 | raise NotImplementedError( 223 | "Position embedding type: {} not implemented.".format( 224 | position_embedding_type 225 | ) 226 | ) 227 | 228 | self.use_rope = use_rope 229 | 230 | if self.use_rope: 231 | 232 | assert ( 233 | hidden_size % num_heads == 0 234 | ), "Hidden size must be divisible by num_heads for rope position embedding." 235 | rope_dim = hidden_size // num_heads 236 | 237 | self.rotary_pos_emb = precompute_freqs_cis( 238 | rope_dim, 239 | rope_params["max_position_embeddings"], 240 | theta=rope_params["rope_base"], 241 | interpolation_factor=rope_params["rope_interpolation_factor"], 242 | max_seq_length=max_seq_len, 243 | ) 244 | 245 | self.blocks = nn.ModuleList( 246 | [ 247 | DiTBlock( 248 | hidden_size, 249 | num_heads, 250 | mlp_ratio=mlp_ratio, 251 | ffn_type=ffn_type, 252 | ffn_conv_kernel_size=ffn_conv_kernel_size, 253 | ffn_gated_glu=ffn_gated_glu, 254 | ffn_act_layer=ffn_act_layer, 255 | ) 256 | for _ in range(depth) 257 | ] 258 | ) 259 | self.final_layer = FinalLayer(hidden_size, output_size) 260 | self.initialize_weights() 261 | 262 | def initialize_weights(self): 263 | # Initialize transformer layers: 264 | def _basic_init(module): 265 | if isinstance(module, nn.Linear): 266 | torch.nn.init.xavier_uniform_(module.weight) 267 | if module.bias is not None: 268 | nn.init.constant_(module.bias, 0) 269 | 270 | self.apply(_basic_init) 271 | 272 | # Initialize timestep embedding MLP: 273 | nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) 274 | nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) 275 | 276 | # Zero-out adaLN modulation layers in DiT blocks: 277 | for block in self.blocks: 278 | nn.init.constant_(block.adaLN_modulation[-1].weight, 0) 279 | nn.init.constant_(block.adaLN_modulation[-1].bias, 0) 280 | 281 | # Zero-out output layers: 282 | nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) 283 | nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) 284 | nn.init.constant_(self.final_layer.linear.weight, 0) 285 | nn.init.constant_(self.final_layer.linear.bias, 0) 286 | 287 | def forward( 288 | self, 289 | x, 290 | position_ids, 291 | t, 292 | condition, 293 | seq_len, 294 | cu_seqlens, 295 | cu_maxlen, 296 | cu_seqlens_k, 297 | cu_maxlen_k, 298 | mask, 299 | incremental_state=None, 300 | nopadding=True, 301 | ): 302 | """ 303 | Forward pass of DiT. 304 | x: (N, T, C) tensor of inputs (latent representations of speech) 305 | position_ids: (N, T) tensor of positional indices 306 | t: (N,) tensor of diffusion timesteps 307 | condition: (N, T) tensor of semantic tokens 308 | seq_len: (N,) tensor of sequence lengths 309 | """ 310 | 311 | condition = self.semantic_token_embedding(condition) # (N, T, D) 312 | 313 | x = self.input_linear(x) 314 | 315 | if self.position_embedding is not None: 316 | position_emb = self.position_embedding(position_ids) 317 | x = x + position_emb 318 | 319 | # ROPE 320 | if self.use_rope: 321 | bsz, seqlen = position_ids.shape 322 | if self.rotary_pos_emb.device != position_ids.device: 323 | self.rotary_pos_emb = self.rotary_pos_emb.to(position_ids.device) 324 | rotary_pos_emb = torch.zeros( 325 | (bsz, seqlen, self.rotary_pos_emb.shape[1]), 326 | dtype=self.rotary_pos_emb.dtype, 327 | device=self.rotary_pos_emb.device, 328 | ) 329 | for b in range(bsz): 330 | cur_rope = rotary_pos_emb[b] 331 | cur_position_ids = position_ids[b] 332 | cur_rope[:] = self.rotary_pos_emb[cur_position_ids] 333 | else: 334 | rotary_pos_emb = None 335 | 336 | t = self.t_embedder(t) # (N, D) 337 | c = t.unsqueeze(1) + condition # (N, T, D) 338 | 339 | for block_idx, block in enumerate(self.blocks): 340 | # x = block(x, c, attn_mask) # (N, T, D) 341 | # XXX mask could be None because we always use full mask 342 | 343 | if incremental_state is not None: 344 | if block_idx not in incremental_state: 345 | incremental_state[block_idx] = {} 346 | incr = incremental_state[block_idx] 347 | else: 348 | incr = None 349 | 350 | x = block( 351 | x=x, 352 | c=c, 353 | seq_len=seq_len, 354 | cu_seqlens=cu_seqlens, 355 | cu_maxlen=cu_maxlen, 356 | cu_seqlens_k=cu_seqlens_k, 357 | cu_maxlen_k=cu_maxlen_k, 358 | mask=mask, 359 | rotary_pos_emb=rotary_pos_emb, 360 | incremental_state=incr, 361 | nopadding=nopadding, 362 | ) 363 | 364 | x = self.final_layer(x, c) # (N, T, C) 365 | return x 366 | -------------------------------------------------------------------------------- /kimia_infer/models/detokenizer/semantic_fm_prefix_streaming.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | import logging 3 | import time 4 | 5 | import os 6 | import torch 7 | 8 | from .flow_matching.ode_wrapper import StreamingODEWrapperForPrefix 9 | from .flow_matching.model import DiTPrefix 10 | from .flow_matching.scheduler import StreamingFlowMatchingScheduler 11 | 12 | 13 | logger = logging.getLogger(__name__) 14 | 15 | 16 | class StreamingSemanticFMWrapper: 17 | def __init__( 18 | self, 19 | speech_model: DiTPrefix, 20 | max_kv_cache_tokens=900, 21 | max_prompt_chunk=2, 22 | use_cfg=True, 23 | use_cfg_rescale=True, 24 | cfg_init=1.5, 25 | cfg_scale=7.5, 26 | cfg_schedule="linear", 27 | cfg_token_id=0, 28 | normalize_mel=False, 29 | mel_mean=None, 30 | mel_std=None, 31 | device: torch.device = torch.device("cpu"), 32 | ) -> None: 33 | 34 | self.dtype = torch.bfloat16 35 | self.speech_model = speech_model.to(device).to(self.dtype) 36 | self.speech_model = self.speech_model.eval() 37 | self.device = device 38 | self.normalize_mel = normalize_mel 39 | self.mel_mean = mel_mean 40 | self.mel_std = mel_std 41 | 42 | self.use_cfg = use_cfg 43 | self.use_cfg_rescale = use_cfg_rescale 44 | self.cfg_init = cfg_init 45 | self.cfg_scale = cfg_scale 46 | self.cfg_schedule = cfg_schedule 47 | 48 | self.incremental_state = {} 49 | self.condition_cache = {"previous_seqlen": 0} 50 | 51 | logger.info( 52 | f">>> SemanticFMWrapper initialized with use_cfg={use_cfg}, use_cfg_rescale={use_cfg_rescale}, cfg_init={cfg_init}, cfg_scale={cfg_scale}, cfg_schedule={cfg_schedule}" 53 | ) 54 | 55 | self.scheduler = StreamingFlowMatchingScheduler() 56 | self.ode_wrapper = StreamingODEWrapperForPrefix( 57 | net=self.speech_model, 58 | x_mask=None, 59 | x_cond=None, 60 | use_cfg=use_cfg, 61 | use_cfg_rescale=use_cfg_rescale, 62 | cfg_init=cfg_init, 63 | cfg_scale=cfg_scale, 64 | cfg_schedule=cfg_schedule, 65 | cfg_token_id=cfg_token_id, 66 | ) 67 | 68 | self.max_kv_cache_tokens = max_kv_cache_tokens 69 | self.max_prompt_chunk = max_prompt_chunk 70 | self.reserve_kv_cache_tokens = 0 71 | 72 | @torch.inference_mode() 73 | def infer_chunk( 74 | self, 75 | xt_chunk, 76 | semantic_tokens_chunk, 77 | start_position_id, 78 | cache=None, 79 | look_ahead_tokens=0, 80 | ode_steps=15, 81 | verbose=False, 82 | ode_solver="neural_ode_euler", 83 | ): 84 | """ 85 | semantic_tokens: [T_1], torch.LongTensor 86 | xt: [T_2, 80], torch.Tensor, DO NOT normalize it outside 87 | ode_steps: int, number of ode steps, default 15 88 | verbose: bool, default False 89 | ode_solver: str, ode solver, expected in ("neural_ode_euler", "naive_euler"), default "neural_ode_euler" 90 | """ 91 | bs = 1 92 | 93 | self.scheduler.set_timesteps(ode_steps) 94 | 95 | semantic_tokens_chunk = semantic_tokens_chunk.unsqueeze(0).to(self.device) 96 | xt_chunk = xt_chunk.unsqueeze(0).to(self.device).to(self.dtype) 97 | 98 | t_span = torch.linspace(0, 1, self.scheduler.timesteps) 99 | 100 | x_mask = torch.zeros(bs, xt_chunk.shape[1], device=self.device).bool() 101 | 102 | cache_ret = self.ode_wrapper.set_conditions( 103 | x_mask=x_mask, 104 | x_cond=semantic_tokens_chunk, 105 | start_position_id=start_position_id, 106 | cache=self.condition_cache, 107 | ) 108 | 109 | if verbose: 110 | t_start = time.time() 111 | if ode_solver == "neural_ode_euler": 112 | x_t = self.scheduler.sample_by_neuralode( 113 | self.ode_wrapper, time_steps=t_span, xt=xt_chunk, verbose=False 114 | ) 115 | elif ode_solver == "naive_euler": 116 | x_t = self.scheduler.sample( 117 | ode_wrapper=self.ode_wrapper, 118 | time_steps=t_span, 119 | xt=xt_chunk, 120 | verbose=False, 121 | ) 122 | else: 123 | raise NotImplementedError( 124 | "ode_solver should be in ('neural_ode_euler', 'naive_euler')" 125 | ) 126 | 127 | if look_ahead_tokens > 0: 128 | semantic_tokens_left = semantic_tokens_chunk.view(-1)[-look_ahead_tokens:] 129 | cache["semantic_token"] = semantic_tokens_left 130 | x_t_ret = x_t[:, :-look_ahead_tokens, :] 131 | else: 132 | x_t_ret = x_t 133 | 134 | if look_ahead_tokens > 0: 135 | x_mask = torch.zeros( 136 | bs, xt_chunk.shape[1] - look_ahead_tokens, device=self.device 137 | ).bool() 138 | self.condition_cache = self.ode_wrapper.set_conditions( 139 | x_mask=x_mask, 140 | x_cond=semantic_tokens_chunk[:, :-look_ahead_tokens], 141 | start_position_id=start_position_id, 142 | cache=self.condition_cache, 143 | ) 144 | self.ode_wrapper(torch.Tensor([0.999]).to(x_t_ret.device), x_t_ret) 145 | else: 146 | self.condition_cache = cache_ret 147 | 148 | if verbose: 149 | t_end = time.time() 150 | logger.info(f"[ODE Chunk] Time cost: {t_end - t_start}") 151 | 152 | if self.normalize_mel: 153 | x_t_ret = x_t_ret * self.mel_std + self.mel_mean 154 | return x_t_ret.squeeze(0) 155 | 156 | @torch.inference_mode() 157 | def infer_mel( 158 | self, 159 | semantic_tokens, 160 | ode_steps=15, 161 | chunk_size=150, 162 | verbose=False, 163 | ode_solver="neural_ode_euler", 164 | ): 165 | """ 166 | semantic_tokens: [T_1], torch.LongTensor 167 | prompt: [T_2, 80], torch.Tensor, DO NOT normalize it outside 168 | prompt_semantic_tokens, [T_2], torch.LongTensor 169 | ode_steps: int, number of ode steps, default 15 170 | verbose: bool, default False 171 | ode_solver: str, ode solver, expected in ("neural_ode_euler", "naive_euler"), default "neural_ode_euler" 172 | """ 173 | assert semantic_tokens.dim() == 1 174 | 175 | x_t = torch.randn(semantic_tokens.shape[0], 80).to(self.device).to(self.dtype) 176 | 177 | seq_len = semantic_tokens.shape[0] 178 | 179 | num_chunks = seq_len // chunk_size 180 | if seq_len % chunk_size != 0: 181 | num_chunks += 1 182 | 183 | x_pred_collect = [] 184 | 185 | if verbose: 186 | t_start = time.time() 187 | 188 | for chunk_id in range(num_chunks): 189 | start = chunk_id * chunk_size 190 | end = min(start + chunk_size, seq_len) 191 | semantic_tokens_chunk = semantic_tokens[start:end] 192 | x_t_chunk = x_t[start:end, :] 193 | 194 | x_pred = self.infer_chunk( 195 | xt_chunk=x_t_chunk, 196 | semantic_tokens_chunk=semantic_tokens_chunk, 197 | start_position_id=self.start_position_id, 198 | ode_steps=ode_steps, 199 | verbose=verbose, 200 | ode_solver=ode_solver, 201 | ) 202 | self.start_position_id += end - start 203 | self.update_incremental_state() 204 | 205 | x_pred_collect.append(x_pred) 206 | 207 | if verbose: 208 | t_end = time.time() 209 | logger.info(f"[ODE] Time cost: {t_end - t_start}") 210 | 211 | x_pred = torch.cat(x_pred_collect, dim=0) 212 | 213 | return x_pred 214 | 215 | def clear_all_states(self): 216 | self.start_position_id = 0 217 | self.condition_cache = {"previous_seqlen": 0} 218 | self.ode_wrapper.clear_all_states() 219 | 220 | def state_dict(self): 221 | return { 222 | "start_position_id": self.start_position_id, 223 | "ode_wrapper": self.ode_wrapper.state_dict(), 224 | "condition_cache": self.condition_cache, 225 | } 226 | 227 | def load_state_dict(self, state_dict): 228 | if state_dict is not None: 229 | self.start_position_id = state_dict["start_position_id"] 230 | self.ode_wrapper.load_state_dict(state_dict["ode_wrapper"]) 231 | self.condition_cache = state_dict["condition_cache"] 232 | 233 | def update_incremental_state(self): 234 | self.ode_wrapper.update_incremental_state( 235 | reserve_kv_cache_tokens=0, 236 | max_kv_cache_tokens=self.max_kv_cache_tokens, 237 | condition_cache=self.condition_cache, 238 | ) 239 | 240 | @torch.inference_mode() 241 | def prefill(self, mel, semantic_token, chunk_size=150, verbose=False): 242 | """ 243 | mel: [T, 80], torch.Tensor 244 | semantic_token: [T], torch.LongTensor 245 | chunk_size: int, default 150 246 | """ 247 | assert mel.dim() == 2 248 | assert semantic_token.dim() == 1 249 | assert ( 250 | semantic_token.shape[0] == mel.shape[0] 251 | ), "Semantic token and mel shape mismatch" 252 | seq_len = mel.shape[0] 253 | num_chunks = min(seq_len // chunk_size, self.max_prompt_chunk) 254 | start_pos = seq_len - num_chunks * chunk_size 255 | 256 | res_mel = mel[:start_pos, :] 257 | res_semantic_token = semantic_token[:start_pos] 258 | self.prefill_chunk( 259 | res_mel, res_semantic_token, start_position_id=self.start_position_id 260 | ) 261 | self.start_position_id += start_pos 262 | self.update_incremental_state() 263 | self.reserve_kv_cache_tokens += self.ode_wrapper.kv_cache_tokens 264 | 265 | if verbose: 266 | logger.info("Prefilling prompt with {} chunks".format(num_chunks)) 267 | start_time = time.time() 268 | 269 | for chunk_id in range(num_chunks): 270 | start = start_pos + chunk_id * chunk_size 271 | end = start + chunk_size 272 | mel_chunk = mel[start:end, :] 273 | semantic_token_chunk = semantic_token[start:end] 274 | 275 | self.prefill_chunk( 276 | mel_chunk, 277 | semantic_token_chunk, 278 | start_position_id=self.start_position_id, 279 | ) 280 | self.start_position_id += end - start 281 | 282 | self.update_incremental_state() 283 | self.reserve_kv_cache_tokens += self.ode_wrapper.kv_cache_tokens 284 | 285 | if verbose: 286 | logger.info( 287 | "Prefilling done in {:.2f} seconds".format(time.time() - start_time) 288 | ) 289 | 290 | def prefill_chunk(self, mel_chunk, semantic_tokens_chunk, start_position_id=0): 291 | """ 292 | mel_chunk: [T, 80], torch.Tensor, T is the chunk size 293 | semantic_tokens_chunk: [T], torch.LongTensor 294 | start_position_id: int, default 0 295 | """ 296 | bs = 1 297 | 298 | semantic_tokens_chunk = semantic_tokens_chunk.unsqueeze(0).to(self.device) 299 | mel_chunk = mel_chunk.unsqueeze(0).to(self.device).to(self.dtype) 300 | 301 | if self.normalize_mel: 302 | mel_chunk = (mel_chunk - self.mel_mean) / self.mel_std 303 | 304 | x_mask = torch.zeros(bs, mel_chunk.shape[1], device=self.device).bool() 305 | 306 | self.condition_cache = self.ode_wrapper.set_conditions( 307 | x_mask=x_mask, 308 | x_cond=semantic_tokens_chunk, 309 | start_position_id=start_position_id, 310 | cache=self.condition_cache, 311 | ) 312 | 313 | x_t = torch.Tensor([0.999]).to(self.device) 314 | 315 | self.ode_wrapper(x_t, mel_chunk) 316 | 317 | @classmethod 318 | def from_pretrained( 319 | cls, 320 | model_config, 321 | ckpt_path, 322 | device, 323 | max_prompt_chunk=2, 324 | max_kv_cache_tokens=900, 325 | use_cfg=True, 326 | use_cfg_rescale=True, 327 | cfg_init=1.5, 328 | cfg_scale=7.5, 329 | cfg_schedule="linear", 330 | ): 331 | 332 | # open yaml file 333 | with open(model_config, "r") as f: 334 | config = yaml.safe_load(f) 335 | model_config = config["model"]["dit"] 336 | dit = DiTPrefix( 337 | input_size=model_config["input_size"], 338 | semantic_vocab_size=model_config["semantic_vocab_size"] + 1, 339 | hidden_size=model_config["hidden_size"], 340 | depth=model_config["depth"], 341 | num_heads=model_config["num_heads"], 342 | mlp_ratio=model_config["mlp_ratio"], 343 | ffn_type=model_config.get("ffn_type", "conv1d_conv1d"), 344 | ffn_gated_glu=model_config.get("ffn_gated_glu", True), 345 | ffn_act_layer=model_config.get("ffn_act_layer", "gelu"), 346 | ffn_conv_kernel_size=model_config.get("ffn_conv_kernel_size", 5), 347 | use_rope=model_config.get("use_rope", False), 348 | rope_params=model_config.get( 349 | "rope_params", 350 | { 351 | "max_position_embeddings": 4096, 352 | "rope_base": 10000, 353 | "rope_interpolation_factor": 1, 354 | }, 355 | ), 356 | position_embedding_type=model_config["position_embedding_type"], 357 | max_seq_len=model_config["max_seq_len"], 358 | output_size=model_config["input_size"], 359 | prompt_cfg_dropout=0, 360 | ) 361 | cfg_semantic_token_id = model_config["semantic_vocab_size"] 362 | 363 | # load state_dict 364 | state_dict = torch.load(ckpt_path, map_location="cpu", weights_only=True)[ 365 | "state_dict" 366 | ] 367 | speech_model_params = { 368 | k.replace("speech_model.", ""): v 369 | for k, v in state_dict.items() 370 | if "speech_model" in k 371 | } 372 | dit.load_state_dict(speech_model_params, strict=True) 373 | logger.info(f">>> Loaded checkpoint from {ckpt_path}") 374 | 375 | return cls( 376 | speech_model=dit, 377 | device=device, 378 | normalize_mel=config["normalize_mel"], 379 | mel_mean=config["mel_mean"], 380 | mel_std=config["mel_std"], 381 | max_prompt_chunk=max_prompt_chunk, 382 | max_kv_cache_tokens=max_kv_cache_tokens, 383 | use_cfg=use_cfg, 384 | use_cfg_rescale=use_cfg_rescale, 385 | cfg_init=cfg_init, 386 | cfg_scale=cfg_scale, 387 | cfg_schedule=cfg_schedule, 388 | cfg_token_id=cfg_semantic_token_id, 389 | ) 390 | --------------------------------------------------------------------------------