├── old_out.wav ├── general_out.wav ├── LJ003-0259-synth.wav ├── LJ016-0073-synth.wav ├── LICENSE ├── g2p_util.py ├── encodec_util.py ├── .gitignore ├── README.md ├── attend.py ├── ljspeech.py ├── train.py ├── megabyte.py └── main.ipynb /old_out.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MiscellaneousStuff/PhoneLM/HEAD/old_out.wav -------------------------------------------------------------------------------- /general_out.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MiscellaneousStuff/PhoneLM/HEAD/general_out.wav -------------------------------------------------------------------------------- /LJ003-0259-synth.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MiscellaneousStuff/PhoneLM/HEAD/LJ003-0259-synth.wav -------------------------------------------------------------------------------- /LJ016-0073-synth.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MiscellaneousStuff/PhoneLM/HEAD/LJ016-0073-synth.wav -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 MiscellaneousStuff 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /g2p_util.py: -------------------------------------------------------------------------------- 1 | from g2p_en import G2p 2 | 3 | import torch 4 | import random 5 | import string 6 | from functools import cache 7 | from tqdm import tqdm 8 | 9 | @cache 10 | def _get_model(): 11 | return G2p() 12 | 13 | @cache 14 | def _get_graphs(path): 15 | with open(path, "r") as f: 16 | graphs = f.read() 17 | return graphs 18 | 19 | def encode_text(graphs: str) -> list[str]: 20 | g2p = _get_model() 21 | phones = g2p(graphs) 22 | ignored = {" ", *string.punctuation} 23 | return ["_" if p in ignored else p for p in phones] 24 | 25 | def encode_text_direct(text): 26 | g2p = _get_model() 27 | phones = g2p(text) 28 | ignored = {" ", *string.punctuation} 29 | return ["_" if p in ignored else p for p in phones] 30 | 31 | @torch.no_grad() 32 | def write_phones(folder, suffix=".normalized.txt"): 33 | paths = list(folder.rglob(f"*{suffix}")) 34 | random.shuffle(paths) 35 | 36 | for path in tqdm(paths): 37 | phone_path = path.with_name(path.stem.split(".")[0] + ".phn.txt") 38 | if phone_path.exists(): 39 | continue 40 | graphs = _get_graphs(path) 41 | phones = encode_text(graphs) 42 | with open(phone_path, "w") as f: 43 | f.write(" ".join(phones)) -------------------------------------------------------------------------------- /encodec_util.py: -------------------------------------------------------------------------------- 1 | 2 | import random 3 | import soundfile 4 | from pathlib import Path 5 | 6 | import torch 7 | from torch import Tensor 8 | import torchaudio 9 | 10 | from functools import cache 11 | from tqdm import tqdm 12 | from einops import rearrange 13 | 14 | from encodec import EncodecModel 15 | from encodec.utils import convert_audio 16 | 17 | SAMPLE_RATE = 24_000 18 | BANDWIDTH = 1.5 # 6.0 19 | 20 | @cache 21 | def _load_model(bandwidth=6.0, device="cuda"): 22 | # Instantiate a pretrained EnCodec model 23 | assert SAMPLE_RATE == 24_000 24 | model = EncodecModel.encodec_model_24khz() 25 | model.set_target_bandwidth(bandwidth) 26 | model.to(device) 27 | return model 28 | 29 | def unload_model(): 30 | return _load_model.cache_clear() 31 | 32 | @torch.inference_mode() 33 | def decode(codes: Tensor, bandwidth=6.0, device="cuda"): 34 | """ 35 | Args: 36 | codes: (b q t) 37 | """ 38 | assert codes.dim() == 3 39 | model = _load_model(bandwidth, device) 40 | return model.decode([(codes, None)]), model.sample_rate 41 | 42 | def decode_to_file(resps: Tensor, path: Path): 43 | assert resps.dim() == 2, f"Require shape (t q), but got {resps.shape}." 44 | resps = rearrange(resps, "t q -> 1 q t") 45 | wavs, sr = decode(codes=resps, bandwidth=BANDWIDTH) 46 | soundfile.write(str(path), wavs.cpu()[0, 0], sr) 47 | 48 | def _replace_file_extension(path, suffix): 49 | return (path.parent / path.name.split(".")[0]).with_suffix(suffix) 50 | 51 | @torch.inference_mode() 52 | def encode(wav: Tensor, sr: int, bandwidth=6.0, device="cuda"): 53 | """ 54 | Args: 55 | wav: (t) 56 | sr: int 57 | """ 58 | model = _load_model(bandwidth, device) 59 | wav = wav.unsqueeze(0) 60 | wav = convert_audio(wav, sr, model.sample_rate, model.channels) 61 | wav = wav.to(device) 62 | encoded_frames = model.encode(wav) 63 | qnt = torch.cat([encoded[0] for encoded in encoded_frames], dim=-1) # (b q t) 64 | return qnt 65 | 66 | def encode_from_file(path, bandwidth=6.0, device="cuda"): 67 | wav, sr = torchaudio.load(str(path)) 68 | if wav.shape[0] == 2: 69 | wav = wav[:1] 70 | return encode(wav, sr, bandwidth, device) 71 | 72 | def quantize_audio(folder, suffix=".wav"): 73 | paths = [*folder.rglob(f"*{suffix}")] 74 | random.shuffle(paths) 75 | 76 | for path in tqdm(paths): 77 | out_path = _replace_file_extension(path, ".qnt.pt") 78 | if out_path.exists(): 79 | continue 80 | qnt = encode_from_file(path, BANDWIDTH) 81 | print(qnt.shape) 82 | torch.save(qnt.cpu(), out_path) 83 | 84 | def decode_files(folder, suffix=".qnt.pt"): 85 | paths = [*folder.rglob(f"*{suffix}")] 86 | random.shuffle(paths) 87 | 88 | for path in tqdm(paths): 89 | out_path = _replace_file_extension(path, ".qt.wav") 90 | if out_path.exists(): 91 | continue 92 | fi = rearrange(torch.load(path).squeeze(0).cuda(), "q t -> t q") 93 | decode_to_file(fi, out_path) -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | 162 | # Ignore datasets (LJSpeech, Personal Audio) 163 | /data 164 | test.m4a 165 | out.wav 166 | 167 | # Ignore checkpoints 168 | *.pt -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PhoneLM 2 | 3 | ## About 4 | 5 | UPDATE UPDATE (04/09/2023): Model overfitting on single and multiple samples works. 6 | Generalisation seems to be harder. Probably because original MegaByte model from 7 | lucidrains trained to predict entire sequence from 1 token to n-1 tokens. Need to 8 | training method to only predict the response and not the token, otherwise the model 9 | is having to pointlessly learn to predict the prompt as well. 10 | 11 | UPDATE: Generalisation training seems some what promising. Model consistently outputs 12 | the correct number of audio tokens and can deal with the temporal context somewhat well. 13 | However, main issue seems to be more with the "spatial" component of predicting the sequence, 14 | i.e., predicting the correct codebook codes per timestep. 15 | 16 | Text to speech using phonemes as inputs and audio codec codes as outputs. Loosely based on MegaByte, VALL-E and Encodec. 17 | 18 | ## Method 19 | 20 | - [x] Use [G2P](https://github.com/Kyubyong/g2p/) to encode text. 21 | - [x] Use [encodec](https://github.com/facebookresearch/encodec) to 22 | encode and decode audio. 23 | - [x] Custom LJSpeech dataloader to include phonemes and encodec audio codes 24 | 25 | ### LJSpeech 26 | 27 | - [x] Overfit model on one sample from LJSpeech 28 | - [x] Combine token space of text and audio codec codes 29 | - `LJ016-0073-synth.wav` The initial "Mr. Cope" can just about be made out 30 | - Using a codebook of 2 seems to be too aggressive. 31 | - `LJ003-0259-synth.wav` "And attracted attention by their". Codebook of 2 is possible. 32 | Main issues is sequence length. 33 | - Scaling up sequence length is easier than scaling up codebook size. This is for the 34 | arrangement of [time1_code_1, time_1_code_2, ...]. 35 | Perhaps [time1_code_1, time_2_code_1, ...] might perform better? So synthesize all codebook1 then all codebook 2. 36 | - Longer duration prompts and audio targets seem to perform worse. Will try experimenting 37 | with shorter prompts (try to stick to roughly 3 second audio snippets.) 38 | - [-] Generalise (Using either 1 second prompt + clip, or 1.5 sec prompt and clip) 39 | - [x] Get any prompt to audio working (even if unintelligible and using clamping) 40 | - [-] Get any coherent output 41 | 42 | 47 | 48 | ## Inspiration 49 | 50 | This model is loosely based on the VALL-E paper by Microsoft. It uses the 51 | MegaByte inspired model from [Lucidrains](https://github.com/lucidrains/MEGABYTE-pytorch) 52 | as the Transformer Decoder model. Just as in VALL-E, a users text prompt is converted 53 | into phonemes using [G2P](https://github.com/Kyubyong/g2p/) (Grapheme-to-phoneme), 54 | and then the [encodec](https://github.com/facebookresearch/encodec) audio codec codes 55 | are predicted. However, unlike VALL-E, only an autoregressive model is used. The VALL-E 56 | paper uses an autoregressive model to accept phonemes and audio codec code snippets of 57 | a source audio and uses that to predict the first codebook codes. The rest of the codebook 58 | codes are then predicted when the AR model is finished, it accepts the entire sequence, 59 | and then predicts all of the codebook 2 to codebook N codes. However, this increases 60 | the complexity of the approach as two models are now required and raises the possibility 61 | that the NAR model can not attend to all past inputs unlike the AR which can reduce 62 | audio quality output and may lead to repeating of outputs. In practice, the use of phonemes 63 | as input into VALL-E may alleviate this, however, this approach explores just predicting 64 | the entire sequence auto-regressively (across all codebooks at once). 65 | 66 | This is inspired by the fact that the authors of the original [MegaByte](https://arxiv.org/pdf/2305.07185.pdf) 67 | paper perform autoregressive audio prediction on raw audio data. They 68 | treat the audio files as just raw byte sequences and train a model to predict audio on 2TB 69 | worth of audio and find that compared to a vanilla transformer or Perceiver architectures, 70 | it scores a higher bpb. In principle, this means that the model is more efficient and accurate 71 | at modelling raw audio byte sequences than other approaches. The other benefits of the method 72 | is that the patch based auto-regressive generation may be well suited to the codebooks used 73 | by [encodec](https://github.com/facebookresearch/encodec). As the patch size can be set to 4 74 | (for 4 codebooks each of which can be 1 of 1024 values), this means the local model of the 75 | MegaByte model can focus on modelling individual audio codec elements and the global model 76 | can focus on the larger context. Hopefully this greatly improves audio quality compared to 77 | VALL-E while being much simpler to train. -------------------------------------------------------------------------------- /attend.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | from functools import wraps 3 | from packaging import version 4 | 5 | import torch 6 | from torch import nn, einsum 7 | import torch.nn.functional as F 8 | 9 | from einops import rearrange 10 | 11 | # constants 12 | 13 | EfficientAttentionConfig = namedtuple('EfficientAttentionConfig', ['enable_flash', 'enable_math', 'enable_mem_efficient']) 14 | 15 | # helpers 16 | 17 | def exists(val): 18 | return val is not None 19 | 20 | def once(fn): 21 | called = False 22 | @wraps(fn) 23 | def inner(x): 24 | nonlocal called 25 | if called: 26 | return 27 | called = True 28 | return fn(x) 29 | return inner 30 | 31 | print_once = once(print) 32 | 33 | # main class 34 | 35 | class Attend(nn.Module): 36 | def __init__( 37 | self, 38 | causal = False, 39 | dropout = 0., 40 | flash = False 41 | ): 42 | super().__init__() 43 | self.dropout = dropout 44 | self.attn_dropout = nn.Dropout(dropout) 45 | 46 | self.causal = causal 47 | self.flash = flash 48 | assert not (flash and version.parse(torch.__version__) < version.parse('2.0.0')), 'in order to use flash attention, you must be using pytorch 2.0 or above' 49 | 50 | # determine efficient attention configs for cuda and cpu 51 | 52 | self.cpu_config = EfficientAttentionConfig(True, True, True) 53 | self.cuda_config = None 54 | 55 | if not torch.cuda.is_available() or not flash: 56 | return 57 | 58 | device_properties = torch.cuda.get_device_properties(torch.device('cuda')) 59 | 60 | if device_properties.major == 8 and device_properties.minor == 0: 61 | print_once('A100+ GPU detected, using flash attention if input tensor is on cuda') 62 | self.cuda_config = EfficientAttentionConfig(True, False, False) 63 | else: 64 | print_once('Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda') 65 | self.cuda_config = EfficientAttentionConfig(False, True, True) 66 | 67 | def get_mask(self, i, j, device): 68 | return torch.ones((i, j), device=device, dtype=torch.bool).triu(j - i + 1) 69 | 70 | def flash_attn(self, q, k, v, mask = None, attn_bias = None): 71 | _, heads, q_len, _, k_len, is_cuda, device = *q.shape, k.shape[-2], q.is_cuda, q.device 72 | 73 | # single headed key / values 74 | 75 | if k.ndim == 3: 76 | k = rearrange(k, 'b n d -> b 1 n d') 77 | 78 | if v.ndim == 3: 79 | v = rearrange(v, 'b n d -> b 1 n d') 80 | 81 | # Check if mask exists and expand to compatible shape 82 | # The mask is B L, so it would have to be expanded to B H N L 83 | 84 | if exists(mask) and mask.ndim != 4: 85 | mask = rearrange(mask, 'b j -> b 1 1 j') 86 | mask = mask.expand(-1, heads, q_len, -1) 87 | 88 | # Check if there is a compatible device for flash attention 89 | 90 | config = self.cuda_config if is_cuda else self.cpu_config 91 | 92 | # pytorch 2.0 flash attn: q, k, v, mask, dropout, causal, softmax_scale 93 | 94 | with torch.backends.cuda.sdp_kernel(**config._asdict()): 95 | out = F.scaled_dot_product_attention( 96 | q, k, v, 97 | attn_mask = mask, 98 | dropout_p = self.dropout if self.training else 0., 99 | is_causal = self.causal 100 | ) 101 | 102 | return out 103 | 104 | def forward(self, q, k, v, mask = None): 105 | """ 106 | einstein notation 107 | b - batch 108 | h - heads 109 | n, i, j - sequence length (base sequence length, source, target) 110 | d - feature dimension 111 | """ 112 | 113 | q_len, k_len, device = q.shape[-2], k.shape[-2], q.device 114 | 115 | scale = q.shape[-1] ** -0.5 116 | 117 | kv_einsum_eq = 'b j d' if k.ndim == 3 else 'b h j d' 118 | 119 | if self.flash: 120 | return self.flash_attn(q, k, v, mask = mask) 121 | 122 | # similarity 123 | 124 | sim = einsum(f"b h i d, {kv_einsum_eq} -> b h i j", q, k) * scale 125 | 126 | # causal mask 127 | 128 | if self.causal: 129 | causal_mask = self.get_mask(q_len, k_len, device) 130 | sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max) 131 | 132 | # attention 133 | 134 | attn = sim.softmax(dim=-1) 135 | attn = self.attn_dropout(attn) 136 | 137 | # aggregate values 138 | 139 | out = einsum(f"b h i j, {kv_einsum_eq} -> b h i d", attn, v) 140 | 141 | return out -------------------------------------------------------------------------------- /ljspeech.py: -------------------------------------------------------------------------------- 1 | """ 2 | Modified version of torchaudio dataset class for LJSpeech. 3 | Instead of returning the waveform, it uses encodec to return the audio file codes. 4 | """ 5 | 6 | import csv 7 | import os 8 | from pathlib import Path 9 | from typing import Tuple, Union 10 | 11 | import torch 12 | import torchaudio 13 | from torch import Tensor 14 | from torch.utils.data import Dataset 15 | 16 | from g2p_en import G2p 17 | from g2p_util import encode_text_direct, _get_model 18 | from encodec_util import encode, decode 19 | 20 | _RELEASE_CONFIGS = { 21 | "release1": { 22 | "folder_in_archive": "wavs", 23 | "url": "https://data.keithito.com/data/speech/LJSpeech-1.1.tar.bz2", 24 | "checksum": "be1a30453f28eb8dd26af4101ae40cbf2c50413b1bb21936cbcdc6fae3de8aa5", 25 | } 26 | } 27 | 28 | class LJSPEECH(Dataset): 29 | """*LJSpeech-1.1* :cite:`ljspeech17` dataset. 30 | 31 | Args: 32 | root (str or Path): Path to the directory where the dataset is found. 33 | folder_in_archive (str, optional): 34 | The top-level directory of the dataset. (default: ``"wavs"``). 35 | """ 36 | 37 | def __init__( 38 | self, 39 | root: Union[str, Path], 40 | encodec_bandwidth: float = 6.0, 41 | folder_in_archive: str = _RELEASE_CONFIGS["release1"]["folder_in_archive"], 42 | max_prompt_length: int = 60 43 | ) -> None: 44 | 45 | self._parse_filesystem(root, folder_in_archive, max_prompt_length) 46 | self.encodec_bandwidth = encodec_bandwidth 47 | self.phone_dict = _get_model().phonemes + ["_"] 48 | 49 | def _parse_filesystem( 50 | self, 51 | root: str, 52 | folder_in_archive: str, 53 | max_prompt_length: int) -> None: 54 | 55 | root = Path(root) 56 | 57 | basename = os.path.basename(_RELEASE_CONFIGS["release1"]["url"]) 58 | 59 | basename = Path(basename.split(".tar.bz2")[0]) 60 | folder_in_archive = basename / folder_in_archive 61 | 62 | self._path = root / folder_in_archive 63 | self._metadata_path = root / basename / "metadata.csv" 64 | 65 | if not os.path.exists(self._path): 66 | raise RuntimeError( 67 | f"The path {self._path} doesn't exist. " 68 | "Please check the ``root`` path" 69 | ) 70 | 71 | with open(self._metadata_path, "r", newline="", encoding="utf-8") as metadata: 72 | flist = csv.reader(metadata, delimiter="|", quoting=csv.QUOTE_NONE) 73 | self._flist = list(flist) 74 | if max_prompt_length: 75 | self._flist = [item 76 | for item in self._flist 77 | if len(item[2]) <= max_prompt_length] 78 | # if max_prompt_length: 79 | # self._flist = [item 80 | # for item in self._flist 81 | # if len(item[1]) <= max_prompt_length] 82 | # for i in range(len(self._flist)): 83 | # item = self._flist[i][1] 84 | # self._flist[i][1] = " ".join(item.split(" ")[0:max_prompt_length//6]) 85 | 86 | def __getitem__(self, n: int) -> Tuple[Tensor, int, str, str]: 87 | """Load the n-th sample from the dataset. 88 | 89 | Args: 90 | n (int): The index of the sample to be loaded 91 | 92 | Returns: 93 | Tuple of the following items; 94 | 95 | Tensor: 96 | Waveform 97 | int: 98 | Sample rate 99 | str: 100 | Transcript 101 | str: 102 | Normalized Transcript 103 | """ 104 | line = self._flist[n] 105 | fileid, transcript, normalized_transcript = line 106 | fileid_audio = self._path / (fileid + ".wav") 107 | 108 | # Load audio 109 | waveform, sample_rate = torchaudio.load(fileid_audio) 110 | 111 | # G2P and Encodec 112 | phones = encode_text_direct(normalized_transcript) 113 | phone_ids = torch.tensor( 114 | [self.phone_dict.index(phone) 115 | if phone in self.phone_dict else 0 116 | for phone in phones]).long().cuda() 117 | codes = encode(waveform, sample_rate, self.encodec_bandwidth) 118 | 119 | return ( 120 | fileid_audio, 121 | waveform, 122 | sample_rate, 123 | transcript, 124 | normalized_transcript, 125 | phones, 126 | phone_ids, 127 | codes 128 | ) 129 | 130 | def __len__(self) -> int: 131 | return len(self._flist) -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import megabyte 2 | import torch 3 | import torch.nn as nn 4 | from einops import rearrange 5 | 6 | import torchaudio 7 | from ljspeech import LJSPEECH 8 | DATASET_PATH = "./data/LJSpeech/" 9 | 10 | from torch.utils.data import DataLoader, SubsetRandomSampler 11 | from sklearn.model_selection import train_test_split 12 | 13 | from einops import rearrange 14 | from tqdm import tqdm 15 | from encodec_util import decode_to_file 16 | 17 | import torch.optim as optim 18 | import torch.nn.functional as F 19 | from torch.cuda.amp.grad_scaler import GradScaler 20 | 21 | import math 22 | import random 23 | import numpy as np 24 | 25 | from g2p_util import _get_model 26 | 27 | SEED = 42 28 | 29 | random.seed(SEED) 30 | torch.manual_seed(SEED) 31 | np.random.seed(SEED) 32 | 33 | MAX_LR = 1e-3 # 1e-3 34 | MIN_LR = 1e-4 35 | # MAX_LR = 1e-2 36 | WEIGHT_DECAY = 1e-4 37 | GRAD_CLIP = 0.1 38 | WARMUP_ITERS = 500 # 500 # Taken from from MegaByte paper 39 | EPOCHS = 100 # 100 # 200 # 100 # 1000 40 | PRINT_INTERVAL = 1 41 | SEQ_LEN = 1024 # 512 42 | BATCH_SIZE = 48 # 32 # 96 # 128 # 64 43 | DECAY_LR = True 44 | NUM_BATCHES = None # EPOCHS * # int(1e5) 45 | 46 | BANDWIDTH_IDX = 0 # original VALL-E 47 | CODEBOOKS = [2, 4, 8, 16, 32] 48 | BANDWIDTHS = [1.5, 3.0, 6.0, 12.0, 24.0] 49 | BANDWIDTH = BANDWIDTHS[BANDWIDTH_IDX] 50 | CODEBOOK = CODEBOOKS[BANDWIDTH_IDX] 51 | MAX_CLIP_LENGTH = int(5) 52 | MAX_PROMPT_LENGTH = int(30 * MAX_CLIP_LENGTH) 53 | 54 | VALIDATE_EVERY = 1 55 | 56 | AMP = True 57 | SAVE = True 58 | 59 | def get_reserved_mem_gb(): 60 | device = torch.cuda.current_device() 61 | reserved = torch.cuda.memory_reserved(device) 62 | reserved_gb = reserved / 1024 / 1024 / 1024 63 | return reserved_gb 64 | 65 | # Taken from: https://github.com/karpathy/nanoGPT/blob/master/train.py 66 | # Learning rate decay scheduler (Cosine with Warmup) 67 | def get_lr(it): 68 | # 1) Linear warmup for warmup_iters steps 69 | if it < WARMUP_ITERS: 70 | return MAX_LR * it / WARMUP_ITERS 71 | 72 | # 2) If it > lr_decay_iters, return min learning rate 73 | if it > NUM_BATCHES: 74 | return MIN_LR 75 | 76 | # 3) In between, use cosine decay down to min learning rate 77 | decay_ratio = (it - WARMUP_ITERS) / (NUM_BATCHES - WARMUP_ITERS) 78 | assert 0 <= decay_ratio <= 1 79 | coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1 80 | return MIN_LR + coeff * (MAX_LR - MIN_LR) 81 | 82 | class PhoneLM(nn.Module): 83 | def __init__(self, n_phone_tokens, n_audio_tokens): 84 | super(PhoneLM, self).__init__() 85 | self.megabyte = megabyte.MEGABYTE( 86 | heads = 16, # 1, 87 | dim_head = 64, # 16, 88 | num_tokens = n_phone_tokens + n_audio_tokens + 4, 89 | dim = (1024, 256, 128), # (32, 32, 32), # (768, 256, 128)# Dg, Dl1, Dl2 90 | depth = (6, 4, 2), # (12, 4, 2), # (6, 4, 2) 91 | max_seq_len = (SEQ_LEN // 16, 4, 4), # (32, 4, 4), # (128, 4, 4), # (32, 4, 4), # 512 92 | flash_attn = False) 93 | 94 | def forward(self, x, debug=False, return_loss=True): 95 | x = self.megabyte(x, return_loss=return_loss) 96 | return x 97 | 98 | def get_params(self): 99 | o = [param.numel() for param in self.parameters() if param.requires_grad] 100 | o = sum(o) 101 | return o 102 | 103 | def generate(self, *args): 104 | return self.megabyte.generate(*args) 105 | 106 | def multi_encode( 107 | phone_tokens, 108 | audio_tokens, 109 | n_phone_tokens, 110 | n_audio_tokens, 111 | max_clip_length=1.0): 112 | """NOTE: 75 steps per second for 24kHz in `encodec. 113 | Set `max_clip_length` to 0 for original clip length.""" 114 | 115 | # Start text token, end text token, start audio token, end audio token 116 | ETT, EAT = [n_phone_tokens + n_audio_tokens + i 117 | for i in range(2)] 118 | ETT = torch.tensor([ETT]).long().cuda() 119 | EAT = torch.tensor([EAT]).long().cuda() 120 | 121 | if max_clip_length > 0: 122 | #print("pre audio_tokens.shape", audio_tokens.shape) 123 | audio_tokens = audio_tokens[:, :, :int(max_clip_length * 75)] 124 | #print("post audio_tokens.shape", audio_tokens.shape) 125 | audio_tokens = rearrange(audio_tokens.squeeze(0), "q s -> (q s)") 126 | #print("post einops audio_tokens.shape", audio_tokens.shape) 127 | 128 | # offset phone tokens past audio tokens 129 | phone_tokens += n_audio_tokens 130 | 131 | #print("phone_tokens.shape:", phone_tokens.shape) 132 | #print("audio_tokens.shape:", audio_tokens.shape) 133 | 134 | device = torch.cuda.current_device() 135 | phone_tokens = phone_tokens.to(device) 136 | # phone_tokens = torch.cat((phone_tokens), dim=0).to(device) 137 | audio_tokens = torch.cat((audio_tokens, EAT), dim=0).to(device) 138 | combined_tokens = torch.cat((phone_tokens, ETT, audio_tokens), dim=0).to(device) 139 | return phone_tokens, audio_tokens, combined_tokens 140 | 141 | def generate_audio(sample, 142 | n_phone_tokens, 143 | n_audio_tokens, 144 | audio_path="./out.wav"): 145 | ETT, EAT = [n_phone_tokens + n_audio_tokens + i 146 | for i in range(2)] 147 | ST_S = [ETT, EAT] 148 | print("ETT, EAT ids:", ST_S) 149 | seq = sample.cpu().tolist()[0] 150 | print("seq:", seq) 151 | # all special tokens in list 152 | if all(st_t in seq for st_t in ST_S) and len(seq) >= len(ST_S) + 2: 153 | # text_tokens = seq[seq.index(STT + 1):seq.index(ETT - 1)] 154 | audio_tokens = seq[seq.index(ETT)+1:seq.index(EAT)] 155 | print(seq.index(ETT), seq.index(EAT), len(audio_tokens)) 156 | audio_tokens = torch.tensor(audio_tokens).cuda() 157 | audio_tokens = rearrange( 158 | audio_tokens, 159 | '(t q) -> t q', 160 | q=1, # CODEBOOK, 161 | t=audio_tokens.size(0) // 1) # t=audio_tokens.size(0) // CODEBOOK) 162 | print("audio_tokens.shape:", audio_tokens, audio_tokens.shape) 163 | decode_to_file(audio_tokens, audio_path) 164 | return True 165 | else: 166 | return False 167 | 168 | def generate(model, prompt): 169 | model.eval() 170 | 171 | prompt = prompt.unsqueeze(0) 172 | sample = model.generate(prompt) 173 | sample = sample.flatten(1) 174 | # print("sample:", sample, sample.shape) 175 | 176 | return prompt, sample 177 | 178 | def collate_fn(dataset, batch): 179 | """ 180 | batch := [ 181 | fileid_audio, 182 | waveform, 183 | sample_rate, 184 | transcript, 185 | normalized_transcript, 186 | phones, 187 | phone_ids, 188 | codes 189 | ] * batch_size 190 | """ 191 | # print("collate len batch:", len(batch)) 192 | items = [] 193 | for item in batch: 194 | _, _, test_inp = multi_encode( 195 | item[-2], 196 | item[-1], 197 | n_phone_tokens=len(dataset.phone_dict), 198 | n_audio_tokens=1024, 199 | max_clip_length=MAX_CLIP_LENGTH) 200 | 201 | padding_len = max(0, SEQ_LEN - test_inp.size(0)) 202 | n_test_inp = F.pad(test_inp, (0, padding_len)) 203 | # print("n_test_inp.shape:", n_test_inp.shape) 204 | items.append(n_test_inp) 205 | 206 | out = torch.stack(items).cuda() 207 | # print(out.shape) 208 | return out 209 | # nn.utils.rnn.pad_sequence(spectrograms, batch_first=True) 210 | 211 | if __name__ == "__main__": 212 | dataset = LJSPEECH("./data/LJSpeech", 213 | encodec_bandwidth=BANDWIDTH, 214 | max_prompt_length=MAX_PROMPT_LENGTH) 215 | print("LJSpeech Dataset Slice:", len(dataset)) 216 | 217 | indices = list(range(len(dataset))) 218 | train_indices, test_indices = train_test_split(indices, test_size=0.1, random_state=42) 219 | 220 | train_sampler = SubsetRandomSampler(train_indices) 221 | test_sampler = SubsetRandomSampler(test_indices) 222 | eval_sampler = SubsetRandomSampler(test_indices) 223 | 224 | train_loader = DataLoader( 225 | dataset, 226 | batch_size=BATCH_SIZE, 227 | sampler=train_sampler, 228 | collate_fn=lambda batch: collate_fn(dataset, batch)) 229 | test_loader = DataLoader( 230 | dataset, 231 | batch_size=BATCH_SIZE, 232 | sampler=test_sampler, 233 | collate_fn=lambda batch: collate_fn(dataset, batch)) 234 | eval_loader = DataLoader( 235 | dataset, 236 | batch_size=1, 237 | sampler=test_sampler, 238 | collate_fn=lambda batch: batch) 239 | 240 | print("len(train_loader), len(test_loader):", len(train_loader), len(test_loader)) 241 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 242 | model = PhoneLM( 243 | n_phone_tokens=len(dataset.phone_dict), 244 | n_audio_tokens=1024).to(device) 245 | 246 | best_val = float("inf") 247 | 248 | NUM_BATCHES = EPOCHS * int(math.ceil(len(dataset) / BATCH_SIZE)) 249 | print("NUM_BATCHES:", NUM_BATCHES) 250 | 251 | # print("Model params:", model.megabyte.get_num_params()) 252 | 253 | # item = next(iter(train_loader))[1] 254 | # item_phone_tokens = item[-2] 255 | # item_audio_tokens = item[-1] 256 | # item_phone_tokens.shape, item_audio_tokens.shape 257 | # print(item[0], item[3]) 258 | 259 | # print("item_audio_tokens:", item_audio_tokens) 260 | # print("item_phone_tokens:", 261 | # item_phone_tokens, 262 | # [_get_model().phonemes[ph_id] 263 | # for ph_id in item_phone_tokens 264 | # if ph_id < len(_get_model().phonemes)]) 265 | 266 | # phone_prompt, audio_target, test_inp = multi_encode( 267 | # item_phone_tokens, 268 | # item_audio_tokens, 269 | # n_phone_tokens=len(dataset.phone_dict), 270 | # n_audio_tokens=1024, 271 | # max_clip_length=MAX_CLIP_LENGTH) 272 | 273 | optimizer = optim.Adam( 274 | model.parameters(), 275 | lr=MAX_LR) 276 | 277 | scaler = GradScaler() 278 | 279 | def train(model, trainloader): 280 | model.train() 281 | 282 | batch = next(iter(trainloader)) 283 | 284 | with torch.autocast( 285 | enabled=AMP, 286 | dtype=torch.bfloat16, 287 | device_type="cuda"): 288 | #padding_len = max(0, SEQ_LEN - test_inp.size(0)) 289 | #n_test_inp = F.pad(test_inp, (0, padding_len)) 290 | #batch = n_test_inp.unsqueeze(0) 291 | # print(batch.shape) 292 | loss = model(batch, return_loss=True) 293 | # loss = model(next(trainloader), return_loss=True) 294 | # loss.backward() 295 | return loss 296 | 297 | def test(model, test_loader, dataset): 298 | model.eval() 299 | 300 | batch = next(iter(test_loader)) 301 | with torch.no_grad(): 302 | loss = model(batch, return_loss = True) 303 | return loss 304 | 305 | pbar = tqdm(range(EPOCHS), mininterval=10., desc='training') 306 | batch_idx = 0 307 | for i in pbar: 308 | mem_gb = get_reserved_mem_gb() 309 | 310 | lr = get_lr(batch_idx) if DECAY_LR else MAX_LR 311 | 312 | for b in range(len(train_loader)): 313 | # Set LR 314 | lr = get_lr(batch_idx) if DECAY_LR else MAX_LR 315 | for param_group in optimizer.param_groups: 316 | param_group['lr'] = lr 317 | 318 | loss = train(model, train_loader) 319 | scaler.scale(loss).backward() 320 | scaler.step(optimizer) 321 | scaler.update() 322 | batch_idx += 1 323 | pbar.set_description( 324 | f"Reserved Memory (GB): {mem_gb}, loss: {loss.item()}, lr: {lr}, batch: {b}/{len(train_loader)}") 325 | 326 | # Validate every `n` steps (because it's time consuming) 327 | if i % VALIDATE_EVERY == 0: 328 | vloss = test(model, test_loader, dataset) 329 | print(f'validation loss: {vloss.item()}') 330 | pbar.set_description( 331 | f"Reserved Memory (GB): {mem_gb}, loss: {loss.item()}, vloss: {vloss.item()}, lr: {lr}") 332 | 333 | # Save best model every `n` steps. Set this to be high as the models are huge 334 | if vloss < best_val: 335 | best_val = vloss 336 | if SAVE: 337 | torch.save( 338 | model.state_dict(), 339 | f"./megabyte_{i}_{vloss}.pt") 340 | torch.save( 341 | optimizer.state_dict(), 342 | f"./megabyte_{i}_{vloss}_optim.pt") 343 | 344 | if i % PRINT_INTERVAL == 0: 345 | print(f"Reserved Memory (GB): {mem_gb}, loss: {loss.item()}, lr: {lr}") 346 | pbar.set_description(f"Reserved Memory (GB): {mem_gb}, loss: {loss.item()}, lr: {lr}") 347 | 348 | item = next(iter(eval_loader))[0] # [0] 349 | # print("item.shape:", item.shape) 350 | print("Generative Prompt:", item[-4]) 351 | item_phone_tokens = item[-2] 352 | item_audio_tokens = item[-1] 353 | # print("Item, Aud, Phone:", item_phone_tokens, item_audio_tokens) 354 | phone_prompt, _, _ = multi_encode( 355 | item_phone_tokens, 356 | item_audio_tokens, 357 | n_phone_tokens=len(dataset.phone_dict), 358 | n_audio_tokens=1024, 359 | max_clip_length=MAX_CLIP_LENGTH) 360 | prompt, sample = generate(model, phone_prompt) 361 | 362 | model.eval() 363 | 364 | try: 365 | out = generate_audio( 366 | sample, 367 | n_phone_tokens=len(dataset.phone_dict), 368 | n_audio_tokens=1024) 369 | 370 | # ETT, EAT ids: [1099, 1100] 371 | values = sample.cpu().numpy() 372 | ETT_S = np.where(values == 1099) 373 | EAT_S = np.where(values == 1100) 374 | 375 | print("ETT_S, EAT_S:", ETT_S, EAT_S) 376 | 377 | print(out) 378 | 379 | except Exception as e: 380 | print("Failure generating audio:", e) -------------------------------------------------------------------------------- /megabyte.py: -------------------------------------------------------------------------------- 1 | import math 2 | import functools 3 | from itertools import zip_longest 4 | 5 | import torch 6 | import torch.nn.functional as F 7 | from torch import nn, einsum 8 | 9 | from einops import rearrange, reduce, repeat, pack, unpack 10 | from einops.layers.torch import Rearrange 11 | 12 | from beartype import beartype 13 | from beartype.typing import Tuple, Union 14 | 15 | from attend import Attend 16 | 17 | from tqdm import tqdm 18 | 19 | # helpers 20 | 21 | def exists(val): 22 | return val is not None 23 | 24 | def default(val, d): 25 | return val if exists(val) else d 26 | 27 | def pack_one(t, pattern): 28 | return pack([t], pattern) 29 | 30 | def unpack_one(t, ps, pattern): 31 | return unpack(t, ps, pattern)[0] 32 | 33 | def remainder_to_mult(num, mult): 34 | return (mult - num % mult) % mult 35 | 36 | def cast_tuple(t, length = 1): 37 | return t if isinstance(t, tuple) else ((t,) * length) 38 | 39 | def reduce_mult(nums): 40 | return functools.reduce(lambda x, y: x * y, nums, 1) 41 | 42 | # tensor helpers 43 | 44 | def log(t, eps = 1e-20): 45 | return torch.log(t.clamp(min = eps)) 46 | 47 | def gumbel_noise(t): 48 | noise = torch.zeros_like(t).uniform_(0, 1) 49 | return -log(-log(noise)) 50 | 51 | def gumbel_sample(t, temperature = 1., dim = -1): 52 | return ((t / temperature) + gumbel_noise(t)).argmax(dim = dim) 53 | 54 | def top_k(logits, thres = 0.5): 55 | num_logits = logits.shape[-1] 56 | k = max(int((1 - thres) * num_logits), 1) 57 | val, ind = torch.topk(logits, k) 58 | probs = torch.full_like(logits, float('-inf')) 59 | probs.scatter_(1, ind, val) 60 | return probs 61 | 62 | # token shift, from Peng et al of RWKV 63 | 64 | def token_shift(t): 65 | t, t_shift = t.chunk(2, dim = -1) 66 | t_shift = F.pad(t_shift, (0, 0, 1, -1)) 67 | return torch.cat((t, t_shift), dim = -1) 68 | 69 | # rotary positional embedding 70 | 71 | class RotaryEmbedding(nn.Module): 72 | def __init__(self, dim, theta = 10000): 73 | super().__init__() 74 | inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim)) 75 | self.register_buffer("inv_freq", inv_freq) 76 | 77 | @property 78 | def device(self): 79 | return next(self.buffers()).device 80 | 81 | def forward(self, seq_len): 82 | t = torch.arange(seq_len, device = self.device).type_as(self.inv_freq) 83 | freqs = torch.einsum('i , j -> i j', t, self.inv_freq) 84 | freqs = torch.cat((freqs, freqs), dim = -1) 85 | return freqs 86 | 87 | def rotate_half(x): 88 | x1, x2 = x.chunk(2, dim=-1) 89 | return torch.cat((-x2, x1), dim=-1) 90 | 91 | def apply_rotary_pos_emb(pos, t): 92 | return t * pos.cos() + rotate_half(t) * pos.sin() 93 | 94 | # norm 95 | 96 | class RMSNorm(nn.Module): 97 | def __init__(self, dim, eps = 1e-8): 98 | super().__init__() 99 | self.scale = dim ** -0.5 100 | self.eps = eps 101 | self.g = nn.Parameter(torch.ones(dim)) 102 | 103 | def forward(self, x): 104 | norm = torch.norm(x, dim = -1, keepdim = True) * self.scale 105 | return x / norm.clamp(min = self.eps) * self.g 106 | 107 | # helper classes 108 | def FeedForward(*, dim, mult = 4, dropout = 0.): 109 | return nn.Sequential( 110 | RMSNorm(dim), 111 | nn.Linear(dim, dim * mult), 112 | nn.GELU(), 113 | nn.Dropout(dropout), 114 | nn.Linear(dim * mult, dim) 115 | ) 116 | 117 | class Attention(nn.Module): 118 | def __init__( 119 | self, 120 | *, 121 | dim, 122 | dim_head = 64, 123 | heads = 8, 124 | dropout = 0., 125 | flash = False 126 | ): 127 | super().__init__() 128 | self.scale = dim_head ** -0.5 129 | self.heads = heads 130 | inner_dim = dim_head * heads 131 | 132 | self.attend = Attend( 133 | causal = True, 134 | flash = flash, 135 | dropout = dropout 136 | ) 137 | 138 | self.dropout = nn.Dropout(dropout) 139 | self.norm = RMSNorm(dim) 140 | self.to_q = nn.Linear(dim, inner_dim, bias = False) 141 | self.to_kv = nn.Linear(dim, dim_head * 2, bias = False) 142 | self.to_out = nn.Linear(inner_dim, dim, bias = False) 143 | 144 | def forward(self, x, rotary_emb = None): 145 | h, device = self.heads, x.device 146 | 147 | x = self.norm(x) 148 | q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim = -1)) 149 | q = rearrange(q, 'b n (h d) -> b h n d', h = h) 150 | 151 | if exists(rotary_emb): 152 | q, k = map(lambda t: apply_rotary_pos_emb(rotary_emb, t), (q, k)) 153 | 154 | out = self.attend(q, k, v) 155 | 156 | out = rearrange(out, 'b h n d -> b n (h d)') 157 | return self.to_out(out) 158 | 159 | class Transformer(nn.Module): 160 | def __init__( 161 | self, 162 | *, 163 | dim, 164 | layers, 165 | dim_head = 64, 166 | heads = 8, 167 | attn_dropout = 0., 168 | ff_dropout = 0., 169 | ff_mult = 4, 170 | rel_pos = True, 171 | flash_attn = False 172 | ): 173 | super().__init__() 174 | self.rotary_emb = RotaryEmbedding(dim_head) if rel_pos else None 175 | self.layers = nn.ModuleList([]) 176 | 177 | for _ in range(layers): 178 | self.layers.append(nn.ModuleList([ 179 | Attention(dim = dim, dim_head = dim_head, heads = heads, dropout = attn_dropout, flash = flash_attn), 180 | FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout) 181 | ])) 182 | 183 | self.norm = RMSNorm(dim) 184 | 185 | def forward(self, x): 186 | n = x.shape[-2] 187 | rotary_emb = self.rotary_emb(n) if exists(self.rotary_emb) else None 188 | 189 | for attn, ff in self.layers: 190 | x = attn(token_shift(x), rotary_emb = rotary_emb) + x 191 | x = ff(token_shift(x)) + x 192 | 193 | return self.norm(x) 194 | 195 | # main class 196 | class MEGABYTE(nn.Module): 197 | @beartype 198 | def __init__( 199 | self, 200 | *, 201 | num_tokens, 202 | dim: Union[Tuple, int], 203 | depth: Tuple, 204 | max_seq_len: Tuple, 205 | dim_head = 64, 206 | heads = 8, 207 | attn_dropout = 0., 208 | ff_mult = 4, 209 | ff_dropout = 0., 210 | pad_id = 0, 211 | rel_pos = False, 212 | pos_emb = False, 213 | flash_attn = False 214 | ): 215 | super().__init__() 216 | 217 | # simplified configuration for each stage of the hierarchy 218 | # depth = (2, 2, 4) would translate to depth 2 at first stage, depth 2 second stage, depth 4 third 219 | # max_seq_len = (16, 8, 4) would translate to max sequence length of 16 at first stage, length of 8 at second stage, length of 4 for last 220 | 221 | assert isinstance(depth, tuple) and isinstance(max_seq_len, tuple) 222 | assert len(depth) == len(max_seq_len) 223 | 224 | self.stages = len(depth) 225 | dim = cast_tuple(dim, self.stages) 226 | 227 | assert len(dim) == self.stages 228 | 229 | coarsest_dim, *_, fine_dim = dim 230 | 231 | self.max_seq_len = max_seq_len 232 | 233 | self.start_tokens = nn.ParameterList([nn.Parameter(torch.randn(h_dim)) for h_dim, seq_len in zip(dim, max_seq_len)]) 234 | self.pos_embs = nn.ModuleList([nn.Embedding(seq_len, h_dim) for h_dim, seq_len in zip(dim, max_seq_len)]) if pos_emb else None 235 | 236 | self.token_embs = nn.ModuleList([]) 237 | 238 | patch_size = 1 239 | self.token_embs.append(nn.Embedding(num_tokens, fine_dim)) 240 | 241 | for dim_out, seq_len in zip(reversed(dim[:-1]), reversed(max_seq_len[1:])): 242 | patch_size *= seq_len 243 | 244 | self.token_embs.append(nn.Sequential( 245 | nn.Embedding(num_tokens, fine_dim), 246 | Rearrange('... r d -> ... (r d)'), 247 | nn.LayerNorm(patch_size * fine_dim), 248 | nn.Linear(patch_size * fine_dim, dim_out), 249 | nn.LayerNorm(dim_out) 250 | )) 251 | 252 | self.transformers = nn.ModuleList([]) 253 | self.to_next_transformer_projections = nn.ModuleList([]) 254 | 255 | for h_dim, next_h_dim, stage_depth, next_seq_len in zip_longest(dim, dim[1:], depth, max_seq_len[1:]): 256 | self.transformers.append(Transformer( 257 | dim = h_dim, 258 | layers = stage_depth, 259 | dim_head = dim_head, 260 | heads = heads, 261 | attn_dropout = attn_dropout, 262 | ff_dropout = ff_dropout, 263 | ff_mult = ff_mult, 264 | rel_pos = rel_pos, 265 | flash_attn = flash_attn 266 | )) 267 | 268 | proj = nn.Identity() 269 | 270 | if exists(next_h_dim) and next_h_dim != dim: 271 | proj = nn.Sequential( 272 | Rearrange('b ... d -> b (...) d'), 273 | nn.Linear(h_dim, next_h_dim * next_seq_len), 274 | Rearrange('b m (n d) -> (b m) n d', n = next_seq_len) 275 | ) 276 | 277 | self.to_next_transformer_projections.append(proj) 278 | 279 | self.to_logits = nn.Linear(fine_dim, num_tokens) 280 | self.pad_id = pad_id 281 | 282 | # report number of parameters 283 | print("number of parameters: %.2fM" % (self.get_num_params()/1e6,)) 284 | 285 | self.apply(self._init_weights) 286 | 287 | def get_num_params(self, non_embedding=True): 288 | """ 289 | Return the number of parameters in the model. 290 | For non-embedding count (default), the position embeddings get subtracted. 291 | The token embeddings would too, except due to the parameter sharing these 292 | params are actually used as weights in the final layer, so we include them. 293 | """ 294 | n_params = sum([p.numel() for p in self.parameters()]) 295 | # if non_embedding: 296 | # n_params -= self.transformer.wpe.weight.numel() 297 | return n_params 298 | 299 | # def _init_weights(self, module): 300 | # if isinstance(module, nn.Linear): 301 | # # torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) 302 | # torch.nn.init.normal_(module.weight, mean=0.0, std=0.006) 303 | 304 | # if module.bias is not None: 305 | # torch.nn.init.zeros_(module.bias) 306 | # elif isinstance(module, nn.Embedding): 307 | # # torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) 308 | # torch.nn.init.normal_(module.weight, mean=0.0, std=0.006) 309 | 310 | def _init_weights(self, module): 311 | # Weight initialisation from MEGABYTE paper, A.1 Training Details 312 | if isinstance(module, nn.Linear): # or isinstance(module, nn.Embedding): 313 | # Init with normal distribution 314 | torch.nn.init.normal_(module.weight, mean=0.0, std=0.006) 315 | 316 | # Truncate weights to lie within two standard deviations 317 | with torch.no_grad(): 318 | module.weight[module.weight > 0.012] = 0.012 319 | module.weight[module.weight < -0.012] = -0.012 320 | 321 | # Bias is initialized to zero if exists 322 | if module.bias is not None: 323 | torch.nn.init.zeros_(module.bias) 324 | 325 | def generate(self, prime = None, filter_thres = 0.9, temperature = 1., default_batch_size = 1): 326 | total_seq_len = reduce_mult(self.max_seq_len) 327 | device = next(self.parameters()).device 328 | 329 | if not exists(prime): 330 | prime = torch.empty((default_batch_size, 0), dtype = torch.long, device = device) 331 | 332 | seq = prime 333 | batch = seq.shape[0] 334 | 335 | for _ in tqdm(range(total_seq_len - seq.shape[-1])): 336 | logits = self.forward(seq)[:, -1] 337 | logits = top_k(logits, thres = filter_thres) 338 | sampled = gumbel_sample(logits, dim = -1, temperature = temperature) 339 | seq = torch.cat((seq, rearrange(sampled, 'b -> b 1')), dim = -1) 340 | 341 | return seq.reshape(batch, *self.max_seq_len) 342 | 343 | def forward_empty(self, batch_size): 344 | # take care of special case 345 | # where you sample from input of 0 (start token only) 346 | 347 | prev_stage_tokens_repr = None 348 | 349 | for stage_start_tokens, transformer, proj in zip(self.start_tokens, self.transformers, self.to_next_transformer_projections): 350 | tokens = repeat(stage_start_tokens, 'd -> b 1 d', b = batch_size) 351 | 352 | if exists(prev_stage_tokens_repr): 353 | tokens = tokens + prev_stage_tokens_repr[..., :tokens.shape[-2], :] 354 | 355 | tokens = transformer(tokens) 356 | prev_stage_tokens_repr = proj(tokens) 357 | 358 | return self.to_logits(tokens) 359 | 360 | def forward(self, ids, return_loss = False): 361 | batch = ids.shape[0] 362 | 363 | assert ids.ndim in {2, self.stages + 1} 364 | flattened_dims = ids.ndim == 2 365 | ids_orig_ndim = ids.ndim 366 | 367 | if ids.numel() == 0: 368 | return self.forward_empty(ids.shape[0]) 369 | 370 | if flattened_dims: 371 | # allow for ids to be given in the shape of (batch, seq) 372 | # in which case it will be auto-padded to the next nearest multiple of depth seq len 373 | seq_len = ids.shape[-1] 374 | multiple_of = reduce_mult(self.max_seq_len[1:]) 375 | padding = remainder_to_mult(seq_len, multiple_of) 376 | ids = F.pad(ids, (0, padding), value = self.pad_id) 377 | ids = ids.reshape(batch, -1, *self.max_seq_len[1:]) 378 | 379 | b, *prec_dims, device = *ids.shape, ids.device 380 | 381 | # check some dimensions 382 | 383 | assert prec_dims[0] <= self.max_seq_len[0], 'the first dimension of your axial autoregressive transformer must be less than the first tuple element of max_seq_len (like any autoregressive transformer)' 384 | assert tuple(prec_dims[1:]) == tuple(self.max_seq_len[1:]), 'all subsequent dimensions must match exactly' 385 | 386 | # get tokens for all hierarchical stages, reducing by appropriate dimensions 387 | # and adding the absolute positional embeddings 388 | 389 | tokens_at_stages = [] 390 | pos_embs = default(self.pos_embs, (None,)) 391 | 392 | for ind, pos_emb, token_emb in zip_longest(range(len(prec_dims)), pos_embs, self.token_embs): 393 | is_first = ind == 0 394 | 395 | tokens = token_emb(ids) 396 | 397 | if exists(pos_emb): 398 | positions = pos_emb(torch.arange(tokens.shape[-2], device = device)) 399 | tokens = tokens + positions 400 | 401 | tokens_at_stages.insert(0, tokens) 402 | 403 | if is_first: 404 | continue 405 | 406 | ids = rearrange(ids, '... m n -> ... (m n)') 407 | 408 | # the un-pixelshuffled representations of the previous hierarchy, starts with None 409 | 410 | prev_stage_tokens_repr = None 411 | 412 | # spatial tokens is tokens with depth pos reduced along depth dimension + spatial positions 413 | 414 | for stage_start_tokens, stage_tokens, transformer, proj in zip(self.start_tokens, tokens_at_stages, self.transformers, self.to_next_transformer_projections): 415 | stage_tokens, ps = pack_one(stage_tokens, '* n d') 416 | stage_start_tokens = repeat(stage_start_tokens, 'f -> b 1 f', b = stage_tokens.shape[0]) 417 | 418 | # concat start token 419 | 420 | stage_tokens = torch.cat(( 421 | stage_start_tokens, 422 | stage_tokens, 423 | ), dim = -2) 424 | 425 | # sum the previous hierarchy's representation 426 | 427 | if exists(prev_stage_tokens_repr): 428 | prev_stage_tokens_repr = F.pad(prev_stage_tokens_repr, (0, 0, 1, 0), value = 0.) 429 | stage_tokens = stage_tokens + prev_stage_tokens_repr 430 | 431 | attended = transformer(stage_tokens) 432 | 433 | attended = unpack_one(attended, ps, '* n d') 434 | 435 | # project for next stage in the hierarchy 436 | 437 | prev_stage_tokens_repr = proj(attended[..., :-1, :]) 438 | 439 | # project to logits 440 | 441 | logits = self.to_logits(attended) 442 | 443 | start_tokens = logits[(slice(None), *((0,) * (logits.ndim - 2)), slice(None))] 444 | start_tokens = rearrange(start_tokens, 'b d -> b 1 d') 445 | 446 | logits = logits[..., 1:, :] 447 | 448 | if not return_loss: 449 | 450 | if flattened_dims: 451 | logits = rearrange(logits, 'b ... c -> b (...) c') 452 | logits = logits[:, :seq_len] 453 | 454 | return logits 455 | 456 | logits = rearrange(logits, 'b ... c -> b (...) c') 457 | logits = torch.cat((start_tokens, logits), dim = -2) 458 | 459 | preds = rearrange(logits, 'b n c -> b c n') 460 | labels = rearrange(ids, 'b ... -> b (...)') 461 | 462 | loss = F.cross_entropy( 463 | preds[..., :-1], 464 | labels, 465 | ignore_index = self.pad_id 466 | ) 467 | 468 | return loss -------------------------------------------------------------------------------- /main.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# PhoneLM" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "## Test `G2P` and `Encodec`" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": null, 20 | "metadata": {}, 21 | "outputs": [], 22 | "source": [ 23 | "!pip install g2p_en encodec" 24 | ] 25 | }, 26 | { 27 | "cell_type": "markdown", 28 | "metadata": {}, 29 | "source": [ 30 | "### `G2P`" 31 | ] 32 | }, 33 | { 34 | "cell_type": "code", 35 | "execution_count": 2, 36 | "metadata": {}, 37 | "outputs": [], 38 | "source": [ 39 | "from g2p_en import G2p" 40 | ] 41 | }, 42 | { 43 | "cell_type": "code", 44 | "execution_count": 3, 45 | "metadata": {}, 46 | "outputs": [], 47 | "source": [ 48 | "import torch\n", 49 | "import random\n", 50 | "import string\n", 51 | "from functools import cache\n", 52 | "from tqdm import tqdm\n", 53 | "\n", 54 | "@cache\n", 55 | "def _get_model():\n", 56 | " return G2p()\n", 57 | "\n", 58 | "@cache\n", 59 | "def _get_graphs(path):\n", 60 | " with open(path, \"r\") as f:\n", 61 | " graphs = f.read()\n", 62 | " return graphs\n", 63 | "\n", 64 | "def encode(graphs: str) -> list[str]:\n", 65 | " g2p = _get_model()\n", 66 | " phones = g2p(graphs)\n", 67 | " ignored = {\" \", *string.punctuation}\n", 68 | " return [\"_\" if p in ignored else p for p in phones]\n", 69 | "\n", 70 | "@torch.no_grad()\n", 71 | "def write_phones(folder, suffix=\".normalized.txt\"):\n", 72 | " print(\"ello?\")\n", 73 | " paths = list(folder.rglob(f\"*{suffix}\"))\n", 74 | " random.shuffle(paths)\n", 75 | "\n", 76 | " print(\"paths:\", paths)\n", 77 | " for path in tqdm(paths):\n", 78 | " phone_path = path.with_name(path.stem.split(\".\")[0] + \".phn.txt\")\n", 79 | " if phone_path.exists():\n", 80 | " continue\n", 81 | " print(\"?\")\n", 82 | " graphs = _get_graphs(path)\n", 83 | " phones = encode(graphs)\n", 84 | " with open(phone_path, \"w\") as f:\n", 85 | " f.write(\" \".join(phones))" 86 | ] 87 | }, 88 | { 89 | "cell_type": "code", 90 | "execution_count": 4, 91 | "metadata": {}, 92 | "outputs": [ 93 | { 94 | "name": "stdout", 95 | "output_type": "stream", 96 | "text": [ 97 | "ello?\n", 98 | "paths: [WindowsPath('data/text/test.normalized.txt')]\n" 99 | ] 100 | }, 101 | { 102 | "name": "stderr", 103 | "output_type": "stream", 104 | "text": [ 105 | "100%|██████████| 1/1 [00:00 1 q t\")\n", 168 | " wavs, sr = decode(codes=resps, bandwidth=BANDWIDTH)\n", 169 | " soundfile.write(str(path), wavs.cpu()[0, 0], sr)\n", 170 | "\n", 171 | "def _replace_file_extension(path, suffix):\n", 172 | " return (path.parent / path.name.split(\".\")[0]).with_suffix(suffix)\n", 173 | "\n", 174 | "@torch.inference_mode()\n", 175 | "def encode(wav: Tensor, sr: int, bandwidth=6.0, device=\"cuda\"):\n", 176 | " \"\"\"\n", 177 | " Args:\n", 178 | " wav: (t)\n", 179 | " sr: int\n", 180 | " \"\"\"\n", 181 | " model = _load_model(bandwidth, device)\n", 182 | " wav = wav.unsqueeze(0)\n", 183 | " wav = convert_audio(wav, sr, model.sample_rate, model.channels)\n", 184 | " wav = wav.to(device)\n", 185 | " encoded_frames = model.encode(wav)\n", 186 | " qnt = torch.cat([encoded[0] for encoded in encoded_frames], dim=-1) # (b q t)\n", 187 | " return qnt\n", 188 | "\n", 189 | "def encode_from_file(path, bandwidth=6.0, device=\"cuda\"):\n", 190 | " wav, sr = torchaudio.load(str(path))\n", 191 | " if wav.shape[0] == 2:\n", 192 | " wav = wav[:1]\n", 193 | " return encode(wav, sr, bandwidth, device)\n", 194 | "\n", 195 | "def quantize_audio(folder, suffix=\".wav\"):\n", 196 | " paths = [*folder.rglob(f\"*{suffix}\")]\n", 197 | " random.shuffle(paths)\n", 198 | "\n", 199 | " for path in tqdm(paths):\n", 200 | " out_path = _replace_file_extension(path, \".qnt.pt\")\n", 201 | " if out_path.exists():\n", 202 | " continue\n", 203 | " qnt = encode_from_file(path, BANDWIDTH)\n", 204 | " print(qnt.shape)\n", 205 | " torch.save(qnt.cpu(), out_path)\n", 206 | "\n", 207 | "def decode_files(folder, suffix=\".qnt.pt\"):\n", 208 | " paths = [*folder.rglob(f\"*{suffix}\")]\n", 209 | " random.shuffle(paths)\n", 210 | "\n", 211 | " for path in tqdm(paths):\n", 212 | " out_path = _replace_file_extension(path, \".qt.wav\")\n", 213 | " if out_path.exists():\n", 214 | " continue\n", 215 | " fi = rearrange(torch.load(path).squeeze(0).cuda(), \"q t -> t q\")\n", 216 | " decode_to_file(fi, out_path)" 217 | ] 218 | }, 219 | { 220 | "cell_type": "code", 221 | "execution_count": 3, 222 | "metadata": {}, 223 | "outputs": [], 224 | "source": [ 225 | "# from pathlib import Path\n", 226 | "# quantize_audio(Path(\"./data/audio\"))\n", 227 | "# decode_files(Path(\"./data/audio\"))" 228 | ] 229 | }, 230 | { 231 | "cell_type": "code", 232 | "execution_count": 4, 233 | "metadata": {}, 234 | "outputs": [], 235 | "source": [ 236 | "# torch.load(\"data/audio/test.qnt.pt\").shape" 237 | ] 238 | }, 239 | { 240 | "cell_type": "markdown", 241 | "metadata": {}, 242 | "source": [ 243 | "#### Generate Audio from Tensor" 244 | ] 245 | }, 246 | { 247 | "cell_type": "code", 248 | "execution_count": 9, 249 | "metadata": {}, 250 | "outputs": [], 251 | "source": [ 252 | "audio_tensor = torch.tensor([[1019, 662],\n", 253 | " [ 598, 25],\n", 254 | " [ 321, 463],\n", 255 | " [1063, 575],\n", 256 | " [ 745, 727],\n", 257 | " [1073, 344],\n", 258 | " [1098, 344],\n", 259 | " [1046, 959],\n", 260 | " [1062, 874],\n", 261 | " [1059, 804],\n", 262 | " [1038, 1010],\n", 263 | " [1081, 577],\n", 264 | " [1098, 323],\n", 265 | " [1049, 858],\n", 266 | " [1034, 278],\n", 267 | " [1098, 469],\n", 268 | " [1069, 626],\n", 269 | " [1034, 482],\n", 270 | " [1071, 398],\n", 271 | " [1063, 858],\n", 272 | " [1083, 443],\n", 273 | " [1034, 418],\n", 274 | " [1072, 632],\n", 275 | " [1075, 914],\n", 276 | " [1098, 1010],\n", 277 | " [1094, 357],\n", 278 | " [1087, 898],\n", 279 | " [1084, 702],\n", 280 | " [1099, 654],\n", 281 | " [ 835, 364],\n", 282 | " [ 208, 416],\n", 283 | " [ 987, 722],\n", 284 | " [ 872, 708],\n", 285 | " [ 994, 399],\n", 286 | " [ 264, 648],\n", 287 | " [ 264, 1007],\n", 288 | " [1001, 961],\n", 289 | " [ 598, 320],\n", 290 | " [ 360, 993],\n", 291 | " [ 879, 747],\n", 292 | " [ 325, 700],\n", 293 | " [ 52, 770],\n", 294 | " [ 257, 268],\n", 295 | " [ 257, 824],\n", 296 | " [ 819, 662],\n", 297 | " [ 709, 567],\n", 298 | " [ 656, 662],\n", 299 | " [ 43, 602],\n", 300 | " [1038, 742],\n", 301 | " [ 24, 964],\n", 302 | " [1098, 289],\n", 303 | " [1099, 722],\n", 304 | " [ 855, 870],\n", 305 | " [ 25, 561],\n", 306 | " [ 472, 519],\n", 307 | " [ 472, 754],\n", 308 | " [ 475, 1038],\n", 309 | " [ 404, 857],\n", 310 | " [ 331, 913],\n", 311 | " [ 574, 434],\n", 312 | " [ 537, 154],\n", 313 | " [1022, 612],\n", 314 | " [ 324, 321],\n", 315 | " [ 937, 563],\n", 316 | " [ 230, 1001],\n", 317 | " [ 912, 563],\n", 318 | " [ 912, 807],\n", 319 | " [ 928, 99],\n", 320 | " [ 928, 99],\n", 321 | " [ 942, 228],\n", 322 | " [ 604, 772],\n", 323 | " [ 904, 94],\n", 324 | " [ 472, 1063],\n", 325 | " [ 52, 812],\n", 326 | " [ 52, 645],\n", 327 | " [ 52, 697],\n", 328 | " [ 257, 387],\n", 329 | " [ 52, 362],\n", 330 | " [ 935, 247],\n", 331 | " [ 983, 65],\n", 332 | " [ 683, 874],\n", 333 | " [ 155, 518],\n", 334 | " [ 30, 822],\n", 335 | " [ 855, 467],\n", 336 | " [ 904, 909],\n", 337 | " [ 904, 529],\n", 338 | " [ 904, 852],\n", 339 | " [ 855, 399],\n", 340 | " [ 855, 470],\n", 341 | " [ 855, 1023],\n", 342 | " [ 106, 870],\n", 343 | " [ 176, 580],\n", 344 | " [ 574, 669],\n", 345 | " [ 502, 888],\n", 346 | " [ 588, 708],\n", 347 | " [ 782, 700],\n", 348 | " [ 588, 743],\n", 349 | " [ 890, 417],\n", 350 | " [ 373, 822],\n", 351 | " [ 160, 514],\n", 352 | " [ 47, 455],\n", 353 | " [ 47, 328],\n", 354 | " [ 47, 259],\n", 355 | " [ 909, 971],\n", 356 | " [1023, 962],\n", 357 | " [ 577, 367]]).cuda()" 358 | ] 359 | }, 360 | { 361 | "cell_type": "code", 362 | "execution_count": 10, 363 | "metadata": {}, 364 | "outputs": [], 365 | "source": [ 366 | "audio_tensor = torch.clamp(audio_tensor, min=0, max=1023)" 367 | ] 368 | }, 369 | { 370 | "cell_type": "code", 371 | "execution_count": 11, 372 | "metadata": {}, 373 | "outputs": [ 374 | { 375 | "data": { 376 | "text/plain": [ 377 | "torch.Size([106, 2])" 378 | ] 379 | }, 380 | "execution_count": 11, 381 | "metadata": {}, 382 | "output_type": "execute_result" 383 | } 384 | ], 385 | "source": [ 386 | "audio_tensor.shape" 387 | ] 388 | }, 389 | { 390 | "cell_type": "code", 391 | "execution_count": 12, 392 | "metadata": {}, 393 | "outputs": [], 394 | "source": [ 395 | "decode_to_file(audio_tensor, \"general_out.wav\")" 396 | ] 397 | }, 398 | { 399 | "cell_type": "markdown", 400 | "metadata": {}, 401 | "source": [ 402 | "## Dataset" 403 | ] 404 | }, 405 | { 406 | "cell_type": "markdown", 407 | "metadata": {}, 408 | "source": [ 409 | "### LJSpeech" 410 | ] 411 | }, 412 | { 413 | "cell_type": "code", 414 | "execution_count": 1, 415 | "metadata": {}, 416 | "outputs": [], 417 | "source": [ 418 | "BANDWIDTH_IDX = 1 # original VALL-E\n", 419 | "CODEBOOKS = [2, 4, 8, 16, 32]\n", 420 | "BANDWIDTHS = [1.5, 3.0, 6.0, 12.0, 24.0]\n", 421 | "BANDWIDTH = BANDWIDTHS[BANDWIDTH_IDX]\n", 422 | "CODEBOOK = CODEBOOKS[BANDWIDTH_IDX]\n", 423 | "\n", 424 | "import torchaudio\n", 425 | "from ljspeech import LJSPEECH\n", 426 | "DATASET_PATH = \"./data/LJSpeech/\"\n", 427 | "dataset = LJSPEECH(\n", 428 | " \"./data/LJSpeech\",\n", 429 | " encodec_bandwidth=BANDWIDTH)" 430 | ] 431 | }, 432 | { 433 | "cell_type": "code", 434 | "execution_count": 2, 435 | "metadata": {}, 436 | "outputs": [ 437 | { 438 | "data": { 439 | "text/plain": [ 440 | "1919" 441 | ] 442 | }, 443 | "execution_count": 2, 444 | "metadata": {}, 445 | "output_type": "execute_result" 446 | } 447 | ], 448 | "source": [ 449 | "len(dataset)" 450 | ] 451 | }, 452 | { 453 | "cell_type": "code", 454 | "execution_count": 3, 455 | "metadata": {}, 456 | "outputs": [ 457 | { 458 | "data": { 459 | "text/plain": [ 460 | "torch.Size([1, 4, 143])" 461 | ] 462 | }, 463 | "execution_count": 3, 464 | "metadata": {}, 465 | "output_type": "execute_result" 466 | } 467 | ], 468 | "source": [ 469 | "dataset[0][-1].shape" 470 | ] 471 | }, 472 | { 473 | "cell_type": "code", 474 | "execution_count": 4, 475 | "metadata": {}, 476 | "outputs": [ 477 | { 478 | "data": { 479 | "text/plain": [ 480 | "'\\nfileid_audio,\\nwaveform,\\nsample_rate,\\ntranscript,\\nnormalized_transcript,\\nphones,\\nphone_ids,\\ncodes\\n'" 481 | ] 482 | }, 483 | "execution_count": 4, 484 | "metadata": {}, 485 | "output_type": "execute_result" 486 | } 487 | ], 488 | "source": [ 489 | "import torch\n", 490 | "device = torch.cuda.current_device()\n", 491 | "\n", 492 | "\"\"\"\n", 493 | "fileid_audio,\n", 494 | "waveform,\n", 495 | "sample_rate,\n", 496 | "transcript,\n", 497 | "normalized_transcript,\n", 498 | "phones,\n", 499 | "phone_ids,\n", 500 | "codes\n", 501 | "\"\"\"\n", 502 | "\n", 503 | "# def collate_fn(batch) -> torch.tensor:\n", 504 | "# audio_tokens = []\n", 505 | "# phone_tokens = []\n", 506 | "\n", 507 | "# for item in batch:\n", 508 | "# cur_aud_tok = torch.tensor(item[7], device=device)\n", 509 | "# cur_phonemes = torch.tensor(item[6], device=device)\n", 510 | "# audio_tokens.append(cur_aud_tok)\n", 511 | "# phone_tokens.append(cur_phonemes)\n", 512 | "\n", 513 | "# # audio_tokens = torch.tensor(phone_tokens, device=device)\n", 514 | "# audio_tokens = nn.utils.rnn.pad_sequence(audio_tokens, batch_first=True)\n", 515 | "# # phone_tokens = torch.tensor(phone_tokens, device=device)\n", 516 | "# phone_tokens = nn.utils.rnn.pad_sequence(phone_tokens, batch_first=True)\n", 517 | "\n", 518 | "# return batch, phone_tokens, audio_tokens" 519 | ] 520 | }, 521 | { 522 | "cell_type": "code", 523 | "execution_count": 2, 524 | "metadata": {}, 525 | "outputs": [], 526 | "source": [ 527 | "import torch\n", 528 | "import torchaudio\n", 529 | "from torch.utils.data import DataLoader, SubsetRandomSampler\n", 530 | "from sklearn.model_selection import train_test_split\n", 531 | "\n", 532 | "indices = list(range(len(dataset)))\n", 533 | "train_indices, test_indices = train_test_split(indices, test_size=0.1, random_state=42)\n", 534 | "\n", 535 | "train_sampler = SubsetRandomSampler(train_indices)\n", 536 | "test_sampler = SubsetRandomSampler(test_indices)\n", 537 | "\n", 538 | "train_loader = DataLoader(dataset, batch_size=1, sampler=train_sampler, collate_fn=lambda x: x)\n", 539 | "test_loader = DataLoader(dataset, batch_size=1, sampler=test_sampler, collate_fn=lambda x: x)\n", 540 | "\n", 541 | "# train_loader = DataLoader(dataset, batch_size=32, sampler=train_sampler, collate_fn=collate_fn)\n", 542 | "# test_loader = DataLoader(dataset, batch_size=32, sampler=test_sampler, collate_fn=collate_fn)" 543 | ] 544 | }, 545 | { 546 | "cell_type": "code", 547 | "execution_count": 3, 548 | "metadata": {}, 549 | "outputs": [], 550 | "source": [ 551 | "# it = next(iter(train_loader))\n", 552 | "# it" 553 | ] 554 | }, 555 | { 556 | "cell_type": "code", 557 | "execution_count": 4, 558 | "metadata": {}, 559 | "outputs": [ 560 | { 561 | "data": { 562 | "text/plain": [ 563 | "(1727, 192)" 564 | ] 565 | }, 566 | "execution_count": 4, 567 | "metadata": {}, 568 | "output_type": "execute_result" 569 | } 570 | ], 571 | "source": [ 572 | "len(train_loader), len(test_loader)" 573 | ] 574 | }, 575 | { 576 | "cell_type": "code", 577 | "execution_count": 5, 578 | "metadata": {}, 579 | "outputs": [], 580 | "source": [ 581 | "item = next(iter(train_loader))" 582 | ] 583 | }, 584 | { 585 | "cell_type": "code", 586 | "execution_count": 6, 587 | "metadata": {}, 588 | "outputs": [ 589 | { 590 | "data": { 591 | "text/plain": [ 592 | "[(WindowsPath('data/LJSpeech/LJSpeech-1.1/wavs/LJ050-0137.wav'),\n", 593 | " tensor([[ 0.0003, 0.0004, 0.0005, ..., -0.0008, -0.0008, -0.0007]]),\n", 594 | " 22050,\n", 595 | " 'FBI, and the Secret Service.',\n", 596 | " 'FBI, and the Secret Service.',\n", 597 | " ['B',\n", 598 | " 'AY1',\n", 599 | " '_',\n", 600 | " '_',\n", 601 | " '_',\n", 602 | " 'AH0',\n", 603 | " 'N',\n", 604 | " 'D',\n", 605 | " '_',\n", 606 | " 'DH',\n", 607 | " 'AH0',\n", 608 | " '_',\n", 609 | " 'S',\n", 610 | " 'IY1',\n", 611 | " 'K',\n", 612 | " 'R',\n", 613 | " 'AH0',\n", 614 | " 'T',\n", 615 | " '_',\n", 616 | " 'S',\n", 617 | " 'ER1',\n", 618 | " 'V',\n", 619 | " 'AH0',\n", 620 | " 'S',\n", 621 | " '_',\n", 622 | " '_'],\n", 623 | " tensor([22, 20, 74, 74, 74, 10, 48, 24, 74, 25, 10, 74, 58, 42, 45, 57, 10, 60,\n", 624 | " 74, 58, 30, 69, 10, 58, 74, 74], device='cuda:0'),\n", 625 | " tensor([[[ 865, 59, 309, 392, 695, 361, 706, 913, 822, 325, 176,\n", 626 | " 438, 438, 360, 360, 176, 176, 106, 257, 106, 106, 408,\n", 627 | " 63, 913, 801, 908, 801, 611, 530, 151, 944, 971, 347,\n", 628 | " 523, 855, 25, 593, 695, 723, 683, 169, 203, 760, 683,\n", 629 | " 240, 925, 925, 20, 162, 216, 216, 216, 793, 793, 901,\n", 630 | " 402, 216, 216, 291, 495, 881, 495, 598, 860, 136, 699,\n", 631 | " 430, 855, 835, 876, 738, 408, 106, 738, 106, 738, 106,\n", 632 | " 738, 738, 738, 738, 738, 738, 738, 106, 738, 408, 408,\n", 633 | " 408, 408, 408, 408, 408, 408, 408, 408, 408, 408, 408,\n", 634 | " 408, 408, 408, 408, 677, 804, 1006, 588, 659, 788, 222,\n", 635 | " 645, 1021, 645, 1022, 208, 860, 598, 1001, 208, 325, 934,\n", 636 | " 890, 784, 944, 148, 148, 574, 574, 47, 574, 160, 574,\n", 637 | " 574, 574, 574, 53, 433, 945, 530, 611, 344, 208, 339,\n", 638 | " 106, 1017, 430, 1017, 91, 475, 323, 936, 987, 23, 151,\n", 639 | " 148, 106, 1019, 160, 574, 47, 47, 160, 47, 574, 574,\n", 640 | " 463, 25, 324, 392, 476, 658, 694, 658, 983, 185, 185,\n", 641 | " 890, 879, 879, 395, 325, 185, 185, 523, 983, 208, 30,\n", 642 | " 779, 432, 276, 463, 148, 148, 160, 160, 160, 463, 373,\n", 643 | " 160, 160, 160, 148, 463, 25, 738, 106, 835, 855, 408,\n", 644 | " 835, 738, 738, 408, 738],\n", 645 | " [ 687, 570, 869, 271, 98, 559, 1011, 696, 222, 516, 959,\n", 646 | " 948, 836, 785, 1023, 928, 993, 993, 913, 928, 928, 928,\n", 647 | " 182, 844, 662, 662, 846, 846, 378, 414, 336, 559, 964,\n", 648 | " 1016, 800, 420, 222, 27, 549, 549, 555, 549, 4, 549,\n", 649 | " 549, 1001, 1001, 390, 905, 596, 295, 685, 435, 834, 527,\n", 650 | " 592, 14, 446, 704, 602, 602, 357, 966, 824, 259, 880,\n", 651 | " 877, 870, 700, 404, 700, 424, 424, 518, 913, 518, 913,\n", 652 | " 913, 544, 424, 544, 544, 544, 544, 913, 518, 424, 424,\n", 653 | " 424, 424, 424, 424, 424, 424, 518, 518, 424, 913, 913,\n", 654 | " 518, 424, 518, 518, 791, 833, 565, 177, 75, 400, 444,\n", 655 | " 668, 598, 745, 289, 345, 414, 948, 259, 841, 896, 748,\n", 656 | " 252, 307, 564, 211, 452, 71, 4, 973, 646, 910, 160,\n", 657 | " 1010, 758, 754, 185, 727, 140, 517, 16, 74, 857, 516,\n", 658 | " 857, 857, 758, 758, 984, 961, 721, 351, 182, 265, 673,\n", 659 | " 336, 964, 424, 973, 1010, 541, 541, 857, 973, 910, 993,\n", 660 | " 974, 222, 615, 243, 527, 441, 307, 269, 349, 444, 945,\n", 661 | " 269, 496, 236, 363, 458, 534, 880, 0, 458, 700, 560,\n", 662 | " 388, 964, 580, 909, 541, 857, 36, 541, 160, 1010, 1010,\n", 663 | " 209, 209, 471, 133, 92, 957, 518, 913, 518, 857, 424,\n", 664 | " 424, 913, 913, 424, 544],\n", 665 | " [ 970, 52, 538, 538, 659, 344, 287, 476, 225, 698, 843,\n", 666 | " 711, 361, 814, 188, 893, 551, 893, 893, 893, 893, 989,\n", 667 | " 868, 951, 381, 435, 47, 47, 365, 918, 510, 255, 618,\n", 668 | " 711, 852, 767, 116, 94, 500, 979, 500, 302, 694, 707,\n", 669 | " 132, 641, 641, 864, 413, 864, 542, 296, 405, 560, 64,\n", 670 | " 672, 467, 864, 823, 911, 906, 868, 431, 1012, 551, 918,\n", 671 | " 198, 428, 1000, 675, 653, 982, 678, 982, 937, 36, 937,\n", 672 | " 937, 786, 678, 786, 1007, 653, 653, 937, 982, 786, 786,\n", 673 | " 36, 36, 786, 653, 653, 36, 36, 36, 36, 36, 937,\n", 674 | " 786, 36, 786, 36, 711, 255, 627, 413, 737, 235, 532,\n", 675 | " 572, 532, 606, 23, 110, 875, 599, 928, 880, 675, 798,\n", 676 | " 451, 997, 883, 907, 907, 242, 706, 814, 451, 326, 451,\n", 677 | " 432, 590, 432, 451, 911, 803, 451, 451, 565, 933, 508,\n", 678 | " 819, 982, 915, 982, 618, 705, 898, 705, 510, 316, 832,\n", 679 | " 486, 934, 653, 601, 432, 702, 864, 451, 228, 601, 590,\n", 680 | " 227, 93, 324, 416, 940, 386, 180, 644, 688, 798, 908,\n", 681 | " 450, 856, 225, 982, 626, 493, 335, 758, 316, 759, 675,\n", 682 | " 406, 711, 769, 918, 977, 526, 267, 432, 432, 1005, 936,\n", 683 | " 813, 541, 432, 936, 937, 907, 730, 1005, 1005, 36, 653,\n", 684 | " 982, 653, 786, 36, 36],\n", 685 | " [ 866, 864, 397, 357, 402, 424, 723, 612, 577, 863, 261,\n", 686 | " 940, 440, 638, 940, 453, 443, 255, 1019, 919, 673, 215,\n", 687 | " 963, 928, 453, 601, 981, 547, 12, 962, 444, 797, 733,\n", 688 | " 960, 529, 597, 612, 118, 562, 870, 110, 402, 104, 103,\n", 689 | " 50, 157, 608, 697, 791, 249, 885, 204, 457, 104, 137,\n", 690 | " 200, 755, 233, 627, 882, 855, 823, 939, 659, 326, 384,\n", 691 | " 222, 74, 762, 793, 793, 741, 673, 673, 673, 866, 673,\n", 692 | " 741, 673, 859, 673, 673, 673, 673, 673, 741, 741, 741,\n", 693 | " 741, 741, 741, 741, 741, 741, 741, 741, 741, 741, 741,\n", 694 | " 741, 741, 741, 741, 721, 410, 50, 259, 16, 282, 791,\n", 695 | " 282, 885, 885, 315, 875, 675, 524, 714, 962, 255, 110,\n", 696 | " 295, 638, 515, 990, 762, 651, 443, 651, 376, 180, 940,\n", 697 | " 601, 601, 440, 940, 326, 440, 107, 644, 440, 440, 255,\n", 698 | " 962, 580, 859, 838, 440, 780, 766, 49, 461, 448, 721,\n", 699 | " 1001, 558, 443, 519, 721, 834, 440, 612, 440, 440, 200,\n", 700 | " 255, 1022, 960, 787, 851, 190, 712, 692, 818, 522, 1007,\n", 701 | " 757, 224, 686, 940, 734, 517, 125, 924, 589, 140, 49,\n", 702 | " 274, 962, 962, 242, 443, 961, 854, 398, 418, 778, 778,\n", 703 | " 75, 443, 651, 434, 916, 1016, 318, 366, 651, 762, 741,\n", 704 | " 673, 866, 866, 741, 673]]], device='cuda:0'))]" 705 | ] 706 | }, 707 | "execution_count": 6, 708 | "metadata": {}, 709 | "output_type": "execute_result" 710 | } 711 | ], 712 | "source": [ 713 | "item" 714 | ] 715 | }, 716 | { 717 | "cell_type": "markdown", 718 | "metadata": {}, 719 | "source": [ 720 | "## Model" 721 | ] 722 | }, 723 | { 724 | "cell_type": "code", 725 | "execution_count": 7, 726 | "metadata": {}, 727 | "outputs": [], 728 | "source": [ 729 | "import megabyte\n", 730 | "import torch\n", 731 | "import torch.nn as nn\n", 732 | "from einops import rearrange\n", 733 | "\n", 734 | "def get_reserved_mem_gb():\n", 735 | " device = torch.cuda.current_device()\n", 736 | " reserved = torch.cuda.memory_reserved(device)\n", 737 | " reserved_gb = reserved / 1024 / 1024 / 1024\n", 738 | " return reserved_gb\n", 739 | "\n", 740 | "class PhoneLM(nn.Module):\n", 741 | " def __init__(self, n_phone_tokens, n_audio_tokens):\n", 742 | " super(PhoneLM, self).__init__()\n", 743 | " self.megabyte = megabyte.MEGABYTE(\n", 744 | " heads = 8, # 1,\n", 745 | " dim_head = 32, # 16,\n", 746 | " num_tokens = n_phone_tokens + n_audio_tokens + 4,\n", 747 | " dim = (768, 256, 128), # (32, 32, 32), # (768, 256, 128)# Dg, Dl1, Dl2\n", 748 | " depth = (6, 4, 2), # (6, 4, 2)\n", 749 | " max_seq_len = (32, 4, 4), # (128, 4, 4), # , # 512\n", 750 | " flash_attn = False)\n", 751 | "\n", 752 | " def forward(self, x, debug=False, return_loss=True):\n", 753 | " x = self.megabyte(x, return_loss=return_loss)\n", 754 | " return x\n", 755 | " \n", 756 | " def get_params(self):\n", 757 | " o = [param.numel() for param in self.parameters() if param.requires_grad]\n", 758 | " o = sum(o)\n", 759 | " return o\n", 760 | " \n", 761 | " def generate(self, *args):\n", 762 | " return self.megabyte.generate(*args)\n", 763 | " \n", 764 | "def multi_encode(\n", 765 | " phone_tokens,\n", 766 | " audio_tokens,\n", 767 | " n_phone_tokens,\n", 768 | " n_audio_tokens,\n", 769 | " max_clip_length=1.0):\n", 770 | " \"\"\"NOTE: 75 steps per second for 24kHz in `encodec.\n", 771 | " Set `max_clip_length` to 0 for original clip length.\"\"\"\n", 772 | "\n", 773 | " # Start text token, end text token, start audio token, end audio token\n", 774 | " ETT, EAT = [n_phone_tokens + n_audio_tokens + i\n", 775 | " for i in range(2)]\n", 776 | " ETT = torch.tensor([ETT]).long().cuda()\n", 777 | " EAT = torch.tensor([EAT]).long().cuda()\n", 778 | "\n", 779 | " if max_clip_length:\n", 780 | " #print(\"pre audio_tokens.shape\", audio_tokens.shape)\n", 781 | " audio_tokens = audio_tokens[:, :, :int(max_clip_length * 75)]\n", 782 | " audio_tokens = rearrange(audio_tokens.squeeze(0), \"q s -> (q s)\")\n", 783 | " #print(\"post audio_tokens.shape\", audio_tokens.shape)\n", 784 | " \n", 785 | " # offset phone tokens past audio tokens\n", 786 | " phone_tokens += n_audio_tokens\n", 787 | " \n", 788 | " #print(\"phone_tokens.shape:\", phone_tokens.shape)\n", 789 | " #print(\"audio_tokens.shape:\", audio_tokens.shape)\n", 790 | " \n", 791 | " device = torch.cuda.current_device()\n", 792 | " phone_tokens = torch.cat((phone_tokens, ETT), dim=0).to(device)\n", 793 | " audio_tokens = torch.cat((audio_tokens, EAT,), dim=0).to(device)\n", 794 | " combined_tokens = torch.cat((phone_tokens, audio_tokens), dim=0).to(device)\n", 795 | " return phone_tokens, audio_tokens, combined_tokens" 796 | ] 797 | }, 798 | { 799 | "cell_type": "code", 800 | "execution_count": 8, 801 | "metadata": {}, 802 | "outputs": [], 803 | "source": [ 804 | "from einops import rearrange\n", 805 | "\n", 806 | "from encodec_util import decode_to_file\n", 807 | "\n", 808 | "\"\"\"\n", 809 | "EinopsError: Error while processing rearrange-reduction pattern \"(t q) -> t q\".\n", 810 | " Input tensor shape: torch.Size([75]). Additional info: {'q': 4, 't': 75}.\n", 811 | " Shape mismatch, 75 != 300\n", 812 | "\"\"\"\n", 813 | "\n", 814 | "def generate_audio(sample,\n", 815 | " n_phone_tokens,\n", 816 | " n_audio_tokens,\n", 817 | " audio_path=\"./out.wav\"):\n", 818 | " ETT, EAT = [n_phone_tokens + n_audio_tokens + i\n", 819 | " for i in range(2)]\n", 820 | " ST_S = [ETT, EAT]\n", 821 | " print(\"ETT, EAT ids:\", ST_S)\n", 822 | " seq = sample.cpu().tolist()[0]\n", 823 | " print(\"seq:\", seq)\n", 824 | " # all special tokens in list\n", 825 | " if all(st_t in seq for st_t in ST_S) and len(seq) >= len(ST_S) + 2:\n", 826 | " # text_tokens = seq[seq.index(STT + 1):seq.index(ETT - 1)]\n", 827 | " audio_tokens = seq[seq.index(ETT)+1:seq.index(EAT)]\n", 828 | " print(seq.index(ETT), seq.index(EAT), len(audio_tokens))\n", 829 | " audio_tokens = torch.tensor(audio_tokens).cuda()\n", 830 | " audio_tokens = rearrange(\n", 831 | " audio_tokens,\n", 832 | " '(t q) -> t q',\n", 833 | " q=1, # CODEBOOK,\n", 834 | " t=audio_tokens.size(0) // 1) # t=audio_tokens.size(0) // CODEBOOK)\n", 835 | " print(\"audio_tokens.shape:\", audio_tokens, audio_tokens.shape)\n", 836 | " decode_to_file(audio_tokens, audio_path)\n", 837 | " return True\n", 838 | " else:\n", 839 | " return False" 840 | ] 841 | }, 842 | { 843 | "cell_type": "markdown", 844 | "metadata": {}, 845 | "source": [ 846 | "## PhoneLM - LJSpeech (Overfit Multi)" 847 | ] 848 | }, 849 | { 850 | "cell_type": "markdown", 851 | "metadata": {}, 852 | "source": [ 853 | "### Train" 854 | ] 855 | }, 856 | { 857 | "cell_type": "code", 858 | "execution_count": 9, 859 | "metadata": {}, 860 | "outputs": [ 861 | { 862 | "name": "stdout", 863 | "output_type": "stream", 864 | "text": [ 865 | "number of parameters: 37.30M\n" 866 | ] 867 | }, 868 | { 869 | "data": { 870 | "text/plain": [ 871 | "37302863" 872 | ] 873 | }, 874 | "execution_count": 9, 875 | "metadata": {}, 876 | "output_type": "execute_result" 877 | } 878 | ], 879 | "source": [ 880 | "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", 881 | "\n", 882 | "model = PhoneLM(\n", 883 | " n_phone_tokens=len(dataset.phone_dict),\n", 884 | " n_audio_tokens=1024).to(device)\n", 885 | "\n", 886 | "model.megabyte.get_num_params()" 887 | ] 888 | }, 889 | { 890 | "cell_type": "code", 891 | "execution_count": 10, 892 | "metadata": {}, 893 | "outputs": [], 894 | "source": [ 895 | "item = next(iter(train_loader))\n", 896 | "# item_phone_tokens = item[-2]\n", 897 | "# # item_audio_tokens = item[-1]\n", 898 | "# item_audio_tokens = item[-1][:, 0, :] # Only keep primary coarse tokens, for now\n", 899 | "# item_audio_tokens = item_audio_tokens.unsqueeze(0)\n", 900 | "# item_phone_tokens.shape, item_audio_tokens.shape" 901 | ] 902 | }, 903 | { 904 | "cell_type": "code", 905 | "execution_count": 11, 906 | "metadata": {}, 907 | "outputs": [ 908 | { 909 | "data": { 910 | "text/plain": [ 911 | "1" 912 | ] 913 | }, 914 | "execution_count": 11, 915 | "metadata": {}, 916 | "output_type": "execute_result" 917 | } 918 | ], 919 | "source": [ 920 | "len(item)" 921 | ] 922 | }, 923 | { 924 | "cell_type": "code", 925 | "execution_count": 12, 926 | "metadata": {}, 927 | "outputs": [ 928 | { 929 | "data": { 930 | "text/plain": [ 931 | "'to taking the descriptions of newly-arrived prisoners.'" 932 | ] 933 | }, 934 | "execution_count": 12, 935 | "metadata": {}, 936 | "output_type": "execute_result" 937 | } 938 | ], 939 | "source": [ 940 | "item[0][3]" 941 | ] 942 | }, 943 | { 944 | "cell_type": "code", 945 | "execution_count": 13, 946 | "metadata": {}, 947 | "outputs": [], 948 | "source": [ 949 | "# item" 950 | ] 951 | }, 952 | { 953 | "cell_type": "code", 954 | "execution_count": 14, 955 | "metadata": {}, 956 | "outputs": [], 957 | "source": [ 958 | "# phone_prompt, audio_target, test_inp = multi_encode(\n", 959 | "# item_phone_tokens,\n", 960 | "# item_audio_tokens,\n", 961 | "# n_phone_tokens=len(dataset.phone_dict),\n", 962 | "# n_audio_tokens=1024,\n", 963 | "# max_clip_length=5)\n", 964 | "# test_inp.shape" 965 | ] 966 | }, 967 | { 968 | "cell_type": "markdown", 969 | "metadata": {}, 970 | "source": [ 971 | "### Training Process" 972 | ] 973 | }, 974 | { 975 | "cell_type": "code", 976 | "execution_count": 15, 977 | "metadata": {}, 978 | "outputs": [], 979 | "source": [ 980 | "from tqdm.notebook import tqdm" 981 | ] 982 | }, 983 | { 984 | "cell_type": "code", 985 | "execution_count": 16, 986 | "metadata": {}, 987 | "outputs": [], 988 | "source": [ 989 | "import torch.optim as optim\n", 990 | "\n", 991 | "epochs = 10\n", 992 | "\n", 993 | "MAX_LR = 1e-2\n", 994 | "# MAX_LR = 1e-2\n", 995 | "WEIGHT_DECAY = 1e-4\n", 996 | "GRAD_CLIP = 0.1\n", 997 | "\n", 998 | "optimizer = optim.Adam(\n", 999 | " model.parameters(),\n", 1000 | " lr=MAX_LR)\n", 1001 | " #,weight_decay=WEIGHT_DECAY)\n", 1002 | "\n", 1003 | "# def get_lr(optimizer):\n", 1004 | "# for param_group in optimizer.param_groups:\n", 1005 | "# return param_group['lr']\n", 1006 | "\n", 1007 | "# sched = torch.optim.lr_scheduler.OneCycleLR(optimizer, MAX_LR, epochs=epochs, \n", 1008 | "# steps_per_epoch=len(trainloader))" 1009 | ] 1010 | }, 1011 | { 1012 | "cell_type": "code", 1013 | "execution_count": 17, 1014 | "metadata": {}, 1015 | "outputs": [ 1016 | { 1017 | "data": { 1018 | "text/plain": [ 1019 | "1" 1020 | ] 1021 | }, 1022 | "execution_count": 17, 1023 | "metadata": {}, 1024 | "output_type": "execute_result" 1025 | } 1026 | ], 1027 | "source": [ 1028 | "len(item)" 1029 | ] 1030 | }, 1031 | { 1032 | "cell_type": "code", 1033 | "execution_count": 19, 1034 | "metadata": {}, 1035 | "outputs": [ 1036 | { 1037 | "ename": "RuntimeError", 1038 | "evalue": "CUDA error: device-side assert triggered\nCUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.\nFor debugging consider passing CUDA_LAUNCH_BLOCKING=1.\nCompile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.\n", 1039 | "output_type": "error", 1040 | "traceback": [ 1041 | "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", 1042 | "\u001b[1;31mRuntimeError\u001b[0m Traceback (most recent call last)", 1043 | "\u001b[1;32m~\\AppData\\Local\\Temp/ipykernel_20876/3201252558.py\u001b[0m in \u001b[0;36m\u001b[1;34m\u001b[0m\n\u001b[0;32m 70\u001b[0m \u001b[1;32mfor\u001b[0m \u001b[0mepoch\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mEPOCHS\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 71\u001b[0m \u001b[0moptimizer\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mzero_grad\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 72\u001b[1;33m \u001b[0mloss\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mtrain\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mtrain_loader\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 73\u001b[0m \u001b[0moptimizer\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mstep\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 74\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n", 1044 | "\u001b[1;32m~\\AppData\\Local\\Temp/ipykernel_20876/3201252558.py\u001b[0m in \u001b[0;36mtrain\u001b[1;34m(model, trainloader)\u001b[0m\n\u001b[0;32m 51\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 52\u001b[0m \u001b[1;31m# print(\"batch:\", batch.shape, batch)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 53\u001b[1;33m \u001b[0mloss\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mmodel\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mbatch\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mreturn_loss\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;32mTrue\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 54\u001b[0m \u001b[0mloss\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 55\u001b[0m \u001b[1;32mreturn\u001b[0m \u001b[0mloss\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", 1045 | "\u001b[1;32mc:\\Users\\win8t\\AppData\\Local\\Programs\\Python\\Python39\\lib\\site-packages\\torch\\nn\\modules\\module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1499\u001b[0m \u001b[1;32mor\u001b[0m \u001b[0m_global_backward_pre_hooks\u001b[0m \u001b[1;32mor\u001b[0m \u001b[0m_global_backward_hooks\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1500\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[1;32m-> 1501\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0margs\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 1502\u001b[0m \u001b[1;31m# Do not call functions when jit is used\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1503\u001b[0m \u001b[0mfull_backward_hooks\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", 1046 | "\u001b[1;32m~\\AppData\\Local\\Temp/ipykernel_20876/3374421788.py\u001b[0m in \u001b[0;36mforward\u001b[1;34m(self, x, debug, return_loss)\u001b[0m\n\u001b[0;32m 23\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 24\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mx\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mdebug\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;32mFalse\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mreturn_loss\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;32mTrue\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 25\u001b[1;33m \u001b[0mx\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mmegabyte\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mx\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mreturn_loss\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mreturn_loss\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 26\u001b[0m \u001b[1;32mreturn\u001b[0m \u001b[0mx\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 27\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n", 1047 | "\u001b[1;32mc:\\Users\\win8t\\AppData\\Local\\Programs\\Python\\Python39\\lib\\site-packages\\torch\\nn\\modules\\module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1499\u001b[0m \u001b[1;32mor\u001b[0m \u001b[0m_global_backward_pre_hooks\u001b[0m \u001b[1;32mor\u001b[0m \u001b[0m_global_backward_hooks\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1500\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[1;32m-> 1501\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0margs\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 1502\u001b[0m \u001b[1;31m# Do not call functions when jit is used\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1503\u001b[0m \u001b[0mfull_backward_hooks\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", 1048 | "\u001b[1;32mc:\\Users\\win8t\\OneDrive\\Desktop\\projects\\PhoneLM\\megabyte.py\u001b[0m in \u001b[0;36mforward\u001b[1;34m(self, ids, return_loss)\u001b[0m\n\u001b[0;32m 429\u001b[0m \u001b[0mstage_tokens\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mstage_tokens\u001b[0m \u001b[1;33m+\u001b[0m \u001b[0mprev_stage_tokens_repr\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 430\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 431\u001b[1;33m \u001b[0mattended\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mtransformer\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mstage_tokens\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 432\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 433\u001b[0m \u001b[0mattended\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0munpack_one\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mattended\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mps\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;34m'* n d'\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", 1049 | "\u001b[1;32mc:\\Users\\win8t\\AppData\\Local\\Programs\\Python\\Python39\\lib\\site-packages\\torch\\nn\\modules\\module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1499\u001b[0m \u001b[1;32mor\u001b[0m \u001b[0m_global_backward_pre_hooks\u001b[0m \u001b[1;32mor\u001b[0m \u001b[0m_global_backward_hooks\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1500\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[1;32m-> 1501\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0margs\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 1502\u001b[0m \u001b[1;31m# Do not call functions when jit is used\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1503\u001b[0m \u001b[0mfull_backward_hooks\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", 1050 | "\u001b[1;32mc:\\Users\\win8t\\OneDrive\\Desktop\\projects\\PhoneLM\\megabyte.py\u001b[0m in \u001b[0;36mforward\u001b[1;34m(self, x)\u001b[0m\n\u001b[0;32m 188\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 189\u001b[0m \u001b[1;32mfor\u001b[0m \u001b[0mattn\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mff\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mlayers\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 190\u001b[1;33m \u001b[0mx\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mattn\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mtoken_shift\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mx\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mrotary_emb\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mrotary_emb\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;33m+\u001b[0m \u001b[0mx\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 191\u001b[0m \u001b[0mx\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mff\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mtoken_shift\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mx\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;33m+\u001b[0m \u001b[0mx\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 192\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n", 1051 | "\u001b[1;32mc:\\Users\\win8t\\AppData\\Local\\Programs\\Python\\Python39\\lib\\site-packages\\torch\\nn\\modules\\module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1499\u001b[0m \u001b[1;32mor\u001b[0m \u001b[0m_global_backward_pre_hooks\u001b[0m \u001b[1;32mor\u001b[0m \u001b[0m_global_backward_hooks\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1500\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[1;32m-> 1501\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0margs\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 1502\u001b[0m \u001b[1;31m# Do not call functions when jit is used\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1503\u001b[0m \u001b[0mfull_backward_hooks\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", 1052 | "\u001b[1;32mc:\\Users\\win8t\\OneDrive\\Desktop\\projects\\PhoneLM\\megabyte.py\u001b[0m in \u001b[0;36mforward\u001b[1;34m(self, x, rotary_emb)\u001b[0m\n\u001b[0;32m 145\u001b[0m \u001b[0mh\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mdevice\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mheads\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mx\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mdevice\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 146\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 147\u001b[1;33m \u001b[0mx\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mnorm\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mx\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 148\u001b[0m \u001b[0mq\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mk\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mv\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mto_q\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mx\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m*\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mto_kv\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mx\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mchunk\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;36m2\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mdim\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;33m-\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 149\u001b[0m \u001b[0mq\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mrearrange\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mq\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;34m'b n (h d) -> b h n d'\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mh\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mh\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", 1053 | "\u001b[1;32mc:\\Users\\win8t\\AppData\\Local\\Programs\\Python\\Python39\\lib\\site-packages\\torch\\nn\\modules\\module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1499\u001b[0m \u001b[1;32mor\u001b[0m \u001b[0m_global_backward_pre_hooks\u001b[0m \u001b[1;32mor\u001b[0m \u001b[0m_global_backward_hooks\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1500\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[1;32m-> 1501\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0margs\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 1502\u001b[0m \u001b[1;31m# Do not call functions when jit is used\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1503\u001b[0m \u001b[0mfull_backward_hooks\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", 1054 | "\u001b[1;32mc:\\Users\\win8t\\OneDrive\\Desktop\\projects\\PhoneLM\\megabyte.py\u001b[0m in \u001b[0;36mforward\u001b[1;34m(self, x)\u001b[0m\n\u001b[0;32m 102\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 103\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mx\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 104\u001b[1;33m \u001b[0mnorm\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mnorm\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mx\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mdim\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;33m-\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mkeepdim\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;32mTrue\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;33m*\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mscale\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 105\u001b[0m \u001b[1;32mreturn\u001b[0m \u001b[0mx\u001b[0m \u001b[1;33m/\u001b[0m \u001b[0mnorm\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mclamp\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mmin\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0meps\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;33m*\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mg\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 106\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n", 1055 | "\u001b[1;32mc:\\Users\\win8t\\AppData\\Local\\Programs\\Python\\Python39\\lib\\site-packages\\torch\\functional.py\u001b[0m in \u001b[0;36mnorm\u001b[1;34m(input, p, dim, keepdim, out, dtype)\u001b[0m\n\u001b[0;32m 1499\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mp\u001b[0m \u001b[1;33m==\u001b[0m \u001b[1;34m\"fro\"\u001b[0m \u001b[1;32mand\u001b[0m \u001b[1;33m(\u001b[0m\u001b[0mdim\u001b[0m \u001b[1;32mis\u001b[0m \u001b[1;32mNone\u001b[0m \u001b[1;32mor\u001b[0m \u001b[0misinstance\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mdim\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mint\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;32mor\u001b[0m \u001b[0mlen\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mdim\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;33m<=\u001b[0m \u001b[1;36m2\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1500\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mout\u001b[0m \u001b[1;32mis\u001b[0m \u001b[1;32mNone\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m-> 1501\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mlinalg\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mvector_norm\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;36m2\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0m_dim\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mkeepdim\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mdtype\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mdtype\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 1502\u001b[0m \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1503\u001b[0m \u001b[1;32mreturn\u001b[0m \u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mlinalg\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mvector_norm\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;36m2\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0m_dim\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mkeepdim\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mdtype\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mdtype\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mout\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mout\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", 1056 | "\u001b[1;31mRuntimeError\u001b[0m: CUDA error: device-side assert triggered\nCUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.\nFor debugging consider passing CUDA_LAUNCH_BLOCKING=1.\nCompile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.\n" 1057 | ] 1058 | } 1059 | ], 1060 | "source": [ 1061 | "import torch.nn.functional as F\n", 1062 | "\n", 1063 | "EPOCHS = 1000\n", 1064 | "PRINT_INTERVAL = 100\n", 1065 | "\n", 1066 | "seq_len = 512 # 2048\n", 1067 | "\n", 1068 | "# phone_prompt, audio_target, test_inp = multi_encode(\n", 1069 | "# item[0][6], # phone tokens\n", 1070 | "# item[0][7], # audio tokens\n", 1071 | "# n_phone_tokens=len(dataset.phone_dict),\n", 1072 | "# n_audio_tokens=1024,\n", 1073 | "# max_clip_length=1)\n", 1074 | "\n", 1075 | "# model = PhoneLM(\n", 1076 | "# n_phone_tokens=len(dataset.phone_dict),\n", 1077 | "# n_audio_tokens=1024).to(device)\n", 1078 | "\n", 1079 | "prompt = None\n", 1080 | "\n", 1081 | "def create_seq(item_phone_tokens, item_audio_tokens):\n", 1082 | " global prompt\n", 1083 | " phone_prompt, audio_target, test_inp = multi_encode(\n", 1084 | " item_phone_tokens,\n", 1085 | " item_audio_tokens,\n", 1086 | " n_phone_tokens=len(dataset.phone_dict),\n", 1087 | " n_audio_tokens=1024,\n", 1088 | " max_clip_length=1)\n", 1089 | " prompt = phone_prompt\n", 1090 | " padding_len = max(0, seq_len - test_inp.size(0))\n", 1091 | " n_test_inp = F.pad(test_inp, (0, padding_len))\n", 1092 | " cur_item = n_test_inp\n", 1093 | " # cur_item = n_test_inp.unsqueeze(0)\n", 1094 | " return cur_item\n", 1095 | "\n", 1096 | "def create_batch(batch):\n", 1097 | " rnn_batch = []\n", 1098 | " for item in batch:\n", 1099 | " item_phone_tokens = item[6]\n", 1100 | " item_audio_tokens = item[7]\n", 1101 | " seq = create_seq(item_phone_tokens, item_audio_tokens)\n", 1102 | " rnn_batch.append(seq)\n", 1103 | " rnn_batch = nn.utils.rnn.pad_sequence(rnn_batch, batch_first=True)\n", 1104 | " return rnn_batch\n", 1105 | "\n", 1106 | "cur_batch = item\n", 1107 | "batch = create_batch(cur_batch)\n", 1108 | "\n", 1109 | "def train(model, trainloader):\n", 1110 | " model.train()\n", 1111 | "\n", 1112 | " # print(\"batch:\", batch.shape, batch)\n", 1113 | " loss = model(batch, return_loss=True)\n", 1114 | " loss.backward()\n", 1115 | " return loss\n", 1116 | "\n", 1117 | "# def train(model, trainloader):\n", 1118 | "# model.train()\n", 1119 | " \n", 1120 | "# padding_len = max(0, seq_len - test_inp.size(0))\n", 1121 | "# n_test_inp = F.pad(test_inp, (0, padding_len))\n", 1122 | "# batch = n_test_inp.unsqueeze(0)\n", 1123 | "# print(\"batch.shape:\", batch.shape, batch)\n", 1124 | "# loss = model(batch, return_loss=True)\n", 1125 | "# # loss = model(next(trainloader), return_loss=True)\n", 1126 | "# loss.backward()\n", 1127 | "# return loss\n", 1128 | "\n", 1129 | "# pbar = tqdm.tqdm(EPOCHS, mininterval=10., desc='training')\n", 1130 | "for epoch in range(EPOCHS):\n", 1131 | " optimizer.zero_grad()\n", 1132 | " loss = train(model, train_loader)\n", 1133 | " optimizer.step()\n", 1134 | " \n", 1135 | " mem_gb = get_reserved_mem_gb()\n", 1136 | " if epoch % PRINT_INTERVAL == 0:\n", 1137 | " print(f\"Reserved Memory (GB): {mem_gb}, loss: {loss.item()}\")\n", 1138 | " # pbar.set_description(f\"Reserved Memory (GB): {mem_gb}, loss: {loss.item()}\")" 1139 | ] 1140 | }, 1141 | { 1142 | "cell_type": "code", 1143 | "execution_count": 22, 1144 | "metadata": {}, 1145 | "outputs": [ 1146 | { 1147 | "name": "stdout", 1148 | "output_type": "stream", 1149 | "text": [ 1150 | "torch.Size([1, 512])\n" 1151 | ] 1152 | } 1153 | ], 1154 | "source": [ 1155 | "print(batch.shape)" 1156 | ] 1157 | }, 1158 | { 1159 | "cell_type": "markdown", 1160 | "metadata": {}, 1161 | "source": [ 1162 | "### Evaluate" 1163 | ] 1164 | }, 1165 | { 1166 | "cell_type": "code", 1167 | "execution_count": 23, 1168 | "metadata": {}, 1169 | "outputs": [ 1170 | { 1171 | "data": { 1172 | "text/plain": [ 1173 | "torch.Size([31])" 1174 | ] 1175 | }, 1176 | "execution_count": 23, 1177 | "metadata": {}, 1178 | "output_type": "execute_result" 1179 | } 1180 | ], 1181 | "source": [ 1182 | "prompt.shape" 1183 | ] 1184 | }, 1185 | { 1186 | "cell_type": "code", 1187 | "execution_count": 34, 1188 | "metadata": {}, 1189 | "outputs": [ 1190 | { 1191 | "data": { 1192 | "text/plain": [ 1193 | "tensor([[1059, 1081, 1057, 1097, 1053, 1098, 1084, 1075, 1070, 1048, 1098, 1049,\n", 1194 | " 1034, 1098, 1069, 1034, 1071, 1063, 1083, 1034, 1072, 1098, 1098, 1098,\n", 1195 | " 1069, 1094, 1075, 1084, 1098, 1098, 1099]], device='cuda:0')" 1196 | ] 1197 | }, 1198 | "execution_count": 34, 1199 | "metadata": {}, 1200 | "output_type": "execute_result" 1201 | } 1202 | ], 1203 | "source": [ 1204 | "prompt" 1205 | ] 1206 | }, 1207 | { 1208 | "cell_type": "code", 1209 | "execution_count": 39, 1210 | "metadata": {}, 1211 | "outputs": [ 1212 | { 1213 | "name": "stdout", 1214 | "output_type": "stream", 1215 | "text": [ 1216 | "FUCKING HELLO?\n", 1217 | "prompt: torch.Size([1, 31])\n" 1218 | ] 1219 | }, 1220 | { 1221 | "name": "stderr", 1222 | "output_type": "stream", 1223 | "text": [ 1224 | "100%|██████████| 481/481 [00:06<00:00, 74.27it/s]" 1225 | ] 1226 | }, 1227 | { 1228 | "name": "stdout", 1229 | "output_type": "stream", 1230 | "text": [ 1231 | "sample: tensor([[1059, 1081, 1057, 1097, 1053, 1098, 1084, 1075, 1070, 1048, 1098, 1049,\n", 1232 | " 1034, 1098, 1069, 1034, 1071, 1063, 1083, 1034, 1072, 1098, 1098, 1098,\n", 1233 | " 1069, 1094, 1075, 1084, 1098, 1098, 1099, 835, 160, 438, 488, 887,\n", 1234 | " 203, 503, 441, 6, 81, 727, 141, 908, 908, 502, 303, 148,\n", 1235 | " 103, 496, 145, 731, 977, 259, 582, 808, 921, 432, 779, 779,\n", 1236 | " 472, 472, 331, 103, 887, 457, 987, 501, 921, 197, 197, 931,\n", 1237 | " 928, 881, 834, 432, 604, 491, 373, 994, 782, 834, 408, 855,\n", 1238 | " 855, 798, 176, 798, 537, 779, 936, 457, 751, 651, 687, 751,\n", 1239 | " 790, 686, 994, 57, 916, 751, 699, 145, 145, 148, 1010, 4,\n", 1240 | " 973, 984, 1002, 3, 472, 185, 662, 662, 471, 471, 106, 734,\n", 1241 | " 893, 802, 285, 1010, 812, 272, 529, 930, 486, 323, 632, 466,\n", 1242 | " 930, 601, 646, 924, 913, 160, 1010, 857, 984, 668, 399, 399,\n", 1243 | " 399, 710, 27, 710, 710, 564, 765, 888, 770, 404, 405, 420,\n", 1244 | " 708, 752, 928, 71, 216, 43, 471, 801, 404, 112, 214, 48,\n", 1245 | " 425, 1011, 48, 767, 792, 792, 869, 268, 471, 200, 1010, 857,\n", 1246 | " 857, 970, 457, 212, 188, 212, 54, 90, 573, 999, 298, 977,\n", 1247 | " 469, 977, 977, 587, 977, 590, 747, 590, 590, 906, 775, 536,\n", 1248 | " 624, 264, 432, 705, 647, 832, 1007, 861, 758, 633, 1015, 1007,\n", 1249 | " 759, 34, 733, 733, 741, 659, 757, 932, 626, 829, 832, 547,\n", 1250 | " 752, 701, 955, 242, 705, 1000, 710, 907, 80, 893, 438, 893,\n", 1251 | " 932, 549, 55, 442, 426, 831, 581, 475, 977, 93, 93, 712,\n", 1252 | " 298, 432, 933, 242, 268, 919, 261, 787, 255, 263, 787, 787,\n", 1253 | " 791, 19, 238, 490, 416, 935, 698, 588, 889, 411, 778, 779,\n", 1254 | " 859, 817, 162, 418, 211, 306, 255, 714, 993, 875, 493, 418,\n", 1255 | " 762, 1019, 178, 504, 414, 605, 499, 706, 628, 214, 78, 558,\n", 1256 | " 1019, 1000, 1019, 874, 110, 804, 381, 529, 866, 36, 803, 940,\n", 1257 | " 180, 941, 733, 96, 803, 335, 524, 714, 224, 733, 499, 769,\n", 1258 | " 204, 273, 994, 249, 882, 940, 148, 1100, 993, 875, 493, 418,\n", 1259 | " 762, 1019, 178, 504, 414, 605, 499, 706, 628, 214, 78, 558,\n", 1260 | " 1019, 1000, 1019, 874, 110, 804, 381, 529, 866, 36, 803, 940,\n", 1261 | " 180, 941, 733, 96, 803, 335, 524, 714, 224, 733, 499, 769,\n", 1262 | " 204, 273, 994, 249, 882, 940, 148, 1100, 993, 875, 493, 418,\n", 1263 | " 762, 1019, 178, 504, 414, 605, 499, 706, 628, 214, 78, 558,\n", 1264 | " 1019, 1000, 1019, 874, 110, 804, 381, 529, 866, 36, 803, 940,\n", 1265 | " 180, 941, 733, 96, 803, 335, 524, 714, 224, 733, 499, 769,\n", 1266 | " 204, 273, 994, 249, 882, 940, 148, 1100, 573, 999, 298, 977,\n", 1267 | " 469, 977, 977, 587, 977, 590, 747, 590, 590, 906, 775, 536,\n", 1268 | " 624, 264, 432, 705, 647, 832, 1007, 861, 758, 633, 1015, 1007,\n", 1269 | " 759, 34, 733, 733, 741, 659, 757, 932, 626, 829, 832, 547,\n", 1270 | " 752, 701, 955, 242, 705, 1000, 710, 907, 80, 893, 438, 893,\n", 1271 | " 932, 549, 55, 442, 426, 831, 581, 475, 977, 93, 93, 712,\n", 1272 | " 298, 432, 933, 242, 268, 919, 261, 787, 255, 263, 787, 787,\n", 1273 | " 791, 19, 238, 490, 416, 935, 698, 588]], device='cuda:0') torch.Size([1, 512])\n" 1274 | ] 1275 | }, 1276 | { 1277 | "name": "stderr", 1278 | "output_type": "stream", 1279 | "text": [ 1280 | "\n" 1281 | ] 1282 | } 1283 | ], 1284 | "source": [ 1285 | "def generate(model, inp):\n", 1286 | " model.eval()\n", 1287 | "\n", 1288 | " # inp = inp.unsqueeze(0)\n", 1289 | " sample = model.generate(inp)\n", 1290 | " sample = sample.flatten(1)\n", 1291 | " print(\"sample:\", sample, sample.shape)\n", 1292 | "\n", 1293 | " return prompt, sample\n", 1294 | "\n", 1295 | "prompt, sample = generate(model, prompt)" 1296 | ] 1297 | }, 1298 | { 1299 | "cell_type": "code", 1300 | "execution_count": 40, 1301 | "metadata": {}, 1302 | "outputs": [ 1303 | { 1304 | "data": { 1305 | "text/plain": [ 1306 | "torch.Size([1, 512])" 1307 | ] 1308 | }, 1309 | "execution_count": 40, 1310 | "metadata": {}, 1311 | "output_type": "execute_result" 1312 | } 1313 | ], 1314 | "source": [ 1315 | "sample.shape" 1316 | ] 1317 | }, 1318 | { 1319 | "cell_type": "code", 1320 | "execution_count": 41, 1321 | "metadata": {}, 1322 | "outputs": [ 1323 | { 1324 | "name": "stdout", 1325 | "output_type": "stream", 1326 | "text": [ 1327 | "1099 1100\n" 1328 | ] 1329 | } 1330 | ], 1331 | "source": [ 1332 | "ETT, EAT = [len(dataset.phone_dict) + 1024 + i\n", 1333 | " for i in range(2)]\n", 1334 | "print(ETT, EAT)\n", 1335 | "# sample.index(STT)" 1336 | ] 1337 | }, 1338 | { 1339 | "cell_type": "code", 1340 | "execution_count": 42, 1341 | "metadata": {}, 1342 | "outputs": [ 1343 | { 1344 | "name": "stdout", 1345 | "output_type": "stream", 1346 | "text": [ 1347 | "ETT, EAT ids: [1099, 1100]\n", 1348 | "seq: [1059, 1081, 1057, 1097, 1053, 1098, 1084, 1075, 1070, 1048, 1098, 1049, 1034, 1098, 1069, 1034, 1071, 1063, 1083, 1034, 1072, 1098, 1098, 1098, 1069, 1094, 1075, 1084, 1098, 1098, 1099, 835, 160, 438, 488, 887, 203, 503, 441, 6, 81, 727, 141, 908, 908, 502, 303, 148, 103, 496, 145, 731, 977, 259, 582, 808, 921, 432, 779, 779, 472, 472, 331, 103, 887, 457, 987, 501, 921, 197, 197, 931, 928, 881, 834, 432, 604, 491, 373, 994, 782, 834, 408, 855, 855, 798, 176, 798, 537, 779, 936, 457, 751, 651, 687, 751, 790, 686, 994, 57, 916, 751, 699, 145, 145, 148, 1010, 4, 973, 984, 1002, 3, 472, 185, 662, 662, 471, 471, 106, 734, 893, 802, 285, 1010, 812, 272, 529, 930, 486, 323, 632, 466, 930, 601, 646, 924, 913, 160, 1010, 857, 984, 668, 399, 399, 399, 710, 27, 710, 710, 564, 765, 888, 770, 404, 405, 420, 708, 752, 928, 71, 216, 43, 471, 801, 404, 112, 214, 48, 425, 1011, 48, 767, 792, 792, 869, 268, 471, 200, 1010, 857, 857, 970, 457, 212, 188, 212, 54, 90, 573, 999, 298, 977, 469, 977, 977, 587, 977, 590, 747, 590, 590, 906, 775, 536, 624, 264, 432, 705, 647, 832, 1007, 861, 758, 633, 1015, 1007, 759, 34, 733, 733, 741, 659, 757, 932, 626, 829, 832, 547, 752, 701, 955, 242, 705, 1000, 710, 907, 80, 893, 438, 893, 932, 549, 55, 442, 426, 831, 581, 475, 977, 93, 93, 712, 298, 432, 933, 242, 268, 919, 261, 787, 255, 263, 787, 787, 791, 19, 238, 490, 416, 935, 698, 588, 889, 411, 778, 779, 859, 817, 162, 418, 211, 306, 255, 714, 993, 875, 493, 418, 762, 1019, 178, 504, 414, 605, 499, 706, 628, 214, 78, 558, 1019, 1000, 1019, 874, 110, 804, 381, 529, 866, 36, 803, 940, 180, 941, 733, 96, 803, 335, 524, 714, 224, 733, 499, 769, 204, 273, 994, 249, 882, 940, 148, 1100, 993, 875, 493, 418, 762, 1019, 178, 504, 414, 605, 499, 706, 628, 214, 78, 558, 1019, 1000, 1019, 874, 110, 804, 381, 529, 866, 36, 803, 940, 180, 941, 733, 96, 803, 335, 524, 714, 224, 733, 499, 769, 204, 273, 994, 249, 882, 940, 148, 1100, 993, 875, 493, 418, 762, 1019, 178, 504, 414, 605, 499, 706, 628, 214, 78, 558, 1019, 1000, 1019, 874, 110, 804, 381, 529, 866, 36, 803, 940, 180, 941, 733, 96, 803, 335, 524, 714, 224, 733, 499, 769, 204, 273, 994, 249, 882, 940, 148, 1100, 573, 999, 298, 977, 469, 977, 977, 587, 977, 590, 747, 590, 590, 906, 775, 536, 624, 264, 432, 705, 647, 832, 1007, 861, 758, 633, 1015, 1007, 759, 34, 733, 733, 741, 659, 757, 932, 626, 829, 832, 547, 752, 701, 955, 242, 705, 1000, 710, 907, 80, 893, 438, 893, 932, 549, 55, 442, 426, 831, 581, 475, 977, 93, 93, 712, 298, 432, 933, 242, 268, 919, 261, 787, 255, 263, 787, 787, 791, 19, 238, 490, 416, 935, 698, 588]\n", 1349 | "30 331 300\n", 1350 | "audio_tokens.shape: tensor([[ 835],\n", 1351 | " [ 160],\n", 1352 | " [ 438],\n", 1353 | " [ 488],\n", 1354 | " [ 887],\n", 1355 | " [ 203],\n", 1356 | " [ 503],\n", 1357 | " [ 441],\n", 1358 | " [ 6],\n", 1359 | " [ 81],\n", 1360 | " [ 727],\n", 1361 | " [ 141],\n", 1362 | " [ 908],\n", 1363 | " [ 908],\n", 1364 | " [ 502],\n", 1365 | " [ 303],\n", 1366 | " [ 148],\n", 1367 | " [ 103],\n", 1368 | " [ 496],\n", 1369 | " [ 145],\n", 1370 | " [ 731],\n", 1371 | " [ 977],\n", 1372 | " [ 259],\n", 1373 | " [ 582],\n", 1374 | " [ 808],\n", 1375 | " [ 921],\n", 1376 | " [ 432],\n", 1377 | " [ 779],\n", 1378 | " [ 779],\n", 1379 | " [ 472],\n", 1380 | " [ 472],\n", 1381 | " [ 331],\n", 1382 | " [ 103],\n", 1383 | " [ 887],\n", 1384 | " [ 457],\n", 1385 | " [ 987],\n", 1386 | " [ 501],\n", 1387 | " [ 921],\n", 1388 | " [ 197],\n", 1389 | " [ 197],\n", 1390 | " [ 931],\n", 1391 | " [ 928],\n", 1392 | " [ 881],\n", 1393 | " [ 834],\n", 1394 | " [ 432],\n", 1395 | " [ 604],\n", 1396 | " [ 491],\n", 1397 | " [ 373],\n", 1398 | " [ 994],\n", 1399 | " [ 782],\n", 1400 | " [ 834],\n", 1401 | " [ 408],\n", 1402 | " [ 855],\n", 1403 | " [ 855],\n", 1404 | " [ 798],\n", 1405 | " [ 176],\n", 1406 | " [ 798],\n", 1407 | " [ 537],\n", 1408 | " [ 779],\n", 1409 | " [ 936],\n", 1410 | " [ 457],\n", 1411 | " [ 751],\n", 1412 | " [ 651],\n", 1413 | " [ 687],\n", 1414 | " [ 751],\n", 1415 | " [ 790],\n", 1416 | " [ 686],\n", 1417 | " [ 994],\n", 1418 | " [ 57],\n", 1419 | " [ 916],\n", 1420 | " [ 751],\n", 1421 | " [ 699],\n", 1422 | " [ 145],\n", 1423 | " [ 145],\n", 1424 | " [ 148],\n", 1425 | " [1010],\n", 1426 | " [ 4],\n", 1427 | " [ 973],\n", 1428 | " [ 984],\n", 1429 | " [1002],\n", 1430 | " [ 3],\n", 1431 | " [ 472],\n", 1432 | " [ 185],\n", 1433 | " [ 662],\n", 1434 | " [ 662],\n", 1435 | " [ 471],\n", 1436 | " [ 471],\n", 1437 | " [ 106],\n", 1438 | " [ 734],\n", 1439 | " [ 893],\n", 1440 | " [ 802],\n", 1441 | " [ 285],\n", 1442 | " [1010],\n", 1443 | " [ 812],\n", 1444 | " [ 272],\n", 1445 | " [ 529],\n", 1446 | " [ 930],\n", 1447 | " [ 486],\n", 1448 | " [ 323],\n", 1449 | " [ 632],\n", 1450 | " [ 466],\n", 1451 | " [ 930],\n", 1452 | " [ 601],\n", 1453 | " [ 646],\n", 1454 | " [ 924],\n", 1455 | " [ 913],\n", 1456 | " [ 160],\n", 1457 | " [1010],\n", 1458 | " [ 857],\n", 1459 | " [ 984],\n", 1460 | " [ 668],\n", 1461 | " [ 399],\n", 1462 | " [ 399],\n", 1463 | " [ 399],\n", 1464 | " [ 710],\n", 1465 | " [ 27],\n", 1466 | " [ 710],\n", 1467 | " [ 710],\n", 1468 | " [ 564],\n", 1469 | " [ 765],\n", 1470 | " [ 888],\n", 1471 | " [ 770],\n", 1472 | " [ 404],\n", 1473 | " [ 405],\n", 1474 | " [ 420],\n", 1475 | " [ 708],\n", 1476 | " [ 752],\n", 1477 | " [ 928],\n", 1478 | " [ 71],\n", 1479 | " [ 216],\n", 1480 | " [ 43],\n", 1481 | " [ 471],\n", 1482 | " [ 801],\n", 1483 | " [ 404],\n", 1484 | " [ 112],\n", 1485 | " [ 214],\n", 1486 | " [ 48],\n", 1487 | " [ 425],\n", 1488 | " [1011],\n", 1489 | " [ 48],\n", 1490 | " [ 767],\n", 1491 | " [ 792],\n", 1492 | " [ 792],\n", 1493 | " [ 869],\n", 1494 | " [ 268],\n", 1495 | " [ 471],\n", 1496 | " [ 200],\n", 1497 | " [1010],\n", 1498 | " [ 857],\n", 1499 | " [ 857],\n", 1500 | " [ 970],\n", 1501 | " [ 457],\n", 1502 | " [ 212],\n", 1503 | " [ 188],\n", 1504 | " [ 212],\n", 1505 | " [ 54],\n", 1506 | " [ 90],\n", 1507 | " [ 573],\n", 1508 | " [ 999],\n", 1509 | " [ 298],\n", 1510 | " [ 977],\n", 1511 | " [ 469],\n", 1512 | " [ 977],\n", 1513 | " [ 977],\n", 1514 | " [ 587],\n", 1515 | " [ 977],\n", 1516 | " [ 590],\n", 1517 | " [ 747],\n", 1518 | " [ 590],\n", 1519 | " [ 590],\n", 1520 | " [ 906],\n", 1521 | " [ 775],\n", 1522 | " [ 536],\n", 1523 | " [ 624],\n", 1524 | " [ 264],\n", 1525 | " [ 432],\n", 1526 | " [ 705],\n", 1527 | " [ 647],\n", 1528 | " [ 832],\n", 1529 | " [1007],\n", 1530 | " [ 861],\n", 1531 | " [ 758],\n", 1532 | " [ 633],\n", 1533 | " [1015],\n", 1534 | " [1007],\n", 1535 | " [ 759],\n", 1536 | " [ 34],\n", 1537 | " [ 733],\n", 1538 | " [ 733],\n", 1539 | " [ 741],\n", 1540 | " [ 659],\n", 1541 | " [ 757],\n", 1542 | " [ 932],\n", 1543 | " [ 626],\n", 1544 | " [ 829],\n", 1545 | " [ 832],\n", 1546 | " [ 547],\n", 1547 | " [ 752],\n", 1548 | " [ 701],\n", 1549 | " [ 955],\n", 1550 | " [ 242],\n", 1551 | " [ 705],\n", 1552 | " [1000],\n", 1553 | " [ 710],\n", 1554 | " [ 907],\n", 1555 | " [ 80],\n", 1556 | " [ 893],\n", 1557 | " [ 438],\n", 1558 | " [ 893],\n", 1559 | " [ 932],\n", 1560 | " [ 549],\n", 1561 | " [ 55],\n", 1562 | " [ 442],\n", 1563 | " [ 426],\n", 1564 | " [ 831],\n", 1565 | " [ 581],\n", 1566 | " [ 475],\n", 1567 | " [ 977],\n", 1568 | " [ 93],\n", 1569 | " [ 93],\n", 1570 | " [ 712],\n", 1571 | " [ 298],\n", 1572 | " [ 432],\n", 1573 | " [ 933],\n", 1574 | " [ 242],\n", 1575 | " [ 268],\n", 1576 | " [ 919],\n", 1577 | " [ 261],\n", 1578 | " [ 787],\n", 1579 | " [ 255],\n", 1580 | " [ 263],\n", 1581 | " [ 787],\n", 1582 | " [ 787],\n", 1583 | " [ 791],\n", 1584 | " [ 19],\n", 1585 | " [ 238],\n", 1586 | " [ 490],\n", 1587 | " [ 416],\n", 1588 | " [ 935],\n", 1589 | " [ 698],\n", 1590 | " [ 588],\n", 1591 | " [ 889],\n", 1592 | " [ 411],\n", 1593 | " [ 778],\n", 1594 | " [ 779],\n", 1595 | " [ 859],\n", 1596 | " [ 817],\n", 1597 | " [ 162],\n", 1598 | " [ 418],\n", 1599 | " [ 211],\n", 1600 | " [ 306],\n", 1601 | " [ 255],\n", 1602 | " [ 714],\n", 1603 | " [ 993],\n", 1604 | " [ 875],\n", 1605 | " [ 493],\n", 1606 | " [ 418],\n", 1607 | " [ 762],\n", 1608 | " [1019],\n", 1609 | " [ 178],\n", 1610 | " [ 504],\n", 1611 | " [ 414],\n", 1612 | " [ 605],\n", 1613 | " [ 499],\n", 1614 | " [ 706],\n", 1615 | " [ 628],\n", 1616 | " [ 214],\n", 1617 | " [ 78],\n", 1618 | " [ 558],\n", 1619 | " [1019],\n", 1620 | " [1000],\n", 1621 | " [1019],\n", 1622 | " [ 874],\n", 1623 | " [ 110],\n", 1624 | " [ 804],\n", 1625 | " [ 381],\n", 1626 | " [ 529],\n", 1627 | " [ 866],\n", 1628 | " [ 36],\n", 1629 | " [ 803],\n", 1630 | " [ 940],\n", 1631 | " [ 180],\n", 1632 | " [ 941],\n", 1633 | " [ 733],\n", 1634 | " [ 96],\n", 1635 | " [ 803],\n", 1636 | " [ 335],\n", 1637 | " [ 524],\n", 1638 | " [ 714],\n", 1639 | " [ 224],\n", 1640 | " [ 733],\n", 1641 | " [ 499],\n", 1642 | " [ 769],\n", 1643 | " [ 204],\n", 1644 | " [ 273],\n", 1645 | " [ 994],\n", 1646 | " [ 249],\n", 1647 | " [ 882],\n", 1648 | " [ 940],\n", 1649 | " [ 148]], device='cuda:0') torch.Size([300, 1])\n" 1650 | ] 1651 | } 1652 | ], 1653 | "source": [ 1654 | "out = generate_audio(\n", 1655 | " sample,\n", 1656 | " n_phone_tokens=len(dataset.phone_dict),\n", 1657 | " n_audio_tokens=1024)" 1658 | ] 1659 | }, 1660 | { 1661 | "cell_type": "code", 1662 | "execution_count": 43, 1663 | "metadata": {}, 1664 | "outputs": [ 1665 | { 1666 | "data": { 1667 | "text/plain": [ 1668 | "True" 1669 | ] 1670 | }, 1671 | "execution_count": 43, 1672 | "metadata": {}, 1673 | "output_type": "execute_result" 1674 | } 1675 | ], 1676 | "source": [ 1677 | "out" 1678 | ] 1679 | }, 1680 | { 1681 | "cell_type": "markdown", 1682 | "metadata": {}, 1683 | "source": [ 1684 | "## PhoneLM - LJSpeech (Generalise)" 1685 | ] 1686 | }, 1687 | { 1688 | "cell_type": "code", 1689 | "execution_count": null, 1690 | "metadata": {}, 1691 | "outputs": [], 1692 | "source": [] 1693 | } 1694 | ], 1695 | "metadata": { 1696 | "kernelspec": { 1697 | "display_name": "Python 3", 1698 | "language": "python", 1699 | "name": "python3" 1700 | }, 1701 | "language_info": { 1702 | "codemirror_mode": { 1703 | "name": "ipython", 1704 | "version": 3 1705 | }, 1706 | "file_extension": ".py", 1707 | "mimetype": "text/x-python", 1708 | "name": "python", 1709 | "nbconvert_exporter": "python", 1710 | "pygments_lexer": "ipython3", 1711 | "version": "3.9.9" 1712 | }, 1713 | "orig_nbformat": 4 1714 | }, 1715 | "nbformat": 4, 1716 | "nbformat_minor": 2 1717 | } 1718 | --------------------------------------------------------------------------------