├── .gitignore ├── README.md ├── barkify ├── bark │ ├── __init__.py │ ├── api.py │ ├── generation.py │ ├── model.py │ └── model_fine.py ├── datas │ ├── __init__.py │ ├── data.py │ └── tokenizer.py ├── pl_model.py └── utils.py ├── configs └── barkify.yaml ├── infer.ipynb ├── process.ipynb ├── requirements.txt └── trainer.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | runs/ 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | pip-wheel-metadata/ 26 | share/python-wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .nox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | *.py,cover 53 | .hypothesis/ 54 | .pytest_cache/ 55 | 56 | # Translations 57 | *.mo 58 | *.pot 59 | 60 | # Django stuff: 61 | *.log 62 | local_settings.py 63 | db.sqlite3 64 | db.sqlite3-journal 65 | 66 | # Flask stuff: 67 | instance/ 68 | .webassets-cache 69 | 70 | # Scrapy stuff: 71 | .scrapy 72 | 73 | # Sphinx documentation 74 | docs/_build/ 75 | 76 | # PyBuilder 77 | target/ 78 | 79 | # Jupyter Notebook 80 | .ipynb_checkpoints 81 | 82 | # IPython 83 | profile_default/ 84 | ipython_config.py 85 | 86 | # pyenv 87 | .python-version 88 | 89 | # pipenv 90 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 91 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 92 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 93 | # install all needed dependencies. 94 | #Pipfile.lock 95 | 96 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 97 | __pypackages__/ 98 | 99 | # Celery stuff 100 | celerybeat-schedule 101 | celerybeat.pid 102 | 103 | # SageMath parsed files 104 | *.sage.py 105 | 106 | # Environments 107 | .env 108 | .venv 109 | env/ 110 | venv/ 111 | ENV/ 112 | env.bak/ 113 | venv.bak/ 114 | 115 | # Spyder project settings 116 | .spyderproject 117 | .spyproject 118 | 119 | # Rope project settings 120 | .ropeproject 121 | 122 | # mkdocs documentation 123 | /site 124 | 125 | # mypy 126 | .mypy_cache/ 127 | .dmypy.json 128 | dmypy.json 129 | 130 | # Pyre type checker 131 | .pyre/ 132 | 133 | # datasets & checkpoints 134 | work_env/ -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # barkify 2 | Barkify: an unoffical repo for training Bark, a text-prompted generative audio model by suno-ai. 3 | 4 | Bark has two GPT style models which is compatible for prompting and other tricks from NLP. Bark realize a great real world tts result but the repo itself doesn't a training recipe. We want to conduct some experiments or train this model. Here we release our basic training code which might be a guidance of training for open source community. 5 | 6 | ## Process dataset 7 | We do our experiment on LJspeech. Follow the instrcutions in `process.ipynb`.
8 | For Chinese, we test a famous steamer named `峰哥亡命天涯`. It shows an acceptable result but worse than our other TTS repo. 9 | For English, we test LibriTTS dataset. It works fine and basic items in our roadmap have been proved. 10 | 11 | ## Training 12 | Stage1 stands for text to semantic and stage2 stands for semantic to acoustic.
13 | You should config paramters in the `configs/barkify.yaml`. We use one A100 to train our model (both S1&S2). 14 | ``` 15 | # training stage 1 or 2 16 | python trainer.py start_path=/path/to/your/work_env stage=1 name= 17 | python trainer.py start_path=/path/to/your/work_env stage=2 name= 18 | ``` 19 | 20 | ## Inference 21 | Directly use `infer.ipynb` and follow the instrcutions to infer your model. 22 | 23 | ## Roadmap 24 | We have already achieve the following items and we will release our code soon. 25 | - [x] Construct a basic training code for bark-like generative model 26 | - [x] Test one speaker scenario 27 | - [x] Test multi speaker scenario 28 | - [x] Test speaker semantic prompting 29 | - [x] Test speech/audio acoustic prompting 30 | - [x] Test variable length data(as we use a fixed length now) 31 | 32 | These items are pretty data-hungry or rely on massive GPUs.
33 | So we are open to any sponsors or collaborators to finish these jobs.
34 | You could contact us by QQ: 3284494602 or email us at 3284494602@qq.com 35 | 36 | - [ ] Long-form generation(which may be longer than 1min.) 37 | - [ ] Support more language(especially for ZH) 38 | - [ ] Paralanguage modeling in the text input 39 | - [ ] Speaker generation by text prompts 40 | - [ ] Emotion/Timbre/Rhythm controlling by text/acoustic prompts 41 | - [ ] Add/Remove background noise(which might be important for downstream tasks) 42 | 43 | ## Appreciation 44 | - [bark](https://github.com/suno-ai/bark/) is a transformer-based text-to-audio model. 45 | - [Vall-E](https://github.com/lifeiteng/vall-e) is an unofficial PyTorch implementation of VALL-E. 46 | -------------------------------------------------------------------------------- /barkify/bark/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import GPT, GPTConfig 2 | from .model_fine import FineGPT, FineGPTConfig 3 | 4 | from .generation import SAMPLE_RATE, preload_models 5 | from .generation import generate_fine, codec_decode 6 | 7 | def create_infer_model(model_config): 8 | gptconf = GPTConfig(**model_config) 9 | return GPT(gptconf).eval() -------------------------------------------------------------------------------- /barkify/bark/api.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Optional, Union 2 | 3 | import numpy as np 4 | 5 | from .generation import codec_decode, generate_coarse, generate_fine, generate_text_semantic 6 | 7 | 8 | def text_to_semantic( 9 | text: str, 10 | history_prompt: Optional[Union[Dict, str]] = None, 11 | temp: float = 0.7, 12 | silent: bool = False, 13 | ): 14 | """Generate semantic array from text. 15 | 16 | Args: 17 | text: text to be turned into audio 18 | history_prompt: history choice for audio cloning 19 | temp: generation temperature (1.0 more diverse, 0.0 more conservative) 20 | silent: disable progress bar 21 | 22 | Returns: 23 | numpy semantic array to be fed into `semantic_to_waveform` 24 | """ 25 | x_semantic = generate_text_semantic( 26 | text, 27 | history_prompt=history_prompt, 28 | temp=temp, 29 | silent=silent, 30 | use_kv_caching=True 31 | ) 32 | return x_semantic 33 | 34 | 35 | def semantic_to_waveform( 36 | semantic_tokens: np.ndarray, 37 | history_prompt: Optional[Union[Dict, str]] = None, 38 | temp: float = 0.7, 39 | silent: bool = False, 40 | output_full: bool = False, 41 | ): 42 | """Generate audio array from semantic input. 43 | 44 | Args: 45 | semantic_tokens: semantic token output from `text_to_semantic` 46 | history_prompt: history choice for audio cloning 47 | temp: generation temperature (1.0 more diverse, 0.0 more conservative) 48 | silent: disable progress bar 49 | output_full: return full generation to be used as a history prompt 50 | 51 | Returns: 52 | numpy audio array at sample frequency 24khz 53 | """ 54 | coarse_tokens = generate_coarse( 55 | semantic_tokens, 56 | history_prompt=history_prompt, 57 | temp=temp, 58 | silent=silent, 59 | use_kv_caching=True 60 | ) 61 | fine_tokens = generate_fine( 62 | coarse_tokens, 63 | history_prompt=history_prompt, 64 | temp=0.5, 65 | ) 66 | audio_arr = codec_decode(fine_tokens) 67 | if output_full: 68 | full_generation = { 69 | "semantic_prompt": semantic_tokens, 70 | "coarse_prompt": coarse_tokens, 71 | "fine_prompt": fine_tokens, 72 | } 73 | return full_generation, audio_arr 74 | return audio_arr 75 | 76 | 77 | def save_as_prompt(filepath, full_generation): 78 | assert(filepath.endswith(".npz")) 79 | assert(isinstance(full_generation, dict)) 80 | assert("semantic_prompt" in full_generation) 81 | assert("coarse_prompt" in full_generation) 82 | assert("fine_prompt" in full_generation) 83 | np.savez(filepath, **full_generation) 84 | 85 | 86 | def generate_audio( 87 | text: str, 88 | history_prompt: Optional[Union[Dict, str]] = None, 89 | text_temp: float = 0.7, 90 | waveform_temp: float = 0.7, 91 | silent: bool = False, 92 | output_full: bool = False, 93 | ): 94 | """Generate audio array from input text. 95 | 96 | Args: 97 | text: text to be turned into audio 98 | history_prompt: history choice for audio cloning 99 | text_temp: generation temperature (1.0 more diverse, 0.0 more conservative) 100 | waveform_temp: generation temperature (1.0 more diverse, 0.0 more conservative) 101 | silent: disable progress bar 102 | output_full: return full generation to be used as a history prompt 103 | 104 | Returns: 105 | numpy audio array at sample frequency 24khz 106 | """ 107 | semantic_tokens = text_to_semantic( 108 | text, 109 | history_prompt=history_prompt, 110 | temp=text_temp, 111 | silent=silent, 112 | ) 113 | out = semantic_to_waveform( 114 | semantic_tokens, 115 | history_prompt=history_prompt, 116 | temp=waveform_temp, 117 | silent=silent, 118 | output_full=output_full, 119 | ) 120 | if output_full: 121 | full_generation, audio_arr = out 122 | return full_generation, audio_arr 123 | else: 124 | audio_arr = out 125 | return audio_arr 126 | -------------------------------------------------------------------------------- /barkify/bark/generation.py: -------------------------------------------------------------------------------- 1 | # Copy from bark 2 | 3 | import contextlib 4 | import gc 5 | import os 6 | import re 7 | 8 | from encodec import EncodecModel 9 | import funcy 10 | import logging 11 | import numpy as np 12 | from scipy.special import softmax 13 | import torch 14 | import torch.nn.functional as F 15 | import tqdm 16 | from transformers import BertTokenizer 17 | from huggingface_hub import hf_hub_download 18 | 19 | from .model import GPTConfig, GPT 20 | from .model_fine import FineGPT, FineGPTConfig 21 | 22 | if ( 23 | torch.cuda.is_available() and 24 | hasattr(torch.cuda, "amp") and 25 | hasattr(torch.cuda.amp, "autocast") and 26 | hasattr(torch.cuda, "is_bf16_supported") and 27 | torch.cuda.is_bf16_supported() 28 | ): 29 | autocast = funcy.partial(torch.cuda.amp.autocast, dtype=torch.bfloat16) 30 | else: 31 | @contextlib.contextmanager 32 | def autocast(): 33 | yield 34 | 35 | 36 | # hold models in global scope to lazy load 37 | global models 38 | models = {} 39 | 40 | global models_devices 41 | models_devices = {} 42 | 43 | 44 | CONTEXT_WINDOW_SIZE = 1024 45 | 46 | SEMANTIC_RATE_HZ = 49.9 47 | SEMANTIC_VOCAB_SIZE = 10_000 48 | 49 | CODEBOOK_SIZE = 1024 50 | N_COARSE_CODEBOOKS = 2 51 | N_FINE_CODEBOOKS = 8 52 | COARSE_RATE_HZ = 75 53 | 54 | SAMPLE_RATE = 24_000 55 | 56 | 57 | SUPPORTED_LANGS = [ 58 | ("English", "en"), 59 | ("German", "de"), 60 | ("Spanish", "es"), 61 | ("French", "fr"), 62 | ("Hindi", "hi"), 63 | ("Italian", "it"), 64 | ("Japanese", "ja"), 65 | ("Korean", "ko"), 66 | ("Polish", "pl"), 67 | ("Portuguese", "pt"), 68 | ("Russian", "ru"), 69 | ("Turkish", "tr"), 70 | ("Chinese", "zh"), 71 | ] 72 | 73 | ALLOWED_PROMPTS = {"announcer"} 74 | for _, lang in SUPPORTED_LANGS: 75 | for prefix in ("", f"v2{os.path.sep}"): 76 | for n in range(10): 77 | ALLOWED_PROMPTS.add(f"{prefix}{lang}_speaker_{n}") 78 | 79 | 80 | logger = logging.getLogger(__name__) 81 | 82 | 83 | CUR_PATH = os.path.dirname(os.path.abspath(__file__)) 84 | 85 | 86 | default_cache_dir = os.path.join(os.path.expanduser("~"), ".cache") 87 | CACHE_DIR = os.path.join(os.getenv("XDG_CACHE_HOME", default_cache_dir), "suno", "bark_v0") 88 | 89 | 90 | def _cast_bool_env_var(s): 91 | return s.lower() in ('true', '1', 't') 92 | 93 | 94 | USE_SMALL_MODELS = _cast_bool_env_var(os.environ.get("SUNO_USE_SMALL_MODELS", "False")) 95 | GLOBAL_ENABLE_MPS = _cast_bool_env_var(os.environ.get("SUNO_ENABLE_MPS", "False")) 96 | OFFLOAD_CPU = _cast_bool_env_var(os.environ.get("SUNO_OFFLOAD_CPU", "False")) 97 | 98 | 99 | REMOTE_MODEL_PATHS = { 100 | "text_small": { 101 | "repo_id": "suno/bark", 102 | "file_name": "text.pt", 103 | }, 104 | "coarse_small": { 105 | "repo_id": "suno/bark", 106 | "file_name": "coarse.pt", 107 | }, 108 | "fine_small": { 109 | "repo_id": "suno/bark", 110 | "file_name": "fine.pt", 111 | }, 112 | "text": { 113 | "repo_id": "suno/bark", 114 | "file_name": "text_2.pt", 115 | }, 116 | "coarse": { 117 | "repo_id": "suno/bark", 118 | "file_name": "coarse_2.pt", 119 | }, 120 | "fine": { 121 | "repo_id": "suno/bark", 122 | "file_name": "fine_2.pt", 123 | }, 124 | } 125 | 126 | 127 | if not hasattr(torch.nn.functional, 'scaled_dot_product_attention') and torch.cuda.is_available(): 128 | logger.warning( 129 | "torch version does not support flash attention. You will get faster" + 130 | " inference speed by upgrade torch to newest nightly version." 131 | ) 132 | 133 | 134 | def _grab_best_device(use_gpu=True): 135 | if torch.cuda.device_count() > 0 and use_gpu: 136 | device = "cuda" 137 | elif torch.backends.mps.is_available() and use_gpu and GLOBAL_ENABLE_MPS: 138 | device = "mps" 139 | else: 140 | device = "cpu" 141 | return device 142 | 143 | 144 | def _get_ckpt_path(model_type, use_small=False): 145 | key = model_type 146 | if use_small or USE_SMALL_MODELS: 147 | key += "_small" 148 | return os.path.join(CACHE_DIR, REMOTE_MODEL_PATHS[key]["file_name"]) 149 | 150 | 151 | def _download(from_hf_path, file_name): 152 | os.makedirs(CACHE_DIR, exist_ok=True) 153 | hf_hub_download(repo_id=from_hf_path, filename=file_name, local_dir=CACHE_DIR) 154 | 155 | 156 | class InferenceContext: 157 | def __init__(self, benchmark=False): 158 | # we can't expect inputs to be the same length, so disable benchmarking by default 159 | self._chosen_cudnn_benchmark = benchmark 160 | self._cudnn_benchmark = None 161 | 162 | def __enter__(self): 163 | self._cudnn_benchmark = torch.backends.cudnn.benchmark 164 | torch.backends.cudnn.benchmark = self._chosen_cudnn_benchmark 165 | 166 | def __exit__(self, exc_type, exc_value, exc_traceback): 167 | torch.backends.cudnn.benchmark = self._cudnn_benchmark 168 | 169 | 170 | if torch.cuda.is_available(): 171 | torch.backends.cuda.matmul.allow_tf32 = True 172 | torch.backends.cudnn.allow_tf32 = True 173 | 174 | 175 | @contextlib.contextmanager 176 | def _inference_mode(): 177 | with InferenceContext(), torch.inference_mode(), torch.no_grad(), autocast(): 178 | yield 179 | 180 | 181 | def _clear_cuda_cache(): 182 | if torch.cuda.is_available(): 183 | torch.cuda.empty_cache() 184 | torch.cuda.synchronize() 185 | 186 | 187 | def clean_models(model_key=None): 188 | global models 189 | model_keys = [model_key] if model_key is not None else models.keys() 190 | for k in model_keys: 191 | if k in models: 192 | del models[k] 193 | _clear_cuda_cache() 194 | gc.collect() 195 | 196 | 197 | def _load_model(ckpt_path, device, use_small=False, model_type="text"): 198 | if model_type == "text": 199 | ConfigClass = GPTConfig 200 | ModelClass = GPT 201 | elif model_type == "coarse": 202 | ConfigClass = GPTConfig 203 | ModelClass = GPT 204 | elif model_type == "fine": 205 | ConfigClass = FineGPTConfig 206 | ModelClass = FineGPT 207 | else: 208 | raise NotImplementedError() 209 | model_key = f"{model_type}_small" if use_small or USE_SMALL_MODELS else model_type 210 | model_info = REMOTE_MODEL_PATHS[model_key] 211 | if not os.path.exists(ckpt_path): 212 | logger.info(f"{model_type} model not found, downloading into `{CACHE_DIR}`.") 213 | _download(model_info["repo_id"], model_info["file_name"]) 214 | checkpoint = torch.load(ckpt_path, map_location=device) 215 | # this is a hack 216 | model_args = checkpoint["model_args"] 217 | if "input_vocab_size" not in model_args: 218 | model_args["input_vocab_size"] = model_args["vocab_size"] 219 | model_args["output_vocab_size"] = model_args["vocab_size"] 220 | del model_args["vocab_size"] 221 | gptconf = ConfigClass(**checkpoint["model_args"]) 222 | model = ModelClass(gptconf) 223 | state_dict = checkpoint["model"] 224 | # fixup checkpoint 225 | unwanted_prefix = "_orig_mod." 226 | for k, v in list(state_dict.items()): 227 | if k.startswith(unwanted_prefix): 228 | state_dict[k[len(unwanted_prefix) :]] = state_dict.pop(k) 229 | extra_keys = set(state_dict.keys()) - set(model.state_dict().keys()) 230 | extra_keys = set([k for k in extra_keys if not k.endswith(".attn.bias")]) 231 | missing_keys = set(model.state_dict().keys()) - set(state_dict.keys()) 232 | missing_keys = set([k for k in missing_keys if not k.endswith(".attn.bias")]) 233 | if len(extra_keys) != 0: 234 | raise ValueError(f"extra keys found: {extra_keys}") 235 | if len(missing_keys) != 0: 236 | raise ValueError(f"missing keys: {missing_keys}") 237 | model.load_state_dict(state_dict, strict=False) 238 | n_params = model.get_num_params() 239 | val_loss = checkpoint["best_val_loss"].item() 240 | logger.info(f"model loaded: {round(n_params/1e6,1)}M params, {round(val_loss,3)} loss") 241 | model.eval() 242 | model.to(device) 243 | del checkpoint, state_dict 244 | _clear_cuda_cache() 245 | if model_type == "text": 246 | tokenizer = BertTokenizer.from_pretrained("bert-base-multilingual-cased") 247 | return { 248 | "model": model, 249 | "tokenizer": tokenizer, 250 | } 251 | return model 252 | 253 | 254 | def _load_codec_model(device): 255 | model = EncodecModel.encodec_model_24khz() 256 | model.set_target_bandwidth(6.0) 257 | model.eval() 258 | model.to(device) 259 | _clear_cuda_cache() 260 | return model 261 | 262 | 263 | def load_model(use_gpu=True, use_small=False, force_reload=False, model_type="text"): 264 | _load_model_f = funcy.partial(_load_model, model_type=model_type, use_small=use_small) 265 | if model_type not in ("text", "coarse", "fine"): 266 | raise NotImplementedError() 267 | global models 268 | global models_devices 269 | device = _grab_best_device(use_gpu=use_gpu) 270 | model_key = f"{model_type}" 271 | if OFFLOAD_CPU: 272 | models_devices[model_key] = device 273 | device = "cpu" 274 | if model_key not in models or force_reload: 275 | ckpt_path = _get_ckpt_path(model_type, use_small=use_small) 276 | clean_models(model_key=model_key) 277 | model = _load_model_f(ckpt_path, device) 278 | models[model_key] = model 279 | if model_type == "text": 280 | models[model_key]["model"].to(device) 281 | else: 282 | models[model_key].to(device) 283 | return models[model_key] 284 | 285 | 286 | def load_codec_model(use_gpu=True, force_reload=False): 287 | global models 288 | global models_devices 289 | device = _grab_best_device(use_gpu=use_gpu) 290 | if device == "mps": 291 | # encodec doesn't support mps 292 | device = "cpu" 293 | model_key = "codec" 294 | if OFFLOAD_CPU: 295 | models_devices[model_key] = device 296 | device = "cpu" 297 | if model_key not in models or force_reload: 298 | clean_models(model_key=model_key) 299 | model = _load_codec_model(device) 300 | models[model_key] = model 301 | models[model_key].to(device) 302 | return models[model_key] 303 | 304 | 305 | def preload_models( 306 | text_use_gpu=True, 307 | text_use_small=False, 308 | coarse_use_gpu=True, 309 | coarse_use_small=False, 310 | fine_use_gpu=True, 311 | fine_use_small=False, 312 | codec_use_gpu=True, 313 | force_reload=False, 314 | ): 315 | """Load all the necessary models for the pipeline.""" 316 | if _grab_best_device() == "cpu" and ( 317 | text_use_gpu or coarse_use_gpu or fine_use_gpu or codec_use_gpu 318 | ): 319 | logger.warning("No GPU being used. Careful, inference might be very slow!") 320 | _ = load_model( 321 | model_type="text", use_gpu=text_use_gpu, use_small=text_use_small, force_reload=force_reload 322 | ) 323 | _ = load_model( 324 | model_type="coarse", 325 | use_gpu=coarse_use_gpu, 326 | use_small=coarse_use_small, 327 | force_reload=force_reload, 328 | ) 329 | _ = load_model( 330 | model_type="fine", use_gpu=fine_use_gpu, use_small=fine_use_small, force_reload=force_reload 331 | ) 332 | _ = load_codec_model(use_gpu=codec_use_gpu, force_reload=force_reload) 333 | 334 | 335 | #### 336 | # Generation Functionality 337 | #### 338 | 339 | 340 | def _tokenize(tokenizer, text): 341 | return tokenizer.encode(text, add_special_tokens=False) 342 | 343 | 344 | def _detokenize(tokenizer, enc_text): 345 | return tokenizer.decode(enc_text) 346 | 347 | 348 | def _normalize_whitespace(text): 349 | return re.sub(r"\s+", " ", text).strip() 350 | 351 | 352 | TEXT_ENCODING_OFFSET = 10_048 353 | SEMANTIC_PAD_TOKEN = 10_000 354 | TEXT_PAD_TOKEN = 129_595 355 | SEMANTIC_INFER_TOKEN = 129_599 356 | 357 | 358 | def _load_history_prompt(history_prompt_input): 359 | if isinstance(history_prompt_input, str) and history_prompt_input.endswith(".npz"): 360 | history_prompt = np.load(history_prompt_input) 361 | elif isinstance(history_prompt_input, str): 362 | # make sure this works on non-ubuntu 363 | history_prompt_input = os.path.join(*history_prompt_input.split("/")) 364 | if history_prompt_input not in ALLOWED_PROMPTS: 365 | raise ValueError("history prompt not found") 366 | history_prompt = np.load( 367 | os.path.join(CUR_PATH, "assets", "prompts", f"{history_prompt_input}.npz") 368 | ) 369 | elif isinstance(history_prompt_input, dict): 370 | assert("semantic_prompt" in history_prompt_input) 371 | assert("coarse_prompt" in history_prompt_input) 372 | assert("fine_prompt" in history_prompt_input) 373 | history_prompt = history_prompt_input 374 | else: 375 | raise ValueError("history prompt format unrecognized") 376 | return history_prompt 377 | 378 | 379 | def generate_text_semantic( 380 | text, 381 | history_prompt=None, 382 | temp=0.7, 383 | top_k=None, 384 | top_p=None, 385 | silent=False, 386 | min_eos_p=0.2, 387 | max_gen_duration_s=None, 388 | allow_early_stop=True, 389 | use_kv_caching=False, 390 | ): 391 | """Generate semantic tokens from text.""" 392 | assert isinstance(text, str) 393 | text = _normalize_whitespace(text) 394 | assert len(text.strip()) > 0 395 | if history_prompt is not None: 396 | history_prompt = _load_history_prompt(history_prompt) 397 | semantic_history = history_prompt["semantic_prompt"] 398 | assert ( 399 | isinstance(semantic_history, np.ndarray) 400 | and len(semantic_history.shape) == 1 401 | and len(semantic_history) > 0 402 | and semantic_history.min() >= 0 403 | and semantic_history.max() <= SEMANTIC_VOCAB_SIZE - 1 404 | ) 405 | else: 406 | semantic_history = None 407 | # load models if not yet exist 408 | global models 409 | global models_devices 410 | if "text" not in models: 411 | preload_models() 412 | model_container = models["text"] 413 | model = model_container["model"] 414 | tokenizer = model_container["tokenizer"] 415 | encoded_text = np.array(_tokenize(tokenizer, text)) + TEXT_ENCODING_OFFSET 416 | if OFFLOAD_CPU: 417 | model.to(models_devices["text"]) 418 | device = next(model.parameters()).device 419 | if len(encoded_text) > 256: 420 | p = round((len(encoded_text) - 256) / len(encoded_text) * 100, 1) 421 | logger.warning(f"warning, text too long, lopping of last {p}%") 422 | encoded_text = encoded_text[:256] 423 | encoded_text = np.pad( 424 | encoded_text, 425 | (0, 256 - len(encoded_text)), 426 | constant_values=TEXT_PAD_TOKEN, 427 | mode="constant", 428 | ) 429 | if semantic_history is not None: 430 | semantic_history = semantic_history.astype(np.int64) 431 | # lop off if history is too long, pad if needed 432 | semantic_history = semantic_history[-256:] 433 | semantic_history = np.pad( 434 | semantic_history, 435 | (0, 256 - len(semantic_history)), 436 | constant_values=SEMANTIC_PAD_TOKEN, 437 | mode="constant", 438 | ) 439 | else: 440 | semantic_history = np.array([SEMANTIC_PAD_TOKEN] * 256) 441 | x = torch.from_numpy( 442 | np.hstack([ 443 | encoded_text, semantic_history, np.array([SEMANTIC_INFER_TOKEN]) 444 | ]).astype(np.int64) 445 | )[None] 446 | assert x.shape[1] == 256 + 256 + 1 447 | with _inference_mode(): 448 | x = x.to(device) 449 | n_tot_steps = 768 450 | # custom tqdm updates since we don't know when eos will occur 451 | pbar = tqdm.tqdm(disable=silent, total=100) 452 | pbar_state = 0 453 | tot_generated_duration_s = 0 454 | kv_cache = None 455 | for n in range(n_tot_steps): 456 | if use_kv_caching and kv_cache is not None: 457 | x_input = x[:, [-1]] 458 | else: 459 | x_input = x 460 | logits, kv_cache = model( 461 | x_input, merge_context=True, use_cache=use_kv_caching, past_kv=kv_cache 462 | ) 463 | relevant_logits = logits[0, 0, :SEMANTIC_VOCAB_SIZE] 464 | if allow_early_stop: 465 | relevant_logits = torch.hstack( 466 | (relevant_logits, logits[0, 0, [SEMANTIC_PAD_TOKEN]]) # eos 467 | ) 468 | if top_p is not None: 469 | # faster to convert to numpy 470 | logits_device = relevant_logits.device 471 | logits_dtype = relevant_logits.type() 472 | relevant_logits = relevant_logits.detach().cpu().type(torch.float32).numpy() 473 | sorted_indices = np.argsort(relevant_logits)[::-1] 474 | sorted_logits = relevant_logits[sorted_indices] 475 | cumulative_probs = np.cumsum(softmax(sorted_logits)) 476 | sorted_indices_to_remove = cumulative_probs > top_p 477 | sorted_indices_to_remove[1:] = sorted_indices_to_remove[:-1].copy() 478 | sorted_indices_to_remove[0] = False 479 | relevant_logits[sorted_indices[sorted_indices_to_remove]] = -np.inf 480 | relevant_logits = torch.from_numpy(relevant_logits) 481 | relevant_logits = relevant_logits.to(logits_device).type(logits_dtype) 482 | if top_k is not None: 483 | v, _ = torch.topk(relevant_logits, min(top_k, relevant_logits.size(-1))) 484 | relevant_logits[relevant_logits < v[-1]] = -float("Inf") 485 | probs = F.softmax(relevant_logits / temp, dim=-1) 486 | # multinomial bugged on mps: shuttle to cpu if necessary 487 | inf_device = probs.device 488 | if probs.device.type == "mps": 489 | probs = probs.to("cpu") 490 | item_next = torch.multinomial(probs, num_samples=1) 491 | probs = probs.to(inf_device) 492 | item_next = item_next.to(inf_device) 493 | if allow_early_stop and ( 494 | item_next == SEMANTIC_VOCAB_SIZE 495 | or (min_eos_p is not None and probs[-1] >= min_eos_p) 496 | ): 497 | # eos found, so break 498 | pbar.update(100 - pbar_state) 499 | break 500 | x = torch.cat((x, item_next[None]), dim=1) 501 | tot_generated_duration_s += 1 / SEMANTIC_RATE_HZ 502 | if max_gen_duration_s is not None and tot_generated_duration_s > max_gen_duration_s: 503 | pbar.update(100 - pbar_state) 504 | break 505 | if n == n_tot_steps - 1: 506 | pbar.update(100 - pbar_state) 507 | break 508 | del logits, relevant_logits, probs, item_next 509 | req_pbar_state = np.min([100, int(round(100 * n / n_tot_steps))]) 510 | if req_pbar_state > pbar_state: 511 | pbar.update(req_pbar_state - pbar_state) 512 | pbar_state = req_pbar_state 513 | pbar.close() 514 | out = x.detach().cpu().numpy().squeeze()[256 + 256 + 1 :] 515 | if OFFLOAD_CPU: 516 | model.to("cpu") 517 | assert all(0 <= out) and all(out < SEMANTIC_VOCAB_SIZE) 518 | _clear_cuda_cache() 519 | return out 520 | 521 | 522 | def _flatten_codebooks(arr, offset_size=CODEBOOK_SIZE): 523 | assert len(arr.shape) == 2 524 | arr = arr.copy() 525 | if offset_size is not None: 526 | for n in range(1, arr.shape[0]): 527 | arr[n, :] += offset_size * n 528 | flat_arr = arr.ravel("F") 529 | return flat_arr 530 | 531 | 532 | COARSE_SEMANTIC_PAD_TOKEN = 12_048 533 | COARSE_INFER_TOKEN = 12_050 534 | 535 | 536 | def generate_coarse( 537 | x_semantic, 538 | history_prompt=None, 539 | temp=0.7, 540 | top_k=None, 541 | top_p=None, 542 | silent=False, 543 | max_coarse_history=630, # min 60 (faster), max 630 (more context) 544 | sliding_window_len=60, 545 | use_kv_caching=False, 546 | ): 547 | """Generate coarse audio codes from semantic tokens.""" 548 | assert ( 549 | isinstance(x_semantic, np.ndarray) 550 | and len(x_semantic.shape) == 1 551 | and len(x_semantic) > 0 552 | and x_semantic.min() >= 0 553 | and x_semantic.max() <= SEMANTIC_VOCAB_SIZE - 1 554 | ) 555 | assert 60 <= max_coarse_history <= 630 556 | assert max_coarse_history + sliding_window_len <= 1024 - 256 557 | semantic_to_coarse_ratio = COARSE_RATE_HZ / SEMANTIC_RATE_HZ * N_COARSE_CODEBOOKS 558 | max_semantic_history = int(np.floor(max_coarse_history / semantic_to_coarse_ratio)) 559 | if history_prompt is not None: 560 | history_prompt = _load_history_prompt(history_prompt) 561 | x_semantic_history = history_prompt["semantic_prompt"] 562 | x_coarse_history = history_prompt["coarse_prompt"] 563 | assert ( 564 | isinstance(x_semantic_history, np.ndarray) 565 | and len(x_semantic_history.shape) == 1 566 | and len(x_semantic_history) > 0 567 | and x_semantic_history.min() >= 0 568 | and x_semantic_history.max() <= SEMANTIC_VOCAB_SIZE - 1 569 | and isinstance(x_coarse_history, np.ndarray) 570 | and len(x_coarse_history.shape) == 2 571 | and x_coarse_history.shape[0] == N_COARSE_CODEBOOKS 572 | and x_coarse_history.shape[-1] >= 0 573 | and x_coarse_history.min() >= 0 574 | and x_coarse_history.max() <= CODEBOOK_SIZE - 1 575 | and ( 576 | round(x_coarse_history.shape[-1] / len(x_semantic_history), 1) 577 | == round(semantic_to_coarse_ratio / N_COARSE_CODEBOOKS, 1) 578 | ) 579 | ) 580 | x_coarse_history = _flatten_codebooks(x_coarse_history) + SEMANTIC_VOCAB_SIZE 581 | # trim histories correctly 582 | n_semantic_hist_provided = np.min( 583 | [ 584 | max_semantic_history, 585 | len(x_semantic_history) - len(x_semantic_history) % 2, 586 | int(np.floor(len(x_coarse_history) / semantic_to_coarse_ratio)), 587 | ] 588 | ) 589 | n_coarse_hist_provided = int(round(n_semantic_hist_provided * semantic_to_coarse_ratio)) 590 | x_semantic_history = x_semantic_history[-n_semantic_hist_provided:].astype(np.int32) 591 | x_coarse_history = x_coarse_history[-n_coarse_hist_provided:].astype(np.int32) 592 | # TODO: bit of a hack for time alignment (sounds better) 593 | x_coarse_history = x_coarse_history[:-2] 594 | else: 595 | x_semantic_history = np.array([], dtype=np.int32) 596 | x_coarse_history = np.array([], dtype=np.int32) 597 | # load models if not yet exist 598 | global models 599 | global models_devices 600 | if "coarse" not in models: 601 | preload_models() 602 | model = models["coarse"] 603 | if OFFLOAD_CPU: 604 | model.to(models_devices["coarse"]) 605 | device = next(model.parameters()).device 606 | # start loop 607 | n_steps = int( 608 | round( 609 | np.floor(len(x_semantic) * semantic_to_coarse_ratio / N_COARSE_CODEBOOKS) 610 | * N_COARSE_CODEBOOKS 611 | ) 612 | ) 613 | assert n_steps > 0 and n_steps % N_COARSE_CODEBOOKS == 0 614 | x_semantic = np.hstack([x_semantic_history, x_semantic]).astype(np.int32) 615 | x_coarse = x_coarse_history.astype(np.int32) 616 | base_semantic_idx = len(x_semantic_history) 617 | with _inference_mode(): 618 | x_semantic_in = torch.from_numpy(x_semantic)[None].to(device) 619 | x_coarse_in = torch.from_numpy(x_coarse)[None].to(device) 620 | n_window_steps = int(np.ceil(n_steps / sliding_window_len)) 621 | n_step = 0 622 | for _ in tqdm.tqdm(range(n_window_steps), total=n_window_steps, disable=silent): 623 | semantic_idx = base_semantic_idx + int(round(n_step / semantic_to_coarse_ratio)) 624 | # pad from right side 625 | x_in = x_semantic_in[:, np.max([0, semantic_idx - max_semantic_history]) :] 626 | x_in = x_in[:, :256] 627 | x_in = F.pad( 628 | x_in, 629 | (0, 256 - x_in.shape[-1]), 630 | "constant", 631 | COARSE_SEMANTIC_PAD_TOKEN, 632 | ) 633 | x_in = torch.hstack( 634 | [ 635 | x_in, 636 | torch.tensor([COARSE_INFER_TOKEN])[None].to(device), 637 | x_coarse_in[:, -max_coarse_history:], 638 | ] 639 | ) 640 | kv_cache = None 641 | for _ in range(sliding_window_len): 642 | if n_step >= n_steps: 643 | continue 644 | is_major_step = n_step % N_COARSE_CODEBOOKS == 0 645 | 646 | if use_kv_caching and kv_cache is not None: 647 | x_input = x_in[:, [-1]] 648 | else: 649 | x_input = x_in 650 | 651 | logits, kv_cache = model(x_input, use_cache=use_kv_caching, past_kv=kv_cache) 652 | logit_start_idx = ( 653 | SEMANTIC_VOCAB_SIZE + (1 - int(is_major_step)) * CODEBOOK_SIZE 654 | ) 655 | logit_end_idx = ( 656 | SEMANTIC_VOCAB_SIZE + (2 - int(is_major_step)) * CODEBOOK_SIZE 657 | ) 658 | relevant_logits = logits[0, 0, logit_start_idx:logit_end_idx] 659 | if top_p is not None: 660 | # faster to convert to numpy 661 | logits_device = relevant_logits.device 662 | logits_dtype = relevant_logits.type() 663 | relevant_logits = relevant_logits.detach().cpu().type(torch.float32).numpy() 664 | sorted_indices = np.argsort(relevant_logits)[::-1] 665 | sorted_logits = relevant_logits[sorted_indices] 666 | cumulative_probs = np.cumsum(softmax(sorted_logits)) 667 | sorted_indices_to_remove = cumulative_probs > top_p 668 | sorted_indices_to_remove[1:] = sorted_indices_to_remove[:-1].copy() 669 | sorted_indices_to_remove[0] = False 670 | relevant_logits[sorted_indices[sorted_indices_to_remove]] = -np.inf 671 | relevant_logits = torch.from_numpy(relevant_logits) 672 | relevant_logits = relevant_logits.to(logits_device).type(logits_dtype) 673 | if top_k is not None: 674 | v, _ = torch.topk(relevant_logits, min(top_k, relevant_logits.size(-1))) 675 | relevant_logits[relevant_logits < v[-1]] = -float("Inf") 676 | probs = F.softmax(relevant_logits / temp, dim=-1) 677 | # multinomial bugged on mps: shuttle to cpu if necessary 678 | inf_device = probs.device 679 | if probs.device.type == "mps": 680 | probs = probs.to("cpu") 681 | item_next = torch.multinomial(probs, num_samples=1) 682 | probs = probs.to(inf_device) 683 | item_next = item_next.to(inf_device) 684 | item_next += logit_start_idx 685 | x_coarse_in = torch.cat((x_coarse_in, item_next[None]), dim=1) 686 | x_in = torch.cat((x_in, item_next[None]), dim=1) 687 | del logits, relevant_logits, probs, item_next 688 | n_step += 1 689 | del x_in 690 | del x_semantic_in 691 | if OFFLOAD_CPU: 692 | model.to("cpu") 693 | gen_coarse_arr = x_coarse_in.detach().cpu().numpy().squeeze()[len(x_coarse_history) :] 694 | del x_coarse_in 695 | assert len(gen_coarse_arr) == n_steps 696 | gen_coarse_audio_arr = gen_coarse_arr.reshape(-1, N_COARSE_CODEBOOKS).T - SEMANTIC_VOCAB_SIZE 697 | for n in range(1, N_COARSE_CODEBOOKS): 698 | gen_coarse_audio_arr[n, :] -= n * CODEBOOK_SIZE 699 | _clear_cuda_cache() 700 | return gen_coarse_audio_arr 701 | 702 | 703 | def generate_fine( 704 | x_coarse_gen, 705 | history_prompt=None, 706 | temp=0.5, 707 | silent=True, 708 | ): 709 | """Generate full audio codes from coarse audio codes.""" 710 | assert ( 711 | isinstance(x_coarse_gen, np.ndarray) 712 | and len(x_coarse_gen.shape) == 2 713 | and 1 <= x_coarse_gen.shape[0] <= N_FINE_CODEBOOKS - 1 714 | and x_coarse_gen.shape[1] > 0 715 | and x_coarse_gen.min() >= 0 716 | and x_coarse_gen.max() <= CODEBOOK_SIZE - 1 717 | ) 718 | if history_prompt is not None: 719 | history_prompt = _load_history_prompt(history_prompt) 720 | x_fine_history = history_prompt["fine_prompt"] 721 | assert ( 722 | isinstance(x_fine_history, np.ndarray) 723 | and len(x_fine_history.shape) == 2 724 | and x_fine_history.shape[0] == N_FINE_CODEBOOKS 725 | and x_fine_history.shape[1] >= 0 726 | and x_fine_history.min() >= 0 727 | and x_fine_history.max() <= CODEBOOK_SIZE - 1 728 | ) 729 | else: 730 | x_fine_history = None 731 | n_coarse = x_coarse_gen.shape[0] 732 | # load models if not yet exist 733 | global models 734 | global models_devices 735 | if "fine" not in models: 736 | preload_models(text_use_small=True, coarse_use_small=True, fine_use_small=True) 737 | 738 | model = models["fine"] 739 | if OFFLOAD_CPU: 740 | model.to(models_devices["fine"]) 741 | device = next(model.parameters()).device 742 | # make input arr 743 | in_arr = np.vstack( 744 | [ 745 | x_coarse_gen, 746 | np.zeros((N_FINE_CODEBOOKS - n_coarse, x_coarse_gen.shape[1])) 747 | + CODEBOOK_SIZE, # padding 748 | ] 749 | ).astype(np.int32) 750 | # prepend history if available (max 512) 751 | if x_fine_history is not None: 752 | x_fine_history = x_fine_history.astype(np.int32) 753 | in_arr = np.hstack( 754 | [ 755 | x_fine_history[:, -512:].astype(np.int32), 756 | in_arr, 757 | ] 758 | ) 759 | n_history = x_fine_history[:, -512:].shape[1] 760 | else: 761 | n_history = 0 762 | n_remove_from_end = 0 763 | # need to pad if too short (since non-causal model) 764 | if in_arr.shape[1] < 1024: 765 | n_remove_from_end = 1024 - in_arr.shape[1] 766 | in_arr = np.hstack( 767 | [ 768 | in_arr, 769 | np.zeros((N_FINE_CODEBOOKS, n_remove_from_end), dtype=np.int32) + CODEBOOK_SIZE, 770 | ] 771 | ) 772 | # we can be lazy about fractional loop and just keep overwriting codebooks 773 | n_loops = np.max([0, int(np.ceil((x_coarse_gen.shape[1] - (1024 - n_history)) / 512))]) + 1 774 | with _inference_mode(): 775 | in_arr = torch.tensor(in_arr.T).to(device) 776 | for n in tqdm.tqdm(range(n_loops), disable=silent): 777 | start_idx = np.min([n * 512, in_arr.shape[0] - 1024]) 778 | start_fill_idx = np.min([n_history + n * 512, in_arr.shape[0] - 512]) 779 | rel_start_fill_idx = start_fill_idx - start_idx 780 | in_buffer = in_arr[start_idx : start_idx + 1024, :][None] 781 | for nn in range(n_coarse, N_FINE_CODEBOOKS): 782 | logits = model(nn, in_buffer) 783 | if temp is None: 784 | relevant_logits = logits[0, rel_start_fill_idx:, :CODEBOOK_SIZE] 785 | codebook_preds = torch.argmax(relevant_logits, -1) 786 | else: 787 | relevant_logits = logits[0, :, :CODEBOOK_SIZE] / temp 788 | probs = F.softmax(relevant_logits, dim=-1) 789 | # multinomial bugged on mps: shuttle to cpu if necessary 790 | inf_device = probs.device 791 | if probs.device.type == "mps": 792 | probs = probs.to("cpu") 793 | codebook_preds = torch.hstack( 794 | [ 795 | torch.multinomial(probs[nnn], num_samples=1).to(inf_device) 796 | for nnn in range(rel_start_fill_idx, 1024) 797 | ] 798 | ) 799 | in_buffer[0, rel_start_fill_idx:, nn] = codebook_preds 800 | del logits, codebook_preds 801 | # transfer over info into model_in and convert to numpy 802 | for nn in range(n_coarse, N_FINE_CODEBOOKS): 803 | in_arr[ 804 | start_fill_idx : start_fill_idx + (1024 - rel_start_fill_idx), nn 805 | ] = in_buffer[0, rel_start_fill_idx:, nn] 806 | del in_buffer 807 | gen_fine_arr = in_arr.detach().cpu().numpy().squeeze().T 808 | del in_arr 809 | if OFFLOAD_CPU: 810 | model.to("cpu") 811 | gen_fine_arr = gen_fine_arr[:, n_history:] 812 | if n_remove_from_end > 0: 813 | gen_fine_arr = gen_fine_arr[:, :-n_remove_from_end] 814 | assert gen_fine_arr.shape[-1] == x_coarse_gen.shape[-1] 815 | _clear_cuda_cache() 816 | return gen_fine_arr 817 | 818 | 819 | def codec_decode(fine_tokens): 820 | """Turn quantized audio codes into audio array using encodec.""" 821 | # load models if not yet exist 822 | global models 823 | global models_devices 824 | if "codec" not in models: 825 | preload_models() 826 | model = models["codec"] 827 | if OFFLOAD_CPU: 828 | model.to(models_devices["codec"]) 829 | device = next(model.parameters()).device 830 | arr = torch.from_numpy(fine_tokens)[None] 831 | arr = arr.to(device) 832 | arr = arr.transpose(0, 1) 833 | emb = model.quantizer.decode(arr) 834 | out = model.decoder(emb) 835 | audio_arr = out.detach().cpu().numpy().squeeze() 836 | del arr, emb, out 837 | if OFFLOAD_CPU: 838 | model.to("cpu") 839 | return audio_arr 840 | -------------------------------------------------------------------------------- /barkify/bark/model.py: -------------------------------------------------------------------------------- 1 | # Copy from bark 2 | 3 | """ 4 | Much of this code is adapted from Andrej Karpathy's NanoGPT 5 | (https://github.com/karpathy/nanoGPT) 6 | """ 7 | import math 8 | from dataclasses import dataclass 9 | 10 | import torch 11 | import torch.nn as nn 12 | from torch.nn import functional as F 13 | 14 | class LayerNorm(nn.Module): 15 | """ LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False """ 16 | 17 | def __init__(self, ndim, bias): 18 | super().__init__() 19 | self.weight = nn.Parameter(torch.ones(ndim)) 20 | self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None 21 | 22 | def forward(self, input): 23 | return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5) 24 | 25 | class CausalSelfAttention(nn.Module): 26 | 27 | def __init__(self, config): 28 | super().__init__() 29 | assert config.n_embd % config.n_head == 0 30 | # key, query, value projections for all heads, but in a batch 31 | self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias) 32 | # output projection 33 | self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias) 34 | # regularization 35 | self.attn_dropout = nn.Dropout(config.dropout) 36 | self.resid_dropout = nn.Dropout(config.dropout) 37 | self.n_head = config.n_head 38 | self.n_embd = config.n_embd 39 | self.dropout = config.dropout 40 | # flash attention make GPU go brrrrr but support is only in PyTorch nightly and still a bit scary 41 | self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention') 42 | if not self.flash: 43 | # print("WARNING: using slow attention. Flash Attention atm needs PyTorch nightly and dropout=0.0") 44 | # causal mask to ensure that attention is only applied to the left in the input sequence 45 | self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size)) 46 | .view(1, 1, config.block_size, config.block_size)) 47 | 48 | def forward(self, x, past_kv=None, use_cache=False): 49 | B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd) 50 | 51 | # calculate query, key, values for all heads in batch and move head forward to be the batch dim 52 | q, k ,v = self.c_attn(x).split(self.n_embd, dim=2) 53 | k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) 54 | q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) 55 | v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) 56 | 57 | if past_kv is not None: 58 | past_key = past_kv[0] 59 | past_value = past_kv[1] 60 | k = torch.cat((past_key, k), dim=-2) 61 | v = torch.cat((past_value, v), dim=-2) 62 | 63 | FULL_T = k.shape[-2] 64 | 65 | if use_cache is True: 66 | present = (k, v) 67 | else: 68 | present = None 69 | 70 | # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) 71 | if self.flash: 72 | # efficient attention using Flash Attention CUDA kernels 73 | if past_kv is not None: 74 | # When `past_kv` is provided, we're doing incremental decoding and `q.shape[2] == 1`: q only contains 75 | # the query for the last token. scaled_dot_product_attention interprets this as the first token in the 76 | # sequence, so if is_causal=True it will mask out all attention from it. This is not what we want, so 77 | # to work around this we set is_causal=False. 78 | is_causal = False 79 | else: 80 | is_causal = True 81 | 82 | y = torch.nn.functional.scaled_dot_product_attention(q, k, v, dropout_p=self.dropout, is_causal=is_causal) 83 | else: 84 | # manual implementation of attention 85 | att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) 86 | att = att.masked_fill(self.bias[:,:,FULL_T-T:FULL_T,:FULL_T] == 0, float('-inf')) 87 | att = F.softmax(att, dim=-1) 88 | att = self.attn_dropout(att) 89 | y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) 90 | y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side 91 | 92 | # output projection 93 | y = self.resid_dropout(self.c_proj(y)) 94 | return (y, present) 95 | 96 | class MLP(nn.Module): 97 | 98 | def __init__(self, config): 99 | super().__init__() 100 | self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias) 101 | self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias) 102 | self.dropout = nn.Dropout(config.dropout) 103 | self.gelu = nn.GELU() 104 | 105 | def forward(self, x): 106 | x = self.c_fc(x) 107 | x = self.gelu(x) 108 | x = self.c_proj(x) 109 | x = self.dropout(x) 110 | return x 111 | 112 | class Block(nn.Module): 113 | 114 | def __init__(self, config, layer_idx): 115 | super().__init__() 116 | self.ln_1 = LayerNorm(config.n_embd, bias=config.bias) 117 | self.attn = CausalSelfAttention(config) 118 | self.ln_2 = LayerNorm(config.n_embd, bias=config.bias) 119 | self.mlp = MLP(config) 120 | self.layer_idx = layer_idx 121 | 122 | def forward(self, x, past_kv=None, use_cache=False): 123 | attn_output, prev_kvs = self.attn(self.ln_1(x), past_kv=past_kv, use_cache=use_cache) 124 | x = x + attn_output 125 | x = x + self.mlp(self.ln_2(x)) 126 | return (x, prev_kvs) 127 | 128 | @dataclass 129 | class GPTConfig: 130 | block_size: int = 1024 131 | input_vocab_size: int = 10_048 132 | output_vocab_size: int = 10_048 133 | use_extra_input: bool = False 134 | extra_input_dim: int = 512 135 | n_layer: int = 12 136 | n_head: int = 12 137 | n_embd: int = 768 138 | dropout: float = 0.0 139 | bias: bool = True # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster 140 | 141 | class GPT(nn.Module): 142 | 143 | def __init__(self, config): 144 | super().__init__() 145 | assert config.input_vocab_size is not None 146 | assert config.output_vocab_size is not None 147 | assert config.block_size is not None 148 | self.config = config 149 | 150 | self.transformer = nn.ModuleDict(dict( 151 | wte = nn.Embedding(config.input_vocab_size, config.n_embd), 152 | wpe = nn.Embedding(config.block_size, config.n_embd), 153 | extra_proj = nn.Linear(config.extra_input_dim, config.n_embd) if config.use_extra_input else None, 154 | drop = nn.Dropout(config.dropout), 155 | h = nn.ModuleList([Block(config, idx) for idx in range(config.n_layer)]), 156 | ln_f = LayerNorm(config.n_embd, bias=config.bias), 157 | )) 158 | self.lm_head = nn.Linear(config.n_embd, config.output_vocab_size, bias=False) 159 | 160 | def get_num_params(self, non_embedding=True): 161 | """ 162 | Return the number of parameters in the model. 163 | For non-embedding count (default), the position embeddings get subtracted. 164 | The token embeddings would too, except due to the parameter sharing these 165 | params are actually used as weights in the final layer, so we include them. 166 | """ 167 | n_params = sum(p.numel() for p in self.parameters()) 168 | if non_embedding: 169 | n_params -= self.transformer.wte.weight.numel() 170 | n_params -= self.transformer.wpe.weight.numel() 171 | return n_params 172 | 173 | def forward(self, idx, extra=None, merge_context=False, past_kv=None, position_ids=None, use_cache=False): 174 | device = idx.device 175 | b, t = idx.size() 176 | if past_kv is not None: 177 | assert t == 1 178 | tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd) 179 | else: 180 | if merge_context: 181 | assert(idx.shape[1] >= 256+256+1) 182 | t = idx.shape[1] - 256 183 | else: 184 | assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}" 185 | 186 | # forward the GPT model itself 187 | if merge_context: 188 | tok_emb = torch.cat([ 189 | self.transformer.wte(idx[:,:256]) + self.transformer.wte(idx[:,256:256+256]), 190 | self.transformer.wte(idx[:,256+256:]) 191 | ], dim=1) 192 | else: 193 | tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd) 194 | 195 | if past_kv is None: 196 | past_length = 0 197 | past_kv = tuple([None] * len(self.transformer.h)) 198 | else: 199 | past_length = past_kv[0][0].size(-2) 200 | 201 | if position_ids is None: 202 | position_ids = torch.arange(past_length, t + past_length, dtype=torch.long, device=device) 203 | position_ids = position_ids.unsqueeze(0) # shape (1, t) 204 | assert position_ids.shape == (1, t) 205 | 206 | pos_emb = self.transformer.wpe(position_ids) # position embeddings of shape (1, t, n_embd) 207 | 208 | if self.config.use_extra_input: 209 | assert extra is not None 210 | x = self.transformer.drop(tok_emb + pos_emb + self.transformer.extra_proj(extra)) 211 | else: 212 | x = self.transformer.drop(tok_emb + pos_emb) 213 | 214 | new_kv = () if use_cache else None 215 | 216 | for i, (block, past_layer_kv) in enumerate(zip(self.transformer.h, past_kv)): 217 | x, kv = block(x, past_kv=past_layer_kv, use_cache=use_cache) 218 | 219 | if use_cache: 220 | new_kv = new_kv + (kv,) 221 | 222 | x = self.transformer.ln_f(x) 223 | 224 | # inference-time mini-optimization: only forward the lm_head on the very last position 225 | if use_cache == False: 226 | logits = self.lm_head(x[:, :, :]) # changes: output full logits 227 | else: 228 | logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim 229 | 230 | return (logits, new_kv) 231 | -------------------------------------------------------------------------------- /barkify/bark/model_fine.py: -------------------------------------------------------------------------------- 1 | # Copy from bark 2 | 3 | """ 4 | Much of this code is adapted from Andrej Karpathy's NanoGPT 5 | (https://github.com/karpathy/nanoGPT) 6 | """ 7 | from dataclasses import dataclass 8 | import math 9 | 10 | import torch 11 | import torch.nn as nn 12 | from torch.nn import functional as F 13 | 14 | from .model import GPT, GPTConfig, MLP 15 | 16 | 17 | class NonCausalSelfAttention(nn.Module): 18 | def __init__(self, config): 19 | super().__init__() 20 | assert config.n_embd % config.n_head == 0 21 | # key, query, value projections for all heads, but in a batch 22 | self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias) 23 | # output projection 24 | self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias) 25 | # regularization 26 | self.attn_dropout = nn.Dropout(config.dropout) 27 | self.resid_dropout = nn.Dropout(config.dropout) 28 | self.n_head = config.n_head 29 | self.n_embd = config.n_embd 30 | self.dropout = config.dropout 31 | # flash attention make GPU go brrrrr but support is only in PyTorch nightly and still a bit scary 32 | self.flash = ( 33 | hasattr(torch.nn.functional, "scaled_dot_product_attention") and self.dropout == 0.0 34 | ) 35 | 36 | def forward(self, x): 37 | B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd) 38 | 39 | # calculate query, key, values for all heads in batch and move head forward to be the batch dim 40 | q, k, v = self.c_attn(x).split(self.n_embd, dim=2) 41 | k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) 42 | q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) 43 | v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) 44 | 45 | # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) 46 | if self.flash: 47 | # efficient attention using Flash Attention CUDA kernels 48 | y = torch.nn.functional.scaled_dot_product_attention( 49 | q, k, v, attn_mask=None, dropout_p=self.dropout, is_causal=False 50 | ) 51 | else: 52 | # manual implementation of attention 53 | att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) 54 | att = F.softmax(att, dim=-1) 55 | att = self.attn_dropout(att) 56 | y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) 57 | y = ( 58 | y.transpose(1, 2).contiguous().view(B, T, C) 59 | ) # re-assemble all head outputs side by side 60 | 61 | # output projection 62 | y = self.resid_dropout(self.c_proj(y)) 63 | return y 64 | 65 | 66 | class FineBlock(nn.Module): 67 | def __init__(self, config): 68 | super().__init__() 69 | self.ln_1 = nn.LayerNorm(config.n_embd) 70 | self.attn = NonCausalSelfAttention(config) 71 | self.ln_2 = nn.LayerNorm(config.n_embd) 72 | self.mlp = MLP(config) 73 | 74 | def forward(self, x): 75 | x = x + self.attn(self.ln_1(x)) 76 | x = x + self.mlp(self.ln_2(x)) 77 | return x 78 | 79 | 80 | class FineGPT(GPT): 81 | def __init__(self, config): 82 | super().__init__(config) 83 | del self.lm_head 84 | self.config = config 85 | self.n_codes_total = config.n_codes_total 86 | self.transformer = nn.ModuleDict( 87 | dict( 88 | wtes=nn.ModuleList( 89 | [ 90 | nn.Embedding(config.input_vocab_size, config.n_embd) 91 | for _ in range(config.n_codes_total) 92 | ] 93 | ), 94 | wpe=nn.Embedding(config.block_size, config.n_embd), 95 | drop=nn.Dropout(config.dropout), 96 | h=nn.ModuleList([FineBlock(config) for _ in range(config.n_layer)]), 97 | ln_f=nn.LayerNorm(config.n_embd), 98 | ) 99 | ) 100 | self.lm_heads = nn.ModuleList( 101 | [ 102 | nn.Linear(config.n_embd, config.output_vocab_size, bias=False) 103 | for _ in range(config.n_codes_given, self.n_codes_total) 104 | ] 105 | ) 106 | for i in range(self.n_codes_total - config.n_codes_given): 107 | self.transformer.wtes[i + 1].weight = self.lm_heads[i].weight 108 | 109 | def forward(self, pred_idx, idx): 110 | device = idx.device 111 | b, t, codes = idx.size() 112 | assert ( 113 | t <= self.config.block_size 114 | ), f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}" 115 | assert pred_idx > 0, "cannot predict 0th codebook" 116 | assert codes == self.n_codes_total, (b, t, codes) 117 | pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0) # shape (1, t) 118 | 119 | # forward the GPT model itself 120 | tok_embs = [ 121 | wte(idx[:, :, i]).unsqueeze(-1) for i, wte in enumerate(self.transformer.wtes) 122 | ] # token embeddings of shape (b, t, n_embd) 123 | tok_emb = torch.cat(tok_embs, dim=-1) 124 | pos_emb = self.transformer.wpe(pos) # position embeddings of shape (1, t, n_embd) 125 | x = tok_emb[:, :, :, : pred_idx + 1].sum(dim=-1) 126 | x = self.transformer.drop(x + pos_emb) 127 | for block in self.transformer.h: 128 | x = block(x) 129 | x = self.transformer.ln_f(x) 130 | logits = self.lm_heads[pred_idx - self.config.n_codes_given](x) 131 | return logits 132 | 133 | def get_num_params(self, non_embedding=True): 134 | """ 135 | Return the number of parameters in the model. 136 | For non-embedding count (default), the position embeddings get subtracted. 137 | The token embeddings would too, except due to the parameter sharing these 138 | params are actually used as weights in the final layer, so we include them. 139 | """ 140 | n_params = sum(p.numel() for p in self.parameters()) 141 | if non_embedding: 142 | for wte in self.transformer.wtes: 143 | n_params -= wte.weight.numel() 144 | n_params -= self.transformer.wpe.weight.numel() 145 | return n_params 146 | 147 | 148 | @dataclass 149 | class FineGPTConfig(GPTConfig): 150 | n_codes_total: int = 8 151 | n_codes_given: int = 1 152 | -------------------------------------------------------------------------------- /barkify/datas/__init__.py: -------------------------------------------------------------------------------- 1 | import functools 2 | from torch.utils.data import DataLoader 3 | 4 | from .data import Dataset 5 | from .data import Text2semanticCollateFn, Semantic2coarseCollateFn 6 | stage_collate = [Text2semanticCollateFn, Semantic2coarseCollateFn] 7 | 8 | from .tokenizer import ZHTokenizer, PhonemeTokenizer 9 | 10 | def StageDataloader(params, stage=1, file='train') -> DataLoader: 11 | 12 | tokenizer = PhonemeTokenizer() # TODO: add more tokenizer 13 | dataset = Dataset(file=file, tokenizer=tokenizer, **params.dataset) 14 | collate_fn = functools.partial(stage_collate[int(stage) - 1], **params.collate_fn) 15 | 16 | loader = DataLoader(dataset, collate_fn=collate_fn, **params.dataloader) 17 | return loader -------------------------------------------------------------------------------- /barkify/datas/data.py: -------------------------------------------------------------------------------- 1 | import json 2 | from os.path import join as pjoin 3 | 4 | import random 5 | import torch 6 | import numpy as np 7 | from tqdm import tqdm 8 | 9 | from .tokenizer import SplitTokenizer 10 | 11 | class Dataset: 12 | def __init__(self, start_path, file, 13 | tokenizer: SplitTokenizer = None, add_prompt = False 14 | ): 15 | 16 | self.start_path = start_path 17 | self.tokenizer = tokenizer 18 | self.add_prompt = add_prompt # add prompt for semantic or acoustic 19 | 20 | with open(pjoin(start_path, 'meta', file+'.json')) as f: 21 | self._datas = [json.loads(i) for i in f.readlines()] 22 | g2p_already = any([data.get('g2p', False) for data in self._datas]) 23 | 24 | if not g2p_already: 25 | print("g2p all texts.") 26 | for data in tqdm(self._datas): 27 | if not data.get('g2p', False): 28 | data['g2p'] = self.tokenizer.g2p(data['text']) 29 | 30 | with open(pjoin(start_path, 'meta', file+'.json'), "w") as f: 31 | for data in self._datas: 32 | f.writelines(json.dumps(data, ensure_ascii=False)+'\n') 33 | 34 | def __getitem__(self, idx): 35 | batch = {} 36 | data = self._datas[idx] 37 | 38 | batch['name'] = data['name'] 39 | batch['text'] = self.tokenizer.token2id(data['g2p']) 40 | 41 | semantic_path = pjoin(self.start_path, 'semantic_idx', data['name']) 42 | codec_path = pjoin(self.start_path, 'encodec_idx', data['name']) 43 | batch['semantic'] = torch.from_numpy(np.load(semantic_path)) 44 | batch['encodec'] = torch.from_numpy(np.load(codec_path)) 45 | 46 | if self.add_prompt: 47 | raise NotImplementedError 48 | 49 | return batch 50 | 51 | def __len__(self): 52 | return len(self._datas) 53 | 54 | def Text2semanticCollateFn( 55 | batches, 56 | text_window=512, semantic_window=512, # set training window size. 57 | text_token_num=210, fixed_length=True, ign_idx=-100, **kwargs 58 | ): 59 | ''' 60 | logits not following bark: text(210) + pad_text(1) + eos(1) + infer(1) + semantic(2048) 61 | returns: 62 | input: [text with pad(512), infer(1), semantic with pad(512), eos(1)] 63 | tgt: [ign_idx(512), ign_idx(1), (semantic, ign_idx)(512), eos(1)] 64 | ''' 65 | 66 | text_pad_token, semantic_pad_token = text_token_num, text_token_num + 1 67 | infer_token, text_offset = text_token_num + 2, text_token_num + 3 68 | 69 | semantic, text = [], [] 70 | semantic_length, text_length = [], [] 71 | names = [] 72 | 73 | for batch in batches: 74 | names.append(batch['name']) 75 | 76 | _text = np.array(batch['text']) 77 | if fixed_length: 78 | _text = _text[-text_window:] # TODO: drop long text? 79 | text_length.append(len(_text)) 80 | 81 | _text = np.pad( 82 | _text, (0, text_window - len(_text)), 83 | constant_values = text_pad_token 84 | ) 85 | 86 | _semantic = batch['semantic'] + text_offset # different from bark, we add offset to semantic tokens rather than text. 87 | if fixed_length: 88 | _semantic = _semantic[-semantic_window:] 89 | semantic_length.append(len(_semantic)) 90 | 91 | _semantic = np.pad( 92 | _semantic, (0, semantic_window - len(_semantic)), 93 | constant_values = semantic_pad_token 94 | ) 95 | 96 | semantic.append(_semantic), text.append(_text) 97 | 98 | text, semantic = torch.from_numpy(np.stack(text, axis=0)), torch.from_numpy(np.stack(semantic, axis=0)) 99 | text_length, semantic_length = torch.tensor(text_length), torch.tensor(semantic_length) 100 | 101 | B = text.shape[0] 102 | inputs = torch.cat([text, torch.ones(B, 1) * infer_token, semantic, torch.ones(B, 1) * semantic_pad_token], -1) 103 | tgt = torch.cat([torch.ones(B, text_window + 1) * ign_idx, semantic, torch.ones(B, 1) * semantic_pad_token], -1) 104 | tgt[:, text_window+1:][(torch.arange(0, semantic_window + 1) > semantic_length[:, None])] = -100 105 | 106 | batch = dict(names = names, 107 | input_ids = inputs.long(), labels = tgt.long(), 108 | semantic_length = semantic_length, text_length = text_length 109 | ) 110 | return batch 111 | 112 | def Semantic2coarseCollateFn( 113 | batches, Q_size, 114 | semantic_window=256, semantic_to_coarse_ratio=3, 115 | semantic_token_num=2048, coarse_num=1024, ign_idx=-100, 116 | slice_range = 60, # window size in bark 117 | **kwargs 118 | ): 119 | ''' 120 | logits following bark: semantic(2048) + pad(1) + infer(1) + coarse(1024*Q) 121 | returns: 122 | input: [semantic with pad(256), infer(1), coarse with pad(768)] 123 | tgt: [ign_idx(256), ign_idx(1), (coarse, ign_idx)(768)] 124 | ''' 125 | 126 | semantic_pad_token = coarse_pad_token = semantic_token_num 127 | infer_token = semantic_token_num + 1 128 | 129 | semantic_window = semantic_window // 2 * 2 130 | acoustic_pad = int(np.ceil((semantic_window * semantic_to_coarse_ratio) / Q_size) * Q_size) 131 | semantic, coarse = [], [] 132 | semantic_length, coarse_length = [], [] 133 | names = [] 134 | 135 | for batch in batches: 136 | names.append(batch['name']) 137 | semantic_start = random.randint(0, np.max([0, len(batch['semantic']) - (semantic_window - slice_range)])) 138 | semantic_start = (semantic_start // Q_size) * Q_size # add 139 | _semantic = batch['semantic'][semantic_start:semantic_start+semantic_window] 140 | semantic_length.append(len(_semantic)) 141 | 142 | _semantic = np.pad( 143 | _semantic, (0, semantic_window - len(_semantic)), 144 | constant_values = semantic_pad_token 145 | ) 146 | 147 | coarse_start = semantic_start * semantic_to_coarse_ratio 148 | coarse_end = coarse_start + acoustic_pad 149 | 150 | _coarse = batch['encodec'][:Q_size] + 2 + semantic_pad_token # add pad and infer token 151 | for i in range(1, _coarse.shape[0]): 152 | _coarse[i:] += coarse_num 153 | _coarse = _coarse.T.reshape(-1)[int(coarse_start) : int(coarse_end)] 154 | coarse_length.append(len(_coarse)) 155 | 156 | _coarse = np.pad( 157 | _coarse, (0, acoustic_pad - len(_coarse)), 158 | constant_values = coarse_pad_token 159 | ) 160 | 161 | semantic.append(_semantic), coarse.append(_coarse) 162 | 163 | semantic, coarse = torch.from_numpy(np.stack(semantic, axis=0)), torch.from_numpy(np.stack(coarse, axis=0)) 164 | semantic_length, coarse_length = torch.tensor(semantic_length), torch.tensor(coarse_length) 165 | 166 | inputs = torch.cat([semantic, torch.ones(semantic.shape[0], 1) * infer_token, coarse], -1) 167 | tgt = torch.cat([torch.ones(semantic.shape[0], semantic_window + 1) * ign_idx, coarse], -1) 168 | tgt[(tgt == semantic_pad_token)|(tgt == coarse_pad_token)] = ign_idx 169 | 170 | batch = dict(names = names, 171 | input_ids = inputs.long(), labels = tgt.long(), 172 | semantic_length = semantic_length, coarse_length = coarse_length 173 | ) 174 | return batch 175 | 176 | if __name__ == "__main__": 177 | from barkify.datas.tokenizer import ZHTokenizer 178 | dataset = Dataset('/root/intern/bark/barkify/work_env', 'train', ZHTokenizer()) 179 | print(dataset[0]['text']) -------------------------------------------------------------------------------- /barkify/datas/tokenizer.py: -------------------------------------------------------------------------------- 1 | # TODO: integrate with NeMo tokenizer? 2 | import string 3 | from pypinyin import lazy_pinyin, Style 4 | from g2p_en import G2p 5 | 6 | class SplitTokenizer: 7 | def __init__(self, **kwargs): 8 | self._token2id = {i:idx + 1 for idx, i in enumerate(ENGLISH_PRONUNCIATION_LIST)} 9 | self._id2token = {idx + 1:i for idx, i in enumerate(ENGLISH_PRONUNCIATION_LIST)} 10 | 11 | def __call__(self, text): 12 | text = self.g2p(text) 13 | return self.token2id(text) 14 | 15 | def g2p(self, text): 16 | # text: I'm Tom. -> [i, ', m, , t, o, m, .] 17 | text = text.lower() 18 | return " ".join(["" if i == ' ' else i for i in list(text)]) 19 | 20 | def token2id(self, text): 21 | token = [] 22 | for i in text.split(): 23 | if i == ' ': 24 | token.append(self._token2id['']) 25 | elif self._token2id.get(i, None): 26 | token.append(self._token2id[i]) 27 | 28 | return token 29 | 30 | class PhonemeTokenizer(SplitTokenizer): 31 | def __init__(self, **kwargs): 32 | self._g2p = G2p() 33 | self._token2id = {i:idx + 1 for idx, i in enumerate(self._g2p.phonemes)} 34 | self._id2token = {idx + 1:i for idx, i in enumerate(self._g2p.phonemes)} 35 | 36 | def g2p(self, text): 37 | text = self._g2p(text) 38 | return " ".join(["" if i == ' ' else i for i in text]) 39 | 40 | class ZHTokenizer(SplitTokenizer): 41 | def __init__(self, **kwargs): 42 | # A basic english and chinese G2P tokenizer 43 | self._token2id = {i:idx + 1 for idx, i in enumerate(PINYIN_PRONUNCIATION_LIST)} 44 | self._id2token = {idx + 1:i for idx, i in enumerate(PINYIN_PRONUNCIATION_LIST)} 45 | 46 | def g2p(self, text): 47 | text = text.lower() 48 | initials = lazy_pinyin(text, neutral_tone_with_five=False, style=Style.INITIALS, strict=False) 49 | finals = lazy_pinyin(text, neutral_tone_with_five=False, style=Style.FINALS_TONE3) 50 | 51 | text_phone = [] 52 | for _o in zip(initials, finals): 53 | if _o[0] != _o[1] and _o[0] != '': 54 | _o = ['@'+i for i in _o] 55 | text_phone.extend(_o) 56 | elif _o[0] != _o[1] and _o[0] == '': 57 | text_phone.append('@'+_o[1]) 58 | else: 59 | text_phone.extend(["" if i == ' ' else i for i in list(_o[0])]) 60 | 61 | return " ".join(text_phone) 62 | 63 | ENGLISH_PRONUNCIATION_LIST = list(string.ascii_lowercase) + list(",.?!") + [''] 64 | PINYIN_PRONUNCIATION_LIST = ENGLISH_PRONUNCIATION_LIST + [ 65 | # PINYIN_PRONUNCIATION_LIST = [ 66 | '@w', '@uo3', '@y', '@i3', '@j', '@ing1', '@b', '@a3', '@o1', '@l', '@a1', '@h', '@ei1', '@e', 67 | '@n', '@iou4', '@sh', '@i4', '@ie2', '@z', '@ai4', '@a4', '@m', '@g', '@ou3', '@t', '@uo2', '@i', '@q', 68 | '@vn2', '@uo1', '@zh', '@i1', '@d', '@ao4', '@uei4', '@a2', '@a', '@e4', '@ing3', '@ei', '@u4', '@uan2', 69 | '@f', '@an4', '@en2', '@c', '@uo4', '@uei1', '@iou1', '@ei2', '@e2', '@r', '@en4', '@eng1', '@e1', '@en1', 70 | '@ou4', '@ang4', '@p', '@eng2', '@ong4', '@u2', '@iang2', '@van1', '@ian1', '@ei4', '@er3', '@ia1', '@ou2', 71 | '@ao3', '@ou1', '@er2', '@s', '@i2', '@v4', '@x', '@ian4', '@ong1', '@uan3', '@uang2', '@ing4', '@ch', '@vn3', 72 | '@uen1', '@ai1', '@an3', '@eng4', '@ing2', '@ve4', '@k', '@ang3', '@en3', '@ai2', '@ian3', '@er4', '@ai3', 73 | '@uai4', '@ian2', '@ao1', '@eng3', '@ia4', '@n2', '@ang1', '@ie3', '@uen3', '@iou3', '@ei3', '@in4', '@v3', 74 | '@uen4', '@an2', '@iang1', '@in1', '@u3', '@ve2', '@e3', '@iang4', '@ia', '@an1', '@in3', '@iao4', '@ang2', 75 | '@vn1', '@iao3', '@u1', '@ie1', '@ie4', '@v2', '@uei2', '@iong1', '@iao1', '@o2', '@uei3', '@in2', '@iong4', 76 | '@ve1', '@uang1', '@iang3', '@uan4', '@iou2', '@en', '@uan1', '@ia2', '@ua1', '@ong3', '@van4', '@van2', '@uang3', 77 | '@iao2', '@ua4', '@ong2', '@uen2', '@iong3', '@er', '@v1', '@uang4', '@ia3', '@ve3', '@ua2', '@van3', '@ao2', 78 | '@o4', '@ua3', '@vn4', '@iong2', '@io1', '@uai1', '@ou', '@uai2', '@ua', '@ueng1', '@o', '@uai3', '@o3', '@uo', 79 | ] + list("、,。") 80 | # ] + list(string.ascii_lowercase) + list("、,。") 81 | 82 | if __name__ == "__main__": 83 | tokenizer = ZHTokenizer() 84 | print(tokenizer.g2p("I'm tom.")) 85 | print(tokenizer("I'm tom.")) 86 | 87 | tokenizer = PhonemeTokenizer() 88 | print(tokenizer.g2p("I'm tom.")) 89 | print(tokenizer("I'm tom.")) 90 | -------------------------------------------------------------------------------- /barkify/pl_model.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | import torch.optim as optim 5 | 6 | import pytorch_lightning as pl 7 | from transformers import get_linear_schedule_with_warmup, get_cosine_schedule_with_warmup, get_constant_schedule_with_warmup 8 | 9 | from .bark import GPT, GPTConfig 10 | 11 | class NanoGPT(pl.LightningModule): 12 | def __init__( 13 | self, 14 | model_config, 15 | lr, 16 | min_lr, 17 | weight_decay, 18 | warmup_iters, 19 | max_iters, 20 | lr_strategy = 'constant', 21 | **kwargs 22 | ): 23 | super().__init__() 24 | self.save_hyperparameters() 25 | self.criterion = nn.CrossEntropyLoss() 26 | self._create_model() 27 | 28 | def _create_model(self): 29 | gptconf = GPTConfig(**self.hparams.model_config) 30 | self.model = GPT(gptconf) 31 | 32 | def forward(self, x,) -> torch.Tensor: 33 | logits, _ = self.model(x, use_cache = False) 34 | return logits 35 | 36 | @torch.no_grad() 37 | def infer(self, x, *args, **kwargs): 38 | return self.forward(x, *args, **kwargs) 39 | 40 | def optimizer_step(self, *args, **kwargs) -> None: 41 | super().optimizer_step(*args, **kwargs) 42 | if self.optimizers().param_groups[0]['lr'] < self.hparams.min_lr and self.lr_schedulers().last_epoch > self.hparams.warmup_iters: 43 | self.optimizers().param_groups[0]['lr'] = self.hparams.min_lr 44 | self.lr_schedulers().last_epoch -= 1 45 | 46 | def configure_optimizers(self): 47 | optimizer = optim.AdamW(self.model.parameters(), lr=self.hparams.lr, weight_decay=self.hparams.weight_decay) 48 | # We don't return the lr scheduler because we need to apply it per iteration, not per epoch 49 | if self.hparams.lr_strategy == 'constant': 50 | lr_scheduler = get_constant_schedule_with_warmup( 51 | optimizer=optimizer, 52 | num_warmup_steps=self.hparams.warmup_iters, 53 | ) 54 | elif self.hparams.lr_strategy == 'cosine': 55 | lr_scheduler = get_cosine_schedule_with_warmup( 56 | optimizer=optimizer, 57 | num_warmup_steps=self.hparams.warmup_iters, 58 | num_training_steps=self.hparams.max_iters, 59 | ) 60 | elif self.hparams.lr_strategy == 'linear': 61 | lr_scheduler = get_linear_schedule_with_warmup( 62 | optimizer=optimizer, 63 | num_warmup_steps=self.hparams.warmup_iters, 64 | num_training_steps=self.hparams.max_iters, 65 | ) 66 | else: 67 | raise NotImplementedError 68 | 69 | return [optimizer], [{"scheduler": lr_scheduler, "interval": "step"}] 70 | 71 | def on_train_start(self): 72 | self.logger.log_hyperparams(self.hparams) 73 | 74 | def training_step(self, batch, batch_idx): 75 | if 'extra' in batch: 76 | logits = self.forward(batch['input_ids'], extra=batch['extra'], merge_context = True) 77 | else: 78 | logits = self.forward(batch['input_ids']) 79 | loss = self.criterion(logits[:,:-1,:].reshape(-1, logits.shape[-1]), batch['labels'][:,1:].reshape(-1)) 80 | top1_acc, top10_acc = self.get_top10_acc(logits[:1,:-1,:], batch['labels'][:1,1:]) 81 | 82 | self.log("train_loss", loss, prog_bar=True) 83 | self.log("train_1acc", top1_acc) 84 | self.log("train_10acc", top10_acc) 85 | 86 | #self.log("lr", self.lr_schedulers().get_last_lr()[0]) 87 | self.log("lr", self.optimizers().param_groups[0]['lr']) 88 | 89 | return loss 90 | 91 | def validation_step(self, batch, batch_idx): 92 | logits = self.forward(batch['input_ids']) 93 | loss = self.criterion(logits[:,:-1,:].reshape(-1, logits.shape[-1]), batch['labels'][:,1:].reshape(-1)) 94 | top1_acc, top10_acc = self.get_top10_acc(logits[:32,:-1,:], batch['labels'][:32,1:]) #FIXME: B > 32 may raise error? 95 | 96 | self.log("val_loss", loss) 97 | self.log("val_1acc", top1_acc) 98 | self.log("val_10acc", top10_acc) 99 | 100 | def test_step(self, batch, batch_idx): 101 | raise NotImplementedError 102 | 103 | def get_top10_acc(self, logits, labels): 104 | with torch.no_grad(): 105 | top10 = logits.topk(dim=-1, k=10).indices 106 | top1 = logits.topk(dim=-1, k=1).indices 107 | label = labels.unsqueeze(-1) 108 | 109 | total_item = (label != -100).sum() 110 | top10_acc = ((top10 == label).any(-1, keepdim=True) & (label != -100)).sum() / total_item 111 | top1_acc = ((top1 == label).any(-1, keepdim=True) & (label != -100)).sum() / total_item 112 | top10_acc, top1_acc = top10_acc.item(), top1_acc.item() 113 | 114 | return top1_acc, top10_acc -------------------------------------------------------------------------------- /barkify/utils.py: -------------------------------------------------------------------------------- 1 | from glob import glob 2 | 3 | def Bestckpt(exp_root_dir): 4 | ckpts = glob(f"{exp_root_dir}/lightning_logs/*/checkpoints/*") 5 | ckpts = sorted(ckpts, key=lambda x:x.split("/")[-1].split("=")[1].split("-")[0]) 6 | ckpts = ckpts[-1] if len(ckpts) > 0 else None 7 | return ckpts -------------------------------------------------------------------------------- /configs/barkify.yaml: -------------------------------------------------------------------------------- 1 | name: ljspeech 2 | stage: 1 3 | start_path: ??? 4 | 5 | common: 6 | trainer: 7 | accelerator: "gpu" 8 | # strategy: "deepspeed" 9 | # strategy: "ddp_find_unused_parameters_false" 10 | strategy: null 11 | precision: 16 # 32 12 | 13 | ckpt: 14 | mode: "min" 15 | monitor: "val_loss" 16 | save_top_k: 3 17 | every_n_epochs: 10 18 | save_weights_only: false 19 | 20 | stage1: 21 | ############# create dataset ############# 22 | tokenizer: null 23 | dataset: 24 | start_path: ${start_path} 25 | add_prompt: false 26 | 27 | collate_fn: 28 | text_window: 256 29 | semantic_window: 768 30 | text_token_num: 210 31 | 32 | dataloader: 33 | batch_size: 128 34 | shuffle: True 35 | num_workers: 32 36 | persistent_workers: True 37 | pin_memory: True 38 | 39 | ############## model params ############## 40 | model: 41 | n_layer: 6 42 | n_head: 4 43 | n_embd: 512 44 | block_size: 1026 # 256(text_window) + 768(semantic) + 2(infer token, eos) 45 | bias: False 46 | dropout: 0.1 47 | input_vocab_size: 2261 # 210(pinyin) + 2048(semantic) + 3(infer, pad_text, pad_semantic) 48 | output_vocab_size: 2261 49 | 50 | ############## optim params ############## 51 | optim: 52 | lr: 1e-4 53 | min_lr: 2e-5 54 | weight_decay: 1e-3 55 | warmup_iters: 1000 56 | max_iters: 10000 # depends on your GPUs. 57 | lr_strategy: 'cosine' 58 | gradient_clip: 1 59 | 60 | stage2: 61 | ############# create dataset ############# 62 | tokenizer: null 63 | dataset: 64 | start_path: ${start_path} 65 | add_prompt: false 66 | 67 | collate_fn: 68 | Q_size: 2 # num of acoustic token to predict 69 | semantic_to_coarse_ratio: 3 70 | semantic_window: 256 71 | semantic_token_num: 2048 72 | coarse_num: 1024 73 | slice_range: 60 # shift window size at inference in bark's source code 74 | 75 | dataloader: 76 | batch_size: 128 77 | shuffle: True 78 | num_workers: 32 79 | persistent_workers: True 80 | pin_memory: True 81 | 82 | ############## model params ############## 83 | model: 84 | n_layer: 6 85 | n_head: 4 86 | n_embd: 512 87 | block_size: 1025 # 256(semantic_window) + 256*3(window X ratio) + 1(infer token) 88 | bias: False 89 | dropout: 0.1 90 | input_vocab_size: 4098 # 2048(semantic) + 1024*2(coarse X Q_size) + 2(infer token, padding) 91 | output_vocab_size: 4098 92 | 93 | ############## optim params ############## 94 | optim: 95 | lr: 1e-4 96 | min_lr: 2e-5 97 | weight_decay: 1e-3 98 | warmup_iters: 1000 99 | max_iters: 20000 # depends on your GPUs. 100 | lr_strategy: 'cosine' 101 | gradient_clip: 1 102 | 103 | 104 | hydra: 105 | run: 106 | dir: ${start_path}/${name}/stage_${stage} -------------------------------------------------------------------------------- /infer.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "b5becc02", 6 | "metadata": {}, 7 | "source": [ 8 | "## Infer" 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": null, 14 | "id": "7a4f55af", 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "import tqdm\n", 19 | "import torch\n", 20 | "import torch.nn.functional as F\n", 21 | "from IPython.display import Audio\n", 22 | "from scipy.io.wavfile import write as write_wav\n", 23 | "\n", 24 | "import barkify.bark as bark \n", 25 | "from barkify.utils import Bestckpt\n", 26 | "from barkify.bark import create_infer_model\n", 27 | "from barkify.datas import PhonemeTokenizer\n", 28 | "\n", 29 | "from omegaconf import OmegaConf\n", 30 | "x_dict = OmegaConf.load(\"configs/barkify.yaml\")\n", 31 | "\n", 32 | "start_path = \"../work_env\" # your data folder." 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": null, 38 | "id": "df6dc569", 39 | "metadata": {}, 40 | "outputs": [], 41 | "source": [ 42 | "TEXT_INPUT_LEN = x_dict.stage1.collate_fn.text_window\n", 43 | "TEXT_TOKEN_NUM = x_dict.stage1.collate_fn.text_token_num\n", 44 | "SEMANTIC_EOS_TOKEN, SEMANTIC_INFER_TOKEN = TEXT_TOKEN_NUM+1, TEXT_TOKEN_NUM+2\n", 45 | "\n", 46 | "COARSE_BOOK = x_dict.stage2.collate_fn.Q_size\n", 47 | "SEMANTIC_TOKEN_NUM = x_dict.stage2.collate_fn.semantic_token_num\n", 48 | "SEMANTIC_INPUT_LEN = x_dict.stage2.collate_fn.semantic_window\n", 49 | "CODEC_TOKEN_NUM = x_dict.stage2.collate_fn.coarse_num\n", 50 | "COARSE_INFER_TOKEN = SEMANTIC_TOKEN_NUM + 1\n", 51 | "\n", 52 | "stage1_model = create_infer_model(x_dict.stage1.model).cuda()\n", 53 | "stage2_model = create_infer_model(x_dict.stage2.model).cuda()\n", 54 | "\n", 55 | "tokenizer = PhonemeTokenizer()" 56 | ] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "execution_count": null, 61 | "id": "276b89a2", 62 | "metadata": {}, 63 | "outputs": [], 64 | "source": [ 65 | "ckpt = torch.load(Bestckpt(f\"{start_path}/{x_dict.name}/stage_1\"))['state_dict']\n", 66 | "stage1_model.load_state_dict({\".\".join(i.split(\"model.\")[1:]):ckpt[i] for i in ckpt})\n", 67 | "\n", 68 | "ckpt = torch.load(Bestckpt(f\"{start_path}/{x_dict.name}/stage_2\"))['state_dict']\n", 69 | "stage2_model.load_state_dict({\".\".join(i.split(\"model.\")[1:]):ckpt[i] for i in ckpt})" 70 | ] 71 | }, 72 | { 73 | "cell_type": "code", 74 | "execution_count": null, 75 | "id": "eb2e1bfe", 76 | "metadata": {}, 77 | "outputs": [], 78 | "source": [ 79 | "def generate_stage1(\n", 80 | " x, \n", 81 | " model,\n", 82 | " tempature = 0.60,\n", 83 | " max_steps = 512,\n", 84 | "):\n", 85 | "\n", 86 | " kv_cache = None\n", 87 | "\n", 88 | " x = F.pad(x, (0, TEXT_INPUT_LEN-x.shape[1]), mode='constant', value=TEXT_TOKEN_NUM)\n", 89 | " x = torch.cat([\n", 90 | " x, \n", 91 | " torch.tensor([SEMANTIC_INFER_TOKEN], dtype=x.dtype, device=x.device)[None]\n", 92 | " ], dim=1)\n", 93 | " \n", 94 | " text_len = x.shape[1]\n", 95 | "\n", 96 | " for _ in tqdm.trange(max_steps):\n", 97 | " \n", 98 | " if kv_cache is not None:\n", 99 | " x_input = x[:, [-1]]\n", 100 | " else:\n", 101 | " x_input = x\n", 102 | "\n", 103 | " logits, kv_cache = model(x_input, use_cache=True, past_kv=kv_cache)\n", 104 | "\n", 105 | " relevant_logits = torch.hstack(\n", 106 | " (logits[0, 0, TEXT_TOKEN_NUM+3:], logits[0, 0, [SEMANTIC_EOS_TOKEN]])\n", 107 | " )\n", 108 | "\n", 109 | " probs = F.softmax(relevant_logits / tempature, dim=-1)\n", 110 | " item_next = torch.multinomial(probs, num_samples=1)\n", 111 | "\n", 112 | " if item_next == len(relevant_logits) - 1:\n", 113 | " break\n", 114 | "\n", 115 | " x = torch.cat((x, item_next[None]+TEXT_TOKEN_NUM+3), dim=1)\n", 116 | " \n", 117 | " return x[:, text_len:] - TEXT_TOKEN_NUM - 3" 118 | ] 119 | }, 120 | { 121 | "cell_type": "code", 122 | "execution_count": null, 123 | "id": "5f7f9642", 124 | "metadata": {}, 125 | "outputs": [], 126 | "source": [ 127 | "def generate_stage2(\n", 128 | " x, \n", 129 | " model,\n", 130 | " tempature = 0.6,\n", 131 | " max_steps = 768\n", 132 | "):\n", 133 | "\n", 134 | " kv_cache = None\n", 135 | " \n", 136 | " x = F.pad(x, (0, SEMANTIC_INPUT_LEN-x.shape[1]), mode='constant', value=SEMANTIC_TOKEN_NUM)\n", 137 | " x = torch.cat([\n", 138 | " x, \n", 139 | " torch.tensor([COARSE_INFER_TOKEN], dtype=x.dtype, device=x.device)[None]\n", 140 | " ], dim=1)\n", 141 | " \n", 142 | " semantic_len = x.shape[1]\n", 143 | "\n", 144 | " for i in tqdm.trange(max_steps):\n", 145 | "\n", 146 | " Q = i % COARSE_BOOK\n", 147 | " if kv_cache is not None:\n", 148 | " x_input = x[:, [-1]]\n", 149 | " else:\n", 150 | " x_input = x\n", 151 | "\n", 152 | " logits, kv_cache = model(x_input, use_cache=True, past_kv=kv_cache)\n", 153 | " start = SEMANTIC_TOKEN_NUM + 2 + Q * CODEC_TOKEN_NUM\n", 154 | " relevant_logits = logits[0, 0, start : start + CODEC_TOKEN_NUM]\n", 155 | " \n", 156 | " probs = F.softmax(relevant_logits / tempature, dim=-1)\n", 157 | " item_next = torch.multinomial(probs, num_samples=1)\n", 158 | " x = torch.cat((x, item_next[None]+start), dim=1)\n", 159 | " \n", 160 | " output = x[:, semantic_len:]\n", 161 | " for Q in range(COARSE_BOOK):\n", 162 | " output[:, Q::COARSE_BOOK] -= (SEMANTIC_TOKEN_NUM + 2 + Q * CODEC_TOKEN_NUM)\n", 163 | " \n", 164 | " return output.reshape(-1, COARSE_BOOK).T" 165 | ] 166 | }, 167 | { 168 | "cell_type": "code", 169 | "execution_count": null, 170 | "id": "9b93eeec", 171 | "metadata": {}, 172 | "outputs": [], 173 | "source": [ 174 | "tgt_text = \"At a given signal, they reenacted the event. Baker's movements were timed with a stopwatch.\"\n", 175 | "\n", 176 | "tokens = tokenizer(tgt_text)\n", 177 | "dummy_tokenized = torch.tensor([tokens]).cuda()\n", 178 | "dummy_semantic = generate_stage1(dummy_tokenized, model=stage1_model)\n", 179 | "dummy_coarse = generate_stage2(dummy_semantic, model=stage2_model)" 180 | ] 181 | }, 182 | { 183 | "cell_type": "code", 184 | "execution_count": null, 185 | "id": "9499e142", 186 | "metadata": { 187 | "scrolled": true 188 | }, 189 | "outputs": [], 190 | "source": [ 191 | "dummy_fine = bark.generate_fine(dummy_coarse.detach().cpu().numpy(), history_prompt=None)\n", 192 | "audio_array = bark.codec_decode(dummy_fine)\n", 193 | "\n", 194 | "# play text in notebook\n", 195 | "Audio(audio_array, rate=24000)\n", 196 | "\n", 197 | "# write_wav(\"bark_generation.wav\", 24000, audio_array)" 198 | ] 199 | } 200 | ], 201 | "metadata": { 202 | "kernelspec": { 203 | "display_name": "Python 3 (ipykernel)", 204 | "language": "python", 205 | "name": "python3" 206 | }, 207 | "language_info": { 208 | "codemirror_mode": { 209 | "name": "ipython", 210 | "version": 3 211 | }, 212 | "file_extension": ".py", 213 | "mimetype": "text/x-python", 214 | "name": "python", 215 | "nbconvert_exporter": "python", 216 | "pygments_lexer": "ipython3", 217 | "version": "3.8.15" 218 | }, 219 | "toc": { 220 | "base_numbering": 1, 221 | "nav_menu": {}, 222 | "number_sections": true, 223 | "sideBar": true, 224 | "skip_h1_title": false, 225 | "title_cell": "Table of Contents", 226 | "title_sidebar": "Contents", 227 | "toc_cell": false, 228 | "toc_position": { 229 | "height": "calc(100% - 180px)", 230 | "left": "10px", 231 | "top": "150px", 232 | "width": "165px" 233 | }, 234 | "toc_section_display": true, 235 | "toc_window_display": false 236 | } 237 | }, 238 | "nbformat": 4, 239 | "nbformat_minor": 5 240 | } 241 | -------------------------------------------------------------------------------- /process.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "b863990e", 6 | "metadata": {}, 7 | "source": [ 8 | "# Barkify: an unoffical repo for training 'bark' like generative model\n", 9 | "\n" 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": null, 15 | "id": "bd7a97cb", 16 | "metadata": {}, 17 | "outputs": [], 18 | "source": [ 19 | "from glob import glob\n", 20 | "from tqdm import tqdm\n", 21 | "import os\n", 22 | "import shutil\n", 23 | "import json\n", 24 | "import soundfile as sf\n", 25 | "import subprocess\n", 26 | "import numpy as np\n", 27 | "import re\n", 28 | "\n", 29 | "import random\n", 30 | "from IPython.display import Audio\n", 31 | "import IPython.display as iply\n", 32 | "\n", 33 | "from pqdm.processes import pqdm\n", 34 | "from pqdm.threads import pqdm as pqdmT\n", 35 | "\n", 36 | "import matplotlib.pyplot as plt\n", 37 | "from matplotlib.pyplot import imshow\n", 38 | "\n", 39 | "import torch\n", 40 | "import torch.nn.functional as F\n", 41 | "\n", 42 | "start_path = \"../work_env\"\n", 43 | "start_path += '/'\n", 44 | "\n", 45 | "cuda_devices = [0,1,2,3]*3\n", 46 | "NJOB = len(cuda_devices)" 47 | ] 48 | }, 49 | { 50 | "cell_type": "code", 51 | "execution_count": null, 52 | "id": "cbc9c004", 53 | "metadata": {}, 54 | "outputs": [], 55 | "source": [ 56 | "def run_multiprocess(x):\n", 57 | " DEVICE, IJOB = x\n", 58 | " subprocess.call(f\"CUDA_VISIBLE_DEVICES={DEVICE} \"+\n", 59 | " f\"nohup python {os.path.join(start_path, 'tmp', 'temp.py')} {IJOB} \"+\n", 60 | " f\"> {os.path.join(start_path, 'tmp','tmp_'+str(IJOB))} 2>&1\",\n", 61 | " shell=True)\n", 62 | " \n", 63 | "def write_tmp(_script):\n", 64 | " with open(os.path.join(start_path, \"tmp\", \"temp.py\"), \"w\") as f:\n", 65 | " f.write(_script)" 66 | ] 67 | }, 68 | { 69 | "cell_type": "code", 70 | "execution_count": null, 71 | "id": "6397855b", 72 | "metadata": {}, 73 | "outputs": [], 74 | "source": [ 75 | "# put all wav to /raw/ folder\n", 76 | "os.makedirs(start_path + \"wavs\", exist_ok=True)\n", 77 | "\n", 78 | "def run(x):\n", 79 | " # \n", 80 | " name = x.replace(\"/raw/\",\"/wavs/\")\n", 81 | " return subprocess.run(f\"ffmpeg -hide_banner -loglevel panic -y -i '{x}' -ac 1 -ar 24000 {name}\",\n", 82 | " shell=True)\n", 83 | "wavs_name = glob(os.path.join(start_path, \"raw/*.wav\"))\n", 84 | "res = pqdm(wavs_name, run, n_jobs=128)" 85 | ] 86 | }, 87 | { 88 | "cell_type": "code", 89 | "execution_count": null, 90 | "id": "a0d74e12", 91 | "metadata": {}, 92 | "outputs": [], 93 | "source": [ 94 | "os.makedirs(start_path + \"wavs16k\", exist_ok=True)\n", 95 | "\n", 96 | "def run(x):\n", 97 | " # \n", 98 | " name = x.replace(\"/raw/\",\"/wavs16k/\")\n", 99 | " return subprocess.run(f\"ffmpeg -hide_banner -loglevel panic -y -i '{x}' -ac 1 -ar 16000 {name}\",\n", 100 | " shell=True)\n", 101 | "wavs_name = glob(os.path.join(start_path, \"raw/*.wav\"))\n", 102 | "res = pqdm(wavs_name, run, n_jobs=128)" 103 | ] 104 | }, 105 | { 106 | "cell_type": "markdown", 107 | "id": "3acacc46", 108 | "metadata": {}, 109 | "source": [ 110 | "## Fetch features from wav2vec2-xlsr and encodec\n", 111 | "根据[meta的论文](https://arxiv.org/pdf/2105.11084.pdf),w2v中15-18层都取得了不错的结果.
\n", 112 | "在我们的实验中证实,w2v2-xlsr中第15层与bark的semantic idx的相关性最高.\n", 113 | "\n", 114 | "在[audioLM](https://arxiv.org/pdf/2209.03143.pdf)中,使用了w2v-bert的7层作为特征.在w2v中,相关性也非常高.
\n", 115 | "我们使用第15层作为实验的特征层.\n", 116 | "\n", 117 | "另外,我们测试了bark代码中,coarse2fine的部分.我们发现,coarse和fine均从encodec中直接得到.
\n", 118 | "因此,如果没有特殊需求,不建议重新训练.\n", 119 | "\n", 120 | "1. Fetch w2v2 hiddens.\n", 121 | "2. Cluster them by codes from fairseq.\n", 122 | "3. Dump cluster idxs to numpy files.\n", 123 | "4. Fetch discrete indices from encodec." 124 | ] 125 | }, 126 | { 127 | "cell_type": "code", 128 | "execution_count": null, 129 | "id": "fd82893a", 130 | "metadata": {}, 131 | "outputs": [], 132 | "source": [ 133 | "LAYER = 15 # use 15th layer.\n", 134 | "Hubert = False # use hubert feature or use w2v2-xlsr feature\n", 135 | "\n", 136 | "clusters = 2048 # semantic idx nums.\n", 137 | "dtype = 32 if clusters > 65535 else 16\n", 138 | "\n", 139 | "percent = 0.1 # use 10% datas for clustering. \n", 140 | "n_init = 10 # when this is larger than 10, some error may occur.\n", 141 | "\n", 142 | "params = dict( # params for clusting. \n", 143 | " init='k-means++', max_iter=100, batch_size=10000, \n", 144 | " tol=0, max_no_improvement=100, n_init=n_init, reassignment_ratio=0,\n", 145 | " compute_labels=False, verbose=100\n", 146 | " )\n", 147 | "params['n_clusters'] = clusters" 148 | ] 149 | }, 150 | { 151 | "cell_type": "markdown", 152 | "id": "c03a7c36", 153 | "metadata": {}, 154 | "source": [ 155 | "### Fetch semantic hiddens\n" 156 | ] 157 | }, 158 | { 159 | "cell_type": "code", 160 | "execution_count": null, 161 | "id": "60b61553", 162 | "metadata": {}, 163 | "outputs": [], 164 | "source": [ 165 | "# process hubert/w2v2 data such that the Hz of semantic idx equals to 50 rather than 49.9\n", 166 | "def process_audio(path):\n", 167 | " wav, fs = sf.read(path)\n", 168 | " safe_length = (wav.shape[0] // 640) * 640 + 160\n", 169 | " wav = wav[:safe_length]\n", 170 | " wav = np.pad(wav, (safe_length - wav.shape[0], 0))\n", 171 | " sf.write(path, wav, fs)\n", 172 | " \n", 173 | "wavs_name = glob(start_path + \"/wavs16k/*.wav\")\n", 174 | "res = pqdmT(wavs_name, process_audio, n_jobs=32)" 175 | ] 176 | }, 177 | { 178 | "cell_type": "code", 179 | "execution_count": null, 180 | "id": "b3c3ef1f", 181 | "metadata": {}, 182 | "outputs": [], 183 | "source": [ 184 | "os.makedirs(start_path + \"tmp\", exist_ok=True)\n", 185 | "os.makedirs(start_path + \"feats\", exist_ok=True)\n", 186 | "\n", 187 | "_script =f'''\n", 188 | "from transformers import AutoProcessor, AutoModelForPreTraining, AutoModel\n", 189 | "if {Hubert}:\n", 190 | " model = AutoModel.from_pretrained(\"TencentGameMate/chinese-hubert-large\")\n", 191 | "else:\n", 192 | " model = AutoModelForPreTraining.from_pretrained(\"facebook/wav2vec2-large-xlsr-53\")\n", 193 | "processor = AutoProcessor.from_pretrained(\"facebook/wav2vec2-base-960h\")\n", 194 | "print(\"downloaded!\")\n", 195 | "'''\n", 196 | "write_tmp(_script) # download it.\n", 197 | "run_multiprocess((0, 0))\n", 198 | " \n", 199 | "_script += f'''\n", 200 | "import os\n", 201 | "import sys\n", 202 | "import torch\n", 203 | "import torch.nn.functional as F\n", 204 | "import numpy as np\n", 205 | "import soundfile as sf\n", 206 | "\n", 207 | "from tqdm import tqdm\n", 208 | "from glob import glob\n", 209 | "\n", 210 | "device = 'cuda:0'\n", 211 | "model = model.to(device)\n", 212 | "\n", 213 | "start_path = '{start_path}'\n", 214 | "NJOB={NJOB}\n", 215 | "meta = glob(start_path+'/wavs16k/*.wav')\n", 216 | "slice_len = (len(meta) + NJOB - 1) // NJOB\n", 217 | "meta = meta[int(sys.argv[1])*slice_len : (int(sys.argv[1])+1)*slice_len]\n", 218 | "\n", 219 | "for _dir in tqdm(meta):\n", 220 | " audio, fs = sf.read(_dir)\n", 221 | " assert fs == 16000\n", 222 | " inputs = processor(audio, sampling_rate=16000, return_tensors=\"pt\")\n", 223 | " for i in inputs:\n", 224 | " inputs[i] = inputs[i].cuda()\n", 225 | " with torch.no_grad():\n", 226 | " hidden = model(**inputs, output_hidden_states=True)['hidden_states'][{LAYER} - 1]\n", 227 | " hidden = F.layer_norm(hidden, hidden.shape)\n", 228 | " np.save(_dir.replace(\"wavs16k\",\"feats\").replace(\".wav\",\"\"), hidden[0].cpu().numpy())\n", 229 | "\n", 230 | "print(\"Finish!\")\n", 231 | "'''" 232 | ] 233 | }, 234 | { 235 | "cell_type": "code", 236 | "execution_count": null, 237 | "id": "02f292ef", 238 | "metadata": {}, 239 | "outputs": [], 240 | "source": [ 241 | "write_tmp(_script)\n", 242 | "res = pqdm(list(zip(cuda_devices, list(range(NJOB)))), run_multiprocess, n_jobs=NJOB)" 243 | ] 244 | }, 245 | { 246 | "cell_type": "markdown", 247 | "id": "e1cd99b9", 248 | "metadata": {}, 249 | "source": [ 250 | "### Clustering" 251 | ] 252 | }, 253 | { 254 | "cell_type": "code", 255 | "execution_count": null, 256 | "id": "ced8ded9", 257 | "metadata": {}, 258 | "outputs": [], 259 | "source": [ 260 | "# Basicly from: https://github.com/facebookresearch/fairseq/tree/main/examples/hubert/simple_kmeans\n", 261 | "# TODO: It seems pretty slow. Maybe try faiss or conduct PCA before clustering?\n", 262 | "\n", 263 | "os.makedirs(start_path + \"tmp\", exist_ok=True)\n", 264 | "os.makedirs(start_path + \"assets\", exist_ok=True)\n", 265 | "\n", 266 | "_script = f'''\n", 267 | "import os\n", 268 | "import sys\n", 269 | "import numpy as np\n", 270 | "import random\n", 271 | "\n", 272 | "from tqdm import tqdm\n", 273 | "from glob import glob\n", 274 | "import joblib\n", 275 | "\n", 276 | "from sklearn.cluster import MiniBatchKMeans\n", 277 | "\n", 278 | "params = {params}\n", 279 | "kmeans = MiniBatchKMeans(**params)\n", 280 | "\n", 281 | "start_path = '{start_path}'\n", 282 | "meta = glob(start_path+'/feats/*.npy')\n", 283 | "random.shuffle(meta)\n", 284 | "meta = meta[ : int(len(meta)*{percent})]\n", 285 | "meta = np.concatenate(\n", 286 | " [np.load(i) for i in meta], axis = 0\n", 287 | ")\n", 288 | "print(\"concated.\")\n", 289 | "\n", 290 | "kmeans.fit(meta)\n", 291 | "joblib.dump(kmeans, start_path + \"assets/km_model.joblib\")\n", 292 | "\n", 293 | "inertia = -kmeans.score(meta) / len(meta)\n", 294 | "print(\"total intertia: %.5f\", inertia)\n", 295 | "print(\"Finish!\")\n", 296 | "'''" 297 | ] 298 | }, 299 | { 300 | "cell_type": "code", 301 | "execution_count": null, 302 | "id": "6594087e", 303 | "metadata": {}, 304 | "outputs": [], 305 | "source": [ 306 | "write_tmp(_script)\n", 307 | "\n", 308 | "# without thread limit, some error may occur.\n", 309 | "!echo OPENBLAS_NUM_THREADS=16 OMP_NUM_THREADS=16 python {start_path + '/tmp/temp.py'}" 310 | ] 311 | }, 312 | { 313 | "cell_type": "markdown", 314 | "id": "9396094f", 315 | "metadata": {}, 316 | "source": [ 317 | "### Infer semantic indices\n" 318 | ] 319 | }, 320 | { 321 | "cell_type": "code", 322 | "execution_count": null, 323 | "id": "cd16049f", 324 | "metadata": {}, 325 | "outputs": [], 326 | "source": [ 327 | "os.makedirs(start_path + \"tmp\", exist_ok=True)\n", 328 | "os.makedirs(start_path + \"semantic_idx\", exist_ok=True)\n", 329 | "\n", 330 | "_script = f'''\n", 331 | "import os\n", 332 | "import sys\n", 333 | "import numpy as np\n", 334 | "import random\n", 335 | "\n", 336 | "from tqdm import tqdm\n", 337 | "from glob import glob\n", 338 | "import joblib\n", 339 | "\n", 340 | "class ApplyKmeans(object):\n", 341 | " def __init__(self, km_path):\n", 342 | " self.km_model = joblib.load(km_path)\n", 343 | " self.C_np = self.km_model.cluster_centers_.transpose()\n", 344 | " self.Cnorm_np = (self.C_np ** 2).sum(0, keepdims=True)\n", 345 | "\n", 346 | " def __call__(self, x):\n", 347 | " dist = (\n", 348 | " (x ** 2).sum(1, keepdims=True)\n", 349 | " - 2 * np.matmul(x, self.C_np)\n", 350 | " + self.Cnorm_np\n", 351 | " )\n", 352 | " return np.argmin(dist, axis=1)\n", 353 | " \n", 354 | "\n", 355 | "start_path = '{start_path}'\n", 356 | "NJOB={NJOB}\n", 357 | "meta = glob(start_path+'/feats/*.npy')\n", 358 | "slice_len = (len(meta) + NJOB - 1) // NJOB\n", 359 | "meta = meta[int(sys.argv[1])*slice_len : (int(sys.argv[1])+1)*slice_len]\n", 360 | "\n", 361 | "apply_kmeans = ApplyKmeans(start_path + '/assets/km_model.joblib')\n", 362 | "\n", 363 | "for _dir in tqdm(meta):\n", 364 | " _idxs = apply_kmeans(np.load(_dir)).astype(np.int{dtype})\n", 365 | " np.save(_dir.replace(\"feats\",\"semantic_idx\"), _idxs)\n", 366 | " \n", 367 | "print(\"Finish!\")\n", 368 | "'''" 369 | ] 370 | }, 371 | { 372 | "cell_type": "code", 373 | "execution_count": null, 374 | "id": "9d96be9b", 375 | "metadata": {}, 376 | "outputs": [], 377 | "source": [ 378 | "write_tmp(_script)\n", 379 | "res = pqdm(list(zip(cuda_devices, list(range(NJOB)))), run_multiprocess, n_jobs=NJOB) " 380 | ] 381 | }, 382 | { 383 | "cell_type": "markdown", 384 | "id": "40a05724", 385 | "metadata": {}, 386 | "source": [ 387 | "### Fetch discrete indices from encodec\n" 388 | ] 389 | }, 390 | { 391 | "cell_type": "code", 392 | "execution_count": null, 393 | "id": "30a2ac29", 394 | "metadata": {}, 395 | "outputs": [], 396 | "source": [ 397 | "os.makedirs(start_path + \"tmp\", exist_ok=True)\n", 398 | "os.makedirs(start_path + \"encodec_idx\", exist_ok=True)\n", 399 | "\n", 400 | "_script ='''\n", 401 | "from encodec import EncodecModel\n", 402 | "from encodec.utils import convert_audio\n", 403 | "\n", 404 | "model = EncodecModel.encodec_model_24khz()\n", 405 | "model.set_target_bandwidth(6.0)\n", 406 | "print(\"downloaded!\")\n", 407 | "'''\n", 408 | "write_tmp(_script) # download it.\n", 409 | "run_multiprocess((0, 0))\n", 410 | " \n", 411 | "_script += f'''\n", 412 | "import os\n", 413 | "import sys\n", 414 | "import torch\n", 415 | "import numpy as np\n", 416 | "import torchaudio\n", 417 | "\n", 418 | "from tqdm import tqdm\n", 419 | "from glob import glob\n", 420 | "\n", 421 | "device = 'cuda:0'\n", 422 | "model = model.to(device)\n", 423 | "\n", 424 | "start_path = '{start_path}'\n", 425 | "NJOB={NJOB}\n", 426 | "meta = glob(start_path+'/wavs/*.wav')\n", 427 | "slice_len = (len(meta) + NJOB - 1) // NJOB\n", 428 | "meta = meta[int(sys.argv[1])*slice_len : (int(sys.argv[1])+1)*slice_len]\n", 429 | "\n", 430 | "for _dir in tqdm(meta):\n", 431 | " wav, sr = torchaudio.load(_dir)\n", 432 | " wav = wav[:, :wav.shape[-1] - int(160*1.5)]\n", 433 | " # wav = convert_audio(wav, sr, model.sample_rate, model.channels)\n", 434 | " wav = wav.unsqueeze(0).cuda()\n", 435 | " with torch.no_grad():\n", 436 | " encoded_frames = model.encode(wav)[0][0][0]\n", 437 | " np.save(_dir.replace(\"wavs\",\"encodec_idx\").replace(\".wav\",\"\"), encoded_frames.cpu().numpy().astype(np.int16))\n", 438 | "\n", 439 | "print(\"Finish!\")\n", 440 | "'''" 441 | ] 442 | }, 443 | { 444 | "cell_type": "code", 445 | "execution_count": null, 446 | "id": "5741ee0c", 447 | "metadata": {}, 448 | "outputs": [], 449 | "source": [ 450 | "write_tmp(_script)\n", 451 | "res = pqdm(list(zip(cuda_devices, list(range(NJOB)))), run_multiprocess, n_jobs=NJOB) " 452 | ] 453 | }, 454 | { 455 | "cell_type": "markdown", 456 | "id": "85b774eb", 457 | "metadata": {}, 458 | "source": [ 459 | "## Prepare dataset" 460 | ] 461 | }, 462 | { 463 | "cell_type": "code", 464 | "execution_count": null, 465 | "id": "3bf05fff", 466 | "metadata": {}, 467 | "outputs": [], 468 | "source": [ 469 | "# min_length = 2 # at least 2 seconds.\n", 470 | "# _min_length = int(np.floor(min_length * 50))\n", 471 | "for_eval = 128\n", 472 | "\n", 473 | "os.makedirs(start_path + \"meta\", exist_ok=True)" 474 | ] 475 | }, 476 | { 477 | "cell_type": "code", 478 | "execution_count": null, 479 | "id": "acb150c3", 480 | "metadata": {}, 481 | "outputs": [], 482 | "source": [ 483 | "import pandas as pd\n", 484 | "import json\n", 485 | "datas = pd.read_csv(start_path+\"meta/metadata.csv\",sep=\"|\", header=None)\n", 486 | "datas = datas.dropna()\n", 487 | "datas = datas.values\n", 488 | "np.random.shuffle(datas)\n", 489 | "\n", 490 | "with open(start_path + \"meta/train.json\",\"w\") as f:\n", 491 | " for i in datas[:-for_eval]:\n", 492 | " line = json.dumps({\"name\": i[0] +\".npy\", \"text\": i[1]})\n", 493 | " f.writelines(line+\"\\n\")\n", 494 | "\n", 495 | "with open(start_path + \"meta/eval.json\",\"w\") as f:\n", 496 | " for i in datas[-for_eval:]:\n", 497 | " line = json.dumps({\"name\": i[0] +\".npy\", \"text\": i[1]})\n", 498 | " f.writelines(line+\"\\n\")" 499 | ] 500 | } 501 | ], 502 | "metadata": { 503 | "kernelspec": { 504 | "display_name": "Python 3 (ipykernel)", 505 | "language": "python", 506 | "name": "python3" 507 | }, 508 | "language_info": { 509 | "codemirror_mode": { 510 | "name": "ipython", 511 | "version": 3 512 | }, 513 | "file_extension": ".py", 514 | "mimetype": "text/x-python", 515 | "name": "python", 516 | "nbconvert_exporter": "python", 517 | "pygments_lexer": "ipython3", 518 | "version": "3.8.15" 519 | }, 520 | "toc": { 521 | "base_numbering": 1, 522 | "nav_menu": {}, 523 | "number_sections": true, 524 | "sideBar": true, 525 | "skip_h1_title": false, 526 | "title_cell": "Table of Contents", 527 | "title_sidebar": "Contents", 528 | "toc_cell": false, 529 | "toc_position": { 530 | "height": "calc(100% - 180px)", 531 | "left": "10px", 532 | "top": "150px", 533 | "width": "181.663px" 534 | }, 535 | "toc_section_display": true, 536 | "toc_window_display": false 537 | } 538 | }, 539 | "nbformat": 4, 540 | "nbformat_minor": 5 541 | } 542 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | funcy 2 | numpy 3 | scipy 4 | tqdm 5 | torch==1.13.0 6 | pytorch_lightning==1.9.0 7 | encodec 8 | transformers -------------------------------------------------------------------------------- /trainer.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import hydra 4 | from glob import glob 5 | 6 | import pytorch_lightning as pl 7 | from pytorch_lightning.callbacks import ModelCheckpoint 8 | 9 | from barkify.datas import StageDataloader 10 | from barkify.pl_model import NanoGPT 11 | from barkify.utils import Bestckpt 12 | 13 | @hydra.main(config_path='configs', config_name='barkify') 14 | def main(cfg=None): 15 | 16 | exp_name = f'stage_{cfg.stage}' 17 | exp_cfg = cfg[f'stage{cfg.stage}'] 18 | exp_root_dir = f'{cfg.start_path}/{cfg.name}/{exp_name}' 19 | 20 | # define datas 21 | train_loader = StageDataloader(exp_cfg, cfg.stage, 'train') 22 | val_loader = StageDataloader(exp_cfg, cfg.stage, 'eval') 23 | 24 | # define model 25 | model = NanoGPT(model_config = exp_cfg.model, **exp_cfg.optim) 26 | 27 | # define trainer 28 | trainer = pl.Trainer( 29 | default_root_dir=exp_root_dir, 30 | callbacks = ModelCheckpoint(**cfg.common.ckpt), 31 | max_steps = exp_cfg.optim.max_iters, 32 | gradient_clip_val = exp_cfg.optim.gradient_clip, 33 | **cfg.common.trainer 34 | ) 35 | trainer.logger._default_hp_metric = None # Optional logging argument that we don't need 36 | 37 | # load best ckpt 38 | ckpt = Bestckpt(exp_root_dir) 39 | trainer.fit(model, train_loader, val_loader, ckpt_path=ckpt) 40 | 41 | if __name__ == '__main__': 42 | torch.set_float32_matmul_precision('medium') 43 | main() 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | --------------------------------------------------------------------------------