├── .gitignore ├── .gitmodules ├── LICENSE ├── README.md ├── config ├── LibriTTS │ ├── ar-quarter.yml │ ├── ar.yml │ ├── nar-quarter.yml │ └── nar.yml └── test │ ├── ar.yml │ └── nar.yml ├── data └── test │ ├── test.normalized.txt │ ├── test.phn.txt │ ├── test.qnt.pt │ ├── test.wav │ ├── test2.phn.txt │ └── test2.qnt.pt ├── scripts ├── plot.py └── run.sh ├── setup.py ├── vall-e.png └── vall_e ├── __init__.py ├── __main__.py ├── config.py ├── data.py ├── emb ├── __init__.py ├── g2p.py └── qnt.py ├── export.py ├── sampler.py ├── train.py └── vall_e ├── __init__.py ├── ar.py ├── base.py └── nar.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | /data 3 | /logs 4 | /ckpts 5 | /.cache 6 | /config 7 | /*.egg-info 8 | /vall_e/version.py 9 | /build 10 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "mini_vall_e/utils"] 2 | path = vall_e/utils 3 | url = https://github.com/enhuiz/pytorch-training-utils.git 4 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Zhe Niu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 | 3 |

4 | 5 | # VALL-E 6 | 7 | An unofficial PyTorch implementation of [VALL-E](https://valle-demo.github.io/), based on the [EnCodec](https://github.com/facebookresearch/encodec) tokenizer. 8 | 9 | [!["Buy Me A Coffee"](https://www.buymeacoffee.com/assets/img/custom_images/orange_img.png)](https://www.buymeacoffee.com/enhuiz) 10 | 11 | ## Get Started 12 | 13 | > A toy Google Colab example: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1wEze0kQ0gt9B3bQmmbtbSXCoCTpq5vg-?usp=sharing). 14 | > Please note that this example overfits a single utterance under the `data/test` and is not usable. 15 | > The pretrained model is yet to come. 16 | 17 | ### Requirements 18 | 19 | Since the trainer is based on [DeepSpeed](https://github.com/microsoft/DeepSpeed#requirements), you will need to have a GPU that DeepSpeed has developed and tested against, as well as a CUDA or ROCm compiler pre-installed to install this package. 20 | 21 | ### Install 22 | 23 | ``` 24 | pip install git+https://github.com/enhuiz/vall-e 25 | ``` 26 | 27 | Or you may clone by: 28 | 29 | ``` 30 | git clone --recurse-submodules https://github.com/enhuiz/vall-e.git 31 | ``` 32 | 33 | Note that the code is only tested under `Python 3.10.7`. 34 | 35 | ### Train 36 | 37 | 1. Put your data into a folder, e.g. `data/your_data`. Audio files should be named with the suffix `.wav` and text files with `.normalized.txt`. 38 | 39 | 2. Quantize the data: 40 | 41 | ``` 42 | python -m vall_e.emb.qnt data/your_data 43 | ``` 44 | 45 | 3. Generate phonemes based on the text: 46 | 47 | ``` 48 | python -m vall_e.emb.g2p data/your_data 49 | ``` 50 | 51 | 4. Customize your configuration by creating `config/your_data/ar.yml` and `config/your_data/nar.yml`. Refer to the example configs in `config/test` and `vall_e/config.py` for details. You may choose different model presets, check `vall_e/vall_e/__init__.py`. 52 | 53 | 5. Train the AR or NAR model using the following scripts: 54 | 55 | ``` 56 | python -m vall_e.train yaml=config/your_data/ar_or_nar.yml 57 | ``` 58 | 59 | You may quit your training any time by just typing `quit` in your CLI. The latest checkpoint will be automatically saved. 60 | 61 | ### Export 62 | 63 | Both trained models need to be exported to a certain path. To export either of them, run: 64 | 65 | ``` 66 | python -m vall_e.export zoo/ar_or_nar.pt yaml=config/your_data/ar_or_nar.yml 67 | ``` 68 | 69 | This will export the latest checkpoint. 70 | 71 | ### Synthesis 72 | 73 | ``` 74 | python -m vall_e --ar-ckpt zoo/ar.pt --nar-ckpt zoo/nar.pt 75 | ``` 76 | 77 | ## TODO 78 | 79 | - [x] AR model for the first quantizer 80 | - [x] Audio decoding from tokens 81 | - [x] NAR model for the rest quantizers 82 | - [x] Trainers for both models 83 | - [x] Implement AdaLN for NAR model. 84 | - [x] Sample-wise quantization level sampling for NAR training. 85 | - [ ] Pre-trained checkpoint and demos on LibriTTS 86 | - [x] Synthesis CLI 87 | 88 | ## Notice 89 | 90 | - [EnCodec](https://github.com/facebookresearch/encodec) is licensed under CC-BY-NC 4.0. If you use the code to generate audio quantization or perform decoding, it is important to adhere to the terms of their license. 91 | 92 | ## Citations 93 | 94 | ```bibtex 95 | @article{wang2023neural, 96 | title={Neural Codec Language Models are Zero-Shot Text to Speech Synthesizers}, 97 | author={Wang, Chengyi and Chen, Sanyuan and Wu, Yu and Zhang, Ziqiang and Zhou, Long and Liu, Shujie and Chen, Zhuo and Liu, Yanqing and Wang, Huaming and Li, Jinyu and others}, 98 | journal={arXiv preprint arXiv:2301.02111}, 99 | year={2023} 100 | } 101 | ``` 102 | 103 | ```bibtex 104 | @article{defossez2022highfi, 105 | title={High Fidelity Neural Audio Compression}, 106 | author={Défossez, Alexandre and Copet, Jade and Synnaeve, Gabriel and Adi, Yossi}, 107 | journal={arXiv preprint arXiv:2210.13438}, 108 | year={2022} 109 | } 110 | ``` 111 | -------------------------------------------------------------------------------- /config/LibriTTS/ar-quarter.yml: -------------------------------------------------------------------------------- 1 | data_dirs: [data/LibriTTS/] 2 | spkr_name_getter: "lambda p: p.parts[-3]" 3 | 4 | model: ar-quarter 5 | batch_size: 48 6 | eval_batch_size: 48 7 | eval_every: 10_000 8 | -------------------------------------------------------------------------------- /config/LibriTTS/ar.yml: -------------------------------------------------------------------------------- 1 | data_dirs: [data/LibriTTS/] 2 | spkr_name_getter: "lambda p: p.parts[-3]" 3 | 4 | model: ar 5 | batch_size: 24 6 | eval_batch_size: 24 7 | eval_every: 10_000 8 | 9 | sampling_temperature: 1.0 10 | -------------------------------------------------------------------------------- /config/LibriTTS/nar-quarter.yml: -------------------------------------------------------------------------------- 1 | data_dirs: [data/LibriTTS/] 2 | spkr_name_getter: "lambda p: p.parts[-3]" 3 | 4 | model: nar-quarter 5 | batch_size: 48 6 | eval_batch_size: 48 7 | -------------------------------------------------------------------------------- /config/LibriTTS/nar.yml: -------------------------------------------------------------------------------- 1 | data_dirs: [data/LibriTTS/] 2 | spkr_name_getter: "lambda p: p.parts[-3]" 3 | 4 | model: nar 5 | batch_size: 24 6 | eval_batch_size: 24 7 | eval_every: 1_000 8 | 9 | sampling_temperature: 0.2 10 | -------------------------------------------------------------------------------- /config/test/ar.yml: -------------------------------------------------------------------------------- 1 | data_dirs: [data/test] 2 | 3 | model: ar-quarter 4 | batch_size: 1 5 | eval_batch_size: 1 6 | save_ckpt_every: 500 7 | eval_every: 500 8 | max_iter: 1000 9 | -------------------------------------------------------------------------------- /config/test/nar.yml: -------------------------------------------------------------------------------- 1 | data_dirs: [data/test] 2 | 3 | model: nar-quarter 4 | batch_size: 1 5 | eval_batch_size: 1 6 | save_ckpt_every: 500 7 | eval_every: 500 8 | max_iter: 1000 9 | -------------------------------------------------------------------------------- /data/test/test.normalized.txt: -------------------------------------------------------------------------------- 1 | hello world 2 | -------------------------------------------------------------------------------- /data/test/test.phn.txt: -------------------------------------------------------------------------------- 1 | HH AH0 L OW1 _ W ER1 L D 2 | -------------------------------------------------------------------------------- /data/test/test.qnt.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/enhuiz/vall-e/3476d393d2133fa9b50d5ad999ca13b95fc22060/data/test/test.qnt.pt -------------------------------------------------------------------------------- /data/test/test.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/enhuiz/vall-e/3476d393d2133fa9b50d5ad999ca13b95fc22060/data/test/test.wav -------------------------------------------------------------------------------- /data/test/test2.phn.txt: -------------------------------------------------------------------------------- 1 | test.phn.txt -------------------------------------------------------------------------------- /data/test/test2.qnt.pt: -------------------------------------------------------------------------------- 1 | test.qnt.pt -------------------------------------------------------------------------------- /scripts/plot.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import argparse 4 | import json 5 | import re 6 | from pathlib import Path 7 | 8 | import matplotlib.pyplot as plt 9 | import pandas as pd 10 | 11 | 12 | def plot(paths, args): 13 | dfs = [] 14 | 15 | for path in paths: 16 | with open(path, "r") as f: 17 | text = f.read() 18 | 19 | rows = [] 20 | 21 | pattern = r"(\{.+?\})" 22 | 23 | for row in re.findall(pattern, text, re.DOTALL): 24 | try: 25 | row = json.loads(row) 26 | except Exception as e: 27 | continue 28 | 29 | if "global_step" in row: 30 | rows.append(row) 31 | 32 | df = pd.DataFrame(rows) 33 | 34 | if "name" in df: 35 | df["name"] = df["name"].fillna("train") 36 | else: 37 | df["name"] = "train" 38 | 39 | df["group"] = str(path.parents[args.group_level]) 40 | df["group"] = df["group"] + "/" + df["name"] 41 | 42 | dfs.append(df) 43 | 44 | df = pd.concat(dfs) 45 | 46 | if args.max_y is not None: 47 | df = df[df["global_step"] < args.max_x] 48 | 49 | for gtag, gdf in sorted( 50 | df.groupby("group"), 51 | key=lambda p: (p[0].split("/")[-1], p[0]), 52 | ): 53 | for y in args.ys: 54 | gdf = gdf.sort_values("global_step") 55 | 56 | if gdf[y].isna().all(): 57 | continue 58 | 59 | if args.max_y is not None: 60 | gdf = gdf[gdf[y] < args.max_y] 61 | 62 | gdf[y] = gdf[y].ewm(10).mean() 63 | 64 | gdf.plot( 65 | x="global_step", 66 | y=y, 67 | label=f"{gtag}/{y}", 68 | ax=plt.gca(), 69 | marker="x" if len(gdf) < 100 else None, 70 | alpha=0.7, 71 | ) 72 | 73 | plt.gca().legend( 74 | loc="center left", 75 | fancybox=True, 76 | shadow=True, 77 | bbox_to_anchor=(1.04, 0.5), 78 | ) 79 | 80 | 81 | def main(): 82 | parser = argparse.ArgumentParser() 83 | parser.add_argument("ys", nargs="+") 84 | parser.add_argument("--log-dir", default="logs", type=Path) 85 | parser.add_argument("--out-dir", default="logs", type=Path) 86 | parser.add_argument("--filename", default="log.txt") 87 | parser.add_argument("--max-x", type=float, default=float("inf")) 88 | parser.add_argument("--max-y", type=float, default=float("inf")) 89 | parser.add_argument("--group-level", default=1) 90 | parser.add_argument("--filter", default=None) 91 | args = parser.parse_args() 92 | 93 | paths = args.log_dir.rglob(f"**/{args.filename}") 94 | 95 | if args.filter: 96 | paths = filter(lambda p: re.match(".*" + args.filter + ".*", str(p)), paths) 97 | 98 | plot(paths, args) 99 | 100 | name = "-".join(args.ys) 101 | out_path = (args.out_dir / name).with_suffix(".png") 102 | plt.savefig(out_path, bbox_inches="tight") 103 | 104 | 105 | if __name__ == "__main__": 106 | main() 107 | -------------------------------------------------------------------------------- /scripts/run.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | until $@; do echo retrying; done 4 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | from pathlib import Path 3 | from datetime import datetime 4 | from setuptools import setup, find_packages 5 | 6 | 7 | def shell(*args): 8 | out = subprocess.check_output(args) 9 | return out.decode("ascii").strip() 10 | 11 | 12 | def write_version(version_core, pre_release=True): 13 | if pre_release: 14 | time = shell("git", "log", "-1", "--format=%cd", "--date=iso") 15 | time = datetime.strptime(time, "%Y-%m-%d %H:%M:%S %z") 16 | time = time.strftime("%Y%m%d%H%M%S") 17 | version = f"{version_core}-dev{time}" 18 | else: 19 | version = version_core 20 | 21 | with open(Path("vall_e", "version.py"), "w") as f: 22 | f.write('__version__ = "{}"\n'.format(version)) 23 | 24 | return version 25 | 26 | 27 | with open("README.md", "r") as f: 28 | long_description = f.read() 29 | 30 | setup( 31 | name="vall-e", 32 | python_requires=">=3.10.0", 33 | version=write_version("0.0.1"), 34 | description="An unofficial toy implementation of the audio LM VALL-E", 35 | author="enhuiz", 36 | author_email="niuzhe.nz@outlook.com", 37 | long_description=long_description, 38 | long_description_content_type="text/markdown", 39 | packages=find_packages(), 40 | install_requires=[ 41 | "coloredlogs>=15.0.1", 42 | "deepspeed>=0.7.7", 43 | "diskcache>=5.4.0", 44 | "einops>=0.6.0", 45 | "encodec>=0.1.1", 46 | "g2p_en>=2.1.0", 47 | "humanize>=4.4.0", 48 | "matplotlib>=3.6.0", 49 | "numpy>=1.23.3", 50 | "omegaconf>=2.2.3", 51 | "openTSNE>=0.6.2", 52 | "pandas>=1.5.0", 53 | "soundfile>=0.11.0", 54 | "torch>=1.13.0", 55 | "torchaudio>=0.13.0", 56 | "tqdm>=4.64.1", 57 | ], 58 | url="https://github.com/enhuiz/vall-e", 59 | ) 60 | -------------------------------------------------------------------------------- /vall-e.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/enhuiz/vall-e/3476d393d2133fa9b50d5ad999ca13b95fc22060/vall-e.png -------------------------------------------------------------------------------- /vall_e/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/enhuiz/vall-e/3476d393d2133fa9b50d5ad999ca13b95fc22060/vall_e/__init__.py -------------------------------------------------------------------------------- /vall_e/__main__.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from pathlib import Path 3 | 4 | import torch 5 | from einops import rearrange 6 | 7 | from .emb import g2p, qnt 8 | from .utils import to_device 9 | 10 | 11 | def main(): 12 | parser = argparse.ArgumentParser("VALL-E TTS") 13 | parser.add_argument("text") 14 | parser.add_argument("reference", type=Path) 15 | parser.add_argument("out_path", type=Path) 16 | parser.add_argument("--ar-ckpt", type=Path, default="zoo/ar.pt") 17 | parser.add_argument("--nar-ckpt", type=Path, default="zoo/nar.pt") 18 | parser.add_argument("--device", default="cuda") 19 | args = parser.parse_args() 20 | 21 | ar = torch.load(args.ar_ckpt).to(args.device) 22 | nar = torch.load(args.nar_ckpt).to(args.device) 23 | 24 | symmap = ar.phone_symmap 25 | 26 | proms = qnt.encode_from_file(args.reference) 27 | proms = rearrange(proms, "1 l t -> t l") 28 | 29 | phns = torch.tensor([symmap[p] for p in g2p.encode(args.text)]) 30 | 31 | proms = to_device(proms, args.device) 32 | phns = to_device(phns, args.device) 33 | 34 | resp_list = ar(text_list=[phns], proms_list=[proms]) 35 | resps_list = [r.unsqueeze(-1) for r in resp_list] 36 | 37 | resps_list = nar(text_list=[phns], proms_list=[proms], resps_list=resps_list) 38 | qnt.decode_to_file(resps=resps_list[0], path=args.out_path) 39 | print(args.out_path, "saved.") 40 | 41 | 42 | if __name__ == "__main__": 43 | main() 44 | -------------------------------------------------------------------------------- /vall_e/config.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from functools import cached_property 3 | from pathlib import Path 4 | 5 | import diskcache 6 | 7 | from .utils import Config as ConfigBase 8 | 9 | 10 | @dataclass(frozen=True) 11 | class Config(ConfigBase): 12 | data_root: Path = Path("data") 13 | data_dirs: list[Path] = field(default_factory=lambda: []) 14 | 15 | @property 16 | def sample_rate(self): 17 | return 24_000 18 | 19 | p_additional_prompt: float = 0.8 20 | max_prompts: int = 3 21 | 22 | max_num_val: int = 20 23 | max_val_ar_steps: int = 300 24 | 25 | token_dim: int = 256 26 | num_tokens: int = 1024 27 | 28 | nj: int = 8 29 | batch_size: int = 32 30 | eval_batch_size: int = 32 31 | warmup_min_lr: float = 1e-6 32 | warmup_max_lr: float = 2e-4 33 | dis_warmup_max_lr: float = 4e-4 34 | warmup_num_steps: int = 1_000 35 | max_iter: int = 1_000_000 36 | gradient_clipping: float = 100 37 | eval_every: int = 2_000 38 | save_ckpt_every: int = 2_000 39 | 40 | model: str = "ar-quarter" 41 | spkr_name_getter: str = "lambda p: p.parts[-2]" 42 | 43 | min_phones: int = 10 44 | max_phones: int = 50 45 | 46 | use_fp16: bool = True 47 | gradient_accumulation_steps: int = 1 48 | sampling_temperature: float = 1.0 49 | 50 | cache_dataloader: bool = False 51 | 52 | @cached_property 53 | def get_spkr(self): 54 | return eval(self.spkr_name_getter) 55 | 56 | @property 57 | def fp16_cfg(self): 58 | return { 59 | "enabled": self.use_fp16, 60 | } 61 | 62 | @property 63 | def ds_cfg(self): 64 | return { 65 | "train_micro_batch_size_per_gpu": self.batch_size, 66 | "gradient_accumulation_steps": self.gradient_accumulation_steps, 67 | "optimizer": { 68 | "type": "Adam", 69 | "lr": self.warmup_min_lr, 70 | }, 71 | "scheduler": { 72 | "type": "WarmupDecayLR", 73 | "params": { 74 | "warmup_min_lr": self.warmup_min_lr, 75 | "warmup_max_lr": self.warmup_max_lr, 76 | "warmup_num_steps": self.warmup_num_steps, 77 | "total_num_steps": self.max_iter, 78 | "warmup_type": "linear", 79 | }, 80 | }, 81 | "gradient_clipping": self.gradient_clipping, 82 | "fp16": self.fp16_cfg, 83 | } 84 | 85 | @property 86 | def cache_dir(self): 87 | return ".cache" / self.relpath 88 | 89 | @cached_property 90 | def diskcache(self): 91 | if self.cache_dataloader: 92 | return diskcache.Cache(self.cache_dir).memoize 93 | return lambda: lambda x: x 94 | 95 | 96 | cfg = Config.from_cli() 97 | 98 | if __name__ == "__main__": 99 | print(cfg) 100 | -------------------------------------------------------------------------------- /vall_e/data.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import logging 3 | import random 4 | from collections import defaultdict 5 | from functools import cache, cached_property 6 | from itertools import groupby, zip_longest 7 | from typing import Any 8 | 9 | import numpy as np 10 | import torch 11 | from torch import Tensor 12 | from torch.utils.data import DataLoader, Dataset 13 | from tqdm import tqdm 14 | 15 | from .config import cfg 16 | from .sampler import Sampler 17 | 18 | torch.multiprocessing.set_sharing_strategy("file_system") 19 | 20 | _logger = logging.getLogger(__name__) 21 | 22 | 23 | def _replace_file_extension(path, suffix): 24 | return (path.parent / path.name.split(".")[0]).with_suffix(suffix) 25 | 26 | 27 | def _get_quant_path(path): 28 | return _replace_file_extension(path, ".qnt.pt") 29 | 30 | 31 | def _load_quants(path) -> Tensor: 32 | """ 33 | Returns: 34 | quants: (t q) 35 | """ 36 | path = _get_quant_path(path) 37 | return torch.load(path)[0].t() 38 | 39 | 40 | @cache 41 | def _get_phones(path): 42 | path = _replace_file_extension(path, ".phn.txt") 43 | with open(path, "r", encoding="utf8") as f: 44 | content = f.read() 45 | return [""] + content.split() + [""] 46 | 47 | 48 | def _interleaved_reorder(l, fn): 49 | groups = defaultdict(list) 50 | for e in l: 51 | groups[fn(e)].append(e) 52 | groups = {k: groups[k] for k in sorted(groups)} 53 | for interleaved in zip_longest(*groups.values()): 54 | for value in interleaved: 55 | if value is not None: 56 | yield value 57 | 58 | 59 | @cache 60 | def _validate(path, min_phones, max_phones): 61 | phones = _get_phones(path) 62 | unique_phones = list(set(phones)) 63 | if len(unique_phones) == 0: 64 | return False 65 | if len(unique_phones) == 1 and unique_phones[0] == "_": 66 | return False 67 | if len(phones) < min_phones: 68 | return False 69 | if len(phones) > max_phones: 70 | return False 71 | return True 72 | 73 | 74 | class VALLEDatset(Dataset): 75 | def __init__( 76 | self, 77 | paths, 78 | phone_symmap=None, 79 | spkr_symmap=None, 80 | min_phones=cfg.min_phones, 81 | max_phones=cfg.max_phones, 82 | training=False, 83 | extra_paths_by_spkr_name: dict[str, list] = {}, 84 | ): 85 | super().__init__() 86 | self._head = None 87 | self.min_phones = min_phones 88 | self.max_phones = max_phones 89 | 90 | self.paths = [ 91 | path for path in paths if _validate(path, self.min_phones, self.max_phones) 92 | ] 93 | 94 | self.spkr_symmap = spkr_symmap or self._get_spkr_symmap() 95 | self.phone_symmap = phone_symmap or self._get_phone_symmap() 96 | self.training = training 97 | 98 | self.paths_by_spkr_name = self._get_paths_by_spkr_name(extra_paths_by_spkr_name) 99 | 100 | self.paths = [ 101 | p for p in self.paths if len(self.paths_by_spkr_name[cfg.get_spkr(p)]) > 1 102 | ] 103 | 104 | if len(self.paths) == 0 and training: 105 | raise ValueError("No valid path is found for training.") 106 | 107 | if training: 108 | self.sampler = Sampler(self.paths, [cfg.get_spkr]) 109 | else: 110 | self.sampler = None 111 | 112 | def _get_paths_by_spkr_name(self, extra_paths_by_spkr_name: dict[str, list]): 113 | ret = defaultdict(list) 114 | for path in self.paths: 115 | if _get_quant_path(path).exists(): 116 | ret[cfg.get_spkr(path)].append(path) 117 | for k, v in extra_paths_by_spkr_name.items(): 118 | ret[k].extend(v) 119 | return {**ret} 120 | 121 | @cached_property 122 | def phones(self): 123 | return sorted(set().union(*[_get_phones(path) for path in self.paths])) 124 | 125 | def _get_phone_symmap(self): 126 | # Note that we use phone symmap starting from 1 so that we can safely pad 0. 127 | return {s: i for i, s in enumerate(self.phones, 1)} 128 | 129 | @cached_property 130 | def spkrs(self): 131 | return sorted({cfg.get_spkr(path) for path in self.paths}) 132 | 133 | def _get_spkr_symmap(self): 134 | return {s: i for i, s in enumerate(self.spkrs)} 135 | 136 | def sample_prompts(self, spkr_name, ignore): 137 | prom_list = [] 138 | 139 | choices = set(self.paths_by_spkr_name[spkr_name]) - {ignore} 140 | choices = [*choices] 141 | 142 | if len(choices) == 0: 143 | raise ValueError( 144 | f"Failed to find another different utterance for {spkr_name}." 145 | ) 146 | 147 | for _ in range(cfg.max_prompts): 148 | path = random.choice(choices) 149 | prom_list.append(_load_quants(path)) 150 | if random.random() > cfg.p_additional_prompt: 151 | break 152 | 153 | prom = torch.cat(prom_list) 154 | 155 | return prom 156 | 157 | def __getitem__(self, index): 158 | if self.training: 159 | assert self.sampler is not None 160 | path = self.sampler.sample() 161 | else: 162 | path = self.paths[index] 163 | 164 | spkr_name = cfg.get_spkr(path) 165 | text = torch.tensor([*map(self.phone_symmap.get, _get_phones(path))]) 166 | proms = self.sample_prompts(spkr_name, ignore=path) 167 | resps = _load_quants(path) 168 | resp = resps[..., 0] 169 | 170 | return dict( 171 | path=path, 172 | spkr_name=spkr_name, 173 | text=text, 174 | proms=proms, 175 | resps=resps, 176 | resp=resp, 177 | ) 178 | 179 | def head_(self, n): 180 | self._head = n 181 | 182 | def training_(self, value): 183 | self.training = value 184 | 185 | def interleaved_reorder_(self, fn): 186 | self.paths = [*_interleaved_reorder(self.paths, fn)] 187 | 188 | def __len__(self): 189 | return min(len(self.paths), self._head or len(self.paths)) 190 | 191 | 192 | def collate_fn(samples: list[dict]): 193 | batch: dict[str, Any] = {k: [s[k] for s in samples] for k in samples[0]} 194 | return batch 195 | 196 | 197 | def _seed_worker(worker_id): 198 | worker_seed = torch.initial_seed() % 2**32 199 | np.random.seed(worker_seed) 200 | random.seed(worker_seed) 201 | 202 | 203 | def _create_dataloader(dataset, training): 204 | return DataLoader( 205 | dataset=dataset, 206 | batch_size=cfg.batch_size if training else cfg.eval_batch_size, 207 | shuffle=training, 208 | drop_last=training, 209 | num_workers=cfg.nj, 210 | collate_fn=collate_fn, 211 | persistent_workers=True, 212 | worker_init_fn=_seed_worker, 213 | ) 214 | 215 | 216 | def _load_train_val_paths(): 217 | paths = [] 218 | train_paths = [] 219 | val_paths = [] 220 | 221 | for data_dir in cfg.data_dirs: 222 | paths.extend(tqdm(data_dir.rglob("*.qnt.pt"))) 223 | 224 | if len(paths) == 0: 225 | raise RuntimeError(f"Failed to find any .qnt.pt file in {cfg.data_dirs}.") 226 | 227 | pairs = sorted([(cfg.get_spkr(p), p) for p in paths]) 228 | del paths 229 | 230 | for _, group in groupby(pairs, lambda pair: pair[0]): 231 | paths = sorted([p for _, p in group]) 232 | random.seed(0) 233 | random.shuffle(paths) 234 | n = round(len(paths) * 0.95) 235 | train_paths.extend(paths[:n]) 236 | val_paths.extend(paths[n:]) 237 | 238 | train_paths, val_paths = map(sorted, [train_paths, val_paths]) 239 | 240 | return train_paths, val_paths 241 | 242 | 243 | @cfg.diskcache() 244 | def create_datasets(): 245 | train_paths, val_paths = _load_train_val_paths() 246 | 247 | train_dataset = VALLEDatset( 248 | train_paths, 249 | training=True, 250 | ) 251 | 252 | val_dataset = VALLEDatset( 253 | val_paths, 254 | train_dataset.phone_symmap, 255 | train_dataset.spkr_symmap, 256 | extra_paths_by_spkr_name=train_dataset.paths_by_spkr_name, 257 | ) 258 | 259 | val_dataset.interleaved_reorder_(cfg.get_spkr) 260 | val_dataset.head_(cfg.max_num_val) 261 | 262 | return train_dataset, val_dataset 263 | 264 | 265 | def create_train_val_dataloader(): 266 | train_dataset, val_dataset = create_datasets() 267 | 268 | train_dl = _create_dataloader(train_dataset, training=True) 269 | val_dl = _create_dataloader(val_dataset, training=False) 270 | 271 | _logger.info(str(train_dataset.phone_symmap)) 272 | _logger.info(str(train_dataset.spkr_symmap)) 273 | 274 | _logger.info(f"#samples (train): {len(train_dataset)}.") 275 | _logger.info(f"#samples (val): {len(val_dataset)}.") 276 | 277 | subtrain_dataset = copy.deepcopy(train_dataset) 278 | subtrain_dataset.interleaved_reorder_(cfg.get_spkr) 279 | subtrain_dataset.head_(cfg.max_num_val) 280 | subtrain_dataset.training_(False) 281 | subtrain_dl = _create_dataloader(subtrain_dataset, training=False) 282 | assert isinstance(subtrain_dl.dataset, VALLEDatset) 283 | 284 | return train_dl, subtrain_dl, val_dl 285 | 286 | 287 | if __name__ == "__main__": 288 | train_dl, subtrain_dl, val_dl = create_train_val_dataloader() 289 | sample = train_dl.dataset[0] 290 | print(sample) 291 | -------------------------------------------------------------------------------- /vall_e/emb/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/enhuiz/vall-e/3476d393d2133fa9b50d5ad999ca13b95fc22060/vall_e/emb/__init__.py -------------------------------------------------------------------------------- /vall_e/emb/g2p.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import random 3 | import string 4 | from functools import cache 5 | from pathlib import Path 6 | 7 | import torch 8 | from g2p_en import G2p 9 | from tqdm import tqdm 10 | 11 | 12 | @cache 13 | def _get_model(): 14 | return G2p() 15 | 16 | 17 | @cache 18 | def _get_graphs(path): 19 | with open(path, "r") as f: 20 | graphs = f.read() 21 | return graphs 22 | 23 | 24 | def encode(graphs: str) -> list[str]: 25 | g2p = _get_model() 26 | phones = g2p(graphs) 27 | ignored = {" ", *string.punctuation} 28 | return ["_" if p in ignored else p for p in phones] 29 | 30 | 31 | @torch.no_grad() 32 | def main(): 33 | parser = argparse.ArgumentParser() 34 | parser.add_argument("folder", type=Path) 35 | parser.add_argument("--suffix", type=str, default=".normalized.txt") 36 | args = parser.parse_args() 37 | 38 | paths = list(args.folder.rglob(f"*{args.suffix}")) 39 | random.shuffle(paths) 40 | 41 | for path in tqdm(paths): 42 | phone_path = path.with_name(path.stem.split(".")[0] + ".phn.txt") 43 | if phone_path.exists(): 44 | continue 45 | graphs = _get_graphs(path) 46 | phones = encode(graphs) 47 | with open(phone_path, "w") as f: 48 | f.write(" ".join(phones)) 49 | 50 | 51 | if __name__ == "__main__": 52 | main() 53 | -------------------------------------------------------------------------------- /vall_e/emb/qnt.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import random 3 | from functools import cache 4 | from pathlib import Path 5 | 6 | import soundfile 7 | import torch 8 | import torchaudio 9 | from einops import rearrange 10 | from encodec import EncodecModel 11 | from encodec.utils import convert_audio 12 | from torch import Tensor 13 | from tqdm import tqdm 14 | 15 | from ..config import cfg 16 | 17 | 18 | @cache 19 | def _load_model(device="cuda"): 20 | # Instantiate a pretrained EnCodec model 21 | assert cfg.sample_rate == 24_000 22 | model = EncodecModel.encodec_model_24khz() 23 | model.set_target_bandwidth(6.0) 24 | model.to(device) 25 | return model 26 | 27 | 28 | def unload_model(): 29 | return _load_model.cache_clear() 30 | 31 | 32 | @torch.inference_mode() 33 | def decode(codes: Tensor, device="cuda"): 34 | """ 35 | Args: 36 | codes: (b q t) 37 | """ 38 | assert codes.dim() == 3 39 | model = _load_model(device) 40 | return model.decode([(codes, None)]), model.sample_rate 41 | 42 | 43 | def decode_to_file(resps: Tensor, path: Path): 44 | assert resps.dim() == 2, f"Require shape (t q), but got {resps.shape}." 45 | resps = rearrange(resps, "t q -> 1 q t") 46 | wavs, sr = decode(resps) 47 | soundfile.write(str(path), wavs.cpu()[0, 0], sr) 48 | 49 | 50 | def _replace_file_extension(path, suffix): 51 | return (path.parent / path.name.split(".")[0]).with_suffix(suffix) 52 | 53 | 54 | @torch.inference_mode() 55 | def encode(wav: Tensor, sr: int, device="cuda"): 56 | """ 57 | Args: 58 | wav: (t) 59 | sr: int 60 | """ 61 | model = _load_model(device) 62 | wav = wav.unsqueeze(0) 63 | wav = convert_audio(wav, sr, model.sample_rate, model.channels) 64 | wav = wav.to(device) 65 | encoded_frames = model.encode(wav) 66 | qnt = torch.cat([encoded[0] for encoded in encoded_frames], dim=-1) # (b q t) 67 | return qnt 68 | 69 | 70 | def encode_from_file(path, device="cuda"): 71 | wav, sr = torchaudio.load(str(path)) 72 | if wav.shape[0] == 2: 73 | wav = wav[:1] 74 | return encode(wav, sr, device) 75 | 76 | 77 | def main(): 78 | parser = argparse.ArgumentParser() 79 | parser.add_argument("folder", type=Path) 80 | parser.add_argument("--suffix", default=".wav") 81 | args = parser.parse_args() 82 | 83 | paths = [*args.folder.rglob(f"*{args.suffix}")] 84 | random.shuffle(paths) 85 | 86 | for path in tqdm(paths): 87 | out_path = _replace_file_extension(path, ".qnt.pt") 88 | if out_path.exists(): 89 | continue 90 | qnt = encode_from_file(path) 91 | torch.save(qnt.cpu(), out_path) 92 | 93 | 94 | if __name__ == "__main__": 95 | main() 96 | -------------------------------------------------------------------------------- /vall_e/export.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import torch 4 | 5 | from .data import VALLEDatset, create_train_val_dataloader 6 | from .train import load_engines 7 | 8 | 9 | def main(): 10 | parser = argparse.ArgumentParser("Save trained model to path.") 11 | parser.add_argument("path") 12 | args = parser.parse_args() 13 | 14 | engine = load_engines() 15 | model = engine["model"].module.cpu() 16 | train_dl, *_ = create_train_val_dataloader() 17 | assert isinstance(train_dl.dataset, VALLEDatset) 18 | model.phone_symmap = train_dl.dataset.phone_symmap 19 | model.spkr_symmap = train_dl.dataset.spkr_symmap 20 | torch.save(model, args.path) 21 | print(args.path, "saved.") 22 | 23 | 24 | if __name__ == "__main__": 25 | main() 26 | -------------------------------------------------------------------------------- /vall_e/sampler.py: -------------------------------------------------------------------------------- 1 | """ 2 | A sampler that balances data by key_fns. 3 | 4 | MIT License 5 | 6 | Copyright (c) 2023 Zhe Niu 7 | 8 | niuzhe.nz@outlook.com 9 | """ 10 | 11 | import random 12 | 13 | 14 | class Sampler: 15 | def __init__(self, l, key_fns): 16 | self.tree = self._build(l, key_fns) 17 | 18 | def _build(self, l, key_fns) -> dict[dict, list]: 19 | if not key_fns: 20 | return l 21 | 22 | tree = {} 23 | 24 | key_fn, *key_fns = key_fns 25 | 26 | for x in l: 27 | k = key_fn(x) 28 | 29 | if k in tree: 30 | tree[k].append(x) 31 | else: 32 | tree[k] = [x] 33 | 34 | for k in tree: 35 | tree[k] = self._build(tree[k], key_fns) 36 | 37 | return tree 38 | 39 | def _sample(self, tree: dict | list): 40 | if isinstance(tree, list): 41 | ret = random.choice(tree) 42 | else: 43 | key = random.choice([*tree.keys()]) 44 | ret = self._sample(tree[key]) 45 | return ret 46 | 47 | def sample(self): 48 | return self._sample(self.tree) 49 | -------------------------------------------------------------------------------- /vall_e/train.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | from collections import defaultdict 4 | 5 | import torch 6 | from tqdm import tqdm 7 | 8 | from .config import cfg 9 | from .data import create_train_val_dataloader 10 | from .emb import qnt 11 | from .utils import setup_logging, to_device, trainer 12 | from .vall_e import get_model 13 | 14 | _logger = logging.getLogger(__name__) 15 | 16 | 17 | def load_engines(): 18 | model = get_model(cfg.model) 19 | 20 | engines = dict( 21 | model=trainer.Engine( 22 | model=model, 23 | config=cfg.ds_cfg, 24 | ), 25 | ) 26 | 27 | return trainer.load_engines(engines, cfg) 28 | 29 | 30 | def main(): 31 | setup_logging(cfg.log_dir) 32 | 33 | train_dl, subtrain_dl, val_dl = create_train_val_dataloader() 34 | 35 | def train_feeder(engines, batch, name): 36 | model = engines["model"] 37 | 38 | if cfg.model.startswith("ar"): 39 | _ = model( 40 | text_list=batch["text"], 41 | proms_list=batch["proms"], 42 | resp_list=batch["resp"], 43 | ) 44 | elif cfg.model.startswith("nar"): 45 | _ = model( 46 | text_list=batch["text"], 47 | proms_list=batch["proms"], 48 | resps_list=batch["resps"], 49 | ) 50 | else: 51 | raise NotImplementedError(cfg.model) 52 | 53 | losses = model.gather_attribute("loss") 54 | 55 | loss = torch.stack([*losses.values()]).sum() 56 | 57 | stats = {} 58 | stats |= {k: v.item() for k, v in losses.items()} 59 | stats |= engines.gather_attribute("scalar") 60 | 61 | return loss, stats 62 | 63 | @torch.inference_mode() 64 | def run_eval(engines, name, dl): 65 | log_dir = cfg.log_dir / str(engines.global_step) / name 66 | 67 | model = engines["model"] 68 | log_dir = cfg.log_dir / str(engines.global_step) / name 69 | stats = defaultdict(list) 70 | for batch in tqdm(dl): 71 | batch: dict = to_device(batch, cfg.device) 72 | 73 | if cfg.model.startswith("ar"): 74 | resp_list = model( 75 | text_list=batch["text"], 76 | proms_list=batch["proms"], 77 | max_steps=cfg.max_val_ar_steps, 78 | sampling_temperature=cfg.sampling_temperature, 79 | ) 80 | resps_list = [r.unsqueeze(-1) for r in resp_list] 81 | elif cfg.model.startswith("nar"): 82 | resps_list = model( 83 | text_list=batch["text"], 84 | proms_list=batch["proms"], 85 | resps_list=[r.unsqueeze(-1) for r in batch["resp"]], 86 | sampling_temperature=cfg.sampling_temperature, 87 | ) 88 | else: 89 | raise NotImplementedError(cfg.model) 90 | 91 | losses = model.gather_attribute("loss") 92 | batch_stats = {k: v.item() for k, v in losses.items()} 93 | for k, v in batch_stats.items(): 94 | stats[k].append(v) 95 | 96 | for path, ref, hyp in zip(batch["path"], batch["resps"], resps_list): 97 | relpath = path.relative_to(cfg.data_root) 98 | hyp_path = (log_dir / "hyp" / relpath).with_suffix(".wav") 99 | ref_path = (log_dir / "ref" / relpath).with_suffix(".wav") 100 | hyp_path.parent.mkdir(parents=True, exist_ok=True) 101 | ref_path.parent.mkdir(parents=True, exist_ok=True) 102 | qnt.decode_to_file(ref, ref_path) 103 | if len(hyp) > 0: 104 | qnt.decode_to_file(hyp, hyp_path) 105 | 106 | qnt.unload_model() 107 | 108 | stats = {k: sum(v) / len(v) for k, v in stats.items()} 109 | stats["global_step"] = engines.global_step 110 | stats["name"] = name 111 | _logger.info(f"Eval: {stats}.") 112 | 113 | _logger.info(f"{json.dumps(stats)}.") 114 | 115 | def eval_fn(engines): 116 | run_eval(engines, "subtrain", subtrain_dl) 117 | run_eval(engines, "val", val_dl) 118 | 119 | trainer.train( 120 | engines_loader=load_engines, 121 | train_dl=train_dl, 122 | train_feeder=train_feeder, 123 | eval_fn=eval_fn, 124 | ) 125 | 126 | 127 | if __name__ == "__main__": 128 | main() 129 | -------------------------------------------------------------------------------- /vall_e/vall_e/__init__.py: -------------------------------------------------------------------------------- 1 | from ..config import cfg 2 | from .ar import AR 3 | from .nar import NAR 4 | 5 | 6 | def get_model(name): 7 | name = name.lower() 8 | 9 | if name.startswith("ar"): 10 | Model = AR 11 | elif name.startswith("nar"): 12 | Model = NAR 13 | else: 14 | raise ValueError("Model name should start with AR or NAR.") 15 | 16 | if "-quarter" in name: 17 | model = Model( 18 | cfg.num_tokens, 19 | d_model=256, 20 | n_heads=4, 21 | n_layers=12, 22 | ) 23 | elif "-half" in name: 24 | model = Model( 25 | cfg.num_tokens, 26 | d_model=512, 27 | n_heads=8, 28 | n_layers=12, 29 | ) 30 | else: 31 | if name not in ["ar", "nar"]: 32 | raise NotImplementedError(name) 33 | 34 | model = Model( 35 | cfg.num_tokens, 36 | d_model=1024, 37 | n_heads=16, 38 | n_layers=12, 39 | ) 40 | 41 | return model 42 | -------------------------------------------------------------------------------- /vall_e/vall_e/ar.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from einops import rearrange 3 | from torch import Tensor 4 | from tqdm import trange 5 | 6 | from .base import Base 7 | 8 | 9 | class AR(Base): 10 | @property 11 | def n_resp_levels(self): 12 | return 1 13 | 14 | @property 15 | def casual(self): 16 | return True 17 | 18 | @property 19 | def use_stop_token(self): 20 | return True 21 | 22 | @property 23 | def norm_type(self): 24 | return "ln" 25 | 26 | @property 27 | def resp_loss_only(self): 28 | return False 29 | 30 | def _prune(self, l: Tensor): 31 | indices = (l == self.stop_token).nonzero() 32 | if len(indices) == 0: 33 | return l 34 | return l[: indices.min().item()] 35 | 36 | @staticmethod 37 | def _unsqueeze_list(x_list, axis=-1): 38 | return [x.unsqueeze(dim=axis) for x in x_list] 39 | 40 | def forward( 41 | self, 42 | text_list: list[Tensor], 43 | proms_list: list[Tensor], 44 | resp_list: list[Tensor] | None = None, 45 | max_steps: int = 1000, 46 | sampling_temperature: float = 1.0, 47 | ): 48 | if resp_list is not None: 49 | return super().forward( 50 | text_list, 51 | proms_list, 52 | self._unsqueeze_list(resp_list), 53 | resp_list, 54 | quant_levels=None, 55 | shift_targ_list=True, 56 | return_all_resp=False, 57 | ) 58 | else: 59 | return self._generate( 60 | text_list, 61 | proms_list, 62 | max_steps, 63 | sampling_temperature, 64 | ) 65 | 66 | def _generate( 67 | self, 68 | text_list: list[Tensor], 69 | proms_list: list[Tensor], 70 | max_steps: int, 71 | sampling_temperature: float, 72 | ): 73 | device = text_list[0].device 74 | resp_list: list[Tensor] = [ 75 | torch.zeros(0, device=device).long() for _ in text_list 76 | ] 77 | stopped = torch.zeros(len(text_list), device=device).bool() 78 | for _ in trange(max_steps): 79 | r = super().forward( 80 | text_list, 81 | proms_list, 82 | self._unsqueeze_list(resp_list), 83 | sampling_temperature=sampling_temperature, 84 | ) 85 | stopped |= r == self.stop_token 86 | for i, ri in enumerate(r): 87 | resp_list[i] = torch.cat([resp_list[i], ri[None]]) 88 | if stopped.all().item(): 89 | break 90 | pruned = [self._prune(r) for r in resp_list] 91 | return pruned 92 | 93 | 94 | def example_usage(): 95 | from functools import partial 96 | 97 | import soundfile 98 | from einops import repeat 99 | 100 | device = "cuda" 101 | 102 | qnt = torch.load("data/test/test.qnt.pt")[0, 0].to(device) 103 | num_qnts = 1024 104 | 105 | model = AR(num_qnts).to(device) 106 | 107 | text_list = [ 108 | torch.tensor([1, 2, 3], device=device), 109 | torch.tensor([2, 3], device=device), 110 | ] 111 | 112 | x8 = partial(repeat, pattern="t -> t l", l=8) 113 | proms_list = [ 114 | x8(torch.tensor([1, 2, 3], device=device)), 115 | x8(torch.tensor([2, 3], device=device)), 116 | ] 117 | 118 | resp_list = [ 119 | torch.tensor([1, 2, 3], device=device), 120 | qnt.to(device), 121 | ] 122 | 123 | out = model(text_list, proms_list, max_steps=200) 124 | 125 | print(out) 126 | 127 | optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) 128 | 129 | for i in range(100): 130 | optimizer.zero_grad() 131 | _ = model(text_list, proms_list, resp_list) 132 | 133 | losses = model.loss 134 | sum(losses.values()).backward() 135 | optimizer.step() 136 | 137 | if i % 20 == 0: 138 | print(f"iter={i}, {losses}.") 139 | 140 | out = model(text_list, proms_list, max_steps=200) 141 | 142 | print(qnt) 143 | print(out) 144 | 145 | from ..emb.qnt import decode 146 | 147 | codes = rearrange(out[1], "t -> 1 1 t") 148 | wavs, sr = decode(codes) 149 | soundfile.write("data/test/test.ar.recon.wav", wavs.cpu()[0, 0], sr) 150 | 151 | 152 | if __name__ == "__main__": 153 | example_usage() 154 | -------------------------------------------------------------------------------- /vall_e/vall_e/base.py: -------------------------------------------------------------------------------- 1 | import math 2 | from functools import partial 3 | from typing import Literal, overload 4 | 5 | import torch 6 | import torch.nn.functional as F 7 | from einops import rearrange 8 | from torch import Tensor, einsum, nn 9 | from torch.distributions import Categorical 10 | from torch.nn.utils.rnn import pad_sequence 11 | from torch.utils.checkpoint import checkpoint 12 | 13 | 14 | def _create_mask(l, device): 15 | """1 is valid region and 0 is invalid.""" 16 | seq = torch.arange(max(l), device=device).unsqueeze(0) # (1 t) 17 | stop = torch.tensor(l, device=device).unsqueeze(1) # (b 1) 18 | return (seq < stop).float() # (b t) 19 | 20 | 21 | def list_to_tensor(x_list: list[Tensor], pattern="t b c -> b t c"): 22 | """ 23 | Args: 24 | x_list: [(t d)] 25 | Returns: 26 | x: (? ? ?) 27 | m: (? ? ?), same as x 28 | """ 29 | l = list(map(len, x_list)) 30 | x = rearrange(pad_sequence(x_list), pattern) 31 | m = _create_mask(l, x_list[0].device) 32 | m = m.t().unsqueeze(-1) # (t b 1) 33 | m = rearrange(m, pattern) 34 | m = m.to(x) 35 | return x, m 36 | 37 | 38 | class SinusodialEmbedding(nn.Module): 39 | def __init__(self, d_model): 40 | super().__init__() 41 | self.d_model = d_model 42 | exponent = torch.arange(self.d_half, dtype=torch.float32) 43 | exponent = exponent / self.d_half 44 | omega = torch.exp(-math.log(1e4) * exponent) 45 | self.omega: torch.Tensor 46 | self.register_buffer("omega", omega, persistent=False) 47 | 48 | @property 49 | def d_half(self): 50 | assert self.d_model % 2 == 0, "Only support even d_model." 51 | return self.d_model // 2 52 | 53 | def forward(self, x): 54 | """ 55 | Args: 56 | x: (...) 57 | Returns: 58 | pe: (... d) 59 | """ 60 | omega = self.omega 61 | 62 | while omega.dim() <= x.dim(): 63 | omega = omega.unsqueeze(0) # (... d) 64 | 65 | x = x.unsqueeze(-1) # (... 1) 66 | x = omega * x 67 | x = torch.cat([x.sin(), x.cos()], dim=-1) 68 | 69 | return x 70 | 71 | def get_pe(self, n: int): 72 | """ 73 | Args: 74 | n: int 75 | Returns: 76 | pe: (n d) 77 | """ 78 | device = self.omega.device 79 | return self.forward(torch.arange(n, device=device)) 80 | 81 | def add_pe(self, x): 82 | """ 83 | Args: 84 | x: (b t c) 85 | """ 86 | e = self.get_pe(x.shape[1]) # t d 87 | e = e[None] # b t d 88 | x = x + e 89 | return x 90 | 91 | 92 | class Attention(nn.Module): 93 | def __init__(self, d_model, n_heads, casual): 94 | super().__init__() 95 | assert d_model % n_heads == 0 96 | dim_head = d_model // n_heads 97 | self.casual = casual 98 | self.n_heads = n_heads 99 | self.scale = dim_head**-0.5 100 | self.to_qkv = nn.Linear(d_model, d_model * 3, bias=False) 101 | self.to_out = nn.Linear(d_model, d_model) 102 | 103 | def forward(self, x, m): 104 | """ 105 | Args: 106 | x: (b t c) 107 | m: (b t c), 1 is data, 0 is padding 108 | Returns: 109 | x: (b t c) 110 | """ 111 | h = self.n_heads 112 | 113 | q, k, v = self.to_qkv(x).chunk(3, dim=-1) 114 | q, k, v = map(lambda t: rearrange(t, "b t (h d) -> b t h d", h=h), (q, k, v)) 115 | 116 | e = einsum("b i h d, b j h d -> b i j h", q, k) 117 | e = e * self.scale 118 | 119 | kpm = m.unsqueeze(1) * m.unsqueeze(2) # b i j 1 120 | 121 | if self.casual: 122 | kpm = kpm.squeeze(-1).tril().unsqueeze(-1) # b i j 1 123 | 124 | e = e.masked_fill(kpm == 0, -torch.finfo(e.dtype).max) 125 | a = e.softmax(dim=2) # Normalize on j, i.e. key 126 | 127 | o = einsum("b i j h, b j h d -> b i h d", a, v) 128 | o = o.flatten(-2) 129 | o = self.to_out(o) # b t c 130 | 131 | o = o * m 132 | 133 | return o 134 | 135 | 136 | class AdaLN(nn.Module): 137 | def __init__(self, d_model, n_levels, eps=1e-5, k=0.1, c=2): 138 | super().__init__() 139 | self.eps = eps 140 | self.emb = nn.Embedding(n_levels, d_model * 2) 141 | self.k = k 142 | self.c = c 143 | nn.init.zeros_(self.emb.weight) 144 | 145 | def forward(self, x, l): 146 | logγ, β = self.emb(l).unsqueeze(1).chunk(2, dim=-1) 147 | 148 | h = F.layer_norm(x, x.shape[-1:], eps=self.eps) 149 | 150 | # The initial implementation (https://github.com/enhuiz/vall-e/blob/fbf023448c08e55c0422eefed7fc234cf8b76680/vall_e/vall_e/base.py#L135) 151 | # performed worse than vanilla LayerNorm. 152 | # The authors mentioned another AdaNorm paper (https://openreview.net/pdf?id=HyxndNrxLB) as they introduce AdaLN. 153 | # Did they use AdaNorm inside AdaLN? (as follows) 154 | h = self.c * (1 - (self.k * h).detach()) * h 155 | 156 | y = logγ.exp() * h + β 157 | 158 | return y 159 | 160 | 161 | class PrenormResidual(nn.Module): 162 | def __init__( 163 | self, 164 | block, 165 | d_model, 166 | p_dropout, 167 | requires_mask=False, 168 | norm_type="ln", 169 | n_levels: int | None = None, 170 | ): 171 | super().__init__() 172 | self.block = block 173 | self.requires_mask = requires_mask 174 | self.norm_type = norm_type 175 | if norm_type == "ln": 176 | self.norm = nn.LayerNorm(d_model) 177 | elif norm_type == "adaln": 178 | assert n_levels is not None 179 | self.norm = AdaLN(d_model, n_levels) 180 | else: 181 | raise NotImplementedError(norm_type) 182 | self.dropout = nn.Dropout(p_dropout) 183 | 184 | def forward(self, x, m, l): 185 | """ 186 | Args: 187 | x: input (b t d) 188 | m: mask (b t 1), 1 is valuable and 0 is padding 189 | l: level to use, required only for AdaLN 190 | """ 191 | nopts = {"l": l} if self.norm_type == "adaln" else {} 192 | bopts = {"m": m} if self.requires_mask else {} 193 | x = x + self.dropout(self.block(self.norm(x, **nopts) * m, **bopts)) 194 | return x * m 195 | 196 | 197 | class Block(nn.Sequential): 198 | def __init__(self, d_model, n_heads, p_dropout, casual, norm_type, n_levels): 199 | super().__init__() 200 | self.attn = PrenormResidual( 201 | Attention(d_model, n_heads, casual), 202 | d_model=d_model, 203 | p_dropout=p_dropout, 204 | requires_mask=True, 205 | norm_type=norm_type, 206 | n_levels=n_levels, 207 | ) 208 | self.ffn = PrenormResidual( 209 | nn.Sequential( 210 | nn.Linear(d_model, d_model * 4), 211 | nn.GELU(), 212 | nn.Dropout(p_dropout), 213 | nn.Linear(d_model * 4, d_model), 214 | ), 215 | d_model=d_model, 216 | p_dropout=p_dropout, 217 | norm_type=norm_type, 218 | n_levels=n_levels, 219 | ) 220 | 221 | def forward(self, x, m, l): 222 | """ 223 | Args: 224 | x: (b t c) 225 | m: (b t 1) 226 | l: (b) 227 | """ 228 | poor_in_vram = True 229 | if x.requires_grad and poor_in_vram: 230 | x = checkpoint(self.attn, x, m, l) 231 | else: 232 | x = self.attn(x, m, l) 233 | x = self.ffn(x, m, l) 234 | return x 235 | 236 | 237 | class Embedding(nn.Embedding): 238 | def forward(self, x_list: list[Tensor]) -> list[Tensor]: 239 | if len(x_list) == 0: 240 | return [] 241 | return super().forward(torch.cat(x_list)).split([*map(len, x_list)]) 242 | 243 | 244 | class MultiEmbedding(nn.Module): 245 | """ 246 | This embedding sums embeddings on different levels. 247 | """ 248 | 249 | def __init__(self, max_n_levels, n_tokens, token_dim): 250 | super().__init__() 251 | self.max_n_levels = max_n_levels 252 | self.n_tokens = n_tokens 253 | self.weight = nn.Parameter(torch.randn(max_n_levels, n_tokens, token_dim)) 254 | 255 | def forward(self, x_list: list[Tensor]) -> list[Tensor]: 256 | if len(x_list) == 0: 257 | return [] 258 | 259 | w = self.weight 260 | 261 | padded_x_list = [] 262 | 263 | for xi in x_list: 264 | xi = F.one_hot(xi, num_classes=self.n_tokens) # t l' k 265 | xi = F.pad(xi, (0, 0, 0, w.shape[0] - xi.shape[1])) # t l k 266 | padded_x_list.append(xi.to(w)) 267 | 268 | x = torch.cat(padded_x_list) # n l k 269 | x = einsum("l k d, n l k -> n d", w, x) 270 | 271 | x_list = x.split([*map(len, x_list)]) 272 | 273 | return x_list 274 | 275 | 276 | def _join(x: tuple[Tensor], sep: Tensor): 277 | """ 278 | Args: 279 | x: (k t d) 280 | sep: (d) 281 | """ 282 | ret = x[0] 283 | for i in range(1, len(x)): 284 | ret = torch.cat((ret, sep[None], x[i]), dim=0) 285 | return ret 286 | 287 | 288 | class Base(nn.Module): 289 | @property 290 | def casual(self) -> bool: 291 | raise NotImplementedError 292 | 293 | @property 294 | def n_resp_levels(self) -> int: 295 | raise NotImplementedError 296 | 297 | @property 298 | def use_stop_token(self) -> bool: 299 | raise NotImplementedError 300 | 301 | @property 302 | def norm_type(self): 303 | raise NotImplementedError 304 | 305 | @property 306 | def n_prom_levels(self) -> int: 307 | return 8 308 | 309 | @property 310 | def resp_loss_only(self): 311 | raise NotImplementedError 312 | 313 | def __init__( 314 | self, 315 | n_tokens: int, 316 | d_model: int = 512, 317 | n_heads: int = 8, 318 | n_layers: int = 12, 319 | p_dropout: float = 0.1, 320 | ): 321 | super().__init__() 322 | self.n_tokens = n_tokens 323 | 324 | casual = self.casual 325 | 326 | # +1 to include the stop token 327 | n_stop_tokens = 1 if self.use_stop_token else 0 328 | n_resp_tokens = n_tokens + n_stop_tokens 329 | 330 | self.text_emb = Embedding(n_tokens, d_model) 331 | 332 | # Here I simply use all prom levels 333 | self.proms_emb = MultiEmbedding(self.n_prom_levels, n_tokens, d_model) 334 | self.resps_emb = MultiEmbedding(self.n_resp_levels, n_resp_tokens, d_model) 335 | 336 | self.sin_emb = SinusodialEmbedding(d_model) 337 | 338 | self.sep = nn.Parameter(torch.randn(d_model)) 339 | 340 | blocks = [ 341 | Block( 342 | d_model=d_model, 343 | n_heads=n_heads, 344 | p_dropout=p_dropout, 345 | casual=casual, 346 | norm_type=self.norm_type, 347 | n_levels=self.n_resp_levels, 348 | ) 349 | for _ in range(n_layers) 350 | ] 351 | 352 | self.blocks = nn.ModuleList(blocks) 353 | 354 | self.classifier = nn.Linear(d_model, n_resp_tokens) 355 | 356 | @property 357 | def stop_token(self): 358 | if not self.use_stop_token: 359 | raise ValueError("Not using stop token!") 360 | return self.n_tokens 361 | 362 | @property 363 | def ignore_index(self): 364 | return -100 365 | 366 | @staticmethod 367 | def _samplewise_merge_tensors(*l, sep: Tensor | None): 368 | if sep is None: 369 | cat = torch.cat 370 | else: 371 | cat = partial(_join, sep=sep) 372 | return [*map(cat, zip(*l))] 373 | 374 | @overload 375 | def forward( 376 | self, 377 | text_list: list[Tensor], 378 | proms_list: list[Tensor], 379 | resps_list: list[Tensor], 380 | targ_list: list[Tensor] | None = None, 381 | quant_levels: Tensor | None = None, 382 | shift_targ_list: bool = False, 383 | return_all_resp: Literal[False] = False, 384 | sampling_temperature: float = 1.0, 385 | ) -> Tensor: 386 | ... 387 | 388 | @overload 389 | def forward( 390 | self, 391 | text_list: list[Tensor], 392 | proms_list: list[Tensor], 393 | resps_list: list[Tensor], 394 | targ_list: list[Tensor] | None = None, 395 | quant_levels: Tensor | None = None, 396 | shift_targ_list: bool = False, 397 | return_all_resp: Literal[True] = True, 398 | sampling_temperature: float = 1.0, 399 | ) -> list[Tensor]: 400 | ... 401 | 402 | def forward( 403 | self, 404 | text_list: list[Tensor], 405 | proms_list: list[Tensor], 406 | resps_list: list[Tensor], 407 | targ_list: list[Tensor] | None = None, 408 | quant_levels: Tensor | None = None, 409 | shift_targ_list: bool = False, 410 | return_all_resp: bool = False, 411 | sampling_temperature: float = 1.0, 412 | ): 413 | """ 414 | Args: 415 | text_list: [t] * b 416 | proms_list: [t' l] * b, l quantization levels. 417 | resps_list: [t'' l] * b, l quantization levels. 418 | targ_list: [t''] * b, one quantization level only, when given, loss will be computed 419 | quant_levels: specify which quant_levels to feed forward, used in NAR mode. 420 | shift_targ_list: whether to shift target list when computing loss. True if AR. 421 | return_all_resp: True if NAR. 422 | sampling_temperature: a lower temperature makes the result more robust but less diverse. 423 | Returns: 424 | y: sampled tokens 425 | """ 426 | x_list = self._samplewise_merge_tensors( 427 | self.text_emb(text_list), 428 | self.proms_emb(proms_list), 429 | self.resps_emb(resps_list), 430 | sep=self.sep, 431 | ) 432 | 433 | x, m = list_to_tensor(x_list) 434 | x = self.sin_emb.add_pe(x) 435 | 436 | for block in self.blocks: 437 | x = block(x, m, quant_levels) 438 | 439 | h = self.classifier(x) * m 440 | 441 | # Remove padding 442 | h_list = [hi[:li] for hi, li in zip(h, map(len, x_list))] 443 | 444 | if targ_list is not None: 445 | if any([l == 0 for l in map(len, targ_list)]): 446 | raise ValueError("Cannot compute loss given empty targ_list.") 447 | 448 | device = h.device 449 | 450 | ignore_sep = torch.tensor(self.ignore_index, device=device) 451 | 452 | # Ignore prom in the target 453 | prom_list = [ 454 | torch.full_like(t[..., 0], self.ignore_index) for t in proms_list 455 | ] 456 | 457 | text_prom_list = self._samplewise_merge_tensors( 458 | text_list, prom_list, sep=ignore_sep 459 | ) 460 | 461 | # Make every token earlier as it is future that is unknown 462 | # If we don't want compute loss, set all to ignored 463 | for i in range(len(text_prom_list)): 464 | if self.resp_loss_only: 465 | text_prom_list[i][:] = self.ignore_index 466 | else: 467 | text_prom_list[i] = text_prom_list[i].roll(-1, dims=0) 468 | text_prom_list[i][-1] = self.ignore_index 469 | 470 | if shift_targ_list: 471 | # Also make target earlier if in autoregressive mode 472 | targ_list = [*targ_list] 473 | for i in range(len(targ_list)): 474 | targ_list[i] = targ_list[i].roll(-1, dims=0) 475 | targ_list[i][-1] = self.stop_token 476 | 477 | y_list = self._samplewise_merge_tensors( 478 | text_prom_list, targ_list, sep=ignore_sep 479 | ) 480 | 481 | self.loss = dict( 482 | nll=F.cross_entropy( 483 | torch.cat(h_list), 484 | torch.cat(y_list), 485 | ignore_index=self.ignore_index, 486 | ) 487 | ) 488 | 489 | if return_all_resp: 490 | logits = [hi[-li:] for hi, li in zip(h_list, map(len, resps_list))] 491 | ret = [ 492 | Categorical(logits=hi / sampling_temperature).sample() for hi in logits 493 | ] 494 | else: 495 | logits = torch.stack([hi[-1] for hi in h_list]) 496 | ret = Categorical(logits=logits / sampling_temperature).sample() 497 | 498 | return ret 499 | -------------------------------------------------------------------------------- /vall_e/vall_e/nar.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | 4 | from .base import Base 5 | 6 | 7 | class NAR(Base): 8 | @property 9 | def n_resp_levels(self): 10 | return 7 11 | 12 | @property 13 | def casual(self): 14 | return False 15 | 16 | @property 17 | def use_stop_token(self): 18 | return False 19 | 20 | @property 21 | def norm_type(self): 22 | return "adaln" 23 | 24 | @property 25 | def resp_loss_only(self): 26 | return True 27 | 28 | def forward( 29 | self, 30 | text_list: list[Tensor], 31 | proms_list: list[Tensor], 32 | resps_list: list[Tensor], 33 | sampling_temperature: float = 0.2, 34 | ): 35 | """ 36 | Args: 37 | text_list: [t] * b 38 | proms_list: [t' l] * b, l=8 39 | resps_list: [t'' l] * b, l=1 or 8, 1 for testing and 8 for training. 40 | Returns: 41 | [t'' l], l=8 if testing. empty list will be returned during training. 42 | """ 43 | 44 | n_levels_set = {r.shape[-1] for r in resps_list} 45 | 46 | if len(n_levels_set) > 1: 47 | raise ValueError(f"Please give only one level, got {n_levels_set}.") 48 | 49 | n_levels = next(iter(n_levels_set)) 50 | 51 | device = text_list[0].device 52 | 53 | if n_levels == self.n_resp_levels + 1: 54 | assert resps_list is not None 55 | 56 | quant_levels = torch.randint(0, self.n_resp_levels, (len(resps_list),)) 57 | 58 | prev_list = [o[..., : l + 1] for o, l in zip(resps_list, quant_levels)] 59 | targ_list = [o[..., l + 1] for o, l in zip(resps_list, quant_levels)] 60 | 61 | quant_levels = quant_levels.to(device=device) 62 | 63 | _ = super().forward( 64 | text_list, 65 | proms_list, 66 | prev_list, 67 | targ_list, 68 | return_all_resp=True, 69 | shift_targ_list=False, 70 | quant_levels=quant_levels, 71 | ) 72 | 73 | # Yes, just nothing as we are training 74 | prev_list = [] 75 | else: 76 | prev_list = resps_list 77 | 78 | while True: 79 | level = prev_list[0].shape[-1] - 1 80 | 81 | if level >= self.n_resp_levels: 82 | break 83 | 84 | quant_levels = torch.full((len(text_list),), level, device=device) 85 | 86 | resp_list = super().forward( 87 | text_list, 88 | proms_list, 89 | prev_list, 90 | return_all_resp=True, 91 | shift_targ_list=False, 92 | quant_levels=quant_levels, 93 | sampling_temperature=sampling_temperature, 94 | ) 95 | 96 | prev_list = [ 97 | torch.cat([rs, r.unsqueeze(-1)], dim=-1) 98 | for rs, r in zip(prev_list, resp_list) 99 | ] 100 | 101 | return prev_list 102 | 103 | 104 | def example_usage(): 105 | from functools import partial 106 | from pathlib import Path 107 | 108 | from einops import repeat 109 | 110 | from ..emb.qnt import decode_to_file 111 | from ..utils import gather_attribute 112 | 113 | device = "cuda" 114 | 115 | resps = torch.load("data/test/test.qnt.pt")[0].to(device) 116 | num_qnts = 1024 117 | 118 | model = NAR(num_qnts).to(device) 119 | 120 | text_list = [ 121 | torch.tensor([2, 3], device=device), 122 | ] 123 | 124 | x8 = partial(repeat, pattern="t -> t l", l=8) 125 | proms_list = [ 126 | x8(torch.tensor([2, 3], device=device)), 127 | ] 128 | 129 | resps_x1_list = [ 130 | resps[:1].t().to(device), 131 | ] 132 | 133 | resps_x8_list = [ 134 | resps.t().to(device), 135 | ] 136 | 137 | codes = model( 138 | text_list, 139 | proms_list, 140 | resps_list=resps_x1_list, 141 | sampling_temperature=0.2, 142 | )[0] 143 | 144 | decode_to_file( 145 | codes, 146 | Path("data/test/test.nar.init.wav"), 147 | ) 148 | 149 | optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) 150 | 151 | for i in range(200): 152 | optimizer.zero_grad() 153 | 154 | _ = model(text_list, proms_list, resps_list=resps_x8_list) 155 | 156 | losses = gather_attribute(model, "loss") 157 | loss = sum(losses.values()) 158 | loss.backward() 159 | 160 | optimizer.step() 161 | 162 | if i % 20 == 0: 163 | stats = {k: v.item() for k, v in losses.items()} 164 | stats["loss"] = loss.item() 165 | print(f"iter={i}, {stats}.") 166 | 167 | for i in range(1, 8): 168 | resps_list = [ 169 | resps[:i].t().to(device), 170 | ] 171 | 172 | codes = model( 173 | text_list, 174 | proms_list, 175 | resps_list=resps_list, 176 | sampling_temperature=0.2, 177 | )[0] 178 | 179 | decode_to_file( 180 | codes, 181 | Path(f"data/test/test.nar.1-{i}.wav"), 182 | ) 183 | 184 | 185 | if __name__ == "__main__": 186 | example_usage() 187 | --------------------------------------------------------------------------------