├── .gitignore ├── .gitmodules ├── Dockerfile ├── LICENSE.txt ├── README.md ├── app.py ├── asset ├── InspireMusic-24kHz.png ├── InspireMusic.png ├── QR.jpg ├── dingding.png ├── dingtalk.png └── logo.png ├── docker-compose.yml ├── envs ├── inspiremusic.yaml ├── logo.png └── requirements.py38cu118.txt ├── examples └── music_generation │ ├── batch_infer.sh │ ├── batch_infer_1.5b_long.sh │ ├── conf │ ├── ds_stage2.json │ ├── inspiremusic.fromscratch.yaml │ ├── inspiremusic.yaml │ ├── inspiremusic_1.5b.yaml │ ├── inspiremusic_1.5b_long.yaml │ ├── inspiremusic_1.5b_long_infer.yaml │ ├── inspiremusic_base.yaml │ ├── inspiremusic_base_24khz.yaml │ └── inspiremusic_infer.yaml │ ├── data │ ├── dataset_example │ │ ├── text │ │ └── wav.scp │ └── samples │ │ ├── parquet │ │ ├── data.list │ │ └── parquet_000000000.tar │ │ └── text │ ├── infer.sh │ ├── infer_1.5b_long.sh │ ├── inspiremusic │ ├── local │ ├── download_and_untar.sh │ └── prepare_data.py │ ├── path.sh │ ├── run.sh │ └── tools ├── inspiremusic ├── __init__.py ├── bin │ ├── export_jit.py │ ├── export_onnx.py │ ├── flow_only_infer.py │ ├── inference.py │ └── train.py ├── cli │ ├── __init__.py │ ├── frontend.py │ ├── inference.py │ ├── inspiremusic.py │ └── model.py ├── dataset │ ├── __init__.py │ ├── dataset.py │ └── processor.py ├── flow │ ├── decoder.py │ ├── flow.py │ ├── flow_matching.py │ └── length_regulator.py ├── hifigan │ ├── discriminator.py │ ├── f0_predictor.py │ ├── generator.py │ └── hifigan.py ├── llm │ └── llm.py ├── metrics │ ├── clap_score.py │ ├── openl3_fd.py │ └── passt_kld.py ├── music_tokenizer │ ├── __init__.py │ ├── env.py │ ├── meldataset.py │ ├── models.py │ └── vqvae.py ├── text │ ├── abs_tokenizer.py │ └── tokenizer.py ├── transformer │ ├── __init__.py │ ├── activation.py │ ├── attention.py │ ├── convolution.py │ ├── decoder.py │ ├── decoder_layer.py │ ├── embedding.py │ ├── encoder.py │ ├── encoder_layer.py │ ├── label_smoothing_loss.py │ ├── positionwise_feed_forward.py │ ├── qwen_encoder.py │ └── subsampling.py ├── utils │ ├── __init__.py │ ├── audio_utils.py │ ├── binary.py │ ├── class_utils.py │ ├── common.py │ ├── data_utils.py │ ├── executor.py │ ├── file_utils.py │ ├── frontend_utils.py │ ├── hinter.py │ ├── losses.py │ ├── mask.py │ ├── scheduler.py │ ├── tokenizer_utils.py │ ├── train_utils.py │ └── utils.py ├── version.txt └── wavtokenizer │ ├── __init__.py │ ├── decoder │ ├── __init__.py │ ├── dataset.py │ ├── discriminator_dac.py │ ├── discriminators.py │ ├── experiment.py │ ├── feature_extractors.py │ ├── heads.py │ ├── helpers.py │ ├── loss.py │ ├── models.py │ ├── modules.py │ ├── pretrained.py │ ├── pretrained_model.py │ └── spectral_ops.py │ └── encoder │ ├── __init__.py │ ├── distrib.py │ ├── model.py │ ├── modules │ ├── __init__.py │ ├── conv.py │ ├── lstm.py │ ├── norm.py │ ├── seanet.py │ └── transformer.py │ ├── msstftd.py │ ├── quantization │ ├── __init__.py │ ├── ac.py │ ├── core_vq.py │ └── vq.py │ └── utils.py ├── requirements.txt ├── setup.py └── tools ├── extract_acoustic_token.py ├── extract_embedding.py ├── extract_semantic_token.py ├── extract_speech_token.py └── make_parquet_list.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | .idea/ 6 | ./examples/music_generation/exp/ 7 | ./examples/music_generation/data/ 8 | ./pretrained_models/ 9 | */__pycache__/ 10 | 11 | # C extensions 12 | *.so 13 | 14 | # Distribution / packaging 15 | .Python 16 | build/ 17 | develop-eggs/ 18 | dist/ 19 | downloads/ 20 | eggs/ 21 | .eggs/ 22 | lib/ 23 | lib64/ 24 | parts/ 25 | sdist/ 26 | var/ 27 | wheels/ 28 | share/python-wheels/ 29 | *.egg-info/ 30 | .installed.cfg 31 | *.egg 32 | MANIFEST 33 | *.DS_Store 34 | .DS_Store 35 | # PyInstaller 36 | # Usually these files are written by a python script from a template 37 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 38 | *.manifest 39 | *.spec 40 | 41 | # Installer logs 42 | pip-log.txt 43 | pip-delete-this-directory.txt 44 | 45 | # Unit test / coverage reports 46 | htmlcov/ 47 | .tox/ 48 | .nox/ 49 | .coverage 50 | .coverage.* 51 | .cache 52 | nosetests.xml 53 | coverage.xml 54 | *.cover 55 | *.py,cover 56 | .hypothesis/ 57 | .pytest_cache/ 58 | cover/ 59 | 60 | # Translations 61 | *.mo 62 | *.pot 63 | 64 | # Django stuff: 65 | *.log 66 | local_settings.py 67 | db.sqlite3 68 | db.sqlite3-journal 69 | 70 | # Flask stuff: 71 | instance/ 72 | .webassets-cache 73 | 74 | # Scrapy stuff: 75 | .scrapy 76 | 77 | # Sphinx documentation 78 | docs/_build/ 79 | 80 | # PyBuilder 81 | .pybuilder/ 82 | target/ 83 | 84 | # Jupyter Notebook 85 | .ipynb_checkpoints 86 | 87 | # IPython 88 | profile_default/ 89 | ipython_config.py 90 | 91 | # pyenv 92 | # For a library or package, you might want to ignore these files since the code is 93 | # intended to run in multiple environments; otherwise, check them in: 94 | # .python-version 95 | 96 | # pipenv 97 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 98 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 99 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 100 | # install all needed dependencies. 101 | #Pipfile.lock 102 | 103 | # poetry 104 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 105 | # This is especially recommended for binary packages to ensure reproducibility, and is more 106 | # commonly ignored for libraries. 107 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 108 | #poetry.lock 109 | 110 | # pdm 111 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 112 | #pdm.lock 113 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 114 | # in version control. 115 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 116 | .pdm.toml 117 | .pdm-python 118 | .pdm-build/ 119 | 120 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 121 | __pypackages__/ 122 | 123 | # Celery stuff 124 | celerybeat-schedule 125 | celerybeat.pid 126 | 127 | # SageMath parsed files 128 | *.sage.py 129 | 130 | # Environments 131 | .env 132 | .venv 133 | env/ 134 | venv/ 135 | ENV/ 136 | env.bak/ 137 | venv.bak/ 138 | 139 | # Spyder project settings 140 | .spyderproject 141 | .spyproject 142 | 143 | # Rope project settings 144 | .ropeproject 145 | 146 | # mkdocs documentation 147 | /site 148 | 149 | # mypy 150 | .mypy_cache/ 151 | .dmypy.json 152 | dmypy.json 153 | 154 | # Pyre type checker 155 | .pyre/ 156 | 157 | # pytype static type analyzer 158 | .pytype/ 159 | 160 | # Cython debug symbols 161 | cython_debug/ 162 | 163 | # PyCharm 164 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 165 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 166 | # and can be added to the global gitignore or merged into this file. For a more nuclear 167 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 168 | #.idea/ 169 | 170 | #InspireMusic 171 | exp/ 172 | pretrained_models/ 173 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "third_party/Matcha-TTS"] 2 | path = third_party/Matcha-TTS 3 | url = https://github.com/shivammehta25/Matcha-TTS.git -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | # Use PyTorch 2.6 GPU base image with Python 3.11 and CUDA 12.1/12.4 on Ubuntu 22.04 2 | FROM nvcr.io/nvidia/pytorch:24.08-py3 3 | 4 | ENV DEBIAN_FRONTEND=noninteractive 5 | 6 | # metainformation 7 | LABEL org.opencontainers.image.source = "https://github.com/FunAudioLLM/InspireMusic" 8 | LABEL org.opencontainers.image.licenses = "Apache License 2.0" 9 | 10 | # Set the working directory 11 | WORKDIR /workspace/InspireMusic 12 | # Copy the current directory contents into the container at /workspace/InspireMusic 13 | git clone https://github.com/FunAudioLLM/InspireMusic.git 14 | 15 | # inatall library dependencies 16 | RUN apt-get update && apt-get install -y ffmpeg sox libsox-dev git && apt-get clean 17 | RUN pip install -r requirements.txt 18 | 19 | # install flash attention 20 | RUN pip install flash-attn 21 | -------------------------------------------------------------------------------- /asset/InspireMusic-24kHz.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FunAudioLLM/InspireMusic/0aefb55b58b8df91accdc4aed20c99e44741e0a7/asset/InspireMusic-24kHz.png -------------------------------------------------------------------------------- /asset/InspireMusic.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FunAudioLLM/InspireMusic/0aefb55b58b8df91accdc4aed20c99e44741e0a7/asset/InspireMusic.png -------------------------------------------------------------------------------- /asset/QR.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FunAudioLLM/InspireMusic/0aefb55b58b8df91accdc4aed20c99e44741e0a7/asset/QR.jpg -------------------------------------------------------------------------------- /asset/dingding.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FunAudioLLM/InspireMusic/0aefb55b58b8df91accdc4aed20c99e44741e0a7/asset/dingding.png -------------------------------------------------------------------------------- /asset/dingtalk.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FunAudioLLM/InspireMusic/0aefb55b58b8df91accdc4aed20c99e44741e0a7/asset/dingtalk.png -------------------------------------------------------------------------------- /asset/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FunAudioLLM/InspireMusic/0aefb55b58b8df91accdc4aed20c99e44741e0a7/asset/logo.png -------------------------------------------------------------------------------- /docker-compose.yml: -------------------------------------------------------------------------------- 1 | services: 2 | inspire-music: 3 | build: 4 | context: . 5 | dockerfile: Dockerfile 6 | image: inspire-music 7 | container_name: inspire-music 8 | runtime: nvidia 9 | deploy: 10 | resources: 11 | reservations: 12 | devices: 13 | - driver: nvidia 14 | count: all 15 | capabilities: [gpu] 16 | volumes: 17 | - ./pretrained_models:/pretrained_models 18 | - .:/workspace/InspireMusic 19 | stdin_open: true 20 | tty: true 21 | -------------------------------------------------------------------------------- /envs/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FunAudioLLM/InspireMusic/0aefb55b58b8df91accdc4aed20c99e44741e0a7/envs/logo.png -------------------------------------------------------------------------------- /envs/requirements.py38cu118.txt: -------------------------------------------------------------------------------- 1 | --extra-index-url https://download.pytorch.org/whl/cu118 2 | conformer==0.3.2 3 | deepspeed==0.14.2; sys_platform == 'linux' 4 | diffusers==0.27.2 5 | gdown==5.1.0 6 | gradio==4.32.2 7 | grpcio==1.57.0 8 | grpcio-tools==1.57.0 9 | hydra-core==1.3.2 10 | HyperPyYAML==1.2.2 11 | inflect==7.3.1 12 | librosa==0.10.2 13 | lightning==2.2.4 14 | matplotlib==3.7.5 15 | modelscope==1.15.0 16 | networkx==3.1 17 | omegaconf==2.3.0 18 | onnx==1.16.0 19 | onnxruntime-gpu==1.16.0; sys_platform == 'linux' 20 | onnxruntime==1.16.0; sys_platform == 'darwin' or sys_platform == 'windows' 21 | openai-whisper==20231117 22 | protobuf==4.25 23 | pydantic==2.7.0 24 | rich==13.7.1 25 | soundfile==0.12.1 26 | tensorboard==2.14.0 27 | torch==2.0.1 28 | torchaudio==2.0.2 29 | uvicorn==0.30.0 30 | wget==3.2 31 | fastapi==0.111.0 32 | fastapi-cli==0.0.4 33 | WeTextProcessing==1.0.3 34 | transformers 35 | accelerate 36 | huggingface-hub==0.25.2 -------------------------------------------------------------------------------- /examples/music_generation/batch_infer.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright 2024 Alibaba Inc. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | . ./path.sh || exit 1; 16 | 17 | export TOKENIZERS_PARALLELISM=False 18 | 19 | model_name="InspireMusic-Base" 20 | pretrained_model_dir=../../pretrained_models/${model_name} 21 | dataset_name=samples 22 | 23 | # batch inference normal mode 24 | echo "Run inference." 25 | expr_name="inspiremusic_${dataset_name}" 26 | for task in 'text-to-music' 'continuation'; do 27 | python inspiremusic/bin/inference.py --task $task \ 28 | --gpu 0 \ 29 | --config conf/inspiremusic_infer.yaml \ 30 | --prompt_data data/${dataset_name}/parquet/data.list \ 31 | --flow_model $pretrained_model_dir/flow.pt \ 32 | --llm_model $pretrained_model_dir/llm.pt \ 33 | --music_tokenizer $pretrained_model_dir/music_tokenizer \ 34 | --wavtokenizer $pretrained_model_dir/wavtokenizer \ 35 | --chorus verse \ 36 | --output_sample_rate 48000 \ 37 | --min_generate_audio_seconds 5.0 \ 38 | --max_generate_audio_seconds 30.0 \ 39 | --batch \ 40 | --result_dir `pwd`/exp/${model_name}/${task}_${expr_name} 41 | # if use InspireMusic-xxxx-24kHz model, please set output sample rate to 24kHz 42 | # --output_sample_rate 24000 \ 43 | # use fast inference mode 44 | # --fast # fast mode without flow matching 45 | echo `pwd`/exp/${model_name}/${task}_${expr_name} 46 | done 47 | -------------------------------------------------------------------------------- /examples/music_generation/batch_infer_1.5b_long.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright 2024 Alibaba Inc. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | . ./path.sh || exit 1; 16 | 17 | convert_path() { 18 | if command -v cygpath &> /dev/null; then 19 | cygpath -w "$1" 20 | else 21 | echo "$1" 22 | fi 23 | } 24 | 25 | # Detect the operating system 26 | OS="$(uname -s)" 27 | 28 | export TOKENIZERS_PARALLELISM=False 29 | 30 | model_name="InspireMusic-1.5B-Long" 31 | pretrained_model_dir=../../pretrained_models/${model_name} 32 | dataset_name=samples 33 | expr_name="inspiremusic_${dataset_name}" 34 | 35 | echo "Run inference." 36 | 37 | case "$OS" in 38 | Linux*|Darwin*) 39 | echo "Running on Unix-like OS: $OS" 40 | # Use Unix-style paths 41 | # batch inference normal mode 42 | for task in 'text-to-music' 'continuation'; do 43 | python inspiremusic/bin/inference.py --task $task \ 44 | --gpu 0 \ 45 | --config conf/inspiremusic_1.5b_long_infer.yaml \ 46 | --prompt_data data/${dataset_name}/parquet/data.list \ 47 | --flow_model $pretrained_model_dir/flow.pt \ 48 | --llm_model $pretrained_model_dir/llm.pt \ 49 | --music_tokenizer $pretrained_model_dir/music_tokenizer \ 50 | --wavtokenizer $pretrained_model_dir/wavtokenizer \ 51 | --chorus default \ 52 | --output_sample_rate 48000 \ 53 | --min_generate_audio_seconds 5.0 \ 54 | --max_generate_audio_seconds 30.0 \ 55 | --batch \ 56 | --result_dir `pwd`/exp/${model_name}/${task}_${expr_name} 57 | # if use InspireMusic-xxxx-24kHz model, please set output sample rate to 24kHz 58 | # --output_sample_rate 24000 \ 59 | # use fast inference mode 60 | # --fast # fast mode without flow matching 61 | echo `pwd`/exp/${model_name}/${task}_${expr_name} 62 | done 63 | ;; 64 | CYGWIN*|MINGW*|MSYS*) 65 | echo "Running on Windows-like OS: $OS" 66 | # Use Windows-style paths 67 | pretrained_model_dir=$(convert_path "$pretrained_model_dir") 68 | # batch inference normal mode 69 | for task in 'text-to-music' 'continuation'; do 70 | python inspiremusic\bin\inference.py --task $task \ 71 | --gpu 0 \ 72 | --config conf\inspiremusic_1.5b_long_infer.yaml \ 73 | --prompt_data data\${dataset_name}\parquet\data.list \ 74 | --flow_model $pretrained_model_dir\flow.pt \ 75 | --llm_model $pretrained_model_dir\llm.pt \ 76 | --music_tokenizer $pretrained_model_dir\music_tokenizer \ 77 | --wavtokenizer $pretrained_model_dir\wavtokenizer \ 78 | --chorus default \ 79 | --output_sample_rate 48000 \ 80 | --min_generate_audio_seconds 5.0 \ 81 | --max_generate_audio_seconds 30.0 \ 82 | --batch \ 83 | --result_dir exp\${model_name}\${task}_${expr_name} 84 | # if use InspireMusic-xxxx-24kHz model, please set output sample rate to 24kHz 85 | # --output_sample_rate 24000 \ 86 | # use fast inference mode 87 | # --fast # fast mode without flow matching 88 | echo exp\${model_name}\${task}_${expr_name} 89 | done 90 | ;; 91 | *) 92 | echo "Unknown OS: $OS" 93 | exit 1 94 | ;; 95 | esac 96 | -------------------------------------------------------------------------------- /examples/music_generation/conf/ds_stage2.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_micro_batch_size_per_gpu": 1, 3 | "gradient_accumulation_steps": 1, 4 | "steps_per_print": 100, 5 | "gradient_clipping": 5, 6 | "fp16": { 7 | "enabled": false, 8 | "auto_cast": false, 9 | "loss_scale": 0, 10 | "initial_scale_power": 16, 11 | "loss_scale_window": 256, 12 | "hysteresis": 2, 13 | "consecutive_hysteresis": false, 14 | "min_loss_scale": 1 15 | }, 16 | "bf16": { 17 | "enabled": false 18 | }, 19 | "zero_force_ds_cpu_optimizer": false, 20 | "zero_optimization": { 21 | "stage": 2, 22 | "offload_optimizer": { 23 | "device": "none", 24 | "pin_memory": true 25 | }, 26 | "allgather_partitions": true, 27 | "allgather_bucket_size": 5e8, 28 | "overlap_comm": false, 29 | "reduce_scatter": true, 30 | "reduce_bucket_size": 5e8, 31 | "contiguous_gradients" : true 32 | }, 33 | "optimizer": { 34 | "type": "AdamW", 35 | "params": { 36 | "lr": 0.001, 37 | "weight_decay": 0.0001, 38 | "torch_adam": true, 39 | "adam_w_mode": true 40 | } 41 | } 42 | } -------------------------------------------------------------------------------- /examples/music_generation/conf/inspiremusic.fromscratch.yaml: -------------------------------------------------------------------------------- 1 | # set random seed, so that you may reproduce your result. 2 | __set_seed1: !apply:random.seed [1024] 3 | __set_seed2: !apply:numpy.random.seed [1024] 4 | __set_seed3: !apply:torch.manual_seed [1024] 5 | __set_seed4: !apply:torch.cuda.manual_seed_all [1024] 6 | 7 | # fixed params 8 | sample_rate: 24000 9 | text_encoder_input_size: 512 10 | llm_input_size: 896 11 | llm_output_size: 896 12 | 13 | basemodel_path: '../../pretrained_models/InspireMusic-Base/' 14 | generator_path: '../../pretrained_models/InspireMusic-Base/music_tokenizer' 15 | 16 | # model params 17 | # for all class/function included in this repo, we use ! or ! for intialization, so that user may find all corresponding class/function according to one single yaml. 18 | # for system/third_party class/function, we do not require this. 19 | llm: !new:inspiremusic.llm.llm.LLM 20 | text_encoder_input_size: !ref 21 | llm_input_size: !ref 22 | llm_output_size: !ref 23 | audio_token_size: 4096 24 | length_normalized_loss: True 25 | lsm_weight: 0 26 | text_encoder_conf: 27 | name: "none" 28 | llm: !new:inspiremusic.transformer.qwen_encoder.QwenEmbeddingEncoder 29 | input_size: !ref 30 | pretrain_path: !ref 31 | 32 | sampling: !name:inspiremusic.utils.common.ras_sampling 33 | top_p: 0.8 34 | top_k: 50 35 | win_size: 10 36 | tau_r: 0.1 37 | train_cfg_ratio: 0.2 38 | infer_cfg_ratio: 7.0 39 | flow: !new:inspiremusic.flow.flow.MaskedDiff 40 | input_size: 256 41 | output_size: 80 42 | output_type: 'mel' 43 | vocab_size: 4096 44 | input_frame_rate: 75 45 | only_mask_loss: True 46 | encoder: !new:inspiremusic.transformer.encoder.ConformerEncoder 47 | output_size: 512 48 | attention_heads: 4 49 | linear_units: 1024 50 | num_blocks: 3 51 | dropout_rate: 0.1 52 | positional_dropout_rate: 0.1 53 | attention_dropout_rate: 0.1 54 | normalize_before: True 55 | input_layer: 'linear' 56 | pos_enc_layer_type: 'rel_pos_espnet' 57 | selfattention_layer_type: 'rel_selfattn' 58 | input_size: 256 59 | use_cnn_module: False 60 | macaron_style: False 61 | length_regulator: !new:inspiremusic.flow.length_regulator.InterpolateRegulator 62 | channels: 512 63 | sampling_ratios: [1, 1, 1, 1] 64 | decoder: !new:inspiremusic.flow.flow_matching.ConditionalCFM 65 | in_channels: 240 66 | cfm_params: !new:omegaconf.DictConfig 67 | content: 68 | sigma_min: 1e-06 69 | solver: 'euler' 70 | t_scheduler: 'cosine' 71 | training_cfg_rate: 0.2 72 | inference_cfg_rate: 0.7 73 | reg_loss_type: 'l1' 74 | estimator: !new:inspiremusic.flow.decoder.ConditionalDecoder 75 | in_channels: 1024 76 | out_channels: 512 77 | channels: [256, 256] 78 | dropout: 0.0 79 | attention_head_dim: 64 80 | n_blocks: 4 81 | num_mid_blocks: 8 82 | num_heads: 8 83 | act_fn: 'gelu' 84 | generator_model_dir: !ref 85 | 86 | hift: !new:inspiremusic.hifigan.generator.HiFTGenerator 87 | in_channels: 80 88 | base_channels: 512 89 | nb_harmonics: 8 90 | sampling_rate: !ref 91 | nsf_alpha: 0.1 92 | nsf_sigma: 0.003 93 | nsf_voiced_threshold: 10 94 | upsample_rates: [8, 8] 95 | upsample_kernel_sizes: [16, 16] 96 | istft_params: 97 | n_fft: 16 98 | hop_len: 4 99 | resblock_kernel_sizes: [3, 7, 11] 100 | resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]] 101 | source_resblock_kernel_sizes: [7, 11] 102 | source_resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5]] 103 | lrelu_slope: 0.1 104 | audio_limit: 0.99 105 | f0_predictor: !new:inspiremusic.hifigan.f0_predictor.ConvRNNF0Predictor 106 | num_class: 1 107 | in_channels: 80 108 | cond_channels: 512 109 | 110 | wavtokenizer: !new:inspiremusic.hifigan.generator.HiFTGenerator 111 | 112 | # processor functions 113 | parquet_opener: !name:inspiremusic.dataset.processor.parquet_opener 114 | get_tokenizer: !name:inspiremusic.text.tokenizer.get_tokenizer 115 | tokenizer_path: !ref 116 | tokenizer_name: "qwen-2.0" 117 | allowed_special: 'all' 118 | tokenize: !name:inspiremusic.dataset.processor.tokenize 119 | get_tokenizer: !ref 120 | allowed_special: !ref 121 | filter: !name:inspiremusic.dataset.processor.filter 122 | max_length: 28000 123 | min_length: 0 124 | token_max_length: 200 125 | token_min_length: 1 126 | resample: !name:inspiremusic.dataset.processor.resample 127 | resample_rate: !ref 128 | feat_extractor: !name:matcha.utils.audio.mel_spectrogram 129 | n_fft: 1024 130 | num_mels: 128 131 | sampling_rate: !ref 132 | hop_size: 256 133 | win_size: 1024 134 | fmin: 0 135 | fmax: 24000 136 | center: False 137 | compute_fbank: !name:inspiremusic.dataset.processor.compute_fbank 138 | feat_extractor: !ref 139 | parse_embedding: !name:inspiremusic.dataset.processor.parse_embedding 140 | normalize: True 141 | shuffle: !name:inspiremusic.dataset.processor.shuffle 142 | shuffle_size: 1000 143 | sort: !name:inspiremusic.dataset.processor.sort 144 | sort_size: 500 # sort_size should be less than shuffle_size 145 | batch: !name:inspiremusic.dataset.processor.batch 146 | batch_type: 'dynamic' 147 | max_frames_in_batch: 12000 148 | padding: !name:inspiremusic.dataset.processor.padding 149 | 150 | # dataset processor pipeline 151 | data_pipeline: [ 152 | !ref , 153 | !ref , 154 | !ref , 155 | !ref , 156 | !ref , 157 | !ref , 158 | !ref , 159 | ] 160 | 161 | # train conf 162 | train_conf: 163 | optim: adam 164 | optim_conf: 165 | lr: 0.001 # change to 0.001 if you want to train flow from scratch 166 | scheduler: warmuplr 167 | scheduler_conf: 168 | warmup_steps: 5000 169 | max_epoch: 200 170 | grad_clip: 5 171 | accum_grad: 2 172 | log_interval: 100 173 | save_per_step: 1000 174 | -------------------------------------------------------------------------------- /examples/music_generation/conf/inspiremusic_1.5b.yaml: -------------------------------------------------------------------------------- 1 | # set random seed, so that you may reproduce your result. 2 | __set_seed1: !apply:random.seed [1024] 3 | __set_seed2: !apply:numpy.random.seed [1024] 4 | __set_seed3: !apply:torch.manual_seed [1024] 5 | __set_seed4: !apply:torch.cuda.manual_seed_all [1024] 6 | 7 | # fixed params 8 | sample_rate: 24000 9 | text_encoder_input_size: 512 10 | llm_input_size: 1536 11 | llm_output_size: 1536 12 | 13 | basemodel_path: '../../pretrained_models/InspireMusic-1.5B/' 14 | generator_path: '../../pretrained_models/InspireMusic-1.5B/music_tokenizer' 15 | 16 | # model params 17 | # for all class/function included in this repo, we use ! or ! for intialization, so that user may find all corresponding class/function according to one single yaml. 18 | # for system/third_party class/function, we do not require this. 19 | llm: !new:inspiremusic.llm.llm.LLM 20 | text_encoder_input_size: !ref 21 | llm_input_size: !ref 22 | llm_output_size: !ref 23 | audio_token_size: 4096 24 | length_normalized_loss: True 25 | lsm_weight: 0 26 | text_encoder_conf: 27 | name: "none" 28 | llm: !new:inspiremusic.transformer.qwen_encoder.QwenEmbeddingEncoder 29 | input_size: !ref 30 | pretrain_path: !ref 31 | 32 | sampling: !name:inspiremusic.utils.common.topk_sampling 33 | top_k: 350 34 | train_cfg_ratio: 0.2 35 | infer_cfg_ratio: 3.0 36 | flow: !new:inspiremusic.flow.flow.MaskedDiff 37 | input_size: 256 38 | output_size: 80 39 | output_type: 'mel' 40 | vocab_size: 4096 41 | input_frame_rate: 75 42 | only_mask_loss: True 43 | encoder: !new:inspiremusic.transformer.encoder.ConformerEncoder 44 | output_size: 512 45 | attention_heads: 4 46 | linear_units: 1024 47 | num_blocks: 3 48 | dropout_rate: 0.1 49 | positional_dropout_rate: 0.1 50 | attention_dropout_rate: 0.1 51 | normalize_before: True 52 | input_layer: 'linear' 53 | pos_enc_layer_type: 'rel_pos_espnet' 54 | selfattention_layer_type: 'rel_selfattn' 55 | input_size: 256 56 | use_cnn_module: False 57 | macaron_style: False 58 | length_regulator: !new:inspiremusic.flow.length_regulator.InterpolateRegulator 59 | channels: 512 60 | sampling_ratios: [1, 1, 1, 1] 61 | decoder: !new:inspiremusic.flow.flow_matching.ConditionalCFM 62 | in_channels: 240 63 | cfm_params: !new:omegaconf.DictConfig 64 | content: 65 | sigma_min: 1e-06 66 | solver: 'euler' 67 | t_scheduler: 'cosine' 68 | training_cfg_rate: 0.2 69 | inference_cfg_rate: 0.7 70 | reg_loss_type: 'l1' 71 | estimator: !new:inspiremusic.flow.decoder.ConditionalDecoder 72 | in_channels: 1024 73 | out_channels: 512 74 | channels: [256, 256] 75 | dropout: 0.0 76 | attention_head_dim: 64 77 | n_blocks: 4 78 | num_mid_blocks: 8 79 | num_heads: 8 80 | act_fn: 'gelu' 81 | generator_model_dir: !ref 82 | 83 | hift: !new:inspiremusic.hifigan.generator.HiFTGenerator 84 | in_channels: 80 85 | base_channels: 512 86 | nb_harmonics: 8 87 | sampling_rate: !ref 88 | nsf_alpha: 0.1 89 | nsf_sigma: 0.003 90 | nsf_voiced_threshold: 10 91 | upsample_rates: [8, 8] 92 | upsample_kernel_sizes: [16, 16] 93 | istft_params: 94 | n_fft: 16 95 | hop_len: 4 96 | resblock_kernel_sizes: [3, 7, 11] 97 | resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]] 98 | source_resblock_kernel_sizes: [7, 11] 99 | source_resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5]] 100 | lrelu_slope: 0.1 101 | audio_limit: 0.99 102 | f0_predictor: !new:inspiremusic.hifigan.f0_predictor.ConvRNNF0Predictor 103 | num_class: 1 104 | in_channels: 80 105 | cond_channels: 512 106 | 107 | wavtokenizer: !new:inspiremusic.hifigan.generator.HiFTGenerator 108 | 109 | # processor functions 110 | parquet_opener: !name:inspiremusic.dataset.processor.parquet_opener 111 | get_tokenizer: !name:inspiremusic.text.tokenizer.get_tokenizer 112 | tokenizer_path: !ref 113 | tokenizer_name: "qwen-2.5" 114 | allowed_special: 'all' 115 | tokenize: !name:inspiremusic.dataset.processor.tokenize 116 | get_tokenizer: !ref 117 | allowed_special: !ref 118 | filter: !name:inspiremusic.dataset.processor.filter 119 | max_length: 28000 120 | min_length: 0 121 | token_max_length: 200 122 | token_min_length: 1 123 | resample: !name:inspiremusic.dataset.processor.resample 124 | resample_rate: !ref 125 | feat_extractor: !name:matcha.utils.audio.mel_spectrogram 126 | n_fft: 1024 127 | num_mels: 128 128 | sampling_rate: !ref 129 | hop_size: 256 130 | win_size: 1024 131 | fmin: 0 132 | fmax: 24000 133 | center: False 134 | compute_fbank: !name:inspiremusic.dataset.processor.compute_fbank 135 | feat_extractor: !ref 136 | parse_embedding: !name:inspiremusic.dataset.processor.parse_embedding 137 | normalize: True 138 | shuffle: !name:inspiremusic.dataset.processor.shuffle 139 | shuffle_size: 1000 140 | sort: !name:inspiremusic.dataset.processor.sort 141 | sort_size: 500 # sort_size should be less than shuffle_size 142 | batch: !name:inspiremusic.dataset.processor.batch 143 | batch_type: 'dynamic' 144 | max_frames_in_batch: 10000 # llm 12000 145 | padding: !name:inspiremusic.dataset.processor.padding 146 | 147 | # dataset processor pipeline 148 | data_pipeline: [ 149 | !ref , 150 | !ref , 151 | !ref , 152 | !ref , 153 | !ref , 154 | !ref , 155 | !ref , 156 | ] 157 | 158 | 159 | # train conf 160 | train_conf: 161 | optim: adam 162 | optim_conf: 163 | lr: 0.0001 # change to 0.001 if you want to train flow from scratch 164 | scheduler: warmuplr 165 | scheduler_conf: 166 | warmup_steps: 5000 167 | max_epoch: 200 168 | grad_clip: 5 169 | accum_grad: 2 170 | log_interval: 100 171 | save_per_step: 500 172 | -------------------------------------------------------------------------------- /examples/music_generation/conf/inspiremusic_1.5b_long.yaml: -------------------------------------------------------------------------------- 1 | # set random seed, so that you may reproduce your result. 2 | __set_seed1: !apply:random.seed [1024] 3 | __set_seed2: !apply:numpy.random.seed [1024] 4 | __set_seed3: !apply:torch.manual_seed [1024] 5 | __set_seed4: !apply:torch.cuda.manual_seed_all [1024] 6 | 7 | # fixed params 8 | sample_rate: 24000 9 | text_encoder_input_size: 512 10 | llm_input_size: 1536 11 | llm_output_size: 1536 12 | 13 | basemodel_path: '../../pretrained_models/InspireMusic-1.5B-Long/' 14 | generator_path: '../../pretrained_models/InspireMusic-1.5B-Long/music_tokenizer' 15 | 16 | # model params 17 | # for all class/function included in this repo, we use ! or ! for intialization, so that user may find all corresponding class/function according to one single yaml. 18 | # for system/third_party class/function, we do not require this. 19 | llm: !new:inspiremusic.llm.llm.LLM 20 | text_encoder_input_size: !ref 21 | llm_input_size: !ref 22 | llm_output_size: !ref 23 | audio_token_size: 4096 24 | length_normalized_loss: True 25 | lsm_weight: 0 26 | text_encoder_conf: 27 | name: "none" 28 | llm: !new:inspiremusic.transformer.qwen_encoder.QwenEmbeddingEncoder 29 | input_size: !ref 30 | pretrain_path: !ref 31 | 32 | sampling: !name:inspiremusic.utils.common.topk_sampling 33 | top_k: 350 34 | train_cfg_ratio: 0.2 35 | infer_cfg_ratio: 3.0 36 | flow: !new:inspiremusic.flow.flow.MaskedDiff 37 | input_size: 256 38 | output_size: 80 39 | output_type: 'mel' 40 | vocab_size: 4096 41 | input_frame_rate: 75 42 | only_mask_loss: True 43 | encoder: !new:inspiremusic.transformer.encoder.ConformerEncoder 44 | output_size: 512 45 | attention_heads: 4 46 | linear_units: 1024 47 | num_blocks: 3 48 | dropout_rate: 0.1 49 | positional_dropout_rate: 0.1 50 | attention_dropout_rate: 0.1 51 | normalize_before: True 52 | input_layer: 'linear' 53 | pos_enc_layer_type: 'rel_pos_espnet' 54 | selfattention_layer_type: 'rel_selfattn' 55 | input_size: 256 56 | use_cnn_module: False 57 | macaron_style: False 58 | length_regulator: !new:inspiremusic.flow.length_regulator.InterpolateRegulator 59 | channels: 512 60 | sampling_ratios: [1, 1, 1, 1] 61 | decoder: !new:inspiremusic.flow.flow_matching.ConditionalCFM 62 | in_channels: 240 63 | cfm_params: !new:omegaconf.DictConfig 64 | content: 65 | sigma_min: 1e-06 66 | solver: 'euler' 67 | t_scheduler: 'cosine' 68 | training_cfg_rate: 0.2 69 | inference_cfg_rate: 0.7 70 | reg_loss_type: 'l1' 71 | estimator: !new:inspiremusic.flow.decoder.ConditionalDecoder 72 | in_channels: 1024 73 | out_channels: 512 74 | channels: [256, 256] 75 | dropout: 0.0 76 | attention_head_dim: 64 77 | n_blocks: 4 78 | num_mid_blocks: 8 79 | num_heads: 8 80 | act_fn: 'gelu' 81 | generator_model_dir: !ref 82 | 83 | hift: !new:inspiremusic.hifigan.generator.HiFTGenerator 84 | in_channels: 80 85 | base_channels: 512 86 | nb_harmonics: 8 87 | sampling_rate: !ref 88 | nsf_alpha: 0.1 89 | nsf_sigma: 0.003 90 | nsf_voiced_threshold: 10 91 | upsample_rates: [8, 8] 92 | upsample_kernel_sizes: [16, 16] 93 | istft_params: 94 | n_fft: 16 95 | hop_len: 4 96 | resblock_kernel_sizes: [3, 7, 11] 97 | resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]] 98 | source_resblock_kernel_sizes: [7, 11] 99 | source_resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5]] 100 | lrelu_slope: 0.1 101 | audio_limit: 0.99 102 | f0_predictor: !new:inspiremusic.hifigan.f0_predictor.ConvRNNF0Predictor 103 | num_class: 1 104 | in_channels: 80 105 | cond_channels: 512 106 | 107 | wavtokenizer: !new:inspiremusic.hifigan.generator.HiFTGenerator 108 | 109 | # processor functions 110 | parquet_opener: !name:inspiremusic.dataset.processor.parquet_opener 111 | get_tokenizer: !name:inspiremusic.text.tokenizer.get_tokenizer 112 | tokenizer_path: !ref 113 | tokenizer_name: "qwen-2.5" 114 | allowed_special: 'all' 115 | tokenize: !name:inspiremusic.dataset.processor.tokenize 116 | get_tokenizer: !ref 117 | allowed_special: !ref 118 | filter: !name:inspiremusic.dataset.processor.filter 119 | max_length: 28000 120 | min_length: 0 121 | token_max_length: 200 122 | token_min_length: 1 123 | resample: !name:inspiremusic.dataset.processor.resample 124 | resample_rate: !ref 125 | feat_extractor: !name:matcha.utils.audio.mel_spectrogram 126 | n_fft: 1024 127 | num_mels: 128 128 | sampling_rate: !ref 129 | hop_size: 256 130 | win_size: 1024 131 | fmin: 0 132 | fmax: 24000 133 | center: False 134 | compute_fbank: !name:inspiremusic.dataset.processor.compute_fbank 135 | feat_extractor: !ref 136 | parse_embedding: !name:inspiremusic.dataset.processor.parse_embedding 137 | normalize: True 138 | shuffle: !name:inspiremusic.dataset.processor.shuffle 139 | shuffle_size: 1000 140 | sort: !name:inspiremusic.dataset.processor.sort 141 | sort_size: 500 # sort_size should be less than shuffle_size 142 | batch: !name:inspiremusic.dataset.processor.batch 143 | batch_type: 'dynamic' 144 | max_frames_in_batch: 10000 # llm 12000 145 | padding: !name:inspiremusic.dataset.processor.padding 146 | 147 | # dataset processor pipeline 148 | data_pipeline: [ 149 | !ref , 150 | !ref , 151 | !ref , 152 | !ref , 153 | !ref , 154 | !ref , 155 | !ref , 156 | ] 157 | 158 | 159 | # train conf 160 | train_conf: 161 | optim: adam 162 | optim_conf: 163 | lr: 0.0001 # change to 0.001 if you want to train flow from scratch 164 | scheduler: warmuplr 165 | scheduler_conf: 166 | warmup_steps: 5000 167 | max_epoch: 200 168 | grad_clip: 5 169 | accum_grad: 2 170 | log_interval: 100 171 | save_per_step: 500 172 | -------------------------------------------------------------------------------- /examples/music_generation/conf/inspiremusic_1.5b_long_infer.yaml: -------------------------------------------------------------------------------- 1 | # set random seed, so that you may reproduce your result. 2 | __set_seed1: !apply:random.seed [1024] 3 | __set_seed2: !apply:numpy.random.seed [1024] 4 | __set_seed3: !apply:torch.manual_seed [1024] 5 | __set_seed4: !apply:torch.cuda.manual_seed_all [1024] 6 | 7 | # fixed params 8 | sample_rate: 24000 9 | text_encoder_input_size: 512 10 | llm_input_size: 1536 11 | llm_output_size: 1536 12 | 13 | basemodel_path: '../../pretrained_models/InspireMusic-1.5B-Long/' 14 | generator_path: '../../pretrained_models/InspireMusic-1.5B-Long/music_tokenizer' 15 | 16 | # model params 17 | # for all class/function included in this repo, we use ! or ! for intialization, so that user may find all corresponding class/function according to one single yaml. 18 | # for system/third_party class/function, we do not require this. 19 | llm: !new:inspiremusic.llm.llm.LLM 20 | text_encoder_input_size: !ref 21 | llm_input_size: !ref 22 | llm_output_size: !ref 23 | audio_token_size: 4096 24 | length_normalized_loss: True 25 | lsm_weight: 0 26 | text_encoder_conf: 27 | name: "none" 28 | llm: !new:inspiremusic.transformer.qwen_encoder.QwenEmbeddingEncoder 29 | input_size: !ref 30 | pretrain_path: !ref 31 | 32 | sampling: !name:inspiremusic.utils.common.topk_sampling 33 | top_k: 350 34 | train_cfg_ratio: 0.2 35 | infer_cfg_ratio: 3.0 36 | flow: !new:inspiremusic.flow.flow.MaskedDiff 37 | input_size: 256 38 | output_size: 80 39 | output_type: 'mel' 40 | vocab_size: 4096 41 | input_frame_rate: 75 42 | only_mask_loss: True 43 | encoder: !new:inspiremusic.transformer.encoder.ConformerEncoder 44 | output_size: 512 45 | attention_heads: 4 46 | linear_units: 1024 47 | num_blocks: 3 48 | dropout_rate: 0.1 49 | positional_dropout_rate: 0.1 50 | attention_dropout_rate: 0.1 51 | normalize_before: True 52 | input_layer: 'linear' 53 | pos_enc_layer_type: 'rel_pos_espnet' 54 | selfattention_layer_type: 'rel_selfattn' 55 | input_size: 256 56 | use_cnn_module: False 57 | macaron_style: False 58 | length_regulator: !new:inspiremusic.flow.length_regulator.InterpolateRegulator 59 | channels: 512 60 | sampling_ratios: [1, 1, 1, 1] 61 | decoder: !new:inspiremusic.flow.flow_matching.ConditionalCFM 62 | in_channels: 240 63 | cfm_params: !new:omegaconf.DictConfig 64 | content: 65 | sigma_min: 1e-06 66 | solver: 'euler' 67 | t_scheduler: 'cosine' 68 | training_cfg_rate: 0.2 69 | inference_cfg_rate: 0.7 70 | reg_loss_type: 'l1' 71 | estimator: !new:inspiremusic.flow.decoder.ConditionalDecoder 72 | in_channels: 1024 73 | out_channels: 512 74 | channels: [256, 256] 75 | dropout: 0.0 76 | attention_head_dim: 64 77 | n_blocks: 4 78 | num_mid_blocks: 8 79 | num_heads: 8 80 | act_fn: 'gelu' 81 | generator_model_dir: !ref 82 | 83 | hift: !new:inspiremusic.hifigan.generator.HiFTGenerator 84 | in_channels: 80 85 | base_channels: 512 86 | nb_harmonics: 8 87 | sampling_rate: !ref 88 | nsf_alpha: 0.1 89 | nsf_sigma: 0.003 90 | nsf_voiced_threshold: 10 91 | upsample_rates: [8, 8] 92 | upsample_kernel_sizes: [16, 16] 93 | istft_params: 94 | n_fft: 16 95 | hop_len: 4 96 | resblock_kernel_sizes: [3, 7, 11] 97 | resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]] 98 | source_resblock_kernel_sizes: [7, 11] 99 | source_resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5]] 100 | lrelu_slope: 0.1 101 | audio_limit: 0.99 102 | f0_predictor: !new:inspiremusic.hifigan.f0_predictor.ConvRNNF0Predictor 103 | num_class: 1 104 | in_channels: 80 105 | cond_channels: 512 106 | 107 | wavtokenizer: !new:inspiremusic.hifigan.generator.HiFTGenerator 108 | 109 | # processor functions 110 | parquet_opener: !name:inspiremusic.dataset.processor.parquet_opener 111 | get_tokenizer: !name:inspiremusic.text.tokenizer.get_tokenizer 112 | tokenizer_path: !ref 113 | tokenizer_name: "qwen-2.5" 114 | allowed_special: 'all' 115 | tokenize: !name:inspiremusic.dataset.processor.tokenize 116 | get_tokenizer: !ref 117 | allowed_special: !ref 118 | filter: !name:inspiremusic.dataset.processor.filter 119 | max_length: 28000 120 | min_length: 0 121 | token_max_length: 200 122 | token_min_length: 1 123 | resample: !name:inspiremusic.dataset.processor.resample 124 | resample_rate: !ref 125 | feat_extractor: !name:matcha.utils.audio.mel_spectrogram 126 | n_fft: 1024 127 | num_mels: 128 128 | sampling_rate: !ref 129 | hop_size: 256 130 | win_size: 1024 131 | fmin: 0 132 | fmax: 24000 133 | center: False 134 | compute_fbank: !name:inspiremusic.dataset.processor.compute_fbank 135 | feat_extractor: !ref 136 | parse_embedding: !name:inspiremusic.dataset.processor.parse_embedding 137 | normalize: True 138 | shuffle: !name:inspiremusic.dataset.processor.shuffle 139 | shuffle_size: 1000 140 | sort: !name:inspiremusic.dataset.processor.sort 141 | sort_size: 500 # sort_size should be less than shuffle_size 142 | batch: !name:inspiremusic.dataset.processor.batch 143 | batch_type: 'static' 144 | batch_size: 8 145 | padding: !name:inspiremusic.dataset.processor.padding 146 | 147 | # dataset processor pipeline 148 | data_pipeline: [ 149 | !ref , 150 | !ref , 151 | !ref , 152 | !ref , 153 | !ref , 154 | !ref , 155 | !ref , 156 | ] 157 | 158 | 159 | # train conf 160 | train_conf: 161 | optim: adam 162 | optim_conf: 163 | lr: 0.0001 # change to 0.001 if you want to train flow from scratch 164 | scheduler: warmuplr 165 | scheduler_conf: 166 | warmup_steps: 5000 167 | max_epoch: 200 168 | grad_clip: 5 169 | accum_grad: 2 170 | log_interval: 100 171 | save_per_step: 500 172 | -------------------------------------------------------------------------------- /examples/music_generation/conf/inspiremusic_base.yaml: -------------------------------------------------------------------------------- 1 | # set random seed, so that you may reproduce your result. 2 | __set_seed1: !apply:random.seed [1024] 3 | __set_seed2: !apply:numpy.random.seed [1024] 4 | __set_seed3: !apply:torch.manual_seed [1024] 5 | __set_seed4: !apply:torch.cuda.manual_seed_all [1024] 6 | 7 | # fixed params 8 | sample_rate: 24000 9 | target_sample_rate: 48000 10 | text_encoder_input_size: 512 11 | llm_input_size: 896 12 | llm_output_size: 896 13 | 14 | basemodel_path: '../../pretrained_models/InspireMusic-Base/' 15 | generator_path: '../../pretrained_models/InspireMusic-Base/music_tokenizer' 16 | 17 | # model params 18 | # for all class/function included in this repo, we use ! or ! for intialization, so that user may find all corresponding class/function according to one single yaml. 19 | # for system/third_party class/function, we do not require this. 20 | llm: !new:inspiremusic.llm.llm.LLM 21 | text_encoder_input_size: !ref 22 | llm_input_size: !ref 23 | llm_output_size: !ref 24 | audio_token_size: 4096 25 | length_normalized_loss: True 26 | lsm_weight: 0 27 | text_encoder_conf: 28 | name: "none" 29 | llm: !new:inspiremusic.transformer.qwen_encoder.QwenEmbeddingEncoder 30 | input_size: !ref 31 | pretrain_path: !ref 32 | 33 | sampling: !name:inspiremusic.utils.common.topk_sampling 34 | top_k: 350 35 | train_cfg_ratio: 0.2 36 | infer_cfg_ratio: 3.0 37 | flow: !new:inspiremusic.flow.flow.MaskedDiff 38 | input_size: 256 39 | output_size: 80 40 | output_type: 'mel' 41 | vocab_size: 4096 42 | input_frame_rate: 75 43 | only_mask_loss: True 44 | encoder: !new:inspiremusic.transformer.encoder.ConformerEncoder 45 | output_size: 512 46 | attention_heads: 4 47 | linear_units: 1024 48 | num_blocks: 3 49 | dropout_rate: 0.1 50 | positional_dropout_rate: 0.1 51 | attention_dropout_rate: 0.1 52 | normalize_before: True 53 | input_layer: 'linear' 54 | pos_enc_layer_type: 'rel_pos_espnet' 55 | selfattention_layer_type: 'rel_selfattn' 56 | input_size: 256 57 | use_cnn_module: False 58 | macaron_style: False 59 | length_regulator: !new:inspiremusic.flow.length_regulator.InterpolateRegulator 60 | channels: 512 61 | sampling_ratios: [1, 1, 1, 1] 62 | decoder: !new:inspiremusic.flow.flow_matching.ConditionalCFM 63 | in_channels: 240 64 | cfm_params: !new:omegaconf.DictConfig 65 | content: 66 | sigma_min: 1e-06 67 | solver: 'euler' 68 | t_scheduler: 'cosine' 69 | training_cfg_rate: 0.2 70 | inference_cfg_rate: 0.7 71 | reg_loss_type: 'l1' 72 | estimator: !new:inspiremusic.flow.decoder.ConditionalDecoder 73 | in_channels: 1024 74 | out_channels: 512 75 | channels: [256, 256] 76 | dropout: 0.0 77 | attention_head_dim: 64 78 | n_blocks: 4 79 | num_mid_blocks: 8 80 | num_heads: 8 81 | act_fn: 'gelu' 82 | generator_model_dir: !ref 83 | 84 | hift: !new:inspiremusic.hifigan.generator.HiFTGenerator 85 | in_channels: 80 86 | base_channels: 512 87 | nb_harmonics: 8 88 | sampling_rate: !ref 89 | nsf_alpha: 0.1 90 | nsf_sigma: 0.003 91 | nsf_voiced_threshold: 10 92 | upsample_rates: [8, 8] 93 | upsample_kernel_sizes: [16, 16] 94 | istft_params: 95 | n_fft: 16 96 | hop_len: 4 97 | resblock_kernel_sizes: [3, 7, 11] 98 | resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]] 99 | source_resblock_kernel_sizes: [7, 11] 100 | source_resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5]] 101 | lrelu_slope: 0.1 102 | audio_limit: 0.99 103 | f0_predictor: !new:inspiremusic.hifigan.f0_predictor.ConvRNNF0Predictor 104 | num_class: 1 105 | in_channels: 80 106 | cond_channels: 512 107 | 108 | wavtokenizer: !new:inspiremusic.hifigan.generator.HiFTGenerator 109 | 110 | # processor functions 111 | parquet_opener: !name:inspiremusic.dataset.processor.parquet_opener 112 | get_tokenizer: !name:inspiremusic.text.tokenizer.get_tokenizer 113 | tokenizer_path: !ref 114 | tokenizer_name: "qwen-2.0" 115 | allowed_special: 'all' 116 | tokenize: !name:inspiremusic.dataset.processor.tokenize 117 | get_tokenizer: !ref 118 | allowed_special: !ref 119 | filter: !name:inspiremusic.dataset.processor.filter 120 | max_length: 20000 121 | min_length: 1 122 | token_max_length: 200 123 | token_min_length: 1 124 | max_acoustic_length: 20000 125 | min_acoustic_length: 1800 126 | mode: 'train_flow' 127 | 128 | resample: !name:inspiremusic.dataset.processor.resample 129 | resample_rate: !ref 130 | 131 | feat_extractor: !name:matcha.utils.audio.mel_spectrogram 132 | n_fft: 1024 133 | num_mels: 128 134 | sampling_rate: !ref 135 | hop_size: 256 136 | win_size: 1024 137 | fmin: 0 138 | fmax: 24000 139 | center: False 140 | compute_fbank: !name:inspiremusic.dataset.processor.compute_fbank 141 | feat_extractor: !ref 142 | parse_embedding: !name:inspiremusic.dataset.processor.parse_embedding 143 | normalize: True 144 | shuffle: !name:inspiremusic.dataset.processor.shuffle 145 | shuffle_size: 1000 146 | sort: !name:inspiremusic.dataset.processor.sort 147 | sort_size: 500 # sort_size should be less than shuffle_size 148 | batch: !name:inspiremusic.dataset.processor.batch 149 | batch_type: 'dynamic' 150 | max_frames_in_batch: 15500 # llm 12000 151 | # batch_type: 'static' 152 | # batch_size: 2 # llm 12000 153 | padding: !name:inspiremusic.dataset.processor.padding 154 | mode: 'train' 155 | 156 | # dataset processor pipeline 157 | data_pipeline: [ 158 | !ref , 159 | !ref , 160 | !ref , 161 | !ref , 162 | !ref , 163 | !ref , 164 | !ref , 165 | ] 166 | 167 | 168 | # train conf 169 | train_conf: 170 | optim: adam 171 | optim_conf: 172 | lr: 0.0001 # change to 0.001 if you want to train flow from scratch 173 | scheduler: warmuplr 174 | scheduler_conf: 175 | warmup_steps: 500 176 | max_epoch: 200 177 | grad_clip: 5 178 | accum_grad: 2 179 | log_interval: 100 180 | save_per_step: 500 181 | -------------------------------------------------------------------------------- /examples/music_generation/conf/inspiremusic_infer.yaml: -------------------------------------------------------------------------------- 1 | # set random seed, so that you may reproduce your result. 2 | __set_seed1: !apply:random.seed [1024] 3 | __set_seed2: !apply:numpy.random.seed [1024] 4 | __set_seed3: !apply:torch.manual_seed [1024] 5 | __set_seed4: !apply:torch.cuda.manual_seed_all [1024] 6 | 7 | # fixed params 8 | sample_rate: 24000 9 | target_sample_rate: 48000 10 | text_encoder_input_size: 512 11 | llm_input_size: 896 12 | llm_output_size: 896 13 | 14 | basemodel_path: '../../pretrained_models/InspireMusic-Base/' 15 | generator_path: '../../pretrained_models/InspireMusic-Base/music_tokenizer' 16 | 17 | # model params 18 | # for all class/function included in this repo, we use ! or ! for intialization, so that user may find all corresponding class/function according to one single yaml. 19 | # for system/third_party class/function, we do not require this. 20 | llm: !new:inspiremusic.llm.llm.LLM 21 | text_encoder_input_size: !ref 22 | llm_input_size: !ref 23 | llm_output_size: !ref 24 | audio_token_size: 4096 25 | length_normalized_loss: True 26 | lsm_weight: 0 27 | text_encoder_conf: 28 | name: "none" 29 | llm: !new:inspiremusic.transformer.qwen_encoder.QwenEmbeddingEncoder 30 | input_size: !ref 31 | pretrain_path: !ref 32 | 33 | sampling: !name:inspiremusic.utils.common.topk_sampling 34 | top_k: 350 35 | 36 | train_cfg_ratio: 0.2 37 | infer_cfg_ratio: 3.0 38 | flow: !new:inspiremusic.flow.flow.MaskedDiff 39 | input_size: 256 40 | output_size: 80 41 | output_type: 'mel' 42 | vocab_size: 4096 43 | input_frame_rate: 75 44 | only_mask_loss: True 45 | encoder: !new:inspiremusic.transformer.encoder.ConformerEncoder 46 | output_size: 512 47 | attention_heads: 4 48 | linear_units: 1024 49 | num_blocks: 3 50 | dropout_rate: 0.1 51 | positional_dropout_rate: 0.1 52 | attention_dropout_rate: 0.1 53 | normalize_before: True 54 | input_layer: 'linear' 55 | pos_enc_layer_type: 'rel_pos_espnet' 56 | selfattention_layer_type: 'rel_selfattn' 57 | input_size: 256 58 | use_cnn_module: False 59 | macaron_style: False 60 | length_regulator: !new:inspiremusic.flow.length_regulator.InterpolateRegulator 61 | channels: 512 62 | sampling_ratios: [1, 1, 1, 1] 63 | decoder: !new:inspiremusic.flow.flow_matching.ConditionalCFM 64 | in_channels: 240 65 | cfm_params: !new:omegaconf.DictConfig 66 | content: 67 | sigma_min: 1e-06 68 | solver: 'euler' 69 | t_scheduler: 'cosine' 70 | training_cfg_rate: 0.2 71 | inference_cfg_rate: 0.7 72 | reg_loss_type: 'l1' 73 | estimator: !new:inspiremusic.flow.decoder.ConditionalDecoder 74 | in_channels: 1024 75 | out_channels: 512 76 | channels: [256, 256] 77 | dropout: 0.0 78 | attention_head_dim: 64 79 | n_blocks: 4 80 | num_mid_blocks: 8 81 | num_heads: 8 82 | act_fn: 'gelu' 83 | generator_model_dir: !ref 84 | 85 | hift: !new:inspiremusic.hifigan.generator.HiFTGenerator 86 | in_channels: 80 87 | base_channels: 512 88 | nb_harmonics: 8 89 | sampling_rate: !ref 90 | nsf_alpha: 0.1 91 | nsf_sigma: 0.003 92 | nsf_voiced_threshold: 10 93 | upsample_rates: [8, 8] 94 | upsample_kernel_sizes: [16, 16] 95 | istft_params: 96 | n_fft: 16 97 | hop_len: 4 98 | resblock_kernel_sizes: [3, 7, 11] 99 | resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]] 100 | source_resblock_kernel_sizes: [7, 11] 101 | source_resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5]] 102 | lrelu_slope: 0.1 103 | audio_limit: 0.99 104 | f0_predictor: !new:inspiremusic.hifigan.f0_predictor.ConvRNNF0Predictor 105 | num_class: 1 106 | in_channels: 80 107 | cond_channels: 512 108 | 109 | wavtokenizer: !new:inspiremusic.hifigan.generator.HiFTGenerator 110 | 111 | # processor functions 112 | parquet_opener: !name:inspiremusic.dataset.processor.parquet_opener 113 | get_tokenizer: !name:inspiremusic.text.tokenizer.get_tokenizer 114 | tokenizer_path: !ref 115 | tokenizer_name: "qwen-2.0" 116 | allowed_special: 'all' 117 | tokenize: !name:inspiremusic.dataset.processor.tokenize 118 | get_tokenizer: !ref 119 | allowed_special: !ref 120 | filter: !name:inspiremusic.dataset.processor.filter 121 | max_length: 20000 122 | min_length: 1 123 | token_max_length: 200 124 | token_min_length: 1 125 | max_acoustic_length: 20000 126 | min_acoustic_length: 1800 127 | mode: 'train_flow' 128 | 129 | resample: !name:inspiremusic.dataset.processor.resample 130 | resample_rate: !ref 131 | 132 | feat_extractor: !name:matcha.utils.audio.mel_spectrogram 133 | n_fft: 1024 134 | num_mels: 128 135 | sampling_rate: !ref 136 | hop_size: 256 137 | win_size: 1024 138 | fmin: 0 139 | fmax: 24000 140 | center: False 141 | compute_fbank: !name:inspiremusic.dataset.processor.compute_fbank 142 | feat_extractor: !ref 143 | parse_embedding: !name:inspiremusic.dataset.processor.parse_embedding 144 | normalize: True 145 | shuffle: !name:inspiremusic.dataset.processor.shuffle 146 | shuffle_size: 1000 147 | sort: !name:inspiremusic.dataset.processor.sort 148 | sort_size: 500 # sort_size should be less than shuffle_size 149 | batch: !name:inspiremusic.dataset.processor.batch 150 | batch_type: 'static' 151 | batch_size: 16 152 | padding: !name:inspiremusic.dataset.processor.padding 153 | mode: 'train' 154 | 155 | # dataset processor pipeline 156 | data_pipeline: [ 157 | !ref , 158 | !ref , 159 | !ref , 160 | !ref , 161 | !ref , 162 | !ref , 163 | !ref , 164 | ] 165 | 166 | 167 | # train conf 168 | train_conf: 169 | optim: adam 170 | optim_conf: 171 | lr: 0.0001 # change to 0.001 if you want to train flow from scratch 172 | scheduler: warmuplr 173 | scheduler_conf: 174 | warmup_steps: 500 175 | max_epoch: 200 176 | grad_clip: 5 177 | accum_grad: 2 178 | log_interval: 100 179 | save_per_step: 500 180 | -------------------------------------------------------------------------------- /examples/music_generation/data/dataset_example/text: -------------------------------------------------------------------------------- 1 | electro_1 <|90.00|><|chorus|><|A dynamic blend of electronic beats and drum and bass rhythms.|><|120.00|> 2 | jazz_1 <|30.00|><|verse1|><|A smooth blend of contemporary jazz with soulful undertones, evoke a relaxed and sophisticated atmosphere.|><|60.00|> 3 | instrumental_1 <|0.00|><|intro|><|A soothing piano instrumental with a melancholic feel, evoke a sense of longing, complemented by light and serene instrumental solos.|><|30.00|> -------------------------------------------------------------------------------- /examples/music_generation/data/dataset_example/wav.scp: -------------------------------------------------------------------------------- 1 | electro_1 dataset/example/electro_1.wav 2 | jazz_1 dataset/example/jazz_1.wav 3 | instrumental_1 dataset/example/instrumental_1.wav -------------------------------------------------------------------------------- /examples/music_generation/data/samples/parquet/data.list: -------------------------------------------------------------------------------- 1 | data/samples/parquet/parquet_000000000.tar 2 | -------------------------------------------------------------------------------- /examples/music_generation/data/samples/parquet/parquet_000000000.tar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FunAudioLLM/InspireMusic/0aefb55b58b8df91accdc4aed20c99e44741e0a7/examples/music_generation/data/samples/parquet/parquet_000000000.tar -------------------------------------------------------------------------------- /examples/music_generation/data/samples/text: -------------------------------------------------------------------------------- 1 | 1 <|30.0|><|verse|><|Experience soothing and sensual instrumental jazz with a touch of Bossa Nova, perfect for a relaxing restaurant or spa ambiance.|><|60.0|> 2 | 2 <|0.0|><|intro|><|A delightful collection of classical keyboard music, purely instrumental, exuding a timeless and elegant charm.|><|30.0|> 3 | 3 <|120.0|><|chorus|><|The instrumental rap track exudes a classic boom bap vibe, characterized by its French hip-hop roots and a smooth, rhythmic flow.|><|150.0|> 4 | 4 <|300.0|><|outro|><|The music exudes a vibrant and sophisticated jazz ambiance, characterized by the rich, dynamic sounds of a big band ensemble. With instrumental purity and a touch of classical influence, it offers a captivating listening experience.|><|330.0|> -------------------------------------------------------------------------------- /examples/music_generation/infer.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright 2024 Alibaba Inc. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | . ./path.sh || exit 1; 16 | 17 | export TOKENIZERS_PARALLELISM=False 18 | 19 | model_name="InspireMusic-Base" 20 | pretrained_model_dir=../../pretrained_models/${model_name} 21 | dataset_name=samples 22 | 23 | # inference normal mode 24 | echo "Run inference." 25 | expr_name="inspiremusic_${dataset_name}" 26 | for task in 'text-to-music' 'continuation'; do 27 | python inspiremusic/bin/inference.py --task $task \ 28 | --gpu 0 \ 29 | --config conf/inspiremusic.yaml \ 30 | --prompt_data data/${dataset_name}/parquet/data.list \ 31 | --flow_model $pretrained_model_dir/flow.pt \ 32 | --llm_model $pretrained_model_dir/llm.pt \ 33 | --music_tokenizer $pretrained_model_dir/music_tokenizer \ 34 | --wavtokenizer $pretrained_model_dir/wavtokenizer \ 35 | --chorus verse \ 36 | --output_sample_rate 48000 \ 37 | --min_generate_audio_seconds 5.0 \ 38 | --max_generate_audio_seconds 30.0 \ 39 | --result_dir `pwd`/exp/${model_name}/${task}_${expr_name} 40 | # if use InspireMusic-xxxx-24kHz model, please set output sample rate to 24kHz 41 | # --output_sample_rate 24000 \ 42 | # use fast inference mode 43 | # --fast # fast mode without flow matching 44 | echo `pwd`/exp/${model_name}/${task}_${expr_name} 45 | done 46 | -------------------------------------------------------------------------------- /examples/music_generation/infer_1.5b_long.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright 2024 Alibaba Inc. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | . ./path.sh || exit 1; 16 | 17 | convert_path() { 18 | if command -v cygpath &> /dev/null; then 19 | cygpath -w "$1" 20 | else 21 | echo "$1" 22 | fi 23 | } 24 | 25 | # Detect the operating system 26 | OS="$(uname -s)" 27 | 28 | export TOKENIZERS_PARALLELISM=False 29 | 30 | model_name="InspireMusic-1.5B-Long" 31 | pretrained_model_dir=../../pretrained_models/${model_name} 32 | dataset_name=samples 33 | expr_name="inspiremusic_${dataset_name}" 34 | 35 | echo "Run inference." 36 | 37 | case "$OS" in 38 | Linux*|Darwin*) 39 | echo "Running on Unix-like OS: $OS" 40 | # Use Unix-style paths 41 | # inference normal mode 42 | for task in 'text-to-music' 'continuation'; do 43 | python inspiremusic/bin/inference.py --task $task \ 44 | --gpu 0 \ 45 | --config conf/inspiremusic_1.5b_long.yaml \ 46 | --prompt_data data/${dataset_name}/parquet/data.list \ 47 | --flow_model $pretrained_model_dir/flow.pt \ 48 | --llm_model $pretrained_model_dir/llm.pt \ 49 | --music_tokenizer $pretrained_model_dir/music_tokenizer \ 50 | --wavtokenizer $pretrained_model_dir/wavtokenizer \ 51 | --chorus default \ 52 | --output_sample_rate 48000 \ 53 | --min_generate_audio_seconds 5.0 \ 54 | --max_generate_audio_seconds 180.0 \ 55 | --result_dir `pwd`/exp/${model_name}/${task}_${expr_name} 56 | # if use InspireMusic-xxxx-24kHz model, please set output sample rate to 24kHz 57 | # --output_sample_rate 24000 \ 58 | # use fast inference mode 59 | # --fast # fast mode without flow matching 60 | echo `pwd`/exp/${model_name}/${task}_${expr_name} 61 | done 62 | ;; 63 | CYGWIN*|MINGW*|MSYS*) 64 | echo "Running on Windows-like OS: $OS" 65 | # Use Windows-style paths 66 | pretrained_model_dir=$(convert_path "$pretrained_model_dir") 67 | # inference normal mode 68 | for task in 'text-to-music' 'continuation'; do 69 | python inspiremusic\bin\inference.py --task $task \ 70 | --gpu 0 \ 71 | --config conf\inspiremusic_1.5b_long.yaml \ 72 | --prompt_data data\${dataset_name}\parquet\data.list \ 73 | --flow_model $pretrained_model_dir\flow.pt \ 74 | --llm_model $pretrained_model_dir\llm.pt \ 75 | --music_tokenizer $pretrained_model_dir\music_tokenizer \ 76 | --wavtokenizer $pretrained_model_dir\wavtokenizer \ 77 | --chorus default \ 78 | --output_sample_rate 48000 \ 79 | --min_generate_audio_seconds 5.0 \ 80 | --max_generate_audio_seconds 180.0 \ 81 | --result_dir exp\${model_name}\${task}_${expr_name} 82 | # if use InspireMusic-xxxx-24kHz model, please set output sample rate to 24kHz 83 | # --output_sample_rate 24000 \ 84 | # use fast inference mode 85 | # --fast # fast mode without flow matching 86 | echo exp\${model_name}\${task}_${expr_name} 87 | done 88 | ;; 89 | *) 90 | echo "Unknown OS: $OS" 91 | exit 1 92 | ;; 93 | esac 94 | -------------------------------------------------------------------------------- /examples/music_generation/inspiremusic: -------------------------------------------------------------------------------- 1 | ../../inspiremusic/ -------------------------------------------------------------------------------- /examples/music_generation/local/download_and_untar.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Copyright 2014 Johns Hopkins University (author: Daniel Povey) 4 | # Apache 2.0 5 | 6 | remove_archive=false 7 | 8 | if [ "$1" == --remove-archive ]; then 9 | remove_archive=true 10 | shift 11 | fi 12 | 13 | if [ $# -ne 3 ]; then 14 | echo "Usage: $0 [--remove-archive] " 15 | echo "e.g.: $0 /export/a15/vpanayotov/data www.openslr.org/resources/11 dev-clean" 16 | echo "With --remove-archive it will remove the archive after successfully un-tarring it." 17 | echo " can be one of: dev-clean, test-clean, dev-other, test-other," 18 | echo " train-clean-100, train-clean-360, train-other-500." 19 | exit 1 20 | fi 21 | 22 | data=$1 23 | url=$2 24 | part=$3 25 | 26 | if [ ! -d "$data" ]; then 27 | echo "$0: no such directory $data" 28 | exit 1 29 | fi 30 | 31 | part_ok=false 32 | list="dev-clean test-clean dev-other test-other train-clean-100 train-clean-360 train-other-500" 33 | for x in $list; do 34 | if [ "$part" == $x ]; then part_ok=true; fi 35 | done 36 | if ! $part_ok; then 37 | echo "$0: expected to be one of $list, but got '$part'" 38 | exit 1 39 | fi 40 | 41 | if [ -z "$url" ]; then 42 | echo "$0: empty URL base." 43 | exit 1 44 | fi 45 | 46 | if [ -f $data/LibriTTS/$part/.complete ]; then 47 | echo "$0: data part $part was already successfully extracted, nothing to do." 48 | exit 0 49 | fi 50 | 51 | 52 | # sizes of the archive files in bytes. This is some older versions. 53 | sizes_old="371012589 347390293 379743611 361838298 6420417880 23082659865 30626749128" 54 | # sizes_new is the archive file sizes of the final release. Some of these sizes are of 55 | # things we probably won't download. 56 | sizes_new="337926286 314305928 695964615 297279345 87960560420 33373768 346663984 328757843 6387309499 23049477885 30593501606" 57 | 58 | if [ -f $data/$part.tar.gz ]; then 59 | size=$(/bin/ls -l $data/$part.tar.gz | awk '{print $5}') 60 | size_ok=false 61 | for s in $sizes_old $sizes_new; do if [ $s == $size ]; then size_ok=true; fi; done 62 | if ! $size_ok; then 63 | echo "$0: removing existing file $data/$part.tar.gz because its size in bytes $size" 64 | echo "does not equal the size of one of the archives." 65 | rm $data/$part.tar.gz 66 | else 67 | echo "$data/$part.tar.gz exists and appears to be complete." 68 | fi 69 | fi 70 | 71 | if [ ! -f $data/$part.tar.gz ]; then 72 | if ! which wget >/dev/null; then 73 | echo "$0: wget is not installed." 74 | exit 1 75 | fi 76 | full_url=$url/$part.tar.gz 77 | echo "$0: downloading data from $full_url. This may take some time, please be patient." 78 | 79 | if ! wget -P $data --no-check-certificate $full_url; then 80 | echo "$0: error executing wget $full_url" 81 | exit 1 82 | fi 83 | fi 84 | 85 | if ! tar -C $data -xvzf $data/$part.tar.gz; then 86 | echo "$0: error un-tarring archive $data/$part.tar.gz" 87 | exit 1 88 | fi 89 | 90 | touch $data/LibriTTS/$part/.complete 91 | 92 | echo "$0: Successfully downloaded and un-tarred $data/$part.tar.gz" 93 | 94 | if $remove_archive; then 95 | echo "$0: removing $data/$part.tar.gz file since --remove-archive option was supplied." 96 | rm $data/$part.tar.gz 97 | fi 98 | -------------------------------------------------------------------------------- /examples/music_generation/local/prepare_data.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Alibaba Inc All Rights Reserved. 2 | 3 | import argparse 4 | import logging 5 | import glob 6 | import os 7 | from tqdm import tqdm 8 | 9 | 10 | logger = logging.getLogger() 11 | 12 | 13 | def main(): 14 | wavs = list(glob.glob('{}/*/*/*wav'.format(args.src_dir))) 15 | 16 | utt2wav, utt2text, utt2spk, spk2utt = {}, {}, {}, {} 17 | for wav in tqdm(wavs): 18 | txt = wav.replace('.wav', '.normalized.txt') 19 | if not os.path.exists(txt): 20 | logger.warning('{} do not exsist'.format(txt)) 21 | continue 22 | with open(txt) as f: 23 | content = ''.join(l.replace('\n', '') for l in f.readline()) 24 | utt = os.path.basename(wav).replace('.wav', '') 25 | spk = utt.split('_')[0] 26 | utt2wav[utt] = wav 27 | utt2text[utt] = content 28 | utt2spk[utt] = spk 29 | if spk not in spk2utt: 30 | spk2utt[spk] = [] 31 | spk2utt[spk].append(utt) 32 | 33 | with open('{}/wav.scp'.format(args.des_dir), 'w') as f: 34 | for k, v in utt2wav.items(): 35 | f.write('{} {}\n'.format(k, v)) 36 | with open('{}/text'.format(args.des_dir), 'w') as f: 37 | for k, v in utt2text.items(): 38 | f.write('{} {}\n'.format(k, v)) 39 | with open('{}/utt2spk'.format(args.des_dir), 'w') as f: 40 | for k, v in utt2spk.items(): 41 | f.write('{} {}\n'.format(k, v)) 42 | with open('{}/spk2utt'.format(args.des_dir), 'w') as f: 43 | for k, v in spk2utt.items(): 44 | f.write('{} {}\n'.format(k, ' '.join(v))) 45 | return 46 | 47 | 48 | if __name__ == "__main__": 49 | parser = argparse.ArgumentParser() 50 | parser.add_argument('--src_dir', 51 | type=str) 52 | parser.add_argument('--des_dir', 53 | type=str) 54 | args = parser.parse_args() 55 | main() 56 | -------------------------------------------------------------------------------- /examples/music_generation/path.sh: -------------------------------------------------------------------------------- 1 | # NOTE(kan-bayashi): Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C 2 | export PYTHONIOENCODING=UTF-8 3 | export PYTHONPATH=../../:../../third_party/Matcha-TTS:$PYTHONPATH 4 | 5 | #!/bin/bash 6 | export MAIN_ROOT=`realpath ${PWD}/../../../` 7 | 8 | export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH} 9 | export BIN_DIR=${MAIN_ROOT}/inspiremusic 10 | -------------------------------------------------------------------------------- /examples/music_generation/tools: -------------------------------------------------------------------------------- 1 | ../../tools -------------------------------------------------------------------------------- /inspiremusic/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FunAudioLLM/InspireMusic/0aefb55b58b8df91accdc4aed20c99e44741e0a7/inspiremusic/__init__.py -------------------------------------------------------------------------------- /inspiremusic/bin/export_jit.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Alibaba Inc 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import print_function 16 | 17 | import argparse 18 | import logging 19 | logging.getLogger('matplotlib').setLevel(logging.WARNING) 20 | import os 21 | import sys 22 | import torch 23 | ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) 24 | sys.path.append('{}/../..'.format(ROOT_DIR)) 25 | sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR)) 26 | from inspiremusic.cli.inspiremusic import InspireMusic 27 | 28 | 29 | def get_args(): 30 | parser = argparse.ArgumentParser(description='export your model for deployment') 31 | parser.add_argument('--model_dir', 32 | type=str, 33 | default='pretrained_models/InspireMusic', 34 | help='local path') 35 | args = parser.parse_args() 36 | print(args) 37 | return args 38 | 39 | 40 | def main(): 41 | args = get_args() 42 | logging.basicConfig(level=logging.DEBUG, 43 | format='%(asctime)s %(levelname)s %(message)s') 44 | 45 | torch._C._jit_set_fusion_strategy([('STATIC', 1)]) 46 | torch._C._jit_set_profiling_mode(False) 47 | torch._C._jit_set_profiling_executor(False) 48 | 49 | inspiremusic = InspireMusic(args.model_dir, load_jit=False, load_onnx=False) 50 | 51 | # 1. export llm text_encoder 52 | llm_text_encoder = inspiremusic.model.llm.text_encoder.half() 53 | script = torch.jit.script(llm_text_encoder) 54 | script = torch.jit.freeze(script) 55 | script = torch.jit.optimize_for_inference(script) 56 | script.save('{}/llm.text_encoder.fp16.zip'.format(args.model_dir)) 57 | 58 | # 2. export llm llm 59 | llm_llm = inspiremusic.model.llm.llm.half() 60 | script = torch.jit.script(llm_llm) 61 | script = torch.jit.freeze(script, preserved_attrs=['forward_chunk']) 62 | script = torch.jit.optimize_for_inference(script) 63 | script.save('{}/llm.llm.fp16.zip'.format(args.model_dir)) 64 | 65 | # 3. export flow encoder 66 | flow_encoder = inspiremusic.model.flow.encoder 67 | script = torch.jit.script(flow_encoder) 68 | script = torch.jit.freeze(script) 69 | script = torch.jit.optimize_for_inference(script) 70 | script.save('{}/flow.encoder.fp32.zip'.format(args.model_dir)) 71 | 72 | 73 | if __name__ == '__main__': 74 | main() 75 | -------------------------------------------------------------------------------- /inspiremusic/bin/export_onnx.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Antgroup Inc (authors: Zhoubofan, hexisyztem@icloud.com) 2 | # Copyright (c) 2024 Alibaba Inc 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from __future__ import print_function 17 | 18 | import argparse 19 | import logging 20 | logging.getLogger('matplotlib').setLevel(logging.WARNING) 21 | import os 22 | import sys 23 | import onnxruntime 24 | import random 25 | import torch 26 | from tqdm import tqdm 27 | ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) 28 | sys.path.append('{}/../..'.format(ROOT_DIR)) 29 | sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR)) 30 | from inspiremusic.cli.inspiremusic import InspireMusic 31 | 32 | 33 | def get_dummy_input(batch_size, seq_len, out_channels, device): 34 | x = torch.rand((batch_size, out_channels, seq_len), dtype=torch.float32, device=device) 35 | mask = torch.ones((batch_size, 1, seq_len), dtype=torch.float32, device=device) 36 | mu = torch.rand((batch_size, out_channels, seq_len), dtype=torch.float32, device=device) 37 | t = torch.rand((batch_size), dtype=torch.float32, device=device) 38 | spks = torch.rand((batch_size, out_channels), dtype=torch.float32, device=device) 39 | cond = torch.rand((batch_size, out_channels, seq_len), dtype=torch.float32, device=device) 40 | return x, mask, mu, t, spks, cond 41 | 42 | 43 | def get_args(): 44 | parser = argparse.ArgumentParser(description='export your model for deployment') 45 | parser.add_argument('--model_dir', 46 | type=str, 47 | default='pretrained_models/InspireMusic', 48 | help='local path') 49 | args = parser.parse_args() 50 | print(args) 51 | return args 52 | 53 | 54 | def main(): 55 | args = get_args() 56 | logging.basicConfig(level=logging.DEBUG, 57 | format='%(asctime)s %(levelname)s %(message)s') 58 | 59 | inspiremusic = InspireMusic(args.model_dir, load_jit=False, load_onnx=False) 60 | 61 | # 1. export flow decoder estimator 62 | estimator = inspiremusic.model.flow.decoder.estimator 63 | 64 | device = inspiremusic.model.device 65 | batch_size, seq_len = 1, 256 66 | out_channels = inspiremusic.model.flow.decoder.estimator.out_channels 67 | x, mask, mu, t, spks, cond = get_dummy_input(batch_size, seq_len, out_channels, device) 68 | torch.onnx.export( 69 | estimator, 70 | (x, mask, mu, t, spks, cond), 71 | '{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir), 72 | export_params=True, 73 | opset_version=18, 74 | do_constant_folding=True, 75 | input_names=['x', 'mask', 'mu', 't', 'spks', 'cond'], 76 | output_names=['estimator_out'], 77 | dynamic_axes={ 78 | 'x': {0: 'batch_size', 2: 'seq_len'}, 79 | 'mask': {0: 'batch_size', 2: 'seq_len'}, 80 | 'mu': {0: 'batch_size', 2: 'seq_len'}, 81 | 'cond': {0: 'batch_size', 2: 'seq_len'}, 82 | 't': {0: 'batch_size'}, 83 | 'spks': {0: 'batch_size'}, 84 | 'estimator_out': {0: 'batch_size', 2: 'seq_len'}, 85 | } 86 | ) 87 | 88 | # 2. test computation consistency 89 | option = onnxruntime.SessionOptions() 90 | option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL 91 | option.intra_op_num_threads = 1 92 | providers = ['CUDAExecutionProvider' if torch.cuda.is_available() else 'CPUExecutionProvider'] 93 | estimator_onnx = onnxruntime.InferenceSession('{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir), 94 | sess_options=option, providers=providers) 95 | 96 | for _ in tqdm(range(10)): 97 | x, mask, mu, t, spks, cond = get_dummy_input(random.randint(1, 6), random.randint(16, 512), out_channels, device) 98 | output_pytorch = estimator(x, mask, mu, t, spks, cond) 99 | ort_inputs = { 100 | 'x': x.cpu().numpy(), 101 | 'mask': mask.cpu().numpy(), 102 | 'mu': mu.cpu().numpy(), 103 | 't': t.cpu().numpy(), 104 | 'spks': spks.cpu().numpy(), 105 | 'cond': cond.cpu().numpy() 106 | } 107 | output_onnx = estimator_onnx.run(None, ort_inputs)[0] 108 | torch.testing.assert_allclose(output_pytorch, torch.from_numpy(output_onnx).to(device), rtol=1e-2, atol=1e-4) 109 | 110 | 111 | if __name__ == "__main__": 112 | main() 113 | -------------------------------------------------------------------------------- /inspiremusic/cli/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FunAudioLLM/InspireMusic/0aefb55b58b8df91accdc4aed20c99e44741e0a7/inspiremusic/cli/__init__.py -------------------------------------------------------------------------------- /inspiremusic/cli/frontend.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Alibaba Inc 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from functools import partial 15 | import torch 16 | from typing import Callable 17 | import re 18 | import inflect 19 | from inspiremusic.cli.model import InspireMusicModel 20 | from inspiremusic.utils.frontend_utils import contains_chinese, replace_blank, replace_corner_mark, remove_bracket, spell_out_number, split_paragraph 21 | from inspiremusic.wavtokenizer.decoder.pretrained import WavTokenizer 22 | 23 | class InspireMusicFrontEnd: 24 | def __init__(self, 25 | configs: Callable, 26 | get_tokenizer: Callable, 27 | llm_model: str, 28 | flow_model: str, 29 | music_tokenizer_dir: str, 30 | audio_tokenizer_dir: str, 31 | instruct: bool = False, 32 | dtype: str = "fp16", 33 | fast: bool = False, 34 | fp16: bool = True, 35 | allowed_special: str = 'all'): 36 | self.tokenizer = get_tokenizer() 37 | self.audio_tokenizer_dir = audio_tokenizer_dir 38 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 39 | 40 | self.bandwidth_id = torch.tensor([0]).to(self.device) 41 | self.wavtokenizer = WavTokenizer.from_pretrained_feat(f"{audio_tokenizer_dir}/config.yaml", f"{audio_tokenizer_dir}/model.pt").to(self.device) 42 | 43 | self.model = InspireMusicModel(configs['llm'], configs['flow'], configs['hift'], configs['wavtokenizer'], dtype, fast, fp16) 44 | self.model = self.model.load(llm_model, flow_model, music_tokenizer_dir, audio_tokenizer_dir) 45 | 46 | self.instruct = instruct 47 | self.allowed_special = allowed_special 48 | self.inflect_parser = inflect.engine() 49 | 50 | def _extract_text_token(self, text): 51 | text_token = self.tokenizer.encode(text, allowed_special=self.allowed_special) 52 | text_token = torch.tensor([text_token], dtype=torch.int32).to(self.device) 53 | text_token_len = torch.tensor([text_token.shape[1]], dtype=torch.int32).to(self.device) 54 | return text_token, text_token_len 55 | 56 | def _extract_audio_token(self, audio, sample_rate=24000): 57 | audio = torch.tensor(audio, dtype=torch.float32, device=self.device) 58 | _, audio_token = self.wavtokenizer.encode_infer(audio, bandwidth_id=self.bandwidth_id) 59 | audio_token = audio_token.squeeze(0) 60 | audio_token_len = torch.tensor([audio_token.shape[1]], dtype=torch.int32, device=self.device) 61 | return audio_token, audio_token_len 62 | 63 | def text_normalize(self, text, split=True): 64 | text = text.strip() 65 | if contains_chinese(text): 66 | text = text.replace("\n", "") 67 | text = replace_blank(text) 68 | text = replace_corner_mark(text) 69 | text = text.replace(".", "、") 70 | text = text.replace(" - ", ",") 71 | text = remove_bracket(text) 72 | text = re.sub(r'[,,]+$', '。', text) 73 | texts = list(split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "zh", token_max_n=80, token_min_n=60, merge_len=20, comma_split=False)) 74 | else: 75 | text = spell_out_number(text, self.inflect_parser) 76 | texts = list(split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "en", token_max_n=80, token_min_n=60, merge_len=20, comma_split=False)) 77 | if split is False: 78 | return text 79 | return texts 80 | 81 | def frontend_text_to_music(self, text, time_start, time_end, chorus): 82 | text_token, text_token_len = self._extract_text_token(text) 83 | model_input = {"text": text, "audio_token": None, "audio_token_len": None, 84 | "text_token": text_token, "text_token_len": text_token_len, 85 | "embeddings": [time_start, time_end, chorus], "raw_text":text} 86 | return model_input 87 | 88 | def frontend_continuation(self, text, audio, time_start, time_end, chorus, target_sr=24000): 89 | if text is None: 90 | text_token = None 91 | text_token_len = None 92 | else: 93 | text_token, text_token_len = self._extract_text_token(text) 94 | audio_token, audio_token_len = self._extract_audio_token(audio, target_sr) 95 | model_input = {"text": text, "audio_token": audio_token, "audio_token_len": audio_token_len, 96 | "text_token": text_token, "text_token_len": text_token_len, 97 | "embeddings": [time_start, time_end, chorus], "raw_text":text} 98 | return model_input 99 | 100 | -------------------------------------------------------------------------------- /inspiremusic/dataset/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FunAudioLLM/InspireMusic/0aefb55b58b8df91accdc4aed20c99e44741e0a7/inspiremusic/dataset/__init__.py -------------------------------------------------------------------------------- /inspiremusic/dataset/dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang) 2 | # 2024 Alibaba Inc 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import random 17 | import json 18 | import math 19 | from functools import partial 20 | 21 | import torch 22 | import torch.distributed as dist 23 | from torch.utils.data import IterableDataset 24 | from inspiremusic.utils.file_utils import read_lists, read_json_lists 25 | 26 | class Processor(IterableDataset): 27 | 28 | def __init__(self, source, f, *args, **kw): 29 | assert callable(f) 30 | self.source = source 31 | self.f = f 32 | self.args = args 33 | self.kw = kw 34 | 35 | def set_epoch(self, epoch): 36 | self.source.set_epoch(epoch) 37 | 38 | def __iter__(self): 39 | """ Return an iterator over the source dataset processed by the 40 | given processor. 41 | """ 42 | assert self.source is not None 43 | assert callable(self.f) 44 | return self.f(iter(self.source), *self.args, **self.kw) 45 | 46 | def apply(self, f): 47 | assert callable(f) 48 | return Processor(self, f, *self.args, **self.kw) 49 | 50 | 51 | class DistributedSampler: 52 | 53 | def __init__(self, shuffle=True, partition=True): 54 | self.epoch = -1 55 | self.update() 56 | self.shuffle = shuffle 57 | self.partition = partition 58 | 59 | def update(self): 60 | assert dist.is_available() 61 | if dist.is_initialized(): 62 | self.rank = dist.get_rank() 63 | self.world_size = dist.get_world_size() 64 | else: 65 | self.rank = 0 66 | self.world_size = 1 67 | worker_info = torch.utils.data.get_worker_info() 68 | if worker_info is None: 69 | self.worker_id = 0 70 | self.num_workers = 1 71 | else: 72 | self.worker_id = worker_info.id 73 | self.num_workers = worker_info.num_workers 74 | return dict(rank=self.rank, 75 | world_size=self.world_size, 76 | worker_id=self.worker_id, 77 | num_workers=self.num_workers) 78 | 79 | def set_epoch(self, epoch): 80 | self.epoch = epoch 81 | 82 | def sample(self, data): 83 | """ Sample data according to rank/world_size/num_workers 84 | 85 | Args: 86 | data(List): input data list 87 | 88 | Returns: 89 | List: data list after sample 90 | """ 91 | data = list(range(len(data))) 92 | # force datalist even 93 | 94 | if self.partition: 95 | if self.shuffle: 96 | random.Random(self.epoch).shuffle(data) 97 | if len(data) < self.world_size: 98 | print(len(data), self.world_size) 99 | data = data * math.ceil(self.world_size / len(data)) 100 | data = data[:self.world_size] 101 | data = data[self.rank::self.world_size] 102 | if len(data) < self.num_workers: 103 | data = data * math.ceil(self.num_workers / len(data)) 104 | data = data[:self.num_workers] 105 | data = data[self.worker_id::self.num_workers] 106 | return data 107 | 108 | 109 | class DataList(IterableDataset): 110 | 111 | def __init__(self, lists, shuffle=True, partition=True): 112 | self.lists = lists 113 | self.sampler = DistributedSampler(shuffle, partition) 114 | 115 | def set_epoch(self, epoch): 116 | self.sampler.set_epoch(epoch) 117 | 118 | def __iter__(self): 119 | sampler_info = self.sampler.update() 120 | indexes = self.sampler.sample(self.lists) 121 | for index in indexes: 122 | data = dict(src=self.lists[index]) 123 | data.update(sampler_info) 124 | yield data 125 | 126 | 127 | def Dataset(data_list_file, 128 | data_pipeline, 129 | mode='train', 130 | shuffle=True, 131 | partition=True 132 | ): 133 | """ Construct dataset from arguments 134 | 135 | We have two shuffle stage in the Dataset. The first is global 136 | shuffle at shards tar/raw file level. The second is global shuffle 137 | at training samples level. 138 | 139 | Args: 140 | data_type(str): raw/shard 141 | tokenizer (BaseTokenizer): tokenizer to tokenize 142 | partition(bool): whether to do data partition in terms of rank 143 | """ 144 | assert mode in ['train', 'inference', 'processing'] 145 | lists = read_lists(data_list_file) 146 | 147 | dataset = DataList(lists, 148 | shuffle=shuffle, 149 | partition=partition) 150 | 151 | for func in data_pipeline: 152 | dataset = Processor(dataset, func, mode=mode) 153 | 154 | return dataset 155 | -------------------------------------------------------------------------------- /inspiremusic/flow/flow.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Alibaba Inc 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import logging 15 | import random 16 | from typing import Dict, Optional 17 | import torch 18 | import torch.nn as nn 19 | from torch.nn import functional as F 20 | from omegaconf import DictConfig 21 | from inspiremusic.utils.mask import make_pad_mask 22 | from inspiremusic.music_tokenizer.vqvae import VQVAE 23 | 24 | class MaskedDiff(torch.nn.Module): 25 | def __init__(self, 26 | input_size: int = 512, 27 | output_size: int = 128, 28 | output_type: str = "mel", 29 | vocab_size: int = 4096, 30 | input_frame_rate: int = 50, 31 | only_mask_loss: bool = True, 32 | encoder: torch.nn.Module = None, 33 | length_regulator: torch.nn.Module = None, 34 | decoder: torch.nn.Module = None, 35 | decoder_conf: Dict = {'in_channels': 240, 'out_channel': 80, 36 | 'cfm_params': DictConfig({'sigma_min': 1e-06, 'solver': 'euler', 't_scheduler': 'cosine', 37 | 'training_cfg_rate': 0.2, 'inference_cfg_rate': 0.7, 'reg_loss_type': 'l1'}), 38 | 'decoder_params': {'channels': [256, 256], 'dropout': 0.0, 'attention_head_dim': 64, 39 | 'n_blocks': 4, 'num_mid_blocks': 12, 'num_heads': 8, 'act_fn': 'gelu'}}, 40 | mel_feat_conf: Dict = {'n_fft': 1024, 'num_mels': 128, 'sampling_rate': 48000, 41 | 'hop_size': 256, 'win_size': 1024, 'fmin': 0, 'fmax': 48000}, 42 | generator_model_dir: str = "../../pretrained_models/InspireMusic-Base/music_tokenizer", 43 | num_codebooks: int = 4 44 | ): 45 | super().__init__() 46 | self.input_size = input_size 47 | self.output_size = output_size 48 | self.decoder_conf = decoder_conf 49 | self.mel_feat_conf = mel_feat_conf 50 | self.vocab_size = vocab_size 51 | self.output_type = output_type 52 | self.input_frame_rate = input_frame_rate 53 | logging.info(f"input frame rate={self.input_frame_rate}") 54 | self.input_embedding = nn.Embedding(vocab_size, input_size) 55 | 56 | self.encoder = encoder 57 | self.encoder_proj = torch.nn.Linear(self.encoder.output_size(), output_size) 58 | self.decoder = decoder 59 | self.length_regulator = length_regulator 60 | self.only_mask_loss = only_mask_loss 61 | self.quantizer = VQVAE( f'{generator_model_dir}/config.json', 62 | f'{generator_model_dir}/model.pt',with_encoder=True).quantizer 63 | self.quantizer.eval() 64 | self.num_codebooks = num_codebooks 65 | self.cond = None 66 | self.interpolate = False 67 | 68 | def forward( 69 | self, 70 | batch: dict, 71 | device: torch.device, 72 | ) -> Dict[str, Optional[torch.Tensor]]: 73 | 74 | audio_token = batch['acoustic_token'].to(device) 75 | audio_token_len = batch['acoustic_token_len'].to(device) 76 | audio_token = audio_token.view(audio_token.size(0),-1,self.num_codebooks) 77 | if "semantic_token" not in batch: 78 | token = audio_token[:,:,0] 79 | token_len = (audio_token_len/self.num_codebooks).long() 80 | 81 | else: 82 | token = batch['semantic_token'].to(device) 83 | token_len = batch['semantic_token_len'].to(device) 84 | 85 | with torch.no_grad(): 86 | feat = self.quantizer.embed(audio_token) 87 | feat_len = (audio_token_len/self.num_codebooks).long() 88 | 89 | token = self.input_embedding(token) 90 | h, h_lengths = self.encoder(token, token_len) 91 | h, h_lengths = self.length_regulator(h, feat_len) 92 | 93 | # get conditions 94 | if self.cond: 95 | conds = torch.zeros(feat.shape, device=token.device) 96 | for i, j in enumerate(feat_len): 97 | if random.random() < 0.5: 98 | continue 99 | index = random.randint(0, int(0.3 * j)) 100 | conds[i, :index] = feat[i, :index] 101 | conds = conds.transpose(1, 2) 102 | else: 103 | conds = None 104 | 105 | mask = (~make_pad_mask(feat_len)).to(h) 106 | 107 | loss, _ = self.decoder.compute_loss( 108 | feat, 109 | mask.unsqueeze(1), 110 | h.transpose(1, 2).contiguous(), 111 | None, 112 | cond=conds 113 | ) 114 | 115 | return {'loss': loss} 116 | 117 | @torch.inference_mode() 118 | def inference(self, 119 | token, 120 | token_len, 121 | sample_rate): 122 | assert token.shape[0] == 1 123 | 124 | token = self.input_embedding(torch.clamp(token, min=0)) 125 | h, h_lengths = self.encoder(token, token_len) 126 | 127 | if sample_rate == 48000: 128 | token_len = 2 * token_len 129 | 130 | h, h_lengths = self.length_regulator(h, token_len) 131 | 132 | # get conditions 133 | conds = None 134 | 135 | mask = (~make_pad_mask(token_len)).to(h) 136 | feat = self.decoder( 137 | mu=h.transpose(1, 2).contiguous(), 138 | mask=mask.unsqueeze(1), 139 | spks=None, 140 | cond=conds, 141 | n_timesteps=10 142 | ) 143 | return feat -------------------------------------------------------------------------------- /inspiremusic/flow/length_regulator.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Alibaba Inc 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from typing import Tuple 15 | import torch.nn as nn 16 | import torch 17 | from torch.nn import functional as F 18 | from inspiremusic.utils.mask import make_pad_mask 19 | 20 | 21 | class InterpolateRegulator(nn.Module): 22 | def __init__( 23 | self, 24 | channels: int, 25 | sampling_ratios: Tuple, 26 | out_channels: int = None, 27 | groups: int = 1, 28 | ): 29 | super().__init__() 30 | self.sampling_ratios = sampling_ratios 31 | out_channels = out_channels or channels 32 | model = nn.ModuleList([]) 33 | if len(sampling_ratios) > 0: 34 | for _ in sampling_ratios: 35 | module = nn.Conv1d(channels, channels, 3, 1, 1) 36 | norm = nn.GroupNorm(groups, channels) 37 | act = nn.Mish() 38 | model.extend([module, norm, act]) 39 | model.append( 40 | nn.Conv1d(channels, out_channels, 1, 1) 41 | ) 42 | self.model = nn.Sequential(*model) 43 | 44 | def forward(self, x, ylens=None): 45 | # x in (B, T, D) 46 | mask = (~make_pad_mask(ylens)).to(x).unsqueeze(-1) 47 | x = F.interpolate(x.transpose(1, 2).contiguous(), size=ylens.max(), mode='linear') 48 | out = self.model(x).transpose(1, 2).contiguous() 49 | olens = ylens 50 | return out * mask, olens 51 | 52 | def inference(self, x1, x2, mel_len1, mel_len2, input_frame_rate=50): 53 | # in inference mode, interploate prompt token and token(head/mid/tail) seprately, so we can get a clear separation point of mel 54 | # x in (B, T, D) 55 | if x2.shape[1] > 40: 56 | x2_head = F.interpolate(x2[:, :20].transpose(1, 2).contiguous(), size=int(20 / input_frame_rate * 22050 / 256), mode='linear') 57 | x2_mid = F.interpolate(x2[:, 20:-20].transpose(1, 2).contiguous(), size=mel_len2 - int(20 / input_frame_rate * 22050 / 256) * 2, 58 | mode='linear') 59 | x2_tail = F.interpolate(x2[:, -20:].transpose(1, 2).contiguous(), size=int(20 / input_frame_rate * 22050 / 256), mode='linear') 60 | x2 = torch.concat([x2_head, x2_mid, x2_tail], dim=2) 61 | else: 62 | x2 = F.interpolate(x2.transpose(1, 2).contiguous(), size=mel_len2, mode='linear') 63 | if x1.shape[1] != 0: 64 | x1 = F.interpolate(x1.transpose(1, 2).contiguous(), size=mel_len1, mode='linear') 65 | x = torch.concat([x1, x2], dim=2) 66 | else: 67 | x = x2 68 | out = self.model(x).transpose(1, 2).contiguous() 69 | return out, mel_len1 + mel_len2 70 | -------------------------------------------------------------------------------- /inspiremusic/hifigan/discriminator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn.utils import weight_norm 4 | from typing import List, Optional, Tuple 5 | from einops import rearrange 6 | from torchaudio.transforms import Spectrogram 7 | 8 | 9 | class MultipleDiscriminator(nn.Module): 10 | def __init__( 11 | self, mpd: nn.Module, mrd: nn.Module 12 | ): 13 | super().__init__() 14 | self.mpd = mpd 15 | self.mrd = mrd 16 | 17 | def forward(self, y: torch.Tensor, y_hat: torch.Tensor): 18 | y_d_rs, y_d_gs, fmap_rs, fmap_gs = [], [], [], [] 19 | this_y_d_rs, this_y_d_gs, this_fmap_rs, this_fmap_gs = self.mpd(y.unsqueeze(dim=1), y_hat.unsqueeze(dim=1)) 20 | y_d_rs += this_y_d_rs 21 | y_d_gs += this_y_d_gs 22 | fmap_rs += this_fmap_rs 23 | fmap_gs += this_fmap_gs 24 | this_y_d_rs, this_y_d_gs, this_fmap_rs, this_fmap_gs = self.mrd(y, y_hat) 25 | y_d_rs += this_y_d_rs 26 | y_d_gs += this_y_d_gs 27 | fmap_rs += this_fmap_rs 28 | fmap_gs += this_fmap_gs 29 | return y_d_rs, y_d_gs, fmap_rs, fmap_gs 30 | 31 | 32 | class MultiResolutionDiscriminator(nn.Module): 33 | def __init__( 34 | self, 35 | fft_sizes: Tuple[int, ...] = (2048, 1024, 512), 36 | num_embeddings: Optional[int] = None, 37 | ): 38 | """ 39 | Multi-Resolution Discriminator module adapted from https://github.com/descriptinc/descript-audio-codec. 40 | Additionally, it allows incorporating conditional information with a learned embeddings table. 41 | 42 | Args: 43 | fft_sizes (tuple[int]): Tuple of window lengths for FFT. Defaults to (2048, 1024, 512). 44 | num_embeddings (int, optional): Number of embeddings. None means non-conditional discriminator. 45 | Defaults to None. 46 | """ 47 | 48 | super().__init__() 49 | self.discriminators = nn.ModuleList( 50 | [DiscriminatorR(window_length=w, num_embeddings=num_embeddings) for w in fft_sizes] 51 | ) 52 | 53 | def forward( 54 | self, y: torch.Tensor, y_hat: torch.Tensor, bandwidth_id: torch.Tensor = None 55 | ) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[List[torch.Tensor]], List[List[torch.Tensor]]]: 56 | y_d_rs = [] 57 | y_d_gs = [] 58 | fmap_rs = [] 59 | fmap_gs = [] 60 | 61 | for d in self.discriminators: 62 | y_d_r, fmap_r = d(x=y, cond_embedding_id=bandwidth_id) 63 | y_d_g, fmap_g = d(x=y_hat, cond_embedding_id=bandwidth_id) 64 | y_d_rs.append(y_d_r) 65 | fmap_rs.append(fmap_r) 66 | y_d_gs.append(y_d_g) 67 | fmap_gs.append(fmap_g) 68 | 69 | return y_d_rs, y_d_gs, fmap_rs, fmap_gs 70 | 71 | 72 | class DiscriminatorR(nn.Module): 73 | def __init__( 74 | self, 75 | window_length: int, 76 | num_embeddings: Optional[int] = None, 77 | channels: int = 32, 78 | hop_factor: float = 0.25, 79 | bands: Tuple[Tuple[float, float], ...] = ((0.0, 0.1), (0.1, 0.25), (0.25, 0.5), (0.5, 0.75), (0.75, 1.0)), 80 | ): 81 | super().__init__() 82 | self.window_length = window_length 83 | self.hop_factor = hop_factor 84 | self.spec_fn = Spectrogram( 85 | n_fft=window_length, hop_length=int(window_length * hop_factor), win_length=window_length, power=None 86 | ) 87 | n_fft = window_length // 2 + 1 88 | bands = [(int(b[0] * n_fft), int(b[1] * n_fft)) for b in bands] 89 | self.bands = bands 90 | convs = lambda: nn.ModuleList( 91 | [ 92 | weight_norm(nn.Conv2d(2, channels, (3, 9), (1, 1), padding=(1, 4))), 93 | weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))), 94 | weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))), 95 | weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))), 96 | weight_norm(nn.Conv2d(channels, channels, (3, 3), (1, 1), padding=(1, 1))), 97 | ] 98 | ) 99 | self.band_convs = nn.ModuleList([convs() for _ in range(len(self.bands))]) 100 | 101 | if num_embeddings is not None: 102 | self.emb = torch.nn.Embedding(num_embeddings=num_embeddings, embedding_dim=channels) 103 | torch.nn.init.zeros_(self.emb.weight) 104 | 105 | self.conv_post = weight_norm(nn.Conv2d(channels, 1, (3, 3), (1, 1), padding=(1, 1))) 106 | 107 | def spectrogram(self, x): 108 | # Remove DC offset 109 | x = x - x.mean(dim=-1, keepdims=True) 110 | # Peak normalize the volume of input audio 111 | x = 0.8 * x / (x.abs().max(dim=-1, keepdim=True)[0] + 1e-9) 112 | x = self.spec_fn(x) 113 | x = torch.view_as_real(x) 114 | x = rearrange(x, "b f t c -> b c t f") 115 | # Split into bands 116 | x_bands = [x[..., b[0]: b[1]] for b in self.bands] 117 | return x_bands 118 | 119 | def forward(self, x: torch.Tensor, cond_embedding_id: torch.Tensor = None): 120 | x_bands = self.spectrogram(x) 121 | fmap = [] 122 | x = [] 123 | for band, stack in zip(x_bands, self.band_convs): 124 | for i, layer in enumerate(stack): 125 | band = layer(band) 126 | band = torch.nn.functional.leaky_relu(band, 0.1) 127 | if i > 0: 128 | fmap.append(band) 129 | x.append(band) 130 | x = torch.cat(x, dim=-1) 131 | if cond_embedding_id is not None: 132 | emb = self.emb(cond_embedding_id) 133 | h = (emb.view(1, -1, 1, 1) * x).sum(dim=1, keepdims=True) 134 | else: 135 | h = 0 136 | x = self.conv_post(x) 137 | fmap.append(x) 138 | x += h 139 | 140 | return x, fmap 141 | -------------------------------------------------------------------------------- /inspiremusic/hifigan/f0_predictor.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Alibaba Inc 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import torch 15 | import torch.nn as nn 16 | from torch.nn.utils import weight_norm 17 | 18 | 19 | class ConvRNNF0Predictor(nn.Module): 20 | def __init__(self, 21 | num_class: int = 1, 22 | in_channels: int = 80, 23 | cond_channels: int = 512 24 | ): 25 | super().__init__() 26 | 27 | self.num_class = num_class 28 | self.condnet = nn.Sequential( 29 | weight_norm( 30 | nn.Conv1d(in_channels, cond_channels, kernel_size=3, padding=1) 31 | ), 32 | nn.ELU(), 33 | weight_norm( 34 | nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1) 35 | ), 36 | nn.ELU(), 37 | weight_norm( 38 | nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1) 39 | ), 40 | nn.ELU(), 41 | weight_norm( 42 | nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1) 43 | ), 44 | nn.ELU(), 45 | weight_norm( 46 | nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1) 47 | ), 48 | nn.ELU(), 49 | ) 50 | self.classifier = nn.Linear(in_features=cond_channels, out_features=self.num_class) 51 | 52 | def forward(self, x: torch.Tensor) -> torch.Tensor: 53 | x = self.condnet(x) 54 | x = x.transpose(1, 2) 55 | return torch.abs(self.classifier(x).squeeze(-1)) 56 | -------------------------------------------------------------------------------- /inspiremusic/hifigan/hifigan.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Optional 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from matcha.hifigan.models import feature_loss, generator_loss, discriminator_loss 6 | from inspiremusic.utils.losses import tpr_loss, mel_loss 7 | 8 | class HiFiGan(nn.Module): 9 | def __init__(self, generator, discriminator, mel_spec_transform, 10 | multi_mel_spectral_recon_loss_weight=45, feat_match_loss_weight=2.0, 11 | tpr_loss_weight=1.0, tpr_loss_tau=0.04): 12 | super(HiFiGan, self).__init__() 13 | self.generator = generator 14 | self.discriminator = discriminator 15 | self.mel_spec_transform = mel_spec_transform 16 | self.multi_mel_spectral_recon_loss_weight = multi_mel_spectral_recon_loss_weight 17 | self.feat_match_loss_weight = feat_match_loss_weight 18 | self.tpr_loss_weight = tpr_loss_weight 19 | self.tpr_loss_tau = tpr_loss_tau 20 | 21 | def forward( 22 | self, 23 | batch: dict, 24 | device: torch.device, 25 | ) -> Dict[str, Optional[torch.Tensor]]: 26 | if batch['turn'] == 'generator': 27 | return self.forward_generator(batch, device) 28 | else: 29 | return self.forward_discriminator(batch, device) 30 | 31 | def forward_generator(self, batch, device): 32 | real_speech = batch['speech'].to(device) 33 | pitch_feat = batch['pitch_feat'].to(device) 34 | # 1. calculate generator outputs 35 | generated_speech, generated_f0 = self.generator(batch, device) 36 | # 2. calculate discriminator outputs 37 | y_d_rs, y_d_gs, fmap_rs, fmap_gs = self.discriminator(real_speech, generated_speech) 38 | # 3. calculate generator losses, feature loss, mel loss, tpr losses [Optional] 39 | loss_gen, _ = generator_loss(y_d_gs) 40 | loss_fm = feature_loss(fmap_rs, fmap_gs) 41 | loss_mel = mel_loss(real_speech, generated_speech, self.mel_spec_transform) 42 | if self.tpr_loss_weight != 0: 43 | loss_tpr = tpr_loss(y_d_rs, y_d_gs, self.tpr_loss_tau) 44 | else: 45 | loss_tpr = torch.zeros(1).to(device) 46 | loss_f0 = F.l1_loss(generated_f0, pitch_feat) 47 | loss = loss_gen + self.feat_match_loss_weight * loss_fm + \ 48 | self.multi_mel_spectral_recon_loss_weight * loss_mel + \ 49 | self.tpr_loss_weight * loss_tpr + loss_f0 50 | return {'loss': loss, 'loss_gen': loss_gen, 'loss_fm': loss_fm, 'loss_mel': loss_mel, 'loss_tpr': loss_tpr, 'loss_f0': loss_f0} 51 | 52 | def forward_discriminator(self, batch, device): 53 | real_speech = batch['speech'].to(device) 54 | # 1. calculate generator outputs 55 | with torch.no_grad(): 56 | generated_speech, generated_f0 = self.generator(batch, device) 57 | # 2. calculate discriminator outputs 58 | y_d_rs, y_d_gs, fmap_rs, fmap_gs = self.discriminator(real_speech, generated_speech) 59 | # 3. calculate discriminator losses, tpr losses [Optional] 60 | loss_disc, _, _ = discriminator_loss(y_d_rs, y_d_gs) 61 | if self.tpr_loss_weight != 0: 62 | loss_tpr = tpr_loss(y_d_rs, y_d_gs, self.tpr_loss_tau) 63 | else: 64 | loss_tpr = torch.zeros(1).to(device) 65 | loss = loss_disc + self.tpr_loss_weight * loss_tpr 66 | return {'loss': loss, 'loss_disc': loss_disc, 'loss_tpr': loss_tpr} 67 | -------------------------------------------------------------------------------- /inspiremusic/music_tokenizer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FunAudioLLM/InspireMusic/0aefb55b58b8df91accdc4aed20c99e44741e0a7/inspiremusic/music_tokenizer/__init__.py -------------------------------------------------------------------------------- /inspiremusic/music_tokenizer/env.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Alibaba Inc 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | import shutil 17 | 18 | 19 | class AttrDict(dict): 20 | def __init__(self, *args, **kwargs): 21 | super(AttrDict, self).__init__(*args, **kwargs) 22 | self.__dict__ = self 23 | 24 | 25 | def build_env(config, config_name, path): 26 | t_path = os.path.join(path, config_name) 27 | if config != t_path: 28 | os.makedirs(path, exist_ok=True) 29 | shutil.copyfile(config, os.path.join(path, config_name)) 30 | -------------------------------------------------------------------------------- /inspiremusic/music_tokenizer/vqvae.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Alibaba Inc 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import json 16 | 17 | import torch 18 | import torch.nn as nn 19 | from inspiremusic.music_tokenizer.env import AttrDict 20 | from inspiremusic.music_tokenizer.models import Encoder 21 | from inspiremusic.music_tokenizer.models import Generator 22 | from inspiremusic.music_tokenizer.models import Quantizer 23 | 24 | 25 | class VQVAE(nn.Module): 26 | def __init__(self, 27 | config_path, 28 | ckpt_path, 29 | with_encoder=False): 30 | super(VQVAE, self).__init__() 31 | ckpt = torch.load(ckpt_path) 32 | with open(config_path) as f: 33 | data = f.read() 34 | json_config = json.loads(data) 35 | self.h = AttrDict(json_config) 36 | self.quantizer = Quantizer(self.h) 37 | self.generator = Generator(self.h) 38 | self.generator.load_state_dict(ckpt['generator']) 39 | self.quantizer.load_state_dict(ckpt['quantizer']) 40 | if with_encoder: 41 | self.encoder = Encoder(self.h) 42 | self.encoder.load_state_dict(ckpt['encoder']) 43 | 44 | def forward(self, x): 45 | # x is the codebook 46 | # x.shape (B, T, Nq) 47 | quant_emb = self.quantizer.embed(x) 48 | return self.generator(quant_emb) 49 | 50 | def encode(self, x): 51 | batch_size = x.size(0) 52 | if len(x.shape) == 3 and x.shape[-1] == 1: 53 | x = x.squeeze(-1) 54 | c = self.encoder(x.unsqueeze(1)) 55 | q, loss_q, c = self.quantizer(c) 56 | c = [code.reshape(batch_size, -1) for code in c] 57 | # shape: [N, T, 4] 58 | return torch.stack(c, -1) 59 | -------------------------------------------------------------------------------- /inspiremusic/text/abs_tokenizer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Alibaba Inc 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from abc import ABC 16 | from abc import abstractmethod 17 | from typing import Iterable 18 | from typing import List 19 | 20 | 21 | class AbsTokenizer(ABC): 22 | @abstractmethod 23 | def text2tokens(self, line: str) -> List[str]: 24 | raise NotImplementedError 25 | 26 | @abstractmethod 27 | def tokens2text(self, tokens: Iterable[str]) -> str: 28 | raise NotImplementedError 29 | 30 | 31 | 32 | def encode(self, line: str, **kwargs) -> List[str]: 33 | 34 | return self.text2tokens(line) -------------------------------------------------------------------------------- /inspiremusic/text/tokenizer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Alibaba Inc 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import copy 16 | import os 17 | import re 18 | from typing import Iterable, List, Union 19 | import numpy as np 20 | import torch 21 | 22 | from inspiremusic.text.abs_tokenizer import AbsTokenizer 23 | from transformers import AutoTokenizer 24 | 25 | def get_tokenizer(tokenizer_name, tokenizer_path): 26 | if "qwen" in tokenizer_name: 27 | return QwenTokenizer(tokenizer_path,skip_special_tokens=True) 28 | else: 29 | return None 30 | 31 | class QwenTokenizer(AbsTokenizer): 32 | def __init__( 33 | self, 34 | token_path: str, 35 | skip_special_tokens: bool = True, 36 | ): 37 | super().__init__() 38 | # NOTE: non-chat model, all these special tokens keep randomly initialized. 39 | special_tokens = { 40 | 'eos_token': '<|endoftext|>', 41 | 'pad_token': '<|endoftext|>', 42 | 'additional_special_tokens': [ 43 | '<|im_start|>', '<|im_end|>', '<|endofprompt|>', 44 | '[breath]', '', '', '[noise]', 45 | '[laughter]', '[cough]', '[clucking]', '[accent]', 46 | '[quick_breath]', 47 | ] 48 | } 49 | self.tokenizer = AutoTokenizer.from_pretrained(token_path) 50 | self.tokenizer.add_special_tokens(special_tokens) 51 | self.skip_special_tokens = skip_special_tokens 52 | 53 | def get_vocab_size(self): 54 | return self.tokenizer.vocab_size 55 | 56 | def text2tokens(self, line: str) -> List: 57 | tokens = self.tokenizer([line], return_tensors="pt") 58 | tokens = tokens["input_ids"][0].cpu().tolist() 59 | return tokens 60 | 61 | def tokens2text(self, tokens) -> str: 62 | tokens = torch.tensor(tokens, dtype=torch.int64) 63 | text = self.tokenizer.batch_decode([tokens], skip_special_tokens=self.skip_special_tokens)[0] 64 | return text 65 | 66 | 67 | 68 | def get_qwen_vocab_size(token_type: str): 69 | if "qwen1.5" in token_type.lower() or "qwen2.0" in token_type.lower() or "qwen2.5" in token_type.lower(): 70 | # 293 for special and extra tokens, including endoftext, im_start, im_end, endofprompt and others in the future. 71 | # model.vocab_size = 151936, tokenizer.vocab_size = 151643 72 | # NOTE: the first three special tokens (endoftext, im_start, im_end) are trained in Chat series models, 73 | # others are kept in random initialization state. 74 | return 151643 + 293 75 | else: 76 | raise ValueError(f"Unknown tokenizer {token_type}") -------------------------------------------------------------------------------- /inspiremusic/transformer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FunAudioLLM/InspireMusic/0aefb55b58b8df91accdc4aed20c99e44741e0a7/inspiremusic/transformer/__init__.py -------------------------------------------------------------------------------- /inspiremusic/transformer/activation.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 Johns Hopkins University (Shinji Watanabe) 2 | # 2020 Northwestern Polytechnical University (Pengcheng Guo) 3 | # 2020 Mobvoi Inc (Binbin Zhang) 4 | # 2024 Alibaba Inc (Xiang Lyu) 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | """Swish() activation function for Conformer.""" 18 | 19 | import torch 20 | from torch import nn, sin, pow 21 | from torch.nn import Parameter 22 | 23 | 24 | class Swish(torch.nn.Module): 25 | """Construct an Swish object.""" 26 | 27 | def forward(self, x: torch.Tensor) -> torch.Tensor: 28 | """Return Swish activation function.""" 29 | return x * torch.sigmoid(x) 30 | 31 | 32 | # Implementation adapted from https://github.com/EdwardDixon/snake under the MIT license. 33 | # LICENSE is in incl_licenses directory. 34 | class Snake(nn.Module): 35 | ''' 36 | Implementation of a sine-based periodic activation function 37 | Shape: 38 | - Input: (B, C, T) 39 | - Output: (B, C, T), same shape as the input 40 | Parameters: 41 | - alpha - trainable parameter 42 | References: 43 | - This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda: 44 | https://arxiv.org/abs/2006.08195 45 | Examples: 46 | >>> a1 = snake(256) 47 | >>> x = torch.randn(256) 48 | >>> x = a1(x) 49 | ''' 50 | def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False): 51 | ''' 52 | Initialization. 53 | INPUT: 54 | - in_features: shape of the input 55 | - alpha: trainable parameter 56 | alpha is initialized to 1 by default, higher values = higher-frequency. 57 | alpha will be trained along with the rest of your model. 58 | ''' 59 | super(Snake, self).__init__() 60 | self.in_features = in_features 61 | 62 | # initialize alpha 63 | self.alpha_logscale = alpha_logscale 64 | if self.alpha_logscale: # log scale alphas initialized to zeros 65 | self.alpha = Parameter(torch.zeros(in_features) * alpha) 66 | else: # linear scale alphas initialized to ones 67 | self.alpha = Parameter(torch.ones(in_features) * alpha) 68 | 69 | self.alpha.requires_grad = alpha_trainable 70 | 71 | self.no_div_by_zero = 0.000000001 72 | 73 | def forward(self, x): 74 | ''' 75 | Forward pass of the function. 76 | Applies the function to the input elementwise. 77 | Snake ∶= x + 1/a * sin^2 (xa) 78 | ''' 79 | alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T] 80 | if self.alpha_logscale: 81 | alpha = torch.exp(alpha) 82 | x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2) 83 | 84 | return x -------------------------------------------------------------------------------- /inspiremusic/transformer/convolution.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Di Wu) 2 | # 2024 Alibaba Inc (Xiang Lyu) 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # Modified from ESPnet(https://github.com/espnet/espnet) 16 | """ConvolutionModule definition.""" 17 | 18 | from typing import Tuple 19 | 20 | import torch 21 | from torch import nn 22 | 23 | 24 | class ConvolutionModule(nn.Module): 25 | """ConvolutionModule in Conformer model.""" 26 | 27 | def __init__(self, 28 | channels: int, 29 | kernel_size: int = 15, 30 | activation: nn.Module = nn.ReLU(), 31 | norm: str = "batch_norm", 32 | causal: bool = False, 33 | bias: bool = True): 34 | """Construct an ConvolutionModule object. 35 | Args: 36 | channels (int): The number of channels of conv layers. 37 | kernel_size (int): Kernel size of conv layers. 38 | causal (int): Whether use causal convolution or not 39 | """ 40 | super().__init__() 41 | 42 | self.pointwise_conv1 = nn.Conv1d( 43 | channels, 44 | 2 * channels, 45 | kernel_size=1, 46 | stride=1, 47 | padding=0, 48 | bias=bias, 49 | ) 50 | # self.lorder is used to distinguish if it's a causal convolution, 51 | # if self.lorder > 0: it's a causal convolution, the input will be 52 | # padded with self.lorder frames on the left in forward. 53 | # else: it's a symmetrical convolution 54 | if causal: 55 | padding = 0 56 | self.lorder = kernel_size - 1 57 | else: 58 | # kernel_size should be an odd number for none causal convolution 59 | assert (kernel_size - 1) % 2 == 0 60 | padding = (kernel_size - 1) // 2 61 | self.lorder = 0 62 | self.depthwise_conv = nn.Conv1d( 63 | channels, 64 | channels, 65 | kernel_size, 66 | stride=1, 67 | padding=padding, 68 | groups=channels, 69 | bias=bias, 70 | ) 71 | 72 | assert norm in ['batch_norm', 'layer_norm'] 73 | if norm == "batch_norm": 74 | self.use_layer_norm = False 75 | self.norm = nn.BatchNorm1d(channels) 76 | else: 77 | self.use_layer_norm = True 78 | self.norm = nn.LayerNorm(channels) 79 | 80 | self.pointwise_conv2 = nn.Conv1d( 81 | channels, 82 | channels, 83 | kernel_size=1, 84 | stride=1, 85 | padding=0, 86 | bias=bias, 87 | ) 88 | self.activation = activation 89 | 90 | def forward( 91 | self, 92 | x: torch.Tensor, 93 | mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), 94 | cache: torch.Tensor = torch.zeros((0, 0, 0)), 95 | ) -> Tuple[torch.Tensor, torch.Tensor]: 96 | """Compute convolution module. 97 | Args: 98 | x (torch.Tensor): Input tensor (#batch, time, channels). 99 | mask_pad (torch.Tensor): used for batch padding (#batch, 1, time), 100 | (0, 0, 0) means fake mask. 101 | cache (torch.Tensor): left context cache, it is only 102 | used in causal convolution (#batch, channels, cache_t), 103 | (0, 0, 0) meas fake cache. 104 | Returns: 105 | torch.Tensor: Output tensor (#batch, time, channels). 106 | """ 107 | # exchange the temporal dimension and the feature dimension 108 | x = x.transpose(1, 2) # (#batch, channels, time) 109 | 110 | # mask batch padding 111 | if mask_pad.size(2) > 0: # time > 0 112 | x.masked_fill_(~mask_pad, 0.0) 113 | 114 | if self.lorder > 0: 115 | if cache.size(2) == 0: # cache_t == 0 116 | x = nn.functional.pad(x, (self.lorder, 0), 'constant', 0.0) 117 | else: 118 | assert cache.size(0) == x.size(0) # equal batch 119 | assert cache.size(1) == x.size(1) # equal channel 120 | x = torch.cat((cache, x), dim=2) 121 | assert (x.size(2) > self.lorder) 122 | new_cache = x[:, :, -self.lorder:] 123 | else: 124 | # It's better we just return None if no cache is required, 125 | # However, for JIT export, here we just fake one tensor instead of 126 | # None. 127 | new_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device) 128 | 129 | # GLU mechanism 130 | x = self.pointwise_conv1(x) # (batch, 2*channel, dim) 131 | x = nn.functional.glu(x, dim=1) # (batch, channel, dim) 132 | 133 | # 1D Depthwise Conv 134 | x = self.depthwise_conv(x) 135 | if self.use_layer_norm: 136 | x = x.transpose(1, 2) 137 | x = self.activation(self.norm(x)) 138 | if self.use_layer_norm: 139 | x = x.transpose(1, 2) 140 | x = self.pointwise_conv2(x) 141 | # mask batch padding 142 | if mask_pad.size(2) > 0: # time > 0 143 | x.masked_fill_(~mask_pad, 0.0) 144 | 145 | return x.transpose(1, 2), new_cache 146 | -------------------------------------------------------------------------------- /inspiremusic/transformer/decoder_layer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019 Shigeki Karita 2 | # 2020 Mobvoi Inc (Binbin Zhang) 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Decoder self-attention layer definition.""" 16 | from typing import Optional, Tuple 17 | 18 | import torch 19 | from torch import nn 20 | 21 | 22 | class DecoderLayer(nn.Module): 23 | """Single decoder layer module. 24 | 25 | Args: 26 | size (int): Input dimension. 27 | self_attn (torch.nn.Module): Self-attention module instance. 28 | `MultiHeadedAttention` instance can be used as the argument. 29 | src_attn (torch.nn.Module): Inter-attention module instance. 30 | `MultiHeadedAttention` instance can be used as the argument. 31 | If `None` is passed, Inter-attention is not used, such as 32 | CIF, GPT, and other decoder only model. 33 | feed_forward (torch.nn.Module): Feed-forward module instance. 34 | `PositionwiseFeedForward` instance can be used as the argument. 35 | dropout_rate (float): Dropout rate. 36 | normalize_before (bool): 37 | True: use layer_norm before each sub-block. 38 | False: to use layer_norm after each sub-block. 39 | """ 40 | 41 | def __init__( 42 | self, 43 | size: int, 44 | self_attn: nn.Module, 45 | src_attn: Optional[nn.Module], 46 | feed_forward: nn.Module, 47 | dropout_rate: float, 48 | normalize_before: bool = True, 49 | ): 50 | """Construct an DecoderLayer object.""" 51 | super().__init__() 52 | self.size = size 53 | self.self_attn = self_attn 54 | self.src_attn = src_attn 55 | self.feed_forward = feed_forward 56 | self.norm1 = nn.LayerNorm(size, eps=1e-5) 57 | self.norm2 = nn.LayerNorm(size, eps=1e-5) 58 | self.norm3 = nn.LayerNorm(size, eps=1e-5) 59 | self.dropout = nn.Dropout(dropout_rate) 60 | self.normalize_before = normalize_before 61 | 62 | def forward( 63 | self, 64 | tgt: torch.Tensor, 65 | tgt_mask: torch.Tensor, 66 | memory: torch.Tensor, 67 | memory_mask: torch.Tensor, 68 | cache: Optional[torch.Tensor] = None 69 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: 70 | """Compute decoded features. 71 | 72 | Args: 73 | tgt (torch.Tensor): Input tensor (#batch, maxlen_out, size). 74 | tgt_mask (torch.Tensor): Mask for input tensor 75 | (#batch, maxlen_out). 76 | memory (torch.Tensor): Encoded memory 77 | (#batch, maxlen_in, size). 78 | memory_mask (torch.Tensor): Encoded memory mask 79 | (#batch, maxlen_in). 80 | cache (torch.Tensor): cached tensors. 81 | (#batch, maxlen_out - 1, size). 82 | 83 | Returns: 84 | torch.Tensor: Output tensor (#batch, maxlen_out, size). 85 | torch.Tensor: Mask for output tensor (#batch, maxlen_out). 86 | torch.Tensor: Encoded memory (#batch, maxlen_in, size). 87 | torch.Tensor: Encoded memory mask (#batch, maxlen_in). 88 | 89 | """ 90 | residual = tgt 91 | if self.normalize_before: 92 | tgt = self.norm1(tgt) 93 | 94 | if cache is None: 95 | tgt_q = tgt 96 | tgt_q_mask = tgt_mask 97 | else: 98 | # compute only the last frame query keeping dim: max_time_out -> 1 99 | assert cache.shape == ( 100 | tgt.shape[0], 101 | tgt.shape[1] - 1, 102 | self.size, 103 | ), "{cache.shape} == {(tgt.shape[0], tgt.shape[1] - 1, self.size)}" 104 | tgt_q = tgt[:, -1:, :] 105 | residual = residual[:, -1:, :] 106 | tgt_q_mask = tgt_mask[:, -1:, :] 107 | 108 | x = residual + self.dropout( 109 | self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)[0]) 110 | if not self.normalize_before: 111 | x = self.norm1(x) 112 | 113 | if self.src_attn is not None: 114 | residual = x 115 | if self.normalize_before: 116 | x = self.norm2(x) 117 | x = residual + self.dropout( 118 | self.src_attn(x, memory, memory, memory_mask)[0]) 119 | if not self.normalize_before: 120 | x = self.norm2(x) 121 | 122 | residual = x 123 | if self.normalize_before: 124 | x = self.norm3(x) 125 | x = residual + self.dropout(self.feed_forward(x)) 126 | if not self.normalize_before: 127 | x = self.norm3(x) 128 | 129 | if cache is not None: 130 | x = torch.cat([cache, x], dim=1) 131 | 132 | return x, tgt_mask, memory, memory_mask 133 | -------------------------------------------------------------------------------- /inspiremusic/transformer/label_smoothing_loss.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019 Shigeki Karita 2 | # 2020 Mobvoi Inc (Binbin Zhang) 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Label smoothing module.""" 16 | 17 | import torch 18 | from torch import nn 19 | 20 | 21 | class LabelSmoothingLoss(nn.Module): 22 | """Label-smoothing loss. 23 | 24 | In a standard CE loss, the label's data distribution is: 25 | [0,1,2] -> 26 | [ 27 | [1.0, 0.0, 0.0], 28 | [0.0, 1.0, 0.0], 29 | [0.0, 0.0, 1.0], 30 | ] 31 | 32 | In the smoothing version CE Loss,some probabilities 33 | are taken from the true label prob (1.0) and are divided 34 | among other labels. 35 | 36 | e.g. 37 | smoothing=0.1 38 | [0,1,2] -> 39 | [ 40 | [0.9, 0.05, 0.05], 41 | [0.05, 0.9, 0.05], 42 | [0.05, 0.05, 0.9], 43 | ] 44 | 45 | Args: 46 | size (int): the number of class 47 | padding_idx (int): padding class id which will be ignored for loss 48 | smoothing (float): smoothing rate (0.0 means the conventional CE) 49 | normalize_length (bool): 50 | normalize loss by sequence length if True 51 | normalize loss by batch size if False 52 | """ 53 | 54 | def __init__(self, 55 | size: int, 56 | padding_idx: int, 57 | smoothing: float, 58 | normalize_length: bool = False): 59 | """Construct an LabelSmoothingLoss object.""" 60 | super(LabelSmoothingLoss, self).__init__() 61 | self.criterion = nn.KLDivLoss(reduction="none") 62 | self.padding_idx = padding_idx 63 | self.confidence = 1.0 - smoothing 64 | self.smoothing = smoothing 65 | self.size = size 66 | self.normalize_length = normalize_length 67 | 68 | def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor: 69 | """Compute loss between x and target. 70 | 71 | The model outputs and data labels tensors are flatten to 72 | (batch*seqlen, class) shape and a mask is applied to the 73 | padding part which should not be calculated for loss. 74 | 75 | Args: 76 | x (torch.Tensor): prediction (batch, seqlen, class) 77 | target (torch.Tensor): 78 | target signal masked with self.padding_id (batch, seqlen) 79 | Returns: 80 | loss (torch.Tensor) : The KL loss, scalar float value 81 | """ 82 | assert x.size(2) == self.size 83 | batch_size = x.size(0) 84 | x = x.view(-1, self.size) 85 | target = target.view(-1) 86 | # use zeros_like instead of torch.no_grad() for true_dist, 87 | # since no_grad() can not be exported by JIT 88 | true_dist = torch.zeros_like(x) 89 | true_dist.fill_(self.smoothing / (self.size - 1)) 90 | ignore = target == self.padding_idx # (B,) 91 | 92 | total = len(target) - ignore.sum().item() 93 | target = target.masked_fill(ignore, 0) # avoid -1 index 94 | true_dist.scatter_(1, target.unsqueeze(1), self.confidence) 95 | kl = self.criterion(torch.log_softmax(x, dim=1), true_dist) 96 | denom = total if self.normalize_length else batch_size 97 | return kl.masked_fill(ignore.unsqueeze(1), 0).sum() / denom 98 | -------------------------------------------------------------------------------- /inspiremusic/transformer/positionwise_feed_forward.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019 Shigeki Karita 2 | # 2020 Mobvoi Inc (Binbin Zhang) 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Positionwise feed forward layer definition.""" 16 | 17 | import torch 18 | 19 | 20 | class PositionwiseFeedForward(torch.nn.Module): 21 | """Positionwise feed forward layer. 22 | 23 | FeedForward are appied on each position of the sequence. 24 | The output dim is same with the input dim. 25 | 26 | Args: 27 | idim (int): Input dimenstion. 28 | hidden_units (int): The number of hidden units. 29 | dropout_rate (float): Dropout rate. 30 | activation (torch.nn.Module): Activation function 31 | """ 32 | 33 | def __init__( 34 | self, 35 | idim: int, 36 | hidden_units: int, 37 | dropout_rate: float, 38 | activation: torch.nn.Module = torch.nn.ReLU(), 39 | ): 40 | """Construct a PositionwiseFeedForward object.""" 41 | super(PositionwiseFeedForward, self).__init__() 42 | self.w_1 = torch.nn.Linear(idim, hidden_units) 43 | self.activation = activation 44 | self.dropout = torch.nn.Dropout(dropout_rate) 45 | self.w_2 = torch.nn.Linear(hidden_units, idim) 46 | 47 | def forward(self, xs: torch.Tensor) -> torch.Tensor: 48 | """Forward function. 49 | 50 | Args: 51 | xs: input tensor (B, L, D) 52 | Returns: 53 | output tensor, (B, L, D) 54 | """ 55 | return self.w_2(self.dropout(self.activation(self.w_1(xs)))) 56 | 57 | 58 | class MoEFFNLayer(torch.nn.Module): 59 | """ 60 | Mixture of expert with Positionwise feed forward layer 61 | See also figure 1 in https://arxiv.org/pdf/2305.15663.pdf 62 | The output dim is same with the input dim. 63 | 64 | Modified from https://github.com/Lightning-AI/lit-gpt/pull/823 65 | https://github.com/mistralai/mistral-src/blob/b46d6/moe_one_file_ref.py#L203-L219 66 | Args: 67 | n_expert: number of expert. 68 | n_expert_per_token: The actual number of experts used for each frame 69 | idim (int): Input dimenstion. 70 | hidden_units (int): The number of hidden units. 71 | dropout_rate (float): Dropout rate. 72 | activation (torch.nn.Module): Activation function 73 | """ 74 | 75 | def __init__( 76 | self, 77 | n_expert: int, 78 | n_expert_per_token: int, 79 | idim: int, 80 | hidden_units: int, 81 | dropout_rate: float, 82 | activation: torch.nn.Module = torch.nn.ReLU(), 83 | ): 84 | super(MoEFFNLayer, self).__init__() 85 | self.gate = torch.nn.Linear(idim, n_expert, bias=False) 86 | self.experts = torch.nn.ModuleList( 87 | PositionwiseFeedForward(idim, hidden_units, dropout_rate, 88 | activation) for _ in range(n_expert)) 89 | self.n_expert_per_token = n_expert_per_token 90 | 91 | def forward(self, xs: torch.Tensor) -> torch.Tensor: 92 | """Foward function. 93 | Args: 94 | xs: input tensor (B, L, D) 95 | Returns: 96 | output tensor, (B, L, D) 97 | 98 | """ 99 | B, L, D = xs.size( 100 | ) # batch size, sequence length, embedding dimension (idim) 101 | xs = xs.view(-1, D) # (B*L, D) 102 | router = self.gate(xs) # (B*L, n_expert) 103 | logits, indices = torch.topk( 104 | router, self.n_expert_per_token 105 | ) # probs:(B*L, n_expert), indices: (B*L, n_expert) 106 | weights = torch.nn.functional.softmax( 107 | logits, dim=1, 108 | dtype=torch.float).to(dtype=xs.dtype) # (B*L, n_expert_per_token) 109 | output = torch.zeros_like(xs) # (B*L, D) 110 | for i, expert in enumerate(self.experts): 111 | mask = indices == i 112 | batch_idx, ith_expert = torch.where(mask) 113 | output[batch_idx] += weights[batch_idx, ith_expert, None] * expert( 114 | xs[batch_idx]) 115 | return output.view(B, L, D) 116 | -------------------------------------------------------------------------------- /inspiremusic/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FunAudioLLM/InspireMusic/0aefb55b58b8df91accdc4aed20c99e44741e0a7/inspiremusic/utils/__init__.py -------------------------------------------------------------------------------- /inspiremusic/utils/binary.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | """Raw binary format for Encodec compressed audio. Actual compression API is in `encodec.compress`.""" 7 | import io 8 | import json 9 | import struct 10 | import typing as tp 11 | 12 | # format is `ECDC` magic code, followed by the header size as uint32. 13 | # Then an uint8 indicates the protocol version (0.) 14 | # The header is then provided as json and should contain all required 15 | # informations for decoding. A raw stream of bytes is then provided 16 | # and should be interpretable using the json header. 17 | _encodec_header_struct = struct.Struct('!4sBI') 18 | _ENCODEC_MAGIC = b'ECDC' 19 | 20 | 21 | def write_ecdc_header(fo: tp.IO[bytes], metadata: tp.Any): 22 | meta_dumped = json.dumps(metadata).encode('utf-8') 23 | version = 0 24 | header = _encodec_header_struct.pack(_ENCODEC_MAGIC, version, 25 | len(meta_dumped)) 26 | fo.write(header) 27 | fo.write(meta_dumped) 28 | fo.flush() 29 | 30 | 31 | def _read_exactly(fo: tp.IO[bytes], size: int) -> bytes: 32 | buf = b"" 33 | while len(buf) < size: 34 | new_buf = fo.read(size) 35 | if not new_buf: 36 | raise EOFError("Impossible to read enough data from the stream, " 37 | f"{size} bytes remaining.") 38 | buf += new_buf 39 | size -= len(new_buf) 40 | return buf 41 | 42 | 43 | def read_ecdc_header(fo: tp.IO[bytes]): 44 | header_bytes = _read_exactly(fo, _encodec_header_struct.size) 45 | magic, version, meta_size = _encodec_header_struct.unpack(header_bytes) 46 | if magic != _ENCODEC_MAGIC: 47 | raise ValueError("File is not in ECDC format.") 48 | if version != 0: 49 | raise ValueError("Version not supported.") 50 | meta_bytes = _read_exactly(fo, meta_size) 51 | return json.loads(meta_bytes.decode('utf-8')) 52 | 53 | 54 | class BitPacker: 55 | """Simple bit packer to handle ints with a non standard width, e.g. 10 bits. 56 | Note that for some bandwidth (1.5, 3), the codebook representation 57 | will not cover an integer number of bytes. 58 | 59 | Args: 60 | bits (int): number of bits per value that will be pushed. 61 | fo (IO[bytes]): file-object to push the bytes to. 62 | """ 63 | 64 | def __init__(self, bits: int, fo: tp.IO[bytes]): 65 | self._current_value = 0 66 | self._current_bits = 0 67 | self.bits = bits 68 | self.fo = fo 69 | 70 | def push(self, value: int): 71 | """Push a new value to the stream. This will immediately 72 | write as many uint8 as possible to the underlying file-object.""" 73 | self._current_value += (value << self._current_bits) 74 | self._current_bits += self.bits 75 | while self._current_bits >= 8: 76 | lower_8bits = self._current_value & 0xff 77 | self._current_bits -= 8 78 | self._current_value >>= 8 79 | self.fo.write(bytes([lower_8bits])) 80 | 81 | def flush(self): 82 | """Flushes the remaining partial uint8, call this at the end 83 | of the stream to encode.""" 84 | if self._current_bits: 85 | self.fo.write(bytes([self._current_value])) 86 | self._current_value = 0 87 | self._current_bits = 0 88 | self.fo.flush() 89 | 90 | 91 | class BitUnpacker: 92 | """BitUnpacker does the opposite of `BitPacker`. 93 | 94 | Args: 95 | bits (int): number of bits of the values to decode. 96 | fo (IO[bytes]): file-object to push the bytes to. 97 | """ 98 | 99 | def __init__(self, bits: int, fo: tp.IO[bytes]): 100 | self.bits = bits 101 | self.fo = fo 102 | self._mask = (1 << bits) - 1 103 | self._current_value = 0 104 | self._current_bits = 0 105 | 106 | def pull(self) -> tp.Optional[int]: 107 | """ 108 | Pull a single value from the stream, potentially reading some 109 | extra bytes from the underlying file-object. 110 | Returns `None` when reaching the end of the stream. 111 | """ 112 | while self._current_bits < self.bits: 113 | buf = self.fo.read(1) 114 | if not buf: 115 | return None 116 | character = buf[0] 117 | self._current_value += character << self._current_bits 118 | self._current_bits += 8 119 | 120 | out = self._current_value & self._mask 121 | self._current_value >>= self.bits 122 | self._current_bits -= self.bits 123 | return out 124 | 125 | 126 | def test(): 127 | import torch 128 | torch.manual_seed(1234) 129 | for rep in range(4): 130 | length: int = torch.randint(10, 2_000, (1, )).item() 131 | bits: int = torch.randint(1, 16, (1, )).item() 132 | tokens: tp.List[int] = torch.randint(2**bits, (length, )).tolist() 133 | rebuilt: tp.List[int] = [] 134 | buf = io.BytesIO() 135 | packer = BitPacker(bits, buf) 136 | for token in tokens: 137 | packer.push(token) 138 | packer.flush() 139 | buf.seek(0) 140 | unpacker = BitUnpacker(bits, buf) 141 | while True: 142 | value = unpacker.pull() 143 | if value is None: 144 | break 145 | rebuilt.append(value) 146 | assert len(rebuilt) >= len(tokens), (len(rebuilt), len(tokens)) 147 | # The flushing mechanism might lead to "ghost" values at the end of the stream. 148 | assert len(rebuilt) <= len(tokens) + 8 // bits, (len(rebuilt), 149 | len(tokens), bits) 150 | for idx, (a, b) in enumerate(zip(tokens, rebuilt)): 151 | assert a == b, (idx, a, b) 152 | 153 | 154 | if __name__ == '__main__': 155 | test() 156 | -------------------------------------------------------------------------------- /inspiremusic/utils/class_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright [2023-11-28] 2 | # 2024 Alibaba Inc 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | import torch 16 | 17 | from inspiremusic.transformer.activation import Swish 18 | from inspiremusic.transformer.subsampling import ( 19 | LinearNoSubsampling, 20 | EmbedinigNoSubsampling, 21 | Conv1dSubsampling2, 22 | Conv2dSubsampling4, 23 | Conv2dSubsampling6, 24 | Conv2dSubsampling8, 25 | ) 26 | from inspiremusic.transformer.embedding import (PositionalEncoding, 27 | RelPositionalEncoding, 28 | WhisperPositionalEncoding, 29 | LearnablePositionalEncoding, 30 | NoPositionalEncoding) 31 | from inspiremusic.transformer.attention import (MultiHeadedAttention, 32 | RelPositionMultiHeadedAttention) 33 | from inspiremusic.transformer.embedding import EspnetRelPositionalEncoding 34 | from inspiremusic.transformer.subsampling import LegacyLinearNoSubsampling 35 | 36 | 37 | INSPIREMUSIC_ACTIVATION_CLASSES = { 38 | "hardtanh": torch.nn.Hardtanh, 39 | "tanh": torch.nn.Tanh, 40 | "relu": torch.nn.ReLU, 41 | "selu": torch.nn.SELU, 42 | "swish": getattr(torch.nn, "SiLU", Swish), 43 | "gelu": torch.nn.GELU, 44 | } 45 | 46 | INSPIREMUSIC_SUBSAMPLE_CLASSES = { 47 | "linear": LinearNoSubsampling, 48 | "linear_legacy": LegacyLinearNoSubsampling, 49 | "embed": EmbedinigNoSubsampling, 50 | "conv1d2": Conv1dSubsampling2, 51 | "conv2d": Conv2dSubsampling4, 52 | "conv2d6": Conv2dSubsampling6, 53 | "conv2d8": Conv2dSubsampling8, 54 | 'paraformer_dummy': torch.nn.Identity 55 | } 56 | 57 | INSPIREMUSIC_EMB_CLASSES = { 58 | "embed": PositionalEncoding, 59 | "abs_pos": PositionalEncoding, 60 | "rel_pos": RelPositionalEncoding, 61 | "rel_pos_espnet": EspnetRelPositionalEncoding, 62 | "no_pos": NoPositionalEncoding, 63 | "abs_pos_whisper": WhisperPositionalEncoding, 64 | "embed_learnable_pe": LearnablePositionalEncoding, 65 | } 66 | 67 | INSPIREMUSIC_ATTENTION_CLASSES = { 68 | "selfattn": MultiHeadedAttention, 69 | "rel_selfattn": RelPositionMultiHeadedAttention, 70 | } 71 | 72 | -------------------------------------------------------------------------------- /inspiremusic/utils/data_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Alibaba Inc 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from torch.utils.data import DataLoader 15 | from inspiremusic.dataset.dataset import Dataset 16 | import numpy as np 17 | import librosa 18 | 19 | def audio_process_dataset_and_dataloader(args, configs): 20 | input_dataset = Dataset(args.input_data, data_pipeline=configs['data_pipeline'], mode='processing', shuffle=True, partition=True) 21 | # do not use persistent_workers=True, as whisper tokenizer opens tiktoken file each time when the for loop starts 22 | input_data_loader = DataLoader(input_dataset, 23 | batch_size=None, 24 | pin_memory=args.pin_memory, 25 | num_workers=args.num_workers, 26 | prefetch_factor=args.prefetch) 27 | return input_dataset, input_data_loader 28 | 29 | def is_silent(wav_path, threshold=0.01, frame_length=2048, hop_length=512): 30 | y, sr = librosa.load(wav_path, sr=None) 31 | rms = librosa.feature.rms(y=y, frame_length=frame_length, hop_length=hop_length)[0] 32 | silent_frames = np.sum(rms < threshold) / len(rms) 33 | silence_fraction_threshold = 0.95 34 | return silent_frames >= silence_fraction_threshold 35 | 36 | def rich_captions(text=None, tags=None, lyrics=None, chorus="verse", start_time=0.0, end_time=30.0): 37 | if text is None and tags is None and lyrics is None: 38 | return None 39 | else: 40 | if start_time is None: 41 | start_time = 0.0 42 | if end_time is None: 43 | end_time = 30.0 44 | if chorus is None: 45 | chorus = "verse" 46 | captions = f"<|{start_time:.1f}|><|{chorus}|>" 47 | if tags is not None: 48 | captions += f"<|{tags}|>" 49 | if text is not None: 50 | captions += f"<|{text}|>" 51 | if lyrics is not None: 52 | captions += f"<|lyrics|><|{lyrics}|>" 53 | captions += f"<|{end_time:.1f}|>" 54 | return captions 55 | 56 | def process_tags(infile, outfile, timefile = None): 57 | key_list = [] 58 | with open(infile, "r") as f: 59 | for line in f: 60 | sec = line.strip() 61 | key_list.append(sec) 62 | f.close() 63 | if timefile is None: 64 | with open(outfile, 'w') as f: 65 | for k in key_list: 66 | parts = k.rsplit('_', 1) 67 | text = parts[0].replace('_', ' ') + ', ' + parts[1] 68 | caption = rich_captions(text, None, None) 69 | if caption is not None: 70 | f.write("%s\t%s\n" %(k, caption)) 71 | f.close() 72 | else: 73 | times = {} 74 | with open(timefile, "r") as f: 75 | for line in f: 76 | sec = line.strip().split("\t") 77 | if len(sec) == 2 : 78 | times[sec[0]] = sec[1] 79 | f.close() 80 | 81 | with open(outfile, 'w') as f: 82 | for k in key_list: 83 | parts = k.rsplit('_', 1) 84 | text = parts[0].replace('_', ' ') + ', ' + parts[1] 85 | if k in times.keys(): 86 | caption = rich_captions(text, None, None, "verse", 0.0, float(times[k])) 87 | if caption is not None: 88 | f.write("%s\t%s\n" %(k, caption)) 89 | f.close() 90 | 91 | def process_trans(infile, outfile): 92 | trans = {} 93 | with open(infile, "r") as f: 94 | for line in f: 95 | sec = line.strip().split("\t") 96 | if len(sec) == 2: 97 | trans[sec[0]] = sec[1] 98 | else: 99 | print(line) 100 | f.close() 101 | with open(outfile, 'w') as f: 102 | for k, v in trans.items(): 103 | f.write("%s\t%s\n" %(k, rich_captions(v))) 104 | f.close() -------------------------------------------------------------------------------- /inspiremusic/utils/executor.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 Mobvoi Inc (Binbin Zhang) 2 | # 2024 Alibaba Inc 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import logging 17 | from contextlib import nullcontext 18 | import os 19 | 20 | import torch 21 | import torch.distributed as dist 22 | 23 | from inspiremusic.utils.train_utils import update_parameter_and_lr, log_per_step, log_per_save, batch_forward, batch_backward, save_model, inspiremusic_join 24 | from torch.amp import autocast 25 | 26 | class Executor: 27 | def __init__(self): 28 | self.step = 0 29 | self.epoch = 0 30 | self.rank = int(os.environ.get('RANK', 0)) 31 | if torch.cuda.is_available(): 32 | if torch.cuda.is_available(): 33 | self.device = torch.device('cuda:{}'.format(self.rank)) 34 | elif torch.backends.mps.is_available(): 35 | self.device = torch.device('mps') 36 | elif torch.xpu.is_available(): 37 | self.device = torch.device('xpu') 38 | else: 39 | self.device = torch.device('cpu') 40 | def train_one_epoch(self, model, optimizer, scheduler, train_data_loader, cv_data_loader, writer, info_dict, group_join, scaler=None): 41 | ''' Train one epoch 42 | ''' 43 | 44 | lr = optimizer.param_groups[0]['lr'] 45 | logging.info('Epoch {} TRAIN info lr {} rank {}'.format(self.epoch, lr, self.rank)) 46 | logging.info('using accumulate grad, new batch size is {} times' 47 | ' larger than before'.format(info_dict['accum_grad'])) 48 | # A context manager to be used in conjunction with an instance of 49 | # torch.nn.parallel.DistributedDataParallel to be able to train 50 | # with uneven inputs across participating processes. 51 | model.train() 52 | model_context = model.join if info_dict['train_engine'] == 'torch_ddp' else nullcontext 53 | with model_context(): 54 | for batch_idx, batch_dict in enumerate(train_data_loader): 55 | info_dict["tag"] = "TRAIN" 56 | info_dict["step"] = self.step 57 | info_dict["epoch"] = self.epoch 58 | info_dict["batch_idx"] = batch_idx 59 | if inspiremusic_join(group_join, info_dict): 60 | break 61 | 62 | # Disable gradient synchronizations across DDP processes. 63 | # Within this context, gradients will be accumulated on module 64 | # variables, which will later be synchronized. 65 | if info_dict['train_engine'] == 'torch_ddp' and (batch_idx + 1) % info_dict["accum_grad"] != 0: 66 | context = model.no_sync 67 | # Used for single gpu training and DDP gradient synchronization 68 | # processes. 69 | else: 70 | context = nullcontext 71 | 72 | with context(): 73 | with autocast(device_type='cuda', enabled=scaler is not None): 74 | info_dict = batch_forward(model, batch_dict, info_dict, scaler) 75 | info_dict = batch_backward(model, info_dict, scaler) 76 | 77 | info_dict = update_parameter_and_lr(model, optimizer, scheduler, info_dict, scaler) 78 | log_per_step(writer, info_dict) 79 | # NOTE specify save_per_step in inspiremusic.yaml if you want to enable step save 80 | if info_dict['save_per_step'] > 0 and (self.step + 1) % info_dict['save_per_step'] == 0 and \ 81 | (batch_idx + 1) % info_dict["accum_grad"] == 0: 82 | dist.barrier() 83 | self.cv(model, cv_data_loader, writer, info_dict, on_batch_end=False, scaler=scaler) 84 | model.train() 85 | if (batch_idx + 1) % info_dict["accum_grad"] == 0: 86 | self.step += 1 87 | dist.barrier() 88 | self.cv(model, cv_data_loader, writer, info_dict, on_batch_end=True, scaler=scaler) 89 | 90 | @torch.inference_mode() 91 | def cv(self, model, cv_data_loader, writer, info_dict, on_batch_end=True, capped_at=5, scaler=None): 92 | ''' Cross validation on 93 | ''' 94 | logging.info('Epoch {} Step {} on_batch_end {} CV rank {}'.format(self.epoch, self.step + 1, on_batch_end, self.rank)) 95 | model.eval() 96 | total_num_utts, total_loss_dict = 0, {} # avoid division by 0 97 | stop = capped_at 98 | for batch_idx, batch_dict in enumerate(cv_data_loader): 99 | info_dict["tag"] = "CV" 100 | info_dict["step"] = self.step 101 | info_dict["epoch"] = self.epoch 102 | info_dict["batch_idx"] = batch_idx 103 | 104 | num_utts = len(batch_dict["utts"]) 105 | total_num_utts += num_utts 106 | 107 | if capped_at>0: 108 | if stop <= 0: 109 | continue 110 | else: 111 | stop -= 1 112 | 113 | with autocast(device_type='cuda', enabled=scaler is not None): 114 | info_dict = batch_forward(model, batch_dict, info_dict, scaler) 115 | 116 | for k, v in info_dict['loss_dict'].items(): 117 | if k not in total_loss_dict: 118 | total_loss_dict[k] = [] 119 | total_loss_dict[k].append(v.item() * num_utts) 120 | log_per_step(None, info_dict) 121 | 122 | for k, v in total_loss_dict.items(): 123 | total_loss_dict[k] = sum(v) / total_num_utts 124 | info_dict['loss_dict'] = total_loss_dict 125 | log_per_save(writer, info_dict) 126 | model_name = 'epoch_{}_whole'.format(self.epoch) if on_batch_end else 'epoch_{}_step_{}'.format(self.epoch, self.step + 1) 127 | save_model(model, model_name, info_dict) 128 | -------------------------------------------------------------------------------- /inspiremusic/utils/file_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang) 2 | # 2024 Alibaba Inc 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import json 17 | import torchaudio 18 | import logging 19 | logging.getLogger('matplotlib').setLevel(logging.WARNING) 20 | logging.basicConfig(level=logging.DEBUG, 21 | format='%(asctime)s %(levelname)s %(message)s') 22 | 23 | def read_trans(list_file): 24 | trans = {} 25 | with open(list_file, 'r', encoding='utf8') as fin: 26 | for line in fin: 27 | sec = line.strip().split("\t") 28 | if len(sec) > 1: 29 | if sec[0] not in trans.keys(): 30 | trans[sec[0]] = sec[1] 31 | return trans 32 | 33 | def read_scp(list_file): 34 | scp = {} 35 | with open(list_file, 'r', encoding='utf8') as fin: 36 | for line in fin: 37 | sec = line.strip().split(" ") 38 | if len(sec) > 1: 39 | if sec[0] not in scp.keys(): 40 | scp[sec[0]] = sec[1] 41 | return scp 42 | 43 | def read_lists(list_file): 44 | lists = [] 45 | with open(list_file, 'r', encoding='utf8') as fin: 46 | for line in fin: 47 | lists.append(line.strip()) 48 | return lists 49 | 50 | 51 | def read_json_lists(list_file): 52 | lists = read_lists(list_file) 53 | results = {} 54 | for fn in lists: 55 | with open(fn, 'r', encoding='utf8') as fin: 56 | results.update(json.load(fin)) 57 | return results 58 | 59 | 60 | def load_wav(wav, target_sr): 61 | audio, sample_rate = torchaudio.load(wav) 62 | audio = audio.mean(dim=0, keepdim=True) 63 | if sample_rate != target_sr: 64 | assert sample_rate > target_sr, 'wav sample rate {} must be greater than {}'.format(sample_rate, target_sr) 65 | audio = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sr)(speech) 66 | return audio 67 | 68 | 69 | def speed_change(waveform, sample_rate, speed_factor: str): 70 | effects = [ 71 | ["tempo", speed_factor], # speed_factor 72 | ["rate", f"{sample_rate}"] 73 | ] 74 | augmented_waveform, new_sample_rate = torchaudio.sox_effects.apply_effects_tensor( 75 | waveform, 76 | sample_rate, 77 | effects 78 | ) 79 | return augmented_waveform, new_sample_rate 80 | -------------------------------------------------------------------------------- /inspiremusic/utils/frontend_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Alibaba Inc 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import re 16 | chinese_char_pattern = re.compile(r'[\u4e00-\u9fff]+') 17 | 18 | 19 | # whether contain chinese character 20 | def contains_chinese(text): 21 | return bool(chinese_char_pattern.search(text)) 22 | 23 | 24 | # replace special symbol 25 | def replace_corner_mark(text): 26 | text = text.replace('²', '平方') 27 | text = text.replace('³', '立方') 28 | return text 29 | 30 | 31 | # remove meaningless symbol 32 | def remove_bracket(text): 33 | text = text.replace('(', '').replace(')', '') 34 | text = text.replace('【', '').replace('】', '') 35 | text = text.replace('`', '').replace('`', '') 36 | text = text.replace("——", " ") 37 | return text 38 | 39 | 40 | # spell Arabic numerals 41 | def spell_out_number(text: str, inflect_parser): 42 | new_text = [] 43 | st = None 44 | for i, c in enumerate(text): 45 | if not c.isdigit(): 46 | if st is not None: 47 | num_str = inflect_parser.number_to_words(text[st: i]) 48 | new_text.append(num_str) 49 | st = None 50 | new_text.append(c) 51 | else: 52 | if st is None: 53 | st = i 54 | if st is not None and st < len(text): 55 | num_str = inflect_parser.number_to_words(text[st:]) 56 | new_text.append(num_str) 57 | return ''.join(new_text) 58 | 59 | 60 | # split paragrah logic: 61 | # 1. per sentence max len token_max_n, min len token_min_n, merge if last sentence len less than merge_len 62 | # 2. cal sentence len according to lang 63 | # 3. split sentence according to puncatation 64 | def split_paragraph(text: str, tokenize, lang="zh", token_max_n=80, token_min_n=60, merge_len=20, comma_split=False): 65 | def calc_utt_length(_text: str): 66 | if lang == "zh": 67 | return len(_text) 68 | else: 69 | return len(tokenize(_text)) 70 | 71 | def should_merge(_text: str): 72 | if lang == "zh": 73 | return len(_text) < merge_len 74 | else: 75 | return len(tokenize(_text)) < merge_len 76 | 77 | if lang == "zh": 78 | pounc = ['。', '?', '!', ';', ':', '、', '.', '?', '!', ';'] 79 | else: 80 | pounc = ['.', '?', '!', ';', ':'] 81 | if comma_split: 82 | pounc.extend([',', ',']) 83 | st = 0 84 | utts = [] 85 | for i, c in enumerate(text): 86 | if c in pounc: 87 | if len(text[st: i]) > 0: 88 | utts.append(text[st: i] + c) 89 | if i + 1 < len(text) and text[i + 1] in ['"', '”']: 90 | tmp = utts.pop(-1) 91 | utts.append(tmp + text[i + 1]) 92 | st = i + 2 93 | else: 94 | st = i + 1 95 | if len(utts) == 0: 96 | if lang == "zh": 97 | utts.append(text + '。') 98 | else: 99 | utts.append(text + '.') 100 | final_utts = [] 101 | cur_utt = "" 102 | for utt in utts: 103 | if calc_utt_length(cur_utt + utt) > token_max_n and calc_utt_length(cur_utt) > token_min_n: 104 | final_utts.append(cur_utt) 105 | cur_utt = "" 106 | cur_utt = cur_utt + utt 107 | if len(cur_utt) > 0: 108 | if should_merge(cur_utt) and len(final_utts) != 0: 109 | final_utts[-1] = final_utts[-1] + cur_utt 110 | else: 111 | final_utts.append(cur_utt) 112 | 113 | return final_utts 114 | 115 | 116 | # remove blank between chinese character 117 | def replace_blank(text: str): 118 | out_str = [] 119 | for i, c in enumerate(text): 120 | if c == " ": 121 | if ((text[i + 1].isascii() and text[i + 1] != " ") and 122 | (text[i - 1].isascii() and text[i - 1] != " ")): 123 | out_str.append(c) 124 | else: 125 | out_str.append(c) 126 | return "".join(out_str) 127 | -------------------------------------------------------------------------------- /inspiremusic/utils/hinter.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import torch.distributed 3 | import logging 4 | 5 | HINTED = set() 6 | 7 | 8 | def hint_once(content, uid, rank=None): 9 | if (rank is None) or (not torch.distributed.is_initialized()) or torch.distributed.get_rank() == rank: 10 | if uid not in HINTED: 11 | logging.info(content, stacklevel=3) 12 | HINTED.add(uid) -------------------------------------------------------------------------------- /inspiremusic/utils/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | def tpr_loss(disc_real_outputs, disc_generated_outputs, tau): 6 | loss = 0 7 | for dr, dg in zip(disc_real_outputs, disc_generated_outputs): 8 | m_DG = torch.median((dr - dg)) 9 | L_rel = torch.mean((((dr - dg) - m_DG) ** 2)[dr < dg + m_DG]) 10 | loss += tau - F.relu(tau - L_rel) 11 | return loss 12 | 13 | 14 | def mel_loss(real_speech, generated_speech, mel_transforms): 15 | loss = 0 16 | for transform in mel_transforms: 17 | mel_r = transform(real_speech) 18 | mel_g = transform(generated_speech) 19 | loss += F.l1_loss(mel_g, mel_r) 20 | return loss 21 | -------------------------------------------------------------------------------- /inspiremusic/utils/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import subprocess 4 | 5 | def download_model(repo_url: str, output_dir: str = None, token: str = None): 6 | try: 7 | if token: 8 | repo_url = repo_url.replace("https://", f"https://USER:{token}@") 9 | else: 10 | repo_url = f"https://www.modelscope.cn/models/iic/{repo_url}" 11 | 12 | cmd = ["git", "clone", repo_url] 13 | if output_dir: 14 | cmd.append(output_dir) 15 | 16 | result = subprocess.run( 17 | cmd, 18 | check=True, 19 | capture_output=True, 20 | text=True 21 | ) 22 | print("Success:", result.stdout) 23 | except subprocess.CalledProcessError as e: 24 | print("Error:", e.stderr) 25 | 26 | def align_trans_scp_file(trans, scp): 27 | trans_dict = {} 28 | with open(trans, 'r') as f: 29 | for line in f: 30 | sec = line.strip().split("\t") 31 | trans_dict[sec[0]] = sec[1] 32 | scp_dict = {} 33 | with open(scp, 'r') as f: 34 | for line in f: 35 | sec = line.strip().split(" ") 36 | scp_dict[sec[0]] = sec[1] 37 | with open("text", "w") as f: 38 | for k, v in scp_dict.items(): 39 | f.write("%s\t%s\n"%(k,trans_dict[k])) -------------------------------------------------------------------------------- /inspiremusic/version.txt: -------------------------------------------------------------------------------- 1 | v0.1 -------------------------------------------------------------------------------- /inspiremusic/wavtokenizer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FunAudioLLM/InspireMusic/0aefb55b58b8df91accdc4aed20c99e44741e0a7/inspiremusic/wavtokenizer/__init__.py -------------------------------------------------------------------------------- /inspiremusic/wavtokenizer/decoder/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FunAudioLLM/InspireMusic/0aefb55b58b8df91accdc4aed20c99e44741e0a7/inspiremusic/wavtokenizer/decoder/__init__.py -------------------------------------------------------------------------------- /inspiremusic/wavtokenizer/decoder/dataset.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | import numpy as np 4 | import torch 5 | import torchaudio 6 | from pytorch_lightning import LightningDataModule 7 | from torch.utils.data import Dataset, DataLoader 8 | 9 | import soundfile 10 | # import librosa 11 | import random 12 | 13 | torch.set_num_threads(1) 14 | 15 | 16 | @dataclass 17 | class DataConfig: 18 | filelist_path: str 19 | sampling_rate: int 20 | num_samples: int 21 | batch_size: int 22 | num_workers: int 23 | 24 | def collate_fn(batch): 25 | batch = [item for item in batch if item is not None] 26 | return torch.stack(batch, dim=0) 27 | 28 | class VocosDataModule(LightningDataModule): 29 | def __init__(self, train_params: DataConfig, val_params: DataConfig): 30 | super().__init__() 31 | self.train_config = train_params 32 | self.val_config = val_params 33 | 34 | def _get_dataloder(self, cfg: DataConfig, train: bool): 35 | dataset = VocosDataset(cfg, train=train) 36 | dataloader = DataLoader( 37 | dataset, batch_size=cfg.batch_size, num_workers=cfg.num_workers, shuffle=train, pin_memory=True, collate_fn=collate_fn 38 | ) 39 | return dataloader 40 | 41 | def train_dataloader(self) -> DataLoader: 42 | return self._get_dataloder(self.train_config, train=True) 43 | 44 | def val_dataloader(self) -> DataLoader: 45 | return self._get_dataloder(self.val_config, train=False) 46 | 47 | 48 | class VocosDataset(Dataset): 49 | def __init__(self, cfg: DataConfig, train: bool): 50 | with open(cfg.filelist_path) as f: 51 | self.filelist = f.read().splitlines() 52 | self.sampling_rate = cfg.sampling_rate 53 | self.num_samples = cfg.num_samples 54 | self.train = train 55 | 56 | def __len__(self) -> int: 57 | return len(self.filelist) 58 | 59 | def __getitem__(self, index: int) -> torch.Tensor: 60 | audio_path = self.filelist[index] 61 | # y, sr = torchaudio.load(audio_path) 62 | # print(audio_path,"111") 63 | try: 64 | y1, sr = soundfile.read(audio_path) 65 | # y1, sr = librosa.load(audio_path,sr=None) 66 | y = torch.tensor(y1).float().unsqueeze(0) 67 | # if y.size(0) > 1: 68 | # # mix to mono 69 | # y = y.mean(dim=0, keepdim=True) 70 | if y.ndim > 2: 71 | # mix to mono 72 | # print("有问题哈,数据处理部分") 73 | # y = y.mean(dim=-1, keepdim=False) 74 | random_channel = random.randint(0, y.size(-1) - 1) 75 | y = y[:, :, random_channel] 76 | 77 | gain = np.random.uniform(-1, -6) if self.train else -3 78 | y, _ = torchaudio.sox_effects.apply_effects_tensor(y, sr, [["norm", f"{gain:.2f}"]]) 79 | if sr != self.sampling_rate: 80 | y = torchaudio.functional.resample(y, orig_freq=sr, new_freq=self.sampling_rate) 81 | if y.size(-1) < self.num_samples: 82 | pad_length = self.num_samples - y.size(-1) 83 | padding_tensor = y.repeat(1, 1 + pad_length // y.size(-1)) 84 | y = torch.cat((y, padding_tensor[:, :pad_length]), dim=1) 85 | elif self.train: 86 | start = np.random.randint(low=0, high=y.size(-1) - self.num_samples + 1) 87 | y = y[:, start : start + self.num_samples] 88 | else: 89 | # During validation, take always the first segment for determinism 90 | y = y[:, : self.num_samples] 91 | 92 | return y[0] 93 | except Exception as e: 94 | print(f"Error processing file {audio_path} at index {index}: {e}") 95 | # 这里可以继续选择抛出异常,或者返回一个 None 表示无效数据 96 | return None 97 | 98 | # def __getitem__(self, index: int) -> torch.Tensor: 99 | # audio_path = self.filelist[index] 100 | # try: 101 | # y, sr = torchaudio.load(audio_path) 102 | # if y.size(0) > 1: 103 | # # 随机选择一个通道 104 | # random_channel = random.randint(0, y.size(0) - 1) 105 | # y = y[random_channel, :].unsqueeze(0) # 保持返回值为 (1, T) 的形式 106 | # # gain = np.random.uniform(-1, -6) if self.train else -3 107 | # # y, _ = torchaudio.sox_effects.apply_effects_tensor(y, sr, [["norm", f"{gain:.2f}"]]) 108 | # if sr != self.sampling_rate: 109 | # y = torchaudio.functional.resample(y, orig_freq=sr, new_freq=self.sampling_rate) 110 | # if y.size(-1) < self.num_samples: 111 | # pad_length = self.num_samples - y.size(-1) 112 | # padding_tensor = y.repeat(1, 1 + pad_length // y.size(-1)) 113 | # y = torch.cat((y, padding_tensor[:, :pad_length]), dim=1) 114 | # elif self.train: 115 | # start = np.random.randint(low=0, high=y.size(-1) - self.num_samples + 1) 116 | # y = y[:, start: start + self.num_samples] 117 | # else: 118 | # # During validation, take always the first segment for determinism 119 | # y = y[:, :self.num_samples] 120 | # return y[0] 121 | # except Exception as e: 122 | # print(f"Error processing file {audio_path} at index {index}: {e}") 123 | # # 这里可以继续选择抛出异常,或者返回一个 None 表示无效数据 124 | # return None -------------------------------------------------------------------------------- /inspiremusic/wavtokenizer/decoder/helpers.py: -------------------------------------------------------------------------------- 1 | import matplotlib 2 | import numpy as np 3 | import torch 4 | from matplotlib import pyplot as plt 5 | from pytorch_lightning import Callback 6 | 7 | matplotlib.use("Agg") 8 | 9 | 10 | def save_figure_to_numpy(fig: plt.Figure) -> np.ndarray: 11 | """ 12 | Save a matplotlib figure to a numpy array. 13 | 14 | Args: 15 | fig (Figure): Matplotlib figure object. 16 | 17 | Returns: 18 | ndarray: Numpy array representing the figure. 19 | """ 20 | data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="") 21 | data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) 22 | return data 23 | 24 | 25 | def plot_spectrogram_to_numpy(spectrogram: np.ndarray) -> np.ndarray: 26 | """ 27 | Plot a spectrogram and convert it to a numpy array. 28 | 29 | Args: 30 | spectrogram (ndarray): Spectrogram data. 31 | 32 | Returns: 33 | ndarray: Numpy array representing the plotted spectrogram. 34 | """ 35 | spectrogram = spectrogram.astype(np.float32) 36 | fig, ax = plt.subplots(figsize=(12, 3)) 37 | im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none") 38 | plt.colorbar(im, ax=ax) 39 | plt.xlabel("Frames") 40 | plt.ylabel("Channels") 41 | plt.tight_layout() 42 | 43 | fig.canvas.draw() 44 | data = save_figure_to_numpy(fig) 45 | plt.close() 46 | return data 47 | 48 | 49 | class GradNormCallback(Callback): 50 | """ 51 | Callback to log the gradient norm. 52 | """ 53 | 54 | def on_after_backward(self, trainer, model): 55 | model.log("grad_norm", gradient_norm(model)) 56 | 57 | 58 | def gradient_norm(model: torch.nn.Module, norm_type: float = 2.0) -> torch.Tensor: 59 | """ 60 | Compute the gradient norm. 61 | 62 | Args: 63 | model (Module): PyTorch model. 64 | norm_type (float, optional): Type of the norm. Defaults to 2.0. 65 | 66 | Returns: 67 | Tensor: Gradient norm. 68 | """ 69 | grads = [p.grad for p in model.parameters() if p.grad is not None] 70 | total_norm = torch.norm(torch.stack([torch.norm(g.detach(), norm_type) for g in grads]), norm_type) 71 | return total_norm 72 | -------------------------------------------------------------------------------- /inspiremusic/wavtokenizer/decoder/loss.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple 2 | 3 | import torch 4 | import torchaudio 5 | from torch import nn 6 | 7 | from decoder.modules import safe_log 8 | 9 | import torch.nn.functional as F 10 | 11 | 12 | class MelSpecReconstructionLoss(nn.Module): 13 | """ 14 | L1 distance between the mel-scaled magnitude spectrograms of the ground truth sample and the generated sample 15 | """ 16 | 17 | def __init__( 18 | self, sample_rate: int = 24000, n_fft: int = 1024, hop_length: int = 256, n_mels: int = 100, 19 | ): 20 | super().__init__() 21 | self.mel_spec = torchaudio.transforms.MelSpectrogram( 22 | sample_rate=sample_rate, n_fft=n_fft, hop_length=hop_length, n_mels=n_mels, center=True, power=1, 23 | ) 24 | 25 | def forward(self, y_hat, y) -> torch.Tensor: 26 | """ 27 | Args: 28 | y_hat (Tensor): Predicted audio waveform. 29 | y (Tensor): Ground truth audio waveform. 30 | 31 | Returns: 32 | Tensor: L1 loss between the mel-scaled magnitude spectrograms. 33 | """ 34 | mel_hat = safe_log(self.mel_spec(y_hat)) 35 | mel = safe_log(self.mel_spec(y)) 36 | 37 | loss = torch.nn.functional.l1_loss(mel, mel_hat) 38 | 39 | return loss 40 | 41 | 42 | class GeneratorLoss(nn.Module): 43 | """ 44 | Generator Loss module. Calculates the loss for the generator based on discriminator outputs. 45 | """ 46 | 47 | def forward(self, disc_outputs: List[torch.Tensor]) -> Tuple[torch.Tensor, List[torch.Tensor]]: 48 | """ 49 | Args: 50 | disc_outputs (List[Tensor]): List of discriminator outputs. 51 | 52 | Returns: 53 | Tuple[Tensor, List[Tensor]]: Tuple containing the total loss and a list of loss values from 54 | the sub-discriminators 55 | """ 56 | loss = 0 57 | gen_losses = [] 58 | for dg in disc_outputs: 59 | l = torch.mean(torch.clamp(1 - dg, min=0)) 60 | gen_losses.append(l) 61 | loss += l 62 | 63 | return loss, gen_losses 64 | 65 | 66 | class DiscriminatorLoss(nn.Module): 67 | """ 68 | Discriminator Loss module. Calculates the loss for the discriminator based on real and generated outputs. 69 | """ 70 | 71 | def forward( 72 | self, disc_real_outputs: List[torch.Tensor], disc_generated_outputs: List[torch.Tensor] 73 | ) -> Tuple[torch.Tensor, List[torch.Tensor], List[torch.Tensor]]: 74 | """ 75 | Args: 76 | disc_real_outputs (List[Tensor]): List of discriminator outputs for real samples. 77 | disc_generated_outputs (List[Tensor]): List of discriminator outputs for generated samples. 78 | 79 | Returns: 80 | Tuple[Tensor, List[Tensor], List[Tensor]]: A tuple containing the total loss, a list of loss values from 81 | the sub-discriminators for real outputs, and a list of 82 | loss values for generated outputs. 83 | """ 84 | loss = 0 85 | r_losses = [] 86 | g_losses = [] 87 | for dr, dg in zip(disc_real_outputs, disc_generated_outputs): 88 | r_loss = torch.mean(torch.clamp(1 - dr, min=0)) 89 | g_loss = torch.mean(torch.clamp(1 + dg, min=0)) 90 | loss += r_loss + g_loss 91 | r_losses.append(r_loss.item()) 92 | g_losses.append(g_loss.item()) 93 | 94 | return loss, r_losses, g_losses 95 | 96 | 97 | class FeatureMatchingLoss(nn.Module): 98 | """ 99 | Feature Matching Loss module. Calculates the feature matching loss between feature maps of the sub-discriminators. 100 | """ 101 | 102 | def forward(self, fmap_r: List[List[torch.Tensor]], fmap_g: List[List[torch.Tensor]]) -> torch.Tensor: 103 | """ 104 | Args: 105 | fmap_r (List[List[Tensor]]): List of feature maps from real samples. 106 | fmap_g (List[List[Tensor]]): List of feature maps from generated samples. 107 | 108 | Returns: 109 | Tensor: The calculated feature matching loss. 110 | """ 111 | loss = 0 112 | for dr, dg in zip(fmap_r, fmap_g): 113 | for rl, gl in zip(dr, dg): 114 | loss += torch.mean(torch.abs(rl - gl)) 115 | 116 | return loss 117 | 118 | class DACGANLoss(nn.Module): 119 | """ 120 | Computes a discriminator loss, given a discriminator on 121 | generated waveforms/spectrograms compared to ground truth 122 | waveforms/spectrograms. Computes the loss for both the 123 | discriminator and the generator in separate functions. 124 | """ 125 | 126 | def __init__(self, discriminator): 127 | super().__init__() 128 | self.discriminator = discriminator 129 | 130 | def forward(self, fake, real): 131 | # d_fake = self.discriminator(fake.audio_data) 132 | # d_real = self.discriminator(real.audio_data) 133 | d_fake = self.discriminator(fake) 134 | d_real = self.discriminator(real) 135 | return d_fake, d_real 136 | 137 | def discriminator_loss(self, fake, real): 138 | d_fake, d_real = self.forward(fake.clone().detach(), real) 139 | 140 | loss_d = 0 141 | for x_fake, x_real in zip(d_fake, d_real): 142 | loss_d += torch.mean(x_fake[-1] ** 2) 143 | loss_d += torch.mean((1 - x_real[-1]) ** 2) 144 | return loss_d 145 | 146 | def generator_loss(self, fake, real): 147 | d_fake, d_real = self.forward(fake, real) 148 | 149 | loss_g = 0 150 | for x_fake in d_fake: 151 | loss_g += torch.mean((1 - x_fake[-1]) ** 2) 152 | 153 | loss_feature = 0 154 | 155 | for i in range(len(d_fake)): 156 | for j in range(len(d_fake[i]) - 1): 157 | loss_feature += F.l1_loss(d_fake[i][j], d_real[i][j].detach()) 158 | return loss_g, loss_feature 159 | 160 | -------------------------------------------------------------------------------- /inspiremusic/wavtokenizer/encoder/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # flake8: noqa 7 | 8 | """EnCodec neural audio codec.""" 9 | 10 | __version__ = "0.1.2a3" 11 | 12 | from .model import EncodecModel 13 | -------------------------------------------------------------------------------- /inspiremusic/wavtokenizer/encoder/distrib.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | """Torch distributed utilities.""" 8 | 9 | import typing as tp 10 | 11 | import torch 12 | 13 | 14 | def rank(): 15 | if torch.distributed.is_initialized(): 16 | return torch.distributed.get_rank() 17 | else: 18 | return 0 19 | 20 | 21 | def world_size(): 22 | if torch.distributed.is_initialized(): 23 | return torch.distributed.get_world_size() 24 | else: 25 | return 1 26 | 27 | 28 | def is_distributed(): 29 | return world_size() > 1 30 | 31 | 32 | def all_reduce(tensor: torch.Tensor, op=torch.distributed.ReduceOp.SUM): 33 | if is_distributed(): 34 | return torch.distributed.all_reduce(tensor, op) 35 | 36 | 37 | def _is_complex_or_float(tensor): 38 | return torch.is_floating_point(tensor) or torch.is_complex(tensor) 39 | 40 | 41 | def _check_number_of_params(params: tp.List[torch.Tensor]): 42 | # utility function to check that the number of params in all workers is the same, 43 | # and thus avoid a deadlock with distributed all reduce. 44 | if not is_distributed() or not params: 45 | return 46 | tensor = torch.tensor([len(params)], device=params[0].device, dtype=torch.long) 47 | all_reduce(tensor) 48 | if tensor.item() != len(params) * world_size(): 49 | # If not all the workers have the same number, for at least one of them, 50 | # this inequality will be verified. 51 | raise RuntimeError(f"Mismatch in number of params: ours is {len(params)}, " 52 | "at least one worker has a different one.") 53 | 54 | 55 | def broadcast_tensors(tensors: tp.Iterable[torch.Tensor], src: int = 0): 56 | """Broadcast the tensors from the given parameters to all workers. 57 | This can be used to ensure that all workers have the same model to start with. 58 | """ 59 | if not is_distributed(): 60 | return 61 | tensors = [tensor for tensor in tensors if _is_complex_or_float(tensor)] 62 | _check_number_of_params(tensors) 63 | handles = [] 64 | for tensor in tensors: 65 | handle = torch.distributed.broadcast(tensor.data, src=src, async_op=True) 66 | handles.append(handle) 67 | for handle in handles: 68 | handle.wait() 69 | 70 | 71 | def sync_buffer(buffers, average=True): 72 | """ 73 | Sync grad for buffers. If average is False, broadcast instead of averaging. 74 | """ 75 | if not is_distributed(): 76 | return 77 | handles = [] 78 | for buffer in buffers: 79 | if torch.is_floating_point(buffer.data): 80 | if average: 81 | handle = torch.distributed.all_reduce( 82 | buffer.data, op=torch.distributed.ReduceOp.SUM, async_op=True) 83 | else: 84 | handle = torch.distributed.broadcast( 85 | buffer.data, src=0, async_op=True) 86 | handles.append((buffer, handle)) 87 | for buffer, handle in handles: 88 | handle.wait() 89 | if average: 90 | buffer.data /= world_size 91 | 92 | 93 | def sync_grad(params): 94 | """ 95 | Simpler alternative to DistributedDataParallel, that doesn't rely 96 | on any black magic. For simple models it can also be as fast. 97 | Just call this on your model parameters after the call to backward! 98 | """ 99 | if not is_distributed(): 100 | return 101 | handles = [] 102 | for p in params: 103 | if p.grad is not None: 104 | handle = torch.distributed.all_reduce( 105 | p.grad.data, op=torch.distributed.ReduceOp.SUM, async_op=True) 106 | handles.append((p, handle)) 107 | for p, handle in handles: 108 | handle.wait() 109 | p.grad.data /= world_size() 110 | 111 | 112 | def average_metrics(metrics: tp.Dict[str, float], count=1.): 113 | """Average a dictionary of metrics across all workers, using the optional 114 | `count` as unnormalized weight. 115 | """ 116 | if not is_distributed(): 117 | return metrics 118 | keys, values = zip(*metrics.items()) 119 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 120 | tensor = torch.tensor(list(values) + [1], device=device, dtype=torch.float32) 121 | tensor *= count 122 | all_reduce(tensor) 123 | averaged = (tensor[:-1] / tensor[-1]).cpu().tolist() 124 | return dict(zip(keys, averaged)) 125 | -------------------------------------------------------------------------------- /inspiremusic/wavtokenizer/encoder/modules/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | """Torch modules.""" 8 | 9 | # flake8: noqa 10 | from .conv import ( 11 | pad1d, 12 | unpad1d, 13 | NormConv1d, 14 | NormConvTranspose1d, 15 | NormConv2d, 16 | NormConvTranspose2d, 17 | SConv1d, 18 | SConvTranspose1d, 19 | ) 20 | from .lstm import SLSTM 21 | from .seanet import SEANetEncoder, SEANetDecoder 22 | from .transformer import StreamingTransformerEncoder 23 | -------------------------------------------------------------------------------- /inspiremusic/wavtokenizer/encoder/modules/lstm.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | """LSTM layers module.""" 8 | 9 | from torch import nn 10 | 11 | 12 | class SLSTM(nn.Module): 13 | """ 14 | LSTM without worrying about the hidden state, nor the layout of the data. 15 | Expects input as convolutional layout. 16 | """ 17 | def __init__(self, dimension: int, num_layers: int = 2, skip: bool = True): 18 | super().__init__() 19 | self.skip = skip 20 | self.lstm = nn.LSTM(dimension, dimension, num_layers) 21 | 22 | # def forward(self, x): 23 | # x = x.permute(2, 0, 1) 24 | # y, _ = self.lstm(x) 25 | # if self.skip: 26 | # y = y + x 27 | # y = y.permute(1, 2, 0) 28 | # return y 29 | 30 | # 修改transpose顺序 31 | def forward(self, x): 32 | # # 插入reshape 33 | # x = x.reshape(x.shape) 34 | x1 = x.permute(2, 0, 1) 35 | y, _ = self.lstm(x1) 36 | y = y.permute(1, 2, 0) 37 | if self.skip: 38 | y = y + x 39 | return y 40 | -------------------------------------------------------------------------------- /inspiremusic/wavtokenizer/encoder/modules/norm.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | """Normalization modules.""" 8 | 9 | import typing as tp 10 | 11 | import einops 12 | import torch 13 | from torch import nn 14 | 15 | 16 | class ConvLayerNorm(nn.LayerNorm): 17 | """ 18 | Convolution-friendly LayerNorm that moves channels to last dimensions 19 | before running the normalization and moves them back to original position right after. 20 | """ 21 | def __init__(self, normalized_shape: tp.Union[int, tp.List[int], torch.Size], **kwargs): 22 | super().__init__(normalized_shape, **kwargs) 23 | 24 | def forward(self, x): 25 | x = einops.rearrange(x, 'b ... t -> b t ...') 26 | x = super().forward(x) 27 | x = einops.rearrange(x, 'b t ... -> b ... t') 28 | return 29 | -------------------------------------------------------------------------------- /inspiremusic/wavtokenizer/encoder/modules/transformer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | """A streamable transformer.""" 8 | 9 | import typing as tp 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | 15 | 16 | def create_sin_embedding(positions: torch.Tensor, dim: int, max_period: float = 10000): 17 | """Create time embedding for the given positions, target dimension `dim`. 18 | """ 19 | # We aim for BTC format 20 | assert dim % 2 == 0 21 | half_dim = dim // 2 22 | adim = torch.arange(half_dim, device=positions.device).view(1, 1, -1) 23 | phase = positions / (max_period ** (adim / (half_dim - 1))) 24 | return torch.cat([ 25 | torch.cos(phase), 26 | torch.sin(phase), 27 | ], dim=-1) 28 | 29 | 30 | class StreamingTransformerEncoderLayer(nn.TransformerEncoderLayer): 31 | def forward(self, x: torch.Tensor, x_past: torch.Tensor, past_context: int): # type: ignore 32 | if self.norm_first: 33 | sa_input = self.norm1(x) 34 | x = x + self._sa_block(sa_input, x_past, past_context) 35 | x = x + self._ff_block(self.norm2(x)) 36 | else: 37 | sa_input = x 38 | x = self.norm1(x + self._sa_block(sa_input, x_past, past_context)) 39 | x = self.norm2(x + self._ff_block(x)) 40 | 41 | return x, sa_input 42 | 43 | # self-attention block 44 | def _sa_block(self, x: torch.Tensor, x_past: torch.Tensor, past_context: int): # type: ignore 45 | _, T, _ = x.shape 46 | _, H, _ = x_past.shape 47 | 48 | queries = x 49 | keys = torch.cat([x_past, x], dim=1) 50 | values = keys 51 | 52 | queries_pos = torch.arange(H, T + H, device=x.device).view(-1, 1) 53 | keys_pos = torch.arange(T + H, device=x.device).view(1, -1) 54 | delta = queries_pos - keys_pos 55 | valid_access = (delta >= 0) & (delta <= past_context) 56 | x = self.self_attn(queries, keys, values, 57 | attn_mask=~valid_access, 58 | need_weights=False)[0] 59 | return self.dropout1(x) 60 | 61 | 62 | class StreamingTransformerEncoder(nn.Module): 63 | """TransformerEncoder with streaming support. 64 | 65 | Args: 66 | dim (int): dimension of the data. 67 | hidden_scale (int): intermediate dimension of FF module is this times the dimension. 68 | num_heads (int): number of heads. 69 | num_layers (int): number of layers. 70 | max_period (float): maxium period of cosines in the positional embedding. 71 | past_context (int or None): receptive field for the causal mask, infinite if None. 72 | gelu (bool): if true uses GeLUs, otherwise use ReLUs. 73 | norm_in (bool): normalize the input. 74 | dropout (float): dropout probability. 75 | **kwargs: See `nn.TransformerEncoderLayer`. 76 | """ 77 | def __init__(self, dim, hidden_scale: float = 4., num_heads: int = 8, num_layers: int = 5, 78 | max_period: float = 10000, past_context: int = 1000, gelu: bool = True, 79 | norm_in: bool = True, dropout: float = 0., **kwargs): 80 | super().__init__() 81 | assert dim % num_heads == 0 82 | hidden_dim = int(dim * hidden_scale) 83 | 84 | self.max_period = max_period 85 | self.past_context = past_context 86 | activation: tp.Any = F.gelu if gelu else F.relu 87 | 88 | self.norm_in: nn.Module 89 | if norm_in: 90 | self.norm_in = nn.LayerNorm(dim) 91 | else: 92 | self.norm_in = nn.Identity() 93 | 94 | self.layers = nn.ModuleList() 95 | for idx in range(num_layers): 96 | self.layers.append( 97 | StreamingTransformerEncoderLayer( 98 | dim, num_heads, hidden_dim, 99 | activation=activation, batch_first=True, dropout=dropout, **kwargs)) 100 | 101 | def forward(self, x: torch.Tensor, 102 | states: tp.Optional[tp.List[torch.Tensor]] = None, 103 | offset: tp.Union[int, torch.Tensor] = 0): 104 | B, T, C = x.shape 105 | if states is None: 106 | states = [torch.zeros_like(x[:, :1]) for _ in range(1 + len(self.layers))] 107 | 108 | positions = torch.arange(T, device=x.device).view(1, -1, 1) + offset 109 | pos_emb = create_sin_embedding(positions, C, max_period=self.max_period) 110 | 111 | new_state: tp.List[torch.Tensor] = [] 112 | x = self.norm_in(x) 113 | x = x + pos_emb 114 | 115 | for layer_state, layer in zip(states, self.layers): 116 | x, new_layer_state = layer(x, layer_state, self.past_context) 117 | new_layer_state = torch.cat([layer_state, new_layer_state], dim=1) 118 | new_state.append(new_layer_state[:, -self.past_context:, :]) 119 | return x, new_state, offset + T 120 | -------------------------------------------------------------------------------- /inspiremusic/wavtokenizer/encoder/quantization/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # flake8: noqa 8 | from .vq import QuantizedResult, ResidualVectorQuantizer 9 | -------------------------------------------------------------------------------- /inspiremusic/wavtokenizer/encoder/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | """Various utilities.""" 8 | 9 | from hashlib import sha256 10 | from pathlib import Path 11 | import typing as tp 12 | 13 | import torch 14 | import torchaudio 15 | 16 | 17 | def _linear_overlap_add(frames: tp.List[torch.Tensor], stride: int): 18 | # Generic overlap add, with linear fade-in/fade-out, supporting complex scenario 19 | # e.g., more than 2 frames per position. 20 | # The core idea is to use a weight function that is a triangle, 21 | # with a maximum value at the middle of the segment. 22 | # We use this weighting when summing the frames, and divide by the sum of weights 23 | # for each positions at the end. Thus: 24 | # - if a frame is the only one to cover a position, the weighting is a no-op. 25 | # - if 2 frames cover a position: 26 | # ... ... 27 | # / \/ \ 28 | # / /\ \ 29 | # S T , i.e. S offset of second frame starts, T end of first frame. 30 | # Then the weight function for each one is: (t - S), (T - t), with `t` a given offset. 31 | # After the final normalization, the weight of the second frame at position `t` is 32 | # (t - S) / (t - S + (T - t)) = (t - S) / (T - S), which is exactly what we want. 33 | # 34 | # - if more than 2 frames overlap at a given point, we hope that by induction 35 | # something sensible happens. 36 | assert len(frames) 37 | device = frames[0].device 38 | dtype = frames[0].dtype 39 | shape = frames[0].shape[:-1] 40 | total_size = stride * (len(frames) - 1) + frames[-1].shape[-1] 41 | 42 | frame_length = frames[0].shape[-1] 43 | t = torch.linspace(0, 1, frame_length + 2, device=device, dtype=dtype)[1: -1] 44 | weight = 0.5 - (t - 0.5).abs() 45 | 46 | sum_weight = torch.zeros(total_size, device=device, dtype=dtype) 47 | out = torch.zeros(*shape, total_size, device=device, dtype=dtype) 48 | offset: int = 0 49 | 50 | for frame in frames: 51 | frame_length = frame.shape[-1] 52 | out[..., offset:offset + frame_length] += weight[:frame_length] * frame 53 | sum_weight[offset:offset + frame_length] += weight[:frame_length] 54 | offset += stride 55 | assert sum_weight.min() > 0 56 | return out / sum_weight 57 | 58 | 59 | def _get_checkpoint_url(root_url: str, checkpoint: str): 60 | if not root_url.endswith('/'): 61 | root_url += '/' 62 | return root_url + checkpoint 63 | 64 | 65 | def _check_checksum(path: Path, checksum: str): 66 | sha = sha256() 67 | with open(path, 'rb') as file: 68 | while True: 69 | buf = file.read(2**20) 70 | if not buf: 71 | break 72 | sha.update(buf) 73 | actual_checksum = sha.hexdigest()[:len(checksum)] 74 | if actual_checksum != checksum: 75 | raise RuntimeError(f'Invalid checksum for file {path}, ' 76 | f'expected {checksum} but got {actual_checksum}') 77 | 78 | 79 | def convert_audio(wav: torch.Tensor, sr: int, target_sr: int, target_channels: int): 80 | assert wav.dim() >= 2, "Audio tensor must have at least 2 dimensions" 81 | assert wav.shape[-2] in [1, 2], "Audio must be mono or stereo." 82 | *shape, channels, length = wav.shape 83 | if target_channels == 1: 84 | wav = wav.mean(-2, keepdim=True) 85 | elif target_channels == 2: 86 | wav = wav.expand(*shape, target_channels, length) 87 | elif channels == 1: 88 | wav = wav.expand(target_channels, -1) 89 | else: 90 | raise RuntimeError(f"Impossible to convert from {channels} to {target_channels}") 91 | wav = torchaudio.transforms.Resample(sr, target_sr)(wav) 92 | return wav 93 | 94 | 95 | def save_audio(wav: torch.Tensor, path: tp.Union[Path, str], 96 | sample_rate: int, rescale: bool = False): 97 | limit = 0.99 98 | mx = wav.abs().max() 99 | if rescale: 100 | wav = wav * min(limit / mx, 1) 101 | else: 102 | wav = wav.clamp(-limit, limit) 103 | torchaudio.save(str(path), wav, sample_rate=sample_rate, encoding='PCM_S', bits_per_sample=16) 104 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | conformer 2 | diffusers 3 | deepspeed 4 | gdown 5 | gradio 6 | grpcio 7 | grpcio-tools 8 | hydra-core==1.3.2 9 | hyperpyyaml 10 | inflect 11 | librosa 12 | lightning 13 | matplotlib 14 | modelscope 15 | networkx 16 | omegaconf==2.3.0 17 | onnx 18 | onnxruntime-gpu 19 | onnxruntime 20 | protobuf 21 | pydantic 22 | rich 23 | soundfile 24 | tensorboard 25 | torch 26 | torchaudio 27 | uvicorn 28 | fastapi 29 | fastapi-cli 30 | WeTextProcessing 31 | transformers==4.46.3 32 | accelerate 33 | huggingface-hub 34 | pystoi 35 | tqdm 36 | pystoi 37 | einops 38 | scipy 39 | accelerate 40 | peft==0.13.2 41 | flash-attn==2.7.4.post1 42 | wget 43 | pyarrow 44 | antlr4-python3-runtime -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) 2024 Alibaba Inc (authors: Chong Zhang) 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """InspireMusic setup script.""" 17 | 18 | import os 19 | 20 | from setuptools import find_packages, setup 21 | 22 | requirements = { 23 | "install": [ 24 | "setuptools", 25 | "conformer==0.3.2", 26 | "diffusers==0.27.2", 27 | "gdown==5.1.0", 28 | "gradio==5.5.0", 29 | "grpcio==1.57.0", 30 | "grpcio-tools==1.57.0", 31 | "hydra-core==1.3.2", 32 | "HyperPyYAML==1.2.2", 33 | "inflect==7.3.1", 34 | "librosa==0.10.2", 35 | "lightning==2.2.4", 36 | "matplotlib==3.7.5", 37 | "modelscope==1.15.0", 38 | "networkx==3.1", 39 | "omegaconf==2.3.0", 40 | "onnx==1.17.0", 41 | "protobuf==4.25", 42 | "pydantic==2.7.0", 43 | "rich==13.7.1", 44 | "soundfile==0.12.1", 45 | "tensorboard==2.14.0", 46 | "torch==2.0.1", 47 | "torchaudio==2.0.2", 48 | "uvicorn==0.30.0", 49 | "wget==3.2", 50 | "fastapi==0.111.0", 51 | "fastapi-cli==0.0.4", 52 | "WeTextProcessing==1.0.3", 53 | "accelerate", 54 | "huggingface-hub==0.25.2", 55 | "julius", 56 | "onnxruntime-gpu==1.16.0", 57 | "onnxruntime==1.16.0", 58 | "transformers", 59 | ], 60 | # train: The modules invoked when training only. 61 | "train": [ 62 | "deepspeed==0.14.2", 63 | ], 64 | # all: The modules should be optionally installled due to some reason. 65 | # Please consider moving them to "install" occasionally 66 | "all": [ 67 | # NOTE(kamo): Append modules requiring specific pytorch version or torch>2.0 68 | "transformers", 69 | "openai-whisper==20231117", 70 | ], 71 | "setup": [ 72 | "numpy", 73 | ], 74 | "test": [ 75 | "pytest>=3.3.0", 76 | ], 77 | } 78 | requirements["all"].extend(requirements["train"]) 79 | requirements["test"].extend(requirements["train"]) 80 | 81 | install_requires = requirements["install"] 82 | setup_requires = requirements["setup"] 83 | tests_require = requirements["test"] 84 | extras_require = {k: v for k, v in requirements.items() if k not in ["install", "setup"]} 85 | 86 | dirname = os.path.dirname(__file__) 87 | version_file = os.path.join(dirname, "inspiremusic", "version.txt") 88 | with open(version_file, "r") as f: 89 | version = f.read().strip() 90 | setup( 91 | name="inspiremusic", 92 | version=version, 93 | url="https://github.com/FunAudioLLM/InspireMusic.git", 94 | author="Tongyi Lab, Alibaba Group", 95 | author_email="chong.zhang@alibaba-inc.com", 96 | description="InspireMusic: A Framework for Music, Audio and Song Generation", 97 | long_description=open(os.path.join(dirname, "README.md"), encoding="utf-8").read(), 98 | long_description_content_type="text/markdown", 99 | license="The MIT License", 100 | packages=find_packages(include=["inspiremusic*"]), 101 | package_data={"inspiremusic": ["version.txt"]}, 102 | install_requires=install_requires, 103 | setup_requires=setup_requires, 104 | tests_require=tests_require, 105 | extras_require=extras_require, 106 | python_requires=">=3.8.0", 107 | classifiers=[ 108 | "Programming Language :: Python", 109 | "Programming Language :: Python :: 3", 110 | "Programming Language :: Python :: 3.8", 111 | "Development Status :: 5 - Production/Stable", 112 | "Intended Audience :: Science/Research", 113 | "Operating System :: POSIX :: Linux", 114 | "License :: OSI Approved :: Apache Software License", 115 | "Topic :: Software Development :: Libraries :: Python Modules", 116 | ], 117 | entry_points={ 118 | "console_scripts": [ 119 | "inspiremusic = inspiremusic.bin.inference:main", 120 | "inspiremusic-train = inspiremusic.bin.train:main", 121 | ] 122 | }, 123 | ) -------------------------------------------------------------------------------- /tools/extract_acoustic_token.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu) 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | import argparse 16 | import logging 17 | import torch 18 | from tqdm import tqdm 19 | import numpy as np 20 | import torchaudio 21 | from inspiremusic.utils.audio_utils import normalize, split_wav_into_chunks 22 | from inspiremusic.music_tokenizer.vqvae import VQVAE 23 | import time 24 | 25 | logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') 26 | 27 | def main(args): 28 | audio_min_length = 1.0 29 | audio_max_length = 30.0 30 | max_chunk_size = int(args.sample_rate * audio_max_length) 31 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 32 | utt2wav = {} 33 | with open('{}/wav.scp'.format(args.dir)) as f: 34 | for l in f: 35 | l = l.replace('\n', '').split() 36 | utt2wav[l[0]] = l[1] 37 | 38 | model = VQVAE(args.config_path, args.ckpt_path, with_encoder=True) 39 | model.cuda() 40 | model.eval() 41 | 42 | utt2acoustic_token = {} 43 | start_time = time.time() 44 | for utt in tqdm(utt2wav.keys()): 45 | audio, sample_rate = torchaudio.load(utt2wav[utt]) 46 | if sample_rate != args.sample_rate: 47 | audio = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=args.sample_rate)(audio) 48 | audio_length = audio.shape[1] 49 | if audio_length > args.sample_rate * audio_min_length: 50 | if audio_length > max_chunk_size: 51 | wav_chunks = split_wav_into_chunks(audio_length, audio, max_chunk_size) 52 | for chunk in wav_chunks: 53 | chunk = torch.tensor(chunk, dtype=torch.float32).to(device) 54 | acoustic_token = model.encode(chunk) 55 | if acoustic_token.is_cuda: 56 | acoustic_token = acoustic_token.cpu() 57 | acoustic_token = acoustic_token.numpy().astype(np.int16) 58 | if utt not in utt2acoustic_token.keys(): 59 | utt2acoustic_token[utt] = acoustic_token 60 | else: 61 | utt2acoustic_token[utt] = np.concatenate((utt2acoustic_token[utt], acoustic_token), axis=1) 62 | else: 63 | audio = torch.tensor(audio, dtype=torch.float32).to(device) 64 | acoustic_token = model.encode(audio) 65 | if acoustic_token.is_cuda: 66 | acoustic_token = acoustic_token.cpu() 67 | acoustic_token = acoustic_token.numpy().astype(np.int16) 68 | utt2acoustic_token[utt] = acoustic_token 69 | else: 70 | logging.warning('This audio length is too short.') 71 | 72 | torch.save(utt2acoustic_token, '{}/utt2acoustic_token.pt'.format(args.dir)) 73 | logging.info('spend time {}'.format(time.time() - start_time)) 74 | 75 | 76 | if __name__ == "__main__": 77 | parser = argparse.ArgumentParser() 78 | parser.add_argument('--dir', 79 | type=str) 80 | parser.add_argument('--config_path', 81 | type=str, default="pretrained_models/InspireMusic-Base/music_tokenizer/config.json") 82 | parser.add_argument('--ckpt_path', 83 | type=str, default="pretrained_models/InspireMusic-Base/music_tokenizer/model.pt") 84 | parser.add_argument('--sample_rate', 85 | default=24000, 86 | type=int) 87 | args = parser.parse_args() 88 | 89 | main(args) 90 | -------------------------------------------------------------------------------- /tools/extract_embedding.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) 2024 Alibaba Inc 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | import argparse 16 | import torch 17 | import torchaudio 18 | from tqdm import tqdm 19 | import onnxruntime 20 | import torchaudio.compliance.kaldi as kaldi 21 | 22 | 23 | def main(args): 24 | utt2wav, utt2spk = {}, {} 25 | with open('{}/wav.scp'.format(args.dir)) as f: 26 | for l in f: 27 | l = l.replace('\n', '').split() 28 | utt2wav[l[0]] = l[1] 29 | with open('{}/utt2spk'.format(args.dir)) as f: 30 | for l in f: 31 | l = l.replace('\n', '').split() 32 | utt2spk[l[0]] = l[1] 33 | 34 | option = onnxruntime.SessionOptions() 35 | option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL 36 | option.intra_op_num_threads = 1 37 | providers = ["CPUExecutionProvider"] 38 | ort_session = onnxruntime.InferenceSession(args.onnx_path, sess_options=option, providers=providers) 39 | 40 | utt2embedding, spk2embedding = {}, {} 41 | for utt in tqdm(utt2wav.keys()): 42 | audio, sample_rate = torchaudio.load(utt2wav[utt]) 43 | if sample_rate != 16000: 44 | audio = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(audio) 45 | feat = kaldi.fbank(audio, 46 | num_mel_bins=80, 47 | dither=0, 48 | sample_frequency=16000) 49 | feat = feat - feat.mean(dim=0, keepdim=True) 50 | embedding = ort_session.run(None, {ort_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()})[0].flatten().tolist() 51 | utt2embedding[utt] = embedding 52 | spk = utt2spk[utt] 53 | if spk not in spk2embedding: 54 | spk2embedding[spk] = [] 55 | spk2embedding[spk].append(embedding) 56 | for k, v in spk2embedding.items(): 57 | spk2embedding[k] = torch.tensor(v).mean(dim=0).tolist() 58 | 59 | torch.save(utt2embedding, '{}/utt2embedding.pt'.format(args.dir)) 60 | torch.save(spk2embedding, '{}/spk2embedding.pt'.format(args.dir)) 61 | 62 | 63 | if __name__ == "__main__": 64 | parser = argparse.ArgumentParser() 65 | parser.add_argument('--dir', 66 | type=str) 67 | parser.add_argument('--onnx_path', 68 | type=str) 69 | args = parser.parse_args() 70 | main(args) 71 | -------------------------------------------------------------------------------- /tools/extract_semantic_token.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) 2024 Alibaba Inc 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | import argparse 16 | import logging 17 | import torch 18 | from tqdm import tqdm 19 | import numpy as np 20 | import torchaudio 21 | import time 22 | import os 23 | from inspiremusic.wavtokenizer.decoder.pretrained import WavTokenizer 24 | from inspiremusic.utils.audio_utils import split_wav_into_chunks 25 | 26 | logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') 27 | def main(args): 28 | audio_min_length = 1.0 29 | audio_max_length = 30.0 30 | max_chunk_size = int(args.sample_rate * audio_max_length) 31 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 32 | utt2wav = {} 33 | with open('{}/wav.scp'.format(args.dir)) as f: 34 | for l in f: 35 | l = l.replace('\n', '').split() 36 | utt2wav[l[0]] = l[1] 37 | 38 | wavtokenizer = WavTokenizer.from_pretrained_feat(args.config_path, args.ckpt_path).to(device) 39 | bandwidth_id = torch.tensor([0]).to(device) 40 | start_time = time.time() 41 | utt2semantic_token = {} 42 | for utt in tqdm(utt2wav.keys()): 43 | audio, sample_rate = torchaudio.load(utt2wav[utt]) 44 | 45 | if sample_rate != args.sample_rate: 46 | audio = torchaudio.functional.resample(audio, orig_freq=sample_rate, new_freq=args.sample_rate) 47 | audio_length = audio.shape[1] 48 | if audio_length > args.sample_rate * audio_min_length: 49 | if audio_length > max_chunk_size: 50 | wav_batch = split_wav_into_chunks(audio_length, audio, max_chunk_size) 51 | for chunk in wav_batch: 52 | chunk = torch.tensor(chunk, dtype=torch.float32).to(device) 53 | _, semantic_token = wavtokenizer.encode_infer(chunk, bandwidth_id=bandwidth_id) 54 | if semantic_token.is_cuda: 55 | semantic_token = semantic_token.cpu() 56 | semantic_token = semantic_token.squeeze(0).numpy().astype(np.int16) 57 | if utt not in utt2semantic_token.keys(): 58 | utt2semantic_token[utt] = semantic_token 59 | else: 60 | utt2semantic_token[utt] = np.concatenate((utt2semantic_token[utt], semantic_token), axis=1) 61 | else: 62 | audio = torch.tensor(audio, dtype=torch.float32).to(device) 63 | _, semantic_token = wavtokenizer.encode_infer(audio, bandwidth_id=bandwidth_id) 64 | if semantic_token.is_cuda: 65 | semantic_token = semantic_token.cpu() 66 | semantic_token = semantic_token.squeeze(0).numpy().astype(np.int16) 67 | utt2semantic_token[utt] = semantic_token 68 | else: 69 | logging.warning('This audio length is too short.') 70 | 71 | torch.save(utt2semantic_token, '{}/utt2semantic_token.pt'.format(args.dir)) 72 | logging.info('spend time {}'.format(time.time() - start_time)) 73 | 74 | 75 | def reconstruct(semantic_token_file, config_path, ckpt_path, outdir, sample_rate=24000): 76 | if not os.path.isdir(outdir): 77 | os.makedirs(outdir, exist_ok=True) 78 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 79 | bandwidth_id = torch.tensor([0]).to(device) 80 | wavtokenizer = WavTokenizer.from_pretrained_feat(config_path, ckpt_path).to(device) 81 | utt2semantic_token = torch.load(semantic_token_file) 82 | for utt in tqdm(utt2semantic_token.keys()): 83 | token = utt2semantic_token[utt] 84 | new_tensor = torch.tensor(token).to(device).unsqueeze(0) 85 | features = wavtokenizer.codes_to_features(new_tensor) 86 | wav = wavtokenizer.decode(features, bandwidth_id=bandwidth_id) 87 | wav = wav.cpu().detach() 88 | torchaudio.save(outdir + "/" + utt + ".wav", wav, sample_rate=sample_rate, encoding='PCM_S', bits_per_sample=16) 89 | 90 | if __name__ == "__main__": 91 | parser = argparse.ArgumentParser() 92 | parser.add_argument('--dir', 93 | type=str) 94 | parser.add_argument('--config_path', 95 | type=str, default="pretrained_models/InspireMusic-Base/wavtokenizer/config.yaml") 96 | parser.add_argument('--ckpt_path', 97 | type=str, default="pretrained_models/InspireMusic-Base/wavtokenizer/model.pt") 98 | parser.add_argument('--sample_rate', 99 | default=24000, 100 | type=int) 101 | parser.add_argument('--outwavdir', 102 | type=str, default="./exp/wavs") 103 | 104 | args = parser.parse_args() 105 | 106 | main(args) 107 | -------------------------------------------------------------------------------- /tools/extract_speech_token.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) 2024 Alibaba Inc 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | import argparse 16 | import logging 17 | import torch 18 | from tqdm import tqdm 19 | import onnxruntime 20 | import numpy as np 21 | import torchaudio 22 | import whisper 23 | 24 | 25 | def main(args): 26 | utt2wav = {} 27 | with open('{}/wav.scp'.format(args.dir)) as f: 28 | for l in f: 29 | l = l.replace('\n', '').split() 30 | utt2wav[l[0]] = l[1] 31 | 32 | option = onnxruntime.SessionOptions() 33 | option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL 34 | option.intra_op_num_threads = 1 35 | providers = ["CUDAExecutionProvider"] 36 | ort_session = onnxruntime.InferenceSession(args.onnx_path, sess_options=option, providers=providers) 37 | 38 | utt2speech_token = {} 39 | for utt in tqdm(utt2wav.keys()): 40 | audio, sample_rate = torchaudio.load(utt2wav[utt]) 41 | if sample_rate != 16000: 42 | audio = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(audio) 43 | if audio.shape[1] / 16000 > 30: 44 | logging.warning('do not support extract speech token for audio longer than 30s') 45 | speech_token = [] 46 | else: 47 | feat = whisper.log_mel_spectrogram(audio, n_mels=128) 48 | speech_token = ort_session.run(None, {ort_session.get_inputs()[0].name: feat.detach().cpu().numpy(), 49 | ort_session.get_inputs()[1].name: np.array([feat.shape[2]], dtype=np.int32)})[0].flatten().tolist() 50 | utt2speech_token[utt] = speech_token 51 | torch.save(utt2speech_token, '{}/utt2speech_token.pt'.format(args.dir)) 52 | 53 | 54 | if __name__ == "__main__": 55 | parser = argparse.ArgumentParser() 56 | parser.add_argument('--dir', 57 | type=str) 58 | parser.add_argument('--onnx_path', 59 | type=str) 60 | args = parser.parse_args() 61 | main(args) 62 | --------------------------------------------------------------------------------