├── .gitattributes ├── .gitignore ├── Dockerfile ├── Dockerfile-vllm ├── LICENSE ├── README.md ├── README_CN.md ├── README_JP.md ├── __init__.py ├── app.py ├── assets ├── Step-Audio.pdf ├── architecture.png ├── logo.png ├── pipeline.png ├── rlhf.png ├── stepeval_radar_chart.png └── yuewen.jpeg ├── call_vllm_chat.py ├── cosyvoice ├── __init__.py ├── cli │ ├── __init__.py │ ├── cosyvoice.py │ ├── frontend.py │ └── model.py ├── flow │ ├── decoder.py │ ├── flow.py │ ├── flow_matching.py │ └── length_regulator.py ├── hifigan │ ├── f0_predictor.py │ └── generator.py ├── matcha │ ├── audio.py │ ├── decoder.py │ ├── flow_matching.py │ └── transformer.py ├── transformer │ ├── __init__.py │ ├── activation.py │ ├── attention.py │ ├── convolution.py │ ├── decoder.py │ ├── decoder_layer.py │ ├── embedding.py │ ├── encoder.py │ ├── encoder_layer.py │ ├── label_smoothing_loss.py │ ├── positionwise_feed_forward.py │ └── subsampling.py └── utils │ ├── __init__.py │ ├── audio.py │ ├── class_utils.py │ ├── common.py │ ├── executor.py │ ├── file_utils.py │ ├── frontend_utils.py │ ├── mask.py │ ├── scheduler.py │ └── train_utils.py ├── examples ├── clone_wav_lixueqin.wav ├── clone_wav_yuqian.wav ├── emotional_control1.wav ├── emotional_control2.wav ├── multilingual1.wav ├── multilingual2.wav ├── multilingual_singing.wav ├── prompt_wav_lixueqin.wav ├── prompt_wav_yuqian.wav ├── prompt_wav_zhaobenshan.wav ├── rap.wav ├── singing.wav ├── speed_control1.wav ├── speed_control2.wav └── tone_control.wav ├── funasr_detach ├── __init__.py ├── auto │ ├── __init__.py │ ├── auto_frontend.py │ ├── auto_model.py │ └── auto_tokenizer.py ├── bin │ ├── __init__.py │ ├── compute_audio_cmvn.py │ ├── inference.py │ ├── tokenize_text.py │ └── train.py ├── datasets │ ├── __init__.py │ └── audio_datasets │ │ ├── __init__.py │ │ ├── datasets.py │ │ ├── index_ds.py │ │ ├── preprocessor.py │ │ ├── samplers.py │ │ └── scp2jsonl.py ├── download │ ├── __init__.py │ ├── download_dataset_from_hub.py │ ├── download_from_hub.py │ ├── file.py │ ├── name_maps_from_hub.py │ └── runtime_sdk_download_tool.py ├── frontends │ ├── __init__.py │ ├── default.py │ ├── eend_ola_feature.py │ ├── fused.py │ ├── s3prl.py │ ├── utils │ │ ├── __init__.py │ │ ├── beamformer.py │ │ ├── complex_utils.py │ │ ├── dnn_beamformer.py │ │ ├── dnn_wpe.py │ │ ├── feature_transform.py │ │ ├── frontend.py │ │ ├── log_mel.py │ │ ├── mask_estimator.py │ │ └── stft.py │ ├── wav_frontend.py │ └── windowing.py ├── losses │ ├── __init__.py │ └── label_smoothing_loss.py ├── metrics │ ├── __init__.py │ ├── common.py │ ├── compute_acc.py │ ├── compute_eer.py │ ├── compute_min_dcf.py │ └── compute_wer.py ├── models │ ├── __init__.py │ ├── bat │ │ ├── __init__.py │ │ └── model.py │ ├── bicif_paraformer │ │ ├── __init__.py │ │ ├── cif_predictor.py │ │ ├── model.py │ │ └── template.yaml │ ├── branchformer │ │ ├── __init__.py │ │ ├── cgmlp.py │ │ ├── encoder.py │ │ ├── fastformer.py │ │ ├── model.py │ │ └── template.yaml │ ├── campplus │ │ ├── __init__.py │ │ ├── cluster_backend.py │ │ ├── components.py │ │ ├── model.py │ │ ├── template.yaml │ │ └── utils.py │ ├── conformer │ │ ├── __init__.py │ │ ├── encoder.py │ │ ├── model.py │ │ └── template.yaml │ ├── contextual_paraformer │ │ ├── __init__.py │ │ ├── decoder.py │ │ ├── model.py │ │ └── template.yaml │ ├── ct_transformer │ │ ├── __init__.py │ │ ├── model.py │ │ ├── template.yaml │ │ └── utils.py │ ├── ct_transformer_streaming │ │ ├── __init__.py │ │ ├── attention.py │ │ ├── encoder.py │ │ ├── model.py │ │ └── template.yaml │ ├── ctc │ │ ├── __init__.py │ │ └── ctc.py │ ├── data2vec │ │ ├── __init__.py │ │ ├── data2vec.py │ │ ├── data2vec_encoder.py │ │ ├── data_utils.py │ │ ├── ema_module.py │ │ ├── grad_multiply.py │ │ ├── multihead_attention.py │ │ ├── quant_noise.py │ │ ├── utils.py │ │ └── wav2vec2.py │ ├── e_branchformer │ │ ├── __init__.py │ │ ├── encoder.py │ │ ├── model.py │ │ └── template.yaml │ ├── eend │ │ ├── __init__.py │ │ ├── e2e_diar_eend_ola.py │ │ ├── eend_ola_dataloader.py │ │ ├── encoder.py │ │ ├── encoder_decoder_attractor.py │ │ └── utils │ │ │ ├── __init__.py │ │ │ ├── feature.py │ │ │ ├── kaldi_data.py │ │ │ ├── losses.py │ │ │ ├── power.py │ │ │ └── report.py │ ├── emotion2vec │ │ ├── __init__.py │ │ ├── audio.py │ │ ├── base.py │ │ ├── fairseq_modules.py │ │ ├── model.py │ │ ├── modules.py │ │ ├── template.yaml │ │ └── timm_modules.py │ ├── eres2net │ │ ├── __init__.py │ │ ├── eres2net.py │ │ ├── eres2net_aug.py │ │ └── fusion.py │ ├── fsmn_vad_streaming │ │ ├── __init__.py │ │ ├── encoder.py │ │ ├── model.py │ │ └── template.yaml │ ├── language_model │ │ ├── __init__.py │ │ ├── rnn │ │ │ ├── __init__.py │ │ │ ├── argument.py │ │ │ ├── attentions.py │ │ │ ├── decoders.py │ │ │ └── encoders.py │ │ ├── seq_rnn_lm.py │ │ ├── transformer_encoder.py │ │ └── transformer_lm.py │ ├── lora │ │ ├── __init__.py │ │ ├── layers.py │ │ └── utils.py │ ├── mfcca │ │ ├── __init__.py │ │ ├── e2e_asr_mfcca.py │ │ ├── encoder_layer_mfcca.py │ │ └── mfcca_encoder.py │ ├── model_hf │ │ └── __init__.py │ ├── monotonic_aligner │ │ ├── __init__.py │ │ ├── model.py │ │ └── template.yaml │ ├── mossformer │ │ ├── __init__.py │ │ ├── e2e_ss.py │ │ ├── mossformer.py │ │ ├── mossformer_decoder.py │ │ └── mossformer_encoder.py │ ├── normalize │ │ ├── __init__.py │ │ ├── global_mvn.py │ │ └── utterance_mvn.py │ ├── paraformer │ │ ├── __init__.py │ │ ├── cif_predictor.py │ │ ├── decoder.py │ │ ├── model.py │ │ ├── search.py │ │ └── template.yaml │ ├── paraformer_streaming │ │ ├── __init__.py │ │ ├── model.py │ │ └── template.yaml │ ├── rwkv_bat │ │ ├── __init__.py │ │ ├── cuda_decoder │ │ │ ├── wkv_cuda.cu │ │ │ └── wkv_op.cpp │ │ ├── cuda_encoder │ │ │ ├── wkv_cuda.cu │ │ │ └── wkv_op.cpp │ │ ├── rwkv.py │ │ ├── rwkv_attention.py │ │ ├── rwkv_encoder.py │ │ ├── rwkv_feed_forward.py │ │ └── rwkv_subsampling.py │ ├── sa_asr │ │ ├── __init__.py │ │ ├── attention.py │ │ ├── beam_search_sa_asr.py │ │ ├── e2e_sa_asr.py │ │ └── transformer_decoder.py │ ├── sanm │ │ ├── __init__.py │ │ ├── attention.py │ │ ├── decoder.py │ │ ├── encoder.py │ │ ├── model.py │ │ ├── positionwise_feed_forward.py │ │ └── template.yaml │ ├── scama │ │ ├── __init__.py │ │ ├── beam_search.py │ │ ├── chunk_utilis.py │ │ ├── decoder.py │ │ ├── encoder.py │ │ ├── model.py │ │ ├── template.yaml │ │ └── utils.py │ ├── seaco_paraformer │ │ ├── __init__.py │ │ ├── model.py │ │ └── template.yaml │ ├── sond │ │ ├── __init__.py │ │ ├── attention.py │ │ ├── e2e_diar_sond.py │ │ ├── encoder │ │ │ ├── __init__.py │ │ │ ├── ci_scorers.py │ │ │ ├── conv_encoder.py │ │ │ ├── ecapa_tdnn_encoder.py │ │ │ ├── fsmn_encoder.py │ │ │ ├── resnet34_encoder.py │ │ │ └── self_attention_encoder.py │ │ ├── label_aggregation.py │ │ ├── pooling │ │ │ ├── __init__.py │ │ │ ├── pooling_layers.py │ │ │ └── statistic_pooling.py │ │ └── sv_decoder.py │ ├── specaug │ │ ├── __init__.py │ │ ├── mask_along_axis.py │ │ ├── profileaug.py │ │ ├── specaug.py │ │ └── time_warp.py │ ├── transducer │ │ ├── __init__.py │ │ ├── beam_search_transducer.py │ │ ├── joint_network.py │ │ ├── model.py │ │ ├── rnn_decoder.py │ │ └── rnnt_decoder.py │ ├── transformer │ │ ├── __init__.py │ │ ├── attention.py │ │ ├── decoder.py │ │ ├── embedding.py │ │ ├── encoder.py │ │ ├── layer_norm.py │ │ ├── model.py │ │ ├── positionwise_feed_forward.py │ │ ├── scorers │ │ │ ├── __init__.py │ │ │ ├── ctc.py │ │ │ ├── ctc_prefix_score.py │ │ │ ├── length_bonus.py │ │ │ └── scorer_interface.py │ │ ├── search.py │ │ ├── template.yaml │ │ └── utils │ │ │ ├── __init__.py │ │ │ ├── add_sos_eos.py │ │ │ ├── dynamic_conv.py │ │ │ ├── dynamic_conv2d.py │ │ │ ├── lightconv.py │ │ │ ├── lightconv2d.py │ │ │ ├── mask.py │ │ │ ├── multi_layer_conv.py │ │ │ ├── nets_utils.py │ │ │ ├── repeat.py │ │ │ ├── subsampling.py │ │ │ ├── subsampling_without_posenc.py │ │ │ └── vgg2l.py │ ├── uniasr │ │ ├── __init__.py │ │ ├── beam_search.py │ │ ├── model.py │ │ └── template.yaml │ ├── whisper │ │ ├── __init__.py │ │ ├── model.py │ │ └── utils │ │ │ ├── __init__.py │ │ │ ├── assets │ │ │ ├── gpt2 │ │ │ │ ├── merges.txt │ │ │ │ ├── special_tokens_map.json │ │ │ │ ├── tokenizer_config.json │ │ │ │ └── vocab.json │ │ │ ├── mel_filters.npz │ │ │ └── multilingual │ │ │ │ ├── added_tokens.json │ │ │ │ ├── merges.txt │ │ │ │ ├── special_tokens_map.json │ │ │ │ ├── tokenizer_config.json │ │ │ │ └── vocab.json │ │ │ ├── audio.py │ │ │ ├── decoding.py │ │ │ ├── tokenizer.py │ │ │ ├── transcribe.py │ │ │ └── utils.py │ └── xvector │ │ ├── __init__.py │ │ └── e2e_sv.py ├── optimizers │ ├── __init__.py │ ├── fairseq_adam.py │ └── sgd.py ├── register.py ├── schedulers │ ├── __init__.py │ ├── abs_scheduler.py │ ├── noam_lr.py │ ├── tri_stage_scheduler.py │ └── warmup_lr.py ├── tokenizer │ ├── __init__.py │ ├── abs_tokenizer.py │ ├── build_tokenizer.py │ ├── char_tokenizer.py │ ├── cleaner.py │ ├── korean_cleaner.py │ ├── phoneme_tokenizer.py │ ├── sentencepiece_tokenizer.py │ ├── token_id_converter.py │ └── word_tokenizer.py ├── train_utils │ ├── __init__.py │ ├── add_gradient_noise.py │ ├── average_nbest_models.py │ ├── device_funcs.py │ ├── forward_adaptor.py │ ├── initialize.py │ ├── load_pretrained_model.py │ ├── model_summary.py │ ├── recursive_op.py │ ├── set_all_random_seed.py │ └── trainer.py ├── utils │ ├── __init__.py │ ├── datadir_writer.py │ ├── load_utils.py │ ├── misc.py │ ├── postprocess_utils.py │ ├── prepare_data.py │ ├── speaker_utils.py │ ├── timestamp_tools.py │ ├── types.py │ └── vad_utils.py └── version.txt ├── offline_inference.py ├── requirements-vllm.txt ├── requirements.txt ├── speakers ├── TingtingRAP_prompt.wav ├── Tingting_prompt.wav ├── Tingting哼唱_prompt.wav └── speakers_info.json ├── stepaudio.py ├── tokenizer.py ├── tts.py ├── tts_app.py ├── tts_inference.py └── utils.py /.gitattributes: -------------------------------------------------------------------------------- 1 | examples filter=lfs diff=lfs merge=lfs -text 2 | speakers/nezha_prompt.wav filter=lfs diff=lfs merge=lfs -text 3 | speakers/nezhaRAP_prompt.wav filter=lfs diff=lfs merge=lfs -text 4 | speakers/nezha哼唱_prompt.wav filter=lfs diff=lfs merge=lfs -text 5 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | output/ -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvidia/cuda:12.1.0-cudnn8-runtime-ubuntu22.04 2 | 3 | ENV TZ=Asia/Shanghai 4 | RUN ln -snf /usr/share/zoneinfo/$TZ /etc/localtime \ 5 | && echo $TZ > /etc/timezone 6 | 7 | RUN apt-get update \ 8 | && apt-get install -y build-essential \ 9 | && apt-get install -y wget \ 10 | && apt-get install -y software-properties-common curl zip unzip git-lfs awscli libssl-dev openssh-server vim \ 11 | && apt-get install -y net-tools iputils-ping iproute2 \ 12 | && apt-get clean \ 13 | && rm -rf /var/lib/apt/lists/* 14 | 15 | RUN apt-get install --reinstall ca-certificates && update-ca-certificates \ 16 | && apt-get clean \ 17 | && rm -rf /var/lib/apt/lists/* 18 | 19 | RUN add-apt-repository -y 'ppa:deadsnakes/ppa' && apt update 20 | RUN apt install python3.10 python3.10-dev python3.10-distutils python3.10-venv -y \ 21 | && apt-get clean \ 22 | && rm -rf /var/lib/apt/lists/* 23 | 24 | RUN wget -qO- https://bootstrap.pypa.io/get-pip.py | python3.10 25 | RUN ln -s /usr/bin/python3.10 /usr/bin/python 26 | RUN pip uninstall -y Pillow && pip install pillow 27 | 28 | COPY requirements.txt /tmp/requirements.txt 29 | RUN pip3 install -r /tmp/requirements.txt 30 | RUN pip3 install onnxruntime-gpu==1.17.0 --index-url=https://pkgs.dev.azure.com/onnxruntime/onnxruntime/_packaging/onnxruntime-cuda-12/pypi/simple --force-reinstall --no-deps 31 | -------------------------------------------------------------------------------- /Dockerfile-vllm: -------------------------------------------------------------------------------- 1 | FROM nvidia/cuda:12.4.1-cudnn-runtime-ubuntu22.04 2 | 3 | ENV TZ=Asia/Shanghai 4 | RUN ln -snf /usr/share/zoneinfo/$TZ /etc/localtime \ 5 | && echo $TZ > /etc/timezone 6 | 7 | RUN apt-get update \ 8 | && apt-get install -y build-essential \ 9 | && apt-get install -y wget \ 10 | && apt-get install -y software-properties-common curl zip unzip git-lfs awscli libssl-dev openssh-server vim \ 11 | && apt-get install -y net-tools iputils-ping iproute2 \ 12 | && apt-get clean \ 13 | && rm -rf /var/lib/apt/lists/* 14 | 15 | RUN apt-get install --reinstall ca-certificates && update-ca-certificates \ 16 | && apt-get clean \ 17 | && rm -rf /var/lib/apt/lists/* 18 | 19 | RUN add-apt-repository -y 'ppa:deadsnakes/ppa' && apt update 20 | RUN apt install python3.10 python3.10-dev python3.10-distutils python3.10-venv -y \ 21 | && apt-get clean \ 22 | && rm -rf /var/lib/apt/lists/* 23 | 24 | RUN wget -qO- https://bootstrap.pypa.io/get-pip.py | python3.10 25 | RUN ln -s /usr/bin/python3.10 /usr/bin/python 26 | RUN pip uninstall -y Pillow && pip install pillow 27 | 28 | COPY requirements-vllm.txt /tmp/requirements.txt 29 | RUN pip3 install -r /tmp/requirements.txt 30 | # update vllm 31 | RUN VLLM_PYTHON_DIR=$(pip3 show vllm | grep Location | awk '{print $2}')/vllm \ 32 | && git clone -b add-step1-model https://github.com/stepfun-ai/vllm.git /tmp/vllm \ 33 | && cd /tmp/vllm/vllm \ 34 | && find . -name '*.py' -exec cp -v --parents -t $VLLM_PYTHON_DIR {} + \ 35 | && rm -rf /tmp/vllm -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stepfun-ai/Step-Audio/d67e374fd6aff418835418e2e32989ca2db00dff/__init__.py -------------------------------------------------------------------------------- /assets/Step-Audio.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stepfun-ai/Step-Audio/d67e374fd6aff418835418e2e32989ca2db00dff/assets/Step-Audio.pdf -------------------------------------------------------------------------------- /assets/architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stepfun-ai/Step-Audio/d67e374fd6aff418835418e2e32989ca2db00dff/assets/architecture.png -------------------------------------------------------------------------------- /assets/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stepfun-ai/Step-Audio/d67e374fd6aff418835418e2e32989ca2db00dff/assets/logo.png -------------------------------------------------------------------------------- /assets/pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stepfun-ai/Step-Audio/d67e374fd6aff418835418e2e32989ca2db00dff/assets/pipeline.png -------------------------------------------------------------------------------- /assets/rlhf.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stepfun-ai/Step-Audio/d67e374fd6aff418835418e2e32989ca2db00dff/assets/rlhf.png -------------------------------------------------------------------------------- /assets/stepeval_radar_chart.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stepfun-ai/Step-Audio/d67e374fd6aff418835418e2e32989ca2db00dff/assets/stepeval_radar_chart.png -------------------------------------------------------------------------------- /assets/yuewen.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stepfun-ai/Step-Audio/d67e374fd6aff418835418e2e32989ca2db00dff/assets/yuewen.jpeg -------------------------------------------------------------------------------- /call_vllm_chat.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from openai import OpenAI 3 | 4 | 5 | if __name__ == "__main__": 6 | parser = argparse.ArgumentParser() 7 | parser.add_argument( 8 | "--server-url", 9 | type=str, 10 | default="http://127.0.0.1:8000", 11 | help="VLLM server url.", 12 | ) 13 | parser.add_argument( 14 | "--model-name", type=str, default="step-audio-chat", help="Model name." 15 | ) 16 | args = parser.parse_args() 17 | 18 | server_url = args.server_url + "/v1" # for chat route 19 | client = OpenAI(base_url=server_url, api_key="whatever") 20 | 21 | messages = [ 22 | { 23 | "role": "system", 24 | "content": "You are an AI designed for conversation, currently unable to connect to the internet.", 25 | }, 26 | {"role": "user", "content": "Introduce yourself."}, 27 | ] 28 | completion = client.chat.completions.create( 29 | model=args.model_name, 30 | messages=messages, 31 | ) 32 | res = completion.choices[0].message.content 33 | print(res) 34 | -------------------------------------------------------------------------------- /cosyvoice/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stepfun-ai/Step-Audio/d67e374fd6aff418835418e2e32989ca2db00dff/cosyvoice/__init__.py -------------------------------------------------------------------------------- /cosyvoice/cli/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stepfun-ai/Step-Audio/d67e374fd6aff418835418e2e32989ca2db00dff/cosyvoice/cli/__init__.py -------------------------------------------------------------------------------- /cosyvoice/cli/cosyvoice.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import os 15 | import uuid 16 | import time 17 | from tqdm import tqdm 18 | import torch 19 | import torchaudio 20 | from hyperpyyaml import load_hyperpyyaml 21 | from cosyvoice.cli.frontend import CosyVoiceFrontEnd 22 | from cosyvoice.cli.model import CosyVoiceModel 23 | 24 | 25 | class CosyVoice: 26 | 27 | def __init__( 28 | self, 29 | model_dir, 30 | ): 31 | self.model_dir = model_dir 32 | with open("{}/cosyvoice.yaml".format(model_dir), "r") as f: 33 | configs = load_hyperpyyaml(f) 34 | self.frontend = CosyVoiceFrontEnd( 35 | configs["feat_extractor"], 36 | "{}/campplus.onnx".format(model_dir), 37 | "{}/speech_tokenizer_v1.onnx".format(model_dir), 38 | ) 39 | self.model = CosyVoiceModel(configs["flow"], configs["hift"]) 40 | self.model.load( 41 | "{}/flow.pt".format(model_dir), 42 | "{}/hift.pt".format(model_dir), 43 | ) 44 | self.model.flow = self.model.flow.to(torch.bfloat16) 45 | del configs 46 | 47 | def token_to_wav_offline( 48 | self, 49 | speech_token, 50 | speech_feat, 51 | speech_feat_len, 52 | prompt_token, 53 | prompt_token_len, 54 | embedding, 55 | ): 56 | tts_mel = self.model.flow.inference( 57 | token=speech_token.to(self.model.device), 58 | token_len=torch.tensor([speech_token.size(1)], dtype=torch.int32).to( 59 | self.model.device 60 | ), 61 | prompt_token=prompt_token.to(self.model.device), 62 | prompt_token_len=prompt_token_len.to(self.model.device), 63 | prompt_feat=speech_feat.to(self.model.device), 64 | prompt_feat_len=speech_feat_len.to(self.model.device), 65 | embedding=embedding.to(self.model.device), 66 | ) 67 | tts_speech = self.model.hift.inference(mel=tts_mel.float())[0].cpu() 68 | return tts_speech 69 | -------------------------------------------------------------------------------- /cosyvoice/cli/model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import torch 15 | 16 | 17 | class CosyVoiceModel: 18 | 19 | def __init__( 20 | self, 21 | flow: torch.nn.Module, 22 | hift: torch.nn.Module, 23 | ): 24 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 25 | self.flow = flow 26 | self.hift = hift 27 | 28 | def load(self, flow_model, hift_model): 29 | self.flow.load_state_dict(torch.load(flow_model, map_location=self.device)) 30 | self.flow.to(self.device).eval() 31 | self.hift.load_state_dict(torch.load(hift_model, map_location=self.device)) 32 | self.hift.to(self.device).eval() 33 | -------------------------------------------------------------------------------- /cosyvoice/flow/length_regulator.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from typing import Tuple 15 | import torch.nn as nn 16 | import torch 17 | from torch.nn import functional as F 18 | from cosyvoice.utils.mask import make_pad_mask 19 | 20 | 21 | class InterpolateRegulator(nn.Module): 22 | def __init__( 23 | self, 24 | channels: int, 25 | sampling_ratios: Tuple, 26 | out_channels: int = None, 27 | groups: int = 1, 28 | ): 29 | super().__init__() 30 | self.sampling_ratios = sampling_ratios 31 | out_channels = out_channels or channels 32 | model = nn.ModuleList([]) 33 | if len(sampling_ratios) > 0: 34 | for _ in sampling_ratios: 35 | module = nn.Conv1d(channels, channels, 3, 1, 1) 36 | norm = nn.GroupNorm(groups, channels) 37 | act = nn.Mish() 38 | model.extend([module, norm, act]) 39 | model.append(nn.Conv1d(channels, out_channels, 1, 1)) 40 | self.model = nn.Sequential(*model) 41 | 42 | def forward(self, x, ylens=None): 43 | # x in (B, T, D) 44 | mask = (~make_pad_mask(ylens)).to(x).unsqueeze(-1) 45 | x = F.interpolate( 46 | x.transpose(1, 2).contiguous(), size=ylens.max(), mode="linear" 47 | ) 48 | out = self.model(x).transpose(1, 2).contiguous() 49 | olens = ylens 50 | return out * mask, olens 51 | 52 | def inference(self, x1, x2, mel_len1, mel_len2): 53 | # x in (B, T, D) 54 | x2 = F.interpolate( 55 | x2.transpose(1, 2).contiguous(), size=mel_len2, mode="linear" 56 | ) 57 | if x1.shape[1] != 0: 58 | x1 = F.interpolate( 59 | x1.transpose(1, 2).contiguous(), size=mel_len1, mode="linear" 60 | ) 61 | x = torch.concat([x1, x2], dim=2) 62 | else: 63 | x = x2 64 | out = self.model(x).transpose(1, 2).contiguous() 65 | return out, mel_len1 + mel_len2 66 | -------------------------------------------------------------------------------- /cosyvoice/hifigan/f0_predictor.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Kai Hu) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import torch 15 | import torch.nn as nn 16 | from torch.nn.utils import weight_norm 17 | 18 | 19 | class ConvRNNF0Predictor(nn.Module): 20 | def __init__( 21 | self, num_class: int = 1, in_channels: int = 80, cond_channels: int = 512 22 | ): 23 | super().__init__() 24 | 25 | self.num_class = num_class 26 | self.condnet = nn.Sequential( 27 | weight_norm( 28 | nn.Conv1d(in_channels, cond_channels, kernel_size=3, padding=1) 29 | ), 30 | nn.ELU(), 31 | weight_norm( 32 | nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1) 33 | ), 34 | nn.ELU(), 35 | weight_norm( 36 | nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1) 37 | ), 38 | nn.ELU(), 39 | weight_norm( 40 | nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1) 41 | ), 42 | nn.ELU(), 43 | weight_norm( 44 | nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1) 45 | ), 46 | nn.ELU(), 47 | ) 48 | self.classifier = nn.Linear( 49 | in_features=cond_channels, out_features=self.num_class 50 | ) 51 | 52 | def forward(self, x: torch.Tensor) -> torch.Tensor: 53 | x = self.condnet(x) 54 | x = x.transpose(1, 2) 55 | return torch.abs(self.classifier(x).squeeze(-1)) 56 | -------------------------------------------------------------------------------- /cosyvoice/matcha/audio.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.utils.data 4 | from librosa.filters import mel as librosa_mel_fn 5 | from scipy.io.wavfile import read 6 | 7 | MAX_WAV_VALUE = 32768.0 8 | 9 | 10 | def load_wav(full_path): 11 | sampling_rate, data = read(full_path) 12 | return data, sampling_rate 13 | 14 | 15 | def dynamic_range_compression(x, C=1, clip_val=1e-5): 16 | return np.log(np.clip(x, a_min=clip_val, a_max=None) * C) 17 | 18 | 19 | def dynamic_range_decompression(x, C=1): 20 | return np.exp(x) / C 21 | 22 | 23 | def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): 24 | return torch.log(torch.clamp(x, min=clip_val) * C) 25 | 26 | 27 | def dynamic_range_decompression_torch(x, C=1): 28 | return torch.exp(x) / C 29 | 30 | 31 | def spectral_normalize_torch(magnitudes): 32 | output = dynamic_range_compression_torch(magnitudes) 33 | return output 34 | 35 | 36 | def spectral_de_normalize_torch(magnitudes): 37 | output = dynamic_range_decompression_torch(magnitudes) 38 | return output 39 | 40 | 41 | mel_basis = {} 42 | hann_window = {} 43 | 44 | 45 | def mel_spectrogram( 46 | y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False 47 | ): 48 | if torch.min(y) < -1.0: 49 | print("min value is ", torch.min(y)) 50 | if torch.max(y) > 1.0: 51 | print("max value is ", torch.max(y)) 52 | 53 | global mel_basis, hann_window # pylint: disable=global-statement 54 | if f"{str(fmax)}_{str(y.device)}" not in mel_basis: 55 | mel = librosa_mel_fn( 56 | sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax 57 | ) 58 | mel_basis[str(fmax) + "_" + str(y.device)] = ( 59 | torch.from_numpy(mel).float().to(y.device) 60 | ) 61 | hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device) 62 | 63 | y = torch.nn.functional.pad( 64 | y.unsqueeze(1), 65 | (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), 66 | mode="reflect", 67 | ) 68 | y = y.squeeze(1) 69 | 70 | spec = torch.view_as_real( 71 | torch.stft( 72 | y, 73 | n_fft, 74 | hop_length=hop_size, 75 | win_length=win_size, 76 | window=hann_window[str(y.device)], 77 | center=center, 78 | pad_mode="reflect", 79 | normalized=False, 80 | onesided=True, 81 | return_complex=True, 82 | ) 83 | ) 84 | 85 | spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9)) 86 | 87 | spec = torch.matmul(mel_basis[str(fmax) + "_" + str(y.device)], spec) 88 | spec = spectral_normalize_torch(spec) 89 | 90 | return spec 91 | -------------------------------------------------------------------------------- /cosyvoice/transformer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stepfun-ai/Step-Audio/d67e374fd6aff418835418e2e32989ca2db00dff/cosyvoice/transformer/__init__.py -------------------------------------------------------------------------------- /cosyvoice/transformer/activation.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 Johns Hopkins University (Shinji Watanabe) 2 | # 2020 Northwestern Polytechnical University (Pengcheng Guo) 3 | # 2020 Mobvoi Inc (Binbin Zhang) 4 | # 2024 Alibaba Inc (Xiang Lyu) 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | """Swish() activation function for Conformer.""" 18 | 19 | import torch 20 | from torch import nn, sin, pow 21 | from torch.nn import Parameter 22 | 23 | 24 | class Swish(torch.nn.Module): 25 | """Construct an Swish object.""" 26 | 27 | def forward(self, x: torch.Tensor) -> torch.Tensor: 28 | """Return Swish activation function.""" 29 | return x * torch.sigmoid(x) 30 | 31 | 32 | # Implementation adapted from https://github.com/EdwardDixon/snake under the MIT license. 33 | # LICENSE is in incl_licenses directory. 34 | class Snake(nn.Module): 35 | """ 36 | Implementation of a sine-based periodic activation function 37 | Shape: 38 | - Input: (B, C, T) 39 | - Output: (B, C, T), same shape as the input 40 | Parameters: 41 | - alpha - trainable parameter 42 | References: 43 | - This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda: 44 | https://arxiv.org/abs/2006.08195 45 | Examples: 46 | >>> a1 = snake(256) 47 | >>> x = torch.randn(256) 48 | >>> x = a1(x) 49 | """ 50 | 51 | def __init__( 52 | self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False 53 | ): 54 | """ 55 | Initialization. 56 | INPUT: 57 | - in_features: shape of the input 58 | - alpha: trainable parameter 59 | alpha is initialized to 1 by default, higher values = higher-frequency. 60 | alpha will be trained along with the rest of your model. 61 | """ 62 | super(Snake, self).__init__() 63 | self.in_features = in_features 64 | 65 | # initialize alpha 66 | self.alpha_logscale = alpha_logscale 67 | if self.alpha_logscale: # log scale alphas initialized to zeros 68 | self.alpha = Parameter(torch.zeros(in_features) * alpha) 69 | else: # linear scale alphas initialized to ones 70 | self.alpha = Parameter(torch.ones(in_features) * alpha) 71 | 72 | self.alpha.requires_grad = alpha_trainable 73 | 74 | self.no_div_by_zero = 0.000000001 75 | 76 | def forward(self, x): 77 | """ 78 | Forward pass of the function. 79 | Applies the function to the input elementwise. 80 | Snake ∶= x + 1/a * sin^2 (xa) 81 | """ 82 | alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T] 83 | if self.alpha_logscale: 84 | alpha = torch.exp(alpha) 85 | x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2) 86 | 87 | return x 88 | -------------------------------------------------------------------------------- /cosyvoice/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stepfun-ai/Step-Audio/d67e374fd6aff418835418e2e32989ca2db00dff/cosyvoice/utils/__init__.py -------------------------------------------------------------------------------- /cosyvoice/utils/audio.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.utils.data 4 | from librosa.filters import mel as librosa_mel_fn 5 | from scipy.io.wavfile import read 6 | 7 | MAX_WAV_VALUE = 32768.0 8 | 9 | 10 | def load_wav(full_path): 11 | sampling_rate, data = read(full_path) 12 | return data, sampling_rate 13 | 14 | 15 | def dynamic_range_compression(x, C=1, clip_val=1e-5): 16 | return np.log(np.clip(x, a_min=clip_val, a_max=None) * C) 17 | 18 | 19 | def dynamic_range_decompression(x, C=1): 20 | return np.exp(x) / C 21 | 22 | 23 | def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): 24 | return torch.log(torch.clamp(x, min=clip_val) * C) 25 | 26 | 27 | def dynamic_range_decompression_torch(x, C=1): 28 | return torch.exp(x) / C 29 | 30 | 31 | def spectral_normalize_torch(magnitudes): 32 | output = dynamic_range_compression_torch(magnitudes) 33 | return output 34 | 35 | 36 | def spectral_de_normalize_torch(magnitudes): 37 | output = dynamic_range_decompression_torch(magnitudes) 38 | return output 39 | 40 | 41 | mel_basis = {} 42 | hann_window = {} 43 | 44 | 45 | def mel_spectrogram( 46 | y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False 47 | ): 48 | # if torch.min(y) < -1.0: 49 | # print("min value is ", torch.min(y)) 50 | # if torch.max(y) > 1.0: 51 | # print("max value is ", torch.max(y)) 52 | 53 | global mel_basis, hann_window # pylint: disable=global-statement 54 | if f"{str(fmax)}_{str(y.device)}" not in mel_basis: 55 | mel = librosa_mel_fn( 56 | sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax 57 | ) 58 | mel_basis[str(fmax) + "_" + str(y.device)] = ( 59 | torch.from_numpy(mel).float().to(y.device) 60 | ) 61 | hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device) 62 | 63 | y = torch.nn.functional.pad( 64 | y.unsqueeze(1), 65 | (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), 66 | mode="reflect", 67 | ) 68 | y = y.squeeze(1) 69 | 70 | spec = torch.view_as_real( 71 | torch.stft( 72 | y, 73 | n_fft, 74 | hop_length=hop_size, 75 | win_length=win_size, 76 | window=hann_window[str(y.device)], 77 | center=center, 78 | pad_mode="reflect", 79 | normalized=False, 80 | onesided=True, 81 | return_complex=True, 82 | ) 83 | ) 84 | 85 | spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9)) 86 | 87 | spec = torch.matmul(mel_basis[str(fmax) + "_" + str(y.device)], spec) 88 | spec = spectral_normalize_torch(spec) 89 | 90 | return spec 91 | -------------------------------------------------------------------------------- /cosyvoice/utils/class_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright [2023-11-28] 2 | # 2024 Alibaba Inc (authors: Xiang Lyu) 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | import torch 16 | 17 | from cosyvoice.transformer.activation import Swish 18 | from cosyvoice.transformer.subsampling import ( 19 | LinearNoSubsampling, 20 | EmbedinigNoSubsampling, 21 | Conv1dSubsampling2, 22 | Conv2dSubsampling4, 23 | Conv2dSubsampling6, 24 | Conv2dSubsampling8, 25 | ) 26 | from cosyvoice.transformer.embedding import ( 27 | PositionalEncoding, 28 | RelPositionalEncoding, 29 | WhisperPositionalEncoding, 30 | LearnablePositionalEncoding, 31 | NoPositionalEncoding, 32 | ) 33 | from cosyvoice.transformer.attention import ( 34 | MultiHeadedAttention, 35 | RelPositionMultiHeadedAttention, 36 | ) 37 | from cosyvoice.transformer.embedding import ( 38 | EspnetRelPositionalEncoding, 39 | ) 40 | from cosyvoice.transformer.subsampling import ( 41 | LegacyLinearNoSubsampling, 42 | ) 43 | 44 | 45 | COSYVOICE_ACTIVATION_CLASSES = { 46 | "hardtanh": torch.nn.Hardtanh, 47 | "tanh": torch.nn.Tanh, 48 | "relu": torch.nn.ReLU, 49 | "selu": torch.nn.SELU, 50 | "swish": getattr(torch.nn, "SiLU", Swish), 51 | "gelu": torch.nn.GELU, 52 | } 53 | 54 | COSYVOICE_SUBSAMPLE_CLASSES = { 55 | "linear": LinearNoSubsampling, 56 | "linear_legacy": LegacyLinearNoSubsampling, 57 | "embed": EmbedinigNoSubsampling, 58 | "conv1d2": Conv1dSubsampling2, 59 | "conv2d": Conv2dSubsampling4, 60 | "conv2d6": Conv2dSubsampling6, 61 | "conv2d8": Conv2dSubsampling8, 62 | "paraformer_dummy": torch.nn.Identity, 63 | } 64 | 65 | COSYVOICE_EMB_CLASSES = { 66 | "embed": PositionalEncoding, 67 | "abs_pos": PositionalEncoding, 68 | "rel_pos": RelPositionalEncoding, 69 | "rel_pos_espnet": EspnetRelPositionalEncoding, 70 | "no_pos": NoPositionalEncoding, 71 | "abs_pos_whisper": WhisperPositionalEncoding, 72 | "embed_learnable_pe": LearnablePositionalEncoding, 73 | } 74 | 75 | COSYVOICE_ATTENTION_CLASSES = { 76 | "selfattn": MultiHeadedAttention, 77 | "rel_selfattn": RelPositionMultiHeadedAttention, 78 | } 79 | -------------------------------------------------------------------------------- /cosyvoice/utils/file_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang) 2 | # 2024 Alibaba Inc (authors: Xiang Lyu) 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import json 17 | import torchaudio 18 | import logging 19 | 20 | logging.getLogger("matplotlib").setLevel(logging.WARNING) 21 | logging.basicConfig(level=logging.DEBUG, format="%(asctime)s %(levelname)s %(message)s") 22 | 23 | 24 | def read_lists(list_file): 25 | lists = [] 26 | with open(list_file, "r", encoding="utf8") as fin: 27 | for line in fin: 28 | lists.append(line.strip()) 29 | return lists 30 | 31 | 32 | def read_json_lists(list_file): 33 | lists = read_lists(list_file) 34 | results = {} 35 | for fn in lists: 36 | with open(fn, "r", encoding="utf8") as fin: 37 | results.update(json.load(fin)) 38 | return results 39 | 40 | 41 | def load_wav(wav, target_sr): 42 | speech, sample_rate = torchaudio.load(wav) 43 | speech = speech.mean(dim=0, keepdim=True) 44 | if sample_rate != target_sr: 45 | # assert sample_rate > target_sr, 'wav sample rate {} must be greater than {}'.format(sample_rate, target_sr) 46 | speech = torchaudio.transforms.Resample( 47 | orig_freq=sample_rate, new_freq=target_sr 48 | )(speech) 49 | return speech 50 | -------------------------------------------------------------------------------- /examples/clone_wav_lixueqin.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stepfun-ai/Step-Audio/d67e374fd6aff418835418e2e32989ca2db00dff/examples/clone_wav_lixueqin.wav -------------------------------------------------------------------------------- /examples/clone_wav_yuqian.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stepfun-ai/Step-Audio/d67e374fd6aff418835418e2e32989ca2db00dff/examples/clone_wav_yuqian.wav -------------------------------------------------------------------------------- /examples/emotional_control1.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stepfun-ai/Step-Audio/d67e374fd6aff418835418e2e32989ca2db00dff/examples/emotional_control1.wav -------------------------------------------------------------------------------- /examples/emotional_control2.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stepfun-ai/Step-Audio/d67e374fd6aff418835418e2e32989ca2db00dff/examples/emotional_control2.wav -------------------------------------------------------------------------------- /examples/multilingual1.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stepfun-ai/Step-Audio/d67e374fd6aff418835418e2e32989ca2db00dff/examples/multilingual1.wav -------------------------------------------------------------------------------- /examples/multilingual2.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stepfun-ai/Step-Audio/d67e374fd6aff418835418e2e32989ca2db00dff/examples/multilingual2.wav -------------------------------------------------------------------------------- /examples/multilingual_singing.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stepfun-ai/Step-Audio/d67e374fd6aff418835418e2e32989ca2db00dff/examples/multilingual_singing.wav -------------------------------------------------------------------------------- /examples/prompt_wav_lixueqin.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stepfun-ai/Step-Audio/d67e374fd6aff418835418e2e32989ca2db00dff/examples/prompt_wav_lixueqin.wav -------------------------------------------------------------------------------- /examples/prompt_wav_yuqian.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stepfun-ai/Step-Audio/d67e374fd6aff418835418e2e32989ca2db00dff/examples/prompt_wav_yuqian.wav -------------------------------------------------------------------------------- /examples/prompt_wav_zhaobenshan.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stepfun-ai/Step-Audio/d67e374fd6aff418835418e2e32989ca2db00dff/examples/prompt_wav_zhaobenshan.wav -------------------------------------------------------------------------------- /examples/rap.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stepfun-ai/Step-Audio/d67e374fd6aff418835418e2e32989ca2db00dff/examples/rap.wav -------------------------------------------------------------------------------- /examples/singing.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stepfun-ai/Step-Audio/d67e374fd6aff418835418e2e32989ca2db00dff/examples/singing.wav -------------------------------------------------------------------------------- /examples/speed_control1.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stepfun-ai/Step-Audio/d67e374fd6aff418835418e2e32989ca2db00dff/examples/speed_control1.wav -------------------------------------------------------------------------------- /examples/speed_control2.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stepfun-ai/Step-Audio/d67e374fd6aff418835418e2e32989ca2db00dff/examples/speed_control2.wav -------------------------------------------------------------------------------- /examples/tone_control.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stepfun-ai/Step-Audio/d67e374fd6aff418835418e2e32989ca2db00dff/examples/tone_control.wav -------------------------------------------------------------------------------- /funasr_detach/__init__.py: -------------------------------------------------------------------------------- 1 | """Initialize funasr package.""" 2 | 3 | import os 4 | import pkgutil 5 | import importlib 6 | 7 | dirname = os.path.dirname(__file__) 8 | version_file = os.path.join(dirname, "version.txt") 9 | with open(version_file, "r") as f: 10 | __version__ = f.read().strip() 11 | 12 | 13 | import importlib 14 | import pkgutil 15 | 16 | 17 | def import_submodules(package, recursive=True): 18 | if isinstance(package, str): 19 | package = importlib.import_module(package) 20 | results = {} 21 | for loader, name, is_pkg in pkgutil.walk_packages( 22 | package.__path__, package.__name__ + "." 23 | ): 24 | try: 25 | results[name] = importlib.import_module(name) 26 | except Exception as e: 27 | # 如果想要看到导入错误的具体信息,可以取消注释下面的行 28 | # print(f"Failed to import {name}: {e}") 29 | pass 30 | if recursive and is_pkg: 31 | results.update(import_submodules(name)) 32 | return results 33 | 34 | 35 | import_submodules(__name__) 36 | 37 | from funasr_detach.auto.auto_model import AutoModel 38 | from funasr_detach.auto.auto_frontend import AutoFrontend 39 | -------------------------------------------------------------------------------- /funasr_detach/auto/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stepfun-ai/Step-Audio/d67e374fd6aff418835418e2e32989ca2db00dff/funasr_detach/auto/__init__.py -------------------------------------------------------------------------------- /funasr_detach/auto/auto_frontend.py: -------------------------------------------------------------------------------- 1 | import time 2 | import logging 3 | from tqdm import tqdm 4 | 5 | from funasr_detach.register import tables 6 | from funasr_detach.download.download_from_hub import download_model 7 | from funasr_detach.utils.load_utils import load_audio_text_image_video, extract_fbank 8 | from funasr_detach.auto.auto_model import prepare_data_iterator 9 | from funasr_detach.auto.auto_model import prepare_data_iterator 10 | 11 | 12 | class AutoFrontend: 13 | def __init__(self, **kwargs): 14 | assert "model" in kwargs 15 | if "model_conf" not in kwargs: 16 | logging.info( 17 | "download models from model hub: {}".format( 18 | kwargs.get("model_hub", "ms") 19 | ) 20 | ) 21 | kwargs = download_model(**kwargs) 22 | 23 | # build frontend 24 | frontend = kwargs.get("frontend", None) 25 | if frontend is not None: 26 | frontend_class = tables.frontend_classes.get(frontend) 27 | frontend = frontend_class(**kwargs["frontend_conf"]) 28 | 29 | self.frontend = frontend 30 | if "frontend" in kwargs: 31 | del kwargs["frontend"] 32 | self.kwargs = kwargs 33 | 34 | def __call__(self, input, input_len=None, kwargs=None, **cfg): 35 | 36 | kwargs = self.kwargs if kwargs is None else kwargs 37 | kwargs.update(cfg) 38 | 39 | key_list, data_list = prepare_data_iterator(input, input_len=input_len) 40 | batch_size = kwargs.get("batch_size", 1) 41 | device = kwargs.get("device", "cpu") 42 | if device == "cpu": 43 | batch_size = 1 44 | 45 | meta_data = {} 46 | 47 | result_list = [] 48 | num_samples = len(data_list) 49 | pbar = tqdm(colour="blue", total=num_samples + 1, dynamic_ncols=True) 50 | 51 | time0 = time.perf_counter() 52 | for beg_idx in range(0, num_samples, batch_size): 53 | end_idx = min(num_samples, beg_idx + batch_size) 54 | data_batch = data_list[beg_idx:end_idx] 55 | key_batch = key_list[beg_idx:end_idx] 56 | 57 | # extract fbank feats 58 | time1 = time.perf_counter() 59 | audio_sample_list = load_audio_text_image_video( 60 | data_batch, fs=self.frontend.fs, audio_fs=kwargs.get("fs", 16000) 61 | ) 62 | time2 = time.perf_counter() 63 | meta_data["load_data"] = f"{time2 - time1:0.3f}" 64 | speech, speech_lengths = extract_fbank( 65 | audio_sample_list, 66 | data_type=kwargs.get("data_type", "sound"), 67 | frontend=self.frontend, 68 | **kwargs, 69 | ) 70 | time3 = time.perf_counter() 71 | meta_data["extract_feat"] = f"{time3 - time2:0.3f}" 72 | meta_data["batch_data_time"] = ( 73 | speech_lengths.sum().item() 74 | * self.frontend.frame_shift 75 | * self.frontend.lfr_n 76 | / 1000 77 | ) 78 | 79 | speech.to(device=device), speech_lengths.to(device=device) 80 | batch = {"input": speech, "input_len": speech_lengths, "key": key_batch} 81 | result_list.append(batch) 82 | 83 | pbar.update(1) 84 | description = f"{meta_data}, " 85 | pbar.set_description(description) 86 | 87 | time_end = time.perf_counter() 88 | pbar.set_description(f"time escaped total: {time_end - time0:0.3f}") 89 | 90 | return result_list 91 | -------------------------------------------------------------------------------- /funasr_detach/auto/auto_tokenizer.py: -------------------------------------------------------------------------------- 1 | class AutoTokenizer: 2 | """ 3 | Undo 4 | """ 5 | 6 | def __init__(self): 7 | pass 8 | -------------------------------------------------------------------------------- /funasr_detach/bin/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stepfun-ai/Step-Audio/d67e374fd6aff418835418e2e32989ca2db00dff/funasr_detach/bin/__init__.py -------------------------------------------------------------------------------- /funasr_detach/bin/inference.py: -------------------------------------------------------------------------------- 1 | import hydra 2 | import logging 3 | from omegaconf import DictConfig, OmegaConf, ListConfig 4 | 5 | from funasr_detach.auto.auto_model import AutoModel 6 | 7 | 8 | @hydra.main(config_name=None, version_base=None) 9 | def main_hydra(cfg: DictConfig): 10 | def to_plain_list(cfg_item): 11 | if isinstance(cfg_item, ListConfig): 12 | return OmegaConf.to_container(cfg_item, resolve=True) 13 | elif isinstance(cfg_item, DictConfig): 14 | return {k: to_plain_list(v) for k, v in cfg_item.items()} 15 | else: 16 | return cfg_item 17 | 18 | kwargs = to_plain_list(cfg) 19 | log_level = getattr(logging, kwargs.get("log_level", "INFO").upper()) 20 | 21 | logging.basicConfig(level=log_level) 22 | 23 | if kwargs.get("debug", False): 24 | import pdb 25 | 26 | pdb.set_trace() 27 | model = AutoModel(**kwargs) 28 | res = model.generate(input=kwargs["input"]) 29 | print(res) 30 | 31 | 32 | if __name__ == "__main__": 33 | main_hydra() 34 | -------------------------------------------------------------------------------- /funasr_detach/datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stepfun-ai/Step-Audio/d67e374fd6aff418835418e2e32989ca2db00dff/funasr_detach/datasets/__init__.py -------------------------------------------------------------------------------- /funasr_detach/datasets/audio_datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stepfun-ai/Step-Audio/d67e374fd6aff418835418e2e32989ca2db00dff/funasr_detach/datasets/audio_datasets/__init__.py -------------------------------------------------------------------------------- /funasr_detach/datasets/audio_datasets/preprocessor.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import torch 4 | import logging 5 | import concurrent.futures 6 | import librosa 7 | import torch.distributed as dist 8 | from typing import Collection 9 | import torch 10 | import torchaudio 11 | from torch import nn 12 | import random 13 | import re 14 | from funasr_detach.tokenizer.cleaner import TextCleaner 15 | from funasr_detach.register import tables 16 | 17 | 18 | @tables.register("preprocessor_classes", "SpeechPreprocessSpeedPerturb") 19 | class SpeechPreprocessSpeedPerturb(nn.Module): 20 | def __init__(self, speed_perturb: list = None, **kwargs): 21 | super().__init__() 22 | self.speed_perturb = speed_perturb 23 | 24 | def forward(self, waveform, fs, **kwargs): 25 | if self.speed_perturb is None: 26 | return waveform 27 | speed = random.choice(self.speed_perturb) 28 | if speed != 1.0: 29 | if not isinstance(waveform, torch.Tensor): 30 | waveform = torch.tensor(waveform) 31 | waveform, _ = torchaudio.sox_effects.apply_effects_tensor( 32 | waveform.view(1, -1), fs, [["speed", str(speed)], ["rate", str(fs)]] 33 | ) 34 | waveform = waveform.view(-1) 35 | 36 | return waveform 37 | 38 | 39 | @tables.register("preprocessor_classes", "TextPreprocessSegDict") 40 | class TextPreprocessSegDict(nn.Module): 41 | def __init__( 42 | self, 43 | seg_dict: str = None, 44 | text_cleaner: Collection[str] = None, 45 | split_with_space: bool = False, 46 | **kwargs 47 | ): 48 | super().__init__() 49 | 50 | self.text_cleaner = TextCleaner(text_cleaner) 51 | 52 | def forward(self, text, **kwargs): 53 | text = self.text_cleaner(text) 54 | 55 | return text 56 | -------------------------------------------------------------------------------- /funasr_detach/download/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stepfun-ai/Step-Audio/d67e374fd6aff418835418e2e32989ca2db00dff/funasr_detach/download/__init__.py -------------------------------------------------------------------------------- /funasr_detach/download/download_dataset_from_hub.py: -------------------------------------------------------------------------------- 1 | def download_dataset(): 2 | pass 3 | 4 | 5 | def download_dataset_from_ms(**kwargs): 6 | from modelscope.msdatasets import MsDataset 7 | 8 | dataset_name = kwargs.get( 9 | "dataset_name", "speech_asr/speech_asr_aishell1_trainsets" 10 | ) 11 | subset_name = kwargs.get("subset_name", "default") 12 | split = kwargs.get("split", "train") 13 | data_dump_dir = kwargs.get("data_dump_dir", None) 14 | ds = MsDataset.load( 15 | dataset_name=dataset_name, 16 | subset_name=subset_name, 17 | split=split, 18 | cache_dir=data_dump_dir, 19 | ) 20 | -------------------------------------------------------------------------------- /funasr_detach/download/name_maps_from_hub.py: -------------------------------------------------------------------------------- 1 | name_maps_ms = { 2 | "paraformer-zh": "damo/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch", 3 | "paraformer-en": "damo/speech_paraformer-large-vad-punc_asr_nat-en-16k-common-vocab10020", 4 | "paraformer-en-spk": "damo/speech_paraformer-large-vad-punc_asr_nat-en-16k-common-vocab10020", 5 | "paraformer-zh-streaming": "damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online", 6 | "fsmn-vad": "damo/speech_fsmn_vad_zh-cn-16k-common-pytorch", 7 | "ct-punc": "damo/punc_ct-transformer_cn-en-common-vocab471067-large", 8 | "ct-punc-c": "damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch", 9 | "fa-zh": "damo/speech_timestamp_prediction-v1-16k-offline", 10 | "cam++": "damo/speech_campplus_sv_zh-cn_16k-common", 11 | } 12 | 13 | name_maps_hf = {} 14 | -------------------------------------------------------------------------------- /funasr_detach/download/runtime_sdk_download_tool.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | from pathlib import Path 4 | 5 | from funasr_detach.utils.types import str2bool 6 | 7 | 8 | def main(): 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument("--model-name", type=str, required=True) 11 | parser.add_argument("--export-dir", type=str, required=True) 12 | parser.add_argument( 13 | "--export", type=str2bool, default=True, help="whether to export model" 14 | ) 15 | parser.add_argument("--type", type=str, default="onnx", help='["onnx", "torch"]') 16 | parser.add_argument("--device", type=str, default="cpu", help='["cpu", "cuda"]') 17 | parser.add_argument( 18 | "--quantize", type=str2bool, default=False, help="export quantized model" 19 | ) 20 | parser.add_argument( 21 | "--fallback-num", type=int, default=0, help="amp fallback number" 22 | ) 23 | parser.add_argument("--audio_in", type=str, default=None, help='["wav", "wav.scp"]') 24 | parser.add_argument( 25 | "--model_revision", type=str, default=None, help="model_revision" 26 | ) 27 | parser.add_argument("--calib_num", type=int, default=200, help="calib max num") 28 | args = parser.parse_args() 29 | 30 | model_dir = args.model_name 31 | if not Path(args.model_name).exists(): 32 | from modelscope.hub.snapshot_download import snapshot_download 33 | 34 | try: 35 | model_dir = snapshot_download( 36 | args.model_name, cache_dir=args.export_dir, revision=args.model_revision 37 | ) 38 | except: 39 | raise "model_dir must be model_name in modelscope or local path downloaded from modelscope, but is {}".format( 40 | model_dir 41 | ) 42 | if args.export: 43 | model_file = os.path.join(model_dir, "model.onnx") 44 | if args.quantize: 45 | model_file = os.path.join(model_dir, "model_quant.onnx") 46 | if not os.path.exists(model_file): 47 | print(".onnx is not exist, begin to export onnx") 48 | from funasr_detach.bin.export_model import ModelExport 49 | 50 | export_model = ModelExport( 51 | cache_dir=args.export_dir, 52 | onnx=True, 53 | device="cpu", 54 | quant=args.quantize, 55 | ) 56 | export_model.export(model_dir) 57 | 58 | 59 | if __name__ == "__main__": 60 | main() 61 | -------------------------------------------------------------------------------- /funasr_detach/frontends/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stepfun-ai/Step-Audio/d67e374fd6aff418835418e2e32989ca2db00dff/funasr_detach/frontends/__init__.py -------------------------------------------------------------------------------- /funasr_detach/frontends/eend_ola_feature.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Hitachi, Ltd. (author: Yusuke Fujita) 2 | # Licensed under the MIT license. 3 | # 4 | # This module is for computing audio features 5 | 6 | import librosa 7 | import numpy as np 8 | 9 | 10 | def transform(Y, dtype=np.float32): 11 | Y = np.abs(Y) 12 | n_fft = 2 * (Y.shape[1] - 1) 13 | sr = 8000 14 | n_mels = 23 15 | mel_basis = librosa.filters.mel(sr, n_fft, n_mels) 16 | Y = np.dot(Y**2, mel_basis.T) 17 | Y = np.log10(np.maximum(Y, 1e-10)) 18 | mean = np.mean(Y, axis=0) 19 | Y = Y - mean 20 | return Y.astype(dtype) 21 | 22 | 23 | def subsample(Y, T, subsampling=1): 24 | Y_ss = Y[::subsampling] 25 | T_ss = T[::subsampling] 26 | return Y_ss, T_ss 27 | 28 | 29 | def splice(Y, context_size=0): 30 | Y_pad = np.pad(Y, [(context_size, context_size), (0, 0)], "constant") 31 | Y_spliced = np.lib.stride_tricks.as_strided( 32 | np.ascontiguousarray(Y_pad), 33 | (Y.shape[0], Y.shape[1] * (2 * context_size + 1)), 34 | (Y.itemsize * Y.shape[1], Y.itemsize), 35 | writeable=False, 36 | ) 37 | return Y_spliced 38 | 39 | 40 | def stft(data, frame_size=1024, frame_shift=256): 41 | fft_size = 1 << (frame_size - 1).bit_length() 42 | if len(data) % frame_shift == 0: 43 | return librosa.stft( 44 | data, n_fft=fft_size, win_length=frame_size, hop_length=frame_shift 45 | ).T[:-1] 46 | else: 47 | return librosa.stft( 48 | data, n_fft=fft_size, win_length=frame_size, hop_length=frame_shift 49 | ).T 50 | -------------------------------------------------------------------------------- /funasr_detach/frontends/utils/__init__.py: -------------------------------------------------------------------------------- 1 | """Initialize sub package.""" 2 | -------------------------------------------------------------------------------- /funasr_detach/frontends/utils/beamformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_complex import functional as FC 3 | from torch_complex.tensor import ComplexTensor 4 | 5 | 6 | def get_power_spectral_density_matrix( 7 | xs: ComplexTensor, mask: torch.Tensor, normalization=True, eps: float = 1e-15 8 | ) -> ComplexTensor: 9 | """Return cross-channel power spectral density (PSD) matrix 10 | 11 | Args: 12 | xs (ComplexTensor): (..., F, C, T) 13 | mask (torch.Tensor): (..., F, C, T) 14 | normalization (bool): 15 | eps (float): 16 | Returns 17 | psd (ComplexTensor): (..., F, C, C) 18 | 19 | """ 20 | # outer product: (..., C_1, T) x (..., C_2, T) -> (..., T, C, C_2) 21 | psd_Y = FC.einsum("...ct,...et->...tce", [xs, xs.conj()]) 22 | 23 | # Averaging mask along C: (..., C, T) -> (..., T) 24 | mask = mask.mean(dim=-2) 25 | 26 | # Normalized mask along T: (..., T) 27 | if normalization: 28 | # If assuming the tensor is padded with zero, the summation along 29 | # the time axis is same regardless of the padding length. 30 | mask = mask / (mask.sum(dim=-1, keepdim=True) + eps) 31 | 32 | # psd: (..., T, C, C) 33 | psd = psd_Y * mask[..., None, None] 34 | # (..., T, C, C) -> (..., C, C) 35 | psd = psd.sum(dim=-3) 36 | 37 | return psd 38 | 39 | 40 | def get_mvdr_vector( 41 | psd_s: ComplexTensor, 42 | psd_n: ComplexTensor, 43 | reference_vector: torch.Tensor, 44 | eps: float = 1e-15, 45 | ) -> ComplexTensor: 46 | """Return the MVDR(Minimum Variance Distortionless Response) vector: 47 | 48 | h = (Npsd^-1 @ Spsd) / (Tr(Npsd^-1 @ Spsd)) @ u 49 | 50 | Reference: 51 | On optimal frequency-domain multichannel linear filtering 52 | for noise reduction; M. Souden et al., 2010; 53 | https://ieeexplore.ieee.org/document/5089420 54 | 55 | Args: 56 | psd_s (ComplexTensor): (..., F, C, C) 57 | psd_n (ComplexTensor): (..., F, C, C) 58 | reference_vector (torch.Tensor): (..., C) 59 | eps (float): 60 | Returns: 61 | beamform_vector (ComplexTensor)r: (..., F, C) 62 | """ 63 | # Add eps 64 | C = psd_n.size(-1) 65 | eye = torch.eye(C, dtype=psd_n.dtype, device=psd_n.device) 66 | shape = [1 for _ in range(psd_n.dim() - 2)] + [C, C] 67 | eye = eye.view(*shape) 68 | psd_n += eps * eye 69 | 70 | # numerator: (..., C_1, C_2) x (..., C_2, C_3) -> (..., C_1, C_3) 71 | numerator = FC.einsum("...ec,...cd->...ed", [psd_n.inverse(), psd_s]) 72 | # ws: (..., C, C) / (...,) -> (..., C, C) 73 | ws = numerator / (FC.trace(numerator)[..., None, None] + eps) 74 | # h: (..., F, C_1, C_2) x (..., C_2) -> (..., F, C_1) 75 | beamform_vector = FC.einsum("...fec,...c->...fe", [ws, reference_vector]) 76 | return beamform_vector 77 | 78 | 79 | def apply_beamforming_vector( 80 | beamform_vector: ComplexTensor, mix: ComplexTensor 81 | ) -> ComplexTensor: 82 | # (..., C) x (..., C, T) -> (..., T) 83 | es = FC.einsum("...c,...ct->...t", [beamform_vector.conj(), mix]) 84 | return es 85 | -------------------------------------------------------------------------------- /funasr_detach/frontends/utils/dnn_wpe.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | from pytorch_wpe import wpe_one_iteration 4 | import torch 5 | from torch_complex.tensor import ComplexTensor 6 | 7 | from funasr_detach.frontends.utils.mask_estimator import MaskEstimator 8 | from funasr_detach.models.transformer.utils.nets_utils import make_pad_mask 9 | 10 | 11 | class DNN_WPE(torch.nn.Module): 12 | def __init__( 13 | self, 14 | wtype: str = "blstmp", 15 | widim: int = 257, 16 | wlayers: int = 3, 17 | wunits: int = 300, 18 | wprojs: int = 320, 19 | dropout_rate: float = 0.0, 20 | taps: int = 5, 21 | delay: int = 3, 22 | use_dnn_mask: bool = True, 23 | iterations: int = 1, 24 | normalization: bool = False, 25 | ): 26 | super().__init__() 27 | self.iterations = iterations 28 | self.taps = taps 29 | self.delay = delay 30 | 31 | self.normalization = normalization 32 | self.use_dnn_mask = use_dnn_mask 33 | 34 | self.inverse_power = True 35 | 36 | if self.use_dnn_mask: 37 | self.mask_est = MaskEstimator( 38 | wtype, widim, wlayers, wunits, wprojs, dropout_rate, nmask=1 39 | ) 40 | 41 | def forward( 42 | self, data: ComplexTensor, ilens: torch.LongTensor 43 | ) -> Tuple[ComplexTensor, torch.LongTensor, ComplexTensor]: 44 | """The forward function 45 | 46 | Notation: 47 | B: Batch 48 | C: Channel 49 | T: Time or Sequence length 50 | F: Freq or Some dimension of the feature vector 51 | 52 | Args: 53 | data: (B, C, T, F) 54 | ilens: (B,) 55 | Returns: 56 | data: (B, C, T, F) 57 | ilens: (B,) 58 | """ 59 | # (B, T, C, F) -> (B, F, C, T) 60 | enhanced = data = data.permute(0, 3, 2, 1) 61 | mask = None 62 | 63 | for i in range(self.iterations): 64 | # Calculate power: (..., C, T) 65 | power = enhanced.real**2 + enhanced.imag**2 66 | if i == 0 and self.use_dnn_mask: 67 | # mask: (B, F, C, T) 68 | (mask,), _ = self.mask_est(enhanced, ilens) 69 | if self.normalization: 70 | # Normalize along T 71 | mask = mask / mask.sum(dim=-1)[..., None] 72 | # (..., C, T) * (..., C, T) -> (..., C, T) 73 | power = power * mask 74 | 75 | # Averaging along the channel axis: (..., C, T) -> (..., T) 76 | power = power.mean(dim=-2) 77 | 78 | # enhanced: (..., C, T) -> (..., C, T) 79 | enhanced = wpe_one_iteration( 80 | data.contiguous(), 81 | power, 82 | taps=self.taps, 83 | delay=self.delay, 84 | inverse_power=self.inverse_power, 85 | ) 86 | 87 | enhanced.masked_fill_(make_pad_mask(ilens, enhanced.real), 0) 88 | 89 | # (B, F, C, T) -> (B, T, C, F) 90 | enhanced = enhanced.permute(0, 3, 2, 1) 91 | if mask is not None: 92 | mask = mask.transpose(-1, -3) 93 | return enhanced, ilens, mask 94 | -------------------------------------------------------------------------------- /funasr_detach/frontends/utils/log_mel.py: -------------------------------------------------------------------------------- 1 | import librosa 2 | import torch 3 | from typing import Tuple 4 | 5 | from funasr_detach.models.transformer.utils.nets_utils import make_pad_mask 6 | 7 | 8 | class LogMel(torch.nn.Module): 9 | """Convert STFT to fbank feats 10 | 11 | The arguments is same as librosa.filters.mel 12 | 13 | Args: 14 | fs: number > 0 [scalar] sampling rate of the incoming signal 15 | n_fft: int > 0 [scalar] number of FFT components 16 | n_mels: int > 0 [scalar] number of Mel bands to generate 17 | fmin: float >= 0 [scalar] lowest frequency (in Hz) 18 | fmax: float >= 0 [scalar] highest frequency (in Hz). 19 | If `None`, use `fmax = fs / 2.0` 20 | htk: use HTK formula instead of Slaney 21 | """ 22 | 23 | def __init__( 24 | self, 25 | fs: int = 16000, 26 | n_fft: int = 512, 27 | n_mels: int = 80, 28 | fmin: float = None, 29 | fmax: float = None, 30 | htk: bool = False, 31 | log_base: float = None, 32 | ): 33 | super().__init__() 34 | 35 | fmin = 0 if fmin is None else fmin 36 | fmax = fs / 2 if fmax is None else fmax 37 | _mel_options = dict( 38 | sr=fs, 39 | n_fft=n_fft, 40 | n_mels=n_mels, 41 | fmin=fmin, 42 | fmax=fmax, 43 | htk=htk, 44 | ) 45 | self.mel_options = _mel_options 46 | self.log_base = log_base 47 | 48 | # Note(kamo): The mel matrix of librosa is different from kaldi. 49 | melmat = librosa.filters.mel(**_mel_options) 50 | # melmat: (D2, D1) -> (D1, D2) 51 | self.register_buffer("melmat", torch.from_numpy(melmat.T).float()) 52 | 53 | def extra_repr(self): 54 | return ", ".join(f"{k}={v}" for k, v in self.mel_options.items()) 55 | 56 | def forward( 57 | self, 58 | feat: torch.Tensor, 59 | ilens: torch.Tensor = None, 60 | ) -> Tuple[torch.Tensor, torch.Tensor]: 61 | # feat: (B, T, D1) x melmat: (D1, D2) -> mel_feat: (B, T, D2) 62 | mel_feat = torch.matmul(feat, self.melmat) 63 | mel_feat = torch.clamp(mel_feat, min=1e-10) 64 | 65 | if self.log_base is None: 66 | logmel_feat = mel_feat.log() 67 | elif self.log_base == 2.0: 68 | logmel_feat = mel_feat.log2() 69 | elif self.log_base == 10.0: 70 | logmel_feat = mel_feat.log10() 71 | else: 72 | logmel_feat = mel_feat.log() / torch.log(self.log_base) 73 | 74 | # Zero padding 75 | if ilens is not None: 76 | logmel_feat = logmel_feat.masked_fill( 77 | make_pad_mask(ilens, logmel_feat, 1), 0.0 78 | ) 79 | else: 80 | ilens = feat.new_full( 81 | [feat.size(0)], fill_value=feat.size(1), dtype=torch.long 82 | ) 83 | return logmel_feat, ilens 84 | -------------------------------------------------------------------------------- /funasr_detach/frontends/utils/mask_estimator.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import numpy as np 4 | import torch 5 | from torch.nn import functional as F 6 | from torch_complex.tensor import ComplexTensor 7 | 8 | from funasr_detach.models.transformer.utils.nets_utils import make_pad_mask 9 | from funasr_detach.models.language_model.rnn.encoders import RNN 10 | from funasr_detach.models.language_model.rnn.encoders import RNNP 11 | 12 | 13 | class MaskEstimator(torch.nn.Module): 14 | def __init__(self, type, idim, layers, units, projs, dropout, nmask=1): 15 | super().__init__() 16 | subsample = np.ones(layers + 1, dtype=np.int32) 17 | 18 | typ = type.lstrip("vgg").rstrip("p") 19 | if type[-1] == "p": 20 | self.brnn = RNNP(idim, layers, units, projs, subsample, dropout, typ=typ) 21 | else: 22 | self.brnn = RNN(idim, layers, units, projs, dropout, typ=typ) 23 | 24 | self.type = type 25 | self.nmask = nmask 26 | self.linears = torch.nn.ModuleList( 27 | [torch.nn.Linear(projs, idim) for _ in range(nmask)] 28 | ) 29 | 30 | def forward( 31 | self, xs: ComplexTensor, ilens: torch.LongTensor 32 | ) -> Tuple[Tuple[torch.Tensor, ...], torch.LongTensor]: 33 | """The forward function 34 | 35 | Args: 36 | xs: (B, F, C, T) 37 | ilens: (B,) 38 | Returns: 39 | hs (torch.Tensor): The hidden vector (B, F, C, T) 40 | masks: A tuple of the masks. (B, F, C, T) 41 | ilens: (B,) 42 | """ 43 | assert xs.size(0) == ilens.size(0), (xs.size(0), ilens.size(0)) 44 | _, _, C, input_length = xs.size() 45 | # (B, F, C, T) -> (B, C, T, F) 46 | xs = xs.permute(0, 2, 3, 1) 47 | 48 | # Calculate amplitude: (B, C, T, F) -> (B, C, T, F) 49 | xs = (xs.real**2 + xs.imag**2) ** 0.5 50 | # xs: (B, C, T, F) -> xs: (B * C, T, F) 51 | xs = xs.contiguous().view(-1, xs.size(-2), xs.size(-1)) 52 | # ilens: (B,) -> ilens_: (B * C) 53 | ilens_ = ilens[:, None].expand(-1, C).contiguous().view(-1) 54 | 55 | # xs: (B * C, T, F) -> xs: (B * C, T, D) 56 | xs, _, _ = self.brnn(xs, ilens_) 57 | # xs: (B * C, T, D) -> xs: (B, C, T, D) 58 | xs = xs.view(-1, C, xs.size(-2), xs.size(-1)) 59 | 60 | masks = [] 61 | for linear in self.linears: 62 | # xs: (B, C, T, D) -> mask:(B, C, T, F) 63 | mask = linear(xs) 64 | 65 | mask = torch.sigmoid(mask) 66 | # Zero padding 67 | mask.masked_fill(make_pad_mask(ilens, mask, length_dim=2), 0) 68 | 69 | # (B, C, T, F) -> (B, F, C, T) 70 | mask = mask.permute(0, 3, 1, 2) 71 | 72 | # Take cares of multi gpu cases: If input_length > max(ilens) 73 | if mask.size(-1) < input_length: 74 | mask = F.pad(mask, [0, input_length - mask.size(-1)], value=0) 75 | masks.append(mask) 76 | 77 | return tuple(masks), ilens 78 | -------------------------------------------------------------------------------- /funasr_detach/frontends/windowing.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # 2020, Technische Universität München; Ludwig Kürzinger 3 | # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) 4 | 5 | """Sliding Window for raw audio input data.""" 6 | 7 | import torch 8 | import torch.nn as nn 9 | from typing import Tuple 10 | 11 | 12 | class SlidingWindow(nn.Module): 13 | """Sliding Window. 14 | Provides a sliding window over a batched continuous raw audio tensor. 15 | Optionally, provides padding (Currently not implemented). 16 | Combine this module with a pre-encoder compatible with raw audio data, 17 | for example Sinc convolutions. 18 | Known issues: 19 | Output length is calculated incorrectly if audio shorter than win_length. 20 | WARNING: trailing values are discarded - padding not implemented yet. 21 | There is currently no additional window function applied to input values. 22 | """ 23 | 24 | def __init__( 25 | self, 26 | win_length: int = 400, 27 | hop_length: int = 160, 28 | channels: int = 1, 29 | padding: int = None, 30 | fs=None, 31 | ): 32 | """Initialize. 33 | Args: 34 | win_length: Length of frame. 35 | hop_length: Relative starting point of next frame. 36 | channels: Number of input channels. 37 | padding: Padding (placeholder, currently not implemented). 38 | fs: Sampling rate (placeholder for compatibility, not used). 39 | """ 40 | super().__init__() 41 | self.fs = fs 42 | self.win_length = win_length 43 | self.hop_length = hop_length 44 | self.channels = channels 45 | self.padding = padding 46 | 47 | def forward( 48 | self, input: torch.Tensor, input_lengths: torch.Tensor 49 | ) -> Tuple[torch.Tensor, torch.Tensor]: 50 | """Apply a sliding window on the input. 51 | Args: 52 | input: Input (B, T, C*D) or (B, T*C*D), with D=C=1. 53 | input_lengths: Input lengths within batch. 54 | Returns: 55 | Tensor: Output with dimensions (B, T, C, D), with D=win_length. 56 | Tensor: Output lengths within batch. 57 | """ 58 | input_size = input.size() 59 | B = input_size[0] 60 | T = input_size[1] 61 | C = self.channels 62 | D = self.win_length 63 | # (B, T, C) --> (T, B, C) 64 | continuous = input.view(B, T, C).permute(1, 0, 2) 65 | windowed = continuous.unfold(0, D, self.hop_length) 66 | # (T, B, C, D) --> (B, T, C, D) 67 | output = windowed.permute(1, 0, 2, 3).contiguous() 68 | # After unfold(), windowed lengths change: 69 | output_lengths = (input_lengths - self.win_length) // self.hop_length + 1 70 | return output, output_lengths 71 | 72 | def output_size(self) -> int: 73 | """Return output length of feature dimension D, i.e. the window length.""" 74 | return self.win_length 75 | -------------------------------------------------------------------------------- /funasr_detach/losses/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stepfun-ai/Step-Audio/d67e374fd6aff418835418e2e32989ca2db00dff/funasr_detach/losses/__init__.py -------------------------------------------------------------------------------- /funasr_detach/metrics/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stepfun-ai/Step-Audio/d67e374fd6aff418835418e2e32989ca2db00dff/funasr_detach/metrics/__init__.py -------------------------------------------------------------------------------- /funasr_detach/metrics/compute_acc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def th_accuracy(pad_outputs, pad_targets, ignore_label): 5 | """Calculate accuracy. 6 | 7 | Args: 8 | pad_outputs (Tensor): Prediction tensors (B * Lmax, D). 9 | pad_targets (LongTensor): Target label tensors (B, Lmax, D). 10 | ignore_label (int): Ignore label id. 11 | 12 | Returns: 13 | float: Accuracy value (0.0 - 1.0). 14 | 15 | """ 16 | pad_pred = pad_outputs.view( 17 | pad_targets.size(0), pad_targets.size(1), pad_outputs.size(1) 18 | ).argmax(2) 19 | mask = pad_targets != ignore_label 20 | numerator = torch.sum( 21 | pad_pred.masked_select(mask) == pad_targets.masked_select(mask) 22 | ) 23 | denominator = torch.sum(mask) 24 | return float(numerator) / float(denominator) 25 | -------------------------------------------------------------------------------- /funasr_detach/metrics/compute_eer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sklearn.metrics import roc_curve 3 | import argparse 4 | 5 | 6 | def _compute_eer(label, pred, positive_label=1): 7 | """ 8 | Python compute equal error rate (eer) 9 | ONLY tested on binary classification 10 | 11 | :param label: ground-truth label, should be a 1-d list or np.array, each element represents the ground-truth label of one sample 12 | :param pred: model prediction, should be a 1-d list or np.array, each element represents the model prediction of one sample 13 | :param positive_label: the class that is viewed as positive class when computing EER 14 | :return: equal error rate (EER) 15 | """ 16 | 17 | # all fpr, tpr, fnr, fnr, threshold are lists (in the format of np.array) 18 | fpr, tpr, threshold = roc_curve(label, pred, pos_label=positive_label) 19 | fnr = 1 - tpr 20 | 21 | # the threshold of fnr == fpr 22 | eer_threshold = threshold[np.nanargmin(np.absolute((fnr - fpr)))] 23 | 24 | # theoretically eer from fpr and eer from fnr should be identical but they can be slightly differ in reality 25 | eer_1 = fpr[np.nanargmin(np.absolute((fnr - fpr)))] 26 | eer_2 = fnr[np.nanargmin(np.absolute((fnr - fpr)))] 27 | 28 | # return the mean of eer from fpr and from fnr 29 | eer = (eer_1 + eer_2) / 2 30 | return eer, eer_threshold 31 | 32 | 33 | def compute_eer(trials_path, scores_path): 34 | labels = [] 35 | for one_line in open(trials_path, "r"): 36 | labels.append(one_line.strip().rsplit(" ", 1)[-1] == "target") 37 | labels = np.array(labels, dtype=int) 38 | 39 | scores = [] 40 | for one_line in open(scores_path, "r"): 41 | scores.append(float(one_line.strip().rsplit(" ", 1)[-1])) 42 | scores = np.array(scores, dtype=float) 43 | 44 | eer, threshold = _compute_eer(labels, scores) 45 | return eer, threshold 46 | 47 | 48 | def main(): 49 | parser = argparse.ArgumentParser() 50 | parser.add_argument("trials", help="trial list") 51 | parser.add_argument("scores", help="score file, normalized to [0, 1]") 52 | args = parser.parse_args() 53 | 54 | eer, threshold = compute_eer(args.trials, args.scores) 55 | print("EER is {:.4f} at threshold {:.4f}".format(eer * 100.0, threshold)) 56 | 57 | 58 | if __name__ == "__main__": 59 | main() 60 | -------------------------------------------------------------------------------- /funasr_detach/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stepfun-ai/Step-Audio/d67e374fd6aff418835418e2e32989ca2db00dff/funasr_detach/models/__init__.py -------------------------------------------------------------------------------- /funasr_detach/models/bat/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stepfun-ai/Step-Audio/d67e374fd6aff418835418e2e32989ca2db00dff/funasr_detach/models/bat/__init__.py -------------------------------------------------------------------------------- /funasr_detach/models/bat/model.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- encoding: utf-8 -*- 3 | # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. 4 | # MIT License (https://opensource.org/licenses/MIT) 5 | 6 | import time 7 | import torch 8 | import logging 9 | from contextlib import contextmanager 10 | from typing import Dict, Optional, Tuple 11 | from distutils.version import LooseVersion 12 | 13 | from funasr_detach.register import tables 14 | from funasr_detach.utils import postprocess_utils 15 | from funasr_detach.utils.datadir_writer import DatadirWriter 16 | from funasr_detach.models.transducer.model import Transducer 17 | from funasr_detach.train_utils.device_funcs import force_gatherable 18 | from funasr_detach.models.transformer.scorers.ctc import CTCPrefixScorer 19 | from funasr_detach.losses.label_smoothing_loss import LabelSmoothingLoss 20 | from funasr_detach.models.transformer.scorers.length_bonus import LengthBonus 21 | from funasr_detach.models.transformer.utils.nets_utils import get_transducer_task_io 22 | from funasr_detach.utils.load_utils import load_audio_text_image_video, extract_fbank 23 | from funasr_detach.models.transducer.beam_search_transducer import BeamSearchTransducer 24 | 25 | 26 | if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"): 27 | from torch.cuda.amp import autocast 28 | else: 29 | # Nothing to do if torch<1.6.0 30 | @contextmanager 31 | def autocast(enabled=True): 32 | yield 33 | 34 | 35 | @tables.register("model_classes", "BAT") # TODO: BAT training 36 | class BAT(Transducer): 37 | pass 38 | -------------------------------------------------------------------------------- /funasr_detach/models/bicif_paraformer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stepfun-ai/Step-Audio/d67e374fd6aff418835418e2e32989ca2db00dff/funasr_detach/models/bicif_paraformer/__init__.py -------------------------------------------------------------------------------- /funasr_detach/models/branchformer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stepfun-ai/Step-Audio/d67e374fd6aff418835418e2e32989ca2db00dff/funasr_detach/models/branchformer/__init__.py -------------------------------------------------------------------------------- /funasr_detach/models/branchformer/model.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from funasr_detach.models.transformer.model import Transformer 4 | from funasr_detach.register import tables 5 | 6 | 7 | @tables.register("model_classes", "Branchformer") 8 | class Branchformer(Transformer): 9 | """CTC-attention hybrid Encoder-Decoder model""" 10 | 11 | def __init__( 12 | self, 13 | *args, 14 | **kwargs, 15 | ): 16 | 17 | super().__init__(*args, **kwargs) 18 | -------------------------------------------------------------------------------- /funasr_detach/models/branchformer/template.yaml: -------------------------------------------------------------------------------- 1 | # This is an example that demonstrates how to configure a model file. 2 | # You can modify the configuration according to your own requirements. 3 | 4 | # to print the register_table: 5 | # from funasr.register import tables 6 | # tables.print() 7 | 8 | # network architecture 9 | model: Branchformer 10 | model_conf: 11 | ctc_weight: 0.3 12 | lsm_weight: 0.1 # label smoothing option 13 | length_normalized_loss: false 14 | 15 | # encoder 16 | encoder: BranchformerEncoder 17 | encoder_conf: 18 | output_size: 256 19 | use_attn: true 20 | attention_heads: 4 21 | attention_layer_type: rel_selfattn 22 | pos_enc_layer_type: rel_pos 23 | rel_pos_type: latest 24 | use_cgmlp: true 25 | cgmlp_linear_units: 2048 26 | cgmlp_conv_kernel: 31 27 | use_linear_after_conv: false 28 | gate_activation: identity 29 | merge_method: concat 30 | cgmlp_weight: 0.5 # used only if merge_method is "fixed_ave" 31 | attn_branch_drop_rate: 0.0 # used only if merge_method is "learned_ave" 32 | num_blocks: 24 33 | dropout_rate: 0.1 34 | positional_dropout_rate: 0.1 35 | attention_dropout_rate: 0.1 36 | input_layer: conv2d 37 | stochastic_depth_rate: 0.0 38 | 39 | # decoder 40 | decoder: TransformerDecoder 41 | decoder_conf: 42 | attention_heads: 4 43 | linear_units: 2048 44 | num_blocks: 6 45 | dropout_rate: 0.1 46 | positional_dropout_rate: 0.1 47 | self_attention_dropout_rate: 0. 48 | src_attention_dropout_rate: 0. 49 | 50 | 51 | # frontend related 52 | frontend: WavFrontend 53 | frontend_conf: 54 | fs: 16000 55 | window: hamming 56 | n_mels: 80 57 | frame_length: 25 58 | frame_shift: 10 59 | dither: 0.0 60 | lfr_m: 1 61 | lfr_n: 1 62 | 63 | specaug: SpecAug 64 | specaug_conf: 65 | apply_time_warp: true 66 | time_warp_window: 5 67 | time_warp_mode: bicubic 68 | apply_freq_mask: true 69 | freq_mask_width_range: 70 | - 0 71 | - 30 72 | num_freq_mask: 2 73 | apply_time_mask: true 74 | time_mask_width_range: 75 | - 0 76 | - 40 77 | num_time_mask: 2 78 | 79 | train_conf: 80 | accum_grad: 1 81 | grad_clip: 5 82 | max_epoch: 150 83 | keep_nbest_models: 10 84 | log_interval: 50 85 | 86 | optim: adam 87 | optim_conf: 88 | lr: 0.001 89 | weight_decay: 0.000001 90 | scheduler: warmuplr 91 | scheduler_conf: 92 | warmup_steps: 35000 93 | 94 | dataset: AudioDataset 95 | dataset_conf: 96 | index_ds: IndexDSJsonl 97 | batch_sampler: DynamicBatchLocalShuffleSampler 98 | batch_type: example # example or length 99 | batch_size: 1 # if batch_type is example, batch_size is the numbers of samples; if length, batch_size is source_token_len+target_token_len; 100 | max_token_length: 2048 # filter samples if source_token_len+target_token_len > max_token_length, 101 | buffer_size: 500 102 | shuffle: True 103 | num_workers: 4 104 | 105 | tokenizer: CharTokenizer 106 | tokenizer_conf: 107 | unk_symbol: 108 | split_with_space: true 109 | 110 | 111 | ctc_conf: 112 | dropout_rate: 0.0 113 | ctc_type: builtin 114 | reduce: true 115 | ignore_nan_grad: true 116 | normalize: null 117 | -------------------------------------------------------------------------------- /funasr_detach/models/campplus/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stepfun-ai/Step-Audio/d67e374fd6aff418835418e2e32989ca2db00dff/funasr_detach/models/campplus/__init__.py -------------------------------------------------------------------------------- /funasr_detach/models/campplus/template.yaml: -------------------------------------------------------------------------------- 1 | # This is an example that demonstrates how to configure a model file. 2 | # You can modify the configuration according to your own requirements. 3 | 4 | # to print the register_table: 5 | # from funasr.register import tables 6 | # tables.print() 7 | 8 | # network architecture 9 | model: CAMPPlus 10 | model_conf: 11 | feat_dim: 80 12 | embedding_size: 192 13 | growth_rate: 32 14 | bn_size: 4 15 | init_channels: 128 16 | config_str: 'batchnorm-relu' 17 | memory_efficient: True 18 | output_level: 'segment' 19 | 20 | # frontend related 21 | frontend: WavFrontend 22 | frontend_conf: 23 | fs: 16000 24 | -------------------------------------------------------------------------------- /funasr_detach/models/conformer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stepfun-ai/Step-Audio/d67e374fd6aff418835418e2e32989ca2db00dff/funasr_detach/models/conformer/__init__.py -------------------------------------------------------------------------------- /funasr_detach/models/conformer/model.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import torch 4 | 5 | from funasr_detach.models.transformer.model import Transformer 6 | from funasr_detach.register import tables 7 | 8 | 9 | @tables.register("model_classes", "Conformer") 10 | class Conformer(Transformer): 11 | """CTC-attention hybrid Encoder-Decoder model""" 12 | 13 | def __init__( 14 | self, 15 | *args, 16 | **kwargs, 17 | ): 18 | 19 | super().__init__(*args, **kwargs) 20 | -------------------------------------------------------------------------------- /funasr_detach/models/conformer/template.yaml: -------------------------------------------------------------------------------- 1 | # This is an example that demonstrates how to configure a model file. 2 | # You can modify the configuration according to your own requirements. 3 | 4 | # to print the register_table: 5 | # from funasr.register import tables 6 | # tables.print() 7 | 8 | # network architecture 9 | model: Conformer 10 | model_conf: 11 | ctc_weight: 0.3 12 | lsm_weight: 0.1 # label smoothing option 13 | length_normalized_loss: false 14 | 15 | # encoder 16 | encoder: ConformerEncoder 17 | encoder_conf: 18 | output_size: 256 19 | attention_heads: 4 20 | linear_units: 2048 21 | num_blocks: 12 22 | dropout_rate: 0.1 23 | positional_dropout_rate: 0.1 24 | attention_dropout_rate: 0.0 25 | input_layer: conv2d 26 | normalize_before: true 27 | pos_enc_layer_type: rel_pos 28 | selfattention_layer_type: rel_selfattn 29 | activation_type: swish 30 | macaron_style: true 31 | use_cnn_module: true 32 | cnn_module_kernel: 15 33 | 34 | # decoder 35 | decoder: TransformerDecoder 36 | decoder_conf: 37 | attention_heads: 4 38 | linear_units: 2048 39 | num_blocks: 6 40 | dropout_rate: 0.1 41 | positional_dropout_rate: 0.1 42 | self_attention_dropout_rate: 0.0 43 | src_attention_dropout_rate: 0.0 44 | 45 | 46 | # frontend related 47 | frontend: WavFrontend 48 | frontend_conf: 49 | fs: 16000 50 | window: hamming 51 | n_mels: 80 52 | frame_length: 25 53 | frame_shift: 10 54 | dither: 0.0 55 | lfr_m: 1 56 | lfr_n: 1 57 | 58 | specaug: SpecAug 59 | specaug_conf: 60 | apply_time_warp: true 61 | time_warp_window: 5 62 | time_warp_mode: bicubic 63 | apply_freq_mask: true 64 | freq_mask_width_range: 65 | - 0 66 | - 30 67 | num_freq_mask: 2 68 | apply_time_mask: true 69 | time_mask_width_range: 70 | - 0 71 | - 40 72 | num_time_mask: 2 73 | 74 | train_conf: 75 | accum_grad: 1 76 | grad_clip: 5 77 | max_epoch: 150 78 | val_scheduler_criterion: 79 | - valid 80 | - acc 81 | best_model_criterion: 82 | - - valid 83 | - acc 84 | - max 85 | keep_nbest_models: 10 86 | log_interval: 50 87 | 88 | optim: adam 89 | optim_conf: 90 | lr: 0.0005 91 | scheduler: warmuplr 92 | scheduler_conf: 93 | warmup_steps: 30000 94 | 95 | dataset: AudioDataset 96 | dataset_conf: 97 | index_ds: IndexDSJsonl 98 | batch_sampler: DynamicBatchLocalShuffleSampler 99 | batch_type: example # example or length 100 | batch_size: 1 # if batch_type is example, batch_size is the numbers of samples; if length, batch_size is source_token_len+target_token_len; 101 | max_token_length: 2048 # filter samples if source_token_len+target_token_len > max_token_length, 102 | buffer_size: 500 103 | shuffle: True 104 | num_workers: 0 105 | 106 | tokenizer: CharTokenizer 107 | tokenizer_conf: 108 | unk_symbol: 109 | split_with_space: true 110 | 111 | 112 | ctc_conf: 113 | dropout_rate: 0.0 114 | ctc_type: builtin 115 | reduce: true 116 | ignore_nan_grad: true 117 | normalize: null 118 | -------------------------------------------------------------------------------- /funasr_detach/models/contextual_paraformer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stepfun-ai/Step-Audio/d67e374fd6aff418835418e2e32989ca2db00dff/funasr_detach/models/contextual_paraformer/__init__.py -------------------------------------------------------------------------------- /funasr_detach/models/contextual_paraformer/template.yaml: -------------------------------------------------------------------------------- 1 | # This is an example that demonstrates how to configure a model file. 2 | # You can modify the configuration according to your own requirements. 3 | 4 | # to print the register_table: 5 | # from funasr.register import tables 6 | # tables.print() 7 | 8 | # network architecture 9 | model: ContextualParaformer 10 | model_conf: 11 | ctc_weight: 0.0 12 | lsm_weight: 0.1 13 | length_normalized_loss: true 14 | predictor_weight: 1.0 15 | predictor_bias: 1 16 | sampling_ratio: 0.75 17 | inner_dim: 512 18 | 19 | # encoder 20 | encoder: SANMEncoder 21 | encoder_conf: 22 | output_size: 512 23 | attention_heads: 4 24 | linear_units: 2048 25 | num_blocks: 50 26 | dropout_rate: 0.1 27 | positional_dropout_rate: 0.1 28 | attention_dropout_rate: 0.1 29 | input_layer: pe 30 | pos_enc_class: SinusoidalPositionEncoder 31 | normalize_before: true 32 | kernel_size: 11 33 | sanm_shfit: 0 34 | selfattention_layer_type: sanm 35 | 36 | 37 | # decoder 38 | decoder: ContextualParaformerDecoder 39 | decoder_conf: 40 | attention_heads: 4 41 | linear_units: 2048 42 | num_blocks: 16 43 | dropout_rate: 0.1 44 | positional_dropout_rate: 0.1 45 | self_attention_dropout_rate: 0.1 46 | src_attention_dropout_rate: 0.1 47 | att_layer_num: 16 48 | kernel_size: 11 49 | sanm_shfit: 0 50 | 51 | predictor: CifPredictorV2 52 | predictor_conf: 53 | idim: 512 54 | threshold: 1.0 55 | l_order: 1 56 | r_order: 1 57 | tail_threshold: 0.45 58 | 59 | # frontend related 60 | frontend: WavFrontend 61 | frontend_conf: 62 | fs: 16000 63 | window: hamming 64 | n_mels: 80 65 | frame_length: 25 66 | frame_shift: 10 67 | lfr_m: 7 68 | lfr_n: 6 69 | 70 | specaug: SpecAugLFR 71 | specaug_conf: 72 | apply_time_warp: false 73 | time_warp_window: 5 74 | time_warp_mode: bicubic 75 | apply_freq_mask: true 76 | freq_mask_width_range: 77 | - 0 78 | - 30 79 | lfr_rate: 6 80 | num_freq_mask: 1 81 | apply_time_mask: true 82 | time_mask_width_range: 83 | - 0 84 | - 12 85 | num_time_mask: 1 86 | 87 | train_conf: 88 | accum_grad: 1 89 | grad_clip: 5 90 | max_epoch: 150 91 | val_scheduler_criterion: 92 | - valid 93 | - acc 94 | best_model_criterion: 95 | - - valid 96 | - acc 97 | - max 98 | keep_nbest_models: 10 99 | log_interval: 50 100 | 101 | optim: adam 102 | optim_conf: 103 | lr: 0.0005 104 | scheduler: warmuplr 105 | scheduler_conf: 106 | warmup_steps: 30000 107 | 108 | dataset: AudioDataset 109 | dataset_conf: 110 | index_ds: IndexDSJsonl 111 | batch_sampler: DynamicBatchLocalShuffleSampler 112 | batch_type: example # example or length 113 | batch_size: 1 # if batch_type is example, batch_size is the numbers of samples; if length, batch_size is source_token_len+target_token_len; 114 | max_token_length: 2048 # filter samples if source_token_len+target_token_len > max_token_length, 115 | buffer_size: 500 116 | shuffle: True 117 | num_workers: 0 118 | 119 | tokenizer: CharTokenizer 120 | tokenizer_conf: 121 | unk_symbol: 122 | split_with_space: true 123 | 124 | ctc_conf: 125 | dropout_rate: 0.0 126 | ctc_type: builtin 127 | reduce: true 128 | ignore_nan_grad: true 129 | normalize: null -------------------------------------------------------------------------------- /funasr_detach/models/ct_transformer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stepfun-ai/Step-Audio/d67e374fd6aff418835418e2e32989ca2db00dff/funasr_detach/models/ct_transformer/__init__.py -------------------------------------------------------------------------------- /funasr_detach/models/ct_transformer/template.yaml: -------------------------------------------------------------------------------- 1 | # This is an example that demonstrates how to configure a model file. 2 | # You can modify the configuration according to your own requirements. 3 | 4 | # to print the register_table: 5 | # from funasr.register import tables 6 | # tables.print() 7 | 8 | model: CTTransformer 9 | model_conf: 10 | ignore_id: 0 11 | embed_unit: 256 12 | att_unit: 256 13 | dropout_rate: 0.1 14 | punc_list: 15 | - 16 | - _ 17 | - ',' 18 | - 。 19 | - '?' 20 | - 、 21 | punc_weight: 22 | - 1.0 23 | - 1.0 24 | - 1.0 25 | - 1.0 26 | - 1.0 27 | - 1.0 28 | sentence_end_id: 3 29 | 30 | encoder: SANMEncoder 31 | encoder_conf: 32 | input_size: 256 33 | output_size: 256 34 | attention_heads: 8 35 | linear_units: 1024 36 | num_blocks: 4 37 | dropout_rate: 0.1 38 | positional_dropout_rate: 0.1 39 | attention_dropout_rate: 0.0 40 | input_layer: pe 41 | pos_enc_class: SinusoidalPositionEncoder 42 | normalize_before: true 43 | kernel_size: 11 44 | sanm_shfit: 0 45 | selfattention_layer_type: sanm 46 | padding_idx: 0 47 | 48 | tokenizer: CharTokenizer 49 | tokenizer_conf: 50 | unk_symbol: 51 | 52 | 53 | 54 | -------------------------------------------------------------------------------- /funasr_detach/models/ct_transformer_streaming/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stepfun-ai/Step-Audio/d67e374fd6aff418835418e2e32989ca2db00dff/funasr_detach/models/ct_transformer_streaming/__init__.py -------------------------------------------------------------------------------- /funasr_detach/models/ct_transformer_streaming/attention.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- encoding: utf-8 -*- 3 | # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. 4 | # MIT License (https://opensource.org/licenses/MIT) 5 | 6 | import torch 7 | from funasr_detach.models.sanm.attention import MultiHeadedAttentionSANM 8 | 9 | 10 | class MultiHeadedAttentionSANMwithMask(MultiHeadedAttentionSANM): 11 | def __init__(self, *args, **kwargs): 12 | super().__init__(*args, **kwargs) 13 | 14 | def forward(self, x, mask, mask_shfit_chunk=None, mask_att_chunk_encoder=None): 15 | q_h, k_h, v_h, v = self.forward_qkv(x) 16 | fsmn_memory = self.forward_fsmn(v, mask[0], mask_shfit_chunk) 17 | q_h = q_h * self.d_k ** (-0.5) 18 | scores = torch.matmul(q_h, k_h.transpose(-2, -1)) 19 | att_outs = self.forward_attention(v_h, scores, mask[1], mask_att_chunk_encoder) 20 | return att_outs + fsmn_memory 21 | -------------------------------------------------------------------------------- /funasr_detach/models/ct_transformer_streaming/template.yaml: -------------------------------------------------------------------------------- 1 | # This is an example that demonstrates how to configure a model file. 2 | # You can modify the configuration according to your own requirements. 3 | 4 | # to print the register_table: 5 | # from funasr.register import tables 6 | # tables.print() 7 | 8 | model: CTTransformerStreaming 9 | model_conf: 10 | ignore_id: 0 11 | embed_unit: 256 12 | att_unit: 256 13 | dropout_rate: 0.1 14 | punc_list: 15 | - 16 | - _ 17 | - , 18 | - 。 19 | - ? 20 | - 、 21 | punc_weight: 22 | - 1.0 23 | - 1.0 24 | - 1.0 25 | - 1.0 26 | - 1.0 27 | - 1.0 28 | sentence_end_id: 3 29 | 30 | encoder: SANMVadEncoder 31 | encoder_conf: 32 | input_size: 256 33 | output_size: 256 34 | attention_heads: 8 35 | linear_units: 1024 36 | num_blocks: 3 37 | dropout_rate: 0.1 38 | positional_dropout_rate: 0.1 39 | attention_dropout_rate: 0.0 40 | input_layer: pe 41 | pos_enc_class: SinusoidalPositionEncoder 42 | normalize_before: true 43 | kernel_size: 11 44 | sanm_shfit: 5 45 | selfattention_layer_type: sanm 46 | padding_idx: 0 47 | 48 | tokenizer: CharTokenizer 49 | tokenizer_conf: 50 | unk_symbol: -------------------------------------------------------------------------------- /funasr_detach/models/ctc/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stepfun-ai/Step-Audio/d67e374fd6aff418835418e2e32989ca2db00dff/funasr_detach/models/ctc/__init__.py -------------------------------------------------------------------------------- /funasr_detach/models/data2vec/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stepfun-ai/Step-Audio/d67e374fd6aff418835418e2e32989ca2db00dff/funasr_detach/models/data2vec/__init__.py -------------------------------------------------------------------------------- /funasr_detach/models/data2vec/grad_multiply.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch 7 | 8 | 9 | class GradMultiply(torch.autograd.Function): 10 | @staticmethod 11 | def forward(ctx, x, scale): 12 | ctx.scale = scale 13 | res = x.new(x) 14 | return res 15 | 16 | @staticmethod 17 | def backward(ctx, grad): 18 | return grad * ctx.scale, None 19 | -------------------------------------------------------------------------------- /funasr_detach/models/e_branchformer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stepfun-ai/Step-Audio/d67e374fd6aff418835418e2e32989ca2db00dff/funasr_detach/models/e_branchformer/__init__.py -------------------------------------------------------------------------------- /funasr_detach/models/e_branchformer/model.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from funasr_detach.models.transformer.model import Transformer 4 | from funasr_detach.register import tables 5 | 6 | 7 | @tables.register("model_classes", "EBranchformer") 8 | class EBranchformer(Transformer): 9 | """CTC-attention hybrid Encoder-Decoder model""" 10 | 11 | def __init__( 12 | self, 13 | *args, 14 | **kwargs, 15 | ): 16 | 17 | super().__init__(*args, **kwargs) 18 | -------------------------------------------------------------------------------- /funasr_detach/models/e_branchformer/template.yaml: -------------------------------------------------------------------------------- 1 | # This is an example that demonstrates how to configure a model file. 2 | # You can modify the configuration according to your own requirements. 3 | 4 | # to print the register_table: 5 | # from funasr.register import tables 6 | # tables.print() 7 | 8 | # network architecture 9 | model: Branchformer 10 | model_conf: 11 | ctc_weight: 0.3 12 | lsm_weight: 0.1 # label smoothing option 13 | length_normalized_loss: false 14 | 15 | # encoder 16 | encoder: EBranchformerEncoder 17 | encoder_conf: 18 | output_size: 256 19 | attention_heads: 4 20 | attention_layer_type: rel_selfattn 21 | pos_enc_layer_type: rel_pos 22 | rel_pos_type: latest 23 | cgmlp_linear_units: 1024 24 | cgmlp_conv_kernel: 31 25 | use_linear_after_conv: false 26 | gate_activation: identity 27 | num_blocks: 12 28 | dropout_rate: 0.1 29 | positional_dropout_rate: 0.1 30 | attention_dropout_rate: 0.1 31 | input_layer: conv2d 32 | layer_drop_rate: 0.0 33 | linear_units: 1024 34 | positionwise_layer_type: linear 35 | use_ffn: true 36 | macaron_ffn: true 37 | merge_conv_kernel: 31 38 | 39 | # decoder 40 | decoder: TransformerDecoder 41 | decoder_conf: 42 | attention_heads: 4 43 | linear_units: 2048 44 | num_blocks: 6 45 | dropout_rate: 0.1 46 | positional_dropout_rate: 0.1 47 | self_attention_dropout_rate: 0. 48 | src_attention_dropout_rate: 0. 49 | 50 | 51 | # frontend related 52 | frontend: WavFrontend 53 | frontend_conf: 54 | fs: 16000 55 | window: hamming 56 | n_mels: 80 57 | frame_length: 25 58 | frame_shift: 10 59 | dither: 0.0 60 | lfr_m: 1 61 | lfr_n: 1 62 | 63 | specaug: SpecAug 64 | specaug_conf: 65 | apply_time_warp: true 66 | time_warp_window: 5 67 | time_warp_mode: bicubic 68 | apply_freq_mask: true 69 | freq_mask_width_range: 70 | - 0 71 | - 30 72 | num_freq_mask: 2 73 | apply_time_mask: true 74 | time_mask_width_range: 75 | - 0 76 | - 40 77 | num_time_mask: 2 78 | 79 | train_conf: 80 | accum_grad: 1 81 | grad_clip: 5 82 | max_epoch: 180 83 | keep_nbest_models: 10 84 | log_interval: 50 85 | 86 | optim: adam 87 | optim_conf: 88 | lr: 0.001 89 | weight_decay: 0.000001 90 | scheduler: warmuplr 91 | scheduler_conf: 92 | warmup_steps: 35000 93 | 94 | dataset: AudioDataset 95 | dataset_conf: 96 | index_ds: IndexDSJsonl 97 | batch_sampler: DynamicBatchLocalShuffleSampler 98 | batch_type: example # example or length 99 | batch_size: 1 # if batch_type is example, batch_size is the numbers of samples; if length, batch_size is source_token_len+target_token_len; 100 | max_token_length: 2048 # filter samples if source_token_len+target_token_len > max_token_length, 101 | buffer_size: 500 102 | shuffle: True 103 | num_workers: 4 104 | 105 | tokenizer: CharTokenizer 106 | tokenizer_conf: 107 | unk_symbol: 108 | split_with_space: true 109 | 110 | 111 | ctc_conf: 112 | dropout_rate: 0.0 113 | ctc_type: builtin 114 | reduce: true 115 | ignore_nan_grad: true 116 | normalize: null 117 | -------------------------------------------------------------------------------- /funasr_detach/models/eend/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stepfun-ai/Step-Audio/d67e374fd6aff418835418e2e32989ca2db00dff/funasr_detach/models/eend/__init__.py -------------------------------------------------------------------------------- /funasr_detach/models/eend/eend_ola_dataloader.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import kaldiio 4 | import numpy as np 5 | import torch 6 | from torch.utils.data import DataLoader 7 | from torch.utils.data import Dataset 8 | 9 | 10 | def custom_collate(batch): 11 | keys, speech, speaker_labels, orders = zip(*batch) 12 | speech = [torch.from_numpy(np.copy(sph)).to(torch.float32) for sph in speech] 13 | speaker_labels = [ 14 | torch.from_numpy(np.copy(spk)).to(torch.float32) for spk in speaker_labels 15 | ] 16 | orders = [torch.from_numpy(np.copy(o)).to(torch.int64) for o in orders] 17 | batch = dict(speech=speech, speaker_labels=speaker_labels, orders=orders) 18 | 19 | return keys, batch 20 | 21 | 22 | class EENDOLADataset(Dataset): 23 | def __init__( 24 | self, 25 | data_file, 26 | ): 27 | self.data_file = data_file 28 | with open(data_file) as f: 29 | lines = f.readlines() 30 | self.samples = [line.strip().split() for line in lines] 31 | logging.info("total samples: {}".format(len(self.samples))) 32 | 33 | def __len__(self): 34 | return len(self.samples) 35 | 36 | def __getitem__(self, idx): 37 | key, speech_path, speaker_label_path = self.samples[idx] 38 | speech = kaldiio.load_mat(speech_path) 39 | speaker_label = kaldiio.load_mat(speaker_label_path).reshape( 40 | speech.shape[0], -1 41 | ) 42 | 43 | order = np.arange(speech.shape[0]) 44 | np.random.shuffle(order) 45 | 46 | return key, speech, speaker_label, order 47 | 48 | 49 | class EENDOLADataLoader: 50 | def __init__(self, data_file, batch_size, shuffle=True, num_workers=8): 51 | dataset = EENDOLADataset(data_file) 52 | self.data_loader = DataLoader( 53 | dataset, 54 | batch_size=batch_size, 55 | collate_fn=custom_collate, 56 | shuffle=shuffle, 57 | num_workers=num_workers, 58 | ) 59 | 60 | def build_iter(self, epoch): 61 | return self.data_loader 62 | -------------------------------------------------------------------------------- /funasr_detach/models/eend/encoder_decoder_attractor.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn.functional as F 4 | from torch import nn 5 | 6 | 7 | class EncoderDecoderAttractor(nn.Module): 8 | 9 | def __init__(self, n_units, encoder_dropout=0.1, decoder_dropout=0.1): 10 | super(EncoderDecoderAttractor, self).__init__() 11 | self.enc0_dropout = nn.Dropout(encoder_dropout) 12 | self.encoder = nn.LSTM( 13 | n_units, n_units, 1, batch_first=True, dropout=encoder_dropout 14 | ) 15 | self.dec0_dropout = nn.Dropout(decoder_dropout) 16 | self.decoder = nn.LSTM( 17 | n_units, n_units, 1, batch_first=True, dropout=decoder_dropout 18 | ) 19 | self.counter = nn.Linear(n_units, 1) 20 | self.n_units = n_units 21 | 22 | def forward_core(self, xs, zeros): 23 | ilens = torch.from_numpy(np.array([x.shape[0] for x in xs])).to(torch.int64) 24 | xs = [self.enc0_dropout(x) for x in xs] 25 | xs = nn.utils.rnn.pad_sequence(xs, batch_first=True, padding_value=-1) 26 | xs = nn.utils.rnn.pack_padded_sequence( 27 | xs, ilens, batch_first=True, enforce_sorted=False 28 | ) 29 | _, (hx, cx) = self.encoder(xs) 30 | zlens = torch.from_numpy(np.array([z.shape[0] for z in zeros])).to(torch.int64) 31 | max_zlen = torch.max(zlens).to(torch.int).item() 32 | zeros = [self.enc0_dropout(z) for z in zeros] 33 | zeros = nn.utils.rnn.pad_sequence(zeros, batch_first=True, padding_value=-1) 34 | zeros = nn.utils.rnn.pack_padded_sequence( 35 | zeros, zlens, batch_first=True, enforce_sorted=False 36 | ) 37 | attractors, (_, _) = self.decoder(zeros, (hx, cx)) 38 | attractors = nn.utils.rnn.pad_packed_sequence( 39 | attractors, batch_first=True, padding_value=-1, total_length=max_zlen 40 | )[0] 41 | attractors = [ 42 | att[: zlens[i].to(torch.int).item()] for i, att in enumerate(attractors) 43 | ] 44 | return attractors 45 | 46 | def forward(self, xs, n_speakers): 47 | zeros = [ 48 | torch.zeros(n_spk + 1, self.n_units).to(torch.float32).to(xs[0].device) 49 | for n_spk in n_speakers 50 | ] 51 | attractors = self.forward_core(xs, zeros) 52 | labels = torch.cat( 53 | [ 54 | torch.from_numpy(np.array([[1] * n_spk + [0]], np.float32)) 55 | for n_spk in n_speakers 56 | ], 57 | dim=1, 58 | ) 59 | labels = labels.to(xs[0].device) 60 | logit = torch.cat( 61 | [ 62 | self.counter(att).view(-1, n_spk + 1) 63 | for att, n_spk in zip(attractors, n_speakers) 64 | ], 65 | dim=1, 66 | ) 67 | loss = F.binary_cross_entropy(torch.sigmoid(logit), labels) 68 | 69 | attractors = [att[slice(0, att.shape[0] - 1)] for att in attractors] 70 | return loss, attractors 71 | 72 | def estimate(self, xs, max_n_speakers=15): 73 | zeros = [ 74 | torch.zeros(max_n_speakers, self.n_units).to(torch.float32).to(xs[0].device) 75 | for _ in xs 76 | ] 77 | attractors = self.forward_core(xs, zeros) 78 | probs = [torch.sigmoid(torch.flatten(self.counter(att))) for att in attractors] 79 | return attractors, probs 80 | -------------------------------------------------------------------------------- /funasr_detach/models/eend/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stepfun-ai/Step-Audio/d67e374fd6aff418835418e2e32989ca2db00dff/funasr_detach/models/eend/utils/__init__.py -------------------------------------------------------------------------------- /funasr_detach/models/eend/utils/losses.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn.functional as F 4 | from scipy.optimize import linear_sum_assignment 5 | 6 | 7 | def standard_loss(ys, ts): 8 | losses = [ 9 | F.binary_cross_entropy(torch.sigmoid(y), t) * len(y) for y, t in zip(ys, ts) 10 | ] 11 | loss = torch.sum(torch.stack(losses)) 12 | n_frames = ( 13 | torch.from_numpy(np.array(np.sum([t.shape[0] for t in ts]))) 14 | .to(torch.float32) 15 | .to(ys[0].device) 16 | ) 17 | loss = loss / n_frames 18 | return loss 19 | 20 | 21 | def fast_batch_pit_n_speaker_loss(ys, ts): 22 | with torch.no_grad(): 23 | bs = len(ys) 24 | indices = [] 25 | for b in range(bs): 26 | y = ys[b].transpose(0, 1) 27 | t = ts[b].transpose(0, 1) 28 | C, _ = t.shape 29 | y = y[:, None, :].repeat(1, C, 1) 30 | t = t[None, :, :].repeat(C, 1, 1) 31 | bce_loss = F.binary_cross_entropy( 32 | torch.sigmoid(y), t, reduction="none" 33 | ).mean(-1) 34 | C = bce_loss.cpu() 35 | indices.append(linear_sum_assignment(C)) 36 | labels_perm = [t[:, idx[1]] for t, idx in zip(ts, indices)] 37 | 38 | return labels_perm 39 | 40 | 41 | def cal_power_loss(logits, power_ts): 42 | losses = [ 43 | F.cross_entropy(input=logit, target=power_t.to(torch.long)) * len(logit) 44 | for logit, power_t in zip(logits, power_ts) 45 | ] 46 | loss = torch.sum(torch.stack(losses)) 47 | n_frames = ( 48 | torch.from_numpy(np.array(np.sum([power_t.shape[0] for power_t in power_ts]))) 49 | .to(torch.float32) 50 | .to(power_ts[0].device) 51 | ) 52 | loss = loss / n_frames 53 | return loss 54 | -------------------------------------------------------------------------------- /funasr_detach/models/emotion2vec/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stepfun-ai/Step-Audio/d67e374fd6aff418835418e2e32989ca2db00dff/funasr_detach/models/emotion2vec/__init__.py -------------------------------------------------------------------------------- /funasr_detach/models/eres2net/__init__.py: -------------------------------------------------------------------------------- 1 | from .eres2net import ERes2Net 2 | from .eres2net_aug import ERes2NetAug 3 | -------------------------------------------------------------------------------- /funasr_detach/models/eres2net/fusion.py: -------------------------------------------------------------------------------- 1 | # Copyright 3D-Speaker (https://github.com/alibaba-damo-academy/3D-Speaker). All Rights Reserved. 2 | # Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0) 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | 8 | class AFF(nn.Module): 9 | 10 | def __init__(self, channels=64, r=4): 11 | super(AFF, self).__init__() 12 | inter_channels = int(channels // r) 13 | 14 | self.local_att = nn.Sequential( 15 | nn.Conv2d(channels * 2, inter_channels, kernel_size=1, stride=1, padding=0), 16 | nn.BatchNorm2d(inter_channels), 17 | nn.SiLU(inplace=True), 18 | nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0), 19 | nn.BatchNorm2d(channels), 20 | ) 21 | 22 | def forward(self, x, ds_y): 23 | xa = torch.cat((x, ds_y), dim=1) 24 | x_att = self.local_att(xa) 25 | x_att = 1.0 + torch.tanh(x_att) 26 | xo = torch.mul(x, x_att) + torch.mul(ds_y, 2.0 - x_att) 27 | 28 | return xo 29 | -------------------------------------------------------------------------------- /funasr_detach/models/fsmn_vad_streaming/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stepfun-ai/Step-Audio/d67e374fd6aff418835418e2e32989ca2db00dff/funasr_detach/models/fsmn_vad_streaming/__init__.py -------------------------------------------------------------------------------- /funasr_detach/models/fsmn_vad_streaming/template.yaml: -------------------------------------------------------------------------------- 1 | # This is an example that demonstrates how to configure a model file. 2 | # You can modify the configuration according to your own requirements. 3 | 4 | # to print the register_table: 5 | # from funasr.register import tables 6 | # tables.print() 7 | 8 | # network architecture 9 | model: FsmnVADStreaming 10 | model_conf: 11 | sample_rate: 16000 12 | detect_mode: 1 13 | snr_mode: 0 14 | max_end_silence_time: 800 15 | max_start_silence_time: 3000 16 | do_start_point_detection: True 17 | do_end_point_detection: True 18 | window_size_ms: 200 19 | sil_to_speech_time_thres: 150 20 | speech_to_sil_time_thres: 150 21 | speech_2_noise_ratio: 1.0 22 | do_extend: 1 23 | lookback_time_start_point: 200 24 | lookahead_time_end_point: 100 25 | max_single_segment_time: 60000 26 | snr_thres: -100.0 27 | noise_frame_num_used_for_snr: 100 28 | decibel_thres: -100.0 29 | speech_noise_thres: 0.6 30 | fe_prior_thres: 0.0001 31 | silence_pdf_num: 1 32 | sil_pdf_ids: [0] 33 | speech_noise_thresh_low: -0.1 34 | speech_noise_thresh_high: 0.3 35 | output_frame_probs: False 36 | frame_in_ms: 10 37 | frame_length_ms: 25 38 | 39 | encoder: FSMN 40 | encoder_conf: 41 | input_dim: 400 42 | input_affine_dim: 140 43 | fsmn_layers: 4 44 | linear_dim: 250 45 | proj_dim: 128 46 | lorder: 20 47 | rorder: 0 48 | lstride: 1 49 | rstride: 0 50 | output_affine_dim: 140 51 | output_dim: 248 52 | 53 | frontend: WavFrontend 54 | frontend_conf: 55 | fs: 16000 56 | window: hamming 57 | n_mels: 80 58 | frame_length: 25 59 | frame_shift: 10 60 | dither: 0.0 61 | lfr_m: 5 62 | lfr_n: 1 63 | -------------------------------------------------------------------------------- /funasr_detach/models/language_model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stepfun-ai/Step-Audio/d67e374fd6aff418835418e2e32989ca2db00dff/funasr_detach/models/language_model/__init__.py -------------------------------------------------------------------------------- /funasr_detach/models/language_model/rnn/__init__.py: -------------------------------------------------------------------------------- 1 | """Initialize sub package.""" 2 | -------------------------------------------------------------------------------- /funasr_detach/models/lora/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stepfun-ai/Step-Audio/d67e374fd6aff418835418e2e32989ca2db00dff/funasr_detach/models/lora/__init__.py -------------------------------------------------------------------------------- /funasr_detach/models/lora/utils.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------ 2 | # Copyright (c) Microsoft Corporation. All rights reserved. 3 | # Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. 4 | # ------------------------------------------------------------------------------------------ 5 | import torch 6 | import torch.nn as nn 7 | 8 | from typing import Dict 9 | 10 | from .layers import LoRALayer 11 | 12 | 13 | def mark_only_lora_as_trainable(model: nn.Module, bias: str = "none") -> None: 14 | for n, p in model.named_parameters(): 15 | if "lora_" not in n and "cif" not in n: 16 | p.requires_grad = False 17 | if bias == "none": 18 | return 19 | elif bias == "all": 20 | for n, p in model.named_parameters(): 21 | if "bias" in n: 22 | p.requires_grad = True 23 | elif bias == "lora_only": 24 | for m in model.modules(): 25 | if isinstance(m, LoRALayer) and hasattr(m, "bias") and m.bias is not None: 26 | m.bias.requires_grad = True 27 | else: 28 | raise NotImplementedError 29 | 30 | 31 | def lora_state_dict(model: nn.Module, bias: str = "none") -> Dict[str, torch.Tensor]: 32 | my_state_dict = model.state_dict() 33 | if bias == "none": 34 | return {k: my_state_dict[k] for k in my_state_dict if "lora_" in k} 35 | elif bias == "all": 36 | return { 37 | k: my_state_dict[k] for k in my_state_dict if "lora_" in k or "bias" in k 38 | } 39 | elif bias == "lora_only": 40 | to_return = {} 41 | for k in my_state_dict: 42 | if "lora_" in k: 43 | to_return[k] = my_state_dict[k] 44 | bias_name = k.split("lora_")[0] + "bias" 45 | if bias_name in my_state_dict: 46 | to_return[bias_name] = my_state_dict[bias_name] 47 | return to_return 48 | else: 49 | raise NotImplementedError 50 | -------------------------------------------------------------------------------- /funasr_detach/models/mfcca/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stepfun-ai/Step-Audio/d67e374fd6aff418835418e2e32989ca2db00dff/funasr_detach/models/mfcca/__init__.py -------------------------------------------------------------------------------- /funasr_detach/models/model_hf/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stepfun-ai/Step-Audio/d67e374fd6aff418835418e2e32989ca2db00dff/funasr_detach/models/model_hf/__init__.py -------------------------------------------------------------------------------- /funasr_detach/models/monotonic_aligner/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stepfun-ai/Step-Audio/d67e374fd6aff418835418e2e32989ca2db00dff/funasr_detach/models/monotonic_aligner/__init__.py -------------------------------------------------------------------------------- /funasr_detach/models/monotonic_aligner/template.yaml: -------------------------------------------------------------------------------- 1 | # This is an example that demonstrates how to configure a model file. 2 | # You can modify the configuration according to your own requirements. 3 | 4 | # to print the register_table: 5 | # from funasr.register import tables 6 | # tables.print() 7 | 8 | # network architecture 9 | model: MonotonicAligner 10 | model_conf: 11 | length_normalized_loss: False 12 | predictor_bias: 1 13 | 14 | # encoder 15 | encoder: SANMEncoder 16 | encoder_conf: 17 | output_size: 320 18 | attention_heads: 4 19 | linear_units: 1280 20 | num_blocks: 30 21 | dropout_rate: 0.1 22 | positional_dropout_rate: 0.1 23 | attention_dropout_rate: 0.1 24 | input_layer: pe 25 | pos_enc_class: SinusoidalPositionEncoder 26 | normalize_before: true 27 | kernel_size: 11 28 | sanm_shfit: 0 29 | selfattention_layer_type: sanm 30 | 31 | predictor: CifPredictorV3 32 | predictor_conf: 33 | idim: 320 34 | threshold: 1.0 35 | l_order: 1 36 | r_order: 1 37 | tail_threshold: 0.45 38 | smooth_factor2: 0.25 39 | noise_threshold2: 0.01 40 | upsample_times: 3 41 | use_cif1_cnn: false 42 | upsample_type: cnn_blstm 43 | 44 | # frontend related 45 | frontend: WavFrontend 46 | frontend_conf: 47 | fs: 16000 48 | window: hamming 49 | n_mels: 80 50 | frame_length: 25 51 | frame_shift: 10 52 | lfr_m: 7 53 | lfr_n: 6 54 | 55 | specaug: SpecAugLFR 56 | specaug_conf: 57 | apply_time_warp: false 58 | time_warp_window: 5 59 | time_warp_mode: bicubic 60 | apply_freq_mask: true 61 | freq_mask_width_range: 62 | - 0 63 | - 30 64 | lfr_rate: 6 65 | num_freq_mask: 1 66 | apply_time_mask: true 67 | time_mask_width_range: 68 | - 0 69 | - 12 70 | num_time_mask: 1 71 | 72 | train_conf: 73 | accum_grad: 1 74 | grad_clip: 5 75 | max_epoch: 150 76 | val_scheduler_criterion: 77 | - valid 78 | - acc 79 | best_model_criterion: 80 | - - valid 81 | - acc 82 | - max 83 | keep_nbest_models: 10 84 | log_interval: 50 85 | 86 | optim: adam 87 | optim_conf: 88 | lr: 0.0005 89 | scheduler: warmuplr 90 | scheduler_conf: 91 | warmup_steps: 30000 92 | 93 | dataset: AudioDataset 94 | dataset_conf: 95 | index_ds: IndexDSJsonl 96 | batch_sampler: DynamicBatchLocalShuffleSampler 97 | batch_type: example # example or length 98 | batch_size: 1 # if batch_type is example, batch_size is the numbers of samples; if length, batch_size is source_token_len+target_token_len; 99 | max_token_length: 2048 # filter samples if source_token_len+target_token_len > max_token_length, 100 | buffer_size: 500 101 | shuffle: True 102 | num_workers: 0 103 | 104 | tokenizer: CharTokenizer 105 | tokenizer_conf: 106 | unk_symbol: 107 | split_with_space: true 108 | 109 | ctc_conf: 110 | dropout_rate: 0.0 111 | ctc_type: builtin 112 | reduce: true 113 | ignore_nan_grad: true 114 | 115 | normalize: null 116 | -------------------------------------------------------------------------------- /funasr_detach/models/mossformer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stepfun-ai/Step-Audio/d67e374fd6aff418835418e2e32989ca2db00dff/funasr_detach/models/mossformer/__init__.py -------------------------------------------------------------------------------- /funasr_detach/models/mossformer/e2e_ss.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import copy 6 | from funasr_detach.models.base_model import FunASRModel 7 | from funasr_detach.models.encoder.mossformer_encoder import ( 8 | MossFormerEncoder, 9 | MossFormer_MaskNet, 10 | ) 11 | from funasr_detach.models.decoder.mossformer_decoder import MossFormerDecoder 12 | 13 | 14 | class MossFormer(FunASRModel): 15 | """The MossFormer model for separating input mixed speech into different speaker's speech. 16 | 17 | Arguments 18 | --------- 19 | in_channels : int 20 | Number of channels at the output of the encoder. 21 | out_channels : int 22 | Number of channels that would be inputted to the intra and inter blocks. 23 | num_blocks : int 24 | Number of layers of Dual Computation Block. 25 | norm : str 26 | Normalization type. 27 | num_spks : int 28 | Number of sources (speakers). 29 | skip_around_intra : bool 30 | Skip connection around intra. 31 | use_global_pos_enc : bool 32 | Global positional encodings. 33 | max_length : int 34 | Maximum sequence length. 35 | kernel_size: int 36 | Encoder and decoder kernel size 37 | """ 38 | 39 | def __init__( 40 | self, 41 | in_channels=512, 42 | out_channels=512, 43 | num_blocks=24, 44 | kernel_size=16, 45 | norm="ln", 46 | num_spks=2, 47 | skip_around_intra=True, 48 | use_global_pos_enc=True, 49 | max_length=20000, 50 | ): 51 | super(MossFormer, self).__init__() 52 | self.num_spks = num_spks 53 | # Encoding 54 | self.enc = MossFormerEncoder( 55 | kernel_size=kernel_size, out_channels=in_channels, in_channels=1 56 | ) 57 | 58 | ##Compute Mask 59 | self.mask_net = MossFormer_MaskNet( 60 | in_channels=in_channels, 61 | out_channels=out_channels, 62 | num_blocks=num_blocks, 63 | norm=norm, 64 | num_spks=num_spks, 65 | skip_around_intra=skip_around_intra, 66 | use_global_pos_enc=use_global_pos_enc, 67 | max_length=max_length, 68 | ) 69 | self.dec = MossFormerDecoder( 70 | in_channels=out_channels, 71 | out_channels=1, 72 | kernel_size=kernel_size, 73 | stride=kernel_size // 2, 74 | bias=False, 75 | ) 76 | 77 | def forward(self, input): 78 | x = self.enc(input) 79 | mask = self.mask_net(x) 80 | x = torch.stack([x] * self.num_spks) 81 | sep_x = x * mask 82 | 83 | # Decoding 84 | est_source = torch.cat( 85 | [self.dec(sep_x[i]).unsqueeze(-1) for i in range(self.num_spks)], 86 | dim=-1, 87 | ) 88 | T_origin = input.size(1) 89 | T_est = est_source.size(1) 90 | if T_origin > T_est: 91 | est_source = F.pad(est_source, (0, 0, 0, T_origin - T_est)) 92 | else: 93 | est_source = est_source[:, :T_origin, :] 94 | 95 | out = [] 96 | for spk in range(self.num_spks): 97 | out.append(est_source[:, :, spk]) 98 | return out 99 | -------------------------------------------------------------------------------- /funasr_detach/models/mossformer/mossformer_decoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class MossFormerDecoder(nn.ConvTranspose1d): 6 | """A decoder layer that consists of ConvTranspose1d. 7 | 8 | Arguments 9 | --------- 10 | kernel_size : int 11 | Length of filters. 12 | in_channels : int 13 | Number of input channels. 14 | out_channels : int 15 | Number of output channels. 16 | 17 | 18 | Example 19 | --------- 20 | >>> x = torch.randn(2, 100, 1000) 21 | >>> decoder = Decoder(kernel_size=4, in_channels=100, out_channels=1) 22 | >>> h = decoder(x) 23 | >>> h.shape 24 | torch.Size([2, 1003]) 25 | """ 26 | 27 | def __init__(self, *args, **kwargs): 28 | super(MossFormerDecoder, self).__init__(*args, **kwargs) 29 | 30 | def forward(self, x): 31 | """Return the decoded output. 32 | 33 | Arguments 34 | --------- 35 | x : torch.Tensor 36 | Input tensor with dimensionality [B, N, L]. 37 | where, B = Batchsize, 38 | N = number of filters 39 | L = time points 40 | """ 41 | 42 | if x.dim() not in [2, 3]: 43 | raise RuntimeError("{} accept 3/4D tensor as input".format(self.__name__)) 44 | x = super().forward(x if x.dim() == 3 else torch.unsqueeze(x, 1)) 45 | 46 | if torch.squeeze(x).dim() == 1: 47 | x = torch.squeeze(x, dim=1) 48 | else: 49 | x = torch.squeeze(x) 50 | return x 51 | -------------------------------------------------------------------------------- /funasr_detach/models/normalize/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stepfun-ai/Step-Audio/d67e374fd6aff418835418e2e32989ca2db00dff/funasr_detach/models/normalize/__init__.py -------------------------------------------------------------------------------- /funasr_detach/models/normalize/utterance_mvn.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import torch 4 | 5 | from funasr_detach.models.transformer.utils.nets_utils import make_pad_mask 6 | from funasr_detach.register import tables 7 | 8 | 9 | @tables.register("normalize_classes", "UtteranceMVN") 10 | class UtteranceMVN(torch.nn.Module): 11 | def __init__( 12 | self, 13 | norm_means: bool = True, 14 | norm_vars: bool = False, 15 | eps: float = 1.0e-20, 16 | ): 17 | super().__init__() 18 | self.norm_means = norm_means 19 | self.norm_vars = norm_vars 20 | self.eps = eps 21 | 22 | def extra_repr(self): 23 | return f"norm_means={self.norm_means}, norm_vars={self.norm_vars}" 24 | 25 | def forward( 26 | self, x: torch.Tensor, ilens: torch.Tensor = None 27 | ) -> Tuple[torch.Tensor, torch.Tensor]: 28 | """Forward function 29 | 30 | Args: 31 | x: (B, L, ...) 32 | ilens: (B,) 33 | 34 | """ 35 | return utterance_mvn( 36 | x, 37 | ilens, 38 | norm_means=self.norm_means, 39 | norm_vars=self.norm_vars, 40 | eps=self.eps, 41 | ) 42 | 43 | 44 | def utterance_mvn( 45 | x: torch.Tensor, 46 | ilens: torch.Tensor = None, 47 | norm_means: bool = True, 48 | norm_vars: bool = False, 49 | eps: float = 1.0e-20, 50 | ) -> Tuple[torch.Tensor, torch.Tensor]: 51 | """Apply utterance mean and variance normalization 52 | 53 | Args: 54 | x: (B, T, D), assumed zero padded 55 | ilens: (B,) 56 | norm_means: 57 | norm_vars: 58 | eps: 59 | 60 | """ 61 | if ilens is None: 62 | ilens = x.new_full([x.size(0)], x.size(1)) 63 | ilens_ = ilens.to(x.device, x.dtype).view(-1, *[1 for _ in range(x.dim() - 1)]) 64 | # Zero padding 65 | if x.requires_grad: 66 | x = x.masked_fill(make_pad_mask(ilens, x, 1), 0.0) 67 | else: 68 | x.masked_fill_(make_pad_mask(ilens, x, 1), 0.0) 69 | # mean: (B, 1, D) 70 | mean = x.sum(dim=1, keepdim=True) / ilens_ 71 | 72 | if norm_means: 73 | x -= mean 74 | 75 | if norm_vars: 76 | var = x.pow(2).sum(dim=1, keepdim=True) / ilens_ 77 | std = torch.clamp(var.sqrt(), min=eps) 78 | x = x / std.sqrt() 79 | return x, ilens 80 | else: 81 | if norm_vars: 82 | y = x - mean 83 | y.masked_fill_(make_pad_mask(ilens, y, 1), 0.0) 84 | var = y.pow(2).sum(dim=1, keepdim=True) / ilens_ 85 | std = torch.clamp(var.sqrt(), min=eps) 86 | x /= std 87 | return x, ilens 88 | -------------------------------------------------------------------------------- /funasr_detach/models/paraformer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stepfun-ai/Step-Audio/d67e374fd6aff418835418e2e32989ca2db00dff/funasr_detach/models/paraformer/__init__.py -------------------------------------------------------------------------------- /funasr_detach/models/paraformer/template.yaml: -------------------------------------------------------------------------------- 1 | # This is an example that demonstrates how to configure a model file. 2 | # You can modify the configuration according to your own requirements. 3 | 4 | # to print the register_table: 5 | # from funasr.register import tables 6 | # tables.print() 7 | 8 | # network architecture 9 | model: Paraformer 10 | model_conf: 11 | ctc_weight: 0.0 12 | lsm_weight: 0.1 13 | length_normalized_loss: true 14 | predictor_weight: 1.0 15 | predictor_bias: 1 16 | sampling_ratio: 0.75 17 | 18 | # encoder 19 | encoder: SANMEncoder 20 | encoder_conf: 21 | output_size: 512 22 | attention_heads: 4 23 | linear_units: 2048 24 | num_blocks: 50 25 | dropout_rate: 0.1 26 | positional_dropout_rate: 0.1 27 | attention_dropout_rate: 0.1 28 | input_layer: pe 29 | pos_enc_class: SinusoidalPositionEncoder 30 | normalize_before: true 31 | kernel_size: 11 32 | sanm_shfit: 0 33 | selfattention_layer_type: sanm 34 | 35 | # decoder 36 | decoder: ParaformerSANMDecoder 37 | decoder_conf: 38 | attention_heads: 4 39 | linear_units: 2048 40 | num_blocks: 16 41 | dropout_rate: 0.1 42 | positional_dropout_rate: 0.1 43 | self_attention_dropout_rate: 0.1 44 | src_attention_dropout_rate: 0.1 45 | att_layer_num: 16 46 | kernel_size: 11 47 | sanm_shfit: 0 48 | 49 | predictor: CifPredictorV2 50 | predictor_conf: 51 | idim: 512 52 | threshold: 1.0 53 | l_order: 1 54 | r_order: 1 55 | tail_threshold: 0.45 56 | 57 | # frontend related 58 | frontend: WavFrontend 59 | frontend_conf: 60 | fs: 16000 61 | window: hamming 62 | n_mels: 80 63 | frame_length: 25 64 | frame_shift: 10 65 | lfr_m: 7 66 | lfr_n: 6 67 | 68 | specaug: SpecAugLFR 69 | specaug_conf: 70 | apply_time_warp: false 71 | time_warp_window: 5 72 | time_warp_mode: bicubic 73 | apply_freq_mask: true 74 | freq_mask_width_range: 75 | - 0 76 | - 30 77 | lfr_rate: 6 78 | num_freq_mask: 1 79 | apply_time_mask: true 80 | time_mask_width_range: 81 | - 0 82 | - 12 83 | num_time_mask: 1 84 | 85 | train_conf: 86 | accum_grad: 1 87 | grad_clip: 5 88 | max_epoch: 150 89 | keep_nbest_models: 10 90 | avg_nbest_model: 5 91 | log_interval: 50 92 | 93 | optim: adam 94 | optim_conf: 95 | lr: 0.0005 96 | scheduler: warmuplr 97 | scheduler_conf: 98 | warmup_steps: 30000 99 | 100 | dataset: AudioDataset 101 | dataset_conf: 102 | index_ds: IndexDSJsonl 103 | batch_sampler: DynamicBatchLocalShuffleSampler 104 | batch_type: example # example or length 105 | batch_size: 1 # if batch_type is example, batch_size is the numbers of samples; if length, batch_size is source_token_len+target_token_len; 106 | max_token_length: 2048 # filter samples if source_token_len+target_token_len > max_token_length, 107 | buffer_size: 500 108 | shuffle: True 109 | num_workers: 0 110 | 111 | tokenizer: CharTokenizer 112 | tokenizer_conf: 113 | unk_symbol: 114 | split_with_space: true 115 | 116 | 117 | ctc_conf: 118 | dropout_rate: 0.0 119 | ctc_type: builtin 120 | reduce: true 121 | ignore_nan_grad: true 122 | normalize: null 123 | -------------------------------------------------------------------------------- /funasr_detach/models/paraformer_streaming/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stepfun-ai/Step-Audio/d67e374fd6aff418835418e2e32989ca2db00dff/funasr_detach/models/paraformer_streaming/__init__.py -------------------------------------------------------------------------------- /funasr_detach/models/rwkv_bat/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stepfun-ai/Step-Audio/d67e374fd6aff418835418e2e32989ca2db00dff/funasr_detach/models/rwkv_bat/__init__.py -------------------------------------------------------------------------------- /funasr_detach/models/rwkv_bat/cuda_decoder/wkv_op.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Bsed on https://github.com/BlinkDL/RWKV-LM/blob/main/RWKV-v4/cuda/wkv_op.cpp 3 | Function signatures were modified based on https://github.com/huggingface/transformers/blob/main/src/transformers/kernels/rwkv/wkv_op.cpp 4 | 5 | */ 6 | 7 | #include 8 | 9 | void cuda_forward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y); 10 | 11 | void cuda_backward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y, float *gy, float *gw, float *gu, float *gk, float *gv); 12 | 13 | void forward(torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y) { 14 | const int B = k.size(0); 15 | const int T = k.size(1); 16 | const int C = k.size(2); 17 | 18 | cuda_forward(B, T, C, w.data_ptr(), u.data_ptr(), k.data_ptr(), v.data_ptr(), y.data_ptr()); 19 | } 20 | 21 | void backward(torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y, torch::Tensor &gy, torch::Tensor &gw, torch::Tensor &gu, torch::Tensor &gk, torch::Tensor &gv) { 22 | const int B = k.size(0); 23 | const int T = k.size(1); 24 | const int C = k.size(2); 25 | 26 | cuda_backward(B, T, C, w.data_ptr(), u.data_ptr(), k.data_ptr(), v.data_ptr(), y.data_ptr(), gy.data_ptr(), gw.data_ptr(), gu.data_ptr(), gk.data_ptr(), gv.data_ptr()); 27 | } 28 | 29 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 30 | m.def("forward", &forward, "wkv forward"); 31 | m.def("backward", &backward, "wkv backward"); 32 | } 33 | 34 | TORCH_LIBRARY(wkv_decoder, m) { 35 | m.def("forward", forward); 36 | m.def("backward", backward); 37 | } 38 | -------------------------------------------------------------------------------- /funasr_detach/models/rwkv_bat/cuda_encoder/wkv_op.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Bsed on https://github.com/BlinkDL/RWKV-LM/blob/main/RWKV-v4/cuda/wkv_op.cpp 3 | Function signatures were modified based on https://github.com/huggingface/transformers/blob/main/src/transformers/kernels/rwkv/wkv_op.cpp 4 | 5 | */ 6 | 7 | #include 8 | 9 | void cuda_forward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y); 10 | 11 | void cuda_backward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y, float *gy, float *gw, float *gu, float *gk, float *gv); 12 | 13 | void forward(torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y) { 14 | const int B = k.size(0); 15 | const int T = k.size(1); 16 | const int C = k.size(2); 17 | 18 | cuda_forward(B, T, C, w.data_ptr(), u.data_ptr(), k.data_ptr(), v.data_ptr(), y.data_ptr()); 19 | } 20 | 21 | void backward(torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y, torch::Tensor &gy, torch::Tensor &gw, torch::Tensor &gu, torch::Tensor &gk, torch::Tensor &gv) { 22 | const int B = k.size(0); 23 | const int T = k.size(1); 24 | const int C = k.size(2); 25 | 26 | cuda_backward(B, T, C, w.data_ptr(), u.data_ptr(), k.data_ptr(), v.data_ptr(), y.data_ptr(), gy.data_ptr(), gw.data_ptr(), gu.data_ptr(), gk.data_ptr(), gv.data_ptr()); 27 | } 28 | 29 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 30 | m.def("forward", &forward, "wkv forward"); 31 | m.def("backward", &backward, "wkv backward"); 32 | } 33 | 34 | TORCH_LIBRARY(wkv_encoder, m) { 35 | m.def("forward", forward); 36 | m.def("backward", backward); 37 | } 38 | -------------------------------------------------------------------------------- /funasr_detach/models/rwkv_bat/rwkv_feed_forward.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- encoding: utf-8 -*- 3 | # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. 4 | # MIT License (https://opensource.org/licenses/MIT) 5 | 6 | import torch 7 | from typing import List, Optional, Tuple 8 | 9 | 10 | class FeedForward(torch.nn.Module): 11 | """FeedForward module definition. 12 | 13 | Args: 14 | size: Input/Output size. 15 | hidden_size: Hidden size. 16 | block_id: Block index. 17 | num_blocks: Number of blocks in the architecture. 18 | 19 | """ 20 | 21 | def __init__( 22 | self, 23 | size: int, 24 | hidden_size: int, 25 | block_id: int, 26 | dropout_rate: float, 27 | num_blocks: int, 28 | ) -> None: 29 | """Construct a FeedForward object.""" 30 | super().__init__() 31 | 32 | self.time_shift = torch.nn.ZeroPad2d((0, 0, 1, -1)) 33 | 34 | self.time_mix_key = torch.nn.Parameter(torch.empty(1, 1, size)) 35 | self.time_mix_receptance = torch.nn.Parameter(torch.empty(1, 1, size)) 36 | 37 | self.proj_key = torch.nn.Linear(size, hidden_size, bias=True) 38 | self.proj_value = torch.nn.Linear(hidden_size, size, bias=True) 39 | self.proj_receptance = torch.nn.Linear(size, size, bias=True) 40 | 41 | self.block_id = block_id 42 | 43 | self.reset_parameters(size, block_id, num_blocks) 44 | self.dropout = torch.nn.Dropout(p=dropout_rate) 45 | 46 | def reset_parameters(self, size: int, block_id: int, num_blocks: int) -> None: 47 | """Reset module parameters. 48 | 49 | Args: 50 | size: Block size. 51 | block_id: Block index. 52 | num_blocks: Number of blocks in the architecture. 53 | 54 | """ 55 | ratio_1_to_almost0 = 1.0 - (block_id / num_blocks) 56 | 57 | time_weight = torch.ones(1, 1, size) 58 | 59 | for i in range(size): 60 | time_weight[0, 0, i] = i / size 61 | 62 | with torch.no_grad(): 63 | self.time_mix_key.data = torch.pow(time_weight, ratio_1_to_almost0) 64 | self.time_mix_receptance.data = torch.pow(time_weight, ratio_1_to_almost0) 65 | 66 | def forward( 67 | self, x: torch.Tensor, state: Optional[List[torch.Tensor]] = None 68 | ) -> Tuple[torch.Tensor, Optional[List[torch.Tensor]]]: 69 | """Compute channel mixing. 70 | 71 | Args: 72 | x: FeedForward input sequences. (B, U, size) 73 | state: Decoder hidden state. [5 x (B, 1, size, N)] 74 | 75 | Returns: 76 | x: FeedForward output sequences. (B, U, size) 77 | state: Decoder hidden state. [5 x (B, 1, size, N)] 78 | 79 | """ 80 | shifted_x = ( 81 | self.time_shift(x) if state is None else state[0][..., self.block_id] 82 | ) 83 | 84 | key = x * self.time_mix_key + shifted_x * (1 - self.time_mix_key) 85 | receptance = x * self.time_mix_receptance + shifted_x * ( 86 | 1 - self.time_mix_receptance 87 | ) 88 | 89 | key = torch.square(torch.relu(self.proj_key(key))) 90 | value = self.proj_value(self.dropout(key)) 91 | receptance = torch.sigmoid(self.proj_receptance(receptance)) 92 | 93 | if state is not None: 94 | state[0][..., self.block_id] = x 95 | 96 | x = receptance * value 97 | 98 | return x, state 99 | -------------------------------------------------------------------------------- /funasr_detach/models/sa_asr/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stepfun-ai/Step-Audio/d67e374fd6aff418835418e2e32989ca2db00dff/funasr_detach/models/sa_asr/__init__.py -------------------------------------------------------------------------------- /funasr_detach/models/sa_asr/attention.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | 5 | """Multi-Head Attention layer definition.""" 6 | 7 | import math 8 | 9 | import numpy 10 | import torch 11 | from torch import nn 12 | from typing import Optional, Tuple 13 | 14 | import torch.nn.functional as F 15 | from funasr_detach.models.transformer.utils.nets_utils import make_pad_mask 16 | import funasr_detach.models.lora.layers as lora 17 | 18 | 19 | class CosineDistanceAttention(nn.Module): 20 | """Compute Cosine Distance between spk decoder output and speaker profile 21 | Args: 22 | profile_path: speaker profile file path (.npy file) 23 | """ 24 | 25 | def __init__(self): 26 | super().__init__() 27 | self.softmax = nn.Softmax(dim=-1) 28 | 29 | def forward(self, spk_decoder_out, profile, profile_lens=None): 30 | """ 31 | Args: 32 | spk_decoder_out(torch.Tensor):(B, L, D) 33 | spk_profiles(torch.Tensor):(B, N, D) 34 | """ 35 | x = spk_decoder_out.unsqueeze(2) # (B, L, 1, D) 36 | if profile_lens is not None: 37 | 38 | mask = (make_pad_mask(profile_lens)[:, None, :]).to(profile.device) 39 | min_value = float( 40 | numpy.finfo(torch.tensor(0, dtype=x.dtype).numpy().dtype).min 41 | ) 42 | weights_not_softmax = F.cosine_similarity( 43 | x, profile.unsqueeze(1), dim=-1 44 | ).masked_fill(mask, min_value) 45 | weights = self.softmax(weights_not_softmax).masked_fill( 46 | mask, 0.0 47 | ) # (B, L, N) 48 | else: 49 | x = x[:, -1:, :, :] 50 | weights_not_softmax = F.cosine_similarity( 51 | x, profile.unsqueeze(1).to(x.device), dim=-1 52 | ) 53 | weights = self.softmax(weights_not_softmax) # (B, 1, N) 54 | spk_embedding = torch.matmul(weights, profile.to(weights.device)) # (B, L, D) 55 | 56 | return spk_embedding, weights 57 | -------------------------------------------------------------------------------- /funasr_detach/models/sanm/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stepfun-ai/Step-Audio/d67e374fd6aff418835418e2e32989ca2db00dff/funasr_detach/models/sanm/__init__.py -------------------------------------------------------------------------------- /funasr_detach/models/sanm/model.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- encoding: utf-8 -*- 3 | # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. 4 | # MIT License (https://opensource.org/licenses/MIT) 5 | 6 | import logging 7 | 8 | import torch 9 | 10 | from funasr_detach.models.transformer.model import Transformer 11 | from funasr_detach.register import tables 12 | 13 | 14 | @tables.register("model_classes", "SANM") 15 | class SANM(Transformer): 16 | """ 17 | Author: Zhifu Gao, Shiliang Zhang, Ming Lei, Ian McLoughlin 18 | San-m: Memory equipped self-attention for end-to-end speech recognition 19 | https://arxiv.org/abs/2006.01713 20 | """ 21 | 22 | def __init__( 23 | self, 24 | *args, 25 | **kwargs, 26 | ): 27 | 28 | super().__init__(*args, **kwargs) 29 | -------------------------------------------------------------------------------- /funasr_detach/models/sanm/positionwise_feed_forward.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | 5 | """Positionwise feed forward layer definition.""" 6 | 7 | import torch 8 | 9 | from funasr_detach.models.transformer.layer_norm import LayerNorm 10 | 11 | 12 | class PositionwiseFeedForwardDecoderSANM(torch.nn.Module): 13 | """Positionwise feed forward layer. 14 | 15 | Args: 16 | idim (int): Input dimenstion. 17 | hidden_units (int): The number of hidden units. 18 | dropout_rate (float): Dropout rate. 19 | 20 | """ 21 | 22 | def __init__( 23 | self, idim, hidden_units, dropout_rate, adim=None, activation=torch.nn.ReLU() 24 | ): 25 | """Construct an PositionwiseFeedForward object.""" 26 | super(PositionwiseFeedForwardDecoderSANM, self).__init__() 27 | self.w_1 = torch.nn.Linear(idim, hidden_units) 28 | self.w_2 = torch.nn.Linear( 29 | hidden_units, idim if adim is None else adim, bias=False 30 | ) 31 | self.dropout = torch.nn.Dropout(dropout_rate) 32 | self.activation = activation 33 | self.norm = LayerNorm(hidden_units) 34 | 35 | def forward(self, x): 36 | """Forward function.""" 37 | return self.w_2(self.norm(self.dropout(self.activation(self.w_1(x))))) 38 | -------------------------------------------------------------------------------- /funasr_detach/models/sanm/template.yaml: -------------------------------------------------------------------------------- 1 | # This is an example that demonstrates how to configure a model file. 2 | # You can modify the configuration according to your own requirements. 3 | 4 | # to print the register_table: 5 | # from funasr.register import tables 6 | # tables.print() 7 | 8 | # network architecture 9 | model: SANM 10 | model_conf: 11 | ctc_weight: 0.0 12 | lsm_weight: 0.1 13 | length_normalized_loss: true 14 | 15 | # encoder 16 | encoder: SANMEncoder 17 | encoder_conf: 18 | output_size: 512 19 | attention_heads: 4 20 | linear_units: 2048 21 | num_blocks: 50 22 | dropout_rate: 0.1 23 | positional_dropout_rate: 0.1 24 | attention_dropout_rate: 0.1 25 | input_layer: pe 26 | pos_enc_class: SinusoidalPositionEncoder 27 | normalize_before: true 28 | kernel_size: 11 29 | sanm_shfit: 0 30 | selfattention_layer_type: sanm 31 | 32 | # decoder 33 | decoder: FsmnDecoder 34 | decoder_conf: 35 | attention_heads: 4 36 | linear_units: 2048 37 | num_blocks: 16 38 | dropout_rate: 0.1 39 | positional_dropout_rate: 0.1 40 | self_attention_dropout_rate: 0.1 41 | src_attention_dropout_rate: 0.1 42 | att_layer_num: 16 43 | kernel_size: 11 44 | sanm_shfit: 0 45 | 46 | 47 | 48 | # frontend related 49 | frontend: WavFrontend 50 | frontend_conf: 51 | fs: 16000 52 | window: hamming 53 | n_mels: 80 54 | frame_length: 25 55 | frame_shift: 10 56 | lfr_m: 7 57 | lfr_n: 6 58 | 59 | specaug: SpecAugLFR 60 | specaug_conf: 61 | apply_time_warp: false 62 | time_warp_window: 5 63 | time_warp_mode: bicubic 64 | apply_freq_mask: true 65 | freq_mask_width_range: 66 | - 0 67 | - 30 68 | lfr_rate: 6 69 | num_freq_mask: 1 70 | apply_time_mask: true 71 | time_mask_width_range: 72 | - 0 73 | - 12 74 | num_time_mask: 1 75 | 76 | train_conf: 77 | accum_grad: 1 78 | grad_clip: 5 79 | max_epoch: 150 80 | val_scheduler_criterion: 81 | - valid 82 | - acc 83 | best_model_criterion: 84 | - - valid 85 | - acc 86 | - max 87 | keep_nbest_models: 10 88 | avg_nbest_model: 5 89 | log_interval: 50 90 | 91 | optim: adam 92 | optim_conf: 93 | lr: 0.0005 94 | scheduler: warmuplr 95 | scheduler_conf: 96 | warmup_steps: 30000 97 | 98 | dataset: AudioDataset 99 | dataset_conf: 100 | index_ds: IndexDSJsonl 101 | batch_sampler: DynamicBatchLocalShuffleSampler 102 | batch_type: example # example or length 103 | batch_size: 1 # if batch_type is example, batch_size is the numbers of samples; if length, batch_size is source_token_len+target_token_len; 104 | max_token_length: 2048 # filter samples if source_token_len+target_token_len > max_token_length, 105 | buffer_size: 500 106 | shuffle: True 107 | num_workers: 0 108 | 109 | tokenizer: CharTokenizer 110 | tokenizer_conf: 111 | unk_symbol: 112 | split_with_space: true 113 | 114 | 115 | ctc_conf: 116 | dropout_rate: 0.0 117 | ctc_type: builtin 118 | reduce: true 119 | ignore_nan_grad: true 120 | 121 | normalize: null 122 | -------------------------------------------------------------------------------- /funasr_detach/models/scama/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stepfun-ai/Step-Audio/d67e374fd6aff418835418e2e32989ca2db00dff/funasr_detach/models/scama/__init__.py -------------------------------------------------------------------------------- /funasr_detach/models/scama/template.yaml: -------------------------------------------------------------------------------- 1 | # This is an example that demonstrates how to configure a model file. 2 | # You can modify the configuration according to your own requirements. 3 | 4 | # to print the register_table: 5 | # from funasr.register import tables 6 | # tables.print() 7 | 8 | # network architecture 9 | model: SCAMA 10 | model_conf: 11 | ctc_weight: 0.0 12 | lsm_weight: 0.1 13 | length_normalized_loss: true 14 | 15 | # encoder 16 | encoder: SANMEncoderChunkOpt 17 | encoder_conf: 18 | output_size: 512 19 | attention_heads: 4 20 | linear_units: 2048 21 | num_blocks: 50 22 | dropout_rate: 0.1 23 | positional_dropout_rate: 0.1 24 | attention_dropout_rate: 0.1 25 | input_layer: pe 26 | pos_enc_class: SinusoidalPositionEncoder 27 | normalize_before: true 28 | kernel_size: 11 29 | sanm_shfit: 0 30 | selfattention_layer_type: sanm 31 | 32 | # decoder 33 | decoder: FsmnDecoderSCAMAOpt 34 | decoder_conf: 35 | attention_heads: 4 36 | linear_units: 2048 37 | num_blocks: 16 38 | dropout_rate: 0.1 39 | positional_dropout_rate: 0.1 40 | self_attention_dropout_rate: 0.1 41 | src_attention_dropout_rate: 0.1 42 | att_layer_num: 16 43 | kernel_size: 11 44 | sanm_shfit: 0 45 | 46 | predictor: CifPredictorV2 47 | predictor_conf: 48 | idim: 512 49 | threshold: 1.0 50 | l_order: 1 51 | r_order: 1 52 | tail_threshold: 0.45 53 | 54 | # frontend related 55 | frontend: WavFrontend 56 | frontend_conf: 57 | fs: 16000 58 | window: hamming 59 | n_mels: 80 60 | frame_length: 25 61 | frame_shift: 10 62 | lfr_m: 7 63 | lfr_n: 6 64 | 65 | specaug: SpecAugLFR 66 | specaug_conf: 67 | apply_time_warp: false 68 | time_warp_window: 5 69 | time_warp_mode: bicubic 70 | apply_freq_mask: true 71 | freq_mask_width_range: 72 | - 0 73 | - 30 74 | lfr_rate: 6 75 | num_freq_mask: 1 76 | apply_time_mask: true 77 | time_mask_width_range: 78 | - 0 79 | - 12 80 | num_time_mask: 1 81 | 82 | train_conf: 83 | accum_grad: 1 84 | grad_clip: 5 85 | max_epoch: 150 86 | val_scheduler_criterion: 87 | - valid 88 | - acc 89 | best_model_criterion: 90 | - - valid 91 | - acc 92 | - max 93 | keep_nbest_models: 10 94 | avg_nbest_model: 5 95 | log_interval: 50 96 | 97 | optim: adam 98 | optim_conf: 99 | lr: 0.0005 100 | scheduler: warmuplr 101 | scheduler_conf: 102 | warmup_steps: 30000 103 | 104 | dataset: AudioDataset 105 | dataset_conf: 106 | index_ds: IndexDSJsonl 107 | batch_sampler: DynamicBatchLocalShuffleSampler 108 | batch_type: example # example or length 109 | batch_size: 1 # if batch_type is example, batch_size is the numbers of samples; if length, batch_size is source_token_len+target_token_len; 110 | max_token_length: 2048 # filter samples if source_token_len+target_token_len > max_token_length, 111 | buffer_size: 500 112 | shuffle: True 113 | num_workers: 0 114 | 115 | tokenizer: CharTokenizer 116 | tokenizer_conf: 117 | unk_symbol: 118 | split_with_space: true 119 | 120 | 121 | ctc_conf: 122 | dropout_rate: 0.0 123 | ctc_type: builtin 124 | reduce: true 125 | ignore_nan_grad: true 126 | 127 | normalize: null 128 | -------------------------------------------------------------------------------- /funasr_detach/models/scama/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import yaml 3 | import torch 4 | import numpy as np 5 | from torch.nn import functional as F 6 | 7 | 8 | def sequence_mask(lengths, maxlen=None, dtype=torch.float32, device=None): 9 | if maxlen is None: 10 | maxlen = lengths.max() 11 | row_vector = torch.arange(0, maxlen, 1).to(lengths.device) 12 | matrix = torch.unsqueeze(lengths, dim=-1) 13 | mask = row_vector < matrix 14 | mask = mask.detach() 15 | 16 | return mask.type(dtype).to(device) if device is not None else mask.type(dtype) 17 | 18 | 19 | def apply_cmvn(inputs, mvn): 20 | device = inputs.device 21 | dtype = inputs.dtype 22 | frame, dim = inputs.shape 23 | meams = np.tile(mvn[0:1, :dim], (frame, 1)) 24 | vars = np.tile(mvn[1:2, :dim], (frame, 1)) 25 | inputs -= torch.from_numpy(meams).type(dtype).to(device) 26 | inputs *= torch.from_numpy(vars).type(dtype).to(device) 27 | 28 | return inputs.type(torch.float32) 29 | 30 | 31 | def drop_and_add( 32 | inputs: torch.Tensor, 33 | outputs: torch.Tensor, 34 | training: bool, 35 | dropout_rate: float = 0.1, 36 | stoch_layer_coeff: float = 1.0, 37 | ): 38 | 39 | outputs = F.dropout(outputs, p=dropout_rate, training=training, inplace=True) 40 | outputs *= stoch_layer_coeff 41 | 42 | input_dim = inputs.size(-1) 43 | output_dim = outputs.size(-1) 44 | 45 | if input_dim == output_dim: 46 | outputs += inputs 47 | return outputs 48 | 49 | 50 | def proc_tf_vocab(vocab_path): 51 | with open(vocab_path, encoding="utf-8") as f: 52 | token_list = [line.rstrip() for line in f] 53 | if "" not in token_list: 54 | token_list.append("") 55 | return token_list 56 | 57 | 58 | def gen_config_for_tfmodel(config_path, vocab_path, output_dir): 59 | token_list = proc_tf_vocab(vocab_path) 60 | with open(config_path, encoding="utf-8") as f: 61 | config = yaml.safe_load(f) 62 | 63 | config["token_list"] = token_list 64 | 65 | if not os.path.exists(output_dir): 66 | os.makedirs(output_dir) 67 | 68 | with open(os.path.join(output_dir, "config.yaml"), "w", encoding="utf-8") as f: 69 | yaml_no_alias_safe_dump(config, f, indent=4, sort_keys=False) 70 | 71 | 72 | class NoAliasSafeDumper(yaml.SafeDumper): 73 | # Disable anchor/alias in yaml because looks ugly 74 | def ignore_aliases(self, data): 75 | return True 76 | 77 | 78 | def yaml_no_alias_safe_dump(data, stream=None, **kwargs): 79 | """Safe-dump in yaml with no anchor/alias""" 80 | return yaml.dump( 81 | data, stream, allow_unicode=True, Dumper=NoAliasSafeDumper, **kwargs 82 | ) 83 | 84 | 85 | if __name__ == "__main__": 86 | import sys 87 | 88 | config_path = sys.argv[1] 89 | vocab_path = sys.argv[2] 90 | output_dir = sys.argv[3] 91 | gen_config_for_tfmodel(config_path, vocab_path, output_dir) 92 | -------------------------------------------------------------------------------- /funasr_detach/models/seaco_paraformer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stepfun-ai/Step-Audio/d67e374fd6aff418835418e2e32989ca2db00dff/funasr_detach/models/seaco_paraformer/__init__.py -------------------------------------------------------------------------------- /funasr_detach/models/sond/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stepfun-ai/Step-Audio/d67e374fd6aff418835418e2e32989ca2db00dff/funasr_detach/models/sond/__init__.py -------------------------------------------------------------------------------- /funasr_detach/models/sond/encoder/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stepfun-ai/Step-Audio/d67e374fd6aff418835418e2e32989ca2db00dff/funasr_detach/models/sond/encoder/__init__.py -------------------------------------------------------------------------------- /funasr_detach/models/sond/encoder/ci_scorers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import functional as F 3 | 4 | 5 | class DotScorer(torch.nn.Module): 6 | def __init__(self): 7 | super().__init__() 8 | 9 | def forward( 10 | self, 11 | xs_pad: torch.Tensor, 12 | spk_emb: torch.Tensor, 13 | ): 14 | # xs_pad: B, T, D 15 | # spk_emb: B, N, D 16 | scores = torch.matmul(xs_pad, spk_emb.transpose(1, 2)) 17 | return scores 18 | 19 | def convert_tf2torch(self, var_dict_tf, var_dict_torch): 20 | return {} 21 | 22 | 23 | class CosScorer(torch.nn.Module): 24 | def __init__(self): 25 | super().__init__() 26 | 27 | def forward( 28 | self, 29 | xs_pad: torch.Tensor, 30 | spk_emb: torch.Tensor, 31 | ): 32 | # xs_pad: B, T, D 33 | # spk_emb: B, N, D 34 | scores = F.cosine_similarity(xs_pad.unsqueeze(2), spk_emb.unsqueeze(1), dim=-1) 35 | return scores 36 | 37 | def convert_tf2torch(self, var_dict_tf, var_dict_torch): 38 | return {} 39 | -------------------------------------------------------------------------------- /funasr_detach/models/sond/pooling/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stepfun-ai/Step-Audio/d67e374fd6aff418835418e2e32989ca2db00dff/funasr_detach/models/sond/pooling/__init__.py -------------------------------------------------------------------------------- /funasr_detach/models/sond/sv_decoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import functional as F 3 | from funasr_detach.models.decoder.abs_decoder import AbsDecoder 4 | 5 | 6 | class DenseDecoder(AbsDecoder): 7 | def __init__( 8 | self, 9 | vocab_size, 10 | encoder_output_size, 11 | num_nodes_resnet1: int = 256, 12 | num_nodes_last_layer: int = 256, 13 | batchnorm_momentum: float = 0.5, 14 | ): 15 | super(DenseDecoder, self).__init__() 16 | self.resnet1_dense = torch.nn.Linear(encoder_output_size, num_nodes_resnet1) 17 | self.resnet1_bn = torch.nn.BatchNorm1d( 18 | num_nodes_resnet1, eps=1e-3, momentum=batchnorm_momentum 19 | ) 20 | 21 | self.resnet2_dense = torch.nn.Linear(num_nodes_resnet1, num_nodes_last_layer) 22 | self.resnet2_bn = torch.nn.BatchNorm1d( 23 | num_nodes_last_layer, eps=1e-3, momentum=batchnorm_momentum 24 | ) 25 | 26 | self.output_dense = torch.nn.Linear( 27 | num_nodes_last_layer, vocab_size, bias=False 28 | ) 29 | 30 | def forward(self, features): 31 | embeddings = {} 32 | features = self.resnet1_dense(features) 33 | embeddings["resnet1_dense"] = features 34 | features = F.relu(features) 35 | features = self.resnet1_bn(features) 36 | 37 | features = self.resnet2_dense(features) 38 | embeddings["resnet2_dense"] = features 39 | features = F.relu(features) 40 | features = self.resnet2_bn(features) 41 | 42 | features = self.output_dense(features) 43 | return features, embeddings 44 | -------------------------------------------------------------------------------- /funasr_detach/models/specaug/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stepfun-ai/Step-Audio/d67e374fd6aff418835418e2e32989ca2db00dff/funasr_detach/models/specaug/__init__.py -------------------------------------------------------------------------------- /funasr_detach/models/specaug/time_warp.py: -------------------------------------------------------------------------------- 1 | """Time warp module.""" 2 | 3 | import torch 4 | 5 | from funasr_detach.models.transformer.utils.nets_utils import pad_list 6 | 7 | DEFAULT_TIME_WARP_MODE = "bicubic" 8 | 9 | 10 | def time_warp(x: torch.Tensor, window: int = 80, mode: str = DEFAULT_TIME_WARP_MODE): 11 | """Time warping using torch.interpolate. 12 | 13 | Args: 14 | x: (Batch, Time, Freq) 15 | window: time warp parameter 16 | mode: Interpolate mode 17 | """ 18 | 19 | # bicubic supports 4D or more dimension tensor 20 | org_size = x.size() 21 | if x.dim() == 3: 22 | # x: (Batch, Time, Freq) -> (Batch, 1, Time, Freq) 23 | x = x[:, None] 24 | 25 | t = x.shape[2] 26 | if t - window <= window: 27 | return x.view(*org_size) 28 | 29 | center = torch.randint(window, t - window, (1,))[0] 30 | warped = torch.randint(center - window, center + window, (1,))[0] + 1 31 | 32 | # left: (Batch, Channel, warped, Freq) 33 | # right: (Batch, Channel, time - warped, Freq) 34 | left = torch.nn.functional.interpolate( 35 | x[:, :, :center], (warped, x.shape[3]), mode=mode, align_corners=False 36 | ) 37 | right = torch.nn.functional.interpolate( 38 | x[:, :, center:], (t - warped, x.shape[3]), mode=mode, align_corners=False 39 | ) 40 | 41 | if x.requires_grad: 42 | x = torch.cat([left, right], dim=-2) 43 | else: 44 | x[:, :, :warped] = left 45 | x[:, :, warped:] = right 46 | 47 | return x.view(*org_size) 48 | 49 | 50 | class TimeWarp(torch.nn.Module): 51 | """Time warping using torch.interpolate. 52 | 53 | Args: 54 | window: time warp parameter 55 | mode: Interpolate mode 56 | """ 57 | 58 | def __init__(self, window: int = 80, mode: str = DEFAULT_TIME_WARP_MODE): 59 | super().__init__() 60 | self.window = window 61 | self.mode = mode 62 | 63 | def extra_repr(self): 64 | return f"window={self.window}, mode={self.mode}" 65 | 66 | def forward(self, x: torch.Tensor, x_lengths: torch.Tensor = None): 67 | """Forward function. 68 | 69 | Args: 70 | x: (Batch, Time, Freq) 71 | x_lengths: (Batch,) 72 | """ 73 | 74 | if x_lengths is None or all(le == x_lengths[0] for le in x_lengths): 75 | # Note that applying same warping for each sample 76 | y = time_warp(x, window=self.window, mode=self.mode) 77 | else: 78 | # FIXME(kamo): I have no idea to batchify Timewarp 79 | ys = [] 80 | for i in range(x.size(0)): 81 | _y = time_warp( 82 | x[i][None, : x_lengths[i]], 83 | window=self.window, 84 | mode=self.mode, 85 | )[0] 86 | ys.append(_y) 87 | y = pad_list(ys, 0.0) 88 | 89 | return y, x_lengths 90 | -------------------------------------------------------------------------------- /funasr_detach/models/transducer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stepfun-ai/Step-Audio/d67e374fd6aff418835418e2e32989ca2db00dff/funasr_detach/models/transducer/__init__.py -------------------------------------------------------------------------------- /funasr_detach/models/transducer/joint_network.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- encoding: utf-8 -*- 3 | # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. 4 | # MIT License (https://opensource.org/licenses/MIT) 5 | 6 | import torch 7 | 8 | from funasr_detach.register import tables 9 | from funasr_detach.models.transformer.utils.nets_utils import get_activation 10 | 11 | 12 | @tables.register("joint_network_classes", "joint_network") 13 | class JointNetwork(torch.nn.Module): 14 | """Transducer joint network module. 15 | 16 | Args: 17 | output_size: Output size. 18 | encoder_size: Encoder output size. 19 | decoder_size: Decoder output size.. 20 | joint_space_size: Joint space size. 21 | joint_act_type: Type of activation for joint network. 22 | **activation_parameters: Parameters for the activation function. 23 | 24 | """ 25 | 26 | def __init__( 27 | self, 28 | output_size: int, 29 | encoder_size: int, 30 | decoder_size: int, 31 | joint_space_size: int = 256, 32 | joint_activation_type: str = "tanh", 33 | ) -> None: 34 | """Construct a JointNetwork object.""" 35 | super().__init__() 36 | 37 | self.lin_enc = torch.nn.Linear(encoder_size, joint_space_size) 38 | self.lin_dec = torch.nn.Linear(decoder_size, joint_space_size, bias=False) 39 | 40 | self.lin_out = torch.nn.Linear(joint_space_size, output_size) 41 | 42 | self.joint_activation = get_activation(joint_activation_type) 43 | 44 | def forward( 45 | self, 46 | enc_out: torch.Tensor, 47 | dec_out: torch.Tensor, 48 | project_input: bool = True, 49 | ) -> torch.Tensor: 50 | """Joint computation of encoder and decoder hidden state sequences. 51 | 52 | Args: 53 | enc_out: Expanded encoder output state sequences (B, T, 1, D_enc) 54 | dec_out: Expanded decoder output state sequences (B, 1, U, D_dec) 55 | 56 | Returns: 57 | joint_out: Joint output state sequences. (B, T, U, D_out) 58 | 59 | """ 60 | if project_input: 61 | joint_out = self.joint_activation( 62 | self.lin_enc(enc_out) + self.lin_dec(dec_out) 63 | ) 64 | else: 65 | joint_out = self.joint_activation(enc_out + dec_out) 66 | return self.lin_out(joint_out) 67 | -------------------------------------------------------------------------------- /funasr_detach/models/transformer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stepfun-ai/Step-Audio/d67e374fd6aff418835418e2e32989ca2db00dff/funasr_detach/models/transformer/__init__.py -------------------------------------------------------------------------------- /funasr_detach/models/transformer/positionwise_feed_forward.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | # Copyright 2019 Shigeki Karita 5 | # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) 6 | 7 | """Positionwise feed forward layer definition.""" 8 | 9 | import torch 10 | 11 | from funasr_detach.models.transformer.layer_norm import LayerNorm 12 | 13 | 14 | class PositionwiseFeedForward(torch.nn.Module): 15 | """Positionwise feed forward layer. 16 | 17 | Args: 18 | idim (int): Input dimenstion. 19 | hidden_units (int): The number of hidden units. 20 | dropout_rate (float): Dropout rate. 21 | 22 | """ 23 | 24 | def __init__(self, idim, hidden_units, dropout_rate, activation=torch.nn.ReLU()): 25 | """Construct an PositionwiseFeedForward object.""" 26 | super(PositionwiseFeedForward, self).__init__() 27 | self.w_1 = torch.nn.Linear(idim, hidden_units) 28 | self.w_2 = torch.nn.Linear(hidden_units, idim) 29 | self.dropout = torch.nn.Dropout(dropout_rate) 30 | self.activation = activation 31 | 32 | def forward(self, x): 33 | """Forward function.""" 34 | return self.w_2(self.dropout(self.activation(self.w_1(x)))) 35 | -------------------------------------------------------------------------------- /funasr_detach/models/transformer/scorers/__init__.py: -------------------------------------------------------------------------------- 1 | """Initialize sub package.""" 2 | -------------------------------------------------------------------------------- /funasr_detach/models/transformer/scorers/length_bonus.py: -------------------------------------------------------------------------------- 1 | """Length bonus module.""" 2 | 3 | from typing import Any 4 | from typing import List 5 | from typing import Tuple 6 | 7 | import torch 8 | 9 | from funasr_detach.models.transformer.scorers.scorer_interface import ( 10 | BatchScorerInterface, 11 | ) 12 | 13 | 14 | class LengthBonus(BatchScorerInterface): 15 | """Length bonus in beam search.""" 16 | 17 | def __init__(self, n_vocab: int): 18 | """Initialize class. 19 | 20 | Args: 21 | n_vocab (int): The number of tokens in vocabulary for beam search 22 | 23 | """ 24 | self.n = n_vocab 25 | 26 | def score(self, y, state, x): 27 | """Score new token. 28 | 29 | Args: 30 | y (torch.Tensor): 1D torch.int64 prefix tokens. 31 | state: Scorer state for prefix tokens 32 | x (torch.Tensor): 2D encoder feature that generates ys. 33 | 34 | Returns: 35 | tuple[torch.Tensor, Any]: Tuple of 36 | torch.float32 scores for next token (n_vocab) 37 | and None 38 | 39 | """ 40 | return torch.tensor([1.0], device=x.device, dtype=x.dtype).expand(self.n), None 41 | 42 | def batch_score( 43 | self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor 44 | ) -> Tuple[torch.Tensor, List[Any]]: 45 | """Score new token batch. 46 | 47 | Args: 48 | ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen). 49 | states (List[Any]): Scorer states for prefix tokens. 50 | xs (torch.Tensor): 51 | The encoder feature that generates ys (n_batch, xlen, n_feat). 52 | 53 | Returns: 54 | tuple[torch.Tensor, List[Any]]: Tuple of 55 | batchfied scores for next token with shape of `(n_batch, n_vocab)` 56 | and next state list for ys. 57 | 58 | """ 59 | return ( 60 | torch.tensor([1.0], device=xs.device, dtype=xs.dtype).expand( 61 | ys.shape[0], self.n 62 | ), 63 | None, 64 | ) 65 | -------------------------------------------------------------------------------- /funasr_detach/models/transformer/template.yaml: -------------------------------------------------------------------------------- 1 | # This is an example that demonstrates how to configure a model file. 2 | # You can modify the configuration according to your own requirements. 3 | 4 | # to print the register_table: 5 | # from funasr.register import tables 6 | # tables.print() 7 | 8 | # network architecture 9 | model: Transformer 10 | model_conf: 11 | ctc_weight: 0.3 12 | lsm_weight: 0.1 # label smoothing option 13 | length_normalized_loss: false 14 | 15 | # encoder 16 | encoder: TransformerEncoder 17 | encoder_conf: 18 | output_size: 256 # dimension of attention 19 | attention_heads: 4 20 | linear_units: 2048 # the number of units of position-wise feed forward 21 | num_blocks: 12 # the number of encoder blocks 22 | dropout_rate: 0.1 23 | positional_dropout_rate: 0.1 24 | attention_dropout_rate: 0.0 25 | input_layer: conv2d # encoder architecture type 26 | normalize_before: true 27 | 28 | # decoder 29 | decoder: TransformerDecoder 30 | decoder_conf: 31 | attention_heads: 4 32 | linear_units: 2048 33 | num_blocks: 6 34 | dropout_rate: 0.1 35 | positional_dropout_rate: 0.1 36 | self_attention_dropout_rate: 0.0 37 | src_attention_dropout_rate: 0.0 38 | 39 | 40 | # frontend related 41 | frontend: WavFrontend 42 | frontend_conf: 43 | fs: 16000 44 | window: hamming 45 | n_mels: 80 46 | frame_length: 25 47 | frame_shift: 10 48 | lfr_m: 1 49 | lfr_n: 1 50 | 51 | specaug: SpecAug 52 | specaug_conf: 53 | apply_time_warp: true 54 | time_warp_window: 5 55 | time_warp_mode: bicubic 56 | apply_freq_mask: true 57 | freq_mask_width_range: 58 | - 0 59 | - 30 60 | num_freq_mask: 2 61 | apply_time_mask: true 62 | time_mask_width_range: 63 | - 0 64 | - 40 65 | num_time_mask: 2 66 | 67 | train_conf: 68 | accum_grad: 1 69 | grad_clip: 5 70 | max_epoch: 150 71 | val_scheduler_criterion: 72 | - valid 73 | - acc 74 | best_model_criterion: 75 | - - valid 76 | - acc 77 | - max 78 | keep_nbest_models: 10 79 | log_interval: 50 80 | 81 | optim: adam 82 | optim_conf: 83 | lr: 0.002 84 | scheduler: warmuplr 85 | scheduler_conf: 86 | warmup_steps: 30000 87 | 88 | dataset: AudioDataset 89 | dataset_conf: 90 | index_ds: IndexDSJsonl 91 | batch_sampler: DynamicBatchLocalShuffleSampler 92 | batch_type: example # example or length 93 | batch_size: 1 # if batch_type is example, batch_size is the numbers of samples; if length, batch_size is source_token_len+target_token_len; 94 | max_token_length: 2048 # filter samples if source_token_len+target_token_len > max_token_length, 95 | buffer_size: 500 96 | shuffle: True 97 | num_workers: 0 98 | 99 | tokenizer: CharTokenizer 100 | tokenizer_conf: 101 | unk_symbol: 102 | split_with_space: true 103 | 104 | 105 | ctc_conf: 106 | dropout_rate: 0.0 107 | ctc_type: builtin 108 | reduce: true 109 | ignore_nan_grad: true 110 | normalize: null 111 | -------------------------------------------------------------------------------- /funasr_detach/models/transformer/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stepfun-ai/Step-Audio/d67e374fd6aff418835418e2e32989ca2db00dff/funasr_detach/models/transformer/utils/__init__.py -------------------------------------------------------------------------------- /funasr_detach/models/transformer/utils/add_sos_eos.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | # Copyright 2019 Shigeki Karita 5 | # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) 6 | 7 | """Unility functions for Transformer.""" 8 | 9 | import torch 10 | from funasr_detach.models.transformer.utils.nets_utils import pad_list 11 | 12 | 13 | def add_sos_eos(ys_pad, sos, eos, ignore_id): 14 | """Add and labels. 15 | 16 | :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax) 17 | :param int sos: index of 18 | :param int eos: index of 19 | :param int ignore_id: index of padding 20 | :return: padded tensor (B, Lmax) 21 | :rtype: torch.Tensor 22 | :return: padded tensor (B, Lmax) 23 | :rtype: torch.Tensor 24 | """ 25 | 26 | _sos = ys_pad.new([sos]) 27 | _eos = ys_pad.new([eos]) 28 | ys = [y[y != ignore_id] for y in ys_pad] # parse padded ys 29 | ys_in = [torch.cat([_sos, y], dim=0) for y in ys] 30 | ys_out = [torch.cat([y, _eos], dim=0) for y in ys] 31 | return pad_list(ys_in, eos), pad_list(ys_out, ignore_id) 32 | -------------------------------------------------------------------------------- /funasr_detach/models/transformer/utils/mask.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Shigeki Karita 2 | # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) 3 | 4 | """Mask module.""" 5 | 6 | import torch 7 | 8 | 9 | def subsequent_mask(size, device="cpu", dtype=torch.bool): 10 | """Create mask for subsequent steps (size, size). 11 | 12 | :param int size: size of mask 13 | :param str device: "cpu" or "cuda" or torch.Tensor.device 14 | :param torch.dtype dtype: result dtype 15 | :rtype: torch.Tensor 16 | >>> subsequent_mask(3) 17 | [[1, 0, 0], 18 | [1, 1, 0], 19 | [1, 1, 1]] 20 | """ 21 | ret = torch.ones(size, size, device=device, dtype=dtype) 22 | return torch.tril(ret, out=ret) 23 | 24 | 25 | def target_mask(ys_in_pad, ignore_id): 26 | """Create mask for decoder self-attention. 27 | 28 | :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax) 29 | :param int ignore_id: index of padding 30 | :param torch.dtype dtype: result dtype 31 | :rtype: torch.Tensor (B, Lmax, Lmax) 32 | """ 33 | ys_mask = ys_in_pad != ignore_id 34 | m = subsequent_mask(ys_mask.size(-1), device=ys_mask.device).unsqueeze(0) 35 | return ys_mask.unsqueeze(-2) & m 36 | 37 | 38 | def vad_mask(size, vad_pos, device="cpu", dtype=torch.bool): 39 | """Create mask for decoder self-attention. 40 | 41 | :param int size: size of mask 42 | :param int vad_pos: index of vad index 43 | :param str device: "cpu" or "cuda" or torch.Tensor.device 44 | :param torch.dtype dtype: result dtype 45 | :rtype: torch.Tensor (B, Lmax, Lmax) 46 | """ 47 | ret = torch.ones(size, size, device=device, dtype=dtype) 48 | if vad_pos <= 0 or vad_pos >= size: 49 | return ret 50 | sub_corner = torch.zeros(vad_pos - 1, size - vad_pos, device=device, dtype=dtype) 51 | ret[0 : vad_pos - 1, vad_pos:] = sub_corner 52 | return ret 53 | -------------------------------------------------------------------------------- /funasr_detach/models/transformer/utils/subsampling_without_posenc.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Emiru Tsunoo 2 | # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) 3 | 4 | """Subsampling layer definition.""" 5 | 6 | import math 7 | import torch 8 | 9 | 10 | class Conv2dSubsamplingWOPosEnc(torch.nn.Module): 11 | """Convolutional 2D subsampling. 12 | 13 | Args: 14 | idim (int): Input dimension. 15 | odim (int): Output dimension. 16 | dropout_rate (float): Dropout rate. 17 | kernels (list): kernel sizes 18 | strides (list): stride sizes 19 | 20 | """ 21 | 22 | def __init__(self, idim, odim, dropout_rate, kernels, strides): 23 | """Construct an Conv2dSubsamplingWOPosEnc object.""" 24 | assert len(kernels) == len(strides) 25 | super().__init__() 26 | conv = [] 27 | olen = idim 28 | for i, (k, s) in enumerate(zip(kernels, strides)): 29 | conv += [ 30 | torch.nn.Conv2d(1 if i == 0 else odim, odim, k, s), 31 | torch.nn.ReLU(), 32 | ] 33 | olen = math.floor((olen - k) / s + 1) 34 | self.conv = torch.nn.Sequential(*conv) 35 | self.out = torch.nn.Linear(odim * olen, odim) 36 | self.strides = strides 37 | self.kernels = kernels 38 | 39 | def forward(self, x, x_mask): 40 | """Subsample x. 41 | 42 | Args: 43 | x (torch.Tensor): Input tensor (#batch, time, idim). 44 | x_mask (torch.Tensor): Input mask (#batch, 1, time). 45 | 46 | Returns: 47 | torch.Tensor: Subsampled tensor (#batch, time', odim), 48 | where time' = time // 4. 49 | torch.Tensor: Subsampled mask (#batch, 1, time'), 50 | where time' = time // 4. 51 | 52 | """ 53 | x = x.unsqueeze(1) # (b, c, t, f) 54 | x = self.conv(x) 55 | b, c, t, f = x.size() 56 | x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) 57 | if x_mask is None: 58 | return x, None 59 | for k, s in zip(self.kernels, self.strides): 60 | x_mask = x_mask[:, :, : -k + 1 : s] 61 | return x, x_mask 62 | -------------------------------------------------------------------------------- /funasr_detach/models/transformer/utils/vgg2l.py: -------------------------------------------------------------------------------- 1 | """VGG2L module definition for custom encoder.""" 2 | 3 | from typing import Tuple, Union 4 | 5 | import torch 6 | 7 | 8 | class VGG2L(torch.nn.Module): 9 | """VGG2L module for custom encoder. 10 | 11 | Args: 12 | idim: Input dimension. 13 | odim: Output dimension. 14 | pos_enc: Positional encoding class. 15 | 16 | """ 17 | 18 | def __init__(self, idim: int, odim: int, pos_enc: torch.nn.Module = None): 19 | """Construct a VGG2L object.""" 20 | super().__init__() 21 | 22 | self.vgg2l = torch.nn.Sequential( 23 | torch.nn.Conv2d(1, 64, 3, stride=1, padding=1), 24 | torch.nn.ReLU(), 25 | torch.nn.Conv2d(64, 64, 3, stride=1, padding=1), 26 | torch.nn.ReLU(), 27 | torch.nn.MaxPool2d((3, 2)), 28 | torch.nn.Conv2d(64, 128, 3, stride=1, padding=1), 29 | torch.nn.ReLU(), 30 | torch.nn.Conv2d(128, 128, 3, stride=1, padding=1), 31 | torch.nn.ReLU(), 32 | torch.nn.MaxPool2d((2, 2)), 33 | ) 34 | 35 | if pos_enc is not None: 36 | self.output = torch.nn.Sequential( 37 | torch.nn.Linear(128 * ((idim // 2) // 2), odim), pos_enc 38 | ) 39 | else: 40 | self.output = torch.nn.Linear(128 * ((idim // 2) // 2), odim) 41 | 42 | def forward(self, feats: torch.Tensor, feats_mask: torch.Tensor) -> Union[ 43 | Tuple[torch.Tensor, torch.Tensor], 44 | Tuple[Tuple[torch.Tensor, torch.Tensor], torch.Tensor], 45 | ]: 46 | """Forward VGG2L bottleneck. 47 | 48 | Args: 49 | feats: Feature sequences. (B, F, D_feats) 50 | feats_mask: Mask of feature sequences. (B, 1, F) 51 | 52 | Returns: 53 | vgg_output: VGG output sequences. 54 | (B, sub(F), D_out) or ((B, sub(F), D_out), (B, sub(F), D_att)) 55 | vgg_mask: Mask of VGG output sequences. (B, 1, sub(F)) 56 | 57 | """ 58 | feats = feats.unsqueeze(1) 59 | vgg_output = self.vgg2l(feats) 60 | 61 | b, c, t, f = vgg_output.size() 62 | 63 | vgg_output = self.output( 64 | vgg_output.transpose(1, 2).contiguous().view(b, t, c * f) 65 | ) 66 | 67 | if feats_mask is not None: 68 | vgg_mask = self.create_new_mask(feats_mask) 69 | else: 70 | vgg_mask = feats_mask 71 | 72 | return vgg_output, vgg_mask 73 | 74 | def create_new_mask(self, feats_mask: torch.Tensor) -> torch.Tensor: 75 | """Create a subsampled mask of feature sequences. 76 | 77 | Args: 78 | feats_mask: Mask of feature sequences. (B, 1, F) 79 | 80 | Returns: 81 | vgg_mask: Mask of VGG2L output sequences. (B, 1, sub(F)) 82 | 83 | """ 84 | vgg1_t_len = feats_mask.size(2) - (feats_mask.size(2) % 3) 85 | vgg_mask = feats_mask[:, :, :vgg1_t_len][:, :, ::3] 86 | 87 | vgg2_t_len = vgg_mask.size(2) - (vgg_mask.size(2) % 2) 88 | vgg_mask = vgg_mask[:, :, :vgg2_t_len][:, :, ::2] 89 | 90 | return vgg_mask 91 | -------------------------------------------------------------------------------- /funasr_detach/models/uniasr/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stepfun-ai/Step-Audio/d67e374fd6aff418835418e2e32989ca2db00dff/funasr_detach/models/uniasr/__init__.py -------------------------------------------------------------------------------- /funasr_detach/models/whisper/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stepfun-ai/Step-Audio/d67e374fd6aff418835418e2e32989ca2db00dff/funasr_detach/models/whisper/__init__.py -------------------------------------------------------------------------------- /funasr_detach/models/whisper/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stepfun-ai/Step-Audio/d67e374fd6aff418835418e2e32989ca2db00dff/funasr_detach/models/whisper/utils/__init__.py -------------------------------------------------------------------------------- /funasr_detach/models/whisper/utils/assets/gpt2/special_tokens_map.json: -------------------------------------------------------------------------------- 1 | {"bos_token": "<|endoftext|>", "eos_token": "<|endoftext|>", "unk_token": "<|endoftext|>"} -------------------------------------------------------------------------------- /funasr_detach/models/whisper/utils/assets/gpt2/tokenizer_config.json: -------------------------------------------------------------------------------- 1 | {"unk_token": "<|endoftext|>", "bos_token": "<|endoftext|>", "eos_token": "<|endoftext|>", "add_prefix_space": false, "model_max_length": 1024, "special_tokens_map_file": null, "name_or_path": "gpt2", "tokenizer_class": "GPT2Tokenizer"} -------------------------------------------------------------------------------- /funasr_detach/models/whisper/utils/assets/mel_filters.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stepfun-ai/Step-Audio/d67e374fd6aff418835418e2e32989ca2db00dff/funasr_detach/models/whisper/utils/assets/mel_filters.npz -------------------------------------------------------------------------------- /funasr_detach/models/whisper/utils/assets/multilingual/added_tokens.json: -------------------------------------------------------------------------------- 1 | {"<|endoftext|>": 50257} 2 | -------------------------------------------------------------------------------- /funasr_detach/models/whisper/utils/assets/multilingual/special_tokens_map.json: -------------------------------------------------------------------------------- 1 | {"bos_token": "<|endoftext|>", "eos_token": "<|endoftext|>", "unk_token": "<|endoftext|>"} -------------------------------------------------------------------------------- /funasr_detach/models/whisper/utils/assets/multilingual/tokenizer_config.json: -------------------------------------------------------------------------------- 1 | {"unk_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "bos_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "eos_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "add_prefix_space": false, "model_max_length": 1024, "special_tokens_map_file": null, "name_or_path": "multilingual", "errors": "replace", "tokenizer_class": "GPT2Tokenizer"} -------------------------------------------------------------------------------- /funasr_detach/models/xvector/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stepfun-ai/Step-Audio/d67e374fd6aff418835418e2e32989ca2db00dff/funasr_detach/models/xvector/__init__.py -------------------------------------------------------------------------------- /funasr_detach/optimizers/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from funasr_detach.optimizers.fairseq_adam import FairseqAdam 3 | from funasr_detach.optimizers.sgd import SGD 4 | 5 | optim_classes = dict( 6 | adam=torch.optim.Adam, 7 | fairseq_adam=FairseqAdam, 8 | adamw=torch.optim.AdamW, 9 | sgd=SGD, 10 | adadelta=torch.optim.Adadelta, 11 | adagrad=torch.optim.Adagrad, 12 | adamax=torch.optim.Adamax, 13 | asgd=torch.optim.ASGD, 14 | lbfgs=torch.optim.LBFGS, 15 | rmsprop=torch.optim.RMSprop, 16 | rprop=torch.optim.Rprop, 17 | ) 18 | -------------------------------------------------------------------------------- /funasr_detach/optimizers/sgd.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class SGD(torch.optim.SGD): 5 | """Thin inheritance of torch.optim.SGD to bind the required arguments, 'lr' 6 | 7 | Note that 8 | the arguments of the optimizer invoked by AbsTask.main() 9 | must have default value except for 'param'. 10 | 11 | I can't understand why only SGD.lr doesn't have the default value. 12 | """ 13 | 14 | def __init__( 15 | self, 16 | params, 17 | lr: float = 0.1, 18 | momentum: float = 0.0, 19 | dampening: float = 0.0, 20 | weight_decay: float = 0.0, 21 | nesterov: bool = False, 22 | ): 23 | super().__init__( 24 | params, 25 | lr=lr, 26 | momentum=momentum, 27 | dampening=dampening, 28 | weight_decay=weight_decay, 29 | nesterov=nesterov, 30 | ) 31 | -------------------------------------------------------------------------------- /funasr_detach/schedulers/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.multiprocessing 3 | import torch.nn 4 | import torch.optim 5 | 6 | from funasr_detach.schedulers.noam_lr import NoamLR 7 | from funasr_detach.schedulers.tri_stage_scheduler import TriStageLR 8 | from funasr_detach.schedulers.warmup_lr import WarmupLR 9 | 10 | scheduler_classes = dict( 11 | ReduceLROnPlateau=torch.optim.lr_scheduler.ReduceLROnPlateau, 12 | lambdalr=torch.optim.lr_scheduler.LambdaLR, 13 | steplr=torch.optim.lr_scheduler.StepLR, 14 | multisteplr=torch.optim.lr_scheduler.MultiStepLR, 15 | exponentiallr=torch.optim.lr_scheduler.ExponentialLR, 16 | CosineAnnealingLR=torch.optim.lr_scheduler.CosineAnnealingLR, 17 | noamlr=NoamLR, 18 | warmuplr=WarmupLR, 19 | tri_stage=TriStageLR, 20 | cycliclr=torch.optim.lr_scheduler.CyclicLR, 21 | onecyclelr=torch.optim.lr_scheduler.OneCycleLR, 22 | CosineAnnealingWarmRestarts=torch.optim.lr_scheduler.CosineAnnealingWarmRestarts, 23 | ) 24 | -------------------------------------------------------------------------------- /funasr_detach/schedulers/abs_scheduler.py: -------------------------------------------------------------------------------- 1 | from abc import ABC 2 | from abc import abstractmethod 3 | 4 | import torch.optim.lr_scheduler as L 5 | 6 | 7 | class AbsScheduler(ABC): 8 | @abstractmethod 9 | def step(self, epoch: int = None): 10 | pass 11 | 12 | @abstractmethod 13 | def state_dict(self): 14 | pass 15 | 16 | @abstractmethod 17 | def load_state_dict(self, state): 18 | pass 19 | 20 | 21 | # If you need to define custom scheduler, please inherit these classes 22 | class AbsBatchStepScheduler(AbsScheduler): 23 | @abstractmethod 24 | def step(self, epoch: int = None): 25 | pass 26 | 27 | @abstractmethod 28 | def state_dict(self): 29 | pass 30 | 31 | @abstractmethod 32 | def load_state_dict(self, state): 33 | pass 34 | 35 | 36 | class AbsEpochStepScheduler(AbsScheduler): 37 | @abstractmethod 38 | def step(self, epoch: int = None): 39 | pass 40 | 41 | @abstractmethod 42 | def state_dict(self): 43 | pass 44 | 45 | @abstractmethod 46 | def load_state_dict(self, state): 47 | pass 48 | 49 | 50 | class AbsValEpochStepScheduler(AbsEpochStepScheduler): 51 | @abstractmethod 52 | def step(self, val, epoch: int = None): 53 | pass 54 | 55 | @abstractmethod 56 | def state_dict(self): 57 | pass 58 | 59 | @abstractmethod 60 | def load_state_dict(self, state): 61 | pass 62 | 63 | 64 | # Create alias type to check the type 65 | # Note(kamo): Currently PyTorch doesn't provide the base class 66 | # to judge these classes. 67 | AbsValEpochStepScheduler.register(L.ReduceLROnPlateau) 68 | for s in [ 69 | L.ReduceLROnPlateau, 70 | L.LambdaLR, 71 | L.StepLR, 72 | L.MultiStepLR, 73 | L.MultiStepLR, 74 | L.ExponentialLR, 75 | L.CosineAnnealingLR, 76 | ]: 77 | AbsEpochStepScheduler.register(s) 78 | 79 | AbsBatchStepScheduler.register(L.CyclicLR) 80 | for s in [ 81 | L.OneCycleLR, 82 | L.CosineAnnealingWarmRestarts, 83 | ]: 84 | AbsBatchStepScheduler.register(s) 85 | -------------------------------------------------------------------------------- /funasr_detach/schedulers/noam_lr.py: -------------------------------------------------------------------------------- 1 | """Noam learning rate scheduler module.""" 2 | 3 | from typing import Union 4 | import warnings 5 | 6 | import torch 7 | from torch.optim.lr_scheduler import _LRScheduler 8 | 9 | from funasr_detach.schedulers.abs_scheduler import AbsBatchStepScheduler 10 | 11 | 12 | class NoamLR(_LRScheduler, AbsBatchStepScheduler): 13 | """The LR scheduler proposed by Noam 14 | 15 | Ref: 16 | "Attention Is All You Need", https://arxiv.org/pdf/1706.03762.pdf 17 | 18 | FIXME(kamo): PyTorch doesn't provide _LRScheduler as public class, 19 | thus the behaviour isn't guaranteed at forward PyTorch version. 20 | 21 | NOTE(kamo): The "model_size" in original implementation is derived from 22 | the model, but in this implementation, this parameter is a constant value. 23 | You need to change it if the model is changed. 24 | 25 | """ 26 | 27 | def __init__( 28 | self, 29 | optimizer: torch.optim.Optimizer, 30 | model_size: Union[int, float] = 320, 31 | warmup_steps: Union[int, float] = 25000, 32 | last_epoch: int = -1, 33 | ): 34 | self.model_size = model_size 35 | self.warmup_steps = warmup_steps 36 | 37 | lr = list(optimizer.param_groups)[0]["lr"] 38 | new_lr = self.lr_for_WarmupLR(lr) 39 | warnings.warn( 40 | f"NoamLR is deprecated. " 41 | f"Use WarmupLR(warmup_steps={warmup_steps}) with Optimizer(lr={new_lr})", 42 | ) 43 | 44 | # __init__() must be invoked before setting field 45 | # because step() is also invoked in __init__() 46 | super().__init__(optimizer, last_epoch) 47 | 48 | def lr_for_WarmupLR(self, lr: float) -> float: 49 | return lr / self.model_size**0.5 / self.warmup_steps**0.5 50 | 51 | def __repr__(self): 52 | return ( 53 | f"{self.__class__.__name__}(model_size={self.model_size}, " 54 | f"warmup_steps={self.warmup_steps})" 55 | ) 56 | 57 | def get_lr(self): 58 | step_num = self.last_epoch + 1 59 | return [ 60 | lr 61 | * self.model_size**-0.5 62 | * min(step_num**-0.5, step_num * self.warmup_steps**-1.5) 63 | for lr in self.base_lrs 64 | ] 65 | -------------------------------------------------------------------------------- /funasr_detach/schedulers/warmup_lr.py: -------------------------------------------------------------------------------- 1 | """Warm up learning rate scheduler module.""" 2 | 3 | from typing import Union 4 | 5 | import torch 6 | from torch.optim.lr_scheduler import _LRScheduler 7 | 8 | from funasr_detach.schedulers.abs_scheduler import AbsBatchStepScheduler 9 | 10 | 11 | class WarmupLR(_LRScheduler, AbsBatchStepScheduler): 12 | """The WarmupLR scheduler 13 | 14 | This scheduler is almost same as NoamLR Scheduler except for following difference: 15 | 16 | NoamLR: 17 | lr = optimizer.lr * model_size ** -0.5 18 | * min(step ** -0.5, step * warmup_step ** -1.5) 19 | WarmupLR: 20 | lr = optimizer.lr * warmup_step ** 0.5 21 | * min(step ** -0.5, step * warmup_step ** -1.5) 22 | 23 | Note that the maximum lr equals to optimizer.lr in this scheduler. 24 | 25 | """ 26 | 27 | def __init__( 28 | self, 29 | optimizer: torch.optim.Optimizer, 30 | warmup_steps: Union[int, float] = 25000, 31 | last_epoch: int = -1, 32 | ): 33 | self.warmup_steps = warmup_steps 34 | 35 | # __init__() must be invoked before setting field 36 | # because step() is also invoked in __init__() 37 | super().__init__(optimizer, last_epoch) 38 | 39 | def __repr__(self): 40 | return f"{self.__class__.__name__}(warmup_steps={self.warmup_steps})" 41 | 42 | def get_lr(self): 43 | step_num = self.last_epoch + 1 44 | return [ 45 | lr 46 | * self.warmup_steps**0.5 47 | * min(step_num**-0.5, step_num * self.warmup_steps**-1.5) 48 | for lr in self.base_lrs 49 | ] 50 | -------------------------------------------------------------------------------- /funasr_detach/tokenizer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stepfun-ai/Step-Audio/d67e374fd6aff418835418e2e32989ca2db00dff/funasr_detach/tokenizer/__init__.py -------------------------------------------------------------------------------- /funasr_detach/tokenizer/build_tokenizer.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Iterable 3 | from typing import Union 4 | 5 | 6 | from funasr_detach.tokenizer.abs_tokenizer import AbsTokenizer 7 | from funasr_detach.tokenizer.char_tokenizer import CharTokenizer 8 | from funasr_detach.tokenizer.phoneme_tokenizer import PhonemeTokenizer 9 | from funasr_detach.tokenizer.sentencepiece_tokenizer import SentencepiecesTokenizer 10 | from funasr_detach.tokenizer.word_tokenizer import WordTokenizer 11 | 12 | 13 | def build_tokenizer( 14 | token_type: str, 15 | bpemodel: Union[Path, str, Iterable[str]] = None, 16 | non_linguistic_symbols: Union[Path, str, Iterable[str]] = None, 17 | remove_non_linguistic_symbols: bool = False, 18 | space_symbol: str = "", 19 | delimiter: str = None, 20 | g2p_type: str = None, 21 | ) -> AbsTokenizer: 22 | """A helper function to instantiate Tokenizer""" 23 | if token_type == "bpe": 24 | if bpemodel is None: 25 | raise ValueError('bpemodel is required if token_type = "bpe"') 26 | 27 | if remove_non_linguistic_symbols: 28 | raise RuntimeError( 29 | "remove_non_linguistic_symbols is not implemented for token_type=bpe" 30 | ) 31 | return SentencepiecesTokenizer(bpemodel) 32 | 33 | elif token_type == "word": 34 | if remove_non_linguistic_symbols and non_linguistic_symbols is not None: 35 | return WordTokenizer( 36 | delimiter=delimiter, 37 | non_linguistic_symbols=non_linguistic_symbols, 38 | remove_non_linguistic_symbols=True, 39 | ) 40 | else: 41 | return WordTokenizer(delimiter=delimiter) 42 | 43 | elif token_type == "char": 44 | return CharTokenizer( 45 | non_linguistic_symbols=non_linguistic_symbols, 46 | space_symbol=space_symbol, 47 | remove_non_linguistic_symbols=remove_non_linguistic_symbols, 48 | ) 49 | 50 | elif token_type == "phn": 51 | return PhonemeTokenizer( 52 | g2p_type=g2p_type, 53 | non_linguistic_symbols=non_linguistic_symbols, 54 | space_symbol=space_symbol, 55 | remove_non_linguistic_symbols=remove_non_linguistic_symbols, 56 | ) 57 | 58 | else: 59 | raise ValueError( 60 | f"token_mode must be one of bpe, word, char or phn: " f"{token_type}" 61 | ) 62 | -------------------------------------------------------------------------------- /funasr_detach/tokenizer/cleaner.py: -------------------------------------------------------------------------------- 1 | from typing import Collection 2 | 3 | from jaconv import jaconv 4 | 5 | # import tacotron_cleaner.cleaners 6 | 7 | try: 8 | from vietnamese_cleaner import vietnamese_cleaners 9 | except ImportError: 10 | vietnamese_cleaners = None 11 | 12 | 13 | class TextCleaner: 14 | """Text cleaner. 15 | 16 | Examples: 17 | >>> cleaner = TextCleaner("tacotron") 18 | >>> cleaner("(Hello-World); & jr. & dr.") 19 | 'HELLO WORLD, AND JUNIOR AND DOCTOR' 20 | 21 | """ 22 | 23 | def __init__(self, cleaner_types: Collection[str] = None): 24 | 25 | if cleaner_types is None: 26 | self.cleaner_types = [] 27 | elif isinstance(cleaner_types, str): 28 | self.cleaner_types = [cleaner_types] 29 | else: 30 | self.cleaner_types = list(cleaner_types) 31 | 32 | def __call__(self, text: str) -> str: 33 | for t in self.cleaner_types: 34 | if t == "tacotron": 35 | # text = tacotron_cleaner.cleaners.custom_english_cleaners(text) 36 | pass 37 | elif t == "jaconv": 38 | text = jaconv.normalize(text) 39 | elif t == "vietnamese": 40 | if vietnamese_cleaners is None: 41 | raise RuntimeError("Please install underthesea") 42 | text = vietnamese_cleaners.vietnamese_cleaner(text) 43 | elif t == "korean_cleaner": 44 | text = KoreanCleaner.normalize_text(text) 45 | else: 46 | raise RuntimeError(f"Not supported: type={t}") 47 | 48 | return text 49 | -------------------------------------------------------------------------------- /funasr_detach/tokenizer/korean_cleaner.py: -------------------------------------------------------------------------------- 1 | # Referenced from https://github.com/hccho2/Tacotron-Wavenet-Vocoder-Korean 2 | 3 | import re 4 | 5 | 6 | class KoreanCleaner: 7 | @classmethod 8 | def _normalize_numbers(cls, text): 9 | number_to_kor = { 10 | "0": "영", 11 | "1": "일", 12 | "2": "이", 13 | "3": "삼", 14 | "4": "사", 15 | "5": "오", 16 | "6": "육", 17 | "7": "칠", 18 | "8": "팔", 19 | "9": "구", 20 | } 21 | new_text = "".join( 22 | number_to_kor[char] if char in number_to_kor.keys() else char 23 | for char in text 24 | ) 25 | return new_text 26 | 27 | @classmethod 28 | def _normalize_english_text(cls, text): 29 | upper_alphabet_to_kor = { 30 | "A": "에이", 31 | "B": "비", 32 | "C": "씨", 33 | "D": "디", 34 | "E": "이", 35 | "F": "에프", 36 | "G": "지", 37 | "H": "에이치", 38 | "I": "아이", 39 | "J": "제이", 40 | "K": "케이", 41 | "L": "엘", 42 | "M": "엠", 43 | "N": "엔", 44 | "O": "오", 45 | "P": "피", 46 | "Q": "큐", 47 | "R": "알", 48 | "S": "에스", 49 | "T": "티", 50 | "U": "유", 51 | "V": "브이", 52 | "W": "더블유", 53 | "X": "엑스", 54 | "Y": "와이", 55 | "Z": "지", 56 | } 57 | new_text = re.sub("[a-z]+", lambda x: str.upper(x.group()), text) 58 | new_text = "".join( 59 | ( 60 | upper_alphabet_to_kor[char] 61 | if char in upper_alphabet_to_kor.keys() 62 | else char 63 | ) 64 | for char in new_text 65 | ) 66 | 67 | return new_text 68 | 69 | @classmethod 70 | def normalize_text(cls, text): 71 | # stage 0 : text strip 72 | text = text.strip() 73 | 74 | # stage 1 : normalize numbers 75 | text = cls._normalize_numbers(text) 76 | 77 | # stage 2 : normalize english text 78 | text = cls._normalize_english_text(text) 79 | return text 80 | -------------------------------------------------------------------------------- /funasr_detach/tokenizer/sentencepiece_tokenizer.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Iterable 3 | from typing import List 4 | from typing import Union 5 | 6 | import sentencepiece as spm 7 | 8 | from funasr_detach.tokenizer.abs_tokenizer import BaseTokenizer 9 | from funasr_detach.register import tables 10 | 11 | 12 | @tables.register("tokenizer_classes", "SentencepiecesTokenizer") 13 | class SentencepiecesTokenizer(BaseTokenizer): 14 | def __init__(self, bpemodel: Union[Path, str], **kwargs): 15 | super().__init__(**kwargs) 16 | self.bpemodel = str(bpemodel) 17 | # NOTE(kamo): 18 | # Don't build SentencePieceProcessor in __init__() 19 | # because it's not picklable and it may cause following error, 20 | # "TypeError: can't pickle SwigPyObject objects", 21 | # when giving it as argument of "multiprocessing.Process()". 22 | self.sp = None 23 | 24 | def __repr__(self): 25 | return f'{self.__class__.__name__}(model="{self.bpemodel}")' 26 | 27 | def _build_sentence_piece_processor(self): 28 | # Build SentencePieceProcessor lazily. 29 | if self.sp is None: 30 | self.sp = spm.SentencePieceProcessor() 31 | self.sp.load(self.bpemodel) 32 | 33 | def text2tokens(self, line: str) -> List[str]: 34 | self._build_sentence_piece_processor() 35 | return self.sp.EncodeAsPieces(line) 36 | 37 | def tokens2text(self, tokens: Iterable[str]) -> str: 38 | self._build_sentence_piece_processor() 39 | return self.sp.DecodePieces(list(tokens)) 40 | 41 | def encode(self, line: str) -> List[int]: 42 | self._build_sentence_piece_processor() 43 | return self.sp.EncodeAsIds(line) 44 | 45 | def decode(self, line: List[int]): 46 | self._build_sentence_piece_processor() 47 | return self.sp.DecodeIds(line) 48 | -------------------------------------------------------------------------------- /funasr_detach/tokenizer/token_id_converter.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Dict 3 | from typing import Iterable 4 | from typing import List 5 | from typing import Union 6 | 7 | import numpy as np 8 | 9 | 10 | class TokenIDConverter: 11 | def __init__( 12 | self, 13 | token_list: Union[Path, str, Iterable[str]], 14 | unk_symbol: str = "", 15 | ): 16 | 17 | if isinstance(token_list, (Path, str)): 18 | token_list = Path(token_list) 19 | self.token_list_repr = str(token_list) 20 | self.token_list: List[str] = [] 21 | 22 | with token_list.open("r", encoding="utf-8") as f: 23 | for idx, line in enumerate(f): 24 | line = line.rstrip() 25 | self.token_list.append(line) 26 | 27 | else: 28 | self.token_list: List[str] = list(token_list) 29 | self.token_list_repr = "" 30 | for i, t in enumerate(self.token_list): 31 | if i == 3: 32 | break 33 | self.token_list_repr += f"{t}, " 34 | self.token_list_repr += f"... (NVocab={(len(self.token_list))})" 35 | 36 | self.token2id: Dict[str, int] = {} 37 | for i, t in enumerate(self.token_list): 38 | if t in self.token2id: 39 | raise RuntimeError(f'Symbol "{t}" is duplicated') 40 | self.token2id[t] = i 41 | 42 | self.unk_symbol = unk_symbol 43 | if self.unk_symbol not in self.token2id: 44 | raise RuntimeError( 45 | f"Unknown symbol '{unk_symbol}' doesn't exist in the token_list" 46 | ) 47 | self.unk_id = self.token2id[self.unk_symbol] 48 | 49 | def get_num_vocabulary_size(self) -> int: 50 | return len(self.token_list) 51 | 52 | def ids2tokens(self, integers: Union[np.ndarray, Iterable[int]]) -> List[str]: 53 | if isinstance(integers, np.ndarray) and integers.ndim != 1: 54 | raise ValueError(f"Must be 1 dim ndarray, but got {integers.ndim}") 55 | return [self.token_list[i] for i in integers] 56 | 57 | def tokens2ids(self, tokens: Iterable[str]) -> List[int]: 58 | return [self.token2id.get(i, self.unk_id) for i in tokens] 59 | -------------------------------------------------------------------------------- /funasr_detach/tokenizer/word_tokenizer.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Iterable 3 | from typing import List 4 | from typing import Union 5 | import warnings 6 | 7 | 8 | from funasr_detach.tokenizer.abs_tokenizer import AbsTokenizer 9 | 10 | 11 | class WordTokenizer(AbsTokenizer): 12 | def __init__( 13 | self, 14 | delimiter: str = None, 15 | non_linguistic_symbols: Union[Path, str, Iterable[str]] = None, 16 | remove_non_linguistic_symbols: bool = False, 17 | ): 18 | self.delimiter = delimiter 19 | 20 | if not remove_non_linguistic_symbols and non_linguistic_symbols is not None: 21 | warnings.warn( 22 | "non_linguistic_symbols is only used " 23 | "when remove_non_linguistic_symbols = True" 24 | ) 25 | 26 | if non_linguistic_symbols is None: 27 | self.non_linguistic_symbols = set() 28 | elif isinstance(non_linguistic_symbols, (Path, str)): 29 | non_linguistic_symbols = Path(non_linguistic_symbols) 30 | try: 31 | with non_linguistic_symbols.open("r", encoding="utf-8") as f: 32 | self.non_linguistic_symbols = set(line.rstrip() for line in f) 33 | except FileNotFoundError: 34 | warnings.warn(f"{non_linguistic_symbols} doesn't exist.") 35 | self.non_linguistic_symbols = set() 36 | else: 37 | self.non_linguistic_symbols = set(non_linguistic_symbols) 38 | self.remove_non_linguistic_symbols = remove_non_linguistic_symbols 39 | 40 | def __repr__(self): 41 | return f'{self.__class__.__name__}(delimiter="{self.delimiter}")' 42 | 43 | def text2tokens(self, line: str) -> List[str]: 44 | tokens = [] 45 | for t in line.split(self.delimiter): 46 | if self.remove_non_linguistic_symbols and t in self.non_linguistic_symbols: 47 | continue 48 | tokens.append(t) 49 | return tokens 50 | 51 | def tokens2text(self, tokens: Iterable[str]) -> str: 52 | if self.delimiter is None: 53 | delimiter = " " 54 | else: 55 | delimiter = self.delimiter 56 | return delimiter.join(tokens) 57 | -------------------------------------------------------------------------------- /funasr_detach/train_utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stepfun-ai/Step-Audio/d67e374fd6aff418835418e2e32989ca2db00dff/funasr_detach/train_utils/__init__.py -------------------------------------------------------------------------------- /funasr_detach/train_utils/add_gradient_noise.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def add_gradient_noise( 5 | model: torch.nn.Module, 6 | iteration: int, 7 | duration: float = 100, 8 | eta: float = 1.0, 9 | scale_factor: float = 0.55, 10 | ): 11 | """Adds noise from a standard normal distribution to the gradients. 12 | 13 | The standard deviation (`sigma`) is controlled 14 | by the three hyper-parameters below. 15 | `sigma` goes to zero (no noise) with more iterations. 16 | 17 | Args: 18 | model: Model. 19 | iteration: Number of iterations. 20 | duration: {100, 1000}: Number of durations to control 21 | the interval of the `sigma` change. 22 | eta: {0.01, 0.3, 1.0}: The magnitude of `sigma`. 23 | scale_factor: {0.55}: The scale of `sigma`. 24 | """ 25 | interval = (iteration // duration) + 1 26 | sigma = eta / interval**scale_factor 27 | for param in model.parameters(): 28 | if param.grad is not None: 29 | _shape = param.grad.size() 30 | noise = sigma * torch.randn(_shape).to(param.device) 31 | param.grad += noise 32 | -------------------------------------------------------------------------------- /funasr_detach/train_utils/device_funcs.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | import warnings 3 | 4 | import numpy as np 5 | import torch 6 | 7 | 8 | def to_device(data, device=None, dtype=None, non_blocking=False, copy=False): 9 | """Change the device of object recursively""" 10 | if isinstance(data, dict): 11 | return { 12 | k: to_device(v, device, dtype, non_blocking, copy) for k, v in data.items() 13 | } 14 | elif dataclasses.is_dataclass(data) and not isinstance(data, type): 15 | return type(data)( 16 | *[ 17 | to_device(v, device, dtype, non_blocking, copy) 18 | for v in dataclasses.astuple(data) 19 | ] 20 | ) 21 | # maybe namedtuple. I don't know the correct way to judge namedtuple. 22 | elif isinstance(data, tuple) and type(data) is not tuple: 23 | return type(data)( 24 | *[to_device(o, device, dtype, non_blocking, copy) for o in data] 25 | ) 26 | elif isinstance(data, (list, tuple)): 27 | return type(data)(to_device(v, device, dtype, non_blocking, copy) for v in data) 28 | elif isinstance(data, np.ndarray): 29 | return to_device(torch.from_numpy(data), device, dtype, non_blocking, copy) 30 | elif isinstance(data, torch.Tensor): 31 | return data.to(device, dtype, non_blocking, copy) 32 | else: 33 | return data 34 | 35 | 36 | def force_gatherable(data, device): 37 | """Change object to gatherable in torch.nn.DataParallel recursively 38 | 39 | The difference from to_device() is changing to torch.Tensor if float or int 40 | value is found. 41 | 42 | The restriction to the returned value in DataParallel: 43 | The object must be 44 | - torch.cuda.Tensor 45 | - 1 or more dimension. 0-dimension-tensor sends warning. 46 | or a list, tuple, dict. 47 | 48 | """ 49 | if isinstance(data, dict): 50 | return {k: force_gatherable(v, device) for k, v in data.items()} 51 | # DataParallel can't handle NamedTuple well 52 | elif isinstance(data, tuple) and type(data) is not tuple: 53 | return type(data)(*[force_gatherable(o, device) for o in data]) 54 | elif isinstance(data, (list, tuple, set)): 55 | return type(data)(force_gatherable(v, device) for v in data) 56 | elif isinstance(data, np.ndarray): 57 | return force_gatherable(torch.from_numpy(data), device) 58 | elif isinstance(data, torch.Tensor): 59 | if data.dim() == 0: 60 | # To 1-dim array 61 | data = data[None] 62 | return data.to(device) 63 | elif isinstance(data, float): 64 | return torch.tensor([data], dtype=torch.float, device=device) 65 | elif isinstance(data, int): 66 | return torch.tensor([data], dtype=torch.long, device=device) 67 | elif data is None: 68 | return None 69 | else: 70 | warnings.warn(f"{type(data)} may not be gatherable by DataParallel") 71 | return data 72 | -------------------------------------------------------------------------------- /funasr_detach/train_utils/forward_adaptor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class ForwardAdaptor(torch.nn.Module): 5 | """Wrapped module to parallelize specified method 6 | 7 | torch.nn.DataParallel parallelizes only "forward()" 8 | and, maybe, the method having the other name can't be applied 9 | except for wrapping the module just like this class. 10 | 11 | Examples: 12 | >>> class A(torch.nn.Module): 13 | ... def foo(self, x): 14 | ... ... 15 | >>> model = A() 16 | >>> model = ForwardAdaptor(model, "foo") 17 | >>> model = torch.nn.DataParallel(model, device_ids=[0, 1]) 18 | >>> x = torch.randn(2, 10) 19 | >>> model(x) 20 | """ 21 | 22 | def __init__(self, module: torch.nn.Module, name: str): 23 | super().__init__() 24 | self.module = module 25 | self.name = name 26 | if not hasattr(module, name): 27 | raise ValueError(f"{module} doesn't have {name}") 28 | 29 | def forward(self, *args, **kwargs): 30 | func = getattr(self.module, self.name) 31 | return func(*args, **kwargs) 32 | -------------------------------------------------------------------------------- /funasr_detach/train_utils/initialize.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | """Initialize modules for espnet2 neural networks.""" 4 | 5 | import math 6 | import torch 7 | 8 | 9 | def initialize(model: torch.nn.Module, init: str): 10 | """Initialize weights of a neural network module. 11 | 12 | Parameters are initialized using the given method or distribution. 13 | 14 | Custom initialization routines can be implemented into submodules 15 | as function `espnet_initialization_fn` within the custom module. 16 | 17 | Args: 18 | model: Target. 19 | init: Method of initialization. 20 | """ 21 | 22 | # weight init 23 | for p in model.parameters(): 24 | if p.dim() > 1: 25 | if init == "xavier_uniform": 26 | torch.nn.init.xavier_uniform_(p.data) 27 | elif init == "xavier_normal": 28 | torch.nn.init.xavier_normal_(p.data) 29 | elif init == "kaiming_uniform": 30 | torch.nn.init.kaiming_uniform_(p.data, nonlinearity="relu") 31 | elif init == "kaiming_normal": 32 | torch.nn.init.kaiming_normal_(p.data, nonlinearity="relu") 33 | else: 34 | raise ValueError("Unknown initialization: " + init) 35 | # bias init 36 | for p in model.parameters(): 37 | if p.dim() == 1: 38 | p.data.zero_() 39 | 40 | # reset some modules with default init 41 | for m in model.modules(): 42 | if isinstance(m, (torch.nn.Embedding, torch.nn.LayerNorm, torch.nn.GroupNorm)): 43 | m.reset_parameters() 44 | if hasattr(m, "espnet_initialization_fn"): 45 | m.espnet_initialization_fn() 46 | 47 | # TODO(xkc): Hacking s3prl_frontend and wav2vec2encoder initialization 48 | if getattr(model, "encoder", None) and getattr( 49 | model.encoder, "reload_pretrained_parameters", None 50 | ): 51 | model.encoder.reload_pretrained_parameters() 52 | if getattr(model, "frontend", None) and getattr( 53 | model.frontend, "reload_pretrained_parameters", None 54 | ): 55 | model.frontend.reload_pretrained_parameters() 56 | -------------------------------------------------------------------------------- /funasr_detach/train_utils/model_summary.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | def get_human_readable_count(number: int) -> str: 6 | """Return human_readable_count 7 | 8 | Originated from: 9 | https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pytorch_lightning/core/memory.py 10 | 11 | Abbreviates an integer number with K, M, B, T for thousands, millions, 12 | billions and trillions, respectively. 13 | Examples: 14 | >>> get_human_readable_count(123) 15 | '123 ' 16 | >>> get_human_readable_count(1234) # (one thousand) 17 | '1 K' 18 | >>> get_human_readable_count(2e6) # (two million) 19 | '2 M' 20 | >>> get_human_readable_count(3e9) # (three billion) 21 | '3 B' 22 | >>> get_human_readable_count(4e12) # (four trillion) 23 | '4 T' 24 | >>> get_human_readable_count(5e15) # (more than trillion) 25 | '5,000 T' 26 | Args: 27 | number: a positive integer number 28 | Return: 29 | A string formatted according to the pattern described above. 30 | """ 31 | assert number >= 0 32 | labels = [" ", "K", "M", "B", "T"] 33 | num_digits = int(np.floor(np.log10(number)) + 1 if number > 0 else 1) 34 | num_groups = int(np.ceil(num_digits / 3)) 35 | num_groups = min(num_groups, len(labels)) # don't abbreviate beyond trillions 36 | shift = -3 * (num_groups - 1) 37 | number = number * (10**shift) 38 | index = num_groups - 1 39 | return f"{number:.2f} {labels[index]}" 40 | 41 | 42 | def to_bytes(dtype) -> int: 43 | # torch.float16 -> 16 44 | return int(str(dtype)[-2:]) // 8 45 | 46 | 47 | def model_summary(model: torch.nn.Module) -> str: 48 | message = "Model structure:\n" 49 | message += str(model) 50 | tot_params = sum(p.numel() for p in model.parameters()) 51 | num_params = sum(p.numel() for p in model.parameters() if p.requires_grad) 52 | percent_trainable = "{:.1f}".format(num_params * 100.0 / tot_params) 53 | tot_params = get_human_readable_count(tot_params) 54 | num_params = get_human_readable_count(num_params) 55 | message += "\n\nModel summary:\n" 56 | message += f" Class Name: {model.__class__.__name__}\n" 57 | message += f" Total Number of model parameters: {tot_params}\n" 58 | message += ( 59 | f" Number of trainable parameters: {num_params} ({percent_trainable}%)\n" 60 | ) 61 | 62 | dtype = next(iter(model.parameters())).dtype 63 | message += f" Type: {dtype}" 64 | return message 65 | -------------------------------------------------------------------------------- /funasr_detach/train_utils/recursive_op.py: -------------------------------------------------------------------------------- 1 | """Torch utility module.""" 2 | 3 | import torch 4 | 5 | if torch.distributed.is_available(): 6 | from torch.distributed import ReduceOp 7 | 8 | 9 | def recursive_sum(obj, weight: torch.Tensor, distributed: bool = False): 10 | assert weight.dim() == 1, weight.size() 11 | if isinstance(obj, (tuple, list)): 12 | return type(obj)(recursive_sum(v, weight, distributed) for v in obj) 13 | elif isinstance(obj, dict): 14 | return {k: recursive_sum(v, weight, distributed) for k, v in obj.items()} 15 | elif isinstance(obj, torch.Tensor): 16 | assert obj.size() == weight.size(), (obj.size(), weight.size()) 17 | obj = (obj * weight.type(obj.dtype)).sum() 18 | if distributed: 19 | torch.distributed.all_reduce(obj, op=ReduceOp.SUM) 20 | return obj 21 | elif obj is None: 22 | return None 23 | else: 24 | raise ValueError(type(obj)) 25 | 26 | 27 | def recursive_divide(a, b: torch.Tensor): 28 | if isinstance(a, (tuple, list)): 29 | return type(a)(recursive_divide(v, b) for v in a) 30 | elif isinstance(a, dict): 31 | return {k: recursive_divide(v, b) for k, v in a.items()} 32 | elif isinstance(a, torch.Tensor): 33 | assert a.size() == b.size(), (a.size(), b.size()) 34 | return a / b.type(a.dtype) 35 | elif a is None: 36 | return None 37 | else: 38 | raise ValueError(type(a)) 39 | 40 | 41 | def recursive_average(obj, weight: torch.Tensor, distributed: bool = False): 42 | obj = recursive_sum(obj, weight, distributed) 43 | weight = weight.sum() 44 | if distributed: 45 | torch.distributed.all_reduce(weight, op=ReduceOp.SUM) 46 | # Normalize weight to be sum-to-1 47 | obj = recursive_divide(obj, weight) 48 | return obj, weight 49 | -------------------------------------------------------------------------------- /funasr_detach/train_utils/set_all_random_seed.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import numpy as np 4 | import torch 5 | 6 | 7 | def set_all_random_seed(seed: int): 8 | random.seed(seed) 9 | np.random.seed(seed) 10 | torch.random.manual_seed(seed) 11 | -------------------------------------------------------------------------------- /funasr_detach/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stepfun-ai/Step-Audio/d67e374fd6aff418835418e2e32989ca2db00dff/funasr_detach/utils/__init__.py -------------------------------------------------------------------------------- /funasr_detach/utils/datadir_writer.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Union 3 | import warnings 4 | 5 | 6 | class DatadirWriter: 7 | """Writer class to create kaldi like data directory. 8 | 9 | Examples: 10 | >>> with DatadirWriter("output") as writer: 11 | ... # output/sub.txt is created here 12 | ... subwriter = writer["sub.txt"] 13 | ... # Write "uttidA some/where/a.wav" 14 | ... subwriter["uttidA"] = "some/where/a.wav" 15 | ... subwriter["uttidB"] = "some/where/b.wav" 16 | 17 | """ 18 | 19 | def __init__(self, p: Union[Path, str]): 20 | self.path = Path(p) 21 | self.chilidren = {} 22 | self.fd = None 23 | self.has_children = False 24 | self.keys = set() 25 | 26 | def __enter__(self): 27 | return self 28 | 29 | def __getitem__(self, key: str) -> "DatadirWriter": 30 | if self.fd is not None: 31 | raise RuntimeError("This writer points out a file") 32 | 33 | if key not in self.chilidren: 34 | w = DatadirWriter((self.path / key)) 35 | self.chilidren[key] = w 36 | self.has_children = True 37 | 38 | retval = self.chilidren[key] 39 | return retval 40 | 41 | def __setitem__(self, key: str, value: str): 42 | if self.has_children: 43 | raise RuntimeError("This writer points out a directory") 44 | if key in self.keys: 45 | warnings.warn(f"Duplicated: {key}") 46 | 47 | if self.fd is None: 48 | self.path.parent.mkdir(parents=True, exist_ok=True) 49 | self.fd = self.path.open("w", encoding="utf-8") 50 | 51 | self.keys.add(key) 52 | self.fd.write(f"{key} {value}\n") 53 | self.fd.flush() 54 | 55 | def __exit__(self, exc_type, exc_val, exc_tb): 56 | self.close() 57 | 58 | def close(self): 59 | if self.has_children: 60 | prev_child = None 61 | for child in self.chilidren.values(): 62 | child.close() 63 | if prev_child is not None and prev_child.keys != child.keys: 64 | warnings.warn( 65 | f"Ids are mismatching between " 66 | f"{prev_child.path} and {child.path}" 67 | ) 68 | prev_child = child 69 | 70 | elif self.fd is not None: 71 | self.fd.close() 72 | -------------------------------------------------------------------------------- /funasr_detach/utils/misc.py: -------------------------------------------------------------------------------- 1 | import io 2 | from collections import OrderedDict 3 | import numpy as np 4 | 5 | 6 | def statistic_model_parameters(model, prefix=None): 7 | var_dict = model.state_dict() 8 | numel = 0 9 | for i, key in enumerate( 10 | sorted(list([x for x in var_dict.keys() if "num_batches_tracked" not in x])) 11 | ): 12 | if prefix is None or key.startswith(prefix): 13 | numel += var_dict[key].numel() 14 | return numel 15 | 16 | 17 | def int2vec(x, vec_dim=8, dtype=np.int32): 18 | b = ("{:0" + str(vec_dim) + "b}").format(x) 19 | # little-endian order: lower bit first 20 | return (np.array(list(b)[::-1]) == "1").astype(dtype) 21 | 22 | 23 | def seq2arr(seq, vec_dim=8): 24 | return np.row_stack([int2vec(int(x), vec_dim) for x in seq]) 25 | 26 | 27 | def load_scp_as_dict(scp_path, value_type="str", kv_sep=" "): 28 | with io.open(scp_path, "r", encoding="utf-8") as f: 29 | ret_dict = OrderedDict() 30 | for one_line in f.readlines(): 31 | one_line = one_line.strip() 32 | pos = one_line.find(kv_sep) 33 | key, value = one_line[:pos], one_line[pos + 1 :] 34 | if value_type == "list": 35 | value = value.split(" ") 36 | ret_dict[key] = value 37 | return ret_dict 38 | 39 | 40 | def load_scp_as_list(scp_path, value_type="str", kv_sep=" "): 41 | with io.open(scp_path, "r", encoding="utf8") as f: 42 | ret_dict = [] 43 | for one_line in f.readlines(): 44 | one_line = one_line.strip() 45 | pos = one_line.find(kv_sep) 46 | key, value = one_line[:pos], one_line[pos + 1 :] 47 | if value_type == "list": 48 | value = value.split(" ") 49 | ret_dict.append((key, value)) 50 | return ret_dict 51 | -------------------------------------------------------------------------------- /funasr_detach/utils/vad_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn.utils.rnn import pad_sequence 3 | 4 | 5 | def slice_padding_fbank(speech, speech_lengths, vad_segments): 6 | speech_list = [] 7 | speech_lengths_list = [] 8 | for i, segment in enumerate(vad_segments): 9 | 10 | bed_idx = int(segment[0][0] * 16) 11 | end_idx = min(int(segment[0][1] * 16), speech_lengths[0]) 12 | speech_i = speech[0, bed_idx:end_idx] 13 | speech_lengths_i = end_idx - bed_idx 14 | speech_list.append(speech_i) 15 | speech_lengths_list.append(speech_lengths_i) 16 | feats_pad = pad_sequence(speech_list, batch_first=True, padding_value=0.0) 17 | speech_lengths_pad = torch.Tensor(speech_lengths_list).int() 18 | return feats_pad, speech_lengths_pad 19 | 20 | 21 | def slice_padding_audio_samples(speech, speech_lengths, vad_segments): 22 | speech_list = [] 23 | speech_lengths_list = [] 24 | for i, segment in enumerate(vad_segments): 25 | bed_idx = int(segment[0][0] * 16) 26 | end_idx = min(int(segment[0][1] * 16), speech_lengths) 27 | speech_i = speech[bed_idx:end_idx] 28 | speech_lengths_i = end_idx - bed_idx 29 | speech_list.append(speech_i) 30 | speech_lengths_list.append(speech_lengths_i) 31 | 32 | return speech_list, speech_lengths_list 33 | -------------------------------------------------------------------------------- /funasr_detach/version.txt: -------------------------------------------------------------------------------- 1 | 1.0.8 2 | -------------------------------------------------------------------------------- /offline_inference.py: -------------------------------------------------------------------------------- 1 | import torchaudio 2 | import argparse 3 | from stepaudio import StepAudio 4 | 5 | 6 | def main(): 7 | parser = argparse.ArgumentParser(description="StepAudio Offline Inference") 8 | parser.add_argument( 9 | "--model-path", type=str, required=True, help="Base path for model files" 10 | ) 11 | args = parser.parse_args() 12 | 13 | model = StepAudio( 14 | tokenizer_path=f"{args.model_path}/Step-Audio-Tokenizer", 15 | tts_path=f"{args.model_path}/Step-Audio-TTS-3B", 16 | llm_path=f"{args.model_path}/Step-Audio-Chat", 17 | ) 18 | 19 | # example for text input 20 | text, audio, sr = model( 21 | [{"role": "user", "content": "你好,我是你的朋友,我叫小明,你叫什么名字?"}], 22 | "Tingting", 23 | ) 24 | torchaudio.save("output/output_e2e_tqta.wav", audio, sr) 25 | 26 | # example for audio input 27 | text, audio, sr = model( 28 | [ 29 | { 30 | "role": "user", 31 | "content": {"type": "audio", "audio": "output/output_e2e_tqta.wav"}, 32 | } 33 | ], 34 | "Tingting", 35 | ) 36 | torchaudio.save("output/output_e2e_aqta.wav", audio, sr) 37 | 38 | 39 | if __name__ == "__main__": 40 | main() 41 | -------------------------------------------------------------------------------- /requirements-vllm.txt: -------------------------------------------------------------------------------- 1 | torch==2.5.1 2 | torchaudio==2.5.1 3 | torchvision 4 | transformers==4.48.3 5 | accelerate==1.3.0 6 | openai-whisper 7 | onnxruntime-gpu==1.19.0 8 | omegaconf==2.3.0 9 | librosa==0.10.2.post1 10 | sox==1.5.0 11 | modelscope 12 | numpy==1.26.4 13 | six==1.16.0 14 | hyperpyyaml 15 | conformer==0.3.2 16 | diffusers 17 | pillow 18 | sentencepiece 19 | funasr>=1.1.3 20 | protobuf==5.29.3 21 | gradio>=5.16.0 22 | vllm==0.7.2 -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==2.3.1 2 | torchaudio==2.3.1 3 | torchvision==0.18.1 4 | transformers==4.48.3 5 | accelerate==1.3.0 6 | openai-whisper==20231117 7 | onnxruntime-gpu==1.17.0 8 | omegaconf==2.3.0 9 | librosa==0.10.2.post1 10 | sox==1.5.0 11 | modelscope 12 | numpy==1.26.4 13 | six==1.16.0 14 | hyperpyyaml 15 | conformer==0.3.2 16 | diffusers 17 | pillow 18 | sentencepiece 19 | funasr>=1.1.3 20 | protobuf==5.29.3 21 | gradio>=5.16.0 22 | -------------------------------------------------------------------------------- /speakers/TingtingRAP_prompt.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stepfun-ai/Step-Audio/d67e374fd6aff418835418e2e32989ca2db00dff/speakers/TingtingRAP_prompt.wav -------------------------------------------------------------------------------- /speakers/Tingting_prompt.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stepfun-ai/Step-Audio/d67e374fd6aff418835418e2e32989ca2db00dff/speakers/Tingting_prompt.wav -------------------------------------------------------------------------------- /speakers/Tingting哼唱_prompt.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stepfun-ai/Step-Audio/d67e374fd6aff418835418e2e32989ca2db00dff/speakers/Tingting哼唱_prompt.wav -------------------------------------------------------------------------------- /speakers/speakers_info.json: -------------------------------------------------------------------------------- 1 | { 2 | "TingtingRAP": "(RAP)远远甩开的笑他是陆行龟 他曾跌倒也曾吃过灰 他说有福的人才会多吃亏 他的爸爸让他小心交友可他偏偏钻进个垃圾堆 他说他明白How to play", 3 | "Tingting哼唱": "(哼唱)你从一座叫 我 的小镇经过 刚好屋顶的雪化成雨飘落", 4 | "Tingting": "那等我们到海洋馆之后,给妈妈买个礼物,好不好呀?" 5 | } -------------------------------------------------------------------------------- /tts_inference.py: -------------------------------------------------------------------------------- 1 | import torchaudio 2 | import argparse 3 | from tts import StepAudioTTS 4 | from tokenizer import StepAudioTokenizer 5 | from utils import load_audio 6 | import os 7 | 8 | 9 | def main(): 10 | parser = argparse.ArgumentParser(description="StepAudio Offline Inference") 11 | parser.add_argument( 12 | "--model-path", type=str, required=True, help="Base path for model files" 13 | ) 14 | parser.add_argument( 15 | "--synthesis-type", type=str, default="tts", help="Use tts or Clone for Synthesis" 16 | ) 17 | parser.add_argument( 18 | "--output-path", type=str, required=True, help="Output path for synthesis audios" 19 | ) 20 | args = parser.parse_args() 21 | os.makedirs(f"{args.output_path}", exist_ok=True) 22 | 23 | encoder = StepAudioTokenizer(f"{args.model_path}/Step-Audio-Tokenizer") 24 | tts_engine = StepAudioTTS(f"{args.model_path}/Step-Audio-TTS-3B", encoder) 25 | 26 | if args.synthesis_type == "tts": 27 | text = "(RAP)我踏上自由的征途,追逐那遥远的梦想,挣脱束缚的枷锁,让心灵随风飘荡,每一步都充满力量,每一刻都无比闪亮,自由的信念在燃烧,照亮我前进的方向!" 28 | output_audio, sr = tts_engine(text, "Tingting") 29 | torchaudio.save(f"{args.output_path}/output_tts.wav", output_audio, sr) 30 | else: 31 | clone_speaker = {"speaker":"test","prompt_text":"叫做秋风起蟹脚痒,啊,什么意思呢?就是说这秋风一起啊,螃蟹就该上市了。", "wav_path":"examples/prompt_wav_yuqian.wav"} 32 | text_clone = "人活一辈子,生老病死,总得是有高峰,有低谷,有顺境,有逆境,每个人都差不多。要不老话怎么讲,三十年河东,三十年河西呢。" 33 | output_audio, sr = tts_engine(text_clone, "",clone_speaker) 34 | torchaudio.save(f"{args.output_path}/output_clone.wav", output_audio, sr) 35 | 36 | if __name__ == "__main__": 37 | main() 38 | --------------------------------------------------------------------------------