├── README.md ├── bin ├── __init__.py └── infer.py ├── data ├── .DS_Store ├── __init__.py ├── collation.py ├── datamodule.py ├── dataset.py ├── fbank.py ├── input_strategies.py ├── speechtokenizer.py └── tokenizer.py ├── images ├── README.md └── overview.png ├── inference.sh ├── models ├── .DS_Store ├── __init__.py ├── macros.py ├── transformer.py ├── uslm.py ├── valle.py └── visualizer.py ├── modules ├── .DS_Store ├── __init__.py ├── activation.py ├── embedding.py ├── optim.py ├── scaling.py ├── scheduler.py └── transformer.py ├── prompts ├── 1580_141083_000002_000002.normalized.txt └── 1580_141083_000002_000002.wav ├── requirements.txt ├── setup.py └── utils ├── .DS_Store ├── __init__.py └── symbol_table.py /README.md: -------------------------------------------------------------------------------- 1 | # USLM: Unified Speech Language Model 2 | 3 | 4 | ## Introduction 5 | Build upon [SpeechTokenizer](https://github.com/ZhangXInFD/SpeechTokenizer), USLM consists of autoregressive and non-autoregressive models, it can hierarchically model information in speech. The autoregressive (AR) model captures the content information by modeling tokens from the first RVQ quantizer. The non-autoregressive (NAR) model complements paralinguistic information for the AR model by generating tokens from the subsequent quantizers conditioned on the first-layer tokens. 6 | 7 |
8 |

9 |
10 | Overview 11 |

12 | 13 | ## Installation 14 | 15 | To get up and running quickly just follow the steps below: 16 | 17 | ``` 18 | # PyTorch 19 | pip install torch==1.13.1 torchaudio==0.13.1 --extra-index-url https://download.pytorch.org/whl/cu116 20 | pip install torchmetrics==0.11.1 21 | # fbank 22 | pip install librosa==0.8.1 23 | 24 | # phonemizer pypinyin 25 | apt-get install espeak-ng 26 | ## OSX: brew install espeak 27 | pip install phonemizer==3.2.1 pypinyin==0.48.0 28 | 29 | # lhotse update to newest version 30 | # https://github.com/lhotse-speech/lhotse/pull/956 31 | # https://github.com/lhotse-speech/lhotse/pull/960 32 | pip uninstall lhotse 33 | pip install git+https://github.com/lhotse-speech/lhotse 34 | 35 | # k2 36 | # find the right version in https://huggingface.co/csukuangfj/k2 37 | pip install https://huggingface.co/csukuangfj/k2/resolve/main/cuda/k2-1.23.4.dev20230224+cuda11.6.torch1.13.1-cp310-cp310-linux_x86_64.whl 38 | 39 | # icefall 40 | git clone https://github.com/k2-fsa/icefall 41 | cd icefall 42 | pip install -r requirements.txt 43 | export PYTHONPATH=`pwd`/../icefall:$PYTHONPATH 44 | echo "export PYTHONPATH=`pwd`/../icefall:\$PYTHONPATH" >> ~/.zshrc 45 | echo "export PYTHONPATH=`pwd`/../icefall:\$PYTHONPATH" >> ~/.bashrc 46 | cd - 47 | source ~/.zshrc 48 | 49 | #SpeechTokenizer 50 | pip install -U speechtokenizer 51 | 52 | # uslm 53 | git clone https://github.com/0nutation/USLM 54 | cd USLM 55 | pip install -e . 56 | ``` 57 | 58 | ## USLM Models 59 | This version of USLM is trained on the LibriTTS dataset, so the performance is not optimal due to data limitations. 60 | 61 | | Model| Dataset |Discription| 62 | |:----|:----:|:----| 63 | |[USLM_libri](https://huggingface.co/fnlp/USLM/tree/main/USLM_libritts)|LibriTTS|USLM trained on LibriTTS dataset | 64 | 65 | 66 | ## Zero-shot TTS Using USLM 67 | Download pre-trained SpeechTokenizer models: 68 | ``` bash 69 | st_dir="ckpt/speechtokenizer/" 70 | mkdir -p ${st_dir} 71 | cd ${st_dir} 72 | wget "https://huggingface.co/fnlp/SpeechTokenizer/resolve/main/speechtokenizer_hubert_avg/SpeechTokenizer.pt" 73 | wget "https://huggingface.co/fnlp/SpeechTokenizer/resolve/main/speechtokenizer_hubert_avg/config.json" 74 | cd - 75 | ``` 76 | 77 | Download pre-trained USLM models: 78 | ``` bash 79 | uslm_dir="ckpt/uslm/" 80 | mkdir -p ${uslm_dir} 81 | cd ${uslm_dir} 82 | wget "https://huggingface.co/fnlp/USLM/resolve/main/USLM_libritts/USLM.pt" 83 | wget "https://huggingface.co/fnlp/USLM/resolve/main/USLM_libritts/unique_text_tokens.k2symbols" 84 | cd - 85 | ``` 86 | 87 | Inference: 88 | ``` bash 89 | out_dir="output/" 90 | mkdir -p ${out_dir} 91 | 92 | python3 bin/infer.py --output-dir ${out_dir}/ \ 93 | --model-name uslm --norm-first true --add-prenet false \ 94 | --share-embedding true --norm-first true --add-prenet false \ 95 | --audio-extractor SpeechTokenizer \ 96 | --speechtokenizer-dir "${st_dir}" \ 97 | --checkpoint=${uslm_dir}/USLM.pt \ 98 | --text-tokens "${uslm_dir}/unique_text_tokens.k2symbols" \ 99 | --text-prompts "mr Soames was a tall, spare man, of a nervous and excitable temperament." \ 100 | --audio-prompts prompts/1580_141083_000002_000002.wav \ 101 | --text "Begin with the fundamental steps of the process. This will give you a solid foundation to build upon and boost your confidence. " \ 102 | ``` 103 | 104 | or you can directly run inference.sh 105 | ``` bash 106 | bash inference.sh 107 | ``` 108 | 109 | ## Acknowledge 110 | [VALL-E](https://github.com/lifeiteng/vall-e): The codebase we build upon. 111 | 112 | ## Citation 113 | If you use this code or result in your paper, please cite our work as: 114 | ```Tex 115 | @misc{zhang2023speechtokenizer, 116 | title={SpeechTokenizer: Unified Speech Tokenizer for Speech Language Models}, 117 | author={Xin Zhang and Dong Zhang and Shimin Li and Yaqian Zhou and Xipeng Qiu}, 118 | year={2023}, 119 | eprint={2308.16692}, 120 | archivePrefix={arXiv}, 121 | primaryClass={cs.CL} 122 | } 123 | ``` 124 | -------------------------------------------------------------------------------- /bin/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/0nutation/USLM/b6cbe07feaeb142fa395b8b98836c534fec16bc9/bin/__init__.py -------------------------------------------------------------------------------- /bin/infer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Created on Fri. Sept. 8 00:43:42 2023 3 | @author: Dong Zhang 4 | """ 5 | 6 | import argparse 7 | import logging 8 | import os 9 | from pathlib import Path 10 | from tqdm import tqdm 11 | os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python" 12 | 13 | import torch 14 | import torchaudio 15 | 16 | from uslm.data import ( 17 | AudioTokenizer, 18 | TextTokenizer, 19 | tokenize_audio, 20 | tokenize_text, 21 | Speechtokenizer, 22 | sttokenize_audio 23 | ) 24 | from uslm.data.collation import get_text_token_collater 25 | from uslm.models import add_model_arguments, get_model 26 | 27 | 28 | def circular_padding(x, tgt_len): 29 | if x.size(-1) >= tgt_len: 30 | return x[:, :, :tgt_len] 31 | t = tgt_len // x.size(-1) 32 | r = tgt_len % x.size(-1) 33 | tgt = x.repeat(1, 1, t) 34 | tgt = torch.cat([tgt, x[:, :, :r]], axis=-1) 35 | return tgt 36 | 37 | 38 | def get_args(): 39 | parser = argparse.ArgumentParser() 40 | 41 | parser.add_argument( 42 | "--text-prompts", 43 | type=str, 44 | default="", 45 | help="Text prompts which are separated by |.", 46 | ) 47 | 48 | parser.add_argument( 49 | "--audio-prompts", 50 | type=str, 51 | default="", 52 | help="Audio prompts which are separated by | and should be aligned with --text-prompts.", 53 | ) 54 | 55 | parser.add_argument( 56 | "--text", 57 | type=str, 58 | default="To get up and running quickly just follow the steps below.", 59 | help="Text to be synthesized.", 60 | ) 61 | 62 | parser.add_argument( 63 | "--audio-extractor", 64 | type=str, 65 | default="Encodec", 66 | help="Encodec or SpeechTokenizer or Fbank", 67 | ) 68 | 69 | # model 70 | add_model_arguments(parser) 71 | 72 | parser.add_argument( 73 | "--text-tokens", 74 | type=str, 75 | default="data/tokenized/unique_text_tokens.k2symbols", 76 | help="Path to the unique text tokens file.", 77 | ) 78 | parser.add_argument( 79 | "--text-extractor", 80 | type=str, 81 | default="espeak", 82 | help="espeak or pypinyin or pypinyin_initials_finals", 83 | ) 84 | 85 | parser.add_argument( 86 | "--checkpoint", 87 | type=str, 88 | default="exp/vallf_nano_full/checkpoint-100000.pt", 89 | help="Path to the saved checkpoint.", 90 | ) 91 | 92 | parser.add_argument( 93 | "--output-dir", 94 | type=Path, 95 | default=Path("infer/demo"), 96 | help="Path to the tokenized files.", 97 | ) 98 | 99 | parser.add_argument( 100 | "--top-k", 101 | type=int, 102 | default=-100, 103 | help="Whether AR Decoder do top_k(if > 0) sampling.", 104 | ) 105 | 106 | parser.add_argument( 107 | "--temperature", 108 | type=float, 109 | default=1.0, 110 | help="The temperature of AR Decoder top_k sampling.", 111 | ) 112 | 113 | parser.add_argument( 114 | "--without-nar", 115 | default=False, 116 | help="without nar", 117 | ) 118 | parser.add_argument( 119 | "--speechtokenizer-dir", 120 | type=str, 121 | default="False", 122 | help="dirname of speechtokenizer models", 123 | ) 124 | 125 | return parser.parse_args() 126 | 127 | 128 | @torch.no_grad() 129 | def main(): 130 | args = get_args() 131 | text_tokenizer = TextTokenizer(backend=args.text_extractor) 132 | text_collater = get_text_token_collater(args.text_tokens) 133 | 134 | if args.audio_extractor == "EnCodec": 135 | sr = 24000 136 | audio_tokenizer = AudioTokenizer() 137 | tokenize_a = tokenize_audio 138 | elif args.audio_extractor == "SpeechTokenizer": 139 | sr = 16000 140 | audio_tokenizer = Speechtokenizer(ckpt_dir=args.speechtokenizer_dir) 141 | tokenize_a = sttokenize_audio 142 | 143 | 144 | device = torch.device("cpu") 145 | if torch.cuda.is_available(): 146 | device = torch.device("cuda", 0) 147 | 148 | model = get_model(args) 149 | if args.checkpoint: 150 | checkpoint = torch.load(args.checkpoint, map_location=device) 151 | missing_keys, unexpected_keys = model.load_state_dict( 152 | checkpoint["model"], strict=True 153 | ) 154 | assert not missing_keys 155 | 156 | 157 | model.to(device) 158 | model.eval() 159 | 160 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 161 | 162 | if os.path.isfile(args.text): # for demos 163 | 164 | with open(args.text) as f: 165 | for line in tqdm(f): 166 | fields = line.strip().split("|") 167 | assert len(fields) == 4 168 | prompt_text, prompt_audio, text, audio_path = fields 169 | logging.info(f"synthesize text: {text}") 170 | os.makedirs(os.path.dirname(audio_path),exist_ok=True) 171 | text_tokens, text_tokens_lens = text_collater( 172 | [ 173 | tokenize_text( 174 | text_tokenizer, text=f"{prompt_text} {text}".strip() 175 | ) 176 | ] 177 | ) 178 | _, enroll_x_lens = text_collater( 179 | [ 180 | tokenize_text( 181 | text_tokenizer, text=f"{prompt_text}".strip() 182 | ) 183 | ] 184 | ) 185 | 186 | audio_prompts = tokenize_a(audio_tokenizer, prompt_audio) 187 | if args.audio_extractor == "SpeechTokenizer": 188 | audio_prompts = audio_prompts.permute(1, 2, 0).to(device) #[b,t,8] 189 | elif args.audio_extractor == "EnCodec": 190 | audio_prompts = audio_prompts[0][0].transpose(2,1) 191 | 192 | # synthesis 193 | encoded_frames = model.inference( 194 | text_tokens.to(device), 195 | text_tokens_lens.to(device), 196 | audio_prompts, 197 | enroll_x_lens=enroll_x_lens, 198 | top_k=args.top_k, 199 | temperature=args.temperature, 200 | ) 201 | 202 | if args.audio_extractor == "SpeechTokenizer": 203 | code_generated = encoded_frames.permute(2,0,1) #[8,b,T] 204 | else: 205 | code_generated = [(encoded_frames.transpose(2, 1), None)] 206 | 207 | 208 | if args.without_nar: 209 | audio_prompts = circular_padding(audio_prompts.permute(2,0,1), code_generated.shape[-1]) 210 | code_generated = torch.cat((code_generated[:1,:,:], audio_prompts[1:4,:,:]),dim=0) 211 | 212 | samples = audio_tokenizer.decode( 213 | code_generated 214 | ) 215 | # store 216 | torchaudio.save(audio_path, samples[0].cpu(), sr) 217 | return 218 | 219 | 220 | 221 | 222 | text_prompts = " ".join(args.text_prompts.split("|")) 223 | 224 | audio_prompts = [] 225 | if args.audio_prompts: 226 | for n, audio_file in enumerate(args.audio_prompts.split("|")): 227 | encoded_frames = tokenize_a(audio_tokenizer, audio_file) 228 | if False: 229 | samples = audio_tokenizer.decode(encoded_frames) 230 | torchaudio.save( 231 | f"{args.output_dir}/p{n}.wav", samples[0], sr 232 | ) 233 | 234 | if args.audio_extractor == "EnCodec": 235 | audio_prompts.append(encoded_frames[0][0]) 236 | elif args.audio_extractor == "SpeechTokenizer": 237 | audio_prompts.append(encoded_frames.permute(1,0,2)) 238 | 239 | 240 | assert len(args.text_prompts.split("|")) == len(audio_prompts) 241 | audio_prompts = torch.concat(audio_prompts, dim=-1).transpose(2, 1) 242 | audio_prompts = audio_prompts.to(device) 243 | 244 | 245 | 246 | for n, text in enumerate(args.text.split("|")): 247 | logging.info(f"synthesize text: {text}") 248 | text_tokens, text_tokens_lens = text_collater( 249 | [ 250 | tokenize_text( 251 | text_tokenizer, text=f"{text_prompts} {text}".strip() 252 | ) 253 | ] 254 | ) 255 | 256 | # synthesis 257 | 258 | enroll_x_lens = None 259 | if text_prompts: 260 | _, enroll_x_lens = text_collater( 261 | [ 262 | tokenize_text( 263 | text_tokenizer, text=f"{text_prompts}".strip() 264 | ) 265 | ] 266 | ) 267 | encoded_frames = model.inference( 268 | text_tokens.to(device), 269 | text_tokens_lens.to(device), 270 | audio_prompts, 271 | enroll_x_lens=enroll_x_lens, 272 | top_k=args.top_k, 273 | temperature=args.temperature, 274 | ) 275 | 276 | 277 | if audio_prompts != []: 278 | if args.audio_extractor == "SpeechTokenizer": 279 | code_generated = encoded_frames.permute(2,0,1) 280 | else: 281 | code_generated = [(encoded_frames.transpose(2, 1), None)] 282 | samples = audio_tokenizer.decode( 283 | code_generated 284 | ) 285 | # store 286 | idx = args.audio_prompts.split('|')[n] 287 | torchaudio.save( 288 | f"{args.output_dir}/gen-{os.path.basename(idx).replace('flac','wav')}", samples[0].cpu(), sr 289 | ) 290 | else: # Transformer 291 | pass 292 | 293 | 294 | torch.set_num_threads(1) 295 | torch.set_num_interop_threads(1) 296 | torch._C._jit_set_profiling_executor(False) 297 | torch._C._jit_set_profiling_mode(False) 298 | torch._C._set_graph_executor_optimize(False) 299 | if __name__ == "__main__": 300 | formatter = ( 301 | "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" 302 | ) 303 | logging.basicConfig(format=formatter, level=logging.INFO) 304 | main() 305 | -------------------------------------------------------------------------------- /data/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/0nutation/USLM/b6cbe07feaeb142fa395b8b98836c534fec16bc9/data/.DS_Store -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | from .datamodule import * 2 | from .tokenizer import * 3 | from .collation import * 4 | from .speechtokenizer import * 5 | -------------------------------------------------------------------------------- /data/collation.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import List, Tuple 3 | 4 | import numpy as np 5 | import torch 6 | 7 | from uslm.utils import SymbolTable 8 | 9 | 10 | class TextTokenCollater: 11 | """Collate list of text tokens 12 | 13 | Map sentences to integers. Sentences are padded to equal length. 14 | Beginning and end-of-sequence symbols can be added. 15 | 16 | Example: 17 | >>> token_collater = TextTokenCollater(text_tokens) 18 | >>> tokens_batch, tokens_lens = token_collater(text) 19 | 20 | Returns: 21 | tokens_batch: IntTensor of shape (B, L) 22 | B: batch dimension, number of input sentences 23 | L: length of the longest sentence 24 | tokens_lens: IntTensor of shape (B,) 25 | Length of each sentence after adding and 26 | but before padding. 27 | """ 28 | 29 | def __init__( 30 | self, 31 | text_tokens: List[str], 32 | add_eos: bool = True, 33 | add_bos: bool = True, 34 | pad_symbol: str = "", 35 | bos_symbol: str = "", 36 | eos_symbol: str = "", 37 | ): 38 | self.pad_symbol = pad_symbol 39 | 40 | self.add_eos = add_eos 41 | self.add_bos = add_bos 42 | 43 | self.bos_symbol = bos_symbol 44 | self.eos_symbol = eos_symbol 45 | 46 | unique_tokens = ( 47 | [pad_symbol] 48 | + ([bos_symbol] if add_bos else []) 49 | + ([eos_symbol] if add_eos else []) 50 | + sorted(text_tokens) 51 | ) 52 | 53 | self.token2idx = {token: idx for idx, token in enumerate(unique_tokens)} 54 | self.idx2token = [token for token in unique_tokens] 55 | 56 | def index( 57 | self, tokens_list: List[str] 58 | ) -> Tuple[torch.Tensor, torch.Tensor]: 59 | seqs, seq_lens = [], [] 60 | for tokens in tokens_list: 61 | assert ( 62 | all([True if s in self.token2idx else False for s in tokens]) 63 | is True 64 | ) 65 | seq = ( 66 | ([self.bos_symbol] if self.add_bos else []) 67 | + list(tokens) 68 | + ([self.eos_symbol] if self.add_eos else []) 69 | ) 70 | seqs.append(seq) 71 | seq_lens.append(len(seq)) 72 | 73 | max_len = max(seq_lens) 74 | for k, (seq, seq_len) in enumerate(zip(seqs, seq_lens)): 75 | seq.extend([self.pad_symbol] * (max_len - seq_len)) 76 | 77 | tokens = torch.from_numpy( 78 | np.array( 79 | [[self.token2idx[token] for token in seq] for seq in seqs], 80 | dtype=np.int64, 81 | ) 82 | ) 83 | tokens_lens = torch.IntTensor(seq_lens) 84 | 85 | return tokens, tokens_lens 86 | 87 | def __call__(self, texts: List[str]) -> Tuple[torch.Tensor, torch.Tensor]: 88 | tokens_seqs = [[p for p in text] for text in texts] 89 | max_len = len(max(tokens_seqs, key=len)) 90 | 91 | seqs = [ 92 | ([self.bos_symbol] if self.add_bos else []) 93 | + list(seq) 94 | + ([self.eos_symbol] if self.add_eos else []) 95 | + [self.pad_symbol] * (max_len - len(seq)) 96 | for seq in tokens_seqs 97 | ] 98 | 99 | # seqs_filtered = [] 100 | # flag=True 101 | # for seq in seqs: 102 | # for token in seq: 103 | # if token not in self.token2idx: 104 | # flag = False 105 | # break 106 | # else: 107 | # flag = True 108 | # if flag: 109 | # seqs_filtered.append(seq) 110 | # seqs = seqs_filtered 111 | 112 | tokens_batch = torch.from_numpy( 113 | np.array( 114 | [[self.token2idx[token] for token in seq] for seq in seqs], 115 | dtype=np.int64, 116 | ) 117 | ) 118 | 119 | tokens_lens = torch.IntTensor( 120 | [ 121 | len(seq) + int(self.add_eos) + int(self.add_bos) 122 | for seq in tokens_seqs 123 | ] 124 | ) 125 | 126 | return tokens_batch, tokens_lens 127 | 128 | 129 | def get_text_token_collater(text_tokens_file: str) -> TextTokenCollater: 130 | text_tokens_path = Path(text_tokens_file) 131 | unique_tokens = SymbolTable.from_file(text_tokens_path) 132 | collater = TextTokenCollater( 133 | unique_tokens.symbols, add_bos=True, add_eos=True 134 | ) 135 | return collater 136 | -------------------------------------------------------------------------------- /data/datamodule.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 (authors: Feiteng Li) 2 | # 3 | # See ../../../../LICENSE for clarification regarding multiple authors 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | 18 | import argparse 19 | import inspect 20 | import logging 21 | from functools import lru_cache 22 | from pathlib import Path 23 | from typing import Any, Dict, Optional 24 | 25 | import torch 26 | from icefall.utils import str2bool 27 | from lhotse import CutSet, load_manifest_lazy 28 | from lhotse.dataset import ( 29 | CutConcatenate, 30 | DynamicBucketingSampler, 31 | PrecomputedFeatures, 32 | SingleCutSampler, 33 | SpecAugment, 34 | ) 35 | from lhotse.dataset.input_strategies import OnTheFlyFeatures 36 | from lhotse.utils import fix_random_seed 37 | from torch.utils.data import DataLoader 38 | 39 | from uslm.data.collation import get_text_token_collater 40 | from uslm.data.dataset import SpeechSynthesisDataset 41 | from uslm.data.fbank import get_fbank_extractor 42 | from uslm.data.input_strategies import PromptedPrecomputedFeatures 43 | 44 | PrecomputedFeatures = PrecomputedFeatures 45 | 46 | 47 | class _SeedWorkers: 48 | def __init__(self, seed: int): 49 | self.seed = seed 50 | 51 | def __call__(self, worker_id: int): 52 | fix_random_seed(self.seed + worker_id) 53 | 54 | 55 | def _get_input_strategy(input_strategy, dataset, cuts): 56 | if input_strategy == "PromptedPrecomputedFeatures": 57 | return PromptedPrecomputedFeatures(dataset, cuts) 58 | 59 | return eval(input_strategy)() 60 | 61 | 62 | class TtsDataModule: 63 | """ 64 | DataModule for VALL-E TTS experiments. 65 | It assumes there is always one train and valid dataloader. 66 | 67 | It contains all the common data pipeline modules used in TTS 68 | experiments, e.g.: 69 | - dynamic batch size, 70 | - bucketing samplers, 71 | - cut concatenation[not used & tested yet], 72 | - augmentation[not used & tested yet], 73 | - on-the-fly feature extraction[not used & tested yet] 74 | 75 | This class should be derived for specific corpora used in TTS tasks. 76 | """ 77 | 78 | def __init__(self, args: argparse.Namespace): 79 | self.args = args 80 | 81 | @classmethod 82 | def add_arguments(cls, parser: argparse.ArgumentParser): 83 | group = parser.add_argument_group( 84 | title="TTS data related options", 85 | description="These options are used for the preparation of " 86 | "PyTorch DataLoaders from Lhotse CutSet's -- they control the " 87 | "effective batch sizes, sampling strategies, applied data " 88 | "augmentations, etc.", 89 | ) 90 | group.add_argument( 91 | "--manifest-dir", 92 | type=Path, 93 | default=Path("data/tokenized"), 94 | help="Path to directory with train/valid/test cuts.", 95 | ) 96 | group.add_argument( 97 | "--max-duration", 98 | type=int, 99 | default=40.0, 100 | help="Maximum pooled recordings duration (seconds) in a " 101 | "single batch. You can reduce it if it causes CUDA OOM.", 102 | ) 103 | group.add_argument( 104 | "--bucketing-sampler", 105 | type=str2bool, 106 | default=True, 107 | help="When enabled, the batches will come from buckets of " 108 | "similar duration (saves padding frames).", 109 | ) 110 | group.add_argument( 111 | "--num-buckets", 112 | type=int, 113 | default=10, 114 | help="The number of buckets for the DynamicBucketingSampler" 115 | "(you might want to increase it for larger datasets).", 116 | ) 117 | group.add_argument( 118 | "--concatenate-cuts", 119 | type=str2bool, 120 | default=False, 121 | help="When enabled, utterances (cuts) will be concatenated " 122 | "to minimize the amount of padding.", 123 | ) 124 | group.add_argument( 125 | "--duration-factor", 126 | type=float, 127 | default=1.0, 128 | help="Determines the maximum duration of a concatenated cut " 129 | "relative to the duration of the longest cut in a batch.", 130 | ) 131 | group.add_argument( 132 | "--gap", 133 | type=float, 134 | default=0.1, 135 | help="The amount of padding (in seconds) inserted between " 136 | "concatenated cuts. This padding is filled with noise when " 137 | "noise augmentation is used.", 138 | ) 139 | group.add_argument( 140 | "--on-the-fly-feats", 141 | type=str2bool, 142 | default=False, 143 | help="When enabled, use on-the-fly cut mixing and feature " 144 | "extraction. Will drop existing precomputed feature manifests " 145 | "if available.", 146 | ) 147 | group.add_argument( 148 | "--shuffle", 149 | type=str2bool, 150 | default=True, 151 | help="When enabled (=default), the examples will be " 152 | "shuffled for each epoch.", 153 | ) 154 | group.add_argument( 155 | "--drop-last", 156 | type=str2bool, 157 | default=False, 158 | help="Whether to drop last batch. Used by sampler.", 159 | ) 160 | group.add_argument( 161 | "--return-cuts", 162 | type=str2bool, 163 | default=True, 164 | help="When enabled, each batch will have the " 165 | "field: batch['supervisions']['cut'] with the cuts that " 166 | "were used to construct it.", 167 | ) 168 | 169 | group.add_argument( 170 | "--num-workers", 171 | type=int, 172 | default=8, 173 | help="The number of training dataloader workers that " 174 | "collect the batches.", 175 | ) 176 | 177 | group.add_argument( 178 | "--enable-spec-aug", 179 | type=str2bool, 180 | default=False, 181 | help="When enabled, use SpecAugment for training dataset.", 182 | ) 183 | 184 | group.add_argument( 185 | "--spec-aug-time-warp-factor", 186 | type=int, 187 | default=80, 188 | help="Used only when --enable-spec-aug is True. " 189 | "It specifies the factor for time warping in SpecAugment. " 190 | "Larger values mean more warping. " 191 | "A value less than 1 means to disable time warp.", 192 | ) 193 | 194 | group.add_argument( 195 | "--input-strategy", 196 | type=str, 197 | default="PrecomputedFeatures", 198 | help="AudioSamples or PrecomputedFeatures or PromptedPrecomputedFeatures", 199 | ) 200 | 201 | group.add_argument( 202 | "--dataset", 203 | type=str, 204 | default="libritts", 205 | help="--input-strategy PromptedPrecomputedFeatures needs dataset name to prepare prompts.", 206 | ) 207 | 208 | parser.add_argument( 209 | "--text-tokens", 210 | type=str, 211 | default="data/tokenized/unique_text_tokens.k2symbols", 212 | help="Path to the unique text tokens file", 213 | ) 214 | 215 | parser.add_argument( 216 | "--sampling-rate", 217 | type=int, 218 | default=24000, 219 | help="""Audio sampling rate.""", 220 | ) 221 | 222 | def train_dataloaders( 223 | self, 224 | cuts_train: CutSet, 225 | sampler_state_dict: Optional[Dict[str, Any]] = None, 226 | ) -> DataLoader: 227 | """ 228 | Args: 229 | cuts_train: 230 | CutSet for training. 231 | sampler_state_dict: 232 | The state dict for the training sampler. 233 | """ 234 | transforms = [] 235 | 236 | if self.args.concatenate_cuts: 237 | logging.info( 238 | f"Using cut concatenation with duration factor " 239 | f"{self.args.duration_factor} and gap {self.args.gap}." 240 | ) 241 | # Cut concatenation should be the first transform in the list, 242 | # so that if we e.g. mix noise in, it will fill the gaps between 243 | # different utterances. 244 | transforms = [ 245 | CutConcatenate( 246 | duration_factor=self.args.duration_factor, gap=self.args.gap 247 | ) 248 | ] + transforms 249 | 250 | input_transforms = [] 251 | if self.args.enable_spec_aug: 252 | logging.info("Enable SpecAugment") 253 | logging.info( 254 | f"Time warp factor: {self.args.spec_aug_time_warp_factor}" 255 | ) 256 | # Set the value of num_frame_masks according to Lhotse's version. 257 | # In different Lhotse's versions, the default of num_frame_masks is 258 | # different. 259 | num_frame_masks = 10 260 | num_frame_masks_parameter = inspect.signature( 261 | SpecAugment.__init__ 262 | ).parameters["num_frame_masks"] 263 | if num_frame_masks_parameter.default == 1: 264 | num_frame_masks = 2 265 | logging.info(f"Num frame mask: {num_frame_masks}") 266 | input_transforms.append( 267 | SpecAugment( 268 | time_warp_factor=self.args.spec_aug_time_warp_factor, 269 | num_frame_masks=num_frame_masks, 270 | features_mask_size=27, 271 | num_feature_masks=2, 272 | frames_mask_size=100, 273 | ) 274 | ) 275 | else: 276 | logging.info("Disable SpecAugment") 277 | 278 | logging.info("About to create train dataset") 279 | if self.args.on_the_fly_feats: 280 | # NOTE: the PerturbSpeed transform should be added only if we 281 | # remove it from data prep stage. 282 | # Add on-the-fly speed perturbation; since originally it would 283 | # have increased epoch size by 3, we will apply prob 2/3 and use 284 | # 3x more epochs. 285 | # Speed perturbation probably should come first before 286 | # concatenation, but in principle the transforms order doesn't have 287 | # to be strict (e.g. could be randomized) 288 | # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa 289 | # Drop feats to be on the safe side. 290 | train = SpeechSynthesisDataset( 291 | get_text_token_collater(self.args.text_tokens), 292 | cut_transforms=transforms, 293 | feature_input_strategy=OnTheFlyFeatures(get_fbank_extractor()), 294 | feature_transforms=input_transforms, 295 | ) 296 | else: 297 | train = SpeechSynthesisDataset( 298 | get_text_token_collater(self.args.text_tokens), 299 | feature_input_strategy=_get_input_strategy( 300 | self.args.input_strategy, self.args.dataset, cuts_train 301 | ), 302 | cut_transforms=transforms, 303 | feature_transforms=input_transforms, 304 | ) 305 | 306 | if self.args.bucketing_sampler: 307 | logging.info("Using DynamicBucketingSampler") 308 | train_sampler = DynamicBucketingSampler( 309 | cuts_train, 310 | max_duration=self.args.max_duration, 311 | shuffle=self.args.shuffle, 312 | num_buckets=self.args.num_buckets, 313 | drop_last=True, 314 | ) 315 | else: 316 | logging.info( 317 | "Using SingleCutSampler and sort by duraton(ascending=True)." 318 | ) 319 | cuts_train = cuts_train.to_eager().sort_by_duration(ascending=True) 320 | train_sampler = SingleCutSampler( 321 | cuts_train, 322 | max_duration=self.args.max_duration, 323 | shuffle=self.args.shuffle, 324 | ) 325 | logging.info("About to create train dataloader") 326 | 327 | if sampler_state_dict is not None: 328 | logging.info("Loading sampler state dict") 329 | train_sampler.load_state_dict(sampler_state_dict) 330 | 331 | # 'seed' is derived from the current random state, which will have 332 | # previously been set in the main process. 333 | seed = torch.randint(0, 100000, ()).item() 334 | worker_init_fn = _SeedWorkers(seed) 335 | 336 | train_dl = DataLoader( 337 | train, 338 | sampler=train_sampler, 339 | batch_size=None, 340 | num_workers=self.args.num_workers, 341 | persistent_workers=False, 342 | worker_init_fn=worker_init_fn, 343 | ) 344 | 345 | return train_dl 346 | 347 | def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: 348 | logging.info("About to create dev dataset") 349 | if self.args.on_the_fly_feats: 350 | validate = SpeechSynthesisDataset( 351 | get_text_token_collater(self.args.text_tokens), 352 | feature_input_strategy=OnTheFlyFeatures(get_fbank_extractor()), 353 | cut_transforms=[], 354 | ) 355 | else: 356 | validate = SpeechSynthesisDataset( 357 | get_text_token_collater(self.args.text_tokens), 358 | feature_input_strategy=_get_input_strategy( 359 | self.args.input_strategy, self.args.dataset, cuts_valid 360 | ), 361 | cut_transforms=[], 362 | ) 363 | valid_sampler = DynamicBucketingSampler( 364 | cuts_valid, 365 | max_duration=self.args.max_duration, 366 | shuffle=False, 367 | drop_last=True, 368 | ) 369 | logging.info("About to create dev dataloader") 370 | valid_dl = DataLoader( 371 | validate, 372 | sampler=valid_sampler, 373 | batch_size=None, 374 | num_workers=4, 375 | persistent_workers=False, 376 | ) 377 | 378 | return valid_dl 379 | 380 | def test_dataloaders(self, cuts: CutSet) -> DataLoader: 381 | logging.debug("About to create test dataset") 382 | test = SpeechSynthesisDataset( 383 | get_text_token_collater(self.args.text_tokens), 384 | feature_input_strategy=OnTheFlyFeatures(get_fbank_extractor()) 385 | if self.args.on_the_fly_feats 386 | else _get_input_strategy( 387 | self.args.input_strategy, self.args.dataset, cuts 388 | ), 389 | cut_transforms=[], 390 | ) 391 | sampler = DynamicBucketingSampler( 392 | cuts, 393 | max_duration=self.args.max_duration, 394 | shuffle=False, 395 | drop_last=True, 396 | ) 397 | logging.debug("About to create test dataloader") 398 | test_dl = DataLoader( 399 | test, 400 | batch_size=None, 401 | sampler=sampler, 402 | num_workers=self.args.num_workers, 403 | ) 404 | return test_dl 405 | 406 | @lru_cache() 407 | def train_cuts(self) -> CutSet: 408 | logging.info("About to get train cuts") 409 | return load_manifest_lazy( 410 | self.args.manifest_dir / "cuts_train.jsonl.gz" 411 | ) 412 | 413 | @lru_cache() 414 | def dev_cuts(self) -> CutSet: 415 | logging.info("About to get dev cuts") 416 | return load_manifest_lazy(self.args.manifest_dir / "cuts_dev.jsonl.gz") 417 | 418 | @lru_cache() 419 | def test_cuts(self) -> CutSet: 420 | logging.info("About to get test cuts") 421 | return load_manifest_lazy(self.args.manifest_dir / "cuts_test.jsonl.gz") 422 | -------------------------------------------------------------------------------- /data/dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 (authors: Feiteng Li) 2 | # 3 | # See ../../../../LICENSE for clarification regarding multiple authors 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """ 18 | modified from lhoste.dataset.speech_synthesis.py 19 | """ 20 | 21 | from typing import Callable, Dict, List, Sequence, Union 22 | 23 | import torch 24 | from lhotse import validate 25 | from lhotse.cut import CutSet 26 | from lhotse.dataset.collation import collate_audio 27 | from lhotse.dataset.input_strategies import BatchIO, PrecomputedFeatures 28 | from lhotse.utils import ifnone 29 | 30 | from uslm.data.collation import TextTokenCollater 31 | 32 | 33 | class SpeechSynthesisDataset(torch.utils.data.Dataset): 34 | """ 35 | The PyTorch Dataset for the speech synthesis(e.g. TTS) task. 36 | Each item in this dataset is a dict of: 37 | 38 | .. code-block:: 39 | 40 | { 41 | 'audio': (B x NumSamples) float tensor 42 | 'audio_lens': (B, ) int tensor 43 | 'text': str 44 | 'audio_features': (B x NumFrames x NumFeatures) float tensor 45 | 'audio_features_lens': (B, ) int tensor 46 | 'text_tokens': (B x NumTextTokens) long tensor 47 | 'text_tokens_lens': (B, ) int tensor 48 | } 49 | """ 50 | 51 | def __init__( 52 | self, 53 | text_token_collater: TextTokenCollater, 54 | cut_transforms: List[Callable[[CutSet], CutSet]] = None, 55 | feature_input_strategy: BatchIO = PrecomputedFeatures(), 56 | feature_transforms: Union[Sequence[Callable], Callable] = None, 57 | ) -> None: 58 | super().__init__() 59 | 60 | self.text_token_collater = text_token_collater 61 | self.cut_transforms = ifnone(cut_transforms, []) 62 | self.feature_input_strategy = feature_input_strategy 63 | 64 | if feature_transforms is None: 65 | feature_transforms = [] 66 | elif not isinstance(feature_transforms, Sequence): 67 | feature_transforms = [feature_transforms] 68 | 69 | assert all( 70 | isinstance(transform, Callable) for transform in feature_transforms 71 | ), "Feature transforms must be Callable" 72 | self.feature_transforms = feature_transforms 73 | 74 | def __getitem__(self, cuts: CutSet) -> Dict[str, torch.Tensor]: 75 | validate_for_tts(cuts) 76 | 77 | for transform in self.cut_transforms: 78 | cuts = transform(cuts) 79 | 80 | if False: # not used 81 | audio, audio_lens = collate_audio(cuts) 82 | else: # for sharing tokenized features in different machines 83 | audio, audio_lens = None, None 84 | 85 | audio_features, audio_features_lens = self.feature_input_strategy(cuts) 86 | 87 | for transform in self.feature_transforms: 88 | audio_features = transform(audio_features) 89 | 90 | text_tokens, text_tokens_lens = self.text_token_collater( 91 | [cut.supervisions[0].custom["tokens"]["text"] for cut in cuts] 92 | ) 93 | 94 | return { 95 | "utt_id": [cut.id for cut in cuts], 96 | "text": [cut.supervisions[0].text for cut in cuts], 97 | "audio": audio, 98 | "audio_lens": audio_lens, 99 | "audio_features": audio_features, 100 | "audio_features_lens": audio_features_lens, 101 | "text_tokens": text_tokens, 102 | "text_tokens_lens": text_tokens_lens, 103 | } 104 | 105 | 106 | def validate_for_tts(cuts: CutSet) -> None: 107 | validate(cuts) 108 | for cut in cuts: 109 | assert ( 110 | len(cut.supervisions) == 1 111 | ), "Only the Cuts with single supervision are supported." 112 | -------------------------------------------------------------------------------- /data/fbank.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 (authors: Feiteng Li) 2 | # 3 | # See ../../../../LICENSE for clarification regarding multiple authors 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | 18 | from dataclasses import asdict, dataclass 19 | from typing import Any, Dict, Optional, Union 20 | 21 | import numpy as np 22 | import torch 23 | from lhotse.features.base import FeatureExtractor 24 | from lhotse.utils import EPSILON, Seconds, compute_num_frames 25 | from librosa.filters import mel as librosa_mel_fn 26 | 27 | 28 | @dataclass 29 | class BigVGANFbankConfig: 30 | # Spectogram-related part 31 | # Note that frame_length and frame_shift will be converted to milliseconds before torchaudio/Kaldi sees them 32 | frame_length: Seconds = 1024 / 24000.0 33 | frame_shift: Seconds = 256 / 24000.0 34 | remove_dc_offset: bool = True 35 | round_to_power_of_two: bool = True 36 | 37 | # Fbank-related part 38 | low_freq: float = 0.0 39 | high_freq: float = 12000.0 40 | num_mel_bins: int = 100 41 | use_energy: bool = False 42 | 43 | def to_dict(self) -> Dict[str, Any]: 44 | return asdict(self) 45 | 46 | @staticmethod 47 | def from_dict(data: Dict[str, Any]) -> "BigVGANFbankConfig": 48 | return BigVGANFbankConfig(**data) 49 | 50 | 51 | def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): 52 | return torch.log(torch.clamp(x, min=clip_val) * C) 53 | 54 | 55 | def spectral_normalize_torch(magnitudes): 56 | output = dynamic_range_compression_torch(magnitudes) 57 | return output 58 | 59 | 60 | # https://github.com/NVIDIA/BigVGAN 61 | # bigvgan_24khz_100band https://drive.google.com/drive/folders/1EpxX6AsxjCbbk0mmAhE0td6eYiABr8Oz 62 | class BigVGANFbank(FeatureExtractor): 63 | name = "fbank" 64 | config_type = BigVGANFbankConfig 65 | 66 | def __init__(self, config: Optional[Any] = None): 67 | super(BigVGANFbank, self).__init__(config) 68 | sampling_rate = 24000 69 | self.mel_basis = torch.from_numpy( 70 | librosa_mel_fn( 71 | sampling_rate, 72 | 1024, 73 | self.config.num_mel_bins, 74 | self.config.low_freq, 75 | self.config.high_freq, 76 | ).astype(np.float32) 77 | ) 78 | self.hann_window = torch.hann_window(1024) 79 | 80 | def _feature_fn(self, samples, **kwargs): 81 | win_length, n_fft = 1024, 1024 82 | hop_size = 256 83 | if True: 84 | sampling_rate = 24000 85 | duration = round(samples.shape[-1] / sampling_rate, ndigits=12) 86 | expected_num_frames = compute_num_frames( 87 | duration=duration, 88 | frame_shift=self.frame_shift, 89 | sampling_rate=sampling_rate, 90 | ) 91 | pad_size = ( 92 | (expected_num_frames - 1) * hop_size 93 | + win_length 94 | - samples.shape[-1] 95 | ) 96 | assert pad_size >= 0 97 | 98 | y = torch.nn.functional.pad( 99 | samples, 100 | (0, pad_size), 101 | mode="constant", 102 | ) 103 | else: 104 | y = torch.nn.functional.pad( 105 | samples, 106 | (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), 107 | mode="reflect", 108 | ) 109 | 110 | y = y.squeeze(1) 111 | 112 | # complex tensor as default, then use view_as_real for future pytorch compatibility 113 | spec = torch.stft( 114 | y, 115 | n_fft, 116 | hop_length=hop_size, 117 | win_length=win_length, 118 | window=self.hann_window, 119 | center=False, 120 | pad_mode="reflect", 121 | normalized=False, 122 | onesided=True, 123 | return_complex=True, 124 | ) 125 | spec = torch.view_as_real(spec) 126 | spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9)) 127 | 128 | spec = torch.matmul(self.mel_basis, spec) 129 | spec = spectral_normalize_torch(spec) 130 | 131 | return spec.transpose(2, 1).squeeze(0) 132 | 133 | def extract( 134 | self, samples: Union[np.ndarray, torch.Tensor], sampling_rate: int 135 | ) -> np.ndarray: 136 | assert sampling_rate == 24000 137 | params = asdict(self.config) 138 | params.update({"sample_frequency": sampling_rate, "snip_edges": False}) 139 | params["frame_shift"] *= 1000.0 140 | params["frame_length"] *= 1000.0 141 | if not isinstance(samples, torch.Tensor): 142 | samples = torch.from_numpy(samples) 143 | # Torchaudio Kaldi feature extractors expect the channel dimension to be first. 144 | if len(samples.shape) == 1: 145 | samples = samples.unsqueeze(0) 146 | features = self._feature_fn(samples, **params).to(torch.float32) 147 | return features.numpy() 148 | 149 | @property 150 | def frame_shift(self) -> Seconds: 151 | return self.config.frame_shift 152 | 153 | def feature_dim(self, sampling_rate: int) -> int: 154 | return self.config.num_mel_bins 155 | 156 | @staticmethod 157 | def mix( 158 | features_a: np.ndarray, 159 | features_b: np.ndarray, 160 | energy_scaling_factor_b: float, 161 | ) -> np.ndarray: 162 | return np.log( 163 | np.maximum( 164 | # protection against log(0); max with EPSILON is adequate since these are energies (always >= 0) 165 | EPSILON, 166 | np.exp(features_a) 167 | + energy_scaling_factor_b * np.exp(features_b), 168 | ) 169 | ) 170 | 171 | @staticmethod 172 | def compute_energy(features: np.ndarray) -> float: 173 | return float(np.sum(np.exp(features))) 174 | 175 | 176 | def get_fbank_extractor() -> BigVGANFbank: 177 | return BigVGANFbank(BigVGANFbankConfig()) 178 | 179 | 180 | if __name__ == "__main__": 181 | extractor = BigVGANFbank(BigVGANFbankConfig()) 182 | 183 | samples = torch.from_numpy(np.random.random([1000]).astype(np.float32)) 184 | samples = torch.clip(samples, -1.0, 1.0) 185 | fbank = extractor.extract(samples, 24000.0) 186 | print(f"fbank {fbank.shape}") 187 | 188 | from scipy.io.wavfile import read 189 | 190 | MAX_WAV_VALUE = 32768.0 191 | 192 | sampling_rate, samples = read( 193 | "egs/libritts/prompts/5639_40744_000000_000002.wav" 194 | ) 195 | print(f"samples: [{samples.min()}, {samples.max()}]") 196 | fbank = extractor.extract(samples.astype(np.float32) / MAX_WAV_VALUE, 24000) 197 | print(f"fbank {fbank.shape}") 198 | 199 | import matplotlib.pyplot as plt 200 | 201 | _ = plt.figure(figsize=(18, 10)) 202 | plt.imshow( 203 | X=fbank.transpose(1, 0), 204 | cmap=plt.get_cmap("jet"), 205 | aspect="auto", 206 | interpolation="nearest", 207 | ) 208 | plt.gca().invert_yaxis() 209 | plt.savefig("egs/libritts/prompts/5639_40744_000000_000002.png") 210 | plt.close() 211 | 212 | print("fbank test PASS!") 213 | -------------------------------------------------------------------------------- /data/input_strategies.py: -------------------------------------------------------------------------------- 1 | import random 2 | from collections import defaultdict 3 | from concurrent.futures import ThreadPoolExecutor 4 | from typing import Tuple, Type 5 | 6 | from lhotse import CutSet 7 | from lhotse.dataset.collation import collate_features 8 | from lhotse.dataset.input_strategies import ( 9 | ExecutorType, 10 | PrecomputedFeatures, 11 | _get_executor, 12 | ) 13 | from lhotse.utils import fastcopy 14 | 15 | 16 | class PromptedFeatures: 17 | def __init__(self, prompts, features): 18 | self.prompts = prompts 19 | self.features = features 20 | 21 | def to(self, device): 22 | return PromptedFeatures( 23 | self.prompts.to(device), self.features.to(device) 24 | ) 25 | 26 | def sum(self): 27 | return self.features.sum() 28 | 29 | @property 30 | def ndim(self): 31 | return self.features.ndim 32 | 33 | @property 34 | def data(self): 35 | return (self.prompts, self.features) 36 | 37 | 38 | class PromptedPrecomputedFeatures(PrecomputedFeatures): 39 | """ 40 | :class:`InputStrategy` that reads pre-computed features, whose manifests 41 | are attached to cuts, from disk. 42 | 43 | It automatically pads the feature matrices with pre or post feature. 44 | 45 | .. automethod:: __call__ 46 | """ 47 | 48 | def __init__( 49 | self, 50 | dataset: str, 51 | cuts: CutSet, 52 | num_workers: int = 0, 53 | executor_type: Type[ExecutorType] = ThreadPoolExecutor, 54 | ) -> None: 55 | super(PromptedPrecomputedFeatures, self).__init__( 56 | num_workers, executor_type 57 | ) 58 | 59 | self.utt2neighbors = defaultdict(lambda: []) 60 | 61 | if dataset.lower() == "libritts": 62 | # 909_131041_000013_000002 63 | # 909_131041_000013_000003 64 | speaker2utts = defaultdict(lambda: []) 65 | 66 | utt2cut = {} 67 | for cut in cuts: 68 | speaker = cut.supervisions[0].speaker 69 | speaker2utts[speaker].append(cut.id) 70 | utt2cut[cut.id] = cut 71 | 72 | for spk in speaker2utts: 73 | uttids = sorted(speaker2utts[spk]) 74 | # Using the property of sorted keys to find previous utterance 75 | # The keys has structure speaker_book_x_y e.g. 1089_134691_000004_000001 76 | if len(uttids) == 1: 77 | self.utt2neighbors[uttids[0]].append(utt2cut[uttids[0]]) 78 | continue 79 | 80 | utt2prevutt = dict(zip(uttids, [uttids[1]] + uttids[:-1])) 81 | utt2postutt = dict(zip(uttids[:-1], uttids[1:])) 82 | 83 | for utt in utt2prevutt: 84 | self.utt2neighbors[utt].append(utt2cut[utt2prevutt[utt]]) 85 | 86 | for utt in utt2postutt: 87 | self.utt2neighbors[utt].append(utt2cut[utt2postutt[utt]]) 88 | elif dataset.lower() == "ljspeech": 89 | utt2cut = {} 90 | uttids = [] 91 | for cut in cuts: 92 | uttids.append(cut.id) 93 | utt2cut[cut.id] = cut 94 | 95 | if len(uttids) == 1: 96 | self.utt2neighbors[uttids[0]].append(utt2cut[uttids[0]]) 97 | else: 98 | # Using the property of sorted keys to find previous utterance 99 | # The keys has structure: LJ001-0010 100 | utt2prevutt = dict(zip(uttids, [uttids[1]] + uttids[:-1])) 101 | utt2postutt = dict(zip(uttids[:-1], uttids[1:])) 102 | 103 | for utt in utt2postutt: 104 | postutt = utt2postutt[utt] 105 | if utt[:5] == postutt[:5]: 106 | self.utt2neighbors[utt].append(utt2cut[postutt]) 107 | 108 | for utt in utt2prevutt: 109 | prevutt = utt2prevutt[utt] 110 | if utt[:5] == prevutt[:5] or not self.utt2neighbors[utt]: 111 | self.utt2neighbors[utt].append(utt2cut[prevutt]) 112 | else: 113 | raise ValueError 114 | 115 | def __call__( 116 | self, cuts: CutSet 117 | ) -> Tuple[PromptedFeatures, PromptedFeatures]: 118 | """ 119 | Reads the pre-computed features from disk/other storage. 120 | The returned shape is``(B, T, F) => (batch_size, num_frames, num_features)``. 121 | 122 | :return: a tensor with collated features, and a tensor of ``num_frames`` of each cut before padding. 123 | """ 124 | features, features_lens = collate_features( 125 | cuts, 126 | executor=_get_executor( 127 | self.num_workers, executor_type=self._executor_type 128 | ), 129 | ) 130 | 131 | prompts_cuts = [] 132 | for k, cut in enumerate(cuts): 133 | prompts_cut = random.choice(self.utt2neighbors[cut.id]) 134 | prompts_cuts.append(fastcopy(prompts_cut, id=f"{cut.id}-{str(k)}")) 135 | 136 | mini_duration = min([cut.duration for cut in prompts_cuts] + [3.0]) 137 | # prompts_cuts = CutSet.from_cuts(prompts_cuts).truncate( 138 | # max_duration=mini_duration, 139 | # offset_type="random", 140 | # preserve_id=True, 141 | # ) 142 | prompts_cuts = CutSet( 143 | cuts={k: cut for k, cut in enumerate(prompts_cuts)} 144 | ).truncate( 145 | max_duration=mini_duration, 146 | offset_type="random", 147 | preserve_id=False, 148 | ) 149 | 150 | prompts, prompts_lens = collate_features( 151 | prompts_cuts, 152 | executor=_get_executor( 153 | self.num_workers, executor_type=self._executor_type 154 | ), 155 | ) 156 | 157 | return PromptedFeatures(prompts, features), PromptedFeatures( 158 | prompts_lens, features_lens 159 | ) 160 | -------------------------------------------------------------------------------- /data/speechtokenizer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Created on Fri. Sept. 8 00:40:25 2023 3 | @author: Dong Zhang 4 | """ 5 | 6 | import re 7 | from dataclasses import asdict, dataclass 8 | from typing import Any, Dict, List, Optional, Pattern, Union 9 | 10 | import numpy as np 11 | import torch 12 | import torchaudio 13 | from encodec.utils import convert_audio 14 | from lhotse.features import FeatureExtractor 15 | from lhotse.utils import Seconds, compute_num_frames 16 | import json 17 | import os 18 | from speechtokenizer import SpeechTokenizer 19 | 20 | 21 | 22 | class AttrDict(dict): 23 | def __init__(self, *args, **kwargs): 24 | super(AttrDict, self).__init__(*args, **kwargs) 25 | self.__dict__ = self 26 | 27 | 28 | class Speechtokenizer: 29 | """SpeechTokenizer""" 30 | 31 | def __init__( 32 | self, 33 | ckpt_dir: str = "", 34 | device: Any = None, 35 | ) -> None: 36 | 37 | config_path = os.path.join(ckpt_dir, "config.json") 38 | ckpt_path = os.path.join(ckpt_dir, "SpeechTokenizer.pt") 39 | # load model 40 | model = SpeechTokenizer.load_from_checkpoint(config_path, ckpt_path) 41 | model.eval() 42 | 43 | if not device: 44 | device = torch.device("cpu") 45 | if torch.cuda.is_available(): 46 | device = torch.device("cuda:0") 47 | 48 | self._device = device 49 | self.model = model.to(device) 50 | self.sample_rate = model.sample_rate 51 | self.channels = 1 52 | 53 | @property 54 | def device(self): 55 | return self._device 56 | 57 | def encode(self, wav: torch.Tensor) -> torch.Tensor: 58 | return self.model.encode(wav.to(self.device)) 59 | 60 | def decode(self, frames: torch.Tensor) -> torch.Tensor: 61 | return self.model.decode(frames) 62 | 63 | 64 | def sttokenize_audio(tokenizer: Speechtokenizer, audio_path: str): 65 | # Load and pre-process the audio waveform 66 | wav, sr = torchaudio.load(audio_path) 67 | wav = convert_audio(wav, sr, tokenizer.sample_rate, tokenizer.channels) 68 | wav = wav.unsqueeze(0) 69 | 70 | # Extract discrete codes from EnCodec 71 | with torch.no_grad(): 72 | encoded_frames = tokenizer.encode(wav) 73 | return encoded_frames 74 | 75 | 76 | @dataclass 77 | class STAudioTokenConfig: 78 | frame_shift: Seconds = 320.0 / 16000 79 | num_quantizers: int = 8 80 | 81 | def to_dict(self) -> Dict[str, Any]: 82 | return asdict(self) 83 | 84 | @staticmethod 85 | def from_dict(data: Dict[str, Any]) -> "STAudioTokenConfig": 86 | return STAudioTokenConfig(**data) 87 | 88 | 89 | class STAudioTokenExtractor(FeatureExtractor): 90 | name = "speechtokenizer" 91 | config_type = STAudioTokenConfig 92 | 93 | def __init__(self, config: Optional[Any] = None): 94 | super(STAudioTokenExtractor, self).__init__(config) 95 | self.tokenizer = Speechtokenizer() 96 | 97 | def extract( 98 | self, samples: Union[np.ndarray, torch.Tensor], sampling_rate: int 99 | ) -> np.ndarray: 100 | if not isinstance(samples, torch.Tensor): 101 | samples = torch.from_numpy(samples) 102 | if sampling_rate != self.tokenizer.sample_rate: 103 | samples = convert_audio( 104 | samples, 105 | sampling_rate, 106 | self.tokenizer.sample_rate, 107 | self.tokenizer.channels, 108 | ) 109 | if len(samples.shape) == 2: 110 | samples = samples.unsqueeze(0) 111 | else: 112 | raise ValueError() 113 | 114 | device = self.tokenizer.device 115 | encoded_frames = self.tokenizer.encode(samples.detach().to(device)) 116 | codes = encoded_frames.permute(1,0,2) # [B, n_q, T] 117 | if True: 118 | duration = round(samples.shape[-1] / sampling_rate, ndigits=12) 119 | expected_num_frames = compute_num_frames( 120 | duration=duration, 121 | frame_shift=self.frame_shift, 122 | sampling_rate=sampling_rate, 123 | ) 124 | assert abs(codes.shape[-1] - expected_num_frames) <= 1 125 | codes = codes[..., :expected_num_frames] 126 | return codes.cpu().squeeze(0).permute(1, 0).numpy() 127 | 128 | @property 129 | def frame_shift(self) -> Seconds: 130 | return self.config.frame_shift 131 | 132 | def feature_dim(self, sampling_rate: int) -> int: 133 | return self.config.num_quantizers 134 | 135 | def pad_tensor_list(self, tensor_list, device, padding_value=0): 136 | # 计算每个张量的长度 137 | lengths = [tensor.shape[0] for tensor in tensor_list] 138 | # 使用pad_sequence函数进行填充 139 | tensor_list = [torch.Tensor(t).to(device) for t in tensor_list] 140 | padded_tensor = torch.nn.utils.rnn.pad_sequence( 141 | tensor_list, batch_first=True, padding_value=padding_value 142 | ) 143 | return padded_tensor, lengths 144 | 145 | def extract_batch(self, samples, sampling_rate, lengths) -> np.ndarray: 146 | samples = [wav.squeeze() for wav in samples] 147 | device = self.tokenizer.device 148 | samples, lengths = self.pad_tensor_list(samples, device) 149 | samples = samples.unsqueeze(1) 150 | 151 | if not isinstance(samples, torch.Tensor): 152 | samples = torch.from_numpy(samples) 153 | if len(samples.shape) != 3: 154 | raise ValueError() 155 | if sampling_rate != self.tokenizer.sample_rate: 156 | samples = [ 157 | convert_audio( 158 | wav, 159 | sampling_rate, 160 | self.tokenizer.sample_rate, 161 | self.tokenizer.channels, 162 | ) 163 | for wav in samples 164 | ] 165 | # Extract discrete codes from EnCodec 166 | with torch.no_grad(): 167 | encoded_frames = self.tokenizer.encode(samples.detach().to(device)) 168 | encoded_frames = encoded_frames.permute(1,0,2) # [B, n_q, T] 169 | batch_codes = [] 170 | for b, length in enumerate(lengths): 171 | codes = encoded_frames[b] 172 | duration = round(length / sampling_rate, ndigits=12) 173 | expected_num_frames = compute_num_frames( 174 | duration=duration, 175 | frame_shift=self.frame_shift, 176 | sampling_rate=sampling_rate, 177 | ) 178 | batch_codes.append(codes[..., :expected_num_frames]) 179 | return [codes.cpu().permute(1, 0).numpy() for codes in batch_codes] 180 | 181 | 182 | 183 | 184 | 185 | 186 | 187 | 188 | 189 | 190 | if __name__ == "__main__": 191 | 192 | audio_path = r'valle/examples/libritts/prompts/8455_210777_000067_000000.wav' 193 | speechtokenizer = Speechtokenizer() 194 | 195 | wav, sr = torchaudio.load(audio_path) 196 | wav = convert_audio(wav, sr, tokenizer.sample_rate, tokenizer.channels) 197 | wav = wav.unsqueeze(0) 198 | 199 | # Extract discrete codes from EnCodec 200 | with torch.no_grad(): 201 | tokens = speechtokenizer.encode(wav) 202 | 203 | reconstructed = speechtokenizer.decode(tokens) 204 | 205 | 206 | -------------------------------------------------------------------------------- /data/tokenizer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright 2023 (authors: Feiteng Li) 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import re 17 | from dataclasses import asdict, dataclass 18 | from typing import Any, Dict, List, Optional, Pattern, Union 19 | 20 | import numpy as np 21 | import torch 22 | import torchaudio 23 | from encodec import EncodecModel 24 | from encodec.utils import convert_audio 25 | from lhotse.features import FeatureExtractor 26 | from lhotse.utils import Seconds, compute_num_frames 27 | from phonemizer.backend import EspeakBackend 28 | from phonemizer.backend.espeak.language_switch import LanguageSwitch 29 | from phonemizer.backend.espeak.words_mismatch import WordMismatch 30 | from phonemizer.punctuation import Punctuation 31 | from phonemizer.separator import Separator 32 | 33 | try: 34 | from pypinyin import Style, pinyin 35 | from pypinyin.style._utils import get_finals, get_initials 36 | except Exception: 37 | pass 38 | 39 | 40 | class PypinyinBackend: 41 | """PypinyinBackend for Chinese. Most codes is referenced from espnet. 42 | There are two types pinyin or initials_finals, one is 43 | just like "ni1 hao3", the other is like "n i1 h ao3". 44 | """ 45 | 46 | def __init__( 47 | self, 48 | backend="initials_finals", 49 | punctuation_marks: Union[str, Pattern] = Punctuation.default_marks(), 50 | ) -> None: 51 | self.backend = backend 52 | self.punctuation_marks = punctuation_marks 53 | 54 | def phonemize( 55 | self, text: List[str], separator: Separator, strip=True, njobs=1 56 | ) -> List[str]: 57 | assert isinstance(text, List) 58 | phonemized = [] 59 | for _text in text: 60 | _text = re.sub(" +", " ", _text.strip()) 61 | _text = _text.replace(" ", separator.word) 62 | phones = [] 63 | if self.backend == "pypinyin": 64 | for n, py in enumerate( 65 | pinyin( 66 | _text, style=Style.TONE3, neutral_tone_with_five=True 67 | ) 68 | ): 69 | if all([c in self.punctuation_marks for c in py[0]]): 70 | if len(phones): 71 | assert phones[-1] == separator.syllable 72 | phones.pop(-1) 73 | 74 | phones.extend(list(py[0])) 75 | else: 76 | phones.extend([py[0], separator.syllable]) 77 | elif self.backend == "pypinyin_initials_finals": 78 | for n, py in enumerate( 79 | pinyin( 80 | _text, style=Style.TONE3, neutral_tone_with_five=True 81 | ) 82 | ): 83 | if all([c in self.punctuation_marks for c in py[0]]): 84 | if len(phones): 85 | assert phones[-1] == separator.syllable 86 | phones.pop(-1) 87 | phones.extend(list(py[0])) 88 | else: 89 | if py[0][-1].isalnum(): 90 | initial = get_initials(py[0], strict=False) 91 | if py[0][-1].isdigit(): 92 | final = ( 93 | get_finals(py[0][:-1], strict=False) 94 | + py[0][-1] 95 | ) 96 | else: 97 | final = get_finals(py[0], strict=False) 98 | phones.extend( 99 | [ 100 | initial, 101 | separator.phone, 102 | final, 103 | separator.syllable, 104 | ] 105 | ) 106 | else: 107 | assert ValueError 108 | else: 109 | raise NotImplementedError 110 | phonemized.append( 111 | "".join(phones).rstrip(f"{separator.word}{separator.syllable}") 112 | ) 113 | return phonemized 114 | 115 | 116 | class TextTokenizer: 117 | """Phonemize Text.""" 118 | 119 | def __init__( 120 | self, 121 | language="en-us", 122 | backend="espeak", 123 | separator=Separator(word="_", syllable="-", phone="|"), 124 | preserve_punctuation=True, 125 | punctuation_marks: Union[str, Pattern] = Punctuation.default_marks(), 126 | with_stress: bool = False, 127 | tie: Union[bool, str] = False, 128 | language_switch: LanguageSwitch = "keep-flags", 129 | words_mismatch: WordMismatch = "ignore", 130 | ) -> None: 131 | if backend == "espeak": 132 | phonemizer = EspeakBackend( 133 | language, 134 | punctuation_marks=punctuation_marks, 135 | preserve_punctuation=preserve_punctuation, 136 | with_stress=with_stress, 137 | tie=tie, 138 | language_switch=language_switch, 139 | words_mismatch=words_mismatch, 140 | ) 141 | elif backend in ["pypinyin", "pypinyin_initials_finals"]: 142 | phonemizer = PypinyinBackend( 143 | backend=backend, 144 | punctuation_marks=punctuation_marks + separator.word, 145 | ) 146 | else: 147 | raise NotImplementedError(f"{backend}") 148 | 149 | self.backend = phonemizer 150 | self.separator = separator 151 | 152 | def to_list(self, phonemized: str) -> List[str]: 153 | fields = [] 154 | for word in phonemized.split(self.separator.word): 155 | # "ɐ m|iː|n?" ɹ|ɪ|z|ɜː|v; h|ɪ|z. 156 | pp = re.findall(r"\w+|[^\w\s]", word, re.UNICODE) 157 | fields.extend( 158 | [p for p in pp if p != self.separator.phone] 159 | + [self.separator.word] 160 | ) 161 | assert len("".join(fields[:-1])) == len(phonemized) - phonemized.count( 162 | self.separator.phone 163 | ) 164 | return fields[:-1] 165 | 166 | def __call__(self, text, strip=True) -> List[List[str]]: 167 | if isinstance(text, str): 168 | text = [text] 169 | 170 | phonemized = self.backend.phonemize( 171 | text, separator=self.separator, strip=strip, njobs=1 172 | ) 173 | return [self.to_list(p) for p in phonemized] 174 | 175 | 176 | def tokenize_text(tokenizer: TextTokenizer, text: str) -> List[str]: 177 | phonemes = tokenizer([text.strip()]) 178 | return phonemes[0] # k2symbols 179 | 180 | 181 | def remove_encodec_weight_norm(model): 182 | from encodec.modules import SConv1d 183 | from encodec.modules.seanet import SConvTranspose1d, SEANetResnetBlock 184 | from torch.nn.utils import remove_weight_norm 185 | 186 | encoder = model.encoder.model 187 | for key in encoder._modules: 188 | if isinstance(encoder._modules[key], SEANetResnetBlock): 189 | remove_weight_norm(encoder._modules[key].shortcut.conv.conv) 190 | block_modules = encoder._modules[key].block._modules 191 | for skey in block_modules: 192 | if isinstance(block_modules[skey], SConv1d): 193 | remove_weight_norm(block_modules[skey].conv.conv) 194 | elif isinstance(encoder._modules[key], SConv1d): 195 | remove_weight_norm(encoder._modules[key].conv.conv) 196 | 197 | decoder = model.decoder.model 198 | for key in decoder._modules: 199 | if isinstance(decoder._modules[key], SEANetResnetBlock): 200 | remove_weight_norm(decoder._modules[key].shortcut.conv.conv) 201 | block_modules = decoder._modules[key].block._modules 202 | for skey in block_modules: 203 | if isinstance(block_modules[skey], SConv1d): 204 | remove_weight_norm(block_modules[skey].conv.conv) 205 | elif isinstance(decoder._modules[key], SConvTranspose1d): 206 | remove_weight_norm(decoder._modules[key].convtr.convtr) 207 | elif isinstance(decoder._modules[key], SConv1d): 208 | remove_weight_norm(decoder._modules[key].conv.conv) 209 | 210 | 211 | class AudioTokenizer: 212 | """EnCodec audio.""" 213 | 214 | def __init__( 215 | self, 216 | device: Any = None, 217 | ) -> None: 218 | # Instantiate a pretrained EnCodec model 219 | model = EncodecModel.encodec_model_24khz() 220 | model.set_target_bandwidth(6.0) 221 | remove_encodec_weight_norm(model) 222 | 223 | if not device: 224 | device = torch.device("cpu") 225 | if torch.cuda.is_available(): 226 | device = torch.device("cuda:0") 227 | 228 | self._device = device 229 | 230 | self.codec = model.to(device) 231 | self.sample_rate = model.sample_rate 232 | self.channels = model.channels 233 | 234 | @property 235 | def device(self): 236 | return self._device 237 | 238 | def encode(self, wav: torch.Tensor) -> torch.Tensor: 239 | return self.codec.encode(wav.to(self.device)) 240 | 241 | def decode(self, frames: torch.Tensor) -> torch.Tensor: 242 | return self.codec.decode(frames) 243 | 244 | 245 | def tokenize_audio(tokenizer: AudioTokenizer, audio_path: str): 246 | # Load and pre-process the audio waveform 247 | wav, sr = torchaudio.load(audio_path) 248 | wav = convert_audio(wav, sr, tokenizer.sample_rate, tokenizer.channels) 249 | wav = wav.unsqueeze(0) 250 | 251 | # Extract discrete codes from EnCodec 252 | with torch.no_grad(): 253 | encoded_frames = tokenizer.encode(wav) 254 | return encoded_frames 255 | 256 | 257 | @dataclass 258 | class AudioTokenConfig: 259 | frame_shift: Seconds = 320.0 / 24000 260 | num_quantizers: int = 8 261 | 262 | def to_dict(self) -> Dict[str, Any]: 263 | return asdict(self) 264 | 265 | @staticmethod 266 | def from_dict(data: Dict[str, Any]) -> "AudioTokenConfig": 267 | return AudioTokenConfig(**data) 268 | 269 | 270 | class AudioTokenExtractor(FeatureExtractor): 271 | name = "encodec" 272 | config_type = AudioTokenConfig 273 | 274 | def __init__(self, config: Optional[Any] = None): 275 | super(AudioTokenExtractor, self).__init__(config) 276 | self.tokenizer = AudioTokenizer() 277 | 278 | def extract( 279 | self, samples: Union[np.ndarray, torch.Tensor], sampling_rate: int 280 | ) -> np.ndarray: 281 | if not isinstance(samples, torch.Tensor): 282 | samples = torch.from_numpy(samples) 283 | if sampling_rate != self.tokenizer.sample_rate: 284 | samples = convert_audio( 285 | samples, 286 | sampling_rate, 287 | self.tokenizer.sample_rate, 288 | self.tokenizer.channels, 289 | ) 290 | if len(samples.shape) == 2: 291 | samples = samples.unsqueeze(0) 292 | else: 293 | raise ValueError() 294 | 295 | device = self.tokenizer.device 296 | encoded_frames = self.tokenizer.encode(samples.detach().to(device)) 297 | codes = encoded_frames[0][0] # [B, n_q, T] 298 | if True: 299 | duration = round(samples.shape[-1] / sampling_rate, ndigits=12) 300 | expected_num_frames = compute_num_frames( 301 | duration=duration, 302 | frame_shift=self.frame_shift, 303 | sampling_rate=sampling_rate, 304 | ) 305 | assert abs(codes.shape[-1] - expected_num_frames) <= 1 306 | codes = codes[..., :expected_num_frames] 307 | return codes.cpu().squeeze(0).permute(1, 0).numpy() 308 | 309 | @property 310 | def frame_shift(self) -> Seconds: 311 | return self.config.frame_shift 312 | 313 | def feature_dim(self, sampling_rate: int) -> int: 314 | return self.config.num_quantizers 315 | 316 | def pad_tensor_list(self, tensor_list, device, padding_value=0): 317 | # 计算每个张量的长度 318 | lengths = [tensor.shape[0] for tensor in tensor_list] 319 | # 使用pad_sequence函数进行填充 320 | tensor_list = [torch.Tensor(t).to(device) for t in tensor_list] 321 | padded_tensor = torch.nn.utils.rnn.pad_sequence( 322 | tensor_list, batch_first=True, padding_value=padding_value 323 | ) 324 | return padded_tensor, lengths 325 | 326 | def extract_batch(self, samples, sampling_rate, lengths) -> np.ndarray: 327 | samples = [wav.squeeze() for wav in samples] 328 | device = self.tokenizer.device 329 | samples, lengths = self.pad_tensor_list(samples, device) 330 | samples = samples.unsqueeze(1) 331 | 332 | if not isinstance(samples, torch.Tensor): 333 | samples = torch.from_numpy(samples) 334 | if len(samples.shape) != 3: 335 | raise ValueError() 336 | 337 | if sampling_rate != self.tokenizer.sample_rate: 338 | samples = [ 339 | convert_audio( 340 | wav.cpu(), 341 | sampling_rate, 342 | self.tokenizer.sample_rate, 343 | self.tokenizer.channels, 344 | ).cuda() 345 | for wav in samples 346 | ] 347 | samples = torch.stack(samples, 0) 348 | 349 | # Extract discrete codes from EnCodec 350 | with torch.no_grad(): 351 | encoded_frames = self.tokenizer.encode(samples.detach().to(device)) 352 | encoded_frames = encoded_frames[0][0] # [B, n_q, T] 353 | batch_codes = [] 354 | for b, length in enumerate(lengths): 355 | codes = encoded_frames[b] 356 | duration = round(length / sampling_rate, ndigits=12) 357 | expected_num_frames = compute_num_frames( 358 | duration=duration, 359 | frame_shift=self.frame_shift, 360 | sampling_rate=sampling_rate, 361 | ) 362 | batch_codes.append(codes[..., :expected_num_frames]) 363 | return [codes.cpu().permute(1, 0).numpy() for codes in batch_codes] 364 | 365 | 366 | if __name__ == "__main__": 367 | model = EncodecModel.encodec_model_24khz() 368 | model.set_target_bandwidth(6.0) 369 | 370 | samples = torch.from_numpy(np.random.random([1, 1, 16000])).type( 371 | torch.float32 372 | ) 373 | codes_raw = model.encode(samples) 374 | 375 | remove_encodec_weight_norm(model) 376 | codes_norm = model.encode(samples) 377 | 378 | assert torch.allclose(codes_raw[0][0], codes_norm[0][0]) 379 | print(codes_raw[0][0].shape) 380 | -------------------------------------------------------------------------------- /images/README.md: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /images/overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/0nutation/USLM/b6cbe07feaeb142fa395b8b98836c534fec16bc9/images/overview.png -------------------------------------------------------------------------------- /inference.sh: -------------------------------------------------------------------------------- 1 | 2 | audio_extractor="SpeechTokenizer" 3 | 4 | st_dir="ckpt/speechtokenizer/" 5 | uslm_dir="ckpt/uslm/" 6 | out_dir="output/" 7 | 8 | mkdir -p ${st_dir} 9 | mkdir -p ${uslm_dir} 10 | mkdir -p ${out_dir} 11 | 12 | if [ ! -e "${st_dir}/config.json" ];then 13 | cd ${st_dir} 14 | wget "https://huggingface.co/fnlp/SpeechTokenizer/resolve/main/speechtokenizer_hubert_avg/SpeechTokenizer.pt" 15 | wget "https://huggingface.co/fnlp/SpeechTokenizer/resolve/main/speechtokenizer_hubert_avg/config.json" 16 | cd - 17 | fi 18 | 19 | if [ ! -e "${uslm_dir}/USLM.pt" ];then 20 | cd ${uslm_dir} 21 | wget "https://huggingface.co/fnlp/USLM/resolve/main/USLM_libritts/USLM.pt" 22 | wget "https://huggingface.co/fnlp/USLM/resolve/main/USLM_libritts/unique_text_tokens.k2symbols" 23 | cd - 24 | fi 25 | 26 | 27 | python3 bin/infer.py --output-dir ${out_dir}/ \ 28 | --model-name uslm --norm-first true --add-prenet false \ 29 | --share-embedding true --norm-first true --add-prenet false \ 30 | --audio-extractor "${audio_extractor}" \ 31 | --speechtokenizer-dir "${st_dir}" \ 32 | --checkpoint=${uslm_dir}/best-valid-loss.pt \ 33 | --text-tokens "${uslm_dir}/unique_text_tokens.k2symbols" \ 34 | --text-prompts "mr Soames was a tall, spare man, of a nervous and excitable temperament." \ 35 | --audio-prompts prompts/1580_141083_000002_000002.wav \ 36 | --text "Begin with the fundamental steps of the process. This will give you a solid foundation to build upon and boost your confidence. " \ 37 | 38 | -------------------------------------------------------------------------------- /models/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/0nutation/USLM/b6cbe07feaeb142fa395b8b98836c534fec16bc9/models/.DS_Store -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | import faulthandler 3 | faulthandler.enable() 4 | 5 | import argparse 6 | 7 | import torch.nn as nn 8 | from icefall.utils import AttributeDict, str2bool 9 | 10 | from .macros import ( 11 | NUM_AUDIO_TOKENS, 12 | NUM_MEL_BINS, 13 | NUM_SPEAKER_CLASSES, 14 | NUM_TEXT_TOKENS, 15 | SPEAKER_EMBEDDING_DIM, 16 | ) 17 | from .transformer import Transformer 18 | from uslm.models.valle import VALLE, VALLF 19 | from uslm.models.uslm import USLM 20 | from .visualizer import visualize 21 | 22 | 23 | def add_model_arguments(parser: argparse.ArgumentParser): 24 | parser.add_argument( 25 | "--model-name", 26 | type=str, 27 | default="USLM", 28 | help="USLM, VALL-E, VALL-F, Transformer.", 29 | ) 30 | parser.add_argument( 31 | "--decoder-dim", 32 | type=int, 33 | default=1024, 34 | help="Embedding dimension in the decoder model.", 35 | ) 36 | parser.add_argument( 37 | "--nhead", 38 | type=int, 39 | default=16, 40 | help="Number of attention heads in the Decoder layers.", 41 | ) 42 | parser.add_argument( 43 | "--num-decoder-layers", 44 | type=int, 45 | default=12, 46 | help="Number of Decoder layers.", 47 | ) 48 | parser.add_argument( 49 | "--scale-factor", 50 | type=float, 51 | default=1.0, 52 | help="Model scale factor which will be assigned different meanings in different models.", 53 | ) 54 | parser.add_argument( 55 | "--norm-first", 56 | type=str2bool, 57 | default=True, 58 | help="Pre or Post Normalization.", 59 | ) 60 | parser.add_argument( 61 | "--add-prenet", 62 | type=str2bool, 63 | default=False, 64 | help="Whether add PreNet after Inputs.", 65 | ) 66 | 67 | # VALL-E & F 68 | parser.add_argument( 69 | "--prefix-mode", 70 | type=int, 71 | default=0, 72 | help="The mode for how to prefix VALL-E NAR Decoder, " 73 | "0: no prefix, 1: 0 to random, 2: random to random, 4: chunk of pre or post utterance.", 74 | ) 75 | parser.add_argument( 76 | "--share-embedding", 77 | type=str2bool, 78 | default=True, 79 | help="Share the parameters of the output projection layer with the parameters of the acoustic embedding.", 80 | ) 81 | parser.add_argument( 82 | "--prepend-bos", 83 | type=str2bool, 84 | default=False, 85 | help="Whether prepend to the acoustic tokens -> AR Decoder inputs.", 86 | ) 87 | parser.add_argument( 88 | "--num-quantizers", 89 | type=int, 90 | default=8, 91 | help="Number of Audio/Semantic quantization layers.", 92 | ) 93 | 94 | # Transformer 95 | parser.add_argument( 96 | "--scaling-xformers", 97 | type=str2bool, 98 | default=False, 99 | help="Apply Reworked Conformer scaling on Transformers.", 100 | ) 101 | 102 | 103 | def get_model(params: AttributeDict) -> nn.Module: 104 | if params.model_name.lower() in ["USLM", "uslm"]: 105 | model = USLM( 106 | params.decoder_dim, 107 | params.nhead, 108 | params.num_decoder_layers, 109 | norm_first=params.norm_first, 110 | add_prenet=params.add_prenet, 111 | prefix_mode=params.prefix_mode, 112 | share_embedding=params.share_embedding, 113 | nar_scale_factor=params.scale_factor, 114 | prepend_bos=params.prepend_bos, 115 | num_quantizers=params.num_quantizers, 116 | ) 117 | elif params.model_name.lower() in ["vall-f", "vallf"]: 118 | model = VALLF( 119 | params.decoder_dim, 120 | params.nhead, 121 | params.num_decoder_layers, 122 | norm_first=params.norm_first, 123 | add_prenet=params.add_prenet, 124 | prefix_mode=params.prefix_mode, 125 | share_embedding=params.share_embedding, 126 | nar_scale_factor=params.scale_factor, 127 | prepend_bos=params.prepend_bos, 128 | num_quantizers=params.num_quantizers, 129 | ) 130 | elif params.model_name.lower() in ["vall-e", "valle"]: 131 | model = VALLE( 132 | params.decoder_dim, 133 | params.nhead, 134 | params.num_decoder_layers, 135 | norm_first=params.norm_first, 136 | add_prenet=params.add_prenet, 137 | prefix_mode=params.prefix_mode, 138 | share_embedding=params.share_embedding, 139 | nar_scale_factor=params.scale_factor, 140 | prepend_bos=params.prepend_bos, 141 | num_quantizers=params.num_quantizers, 142 | ) 143 | else: 144 | assert params.model_name in ["Transformer"] 145 | model = Transformer( 146 | params.decoder_dim, 147 | params.nhead, 148 | params.num_decoder_layers, 149 | norm_first=params.norm_first, 150 | add_prenet=params.add_prenet, 151 | scaling_xformers=params.scaling_xformers, 152 | ) 153 | 154 | return model 155 | -------------------------------------------------------------------------------- /models/macros.py: -------------------------------------------------------------------------------- 1 | # Text 2 | NUM_TEXT_TOKENS = 512 3 | 4 | # Audio 5 | NUM_AUDIO_TOKENS = 1024 # EnCodec RVQ bins 6 | NUM_MEL_BINS = 100 # BigVGAN bigvgan_24khz_100band 7 | 8 | 9 | # Speaker 10 | NUM_SPEAKER_CLASSES = 4096 11 | SPEAKER_EMBEDDING_DIM = 64 12 | -------------------------------------------------------------------------------- /models/transformer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 (authors: Feiteng Li) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from functools import partial 16 | from typing import Any, Dict, List, Tuple, Union 17 | 18 | import torch 19 | import torch.nn as nn 20 | import torch.nn.functional as F 21 | from icefall.utils import make_pad_mask 22 | from torchmetrics.classification import BinaryAccuracy 23 | 24 | from uslm.models.valle import Transpose 25 | from uslm.modules.embedding import SinePositionalEmbedding, TokenEmbedding 26 | from uslm.modules.scaling import BalancedDoubleSwish, ScaledLinear 27 | from uslm.modules.transformer import ( 28 | BalancedBasicNorm, 29 | IdentityNorm, 30 | TransformerDecoderLayer, 31 | TransformerEncoder, 32 | TransformerEncoderLayer, 33 | ) 34 | 35 | from .macros import NUM_MEL_BINS, NUM_TEXT_TOKENS 36 | from .visualizer import visualize 37 | 38 | IdentityNorm = IdentityNorm 39 | 40 | 41 | class Transformer(nn.Module): 42 | """It implements seq2seq Transformer TTS for debug(No StopPredictor and SpeakerEmbeding) 43 | Neural Speech Synthesis with Transformer Network 44 | https://arxiv.org/abs/1809.08895 45 | """ 46 | 47 | def __init__( 48 | self, 49 | d_model: int, 50 | nhead: int, 51 | num_layers: int, 52 | norm_first: bool = True, 53 | add_prenet: bool = False, 54 | scaling_xformers: bool = False, 55 | ): 56 | """ 57 | Args: 58 | d_model: 59 | The number of expected features in the input (required). 60 | nhead: 61 | The number of heads in the multiheadattention models (required). 62 | num_layers: 63 | The number of sub-decoder-layers in the decoder (required). 64 | """ 65 | super().__init__() 66 | self.text_embedding = TokenEmbedding(d_model, NUM_TEXT_TOKENS) # W_x 67 | 68 | if add_prenet: 69 | self.encoder_prenet = nn.Sequential( 70 | Transpose(), 71 | nn.Conv1d(d_model, d_model, kernel_size=5, padding="same"), 72 | nn.BatchNorm1d(d_model), 73 | nn.ReLU(), 74 | nn.Dropout(0.5), 75 | nn.Conv1d(d_model, d_model, kernel_size=5, padding="same"), 76 | nn.BatchNorm1d(d_model), 77 | nn.ReLU(), 78 | nn.Dropout(0.5), 79 | nn.Conv1d(d_model, d_model, kernel_size=5, padding="same"), 80 | nn.BatchNorm1d(d_model), 81 | nn.ReLU(), 82 | nn.Dropout(0.5), 83 | Transpose(), 84 | nn.Linear(d_model, d_model), 85 | ) 86 | 87 | self.decoder_prenet = nn.Sequential( 88 | nn.Linear(NUM_MEL_BINS, 256), 89 | nn.ReLU(), 90 | nn.Dropout(0.5), 91 | nn.Linear(256, 256), 92 | nn.ReLU(), 93 | nn.Dropout(0.5), 94 | nn.Linear(256, d_model), 95 | ) 96 | 97 | assert scaling_xformers is False # TODO: update this block 98 | else: 99 | self.encoder_prenet = nn.Identity() 100 | if scaling_xformers: 101 | self.decoder_prenet = ScaledLinear(NUM_MEL_BINS, d_model) 102 | else: 103 | self.decoder_prenet = nn.Linear(NUM_MEL_BINS, d_model) 104 | 105 | self.encoder_position = SinePositionalEmbedding( 106 | d_model, 107 | dropout=0.1, 108 | scale=False, 109 | ) 110 | self.decoder_position = SinePositionalEmbedding( 111 | d_model, dropout=0.1, scale=False 112 | ) 113 | 114 | if scaling_xformers: 115 | self.encoder = TransformerEncoder( 116 | TransformerEncoderLayer( 117 | d_model, 118 | nhead, 119 | dim_feedforward=d_model * 4, 120 | dropout=0.1, 121 | batch_first=True, 122 | norm_first=norm_first, 123 | linear1_self_attention_cls=ScaledLinear, 124 | linear2_self_attention_cls=partial( 125 | ScaledLinear, initial_scale=0.01 126 | ), 127 | linear1_feedforward_cls=ScaledLinear, 128 | linear2_feedforward_cls=partial( 129 | ScaledLinear, initial_scale=0.01 130 | ), 131 | activation=partial( 132 | BalancedDoubleSwish, 133 | channel_dim=-1, 134 | max_abs=10.0, 135 | min_prob=0.25, 136 | ), 137 | layer_norm_cls=IdentityNorm, 138 | ), 139 | num_layers=num_layers, 140 | norm=BalancedBasicNorm(d_model) if norm_first else None, 141 | ) 142 | 143 | self.decoder = nn.TransformerDecoder( 144 | TransformerDecoderLayer( 145 | d_model, 146 | nhead, 147 | dim_feedforward=d_model * 4, 148 | dropout=0.1, 149 | batch_first=True, 150 | norm_first=norm_first, 151 | linear1_self_attention_cls=ScaledLinear, 152 | linear2_self_attention_cls=partial( 153 | ScaledLinear, initial_scale=0.01 154 | ), 155 | linear1_feedforward_cls=ScaledLinear, 156 | linear2_feedforward_cls=partial( 157 | ScaledLinear, initial_scale=0.01 158 | ), 159 | activation=partial( 160 | BalancedDoubleSwish, 161 | channel_dim=-1, 162 | max_abs=10.0, 163 | min_prob=0.25, 164 | ), 165 | layer_norm_cls=IdentityNorm, 166 | ), 167 | num_layers=num_layers, 168 | norm=BalancedBasicNorm(d_model) if norm_first else None, 169 | ) 170 | 171 | self.predict_layer = ScaledLinear(d_model, NUM_MEL_BINS) 172 | self.stop_layer = nn.Linear(d_model, 1) 173 | else: 174 | self.encoder = nn.TransformerEncoder( 175 | nn.TransformerEncoderLayer( 176 | d_model, 177 | nhead, 178 | dim_feedforward=d_model * 4, 179 | activation=F.relu, 180 | dropout=0.1, 181 | batch_first=True, 182 | norm_first=norm_first, 183 | ), 184 | num_layers=num_layers, 185 | norm=nn.LayerNorm(d_model) if norm_first else None, 186 | ) 187 | 188 | self.decoder = nn.TransformerDecoder( 189 | nn.TransformerDecoderLayer( 190 | d_model, 191 | nhead, 192 | dim_feedforward=d_model * 4, 193 | activation=F.relu, 194 | dropout=0.1, 195 | batch_first=True, 196 | norm_first=norm_first, 197 | ), 198 | num_layers=num_layers, 199 | norm=nn.LayerNorm(d_model) if norm_first else None, 200 | ) 201 | 202 | self.predict_layer = nn.Linear(d_model, NUM_MEL_BINS) 203 | self.stop_layer = nn.Linear(d_model, 1) 204 | 205 | self.stop_accuracy_metric = BinaryAccuracy( 206 | threshold=0.5, multidim_average="global" 207 | ) 208 | 209 | # self.apply(self._init_weights) 210 | 211 | # def _init_weights(self, module): 212 | # if isinstance(module, (nn.Linear)): 213 | # module.weight.data.normal_(mean=0.0, std=0.02) 214 | # if isinstance(module, nn.Linear) and module.bias is not None: 215 | # module.bias.data.zero_() 216 | # elif isinstance(module, nn.LayerNorm): 217 | # module.bias.data.zero_() 218 | # module.weight.data.fill_(1.0) 219 | # elif isinstance(module, nn.Embedding): 220 | # module.weight.data.normal_(mean=0.0, std=0.02) 221 | 222 | def forward( 223 | self, 224 | x: torch.Tensor, 225 | x_lens: torch.Tensor, 226 | y: torch.Tensor, 227 | y_lens: torch.Tensor, 228 | reduction: str = "sum", 229 | train_stage: int = 0, 230 | **kwargs, 231 | ) -> Tuple[torch.Tensor, Union[torch.Tensor, None]]: 232 | """ 233 | Args: 234 | x: 235 | A 2-D tensor of shape (N, S). 236 | x_lens: 237 | A 1-D tensor of shape (N,). It contains the number of tokens in `x` 238 | before padding. 239 | y: 240 | A 3-D tensor of shape (N, T, 8). 241 | y_lens: 242 | A 1-D tensor of shape (N,). It contains the number of tokens in `x` 243 | before padding. 244 | train_stage: 245 | Not used in this model. 246 | Returns: 247 | Return the predicted audio code matrix, cross-entropy loss and Top-10 accuracy. 248 | """ 249 | del train_stage 250 | 251 | assert x.ndim == 2, x.shape 252 | assert x_lens.ndim == 1, x_lens.shape 253 | assert y.ndim == 3, y.shape 254 | assert y_lens.ndim == 1, y_lens.shape 255 | 256 | assert torch.all(x_lens > 0) 257 | 258 | # NOTE: x has been padded in TextTokenCollater 259 | x_mask = make_pad_mask(x_lens).to(x.device) 260 | 261 | x = self.text_embedding(x) 262 | x = self.encoder_prenet(x) 263 | x = self.encoder_position(x) 264 | x = self.encoder(x, src_key_padding_mask=x_mask) 265 | 266 | total_loss, metrics = 0.0, {} 267 | 268 | y_mask = make_pad_mask(y_lens).to(y.device) 269 | y_mask_float = y_mask.type(torch.float32) 270 | data_mask = 1.0 - y_mask_float.unsqueeze(-1) 271 | 272 | # Training 273 | # AR Decoder 274 | def pad_y(y): 275 | y = F.pad(y, (0, 0, 1, 0, 0, 0), value=0).detach() 276 | # inputs, targets 277 | return y[:, :-1], y[:, 1:] 278 | 279 | y, targets = pad_y(y * data_mask) # mask padding as zeros 280 | 281 | y_emb = self.decoder_prenet(y) 282 | y_pos = self.decoder_position(y_emb) 283 | 284 | y_len = y_lens.max() 285 | tgt_mask = torch.triu( 286 | torch.ones(y_len, y_len, device=y.device, dtype=torch.bool), 287 | diagonal=1, 288 | ) 289 | y_dec = self.decoder( 290 | y_pos, 291 | x, 292 | tgt_mask=tgt_mask, 293 | memory_key_padding_mask=x_mask, 294 | ) 295 | 296 | predict = self.predict_layer(y_dec) 297 | # loss 298 | total_loss = F.mse_loss(predict, targets, reduction=reduction) 299 | 300 | logits = self.stop_layer(y_dec).squeeze(-1) 301 | stop_loss = F.binary_cross_entropy_with_logits( 302 | logits, 303 | y_mask_float.detach(), 304 | weight=1.0 + y_mask_float.detach() * 4.0, 305 | reduction=reduction, 306 | ) 307 | metrics["stop_loss"] = stop_loss.detach() 308 | 309 | stop_accuracy = self.stop_accuracy_metric( 310 | (torch.sigmoid(logits) >= 0.5).type(torch.int64), 311 | y_mask.type(torch.int64), 312 | ) 313 | # icefall MetricsTracker.norm_items() 314 | metrics["stop_accuracy"] = stop_accuracy.item() * y_lens.sum().type( 315 | torch.float32 316 | ) 317 | 318 | return ((x, predict), total_loss + 100.0 * stop_loss, metrics) 319 | 320 | def inference( 321 | self, 322 | x: torch.Tensor, 323 | x_lens: torch.Tensor, 324 | y: Any = None, 325 | **kwargs, 326 | ) -> torch.Tensor: 327 | """ 328 | Args: 329 | x: 330 | A 2-D tensor of shape (1, S). 331 | x_lens: 332 | A 1-D tensor of shape (1,). It contains the number of tokens in `x` 333 | before padding. 334 | Returns: 335 | Return the predicted audio code matrix and cross-entropy loss. 336 | """ 337 | assert x.ndim == 2, x.shape 338 | assert x_lens.ndim == 1, x_lens.shape 339 | 340 | assert torch.all(x_lens > 0) 341 | 342 | x_mask = make_pad_mask(x_lens).to(x.device) 343 | 344 | x = self.text_embedding(x) 345 | x = self.encoder_prenet(x) 346 | x = self.encoder_position(x) 347 | x = self.encoder(x, src_key_padding_mask=x_mask) 348 | 349 | x_mask = make_pad_mask(x_lens).to(x.device) 350 | 351 | # AR Decoder 352 | # TODO: Managing decoder steps avoid repetitive computation 353 | y = torch.zeros( 354 | [x.shape[0], 1, NUM_MEL_BINS], dtype=torch.float32, device=x.device 355 | ) 356 | while True: 357 | y_emb = self.decoder_prenet(y) 358 | y_pos = self.decoder_position(y_emb) 359 | 360 | tgt_mask = torch.triu( 361 | torch.ones( 362 | y.shape[1], y.shape[1], device=y.device, dtype=torch.bool 363 | ), 364 | diagonal=1, 365 | ) 366 | 367 | y_dec = self.decoder( 368 | y_pos, 369 | x, 370 | tgt_mask=tgt_mask, 371 | memory_mask=None, 372 | memory_key_padding_mask=x_mask, 373 | ) 374 | predict = self.predict_layer(y_dec[:, -1:]) 375 | 376 | logits = self.stop_layer(y_dec[:, -1:]) > 0 # sigmoid(0.0) = 0.5 377 | if y.shape[1] > x_lens.max() * 10 or all(logits.cpu().numpy()): 378 | print( 379 | f"TransformerTTS EOS [Text {x_lens[0]} -> Audio {y.shape[1]}]" 380 | ) 381 | break 382 | 383 | y = torch.concat([y, predict], dim=1) 384 | 385 | return y[:, 1:] 386 | 387 | def visualize( 388 | self, 389 | predicts: Tuple[torch.Tensor], 390 | batch: Dict[str, Union[List, torch.Tensor]], 391 | output_dir: str, 392 | limit: int = 4, 393 | ) -> None: 394 | visualize(predicts, batch, output_dir, limit=limit) 395 | -------------------------------------------------------------------------------- /models/uslm.py: -------------------------------------------------------------------------------- 1 | """ 2 | Created on Fri. Sept. 8 00:45:46 2023 3 | @author: Dong Zhang 4 | """ 5 | 6 | 7 | import faulthandler 8 | faulthandler.enable() 9 | 10 | 11 | import random 12 | from typing import Dict, Iterator, List, Tuple, Union 13 | 14 | import torch 15 | import torch.nn as nn 16 | import torch.nn.functional as F 17 | from torchmetrics.classification import MulticlassAccuracy 18 | from icefall.utils import make_pad_mask 19 | 20 | 21 | from uslm.data.input_strategies import PromptedFeatures 22 | from uslm.modules.embedding import SinePositionalEmbedding, TokenEmbedding 23 | from uslm.modules.transformer import ( 24 | AdaptiveLayerNorm, 25 | LayerNorm, 26 | TransformerDecoderLayer, 27 | TransformerEncoder, 28 | TransformerEncoderLayer, 29 | ) 30 | from uslm.models.valle import VALLF, top_k_top_p_filtering, topk_sampling, Transpose 31 | 32 | from .macros import NUM_AUDIO_TOKENS, NUM_TEXT_TOKENS 33 | from .visualizer import visualize 34 | 35 | 36 | class USLM(VALLF): 37 | def __init__( 38 | self, 39 | d_model: int, 40 | nhead: int, 41 | num_layers: int, 42 | norm_first: bool = True, 43 | add_prenet: bool = False, 44 | prefix_mode: int = 0, 45 | share_embedding: bool = True, 46 | nar_scale_factor: float = 1.0, 47 | **kwargs, 48 | ): 49 | """ 50 | Args: 51 | d_model: 52 | The number of expected features in the input (required). 53 | nhead: 54 | The number of heads in the multiheadattention models (required). 55 | num_layers: 56 | The number of sub-decoder-layers in the decoder (required). 57 | """ 58 | super(USLM, self).__init__( 59 | d_model, 60 | nhead, 61 | num_layers, 62 | norm_first=norm_first, 63 | add_prenet=add_prenet, 64 | decoder_cls=TransformerEncoder, 65 | decoder_layer_cls=TransformerEncoderLayer, 66 | prefix_mode=prefix_mode, 67 | share_embedding=share_embedding, 68 | nar_scale_factor=nar_scale_factor, 69 | **kwargs, 70 | ) 71 | 72 | def forward( 73 | self, 74 | x: torch.Tensor, 75 | x_lens: torch.Tensor, 76 | y: Union[torch.Tensor, PromptedFeatures], 77 | y_lens: Union[torch.Tensor, PromptedFeatures], 78 | reduction: str = "sum", 79 | train_stage: int = 0, 80 | **kwargs, 81 | ) -> Tuple[torch.Tensor, Union[torch.Tensor, None]]: 82 | """ 83 | Args: 84 | x: 85 | A 2-D tensor of shape (N, S). 86 | x_lens: 87 | A 1-D tensor of shape (N,). It contains the number of tokens in `x` 88 | before padding. 89 | y: 90 | A 3-D tensor of shape (N, T, 8). 91 | y_lens: 92 | A 1-D tensor of shape (N,). It contains the number of tokens in `x` 93 | before padding. 94 | train_stage: 95 | 0: AR & NAR modules, 1: AR modules, 2: NAR modules 96 | Returns: 97 | Return the predicted audio code matrix, cross-entropy loss and Top-10 accuracy. 98 | """ 99 | assert x.ndim == 2, x.shape 100 | assert x_lens.ndim == 1, x_lens.shape 101 | 102 | y_prompts_codes = None 103 | if isinstance(y, PromptedFeatures): 104 | y_prompts_codes, y = y.data 105 | prompts_len, y_lens = y_lens.data 106 | assert prompts_len.min() == prompts_len.max() 107 | assert self.prefix_mode == 4 108 | y_prompts_codes = y_prompts_codes.type(torch.int64) 109 | 110 | assert y.ndim == 3, y.shape 111 | assert y_lens.ndim == 1, y_lens.shape 112 | 113 | # NOTE: x has been padded in TextTokenCollater 114 | x_mask = make_pad_mask(x_lens).to(x.device) 115 | y_mask = make_pad_mask(y_lens).to(y.device) 116 | y_mask_int = y_mask.type(torch.int64) 117 | 118 | text = x 119 | codes = y.type(torch.int64) * (1 - y_mask_int.unsqueeze(dim=-1)) 120 | 121 | y, targets = self.pad_y_eos( 122 | codes[..., 0], y_mask_int, eos_id=NUM_AUDIO_TOKENS 123 | ) 124 | 125 | x_len = x_lens.max() 126 | 127 | metrics = {} 128 | total_loss = 0.0 129 | 130 | xy_padding_mask = torch.concat([x_mask, y_mask], dim=1) 131 | if self.ar_audio_prepend_bos: 132 | ar_xy_padding_mask = torch.concat( 133 | [x_mask, F.pad(y_mask, (1, 0), value=False)], dim=1 134 | ) 135 | else: 136 | ar_xy_padding_mask = xy_padding_mask 137 | # AR Decoder 138 | if train_stage in [0, 1]: 139 | x = self.ar_text_embedding(text) 140 | x = self.ar_text_prenet(x) 141 | x = self.ar_text_position(x) 142 | 143 | y_len = y_lens.max() + int(self.ar_audio_prepend_bos) 144 | 145 | x_attn_mask = F.pad( 146 | torch.zeros((x_len, x_len), dtype=torch.bool, device=x.device), 147 | (0, y_len), 148 | value=True, 149 | ) 150 | y_attn_mask = F.pad( 151 | torch.triu( 152 | torch.ones(y_len, y_len, dtype=torch.bool, device=x.device), 153 | diagonal=1, 154 | ), 155 | (x_len, 0), 156 | value=False, 157 | ) 158 | xy_attn_mask = torch.concat([x_attn_mask, y_attn_mask], dim=0) 159 | 160 | # merge key padding and attention masks 161 | bsz, src_len = x.shape[0], x_len + y_len 162 | _xy_padding_mask = ( 163 | ar_xy_padding_mask.view(bsz, 1, 1, src_len) 164 | .expand(-1, self.num_heads, -1, -1) 165 | .reshape(bsz * self.num_heads, 1, src_len) 166 | ) 167 | xy_attn_mask = xy_attn_mask.logical_or(_xy_padding_mask) 168 | 169 | new_attn_mask = torch.zeros_like(xy_attn_mask, dtype=x.dtype) 170 | new_attn_mask.masked_fill_(xy_attn_mask, float("-inf")) 171 | xy_attn_mask = new_attn_mask 172 | 173 | y_emb = self.ar_audio_embedding(y) 174 | y_emb = self.ar_audio_prenet(y_emb) 175 | y_pos = self.ar_audio_position(y_emb) 176 | 177 | xy_pos = torch.concat([x, y_pos], dim=1) 178 | 179 | xy_dec, _ = self.ar_decoder( 180 | (xy_pos, None), 181 | mask=xy_attn_mask, 182 | # src_key_padding_mask=xy_padding_mask, 183 | # is_causal=True, 184 | ) 185 | logits = self.ar_predict_layer(xy_dec[:, x_len:]).permute(0, 2, 1) 186 | # loss 187 | total_loss = F.cross_entropy(logits, targets, reduction=reduction) 188 | 189 | metrics["ArTop10Accuracy"] = self.ar_accuracy_metric( 190 | logits.detach(), targets 191 | ).item() * y_lens.sum().type(torch.float32) 192 | 193 | if self.num_quantizers == 1: 194 | return ((x, codes), total_loss, metrics) 195 | 196 | # Non-AR Decoders 197 | if self.ar_audio_prepend_bos: 198 | y = y[:, 1:] 199 | if train_stage in [0, 2]: 200 | num_nar_layers = self.num_quantizers - 1 201 | 202 | #layer2: contain most timbre information 203 | nar_stage = 1 204 | 205 | x = self.nar_text_embedding(text) 206 | x = self.nar_text_prenet(x) 207 | x = self.nar_text_position(x) 208 | 209 | y_emb, prefix_len = self._prepare_prompts( 210 | y, y_lens, codes, nar_stage, y_prompts_codes 211 | ) 212 | 213 | y_len = y_lens.max() 214 | targets = codes[..., nar_stage] + NUM_AUDIO_TOKENS * y_mask_int 215 | if self.prefix_mode in [2, 4]: 216 | xy_padding_mask = torch.concat( 217 | [ 218 | x_mask, 219 | F.pad(y_mask, (y_emb.shape[1] - y_len, 0), value=False), 220 | ], 221 | dim=1, 222 | ) 223 | elif self.prefix_mode == 1: 224 | targets = targets[:, prefix_len:] 225 | 226 | y_pos = self.nar_audio_prenet(y_emb) 227 | y_pos = self.nar_audio_position(y_pos) 228 | xy_pos = torch.concat([x, y_pos], dim=1) 229 | xy_dec, _ = self.nar_decoder( 230 | (xy_pos, self.nar_stage_embeddings[nar_stage - 1].weight), 231 | src_key_padding_mask=xy_padding_mask, 232 | # is_causal=False, 233 | ) 234 | xy_dec = xy_dec[:, x_lens.max() + prefix_len :] 235 | if self.prefix_mode == 4: 236 | prefix_len = 0 # reset for Top10Accuracy metric 237 | logits = self.nar_predict_layers[nar_stage - 1](xy_dec).permute( 238 | 0, 2, 1 239 | ) 240 | 241 | # loss 242 | total_length = (y_lens).sum().type(torch.float32) 243 | total_loss += ( 244 | F.cross_entropy( 245 | logits, 246 | targets, 247 | ignore_index=NUM_AUDIO_TOKENS, 248 | reduction=reduction, 249 | ) 250 | * (total_length / (total_length - prefix_len * x.shape[0])) 251 | ) 252 | metrics["NarL2Top10Accuracy"] = ( 253 | self.nar_accuracy_metric( 254 | F.pad( 255 | logits.detach(), 256 | (0, 0, 0, 1, 0, 0), 257 | value=logits.min().cpu().item(), 258 | ), 259 | targets, 260 | ).item() 261 | * total_length 262 | ) 263 | 264 | 265 | #layer3-8 266 | nar_stage = self.rng.choices( 267 | [_k for _k in range(2, self.num_quantizers)], 268 | weights=[1.0 / (num_nar_layers-1)] * (num_nar_layers-1), 269 | k=1, 270 | )[0] 271 | 272 | x = self.nar_text_embedding(text) 273 | x = self.nar_text_prenet(x) 274 | x = self.nar_text_position(x) 275 | 276 | y_emb, prefix_len = self._prepare_prompts( 277 | y, y_lens, codes, nar_stage, y_prompts_codes 278 | ) 279 | 280 | y_len = y_lens.max() 281 | targets = codes[..., nar_stage] + NUM_AUDIO_TOKENS * y_mask_int 282 | if self.prefix_mode in [2, 4]: 283 | xy_padding_mask = torch.concat( 284 | [ 285 | x_mask, 286 | F.pad(y_mask, (y_emb.shape[1] - y_len, 0), value=False), 287 | ], 288 | dim=1, 289 | ) 290 | elif self.prefix_mode == 1: 291 | targets = targets[:, prefix_len:] 292 | 293 | y_pos = self.nar_audio_prenet(y_emb) 294 | y_pos = self.nar_audio_position(y_pos) 295 | xy_pos = torch.concat([x, y_pos], dim=1) 296 | xy_dec, _ = self.nar_decoder( 297 | (xy_pos, self.nar_stage_embeddings[nar_stage - 1].weight), 298 | src_key_padding_mask=xy_padding_mask, 299 | # is_causal=False, 300 | ) 301 | xy_dec = xy_dec[:, x_lens.max() + prefix_len :] 302 | if self.prefix_mode == 4: 303 | prefix_len = 0 # reset for Top10Accuracy metric 304 | logits = self.nar_predict_layers[nar_stage - 1](xy_dec).permute( 305 | 0, 2, 1 306 | ) 307 | 308 | # loss 309 | total_length = (y_lens).sum().type(torch.float32) 310 | total_loss += ( 311 | F.cross_entropy( 312 | logits, 313 | targets, 314 | ignore_index=NUM_AUDIO_TOKENS, 315 | reduction=reduction, 316 | ) 317 | * (total_length / (total_length - prefix_len * x.shape[0])) 318 | ) 319 | metrics["NarL3-8Top10Accuracy"] = ( 320 | self.nar_accuracy_metric( 321 | F.pad( 322 | logits.detach(), 323 | (0, 0, 0, 1, 0, 0), 324 | value=logits.min().cpu().item(), 325 | ), 326 | targets, 327 | ).item() 328 | * total_length 329 | ) 330 | 331 | if train_stage == 0: 332 | total_loss = total_loss / 2.0 333 | 334 | return ((x, codes), total_loss, metrics) 335 | 336 | def inference( 337 | self, 338 | x: torch.Tensor, 339 | x_lens: torch.Tensor, 340 | y: torch.Tensor, 341 | enroll_x_lens: torch.Tensor, 342 | top_k: int = -100, 343 | temperature: float = 1.0, 344 | ) -> torch.Tensor: 345 | """ 346 | Args: 347 | x: 348 | A 2-D tensor of shape (1, S). 349 | x_lens: 350 | A 1-D tensor of shape (1,). It contains the number of tokens in `x` 351 | before padding. 352 | y: 353 | A 3-D tensor of shape (1, T, 8). 354 | top_k: (`optional`) int 355 | The number of highest probability tokens to keep for top-k-filtering. Default to -100. 356 | temperature: (`optional`) float 357 | The value used to module the next token probabilities. Must be strictly positive. Default to 1.0. 358 | Returns: 359 | Return the predicted audio code matrix. 360 | """ 361 | assert x.ndim == 2, x.shape 362 | assert x_lens.ndim == 1, x_lens.shape 363 | assert y.ndim == 3, y.shape 364 | assert y.shape[0] == 1, y.shape 365 | 366 | assert torch.all(x_lens > 0) 367 | 368 | # NOTE: x has been padded in TextTokenCollater 369 | text = x 370 | x = self.ar_text_embedding(text) 371 | x = self.ar_text_prenet(x) 372 | x = self.ar_text_position(x) 373 | 374 | text_len = x_lens.max() 375 | prompts = y 376 | prefix_len = y.shape[1] 377 | 378 | # AR Decoder 379 | # TODO: Managing decoder steps avoid repetitive computation 380 | y = prompts[..., 0] 381 | if self.ar_audio_prepend_bos: 382 | y = F.pad(y, (1, 0), value=NUM_AUDIO_TOKENS + 1) 383 | 384 | x_len = x_lens.max() 385 | x_attn_mask = torch.zeros((x_len, x_len), dtype=torch.bool) 386 | 387 | while True: 388 | y_emb = self.ar_audio_embedding(y) 389 | y_emb = self.ar_audio_prenet(y_emb) 390 | y_pos = self.ar_audio_position(y_emb) 391 | xy_pos = torch.concat([x, y_pos], dim=1) 392 | 393 | y_len = y.shape[1] 394 | x_attn_mask_pad = F.pad( 395 | x_attn_mask, 396 | (0, y_len), 397 | value=True, 398 | ) 399 | y_attn_mask = F.pad( 400 | torch.triu( 401 | torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1 402 | ), 403 | (x_len, 0), 404 | value=False, 405 | ) 406 | xy_attn_mask = torch.concat( 407 | [x_attn_mask_pad, y_attn_mask], dim=0 408 | ).to(y.device) 409 | 410 | xy_dec, _ = self.ar_decoder( 411 | (xy_pos, None), 412 | mask=xy_attn_mask, 413 | ) 414 | logits = self.ar_predict_layer(xy_dec[:, -1]) 415 | samples = topk_sampling( 416 | logits, top_k=top_k, top_p=1.0, temperature=temperature 417 | ) 418 | 419 | if ( 420 | torch.argmax(logits, dim=-1)[0] == NUM_AUDIO_TOKENS 421 | or samples[0, 0] == NUM_AUDIO_TOKENS 422 | or (y.shape[1] - prompts.shape[1]) > x_lens.max() * 16 423 | ): 424 | if prompts.shape[1] == y.shape[1]: 425 | raise SyntaxError( 426 | "well trained model shouldn't reach here." 427 | ) 428 | 429 | print(f"USLM EOS [{prompts.shape[1]} -> {y.shape[1]}]") 430 | break 431 | 432 | y = torch.concat([y, samples], dim=1) 433 | 434 | codes = [y[:, prefix_len + int(self.ar_audio_prepend_bos) :]] 435 | if self.num_quantizers == 1: 436 | return torch.stack(codes, dim=-1) 437 | 438 | # Non-AR Decoders 439 | y_emb = self.nar_audio_embeddings[0]( 440 | y[:, int(self.ar_audio_prepend_bos) :] 441 | ) 442 | 443 | if self.prefix_mode in [2, 4]: # Exclude enrolled_phonemes 444 | enrolled_len = enroll_x_lens.max().item() 445 | # SOS + Synthesis Text + EOS 446 | text = torch.concat( 447 | [ 448 | text[:, :1], 449 | text[:, enrolled_len - 1 :], 450 | ], 451 | dim=1, 452 | ) 453 | text_len = text_len - (enrolled_len - 2) 454 | assert text.shape[0] == 1 455 | 456 | x = self.nar_text_embedding(text) 457 | x = self.nar_text_prenet(x) 458 | x = self.nar_text_position(x) 459 | 460 | if self.prefix_mode == 0: 461 | for i, (predict_layer, embedding_layer) in enumerate( 462 | zip( 463 | self.nar_predict_layers, 464 | self.nar_audio_embeddings[1:], 465 | ) 466 | ): 467 | y_pos = self.nar_audio_prenet(y_emb) 468 | y_pos = self.nar_audio_position(y_pos) 469 | xy_pos = torch.concat([x, y_pos], dim=1) 470 | 471 | xy_dec, _ = self.nar_decoder( 472 | (xy_pos, self.nar_stage_embeddings[i].weight) 473 | ) 474 | logits = predict_layer(xy_dec[:, text_len + prefix_len :]) 475 | 476 | samples = torch.argmax(logits, dim=-1) 477 | codes.append(samples) 478 | 479 | if i < self.num_quantizers - 2: 480 | y_emb[:, :prefix_len] += embedding_layer( 481 | prompts[..., i + 1] 482 | ) 483 | y_emb[:, prefix_len:] += embedding_layer(samples) 484 | else: 485 | for j in range(1, self.num_quantizers): 486 | y_emb[:, :prefix_len] += self.nar_audio_embeddings[j]( 487 | prompts[..., j] 488 | ) 489 | 490 | for i, (predict_layer, embedding_layer) in enumerate( 491 | zip( 492 | self.nar_predict_layers, 493 | self.nar_audio_embeddings[1:], 494 | ) 495 | ): 496 | y_pos = self.nar_audio_prenet(y_emb) 497 | y_pos = self.nar_audio_position(y_pos) 498 | xy_pos = torch.concat([x, y_pos], dim=1) 499 | 500 | xy_dec, _ = self.nar_decoder( 501 | (xy_pos, self.nar_stage_embeddings[i].weight) 502 | ) 503 | logits = predict_layer(xy_dec[:, text_len + prefix_len :]) 504 | 505 | samples = torch.argmax(logits, dim=-1) 506 | codes.append(samples) 507 | 508 | if i < self.num_quantizers - 2: 509 | y_emb[:, prefix_len:] += embedding_layer(samples) 510 | 511 | assert len(codes) == self.num_quantizers 512 | return torch.stack(codes, dim=-1) -------------------------------------------------------------------------------- /models/visualizer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright 2023 (authors: Feiteng Li) 3 | # 4 | # See ../../../../LICENSE for clarification regarding multiple authors 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | 18 | 19 | from typing import Dict, List, Tuple, Union 20 | 21 | import matplotlib.pyplot as plt 22 | import numpy as np 23 | import torch 24 | 25 | 26 | def visualize( 27 | predicts: Tuple[torch.Tensor], 28 | batch: Dict[str, Union[List, torch.Tensor]], 29 | output_dir: str, 30 | limit: int = 4, 31 | ) -> None: 32 | text_tokens = batch["text_tokens"].to("cpu").detach().numpy() 33 | text_tokens_lens = batch["text_tokens_lens"].to("cpu").detach().numpy() 34 | audio_features = batch["audio_features"].to("cpu").detach().numpy() 35 | audio_features_lens = ( 36 | batch["audio_features_lens"].to("cpu").detach().numpy() 37 | ) 38 | assert text_tokens.ndim == 2 39 | 40 | utt_ids, texts = batch["utt_id"], batch["text"] 41 | 42 | encoder_outputs = predicts[0].to("cpu").type(torch.float32).detach().numpy() 43 | decoder_outputs = predicts[1] 44 | if isinstance(decoder_outputs, list): 45 | decoder_outputs = decoder_outputs[-1] 46 | decoder_outputs = ( 47 | decoder_outputs.to("cpu").type(torch.float32).detach().numpy() 48 | ) 49 | 50 | vmin, vmax = 0, 1024 # Encodec 51 | if decoder_outputs.dtype == np.float32: 52 | vmin, vmax = -6, 0 # Fbank 53 | 54 | num_figures = 3 55 | for b, (utt_id, text) in enumerate(zip(utt_ids[:limit], texts[:limit])): 56 | _ = plt.figure(figsize=(14, 8 * num_figures)) 57 | 58 | S = text_tokens_lens[b] 59 | T = audio_features_lens[b] 60 | 61 | # encoder 62 | plt.subplot(num_figures, 1, 1) 63 | plt.title(f"Text: {text}") 64 | plt.imshow( 65 | X=np.transpose(encoder_outputs[b]), 66 | cmap=plt.get_cmap("jet"), 67 | aspect="auto", 68 | interpolation="nearest", 69 | ) 70 | plt.gca().invert_yaxis() 71 | plt.axvline(x=S - 0.4, linewidth=2, color="r") 72 | plt.xlabel("Encoder Output") 73 | plt.colorbar() 74 | 75 | # decoder 76 | plt.subplot(num_figures, 1, 2) 77 | plt.imshow( 78 | X=np.transpose(decoder_outputs[b]), 79 | cmap=plt.get_cmap("jet"), 80 | aspect="auto", 81 | interpolation="nearest", 82 | vmin=vmin, 83 | vmax=vmax, 84 | ) 85 | plt.gca().invert_yaxis() 86 | plt.axvline(x=T - 0.4, linewidth=2, color="r") 87 | plt.xlabel("Decoder Output") 88 | plt.colorbar() 89 | 90 | # target 91 | plt.subplot(num_figures, 1, 3) 92 | plt.imshow( 93 | X=np.transpose(audio_features[b]), 94 | cmap=plt.get_cmap("jet"), 95 | aspect="auto", 96 | interpolation="nearest", 97 | vmin=vmin, 98 | vmax=vmax, 99 | ) 100 | plt.gca().invert_yaxis() 101 | plt.axvline(x=T - 0.4, linewidth=2, color="r") 102 | plt.xlabel("Decoder Target") 103 | plt.colorbar() 104 | 105 | plt.savefig(f"{output_dir}/{utt_id}.png") 106 | plt.close() 107 | -------------------------------------------------------------------------------- /modules/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/0nutation/USLM/b6cbe07feaeb142fa395b8b98836c534fec16bc9/modules/.DS_Store -------------------------------------------------------------------------------- /modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/0nutation/USLM/b6cbe07feaeb142fa395b8b98836c534fec16bc9/modules/__init__.py -------------------------------------------------------------------------------- /modules/activation.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple 2 | 3 | import torch 4 | from torch import Tensor 5 | from torch.nn import Linear, Module 6 | from torch.nn import functional as F 7 | from torch.nn.init import constant_, xavier_normal_, xavier_uniform_ 8 | from torch.nn.modules.linear import NonDynamicallyQuantizableLinear 9 | from torch.nn.parameter import Parameter 10 | 11 | 12 | class MultiheadAttention(Module): 13 | r"""Allows the model to jointly attend to information 14 | from different representation subspaces as described in the paper: 15 | `Attention Is All You Need `_. 16 | 17 | Multi-Head Attention is defined as: 18 | 19 | .. math:: 20 | \text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O 21 | 22 | where :math:`head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)`. 23 | 24 | ``forward()`` will use a special optimized implementation if all of the following 25 | conditions are met: 26 | 27 | - self attention is being computed (i.e., ``query``, ``key``, and ``value`` are the same tensor. This 28 | restriction will be loosened in the future.) 29 | - Either autograd is disabled (using ``torch.inference_mode`` or ``torch.no_grad``) or no tensor argument ``requires_grad`` 30 | - training is disabled (using ``.eval()``) 31 | - dropout is 0 32 | - ``add_bias_kv`` is ``False`` 33 | - ``add_zero_attn`` is ``False`` 34 | - ``batch_first`` is ``True`` and the input is batched 35 | - ``kdim`` and ``vdim`` are equal to ``embed_dim`` 36 | - at most one of ``key_padding_mask`` or ``attn_mask`` is passed 37 | - if a `NestedTensor `_ is passed, neither ``key_padding_mask`` 38 | nor ``attn_mask`` is passed 39 | 40 | If the optimized implementation is in use, a 41 | `NestedTensor `_ can be passed for 42 | ``query``/``key``/``value`` to represent padding more efficiently than using a 43 | padding mask. In this case, a `NestedTensor `_ 44 | will be returned, and an additional speedup proportional to the fraction of the input 45 | that is padding can be expected. 46 | 47 | Args: 48 | embed_dim: Total dimension of the model. 49 | num_heads: Number of parallel attention heads. Note that ``embed_dim`` will be split 50 | across ``num_heads`` (i.e. each head will have dimension ``embed_dim // num_heads``). 51 | dropout: Dropout probability on ``attn_output_weights``. Default: ``0.0`` (no dropout). 52 | bias: If specified, adds bias to input / output projection layers. Default: ``True``. 53 | add_bias_kv: If specified, adds bias to the key and value sequences at dim=0. Default: ``False``. 54 | add_zero_attn: If specified, adds a new batch of zeros to the key and value sequences at dim=1. 55 | Default: ``False``. 56 | kdim: Total number of features for keys. Default: ``None`` (uses ``kdim=embed_dim``). 57 | vdim: Total number of features for values. Default: ``None`` (uses ``vdim=embed_dim``). 58 | batch_first: If ``True``, then the input and output tensors are provided 59 | as (batch, seq, feature). Default: ``False`` (seq, batch, feature). 60 | 61 | Examples:: 62 | 63 | >>> # xdoctest: +SKIP 64 | >>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads) 65 | >>> attn_output, attn_output_weights = multihead_attn(query, key, value) 66 | 67 | """ 68 | __constants__ = ["batch_first"] 69 | bias_k: Optional[torch.Tensor] 70 | bias_v: Optional[torch.Tensor] 71 | 72 | def __init__( 73 | self, 74 | embed_dim, 75 | num_heads, 76 | dropout=0.0, 77 | bias=True, 78 | add_bias_kv=False, 79 | add_zero_attn=False, 80 | kdim=None, 81 | vdim=None, 82 | batch_first=False, 83 | linear1_cls=Linear, 84 | linear2_cls=Linear, 85 | device=None, 86 | dtype=None, 87 | ) -> None: 88 | factory_kwargs = {"device": device, "dtype": dtype} 89 | super(MultiheadAttention, self).__init__() 90 | self.embed_dim = embed_dim 91 | self.kdim = kdim if kdim is not None else embed_dim 92 | self.vdim = vdim if vdim is not None else embed_dim 93 | self._qkv_same_embed_dim = ( 94 | self.kdim == embed_dim and self.vdim == embed_dim 95 | ) 96 | 97 | self.num_heads = num_heads 98 | self.dropout = dropout 99 | self.batch_first = batch_first 100 | self.head_dim = embed_dim // num_heads 101 | assert ( 102 | self.head_dim * num_heads == self.embed_dim 103 | ), "embed_dim must be divisible by num_heads" 104 | 105 | if add_bias_kv: 106 | self.bias_k = Parameter( 107 | torch.empty((1, 1, embed_dim), **factory_kwargs) 108 | ) 109 | self.bias_v = Parameter( 110 | torch.empty((1, 1, embed_dim), **factory_kwargs) 111 | ) 112 | else: 113 | self.bias_k = self.bias_v = None 114 | 115 | if linear1_cls == Linear: 116 | if not self._qkv_same_embed_dim: 117 | self.q_proj_weight = Parameter( 118 | torch.empty((embed_dim, embed_dim), **factory_kwargs) 119 | ) 120 | self.k_proj_weight = Parameter( 121 | torch.empty((embed_dim, self.kdim), **factory_kwargs) 122 | ) 123 | self.v_proj_weight = Parameter( 124 | torch.empty((embed_dim, self.vdim), **factory_kwargs) 125 | ) 126 | self.register_parameter("in_proj_weight", None) 127 | else: 128 | self.in_proj_weight = Parameter( 129 | torch.empty((3 * embed_dim, embed_dim), **factory_kwargs) 130 | ) 131 | self.register_parameter("q_proj_weight", None) 132 | self.register_parameter("k_proj_weight", None) 133 | self.register_parameter("v_proj_weight", None) 134 | 135 | if bias: 136 | self.in_proj_bias = Parameter( 137 | torch.empty(3 * embed_dim, **factory_kwargs) 138 | ) 139 | else: 140 | self.register_parameter("in_proj_bias", None) 141 | self.out_proj = NonDynamicallyQuantizableLinear( 142 | embed_dim, embed_dim, bias=bias, **factory_kwargs 143 | ) 144 | 145 | self._reset_parameters() 146 | else: 147 | if not self._qkv_same_embed_dim: 148 | raise NotImplementedError 149 | else: 150 | self.in_proj_linear = linear1_cls( 151 | embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs 152 | ) 153 | self.in_proj_weight = self.in_proj_linear.weight 154 | 155 | self.register_parameter("q_proj_weight", None) 156 | self.register_parameter("k_proj_weight", None) 157 | self.register_parameter("v_proj_weight", None) 158 | 159 | if bias: 160 | self.in_proj_bias = self.in_proj_linear.bias 161 | else: 162 | self.register_parameter("in_proj_bias", None) 163 | 164 | self.out_proj = linear2_cls( 165 | embed_dim, embed_dim, bias=bias, **factory_kwargs 166 | ) 167 | 168 | if self.bias_k is not None: 169 | xavier_normal_(self.bias_k) 170 | if self.bias_v is not None: 171 | xavier_normal_(self.bias_v) 172 | 173 | self.add_zero_attn = add_zero_attn 174 | 175 | def _reset_parameters(self): 176 | if self._qkv_same_embed_dim: 177 | xavier_uniform_(self.in_proj_weight) 178 | else: 179 | xavier_uniform_(self.q_proj_weight) 180 | xavier_uniform_(self.k_proj_weight) 181 | xavier_uniform_(self.v_proj_weight) 182 | 183 | if self.in_proj_bias is not None: 184 | constant_(self.in_proj_bias, 0.0) 185 | constant_(self.out_proj.bias, 0.0) 186 | 187 | if self.bias_k is not None: 188 | xavier_normal_(self.bias_k) 189 | if self.bias_v is not None: 190 | xavier_normal_(self.bias_v) 191 | 192 | def __setstate__(self, state): 193 | # Support loading old MultiheadAttention checkpoints generated by v1.1.0 194 | if "_qkv_same_embed_dim" not in state: 195 | state["_qkv_same_embed_dim"] = True 196 | 197 | super(MultiheadAttention, self).__setstate__(state) 198 | 199 | def forward( 200 | self, 201 | query: Tensor, 202 | key: Tensor, 203 | value: Tensor, 204 | key_padding_mask: Optional[Tensor] = None, 205 | need_weights: bool = True, 206 | attn_mask: Optional[Tensor] = None, 207 | average_attn_weights: bool = True, 208 | ) -> Tuple[Tensor, Optional[Tensor]]: 209 | r""" 210 | Args: 211 | query: Query embeddings of shape :math:`(L, E_q)` for unbatched input, :math:`(L, N, E_q)` when ``batch_first=False`` 212 | or :math:`(N, L, E_q)` when ``batch_first=True``, where :math:`L` is the target sequence length, 213 | :math:`N` is the batch size, and :math:`E_q` is the query embedding dimension ``embed_dim``. 214 | Queries are compared against key-value pairs to produce the output. 215 | See "Attention Is All You Need" for more details. 216 | key: Key embeddings of shape :math:`(S, E_k)` for unbatched input, :math:`(S, N, E_k)` when ``batch_first=False`` 217 | or :math:`(N, S, E_k)` when ``batch_first=True``, where :math:`S` is the source sequence length, 218 | :math:`N` is the batch size, and :math:`E_k` is the key embedding dimension ``kdim``. 219 | See "Attention Is All You Need" for more details. 220 | value: Value embeddings of shape :math:`(S, E_v)` for unbatched input, :math:`(S, N, E_v)` when 221 | ``batch_first=False`` or :math:`(N, S, E_v)` when ``batch_first=True``, where :math:`S` is the source 222 | sequence length, :math:`N` is the batch size, and :math:`E_v` is the value embedding dimension ``vdim``. 223 | See "Attention Is All You Need" for more details. 224 | key_padding_mask: If specified, a mask of shape :math:`(N, S)` indicating which elements within ``key`` 225 | to ignore for the purpose of attention (i.e. treat as "padding"). For unbatched `query`, shape should be :math:`(S)`. 226 | Binary and byte masks are supported. 227 | For a binary mask, a ``True`` value indicates that the corresponding ``key`` value will be ignored for 228 | the purpose of attention. For a float mask, it will be directly added to the corresponding ``key`` value. 229 | need_weights: If specified, returns ``attn_output_weights`` in addition to ``attn_outputs``. 230 | Default: ``True``. 231 | attn_mask: If specified, a 2D or 3D mask preventing attention to certain positions. Must be of shape 232 | :math:`(L, S)` or :math:`(N\cdot\text{num\_heads}, L, S)`, where :math:`N` is the batch size, 233 | :math:`L` is the target sequence length, and :math:`S` is the source sequence length. A 2D mask will be 234 | broadcasted across the batch while a 3D mask allows for a different mask for each entry in the batch. 235 | Binary, byte, and float masks are supported. For a binary mask, a ``True`` value indicates that the 236 | corresponding position is not allowed to attend. For a byte mask, a non-zero value indicates that the 237 | corresponding position is not allowed to attend. For a float mask, the mask values will be added to 238 | the attention weight. 239 | average_attn_weights: If true, indicates that the returned ``attn_weights`` should be averaged across 240 | heads. Otherwise, ``attn_weights`` are provided separately per head. Note that this flag only has an 241 | effect when ``need_weights=True``. Default: ``True`` (i.e. average weights across heads) 242 | 243 | Outputs: 244 | - **attn_output** - Attention outputs of shape :math:`(L, E)` when input is unbatched, 245 | :math:`(L, N, E)` when ``batch_first=False`` or :math:`(N, L, E)` when ``batch_first=True``, 246 | where :math:`L` is the target sequence length, :math:`N` is the batch size, and :math:`E` is the 247 | embedding dimension ``embed_dim``. 248 | - **attn_output_weights** - Only returned when ``need_weights=True``. If ``average_attn_weights=True``, 249 | returns attention weights averaged across heads of shape :math:`(L, S)` when input is unbatched or 250 | :math:`(N, L, S)`, where :math:`N` is the batch size, :math:`L` is the target sequence length, and 251 | :math:`S` is the source sequence length. If ``average_attn_weights=False``, returns attention weights per 252 | head of shape :math:`(\text{num\_heads}, L, S)` when input is unbatched or :math:`(N, \text{num\_heads}, L, S)`. 253 | 254 | .. note:: 255 | `batch_first` argument is ignored for unbatched inputs. 256 | """ 257 | is_batched = query.dim() == 3 258 | if key_padding_mask is not None: 259 | _kpm_dtype = key_padding_mask.dtype 260 | if _kpm_dtype != torch.bool and not torch.is_floating_point( 261 | key_padding_mask 262 | ): 263 | raise AssertionError( 264 | "only bool and floating types of key_padding_mask are supported" 265 | ) 266 | why_not_fast_path = "" 267 | if not is_batched: 268 | why_not_fast_path = f"input not batched; expected query.dim() of 3 but got {query.dim()}" 269 | elif query is not key or key is not value: 270 | # When lifting this restriction, don't forget to either 271 | # enforce that the dtypes all match or test cases where 272 | # they don't! 273 | why_not_fast_path = "non-self attention was used (query, key, and value are not the same Tensor)" 274 | elif ( 275 | self.in_proj_bias is not None 276 | and query.dtype != self.in_proj_bias.dtype 277 | ): 278 | why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_bias ({self.in_proj_bias.dtype}) don't match" 279 | elif ( 280 | self.in_proj_weight is not None 281 | and query.dtype != self.in_proj_weight.dtype 282 | ): 283 | # this case will fail anyway, but at least they'll get a useful error message. 284 | why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_weight ({self.in_proj_weight.dtype}) don't match" 285 | elif self.training: 286 | why_not_fast_path = "training is enabled" 287 | elif not self.batch_first: 288 | why_not_fast_path = "batch_first was not True" 289 | elif self.bias_k is not None: 290 | why_not_fast_path = "self.bias_k was not None" 291 | elif self.bias_v is not None: 292 | why_not_fast_path = "self.bias_v was not None" 293 | elif self.dropout: 294 | why_not_fast_path = f"dropout was {self.dropout}, required zero" 295 | elif self.add_zero_attn: 296 | why_not_fast_path = "add_zero_attn was enabled" 297 | elif not self._qkv_same_embed_dim: 298 | why_not_fast_path = "_qkv_same_embed_dim was not True" 299 | elif attn_mask is not None: 300 | why_not_fast_path = "attn_mask was not None" 301 | elif query.is_nested and key_padding_mask is not None: 302 | why_not_fast_path = ( 303 | "key_padding_mask is not supported with NestedTensor input" 304 | ) 305 | elif self.num_heads % 2 == 1: 306 | why_not_fast_path = "num_heads is odd" 307 | elif torch.is_autocast_enabled(): 308 | why_not_fast_path = "autocast is enabled" 309 | 310 | if not why_not_fast_path: 311 | tensor_args = ( 312 | query, 313 | key, 314 | value, 315 | self.in_proj_weight, 316 | self.in_proj_bias, 317 | self.out_proj.weight, 318 | self.out_proj.bias, 319 | ) 320 | # We have to use list comprehensions below because TorchScript does not support 321 | # generator expressions. 322 | if torch.overrides.has_torch_function(tensor_args): 323 | why_not_fast_path = "some Tensor argument has_torch_function" 324 | elif not all( 325 | [ 326 | (x is None or x.is_cuda or "cpu" in str(x.device)) 327 | for x in tensor_args 328 | ] 329 | ): 330 | why_not_fast_path = ( 331 | "some Tensor argument is neither CUDA nor CPU" 332 | ) 333 | elif torch.is_grad_enabled() and any( 334 | [x is not None and x.requires_grad for x in tensor_args] 335 | ): 336 | why_not_fast_path = ( 337 | "grad is enabled and at least one of query or the " 338 | "input/output projection weights or biases requires_grad" 339 | ) 340 | if not why_not_fast_path: 341 | return torch._native_multi_head_attention( 342 | query, 343 | key, 344 | value, 345 | self.embed_dim, 346 | self.num_heads, 347 | self.in_proj_weight, 348 | self.in_proj_bias, 349 | self.out_proj.weight, 350 | self.out_proj.bias, 351 | key_padding_mask 352 | if key_padding_mask is not None 353 | else attn_mask, 354 | need_weights, 355 | average_attn_weights, 356 | 1 357 | if key_padding_mask is not None 358 | else 0 359 | if attn_mask is not None 360 | else None, 361 | ) 362 | 363 | any_nested = query.is_nested or key.is_nested or value.is_nested 364 | assert not any_nested, ( 365 | "MultiheadAttention does not support NestedTensor outside of its fast path. " 366 | + f"The fast path was not hit because {why_not_fast_path}" 367 | ) 368 | 369 | if self.batch_first and is_batched: 370 | # make sure that the transpose op does not affect the "is" property 371 | if key is value: 372 | if query is key: 373 | query = key = value = query.transpose(1, 0) 374 | else: 375 | query, key = [x.transpose(1, 0) for x in (query, key)] 376 | value = key 377 | else: 378 | query, key, value = [ 379 | x.transpose(1, 0) for x in (query, key, value) 380 | ] 381 | 382 | if not self._qkv_same_embed_dim: 383 | attn_output, attn_output_weights = F.multi_head_attention_forward( 384 | query, 385 | key, 386 | value, 387 | self.embed_dim, 388 | self.num_heads, 389 | self.in_proj_weight, 390 | self.in_proj_bias, 391 | self.bias_k, 392 | self.bias_v, 393 | self.add_zero_attn, 394 | self.dropout, 395 | self.out_proj.weight, 396 | self.out_proj.bias, 397 | training=self.training, 398 | key_padding_mask=key_padding_mask, 399 | need_weights=need_weights, 400 | attn_mask=attn_mask, 401 | use_separate_proj_weight=True, 402 | q_proj_weight=self.q_proj_weight, 403 | k_proj_weight=self.k_proj_weight, 404 | v_proj_weight=self.v_proj_weight, 405 | average_attn_weights=average_attn_weights, 406 | ) 407 | else: 408 | attn_output, attn_output_weights = F.multi_head_attention_forward( 409 | query, 410 | key, 411 | value, 412 | self.embed_dim, 413 | self.num_heads, 414 | self.in_proj_weight, 415 | self.in_proj_bias, 416 | self.bias_k, 417 | self.bias_v, 418 | self.add_zero_attn, 419 | self.dropout, 420 | self.out_proj.weight, 421 | self.out_proj.bias, 422 | training=self.training, 423 | key_padding_mask=key_padding_mask, 424 | need_weights=need_weights, 425 | attn_mask=attn_mask, 426 | average_attn_weights=average_attn_weights, 427 | ) 428 | if self.batch_first and is_batched: 429 | return attn_output.transpose(1, 0), attn_output_weights 430 | else: 431 | return attn_output, attn_output_weights 432 | -------------------------------------------------------------------------------- /modules/embedding.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 (authors: Feiteng Li) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import math 16 | 17 | import torch 18 | import torch.nn as nn 19 | 20 | 21 | class TokenEmbedding(nn.Module): 22 | def __init__( 23 | self, 24 | dim_model: int, 25 | vocab_size: int, 26 | dropout: float = 0.0, 27 | ): 28 | super().__init__() 29 | 30 | self.vocab_size = vocab_size 31 | self.dim_model = dim_model 32 | 33 | self.dropout = torch.nn.Dropout(p=dropout) 34 | self.word_embeddings = nn.Embedding(self.vocab_size, self.dim_model) 35 | 36 | @property 37 | def weight(self) -> torch.Tensor: 38 | return self.word_embeddings.weight 39 | 40 | def embedding(self, index: int) -> torch.Tensor: 41 | return self.word_embeddings.weight[index : index + 1] 42 | 43 | def forward(self, x: torch.Tensor): 44 | X = self.word_embeddings(x) 45 | X = self.dropout(X) 46 | 47 | return X 48 | 49 | 50 | class SinePositionalEmbedding(nn.Module): 51 | def __init__( 52 | self, 53 | dim_model: int, 54 | dropout: float = 0.0, 55 | scale: bool = False, 56 | alpha: bool = False, 57 | ): 58 | super().__init__() 59 | self.dim_model = dim_model 60 | self.x_scale = math.sqrt(dim_model) if scale else 1.0 61 | self.alpha = nn.Parameter(torch.ones(1), requires_grad=alpha) 62 | self.dropout = torch.nn.Dropout(p=dropout) 63 | 64 | self.reverse = False 65 | self.pe = None 66 | self.extend_pe(torch.tensor(0.0).expand(1, 4000)) 67 | 68 | def extend_pe(self, x): 69 | """Reset the positional encodings.""" 70 | if self.pe is not None: 71 | if self.pe.size(1) >= x.size(1): 72 | if self.pe.dtype != x.dtype or self.pe.device != x.device: 73 | self.pe = self.pe.to(dtype=x.dtype, device=x.device) 74 | return 75 | pe = torch.zeros(x.size(1), self.dim_model) 76 | if self.reverse: 77 | position = torch.arange( 78 | x.size(1) - 1, -1, -1.0, dtype=torch.float32 79 | ).unsqueeze(1) 80 | else: 81 | position = torch.arange( 82 | 0, x.size(1), dtype=torch.float32 83 | ).unsqueeze(1) 84 | div_term = torch.exp( 85 | torch.arange(0, self.dim_model, 2, dtype=torch.float32) 86 | * -(math.log(10000.0) / self.dim_model) 87 | ) 88 | pe[:, 0::2] = torch.sin(position * div_term) 89 | pe[:, 1::2] = torch.cos(position * div_term) 90 | pe = pe.unsqueeze(0) 91 | self.pe = pe.to(device=x.device, dtype=x.dtype).detach() 92 | 93 | def forward(self, x: torch.Tensor) -> torch.Tensor: 94 | self.extend_pe(x) 95 | output = x.unsqueeze(-1) if x.ndim == 2 else x 96 | output = output * self.x_scale + self.alpha * self.pe[:, : x.size(1)] 97 | return self.dropout(output) 98 | -------------------------------------------------------------------------------- /modules/scheduler.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright 2023 (authors: Feiteng Li) 3 | # 4 | # See ../../../../LICENSE for clarification regarding multiple authors 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | 18 | 19 | import torch 20 | 21 | from uslm.modules.optim import Eden 22 | 23 | 24 | def calc_lr(step, dim_embed, warmup_steps): 25 | return dim_embed ** (-0.5) * min( 26 | step ** (-0.5), step * warmup_steps ** (-1.5) 27 | ) 28 | 29 | 30 | class NoamScheduler(torch.optim.lr_scheduler._LRScheduler): 31 | def __init__( 32 | self, 33 | base_lr: float, 34 | optimizer: torch.optim.Optimizer, 35 | dim_embed: int, 36 | warmup_steps: int, 37 | last_epoch: int = -1, 38 | verbose: bool = False, 39 | ) -> None: 40 | 41 | self.dim_embed = dim_embed 42 | self.base_lr = base_lr 43 | self.warmup_steps = warmup_steps 44 | self.num_param_groups = len(optimizer.param_groups) 45 | 46 | super().__init__(optimizer, last_epoch, verbose) 47 | 48 | def get_lr(self) -> float: 49 | lr = self.base_lr * calc_lr( 50 | self._step_count, self.dim_embed, self.warmup_steps 51 | ) 52 | return [lr] * self.num_param_groups 53 | 54 | def set_step(self, step: int): 55 | self._step_count = step 56 | 57 | 58 | def get_scheduler(params, optimizer): 59 | if params.scheduler_name.lower() == "eden": 60 | scheduler = Eden(optimizer, 5000, 4, warmup_batches=params.warmup_steps) 61 | elif params.scheduler_name.lower() == "noam": 62 | scheduler = NoamScheduler( 63 | params.base_lr, 64 | optimizer, 65 | params.decoder_dim, 66 | warmup_steps=params.warmup_steps, 67 | ) 68 | # scheduler.set_step(params.start_batch or params.batch_idx_train) 69 | elif params.scheduler_name.lower() == "cosine": 70 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( 71 | params.warmup_steps, 72 | optimizer, 73 | eta_min=params.base_lr, 74 | ) 75 | else: 76 | raise NotImplementedError(f"{params.scheduler_name}") 77 | 78 | return scheduler 79 | -------------------------------------------------------------------------------- /modules/transformer.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import numbers 3 | from functools import partial 4 | from typing import Any, Callable, List, Optional, Tuple, Union 5 | 6 | import torch 7 | from torch import Tensor, nn 8 | from torch.nn import functional as F 9 | 10 | from .activation import MultiheadAttention 11 | from .scaling import ActivationBalancer, BalancedDoubleSwish 12 | from .scaling import BasicNorm as _BasicNorm 13 | 14 | _shape_t = Union[int, List[int], torch.Size] 15 | 16 | 17 | class LayerNorm(nn.Module): 18 | __constants__ = ["normalized_shape", "eps", "elementwise_affine"] 19 | normalized_shape: Tuple[int, ...] 20 | eps: float 21 | elementwise_affine: bool 22 | 23 | def __init__( 24 | self, 25 | normalized_shape: _shape_t, 26 | eps: float = 1e-5, 27 | elementwise_affine: bool = True, 28 | device=None, 29 | dtype=None, 30 | ) -> None: 31 | factory_kwargs = {"device": device, "dtype": dtype} 32 | super(LayerNorm, self).__init__() 33 | if isinstance(normalized_shape, numbers.Integral): 34 | # mypy error: incompatible types in assignment 35 | normalized_shape = (normalized_shape,) # type: ignore[assignment] 36 | self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type] 37 | self.eps = eps 38 | self.elementwise_affine = elementwise_affine 39 | if self.elementwise_affine: 40 | self.weight = nn.Parameter( 41 | torch.empty(self.normalized_shape, **factory_kwargs) 42 | ) 43 | self.bias = nn.Parameter( 44 | torch.empty(self.normalized_shape, **factory_kwargs) 45 | ) 46 | else: 47 | self.register_parameter("weight", None) 48 | self.register_parameter("bias", None) 49 | 50 | self.reset_parameters() 51 | 52 | def reset_parameters(self) -> None: 53 | if self.elementwise_affine: 54 | nn.init.ones_(self.weight) 55 | nn.init.zeros_(self.bias) 56 | 57 | def forward(self, input: Tensor, embedding: Any = None) -> Tensor: 58 | if isinstance(input, tuple): 59 | input, embedding = input 60 | return ( 61 | F.layer_norm( 62 | input, 63 | self.normalized_shape, 64 | self.weight, 65 | self.bias, 66 | self.eps, 67 | ), 68 | embedding, 69 | ) 70 | 71 | assert embedding is None 72 | return F.layer_norm( 73 | input, self.normalized_shape, self.weight, self.bias, self.eps 74 | ) 75 | 76 | def extra_repr(self) -> str: 77 | return ( 78 | "{normalized_shape}, eps={eps}, " 79 | "elementwise_affine={elementwise_affine}".format(**self.__dict__) 80 | ) 81 | 82 | 83 | class AdaptiveLayerNorm(nn.Module): 84 | r"""Adaptive Layer Normalization""" 85 | 86 | def __init__(self, d_model, norm) -> None: 87 | super(AdaptiveLayerNorm, self).__init__() 88 | self.project_layer = nn.Linear(d_model, 2 * d_model) 89 | self.norm = norm 90 | self.d_model = d_model 91 | self.eps = self.norm.eps 92 | 93 | def forward(self, input: Tensor, embedding: Tensor = None) -> Tensor: 94 | if isinstance(input, tuple): 95 | input, embedding = input 96 | weight, bias = torch.split( 97 | self.project_layer(embedding), 98 | split_size_or_sections=self.d_model, 99 | dim=-1, 100 | ) 101 | return (weight * self.norm(input) + bias, embedding) 102 | 103 | weight, bias = torch.split( 104 | self.project_layer(embedding), 105 | split_size_or_sections=self.d_model, 106 | dim=-1, 107 | ) 108 | return weight * self.norm(input) + bias 109 | 110 | 111 | class BasicNorm(_BasicNorm): 112 | def __init__( 113 | self, 114 | d_model: int, 115 | eps: float = 1e-5, 116 | device=None, 117 | dtype=None, 118 | ): 119 | super(BasicNorm, self).__init__(d_model, eps=eps) 120 | 121 | def forward(self, input: Tensor, embedding: Any = None) -> Tensor: 122 | if isinstance(input, tuple): 123 | input, embedding = input 124 | return ( 125 | super(BasicNorm, self).forward(input), 126 | embedding, 127 | ) 128 | 129 | assert embedding is None 130 | return super(BasicNorm, self).forward(input) 131 | 132 | 133 | class BalancedBasicNorm(nn.Module): 134 | def __init__( 135 | self, 136 | d_model: int, 137 | eps: float = 1e-5, 138 | device=None, 139 | dtype=None, 140 | ): 141 | super(BalancedBasicNorm, self).__init__() 142 | self.balancer = ActivationBalancer( 143 | d_model, 144 | channel_dim=-1, 145 | min_positive=0.45, 146 | max_positive=0.55, 147 | max_abs=6.0, 148 | ) 149 | self.norm = BasicNorm(d_model, eps, device=device, dtype=dtype) 150 | 151 | def forward(self, input: Tensor, embedding: Any = None) -> Tensor: 152 | if isinstance(input, tuple): 153 | input, embedding = input 154 | return self.norm((self.balancer(input), embedding)) 155 | 156 | assert embedding is None 157 | return self.norm(self.balancer(input)) 158 | 159 | 160 | class IdentityNorm(nn.Module): 161 | def __init__( 162 | self, 163 | d_model: int, 164 | eps: float = 1e-5, 165 | device=None, 166 | dtype=None, 167 | ) -> None: 168 | super(IdentityNorm, self).__init__() 169 | 170 | def forward(self, input: Tensor, embedding: Any = None) -> Tensor: 171 | if isinstance(input, tuple): 172 | return input 173 | 174 | assert embedding is None 175 | return input 176 | 177 | 178 | class TransformerEncoderLayer(nn.Module): 179 | __constants__ = ["batch_first", "norm_first"] 180 | 181 | def __init__( 182 | self, 183 | d_model: int, 184 | nhead: int, 185 | dim_feedforward: int = 2048, 186 | dropout: float = 0.1, 187 | activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, 188 | batch_first: bool = False, 189 | norm_first: bool = False, 190 | device=None, 191 | dtype=None, 192 | linear1_self_attention_cls: nn.Module = nn.Linear, 193 | linear2_self_attention_cls: nn.Module = nn.Linear, 194 | linear1_feedforward_cls: nn.Module = nn.Linear, 195 | linear2_feedforward_cls: nn.Module = nn.Linear, 196 | layer_norm_cls: nn.Module = LayerNorm, 197 | layer_norm_eps: float = 1e-5, 198 | adaptive_layer_norm=False, 199 | ) -> None: 200 | factory_kwargs = {"device": device, "dtype": dtype} 201 | super(TransformerEncoderLayer, self).__init__() 202 | self.self_attn = MultiheadAttention( 203 | d_model, 204 | nhead, 205 | dropout=dropout, 206 | batch_first=batch_first, 207 | linear1_cls=linear1_self_attention_cls, 208 | linear2_cls=linear2_self_attention_cls, 209 | **factory_kwargs, 210 | ) 211 | 212 | # Implementation of Feedforward model 213 | self.linear1 = linear1_feedforward_cls( 214 | d_model, dim_feedforward, **factory_kwargs 215 | ) 216 | self.dropout = nn.Dropout(dropout) 217 | self.linear2 = linear2_feedforward_cls( 218 | dim_feedforward, d_model, **factory_kwargs 219 | ) 220 | 221 | self.norm_first = norm_first 222 | self.dropout1 = nn.Dropout(dropout) 223 | self.dropout2 = nn.Dropout(dropout) 224 | 225 | # Legacy string support for activation function. 226 | if isinstance(activation, str): 227 | activation = _get_activation_fn(activation) 228 | elif isinstance(activation, partial): 229 | activation = activation(d_model) 230 | elif activation == BalancedDoubleSwish: 231 | activation = BalancedDoubleSwish(d_model) 232 | 233 | # # We can't test self.activation in forward() in TorchScript, 234 | # # so stash some information about it instead. 235 | # if activation is F.relu or isinstance(activation, torch.nn.ReLU): 236 | # self.activation_relu_or_gelu = 1 237 | # elif activation is F.gelu or isinstance(activation, torch.nn.GELU): 238 | # self.activation_relu_or_gelu = 2 239 | # else: 240 | # self.activation_relu_or_gelu = 0 241 | self.activation = activation 242 | 243 | norm1 = layer_norm_cls(d_model, eps=layer_norm_eps, **factory_kwargs) 244 | if layer_norm_cls == IdentityNorm: 245 | norm2 = BalancedBasicNorm( 246 | d_model, eps=layer_norm_eps, **factory_kwargs 247 | ) 248 | else: 249 | norm2 = layer_norm_cls( 250 | d_model, eps=layer_norm_eps, **factory_kwargs 251 | ) 252 | 253 | if adaptive_layer_norm: 254 | self.norm1 = AdaptiveLayerNorm(d_model, norm1) 255 | self.norm2 = AdaptiveLayerNorm(d_model, norm2) 256 | else: 257 | self.norm1 = norm1 258 | self.norm2 = norm2 259 | 260 | def __setstate__(self, state): 261 | super(TransformerEncoderLayer, self).__setstate__(state) 262 | if not hasattr(self, "activation"): 263 | self.activation = F.relu 264 | 265 | def forward( 266 | self, 267 | src: Tensor, 268 | src_mask: Optional[Tensor] = None, 269 | src_key_padding_mask: Optional[Tensor] = None, 270 | ) -> Tensor: 271 | r"""Pass the input through the encoder layer. 272 | 273 | Args: 274 | src: the sequence to the encoder layer (required). 275 | src_mask: the mask for the src sequence (optional). 276 | src_key_padding_mask: the mask for the src keys per batch (optional). 277 | 278 | Shape: 279 | see the docs in Transformer class. 280 | """ 281 | x, stage_embedding = src, None 282 | is_src_tuple = False 283 | if isinstance(src, tuple): 284 | x, stage_embedding = src 285 | is_src_tuple = True 286 | 287 | if src_key_padding_mask is not None: 288 | _skpm_dtype = src_key_padding_mask.dtype 289 | if _skpm_dtype != torch.bool and not torch.is_floating_point( 290 | src_key_padding_mask 291 | ): 292 | raise AssertionError( 293 | "only bool and floating types of key_padding_mask are supported" 294 | ) 295 | 296 | if self.norm_first: 297 | x = x + self._sa_block( 298 | self.norm1(x, stage_embedding), 299 | src_mask, 300 | src_key_padding_mask, 301 | ) 302 | x = x + self._ff_block(self.norm2(x, stage_embedding)) 303 | else: 304 | x = self.norm1( 305 | x + self._sa_block(x, src_mask, src_key_padding_mask), 306 | stage_embedding, 307 | ) 308 | x = self.norm2(x + self._ff_block(x), stage_embedding) 309 | 310 | if is_src_tuple: 311 | return (x, stage_embedding) 312 | return x 313 | 314 | # self-attention block 315 | def _sa_block( 316 | self, 317 | x: Tensor, 318 | attn_mask: Optional[Tensor], 319 | key_padding_mask: Optional[Tensor], 320 | ) -> Tensor: 321 | x = self.self_attn( 322 | x, 323 | x, 324 | x, 325 | attn_mask=attn_mask, 326 | key_padding_mask=key_padding_mask, 327 | need_weights=False, 328 | )[0] 329 | return self.dropout1(x) 330 | 331 | # feed forward block 332 | def _ff_block(self, x: Tensor) -> Tensor: 333 | x = self.linear2(self.dropout(self.activation(self.linear1(x)))) 334 | return self.dropout2(x) 335 | 336 | 337 | class TransformerEncoder(nn.Module): 338 | r"""TransformerEncoder is a stack of N encoder layers. Users can build the 339 | BERT(https://arxiv.org/abs/1810.04805) model with corresponding parameters. 340 | 341 | Args: 342 | encoder_layer: an instance of the TransformerEncoderLayer() class (required). 343 | num_layers: the number of sub-encoder-layers in the encoder (required). 344 | norm: the layer normalization component (optional). 345 | enable_nested_tensor: if True, input will automatically convert to nested tensor 346 | (and convert back on output). This will improve the overall performance of 347 | TransformerEncoder when padding rate is high. Default: ``True`` (enabled). 348 | 349 | Examples:: 350 | >>> encoder_layer = TransformerEncoderLayer(d_model=512, nhead=8) 351 | >>> transformer_encoder = TransformerEncoder(encoder_layer, num_layers=6) 352 | >>> src = torch.rand(10, 32, 512) 353 | >>> out = transformer_encoder(src) 354 | """ 355 | __constants__ = ["norm"] 356 | 357 | def __init__(self, encoder_layer, num_layers, norm=None): 358 | super(TransformerEncoder, self).__init__() 359 | self.layers = _get_clones(encoder_layer, num_layers) 360 | self.num_layers = num_layers 361 | self.norm = norm 362 | 363 | def forward( 364 | self, 365 | src: Tensor, 366 | mask: Optional[Tensor] = None, 367 | src_key_padding_mask: Optional[Tensor] = None, 368 | return_layer_states: bool = False, 369 | ) -> Tensor: 370 | r"""Pass the input through the encoder layers in turn. 371 | 372 | Args: 373 | src: the sequence to the encoder (required). 374 | mask: the mask for the src sequence (optional). 375 | src_key_padding_mask: the mask for the src keys per batch (optional). 376 | return_layer_states: return layers' state (optional). 377 | 378 | Shape: 379 | see the docs in Transformer class. 380 | """ 381 | if return_layer_states: 382 | layer_states = [] # layers' output 383 | output = src 384 | for mod in self.layers: 385 | output = mod( 386 | output, 387 | src_mask=mask, 388 | src_key_padding_mask=src_key_padding_mask, 389 | ) 390 | layer_states.append(output[0]) 391 | 392 | if self.norm is not None: 393 | output = self.norm(output) 394 | 395 | return layer_states, output 396 | 397 | output = src 398 | for mod in self.layers: 399 | output = mod( 400 | output, src_mask=mask, src_key_padding_mask=src_key_padding_mask 401 | ) 402 | 403 | if self.norm is not None: 404 | output = self.norm(output) 405 | 406 | return output 407 | 408 | 409 | class TransformerDecoderLayer(nn.Module): 410 | __constants__ = ["batch_first", "norm_first"] 411 | 412 | def __init__( 413 | self, 414 | d_model: int, 415 | nhead: int, 416 | dim_feedforward: int = 2048, 417 | dropout: float = 0.1, 418 | activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, 419 | linear1_self_attention_cls: nn.Module = nn.Linear, 420 | linear2_self_attention_cls: nn.Module = nn.Linear, 421 | linear1_feedforward_cls: nn.Module = nn.Linear, 422 | linear2_feedforward_cls: nn.Module = nn.Linear, 423 | batch_first: bool = False, 424 | norm_first: bool = False, 425 | device=None, 426 | dtype=None, 427 | layer_norm_cls: nn.Module = LayerNorm, 428 | layer_norm_eps: float = 1e-5, 429 | adaptive_layer_norm=False, 430 | ) -> None: 431 | factory_kwargs = {"device": device, "dtype": dtype} 432 | super(TransformerDecoderLayer, self).__init__() 433 | self.self_attn = MultiheadAttention( 434 | d_model, 435 | nhead, 436 | dropout=dropout, 437 | batch_first=batch_first, 438 | linear1_cls=linear1_self_attention_cls, 439 | linear2_cls=linear2_self_attention_cls, 440 | **factory_kwargs, 441 | ) 442 | self.multihead_attn = MultiheadAttention( 443 | d_model, 444 | nhead, 445 | dropout=dropout, 446 | batch_first=batch_first, 447 | linear1_cls=linear1_self_attention_cls, 448 | linear2_cls=linear2_self_attention_cls, 449 | **factory_kwargs, 450 | ) 451 | # Implementation of Feedforward model 452 | self.linear1 = linear1_feedforward_cls( 453 | d_model, dim_feedforward, **factory_kwargs 454 | ) 455 | self.dropout = nn.Dropout(dropout) 456 | self.linear2 = linear2_feedforward_cls( 457 | dim_feedforward, d_model, **factory_kwargs 458 | ) 459 | 460 | self.norm_first = norm_first 461 | self.dropout1 = nn.Dropout(dropout) 462 | self.dropout2 = nn.Dropout(dropout) 463 | self.dropout3 = nn.Dropout(dropout) 464 | 465 | # Legacy string support for activation function. 466 | if isinstance(activation, str): 467 | self.activation = _get_activation_fn(activation) 468 | elif isinstance(activation, partial): 469 | self.activation = activation(d_model) 470 | elif activation == BalancedDoubleSwish: 471 | self.activation = BalancedDoubleSwish(d_model) 472 | else: 473 | self.activation = activation 474 | 475 | if adaptive_layer_norm: 476 | norm1 = layer_norm_cls( 477 | d_model, eps=layer_norm_eps, **factory_kwargs 478 | ) 479 | norm2 = layer_norm_cls( 480 | d_model, eps=layer_norm_eps, **factory_kwargs 481 | ) 482 | norm3 = layer_norm_cls( 483 | d_model, eps=layer_norm_eps, **factory_kwargs 484 | ) 485 | 486 | self.norm1 = AdaptiveLayerNorm(d_model, norm1) 487 | self.norm2 = AdaptiveLayerNorm(d_model, norm2) 488 | self.norm3 = AdaptiveLayerNorm(d_model, norm3) 489 | else: 490 | self.norm1 = layer_norm_cls( 491 | d_model, eps=layer_norm_eps, **factory_kwargs 492 | ) 493 | self.norm2 = layer_norm_cls( 494 | d_model, eps=layer_norm_eps, **factory_kwargs 495 | ) 496 | if layer_norm_cls == IdentityNorm: 497 | self.norm3 = BalancedBasicNorm( 498 | d_model, eps=layer_norm_eps, **factory_kwargs 499 | ) 500 | else: 501 | self.norm3 = layer_norm_cls( 502 | d_model, eps=layer_norm_eps, **factory_kwargs 503 | ) 504 | 505 | def forward( 506 | self, 507 | tgt: Tensor, 508 | memory: Tensor, 509 | tgt_mask: Optional[Tensor] = None, 510 | memory_mask: Optional[Tensor] = None, 511 | tgt_key_padding_mask: Optional[Tensor] = None, 512 | memory_key_padding_mask: Optional[Tensor] = None, 513 | ) -> Tensor: 514 | r"""Pass the inputs (and mask) through the decoder layer. 515 | 516 | Args: 517 | tgt: the sequence to the decoder layer (required). 518 | memory: the sequence from the last layer of the encoder (required). 519 | tgt_mask: the mask for the tgt sequence (optional). 520 | memory_mask: the mask for the memory sequence (optional). 521 | tgt_key_padding_mask: the mask for the tgt keys per batch (optional). 522 | memory_key_padding_mask: the mask for the memory keys per batch (optional). 523 | 524 | Shape: 525 | see the docs in Transformer class. 526 | """ 527 | tgt_is_tuple = False 528 | if isinstance(tgt, tuple): 529 | x, stage_embedding = tgt 530 | tgt_is_tuple = True 531 | else: 532 | x, stage_embedding = tgt, None 533 | 534 | if self.norm_first: 535 | x = x + self._sa_block( 536 | self.norm1(x, stage_embedding), tgt_mask, tgt_key_padding_mask 537 | ) 538 | x = x + self._mha_block( 539 | self.norm2(x, stage_embedding), 540 | memory, 541 | memory_mask, 542 | memory_key_padding_mask, 543 | ) 544 | x = x + self._ff_block(self.norm3(x, stage_embedding)) 545 | else: 546 | x = self.norm1( 547 | x + self._sa_block(x, tgt_mask, tgt_key_padding_mask), 548 | stage_embedding, 549 | ) 550 | x = self.norm2( 551 | x 552 | + self._mha_block( 553 | x, memory, memory_mask, memory_key_padding_mask 554 | ), 555 | stage_embedding, 556 | ) 557 | x = self.norm3(x + self._ff_block(x), stage_embedding) 558 | 559 | if tgt_is_tuple: 560 | return (x, stage_embedding) 561 | return x 562 | 563 | # self-attention block 564 | def _sa_block( 565 | self, 566 | x: Tensor, 567 | attn_mask: Optional[Tensor], 568 | key_padding_mask: Optional[Tensor], 569 | ) -> Tensor: 570 | x = self.self_attn( 571 | x, 572 | x, 573 | x, 574 | attn_mask=attn_mask, 575 | key_padding_mask=key_padding_mask, 576 | need_weights=False, 577 | )[0] 578 | return self.dropout1(x) 579 | 580 | # multihead attention block 581 | def _mha_block( 582 | self, 583 | x: Tensor, 584 | mem: Tensor, 585 | attn_mask: Optional[Tensor], 586 | key_padding_mask: Optional[Tensor], 587 | ) -> Tensor: 588 | x = self.multihead_attn( 589 | x, 590 | mem, 591 | mem, 592 | attn_mask=attn_mask, 593 | key_padding_mask=key_padding_mask, 594 | need_weights=False, 595 | )[0] 596 | return self.dropout2(x) 597 | 598 | # feed forward block 599 | def _ff_block(self, x: Tensor) -> Tensor: 600 | x = self.linear2(self.dropout(self.activation(self.linear1(x)))) 601 | return self.dropout3(x) 602 | 603 | 604 | def _get_clones(module, N): 605 | return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) 606 | 607 | 608 | def _get_activation_fn(activation: str) -> Callable[[Tensor], Tensor]: 609 | if activation == "relu": 610 | return F.relu 611 | elif activation == "gelu": 612 | return F.gelu 613 | 614 | raise RuntimeError( 615 | "activation should be relu/gelu, not {}".format(activation) 616 | ) 617 | -------------------------------------------------------------------------------- /prompts/1580_141083_000002_000002.normalized.txt: -------------------------------------------------------------------------------- 1 | mr Soames was a tall, spare man, of a nervous and excitable temperament. -------------------------------------------------------------------------------- /prompts/1580_141083_000002_000002.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/0nutation/USLM/b6cbe07feaeb142fa395b8b98836c534fec16bc9/prompts/1580_141083_000002_000002.wav -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | import os 3 | import sys 4 | from pathlib import Path 5 | from subprocess import DEVNULL, PIPE, run 6 | 7 | from setuptools import find_packages, setup 8 | 9 | project_root = Path(__file__).parent 10 | 11 | # modified from https://github.com/lhotse-speech/lhotse/blob/master/setup.py 12 | 13 | 14 | 15 | if sys.version_info < (3,): 16 | # fmt: off 17 | print( 18 | "Python 2 has reached end-of-life and is no longer supported by valle." 19 | ) 20 | # fmt: on 21 | sys.exit(-1) 22 | 23 | if sys.version_info < (3, 7): 24 | print( 25 | "Python 3.6 has reached end-of-life on December 31st, 2021 " 26 | "and is no longer supported by valle." 27 | ) 28 | sys.exit(-1) 29 | 30 | 31 | 32 | 33 | install_requires = [ 34 | "encodec", 35 | "phonemizer", 36 | ] 37 | 38 | try: 39 | # If the user already installed PyTorch, make sure he has torchaudio too. 40 | # Otherwise, we'll just install the latest versions from PyPI for the user. 41 | import torch 42 | 43 | try: 44 | import torchaudio 45 | except ImportError: 46 | raise ValueError( 47 | "We detected that you have already installed PyTorch, but haven't installed torchaudio. " 48 | "Unfortunately we can't detect the compatible torchaudio version for you; " 49 | "you will have to install it manually. " 50 | "For instructions, please refer either to https://pytorch.org/get-started/locally/ " 51 | "or https://github.com/pytorch/audio#dependencies" 52 | ) 53 | except ImportError: 54 | install_requires.extend(["torch", "torchaudio"]) 55 | 56 | docs_require = ( 57 | (project_root / "requirements.txt").read_text().splitlines() 58 | ) 59 | tests_require = [ 60 | # "pytest==7.1.3", 61 | # "pytest-forked==1.4.0", 62 | # "pytest-xdist==2.5.0", 63 | # "pytest-cov==4.0.0", 64 | ] 65 | workflow_requires = [""] 66 | dev_requires = sorted( 67 | docs_require 68 | + tests_require 69 | + workflow_requires 70 | + ["jupyterlab", "matplotlib"] 71 | ) 72 | all_requires = sorted(dev_requires) 73 | 74 | if os.environ.get("READTHEDOCS", False): 75 | # When building documentation, omit torchaudio installation and mock it instead. 76 | # This works around the inability to install libsoundfile1 in read-the-docs env, 77 | # which caused the documentation builds to silently crash. 78 | install_requires = [ 79 | req 80 | for req in install_requires 81 | if not any(req.startswith(dep) for dep in ["torchaudio", "SoundFile"]) 82 | ] 83 | 84 | setup( 85 | name="USLM", 86 | version='0.1.0', 87 | python_requires=">=3.7.0", 88 | description="USLM: Unified Speech Language Models", 89 | author="Dong Zhang", 90 | author_email="dongzhang22@fudan.edu.cn", 91 | long_description=(project_root / "README.md").read_text(encoding="utf-8"), 92 | long_description_content_type="text/markdown", 93 | license="Apache-2.0 License", 94 | packages=find_packages(exclude=["test", "test.*"]), 95 | include_package_data=True, 96 | entry_points={}, 97 | install_requires=install_requires, 98 | extras_require={ 99 | "docs": docs_require, 100 | "tests": tests_require, 101 | "dev": dev_requires, 102 | "all": all_requires, 103 | }, 104 | classifiers=[ 105 | "Development Status :: 1 - Beta", 106 | "Programming Language :: Python :: 3.7", 107 | "Programming Language :: Python :: 3.8", 108 | "Programming Language :: Python :: 3.9", 109 | "Programming Language :: Python :: 3.10", 110 | "Intended Audience :: Science/Research", 111 | "Operating System :: POSIX :: Linux", 112 | "Operating System :: MacOS :: MacOS X", 113 | "License :: OSI Approved :: Apache Software License", 114 | "Topic :: Multimedia :: Sound/Audio :: Speech", 115 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 116 | "Topic :: Software Development :: Libraries :: Python Modules", 117 | "Typing :: Typed", 118 | ], 119 | ) 120 | -------------------------------------------------------------------------------- /utils/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/0nutation/USLM/b6cbe07feaeb142fa395b8b98836c534fec16bc9/utils/.DS_Store -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from icefall.utils import make_pad_mask 4 | 5 | from .symbol_table import SymbolTable 6 | 7 | make_pad_mask = make_pad_mask 8 | SymbolTable = SymbolTable 9 | 10 | 11 | class Transpose(nn.Identity): 12 | """(N, T, D) -> (N, D, T)""" 13 | 14 | def forward(self, input: torch.Tensor) -> torch.Tensor: 15 | return input.transpose(1, 2) 16 | -------------------------------------------------------------------------------- /utils/symbol_table.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Mobvoi Inc. (authors: Fangjun Kuang) 2 | # 3 | # See ../../../LICENSE for clarification regarding multiple authors 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | from dataclasses import dataclass 18 | from dataclasses import field 19 | from typing import Dict 20 | from typing import Generic 21 | from typing import List 22 | from typing import Optional 23 | from typing import TypeVar 24 | from typing import Union 25 | 26 | Symbol = TypeVar('Symbol') 27 | 28 | 29 | # Disable __repr__ otherwise it could freeze e.g. Jupyter. 30 | @dataclass(repr=False) 31 | class SymbolTable(Generic[Symbol]): 32 | '''SymbolTable that maps symbol IDs, found on the FSA arcs to 33 | actual objects. These objects can be arbitrary Python objects 34 | that can serve as keys in a dictionary (i.e. they need to be 35 | hashable and immutable). 36 | 37 | The SymbolTable can only be read to/written from disk if the 38 | symbols are strings. 39 | ''' 40 | _id2sym: Dict[int, Symbol] = field(default_factory=dict) 41 | '''Map an integer to a symbol. 42 | ''' 43 | 44 | _sym2id: Dict[Symbol, int] = field(default_factory=dict) 45 | '''Map a symbol to an integer. 46 | ''' 47 | 48 | _next_available_id: int = 1 49 | '''A helper internal field that helps adding new symbols 50 | to the table efficiently. 51 | ''' 52 | 53 | eps: Symbol = '' 54 | '''Null symbol, always mapped to index 0. 55 | ''' 56 | 57 | def __post_init__(self): 58 | for idx, sym in self._id2sym.items(): 59 | assert self._sym2id[sym] == idx 60 | assert idx >= 0 61 | 62 | for sym, idx in self._sym2id.items(): 63 | assert idx >= 0 64 | assert self._id2sym[idx] == sym 65 | 66 | if 0 not in self._id2sym: 67 | self._id2sym[0] = self.eps 68 | self._sym2id[self.eps] = 0 69 | else: 70 | assert self._id2sym[0] == self.eps 71 | assert self._sym2id[self.eps] == 0 72 | 73 | self._next_available_id = max(self._id2sym) + 1 74 | 75 | @staticmethod 76 | def from_str(s: str) -> 'SymbolTable': 77 | '''Build a symbol table from a string. 78 | 79 | The string consists of lines. Every line has two fields separated 80 | by space(s), tab(s) or both. The first field is the symbol and the 81 | second the integer id of the symbol. 82 | 83 | Args: 84 | s: 85 | The input string with the format described above. 86 | Returns: 87 | An instance of :class:`SymbolTable`. 88 | ''' 89 | id2sym: Dict[int, str] = dict() 90 | sym2id: Dict[str, int] = dict() 91 | 92 | for line in s.split('\n'): 93 | fields = line.split() 94 | if len(fields) == 0: 95 | continue # skip empty lines 96 | assert len(fields) == 2, \ 97 | f'Expect a line with 2 fields. Given: {len(fields)}' 98 | sym, idx = fields[0], int(fields[1]) 99 | assert sym not in sym2id, f'Duplicated symbol {sym}' 100 | assert idx not in id2sym, f'Duplicated id {idx}' 101 | id2sym[idx] = sym 102 | sym2id[sym] = idx 103 | 104 | eps = id2sym.get(0, '') 105 | 106 | return SymbolTable(_id2sym=id2sym, _sym2id=sym2id, eps=eps) 107 | 108 | @staticmethod 109 | def from_file(filename: str) -> 'SymbolTable': 110 | '''Build a symbol table from file. 111 | 112 | Every line in the symbol table file has two fields separated by 113 | space(s), tab(s) or both. The following is an example file: 114 | 115 | .. code-block:: 116 | 117 | 0 118 | a 1 119 | b 2 120 | c 3 121 | 122 | Args: 123 | filename: 124 | Name of the symbol table file. Its format is documented above. 125 | 126 | Returns: 127 | An instance of :class:`SymbolTable`. 128 | 129 | ''' 130 | with open(filename, 'r', encoding='utf-8') as f: 131 | return SymbolTable.from_str(f.read().strip()) 132 | 133 | def to_str(self) -> str: 134 | ''' 135 | Returns: 136 | Return a string representation of this object. You can pass 137 | it to the method ``from_str`` to recreate an identical object. 138 | ''' 139 | s = '' 140 | for idx, symbol in sorted(self._id2sym.items()): 141 | s += f'{symbol} {idx}\n' 142 | return s 143 | 144 | def to_file(self, filename: str): 145 | '''Serialize the SymbolTable to a file. 146 | 147 | Every line in the symbol table file has two fields separated by 148 | space(s), tab(s) or both. The following is an example file: 149 | 150 | .. code-block:: 151 | 152 | 0 153 | a 1 154 | b 2 155 | c 3 156 | 157 | Args: 158 | filename: 159 | Name of the symbol table file. Its format is documented above. 160 | ''' 161 | with open(filename, 'w') as f: 162 | for idx, symbol in sorted(self._id2sym.items()): 163 | print(symbol, idx, file=f) 164 | 165 | def add(self, symbol: Symbol, index: Optional[int] = None) -> int: 166 | '''Add a new symbol to the SymbolTable. 167 | 168 | Args: 169 | symbol: 170 | The symbol to be added. 171 | index: 172 | Optional int id to which the symbol should be assigned. 173 | If it is not available, a ValueError will be raised. 174 | 175 | Returns: 176 | The int id to which the symbol has been assigned. 177 | ''' 178 | # Already in the table? Return its ID. 179 | if symbol in self._sym2id: 180 | return self._sym2id[symbol] 181 | # Specific ID not provided - use next available. 182 | if index is None: 183 | index = self._next_available_id 184 | # Specific ID provided but not available. 185 | if index in self._id2sym: 186 | raise ValueError(f"Cannot assign id '{index}' to '{symbol}' - " 187 | f"already occupied by {self._id2sym[index]}") 188 | self._sym2id[symbol] = index 189 | self._id2sym[index] = symbol 190 | 191 | # Update next available ID if needed 192 | if self._next_available_id <= index: 193 | self._next_available_id = index + 1 194 | 195 | return index 196 | 197 | def get(self, k: Union[int, Symbol]) -> Union[Symbol, int]: 198 | '''Get a symbol for an id or get an id for a symbol 199 | 200 | Args: 201 | k: 202 | If it is an id, it tries to find the symbol corresponding 203 | to the id; if it is a symbol, it tries to find the id 204 | corresponding to the symbol. 205 | 206 | Returns: 207 | An id or a symbol depending on the given `k`. 208 | ''' 209 | if isinstance(k, int): 210 | return self._id2sym[k] 211 | else: 212 | return self._sym2id[k] 213 | 214 | def merge(self, other: 'SymbolTable') -> 'SymbolTable': 215 | '''Create a union of two SymbolTables. 216 | Raises an AssertionError if the same IDs are occupied by 217 | different symbols. 218 | 219 | Args: 220 | other: 221 | A symbol table to merge with ``self``. 222 | 223 | Returns: 224 | A new symbol table. 225 | ''' 226 | self._check_compatible(other) 227 | 228 | id2sym = {**self._id2sym, **other._id2sym} 229 | sym2id = {**self._sym2id, **other._sym2id} 230 | 231 | return SymbolTable(_id2sym=id2sym, _sym2id=sym2id, eps=self.eps) 232 | 233 | def _check_compatible(self, other: 'SymbolTable') -> None: 234 | # Epsilon compatibility 235 | assert self.eps == other.eps, f'Mismatched epsilon symbol: ' \ 236 | f'{self.eps} != {other.eps}' 237 | # IDs compatibility 238 | common_ids = set(self._id2sym).intersection(other._id2sym) 239 | for idx in common_ids: 240 | assert self[idx] == other[idx], f'ID conflict for id: {idx}, ' \ 241 | f'self[idx] = "{self[idx]}", ' \ 242 | f'other[idx] = "{other[idx]}"' 243 | # Symbols compatibility 244 | common_symbols = set(self._sym2id).intersection(other._sym2id) 245 | for sym in common_symbols: 246 | assert self[sym] == other[sym], f'ID conflict for id: {sym}, ' \ 247 | f'self[sym] = "{self[sym]}", ' \ 248 | f'other[sym] = "{other[sym]}"' 249 | 250 | def __getitem__(self, item: Union[int, Symbol]) -> Union[Symbol, int]: 251 | return self.get(item) 252 | 253 | def __contains__(self, item: Union[int, Symbol]) -> bool: 254 | if isinstance(item, int): 255 | return item in self._id2sym 256 | else: 257 | return item in self._sym2id 258 | 259 | def __len__(self) -> int: 260 | return len(self._id2sym) 261 | 262 | def __eq__(self, other: 'SymbolTable') -> bool: 263 | if len(self) != len(other): 264 | return False 265 | 266 | for s in self.symbols: 267 | if self[s] != other[s]: 268 | return False 269 | 270 | return True 271 | 272 | @property 273 | def ids(self) -> List[int]: 274 | '''Returns a list of integer IDs corresponding to the symbols. 275 | ''' 276 | ans = list(self._id2sym.keys()) 277 | ans.sort() 278 | return ans 279 | 280 | @property 281 | def symbols(self) -> List[Symbol]: 282 | '''Returns a list of symbols (e.g., strings) corresponding to 283 | the integer IDs. 284 | ''' 285 | ans = list(self._sym2id.keys()) 286 | ans.sort() 287 | return ans 288 | --------------------------------------------------------------------------------