├── .gitignore
├── .gitmodules
├── LICENSE
├── README.md
├── config
├── LibriTTS
│ ├── ar-quarter.yml
│ ├── ar.yml
│ ├── nar-quarter.yml
│ └── nar.yml
└── test
│ ├── ar.yml
│ └── nar.yml
├── data
└── test
│ ├── test.normalized.txt
│ ├── test.phn.txt
│ ├── test.qnt.pt
│ ├── test.wav
│ ├── test2.phn.txt
│ └── test2.qnt.pt
├── scripts
├── plot.py
└── run.sh
├── setup.py
├── vall-e.png
└── vall_e
├── __init__.py
├── __main__.py
├── config.py
├── data.py
├── emb
├── __init__.py
├── g2p.py
└── qnt.py
├── export.py
├── sampler.py
├── train.py
└── vall_e
├── __init__.py
├── ar.py
├── base.py
└── nar.py
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__
2 | /data
3 | /logs
4 | /ckpts
5 | /.cache
6 | /config
7 | /*.egg-info
8 | /vall_e/version.py
9 | /build
10 |
--------------------------------------------------------------------------------
/.gitmodules:
--------------------------------------------------------------------------------
1 | [submodule "mini_vall_e/utils"]
2 | path = vall_e/utils
3 | url = https://github.com/enhuiz/pytorch-training-utils.git
4 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Zhe Niu
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 | # VALL-E
6 |
7 | An unofficial PyTorch implementation of [VALL-E](https://valle-demo.github.io/), based on the [EnCodec](https://github.com/facebookresearch/encodec) tokenizer.
8 |
9 | [](https://www.buymeacoffee.com/enhuiz)
10 |
11 | ## Get Started
12 |
13 | > A toy Google Colab example: [](https://colab.research.google.com/drive/1wEze0kQ0gt9B3bQmmbtbSXCoCTpq5vg-?usp=sharing).
14 | > Please note that this example overfits a single utterance under the `data/test` and is not usable.
15 | > The pretrained model is yet to come.
16 |
17 | ### Requirements
18 |
19 | Since the trainer is based on [DeepSpeed](https://github.com/microsoft/DeepSpeed#requirements), you will need to have a GPU that DeepSpeed has developed and tested against, as well as a CUDA or ROCm compiler pre-installed to install this package.
20 |
21 | ### Install
22 |
23 | ```
24 | pip install git+https://github.com/enhuiz/vall-e
25 | ```
26 |
27 | Or you may clone by:
28 |
29 | ```
30 | git clone --recurse-submodules https://github.com/enhuiz/vall-e.git
31 | ```
32 |
33 | Note that the code is only tested under `Python 3.10.7`.
34 |
35 | ### Train
36 |
37 | 1. Put your data into a folder, e.g. `data/your_data`. Audio files should be named with the suffix `.wav` and text files with `.normalized.txt`.
38 |
39 | 2. Quantize the data:
40 |
41 | ```
42 | python -m vall_e.emb.qnt data/your_data
43 | ```
44 |
45 | 3. Generate phonemes based on the text:
46 |
47 | ```
48 | python -m vall_e.emb.g2p data/your_data
49 | ```
50 |
51 | 4. Customize your configuration by creating `config/your_data/ar.yml` and `config/your_data/nar.yml`. Refer to the example configs in `config/test` and `vall_e/config.py` for details. You may choose different model presets, check `vall_e/vall_e/__init__.py`.
52 |
53 | 5. Train the AR or NAR model using the following scripts:
54 |
55 | ```
56 | python -m vall_e.train yaml=config/your_data/ar_or_nar.yml
57 | ```
58 |
59 | You may quit your training any time by just typing `quit` in your CLI. The latest checkpoint will be automatically saved.
60 |
61 | ### Export
62 |
63 | Both trained models need to be exported to a certain path. To export either of them, run:
64 |
65 | ```
66 | python -m vall_e.export zoo/ar_or_nar.pt yaml=config/your_data/ar_or_nar.yml
67 | ```
68 |
69 | This will export the latest checkpoint.
70 |
71 | ### Synthesis
72 |
73 | ```
74 | python -m vall_e --ar-ckpt zoo/ar.pt --nar-ckpt zoo/nar.pt
75 | ```
76 |
77 | ## TODO
78 |
79 | - [x] AR model for the first quantizer
80 | - [x] Audio decoding from tokens
81 | - [x] NAR model for the rest quantizers
82 | - [x] Trainers for both models
83 | - [x] Implement AdaLN for NAR model.
84 | - [x] Sample-wise quantization level sampling for NAR training.
85 | - [ ] Pre-trained checkpoint and demos on LibriTTS
86 | - [x] Synthesis CLI
87 |
88 | ## Notice
89 |
90 | - [EnCodec](https://github.com/facebookresearch/encodec) is licensed under CC-BY-NC 4.0. If you use the code to generate audio quantization or perform decoding, it is important to adhere to the terms of their license.
91 |
92 | ## Citations
93 |
94 | ```bibtex
95 | @article{wang2023neural,
96 | title={Neural Codec Language Models are Zero-Shot Text to Speech Synthesizers},
97 | author={Wang, Chengyi and Chen, Sanyuan and Wu, Yu and Zhang, Ziqiang and Zhou, Long and Liu, Shujie and Chen, Zhuo and Liu, Yanqing and Wang, Huaming and Li, Jinyu and others},
98 | journal={arXiv preprint arXiv:2301.02111},
99 | year={2023}
100 | }
101 | ```
102 |
103 | ```bibtex
104 | @article{defossez2022highfi,
105 | title={High Fidelity Neural Audio Compression},
106 | author={Défossez, Alexandre and Copet, Jade and Synnaeve, Gabriel and Adi, Yossi},
107 | journal={arXiv preprint arXiv:2210.13438},
108 | year={2022}
109 | }
110 | ```
111 |
--------------------------------------------------------------------------------
/config/LibriTTS/ar-quarter.yml:
--------------------------------------------------------------------------------
1 | data_dirs: [data/LibriTTS/]
2 | spkr_name_getter: "lambda p: p.parts[-3]"
3 |
4 | model: ar-quarter
5 | batch_size: 48
6 | eval_batch_size: 48
7 | eval_every: 10_000
8 |
--------------------------------------------------------------------------------
/config/LibriTTS/ar.yml:
--------------------------------------------------------------------------------
1 | data_dirs: [data/LibriTTS/]
2 | spkr_name_getter: "lambda p: p.parts[-3]"
3 |
4 | model: ar
5 | batch_size: 24
6 | eval_batch_size: 24
7 | eval_every: 10_000
8 |
9 | sampling_temperature: 1.0
10 |
--------------------------------------------------------------------------------
/config/LibriTTS/nar-quarter.yml:
--------------------------------------------------------------------------------
1 | data_dirs: [data/LibriTTS/]
2 | spkr_name_getter: "lambda p: p.parts[-3]"
3 |
4 | model: nar-quarter
5 | batch_size: 48
6 | eval_batch_size: 48
7 |
--------------------------------------------------------------------------------
/config/LibriTTS/nar.yml:
--------------------------------------------------------------------------------
1 | data_dirs: [data/LibriTTS/]
2 | spkr_name_getter: "lambda p: p.parts[-3]"
3 |
4 | model: nar
5 | batch_size: 24
6 | eval_batch_size: 24
7 | eval_every: 1_000
8 |
9 | sampling_temperature: 0.2
10 |
--------------------------------------------------------------------------------
/config/test/ar.yml:
--------------------------------------------------------------------------------
1 | data_dirs: [data/test]
2 |
3 | model: ar-quarter
4 | batch_size: 1
5 | eval_batch_size: 1
6 | save_ckpt_every: 500
7 | eval_every: 500
8 | max_iter: 1000
9 |
--------------------------------------------------------------------------------
/config/test/nar.yml:
--------------------------------------------------------------------------------
1 | data_dirs: [data/test]
2 |
3 | model: nar-quarter
4 | batch_size: 1
5 | eval_batch_size: 1
6 | save_ckpt_every: 500
7 | eval_every: 500
8 | max_iter: 1000
9 |
--------------------------------------------------------------------------------
/data/test/test.normalized.txt:
--------------------------------------------------------------------------------
1 | hello world
2 |
--------------------------------------------------------------------------------
/data/test/test.phn.txt:
--------------------------------------------------------------------------------
1 | HH AH0 L OW1 _ W ER1 L D
2 |
--------------------------------------------------------------------------------
/data/test/test.qnt.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/enhuiz/vall-e/3476d393d2133fa9b50d5ad999ca13b95fc22060/data/test/test.qnt.pt
--------------------------------------------------------------------------------
/data/test/test.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/enhuiz/vall-e/3476d393d2133fa9b50d5ad999ca13b95fc22060/data/test/test.wav
--------------------------------------------------------------------------------
/data/test/test2.phn.txt:
--------------------------------------------------------------------------------
1 | test.phn.txt
--------------------------------------------------------------------------------
/data/test/test2.qnt.pt:
--------------------------------------------------------------------------------
1 | test.qnt.pt
--------------------------------------------------------------------------------
/scripts/plot.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | import argparse
4 | import json
5 | import re
6 | from pathlib import Path
7 |
8 | import matplotlib.pyplot as plt
9 | import pandas as pd
10 |
11 |
12 | def plot(paths, args):
13 | dfs = []
14 |
15 | for path in paths:
16 | with open(path, "r") as f:
17 | text = f.read()
18 |
19 | rows = []
20 |
21 | pattern = r"(\{.+?\})"
22 |
23 | for row in re.findall(pattern, text, re.DOTALL):
24 | try:
25 | row = json.loads(row)
26 | except Exception as e:
27 | continue
28 |
29 | if "global_step" in row:
30 | rows.append(row)
31 |
32 | df = pd.DataFrame(rows)
33 |
34 | if "name" in df:
35 | df["name"] = df["name"].fillna("train")
36 | else:
37 | df["name"] = "train"
38 |
39 | df["group"] = str(path.parents[args.group_level])
40 | df["group"] = df["group"] + "/" + df["name"]
41 |
42 | dfs.append(df)
43 |
44 | df = pd.concat(dfs)
45 |
46 | if args.max_y is not None:
47 | df = df[df["global_step"] < args.max_x]
48 |
49 | for gtag, gdf in sorted(
50 | df.groupby("group"),
51 | key=lambda p: (p[0].split("/")[-1], p[0]),
52 | ):
53 | for y in args.ys:
54 | gdf = gdf.sort_values("global_step")
55 |
56 | if gdf[y].isna().all():
57 | continue
58 |
59 | if args.max_y is not None:
60 | gdf = gdf[gdf[y] < args.max_y]
61 |
62 | gdf[y] = gdf[y].ewm(10).mean()
63 |
64 | gdf.plot(
65 | x="global_step",
66 | y=y,
67 | label=f"{gtag}/{y}",
68 | ax=plt.gca(),
69 | marker="x" if len(gdf) < 100 else None,
70 | alpha=0.7,
71 | )
72 |
73 | plt.gca().legend(
74 | loc="center left",
75 | fancybox=True,
76 | shadow=True,
77 | bbox_to_anchor=(1.04, 0.5),
78 | )
79 |
80 |
81 | def main():
82 | parser = argparse.ArgumentParser()
83 | parser.add_argument("ys", nargs="+")
84 | parser.add_argument("--log-dir", default="logs", type=Path)
85 | parser.add_argument("--out-dir", default="logs", type=Path)
86 | parser.add_argument("--filename", default="log.txt")
87 | parser.add_argument("--max-x", type=float, default=float("inf"))
88 | parser.add_argument("--max-y", type=float, default=float("inf"))
89 | parser.add_argument("--group-level", default=1)
90 | parser.add_argument("--filter", default=None)
91 | args = parser.parse_args()
92 |
93 | paths = args.log_dir.rglob(f"**/{args.filename}")
94 |
95 | if args.filter:
96 | paths = filter(lambda p: re.match(".*" + args.filter + ".*", str(p)), paths)
97 |
98 | plot(paths, args)
99 |
100 | name = "-".join(args.ys)
101 | out_path = (args.out_dir / name).with_suffix(".png")
102 | plt.savefig(out_path, bbox_inches="tight")
103 |
104 |
105 | if __name__ == "__main__":
106 | main()
107 |
--------------------------------------------------------------------------------
/scripts/run.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | until $@; do echo retrying; done
4 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | import subprocess
2 | from pathlib import Path
3 | from datetime import datetime
4 | from setuptools import setup, find_packages
5 |
6 |
7 | def shell(*args):
8 | out = subprocess.check_output(args)
9 | return out.decode("ascii").strip()
10 |
11 |
12 | def write_version(version_core, pre_release=True):
13 | if pre_release:
14 | time = shell("git", "log", "-1", "--format=%cd", "--date=iso")
15 | time = datetime.strptime(time, "%Y-%m-%d %H:%M:%S %z")
16 | time = time.strftime("%Y%m%d%H%M%S")
17 | version = f"{version_core}-dev{time}"
18 | else:
19 | version = version_core
20 |
21 | with open(Path("vall_e", "version.py"), "w") as f:
22 | f.write('__version__ = "{}"\n'.format(version))
23 |
24 | return version
25 |
26 |
27 | with open("README.md", "r") as f:
28 | long_description = f.read()
29 |
30 | setup(
31 | name="vall-e",
32 | python_requires=">=3.10.0",
33 | version=write_version("0.0.1"),
34 | description="An unofficial toy implementation of the audio LM VALL-E",
35 | author="enhuiz",
36 | author_email="niuzhe.nz@outlook.com",
37 | long_description=long_description,
38 | long_description_content_type="text/markdown",
39 | packages=find_packages(),
40 | install_requires=[
41 | "coloredlogs>=15.0.1",
42 | "deepspeed>=0.7.7",
43 | "diskcache>=5.4.0",
44 | "einops>=0.6.0",
45 | "encodec>=0.1.1",
46 | "g2p_en>=2.1.0",
47 | "humanize>=4.4.0",
48 | "matplotlib>=3.6.0",
49 | "numpy>=1.23.3",
50 | "omegaconf>=2.2.3",
51 | "openTSNE>=0.6.2",
52 | "pandas>=1.5.0",
53 | "soundfile>=0.11.0",
54 | "torch>=1.13.0",
55 | "torchaudio>=0.13.0",
56 | "tqdm>=4.64.1",
57 | ],
58 | url="https://github.com/enhuiz/vall-e",
59 | )
60 |
--------------------------------------------------------------------------------
/vall-e.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/enhuiz/vall-e/3476d393d2133fa9b50d5ad999ca13b95fc22060/vall-e.png
--------------------------------------------------------------------------------
/vall_e/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/enhuiz/vall-e/3476d393d2133fa9b50d5ad999ca13b95fc22060/vall_e/__init__.py
--------------------------------------------------------------------------------
/vall_e/__main__.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from pathlib import Path
3 |
4 | import torch
5 | from einops import rearrange
6 |
7 | from .emb import g2p, qnt
8 | from .utils import to_device
9 |
10 |
11 | def main():
12 | parser = argparse.ArgumentParser("VALL-E TTS")
13 | parser.add_argument("text")
14 | parser.add_argument("reference", type=Path)
15 | parser.add_argument("out_path", type=Path)
16 | parser.add_argument("--ar-ckpt", type=Path, default="zoo/ar.pt")
17 | parser.add_argument("--nar-ckpt", type=Path, default="zoo/nar.pt")
18 | parser.add_argument("--device", default="cuda")
19 | args = parser.parse_args()
20 |
21 | ar = torch.load(args.ar_ckpt).to(args.device)
22 | nar = torch.load(args.nar_ckpt).to(args.device)
23 |
24 | symmap = ar.phone_symmap
25 |
26 | proms = qnt.encode_from_file(args.reference)
27 | proms = rearrange(proms, "1 l t -> t l")
28 |
29 | phns = torch.tensor([symmap[p] for p in g2p.encode(args.text)])
30 |
31 | proms = to_device(proms, args.device)
32 | phns = to_device(phns, args.device)
33 |
34 | resp_list = ar(text_list=[phns], proms_list=[proms])
35 | resps_list = [r.unsqueeze(-1) for r in resp_list]
36 |
37 | resps_list = nar(text_list=[phns], proms_list=[proms], resps_list=resps_list)
38 | qnt.decode_to_file(resps=resps_list[0], path=args.out_path)
39 | print(args.out_path, "saved.")
40 |
41 |
42 | if __name__ == "__main__":
43 | main()
44 |
--------------------------------------------------------------------------------
/vall_e/config.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass, field
2 | from functools import cached_property
3 | from pathlib import Path
4 |
5 | import diskcache
6 |
7 | from .utils import Config as ConfigBase
8 |
9 |
10 | @dataclass(frozen=True)
11 | class Config(ConfigBase):
12 | data_root: Path = Path("data")
13 | data_dirs: list[Path] = field(default_factory=lambda: [])
14 |
15 | @property
16 | def sample_rate(self):
17 | return 24_000
18 |
19 | p_additional_prompt: float = 0.8
20 | max_prompts: int = 3
21 |
22 | max_num_val: int = 20
23 | max_val_ar_steps: int = 300
24 |
25 | token_dim: int = 256
26 | num_tokens: int = 1024
27 |
28 | nj: int = 8
29 | batch_size: int = 32
30 | eval_batch_size: int = 32
31 | warmup_min_lr: float = 1e-6
32 | warmup_max_lr: float = 2e-4
33 | dis_warmup_max_lr: float = 4e-4
34 | warmup_num_steps: int = 1_000
35 | max_iter: int = 1_000_000
36 | gradient_clipping: float = 100
37 | eval_every: int = 2_000
38 | save_ckpt_every: int = 2_000
39 |
40 | model: str = "ar-quarter"
41 | spkr_name_getter: str = "lambda p: p.parts[-2]"
42 |
43 | min_phones: int = 10
44 | max_phones: int = 50
45 |
46 | use_fp16: bool = True
47 | gradient_accumulation_steps: int = 1
48 | sampling_temperature: float = 1.0
49 |
50 | cache_dataloader: bool = False
51 |
52 | @cached_property
53 | def get_spkr(self):
54 | return eval(self.spkr_name_getter)
55 |
56 | @property
57 | def fp16_cfg(self):
58 | return {
59 | "enabled": self.use_fp16,
60 | }
61 |
62 | @property
63 | def ds_cfg(self):
64 | return {
65 | "train_micro_batch_size_per_gpu": self.batch_size,
66 | "gradient_accumulation_steps": self.gradient_accumulation_steps,
67 | "optimizer": {
68 | "type": "Adam",
69 | "lr": self.warmup_min_lr,
70 | },
71 | "scheduler": {
72 | "type": "WarmupDecayLR",
73 | "params": {
74 | "warmup_min_lr": self.warmup_min_lr,
75 | "warmup_max_lr": self.warmup_max_lr,
76 | "warmup_num_steps": self.warmup_num_steps,
77 | "total_num_steps": self.max_iter,
78 | "warmup_type": "linear",
79 | },
80 | },
81 | "gradient_clipping": self.gradient_clipping,
82 | "fp16": self.fp16_cfg,
83 | }
84 |
85 | @property
86 | def cache_dir(self):
87 | return ".cache" / self.relpath
88 |
89 | @cached_property
90 | def diskcache(self):
91 | if self.cache_dataloader:
92 | return diskcache.Cache(self.cache_dir).memoize
93 | return lambda: lambda x: x
94 |
95 |
96 | cfg = Config.from_cli()
97 |
98 | if __name__ == "__main__":
99 | print(cfg)
100 |
--------------------------------------------------------------------------------
/vall_e/data.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import logging
3 | import random
4 | from collections import defaultdict
5 | from functools import cache, cached_property
6 | from itertools import groupby, zip_longest
7 | from typing import Any
8 |
9 | import numpy as np
10 | import torch
11 | from torch import Tensor
12 | from torch.utils.data import DataLoader, Dataset
13 | from tqdm import tqdm
14 |
15 | from .config import cfg
16 | from .sampler import Sampler
17 |
18 | torch.multiprocessing.set_sharing_strategy("file_system")
19 |
20 | _logger = logging.getLogger(__name__)
21 |
22 |
23 | def _replace_file_extension(path, suffix):
24 | return (path.parent / path.name.split(".")[0]).with_suffix(suffix)
25 |
26 |
27 | def _get_quant_path(path):
28 | return _replace_file_extension(path, ".qnt.pt")
29 |
30 |
31 | def _load_quants(path) -> Tensor:
32 | """
33 | Returns:
34 | quants: (t q)
35 | """
36 | path = _get_quant_path(path)
37 | return torch.load(path)[0].t()
38 |
39 |
40 | @cache
41 | def _get_phones(path):
42 | path = _replace_file_extension(path, ".phn.txt")
43 | with open(path, "r", encoding="utf8") as f:
44 | content = f.read()
45 | return [""] + content.split() + [""]
46 |
47 |
48 | def _interleaved_reorder(l, fn):
49 | groups = defaultdict(list)
50 | for e in l:
51 | groups[fn(e)].append(e)
52 | groups = {k: groups[k] for k in sorted(groups)}
53 | for interleaved in zip_longest(*groups.values()):
54 | for value in interleaved:
55 | if value is not None:
56 | yield value
57 |
58 |
59 | @cache
60 | def _validate(path, min_phones, max_phones):
61 | phones = _get_phones(path)
62 | unique_phones = list(set(phones))
63 | if len(unique_phones) == 0:
64 | return False
65 | if len(unique_phones) == 1 and unique_phones[0] == "_":
66 | return False
67 | if len(phones) < min_phones:
68 | return False
69 | if len(phones) > max_phones:
70 | return False
71 | return True
72 |
73 |
74 | class VALLEDatset(Dataset):
75 | def __init__(
76 | self,
77 | paths,
78 | phone_symmap=None,
79 | spkr_symmap=None,
80 | min_phones=cfg.min_phones,
81 | max_phones=cfg.max_phones,
82 | training=False,
83 | extra_paths_by_spkr_name: dict[str, list] = {},
84 | ):
85 | super().__init__()
86 | self._head = None
87 | self.min_phones = min_phones
88 | self.max_phones = max_phones
89 |
90 | self.paths = [
91 | path for path in paths if _validate(path, self.min_phones, self.max_phones)
92 | ]
93 |
94 | self.spkr_symmap = spkr_symmap or self._get_spkr_symmap()
95 | self.phone_symmap = phone_symmap or self._get_phone_symmap()
96 | self.training = training
97 |
98 | self.paths_by_spkr_name = self._get_paths_by_spkr_name(extra_paths_by_spkr_name)
99 |
100 | self.paths = [
101 | p for p in self.paths if len(self.paths_by_spkr_name[cfg.get_spkr(p)]) > 1
102 | ]
103 |
104 | if len(self.paths) == 0 and training:
105 | raise ValueError("No valid path is found for training.")
106 |
107 | if training:
108 | self.sampler = Sampler(self.paths, [cfg.get_spkr])
109 | else:
110 | self.sampler = None
111 |
112 | def _get_paths_by_spkr_name(self, extra_paths_by_spkr_name: dict[str, list]):
113 | ret = defaultdict(list)
114 | for path in self.paths:
115 | if _get_quant_path(path).exists():
116 | ret[cfg.get_spkr(path)].append(path)
117 | for k, v in extra_paths_by_spkr_name.items():
118 | ret[k].extend(v)
119 | return {**ret}
120 |
121 | @cached_property
122 | def phones(self):
123 | return sorted(set().union(*[_get_phones(path) for path in self.paths]))
124 |
125 | def _get_phone_symmap(self):
126 | # Note that we use phone symmap starting from 1 so that we can safely pad 0.
127 | return {s: i for i, s in enumerate(self.phones, 1)}
128 |
129 | @cached_property
130 | def spkrs(self):
131 | return sorted({cfg.get_spkr(path) for path in self.paths})
132 |
133 | def _get_spkr_symmap(self):
134 | return {s: i for i, s in enumerate(self.spkrs)}
135 |
136 | def sample_prompts(self, spkr_name, ignore):
137 | prom_list = []
138 |
139 | choices = set(self.paths_by_spkr_name[spkr_name]) - {ignore}
140 | choices = [*choices]
141 |
142 | if len(choices) == 0:
143 | raise ValueError(
144 | f"Failed to find another different utterance for {spkr_name}."
145 | )
146 |
147 | for _ in range(cfg.max_prompts):
148 | path = random.choice(choices)
149 | prom_list.append(_load_quants(path))
150 | if random.random() > cfg.p_additional_prompt:
151 | break
152 |
153 | prom = torch.cat(prom_list)
154 |
155 | return prom
156 |
157 | def __getitem__(self, index):
158 | if self.training:
159 | assert self.sampler is not None
160 | path = self.sampler.sample()
161 | else:
162 | path = self.paths[index]
163 |
164 | spkr_name = cfg.get_spkr(path)
165 | text = torch.tensor([*map(self.phone_symmap.get, _get_phones(path))])
166 | proms = self.sample_prompts(spkr_name, ignore=path)
167 | resps = _load_quants(path)
168 | resp = resps[..., 0]
169 |
170 | return dict(
171 | path=path,
172 | spkr_name=spkr_name,
173 | text=text,
174 | proms=proms,
175 | resps=resps,
176 | resp=resp,
177 | )
178 |
179 | def head_(self, n):
180 | self._head = n
181 |
182 | def training_(self, value):
183 | self.training = value
184 |
185 | def interleaved_reorder_(self, fn):
186 | self.paths = [*_interleaved_reorder(self.paths, fn)]
187 |
188 | def __len__(self):
189 | return min(len(self.paths), self._head or len(self.paths))
190 |
191 |
192 | def collate_fn(samples: list[dict]):
193 | batch: dict[str, Any] = {k: [s[k] for s in samples] for k in samples[0]}
194 | return batch
195 |
196 |
197 | def _seed_worker(worker_id):
198 | worker_seed = torch.initial_seed() % 2**32
199 | np.random.seed(worker_seed)
200 | random.seed(worker_seed)
201 |
202 |
203 | def _create_dataloader(dataset, training):
204 | return DataLoader(
205 | dataset=dataset,
206 | batch_size=cfg.batch_size if training else cfg.eval_batch_size,
207 | shuffle=training,
208 | drop_last=training,
209 | num_workers=cfg.nj,
210 | collate_fn=collate_fn,
211 | persistent_workers=True,
212 | worker_init_fn=_seed_worker,
213 | )
214 |
215 |
216 | def _load_train_val_paths():
217 | paths = []
218 | train_paths = []
219 | val_paths = []
220 |
221 | for data_dir in cfg.data_dirs:
222 | paths.extend(tqdm(data_dir.rglob("*.qnt.pt")))
223 |
224 | if len(paths) == 0:
225 | raise RuntimeError(f"Failed to find any .qnt.pt file in {cfg.data_dirs}.")
226 |
227 | pairs = sorted([(cfg.get_spkr(p), p) for p in paths])
228 | del paths
229 |
230 | for _, group in groupby(pairs, lambda pair: pair[0]):
231 | paths = sorted([p for _, p in group])
232 | random.seed(0)
233 | random.shuffle(paths)
234 | n = round(len(paths) * 0.95)
235 | train_paths.extend(paths[:n])
236 | val_paths.extend(paths[n:])
237 |
238 | train_paths, val_paths = map(sorted, [train_paths, val_paths])
239 |
240 | return train_paths, val_paths
241 |
242 |
243 | @cfg.diskcache()
244 | def create_datasets():
245 | train_paths, val_paths = _load_train_val_paths()
246 |
247 | train_dataset = VALLEDatset(
248 | train_paths,
249 | training=True,
250 | )
251 |
252 | val_dataset = VALLEDatset(
253 | val_paths,
254 | train_dataset.phone_symmap,
255 | train_dataset.spkr_symmap,
256 | extra_paths_by_spkr_name=train_dataset.paths_by_spkr_name,
257 | )
258 |
259 | val_dataset.interleaved_reorder_(cfg.get_spkr)
260 | val_dataset.head_(cfg.max_num_val)
261 |
262 | return train_dataset, val_dataset
263 |
264 |
265 | def create_train_val_dataloader():
266 | train_dataset, val_dataset = create_datasets()
267 |
268 | train_dl = _create_dataloader(train_dataset, training=True)
269 | val_dl = _create_dataloader(val_dataset, training=False)
270 |
271 | _logger.info(str(train_dataset.phone_symmap))
272 | _logger.info(str(train_dataset.spkr_symmap))
273 |
274 | _logger.info(f"#samples (train): {len(train_dataset)}.")
275 | _logger.info(f"#samples (val): {len(val_dataset)}.")
276 |
277 | subtrain_dataset = copy.deepcopy(train_dataset)
278 | subtrain_dataset.interleaved_reorder_(cfg.get_spkr)
279 | subtrain_dataset.head_(cfg.max_num_val)
280 | subtrain_dataset.training_(False)
281 | subtrain_dl = _create_dataloader(subtrain_dataset, training=False)
282 | assert isinstance(subtrain_dl.dataset, VALLEDatset)
283 |
284 | return train_dl, subtrain_dl, val_dl
285 |
286 |
287 | if __name__ == "__main__":
288 | train_dl, subtrain_dl, val_dl = create_train_val_dataloader()
289 | sample = train_dl.dataset[0]
290 | print(sample)
291 |
--------------------------------------------------------------------------------
/vall_e/emb/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/enhuiz/vall-e/3476d393d2133fa9b50d5ad999ca13b95fc22060/vall_e/emb/__init__.py
--------------------------------------------------------------------------------
/vall_e/emb/g2p.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import random
3 | import string
4 | from functools import cache
5 | from pathlib import Path
6 |
7 | import torch
8 | from g2p_en import G2p
9 | from tqdm import tqdm
10 |
11 |
12 | @cache
13 | def _get_model():
14 | return G2p()
15 |
16 |
17 | @cache
18 | def _get_graphs(path):
19 | with open(path, "r") as f:
20 | graphs = f.read()
21 | return graphs
22 |
23 |
24 | def encode(graphs: str) -> list[str]:
25 | g2p = _get_model()
26 | phones = g2p(graphs)
27 | ignored = {" ", *string.punctuation}
28 | return ["_" if p in ignored else p for p in phones]
29 |
30 |
31 | @torch.no_grad()
32 | def main():
33 | parser = argparse.ArgumentParser()
34 | parser.add_argument("folder", type=Path)
35 | parser.add_argument("--suffix", type=str, default=".normalized.txt")
36 | args = parser.parse_args()
37 |
38 | paths = list(args.folder.rglob(f"*{args.suffix}"))
39 | random.shuffle(paths)
40 |
41 | for path in tqdm(paths):
42 | phone_path = path.with_name(path.stem.split(".")[0] + ".phn.txt")
43 | if phone_path.exists():
44 | continue
45 | graphs = _get_graphs(path)
46 | phones = encode(graphs)
47 | with open(phone_path, "w") as f:
48 | f.write(" ".join(phones))
49 |
50 |
51 | if __name__ == "__main__":
52 | main()
53 |
--------------------------------------------------------------------------------
/vall_e/emb/qnt.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import random
3 | from functools import cache
4 | from pathlib import Path
5 |
6 | import soundfile
7 | import torch
8 | import torchaudio
9 | from einops import rearrange
10 | from encodec import EncodecModel
11 | from encodec.utils import convert_audio
12 | from torch import Tensor
13 | from tqdm import tqdm
14 |
15 | from ..config import cfg
16 |
17 |
18 | @cache
19 | def _load_model(device="cuda"):
20 | # Instantiate a pretrained EnCodec model
21 | assert cfg.sample_rate == 24_000
22 | model = EncodecModel.encodec_model_24khz()
23 | model.set_target_bandwidth(6.0)
24 | model.to(device)
25 | return model
26 |
27 |
28 | def unload_model():
29 | return _load_model.cache_clear()
30 |
31 |
32 | @torch.inference_mode()
33 | def decode(codes: Tensor, device="cuda"):
34 | """
35 | Args:
36 | codes: (b q t)
37 | """
38 | assert codes.dim() == 3
39 | model = _load_model(device)
40 | return model.decode([(codes, None)]), model.sample_rate
41 |
42 |
43 | def decode_to_file(resps: Tensor, path: Path):
44 | assert resps.dim() == 2, f"Require shape (t q), but got {resps.shape}."
45 | resps = rearrange(resps, "t q -> 1 q t")
46 | wavs, sr = decode(resps)
47 | soundfile.write(str(path), wavs.cpu()[0, 0], sr)
48 |
49 |
50 | def _replace_file_extension(path, suffix):
51 | return (path.parent / path.name.split(".")[0]).with_suffix(suffix)
52 |
53 |
54 | @torch.inference_mode()
55 | def encode(wav: Tensor, sr: int, device="cuda"):
56 | """
57 | Args:
58 | wav: (t)
59 | sr: int
60 | """
61 | model = _load_model(device)
62 | wav = wav.unsqueeze(0)
63 | wav = convert_audio(wav, sr, model.sample_rate, model.channels)
64 | wav = wav.to(device)
65 | encoded_frames = model.encode(wav)
66 | qnt = torch.cat([encoded[0] for encoded in encoded_frames], dim=-1) # (b q t)
67 | return qnt
68 |
69 |
70 | def encode_from_file(path, device="cuda"):
71 | wav, sr = torchaudio.load(str(path))
72 | if wav.shape[0] == 2:
73 | wav = wav[:1]
74 | return encode(wav, sr, device)
75 |
76 |
77 | def main():
78 | parser = argparse.ArgumentParser()
79 | parser.add_argument("folder", type=Path)
80 | parser.add_argument("--suffix", default=".wav")
81 | args = parser.parse_args()
82 |
83 | paths = [*args.folder.rglob(f"*{args.suffix}")]
84 | random.shuffle(paths)
85 |
86 | for path in tqdm(paths):
87 | out_path = _replace_file_extension(path, ".qnt.pt")
88 | if out_path.exists():
89 | continue
90 | qnt = encode_from_file(path)
91 | torch.save(qnt.cpu(), out_path)
92 |
93 |
94 | if __name__ == "__main__":
95 | main()
96 |
--------------------------------------------------------------------------------
/vall_e/export.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import torch
4 |
5 | from .data import VALLEDatset, create_train_val_dataloader
6 | from .train import load_engines
7 |
8 |
9 | def main():
10 | parser = argparse.ArgumentParser("Save trained model to path.")
11 | parser.add_argument("path")
12 | args = parser.parse_args()
13 |
14 | engine = load_engines()
15 | model = engine["model"].module.cpu()
16 | train_dl, *_ = create_train_val_dataloader()
17 | assert isinstance(train_dl.dataset, VALLEDatset)
18 | model.phone_symmap = train_dl.dataset.phone_symmap
19 | model.spkr_symmap = train_dl.dataset.spkr_symmap
20 | torch.save(model, args.path)
21 | print(args.path, "saved.")
22 |
23 |
24 | if __name__ == "__main__":
25 | main()
26 |
--------------------------------------------------------------------------------
/vall_e/sampler.py:
--------------------------------------------------------------------------------
1 | """
2 | A sampler that balances data by key_fns.
3 |
4 | MIT License
5 |
6 | Copyright (c) 2023 Zhe Niu
7 |
8 | niuzhe.nz@outlook.com
9 | """
10 |
11 | import random
12 |
13 |
14 | class Sampler:
15 | def __init__(self, l, key_fns):
16 | self.tree = self._build(l, key_fns)
17 |
18 | def _build(self, l, key_fns) -> dict[dict, list]:
19 | if not key_fns:
20 | return l
21 |
22 | tree = {}
23 |
24 | key_fn, *key_fns = key_fns
25 |
26 | for x in l:
27 | k = key_fn(x)
28 |
29 | if k in tree:
30 | tree[k].append(x)
31 | else:
32 | tree[k] = [x]
33 |
34 | for k in tree:
35 | tree[k] = self._build(tree[k], key_fns)
36 |
37 | return tree
38 |
39 | def _sample(self, tree: dict | list):
40 | if isinstance(tree, list):
41 | ret = random.choice(tree)
42 | else:
43 | key = random.choice([*tree.keys()])
44 | ret = self._sample(tree[key])
45 | return ret
46 |
47 | def sample(self):
48 | return self._sample(self.tree)
49 |
--------------------------------------------------------------------------------
/vall_e/train.py:
--------------------------------------------------------------------------------
1 | import json
2 | import logging
3 | from collections import defaultdict
4 |
5 | import torch
6 | from tqdm import tqdm
7 |
8 | from .config import cfg
9 | from .data import create_train_val_dataloader
10 | from .emb import qnt
11 | from .utils import setup_logging, to_device, trainer
12 | from .vall_e import get_model
13 |
14 | _logger = logging.getLogger(__name__)
15 |
16 |
17 | def load_engines():
18 | model = get_model(cfg.model)
19 |
20 | engines = dict(
21 | model=trainer.Engine(
22 | model=model,
23 | config=cfg.ds_cfg,
24 | ),
25 | )
26 |
27 | return trainer.load_engines(engines, cfg)
28 |
29 |
30 | def main():
31 | setup_logging(cfg.log_dir)
32 |
33 | train_dl, subtrain_dl, val_dl = create_train_val_dataloader()
34 |
35 | def train_feeder(engines, batch, name):
36 | model = engines["model"]
37 |
38 | if cfg.model.startswith("ar"):
39 | _ = model(
40 | text_list=batch["text"],
41 | proms_list=batch["proms"],
42 | resp_list=batch["resp"],
43 | )
44 | elif cfg.model.startswith("nar"):
45 | _ = model(
46 | text_list=batch["text"],
47 | proms_list=batch["proms"],
48 | resps_list=batch["resps"],
49 | )
50 | else:
51 | raise NotImplementedError(cfg.model)
52 |
53 | losses = model.gather_attribute("loss")
54 |
55 | loss = torch.stack([*losses.values()]).sum()
56 |
57 | stats = {}
58 | stats |= {k: v.item() for k, v in losses.items()}
59 | stats |= engines.gather_attribute("scalar")
60 |
61 | return loss, stats
62 |
63 | @torch.inference_mode()
64 | def run_eval(engines, name, dl):
65 | log_dir = cfg.log_dir / str(engines.global_step) / name
66 |
67 | model = engines["model"]
68 | log_dir = cfg.log_dir / str(engines.global_step) / name
69 | stats = defaultdict(list)
70 | for batch in tqdm(dl):
71 | batch: dict = to_device(batch, cfg.device)
72 |
73 | if cfg.model.startswith("ar"):
74 | resp_list = model(
75 | text_list=batch["text"],
76 | proms_list=batch["proms"],
77 | max_steps=cfg.max_val_ar_steps,
78 | sampling_temperature=cfg.sampling_temperature,
79 | )
80 | resps_list = [r.unsqueeze(-1) for r in resp_list]
81 | elif cfg.model.startswith("nar"):
82 | resps_list = model(
83 | text_list=batch["text"],
84 | proms_list=batch["proms"],
85 | resps_list=[r.unsqueeze(-1) for r in batch["resp"]],
86 | sampling_temperature=cfg.sampling_temperature,
87 | )
88 | else:
89 | raise NotImplementedError(cfg.model)
90 |
91 | losses = model.gather_attribute("loss")
92 | batch_stats = {k: v.item() for k, v in losses.items()}
93 | for k, v in batch_stats.items():
94 | stats[k].append(v)
95 |
96 | for path, ref, hyp in zip(batch["path"], batch["resps"], resps_list):
97 | relpath = path.relative_to(cfg.data_root)
98 | hyp_path = (log_dir / "hyp" / relpath).with_suffix(".wav")
99 | ref_path = (log_dir / "ref" / relpath).with_suffix(".wav")
100 | hyp_path.parent.mkdir(parents=True, exist_ok=True)
101 | ref_path.parent.mkdir(parents=True, exist_ok=True)
102 | qnt.decode_to_file(ref, ref_path)
103 | if len(hyp) > 0:
104 | qnt.decode_to_file(hyp, hyp_path)
105 |
106 | qnt.unload_model()
107 |
108 | stats = {k: sum(v) / len(v) for k, v in stats.items()}
109 | stats["global_step"] = engines.global_step
110 | stats["name"] = name
111 | _logger.info(f"Eval: {stats}.")
112 |
113 | _logger.info(f"{json.dumps(stats)}.")
114 |
115 | def eval_fn(engines):
116 | run_eval(engines, "subtrain", subtrain_dl)
117 | run_eval(engines, "val", val_dl)
118 |
119 | trainer.train(
120 | engines_loader=load_engines,
121 | train_dl=train_dl,
122 | train_feeder=train_feeder,
123 | eval_fn=eval_fn,
124 | )
125 |
126 |
127 | if __name__ == "__main__":
128 | main()
129 |
--------------------------------------------------------------------------------
/vall_e/vall_e/__init__.py:
--------------------------------------------------------------------------------
1 | from ..config import cfg
2 | from .ar import AR
3 | from .nar import NAR
4 |
5 |
6 | def get_model(name):
7 | name = name.lower()
8 |
9 | if name.startswith("ar"):
10 | Model = AR
11 | elif name.startswith("nar"):
12 | Model = NAR
13 | else:
14 | raise ValueError("Model name should start with AR or NAR.")
15 |
16 | if "-quarter" in name:
17 | model = Model(
18 | cfg.num_tokens,
19 | d_model=256,
20 | n_heads=4,
21 | n_layers=12,
22 | )
23 | elif "-half" in name:
24 | model = Model(
25 | cfg.num_tokens,
26 | d_model=512,
27 | n_heads=8,
28 | n_layers=12,
29 | )
30 | else:
31 | if name not in ["ar", "nar"]:
32 | raise NotImplementedError(name)
33 |
34 | model = Model(
35 | cfg.num_tokens,
36 | d_model=1024,
37 | n_heads=16,
38 | n_layers=12,
39 | )
40 |
41 | return model
42 |
--------------------------------------------------------------------------------
/vall_e/vall_e/ar.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from einops import rearrange
3 | from torch import Tensor
4 | from tqdm import trange
5 |
6 | from .base import Base
7 |
8 |
9 | class AR(Base):
10 | @property
11 | def n_resp_levels(self):
12 | return 1
13 |
14 | @property
15 | def casual(self):
16 | return True
17 |
18 | @property
19 | def use_stop_token(self):
20 | return True
21 |
22 | @property
23 | def norm_type(self):
24 | return "ln"
25 |
26 | @property
27 | def resp_loss_only(self):
28 | return False
29 |
30 | def _prune(self, l: Tensor):
31 | indices = (l == self.stop_token).nonzero()
32 | if len(indices) == 0:
33 | return l
34 | return l[: indices.min().item()]
35 |
36 | @staticmethod
37 | def _unsqueeze_list(x_list, axis=-1):
38 | return [x.unsqueeze(dim=axis) for x in x_list]
39 |
40 | def forward(
41 | self,
42 | text_list: list[Tensor],
43 | proms_list: list[Tensor],
44 | resp_list: list[Tensor] | None = None,
45 | max_steps: int = 1000,
46 | sampling_temperature: float = 1.0,
47 | ):
48 | if resp_list is not None:
49 | return super().forward(
50 | text_list,
51 | proms_list,
52 | self._unsqueeze_list(resp_list),
53 | resp_list,
54 | quant_levels=None,
55 | shift_targ_list=True,
56 | return_all_resp=False,
57 | )
58 | else:
59 | return self._generate(
60 | text_list,
61 | proms_list,
62 | max_steps,
63 | sampling_temperature,
64 | )
65 |
66 | def _generate(
67 | self,
68 | text_list: list[Tensor],
69 | proms_list: list[Tensor],
70 | max_steps: int,
71 | sampling_temperature: float,
72 | ):
73 | device = text_list[0].device
74 | resp_list: list[Tensor] = [
75 | torch.zeros(0, device=device).long() for _ in text_list
76 | ]
77 | stopped = torch.zeros(len(text_list), device=device).bool()
78 | for _ in trange(max_steps):
79 | r = super().forward(
80 | text_list,
81 | proms_list,
82 | self._unsqueeze_list(resp_list),
83 | sampling_temperature=sampling_temperature,
84 | )
85 | stopped |= r == self.stop_token
86 | for i, ri in enumerate(r):
87 | resp_list[i] = torch.cat([resp_list[i], ri[None]])
88 | if stopped.all().item():
89 | break
90 | pruned = [self._prune(r) for r in resp_list]
91 | return pruned
92 |
93 |
94 | def example_usage():
95 | from functools import partial
96 |
97 | import soundfile
98 | from einops import repeat
99 |
100 | device = "cuda"
101 |
102 | qnt = torch.load("data/test/test.qnt.pt")[0, 0].to(device)
103 | num_qnts = 1024
104 |
105 | model = AR(num_qnts).to(device)
106 |
107 | text_list = [
108 | torch.tensor([1, 2, 3], device=device),
109 | torch.tensor([2, 3], device=device),
110 | ]
111 |
112 | x8 = partial(repeat, pattern="t -> t l", l=8)
113 | proms_list = [
114 | x8(torch.tensor([1, 2, 3], device=device)),
115 | x8(torch.tensor([2, 3], device=device)),
116 | ]
117 |
118 | resp_list = [
119 | torch.tensor([1, 2, 3], device=device),
120 | qnt.to(device),
121 | ]
122 |
123 | out = model(text_list, proms_list, max_steps=200)
124 |
125 | print(out)
126 |
127 | optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
128 |
129 | for i in range(100):
130 | optimizer.zero_grad()
131 | _ = model(text_list, proms_list, resp_list)
132 |
133 | losses = model.loss
134 | sum(losses.values()).backward()
135 | optimizer.step()
136 |
137 | if i % 20 == 0:
138 | print(f"iter={i}, {losses}.")
139 |
140 | out = model(text_list, proms_list, max_steps=200)
141 |
142 | print(qnt)
143 | print(out)
144 |
145 | from ..emb.qnt import decode
146 |
147 | codes = rearrange(out[1], "t -> 1 1 t")
148 | wavs, sr = decode(codes)
149 | soundfile.write("data/test/test.ar.recon.wav", wavs.cpu()[0, 0], sr)
150 |
151 |
152 | if __name__ == "__main__":
153 | example_usage()
154 |
--------------------------------------------------------------------------------
/vall_e/vall_e/base.py:
--------------------------------------------------------------------------------
1 | import math
2 | from functools import partial
3 | from typing import Literal, overload
4 |
5 | import torch
6 | import torch.nn.functional as F
7 | from einops import rearrange
8 | from torch import Tensor, einsum, nn
9 | from torch.distributions import Categorical
10 | from torch.nn.utils.rnn import pad_sequence
11 | from torch.utils.checkpoint import checkpoint
12 |
13 |
14 | def _create_mask(l, device):
15 | """1 is valid region and 0 is invalid."""
16 | seq = torch.arange(max(l), device=device).unsqueeze(0) # (1 t)
17 | stop = torch.tensor(l, device=device).unsqueeze(1) # (b 1)
18 | return (seq < stop).float() # (b t)
19 |
20 |
21 | def list_to_tensor(x_list: list[Tensor], pattern="t b c -> b t c"):
22 | """
23 | Args:
24 | x_list: [(t d)]
25 | Returns:
26 | x: (? ? ?)
27 | m: (? ? ?), same as x
28 | """
29 | l = list(map(len, x_list))
30 | x = rearrange(pad_sequence(x_list), pattern)
31 | m = _create_mask(l, x_list[0].device)
32 | m = m.t().unsqueeze(-1) # (t b 1)
33 | m = rearrange(m, pattern)
34 | m = m.to(x)
35 | return x, m
36 |
37 |
38 | class SinusodialEmbedding(nn.Module):
39 | def __init__(self, d_model):
40 | super().__init__()
41 | self.d_model = d_model
42 | exponent = torch.arange(self.d_half, dtype=torch.float32)
43 | exponent = exponent / self.d_half
44 | omega = torch.exp(-math.log(1e4) * exponent)
45 | self.omega: torch.Tensor
46 | self.register_buffer("omega", omega, persistent=False)
47 |
48 | @property
49 | def d_half(self):
50 | assert self.d_model % 2 == 0, "Only support even d_model."
51 | return self.d_model // 2
52 |
53 | def forward(self, x):
54 | """
55 | Args:
56 | x: (...)
57 | Returns:
58 | pe: (... d)
59 | """
60 | omega = self.omega
61 |
62 | while omega.dim() <= x.dim():
63 | omega = omega.unsqueeze(0) # (... d)
64 |
65 | x = x.unsqueeze(-1) # (... 1)
66 | x = omega * x
67 | x = torch.cat([x.sin(), x.cos()], dim=-1)
68 |
69 | return x
70 |
71 | def get_pe(self, n: int):
72 | """
73 | Args:
74 | n: int
75 | Returns:
76 | pe: (n d)
77 | """
78 | device = self.omega.device
79 | return self.forward(torch.arange(n, device=device))
80 |
81 | def add_pe(self, x):
82 | """
83 | Args:
84 | x: (b t c)
85 | """
86 | e = self.get_pe(x.shape[1]) # t d
87 | e = e[None] # b t d
88 | x = x + e
89 | return x
90 |
91 |
92 | class Attention(nn.Module):
93 | def __init__(self, d_model, n_heads, casual):
94 | super().__init__()
95 | assert d_model % n_heads == 0
96 | dim_head = d_model // n_heads
97 | self.casual = casual
98 | self.n_heads = n_heads
99 | self.scale = dim_head**-0.5
100 | self.to_qkv = nn.Linear(d_model, d_model * 3, bias=False)
101 | self.to_out = nn.Linear(d_model, d_model)
102 |
103 | def forward(self, x, m):
104 | """
105 | Args:
106 | x: (b t c)
107 | m: (b t c), 1 is data, 0 is padding
108 | Returns:
109 | x: (b t c)
110 | """
111 | h = self.n_heads
112 |
113 | q, k, v = self.to_qkv(x).chunk(3, dim=-1)
114 | q, k, v = map(lambda t: rearrange(t, "b t (h d) -> b t h d", h=h), (q, k, v))
115 |
116 | e = einsum("b i h d, b j h d -> b i j h", q, k)
117 | e = e * self.scale
118 |
119 | kpm = m.unsqueeze(1) * m.unsqueeze(2) # b i j 1
120 |
121 | if self.casual:
122 | kpm = kpm.squeeze(-1).tril().unsqueeze(-1) # b i j 1
123 |
124 | e = e.masked_fill(kpm == 0, -torch.finfo(e.dtype).max)
125 | a = e.softmax(dim=2) # Normalize on j, i.e. key
126 |
127 | o = einsum("b i j h, b j h d -> b i h d", a, v)
128 | o = o.flatten(-2)
129 | o = self.to_out(o) # b t c
130 |
131 | o = o * m
132 |
133 | return o
134 |
135 |
136 | class AdaLN(nn.Module):
137 | def __init__(self, d_model, n_levels, eps=1e-5, k=0.1, c=2):
138 | super().__init__()
139 | self.eps = eps
140 | self.emb = nn.Embedding(n_levels, d_model * 2)
141 | self.k = k
142 | self.c = c
143 | nn.init.zeros_(self.emb.weight)
144 |
145 | def forward(self, x, l):
146 | logγ, β = self.emb(l).unsqueeze(1).chunk(2, dim=-1)
147 |
148 | h = F.layer_norm(x, x.shape[-1:], eps=self.eps)
149 |
150 | # The initial implementation (https://github.com/enhuiz/vall-e/blob/fbf023448c08e55c0422eefed7fc234cf8b76680/vall_e/vall_e/base.py#L135)
151 | # performed worse than vanilla LayerNorm.
152 | # The authors mentioned another AdaNorm paper (https://openreview.net/pdf?id=HyxndNrxLB) as they introduce AdaLN.
153 | # Did they use AdaNorm inside AdaLN? (as follows)
154 | h = self.c * (1 - (self.k * h).detach()) * h
155 |
156 | y = logγ.exp() * h + β
157 |
158 | return y
159 |
160 |
161 | class PrenormResidual(nn.Module):
162 | def __init__(
163 | self,
164 | block,
165 | d_model,
166 | p_dropout,
167 | requires_mask=False,
168 | norm_type="ln",
169 | n_levels: int | None = None,
170 | ):
171 | super().__init__()
172 | self.block = block
173 | self.requires_mask = requires_mask
174 | self.norm_type = norm_type
175 | if norm_type == "ln":
176 | self.norm = nn.LayerNorm(d_model)
177 | elif norm_type == "adaln":
178 | assert n_levels is not None
179 | self.norm = AdaLN(d_model, n_levels)
180 | else:
181 | raise NotImplementedError(norm_type)
182 | self.dropout = nn.Dropout(p_dropout)
183 |
184 | def forward(self, x, m, l):
185 | """
186 | Args:
187 | x: input (b t d)
188 | m: mask (b t 1), 1 is valuable and 0 is padding
189 | l: level to use, required only for AdaLN
190 | """
191 | nopts = {"l": l} if self.norm_type == "adaln" else {}
192 | bopts = {"m": m} if self.requires_mask else {}
193 | x = x + self.dropout(self.block(self.norm(x, **nopts) * m, **bopts))
194 | return x * m
195 |
196 |
197 | class Block(nn.Sequential):
198 | def __init__(self, d_model, n_heads, p_dropout, casual, norm_type, n_levels):
199 | super().__init__()
200 | self.attn = PrenormResidual(
201 | Attention(d_model, n_heads, casual),
202 | d_model=d_model,
203 | p_dropout=p_dropout,
204 | requires_mask=True,
205 | norm_type=norm_type,
206 | n_levels=n_levels,
207 | )
208 | self.ffn = PrenormResidual(
209 | nn.Sequential(
210 | nn.Linear(d_model, d_model * 4),
211 | nn.GELU(),
212 | nn.Dropout(p_dropout),
213 | nn.Linear(d_model * 4, d_model),
214 | ),
215 | d_model=d_model,
216 | p_dropout=p_dropout,
217 | norm_type=norm_type,
218 | n_levels=n_levels,
219 | )
220 |
221 | def forward(self, x, m, l):
222 | """
223 | Args:
224 | x: (b t c)
225 | m: (b t 1)
226 | l: (b)
227 | """
228 | poor_in_vram = True
229 | if x.requires_grad and poor_in_vram:
230 | x = checkpoint(self.attn, x, m, l)
231 | else:
232 | x = self.attn(x, m, l)
233 | x = self.ffn(x, m, l)
234 | return x
235 |
236 |
237 | class Embedding(nn.Embedding):
238 | def forward(self, x_list: list[Tensor]) -> list[Tensor]:
239 | if len(x_list) == 0:
240 | return []
241 | return super().forward(torch.cat(x_list)).split([*map(len, x_list)])
242 |
243 |
244 | class MultiEmbedding(nn.Module):
245 | """
246 | This embedding sums embeddings on different levels.
247 | """
248 |
249 | def __init__(self, max_n_levels, n_tokens, token_dim):
250 | super().__init__()
251 | self.max_n_levels = max_n_levels
252 | self.n_tokens = n_tokens
253 | self.weight = nn.Parameter(torch.randn(max_n_levels, n_tokens, token_dim))
254 |
255 | def forward(self, x_list: list[Tensor]) -> list[Tensor]:
256 | if len(x_list) == 0:
257 | return []
258 |
259 | w = self.weight
260 |
261 | padded_x_list = []
262 |
263 | for xi in x_list:
264 | xi = F.one_hot(xi, num_classes=self.n_tokens) # t l' k
265 | xi = F.pad(xi, (0, 0, 0, w.shape[0] - xi.shape[1])) # t l k
266 | padded_x_list.append(xi.to(w))
267 |
268 | x = torch.cat(padded_x_list) # n l k
269 | x = einsum("l k d, n l k -> n d", w, x)
270 |
271 | x_list = x.split([*map(len, x_list)])
272 |
273 | return x_list
274 |
275 |
276 | def _join(x: tuple[Tensor], sep: Tensor):
277 | """
278 | Args:
279 | x: (k t d)
280 | sep: (d)
281 | """
282 | ret = x[0]
283 | for i in range(1, len(x)):
284 | ret = torch.cat((ret, sep[None], x[i]), dim=0)
285 | return ret
286 |
287 |
288 | class Base(nn.Module):
289 | @property
290 | def casual(self) -> bool:
291 | raise NotImplementedError
292 |
293 | @property
294 | def n_resp_levels(self) -> int:
295 | raise NotImplementedError
296 |
297 | @property
298 | def use_stop_token(self) -> bool:
299 | raise NotImplementedError
300 |
301 | @property
302 | def norm_type(self):
303 | raise NotImplementedError
304 |
305 | @property
306 | def n_prom_levels(self) -> int:
307 | return 8
308 |
309 | @property
310 | def resp_loss_only(self):
311 | raise NotImplementedError
312 |
313 | def __init__(
314 | self,
315 | n_tokens: int,
316 | d_model: int = 512,
317 | n_heads: int = 8,
318 | n_layers: int = 12,
319 | p_dropout: float = 0.1,
320 | ):
321 | super().__init__()
322 | self.n_tokens = n_tokens
323 |
324 | casual = self.casual
325 |
326 | # +1 to include the stop token
327 | n_stop_tokens = 1 if self.use_stop_token else 0
328 | n_resp_tokens = n_tokens + n_stop_tokens
329 |
330 | self.text_emb = Embedding(n_tokens, d_model)
331 |
332 | # Here I simply use all prom levels
333 | self.proms_emb = MultiEmbedding(self.n_prom_levels, n_tokens, d_model)
334 | self.resps_emb = MultiEmbedding(self.n_resp_levels, n_resp_tokens, d_model)
335 |
336 | self.sin_emb = SinusodialEmbedding(d_model)
337 |
338 | self.sep = nn.Parameter(torch.randn(d_model))
339 |
340 | blocks = [
341 | Block(
342 | d_model=d_model,
343 | n_heads=n_heads,
344 | p_dropout=p_dropout,
345 | casual=casual,
346 | norm_type=self.norm_type,
347 | n_levels=self.n_resp_levels,
348 | )
349 | for _ in range(n_layers)
350 | ]
351 |
352 | self.blocks = nn.ModuleList(blocks)
353 |
354 | self.classifier = nn.Linear(d_model, n_resp_tokens)
355 |
356 | @property
357 | def stop_token(self):
358 | if not self.use_stop_token:
359 | raise ValueError("Not using stop token!")
360 | return self.n_tokens
361 |
362 | @property
363 | def ignore_index(self):
364 | return -100
365 |
366 | @staticmethod
367 | def _samplewise_merge_tensors(*l, sep: Tensor | None):
368 | if sep is None:
369 | cat = torch.cat
370 | else:
371 | cat = partial(_join, sep=sep)
372 | return [*map(cat, zip(*l))]
373 |
374 | @overload
375 | def forward(
376 | self,
377 | text_list: list[Tensor],
378 | proms_list: list[Tensor],
379 | resps_list: list[Tensor],
380 | targ_list: list[Tensor] | None = None,
381 | quant_levels: Tensor | None = None,
382 | shift_targ_list: bool = False,
383 | return_all_resp: Literal[False] = False,
384 | sampling_temperature: float = 1.0,
385 | ) -> Tensor:
386 | ...
387 |
388 | @overload
389 | def forward(
390 | self,
391 | text_list: list[Tensor],
392 | proms_list: list[Tensor],
393 | resps_list: list[Tensor],
394 | targ_list: list[Tensor] | None = None,
395 | quant_levels: Tensor | None = None,
396 | shift_targ_list: bool = False,
397 | return_all_resp: Literal[True] = True,
398 | sampling_temperature: float = 1.0,
399 | ) -> list[Tensor]:
400 | ...
401 |
402 | def forward(
403 | self,
404 | text_list: list[Tensor],
405 | proms_list: list[Tensor],
406 | resps_list: list[Tensor],
407 | targ_list: list[Tensor] | None = None,
408 | quant_levels: Tensor | None = None,
409 | shift_targ_list: bool = False,
410 | return_all_resp: bool = False,
411 | sampling_temperature: float = 1.0,
412 | ):
413 | """
414 | Args:
415 | text_list: [t] * b
416 | proms_list: [t' l] * b, l quantization levels.
417 | resps_list: [t'' l] * b, l quantization levels.
418 | targ_list: [t''] * b, one quantization level only, when given, loss will be computed
419 | quant_levels: specify which quant_levels to feed forward, used in NAR mode.
420 | shift_targ_list: whether to shift target list when computing loss. True if AR.
421 | return_all_resp: True if NAR.
422 | sampling_temperature: a lower temperature makes the result more robust but less diverse.
423 | Returns:
424 | y: sampled tokens
425 | """
426 | x_list = self._samplewise_merge_tensors(
427 | self.text_emb(text_list),
428 | self.proms_emb(proms_list),
429 | self.resps_emb(resps_list),
430 | sep=self.sep,
431 | )
432 |
433 | x, m = list_to_tensor(x_list)
434 | x = self.sin_emb.add_pe(x)
435 |
436 | for block in self.blocks:
437 | x = block(x, m, quant_levels)
438 |
439 | h = self.classifier(x) * m
440 |
441 | # Remove padding
442 | h_list = [hi[:li] for hi, li in zip(h, map(len, x_list))]
443 |
444 | if targ_list is not None:
445 | if any([l == 0 for l in map(len, targ_list)]):
446 | raise ValueError("Cannot compute loss given empty targ_list.")
447 |
448 | device = h.device
449 |
450 | ignore_sep = torch.tensor(self.ignore_index, device=device)
451 |
452 | # Ignore prom in the target
453 | prom_list = [
454 | torch.full_like(t[..., 0], self.ignore_index) for t in proms_list
455 | ]
456 |
457 | text_prom_list = self._samplewise_merge_tensors(
458 | text_list, prom_list, sep=ignore_sep
459 | )
460 |
461 | # Make every token earlier as it is future that is unknown
462 | # If we don't want compute loss, set all to ignored
463 | for i in range(len(text_prom_list)):
464 | if self.resp_loss_only:
465 | text_prom_list[i][:] = self.ignore_index
466 | else:
467 | text_prom_list[i] = text_prom_list[i].roll(-1, dims=0)
468 | text_prom_list[i][-1] = self.ignore_index
469 |
470 | if shift_targ_list:
471 | # Also make target earlier if in autoregressive mode
472 | targ_list = [*targ_list]
473 | for i in range(len(targ_list)):
474 | targ_list[i] = targ_list[i].roll(-1, dims=0)
475 | targ_list[i][-1] = self.stop_token
476 |
477 | y_list = self._samplewise_merge_tensors(
478 | text_prom_list, targ_list, sep=ignore_sep
479 | )
480 |
481 | self.loss = dict(
482 | nll=F.cross_entropy(
483 | torch.cat(h_list),
484 | torch.cat(y_list),
485 | ignore_index=self.ignore_index,
486 | )
487 | )
488 |
489 | if return_all_resp:
490 | logits = [hi[-li:] for hi, li in zip(h_list, map(len, resps_list))]
491 | ret = [
492 | Categorical(logits=hi / sampling_temperature).sample() for hi in logits
493 | ]
494 | else:
495 | logits = torch.stack([hi[-1] for hi in h_list])
496 | ret = Categorical(logits=logits / sampling_temperature).sample()
497 |
498 | return ret
499 |
--------------------------------------------------------------------------------
/vall_e/vall_e/nar.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import Tensor
3 |
4 | from .base import Base
5 |
6 |
7 | class NAR(Base):
8 | @property
9 | def n_resp_levels(self):
10 | return 7
11 |
12 | @property
13 | def casual(self):
14 | return False
15 |
16 | @property
17 | def use_stop_token(self):
18 | return False
19 |
20 | @property
21 | def norm_type(self):
22 | return "adaln"
23 |
24 | @property
25 | def resp_loss_only(self):
26 | return True
27 |
28 | def forward(
29 | self,
30 | text_list: list[Tensor],
31 | proms_list: list[Tensor],
32 | resps_list: list[Tensor],
33 | sampling_temperature: float = 0.2,
34 | ):
35 | """
36 | Args:
37 | text_list: [t] * b
38 | proms_list: [t' l] * b, l=8
39 | resps_list: [t'' l] * b, l=1 or 8, 1 for testing and 8 for training.
40 | Returns:
41 | [t'' l], l=8 if testing. empty list will be returned during training.
42 | """
43 |
44 | n_levels_set = {r.shape[-1] for r in resps_list}
45 |
46 | if len(n_levels_set) > 1:
47 | raise ValueError(f"Please give only one level, got {n_levels_set}.")
48 |
49 | n_levels = next(iter(n_levels_set))
50 |
51 | device = text_list[0].device
52 |
53 | if n_levels == self.n_resp_levels + 1:
54 | assert resps_list is not None
55 |
56 | quant_levels = torch.randint(0, self.n_resp_levels, (len(resps_list),))
57 |
58 | prev_list = [o[..., : l + 1] for o, l in zip(resps_list, quant_levels)]
59 | targ_list = [o[..., l + 1] for o, l in zip(resps_list, quant_levels)]
60 |
61 | quant_levels = quant_levels.to(device=device)
62 |
63 | _ = super().forward(
64 | text_list,
65 | proms_list,
66 | prev_list,
67 | targ_list,
68 | return_all_resp=True,
69 | shift_targ_list=False,
70 | quant_levels=quant_levels,
71 | )
72 |
73 | # Yes, just nothing as we are training
74 | prev_list = []
75 | else:
76 | prev_list = resps_list
77 |
78 | while True:
79 | level = prev_list[0].shape[-1] - 1
80 |
81 | if level >= self.n_resp_levels:
82 | break
83 |
84 | quant_levels = torch.full((len(text_list),), level, device=device)
85 |
86 | resp_list = super().forward(
87 | text_list,
88 | proms_list,
89 | prev_list,
90 | return_all_resp=True,
91 | shift_targ_list=False,
92 | quant_levels=quant_levels,
93 | sampling_temperature=sampling_temperature,
94 | )
95 |
96 | prev_list = [
97 | torch.cat([rs, r.unsqueeze(-1)], dim=-1)
98 | for rs, r in zip(prev_list, resp_list)
99 | ]
100 |
101 | return prev_list
102 |
103 |
104 | def example_usage():
105 | from functools import partial
106 | from pathlib import Path
107 |
108 | from einops import repeat
109 |
110 | from ..emb.qnt import decode_to_file
111 | from ..utils import gather_attribute
112 |
113 | device = "cuda"
114 |
115 | resps = torch.load("data/test/test.qnt.pt")[0].to(device)
116 | num_qnts = 1024
117 |
118 | model = NAR(num_qnts).to(device)
119 |
120 | text_list = [
121 | torch.tensor([2, 3], device=device),
122 | ]
123 |
124 | x8 = partial(repeat, pattern="t -> t l", l=8)
125 | proms_list = [
126 | x8(torch.tensor([2, 3], device=device)),
127 | ]
128 |
129 | resps_x1_list = [
130 | resps[:1].t().to(device),
131 | ]
132 |
133 | resps_x8_list = [
134 | resps.t().to(device),
135 | ]
136 |
137 | codes = model(
138 | text_list,
139 | proms_list,
140 | resps_list=resps_x1_list,
141 | sampling_temperature=0.2,
142 | )[0]
143 |
144 | decode_to_file(
145 | codes,
146 | Path("data/test/test.nar.init.wav"),
147 | )
148 |
149 | optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
150 |
151 | for i in range(200):
152 | optimizer.zero_grad()
153 |
154 | _ = model(text_list, proms_list, resps_list=resps_x8_list)
155 |
156 | losses = gather_attribute(model, "loss")
157 | loss = sum(losses.values())
158 | loss.backward()
159 |
160 | optimizer.step()
161 |
162 | if i % 20 == 0:
163 | stats = {k: v.item() for k, v in losses.items()}
164 | stats["loss"] = loss.item()
165 | print(f"iter={i}, {stats}.")
166 |
167 | for i in range(1, 8):
168 | resps_list = [
169 | resps[:i].t().to(device),
170 | ]
171 |
172 | codes = model(
173 | text_list,
174 | proms_list,
175 | resps_list=resps_list,
176 | sampling_temperature=0.2,
177 | )[0]
178 |
179 | decode_to_file(
180 | codes,
181 | Path(f"data/test/test.nar.1-{i}.wav"),
182 | )
183 |
184 |
185 | if __name__ == "__main__":
186 | example_usage()
187 |
--------------------------------------------------------------------------------