├── vall_e ├── __init__.py ├── emb │ ├── __init__.py │ ├── codecs │ │ ├── __init__.py │ │ ├── vocos.py │ │ ├── encodec.py │ │ └── dac.py │ └── g2p.py ├── utils │ ├── ext │ │ ├── __init__.py │ │ ├── unsloth.py │ │ └── muon.py │ ├── __init__.py │ ├── distributed.py │ ├── io.py │ ├── sampler.py │ └── ml.py ├── metrics.py ├── models │ ├── arch │ │ └── __init__.py │ ├── __init__.py │ └── lora.py ├── plot.py ├── engines │ └── deepspeed.py └── __main__.py ├── test.wav ├── vall-e.png ├── data ├── qnt.dac ├── qnt.enc ├── qnt.nem ├── noise.enc ├── demo │ └── index.template.html └── tongue_twisters.txt ├── scripts ├── run.sh ├── setup.sh ├── process_nscripter.py ├── prepare_librilight.py ├── deduplicate_librilight_libritts.py ├── process_seed-tts.py ├── parse_ppp.py ├── cleanup_dataset.py ├── train_tokenizer.py ├── process_emilia.py └── process_libritts.py ├── docs ├── demo.md ├── plot.md ├── export.md ├── metrics.md ├── engines.md ├── utils.md ├── webui.md ├── inferenece.md └── samplers.md ├── .gitignore ├── vall_e.cpp ├── include │ ├── ggml-blas.h │ ├── ggml-opencl.h │ ├── ops.h │ ├── llama-cpp.h │ ├── ggml-vulkan.h │ ├── ggml-rpc.h │ ├── utils.h │ ├── ggml-kompute.h │ ├── ggml-cuda.h │ ├── ggml-cpp.h │ ├── ggml-sycl.h │ ├── llama-impl.h │ ├── ggml-metal.h │ ├── lstm.h │ ├── ggml-alloc.h │ ├── encoder.h │ ├── decoder.h │ ├── espeak-ng │ │ ├── encoding.h │ │ └── espeak_ng.h │ ├── quantizer.h │ ├── llama-vocab.h │ ├── ggml-cann.h │ ├── ggml-cpu.h │ └── encodec.h ├── Makefile ├── README.md ├── vall_e-impl.h └── vall_e.h ├── README.md └── setup.py /vall_e/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /vall_e/emb/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /vall_e/utils/ext/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /vall_e/emb/codecs/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /vall_e/emb/codecs/vocos.py: -------------------------------------------------------------------------------- 1 | from vocos import Vocos -------------------------------------------------------------------------------- /test.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/e-c-k-e-r/vall-e/HEAD/test.wav -------------------------------------------------------------------------------- /vall-e.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/e-c-k-e-r/vall-e/HEAD/vall-e.png -------------------------------------------------------------------------------- /data/qnt.dac: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/e-c-k-e-r/vall-e/HEAD/data/qnt.dac -------------------------------------------------------------------------------- /data/qnt.enc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/e-c-k-e-r/vall-e/HEAD/data/qnt.enc -------------------------------------------------------------------------------- /data/qnt.nem: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/e-c-k-e-r/vall-e/HEAD/data/qnt.nem -------------------------------------------------------------------------------- /data/noise.enc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/e-c-k-e-r/vall-e/HEAD/data/noise.enc -------------------------------------------------------------------------------- /scripts/run.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | until $@; do echo retrying; done 4 | -------------------------------------------------------------------------------- /vall_e/emb/codecs/encodec.py: -------------------------------------------------------------------------------- 1 | from encodec import EncodecModel 2 | from encodec.utils import convert_audio -------------------------------------------------------------------------------- /docs/demo.md: -------------------------------------------------------------------------------- 1 | # `demo.py` 2 | 3 | This script handles generating demo pages for comparing against other TTS solutions. 4 | 5 | As this is for my own internal use, documentation at the moment is sparing. 6 | 7 | To be filled. -------------------------------------------------------------------------------- /data/demo/index.template.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 |

VALL-E Demo

7 |

${PREAMBLE}

8 | ${TABLES} 9 |

Settings used:

${SETTINGS}

10 | 11 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | /data 3 | /training 4 | /venv 5 | /*.egg-info 6 | /vall_e/version.py 7 | /.cache 8 | /voices 9 | /wandb 10 | /.nltk 11 | /vall_e.cpp/data 12 | /vall_e.cpp/include 13 | /vall_e.cpp/lib 14 | /vall_e.cpp/*.o 15 | /vall_e.cpp/vall_e 16 | -------------------------------------------------------------------------------- /docs/plot.md: -------------------------------------------------------------------------------- 1 | # `plot.py` 2 | 3 | Included is a helper script to parse the training metrics. Simply invoke it with, for example: `python3 -m vall_e.plot --yaml="./training/config.yaml"` 4 | 5 | You can specify what X and Y labels you want to plot against by passing `--xs tokens_processed --ys loss.nll stats.acc` -------------------------------------------------------------------------------- /vall_e/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .utils import ( 2 | dispatch_attribute, 3 | flatten_dict, 4 | gather_attribute, 5 | load_state_dict_non_strict, 6 | setup_logging, 7 | to_device, 8 | tree_map, 9 | do_gc, 10 | set_seed, 11 | passes_policy, 12 | get_devices, 13 | truncate_json, 14 | timer, 15 | prune_missing, 16 | clamp, 17 | md5_hash, 18 | convert_kwargs, 19 | coerce_dtype, 20 | mean, 21 | logit_normalization, 22 | ) -------------------------------------------------------------------------------- /scripts/setup.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python3 -m venv venv 4 | source ./venv/bin/activate 5 | pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu121 # or cu118 / cu124 6 | pip3 install -e . 7 | 8 | mkdir -p ./training/valle/ckpt/ar+nar-llama-8/ 9 | wget -P ./training/valle/ckpt/ar+nar-llama-8/ "https://huggingface.co/ecker/vall-e/resolve/main/models/ckpt/ar%2Bnar-llama-8/fp32.sft" 10 | wget -P ./training/valle/ "https://huggingface.co/ecker/vall-e/resolve/main/models/config.llama.yaml" 11 | -------------------------------------------------------------------------------- /docs/export.md: -------------------------------------------------------------------------------- 1 | # `export.py` 2 | 3 | To export the models, run: `python -m vall_e.export --yaml=./training/config.yaml`. 4 | 5 | This will export the latest checkpoints, for example, under `./training/ckpt/ar+nar-retnet-8/fp32.pth`, to be loaded on any system with PyTorch, and will include additional metadata, such as the symmap used, and training stats. 6 | 7 | Desite being called `fp32.sft` or `fp32.pth`, you can export it to a different precision type with `--dtype=float16|bfloat16|float32`. 8 | 9 | You can also export to `safetensors` with `--format=sft`, and `fp32.sft` will be exported instead. -------------------------------------------------------------------------------- /vall_e.cpp/include/ggml-blas.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "ggml.h" 4 | #include "ggml-backend.h" 5 | 6 | 7 | #ifdef __cplusplus 8 | extern "C" { 9 | #endif 10 | 11 | // backend API 12 | GGML_BACKEND_API ggml_backend_t ggml_backend_blas_init(void); 13 | 14 | GGML_BACKEND_API bool ggml_backend_is_blas(ggml_backend_t backend); 15 | 16 | // number of threads used for conversion to float 17 | // for openblas and blis, this will also set the number of threads used for blas operations 18 | GGML_BACKEND_API void ggml_backend_blas_set_n_threads(ggml_backend_t backend_blas, int n_threads); 19 | 20 | GGML_BACKEND_API ggml_backend_reg_t ggml_backend_blas_reg(void); 21 | 22 | 23 | #ifdef __cplusplus 24 | } 25 | #endif 26 | -------------------------------------------------------------------------------- /vall_e.cpp/include/ggml-opencl.h: -------------------------------------------------------------------------------- 1 | #ifndef GGML_OPENCL_H 2 | #define GGML_OPENCL_H 3 | 4 | #include "ggml.h" 5 | #include "ggml-backend.h" 6 | 7 | #ifdef __cplusplus 8 | extern "C" { 9 | #endif 10 | 11 | // 12 | // backend API 13 | // 14 | GGML_BACKEND_API ggml_backend_t ggml_backend_opencl_init(void); 15 | GGML_BACKEND_API bool ggml_backend_is_opencl(ggml_backend_t backend); 16 | 17 | GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_opencl_buffer_type(void); 18 | GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_opencl_host_buffer_type(void); 19 | 20 | GGML_BACKEND_API ggml_backend_reg_t ggml_backend_opencl_reg(void); 21 | 22 | #ifdef __cplusplus 23 | } 24 | #endif 25 | 26 | #endif // GGML_OPENCL_H 27 | -------------------------------------------------------------------------------- /vall_e.cpp/include/ops.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "ggml.h" 4 | 5 | struct ggml_tensor *pad_1d(struct ggml_context *ctx0, struct ggml_tensor *inp, 6 | int padding_left, int padding_right); 7 | 8 | struct ggml_tensor *unpad_1d(struct ggml_context *ctx0, struct ggml_tensor *inp, 9 | int padding_left, int padding_right); 10 | 11 | struct ggml_tensor *strided_conv_1d(struct ggml_context *ctx0, struct ggml_tensor *inp, 12 | struct ggml_tensor *conv_w, struct ggml_tensor *conv_b, 13 | int stride); 14 | 15 | struct ggml_tensor *strided_conv_transpose_1d(struct ggml_context *ctx0, struct ggml_tensor *inp, 16 | struct ggml_tensor *conv_w, struct ggml_tensor *conv_b, 17 | int stride); 18 | -------------------------------------------------------------------------------- /vall_e.cpp/include/llama-cpp.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #ifndef __cplusplus 4 | #error "This header is for C++ only" 5 | #endif 6 | 7 | #include 8 | 9 | #include "llama.h" 10 | 11 | struct llama_model_deleter { 12 | void operator()(llama_model * model) { llama_model_free(model); } 13 | }; 14 | 15 | struct llama_context_deleter { 16 | void operator()(llama_context * context) { llama_free(context); } 17 | }; 18 | 19 | struct llama_sampler_deleter { 20 | void operator()(llama_sampler * sampler) { llama_sampler_free(sampler); } 21 | }; 22 | 23 | struct llama_adapter_lora_deleter { 24 | void operator()(llama_adapter_lora * adapter) { llama_adapter_lora_free(adapter); } 25 | }; 26 | 27 | typedef std::unique_ptr llama_model_ptr; 28 | typedef std::unique_ptr llama_context_ptr; 29 | typedef std::unique_ptr llama_sampler_ptr; 30 | typedef std::unique_ptr llama_adapter_lora_ptr; 31 | -------------------------------------------------------------------------------- /vall_e.cpp/Makefile: -------------------------------------------------------------------------------- 1 | ifeq ($(PREFIX),) 2 | PREFIX := /usr/local 3 | endif 4 | 5 | CXX = g++ 6 | 7 | INCS += -I./include 8 | LIBS += -L./lib 9 | 10 | LINKS += -lggml -lggml-base -lllama -lencodec -lespeak-ng 11 | FLAGS += -march=native -O3 -DVALL_E_EXPORTS 12 | 13 | SRCS := $(shell find ./ -name "*.cpp") 14 | OBJS += $(patsubst %.cpp,%.o,$(SRCS)) 15 | 16 | TARGET = vall_e 17 | TARGET_LIB = lib$(TARGET).so 18 | TARGET_HEADER = $(TARGET).h 19 | 20 | 21 | %.o: %.cpp 22 | $(CXX) $(FLAGS) $(INCS) -c $< -o $@ 23 | 24 | $(TARGET): $(OBJS) 25 | $(CXX) $(FLAGS) $(OBJS) $(LIBS) $(INCS) $(LINKS) -o $(TARGET) 26 | 27 | $(TARGET_LIB): $(OBJS) 28 | $(CXX) $(FLAGS) $(OBJS) $(LIBS) $(INCS) $(LINKS) -o $(TARGET_LIB) 29 | 30 | all: $(TARGET_LIB) $(TARGET) 31 | 32 | lib: $(TARGET_LIB) 33 | 34 | install: 35 | cp $(TARGET) $(PREFIX)/bin/$(TARGET) 36 | -cp $(TARGET_LIB) $(PREFIX)/lib/$(TARGET_LIB) 37 | cp $(TARGET_HEADER) $(PREFIX)/include/$(TARGET_HEADER) 38 | 39 | clean: 40 | @-rm -f $(OBJS) 41 | @-rm -f $(TARGET) 42 | @-rm -f $(TARGET_LIB) -------------------------------------------------------------------------------- /vall_e.cpp/include/ggml-vulkan.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "ggml.h" 4 | #include "ggml-backend.h" 5 | 6 | #ifdef __cplusplus 7 | extern "C" { 8 | #endif 9 | 10 | #define GGML_VK_NAME "Vulkan" 11 | #define GGML_VK_MAX_DEVICES 16 12 | 13 | // backend API 14 | GGML_BACKEND_API ggml_backend_t ggml_backend_vk_init(size_t dev_num); 15 | 16 | GGML_BACKEND_API bool ggml_backend_is_vk(ggml_backend_t backend); 17 | GGML_BACKEND_API int ggml_backend_vk_get_device_count(void); 18 | GGML_BACKEND_API void ggml_backend_vk_get_device_description(int device, char * description, size_t description_size); 19 | GGML_BACKEND_API void ggml_backend_vk_get_device_memory(int device, size_t * free, size_t * total); 20 | 21 | GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_vk_buffer_type(size_t dev_num); 22 | // pinned host buffer for use with the CPU backend for faster copies between CPU and GPU 23 | GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_vk_host_buffer_type(void); 24 | 25 | GGML_BACKEND_API ggml_backend_reg_t ggml_backend_vk_reg(void); 26 | 27 | #ifdef __cplusplus 28 | } 29 | #endif 30 | -------------------------------------------------------------------------------- /vall_e.cpp/include/ggml-rpc.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "ggml.h" 4 | #include "ggml-backend.h" 5 | 6 | #ifdef __cplusplus 7 | extern "C" { 8 | #endif 9 | 10 | #define GGML_RPC_MAX_SERVERS 16 11 | 12 | // backend API 13 | GGML_BACKEND_API ggml_backend_t ggml_backend_rpc_init(const char * endpoint); 14 | GGML_BACKEND_API bool ggml_backend_is_rpc(ggml_backend_t backend); 15 | 16 | GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint); 17 | 18 | GGML_BACKEND_API void ggml_backend_rpc_get_device_memory(const char * endpoint, size_t * free, size_t * total); 19 | 20 | GGML_BACKEND_API void ggml_backend_rpc_start_server(ggml_backend_t backend, const char * endpoint, 21 | const char * cache_dir, 22 | size_t free_mem, size_t total_mem); 23 | 24 | GGML_BACKEND_API ggml_backend_reg_t ggml_backend_rpc_reg(void); 25 | 26 | GGML_BACKEND_API ggml_backend_dev_t ggml_backend_rpc_add_device(const char * endpoint); 27 | 28 | #ifdef __cplusplus 29 | } 30 | #endif 31 | -------------------------------------------------------------------------------- /vall_e.cpp/include/utils.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | #define MAX(a, b) ((a) > (b) ? (a) : (b)) 6 | #define MIN(a, b) ((a) < (b) ? (a) : (b)) 7 | 8 | const size_t MB = 1024 * 1024; 9 | 10 | template 11 | void read_safe(std::ifstream &infile, T &dest) { 12 | infile.read((char *)&dest, sizeof(T)); 13 | } 14 | 15 | int32_t get_num_codebooks(float bandwidth, int hop_length, float sample_rate) { 16 | // The number of codebooks is determined by the bandwidth selected. 17 | // Supported bandwidths are 1.5kbps (n_q = 2), 3 kbps (n_q = 4), 6 kbps (n_q = 8), 18 | // 12 kbps (n_q = 16) and 24kbps (n_q = 32). 19 | return (int32_t)ceilf(1000 * bandwidth / (ceilf(sample_rate / hop_length) * 10)); 20 | } 21 | 22 | int32_t get_bandwidth_per_quantizer(int bins, float frame_rate) { 23 | return log2f((float)bins) * frame_rate; 24 | } 25 | 26 | int32_t get_num_quantizers_for_bandwidth(int bins, float frame_rate, float bandwidth) { 27 | float bw_per_q = get_bandwidth_per_quantizer(bins, frame_rate); 28 | int32_t n_q = MAX(1, floorf(bandwidth * 1000 / bw_per_q)); 29 | return n_q; 30 | } 31 | -------------------------------------------------------------------------------- /scripts/process_nscripter.py: -------------------------------------------------------------------------------- 1 | """ 2 | Handles processing NScripter's 0.u file to clean up the pile of audio clips it has 3 | 4 | * to-do: also grab transcriptions 5 | """ 6 | 7 | import os 8 | import re 9 | import json 10 | import argparse 11 | import torch 12 | import shutil 13 | import torchaudio 14 | import numpy as np 15 | 16 | from tqdm.auto import tqdm 17 | from pathlib import Path 18 | 19 | def process( 20 | input_file=Path("./assets/0.u"), 21 | wav_dir=Path("./arc/"), 22 | output_dir=Path("./dataset/"), 23 | ): 24 | file = open(input_file, encoding='utf-8').read() 25 | 26 | names = {} 27 | aliases = {} 28 | lines = file.split('\n') 29 | 30 | for line in lines: 31 | if not line.startswith('stralias'): 32 | continue 33 | # ick 34 | try: 35 | key, path = re.findall(r'^stralias (.+?),"(.+?)"$', line)[0] 36 | name = key.split("_")[0] 37 | if name not in names: 38 | (output_dir / name).mkdir(parents=True, exist_ok=True) 39 | names[name] = True 40 | 41 | aliases[key] = Path(path) 42 | except Exception as e: 43 | pass 44 | 45 | for k, v in aliases.items(): 46 | name = k.split("_")[0] 47 | 48 | 49 | print(aliases) 50 | 51 | if __name__ == "__main__": 52 | process() -------------------------------------------------------------------------------- /vall_e.cpp/include/ggml-kompute.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "ggml.h" 4 | #include "ggml-backend.h" 5 | 6 | #include 7 | #include 8 | #include 9 | 10 | #ifdef __cplusplus 11 | extern "C" { 12 | #endif 13 | 14 | #define GGML_KOMPUTE_MAX_DEVICES 16 15 | 16 | struct ggml_vk_device { 17 | int index; 18 | int type; // same as VkPhysicalDeviceType 19 | size_t heapSize; 20 | const char * name; 21 | const char * vendor; 22 | int subgroupSize; 23 | uint64_t bufferAlignment; 24 | uint64_t maxAlloc; 25 | }; 26 | 27 | struct ggml_vk_device * ggml_vk_available_devices(size_t memoryRequired, size_t * count); 28 | bool ggml_vk_get_device(struct ggml_vk_device * device, size_t memoryRequired, const char * name); 29 | bool ggml_vk_has_vulkan(void); 30 | bool ggml_vk_has_device(void); 31 | struct ggml_vk_device ggml_vk_current_device(void); 32 | 33 | // 34 | // backend API 35 | // 36 | 37 | // forward declaration 38 | typedef struct ggml_backend * ggml_backend_t; 39 | 40 | GGML_BACKEND_API ggml_backend_t ggml_backend_kompute_init(int device); 41 | 42 | GGML_BACKEND_API bool ggml_backend_is_kompute(ggml_backend_t backend); 43 | 44 | GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_kompute_buffer_type(int device); 45 | 46 | GGML_BACKEND_API ggml_backend_reg_t ggml_backend_kompute_reg(void); 47 | 48 | #ifdef __cplusplus 49 | } 50 | #endif 51 | -------------------------------------------------------------------------------- /scripts/prepare_librilight.py: -------------------------------------------------------------------------------- 1 | """ 2 | # Handles processing `facebookresearch/libri-light`'s unlabeled audio into a friendlier hierarchy 3 | """ 4 | 5 | import os 6 | import json 7 | 8 | datasets = ["small", "medium", "large", "duplicate"] 9 | output_dataset = "LibriLight-4K" 10 | 11 | for input_dataset in datasets: 12 | if not os.path.isdir(f'./{input_dataset}/'): 13 | continue 14 | 15 | for speaker_id in os.listdir(f'./{input_dataset}/'): 16 | if not os.path.isdir(f'./{input_dataset}/{speaker_id}/'): 17 | continue 18 | 19 | for book_name in os.listdir(f'./{input_dataset}/{speaker_id}/'): 20 | subid = 0 21 | 22 | for filename in os.listdir(f'./{input_dataset}/{speaker_id}/{book_name}'): 23 | if filename[-5:] != ".json": 24 | continue 25 | 26 | basename = filename[:-5] 27 | 28 | json_path = f'./{input_dataset}/{speaker_id}/{book_name}/{basename}.json' 29 | flac_path = f'./{input_dataset}/{speaker_id}/{book_name}/{basename}.flac' 30 | 31 | j = json.load(open(json_path, 'r', encoding="utf-8")) 32 | id = j['book_meta']['id'] 33 | 34 | json_id_path = f'./{output_dataset}/{speaker_id}/{speaker_id}_{id}_{subid}.json' 35 | flac_id_path = f'./{output_dataset}/{speaker_id}/{speaker_id}_{id}_{subid}.flac' 36 | 37 | os.makedirs(f'./{output_dataset}/{speaker_id}/', exist_ok=True) 38 | os.rename(json_path, json_id_path) 39 | os.rename(flac_path, flac_id_path) 40 | 41 | subid += 1 42 | -------------------------------------------------------------------------------- /data/tongue_twisters.txt: -------------------------------------------------------------------------------- 1 | Six sick hicks nick six slick bricks with picks and sticks. 2 | Fresh French fried fly fritters. 3 | Rory the warrior and Roger the worrier were reared wrongly in a rural brewery. 4 | Which wrist watches are Swiss wrist watches? 5 | Fred fed Ted bread and Ted fed Fred bread. 6 | The 33 thieves thought that they thrilled the throne throughout Thursday. 7 | You know New York, you need New York, you know you need unique New York. 8 | Lesser leather never weathered wetter weather better. 9 | The sixth sick sheikh’s sixth sheep’s sick. 10 | A skunk sat on a stump and thunk the stump stunk, but the stump thunk the skunk stunk. 11 | Thirty-three thirsty, thundering thoroughbreds thumped Mr. Thurber on Thursday. 12 | Wayne went to wales to watch walruses. 13 | Seventy-seven benevolent elephants. 14 | Send toast to ten tense stout saints’ ten tall tents. 15 | I slit the sheet, the sheet I slit, and on the slitted sheet I sit. 16 | Give papa a cup of proper coffee in a copper coffee cup. 17 | She sells seashells by the seashore. 18 | Peter Piper picked a peck of pickled peppers. How many pickled peppers did Peter Piper pick? 19 | Pad kid poured curd pulled cod. 20 | Fuzzy Wuzzy was a bear. Fuzzy Wuzzy had no hair. Fuzzy Wuzzy wasn’t very fuzzy, was he? 21 | Supercalifragilisticexpialidocious. 22 | How much wood would a woodchuck chuck if a woodchuck could chuck wood? He would chuck, he would, as much as he could, and chuck as much wood as a woodchuck would if a woodchuck could chuck wood. 23 | Buffalo buffalo Buffalo buffalo buffalo buffalo Buffalo buffalo. -------------------------------------------------------------------------------- /docs/metrics.md: -------------------------------------------------------------------------------- 1 | # `metrics.py` 2 | 3 | This file provides helper functions for computing objective metrics, such as word-error rate (WER), character-error rate (CER), phoneme-error rate (PER), and speaker similarity (SIM-O). 4 | 5 | ## WER / CER 6 | 7 | Word-error rate (WER) is simply computed by transcribing the requested input, and comparing its transcription against the target transcription. 8 | * The transcription is cleaned up and normalized to account for inconsistencies between transcriptions with `openai/whisper-large-v3` with the nuances of English. 9 | * Languages without spaces between words (Chinese, Japanese) should not rely on this, and instead rely on the CER. 10 | 11 | Character-error rate (CER) does the same thing as WER, but on a character basis rather than a word basis. 12 | 13 | Phoneme-error rate (PER) does the same thing as CER, but on the phonemized transcription instead. As this is a speech model, this metric is more correct than the prior metrics, but this isn't a universal metric for comparison, as most models don't report this. 14 | 15 | All rates are un-normalized because I think that's the right way to go about it? Papers aren't clear that they do this, but the error rates are even more unusually low without this. 16 | 17 | ## SIM-O 18 | 19 | Speaker similarity (SIM-O) is computed by obtaining the embedding of each speaker (the output audio and the input prompt), and computing the cosine similarity between those two embeddings. 20 | 21 | These embeddings are obtained through a finetune of WavLM-large geared towards speaker verification. -------------------------------------------------------------------------------- /scripts/deduplicate_librilight_libritts.py: -------------------------------------------------------------------------------- 1 | """ 2 | # Helper script to try and detect any duplications between LibriLight and LibriTTS (I don't think there were any) 3 | """ 4 | 5 | import os 6 | import json 7 | 8 | librilight_dir = "LibriLight-6K" 9 | libritts_dir = "LibriTTS-Train" 10 | 11 | librilight_data = {} 12 | libritts_data = {} 13 | 14 | for speaker_id in os.listdir(f'./{librilight_dir}/'): 15 | for filename in os.listdir(f'./{librilight_dir}/{speaker_id}'): 16 | parts = filename.split("_") 17 | book_id = parts[1] 18 | subid = parts[2] 19 | 20 | if speaker_id not in librilight_data: 21 | librilight_data[speaker_id] = {} 22 | if book_id not in librilight_data[speaker_id]: 23 | librilight_data[speaker_id][book_id] = [] 24 | librilight_data[speaker_id][book_id].append(subid) 25 | 26 | for speaker_id in os.listdir(f'./{libritts_dir}/'): 27 | for filename in os.listdir(f'./{libritts_dir}/{speaker_id}'): 28 | parts = filename.split("_") 29 | book_id = parts[1] 30 | subid = parts[2] 31 | 32 | if speaker_id not in libritts_data: 33 | libritts_data[speaker_id] = {} 34 | if book_id not in libritts_data[speaker_id]: 35 | libritts_data[speaker_id][book_id] = [] 36 | libritts_data[speaker_id][book_id].append(subid) 37 | 38 | duplicates = [] 39 | 40 | for speaker_id, books in libritts_data.items(): 41 | if speaker_id not in librilight_data: 42 | continue 43 | for book_id, _ in books.items(): 44 | if book_id not in librilight_data[speaker_id]: 45 | continue 46 | print(f'Duplicate: {speaker_id}/{book_id}') 47 | duplicates.append(f'{speaker_id}/{book_id}') 48 | 49 | print("Duplicates:", duplicates) -------------------------------------------------------------------------------- /docs/engines.md: -------------------------------------------------------------------------------- 1 | # `engines/*` 2 | 3 | This folder contains the necessary abstractions for handling training of models through either a local (`base`) backend, or additional wrappers (like DeepSpeed, and in the future Accelerate and Lightning). 4 | 5 | This architecture is partially lifted from the original implementation, but expanded for both my needs and modularity for other backends. 6 | 7 | An `Engine` is just a wrapper that contains training metadata for the loaded module. 8 | 9 | An `Engines` is a dict of `Engine`s, and extra functions to allow iterating through its contents, allowing for simultaneous loading and training of engines for a shared dataloader iteration. 10 | 11 | ## `__init__.py` 12 | 13 | This script handles the bulk of loading a model and wrapping the model with the requested engine type. 14 | 15 | The checkpoint or weight path is automatically deduced, as well as pre-processing the state dict (if requested) before loading it. 16 | * resizing modules from the weights to the requested configuration in the YAML is done here. 17 | * replacing modules with quantized versions or LoRAs are applied here. 18 | * the requested optimizer, and params to freeze, for a model is applied here. 19 | 20 | ## `base.py` 21 | 22 | The internal (`local`) implementation of orchestrating training. The basics are handled here, from automatic-mixed-precision, gradient accumulation, loss scaling, etc. 23 | 24 | Functions for other backends are also defined here, such as the training step function. 25 | 26 | ## `deepspeed.py` 27 | 28 | A backend relying on `deepspeed` for its orchestration, which offers additional features that can be defined under `cfg.trainer.deepspeed`. -------------------------------------------------------------------------------- /scripts/process_seed-tts.py: -------------------------------------------------------------------------------- 1 | """ 2 | Handles processing seed-tts-eval's dataset into something to be used for vall_e.demo 3 | 4 | Reads from meta.lst, a text file where each utterance is formatted as: 5 | 6 | ||| 7 | """ 8 | 9 | import os 10 | import json 11 | import argparse 12 | import torch 13 | import shutil 14 | import torchaudio 15 | import numpy as np 16 | 17 | from tqdm.auto import tqdm 18 | from pathlib import Path 19 | 20 | def process( 21 | input_dir=Path("./seedtts_testset/zh/"), 22 | list_name="./hardcase.lst", 23 | wav_dir="./wavs/", 24 | output_dir=Path("./dataset/seed-tts-eval-hard/"), 25 | ): 26 | language = "auto" 27 | 28 | if "en" in str(input_dir): 29 | language = "en" 30 | elif "zh" in str(input_dir): 31 | language = "zh" 32 | 33 | output_dir.mkdir(parents=True, exist_ok=True) 34 | 35 | # read manifest 36 | lines = open(input_dir / list_name).read() 37 | lines = lines.split("\n") 38 | # split it even further 39 | for line in lines: 40 | if not line: 41 | continue 42 | filename, prompt_text, prompt_wav, text = line.split("|") 43 | 44 | (output_dir / filename / "out").mkdir(parents=True, exist_ok=True) 45 | 46 | open( output_dir / filename / "prompt.txt", "w", encoding="utf-8" ).write( text ) 47 | open( output_dir / filename / "language.txt", "w", encoding="utf-8" ).write( language ) 48 | 49 | reference_wav = (input_dir / wav_dir / filename).with_suffix(".wav") 50 | if not reference_wav.exists(): 51 | continue 52 | 53 | shutil.copy(reference_wav, output_dir / filename / "reference.wav" ) 54 | shutil.copy(input_dir / prompt_wav, output_dir / filename / "prompt.wav" ) 55 | 56 | if __name__ == "__main__": 57 | process() -------------------------------------------------------------------------------- /vall_e.cpp/include/ggml-cuda.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "ggml.h" 4 | #include "ggml-backend.h" 5 | 6 | #ifdef __cplusplus 7 | extern "C" { 8 | #endif 9 | 10 | #ifdef GGML_USE_HIP 11 | #define GGML_CUDA_NAME "ROCm" 12 | #define GGML_CUBLAS_NAME "hipBLAS" 13 | #elif defined(GGML_USE_MUSA) 14 | #define GGML_CUDA_NAME "MUSA" 15 | #define GGML_CUBLAS_NAME "muBLAS" 16 | #else 17 | #define GGML_CUDA_NAME "CUDA" 18 | #define GGML_CUBLAS_NAME "cuBLAS" 19 | #endif 20 | #define GGML_CUDA_MAX_DEVICES 16 21 | 22 | // backend API 23 | GGML_BACKEND_API ggml_backend_t ggml_backend_cuda_init(int device); 24 | 25 | GGML_BACKEND_API bool ggml_backend_is_cuda(ggml_backend_t backend); 26 | 27 | // device buffer 28 | GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_cuda_buffer_type(int device); 29 | 30 | // split tensor buffer that splits matrices by rows across multiple devices 31 | GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_cuda_split_buffer_type(int main_device, const float * tensor_split); 32 | 33 | // pinned host buffer for use with the CPU backend for faster copies between CPU and GPU 34 | GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_cuda_host_buffer_type(void); 35 | 36 | GGML_BACKEND_API int ggml_backend_cuda_get_device_count(void); 37 | GGML_BACKEND_API void ggml_backend_cuda_get_device_description(int device, char * description, size_t description_size); 38 | GGML_BACKEND_API void ggml_backend_cuda_get_device_memory(int device, size_t * free, size_t * total); 39 | 40 | GGML_BACKEND_API bool ggml_backend_cuda_register_host_buffer(void * buffer, size_t size); 41 | GGML_BACKEND_API void ggml_backend_cuda_unregister_host_buffer(void * buffer); 42 | 43 | GGML_BACKEND_API ggml_backend_reg_t ggml_backend_cuda_reg(void); 44 | 45 | #ifdef __cplusplus 46 | } 47 | #endif 48 | -------------------------------------------------------------------------------- /vall_e.cpp/include/ggml-cpp.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #ifndef __cplusplus 4 | #error "This header is for C++ only" 5 | #endif 6 | 7 | #include "ggml.h" 8 | #include "ggml-alloc.h" 9 | #include "ggml-backend.h" 10 | #include "gguf.h" 11 | #include 12 | 13 | // Smart pointers for ggml types 14 | 15 | // ggml 16 | 17 | struct ggml_context_deleter { void operator()(ggml_context * ctx) { ggml_free(ctx); } }; 18 | struct gguf_context_deleter { void operator()(gguf_context * ctx) { gguf_free(ctx); } }; 19 | 20 | typedef std::unique_ptr ggml_context_ptr; 21 | typedef std::unique_ptr gguf_context_ptr; 22 | 23 | // ggml-alloc 24 | 25 | struct ggml_gallocr_deleter { void operator()(ggml_gallocr_t galloc) { ggml_gallocr_free(galloc); } }; 26 | 27 | typedef std::unique_ptr ggml_gallocr_ptr; 28 | 29 | // ggml-backend 30 | 31 | struct ggml_backend_deleter { void operator()(ggml_backend_t backend) { ggml_backend_free(backend); } }; 32 | struct ggml_backend_buffer_deleter { void operator()(ggml_backend_buffer_t buffer) { ggml_backend_buffer_free(buffer); } }; 33 | struct ggml_backend_event_deleter { void operator()(ggml_backend_event_t event) { ggml_backend_event_free(event); } }; 34 | struct ggml_backend_sched_deleter { void operator()(ggml_backend_sched_t sched) { ggml_backend_sched_free(sched); } }; 35 | 36 | typedef std::unique_ptr ggml_backend_ptr; 37 | typedef std::unique_ptr ggml_backend_buffer_ptr; 38 | typedef std::unique_ptr ggml_backend_event_ptr; 39 | typedef std::unique_ptr ggml_backend_sched_ptr; 40 | -------------------------------------------------------------------------------- /vall_e.cpp/include/ggml-sycl.h: -------------------------------------------------------------------------------- 1 | // 2 | // MIT license 3 | // Copyright (C) 2024 Intel Corporation 4 | // SPDX-License-Identifier: MIT 5 | // 6 | 7 | #pragma once 8 | 9 | #include "ggml.h" 10 | #include "ggml-backend.h" 11 | 12 | #define GGML_SYCL_NAME "SYCL" 13 | #define GGML_SYCL_MAX_DEVICES 48 14 | 15 | #ifdef __cplusplus 16 | extern "C" { 17 | #endif 18 | 19 | // backend API 20 | GGML_BACKEND_API ggml_backend_t ggml_backend_sycl_init(int device); 21 | 22 | GGML_BACKEND_API bool ggml_backend_is_sycl(ggml_backend_t backend); 23 | 24 | // devide buffer 25 | GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_sycl_buffer_type(int device); 26 | 27 | // split tensor buffer that splits matrices by rows across multiple devices 28 | GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_sycl_split_buffer_type(const float * tensor_split); 29 | 30 | // pinned host buffer for use with the CPU backend for faster copies between CPU and GPU 31 | GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_sycl_host_buffer_type(void); 32 | 33 | GGML_BACKEND_API void ggml_backend_sycl_print_sycl_devices(void); 34 | GGML_BACKEND_API void ggml_backend_sycl_get_gpu_list(int *id_list, int max_len); 35 | GGML_BACKEND_API void ggml_backend_sycl_get_device_description(int device, 36 | char *description, 37 | size_t description_size); 38 | GGML_BACKEND_API int ggml_backend_sycl_get_device_count(); 39 | GGML_BACKEND_API void ggml_backend_sycl_get_device_memory(int device, size_t *free, size_t *total); 40 | 41 | // SYCL doesn't support registering host memory, keep here for reference 42 | // GGML_BACKEND_API bool ggml_backend_sycl_register_host_buffer(void * buffer, size_t size); 43 | // GGML_BACKEND_API void ggml_backend_sycl_unregister_host_buffer(void * buffer); 44 | 45 | GGML_BACKEND_API ggml_backend_reg_t ggml_backend_sycl_reg(void); 46 | 47 | #ifdef __cplusplus 48 | } 49 | #endif 50 | -------------------------------------------------------------------------------- /vall_e.cpp/include/llama-impl.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "ggml.h" // for ggml_log_level 4 | 5 | #include 6 | #include 7 | 8 | #ifdef __GNUC__ 9 | # if defined(__MINGW32__) && !defined(__clang__) 10 | # define LLAMA_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__))) 11 | # else 12 | # define LLAMA_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__))) 13 | # endif 14 | #else 15 | # define LLAMA_ATTRIBUTE_FORMAT(...) 16 | #endif 17 | 18 | // 19 | // logging 20 | // 21 | 22 | LLAMA_ATTRIBUTE_FORMAT(2, 3) 23 | void llama_log_internal (ggml_log_level level, const char * format, ...); 24 | void llama_log_callback_default(ggml_log_level level, const char * text, void * user_data); 25 | 26 | #define LLAMA_LOG(...) llama_log_internal(GGML_LOG_LEVEL_NONE , __VA_ARGS__) 27 | #define LLAMA_LOG_INFO(...) llama_log_internal(GGML_LOG_LEVEL_INFO , __VA_ARGS__) 28 | #define LLAMA_LOG_WARN(...) llama_log_internal(GGML_LOG_LEVEL_WARN , __VA_ARGS__) 29 | #define LLAMA_LOG_ERROR(...) llama_log_internal(GGML_LOG_LEVEL_ERROR, __VA_ARGS__) 30 | #define LLAMA_LOG_DEBUG(...) llama_log_internal(GGML_LOG_LEVEL_DEBUG, __VA_ARGS__) 31 | #define LLAMA_LOG_CONT(...) llama_log_internal(GGML_LOG_LEVEL_CONT , __VA_ARGS__) 32 | 33 | // 34 | // helpers 35 | // 36 | 37 | template 38 | struct no_init { 39 | T value; 40 | no_init() { /* do nothing */ } 41 | }; 42 | 43 | struct time_meas { 44 | time_meas(int64_t & t_acc, bool disable = false); 45 | ~time_meas(); 46 | 47 | const int64_t t_start_us; 48 | 49 | int64_t & t_acc; 50 | }; 51 | 52 | void replace_all(std::string & s, const std::string & search, const std::string & replace); 53 | 54 | // TODO: rename to llama_format ? 55 | LLAMA_ATTRIBUTE_FORMAT(1, 2) 56 | std::string format(const char * fmt, ...); 57 | 58 | std::string llama_format_tensor_shape(const std::vector & ne); 59 | std::string llama_format_tensor_shape(const struct ggml_tensor * t); 60 | 61 | std::string gguf_kv_to_str(const struct gguf_context * ctx_gguf, int i); 62 | -------------------------------------------------------------------------------- /vall_e/metrics.py: -------------------------------------------------------------------------------- 1 | # handles objective metric calculations, such as WER and SIM-O 2 | 3 | #from .emb.transcribe import transcribe 4 | from .emb.similar import speaker_similarity_embedding 5 | from .emb.transcribe import transcribe 6 | from .emb.g2p import detect_language, coerce_to_hiragana, encode 7 | from .data import normalize_text 8 | 9 | import torch.nn.functional as F 10 | 11 | from pathlib import Path 12 | from torcheval.metrics.functional import word_error_rate 13 | from torchmetrics.functional.text import char_error_rate 14 | 15 | import warnings 16 | warnings.simplefilter(action='ignore', category=FutureWarning) 17 | warnings.simplefilter(action='ignore', category=UserWarning) 18 | 19 | def wer( audio, reference, language="auto", phonemize=True, normalize=True, **transcription_kwargs ): 20 | if language == "auto": 21 | language = detect_language( reference ) 22 | 23 | transcription = transcribe( audio, language=language, align=False, **transcription_kwargs ) 24 | 25 | if language == "auto": 26 | language = transcription["language"] 27 | 28 | transcription = transcription["text"] 29 | 30 | # reference audio needs transcribing too 31 | if isinstance( reference, Path ): 32 | reference = transcribe( reference, language=language, align=False, **transcription_kwargs )["text"] 33 | 34 | if language == "ja": 35 | transcription = coerce_to_hiragana( transcription ) 36 | reference = coerce_to_hiragana( reference ) 37 | 38 | if phonemize: 39 | transcription = encode( transcription, language=language ) 40 | reference = encode( reference, language=language ) 41 | elif normalize: 42 | transcription = normalize_text( transcription, language=language ) 43 | reference = normalize_text( reference, language=language ) 44 | 45 | wer_score = word_error_rate([transcription], [reference]).item() 46 | # un-normalize 47 | wer_score *= len(reference.split()) 48 | 49 | cer_score = char_error_rate([transcription], [reference]).item() 50 | # un-normalize 51 | cer_score *= len(reference) 52 | 53 | return wer_score, cer_score 54 | 55 | def sim_o( audio, reference, **kwargs ): 56 | audio_emb = speaker_similarity_embedding( audio, **kwargs ) 57 | reference_emb = speaker_similarity_embedding( reference, **kwargs ) 58 | 59 | return F.cosine_similarity( audio_emb, reference_emb, dim=-1 ).item() -------------------------------------------------------------------------------- /vall_e/models/arch/__init__.py: -------------------------------------------------------------------------------- 1 | AVAILABLE_ARCHES = [] 2 | ERROR_ARCHES = {} 3 | 4 | try: 5 | from .llama import Config as LlamaConfig, Model as LlamaModel, Attention as LlamaAttention, AVAILABLE_ATTENTIONS 6 | AVAILABLE_ARCHES.append("llama") 7 | except Exception as e: 8 | ERROR_ARCHES["llama"] = e 9 | AVAILABLE_ATTENTIONS = [] 10 | pass 11 | 12 | """ 13 | try: 14 | from .transformer import SinusoidalEmbedding, Block as TransformerBlock 15 | AVAILABLE_ARCHES.append("transformer") 16 | except Exception as e: 17 | ERROR_ARCHES["transformer"] = e 18 | pass 19 | 20 | try: 21 | from .retnet import RetNetDecoder, RetNetConfig 22 | AVAILABLE_ARCHES.append("retnet") 23 | except Exception as e: 24 | ERROR_ARCHES["retnet"] = e 25 | pass 26 | 27 | try: 28 | from .retnet_syncdoth.retnet_ts import RetNetDecoder as RetNetDecoder_TS, RetNetConfig as RetNetConfig_TS 29 | AVAILABLE_ARCHES.append("retnet-ts") 30 | except Exception as e: 31 | ERROR_ARCHES["retnet-ts"] = e 32 | pass 33 | 34 | try: 35 | from .retnet_syncdoth.retnet_hf import RetNetDecoder as RetNetDecoder_HF, RetNetConfig as RetNetConfig_HF, RetNetForCausalLM 36 | AVAILABLE_ARCHES.append("retnet-hf") 37 | except Exception as e: 38 | ERROR_ARCHES["retnet-hf"] = e 39 | pass 40 | 41 | try: 42 | from .bitnet import BitNetTransformer 43 | AVAILABLE_ARCHES.append("bitnet") 44 | except Exception as e: 45 | ERROR_ARCHES["bitnet"] = e 46 | pass 47 | 48 | try: 49 | from .mixtral import MixtralModel, MixtralConfig, MixtralAttention, MixtralAttention_Adapted, MixtralModel_Adapted, load_balancing_loss_func 50 | AVAILABLE_ARCHES.append("mixtral") 51 | except Exception as e: 52 | ERROR_ARCHES["mixtral"] = e 53 | 54 | try: 55 | from .mamba import MambaModel, Mamba2Model, MambaConfig, Mamba2Config 56 | AVAILABLE_ARCHES.append("mamba") 57 | AVAILABLE_ARCHES.append("mamba2") 58 | except Exception as e: 59 | ERROR_ARCHES["mamba"] = e 60 | ERROR_ARCHES["mamba2"] = e 61 | """ 62 | """ 63 | try: 64 | from .mamba import MambaMixelModel, MambaLMHeadModel, MambaConfig 65 | AVAILABLE_ARCHES.append("mamba") 66 | AVAILABLE_ARCHES.append("mamba2") 67 | except Exception as e: 68 | ERROR_ARCHES["mamba"] = e 69 | ERROR_ARCHES["mamba2"] = e 70 | 71 | try: 72 | from .mamba_vasqu import Mamba2Model_HF, Mamba2Config_HF 73 | AVAILABLE_ARCHES.append("mamba2-hf") 74 | except Exception as e: 75 | ERROR_ARCHES["mamba2-hf"] = e 76 | """ -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 | 3 |

4 | 5 | # VALL'E 6 | 7 | An unofficial PyTorch implementation of [VALL-E](https://vall-e-demo.ecker.tech/) (last updated: `2025.05.30`), utilizing the [EnCodec](https://github.com/facebookresearch/encodec) encoder/decoder. 8 | 9 | A demo is available on HuggingFace [here](https://huggingface.co/spaces/ecker/vall-e). 10 | 11 | ## Requirements 12 | 13 | Besides a working PyTorch environment, the only hard requirement is [`espeak-ng`](https://github.com/espeak-ng/espeak-ng/) for phonemizing text: 14 | - Linux users can consult their package managers on installing `espeak`/`espeak-ng`. 15 | - Windows users are required to install [`espeak-ng`](https://github.com/espeak-ng/espeak-ng/releases/tag/1.51#Assets). 16 | + additionally, you may be required to set the `PHONEMIZER_ESPEAK_LIBRARY` environment variable to specify the path to `libespeak-ng.dll`. 17 | - In the future, an internal homebrew to replace this would be fantastic. 18 | 19 | ## Install 20 | 21 | Simply run `pip install git+https://git.ecker.tech/mrq/vall-e` or `pip install git+https://github.com/e-c-k-e-r/vall-e`. 22 | 23 | This repo is tested under Python versions `3.10.9`, `3.11.3`, and `3.12.3`. 24 | 25 | ### Additional Implementations 26 | 27 | An "HF"-ified version of the model is available as [`ecker/vall-e@hf`](https://huggingface.co/ecker/vall-e/tree/hf), but it does require some additional efforts (see the `__main__` of [`./vall_e/models/base.py`](./vall_e/models/base.py) for details). 28 | 29 | Additionally, [`vall_e.cpp`](./vall_e.cpp/) is available. Consult its README for more details. 30 | 31 | ## Pre-Trained Model 32 | 33 | Pre-trained weights can be acquired from 34 | * [here](https://huggingface.co/ecker/vall-e) or automatically when either inferencing or running the web UI. 35 | * `./scripts/setup.sh`, a script to setup a proper environment and download the weights. This will also automatically create a `venv`. 36 | * when inferencing, either through the web UI or CLI, if no model is passed, the default model will download automatically instead, and should automatically update. 37 | 38 | ## Documentation 39 | 40 | The provided documentation under [./docs/](./docs/) should provide thorough coverage over most, if not all, of this project. 41 | 42 | Markdown files should correspond directly to their respective file or folder under `./vall_e/`. -------------------------------------------------------------------------------- /vall_e/utils/distributed.py: -------------------------------------------------------------------------------- 1 | """ 2 | # https://github.com/enhuiz/pytorch-training-utilities 3 | """ 4 | 5 | import os 6 | import socket 7 | 8 | from functools import cache, wraps 9 | from typing import Callable 10 | 11 | import torch 12 | import torch.distributed as dist 13 | from torch.nn.parallel import DistributedDataParallel as DDP 14 | 15 | def get_free_port(): 16 | sock = socket.socket() 17 | sock.bind(("", 0)) 18 | return sock.getsockname()[1] 19 | 20 | 21 | _distributed_initialized = False 22 | def init_distributed( fn, *args, **kwargs ): 23 | torch.cuda.set_device(local_rank()) 24 | fn(*args, **kwargs) 25 | _distributed_initialized = True 26 | 27 | def distributed_initialized(): 28 | return _distributed_initialized 29 | 30 | def cleanup_distributed(): 31 | dist.barrier() 32 | dist.destroy_process_group() 33 | 34 | @cache 35 | def fix_unset_envs(): 36 | envs = dict( 37 | RANK="0", 38 | WORLD_SIZE="1", 39 | MASTER_ADDR="localhost", 40 | MASTER_PORT=str(get_free_port()), 41 | LOCAL_RANK="0", 42 | ) 43 | 44 | for key in envs: 45 | value = os.getenv(key) 46 | if value is not None: 47 | return 48 | 49 | for key, value in envs.items(): 50 | os.environ[key] = value 51 | 52 | 53 | def local_rank(): 54 | return int(os.getenv("LOCAL_RANK", 0)) 55 | 56 | def global_rank(): 57 | return int(os.getenv("RANK", 0)) 58 | 59 | def world_size(): 60 | return int(os.getenv("WORLD_SIZE", 1)) 61 | 62 | 63 | def is_local_leader(): 64 | return local_rank() == 0 65 | 66 | 67 | def is_global_leader(): 68 | return global_rank() == 0 69 | 70 | 71 | def local_leader_only(fn=None, *, default=None) -> Callable: 72 | def wrapper(fn): 73 | @wraps(fn) 74 | def wrapped(*args, **kwargs): 75 | if is_local_leader(): 76 | return fn(*args, **kwargs) 77 | return default 78 | 79 | return wrapped 80 | 81 | if fn is None: 82 | return wrapper 83 | 84 | return wrapper(fn) 85 | 86 | 87 | def global_leader_only(fn: Callable | None = None, *, default=None) -> Callable: 88 | def wrapper(fn): 89 | @wraps(fn) 90 | def wrapped(*args, **kwargs): 91 | if is_global_leader(): 92 | return fn(*args, **kwargs) 93 | return default 94 | 95 | return wrapped 96 | 97 | if fn is None: 98 | return wrapper 99 | 100 | return wrapper(fn) 101 | 102 | def ddp_model(model): 103 | return DDP(model.to(device='cuda'), [local_rank()], find_unused_parameters=True) -------------------------------------------------------------------------------- /vall_e.cpp/include/ggml-metal.h: -------------------------------------------------------------------------------- 1 | // Note: this description is outdated 2 | // 3 | // An interface allowing to compute ggml_cgraph with Metal 4 | // 5 | // This is a fully functional interface that extends ggml with GPU support for Apple devices. 6 | // A similar interface can be created for other GPU backends (e.g. Vulkan, CUDA, etc.) 7 | // 8 | // How it works? 9 | // 10 | // As long as your program can create and evaluate a ggml_cgraph on the CPU, you can use this 11 | // interface to evaluate the same graph on the GPU. Instead of using ggml_graph_compute(), you 12 | // use ggml_metal_graph_compute() (or ggml_vulkan_graph_compute(), etc.) 13 | // 14 | // You only need to make sure that all memory buffers that you used during the graph creation 15 | // are mapped to the device memory with the ggml_metal_add_buffer() function. This mapping is 16 | // used during the graph evaluation to determine the arguments of the compute kernels. 17 | // 18 | // Synchronization between device and host memory (for example for input and output tensors) 19 | // is done with the ggml_metal_set_tensor() and ggml_metal_get_tensor() functions. 20 | // 21 | 22 | #pragma once 23 | 24 | #include "ggml.h" 25 | #include "ggml-backend.h" 26 | 27 | #include 28 | #include 29 | 30 | struct ggml_tensor; 31 | struct ggml_cgraph; 32 | 33 | #ifdef __cplusplus 34 | extern "C" { 35 | #endif 36 | 37 | // 38 | // backend API 39 | // user-code should use only these functions 40 | // 41 | 42 | GGML_BACKEND_API ggml_backend_t ggml_backend_metal_init(void); 43 | 44 | GGML_BACKEND_API bool ggml_backend_is_metal(ggml_backend_t backend); 45 | 46 | GGML_DEPRECATED( 47 | GGML_BACKEND_API ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, size_t size, size_t max_size), 48 | "obsoleted by the new device interface - https://github.com/ggml-org/llama.cpp/pull/9713"); 49 | 50 | GGML_BACKEND_API void ggml_backend_metal_set_abort_callback(ggml_backend_t backend, ggml_abort_callback abort_callback, void * user_data); 51 | 52 | GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_metal_buffer_type(void); 53 | 54 | // helper to check if the device supports a specific family 55 | // ideally, the user code should be doing these checks 56 | // ref: https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf 57 | GGML_BACKEND_API bool ggml_backend_metal_supports_family(ggml_backend_t backend, int family); 58 | 59 | // capture all command buffers committed the next time `ggml_backend_graph_compute` is called 60 | GGML_BACKEND_API void ggml_backend_metal_capture_next_compute(ggml_backend_t backend); 61 | 62 | GGML_BACKEND_API ggml_backend_reg_t ggml_backend_metal_reg(void); 63 | 64 | #ifdef __cplusplus 65 | } 66 | #endif 67 | -------------------------------------------------------------------------------- /vall_e.cpp/README.md: -------------------------------------------------------------------------------- 1 | # vall_e.cpp 2 | 3 | This is an implementation that makes use of [llama.cpp](https://github.com/ggerganov/llama.cpp/) and [encodec.cpp](https://github.com/PABannier/encodec.cpp). 4 | 5 | Model weights can: 6 | * be found at [`ecker/vall-e@gguf`](https://huggingface.co/ecker/vall-e/tree/gguf) 7 | * converted with `vall_e.export --yaml=./model_path/config.yaml --hf`, then running `python3 /path/to/your/llama.cpp/convert_hf_to_gguf ./model_path/hf/` 8 | 9 | ## Build 10 | 11 | Populate `./include/` with the `ggml`, `llama.cpp`, and `encodec.cpp` headers (although these headers should already be provided). 12 | * `encodec.cpp` and `llama.cpp` need to have their CMake files generated with `-DBUILD_SHARED_LIBS=On` passed. 13 | 14 | Populate `./lib/` with the compiled libraries of `llama.cpp`, `encodec.cpp`, and `espeak-ng` (if not already in your `LD_LIBRARY_PATH`). 15 | 16 | Run `make`. 17 | * `make lib` will generate the shared library (rename the `.so` to `.dll` under Windows). 18 | 19 | ### Required Modifications 20 | 21 | [`encodec.cpp`](https://github.com/PABannier/encodec.cpp) requires updating its GGML copy to the latest version, which requires a few lines to get the CPU backend working (per my [fork](https://github.com/e-c-k-e-r/encodec.cpp)). 22 | 23 | [`llama.cpp`](https://github.com/ggerganov/llama.cpp) only possible modification needs to ensure that a non-causal attention mask is used; everything necessary can be hacked together with clever tricks. 24 | * initially written on commit `9ba399dfa7f115effc63d48e6860a94c9faa31b2`, updated to commit `7a84777f42a9b3ba47db5d20b7662f8ddf92f652` 25 | 26 | ## To-Do 27 | 28 | * [x] converted model to GGUF 29 | * [x] convert it without modifying any of the existing code, as the tokenizer requires some care 30 | * [x] basic framework 31 | * [x] load the quantized model 32 | * [x] orchestrate the required embeddings 33 | * [x] juggle the output head / classifier properly 34 | * [x] phonemize text 35 | * with the help of espeak-ng 36 | * [x] tokenize phonemes 37 | * tokenize with `llama_tokenize` instead of a homebrewed method because the tokenizer is being a huge thorn 38 | * [x] load audio from disk 39 | * [x] encode audio 40 | * [x] sum embeddings for the `prom` and prior `resp`s 41 | * [x] working `AR` output 42 | * [x] `AR` sampling 43 | * [x] working `NAR-len` output 44 | * [x] `NAR-len` sampling 45 | * [ ] proper scoring 46 | * [x] working `NAR` output 47 | * [x] `NAR` sampling 48 | * [x] decode audio to disk 49 | * [x] a functional CLI 50 | * [x] actually make it work 51 | * [x] clean up to make the code usable elsewhere 52 | * [x] configured to allow for being used as a lib 53 | * (I do need to validate this in my engine project, but that's in MSYS2) 54 | * [ ] feature parity with the PyTorch version 55 | * [ ] vocos 56 | * [ ] additional tasks 57 | * [ ] `stt` 58 | * [x] `ns` / `sr` 59 | * [ ] samplers -------------------------------------------------------------------------------- /scripts/parse_ppp.py: -------------------------------------------------------------------------------- 1 | """ 2 | # Helper script to parse PPP dataset into a friendlier hierarchy 3 | """ 4 | 5 | import os 6 | import json 7 | import torch 8 | 9 | from tqdm.auto import tqdm 10 | from pathlib import Path 11 | from vall_e.emb.g2p import encode as valle_phonemize 12 | from vall_e.emb.qnt import encode_from_file as valle_quantize, _replace_file_extension 13 | 14 | target = "in" 15 | 16 | audio_map = {} 17 | text_map = {} 18 | 19 | data = {} 20 | 21 | for season in os.listdir(f"./{target}/"): 22 | if not os.path.isdir(f"./{target}/{season}/"): 23 | continue 24 | 25 | for episode in os.listdir(f"./{target}/{season}/"): 26 | if not os.path.isdir(f"./{target}/{season}/{episode}/"): 27 | continue 28 | 29 | for filename in os.listdir(f"./{target}/{season}/{episode}/"): 30 | path = f'./{target}/{season}/{episode}/{filename}' 31 | attrs = filename.split("_") 32 | timestamp = f'{attrs[0]}h{attrs[1]}m{attrs[2]}s' 33 | 34 | key = f'{episode}_{timestamp}' 35 | 36 | if filename[-5:] == ".flac": 37 | name = attrs[3] 38 | emotion = attrs[4] 39 | quality = attrs[5] 40 | 41 | audio_map[key] = { 42 | "path": path, 43 | 'episode': episode, 44 | "name": name, 45 | "emotion": emotion, 46 | "quality": quality, 47 | "timestamp": timestamp, 48 | } 49 | 50 | elif filename[-4:] == ".txt": 51 | text_map[key] = open(path, encoding="utf-8").read() 52 | txts = {} 53 | wavs = [] 54 | 55 | for key, entry in audio_map.items(): 56 | path = entry['path'] 57 | name = entry['name'] 58 | emotion = entry['emotion'] 59 | quality = entry['quality'] 60 | episode = entry['episode'] 61 | path = entry['path'] 62 | timestamp = entry['timestamp'] 63 | transcription = text_map[key] 64 | if name not in data: 65 | data[name] = {} 66 | os.makedirs(f'./training/{name}/', exist_ok=True) 67 | os.makedirs(f'./voices/{name}/', exist_ok=True) 68 | 69 | key = f'{episode}_{timestamp}.flac' 70 | os.rename(path, f'./voices/{name}/{key}') 71 | 72 | data[name][key] = { 73 | "segments": [], 74 | "language": "en", 75 | "text": transcription, 76 | "misc": { 77 | "emotion": emotion, 78 | "quality": quality, 79 | "timestamp": timestamp, 80 | "episode": episode, 81 | } 82 | } 83 | 84 | path = f'./voices/{name}/{key}' 85 | txts[path] = transcription 86 | wavs.append(Path(path)) 87 | 88 | for name in data.keys(): 89 | open(f"./training/{name}/whisper.json", "w", encoding="utf-8").write( json.dumps( data[name], indent='\t' ) ) 90 | 91 | # to-do: update to "The Proper Way" 92 | # for now it can just be fed back into "The Proper Way"" 93 | """ 94 | device = "cuda" 95 | for key, text in tqdm(txts.items(), desc="Phonemizing..."): 96 | path = Path(key) 97 | phones = valle_phonemize(text) 98 | open(_replace_file_extension(path, ".phn.txt"), "w", encoding="utf-8").write(" ".join(phones)) 99 | 100 | for path in tqdm(wavs, desc="Quantizing..."): 101 | qnt = valle_quantize(path, device=device) 102 | torch.save(qnt.cpu(), _replace_file_extension(path, ".qnt.pt")) 103 | """ -------------------------------------------------------------------------------- /vall_e.cpp/include/lstm.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "ggml.h" 4 | #include "ggml-alloc.h" 5 | 6 | #include "ops.h" 7 | 8 | struct encodec_lstm { 9 | struct ggml_tensor *l0_ih_w; 10 | struct ggml_tensor *l0_hh_w; 11 | 12 | struct ggml_tensor *l0_ih_b; 13 | struct ggml_tensor *l0_hh_b; 14 | 15 | struct ggml_tensor *l1_ih_w; 16 | struct ggml_tensor *l1_hh_w; 17 | 18 | struct ggml_tensor *l1_ih_b; 19 | struct ggml_tensor *l1_hh_b; 20 | }; 21 | 22 | struct ggml_tensor *forward_pass_lstm_unilayer(struct ggml_context *ctx0, 23 | struct ggml_tensor *inp, 24 | struct ggml_tensor *weight_ih, 25 | struct ggml_tensor *weight_hh, 26 | struct ggml_tensor *bias_ih, 27 | struct ggml_tensor *bias_hh, 28 | char *prefix) { 29 | const int seq_length = inp->ne[0]; 30 | const int input_dim = inp->ne[1]; 31 | const int hidden_dim = weight_ih->ne[1] / 4; 32 | 33 | char ct_name[10]; 34 | char ht_name[10]; 35 | 36 | snprintf(ct_name, 10, "%s_ct", prefix); 37 | snprintf(ht_name, 10, "%s_ht", prefix); 38 | 39 | struct ggml_tensor *hs = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, hidden_dim, seq_length); 40 | ggml_set_input(hs); 41 | 42 | struct ggml_tensor *c_t = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, hidden_dim); 43 | ggml_set_input(c_t); 44 | ggml_set_name(c_t, ct_name); 45 | 46 | struct ggml_tensor *h_t = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, hidden_dim); 47 | ggml_set_input(h_t); 48 | ggml_set_name(h_t, ht_name); 49 | 50 | struct ggml_tensor *current = ggml_cont(ctx0, ggml_transpose(ctx0, inp)); 51 | 52 | for (int t = 0; t < seq_length; t++) { 53 | struct ggml_tensor *x_t = ggml_view_1d(ctx0, current, input_dim, t * current->nb[1]); 54 | 55 | struct ggml_tensor *inp_gates = ggml_mul_mat(ctx0, weight_ih, x_t); 56 | inp_gates = ggml_add(ctx0, inp_gates, bias_ih); 57 | 58 | struct ggml_tensor *hid_gates = ggml_mul_mat(ctx0, weight_hh, h_t); 59 | hid_gates = ggml_add(ctx0, hid_gates, bias_hh); 60 | 61 | struct ggml_tensor *out_gates = ggml_add(ctx0, inp_gates, hid_gates); 62 | 63 | struct ggml_tensor *i_t = ggml_sigmoid(ctx0, ggml_view_1d(ctx0, out_gates, hidden_dim, 0 * sizeof(float) * hidden_dim)); 64 | struct ggml_tensor *f_t = ggml_sigmoid(ctx0, ggml_view_1d(ctx0, out_gates, hidden_dim, 1 * sizeof(float) * hidden_dim)); 65 | struct ggml_tensor *g_t = ggml_tanh(ctx0 , ggml_view_1d(ctx0, out_gates, hidden_dim, 2 * sizeof(float) * hidden_dim)); 66 | struct ggml_tensor *o_t = ggml_sigmoid(ctx0, ggml_view_1d(ctx0, out_gates, hidden_dim, 3 * sizeof(float) * hidden_dim)); 67 | 68 | c_t = ggml_add(ctx0, ggml_mul(ctx0, f_t, c_t), ggml_mul(ctx0, i_t, g_t)); 69 | 70 | h_t = ggml_mul(ctx0, o_t, ggml_tanh(ctx0, c_t)); 71 | 72 | hs = ggml_set_1d(ctx0, hs, h_t, t * hs->nb[1]); 73 | } 74 | 75 | hs = ggml_cont(ctx0, ggml_transpose(ctx0, hs)); 76 | 77 | return hs; 78 | } 79 | -------------------------------------------------------------------------------- /vall_e.cpp/include/ggml-alloc.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "ggml.h" 4 | 5 | #ifdef __cplusplus 6 | extern "C" { 7 | #endif 8 | 9 | typedef struct ggml_backend_buffer_type * ggml_backend_buffer_type_t; 10 | typedef struct ggml_backend_buffer * ggml_backend_buffer_t; 11 | typedef struct ggml_backend * ggml_backend_t; 12 | 13 | // Tensor allocator 14 | struct ggml_tallocr { 15 | ggml_backend_buffer_t buffer; 16 | void * base; 17 | size_t alignment; 18 | size_t offset; 19 | }; 20 | 21 | GGML_API struct ggml_tallocr ggml_tallocr_new(ggml_backend_buffer_t buffer); 22 | GGML_API enum ggml_status ggml_tallocr_alloc(struct ggml_tallocr * talloc, struct ggml_tensor * tensor); 23 | 24 | // Graph allocator 25 | /* 26 | Example usage: 27 | ggml_gallocr_t galloc = ggml_gallocr_new(ggml_backend_cpu_buffer_type()); 28 | 29 | // optional: create a worst-case graph and reserve the buffers to avoid reallocations 30 | ggml_gallocr_reserve(galloc, build_graph(max_batch)); 31 | 32 | // allocate the graph 33 | struct ggml_cgraph * graph = build_graph(batch); 34 | ggml_gallocr_alloc_graph(galloc, graph); 35 | 36 | printf("compute buffer size: %zu bytes\n", ggml_gallocr_get_buffer_size(galloc, 0)); 37 | 38 | // evaluate the graph 39 | ggml_backend_graph_compute(backend, graph); 40 | */ 41 | 42 | // special tensor flags for use with the graph allocator: 43 | // ggml_set_input(): all input tensors are allocated at the beginning of the graph in non-overlapping addresses 44 | // ggml_set_output(): output tensors are never freed and never overwritten 45 | 46 | typedef struct ggml_gallocr * ggml_gallocr_t; 47 | 48 | GGML_API ggml_gallocr_t ggml_gallocr_new(ggml_backend_buffer_type_t buft); 49 | GGML_API ggml_gallocr_t ggml_gallocr_new_n(ggml_backend_buffer_type_t * bufts, int n_bufs); 50 | GGML_API void ggml_gallocr_free(ggml_gallocr_t galloc); 51 | 52 | // pre-allocate buffers from a measure graph - does not allocate or modify the graph 53 | // call with a worst-case graph to avoid buffer reallocations 54 | // not strictly required for single buffer usage: ggml_gallocr_alloc_graph will reallocate the buffers automatically if needed 55 | // returns false if the buffer allocation failed 56 | GGML_API bool ggml_gallocr_reserve(ggml_gallocr_t galloc, struct ggml_cgraph * graph); 57 | GGML_API bool ggml_gallocr_reserve_n( 58 | ggml_gallocr_t galloc, 59 | struct ggml_cgraph * graph, 60 | const int * node_buffer_ids, 61 | const int * leaf_buffer_ids); 62 | 63 | // automatic reallocation if the topology changes when using a single buffer 64 | // returns false if using multiple buffers and a re-allocation is needed (call ggml_gallocr_reserve_n first to set the node buffers) 65 | GGML_API bool ggml_gallocr_alloc_graph(ggml_gallocr_t galloc, struct ggml_cgraph * graph); 66 | 67 | GGML_API size_t ggml_gallocr_get_buffer_size(ggml_gallocr_t galloc, int buffer_id); 68 | 69 | // Utils 70 | // Create a buffer and allocate all the tensors in a ggml_context 71 | GGML_API struct ggml_backend_buffer * ggml_backend_alloc_ctx_tensors_from_buft(struct ggml_context * ctx, ggml_backend_buffer_type_t buft); 72 | GGML_API struct ggml_backend_buffer * ggml_backend_alloc_ctx_tensors(struct ggml_context * ctx, ggml_backend_t backend); 73 | 74 | #ifdef __cplusplus 75 | } 76 | #endif 77 | -------------------------------------------------------------------------------- /docs/utils.md: -------------------------------------------------------------------------------- 1 | # `utils/*` 2 | 3 | This folder contains helper utilities for either training or general functions of the program. 4 | 5 | These scripts are to remain agnostic to any model, to allow for reuse for other applications. 6 | 7 | ## `utils/distributed.py` 8 | 9 | This script contains the necessary code needed to utilize distributed training. 10 | 11 | Attributions are noted at the top. 12 | 13 | ## `utils/io.py` 14 | 15 | This script contains the necessary code for loading and storing state dicts, through pickles (`.pt`) or SafeTensors (`.sft`), and offers parity for each storage type. 16 | 17 | Additionally, some JSON helper functions are provided here. 18 | 19 | ## `utils/pattern.py` 20 | 21 | This script contains (unused) code related to formatting sequences of audio codes into different pattern types. 22 | 23 | Attributions are noted at the top. 24 | 25 | ## `utils/sampler.py` 26 | 27 | This script contains code to handle sampling from a list of indices. 28 | * `PoolSampler` has a master list of indices "in the marble bag" that are sampled without replacement. 29 | * `OrderedSampler` will output indices from 0 to `length`, in order. 30 | * `BatchedOrderedSampler` does the above, but will output lists of indices instead. 31 | * `RandomSampler` will output indices from 0 to `length`, randomly. 32 | 33 | Each sampler can load and store a state dict. 34 | 35 | 36 | ## `utils/utils.py` 37 | 38 | This script contains additional helper functions that do not require a dedicated file. 39 | 40 | ## `utils/train.py` 41 | 42 | This script handles the necessary code for training, such as: 43 | * iterating through a dataloader 44 | * iterating through an `Engines` to train each underlying `Engine` 45 | * printing training metrics 46 | * invoking `save`, `eval`, `export` every X iterations 47 | * handling stdin commands, such as `save`, `export`, `eval`, and `quit` 48 | 49 | ## `utils/wrapper.py` 50 | 51 | This script contains optimizations and additional code that require injecting or replacing modules. 52 | 53 | Most configurations are offered through `cfg.optimization`. 54 | 55 | ## `utils/ext/` 56 | 57 | This folder contains external code that can't be nicely referenced under a package. 58 | 59 | Proper attribution is noted at the top of each file. 60 | 61 | ### `utils/ext/apollo.py` 62 | 63 | This script contains [APOLLO](https://github.com/zhuhanqing/APOLLO), an optimizer that achieves ADAMW-like performance with very little memory cost. 64 | 65 | In testing, this seems to work fine, and the memory gains (in comparison to Prodigyopt) under the normal-specced model allows you to double the batch size. 66 | 67 | It's definitely usable under extremely low VRAM environments, and specifying `apollo-mini` will further shrink the memory requirements (but robustness is yet to be personally tested). 68 | 69 | However, after a while, it seemed to cause some steps to either cause gradient overflow or NaNs that persist even when swapping back to `prodigyopt` (but I do not know if it's at the fault of `APOLLO` or just the model eventually hitting a point of instability). 70 | 71 | ### `utils/ext/unsloth.py` 72 | 73 | This script contains Unsloth, a VRAM-saving optimization that offloads the input tensors to CPU on a backwards pass. 74 | 75 | This is mostly unncessary, as inputs are rather small themselves, but is offered nonetheless if needed through `cfg.optimizations.unsloth = True` -------------------------------------------------------------------------------- /docs/webui.md: -------------------------------------------------------------------------------- 1 | # `webui.py` 2 | 3 | A Gradio-based web UI is accessible by running `python3 -m vall_e.webui`. You can, optionally, pass: 4 | 5 | * `--yaml=./path/to/your/config.yaml`: will load the targeted YAML 6 | * `--model=./path/to/your/model.sft`: will load the targeted model weights 7 | * `--listen 0.0.0.0:7860`: will set the web UI to listen to all IPs at port 7860. Replace the IP and Port to your preference. 8 | 9 | ## Inference 10 | 11 | Synthesizing speech is simple: 12 | 13 | * `Text`: 14 | * `Input Prompt`: The guiding text prompt. Each segment will be its own generated audio to be stitched together at the end. 15 | * `Audio`: 16 | * `Audio Input`: The transcription of the audio will be inserted into the `Text/Input Prompt` box. 17 | * For `vc` task, this will serve as the guidance reference audio as well. 18 | 19 | * `Audio Input`: The reference audio for the synthesis. Under Gradio, you can trim your clip accordingly, but leaving it as-is works fine. 20 | - A properly trained model can inference without a prompt to generate a random voice (without even needing to generate a random prompt itself). 21 | * `Output`: The resultant audio. 22 | * `Inference`: Button to start generating the audio. 23 | * `Basic Settings`: Basic sampler settings for most uses. 24 | * `Max Steps`: Number of demasking steps to perform for RVQ level 0. For the `NAR-len` modality. 25 | * `Max Duration`: Maximum duration the output audio will be. 26 | * `Input Prompt Repeat/Trim Length`: The audio prompt will be this duration length, as it will either be trimmed down or repeated (although repeating might cause more harm). 27 | * `Language (Text)`: The language of the input text for phonemizing. 28 | * `Language (Output)`: The target language for the output audio. Some checkpoints of the model might ignore this due to how it was trained, unfortunately. Some models might steer the output accent. 29 | * `Task`: The task to perform (in order): Text-To-Speech, Speech Removal, Noise Reduction, Voice Conversion. 30 | * `Text Delimiter`: How to split the `Text/Input Prompt`. Sentences will split by sentences, while lines will split by new lines. 31 | * `(Rolling) Context History`: Paired with the above, the previous N utterances will serve as the prefix to extend the generation on, allowing for consistency and stability across pieces. 32 | * `Sampler Settings`: Advanced sampler settings that are common for most text LLMs, but needs experimentation. 33 | * `Experimental Settings`: Settings used for testing. `cfg.experimental=True` enables this tab. 34 | 35 | All the additional knobs have a description that can be correlated to the inferencing CLI flags. 36 | 37 | Speech-To-Text phoneme transcriptions for models that support it can be done using the `Speech-to-Text` tab. 38 | 39 | ## Dataset 40 | 41 | This tab currently only features exploring a dataset already prepared and referenced in your `config.yaml`. You can select a registered voice, and have it randomly sample an utterance. 42 | 43 | In the future, this *should* contain the necessary niceties to process raw audio into a dataset to train/finetune through, without needing to invoke the above commands to prepare the dataset. 44 | 45 | ## Settings 46 | 47 | So far, this only allows you to load a different model under a different dtype, device, and/or attention mechanism. without needing to restart. The previous model should seamlessly unload, and the new one will load in place. -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | import sys 3 | from pathlib import Path 4 | from datetime import datetime 5 | from setuptools import setup, find_packages 6 | 7 | def shell(*args): 8 | out = subprocess.check_output(args) 9 | return out.decode("ascii").strip() 10 | 11 | def write_version(version_core, pre_release=True): 12 | if pre_release: 13 | time = shell("git", "log", "-1", "--format=%cd", "--date=iso") 14 | time = datetime.strptime(time, "%Y-%m-%d %H:%M:%S %z") 15 | time = time.strftime("%Y%m%d%H%M%S") 16 | version = f"{version_core}-dev{time}" 17 | else: 18 | version = version_core 19 | 20 | with open(Path("vall_e", "version.py"), "w") as f: 21 | f.write('__version__ = "{}"\n'.format(version)) 22 | 23 | return version 24 | 25 | with open("README.md", "r") as f: 26 | long_description = f.read() 27 | 28 | platform_dependencies = [] 29 | 30 | if sys.platform.startswith("win"): 31 | platform_dependencies += ["psutil"] 32 | else: 33 | platform_dependencies += ["deepspeed>=0.7.7"] 34 | 35 | setup( 36 | name="vall-e", 37 | python_requires=">=3.10.0", 38 | version=write_version("0.0.1"), 39 | description="An unofficial implementation of the audio LM VALL-E", 40 | author="ecker", 41 | author_email="mrq@ecker.tech", 42 | long_description=long_description, 43 | long_description_content_type="text/markdown", 44 | packages=find_packages(), 45 | install_requires= 46 | platform_dependencies + [ 47 | # logging niceties 48 | "coloredlogs>=15.0.1", # barely required 49 | "humanize>=4.4.0", # not really required 50 | "matplotlib>=3.6.0", # only required for plotting 51 | "pandas>=1.5.0", # not really required 52 | 53 | # boiler plate niceties 54 | #"diskcache>=5.4.0", 55 | "einops>=0.6.0", # could be replaced 56 | "tqdm", 57 | 58 | # HF bloat 59 | "tokenizers", 60 | "transformers", 61 | "safetensors", 62 | 63 | # training bloat 64 | "auraloss[all]", # [all] is needed for MelSTFTLoss 65 | "h5py", 66 | "prodigyopt @ git+https://github.com/konstmish/prodigy", 67 | 68 | # practically the reason to use python 69 | "numpy", 70 | "torch>=1.13.0", 71 | "torchaudio>=0.13.0", 72 | "torchmetrics", 73 | 74 | # core foundations 75 | "phonemizer>=2.1.0", 76 | "encodec>=0.1.1", 77 | "vocos", 78 | 79 | # for the web UI 80 | "gradio", 81 | "nltk", # for parsing text inputs down to pieces 82 | "langdetect", # for detecting the language of a text 83 | "sounddevice", # for raw playback 84 | ], 85 | extras_require = { 86 | "all": [ 87 | # retnet backend (even though two internal copies exist) 88 | "torchscale @ git+https://git.ecker.tech/mrq/torchscale", 89 | # bitnet 90 | "bitnet", 91 | # mamba 92 | "causal-conv1d", 93 | "mamba-ssm", 94 | 95 | # 96 | "torcheval", 97 | 98 | # attention helpers 99 | "xformers", 100 | "sageattention==1.0.6", 101 | # "flash-attn" --no-build-isolation # commented out right now because I want to query this for Volta freaks like me who can't use it 102 | 103 | # other audio backend that doesn't prove fruitful 104 | "descript-audio-codec", 105 | 106 | # nemo (to-do: cut this down) 107 | "nemo-toolkit", 108 | "hydra-core", 109 | "lightning", 110 | "sentencepiece" 111 | ] 112 | }, 113 | url="https://git.ecker.tech/mrq/vall-e", 114 | ) 115 | -------------------------------------------------------------------------------- /scripts/cleanup_dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | # Helper script to clean up transcription metadata, whatever that entailed. 3 | """ 4 | 5 | import os 6 | import json 7 | import torch 8 | import torchaudio 9 | 10 | from tqdm.auto import tqdm 11 | from pathlib import Path 12 | 13 | input_dataset = "training/metadata" 14 | output_dataset = "training/metadata-cleaned" 15 | 16 | def pad(num, zeroes): 17 | return str(num).zfill(zeroes+1) 18 | 19 | for dataset_name in os.listdir(f'./{input_dataset}/'): 20 | if not os.path.isdir(f'./{input_dataset}/{dataset_name}/'): 21 | print("Is not dir:", f'./{input_dataset}/{dataset_name}/') 22 | continue 23 | 24 | for speaker_id in tqdm(os.listdir(f'./{input_dataset}/{dataset_name}/'), desc=f"Processing speaker: {dataset_name}"): 25 | if not os.path.isdir(f'./{input_dataset}/{dataset_name}/{speaker_id}'): 26 | print("Is not dir:", f'./{input_dataset}/{dataset_name}/{speaker_id}') 27 | continue 28 | 29 | inpath = Path(f'./{input_dataset}/{dataset_name}/{speaker_id}/whisper.json') 30 | outpath = Path(f'./{output_dataset}/{dataset_name}/{speaker_id}/whisper.json') 31 | 32 | if not inpath.exists(): 33 | continue 34 | 35 | if outpath.exists(): 36 | continue 37 | 38 | os.makedirs(f'./{output_dataset}/{dataset_name}/{speaker_id}/', exist_ok=True) 39 | 40 | try: 41 | in_metadata = json.loads(open(inpath, 'r', encoding='utf-8').read()) 42 | except Exception as e: 43 | print("Failed to open metadata file:", inpath) 44 | continue 45 | 46 | out_metadata = {} 47 | speaker_metadatas = {} 48 | 49 | for filename, result in in_metadata.items(): 50 | language = result["language"] if "language" in result else "en" 51 | out_metadata[filename] = { 52 | "segments": [], 53 | "language": language, 54 | "text": "", 55 | "start": 0, 56 | "end": 0, 57 | } 58 | segments = [] 59 | text = [] 60 | start = 0 61 | end = 0 62 | diarized = False 63 | 64 | for segment in result["segments"]: 65 | # diarize split 66 | if "speaker" in segment: 67 | diarized = True 68 | speaker_id = segment["speaker"] 69 | if speaker_id not in speaker_metadatas: 70 | speaker_metadatas[speaker_id] = {} 71 | 72 | if filename not in speaker_metadatas[speaker_id]: 73 | speaker_metadatas[speaker_id][filename] = { 74 | "segments": [], 75 | "language": language, 76 | "text": "", 77 | "start": 0, 78 | "end": 0, 79 | } 80 | 81 | speaker_metadatas[speaker_id][filename]["segments"].append( segment ) 82 | else: 83 | segments.append( segment ) 84 | 85 | text.append( segment["text"] ) 86 | start = min( start, segment["start"] ) 87 | end = max( end, segment["end"] ) 88 | 89 | out_metadata[filename]["segments"] = segments 90 | out_metadata[filename]["text"] = " ".join(text).strip() 91 | out_metadata[filename]["start"] = start 92 | out_metadata[filename]["end"] = end 93 | 94 | if len(segments) == 0: 95 | del out_metadata[filename] 96 | 97 | open(outpath, 'w', encoding='utf-8').write(json.dumps(out_metadata)) 98 | 99 | for speaker_id, out_metadata in speaker_metadatas.items(): 100 | os.makedirs(f'./{output_dataset}/{dataset_name}/{speaker_id}/', exist_ok=True) 101 | outpath = Path(f'./{output_dataset}/{dataset_name}/{speaker_id}/whisper.json') 102 | 103 | open(outpath, 'w', encoding='utf-8').write(json.dumps(out_metadata)) -------------------------------------------------------------------------------- /vall_e.cpp/include/encoder.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | #include "ggml.h" 6 | #include "lstm.h" 7 | 8 | // res + downsample block at some ratio 9 | struct encodec_encoder_block { 10 | // conv1 11 | struct ggml_tensor *conv_1_w; 12 | struct ggml_tensor *conv_1_b; 13 | 14 | // conv2 15 | struct ggml_tensor *conv_2_w; 16 | struct ggml_tensor *conv_2_b; 17 | 18 | // shortcut 19 | struct ggml_tensor *conv_sc_w; 20 | struct ggml_tensor *conv_sc_b; 21 | 22 | // downsampling layers 23 | struct ggml_tensor *ds_conv_w; 24 | struct ggml_tensor *ds_conv_b; 25 | }; 26 | 27 | struct encodec_encoder { 28 | struct ggml_tensor *init_conv_w; 29 | struct ggml_tensor *init_conv_b; 30 | 31 | encodec_lstm lstm; 32 | 33 | struct ggml_tensor *final_conv_w; 34 | struct ggml_tensor *final_conv_b; 35 | 36 | std::vector blocks; 37 | }; 38 | 39 | struct ggml_tensor *encodec_forward_encoder( 40 | const struct encodec_encoder *encoder, struct ggml_context *ctx0, 41 | struct ggml_tensor *inp, const int * ratios, const int kernel_size, const int res_kernel_size, 42 | const int stride) { 43 | 44 | if (!inp) { 45 | fprintf(stderr, "%s: null input tensor\n", __func__); 46 | return NULL; 47 | } 48 | 49 | struct ggml_tensor *inpL = strided_conv_1d( 50 | ctx0, inp, encoder->init_conv_w, encoder->init_conv_b, stride); 51 | 52 | for (int layer_ix = 0; layer_ix < 4; layer_ix++) { 53 | encodec_encoder_block block = encoder->blocks[layer_ix]; 54 | 55 | struct ggml_tensor *current = inpL; 56 | 57 | // shortcut 58 | struct ggml_tensor *shortcut = strided_conv_1d( 59 | ctx0, inpL, block.conv_sc_w, block.conv_sc_b, stride); 60 | 61 | // conv1 62 | current = ggml_elu(ctx0, current); 63 | 64 | current = strided_conv_1d( 65 | ctx0, current, block.conv_1_w, block.conv_1_b, stride); 66 | 67 | // conv2 68 | current = ggml_elu(ctx0, current); 69 | 70 | current = strided_conv_1d( 71 | ctx0, current, block.conv_2_w, block.conv_2_b, stride); 72 | 73 | // residual connection 74 | inpL = ggml_add(ctx0, current, shortcut); 75 | 76 | // downsampling layers 77 | inpL = ggml_elu(ctx0, inpL); 78 | 79 | inpL = strided_conv_1d( 80 | ctx0, inpL, block.ds_conv_w, block.ds_conv_b, ratios[3 - layer_ix]); 81 | } 82 | 83 | // lstm 84 | { 85 | struct ggml_tensor *cur = inpL; 86 | 87 | const encodec_lstm lstm = encoder->lstm; 88 | 89 | // first lstm layer 90 | char l0_prefix[7] = "enc_l0"; 91 | struct ggml_tensor *hs1 = forward_pass_lstm_unilayer( 92 | ctx0, cur, lstm.l0_ih_w, lstm.l0_hh_w, lstm.l0_ih_b, lstm.l0_hh_b, l0_prefix); 93 | 94 | // second lstm layer 95 | char l1_prefix[7] = "enc_l1"; 96 | struct ggml_tensor *out = forward_pass_lstm_unilayer( 97 | ctx0, hs1, lstm.l1_ih_w, lstm.l1_hh_w, lstm.l1_ih_b, lstm.l1_hh_b, l1_prefix); 98 | 99 | inpL = ggml_add(ctx0, inpL, out); 100 | } 101 | 102 | // final conv 103 | inpL = ggml_elu(ctx0, inpL); 104 | 105 | struct ggml_tensor *encoded_inp = strided_conv_1d( 106 | ctx0, inpL, encoder->final_conv_w, encoder->final_conv_b, stride); 107 | 108 | return encoded_inp; 109 | } 110 | -------------------------------------------------------------------------------- /vall_e/models/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import requests 4 | import time 5 | 6 | from tqdm import tqdm 7 | from pathlib import Path 8 | 9 | _logger = logging.getLogger(__name__) 10 | 11 | # to-do: implement automatically downloading model 12 | DEFAULT_MODEL_NAME = os.environ.get("VALLE_DEFAULT_MODEL_NAME", "ar+nar-len-llama-8.sft") 13 | DEFAULT_MODEL_DIR = Path(__file__).parent.parent.parent / 'data/models' 14 | DEFAULT_MODEL_PATH = DEFAULT_MODEL_DIR / DEFAULT_MODEL_NAME 15 | DEFAULT_MODEL_URLS = { 16 | 'ar+nar-len-llama-8.sft': 'https://huggingface.co/ecker/vall-e/resolve/main/models/ckpt/ar%2Bnar-len-llama-8/ckpt/fp32.sft', 17 | 'nemo-larger-44khz-llama-8.sft': 'https://huggingface.co/ecker/vall-e/resolve/main/models/ckpt/nemo-larger-44khz-llama-8/fp32.sft', 18 | 'wavlm_large_finetune.pth': 'https://huggingface.co/Dongchao/UniAudio/resolve/main/wavlm_large_finetune.pth', 19 | } 20 | 21 | if not DEFAULT_MODEL_PATH.exists() and Path(f"./data/models/{DEFAULT_MODEL_NAME}").exists(): 22 | DEFAULT_MODEL_DIR = Path('./data/models') 23 | DEFAULT_MODEL_PATH = DEFAULT_MODEL_DIR / DEFAULT_MODEL_NAME 24 | 25 | # kludge, probably better to use HF's model downloader function 26 | # to-do: write to a temp file then copy so downloads can be interrupted 27 | def download_model( save_path=DEFAULT_MODEL_PATH, chunkSize = 1024 ): 28 | name = save_path.name 29 | url = DEFAULT_MODEL_URLS[name] if name in DEFAULT_MODEL_URLS else None 30 | if url is None: 31 | raise Exception(f'Model requested for download but not defined: {name}') 32 | 33 | if not save_path.parent.exists(): 34 | save_path.parent.mkdir(parents=True, exist_ok=True) 35 | 36 | headers = {} 37 | # check if modified 38 | if save_path.exists(): 39 | headers = {"If-Modified-Since": time.strftime("%a, %d %b %Y %H:%M:%S GMT", time.gmtime(save_path.stat().st_mtime))} 40 | 41 | r = requests.get(url, headers=headers, stream=True) 42 | 43 | # not modified 44 | if r.status_code == 304: 45 | r.close() 46 | return 47 | 48 | # to-do: validate lengths match 49 | 50 | content_length = int(r.headers['Content-Length'] if 'Content-Length' in r.headers else r.headers['content-length']) 51 | with open(save_path, 'wb') as f: 52 | bar = tqdm( unit='B', unit_scale=True, unit_divisor=1024, total=content_length, desc=f"Downloading: {name}" ) 53 | for chunk in r.iter_content(chunk_size=chunkSize): 54 | if not chunk: 55 | continue 56 | bar.update( len(chunk)) 57 | f.write(chunk) 58 | bar.close() 59 | 60 | r.close() 61 | 62 | 63 | def get_model(config, training=True, **model_kwargs): 64 | # crunge 65 | if config.version < 7: 66 | from .ar_nar import AR_NAR 67 | ModelClass = AR_NAR 68 | else: 69 | from .ar_nar_v2 import AR_NAR_V2 70 | ModelClass = AR_NAR_V2 71 | 72 | cfg_kwargs = dict( 73 | n_phn_tokens=config.phoneme_tokens, 74 | n_audio_tokens=config.audio_tokens, 75 | n_text_tokens=config.text_tokens, 76 | d_model=config.dim, 77 | n_heads=config.heads, 78 | n_layers=config.layers, 79 | n_experts=config.experts, 80 | 81 | p_dropout=config.dropout, 82 | 83 | l_padding = config.input_alignment, 84 | 85 | training = training, 86 | config = config, 87 | ) 88 | 89 | name = config.name 90 | model = ModelClass(**(cfg_kwargs | model_kwargs)) 91 | 92 | _logger.info(f"{name} ({next(model.parameters()).dtype}): {sum(p.numel() for p in model.parameters() if p.requires_grad)} parameters") 93 | 94 | return model 95 | 96 | def get_models(models, training=True, **model_kwargs): 97 | return { model.full_name: get_model(model, training=training, **model_kwargs) for model in models } 98 | -------------------------------------------------------------------------------- /vall_e.cpp/include/decoder.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | #include "ggml.h" 6 | #include "ggml-alloc.h" 7 | #include "ggml-backend.h" 8 | 9 | #include "lstm.h" 10 | #include "utils.h" 11 | 12 | 13 | struct encodec_decoder_block { 14 | // upsampling layers 15 | struct ggml_tensor *us_conv_w; 16 | struct ggml_tensor *us_conv_b; 17 | 18 | // conv1 19 | struct ggml_tensor *conv_1_w; 20 | struct ggml_tensor *conv_1_b; 21 | 22 | // conv2 23 | struct ggml_tensor *conv_2_w; 24 | struct ggml_tensor *conv_2_b; 25 | 26 | // shortcut 27 | struct ggml_tensor *conv_sc_w; 28 | struct ggml_tensor *conv_sc_b; 29 | }; 30 | 31 | struct encodec_decoder { 32 | struct ggml_tensor *init_conv_w; 33 | struct ggml_tensor *init_conv_b; 34 | 35 | encodec_lstm lstm; 36 | 37 | struct ggml_tensor *final_conv_w; 38 | struct ggml_tensor *final_conv_b; 39 | 40 | std::vector blocks; 41 | }; 42 | 43 | struct ggml_tensor *encodec_forward_decoder( 44 | const struct encodec_decoder *decoder, struct ggml_context *ctx0, 45 | struct ggml_tensor *quantized_out, const int *ratios, const int kernel_size, const int res_kernel_size, 46 | const int stride) { 47 | 48 | if (!quantized_out) { 49 | fprintf(stderr, "%s: null input tensor\n", __func__); 50 | return NULL; 51 | } 52 | 53 | struct ggml_tensor *inpL = strided_conv_1d( 54 | ctx0, quantized_out, decoder->init_conv_w, decoder->init_conv_b, stride); 55 | 56 | // lstm 57 | { 58 | struct ggml_tensor *cur = inpL; 59 | 60 | const encodec_lstm lstm = decoder->lstm; 61 | 62 | // first lstm layer 63 | char l0_prefix[7] = "dec_l0"; 64 | struct ggml_tensor *hs1 = forward_pass_lstm_unilayer( 65 | ctx0, cur, lstm.l0_ih_w, lstm.l0_hh_w, lstm.l0_ih_b, lstm.l0_hh_b, l0_prefix); 66 | 67 | // second lstm layer 68 | char l1_prefix[7] = "dec_l1"; 69 | struct ggml_tensor *out = forward_pass_lstm_unilayer( 70 | ctx0, hs1, lstm.l1_ih_w, lstm.l1_hh_w, lstm.l1_ih_b, lstm.l1_hh_b, l1_prefix); 71 | 72 | inpL = ggml_add(ctx0, inpL, out); 73 | } 74 | 75 | for (int layer_ix = 0; layer_ix < 4; layer_ix++) { 76 | encodec_decoder_block block = decoder->blocks[layer_ix]; 77 | 78 | // upsampling layers 79 | inpL = ggml_elu(ctx0, inpL); 80 | 81 | inpL = strided_conv_transpose_1d( 82 | ctx0, inpL, block.us_conv_w, block.us_conv_b, ratios[layer_ix]); 83 | 84 | struct ggml_tensor *current = inpL; 85 | 86 | // shortcut 87 | struct ggml_tensor *shortcut = strided_conv_1d( 88 | ctx0, inpL, block.conv_sc_w, block.conv_sc_b, stride); 89 | 90 | // conv1 91 | current = ggml_elu(ctx0, current); 92 | 93 | current = strided_conv_1d( 94 | ctx0, current, block.conv_1_w, block.conv_1_b, stride); 95 | 96 | // conv2 97 | current = ggml_elu(ctx0, current); 98 | 99 | current = strided_conv_1d( 100 | ctx0, current, block.conv_2_w, block.conv_2_b, stride); 101 | 102 | // residual connection 103 | inpL = ggml_add(ctx0, current, shortcut); 104 | } 105 | 106 | // final conv 107 | inpL = ggml_elu(ctx0, inpL); 108 | 109 | struct ggml_tensor *decoded_inp = strided_conv_1d( 110 | ctx0, inpL, decoder->final_conv_w, decoder->final_conv_b, stride); 111 | 112 | return decoded_inp; 113 | } 114 | -------------------------------------------------------------------------------- /vall_e.cpp/include/espeak-ng/encoding.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2017 Reece H. Dunn 3 | * 4 | * This program is free software; you can redistribute it and/or modify 5 | * it under the terms of the GNU General Public License as published by 6 | * the Free Software Foundation; either version 3 of the License, or 7 | * (at your option) any later version. 8 | * 9 | * This program is distributed in the hope that it will be useful, 10 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | * GNU General Public License for more details. 13 | * 14 | * You should have received a copy of the GNU General Public License 15 | * along with this program; if not, see: . 16 | */ 17 | #ifndef ESPEAK_NG_ENCODING_H 18 | #define ESPEAK_NG_ENCODING_H 19 | 20 | #include 21 | 22 | #ifdef __cplusplus 23 | extern "C" 24 | { 25 | #endif 26 | 27 | typedef enum 28 | { 29 | ESPEAKNG_ENCODING_UNKNOWN, 30 | ESPEAKNG_ENCODING_US_ASCII, 31 | ESPEAKNG_ENCODING_ISO_8859_1, 32 | ESPEAKNG_ENCODING_ISO_8859_2, 33 | ESPEAKNG_ENCODING_ISO_8859_3, 34 | ESPEAKNG_ENCODING_ISO_8859_4, 35 | ESPEAKNG_ENCODING_ISO_8859_5, 36 | ESPEAKNG_ENCODING_ISO_8859_6, 37 | ESPEAKNG_ENCODING_ISO_8859_7, 38 | ESPEAKNG_ENCODING_ISO_8859_8, 39 | ESPEAKNG_ENCODING_ISO_8859_9, 40 | ESPEAKNG_ENCODING_ISO_8859_10, 41 | ESPEAKNG_ENCODING_ISO_8859_11, 42 | // ISO-8859-12 is not a valid encoding. 43 | ESPEAKNG_ENCODING_ISO_8859_13, 44 | ESPEAKNG_ENCODING_ISO_8859_14, 45 | ESPEAKNG_ENCODING_ISO_8859_15, 46 | ESPEAKNG_ENCODING_ISO_8859_16, 47 | ESPEAKNG_ENCODING_KOI8_R, 48 | ESPEAKNG_ENCODING_ISCII, 49 | ESPEAKNG_ENCODING_UTF_8, 50 | ESPEAKNG_ENCODING_ISO_10646_UCS_2, 51 | } espeak_ng_ENCODING; 52 | 53 | ESPEAK_NG_API espeak_ng_ENCODING 54 | espeak_ng_EncodingFromName(const char *encoding); 55 | 56 | typedef struct espeak_ng_TEXT_DECODER_ espeak_ng_TEXT_DECODER; 57 | 58 | ESPEAK_NG_API espeak_ng_TEXT_DECODER * 59 | create_text_decoder(void); 60 | 61 | ESPEAK_NG_API void 62 | destroy_text_decoder(espeak_ng_TEXT_DECODER *decoder); 63 | 64 | ESPEAK_NG_API espeak_ng_STATUS 65 | text_decoder_decode_string(espeak_ng_TEXT_DECODER *decoder, 66 | const char *string, 67 | int length, 68 | espeak_ng_ENCODING encoding); 69 | 70 | ESPEAK_NG_API espeak_ng_STATUS 71 | text_decoder_decode_string_auto(espeak_ng_TEXT_DECODER *decoder, 72 | const char *string, 73 | int length, 74 | espeak_ng_ENCODING encoding); 75 | 76 | ESPEAK_NG_API espeak_ng_STATUS 77 | text_decoder_decode_wstring(espeak_ng_TEXT_DECODER *decoder, 78 | const wchar_t *string, 79 | int length); 80 | 81 | ESPEAK_NG_API espeak_ng_STATUS 82 | text_decoder_decode_string_multibyte(espeak_ng_TEXT_DECODER *decoder, 83 | const void *input, 84 | espeak_ng_ENCODING encoding, 85 | int flags); 86 | 87 | ESPEAK_NG_API int 88 | text_decoder_eof(espeak_ng_TEXT_DECODER *decoder); 89 | 90 | ESPEAK_NG_API uint32_t 91 | text_decoder_getc(espeak_ng_TEXT_DECODER *decoder); 92 | 93 | ESPEAK_NG_API uint32_t 94 | text_decoder_peekc(espeak_ng_TEXT_DECODER *decoder); 95 | 96 | ESPEAK_NG_API const void * 97 | text_decoder_get_buffer(espeak_ng_TEXT_DECODER *decoder); 98 | 99 | #ifdef __cplusplus 100 | } 101 | #endif 102 | 103 | #endif 104 | -------------------------------------------------------------------------------- /scripts/train_tokenizer.py: -------------------------------------------------------------------------------- 1 | """ 2 | # Helper script to grab all phonemes through parsed dataset metadata to find the "best" tokenizer dict 3 | """ 4 | 5 | import os 6 | import json 7 | import torch 8 | import torchaudio 9 | 10 | from tqdm.auto import tqdm 11 | from pathlib import Path 12 | 13 | from tokenizers import Tokenizer 14 | from tokenizers.models import BPE, Unigram, WordLevel, WordPiece 15 | from tokenizers.trainers import BpeTrainer 16 | from tokenizers.pre_tokenizers import Whitespace 17 | from tokenizers.processors import TemplateProcessing 18 | 19 | from vall_e.config import cfg 20 | from vall_e.utils.io import json_read 21 | from vall_e.emb.g2p import coerce_to_hiragana 22 | 23 | input_metadata = "training/metadata/" 24 | 25 | output_file = Path("./training/tokenizer_pretraining_data.json") 26 | tokenizer_data = [] 27 | 28 | def pad(num, zeroes): 29 | return str(num).zfill(zeroes+1) 30 | 31 | def add( dir, type="training", audios=True, texts=True ): 32 | name = str(dir) 33 | name = name.replace(str(cfg.data_dir), "") 34 | speaker_name = name 35 | """ 36 | if "LibriTTS-R" in speaker_name: 37 | speaker_name = speaker_name.replace("LibriTTS-R", "LibriVox") 38 | """ 39 | 40 | metadata_path = cfg.metadata_dir / f'{speaker_name}.json' 41 | metadata = json_read( metadata_path, default={} ) 42 | 43 | for k, entry in metadata.items(): 44 | if "text" not in entry: 45 | continue 46 | 47 | language = entry.get('language','auto') 48 | text = entry['text'] 49 | tokenizer_data.append( text ) 50 | 51 | if output_file.exists(): 52 | tokenizer_data = json.loads(open(str(output_file), "r", encoding="utf-8").read()) 53 | else: 54 | # training 55 | for data_dir in tqdm(sorted(cfg.dataset.training), desc="Processing Training"): 56 | try: 57 | add( data_dir, type="training" ) 58 | except Exception as e: 59 | pass 60 | 61 | # validation 62 | for data_dir in tqdm(sorted(cfg.dataset.validation), desc='Processing Validation'): 63 | try: 64 | add( data_dir, type="validation" ) 65 | except Exception as e: 66 | pass 67 | """ 68 | for dataset_name in os.listdir(f'./{input_metadata}/'): 69 | if not os.path.isdir(f'./{input_metadata}/{dataset_name}/'): 70 | continue 71 | 72 | for speaker_id in tqdm(os.listdir(f'./{input_metadata}/{dataset_name}/'), desc="Processing speaker"): 73 | if not os.path.isdir(f'./{input_metadata}/{dataset_name}/{speaker_id}'): 74 | continue 75 | 76 | for id in os.listdir(f'./{input_metadata}/{dataset_name}/{speaker_id}/'): 77 | if ".json" not in id: 78 | continue 79 | 80 | metadata_path = Path(f'./{input_metadata}/{dataset_name}/{speaker_id}/{id}') 81 | metadata = json.loads(open(metadata_path, "r", encoding="utf-8").read()) 82 | 83 | if "text" not in metadata: 84 | continue 85 | 86 | tokenizer_data.append( f'{"".join(metadata["text"])}' ) 87 | 88 | open(output_file, 'w', encoding='utf-8').write(json.dumps(tokenizer_data)) 89 | """ 90 | 91 | unk_token = "" 92 | spl_tokens = [unk_token, "", "", "", ""] 93 | 94 | trainer = BpeTrainer(special_tokens = spl_tokens, vocab_size = 32768, max_token_length=1, min_frequency=len(tokenizer_data)) 95 | tokenizer = Tokenizer(BPE(unk_token = unk_token)) 96 | tokenizer.pre_tokenizer = Whitespace() # takes 2 hours to process without this, we'll just manually add spaces as a token 97 | tokenizer.post_processor = TemplateProcessing( 98 | single=" $A ", 99 | special_tokens=[("", 1), ("", 2)], 100 | ) 101 | 102 | tokenizer.train_from_iterator(tokenizer_data, trainer=trainer) 103 | tokenizer.save("./training/tokenizer_training_data.json") -------------------------------------------------------------------------------- /vall_e.cpp/vall_e-impl.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | // stores all the backend stuff 4 | 5 | // external deps 6 | #include 7 | #include 8 | #include 9 | #include 10 | 11 | #define LLAMA_CPP_EXTENDED 0 // whether the underlying llama.cpp has some extra functions 12 | #define LLAMA_CPP_USE_VALL_E_ARCH 0 // whether the underlying llama.cpp is to use the VALL_E arch (or using LLAMA arch) 13 | 14 | #if !LLAMA_CPP_EXTENDED 15 | #include "llama_hack.h" // cringe hotfix but I have to do this until llama.cpp's API exposes the tok_embd 16 | #endif 17 | 18 | // to-do: clean up spaghetti enums 19 | const int EMBEDDING_MODE_PROM = 0; 20 | const int EMBEDDING_MODE_RESP_AR_NAR = 1; 21 | const int EMBEDDING_MODE_RESP_NAR_LEN = 2; 22 | 23 | const int INFERENCE_MODE_LEN = 0; 24 | const int INFERENCE_MODE_AR = 1; 25 | const int INFERENCE_MODE_NAR_DEMASK = 2; 26 | const int INFERENCE_MODE_NAR = 3; 27 | 28 | // stores metadata for inputs/outputs 29 | struct io_t { 30 | std::string name; 31 | uint32_t start; 32 | uint32_t end; 33 | int32_t head_idx = -1; 34 | 35 | int32_t n_embd = 0; 36 | int32_t n_vocab = 0; 37 | 38 | std::vector embds = {}; 39 | ggml_tensor* head = NULL; 40 | }; 41 | 42 | // stores the mappings between tokens, input embeddings, and output heads 43 | struct io_map_t { 44 | // model's original params 45 | int32_t n_embd = 0; 46 | int32_t n_vocab = 0; 47 | 48 | // mapping 49 | std::unordered_map io = {}; 50 | // context to store slices 51 | ggml_context* ctx = NULL; 52 | }; 53 | // used for top-k (mainly for demasking) 54 | struct score_t { 55 | int32_t idx; 56 | float value; 57 | 58 | bool operator<( const score_t& that ) const { return this->value < that.value; } 59 | }; 60 | // handles storing metadata for token merges 61 | struct merge_entry_t { 62 | std::u32string pre; 63 | std::u32string post; 64 | std::u32string resolved; 65 | 66 | token_t pre_token; 67 | token_t post_token; 68 | token_t resolved_token; 69 | }; 70 | 71 | // helper tensor functions 72 | std::vector read_2d_tensor( struct ggml_tensor* tensor ); 73 | //ggml_tensor* view_2d_tensor( ggml_tensor* tensor, int32_t start, int32_t end, int32_t dim = 0 ); // cringe method to keep in my pocket 74 | ggml_tensor* view_2d_tensor( ggml_context* ctx, ggml_tensor* tensor, int32_t start, int32_t end, int32_t dim = 0 ); 75 | void print_tokens( const std::vector& tokens, const std::string& prefix = "Tokens: " ); 76 | 77 | std::vector> map_embeddings( const std::vector& tokens, int n_embd, const float* embds ); 78 | std::vector> sum_embeddings( const vall_e_audio_codes_t& input, int n_embd, int rvq_l, const float** embds, int mode = EMBEDDING_MODE_PROM ); 79 | std::vector soft_max( int n_logits, const float* logits ); 80 | 81 | // batch and inferencing 82 | void batch_add( llama_batch& batch, token_t id, int n_embd, const float* embds, llama_pos pos, bool output, const std::vector & seq_ids = {0} ); 83 | void fill_batch( llama_batch& batch, vall_e_inputs_t& input, io_map_t& inputs_map, int mode ); 84 | std::vector generate( vall_e_context_t* ctx, vall_e_inputs_t& input, int max_tokens, int mode, bool verbose = true ); 85 | 86 | // (handles text) 87 | std::vector phonemize( vall_e_context_t* ctx, const std::string& text, const std::string& language = "auto" ); 88 | 89 | // model-accessing helpers 90 | const io_t& vall_e_inputs_map_get_embeddings( io_map_t& inputs_map, const std::string& name ); 91 | const float* vall_e_inputs_map_get_embeddings_p( io_map_t& inputs_map, const std::string& name ); 92 | int32_t vall_e_inputs_map_get_classifier_idx( io_map_t& inputs_map, const std::string& name ); 93 | void vall_e_inputs_map_init( io_map_t&, llama_model* model ); -------------------------------------------------------------------------------- /vall_e/utils/ext/unsloth.py: -------------------------------------------------------------------------------- 1 | # lifted from https://gist.github.com/pszemraj/e88ff24ab296b6d89057376b299b368a 2 | # to-do: make this work with LoRAs, it complains 3 | 4 | # Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved. 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | 18 | import torch 19 | import transformers 20 | import inspect 21 | 22 | 23 | class Unsloth_Offloaded_Gradient_Checkpointer(torch.autograd.Function): 24 | """ 25 | Saves VRAM by smartly offloading to RAM. 26 | Tiny hit to performance, since we mask the movement via non blocking calls. 27 | """ 28 | 29 | @staticmethod 30 | @torch.cuda.amp.custom_fwd 31 | def forward(ctx, forward_function, hidden_states, *args): 32 | saved_hidden_states = hidden_states.to("cpu", non_blocking=True) 33 | with torch.no_grad(): 34 | output = forward_function(hidden_states, *args) 35 | ctx.save_for_backward(saved_hidden_states) 36 | ctx.forward_function = forward_function 37 | ctx.args = args 38 | 39 | return output 40 | 41 | pass 42 | 43 | @staticmethod 44 | @torch.cuda.amp.custom_bwd 45 | def backward(ctx, dY): 46 | (hidden_states,) = ctx.saved_tensors 47 | hidden_states = hidden_states.to("cuda", non_blocking=True).detach() 48 | hidden_states.requires_grad = True 49 | with torch.enable_grad(): 50 | (output,) = ctx.forward_function(hidden_states, *ctx.args) 51 | torch.autograd.backward(output, dY) 52 | return ( 53 | None, 54 | hidden_states.grad, 55 | ) + ( 56 | None, 57 | ) * len(ctx.args) 58 | 59 | pass 60 | 61 | 62 | pass 63 | 64 | 65 | def new_gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None): 66 | #assert gradient_checkpointing_kwargs == None 67 | gradient_checkpointing_kwargs = None 68 | if not self.supports_gradient_checkpointing: 69 | raise ValueError( 70 | f"{self.__class__.__name__} does not support gradient checkpointing." 71 | ) 72 | 73 | gradient_checkpointing_func = Unsloth_Offloaded_Gradient_Checkpointer.apply 74 | # For old GC format (transformers < 4.35.0) for models that live on the Hub 75 | # we will fall back to the overwritten `_set_gradient_checkpointing` method 76 | _is_using_old_format = ( 77 | "value" in inspect.signature(self._set_gradient_checkpointing).parameters 78 | ) 79 | 80 | if not _is_using_old_format: 81 | self._set_gradient_checkpointing( 82 | enable=True, gradient_checkpointing_func=gradient_checkpointing_func 83 | ) 84 | else: 85 | raise NotImplementedError() 86 | 87 | if getattr(self, "_hf_peft_config_loaded", False): 88 | # When using PEFT + gradient checkpointing + Trainer we need to make sure the input has requires_grad=True 89 | # we do it also on PEFT: https://github.com/huggingface/peft/blob/85013987aa82aa1af3da1236b6902556ce3e483e/src/peft/peft_model.py#L334 90 | # When training with PEFT, only LoRA layers will have requires grad set to True, but the output of frozen layers need to propagate 91 | # the gradients to make sure the gradient flows. 92 | self.enable_input_require_grads() 93 | 94 | 95 | def apply_unsloth_offloaded_gradient_checkpoint_monkey_patch(): 96 | transformers.modeling_utils.PreTrainedModel.gradient_checkpointing_enable = ( 97 | new_gradient_checkpointing_enable 98 | ) -------------------------------------------------------------------------------- /vall_e.cpp/include/quantizer.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | 6 | #include "ggml.h" 7 | #include "ggml-alloc.h" 8 | #include "ggml-backend.h" 9 | 10 | #include "utils.h" 11 | 12 | struct encodec_quant_block { 13 | struct ggml_tensor *embed; 14 | }; 15 | 16 | struct encodec_quantizer { 17 | std::vector blocks; 18 | }; 19 | 20 | struct ggml_tensor *encodec_forward_quantizer_encode( 21 | const struct encodec_quantizer *quantizer, struct ggml_context *ctx0, 22 | struct ggml_tensor *encoded_inp, const int n_bins, const int sr, const int bandwidth, 23 | const int hop_length) { 24 | 25 | if (!encoded_inp) { 26 | fprintf(stderr, "%s: null input tensor\n", __func__); 27 | return NULL; 28 | } 29 | 30 | const int frame_rate = (int)ceilf(sr / hop_length); 31 | const int n_q = get_num_quantizers_for_bandwidth(n_bins, frame_rate, bandwidth); 32 | 33 | const int seq_length = encoded_inp->ne[0]; 34 | 35 | struct ggml_tensor *codes = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, seq_length, n_q); 36 | ggml_set_input(codes); 37 | 38 | struct ggml_tensor *inpL = ggml_cont(ctx0, ggml_transpose(ctx0, encoded_inp)); 39 | struct ggml_tensor *residual = inpL; 40 | struct ggml_tensor *indices; 41 | 42 | for (int i = 0; i < n_q; i++) { 43 | encodec_quant_block block = quantizer->blocks[i]; 44 | 45 | // compute distance 46 | // [seq_length, n_bins] 47 | struct ggml_tensor *dp = ggml_scale( 48 | ctx0, ggml_mul_mat(ctx0, block.embed, residual), -2.0f); 49 | 50 | // [n_bins] 51 | struct ggml_tensor *sqr_embed = ggml_sqr(ctx0, block.embed); 52 | struct ggml_tensor *sqr_embed_nrm = ggml_sum_rows(ctx0, sqr_embed); 53 | 54 | // [seq_length] 55 | struct ggml_tensor *sqr_inp = ggml_sqr(ctx0, residual); 56 | struct ggml_tensor *sqr_inp_nrm = ggml_sum_rows(ctx0, sqr_inp); 57 | 58 | // [seq_length, n_bins] 59 | struct ggml_tensor *dist = ggml_add(ctx0, ggml_repeat(ctx0, sqr_inp_nrm, dp), dp); 60 | dist = ggml_add(ctx0, ggml_repeat(ctx0, ggml_transpose(ctx0, sqr_embed_nrm), dist), dist); 61 | dist = ggml_neg(ctx0, dist); 62 | 63 | // take the argmax over the column dimension 64 | // [seq_length] 65 | indices = ggml_argmax(ctx0, dist); 66 | 67 | // look up in embedding table 68 | struct ggml_tensor *quantized = ggml_get_rows(ctx0, block.embed, indices); 69 | 70 | residual = ggml_sub(ctx0, residual, quantized); 71 | 72 | codes = ggml_set_1d(ctx0, codes, indices, i * codes->nb[1]); 73 | } 74 | 75 | return codes; 76 | } 77 | 78 | struct ggml_tensor *encodec_forward_quantizer_decode( 79 | const struct encodec_quantizer *quantizer, struct ggml_context *ctx0, 80 | struct ggml_tensor *codes, const int hidden_dim, const int n_bins, const int sr, const int bandwidth, 81 | const int hop_length) { 82 | 83 | if (!codes) { 84 | fprintf(stderr, "%s: null input tensor\n", __func__); 85 | return NULL; 86 | } 87 | 88 | const int seq_length = codes->ne[0]; 89 | 90 | const int frame_rate = (int)ceilf(sr / hop_length); 91 | const int n_q = get_num_quantizers_for_bandwidth(n_bins, frame_rate, bandwidth); 92 | 93 | assert(n_q == codes->ne[1]); 94 | 95 | struct ggml_tensor *quantized_out = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, hidden_dim, seq_length); 96 | ggml_set_input(quantized_out); 97 | ggml_set_name(quantized_out, "quantized_out"); 98 | 99 | for (int i = 0; i < n_q; i++) { 100 | encodec_quant_block block = quantizer->blocks[i]; 101 | 102 | struct ggml_tensor *indices = ggml_view_1d(ctx0, codes, seq_length, i * codes->nb[1]); 103 | struct ggml_tensor *quantized = ggml_get_rows(ctx0, block.embed, indices); 104 | 105 | quantized_out = ggml_add(ctx0, quantized_out, quantized); 106 | } 107 | 108 | quantized_out = ggml_cont(ctx0, ggml_transpose(ctx0, quantized_out)); 109 | 110 | return quantized_out; 111 | } 112 | -------------------------------------------------------------------------------- /vall_e/utils/io.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import json 3 | 4 | from pathlib import Path 5 | from safetensors import safe_open as sft_load 6 | from safetensors.torch import save_file as sft_save 7 | 8 | try: 9 | use_orjson = True 10 | import orjson as json 11 | except: 12 | import json 13 | 14 | from .utils import truncate_json 15 | 16 | def json_stringify( data, truncate=False, pretty=False, raw=False ): 17 | if truncate: 18 | return truncate_json( json.dumps( data ) ) 19 | if pretty: 20 | s = json.dumps( data, option=json.OPT_INDENT_2 ) if use_orjson else json.dumps( data, indent='\t' ) 21 | return s if raw and use_orjson else s.decode('utf-8') 22 | return json.dumps( data ) 23 | 24 | def json_parse( string ): 25 | return json.loads( string ) 26 | 27 | def json_read( path, default=None ): 28 | path = coerce_path( path ) 29 | 30 | if not path.exists(): 31 | return default 32 | 33 | with (open( str(path), "rb" ) if use_orjson else open( str(path), "r", encoding="utf-8" ) ) as f: 34 | return json_parse( f.read() ) 35 | 36 | def json_write( data, path, **kwargs ): 37 | path = coerce_path( path ) 38 | 39 | with (open( str(path), "wb" ) if use_orjson else open( str(path), "w", encoding="utf-8" ) ) as f: 40 | f.write( json_stringify( data, raw=use_orjson, **kwargs ) ) 41 | 42 | def coerce_path( path ): 43 | return path if isinstance( path, Path ) else Path(path) 44 | 45 | def pick_path( path, *suffixes ): 46 | suffixes = [*suffixes] 47 | 48 | for suffix in suffixes: 49 | p = path.with_suffix( suffix ) 50 | if p.exists(): 51 | return p 52 | 53 | return path 54 | 55 | def is_dict_of( d, t ): 56 | if not isinstance( d, dict ): 57 | return False 58 | 59 | return all([ isinstance(v, torch.Tensor) for v in d.values() ]) 60 | 61 | # handles converting the usual pth state_dict into just the dict with the tensors + a dict of JSON strings, for safetensors 62 | def state_dict_to_tensor_metadata( data: dict, module_key=None ): 63 | metadata = {} 64 | 65 | # is a state_dict, no need to coerce 66 | if is_dict_of( data, torch.Tensor ): 67 | return data, metadata 68 | 69 | # is maybe a dict with a state dict + metadata, coerce it 70 | target = module_key 71 | if not target: 72 | for k, v in data.items(): 73 | # is a dict of tensors, our target 74 | if is_dict_of( v, torch.Tensor ): 75 | target = k 76 | continue # continue to iterate to grab other metadata 77 | 78 | # not a dict of tensors, put it as metadata 79 | try: 80 | metadata[k] = json_stringify(v) if any([isinstance( v, dict ), isinstance( v, list )]) else v 81 | 82 | if isinstance( metadata[k], bytes ): 83 | metadata[k] = metadata[k].decode('utf-8') 84 | except Exception as e: 85 | pass 86 | 87 | if not target: 88 | raise Exception(f'Requesting to save safetensors of a state dict, but state dict contains no key of torch.Tensor: {path}') 89 | 90 | return data[target], metadata 91 | 92 | def torch_save( data, path, module_key=None ): 93 | path = coerce_path(path) 94 | ext = path.suffix 95 | 96 | if ext in [".safetensor", ".safetensors", ".sft"]: 97 | data, metadata = state_dict_to_tensor_metadata( data, module_key=module_key ) 98 | 99 | if metadata is None: 100 | metadata = {} 101 | 102 | return sft_save( data, path, { k: v for k, v in metadata.items() if v is not None } ) 103 | 104 | return torch.save( data, path ) 105 | 106 | def torch_load( path, device="cpu", framework="pt", unsafe=True, load_metadata=True, module_key="module" ): 107 | path = coerce_path(path) 108 | ext = path.suffix 109 | 110 | if ext in [".safetensor", ".safetensors", ".sft"]: 111 | state_dict = {} 112 | with sft_load(path, framework=framework, device=device) as f: 113 | for k in f.keys(): 114 | state_dict[k] = f.get_tensor(k) 115 | 116 | if load_metadata: 117 | metadata = f.metadata() 118 | for k, v in metadata.items(): 119 | try: 120 | metadata[k] = json.loads( v ) 121 | except Exception as e: 122 | pass 123 | state_dict = { module_key: state_dict } | metadata 124 | 125 | return state_dict 126 | 127 | return torch.load( path, map_location=torch.device(device), weights_only=not unsafe ) -------------------------------------------------------------------------------- /vall_e/plot.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import argparse 4 | import json 5 | import time 6 | import re 7 | from pathlib import Path 8 | 9 | import matplotlib.pyplot as plt 10 | import pandas as pd 11 | 12 | from .config import cfg 13 | 14 | def plot(paths, args): 15 | dfs = [] 16 | 17 | for path in paths: 18 | with open(path, "r") as f: 19 | text = f.read() 20 | 21 | rows = [] 22 | 23 | pattern = r"(\{.+?\})\.\n" 24 | 25 | for row in re.findall(pattern, text, re.DOTALL): 26 | try: 27 | row = json.loads(row) 28 | except Exception as e: 29 | continue 30 | 31 | for model in args.models: 32 | if f'{model.name}.{args.xs}' not in row: 33 | continue 34 | rows.append(row) 35 | break 36 | 37 | df = pd.DataFrame(rows) 38 | 39 | if "name" in df: 40 | df["name"] = df["name"].fillna("train") 41 | else: 42 | df["name"] = "train" 43 | 44 | df["group"] = str(path.parents[args.group_level]) 45 | df["group"] = df["group"] + "/" + df["name"] 46 | 47 | dfs.append(df) 48 | 49 | df = pd.concat(dfs) 50 | 51 | if args.min_x is not None: 52 | for model in args.models: 53 | df = df[args.min_x < df[f'{model.name}.{args.xs}']] 54 | 55 | if args.max_x is not None: 56 | for model in args.models: 57 | df = df[df[f'{model.name}.{args.xs}'] < args.max_x] 58 | 59 | for gtag, gdf in sorted( 60 | df.groupby("group"), 61 | key=lambda p: (p[0].split("/")[-1], p[0]), 62 | ): 63 | for model in args.models: 64 | x = f'{model.name}.{args.xs}' 65 | for ys in args.ys: 66 | y = f'{model.name}.{ys}' 67 | 68 | if gdf[y].isna().all(): 69 | continue 70 | 71 | if args.min_y is not None: 72 | gdf = gdf[args.min_y < gdf[y]] 73 | if args.max_y is not None: 74 | gdf = gdf[gdf[y] < args.max_y] 75 | 76 | if args.ewm: 77 | gdf[y] = gdf[y].ewm(args.ewm).mean() 78 | elif args.rolling: 79 | gdf[y] = gdf[y].rolling(args.rolling).mean() 80 | 81 | gdf.plot( 82 | x=x, 83 | y=y, 84 | label=f"{y}", 85 | ax=plt.gca(), 86 | marker="x" if len(gdf) < 100 else None, 87 | alpha=0.7, 88 | ) 89 | 90 | plt.gca().legend( 91 | #loc="center left", 92 | fancybox=True, 93 | shadow=True, 94 | #bbox_to_anchor=(1.04, 0.5), 95 | ) 96 | 97 | def plot_sample_metrics( metrics, filename=None ): 98 | """ 99 | fig = plt.figure() 100 | fig.set_figwidth( 16 * len(metrics) // cfg.dataset.frames_per_second ) 101 | """ 102 | 103 | data = { key: [ e[0][key] for e in metrics ] for key in metrics[0][0].keys() } 104 | 105 | df = pd.DataFrame(data) 106 | df.plot() 107 | 108 | plt.gca().legend( 109 | #loc="center left", 110 | fancybox=True, 111 | shadow=True, 112 | #bbox_to_anchor=(1.04, 0.5), 113 | ) 114 | 115 | if not filename: 116 | filename = f'{time.time()}.png' 117 | 118 | out_path = cfg.rel_path / "metrics" / filename 119 | out_path.parent.mkdir(parents=True, exist_ok=True) 120 | plt.savefig(out_path, bbox_inches="tight") 121 | 122 | if __name__ == "__main__": 123 | parser = argparse.ArgumentParser() 124 | parser.add_argument("--xs", default="engine_step") 125 | parser.add_argument("--ys", nargs="+", default="") 126 | parser.add_argument("--model", nargs="+", default="*") 127 | 128 | parser.add_argument("--min-x", type=float, default=-float("inf")) 129 | parser.add_argument("--min-y", type=float, default=-float("inf")) 130 | parser.add_argument("--max-x", type=float, default=float("inf")) 131 | parser.add_argument("--max-y", type=float, default=float("inf")) 132 | 133 | parser.add_argument("--ewm", type=int, default=1024) 134 | parser.add_argument("--rolling", type=int, default=None) 135 | 136 | parser.add_argument("--size", type=str, default=None) 137 | 138 | parser.add_argument("--filename", default="log.txt") 139 | parser.add_argument("--group-level", default=1) 140 | args, unknown = parser.parse_known_args() 141 | 142 | path = cfg.rel_path / "logs" 143 | paths = path.rglob(f"./*/{args.filename}") 144 | 145 | args.models = [ model for model in cfg.model.get() if model.training and (args.model == "*" or model.name in args.model) ] 146 | 147 | if args.ys == "": 148 | args.ys = ["loss.nll"] 149 | 150 | if args.size: 151 | width, height = args.size.split("x") 152 | plt.figure(figsize=(int(width), int(height))) 153 | 154 | plot(paths, args) 155 | 156 | out_path = cfg.rel_path / "metrics.png" 157 | plt.savefig(out_path, bbox_inches="tight") -------------------------------------------------------------------------------- /vall_e.cpp/include/llama-vocab.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "llama.h" 4 | 5 | #include 6 | #include 7 | #include 8 | 9 | struct LLM_KV; 10 | struct llama_model_loader; 11 | 12 | struct llama_vocab { 13 | struct token_data { 14 | std::string text; 15 | float score; 16 | llama_token_attr attr; 17 | }; 18 | 19 | llama_vocab(); 20 | ~llama_vocab(); 21 | 22 | void load(llama_model_loader & ml, const LLM_KV & kv); 23 | 24 | enum llama_vocab_type get_type() const; 25 | enum llama_vocab_pre_type get_pre_type() const; 26 | 27 | uint32_t n_tokens() const; 28 | uint32_t n_token_types() const; 29 | 30 | std::string type_name() const; 31 | 32 | bool is_normal (llama_token id) const; 33 | bool is_unknown (llama_token id) const; 34 | bool is_control (llama_token id) const; 35 | bool is_byte (llama_token id) const; 36 | bool is_user_defined(llama_token id) const; 37 | bool is_unused (llama_token id) const; 38 | bool is_eog (llama_token id) const; 39 | 40 | uint8_t token_to_byte(llama_token id) const; 41 | llama_token byte_to_token(uint8_t ch) const; 42 | 43 | llama_token text_to_token(const std::string & text) const; 44 | 45 | const token_data & get_token_data(llama_token id) const; 46 | 47 | const char * token_get_text (llama_token id) const; 48 | float token_get_score(llama_token id) const; 49 | llama_token_attr token_get_attr (llama_token id) const; 50 | 51 | llama_token token_bos() const; 52 | llama_token token_eos() const; 53 | llama_token token_eot() const; 54 | llama_token token_eom() const; 55 | llama_token token_unk() const; 56 | llama_token token_sep() const; 57 | llama_token token_nl () const; 58 | llama_token token_pad() const; 59 | 60 | llama_token token_prefix() const; 61 | llama_token token_middle() const; 62 | llama_token token_suffix() const; 63 | 64 | llama_token token_fim_pre() const; 65 | llama_token token_fim_suf() const; 66 | llama_token token_fim_mid() const; 67 | llama_token token_fim_pad() const; 68 | llama_token token_fim_rep() const; 69 | llama_token token_fim_sep() const; 70 | 71 | bool get_add_space_prefix () const; 72 | bool get_add_bos () const; 73 | bool get_add_eos () const; 74 | bool get_ignore_merges () const; 75 | bool get_clean_spaces () const; 76 | bool get_remove_extra_whitespaces () const; 77 | bool get_escape_whitespaces () const; 78 | bool get_treat_whitespace_as_suffix() const; 79 | 80 | int max_token_len() const; 81 | 82 | int find_bpe_rank(const std::string & token_left, const std::string & token_right) const; 83 | 84 | int32_t tokenize( 85 | const char * text, 86 | int32_t text_len, 87 | llama_token * tokens, 88 | int32_t n_tokens_max, 89 | bool add_special, 90 | bool parse_special) const; 91 | 92 | std::vector tokenize( 93 | const std::string & raw_text, 94 | bool add_special, 95 | bool parse_special = false) const; 96 | 97 | // does not write null-terminator to buf 98 | int32_t token_to_piece( 99 | llama_token token, 100 | char * buf, 101 | int32_t length, 102 | int32_t lstrip, 103 | bool special) const; 104 | 105 | // use cached data 106 | const std::string & token_to_piece(llama_token token) const; 107 | 108 | int32_t detokenize( 109 | const llama_token * tokens, 110 | int32_t n_tokens, 111 | char * text, 112 | int32_t text_len_max, 113 | bool remove_special, 114 | bool unparse_special) const; 115 | 116 | std::string detokenize( 117 | const std::vector & tokens, 118 | bool special) const; 119 | 120 | void print_info() const; 121 | 122 | private: 123 | struct impl; 124 | std::unique_ptr pimpl; 125 | }; 126 | -------------------------------------------------------------------------------- /vall_e/emb/g2p.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import random 3 | import string 4 | import torch 5 | 6 | from functools import cache 7 | from pathlib import Path 8 | from phonemizer import phonemize 9 | from phonemizer.backend import BACKENDS 10 | 11 | from tqdm import tqdm 12 | 13 | try: 14 | import pykakasi 15 | except Exception as e: 16 | pykakasi = None 17 | print(f'Error while importing pykakasi: {str(e)}') 18 | pass 19 | 20 | try: 21 | import langdetect 22 | except Exception as e: 23 | langdetect = None 24 | print(f'Error while importing langdetect: {str(e)}') 25 | 26 | def detect_language( text ): 27 | if not text: 28 | return "en" # to-do: map to a null language 29 | 30 | if langdetect is None: 31 | raise Exception('langdetect is not installed.') 32 | return langdetect.detect( text ) 33 | 34 | def _get_graphs(path): 35 | with open(path, "r") as f: 36 | graphs = f.read() 37 | return graphs 38 | 39 | def coerce_to_hiragana( runes, sep="" ): 40 | if pykakasi is None: 41 | raise Exception('pykakasi is not installed.') 42 | 43 | kks = pykakasi.kakasi() 44 | result = kks.convert( runes ) 45 | return sep.join([ res['hira'] for res in result ]) 46 | 47 | def coerce_language( lang ): 48 | # bottle of water vs bo'oh'o'wa'er 49 | if lang == "en": 50 | lang = "en-us" 51 | # quebec probably 52 | if lang == "fr": 53 | return "fr-fr" 54 | # phonemizer/espeak used to have zh refer to mandarin, but was renamed to cmn 55 | # cmn outputs cringe, but not cmn-latn-pinyin 56 | # also just coerces any of the dialects into this (to avoid crimes) 57 | if lang[:2] == "zh": 58 | return "cmn-latn-pinyin" 59 | """ 60 | things to consider in the future 61 | en-uk or en-gb 62 | es-la vs es-es 63 | pt-br vs pt-pt 64 | """ 65 | return lang 66 | 67 | cached_backends = {} 68 | def _get_backend( language="en-us", backend="espeak", punctuation=True, stress=True, strip=True ): 69 | key = f'{language}_{backend}' 70 | if key in cached_backends: 71 | return cached_backends[key] 72 | 73 | if backend == 'espeak': 74 | phonemizer = BACKENDS[backend]( language, preserve_punctuation=punctuation, with_stress=stress) 75 | elif backend == 'espeak-mbrola': 76 | phonemizer = BACKENDS[backend]( language ) 77 | else: 78 | phonemizer = BACKENDS[backend]( language, preserve_punctuation=punctuation ) 79 | 80 | cached_backends[key] = phonemizer 81 | return phonemizer 82 | 83 | 84 | def encode(text: str, language="auto", backend="auto", punctuation=True, stress=True, strip=True) -> list[str]: 85 | if language == "auto": 86 | language = detect_language( text ) 87 | 88 | language = coerce_language( language ) 89 | 90 | # 91 | if backend == "auto": 92 | # Convert to hiragana, as espeak does not like kanji 93 | if language[:2] == "ja": 94 | text = coerce_to_hiragana( text ) 95 | 96 | # "zh" => "cmn-latn-pinyin" 97 | elif language == "zh": 98 | language = "cmn-latn-pinyin" 99 | 100 | 101 | if not backend or backend == "auto": 102 | backend = "espeak" # if language[:2] != "en" else "festival" 103 | 104 | backend = _get_backend(language=language, backend=backend, stress=stress, strip=strip, punctuation=punctuation) 105 | if backend is not None: 106 | phonemes = backend.phonemize( [ text ], strip=strip ) 107 | else: 108 | phonemes = phonemize( [ text ], language=language, strip=strip, preserve_punctuation=punctuation, with_stress=stress ) 109 | 110 | if not len(phonemes): 111 | raise Exception(f"Failed to phonemize, received empty string: {text}") 112 | 113 | phonemes = phonemes[0] 114 | 115 | # remap tones 116 | # technically they can be kept in place and just update the tokenizer, but this would be a bit confusing 117 | if language == "cmn-latn-pinyin": 118 | tones = { 119 | "1": "ˇ", 120 | "2": "ˉ", 121 | "3": "ˊ", 122 | "4": "ˋ", 123 | "5": "_", 124 | } 125 | for k, v in tones.items(): 126 | phonemes = phonemes.replace(k, v) 127 | 128 | return phonemes 129 | 130 | # Helper function to debug phonemizer 131 | if __name__ == "__main__": 132 | parser = argparse.ArgumentParser() 133 | 134 | parser.add_argument("string", type=str) 135 | parser.add_argument("--language", type=str, default="en-us") 136 | parser.add_argument("--backend", type=str, default="auto") 137 | parser.add_argument("--no-punctuation", action="store_true") 138 | parser.add_argument("--no-stress", action="store_true") 139 | parser.add_argument("--no-strip", action="store_true") 140 | 141 | args = parser.parse_args() 142 | 143 | phonemes = encode( args.string, language=args.language, backend=args.backend, punctuation=not args.no_punctuation, stress=not args.no_stress, strip=not args.no_strip ) 144 | print( phonemes ) -------------------------------------------------------------------------------- /vall_e.cpp/include/ggml-cann.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2023-2024 The ggml authors 3 | * 4 | * Permission is hereby granted, free of charge, to any person obtaining a copy 5 | * of this software and associated documentation files (the "Software"), to 6 | * deal in the Software without restriction, including without limitation the 7 | * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or 8 | * sell copies of the Software, and to permit persons to whom the Software is 9 | * furnished to do so, subject to the following conditions: 10 | * 11 | * The above copyright notice and this permission notice shall be included in 12 | * all copies or substantial portions of the Software. 13 | * 14 | * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 15 | * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 16 | * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 17 | * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 18 | * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 19 | * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS 20 | * IN THE SOFTWARE. 21 | */ 22 | 23 | #pragma once 24 | 25 | #include "ggml-backend.h" 26 | #include "ggml.h" 27 | 28 | #ifdef __cplusplus 29 | extern "C" { 30 | #endif 31 | 32 | /** 33 | * @brief Maximum number of CANN devices supported. 34 | */ 35 | #define GGML_CANN_MAX_DEVICES 16 36 | 37 | GGML_BACKEND_API ggml_backend_reg_t ggml_backend_cann_reg(void); 38 | 39 | /** 40 | * @brief Initializes the CANN backend for a specified device. 41 | * 42 | * This function initializes the CANN backend for the given device. 43 | * It verifies the device index, allocates a context, and creates a backend 44 | * instance. 45 | * 46 | * @param device The index of the device to initialize. 47 | * @return A pointer to the initialized backend instance, or nullptr on failure. 48 | */ 49 | GGML_BACKEND_API ggml_backend_t ggml_backend_cann_init(int32_t device); 50 | 51 | /** 52 | * @brief Checks if a given backend is a CANN backend. 53 | * 54 | * This function verifies if the provided backend is a CANN backend by comparing 55 | * its GUID with the CANN backend's GUID. 56 | * 57 | * @param backend The backend instance to check. 58 | * @return True if the backend is a CANN backend, false otherwise. 59 | */ 60 | GGML_BACKEND_API bool ggml_backend_is_cann(ggml_backend_t backend); 61 | 62 | /** 63 | * @brief Retrieves the CANN buffer type for a specified device. 64 | * 65 | * This function initializes and returns the buffer type interface associated 66 | * with the given device. It ensures thread-safe access using a mutex. 67 | * 68 | * @param device The device index for which to retrieve the buffer type. 69 | * @return A pointer to the buffer type interface for the specified device, or 70 | * nullptr if the device index is out of range. 71 | */ 72 | GGML_BACKEND_API ggml_backend_buffer_type_t 73 | ggml_backend_cann_buffer_type(int32_t device); 74 | 75 | /** 76 | * @brief Retrieves the number of CANN devices available. 77 | * 78 | * This function returns the number of CANN devices available based on 79 | * information obtained from `ggml_cann_info()`. 80 | * 81 | * @return The number of CANN devices available. 82 | */ 83 | GGML_BACKEND_API int32_t ggml_backend_cann_get_device_count(void); 84 | 85 | /** 86 | * @brief pinned host buffer for use with the CPU backend for faster copies between CPU and NPU. 87 | * 88 | * @return A pointer to the host buffer type interface. 89 | */ 90 | GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_cann_host_buffer_type(void); 91 | 92 | /** 93 | * @brief Retrieves the description of a specific CANN device. 94 | * 95 | * This function sets the specified device, retrieves the SoC name, 96 | * and writes it into the provided description buffer. 97 | * 98 | * @param device The device index to retrieve the description for. 99 | * @param description Pointer to a buffer where the description will be written. 100 | * @param description_size Size of the description buffer. 101 | */ 102 | GGML_BACKEND_API void ggml_backend_cann_get_device_description( 103 | int32_t device, char* description, size_t description_size); 104 | 105 | /** 106 | * @brief Retrieves the memory information of a specific CANN device. 107 | * 108 | * This function sets the specified device, retrieves the free and total 109 | * memory information of the specified type (ACL_HBM_MEM), and stores them 110 | * in the provided pointers. 111 | * 112 | * @param device The device index to retrieve memory information for. 113 | * @param free Pointer to a variable where the free memory size will be stored. 114 | * @param total Pointer to a variable where the total memory size will be 115 | * stored. 116 | */ 117 | GGML_BACKEND_API void ggml_backend_cann_get_device_memory(int32_t device, 118 | size_t* free, 119 | size_t* total); 120 | 121 | #ifdef __cplusplus 122 | } 123 | #endif 124 | -------------------------------------------------------------------------------- /docs/inferenece.md: -------------------------------------------------------------------------------- 1 | # `inference.py` 2 | 3 | This script handles everything the higher level functions of inferencing the model for various tasks for the end user. 4 | 5 | For invoking this model in another Python package, refer to `webui.py` and `demo.py` on how to use this outside of this scope. 6 | 7 | `__main__.py` invokes this according to the below arguments. 8 | 9 | ## Synthesis 10 | 11 | To synthesize speech: `python -m vall_e --yaml=` (or `--model=`) 12 | 13 | Some additional flags you can pass are: 14 | * `--language`: specifies the language for guiding guide inferencing when the model is trained against that language. Use `auto` to automatically deduce this. 15 | * `--text-language`: the language to phonemize the input text under. Leave blank to tie it to the above value. 16 | * `--task`: task to perform. Defaults to `tts`, but accepts `stt` for transcriptions. 17 | * `--max-duration`: maximum token-duration for inferencing through the AR aspect of the model. Every second corresponds to 75 steps. 18 | * `--max-steps`: maximum steps for inferencing through the NAR-len aspect of the model. 19 | * `--device`: device to use (default: `cuda`, examples: `cuda:0`, `cuda:1`, `cpu`) 20 | * `--ar-temperature`: sampling temperature to use for the AR/NAR pass. 0 enables greedy sampling. 21 | * For the AR, ~1.0 is *fine*, but lowering the temperature adheres better to the prosody of the input prompt. 22 | * For the AR, low temperatures require a repetition penalty to prevent outputs from degenerating. 23 | * For the NAR, greedy sampling is best, but can be raised to 0.2. 24 | * `--input-prompt-length`: the duration of the input prompt (~6 seconds is fine, longer durations lead to slower generations for "better" accuracy). 0 does not repeat/trim. 25 | * If a prompt is shorter than the given duration, it's repeated to the duration size. 26 | 27 | And some experimental sampling flags you can use too (your mileage will ***definitely*** vary, but most of these are bandaids for a bad AR): 28 | * `--input-prompt-prefix`: (AR only) treats the input prompt as the initial response prefix, but... 29 | * the transcription of the prompt needs to be in the input text prompt. 30 | * doesn't perform all that well (I belive the model needs to be trained a bit on this, as `tts-c`). 31 | * `--min-temperature`: triggers the dynamic temperature pathway, adjusting the temperature based on the confidence of the best token. Acceptable values are between `[0.0, (n)ar-temperature)`. 32 | + This simply uplifts the [original implementation](https://github.com/kalomaze/koboldcpp/blob/dynamic-temp/llama.cpp#L5132) to perform it. 33 | * `--top-p`: limits the sampling pool to top sum of values that equal `P`% probability in the probability distribution. 34 | * `--top-k`: limits the sampling pool to the top `K` values in the probability distribution. 35 | * `--min-p`: only logits above `P`% probability are considered for sampling (or something, I'm still unsure how this differs from top-p). 36 | * `--repetition-penalty`: modifies the probability of tokens if they have appeared before. In the context of audio generation, this is a very iffy parameter to use. 37 | * `--repetition-penalty-decay`: modifies the above factor applied to scale based on how far away it is in the past sequence. 38 | * `--length-penalty`: (AR only) modifies the probability of the stop token based on the current sequence length. This is ***very*** finnicky due to the AR already being well correlated with the length. 39 | * `--beam-width`: (AR only) specifies the number of branches to search through for beam sampling. 40 | + This is a very naive implementation that's effectively just greedy sampling across `B` spaces. 41 | * `--mirostat-tau`: (AR only) the "surprise value" when performing mirostat sampling. 42 | + This simply uplifts the [original implementation](https://github.com/basusourya/mirostat/blob/master/mirostat.py) to perform it. 43 | + **!**NOTE**!**: This is incompatible with beam search sampling (for the meantime at least). 44 | * `--mirostat-eta`: (AR only) the "learning rate" during mirostat sampling applied to the maximum surprise. 45 | * `--dry-multiplier`: (AR only) performs DRY sampling, the scalar factor. 46 | * `--dry-base`: (AR only) for DRY sampling, the base of the exponent factor. 47 | * `--dry-allowed-length`: (AR only) for DRY sampling, the window to perform DRY sampling within. 48 | 49 | Some arguments are able to be prefixed with `ar-` and `nar-` to only use that setting for its respective pass. At the moment through the CLI, this includes: 50 | * `temperature` 51 | 52 | ### Speech-to-Text 53 | 54 | The `ar+nar-tts+stt-llama-8` (now the reference model) model has received additional training for a speech-to-text task against EnCodec-encoded audio. 55 | 56 | Currently, the model only transcribes back into the IPA phonemes it was trained against, as an additional model or external program is required to translate the IPA phonemes back into text. 57 | * this does make a model that can phonemize text, and unphonemize text, more desirable in the future to replace espeak (having an additional task to handle this requires additional embeddings, output heads, and possible harm to the model as actual text is not a modality the model is trained on). 58 | * it seems to really want to only transcribe the first sentence for a given utterance. I imagine this is simply a problem with how it was trained. -------------------------------------------------------------------------------- /docs/samplers.md: -------------------------------------------------------------------------------- 1 | # `sampler.py` 2 | 3 | This script contains all the samplers used during inferencing. 4 | 5 | While I do expose these samplers for end-user use, I don't like to rely on these, as exotic samplers are always bandaids to the underlying model. 6 | 7 | Most of these sampler functions do what's written on the tin, but for clarity: 8 | 9 | ## Samplers 10 | 11 | When sampling, the output logits are picked for sampling according to the current inference mode. For the AR, only the last token (or last `causal_size` tokens) are used for sampling, while the NAR relies on the previous sequence to determine how many tokens to sample in parallel. 12 | 13 | As the model is trained more, low temperatures are preferred over high temperatures for the AR, while greedy sampling is almost always preferred for the NAR. 14 | 15 | Greedy sampling is enabled when the sampling temperature is <= 0, where the most likely token is picked. 16 | 17 | ### Repetition Penalty 18 | 19 | This function (`reptition_penalize`) applies a penalty to target logits to avoid repetitive output. 20 | 21 | This is implemented by penalizing tokens in the future from repeating the currently iterated token. 22 | * This distinction is required to penalize for the NAR, while the AR only penalizes the single token being inferenced. 23 | 24 | An optional value can also be passed to factor in how far away that token is. 25 | 26 | Implicitly, this is only limited to 75 tokens in the past (one second of audio under EnCodec), and will apply more than once. 27 | 28 | For low temperatures, this is almost necessary, as no rep-pen will have the output be garbled or a mess, and very low rep-pen will have unstable output. 29 | 30 | ### Length Penalty 31 | 32 | This function (`length_penalize`) applies a penalty to the audio stop token (or any other specific token) based on the current length of the sequence. 33 | 34 | This can be either a negative or a positive, to restrain or inhibit the stop token from appearing. 35 | 36 | ### Ban Tokens 37 | 38 | This function (`ban_tokens`) bans a token from appearing. 39 | 40 | Since this is an audio LM, there's no useful purpose for this. 41 | 42 | However, for some models, this is useful for banning the stop token used for the AR, when sampling output from the NAR, if the classifier / LM head / output projections are shared between the two. 43 | 44 | ### Top-K / Top-P 45 | 46 | This function (`top_k_top_p_filtering`) filters the logits to only allow the top-K probability of tokens to be sampled, and/or the top-P probable tokens to be sampled. 47 | 48 | This may be helpful with higher temperatured sampling to offer some variety, but not allow outputs to be *too* chaotic, in theory. 49 | 50 | ### Min-P 51 | 52 | This function (`min_p_filtering`) filters out tokens that are under the min-P% probability. 53 | 54 | ### Dynamic Temperature 55 | 56 | This function (`dynamic_temperature`) implements an early version of dynamic temperature per [this external PR](https://github.com/LostRuins/koboldcpp/pull/464). 57 | 58 | To reiterate, this is an early implementation, as I recall it changing after I have already implemented this. 59 | 60 | In theory, this allows the model to sample under higher temperatures when able, but I still need to test this the more the model receives training. 61 | 62 | ### Mirostat 63 | 64 | This function (`mirostat_sample`) implements mirostat sampling. From what I understand, this modifies the logits based on "surprise" factor. 65 | 66 | This may be an early implementation, as this was implemented a while back. 67 | 68 | This *sometimes* helps the output a lot for some states of the model, but I don't try to rely on this too much. 69 | 70 | ### DRY Sampling 71 | 72 | This function (`dry_sampling`) implements DRY sampling, a replacement to naive repetition penalizing. 73 | 74 | I'm still not too sure what's so good about it, since it just seems like rep-pen with a different coat of paint, and for audio it doesn't seem to be too helpful? 75 | 76 | ### Entropix 77 | 78 | This function (`sample_entropix`) implements entropix sampling, a sampler that aids in Chain-of-Thought for text LLMs by adjusting sampling parameters according to the logits and attentions' entropy and varentropy. 79 | 80 | The huge caveat is that this requires tuning the parameters and thresholds per model, and in testing it doesn't seem like the metrics are consistent enough to rely on this. Acquiring the right attention scores is pretty much a dark art in its own right, as it does not map perfectly to naive attention, much less any other attention mechanism, under `transformers`' LLaMA. 81 | 82 | Additionally, one state requires injecting a CoT token, which doesn't have an analog in the audio domain. 83 | 84 | However, this does seem to serve as a good basis to expand upon this and sample according to the entropy/varentropy of the model's current state. 85 | 86 | ### Classifier-Free Guidance 87 | 88 | While this isn't a direct sampler type used, a helper function is provided to perform classifier-free guidance, given a positive (the primary) logits, and a negative (the null) logits. While the `NAR-len` modality requires this at the moment, it can easily be adapted everything else. 89 | 90 | Rescaling is also applied to avoid clipping the logits. 91 | 92 | Due to the logits being the full sequence, and the input lengths differing, a list of lengths are required to be passed to only modify the last N logits. -------------------------------------------------------------------------------- /vall_e/engines/deepspeed.py: -------------------------------------------------------------------------------- 1 | """ 2 | # https://github.com/enhuiz/pytorch-training-utilities 3 | """ 4 | 5 | # to-do: replace this 6 | # to-do: swap out deepspeed 7 | 8 | from ..config import cfg 9 | from ..utils import dispatch_attribute, flatten_dict, gather_attribute, do_gc, to_device 10 | 11 | import logging 12 | import time 13 | import torch 14 | import torch.distributed 15 | 16 | from torch import Tensor 17 | from torch.distributed import all_reduce 18 | from typing import Any, Protocol 19 | 20 | from .base import TrainFeeder 21 | 22 | _logger = logging.getLogger(__name__) 23 | 24 | from deepspeed import DeepSpeedEngine, DeepSpeedConfig, comm as dist, init_distributed as init_deepspeed_dist 25 | from deepspeed.accelerator import get_accelerator 26 | 27 | from ..utils.distributed import init_distributed, distributed_initialized 28 | from ..utils import ml 29 | 30 | from ..models.lora import freeze_non_lora_weights 31 | 32 | if not distributed_initialized() and cfg.trainer.backend == "deepspeed": 33 | init_distributed(init_deepspeed_dist) 34 | 35 | class Engine(DeepSpeedEngine): 36 | def __init__(self, *args, **kwargs): 37 | self.hyper_config = kwargs.pop('hyper_config', None) 38 | 39 | kwargs['config'] = cfg.trainer.deepspeed.ds_cfg 40 | kwargs['config_class'] = DeepSpeedConfig(kwargs['config']) 41 | 42 | stats = { 43 | "global_step": 0, 44 | "micro_step": 0, 45 | "global_samples": 0, 46 | "tokens_processed": 0, 47 | } 48 | 49 | # kwargs['stats'] = None will return None when popped 50 | maybe_stats = kwargs.pop('stats', stats) 51 | if maybe_stats is not None: 52 | stats = maybe_stats 53 | 54 | super().__init__(None, *args, **kwargs) 55 | 56 | self.global_steps = stats["global_step"] 57 | self.micro_steps = stats["micro_step"] 58 | self.global_samples = stats["global_samples"] 59 | self.tokens_processed = stats["tokens_processed"] 60 | 61 | self._frozen_params = set() 62 | self.current_batch_size = 0 63 | 64 | def freeze(self, freeze_all=True): 65 | # freeze non-LoRA params if requested 66 | if not self.hyper_config.frozen_params and not freeze_all and cfg.lora is not None: 67 | frozen_params = freeze_non_lora_weights( self.module, embeddings=cfg.lora.embeddings ) 68 | for param in frozen_params: 69 | self._frozen_params.add( param ) 70 | 71 | return 72 | 73 | if self.hyper_config is None or not hasattr(self.hyper_config, "frozen_params"): 74 | raise Exception("freeze_all=False yet self.hyper_config.frozen_params is None") 75 | 76 | for name, param in self.module.named_parameters(): 77 | if (freeze_all and param.requires_grad) or (not freeze_all and name in self.hyper_config.frozen_params): 78 | param.requires_grad_(False) 79 | self._frozen_params.add(param) 80 | 81 | def unfreeze(self): 82 | for param in self._frozen_params: 83 | param.requires_grad_(True) 84 | self._frozen_params.clear() 85 | 86 | @property 87 | def _training(self): 88 | return self.hyper_config.training 89 | 90 | @property 91 | def _teacher(self): 92 | return self.hyper_config.teacher 93 | 94 | @property 95 | def global_step(self): 96 | return self.global_steps 97 | 98 | @property 99 | def micro_step(self): 100 | return self.micro_steps 101 | 102 | @property 103 | def batch_size(self): 104 | return self.current_batch_size if self.current_batch_size > 0 else cfg.hyperparameters.batch_size 105 | 106 | def gather_attribute(self, *args, **kwargs): 107 | return gather_attribute(self.module, *args, **kwargs) 108 | 109 | def dispatch_attribute(self, *args, **kwargs): 110 | return dispatch_attribute(self.module, *args, **kwargs) 111 | 112 | def set_lr(self, lr): 113 | try: 114 | if hasattr(self.optimizer, 'param_groups'): 115 | for param_group in self.optimizer.param_groups: 116 | param_group["d_coeff" if "d_coeff" in param_group else "lr"] = lr 117 | else: 118 | self.optimizer.set_lr(lr) 119 | except Exception as e: 120 | _logger.warning(str(e)) 121 | 122 | # cur_scale, because _get_loss_scale has a typo in the def and I can't be assed to inject a fix into it or push a PR 123 | def get_loss_scale(self): 124 | if not hasattr(self.optimizer, "cur_scale") or self.optimizer.cur_scale is None: 125 | return 1.0 126 | 127 | return self.optimizer.cur_scale 128 | 129 | def set_loss_scale(self, value): 130 | if not hasattr(self.optimizer, "cur_scale") or self.optimizer.cur_scale is None: 131 | return 132 | 133 | self.optimizer.cur_scale = value 134 | 135 | # we'll just have to live with the LoRA weights living within our main weights 136 | # they're easy to extract anyways 137 | def load_checkpoint(self, load_dir, **kwargs ): 138 | # override to load the lora instead 139 | if cfg.lora is not None: 140 | load_dir = cfg.ckpt_dir / cfg.lora.full_name 141 | 142 | return super().load_checkpoint( load_dir, **kwargs ) 143 | 144 | def save_checkpoint(self, save_dir, **kwargs ): 145 | # override to save the lora instead 146 | if cfg.lora is not None: 147 | save_dir = cfg.ckpt_dir / cfg.lora.full_name 148 | 149 | return super().save_checkpoint( save_dir, **kwargs ) 150 | 151 | def traverse(self, *args, **kwargs): 152 | with ml.autocast(): 153 | self.forward(*args, **kwargs) 154 | 155 | losses = self.gather_attribute("loss") 156 | loss = torch.stack([*losses.values()]).sum() 157 | 158 | stats = {} 159 | stats |= {k: v.item() for k, v in losses.items()} 160 | stats |= self.gather_attribute("scalar") 161 | 162 | """ 163 | if torch.isnan(loss).any(): 164 | self.max_nan_losses = self.max_nan_losses - 1 165 | if self.max_nan_losses < 0: 166 | raise RuntimeError("Too many NaN losses detected.") 167 | 168 | return stats 169 | """ 170 | 171 | self.backward(loss) 172 | self.step() 173 | 174 | return stats -------------------------------------------------------------------------------- /vall_e/emb/codecs/dac.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from dac import DACFile 4 | from audiotools import AudioSignal 5 | from dac.utils import load_model as load_dac_model 6 | 7 | from typing import Union 8 | from pathlib import Path 9 | """ 10 | Patch decode to skip things related to the metadata (namely the waveform trimming) 11 | So far it seems the raw waveform can just be returned without any post-processing 12 | A smart implementation would just reuse the values from the input prompt 13 | """ 14 | from dac.model.base import CodecMixin 15 | 16 | @torch.no_grad() 17 | def CodecMixin_compress( 18 | self, 19 | audio_path_or_signal: Union[str, Path, AudioSignal], 20 | win_duration: float = 1.0, 21 | verbose: bool = False, 22 | normalize_db: float = -16, 23 | n_quantizers: int = None, 24 | ) -> DACFile: 25 | """Processes an audio signal from a file or AudioSignal object into 26 | discrete codes. This function processes the signal in short windows, 27 | using constant GPU memory. 28 | 29 | Parameters 30 | ---------- 31 | audio_path_or_signal : Union[str, Path, AudioSignal] 32 | audio signal to reconstruct 33 | win_duration : float, optional 34 | window duration in seconds, by default 5.0 35 | verbose : bool, optional 36 | by default False 37 | normalize_db : float, optional 38 | normalize db, by default -16 39 | 40 | Returns 41 | ------- 42 | DACFile 43 | Object containing compressed codes and metadata 44 | required for decompression 45 | """ 46 | audio_signal = audio_path_or_signal 47 | if isinstance(audio_signal, (str, Path)): 48 | audio_signal = AudioSignal.load_from_file_with_ffmpeg(str(audio_signal)) 49 | 50 | self.eval() 51 | original_padding = self.padding 52 | original_device = audio_signal.device 53 | 54 | audio_signal = audio_signal.clone() 55 | original_sr = audio_signal.sample_rate 56 | 57 | resample_fn = audio_signal.resample 58 | loudness_fn = audio_signal.loudness 59 | 60 | # If audio is > 10 minutes long, use the ffmpeg versions 61 | if audio_signal.signal_duration >= 10 * 60 * 60: 62 | resample_fn = audio_signal.ffmpeg_resample 63 | loudness_fn = audio_signal.ffmpeg_loudness 64 | 65 | original_length = audio_signal.signal_length 66 | resample_fn(self.sample_rate) 67 | input_db = loudness_fn() 68 | 69 | if normalize_db is not None: 70 | audio_signal.normalize(normalize_db) 71 | audio_signal.ensure_max_of_audio() 72 | 73 | nb, nac, nt = audio_signal.audio_data.shape 74 | audio_signal.audio_data = audio_signal.audio_data.reshape(nb * nac, 1, nt) 75 | win_duration = ( 76 | audio_signal.signal_duration if win_duration is None else win_duration 77 | ) 78 | 79 | if audio_signal.signal_duration <= win_duration: 80 | # Unchunked compression (used if signal length < win duration) 81 | self.padding = True 82 | n_samples = nt 83 | hop = nt 84 | else: 85 | # Chunked inference 86 | self.padding = False 87 | # Zero-pad signal on either side by the delay 88 | audio_signal.zero_pad(self.delay, self.delay) 89 | n_samples = int(win_duration * self.sample_rate) 90 | # Round n_samples to nearest hop length multiple 91 | n_samples = int(math.ceil(n_samples / self.hop_length) * self.hop_length) 92 | hop = self.get_output_length(n_samples) 93 | 94 | codes = [] 95 | range_fn = range if not verbose else tqdm.trange 96 | 97 | for i in range_fn(0, nt, hop): 98 | x = audio_signal[..., i : i + n_samples] 99 | x = x.zero_pad(0, max(0, n_samples - x.shape[-1])) 100 | 101 | audio_data = x.audio_data.to(self.device) 102 | audio_data = self.preprocess(audio_data, self.sample_rate) 103 | with torch.autocast("cuda", dtype=cfg.inference.dtype, enabled=cfg.inference.amp): 104 | _, c, _, _, _ = self.encode(audio_data, n_quantizers) 105 | codes.append(c.to(original_device)) 106 | chunk_length = c.shape[-1] 107 | 108 | codes = torch.cat(codes, dim=-1) 109 | 110 | dac_file = DACFile( 111 | codes=codes, 112 | chunk_length=chunk_length, 113 | original_length=original_length, 114 | input_db=input_db, 115 | channels=nac, 116 | sample_rate=original_sr, 117 | padding=self.padding, 118 | dac_version="1.0.0", 119 | #dac_version=SUPPORTED_VERSIONS[-1], 120 | ) 121 | 122 | if n_quantizers is not None: 123 | codes = codes[:, :n_quantizers, :] 124 | 125 | self.padding = original_padding 126 | return dac_file 127 | 128 | @torch.no_grad() 129 | def CodecMixin_decompress( 130 | self, 131 | obj: Union[str, Path, DACFile], 132 | verbose: bool = False, 133 | ) -> AudioSignal: 134 | self.eval() 135 | if isinstance(obj, (str, Path)): 136 | obj = DACFile.load(obj) 137 | 138 | original_padding = self.padding 139 | self.padding = obj.padding 140 | 141 | range_fn = range if not verbose else tqdm.trange 142 | codes = obj.codes 143 | original_device = codes.device 144 | chunk_length = obj.chunk_length 145 | recons = [] 146 | 147 | for i in range_fn(0, codes.shape[-1], chunk_length): 148 | c = codes[..., i : i + chunk_length].to(self.device) 149 | z = self.quantizer.from_codes(c)[0] 150 | r = self.decode(z) 151 | recons.append(r.to(original_device)) 152 | 153 | recons = torch.cat(recons, dim=-1) 154 | recons = AudioSignal(recons, self.sample_rate) 155 | 156 | # to-do, original implementation 157 | if not hasattr(obj, "dummy") or not obj.dummy: 158 | resample_fn = recons.resample 159 | loudness_fn = recons.loudness 160 | 161 | # If audio is > 10 minutes long, use the ffmpeg versions 162 | if recons.signal_duration >= 10 * 60 * 60: 163 | resample_fn = recons.ffmpeg_resample 164 | loudness_fn = recons.ffmpeg_loudness 165 | 166 | recons.normalize(obj.input_db) 167 | resample_fn(obj.sample_rate) 168 | recons = recons[..., : obj.original_length] 169 | loudness_fn() 170 | recons.audio_data = recons.audio_data.reshape( 171 | -1, obj.channels, obj.original_length 172 | ) 173 | self.padding = original_padding 174 | return recons 175 | 176 | CodecMixin.compress = CodecMixin_compress 177 | CodecMixin.decompress = CodecMixin_decompress -------------------------------------------------------------------------------- /vall_e/utils/sampler.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Any 3 | import random 4 | 5 | import torch 6 | from torch.utils.data import Sampler 7 | 8 | from .distributed import global_rank, local_rank, world_size 9 | 10 | # Randomly picks an index from an array of indices 11 | class PoolSampler(): 12 | def __init__( self, pool = [], keep_all = False, shuffle = False ): 13 | self.length = len(pool) 14 | self.shuffle = shuffle 15 | self.global_pool = pool if keep_all else None 16 | self.global_indices = [ i for i in range(self.length) ] 17 | self.reset() 18 | 19 | def reset(self): 20 | self.current_pool = [ i for i in self.global_indices ] 21 | if self.shuffle: 22 | random.shuffle(self.current_pool) 23 | 24 | def sample(self, pool = None): 25 | if pool is None: 26 | pool = self.global_pool 27 | # check if we need to reset 28 | index = random.choice( self.current_pool ) 29 | # remove from pool 30 | self.current_pool.remove(index) 31 | # reset if needed 32 | if len(self.current_pool) == 0: 33 | self.reset() 34 | # map indices to our real values 35 | return pool[index] if pool is not None else index 36 | 37 | def __len__(self): 38 | return self.length # len(self.current_pool) 39 | 40 | def __iter__(self): 41 | while len(self.current_pool) > 0: 42 | yield self.sample() 43 | 44 | def __call__(self, *args, **kwargs): 45 | return self.sample(*args, **kwargs) 46 | 47 | def index(self): 48 | return len(self.global_indices) - len(self.current_pool) 49 | 50 | def get_state(self): 51 | return { "length": self.length, "global_pool": self.global_pool, "global_indices": self.global_indices, "current_pool": self.current_pool } 52 | 53 | def set_state(self, state): 54 | self.length = state["length"] 55 | self.global_pool = state["global_pool"] 56 | self.global_indices = state["global_indices"] 57 | self.current_pool = state["current_pool"] 58 | 59 | # "Samples" through a fixed sequence from 0 to length 60 | # Necessary for our "shuffle+sort by duration+interleave" sampling method 61 | # Allows saving and loading state 62 | class OrderedSampler(Sampler): 63 | def __init__( self, length ): 64 | self.position = 0 65 | self.length = length 66 | 67 | def __len__(self): 68 | return self.length 69 | 70 | def __iter__(self): 71 | if self.position >= self.length: 72 | self.position = 0 73 | 74 | while self.position < self.length: 75 | yield self.position 76 | self.position += 1 77 | 78 | def index(self): 79 | return self.position 80 | 81 | def get_state(self): 82 | return { "position": self.position, "length": self.length } 83 | 84 | def set_state(self, state): 85 | self.position = state["position"] 86 | self.length = state["length"] 87 | 88 | # Like the above, but will batch based on token count 89 | class BatchedOrderedSampler(Sampler): 90 | def __init__( self, buckets, max_duration=0, max_batch_size=0, shuffle=False, drop_last=True, use_max_size=True ): 91 | self.position = 0 92 | self.batches = [] 93 | self.shuffle = shuffle 94 | 95 | assert max_duration != 0 and max_batch_size != 0, "max_duration and max_batch_size cannot both be 0" 96 | 97 | current_batch = [] 98 | current_index = 0 99 | current_duration = 0 100 | 101 | for key, bucket in buckets.items(): 102 | for path, duration in bucket: 103 | # flush 104 | should_flush = False 105 | if max_duration > 0 and current_duration + duration > max_duration: 106 | should_flush = True 107 | elif max_batch_size > 0 and len(current_batch) >= max_batch_size: 108 | should_flush = True 109 | 110 | if should_flush and len(current_batch) > 0: 111 | self.batches.append( current_batch ) 112 | current_batch = [] 113 | current_duration = 0 114 | 115 | current_batch.append( current_index ) 116 | current_index += 1 117 | # as long as durations are ordered, this assertion is always true 118 | if use_max_size: 119 | current_duration = duration * len(current_batch) 120 | else: 121 | current_duration += duration 122 | 123 | if not drop_last and current_batch: 124 | self.batches.append( current_batch ) 125 | 126 | if self.shuffle: 127 | random.shuffle(self.batches) 128 | 129 | def __len__(self): 130 | return len(self.batches) 131 | 132 | def __iter__(self): 133 | if self.position >= len(self.batches): 134 | self.position = 0 135 | if self.shuffle: 136 | random.shuffle(self.batches) 137 | 138 | while self.position < len(self.batches): 139 | yield self.batches[self.position] 140 | self.position += 1 141 | 142 | def index(self): 143 | return self.position 144 | 145 | def get_state(self): 146 | return { "position": self.position, "batches": self.batches } 147 | 148 | def set_state(self, state): 149 | self.position = state["position"] 150 | self.batches = state["batches"] 151 | 152 | # Randomly samples indices from a given sequence from 0 to length 153 | # Allows saving and loading state 154 | class RandomSampler(Sampler): 155 | def __init__( self, length ): 156 | self.position = 0 157 | self.length = length 158 | 159 | self.generator = torch.Generator() 160 | self.perm = torch.randperm(self.length, generator=self.generator) 161 | 162 | def __len__(self): 163 | return self.length 164 | 165 | def __iter__(self): 166 | if self.position >= self.length: 167 | self.position = 0 168 | self.perm = torch.randperm(self.length, generator=self.generator) 169 | 170 | while self.position < self.length: 171 | yield self.perm[self.position] 172 | self.position += 1 173 | 174 | def index(self): 175 | return self.position 176 | 177 | def get_state(self): 178 | return { "position": self.position, "length": self.length, "perm": self.perm, "generator": self.generator.get_state() } 179 | 180 | def set_state(self, state): 181 | self.position = state["position"] 182 | self.length = state["length"] 183 | self.perm = state["perm"] 184 | self.generator.set_state(state["generator"]) -------------------------------------------------------------------------------- /vall_e/__main__.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from pathlib import Path 3 | from .inference import TTS 4 | from .config import cfg 5 | 6 | def path_list(arg): 7 | if not arg: 8 | return None 9 | return [Path(p) for p in arg.split(";")] 10 | 11 | def main(): 12 | parser = argparse.ArgumentParser("VALL-E TTS") 13 | parser.add_argument("text") 14 | parser.add_argument("references", type=path_list, default=None) 15 | parser.add_argument("--language", type=str, default="auto") 16 | parser.add_argument("--text-language", type=str, default=None) 17 | parser.add_argument("--task", type=str, default="tts") 18 | parser.add_argument("--modality", type=str, default="auto") 19 | parser.add_argument("--out-path", type=Path, default=None) 20 | 21 | parser.add_argument("--split-text-by", type=str, default="\n") 22 | parser.add_argument("--context-history", type=int, default=0) 23 | parser.add_argument("--no-phonemize", action='store_true') 24 | 25 | parser.add_argument("--yaml", type=Path, default=None) 26 | parser.add_argument("--model", type=Path, default=None) 27 | parser.add_argument("--lora", type=Path, default=None) 28 | 29 | parser.add_argument("--max-duration", type=int, default=12 * cfg.dataset.frames_per_second) 30 | parser.add_argument("--max-steps", type=int, default=25) 31 | parser.add_argument("--max-levels", type=int, default=7) 32 | 33 | parser.add_argument("--ar-temperature", type=float, default=1.0) 34 | parser.add_argument("--nar-temperature", type=float, default=0.0) 35 | parser.add_argument("--min-ar-temperature", type=float, default=-1.0) 36 | parser.add_argument("--min-nar-temperature", type=float, default=-1.0) 37 | parser.add_argument("--input-prompt-length", type=float, default=3.0) 38 | parser.add_argument("--input-prompt-prefix", action="store_true") 39 | parser.add_argument("--prefix-silence", type=float, default=0.0) 40 | parser.add_argument("--cfg-strength", type=float, default=0.0) 41 | parser.add_argument("--cfg-rescale", type=float, default=0.75) 42 | 43 | parser.add_argument("--top-p", type=float, default=1.0) 44 | parser.add_argument("--top-k", type=int, default=0) 45 | parser.add_argument("--top-no", type=float, default=0.0) 46 | parser.add_argument("--min-p", type=float, default=0.0) 47 | parser.add_argument("--repetition-penalty", type=float, default=1.0) 48 | parser.add_argument("--repetition-penalty-decay", type=float, default=0.0) 49 | parser.add_argument("--length-penalty", type=float, default=0.0) 50 | parser.add_argument("--beam-width", type=int, default=0) 51 | 52 | parser.add_argument("--mirostat-tau", type=float, default=0) 53 | parser.add_argument("--mirostat-eta", type=float, default=0) 54 | 55 | parser.add_argument("--dry-multiplier", type=float, default=0) 56 | parser.add_argument("--dry-base", type=float, default=1.75) 57 | parser.add_argument("--dry-allowed-length", type=int, default=2) 58 | 59 | parser.add_argument("--entropix-sampling", action="store_true") 60 | 61 | parser.add_argument("--layer-skip", action="store_true") 62 | parser.add_argument("--layer-skip-exit-layer", type=int, default=None) 63 | parser.add_argument("--layer-skip-entropy-threshold", type=int, default=0.1) 64 | parser.add_argument("--layer-skip-varentropy-threshold", type=int, default=0.1) 65 | parser.add_argument("--refine-on-stop", action="store_true") 66 | 67 | # experimental settings 68 | parser.add_argument("--load-from-artifact", type=Path, default=None) 69 | parser.add_argument("--denoise-start", type=float, default=0.0) 70 | 71 | parser.add_argument("--seed", type=int, default=None) 72 | 73 | parser.add_argument("--device", type=str, default=None) 74 | parser.add_argument("--amp", action="store_true") 75 | parser.add_argument("--dtype", type=str, default=None) 76 | parser.add_argument("--attention", type=str, default=None) 77 | parser.add_argument("--play", action="store_true") 78 | args = parser.parse_args() 79 | 80 | config = None 81 | 82 | if args.yaml: 83 | config = args.yaml 84 | elif args.model: 85 | config = args.model 86 | 87 | tts = TTS( config=config, lora=args.lora, device=args.device, dtype=args.dtype, amp=args.amp, attention=args.attention ) 88 | 89 | sampling_kwargs = dict( 90 | split_text_by=args.split_text_by, 91 | context_history=args.context_history, 92 | phonemize=not args.no_phonemize, 93 | max_steps=args.max_steps, 94 | max_levels=args.max_levels, 95 | max_duration=args.max_duration, 96 | ar_temperature=args.ar_temperature, nar_temperature=args.nar_temperature, 97 | min_ar_temperature=args.min_ar_temperature, min_nar_temperature=args.min_nar_temperature, 98 | top_p=args.top_p, top_k=args.top_k, top_no=args.top_no,min_p=args.min_p, 99 | repetition_penalty=args.repetition_penalty, repetition_penalty_decay=args.repetition_penalty_decay, 100 | length_penalty=args.length_penalty, 101 | beam_width=args.beam_width, 102 | mirostat_tau=args.mirostat_tau, mirostat_eta=args.mirostat_eta, 103 | dry_multiplier=args.dry_multiplier, dry_base=args.dry_base, dry_allowed_length=args.dry_allowed_length, 104 | entropix_sampling=args.entropix_sampling, 105 | layer_skip=args.layer_skip, 106 | layer_skip_exit_layer=args.layer_skip_exit_layer, 107 | layer_skip_entropy_threshold=args.layer_skip_entropy_threshold, 108 | layer_skip_varentropy_threshold=args.layer_skip_varentropy_threshold, 109 | refine_on_stop=args.refine_on_stop, 110 | denoise_start=args.denoise_start, 111 | input_prompt_length=args.input_prompt_length, 112 | input_prompt_prefix=args.input_prompt_prefix, 113 | prefix_silence=args.prefix_silence, 114 | cfg_strength=args.cfg_strength, 115 | cfg_rescale=args.cfg_rescale, 116 | ) 117 | 118 | output = tts.inference( 119 | text=args.text, 120 | references=args.references, 121 | text_language=args.text_language, 122 | language=args.language, 123 | task=args.task, 124 | modality=args.modality, 125 | out_path=args.out_path, 126 | play=args.play, 127 | 128 | input_prompt_length=args.input_prompt_length, 129 | load_from_artifact=args.load_from_artifact, 130 | 131 | sampling_kwargs=sampling_kwargs, 132 | 133 | seed=args.seed, 134 | ) 135 | 136 | if isinstance( output, str ): 137 | print( output ) 138 | 139 | if __name__ == "__main__": 140 | main() 141 | -------------------------------------------------------------------------------- /vall_e.cpp/vall_e.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | // C++ deps 4 | #include 5 | #include 6 | #include 7 | 8 | #include 9 | 10 | // handles defining platform specific macros and import/export decorators (copied from my engine's uf/config.h) 11 | #if defined(_WIN32) || defined(__WIN32__) || defined(__CYGWIN__) 12 | // Windows 13 | #define VALL_E_ENV "Windows" 14 | #define VALL_E_ENV_WINDOWS 1 15 | #define VALL_E_ENV_HEADER "windows.h" 16 | #if defined(__CYGWIN__) 17 | #define to_string(var) string(var) 18 | #endif 19 | #ifndef _WIN32_WINNT 20 | #define _WIN32_WINNT 0x0600 21 | #endif 22 | #ifndef WINVER 23 | #define WINVER 0x0600 24 | #endif 25 | 26 | #define VALL_E_IO_ROOT "./data/" 27 | #elif defined(linux) || defined(__linux) 28 | // Linux 29 | #define VALL_E_ENV "Linux" 30 | #define VALL_E_ENV_LINUX 1 31 | #define VALL_E_ENV_HEADER "linux.h" 32 | 33 | #define VALL_E_IO_ROOT "./data/" 34 | #elif defined(__APPLE__) || defined(MACOSX) || defined(macintosh) || defined(Macintosh) 35 | // MacOS 36 | #define VALL_E_ENV "OSX" 37 | #define VALL_E_ENV_OSX 1 38 | #define VALL_E_ENV_HEADER "osx.h" 39 | 40 | #define VALL_E_IO_ROOT "./data/" 41 | #elif defined(__FreeBSD__) || defined(__FreeBSD_kernel__) 42 | // FreeBSD 43 | #define VALL_E_ENV "FreeBSD" 44 | #define VALL_E_ENV_FREEBSD 1 45 | #define VALL_E_ENV_HEADER "freebsd.h" 46 | 47 | #define VALL_E_IO_ROOT "./data/" 48 | #elif defined(__sh__) 49 | // Dreamcast 50 | #define VALL_E_ENV "Dreamcast" 51 | #define VALL_E_ENV_DREAMCAST 1 52 | #define VALL_E_ENV_HEADER "dreamcast.h" 53 | #include VALL_E_ENV_HEADER 54 | 55 | #define _arch_dreamcast 56 | 57 | #define VALL_E_IO_ROOT "/cd/" 58 | #else 59 | // Unsupported system 60 | #define VALL_E_ENV "Unknown" 61 | #define VALL_E_ENV_UNKNOWN 1 62 | #define VALL_E_ENV_HEADER "unknown.h" 63 | #warning Using "unknown" 64 | #error No support 65 | #endif 66 | 67 | #if !defined(VALL_E_STATIC) 68 | #if defined(VALL_E_ENV_WINDOWS) 69 | // Windows compilers need specific (and different) keywords for export and import 70 | #define VALL_E_API_EXPORT __declspec(dllexport) 71 | #define VALL_E_API_IMPORT __declspec(dllimport) 72 | // For Visual C++ compilers, we also need to turn off this annoying C4251 warning 73 | #ifdef _MSC_VER 74 | #pragma warning(disable : 4251) 75 | #endif 76 | #else // Linux, FreeBSD, Mac OS X 77 | #if __GNUC__ >= 4 78 | // GCC 4 has special keywords for showing/hidding symbols, 79 | // the same keyword is used for both importing and exporting 80 | #define VALL_E_API_EXPORT __attribute__ ((__visibility__ ("default"))) 81 | #define VALL_E_API_IMPORT __attribute__ ((__visibility__ ("default"))) 82 | #else 83 | // GCC < 4 has no mechanism to explicitely hide symbols, everything's exported 84 | #define VALL_E_API_EXPORT 85 | #define VALL_E_API_IMPORT 86 | #endif 87 | #endif 88 | #else 89 | // Static build doesn't need import/export macros 90 | #define VALL_E_API_EXPORT 91 | #define VALL_E_API_IMPORT 92 | #endif 93 | 94 | #ifdef VALL_E_EXPORTS 95 | #define VALL_E_API VALL_E_API_EXPORT 96 | #else 97 | #define VALL_E_API VALL_E_API_IMPORT 98 | #endif 99 | 100 | typedef llama_token token_t; 101 | typedef std::vector> vall_e_audio_codes_t; 102 | 103 | const int ENCODEC_FRAMES_PER_SECOND = 75; 104 | const int MAX_DURATION = ENCODEC_FRAMES_PER_SECOND * 12; 105 | const int CTX_SIZE = 2048; 106 | const int N_THREADS = 8; 107 | const int N_GPU_LAYERS = 99; 108 | 109 | const int MODALITY_AR_NAR = 0; 110 | const int MODALITY_NAR_LEN = 1; 111 | 112 | // forward declarations 113 | struct io_map_t; 114 | struct llama_model; 115 | struct llama_context; 116 | struct encodec_context; 117 | 118 | // model-specific parameters 119 | struct vall_e_context_params_t { 120 | std::string model_path = "./data/vall_e.gguf"; 121 | std::string encodec_path = "./data/encodec.bin"; 122 | int32_t gpu_layers = N_GPU_LAYERS; 123 | int32_t n_threads = N_THREADS; 124 | int32_t ctx_size = CTX_SIZE; 125 | bool verbose = false; 126 | }; 127 | // inference-specific arguments 128 | struct vall_e_args_t { 129 | std::string text = "Hello world."; 130 | std::string prompt_path = "./data/prom.wav"; 131 | std::string output_path = "./data/resp.wav"; 132 | std::string language = "en"; 133 | std::string task = "tts"; 134 | int modality = MODALITY_NAR_LEN; 135 | int max_steps = 30; 136 | int max_duration = MAX_DURATION; 137 | }; 138 | // stores everything needed for vall_e.cpp at runtime 139 | struct vall_e_context_t { 140 | vall_e_context_params_t params; 141 | 142 | io_map_t* io_map = NULL; // pointer for reasons 143 | 144 | struct { 145 | llama_model* model = NULL; 146 | llama_context* ctx = NULL; 147 | } llama; 148 | 149 | struct { 150 | encodec_context* ctx; 151 | } encodec; 152 | }; 153 | // stores the raw inputs to be fed 154 | struct vall_e_inputs_t { 155 | std::string task = "tts"; 156 | std::string lang = "en"; 157 | 158 | token_t rvq_l = 0; 159 | 160 | std::vector phn = {}; 161 | vall_e_audio_codes_t prom = {}; 162 | vall_e_audio_codes_t resp = {}; 163 | }; 164 | 165 | // encodec helpers 166 | VALL_E_API std::vector read_audio_from_disk( const std::string& path ); 167 | VALL_E_API void write_audio_to_disk( const std::vector& waveform, const std::string& path ); 168 | 169 | VALL_E_API std::vector> encode_audio( struct encodec_context* ectx, const std::vector& waveform ); 170 | VALL_E_API std::vector decode_audio( struct encodec_context* ectx, const vall_e_audio_codes_t& codes_2d ); 171 | 172 | // context management 173 | VALL_E_API void vall_e_print_usage( char** argv, const vall_e_context_params_t& params, const vall_e_args_t& args ); 174 | VALL_E_API bool vall_e_args_parse( int argc, char** argv, vall_e_context_params_t& params, vall_e_args_t& args ); 175 | VALL_E_API vall_e_context_t* vall_e_load( const vall_e_context_params_t& params ); 176 | VALL_E_API vall_e_inputs_t vall_e_prepare_inputs( vall_e_context_t* ctx, const std::string& text, const std::string& prompt_path, const std::string& lang = "auto", const std::string& task = "tts" ); 177 | VALL_E_API vall_e_audio_codes_t vall_e_generate( vall_e_context_t* ctx, vall_e_inputs_t& inputs, int max_steps, int max_duration, int modality = MODALITY_NAR_LEN ); 178 | VALL_E_API void vall_e_free( vall_e_context_t* ctx ); 179 | -------------------------------------------------------------------------------- /scripts/process_emilia.py: -------------------------------------------------------------------------------- 1 | """ 2 | # Handles processing audio provided through --input-audio of adequately annotated transcriptions provided through --input-metadata (through transcribe.py) 3 | # Outputs NumPy objects containing quantized audio and adequate metadata for use of loading in the trainer through --output-dataset 4 | """ 5 | 6 | import os 7 | import json 8 | import argparse 9 | import torch 10 | import torchaudio 11 | import numpy as np 12 | 13 | from tqdm.auto import tqdm 14 | from pathlib import Path 15 | 16 | from vall_e.config import cfg 17 | 18 | from vall_e.emb.g2p import encode as phonemize 19 | from vall_e.emb.qnt import encode as quantize, _replace_file_extension, convert_audio 20 | 21 | from vall_e.emb.process import pad, load_audio, process_items, process_jobs 22 | 23 | def process( 24 | audio_backend="encodec", 25 | input_audio="Emilia", 26 | output_dataset="training", 27 | raise_exceptions=False, 28 | stride=0, 29 | stride_offset=0, 30 | slice="auto", 31 | batch_size=1, 32 | low_memory=False, 33 | 34 | device="cuda", 35 | dtype="float16", 36 | amp=False, 37 | ): 38 | # prepare from args 39 | cfg.device = device 40 | cfg.set_audio_backend(audio_backend) 41 | audio_extension = cfg.audio_backend_extension 42 | 43 | cfg.inference.weight_dtype = dtype # "bfloat16" 44 | cfg.inference.amp = amp # False 45 | 46 | dtype = cfg.inference.dtype if not amp else None 47 | 48 | output_dataset = f"{output_dataset}/{'2' if cfg.sample_rate == 24_000 else '4'}{'8' if cfg.sample_rate == 48_000 else '4'}KHz-{cfg.audio_backend}" # "training" 49 | 50 | language_map = {} # k = group, v = language 51 | 52 | ignore_groups = [] # skip these groups 53 | ignore_speakers = [] # skip these speakers 54 | 55 | only_groups = [] # only process these groups 56 | only_speakers = [] # only process these speakers 57 | 58 | always_slice_groups = [] # always slice from this group 59 | 60 | missing = { 61 | "transcription": [], 62 | "audio": [] 63 | } 64 | dataset = [] 65 | 66 | # Layout: ./Emilia/JA/JA-B000000/JA_B00000_S00000_W000000.{json|mp3} 67 | for language in sorted(os.listdir(f'./{input_audio}/')): 68 | if not os.path.isdir(f'./{input_audio}/{language}/'): 69 | print("Is not dir:", f'./{input_audio}/{language}/') 70 | continue 71 | 72 | if language in ignore_groups: 73 | continue 74 | 75 | if only_groups and language not in only_groups: 76 | continue 77 | 78 | group_name = "Emilia" 79 | 80 | for speaker_group in tqdm(process_items(os.listdir(f'./{input_audio}/{language}/'), stride=stride, stride_offset=stride_offset), desc=f"Processing speaker in {language}"): 81 | if not os.path.isdir(f'./{input_audio}/{language}/{speaker_group}'): 82 | print("Is not dir:", f'./{input_audio}/{language}/{speaker_group}') 83 | continue 84 | 85 | if speaker_group in ignore_speakers: 86 | continue 87 | if only_speakers and speaker_group not in only_speakers: 88 | continue 89 | 90 | if f'{group_name}/{speaker_group}' not in dataset: 91 | dataset.append(f'{group_name}/{speaker_group}') 92 | 93 | txts = [] 94 | wavs = [] 95 | 96 | for filename in os.listdir(f'./{input_audio}/{language}/{speaker_group}'): 97 | if ".mp3" not in filename: 98 | continue 99 | 100 | inpath = Path(f'./{input_audio}/{language}/{speaker_group}/{filename}') 101 | jsonpath = _replace_file_extension(inpath, ".json") 102 | if not inpath.exists() or not jsonpath.exists(): 103 | missing["audio"].append(str(inpath)) 104 | continue 105 | 106 | extension = os.path.splitext(filename)[-1][1:] 107 | fname = filename.replace(f'.{extension}', "") 108 | 109 | waveform, sample_rate = None, None 110 | metadata = json.load(open(jsonpath, "r", encoding="utf-8")) 111 | if "text" not in metadata: 112 | continue 113 | speaker_id = metadata["speaker"] 114 | outpath = Path(f'./{output_dataset}/{group_name}/{speaker_id}/{fname}.{extension}').with_suffix(audio_extension) 115 | os.makedirs(f'./{output_dataset}/{group_name}/{speaker_id}/', exist_ok=True) 116 | 117 | if _replace_file_extension(outpath, audio_extension).exists(): 118 | continue 119 | 120 | text = metadata["text"] 121 | 122 | if waveform is None: 123 | waveform, sample_rate = load_audio(inpath) 124 | 125 | jobs.append(( outpath, waveform, sample_rate, text, language.lower() )) 126 | 127 | # processes audio files one at a time 128 | process_jobs( jobs, device=device, speaker_id=f'{speaker_id}/{filename}', raise_exceptions=raise_exceptions, batch_size=batch_size, dtype=dtype if not amp else None ) 129 | jobs = [] 130 | 131 | open(f"./{output_dataset}/missing.json", 'w', encoding='utf-8').write(json.dumps(missing)) 132 | open(f"./{output_dataset}/dataset.json", 'w', encoding='utf-8').write(json.dumps(dataset)) 133 | 134 | def main(): 135 | parser = argparse.ArgumentParser() 136 | 137 | parser.add_argument("--audio-backend", type=str, default="encodec") 138 | parser.add_argument("--dtype", type=str, default="bfloat16") 139 | parser.add_argument("--amp", action="store_true") 140 | parser.add_argument("--input-audio", type=str, default="Emilia") 141 | parser.add_argument("--output-dataset", type=str, default="training/dataset") 142 | parser.add_argument("--device", type=str, default="cuda") 143 | parser.add_argument("--raise-exceptions", action="store_true") 144 | parser.add_argument("--stride", type=int, default=0) 145 | parser.add_argument("--stride-offset", type=int, default=0) 146 | parser.add_argument("--slice", type=str, default="auto") 147 | parser.add_argument("--low-memory", action="store_true") 148 | parser.add_argument("--batch-size", type=int, default=0) 149 | 150 | args = parser.parse_args() 151 | 152 | # do some assumption magic 153 | # to-do: find a nice way to spawn multiple processes where tqdm plays nicely 154 | if args.device.isnumeric(): 155 | args.stride = torch.cuda.device_count() 156 | args.stride_offset = int(args.device) 157 | args.device = f'cuda:{args.device}' 158 | 159 | process( 160 | audio_backend=args.audio_backend, 161 | input_audio=args.input_audio, 162 | output_dataset=args.output_dataset, 163 | raise_exceptions=args.raise_exceptions, 164 | stride=args.stride, 165 | stride_offset=args.stride_offset, 166 | slice=args.slice, 167 | batch_size=args.batch_size, 168 | low_memory=args.low_memory, 169 | 170 | device=args.device, 171 | dtype=args.dtype, 172 | amp=args.amp, 173 | ) 174 | 175 | if __name__ == "__main__": 176 | main() -------------------------------------------------------------------------------- /vall_e/utils/ext/muon.py: -------------------------------------------------------------------------------- 1 | # From https://github.com/MoonshotAI/Moonlight/blob/master/examples/toy_train.py 2 | # because it combines both param types and makes life easier with DeepSpeed 3 | 4 | import os 5 | import math 6 | import torch 7 | import torch.distributed as dist 8 | 9 | @torch.compile 10 | def zeropower_via_newtonschulz5(G, steps): 11 | """ 12 | Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a 13 | quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose 14 | of minimizing steps, it turns out to be empirically effective to keep increasing the slope at 15 | zero even beyond the point where the iteration no longer converges all the way to one everywhere 16 | on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T 17 | where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model 18 | performance at all relative to UV^T, where USV^T = G is the SVD. 19 | """ 20 | assert len(G.shape) == 2 21 | a, b, c = (3.4445, -4.7750, 2.0315) 22 | X = G.bfloat16() 23 | if G.size(0) > G.size(1): 24 | X = X.T 25 | # Ensure spectral norm is at most 1 26 | X = X / (X.norm() + 1e-7) 27 | # Perform the NS iterations 28 | for _ in range(steps): 29 | A = X @ X.T 30 | B = ( 31 | b * A + c * A @ A 32 | ) # adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng 33 | X = a * X + B @ X 34 | 35 | if G.size(0) > G.size(1): 36 | X = X.T 37 | return X 38 | 39 | 40 | class Muon(torch.optim.Optimizer): 41 | """ 42 | Muon - MomentUm Orthogonalized by Newton-schulz 43 | 44 | Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- 45 | processing step, in which each 2D parameter's update is replaced with the nearest orthogonal 46 | matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has 47 | the advantage that it can be stably run in bfloat16 on the GPU. 48 | 49 | Some warnings: 50 | - We believe this optimizer is unlikely to work well for training with small batch size. 51 | - We believe it may not work well for finetuning pretrained models, but we haven't tested this. 52 | 53 | Arguments: 54 | muon_params: The parameters to be optimized by Muon. 55 | lr: The learning rate. The updates will have spectral norm of `lr`. (0.02 is a good default) 56 | momentum: The momentum used by the internal SGD. (0.95 is a good default) 57 | nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended) 58 | ns_steps: The number of Newton-Schulz iterations to run. (6 is probably always enough) 59 | adamw_params: The parameters to be optimized by AdamW. Any parameters in `muon_params` which are 60 | {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW as well. 61 | adamw_lr: The learning rate for the internal AdamW. 62 | adamw_betas: The betas for the internal AdamW. 63 | adamw_eps: The epsilon for the internal AdamW. 64 | adamw_wd: The weight decay for the internal AdamW. 65 | """ 66 | 67 | def __init__( 68 | self, 69 | params=None, 70 | lr=1e-3, 71 | wd=0.1, 72 | momentum=0.95, 73 | nesterov=True, 74 | ns_steps=5, 75 | betas=(0.95, 0.95), 76 | eps=1e-8, 77 | ): 78 | 79 | defaults = dict( 80 | lr=lr, 81 | wd=wd, 82 | momentum=momentum, 83 | nesterov=nesterov, 84 | ns_steps=ns_steps, 85 | betas=betas, 86 | eps=eps, 87 | muon=False, 88 | ) 89 | 90 | super().__init__(params, defaults) 91 | 92 | def adjust_lr_for_muon(self, lr, param_shape): 93 | A, B = param_shape[:2] 94 | # We adjust the learning rate and weight decay based on the size of the parameter matrix 95 | # as describted in the paper 96 | adjusted_ratio = 0.2 * math.sqrt(max(A, B)) 97 | adjusted_lr = lr * adjusted_ratio 98 | return adjusted_lr 99 | 100 | def step(self, closure=None): 101 | """Perform a single optimization step. 102 | 103 | Args: 104 | closure (Callable, optional): A closure that reevaluates the model 105 | and returns the loss. 106 | """ 107 | loss = None 108 | if closure is not None: 109 | with torch.enable_grad(): 110 | loss = closure() 111 | 112 | for group in self.param_groups: 113 | 114 | ############################ 115 | # Muon # 116 | ############################ 117 | if group["muon"]: 118 | # import pdb; pdb.set_trace() 119 | lr = group["lr"] 120 | wd = group["wd"] 121 | momentum = group["momentum"] 122 | 123 | # generate weight updates in distributed fashion 124 | for p in group["params"]: 125 | # sanity check 126 | g = p.grad 127 | if g is None: 128 | continue 129 | if g.ndim > 2: 130 | g = g.view(g.size(0), -1) 131 | assert g is not None 132 | 133 | # calc update 134 | state = self.state[p] 135 | if "momentum_buffer" not in state: 136 | state["momentum_buffer"] = torch.zeros_like(g) 137 | buf = state["momentum_buffer"] 138 | buf.mul_(momentum).add_(g) 139 | if group["nesterov"]: 140 | g = g.add(buf, alpha=momentum) 141 | else: 142 | g = buf 143 | u = zeropower_via_newtonschulz5(g, steps=group["ns_steps"]) 144 | 145 | # scale update 146 | adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) 147 | 148 | # apply weight decay 149 | p.data.mul_(1 - lr * wd) 150 | 151 | # apply update 152 | p.data.add_(u, alpha=-adjusted_lr) 153 | 154 | ############################ 155 | # AdamW backup # 156 | ############################ 157 | else: 158 | lr = group['lr'] 159 | beta1, beta2 = group["betas"] 160 | eps = group["eps"] 161 | weight_decay = group["wd"] 162 | 163 | for p in group["params"]: 164 | g = p.grad 165 | if g is None: 166 | continue 167 | state = self.state[p] 168 | if "step" not in state: 169 | state["step"] = 0 170 | state["moment1"] = torch.zeros_like(g) 171 | state["moment2"] = torch.zeros_like(g) 172 | state["step"] += 1 173 | step = state["step"] 174 | buf1 = state["moment1"] 175 | buf2 = state["moment2"] 176 | buf1.lerp_(g, 1 - beta1) 177 | buf2.lerp_(g.square(), 1 - beta2) 178 | 179 | g = buf1 / (eps + buf2.sqrt()) 180 | 181 | bias_correction1 = 1 - beta1**step 182 | bias_correction2 = 1 - beta2**step 183 | scale = bias_correction1 / bias_correction2**0.5 184 | p.data.mul_(1 - lr * weight_decay) 185 | p.data.add_(g, alpha=-lr / scale) 186 | 187 | return loss -------------------------------------------------------------------------------- /scripts/process_libritts.py: -------------------------------------------------------------------------------- 1 | """ 2 | # Handles processing audio provided through --input-audio of adequately annotated transcriptions provided through --input-metadata (through transcribe.py) 3 | # Outputs NumPy objects containing quantized audio and adequate metadata for use of loading in the trainer through --output-dataset 4 | """ 5 | 6 | import os 7 | import json 8 | import argparse 9 | import torch 10 | import torchaudio 11 | import numpy as np 12 | 13 | from tqdm.auto import tqdm 14 | from pathlib import Path 15 | 16 | from vall_e.config import cfg 17 | 18 | from vall_e.emb.g2p import encode as phonemize 19 | from vall_e.emb.qnt import encode as quantize, _replace_file_extension, convert_audio 20 | 21 | from vall_e.emb.process import pad, load_audio, process_items, process_jobs 22 | 23 | 24 | def process( 25 | audio_backend="encodec", 26 | input_audio="LibriTTS_R", 27 | output_dataset="training", 28 | raise_exceptions=False, 29 | stride=0, 30 | stride_offset=0, 31 | slice="auto", 32 | batch_size=1, 33 | low_memory=False, 34 | 35 | device="cuda", 36 | dtype="float16", 37 | amp=False, 38 | ): 39 | # prepare from args 40 | cfg.device = device 41 | cfg.set_audio_backend(audio_backend) 42 | audio_extension = cfg.audio_backend_extension 43 | 44 | cfg.inference.weight_dtype = dtype # "bfloat16" 45 | cfg.inference.amp = amp # False 46 | 47 | dtype = cfg.inference.dtype if not amp else None 48 | 49 | output_dataset = f"{output_dataset}/{'2' if cfg.sample_rate == 24_000 else '4'}{'8' if cfg.sample_rate == 48_000 else '4'}KHz-{cfg.audio_backend}" # "training" 50 | 51 | language_map = {} # k = group, v = language 52 | 53 | ignore_groups = [] # skip these groups 54 | ignore_speakers = [] # skip these speakers 55 | 56 | only_groups = [] # only process these groups 57 | only_speakers = [] # only process these speakers 58 | 59 | always_slice_groups = [] # always slice from this group 60 | 61 | missing = { 62 | "transcription": [], 63 | "audio": [] 64 | } 65 | dataset = [] 66 | 67 | # Layout: ./LibriTTS_R/train-clean-100/103/1241 68 | for group_name in sorted(os.listdir(f'./{input_audio}/')): 69 | if not os.path.isdir(f'./{input_audio}/{group_name}/'): 70 | print("Is not dir:", f'./{input_audio}/{group_name}/') 71 | continue 72 | 73 | if group_name in ignore_groups: 74 | continue 75 | if only_groups and group_name not in only_groups: 76 | continue 77 | 78 | for speaker_id in tqdm(process_items(os.listdir(f'./{input_audio}/{group_name}/'), stride=stride, stride_offset=stride_offset), desc=f"Processing speaker in {group_name}"): 79 | if not os.path.isdir(f'./{input_audio}/{group_name}/{speaker_id}'): 80 | print("Is not dir:", f'./{input_audio}/{group_name}/{speaker_id}') 81 | continue 82 | 83 | if speaker_id in ignore_speakers: 84 | continue 85 | if only_speakers and speaker_id not in only_speakers: 86 | continue 87 | 88 | os.makedirs(f'./{output_dataset}/{group_name}/{speaker_id}/', exist_ok=True) 89 | 90 | if f'{group_name}/{speaker_id}' not in dataset: 91 | dataset.append(f'{group_name}/{speaker_id}') 92 | 93 | txts = [] 94 | wavs = [] 95 | 96 | for book_id in os.listdir(f'./{input_audio}/{group_name}/{speaker_id}'): 97 | if not os.path.isdir(f'./{input_audio}/{group_name}/{speaker_id}/{book_id}'): 98 | print("Is not dir:", f'./{input_audio}/{group_name}/{speaker_id}/{book_id}') 99 | continue 100 | 101 | for filename in os.listdir(f'./{input_audio}/{group_name}/{speaker_id}/{book_id}'): 102 | if ".wav" not in filename: 103 | continue 104 | 105 | inpath = Path(f'./{input_audio}/{group_name}/{speaker_id}/{book_id}/{filename}') 106 | textpath = _replace_file_extension(inpath, ".original.txt") 107 | if not inpath.exists() or not textpath.exists(): 108 | missing["audio"].append(str(inpath)) 109 | continue 110 | 111 | extension = os.path.splitext(filename)[-1][1:] 112 | fname = filename.replace(f'.{extension}', "") 113 | 114 | waveform, sample_rate = None, None 115 | language = "en" 116 | 117 | outpath = Path(f'./{output_dataset}/{group_name}/{speaker_id}/{fname}.{extension}') 118 | text = open(textpath, "r", encoding="utf-8").read() 119 | 120 | if len(text) == 0: 121 | continue 122 | 123 | if _replace_file_extension(outpath, audio_extension).exists(): 124 | continue 125 | 126 | if waveform is None: 127 | waveform, sample_rate = load_audio(inpath) 128 | 129 | jobs.append(( outpath, waveform, sample_rate, text, language )) 130 | 131 | # processes audio files one at a time 132 | if low_memory: 133 | process_jobs( jobs, device=device, speaker_id=f'{speaker_id}/{filename}', raise_exceptions=raise_exceptions, batch_size=batch_size, dtype=dtype if not amp else None ) 134 | jobs = [] 135 | 136 | # processes all audio files for a given speaker 137 | if not low_memory: 138 | process_jobs( jobs, device=device, speaker_id=speaker_id, raise_exceptions=raise_exceptions, batch_size=batch_size, dtype=dtype if not amp else None ) 139 | jobs = [] 140 | 141 | open(f"./{output_dataset}/missing.json", 'w', encoding='utf-8').write(json.dumps(missing)) 142 | open(f"./{output_dataset}/dataset.json", 'w', encoding='utf-8').write(json.dumps(dataset)) 143 | 144 | def main(): 145 | parser = argparse.ArgumentParser() 146 | 147 | parser.add_argument("--audio-backend", type=str, default="encodec") 148 | parser.add_argument("--dtype", type=str, default="bfloat16") 149 | parser.add_argument("--amp", action="store_true") 150 | parser.add_argument("--input-audio", type=str, default="LibriTTS_R") 151 | parser.add_argument("--output-dataset", type=str, default="training/dataset") 152 | parser.add_argument("--device", type=str, default="cuda") 153 | parser.add_argument("--raise-exceptions", action="store_true") 154 | parser.add_argument("--stride", type=int, default=0) 155 | parser.add_argument("--stride-offset", type=int, default=0) 156 | parser.add_argument("--slice", type=str, default="auto") 157 | parser.add_argument("--low-memory", action="store_true") 158 | parser.add_argument("--batch-size", type=int, default=0) 159 | 160 | args = parser.parse_args() 161 | 162 | # do some assumption magic 163 | # to-do: find a nice way to spawn multiple processes where tqdm plays nicely 164 | if args.device.isnumeric(): 165 | args.stride = torch.cuda.device_count() 166 | args.stride_offset = int(args.device) 167 | args.device = f'cuda:{args.device}' 168 | 169 | process( 170 | audio_backend=args.audio_backend, 171 | input_audio=args.input_audio, 172 | output_dataset=args.output_dataset, 173 | raise_exceptions=args.raise_exceptions, 174 | stride=args.stride, 175 | stride_offset=args.stride_offset, 176 | slice=args.slice, 177 | batch_size=args.batch_size, 178 | low_memory=args.low_memory, 179 | 180 | device=args.device, 181 | dtype=args.dtype, 182 | amp=args.amp, 183 | ) 184 | 185 | if __name__ == "__main__": 186 | main() -------------------------------------------------------------------------------- /vall_e/utils/ml.py: -------------------------------------------------------------------------------- 1 | from contextlib import contextmanager 2 | 3 | import math 4 | import logging 5 | 6 | import torch 7 | import torch.nn.functional as F 8 | from torch.optim.lr_scheduler import _LRScheduler 9 | 10 | from ..config import cfg 11 | 12 | _logger = logging.getLogger(__name__) 13 | 14 | Embedding = torch.nn.Embedding 15 | Linear = torch.nn.Linear 16 | 17 | Adam = torch.optim.Adam 18 | AdamW = torch.optim.AdamW 19 | SGD = torch.optim.SGD 20 | Adagrad = torch.optim.Adagrad 21 | Adafactor = torch.optim.Adafactor 22 | 23 | OneCycleLR = torch.optim.lr_scheduler.OneCycleLR 24 | CosineAnnealingLR = torch.optim.lr_scheduler.CosineAnnealingLR 25 | LambdaLR = torch.optim.lr_scheduler.LambdaLR 26 | 27 | # implements Noam scheduling 28 | # it's cringe 29 | class NoamLR(_LRScheduler): 30 | def __init__(self, optimizer, warmup_steps, d_model=1024, last_epoch=-1): 31 | self.base_factor = d_model ** (-0.5) 32 | self.warmup_steps = warmup_steps 33 | 34 | super().__init__(optimizer, last_epoch) 35 | 36 | def get_lr(self): 37 | step = max(1, self.last_epoch) 38 | scale = self.base_factor * min(step ** (-0.5), step * self.warmup_steps ** (-1.5)) 39 | 40 | return [base_lr * scale for base_lr in self.base_lrs] 41 | 42 | # gradually warms up LR then holds or decays 43 | class WarmupLR(_LRScheduler): 44 | def __init__(self, optimizer, warmup_steps, decay_factor=0.0, last_epoch=-1): 45 | self.warmup_steps = warmup_steps 46 | self.decay_factor = decay_factor 47 | 48 | super().__init__(optimizer, last_epoch) 49 | 50 | def get_lr(self): 51 | step = self.last_epoch + 1 52 | scale = 1 53 | if step < self.warmup_steps: 54 | scale = float(step) / float(max(1, self.warmup_steps)) 55 | elif self.decay_factor != 0: 56 | scale = (1.0 - self.decay_factor) ** (step - self.warmup_steps) 57 | 58 | return [base_lr * scale for base_lr in self.base_lrs] 59 | 60 | # https://github.com/kyegomez/BitNet 61 | if cfg.optimizations.bitnet: 62 | from bitnet import BitLinear 63 | 64 | if cfg.optimizations.bitsandbytes: 65 | import bitsandbytes as bnb 66 | 67 | if cfg.optimizations.linear: 68 | 69 | if cfg.optimizations.bitnet: 70 | Linear = BitLinear 71 | else: 72 | Linear = bnb.nn.Linear8bitLt 73 | 74 | if cfg.optimizations.embedding: 75 | Embedding = bnb.nn.StableEmbedding 76 | """ 77 | Embedding.forward = lambda self, input: ( self.norm(F.embedding( 78 | input, 79 | self.weight, 80 | self.padding_idx, 81 | self.max_norm, 82 | self.norm_type, 83 | self.scale_grad_by_freq, 84 | self.sparse, 85 | )).to(self.weight.dtype) ) 86 | """ 87 | 88 | if cfg.optimizations.optimizers: 89 | Adam = bnb.optim.Adam8bit 90 | AdamW = bnb.optim.AdamW8bit 91 | SGD = bnb.optim.SGD8bit 92 | Adagrad = bnb.optim.Adagrad8bit 93 | 94 | elif cfg.optimizations.dadaptation: 95 | import dadaptation 96 | 97 | if cfg.optimizations.optimizers: 98 | Adam = dadaptation.DAdaptAdam 99 | AdamW = dadaptation.DAdaptAdam 100 | SGD = dadaptation.DAdaptSGD 101 | AdaGrad = dadaptation.DAdaptAdaGrad 102 | 103 | if cfg.optimizations.fp8: 104 | import transformer_engine.pytorch as te 105 | 106 | Linear = te.Linear 107 | 108 | @contextmanager 109 | def autocast(): 110 | yield te.fp8_autocast(enabled=True) 111 | else: 112 | @contextmanager 113 | def autocast(): 114 | yield torch.autocast("cuda", dtype=cfg.trainer.dtype, enabled=cfg.trainer.amp) 115 | 116 | if cfg.optimizations.injects: 117 | if cfg.optimizations.linear: 118 | torch.nn.Linear = Linear 119 | 120 | if cfg.optimizations.embedding: 121 | torch.nn.Embedding = Embedding 122 | 123 | if cfg.optimizations.optimizers: 124 | torch.optim.Adam = Adam 125 | torch.optim.AdamW = AdamW 126 | torch.optim.SGD = SGD 127 | 128 | if cfg.optimizations.unsloth: 129 | try: 130 | from .ext.unsloth import apply_unsloth_offloaded_gradient_checkpoint_monkey_patch 131 | #apply_unsloth_offloaded_gradient_checkpoint_monkey_patch() 132 | except Exception as e: 133 | _logger.warning(f'Error while importing Unsloth: {str(e)}') 134 | pass 135 | 136 | class Optimizers(torch.optim.Optimizer): 137 | def __init__(self, opts): 138 | self.opts = opts 139 | 140 | def step(self, *args, **kwargs): 141 | for opt in self.opts: 142 | opt.step(*args, **kwargs) 143 | 144 | def zero_grad(self, *args, **kwargs): 145 | for opt in self.opts: 146 | opt.zero_grad(*args, **kwargs) 147 | 148 | @property 149 | def param_groups(self): 150 | l = [] 151 | for opt in self.opts: 152 | l += opt.param_groups 153 | return l 154 | 155 | def state_dict(self): 156 | states = [] 157 | for i, opt in enumerate( self.opts ): 158 | states.append( opt.state_dict() ) 159 | 160 | return states 161 | 162 | def load_state_dict(self, state_dict): 163 | for opt, state in zip( self.opts, state_dict ): 164 | opt.load_state_dict( state ) 165 | 166 | try: 167 | from .ext.apollo import Apollo 168 | except Exception as e: 169 | _logger.warning(f'Error while importing APOLLO: {str(e)}') 170 | pass 171 | 172 | try: 173 | from .ext.muon import Muon 174 | except Exception as e: 175 | _logger.warning(f'Error while importing Muon: {str(e)}') 176 | pass 177 | 178 | # https://github.com/konstmish/prodigy 179 | try: 180 | from prodigyopt import Prodigy 181 | except Exception as e: 182 | _logger.warning(f'Error while importing Prodigyopt: {str(e)}') 183 | pass 184 | 185 | # https://github.com/facebookresearch/schedule_free/ 186 | try: 187 | import schedulefree 188 | except Exception as e: 189 | _logger.warning(f'Error while importing Schedule_Free: {str(e)}') 190 | pass 191 | 192 | # backwards compat 193 | from .utils import ( 194 | autocast_forward, 195 | replace_linear as replace_linear_old, 196 | replace_embedding as replace_embedding_old, 197 | replace_attention, 198 | resize_weight, 199 | offload_model, 200 | ) 201 | 202 | # wrapped here so we can maintain default args 203 | def replace_linear( model, klass=Linear, target=torch.nn.Linear, verbose=False ): 204 | return replace_linear_old( model, klass, target, verbose ) 205 | def replace_embedding( model, klass=Embedding, target=torch.nn.Embedding, verbose=False ): 206 | return replace_embedding_old( model, klass, target, verbose ) 207 | 208 | Embedding.forward = autocast_forward(Embedding.forward) 209 | 210 | AVAILABLE_COMPILE_BACKENDS = [] 211 | 212 | try: 213 | AVAILABLE_COMPILE_BACKENDS += torch._dynamo.list_backends() 214 | except Exception as e: 215 | pass 216 | 217 | def compile_model(model, backend="auto"): 218 | if not backend or backend == "auto": 219 | backend = AVAILABLE_COMPILE_BACKENDS[0] 220 | 221 | if backend not in AVAILABLE_COMPILE_BACKENDS: 222 | return torch.compile(model) 223 | 224 | return torch.compile(model, backend=backend) 225 | 226 | 227 | if cfg.optimizations.tensorrt: 228 | try: 229 | import torch_tensorrt 230 | AVAILABLE_COMPILE_BACKENDS.append("tensorrt") 231 | except Exception as e: 232 | _logger.warning(f'Error while importing TensorRT: {str(e)}') 233 | pass -------------------------------------------------------------------------------- /vall_e.cpp/include/ggml-cpu.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "ggml.h" 4 | #include "ggml-backend.h" 5 | 6 | #ifdef __cplusplus 7 | extern "C" { 8 | #endif 9 | 10 | // the compute plan that needs to be prepared for ggml_graph_compute() 11 | // since https://github.com/ggml-org/ggml/issues/287 12 | struct ggml_cplan { 13 | size_t work_size; // size of work buffer, calculated by `ggml_graph_plan()` 14 | uint8_t * work_data; // work buffer, to be allocated by caller before calling to `ggml_graph_compute()` 15 | 16 | int n_threads; 17 | struct ggml_threadpool * threadpool; 18 | 19 | // abort ggml_graph_compute when true 20 | ggml_abort_callback abort_callback; 21 | void * abort_callback_data; 22 | }; 23 | 24 | // numa strategies 25 | enum ggml_numa_strategy { 26 | GGML_NUMA_STRATEGY_DISABLED = 0, 27 | GGML_NUMA_STRATEGY_DISTRIBUTE = 1, 28 | GGML_NUMA_STRATEGY_ISOLATE = 2, 29 | GGML_NUMA_STRATEGY_NUMACTL = 3, 30 | GGML_NUMA_STRATEGY_MIRROR = 4, 31 | GGML_NUMA_STRATEGY_COUNT 32 | }; 33 | 34 | GGML_BACKEND_API void ggml_numa_init(enum ggml_numa_strategy numa); // call once for better performance on NUMA systems 35 | GGML_BACKEND_API bool ggml_is_numa(void); // true if init detected that system has >1 NUMA node 36 | 37 | GGML_BACKEND_API struct ggml_tensor * ggml_new_i32(struct ggml_context * ctx, int32_t value); 38 | GGML_BACKEND_API struct ggml_tensor * ggml_new_f32(struct ggml_context * ctx, float value); 39 | 40 | GGML_BACKEND_API struct ggml_tensor * ggml_set_i32 (struct ggml_tensor * tensor, int32_t value); 41 | GGML_BACKEND_API struct ggml_tensor * ggml_set_f32 (struct ggml_tensor * tensor, float value); 42 | 43 | GGML_BACKEND_API int32_t ggml_get_i32_1d(const struct ggml_tensor * tensor, int i); 44 | GGML_BACKEND_API void ggml_set_i32_1d(const struct ggml_tensor * tensor, int i, int32_t value); 45 | 46 | GGML_BACKEND_API int32_t ggml_get_i32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3); 47 | GGML_BACKEND_API void ggml_set_i32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3, int32_t value); 48 | 49 | GGML_BACKEND_API float ggml_get_f32_1d(const struct ggml_tensor * tensor, int i); 50 | GGML_BACKEND_API void ggml_set_f32_1d(const struct ggml_tensor * tensor, int i, float value); 51 | 52 | GGML_BACKEND_API float ggml_get_f32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3); 53 | GGML_BACKEND_API void ggml_set_f32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3, float value); 54 | 55 | GGML_BACKEND_API struct ggml_threadpool * ggml_threadpool_new (struct ggml_threadpool_params * params); 56 | GGML_BACKEND_API void ggml_threadpool_free (struct ggml_threadpool * threadpool); 57 | GGML_BACKEND_API int ggml_threadpool_get_n_threads (struct ggml_threadpool * threadpool); 58 | GGML_BACKEND_API void ggml_threadpool_pause (struct ggml_threadpool * threadpool); 59 | GGML_BACKEND_API void ggml_threadpool_resume (struct ggml_threadpool * threadpool); 60 | 61 | // ggml_graph_plan() has to be called before ggml_graph_compute() 62 | // when plan.work_size > 0, caller must allocate memory for plan.work_data 63 | GGML_BACKEND_API struct ggml_cplan ggml_graph_plan( 64 | const struct ggml_cgraph * cgraph, 65 | int n_threads, /* = GGML_DEFAULT_N_THREADS */ 66 | struct ggml_threadpool * threadpool /* = NULL */ ); 67 | GGML_BACKEND_API enum ggml_status ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan); 68 | 69 | // same as ggml_graph_compute() but the work data is allocated as a part of the context 70 | // note: the drawback of this API is that you must have ensured that the context has enough memory for the work data 71 | GGML_BACKEND_API enum ggml_status ggml_graph_compute_with_ctx(struct ggml_context * ctx, struct ggml_cgraph * cgraph, int n_threads); 72 | 73 | // 74 | // system info 75 | // 76 | 77 | // x86 78 | GGML_BACKEND_API int ggml_cpu_has_sse3 (void); 79 | GGML_BACKEND_API int ggml_cpu_has_ssse3 (void); 80 | GGML_BACKEND_API int ggml_cpu_has_avx (void); 81 | GGML_BACKEND_API int ggml_cpu_has_avx_vnni (void); 82 | GGML_BACKEND_API int ggml_cpu_has_avx2 (void); 83 | GGML_BACKEND_API int ggml_cpu_has_bmi2 (void); 84 | GGML_BACKEND_API int ggml_cpu_has_f16c (void); 85 | GGML_BACKEND_API int ggml_cpu_has_fma (void); 86 | GGML_BACKEND_API int ggml_cpu_has_avx512 (void); 87 | GGML_BACKEND_API int ggml_cpu_has_avx512_vbmi(void); 88 | GGML_BACKEND_API int ggml_cpu_has_avx512_vnni(void); 89 | GGML_BACKEND_API int ggml_cpu_has_avx512_bf16(void); 90 | GGML_BACKEND_API int ggml_cpu_has_amx_int8 (void); 91 | // ARM 92 | GGML_BACKEND_API int ggml_cpu_has_neon (void); 93 | GGML_BACKEND_API int ggml_cpu_has_arm_fma (void); 94 | GGML_BACKEND_API int ggml_cpu_has_fp16_va (void); 95 | GGML_BACKEND_API int ggml_cpu_has_dotprod (void); 96 | GGML_BACKEND_API int ggml_cpu_has_matmul_int8(void); 97 | GGML_BACKEND_API int ggml_cpu_has_sve (void); 98 | GGML_BACKEND_API int ggml_cpu_get_sve_cnt (void); // sve vector length in bytes 99 | GGML_BACKEND_API int ggml_cpu_has_sme (void); 100 | // other 101 | GGML_BACKEND_API int ggml_cpu_has_riscv_v (void); 102 | GGML_BACKEND_API int ggml_cpu_has_vsx (void); 103 | GGML_BACKEND_API int ggml_cpu_has_vxe (void); 104 | GGML_BACKEND_API int ggml_cpu_has_wasm_simd (void); 105 | GGML_BACKEND_API int ggml_cpu_has_llamafile (void); 106 | 107 | // Internal types and functions exposed for tests and benchmarks 108 | 109 | typedef void (*ggml_vec_dot_t) (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT x, size_t bx, 110 | const void * GGML_RESTRICT y, size_t by, int nrc); 111 | 112 | struct ggml_type_traits_cpu { 113 | ggml_from_float_t from_float; 114 | ggml_vec_dot_t vec_dot; 115 | enum ggml_type vec_dot_type; 116 | int64_t nrows; // number of rows to process simultaneously 117 | }; 118 | 119 | GGML_BACKEND_API const struct ggml_type_traits_cpu * ggml_get_type_traits_cpu(enum ggml_type type); 120 | 121 | GGML_BACKEND_API void ggml_cpu_init(void); 122 | 123 | // 124 | // CPU backend 125 | // 126 | 127 | GGML_BACKEND_API ggml_backend_t ggml_backend_cpu_init(void); 128 | 129 | GGML_BACKEND_API bool ggml_backend_is_cpu (ggml_backend_t backend); 130 | GGML_BACKEND_API void ggml_backend_cpu_set_n_threads (ggml_backend_t backend_cpu, int n_threads); 131 | GGML_BACKEND_API void ggml_backend_cpu_set_threadpool (ggml_backend_t backend_cpu, ggml_threadpool_t threadpool); 132 | GGML_BACKEND_API void ggml_backend_cpu_set_abort_callback(ggml_backend_t backend_cpu, ggml_abort_callback abort_callback, void * abort_callback_data); 133 | 134 | GGML_BACKEND_API ggml_backend_reg_t ggml_backend_cpu_reg(void); 135 | 136 | #ifdef __cplusplus 137 | } 138 | #endif 139 | -------------------------------------------------------------------------------- /vall_e.cpp/include/encodec.h: -------------------------------------------------------------------------------- 1 | /* 2 | ╞══════════════════════════════════════════════════════════════════════════════╡ 3 | │ Copyright 2024 Pierre-Antoine Bannier │ 4 | │ │ 5 | │ Permission to use, copy, modify, and/or distribute this software for │ 6 | │ any purpose with or without fee is hereby granted, provided that the │ 7 | │ above copyright notice and this permission notice appear in all copies. │ 8 | │ │ 9 | │ THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL │ 10 | │ WARRANTIES WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED │ 11 | │ WARRANTIES OF MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE │ 12 | │ AUTHOR BE LIABLE FOR ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL │ 13 | │ DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR │ 14 | │ PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER │ 15 | │ TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR │ 16 | │ PERFORMANCE OF THIS SOFTWARE. │ 17 | ╚─────────────────────────────────────────────────────────────────────────────*/ 18 | /* 19 | * This file contains the declarations of the structs and functions used in the encodec library. 20 | * The library provides functionality for audio compression and decompression using a custom model. 21 | * The model consists of an encoder, a quantizer and a decoder, each with their own set of parameters. 22 | * The library also provides functions for loading and freeing the model, as well as compressing and decompressing audio data. 23 | * 24 | */ 25 | #pragma once 26 | 27 | #include "ggml-alloc.h" 28 | #include "ggml-backend.h" 29 | #include "ggml.h" 30 | 31 | #ifdef __cplusplus 32 | extern "C" { 33 | #endif 34 | struct encodec_context; 35 | 36 | struct encodec_statistics { 37 | // The time taken to load the model. 38 | int64_t t_load_us; 39 | // The time taken to compute the model. 40 | int64_t t_compute_us; 41 | }; 42 | 43 | /** 44 | * Loads an encodec model from the specified file path. 45 | * 46 | * @param model_path The file path to the encodec model. 47 | * @param offset The offset (in bytes) to the start of the model in the file. 48 | * @param n_gpu_layers The number of GPU layers to use. 49 | * @return A pointer to the encodec context struct. 50 | */ 51 | struct encodec_context *encodec_load_model( 52 | const char *model_path, 53 | const int offset, 54 | int n_gpu_layers); 55 | 56 | /** 57 | * Sets the target bandwidth for the given encodec context. 58 | * 59 | * @param ectx The encodec context to set the target bandwidth for. 60 | * @param bandwidth The target bandwidth to set, in bits per second. 61 | */ 62 | void encodec_set_target_bandwidth( 63 | struct encodec_context *ectx, 64 | int bandwidth); 65 | 66 | /** 67 | * Sets the sample rate for the given encodec context. 68 | * 69 | * @param ectx The encodec context to set the target bandwidth for. 70 | * @param sample_rate The sample rate to set. 71 | */ 72 | void encodec_set_sample_rate( 73 | struct encodec_context *ectx, 74 | int sample_rate); 75 | 76 | /** 77 | * Reconstructs audio from raw audio data using the specified encodec context. 78 | * 79 | * @param ectx The encodec context to use for reconstruction. 80 | * @param raw_audio The raw audio data to reconstruct. 81 | * @param n_samples The number of samples in the raw audio buffer. 82 | * @param n_threads The number of threads to use for reconstruction. 83 | * @return True if the reconstruction was successful, false otherwise. 84 | */ 85 | bool encodec_reconstruct_audio( 86 | struct encodec_context *ectx, 87 | const float *raw_audio, 88 | const int n_samples, 89 | int n_threads); 90 | 91 | /** 92 | * Compresses audio data using the specified encodec context. 93 | * 94 | * @param ectx The encodec context to use for compression. 95 | * @param raw_audio The raw audio data to compress. 96 | * @param n_samples The number of samples in the raw audio buffer. 97 | * @param n_threads The number of threads to use for compression. 98 | * @return True if the compression was successful, false otherwise. 99 | */ 100 | bool encodec_compress_audio( 101 | struct encodec_context *ectx, 102 | const float *raw_audio, 103 | const int n_samples, 104 | int n_threads); 105 | 106 | /** 107 | * Decompresses audio data using the specified encodec context. 108 | * 109 | * @param ectx The encodec context to use for decompression. 110 | * @param codes The compressed audio data to decompress. 111 | * @param n_codes The number of codes in the codes buffer. 112 | * @param n_threads The number of threads to use for decompression. 113 | * @return True if the audio data was successfully decompressed, false otherwise. 114 | */ 115 | bool encodec_decompress_audio( 116 | struct encodec_context *ectx, 117 | const int32_t *codes, 118 | const int n_codes, 119 | int n_threads); 120 | 121 | /** 122 | * Gets the audio data from the given encodec context. 123 | * 124 | * @param ectx The encodec context to get the audio data from. 125 | * @return A pointer to the audio data. 126 | */ 127 | float * encodec_get_audio( 128 | struct encodec_context *ectx); 129 | 130 | /** 131 | * Gets the size of the audio data from the given encodec context. 132 | * 133 | * @param ectx The encodec context to get the audio size from. 134 | * @return The size of the audio data. 135 | */ 136 | int encodec_get_audio_size( 137 | struct encodec_context *ectx); 138 | 139 | /** 140 | * Gets the code data from the given encodec context. 141 | * 142 | * @param ectx The encodec context to get the code data from. 143 | * @return A pointer to the code data. 144 | */ 145 | int32_t * encodec_get_codes( 146 | struct encodec_context *ectx); 147 | 148 | /** 149 | * Gets the size of the code data from the given encodec context. 150 | * 151 | * @param ectx The encodec context to get the code size from. 152 | * @return The size of the code data. 153 | */ 154 | int encodec_get_codes_size( 155 | struct encodec_context *ectx); 156 | 157 | /** 158 | * Gets the statistics for the given encodec context. 159 | * 160 | * @param ectx The encodec context to get the statistics for. 161 | * @return A pointer to the statistics struct. 162 | */ 163 | const struct encodec_statistics* encodec_get_statistics( 164 | struct encodec_context *ectx); 165 | 166 | /** 167 | * Reset the statistics for the given encodec context. 168 | * 169 | * @param ectx The encodec context to reset the statistics for. 170 | */ 171 | void encodec_reset_statistics( 172 | struct encodec_context *ectx); 173 | 174 | /** 175 | * @brief Frees the memory allocated for an encodec context. 176 | * 177 | * @param ectx The encodec context to free. 178 | */ 179 | void encodec_free( 180 | struct encodec_context *ectx); 181 | 182 | #ifdef __cplusplus 183 | } 184 | #endif -------------------------------------------------------------------------------- /vall_e/models/lora.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/microsoft/LoRA/blob/main/loralib/layers.py 2 | from functools import partial 3 | import torch 4 | import torch.nn.functional as F 5 | import torch.nn.utils.parametrize as parametrize 6 | 7 | from transformers.pytorch_utils import Conv1D 8 | 9 | from torch import Tensor, nn 10 | 11 | import math 12 | from typing import Optional, List 13 | 14 | from ..utils import passes_policy 15 | 16 | # LoRA Linear for replacement 17 | # Pros: simple, just needs to reuse the replace_linear and copy weights 18 | # Cons: does not work with other Linears (bnb, bitnet, te's fp8, etc), cannot apply multiple LoRAs (although for audio why would you) 19 | class LoRALinear(nn.Linear): 20 | def __init__( 21 | self, 22 | 23 | in_features: int, 24 | out_features: int, 25 | bias: bool = True, 26 | 27 | rank: int = 4, 28 | alpha: int = 1, 29 | 30 | dropout: float = 0.1, 31 | merge_weights: bool = False, 32 | **kwargs, 33 | ): 34 | super().__init__(in_features=in_features, out_features=out_features, bias=bias, **kwargs) 35 | 36 | self.rank = rank 37 | self.alpha = alpha 38 | self.dropout = nn.Dropout(p=dropout) if dropout > 0 else lambda x: x 39 | self.merge_weights = merge_weights 40 | self.merged = False 41 | self.enabled = True 42 | 43 | self.lora_B = nn.Parameter( self.weight.new_zeros( (out_features, rank) ) ) 44 | self.lora_A = nn.Parameter( self.weight.new_zeros( (rank, in_features) ) ) 45 | self.scaling = self.alpha / self.rank 46 | 47 | self.weight.requires_grad = False 48 | 49 | self.reset_parameters() 50 | 51 | def reset_parameters(self): 52 | super().reset_parameters() 53 | # super silly but necessary because nn.Linear's constructor calls this 54 | if hasattr(self, 'lora_A'): 55 | nn.init.kaiming_uniform_( self.lora_A, a=math.sqrt(5) ) 56 | nn.init.zeros_( self.lora_B ) 57 | 58 | def train(self, mode: bool = True): 59 | super().train(mode) 60 | 61 | # training, separate lora from base weights 62 | if mode and self.merge_weights and self.merged: 63 | self.weight.data -= (self.lora_B @ self.lora_A) * self.scaling 64 | self.merged = False 65 | 66 | # not training, merge lora to base weights 67 | if not mode and self.merge_weights and not self.merged: 68 | self.weight.data += (self.lora_B @ self.lora_A) * self.scaling 69 | self.merged = True 70 | 71 | def forward(self, x: torch.Tensor): 72 | if not self.merged and self.enabled: 73 | result = F.linear(x, self.weight, bias=self.bias) 74 | result += (self.dropout(x) @ self.lora_A.transpose(0, 1) @ self.lora_B.transpose(0, 1)) * self.scaling 75 | return result 76 | 77 | return F.linear(x, self.weight, bias=self.bias) 78 | 79 | @classmethod 80 | def from_linear( cls, layer, device = None, dtype = None, **kwargs ): 81 | if device is None: 82 | device = layer.weight.device 83 | if dtype is None: 84 | dtype = layer.weight.dtype 85 | return cls( in_features = layer.in_features, out_features = layer.out_features, bias = layer.bias is not None, **kwargs ).to(device=device, dtype=dtype) 86 | 87 | # Uses parametrization to inject LoRA weights 88 | # Pros: should work with any Linears 89 | # Cons: TBD 90 | class ParameterizedLoRA(nn.Module): 91 | def __init__( 92 | self, 93 | 94 | in_features: int, 95 | out_features: int, 96 | bias: bool = True, 97 | 98 | rank: int = 4, 99 | alpha: int = 1, 100 | 101 | dropout: float = 0.1, 102 | 103 | device = None, 104 | dtype = None 105 | ): 106 | super().__init__() 107 | self.rank = rank 108 | self.alpha = alpha 109 | self.dropout = nn.Dropout(p=dropout) if dropout > 0 else lambda x: x 110 | 111 | self.lora_B = nn.Parameter( torch.zeros( (out_features, rank) ) ).to( device=device, dtype=dtype ) 112 | self.lora_A = nn.Parameter( torch.zeros( (rank, in_features) ) ).to( device=device, dtype=dtype ) 113 | self.scaling = self.alpha / self.rank 114 | self.enabled = True 115 | 116 | self.reset_parameters() 117 | 118 | def reset_parameters(self): 119 | nn.init.kaiming_uniform_( self.lora_A, a=math.sqrt(5) ) 120 | nn.init.zeros_( self.lora_B ) 121 | 122 | def forward(self, x: torch.Tensor): 123 | if self.enabled: 124 | return x + torch.matmul(self.lora_B, self.dropout(self.lora_A)).view(x.shape) * self.scaling 125 | return x 126 | 127 | @classmethod 128 | def from_linear( cls, layer, device = None, dtype = None, **kwargs ): 129 | if device is None: 130 | device = layer.weight.device 131 | if dtype is None: 132 | dtype = layer.weight.dtype 133 | # swap because we're feeding the output as our input 134 | # M$'s LoRA class arranges things to where this isn't necessary 135 | return cls( in_features = layer.out_features, out_features = layer.in_features, bias = layer.bias is not None, **kwargs ).to(device=device, dtype=dtype) 136 | 137 | @classmethod 138 | def from_conv1d( cls, layer, device = None, dtype = None, **kwargs ): 139 | if device is None: 140 | device = layer.weight.device 141 | if dtype is None: 142 | dtype = layer.weight.dtype 143 | 144 | in_channels, out_channels = layer.weight.shape 145 | # swap because we're feeding the output as our input 146 | # M$'s LoRA class arranges things to where this isn't necessary 147 | return cls( in_features = out_channels, out_features = in_channels, bias = layer.bias is not None, **kwargs ).to(device=device, dtype=dtype) 148 | 149 | def apply_lora( model, register = True, merge = False, policy = None, use_parametrize = False, **kwargs ): 150 | device = next(model.parameters()).device 151 | dtype = next(model.parameters()).dtype 152 | 153 | modules = [ k.split('.') for k, m in model.named_modules() if passes_policy( policy, k ) ] 154 | 155 | for *parent, k in modules: 156 | name = '.'.join(parent) 157 | layer = getattr( model.get_submodule(name), k ) 158 | 159 | if isinstance( layer, nn.Linear ): 160 | target = nn.Linear 161 | klass = ParameterizedLoRA if use_parametrize else LoRALinear 162 | replacer = klass.from_linear 163 | elif isinstance( layer, nn.Conv1d ): 164 | target = nn.Conv1d 165 | klass = ParameterizedLoRA if use_parametrize else LoRAConv1d 166 | replacer = klass.from_conv1d 167 | elif isinstance( layer, Conv1D ): 168 | target = Conv1D 169 | klass = ParameterizedLoRA if use_parametrize else LoRAConv1d 170 | replacer = klass.from_conv1d 171 | else: 172 | continue 173 | 174 | replacement = replacer( layer, device=device, dtype=dtype, **kwargs ) 175 | 176 | if use_parametrize: 177 | parametrize.register_parametrization( layer, "weight", replacement ) 178 | else: 179 | setattr( model.get_submodule(name), k, replacement ) 180 | 181 | return enable_lora( model ) 182 | 183 | def enable_lora( model, mode = True ): 184 | for name, module in model.named_modules(): 185 | if not isinstance( module, ParameterizedLoRA ) and not isinstance( module, LoRALinear ): 186 | continue 187 | module.enabled = mode 188 | return model 189 | 190 | def disable_lora( model ): 191 | return enable_lora( model, False ) 192 | 193 | def freeze_non_lora_weights( model, embeddings = False ): 194 | frozen_params = [] 195 | 196 | for name, param in model.named_parameters(): 197 | should = 'lora_' in name or (embeddings and "_emb" in name) 198 | 199 | param.requires_grad_(should) 200 | 201 | if not should: 202 | frozen_params.append( param ) 203 | 204 | return frozen_params 205 | 206 | def lora_get_state_dict( state_dict, split = True ): 207 | lora = { name: param for name, param in state_dict.items() if "lora_" in name } 208 | if not split: 209 | return lora 210 | 211 | return lora, { name: param for name, param in state_dict.items() if "lora_" not in name } 212 | 213 | def lora_load_state_dict( model, state_dict ): 214 | return model.load_state_dict( state_dict, strict = False ) -------------------------------------------------------------------------------- /vall_e.cpp/include/espeak-ng/espeak_ng.h: -------------------------------------------------------------------------------- 1 | /* eSpeak NG API. 2 | * 3 | * Copyright (C) 2015-2017 Reece H. Dunn 4 | * 5 | * This program is free software: you can redistribute it and/or modify 6 | * it under the terms of the GNU General Public License as published by 7 | * the Free Software Foundation, either version 3 of the License, or 8 | * (at your option) any later version. 9 | * 10 | * This program is distributed in the hope that it will be useful, 11 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 12 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 | * GNU General Public License for more details. 14 | * 15 | * You should have received a copy of the GNU General Public License 16 | * along with this program. If not, see . 17 | */ 18 | 19 | #ifndef ESPEAK_NG_H 20 | #define ESPEAK_NG_H 21 | 22 | #include 23 | 24 | #ifdef __cplusplus 25 | extern "C" 26 | { 27 | #endif 28 | 29 | #if defined(_WIN32) || defined(_WIN64) 30 | #ifdef LIBESPEAK_NG_EXPORT 31 | #define ESPEAK_NG_API __declspec(dllexport) 32 | #else 33 | #define ESPEAK_NG_API __declspec(dllimport) 34 | #endif 35 | #else 36 | #define ESPEAK_NG_API 37 | #endif 38 | 39 | #define ESPEAKNG_DEFAULT_VOICE "en" 40 | 41 | typedef enum { 42 | ENS_GROUP_MASK = 0x70000000, 43 | ENS_GROUP_ERRNO = 0x00000000, /* Values 0-255 map to errno error codes. */ 44 | ENS_GROUP_ESPEAK_NG = 0x10000000, /* eSpeak NG error codes. */ 45 | 46 | /* eSpeak NG 1.49.0 */ 47 | ENS_OK = 0, 48 | ENS_COMPILE_ERROR = 0x100001FF, 49 | ENS_VERSION_MISMATCH = 0x100002FF, 50 | ENS_FIFO_BUFFER_FULL = 0x100003FF, 51 | ENS_NOT_INITIALIZED = 0x100004FF, 52 | ENS_AUDIO_ERROR = 0x100005FF, 53 | ENS_VOICE_NOT_FOUND = 0x100006FF, 54 | ENS_MBROLA_NOT_FOUND = 0x100007FF, 55 | ENS_MBROLA_VOICE_NOT_FOUND = 0x100008FF, 56 | ENS_EVENT_BUFFER_FULL = 0x100009FF, 57 | ENS_NOT_SUPPORTED = 0x10000AFF, 58 | ENS_UNSUPPORTED_PHON_FORMAT = 0x10000BFF, 59 | ENS_NO_SPECT_FRAMES = 0x10000CFF, 60 | ENS_EMPTY_PHONEME_MANIFEST = 0x10000DFF, 61 | ENS_SPEECH_STOPPED = 0x10000EFF, 62 | 63 | /* eSpeak NG 1.49.2 */ 64 | ENS_UNKNOWN_PHONEME_FEATURE = 0x10000FFF, 65 | ENS_UNKNOWN_TEXT_ENCODING = 0x100010FF, 66 | } espeak_ng_STATUS; 67 | 68 | typedef enum { 69 | ENOUTPUT_MODE_SYNCHRONOUS = 0x0001, 70 | ENOUTPUT_MODE_SPEAK_AUDIO = 0x0002, 71 | } espeak_ng_OUTPUT_MODE; 72 | 73 | typedef enum { 74 | ENGENDER_UNKNOWN = 0, 75 | ENGENDER_MALE = 1, 76 | ENGENDER_FEMALE = 2, 77 | ENGENDER_NEUTRAL = 3, 78 | } espeak_ng_VOICE_GENDER; 79 | 80 | typedef struct 81 | { 82 | void (*outputPhoSymbol)(char* pho_code,int pho_type); 83 | void (*outputSilence)(short echo_tail); 84 | void (*outputVoiced)(short sample); 85 | void (*outputUnvoiced)(short sample); 86 | } espeak_ng_OUTPUT_HOOKS; 87 | 88 | /* eSpeak NG 1.49.0 */ 89 | 90 | typedef struct espeak_ng_ERROR_CONTEXT_ *espeak_ng_ERROR_CONTEXT; 91 | 92 | ESPEAK_NG_API void 93 | espeak_ng_ClearErrorContext(espeak_ng_ERROR_CONTEXT *context); 94 | 95 | ESPEAK_NG_API void 96 | espeak_ng_GetStatusCodeMessage(espeak_ng_STATUS status, 97 | char *buffer, 98 | size_t length); 99 | 100 | ESPEAK_NG_API void 101 | espeak_ng_PrintStatusCodeMessage(espeak_ng_STATUS status, 102 | FILE *out, 103 | espeak_ng_ERROR_CONTEXT context); 104 | 105 | ESPEAK_NG_API void 106 | espeak_ng_InitializePath(const char *path); 107 | 108 | ESPEAK_NG_API espeak_ng_STATUS 109 | espeak_ng_Initialize(espeak_ng_ERROR_CONTEXT *context); 110 | 111 | ESPEAK_NG_API espeak_ng_STATUS 112 | espeak_ng_InitializeOutput(espeak_ng_OUTPUT_MODE output_mode, 113 | int buffer_length, 114 | const char *device); 115 | 116 | ESPEAK_NG_API int 117 | espeak_ng_GetSampleRate(void); 118 | 119 | ESPEAK_NG_API espeak_ng_STATUS 120 | espeak_ng_SetParameter(espeak_PARAMETER parameter, 121 | int value, 122 | int relative); 123 | 124 | ESPEAK_NG_API espeak_ng_STATUS 125 | espeak_ng_SetPhonemeEvents(int enable, int ipa); 126 | 127 | ESPEAK_NG_API espeak_ng_STATUS 128 | espeak_ng_SetPunctuationList(const wchar_t *punctlist); 129 | 130 | ESPEAK_NG_API espeak_ng_STATUS 131 | espeak_ng_SetVoiceByName(const char *name); 132 | 133 | ESPEAK_NG_API espeak_ng_STATUS 134 | espeak_ng_SetVoiceByFile(const char *filename); 135 | 136 | ESPEAK_NG_API espeak_ng_STATUS 137 | espeak_ng_SetVoiceByProperties(espeak_VOICE *voice_selector); 138 | 139 | ESPEAK_NG_API espeak_ng_STATUS 140 | espeak_ng_Synthesize(const void *text, 141 | size_t size, 142 | unsigned int position, 143 | espeak_POSITION_TYPE position_type, 144 | unsigned int end_position, 145 | unsigned int flags, 146 | unsigned int *unique_identifier, 147 | void *user_data); 148 | 149 | ESPEAK_NG_API espeak_ng_STATUS 150 | espeak_ng_SynthesizeMark(const void *text, 151 | size_t size, 152 | const char *index_mark, 153 | unsigned int end_position, 154 | unsigned int flags, 155 | unsigned int *unique_identifier, 156 | void *user_data); 157 | 158 | ESPEAK_NG_API espeak_ng_STATUS 159 | espeak_ng_SpeakKeyName(const char *key_name); 160 | 161 | ESPEAK_NG_API espeak_ng_STATUS 162 | espeak_ng_SpeakCharacter(wchar_t character); 163 | 164 | ESPEAK_NG_API espeak_ng_STATUS 165 | espeak_ng_Cancel(void); 166 | 167 | ESPEAK_NG_API espeak_ng_STATUS 168 | espeak_ng_Synchronize(void); 169 | 170 | ESPEAK_NG_API espeak_ng_STATUS 171 | espeak_ng_Terminate(void); 172 | 173 | ESPEAK_NG_API espeak_ng_STATUS 174 | espeak_ng_CompileDictionary(const char *dsource, 175 | const char *dict_name, 176 | FILE *log, 177 | int flags, 178 | espeak_ng_ERROR_CONTEXT *context); 179 | 180 | ESPEAK_NG_API espeak_ng_STATUS 181 | espeak_ng_CompileMbrolaVoice(const char *path, 182 | FILE *log, 183 | espeak_ng_ERROR_CONTEXT *context); 184 | 185 | ESPEAK_NG_API espeak_ng_STATUS 186 | espeak_ng_CompilePhonemeData(long rate, 187 | FILE *log, 188 | espeak_ng_ERROR_CONTEXT *context); 189 | 190 | ESPEAK_NG_API espeak_ng_STATUS 191 | espeak_ng_CompileIntonation(FILE *log, 192 | espeak_ng_ERROR_CONTEXT *context); 193 | 194 | 195 | ESPEAK_NG_API espeak_ng_STATUS 196 | espeak_ng_CompileIntonationPath(const char *source_path, 197 | const char *destination_path, 198 | FILE *log, 199 | espeak_ng_ERROR_CONTEXT *context); 200 | 201 | /* eSpeak NG 1.49.1 */ 202 | 203 | ESPEAK_NG_API espeak_ng_STATUS 204 | espeak_ng_CompilePhonemeDataPath(long rate, 205 | const char *source_path, 206 | const char *destination_path, 207 | FILE *log, 208 | espeak_ng_ERROR_CONTEXT *context); 209 | 210 | ESPEAK_NG_API espeak_ng_STATUS 211 | espeak_ng_SetOutputHooks(espeak_ng_OUTPUT_HOOKS* hooks); 212 | ESPEAK_NG_API espeak_ng_STATUS 213 | espeak_ng_SetConstF0(int f0); 214 | 215 | ESPEAK_NG_API espeak_ng_STATUS 216 | espeak_ng_SetRandSeed(long seed); 217 | 218 | 219 | #ifdef __cplusplus 220 | } 221 | #endif 222 | 223 | #endif 224 | --------------------------------------------------------------------------------