├── ChatTTS
├── __init__.py
├── utils
│ ├── io_utils.py
│ ├── gpu_utils.py
│ └── infer_utils.py
├── experimental
│ └── llm.py
├── infer
│ └── api.py
├── model
│ ├── dvae.py
│ └── gpt.py
└── core.py
├── OpenVoice
├── checkpoints
│ ├── converter
│ │ ├── config_json
│ │ └── checkpiont_pth
│ └── openvoice_v1_in_here
├── text
│ ├── cleaners.py
│ ├── symbols.py
│ ├── __init__.py
│ ├── english.py
│ └── mandarin.py
├── utils
│ ├── se_extractor.py
│ ├── commons.py
│ ├── utils.py
│ ├── mel_processing.py
│ ├── transforms.py
│ ├── attentions.py
│ ├── models.py
│ └── modules.py
└── api.py
├── requirements.txt
├── README.md
├── gitattributes
├── app.py
└── LICENSE
/ChatTTS/__init__.py:
--------------------------------------------------------------------------------
1 | from .core import Chat
--------------------------------------------------------------------------------
/OpenVoice/checkpoints/converter/config_json:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/OpenVoice/checkpoints/openvoice_v1_in_here:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/OpenVoice/checkpoints/converter/checkpiont_pth:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/ChatTTS/utils/io_utils.py:
--------------------------------------------------------------------------------
1 |
2 | import os
3 | import logging
4 |
5 | def get_latest_modified_file(directory):
6 | logger = logging.getLogger(__name__)
7 |
8 | files = [os.path.join(directory, f) for f in os.listdir(directory)]
9 | if not files:
10 | logger.log(logging.WARNING, f'No files found in the directory: {directory}')
11 | return None
12 | latest_file = max(files, key=os.path.getmtime)
13 |
14 | return latest_file
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | # PyTorch and related libraries
2 | torch
3 | torchvision
4 | torchaudio
5 |
6 | # Hugging Face transformers library
7 | transformers
8 |
9 | # Configuration management with OmegaConf
10 | omegaconf
11 |
12 | # Interactive widgets for Jupyter Notebooks
13 | ipywidgets
14 |
15 | # Gradio for creating web UIs
16 | gradio
17 |
18 | # Vector quantization for PyTorch
19 | vector_quantize_pytorch
20 |
21 | # Hugging Face Hub client
22 | huggingface_hub
23 |
24 | vocos
25 |
26 | # OpenVoice
27 | librosa
28 | faster-whisper
29 | pydub
30 | wavmark
31 | numpy
32 | eng_to_ipa
33 | inflect
34 | unidecode
35 | whisper-timestamped
36 | openai
37 | python-dotenv
38 | pypinyin
39 | jieba
40 | cn2an
41 | edge_tts
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | __ChatTTS x OpenVoice__
2 |
3 | Enhance the authenticity of speech by utilizing ChatTTS for more natural voice generation, complemented with the voice timber simulation module from Openvoice for seamless tone transplantation.
4 |
5 | Have a try on huggingface!
6 | https://huggingface.co/spaces/Hilley/ChatTTS-OpenVoice
7 |
8 |
9 |
10 |
11 |
12 | ---
13 | __Notice:__
14 |
15 | We need to download the OpenVoice Checkpoint and save it into the __./OpenVoice/checkpoint__ folder.
16 |
17 | __OpenVoice Checkpoint:__ https://huggingface.co/myshell-ai/OpenVoice/tree/main/checkpoints
18 |
19 |
--------------------------------------------------------------------------------
/OpenVoice/text/cleaners.py:
--------------------------------------------------------------------------------
1 | import re
2 | from .english import english_to_lazy_ipa, english_to_ipa2, english_to_lazy_ipa2
3 | from .mandarin import number_to_chinese, chinese_to_bopomofo, latin_to_bopomofo, chinese_to_romaji, chinese_to_lazy_ipa, chinese_to_ipa, chinese_to_ipa2
4 |
5 | def cjke_cleaners2(text):
6 | text = re.sub(r'\[ZH\](.*?)\[ZH\]',
7 | lambda x: chinese_to_ipa(x.group(1))+' ', text)
8 | text = re.sub(r'\[JA\](.*?)\[JA\]',
9 | lambda x: japanese_to_ipa2(x.group(1))+' ', text)
10 | text = re.sub(r'\[KO\](.*?)\[KO\]',
11 | lambda x: korean_to_ipa(x.group(1))+' ', text)
12 | text = re.sub(r'\[EN\](.*?)\[EN\]',
13 | lambda x: english_to_ipa2(x.group(1))+' ', text)
14 | text = re.sub(r'\s+$', '', text)
15 | text = re.sub(r'([^\.,!\?\-…~])$', r'\1.', text)
16 | return text
--------------------------------------------------------------------------------
/ChatTTS/utils/gpu_utils.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | import logging
4 |
5 | def select_device(min_memory = 2048):
6 | logger = logging.getLogger(__name__)
7 | if torch.cuda.is_available():
8 | available_gpus = []
9 | for i in range(torch.cuda.device_count()):
10 | props = torch.cuda.get_device_properties(i)
11 | free_memory = props.total_memory - torch.cuda.memory_reserved(i)
12 | available_gpus.append((i, free_memory))
13 | selected_gpu, max_free_memory = max(available_gpus, key=lambda x: x[1])
14 | device = torch.device(f'cuda:{selected_gpu}')
15 | free_memory_mb = max_free_memory / (1024 * 1024)
16 | if free_memory_mb < min_memory:
17 | logger.log(logging.WARNING, f'GPU {selected_gpu} has {round(free_memory_mb, 2)} MB memory left.')
18 | device = torch.device('cpu')
19 | else:
20 | logger.log(logging.WARNING, f'No GPU found, use CPU instead')
21 | device = torch.device('cpu')
22 |
23 | return device
24 |
--------------------------------------------------------------------------------
/gitattributes:
--------------------------------------------------------------------------------
1 | *.7z filter=lfs diff=lfs merge=lfs -text
2 | *.arrow filter=lfs diff=lfs merge=lfs -text
3 | *.bin filter=lfs diff=lfs merge=lfs -text
4 | *.bz2 filter=lfs diff=lfs merge=lfs -text
5 | *.ckpt filter=lfs diff=lfs merge=lfs -text
6 | *.ftz filter=lfs diff=lfs merge=lfs -text
7 | *.gz filter=lfs diff=lfs merge=lfs -text
8 | *.h5 filter=lfs diff=lfs merge=lfs -text
9 | *.joblib filter=lfs diff=lfs merge=lfs -text
10 | *.lfs.* filter=lfs diff=lfs merge=lfs -text
11 | *.mlmodel filter=lfs diff=lfs merge=lfs -text
12 | *.model filter=lfs diff=lfs merge=lfs -text
13 | *.msgpack filter=lfs diff=lfs merge=lfs -text
14 | *.npy filter=lfs diff=lfs merge=lfs -text
15 | *.npz filter=lfs diff=lfs merge=lfs -text
16 | *.onnx filter=lfs diff=lfs merge=lfs -text
17 | *.ot filter=lfs diff=lfs merge=lfs -text
18 | *.parquet filter=lfs diff=lfs merge=lfs -text
19 | *.pb filter=lfs diff=lfs merge=lfs -text
20 | *.pickle filter=lfs diff=lfs merge=lfs -text
21 | *.pkl filter=lfs diff=lfs merge=lfs -text
22 | *.pt filter=lfs diff=lfs merge=lfs -text
23 | *.pth filter=lfs diff=lfs merge=lfs -text
24 | *.rar filter=lfs diff=lfs merge=lfs -text
25 | *.safetensors filter=lfs diff=lfs merge=lfs -text
26 | saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27 | *.tar.* filter=lfs diff=lfs merge=lfs -text
28 | *.tar filter=lfs diff=lfs merge=lfs -text
29 | *.tflite filter=lfs diff=lfs merge=lfs -text
30 | *.tgz filter=lfs diff=lfs merge=lfs -text
31 | *.wasm filter=lfs diff=lfs merge=lfs -text
32 | *.xz filter=lfs diff=lfs merge=lfs -text
33 | *.zip filter=lfs diff=lfs merge=lfs -text
34 | *.zst filter=lfs diff=lfs merge=lfs -text
35 | *tfevents* filter=lfs diff=lfs merge=lfs -text
36 |
--------------------------------------------------------------------------------
/ChatTTS/utils/infer_utils.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | import torch.nn.functional as F
4 |
5 |
6 | class CustomRepetitionPenaltyLogitsProcessorRepeat():
7 |
8 | def __init__(self, penalty: float, max_input_ids, past_window):
9 | if not isinstance(penalty, float) or not (penalty > 0):
10 | raise ValueError(f"`penalty` has to be a strictly positive float, but is {penalty}")
11 |
12 | self.penalty = penalty
13 | self.max_input_ids = max_input_ids
14 | self.past_window = past_window
15 |
16 | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
17 |
18 | input_ids = input_ids[:, -self.past_window:]
19 | freq = F.one_hot(input_ids, scores.size(1)).sum(1)
20 | freq[self.max_input_ids:] = 0
21 | alpha = self.penalty**freq
22 | scores = torch.where(scores < 0, scores*alpha, scores/alpha)
23 |
24 | return scores
25 |
26 | class CustomRepetitionPenaltyLogitsProcessor():
27 |
28 | def __init__(self, penalty: float, max_input_ids, past_window):
29 | if not isinstance(penalty, float) or not (penalty > 0):
30 | raise ValueError(f"`penalty` has to be a strictly positive float, but is {penalty}")
31 |
32 | self.penalty = penalty
33 | self.max_input_ids = max_input_ids
34 | self.past_window = past_window
35 |
36 | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
37 |
38 | input_ids = input_ids[:, -self.past_window:]
39 | score = torch.gather(scores, 1, input_ids)
40 | _score = score.detach().clone()
41 | score = torch.where(score < 0, score * self.penalty, score / self.penalty)
42 | score[input_ids>=self.max_input_ids] = _score[input_ids>=self.max_input_ids]
43 | scores.scatter_(1, input_ids, score)
44 |
45 | return scores
--------------------------------------------------------------------------------
/ChatTTS/experimental/llm.py:
--------------------------------------------------------------------------------
1 |
2 | from openai import OpenAI
3 |
4 | prompt_dict = {
5 | 'kimi': [ {"role": "system", "content": "你是 Kimi,由 Moonshot AI 提供的人工智能助手,你更擅长中文和英文的对话。"},
6 | {"role": "user", "content": "你好,请注意你现在生成的文字要按照人日常生活的口吻,你的回复将会后续用TTS模型转为语音,并且请把回答控制在100字以内。并且标点符号仅包含逗号和句号,将数字等转为文字回答。"},
7 | {"role": "assistant", "content": "好的,我现在生成的文字将按照人日常生活的口吻, 并且我会把回答控制在一百字以内, 标点符号仅包含逗号和句号,将阿拉伯数字等转为中文文字回答。下面请开始对话。"},],
8 | 'deepseek': [
9 | {"role": "system", "content": "You are a helpful assistant"},
10 | {"role": "user", "content": "你好,请注意你现在生成的文字要按照人日常生活的口吻,你的回复将会后续用TTS模型转为语音,并且请把回答控制在100字以内。并且标点符号仅包含逗号和句号,将数字等转为文字回答。"},
11 | {"role": "assistant", "content": "好的,我现在生成的文字将按照人日常生活的口吻, 并且我会把回答控制在一百字以内, 标点符号仅包含逗号和句号,将阿拉伯数字等转为中文文字回答。下面请开始对话。"},],
12 | 'deepseek_TN': [
13 | {"role": "system", "content": "You are a helpful assistant"},
14 | {"role": "user", "content": "你好,现在我们在处理TTS的文本输入,下面将会给你输入一段文本,请你将其中的阿拉伯数字等等转为文字表达,并且输出的文本里仅包含逗号和句号这两个标点符号"},
15 | {"role": "assistant", "content": "好的,我现在对TTS的文本输入进行处理。这一般叫做text normalization。下面请输入"},
16 | {"role": "user", "content": "We paid $123 for this desk."},
17 | {"role": "assistant", "content": "We paid one hundred and twenty three dollars for this desk."},
18 | {"role": "user", "content": "详询请拨打010-724654"},
19 | {"role": "assistant", "content": "详询请拨打零幺零,七二四六五四"},
20 | {"role": "user", "content": "罗森宣布将于7月24日退市,在华门店超6000家!"},
21 | {"role": "assistant", "content": "罗森宣布将于七月二十四日退市,在华门店超过六千家。"},
22 | ],
23 | }
24 |
25 | class llm_api:
26 | def __init__(self, api_key, base_url, model):
27 | self.client = OpenAI(
28 | api_key = api_key,
29 | base_url = base_url,
30 | )
31 | self.model = model
32 | def call(self, user_question, temperature = 0.3, prompt_version='kimi', **kwargs):
33 |
34 | completion = self.client.chat.completions.create(
35 | model = self.model,
36 | messages = prompt_dict[prompt_version]+[{"role": "user", "content": user_question},],
37 | temperature = temperature,
38 | **kwargs
39 | )
40 | return completion.choices[0].message.content
41 |
--------------------------------------------------------------------------------
/OpenVoice/text/symbols.py:
--------------------------------------------------------------------------------
1 | '''
2 | Defines the set of symbols used in text input to the model.
3 | '''
4 |
5 | # japanese_cleaners
6 | # _pad = '_'
7 | # _punctuation = ',.!?-'
8 | # _letters = 'AEINOQUabdefghijkmnoprstuvwyzʃʧ↓↑ '
9 |
10 |
11 | '''# japanese_cleaners2
12 | _pad = '_'
13 | _punctuation = ',.!?-~…'
14 | _letters = 'AEINOQUabdefghijkmnoprstuvwyzʃʧʦ↓↑ '
15 | '''
16 |
17 |
18 | '''# korean_cleaners
19 | _pad = '_'
20 | _punctuation = ',.!?…~'
21 | _letters = 'ㄱㄴㄷㄹㅁㅂㅅㅇㅈㅊㅋㅌㅍㅎㄲㄸㅃㅆㅉㅏㅓㅗㅜㅡㅣㅐㅔ '
22 | '''
23 |
24 | '''# chinese_cleaners
25 | _pad = '_'
26 | _punctuation = ',。!?—…'
27 | _letters = 'ㄅㄆㄇㄈㄉㄊㄋㄌㄍㄎㄏㄐㄑㄒㄓㄔㄕㄖㄗㄘㄙㄚㄛㄜㄝㄞㄟㄠㄡㄢㄣㄤㄥㄦㄧㄨㄩˉˊˇˋ˙ '
28 | '''
29 |
30 | # # zh_ja_mixture_cleaners
31 | # _pad = '_'
32 | # _punctuation = ',.!?-~…'
33 | # _letters = 'AEINOQUabdefghijklmnoprstuvwyzʃʧʦɯɹəɥ⁼ʰ`→↓↑ '
34 |
35 |
36 | '''# sanskrit_cleaners
37 | _pad = '_'
38 | _punctuation = '।'
39 | _letters = 'ँंःअआइईउऊऋएऐओऔकखगघङचछजझञटठडढणतथदधनपफबभमयरलळवशषसहऽािीुूृॄेैोौ्ॠॢ '
40 | '''
41 |
42 | '''# cjks_cleaners
43 | _pad = '_'
44 | _punctuation = ',.!?-~…'
45 | _letters = 'NQabdefghijklmnopstuvwxyzʃʧʥʦɯɹəɥçɸɾβŋɦː⁼ʰ`^#*=→↓↑ '
46 | '''
47 |
48 | '''# thai_cleaners
49 | _pad = '_'
50 | _punctuation = '.!? '
51 | _letters = 'กขฃคฆงจฉชซฌญฎฏฐฑฒณดตถทธนบปผฝพฟภมยรฤลวศษสหฬอฮฯะัาำิีึืุูเแโใไๅๆ็่้๊๋์'
52 | '''
53 |
54 | # # cjke_cleaners2
55 | _pad = '_'
56 | _punctuation = ',.!?-~…'
57 | _letters = 'NQabdefghijklmnopstuvwxyzɑæʃʑçɯɪɔɛɹðəɫɥɸʊɾʒθβŋɦ⁼ʰ`^#*=ˈˌ→↓↑ '
58 |
59 |
60 | '''# shanghainese_cleaners
61 | _pad = '_'
62 | _punctuation = ',.!?…'
63 | _letters = 'abdfghiklmnopstuvyzøŋȵɑɔɕəɤɦɪɿʑʔʰ̩̃ᴀᴇ15678 '
64 | '''
65 |
66 | '''# chinese_dialect_cleaners
67 | _pad = '_'
68 | _punctuation = ',.!?~…─'
69 | _letters = '#Nabdefghijklmnoprstuvwxyzæçøŋœȵɐɑɒɓɔɕɗɘəɚɛɜɣɤɦɪɭɯɵɷɸɻɾɿʂʅʊʋʌʏʑʔʦʮʰʷˀː˥˦˧˨˩̥̩̃̚ᴀᴇ↑↓∅ⱼ '
70 | '''
71 |
72 | # Export all symbols:
73 | symbols = [_pad] + list(_punctuation) + list(_letters)
74 |
75 | # Special symbol ids
76 | SPACE_ID = symbols.index(" ")
77 |
78 | num_ja_tones = 1
79 | num_kr_tones = 1
80 | num_zh_tones = 6
81 | num_en_tones = 4
82 |
83 | language_tone_start_map = {
84 | "ZH": 0,
85 | "JP": num_zh_tones,
86 | "EN": num_zh_tones + num_ja_tones,
87 | 'KR': num_zh_tones + num_ja_tones + num_en_tones,
88 | }
--------------------------------------------------------------------------------
/OpenVoice/text/__init__.py:
--------------------------------------------------------------------------------
1 | """ from https://github.com/keithito/tacotron """
2 | from . import cleaners
3 | from .symbols import symbols, language_tone_start_map
4 |
5 | # Mappings from symbol to numeric ID and vice versa:
6 | _symbol_to_id = {s: i for i, s in enumerate(symbols)}
7 | _id_to_symbol = {i: s for i, s in enumerate(symbols)}
8 |
9 |
10 | def text_to_sequence(text, symbols, cleaner_names):
11 | '''Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
12 | Args:
13 | text: string to convert to a sequence
14 | cleaner_names: names of the cleaner functions to run the text through
15 | Returns:
16 | List of integers corresponding to the symbols in the text
17 | '''
18 | sequence = []
19 | symbol_to_id = {s: i for i, s in enumerate(symbols)}
20 | clean_text = _clean_text(text, cleaner_names)
21 | print(clean_text)
22 | print(f" length:{len(clean_text)}")
23 | for symbol in clean_text:
24 | if symbol not in symbol_to_id.keys():
25 | continue
26 | symbol_id = symbol_to_id[symbol]
27 | sequence += [symbol_id]
28 | print(f" length:{len(sequence)}")
29 | return sequence
30 |
31 |
32 | def cleaned_text_to_sequence(cleaned_text, symbols):
33 | '''Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
34 | Args:
35 | text: string to convert to a sequence
36 | Returns:
37 | List of integers corresponding to the symbols in the text
38 | '''
39 | symbol_to_id = {s: i for i, s in enumerate(symbols)}
40 | sequence = [symbol_to_id[symbol] for symbol in cleaned_text if symbol in symbol_to_id.keys()]
41 | return sequence
42 |
43 | def cleaned_text_to_sequence_vits2(cleaned_text, tones, language, symbols, languages):
44 | """Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
45 | Args:
46 | text: string to convert to a sequence
47 | Returns:
48 | List of integers corresponding to the symbols in the text
49 | """
50 | symbol_to_id = {s: i for i, s in enumerate(symbols)}
51 | language_id_map = {s: i for i, s in enumerate(languages)}
52 | phones = [symbol_to_id[symbol] for symbol in cleaned_text]
53 | tone_start = language_tone_start_map[language]
54 | tones = [i + tone_start for i in tones]
55 | lang_id = language_id_map[language]
56 | lang_ids = [lang_id for i in phones]
57 | return phones, tones, lang_ids
58 |
59 |
60 | def sequence_to_text(sequence):
61 | '''Converts a sequence of IDs back to a string'''
62 | result = ''
63 | for symbol_id in sequence:
64 | s = _id_to_symbol[symbol_id]
65 | result += s
66 | return result
67 |
68 |
69 | def _clean_text(text, cleaner_names):
70 | for name in cleaner_names:
71 | cleaner = getattr(cleaners, name)
72 | if not cleaner:
73 | raise Exception('Unknown cleaner: %s' % name)
74 | text = cleaner(text)
75 | return text
76 |
--------------------------------------------------------------------------------
/ChatTTS/infer/api.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | import torch.nn.functional as F
4 | from transformers.generation import TopKLogitsWarper, TopPLogitsWarper
5 | from ..utils.infer_utils import CustomRepetitionPenaltyLogitsProcessorRepeat
6 |
7 | def infer_code(
8 | models,
9 | text,
10 | spk_emb = None,
11 | top_P = 0.7,
12 | top_K = 20,
13 | temperature = 0.3,
14 | repetition_penalty = 1.05,
15 | max_new_token = 2048,
16 | **kwargs
17 | ):
18 |
19 | device = next(models['gpt'].parameters()).device
20 |
21 | if not isinstance(text, list):
22 | text = [text]
23 |
24 | if not isinstance(temperature, list):
25 | temperature = [temperature] * models['gpt'].num_vq
26 |
27 | if spk_emb is not None:
28 | text = [f'[Stts][spk_emb]{i}[uv_break][Ptts]' for i in text]
29 | else:
30 | text = [f'[Stts][empty_spk]{i}[uv_break][Ptts]' for i in text]
31 |
32 | text_token = models['tokenizer'](text, return_tensors='pt', add_special_tokens=False, padding=True).to(device)
33 | input_ids = text_token['input_ids'][...,None].expand(-1, -1, models['gpt'].num_vq)
34 | text_mask = torch.ones(text_token['input_ids'].shape, dtype=bool, device=device)
35 |
36 | inputs = {
37 | 'input_ids': input_ids,
38 | 'text_mask': text_mask,
39 | 'attention_mask': text_token['attention_mask'],
40 | }
41 |
42 | emb = models['gpt'].get_emb(**inputs)
43 | if spk_emb is not None:
44 | emb[inputs['input_ids'][..., 0] == models['tokenizer'].convert_tokens_to_ids('[spk_emb]')] = \
45 | F.normalize(spk_emb.to(device).to(emb.dtype)[None].expand(len(text), -1), p=2.0, dim=1, eps=1e-12)
46 |
47 | num_code = models['gpt'].emb_code[0].num_embeddings - 1
48 |
49 | LogitsWarpers = []
50 | if top_P is not None:
51 | LogitsWarpers.append(TopPLogitsWarper(top_P, min_tokens_to_keep=3))
52 | if top_K is not None:
53 | LogitsWarpers.append(TopKLogitsWarper(top_K, min_tokens_to_keep=3))
54 |
55 | LogitsProcessors = []
56 | if repetition_penalty is not None and repetition_penalty != 1:
57 | LogitsProcessors.append(CustomRepetitionPenaltyLogitsProcessorRepeat(\
58 | repetition_penalty, num_code, 16))
59 |
60 | result = models['gpt'].generate(
61 | emb, inputs['input_ids'],
62 | temperature = torch.tensor(temperature, device=device),
63 | attention_mask = inputs['attention_mask'],
64 | LogitsWarpers = LogitsWarpers,
65 | LogitsProcessors = LogitsProcessors,
66 | eos_token = num_code,
67 | max_new_token = max_new_token,
68 | infer_text = False,
69 | **kwargs
70 | )
71 |
72 | return result
73 |
74 |
75 | def refine_text(
76 | models,
77 | text,
78 | top_P = 0.7,
79 | top_K = 20,
80 | temperature = 0.7,
81 | repetition_penalty = 1.0,
82 | max_new_token = 384,
83 | prompt = '',
84 | **kwargs
85 | ):
86 |
87 | device = next(models['gpt'].parameters()).device
88 |
89 | if not isinstance(text, list):
90 | text = [text]
91 |
92 | assert len(text), 'text should not be empty'
93 |
94 | text = [f"[Sbreak]{i}[Pbreak]{prompt}" for i in text]
95 | text_token = models['tokenizer'](text, return_tensors='pt', add_special_tokens=False, padding=True).to(device)
96 | text_mask = torch.ones(text_token['input_ids'].shape, dtype=bool, device=device)
97 |
98 | inputs = {
99 | 'input_ids': text_token['input_ids'][...,None].expand(-1, -1, models['gpt'].num_vq),
100 | 'text_mask': text_mask,
101 | 'attention_mask': text_token['attention_mask'],
102 | }
103 |
104 | LogitsWarpers = []
105 | if top_P is not None:
106 | LogitsWarpers.append(TopPLogitsWarper(top_P, min_tokens_to_keep=3))
107 | if top_K is not None:
108 | LogitsWarpers.append(TopKLogitsWarper(top_K, min_tokens_to_keep=3))
109 |
110 | LogitsProcessors = []
111 | if repetition_penalty is not None and repetition_penalty != 1:
112 | LogitsProcessors.append(CustomRepetitionPenaltyLogitsProcessorRepeat(repetition_penalty, len(models['tokenizer']), 16))
113 |
114 | result = models['gpt'].generate(
115 | models['gpt'].get_emb(**inputs), inputs['input_ids'],
116 | temperature = torch.tensor([temperature,], device=device),
117 | attention_mask = inputs['attention_mask'],
118 | LogitsWarpers = LogitsWarpers,
119 | LogitsProcessors = LogitsProcessors,
120 | eos_token = torch.tensor(models['tokenizer'].convert_tokens_to_ids('[Ebreak]'), device=device)[None],
121 | max_new_token = max_new_token,
122 | infer_text = True,
123 | **kwargs
124 | )
125 | return result
--------------------------------------------------------------------------------
/OpenVoice/utils/se_extractor.py:
--------------------------------------------------------------------------------
1 | import os
2 | import glob
3 | import torch
4 | from glob import glob
5 | import numpy as np
6 | from pydub import AudioSegment
7 | from faster_whisper import WhisperModel
8 | from whisper_timestamped.transcribe import get_audio_tensor, get_vad_segments
9 |
10 | model_size = "medium"
11 | # Run on GPU with FP16
12 | model = None
13 | def split_audio_whisper(audio_path, target_dir='processed'):
14 | global model
15 | if model is None:
16 | model = WhisperModel(model_size, device="cuda", compute_type="float16")
17 | audio = AudioSegment.from_file(audio_path)
18 | max_len = len(audio)
19 |
20 | audio_name = os.path.basename(audio_path).rsplit('.', 1)[0]
21 | target_folder = os.path.join(target_dir, audio_name)
22 |
23 | segments, info = model.transcribe(audio_path, beam_size=5, word_timestamps=True)
24 | segments = list(segments)
25 |
26 | # create directory
27 | os.makedirs(target_folder, exist_ok=True)
28 | wavs_folder = os.path.join(target_folder, 'wavs')
29 | os.makedirs(wavs_folder, exist_ok=True)
30 |
31 | # segments
32 | s_ind = 0
33 | start_time = None
34 |
35 | for k, w in enumerate(segments):
36 | # process with the time
37 | if k == 0:
38 | start_time = max(0, w.start)
39 |
40 | end_time = w.end
41 |
42 | # calculate confidence
43 | if len(w.words) > 0:
44 | confidence = sum([s.probability for s in w.words]) / len(w.words)
45 | else:
46 | confidence = 0.
47 | # clean text
48 | text = w.text.replace('...', '')
49 |
50 | # left 0.08s for each audios
51 | audio_seg = audio[int( start_time * 1000) : min(max_len, int(end_time * 1000) + 80)]
52 |
53 | # segment file name
54 | fname = f"{audio_name}_seg{s_ind}.wav"
55 |
56 | # filter out the segment shorter than 1.5s and longer than 20s
57 | save = audio_seg.duration_seconds > 1.5 and \
58 | audio_seg.duration_seconds < 20. and \
59 | len(text) >= 2 and len(text) < 200
60 |
61 | if save:
62 | output_file = os.path.join(wavs_folder, fname)
63 | audio_seg.export(output_file, format='wav')
64 |
65 | if k < len(segments) - 1:
66 | start_time = max(0, segments[k+1].start - 0.08)
67 |
68 | s_ind = s_ind + 1
69 | return wavs_folder
70 |
71 |
72 | def split_audio_vad(audio_path, target_dir, split_seconds=10.0):
73 | SAMPLE_RATE = 16000
74 | audio_vad = get_audio_tensor(audio_path)
75 | segments = get_vad_segments(
76 | audio_vad,
77 | output_sample=True,
78 | min_speech_duration=0.1,
79 | min_silence_duration=1,
80 | method="silero",
81 | )
82 | segments = [(seg["start"], seg["end"]) for seg in segments]
83 | segments = [(float(s) / SAMPLE_RATE, float(e) / SAMPLE_RATE) for s,e in segments]
84 | print(segments)
85 | audio_active = AudioSegment.silent(duration=0)
86 | audio = AudioSegment.from_file(audio_path)
87 |
88 | for start_time, end_time in segments:
89 | audio_active += audio[int( start_time * 1000) : int(end_time * 1000)]
90 |
91 | audio_dur = audio_active.duration_seconds
92 | print(f'after vad: dur = {audio_dur}')
93 | audio_name = os.path.basename(audio_path).rsplit('.', 1)[0]
94 | target_folder = os.path.join(target_dir, audio_name)
95 | wavs_folder = os.path.join(target_folder, 'wavs')
96 | os.makedirs(wavs_folder, exist_ok=True)
97 | start_time = 0.
98 | count = 0
99 | num_splits = int(np.round(audio_dur / split_seconds))
100 | assert num_splits > 0, 'input audio is too short'
101 | interval = audio_dur / num_splits
102 |
103 | for i in range(num_splits):
104 | end_time = min(start_time + interval, audio_dur)
105 | if i == num_splits - 1:
106 | end_time = audio_dur
107 | output_file = f"{wavs_folder}/{audio_name}_seg{count}.wav"
108 | audio_seg = audio_active[int(start_time * 1000): int(end_time * 1000)]
109 | audio_seg.export(output_file, format='wav')
110 | start_time = end_time
111 | count += 1
112 | return wavs_folder
113 |
114 |
115 |
116 |
117 |
118 | def get_se(audio_path, vc_model, target_dir='processed', vad=True):
119 | device = vc_model.device
120 |
121 | audio_name = os.path.basename(audio_path).rsplit('.', 1)[0]
122 | se_path = os.path.join(target_dir, audio_name, 'se.pth')
123 |
124 | if os.path.isfile(se_path):
125 | se = torch.load(se_path).to(device)
126 | return se, audio_name
127 | if os.path.isdir(audio_path):
128 | wavs_folder = audio_path
129 | elif vad:
130 | wavs_folder = split_audio_vad(audio_path, target_dir)
131 | else:
132 | wavs_folder = split_audio_whisper(audio_path, target_dir)
133 |
134 | audio_segs = glob(f'{wavs_folder}/*.wav')
135 | if len(audio_segs) == 0:
136 | raise NotImplementedError('No audio segments found!')
137 |
138 | return vc_model.extract_se(audio_segs, se_save_path=se_path), audio_name
139 |
140 |
--------------------------------------------------------------------------------
/ChatTTS/model/dvae.py:
--------------------------------------------------------------------------------
1 | import math
2 | from einops import rearrange
3 | from vector_quantize_pytorch import GroupedResidualFSQ
4 |
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 |
9 | class ConvNeXtBlock(nn.Module):
10 | def __init__(
11 | self,
12 | dim: int,
13 | intermediate_dim: int,
14 | kernel, dilation,
15 | layer_scale_init_value: float = 1e-6,
16 | ):
17 | # ConvNeXt Block copied from Vocos.
18 | super().__init__()
19 | self.dwconv = nn.Conv1d(dim, dim,
20 | kernel_size=kernel, padding=dilation*(kernel//2),
21 | dilation=dilation, groups=dim
22 | ) # depthwise conv
23 |
24 | self.norm = nn.LayerNorm(dim, eps=1e-6)
25 | self.pwconv1 = nn.Linear(dim, intermediate_dim) # pointwise/1x1 convs, implemented with linear layers
26 | self.act = nn.GELU()
27 | self.pwconv2 = nn.Linear(intermediate_dim, dim)
28 | self.gamma = (
29 | nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True)
30 | if layer_scale_init_value > 0
31 | else None
32 | )
33 |
34 | def forward(self, x: torch.Tensor, cond = None) -> torch.Tensor:
35 | residual = x
36 | x = self.dwconv(x)
37 | x = x.transpose(1, 2) # (B, C, T) -> (B, T, C)
38 | x = self.norm(x)
39 | x = self.pwconv1(x)
40 | x = self.act(x)
41 | x = self.pwconv2(x)
42 | if self.gamma is not None:
43 | x = self.gamma * x
44 | x = x.transpose(1, 2) # (B, T, C) -> (B, C, T)
45 |
46 | x = residual + x
47 | return x
48 |
49 |
50 |
51 | class GFSQ(nn.Module):
52 |
53 | def __init__(self,
54 | dim, levels, G, R, eps=1e-5, transpose = True
55 | ):
56 | super(GFSQ, self).__init__()
57 | self.quantizer = GroupedResidualFSQ(
58 | dim=dim,
59 | levels=levels,
60 | num_quantizers=R,
61 | groups=G,
62 | )
63 | self.n_ind = math.prod(levels)
64 | self.eps = eps
65 | self.transpose = transpose
66 | self.G = G
67 | self.R = R
68 |
69 | def _embed(self, x):
70 | if self.transpose:
71 | x = x.transpose(1,2)
72 | x = rearrange(
73 | x, "b t (g r) -> g b t r", g = self.G, r = self.R,
74 | )
75 | feat = self.quantizer.get_output_from_indices(x)
76 | return feat.transpose(1,2) if self.transpose else feat
77 |
78 | def forward(self, x,):
79 | if self.transpose:
80 | x = x.transpose(1,2)
81 | feat, ind = self.quantizer(x)
82 | ind = rearrange(
83 | ind, "g b t r ->b t (g r)",
84 | )
85 | embed_onehot = F.one_hot(ind.long(), self.n_ind).to(x.dtype)
86 | e_mean = torch.mean(embed_onehot, dim=[0,1])
87 | e_mean = e_mean / (e_mean.sum(dim=1) + self.eps).unsqueeze(1)
88 | perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + self.eps), dim=1))
89 |
90 | return (
91 | torch.zeros(perplexity.shape, dtype=x.dtype, device=x.device),
92 | feat.transpose(1,2) if self.transpose else feat,
93 | perplexity,
94 | None,
95 | ind.transpose(1,2) if self.transpose else ind,
96 | )
97 |
98 | class DVAEDecoder(nn.Module):
99 | def __init__(self, idim, odim,
100 | n_layer = 12, bn_dim = 64, hidden = 256,
101 | kernel = 7, dilation = 2, up = False
102 | ):
103 | super().__init__()
104 | self.up = up
105 | self.conv_in = nn.Sequential(
106 | nn.Conv1d(idim, bn_dim, 3, 1, 1), nn.GELU(),
107 | nn.Conv1d(bn_dim, hidden, 3, 1, 1)
108 | )
109 | self.decoder_block = nn.ModuleList([
110 | ConvNeXtBlock(hidden, hidden* 4, kernel, dilation,)
111 | for _ in range(n_layer)])
112 | self.conv_out = nn.Conv1d(hidden, odim, kernel_size=1, bias=False)
113 |
114 | def forward(self, input, conditioning=None):
115 | # B, T, C
116 | x = input.transpose(1, 2)
117 | x = self.conv_in(x)
118 | for f in self.decoder_block:
119 | x = f(x, conditioning)
120 |
121 | x = self.conv_out(x)
122 | return x.transpose(1, 2)
123 |
124 |
125 | class DVAE(nn.Module):
126 | def __init__(
127 | self, decoder_config, vq_config, dim=512
128 | ):
129 | super().__init__()
130 | self.register_buffer('coef', torch.randn(1, 100, 1))
131 |
132 | self.decoder = DVAEDecoder(**decoder_config)
133 | self.out_conv = nn.Conv1d(dim, 100, 3, 1, 1, bias=False)
134 | if vq_config is not None:
135 | self.vq_layer = GFSQ(**vq_config)
136 | else:
137 | self.vq_layer = None
138 |
139 | def forward(self, inp):
140 |
141 | if self.vq_layer is not None:
142 | vq_feats = self.vq_layer._embed(inp)
143 | else:
144 | vq_feats = inp.detach().clone()
145 |
146 | temp = torch.chunk(vq_feats, 2, dim=1) # flatten trick :)
147 | temp = torch.stack(temp, -1)
148 | vq_feats = temp.reshape(*temp.shape[:2], -1)
149 |
150 | vq_feats = vq_feats.transpose(1, 2)
151 | dec_out = self.decoder(input=vq_feats)
152 | dec_out = self.out_conv(dec_out.transpose(1, 2))
153 | mel = dec_out * self.coef
154 |
155 | return mel
156 |
--------------------------------------------------------------------------------
/OpenVoice/utils/commons.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | from torch.nn import functional as F
4 |
5 |
6 | def init_weights(m, mean=0.0, std=0.01):
7 | classname = m.__class__.__name__
8 | if classname.find("Conv") != -1:
9 | m.weight.data.normal_(mean, std)
10 |
11 |
12 | def get_padding(kernel_size, dilation=1):
13 | return int((kernel_size * dilation - dilation) / 2)
14 |
15 |
16 | def convert_pad_shape(pad_shape):
17 | layer = pad_shape[::-1]
18 | pad_shape = [item for sublist in layer for item in sublist]
19 | return pad_shape
20 |
21 |
22 | def intersperse(lst, item):
23 | result = [item] * (len(lst) * 2 + 1)
24 | result[1::2] = lst
25 | return result
26 |
27 |
28 | def kl_divergence(m_p, logs_p, m_q, logs_q):
29 | """KL(P||Q)"""
30 | kl = (logs_q - logs_p) - 0.5
31 | kl += (
32 | 0.5 * (torch.exp(2.0 * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2.0 * logs_q)
33 | )
34 | return kl
35 |
36 |
37 | def rand_gumbel(shape):
38 | """Sample from the Gumbel distribution, protect from overflows."""
39 | uniform_samples = torch.rand(shape) * 0.99998 + 0.00001
40 | return -torch.log(-torch.log(uniform_samples))
41 |
42 |
43 | def rand_gumbel_like(x):
44 | g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device)
45 | return g
46 |
47 |
48 | def slice_segments(x, ids_str, segment_size=4):
49 | ret = torch.zeros_like(x[:, :, :segment_size])
50 | for i in range(x.size(0)):
51 | idx_str = ids_str[i]
52 | idx_end = idx_str + segment_size
53 | ret[i] = x[i, :, idx_str:idx_end]
54 | return ret
55 |
56 |
57 | def rand_slice_segments(x, x_lengths=None, segment_size=4):
58 | b, d, t = x.size()
59 | if x_lengths is None:
60 | x_lengths = t
61 | ids_str_max = x_lengths - segment_size + 1
62 | ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long)
63 | ret = slice_segments(x, ids_str, segment_size)
64 | return ret, ids_str
65 |
66 |
67 | def get_timing_signal_1d(length, channels, min_timescale=1.0, max_timescale=1.0e4):
68 | position = torch.arange(length, dtype=torch.float)
69 | num_timescales = channels // 2
70 | log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / (
71 | num_timescales - 1
72 | )
73 | inv_timescales = min_timescale * torch.exp(
74 | torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment
75 | )
76 | scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1)
77 | signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0)
78 | signal = F.pad(signal, [0, 0, 0, channels % 2])
79 | signal = signal.view(1, channels, length)
80 | return signal
81 |
82 |
83 | def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4):
84 | b, channels, length = x.size()
85 | signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
86 | return x + signal.to(dtype=x.dtype, device=x.device)
87 |
88 |
89 | def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1):
90 | b, channels, length = x.size()
91 | signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
92 | return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis)
93 |
94 |
95 | def subsequent_mask(length):
96 | mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0)
97 | return mask
98 |
99 |
100 | @torch.jit.script
101 | def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
102 | n_channels_int = n_channels[0]
103 | in_act = input_a + input_b
104 | t_act = torch.tanh(in_act[:, :n_channels_int, :])
105 | s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
106 | acts = t_act * s_act
107 | return acts
108 |
109 |
110 | def convert_pad_shape(pad_shape):
111 | layer = pad_shape[::-1]
112 | pad_shape = [item for sublist in layer for item in sublist]
113 | return pad_shape
114 |
115 |
116 | def shift_1d(x):
117 | x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1]
118 | return x
119 |
120 |
121 | def sequence_mask(length, max_length=None):
122 | if max_length is None:
123 | max_length = length.max()
124 | x = torch.arange(max_length, dtype=length.dtype, device=length.device)
125 | return x.unsqueeze(0) < length.unsqueeze(1)
126 |
127 |
128 | def generate_path(duration, mask):
129 | """
130 | duration: [b, 1, t_x]
131 | mask: [b, 1, t_y, t_x]
132 | """
133 |
134 | b, _, t_y, t_x = mask.shape
135 | cum_duration = torch.cumsum(duration, -1)
136 |
137 | cum_duration_flat = cum_duration.view(b * t_x)
138 | path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
139 | path = path.view(b, t_x, t_y)
140 | path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
141 | path = path.unsqueeze(1).transpose(2, 3) * mask
142 | return path
143 |
144 |
145 | def clip_grad_value_(parameters, clip_value, norm_type=2):
146 | if isinstance(parameters, torch.Tensor):
147 | parameters = [parameters]
148 | parameters = list(filter(lambda p: p.grad is not None, parameters))
149 | norm_type = float(norm_type)
150 | if clip_value is not None:
151 | clip_value = float(clip_value)
152 |
153 | total_norm = 0
154 | for p in parameters:
155 | param_norm = p.grad.data.norm(norm_type)
156 | total_norm += param_norm.item() ** norm_type
157 | if clip_value is not None:
158 | p.grad.data.clamp_(min=-clip_value, max=clip_value)
159 | total_norm = total_norm ** (1.0 / norm_type)
160 | return total_norm
161 |
--------------------------------------------------------------------------------
/OpenVoice/text/english.py:
--------------------------------------------------------------------------------
1 | """ from https://github.com/keithito/tacotron """
2 |
3 | '''
4 | Cleaners are transformations that run over the input text at both training and eval time.
5 |
6 | Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners"
7 | hyperparameter. Some cleaners are English-specific. You'll typically want to use:
8 | 1. "english_cleaners" for English text
9 | 2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using
10 | the Unidecode library (https://pypi.python.org/pypi/Unidecode)
11 | 3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update
12 | the symbols in symbols.py to match your data).
13 | '''
14 |
15 |
16 | # Regular expression matching whitespace:
17 |
18 |
19 | import re
20 | import inflect
21 | from unidecode import unidecode
22 | import eng_to_ipa as ipa
23 | _inflect = inflect.engine()
24 | _comma_number_re = re.compile(r'([0-9][0-9\,]+[0-9])')
25 | _decimal_number_re = re.compile(r'([0-9]+\.[0-9]+)')
26 | _pounds_re = re.compile(r'£([0-9\,]*[0-9]+)')
27 | _dollars_re = re.compile(r'\$([0-9\.\,]*[0-9]+)')
28 | _ordinal_re = re.compile(r'[0-9]+(st|nd|rd|th)')
29 | _number_re = re.compile(r'[0-9]+')
30 |
31 | # List of (regular expression, replacement) pairs for abbreviations:
32 | _abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [
33 | ('mrs', 'misess'),
34 | ('mr', 'mister'),
35 | ('dr', 'doctor'),
36 | ('st', 'saint'),
37 | ('co', 'company'),
38 | ('jr', 'junior'),
39 | ('maj', 'major'),
40 | ('gen', 'general'),
41 | ('drs', 'doctors'),
42 | ('rev', 'reverend'),
43 | ('lt', 'lieutenant'),
44 | ('hon', 'honorable'),
45 | ('sgt', 'sergeant'),
46 | ('capt', 'captain'),
47 | ('esq', 'esquire'),
48 | ('ltd', 'limited'),
49 | ('col', 'colonel'),
50 | ('ft', 'fort'),
51 | ]]
52 |
53 |
54 | # List of (ipa, lazy ipa) pairs:
55 | _lazy_ipa = [(re.compile('%s' % x[0]), x[1]) for x in [
56 | ('r', 'ɹ'),
57 | ('æ', 'e'),
58 | ('ɑ', 'a'),
59 | ('ɔ', 'o'),
60 | ('ð', 'z'),
61 | ('θ', 's'),
62 | ('ɛ', 'e'),
63 | ('ɪ', 'i'),
64 | ('ʊ', 'u'),
65 | ('ʒ', 'ʥ'),
66 | ('ʤ', 'ʥ'),
67 | ('ˈ', '↓'),
68 | ]]
69 |
70 | # List of (ipa, lazy ipa2) pairs:
71 | _lazy_ipa2 = [(re.compile('%s' % x[0]), x[1]) for x in [
72 | ('r', 'ɹ'),
73 | ('ð', 'z'),
74 | ('θ', 's'),
75 | ('ʒ', 'ʑ'),
76 | ('ʤ', 'dʑ'),
77 | ('ˈ', '↓'),
78 | ]]
79 |
80 | # List of (ipa, ipa2) pairs
81 | _ipa_to_ipa2 = [(re.compile('%s' % x[0]), x[1]) for x in [
82 | ('r', 'ɹ'),
83 | ('ʤ', 'dʒ'),
84 | ('ʧ', 'tʃ')
85 | ]]
86 |
87 |
88 | def expand_abbreviations(text):
89 | for regex, replacement in _abbreviations:
90 | text = re.sub(regex, replacement, text)
91 | return text
92 |
93 |
94 | def collapse_whitespace(text):
95 | return re.sub(r'\s+', ' ', text)
96 |
97 |
98 | def _remove_commas(m):
99 | return m.group(1).replace(',', '')
100 |
101 |
102 | def _expand_decimal_point(m):
103 | return m.group(1).replace('.', ' point ')
104 |
105 |
106 | def _expand_dollars(m):
107 | match = m.group(1)
108 | parts = match.split('.')
109 | if len(parts) > 2:
110 | return match + ' dollars' # Unexpected format
111 | dollars = int(parts[0]) if parts[0] else 0
112 | cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0
113 | if dollars and cents:
114 | dollar_unit = 'dollar' if dollars == 1 else 'dollars'
115 | cent_unit = 'cent' if cents == 1 else 'cents'
116 | return '%s %s, %s %s' % (dollars, dollar_unit, cents, cent_unit)
117 | elif dollars:
118 | dollar_unit = 'dollar' if dollars == 1 else 'dollars'
119 | return '%s %s' % (dollars, dollar_unit)
120 | elif cents:
121 | cent_unit = 'cent' if cents == 1 else 'cents'
122 | return '%s %s' % (cents, cent_unit)
123 | else:
124 | return 'zero dollars'
125 |
126 |
127 | def _expand_ordinal(m):
128 | return _inflect.number_to_words(m.group(0))
129 |
130 |
131 | def _expand_number(m):
132 | num = int(m.group(0))
133 | if num > 1000 and num < 3000:
134 | if num == 2000:
135 | return 'two thousand'
136 | elif num > 2000 and num < 2010:
137 | return 'two thousand ' + _inflect.number_to_words(num % 100)
138 | elif num % 100 == 0:
139 | return _inflect.number_to_words(num // 100) + ' hundred'
140 | else:
141 | return _inflect.number_to_words(num, andword='', zero='oh', group=2).replace(', ', ' ')
142 | else:
143 | return _inflect.number_to_words(num, andword='')
144 |
145 |
146 | def normalize_numbers(text):
147 | text = re.sub(_comma_number_re, _remove_commas, text)
148 | text = re.sub(_pounds_re, r'\1 pounds', text)
149 | text = re.sub(_dollars_re, _expand_dollars, text)
150 | text = re.sub(_decimal_number_re, _expand_decimal_point, text)
151 | text = re.sub(_ordinal_re, _expand_ordinal, text)
152 | text = re.sub(_number_re, _expand_number, text)
153 | return text
154 |
155 |
156 | def mark_dark_l(text):
157 | return re.sub(r'l([^aeiouæɑɔəɛɪʊ ]*(?: |$))', lambda x: 'ɫ'+x.group(1), text)
158 |
159 |
160 | def english_to_ipa(text):
161 | text = unidecode(text).lower()
162 | text = expand_abbreviations(text)
163 | text = normalize_numbers(text)
164 | phonemes = ipa.convert(text)
165 | phonemes = collapse_whitespace(phonemes)
166 | return phonemes
167 |
168 |
169 | def english_to_lazy_ipa(text):
170 | text = english_to_ipa(text)
171 | for regex, replacement in _lazy_ipa:
172 | text = re.sub(regex, replacement, text)
173 | return text
174 |
175 |
176 | def english_to_ipa2(text):
177 | text = english_to_ipa(text)
178 | text = mark_dark_l(text)
179 | for regex, replacement in _ipa_to_ipa2:
180 | text = re.sub(regex, replacement, text)
181 | return text.replace('...', '…')
182 |
183 |
184 | def english_to_lazy_ipa2(text):
185 | text = english_to_ipa(text)
186 | for regex, replacement in _lazy_ipa2:
187 | text = re.sub(regex, replacement, text)
188 | return text
189 |
--------------------------------------------------------------------------------
/OpenVoice/utils/utils.py:
--------------------------------------------------------------------------------
1 | import re
2 | import json
3 | import numpy as np
4 |
5 |
6 | def get_hparams_from_file(config_path):
7 | with open(config_path, "r", encoding="utf-8") as f:
8 | data = f.read()
9 | config = json.loads(data)
10 |
11 | hparams = HParams(**config)
12 | return hparams
13 |
14 | class HParams:
15 | def __init__(self, **kwargs):
16 | for k, v in kwargs.items():
17 | if type(v) == dict:
18 | v = HParams(**v)
19 | self[k] = v
20 |
21 | def keys(self):
22 | return self.__dict__.keys()
23 |
24 | def items(self):
25 | return self.__dict__.items()
26 |
27 | def values(self):
28 | return self.__dict__.values()
29 |
30 | def __len__(self):
31 | return len(self.__dict__)
32 |
33 | def __getitem__(self, key):
34 | return getattr(self, key)
35 |
36 | def __setitem__(self, key, value):
37 | return setattr(self, key, value)
38 |
39 | def __contains__(self, key):
40 | return key in self.__dict__
41 |
42 | def __repr__(self):
43 | return self.__dict__.__repr__()
44 |
45 |
46 | def string_to_bits(string, pad_len=8):
47 | # Convert each character to its ASCII value
48 | ascii_values = [ord(char) for char in string]
49 |
50 | # Convert ASCII values to binary representation
51 | binary_values = [bin(value)[2:].zfill(8) for value in ascii_values]
52 |
53 | # Convert binary strings to integer arrays
54 | bit_arrays = [[int(bit) for bit in binary] for binary in binary_values]
55 |
56 | # Convert list of arrays to NumPy array
57 | numpy_array = np.array(bit_arrays)
58 | numpy_array_full = np.zeros((pad_len, 8), dtype=numpy_array.dtype)
59 | numpy_array_full[:, 2] = 1
60 | max_len = min(pad_len, len(numpy_array))
61 | numpy_array_full[:max_len] = numpy_array[:max_len]
62 | return numpy_array_full
63 |
64 |
65 | def bits_to_string(bits_array):
66 | # Convert each row of the array to a binary string
67 | binary_values = [''.join(str(bit) for bit in row) for row in bits_array]
68 |
69 | # Convert binary strings to ASCII values
70 | ascii_values = [int(binary, 2) for binary in binary_values]
71 |
72 | # Convert ASCII values to characters
73 | output_string = ''.join(chr(value) for value in ascii_values)
74 |
75 | return output_string
76 |
77 |
78 | def split_sentence(text, min_len=10, language_str='[EN]'):
79 | if language_str in ['EN']:
80 | sentences = split_sentences_latin(text, min_len=min_len)
81 | else:
82 | sentences = split_sentences_zh(text, min_len=min_len)
83 | return sentences
84 |
85 | def split_sentences_latin(text, min_len=10):
86 | """Split Long sentences into list of short ones
87 |
88 | Args:
89 | str: Input sentences.
90 |
91 | Returns:
92 | List[str]: list of output sentences.
93 | """
94 | # deal with dirty sentences
95 | text = re.sub('[。!?;]', '.', text)
96 | text = re.sub('[,]', ',', text)
97 | text = re.sub('[“”]', '"', text)
98 | text = re.sub('[‘’]', "'", text)
99 | text = re.sub(r"[\<\>\(\)\[\]\"\«\»]+", "", text)
100 | text = re.sub('[\n\t ]+', ' ', text)
101 | text = re.sub('([,.!?;])', r'\1 $#!', text)
102 | # split
103 | sentences = [s.strip() for s in text.split('$#!')]
104 | if len(sentences[-1]) == 0: del sentences[-1]
105 |
106 | new_sentences = []
107 | new_sent = []
108 | count_len = 0
109 | for ind, sent in enumerate(sentences):
110 | # print(sent)
111 | new_sent.append(sent)
112 | count_len += len(sent.split(" "))
113 | if count_len > min_len or ind == len(sentences) - 1:
114 | count_len = 0
115 | new_sentences.append(' '.join(new_sent))
116 | new_sent = []
117 | return merge_short_sentences_latin(new_sentences)
118 |
119 |
120 | def merge_short_sentences_latin(sens):
121 | """Avoid short sentences by merging them with the following sentence.
122 |
123 | Args:
124 | List[str]: list of input sentences.
125 |
126 | Returns:
127 | List[str]: list of output sentences.
128 | """
129 | sens_out = []
130 | for s in sens:
131 | # If the previous sentense is too short, merge them with
132 | # the current sentence.
133 | if len(sens_out) > 0 and len(sens_out[-1].split(" ")) <= 2:
134 | sens_out[-1] = sens_out[-1] + " " + s
135 | else:
136 | sens_out.append(s)
137 | try:
138 | if len(sens_out[-1].split(" ")) <= 2:
139 | sens_out[-2] = sens_out[-2] + " " + sens_out[-1]
140 | sens_out.pop(-1)
141 | except:
142 | pass
143 | return sens_out
144 |
145 | def split_sentences_zh(text, min_len=10):
146 | text = re.sub('[。!?;]', '.', text)
147 | text = re.sub('[,]', ',', text)
148 | # 将文本中的换行符、空格和制表符替换为空格
149 | text = re.sub('[\n\t ]+', ' ', text)
150 | # 在标点符号后添加一个空格
151 | text = re.sub('([,.!?;])', r'\1 $#!', text)
152 | # 分隔句子并去除前后空格
153 | # sentences = [s.strip() for s in re.split('(。|!|?|;)', text)]
154 | sentences = [s.strip() for s in text.split('$#!')]
155 | if len(sentences[-1]) == 0: del sentences[-1]
156 |
157 | new_sentences = []
158 | new_sent = []
159 | count_len = 0
160 | for ind, sent in enumerate(sentences):
161 | new_sent.append(sent)
162 | count_len += len(sent)
163 | if count_len > min_len or ind == len(sentences) - 1:
164 | count_len = 0
165 | new_sentences.append(' '.join(new_sent))
166 | new_sent = []
167 | return merge_short_sentences_zh(new_sentences)
168 |
169 |
170 | def merge_short_sentences_zh(sens):
171 | # return sens
172 | """Avoid short sentences by merging them with the following sentence.
173 |
174 | Args:
175 | List[str]: list of input sentences.
176 |
177 | Returns:
178 | List[str]: list of output sentences.
179 | """
180 | sens_out = []
181 | for s in sens:
182 | # If the previous sentense is too short, merge them with
183 | # the current sentence.
184 | if len(sens_out) > 0 and len(sens_out[-1]) <= 2:
185 | sens_out[-1] = sens_out[-1] + " " + s
186 | else:
187 | sens_out.append(s)
188 | try:
189 | if len(sens_out[-1]) <= 2:
190 | sens_out[-2] = sens_out[-2] + " " + sens_out[-1]
191 | sens_out.pop(-1)
192 | except:
193 | pass
194 | return sens_out
--------------------------------------------------------------------------------
/ChatTTS/core.py:
--------------------------------------------------------------------------------
1 |
2 | import os
3 | import logging
4 | from omegaconf import OmegaConf
5 |
6 | import torch
7 | from vocos import Vocos
8 | from .model.dvae import DVAE
9 | from .model.gpt import GPT_warpper
10 | from .utils.gpu_utils import select_device
11 | from .utils.io_utils import get_latest_modified_file
12 | from .infer.api import refine_text, infer_code
13 |
14 | from huggingface_hub import snapshot_download
15 |
16 | logging.basicConfig(level = logging.INFO)
17 |
18 |
19 | class Chat:
20 | def __init__(self, ):
21 | self.pretrain_models = {}
22 | self.logger = logging.getLogger(__name__)
23 |
24 | def check_model(self, level = logging.INFO, use_decoder = False):
25 | not_finish = False
26 | check_list = ['vocos', 'gpt', 'tokenizer']
27 |
28 | if use_decoder:
29 | check_list.append('decoder')
30 | else:
31 | check_list.append('dvae')
32 |
33 | for module in check_list:
34 | if module not in self.pretrain_models:
35 | self.logger.log(logging.WARNING, f'{module} not initialized.')
36 | not_finish = True
37 |
38 | if not not_finish:
39 | self.logger.log(level, f'All initialized.')
40 |
41 | return not not_finish
42 |
43 | def load_models(self, source='huggingface', force_redownload=False, local_path=''):
44 | if source == 'huggingface':
45 | hf_home = os.getenv('HF_HOME', os.path.expanduser("~/.cache/huggingface"))
46 | try:
47 | download_path = get_latest_modified_file(os.path.join(hf_home, 'hub/models--2Noise--ChatTTS/snapshots'))
48 | except:
49 | download_path = None
50 | if download_path is None or force_redownload:
51 | self.logger.log(logging.INFO, f'Download from HF: https://huggingface.co/2Noise/ChatTTS')
52 | download_path = snapshot_download(repo_id="2Noise/ChatTTS", allow_patterns=["*.pt", "*.yaml"])
53 | else:
54 | self.logger.log(logging.INFO, f'Load from cache: {download_path}')
55 | self._load(**{k: os.path.join(download_path, v) for k, v in OmegaConf.load(os.path.join(download_path, 'config', 'path.yaml')).items()})
56 | elif source == 'local':
57 | self.logger.log(logging.INFO, f'Load from local: {local_path}')
58 | self._load(**{k: os.path.join(local_path, v) for k, v in OmegaConf.load(os.path.join(local_path, 'config', 'path.yaml')).items()})
59 |
60 | def _load(
61 | self,
62 | vocos_config_path: str = None,
63 | vocos_ckpt_path: str = None,
64 | dvae_config_path: str = None,
65 | dvae_ckpt_path: str = None,
66 | gpt_config_path: str = None,
67 | gpt_ckpt_path: str = None,
68 | decoder_config_path: str = None,
69 | decoder_ckpt_path: str = None,
70 | tokenizer_path: str = None,
71 | device: str = None
72 | ):
73 | if not device:
74 | device = select_device(4096)
75 | self.logger.log(logging.INFO, f'use {device}')
76 |
77 | if vocos_config_path:
78 | vocos = Vocos.from_hparams(vocos_config_path).to(device).eval()
79 | assert vocos_ckpt_path, 'vocos_ckpt_path should not be None'
80 | vocos.load_state_dict(torch.load(vocos_ckpt_path))
81 | self.pretrain_models['vocos'] = vocos
82 | self.logger.log(logging.INFO, 'vocos loaded.')
83 |
84 | if dvae_config_path:
85 | cfg = OmegaConf.load(dvae_config_path)
86 | dvae = DVAE(**cfg).to(device).eval()
87 | assert dvae_ckpt_path, 'dvae_ckpt_path should not be None'
88 | dvae.load_state_dict(torch.load(dvae_ckpt_path, map_location='cpu'))
89 | self.pretrain_models['dvae'] = dvae
90 | self.logger.log(logging.INFO, 'dvae loaded.')
91 |
92 | if gpt_config_path:
93 | cfg = OmegaConf.load(gpt_config_path)
94 | gpt = GPT_warpper(**cfg).to(device).eval()
95 | assert gpt_ckpt_path, 'gpt_ckpt_path should not be None'
96 | gpt.load_state_dict(torch.load(gpt_ckpt_path, map_location='cpu'))
97 | self.pretrain_models['gpt'] = gpt
98 | self.logger.log(logging.INFO, 'gpt loaded.')
99 |
100 | if decoder_config_path:
101 | cfg = OmegaConf.load(decoder_config_path)
102 | decoder = DVAE(**cfg).to(device).eval()
103 | assert decoder_ckpt_path, 'decoder_ckpt_path should not be None'
104 | decoder.load_state_dict(torch.load(decoder_ckpt_path, map_location='cpu'))
105 | self.pretrain_models['decoder'] = decoder
106 | self.logger.log(logging.INFO, 'decoder loaded.')
107 |
108 | if tokenizer_path:
109 | tokenizer = torch.load(tokenizer_path, map_location='cpu')
110 | tokenizer.padding_side = 'left'
111 | self.pretrain_models['tokenizer'] = tokenizer
112 | self.logger.log(logging.INFO, 'tokenizer loaded.')
113 |
114 | self.check_model()
115 |
116 | def infer(
117 | self,
118 | text,
119 | skip_refine_text=False,
120 | refine_text_only=False,
121 | params_refine_text={},
122 | params_infer_code={},
123 | use_decoder=False
124 | ):
125 |
126 | assert self.check_model(use_decoder=use_decoder)
127 |
128 | if not skip_refine_text:
129 | text_tokens = refine_text(self.pretrain_models, text, **params_refine_text)['ids']
130 | text_tokens = [i[i < self.pretrain_models['tokenizer'].convert_tokens_to_ids('[break_0]')] for i in text_tokens]
131 | text = self.pretrain_models['tokenizer'].batch_decode(text_tokens)
132 | if refine_text_only:
133 | return text
134 |
135 | text = [params_infer_code.get('prompt', '') + i for i in text]
136 | params_infer_code.pop('prompt', '')
137 | result = infer_code(self.pretrain_models, text, **params_infer_code, return_hidden=use_decoder)
138 |
139 | if use_decoder:
140 | mel_spec = [self.pretrain_models['decoder'](i[None].permute(0,2,1)) for i in result['hiddens']]
141 | else:
142 | mel_spec = [self.pretrain_models['dvae'](i[None].permute(0,2,1)) for i in result['ids']]
143 |
144 | wav = [self.pretrain_models['vocos'].decode(i).cpu().numpy() for i in mel_spec]
145 |
146 | return wav
147 |
148 |
149 |
150 |
--------------------------------------------------------------------------------
/app.py:
--------------------------------------------------------------------------------
1 | # import spaces
2 | import os
3 | import random
4 | import argparse
5 |
6 | import torch
7 | import gradio as gr
8 | import numpy as np
9 |
10 | import ChatTTS
11 |
12 | from OpenVoice.utils import se_extractor
13 | from OpenVoice.api import ToneColorConverter
14 | import soundfile
15 |
16 | print("loading ChatTTS model...")
17 | chat = ChatTTS.Chat()
18 | chat.load_models()
19 |
20 |
21 | def generate_seed():
22 | new_seed = random.randint(1, 100000000)
23 | return {
24 | "__type__": "update",
25 | "value": new_seed
26 | }
27 |
28 | # @spaces.GPU
29 | def chat_tts(text, temperature, top_P, top_K, audio_seed_input, text_seed_input, refine_text_flag, refine_text_input, output_path=None):
30 |
31 | torch.manual_seed(audio_seed_input)
32 | rand_spk = torch.randn(768)
33 | params_infer_code = {
34 | 'spk_emb': rand_spk,
35 | 'temperature': temperature,
36 | 'top_P': top_P,
37 | 'top_K': top_K,
38 | }
39 | params_refine_text = {'prompt': '[oral_2][laugh_0][break_6]'}
40 |
41 | torch.manual_seed(text_seed_input)
42 |
43 | if refine_text_flag:
44 | if refine_text_input:
45 | params_refine_text['prompt'] = refine_text_input
46 | text = chat.infer(text,
47 | skip_refine_text=False,
48 | refine_text_only=True,
49 | params_refine_text=params_refine_text,
50 | params_infer_code=params_infer_code
51 | )
52 | print("Text has been refined!")
53 |
54 | wav = chat.infer(text,
55 | skip_refine_text=True,
56 | params_refine_text=params_refine_text,
57 | params_infer_code=params_infer_code
58 | )
59 |
60 | audio_data = np.array(wav[0]).flatten()
61 | sample_rate = 24000
62 | text_data = text[0] if isinstance(text, list) else text
63 |
64 | if output_path is None:
65 | return [(sample_rate, audio_data), text_data]
66 | else:
67 | soundfile.write(output_path, audio_data, sample_rate)
68 | return text_data
69 |
70 | # OpenVoice Clone
71 | ckpt_converter = 'OpenVoice/checkpoints/converter'
72 | device = "cuda:0" if torch.cuda.is_available() else "cpu"
73 |
74 | tone_color_converter = ToneColorConverter(f'{ckpt_converter}/config.json', device=device)
75 | tone_color_converter.load_ckpt(f'{ckpt_converter}/checkpoint.pth')
76 |
77 | def generate_audio(text, audio_ref, temperature, top_P, top_K, audio_seed_input, text_seed_input, refine_text_flag, refine_text_input):
78 | save_path = "output.wav"
79 |
80 | if audio_ref != "" :
81 | # Run the base speaker tts
82 | src_path = "tmp.wav"
83 | text_data = chat_tts(text, temperature, top_P, top_K, audio_seed_input, text_seed_input, refine_text_flag, refine_text_input, src_path)
84 | print("Ready for voice cloning!")
85 |
86 | source_se, audio_name = se_extractor.get_se(src_path, tone_color_converter, target_dir='processed', vad=True)
87 | reference_speaker = audio_ref
88 | target_se, audio_name = se_extractor.get_se(reference_speaker, tone_color_converter, target_dir='processed', vad=True)
89 |
90 | print("Get voices segment!")
91 |
92 | # Run the tone color converter
93 | # convert from file
94 | tone_color_converter.convert(
95 | audio_src_path=src_path,
96 | src_se=source_se,
97 | tgt_se=target_se,
98 | output_path=save_path)
99 | else:
100 | chat_tts(text, temperature, top_P, top_K, audio_seed_input, text_seed_input, refine_text_flag, refine_text_input, save_path)
101 |
102 | print("Finished!")
103 |
104 | return [save_path, text_data]
105 |
106 |
107 | with gr.Blocks() as demo:
108 | gr.Markdown("# 🥳 ChatTTS x OpenVoice 🥳")
109 | gr.Markdown("## 🌟 Make it sound super natural and switch it up to any voice you want, nailing the mood and tone also!🌟 ")
110 |
111 | default_text = "Today a man knocked on my door and asked for a small donation toward the local swimming pool. I gave him a glass of water."
112 | text_input = gr.Textbox(label="Input Text", lines=4, placeholder="Please Input Text...", value=default_text)
113 |
114 |
115 | default_refine_text = "[oral_2][laugh_0][break_6]"
116 | refine_text_checkbox = gr.Checkbox(label="Refine text", info="'oral' means add filler words, 'laugh' means add laughter, and 'break' means add a pause. (0-10) ", value=True)
117 | refine_text_input = gr.Textbox(label="Refine Prompt", lines=1, placeholder="Please Refine Prompt...", value=default_refine_text)
118 | with gr.Column():
119 | voice_ref = gr.Audio(label="Reference Audio", type="filepath", value="Examples/speaker.mp3")
120 |
121 | with gr.Row():
122 | temperature_slider = gr.Slider(minimum=0.00001, maximum=1.0, step=0.00001, value=0.3, label="Audio temperature")
123 | top_p_slider = gr.Slider(minimum=0.1, maximum=0.9, step=0.05, value=0.7, label="top_P")
124 | top_k_slider = gr.Slider(minimum=1, maximum=20, step=1, value=20, label="top_K")
125 |
126 | with gr.Row():
127 | audio_seed_input = gr.Number(value=42, label="Speaker Seed")
128 | generate_audio_seed = gr.Button("\U0001F3B2")
129 | text_seed_input = gr.Number(value=42, label="Text Seed")
130 | generate_text_seed = gr.Button("\U0001F3B2")
131 |
132 | generate_button = gr.Button("Generate")
133 |
134 | text_output = gr.Textbox(label="Refined Text", interactive=False)
135 | audio_output = gr.Audio(label="Output Audio")
136 |
137 | generate_audio_seed.click(generate_seed,
138 | inputs=[],
139 | outputs=audio_seed_input)
140 |
141 | generate_text_seed.click(generate_seed,
142 | inputs=[],
143 | outputs=text_seed_input)
144 |
145 | generate_button.click(generate_audio,
146 | inputs=[text_input, voice_ref, temperature_slider, top_p_slider, top_k_slider, audio_seed_input, text_seed_input, refine_text_checkbox, refine_text_input],
147 | outputs=[audio_output,text_output])
148 |
149 | parser = argparse.ArgumentParser(description='ChatTTS-OpenVoice Launch')
150 | parser.add_argument('--server_name', type=str, default='0.0.0.0', help='Server name')
151 | parser.add_argument('--server_port', type=int, default=8080, help='Server port')
152 | args = parser.parse_args()
153 |
154 | # demo.launch(server_name=args.server_name, server_port=args.server_port, inbrowser=True)
155 |
156 | if __name__ == '__main__':
157 | demo.launch()
158 |
--------------------------------------------------------------------------------
/OpenVoice/utils/mel_processing.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.utils.data
3 | from librosa.filters import mel as librosa_mel_fn
4 |
5 | MAX_WAV_VALUE = 32768.0
6 |
7 |
8 | def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
9 | """
10 | PARAMS
11 | ------
12 | C: compression factor
13 | """
14 | return torch.log(torch.clamp(x, min=clip_val) * C)
15 |
16 |
17 | def dynamic_range_decompression_torch(x, C=1):
18 | """
19 | PARAMS
20 | ------
21 | C: compression factor used to compress
22 | """
23 | return torch.exp(x) / C
24 |
25 |
26 | def spectral_normalize_torch(magnitudes):
27 | output = dynamic_range_compression_torch(magnitudes)
28 | return output
29 |
30 |
31 | def spectral_de_normalize_torch(magnitudes):
32 | output = dynamic_range_decompression_torch(magnitudes)
33 | return output
34 |
35 |
36 | mel_basis = {}
37 | hann_window = {}
38 |
39 |
40 | def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False):
41 | if torch.min(y) < -1.1:
42 | print("min value is ", torch.min(y))
43 | if torch.max(y) > 1.1:
44 | print("max value is ", torch.max(y))
45 |
46 | global hann_window
47 | dtype_device = str(y.dtype) + "_" + str(y.device)
48 | wnsize_dtype_device = str(win_size) + "_" + dtype_device
49 | if wnsize_dtype_device not in hann_window:
50 | hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(
51 | dtype=y.dtype, device=y.device
52 | )
53 |
54 | y = torch.nn.functional.pad(
55 | y.unsqueeze(1),
56 | (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
57 | mode="reflect",
58 | )
59 | y = y.squeeze(1)
60 |
61 | spec = torch.stft(
62 | y,
63 | n_fft,
64 | hop_length=hop_size,
65 | win_length=win_size,
66 | window=hann_window[wnsize_dtype_device],
67 | center=center,
68 | pad_mode="reflect",
69 | normalized=False,
70 | onesided=True,
71 | return_complex=False,
72 | )
73 |
74 | spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
75 | return spec
76 |
77 |
78 | def spectrogram_torch_conv(y, n_fft, sampling_rate, hop_size, win_size, center=False):
79 | # if torch.min(y) < -1.:
80 | # print('min value is ', torch.min(y))
81 | # if torch.max(y) > 1.:
82 | # print('max value is ', torch.max(y))
83 |
84 | global hann_window
85 | dtype_device = str(y.dtype) + '_' + str(y.device)
86 | wnsize_dtype_device = str(win_size) + '_' + dtype_device
87 | if wnsize_dtype_device not in hann_window:
88 | hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device)
89 |
90 | y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect')
91 |
92 | # ******************** original ************************#
93 | # y = y.squeeze(1)
94 | # spec1 = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[wnsize_dtype_device],
95 | # center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=False)
96 |
97 | # ******************** ConvSTFT ************************#
98 | freq_cutoff = n_fft // 2 + 1
99 | fourier_basis = torch.view_as_real(torch.fft.fft(torch.eye(n_fft)))
100 | forward_basis = fourier_basis[:freq_cutoff].permute(2, 0, 1).reshape(-1, 1, fourier_basis.shape[1])
101 | forward_basis = forward_basis * torch.as_tensor(librosa.util.pad_center(torch.hann_window(win_size), size=n_fft)).float()
102 |
103 | import torch.nn.functional as F
104 |
105 | # if center:
106 | # signal = F.pad(y[:, None, None, :], (n_fft // 2, n_fft // 2, 0, 0), mode = 'reflect').squeeze(1)
107 | assert center is False
108 |
109 | forward_transform_squared = F.conv1d(y, forward_basis.to(y.device), stride = hop_size)
110 | spec2 = torch.stack([forward_transform_squared[:, :freq_cutoff, :], forward_transform_squared[:, freq_cutoff:, :]], dim = -1)
111 |
112 |
113 | # ******************** Verification ************************#
114 | spec1 = torch.stft(y.squeeze(1), n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[wnsize_dtype_device],
115 | center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=False)
116 | assert torch.allclose(spec1, spec2, atol=1e-4)
117 |
118 | spec = torch.sqrt(spec2.pow(2).sum(-1) + 1e-6)
119 | return spec
120 |
121 |
122 | def spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax):
123 | global mel_basis
124 | dtype_device = str(spec.dtype) + "_" + str(spec.device)
125 | fmax_dtype_device = str(fmax) + "_" + dtype_device
126 | if fmax_dtype_device not in mel_basis:
127 | mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax)
128 | mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(
129 | dtype=spec.dtype, device=spec.device
130 | )
131 | spec = torch.matmul(mel_basis[fmax_dtype_device], spec)
132 | spec = spectral_normalize_torch(spec)
133 | return spec
134 |
135 |
136 | def mel_spectrogram_torch(
137 | y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False
138 | ):
139 | if torch.min(y) < -1.0:
140 | print("min value is ", torch.min(y))
141 | if torch.max(y) > 1.0:
142 | print("max value is ", torch.max(y))
143 |
144 | global mel_basis, hann_window
145 | dtype_device = str(y.dtype) + "_" + str(y.device)
146 | fmax_dtype_device = str(fmax) + "_" + dtype_device
147 | wnsize_dtype_device = str(win_size) + "_" + dtype_device
148 | if fmax_dtype_device not in mel_basis:
149 | mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax)
150 | mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(
151 | dtype=y.dtype, device=y.device
152 | )
153 | if wnsize_dtype_device not in hann_window:
154 | hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(
155 | dtype=y.dtype, device=y.device
156 | )
157 |
158 | y = torch.nn.functional.pad(
159 | y.unsqueeze(1),
160 | (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
161 | mode="reflect",
162 | )
163 | y = y.squeeze(1)
164 |
165 | spec = torch.stft(
166 | y,
167 | n_fft,
168 | hop_length=hop_size,
169 | win_length=win_size,
170 | window=hann_window[wnsize_dtype_device],
171 | center=center,
172 | pad_mode="reflect",
173 | normalized=False,
174 | onesided=True,
175 | return_complex=False,
176 | )
177 |
178 | spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
179 |
180 | spec = torch.matmul(mel_basis[fmax_dtype_device], spec)
181 | spec = spectral_normalize_torch(spec)
182 |
183 | return spec
--------------------------------------------------------------------------------
/OpenVoice/utils/transforms.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.nn import functional as F
3 |
4 | import numpy as np
5 |
6 |
7 | DEFAULT_MIN_BIN_WIDTH = 1e-3
8 | DEFAULT_MIN_BIN_HEIGHT = 1e-3
9 | DEFAULT_MIN_DERIVATIVE = 1e-3
10 |
11 |
12 | def piecewise_rational_quadratic_transform(
13 | inputs,
14 | unnormalized_widths,
15 | unnormalized_heights,
16 | unnormalized_derivatives,
17 | inverse=False,
18 | tails=None,
19 | tail_bound=1.0,
20 | min_bin_width=DEFAULT_MIN_BIN_WIDTH,
21 | min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
22 | min_derivative=DEFAULT_MIN_DERIVATIVE,
23 | ):
24 | if tails is None:
25 | spline_fn = rational_quadratic_spline
26 | spline_kwargs = {}
27 | else:
28 | spline_fn = unconstrained_rational_quadratic_spline
29 | spline_kwargs = {"tails": tails, "tail_bound": tail_bound}
30 |
31 | outputs, logabsdet = spline_fn(
32 | inputs=inputs,
33 | unnormalized_widths=unnormalized_widths,
34 | unnormalized_heights=unnormalized_heights,
35 | unnormalized_derivatives=unnormalized_derivatives,
36 | inverse=inverse,
37 | min_bin_width=min_bin_width,
38 | min_bin_height=min_bin_height,
39 | min_derivative=min_derivative,
40 | **spline_kwargs
41 | )
42 | return outputs, logabsdet
43 |
44 |
45 | def searchsorted(bin_locations, inputs, eps=1e-6):
46 | bin_locations[..., -1] += eps
47 | return torch.sum(inputs[..., None] >= bin_locations, dim=-1) - 1
48 |
49 |
50 | def unconstrained_rational_quadratic_spline(
51 | inputs,
52 | unnormalized_widths,
53 | unnormalized_heights,
54 | unnormalized_derivatives,
55 | inverse=False,
56 | tails="linear",
57 | tail_bound=1.0,
58 | min_bin_width=DEFAULT_MIN_BIN_WIDTH,
59 | min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
60 | min_derivative=DEFAULT_MIN_DERIVATIVE,
61 | ):
62 | inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound)
63 | outside_interval_mask = ~inside_interval_mask
64 |
65 | outputs = torch.zeros_like(inputs)
66 | logabsdet = torch.zeros_like(inputs)
67 |
68 | if tails == "linear":
69 | unnormalized_derivatives = F.pad(unnormalized_derivatives, pad=(1, 1))
70 | constant = np.log(np.exp(1 - min_derivative) - 1)
71 | unnormalized_derivatives[..., 0] = constant
72 | unnormalized_derivatives[..., -1] = constant
73 |
74 | outputs[outside_interval_mask] = inputs[outside_interval_mask]
75 | logabsdet[outside_interval_mask] = 0
76 | else:
77 | raise RuntimeError("{} tails are not implemented.".format(tails))
78 |
79 | (
80 | outputs[inside_interval_mask],
81 | logabsdet[inside_interval_mask],
82 | ) = rational_quadratic_spline(
83 | inputs=inputs[inside_interval_mask],
84 | unnormalized_widths=unnormalized_widths[inside_interval_mask, :],
85 | unnormalized_heights=unnormalized_heights[inside_interval_mask, :],
86 | unnormalized_derivatives=unnormalized_derivatives[inside_interval_mask, :],
87 | inverse=inverse,
88 | left=-tail_bound,
89 | right=tail_bound,
90 | bottom=-tail_bound,
91 | top=tail_bound,
92 | min_bin_width=min_bin_width,
93 | min_bin_height=min_bin_height,
94 | min_derivative=min_derivative,
95 | )
96 |
97 | return outputs, logabsdet
98 |
99 |
100 | def rational_quadratic_spline(
101 | inputs,
102 | unnormalized_widths,
103 | unnormalized_heights,
104 | unnormalized_derivatives,
105 | inverse=False,
106 | left=0.0,
107 | right=1.0,
108 | bottom=0.0,
109 | top=1.0,
110 | min_bin_width=DEFAULT_MIN_BIN_WIDTH,
111 | min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
112 | min_derivative=DEFAULT_MIN_DERIVATIVE,
113 | ):
114 | if torch.min(inputs) < left or torch.max(inputs) > right:
115 | raise ValueError("Input to a transform is not within its domain")
116 |
117 | num_bins = unnormalized_widths.shape[-1]
118 |
119 | if min_bin_width * num_bins > 1.0:
120 | raise ValueError("Minimal bin width too large for the number of bins")
121 | if min_bin_height * num_bins > 1.0:
122 | raise ValueError("Minimal bin height too large for the number of bins")
123 |
124 | widths = F.softmax(unnormalized_widths, dim=-1)
125 | widths = min_bin_width + (1 - min_bin_width * num_bins) * widths
126 | cumwidths = torch.cumsum(widths, dim=-1)
127 | cumwidths = F.pad(cumwidths, pad=(1, 0), mode="constant", value=0.0)
128 | cumwidths = (right - left) * cumwidths + left
129 | cumwidths[..., 0] = left
130 | cumwidths[..., -1] = right
131 | widths = cumwidths[..., 1:] - cumwidths[..., :-1]
132 |
133 | derivatives = min_derivative + F.softplus(unnormalized_derivatives)
134 |
135 | heights = F.softmax(unnormalized_heights, dim=-1)
136 | heights = min_bin_height + (1 - min_bin_height * num_bins) * heights
137 | cumheights = torch.cumsum(heights, dim=-1)
138 | cumheights = F.pad(cumheights, pad=(1, 0), mode="constant", value=0.0)
139 | cumheights = (top - bottom) * cumheights + bottom
140 | cumheights[..., 0] = bottom
141 | cumheights[..., -1] = top
142 | heights = cumheights[..., 1:] - cumheights[..., :-1]
143 |
144 | if inverse:
145 | bin_idx = searchsorted(cumheights, inputs)[..., None]
146 | else:
147 | bin_idx = searchsorted(cumwidths, inputs)[..., None]
148 |
149 | input_cumwidths = cumwidths.gather(-1, bin_idx)[..., 0]
150 | input_bin_widths = widths.gather(-1, bin_idx)[..., 0]
151 |
152 | input_cumheights = cumheights.gather(-1, bin_idx)[..., 0]
153 | delta = heights / widths
154 | input_delta = delta.gather(-1, bin_idx)[..., 0]
155 |
156 | input_derivatives = derivatives.gather(-1, bin_idx)[..., 0]
157 | input_derivatives_plus_one = derivatives[..., 1:].gather(-1, bin_idx)[..., 0]
158 |
159 | input_heights = heights.gather(-1, bin_idx)[..., 0]
160 |
161 | if inverse:
162 | a = (inputs - input_cumheights) * (
163 | input_derivatives + input_derivatives_plus_one - 2 * input_delta
164 | ) + input_heights * (input_delta - input_derivatives)
165 | b = input_heights * input_derivatives - (inputs - input_cumheights) * (
166 | input_derivatives + input_derivatives_plus_one - 2 * input_delta
167 | )
168 | c = -input_delta * (inputs - input_cumheights)
169 |
170 | discriminant = b.pow(2) - 4 * a * c
171 | assert (discriminant >= 0).all()
172 |
173 | root = (2 * c) / (-b - torch.sqrt(discriminant))
174 | outputs = root * input_bin_widths + input_cumwidths
175 |
176 | theta_one_minus_theta = root * (1 - root)
177 | denominator = input_delta + (
178 | (input_derivatives + input_derivatives_plus_one - 2 * input_delta)
179 | * theta_one_minus_theta
180 | )
181 | derivative_numerator = input_delta.pow(2) * (
182 | input_derivatives_plus_one * root.pow(2)
183 | + 2 * input_delta * theta_one_minus_theta
184 | + input_derivatives * (1 - root).pow(2)
185 | )
186 | logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
187 |
188 | return outputs, -logabsdet
189 | else:
190 | theta = (inputs - input_cumwidths) / input_bin_widths
191 | theta_one_minus_theta = theta * (1 - theta)
192 |
193 | numerator = input_heights * (
194 | input_delta * theta.pow(2) + input_derivatives * theta_one_minus_theta
195 | )
196 | denominator = input_delta + (
197 | (input_derivatives + input_derivatives_plus_one - 2 * input_delta)
198 | * theta_one_minus_theta
199 | )
200 | outputs = input_cumheights + numerator / denominator
201 |
202 | derivative_numerator = input_delta.pow(2) * (
203 | input_derivatives_plus_one * theta.pow(2)
204 | + 2 * input_delta * theta_one_minus_theta
205 | + input_derivatives * (1 - theta).pow(2)
206 | )
207 | logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
208 |
209 | return outputs, logabsdet
210 |
--------------------------------------------------------------------------------
/OpenVoice/text/mandarin.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import re
4 | from pypinyin import lazy_pinyin, BOPOMOFO
5 | import jieba
6 | import cn2an
7 | import logging
8 |
9 |
10 | # List of (Latin alphabet, bopomofo) pairs:
11 | _latin_to_bopomofo = [(re.compile('%s' % x[0], re.IGNORECASE), x[1]) for x in [
12 | ('a', 'ㄟˉ'),
13 | ('b', 'ㄅㄧˋ'),
14 | ('c', 'ㄙㄧˉ'),
15 | ('d', 'ㄉㄧˋ'),
16 | ('e', 'ㄧˋ'),
17 | ('f', 'ㄝˊㄈㄨˋ'),
18 | ('g', 'ㄐㄧˋ'),
19 | ('h', 'ㄝˇㄑㄩˋ'),
20 | ('i', 'ㄞˋ'),
21 | ('j', 'ㄐㄟˋ'),
22 | ('k', 'ㄎㄟˋ'),
23 | ('l', 'ㄝˊㄛˋ'),
24 | ('m', 'ㄝˊㄇㄨˋ'),
25 | ('n', 'ㄣˉ'),
26 | ('o', 'ㄡˉ'),
27 | ('p', 'ㄆㄧˉ'),
28 | ('q', 'ㄎㄧㄡˉ'),
29 | ('r', 'ㄚˋ'),
30 | ('s', 'ㄝˊㄙˋ'),
31 | ('t', 'ㄊㄧˋ'),
32 | ('u', 'ㄧㄡˉ'),
33 | ('v', 'ㄨㄧˉ'),
34 | ('w', 'ㄉㄚˋㄅㄨˋㄌㄧㄡˋ'),
35 | ('x', 'ㄝˉㄎㄨˋㄙˋ'),
36 | ('y', 'ㄨㄞˋ'),
37 | ('z', 'ㄗㄟˋ')
38 | ]]
39 |
40 | # List of (bopomofo, romaji) pairs:
41 | _bopomofo_to_romaji = [(re.compile('%s' % x[0]), x[1]) for x in [
42 | ('ㄅㄛ', 'p⁼wo'),
43 | ('ㄆㄛ', 'pʰwo'),
44 | ('ㄇㄛ', 'mwo'),
45 | ('ㄈㄛ', 'fwo'),
46 | ('ㄅ', 'p⁼'),
47 | ('ㄆ', 'pʰ'),
48 | ('ㄇ', 'm'),
49 | ('ㄈ', 'f'),
50 | ('ㄉ', 't⁼'),
51 | ('ㄊ', 'tʰ'),
52 | ('ㄋ', 'n'),
53 | ('ㄌ', 'l'),
54 | ('ㄍ', 'k⁼'),
55 | ('ㄎ', 'kʰ'),
56 | ('ㄏ', 'h'),
57 | ('ㄐ', 'ʧ⁼'),
58 | ('ㄑ', 'ʧʰ'),
59 | ('ㄒ', 'ʃ'),
60 | ('ㄓ', 'ʦ`⁼'),
61 | ('ㄔ', 'ʦ`ʰ'),
62 | ('ㄕ', 's`'),
63 | ('ㄖ', 'ɹ`'),
64 | ('ㄗ', 'ʦ⁼'),
65 | ('ㄘ', 'ʦʰ'),
66 | ('ㄙ', 's'),
67 | ('ㄚ', 'a'),
68 | ('ㄛ', 'o'),
69 | ('ㄜ', 'ə'),
70 | ('ㄝ', 'e'),
71 | ('ㄞ', 'ai'),
72 | ('ㄟ', 'ei'),
73 | ('ㄠ', 'au'),
74 | ('ㄡ', 'ou'),
75 | ('ㄧㄢ', 'yeNN'),
76 | ('ㄢ', 'aNN'),
77 | ('ㄧㄣ', 'iNN'),
78 | ('ㄣ', 'əNN'),
79 | ('ㄤ', 'aNg'),
80 | ('ㄧㄥ', 'iNg'),
81 | ('ㄨㄥ', 'uNg'),
82 | ('ㄩㄥ', 'yuNg'),
83 | ('ㄥ', 'əNg'),
84 | ('ㄦ', 'əɻ'),
85 | ('ㄧ', 'i'),
86 | ('ㄨ', 'u'),
87 | ('ㄩ', 'ɥ'),
88 | ('ˉ', '→'),
89 | ('ˊ', '↑'),
90 | ('ˇ', '↓↑'),
91 | ('ˋ', '↓'),
92 | ('˙', ''),
93 | (',', ','),
94 | ('。', '.'),
95 | ('!', '!'),
96 | ('?', '?'),
97 | ('—', '-')
98 | ]]
99 |
100 | # List of (romaji, ipa) pairs:
101 | _romaji_to_ipa = [(re.compile('%s' % x[0], re.IGNORECASE), x[1]) for x in [
102 | ('ʃy', 'ʃ'),
103 | ('ʧʰy', 'ʧʰ'),
104 | ('ʧ⁼y', 'ʧ⁼'),
105 | ('NN', 'n'),
106 | ('Ng', 'ŋ'),
107 | ('y', 'j'),
108 | ('h', 'x')
109 | ]]
110 |
111 | # List of (bopomofo, ipa) pairs:
112 | _bopomofo_to_ipa = [(re.compile('%s' % x[0]), x[1]) for x in [
113 | ('ㄅㄛ', 'p⁼wo'),
114 | ('ㄆㄛ', 'pʰwo'),
115 | ('ㄇㄛ', 'mwo'),
116 | ('ㄈㄛ', 'fwo'),
117 | ('ㄅ', 'p⁼'),
118 | ('ㄆ', 'pʰ'),
119 | ('ㄇ', 'm'),
120 | ('ㄈ', 'f'),
121 | ('ㄉ', 't⁼'),
122 | ('ㄊ', 'tʰ'),
123 | ('ㄋ', 'n'),
124 | ('ㄌ', 'l'),
125 | ('ㄍ', 'k⁼'),
126 | ('ㄎ', 'kʰ'),
127 | ('ㄏ', 'x'),
128 | ('ㄐ', 'tʃ⁼'),
129 | ('ㄑ', 'tʃʰ'),
130 | ('ㄒ', 'ʃ'),
131 | ('ㄓ', 'ts`⁼'),
132 | ('ㄔ', 'ts`ʰ'),
133 | ('ㄕ', 's`'),
134 | ('ㄖ', 'ɹ`'),
135 | ('ㄗ', 'ts⁼'),
136 | ('ㄘ', 'tsʰ'),
137 | ('ㄙ', 's'),
138 | ('ㄚ', 'a'),
139 | ('ㄛ', 'o'),
140 | ('ㄜ', 'ə'),
141 | ('ㄝ', 'ɛ'),
142 | ('ㄞ', 'aɪ'),
143 | ('ㄟ', 'eɪ'),
144 | ('ㄠ', 'ɑʊ'),
145 | ('ㄡ', 'oʊ'),
146 | ('ㄧㄢ', 'jɛn'),
147 | ('ㄩㄢ', 'ɥæn'),
148 | ('ㄢ', 'an'),
149 | ('ㄧㄣ', 'in'),
150 | ('ㄩㄣ', 'ɥn'),
151 | ('ㄣ', 'ən'),
152 | ('ㄤ', 'ɑŋ'),
153 | ('ㄧㄥ', 'iŋ'),
154 | ('ㄨㄥ', 'ʊŋ'),
155 | ('ㄩㄥ', 'jʊŋ'),
156 | ('ㄥ', 'əŋ'),
157 | ('ㄦ', 'əɻ'),
158 | ('ㄧ', 'i'),
159 | ('ㄨ', 'u'),
160 | ('ㄩ', 'ɥ'),
161 | ('ˉ', '→'),
162 | ('ˊ', '↑'),
163 | ('ˇ', '↓↑'),
164 | ('ˋ', '↓'),
165 | ('˙', ''),
166 | (',', ','),
167 | ('。', '.'),
168 | ('!', '!'),
169 | ('?', '?'),
170 | ('—', '-')
171 | ]]
172 |
173 | # List of (bopomofo, ipa2) pairs:
174 | _bopomofo_to_ipa2 = [(re.compile('%s' % x[0]), x[1]) for x in [
175 | ('ㄅㄛ', 'pwo'),
176 | ('ㄆㄛ', 'pʰwo'),
177 | ('ㄇㄛ', 'mwo'),
178 | ('ㄈㄛ', 'fwo'),
179 | ('ㄅ', 'p'),
180 | ('ㄆ', 'pʰ'),
181 | ('ㄇ', 'm'),
182 | ('ㄈ', 'f'),
183 | ('ㄉ', 't'),
184 | ('ㄊ', 'tʰ'),
185 | ('ㄋ', 'n'),
186 | ('ㄌ', 'l'),
187 | ('ㄍ', 'k'),
188 | ('ㄎ', 'kʰ'),
189 | ('ㄏ', 'h'),
190 | ('ㄐ', 'tɕ'),
191 | ('ㄑ', 'tɕʰ'),
192 | ('ㄒ', 'ɕ'),
193 | ('ㄓ', 'tʂ'),
194 | ('ㄔ', 'tʂʰ'),
195 | ('ㄕ', 'ʂ'),
196 | ('ㄖ', 'ɻ'),
197 | ('ㄗ', 'ts'),
198 | ('ㄘ', 'tsʰ'),
199 | ('ㄙ', 's'),
200 | ('ㄚ', 'a'),
201 | ('ㄛ', 'o'),
202 | ('ㄜ', 'ɤ'),
203 | ('ㄝ', 'ɛ'),
204 | ('ㄞ', 'aɪ'),
205 | ('ㄟ', 'eɪ'),
206 | ('ㄠ', 'ɑʊ'),
207 | ('ㄡ', 'oʊ'),
208 | ('ㄧㄢ', 'jɛn'),
209 | ('ㄩㄢ', 'yæn'),
210 | ('ㄢ', 'an'),
211 | ('ㄧㄣ', 'in'),
212 | ('ㄩㄣ', 'yn'),
213 | ('ㄣ', 'ən'),
214 | ('ㄤ', 'ɑŋ'),
215 | ('ㄧㄥ', 'iŋ'),
216 | ('ㄨㄥ', 'ʊŋ'),
217 | ('ㄩㄥ', 'jʊŋ'),
218 | ('ㄥ', 'ɤŋ'),
219 | ('ㄦ', 'əɻ'),
220 | ('ㄧ', 'i'),
221 | ('ㄨ', 'u'),
222 | ('ㄩ', 'y'),
223 | ('ˉ', '˥'),
224 | ('ˊ', '˧˥'),
225 | ('ˇ', '˨˩˦'),
226 | ('ˋ', '˥˩'),
227 | ('˙', ''),
228 | (',', ','),
229 | ('。', '.'),
230 | ('!', '!'),
231 | ('?', '?'),
232 | ('—', '-')
233 | ]]
234 |
235 |
236 | def number_to_chinese(text):
237 | numbers = re.findall(r'\d+(?:\.?\d+)?', text)
238 | for number in numbers:
239 | text = text.replace(number, cn2an.an2cn(number), 1)
240 | return text
241 |
242 |
243 | def chinese_to_bopomofo(text):
244 | text = text.replace('、', ',').replace(';', ',').replace(':', ',')
245 | words = jieba.lcut(text, cut_all=False)
246 | text = ''
247 | for word in words:
248 | bopomofos = lazy_pinyin(word, BOPOMOFO)
249 | if not re.search('[\u4e00-\u9fff]', word):
250 | text += word
251 | continue
252 | for i in range(len(bopomofos)):
253 | bopomofos[i] = re.sub(r'([\u3105-\u3129])$', r'\1ˉ', bopomofos[i])
254 | if text != '':
255 | text += ' '
256 | text += ''.join(bopomofos)
257 | return text
258 |
259 |
260 | def latin_to_bopomofo(text):
261 | for regex, replacement in _latin_to_bopomofo:
262 | text = re.sub(regex, replacement, text)
263 | return text
264 |
265 |
266 | def bopomofo_to_romaji(text):
267 | for regex, replacement in _bopomofo_to_romaji:
268 | text = re.sub(regex, replacement, text)
269 | return text
270 |
271 |
272 | def bopomofo_to_ipa(text):
273 | for regex, replacement in _bopomofo_to_ipa:
274 | text = re.sub(regex, replacement, text)
275 | return text
276 |
277 |
278 | def bopomofo_to_ipa2(text):
279 | for regex, replacement in _bopomofo_to_ipa2:
280 | text = re.sub(regex, replacement, text)
281 | return text
282 |
283 |
284 | def chinese_to_romaji(text):
285 | text = number_to_chinese(text)
286 | text = chinese_to_bopomofo(text)
287 | text = latin_to_bopomofo(text)
288 | text = bopomofo_to_romaji(text)
289 | text = re.sub('i([aoe])', r'y\1', text)
290 | text = re.sub('u([aoəe])', r'w\1', text)
291 | text = re.sub('([ʦsɹ]`[⁼ʰ]?)([→↓↑ ]+|$)',
292 | r'\1ɹ`\2', text).replace('ɻ', 'ɹ`')
293 | text = re.sub('([ʦs][⁼ʰ]?)([→↓↑ ]+|$)', r'\1ɹ\2', text)
294 | return text
295 |
296 |
297 | def chinese_to_lazy_ipa(text):
298 | text = chinese_to_romaji(text)
299 | for regex, replacement in _romaji_to_ipa:
300 | text = re.sub(regex, replacement, text)
301 | return text
302 |
303 |
304 | def chinese_to_ipa(text):
305 | text = number_to_chinese(text)
306 | text = chinese_to_bopomofo(text)
307 | text = latin_to_bopomofo(text)
308 | text = bopomofo_to_ipa(text)
309 | text = re.sub('i([aoe])', r'j\1', text)
310 | text = re.sub('u([aoəe])', r'w\1', text)
311 | text = re.sub('([sɹ]`[⁼ʰ]?)([→↓↑ ]+|$)',
312 | r'\1ɹ`\2', text).replace('ɻ', 'ɹ`')
313 | text = re.sub('([s][⁼ʰ]?)([→↓↑ ]+|$)', r'\1ɹ\2', text)
314 | return text
315 |
316 |
317 | def chinese_to_ipa2(text):
318 | text = number_to_chinese(text)
319 | text = chinese_to_bopomofo(text)
320 | text = latin_to_bopomofo(text)
321 | text = bopomofo_to_ipa2(text)
322 | text = re.sub(r'i([aoe])', r'j\1', text)
323 | text = re.sub(r'u([aoəe])', r'w\1', text)
324 | text = re.sub(r'([ʂɹ]ʰ?)([˩˨˧˦˥ ]+|$)', r'\1ʅ\2', text)
325 | text = re.sub(r'(sʰ?)([˩˨˧˦˥ ]+|$)', r'\1ɿ\2', text)
326 | return text
327 |
--------------------------------------------------------------------------------
/OpenVoice/api.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import re
4 | import soundfile
5 | from .utils import utils
6 | from .utils import commons
7 | import os
8 | import librosa
9 | from .text import text_to_sequence
10 | from .utils.mel_processing import spectrogram_torch
11 | from .utils.models import SynthesizerTrn
12 |
13 |
14 | class OpenVoiceBaseClass(object):
15 | def __init__(self,
16 | config_path,
17 | #device='cuda:0'):
18 | device="cpu"):
19 | #if 'cuda' in device:
20 | # assert torch.cuda.is_available()
21 |
22 | hps = utils.get_hparams_from_file(config_path)
23 |
24 | model = SynthesizerTrn(
25 | len(getattr(hps, 'symbols', [])),
26 | hps.data.filter_length // 2 + 1,
27 | n_speakers=hps.data.n_speakers,
28 | **hps.model,
29 | ).to(device)
30 |
31 | model.eval()
32 | self.model = model
33 | self.hps = hps
34 | self.device = device
35 |
36 | def load_ckpt(self, ckpt_path):
37 | checkpoint_dict = torch.load(ckpt_path, map_location=torch.device('cpu'))
38 | a, b = self.model.load_state_dict(checkpoint_dict['model'], strict=False)
39 | print("Loaded checkpoint '{}'".format(ckpt_path))
40 | print('missing/unexpected keys:', a, b)
41 |
42 |
43 | class BaseSpeakerTTS(OpenVoiceBaseClass):
44 | language_marks = {
45 | "english": "EN",
46 | "chinese": "ZH",
47 | }
48 |
49 | @staticmethod
50 | def get_text(text, hps, is_symbol):
51 | text_norm = text_to_sequence(text, hps.symbols, [] if is_symbol else hps.data.text_cleaners)
52 | if hps.data.add_blank:
53 | text_norm = commons.intersperse(text_norm, 0)
54 | text_norm = torch.LongTensor(text_norm)
55 | return text_norm
56 |
57 | @staticmethod
58 | def audio_numpy_concat(segment_data_list, sr, speed=1.):
59 | audio_segments = []
60 | for segment_data in segment_data_list:
61 | audio_segments += segment_data.reshape(-1).tolist()
62 | audio_segments += [0] * int((sr * 0.05)/speed)
63 | audio_segments = np.array(audio_segments).astype(np.float32)
64 | return audio_segments
65 |
66 | @staticmethod
67 | def split_sentences_into_pieces(text, language_str):
68 | texts = utils.split_sentence(text, language_str=language_str)
69 | print(" > Text splitted to sentences.")
70 | print('\n'.join(texts))
71 | print(" > ===========================")
72 | return texts
73 |
74 | def tts(self, text, output_path, speaker, language='English', speed=1.0):
75 | mark = self.language_marks.get(language.lower(), None)
76 | assert mark is not None, f"language {language} is not supported"
77 |
78 | texts = self.split_sentences_into_pieces(text, mark)
79 |
80 | audio_list = []
81 | for t in texts:
82 | t = re.sub(r'([a-z])([A-Z])', r'\1 \2', t)
83 | t = f'[{mark}]{t}[{mark}]'
84 | stn_tst = self.get_text(t, self.hps, False)
85 | device = self.device
86 | speaker_id = self.hps.speakers[speaker]
87 | with torch.no_grad():
88 | x_tst = stn_tst.unsqueeze(0).to(device)
89 | x_tst_lengths = torch.LongTensor([stn_tst.size(0)]).to(device)
90 | sid = torch.LongTensor([speaker_id]).to(device)
91 | audio = self.model.infer(x_tst, x_tst_lengths, sid=sid, noise_scale=0.667, noise_scale_w=0.6,
92 | length_scale=1.0 / speed)[0][0, 0].data.cpu().float().numpy()
93 | audio_list.append(audio)
94 | audio = self.audio_numpy_concat(audio_list, sr=self.hps.data.sampling_rate, speed=speed)
95 |
96 | if output_path is None:
97 | return audio
98 | else:
99 | soundfile.write(output_path, audio, self.hps.data.sampling_rate)
100 |
101 |
102 | class ToneColorConverter(OpenVoiceBaseClass):
103 | def __init__(self, *args, **kwargs):
104 | super().__init__(*args, **kwargs)
105 |
106 | if kwargs.get('enable_watermark', True):
107 | import wavmark
108 | self.watermark_model = wavmark.load_model().to(self.device)
109 | else:
110 | self.watermark_model = None
111 |
112 |
113 |
114 | def extract_se(self, ref_wav_list, se_save_path=None):
115 | if isinstance(ref_wav_list, str):
116 | ref_wav_list = [ref_wav_list]
117 |
118 | device = self.device
119 | hps = self.hps
120 | gs = []
121 |
122 | for fname in ref_wav_list:
123 | audio_ref, sr = librosa.load(fname, sr=hps.data.sampling_rate)
124 | y = torch.FloatTensor(audio_ref)
125 | y = y.to(device)
126 | y = y.unsqueeze(0)
127 | y = spectrogram_torch(y, hps.data.filter_length,
128 | hps.data.sampling_rate, hps.data.hop_length, hps.data.win_length,
129 | center=False).to(device)
130 | with torch.no_grad():
131 | g = self.model.ref_enc(y.transpose(1, 2)).unsqueeze(-1)
132 | gs.append(g.detach())
133 | gs = torch.stack(gs).mean(0)
134 |
135 | if se_save_path is not None:
136 | os.makedirs(os.path.dirname(se_save_path), exist_ok=True)
137 | torch.save(gs.cpu(), se_save_path)
138 |
139 | return gs
140 |
141 | def convert(self, audio_src_path, src_se, tgt_se, output_path=None, tau=0.3, message="@Hilley-MyShell"):
142 | hps = self.hps
143 | # load audio
144 | audio, sample_rate = librosa.load(audio_src_path, sr=hps.data.sampling_rate)
145 | audio = torch.tensor(audio).float()
146 |
147 | with torch.no_grad():
148 | y = torch.FloatTensor(audio).to(self.device)
149 | y = y.unsqueeze(0)
150 | spec = spectrogram_torch(y, hps.data.filter_length,
151 | hps.data.sampling_rate, hps.data.hop_length, hps.data.win_length,
152 | center=False).to(self.device)
153 | spec_lengths = torch.LongTensor([spec.size(-1)]).to(self.device)
154 | audio = self.model.voice_conversion(spec, spec_lengths, sid_src=src_se, sid_tgt=tgt_se, tau=tau)[0][
155 | 0, 0].data.cpu().float().numpy()
156 | audio = self.add_watermark(audio, message)
157 | if output_path is None:
158 | return audio
159 | else:
160 | soundfile.write(output_path, audio, hps.data.sampling_rate)
161 |
162 | def add_watermark(self, audio, message):
163 | if self.watermark_model is None:
164 | return audio
165 | device = self.device
166 | bits = utils.string_to_bits(message).reshape(-1)
167 | n_repeat = len(bits) // 32
168 |
169 | K = 16000
170 | coeff = 2
171 | for n in range(n_repeat):
172 | trunck = audio[(coeff * n) * K: (coeff * n + 1) * K]
173 | if len(trunck) != K:
174 | print('Audio too short, fail to add watermark')
175 | break
176 | message_npy = bits[n * 32: (n + 1) * 32]
177 |
178 | with torch.no_grad():
179 | signal = torch.FloatTensor(trunck).to(device)[None]
180 | message_tensor = torch.FloatTensor(message_npy).to(device)[None]
181 | signal_wmd_tensor = self.watermark_model.encode(signal, message_tensor)
182 | signal_wmd_npy = signal_wmd_tensor.detach().cpu().squeeze()
183 | audio[(coeff * n) * K: (coeff * n + 1) * K] = signal_wmd_npy
184 | return audio
185 |
186 | def detect_watermark(self, audio, n_repeat):
187 | bits = []
188 | K = 16000
189 | coeff = 2
190 | for n in range(n_repeat):
191 | trunck = audio[(coeff * n) * K: (coeff * n + 1) * K]
192 | if len(trunck) != K:
193 | print('Audio too short, fail to detect watermark')
194 | return 'Fail'
195 | with torch.no_grad():
196 | signal = torch.FloatTensor(trunck).to(self.device).unsqueeze(0)
197 | message_decoded_npy = (self.watermark_model.decode(signal) >= 0.5).int().detach().cpu().numpy().squeeze()
198 | bits.append(message_decoded_npy)
199 | bits = np.stack(bits).reshape(-1, 8)
200 | message = utils.bits_to_string(bits)
201 | return message
202 |
203 |
--------------------------------------------------------------------------------
/ChatTTS/model/gpt.py:
--------------------------------------------------------------------------------
1 | import os
2 | os.environ["TOKENIZERS_PARALLELISM"] = "false"
3 |
4 | import logging
5 | from tqdm import tqdm
6 | from einops import rearrange
7 | from transformers.cache_utils import Cache
8 |
9 | import torch
10 | import torch.nn as nn
11 | import torch.nn.functional as F
12 | import torch.nn.utils.parametrize as P
13 | from torch.nn.utils.parametrizations import weight_norm
14 | from transformers import LlamaModel, LlamaConfig
15 |
16 |
17 | class LlamaMLP(nn.Module):
18 | def __init__(self, hidden_size, intermediate_size):
19 | super().__init__()
20 | self.hidden_size = hidden_size
21 | self.intermediate_size = intermediate_size
22 | self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
23 | self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
24 | self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
25 | self.act_fn = F.silu
26 |
27 | def forward(self, x):
28 | down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
29 | return down_proj
30 |
31 |
32 | class GPT_warpper(nn.Module):
33 | def __init__(
34 | self,
35 | gpt_config,
36 | num_audio_tokens,
37 | num_text_tokens,
38 | num_vq=4,
39 | **kwargs,
40 | ):
41 | super().__init__()
42 |
43 | self.logger = logging.getLogger(__name__)
44 | self.gpt = self.build_model(gpt_config)
45 | self.model_dim = self.gpt.config.hidden_size
46 |
47 | self.num_vq = num_vq
48 | self.emb_code = nn.ModuleList([nn.Embedding(num_audio_tokens, self.model_dim) for i in range(self.num_vq)])
49 | self.emb_text = nn.Embedding(num_text_tokens, self.model_dim)
50 | self.head_text = weight_norm(nn.Linear(self.model_dim, num_text_tokens, bias=False), name='weight')
51 | self.head_code = nn.ModuleList([weight_norm(nn.Linear(self.model_dim, num_audio_tokens, bias=False), name='weight') for i in range(self.num_vq)])
52 |
53 | def build_model(self, config):
54 |
55 | configuration = LlamaConfig(**config)
56 | model = LlamaModel(configuration)
57 | del model.embed_tokens
58 |
59 | return model
60 |
61 | def get_emb(self, input_ids, text_mask, **kwargs):
62 |
63 | emb_text = self.emb_text(input_ids[text_mask][:, 0])
64 |
65 | emb_code = [self.emb_code[i](input_ids[~text_mask][:, i]) for i in range(self.num_vq)]
66 | emb_code = torch.stack(emb_code, 2).sum(2)
67 |
68 | emb = torch.zeros((input_ids.shape[:-1])+(emb_text.shape[-1],), device=emb_text.device, dtype=emb_text.dtype)
69 | emb[text_mask] = emb_text
70 | emb[~text_mask] = emb_code.to(emb.dtype)
71 |
72 | return emb
73 |
74 | def prepare_inputs_for_generation(
75 | self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, cache_position=None, **kwargs
76 | ):
77 | # With static cache, the `past_key_values` is None
78 | # TODO joao: standardize interface for the different Cache classes and remove of this if
79 | has_static_cache = False
80 | if past_key_values is None:
81 | past_key_values = getattr(self.gpt.layers[0].self_attn, "past_key_value", None)
82 | has_static_cache = past_key_values is not None
83 |
84 | past_length = 0
85 | if past_key_values is not None:
86 | if isinstance(past_key_values, Cache):
87 | past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length()
88 | max_cache_length = (
89 | torch.tensor(past_key_values.get_max_length(), device=input_ids.device)
90 | if past_key_values.get_max_length() is not None
91 | else None
92 | )
93 | cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length)
94 | # TODO joao: remove this `else` after `generate` prioritizes `Cache` objects
95 | else:
96 | cache_length = past_length = past_key_values[0][0].shape[2]
97 | max_cache_length = None
98 |
99 | # Keep only the unprocessed tokens:
100 | # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
101 | # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
102 | # input)
103 | if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
104 | input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
105 | # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
106 | # input_ids based on the past_length.
107 | elif past_length < input_ids.shape[1]:
108 | input_ids = input_ids[:, past_length:]
109 | # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
110 |
111 | # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
112 | if (
113 | max_cache_length is not None
114 | and attention_mask is not None
115 | and cache_length + input_ids.shape[1] > max_cache_length
116 | ):
117 | attention_mask = attention_mask[:, -max_cache_length:]
118 |
119 | position_ids = kwargs.get("position_ids", None)
120 | if attention_mask is not None and position_ids is None:
121 | # create position_ids on the fly for batch generation
122 | position_ids = attention_mask.long().cumsum(-1) - 1
123 | position_ids.masked_fill_(attention_mask == 0, 1)
124 | if past_key_values:
125 | position_ids = position_ids[:, -input_ids.shape[1] :]
126 |
127 | # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
128 | if inputs_embeds is not None and past_key_values is None:
129 | model_inputs = {"inputs_embeds": inputs_embeds}
130 | else:
131 | # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
132 | # recompiles graphs as the stride of the inputs is a guard. Ref: https://github.com/huggingface/transformers/pull/29114
133 | # TODO: use `next_tokens` directly instead.
134 | model_inputs = {"input_ids": input_ids.contiguous()}
135 |
136 | input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1]
137 | if cache_position is None:
138 | cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device)
139 | else:
140 | cache_position = cache_position[-input_length:]
141 |
142 | if has_static_cache:
143 | past_key_values = None
144 |
145 | model_inputs.update(
146 | {
147 | "position_ids": position_ids,
148 | "cache_position": cache_position,
149 | "past_key_values": past_key_values,
150 | "use_cache": kwargs.get("use_cache"),
151 | "attention_mask": attention_mask,
152 | }
153 | )
154 | return model_inputs
155 |
156 | def generate(
157 | self,
158 | emb,
159 | inputs_ids,
160 | temperature,
161 | eos_token,
162 | attention_mask = None,
163 | max_new_token = 2048,
164 | min_new_token = 0,
165 | LogitsWarpers = [],
166 | LogitsProcessors = [],
167 | infer_text=False,
168 | return_attn=False,
169 | return_hidden=False,
170 | ):
171 |
172 | with torch.no_grad():
173 |
174 | attentions = []
175 | hiddens = []
176 |
177 | start_idx, end_idx = inputs_ids.shape[1], torch.zeros(inputs_ids.shape[0], device=inputs_ids.device, dtype=torch.long)
178 | finish = torch.zeros(inputs_ids.shape[0], device=inputs_ids.device).bool()
179 |
180 | temperature = temperature[None].expand(inputs_ids.shape[0], -1)
181 | temperature = rearrange(temperature, "b n -> (b n) 1")
182 |
183 | attention_mask_cache = torch.ones((inputs_ids.shape[0], inputs_ids.shape[1]+max_new_token,), dtype=torch.bool, device=inputs_ids.device)
184 | if attention_mask is not None:
185 | attention_mask_cache[:, :attention_mask.shape[1]] = attention_mask
186 |
187 | for i in tqdm(range(max_new_token)):
188 |
189 | model_input = self.prepare_inputs_for_generation(inputs_ids,
190 | outputs.past_key_values if i!=0 else None,
191 | attention_mask_cache[:, :inputs_ids.shape[1]], use_cache=True)
192 |
193 | if i == 0:
194 | model_input['inputs_embeds'] = emb
195 | else:
196 | if infer_text:
197 | model_input['inputs_embeds'] = self.emb_text(model_input['input_ids'][:,:,0])
198 | else:
199 | code_emb = [self.emb_code[i](model_input['input_ids'][:,:,i]) for i in range(self.num_vq)]
200 | model_input['inputs_embeds'] = torch.stack(code_emb, 3).sum(3)
201 |
202 | model_input['input_ids'] = None
203 | outputs = self.gpt.forward(**model_input, output_attentions=return_attn)
204 | attentions.append(outputs.attentions)
205 | hidden_states = outputs[0] # 🐻
206 | if return_hidden:
207 | hiddens.append(hidden_states[:, -1])
208 |
209 | with P.cached():
210 | if infer_text:
211 | logits = self.head_text(hidden_states)
212 | else:
213 | logits = torch.stack([self.head_code[i](hidden_states) for i in range(self.num_vq)], 3)
214 |
215 | logits = logits[:, -1].float()
216 |
217 | if not infer_text:
218 | logits = rearrange(logits, "b c n -> (b n) c")
219 | logits_token = rearrange(inputs_ids[:, start_idx:], "b c n -> (b n) c")
220 | else:
221 | logits_token = inputs_ids[:, start_idx:, 0]
222 |
223 | logits = logits / temperature
224 |
225 | for logitsProcessors in LogitsProcessors:
226 | logits = logitsProcessors(logits_token, logits)
227 |
228 | for logitsWarpers in LogitsWarpers:
229 | logits = logitsWarpers(logits_token, logits)
230 |
231 | if i < min_new_token:
232 | logits[:, eos_token] = -torch.inf
233 |
234 | scores = F.softmax(logits, dim=-1)
235 |
236 | idx_next = torch.multinomial(scores, num_samples=1)
237 |
238 | if not infer_text:
239 | idx_next = rearrange(idx_next, "(b n) 1 -> b n", n=self.num_vq)
240 | finish = finish | (idx_next == eos_token).any(1)
241 | inputs_ids = torch.cat([inputs_ids, idx_next.unsqueeze(1)], 1)
242 | else:
243 | finish = finish | (idx_next == eos_token).any(1)
244 | inputs_ids = torch.cat([inputs_ids, idx_next.unsqueeze(-1).expand(-1, -1, self.num_vq)], 1)
245 |
246 | end_idx = end_idx + (~finish).int()
247 |
248 | if finish.all():
249 | break
250 |
251 | inputs_ids = [inputs_ids[idx, start_idx: start_idx+i] for idx, i in enumerate(end_idx.int())]
252 | inputs_ids = [i[:, 0] for i in inputs_ids] if infer_text else inputs_ids
253 |
254 | if return_hidden:
255 | hiddens = torch.stack(hiddens, 1)
256 | hiddens = [hiddens[idx, :i] for idx, i in enumerate(end_idx.int())]
257 |
258 | if not finish.all():
259 | self.logger.warn(f'Incomplete result. hit max_new_token: {max_new_token}')
260 |
261 | return {
262 | 'ids': inputs_ids,
263 | 'attentions': attentions,
264 | 'hiddens':hiddens,
265 | }
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Creative Commons Attribution-NonCommercial 4.0 International Public
2 | License
3 |
4 | By exercising the Licensed Rights (defined below), You accept and agree
5 | to be bound by the terms and conditions of this Creative Commons
6 | Attribution-NonCommercial 4.0 International Public License ("Public
7 | License"). To the extent this Public License may be interpreted as a
8 | contract, You are granted the Licensed Rights in consideration of Your
9 | acceptance of these terms and conditions, and the Licensor grants You
10 | such rights in consideration of benefits the Licensor receives from
11 | making the Licensed Material available under these terms and
12 | conditions.
13 |
14 |
15 | Section 1 -- Definitions.
16 |
17 | a. Adapted Material means material subject to Copyright and Similar
18 | Rights that is derived from or based upon the Licensed Material
19 | and in which the Licensed Material is translated, altered,
20 | arranged, transformed, or otherwise modified in a manner requiring
21 | permission under the Copyright and Similar Rights held by the
22 | Licensor. For purposes of this Public License, where the Licensed
23 | Material is a musical work, performance, or sound recording,
24 | Adapted Material is always produced where the Licensed Material is
25 | synched in timed relation with a moving image.
26 |
27 | b. Adapter's License means the license You apply to Your Copyright
28 | and Similar Rights in Your contributions to Adapted Material in
29 | accordance with the terms and conditions of this Public License.
30 |
31 | c. Copyright and Similar Rights means copyright and/or similar rights
32 | closely related to copyright including, without limitation,
33 | performance, broadcast, sound recording, and Sui Generis Database
34 | Rights, without regard to how the rights are labeled or
35 | categorized. For purposes of this Public License, the rights
36 | specified in Section 2(b)(1)-(2) are not Copyright and Similar
37 | Rights.
38 | d. Effective Technological Measures means those measures that, in the
39 | absence of proper authority, may not be circumvented under laws
40 | fulfilling obligations under Article 11 of the WIPO Copyright
41 | Treaty adopted on December 20, 1996, and/or similar international
42 | agreements.
43 |
44 | e. Exceptions and Limitations means fair use, fair dealing, and/or
45 | any other exception or limitation to Copyright and Similar Rights
46 | that applies to Your use of the Licensed Material.
47 |
48 | f. Licensed Material means the artistic or literary work, database,
49 | or other material to which the Licensor applied this Public
50 | License.
51 |
52 | g. Licensed Rights means the rights granted to You subject to the
53 | terms and conditions of this Public License, which are limited to
54 | all Copyright and Similar Rights that apply to Your use of the
55 | Licensed Material and that the Licensor has authority to license.
56 |
57 | h. Licensor means the individual(s) or entity(ies) granting rights
58 | under this Public License.
59 |
60 | i. NonCommercial means not primarily intended for or directed towards
61 | commercial advantage or monetary compensation. For purposes of
62 | this Public License, the exchange of the Licensed Material for
63 | other material subject to Copyright and Similar Rights by digital
64 | file-sharing or similar means is NonCommercial provided there is
65 | no payment of monetary compensation in connection with the
66 | exchange.
67 |
68 | j. Share means to provide material to the public by any means or
69 | process that requires permission under the Licensed Rights, such
70 | as reproduction, public display, public performance, distribution,
71 | dissemination, communication, or importation, and to make material
72 | available to the public including in ways that members of the
73 | public may access the material from a place and at a time
74 | individually chosen by them.
75 |
76 | k. Sui Generis Database Rights means rights other than copyright
77 | resulting from Directive 96/9/EC of the European Parliament and of
78 | the Council of 11 March 1996 on the legal protection of databases,
79 | as amended and/or succeeded, as well as other essentially
80 | equivalent rights anywhere in the world.
81 |
82 | l. You means the individual or entity exercising the Licensed Rights
83 | under this Public License. Your has a corresponding meaning.
84 |
85 |
86 | Section 2 -- Scope.
87 |
88 | a. License grant.
89 |
90 | 1. Subject to the terms and conditions of this Public License,
91 | the Licensor hereby grants You a worldwide, royalty-free,
92 | non-sublicensable, non-exclusive, irrevocable license to
93 | exercise the Licensed Rights in the Licensed Material to:
94 |
95 | a. reproduce and Share the Licensed Material, in whole or
96 | in part, for NonCommercial purposes only; and
97 |
98 | b. produce, reproduce, and Share Adapted Material for
99 | NonCommercial purposes only.
100 |
101 | 2. Exceptions and Limitations. For the avoidance of doubt, where
102 | Exceptions and Limitations apply to Your use, this Public
103 | License does not apply, and You do not need to comply with
104 | its terms and conditions.
105 |
106 | 3. Term. The term of this Public License is specified in Section
107 | 6(a).
108 |
109 | 4. Media and formats; technical modifications allowed. The
110 | Licensor authorizes You to exercise the Licensed Rights in
111 | all media and formats whether now known or hereafter created,
112 | and to make technical modifications necessary to do so. The
113 | Licensor waives and/or agrees not to assert any right or
114 | authority to forbid You from making technical modifications
115 | necessary to exercise the Licensed Rights, including
116 | technical modifications necessary to circumvent Effective
117 | Technological Measures. For purposes of this Public License,
118 | simply making modifications authorized by this Section 2(a)
119 | (4) never produces Adapted Material.
120 |
121 | 5. Downstream recipients.
122 |
123 | a. Offer from the Licensor -- Licensed Material. Every
124 | recipient of the Licensed Material automatically
125 | receives an offer from the Licensor to exercise the
126 | Licensed Rights under the terms and conditions of this
127 | Public License.
128 |
129 | b. No downstream restrictions. You may not offer or impose
130 | any additional or different terms or conditions on, or
131 | apply any Effective Technological Measures to, the
132 | Licensed Material if doing so restricts exercise of the
133 | Licensed Rights by any recipient of the Licensed
134 | Material.
135 |
136 | 6. No endorsement. Nothing in this Public License constitutes or
137 | may be construed as permission to assert or imply that You
138 | are, or that Your use of the Licensed Material is, connected
139 | with, or sponsored, endorsed, or granted official status by,
140 | the Licensor or others designated to receive attribution as
141 | provided in Section 3(a)(1)(A)(i).
142 |
143 | b. Other rights.
144 |
145 | 1. Moral rights, such as the right of integrity, are not
146 | licensed under this Public License, nor are publicity,
147 | privacy, and/or other similar personality rights; however, to
148 | the extent possible, the Licensor waives and/or agrees not to
149 | assert any such rights held by the Licensor to the limited
150 | extent necessary to allow You to exercise the Licensed
151 | Rights, but not otherwise.
152 |
153 | 2. Patent and trademark rights are not licensed under this
154 | Public License.
155 |
156 | 3. To the extent possible, the Licensor waives any right to
157 | collect royalties from You for the exercise of the Licensed
158 | Rights, whether directly or through a collecting society
159 | under any voluntary or waivable statutory or compulsory
160 | licensing scheme. In all other cases the Licensor expressly
161 | reserves any right to collect such royalties, including when
162 | the Licensed Material is used other than for NonCommercial
163 | purposes.
164 |
165 |
166 | Section 3 -- License Conditions.
167 |
168 | Your exercise of the Licensed Rights is expressly made subject to the
169 | following conditions.
170 |
171 | a. Attribution.
172 |
173 | 1. If You Share the Licensed Material (including in modified
174 | form), You must:
175 |
176 | a. retain the following if it is supplied by the Licensor
177 | with the Licensed Material:
178 |
179 | i. identification of the creator(s) of the Licensed
180 | Material and any others designated to receive
181 | attribution, in any reasonable manner requested by
182 | the Licensor (including by pseudonym if
183 | designated);
184 |
185 | ii. a copyright notice;
186 |
187 | iii. a notice that refers to this Public License;
188 |
189 | iv. a notice that refers to the disclaimer of
190 | warranties;
191 |
192 | v. a URI or hyperlink to the Licensed Material to the
193 | extent reasonably practicable;
194 |
195 | b. indicate if You modified the Licensed Material and
196 | retain an indication of any previous modifications; and
197 |
198 | c. indicate the Licensed Material is licensed under this
199 | Public License, and include the text of, or the URI or
200 | hyperlink to, this Public License.
201 |
202 | 2. You may satisfy the conditions in Section 3(a)(1) in any
203 | reasonable manner based on the medium, means, and context in
204 | which You Share the Licensed Material. For example, it may be
205 | reasonable to satisfy the conditions by providing a URI or
206 | hyperlink to a resource that includes the required
207 | information.
208 |
209 | 3. If requested by the Licensor, You must remove any of the
210 | information required by Section 3(a)(1)(A) to the extent
211 | reasonably practicable.
212 |
213 | 4. If You Share Adapted Material You produce, the Adapter's
214 | License You apply must not prevent recipients of the Adapted
215 | Material from complying with this Public License.
216 |
217 |
218 | Section 4 -- Sui Generis Database Rights.
219 |
220 | Where the Licensed Rights include Sui Generis Database Rights that
221 | apply to Your use of the Licensed Material:
222 |
223 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right
224 | to extract, reuse, reproduce, and Share all or a substantial
225 | portion of the contents of the database for NonCommercial purposes
226 | only;
227 |
228 | b. if You include all or a substantial portion of the database
229 | contents in a database in which You have Sui Generis Database
230 | Rights, then the database in which You have Sui Generis Database
231 | Rights (but not its individual contents) is Adapted Material; and
232 |
233 | c. You must comply with the conditions in Section 3(a) if You Share
234 | all or a substantial portion of the contents of the database.
235 |
236 | For the avoidance of doubt, this Section 4 supplements and does not
237 | replace Your obligations under this Public License where the Licensed
238 | Rights include other Copyright and Similar Rights.
239 |
240 |
241 | Section 5 -- Disclaimer of Warranties and Limitation of Liability.
242 |
243 | a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
244 | EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
245 | AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
246 | ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
247 | IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
248 | WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
249 | PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
250 | ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
251 | KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
252 | ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
253 |
254 | b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
255 | TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
256 | NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
257 | INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
258 | COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
259 | USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
260 | ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
261 | DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
262 | IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
263 |
264 | c. The disclaimer of warranties and limitation of liability provided
265 | above shall be interpreted in a manner that, to the extent
266 | possible, most closely approximates an absolute disclaimer and
267 | waiver of all liability.
268 |
269 |
270 | Section 6 -- Term and Termination.
271 |
272 | a. This Public License applies for the term of the Copyright and
273 | Similar Rights licensed here. However, if You fail to comply with
274 | this Public License, then Your rights under this Public License
275 | terminate automatically.
276 |
277 | b. Where Your right to use the Licensed Material has terminated under
278 | Section 6(a), it reinstates:
279 |
280 | 1. automatically as of the date the violation is cured, provided
281 | it is cured within 30 days of Your discovery of the
282 | violation; or
283 |
284 | 2. upon express reinstatement by the Licensor.
285 |
286 | For the avoidance of doubt, this Section 6(b) does not affect any
287 | right the Licensor may have to seek remedies for Your violations
288 | of this Public License.
289 |
290 | c. For the avoidance of doubt, the Licensor may also offer the
291 | Licensed Material under separate terms or conditions or stop
292 | distributing the Licensed Material at any time; however, doing so
293 | will not terminate this Public License.
294 |
295 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
296 | License.
297 |
298 |
299 | Section 7 -- Other Terms and Conditions.
300 |
301 | a. The Licensor shall not be bound by any additional or different
302 | terms or conditions communicated by You unless expressly agreed.
303 |
304 | b. Any arrangements, understandings, or agreements regarding the
305 | Licensed Material not stated herein are separate from and
306 | independent of the terms and conditions of this Public License.
307 |
308 |
309 | Section 8 -- Interpretation.
310 |
311 | a. For the avoidance of doubt, this Public License does not, and
312 | shall not be interpreted to, reduce, limit, restrict, or impose
313 | conditions on any use of the Licensed Material that could lawfully
314 | be made without permission under this Public License.
315 |
316 | b. To the extent possible, if any provision of this Public License is
317 | deemed unenforceable, it shall be automatically reformed to the
318 | minimum extent necessary to make it enforceable. If the provision
319 | cannot be reformed, it shall be severed from this Public License
320 | without affecting the enforceability of the remaining terms and
321 | conditions.
322 |
323 | c. No term or condition of this Public License will be waived and no
324 | failure to comply consented to unless expressly agreed to by the
325 | Licensor.
326 |
327 | d. Nothing in this Public License constitutes or may be interpreted
328 | as a limitation upon, or waiver of, any privileges and immunities
329 | that apply to the Licensor or You, including from the legal
330 | processes of any jurisdiction or authority.
331 |
332 | =======================================================================
333 |
334 |
--------------------------------------------------------------------------------
/OpenVoice/utils/attentions.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | from torch import nn
4 | from torch.nn import functional as F
5 |
6 | from . import commons
7 | import logging
8 |
9 | logger = logging.getLogger(__name__)
10 |
11 |
12 | class LayerNorm(nn.Module):
13 | def __init__(self, channels, eps=1e-5):
14 | super().__init__()
15 | self.channels = channels
16 | self.eps = eps
17 |
18 | self.gamma = nn.Parameter(torch.ones(channels))
19 | self.beta = nn.Parameter(torch.zeros(channels))
20 |
21 | def forward(self, x):
22 | x = x.transpose(1, -1)
23 | x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
24 | return x.transpose(1, -1)
25 |
26 |
27 | @torch.jit.script
28 | def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
29 | n_channels_int = n_channels[0]
30 | in_act = input_a + input_b
31 | t_act = torch.tanh(in_act[:, :n_channels_int, :])
32 | s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
33 | acts = t_act * s_act
34 | return acts
35 |
36 |
37 | class Encoder(nn.Module):
38 | def __init__(
39 | self,
40 | hidden_channels,
41 | filter_channels,
42 | n_heads,
43 | n_layers,
44 | kernel_size=1,
45 | p_dropout=0.0,
46 | window_size=4,
47 | isflow=True,
48 | **kwargs
49 | ):
50 | super().__init__()
51 | self.hidden_channels = hidden_channels
52 | self.filter_channels = filter_channels
53 | self.n_heads = n_heads
54 | self.n_layers = n_layers
55 | self.kernel_size = kernel_size
56 | self.p_dropout = p_dropout
57 | self.window_size = window_size
58 | # if isflow:
59 | # cond_layer = torch.nn.Conv1d(256, 2*hidden_channels*n_layers, 1)
60 | # self.cond_pre = torch.nn.Conv1d(hidden_channels, 2*hidden_channels, 1)
61 | # self.cond_layer = weight_norm(cond_layer, name='weight')
62 | # self.gin_channels = 256
63 | self.cond_layer_idx = self.n_layers
64 | if "gin_channels" in kwargs:
65 | self.gin_channels = kwargs["gin_channels"]
66 | if self.gin_channels != 0:
67 | self.spk_emb_linear = nn.Linear(self.gin_channels, self.hidden_channels)
68 | # vits2 says 3rd block, so idx is 2 by default
69 | self.cond_layer_idx = (
70 | kwargs["cond_layer_idx"] if "cond_layer_idx" in kwargs else 2
71 | )
72 | # logging.debug(self.gin_channels, self.cond_layer_idx)
73 | assert (
74 | self.cond_layer_idx < self.n_layers
75 | ), "cond_layer_idx should be less than n_layers"
76 | self.drop = nn.Dropout(p_dropout)
77 | self.attn_layers = nn.ModuleList()
78 | self.norm_layers_1 = nn.ModuleList()
79 | self.ffn_layers = nn.ModuleList()
80 | self.norm_layers_2 = nn.ModuleList()
81 |
82 | for i in range(self.n_layers):
83 | self.attn_layers.append(
84 | MultiHeadAttention(
85 | hidden_channels,
86 | hidden_channels,
87 | n_heads,
88 | p_dropout=p_dropout,
89 | window_size=window_size,
90 | )
91 | )
92 | self.norm_layers_1.append(LayerNorm(hidden_channels))
93 | self.ffn_layers.append(
94 | FFN(
95 | hidden_channels,
96 | hidden_channels,
97 | filter_channels,
98 | kernel_size,
99 | p_dropout=p_dropout,
100 | )
101 | )
102 | self.norm_layers_2.append(LayerNorm(hidden_channels))
103 |
104 | def forward(self, x, x_mask, g=None):
105 | attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
106 | x = x * x_mask
107 | for i in range(self.n_layers):
108 | if i == self.cond_layer_idx and g is not None:
109 | g = self.spk_emb_linear(g.transpose(1, 2))
110 | g = g.transpose(1, 2)
111 | x = x + g
112 | x = x * x_mask
113 | y = self.attn_layers[i](x, x, attn_mask)
114 | y = self.drop(y)
115 | x = self.norm_layers_1[i](x + y)
116 |
117 | y = self.ffn_layers[i](x, x_mask)
118 | y = self.drop(y)
119 | x = self.norm_layers_2[i](x + y)
120 | x = x * x_mask
121 | return x
122 |
123 |
124 | class Decoder(nn.Module):
125 | def __init__(
126 | self,
127 | hidden_channels,
128 | filter_channels,
129 | n_heads,
130 | n_layers,
131 | kernel_size=1,
132 | p_dropout=0.0,
133 | proximal_bias=False,
134 | proximal_init=True,
135 | **kwargs
136 | ):
137 | super().__init__()
138 | self.hidden_channels = hidden_channels
139 | self.filter_channels = filter_channels
140 | self.n_heads = n_heads
141 | self.n_layers = n_layers
142 | self.kernel_size = kernel_size
143 | self.p_dropout = p_dropout
144 | self.proximal_bias = proximal_bias
145 | self.proximal_init = proximal_init
146 |
147 | self.drop = nn.Dropout(p_dropout)
148 | self.self_attn_layers = nn.ModuleList()
149 | self.norm_layers_0 = nn.ModuleList()
150 | self.encdec_attn_layers = nn.ModuleList()
151 | self.norm_layers_1 = nn.ModuleList()
152 | self.ffn_layers = nn.ModuleList()
153 | self.norm_layers_2 = nn.ModuleList()
154 | for i in range(self.n_layers):
155 | self.self_attn_layers.append(
156 | MultiHeadAttention(
157 | hidden_channels,
158 | hidden_channels,
159 | n_heads,
160 | p_dropout=p_dropout,
161 | proximal_bias=proximal_bias,
162 | proximal_init=proximal_init,
163 | )
164 | )
165 | self.norm_layers_0.append(LayerNorm(hidden_channels))
166 | self.encdec_attn_layers.append(
167 | MultiHeadAttention(
168 | hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout
169 | )
170 | )
171 | self.norm_layers_1.append(LayerNorm(hidden_channels))
172 | self.ffn_layers.append(
173 | FFN(
174 | hidden_channels,
175 | hidden_channels,
176 | filter_channels,
177 | kernel_size,
178 | p_dropout=p_dropout,
179 | causal=True,
180 | )
181 | )
182 | self.norm_layers_2.append(LayerNorm(hidden_channels))
183 |
184 | def forward(self, x, x_mask, h, h_mask):
185 | """
186 | x: decoder input
187 | h: encoder output
188 | """
189 | self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(
190 | device=x.device, dtype=x.dtype
191 | )
192 | encdec_attn_mask = h_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
193 | x = x * x_mask
194 | for i in range(self.n_layers):
195 | y = self.self_attn_layers[i](x, x, self_attn_mask)
196 | y = self.drop(y)
197 | x = self.norm_layers_0[i](x + y)
198 |
199 | y = self.encdec_attn_layers[i](x, h, encdec_attn_mask)
200 | y = self.drop(y)
201 | x = self.norm_layers_1[i](x + y)
202 |
203 | y = self.ffn_layers[i](x, x_mask)
204 | y = self.drop(y)
205 | x = self.norm_layers_2[i](x + y)
206 | x = x * x_mask
207 | return x
208 |
209 |
210 | class MultiHeadAttention(nn.Module):
211 | def __init__(
212 | self,
213 | channels,
214 | out_channels,
215 | n_heads,
216 | p_dropout=0.0,
217 | window_size=None,
218 | heads_share=True,
219 | block_length=None,
220 | proximal_bias=False,
221 | proximal_init=False,
222 | ):
223 | super().__init__()
224 | assert channels % n_heads == 0
225 |
226 | self.channels = channels
227 | self.out_channels = out_channels
228 | self.n_heads = n_heads
229 | self.p_dropout = p_dropout
230 | self.window_size = window_size
231 | self.heads_share = heads_share
232 | self.block_length = block_length
233 | self.proximal_bias = proximal_bias
234 | self.proximal_init = proximal_init
235 | self.attn = None
236 |
237 | self.k_channels = channels // n_heads
238 | self.conv_q = nn.Conv1d(channels, channels, 1)
239 | self.conv_k = nn.Conv1d(channels, channels, 1)
240 | self.conv_v = nn.Conv1d(channels, channels, 1)
241 | self.conv_o = nn.Conv1d(channels, out_channels, 1)
242 | self.drop = nn.Dropout(p_dropout)
243 |
244 | if window_size is not None:
245 | n_heads_rel = 1 if heads_share else n_heads
246 | rel_stddev = self.k_channels**-0.5
247 | self.emb_rel_k = nn.Parameter(
248 | torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
249 | * rel_stddev
250 | )
251 | self.emb_rel_v = nn.Parameter(
252 | torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
253 | * rel_stddev
254 | )
255 |
256 | nn.init.xavier_uniform_(self.conv_q.weight)
257 | nn.init.xavier_uniform_(self.conv_k.weight)
258 | nn.init.xavier_uniform_(self.conv_v.weight)
259 | if proximal_init:
260 | with torch.no_grad():
261 | self.conv_k.weight.copy_(self.conv_q.weight)
262 | self.conv_k.bias.copy_(self.conv_q.bias)
263 |
264 | def forward(self, x, c, attn_mask=None):
265 | q = self.conv_q(x)
266 | k = self.conv_k(c)
267 | v = self.conv_v(c)
268 |
269 | x, self.attn = self.attention(q, k, v, mask=attn_mask)
270 |
271 | x = self.conv_o(x)
272 | return x
273 |
274 | def attention(self, query, key, value, mask=None):
275 | # reshape [b, d, t] -> [b, n_h, t, d_k]
276 | b, d, t_s, t_t = (*key.size(), query.size(2))
277 | query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3)
278 | key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
279 | value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
280 |
281 | scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1))
282 | if self.window_size is not None:
283 | assert (
284 | t_s == t_t
285 | ), "Relative attention is only available for self-attention."
286 | key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
287 | rel_logits = self._matmul_with_relative_keys(
288 | query / math.sqrt(self.k_channels), key_relative_embeddings
289 | )
290 | scores_local = self._relative_position_to_absolute_position(rel_logits)
291 | scores = scores + scores_local
292 | if self.proximal_bias:
293 | assert t_s == t_t, "Proximal bias is only available for self-attention."
294 | scores = scores + self._attention_bias_proximal(t_s).to(
295 | device=scores.device, dtype=scores.dtype
296 | )
297 | if mask is not None:
298 | scores = scores.masked_fill(mask == 0, -1e4)
299 | if self.block_length is not None:
300 | assert (
301 | t_s == t_t
302 | ), "Local attention is only available for self-attention."
303 | block_mask = (
304 | torch.ones_like(scores)
305 | .triu(-self.block_length)
306 | .tril(self.block_length)
307 | )
308 | scores = scores.masked_fill(block_mask == 0, -1e4)
309 | p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s]
310 | p_attn = self.drop(p_attn)
311 | output = torch.matmul(p_attn, value)
312 | if self.window_size is not None:
313 | relative_weights = self._absolute_position_to_relative_position(p_attn)
314 | value_relative_embeddings = self._get_relative_embeddings(
315 | self.emb_rel_v, t_s
316 | )
317 | output = output + self._matmul_with_relative_values(
318 | relative_weights, value_relative_embeddings
319 | )
320 | output = (
321 | output.transpose(2, 3).contiguous().view(b, d, t_t)
322 | ) # [b, n_h, t_t, d_k] -> [b, d, t_t]
323 | return output, p_attn
324 |
325 | def _matmul_with_relative_values(self, x, y):
326 | """
327 | x: [b, h, l, m]
328 | y: [h or 1, m, d]
329 | ret: [b, h, l, d]
330 | """
331 | ret = torch.matmul(x, y.unsqueeze(0))
332 | return ret
333 |
334 | def _matmul_with_relative_keys(self, x, y):
335 | """
336 | x: [b, h, l, d]
337 | y: [h or 1, m, d]
338 | ret: [b, h, l, m]
339 | """
340 | ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
341 | return ret
342 |
343 | def _get_relative_embeddings(self, relative_embeddings, length):
344 | 2 * self.window_size + 1
345 | # Pad first before slice to avoid using cond ops.
346 | pad_length = max(length - (self.window_size + 1), 0)
347 | slice_start_position = max((self.window_size + 1) - length, 0)
348 | slice_end_position = slice_start_position + 2 * length - 1
349 | if pad_length > 0:
350 | padded_relative_embeddings = F.pad(
351 | relative_embeddings,
352 | commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]),
353 | )
354 | else:
355 | padded_relative_embeddings = relative_embeddings
356 | used_relative_embeddings = padded_relative_embeddings[
357 | :, slice_start_position:slice_end_position
358 | ]
359 | return used_relative_embeddings
360 |
361 | def _relative_position_to_absolute_position(self, x):
362 | """
363 | x: [b, h, l, 2*l-1]
364 | ret: [b, h, l, l]
365 | """
366 | batch, heads, length, _ = x.size()
367 | # Concat columns of pad to shift from relative to absolute indexing.
368 | x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]]))
369 |
370 | # Concat extra elements so to add up to shape (len+1, 2*len-1).
371 | x_flat = x.view([batch, heads, length * 2 * length])
372 | x_flat = F.pad(
373 | x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [0, length - 1]])
374 | )
375 |
376 | # Reshape and slice out the padded elements.
377 | x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[
378 | :, :, :length, length - 1 :
379 | ]
380 | return x_final
381 |
382 | def _absolute_position_to_relative_position(self, x):
383 | """
384 | x: [b, h, l, l]
385 | ret: [b, h, l, 2*l-1]
386 | """
387 | batch, heads, length, _ = x.size()
388 | # pad along column
389 | x = F.pad(
390 | x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]])
391 | )
392 | x_flat = x.view([batch, heads, length**2 + length * (length - 1)])
393 | # add 0's in the beginning that will skew the elements after reshape
394 | x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
395 | x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:]
396 | return x_final
397 |
398 | def _attention_bias_proximal(self, length):
399 | """Bias for self-attention to encourage attention to close positions.
400 | Args:
401 | length: an integer scalar.
402 | Returns:
403 | a Tensor with shape [1, 1, length, length]
404 | """
405 | r = torch.arange(length, dtype=torch.float32)
406 | diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
407 | return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
408 |
409 |
410 | class FFN(nn.Module):
411 | def __init__(
412 | self,
413 | in_channels,
414 | out_channels,
415 | filter_channels,
416 | kernel_size,
417 | p_dropout=0.0,
418 | activation=None,
419 | causal=False,
420 | ):
421 | super().__init__()
422 | self.in_channels = in_channels
423 | self.out_channels = out_channels
424 | self.filter_channels = filter_channels
425 | self.kernel_size = kernel_size
426 | self.p_dropout = p_dropout
427 | self.activation = activation
428 | self.causal = causal
429 |
430 | if causal:
431 | self.padding = self._causal_padding
432 | else:
433 | self.padding = self._same_padding
434 |
435 | self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size)
436 | self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size)
437 | self.drop = nn.Dropout(p_dropout)
438 |
439 | def forward(self, x, x_mask):
440 | x = self.conv_1(self.padding(x * x_mask))
441 | if self.activation == "gelu":
442 | x = x * torch.sigmoid(1.702 * x)
443 | else:
444 | x = torch.relu(x)
445 | x = self.drop(x)
446 | x = self.conv_2(self.padding(x * x_mask))
447 | return x * x_mask
448 |
449 | def _causal_padding(self, x):
450 | if self.kernel_size == 1:
451 | return x
452 | pad_l = self.kernel_size - 1
453 | pad_r = 0
454 | padding = [[0, 0], [0, 0], [pad_l, pad_r]]
455 | x = F.pad(x, commons.convert_pad_shape(padding))
456 | return x
457 |
458 | def _same_padding(self, x):
459 | if self.kernel_size == 1:
460 | return x
461 | pad_l = (self.kernel_size - 1) // 2
462 | pad_r = self.kernel_size // 2
463 | padding = [[0, 0], [0, 0], [pad_l, pad_r]]
464 | x = F.pad(x, commons.convert_pad_shape(padding))
465 | return x
466 |
--------------------------------------------------------------------------------
/OpenVoice/utils/models.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | from torch import nn
4 | from torch.nn import functional as F
5 |
6 | from . import commons
7 | from . import modules
8 | from . import attentions
9 |
10 | from torch.nn import Conv1d, ConvTranspose1d, Conv2d
11 | from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
12 |
13 | from .utils.commons import init_weights, get_padding
14 |
15 |
16 | class TextEncoder(nn.Module):
17 | def __init__(self,
18 | n_vocab,
19 | out_channels,
20 | hidden_channels,
21 | filter_channels,
22 | n_heads,
23 | n_layers,
24 | kernel_size,
25 | p_dropout):
26 | super().__init__()
27 | self.n_vocab = n_vocab
28 | self.out_channels = out_channels
29 | self.hidden_channels = hidden_channels
30 | self.filter_channels = filter_channels
31 | self.n_heads = n_heads
32 | self.n_layers = n_layers
33 | self.kernel_size = kernel_size
34 | self.p_dropout = p_dropout
35 |
36 | self.emb = nn.Embedding(n_vocab, hidden_channels)
37 | nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5)
38 |
39 | self.encoder = attentions.Encoder(
40 | hidden_channels,
41 | filter_channels,
42 | n_heads,
43 | n_layers,
44 | kernel_size,
45 | p_dropout)
46 | self.proj= nn.Conv1d(hidden_channels, out_channels * 2, 1)
47 |
48 | def forward(self, x, x_lengths):
49 | x = self.emb(x) * math.sqrt(self.hidden_channels) # [b, t, h]
50 | x = torch.transpose(x, 1, -1) # [b, h, t]
51 | x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
52 |
53 | x = self.encoder(x * x_mask, x_mask)
54 | stats = self.proj(x) * x_mask
55 |
56 | m, logs = torch.split(stats, self.out_channels, dim=1)
57 | return x, m, logs, x_mask
58 |
59 |
60 | class DurationPredictor(nn.Module):
61 | def __init__(
62 | self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0
63 | ):
64 | super().__init__()
65 |
66 | self.in_channels = in_channels
67 | self.filter_channels = filter_channels
68 | self.kernel_size = kernel_size
69 | self.p_dropout = p_dropout
70 | self.gin_channels = gin_channels
71 |
72 | self.drop = nn.Dropout(p_dropout)
73 | self.conv_1 = nn.Conv1d(
74 | in_channels, filter_channels, kernel_size, padding=kernel_size // 2
75 | )
76 | self.norm_1 = modules.LayerNorm(filter_channels)
77 | self.conv_2 = nn.Conv1d(
78 | filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
79 | )
80 | self.norm_2 = modules.LayerNorm(filter_channels)
81 | self.proj = nn.Conv1d(filter_channels, 1, 1)
82 |
83 | if gin_channels != 0:
84 | self.cond = nn.Conv1d(gin_channels, in_channels, 1)
85 |
86 | def forward(self, x, x_mask, g=None):
87 | x = torch.detach(x)
88 | if g is not None:
89 | g = torch.detach(g)
90 | x = x + self.cond(g)
91 | x = self.conv_1(x * x_mask)
92 | x = torch.relu(x)
93 | x = self.norm_1(x)
94 | x = self.drop(x)
95 | x = self.conv_2(x * x_mask)
96 | x = torch.relu(x)
97 | x = self.norm_2(x)
98 | x = self.drop(x)
99 | x = self.proj(x * x_mask)
100 | return x * x_mask
101 |
102 | class StochasticDurationPredictor(nn.Module):
103 | def __init__(self, in_channels, filter_channels, kernel_size, p_dropout, n_flows=4, gin_channels=0):
104 | super().__init__()
105 | filter_channels = in_channels # it needs to be removed from future version.
106 | self.in_channels = in_channels
107 | self.filter_channels = filter_channels
108 | self.kernel_size = kernel_size
109 | self.p_dropout = p_dropout
110 | self.n_flows = n_flows
111 | self.gin_channels = gin_channels
112 |
113 | self.log_flow = modules.Log()
114 | self.flows = nn.ModuleList()
115 | self.flows.append(modules.ElementwiseAffine(2))
116 | for i in range(n_flows):
117 | self.flows.append(modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3))
118 | self.flows.append(modules.Flip())
119 |
120 | self.post_pre = nn.Conv1d(1, filter_channels, 1)
121 | self.post_proj = nn.Conv1d(filter_channels, filter_channels, 1)
122 | self.post_convs = modules.DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout)
123 | self.post_flows = nn.ModuleList()
124 | self.post_flows.append(modules.ElementwiseAffine(2))
125 | for i in range(4):
126 | self.post_flows.append(modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3))
127 | self.post_flows.append(modules.Flip())
128 |
129 | self.pre = nn.Conv1d(in_channels, filter_channels, 1)
130 | self.proj = nn.Conv1d(filter_channels, filter_channels, 1)
131 | self.convs = modules.DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout)
132 | if gin_channels != 0:
133 | self.cond = nn.Conv1d(gin_channels, filter_channels, 1)
134 |
135 | def forward(self, x, x_mask, w=None, g=None, reverse=False, noise_scale=1.0):
136 | x = torch.detach(x)
137 | x = self.pre(x)
138 | if g is not None:
139 | g = torch.detach(g)
140 | x = x + self.cond(g)
141 | x = self.convs(x, x_mask)
142 | x = self.proj(x) * x_mask
143 |
144 | if not reverse:
145 | flows = self.flows
146 | assert w is not None
147 |
148 | logdet_tot_q = 0
149 | h_w = self.post_pre(w)
150 | h_w = self.post_convs(h_w, x_mask)
151 | h_w = self.post_proj(h_w) * x_mask
152 | e_q = torch.randn(w.size(0), 2, w.size(2)).to(device=x.device, dtype=x.dtype) * x_mask
153 | z_q = e_q
154 | for flow in self.post_flows:
155 | z_q, logdet_q = flow(z_q, x_mask, g=(x + h_w))
156 | logdet_tot_q += logdet_q
157 | z_u, z1 = torch.split(z_q, [1, 1], 1)
158 | u = torch.sigmoid(z_u) * x_mask
159 | z0 = (w - u) * x_mask
160 | logdet_tot_q += torch.sum((F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1,2])
161 | logq = torch.sum(-0.5 * (math.log(2*math.pi) + (e_q**2)) * x_mask, [1,2]) - logdet_tot_q
162 |
163 | logdet_tot = 0
164 | z0, logdet = self.log_flow(z0, x_mask)
165 | logdet_tot += logdet
166 | z = torch.cat([z0, z1], 1)
167 | for flow in flows:
168 | z, logdet = flow(z, x_mask, g=x, reverse=reverse)
169 | logdet_tot = logdet_tot + logdet
170 | nll = torch.sum(0.5 * (math.log(2*math.pi) + (z**2)) * x_mask, [1,2]) - logdet_tot
171 | return nll + logq # [b]
172 | else:
173 | flows = list(reversed(self.flows))
174 | flows = flows[:-2] + [flows[-1]] # remove a useless vflow
175 | z = torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype) * noise_scale
176 | for flow in flows:
177 | z = flow(z, x_mask, g=x, reverse=reverse)
178 | z0, z1 = torch.split(z, [1, 1], 1)
179 | logw = z0
180 | return logw
181 |
182 | class PosteriorEncoder(nn.Module):
183 | def __init__(
184 | self,
185 | in_channels,
186 | out_channels,
187 | hidden_channels,
188 | kernel_size,
189 | dilation_rate,
190 | n_layers,
191 | gin_channels=0,
192 | ):
193 | super().__init__()
194 | self.in_channels = in_channels
195 | self.out_channels = out_channels
196 | self.hidden_channels = hidden_channels
197 | self.kernel_size = kernel_size
198 | self.dilation_rate = dilation_rate
199 | self.n_layers = n_layers
200 | self.gin_channels = gin_channels
201 |
202 | self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
203 | self.enc = modules.WN(
204 | hidden_channels,
205 | kernel_size,
206 | dilation_rate,
207 | n_layers,
208 | gin_channels=gin_channels,
209 | )
210 | self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
211 |
212 | def forward(self, x, x_lengths, g=None, tau=1.0):
213 | x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
214 | x.dtype
215 | )
216 | x = self.pre(x) * x_mask
217 | x = self.enc(x, x_mask, g=g)
218 | stats = self.proj(x) * x_mask
219 | m, logs = torch.split(stats, self.out_channels, dim=1)
220 | z = (m + torch.randn_like(m) * tau * torch.exp(logs)) * x_mask
221 | return z, m, logs, x_mask
222 |
223 |
224 | class Generator(torch.nn.Module):
225 | def __init__(
226 | self,
227 | initial_channel,
228 | resblock,
229 | resblock_kernel_sizes,
230 | resblock_dilation_sizes,
231 | upsample_rates,
232 | upsample_initial_channel,
233 | upsample_kernel_sizes,
234 | gin_channels=0,
235 | ):
236 | super(Generator, self).__init__()
237 | self.num_kernels = len(resblock_kernel_sizes)
238 | self.num_upsamples = len(upsample_rates)
239 | self.conv_pre = Conv1d(
240 | initial_channel, upsample_initial_channel, 7, 1, padding=3
241 | )
242 | resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2
243 |
244 | self.ups = nn.ModuleList()
245 | for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
246 | self.ups.append(
247 | weight_norm(
248 | ConvTranspose1d(
249 | upsample_initial_channel // (2**i),
250 | upsample_initial_channel // (2 ** (i + 1)),
251 | k,
252 | u,
253 | padding=(k - u) // 2,
254 | )
255 | )
256 | )
257 |
258 | self.resblocks = nn.ModuleList()
259 | for i in range(len(self.ups)):
260 | ch = upsample_initial_channel // (2 ** (i + 1))
261 | for j, (k, d) in enumerate(
262 | zip(resblock_kernel_sizes, resblock_dilation_sizes)
263 | ):
264 | self.resblocks.append(resblock(ch, k, d))
265 |
266 | self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
267 | self.ups.apply(init_weights)
268 |
269 | if gin_channels != 0:
270 | self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
271 |
272 | def forward(self, x, g=None):
273 | x = self.conv_pre(x)
274 | if g is not None:
275 | x = x + self.cond(g)
276 |
277 | for i in range(self.num_upsamples):
278 | x = F.leaky_relu(x, modules.LRELU_SLOPE)
279 | x = self.ups[i](x)
280 | xs = None
281 | for j in range(self.num_kernels):
282 | if xs is None:
283 | xs = self.resblocks[i * self.num_kernels + j](x)
284 | else:
285 | xs += self.resblocks[i * self.num_kernels + j](x)
286 | x = xs / self.num_kernels
287 | x = F.leaky_relu(x)
288 | x = self.conv_post(x)
289 | x = torch.tanh(x)
290 |
291 | return x
292 |
293 | def remove_weight_norm(self):
294 | print("Removing weight norm...")
295 | for layer in self.ups:
296 | remove_weight_norm(layer)
297 | for layer in self.resblocks:
298 | layer.remove_weight_norm()
299 |
300 |
301 | class ReferenceEncoder(nn.Module):
302 | """
303 | inputs --- [N, Ty/r, n_mels*r] mels
304 | outputs --- [N, ref_enc_gru_size]
305 | """
306 |
307 | def __init__(self, spec_channels, gin_channels=0, layernorm=True):
308 | super().__init__()
309 | self.spec_channels = spec_channels
310 | ref_enc_filters = [32, 32, 64, 64, 128, 128]
311 | K = len(ref_enc_filters)
312 | filters = [1] + ref_enc_filters
313 | convs = [
314 | weight_norm(
315 | nn.Conv2d(
316 | in_channels=filters[i],
317 | out_channels=filters[i + 1],
318 | kernel_size=(3, 3),
319 | stride=(2, 2),
320 | padding=(1, 1),
321 | )
322 | )
323 | for i in range(K)
324 | ]
325 | self.convs = nn.ModuleList(convs)
326 |
327 | out_channels = self.calculate_channels(spec_channels, 3, 2, 1, K)
328 | self.gru = nn.GRU(
329 | input_size=ref_enc_filters[-1] * out_channels,
330 | hidden_size=256 // 2,
331 | batch_first=True,
332 | )
333 | self.proj = nn.Linear(128, gin_channels)
334 | if layernorm:
335 | self.layernorm = nn.LayerNorm(self.spec_channels)
336 | else:
337 | self.layernorm = None
338 |
339 | def forward(self, inputs, mask=None):
340 | N = inputs.size(0)
341 |
342 | out = inputs.view(N, 1, -1, self.spec_channels) # [N, 1, Ty, n_freqs]
343 | if self.layernorm is not None:
344 | out = self.layernorm(out)
345 |
346 | for conv in self.convs:
347 | out = conv(out)
348 | # out = wn(out)
349 | out = F.relu(out) # [N, 128, Ty//2^K, n_mels//2^K]
350 |
351 | out = out.transpose(1, 2) # [N, Ty//2^K, 128, n_mels//2^K]
352 | T = out.size(1)
353 | N = out.size(0)
354 | out = out.contiguous().view(N, T, -1) # [N, Ty//2^K, 128*n_mels//2^K]
355 |
356 | self.gru.flatten_parameters()
357 | memory, out = self.gru(out) # out --- [1, N, 128]
358 |
359 | return self.proj(out.squeeze(0))
360 |
361 | def calculate_channels(self, L, kernel_size, stride, pad, n_convs):
362 | for i in range(n_convs):
363 | L = (L - kernel_size + 2 * pad) // stride + 1
364 | return L
365 |
366 |
367 | class ResidualCouplingBlock(nn.Module):
368 | def __init__(self,
369 | channels,
370 | hidden_channels,
371 | kernel_size,
372 | dilation_rate,
373 | n_layers,
374 | n_flows=4,
375 | gin_channels=0):
376 | super().__init__()
377 | self.channels = channels
378 | self.hidden_channels = hidden_channels
379 | self.kernel_size = kernel_size
380 | self.dilation_rate = dilation_rate
381 | self.n_layers = n_layers
382 | self.n_flows = n_flows
383 | self.gin_channels = gin_channels
384 |
385 | self.flows = nn.ModuleList()
386 | for i in range(n_flows):
387 | self.flows.append(modules.ResidualCouplingLayer(channels, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels, mean_only=True))
388 | self.flows.append(modules.Flip())
389 |
390 | def forward(self, x, x_mask, g=None, reverse=False):
391 | if not reverse:
392 | for flow in self.flows:
393 | x, _ = flow(x, x_mask, g=g, reverse=reverse)
394 | else:
395 | for flow in reversed(self.flows):
396 | x = flow(x, x_mask, g=g, reverse=reverse)
397 | return x
398 |
399 | class SynthesizerTrn(nn.Module):
400 | """
401 | Synthesizer for Training
402 | """
403 |
404 | def __init__(
405 | self,
406 | n_vocab,
407 | spec_channels,
408 | inter_channels,
409 | hidden_channels,
410 | filter_channels,
411 | n_heads,
412 | n_layers,
413 | kernel_size,
414 | p_dropout,
415 | resblock,
416 | resblock_kernel_sizes,
417 | resblock_dilation_sizes,
418 | upsample_rates,
419 | upsample_initial_channel,
420 | upsample_kernel_sizes,
421 | n_speakers=256,
422 | gin_channels=256,
423 | **kwargs
424 | ):
425 | super().__init__()
426 |
427 | self.dec = Generator(
428 | inter_channels,
429 | resblock,
430 | resblock_kernel_sizes,
431 | resblock_dilation_sizes,
432 | upsample_rates,
433 | upsample_initial_channel,
434 | upsample_kernel_sizes,
435 | gin_channels=gin_channels,
436 | )
437 | self.enc_q = PosteriorEncoder(
438 | spec_channels,
439 | inter_channels,
440 | hidden_channels,
441 | 5,
442 | 1,
443 | 16,
444 | gin_channels=gin_channels,
445 | )
446 |
447 | self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels)
448 |
449 | self.n_speakers = n_speakers
450 | if n_speakers == 0:
451 | self.ref_enc = ReferenceEncoder(spec_channels, gin_channels)
452 | else:
453 | self.enc_p = TextEncoder(n_vocab,
454 | inter_channels,
455 | hidden_channels,
456 | filter_channels,
457 | n_heads,
458 | n_layers,
459 | kernel_size,
460 | p_dropout)
461 | self.sdp = StochasticDurationPredictor(hidden_channels, 192, 3, 0.5, 4, gin_channels=gin_channels)
462 | self.dp = DurationPredictor(hidden_channels, 256, 3, 0.5, gin_channels=gin_channels)
463 | self.emb_g = nn.Embedding(n_speakers, gin_channels)
464 |
465 | def infer(self, x, x_lengths, sid=None, noise_scale=1, length_scale=1, noise_scale_w=1., sdp_ratio=0.2, max_len=None):
466 | x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths)
467 | if self.n_speakers > 0:
468 | g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
469 | else:
470 | g = None
471 |
472 | logw = self.sdp(x, x_mask, g=g, reverse=True, noise_scale=noise_scale_w) * sdp_ratio \
473 | + self.dp(x, x_mask, g=g) * (1 - sdp_ratio)
474 |
475 | w = torch.exp(logw) * x_mask * length_scale
476 | w_ceil = torch.ceil(w)
477 | y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
478 | y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, None), 1).to(x_mask.dtype)
479 | attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
480 | attn = commons.generate_path(w_ceil, attn_mask)
481 |
482 | m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(1, 2) # [b, t', t], [b, t, d] -> [b, d, t']
483 | logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(1, 2) # [b, t', t], [b, t, d] -> [b, d, t']
484 |
485 | z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
486 | z = self.flow(z_p, y_mask, g=g, reverse=True)
487 | o = self.dec((z * y_mask)[:,:,:max_len], g=g)
488 | return o, attn, y_mask, (z, z_p, m_p, logs_p)
489 |
490 | def voice_conversion(self, y, y_lengths, sid_src, sid_tgt, tau=1.0):
491 | g_src = sid_src
492 | g_tgt = sid_tgt
493 | z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g_src, tau=tau)
494 | z_p = self.flow(z, y_mask, g=g_src)
495 | z_hat = self.flow(z_p, y_mask, g=g_tgt, reverse=True)
496 | o_hat = self.dec(z_hat * y_mask, g=g_tgt)
497 | return o_hat, y_mask, (z, z_p, z_hat)
498 |
--------------------------------------------------------------------------------
/OpenVoice/utils/modules.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | from torch import nn
4 | from torch.nn import functional as F
5 |
6 | from torch.nn import Conv1d
7 | from torch.nn.utils import weight_norm, remove_weight_norm
8 |
9 | from . import commons
10 | from .commons import init_weights, get_padding
11 | from .transforms import piecewise_rational_quadratic_transform
12 | from .attentions import Encoder
13 |
14 | LRELU_SLOPE = 0.1
15 |
16 |
17 | class LayerNorm(nn.Module):
18 | def __init__(self, channels, eps=1e-5):
19 | super().__init__()
20 | self.channels = channels
21 | self.eps = eps
22 |
23 | self.gamma = nn.Parameter(torch.ones(channels))
24 | self.beta = nn.Parameter(torch.zeros(channels))
25 |
26 | def forward(self, x):
27 | x = x.transpose(1, -1)
28 | x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
29 | return x.transpose(1, -1)
30 |
31 |
32 | class ConvReluNorm(nn.Module):
33 | def __init__(
34 | self,
35 | in_channels,
36 | hidden_channels,
37 | out_channels,
38 | kernel_size,
39 | n_layers,
40 | p_dropout,
41 | ):
42 | super().__init__()
43 | self.in_channels = in_channels
44 | self.hidden_channels = hidden_channels
45 | self.out_channels = out_channels
46 | self.kernel_size = kernel_size
47 | self.n_layers = n_layers
48 | self.p_dropout = p_dropout
49 | assert n_layers > 1, "Number of layers should be larger than 0."
50 |
51 | self.conv_layers = nn.ModuleList()
52 | self.norm_layers = nn.ModuleList()
53 | self.conv_layers.append(
54 | nn.Conv1d(
55 | in_channels, hidden_channels, kernel_size, padding=kernel_size // 2
56 | )
57 | )
58 | self.norm_layers.append(LayerNorm(hidden_channels))
59 | self.relu_drop = nn.Sequential(nn.ReLU(), nn.Dropout(p_dropout))
60 | for _ in range(n_layers - 1):
61 | self.conv_layers.append(
62 | nn.Conv1d(
63 | hidden_channels,
64 | hidden_channels,
65 | kernel_size,
66 | padding=kernel_size // 2,
67 | )
68 | )
69 | self.norm_layers.append(LayerNorm(hidden_channels))
70 | self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
71 | self.proj.weight.data.zero_()
72 | self.proj.bias.data.zero_()
73 |
74 | def forward(self, x, x_mask):
75 | x_org = x
76 | for i in range(self.n_layers):
77 | x = self.conv_layers[i](x * x_mask)
78 | x = self.norm_layers[i](x)
79 | x = self.relu_drop(x)
80 | x = x_org + self.proj(x)
81 | return x * x_mask
82 |
83 |
84 | class DDSConv(nn.Module):
85 | """
86 | Dilated and Depth-Separable Convolution
87 | """
88 |
89 | def __init__(self, channels, kernel_size, n_layers, p_dropout=0.0):
90 | super().__init__()
91 | self.channels = channels
92 | self.kernel_size = kernel_size
93 | self.n_layers = n_layers
94 | self.p_dropout = p_dropout
95 |
96 | self.drop = nn.Dropout(p_dropout)
97 | self.convs_sep = nn.ModuleList()
98 | self.convs_1x1 = nn.ModuleList()
99 | self.norms_1 = nn.ModuleList()
100 | self.norms_2 = nn.ModuleList()
101 | for i in range(n_layers):
102 | dilation = kernel_size**i
103 | padding = (kernel_size * dilation - dilation) // 2
104 | self.convs_sep.append(
105 | nn.Conv1d(
106 | channels,
107 | channels,
108 | kernel_size,
109 | groups=channels,
110 | dilation=dilation,
111 | padding=padding,
112 | )
113 | )
114 | self.convs_1x1.append(nn.Conv1d(channels, channels, 1))
115 | self.norms_1.append(LayerNorm(channels))
116 | self.norms_2.append(LayerNorm(channels))
117 |
118 | def forward(self, x, x_mask, g=None):
119 | if g is not None:
120 | x = x + g
121 | for i in range(self.n_layers):
122 | y = self.convs_sep[i](x * x_mask)
123 | y = self.norms_1[i](y)
124 | y = F.gelu(y)
125 | y = self.convs_1x1[i](y)
126 | y = self.norms_2[i](y)
127 | y = F.gelu(y)
128 | y = self.drop(y)
129 | x = x + y
130 | return x * x_mask
131 |
132 |
133 | class WN(torch.nn.Module):
134 | def __init__(
135 | self,
136 | hidden_channels,
137 | kernel_size,
138 | dilation_rate,
139 | n_layers,
140 | gin_channels=0,
141 | p_dropout=0,
142 | ):
143 | super(WN, self).__init__()
144 | assert kernel_size % 2 == 1
145 | self.hidden_channels = hidden_channels
146 | self.kernel_size = (kernel_size,)
147 | self.dilation_rate = dilation_rate
148 | self.n_layers = n_layers
149 | self.gin_channels = gin_channels
150 | self.p_dropout = p_dropout
151 |
152 | self.in_layers = torch.nn.ModuleList()
153 | self.res_skip_layers = torch.nn.ModuleList()
154 | self.drop = nn.Dropout(p_dropout)
155 |
156 | if gin_channels != 0:
157 | cond_layer = torch.nn.Conv1d(
158 | gin_channels, 2 * hidden_channels * n_layers, 1
159 | )
160 | self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name="weight")
161 |
162 | for i in range(n_layers):
163 | dilation = dilation_rate**i
164 | padding = int((kernel_size * dilation - dilation) / 2)
165 | in_layer = torch.nn.Conv1d(
166 | hidden_channels,
167 | 2 * hidden_channels,
168 | kernel_size,
169 | dilation=dilation,
170 | padding=padding,
171 | )
172 | in_layer = torch.nn.utils.weight_norm(in_layer, name="weight")
173 | self.in_layers.append(in_layer)
174 |
175 | # last one is not necessary
176 | if i < n_layers - 1:
177 | res_skip_channels = 2 * hidden_channels
178 | else:
179 | res_skip_channels = hidden_channels
180 |
181 | res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1)
182 | res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name="weight")
183 | self.res_skip_layers.append(res_skip_layer)
184 |
185 | def forward(self, x, x_mask, g=None, **kwargs):
186 | output = torch.zeros_like(x)
187 | n_channels_tensor = torch.IntTensor([self.hidden_channels])
188 |
189 | if g is not None:
190 | g = self.cond_layer(g)
191 |
192 | for i in range(self.n_layers):
193 | x_in = self.in_layers[i](x)
194 | if g is not None:
195 | cond_offset = i * 2 * self.hidden_channels
196 | g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :]
197 | else:
198 | g_l = torch.zeros_like(x_in)
199 |
200 | acts = commons.fused_add_tanh_sigmoid_multiply(x_in, g_l, n_channels_tensor)
201 | acts = self.drop(acts)
202 |
203 | res_skip_acts = self.res_skip_layers[i](acts)
204 | if i < self.n_layers - 1:
205 | res_acts = res_skip_acts[:, : self.hidden_channels, :]
206 | x = (x + res_acts) * x_mask
207 | output = output + res_skip_acts[:, self.hidden_channels :, :]
208 | else:
209 | output = output + res_skip_acts
210 | return output * x_mask
211 |
212 | def remove_weight_norm(self):
213 | if self.gin_channels != 0:
214 | torch.nn.utils.remove_weight_norm(self.cond_layer)
215 | for l in self.in_layers:
216 | torch.nn.utils.remove_weight_norm(l)
217 | for l in self.res_skip_layers:
218 | torch.nn.utils.remove_weight_norm(l)
219 |
220 |
221 | class ResBlock1(torch.nn.Module):
222 | def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
223 | super(ResBlock1, self).__init__()
224 | self.convs1 = nn.ModuleList(
225 | [
226 | weight_norm(
227 | Conv1d(
228 | channels,
229 | channels,
230 | kernel_size,
231 | 1,
232 | dilation=dilation[0],
233 | padding=get_padding(kernel_size, dilation[0]),
234 | )
235 | ),
236 | weight_norm(
237 | Conv1d(
238 | channels,
239 | channels,
240 | kernel_size,
241 | 1,
242 | dilation=dilation[1],
243 | padding=get_padding(kernel_size, dilation[1]),
244 | )
245 | ),
246 | weight_norm(
247 | Conv1d(
248 | channels,
249 | channels,
250 | kernel_size,
251 | 1,
252 | dilation=dilation[2],
253 | padding=get_padding(kernel_size, dilation[2]),
254 | )
255 | ),
256 | ]
257 | )
258 | self.convs1.apply(init_weights)
259 |
260 | self.convs2 = nn.ModuleList(
261 | [
262 | weight_norm(
263 | Conv1d(
264 | channels,
265 | channels,
266 | kernel_size,
267 | 1,
268 | dilation=1,
269 | padding=get_padding(kernel_size, 1),
270 | )
271 | ),
272 | weight_norm(
273 | Conv1d(
274 | channels,
275 | channels,
276 | kernel_size,
277 | 1,
278 | dilation=1,
279 | padding=get_padding(kernel_size, 1),
280 | )
281 | ),
282 | weight_norm(
283 | Conv1d(
284 | channels,
285 | channels,
286 | kernel_size,
287 | 1,
288 | dilation=1,
289 | padding=get_padding(kernel_size, 1),
290 | )
291 | ),
292 | ]
293 | )
294 | self.convs2.apply(init_weights)
295 |
296 | def forward(self, x, x_mask=None):
297 | for c1, c2 in zip(self.convs1, self.convs2):
298 | xt = F.leaky_relu(x, LRELU_SLOPE)
299 | if x_mask is not None:
300 | xt = xt * x_mask
301 | xt = c1(xt)
302 | xt = F.leaky_relu(xt, LRELU_SLOPE)
303 | if x_mask is not None:
304 | xt = xt * x_mask
305 | xt = c2(xt)
306 | x = xt + x
307 | if x_mask is not None:
308 | x = x * x_mask
309 | return x
310 |
311 | def remove_weight_norm(self):
312 | for l in self.convs1:
313 | remove_weight_norm(l)
314 | for l in self.convs2:
315 | remove_weight_norm(l)
316 |
317 |
318 | class ResBlock2(torch.nn.Module):
319 | def __init__(self, channels, kernel_size=3, dilation=(1, 3)):
320 | super(ResBlock2, self).__init__()
321 | self.convs = nn.ModuleList(
322 | [
323 | weight_norm(
324 | Conv1d(
325 | channels,
326 | channels,
327 | kernel_size,
328 | 1,
329 | dilation=dilation[0],
330 | padding=get_padding(kernel_size, dilation[0]),
331 | )
332 | ),
333 | weight_norm(
334 | Conv1d(
335 | channels,
336 | channels,
337 | kernel_size,
338 | 1,
339 | dilation=dilation[1],
340 | padding=get_padding(kernel_size, dilation[1]),
341 | )
342 | ),
343 | ]
344 | )
345 | self.convs.apply(init_weights)
346 |
347 | def forward(self, x, x_mask=None):
348 | for c in self.convs:
349 | xt = F.leaky_relu(x, LRELU_SLOPE)
350 | if x_mask is not None:
351 | xt = xt * x_mask
352 | xt = c(xt)
353 | x = xt + x
354 | if x_mask is not None:
355 | x = x * x_mask
356 | return x
357 |
358 | def remove_weight_norm(self):
359 | for l in self.convs:
360 | remove_weight_norm(l)
361 |
362 |
363 | class Log(nn.Module):
364 | def forward(self, x, x_mask, reverse=False, **kwargs):
365 | if not reverse:
366 | y = torch.log(torch.clamp_min(x, 1e-5)) * x_mask
367 | logdet = torch.sum(-y, [1, 2])
368 | return y, logdet
369 | else:
370 | x = torch.exp(x) * x_mask
371 | return x
372 |
373 |
374 | class Flip(nn.Module):
375 | def forward(self, x, *args, reverse=False, **kwargs):
376 | x = torch.flip(x, [1])
377 | if not reverse:
378 | logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device)
379 | return x, logdet
380 | else:
381 | return x
382 |
383 |
384 | class ElementwiseAffine(nn.Module):
385 | def __init__(self, channels):
386 | super().__init__()
387 | self.channels = channels
388 | self.m = nn.Parameter(torch.zeros(channels, 1))
389 | self.logs = nn.Parameter(torch.zeros(channels, 1))
390 |
391 | def forward(self, x, x_mask, reverse=False, **kwargs):
392 | if not reverse:
393 | y = self.m + torch.exp(self.logs) * x
394 | y = y * x_mask
395 | logdet = torch.sum(self.logs * x_mask, [1, 2])
396 | return y, logdet
397 | else:
398 | x = (x - self.m) * torch.exp(-self.logs) * x_mask
399 | return x
400 |
401 |
402 | class ResidualCouplingLayer(nn.Module):
403 | def __init__(
404 | self,
405 | channels,
406 | hidden_channels,
407 | kernel_size,
408 | dilation_rate,
409 | n_layers,
410 | p_dropout=0,
411 | gin_channels=0,
412 | mean_only=False,
413 | ):
414 | assert channels % 2 == 0, "channels should be divisible by 2"
415 | super().__init__()
416 | self.channels = channels
417 | self.hidden_channels = hidden_channels
418 | self.kernel_size = kernel_size
419 | self.dilation_rate = dilation_rate
420 | self.n_layers = n_layers
421 | self.half_channels = channels // 2
422 | self.mean_only = mean_only
423 |
424 | self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
425 | self.enc = WN(
426 | hidden_channels,
427 | kernel_size,
428 | dilation_rate,
429 | n_layers,
430 | p_dropout=p_dropout,
431 | gin_channels=gin_channels,
432 | )
433 | self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
434 | self.post.weight.data.zero_()
435 | self.post.bias.data.zero_()
436 |
437 | def forward(self, x, x_mask, g=None, reverse=False):
438 | x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
439 | h = self.pre(x0) * x_mask
440 | h = self.enc(h, x_mask, g=g)
441 | stats = self.post(h) * x_mask
442 | if not self.mean_only:
443 | m, logs = torch.split(stats, [self.half_channels] * 2, 1)
444 | else:
445 | m = stats
446 | logs = torch.zeros_like(m)
447 |
448 | if not reverse:
449 | x1 = m + x1 * torch.exp(logs) * x_mask
450 | x = torch.cat([x0, x1], 1)
451 | logdet = torch.sum(logs, [1, 2])
452 | return x, logdet
453 | else:
454 | x1 = (x1 - m) * torch.exp(-logs) * x_mask
455 | x = torch.cat([x0, x1], 1)
456 | return x
457 |
458 |
459 | class ConvFlow(nn.Module):
460 | def __init__(
461 | self,
462 | in_channels,
463 | filter_channels,
464 | kernel_size,
465 | n_layers,
466 | num_bins=10,
467 | tail_bound=5.0,
468 | ):
469 | super().__init__()
470 | self.in_channels = in_channels
471 | self.filter_channels = filter_channels
472 | self.kernel_size = kernel_size
473 | self.n_layers = n_layers
474 | self.num_bins = num_bins
475 | self.tail_bound = tail_bound
476 | self.half_channels = in_channels // 2
477 |
478 | self.pre = nn.Conv1d(self.half_channels, filter_channels, 1)
479 | self.convs = DDSConv(filter_channels, kernel_size, n_layers, p_dropout=0.0)
480 | self.proj = nn.Conv1d(
481 | filter_channels, self.half_channels * (num_bins * 3 - 1), 1
482 | )
483 | self.proj.weight.data.zero_()
484 | self.proj.bias.data.zero_()
485 |
486 | def forward(self, x, x_mask, g=None, reverse=False):
487 | x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
488 | h = self.pre(x0)
489 | h = self.convs(h, x_mask, g=g)
490 | h = self.proj(h) * x_mask
491 |
492 | b, c, t = x0.shape
493 | h = h.reshape(b, c, -1, t).permute(0, 1, 3, 2) # [b, cx?, t] -> [b, c, t, ?]
494 |
495 | unnormalized_widths = h[..., : self.num_bins] / math.sqrt(self.filter_channels)
496 | unnormalized_heights = h[..., self.num_bins : 2 * self.num_bins] / math.sqrt(
497 | self.filter_channels
498 | )
499 | unnormalized_derivatives = h[..., 2 * self.num_bins :]
500 |
501 | x1, logabsdet = piecewise_rational_quadratic_transform(
502 | x1,
503 | unnormalized_widths,
504 | unnormalized_heights,
505 | unnormalized_derivatives,
506 | inverse=reverse,
507 | tails="linear",
508 | tail_bound=self.tail_bound,
509 | )
510 |
511 | x = torch.cat([x0, x1], 1) * x_mask
512 | logdet = torch.sum(logabsdet * x_mask, [1, 2])
513 | if not reverse:
514 | return x, logdet
515 | else:
516 | return x
517 |
518 |
519 | class TransformerCouplingLayer(nn.Module):
520 | def __init__(
521 | self,
522 | channels,
523 | hidden_channels,
524 | kernel_size,
525 | n_layers,
526 | n_heads,
527 | p_dropout=0,
528 | filter_channels=0,
529 | mean_only=False,
530 | wn_sharing_parameter=None,
531 | gin_channels=0,
532 | ):
533 | assert n_layers == 3, n_layers
534 | assert channels % 2 == 0, "channels should be divisible by 2"
535 | super().__init__()
536 | self.channels = channels
537 | self.hidden_channels = hidden_channels
538 | self.kernel_size = kernel_size
539 | self.n_layers = n_layers
540 | self.half_channels = channels // 2
541 | self.mean_only = mean_only
542 |
543 | self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
544 | self.enc = (
545 | Encoder(
546 | hidden_channels,
547 | filter_channels,
548 | n_heads,
549 | n_layers,
550 | kernel_size,
551 | p_dropout,
552 | isflow=True,
553 | gin_channels=gin_channels,
554 | )
555 | if wn_sharing_parameter is None
556 | else wn_sharing_parameter
557 | )
558 | self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
559 | self.post.weight.data.zero_()
560 | self.post.bias.data.zero_()
561 |
562 | def forward(self, x, x_mask, g=None, reverse=False):
563 | x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
564 | h = self.pre(x0) * x_mask
565 | h = self.enc(h, x_mask, g=g)
566 | stats = self.post(h) * x_mask
567 | if not self.mean_only:
568 | m, logs = torch.split(stats, [self.half_channels] * 2, 1)
569 | else:
570 | m = stats
571 | logs = torch.zeros_like(m)
572 |
573 | if not reverse:
574 | x1 = m + x1 * torch.exp(logs) * x_mask
575 | x = torch.cat([x0, x1], 1)
576 | logdet = torch.sum(logs, [1, 2])
577 | return x, logdet
578 | else:
579 | x1 = (x1 - m) * torch.exp(-logs) * x_mask
580 | x = torch.cat([x0, x1], 1)
581 | return x
582 |
583 | x1, logabsdet = piecewise_rational_quadratic_transform(
584 | x1,
585 | unnormalized_widths,
586 | unnormalized_heights,
587 | unnormalized_derivatives,
588 | inverse=reverse,
589 | tails="linear",
590 | tail_bound=self.tail_bound,
591 | )
592 |
593 | x = torch.cat([x0, x1], 1) * x_mask
594 | logdet = torch.sum(logabsdet * x_mask, [1, 2])
595 | if not reverse:
596 | return x, logdet
597 | else:
598 | return x
599 |
--------------------------------------------------------------------------------