├── vall_e
├── __init__.py
├── emb
│ ├── __init__.py
│ ├── codecs
│ │ ├── __init__.py
│ │ ├── vocos.py
│ │ ├── encodec.py
│ │ └── dac.py
│ └── g2p.py
├── utils
│ ├── ext
│ │ ├── __init__.py
│ │ ├── unsloth.py
│ │ └── muon.py
│ ├── __init__.py
│ ├── distributed.py
│ ├── io.py
│ ├── sampler.py
│ └── ml.py
├── metrics.py
├── models
│ ├── arch
│ │ └── __init__.py
│ ├── __init__.py
│ └── lora.py
├── plot.py
├── engines
│ └── deepspeed.py
└── __main__.py
├── test.wav
├── vall-e.png
├── data
├── qnt.dac
├── qnt.enc
├── qnt.nem
├── noise.enc
├── demo
│ └── index.template.html
└── tongue_twisters.txt
├── scripts
├── run.sh
├── setup.sh
├── process_nscripter.py
├── prepare_librilight.py
├── deduplicate_librilight_libritts.py
├── process_seed-tts.py
├── parse_ppp.py
├── cleanup_dataset.py
├── train_tokenizer.py
├── process_emilia.py
└── process_libritts.py
├── docs
├── demo.md
├── plot.md
├── export.md
├── metrics.md
├── engines.md
├── utils.md
├── webui.md
├── inferenece.md
└── samplers.md
├── .gitignore
├── vall_e.cpp
├── include
│ ├── ggml-blas.h
│ ├── ggml-opencl.h
│ ├── ops.h
│ ├── llama-cpp.h
│ ├── ggml-vulkan.h
│ ├── ggml-rpc.h
│ ├── utils.h
│ ├── ggml-kompute.h
│ ├── ggml-cuda.h
│ ├── ggml-cpp.h
│ ├── ggml-sycl.h
│ ├── llama-impl.h
│ ├── ggml-metal.h
│ ├── lstm.h
│ ├── ggml-alloc.h
│ ├── encoder.h
│ ├── decoder.h
│ ├── espeak-ng
│ │ ├── encoding.h
│ │ └── espeak_ng.h
│ ├── quantizer.h
│ ├── llama-vocab.h
│ ├── ggml-cann.h
│ ├── ggml-cpu.h
│ └── encodec.h
├── Makefile
├── README.md
├── vall_e-impl.h
└── vall_e.h
├── README.md
└── setup.py
/vall_e/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/vall_e/emb/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/vall_e/utils/ext/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/vall_e/emb/codecs/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/vall_e/emb/codecs/vocos.py:
--------------------------------------------------------------------------------
1 | from vocos import Vocos
--------------------------------------------------------------------------------
/test.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/e-c-k-e-r/vall-e/HEAD/test.wav
--------------------------------------------------------------------------------
/vall-e.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/e-c-k-e-r/vall-e/HEAD/vall-e.png
--------------------------------------------------------------------------------
/data/qnt.dac:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/e-c-k-e-r/vall-e/HEAD/data/qnt.dac
--------------------------------------------------------------------------------
/data/qnt.enc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/e-c-k-e-r/vall-e/HEAD/data/qnt.enc
--------------------------------------------------------------------------------
/data/qnt.nem:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/e-c-k-e-r/vall-e/HEAD/data/qnt.nem
--------------------------------------------------------------------------------
/data/noise.enc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/e-c-k-e-r/vall-e/HEAD/data/noise.enc
--------------------------------------------------------------------------------
/scripts/run.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | until $@; do echo retrying; done
4 |
--------------------------------------------------------------------------------
/vall_e/emb/codecs/encodec.py:
--------------------------------------------------------------------------------
1 | from encodec import EncodecModel
2 | from encodec.utils import convert_audio
--------------------------------------------------------------------------------
/docs/demo.md:
--------------------------------------------------------------------------------
1 | # `demo.py`
2 |
3 | This script handles generating demo pages for comparing against other TTS solutions.
4 |
5 | As this is for my own internal use, documentation at the moment is sparing.
6 |
7 | To be filled.
--------------------------------------------------------------------------------
/data/demo/index.template.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 | VALL-E Demo
7 | ${PREAMBLE}
8 | ${TABLES}
9 | Settings used:
${SETTINGS}
10 |
11 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__
2 | /data
3 | /training
4 | /venv
5 | /*.egg-info
6 | /vall_e/version.py
7 | /.cache
8 | /voices
9 | /wandb
10 | /.nltk
11 | /vall_e.cpp/data
12 | /vall_e.cpp/include
13 | /vall_e.cpp/lib
14 | /vall_e.cpp/*.o
15 | /vall_e.cpp/vall_e
16 |
--------------------------------------------------------------------------------
/docs/plot.md:
--------------------------------------------------------------------------------
1 | # `plot.py`
2 |
3 | Included is a helper script to parse the training metrics. Simply invoke it with, for example: `python3 -m vall_e.plot --yaml="./training/config.yaml"`
4 |
5 | You can specify what X and Y labels you want to plot against by passing `--xs tokens_processed --ys loss.nll stats.acc`
--------------------------------------------------------------------------------
/vall_e/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .utils import (
2 | dispatch_attribute,
3 | flatten_dict,
4 | gather_attribute,
5 | load_state_dict_non_strict,
6 | setup_logging,
7 | to_device,
8 | tree_map,
9 | do_gc,
10 | set_seed,
11 | passes_policy,
12 | get_devices,
13 | truncate_json,
14 | timer,
15 | prune_missing,
16 | clamp,
17 | md5_hash,
18 | convert_kwargs,
19 | coerce_dtype,
20 | mean,
21 | logit_normalization,
22 | )
--------------------------------------------------------------------------------
/scripts/setup.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | python3 -m venv venv
4 | source ./venv/bin/activate
5 | pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu121 # or cu118 / cu124
6 | pip3 install -e .
7 |
8 | mkdir -p ./training/valle/ckpt/ar+nar-llama-8/
9 | wget -P ./training/valle/ckpt/ar+nar-llama-8/ "https://huggingface.co/ecker/vall-e/resolve/main/models/ckpt/ar%2Bnar-llama-8/fp32.sft"
10 | wget -P ./training/valle/ "https://huggingface.co/ecker/vall-e/resolve/main/models/config.llama.yaml"
11 |
--------------------------------------------------------------------------------
/docs/export.md:
--------------------------------------------------------------------------------
1 | # `export.py`
2 |
3 | To export the models, run: `python -m vall_e.export --yaml=./training/config.yaml`.
4 |
5 | This will export the latest checkpoints, for example, under `./training/ckpt/ar+nar-retnet-8/fp32.pth`, to be loaded on any system with PyTorch, and will include additional metadata, such as the symmap used, and training stats.
6 |
7 | Desite being called `fp32.sft` or `fp32.pth`, you can export it to a different precision type with `--dtype=float16|bfloat16|float32`.
8 |
9 | You can also export to `safetensors` with `--format=sft`, and `fp32.sft` will be exported instead.
--------------------------------------------------------------------------------
/vall_e.cpp/include/ggml-blas.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 |
3 | #include "ggml.h"
4 | #include "ggml-backend.h"
5 |
6 |
7 | #ifdef __cplusplus
8 | extern "C" {
9 | #endif
10 |
11 | // backend API
12 | GGML_BACKEND_API ggml_backend_t ggml_backend_blas_init(void);
13 |
14 | GGML_BACKEND_API bool ggml_backend_is_blas(ggml_backend_t backend);
15 |
16 | // number of threads used for conversion to float
17 | // for openblas and blis, this will also set the number of threads used for blas operations
18 | GGML_BACKEND_API void ggml_backend_blas_set_n_threads(ggml_backend_t backend_blas, int n_threads);
19 |
20 | GGML_BACKEND_API ggml_backend_reg_t ggml_backend_blas_reg(void);
21 |
22 |
23 | #ifdef __cplusplus
24 | }
25 | #endif
26 |
--------------------------------------------------------------------------------
/vall_e.cpp/include/ggml-opencl.h:
--------------------------------------------------------------------------------
1 | #ifndef GGML_OPENCL_H
2 | #define GGML_OPENCL_H
3 |
4 | #include "ggml.h"
5 | #include "ggml-backend.h"
6 |
7 | #ifdef __cplusplus
8 | extern "C" {
9 | #endif
10 |
11 | //
12 | // backend API
13 | //
14 | GGML_BACKEND_API ggml_backend_t ggml_backend_opencl_init(void);
15 | GGML_BACKEND_API bool ggml_backend_is_opencl(ggml_backend_t backend);
16 |
17 | GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_opencl_buffer_type(void);
18 | GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_opencl_host_buffer_type(void);
19 |
20 | GGML_BACKEND_API ggml_backend_reg_t ggml_backend_opencl_reg(void);
21 |
22 | #ifdef __cplusplus
23 | }
24 | #endif
25 |
26 | #endif // GGML_OPENCL_H
27 |
--------------------------------------------------------------------------------
/vall_e.cpp/include/ops.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 |
3 | #include "ggml.h"
4 |
5 | struct ggml_tensor *pad_1d(struct ggml_context *ctx0, struct ggml_tensor *inp,
6 | int padding_left, int padding_right);
7 |
8 | struct ggml_tensor *unpad_1d(struct ggml_context *ctx0, struct ggml_tensor *inp,
9 | int padding_left, int padding_right);
10 |
11 | struct ggml_tensor *strided_conv_1d(struct ggml_context *ctx0, struct ggml_tensor *inp,
12 | struct ggml_tensor *conv_w, struct ggml_tensor *conv_b,
13 | int stride);
14 |
15 | struct ggml_tensor *strided_conv_transpose_1d(struct ggml_context *ctx0, struct ggml_tensor *inp,
16 | struct ggml_tensor *conv_w, struct ggml_tensor *conv_b,
17 | int stride);
18 |
--------------------------------------------------------------------------------
/vall_e.cpp/include/llama-cpp.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 |
3 | #ifndef __cplusplus
4 | #error "This header is for C++ only"
5 | #endif
6 |
7 | #include
8 |
9 | #include "llama.h"
10 |
11 | struct llama_model_deleter {
12 | void operator()(llama_model * model) { llama_model_free(model); }
13 | };
14 |
15 | struct llama_context_deleter {
16 | void operator()(llama_context * context) { llama_free(context); }
17 | };
18 |
19 | struct llama_sampler_deleter {
20 | void operator()(llama_sampler * sampler) { llama_sampler_free(sampler); }
21 | };
22 |
23 | struct llama_adapter_lora_deleter {
24 | void operator()(llama_adapter_lora * adapter) { llama_adapter_lora_free(adapter); }
25 | };
26 |
27 | typedef std::unique_ptr llama_model_ptr;
28 | typedef std::unique_ptr llama_context_ptr;
29 | typedef std::unique_ptr llama_sampler_ptr;
30 | typedef std::unique_ptr llama_adapter_lora_ptr;
31 |
--------------------------------------------------------------------------------
/vall_e.cpp/Makefile:
--------------------------------------------------------------------------------
1 | ifeq ($(PREFIX),)
2 | PREFIX := /usr/local
3 | endif
4 |
5 | CXX = g++
6 |
7 | INCS += -I./include
8 | LIBS += -L./lib
9 |
10 | LINKS += -lggml -lggml-base -lllama -lencodec -lespeak-ng
11 | FLAGS += -march=native -O3 -DVALL_E_EXPORTS
12 |
13 | SRCS := $(shell find ./ -name "*.cpp")
14 | OBJS += $(patsubst %.cpp,%.o,$(SRCS))
15 |
16 | TARGET = vall_e
17 | TARGET_LIB = lib$(TARGET).so
18 | TARGET_HEADER = $(TARGET).h
19 |
20 |
21 | %.o: %.cpp
22 | $(CXX) $(FLAGS) $(INCS) -c $< -o $@
23 |
24 | $(TARGET): $(OBJS)
25 | $(CXX) $(FLAGS) $(OBJS) $(LIBS) $(INCS) $(LINKS) -o $(TARGET)
26 |
27 | $(TARGET_LIB): $(OBJS)
28 | $(CXX) $(FLAGS) $(OBJS) $(LIBS) $(INCS) $(LINKS) -o $(TARGET_LIB)
29 |
30 | all: $(TARGET_LIB) $(TARGET)
31 |
32 | lib: $(TARGET_LIB)
33 |
34 | install:
35 | cp $(TARGET) $(PREFIX)/bin/$(TARGET)
36 | -cp $(TARGET_LIB) $(PREFIX)/lib/$(TARGET_LIB)
37 | cp $(TARGET_HEADER) $(PREFIX)/include/$(TARGET_HEADER)
38 |
39 | clean:
40 | @-rm -f $(OBJS)
41 | @-rm -f $(TARGET)
42 | @-rm -f $(TARGET_LIB)
--------------------------------------------------------------------------------
/vall_e.cpp/include/ggml-vulkan.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 |
3 | #include "ggml.h"
4 | #include "ggml-backend.h"
5 |
6 | #ifdef __cplusplus
7 | extern "C" {
8 | #endif
9 |
10 | #define GGML_VK_NAME "Vulkan"
11 | #define GGML_VK_MAX_DEVICES 16
12 |
13 | // backend API
14 | GGML_BACKEND_API ggml_backend_t ggml_backend_vk_init(size_t dev_num);
15 |
16 | GGML_BACKEND_API bool ggml_backend_is_vk(ggml_backend_t backend);
17 | GGML_BACKEND_API int ggml_backend_vk_get_device_count(void);
18 | GGML_BACKEND_API void ggml_backend_vk_get_device_description(int device, char * description, size_t description_size);
19 | GGML_BACKEND_API void ggml_backend_vk_get_device_memory(int device, size_t * free, size_t * total);
20 |
21 | GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_vk_buffer_type(size_t dev_num);
22 | // pinned host buffer for use with the CPU backend for faster copies between CPU and GPU
23 | GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_vk_host_buffer_type(void);
24 |
25 | GGML_BACKEND_API ggml_backend_reg_t ggml_backend_vk_reg(void);
26 |
27 | #ifdef __cplusplus
28 | }
29 | #endif
30 |
--------------------------------------------------------------------------------
/vall_e.cpp/include/ggml-rpc.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 |
3 | #include "ggml.h"
4 | #include "ggml-backend.h"
5 |
6 | #ifdef __cplusplus
7 | extern "C" {
8 | #endif
9 |
10 | #define GGML_RPC_MAX_SERVERS 16
11 |
12 | // backend API
13 | GGML_BACKEND_API ggml_backend_t ggml_backend_rpc_init(const char * endpoint);
14 | GGML_BACKEND_API bool ggml_backend_is_rpc(ggml_backend_t backend);
15 |
16 | GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint);
17 |
18 | GGML_BACKEND_API void ggml_backend_rpc_get_device_memory(const char * endpoint, size_t * free, size_t * total);
19 |
20 | GGML_BACKEND_API void ggml_backend_rpc_start_server(ggml_backend_t backend, const char * endpoint,
21 | const char * cache_dir,
22 | size_t free_mem, size_t total_mem);
23 |
24 | GGML_BACKEND_API ggml_backend_reg_t ggml_backend_rpc_reg(void);
25 |
26 | GGML_BACKEND_API ggml_backend_dev_t ggml_backend_rpc_add_device(const char * endpoint);
27 |
28 | #ifdef __cplusplus
29 | }
30 | #endif
31 |
--------------------------------------------------------------------------------
/vall_e.cpp/include/utils.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 |
3 | #include
4 |
5 | #define MAX(a, b) ((a) > (b) ? (a) : (b))
6 | #define MIN(a, b) ((a) < (b) ? (a) : (b))
7 |
8 | const size_t MB = 1024 * 1024;
9 |
10 | template
11 | void read_safe(std::ifstream &infile, T &dest) {
12 | infile.read((char *)&dest, sizeof(T));
13 | }
14 |
15 | int32_t get_num_codebooks(float bandwidth, int hop_length, float sample_rate) {
16 | // The number of codebooks is determined by the bandwidth selected.
17 | // Supported bandwidths are 1.5kbps (n_q = 2), 3 kbps (n_q = 4), 6 kbps (n_q = 8),
18 | // 12 kbps (n_q = 16) and 24kbps (n_q = 32).
19 | return (int32_t)ceilf(1000 * bandwidth / (ceilf(sample_rate / hop_length) * 10));
20 | }
21 |
22 | int32_t get_bandwidth_per_quantizer(int bins, float frame_rate) {
23 | return log2f((float)bins) * frame_rate;
24 | }
25 |
26 | int32_t get_num_quantizers_for_bandwidth(int bins, float frame_rate, float bandwidth) {
27 | float bw_per_q = get_bandwidth_per_quantizer(bins, frame_rate);
28 | int32_t n_q = MAX(1, floorf(bandwidth * 1000 / bw_per_q));
29 | return n_q;
30 | }
31 |
--------------------------------------------------------------------------------
/scripts/process_nscripter.py:
--------------------------------------------------------------------------------
1 | """
2 | Handles processing NScripter's 0.u file to clean up the pile of audio clips it has
3 |
4 | * to-do: also grab transcriptions
5 | """
6 |
7 | import os
8 | import re
9 | import json
10 | import argparse
11 | import torch
12 | import shutil
13 | import torchaudio
14 | import numpy as np
15 |
16 | from tqdm.auto import tqdm
17 | from pathlib import Path
18 |
19 | def process(
20 | input_file=Path("./assets/0.u"),
21 | wav_dir=Path("./arc/"),
22 | output_dir=Path("./dataset/"),
23 | ):
24 | file = open(input_file, encoding='utf-8').read()
25 |
26 | names = {}
27 | aliases = {}
28 | lines = file.split('\n')
29 |
30 | for line in lines:
31 | if not line.startswith('stralias'):
32 | continue
33 | # ick
34 | try:
35 | key, path = re.findall(r'^stralias (.+?),"(.+?)"$', line)[0]
36 | name = key.split("_")[0]
37 | if name not in names:
38 | (output_dir / name).mkdir(parents=True, exist_ok=True)
39 | names[name] = True
40 |
41 | aliases[key] = Path(path)
42 | except Exception as e:
43 | pass
44 |
45 | for k, v in aliases.items():
46 | name = k.split("_")[0]
47 |
48 |
49 | print(aliases)
50 |
51 | if __name__ == "__main__":
52 | process()
--------------------------------------------------------------------------------
/vall_e.cpp/include/ggml-kompute.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 |
3 | #include "ggml.h"
4 | #include "ggml-backend.h"
5 |
6 | #include
7 | #include
8 | #include
9 |
10 | #ifdef __cplusplus
11 | extern "C" {
12 | #endif
13 |
14 | #define GGML_KOMPUTE_MAX_DEVICES 16
15 |
16 | struct ggml_vk_device {
17 | int index;
18 | int type; // same as VkPhysicalDeviceType
19 | size_t heapSize;
20 | const char * name;
21 | const char * vendor;
22 | int subgroupSize;
23 | uint64_t bufferAlignment;
24 | uint64_t maxAlloc;
25 | };
26 |
27 | struct ggml_vk_device * ggml_vk_available_devices(size_t memoryRequired, size_t * count);
28 | bool ggml_vk_get_device(struct ggml_vk_device * device, size_t memoryRequired, const char * name);
29 | bool ggml_vk_has_vulkan(void);
30 | bool ggml_vk_has_device(void);
31 | struct ggml_vk_device ggml_vk_current_device(void);
32 |
33 | //
34 | // backend API
35 | //
36 |
37 | // forward declaration
38 | typedef struct ggml_backend * ggml_backend_t;
39 |
40 | GGML_BACKEND_API ggml_backend_t ggml_backend_kompute_init(int device);
41 |
42 | GGML_BACKEND_API bool ggml_backend_is_kompute(ggml_backend_t backend);
43 |
44 | GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_kompute_buffer_type(int device);
45 |
46 | GGML_BACKEND_API ggml_backend_reg_t ggml_backend_kompute_reg(void);
47 |
48 | #ifdef __cplusplus
49 | }
50 | #endif
51 |
--------------------------------------------------------------------------------
/scripts/prepare_librilight.py:
--------------------------------------------------------------------------------
1 | """
2 | # Handles processing `facebookresearch/libri-light`'s unlabeled audio into a friendlier hierarchy
3 | """
4 |
5 | import os
6 | import json
7 |
8 | datasets = ["small", "medium", "large", "duplicate"]
9 | output_dataset = "LibriLight-4K"
10 |
11 | for input_dataset in datasets:
12 | if not os.path.isdir(f'./{input_dataset}/'):
13 | continue
14 |
15 | for speaker_id in os.listdir(f'./{input_dataset}/'):
16 | if not os.path.isdir(f'./{input_dataset}/{speaker_id}/'):
17 | continue
18 |
19 | for book_name in os.listdir(f'./{input_dataset}/{speaker_id}/'):
20 | subid = 0
21 |
22 | for filename in os.listdir(f'./{input_dataset}/{speaker_id}/{book_name}'):
23 | if filename[-5:] != ".json":
24 | continue
25 |
26 | basename = filename[:-5]
27 |
28 | json_path = f'./{input_dataset}/{speaker_id}/{book_name}/{basename}.json'
29 | flac_path = f'./{input_dataset}/{speaker_id}/{book_name}/{basename}.flac'
30 |
31 | j = json.load(open(json_path, 'r', encoding="utf-8"))
32 | id = j['book_meta']['id']
33 |
34 | json_id_path = f'./{output_dataset}/{speaker_id}/{speaker_id}_{id}_{subid}.json'
35 | flac_id_path = f'./{output_dataset}/{speaker_id}/{speaker_id}_{id}_{subid}.flac'
36 |
37 | os.makedirs(f'./{output_dataset}/{speaker_id}/', exist_ok=True)
38 | os.rename(json_path, json_id_path)
39 | os.rename(flac_path, flac_id_path)
40 |
41 | subid += 1
42 |
--------------------------------------------------------------------------------
/data/tongue_twisters.txt:
--------------------------------------------------------------------------------
1 | Six sick hicks nick six slick bricks with picks and sticks.
2 | Fresh French fried fly fritters.
3 | Rory the warrior and Roger the worrier were reared wrongly in a rural brewery.
4 | Which wrist watches are Swiss wrist watches?
5 | Fred fed Ted bread and Ted fed Fred bread.
6 | The 33 thieves thought that they thrilled the throne throughout Thursday.
7 | You know New York, you need New York, you know you need unique New York.
8 | Lesser leather never weathered wetter weather better.
9 | The sixth sick sheikh’s sixth sheep’s sick.
10 | A skunk sat on a stump and thunk the stump stunk, but the stump thunk the skunk stunk.
11 | Thirty-three thirsty, thundering thoroughbreds thumped Mr. Thurber on Thursday.
12 | Wayne went to wales to watch walruses.
13 | Seventy-seven benevolent elephants.
14 | Send toast to ten tense stout saints’ ten tall tents.
15 | I slit the sheet, the sheet I slit, and on the slitted sheet I sit.
16 | Give papa a cup of proper coffee in a copper coffee cup.
17 | She sells seashells by the seashore.
18 | Peter Piper picked a peck of pickled peppers. How many pickled peppers did Peter Piper pick?
19 | Pad kid poured curd pulled cod.
20 | Fuzzy Wuzzy was a bear. Fuzzy Wuzzy had no hair. Fuzzy Wuzzy wasn’t very fuzzy, was he?
21 | Supercalifragilisticexpialidocious.
22 | How much wood would a woodchuck chuck if a woodchuck could chuck wood? He would chuck, he would, as much as he could, and chuck as much wood as a woodchuck would if a woodchuck could chuck wood.
23 | Buffalo buffalo Buffalo buffalo buffalo buffalo Buffalo buffalo.
--------------------------------------------------------------------------------
/docs/metrics.md:
--------------------------------------------------------------------------------
1 | # `metrics.py`
2 |
3 | This file provides helper functions for computing objective metrics, such as word-error rate (WER), character-error rate (CER), phoneme-error rate (PER), and speaker similarity (SIM-O).
4 |
5 | ## WER / CER
6 |
7 | Word-error rate (WER) is simply computed by transcribing the requested input, and comparing its transcription against the target transcription.
8 | * The transcription is cleaned up and normalized to account for inconsistencies between transcriptions with `openai/whisper-large-v3` with the nuances of English.
9 | * Languages without spaces between words (Chinese, Japanese) should not rely on this, and instead rely on the CER.
10 |
11 | Character-error rate (CER) does the same thing as WER, but on a character basis rather than a word basis.
12 |
13 | Phoneme-error rate (PER) does the same thing as CER, but on the phonemized transcription instead. As this is a speech model, this metric is more correct than the prior metrics, but this isn't a universal metric for comparison, as most models don't report this.
14 |
15 | All rates are un-normalized because I think that's the right way to go about it? Papers aren't clear that they do this, but the error rates are even more unusually low without this.
16 |
17 | ## SIM-O
18 |
19 | Speaker similarity (SIM-O) is computed by obtaining the embedding of each speaker (the output audio and the input prompt), and computing the cosine similarity between those two embeddings.
20 |
21 | These embeddings are obtained through a finetune of WavLM-large geared towards speaker verification.
--------------------------------------------------------------------------------
/scripts/deduplicate_librilight_libritts.py:
--------------------------------------------------------------------------------
1 | """
2 | # Helper script to try and detect any duplications between LibriLight and LibriTTS (I don't think there were any)
3 | """
4 |
5 | import os
6 | import json
7 |
8 | librilight_dir = "LibriLight-6K"
9 | libritts_dir = "LibriTTS-Train"
10 |
11 | librilight_data = {}
12 | libritts_data = {}
13 |
14 | for speaker_id in os.listdir(f'./{librilight_dir}/'):
15 | for filename in os.listdir(f'./{librilight_dir}/{speaker_id}'):
16 | parts = filename.split("_")
17 | book_id = parts[1]
18 | subid = parts[2]
19 |
20 | if speaker_id not in librilight_data:
21 | librilight_data[speaker_id] = {}
22 | if book_id not in librilight_data[speaker_id]:
23 | librilight_data[speaker_id][book_id] = []
24 | librilight_data[speaker_id][book_id].append(subid)
25 |
26 | for speaker_id in os.listdir(f'./{libritts_dir}/'):
27 | for filename in os.listdir(f'./{libritts_dir}/{speaker_id}'):
28 | parts = filename.split("_")
29 | book_id = parts[1]
30 | subid = parts[2]
31 |
32 | if speaker_id not in libritts_data:
33 | libritts_data[speaker_id] = {}
34 | if book_id not in libritts_data[speaker_id]:
35 | libritts_data[speaker_id][book_id] = []
36 | libritts_data[speaker_id][book_id].append(subid)
37 |
38 | duplicates = []
39 |
40 | for speaker_id, books in libritts_data.items():
41 | if speaker_id not in librilight_data:
42 | continue
43 | for book_id, _ in books.items():
44 | if book_id not in librilight_data[speaker_id]:
45 | continue
46 | print(f'Duplicate: {speaker_id}/{book_id}')
47 | duplicates.append(f'{speaker_id}/{book_id}')
48 |
49 | print("Duplicates:", duplicates)
--------------------------------------------------------------------------------
/docs/engines.md:
--------------------------------------------------------------------------------
1 | # `engines/*`
2 |
3 | This folder contains the necessary abstractions for handling training of models through either a local (`base`) backend, or additional wrappers (like DeepSpeed, and in the future Accelerate and Lightning).
4 |
5 | This architecture is partially lifted from the original implementation, but expanded for both my needs and modularity for other backends.
6 |
7 | An `Engine` is just a wrapper that contains training metadata for the loaded module.
8 |
9 | An `Engines` is a dict of `Engine`s, and extra functions to allow iterating through its contents, allowing for simultaneous loading and training of engines for a shared dataloader iteration.
10 |
11 | ## `__init__.py`
12 |
13 | This script handles the bulk of loading a model and wrapping the model with the requested engine type.
14 |
15 | The checkpoint or weight path is automatically deduced, as well as pre-processing the state dict (if requested) before loading it.
16 | * resizing modules from the weights to the requested configuration in the YAML is done here.
17 | * replacing modules with quantized versions or LoRAs are applied here.
18 | * the requested optimizer, and params to freeze, for a model is applied here.
19 |
20 | ## `base.py`
21 |
22 | The internal (`local`) implementation of orchestrating training. The basics are handled here, from automatic-mixed-precision, gradient accumulation, loss scaling, etc.
23 |
24 | Functions for other backends are also defined here, such as the training step function.
25 |
26 | ## `deepspeed.py`
27 |
28 | A backend relying on `deepspeed` for its orchestration, which offers additional features that can be defined under `cfg.trainer.deepspeed`.
--------------------------------------------------------------------------------
/scripts/process_seed-tts.py:
--------------------------------------------------------------------------------
1 | """
2 | Handles processing seed-tts-eval's dataset into something to be used for vall_e.demo
3 |
4 | Reads from meta.lst, a text file where each utterance is formatted as:
5 |
6 | |||
7 | """
8 |
9 | import os
10 | import json
11 | import argparse
12 | import torch
13 | import shutil
14 | import torchaudio
15 | import numpy as np
16 |
17 | from tqdm.auto import tqdm
18 | from pathlib import Path
19 |
20 | def process(
21 | input_dir=Path("./seedtts_testset/zh/"),
22 | list_name="./hardcase.lst",
23 | wav_dir="./wavs/",
24 | output_dir=Path("./dataset/seed-tts-eval-hard/"),
25 | ):
26 | language = "auto"
27 |
28 | if "en" in str(input_dir):
29 | language = "en"
30 | elif "zh" in str(input_dir):
31 | language = "zh"
32 |
33 | output_dir.mkdir(parents=True, exist_ok=True)
34 |
35 | # read manifest
36 | lines = open(input_dir / list_name).read()
37 | lines = lines.split("\n")
38 | # split it even further
39 | for line in lines:
40 | if not line:
41 | continue
42 | filename, prompt_text, prompt_wav, text = line.split("|")
43 |
44 | (output_dir / filename / "out").mkdir(parents=True, exist_ok=True)
45 |
46 | open( output_dir / filename / "prompt.txt", "w", encoding="utf-8" ).write( text )
47 | open( output_dir / filename / "language.txt", "w", encoding="utf-8" ).write( language )
48 |
49 | reference_wav = (input_dir / wav_dir / filename).with_suffix(".wav")
50 | if not reference_wav.exists():
51 | continue
52 |
53 | shutil.copy(reference_wav, output_dir / filename / "reference.wav" )
54 | shutil.copy(input_dir / prompt_wav, output_dir / filename / "prompt.wav" )
55 |
56 | if __name__ == "__main__":
57 | process()
--------------------------------------------------------------------------------
/vall_e.cpp/include/ggml-cuda.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 |
3 | #include "ggml.h"
4 | #include "ggml-backend.h"
5 |
6 | #ifdef __cplusplus
7 | extern "C" {
8 | #endif
9 |
10 | #ifdef GGML_USE_HIP
11 | #define GGML_CUDA_NAME "ROCm"
12 | #define GGML_CUBLAS_NAME "hipBLAS"
13 | #elif defined(GGML_USE_MUSA)
14 | #define GGML_CUDA_NAME "MUSA"
15 | #define GGML_CUBLAS_NAME "muBLAS"
16 | #else
17 | #define GGML_CUDA_NAME "CUDA"
18 | #define GGML_CUBLAS_NAME "cuBLAS"
19 | #endif
20 | #define GGML_CUDA_MAX_DEVICES 16
21 |
22 | // backend API
23 | GGML_BACKEND_API ggml_backend_t ggml_backend_cuda_init(int device);
24 |
25 | GGML_BACKEND_API bool ggml_backend_is_cuda(ggml_backend_t backend);
26 |
27 | // device buffer
28 | GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_cuda_buffer_type(int device);
29 |
30 | // split tensor buffer that splits matrices by rows across multiple devices
31 | GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_cuda_split_buffer_type(int main_device, const float * tensor_split);
32 |
33 | // pinned host buffer for use with the CPU backend for faster copies between CPU and GPU
34 | GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_cuda_host_buffer_type(void);
35 |
36 | GGML_BACKEND_API int ggml_backend_cuda_get_device_count(void);
37 | GGML_BACKEND_API void ggml_backend_cuda_get_device_description(int device, char * description, size_t description_size);
38 | GGML_BACKEND_API void ggml_backend_cuda_get_device_memory(int device, size_t * free, size_t * total);
39 |
40 | GGML_BACKEND_API bool ggml_backend_cuda_register_host_buffer(void * buffer, size_t size);
41 | GGML_BACKEND_API void ggml_backend_cuda_unregister_host_buffer(void * buffer);
42 |
43 | GGML_BACKEND_API ggml_backend_reg_t ggml_backend_cuda_reg(void);
44 |
45 | #ifdef __cplusplus
46 | }
47 | #endif
48 |
--------------------------------------------------------------------------------
/vall_e.cpp/include/ggml-cpp.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 |
3 | #ifndef __cplusplus
4 | #error "This header is for C++ only"
5 | #endif
6 |
7 | #include "ggml.h"
8 | #include "ggml-alloc.h"
9 | #include "ggml-backend.h"
10 | #include "gguf.h"
11 | #include
12 |
13 | // Smart pointers for ggml types
14 |
15 | // ggml
16 |
17 | struct ggml_context_deleter { void operator()(ggml_context * ctx) { ggml_free(ctx); } };
18 | struct gguf_context_deleter { void operator()(gguf_context * ctx) { gguf_free(ctx); } };
19 |
20 | typedef std::unique_ptr ggml_context_ptr;
21 | typedef std::unique_ptr gguf_context_ptr;
22 |
23 | // ggml-alloc
24 |
25 | struct ggml_gallocr_deleter { void operator()(ggml_gallocr_t galloc) { ggml_gallocr_free(galloc); } };
26 |
27 | typedef std::unique_ptr ggml_gallocr_ptr;
28 |
29 | // ggml-backend
30 |
31 | struct ggml_backend_deleter { void operator()(ggml_backend_t backend) { ggml_backend_free(backend); } };
32 | struct ggml_backend_buffer_deleter { void operator()(ggml_backend_buffer_t buffer) { ggml_backend_buffer_free(buffer); } };
33 | struct ggml_backend_event_deleter { void operator()(ggml_backend_event_t event) { ggml_backend_event_free(event); } };
34 | struct ggml_backend_sched_deleter { void operator()(ggml_backend_sched_t sched) { ggml_backend_sched_free(sched); } };
35 |
36 | typedef std::unique_ptr ggml_backend_ptr;
37 | typedef std::unique_ptr ggml_backend_buffer_ptr;
38 | typedef std::unique_ptr ggml_backend_event_ptr;
39 | typedef std::unique_ptr ggml_backend_sched_ptr;
40 |
--------------------------------------------------------------------------------
/vall_e.cpp/include/ggml-sycl.h:
--------------------------------------------------------------------------------
1 | //
2 | // MIT license
3 | // Copyright (C) 2024 Intel Corporation
4 | // SPDX-License-Identifier: MIT
5 | //
6 |
7 | #pragma once
8 |
9 | #include "ggml.h"
10 | #include "ggml-backend.h"
11 |
12 | #define GGML_SYCL_NAME "SYCL"
13 | #define GGML_SYCL_MAX_DEVICES 48
14 |
15 | #ifdef __cplusplus
16 | extern "C" {
17 | #endif
18 |
19 | // backend API
20 | GGML_BACKEND_API ggml_backend_t ggml_backend_sycl_init(int device);
21 |
22 | GGML_BACKEND_API bool ggml_backend_is_sycl(ggml_backend_t backend);
23 |
24 | // devide buffer
25 | GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_sycl_buffer_type(int device);
26 |
27 | // split tensor buffer that splits matrices by rows across multiple devices
28 | GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_sycl_split_buffer_type(const float * tensor_split);
29 |
30 | // pinned host buffer for use with the CPU backend for faster copies between CPU and GPU
31 | GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_sycl_host_buffer_type(void);
32 |
33 | GGML_BACKEND_API void ggml_backend_sycl_print_sycl_devices(void);
34 | GGML_BACKEND_API void ggml_backend_sycl_get_gpu_list(int *id_list, int max_len);
35 | GGML_BACKEND_API void ggml_backend_sycl_get_device_description(int device,
36 | char *description,
37 | size_t description_size);
38 | GGML_BACKEND_API int ggml_backend_sycl_get_device_count();
39 | GGML_BACKEND_API void ggml_backend_sycl_get_device_memory(int device, size_t *free, size_t *total);
40 |
41 | // SYCL doesn't support registering host memory, keep here for reference
42 | // GGML_BACKEND_API bool ggml_backend_sycl_register_host_buffer(void * buffer, size_t size);
43 | // GGML_BACKEND_API void ggml_backend_sycl_unregister_host_buffer(void * buffer);
44 |
45 | GGML_BACKEND_API ggml_backend_reg_t ggml_backend_sycl_reg(void);
46 |
47 | #ifdef __cplusplus
48 | }
49 | #endif
50 |
--------------------------------------------------------------------------------
/vall_e.cpp/include/llama-impl.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 |
3 | #include "ggml.h" // for ggml_log_level
4 |
5 | #include
6 | #include
7 |
8 | #ifdef __GNUC__
9 | # if defined(__MINGW32__) && !defined(__clang__)
10 | # define LLAMA_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__)))
11 | # else
12 | # define LLAMA_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__)))
13 | # endif
14 | #else
15 | # define LLAMA_ATTRIBUTE_FORMAT(...)
16 | #endif
17 |
18 | //
19 | // logging
20 | //
21 |
22 | LLAMA_ATTRIBUTE_FORMAT(2, 3)
23 | void llama_log_internal (ggml_log_level level, const char * format, ...);
24 | void llama_log_callback_default(ggml_log_level level, const char * text, void * user_data);
25 |
26 | #define LLAMA_LOG(...) llama_log_internal(GGML_LOG_LEVEL_NONE , __VA_ARGS__)
27 | #define LLAMA_LOG_INFO(...) llama_log_internal(GGML_LOG_LEVEL_INFO , __VA_ARGS__)
28 | #define LLAMA_LOG_WARN(...) llama_log_internal(GGML_LOG_LEVEL_WARN , __VA_ARGS__)
29 | #define LLAMA_LOG_ERROR(...) llama_log_internal(GGML_LOG_LEVEL_ERROR, __VA_ARGS__)
30 | #define LLAMA_LOG_DEBUG(...) llama_log_internal(GGML_LOG_LEVEL_DEBUG, __VA_ARGS__)
31 | #define LLAMA_LOG_CONT(...) llama_log_internal(GGML_LOG_LEVEL_CONT , __VA_ARGS__)
32 |
33 | //
34 | // helpers
35 | //
36 |
37 | template
38 | struct no_init {
39 | T value;
40 | no_init() { /* do nothing */ }
41 | };
42 |
43 | struct time_meas {
44 | time_meas(int64_t & t_acc, bool disable = false);
45 | ~time_meas();
46 |
47 | const int64_t t_start_us;
48 |
49 | int64_t & t_acc;
50 | };
51 |
52 | void replace_all(std::string & s, const std::string & search, const std::string & replace);
53 |
54 | // TODO: rename to llama_format ?
55 | LLAMA_ATTRIBUTE_FORMAT(1, 2)
56 | std::string format(const char * fmt, ...);
57 |
58 | std::string llama_format_tensor_shape(const std::vector & ne);
59 | std::string llama_format_tensor_shape(const struct ggml_tensor * t);
60 |
61 | std::string gguf_kv_to_str(const struct gguf_context * ctx_gguf, int i);
62 |
--------------------------------------------------------------------------------
/vall_e/metrics.py:
--------------------------------------------------------------------------------
1 | # handles objective metric calculations, such as WER and SIM-O
2 |
3 | #from .emb.transcribe import transcribe
4 | from .emb.similar import speaker_similarity_embedding
5 | from .emb.transcribe import transcribe
6 | from .emb.g2p import detect_language, coerce_to_hiragana, encode
7 | from .data import normalize_text
8 |
9 | import torch.nn.functional as F
10 |
11 | from pathlib import Path
12 | from torcheval.metrics.functional import word_error_rate
13 | from torchmetrics.functional.text import char_error_rate
14 |
15 | import warnings
16 | warnings.simplefilter(action='ignore', category=FutureWarning)
17 | warnings.simplefilter(action='ignore', category=UserWarning)
18 |
19 | def wer( audio, reference, language="auto", phonemize=True, normalize=True, **transcription_kwargs ):
20 | if language == "auto":
21 | language = detect_language( reference )
22 |
23 | transcription = transcribe( audio, language=language, align=False, **transcription_kwargs )
24 |
25 | if language == "auto":
26 | language = transcription["language"]
27 |
28 | transcription = transcription["text"]
29 |
30 | # reference audio needs transcribing too
31 | if isinstance( reference, Path ):
32 | reference = transcribe( reference, language=language, align=False, **transcription_kwargs )["text"]
33 |
34 | if language == "ja":
35 | transcription = coerce_to_hiragana( transcription )
36 | reference = coerce_to_hiragana( reference )
37 |
38 | if phonemize:
39 | transcription = encode( transcription, language=language )
40 | reference = encode( reference, language=language )
41 | elif normalize:
42 | transcription = normalize_text( transcription, language=language )
43 | reference = normalize_text( reference, language=language )
44 |
45 | wer_score = word_error_rate([transcription], [reference]).item()
46 | # un-normalize
47 | wer_score *= len(reference.split())
48 |
49 | cer_score = char_error_rate([transcription], [reference]).item()
50 | # un-normalize
51 | cer_score *= len(reference)
52 |
53 | return wer_score, cer_score
54 |
55 | def sim_o( audio, reference, **kwargs ):
56 | audio_emb = speaker_similarity_embedding( audio, **kwargs )
57 | reference_emb = speaker_similarity_embedding( reference, **kwargs )
58 |
59 | return F.cosine_similarity( audio_emb, reference_emb, dim=-1 ).item()
--------------------------------------------------------------------------------
/vall_e/models/arch/__init__.py:
--------------------------------------------------------------------------------
1 | AVAILABLE_ARCHES = []
2 | ERROR_ARCHES = {}
3 |
4 | try:
5 | from .llama import Config as LlamaConfig, Model as LlamaModel, Attention as LlamaAttention, AVAILABLE_ATTENTIONS
6 | AVAILABLE_ARCHES.append("llama")
7 | except Exception as e:
8 | ERROR_ARCHES["llama"] = e
9 | AVAILABLE_ATTENTIONS = []
10 | pass
11 |
12 | """
13 | try:
14 | from .transformer import SinusoidalEmbedding, Block as TransformerBlock
15 | AVAILABLE_ARCHES.append("transformer")
16 | except Exception as e:
17 | ERROR_ARCHES["transformer"] = e
18 | pass
19 |
20 | try:
21 | from .retnet import RetNetDecoder, RetNetConfig
22 | AVAILABLE_ARCHES.append("retnet")
23 | except Exception as e:
24 | ERROR_ARCHES["retnet"] = e
25 | pass
26 |
27 | try:
28 | from .retnet_syncdoth.retnet_ts import RetNetDecoder as RetNetDecoder_TS, RetNetConfig as RetNetConfig_TS
29 | AVAILABLE_ARCHES.append("retnet-ts")
30 | except Exception as e:
31 | ERROR_ARCHES["retnet-ts"] = e
32 | pass
33 |
34 | try:
35 | from .retnet_syncdoth.retnet_hf import RetNetDecoder as RetNetDecoder_HF, RetNetConfig as RetNetConfig_HF, RetNetForCausalLM
36 | AVAILABLE_ARCHES.append("retnet-hf")
37 | except Exception as e:
38 | ERROR_ARCHES["retnet-hf"] = e
39 | pass
40 |
41 | try:
42 | from .bitnet import BitNetTransformer
43 | AVAILABLE_ARCHES.append("bitnet")
44 | except Exception as e:
45 | ERROR_ARCHES["bitnet"] = e
46 | pass
47 |
48 | try:
49 | from .mixtral import MixtralModel, MixtralConfig, MixtralAttention, MixtralAttention_Adapted, MixtralModel_Adapted, load_balancing_loss_func
50 | AVAILABLE_ARCHES.append("mixtral")
51 | except Exception as e:
52 | ERROR_ARCHES["mixtral"] = e
53 |
54 | try:
55 | from .mamba import MambaModel, Mamba2Model, MambaConfig, Mamba2Config
56 | AVAILABLE_ARCHES.append("mamba")
57 | AVAILABLE_ARCHES.append("mamba2")
58 | except Exception as e:
59 | ERROR_ARCHES["mamba"] = e
60 | ERROR_ARCHES["mamba2"] = e
61 | """
62 | """
63 | try:
64 | from .mamba import MambaMixelModel, MambaLMHeadModel, MambaConfig
65 | AVAILABLE_ARCHES.append("mamba")
66 | AVAILABLE_ARCHES.append("mamba2")
67 | except Exception as e:
68 | ERROR_ARCHES["mamba"] = e
69 | ERROR_ARCHES["mamba2"] = e
70 |
71 | try:
72 | from .mamba_vasqu import Mamba2Model_HF, Mamba2Config_HF
73 | AVAILABLE_ARCHES.append("mamba2-hf")
74 | except Exception as e:
75 | ERROR_ARCHES["mamba2-hf"] = e
76 | """
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 | # VALL'E
6 |
7 | An unofficial PyTorch implementation of [VALL-E](https://vall-e-demo.ecker.tech/) (last updated: `2025.05.30`), utilizing the [EnCodec](https://github.com/facebookresearch/encodec) encoder/decoder.
8 |
9 | A demo is available on HuggingFace [here](https://huggingface.co/spaces/ecker/vall-e).
10 |
11 | ## Requirements
12 |
13 | Besides a working PyTorch environment, the only hard requirement is [`espeak-ng`](https://github.com/espeak-ng/espeak-ng/) for phonemizing text:
14 | - Linux users can consult their package managers on installing `espeak`/`espeak-ng`.
15 | - Windows users are required to install [`espeak-ng`](https://github.com/espeak-ng/espeak-ng/releases/tag/1.51#Assets).
16 | + additionally, you may be required to set the `PHONEMIZER_ESPEAK_LIBRARY` environment variable to specify the path to `libespeak-ng.dll`.
17 | - In the future, an internal homebrew to replace this would be fantastic.
18 |
19 | ## Install
20 |
21 | Simply run `pip install git+https://git.ecker.tech/mrq/vall-e` or `pip install git+https://github.com/e-c-k-e-r/vall-e`.
22 |
23 | This repo is tested under Python versions `3.10.9`, `3.11.3`, and `3.12.3`.
24 |
25 | ### Additional Implementations
26 |
27 | An "HF"-ified version of the model is available as [`ecker/vall-e@hf`](https://huggingface.co/ecker/vall-e/tree/hf), but it does require some additional efforts (see the `__main__` of [`./vall_e/models/base.py`](./vall_e/models/base.py) for details).
28 |
29 | Additionally, [`vall_e.cpp`](./vall_e.cpp/) is available. Consult its README for more details.
30 |
31 | ## Pre-Trained Model
32 |
33 | Pre-trained weights can be acquired from
34 | * [here](https://huggingface.co/ecker/vall-e) or automatically when either inferencing or running the web UI.
35 | * `./scripts/setup.sh`, a script to setup a proper environment and download the weights. This will also automatically create a `venv`.
36 | * when inferencing, either through the web UI or CLI, if no model is passed, the default model will download automatically instead, and should automatically update.
37 |
38 | ## Documentation
39 |
40 | The provided documentation under [./docs/](./docs/) should provide thorough coverage over most, if not all, of this project.
41 |
42 | Markdown files should correspond directly to their respective file or folder under `./vall_e/`.
--------------------------------------------------------------------------------
/vall_e/utils/distributed.py:
--------------------------------------------------------------------------------
1 | """
2 | # https://github.com/enhuiz/pytorch-training-utilities
3 | """
4 |
5 | import os
6 | import socket
7 |
8 | from functools import cache, wraps
9 | from typing import Callable
10 |
11 | import torch
12 | import torch.distributed as dist
13 | from torch.nn.parallel import DistributedDataParallel as DDP
14 |
15 | def get_free_port():
16 | sock = socket.socket()
17 | sock.bind(("", 0))
18 | return sock.getsockname()[1]
19 |
20 |
21 | _distributed_initialized = False
22 | def init_distributed( fn, *args, **kwargs ):
23 | torch.cuda.set_device(local_rank())
24 | fn(*args, **kwargs)
25 | _distributed_initialized = True
26 |
27 | def distributed_initialized():
28 | return _distributed_initialized
29 |
30 | def cleanup_distributed():
31 | dist.barrier()
32 | dist.destroy_process_group()
33 |
34 | @cache
35 | def fix_unset_envs():
36 | envs = dict(
37 | RANK="0",
38 | WORLD_SIZE="1",
39 | MASTER_ADDR="localhost",
40 | MASTER_PORT=str(get_free_port()),
41 | LOCAL_RANK="0",
42 | )
43 |
44 | for key in envs:
45 | value = os.getenv(key)
46 | if value is not None:
47 | return
48 |
49 | for key, value in envs.items():
50 | os.environ[key] = value
51 |
52 |
53 | def local_rank():
54 | return int(os.getenv("LOCAL_RANK", 0))
55 |
56 | def global_rank():
57 | return int(os.getenv("RANK", 0))
58 |
59 | def world_size():
60 | return int(os.getenv("WORLD_SIZE", 1))
61 |
62 |
63 | def is_local_leader():
64 | return local_rank() == 0
65 |
66 |
67 | def is_global_leader():
68 | return global_rank() == 0
69 |
70 |
71 | def local_leader_only(fn=None, *, default=None) -> Callable:
72 | def wrapper(fn):
73 | @wraps(fn)
74 | def wrapped(*args, **kwargs):
75 | if is_local_leader():
76 | return fn(*args, **kwargs)
77 | return default
78 |
79 | return wrapped
80 |
81 | if fn is None:
82 | return wrapper
83 |
84 | return wrapper(fn)
85 |
86 |
87 | def global_leader_only(fn: Callable | None = None, *, default=None) -> Callable:
88 | def wrapper(fn):
89 | @wraps(fn)
90 | def wrapped(*args, **kwargs):
91 | if is_global_leader():
92 | return fn(*args, **kwargs)
93 | return default
94 |
95 | return wrapped
96 |
97 | if fn is None:
98 | return wrapper
99 |
100 | return wrapper(fn)
101 |
102 | def ddp_model(model):
103 | return DDP(model.to(device='cuda'), [local_rank()], find_unused_parameters=True)
--------------------------------------------------------------------------------
/vall_e.cpp/include/ggml-metal.h:
--------------------------------------------------------------------------------
1 | // Note: this description is outdated
2 | //
3 | // An interface allowing to compute ggml_cgraph with Metal
4 | //
5 | // This is a fully functional interface that extends ggml with GPU support for Apple devices.
6 | // A similar interface can be created for other GPU backends (e.g. Vulkan, CUDA, etc.)
7 | //
8 | // How it works?
9 | //
10 | // As long as your program can create and evaluate a ggml_cgraph on the CPU, you can use this
11 | // interface to evaluate the same graph on the GPU. Instead of using ggml_graph_compute(), you
12 | // use ggml_metal_graph_compute() (or ggml_vulkan_graph_compute(), etc.)
13 | //
14 | // You only need to make sure that all memory buffers that you used during the graph creation
15 | // are mapped to the device memory with the ggml_metal_add_buffer() function. This mapping is
16 | // used during the graph evaluation to determine the arguments of the compute kernels.
17 | //
18 | // Synchronization between device and host memory (for example for input and output tensors)
19 | // is done with the ggml_metal_set_tensor() and ggml_metal_get_tensor() functions.
20 | //
21 |
22 | #pragma once
23 |
24 | #include "ggml.h"
25 | #include "ggml-backend.h"
26 |
27 | #include
28 | #include
29 |
30 | struct ggml_tensor;
31 | struct ggml_cgraph;
32 |
33 | #ifdef __cplusplus
34 | extern "C" {
35 | #endif
36 |
37 | //
38 | // backend API
39 | // user-code should use only these functions
40 | //
41 |
42 | GGML_BACKEND_API ggml_backend_t ggml_backend_metal_init(void);
43 |
44 | GGML_BACKEND_API bool ggml_backend_is_metal(ggml_backend_t backend);
45 |
46 | GGML_DEPRECATED(
47 | GGML_BACKEND_API ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, size_t size, size_t max_size),
48 | "obsoleted by the new device interface - https://github.com/ggml-org/llama.cpp/pull/9713");
49 |
50 | GGML_BACKEND_API void ggml_backend_metal_set_abort_callback(ggml_backend_t backend, ggml_abort_callback abort_callback, void * user_data);
51 |
52 | GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_metal_buffer_type(void);
53 |
54 | // helper to check if the device supports a specific family
55 | // ideally, the user code should be doing these checks
56 | // ref: https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf
57 | GGML_BACKEND_API bool ggml_backend_metal_supports_family(ggml_backend_t backend, int family);
58 |
59 | // capture all command buffers committed the next time `ggml_backend_graph_compute` is called
60 | GGML_BACKEND_API void ggml_backend_metal_capture_next_compute(ggml_backend_t backend);
61 |
62 | GGML_BACKEND_API ggml_backend_reg_t ggml_backend_metal_reg(void);
63 |
64 | #ifdef __cplusplus
65 | }
66 | #endif
67 |
--------------------------------------------------------------------------------
/vall_e.cpp/README.md:
--------------------------------------------------------------------------------
1 | # vall_e.cpp
2 |
3 | This is an implementation that makes use of [llama.cpp](https://github.com/ggerganov/llama.cpp/) and [encodec.cpp](https://github.com/PABannier/encodec.cpp).
4 |
5 | Model weights can:
6 | * be found at [`ecker/vall-e@gguf`](https://huggingface.co/ecker/vall-e/tree/gguf)
7 | * converted with `vall_e.export --yaml=./model_path/config.yaml --hf`, then running `python3 /path/to/your/llama.cpp/convert_hf_to_gguf ./model_path/hf/`
8 |
9 | ## Build
10 |
11 | Populate `./include/` with the `ggml`, `llama.cpp`, and `encodec.cpp` headers (although these headers should already be provided).
12 | * `encodec.cpp` and `llama.cpp` need to have their CMake files generated with `-DBUILD_SHARED_LIBS=On` passed.
13 |
14 | Populate `./lib/` with the compiled libraries of `llama.cpp`, `encodec.cpp`, and `espeak-ng` (if not already in your `LD_LIBRARY_PATH`).
15 |
16 | Run `make`.
17 | * `make lib` will generate the shared library (rename the `.so` to `.dll` under Windows).
18 |
19 | ### Required Modifications
20 |
21 | [`encodec.cpp`](https://github.com/PABannier/encodec.cpp) requires updating its GGML copy to the latest version, which requires a few lines to get the CPU backend working (per my [fork](https://github.com/e-c-k-e-r/encodec.cpp)).
22 |
23 | [`llama.cpp`](https://github.com/ggerganov/llama.cpp) only possible modification needs to ensure that a non-causal attention mask is used; everything necessary can be hacked together with clever tricks.
24 | * initially written on commit `9ba399dfa7f115effc63d48e6860a94c9faa31b2`, updated to commit `7a84777f42a9b3ba47db5d20b7662f8ddf92f652`
25 |
26 | ## To-Do
27 |
28 | * [x] converted model to GGUF
29 | * [x] convert it without modifying any of the existing code, as the tokenizer requires some care
30 | * [x] basic framework
31 | * [x] load the quantized model
32 | * [x] orchestrate the required embeddings
33 | * [x] juggle the output head / classifier properly
34 | * [x] phonemize text
35 | * with the help of espeak-ng
36 | * [x] tokenize phonemes
37 | * tokenize with `llama_tokenize` instead of a homebrewed method because the tokenizer is being a huge thorn
38 | * [x] load audio from disk
39 | * [x] encode audio
40 | * [x] sum embeddings for the `prom` and prior `resp`s
41 | * [x] working `AR` output
42 | * [x] `AR` sampling
43 | * [x] working `NAR-len` output
44 | * [x] `NAR-len` sampling
45 | * [ ] proper scoring
46 | * [x] working `NAR` output
47 | * [x] `NAR` sampling
48 | * [x] decode audio to disk
49 | * [x] a functional CLI
50 | * [x] actually make it work
51 | * [x] clean up to make the code usable elsewhere
52 | * [x] configured to allow for being used as a lib
53 | * (I do need to validate this in my engine project, but that's in MSYS2)
54 | * [ ] feature parity with the PyTorch version
55 | * [ ] vocos
56 | * [ ] additional tasks
57 | * [ ] `stt`
58 | * [x] `ns` / `sr`
59 | * [ ] samplers
--------------------------------------------------------------------------------
/scripts/parse_ppp.py:
--------------------------------------------------------------------------------
1 | """
2 | # Helper script to parse PPP dataset into a friendlier hierarchy
3 | """
4 |
5 | import os
6 | import json
7 | import torch
8 |
9 | from tqdm.auto import tqdm
10 | from pathlib import Path
11 | from vall_e.emb.g2p import encode as valle_phonemize
12 | from vall_e.emb.qnt import encode_from_file as valle_quantize, _replace_file_extension
13 |
14 | target = "in"
15 |
16 | audio_map = {}
17 | text_map = {}
18 |
19 | data = {}
20 |
21 | for season in os.listdir(f"./{target}/"):
22 | if not os.path.isdir(f"./{target}/{season}/"):
23 | continue
24 |
25 | for episode in os.listdir(f"./{target}/{season}/"):
26 | if not os.path.isdir(f"./{target}/{season}/{episode}/"):
27 | continue
28 |
29 | for filename in os.listdir(f"./{target}/{season}/{episode}/"):
30 | path = f'./{target}/{season}/{episode}/{filename}'
31 | attrs = filename.split("_")
32 | timestamp = f'{attrs[0]}h{attrs[1]}m{attrs[2]}s'
33 |
34 | key = f'{episode}_{timestamp}'
35 |
36 | if filename[-5:] == ".flac":
37 | name = attrs[3]
38 | emotion = attrs[4]
39 | quality = attrs[5]
40 |
41 | audio_map[key] = {
42 | "path": path,
43 | 'episode': episode,
44 | "name": name,
45 | "emotion": emotion,
46 | "quality": quality,
47 | "timestamp": timestamp,
48 | }
49 |
50 | elif filename[-4:] == ".txt":
51 | text_map[key] = open(path, encoding="utf-8").read()
52 | txts = {}
53 | wavs = []
54 |
55 | for key, entry in audio_map.items():
56 | path = entry['path']
57 | name = entry['name']
58 | emotion = entry['emotion']
59 | quality = entry['quality']
60 | episode = entry['episode']
61 | path = entry['path']
62 | timestamp = entry['timestamp']
63 | transcription = text_map[key]
64 | if name not in data:
65 | data[name] = {}
66 | os.makedirs(f'./training/{name}/', exist_ok=True)
67 | os.makedirs(f'./voices/{name}/', exist_ok=True)
68 |
69 | key = f'{episode}_{timestamp}.flac'
70 | os.rename(path, f'./voices/{name}/{key}')
71 |
72 | data[name][key] = {
73 | "segments": [],
74 | "language": "en",
75 | "text": transcription,
76 | "misc": {
77 | "emotion": emotion,
78 | "quality": quality,
79 | "timestamp": timestamp,
80 | "episode": episode,
81 | }
82 | }
83 |
84 | path = f'./voices/{name}/{key}'
85 | txts[path] = transcription
86 | wavs.append(Path(path))
87 |
88 | for name in data.keys():
89 | open(f"./training/{name}/whisper.json", "w", encoding="utf-8").write( json.dumps( data[name], indent='\t' ) )
90 |
91 | # to-do: update to "The Proper Way"
92 | # for now it can just be fed back into "The Proper Way""
93 | """
94 | device = "cuda"
95 | for key, text in tqdm(txts.items(), desc="Phonemizing..."):
96 | path = Path(key)
97 | phones = valle_phonemize(text)
98 | open(_replace_file_extension(path, ".phn.txt"), "w", encoding="utf-8").write(" ".join(phones))
99 |
100 | for path in tqdm(wavs, desc="Quantizing..."):
101 | qnt = valle_quantize(path, device=device)
102 | torch.save(qnt.cpu(), _replace_file_extension(path, ".qnt.pt"))
103 | """
--------------------------------------------------------------------------------
/vall_e.cpp/include/lstm.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 |
3 | #include "ggml.h"
4 | #include "ggml-alloc.h"
5 |
6 | #include "ops.h"
7 |
8 | struct encodec_lstm {
9 | struct ggml_tensor *l0_ih_w;
10 | struct ggml_tensor *l0_hh_w;
11 |
12 | struct ggml_tensor *l0_ih_b;
13 | struct ggml_tensor *l0_hh_b;
14 |
15 | struct ggml_tensor *l1_ih_w;
16 | struct ggml_tensor *l1_hh_w;
17 |
18 | struct ggml_tensor *l1_ih_b;
19 | struct ggml_tensor *l1_hh_b;
20 | };
21 |
22 | struct ggml_tensor *forward_pass_lstm_unilayer(struct ggml_context *ctx0,
23 | struct ggml_tensor *inp,
24 | struct ggml_tensor *weight_ih,
25 | struct ggml_tensor *weight_hh,
26 | struct ggml_tensor *bias_ih,
27 | struct ggml_tensor *bias_hh,
28 | char *prefix) {
29 | const int seq_length = inp->ne[0];
30 | const int input_dim = inp->ne[1];
31 | const int hidden_dim = weight_ih->ne[1] / 4;
32 |
33 | char ct_name[10];
34 | char ht_name[10];
35 |
36 | snprintf(ct_name, 10, "%s_ct", prefix);
37 | snprintf(ht_name, 10, "%s_ht", prefix);
38 |
39 | struct ggml_tensor *hs = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, hidden_dim, seq_length);
40 | ggml_set_input(hs);
41 |
42 | struct ggml_tensor *c_t = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, hidden_dim);
43 | ggml_set_input(c_t);
44 | ggml_set_name(c_t, ct_name);
45 |
46 | struct ggml_tensor *h_t = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, hidden_dim);
47 | ggml_set_input(h_t);
48 | ggml_set_name(h_t, ht_name);
49 |
50 | struct ggml_tensor *current = ggml_cont(ctx0, ggml_transpose(ctx0, inp));
51 |
52 | for (int t = 0; t < seq_length; t++) {
53 | struct ggml_tensor *x_t = ggml_view_1d(ctx0, current, input_dim, t * current->nb[1]);
54 |
55 | struct ggml_tensor *inp_gates = ggml_mul_mat(ctx0, weight_ih, x_t);
56 | inp_gates = ggml_add(ctx0, inp_gates, bias_ih);
57 |
58 | struct ggml_tensor *hid_gates = ggml_mul_mat(ctx0, weight_hh, h_t);
59 | hid_gates = ggml_add(ctx0, hid_gates, bias_hh);
60 |
61 | struct ggml_tensor *out_gates = ggml_add(ctx0, inp_gates, hid_gates);
62 |
63 | struct ggml_tensor *i_t = ggml_sigmoid(ctx0, ggml_view_1d(ctx0, out_gates, hidden_dim, 0 * sizeof(float) * hidden_dim));
64 | struct ggml_tensor *f_t = ggml_sigmoid(ctx0, ggml_view_1d(ctx0, out_gates, hidden_dim, 1 * sizeof(float) * hidden_dim));
65 | struct ggml_tensor *g_t = ggml_tanh(ctx0 , ggml_view_1d(ctx0, out_gates, hidden_dim, 2 * sizeof(float) * hidden_dim));
66 | struct ggml_tensor *o_t = ggml_sigmoid(ctx0, ggml_view_1d(ctx0, out_gates, hidden_dim, 3 * sizeof(float) * hidden_dim));
67 |
68 | c_t = ggml_add(ctx0, ggml_mul(ctx0, f_t, c_t), ggml_mul(ctx0, i_t, g_t));
69 |
70 | h_t = ggml_mul(ctx0, o_t, ggml_tanh(ctx0, c_t));
71 |
72 | hs = ggml_set_1d(ctx0, hs, h_t, t * hs->nb[1]);
73 | }
74 |
75 | hs = ggml_cont(ctx0, ggml_transpose(ctx0, hs));
76 |
77 | return hs;
78 | }
79 |
--------------------------------------------------------------------------------
/vall_e.cpp/include/ggml-alloc.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 |
3 | #include "ggml.h"
4 |
5 | #ifdef __cplusplus
6 | extern "C" {
7 | #endif
8 |
9 | typedef struct ggml_backend_buffer_type * ggml_backend_buffer_type_t;
10 | typedef struct ggml_backend_buffer * ggml_backend_buffer_t;
11 | typedef struct ggml_backend * ggml_backend_t;
12 |
13 | // Tensor allocator
14 | struct ggml_tallocr {
15 | ggml_backend_buffer_t buffer;
16 | void * base;
17 | size_t alignment;
18 | size_t offset;
19 | };
20 |
21 | GGML_API struct ggml_tallocr ggml_tallocr_new(ggml_backend_buffer_t buffer);
22 | GGML_API enum ggml_status ggml_tallocr_alloc(struct ggml_tallocr * talloc, struct ggml_tensor * tensor);
23 |
24 | // Graph allocator
25 | /*
26 | Example usage:
27 | ggml_gallocr_t galloc = ggml_gallocr_new(ggml_backend_cpu_buffer_type());
28 |
29 | // optional: create a worst-case graph and reserve the buffers to avoid reallocations
30 | ggml_gallocr_reserve(galloc, build_graph(max_batch));
31 |
32 | // allocate the graph
33 | struct ggml_cgraph * graph = build_graph(batch);
34 | ggml_gallocr_alloc_graph(galloc, graph);
35 |
36 | printf("compute buffer size: %zu bytes\n", ggml_gallocr_get_buffer_size(galloc, 0));
37 |
38 | // evaluate the graph
39 | ggml_backend_graph_compute(backend, graph);
40 | */
41 |
42 | // special tensor flags for use with the graph allocator:
43 | // ggml_set_input(): all input tensors are allocated at the beginning of the graph in non-overlapping addresses
44 | // ggml_set_output(): output tensors are never freed and never overwritten
45 |
46 | typedef struct ggml_gallocr * ggml_gallocr_t;
47 |
48 | GGML_API ggml_gallocr_t ggml_gallocr_new(ggml_backend_buffer_type_t buft);
49 | GGML_API ggml_gallocr_t ggml_gallocr_new_n(ggml_backend_buffer_type_t * bufts, int n_bufs);
50 | GGML_API void ggml_gallocr_free(ggml_gallocr_t galloc);
51 |
52 | // pre-allocate buffers from a measure graph - does not allocate or modify the graph
53 | // call with a worst-case graph to avoid buffer reallocations
54 | // not strictly required for single buffer usage: ggml_gallocr_alloc_graph will reallocate the buffers automatically if needed
55 | // returns false if the buffer allocation failed
56 | GGML_API bool ggml_gallocr_reserve(ggml_gallocr_t galloc, struct ggml_cgraph * graph);
57 | GGML_API bool ggml_gallocr_reserve_n(
58 | ggml_gallocr_t galloc,
59 | struct ggml_cgraph * graph,
60 | const int * node_buffer_ids,
61 | const int * leaf_buffer_ids);
62 |
63 | // automatic reallocation if the topology changes when using a single buffer
64 | // returns false if using multiple buffers and a re-allocation is needed (call ggml_gallocr_reserve_n first to set the node buffers)
65 | GGML_API bool ggml_gallocr_alloc_graph(ggml_gallocr_t galloc, struct ggml_cgraph * graph);
66 |
67 | GGML_API size_t ggml_gallocr_get_buffer_size(ggml_gallocr_t galloc, int buffer_id);
68 |
69 | // Utils
70 | // Create a buffer and allocate all the tensors in a ggml_context
71 | GGML_API struct ggml_backend_buffer * ggml_backend_alloc_ctx_tensors_from_buft(struct ggml_context * ctx, ggml_backend_buffer_type_t buft);
72 | GGML_API struct ggml_backend_buffer * ggml_backend_alloc_ctx_tensors(struct ggml_context * ctx, ggml_backend_t backend);
73 |
74 | #ifdef __cplusplus
75 | }
76 | #endif
77 |
--------------------------------------------------------------------------------
/docs/utils.md:
--------------------------------------------------------------------------------
1 | # `utils/*`
2 |
3 | This folder contains helper utilities for either training or general functions of the program.
4 |
5 | These scripts are to remain agnostic to any model, to allow for reuse for other applications.
6 |
7 | ## `utils/distributed.py`
8 |
9 | This script contains the necessary code needed to utilize distributed training.
10 |
11 | Attributions are noted at the top.
12 |
13 | ## `utils/io.py`
14 |
15 | This script contains the necessary code for loading and storing state dicts, through pickles (`.pt`) or SafeTensors (`.sft`), and offers parity for each storage type.
16 |
17 | Additionally, some JSON helper functions are provided here.
18 |
19 | ## `utils/pattern.py`
20 |
21 | This script contains (unused) code related to formatting sequences of audio codes into different pattern types.
22 |
23 | Attributions are noted at the top.
24 |
25 | ## `utils/sampler.py`
26 |
27 | This script contains code to handle sampling from a list of indices.
28 | * `PoolSampler` has a master list of indices "in the marble bag" that are sampled without replacement.
29 | * `OrderedSampler` will output indices from 0 to `length`, in order.
30 | * `BatchedOrderedSampler` does the above, but will output lists of indices instead.
31 | * `RandomSampler` will output indices from 0 to `length`, randomly.
32 |
33 | Each sampler can load and store a state dict.
34 |
35 |
36 | ## `utils/utils.py`
37 |
38 | This script contains additional helper functions that do not require a dedicated file.
39 |
40 | ## `utils/train.py`
41 |
42 | This script handles the necessary code for training, such as:
43 | * iterating through a dataloader
44 | * iterating through an `Engines` to train each underlying `Engine`
45 | * printing training metrics
46 | * invoking `save`, `eval`, `export` every X iterations
47 | * handling stdin commands, such as `save`, `export`, `eval`, and `quit`
48 |
49 | ## `utils/wrapper.py`
50 |
51 | This script contains optimizations and additional code that require injecting or replacing modules.
52 |
53 | Most configurations are offered through `cfg.optimization`.
54 |
55 | ## `utils/ext/`
56 |
57 | This folder contains external code that can't be nicely referenced under a package.
58 |
59 | Proper attribution is noted at the top of each file.
60 |
61 | ### `utils/ext/apollo.py`
62 |
63 | This script contains [APOLLO](https://github.com/zhuhanqing/APOLLO), an optimizer that achieves ADAMW-like performance with very little memory cost.
64 |
65 | In testing, this seems to work fine, and the memory gains (in comparison to Prodigyopt) under the normal-specced model allows you to double the batch size.
66 |
67 | It's definitely usable under extremely low VRAM environments, and specifying `apollo-mini` will further shrink the memory requirements (but robustness is yet to be personally tested).
68 |
69 | However, after a while, it seemed to cause some steps to either cause gradient overflow or NaNs that persist even when swapping back to `prodigyopt` (but I do not know if it's at the fault of `APOLLO` or just the model eventually hitting a point of instability).
70 |
71 | ### `utils/ext/unsloth.py`
72 |
73 | This script contains Unsloth, a VRAM-saving optimization that offloads the input tensors to CPU on a backwards pass.
74 |
75 | This is mostly unncessary, as inputs are rather small themselves, but is offered nonetheless if needed through `cfg.optimizations.unsloth = True`
--------------------------------------------------------------------------------
/docs/webui.md:
--------------------------------------------------------------------------------
1 | # `webui.py`
2 |
3 | A Gradio-based web UI is accessible by running `python3 -m vall_e.webui`. You can, optionally, pass:
4 |
5 | * `--yaml=./path/to/your/config.yaml`: will load the targeted YAML
6 | * `--model=./path/to/your/model.sft`: will load the targeted model weights
7 | * `--listen 0.0.0.0:7860`: will set the web UI to listen to all IPs at port 7860. Replace the IP and Port to your preference.
8 |
9 | ## Inference
10 |
11 | Synthesizing speech is simple:
12 |
13 | * `Text`:
14 | * `Input Prompt`: The guiding text prompt. Each segment will be its own generated audio to be stitched together at the end.
15 | * `Audio`:
16 | * `Audio Input`: The transcription of the audio will be inserted into the `Text/Input Prompt` box.
17 | * For `vc` task, this will serve as the guidance reference audio as well.
18 |
19 | * `Audio Input`: The reference audio for the synthesis. Under Gradio, you can trim your clip accordingly, but leaving it as-is works fine.
20 | - A properly trained model can inference without a prompt to generate a random voice (without even needing to generate a random prompt itself).
21 | * `Output`: The resultant audio.
22 | * `Inference`: Button to start generating the audio.
23 | * `Basic Settings`: Basic sampler settings for most uses.
24 | * `Max Steps`: Number of demasking steps to perform for RVQ level 0. For the `NAR-len` modality.
25 | * `Max Duration`: Maximum duration the output audio will be.
26 | * `Input Prompt Repeat/Trim Length`: The audio prompt will be this duration length, as it will either be trimmed down or repeated (although repeating might cause more harm).
27 | * `Language (Text)`: The language of the input text for phonemizing.
28 | * `Language (Output)`: The target language for the output audio. Some checkpoints of the model might ignore this due to how it was trained, unfortunately. Some models might steer the output accent.
29 | * `Task`: The task to perform (in order): Text-To-Speech, Speech Removal, Noise Reduction, Voice Conversion.
30 | * `Text Delimiter`: How to split the `Text/Input Prompt`. Sentences will split by sentences, while lines will split by new lines.
31 | * `(Rolling) Context History`: Paired with the above, the previous N utterances will serve as the prefix to extend the generation on, allowing for consistency and stability across pieces.
32 | * `Sampler Settings`: Advanced sampler settings that are common for most text LLMs, but needs experimentation.
33 | * `Experimental Settings`: Settings used for testing. `cfg.experimental=True` enables this tab.
34 |
35 | All the additional knobs have a description that can be correlated to the inferencing CLI flags.
36 |
37 | Speech-To-Text phoneme transcriptions for models that support it can be done using the `Speech-to-Text` tab.
38 |
39 | ## Dataset
40 |
41 | This tab currently only features exploring a dataset already prepared and referenced in your `config.yaml`. You can select a registered voice, and have it randomly sample an utterance.
42 |
43 | In the future, this *should* contain the necessary niceties to process raw audio into a dataset to train/finetune through, without needing to invoke the above commands to prepare the dataset.
44 |
45 | ## Settings
46 |
47 | So far, this only allows you to load a different model under a different dtype, device, and/or attention mechanism. without needing to restart. The previous model should seamlessly unload, and the new one will load in place.
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | import subprocess
2 | import sys
3 | from pathlib import Path
4 | from datetime import datetime
5 | from setuptools import setup, find_packages
6 |
7 | def shell(*args):
8 | out = subprocess.check_output(args)
9 | return out.decode("ascii").strip()
10 |
11 | def write_version(version_core, pre_release=True):
12 | if pre_release:
13 | time = shell("git", "log", "-1", "--format=%cd", "--date=iso")
14 | time = datetime.strptime(time, "%Y-%m-%d %H:%M:%S %z")
15 | time = time.strftime("%Y%m%d%H%M%S")
16 | version = f"{version_core}-dev{time}"
17 | else:
18 | version = version_core
19 |
20 | with open(Path("vall_e", "version.py"), "w") as f:
21 | f.write('__version__ = "{}"\n'.format(version))
22 |
23 | return version
24 |
25 | with open("README.md", "r") as f:
26 | long_description = f.read()
27 |
28 | platform_dependencies = []
29 |
30 | if sys.platform.startswith("win"):
31 | platform_dependencies += ["psutil"]
32 | else:
33 | platform_dependencies += ["deepspeed>=0.7.7"]
34 |
35 | setup(
36 | name="vall-e",
37 | python_requires=">=3.10.0",
38 | version=write_version("0.0.1"),
39 | description="An unofficial implementation of the audio LM VALL-E",
40 | author="ecker",
41 | author_email="mrq@ecker.tech",
42 | long_description=long_description,
43 | long_description_content_type="text/markdown",
44 | packages=find_packages(),
45 | install_requires=
46 | platform_dependencies + [
47 | # logging niceties
48 | "coloredlogs>=15.0.1", # barely required
49 | "humanize>=4.4.0", # not really required
50 | "matplotlib>=3.6.0", # only required for plotting
51 | "pandas>=1.5.0", # not really required
52 |
53 | # boiler plate niceties
54 | #"diskcache>=5.4.0",
55 | "einops>=0.6.0", # could be replaced
56 | "tqdm",
57 |
58 | # HF bloat
59 | "tokenizers",
60 | "transformers",
61 | "safetensors",
62 |
63 | # training bloat
64 | "auraloss[all]", # [all] is needed for MelSTFTLoss
65 | "h5py",
66 | "prodigyopt @ git+https://github.com/konstmish/prodigy",
67 |
68 | # practically the reason to use python
69 | "numpy",
70 | "torch>=1.13.0",
71 | "torchaudio>=0.13.0",
72 | "torchmetrics",
73 |
74 | # core foundations
75 | "phonemizer>=2.1.0",
76 | "encodec>=0.1.1",
77 | "vocos",
78 |
79 | # for the web UI
80 | "gradio",
81 | "nltk", # for parsing text inputs down to pieces
82 | "langdetect", # for detecting the language of a text
83 | "sounddevice", # for raw playback
84 | ],
85 | extras_require = {
86 | "all": [
87 | # retnet backend (even though two internal copies exist)
88 | "torchscale @ git+https://git.ecker.tech/mrq/torchscale",
89 | # bitnet
90 | "bitnet",
91 | # mamba
92 | "causal-conv1d",
93 | "mamba-ssm",
94 |
95 | #
96 | "torcheval",
97 |
98 | # attention helpers
99 | "xformers",
100 | "sageattention==1.0.6",
101 | # "flash-attn" --no-build-isolation # commented out right now because I want to query this for Volta freaks like me who can't use it
102 |
103 | # other audio backend that doesn't prove fruitful
104 | "descript-audio-codec",
105 |
106 | # nemo (to-do: cut this down)
107 | "nemo-toolkit",
108 | "hydra-core",
109 | "lightning",
110 | "sentencepiece"
111 | ]
112 | },
113 | url="https://git.ecker.tech/mrq/vall-e",
114 | )
115 |
--------------------------------------------------------------------------------
/scripts/cleanup_dataset.py:
--------------------------------------------------------------------------------
1 | """
2 | # Helper script to clean up transcription metadata, whatever that entailed.
3 | """
4 |
5 | import os
6 | import json
7 | import torch
8 | import torchaudio
9 |
10 | from tqdm.auto import tqdm
11 | from pathlib import Path
12 |
13 | input_dataset = "training/metadata"
14 | output_dataset = "training/metadata-cleaned"
15 |
16 | def pad(num, zeroes):
17 | return str(num).zfill(zeroes+1)
18 |
19 | for dataset_name in os.listdir(f'./{input_dataset}/'):
20 | if not os.path.isdir(f'./{input_dataset}/{dataset_name}/'):
21 | print("Is not dir:", f'./{input_dataset}/{dataset_name}/')
22 | continue
23 |
24 | for speaker_id in tqdm(os.listdir(f'./{input_dataset}/{dataset_name}/'), desc=f"Processing speaker: {dataset_name}"):
25 | if not os.path.isdir(f'./{input_dataset}/{dataset_name}/{speaker_id}'):
26 | print("Is not dir:", f'./{input_dataset}/{dataset_name}/{speaker_id}')
27 | continue
28 |
29 | inpath = Path(f'./{input_dataset}/{dataset_name}/{speaker_id}/whisper.json')
30 | outpath = Path(f'./{output_dataset}/{dataset_name}/{speaker_id}/whisper.json')
31 |
32 | if not inpath.exists():
33 | continue
34 |
35 | if outpath.exists():
36 | continue
37 |
38 | os.makedirs(f'./{output_dataset}/{dataset_name}/{speaker_id}/', exist_ok=True)
39 |
40 | try:
41 | in_metadata = json.loads(open(inpath, 'r', encoding='utf-8').read())
42 | except Exception as e:
43 | print("Failed to open metadata file:", inpath)
44 | continue
45 |
46 | out_metadata = {}
47 | speaker_metadatas = {}
48 |
49 | for filename, result in in_metadata.items():
50 | language = result["language"] if "language" in result else "en"
51 | out_metadata[filename] = {
52 | "segments": [],
53 | "language": language,
54 | "text": "",
55 | "start": 0,
56 | "end": 0,
57 | }
58 | segments = []
59 | text = []
60 | start = 0
61 | end = 0
62 | diarized = False
63 |
64 | for segment in result["segments"]:
65 | # diarize split
66 | if "speaker" in segment:
67 | diarized = True
68 | speaker_id = segment["speaker"]
69 | if speaker_id not in speaker_metadatas:
70 | speaker_metadatas[speaker_id] = {}
71 |
72 | if filename not in speaker_metadatas[speaker_id]:
73 | speaker_metadatas[speaker_id][filename] = {
74 | "segments": [],
75 | "language": language,
76 | "text": "",
77 | "start": 0,
78 | "end": 0,
79 | }
80 |
81 | speaker_metadatas[speaker_id][filename]["segments"].append( segment )
82 | else:
83 | segments.append( segment )
84 |
85 | text.append( segment["text"] )
86 | start = min( start, segment["start"] )
87 | end = max( end, segment["end"] )
88 |
89 | out_metadata[filename]["segments"] = segments
90 | out_metadata[filename]["text"] = " ".join(text).strip()
91 | out_metadata[filename]["start"] = start
92 | out_metadata[filename]["end"] = end
93 |
94 | if len(segments) == 0:
95 | del out_metadata[filename]
96 |
97 | open(outpath, 'w', encoding='utf-8').write(json.dumps(out_metadata))
98 |
99 | for speaker_id, out_metadata in speaker_metadatas.items():
100 | os.makedirs(f'./{output_dataset}/{dataset_name}/{speaker_id}/', exist_ok=True)
101 | outpath = Path(f'./{output_dataset}/{dataset_name}/{speaker_id}/whisper.json')
102 |
103 | open(outpath, 'w', encoding='utf-8').write(json.dumps(out_metadata))
--------------------------------------------------------------------------------
/vall_e.cpp/include/encoder.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 |
3 | #include
4 |
5 | #include "ggml.h"
6 | #include "lstm.h"
7 |
8 | // res + downsample block at some ratio
9 | struct encodec_encoder_block {
10 | // conv1
11 | struct ggml_tensor *conv_1_w;
12 | struct ggml_tensor *conv_1_b;
13 |
14 | // conv2
15 | struct ggml_tensor *conv_2_w;
16 | struct ggml_tensor *conv_2_b;
17 |
18 | // shortcut
19 | struct ggml_tensor *conv_sc_w;
20 | struct ggml_tensor *conv_sc_b;
21 |
22 | // downsampling layers
23 | struct ggml_tensor *ds_conv_w;
24 | struct ggml_tensor *ds_conv_b;
25 | };
26 |
27 | struct encodec_encoder {
28 | struct ggml_tensor *init_conv_w;
29 | struct ggml_tensor *init_conv_b;
30 |
31 | encodec_lstm lstm;
32 |
33 | struct ggml_tensor *final_conv_w;
34 | struct ggml_tensor *final_conv_b;
35 |
36 | std::vector blocks;
37 | };
38 |
39 | struct ggml_tensor *encodec_forward_encoder(
40 | const struct encodec_encoder *encoder, struct ggml_context *ctx0,
41 | struct ggml_tensor *inp, const int * ratios, const int kernel_size, const int res_kernel_size,
42 | const int stride) {
43 |
44 | if (!inp) {
45 | fprintf(stderr, "%s: null input tensor\n", __func__);
46 | return NULL;
47 | }
48 |
49 | struct ggml_tensor *inpL = strided_conv_1d(
50 | ctx0, inp, encoder->init_conv_w, encoder->init_conv_b, stride);
51 |
52 | for (int layer_ix = 0; layer_ix < 4; layer_ix++) {
53 | encodec_encoder_block block = encoder->blocks[layer_ix];
54 |
55 | struct ggml_tensor *current = inpL;
56 |
57 | // shortcut
58 | struct ggml_tensor *shortcut = strided_conv_1d(
59 | ctx0, inpL, block.conv_sc_w, block.conv_sc_b, stride);
60 |
61 | // conv1
62 | current = ggml_elu(ctx0, current);
63 |
64 | current = strided_conv_1d(
65 | ctx0, current, block.conv_1_w, block.conv_1_b, stride);
66 |
67 | // conv2
68 | current = ggml_elu(ctx0, current);
69 |
70 | current = strided_conv_1d(
71 | ctx0, current, block.conv_2_w, block.conv_2_b, stride);
72 |
73 | // residual connection
74 | inpL = ggml_add(ctx0, current, shortcut);
75 |
76 | // downsampling layers
77 | inpL = ggml_elu(ctx0, inpL);
78 |
79 | inpL = strided_conv_1d(
80 | ctx0, inpL, block.ds_conv_w, block.ds_conv_b, ratios[3 - layer_ix]);
81 | }
82 |
83 | // lstm
84 | {
85 | struct ggml_tensor *cur = inpL;
86 |
87 | const encodec_lstm lstm = encoder->lstm;
88 |
89 | // first lstm layer
90 | char l0_prefix[7] = "enc_l0";
91 | struct ggml_tensor *hs1 = forward_pass_lstm_unilayer(
92 | ctx0, cur, lstm.l0_ih_w, lstm.l0_hh_w, lstm.l0_ih_b, lstm.l0_hh_b, l0_prefix);
93 |
94 | // second lstm layer
95 | char l1_prefix[7] = "enc_l1";
96 | struct ggml_tensor *out = forward_pass_lstm_unilayer(
97 | ctx0, hs1, lstm.l1_ih_w, lstm.l1_hh_w, lstm.l1_ih_b, lstm.l1_hh_b, l1_prefix);
98 |
99 | inpL = ggml_add(ctx0, inpL, out);
100 | }
101 |
102 | // final conv
103 | inpL = ggml_elu(ctx0, inpL);
104 |
105 | struct ggml_tensor *encoded_inp = strided_conv_1d(
106 | ctx0, inpL, encoder->final_conv_w, encoder->final_conv_b, stride);
107 |
108 | return encoded_inp;
109 | }
110 |
--------------------------------------------------------------------------------
/vall_e/models/__init__.py:
--------------------------------------------------------------------------------
1 | import os
2 | import logging
3 | import requests
4 | import time
5 |
6 | from tqdm import tqdm
7 | from pathlib import Path
8 |
9 | _logger = logging.getLogger(__name__)
10 |
11 | # to-do: implement automatically downloading model
12 | DEFAULT_MODEL_NAME = os.environ.get("VALLE_DEFAULT_MODEL_NAME", "ar+nar-len-llama-8.sft")
13 | DEFAULT_MODEL_DIR = Path(__file__).parent.parent.parent / 'data/models'
14 | DEFAULT_MODEL_PATH = DEFAULT_MODEL_DIR / DEFAULT_MODEL_NAME
15 | DEFAULT_MODEL_URLS = {
16 | 'ar+nar-len-llama-8.sft': 'https://huggingface.co/ecker/vall-e/resolve/main/models/ckpt/ar%2Bnar-len-llama-8/ckpt/fp32.sft',
17 | 'nemo-larger-44khz-llama-8.sft': 'https://huggingface.co/ecker/vall-e/resolve/main/models/ckpt/nemo-larger-44khz-llama-8/fp32.sft',
18 | 'wavlm_large_finetune.pth': 'https://huggingface.co/Dongchao/UniAudio/resolve/main/wavlm_large_finetune.pth',
19 | }
20 |
21 | if not DEFAULT_MODEL_PATH.exists() and Path(f"./data/models/{DEFAULT_MODEL_NAME}").exists():
22 | DEFAULT_MODEL_DIR = Path('./data/models')
23 | DEFAULT_MODEL_PATH = DEFAULT_MODEL_DIR / DEFAULT_MODEL_NAME
24 |
25 | # kludge, probably better to use HF's model downloader function
26 | # to-do: write to a temp file then copy so downloads can be interrupted
27 | def download_model( save_path=DEFAULT_MODEL_PATH, chunkSize = 1024 ):
28 | name = save_path.name
29 | url = DEFAULT_MODEL_URLS[name] if name in DEFAULT_MODEL_URLS else None
30 | if url is None:
31 | raise Exception(f'Model requested for download but not defined: {name}')
32 |
33 | if not save_path.parent.exists():
34 | save_path.parent.mkdir(parents=True, exist_ok=True)
35 |
36 | headers = {}
37 | # check if modified
38 | if save_path.exists():
39 | headers = {"If-Modified-Since": time.strftime("%a, %d %b %Y %H:%M:%S GMT", time.gmtime(save_path.stat().st_mtime))}
40 |
41 | r = requests.get(url, headers=headers, stream=True)
42 |
43 | # not modified
44 | if r.status_code == 304:
45 | r.close()
46 | return
47 |
48 | # to-do: validate lengths match
49 |
50 | content_length = int(r.headers['Content-Length'] if 'Content-Length' in r.headers else r.headers['content-length'])
51 | with open(save_path, 'wb') as f:
52 | bar = tqdm( unit='B', unit_scale=True, unit_divisor=1024, total=content_length, desc=f"Downloading: {name}" )
53 | for chunk in r.iter_content(chunk_size=chunkSize):
54 | if not chunk:
55 | continue
56 | bar.update( len(chunk))
57 | f.write(chunk)
58 | bar.close()
59 |
60 | r.close()
61 |
62 |
63 | def get_model(config, training=True, **model_kwargs):
64 | # crunge
65 | if config.version < 7:
66 | from .ar_nar import AR_NAR
67 | ModelClass = AR_NAR
68 | else:
69 | from .ar_nar_v2 import AR_NAR_V2
70 | ModelClass = AR_NAR_V2
71 |
72 | cfg_kwargs = dict(
73 | n_phn_tokens=config.phoneme_tokens,
74 | n_audio_tokens=config.audio_tokens,
75 | n_text_tokens=config.text_tokens,
76 | d_model=config.dim,
77 | n_heads=config.heads,
78 | n_layers=config.layers,
79 | n_experts=config.experts,
80 |
81 | p_dropout=config.dropout,
82 |
83 | l_padding = config.input_alignment,
84 |
85 | training = training,
86 | config = config,
87 | )
88 |
89 | name = config.name
90 | model = ModelClass(**(cfg_kwargs | model_kwargs))
91 |
92 | _logger.info(f"{name} ({next(model.parameters()).dtype}): {sum(p.numel() for p in model.parameters() if p.requires_grad)} parameters")
93 |
94 | return model
95 |
96 | def get_models(models, training=True, **model_kwargs):
97 | return { model.full_name: get_model(model, training=training, **model_kwargs) for model in models }
98 |
--------------------------------------------------------------------------------
/vall_e.cpp/include/decoder.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 |
3 | #include
4 |
5 | #include "ggml.h"
6 | #include "ggml-alloc.h"
7 | #include "ggml-backend.h"
8 |
9 | #include "lstm.h"
10 | #include "utils.h"
11 |
12 |
13 | struct encodec_decoder_block {
14 | // upsampling layers
15 | struct ggml_tensor *us_conv_w;
16 | struct ggml_tensor *us_conv_b;
17 |
18 | // conv1
19 | struct ggml_tensor *conv_1_w;
20 | struct ggml_tensor *conv_1_b;
21 |
22 | // conv2
23 | struct ggml_tensor *conv_2_w;
24 | struct ggml_tensor *conv_2_b;
25 |
26 | // shortcut
27 | struct ggml_tensor *conv_sc_w;
28 | struct ggml_tensor *conv_sc_b;
29 | };
30 |
31 | struct encodec_decoder {
32 | struct ggml_tensor *init_conv_w;
33 | struct ggml_tensor *init_conv_b;
34 |
35 | encodec_lstm lstm;
36 |
37 | struct ggml_tensor *final_conv_w;
38 | struct ggml_tensor *final_conv_b;
39 |
40 | std::vector blocks;
41 | };
42 |
43 | struct ggml_tensor *encodec_forward_decoder(
44 | const struct encodec_decoder *decoder, struct ggml_context *ctx0,
45 | struct ggml_tensor *quantized_out, const int *ratios, const int kernel_size, const int res_kernel_size,
46 | const int stride) {
47 |
48 | if (!quantized_out) {
49 | fprintf(stderr, "%s: null input tensor\n", __func__);
50 | return NULL;
51 | }
52 |
53 | struct ggml_tensor *inpL = strided_conv_1d(
54 | ctx0, quantized_out, decoder->init_conv_w, decoder->init_conv_b, stride);
55 |
56 | // lstm
57 | {
58 | struct ggml_tensor *cur = inpL;
59 |
60 | const encodec_lstm lstm = decoder->lstm;
61 |
62 | // first lstm layer
63 | char l0_prefix[7] = "dec_l0";
64 | struct ggml_tensor *hs1 = forward_pass_lstm_unilayer(
65 | ctx0, cur, lstm.l0_ih_w, lstm.l0_hh_w, lstm.l0_ih_b, lstm.l0_hh_b, l0_prefix);
66 |
67 | // second lstm layer
68 | char l1_prefix[7] = "dec_l1";
69 | struct ggml_tensor *out = forward_pass_lstm_unilayer(
70 | ctx0, hs1, lstm.l1_ih_w, lstm.l1_hh_w, lstm.l1_ih_b, lstm.l1_hh_b, l1_prefix);
71 |
72 | inpL = ggml_add(ctx0, inpL, out);
73 | }
74 |
75 | for (int layer_ix = 0; layer_ix < 4; layer_ix++) {
76 | encodec_decoder_block block = decoder->blocks[layer_ix];
77 |
78 | // upsampling layers
79 | inpL = ggml_elu(ctx0, inpL);
80 |
81 | inpL = strided_conv_transpose_1d(
82 | ctx0, inpL, block.us_conv_w, block.us_conv_b, ratios[layer_ix]);
83 |
84 | struct ggml_tensor *current = inpL;
85 |
86 | // shortcut
87 | struct ggml_tensor *shortcut = strided_conv_1d(
88 | ctx0, inpL, block.conv_sc_w, block.conv_sc_b, stride);
89 |
90 | // conv1
91 | current = ggml_elu(ctx0, current);
92 |
93 | current = strided_conv_1d(
94 | ctx0, current, block.conv_1_w, block.conv_1_b, stride);
95 |
96 | // conv2
97 | current = ggml_elu(ctx0, current);
98 |
99 | current = strided_conv_1d(
100 | ctx0, current, block.conv_2_w, block.conv_2_b, stride);
101 |
102 | // residual connection
103 | inpL = ggml_add(ctx0, current, shortcut);
104 | }
105 |
106 | // final conv
107 | inpL = ggml_elu(ctx0, inpL);
108 |
109 | struct ggml_tensor *decoded_inp = strided_conv_1d(
110 | ctx0, inpL, decoder->final_conv_w, decoder->final_conv_b, stride);
111 |
112 | return decoded_inp;
113 | }
114 |
--------------------------------------------------------------------------------
/vall_e.cpp/include/espeak-ng/encoding.h:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (C) 2017 Reece H. Dunn
3 | *
4 | * This program is free software; you can redistribute it and/or modify
5 | * it under the terms of the GNU General Public License as published by
6 | * the Free Software Foundation; either version 3 of the License, or
7 | * (at your option) any later version.
8 | *
9 | * This program is distributed in the hope that it will be useful,
10 | * but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | * GNU General Public License for more details.
13 | *
14 | * You should have received a copy of the GNU General Public License
15 | * along with this program; if not, see: .
16 | */
17 | #ifndef ESPEAK_NG_ENCODING_H
18 | #define ESPEAK_NG_ENCODING_H
19 |
20 | #include
21 |
22 | #ifdef __cplusplus
23 | extern "C"
24 | {
25 | #endif
26 |
27 | typedef enum
28 | {
29 | ESPEAKNG_ENCODING_UNKNOWN,
30 | ESPEAKNG_ENCODING_US_ASCII,
31 | ESPEAKNG_ENCODING_ISO_8859_1,
32 | ESPEAKNG_ENCODING_ISO_8859_2,
33 | ESPEAKNG_ENCODING_ISO_8859_3,
34 | ESPEAKNG_ENCODING_ISO_8859_4,
35 | ESPEAKNG_ENCODING_ISO_8859_5,
36 | ESPEAKNG_ENCODING_ISO_8859_6,
37 | ESPEAKNG_ENCODING_ISO_8859_7,
38 | ESPEAKNG_ENCODING_ISO_8859_8,
39 | ESPEAKNG_ENCODING_ISO_8859_9,
40 | ESPEAKNG_ENCODING_ISO_8859_10,
41 | ESPEAKNG_ENCODING_ISO_8859_11,
42 | // ISO-8859-12 is not a valid encoding.
43 | ESPEAKNG_ENCODING_ISO_8859_13,
44 | ESPEAKNG_ENCODING_ISO_8859_14,
45 | ESPEAKNG_ENCODING_ISO_8859_15,
46 | ESPEAKNG_ENCODING_ISO_8859_16,
47 | ESPEAKNG_ENCODING_KOI8_R,
48 | ESPEAKNG_ENCODING_ISCII,
49 | ESPEAKNG_ENCODING_UTF_8,
50 | ESPEAKNG_ENCODING_ISO_10646_UCS_2,
51 | } espeak_ng_ENCODING;
52 |
53 | ESPEAK_NG_API espeak_ng_ENCODING
54 | espeak_ng_EncodingFromName(const char *encoding);
55 |
56 | typedef struct espeak_ng_TEXT_DECODER_ espeak_ng_TEXT_DECODER;
57 |
58 | ESPEAK_NG_API espeak_ng_TEXT_DECODER *
59 | create_text_decoder(void);
60 |
61 | ESPEAK_NG_API void
62 | destroy_text_decoder(espeak_ng_TEXT_DECODER *decoder);
63 |
64 | ESPEAK_NG_API espeak_ng_STATUS
65 | text_decoder_decode_string(espeak_ng_TEXT_DECODER *decoder,
66 | const char *string,
67 | int length,
68 | espeak_ng_ENCODING encoding);
69 |
70 | ESPEAK_NG_API espeak_ng_STATUS
71 | text_decoder_decode_string_auto(espeak_ng_TEXT_DECODER *decoder,
72 | const char *string,
73 | int length,
74 | espeak_ng_ENCODING encoding);
75 |
76 | ESPEAK_NG_API espeak_ng_STATUS
77 | text_decoder_decode_wstring(espeak_ng_TEXT_DECODER *decoder,
78 | const wchar_t *string,
79 | int length);
80 |
81 | ESPEAK_NG_API espeak_ng_STATUS
82 | text_decoder_decode_string_multibyte(espeak_ng_TEXT_DECODER *decoder,
83 | const void *input,
84 | espeak_ng_ENCODING encoding,
85 | int flags);
86 |
87 | ESPEAK_NG_API int
88 | text_decoder_eof(espeak_ng_TEXT_DECODER *decoder);
89 |
90 | ESPEAK_NG_API uint32_t
91 | text_decoder_getc(espeak_ng_TEXT_DECODER *decoder);
92 |
93 | ESPEAK_NG_API uint32_t
94 | text_decoder_peekc(espeak_ng_TEXT_DECODER *decoder);
95 |
96 | ESPEAK_NG_API const void *
97 | text_decoder_get_buffer(espeak_ng_TEXT_DECODER *decoder);
98 |
99 | #ifdef __cplusplus
100 | }
101 | #endif
102 |
103 | #endif
104 |
--------------------------------------------------------------------------------
/scripts/train_tokenizer.py:
--------------------------------------------------------------------------------
1 | """
2 | # Helper script to grab all phonemes through parsed dataset metadata to find the "best" tokenizer dict
3 | """
4 |
5 | import os
6 | import json
7 | import torch
8 | import torchaudio
9 |
10 | from tqdm.auto import tqdm
11 | from pathlib import Path
12 |
13 | from tokenizers import Tokenizer
14 | from tokenizers.models import BPE, Unigram, WordLevel, WordPiece
15 | from tokenizers.trainers import BpeTrainer
16 | from tokenizers.pre_tokenizers import Whitespace
17 | from tokenizers.processors import TemplateProcessing
18 |
19 | from vall_e.config import cfg
20 | from vall_e.utils.io import json_read
21 | from vall_e.emb.g2p import coerce_to_hiragana
22 |
23 | input_metadata = "training/metadata/"
24 |
25 | output_file = Path("./training/tokenizer_pretraining_data.json")
26 | tokenizer_data = []
27 |
28 | def pad(num, zeroes):
29 | return str(num).zfill(zeroes+1)
30 |
31 | def add( dir, type="training", audios=True, texts=True ):
32 | name = str(dir)
33 | name = name.replace(str(cfg.data_dir), "")
34 | speaker_name = name
35 | """
36 | if "LibriTTS-R" in speaker_name:
37 | speaker_name = speaker_name.replace("LibriTTS-R", "LibriVox")
38 | """
39 |
40 | metadata_path = cfg.metadata_dir / f'{speaker_name}.json'
41 | metadata = json_read( metadata_path, default={} )
42 |
43 | for k, entry in metadata.items():
44 | if "text" not in entry:
45 | continue
46 |
47 | language = entry.get('language','auto')
48 | text = entry['text']
49 | tokenizer_data.append( text )
50 |
51 | if output_file.exists():
52 | tokenizer_data = json.loads(open(str(output_file), "r", encoding="utf-8").read())
53 | else:
54 | # training
55 | for data_dir in tqdm(sorted(cfg.dataset.training), desc="Processing Training"):
56 | try:
57 | add( data_dir, type="training" )
58 | except Exception as e:
59 | pass
60 |
61 | # validation
62 | for data_dir in tqdm(sorted(cfg.dataset.validation), desc='Processing Validation'):
63 | try:
64 | add( data_dir, type="validation" )
65 | except Exception as e:
66 | pass
67 | """
68 | for dataset_name in os.listdir(f'./{input_metadata}/'):
69 | if not os.path.isdir(f'./{input_metadata}/{dataset_name}/'):
70 | continue
71 |
72 | for speaker_id in tqdm(os.listdir(f'./{input_metadata}/{dataset_name}/'), desc="Processing speaker"):
73 | if not os.path.isdir(f'./{input_metadata}/{dataset_name}/{speaker_id}'):
74 | continue
75 |
76 | for id in os.listdir(f'./{input_metadata}/{dataset_name}/{speaker_id}/'):
77 | if ".json" not in id:
78 | continue
79 |
80 | metadata_path = Path(f'./{input_metadata}/{dataset_name}/{speaker_id}/{id}')
81 | metadata = json.loads(open(metadata_path, "r", encoding="utf-8").read())
82 |
83 | if "text" not in metadata:
84 | continue
85 |
86 | tokenizer_data.append( f'{"".join(metadata["text"])}' )
87 |
88 | open(output_file, 'w', encoding='utf-8').write(json.dumps(tokenizer_data))
89 | """
90 |
91 | unk_token = ""
92 | spl_tokens = [unk_token, "", "", "", ""]
93 |
94 | trainer = BpeTrainer(special_tokens = spl_tokens, vocab_size = 32768, max_token_length=1, min_frequency=len(tokenizer_data))
95 | tokenizer = Tokenizer(BPE(unk_token = unk_token))
96 | tokenizer.pre_tokenizer = Whitespace() # takes 2 hours to process without this, we'll just manually add spaces as a token
97 | tokenizer.post_processor = TemplateProcessing(
98 | single=" $A ",
99 | special_tokens=[("", 1), ("", 2)],
100 | )
101 |
102 | tokenizer.train_from_iterator(tokenizer_data, trainer=trainer)
103 | tokenizer.save("./training/tokenizer_training_data.json")
--------------------------------------------------------------------------------
/vall_e.cpp/vall_e-impl.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 |
3 | // stores all the backend stuff
4 |
5 | // external deps
6 | #include
7 | #include
8 | #include
9 | #include
10 |
11 | #define LLAMA_CPP_EXTENDED 0 // whether the underlying llama.cpp has some extra functions
12 | #define LLAMA_CPP_USE_VALL_E_ARCH 0 // whether the underlying llama.cpp is to use the VALL_E arch (or using LLAMA arch)
13 |
14 | #if !LLAMA_CPP_EXTENDED
15 | #include "llama_hack.h" // cringe hotfix but I have to do this until llama.cpp's API exposes the tok_embd
16 | #endif
17 |
18 | // to-do: clean up spaghetti enums
19 | const int EMBEDDING_MODE_PROM = 0;
20 | const int EMBEDDING_MODE_RESP_AR_NAR = 1;
21 | const int EMBEDDING_MODE_RESP_NAR_LEN = 2;
22 |
23 | const int INFERENCE_MODE_LEN = 0;
24 | const int INFERENCE_MODE_AR = 1;
25 | const int INFERENCE_MODE_NAR_DEMASK = 2;
26 | const int INFERENCE_MODE_NAR = 3;
27 |
28 | // stores metadata for inputs/outputs
29 | struct io_t {
30 | std::string name;
31 | uint32_t start;
32 | uint32_t end;
33 | int32_t head_idx = -1;
34 |
35 | int32_t n_embd = 0;
36 | int32_t n_vocab = 0;
37 |
38 | std::vector embds = {};
39 | ggml_tensor* head = NULL;
40 | };
41 |
42 | // stores the mappings between tokens, input embeddings, and output heads
43 | struct io_map_t {
44 | // model's original params
45 | int32_t n_embd = 0;
46 | int32_t n_vocab = 0;
47 |
48 | // mapping
49 | std::unordered_map io = {};
50 | // context to store slices
51 | ggml_context* ctx = NULL;
52 | };
53 | // used for top-k (mainly for demasking)
54 | struct score_t {
55 | int32_t idx;
56 | float value;
57 |
58 | bool operator<( const score_t& that ) const { return this->value < that.value; }
59 | };
60 | // handles storing metadata for token merges
61 | struct merge_entry_t {
62 | std::u32string pre;
63 | std::u32string post;
64 | std::u32string resolved;
65 |
66 | token_t pre_token;
67 | token_t post_token;
68 | token_t resolved_token;
69 | };
70 |
71 | // helper tensor functions
72 | std::vector read_2d_tensor( struct ggml_tensor* tensor );
73 | //ggml_tensor* view_2d_tensor( ggml_tensor* tensor, int32_t start, int32_t end, int32_t dim = 0 ); // cringe method to keep in my pocket
74 | ggml_tensor* view_2d_tensor( ggml_context* ctx, ggml_tensor* tensor, int32_t start, int32_t end, int32_t dim = 0 );
75 | void print_tokens( const std::vector& tokens, const std::string& prefix = "Tokens: " );
76 |
77 | std::vector> map_embeddings( const std::vector& tokens, int n_embd, const float* embds );
78 | std::vector> sum_embeddings( const vall_e_audio_codes_t& input, int n_embd, int rvq_l, const float** embds, int mode = EMBEDDING_MODE_PROM );
79 | std::vector soft_max( int n_logits, const float* logits );
80 |
81 | // batch and inferencing
82 | void batch_add( llama_batch& batch, token_t id, int n_embd, const float* embds, llama_pos pos, bool output, const std::vector & seq_ids = {0} );
83 | void fill_batch( llama_batch& batch, vall_e_inputs_t& input, io_map_t& inputs_map, int mode );
84 | std::vector generate( vall_e_context_t* ctx, vall_e_inputs_t& input, int max_tokens, int mode, bool verbose = true );
85 |
86 | // (handles text)
87 | std::vector phonemize( vall_e_context_t* ctx, const std::string& text, const std::string& language = "auto" );
88 |
89 | // model-accessing helpers
90 | const io_t& vall_e_inputs_map_get_embeddings( io_map_t& inputs_map, const std::string& name );
91 | const float* vall_e_inputs_map_get_embeddings_p( io_map_t& inputs_map, const std::string& name );
92 | int32_t vall_e_inputs_map_get_classifier_idx( io_map_t& inputs_map, const std::string& name );
93 | void vall_e_inputs_map_init( io_map_t&, llama_model* model );
--------------------------------------------------------------------------------
/vall_e/utils/ext/unsloth.py:
--------------------------------------------------------------------------------
1 | # lifted from https://gist.github.com/pszemraj/e88ff24ab296b6d89057376b299b368a
2 | # to-do: make this work with LoRAs, it complains
3 |
4 | # Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
5 | #
6 | # Licensed under the Apache License, Version 2.0 (the "License");
7 | # you may not use this file except in compliance with the License.
8 | # You may obtain a copy of the License at
9 | #
10 | # http://www.apache.org/licenses/LICENSE-2.0
11 | #
12 | # Unless required by applicable law or agreed to in writing, software
13 | # distributed under the License is distributed on an "AS IS" BASIS,
14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 | # See the License for the specific language governing permissions and
16 | # limitations under the License.
17 |
18 | import torch
19 | import transformers
20 | import inspect
21 |
22 |
23 | class Unsloth_Offloaded_Gradient_Checkpointer(torch.autograd.Function):
24 | """
25 | Saves VRAM by smartly offloading to RAM.
26 | Tiny hit to performance, since we mask the movement via non blocking calls.
27 | """
28 |
29 | @staticmethod
30 | @torch.cuda.amp.custom_fwd
31 | def forward(ctx, forward_function, hidden_states, *args):
32 | saved_hidden_states = hidden_states.to("cpu", non_blocking=True)
33 | with torch.no_grad():
34 | output = forward_function(hidden_states, *args)
35 | ctx.save_for_backward(saved_hidden_states)
36 | ctx.forward_function = forward_function
37 | ctx.args = args
38 |
39 | return output
40 |
41 | pass
42 |
43 | @staticmethod
44 | @torch.cuda.amp.custom_bwd
45 | def backward(ctx, dY):
46 | (hidden_states,) = ctx.saved_tensors
47 | hidden_states = hidden_states.to("cuda", non_blocking=True).detach()
48 | hidden_states.requires_grad = True
49 | with torch.enable_grad():
50 | (output,) = ctx.forward_function(hidden_states, *ctx.args)
51 | torch.autograd.backward(output, dY)
52 | return (
53 | None,
54 | hidden_states.grad,
55 | ) + (
56 | None,
57 | ) * len(ctx.args)
58 |
59 | pass
60 |
61 |
62 | pass
63 |
64 |
65 | def new_gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None):
66 | #assert gradient_checkpointing_kwargs == None
67 | gradient_checkpointing_kwargs = None
68 | if not self.supports_gradient_checkpointing:
69 | raise ValueError(
70 | f"{self.__class__.__name__} does not support gradient checkpointing."
71 | )
72 |
73 | gradient_checkpointing_func = Unsloth_Offloaded_Gradient_Checkpointer.apply
74 | # For old GC format (transformers < 4.35.0) for models that live on the Hub
75 | # we will fall back to the overwritten `_set_gradient_checkpointing` method
76 | _is_using_old_format = (
77 | "value" in inspect.signature(self._set_gradient_checkpointing).parameters
78 | )
79 |
80 | if not _is_using_old_format:
81 | self._set_gradient_checkpointing(
82 | enable=True, gradient_checkpointing_func=gradient_checkpointing_func
83 | )
84 | else:
85 | raise NotImplementedError()
86 |
87 | if getattr(self, "_hf_peft_config_loaded", False):
88 | # When using PEFT + gradient checkpointing + Trainer we need to make sure the input has requires_grad=True
89 | # we do it also on PEFT: https://github.com/huggingface/peft/blob/85013987aa82aa1af3da1236b6902556ce3e483e/src/peft/peft_model.py#L334
90 | # When training with PEFT, only LoRA layers will have requires grad set to True, but the output of frozen layers need to propagate
91 | # the gradients to make sure the gradient flows.
92 | self.enable_input_require_grads()
93 |
94 |
95 | def apply_unsloth_offloaded_gradient_checkpoint_monkey_patch():
96 | transformers.modeling_utils.PreTrainedModel.gradient_checkpointing_enable = (
97 | new_gradient_checkpointing_enable
98 | )
--------------------------------------------------------------------------------
/vall_e.cpp/include/quantizer.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 |
3 | #include
4 | #include
5 |
6 | #include "ggml.h"
7 | #include "ggml-alloc.h"
8 | #include "ggml-backend.h"
9 |
10 | #include "utils.h"
11 |
12 | struct encodec_quant_block {
13 | struct ggml_tensor *embed;
14 | };
15 |
16 | struct encodec_quantizer {
17 | std::vector blocks;
18 | };
19 |
20 | struct ggml_tensor *encodec_forward_quantizer_encode(
21 | const struct encodec_quantizer *quantizer, struct ggml_context *ctx0,
22 | struct ggml_tensor *encoded_inp, const int n_bins, const int sr, const int bandwidth,
23 | const int hop_length) {
24 |
25 | if (!encoded_inp) {
26 | fprintf(stderr, "%s: null input tensor\n", __func__);
27 | return NULL;
28 | }
29 |
30 | const int frame_rate = (int)ceilf(sr / hop_length);
31 | const int n_q = get_num_quantizers_for_bandwidth(n_bins, frame_rate, bandwidth);
32 |
33 | const int seq_length = encoded_inp->ne[0];
34 |
35 | struct ggml_tensor *codes = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, seq_length, n_q);
36 | ggml_set_input(codes);
37 |
38 | struct ggml_tensor *inpL = ggml_cont(ctx0, ggml_transpose(ctx0, encoded_inp));
39 | struct ggml_tensor *residual = inpL;
40 | struct ggml_tensor *indices;
41 |
42 | for (int i = 0; i < n_q; i++) {
43 | encodec_quant_block block = quantizer->blocks[i];
44 |
45 | // compute distance
46 | // [seq_length, n_bins]
47 | struct ggml_tensor *dp = ggml_scale(
48 | ctx0, ggml_mul_mat(ctx0, block.embed, residual), -2.0f);
49 |
50 | // [n_bins]
51 | struct ggml_tensor *sqr_embed = ggml_sqr(ctx0, block.embed);
52 | struct ggml_tensor *sqr_embed_nrm = ggml_sum_rows(ctx0, sqr_embed);
53 |
54 | // [seq_length]
55 | struct ggml_tensor *sqr_inp = ggml_sqr(ctx0, residual);
56 | struct ggml_tensor *sqr_inp_nrm = ggml_sum_rows(ctx0, sqr_inp);
57 |
58 | // [seq_length, n_bins]
59 | struct ggml_tensor *dist = ggml_add(ctx0, ggml_repeat(ctx0, sqr_inp_nrm, dp), dp);
60 | dist = ggml_add(ctx0, ggml_repeat(ctx0, ggml_transpose(ctx0, sqr_embed_nrm), dist), dist);
61 | dist = ggml_neg(ctx0, dist);
62 |
63 | // take the argmax over the column dimension
64 | // [seq_length]
65 | indices = ggml_argmax(ctx0, dist);
66 |
67 | // look up in embedding table
68 | struct ggml_tensor *quantized = ggml_get_rows(ctx0, block.embed, indices);
69 |
70 | residual = ggml_sub(ctx0, residual, quantized);
71 |
72 | codes = ggml_set_1d(ctx0, codes, indices, i * codes->nb[1]);
73 | }
74 |
75 | return codes;
76 | }
77 |
78 | struct ggml_tensor *encodec_forward_quantizer_decode(
79 | const struct encodec_quantizer *quantizer, struct ggml_context *ctx0,
80 | struct ggml_tensor *codes, const int hidden_dim, const int n_bins, const int sr, const int bandwidth,
81 | const int hop_length) {
82 |
83 | if (!codes) {
84 | fprintf(stderr, "%s: null input tensor\n", __func__);
85 | return NULL;
86 | }
87 |
88 | const int seq_length = codes->ne[0];
89 |
90 | const int frame_rate = (int)ceilf(sr / hop_length);
91 | const int n_q = get_num_quantizers_for_bandwidth(n_bins, frame_rate, bandwidth);
92 |
93 | assert(n_q == codes->ne[1]);
94 |
95 | struct ggml_tensor *quantized_out = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, hidden_dim, seq_length);
96 | ggml_set_input(quantized_out);
97 | ggml_set_name(quantized_out, "quantized_out");
98 |
99 | for (int i = 0; i < n_q; i++) {
100 | encodec_quant_block block = quantizer->blocks[i];
101 |
102 | struct ggml_tensor *indices = ggml_view_1d(ctx0, codes, seq_length, i * codes->nb[1]);
103 | struct ggml_tensor *quantized = ggml_get_rows(ctx0, block.embed, indices);
104 |
105 | quantized_out = ggml_add(ctx0, quantized_out, quantized);
106 | }
107 |
108 | quantized_out = ggml_cont(ctx0, ggml_transpose(ctx0, quantized_out));
109 |
110 | return quantized_out;
111 | }
112 |
--------------------------------------------------------------------------------
/vall_e/utils/io.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import json
3 |
4 | from pathlib import Path
5 | from safetensors import safe_open as sft_load
6 | from safetensors.torch import save_file as sft_save
7 |
8 | try:
9 | use_orjson = True
10 | import orjson as json
11 | except:
12 | import json
13 |
14 | from .utils import truncate_json
15 |
16 | def json_stringify( data, truncate=False, pretty=False, raw=False ):
17 | if truncate:
18 | return truncate_json( json.dumps( data ) )
19 | if pretty:
20 | s = json.dumps( data, option=json.OPT_INDENT_2 ) if use_orjson else json.dumps( data, indent='\t' )
21 | return s if raw and use_orjson else s.decode('utf-8')
22 | return json.dumps( data )
23 |
24 | def json_parse( string ):
25 | return json.loads( string )
26 |
27 | def json_read( path, default=None ):
28 | path = coerce_path( path )
29 |
30 | if not path.exists():
31 | return default
32 |
33 | with (open( str(path), "rb" ) if use_orjson else open( str(path), "r", encoding="utf-8" ) ) as f:
34 | return json_parse( f.read() )
35 |
36 | def json_write( data, path, **kwargs ):
37 | path = coerce_path( path )
38 |
39 | with (open( str(path), "wb" ) if use_orjson else open( str(path), "w", encoding="utf-8" ) ) as f:
40 | f.write( json_stringify( data, raw=use_orjson, **kwargs ) )
41 |
42 | def coerce_path( path ):
43 | return path if isinstance( path, Path ) else Path(path)
44 |
45 | def pick_path( path, *suffixes ):
46 | suffixes = [*suffixes]
47 |
48 | for suffix in suffixes:
49 | p = path.with_suffix( suffix )
50 | if p.exists():
51 | return p
52 |
53 | return path
54 |
55 | def is_dict_of( d, t ):
56 | if not isinstance( d, dict ):
57 | return False
58 |
59 | return all([ isinstance(v, torch.Tensor) for v in d.values() ])
60 |
61 | # handles converting the usual pth state_dict into just the dict with the tensors + a dict of JSON strings, for safetensors
62 | def state_dict_to_tensor_metadata( data: dict, module_key=None ):
63 | metadata = {}
64 |
65 | # is a state_dict, no need to coerce
66 | if is_dict_of( data, torch.Tensor ):
67 | return data, metadata
68 |
69 | # is maybe a dict with a state dict + metadata, coerce it
70 | target = module_key
71 | if not target:
72 | for k, v in data.items():
73 | # is a dict of tensors, our target
74 | if is_dict_of( v, torch.Tensor ):
75 | target = k
76 | continue # continue to iterate to grab other metadata
77 |
78 | # not a dict of tensors, put it as metadata
79 | try:
80 | metadata[k] = json_stringify(v) if any([isinstance( v, dict ), isinstance( v, list )]) else v
81 |
82 | if isinstance( metadata[k], bytes ):
83 | metadata[k] = metadata[k].decode('utf-8')
84 | except Exception as e:
85 | pass
86 |
87 | if not target:
88 | raise Exception(f'Requesting to save safetensors of a state dict, but state dict contains no key of torch.Tensor: {path}')
89 |
90 | return data[target], metadata
91 |
92 | def torch_save( data, path, module_key=None ):
93 | path = coerce_path(path)
94 | ext = path.suffix
95 |
96 | if ext in [".safetensor", ".safetensors", ".sft"]:
97 | data, metadata = state_dict_to_tensor_metadata( data, module_key=module_key )
98 |
99 | if metadata is None:
100 | metadata = {}
101 |
102 | return sft_save( data, path, { k: v for k, v in metadata.items() if v is not None } )
103 |
104 | return torch.save( data, path )
105 |
106 | def torch_load( path, device="cpu", framework="pt", unsafe=True, load_metadata=True, module_key="module" ):
107 | path = coerce_path(path)
108 | ext = path.suffix
109 |
110 | if ext in [".safetensor", ".safetensors", ".sft"]:
111 | state_dict = {}
112 | with sft_load(path, framework=framework, device=device) as f:
113 | for k in f.keys():
114 | state_dict[k] = f.get_tensor(k)
115 |
116 | if load_metadata:
117 | metadata = f.metadata()
118 | for k, v in metadata.items():
119 | try:
120 | metadata[k] = json.loads( v )
121 | except Exception as e:
122 | pass
123 | state_dict = { module_key: state_dict } | metadata
124 |
125 | return state_dict
126 |
127 | return torch.load( path, map_location=torch.device(device), weights_only=not unsafe )
--------------------------------------------------------------------------------
/vall_e/plot.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | import argparse
4 | import json
5 | import time
6 | import re
7 | from pathlib import Path
8 |
9 | import matplotlib.pyplot as plt
10 | import pandas as pd
11 |
12 | from .config import cfg
13 |
14 | def plot(paths, args):
15 | dfs = []
16 |
17 | for path in paths:
18 | with open(path, "r") as f:
19 | text = f.read()
20 |
21 | rows = []
22 |
23 | pattern = r"(\{.+?\})\.\n"
24 |
25 | for row in re.findall(pattern, text, re.DOTALL):
26 | try:
27 | row = json.loads(row)
28 | except Exception as e:
29 | continue
30 |
31 | for model in args.models:
32 | if f'{model.name}.{args.xs}' not in row:
33 | continue
34 | rows.append(row)
35 | break
36 |
37 | df = pd.DataFrame(rows)
38 |
39 | if "name" in df:
40 | df["name"] = df["name"].fillna("train")
41 | else:
42 | df["name"] = "train"
43 |
44 | df["group"] = str(path.parents[args.group_level])
45 | df["group"] = df["group"] + "/" + df["name"]
46 |
47 | dfs.append(df)
48 |
49 | df = pd.concat(dfs)
50 |
51 | if args.min_x is not None:
52 | for model in args.models:
53 | df = df[args.min_x < df[f'{model.name}.{args.xs}']]
54 |
55 | if args.max_x is not None:
56 | for model in args.models:
57 | df = df[df[f'{model.name}.{args.xs}'] < args.max_x]
58 |
59 | for gtag, gdf in sorted(
60 | df.groupby("group"),
61 | key=lambda p: (p[0].split("/")[-1], p[0]),
62 | ):
63 | for model in args.models:
64 | x = f'{model.name}.{args.xs}'
65 | for ys in args.ys:
66 | y = f'{model.name}.{ys}'
67 |
68 | if gdf[y].isna().all():
69 | continue
70 |
71 | if args.min_y is not None:
72 | gdf = gdf[args.min_y < gdf[y]]
73 | if args.max_y is not None:
74 | gdf = gdf[gdf[y] < args.max_y]
75 |
76 | if args.ewm:
77 | gdf[y] = gdf[y].ewm(args.ewm).mean()
78 | elif args.rolling:
79 | gdf[y] = gdf[y].rolling(args.rolling).mean()
80 |
81 | gdf.plot(
82 | x=x,
83 | y=y,
84 | label=f"{y}",
85 | ax=plt.gca(),
86 | marker="x" if len(gdf) < 100 else None,
87 | alpha=0.7,
88 | )
89 |
90 | plt.gca().legend(
91 | #loc="center left",
92 | fancybox=True,
93 | shadow=True,
94 | #bbox_to_anchor=(1.04, 0.5),
95 | )
96 |
97 | def plot_sample_metrics( metrics, filename=None ):
98 | """
99 | fig = plt.figure()
100 | fig.set_figwidth( 16 * len(metrics) // cfg.dataset.frames_per_second )
101 | """
102 |
103 | data = { key: [ e[0][key] for e in metrics ] for key in metrics[0][0].keys() }
104 |
105 | df = pd.DataFrame(data)
106 | df.plot()
107 |
108 | plt.gca().legend(
109 | #loc="center left",
110 | fancybox=True,
111 | shadow=True,
112 | #bbox_to_anchor=(1.04, 0.5),
113 | )
114 |
115 | if not filename:
116 | filename = f'{time.time()}.png'
117 |
118 | out_path = cfg.rel_path / "metrics" / filename
119 | out_path.parent.mkdir(parents=True, exist_ok=True)
120 | plt.savefig(out_path, bbox_inches="tight")
121 |
122 | if __name__ == "__main__":
123 | parser = argparse.ArgumentParser()
124 | parser.add_argument("--xs", default="engine_step")
125 | parser.add_argument("--ys", nargs="+", default="")
126 | parser.add_argument("--model", nargs="+", default="*")
127 |
128 | parser.add_argument("--min-x", type=float, default=-float("inf"))
129 | parser.add_argument("--min-y", type=float, default=-float("inf"))
130 | parser.add_argument("--max-x", type=float, default=float("inf"))
131 | parser.add_argument("--max-y", type=float, default=float("inf"))
132 |
133 | parser.add_argument("--ewm", type=int, default=1024)
134 | parser.add_argument("--rolling", type=int, default=None)
135 |
136 | parser.add_argument("--size", type=str, default=None)
137 |
138 | parser.add_argument("--filename", default="log.txt")
139 | parser.add_argument("--group-level", default=1)
140 | args, unknown = parser.parse_known_args()
141 |
142 | path = cfg.rel_path / "logs"
143 | paths = path.rglob(f"./*/{args.filename}")
144 |
145 | args.models = [ model for model in cfg.model.get() if model.training and (args.model == "*" or model.name in args.model) ]
146 |
147 | if args.ys == "":
148 | args.ys = ["loss.nll"]
149 |
150 | if args.size:
151 | width, height = args.size.split("x")
152 | plt.figure(figsize=(int(width), int(height)))
153 |
154 | plot(paths, args)
155 |
156 | out_path = cfg.rel_path / "metrics.png"
157 | plt.savefig(out_path, bbox_inches="tight")
--------------------------------------------------------------------------------
/vall_e.cpp/include/llama-vocab.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 |
3 | #include "llama.h"
4 |
5 | #include
6 | #include
7 | #include
8 |
9 | struct LLM_KV;
10 | struct llama_model_loader;
11 |
12 | struct llama_vocab {
13 | struct token_data {
14 | std::string text;
15 | float score;
16 | llama_token_attr attr;
17 | };
18 |
19 | llama_vocab();
20 | ~llama_vocab();
21 |
22 | void load(llama_model_loader & ml, const LLM_KV & kv);
23 |
24 | enum llama_vocab_type get_type() const;
25 | enum llama_vocab_pre_type get_pre_type() const;
26 |
27 | uint32_t n_tokens() const;
28 | uint32_t n_token_types() const;
29 |
30 | std::string type_name() const;
31 |
32 | bool is_normal (llama_token id) const;
33 | bool is_unknown (llama_token id) const;
34 | bool is_control (llama_token id) const;
35 | bool is_byte (llama_token id) const;
36 | bool is_user_defined(llama_token id) const;
37 | bool is_unused (llama_token id) const;
38 | bool is_eog (llama_token id) const;
39 |
40 | uint8_t token_to_byte(llama_token id) const;
41 | llama_token byte_to_token(uint8_t ch) const;
42 |
43 | llama_token text_to_token(const std::string & text) const;
44 |
45 | const token_data & get_token_data(llama_token id) const;
46 |
47 | const char * token_get_text (llama_token id) const;
48 | float token_get_score(llama_token id) const;
49 | llama_token_attr token_get_attr (llama_token id) const;
50 |
51 | llama_token token_bos() const;
52 | llama_token token_eos() const;
53 | llama_token token_eot() const;
54 | llama_token token_eom() const;
55 | llama_token token_unk() const;
56 | llama_token token_sep() const;
57 | llama_token token_nl () const;
58 | llama_token token_pad() const;
59 |
60 | llama_token token_prefix() const;
61 | llama_token token_middle() const;
62 | llama_token token_suffix() const;
63 |
64 | llama_token token_fim_pre() const;
65 | llama_token token_fim_suf() const;
66 | llama_token token_fim_mid() const;
67 | llama_token token_fim_pad() const;
68 | llama_token token_fim_rep() const;
69 | llama_token token_fim_sep() const;
70 |
71 | bool get_add_space_prefix () const;
72 | bool get_add_bos () const;
73 | bool get_add_eos () const;
74 | bool get_ignore_merges () const;
75 | bool get_clean_spaces () const;
76 | bool get_remove_extra_whitespaces () const;
77 | bool get_escape_whitespaces () const;
78 | bool get_treat_whitespace_as_suffix() const;
79 |
80 | int max_token_len() const;
81 |
82 | int find_bpe_rank(const std::string & token_left, const std::string & token_right) const;
83 |
84 | int32_t tokenize(
85 | const char * text,
86 | int32_t text_len,
87 | llama_token * tokens,
88 | int32_t n_tokens_max,
89 | bool add_special,
90 | bool parse_special) const;
91 |
92 | std::vector tokenize(
93 | const std::string & raw_text,
94 | bool add_special,
95 | bool parse_special = false) const;
96 |
97 | // does not write null-terminator to buf
98 | int32_t token_to_piece(
99 | llama_token token,
100 | char * buf,
101 | int32_t length,
102 | int32_t lstrip,
103 | bool special) const;
104 |
105 | // use cached data
106 | const std::string & token_to_piece(llama_token token) const;
107 |
108 | int32_t detokenize(
109 | const llama_token * tokens,
110 | int32_t n_tokens,
111 | char * text,
112 | int32_t text_len_max,
113 | bool remove_special,
114 | bool unparse_special) const;
115 |
116 | std::string detokenize(
117 | const std::vector & tokens,
118 | bool special) const;
119 |
120 | void print_info() const;
121 |
122 | private:
123 | struct impl;
124 | std::unique_ptr pimpl;
125 | };
126 |
--------------------------------------------------------------------------------
/vall_e/emb/g2p.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import random
3 | import string
4 | import torch
5 |
6 | from functools import cache
7 | from pathlib import Path
8 | from phonemizer import phonemize
9 | from phonemizer.backend import BACKENDS
10 |
11 | from tqdm import tqdm
12 |
13 | try:
14 | import pykakasi
15 | except Exception as e:
16 | pykakasi = None
17 | print(f'Error while importing pykakasi: {str(e)}')
18 | pass
19 |
20 | try:
21 | import langdetect
22 | except Exception as e:
23 | langdetect = None
24 | print(f'Error while importing langdetect: {str(e)}')
25 |
26 | def detect_language( text ):
27 | if not text:
28 | return "en" # to-do: map to a null language
29 |
30 | if langdetect is None:
31 | raise Exception('langdetect is not installed.')
32 | return langdetect.detect( text )
33 |
34 | def _get_graphs(path):
35 | with open(path, "r") as f:
36 | graphs = f.read()
37 | return graphs
38 |
39 | def coerce_to_hiragana( runes, sep="" ):
40 | if pykakasi is None:
41 | raise Exception('pykakasi is not installed.')
42 |
43 | kks = pykakasi.kakasi()
44 | result = kks.convert( runes )
45 | return sep.join([ res['hira'] for res in result ])
46 |
47 | def coerce_language( lang ):
48 | # bottle of water vs bo'oh'o'wa'er
49 | if lang == "en":
50 | lang = "en-us"
51 | # quebec probably
52 | if lang == "fr":
53 | return "fr-fr"
54 | # phonemizer/espeak used to have zh refer to mandarin, but was renamed to cmn
55 | # cmn outputs cringe, but not cmn-latn-pinyin
56 | # also just coerces any of the dialects into this (to avoid crimes)
57 | if lang[:2] == "zh":
58 | return "cmn-latn-pinyin"
59 | """
60 | things to consider in the future
61 | en-uk or en-gb
62 | es-la vs es-es
63 | pt-br vs pt-pt
64 | """
65 | return lang
66 |
67 | cached_backends = {}
68 | def _get_backend( language="en-us", backend="espeak", punctuation=True, stress=True, strip=True ):
69 | key = f'{language}_{backend}'
70 | if key in cached_backends:
71 | return cached_backends[key]
72 |
73 | if backend == 'espeak':
74 | phonemizer = BACKENDS[backend]( language, preserve_punctuation=punctuation, with_stress=stress)
75 | elif backend == 'espeak-mbrola':
76 | phonemizer = BACKENDS[backend]( language )
77 | else:
78 | phonemizer = BACKENDS[backend]( language, preserve_punctuation=punctuation )
79 |
80 | cached_backends[key] = phonemizer
81 | return phonemizer
82 |
83 |
84 | def encode(text: str, language="auto", backend="auto", punctuation=True, stress=True, strip=True) -> list[str]:
85 | if language == "auto":
86 | language = detect_language( text )
87 |
88 | language = coerce_language( language )
89 |
90 | #
91 | if backend == "auto":
92 | # Convert to hiragana, as espeak does not like kanji
93 | if language[:2] == "ja":
94 | text = coerce_to_hiragana( text )
95 |
96 | # "zh" => "cmn-latn-pinyin"
97 | elif language == "zh":
98 | language = "cmn-latn-pinyin"
99 |
100 |
101 | if not backend or backend == "auto":
102 | backend = "espeak" # if language[:2] != "en" else "festival"
103 |
104 | backend = _get_backend(language=language, backend=backend, stress=stress, strip=strip, punctuation=punctuation)
105 | if backend is not None:
106 | phonemes = backend.phonemize( [ text ], strip=strip )
107 | else:
108 | phonemes = phonemize( [ text ], language=language, strip=strip, preserve_punctuation=punctuation, with_stress=stress )
109 |
110 | if not len(phonemes):
111 | raise Exception(f"Failed to phonemize, received empty string: {text}")
112 |
113 | phonemes = phonemes[0]
114 |
115 | # remap tones
116 | # technically they can be kept in place and just update the tokenizer, but this would be a bit confusing
117 | if language == "cmn-latn-pinyin":
118 | tones = {
119 | "1": "ˇ",
120 | "2": "ˉ",
121 | "3": "ˊ",
122 | "4": "ˋ",
123 | "5": "_",
124 | }
125 | for k, v in tones.items():
126 | phonemes = phonemes.replace(k, v)
127 |
128 | return phonemes
129 |
130 | # Helper function to debug phonemizer
131 | if __name__ == "__main__":
132 | parser = argparse.ArgumentParser()
133 |
134 | parser.add_argument("string", type=str)
135 | parser.add_argument("--language", type=str, default="en-us")
136 | parser.add_argument("--backend", type=str, default="auto")
137 | parser.add_argument("--no-punctuation", action="store_true")
138 | parser.add_argument("--no-stress", action="store_true")
139 | parser.add_argument("--no-strip", action="store_true")
140 |
141 | args = parser.parse_args()
142 |
143 | phonemes = encode( args.string, language=args.language, backend=args.backend, punctuation=not args.no_punctuation, stress=not args.no_stress, strip=not args.no_strip )
144 | print( phonemes )
--------------------------------------------------------------------------------
/vall_e.cpp/include/ggml-cann.h:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (c) 2023-2024 The ggml authors
3 | *
4 | * Permission is hereby granted, free of charge, to any person obtaining a copy
5 | * of this software and associated documentation files (the "Software"), to
6 | * deal in the Software without restriction, including without limitation the
7 | * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
8 | * sell copies of the Software, and to permit persons to whom the Software is
9 | * furnished to do so, subject to the following conditions:
10 | *
11 | * The above copyright notice and this permission notice shall be included in
12 | * all copies or substantial portions of the Software.
13 | *
14 | * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
15 | * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16 | * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
17 | * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18 | * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
19 | * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
20 | * IN THE SOFTWARE.
21 | */
22 |
23 | #pragma once
24 |
25 | #include "ggml-backend.h"
26 | #include "ggml.h"
27 |
28 | #ifdef __cplusplus
29 | extern "C" {
30 | #endif
31 |
32 | /**
33 | * @brief Maximum number of CANN devices supported.
34 | */
35 | #define GGML_CANN_MAX_DEVICES 16
36 |
37 | GGML_BACKEND_API ggml_backend_reg_t ggml_backend_cann_reg(void);
38 |
39 | /**
40 | * @brief Initializes the CANN backend for a specified device.
41 | *
42 | * This function initializes the CANN backend for the given device.
43 | * It verifies the device index, allocates a context, and creates a backend
44 | * instance.
45 | *
46 | * @param device The index of the device to initialize.
47 | * @return A pointer to the initialized backend instance, or nullptr on failure.
48 | */
49 | GGML_BACKEND_API ggml_backend_t ggml_backend_cann_init(int32_t device);
50 |
51 | /**
52 | * @brief Checks if a given backend is a CANN backend.
53 | *
54 | * This function verifies if the provided backend is a CANN backend by comparing
55 | * its GUID with the CANN backend's GUID.
56 | *
57 | * @param backend The backend instance to check.
58 | * @return True if the backend is a CANN backend, false otherwise.
59 | */
60 | GGML_BACKEND_API bool ggml_backend_is_cann(ggml_backend_t backend);
61 |
62 | /**
63 | * @brief Retrieves the CANN buffer type for a specified device.
64 | *
65 | * This function initializes and returns the buffer type interface associated
66 | * with the given device. It ensures thread-safe access using a mutex.
67 | *
68 | * @param device The device index for which to retrieve the buffer type.
69 | * @return A pointer to the buffer type interface for the specified device, or
70 | * nullptr if the device index is out of range.
71 | */
72 | GGML_BACKEND_API ggml_backend_buffer_type_t
73 | ggml_backend_cann_buffer_type(int32_t device);
74 |
75 | /**
76 | * @brief Retrieves the number of CANN devices available.
77 | *
78 | * This function returns the number of CANN devices available based on
79 | * information obtained from `ggml_cann_info()`.
80 | *
81 | * @return The number of CANN devices available.
82 | */
83 | GGML_BACKEND_API int32_t ggml_backend_cann_get_device_count(void);
84 |
85 | /**
86 | * @brief pinned host buffer for use with the CPU backend for faster copies between CPU and NPU.
87 | *
88 | * @return A pointer to the host buffer type interface.
89 | */
90 | GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_cann_host_buffer_type(void);
91 |
92 | /**
93 | * @brief Retrieves the description of a specific CANN device.
94 | *
95 | * This function sets the specified device, retrieves the SoC name,
96 | * and writes it into the provided description buffer.
97 | *
98 | * @param device The device index to retrieve the description for.
99 | * @param description Pointer to a buffer where the description will be written.
100 | * @param description_size Size of the description buffer.
101 | */
102 | GGML_BACKEND_API void ggml_backend_cann_get_device_description(
103 | int32_t device, char* description, size_t description_size);
104 |
105 | /**
106 | * @brief Retrieves the memory information of a specific CANN device.
107 | *
108 | * This function sets the specified device, retrieves the free and total
109 | * memory information of the specified type (ACL_HBM_MEM), and stores them
110 | * in the provided pointers.
111 | *
112 | * @param device The device index to retrieve memory information for.
113 | * @param free Pointer to a variable where the free memory size will be stored.
114 | * @param total Pointer to a variable where the total memory size will be
115 | * stored.
116 | */
117 | GGML_BACKEND_API void ggml_backend_cann_get_device_memory(int32_t device,
118 | size_t* free,
119 | size_t* total);
120 |
121 | #ifdef __cplusplus
122 | }
123 | #endif
124 |
--------------------------------------------------------------------------------
/docs/inferenece.md:
--------------------------------------------------------------------------------
1 | # `inference.py`
2 |
3 | This script handles everything the higher level functions of inferencing the model for various tasks for the end user.
4 |
5 | For invoking this model in another Python package, refer to `webui.py` and `demo.py` on how to use this outside of this scope.
6 |
7 | `__main__.py` invokes this according to the below arguments.
8 |
9 | ## Synthesis
10 |
11 | To synthesize speech: `python -m vall_e --yaml=` (or `--model=`)
12 |
13 | Some additional flags you can pass are:
14 | * `--language`: specifies the language for guiding guide inferencing when the model is trained against that language. Use `auto` to automatically deduce this.
15 | * `--text-language`: the language to phonemize the input text under. Leave blank to tie it to the above value.
16 | * `--task`: task to perform. Defaults to `tts`, but accepts `stt` for transcriptions.
17 | * `--max-duration`: maximum token-duration for inferencing through the AR aspect of the model. Every second corresponds to 75 steps.
18 | * `--max-steps`: maximum steps for inferencing through the NAR-len aspect of the model.
19 | * `--device`: device to use (default: `cuda`, examples: `cuda:0`, `cuda:1`, `cpu`)
20 | * `--ar-temperature`: sampling temperature to use for the AR/NAR pass. 0 enables greedy sampling.
21 | * For the AR, ~1.0 is *fine*, but lowering the temperature adheres better to the prosody of the input prompt.
22 | * For the AR, low temperatures require a repetition penalty to prevent outputs from degenerating.
23 | * For the NAR, greedy sampling is best, but can be raised to 0.2.
24 | * `--input-prompt-length`: the duration of the input prompt (~6 seconds is fine, longer durations lead to slower generations for "better" accuracy). 0 does not repeat/trim.
25 | * If a prompt is shorter than the given duration, it's repeated to the duration size.
26 |
27 | And some experimental sampling flags you can use too (your mileage will ***definitely*** vary, but most of these are bandaids for a bad AR):
28 | * `--input-prompt-prefix`: (AR only) treats the input prompt as the initial response prefix, but...
29 | * the transcription of the prompt needs to be in the input text prompt.
30 | * doesn't perform all that well (I belive the model needs to be trained a bit on this, as `tts-c`).
31 | * `--min-temperature`: triggers the dynamic temperature pathway, adjusting the temperature based on the confidence of the best token. Acceptable values are between `[0.0, (n)ar-temperature)`.
32 | + This simply uplifts the [original implementation](https://github.com/kalomaze/koboldcpp/blob/dynamic-temp/llama.cpp#L5132) to perform it.
33 | * `--top-p`: limits the sampling pool to top sum of values that equal `P`% probability in the probability distribution.
34 | * `--top-k`: limits the sampling pool to the top `K` values in the probability distribution.
35 | * `--min-p`: only logits above `P`% probability are considered for sampling (or something, I'm still unsure how this differs from top-p).
36 | * `--repetition-penalty`: modifies the probability of tokens if they have appeared before. In the context of audio generation, this is a very iffy parameter to use.
37 | * `--repetition-penalty-decay`: modifies the above factor applied to scale based on how far away it is in the past sequence.
38 | * `--length-penalty`: (AR only) modifies the probability of the stop token based on the current sequence length. This is ***very*** finnicky due to the AR already being well correlated with the length.
39 | * `--beam-width`: (AR only) specifies the number of branches to search through for beam sampling.
40 | + This is a very naive implementation that's effectively just greedy sampling across `B` spaces.
41 | * `--mirostat-tau`: (AR only) the "surprise value" when performing mirostat sampling.
42 | + This simply uplifts the [original implementation](https://github.com/basusourya/mirostat/blob/master/mirostat.py) to perform it.
43 | + **!**NOTE**!**: This is incompatible with beam search sampling (for the meantime at least).
44 | * `--mirostat-eta`: (AR only) the "learning rate" during mirostat sampling applied to the maximum surprise.
45 | * `--dry-multiplier`: (AR only) performs DRY sampling, the scalar factor.
46 | * `--dry-base`: (AR only) for DRY sampling, the base of the exponent factor.
47 | * `--dry-allowed-length`: (AR only) for DRY sampling, the window to perform DRY sampling within.
48 |
49 | Some arguments are able to be prefixed with `ar-` and `nar-` to only use that setting for its respective pass. At the moment through the CLI, this includes:
50 | * `temperature`
51 |
52 | ### Speech-to-Text
53 |
54 | The `ar+nar-tts+stt-llama-8` (now the reference model) model has received additional training for a speech-to-text task against EnCodec-encoded audio.
55 |
56 | Currently, the model only transcribes back into the IPA phonemes it was trained against, as an additional model or external program is required to translate the IPA phonemes back into text.
57 | * this does make a model that can phonemize text, and unphonemize text, more desirable in the future to replace espeak (having an additional task to handle this requires additional embeddings, output heads, and possible harm to the model as actual text is not a modality the model is trained on).
58 | * it seems to really want to only transcribe the first sentence for a given utterance. I imagine this is simply a problem with how it was trained.
--------------------------------------------------------------------------------
/docs/samplers.md:
--------------------------------------------------------------------------------
1 | # `sampler.py`
2 |
3 | This script contains all the samplers used during inferencing.
4 |
5 | While I do expose these samplers for end-user use, I don't like to rely on these, as exotic samplers are always bandaids to the underlying model.
6 |
7 | Most of these sampler functions do what's written on the tin, but for clarity:
8 |
9 | ## Samplers
10 |
11 | When sampling, the output logits are picked for sampling according to the current inference mode. For the AR, only the last token (or last `causal_size` tokens) are used for sampling, while the NAR relies on the previous sequence to determine how many tokens to sample in parallel.
12 |
13 | As the model is trained more, low temperatures are preferred over high temperatures for the AR, while greedy sampling is almost always preferred for the NAR.
14 |
15 | Greedy sampling is enabled when the sampling temperature is <= 0, where the most likely token is picked.
16 |
17 | ### Repetition Penalty
18 |
19 | This function (`reptition_penalize`) applies a penalty to target logits to avoid repetitive output.
20 |
21 | This is implemented by penalizing tokens in the future from repeating the currently iterated token.
22 | * This distinction is required to penalize for the NAR, while the AR only penalizes the single token being inferenced.
23 |
24 | An optional value can also be passed to factor in how far away that token is.
25 |
26 | Implicitly, this is only limited to 75 tokens in the past (one second of audio under EnCodec), and will apply more than once.
27 |
28 | For low temperatures, this is almost necessary, as no rep-pen will have the output be garbled or a mess, and very low rep-pen will have unstable output.
29 |
30 | ### Length Penalty
31 |
32 | This function (`length_penalize`) applies a penalty to the audio stop token (or any other specific token) based on the current length of the sequence.
33 |
34 | This can be either a negative or a positive, to restrain or inhibit the stop token from appearing.
35 |
36 | ### Ban Tokens
37 |
38 | This function (`ban_tokens`) bans a token from appearing.
39 |
40 | Since this is an audio LM, there's no useful purpose for this.
41 |
42 | However, for some models, this is useful for banning the stop token used for the AR, when sampling output from the NAR, if the classifier / LM head / output projections are shared between the two.
43 |
44 | ### Top-K / Top-P
45 |
46 | This function (`top_k_top_p_filtering`) filters the logits to only allow the top-K probability of tokens to be sampled, and/or the top-P probable tokens to be sampled.
47 |
48 | This may be helpful with higher temperatured sampling to offer some variety, but not allow outputs to be *too* chaotic, in theory.
49 |
50 | ### Min-P
51 |
52 | This function (`min_p_filtering`) filters out tokens that are under the min-P% probability.
53 |
54 | ### Dynamic Temperature
55 |
56 | This function (`dynamic_temperature`) implements an early version of dynamic temperature per [this external PR](https://github.com/LostRuins/koboldcpp/pull/464).
57 |
58 | To reiterate, this is an early implementation, as I recall it changing after I have already implemented this.
59 |
60 | In theory, this allows the model to sample under higher temperatures when able, but I still need to test this the more the model receives training.
61 |
62 | ### Mirostat
63 |
64 | This function (`mirostat_sample`) implements mirostat sampling. From what I understand, this modifies the logits based on "surprise" factor.
65 |
66 | This may be an early implementation, as this was implemented a while back.
67 |
68 | This *sometimes* helps the output a lot for some states of the model, but I don't try to rely on this too much.
69 |
70 | ### DRY Sampling
71 |
72 | This function (`dry_sampling`) implements DRY sampling, a replacement to naive repetition penalizing.
73 |
74 | I'm still not too sure what's so good about it, since it just seems like rep-pen with a different coat of paint, and for audio it doesn't seem to be too helpful?
75 |
76 | ### Entropix
77 |
78 | This function (`sample_entropix`) implements entropix sampling, a sampler that aids in Chain-of-Thought for text LLMs by adjusting sampling parameters according to the logits and attentions' entropy and varentropy.
79 |
80 | The huge caveat is that this requires tuning the parameters and thresholds per model, and in testing it doesn't seem like the metrics are consistent enough to rely on this. Acquiring the right attention scores is pretty much a dark art in its own right, as it does not map perfectly to naive attention, much less any other attention mechanism, under `transformers`' LLaMA.
81 |
82 | Additionally, one state requires injecting a CoT token, which doesn't have an analog in the audio domain.
83 |
84 | However, this does seem to serve as a good basis to expand upon this and sample according to the entropy/varentropy of the model's current state.
85 |
86 | ### Classifier-Free Guidance
87 |
88 | While this isn't a direct sampler type used, a helper function is provided to perform classifier-free guidance, given a positive (the primary) logits, and a negative (the null) logits. While the `NAR-len` modality requires this at the moment, it can easily be adapted everything else.
89 |
90 | Rescaling is also applied to avoid clipping the logits.
91 |
92 | Due to the logits being the full sequence, and the input lengths differing, a list of lengths are required to be passed to only modify the last N logits.
--------------------------------------------------------------------------------
/vall_e/engines/deepspeed.py:
--------------------------------------------------------------------------------
1 | """
2 | # https://github.com/enhuiz/pytorch-training-utilities
3 | """
4 |
5 | # to-do: replace this
6 | # to-do: swap out deepspeed
7 |
8 | from ..config import cfg
9 | from ..utils import dispatch_attribute, flatten_dict, gather_attribute, do_gc, to_device
10 |
11 | import logging
12 | import time
13 | import torch
14 | import torch.distributed
15 |
16 | from torch import Tensor
17 | from torch.distributed import all_reduce
18 | from typing import Any, Protocol
19 |
20 | from .base import TrainFeeder
21 |
22 | _logger = logging.getLogger(__name__)
23 |
24 | from deepspeed import DeepSpeedEngine, DeepSpeedConfig, comm as dist, init_distributed as init_deepspeed_dist
25 | from deepspeed.accelerator import get_accelerator
26 |
27 | from ..utils.distributed import init_distributed, distributed_initialized
28 | from ..utils import ml
29 |
30 | from ..models.lora import freeze_non_lora_weights
31 |
32 | if not distributed_initialized() and cfg.trainer.backend == "deepspeed":
33 | init_distributed(init_deepspeed_dist)
34 |
35 | class Engine(DeepSpeedEngine):
36 | def __init__(self, *args, **kwargs):
37 | self.hyper_config = kwargs.pop('hyper_config', None)
38 |
39 | kwargs['config'] = cfg.trainer.deepspeed.ds_cfg
40 | kwargs['config_class'] = DeepSpeedConfig(kwargs['config'])
41 |
42 | stats = {
43 | "global_step": 0,
44 | "micro_step": 0,
45 | "global_samples": 0,
46 | "tokens_processed": 0,
47 | }
48 |
49 | # kwargs['stats'] = None will return None when popped
50 | maybe_stats = kwargs.pop('stats', stats)
51 | if maybe_stats is not None:
52 | stats = maybe_stats
53 |
54 | super().__init__(None, *args, **kwargs)
55 |
56 | self.global_steps = stats["global_step"]
57 | self.micro_steps = stats["micro_step"]
58 | self.global_samples = stats["global_samples"]
59 | self.tokens_processed = stats["tokens_processed"]
60 |
61 | self._frozen_params = set()
62 | self.current_batch_size = 0
63 |
64 | def freeze(self, freeze_all=True):
65 | # freeze non-LoRA params if requested
66 | if not self.hyper_config.frozen_params and not freeze_all and cfg.lora is not None:
67 | frozen_params = freeze_non_lora_weights( self.module, embeddings=cfg.lora.embeddings )
68 | for param in frozen_params:
69 | self._frozen_params.add( param )
70 |
71 | return
72 |
73 | if self.hyper_config is None or not hasattr(self.hyper_config, "frozen_params"):
74 | raise Exception("freeze_all=False yet self.hyper_config.frozen_params is None")
75 |
76 | for name, param in self.module.named_parameters():
77 | if (freeze_all and param.requires_grad) or (not freeze_all and name in self.hyper_config.frozen_params):
78 | param.requires_grad_(False)
79 | self._frozen_params.add(param)
80 |
81 | def unfreeze(self):
82 | for param in self._frozen_params:
83 | param.requires_grad_(True)
84 | self._frozen_params.clear()
85 |
86 | @property
87 | def _training(self):
88 | return self.hyper_config.training
89 |
90 | @property
91 | def _teacher(self):
92 | return self.hyper_config.teacher
93 |
94 | @property
95 | def global_step(self):
96 | return self.global_steps
97 |
98 | @property
99 | def micro_step(self):
100 | return self.micro_steps
101 |
102 | @property
103 | def batch_size(self):
104 | return self.current_batch_size if self.current_batch_size > 0 else cfg.hyperparameters.batch_size
105 |
106 | def gather_attribute(self, *args, **kwargs):
107 | return gather_attribute(self.module, *args, **kwargs)
108 |
109 | def dispatch_attribute(self, *args, **kwargs):
110 | return dispatch_attribute(self.module, *args, **kwargs)
111 |
112 | def set_lr(self, lr):
113 | try:
114 | if hasattr(self.optimizer, 'param_groups'):
115 | for param_group in self.optimizer.param_groups:
116 | param_group["d_coeff" if "d_coeff" in param_group else "lr"] = lr
117 | else:
118 | self.optimizer.set_lr(lr)
119 | except Exception as e:
120 | _logger.warning(str(e))
121 |
122 | # cur_scale, because _get_loss_scale has a typo in the def and I can't be assed to inject a fix into it or push a PR
123 | def get_loss_scale(self):
124 | if not hasattr(self.optimizer, "cur_scale") or self.optimizer.cur_scale is None:
125 | return 1.0
126 |
127 | return self.optimizer.cur_scale
128 |
129 | def set_loss_scale(self, value):
130 | if not hasattr(self.optimizer, "cur_scale") or self.optimizer.cur_scale is None:
131 | return
132 |
133 | self.optimizer.cur_scale = value
134 |
135 | # we'll just have to live with the LoRA weights living within our main weights
136 | # they're easy to extract anyways
137 | def load_checkpoint(self, load_dir, **kwargs ):
138 | # override to load the lora instead
139 | if cfg.lora is not None:
140 | load_dir = cfg.ckpt_dir / cfg.lora.full_name
141 |
142 | return super().load_checkpoint( load_dir, **kwargs )
143 |
144 | def save_checkpoint(self, save_dir, **kwargs ):
145 | # override to save the lora instead
146 | if cfg.lora is not None:
147 | save_dir = cfg.ckpt_dir / cfg.lora.full_name
148 |
149 | return super().save_checkpoint( save_dir, **kwargs )
150 |
151 | def traverse(self, *args, **kwargs):
152 | with ml.autocast():
153 | self.forward(*args, **kwargs)
154 |
155 | losses = self.gather_attribute("loss")
156 | loss = torch.stack([*losses.values()]).sum()
157 |
158 | stats = {}
159 | stats |= {k: v.item() for k, v in losses.items()}
160 | stats |= self.gather_attribute("scalar")
161 |
162 | """
163 | if torch.isnan(loss).any():
164 | self.max_nan_losses = self.max_nan_losses - 1
165 | if self.max_nan_losses < 0:
166 | raise RuntimeError("Too many NaN losses detected.")
167 |
168 | return stats
169 | """
170 |
171 | self.backward(loss)
172 | self.step()
173 |
174 | return stats
--------------------------------------------------------------------------------
/vall_e/emb/codecs/dac.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from dac import DACFile
4 | from audiotools import AudioSignal
5 | from dac.utils import load_model as load_dac_model
6 |
7 | from typing import Union
8 | from pathlib import Path
9 | """
10 | Patch decode to skip things related to the metadata (namely the waveform trimming)
11 | So far it seems the raw waveform can just be returned without any post-processing
12 | A smart implementation would just reuse the values from the input prompt
13 | """
14 | from dac.model.base import CodecMixin
15 |
16 | @torch.no_grad()
17 | def CodecMixin_compress(
18 | self,
19 | audio_path_or_signal: Union[str, Path, AudioSignal],
20 | win_duration: float = 1.0,
21 | verbose: bool = False,
22 | normalize_db: float = -16,
23 | n_quantizers: int = None,
24 | ) -> DACFile:
25 | """Processes an audio signal from a file or AudioSignal object into
26 | discrete codes. This function processes the signal in short windows,
27 | using constant GPU memory.
28 |
29 | Parameters
30 | ----------
31 | audio_path_or_signal : Union[str, Path, AudioSignal]
32 | audio signal to reconstruct
33 | win_duration : float, optional
34 | window duration in seconds, by default 5.0
35 | verbose : bool, optional
36 | by default False
37 | normalize_db : float, optional
38 | normalize db, by default -16
39 |
40 | Returns
41 | -------
42 | DACFile
43 | Object containing compressed codes and metadata
44 | required for decompression
45 | """
46 | audio_signal = audio_path_or_signal
47 | if isinstance(audio_signal, (str, Path)):
48 | audio_signal = AudioSignal.load_from_file_with_ffmpeg(str(audio_signal))
49 |
50 | self.eval()
51 | original_padding = self.padding
52 | original_device = audio_signal.device
53 |
54 | audio_signal = audio_signal.clone()
55 | original_sr = audio_signal.sample_rate
56 |
57 | resample_fn = audio_signal.resample
58 | loudness_fn = audio_signal.loudness
59 |
60 | # If audio is > 10 minutes long, use the ffmpeg versions
61 | if audio_signal.signal_duration >= 10 * 60 * 60:
62 | resample_fn = audio_signal.ffmpeg_resample
63 | loudness_fn = audio_signal.ffmpeg_loudness
64 |
65 | original_length = audio_signal.signal_length
66 | resample_fn(self.sample_rate)
67 | input_db = loudness_fn()
68 |
69 | if normalize_db is not None:
70 | audio_signal.normalize(normalize_db)
71 | audio_signal.ensure_max_of_audio()
72 |
73 | nb, nac, nt = audio_signal.audio_data.shape
74 | audio_signal.audio_data = audio_signal.audio_data.reshape(nb * nac, 1, nt)
75 | win_duration = (
76 | audio_signal.signal_duration if win_duration is None else win_duration
77 | )
78 |
79 | if audio_signal.signal_duration <= win_duration:
80 | # Unchunked compression (used if signal length < win duration)
81 | self.padding = True
82 | n_samples = nt
83 | hop = nt
84 | else:
85 | # Chunked inference
86 | self.padding = False
87 | # Zero-pad signal on either side by the delay
88 | audio_signal.zero_pad(self.delay, self.delay)
89 | n_samples = int(win_duration * self.sample_rate)
90 | # Round n_samples to nearest hop length multiple
91 | n_samples = int(math.ceil(n_samples / self.hop_length) * self.hop_length)
92 | hop = self.get_output_length(n_samples)
93 |
94 | codes = []
95 | range_fn = range if not verbose else tqdm.trange
96 |
97 | for i in range_fn(0, nt, hop):
98 | x = audio_signal[..., i : i + n_samples]
99 | x = x.zero_pad(0, max(0, n_samples - x.shape[-1]))
100 |
101 | audio_data = x.audio_data.to(self.device)
102 | audio_data = self.preprocess(audio_data, self.sample_rate)
103 | with torch.autocast("cuda", dtype=cfg.inference.dtype, enabled=cfg.inference.amp):
104 | _, c, _, _, _ = self.encode(audio_data, n_quantizers)
105 | codes.append(c.to(original_device))
106 | chunk_length = c.shape[-1]
107 |
108 | codes = torch.cat(codes, dim=-1)
109 |
110 | dac_file = DACFile(
111 | codes=codes,
112 | chunk_length=chunk_length,
113 | original_length=original_length,
114 | input_db=input_db,
115 | channels=nac,
116 | sample_rate=original_sr,
117 | padding=self.padding,
118 | dac_version="1.0.0",
119 | #dac_version=SUPPORTED_VERSIONS[-1],
120 | )
121 |
122 | if n_quantizers is not None:
123 | codes = codes[:, :n_quantizers, :]
124 |
125 | self.padding = original_padding
126 | return dac_file
127 |
128 | @torch.no_grad()
129 | def CodecMixin_decompress(
130 | self,
131 | obj: Union[str, Path, DACFile],
132 | verbose: bool = False,
133 | ) -> AudioSignal:
134 | self.eval()
135 | if isinstance(obj, (str, Path)):
136 | obj = DACFile.load(obj)
137 |
138 | original_padding = self.padding
139 | self.padding = obj.padding
140 |
141 | range_fn = range if not verbose else tqdm.trange
142 | codes = obj.codes
143 | original_device = codes.device
144 | chunk_length = obj.chunk_length
145 | recons = []
146 |
147 | for i in range_fn(0, codes.shape[-1], chunk_length):
148 | c = codes[..., i : i + chunk_length].to(self.device)
149 | z = self.quantizer.from_codes(c)[0]
150 | r = self.decode(z)
151 | recons.append(r.to(original_device))
152 |
153 | recons = torch.cat(recons, dim=-1)
154 | recons = AudioSignal(recons, self.sample_rate)
155 |
156 | # to-do, original implementation
157 | if not hasattr(obj, "dummy") or not obj.dummy:
158 | resample_fn = recons.resample
159 | loudness_fn = recons.loudness
160 |
161 | # If audio is > 10 minutes long, use the ffmpeg versions
162 | if recons.signal_duration >= 10 * 60 * 60:
163 | resample_fn = recons.ffmpeg_resample
164 | loudness_fn = recons.ffmpeg_loudness
165 |
166 | recons.normalize(obj.input_db)
167 | resample_fn(obj.sample_rate)
168 | recons = recons[..., : obj.original_length]
169 | loudness_fn()
170 | recons.audio_data = recons.audio_data.reshape(
171 | -1, obj.channels, obj.original_length
172 | )
173 | self.padding = original_padding
174 | return recons
175 |
176 | CodecMixin.compress = CodecMixin_compress
177 | CodecMixin.decompress = CodecMixin_decompress
--------------------------------------------------------------------------------
/vall_e/utils/sampler.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from typing import Any
3 | import random
4 |
5 | import torch
6 | from torch.utils.data import Sampler
7 |
8 | from .distributed import global_rank, local_rank, world_size
9 |
10 | # Randomly picks an index from an array of indices
11 | class PoolSampler():
12 | def __init__( self, pool = [], keep_all = False, shuffle = False ):
13 | self.length = len(pool)
14 | self.shuffle = shuffle
15 | self.global_pool = pool if keep_all else None
16 | self.global_indices = [ i for i in range(self.length) ]
17 | self.reset()
18 |
19 | def reset(self):
20 | self.current_pool = [ i for i in self.global_indices ]
21 | if self.shuffle:
22 | random.shuffle(self.current_pool)
23 |
24 | def sample(self, pool = None):
25 | if pool is None:
26 | pool = self.global_pool
27 | # check if we need to reset
28 | index = random.choice( self.current_pool )
29 | # remove from pool
30 | self.current_pool.remove(index)
31 | # reset if needed
32 | if len(self.current_pool) == 0:
33 | self.reset()
34 | # map indices to our real values
35 | return pool[index] if pool is not None else index
36 |
37 | def __len__(self):
38 | return self.length # len(self.current_pool)
39 |
40 | def __iter__(self):
41 | while len(self.current_pool) > 0:
42 | yield self.sample()
43 |
44 | def __call__(self, *args, **kwargs):
45 | return self.sample(*args, **kwargs)
46 |
47 | def index(self):
48 | return len(self.global_indices) - len(self.current_pool)
49 |
50 | def get_state(self):
51 | return { "length": self.length, "global_pool": self.global_pool, "global_indices": self.global_indices, "current_pool": self.current_pool }
52 |
53 | def set_state(self, state):
54 | self.length = state["length"]
55 | self.global_pool = state["global_pool"]
56 | self.global_indices = state["global_indices"]
57 | self.current_pool = state["current_pool"]
58 |
59 | # "Samples" through a fixed sequence from 0 to length
60 | # Necessary for our "shuffle+sort by duration+interleave" sampling method
61 | # Allows saving and loading state
62 | class OrderedSampler(Sampler):
63 | def __init__( self, length ):
64 | self.position = 0
65 | self.length = length
66 |
67 | def __len__(self):
68 | return self.length
69 |
70 | def __iter__(self):
71 | if self.position >= self.length:
72 | self.position = 0
73 |
74 | while self.position < self.length:
75 | yield self.position
76 | self.position += 1
77 |
78 | def index(self):
79 | return self.position
80 |
81 | def get_state(self):
82 | return { "position": self.position, "length": self.length }
83 |
84 | def set_state(self, state):
85 | self.position = state["position"]
86 | self.length = state["length"]
87 |
88 | # Like the above, but will batch based on token count
89 | class BatchedOrderedSampler(Sampler):
90 | def __init__( self, buckets, max_duration=0, max_batch_size=0, shuffle=False, drop_last=True, use_max_size=True ):
91 | self.position = 0
92 | self.batches = []
93 | self.shuffle = shuffle
94 |
95 | assert max_duration != 0 and max_batch_size != 0, "max_duration and max_batch_size cannot both be 0"
96 |
97 | current_batch = []
98 | current_index = 0
99 | current_duration = 0
100 |
101 | for key, bucket in buckets.items():
102 | for path, duration in bucket:
103 | # flush
104 | should_flush = False
105 | if max_duration > 0 and current_duration + duration > max_duration:
106 | should_flush = True
107 | elif max_batch_size > 0 and len(current_batch) >= max_batch_size:
108 | should_flush = True
109 |
110 | if should_flush and len(current_batch) > 0:
111 | self.batches.append( current_batch )
112 | current_batch = []
113 | current_duration = 0
114 |
115 | current_batch.append( current_index )
116 | current_index += 1
117 | # as long as durations are ordered, this assertion is always true
118 | if use_max_size:
119 | current_duration = duration * len(current_batch)
120 | else:
121 | current_duration += duration
122 |
123 | if not drop_last and current_batch:
124 | self.batches.append( current_batch )
125 |
126 | if self.shuffle:
127 | random.shuffle(self.batches)
128 |
129 | def __len__(self):
130 | return len(self.batches)
131 |
132 | def __iter__(self):
133 | if self.position >= len(self.batches):
134 | self.position = 0
135 | if self.shuffle:
136 | random.shuffle(self.batches)
137 |
138 | while self.position < len(self.batches):
139 | yield self.batches[self.position]
140 | self.position += 1
141 |
142 | def index(self):
143 | return self.position
144 |
145 | def get_state(self):
146 | return { "position": self.position, "batches": self.batches }
147 |
148 | def set_state(self, state):
149 | self.position = state["position"]
150 | self.batches = state["batches"]
151 |
152 | # Randomly samples indices from a given sequence from 0 to length
153 | # Allows saving and loading state
154 | class RandomSampler(Sampler):
155 | def __init__( self, length ):
156 | self.position = 0
157 | self.length = length
158 |
159 | self.generator = torch.Generator()
160 | self.perm = torch.randperm(self.length, generator=self.generator)
161 |
162 | def __len__(self):
163 | return self.length
164 |
165 | def __iter__(self):
166 | if self.position >= self.length:
167 | self.position = 0
168 | self.perm = torch.randperm(self.length, generator=self.generator)
169 |
170 | while self.position < self.length:
171 | yield self.perm[self.position]
172 | self.position += 1
173 |
174 | def index(self):
175 | return self.position
176 |
177 | def get_state(self):
178 | return { "position": self.position, "length": self.length, "perm": self.perm, "generator": self.generator.get_state() }
179 |
180 | def set_state(self, state):
181 | self.position = state["position"]
182 | self.length = state["length"]
183 | self.perm = state["perm"]
184 | self.generator.set_state(state["generator"])
--------------------------------------------------------------------------------
/vall_e/__main__.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from pathlib import Path
3 | from .inference import TTS
4 | from .config import cfg
5 |
6 | def path_list(arg):
7 | if not arg:
8 | return None
9 | return [Path(p) for p in arg.split(";")]
10 |
11 | def main():
12 | parser = argparse.ArgumentParser("VALL-E TTS")
13 | parser.add_argument("text")
14 | parser.add_argument("references", type=path_list, default=None)
15 | parser.add_argument("--language", type=str, default="auto")
16 | parser.add_argument("--text-language", type=str, default=None)
17 | parser.add_argument("--task", type=str, default="tts")
18 | parser.add_argument("--modality", type=str, default="auto")
19 | parser.add_argument("--out-path", type=Path, default=None)
20 |
21 | parser.add_argument("--split-text-by", type=str, default="\n")
22 | parser.add_argument("--context-history", type=int, default=0)
23 | parser.add_argument("--no-phonemize", action='store_true')
24 |
25 | parser.add_argument("--yaml", type=Path, default=None)
26 | parser.add_argument("--model", type=Path, default=None)
27 | parser.add_argument("--lora", type=Path, default=None)
28 |
29 | parser.add_argument("--max-duration", type=int, default=12 * cfg.dataset.frames_per_second)
30 | parser.add_argument("--max-steps", type=int, default=25)
31 | parser.add_argument("--max-levels", type=int, default=7)
32 |
33 | parser.add_argument("--ar-temperature", type=float, default=1.0)
34 | parser.add_argument("--nar-temperature", type=float, default=0.0)
35 | parser.add_argument("--min-ar-temperature", type=float, default=-1.0)
36 | parser.add_argument("--min-nar-temperature", type=float, default=-1.0)
37 | parser.add_argument("--input-prompt-length", type=float, default=3.0)
38 | parser.add_argument("--input-prompt-prefix", action="store_true")
39 | parser.add_argument("--prefix-silence", type=float, default=0.0)
40 | parser.add_argument("--cfg-strength", type=float, default=0.0)
41 | parser.add_argument("--cfg-rescale", type=float, default=0.75)
42 |
43 | parser.add_argument("--top-p", type=float, default=1.0)
44 | parser.add_argument("--top-k", type=int, default=0)
45 | parser.add_argument("--top-no", type=float, default=0.0)
46 | parser.add_argument("--min-p", type=float, default=0.0)
47 | parser.add_argument("--repetition-penalty", type=float, default=1.0)
48 | parser.add_argument("--repetition-penalty-decay", type=float, default=0.0)
49 | parser.add_argument("--length-penalty", type=float, default=0.0)
50 | parser.add_argument("--beam-width", type=int, default=0)
51 |
52 | parser.add_argument("--mirostat-tau", type=float, default=0)
53 | parser.add_argument("--mirostat-eta", type=float, default=0)
54 |
55 | parser.add_argument("--dry-multiplier", type=float, default=0)
56 | parser.add_argument("--dry-base", type=float, default=1.75)
57 | parser.add_argument("--dry-allowed-length", type=int, default=2)
58 |
59 | parser.add_argument("--entropix-sampling", action="store_true")
60 |
61 | parser.add_argument("--layer-skip", action="store_true")
62 | parser.add_argument("--layer-skip-exit-layer", type=int, default=None)
63 | parser.add_argument("--layer-skip-entropy-threshold", type=int, default=0.1)
64 | parser.add_argument("--layer-skip-varentropy-threshold", type=int, default=0.1)
65 | parser.add_argument("--refine-on-stop", action="store_true")
66 |
67 | # experimental settings
68 | parser.add_argument("--load-from-artifact", type=Path, default=None)
69 | parser.add_argument("--denoise-start", type=float, default=0.0)
70 |
71 | parser.add_argument("--seed", type=int, default=None)
72 |
73 | parser.add_argument("--device", type=str, default=None)
74 | parser.add_argument("--amp", action="store_true")
75 | parser.add_argument("--dtype", type=str, default=None)
76 | parser.add_argument("--attention", type=str, default=None)
77 | parser.add_argument("--play", action="store_true")
78 | args = parser.parse_args()
79 |
80 | config = None
81 |
82 | if args.yaml:
83 | config = args.yaml
84 | elif args.model:
85 | config = args.model
86 |
87 | tts = TTS( config=config, lora=args.lora, device=args.device, dtype=args.dtype, amp=args.amp, attention=args.attention )
88 |
89 | sampling_kwargs = dict(
90 | split_text_by=args.split_text_by,
91 | context_history=args.context_history,
92 | phonemize=not args.no_phonemize,
93 | max_steps=args.max_steps,
94 | max_levels=args.max_levels,
95 | max_duration=args.max_duration,
96 | ar_temperature=args.ar_temperature, nar_temperature=args.nar_temperature,
97 | min_ar_temperature=args.min_ar_temperature, min_nar_temperature=args.min_nar_temperature,
98 | top_p=args.top_p, top_k=args.top_k, top_no=args.top_no,min_p=args.min_p,
99 | repetition_penalty=args.repetition_penalty, repetition_penalty_decay=args.repetition_penalty_decay,
100 | length_penalty=args.length_penalty,
101 | beam_width=args.beam_width,
102 | mirostat_tau=args.mirostat_tau, mirostat_eta=args.mirostat_eta,
103 | dry_multiplier=args.dry_multiplier, dry_base=args.dry_base, dry_allowed_length=args.dry_allowed_length,
104 | entropix_sampling=args.entropix_sampling,
105 | layer_skip=args.layer_skip,
106 | layer_skip_exit_layer=args.layer_skip_exit_layer,
107 | layer_skip_entropy_threshold=args.layer_skip_entropy_threshold,
108 | layer_skip_varentropy_threshold=args.layer_skip_varentropy_threshold,
109 | refine_on_stop=args.refine_on_stop,
110 | denoise_start=args.denoise_start,
111 | input_prompt_length=args.input_prompt_length,
112 | input_prompt_prefix=args.input_prompt_prefix,
113 | prefix_silence=args.prefix_silence,
114 | cfg_strength=args.cfg_strength,
115 | cfg_rescale=args.cfg_rescale,
116 | )
117 |
118 | output = tts.inference(
119 | text=args.text,
120 | references=args.references,
121 | text_language=args.text_language,
122 | language=args.language,
123 | task=args.task,
124 | modality=args.modality,
125 | out_path=args.out_path,
126 | play=args.play,
127 |
128 | input_prompt_length=args.input_prompt_length,
129 | load_from_artifact=args.load_from_artifact,
130 |
131 | sampling_kwargs=sampling_kwargs,
132 |
133 | seed=args.seed,
134 | )
135 |
136 | if isinstance( output, str ):
137 | print( output )
138 |
139 | if __name__ == "__main__":
140 | main()
141 |
--------------------------------------------------------------------------------
/vall_e.cpp/vall_e.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 |
3 | // C++ deps
4 | #include
5 | #include
6 | #include
7 |
8 | #include
9 |
10 | // handles defining platform specific macros and import/export decorators (copied from my engine's uf/config.h)
11 | #if defined(_WIN32) || defined(__WIN32__) || defined(__CYGWIN__)
12 | // Windows
13 | #define VALL_E_ENV "Windows"
14 | #define VALL_E_ENV_WINDOWS 1
15 | #define VALL_E_ENV_HEADER "windows.h"
16 | #if defined(__CYGWIN__)
17 | #define to_string(var) string(var)
18 | #endif
19 | #ifndef _WIN32_WINNT
20 | #define _WIN32_WINNT 0x0600
21 | #endif
22 | #ifndef WINVER
23 | #define WINVER 0x0600
24 | #endif
25 |
26 | #define VALL_E_IO_ROOT "./data/"
27 | #elif defined(linux) || defined(__linux)
28 | // Linux
29 | #define VALL_E_ENV "Linux"
30 | #define VALL_E_ENV_LINUX 1
31 | #define VALL_E_ENV_HEADER "linux.h"
32 |
33 | #define VALL_E_IO_ROOT "./data/"
34 | #elif defined(__APPLE__) || defined(MACOSX) || defined(macintosh) || defined(Macintosh)
35 | // MacOS
36 | #define VALL_E_ENV "OSX"
37 | #define VALL_E_ENV_OSX 1
38 | #define VALL_E_ENV_HEADER "osx.h"
39 |
40 | #define VALL_E_IO_ROOT "./data/"
41 | #elif defined(__FreeBSD__) || defined(__FreeBSD_kernel__)
42 | // FreeBSD
43 | #define VALL_E_ENV "FreeBSD"
44 | #define VALL_E_ENV_FREEBSD 1
45 | #define VALL_E_ENV_HEADER "freebsd.h"
46 |
47 | #define VALL_E_IO_ROOT "./data/"
48 | #elif defined(__sh__)
49 | // Dreamcast
50 | #define VALL_E_ENV "Dreamcast"
51 | #define VALL_E_ENV_DREAMCAST 1
52 | #define VALL_E_ENV_HEADER "dreamcast.h"
53 | #include VALL_E_ENV_HEADER
54 |
55 | #define _arch_dreamcast
56 |
57 | #define VALL_E_IO_ROOT "/cd/"
58 | #else
59 | // Unsupported system
60 | #define VALL_E_ENV "Unknown"
61 | #define VALL_E_ENV_UNKNOWN 1
62 | #define VALL_E_ENV_HEADER "unknown.h"
63 | #warning Using "unknown"
64 | #error No support
65 | #endif
66 |
67 | #if !defined(VALL_E_STATIC)
68 | #if defined(VALL_E_ENV_WINDOWS)
69 | // Windows compilers need specific (and different) keywords for export and import
70 | #define VALL_E_API_EXPORT __declspec(dllexport)
71 | #define VALL_E_API_IMPORT __declspec(dllimport)
72 | // For Visual C++ compilers, we also need to turn off this annoying C4251 warning
73 | #ifdef _MSC_VER
74 | #pragma warning(disable : 4251)
75 | #endif
76 | #else // Linux, FreeBSD, Mac OS X
77 | #if __GNUC__ >= 4
78 | // GCC 4 has special keywords for showing/hidding symbols,
79 | // the same keyword is used for both importing and exporting
80 | #define VALL_E_API_EXPORT __attribute__ ((__visibility__ ("default")))
81 | #define VALL_E_API_IMPORT __attribute__ ((__visibility__ ("default")))
82 | #else
83 | // GCC < 4 has no mechanism to explicitely hide symbols, everything's exported
84 | #define VALL_E_API_EXPORT
85 | #define VALL_E_API_IMPORT
86 | #endif
87 | #endif
88 | #else
89 | // Static build doesn't need import/export macros
90 | #define VALL_E_API_EXPORT
91 | #define VALL_E_API_IMPORT
92 | #endif
93 |
94 | #ifdef VALL_E_EXPORTS
95 | #define VALL_E_API VALL_E_API_EXPORT
96 | #else
97 | #define VALL_E_API VALL_E_API_IMPORT
98 | #endif
99 |
100 | typedef llama_token token_t;
101 | typedef std::vector> vall_e_audio_codes_t;
102 |
103 | const int ENCODEC_FRAMES_PER_SECOND = 75;
104 | const int MAX_DURATION = ENCODEC_FRAMES_PER_SECOND * 12;
105 | const int CTX_SIZE = 2048;
106 | const int N_THREADS = 8;
107 | const int N_GPU_LAYERS = 99;
108 |
109 | const int MODALITY_AR_NAR = 0;
110 | const int MODALITY_NAR_LEN = 1;
111 |
112 | // forward declarations
113 | struct io_map_t;
114 | struct llama_model;
115 | struct llama_context;
116 | struct encodec_context;
117 |
118 | // model-specific parameters
119 | struct vall_e_context_params_t {
120 | std::string model_path = "./data/vall_e.gguf";
121 | std::string encodec_path = "./data/encodec.bin";
122 | int32_t gpu_layers = N_GPU_LAYERS;
123 | int32_t n_threads = N_THREADS;
124 | int32_t ctx_size = CTX_SIZE;
125 | bool verbose = false;
126 | };
127 | // inference-specific arguments
128 | struct vall_e_args_t {
129 | std::string text = "Hello world.";
130 | std::string prompt_path = "./data/prom.wav";
131 | std::string output_path = "./data/resp.wav";
132 | std::string language = "en";
133 | std::string task = "tts";
134 | int modality = MODALITY_NAR_LEN;
135 | int max_steps = 30;
136 | int max_duration = MAX_DURATION;
137 | };
138 | // stores everything needed for vall_e.cpp at runtime
139 | struct vall_e_context_t {
140 | vall_e_context_params_t params;
141 |
142 | io_map_t* io_map = NULL; // pointer for reasons
143 |
144 | struct {
145 | llama_model* model = NULL;
146 | llama_context* ctx = NULL;
147 | } llama;
148 |
149 | struct {
150 | encodec_context* ctx;
151 | } encodec;
152 | };
153 | // stores the raw inputs to be fed
154 | struct vall_e_inputs_t {
155 | std::string task = "tts";
156 | std::string lang = "en";
157 |
158 | token_t rvq_l = 0;
159 |
160 | std::vector phn = {};
161 | vall_e_audio_codes_t prom = {};
162 | vall_e_audio_codes_t resp = {};
163 | };
164 |
165 | // encodec helpers
166 | VALL_E_API std::vector read_audio_from_disk( const std::string& path );
167 | VALL_E_API void write_audio_to_disk( const std::vector& waveform, const std::string& path );
168 |
169 | VALL_E_API std::vector> encode_audio( struct encodec_context* ectx, const std::vector& waveform );
170 | VALL_E_API std::vector decode_audio( struct encodec_context* ectx, const vall_e_audio_codes_t& codes_2d );
171 |
172 | // context management
173 | VALL_E_API void vall_e_print_usage( char** argv, const vall_e_context_params_t& params, const vall_e_args_t& args );
174 | VALL_E_API bool vall_e_args_parse( int argc, char** argv, vall_e_context_params_t& params, vall_e_args_t& args );
175 | VALL_E_API vall_e_context_t* vall_e_load( const vall_e_context_params_t& params );
176 | VALL_E_API vall_e_inputs_t vall_e_prepare_inputs( vall_e_context_t* ctx, const std::string& text, const std::string& prompt_path, const std::string& lang = "auto", const std::string& task = "tts" );
177 | VALL_E_API vall_e_audio_codes_t vall_e_generate( vall_e_context_t* ctx, vall_e_inputs_t& inputs, int max_steps, int max_duration, int modality = MODALITY_NAR_LEN );
178 | VALL_E_API void vall_e_free( vall_e_context_t* ctx );
179 |
--------------------------------------------------------------------------------
/scripts/process_emilia.py:
--------------------------------------------------------------------------------
1 | """
2 | # Handles processing audio provided through --input-audio of adequately annotated transcriptions provided through --input-metadata (through transcribe.py)
3 | # Outputs NumPy objects containing quantized audio and adequate metadata for use of loading in the trainer through --output-dataset
4 | """
5 |
6 | import os
7 | import json
8 | import argparse
9 | import torch
10 | import torchaudio
11 | import numpy as np
12 |
13 | from tqdm.auto import tqdm
14 | from pathlib import Path
15 |
16 | from vall_e.config import cfg
17 |
18 | from vall_e.emb.g2p import encode as phonemize
19 | from vall_e.emb.qnt import encode as quantize, _replace_file_extension, convert_audio
20 |
21 | from vall_e.emb.process import pad, load_audio, process_items, process_jobs
22 |
23 | def process(
24 | audio_backend="encodec",
25 | input_audio="Emilia",
26 | output_dataset="training",
27 | raise_exceptions=False,
28 | stride=0,
29 | stride_offset=0,
30 | slice="auto",
31 | batch_size=1,
32 | low_memory=False,
33 |
34 | device="cuda",
35 | dtype="float16",
36 | amp=False,
37 | ):
38 | # prepare from args
39 | cfg.device = device
40 | cfg.set_audio_backend(audio_backend)
41 | audio_extension = cfg.audio_backend_extension
42 |
43 | cfg.inference.weight_dtype = dtype # "bfloat16"
44 | cfg.inference.amp = amp # False
45 |
46 | dtype = cfg.inference.dtype if not amp else None
47 |
48 | output_dataset = f"{output_dataset}/{'2' if cfg.sample_rate == 24_000 else '4'}{'8' if cfg.sample_rate == 48_000 else '4'}KHz-{cfg.audio_backend}" # "training"
49 |
50 | language_map = {} # k = group, v = language
51 |
52 | ignore_groups = [] # skip these groups
53 | ignore_speakers = [] # skip these speakers
54 |
55 | only_groups = [] # only process these groups
56 | only_speakers = [] # only process these speakers
57 |
58 | always_slice_groups = [] # always slice from this group
59 |
60 | missing = {
61 | "transcription": [],
62 | "audio": []
63 | }
64 | dataset = []
65 |
66 | # Layout: ./Emilia/JA/JA-B000000/JA_B00000_S00000_W000000.{json|mp3}
67 | for language in sorted(os.listdir(f'./{input_audio}/')):
68 | if not os.path.isdir(f'./{input_audio}/{language}/'):
69 | print("Is not dir:", f'./{input_audio}/{language}/')
70 | continue
71 |
72 | if language in ignore_groups:
73 | continue
74 |
75 | if only_groups and language not in only_groups:
76 | continue
77 |
78 | group_name = "Emilia"
79 |
80 | for speaker_group in tqdm(process_items(os.listdir(f'./{input_audio}/{language}/'), stride=stride, stride_offset=stride_offset), desc=f"Processing speaker in {language}"):
81 | if not os.path.isdir(f'./{input_audio}/{language}/{speaker_group}'):
82 | print("Is not dir:", f'./{input_audio}/{language}/{speaker_group}')
83 | continue
84 |
85 | if speaker_group in ignore_speakers:
86 | continue
87 | if only_speakers and speaker_group not in only_speakers:
88 | continue
89 |
90 | if f'{group_name}/{speaker_group}' not in dataset:
91 | dataset.append(f'{group_name}/{speaker_group}')
92 |
93 | txts = []
94 | wavs = []
95 |
96 | for filename in os.listdir(f'./{input_audio}/{language}/{speaker_group}'):
97 | if ".mp3" not in filename:
98 | continue
99 |
100 | inpath = Path(f'./{input_audio}/{language}/{speaker_group}/{filename}')
101 | jsonpath = _replace_file_extension(inpath, ".json")
102 | if not inpath.exists() or not jsonpath.exists():
103 | missing["audio"].append(str(inpath))
104 | continue
105 |
106 | extension = os.path.splitext(filename)[-1][1:]
107 | fname = filename.replace(f'.{extension}', "")
108 |
109 | waveform, sample_rate = None, None
110 | metadata = json.load(open(jsonpath, "r", encoding="utf-8"))
111 | if "text" not in metadata:
112 | continue
113 | speaker_id = metadata["speaker"]
114 | outpath = Path(f'./{output_dataset}/{group_name}/{speaker_id}/{fname}.{extension}').with_suffix(audio_extension)
115 | os.makedirs(f'./{output_dataset}/{group_name}/{speaker_id}/', exist_ok=True)
116 |
117 | if _replace_file_extension(outpath, audio_extension).exists():
118 | continue
119 |
120 | text = metadata["text"]
121 |
122 | if waveform is None:
123 | waveform, sample_rate = load_audio(inpath)
124 |
125 | jobs.append(( outpath, waveform, sample_rate, text, language.lower() ))
126 |
127 | # processes audio files one at a time
128 | process_jobs( jobs, device=device, speaker_id=f'{speaker_id}/{filename}', raise_exceptions=raise_exceptions, batch_size=batch_size, dtype=dtype if not amp else None )
129 | jobs = []
130 |
131 | open(f"./{output_dataset}/missing.json", 'w', encoding='utf-8').write(json.dumps(missing))
132 | open(f"./{output_dataset}/dataset.json", 'w', encoding='utf-8').write(json.dumps(dataset))
133 |
134 | def main():
135 | parser = argparse.ArgumentParser()
136 |
137 | parser.add_argument("--audio-backend", type=str, default="encodec")
138 | parser.add_argument("--dtype", type=str, default="bfloat16")
139 | parser.add_argument("--amp", action="store_true")
140 | parser.add_argument("--input-audio", type=str, default="Emilia")
141 | parser.add_argument("--output-dataset", type=str, default="training/dataset")
142 | parser.add_argument("--device", type=str, default="cuda")
143 | parser.add_argument("--raise-exceptions", action="store_true")
144 | parser.add_argument("--stride", type=int, default=0)
145 | parser.add_argument("--stride-offset", type=int, default=0)
146 | parser.add_argument("--slice", type=str, default="auto")
147 | parser.add_argument("--low-memory", action="store_true")
148 | parser.add_argument("--batch-size", type=int, default=0)
149 |
150 | args = parser.parse_args()
151 |
152 | # do some assumption magic
153 | # to-do: find a nice way to spawn multiple processes where tqdm plays nicely
154 | if args.device.isnumeric():
155 | args.stride = torch.cuda.device_count()
156 | args.stride_offset = int(args.device)
157 | args.device = f'cuda:{args.device}'
158 |
159 | process(
160 | audio_backend=args.audio_backend,
161 | input_audio=args.input_audio,
162 | output_dataset=args.output_dataset,
163 | raise_exceptions=args.raise_exceptions,
164 | stride=args.stride,
165 | stride_offset=args.stride_offset,
166 | slice=args.slice,
167 | batch_size=args.batch_size,
168 | low_memory=args.low_memory,
169 |
170 | device=args.device,
171 | dtype=args.dtype,
172 | amp=args.amp,
173 | )
174 |
175 | if __name__ == "__main__":
176 | main()
--------------------------------------------------------------------------------
/vall_e/utils/ext/muon.py:
--------------------------------------------------------------------------------
1 | # From https://github.com/MoonshotAI/Moonlight/blob/master/examples/toy_train.py
2 | # because it combines both param types and makes life easier with DeepSpeed
3 |
4 | import os
5 | import math
6 | import torch
7 | import torch.distributed as dist
8 |
9 | @torch.compile
10 | def zeropower_via_newtonschulz5(G, steps):
11 | """
12 | Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
13 | quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
14 | of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
15 | zero even beyond the point where the iteration no longer converges all the way to one everywhere
16 | on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
17 | where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
18 | performance at all relative to UV^T, where USV^T = G is the SVD.
19 | """
20 | assert len(G.shape) == 2
21 | a, b, c = (3.4445, -4.7750, 2.0315)
22 | X = G.bfloat16()
23 | if G.size(0) > G.size(1):
24 | X = X.T
25 | # Ensure spectral norm is at most 1
26 | X = X / (X.norm() + 1e-7)
27 | # Perform the NS iterations
28 | for _ in range(steps):
29 | A = X @ X.T
30 | B = (
31 | b * A + c * A @ A
32 | ) # adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng
33 | X = a * X + B @ X
34 |
35 | if G.size(0) > G.size(1):
36 | X = X.T
37 | return X
38 |
39 |
40 | class Muon(torch.optim.Optimizer):
41 | """
42 | Muon - MomentUm Orthogonalized by Newton-schulz
43 |
44 | Muon internally runs standard SGD-momentum, and then performs an orthogonalization post-
45 | processing step, in which each 2D parameter's update is replaced with the nearest orthogonal
46 | matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has
47 | the advantage that it can be stably run in bfloat16 on the GPU.
48 |
49 | Some warnings:
50 | - We believe this optimizer is unlikely to work well for training with small batch size.
51 | - We believe it may not work well for finetuning pretrained models, but we haven't tested this.
52 |
53 | Arguments:
54 | muon_params: The parameters to be optimized by Muon.
55 | lr: The learning rate. The updates will have spectral norm of `lr`. (0.02 is a good default)
56 | momentum: The momentum used by the internal SGD. (0.95 is a good default)
57 | nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended)
58 | ns_steps: The number of Newton-Schulz iterations to run. (6 is probably always enough)
59 | adamw_params: The parameters to be optimized by AdamW. Any parameters in `muon_params` which are
60 | {0, 1}-D or are detected as being the embed or lm_head will be optimized by AdamW as well.
61 | adamw_lr: The learning rate for the internal AdamW.
62 | adamw_betas: The betas for the internal AdamW.
63 | adamw_eps: The epsilon for the internal AdamW.
64 | adamw_wd: The weight decay for the internal AdamW.
65 | """
66 |
67 | def __init__(
68 | self,
69 | params=None,
70 | lr=1e-3,
71 | wd=0.1,
72 | momentum=0.95,
73 | nesterov=True,
74 | ns_steps=5,
75 | betas=(0.95, 0.95),
76 | eps=1e-8,
77 | ):
78 |
79 | defaults = dict(
80 | lr=lr,
81 | wd=wd,
82 | momentum=momentum,
83 | nesterov=nesterov,
84 | ns_steps=ns_steps,
85 | betas=betas,
86 | eps=eps,
87 | muon=False,
88 | )
89 |
90 | super().__init__(params, defaults)
91 |
92 | def adjust_lr_for_muon(self, lr, param_shape):
93 | A, B = param_shape[:2]
94 | # We adjust the learning rate and weight decay based on the size of the parameter matrix
95 | # as describted in the paper
96 | adjusted_ratio = 0.2 * math.sqrt(max(A, B))
97 | adjusted_lr = lr * adjusted_ratio
98 | return adjusted_lr
99 |
100 | def step(self, closure=None):
101 | """Perform a single optimization step.
102 |
103 | Args:
104 | closure (Callable, optional): A closure that reevaluates the model
105 | and returns the loss.
106 | """
107 | loss = None
108 | if closure is not None:
109 | with torch.enable_grad():
110 | loss = closure()
111 |
112 | for group in self.param_groups:
113 |
114 | ############################
115 | # Muon #
116 | ############################
117 | if group["muon"]:
118 | # import pdb; pdb.set_trace()
119 | lr = group["lr"]
120 | wd = group["wd"]
121 | momentum = group["momentum"]
122 |
123 | # generate weight updates in distributed fashion
124 | for p in group["params"]:
125 | # sanity check
126 | g = p.grad
127 | if g is None:
128 | continue
129 | if g.ndim > 2:
130 | g = g.view(g.size(0), -1)
131 | assert g is not None
132 |
133 | # calc update
134 | state = self.state[p]
135 | if "momentum_buffer" not in state:
136 | state["momentum_buffer"] = torch.zeros_like(g)
137 | buf = state["momentum_buffer"]
138 | buf.mul_(momentum).add_(g)
139 | if group["nesterov"]:
140 | g = g.add(buf, alpha=momentum)
141 | else:
142 | g = buf
143 | u = zeropower_via_newtonschulz5(g, steps=group["ns_steps"])
144 |
145 | # scale update
146 | adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
147 |
148 | # apply weight decay
149 | p.data.mul_(1 - lr * wd)
150 |
151 | # apply update
152 | p.data.add_(u, alpha=-adjusted_lr)
153 |
154 | ############################
155 | # AdamW backup #
156 | ############################
157 | else:
158 | lr = group['lr']
159 | beta1, beta2 = group["betas"]
160 | eps = group["eps"]
161 | weight_decay = group["wd"]
162 |
163 | for p in group["params"]:
164 | g = p.grad
165 | if g is None:
166 | continue
167 | state = self.state[p]
168 | if "step" not in state:
169 | state["step"] = 0
170 | state["moment1"] = torch.zeros_like(g)
171 | state["moment2"] = torch.zeros_like(g)
172 | state["step"] += 1
173 | step = state["step"]
174 | buf1 = state["moment1"]
175 | buf2 = state["moment2"]
176 | buf1.lerp_(g, 1 - beta1)
177 | buf2.lerp_(g.square(), 1 - beta2)
178 |
179 | g = buf1 / (eps + buf2.sqrt())
180 |
181 | bias_correction1 = 1 - beta1**step
182 | bias_correction2 = 1 - beta2**step
183 | scale = bias_correction1 / bias_correction2**0.5
184 | p.data.mul_(1 - lr * weight_decay)
185 | p.data.add_(g, alpha=-lr / scale)
186 |
187 | return loss
--------------------------------------------------------------------------------
/scripts/process_libritts.py:
--------------------------------------------------------------------------------
1 | """
2 | # Handles processing audio provided through --input-audio of adequately annotated transcriptions provided through --input-metadata (through transcribe.py)
3 | # Outputs NumPy objects containing quantized audio and adequate metadata for use of loading in the trainer through --output-dataset
4 | """
5 |
6 | import os
7 | import json
8 | import argparse
9 | import torch
10 | import torchaudio
11 | import numpy as np
12 |
13 | from tqdm.auto import tqdm
14 | from pathlib import Path
15 |
16 | from vall_e.config import cfg
17 |
18 | from vall_e.emb.g2p import encode as phonemize
19 | from vall_e.emb.qnt import encode as quantize, _replace_file_extension, convert_audio
20 |
21 | from vall_e.emb.process import pad, load_audio, process_items, process_jobs
22 |
23 |
24 | def process(
25 | audio_backend="encodec",
26 | input_audio="LibriTTS_R",
27 | output_dataset="training",
28 | raise_exceptions=False,
29 | stride=0,
30 | stride_offset=0,
31 | slice="auto",
32 | batch_size=1,
33 | low_memory=False,
34 |
35 | device="cuda",
36 | dtype="float16",
37 | amp=False,
38 | ):
39 | # prepare from args
40 | cfg.device = device
41 | cfg.set_audio_backend(audio_backend)
42 | audio_extension = cfg.audio_backend_extension
43 |
44 | cfg.inference.weight_dtype = dtype # "bfloat16"
45 | cfg.inference.amp = amp # False
46 |
47 | dtype = cfg.inference.dtype if not amp else None
48 |
49 | output_dataset = f"{output_dataset}/{'2' if cfg.sample_rate == 24_000 else '4'}{'8' if cfg.sample_rate == 48_000 else '4'}KHz-{cfg.audio_backend}" # "training"
50 |
51 | language_map = {} # k = group, v = language
52 |
53 | ignore_groups = [] # skip these groups
54 | ignore_speakers = [] # skip these speakers
55 |
56 | only_groups = [] # only process these groups
57 | only_speakers = [] # only process these speakers
58 |
59 | always_slice_groups = [] # always slice from this group
60 |
61 | missing = {
62 | "transcription": [],
63 | "audio": []
64 | }
65 | dataset = []
66 |
67 | # Layout: ./LibriTTS_R/train-clean-100/103/1241
68 | for group_name in sorted(os.listdir(f'./{input_audio}/')):
69 | if not os.path.isdir(f'./{input_audio}/{group_name}/'):
70 | print("Is not dir:", f'./{input_audio}/{group_name}/')
71 | continue
72 |
73 | if group_name in ignore_groups:
74 | continue
75 | if only_groups and group_name not in only_groups:
76 | continue
77 |
78 | for speaker_id in tqdm(process_items(os.listdir(f'./{input_audio}/{group_name}/'), stride=stride, stride_offset=stride_offset), desc=f"Processing speaker in {group_name}"):
79 | if not os.path.isdir(f'./{input_audio}/{group_name}/{speaker_id}'):
80 | print("Is not dir:", f'./{input_audio}/{group_name}/{speaker_id}')
81 | continue
82 |
83 | if speaker_id in ignore_speakers:
84 | continue
85 | if only_speakers and speaker_id not in only_speakers:
86 | continue
87 |
88 | os.makedirs(f'./{output_dataset}/{group_name}/{speaker_id}/', exist_ok=True)
89 |
90 | if f'{group_name}/{speaker_id}' not in dataset:
91 | dataset.append(f'{group_name}/{speaker_id}')
92 |
93 | txts = []
94 | wavs = []
95 |
96 | for book_id in os.listdir(f'./{input_audio}/{group_name}/{speaker_id}'):
97 | if not os.path.isdir(f'./{input_audio}/{group_name}/{speaker_id}/{book_id}'):
98 | print("Is not dir:", f'./{input_audio}/{group_name}/{speaker_id}/{book_id}')
99 | continue
100 |
101 | for filename in os.listdir(f'./{input_audio}/{group_name}/{speaker_id}/{book_id}'):
102 | if ".wav" not in filename:
103 | continue
104 |
105 | inpath = Path(f'./{input_audio}/{group_name}/{speaker_id}/{book_id}/{filename}')
106 | textpath = _replace_file_extension(inpath, ".original.txt")
107 | if not inpath.exists() or not textpath.exists():
108 | missing["audio"].append(str(inpath))
109 | continue
110 |
111 | extension = os.path.splitext(filename)[-1][1:]
112 | fname = filename.replace(f'.{extension}', "")
113 |
114 | waveform, sample_rate = None, None
115 | language = "en"
116 |
117 | outpath = Path(f'./{output_dataset}/{group_name}/{speaker_id}/{fname}.{extension}')
118 | text = open(textpath, "r", encoding="utf-8").read()
119 |
120 | if len(text) == 0:
121 | continue
122 |
123 | if _replace_file_extension(outpath, audio_extension).exists():
124 | continue
125 |
126 | if waveform is None:
127 | waveform, sample_rate = load_audio(inpath)
128 |
129 | jobs.append(( outpath, waveform, sample_rate, text, language ))
130 |
131 | # processes audio files one at a time
132 | if low_memory:
133 | process_jobs( jobs, device=device, speaker_id=f'{speaker_id}/{filename}', raise_exceptions=raise_exceptions, batch_size=batch_size, dtype=dtype if not amp else None )
134 | jobs = []
135 |
136 | # processes all audio files for a given speaker
137 | if not low_memory:
138 | process_jobs( jobs, device=device, speaker_id=speaker_id, raise_exceptions=raise_exceptions, batch_size=batch_size, dtype=dtype if not amp else None )
139 | jobs = []
140 |
141 | open(f"./{output_dataset}/missing.json", 'w', encoding='utf-8').write(json.dumps(missing))
142 | open(f"./{output_dataset}/dataset.json", 'w', encoding='utf-8').write(json.dumps(dataset))
143 |
144 | def main():
145 | parser = argparse.ArgumentParser()
146 |
147 | parser.add_argument("--audio-backend", type=str, default="encodec")
148 | parser.add_argument("--dtype", type=str, default="bfloat16")
149 | parser.add_argument("--amp", action="store_true")
150 | parser.add_argument("--input-audio", type=str, default="LibriTTS_R")
151 | parser.add_argument("--output-dataset", type=str, default="training/dataset")
152 | parser.add_argument("--device", type=str, default="cuda")
153 | parser.add_argument("--raise-exceptions", action="store_true")
154 | parser.add_argument("--stride", type=int, default=0)
155 | parser.add_argument("--stride-offset", type=int, default=0)
156 | parser.add_argument("--slice", type=str, default="auto")
157 | parser.add_argument("--low-memory", action="store_true")
158 | parser.add_argument("--batch-size", type=int, default=0)
159 |
160 | args = parser.parse_args()
161 |
162 | # do some assumption magic
163 | # to-do: find a nice way to spawn multiple processes where tqdm plays nicely
164 | if args.device.isnumeric():
165 | args.stride = torch.cuda.device_count()
166 | args.stride_offset = int(args.device)
167 | args.device = f'cuda:{args.device}'
168 |
169 | process(
170 | audio_backend=args.audio_backend,
171 | input_audio=args.input_audio,
172 | output_dataset=args.output_dataset,
173 | raise_exceptions=args.raise_exceptions,
174 | stride=args.stride,
175 | stride_offset=args.stride_offset,
176 | slice=args.slice,
177 | batch_size=args.batch_size,
178 | low_memory=args.low_memory,
179 |
180 | device=args.device,
181 | dtype=args.dtype,
182 | amp=args.amp,
183 | )
184 |
185 | if __name__ == "__main__":
186 | main()
--------------------------------------------------------------------------------
/vall_e/utils/ml.py:
--------------------------------------------------------------------------------
1 | from contextlib import contextmanager
2 |
3 | import math
4 | import logging
5 |
6 | import torch
7 | import torch.nn.functional as F
8 | from torch.optim.lr_scheduler import _LRScheduler
9 |
10 | from ..config import cfg
11 |
12 | _logger = logging.getLogger(__name__)
13 |
14 | Embedding = torch.nn.Embedding
15 | Linear = torch.nn.Linear
16 |
17 | Adam = torch.optim.Adam
18 | AdamW = torch.optim.AdamW
19 | SGD = torch.optim.SGD
20 | Adagrad = torch.optim.Adagrad
21 | Adafactor = torch.optim.Adafactor
22 |
23 | OneCycleLR = torch.optim.lr_scheduler.OneCycleLR
24 | CosineAnnealingLR = torch.optim.lr_scheduler.CosineAnnealingLR
25 | LambdaLR = torch.optim.lr_scheduler.LambdaLR
26 |
27 | # implements Noam scheduling
28 | # it's cringe
29 | class NoamLR(_LRScheduler):
30 | def __init__(self, optimizer, warmup_steps, d_model=1024, last_epoch=-1):
31 | self.base_factor = d_model ** (-0.5)
32 | self.warmup_steps = warmup_steps
33 |
34 | super().__init__(optimizer, last_epoch)
35 |
36 | def get_lr(self):
37 | step = max(1, self.last_epoch)
38 | scale = self.base_factor * min(step ** (-0.5), step * self.warmup_steps ** (-1.5))
39 |
40 | return [base_lr * scale for base_lr in self.base_lrs]
41 |
42 | # gradually warms up LR then holds or decays
43 | class WarmupLR(_LRScheduler):
44 | def __init__(self, optimizer, warmup_steps, decay_factor=0.0, last_epoch=-1):
45 | self.warmup_steps = warmup_steps
46 | self.decay_factor = decay_factor
47 |
48 | super().__init__(optimizer, last_epoch)
49 |
50 | def get_lr(self):
51 | step = self.last_epoch + 1
52 | scale = 1
53 | if step < self.warmup_steps:
54 | scale = float(step) / float(max(1, self.warmup_steps))
55 | elif self.decay_factor != 0:
56 | scale = (1.0 - self.decay_factor) ** (step - self.warmup_steps)
57 |
58 | return [base_lr * scale for base_lr in self.base_lrs]
59 |
60 | # https://github.com/kyegomez/BitNet
61 | if cfg.optimizations.bitnet:
62 | from bitnet import BitLinear
63 |
64 | if cfg.optimizations.bitsandbytes:
65 | import bitsandbytes as bnb
66 |
67 | if cfg.optimizations.linear:
68 |
69 | if cfg.optimizations.bitnet:
70 | Linear = BitLinear
71 | else:
72 | Linear = bnb.nn.Linear8bitLt
73 |
74 | if cfg.optimizations.embedding:
75 | Embedding = bnb.nn.StableEmbedding
76 | """
77 | Embedding.forward = lambda self, input: ( self.norm(F.embedding(
78 | input,
79 | self.weight,
80 | self.padding_idx,
81 | self.max_norm,
82 | self.norm_type,
83 | self.scale_grad_by_freq,
84 | self.sparse,
85 | )).to(self.weight.dtype) )
86 | """
87 |
88 | if cfg.optimizations.optimizers:
89 | Adam = bnb.optim.Adam8bit
90 | AdamW = bnb.optim.AdamW8bit
91 | SGD = bnb.optim.SGD8bit
92 | Adagrad = bnb.optim.Adagrad8bit
93 |
94 | elif cfg.optimizations.dadaptation:
95 | import dadaptation
96 |
97 | if cfg.optimizations.optimizers:
98 | Adam = dadaptation.DAdaptAdam
99 | AdamW = dadaptation.DAdaptAdam
100 | SGD = dadaptation.DAdaptSGD
101 | AdaGrad = dadaptation.DAdaptAdaGrad
102 |
103 | if cfg.optimizations.fp8:
104 | import transformer_engine.pytorch as te
105 |
106 | Linear = te.Linear
107 |
108 | @contextmanager
109 | def autocast():
110 | yield te.fp8_autocast(enabled=True)
111 | else:
112 | @contextmanager
113 | def autocast():
114 | yield torch.autocast("cuda", dtype=cfg.trainer.dtype, enabled=cfg.trainer.amp)
115 |
116 | if cfg.optimizations.injects:
117 | if cfg.optimizations.linear:
118 | torch.nn.Linear = Linear
119 |
120 | if cfg.optimizations.embedding:
121 | torch.nn.Embedding = Embedding
122 |
123 | if cfg.optimizations.optimizers:
124 | torch.optim.Adam = Adam
125 | torch.optim.AdamW = AdamW
126 | torch.optim.SGD = SGD
127 |
128 | if cfg.optimizations.unsloth:
129 | try:
130 | from .ext.unsloth import apply_unsloth_offloaded_gradient_checkpoint_monkey_patch
131 | #apply_unsloth_offloaded_gradient_checkpoint_monkey_patch()
132 | except Exception as e:
133 | _logger.warning(f'Error while importing Unsloth: {str(e)}')
134 | pass
135 |
136 | class Optimizers(torch.optim.Optimizer):
137 | def __init__(self, opts):
138 | self.opts = opts
139 |
140 | def step(self, *args, **kwargs):
141 | for opt in self.opts:
142 | opt.step(*args, **kwargs)
143 |
144 | def zero_grad(self, *args, **kwargs):
145 | for opt in self.opts:
146 | opt.zero_grad(*args, **kwargs)
147 |
148 | @property
149 | def param_groups(self):
150 | l = []
151 | for opt in self.opts:
152 | l += opt.param_groups
153 | return l
154 |
155 | def state_dict(self):
156 | states = []
157 | for i, opt in enumerate( self.opts ):
158 | states.append( opt.state_dict() )
159 |
160 | return states
161 |
162 | def load_state_dict(self, state_dict):
163 | for opt, state in zip( self.opts, state_dict ):
164 | opt.load_state_dict( state )
165 |
166 | try:
167 | from .ext.apollo import Apollo
168 | except Exception as e:
169 | _logger.warning(f'Error while importing APOLLO: {str(e)}')
170 | pass
171 |
172 | try:
173 | from .ext.muon import Muon
174 | except Exception as e:
175 | _logger.warning(f'Error while importing Muon: {str(e)}')
176 | pass
177 |
178 | # https://github.com/konstmish/prodigy
179 | try:
180 | from prodigyopt import Prodigy
181 | except Exception as e:
182 | _logger.warning(f'Error while importing Prodigyopt: {str(e)}')
183 | pass
184 |
185 | # https://github.com/facebookresearch/schedule_free/
186 | try:
187 | import schedulefree
188 | except Exception as e:
189 | _logger.warning(f'Error while importing Schedule_Free: {str(e)}')
190 | pass
191 |
192 | # backwards compat
193 | from .utils import (
194 | autocast_forward,
195 | replace_linear as replace_linear_old,
196 | replace_embedding as replace_embedding_old,
197 | replace_attention,
198 | resize_weight,
199 | offload_model,
200 | )
201 |
202 | # wrapped here so we can maintain default args
203 | def replace_linear( model, klass=Linear, target=torch.nn.Linear, verbose=False ):
204 | return replace_linear_old( model, klass, target, verbose )
205 | def replace_embedding( model, klass=Embedding, target=torch.nn.Embedding, verbose=False ):
206 | return replace_embedding_old( model, klass, target, verbose )
207 |
208 | Embedding.forward = autocast_forward(Embedding.forward)
209 |
210 | AVAILABLE_COMPILE_BACKENDS = []
211 |
212 | try:
213 | AVAILABLE_COMPILE_BACKENDS += torch._dynamo.list_backends()
214 | except Exception as e:
215 | pass
216 |
217 | def compile_model(model, backend="auto"):
218 | if not backend or backend == "auto":
219 | backend = AVAILABLE_COMPILE_BACKENDS[0]
220 |
221 | if backend not in AVAILABLE_COMPILE_BACKENDS:
222 | return torch.compile(model)
223 |
224 | return torch.compile(model, backend=backend)
225 |
226 |
227 | if cfg.optimizations.tensorrt:
228 | try:
229 | import torch_tensorrt
230 | AVAILABLE_COMPILE_BACKENDS.append("tensorrt")
231 | except Exception as e:
232 | _logger.warning(f'Error while importing TensorRT: {str(e)}')
233 | pass
--------------------------------------------------------------------------------
/vall_e.cpp/include/ggml-cpu.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 |
3 | #include "ggml.h"
4 | #include "ggml-backend.h"
5 |
6 | #ifdef __cplusplus
7 | extern "C" {
8 | #endif
9 |
10 | // the compute plan that needs to be prepared for ggml_graph_compute()
11 | // since https://github.com/ggml-org/ggml/issues/287
12 | struct ggml_cplan {
13 | size_t work_size; // size of work buffer, calculated by `ggml_graph_plan()`
14 | uint8_t * work_data; // work buffer, to be allocated by caller before calling to `ggml_graph_compute()`
15 |
16 | int n_threads;
17 | struct ggml_threadpool * threadpool;
18 |
19 | // abort ggml_graph_compute when true
20 | ggml_abort_callback abort_callback;
21 | void * abort_callback_data;
22 | };
23 |
24 | // numa strategies
25 | enum ggml_numa_strategy {
26 | GGML_NUMA_STRATEGY_DISABLED = 0,
27 | GGML_NUMA_STRATEGY_DISTRIBUTE = 1,
28 | GGML_NUMA_STRATEGY_ISOLATE = 2,
29 | GGML_NUMA_STRATEGY_NUMACTL = 3,
30 | GGML_NUMA_STRATEGY_MIRROR = 4,
31 | GGML_NUMA_STRATEGY_COUNT
32 | };
33 |
34 | GGML_BACKEND_API void ggml_numa_init(enum ggml_numa_strategy numa); // call once for better performance on NUMA systems
35 | GGML_BACKEND_API bool ggml_is_numa(void); // true if init detected that system has >1 NUMA node
36 |
37 | GGML_BACKEND_API struct ggml_tensor * ggml_new_i32(struct ggml_context * ctx, int32_t value);
38 | GGML_BACKEND_API struct ggml_tensor * ggml_new_f32(struct ggml_context * ctx, float value);
39 |
40 | GGML_BACKEND_API struct ggml_tensor * ggml_set_i32 (struct ggml_tensor * tensor, int32_t value);
41 | GGML_BACKEND_API struct ggml_tensor * ggml_set_f32 (struct ggml_tensor * tensor, float value);
42 |
43 | GGML_BACKEND_API int32_t ggml_get_i32_1d(const struct ggml_tensor * tensor, int i);
44 | GGML_BACKEND_API void ggml_set_i32_1d(const struct ggml_tensor * tensor, int i, int32_t value);
45 |
46 | GGML_BACKEND_API int32_t ggml_get_i32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3);
47 | GGML_BACKEND_API void ggml_set_i32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3, int32_t value);
48 |
49 | GGML_BACKEND_API float ggml_get_f32_1d(const struct ggml_tensor * tensor, int i);
50 | GGML_BACKEND_API void ggml_set_f32_1d(const struct ggml_tensor * tensor, int i, float value);
51 |
52 | GGML_BACKEND_API float ggml_get_f32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3);
53 | GGML_BACKEND_API void ggml_set_f32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3, float value);
54 |
55 | GGML_BACKEND_API struct ggml_threadpool * ggml_threadpool_new (struct ggml_threadpool_params * params);
56 | GGML_BACKEND_API void ggml_threadpool_free (struct ggml_threadpool * threadpool);
57 | GGML_BACKEND_API int ggml_threadpool_get_n_threads (struct ggml_threadpool * threadpool);
58 | GGML_BACKEND_API void ggml_threadpool_pause (struct ggml_threadpool * threadpool);
59 | GGML_BACKEND_API void ggml_threadpool_resume (struct ggml_threadpool * threadpool);
60 |
61 | // ggml_graph_plan() has to be called before ggml_graph_compute()
62 | // when plan.work_size > 0, caller must allocate memory for plan.work_data
63 | GGML_BACKEND_API struct ggml_cplan ggml_graph_plan(
64 | const struct ggml_cgraph * cgraph,
65 | int n_threads, /* = GGML_DEFAULT_N_THREADS */
66 | struct ggml_threadpool * threadpool /* = NULL */ );
67 | GGML_BACKEND_API enum ggml_status ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan);
68 |
69 | // same as ggml_graph_compute() but the work data is allocated as a part of the context
70 | // note: the drawback of this API is that you must have ensured that the context has enough memory for the work data
71 | GGML_BACKEND_API enum ggml_status ggml_graph_compute_with_ctx(struct ggml_context * ctx, struct ggml_cgraph * cgraph, int n_threads);
72 |
73 | //
74 | // system info
75 | //
76 |
77 | // x86
78 | GGML_BACKEND_API int ggml_cpu_has_sse3 (void);
79 | GGML_BACKEND_API int ggml_cpu_has_ssse3 (void);
80 | GGML_BACKEND_API int ggml_cpu_has_avx (void);
81 | GGML_BACKEND_API int ggml_cpu_has_avx_vnni (void);
82 | GGML_BACKEND_API int ggml_cpu_has_avx2 (void);
83 | GGML_BACKEND_API int ggml_cpu_has_bmi2 (void);
84 | GGML_BACKEND_API int ggml_cpu_has_f16c (void);
85 | GGML_BACKEND_API int ggml_cpu_has_fma (void);
86 | GGML_BACKEND_API int ggml_cpu_has_avx512 (void);
87 | GGML_BACKEND_API int ggml_cpu_has_avx512_vbmi(void);
88 | GGML_BACKEND_API int ggml_cpu_has_avx512_vnni(void);
89 | GGML_BACKEND_API int ggml_cpu_has_avx512_bf16(void);
90 | GGML_BACKEND_API int ggml_cpu_has_amx_int8 (void);
91 | // ARM
92 | GGML_BACKEND_API int ggml_cpu_has_neon (void);
93 | GGML_BACKEND_API int ggml_cpu_has_arm_fma (void);
94 | GGML_BACKEND_API int ggml_cpu_has_fp16_va (void);
95 | GGML_BACKEND_API int ggml_cpu_has_dotprod (void);
96 | GGML_BACKEND_API int ggml_cpu_has_matmul_int8(void);
97 | GGML_BACKEND_API int ggml_cpu_has_sve (void);
98 | GGML_BACKEND_API int ggml_cpu_get_sve_cnt (void); // sve vector length in bytes
99 | GGML_BACKEND_API int ggml_cpu_has_sme (void);
100 | // other
101 | GGML_BACKEND_API int ggml_cpu_has_riscv_v (void);
102 | GGML_BACKEND_API int ggml_cpu_has_vsx (void);
103 | GGML_BACKEND_API int ggml_cpu_has_vxe (void);
104 | GGML_BACKEND_API int ggml_cpu_has_wasm_simd (void);
105 | GGML_BACKEND_API int ggml_cpu_has_llamafile (void);
106 |
107 | // Internal types and functions exposed for tests and benchmarks
108 |
109 | typedef void (*ggml_vec_dot_t) (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT x, size_t bx,
110 | const void * GGML_RESTRICT y, size_t by, int nrc);
111 |
112 | struct ggml_type_traits_cpu {
113 | ggml_from_float_t from_float;
114 | ggml_vec_dot_t vec_dot;
115 | enum ggml_type vec_dot_type;
116 | int64_t nrows; // number of rows to process simultaneously
117 | };
118 |
119 | GGML_BACKEND_API const struct ggml_type_traits_cpu * ggml_get_type_traits_cpu(enum ggml_type type);
120 |
121 | GGML_BACKEND_API void ggml_cpu_init(void);
122 |
123 | //
124 | // CPU backend
125 | //
126 |
127 | GGML_BACKEND_API ggml_backend_t ggml_backend_cpu_init(void);
128 |
129 | GGML_BACKEND_API bool ggml_backend_is_cpu (ggml_backend_t backend);
130 | GGML_BACKEND_API void ggml_backend_cpu_set_n_threads (ggml_backend_t backend_cpu, int n_threads);
131 | GGML_BACKEND_API void ggml_backend_cpu_set_threadpool (ggml_backend_t backend_cpu, ggml_threadpool_t threadpool);
132 | GGML_BACKEND_API void ggml_backend_cpu_set_abort_callback(ggml_backend_t backend_cpu, ggml_abort_callback abort_callback, void * abort_callback_data);
133 |
134 | GGML_BACKEND_API ggml_backend_reg_t ggml_backend_cpu_reg(void);
135 |
136 | #ifdef __cplusplus
137 | }
138 | #endif
139 |
--------------------------------------------------------------------------------
/vall_e.cpp/include/encodec.h:
--------------------------------------------------------------------------------
1 | /*
2 | ╞══════════════════════════════════════════════════════════════════════════════╡
3 | │ Copyright 2024 Pierre-Antoine Bannier │
4 | │ │
5 | │ Permission to use, copy, modify, and/or distribute this software for │
6 | │ any purpose with or without fee is hereby granted, provided that the │
7 | │ above copyright notice and this permission notice appear in all copies. │
8 | │ │
9 | │ THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL │
10 | │ WARRANTIES WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED │
11 | │ WARRANTIES OF MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE │
12 | │ AUTHOR BE LIABLE FOR ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL │
13 | │ DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR │
14 | │ PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER │
15 | │ TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR │
16 | │ PERFORMANCE OF THIS SOFTWARE. │
17 | ╚─────────────────────────────────────────────────────────────────────────────*/
18 | /*
19 | * This file contains the declarations of the structs and functions used in the encodec library.
20 | * The library provides functionality for audio compression and decompression using a custom model.
21 | * The model consists of an encoder, a quantizer and a decoder, each with their own set of parameters.
22 | * The library also provides functions for loading and freeing the model, as well as compressing and decompressing audio data.
23 | *
24 | */
25 | #pragma once
26 |
27 | #include "ggml-alloc.h"
28 | #include "ggml-backend.h"
29 | #include "ggml.h"
30 |
31 | #ifdef __cplusplus
32 | extern "C" {
33 | #endif
34 | struct encodec_context;
35 |
36 | struct encodec_statistics {
37 | // The time taken to load the model.
38 | int64_t t_load_us;
39 | // The time taken to compute the model.
40 | int64_t t_compute_us;
41 | };
42 |
43 | /**
44 | * Loads an encodec model from the specified file path.
45 | *
46 | * @param model_path The file path to the encodec model.
47 | * @param offset The offset (in bytes) to the start of the model in the file.
48 | * @param n_gpu_layers The number of GPU layers to use.
49 | * @return A pointer to the encodec context struct.
50 | */
51 | struct encodec_context *encodec_load_model(
52 | const char *model_path,
53 | const int offset,
54 | int n_gpu_layers);
55 |
56 | /**
57 | * Sets the target bandwidth for the given encodec context.
58 | *
59 | * @param ectx The encodec context to set the target bandwidth for.
60 | * @param bandwidth The target bandwidth to set, in bits per second.
61 | */
62 | void encodec_set_target_bandwidth(
63 | struct encodec_context *ectx,
64 | int bandwidth);
65 |
66 | /**
67 | * Sets the sample rate for the given encodec context.
68 | *
69 | * @param ectx The encodec context to set the target bandwidth for.
70 | * @param sample_rate The sample rate to set.
71 | */
72 | void encodec_set_sample_rate(
73 | struct encodec_context *ectx,
74 | int sample_rate);
75 |
76 | /**
77 | * Reconstructs audio from raw audio data using the specified encodec context.
78 | *
79 | * @param ectx The encodec context to use for reconstruction.
80 | * @param raw_audio The raw audio data to reconstruct.
81 | * @param n_samples The number of samples in the raw audio buffer.
82 | * @param n_threads The number of threads to use for reconstruction.
83 | * @return True if the reconstruction was successful, false otherwise.
84 | */
85 | bool encodec_reconstruct_audio(
86 | struct encodec_context *ectx,
87 | const float *raw_audio,
88 | const int n_samples,
89 | int n_threads);
90 |
91 | /**
92 | * Compresses audio data using the specified encodec context.
93 | *
94 | * @param ectx The encodec context to use for compression.
95 | * @param raw_audio The raw audio data to compress.
96 | * @param n_samples The number of samples in the raw audio buffer.
97 | * @param n_threads The number of threads to use for compression.
98 | * @return True if the compression was successful, false otherwise.
99 | */
100 | bool encodec_compress_audio(
101 | struct encodec_context *ectx,
102 | const float *raw_audio,
103 | const int n_samples,
104 | int n_threads);
105 |
106 | /**
107 | * Decompresses audio data using the specified encodec context.
108 | *
109 | * @param ectx The encodec context to use for decompression.
110 | * @param codes The compressed audio data to decompress.
111 | * @param n_codes The number of codes in the codes buffer.
112 | * @param n_threads The number of threads to use for decompression.
113 | * @return True if the audio data was successfully decompressed, false otherwise.
114 | */
115 | bool encodec_decompress_audio(
116 | struct encodec_context *ectx,
117 | const int32_t *codes,
118 | const int n_codes,
119 | int n_threads);
120 |
121 | /**
122 | * Gets the audio data from the given encodec context.
123 | *
124 | * @param ectx The encodec context to get the audio data from.
125 | * @return A pointer to the audio data.
126 | */
127 | float * encodec_get_audio(
128 | struct encodec_context *ectx);
129 |
130 | /**
131 | * Gets the size of the audio data from the given encodec context.
132 | *
133 | * @param ectx The encodec context to get the audio size from.
134 | * @return The size of the audio data.
135 | */
136 | int encodec_get_audio_size(
137 | struct encodec_context *ectx);
138 |
139 | /**
140 | * Gets the code data from the given encodec context.
141 | *
142 | * @param ectx The encodec context to get the code data from.
143 | * @return A pointer to the code data.
144 | */
145 | int32_t * encodec_get_codes(
146 | struct encodec_context *ectx);
147 |
148 | /**
149 | * Gets the size of the code data from the given encodec context.
150 | *
151 | * @param ectx The encodec context to get the code size from.
152 | * @return The size of the code data.
153 | */
154 | int encodec_get_codes_size(
155 | struct encodec_context *ectx);
156 |
157 | /**
158 | * Gets the statistics for the given encodec context.
159 | *
160 | * @param ectx The encodec context to get the statistics for.
161 | * @return A pointer to the statistics struct.
162 | */
163 | const struct encodec_statistics* encodec_get_statistics(
164 | struct encodec_context *ectx);
165 |
166 | /**
167 | * Reset the statistics for the given encodec context.
168 | *
169 | * @param ectx The encodec context to reset the statistics for.
170 | */
171 | void encodec_reset_statistics(
172 | struct encodec_context *ectx);
173 |
174 | /**
175 | * @brief Frees the memory allocated for an encodec context.
176 | *
177 | * @param ectx The encodec context to free.
178 | */
179 | void encodec_free(
180 | struct encodec_context *ectx);
181 |
182 | #ifdef __cplusplus
183 | }
184 | #endif
--------------------------------------------------------------------------------
/vall_e/models/lora.py:
--------------------------------------------------------------------------------
1 | # Adapted from https://github.com/microsoft/LoRA/blob/main/loralib/layers.py
2 | from functools import partial
3 | import torch
4 | import torch.nn.functional as F
5 | import torch.nn.utils.parametrize as parametrize
6 |
7 | from transformers.pytorch_utils import Conv1D
8 |
9 | from torch import Tensor, nn
10 |
11 | import math
12 | from typing import Optional, List
13 |
14 | from ..utils import passes_policy
15 |
16 | # LoRA Linear for replacement
17 | # Pros: simple, just needs to reuse the replace_linear and copy weights
18 | # Cons: does not work with other Linears (bnb, bitnet, te's fp8, etc), cannot apply multiple LoRAs (although for audio why would you)
19 | class LoRALinear(nn.Linear):
20 | def __init__(
21 | self,
22 |
23 | in_features: int,
24 | out_features: int,
25 | bias: bool = True,
26 |
27 | rank: int = 4,
28 | alpha: int = 1,
29 |
30 | dropout: float = 0.1,
31 | merge_weights: bool = False,
32 | **kwargs,
33 | ):
34 | super().__init__(in_features=in_features, out_features=out_features, bias=bias, **kwargs)
35 |
36 | self.rank = rank
37 | self.alpha = alpha
38 | self.dropout = nn.Dropout(p=dropout) if dropout > 0 else lambda x: x
39 | self.merge_weights = merge_weights
40 | self.merged = False
41 | self.enabled = True
42 |
43 | self.lora_B = nn.Parameter( self.weight.new_zeros( (out_features, rank) ) )
44 | self.lora_A = nn.Parameter( self.weight.new_zeros( (rank, in_features) ) )
45 | self.scaling = self.alpha / self.rank
46 |
47 | self.weight.requires_grad = False
48 |
49 | self.reset_parameters()
50 |
51 | def reset_parameters(self):
52 | super().reset_parameters()
53 | # super silly but necessary because nn.Linear's constructor calls this
54 | if hasattr(self, 'lora_A'):
55 | nn.init.kaiming_uniform_( self.lora_A, a=math.sqrt(5) )
56 | nn.init.zeros_( self.lora_B )
57 |
58 | def train(self, mode: bool = True):
59 | super().train(mode)
60 |
61 | # training, separate lora from base weights
62 | if mode and self.merge_weights and self.merged:
63 | self.weight.data -= (self.lora_B @ self.lora_A) * self.scaling
64 | self.merged = False
65 |
66 | # not training, merge lora to base weights
67 | if not mode and self.merge_weights and not self.merged:
68 | self.weight.data += (self.lora_B @ self.lora_A) * self.scaling
69 | self.merged = True
70 |
71 | def forward(self, x: torch.Tensor):
72 | if not self.merged and self.enabled:
73 | result = F.linear(x, self.weight, bias=self.bias)
74 | result += (self.dropout(x) @ self.lora_A.transpose(0, 1) @ self.lora_B.transpose(0, 1)) * self.scaling
75 | return result
76 |
77 | return F.linear(x, self.weight, bias=self.bias)
78 |
79 | @classmethod
80 | def from_linear( cls, layer, device = None, dtype = None, **kwargs ):
81 | if device is None:
82 | device = layer.weight.device
83 | if dtype is None:
84 | dtype = layer.weight.dtype
85 | return cls( in_features = layer.in_features, out_features = layer.out_features, bias = layer.bias is not None, **kwargs ).to(device=device, dtype=dtype)
86 |
87 | # Uses parametrization to inject LoRA weights
88 | # Pros: should work with any Linears
89 | # Cons: TBD
90 | class ParameterizedLoRA(nn.Module):
91 | def __init__(
92 | self,
93 |
94 | in_features: int,
95 | out_features: int,
96 | bias: bool = True,
97 |
98 | rank: int = 4,
99 | alpha: int = 1,
100 |
101 | dropout: float = 0.1,
102 |
103 | device = None,
104 | dtype = None
105 | ):
106 | super().__init__()
107 | self.rank = rank
108 | self.alpha = alpha
109 | self.dropout = nn.Dropout(p=dropout) if dropout > 0 else lambda x: x
110 |
111 | self.lora_B = nn.Parameter( torch.zeros( (out_features, rank) ) ).to( device=device, dtype=dtype )
112 | self.lora_A = nn.Parameter( torch.zeros( (rank, in_features) ) ).to( device=device, dtype=dtype )
113 | self.scaling = self.alpha / self.rank
114 | self.enabled = True
115 |
116 | self.reset_parameters()
117 |
118 | def reset_parameters(self):
119 | nn.init.kaiming_uniform_( self.lora_A, a=math.sqrt(5) )
120 | nn.init.zeros_( self.lora_B )
121 |
122 | def forward(self, x: torch.Tensor):
123 | if self.enabled:
124 | return x + torch.matmul(self.lora_B, self.dropout(self.lora_A)).view(x.shape) * self.scaling
125 | return x
126 |
127 | @classmethod
128 | def from_linear( cls, layer, device = None, dtype = None, **kwargs ):
129 | if device is None:
130 | device = layer.weight.device
131 | if dtype is None:
132 | dtype = layer.weight.dtype
133 | # swap because we're feeding the output as our input
134 | # M$'s LoRA class arranges things to where this isn't necessary
135 | return cls( in_features = layer.out_features, out_features = layer.in_features, bias = layer.bias is not None, **kwargs ).to(device=device, dtype=dtype)
136 |
137 | @classmethod
138 | def from_conv1d( cls, layer, device = None, dtype = None, **kwargs ):
139 | if device is None:
140 | device = layer.weight.device
141 | if dtype is None:
142 | dtype = layer.weight.dtype
143 |
144 | in_channels, out_channels = layer.weight.shape
145 | # swap because we're feeding the output as our input
146 | # M$'s LoRA class arranges things to where this isn't necessary
147 | return cls( in_features = out_channels, out_features = in_channels, bias = layer.bias is not None, **kwargs ).to(device=device, dtype=dtype)
148 |
149 | def apply_lora( model, register = True, merge = False, policy = None, use_parametrize = False, **kwargs ):
150 | device = next(model.parameters()).device
151 | dtype = next(model.parameters()).dtype
152 |
153 | modules = [ k.split('.') for k, m in model.named_modules() if passes_policy( policy, k ) ]
154 |
155 | for *parent, k in modules:
156 | name = '.'.join(parent)
157 | layer = getattr( model.get_submodule(name), k )
158 |
159 | if isinstance( layer, nn.Linear ):
160 | target = nn.Linear
161 | klass = ParameterizedLoRA if use_parametrize else LoRALinear
162 | replacer = klass.from_linear
163 | elif isinstance( layer, nn.Conv1d ):
164 | target = nn.Conv1d
165 | klass = ParameterizedLoRA if use_parametrize else LoRAConv1d
166 | replacer = klass.from_conv1d
167 | elif isinstance( layer, Conv1D ):
168 | target = Conv1D
169 | klass = ParameterizedLoRA if use_parametrize else LoRAConv1d
170 | replacer = klass.from_conv1d
171 | else:
172 | continue
173 |
174 | replacement = replacer( layer, device=device, dtype=dtype, **kwargs )
175 |
176 | if use_parametrize:
177 | parametrize.register_parametrization( layer, "weight", replacement )
178 | else:
179 | setattr( model.get_submodule(name), k, replacement )
180 |
181 | return enable_lora( model )
182 |
183 | def enable_lora( model, mode = True ):
184 | for name, module in model.named_modules():
185 | if not isinstance( module, ParameterizedLoRA ) and not isinstance( module, LoRALinear ):
186 | continue
187 | module.enabled = mode
188 | return model
189 |
190 | def disable_lora( model ):
191 | return enable_lora( model, False )
192 |
193 | def freeze_non_lora_weights( model, embeddings = False ):
194 | frozen_params = []
195 |
196 | for name, param in model.named_parameters():
197 | should = 'lora_' in name or (embeddings and "_emb" in name)
198 |
199 | param.requires_grad_(should)
200 |
201 | if not should:
202 | frozen_params.append( param )
203 |
204 | return frozen_params
205 |
206 | def lora_get_state_dict( state_dict, split = True ):
207 | lora = { name: param for name, param in state_dict.items() if "lora_" in name }
208 | if not split:
209 | return lora
210 |
211 | return lora, { name: param for name, param in state_dict.items() if "lora_" not in name }
212 |
213 | def lora_load_state_dict( model, state_dict ):
214 | return model.load_state_dict( state_dict, strict = False )
--------------------------------------------------------------------------------
/vall_e.cpp/include/espeak-ng/espeak_ng.h:
--------------------------------------------------------------------------------
1 | /* eSpeak NG API.
2 | *
3 | * Copyright (C) 2015-2017 Reece H. Dunn
4 | *
5 | * This program is free software: you can redistribute it and/or modify
6 | * it under the terms of the GNU General Public License as published by
7 | * the Free Software Foundation, either version 3 of the License, or
8 | * (at your option) any later version.
9 | *
10 | * This program is distributed in the hope that it will be useful,
11 | * but WITHOUT ANY WARRANTY; without even the implied warranty of
12 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13 | * GNU General Public License for more details.
14 | *
15 | * You should have received a copy of the GNU General Public License
16 | * along with this program. If not, see .
17 | */
18 |
19 | #ifndef ESPEAK_NG_H
20 | #define ESPEAK_NG_H
21 |
22 | #include
23 |
24 | #ifdef __cplusplus
25 | extern "C"
26 | {
27 | #endif
28 |
29 | #if defined(_WIN32) || defined(_WIN64)
30 | #ifdef LIBESPEAK_NG_EXPORT
31 | #define ESPEAK_NG_API __declspec(dllexport)
32 | #else
33 | #define ESPEAK_NG_API __declspec(dllimport)
34 | #endif
35 | #else
36 | #define ESPEAK_NG_API
37 | #endif
38 |
39 | #define ESPEAKNG_DEFAULT_VOICE "en"
40 |
41 | typedef enum {
42 | ENS_GROUP_MASK = 0x70000000,
43 | ENS_GROUP_ERRNO = 0x00000000, /* Values 0-255 map to errno error codes. */
44 | ENS_GROUP_ESPEAK_NG = 0x10000000, /* eSpeak NG error codes. */
45 |
46 | /* eSpeak NG 1.49.0 */
47 | ENS_OK = 0,
48 | ENS_COMPILE_ERROR = 0x100001FF,
49 | ENS_VERSION_MISMATCH = 0x100002FF,
50 | ENS_FIFO_BUFFER_FULL = 0x100003FF,
51 | ENS_NOT_INITIALIZED = 0x100004FF,
52 | ENS_AUDIO_ERROR = 0x100005FF,
53 | ENS_VOICE_NOT_FOUND = 0x100006FF,
54 | ENS_MBROLA_NOT_FOUND = 0x100007FF,
55 | ENS_MBROLA_VOICE_NOT_FOUND = 0x100008FF,
56 | ENS_EVENT_BUFFER_FULL = 0x100009FF,
57 | ENS_NOT_SUPPORTED = 0x10000AFF,
58 | ENS_UNSUPPORTED_PHON_FORMAT = 0x10000BFF,
59 | ENS_NO_SPECT_FRAMES = 0x10000CFF,
60 | ENS_EMPTY_PHONEME_MANIFEST = 0x10000DFF,
61 | ENS_SPEECH_STOPPED = 0x10000EFF,
62 |
63 | /* eSpeak NG 1.49.2 */
64 | ENS_UNKNOWN_PHONEME_FEATURE = 0x10000FFF,
65 | ENS_UNKNOWN_TEXT_ENCODING = 0x100010FF,
66 | } espeak_ng_STATUS;
67 |
68 | typedef enum {
69 | ENOUTPUT_MODE_SYNCHRONOUS = 0x0001,
70 | ENOUTPUT_MODE_SPEAK_AUDIO = 0x0002,
71 | } espeak_ng_OUTPUT_MODE;
72 |
73 | typedef enum {
74 | ENGENDER_UNKNOWN = 0,
75 | ENGENDER_MALE = 1,
76 | ENGENDER_FEMALE = 2,
77 | ENGENDER_NEUTRAL = 3,
78 | } espeak_ng_VOICE_GENDER;
79 |
80 | typedef struct
81 | {
82 | void (*outputPhoSymbol)(char* pho_code,int pho_type);
83 | void (*outputSilence)(short echo_tail);
84 | void (*outputVoiced)(short sample);
85 | void (*outputUnvoiced)(short sample);
86 | } espeak_ng_OUTPUT_HOOKS;
87 |
88 | /* eSpeak NG 1.49.0 */
89 |
90 | typedef struct espeak_ng_ERROR_CONTEXT_ *espeak_ng_ERROR_CONTEXT;
91 |
92 | ESPEAK_NG_API void
93 | espeak_ng_ClearErrorContext(espeak_ng_ERROR_CONTEXT *context);
94 |
95 | ESPEAK_NG_API void
96 | espeak_ng_GetStatusCodeMessage(espeak_ng_STATUS status,
97 | char *buffer,
98 | size_t length);
99 |
100 | ESPEAK_NG_API void
101 | espeak_ng_PrintStatusCodeMessage(espeak_ng_STATUS status,
102 | FILE *out,
103 | espeak_ng_ERROR_CONTEXT context);
104 |
105 | ESPEAK_NG_API void
106 | espeak_ng_InitializePath(const char *path);
107 |
108 | ESPEAK_NG_API espeak_ng_STATUS
109 | espeak_ng_Initialize(espeak_ng_ERROR_CONTEXT *context);
110 |
111 | ESPEAK_NG_API espeak_ng_STATUS
112 | espeak_ng_InitializeOutput(espeak_ng_OUTPUT_MODE output_mode,
113 | int buffer_length,
114 | const char *device);
115 |
116 | ESPEAK_NG_API int
117 | espeak_ng_GetSampleRate(void);
118 |
119 | ESPEAK_NG_API espeak_ng_STATUS
120 | espeak_ng_SetParameter(espeak_PARAMETER parameter,
121 | int value,
122 | int relative);
123 |
124 | ESPEAK_NG_API espeak_ng_STATUS
125 | espeak_ng_SetPhonemeEvents(int enable, int ipa);
126 |
127 | ESPEAK_NG_API espeak_ng_STATUS
128 | espeak_ng_SetPunctuationList(const wchar_t *punctlist);
129 |
130 | ESPEAK_NG_API espeak_ng_STATUS
131 | espeak_ng_SetVoiceByName(const char *name);
132 |
133 | ESPEAK_NG_API espeak_ng_STATUS
134 | espeak_ng_SetVoiceByFile(const char *filename);
135 |
136 | ESPEAK_NG_API espeak_ng_STATUS
137 | espeak_ng_SetVoiceByProperties(espeak_VOICE *voice_selector);
138 |
139 | ESPEAK_NG_API espeak_ng_STATUS
140 | espeak_ng_Synthesize(const void *text,
141 | size_t size,
142 | unsigned int position,
143 | espeak_POSITION_TYPE position_type,
144 | unsigned int end_position,
145 | unsigned int flags,
146 | unsigned int *unique_identifier,
147 | void *user_data);
148 |
149 | ESPEAK_NG_API espeak_ng_STATUS
150 | espeak_ng_SynthesizeMark(const void *text,
151 | size_t size,
152 | const char *index_mark,
153 | unsigned int end_position,
154 | unsigned int flags,
155 | unsigned int *unique_identifier,
156 | void *user_data);
157 |
158 | ESPEAK_NG_API espeak_ng_STATUS
159 | espeak_ng_SpeakKeyName(const char *key_name);
160 |
161 | ESPEAK_NG_API espeak_ng_STATUS
162 | espeak_ng_SpeakCharacter(wchar_t character);
163 |
164 | ESPEAK_NG_API espeak_ng_STATUS
165 | espeak_ng_Cancel(void);
166 |
167 | ESPEAK_NG_API espeak_ng_STATUS
168 | espeak_ng_Synchronize(void);
169 |
170 | ESPEAK_NG_API espeak_ng_STATUS
171 | espeak_ng_Terminate(void);
172 |
173 | ESPEAK_NG_API espeak_ng_STATUS
174 | espeak_ng_CompileDictionary(const char *dsource,
175 | const char *dict_name,
176 | FILE *log,
177 | int flags,
178 | espeak_ng_ERROR_CONTEXT *context);
179 |
180 | ESPEAK_NG_API espeak_ng_STATUS
181 | espeak_ng_CompileMbrolaVoice(const char *path,
182 | FILE *log,
183 | espeak_ng_ERROR_CONTEXT *context);
184 |
185 | ESPEAK_NG_API espeak_ng_STATUS
186 | espeak_ng_CompilePhonemeData(long rate,
187 | FILE *log,
188 | espeak_ng_ERROR_CONTEXT *context);
189 |
190 | ESPEAK_NG_API espeak_ng_STATUS
191 | espeak_ng_CompileIntonation(FILE *log,
192 | espeak_ng_ERROR_CONTEXT *context);
193 |
194 |
195 | ESPEAK_NG_API espeak_ng_STATUS
196 | espeak_ng_CompileIntonationPath(const char *source_path,
197 | const char *destination_path,
198 | FILE *log,
199 | espeak_ng_ERROR_CONTEXT *context);
200 |
201 | /* eSpeak NG 1.49.1 */
202 |
203 | ESPEAK_NG_API espeak_ng_STATUS
204 | espeak_ng_CompilePhonemeDataPath(long rate,
205 | const char *source_path,
206 | const char *destination_path,
207 | FILE *log,
208 | espeak_ng_ERROR_CONTEXT *context);
209 |
210 | ESPEAK_NG_API espeak_ng_STATUS
211 | espeak_ng_SetOutputHooks(espeak_ng_OUTPUT_HOOKS* hooks);
212 | ESPEAK_NG_API espeak_ng_STATUS
213 | espeak_ng_SetConstF0(int f0);
214 |
215 | ESPEAK_NG_API espeak_ng_STATUS
216 | espeak_ng_SetRandSeed(long seed);
217 |
218 |
219 | #ifdef __cplusplus
220 | }
221 | #endif
222 |
223 | #endif
224 |
--------------------------------------------------------------------------------