├── README.md
├── bin
├── __init__.py
└── infer.py
├── data
├── .DS_Store
├── __init__.py
├── collation.py
├── datamodule.py
├── dataset.py
├── fbank.py
├── input_strategies.py
├── speechtokenizer.py
└── tokenizer.py
├── images
├── README.md
└── overview.png
├── inference.sh
├── models
├── .DS_Store
├── __init__.py
├── macros.py
├── transformer.py
├── uslm.py
├── valle.py
└── visualizer.py
├── modules
├── .DS_Store
├── __init__.py
├── activation.py
├── embedding.py
├── optim.py
├── scaling.py
├── scheduler.py
└── transformer.py
├── prompts
├── 1580_141083_000002_000002.normalized.txt
└── 1580_141083_000002_000002.wav
├── requirements.txt
├── setup.py
└── utils
├── .DS_Store
├── __init__.py
└── symbol_table.py
/README.md:
--------------------------------------------------------------------------------
1 | # USLM: Unified Speech Language Model
2 |
3 |
4 | ## Introduction
5 | Build upon [SpeechTokenizer](https://github.com/ZhangXInFD/SpeechTokenizer), USLM consists of autoregressive and non-autoregressive models, it can hierarchically model information in speech. The autoregressive (AR) model captures the content information by modeling tokens from the first RVQ quantizer. The non-autoregressive (NAR) model complements paralinguistic information for the AR model by generating tokens from the subsequent quantizers conditioned on the first-layer tokens.
6 |
7 |
8 |
9 |
10 | Overview
11 |
12 |
13 | ## Installation
14 |
15 | To get up and running quickly just follow the steps below:
16 |
17 | ```
18 | # PyTorch
19 | pip install torch==1.13.1 torchaudio==0.13.1 --extra-index-url https://download.pytorch.org/whl/cu116
20 | pip install torchmetrics==0.11.1
21 | # fbank
22 | pip install librosa==0.8.1
23 |
24 | # phonemizer pypinyin
25 | apt-get install espeak-ng
26 | ## OSX: brew install espeak
27 | pip install phonemizer==3.2.1 pypinyin==0.48.0
28 |
29 | # lhotse update to newest version
30 | # https://github.com/lhotse-speech/lhotse/pull/956
31 | # https://github.com/lhotse-speech/lhotse/pull/960
32 | pip uninstall lhotse
33 | pip install git+https://github.com/lhotse-speech/lhotse
34 |
35 | # k2
36 | # find the right version in https://huggingface.co/csukuangfj/k2
37 | pip install https://huggingface.co/csukuangfj/k2/resolve/main/cuda/k2-1.23.4.dev20230224+cuda11.6.torch1.13.1-cp310-cp310-linux_x86_64.whl
38 |
39 | # icefall
40 | git clone https://github.com/k2-fsa/icefall
41 | cd icefall
42 | pip install -r requirements.txt
43 | export PYTHONPATH=`pwd`/../icefall:$PYTHONPATH
44 | echo "export PYTHONPATH=`pwd`/../icefall:\$PYTHONPATH" >> ~/.zshrc
45 | echo "export PYTHONPATH=`pwd`/../icefall:\$PYTHONPATH" >> ~/.bashrc
46 | cd -
47 | source ~/.zshrc
48 |
49 | #SpeechTokenizer
50 | pip install -U speechtokenizer
51 |
52 | # uslm
53 | git clone https://github.com/0nutation/USLM
54 | cd USLM
55 | pip install -e .
56 | ```
57 |
58 | ## USLM Models
59 | This version of USLM is trained on the LibriTTS dataset, so the performance is not optimal due to data limitations.
60 |
61 | | Model| Dataset |Discription|
62 | |:----|:----:|:----|
63 | |[USLM_libri](https://huggingface.co/fnlp/USLM/tree/main/USLM_libritts)|LibriTTS|USLM trained on LibriTTS dataset |
64 |
65 |
66 | ## Zero-shot TTS Using USLM
67 | Download pre-trained SpeechTokenizer models:
68 | ``` bash
69 | st_dir="ckpt/speechtokenizer/"
70 | mkdir -p ${st_dir}
71 | cd ${st_dir}
72 | wget "https://huggingface.co/fnlp/SpeechTokenizer/resolve/main/speechtokenizer_hubert_avg/SpeechTokenizer.pt"
73 | wget "https://huggingface.co/fnlp/SpeechTokenizer/resolve/main/speechtokenizer_hubert_avg/config.json"
74 | cd -
75 | ```
76 |
77 | Download pre-trained USLM models:
78 | ``` bash
79 | uslm_dir="ckpt/uslm/"
80 | mkdir -p ${uslm_dir}
81 | cd ${uslm_dir}
82 | wget "https://huggingface.co/fnlp/USLM/resolve/main/USLM_libritts/USLM.pt"
83 | wget "https://huggingface.co/fnlp/USLM/resolve/main/USLM_libritts/unique_text_tokens.k2symbols"
84 | cd -
85 | ```
86 |
87 | Inference:
88 | ``` bash
89 | out_dir="output/"
90 | mkdir -p ${out_dir}
91 |
92 | python3 bin/infer.py --output-dir ${out_dir}/ \
93 | --model-name uslm --norm-first true --add-prenet false \
94 | --share-embedding true --norm-first true --add-prenet false \
95 | --audio-extractor SpeechTokenizer \
96 | --speechtokenizer-dir "${st_dir}" \
97 | --checkpoint=${uslm_dir}/USLM.pt \
98 | --text-tokens "${uslm_dir}/unique_text_tokens.k2symbols" \
99 | --text-prompts "mr Soames was a tall, spare man, of a nervous and excitable temperament." \
100 | --audio-prompts prompts/1580_141083_000002_000002.wav \
101 | --text "Begin with the fundamental steps of the process. This will give you a solid foundation to build upon and boost your confidence. " \
102 | ```
103 |
104 | or you can directly run inference.sh
105 | ``` bash
106 | bash inference.sh
107 | ```
108 |
109 | ## Acknowledge
110 | [VALL-E](https://github.com/lifeiteng/vall-e): The codebase we build upon.
111 |
112 | ## Citation
113 | If you use this code or result in your paper, please cite our work as:
114 | ```Tex
115 | @misc{zhang2023speechtokenizer,
116 | title={SpeechTokenizer: Unified Speech Tokenizer for Speech Language Models},
117 | author={Xin Zhang and Dong Zhang and Shimin Li and Yaqian Zhou and Xipeng Qiu},
118 | year={2023},
119 | eprint={2308.16692},
120 | archivePrefix={arXiv},
121 | primaryClass={cs.CL}
122 | }
123 | ```
124 |
--------------------------------------------------------------------------------
/bin/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/0nutation/USLM/b6cbe07feaeb142fa395b8b98836c534fec16bc9/bin/__init__.py
--------------------------------------------------------------------------------
/bin/infer.py:
--------------------------------------------------------------------------------
1 | """
2 | Created on Fri. Sept. 8 00:43:42 2023
3 | @author: Dong Zhang
4 | """
5 |
6 | import argparse
7 | import logging
8 | import os
9 | from pathlib import Path
10 | from tqdm import tqdm
11 | os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
12 |
13 | import torch
14 | import torchaudio
15 |
16 | from uslm.data import (
17 | AudioTokenizer,
18 | TextTokenizer,
19 | tokenize_audio,
20 | tokenize_text,
21 | Speechtokenizer,
22 | sttokenize_audio
23 | )
24 | from uslm.data.collation import get_text_token_collater
25 | from uslm.models import add_model_arguments, get_model
26 |
27 |
28 | def circular_padding(x, tgt_len):
29 | if x.size(-1) >= tgt_len:
30 | return x[:, :, :tgt_len]
31 | t = tgt_len // x.size(-1)
32 | r = tgt_len % x.size(-1)
33 | tgt = x.repeat(1, 1, t)
34 | tgt = torch.cat([tgt, x[:, :, :r]], axis=-1)
35 | return tgt
36 |
37 |
38 | def get_args():
39 | parser = argparse.ArgumentParser()
40 |
41 | parser.add_argument(
42 | "--text-prompts",
43 | type=str,
44 | default="",
45 | help="Text prompts which are separated by |.",
46 | )
47 |
48 | parser.add_argument(
49 | "--audio-prompts",
50 | type=str,
51 | default="",
52 | help="Audio prompts which are separated by | and should be aligned with --text-prompts.",
53 | )
54 |
55 | parser.add_argument(
56 | "--text",
57 | type=str,
58 | default="To get up and running quickly just follow the steps below.",
59 | help="Text to be synthesized.",
60 | )
61 |
62 | parser.add_argument(
63 | "--audio-extractor",
64 | type=str,
65 | default="Encodec",
66 | help="Encodec or SpeechTokenizer or Fbank",
67 | )
68 |
69 | # model
70 | add_model_arguments(parser)
71 |
72 | parser.add_argument(
73 | "--text-tokens",
74 | type=str,
75 | default="data/tokenized/unique_text_tokens.k2symbols",
76 | help="Path to the unique text tokens file.",
77 | )
78 | parser.add_argument(
79 | "--text-extractor",
80 | type=str,
81 | default="espeak",
82 | help="espeak or pypinyin or pypinyin_initials_finals",
83 | )
84 |
85 | parser.add_argument(
86 | "--checkpoint",
87 | type=str,
88 | default="exp/vallf_nano_full/checkpoint-100000.pt",
89 | help="Path to the saved checkpoint.",
90 | )
91 |
92 | parser.add_argument(
93 | "--output-dir",
94 | type=Path,
95 | default=Path("infer/demo"),
96 | help="Path to the tokenized files.",
97 | )
98 |
99 | parser.add_argument(
100 | "--top-k",
101 | type=int,
102 | default=-100,
103 | help="Whether AR Decoder do top_k(if > 0) sampling.",
104 | )
105 |
106 | parser.add_argument(
107 | "--temperature",
108 | type=float,
109 | default=1.0,
110 | help="The temperature of AR Decoder top_k sampling.",
111 | )
112 |
113 | parser.add_argument(
114 | "--without-nar",
115 | default=False,
116 | help="without nar",
117 | )
118 | parser.add_argument(
119 | "--speechtokenizer-dir",
120 | type=str,
121 | default="False",
122 | help="dirname of speechtokenizer models",
123 | )
124 |
125 | return parser.parse_args()
126 |
127 |
128 | @torch.no_grad()
129 | def main():
130 | args = get_args()
131 | text_tokenizer = TextTokenizer(backend=args.text_extractor)
132 | text_collater = get_text_token_collater(args.text_tokens)
133 |
134 | if args.audio_extractor == "EnCodec":
135 | sr = 24000
136 | audio_tokenizer = AudioTokenizer()
137 | tokenize_a = tokenize_audio
138 | elif args.audio_extractor == "SpeechTokenizer":
139 | sr = 16000
140 | audio_tokenizer = Speechtokenizer(ckpt_dir=args.speechtokenizer_dir)
141 | tokenize_a = sttokenize_audio
142 |
143 |
144 | device = torch.device("cpu")
145 | if torch.cuda.is_available():
146 | device = torch.device("cuda", 0)
147 |
148 | model = get_model(args)
149 | if args.checkpoint:
150 | checkpoint = torch.load(args.checkpoint, map_location=device)
151 | missing_keys, unexpected_keys = model.load_state_dict(
152 | checkpoint["model"], strict=True
153 | )
154 | assert not missing_keys
155 |
156 |
157 | model.to(device)
158 | model.eval()
159 |
160 | Path(args.output_dir).mkdir(parents=True, exist_ok=True)
161 |
162 | if os.path.isfile(args.text): # for demos
163 |
164 | with open(args.text) as f:
165 | for line in tqdm(f):
166 | fields = line.strip().split("|")
167 | assert len(fields) == 4
168 | prompt_text, prompt_audio, text, audio_path = fields
169 | logging.info(f"synthesize text: {text}")
170 | os.makedirs(os.path.dirname(audio_path),exist_ok=True)
171 | text_tokens, text_tokens_lens = text_collater(
172 | [
173 | tokenize_text(
174 | text_tokenizer, text=f"{prompt_text} {text}".strip()
175 | )
176 | ]
177 | )
178 | _, enroll_x_lens = text_collater(
179 | [
180 | tokenize_text(
181 | text_tokenizer, text=f"{prompt_text}".strip()
182 | )
183 | ]
184 | )
185 |
186 | audio_prompts = tokenize_a(audio_tokenizer, prompt_audio)
187 | if args.audio_extractor == "SpeechTokenizer":
188 | audio_prompts = audio_prompts.permute(1, 2, 0).to(device) #[b,t,8]
189 | elif args.audio_extractor == "EnCodec":
190 | audio_prompts = audio_prompts[0][0].transpose(2,1)
191 |
192 | # synthesis
193 | encoded_frames = model.inference(
194 | text_tokens.to(device),
195 | text_tokens_lens.to(device),
196 | audio_prompts,
197 | enroll_x_lens=enroll_x_lens,
198 | top_k=args.top_k,
199 | temperature=args.temperature,
200 | )
201 |
202 | if args.audio_extractor == "SpeechTokenizer":
203 | code_generated = encoded_frames.permute(2,0,1) #[8,b,T]
204 | else:
205 | code_generated = [(encoded_frames.transpose(2, 1), None)]
206 |
207 |
208 | if args.without_nar:
209 | audio_prompts = circular_padding(audio_prompts.permute(2,0,1), code_generated.shape[-1])
210 | code_generated = torch.cat((code_generated[:1,:,:], audio_prompts[1:4,:,:]),dim=0)
211 |
212 | samples = audio_tokenizer.decode(
213 | code_generated
214 | )
215 | # store
216 | torchaudio.save(audio_path, samples[0].cpu(), sr)
217 | return
218 |
219 |
220 |
221 |
222 | text_prompts = " ".join(args.text_prompts.split("|"))
223 |
224 | audio_prompts = []
225 | if args.audio_prompts:
226 | for n, audio_file in enumerate(args.audio_prompts.split("|")):
227 | encoded_frames = tokenize_a(audio_tokenizer, audio_file)
228 | if False:
229 | samples = audio_tokenizer.decode(encoded_frames)
230 | torchaudio.save(
231 | f"{args.output_dir}/p{n}.wav", samples[0], sr
232 | )
233 |
234 | if args.audio_extractor == "EnCodec":
235 | audio_prompts.append(encoded_frames[0][0])
236 | elif args.audio_extractor == "SpeechTokenizer":
237 | audio_prompts.append(encoded_frames.permute(1,0,2))
238 |
239 |
240 | assert len(args.text_prompts.split("|")) == len(audio_prompts)
241 | audio_prompts = torch.concat(audio_prompts, dim=-1).transpose(2, 1)
242 | audio_prompts = audio_prompts.to(device)
243 |
244 |
245 |
246 | for n, text in enumerate(args.text.split("|")):
247 | logging.info(f"synthesize text: {text}")
248 | text_tokens, text_tokens_lens = text_collater(
249 | [
250 | tokenize_text(
251 | text_tokenizer, text=f"{text_prompts} {text}".strip()
252 | )
253 | ]
254 | )
255 |
256 | # synthesis
257 |
258 | enroll_x_lens = None
259 | if text_prompts:
260 | _, enroll_x_lens = text_collater(
261 | [
262 | tokenize_text(
263 | text_tokenizer, text=f"{text_prompts}".strip()
264 | )
265 | ]
266 | )
267 | encoded_frames = model.inference(
268 | text_tokens.to(device),
269 | text_tokens_lens.to(device),
270 | audio_prompts,
271 | enroll_x_lens=enroll_x_lens,
272 | top_k=args.top_k,
273 | temperature=args.temperature,
274 | )
275 |
276 |
277 | if audio_prompts != []:
278 | if args.audio_extractor == "SpeechTokenizer":
279 | code_generated = encoded_frames.permute(2,0,1)
280 | else:
281 | code_generated = [(encoded_frames.transpose(2, 1), None)]
282 | samples = audio_tokenizer.decode(
283 | code_generated
284 | )
285 | # store
286 | idx = args.audio_prompts.split('|')[n]
287 | torchaudio.save(
288 | f"{args.output_dir}/gen-{os.path.basename(idx).replace('flac','wav')}", samples[0].cpu(), sr
289 | )
290 | else: # Transformer
291 | pass
292 |
293 |
294 | torch.set_num_threads(1)
295 | torch.set_num_interop_threads(1)
296 | torch._C._jit_set_profiling_executor(False)
297 | torch._C._jit_set_profiling_mode(False)
298 | torch._C._set_graph_executor_optimize(False)
299 | if __name__ == "__main__":
300 | formatter = (
301 | "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
302 | )
303 | logging.basicConfig(format=formatter, level=logging.INFO)
304 | main()
305 |
--------------------------------------------------------------------------------
/data/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/0nutation/USLM/b6cbe07feaeb142fa395b8b98836c534fec16bc9/data/.DS_Store
--------------------------------------------------------------------------------
/data/__init__.py:
--------------------------------------------------------------------------------
1 | from .datamodule import *
2 | from .tokenizer import *
3 | from .collation import *
4 | from .speechtokenizer import *
5 |
--------------------------------------------------------------------------------
/data/collation.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | from typing import List, Tuple
3 |
4 | import numpy as np
5 | import torch
6 |
7 | from uslm.utils import SymbolTable
8 |
9 |
10 | class TextTokenCollater:
11 | """Collate list of text tokens
12 |
13 | Map sentences to integers. Sentences are padded to equal length.
14 | Beginning and end-of-sequence symbols can be added.
15 |
16 | Example:
17 | >>> token_collater = TextTokenCollater(text_tokens)
18 | >>> tokens_batch, tokens_lens = token_collater(text)
19 |
20 | Returns:
21 | tokens_batch: IntTensor of shape (B, L)
22 | B: batch dimension, number of input sentences
23 | L: length of the longest sentence
24 | tokens_lens: IntTensor of shape (B,)
25 | Length of each sentence after adding and
26 | but before padding.
27 | """
28 |
29 | def __init__(
30 | self,
31 | text_tokens: List[str],
32 | add_eos: bool = True,
33 | add_bos: bool = True,
34 | pad_symbol: str = "",
35 | bos_symbol: str = "",
36 | eos_symbol: str = "",
37 | ):
38 | self.pad_symbol = pad_symbol
39 |
40 | self.add_eos = add_eos
41 | self.add_bos = add_bos
42 |
43 | self.bos_symbol = bos_symbol
44 | self.eos_symbol = eos_symbol
45 |
46 | unique_tokens = (
47 | [pad_symbol]
48 | + ([bos_symbol] if add_bos else [])
49 | + ([eos_symbol] if add_eos else [])
50 | + sorted(text_tokens)
51 | )
52 |
53 | self.token2idx = {token: idx for idx, token in enumerate(unique_tokens)}
54 | self.idx2token = [token for token in unique_tokens]
55 |
56 | def index(
57 | self, tokens_list: List[str]
58 | ) -> Tuple[torch.Tensor, torch.Tensor]:
59 | seqs, seq_lens = [], []
60 | for tokens in tokens_list:
61 | assert (
62 | all([True if s in self.token2idx else False for s in tokens])
63 | is True
64 | )
65 | seq = (
66 | ([self.bos_symbol] if self.add_bos else [])
67 | + list(tokens)
68 | + ([self.eos_symbol] if self.add_eos else [])
69 | )
70 | seqs.append(seq)
71 | seq_lens.append(len(seq))
72 |
73 | max_len = max(seq_lens)
74 | for k, (seq, seq_len) in enumerate(zip(seqs, seq_lens)):
75 | seq.extend([self.pad_symbol] * (max_len - seq_len))
76 |
77 | tokens = torch.from_numpy(
78 | np.array(
79 | [[self.token2idx[token] for token in seq] for seq in seqs],
80 | dtype=np.int64,
81 | )
82 | )
83 | tokens_lens = torch.IntTensor(seq_lens)
84 |
85 | return tokens, tokens_lens
86 |
87 | def __call__(self, texts: List[str]) -> Tuple[torch.Tensor, torch.Tensor]:
88 | tokens_seqs = [[p for p in text] for text in texts]
89 | max_len = len(max(tokens_seqs, key=len))
90 |
91 | seqs = [
92 | ([self.bos_symbol] if self.add_bos else [])
93 | + list(seq)
94 | + ([self.eos_symbol] if self.add_eos else [])
95 | + [self.pad_symbol] * (max_len - len(seq))
96 | for seq in tokens_seqs
97 | ]
98 |
99 | # seqs_filtered = []
100 | # flag=True
101 | # for seq in seqs:
102 | # for token in seq:
103 | # if token not in self.token2idx:
104 | # flag = False
105 | # break
106 | # else:
107 | # flag = True
108 | # if flag:
109 | # seqs_filtered.append(seq)
110 | # seqs = seqs_filtered
111 |
112 | tokens_batch = torch.from_numpy(
113 | np.array(
114 | [[self.token2idx[token] for token in seq] for seq in seqs],
115 | dtype=np.int64,
116 | )
117 | )
118 |
119 | tokens_lens = torch.IntTensor(
120 | [
121 | len(seq) + int(self.add_eos) + int(self.add_bos)
122 | for seq in tokens_seqs
123 | ]
124 | )
125 |
126 | return tokens_batch, tokens_lens
127 |
128 |
129 | def get_text_token_collater(text_tokens_file: str) -> TextTokenCollater:
130 | text_tokens_path = Path(text_tokens_file)
131 | unique_tokens = SymbolTable.from_file(text_tokens_path)
132 | collater = TextTokenCollater(
133 | unique_tokens.symbols, add_bos=True, add_eos=True
134 | )
135 | return collater
136 |
--------------------------------------------------------------------------------
/data/datamodule.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 (authors: Feiteng Li)
2 | #
3 | # See ../../../../LICENSE for clarification regarding multiple authors
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 |
18 | import argparse
19 | import inspect
20 | import logging
21 | from functools import lru_cache
22 | from pathlib import Path
23 | from typing import Any, Dict, Optional
24 |
25 | import torch
26 | from icefall.utils import str2bool
27 | from lhotse import CutSet, load_manifest_lazy
28 | from lhotse.dataset import (
29 | CutConcatenate,
30 | DynamicBucketingSampler,
31 | PrecomputedFeatures,
32 | SingleCutSampler,
33 | SpecAugment,
34 | )
35 | from lhotse.dataset.input_strategies import OnTheFlyFeatures
36 | from lhotse.utils import fix_random_seed
37 | from torch.utils.data import DataLoader
38 |
39 | from uslm.data.collation import get_text_token_collater
40 | from uslm.data.dataset import SpeechSynthesisDataset
41 | from uslm.data.fbank import get_fbank_extractor
42 | from uslm.data.input_strategies import PromptedPrecomputedFeatures
43 |
44 | PrecomputedFeatures = PrecomputedFeatures
45 |
46 |
47 | class _SeedWorkers:
48 | def __init__(self, seed: int):
49 | self.seed = seed
50 |
51 | def __call__(self, worker_id: int):
52 | fix_random_seed(self.seed + worker_id)
53 |
54 |
55 | def _get_input_strategy(input_strategy, dataset, cuts):
56 | if input_strategy == "PromptedPrecomputedFeatures":
57 | return PromptedPrecomputedFeatures(dataset, cuts)
58 |
59 | return eval(input_strategy)()
60 |
61 |
62 | class TtsDataModule:
63 | """
64 | DataModule for VALL-E TTS experiments.
65 | It assumes there is always one train and valid dataloader.
66 |
67 | It contains all the common data pipeline modules used in TTS
68 | experiments, e.g.:
69 | - dynamic batch size,
70 | - bucketing samplers,
71 | - cut concatenation[not used & tested yet],
72 | - augmentation[not used & tested yet],
73 | - on-the-fly feature extraction[not used & tested yet]
74 |
75 | This class should be derived for specific corpora used in TTS tasks.
76 | """
77 |
78 | def __init__(self, args: argparse.Namespace):
79 | self.args = args
80 |
81 | @classmethod
82 | def add_arguments(cls, parser: argparse.ArgumentParser):
83 | group = parser.add_argument_group(
84 | title="TTS data related options",
85 | description="These options are used for the preparation of "
86 | "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
87 | "effective batch sizes, sampling strategies, applied data "
88 | "augmentations, etc.",
89 | )
90 | group.add_argument(
91 | "--manifest-dir",
92 | type=Path,
93 | default=Path("data/tokenized"),
94 | help="Path to directory with train/valid/test cuts.",
95 | )
96 | group.add_argument(
97 | "--max-duration",
98 | type=int,
99 | default=40.0,
100 | help="Maximum pooled recordings duration (seconds) in a "
101 | "single batch. You can reduce it if it causes CUDA OOM.",
102 | )
103 | group.add_argument(
104 | "--bucketing-sampler",
105 | type=str2bool,
106 | default=True,
107 | help="When enabled, the batches will come from buckets of "
108 | "similar duration (saves padding frames).",
109 | )
110 | group.add_argument(
111 | "--num-buckets",
112 | type=int,
113 | default=10,
114 | help="The number of buckets for the DynamicBucketingSampler"
115 | "(you might want to increase it for larger datasets).",
116 | )
117 | group.add_argument(
118 | "--concatenate-cuts",
119 | type=str2bool,
120 | default=False,
121 | help="When enabled, utterances (cuts) will be concatenated "
122 | "to minimize the amount of padding.",
123 | )
124 | group.add_argument(
125 | "--duration-factor",
126 | type=float,
127 | default=1.0,
128 | help="Determines the maximum duration of a concatenated cut "
129 | "relative to the duration of the longest cut in a batch.",
130 | )
131 | group.add_argument(
132 | "--gap",
133 | type=float,
134 | default=0.1,
135 | help="The amount of padding (in seconds) inserted between "
136 | "concatenated cuts. This padding is filled with noise when "
137 | "noise augmentation is used.",
138 | )
139 | group.add_argument(
140 | "--on-the-fly-feats",
141 | type=str2bool,
142 | default=False,
143 | help="When enabled, use on-the-fly cut mixing and feature "
144 | "extraction. Will drop existing precomputed feature manifests "
145 | "if available.",
146 | )
147 | group.add_argument(
148 | "--shuffle",
149 | type=str2bool,
150 | default=True,
151 | help="When enabled (=default), the examples will be "
152 | "shuffled for each epoch.",
153 | )
154 | group.add_argument(
155 | "--drop-last",
156 | type=str2bool,
157 | default=False,
158 | help="Whether to drop last batch. Used by sampler.",
159 | )
160 | group.add_argument(
161 | "--return-cuts",
162 | type=str2bool,
163 | default=True,
164 | help="When enabled, each batch will have the "
165 | "field: batch['supervisions']['cut'] with the cuts that "
166 | "were used to construct it.",
167 | )
168 |
169 | group.add_argument(
170 | "--num-workers",
171 | type=int,
172 | default=8,
173 | help="The number of training dataloader workers that "
174 | "collect the batches.",
175 | )
176 |
177 | group.add_argument(
178 | "--enable-spec-aug",
179 | type=str2bool,
180 | default=False,
181 | help="When enabled, use SpecAugment for training dataset.",
182 | )
183 |
184 | group.add_argument(
185 | "--spec-aug-time-warp-factor",
186 | type=int,
187 | default=80,
188 | help="Used only when --enable-spec-aug is True. "
189 | "It specifies the factor for time warping in SpecAugment. "
190 | "Larger values mean more warping. "
191 | "A value less than 1 means to disable time warp.",
192 | )
193 |
194 | group.add_argument(
195 | "--input-strategy",
196 | type=str,
197 | default="PrecomputedFeatures",
198 | help="AudioSamples or PrecomputedFeatures or PromptedPrecomputedFeatures",
199 | )
200 |
201 | group.add_argument(
202 | "--dataset",
203 | type=str,
204 | default="libritts",
205 | help="--input-strategy PromptedPrecomputedFeatures needs dataset name to prepare prompts.",
206 | )
207 |
208 | parser.add_argument(
209 | "--text-tokens",
210 | type=str,
211 | default="data/tokenized/unique_text_tokens.k2symbols",
212 | help="Path to the unique text tokens file",
213 | )
214 |
215 | parser.add_argument(
216 | "--sampling-rate",
217 | type=int,
218 | default=24000,
219 | help="""Audio sampling rate.""",
220 | )
221 |
222 | def train_dataloaders(
223 | self,
224 | cuts_train: CutSet,
225 | sampler_state_dict: Optional[Dict[str, Any]] = None,
226 | ) -> DataLoader:
227 | """
228 | Args:
229 | cuts_train:
230 | CutSet for training.
231 | sampler_state_dict:
232 | The state dict for the training sampler.
233 | """
234 | transforms = []
235 |
236 | if self.args.concatenate_cuts:
237 | logging.info(
238 | f"Using cut concatenation with duration factor "
239 | f"{self.args.duration_factor} and gap {self.args.gap}."
240 | )
241 | # Cut concatenation should be the first transform in the list,
242 | # so that if we e.g. mix noise in, it will fill the gaps between
243 | # different utterances.
244 | transforms = [
245 | CutConcatenate(
246 | duration_factor=self.args.duration_factor, gap=self.args.gap
247 | )
248 | ] + transforms
249 |
250 | input_transforms = []
251 | if self.args.enable_spec_aug:
252 | logging.info("Enable SpecAugment")
253 | logging.info(
254 | f"Time warp factor: {self.args.spec_aug_time_warp_factor}"
255 | )
256 | # Set the value of num_frame_masks according to Lhotse's version.
257 | # In different Lhotse's versions, the default of num_frame_masks is
258 | # different.
259 | num_frame_masks = 10
260 | num_frame_masks_parameter = inspect.signature(
261 | SpecAugment.__init__
262 | ).parameters["num_frame_masks"]
263 | if num_frame_masks_parameter.default == 1:
264 | num_frame_masks = 2
265 | logging.info(f"Num frame mask: {num_frame_masks}")
266 | input_transforms.append(
267 | SpecAugment(
268 | time_warp_factor=self.args.spec_aug_time_warp_factor,
269 | num_frame_masks=num_frame_masks,
270 | features_mask_size=27,
271 | num_feature_masks=2,
272 | frames_mask_size=100,
273 | )
274 | )
275 | else:
276 | logging.info("Disable SpecAugment")
277 |
278 | logging.info("About to create train dataset")
279 | if self.args.on_the_fly_feats:
280 | # NOTE: the PerturbSpeed transform should be added only if we
281 | # remove it from data prep stage.
282 | # Add on-the-fly speed perturbation; since originally it would
283 | # have increased epoch size by 3, we will apply prob 2/3 and use
284 | # 3x more epochs.
285 | # Speed perturbation probably should come first before
286 | # concatenation, but in principle the transforms order doesn't have
287 | # to be strict (e.g. could be randomized)
288 | # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa
289 | # Drop feats to be on the safe side.
290 | train = SpeechSynthesisDataset(
291 | get_text_token_collater(self.args.text_tokens),
292 | cut_transforms=transforms,
293 | feature_input_strategy=OnTheFlyFeatures(get_fbank_extractor()),
294 | feature_transforms=input_transforms,
295 | )
296 | else:
297 | train = SpeechSynthesisDataset(
298 | get_text_token_collater(self.args.text_tokens),
299 | feature_input_strategy=_get_input_strategy(
300 | self.args.input_strategy, self.args.dataset, cuts_train
301 | ),
302 | cut_transforms=transforms,
303 | feature_transforms=input_transforms,
304 | )
305 |
306 | if self.args.bucketing_sampler:
307 | logging.info("Using DynamicBucketingSampler")
308 | train_sampler = DynamicBucketingSampler(
309 | cuts_train,
310 | max_duration=self.args.max_duration,
311 | shuffle=self.args.shuffle,
312 | num_buckets=self.args.num_buckets,
313 | drop_last=True,
314 | )
315 | else:
316 | logging.info(
317 | "Using SingleCutSampler and sort by duraton(ascending=True)."
318 | )
319 | cuts_train = cuts_train.to_eager().sort_by_duration(ascending=True)
320 | train_sampler = SingleCutSampler(
321 | cuts_train,
322 | max_duration=self.args.max_duration,
323 | shuffle=self.args.shuffle,
324 | )
325 | logging.info("About to create train dataloader")
326 |
327 | if sampler_state_dict is not None:
328 | logging.info("Loading sampler state dict")
329 | train_sampler.load_state_dict(sampler_state_dict)
330 |
331 | # 'seed' is derived from the current random state, which will have
332 | # previously been set in the main process.
333 | seed = torch.randint(0, 100000, ()).item()
334 | worker_init_fn = _SeedWorkers(seed)
335 |
336 | train_dl = DataLoader(
337 | train,
338 | sampler=train_sampler,
339 | batch_size=None,
340 | num_workers=self.args.num_workers,
341 | persistent_workers=False,
342 | worker_init_fn=worker_init_fn,
343 | )
344 |
345 | return train_dl
346 |
347 | def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
348 | logging.info("About to create dev dataset")
349 | if self.args.on_the_fly_feats:
350 | validate = SpeechSynthesisDataset(
351 | get_text_token_collater(self.args.text_tokens),
352 | feature_input_strategy=OnTheFlyFeatures(get_fbank_extractor()),
353 | cut_transforms=[],
354 | )
355 | else:
356 | validate = SpeechSynthesisDataset(
357 | get_text_token_collater(self.args.text_tokens),
358 | feature_input_strategy=_get_input_strategy(
359 | self.args.input_strategy, self.args.dataset, cuts_valid
360 | ),
361 | cut_transforms=[],
362 | )
363 | valid_sampler = DynamicBucketingSampler(
364 | cuts_valid,
365 | max_duration=self.args.max_duration,
366 | shuffle=False,
367 | drop_last=True,
368 | )
369 | logging.info("About to create dev dataloader")
370 | valid_dl = DataLoader(
371 | validate,
372 | sampler=valid_sampler,
373 | batch_size=None,
374 | num_workers=4,
375 | persistent_workers=False,
376 | )
377 |
378 | return valid_dl
379 |
380 | def test_dataloaders(self, cuts: CutSet) -> DataLoader:
381 | logging.debug("About to create test dataset")
382 | test = SpeechSynthesisDataset(
383 | get_text_token_collater(self.args.text_tokens),
384 | feature_input_strategy=OnTheFlyFeatures(get_fbank_extractor())
385 | if self.args.on_the_fly_feats
386 | else _get_input_strategy(
387 | self.args.input_strategy, self.args.dataset, cuts
388 | ),
389 | cut_transforms=[],
390 | )
391 | sampler = DynamicBucketingSampler(
392 | cuts,
393 | max_duration=self.args.max_duration,
394 | shuffle=False,
395 | drop_last=True,
396 | )
397 | logging.debug("About to create test dataloader")
398 | test_dl = DataLoader(
399 | test,
400 | batch_size=None,
401 | sampler=sampler,
402 | num_workers=self.args.num_workers,
403 | )
404 | return test_dl
405 |
406 | @lru_cache()
407 | def train_cuts(self) -> CutSet:
408 | logging.info("About to get train cuts")
409 | return load_manifest_lazy(
410 | self.args.manifest_dir / "cuts_train.jsonl.gz"
411 | )
412 |
413 | @lru_cache()
414 | def dev_cuts(self) -> CutSet:
415 | logging.info("About to get dev cuts")
416 | return load_manifest_lazy(self.args.manifest_dir / "cuts_dev.jsonl.gz")
417 |
418 | @lru_cache()
419 | def test_cuts(self) -> CutSet:
420 | logging.info("About to get test cuts")
421 | return load_manifest_lazy(self.args.manifest_dir / "cuts_test.jsonl.gz")
422 |
--------------------------------------------------------------------------------
/data/dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 (authors: Feiteng Li)
2 | #
3 | # See ../../../../LICENSE for clarification regarding multiple authors
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | """
18 | modified from lhoste.dataset.speech_synthesis.py
19 | """
20 |
21 | from typing import Callable, Dict, List, Sequence, Union
22 |
23 | import torch
24 | from lhotse import validate
25 | from lhotse.cut import CutSet
26 | from lhotse.dataset.collation import collate_audio
27 | from lhotse.dataset.input_strategies import BatchIO, PrecomputedFeatures
28 | from lhotse.utils import ifnone
29 |
30 | from uslm.data.collation import TextTokenCollater
31 |
32 |
33 | class SpeechSynthesisDataset(torch.utils.data.Dataset):
34 | """
35 | The PyTorch Dataset for the speech synthesis(e.g. TTS) task.
36 | Each item in this dataset is a dict of:
37 |
38 | .. code-block::
39 |
40 | {
41 | 'audio': (B x NumSamples) float tensor
42 | 'audio_lens': (B, ) int tensor
43 | 'text': str
44 | 'audio_features': (B x NumFrames x NumFeatures) float tensor
45 | 'audio_features_lens': (B, ) int tensor
46 | 'text_tokens': (B x NumTextTokens) long tensor
47 | 'text_tokens_lens': (B, ) int tensor
48 | }
49 | """
50 |
51 | def __init__(
52 | self,
53 | text_token_collater: TextTokenCollater,
54 | cut_transforms: List[Callable[[CutSet], CutSet]] = None,
55 | feature_input_strategy: BatchIO = PrecomputedFeatures(),
56 | feature_transforms: Union[Sequence[Callable], Callable] = None,
57 | ) -> None:
58 | super().__init__()
59 |
60 | self.text_token_collater = text_token_collater
61 | self.cut_transforms = ifnone(cut_transforms, [])
62 | self.feature_input_strategy = feature_input_strategy
63 |
64 | if feature_transforms is None:
65 | feature_transforms = []
66 | elif not isinstance(feature_transforms, Sequence):
67 | feature_transforms = [feature_transforms]
68 |
69 | assert all(
70 | isinstance(transform, Callable) for transform in feature_transforms
71 | ), "Feature transforms must be Callable"
72 | self.feature_transforms = feature_transforms
73 |
74 | def __getitem__(self, cuts: CutSet) -> Dict[str, torch.Tensor]:
75 | validate_for_tts(cuts)
76 |
77 | for transform in self.cut_transforms:
78 | cuts = transform(cuts)
79 |
80 | if False: # not used
81 | audio, audio_lens = collate_audio(cuts)
82 | else: # for sharing tokenized features in different machines
83 | audio, audio_lens = None, None
84 |
85 | audio_features, audio_features_lens = self.feature_input_strategy(cuts)
86 |
87 | for transform in self.feature_transforms:
88 | audio_features = transform(audio_features)
89 |
90 | text_tokens, text_tokens_lens = self.text_token_collater(
91 | [cut.supervisions[0].custom["tokens"]["text"] for cut in cuts]
92 | )
93 |
94 | return {
95 | "utt_id": [cut.id for cut in cuts],
96 | "text": [cut.supervisions[0].text for cut in cuts],
97 | "audio": audio,
98 | "audio_lens": audio_lens,
99 | "audio_features": audio_features,
100 | "audio_features_lens": audio_features_lens,
101 | "text_tokens": text_tokens,
102 | "text_tokens_lens": text_tokens_lens,
103 | }
104 |
105 |
106 | def validate_for_tts(cuts: CutSet) -> None:
107 | validate(cuts)
108 | for cut in cuts:
109 | assert (
110 | len(cut.supervisions) == 1
111 | ), "Only the Cuts with single supervision are supported."
112 |
--------------------------------------------------------------------------------
/data/fbank.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 (authors: Feiteng Li)
2 | #
3 | # See ../../../../LICENSE for clarification regarding multiple authors
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 |
18 | from dataclasses import asdict, dataclass
19 | from typing import Any, Dict, Optional, Union
20 |
21 | import numpy as np
22 | import torch
23 | from lhotse.features.base import FeatureExtractor
24 | from lhotse.utils import EPSILON, Seconds, compute_num_frames
25 | from librosa.filters import mel as librosa_mel_fn
26 |
27 |
28 | @dataclass
29 | class BigVGANFbankConfig:
30 | # Spectogram-related part
31 | # Note that frame_length and frame_shift will be converted to milliseconds before torchaudio/Kaldi sees them
32 | frame_length: Seconds = 1024 / 24000.0
33 | frame_shift: Seconds = 256 / 24000.0
34 | remove_dc_offset: bool = True
35 | round_to_power_of_two: bool = True
36 |
37 | # Fbank-related part
38 | low_freq: float = 0.0
39 | high_freq: float = 12000.0
40 | num_mel_bins: int = 100
41 | use_energy: bool = False
42 |
43 | def to_dict(self) -> Dict[str, Any]:
44 | return asdict(self)
45 |
46 | @staticmethod
47 | def from_dict(data: Dict[str, Any]) -> "BigVGANFbankConfig":
48 | return BigVGANFbankConfig(**data)
49 |
50 |
51 | def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
52 | return torch.log(torch.clamp(x, min=clip_val) * C)
53 |
54 |
55 | def spectral_normalize_torch(magnitudes):
56 | output = dynamic_range_compression_torch(magnitudes)
57 | return output
58 |
59 |
60 | # https://github.com/NVIDIA/BigVGAN
61 | # bigvgan_24khz_100band https://drive.google.com/drive/folders/1EpxX6AsxjCbbk0mmAhE0td6eYiABr8Oz
62 | class BigVGANFbank(FeatureExtractor):
63 | name = "fbank"
64 | config_type = BigVGANFbankConfig
65 |
66 | def __init__(self, config: Optional[Any] = None):
67 | super(BigVGANFbank, self).__init__(config)
68 | sampling_rate = 24000
69 | self.mel_basis = torch.from_numpy(
70 | librosa_mel_fn(
71 | sampling_rate,
72 | 1024,
73 | self.config.num_mel_bins,
74 | self.config.low_freq,
75 | self.config.high_freq,
76 | ).astype(np.float32)
77 | )
78 | self.hann_window = torch.hann_window(1024)
79 |
80 | def _feature_fn(self, samples, **kwargs):
81 | win_length, n_fft = 1024, 1024
82 | hop_size = 256
83 | if True:
84 | sampling_rate = 24000
85 | duration = round(samples.shape[-1] / sampling_rate, ndigits=12)
86 | expected_num_frames = compute_num_frames(
87 | duration=duration,
88 | frame_shift=self.frame_shift,
89 | sampling_rate=sampling_rate,
90 | )
91 | pad_size = (
92 | (expected_num_frames - 1) * hop_size
93 | + win_length
94 | - samples.shape[-1]
95 | )
96 | assert pad_size >= 0
97 |
98 | y = torch.nn.functional.pad(
99 | samples,
100 | (0, pad_size),
101 | mode="constant",
102 | )
103 | else:
104 | y = torch.nn.functional.pad(
105 | samples,
106 | (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
107 | mode="reflect",
108 | )
109 |
110 | y = y.squeeze(1)
111 |
112 | # complex tensor as default, then use view_as_real for future pytorch compatibility
113 | spec = torch.stft(
114 | y,
115 | n_fft,
116 | hop_length=hop_size,
117 | win_length=win_length,
118 | window=self.hann_window,
119 | center=False,
120 | pad_mode="reflect",
121 | normalized=False,
122 | onesided=True,
123 | return_complex=True,
124 | )
125 | spec = torch.view_as_real(spec)
126 | spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9))
127 |
128 | spec = torch.matmul(self.mel_basis, spec)
129 | spec = spectral_normalize_torch(spec)
130 |
131 | return spec.transpose(2, 1).squeeze(0)
132 |
133 | def extract(
134 | self, samples: Union[np.ndarray, torch.Tensor], sampling_rate: int
135 | ) -> np.ndarray:
136 | assert sampling_rate == 24000
137 | params = asdict(self.config)
138 | params.update({"sample_frequency": sampling_rate, "snip_edges": False})
139 | params["frame_shift"] *= 1000.0
140 | params["frame_length"] *= 1000.0
141 | if not isinstance(samples, torch.Tensor):
142 | samples = torch.from_numpy(samples)
143 | # Torchaudio Kaldi feature extractors expect the channel dimension to be first.
144 | if len(samples.shape) == 1:
145 | samples = samples.unsqueeze(0)
146 | features = self._feature_fn(samples, **params).to(torch.float32)
147 | return features.numpy()
148 |
149 | @property
150 | def frame_shift(self) -> Seconds:
151 | return self.config.frame_shift
152 |
153 | def feature_dim(self, sampling_rate: int) -> int:
154 | return self.config.num_mel_bins
155 |
156 | @staticmethod
157 | def mix(
158 | features_a: np.ndarray,
159 | features_b: np.ndarray,
160 | energy_scaling_factor_b: float,
161 | ) -> np.ndarray:
162 | return np.log(
163 | np.maximum(
164 | # protection against log(0); max with EPSILON is adequate since these are energies (always >= 0)
165 | EPSILON,
166 | np.exp(features_a)
167 | + energy_scaling_factor_b * np.exp(features_b),
168 | )
169 | )
170 |
171 | @staticmethod
172 | def compute_energy(features: np.ndarray) -> float:
173 | return float(np.sum(np.exp(features)))
174 |
175 |
176 | def get_fbank_extractor() -> BigVGANFbank:
177 | return BigVGANFbank(BigVGANFbankConfig())
178 |
179 |
180 | if __name__ == "__main__":
181 | extractor = BigVGANFbank(BigVGANFbankConfig())
182 |
183 | samples = torch.from_numpy(np.random.random([1000]).astype(np.float32))
184 | samples = torch.clip(samples, -1.0, 1.0)
185 | fbank = extractor.extract(samples, 24000.0)
186 | print(f"fbank {fbank.shape}")
187 |
188 | from scipy.io.wavfile import read
189 |
190 | MAX_WAV_VALUE = 32768.0
191 |
192 | sampling_rate, samples = read(
193 | "egs/libritts/prompts/5639_40744_000000_000002.wav"
194 | )
195 | print(f"samples: [{samples.min()}, {samples.max()}]")
196 | fbank = extractor.extract(samples.astype(np.float32) / MAX_WAV_VALUE, 24000)
197 | print(f"fbank {fbank.shape}")
198 |
199 | import matplotlib.pyplot as plt
200 |
201 | _ = plt.figure(figsize=(18, 10))
202 | plt.imshow(
203 | X=fbank.transpose(1, 0),
204 | cmap=plt.get_cmap("jet"),
205 | aspect="auto",
206 | interpolation="nearest",
207 | )
208 | plt.gca().invert_yaxis()
209 | plt.savefig("egs/libritts/prompts/5639_40744_000000_000002.png")
210 | plt.close()
211 |
212 | print("fbank test PASS!")
213 |
--------------------------------------------------------------------------------
/data/input_strategies.py:
--------------------------------------------------------------------------------
1 | import random
2 | from collections import defaultdict
3 | from concurrent.futures import ThreadPoolExecutor
4 | from typing import Tuple, Type
5 |
6 | from lhotse import CutSet
7 | from lhotse.dataset.collation import collate_features
8 | from lhotse.dataset.input_strategies import (
9 | ExecutorType,
10 | PrecomputedFeatures,
11 | _get_executor,
12 | )
13 | from lhotse.utils import fastcopy
14 |
15 |
16 | class PromptedFeatures:
17 | def __init__(self, prompts, features):
18 | self.prompts = prompts
19 | self.features = features
20 |
21 | def to(self, device):
22 | return PromptedFeatures(
23 | self.prompts.to(device), self.features.to(device)
24 | )
25 |
26 | def sum(self):
27 | return self.features.sum()
28 |
29 | @property
30 | def ndim(self):
31 | return self.features.ndim
32 |
33 | @property
34 | def data(self):
35 | return (self.prompts, self.features)
36 |
37 |
38 | class PromptedPrecomputedFeatures(PrecomputedFeatures):
39 | """
40 | :class:`InputStrategy` that reads pre-computed features, whose manifests
41 | are attached to cuts, from disk.
42 |
43 | It automatically pads the feature matrices with pre or post feature.
44 |
45 | .. automethod:: __call__
46 | """
47 |
48 | def __init__(
49 | self,
50 | dataset: str,
51 | cuts: CutSet,
52 | num_workers: int = 0,
53 | executor_type: Type[ExecutorType] = ThreadPoolExecutor,
54 | ) -> None:
55 | super(PromptedPrecomputedFeatures, self).__init__(
56 | num_workers, executor_type
57 | )
58 |
59 | self.utt2neighbors = defaultdict(lambda: [])
60 |
61 | if dataset.lower() == "libritts":
62 | # 909_131041_000013_000002
63 | # 909_131041_000013_000003
64 | speaker2utts = defaultdict(lambda: [])
65 |
66 | utt2cut = {}
67 | for cut in cuts:
68 | speaker = cut.supervisions[0].speaker
69 | speaker2utts[speaker].append(cut.id)
70 | utt2cut[cut.id] = cut
71 |
72 | for spk in speaker2utts:
73 | uttids = sorted(speaker2utts[spk])
74 | # Using the property of sorted keys to find previous utterance
75 | # The keys has structure speaker_book_x_y e.g. 1089_134691_000004_000001
76 | if len(uttids) == 1:
77 | self.utt2neighbors[uttids[0]].append(utt2cut[uttids[0]])
78 | continue
79 |
80 | utt2prevutt = dict(zip(uttids, [uttids[1]] + uttids[:-1]))
81 | utt2postutt = dict(zip(uttids[:-1], uttids[1:]))
82 |
83 | for utt in utt2prevutt:
84 | self.utt2neighbors[utt].append(utt2cut[utt2prevutt[utt]])
85 |
86 | for utt in utt2postutt:
87 | self.utt2neighbors[utt].append(utt2cut[utt2postutt[utt]])
88 | elif dataset.lower() == "ljspeech":
89 | utt2cut = {}
90 | uttids = []
91 | for cut in cuts:
92 | uttids.append(cut.id)
93 | utt2cut[cut.id] = cut
94 |
95 | if len(uttids) == 1:
96 | self.utt2neighbors[uttids[0]].append(utt2cut[uttids[0]])
97 | else:
98 | # Using the property of sorted keys to find previous utterance
99 | # The keys has structure: LJ001-0010
100 | utt2prevutt = dict(zip(uttids, [uttids[1]] + uttids[:-1]))
101 | utt2postutt = dict(zip(uttids[:-1], uttids[1:]))
102 |
103 | for utt in utt2postutt:
104 | postutt = utt2postutt[utt]
105 | if utt[:5] == postutt[:5]:
106 | self.utt2neighbors[utt].append(utt2cut[postutt])
107 |
108 | for utt in utt2prevutt:
109 | prevutt = utt2prevutt[utt]
110 | if utt[:5] == prevutt[:5] or not self.utt2neighbors[utt]:
111 | self.utt2neighbors[utt].append(utt2cut[prevutt])
112 | else:
113 | raise ValueError
114 |
115 | def __call__(
116 | self, cuts: CutSet
117 | ) -> Tuple[PromptedFeatures, PromptedFeatures]:
118 | """
119 | Reads the pre-computed features from disk/other storage.
120 | The returned shape is``(B, T, F) => (batch_size, num_frames, num_features)``.
121 |
122 | :return: a tensor with collated features, and a tensor of ``num_frames`` of each cut before padding.
123 | """
124 | features, features_lens = collate_features(
125 | cuts,
126 | executor=_get_executor(
127 | self.num_workers, executor_type=self._executor_type
128 | ),
129 | )
130 |
131 | prompts_cuts = []
132 | for k, cut in enumerate(cuts):
133 | prompts_cut = random.choice(self.utt2neighbors[cut.id])
134 | prompts_cuts.append(fastcopy(prompts_cut, id=f"{cut.id}-{str(k)}"))
135 |
136 | mini_duration = min([cut.duration for cut in prompts_cuts] + [3.0])
137 | # prompts_cuts = CutSet.from_cuts(prompts_cuts).truncate(
138 | # max_duration=mini_duration,
139 | # offset_type="random",
140 | # preserve_id=True,
141 | # )
142 | prompts_cuts = CutSet(
143 | cuts={k: cut for k, cut in enumerate(prompts_cuts)}
144 | ).truncate(
145 | max_duration=mini_duration,
146 | offset_type="random",
147 | preserve_id=False,
148 | )
149 |
150 | prompts, prompts_lens = collate_features(
151 | prompts_cuts,
152 | executor=_get_executor(
153 | self.num_workers, executor_type=self._executor_type
154 | ),
155 | )
156 |
157 | return PromptedFeatures(prompts, features), PromptedFeatures(
158 | prompts_lens, features_lens
159 | )
160 |
--------------------------------------------------------------------------------
/data/speechtokenizer.py:
--------------------------------------------------------------------------------
1 | """
2 | Created on Fri. Sept. 8 00:40:25 2023
3 | @author: Dong Zhang
4 | """
5 |
6 | import re
7 | from dataclasses import asdict, dataclass
8 | from typing import Any, Dict, List, Optional, Pattern, Union
9 |
10 | import numpy as np
11 | import torch
12 | import torchaudio
13 | from encodec.utils import convert_audio
14 | from lhotse.features import FeatureExtractor
15 | from lhotse.utils import Seconds, compute_num_frames
16 | import json
17 | import os
18 | from speechtokenizer import SpeechTokenizer
19 |
20 |
21 |
22 | class AttrDict(dict):
23 | def __init__(self, *args, **kwargs):
24 | super(AttrDict, self).__init__(*args, **kwargs)
25 | self.__dict__ = self
26 |
27 |
28 | class Speechtokenizer:
29 | """SpeechTokenizer"""
30 |
31 | def __init__(
32 | self,
33 | ckpt_dir: str = "",
34 | device: Any = None,
35 | ) -> None:
36 |
37 | config_path = os.path.join(ckpt_dir, "config.json")
38 | ckpt_path = os.path.join(ckpt_dir, "SpeechTokenizer.pt")
39 | # load model
40 | model = SpeechTokenizer.load_from_checkpoint(config_path, ckpt_path)
41 | model.eval()
42 |
43 | if not device:
44 | device = torch.device("cpu")
45 | if torch.cuda.is_available():
46 | device = torch.device("cuda:0")
47 |
48 | self._device = device
49 | self.model = model.to(device)
50 | self.sample_rate = model.sample_rate
51 | self.channels = 1
52 |
53 | @property
54 | def device(self):
55 | return self._device
56 |
57 | def encode(self, wav: torch.Tensor) -> torch.Tensor:
58 | return self.model.encode(wav.to(self.device))
59 |
60 | def decode(self, frames: torch.Tensor) -> torch.Tensor:
61 | return self.model.decode(frames)
62 |
63 |
64 | def sttokenize_audio(tokenizer: Speechtokenizer, audio_path: str):
65 | # Load and pre-process the audio waveform
66 | wav, sr = torchaudio.load(audio_path)
67 | wav = convert_audio(wav, sr, tokenizer.sample_rate, tokenizer.channels)
68 | wav = wav.unsqueeze(0)
69 |
70 | # Extract discrete codes from EnCodec
71 | with torch.no_grad():
72 | encoded_frames = tokenizer.encode(wav)
73 | return encoded_frames
74 |
75 |
76 | @dataclass
77 | class STAudioTokenConfig:
78 | frame_shift: Seconds = 320.0 / 16000
79 | num_quantizers: int = 8
80 |
81 | def to_dict(self) -> Dict[str, Any]:
82 | return asdict(self)
83 |
84 | @staticmethod
85 | def from_dict(data: Dict[str, Any]) -> "STAudioTokenConfig":
86 | return STAudioTokenConfig(**data)
87 |
88 |
89 | class STAudioTokenExtractor(FeatureExtractor):
90 | name = "speechtokenizer"
91 | config_type = STAudioTokenConfig
92 |
93 | def __init__(self, config: Optional[Any] = None):
94 | super(STAudioTokenExtractor, self).__init__(config)
95 | self.tokenizer = Speechtokenizer()
96 |
97 | def extract(
98 | self, samples: Union[np.ndarray, torch.Tensor], sampling_rate: int
99 | ) -> np.ndarray:
100 | if not isinstance(samples, torch.Tensor):
101 | samples = torch.from_numpy(samples)
102 | if sampling_rate != self.tokenizer.sample_rate:
103 | samples = convert_audio(
104 | samples,
105 | sampling_rate,
106 | self.tokenizer.sample_rate,
107 | self.tokenizer.channels,
108 | )
109 | if len(samples.shape) == 2:
110 | samples = samples.unsqueeze(0)
111 | else:
112 | raise ValueError()
113 |
114 | device = self.tokenizer.device
115 | encoded_frames = self.tokenizer.encode(samples.detach().to(device))
116 | codes = encoded_frames.permute(1,0,2) # [B, n_q, T]
117 | if True:
118 | duration = round(samples.shape[-1] / sampling_rate, ndigits=12)
119 | expected_num_frames = compute_num_frames(
120 | duration=duration,
121 | frame_shift=self.frame_shift,
122 | sampling_rate=sampling_rate,
123 | )
124 | assert abs(codes.shape[-1] - expected_num_frames) <= 1
125 | codes = codes[..., :expected_num_frames]
126 | return codes.cpu().squeeze(0).permute(1, 0).numpy()
127 |
128 | @property
129 | def frame_shift(self) -> Seconds:
130 | return self.config.frame_shift
131 |
132 | def feature_dim(self, sampling_rate: int) -> int:
133 | return self.config.num_quantizers
134 |
135 | def pad_tensor_list(self, tensor_list, device, padding_value=0):
136 | # 计算每个张量的长度
137 | lengths = [tensor.shape[0] for tensor in tensor_list]
138 | # 使用pad_sequence函数进行填充
139 | tensor_list = [torch.Tensor(t).to(device) for t in tensor_list]
140 | padded_tensor = torch.nn.utils.rnn.pad_sequence(
141 | tensor_list, batch_first=True, padding_value=padding_value
142 | )
143 | return padded_tensor, lengths
144 |
145 | def extract_batch(self, samples, sampling_rate, lengths) -> np.ndarray:
146 | samples = [wav.squeeze() for wav in samples]
147 | device = self.tokenizer.device
148 | samples, lengths = self.pad_tensor_list(samples, device)
149 | samples = samples.unsqueeze(1)
150 |
151 | if not isinstance(samples, torch.Tensor):
152 | samples = torch.from_numpy(samples)
153 | if len(samples.shape) != 3:
154 | raise ValueError()
155 | if sampling_rate != self.tokenizer.sample_rate:
156 | samples = [
157 | convert_audio(
158 | wav,
159 | sampling_rate,
160 | self.tokenizer.sample_rate,
161 | self.tokenizer.channels,
162 | )
163 | for wav in samples
164 | ]
165 | # Extract discrete codes from EnCodec
166 | with torch.no_grad():
167 | encoded_frames = self.tokenizer.encode(samples.detach().to(device))
168 | encoded_frames = encoded_frames.permute(1,0,2) # [B, n_q, T]
169 | batch_codes = []
170 | for b, length in enumerate(lengths):
171 | codes = encoded_frames[b]
172 | duration = round(length / sampling_rate, ndigits=12)
173 | expected_num_frames = compute_num_frames(
174 | duration=duration,
175 | frame_shift=self.frame_shift,
176 | sampling_rate=sampling_rate,
177 | )
178 | batch_codes.append(codes[..., :expected_num_frames])
179 | return [codes.cpu().permute(1, 0).numpy() for codes in batch_codes]
180 |
181 |
182 |
183 |
184 |
185 |
186 |
187 |
188 |
189 |
190 | if __name__ == "__main__":
191 |
192 | audio_path = r'valle/examples/libritts/prompts/8455_210777_000067_000000.wav'
193 | speechtokenizer = Speechtokenizer()
194 |
195 | wav, sr = torchaudio.load(audio_path)
196 | wav = convert_audio(wav, sr, tokenizer.sample_rate, tokenizer.channels)
197 | wav = wav.unsqueeze(0)
198 |
199 | # Extract discrete codes from EnCodec
200 | with torch.no_grad():
201 | tokens = speechtokenizer.encode(wav)
202 |
203 | reconstructed = speechtokenizer.decode(tokens)
204 |
205 |
206 |
--------------------------------------------------------------------------------
/data/tokenizer.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright 2023 (authors: Feiteng Li)
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | import re
17 | from dataclasses import asdict, dataclass
18 | from typing import Any, Dict, List, Optional, Pattern, Union
19 |
20 | import numpy as np
21 | import torch
22 | import torchaudio
23 | from encodec import EncodecModel
24 | from encodec.utils import convert_audio
25 | from lhotse.features import FeatureExtractor
26 | from lhotse.utils import Seconds, compute_num_frames
27 | from phonemizer.backend import EspeakBackend
28 | from phonemizer.backend.espeak.language_switch import LanguageSwitch
29 | from phonemizer.backend.espeak.words_mismatch import WordMismatch
30 | from phonemizer.punctuation import Punctuation
31 | from phonemizer.separator import Separator
32 |
33 | try:
34 | from pypinyin import Style, pinyin
35 | from pypinyin.style._utils import get_finals, get_initials
36 | except Exception:
37 | pass
38 |
39 |
40 | class PypinyinBackend:
41 | """PypinyinBackend for Chinese. Most codes is referenced from espnet.
42 | There are two types pinyin or initials_finals, one is
43 | just like "ni1 hao3", the other is like "n i1 h ao3".
44 | """
45 |
46 | def __init__(
47 | self,
48 | backend="initials_finals",
49 | punctuation_marks: Union[str, Pattern] = Punctuation.default_marks(),
50 | ) -> None:
51 | self.backend = backend
52 | self.punctuation_marks = punctuation_marks
53 |
54 | def phonemize(
55 | self, text: List[str], separator: Separator, strip=True, njobs=1
56 | ) -> List[str]:
57 | assert isinstance(text, List)
58 | phonemized = []
59 | for _text in text:
60 | _text = re.sub(" +", " ", _text.strip())
61 | _text = _text.replace(" ", separator.word)
62 | phones = []
63 | if self.backend == "pypinyin":
64 | for n, py in enumerate(
65 | pinyin(
66 | _text, style=Style.TONE3, neutral_tone_with_five=True
67 | )
68 | ):
69 | if all([c in self.punctuation_marks for c in py[0]]):
70 | if len(phones):
71 | assert phones[-1] == separator.syllable
72 | phones.pop(-1)
73 |
74 | phones.extend(list(py[0]))
75 | else:
76 | phones.extend([py[0], separator.syllable])
77 | elif self.backend == "pypinyin_initials_finals":
78 | for n, py in enumerate(
79 | pinyin(
80 | _text, style=Style.TONE3, neutral_tone_with_five=True
81 | )
82 | ):
83 | if all([c in self.punctuation_marks for c in py[0]]):
84 | if len(phones):
85 | assert phones[-1] == separator.syllable
86 | phones.pop(-1)
87 | phones.extend(list(py[0]))
88 | else:
89 | if py[0][-1].isalnum():
90 | initial = get_initials(py[0], strict=False)
91 | if py[0][-1].isdigit():
92 | final = (
93 | get_finals(py[0][:-1], strict=False)
94 | + py[0][-1]
95 | )
96 | else:
97 | final = get_finals(py[0], strict=False)
98 | phones.extend(
99 | [
100 | initial,
101 | separator.phone,
102 | final,
103 | separator.syllable,
104 | ]
105 | )
106 | else:
107 | assert ValueError
108 | else:
109 | raise NotImplementedError
110 | phonemized.append(
111 | "".join(phones).rstrip(f"{separator.word}{separator.syllable}")
112 | )
113 | return phonemized
114 |
115 |
116 | class TextTokenizer:
117 | """Phonemize Text."""
118 |
119 | def __init__(
120 | self,
121 | language="en-us",
122 | backend="espeak",
123 | separator=Separator(word="_", syllable="-", phone="|"),
124 | preserve_punctuation=True,
125 | punctuation_marks: Union[str, Pattern] = Punctuation.default_marks(),
126 | with_stress: bool = False,
127 | tie: Union[bool, str] = False,
128 | language_switch: LanguageSwitch = "keep-flags",
129 | words_mismatch: WordMismatch = "ignore",
130 | ) -> None:
131 | if backend == "espeak":
132 | phonemizer = EspeakBackend(
133 | language,
134 | punctuation_marks=punctuation_marks,
135 | preserve_punctuation=preserve_punctuation,
136 | with_stress=with_stress,
137 | tie=tie,
138 | language_switch=language_switch,
139 | words_mismatch=words_mismatch,
140 | )
141 | elif backend in ["pypinyin", "pypinyin_initials_finals"]:
142 | phonemizer = PypinyinBackend(
143 | backend=backend,
144 | punctuation_marks=punctuation_marks + separator.word,
145 | )
146 | else:
147 | raise NotImplementedError(f"{backend}")
148 |
149 | self.backend = phonemizer
150 | self.separator = separator
151 |
152 | def to_list(self, phonemized: str) -> List[str]:
153 | fields = []
154 | for word in phonemized.split(self.separator.word):
155 | # "ɐ m|iː|n?" ɹ|ɪ|z|ɜː|v; h|ɪ|z.
156 | pp = re.findall(r"\w+|[^\w\s]", word, re.UNICODE)
157 | fields.extend(
158 | [p for p in pp if p != self.separator.phone]
159 | + [self.separator.word]
160 | )
161 | assert len("".join(fields[:-1])) == len(phonemized) - phonemized.count(
162 | self.separator.phone
163 | )
164 | return fields[:-1]
165 |
166 | def __call__(self, text, strip=True) -> List[List[str]]:
167 | if isinstance(text, str):
168 | text = [text]
169 |
170 | phonemized = self.backend.phonemize(
171 | text, separator=self.separator, strip=strip, njobs=1
172 | )
173 | return [self.to_list(p) for p in phonemized]
174 |
175 |
176 | def tokenize_text(tokenizer: TextTokenizer, text: str) -> List[str]:
177 | phonemes = tokenizer([text.strip()])
178 | return phonemes[0] # k2symbols
179 |
180 |
181 | def remove_encodec_weight_norm(model):
182 | from encodec.modules import SConv1d
183 | from encodec.modules.seanet import SConvTranspose1d, SEANetResnetBlock
184 | from torch.nn.utils import remove_weight_norm
185 |
186 | encoder = model.encoder.model
187 | for key in encoder._modules:
188 | if isinstance(encoder._modules[key], SEANetResnetBlock):
189 | remove_weight_norm(encoder._modules[key].shortcut.conv.conv)
190 | block_modules = encoder._modules[key].block._modules
191 | for skey in block_modules:
192 | if isinstance(block_modules[skey], SConv1d):
193 | remove_weight_norm(block_modules[skey].conv.conv)
194 | elif isinstance(encoder._modules[key], SConv1d):
195 | remove_weight_norm(encoder._modules[key].conv.conv)
196 |
197 | decoder = model.decoder.model
198 | for key in decoder._modules:
199 | if isinstance(decoder._modules[key], SEANetResnetBlock):
200 | remove_weight_norm(decoder._modules[key].shortcut.conv.conv)
201 | block_modules = decoder._modules[key].block._modules
202 | for skey in block_modules:
203 | if isinstance(block_modules[skey], SConv1d):
204 | remove_weight_norm(block_modules[skey].conv.conv)
205 | elif isinstance(decoder._modules[key], SConvTranspose1d):
206 | remove_weight_norm(decoder._modules[key].convtr.convtr)
207 | elif isinstance(decoder._modules[key], SConv1d):
208 | remove_weight_norm(decoder._modules[key].conv.conv)
209 |
210 |
211 | class AudioTokenizer:
212 | """EnCodec audio."""
213 |
214 | def __init__(
215 | self,
216 | device: Any = None,
217 | ) -> None:
218 | # Instantiate a pretrained EnCodec model
219 | model = EncodecModel.encodec_model_24khz()
220 | model.set_target_bandwidth(6.0)
221 | remove_encodec_weight_norm(model)
222 |
223 | if not device:
224 | device = torch.device("cpu")
225 | if torch.cuda.is_available():
226 | device = torch.device("cuda:0")
227 |
228 | self._device = device
229 |
230 | self.codec = model.to(device)
231 | self.sample_rate = model.sample_rate
232 | self.channels = model.channels
233 |
234 | @property
235 | def device(self):
236 | return self._device
237 |
238 | def encode(self, wav: torch.Tensor) -> torch.Tensor:
239 | return self.codec.encode(wav.to(self.device))
240 |
241 | def decode(self, frames: torch.Tensor) -> torch.Tensor:
242 | return self.codec.decode(frames)
243 |
244 |
245 | def tokenize_audio(tokenizer: AudioTokenizer, audio_path: str):
246 | # Load and pre-process the audio waveform
247 | wav, sr = torchaudio.load(audio_path)
248 | wav = convert_audio(wav, sr, tokenizer.sample_rate, tokenizer.channels)
249 | wav = wav.unsqueeze(0)
250 |
251 | # Extract discrete codes from EnCodec
252 | with torch.no_grad():
253 | encoded_frames = tokenizer.encode(wav)
254 | return encoded_frames
255 |
256 |
257 | @dataclass
258 | class AudioTokenConfig:
259 | frame_shift: Seconds = 320.0 / 24000
260 | num_quantizers: int = 8
261 |
262 | def to_dict(self) -> Dict[str, Any]:
263 | return asdict(self)
264 |
265 | @staticmethod
266 | def from_dict(data: Dict[str, Any]) -> "AudioTokenConfig":
267 | return AudioTokenConfig(**data)
268 |
269 |
270 | class AudioTokenExtractor(FeatureExtractor):
271 | name = "encodec"
272 | config_type = AudioTokenConfig
273 |
274 | def __init__(self, config: Optional[Any] = None):
275 | super(AudioTokenExtractor, self).__init__(config)
276 | self.tokenizer = AudioTokenizer()
277 |
278 | def extract(
279 | self, samples: Union[np.ndarray, torch.Tensor], sampling_rate: int
280 | ) -> np.ndarray:
281 | if not isinstance(samples, torch.Tensor):
282 | samples = torch.from_numpy(samples)
283 | if sampling_rate != self.tokenizer.sample_rate:
284 | samples = convert_audio(
285 | samples,
286 | sampling_rate,
287 | self.tokenizer.sample_rate,
288 | self.tokenizer.channels,
289 | )
290 | if len(samples.shape) == 2:
291 | samples = samples.unsqueeze(0)
292 | else:
293 | raise ValueError()
294 |
295 | device = self.tokenizer.device
296 | encoded_frames = self.tokenizer.encode(samples.detach().to(device))
297 | codes = encoded_frames[0][0] # [B, n_q, T]
298 | if True:
299 | duration = round(samples.shape[-1] / sampling_rate, ndigits=12)
300 | expected_num_frames = compute_num_frames(
301 | duration=duration,
302 | frame_shift=self.frame_shift,
303 | sampling_rate=sampling_rate,
304 | )
305 | assert abs(codes.shape[-1] - expected_num_frames) <= 1
306 | codes = codes[..., :expected_num_frames]
307 | return codes.cpu().squeeze(0).permute(1, 0).numpy()
308 |
309 | @property
310 | def frame_shift(self) -> Seconds:
311 | return self.config.frame_shift
312 |
313 | def feature_dim(self, sampling_rate: int) -> int:
314 | return self.config.num_quantizers
315 |
316 | def pad_tensor_list(self, tensor_list, device, padding_value=0):
317 | # 计算每个张量的长度
318 | lengths = [tensor.shape[0] for tensor in tensor_list]
319 | # 使用pad_sequence函数进行填充
320 | tensor_list = [torch.Tensor(t).to(device) for t in tensor_list]
321 | padded_tensor = torch.nn.utils.rnn.pad_sequence(
322 | tensor_list, batch_first=True, padding_value=padding_value
323 | )
324 | return padded_tensor, lengths
325 |
326 | def extract_batch(self, samples, sampling_rate, lengths) -> np.ndarray:
327 | samples = [wav.squeeze() for wav in samples]
328 | device = self.tokenizer.device
329 | samples, lengths = self.pad_tensor_list(samples, device)
330 | samples = samples.unsqueeze(1)
331 |
332 | if not isinstance(samples, torch.Tensor):
333 | samples = torch.from_numpy(samples)
334 | if len(samples.shape) != 3:
335 | raise ValueError()
336 |
337 | if sampling_rate != self.tokenizer.sample_rate:
338 | samples = [
339 | convert_audio(
340 | wav.cpu(),
341 | sampling_rate,
342 | self.tokenizer.sample_rate,
343 | self.tokenizer.channels,
344 | ).cuda()
345 | for wav in samples
346 | ]
347 | samples = torch.stack(samples, 0)
348 |
349 | # Extract discrete codes from EnCodec
350 | with torch.no_grad():
351 | encoded_frames = self.tokenizer.encode(samples.detach().to(device))
352 | encoded_frames = encoded_frames[0][0] # [B, n_q, T]
353 | batch_codes = []
354 | for b, length in enumerate(lengths):
355 | codes = encoded_frames[b]
356 | duration = round(length / sampling_rate, ndigits=12)
357 | expected_num_frames = compute_num_frames(
358 | duration=duration,
359 | frame_shift=self.frame_shift,
360 | sampling_rate=sampling_rate,
361 | )
362 | batch_codes.append(codes[..., :expected_num_frames])
363 | return [codes.cpu().permute(1, 0).numpy() for codes in batch_codes]
364 |
365 |
366 | if __name__ == "__main__":
367 | model = EncodecModel.encodec_model_24khz()
368 | model.set_target_bandwidth(6.0)
369 |
370 | samples = torch.from_numpy(np.random.random([1, 1, 16000])).type(
371 | torch.float32
372 | )
373 | codes_raw = model.encode(samples)
374 |
375 | remove_encodec_weight_norm(model)
376 | codes_norm = model.encode(samples)
377 |
378 | assert torch.allclose(codes_raw[0][0], codes_norm[0][0])
379 | print(codes_raw[0][0].shape)
380 |
--------------------------------------------------------------------------------
/images/README.md:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/images/overview.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/0nutation/USLM/b6cbe07feaeb142fa395b8b98836c534fec16bc9/images/overview.png
--------------------------------------------------------------------------------
/inference.sh:
--------------------------------------------------------------------------------
1 |
2 | audio_extractor="SpeechTokenizer"
3 |
4 | st_dir="ckpt/speechtokenizer/"
5 | uslm_dir="ckpt/uslm/"
6 | out_dir="output/"
7 |
8 | mkdir -p ${st_dir}
9 | mkdir -p ${uslm_dir}
10 | mkdir -p ${out_dir}
11 |
12 | if [ ! -e "${st_dir}/config.json" ];then
13 | cd ${st_dir}
14 | wget "https://huggingface.co/fnlp/SpeechTokenizer/resolve/main/speechtokenizer_hubert_avg/SpeechTokenizer.pt"
15 | wget "https://huggingface.co/fnlp/SpeechTokenizer/resolve/main/speechtokenizer_hubert_avg/config.json"
16 | cd -
17 | fi
18 |
19 | if [ ! -e "${uslm_dir}/USLM.pt" ];then
20 | cd ${uslm_dir}
21 | wget "https://huggingface.co/fnlp/USLM/resolve/main/USLM_libritts/USLM.pt"
22 | wget "https://huggingface.co/fnlp/USLM/resolve/main/USLM_libritts/unique_text_tokens.k2symbols"
23 | cd -
24 | fi
25 |
26 |
27 | python3 bin/infer.py --output-dir ${out_dir}/ \
28 | --model-name uslm --norm-first true --add-prenet false \
29 | --share-embedding true --norm-first true --add-prenet false \
30 | --audio-extractor "${audio_extractor}" \
31 | --speechtokenizer-dir "${st_dir}" \
32 | --checkpoint=${uslm_dir}/best-valid-loss.pt \
33 | --text-tokens "${uslm_dir}/unique_text_tokens.k2symbols" \
34 | --text-prompts "mr Soames was a tall, spare man, of a nervous and excitable temperament." \
35 | --audio-prompts prompts/1580_141083_000002_000002.wav \
36 | --text "Begin with the fundamental steps of the process. This will give you a solid foundation to build upon and boost your confidence. " \
37 |
38 |
--------------------------------------------------------------------------------
/models/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/0nutation/USLM/b6cbe07feaeb142fa395b8b98836c534fec16bc9/models/.DS_Store
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 |
2 | import faulthandler
3 | faulthandler.enable()
4 |
5 | import argparse
6 |
7 | import torch.nn as nn
8 | from icefall.utils import AttributeDict, str2bool
9 |
10 | from .macros import (
11 | NUM_AUDIO_TOKENS,
12 | NUM_MEL_BINS,
13 | NUM_SPEAKER_CLASSES,
14 | NUM_TEXT_TOKENS,
15 | SPEAKER_EMBEDDING_DIM,
16 | )
17 | from .transformer import Transformer
18 | from uslm.models.valle import VALLE, VALLF
19 | from uslm.models.uslm import USLM
20 | from .visualizer import visualize
21 |
22 |
23 | def add_model_arguments(parser: argparse.ArgumentParser):
24 | parser.add_argument(
25 | "--model-name",
26 | type=str,
27 | default="USLM",
28 | help="USLM, VALL-E, VALL-F, Transformer.",
29 | )
30 | parser.add_argument(
31 | "--decoder-dim",
32 | type=int,
33 | default=1024,
34 | help="Embedding dimension in the decoder model.",
35 | )
36 | parser.add_argument(
37 | "--nhead",
38 | type=int,
39 | default=16,
40 | help="Number of attention heads in the Decoder layers.",
41 | )
42 | parser.add_argument(
43 | "--num-decoder-layers",
44 | type=int,
45 | default=12,
46 | help="Number of Decoder layers.",
47 | )
48 | parser.add_argument(
49 | "--scale-factor",
50 | type=float,
51 | default=1.0,
52 | help="Model scale factor which will be assigned different meanings in different models.",
53 | )
54 | parser.add_argument(
55 | "--norm-first",
56 | type=str2bool,
57 | default=True,
58 | help="Pre or Post Normalization.",
59 | )
60 | parser.add_argument(
61 | "--add-prenet",
62 | type=str2bool,
63 | default=False,
64 | help="Whether add PreNet after Inputs.",
65 | )
66 |
67 | # VALL-E & F
68 | parser.add_argument(
69 | "--prefix-mode",
70 | type=int,
71 | default=0,
72 | help="The mode for how to prefix VALL-E NAR Decoder, "
73 | "0: no prefix, 1: 0 to random, 2: random to random, 4: chunk of pre or post utterance.",
74 | )
75 | parser.add_argument(
76 | "--share-embedding",
77 | type=str2bool,
78 | default=True,
79 | help="Share the parameters of the output projection layer with the parameters of the acoustic embedding.",
80 | )
81 | parser.add_argument(
82 | "--prepend-bos",
83 | type=str2bool,
84 | default=False,
85 | help="Whether prepend to the acoustic tokens -> AR Decoder inputs.",
86 | )
87 | parser.add_argument(
88 | "--num-quantizers",
89 | type=int,
90 | default=8,
91 | help="Number of Audio/Semantic quantization layers.",
92 | )
93 |
94 | # Transformer
95 | parser.add_argument(
96 | "--scaling-xformers",
97 | type=str2bool,
98 | default=False,
99 | help="Apply Reworked Conformer scaling on Transformers.",
100 | )
101 |
102 |
103 | def get_model(params: AttributeDict) -> nn.Module:
104 | if params.model_name.lower() in ["USLM", "uslm"]:
105 | model = USLM(
106 | params.decoder_dim,
107 | params.nhead,
108 | params.num_decoder_layers,
109 | norm_first=params.norm_first,
110 | add_prenet=params.add_prenet,
111 | prefix_mode=params.prefix_mode,
112 | share_embedding=params.share_embedding,
113 | nar_scale_factor=params.scale_factor,
114 | prepend_bos=params.prepend_bos,
115 | num_quantizers=params.num_quantizers,
116 | )
117 | elif params.model_name.lower() in ["vall-f", "vallf"]:
118 | model = VALLF(
119 | params.decoder_dim,
120 | params.nhead,
121 | params.num_decoder_layers,
122 | norm_first=params.norm_first,
123 | add_prenet=params.add_prenet,
124 | prefix_mode=params.prefix_mode,
125 | share_embedding=params.share_embedding,
126 | nar_scale_factor=params.scale_factor,
127 | prepend_bos=params.prepend_bos,
128 | num_quantizers=params.num_quantizers,
129 | )
130 | elif params.model_name.lower() in ["vall-e", "valle"]:
131 | model = VALLE(
132 | params.decoder_dim,
133 | params.nhead,
134 | params.num_decoder_layers,
135 | norm_first=params.norm_first,
136 | add_prenet=params.add_prenet,
137 | prefix_mode=params.prefix_mode,
138 | share_embedding=params.share_embedding,
139 | nar_scale_factor=params.scale_factor,
140 | prepend_bos=params.prepend_bos,
141 | num_quantizers=params.num_quantizers,
142 | )
143 | else:
144 | assert params.model_name in ["Transformer"]
145 | model = Transformer(
146 | params.decoder_dim,
147 | params.nhead,
148 | params.num_decoder_layers,
149 | norm_first=params.norm_first,
150 | add_prenet=params.add_prenet,
151 | scaling_xformers=params.scaling_xformers,
152 | )
153 |
154 | return model
155 |
--------------------------------------------------------------------------------
/models/macros.py:
--------------------------------------------------------------------------------
1 | # Text
2 | NUM_TEXT_TOKENS = 512
3 |
4 | # Audio
5 | NUM_AUDIO_TOKENS = 1024 # EnCodec RVQ bins
6 | NUM_MEL_BINS = 100 # BigVGAN bigvgan_24khz_100band
7 |
8 |
9 | # Speaker
10 | NUM_SPEAKER_CLASSES = 4096
11 | SPEAKER_EMBEDDING_DIM = 64
12 |
--------------------------------------------------------------------------------
/models/transformer.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 (authors: Feiteng Li)
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from functools import partial
16 | from typing import Any, Dict, List, Tuple, Union
17 |
18 | import torch
19 | import torch.nn as nn
20 | import torch.nn.functional as F
21 | from icefall.utils import make_pad_mask
22 | from torchmetrics.classification import BinaryAccuracy
23 |
24 | from uslm.models.valle import Transpose
25 | from uslm.modules.embedding import SinePositionalEmbedding, TokenEmbedding
26 | from uslm.modules.scaling import BalancedDoubleSwish, ScaledLinear
27 | from uslm.modules.transformer import (
28 | BalancedBasicNorm,
29 | IdentityNorm,
30 | TransformerDecoderLayer,
31 | TransformerEncoder,
32 | TransformerEncoderLayer,
33 | )
34 |
35 | from .macros import NUM_MEL_BINS, NUM_TEXT_TOKENS
36 | from .visualizer import visualize
37 |
38 | IdentityNorm = IdentityNorm
39 |
40 |
41 | class Transformer(nn.Module):
42 | """It implements seq2seq Transformer TTS for debug(No StopPredictor and SpeakerEmbeding)
43 | Neural Speech Synthesis with Transformer Network
44 | https://arxiv.org/abs/1809.08895
45 | """
46 |
47 | def __init__(
48 | self,
49 | d_model: int,
50 | nhead: int,
51 | num_layers: int,
52 | norm_first: bool = True,
53 | add_prenet: bool = False,
54 | scaling_xformers: bool = False,
55 | ):
56 | """
57 | Args:
58 | d_model:
59 | The number of expected features in the input (required).
60 | nhead:
61 | The number of heads in the multiheadattention models (required).
62 | num_layers:
63 | The number of sub-decoder-layers in the decoder (required).
64 | """
65 | super().__init__()
66 | self.text_embedding = TokenEmbedding(d_model, NUM_TEXT_TOKENS) # W_x
67 |
68 | if add_prenet:
69 | self.encoder_prenet = nn.Sequential(
70 | Transpose(),
71 | nn.Conv1d(d_model, d_model, kernel_size=5, padding="same"),
72 | nn.BatchNorm1d(d_model),
73 | nn.ReLU(),
74 | nn.Dropout(0.5),
75 | nn.Conv1d(d_model, d_model, kernel_size=5, padding="same"),
76 | nn.BatchNorm1d(d_model),
77 | nn.ReLU(),
78 | nn.Dropout(0.5),
79 | nn.Conv1d(d_model, d_model, kernel_size=5, padding="same"),
80 | nn.BatchNorm1d(d_model),
81 | nn.ReLU(),
82 | nn.Dropout(0.5),
83 | Transpose(),
84 | nn.Linear(d_model, d_model),
85 | )
86 |
87 | self.decoder_prenet = nn.Sequential(
88 | nn.Linear(NUM_MEL_BINS, 256),
89 | nn.ReLU(),
90 | nn.Dropout(0.5),
91 | nn.Linear(256, 256),
92 | nn.ReLU(),
93 | nn.Dropout(0.5),
94 | nn.Linear(256, d_model),
95 | )
96 |
97 | assert scaling_xformers is False # TODO: update this block
98 | else:
99 | self.encoder_prenet = nn.Identity()
100 | if scaling_xformers:
101 | self.decoder_prenet = ScaledLinear(NUM_MEL_BINS, d_model)
102 | else:
103 | self.decoder_prenet = nn.Linear(NUM_MEL_BINS, d_model)
104 |
105 | self.encoder_position = SinePositionalEmbedding(
106 | d_model,
107 | dropout=0.1,
108 | scale=False,
109 | )
110 | self.decoder_position = SinePositionalEmbedding(
111 | d_model, dropout=0.1, scale=False
112 | )
113 |
114 | if scaling_xformers:
115 | self.encoder = TransformerEncoder(
116 | TransformerEncoderLayer(
117 | d_model,
118 | nhead,
119 | dim_feedforward=d_model * 4,
120 | dropout=0.1,
121 | batch_first=True,
122 | norm_first=norm_first,
123 | linear1_self_attention_cls=ScaledLinear,
124 | linear2_self_attention_cls=partial(
125 | ScaledLinear, initial_scale=0.01
126 | ),
127 | linear1_feedforward_cls=ScaledLinear,
128 | linear2_feedforward_cls=partial(
129 | ScaledLinear, initial_scale=0.01
130 | ),
131 | activation=partial(
132 | BalancedDoubleSwish,
133 | channel_dim=-1,
134 | max_abs=10.0,
135 | min_prob=0.25,
136 | ),
137 | layer_norm_cls=IdentityNorm,
138 | ),
139 | num_layers=num_layers,
140 | norm=BalancedBasicNorm(d_model) if norm_first else None,
141 | )
142 |
143 | self.decoder = nn.TransformerDecoder(
144 | TransformerDecoderLayer(
145 | d_model,
146 | nhead,
147 | dim_feedforward=d_model * 4,
148 | dropout=0.1,
149 | batch_first=True,
150 | norm_first=norm_first,
151 | linear1_self_attention_cls=ScaledLinear,
152 | linear2_self_attention_cls=partial(
153 | ScaledLinear, initial_scale=0.01
154 | ),
155 | linear1_feedforward_cls=ScaledLinear,
156 | linear2_feedforward_cls=partial(
157 | ScaledLinear, initial_scale=0.01
158 | ),
159 | activation=partial(
160 | BalancedDoubleSwish,
161 | channel_dim=-1,
162 | max_abs=10.0,
163 | min_prob=0.25,
164 | ),
165 | layer_norm_cls=IdentityNorm,
166 | ),
167 | num_layers=num_layers,
168 | norm=BalancedBasicNorm(d_model) if norm_first else None,
169 | )
170 |
171 | self.predict_layer = ScaledLinear(d_model, NUM_MEL_BINS)
172 | self.stop_layer = nn.Linear(d_model, 1)
173 | else:
174 | self.encoder = nn.TransformerEncoder(
175 | nn.TransformerEncoderLayer(
176 | d_model,
177 | nhead,
178 | dim_feedforward=d_model * 4,
179 | activation=F.relu,
180 | dropout=0.1,
181 | batch_first=True,
182 | norm_first=norm_first,
183 | ),
184 | num_layers=num_layers,
185 | norm=nn.LayerNorm(d_model) if norm_first else None,
186 | )
187 |
188 | self.decoder = nn.TransformerDecoder(
189 | nn.TransformerDecoderLayer(
190 | d_model,
191 | nhead,
192 | dim_feedforward=d_model * 4,
193 | activation=F.relu,
194 | dropout=0.1,
195 | batch_first=True,
196 | norm_first=norm_first,
197 | ),
198 | num_layers=num_layers,
199 | norm=nn.LayerNorm(d_model) if norm_first else None,
200 | )
201 |
202 | self.predict_layer = nn.Linear(d_model, NUM_MEL_BINS)
203 | self.stop_layer = nn.Linear(d_model, 1)
204 |
205 | self.stop_accuracy_metric = BinaryAccuracy(
206 | threshold=0.5, multidim_average="global"
207 | )
208 |
209 | # self.apply(self._init_weights)
210 |
211 | # def _init_weights(self, module):
212 | # if isinstance(module, (nn.Linear)):
213 | # module.weight.data.normal_(mean=0.0, std=0.02)
214 | # if isinstance(module, nn.Linear) and module.bias is not None:
215 | # module.bias.data.zero_()
216 | # elif isinstance(module, nn.LayerNorm):
217 | # module.bias.data.zero_()
218 | # module.weight.data.fill_(1.0)
219 | # elif isinstance(module, nn.Embedding):
220 | # module.weight.data.normal_(mean=0.0, std=0.02)
221 |
222 | def forward(
223 | self,
224 | x: torch.Tensor,
225 | x_lens: torch.Tensor,
226 | y: torch.Tensor,
227 | y_lens: torch.Tensor,
228 | reduction: str = "sum",
229 | train_stage: int = 0,
230 | **kwargs,
231 | ) -> Tuple[torch.Tensor, Union[torch.Tensor, None]]:
232 | """
233 | Args:
234 | x:
235 | A 2-D tensor of shape (N, S).
236 | x_lens:
237 | A 1-D tensor of shape (N,). It contains the number of tokens in `x`
238 | before padding.
239 | y:
240 | A 3-D tensor of shape (N, T, 8).
241 | y_lens:
242 | A 1-D tensor of shape (N,). It contains the number of tokens in `x`
243 | before padding.
244 | train_stage:
245 | Not used in this model.
246 | Returns:
247 | Return the predicted audio code matrix, cross-entropy loss and Top-10 accuracy.
248 | """
249 | del train_stage
250 |
251 | assert x.ndim == 2, x.shape
252 | assert x_lens.ndim == 1, x_lens.shape
253 | assert y.ndim == 3, y.shape
254 | assert y_lens.ndim == 1, y_lens.shape
255 |
256 | assert torch.all(x_lens > 0)
257 |
258 | # NOTE: x has been padded in TextTokenCollater
259 | x_mask = make_pad_mask(x_lens).to(x.device)
260 |
261 | x = self.text_embedding(x)
262 | x = self.encoder_prenet(x)
263 | x = self.encoder_position(x)
264 | x = self.encoder(x, src_key_padding_mask=x_mask)
265 |
266 | total_loss, metrics = 0.0, {}
267 |
268 | y_mask = make_pad_mask(y_lens).to(y.device)
269 | y_mask_float = y_mask.type(torch.float32)
270 | data_mask = 1.0 - y_mask_float.unsqueeze(-1)
271 |
272 | # Training
273 | # AR Decoder
274 | def pad_y(y):
275 | y = F.pad(y, (0, 0, 1, 0, 0, 0), value=0).detach()
276 | # inputs, targets
277 | return y[:, :-1], y[:, 1:]
278 |
279 | y, targets = pad_y(y * data_mask) # mask padding as zeros
280 |
281 | y_emb = self.decoder_prenet(y)
282 | y_pos = self.decoder_position(y_emb)
283 |
284 | y_len = y_lens.max()
285 | tgt_mask = torch.triu(
286 | torch.ones(y_len, y_len, device=y.device, dtype=torch.bool),
287 | diagonal=1,
288 | )
289 | y_dec = self.decoder(
290 | y_pos,
291 | x,
292 | tgt_mask=tgt_mask,
293 | memory_key_padding_mask=x_mask,
294 | )
295 |
296 | predict = self.predict_layer(y_dec)
297 | # loss
298 | total_loss = F.mse_loss(predict, targets, reduction=reduction)
299 |
300 | logits = self.stop_layer(y_dec).squeeze(-1)
301 | stop_loss = F.binary_cross_entropy_with_logits(
302 | logits,
303 | y_mask_float.detach(),
304 | weight=1.0 + y_mask_float.detach() * 4.0,
305 | reduction=reduction,
306 | )
307 | metrics["stop_loss"] = stop_loss.detach()
308 |
309 | stop_accuracy = self.stop_accuracy_metric(
310 | (torch.sigmoid(logits) >= 0.5).type(torch.int64),
311 | y_mask.type(torch.int64),
312 | )
313 | # icefall MetricsTracker.norm_items()
314 | metrics["stop_accuracy"] = stop_accuracy.item() * y_lens.sum().type(
315 | torch.float32
316 | )
317 |
318 | return ((x, predict), total_loss + 100.0 * stop_loss, metrics)
319 |
320 | def inference(
321 | self,
322 | x: torch.Tensor,
323 | x_lens: torch.Tensor,
324 | y: Any = None,
325 | **kwargs,
326 | ) -> torch.Tensor:
327 | """
328 | Args:
329 | x:
330 | A 2-D tensor of shape (1, S).
331 | x_lens:
332 | A 1-D tensor of shape (1,). It contains the number of tokens in `x`
333 | before padding.
334 | Returns:
335 | Return the predicted audio code matrix and cross-entropy loss.
336 | """
337 | assert x.ndim == 2, x.shape
338 | assert x_lens.ndim == 1, x_lens.shape
339 |
340 | assert torch.all(x_lens > 0)
341 |
342 | x_mask = make_pad_mask(x_lens).to(x.device)
343 |
344 | x = self.text_embedding(x)
345 | x = self.encoder_prenet(x)
346 | x = self.encoder_position(x)
347 | x = self.encoder(x, src_key_padding_mask=x_mask)
348 |
349 | x_mask = make_pad_mask(x_lens).to(x.device)
350 |
351 | # AR Decoder
352 | # TODO: Managing decoder steps avoid repetitive computation
353 | y = torch.zeros(
354 | [x.shape[0], 1, NUM_MEL_BINS], dtype=torch.float32, device=x.device
355 | )
356 | while True:
357 | y_emb = self.decoder_prenet(y)
358 | y_pos = self.decoder_position(y_emb)
359 |
360 | tgt_mask = torch.triu(
361 | torch.ones(
362 | y.shape[1], y.shape[1], device=y.device, dtype=torch.bool
363 | ),
364 | diagonal=1,
365 | )
366 |
367 | y_dec = self.decoder(
368 | y_pos,
369 | x,
370 | tgt_mask=tgt_mask,
371 | memory_mask=None,
372 | memory_key_padding_mask=x_mask,
373 | )
374 | predict = self.predict_layer(y_dec[:, -1:])
375 |
376 | logits = self.stop_layer(y_dec[:, -1:]) > 0 # sigmoid(0.0) = 0.5
377 | if y.shape[1] > x_lens.max() * 10 or all(logits.cpu().numpy()):
378 | print(
379 | f"TransformerTTS EOS [Text {x_lens[0]} -> Audio {y.shape[1]}]"
380 | )
381 | break
382 |
383 | y = torch.concat([y, predict], dim=1)
384 |
385 | return y[:, 1:]
386 |
387 | def visualize(
388 | self,
389 | predicts: Tuple[torch.Tensor],
390 | batch: Dict[str, Union[List, torch.Tensor]],
391 | output_dir: str,
392 | limit: int = 4,
393 | ) -> None:
394 | visualize(predicts, batch, output_dir, limit=limit)
395 |
--------------------------------------------------------------------------------
/models/uslm.py:
--------------------------------------------------------------------------------
1 | """
2 | Created on Fri. Sept. 8 00:45:46 2023
3 | @author: Dong Zhang
4 | """
5 |
6 |
7 | import faulthandler
8 | faulthandler.enable()
9 |
10 |
11 | import random
12 | from typing import Dict, Iterator, List, Tuple, Union
13 |
14 | import torch
15 | import torch.nn as nn
16 | import torch.nn.functional as F
17 | from torchmetrics.classification import MulticlassAccuracy
18 | from icefall.utils import make_pad_mask
19 |
20 |
21 | from uslm.data.input_strategies import PromptedFeatures
22 | from uslm.modules.embedding import SinePositionalEmbedding, TokenEmbedding
23 | from uslm.modules.transformer import (
24 | AdaptiveLayerNorm,
25 | LayerNorm,
26 | TransformerDecoderLayer,
27 | TransformerEncoder,
28 | TransformerEncoderLayer,
29 | )
30 | from uslm.models.valle import VALLF, top_k_top_p_filtering, topk_sampling, Transpose
31 |
32 | from .macros import NUM_AUDIO_TOKENS, NUM_TEXT_TOKENS
33 | from .visualizer import visualize
34 |
35 |
36 | class USLM(VALLF):
37 | def __init__(
38 | self,
39 | d_model: int,
40 | nhead: int,
41 | num_layers: int,
42 | norm_first: bool = True,
43 | add_prenet: bool = False,
44 | prefix_mode: int = 0,
45 | share_embedding: bool = True,
46 | nar_scale_factor: float = 1.0,
47 | **kwargs,
48 | ):
49 | """
50 | Args:
51 | d_model:
52 | The number of expected features in the input (required).
53 | nhead:
54 | The number of heads in the multiheadattention models (required).
55 | num_layers:
56 | The number of sub-decoder-layers in the decoder (required).
57 | """
58 | super(USLM, self).__init__(
59 | d_model,
60 | nhead,
61 | num_layers,
62 | norm_first=norm_first,
63 | add_prenet=add_prenet,
64 | decoder_cls=TransformerEncoder,
65 | decoder_layer_cls=TransformerEncoderLayer,
66 | prefix_mode=prefix_mode,
67 | share_embedding=share_embedding,
68 | nar_scale_factor=nar_scale_factor,
69 | **kwargs,
70 | )
71 |
72 | def forward(
73 | self,
74 | x: torch.Tensor,
75 | x_lens: torch.Tensor,
76 | y: Union[torch.Tensor, PromptedFeatures],
77 | y_lens: Union[torch.Tensor, PromptedFeatures],
78 | reduction: str = "sum",
79 | train_stage: int = 0,
80 | **kwargs,
81 | ) -> Tuple[torch.Tensor, Union[torch.Tensor, None]]:
82 | """
83 | Args:
84 | x:
85 | A 2-D tensor of shape (N, S).
86 | x_lens:
87 | A 1-D tensor of shape (N,). It contains the number of tokens in `x`
88 | before padding.
89 | y:
90 | A 3-D tensor of shape (N, T, 8).
91 | y_lens:
92 | A 1-D tensor of shape (N,). It contains the number of tokens in `x`
93 | before padding.
94 | train_stage:
95 | 0: AR & NAR modules, 1: AR modules, 2: NAR modules
96 | Returns:
97 | Return the predicted audio code matrix, cross-entropy loss and Top-10 accuracy.
98 | """
99 | assert x.ndim == 2, x.shape
100 | assert x_lens.ndim == 1, x_lens.shape
101 |
102 | y_prompts_codes = None
103 | if isinstance(y, PromptedFeatures):
104 | y_prompts_codes, y = y.data
105 | prompts_len, y_lens = y_lens.data
106 | assert prompts_len.min() == prompts_len.max()
107 | assert self.prefix_mode == 4
108 | y_prompts_codes = y_prompts_codes.type(torch.int64)
109 |
110 | assert y.ndim == 3, y.shape
111 | assert y_lens.ndim == 1, y_lens.shape
112 |
113 | # NOTE: x has been padded in TextTokenCollater
114 | x_mask = make_pad_mask(x_lens).to(x.device)
115 | y_mask = make_pad_mask(y_lens).to(y.device)
116 | y_mask_int = y_mask.type(torch.int64)
117 |
118 | text = x
119 | codes = y.type(torch.int64) * (1 - y_mask_int.unsqueeze(dim=-1))
120 |
121 | y, targets = self.pad_y_eos(
122 | codes[..., 0], y_mask_int, eos_id=NUM_AUDIO_TOKENS
123 | )
124 |
125 | x_len = x_lens.max()
126 |
127 | metrics = {}
128 | total_loss = 0.0
129 |
130 | xy_padding_mask = torch.concat([x_mask, y_mask], dim=1)
131 | if self.ar_audio_prepend_bos:
132 | ar_xy_padding_mask = torch.concat(
133 | [x_mask, F.pad(y_mask, (1, 0), value=False)], dim=1
134 | )
135 | else:
136 | ar_xy_padding_mask = xy_padding_mask
137 | # AR Decoder
138 | if train_stage in [0, 1]:
139 | x = self.ar_text_embedding(text)
140 | x = self.ar_text_prenet(x)
141 | x = self.ar_text_position(x)
142 |
143 | y_len = y_lens.max() + int(self.ar_audio_prepend_bos)
144 |
145 | x_attn_mask = F.pad(
146 | torch.zeros((x_len, x_len), dtype=torch.bool, device=x.device),
147 | (0, y_len),
148 | value=True,
149 | )
150 | y_attn_mask = F.pad(
151 | torch.triu(
152 | torch.ones(y_len, y_len, dtype=torch.bool, device=x.device),
153 | diagonal=1,
154 | ),
155 | (x_len, 0),
156 | value=False,
157 | )
158 | xy_attn_mask = torch.concat([x_attn_mask, y_attn_mask], dim=0)
159 |
160 | # merge key padding and attention masks
161 | bsz, src_len = x.shape[0], x_len + y_len
162 | _xy_padding_mask = (
163 | ar_xy_padding_mask.view(bsz, 1, 1, src_len)
164 | .expand(-1, self.num_heads, -1, -1)
165 | .reshape(bsz * self.num_heads, 1, src_len)
166 | )
167 | xy_attn_mask = xy_attn_mask.logical_or(_xy_padding_mask)
168 |
169 | new_attn_mask = torch.zeros_like(xy_attn_mask, dtype=x.dtype)
170 | new_attn_mask.masked_fill_(xy_attn_mask, float("-inf"))
171 | xy_attn_mask = new_attn_mask
172 |
173 | y_emb = self.ar_audio_embedding(y)
174 | y_emb = self.ar_audio_prenet(y_emb)
175 | y_pos = self.ar_audio_position(y_emb)
176 |
177 | xy_pos = torch.concat([x, y_pos], dim=1)
178 |
179 | xy_dec, _ = self.ar_decoder(
180 | (xy_pos, None),
181 | mask=xy_attn_mask,
182 | # src_key_padding_mask=xy_padding_mask,
183 | # is_causal=True,
184 | )
185 | logits = self.ar_predict_layer(xy_dec[:, x_len:]).permute(0, 2, 1)
186 | # loss
187 | total_loss = F.cross_entropy(logits, targets, reduction=reduction)
188 |
189 | metrics["ArTop10Accuracy"] = self.ar_accuracy_metric(
190 | logits.detach(), targets
191 | ).item() * y_lens.sum().type(torch.float32)
192 |
193 | if self.num_quantizers == 1:
194 | return ((x, codes), total_loss, metrics)
195 |
196 | # Non-AR Decoders
197 | if self.ar_audio_prepend_bos:
198 | y = y[:, 1:]
199 | if train_stage in [0, 2]:
200 | num_nar_layers = self.num_quantizers - 1
201 |
202 | #layer2: contain most timbre information
203 | nar_stage = 1
204 |
205 | x = self.nar_text_embedding(text)
206 | x = self.nar_text_prenet(x)
207 | x = self.nar_text_position(x)
208 |
209 | y_emb, prefix_len = self._prepare_prompts(
210 | y, y_lens, codes, nar_stage, y_prompts_codes
211 | )
212 |
213 | y_len = y_lens.max()
214 | targets = codes[..., nar_stage] + NUM_AUDIO_TOKENS * y_mask_int
215 | if self.prefix_mode in [2, 4]:
216 | xy_padding_mask = torch.concat(
217 | [
218 | x_mask,
219 | F.pad(y_mask, (y_emb.shape[1] - y_len, 0), value=False),
220 | ],
221 | dim=1,
222 | )
223 | elif self.prefix_mode == 1:
224 | targets = targets[:, prefix_len:]
225 |
226 | y_pos = self.nar_audio_prenet(y_emb)
227 | y_pos = self.nar_audio_position(y_pos)
228 | xy_pos = torch.concat([x, y_pos], dim=1)
229 | xy_dec, _ = self.nar_decoder(
230 | (xy_pos, self.nar_stage_embeddings[nar_stage - 1].weight),
231 | src_key_padding_mask=xy_padding_mask,
232 | # is_causal=False,
233 | )
234 | xy_dec = xy_dec[:, x_lens.max() + prefix_len :]
235 | if self.prefix_mode == 4:
236 | prefix_len = 0 # reset for Top10Accuracy metric
237 | logits = self.nar_predict_layers[nar_stage - 1](xy_dec).permute(
238 | 0, 2, 1
239 | )
240 |
241 | # loss
242 | total_length = (y_lens).sum().type(torch.float32)
243 | total_loss += (
244 | F.cross_entropy(
245 | logits,
246 | targets,
247 | ignore_index=NUM_AUDIO_TOKENS,
248 | reduction=reduction,
249 | )
250 | * (total_length / (total_length - prefix_len * x.shape[0]))
251 | )
252 | metrics["NarL2Top10Accuracy"] = (
253 | self.nar_accuracy_metric(
254 | F.pad(
255 | logits.detach(),
256 | (0, 0, 0, 1, 0, 0),
257 | value=logits.min().cpu().item(),
258 | ),
259 | targets,
260 | ).item()
261 | * total_length
262 | )
263 |
264 |
265 | #layer3-8
266 | nar_stage = self.rng.choices(
267 | [_k for _k in range(2, self.num_quantizers)],
268 | weights=[1.0 / (num_nar_layers-1)] * (num_nar_layers-1),
269 | k=1,
270 | )[0]
271 |
272 | x = self.nar_text_embedding(text)
273 | x = self.nar_text_prenet(x)
274 | x = self.nar_text_position(x)
275 |
276 | y_emb, prefix_len = self._prepare_prompts(
277 | y, y_lens, codes, nar_stage, y_prompts_codes
278 | )
279 |
280 | y_len = y_lens.max()
281 | targets = codes[..., nar_stage] + NUM_AUDIO_TOKENS * y_mask_int
282 | if self.prefix_mode in [2, 4]:
283 | xy_padding_mask = torch.concat(
284 | [
285 | x_mask,
286 | F.pad(y_mask, (y_emb.shape[1] - y_len, 0), value=False),
287 | ],
288 | dim=1,
289 | )
290 | elif self.prefix_mode == 1:
291 | targets = targets[:, prefix_len:]
292 |
293 | y_pos = self.nar_audio_prenet(y_emb)
294 | y_pos = self.nar_audio_position(y_pos)
295 | xy_pos = torch.concat([x, y_pos], dim=1)
296 | xy_dec, _ = self.nar_decoder(
297 | (xy_pos, self.nar_stage_embeddings[nar_stage - 1].weight),
298 | src_key_padding_mask=xy_padding_mask,
299 | # is_causal=False,
300 | )
301 | xy_dec = xy_dec[:, x_lens.max() + prefix_len :]
302 | if self.prefix_mode == 4:
303 | prefix_len = 0 # reset for Top10Accuracy metric
304 | logits = self.nar_predict_layers[nar_stage - 1](xy_dec).permute(
305 | 0, 2, 1
306 | )
307 |
308 | # loss
309 | total_length = (y_lens).sum().type(torch.float32)
310 | total_loss += (
311 | F.cross_entropy(
312 | logits,
313 | targets,
314 | ignore_index=NUM_AUDIO_TOKENS,
315 | reduction=reduction,
316 | )
317 | * (total_length / (total_length - prefix_len * x.shape[0]))
318 | )
319 | metrics["NarL3-8Top10Accuracy"] = (
320 | self.nar_accuracy_metric(
321 | F.pad(
322 | logits.detach(),
323 | (0, 0, 0, 1, 0, 0),
324 | value=logits.min().cpu().item(),
325 | ),
326 | targets,
327 | ).item()
328 | * total_length
329 | )
330 |
331 | if train_stage == 0:
332 | total_loss = total_loss / 2.0
333 |
334 | return ((x, codes), total_loss, metrics)
335 |
336 | def inference(
337 | self,
338 | x: torch.Tensor,
339 | x_lens: torch.Tensor,
340 | y: torch.Tensor,
341 | enroll_x_lens: torch.Tensor,
342 | top_k: int = -100,
343 | temperature: float = 1.0,
344 | ) -> torch.Tensor:
345 | """
346 | Args:
347 | x:
348 | A 2-D tensor of shape (1, S).
349 | x_lens:
350 | A 1-D tensor of shape (1,). It contains the number of tokens in `x`
351 | before padding.
352 | y:
353 | A 3-D tensor of shape (1, T, 8).
354 | top_k: (`optional`) int
355 | The number of highest probability tokens to keep for top-k-filtering. Default to -100.
356 | temperature: (`optional`) float
357 | The value used to module the next token probabilities. Must be strictly positive. Default to 1.0.
358 | Returns:
359 | Return the predicted audio code matrix.
360 | """
361 | assert x.ndim == 2, x.shape
362 | assert x_lens.ndim == 1, x_lens.shape
363 | assert y.ndim == 3, y.shape
364 | assert y.shape[0] == 1, y.shape
365 |
366 | assert torch.all(x_lens > 0)
367 |
368 | # NOTE: x has been padded in TextTokenCollater
369 | text = x
370 | x = self.ar_text_embedding(text)
371 | x = self.ar_text_prenet(x)
372 | x = self.ar_text_position(x)
373 |
374 | text_len = x_lens.max()
375 | prompts = y
376 | prefix_len = y.shape[1]
377 |
378 | # AR Decoder
379 | # TODO: Managing decoder steps avoid repetitive computation
380 | y = prompts[..., 0]
381 | if self.ar_audio_prepend_bos:
382 | y = F.pad(y, (1, 0), value=NUM_AUDIO_TOKENS + 1)
383 |
384 | x_len = x_lens.max()
385 | x_attn_mask = torch.zeros((x_len, x_len), dtype=torch.bool)
386 |
387 | while True:
388 | y_emb = self.ar_audio_embedding(y)
389 | y_emb = self.ar_audio_prenet(y_emb)
390 | y_pos = self.ar_audio_position(y_emb)
391 | xy_pos = torch.concat([x, y_pos], dim=1)
392 |
393 | y_len = y.shape[1]
394 | x_attn_mask_pad = F.pad(
395 | x_attn_mask,
396 | (0, y_len),
397 | value=True,
398 | )
399 | y_attn_mask = F.pad(
400 | torch.triu(
401 | torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1
402 | ),
403 | (x_len, 0),
404 | value=False,
405 | )
406 | xy_attn_mask = torch.concat(
407 | [x_attn_mask_pad, y_attn_mask], dim=0
408 | ).to(y.device)
409 |
410 | xy_dec, _ = self.ar_decoder(
411 | (xy_pos, None),
412 | mask=xy_attn_mask,
413 | )
414 | logits = self.ar_predict_layer(xy_dec[:, -1])
415 | samples = topk_sampling(
416 | logits, top_k=top_k, top_p=1.0, temperature=temperature
417 | )
418 |
419 | if (
420 | torch.argmax(logits, dim=-1)[0] == NUM_AUDIO_TOKENS
421 | or samples[0, 0] == NUM_AUDIO_TOKENS
422 | or (y.shape[1] - prompts.shape[1]) > x_lens.max() * 16
423 | ):
424 | if prompts.shape[1] == y.shape[1]:
425 | raise SyntaxError(
426 | "well trained model shouldn't reach here."
427 | )
428 |
429 | print(f"USLM EOS [{prompts.shape[1]} -> {y.shape[1]}]")
430 | break
431 |
432 | y = torch.concat([y, samples], dim=1)
433 |
434 | codes = [y[:, prefix_len + int(self.ar_audio_prepend_bos) :]]
435 | if self.num_quantizers == 1:
436 | return torch.stack(codes, dim=-1)
437 |
438 | # Non-AR Decoders
439 | y_emb = self.nar_audio_embeddings[0](
440 | y[:, int(self.ar_audio_prepend_bos) :]
441 | )
442 |
443 | if self.prefix_mode in [2, 4]: # Exclude enrolled_phonemes
444 | enrolled_len = enroll_x_lens.max().item()
445 | # SOS + Synthesis Text + EOS
446 | text = torch.concat(
447 | [
448 | text[:, :1],
449 | text[:, enrolled_len - 1 :],
450 | ],
451 | dim=1,
452 | )
453 | text_len = text_len - (enrolled_len - 2)
454 | assert text.shape[0] == 1
455 |
456 | x = self.nar_text_embedding(text)
457 | x = self.nar_text_prenet(x)
458 | x = self.nar_text_position(x)
459 |
460 | if self.prefix_mode == 0:
461 | for i, (predict_layer, embedding_layer) in enumerate(
462 | zip(
463 | self.nar_predict_layers,
464 | self.nar_audio_embeddings[1:],
465 | )
466 | ):
467 | y_pos = self.nar_audio_prenet(y_emb)
468 | y_pos = self.nar_audio_position(y_pos)
469 | xy_pos = torch.concat([x, y_pos], dim=1)
470 |
471 | xy_dec, _ = self.nar_decoder(
472 | (xy_pos, self.nar_stage_embeddings[i].weight)
473 | )
474 | logits = predict_layer(xy_dec[:, text_len + prefix_len :])
475 |
476 | samples = torch.argmax(logits, dim=-1)
477 | codes.append(samples)
478 |
479 | if i < self.num_quantizers - 2:
480 | y_emb[:, :prefix_len] += embedding_layer(
481 | prompts[..., i + 1]
482 | )
483 | y_emb[:, prefix_len:] += embedding_layer(samples)
484 | else:
485 | for j in range(1, self.num_quantizers):
486 | y_emb[:, :prefix_len] += self.nar_audio_embeddings[j](
487 | prompts[..., j]
488 | )
489 |
490 | for i, (predict_layer, embedding_layer) in enumerate(
491 | zip(
492 | self.nar_predict_layers,
493 | self.nar_audio_embeddings[1:],
494 | )
495 | ):
496 | y_pos = self.nar_audio_prenet(y_emb)
497 | y_pos = self.nar_audio_position(y_pos)
498 | xy_pos = torch.concat([x, y_pos], dim=1)
499 |
500 | xy_dec, _ = self.nar_decoder(
501 | (xy_pos, self.nar_stage_embeddings[i].weight)
502 | )
503 | logits = predict_layer(xy_dec[:, text_len + prefix_len :])
504 |
505 | samples = torch.argmax(logits, dim=-1)
506 | codes.append(samples)
507 |
508 | if i < self.num_quantizers - 2:
509 | y_emb[:, prefix_len:] += embedding_layer(samples)
510 |
511 | assert len(codes) == self.num_quantizers
512 | return torch.stack(codes, dim=-1)
--------------------------------------------------------------------------------
/models/visualizer.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright 2023 (authors: Feiteng Li)
3 | #
4 | # See ../../../../LICENSE for clarification regarding multiple authors
5 | #
6 | # Licensed under the Apache License, Version 2.0 (the "License");
7 | # you may not use this file except in compliance with the License.
8 | # You may obtain a copy of the License at
9 | #
10 | # http://www.apache.org/licenses/LICENSE-2.0
11 | #
12 | # Unless required by applicable law or agreed to in writing, software
13 | # distributed under the License is distributed on an "AS IS" BASIS,
14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 | # See the License for the specific language governing permissions and
16 | # limitations under the License.
17 |
18 |
19 | from typing import Dict, List, Tuple, Union
20 |
21 | import matplotlib.pyplot as plt
22 | import numpy as np
23 | import torch
24 |
25 |
26 | def visualize(
27 | predicts: Tuple[torch.Tensor],
28 | batch: Dict[str, Union[List, torch.Tensor]],
29 | output_dir: str,
30 | limit: int = 4,
31 | ) -> None:
32 | text_tokens = batch["text_tokens"].to("cpu").detach().numpy()
33 | text_tokens_lens = batch["text_tokens_lens"].to("cpu").detach().numpy()
34 | audio_features = batch["audio_features"].to("cpu").detach().numpy()
35 | audio_features_lens = (
36 | batch["audio_features_lens"].to("cpu").detach().numpy()
37 | )
38 | assert text_tokens.ndim == 2
39 |
40 | utt_ids, texts = batch["utt_id"], batch["text"]
41 |
42 | encoder_outputs = predicts[0].to("cpu").type(torch.float32).detach().numpy()
43 | decoder_outputs = predicts[1]
44 | if isinstance(decoder_outputs, list):
45 | decoder_outputs = decoder_outputs[-1]
46 | decoder_outputs = (
47 | decoder_outputs.to("cpu").type(torch.float32).detach().numpy()
48 | )
49 |
50 | vmin, vmax = 0, 1024 # Encodec
51 | if decoder_outputs.dtype == np.float32:
52 | vmin, vmax = -6, 0 # Fbank
53 |
54 | num_figures = 3
55 | for b, (utt_id, text) in enumerate(zip(utt_ids[:limit], texts[:limit])):
56 | _ = plt.figure(figsize=(14, 8 * num_figures))
57 |
58 | S = text_tokens_lens[b]
59 | T = audio_features_lens[b]
60 |
61 | # encoder
62 | plt.subplot(num_figures, 1, 1)
63 | plt.title(f"Text: {text}")
64 | plt.imshow(
65 | X=np.transpose(encoder_outputs[b]),
66 | cmap=plt.get_cmap("jet"),
67 | aspect="auto",
68 | interpolation="nearest",
69 | )
70 | plt.gca().invert_yaxis()
71 | plt.axvline(x=S - 0.4, linewidth=2, color="r")
72 | plt.xlabel("Encoder Output")
73 | plt.colorbar()
74 |
75 | # decoder
76 | plt.subplot(num_figures, 1, 2)
77 | plt.imshow(
78 | X=np.transpose(decoder_outputs[b]),
79 | cmap=plt.get_cmap("jet"),
80 | aspect="auto",
81 | interpolation="nearest",
82 | vmin=vmin,
83 | vmax=vmax,
84 | )
85 | plt.gca().invert_yaxis()
86 | plt.axvline(x=T - 0.4, linewidth=2, color="r")
87 | plt.xlabel("Decoder Output")
88 | plt.colorbar()
89 |
90 | # target
91 | plt.subplot(num_figures, 1, 3)
92 | plt.imshow(
93 | X=np.transpose(audio_features[b]),
94 | cmap=plt.get_cmap("jet"),
95 | aspect="auto",
96 | interpolation="nearest",
97 | vmin=vmin,
98 | vmax=vmax,
99 | )
100 | plt.gca().invert_yaxis()
101 | plt.axvline(x=T - 0.4, linewidth=2, color="r")
102 | plt.xlabel("Decoder Target")
103 | plt.colorbar()
104 |
105 | plt.savefig(f"{output_dir}/{utt_id}.png")
106 | plt.close()
107 |
--------------------------------------------------------------------------------
/modules/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/0nutation/USLM/b6cbe07feaeb142fa395b8b98836c534fec16bc9/modules/.DS_Store
--------------------------------------------------------------------------------
/modules/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/0nutation/USLM/b6cbe07feaeb142fa395b8b98836c534fec16bc9/modules/__init__.py
--------------------------------------------------------------------------------
/modules/activation.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, Tuple
2 |
3 | import torch
4 | from torch import Tensor
5 | from torch.nn import Linear, Module
6 | from torch.nn import functional as F
7 | from torch.nn.init import constant_, xavier_normal_, xavier_uniform_
8 | from torch.nn.modules.linear import NonDynamicallyQuantizableLinear
9 | from torch.nn.parameter import Parameter
10 |
11 |
12 | class MultiheadAttention(Module):
13 | r"""Allows the model to jointly attend to information
14 | from different representation subspaces as described in the paper:
15 | `Attention Is All You Need `_.
16 |
17 | Multi-Head Attention is defined as:
18 |
19 | .. math::
20 | \text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O
21 |
22 | where :math:`head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)`.
23 |
24 | ``forward()`` will use a special optimized implementation if all of the following
25 | conditions are met:
26 |
27 | - self attention is being computed (i.e., ``query``, ``key``, and ``value`` are the same tensor. This
28 | restriction will be loosened in the future.)
29 | - Either autograd is disabled (using ``torch.inference_mode`` or ``torch.no_grad``) or no tensor argument ``requires_grad``
30 | - training is disabled (using ``.eval()``)
31 | - dropout is 0
32 | - ``add_bias_kv`` is ``False``
33 | - ``add_zero_attn`` is ``False``
34 | - ``batch_first`` is ``True`` and the input is batched
35 | - ``kdim`` and ``vdim`` are equal to ``embed_dim``
36 | - at most one of ``key_padding_mask`` or ``attn_mask`` is passed
37 | - if a `NestedTensor `_ is passed, neither ``key_padding_mask``
38 | nor ``attn_mask`` is passed
39 |
40 | If the optimized implementation is in use, a
41 | `NestedTensor `_ can be passed for
42 | ``query``/``key``/``value`` to represent padding more efficiently than using a
43 | padding mask. In this case, a `NestedTensor `_
44 | will be returned, and an additional speedup proportional to the fraction of the input
45 | that is padding can be expected.
46 |
47 | Args:
48 | embed_dim: Total dimension of the model.
49 | num_heads: Number of parallel attention heads. Note that ``embed_dim`` will be split
50 | across ``num_heads`` (i.e. each head will have dimension ``embed_dim // num_heads``).
51 | dropout: Dropout probability on ``attn_output_weights``. Default: ``0.0`` (no dropout).
52 | bias: If specified, adds bias to input / output projection layers. Default: ``True``.
53 | add_bias_kv: If specified, adds bias to the key and value sequences at dim=0. Default: ``False``.
54 | add_zero_attn: If specified, adds a new batch of zeros to the key and value sequences at dim=1.
55 | Default: ``False``.
56 | kdim: Total number of features for keys. Default: ``None`` (uses ``kdim=embed_dim``).
57 | vdim: Total number of features for values. Default: ``None`` (uses ``vdim=embed_dim``).
58 | batch_first: If ``True``, then the input and output tensors are provided
59 | as (batch, seq, feature). Default: ``False`` (seq, batch, feature).
60 |
61 | Examples::
62 |
63 | >>> # xdoctest: +SKIP
64 | >>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
65 | >>> attn_output, attn_output_weights = multihead_attn(query, key, value)
66 |
67 | """
68 | __constants__ = ["batch_first"]
69 | bias_k: Optional[torch.Tensor]
70 | bias_v: Optional[torch.Tensor]
71 |
72 | def __init__(
73 | self,
74 | embed_dim,
75 | num_heads,
76 | dropout=0.0,
77 | bias=True,
78 | add_bias_kv=False,
79 | add_zero_attn=False,
80 | kdim=None,
81 | vdim=None,
82 | batch_first=False,
83 | linear1_cls=Linear,
84 | linear2_cls=Linear,
85 | device=None,
86 | dtype=None,
87 | ) -> None:
88 | factory_kwargs = {"device": device, "dtype": dtype}
89 | super(MultiheadAttention, self).__init__()
90 | self.embed_dim = embed_dim
91 | self.kdim = kdim if kdim is not None else embed_dim
92 | self.vdim = vdim if vdim is not None else embed_dim
93 | self._qkv_same_embed_dim = (
94 | self.kdim == embed_dim and self.vdim == embed_dim
95 | )
96 |
97 | self.num_heads = num_heads
98 | self.dropout = dropout
99 | self.batch_first = batch_first
100 | self.head_dim = embed_dim // num_heads
101 | assert (
102 | self.head_dim * num_heads == self.embed_dim
103 | ), "embed_dim must be divisible by num_heads"
104 |
105 | if add_bias_kv:
106 | self.bias_k = Parameter(
107 | torch.empty((1, 1, embed_dim), **factory_kwargs)
108 | )
109 | self.bias_v = Parameter(
110 | torch.empty((1, 1, embed_dim), **factory_kwargs)
111 | )
112 | else:
113 | self.bias_k = self.bias_v = None
114 |
115 | if linear1_cls == Linear:
116 | if not self._qkv_same_embed_dim:
117 | self.q_proj_weight = Parameter(
118 | torch.empty((embed_dim, embed_dim), **factory_kwargs)
119 | )
120 | self.k_proj_weight = Parameter(
121 | torch.empty((embed_dim, self.kdim), **factory_kwargs)
122 | )
123 | self.v_proj_weight = Parameter(
124 | torch.empty((embed_dim, self.vdim), **factory_kwargs)
125 | )
126 | self.register_parameter("in_proj_weight", None)
127 | else:
128 | self.in_proj_weight = Parameter(
129 | torch.empty((3 * embed_dim, embed_dim), **factory_kwargs)
130 | )
131 | self.register_parameter("q_proj_weight", None)
132 | self.register_parameter("k_proj_weight", None)
133 | self.register_parameter("v_proj_weight", None)
134 |
135 | if bias:
136 | self.in_proj_bias = Parameter(
137 | torch.empty(3 * embed_dim, **factory_kwargs)
138 | )
139 | else:
140 | self.register_parameter("in_proj_bias", None)
141 | self.out_proj = NonDynamicallyQuantizableLinear(
142 | embed_dim, embed_dim, bias=bias, **factory_kwargs
143 | )
144 |
145 | self._reset_parameters()
146 | else:
147 | if not self._qkv_same_embed_dim:
148 | raise NotImplementedError
149 | else:
150 | self.in_proj_linear = linear1_cls(
151 | embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs
152 | )
153 | self.in_proj_weight = self.in_proj_linear.weight
154 |
155 | self.register_parameter("q_proj_weight", None)
156 | self.register_parameter("k_proj_weight", None)
157 | self.register_parameter("v_proj_weight", None)
158 |
159 | if bias:
160 | self.in_proj_bias = self.in_proj_linear.bias
161 | else:
162 | self.register_parameter("in_proj_bias", None)
163 |
164 | self.out_proj = linear2_cls(
165 | embed_dim, embed_dim, bias=bias, **factory_kwargs
166 | )
167 |
168 | if self.bias_k is not None:
169 | xavier_normal_(self.bias_k)
170 | if self.bias_v is not None:
171 | xavier_normal_(self.bias_v)
172 |
173 | self.add_zero_attn = add_zero_attn
174 |
175 | def _reset_parameters(self):
176 | if self._qkv_same_embed_dim:
177 | xavier_uniform_(self.in_proj_weight)
178 | else:
179 | xavier_uniform_(self.q_proj_weight)
180 | xavier_uniform_(self.k_proj_weight)
181 | xavier_uniform_(self.v_proj_weight)
182 |
183 | if self.in_proj_bias is not None:
184 | constant_(self.in_proj_bias, 0.0)
185 | constant_(self.out_proj.bias, 0.0)
186 |
187 | if self.bias_k is not None:
188 | xavier_normal_(self.bias_k)
189 | if self.bias_v is not None:
190 | xavier_normal_(self.bias_v)
191 |
192 | def __setstate__(self, state):
193 | # Support loading old MultiheadAttention checkpoints generated by v1.1.0
194 | if "_qkv_same_embed_dim" not in state:
195 | state["_qkv_same_embed_dim"] = True
196 |
197 | super(MultiheadAttention, self).__setstate__(state)
198 |
199 | def forward(
200 | self,
201 | query: Tensor,
202 | key: Tensor,
203 | value: Tensor,
204 | key_padding_mask: Optional[Tensor] = None,
205 | need_weights: bool = True,
206 | attn_mask: Optional[Tensor] = None,
207 | average_attn_weights: bool = True,
208 | ) -> Tuple[Tensor, Optional[Tensor]]:
209 | r"""
210 | Args:
211 | query: Query embeddings of shape :math:`(L, E_q)` for unbatched input, :math:`(L, N, E_q)` when ``batch_first=False``
212 | or :math:`(N, L, E_q)` when ``batch_first=True``, where :math:`L` is the target sequence length,
213 | :math:`N` is the batch size, and :math:`E_q` is the query embedding dimension ``embed_dim``.
214 | Queries are compared against key-value pairs to produce the output.
215 | See "Attention Is All You Need" for more details.
216 | key: Key embeddings of shape :math:`(S, E_k)` for unbatched input, :math:`(S, N, E_k)` when ``batch_first=False``
217 | or :math:`(N, S, E_k)` when ``batch_first=True``, where :math:`S` is the source sequence length,
218 | :math:`N` is the batch size, and :math:`E_k` is the key embedding dimension ``kdim``.
219 | See "Attention Is All You Need" for more details.
220 | value: Value embeddings of shape :math:`(S, E_v)` for unbatched input, :math:`(S, N, E_v)` when
221 | ``batch_first=False`` or :math:`(N, S, E_v)` when ``batch_first=True``, where :math:`S` is the source
222 | sequence length, :math:`N` is the batch size, and :math:`E_v` is the value embedding dimension ``vdim``.
223 | See "Attention Is All You Need" for more details.
224 | key_padding_mask: If specified, a mask of shape :math:`(N, S)` indicating which elements within ``key``
225 | to ignore for the purpose of attention (i.e. treat as "padding"). For unbatched `query`, shape should be :math:`(S)`.
226 | Binary and byte masks are supported.
227 | For a binary mask, a ``True`` value indicates that the corresponding ``key`` value will be ignored for
228 | the purpose of attention. For a float mask, it will be directly added to the corresponding ``key`` value.
229 | need_weights: If specified, returns ``attn_output_weights`` in addition to ``attn_outputs``.
230 | Default: ``True``.
231 | attn_mask: If specified, a 2D or 3D mask preventing attention to certain positions. Must be of shape
232 | :math:`(L, S)` or :math:`(N\cdot\text{num\_heads}, L, S)`, where :math:`N` is the batch size,
233 | :math:`L` is the target sequence length, and :math:`S` is the source sequence length. A 2D mask will be
234 | broadcasted across the batch while a 3D mask allows for a different mask for each entry in the batch.
235 | Binary, byte, and float masks are supported. For a binary mask, a ``True`` value indicates that the
236 | corresponding position is not allowed to attend. For a byte mask, a non-zero value indicates that the
237 | corresponding position is not allowed to attend. For a float mask, the mask values will be added to
238 | the attention weight.
239 | average_attn_weights: If true, indicates that the returned ``attn_weights`` should be averaged across
240 | heads. Otherwise, ``attn_weights`` are provided separately per head. Note that this flag only has an
241 | effect when ``need_weights=True``. Default: ``True`` (i.e. average weights across heads)
242 |
243 | Outputs:
244 | - **attn_output** - Attention outputs of shape :math:`(L, E)` when input is unbatched,
245 | :math:`(L, N, E)` when ``batch_first=False`` or :math:`(N, L, E)` when ``batch_first=True``,
246 | where :math:`L` is the target sequence length, :math:`N` is the batch size, and :math:`E` is the
247 | embedding dimension ``embed_dim``.
248 | - **attn_output_weights** - Only returned when ``need_weights=True``. If ``average_attn_weights=True``,
249 | returns attention weights averaged across heads of shape :math:`(L, S)` when input is unbatched or
250 | :math:`(N, L, S)`, where :math:`N` is the batch size, :math:`L` is the target sequence length, and
251 | :math:`S` is the source sequence length. If ``average_attn_weights=False``, returns attention weights per
252 | head of shape :math:`(\text{num\_heads}, L, S)` when input is unbatched or :math:`(N, \text{num\_heads}, L, S)`.
253 |
254 | .. note::
255 | `batch_first` argument is ignored for unbatched inputs.
256 | """
257 | is_batched = query.dim() == 3
258 | if key_padding_mask is not None:
259 | _kpm_dtype = key_padding_mask.dtype
260 | if _kpm_dtype != torch.bool and not torch.is_floating_point(
261 | key_padding_mask
262 | ):
263 | raise AssertionError(
264 | "only bool and floating types of key_padding_mask are supported"
265 | )
266 | why_not_fast_path = ""
267 | if not is_batched:
268 | why_not_fast_path = f"input not batched; expected query.dim() of 3 but got {query.dim()}"
269 | elif query is not key or key is not value:
270 | # When lifting this restriction, don't forget to either
271 | # enforce that the dtypes all match or test cases where
272 | # they don't!
273 | why_not_fast_path = "non-self attention was used (query, key, and value are not the same Tensor)"
274 | elif (
275 | self.in_proj_bias is not None
276 | and query.dtype != self.in_proj_bias.dtype
277 | ):
278 | why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_bias ({self.in_proj_bias.dtype}) don't match"
279 | elif (
280 | self.in_proj_weight is not None
281 | and query.dtype != self.in_proj_weight.dtype
282 | ):
283 | # this case will fail anyway, but at least they'll get a useful error message.
284 | why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_weight ({self.in_proj_weight.dtype}) don't match"
285 | elif self.training:
286 | why_not_fast_path = "training is enabled"
287 | elif not self.batch_first:
288 | why_not_fast_path = "batch_first was not True"
289 | elif self.bias_k is not None:
290 | why_not_fast_path = "self.bias_k was not None"
291 | elif self.bias_v is not None:
292 | why_not_fast_path = "self.bias_v was not None"
293 | elif self.dropout:
294 | why_not_fast_path = f"dropout was {self.dropout}, required zero"
295 | elif self.add_zero_attn:
296 | why_not_fast_path = "add_zero_attn was enabled"
297 | elif not self._qkv_same_embed_dim:
298 | why_not_fast_path = "_qkv_same_embed_dim was not True"
299 | elif attn_mask is not None:
300 | why_not_fast_path = "attn_mask was not None"
301 | elif query.is_nested and key_padding_mask is not None:
302 | why_not_fast_path = (
303 | "key_padding_mask is not supported with NestedTensor input"
304 | )
305 | elif self.num_heads % 2 == 1:
306 | why_not_fast_path = "num_heads is odd"
307 | elif torch.is_autocast_enabled():
308 | why_not_fast_path = "autocast is enabled"
309 |
310 | if not why_not_fast_path:
311 | tensor_args = (
312 | query,
313 | key,
314 | value,
315 | self.in_proj_weight,
316 | self.in_proj_bias,
317 | self.out_proj.weight,
318 | self.out_proj.bias,
319 | )
320 | # We have to use list comprehensions below because TorchScript does not support
321 | # generator expressions.
322 | if torch.overrides.has_torch_function(tensor_args):
323 | why_not_fast_path = "some Tensor argument has_torch_function"
324 | elif not all(
325 | [
326 | (x is None or x.is_cuda or "cpu" in str(x.device))
327 | for x in tensor_args
328 | ]
329 | ):
330 | why_not_fast_path = (
331 | "some Tensor argument is neither CUDA nor CPU"
332 | )
333 | elif torch.is_grad_enabled() and any(
334 | [x is not None and x.requires_grad for x in tensor_args]
335 | ):
336 | why_not_fast_path = (
337 | "grad is enabled and at least one of query or the "
338 | "input/output projection weights or biases requires_grad"
339 | )
340 | if not why_not_fast_path:
341 | return torch._native_multi_head_attention(
342 | query,
343 | key,
344 | value,
345 | self.embed_dim,
346 | self.num_heads,
347 | self.in_proj_weight,
348 | self.in_proj_bias,
349 | self.out_proj.weight,
350 | self.out_proj.bias,
351 | key_padding_mask
352 | if key_padding_mask is not None
353 | else attn_mask,
354 | need_weights,
355 | average_attn_weights,
356 | 1
357 | if key_padding_mask is not None
358 | else 0
359 | if attn_mask is not None
360 | else None,
361 | )
362 |
363 | any_nested = query.is_nested or key.is_nested or value.is_nested
364 | assert not any_nested, (
365 | "MultiheadAttention does not support NestedTensor outside of its fast path. "
366 | + f"The fast path was not hit because {why_not_fast_path}"
367 | )
368 |
369 | if self.batch_first and is_batched:
370 | # make sure that the transpose op does not affect the "is" property
371 | if key is value:
372 | if query is key:
373 | query = key = value = query.transpose(1, 0)
374 | else:
375 | query, key = [x.transpose(1, 0) for x in (query, key)]
376 | value = key
377 | else:
378 | query, key, value = [
379 | x.transpose(1, 0) for x in (query, key, value)
380 | ]
381 |
382 | if not self._qkv_same_embed_dim:
383 | attn_output, attn_output_weights = F.multi_head_attention_forward(
384 | query,
385 | key,
386 | value,
387 | self.embed_dim,
388 | self.num_heads,
389 | self.in_proj_weight,
390 | self.in_proj_bias,
391 | self.bias_k,
392 | self.bias_v,
393 | self.add_zero_attn,
394 | self.dropout,
395 | self.out_proj.weight,
396 | self.out_proj.bias,
397 | training=self.training,
398 | key_padding_mask=key_padding_mask,
399 | need_weights=need_weights,
400 | attn_mask=attn_mask,
401 | use_separate_proj_weight=True,
402 | q_proj_weight=self.q_proj_weight,
403 | k_proj_weight=self.k_proj_weight,
404 | v_proj_weight=self.v_proj_weight,
405 | average_attn_weights=average_attn_weights,
406 | )
407 | else:
408 | attn_output, attn_output_weights = F.multi_head_attention_forward(
409 | query,
410 | key,
411 | value,
412 | self.embed_dim,
413 | self.num_heads,
414 | self.in_proj_weight,
415 | self.in_proj_bias,
416 | self.bias_k,
417 | self.bias_v,
418 | self.add_zero_attn,
419 | self.dropout,
420 | self.out_proj.weight,
421 | self.out_proj.bias,
422 | training=self.training,
423 | key_padding_mask=key_padding_mask,
424 | need_weights=need_weights,
425 | attn_mask=attn_mask,
426 | average_attn_weights=average_attn_weights,
427 | )
428 | if self.batch_first and is_batched:
429 | return attn_output.transpose(1, 0), attn_output_weights
430 | else:
431 | return attn_output, attn_output_weights
432 |
--------------------------------------------------------------------------------
/modules/embedding.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 (authors: Feiteng Li)
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import math
16 |
17 | import torch
18 | import torch.nn as nn
19 |
20 |
21 | class TokenEmbedding(nn.Module):
22 | def __init__(
23 | self,
24 | dim_model: int,
25 | vocab_size: int,
26 | dropout: float = 0.0,
27 | ):
28 | super().__init__()
29 |
30 | self.vocab_size = vocab_size
31 | self.dim_model = dim_model
32 |
33 | self.dropout = torch.nn.Dropout(p=dropout)
34 | self.word_embeddings = nn.Embedding(self.vocab_size, self.dim_model)
35 |
36 | @property
37 | def weight(self) -> torch.Tensor:
38 | return self.word_embeddings.weight
39 |
40 | def embedding(self, index: int) -> torch.Tensor:
41 | return self.word_embeddings.weight[index : index + 1]
42 |
43 | def forward(self, x: torch.Tensor):
44 | X = self.word_embeddings(x)
45 | X = self.dropout(X)
46 |
47 | return X
48 |
49 |
50 | class SinePositionalEmbedding(nn.Module):
51 | def __init__(
52 | self,
53 | dim_model: int,
54 | dropout: float = 0.0,
55 | scale: bool = False,
56 | alpha: bool = False,
57 | ):
58 | super().__init__()
59 | self.dim_model = dim_model
60 | self.x_scale = math.sqrt(dim_model) if scale else 1.0
61 | self.alpha = nn.Parameter(torch.ones(1), requires_grad=alpha)
62 | self.dropout = torch.nn.Dropout(p=dropout)
63 |
64 | self.reverse = False
65 | self.pe = None
66 | self.extend_pe(torch.tensor(0.0).expand(1, 4000))
67 |
68 | def extend_pe(self, x):
69 | """Reset the positional encodings."""
70 | if self.pe is not None:
71 | if self.pe.size(1) >= x.size(1):
72 | if self.pe.dtype != x.dtype or self.pe.device != x.device:
73 | self.pe = self.pe.to(dtype=x.dtype, device=x.device)
74 | return
75 | pe = torch.zeros(x.size(1), self.dim_model)
76 | if self.reverse:
77 | position = torch.arange(
78 | x.size(1) - 1, -1, -1.0, dtype=torch.float32
79 | ).unsqueeze(1)
80 | else:
81 | position = torch.arange(
82 | 0, x.size(1), dtype=torch.float32
83 | ).unsqueeze(1)
84 | div_term = torch.exp(
85 | torch.arange(0, self.dim_model, 2, dtype=torch.float32)
86 | * -(math.log(10000.0) / self.dim_model)
87 | )
88 | pe[:, 0::2] = torch.sin(position * div_term)
89 | pe[:, 1::2] = torch.cos(position * div_term)
90 | pe = pe.unsqueeze(0)
91 | self.pe = pe.to(device=x.device, dtype=x.dtype).detach()
92 |
93 | def forward(self, x: torch.Tensor) -> torch.Tensor:
94 | self.extend_pe(x)
95 | output = x.unsqueeze(-1) if x.ndim == 2 else x
96 | output = output * self.x_scale + self.alpha * self.pe[:, : x.size(1)]
97 | return self.dropout(output)
98 |
--------------------------------------------------------------------------------
/modules/scheduler.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright 2023 (authors: Feiteng Li)
3 | #
4 | # See ../../../../LICENSE for clarification regarding multiple authors
5 | #
6 | # Licensed under the Apache License, Version 2.0 (the "License");
7 | # you may not use this file except in compliance with the License.
8 | # You may obtain a copy of the License at
9 | #
10 | # http://www.apache.org/licenses/LICENSE-2.0
11 | #
12 | # Unless required by applicable law or agreed to in writing, software
13 | # distributed under the License is distributed on an "AS IS" BASIS,
14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 | # See the License for the specific language governing permissions and
16 | # limitations under the License.
17 |
18 |
19 | import torch
20 |
21 | from uslm.modules.optim import Eden
22 |
23 |
24 | def calc_lr(step, dim_embed, warmup_steps):
25 | return dim_embed ** (-0.5) * min(
26 | step ** (-0.5), step * warmup_steps ** (-1.5)
27 | )
28 |
29 |
30 | class NoamScheduler(torch.optim.lr_scheduler._LRScheduler):
31 | def __init__(
32 | self,
33 | base_lr: float,
34 | optimizer: torch.optim.Optimizer,
35 | dim_embed: int,
36 | warmup_steps: int,
37 | last_epoch: int = -1,
38 | verbose: bool = False,
39 | ) -> None:
40 |
41 | self.dim_embed = dim_embed
42 | self.base_lr = base_lr
43 | self.warmup_steps = warmup_steps
44 | self.num_param_groups = len(optimizer.param_groups)
45 |
46 | super().__init__(optimizer, last_epoch, verbose)
47 |
48 | def get_lr(self) -> float:
49 | lr = self.base_lr * calc_lr(
50 | self._step_count, self.dim_embed, self.warmup_steps
51 | )
52 | return [lr] * self.num_param_groups
53 |
54 | def set_step(self, step: int):
55 | self._step_count = step
56 |
57 |
58 | def get_scheduler(params, optimizer):
59 | if params.scheduler_name.lower() == "eden":
60 | scheduler = Eden(optimizer, 5000, 4, warmup_batches=params.warmup_steps)
61 | elif params.scheduler_name.lower() == "noam":
62 | scheduler = NoamScheduler(
63 | params.base_lr,
64 | optimizer,
65 | params.decoder_dim,
66 | warmup_steps=params.warmup_steps,
67 | )
68 | # scheduler.set_step(params.start_batch or params.batch_idx_train)
69 | elif params.scheduler_name.lower() == "cosine":
70 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
71 | params.warmup_steps,
72 | optimizer,
73 | eta_min=params.base_lr,
74 | )
75 | else:
76 | raise NotImplementedError(f"{params.scheduler_name}")
77 |
78 | return scheduler
79 |
--------------------------------------------------------------------------------
/modules/transformer.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import numbers
3 | from functools import partial
4 | from typing import Any, Callable, List, Optional, Tuple, Union
5 |
6 | import torch
7 | from torch import Tensor, nn
8 | from torch.nn import functional as F
9 |
10 | from .activation import MultiheadAttention
11 | from .scaling import ActivationBalancer, BalancedDoubleSwish
12 | from .scaling import BasicNorm as _BasicNorm
13 |
14 | _shape_t = Union[int, List[int], torch.Size]
15 |
16 |
17 | class LayerNorm(nn.Module):
18 | __constants__ = ["normalized_shape", "eps", "elementwise_affine"]
19 | normalized_shape: Tuple[int, ...]
20 | eps: float
21 | elementwise_affine: bool
22 |
23 | def __init__(
24 | self,
25 | normalized_shape: _shape_t,
26 | eps: float = 1e-5,
27 | elementwise_affine: bool = True,
28 | device=None,
29 | dtype=None,
30 | ) -> None:
31 | factory_kwargs = {"device": device, "dtype": dtype}
32 | super(LayerNorm, self).__init__()
33 | if isinstance(normalized_shape, numbers.Integral):
34 | # mypy error: incompatible types in assignment
35 | normalized_shape = (normalized_shape,) # type: ignore[assignment]
36 | self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type]
37 | self.eps = eps
38 | self.elementwise_affine = elementwise_affine
39 | if self.elementwise_affine:
40 | self.weight = nn.Parameter(
41 | torch.empty(self.normalized_shape, **factory_kwargs)
42 | )
43 | self.bias = nn.Parameter(
44 | torch.empty(self.normalized_shape, **factory_kwargs)
45 | )
46 | else:
47 | self.register_parameter("weight", None)
48 | self.register_parameter("bias", None)
49 |
50 | self.reset_parameters()
51 |
52 | def reset_parameters(self) -> None:
53 | if self.elementwise_affine:
54 | nn.init.ones_(self.weight)
55 | nn.init.zeros_(self.bias)
56 |
57 | def forward(self, input: Tensor, embedding: Any = None) -> Tensor:
58 | if isinstance(input, tuple):
59 | input, embedding = input
60 | return (
61 | F.layer_norm(
62 | input,
63 | self.normalized_shape,
64 | self.weight,
65 | self.bias,
66 | self.eps,
67 | ),
68 | embedding,
69 | )
70 |
71 | assert embedding is None
72 | return F.layer_norm(
73 | input, self.normalized_shape, self.weight, self.bias, self.eps
74 | )
75 |
76 | def extra_repr(self) -> str:
77 | return (
78 | "{normalized_shape}, eps={eps}, "
79 | "elementwise_affine={elementwise_affine}".format(**self.__dict__)
80 | )
81 |
82 |
83 | class AdaptiveLayerNorm(nn.Module):
84 | r"""Adaptive Layer Normalization"""
85 |
86 | def __init__(self, d_model, norm) -> None:
87 | super(AdaptiveLayerNorm, self).__init__()
88 | self.project_layer = nn.Linear(d_model, 2 * d_model)
89 | self.norm = norm
90 | self.d_model = d_model
91 | self.eps = self.norm.eps
92 |
93 | def forward(self, input: Tensor, embedding: Tensor = None) -> Tensor:
94 | if isinstance(input, tuple):
95 | input, embedding = input
96 | weight, bias = torch.split(
97 | self.project_layer(embedding),
98 | split_size_or_sections=self.d_model,
99 | dim=-1,
100 | )
101 | return (weight * self.norm(input) + bias, embedding)
102 |
103 | weight, bias = torch.split(
104 | self.project_layer(embedding),
105 | split_size_or_sections=self.d_model,
106 | dim=-1,
107 | )
108 | return weight * self.norm(input) + bias
109 |
110 |
111 | class BasicNorm(_BasicNorm):
112 | def __init__(
113 | self,
114 | d_model: int,
115 | eps: float = 1e-5,
116 | device=None,
117 | dtype=None,
118 | ):
119 | super(BasicNorm, self).__init__(d_model, eps=eps)
120 |
121 | def forward(self, input: Tensor, embedding: Any = None) -> Tensor:
122 | if isinstance(input, tuple):
123 | input, embedding = input
124 | return (
125 | super(BasicNorm, self).forward(input),
126 | embedding,
127 | )
128 |
129 | assert embedding is None
130 | return super(BasicNorm, self).forward(input)
131 |
132 |
133 | class BalancedBasicNorm(nn.Module):
134 | def __init__(
135 | self,
136 | d_model: int,
137 | eps: float = 1e-5,
138 | device=None,
139 | dtype=None,
140 | ):
141 | super(BalancedBasicNorm, self).__init__()
142 | self.balancer = ActivationBalancer(
143 | d_model,
144 | channel_dim=-1,
145 | min_positive=0.45,
146 | max_positive=0.55,
147 | max_abs=6.0,
148 | )
149 | self.norm = BasicNorm(d_model, eps, device=device, dtype=dtype)
150 |
151 | def forward(self, input: Tensor, embedding: Any = None) -> Tensor:
152 | if isinstance(input, tuple):
153 | input, embedding = input
154 | return self.norm((self.balancer(input), embedding))
155 |
156 | assert embedding is None
157 | return self.norm(self.balancer(input))
158 |
159 |
160 | class IdentityNorm(nn.Module):
161 | def __init__(
162 | self,
163 | d_model: int,
164 | eps: float = 1e-5,
165 | device=None,
166 | dtype=None,
167 | ) -> None:
168 | super(IdentityNorm, self).__init__()
169 |
170 | def forward(self, input: Tensor, embedding: Any = None) -> Tensor:
171 | if isinstance(input, tuple):
172 | return input
173 |
174 | assert embedding is None
175 | return input
176 |
177 |
178 | class TransformerEncoderLayer(nn.Module):
179 | __constants__ = ["batch_first", "norm_first"]
180 |
181 | def __init__(
182 | self,
183 | d_model: int,
184 | nhead: int,
185 | dim_feedforward: int = 2048,
186 | dropout: float = 0.1,
187 | activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
188 | batch_first: bool = False,
189 | norm_first: bool = False,
190 | device=None,
191 | dtype=None,
192 | linear1_self_attention_cls: nn.Module = nn.Linear,
193 | linear2_self_attention_cls: nn.Module = nn.Linear,
194 | linear1_feedforward_cls: nn.Module = nn.Linear,
195 | linear2_feedforward_cls: nn.Module = nn.Linear,
196 | layer_norm_cls: nn.Module = LayerNorm,
197 | layer_norm_eps: float = 1e-5,
198 | adaptive_layer_norm=False,
199 | ) -> None:
200 | factory_kwargs = {"device": device, "dtype": dtype}
201 | super(TransformerEncoderLayer, self).__init__()
202 | self.self_attn = MultiheadAttention(
203 | d_model,
204 | nhead,
205 | dropout=dropout,
206 | batch_first=batch_first,
207 | linear1_cls=linear1_self_attention_cls,
208 | linear2_cls=linear2_self_attention_cls,
209 | **factory_kwargs,
210 | )
211 |
212 | # Implementation of Feedforward model
213 | self.linear1 = linear1_feedforward_cls(
214 | d_model, dim_feedforward, **factory_kwargs
215 | )
216 | self.dropout = nn.Dropout(dropout)
217 | self.linear2 = linear2_feedforward_cls(
218 | dim_feedforward, d_model, **factory_kwargs
219 | )
220 |
221 | self.norm_first = norm_first
222 | self.dropout1 = nn.Dropout(dropout)
223 | self.dropout2 = nn.Dropout(dropout)
224 |
225 | # Legacy string support for activation function.
226 | if isinstance(activation, str):
227 | activation = _get_activation_fn(activation)
228 | elif isinstance(activation, partial):
229 | activation = activation(d_model)
230 | elif activation == BalancedDoubleSwish:
231 | activation = BalancedDoubleSwish(d_model)
232 |
233 | # # We can't test self.activation in forward() in TorchScript,
234 | # # so stash some information about it instead.
235 | # if activation is F.relu or isinstance(activation, torch.nn.ReLU):
236 | # self.activation_relu_or_gelu = 1
237 | # elif activation is F.gelu or isinstance(activation, torch.nn.GELU):
238 | # self.activation_relu_or_gelu = 2
239 | # else:
240 | # self.activation_relu_or_gelu = 0
241 | self.activation = activation
242 |
243 | norm1 = layer_norm_cls(d_model, eps=layer_norm_eps, **factory_kwargs)
244 | if layer_norm_cls == IdentityNorm:
245 | norm2 = BalancedBasicNorm(
246 | d_model, eps=layer_norm_eps, **factory_kwargs
247 | )
248 | else:
249 | norm2 = layer_norm_cls(
250 | d_model, eps=layer_norm_eps, **factory_kwargs
251 | )
252 |
253 | if adaptive_layer_norm:
254 | self.norm1 = AdaptiveLayerNorm(d_model, norm1)
255 | self.norm2 = AdaptiveLayerNorm(d_model, norm2)
256 | else:
257 | self.norm1 = norm1
258 | self.norm2 = norm2
259 |
260 | def __setstate__(self, state):
261 | super(TransformerEncoderLayer, self).__setstate__(state)
262 | if not hasattr(self, "activation"):
263 | self.activation = F.relu
264 |
265 | def forward(
266 | self,
267 | src: Tensor,
268 | src_mask: Optional[Tensor] = None,
269 | src_key_padding_mask: Optional[Tensor] = None,
270 | ) -> Tensor:
271 | r"""Pass the input through the encoder layer.
272 |
273 | Args:
274 | src: the sequence to the encoder layer (required).
275 | src_mask: the mask for the src sequence (optional).
276 | src_key_padding_mask: the mask for the src keys per batch (optional).
277 |
278 | Shape:
279 | see the docs in Transformer class.
280 | """
281 | x, stage_embedding = src, None
282 | is_src_tuple = False
283 | if isinstance(src, tuple):
284 | x, stage_embedding = src
285 | is_src_tuple = True
286 |
287 | if src_key_padding_mask is not None:
288 | _skpm_dtype = src_key_padding_mask.dtype
289 | if _skpm_dtype != torch.bool and not torch.is_floating_point(
290 | src_key_padding_mask
291 | ):
292 | raise AssertionError(
293 | "only bool and floating types of key_padding_mask are supported"
294 | )
295 |
296 | if self.norm_first:
297 | x = x + self._sa_block(
298 | self.norm1(x, stage_embedding),
299 | src_mask,
300 | src_key_padding_mask,
301 | )
302 | x = x + self._ff_block(self.norm2(x, stage_embedding))
303 | else:
304 | x = self.norm1(
305 | x + self._sa_block(x, src_mask, src_key_padding_mask),
306 | stage_embedding,
307 | )
308 | x = self.norm2(x + self._ff_block(x), stage_embedding)
309 |
310 | if is_src_tuple:
311 | return (x, stage_embedding)
312 | return x
313 |
314 | # self-attention block
315 | def _sa_block(
316 | self,
317 | x: Tensor,
318 | attn_mask: Optional[Tensor],
319 | key_padding_mask: Optional[Tensor],
320 | ) -> Tensor:
321 | x = self.self_attn(
322 | x,
323 | x,
324 | x,
325 | attn_mask=attn_mask,
326 | key_padding_mask=key_padding_mask,
327 | need_weights=False,
328 | )[0]
329 | return self.dropout1(x)
330 |
331 | # feed forward block
332 | def _ff_block(self, x: Tensor) -> Tensor:
333 | x = self.linear2(self.dropout(self.activation(self.linear1(x))))
334 | return self.dropout2(x)
335 |
336 |
337 | class TransformerEncoder(nn.Module):
338 | r"""TransformerEncoder is a stack of N encoder layers. Users can build the
339 | BERT(https://arxiv.org/abs/1810.04805) model with corresponding parameters.
340 |
341 | Args:
342 | encoder_layer: an instance of the TransformerEncoderLayer() class (required).
343 | num_layers: the number of sub-encoder-layers in the encoder (required).
344 | norm: the layer normalization component (optional).
345 | enable_nested_tensor: if True, input will automatically convert to nested tensor
346 | (and convert back on output). This will improve the overall performance of
347 | TransformerEncoder when padding rate is high. Default: ``True`` (enabled).
348 |
349 | Examples::
350 | >>> encoder_layer = TransformerEncoderLayer(d_model=512, nhead=8)
351 | >>> transformer_encoder = TransformerEncoder(encoder_layer, num_layers=6)
352 | >>> src = torch.rand(10, 32, 512)
353 | >>> out = transformer_encoder(src)
354 | """
355 | __constants__ = ["norm"]
356 |
357 | def __init__(self, encoder_layer, num_layers, norm=None):
358 | super(TransformerEncoder, self).__init__()
359 | self.layers = _get_clones(encoder_layer, num_layers)
360 | self.num_layers = num_layers
361 | self.norm = norm
362 |
363 | def forward(
364 | self,
365 | src: Tensor,
366 | mask: Optional[Tensor] = None,
367 | src_key_padding_mask: Optional[Tensor] = None,
368 | return_layer_states: bool = False,
369 | ) -> Tensor:
370 | r"""Pass the input through the encoder layers in turn.
371 |
372 | Args:
373 | src: the sequence to the encoder (required).
374 | mask: the mask for the src sequence (optional).
375 | src_key_padding_mask: the mask for the src keys per batch (optional).
376 | return_layer_states: return layers' state (optional).
377 |
378 | Shape:
379 | see the docs in Transformer class.
380 | """
381 | if return_layer_states:
382 | layer_states = [] # layers' output
383 | output = src
384 | for mod in self.layers:
385 | output = mod(
386 | output,
387 | src_mask=mask,
388 | src_key_padding_mask=src_key_padding_mask,
389 | )
390 | layer_states.append(output[0])
391 |
392 | if self.norm is not None:
393 | output = self.norm(output)
394 |
395 | return layer_states, output
396 |
397 | output = src
398 | for mod in self.layers:
399 | output = mod(
400 | output, src_mask=mask, src_key_padding_mask=src_key_padding_mask
401 | )
402 |
403 | if self.norm is not None:
404 | output = self.norm(output)
405 |
406 | return output
407 |
408 |
409 | class TransformerDecoderLayer(nn.Module):
410 | __constants__ = ["batch_first", "norm_first"]
411 |
412 | def __init__(
413 | self,
414 | d_model: int,
415 | nhead: int,
416 | dim_feedforward: int = 2048,
417 | dropout: float = 0.1,
418 | activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
419 | linear1_self_attention_cls: nn.Module = nn.Linear,
420 | linear2_self_attention_cls: nn.Module = nn.Linear,
421 | linear1_feedforward_cls: nn.Module = nn.Linear,
422 | linear2_feedforward_cls: nn.Module = nn.Linear,
423 | batch_first: bool = False,
424 | norm_first: bool = False,
425 | device=None,
426 | dtype=None,
427 | layer_norm_cls: nn.Module = LayerNorm,
428 | layer_norm_eps: float = 1e-5,
429 | adaptive_layer_norm=False,
430 | ) -> None:
431 | factory_kwargs = {"device": device, "dtype": dtype}
432 | super(TransformerDecoderLayer, self).__init__()
433 | self.self_attn = MultiheadAttention(
434 | d_model,
435 | nhead,
436 | dropout=dropout,
437 | batch_first=batch_first,
438 | linear1_cls=linear1_self_attention_cls,
439 | linear2_cls=linear2_self_attention_cls,
440 | **factory_kwargs,
441 | )
442 | self.multihead_attn = MultiheadAttention(
443 | d_model,
444 | nhead,
445 | dropout=dropout,
446 | batch_first=batch_first,
447 | linear1_cls=linear1_self_attention_cls,
448 | linear2_cls=linear2_self_attention_cls,
449 | **factory_kwargs,
450 | )
451 | # Implementation of Feedforward model
452 | self.linear1 = linear1_feedforward_cls(
453 | d_model, dim_feedforward, **factory_kwargs
454 | )
455 | self.dropout = nn.Dropout(dropout)
456 | self.linear2 = linear2_feedforward_cls(
457 | dim_feedforward, d_model, **factory_kwargs
458 | )
459 |
460 | self.norm_first = norm_first
461 | self.dropout1 = nn.Dropout(dropout)
462 | self.dropout2 = nn.Dropout(dropout)
463 | self.dropout3 = nn.Dropout(dropout)
464 |
465 | # Legacy string support for activation function.
466 | if isinstance(activation, str):
467 | self.activation = _get_activation_fn(activation)
468 | elif isinstance(activation, partial):
469 | self.activation = activation(d_model)
470 | elif activation == BalancedDoubleSwish:
471 | self.activation = BalancedDoubleSwish(d_model)
472 | else:
473 | self.activation = activation
474 |
475 | if adaptive_layer_norm:
476 | norm1 = layer_norm_cls(
477 | d_model, eps=layer_norm_eps, **factory_kwargs
478 | )
479 | norm2 = layer_norm_cls(
480 | d_model, eps=layer_norm_eps, **factory_kwargs
481 | )
482 | norm3 = layer_norm_cls(
483 | d_model, eps=layer_norm_eps, **factory_kwargs
484 | )
485 |
486 | self.norm1 = AdaptiveLayerNorm(d_model, norm1)
487 | self.norm2 = AdaptiveLayerNorm(d_model, norm2)
488 | self.norm3 = AdaptiveLayerNorm(d_model, norm3)
489 | else:
490 | self.norm1 = layer_norm_cls(
491 | d_model, eps=layer_norm_eps, **factory_kwargs
492 | )
493 | self.norm2 = layer_norm_cls(
494 | d_model, eps=layer_norm_eps, **factory_kwargs
495 | )
496 | if layer_norm_cls == IdentityNorm:
497 | self.norm3 = BalancedBasicNorm(
498 | d_model, eps=layer_norm_eps, **factory_kwargs
499 | )
500 | else:
501 | self.norm3 = layer_norm_cls(
502 | d_model, eps=layer_norm_eps, **factory_kwargs
503 | )
504 |
505 | def forward(
506 | self,
507 | tgt: Tensor,
508 | memory: Tensor,
509 | tgt_mask: Optional[Tensor] = None,
510 | memory_mask: Optional[Tensor] = None,
511 | tgt_key_padding_mask: Optional[Tensor] = None,
512 | memory_key_padding_mask: Optional[Tensor] = None,
513 | ) -> Tensor:
514 | r"""Pass the inputs (and mask) through the decoder layer.
515 |
516 | Args:
517 | tgt: the sequence to the decoder layer (required).
518 | memory: the sequence from the last layer of the encoder (required).
519 | tgt_mask: the mask for the tgt sequence (optional).
520 | memory_mask: the mask for the memory sequence (optional).
521 | tgt_key_padding_mask: the mask for the tgt keys per batch (optional).
522 | memory_key_padding_mask: the mask for the memory keys per batch (optional).
523 |
524 | Shape:
525 | see the docs in Transformer class.
526 | """
527 | tgt_is_tuple = False
528 | if isinstance(tgt, tuple):
529 | x, stage_embedding = tgt
530 | tgt_is_tuple = True
531 | else:
532 | x, stage_embedding = tgt, None
533 |
534 | if self.norm_first:
535 | x = x + self._sa_block(
536 | self.norm1(x, stage_embedding), tgt_mask, tgt_key_padding_mask
537 | )
538 | x = x + self._mha_block(
539 | self.norm2(x, stage_embedding),
540 | memory,
541 | memory_mask,
542 | memory_key_padding_mask,
543 | )
544 | x = x + self._ff_block(self.norm3(x, stage_embedding))
545 | else:
546 | x = self.norm1(
547 | x + self._sa_block(x, tgt_mask, tgt_key_padding_mask),
548 | stage_embedding,
549 | )
550 | x = self.norm2(
551 | x
552 | + self._mha_block(
553 | x, memory, memory_mask, memory_key_padding_mask
554 | ),
555 | stage_embedding,
556 | )
557 | x = self.norm3(x + self._ff_block(x), stage_embedding)
558 |
559 | if tgt_is_tuple:
560 | return (x, stage_embedding)
561 | return x
562 |
563 | # self-attention block
564 | def _sa_block(
565 | self,
566 | x: Tensor,
567 | attn_mask: Optional[Tensor],
568 | key_padding_mask: Optional[Tensor],
569 | ) -> Tensor:
570 | x = self.self_attn(
571 | x,
572 | x,
573 | x,
574 | attn_mask=attn_mask,
575 | key_padding_mask=key_padding_mask,
576 | need_weights=False,
577 | )[0]
578 | return self.dropout1(x)
579 |
580 | # multihead attention block
581 | def _mha_block(
582 | self,
583 | x: Tensor,
584 | mem: Tensor,
585 | attn_mask: Optional[Tensor],
586 | key_padding_mask: Optional[Tensor],
587 | ) -> Tensor:
588 | x = self.multihead_attn(
589 | x,
590 | mem,
591 | mem,
592 | attn_mask=attn_mask,
593 | key_padding_mask=key_padding_mask,
594 | need_weights=False,
595 | )[0]
596 | return self.dropout2(x)
597 |
598 | # feed forward block
599 | def _ff_block(self, x: Tensor) -> Tensor:
600 | x = self.linear2(self.dropout(self.activation(self.linear1(x))))
601 | return self.dropout3(x)
602 |
603 |
604 | def _get_clones(module, N):
605 | return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
606 |
607 |
608 | def _get_activation_fn(activation: str) -> Callable[[Tensor], Tensor]:
609 | if activation == "relu":
610 | return F.relu
611 | elif activation == "gelu":
612 | return F.gelu
613 |
614 | raise RuntimeError(
615 | "activation should be relu/gelu, not {}".format(activation)
616 | )
617 |
--------------------------------------------------------------------------------
/prompts/1580_141083_000002_000002.normalized.txt:
--------------------------------------------------------------------------------
1 | mr Soames was a tall, spare man, of a nervous and excitable temperament.
--------------------------------------------------------------------------------
/prompts/1580_141083_000002_000002.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/0nutation/USLM/b6cbe07feaeb142fa395b8b98836c534fec16bc9/prompts/1580_141083_000002_000002.wav
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | import os
3 | import sys
4 | from pathlib import Path
5 | from subprocess import DEVNULL, PIPE, run
6 |
7 | from setuptools import find_packages, setup
8 |
9 | project_root = Path(__file__).parent
10 |
11 | # modified from https://github.com/lhotse-speech/lhotse/blob/master/setup.py
12 |
13 |
14 |
15 | if sys.version_info < (3,):
16 | # fmt: off
17 | print(
18 | "Python 2 has reached end-of-life and is no longer supported by valle."
19 | )
20 | # fmt: on
21 | sys.exit(-1)
22 |
23 | if sys.version_info < (3, 7):
24 | print(
25 | "Python 3.6 has reached end-of-life on December 31st, 2021 "
26 | "and is no longer supported by valle."
27 | )
28 | sys.exit(-1)
29 |
30 |
31 |
32 |
33 | install_requires = [
34 | "encodec",
35 | "phonemizer",
36 | ]
37 |
38 | try:
39 | # If the user already installed PyTorch, make sure he has torchaudio too.
40 | # Otherwise, we'll just install the latest versions from PyPI for the user.
41 | import torch
42 |
43 | try:
44 | import torchaudio
45 | except ImportError:
46 | raise ValueError(
47 | "We detected that you have already installed PyTorch, but haven't installed torchaudio. "
48 | "Unfortunately we can't detect the compatible torchaudio version for you; "
49 | "you will have to install it manually. "
50 | "For instructions, please refer either to https://pytorch.org/get-started/locally/ "
51 | "or https://github.com/pytorch/audio#dependencies"
52 | )
53 | except ImportError:
54 | install_requires.extend(["torch", "torchaudio"])
55 |
56 | docs_require = (
57 | (project_root / "requirements.txt").read_text().splitlines()
58 | )
59 | tests_require = [
60 | # "pytest==7.1.3",
61 | # "pytest-forked==1.4.0",
62 | # "pytest-xdist==2.5.0",
63 | # "pytest-cov==4.0.0",
64 | ]
65 | workflow_requires = [""]
66 | dev_requires = sorted(
67 | docs_require
68 | + tests_require
69 | + workflow_requires
70 | + ["jupyterlab", "matplotlib"]
71 | )
72 | all_requires = sorted(dev_requires)
73 |
74 | if os.environ.get("READTHEDOCS", False):
75 | # When building documentation, omit torchaudio installation and mock it instead.
76 | # This works around the inability to install libsoundfile1 in read-the-docs env,
77 | # which caused the documentation builds to silently crash.
78 | install_requires = [
79 | req
80 | for req in install_requires
81 | if not any(req.startswith(dep) for dep in ["torchaudio", "SoundFile"])
82 | ]
83 |
84 | setup(
85 | name="USLM",
86 | version='0.1.0',
87 | python_requires=">=3.7.0",
88 | description="USLM: Unified Speech Language Models",
89 | author="Dong Zhang",
90 | author_email="dongzhang22@fudan.edu.cn",
91 | long_description=(project_root / "README.md").read_text(encoding="utf-8"),
92 | long_description_content_type="text/markdown",
93 | license="Apache-2.0 License",
94 | packages=find_packages(exclude=["test", "test.*"]),
95 | include_package_data=True,
96 | entry_points={},
97 | install_requires=install_requires,
98 | extras_require={
99 | "docs": docs_require,
100 | "tests": tests_require,
101 | "dev": dev_requires,
102 | "all": all_requires,
103 | },
104 | classifiers=[
105 | "Development Status :: 1 - Beta",
106 | "Programming Language :: Python :: 3.7",
107 | "Programming Language :: Python :: 3.8",
108 | "Programming Language :: Python :: 3.9",
109 | "Programming Language :: Python :: 3.10",
110 | "Intended Audience :: Science/Research",
111 | "Operating System :: POSIX :: Linux",
112 | "Operating System :: MacOS :: MacOS X",
113 | "License :: OSI Approved :: Apache Software License",
114 | "Topic :: Multimedia :: Sound/Audio :: Speech",
115 | "Topic :: Scientific/Engineering :: Artificial Intelligence",
116 | "Topic :: Software Development :: Libraries :: Python Modules",
117 | "Typing :: Typed",
118 | ],
119 | )
120 |
--------------------------------------------------------------------------------
/utils/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/0nutation/USLM/b6cbe07feaeb142fa395b8b98836c534fec16bc9/utils/.DS_Store
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from icefall.utils import make_pad_mask
4 |
5 | from .symbol_table import SymbolTable
6 |
7 | make_pad_mask = make_pad_mask
8 | SymbolTable = SymbolTable
9 |
10 |
11 | class Transpose(nn.Identity):
12 | """(N, T, D) -> (N, D, T)"""
13 |
14 | def forward(self, input: torch.Tensor) -> torch.Tensor:
15 | return input.transpose(1, 2)
16 |
--------------------------------------------------------------------------------
/utils/symbol_table.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Mobvoi Inc. (authors: Fangjun Kuang)
2 | #
3 | # See ../../../LICENSE for clarification regarding multiple authors
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | from dataclasses import dataclass
18 | from dataclasses import field
19 | from typing import Dict
20 | from typing import Generic
21 | from typing import List
22 | from typing import Optional
23 | from typing import TypeVar
24 | from typing import Union
25 |
26 | Symbol = TypeVar('Symbol')
27 |
28 |
29 | # Disable __repr__ otherwise it could freeze e.g. Jupyter.
30 | @dataclass(repr=False)
31 | class SymbolTable(Generic[Symbol]):
32 | '''SymbolTable that maps symbol IDs, found on the FSA arcs to
33 | actual objects. These objects can be arbitrary Python objects
34 | that can serve as keys in a dictionary (i.e. they need to be
35 | hashable and immutable).
36 |
37 | The SymbolTable can only be read to/written from disk if the
38 | symbols are strings.
39 | '''
40 | _id2sym: Dict[int, Symbol] = field(default_factory=dict)
41 | '''Map an integer to a symbol.
42 | '''
43 |
44 | _sym2id: Dict[Symbol, int] = field(default_factory=dict)
45 | '''Map a symbol to an integer.
46 | '''
47 |
48 | _next_available_id: int = 1
49 | '''A helper internal field that helps adding new symbols
50 | to the table efficiently.
51 | '''
52 |
53 | eps: Symbol = ''
54 | '''Null symbol, always mapped to index 0.
55 | '''
56 |
57 | def __post_init__(self):
58 | for idx, sym in self._id2sym.items():
59 | assert self._sym2id[sym] == idx
60 | assert idx >= 0
61 |
62 | for sym, idx in self._sym2id.items():
63 | assert idx >= 0
64 | assert self._id2sym[idx] == sym
65 |
66 | if 0 not in self._id2sym:
67 | self._id2sym[0] = self.eps
68 | self._sym2id[self.eps] = 0
69 | else:
70 | assert self._id2sym[0] == self.eps
71 | assert self._sym2id[self.eps] == 0
72 |
73 | self._next_available_id = max(self._id2sym) + 1
74 |
75 | @staticmethod
76 | def from_str(s: str) -> 'SymbolTable':
77 | '''Build a symbol table from a string.
78 |
79 | The string consists of lines. Every line has two fields separated
80 | by space(s), tab(s) or both. The first field is the symbol and the
81 | second the integer id of the symbol.
82 |
83 | Args:
84 | s:
85 | The input string with the format described above.
86 | Returns:
87 | An instance of :class:`SymbolTable`.
88 | '''
89 | id2sym: Dict[int, str] = dict()
90 | sym2id: Dict[str, int] = dict()
91 |
92 | for line in s.split('\n'):
93 | fields = line.split()
94 | if len(fields) == 0:
95 | continue # skip empty lines
96 | assert len(fields) == 2, \
97 | f'Expect a line with 2 fields. Given: {len(fields)}'
98 | sym, idx = fields[0], int(fields[1])
99 | assert sym not in sym2id, f'Duplicated symbol {sym}'
100 | assert idx not in id2sym, f'Duplicated id {idx}'
101 | id2sym[idx] = sym
102 | sym2id[sym] = idx
103 |
104 | eps = id2sym.get(0, '')
105 |
106 | return SymbolTable(_id2sym=id2sym, _sym2id=sym2id, eps=eps)
107 |
108 | @staticmethod
109 | def from_file(filename: str) -> 'SymbolTable':
110 | '''Build a symbol table from file.
111 |
112 | Every line in the symbol table file has two fields separated by
113 | space(s), tab(s) or both. The following is an example file:
114 |
115 | .. code-block::
116 |
117 | 0
118 | a 1
119 | b 2
120 | c 3
121 |
122 | Args:
123 | filename:
124 | Name of the symbol table file. Its format is documented above.
125 |
126 | Returns:
127 | An instance of :class:`SymbolTable`.
128 |
129 | '''
130 | with open(filename, 'r', encoding='utf-8') as f:
131 | return SymbolTable.from_str(f.read().strip())
132 |
133 | def to_str(self) -> str:
134 | '''
135 | Returns:
136 | Return a string representation of this object. You can pass
137 | it to the method ``from_str`` to recreate an identical object.
138 | '''
139 | s = ''
140 | for idx, symbol in sorted(self._id2sym.items()):
141 | s += f'{symbol} {idx}\n'
142 | return s
143 |
144 | def to_file(self, filename: str):
145 | '''Serialize the SymbolTable to a file.
146 |
147 | Every line in the symbol table file has two fields separated by
148 | space(s), tab(s) or both. The following is an example file:
149 |
150 | .. code-block::
151 |
152 | 0
153 | a 1
154 | b 2
155 | c 3
156 |
157 | Args:
158 | filename:
159 | Name of the symbol table file. Its format is documented above.
160 | '''
161 | with open(filename, 'w') as f:
162 | for idx, symbol in sorted(self._id2sym.items()):
163 | print(symbol, idx, file=f)
164 |
165 | def add(self, symbol: Symbol, index: Optional[int] = None) -> int:
166 | '''Add a new symbol to the SymbolTable.
167 |
168 | Args:
169 | symbol:
170 | The symbol to be added.
171 | index:
172 | Optional int id to which the symbol should be assigned.
173 | If it is not available, a ValueError will be raised.
174 |
175 | Returns:
176 | The int id to which the symbol has been assigned.
177 | '''
178 | # Already in the table? Return its ID.
179 | if symbol in self._sym2id:
180 | return self._sym2id[symbol]
181 | # Specific ID not provided - use next available.
182 | if index is None:
183 | index = self._next_available_id
184 | # Specific ID provided but not available.
185 | if index in self._id2sym:
186 | raise ValueError(f"Cannot assign id '{index}' to '{symbol}' - "
187 | f"already occupied by {self._id2sym[index]}")
188 | self._sym2id[symbol] = index
189 | self._id2sym[index] = symbol
190 |
191 | # Update next available ID if needed
192 | if self._next_available_id <= index:
193 | self._next_available_id = index + 1
194 |
195 | return index
196 |
197 | def get(self, k: Union[int, Symbol]) -> Union[Symbol, int]:
198 | '''Get a symbol for an id or get an id for a symbol
199 |
200 | Args:
201 | k:
202 | If it is an id, it tries to find the symbol corresponding
203 | to the id; if it is a symbol, it tries to find the id
204 | corresponding to the symbol.
205 |
206 | Returns:
207 | An id or a symbol depending on the given `k`.
208 | '''
209 | if isinstance(k, int):
210 | return self._id2sym[k]
211 | else:
212 | return self._sym2id[k]
213 |
214 | def merge(self, other: 'SymbolTable') -> 'SymbolTable':
215 | '''Create a union of two SymbolTables.
216 | Raises an AssertionError if the same IDs are occupied by
217 | different symbols.
218 |
219 | Args:
220 | other:
221 | A symbol table to merge with ``self``.
222 |
223 | Returns:
224 | A new symbol table.
225 | '''
226 | self._check_compatible(other)
227 |
228 | id2sym = {**self._id2sym, **other._id2sym}
229 | sym2id = {**self._sym2id, **other._sym2id}
230 |
231 | return SymbolTable(_id2sym=id2sym, _sym2id=sym2id, eps=self.eps)
232 |
233 | def _check_compatible(self, other: 'SymbolTable') -> None:
234 | # Epsilon compatibility
235 | assert self.eps == other.eps, f'Mismatched epsilon symbol: ' \
236 | f'{self.eps} != {other.eps}'
237 | # IDs compatibility
238 | common_ids = set(self._id2sym).intersection(other._id2sym)
239 | for idx in common_ids:
240 | assert self[idx] == other[idx], f'ID conflict for id: {idx}, ' \
241 | f'self[idx] = "{self[idx]}", ' \
242 | f'other[idx] = "{other[idx]}"'
243 | # Symbols compatibility
244 | common_symbols = set(self._sym2id).intersection(other._sym2id)
245 | for sym in common_symbols:
246 | assert self[sym] == other[sym], f'ID conflict for id: {sym}, ' \
247 | f'self[sym] = "{self[sym]}", ' \
248 | f'other[sym] = "{other[sym]}"'
249 |
250 | def __getitem__(self, item: Union[int, Symbol]) -> Union[Symbol, int]:
251 | return self.get(item)
252 |
253 | def __contains__(self, item: Union[int, Symbol]) -> bool:
254 | if isinstance(item, int):
255 | return item in self._id2sym
256 | else:
257 | return item in self._sym2id
258 |
259 | def __len__(self) -> int:
260 | return len(self._id2sym)
261 |
262 | def __eq__(self, other: 'SymbolTable') -> bool:
263 | if len(self) != len(other):
264 | return False
265 |
266 | for s in self.symbols:
267 | if self[s] != other[s]:
268 | return False
269 |
270 | return True
271 |
272 | @property
273 | def ids(self) -> List[int]:
274 | '''Returns a list of integer IDs corresponding to the symbols.
275 | '''
276 | ans = list(self._id2sym.keys())
277 | ans.sort()
278 | return ans
279 |
280 | @property
281 | def symbols(self) -> List[Symbol]:
282 | '''Returns a list of symbols (e.g., strings) corresponding to
283 | the integer IDs.
284 | '''
285 | ans = list(self._sym2id.keys())
286 | ans.sort()
287 | return ans
288 |
--------------------------------------------------------------------------------