├── .gitignore
├── README.md
├── barkify
├── bark
│ ├── __init__.py
│ ├── api.py
│ ├── generation.py
│ ├── model.py
│ └── model_fine.py
├── datas
│ ├── __init__.py
│ ├── data.py
│ └── tokenizer.py
├── pl_model.py
└── utils.py
├── configs
└── barkify.yaml
├── infer.ipynb
├── process.ipynb
├── requirements.txt
└── trainer.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | runs/
7 |
8 | # C extensions
9 | *.so
10 |
11 | # Distribution / packaging
12 | .Python
13 | build/
14 | develop-eggs/
15 | dist/
16 | downloads/
17 | eggs/
18 | .eggs/
19 | lib/
20 | lib64/
21 | parts/
22 | sdist/
23 | var/
24 | wheels/
25 | pip-wheel-metadata/
26 | share/python-wheels/
27 | *.egg-info/
28 | .installed.cfg
29 | *.egg
30 | MANIFEST
31 |
32 | # PyInstaller
33 | # Usually these files are written by a python script from a template
34 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
35 | *.manifest
36 | *.spec
37 |
38 | # Installer logs
39 | pip-log.txt
40 | pip-delete-this-directory.txt
41 |
42 | # Unit test / coverage reports
43 | htmlcov/
44 | .tox/
45 | .nox/
46 | .coverage
47 | .coverage.*
48 | .cache
49 | nosetests.xml
50 | coverage.xml
51 | *.cover
52 | *.py,cover
53 | .hypothesis/
54 | .pytest_cache/
55 |
56 | # Translations
57 | *.mo
58 | *.pot
59 |
60 | # Django stuff:
61 | *.log
62 | local_settings.py
63 | db.sqlite3
64 | db.sqlite3-journal
65 |
66 | # Flask stuff:
67 | instance/
68 | .webassets-cache
69 |
70 | # Scrapy stuff:
71 | .scrapy
72 |
73 | # Sphinx documentation
74 | docs/_build/
75 |
76 | # PyBuilder
77 | target/
78 |
79 | # Jupyter Notebook
80 | .ipynb_checkpoints
81 |
82 | # IPython
83 | profile_default/
84 | ipython_config.py
85 |
86 | # pyenv
87 | .python-version
88 |
89 | # pipenv
90 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
91 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
92 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
93 | # install all needed dependencies.
94 | #Pipfile.lock
95 |
96 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
97 | __pypackages__/
98 |
99 | # Celery stuff
100 | celerybeat-schedule
101 | celerybeat.pid
102 |
103 | # SageMath parsed files
104 | *.sage.py
105 |
106 | # Environments
107 | .env
108 | .venv
109 | env/
110 | venv/
111 | ENV/
112 | env.bak/
113 | venv.bak/
114 |
115 | # Spyder project settings
116 | .spyderproject
117 | .spyproject
118 |
119 | # Rope project settings
120 | .ropeproject
121 |
122 | # mkdocs documentation
123 | /site
124 |
125 | # mypy
126 | .mypy_cache/
127 | .dmypy.json
128 | dmypy.json
129 |
130 | # Pyre type checker
131 | .pyre/
132 |
133 | # datasets & checkpoints
134 | work_env/
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # barkify
2 | Barkify: an unoffical repo for training Bark, a text-prompted generative audio model by suno-ai.
3 |
4 | Bark has two GPT style models which is compatible for prompting and other tricks from NLP. Bark realize a great real world tts result but the repo itself doesn't a training recipe. We want to conduct some experiments or train this model. Here we release our basic training code which might be a guidance of training for open source community.
5 |
6 | ## Process dataset
7 | We do our experiment on LJspeech. Follow the instrcutions in `process.ipynb`.
8 | For Chinese, we test a famous steamer named `峰哥亡命天涯`. It shows an acceptable result but worse than our other TTS repo.
9 | For English, we test LibriTTS dataset. It works fine and basic items in our roadmap have been proved.
10 |
11 | ## Training
12 | Stage1 stands for text to semantic and stage2 stands for semantic to acoustic.
13 | You should config paramters in the `configs/barkify.yaml`. We use one A100 to train our model (both S1&S2).
14 | ```
15 | # training stage 1 or 2
16 | python trainer.py start_path=/path/to/your/work_env stage=1 name=
17 | python trainer.py start_path=/path/to/your/work_env stage=2 name=
18 | ```
19 |
20 | ## Inference
21 | Directly use `infer.ipynb` and follow the instrcutions to infer your model.
22 |
23 | ## Roadmap
24 | We have already achieve the following items and we will release our code soon.
25 | - [x] Construct a basic training code for bark-like generative model
26 | - [x] Test one speaker scenario
27 | - [x] Test multi speaker scenario
28 | - [x] Test speaker semantic prompting
29 | - [x] Test speech/audio acoustic prompting
30 | - [x] Test variable length data(as we use a fixed length now)
31 |
32 | These items are pretty data-hungry or rely on massive GPUs.
33 | So we are open to any sponsors or collaborators to finish these jobs.
34 | You could contact us by QQ: 3284494602 or email us at 3284494602@qq.com
35 |
36 | - [ ] Long-form generation(which may be longer than 1min.)
37 | - [ ] Support more language(especially for ZH)
38 | - [ ] Paralanguage modeling in the text input
39 | - [ ] Speaker generation by text prompts
40 | - [ ] Emotion/Timbre/Rhythm controlling by text/acoustic prompts
41 | - [ ] Add/Remove background noise(which might be important for downstream tasks)
42 |
43 | ## Appreciation
44 | - [bark](https://github.com/suno-ai/bark/) is a transformer-based text-to-audio model.
45 | - [Vall-E](https://github.com/lifeiteng/vall-e) is an unofficial PyTorch implementation of VALL-E.
46 |
--------------------------------------------------------------------------------
/barkify/bark/__init__.py:
--------------------------------------------------------------------------------
1 | from .model import GPT, GPTConfig
2 | from .model_fine import FineGPT, FineGPTConfig
3 |
4 | from .generation import SAMPLE_RATE, preload_models
5 | from .generation import generate_fine, codec_decode
6 |
7 | def create_infer_model(model_config):
8 | gptconf = GPTConfig(**model_config)
9 | return GPT(gptconf).eval()
--------------------------------------------------------------------------------
/barkify/bark/api.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, Optional, Union
2 |
3 | import numpy as np
4 |
5 | from .generation import codec_decode, generate_coarse, generate_fine, generate_text_semantic
6 |
7 |
8 | def text_to_semantic(
9 | text: str,
10 | history_prompt: Optional[Union[Dict, str]] = None,
11 | temp: float = 0.7,
12 | silent: bool = False,
13 | ):
14 | """Generate semantic array from text.
15 |
16 | Args:
17 | text: text to be turned into audio
18 | history_prompt: history choice for audio cloning
19 | temp: generation temperature (1.0 more diverse, 0.0 more conservative)
20 | silent: disable progress bar
21 |
22 | Returns:
23 | numpy semantic array to be fed into `semantic_to_waveform`
24 | """
25 | x_semantic = generate_text_semantic(
26 | text,
27 | history_prompt=history_prompt,
28 | temp=temp,
29 | silent=silent,
30 | use_kv_caching=True
31 | )
32 | return x_semantic
33 |
34 |
35 | def semantic_to_waveform(
36 | semantic_tokens: np.ndarray,
37 | history_prompt: Optional[Union[Dict, str]] = None,
38 | temp: float = 0.7,
39 | silent: bool = False,
40 | output_full: bool = False,
41 | ):
42 | """Generate audio array from semantic input.
43 |
44 | Args:
45 | semantic_tokens: semantic token output from `text_to_semantic`
46 | history_prompt: history choice for audio cloning
47 | temp: generation temperature (1.0 more diverse, 0.0 more conservative)
48 | silent: disable progress bar
49 | output_full: return full generation to be used as a history prompt
50 |
51 | Returns:
52 | numpy audio array at sample frequency 24khz
53 | """
54 | coarse_tokens = generate_coarse(
55 | semantic_tokens,
56 | history_prompt=history_prompt,
57 | temp=temp,
58 | silent=silent,
59 | use_kv_caching=True
60 | )
61 | fine_tokens = generate_fine(
62 | coarse_tokens,
63 | history_prompt=history_prompt,
64 | temp=0.5,
65 | )
66 | audio_arr = codec_decode(fine_tokens)
67 | if output_full:
68 | full_generation = {
69 | "semantic_prompt": semantic_tokens,
70 | "coarse_prompt": coarse_tokens,
71 | "fine_prompt": fine_tokens,
72 | }
73 | return full_generation, audio_arr
74 | return audio_arr
75 |
76 |
77 | def save_as_prompt(filepath, full_generation):
78 | assert(filepath.endswith(".npz"))
79 | assert(isinstance(full_generation, dict))
80 | assert("semantic_prompt" in full_generation)
81 | assert("coarse_prompt" in full_generation)
82 | assert("fine_prompt" in full_generation)
83 | np.savez(filepath, **full_generation)
84 |
85 |
86 | def generate_audio(
87 | text: str,
88 | history_prompt: Optional[Union[Dict, str]] = None,
89 | text_temp: float = 0.7,
90 | waveform_temp: float = 0.7,
91 | silent: bool = False,
92 | output_full: bool = False,
93 | ):
94 | """Generate audio array from input text.
95 |
96 | Args:
97 | text: text to be turned into audio
98 | history_prompt: history choice for audio cloning
99 | text_temp: generation temperature (1.0 more diverse, 0.0 more conservative)
100 | waveform_temp: generation temperature (1.0 more diverse, 0.0 more conservative)
101 | silent: disable progress bar
102 | output_full: return full generation to be used as a history prompt
103 |
104 | Returns:
105 | numpy audio array at sample frequency 24khz
106 | """
107 | semantic_tokens = text_to_semantic(
108 | text,
109 | history_prompt=history_prompt,
110 | temp=text_temp,
111 | silent=silent,
112 | )
113 | out = semantic_to_waveform(
114 | semantic_tokens,
115 | history_prompt=history_prompt,
116 | temp=waveform_temp,
117 | silent=silent,
118 | output_full=output_full,
119 | )
120 | if output_full:
121 | full_generation, audio_arr = out
122 | return full_generation, audio_arr
123 | else:
124 | audio_arr = out
125 | return audio_arr
126 |
--------------------------------------------------------------------------------
/barkify/bark/generation.py:
--------------------------------------------------------------------------------
1 | # Copy from bark
2 |
3 | import contextlib
4 | import gc
5 | import os
6 | import re
7 |
8 | from encodec import EncodecModel
9 | import funcy
10 | import logging
11 | import numpy as np
12 | from scipy.special import softmax
13 | import torch
14 | import torch.nn.functional as F
15 | import tqdm
16 | from transformers import BertTokenizer
17 | from huggingface_hub import hf_hub_download
18 |
19 | from .model import GPTConfig, GPT
20 | from .model_fine import FineGPT, FineGPTConfig
21 |
22 | if (
23 | torch.cuda.is_available() and
24 | hasattr(torch.cuda, "amp") and
25 | hasattr(torch.cuda.amp, "autocast") and
26 | hasattr(torch.cuda, "is_bf16_supported") and
27 | torch.cuda.is_bf16_supported()
28 | ):
29 | autocast = funcy.partial(torch.cuda.amp.autocast, dtype=torch.bfloat16)
30 | else:
31 | @contextlib.contextmanager
32 | def autocast():
33 | yield
34 |
35 |
36 | # hold models in global scope to lazy load
37 | global models
38 | models = {}
39 |
40 | global models_devices
41 | models_devices = {}
42 |
43 |
44 | CONTEXT_WINDOW_SIZE = 1024
45 |
46 | SEMANTIC_RATE_HZ = 49.9
47 | SEMANTIC_VOCAB_SIZE = 10_000
48 |
49 | CODEBOOK_SIZE = 1024
50 | N_COARSE_CODEBOOKS = 2
51 | N_FINE_CODEBOOKS = 8
52 | COARSE_RATE_HZ = 75
53 |
54 | SAMPLE_RATE = 24_000
55 |
56 |
57 | SUPPORTED_LANGS = [
58 | ("English", "en"),
59 | ("German", "de"),
60 | ("Spanish", "es"),
61 | ("French", "fr"),
62 | ("Hindi", "hi"),
63 | ("Italian", "it"),
64 | ("Japanese", "ja"),
65 | ("Korean", "ko"),
66 | ("Polish", "pl"),
67 | ("Portuguese", "pt"),
68 | ("Russian", "ru"),
69 | ("Turkish", "tr"),
70 | ("Chinese", "zh"),
71 | ]
72 |
73 | ALLOWED_PROMPTS = {"announcer"}
74 | for _, lang in SUPPORTED_LANGS:
75 | for prefix in ("", f"v2{os.path.sep}"):
76 | for n in range(10):
77 | ALLOWED_PROMPTS.add(f"{prefix}{lang}_speaker_{n}")
78 |
79 |
80 | logger = logging.getLogger(__name__)
81 |
82 |
83 | CUR_PATH = os.path.dirname(os.path.abspath(__file__))
84 |
85 |
86 | default_cache_dir = os.path.join(os.path.expanduser("~"), ".cache")
87 | CACHE_DIR = os.path.join(os.getenv("XDG_CACHE_HOME", default_cache_dir), "suno", "bark_v0")
88 |
89 |
90 | def _cast_bool_env_var(s):
91 | return s.lower() in ('true', '1', 't')
92 |
93 |
94 | USE_SMALL_MODELS = _cast_bool_env_var(os.environ.get("SUNO_USE_SMALL_MODELS", "False"))
95 | GLOBAL_ENABLE_MPS = _cast_bool_env_var(os.environ.get("SUNO_ENABLE_MPS", "False"))
96 | OFFLOAD_CPU = _cast_bool_env_var(os.environ.get("SUNO_OFFLOAD_CPU", "False"))
97 |
98 |
99 | REMOTE_MODEL_PATHS = {
100 | "text_small": {
101 | "repo_id": "suno/bark",
102 | "file_name": "text.pt",
103 | },
104 | "coarse_small": {
105 | "repo_id": "suno/bark",
106 | "file_name": "coarse.pt",
107 | },
108 | "fine_small": {
109 | "repo_id": "suno/bark",
110 | "file_name": "fine.pt",
111 | },
112 | "text": {
113 | "repo_id": "suno/bark",
114 | "file_name": "text_2.pt",
115 | },
116 | "coarse": {
117 | "repo_id": "suno/bark",
118 | "file_name": "coarse_2.pt",
119 | },
120 | "fine": {
121 | "repo_id": "suno/bark",
122 | "file_name": "fine_2.pt",
123 | },
124 | }
125 |
126 |
127 | if not hasattr(torch.nn.functional, 'scaled_dot_product_attention') and torch.cuda.is_available():
128 | logger.warning(
129 | "torch version does not support flash attention. You will get faster" +
130 | " inference speed by upgrade torch to newest nightly version."
131 | )
132 |
133 |
134 | def _grab_best_device(use_gpu=True):
135 | if torch.cuda.device_count() > 0 and use_gpu:
136 | device = "cuda"
137 | elif torch.backends.mps.is_available() and use_gpu and GLOBAL_ENABLE_MPS:
138 | device = "mps"
139 | else:
140 | device = "cpu"
141 | return device
142 |
143 |
144 | def _get_ckpt_path(model_type, use_small=False):
145 | key = model_type
146 | if use_small or USE_SMALL_MODELS:
147 | key += "_small"
148 | return os.path.join(CACHE_DIR, REMOTE_MODEL_PATHS[key]["file_name"])
149 |
150 |
151 | def _download(from_hf_path, file_name):
152 | os.makedirs(CACHE_DIR, exist_ok=True)
153 | hf_hub_download(repo_id=from_hf_path, filename=file_name, local_dir=CACHE_DIR)
154 |
155 |
156 | class InferenceContext:
157 | def __init__(self, benchmark=False):
158 | # we can't expect inputs to be the same length, so disable benchmarking by default
159 | self._chosen_cudnn_benchmark = benchmark
160 | self._cudnn_benchmark = None
161 |
162 | def __enter__(self):
163 | self._cudnn_benchmark = torch.backends.cudnn.benchmark
164 | torch.backends.cudnn.benchmark = self._chosen_cudnn_benchmark
165 |
166 | def __exit__(self, exc_type, exc_value, exc_traceback):
167 | torch.backends.cudnn.benchmark = self._cudnn_benchmark
168 |
169 |
170 | if torch.cuda.is_available():
171 | torch.backends.cuda.matmul.allow_tf32 = True
172 | torch.backends.cudnn.allow_tf32 = True
173 |
174 |
175 | @contextlib.contextmanager
176 | def _inference_mode():
177 | with InferenceContext(), torch.inference_mode(), torch.no_grad(), autocast():
178 | yield
179 |
180 |
181 | def _clear_cuda_cache():
182 | if torch.cuda.is_available():
183 | torch.cuda.empty_cache()
184 | torch.cuda.synchronize()
185 |
186 |
187 | def clean_models(model_key=None):
188 | global models
189 | model_keys = [model_key] if model_key is not None else models.keys()
190 | for k in model_keys:
191 | if k in models:
192 | del models[k]
193 | _clear_cuda_cache()
194 | gc.collect()
195 |
196 |
197 | def _load_model(ckpt_path, device, use_small=False, model_type="text"):
198 | if model_type == "text":
199 | ConfigClass = GPTConfig
200 | ModelClass = GPT
201 | elif model_type == "coarse":
202 | ConfigClass = GPTConfig
203 | ModelClass = GPT
204 | elif model_type == "fine":
205 | ConfigClass = FineGPTConfig
206 | ModelClass = FineGPT
207 | else:
208 | raise NotImplementedError()
209 | model_key = f"{model_type}_small" if use_small or USE_SMALL_MODELS else model_type
210 | model_info = REMOTE_MODEL_PATHS[model_key]
211 | if not os.path.exists(ckpt_path):
212 | logger.info(f"{model_type} model not found, downloading into `{CACHE_DIR}`.")
213 | _download(model_info["repo_id"], model_info["file_name"])
214 | checkpoint = torch.load(ckpt_path, map_location=device)
215 | # this is a hack
216 | model_args = checkpoint["model_args"]
217 | if "input_vocab_size" not in model_args:
218 | model_args["input_vocab_size"] = model_args["vocab_size"]
219 | model_args["output_vocab_size"] = model_args["vocab_size"]
220 | del model_args["vocab_size"]
221 | gptconf = ConfigClass(**checkpoint["model_args"])
222 | model = ModelClass(gptconf)
223 | state_dict = checkpoint["model"]
224 | # fixup checkpoint
225 | unwanted_prefix = "_orig_mod."
226 | for k, v in list(state_dict.items()):
227 | if k.startswith(unwanted_prefix):
228 | state_dict[k[len(unwanted_prefix) :]] = state_dict.pop(k)
229 | extra_keys = set(state_dict.keys()) - set(model.state_dict().keys())
230 | extra_keys = set([k for k in extra_keys if not k.endswith(".attn.bias")])
231 | missing_keys = set(model.state_dict().keys()) - set(state_dict.keys())
232 | missing_keys = set([k for k in missing_keys if not k.endswith(".attn.bias")])
233 | if len(extra_keys) != 0:
234 | raise ValueError(f"extra keys found: {extra_keys}")
235 | if len(missing_keys) != 0:
236 | raise ValueError(f"missing keys: {missing_keys}")
237 | model.load_state_dict(state_dict, strict=False)
238 | n_params = model.get_num_params()
239 | val_loss = checkpoint["best_val_loss"].item()
240 | logger.info(f"model loaded: {round(n_params/1e6,1)}M params, {round(val_loss,3)} loss")
241 | model.eval()
242 | model.to(device)
243 | del checkpoint, state_dict
244 | _clear_cuda_cache()
245 | if model_type == "text":
246 | tokenizer = BertTokenizer.from_pretrained("bert-base-multilingual-cased")
247 | return {
248 | "model": model,
249 | "tokenizer": tokenizer,
250 | }
251 | return model
252 |
253 |
254 | def _load_codec_model(device):
255 | model = EncodecModel.encodec_model_24khz()
256 | model.set_target_bandwidth(6.0)
257 | model.eval()
258 | model.to(device)
259 | _clear_cuda_cache()
260 | return model
261 |
262 |
263 | def load_model(use_gpu=True, use_small=False, force_reload=False, model_type="text"):
264 | _load_model_f = funcy.partial(_load_model, model_type=model_type, use_small=use_small)
265 | if model_type not in ("text", "coarse", "fine"):
266 | raise NotImplementedError()
267 | global models
268 | global models_devices
269 | device = _grab_best_device(use_gpu=use_gpu)
270 | model_key = f"{model_type}"
271 | if OFFLOAD_CPU:
272 | models_devices[model_key] = device
273 | device = "cpu"
274 | if model_key not in models or force_reload:
275 | ckpt_path = _get_ckpt_path(model_type, use_small=use_small)
276 | clean_models(model_key=model_key)
277 | model = _load_model_f(ckpt_path, device)
278 | models[model_key] = model
279 | if model_type == "text":
280 | models[model_key]["model"].to(device)
281 | else:
282 | models[model_key].to(device)
283 | return models[model_key]
284 |
285 |
286 | def load_codec_model(use_gpu=True, force_reload=False):
287 | global models
288 | global models_devices
289 | device = _grab_best_device(use_gpu=use_gpu)
290 | if device == "mps":
291 | # encodec doesn't support mps
292 | device = "cpu"
293 | model_key = "codec"
294 | if OFFLOAD_CPU:
295 | models_devices[model_key] = device
296 | device = "cpu"
297 | if model_key not in models or force_reload:
298 | clean_models(model_key=model_key)
299 | model = _load_codec_model(device)
300 | models[model_key] = model
301 | models[model_key].to(device)
302 | return models[model_key]
303 |
304 |
305 | def preload_models(
306 | text_use_gpu=True,
307 | text_use_small=False,
308 | coarse_use_gpu=True,
309 | coarse_use_small=False,
310 | fine_use_gpu=True,
311 | fine_use_small=False,
312 | codec_use_gpu=True,
313 | force_reload=False,
314 | ):
315 | """Load all the necessary models for the pipeline."""
316 | if _grab_best_device() == "cpu" and (
317 | text_use_gpu or coarse_use_gpu or fine_use_gpu or codec_use_gpu
318 | ):
319 | logger.warning("No GPU being used. Careful, inference might be very slow!")
320 | _ = load_model(
321 | model_type="text", use_gpu=text_use_gpu, use_small=text_use_small, force_reload=force_reload
322 | )
323 | _ = load_model(
324 | model_type="coarse",
325 | use_gpu=coarse_use_gpu,
326 | use_small=coarse_use_small,
327 | force_reload=force_reload,
328 | )
329 | _ = load_model(
330 | model_type="fine", use_gpu=fine_use_gpu, use_small=fine_use_small, force_reload=force_reload
331 | )
332 | _ = load_codec_model(use_gpu=codec_use_gpu, force_reload=force_reload)
333 |
334 |
335 | ####
336 | # Generation Functionality
337 | ####
338 |
339 |
340 | def _tokenize(tokenizer, text):
341 | return tokenizer.encode(text, add_special_tokens=False)
342 |
343 |
344 | def _detokenize(tokenizer, enc_text):
345 | return tokenizer.decode(enc_text)
346 |
347 |
348 | def _normalize_whitespace(text):
349 | return re.sub(r"\s+", " ", text).strip()
350 |
351 |
352 | TEXT_ENCODING_OFFSET = 10_048
353 | SEMANTIC_PAD_TOKEN = 10_000
354 | TEXT_PAD_TOKEN = 129_595
355 | SEMANTIC_INFER_TOKEN = 129_599
356 |
357 |
358 | def _load_history_prompt(history_prompt_input):
359 | if isinstance(history_prompt_input, str) and history_prompt_input.endswith(".npz"):
360 | history_prompt = np.load(history_prompt_input)
361 | elif isinstance(history_prompt_input, str):
362 | # make sure this works on non-ubuntu
363 | history_prompt_input = os.path.join(*history_prompt_input.split("/"))
364 | if history_prompt_input not in ALLOWED_PROMPTS:
365 | raise ValueError("history prompt not found")
366 | history_prompt = np.load(
367 | os.path.join(CUR_PATH, "assets", "prompts", f"{history_prompt_input}.npz")
368 | )
369 | elif isinstance(history_prompt_input, dict):
370 | assert("semantic_prompt" in history_prompt_input)
371 | assert("coarse_prompt" in history_prompt_input)
372 | assert("fine_prompt" in history_prompt_input)
373 | history_prompt = history_prompt_input
374 | else:
375 | raise ValueError("history prompt format unrecognized")
376 | return history_prompt
377 |
378 |
379 | def generate_text_semantic(
380 | text,
381 | history_prompt=None,
382 | temp=0.7,
383 | top_k=None,
384 | top_p=None,
385 | silent=False,
386 | min_eos_p=0.2,
387 | max_gen_duration_s=None,
388 | allow_early_stop=True,
389 | use_kv_caching=False,
390 | ):
391 | """Generate semantic tokens from text."""
392 | assert isinstance(text, str)
393 | text = _normalize_whitespace(text)
394 | assert len(text.strip()) > 0
395 | if history_prompt is not None:
396 | history_prompt = _load_history_prompt(history_prompt)
397 | semantic_history = history_prompt["semantic_prompt"]
398 | assert (
399 | isinstance(semantic_history, np.ndarray)
400 | and len(semantic_history.shape) == 1
401 | and len(semantic_history) > 0
402 | and semantic_history.min() >= 0
403 | and semantic_history.max() <= SEMANTIC_VOCAB_SIZE - 1
404 | )
405 | else:
406 | semantic_history = None
407 | # load models if not yet exist
408 | global models
409 | global models_devices
410 | if "text" not in models:
411 | preload_models()
412 | model_container = models["text"]
413 | model = model_container["model"]
414 | tokenizer = model_container["tokenizer"]
415 | encoded_text = np.array(_tokenize(tokenizer, text)) + TEXT_ENCODING_OFFSET
416 | if OFFLOAD_CPU:
417 | model.to(models_devices["text"])
418 | device = next(model.parameters()).device
419 | if len(encoded_text) > 256:
420 | p = round((len(encoded_text) - 256) / len(encoded_text) * 100, 1)
421 | logger.warning(f"warning, text too long, lopping of last {p}%")
422 | encoded_text = encoded_text[:256]
423 | encoded_text = np.pad(
424 | encoded_text,
425 | (0, 256 - len(encoded_text)),
426 | constant_values=TEXT_PAD_TOKEN,
427 | mode="constant",
428 | )
429 | if semantic_history is not None:
430 | semantic_history = semantic_history.astype(np.int64)
431 | # lop off if history is too long, pad if needed
432 | semantic_history = semantic_history[-256:]
433 | semantic_history = np.pad(
434 | semantic_history,
435 | (0, 256 - len(semantic_history)),
436 | constant_values=SEMANTIC_PAD_TOKEN,
437 | mode="constant",
438 | )
439 | else:
440 | semantic_history = np.array([SEMANTIC_PAD_TOKEN] * 256)
441 | x = torch.from_numpy(
442 | np.hstack([
443 | encoded_text, semantic_history, np.array([SEMANTIC_INFER_TOKEN])
444 | ]).astype(np.int64)
445 | )[None]
446 | assert x.shape[1] == 256 + 256 + 1
447 | with _inference_mode():
448 | x = x.to(device)
449 | n_tot_steps = 768
450 | # custom tqdm updates since we don't know when eos will occur
451 | pbar = tqdm.tqdm(disable=silent, total=100)
452 | pbar_state = 0
453 | tot_generated_duration_s = 0
454 | kv_cache = None
455 | for n in range(n_tot_steps):
456 | if use_kv_caching and kv_cache is not None:
457 | x_input = x[:, [-1]]
458 | else:
459 | x_input = x
460 | logits, kv_cache = model(
461 | x_input, merge_context=True, use_cache=use_kv_caching, past_kv=kv_cache
462 | )
463 | relevant_logits = logits[0, 0, :SEMANTIC_VOCAB_SIZE]
464 | if allow_early_stop:
465 | relevant_logits = torch.hstack(
466 | (relevant_logits, logits[0, 0, [SEMANTIC_PAD_TOKEN]]) # eos
467 | )
468 | if top_p is not None:
469 | # faster to convert to numpy
470 | logits_device = relevant_logits.device
471 | logits_dtype = relevant_logits.type()
472 | relevant_logits = relevant_logits.detach().cpu().type(torch.float32).numpy()
473 | sorted_indices = np.argsort(relevant_logits)[::-1]
474 | sorted_logits = relevant_logits[sorted_indices]
475 | cumulative_probs = np.cumsum(softmax(sorted_logits))
476 | sorted_indices_to_remove = cumulative_probs > top_p
477 | sorted_indices_to_remove[1:] = sorted_indices_to_remove[:-1].copy()
478 | sorted_indices_to_remove[0] = False
479 | relevant_logits[sorted_indices[sorted_indices_to_remove]] = -np.inf
480 | relevant_logits = torch.from_numpy(relevant_logits)
481 | relevant_logits = relevant_logits.to(logits_device).type(logits_dtype)
482 | if top_k is not None:
483 | v, _ = torch.topk(relevant_logits, min(top_k, relevant_logits.size(-1)))
484 | relevant_logits[relevant_logits < v[-1]] = -float("Inf")
485 | probs = F.softmax(relevant_logits / temp, dim=-1)
486 | # multinomial bugged on mps: shuttle to cpu if necessary
487 | inf_device = probs.device
488 | if probs.device.type == "mps":
489 | probs = probs.to("cpu")
490 | item_next = torch.multinomial(probs, num_samples=1)
491 | probs = probs.to(inf_device)
492 | item_next = item_next.to(inf_device)
493 | if allow_early_stop and (
494 | item_next == SEMANTIC_VOCAB_SIZE
495 | or (min_eos_p is not None and probs[-1] >= min_eos_p)
496 | ):
497 | # eos found, so break
498 | pbar.update(100 - pbar_state)
499 | break
500 | x = torch.cat((x, item_next[None]), dim=1)
501 | tot_generated_duration_s += 1 / SEMANTIC_RATE_HZ
502 | if max_gen_duration_s is not None and tot_generated_duration_s > max_gen_duration_s:
503 | pbar.update(100 - pbar_state)
504 | break
505 | if n == n_tot_steps - 1:
506 | pbar.update(100 - pbar_state)
507 | break
508 | del logits, relevant_logits, probs, item_next
509 | req_pbar_state = np.min([100, int(round(100 * n / n_tot_steps))])
510 | if req_pbar_state > pbar_state:
511 | pbar.update(req_pbar_state - pbar_state)
512 | pbar_state = req_pbar_state
513 | pbar.close()
514 | out = x.detach().cpu().numpy().squeeze()[256 + 256 + 1 :]
515 | if OFFLOAD_CPU:
516 | model.to("cpu")
517 | assert all(0 <= out) and all(out < SEMANTIC_VOCAB_SIZE)
518 | _clear_cuda_cache()
519 | return out
520 |
521 |
522 | def _flatten_codebooks(arr, offset_size=CODEBOOK_SIZE):
523 | assert len(arr.shape) == 2
524 | arr = arr.copy()
525 | if offset_size is not None:
526 | for n in range(1, arr.shape[0]):
527 | arr[n, :] += offset_size * n
528 | flat_arr = arr.ravel("F")
529 | return flat_arr
530 |
531 |
532 | COARSE_SEMANTIC_PAD_TOKEN = 12_048
533 | COARSE_INFER_TOKEN = 12_050
534 |
535 |
536 | def generate_coarse(
537 | x_semantic,
538 | history_prompt=None,
539 | temp=0.7,
540 | top_k=None,
541 | top_p=None,
542 | silent=False,
543 | max_coarse_history=630, # min 60 (faster), max 630 (more context)
544 | sliding_window_len=60,
545 | use_kv_caching=False,
546 | ):
547 | """Generate coarse audio codes from semantic tokens."""
548 | assert (
549 | isinstance(x_semantic, np.ndarray)
550 | and len(x_semantic.shape) == 1
551 | and len(x_semantic) > 0
552 | and x_semantic.min() >= 0
553 | and x_semantic.max() <= SEMANTIC_VOCAB_SIZE - 1
554 | )
555 | assert 60 <= max_coarse_history <= 630
556 | assert max_coarse_history + sliding_window_len <= 1024 - 256
557 | semantic_to_coarse_ratio = COARSE_RATE_HZ / SEMANTIC_RATE_HZ * N_COARSE_CODEBOOKS
558 | max_semantic_history = int(np.floor(max_coarse_history / semantic_to_coarse_ratio))
559 | if history_prompt is not None:
560 | history_prompt = _load_history_prompt(history_prompt)
561 | x_semantic_history = history_prompt["semantic_prompt"]
562 | x_coarse_history = history_prompt["coarse_prompt"]
563 | assert (
564 | isinstance(x_semantic_history, np.ndarray)
565 | and len(x_semantic_history.shape) == 1
566 | and len(x_semantic_history) > 0
567 | and x_semantic_history.min() >= 0
568 | and x_semantic_history.max() <= SEMANTIC_VOCAB_SIZE - 1
569 | and isinstance(x_coarse_history, np.ndarray)
570 | and len(x_coarse_history.shape) == 2
571 | and x_coarse_history.shape[0] == N_COARSE_CODEBOOKS
572 | and x_coarse_history.shape[-1] >= 0
573 | and x_coarse_history.min() >= 0
574 | and x_coarse_history.max() <= CODEBOOK_SIZE - 1
575 | and (
576 | round(x_coarse_history.shape[-1] / len(x_semantic_history), 1)
577 | == round(semantic_to_coarse_ratio / N_COARSE_CODEBOOKS, 1)
578 | )
579 | )
580 | x_coarse_history = _flatten_codebooks(x_coarse_history) + SEMANTIC_VOCAB_SIZE
581 | # trim histories correctly
582 | n_semantic_hist_provided = np.min(
583 | [
584 | max_semantic_history,
585 | len(x_semantic_history) - len(x_semantic_history) % 2,
586 | int(np.floor(len(x_coarse_history) / semantic_to_coarse_ratio)),
587 | ]
588 | )
589 | n_coarse_hist_provided = int(round(n_semantic_hist_provided * semantic_to_coarse_ratio))
590 | x_semantic_history = x_semantic_history[-n_semantic_hist_provided:].astype(np.int32)
591 | x_coarse_history = x_coarse_history[-n_coarse_hist_provided:].astype(np.int32)
592 | # TODO: bit of a hack for time alignment (sounds better)
593 | x_coarse_history = x_coarse_history[:-2]
594 | else:
595 | x_semantic_history = np.array([], dtype=np.int32)
596 | x_coarse_history = np.array([], dtype=np.int32)
597 | # load models if not yet exist
598 | global models
599 | global models_devices
600 | if "coarse" not in models:
601 | preload_models()
602 | model = models["coarse"]
603 | if OFFLOAD_CPU:
604 | model.to(models_devices["coarse"])
605 | device = next(model.parameters()).device
606 | # start loop
607 | n_steps = int(
608 | round(
609 | np.floor(len(x_semantic) * semantic_to_coarse_ratio / N_COARSE_CODEBOOKS)
610 | * N_COARSE_CODEBOOKS
611 | )
612 | )
613 | assert n_steps > 0 and n_steps % N_COARSE_CODEBOOKS == 0
614 | x_semantic = np.hstack([x_semantic_history, x_semantic]).astype(np.int32)
615 | x_coarse = x_coarse_history.astype(np.int32)
616 | base_semantic_idx = len(x_semantic_history)
617 | with _inference_mode():
618 | x_semantic_in = torch.from_numpy(x_semantic)[None].to(device)
619 | x_coarse_in = torch.from_numpy(x_coarse)[None].to(device)
620 | n_window_steps = int(np.ceil(n_steps / sliding_window_len))
621 | n_step = 0
622 | for _ in tqdm.tqdm(range(n_window_steps), total=n_window_steps, disable=silent):
623 | semantic_idx = base_semantic_idx + int(round(n_step / semantic_to_coarse_ratio))
624 | # pad from right side
625 | x_in = x_semantic_in[:, np.max([0, semantic_idx - max_semantic_history]) :]
626 | x_in = x_in[:, :256]
627 | x_in = F.pad(
628 | x_in,
629 | (0, 256 - x_in.shape[-1]),
630 | "constant",
631 | COARSE_SEMANTIC_PAD_TOKEN,
632 | )
633 | x_in = torch.hstack(
634 | [
635 | x_in,
636 | torch.tensor([COARSE_INFER_TOKEN])[None].to(device),
637 | x_coarse_in[:, -max_coarse_history:],
638 | ]
639 | )
640 | kv_cache = None
641 | for _ in range(sliding_window_len):
642 | if n_step >= n_steps:
643 | continue
644 | is_major_step = n_step % N_COARSE_CODEBOOKS == 0
645 |
646 | if use_kv_caching and kv_cache is not None:
647 | x_input = x_in[:, [-1]]
648 | else:
649 | x_input = x_in
650 |
651 | logits, kv_cache = model(x_input, use_cache=use_kv_caching, past_kv=kv_cache)
652 | logit_start_idx = (
653 | SEMANTIC_VOCAB_SIZE + (1 - int(is_major_step)) * CODEBOOK_SIZE
654 | )
655 | logit_end_idx = (
656 | SEMANTIC_VOCAB_SIZE + (2 - int(is_major_step)) * CODEBOOK_SIZE
657 | )
658 | relevant_logits = logits[0, 0, logit_start_idx:logit_end_idx]
659 | if top_p is not None:
660 | # faster to convert to numpy
661 | logits_device = relevant_logits.device
662 | logits_dtype = relevant_logits.type()
663 | relevant_logits = relevant_logits.detach().cpu().type(torch.float32).numpy()
664 | sorted_indices = np.argsort(relevant_logits)[::-1]
665 | sorted_logits = relevant_logits[sorted_indices]
666 | cumulative_probs = np.cumsum(softmax(sorted_logits))
667 | sorted_indices_to_remove = cumulative_probs > top_p
668 | sorted_indices_to_remove[1:] = sorted_indices_to_remove[:-1].copy()
669 | sorted_indices_to_remove[0] = False
670 | relevant_logits[sorted_indices[sorted_indices_to_remove]] = -np.inf
671 | relevant_logits = torch.from_numpy(relevant_logits)
672 | relevant_logits = relevant_logits.to(logits_device).type(logits_dtype)
673 | if top_k is not None:
674 | v, _ = torch.topk(relevant_logits, min(top_k, relevant_logits.size(-1)))
675 | relevant_logits[relevant_logits < v[-1]] = -float("Inf")
676 | probs = F.softmax(relevant_logits / temp, dim=-1)
677 | # multinomial bugged on mps: shuttle to cpu if necessary
678 | inf_device = probs.device
679 | if probs.device.type == "mps":
680 | probs = probs.to("cpu")
681 | item_next = torch.multinomial(probs, num_samples=1)
682 | probs = probs.to(inf_device)
683 | item_next = item_next.to(inf_device)
684 | item_next += logit_start_idx
685 | x_coarse_in = torch.cat((x_coarse_in, item_next[None]), dim=1)
686 | x_in = torch.cat((x_in, item_next[None]), dim=1)
687 | del logits, relevant_logits, probs, item_next
688 | n_step += 1
689 | del x_in
690 | del x_semantic_in
691 | if OFFLOAD_CPU:
692 | model.to("cpu")
693 | gen_coarse_arr = x_coarse_in.detach().cpu().numpy().squeeze()[len(x_coarse_history) :]
694 | del x_coarse_in
695 | assert len(gen_coarse_arr) == n_steps
696 | gen_coarse_audio_arr = gen_coarse_arr.reshape(-1, N_COARSE_CODEBOOKS).T - SEMANTIC_VOCAB_SIZE
697 | for n in range(1, N_COARSE_CODEBOOKS):
698 | gen_coarse_audio_arr[n, :] -= n * CODEBOOK_SIZE
699 | _clear_cuda_cache()
700 | return gen_coarse_audio_arr
701 |
702 |
703 | def generate_fine(
704 | x_coarse_gen,
705 | history_prompt=None,
706 | temp=0.5,
707 | silent=True,
708 | ):
709 | """Generate full audio codes from coarse audio codes."""
710 | assert (
711 | isinstance(x_coarse_gen, np.ndarray)
712 | and len(x_coarse_gen.shape) == 2
713 | and 1 <= x_coarse_gen.shape[0] <= N_FINE_CODEBOOKS - 1
714 | and x_coarse_gen.shape[1] > 0
715 | and x_coarse_gen.min() >= 0
716 | and x_coarse_gen.max() <= CODEBOOK_SIZE - 1
717 | )
718 | if history_prompt is not None:
719 | history_prompt = _load_history_prompt(history_prompt)
720 | x_fine_history = history_prompt["fine_prompt"]
721 | assert (
722 | isinstance(x_fine_history, np.ndarray)
723 | and len(x_fine_history.shape) == 2
724 | and x_fine_history.shape[0] == N_FINE_CODEBOOKS
725 | and x_fine_history.shape[1] >= 0
726 | and x_fine_history.min() >= 0
727 | and x_fine_history.max() <= CODEBOOK_SIZE - 1
728 | )
729 | else:
730 | x_fine_history = None
731 | n_coarse = x_coarse_gen.shape[0]
732 | # load models if not yet exist
733 | global models
734 | global models_devices
735 | if "fine" not in models:
736 | preload_models(text_use_small=True, coarse_use_small=True, fine_use_small=True)
737 |
738 | model = models["fine"]
739 | if OFFLOAD_CPU:
740 | model.to(models_devices["fine"])
741 | device = next(model.parameters()).device
742 | # make input arr
743 | in_arr = np.vstack(
744 | [
745 | x_coarse_gen,
746 | np.zeros((N_FINE_CODEBOOKS - n_coarse, x_coarse_gen.shape[1]))
747 | + CODEBOOK_SIZE, # padding
748 | ]
749 | ).astype(np.int32)
750 | # prepend history if available (max 512)
751 | if x_fine_history is not None:
752 | x_fine_history = x_fine_history.astype(np.int32)
753 | in_arr = np.hstack(
754 | [
755 | x_fine_history[:, -512:].astype(np.int32),
756 | in_arr,
757 | ]
758 | )
759 | n_history = x_fine_history[:, -512:].shape[1]
760 | else:
761 | n_history = 0
762 | n_remove_from_end = 0
763 | # need to pad if too short (since non-causal model)
764 | if in_arr.shape[1] < 1024:
765 | n_remove_from_end = 1024 - in_arr.shape[1]
766 | in_arr = np.hstack(
767 | [
768 | in_arr,
769 | np.zeros((N_FINE_CODEBOOKS, n_remove_from_end), dtype=np.int32) + CODEBOOK_SIZE,
770 | ]
771 | )
772 | # we can be lazy about fractional loop and just keep overwriting codebooks
773 | n_loops = np.max([0, int(np.ceil((x_coarse_gen.shape[1] - (1024 - n_history)) / 512))]) + 1
774 | with _inference_mode():
775 | in_arr = torch.tensor(in_arr.T).to(device)
776 | for n in tqdm.tqdm(range(n_loops), disable=silent):
777 | start_idx = np.min([n * 512, in_arr.shape[0] - 1024])
778 | start_fill_idx = np.min([n_history + n * 512, in_arr.shape[0] - 512])
779 | rel_start_fill_idx = start_fill_idx - start_idx
780 | in_buffer = in_arr[start_idx : start_idx + 1024, :][None]
781 | for nn in range(n_coarse, N_FINE_CODEBOOKS):
782 | logits = model(nn, in_buffer)
783 | if temp is None:
784 | relevant_logits = logits[0, rel_start_fill_idx:, :CODEBOOK_SIZE]
785 | codebook_preds = torch.argmax(relevant_logits, -1)
786 | else:
787 | relevant_logits = logits[0, :, :CODEBOOK_SIZE] / temp
788 | probs = F.softmax(relevant_logits, dim=-1)
789 | # multinomial bugged on mps: shuttle to cpu if necessary
790 | inf_device = probs.device
791 | if probs.device.type == "mps":
792 | probs = probs.to("cpu")
793 | codebook_preds = torch.hstack(
794 | [
795 | torch.multinomial(probs[nnn], num_samples=1).to(inf_device)
796 | for nnn in range(rel_start_fill_idx, 1024)
797 | ]
798 | )
799 | in_buffer[0, rel_start_fill_idx:, nn] = codebook_preds
800 | del logits, codebook_preds
801 | # transfer over info into model_in and convert to numpy
802 | for nn in range(n_coarse, N_FINE_CODEBOOKS):
803 | in_arr[
804 | start_fill_idx : start_fill_idx + (1024 - rel_start_fill_idx), nn
805 | ] = in_buffer[0, rel_start_fill_idx:, nn]
806 | del in_buffer
807 | gen_fine_arr = in_arr.detach().cpu().numpy().squeeze().T
808 | del in_arr
809 | if OFFLOAD_CPU:
810 | model.to("cpu")
811 | gen_fine_arr = gen_fine_arr[:, n_history:]
812 | if n_remove_from_end > 0:
813 | gen_fine_arr = gen_fine_arr[:, :-n_remove_from_end]
814 | assert gen_fine_arr.shape[-1] == x_coarse_gen.shape[-1]
815 | _clear_cuda_cache()
816 | return gen_fine_arr
817 |
818 |
819 | def codec_decode(fine_tokens):
820 | """Turn quantized audio codes into audio array using encodec."""
821 | # load models if not yet exist
822 | global models
823 | global models_devices
824 | if "codec" not in models:
825 | preload_models()
826 | model = models["codec"]
827 | if OFFLOAD_CPU:
828 | model.to(models_devices["codec"])
829 | device = next(model.parameters()).device
830 | arr = torch.from_numpy(fine_tokens)[None]
831 | arr = arr.to(device)
832 | arr = arr.transpose(0, 1)
833 | emb = model.quantizer.decode(arr)
834 | out = model.decoder(emb)
835 | audio_arr = out.detach().cpu().numpy().squeeze()
836 | del arr, emb, out
837 | if OFFLOAD_CPU:
838 | model.to("cpu")
839 | return audio_arr
840 |
--------------------------------------------------------------------------------
/barkify/bark/model.py:
--------------------------------------------------------------------------------
1 | # Copy from bark
2 |
3 | """
4 | Much of this code is adapted from Andrej Karpathy's NanoGPT
5 | (https://github.com/karpathy/nanoGPT)
6 | """
7 | import math
8 | from dataclasses import dataclass
9 |
10 | import torch
11 | import torch.nn as nn
12 | from torch.nn import functional as F
13 |
14 | class LayerNorm(nn.Module):
15 | """ LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False """
16 |
17 | def __init__(self, ndim, bias):
18 | super().__init__()
19 | self.weight = nn.Parameter(torch.ones(ndim))
20 | self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None
21 |
22 | def forward(self, input):
23 | return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5)
24 |
25 | class CausalSelfAttention(nn.Module):
26 |
27 | def __init__(self, config):
28 | super().__init__()
29 | assert config.n_embd % config.n_head == 0
30 | # key, query, value projections for all heads, but in a batch
31 | self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
32 | # output projection
33 | self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
34 | # regularization
35 | self.attn_dropout = nn.Dropout(config.dropout)
36 | self.resid_dropout = nn.Dropout(config.dropout)
37 | self.n_head = config.n_head
38 | self.n_embd = config.n_embd
39 | self.dropout = config.dropout
40 | # flash attention make GPU go brrrrr but support is only in PyTorch nightly and still a bit scary
41 | self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
42 | if not self.flash:
43 | # print("WARNING: using slow attention. Flash Attention atm needs PyTorch nightly and dropout=0.0")
44 | # causal mask to ensure that attention is only applied to the left in the input sequence
45 | self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
46 | .view(1, 1, config.block_size, config.block_size))
47 |
48 | def forward(self, x, past_kv=None, use_cache=False):
49 | B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
50 |
51 | # calculate query, key, values for all heads in batch and move head forward to be the batch dim
52 | q, k ,v = self.c_attn(x).split(self.n_embd, dim=2)
53 | k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
54 | q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
55 | v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
56 |
57 | if past_kv is not None:
58 | past_key = past_kv[0]
59 | past_value = past_kv[1]
60 | k = torch.cat((past_key, k), dim=-2)
61 | v = torch.cat((past_value, v), dim=-2)
62 |
63 | FULL_T = k.shape[-2]
64 |
65 | if use_cache is True:
66 | present = (k, v)
67 | else:
68 | present = None
69 |
70 | # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
71 | if self.flash:
72 | # efficient attention using Flash Attention CUDA kernels
73 | if past_kv is not None:
74 | # When `past_kv` is provided, we're doing incremental decoding and `q.shape[2] == 1`: q only contains
75 | # the query for the last token. scaled_dot_product_attention interprets this as the first token in the
76 | # sequence, so if is_causal=True it will mask out all attention from it. This is not what we want, so
77 | # to work around this we set is_causal=False.
78 | is_causal = False
79 | else:
80 | is_causal = True
81 |
82 | y = torch.nn.functional.scaled_dot_product_attention(q, k, v, dropout_p=self.dropout, is_causal=is_causal)
83 | else:
84 | # manual implementation of attention
85 | att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
86 | att = att.masked_fill(self.bias[:,:,FULL_T-T:FULL_T,:FULL_T] == 0, float('-inf'))
87 | att = F.softmax(att, dim=-1)
88 | att = self.attn_dropout(att)
89 | y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
90 | y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
91 |
92 | # output projection
93 | y = self.resid_dropout(self.c_proj(y))
94 | return (y, present)
95 |
96 | class MLP(nn.Module):
97 |
98 | def __init__(self, config):
99 | super().__init__()
100 | self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias)
101 | self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)
102 | self.dropout = nn.Dropout(config.dropout)
103 | self.gelu = nn.GELU()
104 |
105 | def forward(self, x):
106 | x = self.c_fc(x)
107 | x = self.gelu(x)
108 | x = self.c_proj(x)
109 | x = self.dropout(x)
110 | return x
111 |
112 | class Block(nn.Module):
113 |
114 | def __init__(self, config, layer_idx):
115 | super().__init__()
116 | self.ln_1 = LayerNorm(config.n_embd, bias=config.bias)
117 | self.attn = CausalSelfAttention(config)
118 | self.ln_2 = LayerNorm(config.n_embd, bias=config.bias)
119 | self.mlp = MLP(config)
120 | self.layer_idx = layer_idx
121 |
122 | def forward(self, x, past_kv=None, use_cache=False):
123 | attn_output, prev_kvs = self.attn(self.ln_1(x), past_kv=past_kv, use_cache=use_cache)
124 | x = x + attn_output
125 | x = x + self.mlp(self.ln_2(x))
126 | return (x, prev_kvs)
127 |
128 | @dataclass
129 | class GPTConfig:
130 | block_size: int = 1024
131 | input_vocab_size: int = 10_048
132 | output_vocab_size: int = 10_048
133 | use_extra_input: bool = False
134 | extra_input_dim: int = 512
135 | n_layer: int = 12
136 | n_head: int = 12
137 | n_embd: int = 768
138 | dropout: float = 0.0
139 | bias: bool = True # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster
140 |
141 | class GPT(nn.Module):
142 |
143 | def __init__(self, config):
144 | super().__init__()
145 | assert config.input_vocab_size is not None
146 | assert config.output_vocab_size is not None
147 | assert config.block_size is not None
148 | self.config = config
149 |
150 | self.transformer = nn.ModuleDict(dict(
151 | wte = nn.Embedding(config.input_vocab_size, config.n_embd),
152 | wpe = nn.Embedding(config.block_size, config.n_embd),
153 | extra_proj = nn.Linear(config.extra_input_dim, config.n_embd) if config.use_extra_input else None,
154 | drop = nn.Dropout(config.dropout),
155 | h = nn.ModuleList([Block(config, idx) for idx in range(config.n_layer)]),
156 | ln_f = LayerNorm(config.n_embd, bias=config.bias),
157 | ))
158 | self.lm_head = nn.Linear(config.n_embd, config.output_vocab_size, bias=False)
159 |
160 | def get_num_params(self, non_embedding=True):
161 | """
162 | Return the number of parameters in the model.
163 | For non-embedding count (default), the position embeddings get subtracted.
164 | The token embeddings would too, except due to the parameter sharing these
165 | params are actually used as weights in the final layer, so we include them.
166 | """
167 | n_params = sum(p.numel() for p in self.parameters())
168 | if non_embedding:
169 | n_params -= self.transformer.wte.weight.numel()
170 | n_params -= self.transformer.wpe.weight.numel()
171 | return n_params
172 |
173 | def forward(self, idx, extra=None, merge_context=False, past_kv=None, position_ids=None, use_cache=False):
174 | device = idx.device
175 | b, t = idx.size()
176 | if past_kv is not None:
177 | assert t == 1
178 | tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
179 | else:
180 | if merge_context:
181 | assert(idx.shape[1] >= 256+256+1)
182 | t = idx.shape[1] - 256
183 | else:
184 | assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
185 |
186 | # forward the GPT model itself
187 | if merge_context:
188 | tok_emb = torch.cat([
189 | self.transformer.wte(idx[:,:256]) + self.transformer.wte(idx[:,256:256+256]),
190 | self.transformer.wte(idx[:,256+256:])
191 | ], dim=1)
192 | else:
193 | tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
194 |
195 | if past_kv is None:
196 | past_length = 0
197 | past_kv = tuple([None] * len(self.transformer.h))
198 | else:
199 | past_length = past_kv[0][0].size(-2)
200 |
201 | if position_ids is None:
202 | position_ids = torch.arange(past_length, t + past_length, dtype=torch.long, device=device)
203 | position_ids = position_ids.unsqueeze(0) # shape (1, t)
204 | assert position_ids.shape == (1, t)
205 |
206 | pos_emb = self.transformer.wpe(position_ids) # position embeddings of shape (1, t, n_embd)
207 |
208 | if self.config.use_extra_input:
209 | assert extra is not None
210 | x = self.transformer.drop(tok_emb + pos_emb + self.transformer.extra_proj(extra))
211 | else:
212 | x = self.transformer.drop(tok_emb + pos_emb)
213 |
214 | new_kv = () if use_cache else None
215 |
216 | for i, (block, past_layer_kv) in enumerate(zip(self.transformer.h, past_kv)):
217 | x, kv = block(x, past_kv=past_layer_kv, use_cache=use_cache)
218 |
219 | if use_cache:
220 | new_kv = new_kv + (kv,)
221 |
222 | x = self.transformer.ln_f(x)
223 |
224 | # inference-time mini-optimization: only forward the lm_head on the very last position
225 | if use_cache == False:
226 | logits = self.lm_head(x[:, :, :]) # changes: output full logits
227 | else:
228 | logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim
229 |
230 | return (logits, new_kv)
231 |
--------------------------------------------------------------------------------
/barkify/bark/model_fine.py:
--------------------------------------------------------------------------------
1 | # Copy from bark
2 |
3 | """
4 | Much of this code is adapted from Andrej Karpathy's NanoGPT
5 | (https://github.com/karpathy/nanoGPT)
6 | """
7 | from dataclasses import dataclass
8 | import math
9 |
10 | import torch
11 | import torch.nn as nn
12 | from torch.nn import functional as F
13 |
14 | from .model import GPT, GPTConfig, MLP
15 |
16 |
17 | class NonCausalSelfAttention(nn.Module):
18 | def __init__(self, config):
19 | super().__init__()
20 | assert config.n_embd % config.n_head == 0
21 | # key, query, value projections for all heads, but in a batch
22 | self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
23 | # output projection
24 | self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
25 | # regularization
26 | self.attn_dropout = nn.Dropout(config.dropout)
27 | self.resid_dropout = nn.Dropout(config.dropout)
28 | self.n_head = config.n_head
29 | self.n_embd = config.n_embd
30 | self.dropout = config.dropout
31 | # flash attention make GPU go brrrrr but support is only in PyTorch nightly and still a bit scary
32 | self.flash = (
33 | hasattr(torch.nn.functional, "scaled_dot_product_attention") and self.dropout == 0.0
34 | )
35 |
36 | def forward(self, x):
37 | B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
38 |
39 | # calculate query, key, values for all heads in batch and move head forward to be the batch dim
40 | q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
41 | k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
42 | q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
43 | v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
44 |
45 | # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
46 | if self.flash:
47 | # efficient attention using Flash Attention CUDA kernels
48 | y = torch.nn.functional.scaled_dot_product_attention(
49 | q, k, v, attn_mask=None, dropout_p=self.dropout, is_causal=False
50 | )
51 | else:
52 | # manual implementation of attention
53 | att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
54 | att = F.softmax(att, dim=-1)
55 | att = self.attn_dropout(att)
56 | y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
57 | y = (
58 | y.transpose(1, 2).contiguous().view(B, T, C)
59 | ) # re-assemble all head outputs side by side
60 |
61 | # output projection
62 | y = self.resid_dropout(self.c_proj(y))
63 | return y
64 |
65 |
66 | class FineBlock(nn.Module):
67 | def __init__(self, config):
68 | super().__init__()
69 | self.ln_1 = nn.LayerNorm(config.n_embd)
70 | self.attn = NonCausalSelfAttention(config)
71 | self.ln_2 = nn.LayerNorm(config.n_embd)
72 | self.mlp = MLP(config)
73 |
74 | def forward(self, x):
75 | x = x + self.attn(self.ln_1(x))
76 | x = x + self.mlp(self.ln_2(x))
77 | return x
78 |
79 |
80 | class FineGPT(GPT):
81 | def __init__(self, config):
82 | super().__init__(config)
83 | del self.lm_head
84 | self.config = config
85 | self.n_codes_total = config.n_codes_total
86 | self.transformer = nn.ModuleDict(
87 | dict(
88 | wtes=nn.ModuleList(
89 | [
90 | nn.Embedding(config.input_vocab_size, config.n_embd)
91 | for _ in range(config.n_codes_total)
92 | ]
93 | ),
94 | wpe=nn.Embedding(config.block_size, config.n_embd),
95 | drop=nn.Dropout(config.dropout),
96 | h=nn.ModuleList([FineBlock(config) for _ in range(config.n_layer)]),
97 | ln_f=nn.LayerNorm(config.n_embd),
98 | )
99 | )
100 | self.lm_heads = nn.ModuleList(
101 | [
102 | nn.Linear(config.n_embd, config.output_vocab_size, bias=False)
103 | for _ in range(config.n_codes_given, self.n_codes_total)
104 | ]
105 | )
106 | for i in range(self.n_codes_total - config.n_codes_given):
107 | self.transformer.wtes[i + 1].weight = self.lm_heads[i].weight
108 |
109 | def forward(self, pred_idx, idx):
110 | device = idx.device
111 | b, t, codes = idx.size()
112 | assert (
113 | t <= self.config.block_size
114 | ), f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
115 | assert pred_idx > 0, "cannot predict 0th codebook"
116 | assert codes == self.n_codes_total, (b, t, codes)
117 | pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0) # shape (1, t)
118 |
119 | # forward the GPT model itself
120 | tok_embs = [
121 | wte(idx[:, :, i]).unsqueeze(-1) for i, wte in enumerate(self.transformer.wtes)
122 | ] # token embeddings of shape (b, t, n_embd)
123 | tok_emb = torch.cat(tok_embs, dim=-1)
124 | pos_emb = self.transformer.wpe(pos) # position embeddings of shape (1, t, n_embd)
125 | x = tok_emb[:, :, :, : pred_idx + 1].sum(dim=-1)
126 | x = self.transformer.drop(x + pos_emb)
127 | for block in self.transformer.h:
128 | x = block(x)
129 | x = self.transformer.ln_f(x)
130 | logits = self.lm_heads[pred_idx - self.config.n_codes_given](x)
131 | return logits
132 |
133 | def get_num_params(self, non_embedding=True):
134 | """
135 | Return the number of parameters in the model.
136 | For non-embedding count (default), the position embeddings get subtracted.
137 | The token embeddings would too, except due to the parameter sharing these
138 | params are actually used as weights in the final layer, so we include them.
139 | """
140 | n_params = sum(p.numel() for p in self.parameters())
141 | if non_embedding:
142 | for wte in self.transformer.wtes:
143 | n_params -= wte.weight.numel()
144 | n_params -= self.transformer.wpe.weight.numel()
145 | return n_params
146 |
147 |
148 | @dataclass
149 | class FineGPTConfig(GPTConfig):
150 | n_codes_total: int = 8
151 | n_codes_given: int = 1
152 |
--------------------------------------------------------------------------------
/barkify/datas/__init__.py:
--------------------------------------------------------------------------------
1 | import functools
2 | from torch.utils.data import DataLoader
3 |
4 | from .data import Dataset
5 | from .data import Text2semanticCollateFn, Semantic2coarseCollateFn
6 | stage_collate = [Text2semanticCollateFn, Semantic2coarseCollateFn]
7 |
8 | from .tokenizer import ZHTokenizer, PhonemeTokenizer
9 |
10 | def StageDataloader(params, stage=1, file='train') -> DataLoader:
11 |
12 | tokenizer = PhonemeTokenizer() # TODO: add more tokenizer
13 | dataset = Dataset(file=file, tokenizer=tokenizer, **params.dataset)
14 | collate_fn = functools.partial(stage_collate[int(stage) - 1], **params.collate_fn)
15 |
16 | loader = DataLoader(dataset, collate_fn=collate_fn, **params.dataloader)
17 | return loader
--------------------------------------------------------------------------------
/barkify/datas/data.py:
--------------------------------------------------------------------------------
1 | import json
2 | from os.path import join as pjoin
3 |
4 | import random
5 | import torch
6 | import numpy as np
7 | from tqdm import tqdm
8 |
9 | from .tokenizer import SplitTokenizer
10 |
11 | class Dataset:
12 | def __init__(self, start_path, file,
13 | tokenizer: SplitTokenizer = None, add_prompt = False
14 | ):
15 |
16 | self.start_path = start_path
17 | self.tokenizer = tokenizer
18 | self.add_prompt = add_prompt # add prompt for semantic or acoustic
19 |
20 | with open(pjoin(start_path, 'meta', file+'.json')) as f:
21 | self._datas = [json.loads(i) for i in f.readlines()]
22 | g2p_already = any([data.get('g2p', False) for data in self._datas])
23 |
24 | if not g2p_already:
25 | print("g2p all texts.")
26 | for data in tqdm(self._datas):
27 | if not data.get('g2p', False):
28 | data['g2p'] = self.tokenizer.g2p(data['text'])
29 |
30 | with open(pjoin(start_path, 'meta', file+'.json'), "w") as f:
31 | for data in self._datas:
32 | f.writelines(json.dumps(data, ensure_ascii=False)+'\n')
33 |
34 | def __getitem__(self, idx):
35 | batch = {}
36 | data = self._datas[idx]
37 |
38 | batch['name'] = data['name']
39 | batch['text'] = self.tokenizer.token2id(data['g2p'])
40 |
41 | semantic_path = pjoin(self.start_path, 'semantic_idx', data['name'])
42 | codec_path = pjoin(self.start_path, 'encodec_idx', data['name'])
43 | batch['semantic'] = torch.from_numpy(np.load(semantic_path))
44 | batch['encodec'] = torch.from_numpy(np.load(codec_path))
45 |
46 | if self.add_prompt:
47 | raise NotImplementedError
48 |
49 | return batch
50 |
51 | def __len__(self):
52 | return len(self._datas)
53 |
54 | def Text2semanticCollateFn(
55 | batches,
56 | text_window=512, semantic_window=512, # set training window size.
57 | text_token_num=210, fixed_length=True, ign_idx=-100, **kwargs
58 | ):
59 | '''
60 | logits not following bark: text(210) + pad_text(1) + eos(1) + infer(1) + semantic(2048)
61 | returns:
62 | input: [text with pad(512), infer(1), semantic with pad(512), eos(1)]
63 | tgt: [ign_idx(512), ign_idx(1), (semantic, ign_idx)(512), eos(1)]
64 | '''
65 |
66 | text_pad_token, semantic_pad_token = text_token_num, text_token_num + 1
67 | infer_token, text_offset = text_token_num + 2, text_token_num + 3
68 |
69 | semantic, text = [], []
70 | semantic_length, text_length = [], []
71 | names = []
72 |
73 | for batch in batches:
74 | names.append(batch['name'])
75 |
76 | _text = np.array(batch['text'])
77 | if fixed_length:
78 | _text = _text[-text_window:] # TODO: drop long text?
79 | text_length.append(len(_text))
80 |
81 | _text = np.pad(
82 | _text, (0, text_window - len(_text)),
83 | constant_values = text_pad_token
84 | )
85 |
86 | _semantic = batch['semantic'] + text_offset # different from bark, we add offset to semantic tokens rather than text.
87 | if fixed_length:
88 | _semantic = _semantic[-semantic_window:]
89 | semantic_length.append(len(_semantic))
90 |
91 | _semantic = np.pad(
92 | _semantic, (0, semantic_window - len(_semantic)),
93 | constant_values = semantic_pad_token
94 | )
95 |
96 | semantic.append(_semantic), text.append(_text)
97 |
98 | text, semantic = torch.from_numpy(np.stack(text, axis=0)), torch.from_numpy(np.stack(semantic, axis=0))
99 | text_length, semantic_length = torch.tensor(text_length), torch.tensor(semantic_length)
100 |
101 | B = text.shape[0]
102 | inputs = torch.cat([text, torch.ones(B, 1) * infer_token, semantic, torch.ones(B, 1) * semantic_pad_token], -1)
103 | tgt = torch.cat([torch.ones(B, text_window + 1) * ign_idx, semantic, torch.ones(B, 1) * semantic_pad_token], -1)
104 | tgt[:, text_window+1:][(torch.arange(0, semantic_window + 1) > semantic_length[:, None])] = -100
105 |
106 | batch = dict(names = names,
107 | input_ids = inputs.long(), labels = tgt.long(),
108 | semantic_length = semantic_length, text_length = text_length
109 | )
110 | return batch
111 |
112 | def Semantic2coarseCollateFn(
113 | batches, Q_size,
114 | semantic_window=256, semantic_to_coarse_ratio=3,
115 | semantic_token_num=2048, coarse_num=1024, ign_idx=-100,
116 | slice_range = 60, # window size in bark
117 | **kwargs
118 | ):
119 | '''
120 | logits following bark: semantic(2048) + pad(1) + infer(1) + coarse(1024*Q)
121 | returns:
122 | input: [semantic with pad(256), infer(1), coarse with pad(768)]
123 | tgt: [ign_idx(256), ign_idx(1), (coarse, ign_idx)(768)]
124 | '''
125 |
126 | semantic_pad_token = coarse_pad_token = semantic_token_num
127 | infer_token = semantic_token_num + 1
128 |
129 | semantic_window = semantic_window // 2 * 2
130 | acoustic_pad = int(np.ceil((semantic_window * semantic_to_coarse_ratio) / Q_size) * Q_size)
131 | semantic, coarse = [], []
132 | semantic_length, coarse_length = [], []
133 | names = []
134 |
135 | for batch in batches:
136 | names.append(batch['name'])
137 | semantic_start = random.randint(0, np.max([0, len(batch['semantic']) - (semantic_window - slice_range)]))
138 | semantic_start = (semantic_start // Q_size) * Q_size # add
139 | _semantic = batch['semantic'][semantic_start:semantic_start+semantic_window]
140 | semantic_length.append(len(_semantic))
141 |
142 | _semantic = np.pad(
143 | _semantic, (0, semantic_window - len(_semantic)),
144 | constant_values = semantic_pad_token
145 | )
146 |
147 | coarse_start = semantic_start * semantic_to_coarse_ratio
148 | coarse_end = coarse_start + acoustic_pad
149 |
150 | _coarse = batch['encodec'][:Q_size] + 2 + semantic_pad_token # add pad and infer token
151 | for i in range(1, _coarse.shape[0]):
152 | _coarse[i:] += coarse_num
153 | _coarse = _coarse.T.reshape(-1)[int(coarse_start) : int(coarse_end)]
154 | coarse_length.append(len(_coarse))
155 |
156 | _coarse = np.pad(
157 | _coarse, (0, acoustic_pad - len(_coarse)),
158 | constant_values = coarse_pad_token
159 | )
160 |
161 | semantic.append(_semantic), coarse.append(_coarse)
162 |
163 | semantic, coarse = torch.from_numpy(np.stack(semantic, axis=0)), torch.from_numpy(np.stack(coarse, axis=0))
164 | semantic_length, coarse_length = torch.tensor(semantic_length), torch.tensor(coarse_length)
165 |
166 | inputs = torch.cat([semantic, torch.ones(semantic.shape[0], 1) * infer_token, coarse], -1)
167 | tgt = torch.cat([torch.ones(semantic.shape[0], semantic_window + 1) * ign_idx, coarse], -1)
168 | tgt[(tgt == semantic_pad_token)|(tgt == coarse_pad_token)] = ign_idx
169 |
170 | batch = dict(names = names,
171 | input_ids = inputs.long(), labels = tgt.long(),
172 | semantic_length = semantic_length, coarse_length = coarse_length
173 | )
174 | return batch
175 |
176 | if __name__ == "__main__":
177 | from barkify.datas.tokenizer import ZHTokenizer
178 | dataset = Dataset('/root/intern/bark/barkify/work_env', 'train', ZHTokenizer())
179 | print(dataset[0]['text'])
--------------------------------------------------------------------------------
/barkify/datas/tokenizer.py:
--------------------------------------------------------------------------------
1 | # TODO: integrate with NeMo tokenizer?
2 | import string
3 | from pypinyin import lazy_pinyin, Style
4 | from g2p_en import G2p
5 |
6 | class SplitTokenizer:
7 | def __init__(self, **kwargs):
8 | self._token2id = {i:idx + 1 for idx, i in enumerate(ENGLISH_PRONUNCIATION_LIST)}
9 | self._id2token = {idx + 1:i for idx, i in enumerate(ENGLISH_PRONUNCIATION_LIST)}
10 |
11 | def __call__(self, text):
12 | text = self.g2p(text)
13 | return self.token2id(text)
14 |
15 | def g2p(self, text):
16 | # text: I'm Tom. -> [i, ', m, , t, o, m, .]
17 | text = text.lower()
18 | return " ".join(["" if i == ' ' else i for i in list(text)])
19 |
20 | def token2id(self, text):
21 | token = []
22 | for i in text.split():
23 | if i == ' ':
24 | token.append(self._token2id[''])
25 | elif self._token2id.get(i, None):
26 | token.append(self._token2id[i])
27 |
28 | return token
29 |
30 | class PhonemeTokenizer(SplitTokenizer):
31 | def __init__(self, **kwargs):
32 | self._g2p = G2p()
33 | self._token2id = {i:idx + 1 for idx, i in enumerate(self._g2p.phonemes)}
34 | self._id2token = {idx + 1:i for idx, i in enumerate(self._g2p.phonemes)}
35 |
36 | def g2p(self, text):
37 | text = self._g2p(text)
38 | return " ".join(["" if i == ' ' else i for i in text])
39 |
40 | class ZHTokenizer(SplitTokenizer):
41 | def __init__(self, **kwargs):
42 | # A basic english and chinese G2P tokenizer
43 | self._token2id = {i:idx + 1 for idx, i in enumerate(PINYIN_PRONUNCIATION_LIST)}
44 | self._id2token = {idx + 1:i for idx, i in enumerate(PINYIN_PRONUNCIATION_LIST)}
45 |
46 | def g2p(self, text):
47 | text = text.lower()
48 | initials = lazy_pinyin(text, neutral_tone_with_five=False, style=Style.INITIALS, strict=False)
49 | finals = lazy_pinyin(text, neutral_tone_with_five=False, style=Style.FINALS_TONE3)
50 |
51 | text_phone = []
52 | for _o in zip(initials, finals):
53 | if _o[0] != _o[1] and _o[0] != '':
54 | _o = ['@'+i for i in _o]
55 | text_phone.extend(_o)
56 | elif _o[0] != _o[1] and _o[0] == '':
57 | text_phone.append('@'+_o[1])
58 | else:
59 | text_phone.extend(["" if i == ' ' else i for i in list(_o[0])])
60 |
61 | return " ".join(text_phone)
62 |
63 | ENGLISH_PRONUNCIATION_LIST = list(string.ascii_lowercase) + list(",.?!") + ['']
64 | PINYIN_PRONUNCIATION_LIST = ENGLISH_PRONUNCIATION_LIST + [
65 | # PINYIN_PRONUNCIATION_LIST = [
66 | '@w', '@uo3', '@y', '@i3', '@j', '@ing1', '@b', '@a3', '@o1', '@l', '@a1', '@h', '@ei1', '@e',
67 | '@n', '@iou4', '@sh', '@i4', '@ie2', '@z', '@ai4', '@a4', '@m', '@g', '@ou3', '@t', '@uo2', '@i', '@q',
68 | '@vn2', '@uo1', '@zh', '@i1', '@d', '@ao4', '@uei4', '@a2', '@a', '@e4', '@ing3', '@ei', '@u4', '@uan2',
69 | '@f', '@an4', '@en2', '@c', '@uo4', '@uei1', '@iou1', '@ei2', '@e2', '@r', '@en4', '@eng1', '@e1', '@en1',
70 | '@ou4', '@ang4', '@p', '@eng2', '@ong4', '@u2', '@iang2', '@van1', '@ian1', '@ei4', '@er3', '@ia1', '@ou2',
71 | '@ao3', '@ou1', '@er2', '@s', '@i2', '@v4', '@x', '@ian4', '@ong1', '@uan3', '@uang2', '@ing4', '@ch', '@vn3',
72 | '@uen1', '@ai1', '@an3', '@eng4', '@ing2', '@ve4', '@k', '@ang3', '@en3', '@ai2', '@ian3', '@er4', '@ai3',
73 | '@uai4', '@ian2', '@ao1', '@eng3', '@ia4', '@n2', '@ang1', '@ie3', '@uen3', '@iou3', '@ei3', '@in4', '@v3',
74 | '@uen4', '@an2', '@iang1', '@in1', '@u3', '@ve2', '@e3', '@iang4', '@ia', '@an1', '@in3', '@iao4', '@ang2',
75 | '@vn1', '@iao3', '@u1', '@ie1', '@ie4', '@v2', '@uei2', '@iong1', '@iao1', '@o2', '@uei3', '@in2', '@iong4',
76 | '@ve1', '@uang1', '@iang3', '@uan4', '@iou2', '@en', '@uan1', '@ia2', '@ua1', '@ong3', '@van4', '@van2', '@uang3',
77 | '@iao2', '@ua4', '@ong2', '@uen2', '@iong3', '@er', '@v1', '@uang4', '@ia3', '@ve3', '@ua2', '@van3', '@ao2',
78 | '@o4', '@ua3', '@vn4', '@iong2', '@io1', '@uai1', '@ou', '@uai2', '@ua', '@ueng1', '@o', '@uai3', '@o3', '@uo',
79 | ] + list("、,。")
80 | # ] + list(string.ascii_lowercase) + list("、,。")
81 |
82 | if __name__ == "__main__":
83 | tokenizer = ZHTokenizer()
84 | print(tokenizer.g2p("I'm tom."))
85 | print(tokenizer("I'm tom."))
86 |
87 | tokenizer = PhonemeTokenizer()
88 | print(tokenizer.g2p("I'm tom."))
89 | print(tokenizer("I'm tom."))
90 |
--------------------------------------------------------------------------------
/barkify/pl_model.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | import torch.nn as nn
4 | import torch.optim as optim
5 |
6 | import pytorch_lightning as pl
7 | from transformers import get_linear_schedule_with_warmup, get_cosine_schedule_with_warmup, get_constant_schedule_with_warmup
8 |
9 | from .bark import GPT, GPTConfig
10 |
11 | class NanoGPT(pl.LightningModule):
12 | def __init__(
13 | self,
14 | model_config,
15 | lr,
16 | min_lr,
17 | weight_decay,
18 | warmup_iters,
19 | max_iters,
20 | lr_strategy = 'constant',
21 | **kwargs
22 | ):
23 | super().__init__()
24 | self.save_hyperparameters()
25 | self.criterion = nn.CrossEntropyLoss()
26 | self._create_model()
27 |
28 | def _create_model(self):
29 | gptconf = GPTConfig(**self.hparams.model_config)
30 | self.model = GPT(gptconf)
31 |
32 | def forward(self, x,) -> torch.Tensor:
33 | logits, _ = self.model(x, use_cache = False)
34 | return logits
35 |
36 | @torch.no_grad()
37 | def infer(self, x, *args, **kwargs):
38 | return self.forward(x, *args, **kwargs)
39 |
40 | def optimizer_step(self, *args, **kwargs) -> None:
41 | super().optimizer_step(*args, **kwargs)
42 | if self.optimizers().param_groups[0]['lr'] < self.hparams.min_lr and self.lr_schedulers().last_epoch > self.hparams.warmup_iters:
43 | self.optimizers().param_groups[0]['lr'] = self.hparams.min_lr
44 | self.lr_schedulers().last_epoch -= 1
45 |
46 | def configure_optimizers(self):
47 | optimizer = optim.AdamW(self.model.parameters(), lr=self.hparams.lr, weight_decay=self.hparams.weight_decay)
48 | # We don't return the lr scheduler because we need to apply it per iteration, not per epoch
49 | if self.hparams.lr_strategy == 'constant':
50 | lr_scheduler = get_constant_schedule_with_warmup(
51 | optimizer=optimizer,
52 | num_warmup_steps=self.hparams.warmup_iters,
53 | )
54 | elif self.hparams.lr_strategy == 'cosine':
55 | lr_scheduler = get_cosine_schedule_with_warmup(
56 | optimizer=optimizer,
57 | num_warmup_steps=self.hparams.warmup_iters,
58 | num_training_steps=self.hparams.max_iters,
59 | )
60 | elif self.hparams.lr_strategy == 'linear':
61 | lr_scheduler = get_linear_schedule_with_warmup(
62 | optimizer=optimizer,
63 | num_warmup_steps=self.hparams.warmup_iters,
64 | num_training_steps=self.hparams.max_iters,
65 | )
66 | else:
67 | raise NotImplementedError
68 |
69 | return [optimizer], [{"scheduler": lr_scheduler, "interval": "step"}]
70 |
71 | def on_train_start(self):
72 | self.logger.log_hyperparams(self.hparams)
73 |
74 | def training_step(self, batch, batch_idx):
75 | if 'extra' in batch:
76 | logits = self.forward(batch['input_ids'], extra=batch['extra'], merge_context = True)
77 | else:
78 | logits = self.forward(batch['input_ids'])
79 | loss = self.criterion(logits[:,:-1,:].reshape(-1, logits.shape[-1]), batch['labels'][:,1:].reshape(-1))
80 | top1_acc, top10_acc = self.get_top10_acc(logits[:1,:-1,:], batch['labels'][:1,1:])
81 |
82 | self.log("train_loss", loss, prog_bar=True)
83 | self.log("train_1acc", top1_acc)
84 | self.log("train_10acc", top10_acc)
85 |
86 | #self.log("lr", self.lr_schedulers().get_last_lr()[0])
87 | self.log("lr", self.optimizers().param_groups[0]['lr'])
88 |
89 | return loss
90 |
91 | def validation_step(self, batch, batch_idx):
92 | logits = self.forward(batch['input_ids'])
93 | loss = self.criterion(logits[:,:-1,:].reshape(-1, logits.shape[-1]), batch['labels'][:,1:].reshape(-1))
94 | top1_acc, top10_acc = self.get_top10_acc(logits[:32,:-1,:], batch['labels'][:32,1:]) #FIXME: B > 32 may raise error?
95 |
96 | self.log("val_loss", loss)
97 | self.log("val_1acc", top1_acc)
98 | self.log("val_10acc", top10_acc)
99 |
100 | def test_step(self, batch, batch_idx):
101 | raise NotImplementedError
102 |
103 | def get_top10_acc(self, logits, labels):
104 | with torch.no_grad():
105 | top10 = logits.topk(dim=-1, k=10).indices
106 | top1 = logits.topk(dim=-1, k=1).indices
107 | label = labels.unsqueeze(-1)
108 |
109 | total_item = (label != -100).sum()
110 | top10_acc = ((top10 == label).any(-1, keepdim=True) & (label != -100)).sum() / total_item
111 | top1_acc = ((top1 == label).any(-1, keepdim=True) & (label != -100)).sum() / total_item
112 | top10_acc, top1_acc = top10_acc.item(), top1_acc.item()
113 |
114 | return top1_acc, top10_acc
--------------------------------------------------------------------------------
/barkify/utils.py:
--------------------------------------------------------------------------------
1 | from glob import glob
2 |
3 | def Bestckpt(exp_root_dir):
4 | ckpts = glob(f"{exp_root_dir}/lightning_logs/*/checkpoints/*")
5 | ckpts = sorted(ckpts, key=lambda x:x.split("/")[-1].split("=")[1].split("-")[0])
6 | ckpts = ckpts[-1] if len(ckpts) > 0 else None
7 | return ckpts
--------------------------------------------------------------------------------
/configs/barkify.yaml:
--------------------------------------------------------------------------------
1 | name: ljspeech
2 | stage: 1
3 | start_path: ???
4 |
5 | common:
6 | trainer:
7 | accelerator: "gpu"
8 | # strategy: "deepspeed"
9 | # strategy: "ddp_find_unused_parameters_false"
10 | strategy: null
11 | precision: 16 # 32
12 |
13 | ckpt:
14 | mode: "min"
15 | monitor: "val_loss"
16 | save_top_k: 3
17 | every_n_epochs: 10
18 | save_weights_only: false
19 |
20 | stage1:
21 | ############# create dataset #############
22 | tokenizer: null
23 | dataset:
24 | start_path: ${start_path}
25 | add_prompt: false
26 |
27 | collate_fn:
28 | text_window: 256
29 | semantic_window: 768
30 | text_token_num: 210
31 |
32 | dataloader:
33 | batch_size: 128
34 | shuffle: True
35 | num_workers: 32
36 | persistent_workers: True
37 | pin_memory: True
38 |
39 | ############## model params ##############
40 | model:
41 | n_layer: 6
42 | n_head: 4
43 | n_embd: 512
44 | block_size: 1026 # 256(text_window) + 768(semantic) + 2(infer token, eos)
45 | bias: False
46 | dropout: 0.1
47 | input_vocab_size: 2261 # 210(pinyin) + 2048(semantic) + 3(infer, pad_text, pad_semantic)
48 | output_vocab_size: 2261
49 |
50 | ############## optim params ##############
51 | optim:
52 | lr: 1e-4
53 | min_lr: 2e-5
54 | weight_decay: 1e-3
55 | warmup_iters: 1000
56 | max_iters: 10000 # depends on your GPUs.
57 | lr_strategy: 'cosine'
58 | gradient_clip: 1
59 |
60 | stage2:
61 | ############# create dataset #############
62 | tokenizer: null
63 | dataset:
64 | start_path: ${start_path}
65 | add_prompt: false
66 |
67 | collate_fn:
68 | Q_size: 2 # num of acoustic token to predict
69 | semantic_to_coarse_ratio: 3
70 | semantic_window: 256
71 | semantic_token_num: 2048
72 | coarse_num: 1024
73 | slice_range: 60 # shift window size at inference in bark's source code
74 |
75 | dataloader:
76 | batch_size: 128
77 | shuffle: True
78 | num_workers: 32
79 | persistent_workers: True
80 | pin_memory: True
81 |
82 | ############## model params ##############
83 | model:
84 | n_layer: 6
85 | n_head: 4
86 | n_embd: 512
87 | block_size: 1025 # 256(semantic_window) + 256*3(window X ratio) + 1(infer token)
88 | bias: False
89 | dropout: 0.1
90 | input_vocab_size: 4098 # 2048(semantic) + 1024*2(coarse X Q_size) + 2(infer token, padding)
91 | output_vocab_size: 4098
92 |
93 | ############## optim params ##############
94 | optim:
95 | lr: 1e-4
96 | min_lr: 2e-5
97 | weight_decay: 1e-3
98 | warmup_iters: 1000
99 | max_iters: 20000 # depends on your GPUs.
100 | lr_strategy: 'cosine'
101 | gradient_clip: 1
102 |
103 |
104 | hydra:
105 | run:
106 | dir: ${start_path}/${name}/stage_${stage}
--------------------------------------------------------------------------------
/infer.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "id": "b5becc02",
6 | "metadata": {},
7 | "source": [
8 | "## Infer"
9 | ]
10 | },
11 | {
12 | "cell_type": "code",
13 | "execution_count": null,
14 | "id": "7a4f55af",
15 | "metadata": {},
16 | "outputs": [],
17 | "source": [
18 | "import tqdm\n",
19 | "import torch\n",
20 | "import torch.nn.functional as F\n",
21 | "from IPython.display import Audio\n",
22 | "from scipy.io.wavfile import write as write_wav\n",
23 | "\n",
24 | "import barkify.bark as bark \n",
25 | "from barkify.utils import Bestckpt\n",
26 | "from barkify.bark import create_infer_model\n",
27 | "from barkify.datas import PhonemeTokenizer\n",
28 | "\n",
29 | "from omegaconf import OmegaConf\n",
30 | "x_dict = OmegaConf.load(\"configs/barkify.yaml\")\n",
31 | "\n",
32 | "start_path = \"../work_env\" # your data folder."
33 | ]
34 | },
35 | {
36 | "cell_type": "code",
37 | "execution_count": null,
38 | "id": "df6dc569",
39 | "metadata": {},
40 | "outputs": [],
41 | "source": [
42 | "TEXT_INPUT_LEN = x_dict.stage1.collate_fn.text_window\n",
43 | "TEXT_TOKEN_NUM = x_dict.stage1.collate_fn.text_token_num\n",
44 | "SEMANTIC_EOS_TOKEN, SEMANTIC_INFER_TOKEN = TEXT_TOKEN_NUM+1, TEXT_TOKEN_NUM+2\n",
45 | "\n",
46 | "COARSE_BOOK = x_dict.stage2.collate_fn.Q_size\n",
47 | "SEMANTIC_TOKEN_NUM = x_dict.stage2.collate_fn.semantic_token_num\n",
48 | "SEMANTIC_INPUT_LEN = x_dict.stage2.collate_fn.semantic_window\n",
49 | "CODEC_TOKEN_NUM = x_dict.stage2.collate_fn.coarse_num\n",
50 | "COARSE_INFER_TOKEN = SEMANTIC_TOKEN_NUM + 1\n",
51 | "\n",
52 | "stage1_model = create_infer_model(x_dict.stage1.model).cuda()\n",
53 | "stage2_model = create_infer_model(x_dict.stage2.model).cuda()\n",
54 | "\n",
55 | "tokenizer = PhonemeTokenizer()"
56 | ]
57 | },
58 | {
59 | "cell_type": "code",
60 | "execution_count": null,
61 | "id": "276b89a2",
62 | "metadata": {},
63 | "outputs": [],
64 | "source": [
65 | "ckpt = torch.load(Bestckpt(f\"{start_path}/{x_dict.name}/stage_1\"))['state_dict']\n",
66 | "stage1_model.load_state_dict({\".\".join(i.split(\"model.\")[1:]):ckpt[i] for i in ckpt})\n",
67 | "\n",
68 | "ckpt = torch.load(Bestckpt(f\"{start_path}/{x_dict.name}/stage_2\"))['state_dict']\n",
69 | "stage2_model.load_state_dict({\".\".join(i.split(\"model.\")[1:]):ckpt[i] for i in ckpt})"
70 | ]
71 | },
72 | {
73 | "cell_type": "code",
74 | "execution_count": null,
75 | "id": "eb2e1bfe",
76 | "metadata": {},
77 | "outputs": [],
78 | "source": [
79 | "def generate_stage1(\n",
80 | " x, \n",
81 | " model,\n",
82 | " tempature = 0.60,\n",
83 | " max_steps = 512,\n",
84 | "):\n",
85 | "\n",
86 | " kv_cache = None\n",
87 | "\n",
88 | " x = F.pad(x, (0, TEXT_INPUT_LEN-x.shape[1]), mode='constant', value=TEXT_TOKEN_NUM)\n",
89 | " x = torch.cat([\n",
90 | " x, \n",
91 | " torch.tensor([SEMANTIC_INFER_TOKEN], dtype=x.dtype, device=x.device)[None]\n",
92 | " ], dim=1)\n",
93 | " \n",
94 | " text_len = x.shape[1]\n",
95 | "\n",
96 | " for _ in tqdm.trange(max_steps):\n",
97 | " \n",
98 | " if kv_cache is not None:\n",
99 | " x_input = x[:, [-1]]\n",
100 | " else:\n",
101 | " x_input = x\n",
102 | "\n",
103 | " logits, kv_cache = model(x_input, use_cache=True, past_kv=kv_cache)\n",
104 | "\n",
105 | " relevant_logits = torch.hstack(\n",
106 | " (logits[0, 0, TEXT_TOKEN_NUM+3:], logits[0, 0, [SEMANTIC_EOS_TOKEN]])\n",
107 | " )\n",
108 | "\n",
109 | " probs = F.softmax(relevant_logits / tempature, dim=-1)\n",
110 | " item_next = torch.multinomial(probs, num_samples=1)\n",
111 | "\n",
112 | " if item_next == len(relevant_logits) - 1:\n",
113 | " break\n",
114 | "\n",
115 | " x = torch.cat((x, item_next[None]+TEXT_TOKEN_NUM+3), dim=1)\n",
116 | " \n",
117 | " return x[:, text_len:] - TEXT_TOKEN_NUM - 3"
118 | ]
119 | },
120 | {
121 | "cell_type": "code",
122 | "execution_count": null,
123 | "id": "5f7f9642",
124 | "metadata": {},
125 | "outputs": [],
126 | "source": [
127 | "def generate_stage2(\n",
128 | " x, \n",
129 | " model,\n",
130 | " tempature = 0.6,\n",
131 | " max_steps = 768\n",
132 | "):\n",
133 | "\n",
134 | " kv_cache = None\n",
135 | " \n",
136 | " x = F.pad(x, (0, SEMANTIC_INPUT_LEN-x.shape[1]), mode='constant', value=SEMANTIC_TOKEN_NUM)\n",
137 | " x = torch.cat([\n",
138 | " x, \n",
139 | " torch.tensor([COARSE_INFER_TOKEN], dtype=x.dtype, device=x.device)[None]\n",
140 | " ], dim=1)\n",
141 | " \n",
142 | " semantic_len = x.shape[1]\n",
143 | "\n",
144 | " for i in tqdm.trange(max_steps):\n",
145 | "\n",
146 | " Q = i % COARSE_BOOK\n",
147 | " if kv_cache is not None:\n",
148 | " x_input = x[:, [-1]]\n",
149 | " else:\n",
150 | " x_input = x\n",
151 | "\n",
152 | " logits, kv_cache = model(x_input, use_cache=True, past_kv=kv_cache)\n",
153 | " start = SEMANTIC_TOKEN_NUM + 2 + Q * CODEC_TOKEN_NUM\n",
154 | " relevant_logits = logits[0, 0, start : start + CODEC_TOKEN_NUM]\n",
155 | " \n",
156 | " probs = F.softmax(relevant_logits / tempature, dim=-1)\n",
157 | " item_next = torch.multinomial(probs, num_samples=1)\n",
158 | " x = torch.cat((x, item_next[None]+start), dim=1)\n",
159 | " \n",
160 | " output = x[:, semantic_len:]\n",
161 | " for Q in range(COARSE_BOOK):\n",
162 | " output[:, Q::COARSE_BOOK] -= (SEMANTIC_TOKEN_NUM + 2 + Q * CODEC_TOKEN_NUM)\n",
163 | " \n",
164 | " return output.reshape(-1, COARSE_BOOK).T"
165 | ]
166 | },
167 | {
168 | "cell_type": "code",
169 | "execution_count": null,
170 | "id": "9b93eeec",
171 | "metadata": {},
172 | "outputs": [],
173 | "source": [
174 | "tgt_text = \"At a given signal, they reenacted the event. Baker's movements were timed with a stopwatch.\"\n",
175 | "\n",
176 | "tokens = tokenizer(tgt_text)\n",
177 | "dummy_tokenized = torch.tensor([tokens]).cuda()\n",
178 | "dummy_semantic = generate_stage1(dummy_tokenized, model=stage1_model)\n",
179 | "dummy_coarse = generate_stage2(dummy_semantic, model=stage2_model)"
180 | ]
181 | },
182 | {
183 | "cell_type": "code",
184 | "execution_count": null,
185 | "id": "9499e142",
186 | "metadata": {
187 | "scrolled": true
188 | },
189 | "outputs": [],
190 | "source": [
191 | "dummy_fine = bark.generate_fine(dummy_coarse.detach().cpu().numpy(), history_prompt=None)\n",
192 | "audio_array = bark.codec_decode(dummy_fine)\n",
193 | "\n",
194 | "# play text in notebook\n",
195 | "Audio(audio_array, rate=24000)\n",
196 | "\n",
197 | "# write_wav(\"bark_generation.wav\", 24000, audio_array)"
198 | ]
199 | }
200 | ],
201 | "metadata": {
202 | "kernelspec": {
203 | "display_name": "Python 3 (ipykernel)",
204 | "language": "python",
205 | "name": "python3"
206 | },
207 | "language_info": {
208 | "codemirror_mode": {
209 | "name": "ipython",
210 | "version": 3
211 | },
212 | "file_extension": ".py",
213 | "mimetype": "text/x-python",
214 | "name": "python",
215 | "nbconvert_exporter": "python",
216 | "pygments_lexer": "ipython3",
217 | "version": "3.8.15"
218 | },
219 | "toc": {
220 | "base_numbering": 1,
221 | "nav_menu": {},
222 | "number_sections": true,
223 | "sideBar": true,
224 | "skip_h1_title": false,
225 | "title_cell": "Table of Contents",
226 | "title_sidebar": "Contents",
227 | "toc_cell": false,
228 | "toc_position": {
229 | "height": "calc(100% - 180px)",
230 | "left": "10px",
231 | "top": "150px",
232 | "width": "165px"
233 | },
234 | "toc_section_display": true,
235 | "toc_window_display": false
236 | }
237 | },
238 | "nbformat": 4,
239 | "nbformat_minor": 5
240 | }
241 |
--------------------------------------------------------------------------------
/process.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "id": "b863990e",
6 | "metadata": {},
7 | "source": [
8 | "# Barkify: an unoffical repo for training 'bark' like generative model\n",
9 | "\n"
10 | ]
11 | },
12 | {
13 | "cell_type": "code",
14 | "execution_count": null,
15 | "id": "bd7a97cb",
16 | "metadata": {},
17 | "outputs": [],
18 | "source": [
19 | "from glob import glob\n",
20 | "from tqdm import tqdm\n",
21 | "import os\n",
22 | "import shutil\n",
23 | "import json\n",
24 | "import soundfile as sf\n",
25 | "import subprocess\n",
26 | "import numpy as np\n",
27 | "import re\n",
28 | "\n",
29 | "import random\n",
30 | "from IPython.display import Audio\n",
31 | "import IPython.display as iply\n",
32 | "\n",
33 | "from pqdm.processes import pqdm\n",
34 | "from pqdm.threads import pqdm as pqdmT\n",
35 | "\n",
36 | "import matplotlib.pyplot as plt\n",
37 | "from matplotlib.pyplot import imshow\n",
38 | "\n",
39 | "import torch\n",
40 | "import torch.nn.functional as F\n",
41 | "\n",
42 | "start_path = \"../work_env\"\n",
43 | "start_path += '/'\n",
44 | "\n",
45 | "cuda_devices = [0,1,2,3]*3\n",
46 | "NJOB = len(cuda_devices)"
47 | ]
48 | },
49 | {
50 | "cell_type": "code",
51 | "execution_count": null,
52 | "id": "cbc9c004",
53 | "metadata": {},
54 | "outputs": [],
55 | "source": [
56 | "def run_multiprocess(x):\n",
57 | " DEVICE, IJOB = x\n",
58 | " subprocess.call(f\"CUDA_VISIBLE_DEVICES={DEVICE} \"+\n",
59 | " f\"nohup python {os.path.join(start_path, 'tmp', 'temp.py')} {IJOB} \"+\n",
60 | " f\"> {os.path.join(start_path, 'tmp','tmp_'+str(IJOB))} 2>&1\",\n",
61 | " shell=True)\n",
62 | " \n",
63 | "def write_tmp(_script):\n",
64 | " with open(os.path.join(start_path, \"tmp\", \"temp.py\"), \"w\") as f:\n",
65 | " f.write(_script)"
66 | ]
67 | },
68 | {
69 | "cell_type": "code",
70 | "execution_count": null,
71 | "id": "6397855b",
72 | "metadata": {},
73 | "outputs": [],
74 | "source": [
75 | "# put all wav to /raw/ folder\n",
76 | "os.makedirs(start_path + \"wavs\", exist_ok=True)\n",
77 | "\n",
78 | "def run(x):\n",
79 | " # \n",
80 | " name = x.replace(\"/raw/\",\"/wavs/\")\n",
81 | " return subprocess.run(f\"ffmpeg -hide_banner -loglevel panic -y -i '{x}' -ac 1 -ar 24000 {name}\",\n",
82 | " shell=True)\n",
83 | "wavs_name = glob(os.path.join(start_path, \"raw/*.wav\"))\n",
84 | "res = pqdm(wavs_name, run, n_jobs=128)"
85 | ]
86 | },
87 | {
88 | "cell_type": "code",
89 | "execution_count": null,
90 | "id": "a0d74e12",
91 | "metadata": {},
92 | "outputs": [],
93 | "source": [
94 | "os.makedirs(start_path + \"wavs16k\", exist_ok=True)\n",
95 | "\n",
96 | "def run(x):\n",
97 | " # \n",
98 | " name = x.replace(\"/raw/\",\"/wavs16k/\")\n",
99 | " return subprocess.run(f\"ffmpeg -hide_banner -loglevel panic -y -i '{x}' -ac 1 -ar 16000 {name}\",\n",
100 | " shell=True)\n",
101 | "wavs_name = glob(os.path.join(start_path, \"raw/*.wav\"))\n",
102 | "res = pqdm(wavs_name, run, n_jobs=128)"
103 | ]
104 | },
105 | {
106 | "cell_type": "markdown",
107 | "id": "3acacc46",
108 | "metadata": {},
109 | "source": [
110 | "## Fetch features from wav2vec2-xlsr and encodec\n",
111 | "根据[meta的论文](https://arxiv.org/pdf/2105.11084.pdf),w2v中15-18层都取得了不错的结果.
\n",
112 | "在我们的实验中证实,w2v2-xlsr中第15层与bark的semantic idx的相关性最高.\n",
113 | "\n",
114 | "在[audioLM](https://arxiv.org/pdf/2209.03143.pdf)中,使用了w2v-bert的7层作为特征.在w2v中,相关性也非常高.
\n",
115 | "我们使用第15层作为实验的特征层.\n",
116 | "\n",
117 | "另外,我们测试了bark代码中,coarse2fine的部分.我们发现,coarse和fine均从encodec中直接得到.
\n",
118 | "因此,如果没有特殊需求,不建议重新训练.\n",
119 | "\n",
120 | "1. Fetch w2v2 hiddens.\n",
121 | "2. Cluster them by codes from fairseq.\n",
122 | "3. Dump cluster idxs to numpy files.\n",
123 | "4. Fetch discrete indices from encodec."
124 | ]
125 | },
126 | {
127 | "cell_type": "code",
128 | "execution_count": null,
129 | "id": "fd82893a",
130 | "metadata": {},
131 | "outputs": [],
132 | "source": [
133 | "LAYER = 15 # use 15th layer.\n",
134 | "Hubert = False # use hubert feature or use w2v2-xlsr feature\n",
135 | "\n",
136 | "clusters = 2048 # semantic idx nums.\n",
137 | "dtype = 32 if clusters > 65535 else 16\n",
138 | "\n",
139 | "percent = 0.1 # use 10% datas for clustering. \n",
140 | "n_init = 10 # when this is larger than 10, some error may occur.\n",
141 | "\n",
142 | "params = dict( # params for clusting. \n",
143 | " init='k-means++', max_iter=100, batch_size=10000, \n",
144 | " tol=0, max_no_improvement=100, n_init=n_init, reassignment_ratio=0,\n",
145 | " compute_labels=False, verbose=100\n",
146 | " )\n",
147 | "params['n_clusters'] = clusters"
148 | ]
149 | },
150 | {
151 | "cell_type": "markdown",
152 | "id": "c03a7c36",
153 | "metadata": {},
154 | "source": [
155 | "### Fetch semantic hiddens\n"
156 | ]
157 | },
158 | {
159 | "cell_type": "code",
160 | "execution_count": null,
161 | "id": "60b61553",
162 | "metadata": {},
163 | "outputs": [],
164 | "source": [
165 | "# process hubert/w2v2 data such that the Hz of semantic idx equals to 50 rather than 49.9\n",
166 | "def process_audio(path):\n",
167 | " wav, fs = sf.read(path)\n",
168 | " safe_length = (wav.shape[0] // 640) * 640 + 160\n",
169 | " wav = wav[:safe_length]\n",
170 | " wav = np.pad(wav, (safe_length - wav.shape[0], 0))\n",
171 | " sf.write(path, wav, fs)\n",
172 | " \n",
173 | "wavs_name = glob(start_path + \"/wavs16k/*.wav\")\n",
174 | "res = pqdmT(wavs_name, process_audio, n_jobs=32)"
175 | ]
176 | },
177 | {
178 | "cell_type": "code",
179 | "execution_count": null,
180 | "id": "b3c3ef1f",
181 | "metadata": {},
182 | "outputs": [],
183 | "source": [
184 | "os.makedirs(start_path + \"tmp\", exist_ok=True)\n",
185 | "os.makedirs(start_path + \"feats\", exist_ok=True)\n",
186 | "\n",
187 | "_script =f'''\n",
188 | "from transformers import AutoProcessor, AutoModelForPreTraining, AutoModel\n",
189 | "if {Hubert}:\n",
190 | " model = AutoModel.from_pretrained(\"TencentGameMate/chinese-hubert-large\")\n",
191 | "else:\n",
192 | " model = AutoModelForPreTraining.from_pretrained(\"facebook/wav2vec2-large-xlsr-53\")\n",
193 | "processor = AutoProcessor.from_pretrained(\"facebook/wav2vec2-base-960h\")\n",
194 | "print(\"downloaded!\")\n",
195 | "'''\n",
196 | "write_tmp(_script) # download it.\n",
197 | "run_multiprocess((0, 0))\n",
198 | " \n",
199 | "_script += f'''\n",
200 | "import os\n",
201 | "import sys\n",
202 | "import torch\n",
203 | "import torch.nn.functional as F\n",
204 | "import numpy as np\n",
205 | "import soundfile as sf\n",
206 | "\n",
207 | "from tqdm import tqdm\n",
208 | "from glob import glob\n",
209 | "\n",
210 | "device = 'cuda:0'\n",
211 | "model = model.to(device)\n",
212 | "\n",
213 | "start_path = '{start_path}'\n",
214 | "NJOB={NJOB}\n",
215 | "meta = glob(start_path+'/wavs16k/*.wav')\n",
216 | "slice_len = (len(meta) + NJOB - 1) // NJOB\n",
217 | "meta = meta[int(sys.argv[1])*slice_len : (int(sys.argv[1])+1)*slice_len]\n",
218 | "\n",
219 | "for _dir in tqdm(meta):\n",
220 | " audio, fs = sf.read(_dir)\n",
221 | " assert fs == 16000\n",
222 | " inputs = processor(audio, sampling_rate=16000, return_tensors=\"pt\")\n",
223 | " for i in inputs:\n",
224 | " inputs[i] = inputs[i].cuda()\n",
225 | " with torch.no_grad():\n",
226 | " hidden = model(**inputs, output_hidden_states=True)['hidden_states'][{LAYER} - 1]\n",
227 | " hidden = F.layer_norm(hidden, hidden.shape)\n",
228 | " np.save(_dir.replace(\"wavs16k\",\"feats\").replace(\".wav\",\"\"), hidden[0].cpu().numpy())\n",
229 | "\n",
230 | "print(\"Finish!\")\n",
231 | "'''"
232 | ]
233 | },
234 | {
235 | "cell_type": "code",
236 | "execution_count": null,
237 | "id": "02f292ef",
238 | "metadata": {},
239 | "outputs": [],
240 | "source": [
241 | "write_tmp(_script)\n",
242 | "res = pqdm(list(zip(cuda_devices, list(range(NJOB)))), run_multiprocess, n_jobs=NJOB)"
243 | ]
244 | },
245 | {
246 | "cell_type": "markdown",
247 | "id": "e1cd99b9",
248 | "metadata": {},
249 | "source": [
250 | "### Clustering"
251 | ]
252 | },
253 | {
254 | "cell_type": "code",
255 | "execution_count": null,
256 | "id": "ced8ded9",
257 | "metadata": {},
258 | "outputs": [],
259 | "source": [
260 | "# Basicly from: https://github.com/facebookresearch/fairseq/tree/main/examples/hubert/simple_kmeans\n",
261 | "# TODO: It seems pretty slow. Maybe try faiss or conduct PCA before clustering?\n",
262 | "\n",
263 | "os.makedirs(start_path + \"tmp\", exist_ok=True)\n",
264 | "os.makedirs(start_path + \"assets\", exist_ok=True)\n",
265 | "\n",
266 | "_script = f'''\n",
267 | "import os\n",
268 | "import sys\n",
269 | "import numpy as np\n",
270 | "import random\n",
271 | "\n",
272 | "from tqdm import tqdm\n",
273 | "from glob import glob\n",
274 | "import joblib\n",
275 | "\n",
276 | "from sklearn.cluster import MiniBatchKMeans\n",
277 | "\n",
278 | "params = {params}\n",
279 | "kmeans = MiniBatchKMeans(**params)\n",
280 | "\n",
281 | "start_path = '{start_path}'\n",
282 | "meta = glob(start_path+'/feats/*.npy')\n",
283 | "random.shuffle(meta)\n",
284 | "meta = meta[ : int(len(meta)*{percent})]\n",
285 | "meta = np.concatenate(\n",
286 | " [np.load(i) for i in meta], axis = 0\n",
287 | ")\n",
288 | "print(\"concated.\")\n",
289 | "\n",
290 | "kmeans.fit(meta)\n",
291 | "joblib.dump(kmeans, start_path + \"assets/km_model.joblib\")\n",
292 | "\n",
293 | "inertia = -kmeans.score(meta) / len(meta)\n",
294 | "print(\"total intertia: %.5f\", inertia)\n",
295 | "print(\"Finish!\")\n",
296 | "'''"
297 | ]
298 | },
299 | {
300 | "cell_type": "code",
301 | "execution_count": null,
302 | "id": "6594087e",
303 | "metadata": {},
304 | "outputs": [],
305 | "source": [
306 | "write_tmp(_script)\n",
307 | "\n",
308 | "# without thread limit, some error may occur.\n",
309 | "!echo OPENBLAS_NUM_THREADS=16 OMP_NUM_THREADS=16 python {start_path + '/tmp/temp.py'}"
310 | ]
311 | },
312 | {
313 | "cell_type": "markdown",
314 | "id": "9396094f",
315 | "metadata": {},
316 | "source": [
317 | "### Infer semantic indices\n"
318 | ]
319 | },
320 | {
321 | "cell_type": "code",
322 | "execution_count": null,
323 | "id": "cd16049f",
324 | "metadata": {},
325 | "outputs": [],
326 | "source": [
327 | "os.makedirs(start_path + \"tmp\", exist_ok=True)\n",
328 | "os.makedirs(start_path + \"semantic_idx\", exist_ok=True)\n",
329 | "\n",
330 | "_script = f'''\n",
331 | "import os\n",
332 | "import sys\n",
333 | "import numpy as np\n",
334 | "import random\n",
335 | "\n",
336 | "from tqdm import tqdm\n",
337 | "from glob import glob\n",
338 | "import joblib\n",
339 | "\n",
340 | "class ApplyKmeans(object):\n",
341 | " def __init__(self, km_path):\n",
342 | " self.km_model = joblib.load(km_path)\n",
343 | " self.C_np = self.km_model.cluster_centers_.transpose()\n",
344 | " self.Cnorm_np = (self.C_np ** 2).sum(0, keepdims=True)\n",
345 | "\n",
346 | " def __call__(self, x):\n",
347 | " dist = (\n",
348 | " (x ** 2).sum(1, keepdims=True)\n",
349 | " - 2 * np.matmul(x, self.C_np)\n",
350 | " + self.Cnorm_np\n",
351 | " )\n",
352 | " return np.argmin(dist, axis=1)\n",
353 | " \n",
354 | "\n",
355 | "start_path = '{start_path}'\n",
356 | "NJOB={NJOB}\n",
357 | "meta = glob(start_path+'/feats/*.npy')\n",
358 | "slice_len = (len(meta) + NJOB - 1) // NJOB\n",
359 | "meta = meta[int(sys.argv[1])*slice_len : (int(sys.argv[1])+1)*slice_len]\n",
360 | "\n",
361 | "apply_kmeans = ApplyKmeans(start_path + '/assets/km_model.joblib')\n",
362 | "\n",
363 | "for _dir in tqdm(meta):\n",
364 | " _idxs = apply_kmeans(np.load(_dir)).astype(np.int{dtype})\n",
365 | " np.save(_dir.replace(\"feats\",\"semantic_idx\"), _idxs)\n",
366 | " \n",
367 | "print(\"Finish!\")\n",
368 | "'''"
369 | ]
370 | },
371 | {
372 | "cell_type": "code",
373 | "execution_count": null,
374 | "id": "9d96be9b",
375 | "metadata": {},
376 | "outputs": [],
377 | "source": [
378 | "write_tmp(_script)\n",
379 | "res = pqdm(list(zip(cuda_devices, list(range(NJOB)))), run_multiprocess, n_jobs=NJOB) "
380 | ]
381 | },
382 | {
383 | "cell_type": "markdown",
384 | "id": "40a05724",
385 | "metadata": {},
386 | "source": [
387 | "### Fetch discrete indices from encodec\n"
388 | ]
389 | },
390 | {
391 | "cell_type": "code",
392 | "execution_count": null,
393 | "id": "30a2ac29",
394 | "metadata": {},
395 | "outputs": [],
396 | "source": [
397 | "os.makedirs(start_path + \"tmp\", exist_ok=True)\n",
398 | "os.makedirs(start_path + \"encodec_idx\", exist_ok=True)\n",
399 | "\n",
400 | "_script ='''\n",
401 | "from encodec import EncodecModel\n",
402 | "from encodec.utils import convert_audio\n",
403 | "\n",
404 | "model = EncodecModel.encodec_model_24khz()\n",
405 | "model.set_target_bandwidth(6.0)\n",
406 | "print(\"downloaded!\")\n",
407 | "'''\n",
408 | "write_tmp(_script) # download it.\n",
409 | "run_multiprocess((0, 0))\n",
410 | " \n",
411 | "_script += f'''\n",
412 | "import os\n",
413 | "import sys\n",
414 | "import torch\n",
415 | "import numpy as np\n",
416 | "import torchaudio\n",
417 | "\n",
418 | "from tqdm import tqdm\n",
419 | "from glob import glob\n",
420 | "\n",
421 | "device = 'cuda:0'\n",
422 | "model = model.to(device)\n",
423 | "\n",
424 | "start_path = '{start_path}'\n",
425 | "NJOB={NJOB}\n",
426 | "meta = glob(start_path+'/wavs/*.wav')\n",
427 | "slice_len = (len(meta) + NJOB - 1) // NJOB\n",
428 | "meta = meta[int(sys.argv[1])*slice_len : (int(sys.argv[1])+1)*slice_len]\n",
429 | "\n",
430 | "for _dir in tqdm(meta):\n",
431 | " wav, sr = torchaudio.load(_dir)\n",
432 | " wav = wav[:, :wav.shape[-1] - int(160*1.5)]\n",
433 | " # wav = convert_audio(wav, sr, model.sample_rate, model.channels)\n",
434 | " wav = wav.unsqueeze(0).cuda()\n",
435 | " with torch.no_grad():\n",
436 | " encoded_frames = model.encode(wav)[0][0][0]\n",
437 | " np.save(_dir.replace(\"wavs\",\"encodec_idx\").replace(\".wav\",\"\"), encoded_frames.cpu().numpy().astype(np.int16))\n",
438 | "\n",
439 | "print(\"Finish!\")\n",
440 | "'''"
441 | ]
442 | },
443 | {
444 | "cell_type": "code",
445 | "execution_count": null,
446 | "id": "5741ee0c",
447 | "metadata": {},
448 | "outputs": [],
449 | "source": [
450 | "write_tmp(_script)\n",
451 | "res = pqdm(list(zip(cuda_devices, list(range(NJOB)))), run_multiprocess, n_jobs=NJOB) "
452 | ]
453 | },
454 | {
455 | "cell_type": "markdown",
456 | "id": "85b774eb",
457 | "metadata": {},
458 | "source": [
459 | "## Prepare dataset"
460 | ]
461 | },
462 | {
463 | "cell_type": "code",
464 | "execution_count": null,
465 | "id": "3bf05fff",
466 | "metadata": {},
467 | "outputs": [],
468 | "source": [
469 | "# min_length = 2 # at least 2 seconds.\n",
470 | "# _min_length = int(np.floor(min_length * 50))\n",
471 | "for_eval = 128\n",
472 | "\n",
473 | "os.makedirs(start_path + \"meta\", exist_ok=True)"
474 | ]
475 | },
476 | {
477 | "cell_type": "code",
478 | "execution_count": null,
479 | "id": "acb150c3",
480 | "metadata": {},
481 | "outputs": [],
482 | "source": [
483 | "import pandas as pd\n",
484 | "import json\n",
485 | "datas = pd.read_csv(start_path+\"meta/metadata.csv\",sep=\"|\", header=None)\n",
486 | "datas = datas.dropna()\n",
487 | "datas = datas.values\n",
488 | "np.random.shuffle(datas)\n",
489 | "\n",
490 | "with open(start_path + \"meta/train.json\",\"w\") as f:\n",
491 | " for i in datas[:-for_eval]:\n",
492 | " line = json.dumps({\"name\": i[0] +\".npy\", \"text\": i[1]})\n",
493 | " f.writelines(line+\"\\n\")\n",
494 | "\n",
495 | "with open(start_path + \"meta/eval.json\",\"w\") as f:\n",
496 | " for i in datas[-for_eval:]:\n",
497 | " line = json.dumps({\"name\": i[0] +\".npy\", \"text\": i[1]})\n",
498 | " f.writelines(line+\"\\n\")"
499 | ]
500 | }
501 | ],
502 | "metadata": {
503 | "kernelspec": {
504 | "display_name": "Python 3 (ipykernel)",
505 | "language": "python",
506 | "name": "python3"
507 | },
508 | "language_info": {
509 | "codemirror_mode": {
510 | "name": "ipython",
511 | "version": 3
512 | },
513 | "file_extension": ".py",
514 | "mimetype": "text/x-python",
515 | "name": "python",
516 | "nbconvert_exporter": "python",
517 | "pygments_lexer": "ipython3",
518 | "version": "3.8.15"
519 | },
520 | "toc": {
521 | "base_numbering": 1,
522 | "nav_menu": {},
523 | "number_sections": true,
524 | "sideBar": true,
525 | "skip_h1_title": false,
526 | "title_cell": "Table of Contents",
527 | "title_sidebar": "Contents",
528 | "toc_cell": false,
529 | "toc_position": {
530 | "height": "calc(100% - 180px)",
531 | "left": "10px",
532 | "top": "150px",
533 | "width": "181.663px"
534 | },
535 | "toc_section_display": true,
536 | "toc_window_display": false
537 | }
538 | },
539 | "nbformat": 4,
540 | "nbformat_minor": 5
541 | }
542 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | funcy
2 | numpy
3 | scipy
4 | tqdm
5 | torch==1.13.0
6 | pytorch_lightning==1.9.0
7 | encodec
8 | transformers
--------------------------------------------------------------------------------
/trainer.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | import hydra
4 | from glob import glob
5 |
6 | import pytorch_lightning as pl
7 | from pytorch_lightning.callbacks import ModelCheckpoint
8 |
9 | from barkify.datas import StageDataloader
10 | from barkify.pl_model import NanoGPT
11 | from barkify.utils import Bestckpt
12 |
13 | @hydra.main(config_path='configs', config_name='barkify')
14 | def main(cfg=None):
15 |
16 | exp_name = f'stage_{cfg.stage}'
17 | exp_cfg = cfg[f'stage{cfg.stage}']
18 | exp_root_dir = f'{cfg.start_path}/{cfg.name}/{exp_name}'
19 |
20 | # define datas
21 | train_loader = StageDataloader(exp_cfg, cfg.stage, 'train')
22 | val_loader = StageDataloader(exp_cfg, cfg.stage, 'eval')
23 |
24 | # define model
25 | model = NanoGPT(model_config = exp_cfg.model, **exp_cfg.optim)
26 |
27 | # define trainer
28 | trainer = pl.Trainer(
29 | default_root_dir=exp_root_dir,
30 | callbacks = ModelCheckpoint(**cfg.common.ckpt),
31 | max_steps = exp_cfg.optim.max_iters,
32 | gradient_clip_val = exp_cfg.optim.gradient_clip,
33 | **cfg.common.trainer
34 | )
35 | trainer.logger._default_hp_metric = None # Optional logging argument that we don't need
36 |
37 | # load best ckpt
38 | ckpt = Bestckpt(exp_root_dir)
39 | trainer.fit(model, train_loader, val_loader, ckpt_path=ckpt)
40 |
41 | if __name__ == '__main__':
42 | torch.set_float32_matmul_precision('medium')
43 | main()
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
--------------------------------------------------------------------------------