├── .gitattributes ├── demo.wav ├── tensorboard.png ├── .gitmodules ├── requirements.txt ├── Dockerfile ├── preprocess.py ├── multiproc.py ├── loss_function.py ├── timerthread.py ├── text ├── symbols.py ├── LICENSE ├── cmudict.py ├── numbers.py ├── __init__.py └── cleaners.py ├── utils.py ├── LICENSE ├── plotting_utils.py ├── logger.py ├── .gitignore ├── audio_processing.py ├── ui.py ├── hparams.py ├── layers.py ├── data_utils.py ├── loss_scaler.py ├── README.md ├── stft.py ├── switch.py ├── distributed.py ├── filelists └── ljs_audio_text_val_filelist.txt ├── train.py ├── nvidia_tacotron_TTS_Layout.py ├── model.py ├── utils_hparam.py └── gui.py /.gitattributes: -------------------------------------------------------------------------------- 1 | *.ipynb linguist-language=Python -------------------------------------------------------------------------------- /demo.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lokkelvin2/tacotron2-tts-GUI/HEAD/demo.wav -------------------------------------------------------------------------------- /tensorboard.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lokkelvin2/tacotron2-tts-GUI/HEAD/tensorboard.png -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "waveglow"] 2 | path = waveglow 3 | url = https://github.com/lokkelvin2/waveglow_GUI 4 | branch = gui 5 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | matplotlib 2 | requests 3 | numpy 4 | inflect 5 | librosa 6 | scipy 7 | Unidecode 8 | pillow 9 | pygame 10 | pyqt5==5.15.0 11 | numba==0.48 12 | tqdm -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM pytorch/pytorch:nightly-devel-cuda10.0-cudnn7 2 | ENV PATH /usr/local/nvidia/bin:/usr/local/cuda/bin:${PATH} 3 | 4 | RUN apt-get update -y 5 | 6 | RUN pip install numpy scipy matplotlib librosa==0.6.0 tensorflow tensorboardX inflect==0.2.5 Unidecode==1.0.22 pillow jupyter 7 | 8 | ADD apex /apex/ 9 | WORKDIR /apex/ 10 | RUN pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" . 11 | -------------------------------------------------------------------------------- /preprocess.py: -------------------------------------------------------------------------------- 1 | import re 2 | import num2words 3 | import textwrap 4 | 5 | _MAXIMUM_ALLOWED_LENGTH = 140 6 | 7 | 8 | def add_fullstop(text): 9 | out = text+'.' if text[-1]!='.' else text 10 | return out 11 | 12 | def break_long_sentences(text): 13 | return textwrap.wrap(text, _MAXIMUM_ALLOWED_LENGTH, break_long_words=False) 14 | 15 | def preprocess_text(text): 16 | ''' 17 | Takes in string, replaces numbers with words, wraps the string into a list 18 | of multiple lines. 19 | ''' 20 | text = re.sub(r"(\d+)", lambda x: num2words.num2words(int(x.group(0))), text) 21 | lines = break_long_sentences(text) 22 | if lines: 23 | lines[-1] = add_fullstop(lines[-1]) 24 | return lines -------------------------------------------------------------------------------- /multiproc.py: -------------------------------------------------------------------------------- 1 | import time 2 | import torch 3 | import sys 4 | import subprocess 5 | 6 | argslist = list(sys.argv)[1:] 7 | num_gpus = torch.cuda.device_count() 8 | argslist.append('--n_gpus={}'.format(num_gpus)) 9 | workers = [] 10 | job_id = time.strftime("%Y_%m_%d-%H%M%S") 11 | argslist.append("--group_name=group_{}".format(job_id)) 12 | 13 | for i in range(num_gpus): 14 | argslist.append('--rank={}'.format(i)) 15 | stdout = None if i == 0 else open("logs/{}_GPU_{}.log".format(job_id, i), 16 | "w") 17 | print(argslist) 18 | p = subprocess.Popen([str(sys.executable)]+argslist, stdout=stdout) 19 | workers.append(p) 20 | argslist = argslist[:-1] 21 | 22 | for p in workers: 23 | p.wait() 24 | -------------------------------------------------------------------------------- /loss_function.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | 4 | class Tacotron2Loss(nn.Module): 5 | def __init__(self): 6 | super(Tacotron2Loss, self).__init__() 7 | 8 | def forward(self, model_output, targets): 9 | mel_target, gate_target = targets[0], targets[1] 10 | mel_target.requires_grad = False 11 | gate_target.requires_grad = False 12 | gate_target = gate_target.view(-1, 1) 13 | 14 | mel_out, mel_out_postnet, gate_out, _ = model_output 15 | gate_out = gate_out.view(-1, 1) 16 | mel_loss = nn.MSELoss()(mel_out, mel_target) + \ 17 | nn.MSELoss()(mel_out_postnet, mel_target) 18 | gate_loss = nn.BCEWithLogitsLoss()(gate_out, gate_target) 19 | return mel_loss + gate_loss 20 | -------------------------------------------------------------------------------- /timerthread.py: -------------------------------------------------------------------------------- 1 | # from https://stackoverflow.com/a/14369192 2 | from PyQt5 import QtCore, QtGui 3 | from PyQt5.QtCore import QObject, pyqtSignal, pyqtSlot, QThread 4 | import time 5 | 6 | class timerThread(QThread): 7 | timeElapsed = pyqtSignal(int) 8 | 9 | def __init__(self, timeoffset, parent=None): 10 | super(timerThread, self).__init__(parent) 11 | self.timeoffset = timeoffset 12 | self.timeStart = None 13 | 14 | 15 | def start(self, timeStart): 16 | self.timeStart = timeStart 17 | 18 | return super(timerThread, self).start() 19 | 20 | def run(self): 21 | while self.parent().isRunning(): 22 | self.timeElapsed.emit(time.time() - self.timeStart + self.timeoffset) 23 | time.sleep(1) 24 | -------------------------------------------------------------------------------- /text/symbols.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | 3 | ''' 4 | Defines the set of symbols used in text input to the model. 5 | 6 | The default is a set of ASCII characters that works well for English or text that has been run through Unidecode. For other data, you can modify _characters. See TRAINING_DATA.md for details. ''' 7 | from text import cmudict 8 | 9 | _pad = '_' 10 | _punctuation = '!\'(),.:;? ' 11 | _special = '-' 12 | _letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz' 13 | 14 | # Prepend "@" to ARPAbet symbols to ensure uniqueness (some are the same as uppercase letters): 15 | _arpabet = ['@' + s for s in cmudict.valid_symbols] 16 | 17 | # Export all symbols: 18 | symbols = [_pad] + list(_special) + list(_punctuation) + list(_letters) + _arpabet 19 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.io.wavfile import read 3 | import torch 4 | 5 | 6 | def get_mask_from_lengths(lengths): 7 | max_len = torch.max(lengths).item() 8 | ids = torch.arange(0, max_len, out=torch.cuda.LongTensor(max_len)) 9 | mask = (ids < lengths.unsqueeze(1)).bool() 10 | return mask 11 | 12 | 13 | def load_wav_to_torch(full_path): 14 | sampling_rate, data = read(full_path) 15 | return torch.FloatTensor(data.astype(np.float32)), sampling_rate 16 | 17 | 18 | def load_filepaths_and_text(filename, split="|"): 19 | with open(filename, encoding='utf-8') as f: 20 | filepaths_and_text = [line.strip().split(split) for line in f] 21 | return filepaths_and_text 22 | 23 | 24 | def to_gpu(x): 25 | x = x.contiguous() 26 | 27 | if torch.cuda.is_available(): 28 | x = x.cuda(non_blocking=True) 29 | return torch.autograd.Variable(x) 30 | -------------------------------------------------------------------------------- /text/LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2017 Keith Ito 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in 11 | all copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 19 | THE SOFTWARE. 20 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2018, NVIDIA Corporation 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | * Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | * Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | * Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /plotting_utils.py: -------------------------------------------------------------------------------- 1 | import matplotlib 2 | matplotlib.use("Agg") 3 | import matplotlib.pylab as plt 4 | import numpy as np 5 | 6 | 7 | def save_figure_to_numpy(fig): 8 | # save it to a numpy array. 9 | data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='') 10 | data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) 11 | return data 12 | 13 | 14 | def plot_alignment_to_numpy(alignment, info=None): 15 | fig, ax = plt.subplots(figsize=(6, 4)) 16 | im = ax.imshow(alignment, aspect='auto', origin='lower', 17 | interpolation='none') 18 | fig.colorbar(im, ax=ax) 19 | xlabel = 'Decoder timestep' 20 | if info is not None: 21 | xlabel += '\n\n' + info 22 | plt.xlabel(xlabel) 23 | plt.ylabel('Encoder timestep') 24 | plt.tight_layout() 25 | 26 | fig.canvas.draw() 27 | data = save_figure_to_numpy(fig) 28 | plt.close() 29 | return data 30 | 31 | 32 | def plot_spectrogram_to_numpy(spectrogram): 33 | fig, ax = plt.subplots(figsize=(12, 3)) 34 | im = ax.imshow(spectrogram, aspect="auto", origin="lower", 35 | interpolation='none') 36 | plt.colorbar(im, ax=ax) 37 | plt.xlabel("Frames") 38 | plt.ylabel("Channels") 39 | plt.tight_layout() 40 | 41 | fig.canvas.draw() 42 | data = save_figure_to_numpy(fig) 43 | plt.close() 44 | return data 45 | 46 | 47 | def plot_gate_outputs_to_numpy(gate_targets, gate_outputs): 48 | fig, ax = plt.subplots(figsize=(12, 3)) 49 | ax.scatter(range(len(gate_targets)), gate_targets, alpha=0.5, 50 | color='green', marker='+', s=1, label='target') 51 | ax.scatter(range(len(gate_outputs)), gate_outputs, alpha=0.5, 52 | color='red', marker='.', s=1, label='predicted') 53 | 54 | plt.xlabel("Frames (Green target, Red predicted)") 55 | plt.ylabel("Gate State") 56 | plt.tight_layout() 57 | 58 | fig.canvas.draw() 59 | data = save_figure_to_numpy(fig) 60 | plt.close() 61 | return data 62 | -------------------------------------------------------------------------------- /logger.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | from torch.utils.tensorboard import SummaryWriter 4 | from plotting_utils import plot_alignment_to_numpy, plot_spectrogram_to_numpy 5 | from plotting_utils import plot_gate_outputs_to_numpy 6 | 7 | 8 | class Tacotron2Logger(SummaryWriter): 9 | def __init__(self, logdir): 10 | super(Tacotron2Logger, self).__init__(logdir) 11 | 12 | def log_training(self, reduced_loss, grad_norm, learning_rate, duration, 13 | iteration): 14 | self.add_scalar("training.loss", reduced_loss, iteration) 15 | self.add_scalar("grad.norm", grad_norm, iteration) 16 | self.add_scalar("learning.rate", learning_rate, iteration) 17 | self.add_scalar("duration", duration, iteration) 18 | 19 | def log_validation(self, reduced_loss, model, y, y_pred, iteration): 20 | self.add_scalar("validation.loss", reduced_loss, iteration) 21 | _, mel_outputs, gate_outputs, alignments = y_pred 22 | mel_targets, gate_targets = y 23 | 24 | # plot distribution of parameters 25 | for tag, value in model.named_parameters(): 26 | tag = tag.replace('.', '/') 27 | self.add_histogram(tag, value.data.cpu().numpy(), iteration) 28 | 29 | # plot alignment, mel target and predicted, gate target and predicted 30 | idx = random.randint(0, alignments.size(0) - 1) 31 | self.add_image( 32 | "alignment", 33 | plot_alignment_to_numpy(alignments[idx].data.cpu().numpy().T), 34 | iteration, dataformats='HWC') 35 | self.add_image( 36 | "mel_target", 37 | plot_spectrogram_to_numpy(mel_targets[idx].data.cpu().numpy()), 38 | iteration, dataformats='HWC') 39 | self.add_image( 40 | "mel_predicted", 41 | plot_spectrogram_to_numpy(mel_outputs[idx].data.cpu().numpy()), 42 | iteration, dataformats='HWC') 43 | self.add_image( 44 | "gate", 45 | plot_gate_outputs_to_numpy( 46 | gate_targets[idx].data.cpu().numpy(), 47 | torch.sigmoid(gate_outputs[idx]).data.cpu().numpy()), 48 | iteration, dataformats='HWC') 49 | -------------------------------------------------------------------------------- /text/cmudict.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | 3 | import re 4 | 5 | 6 | valid_symbols = [ 7 | 'AA', 'AA0', 'AA1', 'AA2', 'AE', 'AE0', 'AE1', 'AE2', 'AH', 'AH0', 'AH1', 'AH2', 8 | 'AO', 'AO0', 'AO1', 'AO2', 'AW', 'AW0', 'AW1', 'AW2', 'AY', 'AY0', 'AY1', 'AY2', 9 | 'B', 'CH', 'D', 'DH', 'EH', 'EH0', 'EH1', 'EH2', 'ER', 'ER0', 'ER1', 'ER2', 'EY', 10 | 'EY0', 'EY1', 'EY2', 'F', 'G', 'HH', 'IH', 'IH0', 'IH1', 'IH2', 'IY', 'IY0', 'IY1', 11 | 'IY2', 'JH', 'K', 'L', 'M', 'N', 'NG', 'OW', 'OW0', 'OW1', 'OW2', 'OY', 'OY0', 12 | 'OY1', 'OY2', 'P', 'R', 'S', 'SH', 'T', 'TH', 'UH', 'UH0', 'UH1', 'UH2', 'UW', 13 | 'UW0', 'UW1', 'UW2', 'V', 'W', 'Y', 'Z', 'ZH' 14 | ] 15 | 16 | _valid_symbol_set = set(valid_symbols) 17 | 18 | 19 | class CMUDict: 20 | '''Thin wrapper around CMUDict data. http://www.speech.cs.cmu.edu/cgi-bin/cmudict''' 21 | def __init__(self, file_or_path, keep_ambiguous=True): 22 | if isinstance(file_or_path, str): 23 | with open(file_or_path, encoding='latin-1') as f: 24 | entries = _parse_cmudict(f) 25 | else: 26 | entries = _parse_cmudict(file_or_path) 27 | if not keep_ambiguous: 28 | entries = {word: pron for word, pron in entries.items() if len(pron) == 1} 29 | self._entries = entries 30 | 31 | 32 | def __len__(self): 33 | return len(self._entries) 34 | 35 | 36 | def lookup(self, word): 37 | '''Returns list of ARPAbet pronunciations of the given word.''' 38 | return self._entries.get(word.upper()) 39 | 40 | 41 | 42 | _alt_re = re.compile(r'\([0-9]+\)') 43 | 44 | 45 | def _parse_cmudict(file): 46 | cmudict = {} 47 | for line in file: 48 | if len(line) and (line[0] >= 'A' and line[0] <= 'Z' or line[0] == "'"): 49 | parts = line.split(' ') 50 | word = re.sub(_alt_re, '', parts[0]) 51 | pronunciation = _get_pronunciation(parts[1]) 52 | if pronunciation: 53 | if word in cmudict: 54 | cmudict[word].append(pronunciation) 55 | else: 56 | cmudict[word] = [pronunciation] 57 | return cmudict 58 | 59 | 60 | def _get_pronunciation(s): 61 | parts = s.strip().split(' ') 62 | for part in parts: 63 | if part not in _valid_symbol_set: 64 | return None 65 | return ' '.join(parts) 66 | -------------------------------------------------------------------------------- /text/numbers.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | 3 | import inflect 4 | import re 5 | 6 | 7 | _inflect = inflect.engine() 8 | _comma_number_re = re.compile(r'([0-9][0-9\,]+[0-9])') 9 | _decimal_number_re = re.compile(r'([0-9]+\.[0-9]+)') 10 | _pounds_re = re.compile(r'£([0-9\,]*[0-9]+)') 11 | _dollars_re = re.compile(r'\$([0-9\.\,]*[0-9]+)') 12 | _ordinal_re = re.compile(r'[0-9]+(st|nd|rd|th)') 13 | _number_re = re.compile(r'[0-9]+') 14 | 15 | 16 | def _remove_commas(m): 17 | return m.group(1).replace(',', '') 18 | 19 | 20 | def _expand_decimal_point(m): 21 | return m.group(1).replace('.', ' point ') 22 | 23 | 24 | def _expand_dollars(m): 25 | match = m.group(1) 26 | parts = match.split('.') 27 | if len(parts) > 2: 28 | return match + ' dollars' # Unexpected format 29 | dollars = int(parts[0]) if parts[0] else 0 30 | cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0 31 | if dollars and cents: 32 | dollar_unit = 'dollar' if dollars == 1 else 'dollars' 33 | cent_unit = 'cent' if cents == 1 else 'cents' 34 | return '%s %s, %s %s' % (dollars, dollar_unit, cents, cent_unit) 35 | elif dollars: 36 | dollar_unit = 'dollar' if dollars == 1 else 'dollars' 37 | return '%s %s' % (dollars, dollar_unit) 38 | elif cents: 39 | cent_unit = 'cent' if cents == 1 else 'cents' 40 | return '%s %s' % (cents, cent_unit) 41 | else: 42 | return 'zero dollars' 43 | 44 | 45 | def _expand_ordinal(m): 46 | return _inflect.number_to_words(m.group(0)) 47 | 48 | 49 | def _expand_number(m): 50 | num = int(m.group(0)) 51 | if num > 1000 and num < 3000: 52 | if num == 2000: 53 | return 'two thousand' 54 | elif num > 2000 and num < 2010: 55 | return 'two thousand ' + _inflect.number_to_words(num % 100) 56 | elif num % 100 == 0: 57 | return _inflect.number_to_words(num // 100) + ' hundred' 58 | else: 59 | return _inflect.number_to_words(num, andword='', zero='oh', group=2).replace(', ', ' ') 60 | else: 61 | return _inflect.number_to_words(num, andword='') 62 | 63 | 64 | def normalize_numbers(text): 65 | text = re.sub(_comma_number_re, _remove_commas, text) 66 | text = re.sub(_pounds_re, r'\1 pounds', text) 67 | text = re.sub(_dollars_re, _expand_dollars, text) 68 | text = re.sub(_decimal_number_re, _expand_decimal_point, text) 69 | text = re.sub(_ordinal_re, _expand_ordinal, text) 70 | text = re.sub(_number_re, _expand_number, text) 71 | return text 72 | -------------------------------------------------------------------------------- /text/__init__.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | import re 3 | from text import cleaners 4 | from text.symbols import symbols 5 | 6 | 7 | # Mappings from symbol to numeric ID and vice versa: 8 | _symbol_to_id = {s: i for i, s in enumerate(symbols)} 9 | _id_to_symbol = {i: s for i, s in enumerate(symbols)} 10 | 11 | # Regular expression matching text enclosed in curly braces: 12 | _curly_re = re.compile(r'(.*?)\{(.+?)\}(.*)') 13 | 14 | 15 | def text_to_sequence(text, cleaner_names): 16 | '''Converts a string of text to a sequence of IDs corresponding to the symbols in the text. 17 | 18 | The text can optionally have ARPAbet sequences enclosed in curly braces embedded 19 | in it. For example, "Turn left on {HH AW1 S S T AH0 N} Street." 20 | 21 | Args: 22 | text: string to convert to a sequence 23 | cleaner_names: names of the cleaner functions to run the text through 24 | 25 | Returns: 26 | List of integers corresponding to the symbols in the text 27 | ''' 28 | sequence = [] 29 | 30 | # Check for curly braces and treat their contents as ARPAbet: 31 | while len(text): 32 | m = _curly_re.match(text) 33 | if not m: 34 | sequence += _symbols_to_sequence(_clean_text(text, cleaner_names)) 35 | break 36 | sequence += _symbols_to_sequence(_clean_text(m.group(1), cleaner_names)) 37 | sequence += _arpabet_to_sequence(m.group(2)) 38 | text = m.group(3) 39 | 40 | return sequence 41 | 42 | 43 | def sequence_to_text(sequence): 44 | '''Converts a sequence of IDs back to a string''' 45 | result = '' 46 | for symbol_id in sequence: 47 | if symbol_id in _id_to_symbol: 48 | s = _id_to_symbol[symbol_id] 49 | # Enclose ARPAbet back in curly braces: 50 | if len(s) > 1 and s[0] == '@': 51 | s = '{%s}' % s[1:] 52 | result += s 53 | return result.replace('}{', ' ') 54 | 55 | 56 | def _clean_text(text, cleaner_names): 57 | for name in cleaner_names: 58 | cleaner = getattr(cleaners, name) 59 | if not cleaner: 60 | raise Exception('Unknown cleaner: %s' % name) 61 | text = cleaner(text) 62 | return text 63 | 64 | 65 | def _symbols_to_sequence(symbols): 66 | return [_symbol_to_id[s] for s in symbols if _should_keep_symbol(s)] 67 | 68 | 69 | def _arpabet_to_sequence(text): 70 | return _symbols_to_sequence(['@' + s for s in text.split()]) 71 | 72 | 73 | def _should_keep_symbol(s): 74 | return s in _symbol_to_id and s is not '_' and s is not '~' 75 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | *.egg-info/ 30 | .installed.cfg 31 | *.egg 32 | 33 | # PyInstaller 34 | # Usually these files are written by a python script from a template 35 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 36 | *.manifest 37 | *.spec 38 | 39 | # Installer logs 40 | pip-log.txt 41 | pip-delete-this-directory.txt 42 | 43 | # Unit test / coverage reports 44 | htmlcov/ 45 | .tox/ 46 | .nox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | *.cover 53 | *.py,cover 54 | .hypothesis/ 55 | .pytest_cache/ 56 | cover/ 57 | *,cover 58 | 59 | # Translations 60 | *.mo 61 | *.pot 62 | 63 | # Django stuff: 64 | *.log 65 | local_settings.py 66 | db.sqlite3 67 | db.sqlite3-journal 68 | 69 | # Flask stuff: 70 | instance/ 71 | .webassets-cache 72 | 73 | # Scrapy stuff: 74 | .scrapy 75 | 76 | # Sphinx documentation 77 | docs/_build/ 78 | 79 | # PyBuilder 80 | .pybuilder/ 81 | target/ 82 | 83 | # Jupyter Notebook 84 | .ipynb_checkpoints 85 | 86 | # IPython 87 | profile_default/ 88 | ipython_config.py 89 | 90 | # pyenv 91 | # For a library or package, you might want to ignore these files since the code is 92 | # intended to run in multiple environments; otherwise, check them in: 93 | # .python-version 94 | 95 | # pipenv 96 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 97 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 98 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 99 | # install all needed dependencies. 100 | #Pipfile.lock 101 | 102 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 103 | __pypackages__/ 104 | 105 | # Celery stuff 106 | celerybeat-schedule 107 | celerybeat.pid 108 | 109 | # SageMath parsed files 110 | *.sage.py 111 | 112 | # Environments 113 | .env 114 | .venv 115 | env/ 116 | venv/ 117 | ENV/ 118 | env.bak/ 119 | venv.bak/ 120 | virtualenv/ 121 | 122 | # Spyder project settings 123 | .spyderproject 124 | .spyproject 125 | 126 | # Rope project settings 127 | .ropeproject 128 | 129 | # mkdocs documentation 130 | /site 131 | 132 | # mypy 133 | .mypy_cache/ 134 | .dmypy.json 135 | dmypy.json 136 | 137 | # Pyre type checker 138 | .pyre/ 139 | 140 | # pytype static type analyzer 141 | .pytype/ 142 | 143 | # Cython debug symbols 144 | cython_debug/ 145 | 146 | # Others 147 | .ipynb_checkpoints/ 148 | .vs/ 149 | srt2fcpxml.py 150 | secrets.py 151 | .vscode/ 152 | target/ 153 | -------------------------------------------------------------------------------- /audio_processing.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from scipy.signal import get_window 4 | import librosa.util as librosa_util 5 | 6 | 7 | def window_sumsquare(window, n_frames, hop_length=200, win_length=800, 8 | n_fft=800, dtype=np.float32, norm=None): 9 | """ 10 | # from librosa 0.6 11 | Compute the sum-square envelope of a window function at a given hop length. 12 | 13 | This is used to estimate modulation effects induced by windowing 14 | observations in short-time fourier transforms. 15 | 16 | Parameters 17 | ---------- 18 | window : string, tuple, number, callable, or list-like 19 | Window specification, as in `get_window` 20 | 21 | n_frames : int > 0 22 | The number of analysis frames 23 | 24 | hop_length : int > 0 25 | The number of samples to advance between frames 26 | 27 | win_length : [optional] 28 | The length of the window function. By default, this matches `n_fft`. 29 | 30 | n_fft : int > 0 31 | The length of each analysis frame. 32 | 33 | dtype : np.dtype 34 | The data type of the output 35 | 36 | Returns 37 | ------- 38 | wss : np.ndarray, shape=`(n_fft + hop_length * (n_frames - 1))` 39 | The sum-squared envelope of the window function 40 | """ 41 | if win_length is None: 42 | win_length = n_fft 43 | 44 | n = n_fft + hop_length * (n_frames - 1) 45 | x = np.zeros(n, dtype=dtype) 46 | 47 | # Compute the squared window at the desired length 48 | win_sq = get_window(window, win_length, fftbins=True) 49 | win_sq = librosa_util.normalize(win_sq, norm=norm)**2 50 | win_sq = librosa_util.pad_center(win_sq, n_fft) 51 | 52 | # Fill the envelope 53 | for i in range(n_frames): 54 | sample = i * hop_length 55 | x[sample:min(n, sample + n_fft)] += win_sq[:max(0, min(n_fft, n - sample))] 56 | return x 57 | 58 | 59 | def griffin_lim(magnitudes, stft_fn, n_iters=30): 60 | """ 61 | PARAMS 62 | ------ 63 | magnitudes: spectrogram magnitudes 64 | stft_fn: STFT class with transform (STFT) and inverse (ISTFT) methods 65 | """ 66 | 67 | angles = np.angle(np.exp(2j * np.pi * np.random.rand(*magnitudes.size()))) 68 | angles = angles.astype(np.float32) 69 | angles = torch.autograd.Variable(torch.from_numpy(angles)) 70 | signal = stft_fn.inverse(magnitudes, angles).squeeze(1) 71 | 72 | for i in range(n_iters): 73 | _, angles = stft_fn.transform(signal) 74 | signal = stft_fn.inverse(magnitudes, angles).squeeze(1) 75 | return signal 76 | 77 | 78 | def dynamic_range_compression(x, C=1, clip_val=1e-5): 79 | """ 80 | PARAMS 81 | ------ 82 | C: compression factor 83 | """ 84 | return torch.log(torch.clamp(x, min=clip_val) * C) 85 | 86 | 87 | def dynamic_range_decompression(x, C=1): 88 | """ 89 | PARAMS 90 | ------ 91 | C: compression factor used to compress 92 | """ 93 | return torch.exp(x) / C 94 | -------------------------------------------------------------------------------- /ui.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from PyQt5 import QtCore, QtGui, QtWidgets 4 | from switch import Switch 5 | import torch 6 | 7 | class Ui_extras(object): 8 | def drawGpuSwitch(self, MainWindow): 9 | MainWindow.GpuSwitch = Switch(thumb_radius=8, track_radius=10, show_text = False) 10 | MainWindow.horizontalLayout.addWidget(MainWindow.GpuSwitch) 11 | MainWindow.GpuSwitch.setEnabled(torch.cuda.is_available()) 12 | MainWindow.use_cuda = False 13 | MainWindow.GpuSwitch.setToolTip("

CUDA installed: {}

".format(torch.cuda.is_available())) 14 | 15 | def initWidgets(self, MainWindow): 16 | MainWindow.TTSStopButton.setDisabled(True) 17 | MainWindow.progressBar2Label.setText('') 18 | MainWindow.progressBarLabel.setText('') 19 | MainWindow.ClientStopBtn.setDisabled(True) 20 | MainWindow.ClientSkipBtn.setDisabled(True) 21 | MainWindow.TTModelCombo.setDisabled(True) 22 | MainWindow.WGModelCombo.setDisabled(True) 23 | MainWindow.TTSDialogButton.setDisabled(True) 24 | MainWindow.tab_2.setDisabled(True) 25 | MainWindow.log_window2.ensureCursorVisible() 26 | MainWindow.label_10.setDisabled(True) 27 | MainWindow.OptLimitCpuCombo.setDisabled(True) 28 | 29 | MainWindow.OptLimitCpuCombo.addItems( 30 | [str(i) for i in range(1,torch.get_num_threads()+1)]) 31 | MainWindow.OptLimitCpuCombo.setCurrentIndex(torch.get_num_threads()-1) 32 | 33 | def setUpconnections(self,MainWindow): 34 | # Static widget signals 35 | MainWindow.TTModelCombo.currentIndexChanged.connect(MainWindow.set_reload_model_flag) 36 | MainWindow.WGModelCombo.currentIndexChanged.connect(MainWindow.set_reload_model_flag) 37 | MainWindow.TTSDialogButton.clicked.connect(MainWindow.start_synthesis) 38 | MainWindow.TTSStopButton.clicked.connect(MainWindow.skip_infer_playback) 39 | MainWindow.LoadTTButton.clicked.connect(MainWindow.add_TTmodel_path) 40 | MainWindow.LoadWGButton.clicked.connect(MainWindow.add_WGmodel_path) 41 | MainWindow.ClientSkipBtn.clicked.connect(MainWindow.skip_eventloop) 42 | MainWindow.ClientStartBtn.clicked.connect(MainWindow.start_eventloop) 43 | MainWindow.ClientStopBtn.clicked.connect(MainWindow.stop_eventloop) 44 | MainWindow.OptLimitCpuBtn.stateChanged.connect(MainWindow.toggle_cpu_limit) 45 | MainWindow.OptLimitCpuCombo.currentIndexChanged.connect(MainWindow.change_cpu_limit) 46 | MainWindow.OptApproveDonoBtn.stateChanged.connect(MainWindow.toggle_approve_dono) 47 | MainWindow.OptBlockNumberBtn.stateChanged.connect(MainWindow.toggle_block_number) 48 | MainWindow.OptDonoNameAmountBtn.stateChanged.connect(MainWindow.toggle_dono_amount) 49 | # Instantiated widget signals 50 | MainWindow.GpuSwitch.toggled.connect(MainWindow.set_cuda) 51 | # Instantiated signals 52 | MainWindow.signals.progress.connect(MainWindow.update_log_bar) 53 | MainWindow.signals.elapsed.connect(MainWindow.on_elapsed) 54 | 55 | 56 | 57 | 58 | 59 | -------------------------------------------------------------------------------- /hparams.py: -------------------------------------------------------------------------------- 1 | from text import symbols 2 | 3 | 4 | def create_hparams(hparams_string=None, verbose=False): 5 | """Create model hyperparameters. Parse nondefault from given string.""" 6 | 7 | from utils_hparam import HParams 8 | hparams = HParams( 9 | ################################ 10 | # Experiment Parameters # 11 | ################################ 12 | epochs=500, 13 | iters_per_checkpoint=1000, 14 | seed=1234, 15 | dynamic_loss_scaling=True, 16 | fp16_run=False, 17 | distributed_run=False, 18 | dist_backend="nccl", 19 | dist_url="tcp://localhost:54321", 20 | cudnn_enabled=True, 21 | cudnn_benchmark=False, 22 | ignore_layers=['embedding.weight'], 23 | 24 | ################################ 25 | # Data Parameters # 26 | ################################ 27 | load_mel_from_disk=False, 28 | training_files='filelists/ljs_audio_text_train_filelist.txt', 29 | validation_files='filelists/ljs_audio_text_val_filelist.txt', 30 | text_cleaners=['english_cleaners'], 31 | 32 | ################################ 33 | # Audio Parameters # 34 | ################################ 35 | max_wav_value=32768.0, 36 | sampling_rate=22050, 37 | filter_length=1024, 38 | hop_length=256, 39 | win_length=1024, 40 | n_mel_channels=80, 41 | mel_fmin=0.0, 42 | mel_fmax=8000.0, 43 | 44 | ################################ 45 | # Model Parameters # 46 | ################################ 47 | n_symbols=len(symbols), 48 | symbols_embedding_dim=512, 49 | 50 | # Encoder parameters 51 | encoder_kernel_size=5, 52 | encoder_n_convolutions=3, 53 | encoder_embedding_dim=512, 54 | 55 | # Decoder parameters 56 | n_frames_per_step=1, # currently only 1 is supported 57 | decoder_rnn_dim=1024, 58 | prenet_dim=256, 59 | max_decoder_steps=1000, 60 | gate_threshold=0.5, 61 | p_attention_dropout=0.1, 62 | p_decoder_dropout=0.1, 63 | 64 | # Attention parameters 65 | attention_rnn_dim=1024, 66 | attention_dim=128, 67 | 68 | # Location Layer parameters 69 | attention_location_n_filters=32, 70 | attention_location_kernel_size=31, 71 | 72 | # Mel-post processing network parameters 73 | postnet_embedding_dim=512, 74 | postnet_kernel_size=5, 75 | postnet_n_convolutions=5, 76 | 77 | ################################ 78 | # Optimization Hyperparameters # 79 | ################################ 80 | use_saved_learning_rate=False, 81 | learning_rate=1e-3, 82 | weight_decay=1e-6, 83 | grad_clip_thresh=1.0, 84 | batch_size=64, 85 | mask_padding=True # set model's padded outputs to padded values 86 | ) 87 | 88 | if hparams_string: 89 | print('Parsing command line hparams: ', hparams_string) 90 | hparams.parse(hparams_string) 91 | 92 | if verbose: 93 | print('Final parsed hparams: ', hparams.values()) 94 | 95 | return hparams 96 | -------------------------------------------------------------------------------- /layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from librosa.filters import mel as librosa_mel_fn 3 | from audio_processing import dynamic_range_compression 4 | from audio_processing import dynamic_range_decompression 5 | from stft import STFT 6 | 7 | 8 | class LinearNorm(torch.nn.Module): 9 | def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'): 10 | super(LinearNorm, self).__init__() 11 | self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias) 12 | 13 | torch.nn.init.xavier_uniform_( 14 | self.linear_layer.weight, 15 | gain=torch.nn.init.calculate_gain(w_init_gain)) 16 | 17 | def forward(self, x): 18 | return self.linear_layer(x) 19 | 20 | 21 | class ConvNorm(torch.nn.Module): 22 | def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, 23 | padding=None, dilation=1, bias=True, w_init_gain='linear'): 24 | super(ConvNorm, self).__init__() 25 | if padding is None: 26 | assert(kernel_size % 2 == 1) 27 | padding = int(dilation * (kernel_size - 1) / 2) 28 | 29 | self.conv = torch.nn.Conv1d(in_channels, out_channels, 30 | kernel_size=kernel_size, stride=stride, 31 | padding=padding, dilation=dilation, 32 | bias=bias) 33 | 34 | torch.nn.init.xavier_uniform_( 35 | self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain)) 36 | 37 | def forward(self, signal): 38 | conv_signal = self.conv(signal) 39 | return conv_signal 40 | 41 | 42 | class TacotronSTFT(torch.nn.Module): 43 | def __init__(self, filter_length=1024, hop_length=256, win_length=1024, 44 | n_mel_channels=80, sampling_rate=22050, mel_fmin=0.0, 45 | mel_fmax=8000.0): 46 | super(TacotronSTFT, self).__init__() 47 | self.n_mel_channels = n_mel_channels 48 | self.sampling_rate = sampling_rate 49 | self.stft_fn = STFT(filter_length, hop_length, win_length) 50 | mel_basis = librosa_mel_fn( 51 | sampling_rate, filter_length, n_mel_channels, mel_fmin, mel_fmax) 52 | mel_basis = torch.from_numpy(mel_basis).float() 53 | self.register_buffer('mel_basis', mel_basis) 54 | 55 | def spectral_normalize(self, magnitudes): 56 | output = dynamic_range_compression(magnitudes) 57 | return output 58 | 59 | def spectral_de_normalize(self, magnitudes): 60 | output = dynamic_range_decompression(magnitudes) 61 | return output 62 | 63 | def mel_spectrogram(self, y): 64 | """Computes mel-spectrograms from a batch of waves 65 | PARAMS 66 | ------ 67 | y: Variable(torch.FloatTensor) with shape (B, T) in range [-1, 1] 68 | 69 | RETURNS 70 | ------- 71 | mel_output: torch.FloatTensor of shape (B, n_mel_channels, T) 72 | """ 73 | assert(torch.min(y.data) >= -1) 74 | assert(torch.max(y.data) <= 1) 75 | 76 | magnitudes, phases = self.stft_fn.transform(y) 77 | magnitudes = magnitudes.data 78 | mel_output = torch.matmul(self.mel_basis, magnitudes) 79 | mel_output = self.spectral_normalize(mel_output) 80 | return mel_output 81 | -------------------------------------------------------------------------------- /text/cleaners.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | 3 | ''' 4 | Cleaners are transformations that run over the input text at both training and eval time. 5 | 6 | Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners" 7 | hyperparameter. Some cleaners are English-specific. You'll typically want to use: 8 | 1. "english_cleaners" for English text 9 | 2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using 10 | the Unidecode library (https://pypi.python.org/pypi/Unidecode) 11 | 3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update 12 | the symbols in symbols.py to match your data). 13 | ''' 14 | 15 | import re 16 | from unidecode import unidecode 17 | from .numbers import normalize_numbers 18 | 19 | 20 | # Regular expression matching whitespace: 21 | _whitespace_re = re.compile(r'\s+') 22 | 23 | # List of (regular expression, replacement) pairs for abbreviations: 24 | _abbreviations = [(re.compile('\\b%s(\\.|\\b)' % x[0], re.IGNORECASE), x[1]) for x in [ 25 | ('mrs', 'misess'), 26 | ('mr', 'mister'), 27 | ('dr', 'doctor'), 28 | ('st', 'saint'), 29 | ('co', 'company'), 30 | ('jr', 'junior'), 31 | ('maj', 'major'), 32 | ('gen', 'general'), 33 | ('drs', 'doctors'), 34 | ('rev', 'reverend'), 35 | ('lt', 'lieutenant'), 36 | ('hon', 'honorable'), 37 | ('sgt', 'sergeant'), 38 | ('capt', 'captain'), 39 | ('esq', 'esquire'), 40 | ('ltd', 'limited'), 41 | ('col', 'colonel'), 42 | ('ft', 'fort'), 43 | ('bn', 'billion'), 44 | ]] 45 | 46 | _currency = [(re.compile('\\b%s' % x[0], re.IGNORECASE), x[1]) for x in [ 47 | ('aud', 'australia dollar'), 48 | ('brl', 'brazil real'), 49 | ('cad', 'canada Dollar'), 50 | ('czk', 'czech koruna'), 51 | ('dkk', 'denmark krone'), 52 | ('eur', 'euro'), 53 | ('hkd', 'hong kong dollar'), 54 | ('huf', 'hungary forint'), 55 | ('ils', 'israel new shekel'), 56 | ('jpy', 'japan yen'), 57 | ('myr', 'malaysia ringgit'), 58 | ('mxn', 'mexico peso'), 59 | ('nok', 'norway kroner'), 60 | ('nzd', 'new zealand dollar'), 61 | ('php', 'philippine peso'), 62 | ('gbp', 'great britain pound'), 63 | ('rub', 'russia rouble'), 64 | ('sgd', 'singapore dollar'), 65 | ('sek', 'sweden krona'), 66 | ('chf', 'switzerland franc'), 67 | ('twd', 'taiwan dollar'), 68 | ('thb', 'thailand baht'), 69 | ('try', 'turkish new lira'), 70 | ('usd', 'us dollar'), 71 | ]] 72 | 73 | def expand_currency(text): 74 | for regex, replacement in _currency: 75 | text = re.sub(regex, replacement, text) 76 | return text 77 | 78 | def expand_abbreviations(text): 79 | for regex, replacement in _abbreviations: 80 | text = re.sub(regex, replacement, text) 81 | return text 82 | 83 | 84 | def expand_numbers(text): 85 | return normalize_numbers(text) 86 | 87 | 88 | def lowercase(text): 89 | return text.lower() 90 | 91 | 92 | def collapse_whitespace(text): 93 | return re.sub(_whitespace_re, ' ', text) 94 | 95 | 96 | def convert_to_ascii(text): 97 | return unidecode(text) 98 | 99 | def fullstop_short_phrases(text): 100 | if len(text) < 10: 101 | if text[-1]!='.': 102 | return text + '.' 103 | 104 | 105 | def basic_cleaners(text): 106 | '''Basic pipeline that lowercases and collapses whitespace without transliteration.''' 107 | text = lowercase(text) 108 | text = collapse_whitespace(text) 109 | return text 110 | 111 | 112 | def transliteration_cleaners(text): 113 | '''Pipeline for non-English text that transliterates to ASCII.''' 114 | text = convert_to_ascii(text) 115 | text = lowercase(text) 116 | text = collapse_whitespace(text) 117 | return text 118 | 119 | 120 | def english_cleaners(text): 121 | '''Pipeline for English text, including number and abbreviation expansion.''' 122 | text = convert_to_ascii(text) 123 | text = lowercase(text) 124 | text = expand_numbers(text) 125 | text = expand_abbreviations(text) 126 | text = collapse_whitespace(text) 127 | return text 128 | -------------------------------------------------------------------------------- /data_utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | import torch 4 | import torch.utils.data 5 | 6 | import layers 7 | from utils import load_wav_to_torch, load_filepaths_and_text 8 | from text import text_to_sequence 9 | 10 | 11 | class TextMelLoader(torch.utils.data.Dataset): 12 | """ 13 | 1) loads audio,text pairs 14 | 2) normalizes text and converts them to sequences of one-hot vectors 15 | 3) computes mel-spectrograms from audio files. 16 | """ 17 | def __init__(self, audiopaths_and_text, hparams): 18 | self.audiopaths_and_text = load_filepaths_and_text(audiopaths_and_text) 19 | self.text_cleaners = hparams.text_cleaners 20 | self.max_wav_value = hparams.max_wav_value 21 | self.sampling_rate = hparams.sampling_rate 22 | self.load_mel_from_disk = hparams.load_mel_from_disk 23 | self.stft = layers.TacotronSTFT( 24 | hparams.filter_length, hparams.hop_length, hparams.win_length, 25 | hparams.n_mel_channels, hparams.sampling_rate, hparams.mel_fmin, 26 | hparams.mel_fmax) 27 | random.seed(hparams.seed) 28 | random.shuffle(self.audiopaths_and_text) 29 | 30 | def get_mel_text_pair(self, audiopath_and_text): 31 | # separate filename and text 32 | audiopath, text = audiopath_and_text[0], audiopath_and_text[1] 33 | text = self.get_text(text) 34 | mel = self.get_mel(audiopath) 35 | return (text, mel) 36 | 37 | def get_mel(self, filename): 38 | if not self.load_mel_from_disk: 39 | audio, sampling_rate = load_wav_to_torch(filename) 40 | if sampling_rate != self.stft.sampling_rate: 41 | raise ValueError("{} {} SR doesn't match target {} SR".format( 42 | sampling_rate, self.stft.sampling_rate)) 43 | audio_norm = audio / self.max_wav_value 44 | audio_norm = audio_norm.unsqueeze(0) 45 | audio_norm = torch.autograd.Variable(audio_norm, requires_grad=False) 46 | melspec = self.stft.mel_spectrogram(audio_norm) 47 | melspec = torch.squeeze(melspec, 0) 48 | else: 49 | melspec = torch.from_numpy(np.load(filename)) 50 | assert melspec.size(0) == self.stft.n_mel_channels, ( 51 | 'Mel dimension mismatch: given {}, expected {}'.format( 52 | melspec.size(0), self.stft.n_mel_channels)) 53 | 54 | return melspec 55 | 56 | def get_text(self, text): 57 | text_norm = torch.IntTensor(text_to_sequence(text, self.text_cleaners)) 58 | return text_norm 59 | 60 | def __getitem__(self, index): 61 | return self.get_mel_text_pair(self.audiopaths_and_text[index]) 62 | 63 | def __len__(self): 64 | return len(self.audiopaths_and_text) 65 | 66 | 67 | class TextMelCollate(): 68 | """ Zero-pads model inputs and targets based on number of frames per setep 69 | """ 70 | def __init__(self, n_frames_per_step): 71 | self.n_frames_per_step = n_frames_per_step 72 | 73 | def __call__(self, batch): 74 | """Collate's training batch from normalized text and mel-spectrogram 75 | PARAMS 76 | ------ 77 | batch: [text_normalized, mel_normalized] 78 | """ 79 | # Right zero-pad all one-hot text sequences to max input length 80 | input_lengths, ids_sorted_decreasing = torch.sort( 81 | torch.LongTensor([len(x[0]) for x in batch]), 82 | dim=0, descending=True) 83 | max_input_len = input_lengths[0] 84 | 85 | text_padded = torch.LongTensor(len(batch), max_input_len) 86 | text_padded.zero_() 87 | for i in range(len(ids_sorted_decreasing)): 88 | text = batch[ids_sorted_decreasing[i]][0] 89 | text_padded[i, :text.size(0)] = text 90 | 91 | # Right zero-pad mel-spec 92 | num_mels = batch[0][1].size(0) 93 | max_target_len = max([x[1].size(1) for x in batch]) 94 | if max_target_len % self.n_frames_per_step != 0: 95 | max_target_len += self.n_frames_per_step - max_target_len % self.n_frames_per_step 96 | assert max_target_len % self.n_frames_per_step == 0 97 | 98 | # include mel padded and gate padded 99 | mel_padded = torch.FloatTensor(len(batch), num_mels, max_target_len) 100 | mel_padded.zero_() 101 | gate_padded = torch.FloatTensor(len(batch), max_target_len) 102 | gate_padded.zero_() 103 | output_lengths = torch.LongTensor(len(batch)) 104 | for i in range(len(ids_sorted_decreasing)): 105 | mel = batch[ids_sorted_decreasing[i]][1] 106 | mel_padded[i, :, :mel.size(1)] = mel 107 | gate_padded[i, mel.size(1)-1:] = 1 108 | output_lengths[i] = mel.size(1) 109 | 110 | return text_padded, input_lengths, mel_padded, gate_padded, \ 111 | output_lengths 112 | -------------------------------------------------------------------------------- /loss_scaler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class LossScaler: 4 | 5 | def __init__(self, scale=1): 6 | self.cur_scale = scale 7 | 8 | # `params` is a list / generator of torch.Variable 9 | def has_overflow(self, params): 10 | return False 11 | 12 | # `x` is a torch.Tensor 13 | def _has_inf_or_nan(x): 14 | return False 15 | 16 | # `overflow` is boolean indicating whether we overflowed in gradient 17 | def update_scale(self, overflow): 18 | pass 19 | 20 | @property 21 | def loss_scale(self): 22 | return self.cur_scale 23 | 24 | def scale_gradient(self, module, grad_in, grad_out): 25 | return tuple(self.loss_scale * g for g in grad_in) 26 | 27 | def backward(self, loss): 28 | scaled_loss = loss*self.loss_scale 29 | scaled_loss.backward() 30 | 31 | class DynamicLossScaler: 32 | 33 | def __init__(self, 34 | init_scale=2**32, 35 | scale_factor=2., 36 | scale_window=1000): 37 | self.cur_scale = init_scale 38 | self.cur_iter = 0 39 | self.last_overflow_iter = -1 40 | self.scale_factor = scale_factor 41 | self.scale_window = scale_window 42 | 43 | # `params` is a list / generator of torch.Variable 44 | def has_overflow(self, params): 45 | # return False 46 | for p in params: 47 | if p.grad is not None and DynamicLossScaler._has_inf_or_nan(p.grad.data): 48 | return True 49 | 50 | return False 51 | 52 | # `x` is a torch.Tensor 53 | def _has_inf_or_nan(x): 54 | cpu_sum = float(x.float().sum()) 55 | if cpu_sum == float('inf') or cpu_sum == -float('inf') or cpu_sum != cpu_sum: 56 | return True 57 | return False 58 | 59 | # `overflow` is boolean indicating whether we overflowed in gradient 60 | def update_scale(self, overflow): 61 | if overflow: 62 | #self.cur_scale /= self.scale_factor 63 | self.cur_scale = max(self.cur_scale/self.scale_factor, 1) 64 | self.last_overflow_iter = self.cur_iter 65 | else: 66 | if (self.cur_iter - self.last_overflow_iter) % self.scale_window == 0: 67 | self.cur_scale *= self.scale_factor 68 | # self.cur_scale = 1 69 | self.cur_iter += 1 70 | 71 | @property 72 | def loss_scale(self): 73 | return self.cur_scale 74 | 75 | def scale_gradient(self, module, grad_in, grad_out): 76 | return tuple(self.loss_scale * g for g in grad_in) 77 | 78 | def backward(self, loss): 79 | scaled_loss = loss*self.loss_scale 80 | scaled_loss.backward() 81 | 82 | ############################################################## 83 | # Example usage below here -- assuming it's in a separate file 84 | ############################################################## 85 | if __name__ == "__main__": 86 | import torch 87 | from torch.autograd import Variable 88 | from dynamic_loss_scaler import DynamicLossScaler 89 | 90 | # N is batch size; D_in is input dimension; 91 | # H is hidden dimension; D_out is output dimension. 92 | N, D_in, H, D_out = 64, 1000, 100, 10 93 | 94 | # Create random Tensors to hold inputs and outputs, and wrap them in Variables. 95 | x = Variable(torch.randn(N, D_in), requires_grad=False) 96 | y = Variable(torch.randn(N, D_out), requires_grad=False) 97 | 98 | w1 = Variable(torch.randn(D_in, H), requires_grad=True) 99 | w2 = Variable(torch.randn(H, D_out), requires_grad=True) 100 | parameters = [w1, w2] 101 | 102 | learning_rate = 1e-6 103 | optimizer = torch.optim.SGD(parameters, lr=learning_rate) 104 | loss_scaler = DynamicLossScaler() 105 | 106 | for t in range(500): 107 | y_pred = x.mm(w1).clamp(min=0).mm(w2) 108 | loss = (y_pred - y).pow(2).sum() * loss_scaler.loss_scale 109 | print('Iter {} loss scale: {}'.format(t, loss_scaler.loss_scale)) 110 | print('Iter {} scaled loss: {}'.format(t, loss.data[0])) 111 | print('Iter {} unscaled loss: {}'.format(t, loss.data[0] / loss_scaler.loss_scale)) 112 | 113 | # Run backprop 114 | optimizer.zero_grad() 115 | loss.backward() 116 | 117 | # Check for overflow 118 | has_overflow = DynamicLossScaler.has_overflow(parameters) 119 | 120 | # If no overflow, unscale grad and update as usual 121 | if not has_overflow: 122 | for param in parameters: 123 | param.grad.data.mul_(1. / loss_scaler.loss_scale) 124 | optimizer.step() 125 | # Otherwise, don't do anything -- ie, skip iteration 126 | else: 127 | print('OVERFLOW!') 128 | 129 | # Update loss scale for next iteration 130 | loss_scaler.update_scale(has_overflow) 131 | 132 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## GUI Work in Progress (update 4 August 2020) 2 | GUI wrapper for synthesize. Allows CPU-only synthesis via a toggleable switch. Portable exe file is available (that runs on CPU only). 3 | 4 | Also plays TTS donations alerts from Stream Elements. 5 | 6 | Main UI | Stream Elements integration 7 | ------------ | ------------- 8 | | 9 | # Overview 10 | A machine learning based Text to Speech program with a user friendly GUI. Target audience include Twitch streamers or content creators looking for an open source TTS program. The aim of this software is to make tts synthesis accessible offline (No coding experience, gpu/colab) in a portable exe. 11 | 12 | ## Features 13 | * Reads donations from Stream Elements automatically 14 | * PyQt5 wrapper for NVIDIA/tacotron2 & /waveglow 15 | 16 | ## Download Link 17 | A portable executable can be found at the [Releases](https://github.com/lokkelvin2/tacotron2-tts-GUI/releases) page, or directly [here](https://github.com/lokkelvin2/tacotron2-tts-GUI/releases/download/v0.3/nvidia_waveglow-v0.3.1_x86_64.exe). Download a pretrained *Tacotron 2* and *Waveglow* model from below. 18 | 19 | Warning: the portable executable runs on CPU which leads to a >10x speed slowdown compared to running it on GPU. 20 | 21 | # Building from source 22 | ## Requirements 23 | * Python >=3.7 24 | * librosa 25 | * numpy 26 | * PyQt5==5.15.0 27 | * requests 28 | * tqdm 29 | * matplotlib 30 | * scipy 31 | * num2words 32 | * pygame 33 | 34 | [PyTorch 1.0](https://pytorch.org/) 35 | 36 | ## To Run 37 | ``` 38 | python gui.py 39 | ``` 40 | ## License 41 | * NVIDIA/tacotron2 & waveglow: BSD-3-Clause License 42 | 43 | ## Notes 44 | * TTS code from [NVIDIA/tacotron2](https://github.com/NVIDIA/tacotron2) 45 | * Partial GUI code from [https://github.com/CorentinJ/Real-Time-Voice-Cloning](https://github.com/CorentinJ/Real-Time-Voice-Cloning) and layout inspired by u/realstreamer's Forsen TTS [https://www.youtube.com/watch?v=kL2tglbcDCo](https://www.youtube.com/watch?v=kL2tglbcDCo) 46 | 47 | 48 | # Original Repo: 49 | 50 | # Tacotron 2 (without wavenet) 51 | 52 | PyTorch implementation of [Natural TTS Synthesis By Conditioning 53 | Wavenet On Mel Spectrogram Predictions](https://arxiv.org/pdf/1712.05884.pdf). 54 | 55 | This implementation includes **distributed** and **automatic mixed precision** support 56 | and uses the [LJSpeech dataset](https://keithito.com/LJ-Speech-Dataset/). 57 | 58 | Distributed and Automatic Mixed Precision support relies on NVIDIA's [Apex] and [AMP]. 59 | 60 | Visit our [website] for audio samples using our published [Tacotron 2] and 61 | [WaveGlow] models. 62 | 63 | ![Alignment, Predicted Mel Spectrogram, Target Mel Spectrogram](tensorboard.png) 64 | 65 | 66 | ## Pre-requisites 67 | 1. NVIDIA GPU + CUDA cuDNN 68 | 69 | ## Setup 70 | 1. Download and extract the [LJ Speech dataset](https://keithito.com/LJ-Speech-Dataset/) 71 | 2. Clone this repo: `git clone https://github.com/NVIDIA/tacotron2.git` 72 | 3. CD into this repo: `cd tacotron2` 73 | 4. Initialize submodule: `git submodule init; git submodule update` 74 | 5. Update .wav paths: `sed -i -- 's,DUMMY,ljs_dataset_folder/wavs,g' filelists/*.txt` 75 | - Alternatively, set `load_mel_from_disk=True` in `hparams.py` and update mel-spectrogram paths 76 | 6. Install [PyTorch 1.0] 77 | 7. Install [Apex] 78 | 8. Install python requirements or build docker image 79 | - Install python requirements: `pip install -r requirements.txt` 80 | 81 | ## Training 82 | 1. `python train.py --output_directory=outdir --log_directory=logdir` 83 | 2. (OPTIONAL) `tensorboard --logdir=outdir/logdir` 84 | 85 | ## Training using a pre-trained model 86 | Training using a pre-trained model can lead to faster convergence 87 | By default, the dataset dependent text embedding layers are [ignored] 88 | 89 | 1. Download our published [Tacotron 2] model 90 | 2. `python train.py --output_directory=outdir --log_directory=logdir -c tacotron2_statedict.pt --warm_start` 91 | 92 | ## Multi-GPU (distributed) and Automatic Mixed Precision Training 93 | 1. `python -m multiproc train.py --output_directory=outdir --log_directory=logdir --hparams=distributed_run=True,fp16_run=True` 94 | 95 | ## Inference demo 96 | 1. Download our published [Tacotron 2] model 97 | 2. Download our published [WaveGlow] model 98 | 3. `jupyter notebook --ip=127.0.0.1 --port=31337` 99 | 4. Load inference.ipynb 100 | 101 | N.b. When performing Mel-Spectrogram to Audio synthesis, make sure Tacotron 2 102 | and the Mel decoder were trained on the same mel-spectrogram representation. 103 | 104 | 105 | ## Related repos 106 | [WaveGlow](https://github.com/NVIDIA/WaveGlow) Faster than real time Flow-based 107 | Generative Network for Speech Synthesis 108 | 109 | [nv-wavenet](https://github.com/NVIDIA/nv-wavenet/) Faster than real time 110 | WaveNet. 111 | 112 | ## Acknowledgements 113 | This implementation uses code from the following repos: [Keith 114 | Ito](https://github.com/keithito/tacotron/), [Prem 115 | Seetharaman](https://github.com/pseeth/pytorch-stft) as described in our code. 116 | 117 | We are inspired by [Ryuchi Yamamoto's](https://github.com/r9y9/tacotron_pytorch) 118 | Tacotron PyTorch implementation. 119 | 120 | We are thankful to the Tacotron 2 paper authors, specially Jonathan Shen, Yuxuan 121 | Wang and Zongheng Yang. 122 | 123 | 124 | [WaveGlow]: https://drive.google.com/open?id=1rpK8CzAAirq9sWZhe9nlfvxMF1dRgFbF 125 | [Tacotron 2]: https://drive.google.com/file/d/1c5ZTuT7J08wLUoVZ2KkUs_VdZuJ86ZqA/view?usp=sharing 126 | [pytorch 1.0]: https://github.com/pytorch/pytorch#installation 127 | [website]: https://nv-adlr.github.io/WaveGlow 128 | [ignored]: https://github.com/NVIDIA/tacotron2/blob/master/hparams.py#L22 129 | [Apex]: https://github.com/nvidia/apex 130 | [AMP]: https://github.com/NVIDIA/apex/tree/master/apex/amp 131 | -------------------------------------------------------------------------------- /stft.py: -------------------------------------------------------------------------------- 1 | """ 2 | BSD 3-Clause License 3 | 4 | Copyright (c) 2017, Prem Seetharaman 5 | All rights reserved. 6 | 7 | * Redistribution and use in source and binary forms, with or without 8 | modification, are permitted provided that the following conditions are met: 9 | 10 | * Redistributions of source code must retain the above copyright notice, 11 | this list of conditions and the following disclaimer. 12 | 13 | * Redistributions in binary form must reproduce the above copyright notice, this 14 | list of conditions and the following disclaimer in the 15 | documentation and/or other materials provided with the distribution. 16 | 17 | * Neither the name of the copyright holder nor the names of its 18 | contributors may be used to endorse or promote products derived from this 19 | software without specific prior written permission. 20 | 21 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 22 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 23 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 24 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR 25 | ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 26 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 27 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 28 | ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 29 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 30 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 31 | """ 32 | 33 | import torch 34 | import numpy as np 35 | import torch.nn.functional as F 36 | from torch.autograd import Variable 37 | from scipy.signal import get_window 38 | from librosa.util import pad_center, tiny 39 | from audio_processing import window_sumsquare 40 | 41 | 42 | class STFT(torch.nn.Module): 43 | """adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft""" 44 | def __init__(self, filter_length=800, hop_length=200, win_length=800, 45 | window='hann',use_cuda = True): 46 | super(STFT, self).__init__() 47 | self.device = torch.device('cuda' if use_cuda else 'cpu') 48 | self.filter_length = filter_length 49 | self.hop_length = hop_length 50 | self.win_length = win_length 51 | self.window = window 52 | self.forward_transform = None 53 | scale = self.filter_length / self.hop_length 54 | fourier_basis = np.fft.fft(np.eye(self.filter_length)) 55 | 56 | cutoff = int((self.filter_length / 2 + 1)) 57 | fourier_basis = np.vstack([np.real(fourier_basis[:cutoff, :]), 58 | np.imag(fourier_basis[:cutoff, :])]) 59 | 60 | forward_basis = torch.FloatTensor(fourier_basis[:, None, :]) 61 | inverse_basis = torch.FloatTensor( 62 | np.linalg.pinv(scale * fourier_basis).T[:, None, :]) 63 | 64 | if window is not None: 65 | assert(filter_length >= win_length) 66 | # get window and zero center pad it to filter_length 67 | fft_window = get_window(window, win_length, fftbins=True) 68 | fft_window = pad_center(fft_window, filter_length) 69 | fft_window = torch.from_numpy(fft_window).float() 70 | 71 | # window the bases 72 | forward_basis *= fft_window 73 | inverse_basis *= fft_window 74 | 75 | self.register_buffer('forward_basis', forward_basis.float()) 76 | self.register_buffer('inverse_basis', inverse_basis.float()) 77 | 78 | def transform(self, input_data): 79 | num_batches = input_data.size(0) 80 | num_samples = input_data.size(1) 81 | 82 | self.num_samples = num_samples 83 | 84 | # similar to librosa, reflect-pad the input 85 | input_data = input_data.view(num_batches, 1, num_samples) 86 | input_data = F.pad( 87 | input_data.unsqueeze(1), 88 | (int(self.filter_length / 2), int(self.filter_length / 2), 0, 0), 89 | mode='reflect') 90 | input_data = input_data.squeeze(1) 91 | 92 | forward_transform = F.conv1d( 93 | input_data.to(self.device), 94 | Variable(self.forward_basis, requires_grad=False).to(self.device), 95 | stride=self.hop_length, 96 | padding=0) 97 | 98 | cutoff = int((self.filter_length / 2) + 1) 99 | real_part = forward_transform[:, :cutoff, :] 100 | imag_part = forward_transform[:, cutoff:, :] 101 | 102 | magnitude = torch.sqrt(real_part**2 + imag_part**2) 103 | phase = torch.autograd.Variable( 104 | torch.atan2(imag_part.data, real_part.data)) 105 | 106 | return magnitude, phase 107 | 108 | def inverse(self, magnitude, phase): 109 | recombine_magnitude_phase = torch.cat( 110 | [magnitude*torch.cos(phase), magnitude*torch.sin(phase)], dim=1) 111 | 112 | inverse_transform = F.conv_transpose1d( 113 | recombine_magnitude_phase, 114 | Variable(self.inverse_basis, requires_grad=False), 115 | stride=self.hop_length, 116 | padding=0) 117 | 118 | if self.window is not None: 119 | window_sum = window_sumsquare( 120 | self.window, magnitude.size(-1), hop_length=self.hop_length, 121 | win_length=self.win_length, n_fft=self.filter_length, 122 | dtype=np.float32) 123 | # remove modulation effects 124 | approx_nonzero_indices = torch.from_numpy( 125 | np.where(window_sum > tiny(window_sum))[0]) 126 | window_sum = torch.autograd.Variable( 127 | torch.from_numpy(window_sum), requires_grad=False) 128 | window_sum = window_sum.to(self.device) if magnitude.is_cuda else window_sum 129 | inverse_transform[:, :, approx_nonzero_indices] /= window_sum[approx_nonzero_indices] 130 | 131 | # scale by hop ratio 132 | inverse_transform *= float(self.filter_length) / self.hop_length 133 | 134 | inverse_transform = inverse_transform[:, :, int(self.filter_length/2):] 135 | inverse_transform = inverse_transform[:, :, :-int(self.filter_length/2):] 136 | 137 | return inverse_transform 138 | 139 | def forward(self, input_data): 140 | self.magnitude, self.phase = self.transform(input_data) 141 | reconstruction = self.inverse(self.magnitude, self.phase) 142 | return reconstruction 143 | -------------------------------------------------------------------------------- /switch.py: -------------------------------------------------------------------------------- 1 | # Taken from https://stackoverflow.com/a/51825815 2 | from PyQt5.QtCore import QPropertyAnimation, QRectF, QSize, Qt, pyqtProperty 3 | from PyQt5.QtGui import QPainter 4 | from PyQt5.QtWidgets import ( 5 | QAbstractButton, 6 | QApplication, 7 | QHBoxLayout, 8 | QSizePolicy, 9 | QWidget, 10 | ) 11 | 12 | 13 | class Switch(QAbstractButton): 14 | def __init__(self, parent=None, track_radius=10, thumb_radius=8, show_text = True): 15 | super().__init__(parent=parent) 16 | self.setCheckable(True) 17 | self.setSizePolicy(QSizePolicy.Fixed, QSizePolicy.Fixed) 18 | 19 | self._track_radius = track_radius 20 | self._thumb_radius = thumb_radius 21 | 22 | self._margin = max(0, self._thumb_radius - self._track_radius) 23 | self._base_offset = max(self._thumb_radius, self._track_radius) 24 | self._end_offset = { 25 | True: lambda: self.width() - self._base_offset, 26 | False: lambda: self._base_offset, 27 | } 28 | self._offset = self._base_offset 29 | 30 | palette = self.palette() 31 | if not show_text: 32 | self._track_color = { 33 | True: palette.highlight(), 34 | False: palette.dark(), 35 | } 36 | self._thumb_color = { 37 | True: palette.highlight(), 38 | False: palette.light(), 39 | } 40 | self._text_color = { 41 | True: palette.highlightedText().color(), 42 | False: palette.dark().color(), 43 | } 44 | self._thumb_text = { 45 | True: '', 46 | False: '', 47 | } 48 | self._track_opacity = 0.5 49 | else: 50 | self._thumb_color = { 51 | True: palette.highlightedText(), 52 | False: palette.light(), 53 | } 54 | self._track_color = { 55 | True: palette.highlight(), 56 | False: palette.dark(), 57 | } 58 | self._text_color = { 59 | True: palette.highlight().color(), 60 | False: palette.dark().color(), 61 | } 62 | self._thumb_text = { 63 | True: '✔', 64 | False: '✕', 65 | #True: 'On', 66 | #False: 'Off', 67 | } 68 | self._track_opacity = 1 69 | 70 | @pyqtProperty(int) 71 | def offset(self): 72 | return self._offset 73 | 74 | @offset.setter 75 | def offset(self, value): 76 | self._offset = value 77 | self.update() 78 | 79 | def sizeHint(self): # pylint: disable=invalid-name 80 | return QSize( 81 | 4 * self._track_radius + 2 * self._margin, 82 | 2 * self._track_radius + 2 * self._margin, 83 | ) 84 | 85 | def setChecked(self, checked): 86 | super().setChecked(checked) 87 | self.offset = self._end_offset[checked]() 88 | 89 | def resizeEvent(self, event): 90 | super().resizeEvent(event) 91 | self.offset = self._end_offset[self.isChecked()]() 92 | 93 | def paintEvent(self, event): # pylint: disable=invalid-name, unused-argument 94 | p = QPainter(self) 95 | p.setRenderHint(QPainter.Antialiasing, True) 96 | p.setPen(Qt.NoPen) 97 | track_opacity = self._track_opacity 98 | thumb_opacity = 1.0 99 | text_opacity = 1.0 100 | if self.isEnabled(): 101 | track_brush = self._track_color[self.isChecked()] 102 | thumb_brush = self._thumb_color[self.isChecked()] 103 | text_color = self._text_color[self.isChecked()] 104 | else: 105 | track_opacity *= 0.8 106 | track_brush = self.palette().shadow() 107 | thumb_brush = self.palette().mid() 108 | text_color = self.palette().shadow().color() 109 | 110 | p.setBrush(track_brush) 111 | p.setOpacity(track_opacity) 112 | p.drawRoundedRect( 113 | self._margin, 114 | self._margin, 115 | self.width() - 2 * self._margin, 116 | self.height() - 2 * self._margin, 117 | self._track_radius, 118 | self._track_radius, 119 | ) 120 | p.setBrush(thumb_brush) 121 | p.setOpacity(thumb_opacity) 122 | p.drawEllipse( 123 | self.offset - self._thumb_radius, 124 | self._base_offset - self._thumb_radius, 125 | 2 * self._thumb_radius, 126 | 2 * self._thumb_radius, 127 | ) 128 | p.setPen(text_color) 129 | p.setOpacity(text_opacity) 130 | font = p.font() 131 | font.setPixelSize(1.5 * self._thumb_radius) 132 | p.setFont(font) 133 | p.drawText( 134 | QRectF( 135 | self.offset - self._thumb_radius, 136 | self._base_offset - self._thumb_radius, 137 | 2 * self._thumb_radius, 138 | 2 * self._thumb_radius, 139 | ), 140 | Qt.AlignCenter, 141 | self._thumb_text[self.isChecked()], 142 | ) 143 | 144 | def mouseReleaseEvent(self, event): # pylint: disable=invalid-name 145 | super().mouseReleaseEvent(event) 146 | if event.button() == Qt.LeftButton: 147 | anim = QPropertyAnimation(self, b'offset', self) 148 | anim.setDuration(120) 149 | anim.setStartValue(self.offset) 150 | anim.setEndValue(self._end_offset[self.isChecked()]()) 151 | anim.start() 152 | 153 | def enterEvent(self, event): # pylint: disable=invalid-name 154 | self.setCursor(Qt.PointingHandCursor) 155 | super().enterEvent(event) 156 | 157 | 158 | def main(): 159 | app = QApplication([]) 160 | 161 | # Thumb size < track size (Gitlab style) 162 | s1 = Switch() 163 | s1.toggled.connect(lambda c: print('toggled', c)) 164 | s1.clicked.connect(lambda c: print('clicked', c)) 165 | s1.pressed.connect(lambda: print('pressed')) 166 | s1.released.connect(lambda: print('released')) 167 | s2 = Switch() 168 | s2.setEnabled(False) 169 | 170 | # Thumb size > track size (Android style) 171 | s3 = Switch(thumb_radius=11, track_radius=8) 172 | s4 = Switch(thumb_radius=8, track_radius=10, show_text = False) 173 | 174 | #s4.setEnabled(False) 175 | 176 | l = QHBoxLayout() 177 | l.addWidget(s1) 178 | l.addWidget(s2) 179 | l.addWidget(s3) 180 | l.addWidget(s4) 181 | w = QWidget() 182 | w.setLayout(l) 183 | w.show() 184 | 185 | app.exec() 186 | 187 | 188 | if __name__ == '__main__': 189 | main() -------------------------------------------------------------------------------- /distributed.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.distributed as dist 3 | from torch.nn.modules import Module 4 | from torch.autograd import Variable 5 | 6 | def _flatten_dense_tensors(tensors): 7 | """Flatten dense tensors into a contiguous 1D buffer. Assume tensors are of 8 | same dense type. 9 | Since inputs are dense, the resulting tensor will be a concatenated 1D 10 | buffer. Element-wise operation on this buffer will be equivalent to 11 | operating individually. 12 | Arguments: 13 | tensors (Iterable[Tensor]): dense tensors to flatten. 14 | Returns: 15 | A contiguous 1D buffer containing input tensors. 16 | """ 17 | if len(tensors) == 1: 18 | return tensors[0].contiguous().view(-1) 19 | flat = torch.cat([t.contiguous().view(-1) for t in tensors], dim=0) 20 | return flat 21 | 22 | def _unflatten_dense_tensors(flat, tensors): 23 | """View a flat buffer using the sizes of tensors. Assume that tensors are of 24 | same dense type, and that flat is given by _flatten_dense_tensors. 25 | Arguments: 26 | flat (Tensor): flattened dense tensors to unflatten. 27 | tensors (Iterable[Tensor]): dense tensors whose sizes will be used to 28 | unflatten flat. 29 | Returns: 30 | Unflattened dense tensors with sizes same as tensors and values from 31 | flat. 32 | """ 33 | outputs = [] 34 | offset = 0 35 | for tensor in tensors: 36 | numel = tensor.numel() 37 | outputs.append(flat.narrow(0, offset, numel).view_as(tensor)) 38 | offset += numel 39 | return tuple(outputs) 40 | 41 | 42 | ''' 43 | This version of DistributedDataParallel is designed to be used in conjunction with the multiproc.py 44 | launcher included with this example. It assumes that your run is using multiprocess with 1 45 | GPU/process, that the model is on the correct device, and that torch.set_device has been 46 | used to set the device. 47 | 48 | Parameters are broadcasted to the other processes on initialization of DistributedDataParallel, 49 | and will be allreduced at the finish of the backward pass. 50 | ''' 51 | class DistributedDataParallel(Module): 52 | 53 | def __init__(self, module): 54 | super(DistributedDataParallel, self).__init__() 55 | #fallback for PyTorch 0.3 56 | if not hasattr(dist, '_backend'): 57 | self.warn_on_half = True 58 | else: 59 | self.warn_on_half = True if dist._backend == dist.dist_backend.GLOO else False 60 | 61 | self.module = module 62 | 63 | for p in self.module.state_dict().values(): 64 | if not torch.is_tensor(p): 65 | continue 66 | dist.broadcast(p, 0) 67 | 68 | def allreduce_params(): 69 | if(self.needs_reduction): 70 | self.needs_reduction = False 71 | buckets = {} 72 | for param in self.module.parameters(): 73 | if param.requires_grad and param.grad is not None: 74 | tp = type(param.data) 75 | if tp not in buckets: 76 | buckets[tp] = [] 77 | buckets[tp].append(param) 78 | if self.warn_on_half: 79 | if torch.cuda.HalfTensor in buckets: 80 | print("WARNING: gloo dist backend for half parameters may be extremely slow." + 81 | " It is recommended to use the NCCL backend in this case. This currently requires" + 82 | "PyTorch built from top of tree master.") 83 | self.warn_on_half = False 84 | 85 | for tp in buckets: 86 | bucket = buckets[tp] 87 | grads = [param.grad.data for param in bucket] 88 | coalesced = _flatten_dense_tensors(grads) 89 | dist.all_reduce(coalesced) 90 | coalesced /= dist.get_world_size() 91 | for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)): 92 | buf.copy_(synced) 93 | 94 | for param in list(self.module.parameters()): 95 | def allreduce_hook(*unused): 96 | param._execution_engine.queue_callback(allreduce_params) 97 | if param.requires_grad: 98 | param.register_hook(allreduce_hook) 99 | 100 | def forward(self, *inputs, **kwargs): 101 | self.needs_reduction = True 102 | return self.module(*inputs, **kwargs) 103 | 104 | ''' 105 | def _sync_buffers(self): 106 | buffers = list(self.module._all_buffers()) 107 | if len(buffers) > 0: 108 | # cross-node buffer sync 109 | flat_buffers = _flatten_dense_tensors(buffers) 110 | dist.broadcast(flat_buffers, 0) 111 | for buf, synced in zip(buffers, _unflatten_dense_tensors(flat_buffers, buffers)): 112 | buf.copy_(synced) 113 | def train(self, mode=True): 114 | # Clear NCCL communicator and CUDA event cache of the default group ID, 115 | # These cache will be recreated at the later call. This is currently a 116 | # work-around for a potential NCCL deadlock. 117 | if dist._backend == dist.dist_backend.NCCL: 118 | dist._clear_group_cache() 119 | super(DistributedDataParallel, self).train(mode) 120 | self.module.train(mode) 121 | ''' 122 | ''' 123 | Modifies existing model to do gradient allreduce, but doesn't change class 124 | so you don't need "module" 125 | ''' 126 | def apply_gradient_allreduce(module): 127 | if not hasattr(dist, '_backend'): 128 | module.warn_on_half = True 129 | else: 130 | module.warn_on_half = True if dist._backend == dist.dist_backend.GLOO else False 131 | 132 | for p in module.state_dict().values(): 133 | if not torch.is_tensor(p): 134 | continue 135 | dist.broadcast(p, 0) 136 | 137 | def allreduce_params(): 138 | if(module.needs_reduction): 139 | module.needs_reduction = False 140 | buckets = {} 141 | for param in module.parameters(): 142 | if param.requires_grad and param.grad is not None: 143 | tp = param.data.dtype 144 | if tp not in buckets: 145 | buckets[tp] = [] 146 | buckets[tp].append(param) 147 | if module.warn_on_half: 148 | if torch.cuda.HalfTensor in buckets: 149 | print("WARNING: gloo dist backend for half parameters may be extremely slow." + 150 | " It is recommended to use the NCCL backend in this case. This currently requires" + 151 | "PyTorch built from top of tree master.") 152 | module.warn_on_half = False 153 | 154 | for tp in buckets: 155 | bucket = buckets[tp] 156 | grads = [param.grad.data for param in bucket] 157 | coalesced = _flatten_dense_tensors(grads) 158 | dist.all_reduce(coalesced) 159 | coalesced /= dist.get_world_size() 160 | for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)): 161 | buf.copy_(synced) 162 | 163 | for param in list(module.parameters()): 164 | def allreduce_hook(*unused): 165 | Variable._execution_engine.queue_callback(allreduce_params) 166 | if param.requires_grad: 167 | param.register_hook(allreduce_hook) 168 | 169 | def set_needs_reduction(self, input, output): 170 | self.needs_reduction = True 171 | 172 | module.register_forward_hook(set_needs_reduction) 173 | return module 174 | -------------------------------------------------------------------------------- /filelists/ljs_audio_text_val_filelist.txt: -------------------------------------------------------------------------------- 1 | DUMMY/LJ022-0023.wav|The overwhelming majority of people in this country know how to sift the wheat from the chaff in what they hear and what they read. 2 | DUMMY/LJ043-0030.wav|If somebody did that to me, a lousy trick like that, to take my wife away, and all the furniture, I would be mad as hell, too. 3 | DUMMY/LJ005-0201.wav|as is shown by the report of the Commissioners to inquire into the state of the municipal corporations in eighteen thirty-five. 4 | DUMMY/LJ001-0110.wav|Even the Caslon type when enlarged shows great shortcomings in this respect: 5 | DUMMY/LJ003-0345.wav|All the committee could do in this respect was to throw the responsibility on others. 6 | DUMMY/LJ007-0154.wav|These pungent and well-grounded strictures applied with still greater force to the unconvicted prisoner, the man who came to the prison innocent, and still uncontaminated, 7 | DUMMY/LJ018-0098.wav|and recognized as one of the frequenters of the bogus law-stationers. His arrest led to that of others. 8 | DUMMY/LJ047-0044.wav|Oswald was, however, willing to discuss his contacts with Soviet authorities. He denied having any involvement with Soviet intelligence agencies 9 | DUMMY/LJ031-0038.wav|The first physician to see the President at Parkland Hospital was Dr. Charles J. Carrico, a resident in general surgery. 10 | DUMMY/LJ048-0194.wav|during the morning of November twenty-two prior to the motorcade. 11 | DUMMY/LJ049-0026.wav|On occasion the Secret Service has been permitted to have an agent riding in the passenger compartment with the President. 12 | DUMMY/LJ004-0152.wav|although at Mr. Buxton's visit a new jail was in process of erection, the first step towards reform since Howard's visitation in seventeen seventy-four. 13 | DUMMY/LJ008-0278.wav|or theirs might be one of many, and it might be considered necessary to "make an example." 14 | DUMMY/LJ043-0002.wav|The Warren Commission Report. By The President's Commission on the Assassination of President Kennedy. Chapter seven. Lee Harvey Oswald: 15 | DUMMY/LJ009-0114.wav|Mr. Wakefield winds up his graphic but somewhat sensational account by describing another religious service, which may appropriately be inserted here. 16 | DUMMY/LJ028-0506.wav|A modern artist would have difficulty in doing such accurate work. 17 | DUMMY/LJ050-0168.wav|with the particular purposes of the agency involved. The Commission recognizes that this is a controversial area 18 | DUMMY/LJ039-0223.wav|Oswald's Marine training in marksmanship, his other rifle experience and his established familiarity with this particular weapon 19 | DUMMY/LJ029-0032.wav|According to O'Donnell, quote, we had a motorcade wherever we went, end quote. 20 | DUMMY/LJ031-0070.wav|Dr. Clark, who most closely observed the head wound, 21 | DUMMY/LJ034-0198.wav|Euins, who was on the southwest corner of Elm and Houston Streets testified that he could not describe the man he saw in the window. 22 | DUMMY/LJ026-0068.wav|Energy enters the plant, to a small extent, 23 | DUMMY/LJ039-0075.wav|once you know that you must put the crosshairs on the target and that is all that is necessary. 24 | DUMMY/LJ004-0096.wav|the fatal consequences whereof might be prevented if the justices of the peace were duly authorized 25 | DUMMY/LJ005-0014.wav|Speaking on a debate on prison matters, he declared that 26 | DUMMY/LJ012-0161.wav|he was reported to have fallen away to a shadow. 27 | DUMMY/LJ018-0239.wav|His disappearance gave color and substance to evil reports already in circulation that the will and conveyance above referred to 28 | DUMMY/LJ019-0257.wav|Here the tread-wheel was in use, there cellular cranks, or hard-labor machines. 29 | DUMMY/LJ028-0008.wav|you tap gently with your heel upon the shoulder of the dromedary to urge her on. 30 | DUMMY/LJ024-0083.wav|This plan of mine is no attack on the Court; 31 | DUMMY/LJ042-0129.wav|No night clubs or bowling alleys, no places of recreation except the trade union dances. I have had enough. 32 | DUMMY/LJ036-0103.wav|The police asked him whether he could pick out his passenger from the lineup. 33 | DUMMY/LJ046-0058.wav|During his Presidency, Franklin D. Roosevelt made almost four hundred journeys and traveled more than three hundred fifty thousand miles. 34 | DUMMY/LJ014-0076.wav|He was seen afterwards smoking and talking with his hosts in their back parlor, and never seen again alive. 35 | DUMMY/LJ002-0043.wav|long narrow rooms -- one thirty-six feet, six twenty-three feet, and the eighth eighteen, 36 | DUMMY/LJ009-0076.wav|We come to the sermon. 37 | DUMMY/LJ017-0131.wav|even when the high sheriff had told him there was no possibility of a reprieve, and within a few hours of execution. 38 | DUMMY/LJ046-0184.wav|but there is a system for the immediate notification of the Secret Service by the confining institution when a subject is released or escapes. 39 | DUMMY/LJ014-0263.wav|When other pleasures palled he took a theatre, and posed as a munificent patron of the dramatic art. 40 | DUMMY/LJ042-0096.wav|(old exchange rate) in addition to his factory salary of approximately equal amount 41 | DUMMY/LJ049-0050.wav|Hill had both feet on the car and was climbing aboard to assist President and Mrs. Kennedy. 42 | DUMMY/LJ019-0186.wav|seeing that since the establishment of the Central Criminal Court, Newgate received prisoners for trial from several counties, 43 | DUMMY/LJ028-0307.wav|then let twenty days pass, and at the end of that time station near the Chaldasan gates a body of four thousand. 44 | DUMMY/LJ012-0235.wav|While they were in a state of insensibility the murder was committed. 45 | DUMMY/LJ034-0053.wav|reached the same conclusion as Latona that the prints found on the cartons were those of Lee Harvey Oswald. 46 | DUMMY/LJ014-0030.wav|These were damnatory facts which well supported the prosecution. 47 | DUMMY/LJ015-0203.wav|but were the precautions too minute, the vigilance too close to be eluded or overcome? 48 | DUMMY/LJ028-0093.wav|but his scribe wrote it in the manner customary for the scribes of those days to write of their royal masters. 49 | DUMMY/LJ002-0018.wav|The inadequacy of the jail was noticed and reported upon again and again by the grand juries of the city of London, 50 | DUMMY/LJ028-0275.wav|At last, in the twentieth month, 51 | DUMMY/LJ012-0042.wav|which he kept concealed in a hiding-place with a trap-door just under his bed. 52 | DUMMY/LJ011-0096.wav|He married a lady also belonging to the Society of Friends, who brought him a large fortune, which, and his own money, he put into a city firm, 53 | DUMMY/LJ036-0077.wav|Roger D. Craig, a deputy sheriff of Dallas County, 54 | DUMMY/LJ016-0318.wav|Other officials, great lawyers, governors of prisons, and chaplains supported this view. 55 | DUMMY/LJ013-0164.wav|who came from his room ready dressed, a suspicious circumstance, as he was always late in the morning. 56 | DUMMY/LJ027-0141.wav|is closely reproduced in the life-history of existing deer. Or, in other words, 57 | DUMMY/LJ028-0335.wav|accordingly they committed to him the command of their whole army, and put the keys of their city into his hands. 58 | DUMMY/LJ031-0202.wav|Mrs. Kennedy chose the hospital in Bethesda for the autopsy because the President had served in the Navy. 59 | DUMMY/LJ021-0145.wav|From those willing to join in establishing this hoped-for period of peace, 60 | DUMMY/LJ016-0288.wav|"Müller, Müller, He's the man," till a diversion was created by the appearance of the gallows, which was received with continuous yells. 61 | DUMMY/LJ028-0081.wav|Years later, when the archaeologists could readily distinguish the false from the true, 62 | DUMMY/LJ018-0081.wav|his defense being that he had intended to commit suicide, but that, on the appearance of this officer who had wronged him, 63 | DUMMY/LJ021-0066.wav|together with a great increase in the payrolls, there has come a substantial rise in the total of industrial profits 64 | DUMMY/LJ009-0238.wav|After this the sheriffs sent for another rope, but the spectators interfered, and the man was carried back to jail. 65 | DUMMY/LJ005-0079.wav|and improve the morals of the prisoners, and shall insure the proper measure of punishment to convicted offenders. 66 | DUMMY/LJ035-0019.wav|drove to the northwest corner of Elm and Houston, and parked approximately ten feet from the traffic signal. 67 | DUMMY/LJ036-0174.wav|This is the approximate time he entered the roominghouse, according to Earlene Roberts, the housekeeper there. 68 | DUMMY/LJ046-0146.wav|The criteria in effect prior to November twenty-two, nineteen sixty-three, for determining whether to accept material for the PRS general files 69 | DUMMY/LJ017-0044.wav|and the deepest anxiety was felt that the crime, if crime there had been, should be brought home to its perpetrator. 70 | DUMMY/LJ017-0070.wav|but his sporting operations did not prosper, and he became a needy man, always driven to desperate straits for cash. 71 | DUMMY/LJ014-0020.wav|He was soon afterwards arrested on suspicion, and a search of his lodgings brought to light several garments saturated with blood; 72 | DUMMY/LJ016-0020.wav|He never reached the cistern, but fell back into the yard, injuring his legs severely. 73 | DUMMY/LJ045-0230.wav|when he was finally apprehended in the Texas Theatre. Although it is not fully corroborated by others who were present, 74 | DUMMY/LJ035-0129.wav|and she must have run down the stairs ahead of Oswald and would probably have seen or heard him. 75 | DUMMY/LJ008-0307.wav|afterwards express a wish to murder the Recorder for having kept them so long in suspense. 76 | DUMMY/LJ008-0294.wav|nearly indefinitely deferred. 77 | DUMMY/LJ047-0148.wav|On October twenty-five, 78 | DUMMY/LJ008-0111.wav|They entered a "stone cold room," and were presently joined by the prisoner. 79 | DUMMY/LJ034-0042.wav|that he could only testify with certainty that the print was less than three days old. 80 | DUMMY/LJ037-0234.wav|Mrs. Mary Brock, the wife of a mechanic who worked at the station, was there at the time and she saw a white male, 81 | DUMMY/LJ040-0002.wav|Chapter seven. Lee Harvey Oswald: Background and Possible Motives, Part one. 82 | DUMMY/LJ045-0140.wav|The arguments he used to justify his use of the alias suggest that Oswald may have come to think that the whole world was becoming involved 83 | DUMMY/LJ012-0035.wav|the number and names on watches, were carefully removed or obliterated after the goods passed out of his hands. 84 | DUMMY/LJ012-0250.wav|On the seventh July, eighteen thirty-seven, 85 | DUMMY/LJ016-0179.wav|contracted with sheriffs and conveners to work by the job. 86 | DUMMY/LJ016-0138.wav|at a distance from the prison. 87 | DUMMY/LJ027-0052.wav|These principles of homology are essential to a correct interpretation of the facts of morphology. 88 | DUMMY/LJ031-0134.wav|On one occasion Mrs. Johnson, accompanied by two Secret Service agents, left the room to see Mrs. Kennedy and Mrs. Connally. 89 | DUMMY/LJ019-0273.wav|which Sir Joshua Jebb told the committee he considered the proper elements of penal discipline. 90 | DUMMY/LJ014-0110.wav|At the first the boxes were impounded, opened, and found to contain many of O'Connor's effects. 91 | DUMMY/LJ034-0160.wav|on Brennan's subsequent certain identification of Lee Harvey Oswald as the man he saw fire the rifle. 92 | DUMMY/LJ038-0199.wav|eleven. If I am alive and taken prisoner, 93 | DUMMY/LJ014-0010.wav|yet he could not overcome the strange fascination it had for him, and remained by the side of the corpse till the stretcher came. 94 | DUMMY/LJ033-0047.wav|I noticed when I went out that the light was on, end quote, 95 | DUMMY/LJ040-0027.wav|He was never satisfied with anything. 96 | DUMMY/LJ048-0228.wav|and others who were present say that no agent was inebriated or acted improperly. 97 | DUMMY/LJ003-0111.wav|He was in consequence put out of the protection of their internal law, end quote. Their code was a subject of some curiosity. 98 | DUMMY/LJ008-0258.wav|Let me retrace my steps, and speak more in detail of the treatment of the condemned in those bloodthirsty and brutally indifferent days, 99 | DUMMY/LJ029-0022.wav|The original plan called for the President to spend only one day in the State, making whirlwind visits to Dallas, Fort Worth, San Antonio, and Houston. 100 | DUMMY/LJ004-0045.wav|Mr. Sturges Bourne, Sir James Mackintosh, Sir James Scarlett, and William Wilberforce. 101 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import argparse 4 | import math 5 | from numpy import finfo 6 | 7 | import torch 8 | from distributed import apply_gradient_allreduce 9 | import torch.distributed as dist 10 | from torch.utils.data.distributed import DistributedSampler 11 | from torch.utils.data import DataLoader 12 | 13 | from model import Tacotron2 14 | from data_utils import TextMelLoader, TextMelCollate 15 | from loss_function import Tacotron2Loss 16 | from logger import Tacotron2Logger 17 | from hparams import create_hparams 18 | 19 | 20 | def reduce_tensor(tensor, n_gpus): 21 | rt = tensor.clone() 22 | dist.all_reduce(rt, op=dist.reduce_op.SUM) 23 | rt /= n_gpus 24 | return rt 25 | 26 | 27 | def init_distributed(hparams, n_gpus, rank, group_name): 28 | assert torch.cuda.is_available(), "Distributed mode requires CUDA." 29 | print("Initializing Distributed") 30 | 31 | # Set cuda device so everything is done on the right GPU. 32 | torch.cuda.set_device(rank % torch.cuda.device_count()) 33 | 34 | # Initialize distributed communication 35 | dist.init_process_group( 36 | backend=hparams.dist_backend, init_method=hparams.dist_url, 37 | world_size=n_gpus, rank=rank, group_name=group_name) 38 | 39 | print("Done initializing distributed") 40 | 41 | 42 | def prepare_dataloaders(hparams): 43 | # Get data, data loaders and collate function ready 44 | trainset = TextMelLoader(hparams.training_files, hparams) 45 | valset = TextMelLoader(hparams.validation_files, hparams) 46 | collate_fn = TextMelCollate(hparams.n_frames_per_step) 47 | 48 | if hparams.distributed_run: 49 | train_sampler = DistributedSampler(trainset) 50 | shuffle = False 51 | else: 52 | train_sampler = None 53 | shuffle = True 54 | 55 | train_loader = DataLoader(trainset, num_workers=1, shuffle=shuffle, 56 | sampler=train_sampler, 57 | batch_size=hparams.batch_size, pin_memory=False, 58 | drop_last=True, collate_fn=collate_fn) 59 | return train_loader, valset, collate_fn 60 | 61 | 62 | def prepare_directories_and_logger(output_directory, log_directory, rank): 63 | if rank == 0: 64 | if not os.path.isdir(output_directory): 65 | os.makedirs(output_directory) 66 | os.chmod(output_directory, 0o775) 67 | logger = Tacotron2Logger(os.path.join(output_directory, log_directory)) 68 | else: 69 | logger = None 70 | return logger 71 | 72 | 73 | def load_model(hparams,use_cuda=True): 74 | device = torch.device('cuda' if use_cuda else 'cpu') 75 | model = Tacotron2(hparams).to(device) 76 | if hparams.fp16_run: 77 | model.decoder.attention_layer.score_mask_value = finfo('float16').min 78 | 79 | if hparams.distributed_run: 80 | model = apply_gradient_allreduce(model) 81 | 82 | return model 83 | 84 | 85 | def warm_start_model(checkpoint_path, model, ignore_layers): 86 | assert os.path.isfile(checkpoint_path) 87 | print("Warm starting model from checkpoint '{}'".format(checkpoint_path)) 88 | checkpoint_dict = torch.load(checkpoint_path, map_location='cpu') 89 | model_dict = checkpoint_dict['state_dict'] 90 | if len(ignore_layers) > 0: 91 | model_dict = {k: v for k, v in model_dict.items() 92 | if k not in ignore_layers} 93 | dummy_dict = model.state_dict() 94 | dummy_dict.update(model_dict) 95 | model_dict = dummy_dict 96 | model.load_state_dict(model_dict) 97 | return model 98 | 99 | 100 | def load_checkpoint(checkpoint_path, model, optimizer): 101 | assert os.path.isfile(checkpoint_path) 102 | print("Loading checkpoint '{}'".format(checkpoint_path)) 103 | checkpoint_dict = torch.load(checkpoint_path, map_location='cpu') 104 | model.load_state_dict(checkpoint_dict['state_dict']) 105 | optimizer.load_state_dict(checkpoint_dict['optimizer']) 106 | learning_rate = checkpoint_dict['learning_rate'] 107 | iteration = checkpoint_dict['iteration'] 108 | print("Loaded checkpoint '{}' from iteration {}" .format( 109 | checkpoint_path, iteration)) 110 | return model, optimizer, learning_rate, iteration 111 | 112 | 113 | def save_checkpoint(model, optimizer, learning_rate, iteration, filepath): 114 | print("Saving model and optimizer state at iteration {} to {}".format( 115 | iteration, filepath)) 116 | torch.save({'iteration': iteration, 117 | 'state_dict': model.state_dict(), 118 | 'optimizer': optimizer.state_dict(), 119 | 'learning_rate': learning_rate}, filepath) 120 | 121 | 122 | def validate(model, criterion, valset, iteration, batch_size, n_gpus, 123 | collate_fn, logger, distributed_run, rank): 124 | """Handles all the validation scoring and printing""" 125 | model.eval() 126 | with torch.no_grad(): 127 | val_sampler = DistributedSampler(valset) if distributed_run else None 128 | val_loader = DataLoader(valset, sampler=val_sampler, num_workers=1, 129 | shuffle=False, batch_size=batch_size, 130 | pin_memory=False, collate_fn=collate_fn) 131 | 132 | val_loss = 0.0 133 | for i, batch in enumerate(val_loader): 134 | x, y = model.parse_batch(batch) 135 | y_pred = model(x) 136 | loss = criterion(y_pred, y) 137 | if distributed_run: 138 | reduced_val_loss = reduce_tensor(loss.data, n_gpus).item() 139 | else: 140 | reduced_val_loss = loss.item() 141 | val_loss += reduced_val_loss 142 | val_loss = val_loss / (i + 1) 143 | 144 | model.train() 145 | if rank == 0: 146 | print("Validation loss {}: {:9f} ".format(iteration, val_loss)) 147 | logger.log_validation(val_loss, model, y, y_pred, iteration) 148 | 149 | 150 | def train(output_directory, log_directory, checkpoint_path, warm_start, n_gpus, 151 | rank, group_name, hparams): 152 | """Training and validation logging results to tensorboard and stdout 153 | 154 | Params 155 | ------ 156 | output_directory (string): directory to save checkpoints 157 | log_directory (string) directory to save tensorboard logs 158 | checkpoint_path(string): checkpoint path 159 | n_gpus (int): number of gpus 160 | rank (int): rank of current gpu 161 | hparams (object): comma separated list of "name=value" pairs. 162 | """ 163 | if hparams.distributed_run: 164 | init_distributed(hparams, n_gpus, rank, group_name) 165 | 166 | torch.manual_seed(hparams.seed) 167 | torch.cuda.manual_seed(hparams.seed) 168 | 169 | model = load_model(hparams) 170 | learning_rate = hparams.learning_rate 171 | optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, 172 | weight_decay=hparams.weight_decay) 173 | 174 | if hparams.fp16_run: 175 | from apex import amp 176 | model, optimizer = amp.initialize( 177 | model, optimizer, opt_level='O2') 178 | 179 | if hparams.distributed_run: 180 | model = apply_gradient_allreduce(model) 181 | 182 | criterion = Tacotron2Loss() 183 | 184 | logger = prepare_directories_and_logger( 185 | output_directory, log_directory, rank) 186 | 187 | train_loader, valset, collate_fn = prepare_dataloaders(hparams) 188 | 189 | # Load checkpoint if one exists 190 | iteration = 0 191 | epoch_offset = 0 192 | if checkpoint_path is not None: 193 | if warm_start: 194 | model = warm_start_model( 195 | checkpoint_path, model, hparams.ignore_layers) 196 | else: 197 | model, optimizer, _learning_rate, iteration = load_checkpoint( 198 | checkpoint_path, model, optimizer) 199 | if hparams.use_saved_learning_rate: 200 | learning_rate = _learning_rate 201 | iteration += 1 # next iteration is iteration + 1 202 | epoch_offset = max(0, int(iteration / len(train_loader))) 203 | 204 | model.train() 205 | is_overflow = False 206 | # ================ MAIN TRAINNIG LOOP! =================== 207 | for epoch in range(epoch_offset, hparams.epochs): 208 | print("Epoch: {}".format(epoch)) 209 | for i, batch in enumerate(train_loader): 210 | start = time.perf_counter() 211 | for param_group in optimizer.param_groups: 212 | param_group['lr'] = learning_rate 213 | 214 | model.zero_grad() 215 | x, y = model.parse_batch(batch) 216 | y_pred = model(x) 217 | 218 | loss = criterion(y_pred, y) 219 | if hparams.distributed_run: 220 | reduced_loss = reduce_tensor(loss.data, n_gpus).item() 221 | else: 222 | reduced_loss = loss.item() 223 | if hparams.fp16_run: 224 | with amp.scale_loss(loss, optimizer) as scaled_loss: 225 | scaled_loss.backward() 226 | else: 227 | loss.backward() 228 | 229 | if hparams.fp16_run: 230 | grad_norm = torch.nn.utils.clip_grad_norm_( 231 | amp.master_params(optimizer), hparams.grad_clip_thresh) 232 | is_overflow = math.isnan(grad_norm) 233 | else: 234 | grad_norm = torch.nn.utils.clip_grad_norm_( 235 | model.parameters(), hparams.grad_clip_thresh) 236 | 237 | optimizer.step() 238 | 239 | if not is_overflow and rank == 0: 240 | duration = time.perf_counter() - start 241 | print("Train loss {} {:.6f} Grad Norm {:.6f} {:.2f}s/it".format( 242 | iteration, reduced_loss, grad_norm, duration)) 243 | logger.log_training( 244 | reduced_loss, grad_norm, learning_rate, duration, iteration) 245 | 246 | if not is_overflow and (iteration % hparams.iters_per_checkpoint == 0): 247 | validate(model, criterion, valset, iteration, 248 | hparams.batch_size, n_gpus, collate_fn, logger, 249 | hparams.distributed_run, rank) 250 | if rank == 0: 251 | checkpoint_path = os.path.join( 252 | output_directory, "checkpoint_{}".format(iteration)) 253 | save_checkpoint(model, optimizer, learning_rate, iteration, 254 | checkpoint_path) 255 | 256 | iteration += 1 257 | 258 | 259 | if __name__ == '__main__': 260 | parser = argparse.ArgumentParser() 261 | parser.add_argument('-o', '--output_directory', type=str, 262 | help='directory to save checkpoints') 263 | parser.add_argument('-l', '--log_directory', type=str, 264 | help='directory to save tensorboard logs') 265 | parser.add_argument('-c', '--checkpoint_path', type=str, default=None, 266 | required=False, help='checkpoint path') 267 | parser.add_argument('--warm_start', action='store_true', 268 | help='load model weights only, ignore specified layers') 269 | parser.add_argument('--n_gpus', type=int, default=1, 270 | required=False, help='number of gpus') 271 | parser.add_argument('--rank', type=int, default=0, 272 | required=False, help='rank of current gpu') 273 | parser.add_argument('--group_name', type=str, default='group_name', 274 | required=False, help='Distributed group name') 275 | parser.add_argument('--hparams', type=str, 276 | required=False, help='comma separated name=value pairs') 277 | 278 | args = parser.parse_args() 279 | hparams = create_hparams(args.hparams) 280 | 281 | torch.backends.cudnn.enabled = hparams.cudnn_enabled 282 | torch.backends.cudnn.benchmark = hparams.cudnn_benchmark 283 | 284 | print("FP16 Run:", hparams.fp16_run) 285 | print("Dynamic Loss Scaling:", hparams.dynamic_loss_scaling) 286 | print("Distributed Run:", hparams.distributed_run) 287 | print("cuDNN Enabled:", hparams.cudnn_enabled) 288 | print("cuDNN Benchmark:", hparams.cudnn_benchmark) 289 | 290 | train(args.output_directory, args.log_directory, args.checkpoint_path, 291 | args.warm_start, args.n_gpus, args.rank, args.group_name, hparams) 292 | -------------------------------------------------------------------------------- /nvidia_tacotron_TTS_Layout.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # Form implementation generated from reading ui file '.\nvidia_tacotron_TTS_Layout.ui' 4 | # 5 | # Created by: PyQt5 UI code generator 5.15.0 6 | # 7 | # WARNING: Any manual changes made to this file will be lost when pyuic5 is 8 | # run again. Do not edit this file unless you know what you are doing. 9 | 10 | 11 | from PyQt5 import QtCore, QtGui, QtWidgets 12 | 13 | 14 | class Ui_MainWindow(object): 15 | def setupUi(self, MainWindow): 16 | MainWindow.setObjectName("MainWindow") 17 | MainWindow.resize(518, 534) 18 | self.centralwidget = QtWidgets.QWidget(MainWindow) 19 | self.centralwidget.setObjectName("centralwidget") 20 | self.verticalLayout_3 = QtWidgets.QVBoxLayout(self.centralwidget) 21 | self.verticalLayout_3.setObjectName("verticalLayout_3") 22 | self.tabWidget = QtWidgets.QTabWidget(self.centralwidget) 23 | sizePolicy = QtWidgets.QSizePolicy(QtWidgets.QSizePolicy.MinimumExpanding, QtWidgets.QSizePolicy.MinimumExpanding) 24 | sizePolicy.setHorizontalStretch(0) 25 | sizePolicy.setVerticalStretch(0) 26 | sizePolicy.setHeightForWidth(self.tabWidget.sizePolicy().hasHeightForWidth()) 27 | self.tabWidget.setSizePolicy(sizePolicy) 28 | self.tabWidget.setObjectName("tabWidget") 29 | self.tab = QtWidgets.QWidget() 30 | self.tab.setObjectName("tab") 31 | self.verticalLayout = QtWidgets.QVBoxLayout(self.tab) 32 | self.verticalLayout.setObjectName("verticalLayout") 33 | self.gridLayout_5 = QtWidgets.QGridLayout() 34 | self.gridLayout_5.setObjectName("gridLayout_5") 35 | self.label_9 = QtWidgets.QLabel(self.tab) 36 | self.label_9.setObjectName("label_9") 37 | self.gridLayout_5.addWidget(self.label_9, 1, 0, 1, 1) 38 | self.WGModelCombo = QtWidgets.QComboBox(self.tab) 39 | sizePolicy = QtWidgets.QSizePolicy(QtWidgets.QSizePolicy.MinimumExpanding, QtWidgets.QSizePolicy.Fixed) 40 | sizePolicy.setHorizontalStretch(0) 41 | sizePolicy.setVerticalStretch(0) 42 | sizePolicy.setHeightForWidth(self.WGModelCombo.sizePolicy().hasHeightForWidth()) 43 | self.WGModelCombo.setSizePolicy(sizePolicy) 44 | self.WGModelCombo.setObjectName("WGModelCombo") 45 | self.gridLayout_5.addWidget(self.WGModelCombo, 1, 2, 1, 1) 46 | self.label_7 = QtWidgets.QLabel(self.tab) 47 | self.label_7.setObjectName("label_7") 48 | self.gridLayout_5.addWidget(self.label_7, 0, 0, 1, 1) 49 | self.TTModelCombo = QtWidgets.QComboBox(self.tab) 50 | sizePolicy = QtWidgets.QSizePolicy(QtWidgets.QSizePolicy.MinimumExpanding, QtWidgets.QSizePolicy.Fixed) 51 | sizePolicy.setHorizontalStretch(0) 52 | sizePolicy.setVerticalStretch(0) 53 | sizePolicy.setHeightForWidth(self.TTModelCombo.sizePolicy().hasHeightForWidth()) 54 | self.TTModelCombo.setSizePolicy(sizePolicy) 55 | self.TTModelCombo.setObjectName("TTModelCombo") 56 | self.gridLayout_5.addWidget(self.TTModelCombo, 0, 2, 1, 1) 57 | self.LoadTTButton = QtWidgets.QToolButton(self.tab) 58 | sizePolicy = QtWidgets.QSizePolicy(QtWidgets.QSizePolicy.Fixed, QtWidgets.QSizePolicy.Fixed) 59 | sizePolicy.setHorizontalStretch(0) 60 | sizePolicy.setVerticalStretch(0) 61 | sizePolicy.setHeightForWidth(self.LoadTTButton.sizePolicy().hasHeightForWidth()) 62 | self.LoadTTButton.setSizePolicy(sizePolicy) 63 | self.LoadTTButton.setLayoutDirection(QtCore.Qt.LeftToRight) 64 | self.LoadTTButton.setObjectName("LoadTTButton") 65 | self.gridLayout_5.addWidget(self.LoadTTButton, 0, 3, 1, 1) 66 | self.LoadWGButton = QtWidgets.QToolButton(self.tab) 67 | self.LoadWGButton.setObjectName("LoadWGButton") 68 | self.gridLayout_5.addWidget(self.LoadWGButton, 1, 3, 1, 1) 69 | self.verticalLayout.addLayout(self.gridLayout_5) 70 | self.horizontalLayout = QtWidgets.QHBoxLayout() 71 | self.horizontalLayout.setObjectName("horizontalLayout") 72 | spacerItem = QtWidgets.QSpacerItem(40, 20, QtWidgets.QSizePolicy.Expanding, QtWidgets.QSizePolicy.Minimum) 73 | self.horizontalLayout.addItem(spacerItem) 74 | self.label = QtWidgets.QLabel(self.tab) 75 | self.label.setObjectName("label") 76 | self.horizontalLayout.addWidget(self.label) 77 | self.verticalLayout.addLayout(self.horizontalLayout) 78 | spacerItem1 = QtWidgets.QSpacerItem(20, 20, QtWidgets.QSizePolicy.Minimum, QtWidgets.QSizePolicy.Fixed) 79 | self.verticalLayout.addItem(spacerItem1) 80 | self.label_2 = QtWidgets.QLabel(self.tab) 81 | self.label_2.setObjectName("label_2") 82 | self.verticalLayout.addWidget(self.label_2) 83 | self.TTSTextEdit = QtWidgets.QPlainTextEdit(self.tab) 84 | self.TTSTextEdit.setObjectName("TTSTextEdit") 85 | self.verticalLayout.addWidget(self.TTSTextEdit) 86 | self.gridLayout_2 = QtWidgets.QGridLayout() 87 | self.gridLayout_2.setObjectName("gridLayout_2") 88 | self.progressBar = QtWidgets.QProgressBar(self.tab) 89 | self.progressBar.setProperty("value", 0) 90 | self.progressBar.setObjectName("progressBar") 91 | self.gridLayout_2.addWidget(self.progressBar, 1, 0, 1, 1) 92 | self.progressBarLabel = QtWidgets.QLabel(self.tab) 93 | self.progressBarLabel.setObjectName("progressBarLabel") 94 | self.gridLayout_2.addWidget(self.progressBarLabel, 1, 0, 1, 1, QtCore.Qt.AlignHCenter) 95 | self.verticalLayout.addLayout(self.gridLayout_2) 96 | self.gridLayout_4 = QtWidgets.QGridLayout() 97 | self.gridLayout_4.setObjectName("gridLayout_4") 98 | spacerItem2 = QtWidgets.QSpacerItem(40, 20, QtWidgets.QSizePolicy.Expanding, QtWidgets.QSizePolicy.Minimum) 99 | self.gridLayout_4.addItem(spacerItem2, 0, 0, 1, 1) 100 | self.TTSDialogButton = QtWidgets.QPushButton(self.tab) 101 | self.TTSDialogButton.setObjectName("TTSDialogButton") 102 | self.gridLayout_4.addWidget(self.TTSDialogButton, 0, 2, 1, 1, QtCore.Qt.AlignRight) 103 | self.TTSStopButton = QtWidgets.QPushButton(self.tab) 104 | self.TTSStopButton.setObjectName("TTSStopButton") 105 | self.gridLayout_4.addWidget(self.TTSStopButton, 0, 3, 1, 1) 106 | self.verticalLayout.addLayout(self.gridLayout_4) 107 | self.log_window1 = QtWidgets.QLabel(self.tab) 108 | sizePolicy = QtWidgets.QSizePolicy(QtWidgets.QSizePolicy.MinimumExpanding, QtWidgets.QSizePolicy.Fixed) 109 | sizePolicy.setHorizontalStretch(0) 110 | sizePolicy.setVerticalStretch(0) 111 | sizePolicy.setHeightForWidth(self.log_window1.sizePolicy().hasHeightForWidth()) 112 | self.log_window1.setSizePolicy(sizePolicy) 113 | self.log_window1.setMinimumSize(QtCore.QSize(0, 39)) 114 | self.log_window1.setAlignment(QtCore.Qt.AlignBottom|QtCore.Qt.AlignLeading|QtCore.Qt.AlignLeft) 115 | self.log_window1.setObjectName("log_window1") 116 | self.verticalLayout.addWidget(self.log_window1) 117 | self.tabWidget.addTab(self.tab, "") 118 | self.tab_2 = QtWidgets.QWidget() 119 | self.tab_2.setObjectName("tab_2") 120 | self.verticalLayout_2 = QtWidgets.QVBoxLayout(self.tab_2) 121 | self.verticalLayout_2.setObjectName("verticalLayout_2") 122 | self.gridLayout = QtWidgets.QGridLayout() 123 | self.gridLayout.setSizeConstraint(QtWidgets.QLayout.SetFixedSize) 124 | self.gridLayout.setObjectName("gridLayout") 125 | self.label_5 = QtWidgets.QLabel(self.tab_2) 126 | self.label_5.setObjectName("label_5") 127 | self.gridLayout.addWidget(self.label_5, 0, 0, 1, 1) 128 | self.ChannelName = QtWidgets.QLineEdit(self.tab_2) 129 | self.ChannelName.setObjectName("ChannelName") 130 | self.gridLayout.addWidget(self.ChannelName, 0, 1, 1, 1) 131 | self.label_3 = QtWidgets.QLabel(self.tab_2) 132 | self.label_3.setObjectName("label_3") 133 | self.gridLayout.addWidget(self.label_3, 1, 0, 1, 1) 134 | self.APIKeyLine = QtWidgets.QLineEdit(self.tab_2) 135 | self.APIKeyLine.setEchoMode(QtWidgets.QLineEdit.Password) 136 | self.APIKeyLine.setObjectName("APIKeyLine") 137 | self.gridLayout.addWidget(self.APIKeyLine, 1, 1, 1, 1) 138 | self.verticalLayout_2.addLayout(self.gridLayout) 139 | self.horizontalLayout_2 = QtWidgets.QHBoxLayout() 140 | self.horizontalLayout_2.setObjectName("horizontalLayout_2") 141 | spacerItem3 = QtWidgets.QSpacerItem(40, 20, QtWidgets.QSizePolicy.Expanding, QtWidgets.QSizePolicy.Minimum) 142 | self.horizontalLayout_2.addItem(spacerItem3) 143 | self.ClientStartBtn = QtWidgets.QPushButton(self.tab_2) 144 | self.ClientStartBtn.setObjectName("ClientStartBtn") 145 | self.horizontalLayout_2.addWidget(self.ClientStartBtn) 146 | self.ClientStopBtn = QtWidgets.QPushButton(self.tab_2) 147 | self.ClientStopBtn.setObjectName("ClientStopBtn") 148 | self.horizontalLayout_2.addWidget(self.ClientStopBtn, 0, QtCore.Qt.AlignRight) 149 | self.ClientSkipBtn = QtWidgets.QPushButton(self.tab_2) 150 | self.ClientSkipBtn.setObjectName("ClientSkipBtn") 151 | self.horizontalLayout_2.addWidget(self.ClientSkipBtn) 152 | self.verticalLayout_2.addLayout(self.horizontalLayout_2) 153 | spacerItem4 = QtWidgets.QSpacerItem(20, 20, QtWidgets.QSizePolicy.Minimum, QtWidgets.QSizePolicy.Fixed) 154 | self.verticalLayout_2.addItem(spacerItem4) 155 | self.widget = QtWidgets.QWidget(self.tab_2) 156 | self.widget.setObjectName("widget") 157 | self.verticalLayout_2.addWidget(self.widget) 158 | self.label_8 = QtWidgets.QLabel(self.tab_2) 159 | self.label_8.setObjectName("label_8") 160 | self.verticalLayout_2.addWidget(self.label_8) 161 | self.ClientAmountLine = QtWidgets.QDoubleSpinBox(self.tab_2) 162 | self.ClientAmountLine.setObjectName("ClientAmountLine") 163 | self.verticalLayout_2.addWidget(self.ClientAmountLine) 164 | self.label_4 = QtWidgets.QLabel(self.tab_2) 165 | self.label_4.setObjectName("label_4") 166 | self.verticalLayout_2.addWidget(self.label_4) 167 | self.log_window2 = QtWidgets.QPlainTextEdit(self.tab_2) 168 | sizePolicy = QtWidgets.QSizePolicy(QtWidgets.QSizePolicy.MinimumExpanding, QtWidgets.QSizePolicy.MinimumExpanding) 169 | sizePolicy.setHorizontalStretch(0) 170 | sizePolicy.setVerticalStretch(0) 171 | sizePolicy.setHeightForWidth(self.log_window2.sizePolicy().hasHeightForWidth()) 172 | self.log_window2.setSizePolicy(sizePolicy) 173 | self.log_window2.setAutoFillBackground(False) 174 | self.log_window2.setObjectName("log_window2") 175 | self.verticalLayout_2.addWidget(self.log_window2) 176 | self.gridLayout_6 = QtWidgets.QGridLayout() 177 | self.gridLayout_6.setObjectName("gridLayout_6") 178 | self.progressBar2 = QtWidgets.QProgressBar(self.tab_2) 179 | self.progressBar2.setProperty("value", 0) 180 | self.progressBar2.setTextVisible(True) 181 | self.progressBar2.setOrientation(QtCore.Qt.Horizontal) 182 | self.progressBar2.setObjectName("progressBar2") 183 | self.gridLayout_6.addWidget(self.progressBar2, 0, 0, 1, 1) 184 | self.progressBar2Label = QtWidgets.QLabel(self.tab_2) 185 | self.progressBar2Label.setObjectName("progressBar2Label") 186 | self.gridLayout_6.addWidget(self.progressBar2Label, 0, 0, 1, 1, QtCore.Qt.AlignHCenter) 187 | self.verticalLayout_2.addLayout(self.gridLayout_6) 188 | self.tabWidget.addTab(self.tab_2, "") 189 | self.tab_3 = QtWidgets.QWidget() 190 | self.tab_3.setObjectName("tab_3") 191 | self.verticalLayout_4 = QtWidgets.QVBoxLayout(self.tab_3) 192 | self.verticalLayout_4.setObjectName("verticalLayout_4") 193 | self.groupBox = QtWidgets.QGroupBox(self.tab_3) 194 | sizePolicy = QtWidgets.QSizePolicy(QtWidgets.QSizePolicy.Preferred, QtWidgets.QSizePolicy.Fixed) 195 | sizePolicy.setHorizontalStretch(0) 196 | sizePolicy.setVerticalStretch(0) 197 | sizePolicy.setHeightForWidth(self.groupBox.sizePolicy().hasHeightForWidth()) 198 | self.groupBox.setSizePolicy(sizePolicy) 199 | self.groupBox.setMaximumSize(QtCore.QSize(16777215, 150)) 200 | self.groupBox.setObjectName("groupBox") 201 | self.gridLayout_7 = QtWidgets.QGridLayout(self.groupBox) 202 | self.gridLayout_7.setObjectName("gridLayout_7") 203 | self.OptLimitCpuBtn = QtWidgets.QCheckBox(self.groupBox) 204 | self.OptLimitCpuBtn.setMaximumSize(QtCore.QSize(67, 17)) 205 | self.OptLimitCpuBtn.setObjectName("OptLimitCpuBtn") 206 | self.gridLayout_7.addWidget(self.OptLimitCpuBtn, 0, 0, 1, 1) 207 | spacerItem5 = QtWidgets.QSpacerItem(20, 40, QtWidgets.QSizePolicy.Minimum, QtWidgets.QSizePolicy.Expanding) 208 | self.gridLayout_7.addItem(spacerItem5, 1, 0, 1, 1) 209 | spacerItem6 = QtWidgets.QSpacerItem(40, 20, QtWidgets.QSizePolicy.Expanding, QtWidgets.QSizePolicy.Minimum) 210 | self.gridLayout_7.addItem(spacerItem6, 0, 4, 1, 1) 211 | self.label_10 = QtWidgets.QLabel(self.groupBox) 212 | self.label_10.setMaximumSize(QtCore.QSize(43, 16777215)) 213 | self.label_10.setObjectName("label_10") 214 | self.gridLayout_7.addWidget(self.label_10, 0, 2, 1, 1) 215 | self.OptLimitCpuCombo = QtWidgets.QComboBox(self.groupBox) 216 | self.OptLimitCpuCombo.setEnabled(True) 217 | sizePolicy = QtWidgets.QSizePolicy(QtWidgets.QSizePolicy.Fixed, QtWidgets.QSizePolicy.Fixed) 218 | sizePolicy.setHorizontalStretch(0) 219 | sizePolicy.setVerticalStretch(0) 220 | sizePolicy.setHeightForWidth(self.OptLimitCpuCombo.sizePolicy().hasHeightForWidth()) 221 | self.OptLimitCpuCombo.setSizePolicy(sizePolicy) 222 | self.OptLimitCpuCombo.setMaximumSize(QtCore.QSize(30, 20)) 223 | self.OptLimitCpuCombo.setObjectName("OptLimitCpuCombo") 224 | self.gridLayout_7.addWidget(self.OptLimitCpuCombo, 0, 3, 1, 1) 225 | spacerItem7 = QtWidgets.QSpacerItem(40, 20, QtWidgets.QSizePolicy.Fixed, QtWidgets.QSizePolicy.Minimum) 226 | self.gridLayout_7.addItem(spacerItem7, 0, 1, 1, 1) 227 | self.verticalLayout_4.addWidget(self.groupBox) 228 | self.groupBox_2 = QtWidgets.QGroupBox(self.tab_3) 229 | self.groupBox_2.setObjectName("groupBox_2") 230 | self.verticalLayout_5 = QtWidgets.QVBoxLayout(self.groupBox_2) 231 | self.verticalLayout_5.setObjectName("verticalLayout_5") 232 | self.OptDonoNameAmountBtn = QtWidgets.QCheckBox(self.groupBox_2) 233 | self.OptDonoNameAmountBtn.setChecked(True) 234 | self.OptDonoNameAmountBtn.setObjectName("OptDonoNameAmountBtn") 235 | self.verticalLayout_5.addWidget(self.OptDonoNameAmountBtn) 236 | self.OptApproveDonoBtn = QtWidgets.QCheckBox(self.groupBox_2) 237 | sizePolicy = QtWidgets.QSizePolicy(QtWidgets.QSizePolicy.MinimumExpanding, QtWidgets.QSizePolicy.Fixed) 238 | sizePolicy.setHorizontalStretch(0) 239 | sizePolicy.setVerticalStretch(0) 240 | sizePolicy.setHeightForWidth(self.OptApproveDonoBtn.sizePolicy().hasHeightForWidth()) 241 | self.OptApproveDonoBtn.setSizePolicy(sizePolicy) 242 | self.OptApproveDonoBtn.setChecked(True) 243 | self.OptApproveDonoBtn.setObjectName("OptApproveDonoBtn") 244 | self.verticalLayout_5.addWidget(self.OptApproveDonoBtn) 245 | self.OptBlockNumberBtn = QtWidgets.QCheckBox(self.groupBox_2) 246 | self.OptBlockNumberBtn.setObjectName("OptBlockNumberBtn") 247 | self.verticalLayout_5.addWidget(self.OptBlockNumberBtn) 248 | spacerItem8 = QtWidgets.QSpacerItem(20, 40, QtWidgets.QSizePolicy.Minimum, QtWidgets.QSizePolicy.Expanding) 249 | self.verticalLayout_5.addItem(spacerItem8) 250 | self.verticalLayout_4.addWidget(self.groupBox_2) 251 | self.tabWidget.addTab(self.tab_3, "") 252 | self.verticalLayout_3.addWidget(self.tabWidget) 253 | self.gridLayout_3 = QtWidgets.QGridLayout() 254 | self.gridLayout_3.setObjectName("gridLayout_3") 255 | self.statusbar = QtWidgets.QLabel(self.centralwidget) 256 | self.statusbar.setObjectName("statusbar") 257 | self.gridLayout_3.addWidget(self.statusbar, 0, 0, 1, 1, QtCore.Qt.AlignLeft) 258 | spacerItem9 = QtWidgets.QSpacerItem(40, 20, QtWidgets.QSizePolicy.Expanding, QtWidgets.QSizePolicy.Minimum) 259 | self.gridLayout_3.addItem(spacerItem9, 0, 1, 1, 1) 260 | self.label_6 = QtWidgets.QLabel(self.centralwidget) 261 | self.label_6.setOpenExternalLinks(True) 262 | self.label_6.setObjectName("label_6") 263 | self.gridLayout_3.addWidget(self.label_6, 0, 2, 1, 1, QtCore.Qt.AlignRight) 264 | self.verticalLayout_3.addLayout(self.gridLayout_3) 265 | MainWindow.setCentralWidget(self.centralwidget) 266 | 267 | self.retranslateUi(MainWindow) 268 | self.tabWidget.setCurrentIndex(0) 269 | self.setWindowTitle("Tacotron2 + Waveglow GUI v0.3") 270 | 271 | def retranslateUi(self, MainWindow): 272 | _translate = QtCore.QCoreApplication.translate 273 | MainWindow.setWindowTitle(_translate("MainWindow", "MainWindow")) 274 | self.label_9.setText(_translate("MainWindow", "Waveglow Model:")) 275 | self.label_7.setText(_translate("MainWindow", "Tacotron 2 Model:")) 276 | self.LoadTTButton.setText(_translate("MainWindow", "Browse...")) 277 | self.LoadWGButton.setText(_translate("MainWindow", "Browse...")) 278 | self.label.setText(_translate("MainWindow", "GPU Mode (Requires CUDA)")) 279 | self.label_2.setText(_translate("MainWindow", "Enter Text to speech:")) 280 | self.progressBarLabel.setText(_translate("MainWindow", "TextLabel")) 281 | self.TTSDialogButton.setText(_translate("MainWindow", "Start")) 282 | self.TTSStopButton.setText(_translate("MainWindow", "Stop")) 283 | self.log_window1.setText(_translate("MainWindow", "Begin by loading a voice model")) 284 | self.tabWidget.setTabText(self.tabWidget.indexOf(self.tab), _translate("MainWindow", "Text to Speech")) 285 | self.label_5.setText(_translate("MainWindow", "Channel name")) 286 | self.label_3.setText(_translate("MainWindow", "StreamElements JWT Token")) 287 | self.ClientStartBtn.setText(_translate("MainWindow", "Start")) 288 | self.ClientStopBtn.setText(_translate("MainWindow", "Stop")) 289 | self.ClientSkipBtn.setText(_translate("MainWindow", "Skip")) 290 | self.label_8.setText(_translate("MainWindow", "Minimum amount for TTS:")) 291 | self.label_4.setText(_translate("MainWindow", "Status:")) 292 | self.progressBar2Label.setText(_translate("MainWindow", "TextLabel")) 293 | self.tabWidget.setTabText(self.tabWidget.indexOf(self.tab_2), _translate("MainWindow", "StreamElements")) 294 | self.groupBox.setTitle(_translate("MainWindow", "Pytorch")) 295 | self.OptLimitCpuBtn.setText(_translate("MainWindow", "Limit CPU")) 296 | self.label_10.setText(_translate("MainWindow", "Threads:")) 297 | self.groupBox_2.setTitle(_translate("MainWindow", "StreamElement")) 298 | self.OptDonoNameAmountBtn.setText(_translate("MainWindow", "Read donor\'s name and amount")) 299 | self.OptApproveDonoBtn.setText(_translate("MainWindow", "Approved donations only")) 300 | self.OptBlockNumberBtn.setText(_translate("MainWindow", "Block large numbers (>8 digits)")) 301 | self.tabWidget.setTabText(self.tabWidget.indexOf(self.tab_3), _translate("MainWindow", "Options")) 302 | self.statusbar.setText(_translate("MainWindow", "Ready")) 303 | self.label_6.setText(_translate("MainWindow", "v0.3")) 304 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | from math import sqrt 2 | import torch 3 | from torch.autograd import Variable 4 | from torch import nn 5 | from torch.nn import functional as F 6 | from layers import ConvNorm, LinearNorm 7 | from utils import to_gpu, get_mask_from_lengths 8 | 9 | 10 | class LocationLayer(nn.Module): 11 | def __init__(self, attention_n_filters, attention_kernel_size, 12 | attention_dim): 13 | super(LocationLayer, self).__init__() 14 | padding = int((attention_kernel_size - 1) / 2) 15 | self.location_conv = ConvNorm(2, attention_n_filters, 16 | kernel_size=attention_kernel_size, 17 | padding=padding, bias=False, stride=1, 18 | dilation=1) 19 | self.location_dense = LinearNorm(attention_n_filters, attention_dim, 20 | bias=False, w_init_gain='tanh') 21 | 22 | def forward(self, attention_weights_cat): 23 | processed_attention = self.location_conv(attention_weights_cat) 24 | processed_attention = processed_attention.transpose(1, 2) 25 | processed_attention = self.location_dense(processed_attention) 26 | return processed_attention 27 | 28 | 29 | class Attention(nn.Module): 30 | def __init__(self, attention_rnn_dim, embedding_dim, attention_dim, 31 | attention_location_n_filters, attention_location_kernel_size): 32 | super(Attention, self).__init__() 33 | self.query_layer = LinearNorm(attention_rnn_dim, attention_dim, 34 | bias=False, w_init_gain='tanh') 35 | self.memory_layer = LinearNorm(embedding_dim, attention_dim, bias=False, 36 | w_init_gain='tanh') 37 | self.v = LinearNorm(attention_dim, 1, bias=False) 38 | self.location_layer = LocationLayer(attention_location_n_filters, 39 | attention_location_kernel_size, 40 | attention_dim) 41 | self.score_mask_value = -float("inf") 42 | 43 | def get_alignment_energies(self, query, processed_memory, 44 | attention_weights_cat): 45 | """ 46 | PARAMS 47 | ------ 48 | query: decoder output (batch, n_mel_channels * n_frames_per_step) 49 | processed_memory: processed encoder outputs (B, T_in, attention_dim) 50 | attention_weights_cat: cumulative and prev. att weights (B, 2, max_time) 51 | 52 | RETURNS 53 | ------- 54 | alignment (batch, max_time) 55 | """ 56 | 57 | processed_query = self.query_layer(query.unsqueeze(1)) 58 | processed_attention_weights = self.location_layer(attention_weights_cat) 59 | energies = self.v(torch.tanh( 60 | processed_query + processed_attention_weights + processed_memory)) 61 | 62 | energies = energies.squeeze(-1) 63 | return energies 64 | 65 | def forward(self, attention_hidden_state, memory, processed_memory, 66 | attention_weights_cat, mask): 67 | """ 68 | PARAMS 69 | ------ 70 | attention_hidden_state: attention rnn last output 71 | memory: encoder outputs 72 | processed_memory: processed encoder outputs 73 | attention_weights_cat: previous and cummulative attention weights 74 | mask: binary mask for padded data 75 | """ 76 | alignment = self.get_alignment_energies( 77 | attention_hidden_state, processed_memory, attention_weights_cat) 78 | 79 | if mask is not None: 80 | alignment.data.masked_fill_(mask, self.score_mask_value) 81 | 82 | attention_weights = F.softmax(alignment, dim=1) 83 | attention_context = torch.bmm(attention_weights.unsqueeze(1), memory) 84 | attention_context = attention_context.squeeze(1) 85 | 86 | return attention_context, attention_weights 87 | 88 | 89 | class Prenet(nn.Module): 90 | def __init__(self, in_dim, sizes): 91 | super(Prenet, self).__init__() 92 | in_sizes = [in_dim] + sizes[:-1] 93 | self.layers = nn.ModuleList( 94 | [LinearNorm(in_size, out_size, bias=False) 95 | for (in_size, out_size) in zip(in_sizes, sizes)]) 96 | 97 | def forward(self, x): 98 | for linear in self.layers: 99 | x = F.dropout(F.relu(linear(x)), p=0.5, training=True) 100 | return x 101 | 102 | 103 | class Postnet(nn.Module): 104 | """Postnet 105 | - Five 1-d convolution with 512 channels and kernel size 5 106 | """ 107 | 108 | def __init__(self, hparams): 109 | super(Postnet, self).__init__() 110 | self.convolutions = nn.ModuleList() 111 | 112 | self.convolutions.append( 113 | nn.Sequential( 114 | ConvNorm(hparams.n_mel_channels, hparams.postnet_embedding_dim, 115 | kernel_size=hparams.postnet_kernel_size, stride=1, 116 | padding=int((hparams.postnet_kernel_size - 1) / 2), 117 | dilation=1, w_init_gain='tanh'), 118 | nn.BatchNorm1d(hparams.postnet_embedding_dim)) 119 | ) 120 | 121 | for i in range(1, hparams.postnet_n_convolutions - 1): 122 | self.convolutions.append( 123 | nn.Sequential( 124 | ConvNorm(hparams.postnet_embedding_dim, 125 | hparams.postnet_embedding_dim, 126 | kernel_size=hparams.postnet_kernel_size, stride=1, 127 | padding=int((hparams.postnet_kernel_size - 1) / 2), 128 | dilation=1, w_init_gain='tanh'), 129 | nn.BatchNorm1d(hparams.postnet_embedding_dim)) 130 | ) 131 | 132 | self.convolutions.append( 133 | nn.Sequential( 134 | ConvNorm(hparams.postnet_embedding_dim, hparams.n_mel_channels, 135 | kernel_size=hparams.postnet_kernel_size, stride=1, 136 | padding=int((hparams.postnet_kernel_size - 1) / 2), 137 | dilation=1, w_init_gain='linear'), 138 | nn.BatchNorm1d(hparams.n_mel_channels)) 139 | ) 140 | 141 | def forward(self, x): 142 | for i in range(len(self.convolutions) - 1): 143 | x = F.dropout(torch.tanh(self.convolutions[i](x)), 0.5, self.training) 144 | x = F.dropout(self.convolutions[-1](x), 0.5, self.training) 145 | 146 | return x 147 | 148 | 149 | class Encoder(nn.Module): 150 | """Encoder module: 151 | - Three 1-d convolution banks 152 | - Bidirectional LSTM 153 | """ 154 | def __init__(self, hparams): 155 | super(Encoder, self).__init__() 156 | 157 | convolutions = [] 158 | for _ in range(hparams.encoder_n_convolutions): 159 | conv_layer = nn.Sequential( 160 | ConvNorm(hparams.encoder_embedding_dim, 161 | hparams.encoder_embedding_dim, 162 | kernel_size=hparams.encoder_kernel_size, stride=1, 163 | padding=int((hparams.encoder_kernel_size - 1) / 2), 164 | dilation=1, w_init_gain='relu'), 165 | nn.BatchNorm1d(hparams.encoder_embedding_dim)) 166 | convolutions.append(conv_layer) 167 | self.convolutions = nn.ModuleList(convolutions) 168 | 169 | self.lstm = nn.LSTM(hparams.encoder_embedding_dim, 170 | int(hparams.encoder_embedding_dim / 2), 1, 171 | batch_first=True, bidirectional=True) 172 | 173 | def forward(self, x, input_lengths): 174 | for conv in self.convolutions: 175 | x = F.dropout(F.relu(conv(x)), 0.5, self.training) 176 | 177 | x = x.transpose(1, 2) 178 | 179 | # pytorch tensor are not reversible, hence the conversion 180 | input_lengths = input_lengths.cpu().numpy() 181 | x = nn.utils.rnn.pack_padded_sequence( 182 | x, input_lengths, batch_first=True) 183 | 184 | self.lstm.flatten_parameters() 185 | outputs, _ = self.lstm(x) 186 | 187 | outputs, _ = nn.utils.rnn.pad_packed_sequence( 188 | outputs, batch_first=True) 189 | 190 | return outputs 191 | 192 | def inference(self, x): 193 | for conv in self.convolutions: 194 | x = F.dropout(F.relu(conv(x)), 0.5, self.training) 195 | 196 | x = x.transpose(1, 2) 197 | 198 | self.lstm.flatten_parameters() 199 | outputs, _ = self.lstm(x) 200 | 201 | return outputs 202 | 203 | 204 | class Decoder(nn.Module): 205 | def __init__(self, hparams): 206 | super(Decoder, self).__init__() 207 | self.n_mel_channels = hparams.n_mel_channels 208 | self.n_frames_per_step = hparams.n_frames_per_step 209 | self.encoder_embedding_dim = hparams.encoder_embedding_dim 210 | self.attention_rnn_dim = hparams.attention_rnn_dim 211 | self.decoder_rnn_dim = hparams.decoder_rnn_dim 212 | self.prenet_dim = hparams.prenet_dim 213 | self.max_decoder_steps = hparams.max_decoder_steps 214 | self.gate_threshold = hparams.gate_threshold 215 | self.p_attention_dropout = hparams.p_attention_dropout 216 | self.p_decoder_dropout = hparams.p_decoder_dropout 217 | 218 | self.prenet = Prenet( 219 | hparams.n_mel_channels * hparams.n_frames_per_step, 220 | [hparams.prenet_dim, hparams.prenet_dim]) 221 | 222 | self.attention_rnn = nn.LSTMCell( 223 | hparams.prenet_dim + hparams.encoder_embedding_dim, 224 | hparams.attention_rnn_dim) 225 | 226 | self.attention_layer = Attention( 227 | hparams.attention_rnn_dim, hparams.encoder_embedding_dim, 228 | hparams.attention_dim, hparams.attention_location_n_filters, 229 | hparams.attention_location_kernel_size) 230 | 231 | self.decoder_rnn = nn.LSTMCell( 232 | hparams.attention_rnn_dim + hparams.encoder_embedding_dim, 233 | hparams.decoder_rnn_dim, 1) 234 | 235 | self.linear_projection = LinearNorm( 236 | hparams.decoder_rnn_dim + hparams.encoder_embedding_dim, 237 | hparams.n_mel_channels * hparams.n_frames_per_step) 238 | 239 | self.gate_layer = LinearNorm( 240 | hparams.decoder_rnn_dim + hparams.encoder_embedding_dim, 1, 241 | bias=True, w_init_gain='sigmoid') 242 | 243 | def get_go_frame(self, memory): 244 | """ Gets all zeros frames to use as first decoder input 245 | PARAMS 246 | ------ 247 | memory: decoder outputs 248 | 249 | RETURNS 250 | ------- 251 | decoder_input: all zeros frames 252 | """ 253 | B = memory.size(0) 254 | decoder_input = Variable(memory.data.new( 255 | B, self.n_mel_channels * self.n_frames_per_step).zero_()) 256 | return decoder_input 257 | 258 | def initialize_decoder_states(self, memory, mask): 259 | """ Initializes attention rnn states, decoder rnn states, attention 260 | weights, attention cumulative weights, attention context, stores memory 261 | and stores processed memory 262 | PARAMS 263 | ------ 264 | memory: Encoder outputs 265 | mask: Mask for padded data if training, expects None for inference 266 | """ 267 | B = memory.size(0) 268 | MAX_TIME = memory.size(1) 269 | 270 | self.attention_hidden = Variable(memory.data.new( 271 | B, self.attention_rnn_dim).zero_()) 272 | self.attention_cell = Variable(memory.data.new( 273 | B, self.attention_rnn_dim).zero_()) 274 | 275 | self.decoder_hidden = Variable(memory.data.new( 276 | B, self.decoder_rnn_dim).zero_()) 277 | self.decoder_cell = Variable(memory.data.new( 278 | B, self.decoder_rnn_dim).zero_()) 279 | 280 | self.attention_weights = Variable(memory.data.new( 281 | B, MAX_TIME).zero_()) 282 | self.attention_weights_cum = Variable(memory.data.new( 283 | B, MAX_TIME).zero_()) 284 | self.attention_context = Variable(memory.data.new( 285 | B, self.encoder_embedding_dim).zero_()) 286 | 287 | self.memory = memory 288 | self.processed_memory = self.attention_layer.memory_layer(memory) 289 | self.mask = mask 290 | 291 | def parse_decoder_inputs(self, decoder_inputs): 292 | """ Prepares decoder inputs, i.e. mel outputs 293 | PARAMS 294 | ------ 295 | decoder_inputs: inputs used for teacher-forced training, i.e. mel-specs 296 | 297 | RETURNS 298 | ------- 299 | inputs: processed decoder inputs 300 | 301 | """ 302 | # (B, n_mel_channels, T_out) -> (B, T_out, n_mel_channels) 303 | decoder_inputs = decoder_inputs.transpose(1, 2) 304 | decoder_inputs = decoder_inputs.view( 305 | decoder_inputs.size(0), 306 | int(decoder_inputs.size(1)/self.n_frames_per_step), -1) 307 | # (B, T_out, n_mel_channels) -> (T_out, B, n_mel_channels) 308 | decoder_inputs = decoder_inputs.transpose(0, 1) 309 | return decoder_inputs 310 | 311 | def parse_decoder_outputs(self, mel_outputs, gate_outputs, alignments): 312 | """ Prepares decoder outputs for output 313 | PARAMS 314 | ------ 315 | mel_outputs: 316 | gate_outputs: gate output energies 317 | alignments: 318 | 319 | RETURNS 320 | ------- 321 | mel_outputs: 322 | gate_outpust: gate output energies 323 | alignments: 324 | """ 325 | # (T_out, B) -> (B, T_out) 326 | alignments = torch.stack(alignments).transpose(0, 1) 327 | # (T_out, B) -> (B, T_out) 328 | gate_outputs = torch.stack(gate_outputs).transpose(0, 1) 329 | gate_outputs = gate_outputs.contiguous() 330 | # (T_out, B, n_mel_channels) -> (B, T_out, n_mel_channels) 331 | mel_outputs = torch.stack(mel_outputs).transpose(0, 1).contiguous() 332 | # decouple frames per step 333 | mel_outputs = mel_outputs.view( 334 | mel_outputs.size(0), -1, self.n_mel_channels) 335 | # (B, T_out, n_mel_channels) -> (B, n_mel_channels, T_out) 336 | mel_outputs = mel_outputs.transpose(1, 2) 337 | 338 | return mel_outputs, gate_outputs, alignments 339 | 340 | def decode(self, decoder_input): 341 | """ Decoder step using stored states, attention and memory 342 | PARAMS 343 | ------ 344 | decoder_input: previous mel output 345 | 346 | RETURNS 347 | ------- 348 | mel_output: 349 | gate_output: gate output energies 350 | attention_weights: 351 | """ 352 | cell_input = torch.cat((decoder_input, self.attention_context), -1) 353 | self.attention_hidden, self.attention_cell = self.attention_rnn( 354 | cell_input, (self.attention_hidden, self.attention_cell)) 355 | self.attention_hidden = F.dropout( 356 | self.attention_hidden, self.p_attention_dropout, self.training) 357 | 358 | attention_weights_cat = torch.cat( 359 | (self.attention_weights.unsqueeze(1), 360 | self.attention_weights_cum.unsqueeze(1)), dim=1) 361 | self.attention_context, self.attention_weights = self.attention_layer( 362 | self.attention_hidden, self.memory, self.processed_memory, 363 | attention_weights_cat, self.mask) 364 | 365 | self.attention_weights_cum += self.attention_weights 366 | decoder_input = torch.cat( 367 | (self.attention_hidden, self.attention_context), -1) 368 | self.decoder_hidden, self.decoder_cell = self.decoder_rnn( 369 | decoder_input, (self.decoder_hidden, self.decoder_cell)) 370 | self.decoder_hidden = F.dropout( 371 | self.decoder_hidden, self.p_decoder_dropout, self.training) 372 | 373 | decoder_hidden_attention_context = torch.cat( 374 | (self.decoder_hidden, self.attention_context), dim=1) 375 | decoder_output = self.linear_projection( 376 | decoder_hidden_attention_context) 377 | 378 | gate_prediction = self.gate_layer(decoder_hidden_attention_context) 379 | return decoder_output, gate_prediction, self.attention_weights 380 | 381 | def forward(self, memory, decoder_inputs, memory_lengths): 382 | """ Decoder forward pass for training 383 | PARAMS 384 | ------ 385 | memory: Encoder outputs 386 | decoder_inputs: Decoder inputs for teacher forcing. i.e. mel-specs 387 | memory_lengths: Encoder output lengths for attention masking. 388 | 389 | RETURNS 390 | ------- 391 | mel_outputs: mel outputs from the decoder 392 | gate_outputs: gate outputs from the decoder 393 | alignments: sequence of attention weights from the decoder 394 | """ 395 | 396 | decoder_input = self.get_go_frame(memory).unsqueeze(0) 397 | decoder_inputs = self.parse_decoder_inputs(decoder_inputs) 398 | decoder_inputs = torch.cat((decoder_input, decoder_inputs), dim=0) 399 | decoder_inputs = self.prenet(decoder_inputs) 400 | 401 | self.initialize_decoder_states( 402 | memory, mask=~get_mask_from_lengths(memory_lengths)) 403 | 404 | mel_outputs, gate_outputs, alignments = [], [], [] 405 | while len(mel_outputs) < decoder_inputs.size(0) - 1: 406 | decoder_input = decoder_inputs[len(mel_outputs)] 407 | mel_output, gate_output, attention_weights = self.decode( 408 | decoder_input) 409 | mel_outputs += [mel_output.squeeze(1)] 410 | gate_outputs += [gate_output.squeeze(1)] 411 | alignments += [attention_weights] 412 | 413 | mel_outputs, gate_outputs, alignments = self.parse_decoder_outputs( 414 | mel_outputs, gate_outputs, alignments) 415 | 416 | return mel_outputs, gate_outputs, alignments 417 | 418 | def inference(self, memory): 419 | """ Decoder inference 420 | PARAMS 421 | ------ 422 | memory: Encoder outputs 423 | 424 | RETURNS 425 | ------- 426 | mel_outputs: mel outputs from the decoder 427 | gate_outputs: gate outputs from the decoder 428 | alignments: sequence of attention weights from the decoder 429 | """ 430 | decoder_input = self.get_go_frame(memory) 431 | 432 | self.initialize_decoder_states(memory, mask=None) 433 | 434 | mel_outputs, gate_outputs, alignments = [], [], [] 435 | while True: 436 | decoder_input = self.prenet(decoder_input) 437 | mel_output, gate_output, alignment = self.decode(decoder_input) 438 | 439 | mel_outputs += [mel_output.squeeze(1)] 440 | gate_outputs += [gate_output] 441 | alignments += [alignment] 442 | 443 | if torch.sigmoid(gate_output.data) > self.gate_threshold: 444 | break 445 | elif len(mel_outputs) == self.max_decoder_steps: 446 | print("Warning! Reached max decoder steps") 447 | break 448 | 449 | decoder_input = mel_output 450 | 451 | mel_outputs, gate_outputs, alignments = self.parse_decoder_outputs( 452 | mel_outputs, gate_outputs, alignments) 453 | 454 | return mel_outputs, gate_outputs, alignments 455 | 456 | 457 | class Tacotron2(nn.Module): 458 | def __init__(self, hparams): 459 | super(Tacotron2, self).__init__() 460 | self.mask_padding = hparams.mask_padding 461 | self.fp16_run = hparams.fp16_run 462 | self.n_mel_channels = hparams.n_mel_channels 463 | self.n_frames_per_step = hparams.n_frames_per_step 464 | self.embedding = nn.Embedding( 465 | hparams.n_symbols, hparams.symbols_embedding_dim) 466 | std = sqrt(2.0 / (hparams.n_symbols + hparams.symbols_embedding_dim)) 467 | val = sqrt(3.0) * std # uniform bounds for std 468 | self.embedding.weight.data.uniform_(-val, val) 469 | self.encoder = Encoder(hparams) 470 | self.decoder = Decoder(hparams) 471 | self.postnet = Postnet(hparams) 472 | 473 | def parse_batch(self, batch): 474 | text_padded, input_lengths, mel_padded, gate_padded, \ 475 | output_lengths = batch 476 | text_padded = to_gpu(text_padded).long() 477 | input_lengths = to_gpu(input_lengths).long() 478 | max_len = torch.max(input_lengths.data).item() 479 | mel_padded = to_gpu(mel_padded).float() 480 | gate_padded = to_gpu(gate_padded).float() 481 | output_lengths = to_gpu(output_lengths).long() 482 | 483 | return ( 484 | (text_padded, input_lengths, mel_padded, max_len, output_lengths), 485 | (mel_padded, gate_padded)) 486 | 487 | def parse_output(self, outputs, output_lengths=None): 488 | if self.mask_padding and output_lengths is not None: 489 | mask = ~get_mask_from_lengths(output_lengths) 490 | mask = mask.expand(self.n_mel_channels, mask.size(0), mask.size(1)) 491 | mask = mask.permute(1, 0, 2) 492 | 493 | outputs[0].data.masked_fill_(mask, 0.0) 494 | outputs[1].data.masked_fill_(mask, 0.0) 495 | outputs[2].data.masked_fill_(mask[:, 0, :], 1e3) # gate energies 496 | 497 | return outputs 498 | 499 | def forward(self, inputs): 500 | text_inputs, text_lengths, mels, max_len, output_lengths = inputs 501 | text_lengths, output_lengths = text_lengths.data, output_lengths.data 502 | 503 | embedded_inputs = self.embedding(text_inputs).transpose(1, 2) 504 | 505 | encoder_outputs = self.encoder(embedded_inputs, text_lengths) 506 | 507 | mel_outputs, gate_outputs, alignments = self.decoder( 508 | encoder_outputs, mels, memory_lengths=text_lengths) 509 | 510 | mel_outputs_postnet = self.postnet(mel_outputs) 511 | mel_outputs_postnet = mel_outputs + mel_outputs_postnet 512 | 513 | return self.parse_output( 514 | [mel_outputs, mel_outputs_postnet, gate_outputs, alignments], 515 | output_lengths) 516 | 517 | def inference(self, inputs): 518 | embedded_inputs = self.embedding(inputs).transpose(1, 2) 519 | encoder_outputs = self.encoder.inference(embedded_inputs) 520 | mel_outputs, gate_outputs, alignments = self.decoder.inference( 521 | encoder_outputs) 522 | 523 | mel_outputs_postnet = self.postnet(mel_outputs) 524 | mel_outputs_postnet = mel_outputs + mel_outputs_postnet 525 | 526 | outputs = self.parse_output( 527 | [mel_outputs, mel_outputs_postnet, gate_outputs, alignments]) 528 | 529 | return outputs 530 | -------------------------------------------------------------------------------- /utils_hparam.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Tensor2Tensor Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Forked with minor changes from https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/training/python/training/hparam.py pylint: disable=line-too-long 17 | """Hyperparameter values.""" 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import json 23 | import numbers 24 | import re 25 | import six 26 | 27 | # Define the regular expression for parsing a single clause of the input 28 | # (delimited by commas). A legal clause looks like: 29 | # []? = 30 | # where is either a single token or [] enclosed list of tokens. 31 | # For example: "var[1] = a" or "x = [1,2,3]" 32 | PARAM_RE = re.compile(r""" 33 | (?P[a-zA-Z][\w\.]*) # variable name: "var" or "x" 34 | (\[\s*(?P\d+)\s*\])? # (optional) index: "1" or None 35 | \s*=\s* 36 | ((?P[^,\[]*) # single value: "a" or None 37 | | 38 | \[(?P[^\]]*)\]) # list of values: None or "1,2,3" 39 | ($|,\s*)""", re.VERBOSE) 40 | 41 | 42 | def _parse_fail(name, var_type, value, values): 43 | """Helper function for raising a value error for bad assignment.""" 44 | raise ValueError( 45 | 'Could not parse hparam \'%s\' of type \'%s\' with value \'%s\' in %s' % 46 | (name, var_type.__name__, value, values)) 47 | 48 | 49 | def _reuse_fail(name, values): 50 | """Helper function for raising a value error for reuse of name.""" 51 | raise ValueError('Multiple assignments to variable \'%s\' in %s' % (name, 52 | values)) 53 | 54 | 55 | def _process_scalar_value(name, parse_fn, var_type, m_dict, values, 56 | results_dictionary): 57 | """Update results_dictionary with a scalar value. 58 | 59 | Used to update the results_dictionary to be returned by parse_values when 60 | encountering a clause with a scalar RHS (e.g. "s=5" or "arr[0]=5".) 61 | 62 | Mutates results_dictionary. 63 | 64 | Args: 65 | name: Name of variable in assignment ("s" or "arr"). 66 | parse_fn: Function for parsing the actual value. 67 | var_type: Type of named variable. 68 | m_dict: Dictionary constructed from regex parsing. 69 | m_dict['val']: RHS value (scalar) 70 | m_dict['index']: List index value (or None) 71 | values: Full expression being parsed 72 | results_dictionary: The dictionary being updated for return by the parsing 73 | function. 74 | 75 | Raises: 76 | ValueError: If the name has already been used. 77 | """ 78 | try: 79 | parsed_value = parse_fn(m_dict['val']) 80 | except ValueError: 81 | _parse_fail(name, var_type, m_dict['val'], values) 82 | 83 | # If no index is provided 84 | if not m_dict['index']: 85 | if name in results_dictionary: 86 | _reuse_fail(name, values) 87 | results_dictionary[name] = parsed_value 88 | else: 89 | if name in results_dictionary: 90 | # The name has already been used as a scalar, then it 91 | # will be in this dictionary and map to a non-dictionary. 92 | if not isinstance(results_dictionary.get(name), dict): 93 | _reuse_fail(name, values) 94 | else: 95 | results_dictionary[name] = {} 96 | 97 | index = int(m_dict['index']) 98 | # Make sure the index position hasn't already been assigned a value. 99 | if index in results_dictionary[name]: 100 | _reuse_fail('{}[{}]'.format(name, index), values) 101 | results_dictionary[name][index] = parsed_value 102 | 103 | 104 | def _process_list_value(name, parse_fn, var_type, m_dict, values, 105 | results_dictionary): 106 | """Update results_dictionary from a list of values. 107 | 108 | Used to update results_dictionary to be returned by parse_values when 109 | encountering a clause with a list RHS (e.g. "arr=[1,2,3]".) 110 | 111 | Mutates results_dictionary. 112 | 113 | Args: 114 | name: Name of variable in assignment ("arr"). 115 | parse_fn: Function for parsing individual values. 116 | var_type: Type of named variable. 117 | m_dict: Dictionary constructed from regex parsing. 118 | m_dict['val']: RHS value (scalar) 119 | values: Full expression being parsed 120 | results_dictionary: The dictionary being updated for return by the parsing 121 | function. 122 | 123 | Raises: 124 | ValueError: If the name has an index or the values cannot be parsed. 125 | """ 126 | if m_dict['index'] is not None: 127 | raise ValueError('Assignment of a list to a list index.') 128 | elements = filter(None, re.split('[ ,]', m_dict['vals'])) 129 | # Make sure the name hasn't already been assigned a value 130 | if name in results_dictionary: 131 | raise _reuse_fail(name, values) 132 | try: 133 | results_dictionary[name] = [parse_fn(e) for e in elements] 134 | except ValueError: 135 | _parse_fail(name, var_type, m_dict['vals'], values) 136 | 137 | 138 | def _cast_to_type_if_compatible(name, param_type, value): 139 | """Cast hparam to the provided type, if compatible. 140 | 141 | Args: 142 | name: Name of the hparam to be cast. 143 | param_type: The type of the hparam. 144 | value: The value to be cast, if compatible. 145 | 146 | Returns: 147 | The result of casting `value` to `param_type`. 148 | 149 | Raises: 150 | ValueError: If the type of `value` is not compatible with param_type. 151 | * If `param_type` is a string type, but `value` is not. 152 | * If `param_type` is a boolean, but `value` is not, or vice versa. 153 | * If `param_type` is an integer type, but `value` is not. 154 | * If `param_type` is a float type, but `value` is not a numeric type. 155 | """ 156 | fail_msg = ( 157 | "Could not cast hparam '%s' of type '%s' from value %r" % 158 | (name, param_type, value)) 159 | 160 | # Some callers use None, for which we can't do any casting/checking. :( 161 | if issubclass(param_type, type(None)): 162 | return value 163 | 164 | # Avoid converting a non-string type to a string. 165 | if (issubclass(param_type, (six.string_types, six.binary_type)) and 166 | not isinstance(value, (six.string_types, six.binary_type))): 167 | raise ValueError(fail_msg) 168 | 169 | # Avoid converting a number or string type to a boolean or vice versa. 170 | if issubclass(param_type, bool) != isinstance(value, bool): 171 | raise ValueError(fail_msg) 172 | 173 | # Avoid converting float to an integer (the reverse is fine). 174 | if (issubclass(param_type, numbers.Integral) and 175 | not isinstance(value, numbers.Integral)): 176 | raise ValueError(fail_msg) 177 | 178 | # Avoid converting a non-numeric type to a numeric type. 179 | if (issubclass(param_type, numbers.Number) and 180 | not isinstance(value, numbers.Number)): 181 | raise ValueError(fail_msg) 182 | 183 | return param_type(value) 184 | 185 | 186 | def parse_values(values, type_map, ignore_unknown=False): 187 | """Parses hyperparameter values from a string into a python map. 188 | 189 | `values` is a string containing comma-separated `name=value` pairs. 190 | For each pair, the value of the hyperparameter named `name` is set to 191 | `value`. 192 | 193 | If a hyperparameter name appears multiple times in `values`, a ValueError 194 | is raised (e.g. 'a=1,a=2', 'a[1]=1,a[1]=2'). 195 | 196 | If a hyperparameter name in both an index assignment and scalar assignment, 197 | a ValueError is raised. (e.g. 'a=[1,2,3],a[0] = 1'). 198 | 199 | The hyperparameter name may contain '.' symbols, which will result in an 200 | attribute name that is only accessible through the getattr and setattr 201 | functions. (And must be first explicit added through add_hparam.) 202 | 203 | WARNING: Use of '.' in your variable names is allowed, but is not well 204 | supported and not recommended. 205 | 206 | The `value` in `name=value` must follows the syntax according to the 207 | type of the parameter: 208 | 209 | * Scalar integer: A Python-parsable integer point value. E.g.: 1, 210 | 100, -12. 211 | * Scalar float: A Python-parsable floating point value. E.g.: 1.0, 212 | -.54e89. 213 | * Boolean: Either true or false. 214 | * Scalar string: A non-empty sequence of characters, excluding comma, 215 | spaces, and square brackets. E.g.: foo, bar_1. 216 | * List: A comma separated list of scalar values of the parameter type 217 | enclosed in square brackets. E.g.: [1,2,3], [1.0,1e-12], [high,low]. 218 | 219 | When index assignment is used, the corresponding type_map key should be the 220 | list name. E.g. for "arr[1]=0" the type_map must have the key "arr" (not 221 | "arr[1]"). 222 | 223 | Args: 224 | values: String. Comma separated list of `name=value` pairs where 225 | 'value' must follow the syntax described above. 226 | type_map: A dictionary mapping hyperparameter names to types. Note every 227 | parameter name in values must be a key in type_map. The values must 228 | conform to the types indicated, where a value V is said to conform to a 229 | type T if either V has type T, or V is a list of elements of type T. 230 | Hence, for a multidimensional parameter 'x' taking float values, 231 | 'x=[0.1,0.2]' will parse successfully if type_map['x'] = float. 232 | ignore_unknown: Bool. Whether values that are missing a type in type_map 233 | should be ignored. If set to True, a ValueError will not be raised for 234 | unknown hyperparameter type. 235 | 236 | Returns: 237 | A python map mapping each name to either: 238 | * A scalar value. 239 | * A list of scalar values. 240 | * A dictionary mapping index numbers to scalar values. 241 | (e.g. "x=5,L=[1,2],arr[1]=3" results in {'x':5,'L':[1,2],'arr':{1:3}}") 242 | 243 | Raises: 244 | ValueError: If there is a problem with input. 245 | * If `values` cannot be parsed. 246 | * If a list is assigned to a list index (e.g. 'a[1] = [1,2,3]'). 247 | * If the same rvalue is assigned two different values (e.g. 'a=1,a=2', 248 | 'a[1]=1,a[1]=2', or 'a=1,a=[1]') 249 | """ 250 | results_dictionary = {} 251 | pos = 0 252 | while pos < len(values): 253 | m = PARAM_RE.match(values, pos) 254 | if not m: 255 | raise ValueError('Malformed hyperparameter value: %s' % values[pos:]) 256 | # Check that there is a comma between parameters and move past it. 257 | pos = m.end() 258 | # Parse the values. 259 | m_dict = m.groupdict() 260 | name = m_dict['name'] 261 | if name not in type_map: 262 | if ignore_unknown: 263 | continue 264 | raise ValueError('Unknown hyperparameter type for %s' % name) 265 | type_ = type_map[name] 266 | 267 | # Set up correct parsing function (depending on whether type_ is a bool) 268 | if type_ == bool: 269 | 270 | def parse_bool(value): 271 | if value in ['true', 'True']: 272 | return True 273 | elif value in ['false', 'False']: 274 | return False 275 | else: 276 | try: 277 | return bool(int(value)) 278 | except ValueError: 279 | _parse_fail(name, type_, value, values) 280 | 281 | parse = parse_bool 282 | else: 283 | parse = type_ 284 | 285 | # If a singe value is provided 286 | if m_dict['val'] is not None: 287 | _process_scalar_value(name, parse, type_, m_dict, values, 288 | results_dictionary) 289 | 290 | # If the assigned value is a list: 291 | elif m_dict['vals'] is not None: 292 | _process_list_value(name, parse, type_, m_dict, values, 293 | results_dictionary) 294 | 295 | else: # Not assigned a list or value 296 | _parse_fail(name, type_, '', values) 297 | 298 | return results_dictionary 299 | 300 | 301 | class HParams(object): 302 | """Class to hold a set of hyperparameters as name-value pairs. 303 | 304 | A `HParams` object holds hyperparameters used to build and train a model, 305 | such as the number of hidden units in a neural net layer or the learning rate 306 | to use when training. 307 | 308 | You first create a `HParams` object by specifying the names and values of the 309 | hyperparameters. 310 | 311 | To make them easily accessible the parameter names are added as direct 312 | attributes of the class. A typical usage is as follows: 313 | 314 | ```python 315 | # Create a HParams object specifying names and values of the model 316 | # hyperparameters: 317 | hparams = HParams(learning_rate=0.1, num_hidden_units=100) 318 | 319 | # The hyperparameter are available as attributes of the HParams object: 320 | hparams.learning_rate ==> 0.1 321 | hparams.num_hidden_units ==> 100 322 | ``` 323 | 324 | Hyperparameters have type, which is inferred from the type of their value 325 | passed at construction type. The currently supported types are: integer, 326 | float, boolean, string, and list of integer, float, boolean, or string. 327 | 328 | You can override hyperparameter values by calling the 329 | [`parse()`](#HParams.parse) method, passing a string of comma separated 330 | `name=value` pairs. This is intended to make it possible to override 331 | any hyperparameter values from a single command-line flag to which 332 | the user passes 'hyper-param=value' pairs. It avoids having to define 333 | one flag for each hyperparameter. 334 | 335 | The syntax expected for each value depends on the type of the parameter. 336 | See `parse()` for a description of the syntax. 337 | 338 | Example: 339 | 340 | ```python 341 | # Define a command line flag to pass name=value pairs. 342 | # For example using argparse: 343 | import argparse 344 | parser = argparse.ArgumentParser(description='Train my model.') 345 | parser.add_argument('--hparams', type=str, 346 | help='Comma separated list of "name=value" pairs.') 347 | args = parser.parse_args() 348 | ... 349 | def my_program(): 350 | # Create a HParams object specifying the names and values of the 351 | # model hyperparameters: 352 | hparams = tf.HParams(learning_rate=0.1, num_hidden_units=100, 353 | activations=['relu', 'tanh']) 354 | 355 | # Override hyperparameters values by parsing the command line 356 | hparams.parse(args.hparams) 357 | 358 | # If the user passed `--hparams=learning_rate=0.3` on the command line 359 | # then 'hparams' has the following attributes: 360 | hparams.learning_rate ==> 0.3 361 | hparams.num_hidden_units ==> 100 362 | hparams.activations ==> ['relu', 'tanh'] 363 | 364 | # If the hyperparameters are in json format use parse_json: 365 | hparams.parse_json('{"learning_rate": 0.3, "activations": "relu"}') 366 | ``` 367 | """ 368 | 369 | _HAS_DYNAMIC_ATTRIBUTES = True # Required for pytype checks. 370 | 371 | def __init__(self, model_structure=None, **kwargs): 372 | """Create an instance of `HParams` from keyword arguments. 373 | 374 | The keyword arguments specify name-values pairs for the hyperparameters. 375 | The parameter types are inferred from the type of the values passed. 376 | 377 | The parameter names are added as attributes of `HParams` object, so they 378 | can be accessed directly with the dot notation `hparams._name_`. 379 | 380 | Example: 381 | 382 | ```python 383 | # Define 3 hyperparameters: 'learning_rate' is a float parameter, 384 | # 'num_hidden_units' an integer parameter, and 'activation' a string 385 | # parameter. 386 | hparams = tf.HParams( 387 | learning_rate=0.1, num_hidden_units=100, activation='relu') 388 | 389 | hparams.activation ==> 'relu' 390 | ``` 391 | 392 | Note that a few names are reserved and cannot be used as hyperparameter 393 | names. If you use one of the reserved name the constructor raises a 394 | `ValueError`. 395 | 396 | Args: 397 | model_structure: An instance of ModelStructure, defining the feature 398 | crosses to be used in the Trial. 399 | **kwargs: Key-value pairs where the key is the hyperparameter name and 400 | the value is the value for the parameter. 401 | 402 | Raises: 403 | ValueError: If both `hparam_def` and initialization values are provided, 404 | or if one of the arguments is invalid. 405 | 406 | """ 407 | # Register the hyperparameters and their type in _hparam_types. 408 | # This simplifies the implementation of parse(). 409 | # _hparam_types maps the parameter name to a tuple (type, bool). 410 | # The type value is the type of the parameter for scalar hyperparameters, 411 | # or the type of the list elements for multidimensional hyperparameters. 412 | # The bool value is True if the value is a list, False otherwise. 413 | self._hparam_types = {} 414 | self._model_structure = model_structure 415 | for name, value in six.iteritems(kwargs): 416 | self.add_hparam(name, value) 417 | 418 | def add_hparam(self, name, value): 419 | """Adds {name, value} pair to hyperparameters. 420 | 421 | Args: 422 | name: Name of the hyperparameter. 423 | value: Value of the hyperparameter. Can be one of the following types: 424 | int, float, string, int list, float list, or string list. 425 | 426 | Raises: 427 | ValueError: if one of the arguments is invalid. 428 | """ 429 | # Keys in kwargs are unique, but 'name' could the name of a pre-existing 430 | # attribute of this object. In that case we refuse to use it as a 431 | # hyperparameter name. 432 | if getattr(self, name, None) is not None: 433 | raise ValueError('Hyperparameter name is reserved: %s' % name) 434 | if isinstance(value, (list, tuple)): 435 | if not value: 436 | raise ValueError( 437 | 'Multi-valued hyperparameters cannot be empty: %s' % name) 438 | self._hparam_types[name] = (type(value[0]), True) 439 | else: 440 | self._hparam_types[name] = (type(value), False) 441 | setattr(self, name, value) 442 | 443 | def set_hparam(self, name, value): 444 | """Set the value of an existing hyperparameter. 445 | 446 | This function verifies that the type of the value matches the type of the 447 | existing hyperparameter. 448 | 449 | Args: 450 | name: Name of the hyperparameter. 451 | value: New value of the hyperparameter. 452 | 453 | Raises: 454 | KeyError: If the hyperparameter doesn't exist. 455 | ValueError: If there is a type mismatch. 456 | """ 457 | param_type, is_list = self._hparam_types[name] 458 | if isinstance(value, list): 459 | if not is_list: 460 | raise ValueError( 461 | 'Must not pass a list for single-valued parameter: %s' % name) 462 | setattr(self, name, [ 463 | _cast_to_type_if_compatible(name, param_type, v) for v in value]) 464 | else: 465 | if is_list: 466 | raise ValueError( 467 | 'Must pass a list for multi-valued parameter: %s.' % name) 468 | setattr(self, name, _cast_to_type_if_compatible(name, param_type, value)) 469 | 470 | def del_hparam(self, name): 471 | """Removes the hyperparameter with key 'name'. 472 | 473 | Does nothing if it isn't present. 474 | 475 | Args: 476 | name: Name of the hyperparameter. 477 | """ 478 | if hasattr(self, name): 479 | delattr(self, name) 480 | del self._hparam_types[name] 481 | 482 | def parse(self, values): 483 | """Override existing hyperparameter values, parsing new values from a string. 484 | 485 | See parse_values for more detail on the allowed format for values. 486 | 487 | Args: 488 | values: String. Comma separated list of `name=value` pairs where 'value' 489 | must follow the syntax described above. 490 | 491 | Returns: 492 | The `HParams` instance. 493 | 494 | Raises: 495 | ValueError: If `values` cannot be parsed or a hyperparameter in `values` 496 | doesn't exist. 497 | """ 498 | type_map = {} 499 | for name, t in self._hparam_types.items(): 500 | param_type, _ = t 501 | type_map[name] = param_type 502 | 503 | values_map = parse_values(values, type_map) 504 | return self.override_from_dict(values_map) 505 | 506 | def override_from_dict(self, values_dict): 507 | """Override existing hyperparameter values, parsing new values from a dictionary. 508 | 509 | Args: 510 | values_dict: Dictionary of name:value pairs. 511 | 512 | Returns: 513 | The `HParams` instance. 514 | 515 | Raises: 516 | KeyError: If a hyperparameter in `values_dict` doesn't exist. 517 | ValueError: If `values_dict` cannot be parsed. 518 | """ 519 | for name, value in values_dict.items(): 520 | self.set_hparam(name, value) 521 | return self 522 | 523 | def set_model_structure(self, model_structure): 524 | self._model_structure = model_structure 525 | 526 | def get_model_structure(self): 527 | return self._model_structure 528 | 529 | def to_json(self, indent=None, separators=None, sort_keys=False): 530 | """Serializes the hyperparameters into JSON. 531 | 532 | Args: 533 | indent: If a non-negative integer, JSON array elements and object members 534 | will be pretty-printed with that indent level. An indent level of 0, or 535 | negative, will only insert newlines. `None` (the default) selects the 536 | most compact representation. 537 | separators: Optional `(item_separator, key_separator)` tuple. Default is 538 | `(', ', ': ')`. 539 | sort_keys: If `True`, the output dictionaries will be sorted by key. 540 | 541 | Returns: 542 | A JSON string. 543 | """ 544 | def remove_callables(x): 545 | """Omit callable elements from input with arbitrary nesting.""" 546 | if isinstance(x, dict): 547 | return {k: remove_callables(v) for k, v in six.iteritems(x) 548 | if not callable(v)} 549 | elif isinstance(x, list): 550 | return [remove_callables(i) for i in x if not callable(i)] 551 | return x 552 | return json.dumps( 553 | remove_callables(self.values()), 554 | indent=indent, 555 | separators=separators, 556 | sort_keys=sort_keys) 557 | 558 | def parse_json(self, values_json): 559 | """Override existing hyperparameter values, parsing new values from a json object. 560 | 561 | Args: 562 | values_json: String containing a json object of name:value pairs. 563 | 564 | Returns: 565 | The `HParams` instance. 566 | 567 | Raises: 568 | KeyError: If a hyperparameter in `values_json` doesn't exist. 569 | ValueError: If `values_json` cannot be parsed. 570 | """ 571 | values_map = json.loads(values_json) 572 | return self.override_from_dict(values_map) 573 | 574 | def values(self): 575 | """Return the hyperparameter values as a Python dictionary. 576 | 577 | Returns: 578 | A dictionary with hyperparameter names as keys. The values are the 579 | hyperparameter values. 580 | """ 581 | return {n: getattr(self, n) for n in self._hparam_types.keys()} 582 | 583 | def get(self, key, default=None): 584 | """Returns the value of `key` if it exists, else `default`.""" 585 | if key in self._hparam_types: 586 | # Ensure that default is compatible with the parameter type. 587 | if default is not None: 588 | param_type, is_param_list = self._hparam_types[key] 589 | type_str = 'list<%s>' % param_type if is_param_list else str(param_type) 590 | fail_msg = ("Hparam '%s' of type '%s' is incompatible with " 591 | 'default=%s' % (key, type_str, default)) 592 | 593 | is_default_list = isinstance(default, list) 594 | if is_param_list != is_default_list: 595 | raise ValueError(fail_msg) 596 | 597 | try: 598 | if is_default_list: 599 | for value in default: 600 | _cast_to_type_if_compatible(key, param_type, value) 601 | else: 602 | _cast_to_type_if_compatible(key, param_type, default) 603 | except ValueError as e: 604 | raise ValueError('%s. %s' % (fail_msg, e)) 605 | 606 | return getattr(self, key) 607 | 608 | return default 609 | 610 | def __contains__(self, key): 611 | return key in self._hparam_types 612 | 613 | def __str__(self): 614 | return str(sorted(self.values().items())) 615 | 616 | def __repr__(self): 617 | return '%s(%s)' % (type(self).__name__, self.__str__()) 618 | 619 | @staticmethod 620 | def _get_kind_name(param_type, is_list): 621 | """Returns the field name given parameter type and is_list. 622 | 623 | Args: 624 | param_type: Data type of the hparam. 625 | is_list: Whether this is a list. 626 | 627 | Returns: 628 | A string representation of the field name. 629 | 630 | Raises: 631 | ValueError: If parameter type is not recognized. 632 | """ 633 | if issubclass(param_type, bool): 634 | # This check must happen before issubclass(param_type, six.integer_types), 635 | # since Python considers bool to be a subclass of int. 636 | typename = 'bool' 637 | elif issubclass(param_type, six.integer_types): 638 | # Setting 'int' and 'long' types to be 'int64' to ensure the type is 639 | # compatible with both Python2 and Python3. 640 | typename = 'int64' 641 | elif issubclass(param_type, (six.string_types, six.binary_type)): 642 | # Setting 'string' and 'bytes' types to be 'bytes' to ensure the type is 643 | # compatible with both Python2 and Python3. 644 | typename = 'bytes' 645 | elif issubclass(param_type, float): 646 | typename = 'float' 647 | else: 648 | raise ValueError('Unsupported parameter type: %s' % str(param_type)) 649 | 650 | suffix = 'list' if is_list else 'value' 651 | return '_'.join([typename, suffix]) -------------------------------------------------------------------------------- /gui.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from PyQt5 import Qt 3 | from PyQt5 import QtCore,QtGui 4 | from PyQt5.QtCore import QMutex, QObject, QRunnable, pyqtSignal, pyqtSlot, QThreadPool, QTimer, QThread 5 | from PyQt5.QtWidgets import QWidget,QMainWindow,QHeaderView, QMessageBox, QFileDialog 6 | from nvidia_tacotron_TTS_Layout import Ui_MainWindow 7 | from ui import Ui_extras 8 | from timerthread import timerThread 9 | from preprocess import preprocess_text 10 | 11 | import time 12 | import requests 13 | import json 14 | import datetime 15 | import numpy as np 16 | import os 17 | import pygame 18 | 19 | import sys 20 | sys.path.append(os.path.join(sys.path[0],'waveglow/')) 21 | 22 | import numpy as np 23 | import torch 24 | 25 | from hparams import create_hparams 26 | from model import Tacotron2 27 | from train import load_model 28 | from text import text_to_sequence, cleaners 29 | #from denoiser import Denoiser 30 | 31 | from secrets import TOKEN # for debugging 32 | 33 | _mutex1 = QMutex() 34 | _running1 = False # tab 0 synthesis QThread : Start/stop 35 | _mutex2 = QMutex() 36 | _running2 = False # tab 1 eventloop QRunnable: Start/stop 37 | _mutex3 = QMutex() 38 | _running3 = False # tab 1 eventloop QRunnable: Skip current item 39 | 40 | #https://www.learnpyqt.com/courses/concurrent-execution/multithreading-pyqt-applications-qthreadpool/ 41 | class WorkerSignals(QObject): 42 | ''' 43 | Defines the signals available from a running worker thread. 44 | 45 | Supported signals are: 46 | 47 | finished 48 | No data 49 | 50 | error 51 | `tuple` (exctype, value, traceback.format_exc() ) 52 | 53 | result 54 | `object` data returned from processing, anything 55 | 56 | progress 57 | `int` indicating % progress 58 | 59 | ''' 60 | 61 | textready = pyqtSignal(str) 62 | finished = pyqtSignal() 63 | error = pyqtSignal(tuple) 64 | result = pyqtSignal(object) 65 | progress = pyqtSignal(int) 66 | elapsed = pyqtSignal(int) 67 | fncallback = pyqtSignal(tuple) 68 | 69 | class Worker(QRunnable): 70 | ''' 71 | Worker thread 72 | 73 | Inherits from QRunnable to handler worker thread setup, signals and wrap-up. 74 | 75 | :param callback: The function callback to run on this worker thread. Supplied args and 76 | kwargs will be passed through to the runner. 77 | :type callback: function 78 | :param args: Arguments to pass to the callback function 79 | :param kwargs: Keywords to pass to the callback function 80 | 81 | ''' 82 | 83 | def __init__(self, fn, *args, **kwargs): 84 | super(Worker, self).__init__() 85 | 86 | # Store constructor arguments (re-used for processing) 87 | self.fn = fn 88 | self.args = args 89 | self.kwargs = kwargs 90 | self.signals = WorkerSignals() 91 | 92 | # Add the callback to our kwargs 93 | self.kwargs['progress_callback'] = self.signals.progress 94 | self.kwargs['elapsed_callback'] = self.signals.elapsed 95 | self.kwargs['text_ready'] = self.signals.textready 96 | self.kwargs['fn_callback'] = self.signals.fncallback 97 | 98 | @pyqtSlot() 99 | def run(self): 100 | ''' 101 | Initialise the runner function with passed args, kwargs. 102 | ''' 103 | 104 | # Retrieve args/kwargs here; and fire processing using them 105 | try: 106 | result = self.fn(*self.args, **self.kwargs) 107 | except: 108 | pass 109 | # traceback.print_exc() 110 | # exctype, value = sys.exc_info()[:2] 111 | # self.signals.error.emit((exctype, value, traceback.format_exc())) 112 | else: 113 | self.signals.result.emit(result) # Return the result of the processing 114 | finally: 115 | self.signals.finished.emit() # Done 116 | 117 | class GUISignals(QObject): 118 | progress = pyqtSignal(int) 119 | elapsed = pyqtSignal(int) 120 | 121 | class GUI(QMainWindow, Ui_MainWindow, Ui_extras): 122 | def __init__(self,app): 123 | super(GUI, self).__init__() 124 | self.app = app 125 | 126 | ### Setup UI and signals 127 | self.setupUi(self) 128 | self.drawGpuSwitch(self) 129 | self.initWidgets(self) 130 | self.signals = GUISignals() 131 | self.setUpconnections(self) 132 | 133 | ### Init vars 134 | self.model = None 135 | self.waveglow = None 136 | self.hparams = None 137 | self.current_thread = None 138 | self.t_1 = None # timing 139 | self.logs = [] # message logs 140 | self.logs2 = [] 141 | self.max_log_lines = 3 142 | self.max_log2_lines = 100 143 | self.TTmodel_dir = [] # list of model paths 144 | self.WGmodel_dir = [] 145 | self.reload_model_flag = True 146 | self.channel_id = '' # stream elements channel ID 147 | # Because of bug in streamelements timestamp filter, need 2 variables for previous time 148 | self.startup_time = datetime.datetime.utcnow().isoformat() 149 | #self.startup_time = '0' # For debugging 150 | self.prev_time = datetime.datetime.utcnow().isoformat() 151 | #self.prev_time = '0' # for debugging 152 | self.msg_offset = 0 153 | self.se_opts = {'approve only': 2, # Stream element options 154 | 'block large numbers': 0, 155 | 'read dono amount': 2, 156 | } 157 | self.fns = {'GUI: start of polling loop': self.fns_gui_startpolling, # Callback functions 158 | 'GUI: end of polling loop': self.fns_gui_endpolling , 159 | 'Wav: playback' : self.fns_wav_playback, 160 | 'Var: offset': self.fns_var_offset, 161 | 'Var: prev_time': self.fns_var_prevtime, 162 | 'GUI: progress bar 2 text' : self.fns_gui_pbtext, 163 | 'GUI: reenable skip btn' : self.fns_gui_enableclientskipbtn} 164 | self.pyt_opts = {'cpu limit': None, # pytorch options 165 | 'denoiser':None} 166 | 167 | ### Init pygame mixer 168 | pygame.mixer.quit() 169 | pygame.mixer.init(frequency=22050,size=-16, channels=1) 170 | self.channel = pygame.mixer.Channel(0) 171 | 172 | ### Init qthreadpool 173 | self.threadpool = QThreadPool() 174 | print("Multithreading with maximum %d threads" % self.threadpool.maxThreadCount()) 175 | 176 | ### Setup Complete 177 | self.update_log_window("Begin by loading a model") 178 | 179 | @pyqtSlot(int) 180 | def toggle_cpu_limit(self, state): 181 | self.label_10.setEnabled(state) 182 | self.OptLimitCpuCombo.setEnabled(state) 183 | 184 | @pyqtSlot(int) 185 | def change_cpu_limit(self, indx): 186 | num_thread = indx + 1 187 | self.pyt_opts['cpu limit'] = num_thread 188 | 189 | @pyqtSlot(int) 190 | def toggle_approve_dono(self, state): 191 | self.se_opts['approve only'] = state 192 | 193 | @pyqtSlot(int) 194 | def toggle_block_number(self, state): 195 | self.se_opts['block large numbers'] = state 196 | 197 | @pyqtSlot(int) 198 | def toggle_dono_amount(self, state): 199 | self.se_opts['read dono amount'] = state 200 | 201 | @pyqtSlot(tuple) 202 | def on_fncallback(self,tup): 203 | option,arg = tup 204 | self.fns[option](arg) 205 | 206 | @pyqtSlot(str) 207 | def on_textready(self,text): 208 | # Function to send text from client thread to GUI thread 209 | # Format of text: : 210 | obj = text[0:4] 211 | msg = text[5:] 212 | if obj=='Log1': 213 | if len(self.logs) > self.max_log_lines: 214 | self.logs.pop(0) 215 | self.logs.append(msg) 216 | log_text = '\n'.join(self.logs) 217 | self.log_window1.setText(log_text) 218 | if obj=='Log2': 219 | if len(self.logs2) > self.max_log2_lines: 220 | self.logs2.pop(0) 221 | self.logs2.append(msg) 222 | log_text = '\n'.join(self.logs2) 223 | self.log_window2.setPlainText(log_text) 224 | self.log_window2.verticalScrollBar().setValue( 225 | self.log_window2.verticalScrollBar().maximum()) 226 | if obj=='Sta2': 227 | self.statusbar.setText(msg) 228 | 229 | @pyqtSlot(int) 230 | def update_log_bar(self,val): 231 | self.progressBar.setValue(val) 232 | #self.progressBar.setTextVisible(val != 0) 233 | 234 | @pyqtSlot(int) 235 | def update_log_bar2(self,val): 236 | self.progressBar2.setValue(val) 237 | #self.progressBar2.setTextVisible(val != 0) 238 | 239 | @pyqtSlot(int) 240 | def on_elapsed(self,val): 241 | if self.tabWidget.currentIndex()==0: 242 | self.update_log_window('Elapsed: '+str(val)+'s',mode='overwrite') 243 | else: 244 | pass # No elapsed time for tab2 245 | 246 | @pyqtSlot(np.ndarray) 247 | def on_inferThread_complete(self,wav): 248 | global _running1 249 | _mutex1.lock() 250 | _running1 = False 251 | _mutex1.unlock() 252 | self.playback_wav(wav) 253 | self.TTSDialogButton.setEnabled(True) 254 | self.TTModelCombo.setEnabled(True) 255 | self.WGModelCombo.setEnabled(True) 256 | self.TTSTextEdit.setEnabled(True) 257 | self.LoadTTButton.setEnabled(True) 258 | self.LoadWGButton.setEnabled(True) 259 | self.tab_2.setEnabled(True) 260 | elapsed = (time.time() - self.t_1) 261 | wav_length = (len(wav) / self.hparams.sampling_rate) 262 | rtf = elapsed / wav_length 263 | line = 'Generated {:.1f}s of audio in {:.1f}s ({:.2f} real-time factor)'.format(wav_length,elapsed,rtf) 264 | self.update_log_window(line,'overwrite') 265 | tps = elapsed / len(wav) 266 | print(" > Run-time: {}".format(elapsed)) 267 | print(" > Real-time factor: {}".format(rtf)) 268 | print(" > Time per step: {}".format(tps)) 269 | self.update_status_bar("Ready") 270 | # TODO get pygame mixer callback on end or use sounddevice 271 | 272 | @pyqtSlot(tuple) 273 | def on_itersignal(self,tup): 274 | # Displays current iteration on progress bar 275 | current,total = tup 276 | self.progressBarLabel.setText('{}/{}'.format(current,total)) 277 | 278 | @pyqtSlot() 279 | def on_interrupt(self): 280 | # Reenable buttons 281 | self.TTSDialogButton.setEnabled(True) 282 | self.TTModelCombo.setEnabled(True) 283 | self.WGModelCombo.setEnabled(True) 284 | self.TTSTextEdit.setEnabled(True) 285 | self.LoadTTButton.setEnabled(True) 286 | self.LoadWGButton.setEnabled(True) 287 | self.tab_2.setEnabled(True) 288 | # Refresh progress bar 289 | self.update_log_bar(0) 290 | self.progressBarLabel.setText('') 291 | # Write to log window 292 | self.update_log_window('Interrupted','overwrite') 293 | # Write to status bar 294 | self.update_status_bar("Ready") 295 | 296 | def fns_gui_startpolling(self,arg=None): 297 | self.ClientStartBtn.setDisabled(True) 298 | self.ClientStopBtn.setEnabled(True) 299 | self.tab.setDisabled(True) 300 | self.tab_3.setDisabled(True) 301 | self.ClientAmountLine.setDisabled(True) 302 | 303 | def fns_gui_endpolling(self,arg=None): 304 | self.update_log_bar2(0) 305 | self.progressBar2Label.setText('') 306 | self.ClientStartBtn.setEnabled(True) 307 | self.ClientStopBtn.setDisabled(True) 308 | self.ClientSkipBtn.setDisabled(True) 309 | self.tab.setEnabled(True) 310 | self.tab_3.setEnabled(True) 311 | self.ClientAmountLine.setEnabled(True) 312 | 313 | def fns_wav_playback(self,wav): 314 | if self.tabWidget.currentIndex()==0: 315 | self.TTSStopButton.setEnabled(True) 316 | else: 317 | self.ClientSkipBtn.setEnabled(True) 318 | if wav.dtype != np.int16 : 319 | # Convert from float32 or float16 to signed int16 for pygame 320 | wav = (wav/np.amax(wav) * 32767).astype(np.int16) 321 | sound = pygame.mixer.Sound(wav) 322 | self.channel.queue(sound) 323 | 324 | def fns_var_offset(self,arg): 325 | self.msg_offset = arg 326 | 327 | def fns_var_prevtime(self,arg): 328 | self.prev_time = arg 329 | 330 | def fns_gui_pbtext(self,tup): 331 | current,total = tup 332 | self.progressBar2Label.setText('{}/{}'.format(current,total)) 333 | 334 | def fns_gui_enableclientskipbtn(self,arg=None): 335 | self.ClientSkipBtn.setEnabled(True) 336 | 337 | def on_finished(self): 338 | #print("THREAD COMPLETE!") 339 | pass 340 | 341 | def on_result(self, s): 342 | #print(s) 343 | pass 344 | 345 | def start_eventloop(self): 346 | # Pass the function to execute 347 | global _running2,_running3 348 | if not self.validate_se(): 349 | return 350 | if self.reload_model_flag: 351 | self.reload_model() 352 | self.reload_model_flag = False 353 | min_donation = self.get_min_donation() 354 | TOKEN = self.get_token() 355 | _mutex2.lock() 356 | _running2 = True 357 | _mutex2.unlock() 358 | _mutex3.lock() 359 | _running3 = True 360 | _mutex3.unlock() 361 | worker = Worker(self.eventloop, TOKEN, min_donation, self.channel, 362 | self.se_opts, self.use_cuda, self.model, self.waveglow, self.pyt_opts['cpu limit'], 363 | self.msg_offset, self.prev_time, self.startup_time) 364 | # Any other args, kwargs are passed to the run function 365 | worker.signals.result.connect(self.on_result) 366 | worker.signals.finished.connect(self.on_finished) 367 | worker.signals.progress.connect(self.update_log_bar2) 368 | worker.signals.textready.connect(self.on_textready) 369 | worker.signals.elapsed.connect(self.on_elapsed) 370 | worker.signals.fncallback.connect(self.on_fncallback) 371 | # Execute 372 | self.threadpool.start(worker) 373 | 374 | def stop_eventloop(self): 375 | global _running2, _running3 376 | _mutex2.lock() 377 | _running2 = False 378 | _mutex2.unlock() 379 | _mutex3.lock() 380 | _running3 = False 381 | _mutex3.unlock() 382 | self.skip_wav() 383 | 384 | def skip_eventloop(self): 385 | global _running3 386 | _mutex3.lock() 387 | _running3 = False 388 | _mutex3.unlock() 389 | self.skip_wav() 390 | 391 | def eventloop(self, TOKEN, min_donation, channel, se_opts, 392 | use_cuda, model, waveglow, num_thread, 393 | offset, prev_time, startup_time, 394 | progress_callback, elapsed_callback, text_ready, fn_callback): 395 | # TODO: refactor this messy block 396 | global _running3 397 | if num_thread: 398 | torch.set_num_threads(num_thread) 399 | os.environ['OMP_NUM_THREADS'] = str(num_thread) 400 | os.environ['MKL_NUM_THREADS'] = str(num_thread) 401 | fn_callback.emit(('GUI: start of polling loop',None)) 402 | text_ready.emit("Sta2:Connecting to StreamElements") 403 | url = "https://api.streamelements.com/kappa/v2/tips/"+self.channel_id 404 | headers = {'accept': 'application/json',"Authorization": "Bearer "+TOKEN} 405 | text_ready.emit('Log2:Initializing') 406 | text_ready.emit('Log2:Minimum amount for TTS: '+str(min_donation)) 407 | while True: 408 | _mutex2.lock() 409 | if _running2 == False: 410 | _mutex2.unlock() 411 | break 412 | else: 413 | _mutex2.unlock() 414 | if not channel.get_busy(): 415 | #print('Polling', datetime.datetime.utcnow().isoformat()) 416 | text_ready.emit("Sta2:Waiting for incoming donations . . .") 417 | current_time = datetime.datetime.utcnow().isoformat() 418 | # TODO: possible bug: missed donations once time pasts midnight 419 | querystring = {"offset":offset, 420 | "limit":"1", 421 | "sort":"createdAt", 422 | "after":startup_time, 423 | "before":current_time} 424 | response = requests.request("GET", url, headers=headers, params=querystring) 425 | data = json.loads(response.text) 426 | for dono in data['docs']: 427 | text_ready.emit("Sta2:Processing donations") 428 | dono_time = dono['createdAt'] 429 | offset += 1 430 | if dono_time > prev_time: # Str comparison 431 | amount = dono['donation']['amount'] # Int 432 | if se_opts['approve only'] == 2: 433 | approved = dono['approved']=='allowed' 434 | else: 435 | approved = True 436 | if float(amount) >= min_donation and approved: 437 | _mutex3.lock() 438 | if not _running3: 439 | _running3 = True 440 | _mutex3.unlock() 441 | fn_callback.emit(('GUI: reenable skip btn',None)) 442 | name = dono['donation']['user']['username'] 443 | msg = dono['donation']['message'] 444 | if msg.isspace(): break # Check for empty line 445 | ## TODO Allow multiple speaker in msg 446 | currency = dono['donation']['currency'] 447 | dono_id = dono['_id'] 448 | text_ready.emit("Log2:\n###########################") 449 | text_ready.emit("Log2:"+name+' donated '+currency+str(amount)) 450 | text_ready.emit("Log2:"+msg) 451 | lines = preprocess_text(msg) 452 | if se_opts['read dono amount'] == 2: # reads dono name and amount 453 | msg = '{} donated {} {}.'.format(name, 454 | str(amount), 455 | cleaners.expand_currency(currency)) 456 | lines.insert(0,msg) # Add to head to list 457 | output = [] 458 | for count, line in enumerate(lines): 459 | fn_callback.emit(('GUI: progress bar 2 text', (count,len(lines)))) 460 | sequence = np.array(text_to_sequence(line, ['english_cleaners']))[None, :] 461 | # Inference 462 | device = torch.device('cuda' if use_cuda else 'cpu') 463 | sequence = torch.autograd.Variable( 464 | torch.from_numpy(sequence)).to(device).long() 465 | # Decode text input 466 | mel_outputs, mel_outputs_postnet, _, alignments = model.inference(sequence) 467 | with torch.no_grad(): 468 | audio = waveglow.infer(mel_outputs_postnet, 469 | sigma=0.666, 470 | progress_callback = progress_callback, 471 | elapsed_callback = None, 472 | get_interruptflag = self.get_interruptflag2) 473 | if type(audio) != torch.Tensor: 474 | # Catches when waveglow is interrupted and returns none 475 | break 476 | fn_callback.emit(('GUI: progress bar 2 text', (count+1,len(lines)))) 477 | wav = audio[0].data.cpu().numpy() 478 | output.append(wav) 479 | _mutex3.lock() 480 | if _running3 == True: 481 | _mutex3.unlock() 482 | outwav = np.concatenate(output) 483 | # Playback 484 | fn_callback.emit(('Wav: playback',outwav)) 485 | else: _mutex3.unlock() 486 | prev_time = dono_time # Increment time 487 | time.sleep(0.5) 488 | fn_callback.emit(('GUI: end of polling loop',None)) 489 | text_ready.emit('Log2:\nDisconnected') 490 | text_ready.emit('Sta2:Ready') 491 | fn_callback.emit(('Var: offset', offset)) 492 | fn_callback.emit(('Var: prev_time', prev_time)) 493 | return #'Return value of execute_this_fn' 494 | 495 | def startup_update(self): 496 | if not self.tab_2.isEnabled(): 497 | self.tab_2.setEnabled(True) 498 | if not self.TTSDialogButton.isEnabled(): 499 | self.TTSDialogButton.setEnabled(True) 500 | 501 | def playback_wav(self,wav): 502 | if self.tabWidget.currentIndex()==1: 503 | self.ClientSkipBtn.setEnabled(True) 504 | if wav.dtype != np.int16 : 505 | # Convert from float32 or float16 to signed int16 for pygame 506 | wav = (wav/np.amax(wav) * 32767).astype(np.int16) 507 | sound = pygame.mixer.Sound(wav) 508 | self.channel.queue(sound) 509 | # TODO Disable skip btn on playback end 510 | 511 | def skip_wav(self): 512 | if self.channel.get_busy(): 513 | self.channel.stop() 514 | self.ClientSkipBtn.setDisabled(True) 515 | 516 | def skip_infer_playback(self): 517 | global _running1 518 | if self.channel.get_busy(): 519 | self.channel.stop() 520 | _mutex1.lock() # We could also use a signal/slot mechanism 521 | if _running1: 522 | self.progressBarLabel.setText('Interrupting...') 523 | _running1 = False # instead of mutex since inference is on QThread 524 | _mutex1.unlock() 525 | self.TTSStopButton.setDisabled(True) 526 | 527 | def reload_model(self): 528 | TTmodel_fpath = self.get_current_TTmodel_dir() 529 | WGmodel_fpath = self.get_current_WGmodel_dir() 530 | # Setup hparams 531 | self.hparams = create_hparams() 532 | self.hparams.sampling_rate = 22050 533 | # Load Tacotron 2 from checkpoint 534 | self.model = load_model(self.hparams,self.use_cuda) 535 | device = torch.device('cuda' if self.use_cuda else 'cpu') 536 | self.model.load_state_dict(torch.load(TTmodel_fpath, map_location = device)['state_dict']) 537 | if self.use_cuda: 538 | _ = self.model.cuda().eval().half() 539 | else: 540 | _ = self.model.eval() 541 | # Load WaveGlow for mel2audio synthesis and denoiser 542 | self.waveglow = torch.load(WGmodel_fpath, map_location = device)['model'] 543 | self.waveglow.use_cuda = self.use_cuda 544 | if self.use_cuda: 545 | self.waveglow.cuda().eval().half() 546 | else: 547 | self.waveglow.eval() 548 | for k in self.waveglow.convinv: 549 | k.float() 550 | #denoiser = Denoiser(waveglow,use_cuda=self.use_cuda) 551 | 552 | def start_synthesis(self): 553 | # Runs in main gui thread. Synthesize blocks gui. 554 | # Can update gui directly in this function. 555 | text = self.TTSTextEdit.toPlainText() 556 | if text.isspace():return 557 | global _running1 558 | self.t_1 = time.time() 559 | self.TTSDialogButton.setDisabled(True) 560 | self.TTModelCombo.setDisabled(True) 561 | self.WGModelCombo.setDisabled(True) 562 | self.TTSTextEdit.setDisabled(True) 563 | self.LoadTTButton.setDisabled(True) 564 | self.LoadWGButton.setDisabled(True) 565 | self.TTSStopButton.setEnabled(True) 566 | self.tab_2.setDisabled(True) 567 | self.update_log_bar(0) 568 | self.update_log_window('Initializing','clear') 569 | self.update_status_bar("Creating voice") 570 | # We use a signal callback here to stick to the same params type in synthesize.py 571 | if self.reload_model_flag: 572 | self.reload_model() 573 | self.reload_model_flag = False 574 | # Prepare text input 575 | _mutex1.lock() 576 | _running1 = True 577 | _mutex1.unlock() 578 | self.current_thread = inferThread(text, 579 | self.use_cuda, 580 | self.model, 581 | self.waveglow, 582 | self.signals.progress, 583 | None, 584 | self.t_1, 585 | self.pyt_opts['cpu limit'], 586 | parent = self) 587 | self.current_thread.audioSignal.connect(self.on_inferThread_complete) 588 | self.current_thread.timeElapsed.connect(self.on_elapsed) 589 | self.current_thread.iterSignal.connect(self.on_itersignal) 590 | self.current_thread.interruptSignal.connect(self.on_interrupt) 591 | 592 | def validate_se(self): 593 | # Connect to streamelement and saves channel id 594 | # return true if chn id and token returns valid 595 | # Test Channel ID 596 | self.update_status_bar("Validating StreamElements") 597 | CHANNEL_NAME = ''.join(self.ChannelName.text().split()) 598 | url = "https://api.streamelements.com/kappa/v2/channels/"+CHANNEL_NAME 599 | response = requests.request("GET", url, headers={'accept': 'application/json'}) 600 | if response.status_code == 200: 601 | # Test JWT Token 602 | self.channel_id = json.loads(response.text)['_id'] 603 | url = "https://api.streamelements.com/kappa/v2/tips/"+self.channel_id 604 | querystring = {"offset":"0","limit":"10","sort":"createdAt","after":"0","before":"0"} 605 | TOKEN = self.get_token() 606 | headers = {'accept': 'application/json',"Authorization": "Bearer "+TOKEN} 607 | response2 = requests.request("GET", url, headers=headers, params=querystring) 608 | if response2.status_code == 200: 609 | self.update_log_window_2("\nConnected to "+CHANNEL_NAME) 610 | return True 611 | else: 612 | self.update_log_window_2("\nError: Double check your token") 613 | self.update_status_bar("Invalid StreamElements") 614 | print(response2.text) 615 | else: 616 | self.update_log_window_2("\nError: Double check your channel name") 617 | self.update_status_bar("Invalid StreamElements") 618 | print(response.text) 619 | 620 | return False 621 | 622 | def get_min_donation(self): 623 | return float(self.ClientAmountLine.value()) 624 | 625 | def get_token(self): 626 | #TOKEN = ''.join(self.APIKeyLine.text().split()) 627 | #return TOKEN 628 | tokenobj = TOKEN() # for debugging 629 | return tokenobj.token # for debugging 630 | 631 | def get_current_TTmodel_dir(self): 632 | return self.TTmodel_dir[self.TTModelCombo.currentIndex()] 633 | 634 | def get_current_WGmodel_dir(self): 635 | return self.WGmodel_dir[self.WGModelCombo.currentIndex()] 636 | 637 | def get_current_TTmodel_fname(self): 638 | return self.TTModelCombo.currentText() 639 | 640 | def get_current_WGmodel_fname(self): 641 | return self.WGModelCombo.currentText() 642 | 643 | def get_interruptflag2(self): 644 | _mutex3.lock() 645 | val = _running3 646 | _mutex3.unlock() 647 | return val 648 | 649 | def set_reload_model_flag(self): 650 | self.reload_model_flag = True 651 | 652 | def set_cuda(self): 653 | self.use_cuda = self.GpuSwitch.isChecked() 654 | self.reload_model_flag = True 655 | 656 | def add_TTmodel_path(self): 657 | fpath = str(QFileDialog.getOpenFileName(self, 658 | 'Select Tacotron2 model', 659 | filter='*.pt')[0]) 660 | if not fpath: # If no folder selected 661 | return 662 | if fpath not in self.TTmodel_dir: 663 | head,tail = os.path.split(fpath) # Split into parent and child dir 664 | self.TTmodel_dir.append(fpath) # Save full path 665 | self.populate_modelcombo(tail, self.TTModelCombo) 666 | self.update_log_window("Added Tacotron 2 model: "+tail) 667 | if self.WGModelCombo.count() > 0: 668 | self.startup_update() 669 | 670 | def add_WGmodel_path(self): 671 | fpath = str(QFileDialog.getOpenFileName(self, 672 | 'Select Waveglow model', 673 | filter='*.pt')[0]) 674 | if not fpath: # If no folder selected 675 | return 676 | if fpath not in self.WGmodel_dir: 677 | head,tail = os.path.split(fpath) # Split into parent and child dir 678 | self.WGmodel_dir.append(fpath) # Save full path 679 | self.populate_modelcombo(tail, self.WGModelCombo) 680 | self.update_log_window("Added Waveglow model: "+tail) 681 | if self.TTModelCombo.count() > 0: 682 | self.startup_update() 683 | 684 | def populate_modelcombo(self, item, combobox): 685 | combobox.addItem(item) 686 | combobox.setCurrentIndex(combobox.count()-1) 687 | if not combobox.isEnabled(): 688 | combobox.setEnabled(True) 689 | 690 | def update_log_window(self, line, mode="newline"): 691 | if mode == "newline" or not self.logs: 692 | self.logs.append(line) 693 | if len(self.logs) > self.max_log_lines: 694 | del self.logs[0] 695 | elif mode == "append": 696 | self.logs[-1] += line 697 | elif mode == "overwrite": 698 | self.logs[-1] = line 699 | elif mode == "clear": 700 | self.logs = [line] 701 | log_text = '\n'.join(self.logs) 702 | self.log_window1.setText(log_text) 703 | 704 | def update_log_window_2(self, line, mode="newline"): 705 | if mode == "newline" or not self.logs2: 706 | self.logs2.append(line) 707 | elif mode == "append": 708 | self.logs2[-1] += line 709 | elif mode == "overwrite": 710 | self.logs2[-1] = line 711 | log_text = '\n'.join(self.logs2) 712 | self.log_window2.setPlainText(log_text) 713 | self.log_window2.verticalScrollBar().setValue( 714 | self.log_window2.verticalScrollBar().maximum()) 715 | 716 | def update_status_bar(self, line): 717 | self.statusbar.setText(line) 718 | 719 | class inferThread(QThread): 720 | timeElapsed = pyqtSignal(int) 721 | audioSignal = pyqtSignal(np.ndarray) 722 | iterSignal = pyqtSignal(tuple) 723 | interruptSignal = pyqtSignal() 724 | 725 | def __init__(self, text, use_cuda, model, waveglow, 726 | progress, elapsed, timestart, num_thread, parent=None): 727 | super(inferThread, self).__init__(parent) 728 | self.text = text 729 | self.use_cuda = use_cuda 730 | self.model = model 731 | self.waveglow = waveglow 732 | self.progress = progress 733 | self.elapsed = elapsed 734 | self.num_thread = num_thread 735 | self.timeoffset = time.time()-timestart 736 | self.timerThread = timerThread(self.timeoffset, parent = self) 737 | self.timerThread.timeElapsed.connect(self.timeElapsed.emit) 738 | self.start() 739 | 740 | def run(self): 741 | self.timerThread.start(time.time()) 742 | if self.num_thread: 743 | torch.set_num_threads(self.num_thread) 744 | os.environ['OMP_NUM_THREADS'] = str(self.num_thread) 745 | os.environ['MKL_NUM_THREADS'] = str(self.num_thread) 746 | lines = preprocess_text(self.text) 747 | output = [] 748 | for count,line in enumerate(lines): 749 | _mutex1.lock() 750 | if _running1 == False: 751 | _mutex1.unlock() 752 | self.interruptSignal.emit() 753 | return 754 | else: 755 | _mutex1.unlock() 756 | self.iterSignal.emit((count,len(lines))) 757 | sequence = np.array(text_to_sequence(line, ['english_cleaners']))[None, :] 758 | device = torch.device('cuda' if self.use_cuda else 'cpu') 759 | sequence = torch.autograd.Variable( 760 | torch.from_numpy(sequence)).to(device).long() 761 | # Decode text input 762 | mel_outputs, mel_outputs_postnet, _, alignments = self.model.inference(sequence) 763 | with torch.no_grad(): 764 | audio = self.waveglow.infer(mel_outputs_postnet, 765 | sigma=0.666, 766 | progress_callback = self.progress, 767 | elapsed_callback = self.elapsed, 768 | get_interruptflag = self.get_interruptflag) 769 | if type(audio) != torch.Tensor: 770 | # Catches when waveglow is interrupted and returns none 771 | self.interruptSignal.emit() 772 | return 773 | self.iterSignal.emit((count+1,len(lines))) 774 | wav = audio[0].data.cpu().numpy() 775 | output.append(wav) 776 | outwav = np.concatenate(output) 777 | self.audioSignal.emit(outwav) 778 | 779 | def get_interruptflag(self): 780 | _mutex1.lock() 781 | val = _running1 782 | _mutex1.unlock() 783 | return val 784 | 785 | 786 | if __name__ == '__main__': 787 | app = Qt.QApplication(sys.argv) 788 | window = GUI(app) 789 | window.show() 790 | sys.exit(app.exec_()) --------------------------------------------------------------------------------