├── .gitattributes ├── demo.wav ├── tensorboard.png ├── .gitmodules ├── requirements.txt ├── Dockerfile ├── preprocess.py ├── multiproc.py ├── loss_function.py ├── timerthread.py ├── text ├── symbols.py ├── LICENSE ├── cmudict.py ├── numbers.py ├── __init__.py └── cleaners.py ├── utils.py ├── LICENSE ├── plotting_utils.py ├── logger.py ├── .gitignore ├── audio_processing.py ├── ui.py ├── hparams.py ├── layers.py ├── data_utils.py ├── loss_scaler.py ├── README.md ├── stft.py ├── switch.py ├── distributed.py ├── filelists └── ljs_audio_text_val_filelist.txt ├── train.py ├── nvidia_tacotron_TTS_Layout.py ├── model.py ├── utils_hparam.py └── gui.py /.gitattributes: -------------------------------------------------------------------------------- 1 | *.ipynb linguist-language=Python -------------------------------------------------------------------------------- /demo.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lokkelvin2/tacotron2-tts-GUI/HEAD/demo.wav -------------------------------------------------------------------------------- /tensorboard.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lokkelvin2/tacotron2-tts-GUI/HEAD/tensorboard.png -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "waveglow"] 2 | path = waveglow 3 | url = https://github.com/lokkelvin2/waveglow_GUI 4 | branch = gui 5 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | matplotlib 2 | requests 3 | numpy 4 | inflect 5 | librosa 6 | scipy 7 | Unidecode 8 | pillow 9 | pygame 10 | pyqt5==5.15.0 11 | numba==0.48 12 | tqdm -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM pytorch/pytorch:nightly-devel-cuda10.0-cudnn7 2 | ENV PATH /usr/local/nvidia/bin:/usr/local/cuda/bin:${PATH} 3 | 4 | RUN apt-get update -y 5 | 6 | RUN pip install numpy scipy matplotlib librosa==0.6.0 tensorflow tensorboardX inflect==0.2.5 Unidecode==1.0.22 pillow jupyter 7 | 8 | ADD apex /apex/ 9 | WORKDIR /apex/ 10 | RUN pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" . 11 | -------------------------------------------------------------------------------- /preprocess.py: -------------------------------------------------------------------------------- 1 | import re 2 | import num2words 3 | import textwrap 4 | 5 | _MAXIMUM_ALLOWED_LENGTH = 140 6 | 7 | 8 | def add_fullstop(text): 9 | out = text+'.' if text[-1]!='.' else text 10 | return out 11 | 12 | def break_long_sentences(text): 13 | return textwrap.wrap(text, _MAXIMUM_ALLOWED_LENGTH, break_long_words=False) 14 | 15 | def preprocess_text(text): 16 | ''' 17 | Takes in string, replaces numbers with words, wraps the string into a list 18 | of multiple lines. 19 | ''' 20 | text = re.sub(r"(\d+)", lambda x: num2words.num2words(int(x.group(0))), text) 21 | lines = break_long_sentences(text) 22 | if lines: 23 | lines[-1] = add_fullstop(lines[-1]) 24 | return lines -------------------------------------------------------------------------------- /multiproc.py: -------------------------------------------------------------------------------- 1 | import time 2 | import torch 3 | import sys 4 | import subprocess 5 | 6 | argslist = list(sys.argv)[1:] 7 | num_gpus = torch.cuda.device_count() 8 | argslist.append('--n_gpus={}'.format(num_gpus)) 9 | workers = [] 10 | job_id = time.strftime("%Y_%m_%d-%H%M%S") 11 | argslist.append("--group_name=group_{}".format(job_id)) 12 | 13 | for i in range(num_gpus): 14 | argslist.append('--rank={}'.format(i)) 15 | stdout = None if i == 0 else open("logs/{}_GPU_{}.log".format(job_id, i), 16 | "w") 17 | print(argslist) 18 | p = subprocess.Popen([str(sys.executable)]+argslist, stdout=stdout) 19 | workers.append(p) 20 | argslist = argslist[:-1] 21 | 22 | for p in workers: 23 | p.wait() 24 | -------------------------------------------------------------------------------- /loss_function.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | 4 | class Tacotron2Loss(nn.Module): 5 | def __init__(self): 6 | super(Tacotron2Loss, self).__init__() 7 | 8 | def forward(self, model_output, targets): 9 | mel_target, gate_target = targets[0], targets[1] 10 | mel_target.requires_grad = False 11 | gate_target.requires_grad = False 12 | gate_target = gate_target.view(-1, 1) 13 | 14 | mel_out, mel_out_postnet, gate_out, _ = model_output 15 | gate_out = gate_out.view(-1, 1) 16 | mel_loss = nn.MSELoss()(mel_out, mel_target) + \ 17 | nn.MSELoss()(mel_out_postnet, mel_target) 18 | gate_loss = nn.BCEWithLogitsLoss()(gate_out, gate_target) 19 | return mel_loss + gate_loss 20 | -------------------------------------------------------------------------------- /timerthread.py: -------------------------------------------------------------------------------- 1 | # from https://stackoverflow.com/a/14369192 2 | from PyQt5 import QtCore, QtGui 3 | from PyQt5.QtCore import QObject, pyqtSignal, pyqtSlot, QThread 4 | import time 5 | 6 | class timerThread(QThread): 7 | timeElapsed = pyqtSignal(int) 8 | 9 | def __init__(self, timeoffset, parent=None): 10 | super(timerThread, self).__init__(parent) 11 | self.timeoffset = timeoffset 12 | self.timeStart = None 13 | 14 | 15 | def start(self, timeStart): 16 | self.timeStart = timeStart 17 | 18 | return super(timerThread, self).start() 19 | 20 | def run(self): 21 | while self.parent().isRunning(): 22 | self.timeElapsed.emit(time.time() - self.timeStart + self.timeoffset) 23 | time.sleep(1) 24 | -------------------------------------------------------------------------------- /text/symbols.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | 3 | ''' 4 | Defines the set of symbols used in text input to the model. 5 | 6 | The default is a set of ASCII characters that works well for English or text that has been run through Unidecode. For other data, you can modify _characters. See TRAINING_DATA.md for details. ''' 7 | from text import cmudict 8 | 9 | _pad = '_' 10 | _punctuation = '!\'(),.:;? ' 11 | _special = '-' 12 | _letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz' 13 | 14 | # Prepend "@" to ARPAbet symbols to ensure uniqueness (some are the same as uppercase letters): 15 | _arpabet = ['@' + s for s in cmudict.valid_symbols] 16 | 17 | # Export all symbols: 18 | symbols = [_pad] + list(_special) + list(_punctuation) + list(_letters) + _arpabet 19 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.io.wavfile import read 3 | import torch 4 | 5 | 6 | def get_mask_from_lengths(lengths): 7 | max_len = torch.max(lengths).item() 8 | ids = torch.arange(0, max_len, out=torch.cuda.LongTensor(max_len)) 9 | mask = (ids < lengths.unsqueeze(1)).bool() 10 | return mask 11 | 12 | 13 | def load_wav_to_torch(full_path): 14 | sampling_rate, data = read(full_path) 15 | return torch.FloatTensor(data.astype(np.float32)), sampling_rate 16 | 17 | 18 | def load_filepaths_and_text(filename, split="|"): 19 | with open(filename, encoding='utf-8') as f: 20 | filepaths_and_text = [line.strip().split(split) for line in f] 21 | return filepaths_and_text 22 | 23 | 24 | def to_gpu(x): 25 | x = x.contiguous() 26 | 27 | if torch.cuda.is_available(): 28 | x = x.cuda(non_blocking=True) 29 | return torch.autograd.Variable(x) 30 | -------------------------------------------------------------------------------- /text/LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2017 Keith Ito 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in 11 | all copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 19 | THE SOFTWARE. 20 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2018, NVIDIA Corporation 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | * Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | * Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | * Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /plotting_utils.py: -------------------------------------------------------------------------------- 1 | import matplotlib 2 | matplotlib.use("Agg") 3 | import matplotlib.pylab as plt 4 | import numpy as np 5 | 6 | 7 | def save_figure_to_numpy(fig): 8 | # save it to a numpy array. 9 | data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='') 10 | data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) 11 | return data 12 | 13 | 14 | def plot_alignment_to_numpy(alignment, info=None): 15 | fig, ax = plt.subplots(figsize=(6, 4)) 16 | im = ax.imshow(alignment, aspect='auto', origin='lower', 17 | interpolation='none') 18 | fig.colorbar(im, ax=ax) 19 | xlabel = 'Decoder timestep' 20 | if info is not None: 21 | xlabel += '\n\n' + info 22 | plt.xlabel(xlabel) 23 | plt.ylabel('Encoder timestep') 24 | plt.tight_layout() 25 | 26 | fig.canvas.draw() 27 | data = save_figure_to_numpy(fig) 28 | plt.close() 29 | return data 30 | 31 | 32 | def plot_spectrogram_to_numpy(spectrogram): 33 | fig, ax = plt.subplots(figsize=(12, 3)) 34 | im = ax.imshow(spectrogram, aspect="auto", origin="lower", 35 | interpolation='none') 36 | plt.colorbar(im, ax=ax) 37 | plt.xlabel("Frames") 38 | plt.ylabel("Channels") 39 | plt.tight_layout() 40 | 41 | fig.canvas.draw() 42 | data = save_figure_to_numpy(fig) 43 | plt.close() 44 | return data 45 | 46 | 47 | def plot_gate_outputs_to_numpy(gate_targets, gate_outputs): 48 | fig, ax = plt.subplots(figsize=(12, 3)) 49 | ax.scatter(range(len(gate_targets)), gate_targets, alpha=0.5, 50 | color='green', marker='+', s=1, label='target') 51 | ax.scatter(range(len(gate_outputs)), gate_outputs, alpha=0.5, 52 | color='red', marker='.', s=1, label='predicted') 53 | 54 | plt.xlabel("Frames (Green target, Red predicted)") 55 | plt.ylabel("Gate State") 56 | plt.tight_layout() 57 | 58 | fig.canvas.draw() 59 | data = save_figure_to_numpy(fig) 60 | plt.close() 61 | return data 62 | -------------------------------------------------------------------------------- /logger.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | from torch.utils.tensorboard import SummaryWriter 4 | from plotting_utils import plot_alignment_to_numpy, plot_spectrogram_to_numpy 5 | from plotting_utils import plot_gate_outputs_to_numpy 6 | 7 | 8 | class Tacotron2Logger(SummaryWriter): 9 | def __init__(self, logdir): 10 | super(Tacotron2Logger, self).__init__(logdir) 11 | 12 | def log_training(self, reduced_loss, grad_norm, learning_rate, duration, 13 | iteration): 14 | self.add_scalar("training.loss", reduced_loss, iteration) 15 | self.add_scalar("grad.norm", grad_norm, iteration) 16 | self.add_scalar("learning.rate", learning_rate, iteration) 17 | self.add_scalar("duration", duration, iteration) 18 | 19 | def log_validation(self, reduced_loss, model, y, y_pred, iteration): 20 | self.add_scalar("validation.loss", reduced_loss, iteration) 21 | _, mel_outputs, gate_outputs, alignments = y_pred 22 | mel_targets, gate_targets = y 23 | 24 | # plot distribution of parameters 25 | for tag, value in model.named_parameters(): 26 | tag = tag.replace('.', '/') 27 | self.add_histogram(tag, value.data.cpu().numpy(), iteration) 28 | 29 | # plot alignment, mel target and predicted, gate target and predicted 30 | idx = random.randint(0, alignments.size(0) - 1) 31 | self.add_image( 32 | "alignment", 33 | plot_alignment_to_numpy(alignments[idx].data.cpu().numpy().T), 34 | iteration, dataformats='HWC') 35 | self.add_image( 36 | "mel_target", 37 | plot_spectrogram_to_numpy(mel_targets[idx].data.cpu().numpy()), 38 | iteration, dataformats='HWC') 39 | self.add_image( 40 | "mel_predicted", 41 | plot_spectrogram_to_numpy(mel_outputs[idx].data.cpu().numpy()), 42 | iteration, dataformats='HWC') 43 | self.add_image( 44 | "gate", 45 | plot_gate_outputs_to_numpy( 46 | gate_targets[idx].data.cpu().numpy(), 47 | torch.sigmoid(gate_outputs[idx]).data.cpu().numpy()), 48 | iteration, dataformats='HWC') 49 | -------------------------------------------------------------------------------- /text/cmudict.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | 3 | import re 4 | 5 | 6 | valid_symbols = [ 7 | 'AA', 'AA0', 'AA1', 'AA2', 'AE', 'AE0', 'AE1', 'AE2', 'AH', 'AH0', 'AH1', 'AH2', 8 | 'AO', 'AO0', 'AO1', 'AO2', 'AW', 'AW0', 'AW1', 'AW2', 'AY', 'AY0', 'AY1', 'AY2', 9 | 'B', 'CH', 'D', 'DH', 'EH', 'EH0', 'EH1', 'EH2', 'ER', 'ER0', 'ER1', 'ER2', 'EY', 10 | 'EY0', 'EY1', 'EY2', 'F', 'G', 'HH', 'IH', 'IH0', 'IH1', 'IH2', 'IY', 'IY0', 'IY1', 11 | 'IY2', 'JH', 'K', 'L', 'M', 'N', 'NG', 'OW', 'OW0', 'OW1', 'OW2', 'OY', 'OY0', 12 | 'OY1', 'OY2', 'P', 'R', 'S', 'SH', 'T', 'TH', 'UH', 'UH0', 'UH1', 'UH2', 'UW', 13 | 'UW0', 'UW1', 'UW2', 'V', 'W', 'Y', 'Z', 'ZH' 14 | ] 15 | 16 | _valid_symbol_set = set(valid_symbols) 17 | 18 | 19 | class CMUDict: 20 | '''Thin wrapper around CMUDict data. http://www.speech.cs.cmu.edu/cgi-bin/cmudict''' 21 | def __init__(self, file_or_path, keep_ambiguous=True): 22 | if isinstance(file_or_path, str): 23 | with open(file_or_path, encoding='latin-1') as f: 24 | entries = _parse_cmudict(f) 25 | else: 26 | entries = _parse_cmudict(file_or_path) 27 | if not keep_ambiguous: 28 | entries = {word: pron for word, pron in entries.items() if len(pron) == 1} 29 | self._entries = entries 30 | 31 | 32 | def __len__(self): 33 | return len(self._entries) 34 | 35 | 36 | def lookup(self, word): 37 | '''Returns list of ARPAbet pronunciations of the given word.''' 38 | return self._entries.get(word.upper()) 39 | 40 | 41 | 42 | _alt_re = re.compile(r'\([0-9]+\)') 43 | 44 | 45 | def _parse_cmudict(file): 46 | cmudict = {} 47 | for line in file: 48 | if len(line) and (line[0] >= 'A' and line[0] <= 'Z' or line[0] == "'"): 49 | parts = line.split(' ') 50 | word = re.sub(_alt_re, '', parts[0]) 51 | pronunciation = _get_pronunciation(parts[1]) 52 | if pronunciation: 53 | if word in cmudict: 54 | cmudict[word].append(pronunciation) 55 | else: 56 | cmudict[word] = [pronunciation] 57 | return cmudict 58 | 59 | 60 | def _get_pronunciation(s): 61 | parts = s.strip().split(' ') 62 | for part in parts: 63 | if part not in _valid_symbol_set: 64 | return None 65 | return ' '.join(parts) 66 | -------------------------------------------------------------------------------- /text/numbers.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | 3 | import inflect 4 | import re 5 | 6 | 7 | _inflect = inflect.engine() 8 | _comma_number_re = re.compile(r'([0-9][0-9\,]+[0-9])') 9 | _decimal_number_re = re.compile(r'([0-9]+\.[0-9]+)') 10 | _pounds_re = re.compile(r'£([0-9\,]*[0-9]+)') 11 | _dollars_re = re.compile(r'\$([0-9\.\,]*[0-9]+)') 12 | _ordinal_re = re.compile(r'[0-9]+(st|nd|rd|th)') 13 | _number_re = re.compile(r'[0-9]+') 14 | 15 | 16 | def _remove_commas(m): 17 | return m.group(1).replace(',', '') 18 | 19 | 20 | def _expand_decimal_point(m): 21 | return m.group(1).replace('.', ' point ') 22 | 23 | 24 | def _expand_dollars(m): 25 | match = m.group(1) 26 | parts = match.split('.') 27 | if len(parts) > 2: 28 | return match + ' dollars' # Unexpected format 29 | dollars = int(parts[0]) if parts[0] else 0 30 | cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0 31 | if dollars and cents: 32 | dollar_unit = 'dollar' if dollars == 1 else 'dollars' 33 | cent_unit = 'cent' if cents == 1 else 'cents' 34 | return '%s %s, %s %s' % (dollars, dollar_unit, cents, cent_unit) 35 | elif dollars: 36 | dollar_unit = 'dollar' if dollars == 1 else 'dollars' 37 | return '%s %s' % (dollars, dollar_unit) 38 | elif cents: 39 | cent_unit = 'cent' if cents == 1 else 'cents' 40 | return '%s %s' % (cents, cent_unit) 41 | else: 42 | return 'zero dollars' 43 | 44 | 45 | def _expand_ordinal(m): 46 | return _inflect.number_to_words(m.group(0)) 47 | 48 | 49 | def _expand_number(m): 50 | num = int(m.group(0)) 51 | if num > 1000 and num < 3000: 52 | if num == 2000: 53 | return 'two thousand' 54 | elif num > 2000 and num < 2010: 55 | return 'two thousand ' + _inflect.number_to_words(num % 100) 56 | elif num % 100 == 0: 57 | return _inflect.number_to_words(num // 100) + ' hundred' 58 | else: 59 | return _inflect.number_to_words(num, andword='', zero='oh', group=2).replace(', ', ' ') 60 | else: 61 | return _inflect.number_to_words(num, andword='') 62 | 63 | 64 | def normalize_numbers(text): 65 | text = re.sub(_comma_number_re, _remove_commas, text) 66 | text = re.sub(_pounds_re, r'\1 pounds', text) 67 | text = re.sub(_dollars_re, _expand_dollars, text) 68 | text = re.sub(_decimal_number_re, _expand_decimal_point, text) 69 | text = re.sub(_ordinal_re, _expand_ordinal, text) 70 | text = re.sub(_number_re, _expand_number, text) 71 | return text 72 | -------------------------------------------------------------------------------- /text/__init__.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | import re 3 | from text import cleaners 4 | from text.symbols import symbols 5 | 6 | 7 | # Mappings from symbol to numeric ID and vice versa: 8 | _symbol_to_id = {s: i for i, s in enumerate(symbols)} 9 | _id_to_symbol = {i: s for i, s in enumerate(symbols)} 10 | 11 | # Regular expression matching text enclosed in curly braces: 12 | _curly_re = re.compile(r'(.*?)\{(.+?)\}(.*)') 13 | 14 | 15 | def text_to_sequence(text, cleaner_names): 16 | '''Converts a string of text to a sequence of IDs corresponding to the symbols in the text. 17 | 18 | The text can optionally have ARPAbet sequences enclosed in curly braces embedded 19 | in it. For example, "Turn left on {HH AW1 S S T AH0 N} Street." 20 | 21 | Args: 22 | text: string to convert to a sequence 23 | cleaner_names: names of the cleaner functions to run the text through 24 | 25 | Returns: 26 | List of integers corresponding to the symbols in the text 27 | ''' 28 | sequence = [] 29 | 30 | # Check for curly braces and treat their contents as ARPAbet: 31 | while len(text): 32 | m = _curly_re.match(text) 33 | if not m: 34 | sequence += _symbols_to_sequence(_clean_text(text, cleaner_names)) 35 | break 36 | sequence += _symbols_to_sequence(_clean_text(m.group(1), cleaner_names)) 37 | sequence += _arpabet_to_sequence(m.group(2)) 38 | text = m.group(3) 39 | 40 | return sequence 41 | 42 | 43 | def sequence_to_text(sequence): 44 | '''Converts a sequence of IDs back to a string''' 45 | result = '' 46 | for symbol_id in sequence: 47 | if symbol_id in _id_to_symbol: 48 | s = _id_to_symbol[symbol_id] 49 | # Enclose ARPAbet back in curly braces: 50 | if len(s) > 1 and s[0] == '@': 51 | s = '{%s}' % s[1:] 52 | result += s 53 | return result.replace('}{', ' ') 54 | 55 | 56 | def _clean_text(text, cleaner_names): 57 | for name in cleaner_names: 58 | cleaner = getattr(cleaners, name) 59 | if not cleaner: 60 | raise Exception('Unknown cleaner: %s' % name) 61 | text = cleaner(text) 62 | return text 63 | 64 | 65 | def _symbols_to_sequence(symbols): 66 | return [_symbol_to_id[s] for s in symbols if _should_keep_symbol(s)] 67 | 68 | 69 | def _arpabet_to_sequence(text): 70 | return _symbols_to_sequence(['@' + s for s in text.split()]) 71 | 72 | 73 | def _should_keep_symbol(s): 74 | return s in _symbol_to_id and s is not '_' and s is not '~' 75 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | *.egg-info/ 30 | .installed.cfg 31 | *.egg 32 | 33 | # PyInstaller 34 | # Usually these files are written by a python script from a template 35 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 36 | *.manifest 37 | *.spec 38 | 39 | # Installer logs 40 | pip-log.txt 41 | pip-delete-this-directory.txt 42 | 43 | # Unit test / coverage reports 44 | htmlcov/ 45 | .tox/ 46 | .nox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | *.cover 53 | *.py,cover 54 | .hypothesis/ 55 | .pytest_cache/ 56 | cover/ 57 | *,cover 58 | 59 | # Translations 60 | *.mo 61 | *.pot 62 | 63 | # Django stuff: 64 | *.log 65 | local_settings.py 66 | db.sqlite3 67 | db.sqlite3-journal 68 | 69 | # Flask stuff: 70 | instance/ 71 | .webassets-cache 72 | 73 | # Scrapy stuff: 74 | .scrapy 75 | 76 | # Sphinx documentation 77 | docs/_build/ 78 | 79 | # PyBuilder 80 | .pybuilder/ 81 | target/ 82 | 83 | # Jupyter Notebook 84 | .ipynb_checkpoints 85 | 86 | # IPython 87 | profile_default/ 88 | ipython_config.py 89 | 90 | # pyenv 91 | # For a library or package, you might want to ignore these files since the code is 92 | # intended to run in multiple environments; otherwise, check them in: 93 | # .python-version 94 | 95 | # pipenv 96 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 97 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 98 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 99 | # install all needed dependencies. 100 | #Pipfile.lock 101 | 102 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 103 | __pypackages__/ 104 | 105 | # Celery stuff 106 | celerybeat-schedule 107 | celerybeat.pid 108 | 109 | # SageMath parsed files 110 | *.sage.py 111 | 112 | # Environments 113 | .env 114 | .venv 115 | env/ 116 | venv/ 117 | ENV/ 118 | env.bak/ 119 | venv.bak/ 120 | virtualenv/ 121 | 122 | # Spyder project settings 123 | .spyderproject 124 | .spyproject 125 | 126 | # Rope project settings 127 | .ropeproject 128 | 129 | # mkdocs documentation 130 | /site 131 | 132 | # mypy 133 | .mypy_cache/ 134 | .dmypy.json 135 | dmypy.json 136 | 137 | # Pyre type checker 138 | .pyre/ 139 | 140 | # pytype static type analyzer 141 | .pytype/ 142 | 143 | # Cython debug symbols 144 | cython_debug/ 145 | 146 | # Others 147 | .ipynb_checkpoints/ 148 | .vs/ 149 | srt2fcpxml.py 150 | secrets.py 151 | .vscode/ 152 | target/ 153 | -------------------------------------------------------------------------------- /audio_processing.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from scipy.signal import get_window 4 | import librosa.util as librosa_util 5 | 6 | 7 | def window_sumsquare(window, n_frames, hop_length=200, win_length=800, 8 | n_fft=800, dtype=np.float32, norm=None): 9 | """ 10 | # from librosa 0.6 11 | Compute the sum-square envelope of a window function at a given hop length. 12 | 13 | This is used to estimate modulation effects induced by windowing 14 | observations in short-time fourier transforms. 15 | 16 | Parameters 17 | ---------- 18 | window : string, tuple, number, callable, or list-like 19 | Window specification, as in `get_window` 20 | 21 | n_frames : int > 0 22 | The number of analysis frames 23 | 24 | hop_length : int > 0 25 | The number of samples to advance between frames 26 | 27 | win_length : [optional] 28 | The length of the window function. By default, this matches `n_fft`. 29 | 30 | n_fft : int > 0 31 | The length of each analysis frame. 32 | 33 | dtype : np.dtype 34 | The data type of the output 35 | 36 | Returns 37 | ------- 38 | wss : np.ndarray, shape=`(n_fft + hop_length * (n_frames - 1))` 39 | The sum-squared envelope of the window function 40 | """ 41 | if win_length is None: 42 | win_length = n_fft 43 | 44 | n = n_fft + hop_length * (n_frames - 1) 45 | x = np.zeros(n, dtype=dtype) 46 | 47 | # Compute the squared window at the desired length 48 | win_sq = get_window(window, win_length, fftbins=True) 49 | win_sq = librosa_util.normalize(win_sq, norm=norm)**2 50 | win_sq = librosa_util.pad_center(win_sq, n_fft) 51 | 52 | # Fill the envelope 53 | for i in range(n_frames): 54 | sample = i * hop_length 55 | x[sample:min(n, sample + n_fft)] += win_sq[:max(0, min(n_fft, n - sample))] 56 | return x 57 | 58 | 59 | def griffin_lim(magnitudes, stft_fn, n_iters=30): 60 | """ 61 | PARAMS 62 | ------ 63 | magnitudes: spectrogram magnitudes 64 | stft_fn: STFT class with transform (STFT) and inverse (ISTFT) methods 65 | """ 66 | 67 | angles = np.angle(np.exp(2j * np.pi * np.random.rand(*magnitudes.size()))) 68 | angles = angles.astype(np.float32) 69 | angles = torch.autograd.Variable(torch.from_numpy(angles)) 70 | signal = stft_fn.inverse(magnitudes, angles).squeeze(1) 71 | 72 | for i in range(n_iters): 73 | _, angles = stft_fn.transform(signal) 74 | signal = stft_fn.inverse(magnitudes, angles).squeeze(1) 75 | return signal 76 | 77 | 78 | def dynamic_range_compression(x, C=1, clip_val=1e-5): 79 | """ 80 | PARAMS 81 | ------ 82 | C: compression factor 83 | """ 84 | return torch.log(torch.clamp(x, min=clip_val) * C) 85 | 86 | 87 | def dynamic_range_decompression(x, C=1): 88 | """ 89 | PARAMS 90 | ------ 91 | C: compression factor used to compress 92 | """ 93 | return torch.exp(x) / C 94 | -------------------------------------------------------------------------------- /ui.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from PyQt5 import QtCore, QtGui, QtWidgets 4 | from switch import Switch 5 | import torch 6 | 7 | class Ui_extras(object): 8 | def drawGpuSwitch(self, MainWindow): 9 | MainWindow.GpuSwitch = Switch(thumb_radius=8, track_radius=10, show_text = False) 10 | MainWindow.horizontalLayout.addWidget(MainWindow.GpuSwitch) 11 | MainWindow.GpuSwitch.setEnabled(torch.cuda.is_available()) 12 | MainWindow.use_cuda = False 13 | MainWindow.GpuSwitch.setToolTip("
|
9 | # Overview
10 | A machine learning based Text to Speech program with a user friendly GUI. Target audience include Twitch streamers or content creators looking for an open source TTS program. The aim of this software is to make tts synthesis accessible offline (No coding experience, gpu/colab) in a portable exe.
11 |
12 | ## Features
13 | * Reads donations from Stream Elements automatically
14 | * PyQt5 wrapper for NVIDIA/tacotron2 & /waveglow
15 |
16 | ## Download Link
17 | A portable executable can be found at the [Releases](https://github.com/lokkelvin2/tacotron2-tts-GUI/releases) page, or directly [here](https://github.com/lokkelvin2/tacotron2-tts-GUI/releases/download/v0.3/nvidia_waveglow-v0.3.1_x86_64.exe). Download a pretrained *Tacotron 2* and *Waveglow* model from below.
18 |
19 | Warning: the portable executable runs on CPU which leads to a >10x speed slowdown compared to running it on GPU.
20 |
21 | # Building from source
22 | ## Requirements
23 | * Python >=3.7
24 | * librosa
25 | * numpy
26 | * PyQt5==5.15.0
27 | * requests
28 | * tqdm
29 | * matplotlib
30 | * scipy
31 | * num2words
32 | * pygame
33 |
34 | [PyTorch 1.0](https://pytorch.org/)
35 |
36 | ## To Run
37 | ```
38 | python gui.py
39 | ```
40 | ## License
41 | * NVIDIA/tacotron2 & waveglow: BSD-3-Clause License
42 |
43 | ## Notes
44 | * TTS code from [NVIDIA/tacotron2](https://github.com/NVIDIA/tacotron2)
45 | * Partial GUI code from [https://github.com/CorentinJ/Real-Time-Voice-Cloning](https://github.com/CorentinJ/Real-Time-Voice-Cloning) and layout inspired by u/realstreamer's Forsen TTS [https://www.youtube.com/watch?v=kL2tglbcDCo](https://www.youtube.com/watch?v=kL2tglbcDCo)
46 |
47 |
48 | # Original Repo:
49 |
50 | # Tacotron 2 (without wavenet)
51 |
52 | PyTorch implementation of [Natural TTS Synthesis By Conditioning
53 | Wavenet On Mel Spectrogram Predictions](https://arxiv.org/pdf/1712.05884.pdf).
54 |
55 | This implementation includes **distributed** and **automatic mixed precision** support
56 | and uses the [LJSpeech dataset](https://keithito.com/LJ-Speech-Dataset/).
57 |
58 | Distributed and Automatic Mixed Precision support relies on NVIDIA's [Apex] and [AMP].
59 |
60 | Visit our [website] for audio samples using our published [Tacotron 2] and
61 | [WaveGlow] models.
62 |
63 | 
64 |
65 |
66 | ## Pre-requisites
67 | 1. NVIDIA GPU + CUDA cuDNN
68 |
69 | ## Setup
70 | 1. Download and extract the [LJ Speech dataset](https://keithito.com/LJ-Speech-Dataset/)
71 | 2. Clone this repo: `git clone https://github.com/NVIDIA/tacotron2.git`
72 | 3. CD into this repo: `cd tacotron2`
73 | 4. Initialize submodule: `git submodule init; git submodule update`
74 | 5. Update .wav paths: `sed -i -- 's,DUMMY,ljs_dataset_folder/wavs,g' filelists/*.txt`
75 | - Alternatively, set `load_mel_from_disk=True` in `hparams.py` and update mel-spectrogram paths
76 | 6. Install [PyTorch 1.0]
77 | 7. Install [Apex]
78 | 8. Install python requirements or build docker image
79 | - Install python requirements: `pip install -r requirements.txt`
80 |
81 | ## Training
82 | 1. `python train.py --output_directory=outdir --log_directory=logdir`
83 | 2. (OPTIONAL) `tensorboard --logdir=outdir/logdir`
84 |
85 | ## Training using a pre-trained model
86 | Training using a pre-trained model can lead to faster convergence
87 | By default, the dataset dependent text embedding layers are [ignored]
88 |
89 | 1. Download our published [Tacotron 2] model
90 | 2. `python train.py --output_directory=outdir --log_directory=logdir -c tacotron2_statedict.pt --warm_start`
91 |
92 | ## Multi-GPU (distributed) and Automatic Mixed Precision Training
93 | 1. `python -m multiproc train.py --output_directory=outdir --log_directory=logdir --hparams=distributed_run=True,fp16_run=True`
94 |
95 | ## Inference demo
96 | 1. Download our published [Tacotron 2] model
97 | 2. Download our published [WaveGlow] model
98 | 3. `jupyter notebook --ip=127.0.0.1 --port=31337`
99 | 4. Load inference.ipynb
100 |
101 | N.b. When performing Mel-Spectrogram to Audio synthesis, make sure Tacotron 2
102 | and the Mel decoder were trained on the same mel-spectrogram representation.
103 |
104 |
105 | ## Related repos
106 | [WaveGlow](https://github.com/NVIDIA/WaveGlow) Faster than real time Flow-based
107 | Generative Network for Speech Synthesis
108 |
109 | [nv-wavenet](https://github.com/NVIDIA/nv-wavenet/) Faster than real time
110 | WaveNet.
111 |
112 | ## Acknowledgements
113 | This implementation uses code from the following repos: [Keith
114 | Ito](https://github.com/keithito/tacotron/), [Prem
115 | Seetharaman](https://github.com/pseeth/pytorch-stft) as described in our code.
116 |
117 | We are inspired by [Ryuchi Yamamoto's](https://github.com/r9y9/tacotron_pytorch)
118 | Tacotron PyTorch implementation.
119 |
120 | We are thankful to the Tacotron 2 paper authors, specially Jonathan Shen, Yuxuan
121 | Wang and Zongheng Yang.
122 |
123 |
124 | [WaveGlow]: https://drive.google.com/open?id=1rpK8CzAAirq9sWZhe9nlfvxMF1dRgFbF
125 | [Tacotron 2]: https://drive.google.com/file/d/1c5ZTuT7J08wLUoVZ2KkUs_VdZuJ86ZqA/view?usp=sharing
126 | [pytorch 1.0]: https://github.com/pytorch/pytorch#installation
127 | [website]: https://nv-adlr.github.io/WaveGlow
128 | [ignored]: https://github.com/NVIDIA/tacotron2/blob/master/hparams.py#L22
129 | [Apex]: https://github.com/nvidia/apex
130 | [AMP]: https://github.com/NVIDIA/apex/tree/master/apex/amp
131 |
--------------------------------------------------------------------------------
/stft.py:
--------------------------------------------------------------------------------
1 | """
2 | BSD 3-Clause License
3 |
4 | Copyright (c) 2017, Prem Seetharaman
5 | All rights reserved.
6 |
7 | * Redistribution and use in source and binary forms, with or without
8 | modification, are permitted provided that the following conditions are met:
9 |
10 | * Redistributions of source code must retain the above copyright notice,
11 | this list of conditions and the following disclaimer.
12 |
13 | * Redistributions in binary form must reproduce the above copyright notice, this
14 | list of conditions and the following disclaimer in the
15 | documentation and/or other materials provided with the distribution.
16 |
17 | * Neither the name of the copyright holder nor the names of its
18 | contributors may be used to endorse or promote products derived from this
19 | software without specific prior written permission.
20 |
21 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
22 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
23 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
24 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
25 | ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
26 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
27 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
28 | ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
29 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
30 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
31 | """
32 |
33 | import torch
34 | import numpy as np
35 | import torch.nn.functional as F
36 | from torch.autograd import Variable
37 | from scipy.signal import get_window
38 | from librosa.util import pad_center, tiny
39 | from audio_processing import window_sumsquare
40 |
41 |
42 | class STFT(torch.nn.Module):
43 | """adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft"""
44 | def __init__(self, filter_length=800, hop_length=200, win_length=800,
45 | window='hann',use_cuda = True):
46 | super(STFT, self).__init__()
47 | self.device = torch.device('cuda' if use_cuda else 'cpu')
48 | self.filter_length = filter_length
49 | self.hop_length = hop_length
50 | self.win_length = win_length
51 | self.window = window
52 | self.forward_transform = None
53 | scale = self.filter_length / self.hop_length
54 | fourier_basis = np.fft.fft(np.eye(self.filter_length))
55 |
56 | cutoff = int((self.filter_length / 2 + 1))
57 | fourier_basis = np.vstack([np.real(fourier_basis[:cutoff, :]),
58 | np.imag(fourier_basis[:cutoff, :])])
59 |
60 | forward_basis = torch.FloatTensor(fourier_basis[:, None, :])
61 | inverse_basis = torch.FloatTensor(
62 | np.linalg.pinv(scale * fourier_basis).T[:, None, :])
63 |
64 | if window is not None:
65 | assert(filter_length >= win_length)
66 | # get window and zero center pad it to filter_length
67 | fft_window = get_window(window, win_length, fftbins=True)
68 | fft_window = pad_center(fft_window, filter_length)
69 | fft_window = torch.from_numpy(fft_window).float()
70 |
71 | # window the bases
72 | forward_basis *= fft_window
73 | inverse_basis *= fft_window
74 |
75 | self.register_buffer('forward_basis', forward_basis.float())
76 | self.register_buffer('inverse_basis', inverse_basis.float())
77 |
78 | def transform(self, input_data):
79 | num_batches = input_data.size(0)
80 | num_samples = input_data.size(1)
81 |
82 | self.num_samples = num_samples
83 |
84 | # similar to librosa, reflect-pad the input
85 | input_data = input_data.view(num_batches, 1, num_samples)
86 | input_data = F.pad(
87 | input_data.unsqueeze(1),
88 | (int(self.filter_length / 2), int(self.filter_length / 2), 0, 0),
89 | mode='reflect')
90 | input_data = input_data.squeeze(1)
91 |
92 | forward_transform = F.conv1d(
93 | input_data.to(self.device),
94 | Variable(self.forward_basis, requires_grad=False).to(self.device),
95 | stride=self.hop_length,
96 | padding=0)
97 |
98 | cutoff = int((self.filter_length / 2) + 1)
99 | real_part = forward_transform[:, :cutoff, :]
100 | imag_part = forward_transform[:, cutoff:, :]
101 |
102 | magnitude = torch.sqrt(real_part**2 + imag_part**2)
103 | phase = torch.autograd.Variable(
104 | torch.atan2(imag_part.data, real_part.data))
105 |
106 | return magnitude, phase
107 |
108 | def inverse(self, magnitude, phase):
109 | recombine_magnitude_phase = torch.cat(
110 | [magnitude*torch.cos(phase), magnitude*torch.sin(phase)], dim=1)
111 |
112 | inverse_transform = F.conv_transpose1d(
113 | recombine_magnitude_phase,
114 | Variable(self.inverse_basis, requires_grad=False),
115 | stride=self.hop_length,
116 | padding=0)
117 |
118 | if self.window is not None:
119 | window_sum = window_sumsquare(
120 | self.window, magnitude.size(-1), hop_length=self.hop_length,
121 | win_length=self.win_length, n_fft=self.filter_length,
122 | dtype=np.float32)
123 | # remove modulation effects
124 | approx_nonzero_indices = torch.from_numpy(
125 | np.where(window_sum > tiny(window_sum))[0])
126 | window_sum = torch.autograd.Variable(
127 | torch.from_numpy(window_sum), requires_grad=False)
128 | window_sum = window_sum.to(self.device) if magnitude.is_cuda else window_sum
129 | inverse_transform[:, :, approx_nonzero_indices] /= window_sum[approx_nonzero_indices]
130 |
131 | # scale by hop ratio
132 | inverse_transform *= float(self.filter_length) / self.hop_length
133 |
134 | inverse_transform = inverse_transform[:, :, int(self.filter_length/2):]
135 | inverse_transform = inverse_transform[:, :, :-int(self.filter_length/2):]
136 |
137 | return inverse_transform
138 |
139 | def forward(self, input_data):
140 | self.magnitude, self.phase = self.transform(input_data)
141 | reconstruction = self.inverse(self.magnitude, self.phase)
142 | return reconstruction
143 |
--------------------------------------------------------------------------------
/switch.py:
--------------------------------------------------------------------------------
1 | # Taken from https://stackoverflow.com/a/51825815
2 | from PyQt5.QtCore import QPropertyAnimation, QRectF, QSize, Qt, pyqtProperty
3 | from PyQt5.QtGui import QPainter
4 | from PyQt5.QtWidgets import (
5 | QAbstractButton,
6 | QApplication,
7 | QHBoxLayout,
8 | QSizePolicy,
9 | QWidget,
10 | )
11 |
12 |
13 | class Switch(QAbstractButton):
14 | def __init__(self, parent=None, track_radius=10, thumb_radius=8, show_text = True):
15 | super().__init__(parent=parent)
16 | self.setCheckable(True)
17 | self.setSizePolicy(QSizePolicy.Fixed, QSizePolicy.Fixed)
18 |
19 | self._track_radius = track_radius
20 | self._thumb_radius = thumb_radius
21 |
22 | self._margin = max(0, self._thumb_radius - self._track_radius)
23 | self._base_offset = max(self._thumb_radius, self._track_radius)
24 | self._end_offset = {
25 | True: lambda: self.width() - self._base_offset,
26 | False: lambda: self._base_offset,
27 | }
28 | self._offset = self._base_offset
29 |
30 | palette = self.palette()
31 | if not show_text:
32 | self._track_color = {
33 | True: palette.highlight(),
34 | False: palette.dark(),
35 | }
36 | self._thumb_color = {
37 | True: palette.highlight(),
38 | False: palette.light(),
39 | }
40 | self._text_color = {
41 | True: palette.highlightedText().color(),
42 | False: palette.dark().color(),
43 | }
44 | self._thumb_text = {
45 | True: '',
46 | False: '',
47 | }
48 | self._track_opacity = 0.5
49 | else:
50 | self._thumb_color = {
51 | True: palette.highlightedText(),
52 | False: palette.light(),
53 | }
54 | self._track_color = {
55 | True: palette.highlight(),
56 | False: palette.dark(),
57 | }
58 | self._text_color = {
59 | True: palette.highlight().color(),
60 | False: palette.dark().color(),
61 | }
62 | self._thumb_text = {
63 | True: '✔',
64 | False: '✕',
65 | #True: 'On',
66 | #False: 'Off',
67 | }
68 | self._track_opacity = 1
69 |
70 | @pyqtProperty(int)
71 | def offset(self):
72 | return self._offset
73 |
74 | @offset.setter
75 | def offset(self, value):
76 | self._offset = value
77 | self.update()
78 |
79 | def sizeHint(self): # pylint: disable=invalid-name
80 | return QSize(
81 | 4 * self._track_radius + 2 * self._margin,
82 | 2 * self._track_radius + 2 * self._margin,
83 | )
84 |
85 | def setChecked(self, checked):
86 | super().setChecked(checked)
87 | self.offset = self._end_offset[checked]()
88 |
89 | def resizeEvent(self, event):
90 | super().resizeEvent(event)
91 | self.offset = self._end_offset[self.isChecked()]()
92 |
93 | def paintEvent(self, event): # pylint: disable=invalid-name, unused-argument
94 | p = QPainter(self)
95 | p.setRenderHint(QPainter.Antialiasing, True)
96 | p.setPen(Qt.NoPen)
97 | track_opacity = self._track_opacity
98 | thumb_opacity = 1.0
99 | text_opacity = 1.0
100 | if self.isEnabled():
101 | track_brush = self._track_color[self.isChecked()]
102 | thumb_brush = self._thumb_color[self.isChecked()]
103 | text_color = self._text_color[self.isChecked()]
104 | else:
105 | track_opacity *= 0.8
106 | track_brush = self.palette().shadow()
107 | thumb_brush = self.palette().mid()
108 | text_color = self.palette().shadow().color()
109 |
110 | p.setBrush(track_brush)
111 | p.setOpacity(track_opacity)
112 | p.drawRoundedRect(
113 | self._margin,
114 | self._margin,
115 | self.width() - 2 * self._margin,
116 | self.height() - 2 * self._margin,
117 | self._track_radius,
118 | self._track_radius,
119 | )
120 | p.setBrush(thumb_brush)
121 | p.setOpacity(thumb_opacity)
122 | p.drawEllipse(
123 | self.offset - self._thumb_radius,
124 | self._base_offset - self._thumb_radius,
125 | 2 * self._thumb_radius,
126 | 2 * self._thumb_radius,
127 | )
128 | p.setPen(text_color)
129 | p.setOpacity(text_opacity)
130 | font = p.font()
131 | font.setPixelSize(1.5 * self._thumb_radius)
132 | p.setFont(font)
133 | p.drawText(
134 | QRectF(
135 | self.offset - self._thumb_radius,
136 | self._base_offset - self._thumb_radius,
137 | 2 * self._thumb_radius,
138 | 2 * self._thumb_radius,
139 | ),
140 | Qt.AlignCenter,
141 | self._thumb_text[self.isChecked()],
142 | )
143 |
144 | def mouseReleaseEvent(self, event): # pylint: disable=invalid-name
145 | super().mouseReleaseEvent(event)
146 | if event.button() == Qt.LeftButton:
147 | anim = QPropertyAnimation(self, b'offset', self)
148 | anim.setDuration(120)
149 | anim.setStartValue(self.offset)
150 | anim.setEndValue(self._end_offset[self.isChecked()]())
151 | anim.start()
152 |
153 | def enterEvent(self, event): # pylint: disable=invalid-name
154 | self.setCursor(Qt.PointingHandCursor)
155 | super().enterEvent(event)
156 |
157 |
158 | def main():
159 | app = QApplication([])
160 |
161 | # Thumb size < track size (Gitlab style)
162 | s1 = Switch()
163 | s1.toggled.connect(lambda c: print('toggled', c))
164 | s1.clicked.connect(lambda c: print('clicked', c))
165 | s1.pressed.connect(lambda: print('pressed'))
166 | s1.released.connect(lambda: print('released'))
167 | s2 = Switch()
168 | s2.setEnabled(False)
169 |
170 | # Thumb size > track size (Android style)
171 | s3 = Switch(thumb_radius=11, track_radius=8)
172 | s4 = Switch(thumb_radius=8, track_radius=10, show_text = False)
173 |
174 | #s4.setEnabled(False)
175 |
176 | l = QHBoxLayout()
177 | l.addWidget(s1)
178 | l.addWidget(s2)
179 | l.addWidget(s3)
180 | l.addWidget(s4)
181 | w = QWidget()
182 | w.setLayout(l)
183 | w.show()
184 |
185 | app.exec()
186 |
187 |
188 | if __name__ == '__main__':
189 | main()
--------------------------------------------------------------------------------
/distributed.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.distributed as dist
3 | from torch.nn.modules import Module
4 | from torch.autograd import Variable
5 |
6 | def _flatten_dense_tensors(tensors):
7 | """Flatten dense tensors into a contiguous 1D buffer. Assume tensors are of
8 | same dense type.
9 | Since inputs are dense, the resulting tensor will be a concatenated 1D
10 | buffer. Element-wise operation on this buffer will be equivalent to
11 | operating individually.
12 | Arguments:
13 | tensors (Iterable[Tensor]): dense tensors to flatten.
14 | Returns:
15 | A contiguous 1D buffer containing input tensors.
16 | """
17 | if len(tensors) == 1:
18 | return tensors[0].contiguous().view(-1)
19 | flat = torch.cat([t.contiguous().view(-1) for t in tensors], dim=0)
20 | return flat
21 |
22 | def _unflatten_dense_tensors(flat, tensors):
23 | """View a flat buffer using the sizes of tensors. Assume that tensors are of
24 | same dense type, and that flat is given by _flatten_dense_tensors.
25 | Arguments:
26 | flat (Tensor): flattened dense tensors to unflatten.
27 | tensors (Iterable[Tensor]): dense tensors whose sizes will be used to
28 | unflatten flat.
29 | Returns:
30 | Unflattened dense tensors with sizes same as tensors and values from
31 | flat.
32 | """
33 | outputs = []
34 | offset = 0
35 | for tensor in tensors:
36 | numel = tensor.numel()
37 | outputs.append(flat.narrow(0, offset, numel).view_as(tensor))
38 | offset += numel
39 | return tuple(outputs)
40 |
41 |
42 | '''
43 | This version of DistributedDataParallel is designed to be used in conjunction with the multiproc.py
44 | launcher included with this example. It assumes that your run is using multiprocess with 1
45 | GPU/process, that the model is on the correct device, and that torch.set_device has been
46 | used to set the device.
47 |
48 | Parameters are broadcasted to the other processes on initialization of DistributedDataParallel,
49 | and will be allreduced at the finish of the backward pass.
50 | '''
51 | class DistributedDataParallel(Module):
52 |
53 | def __init__(self, module):
54 | super(DistributedDataParallel, self).__init__()
55 | #fallback for PyTorch 0.3
56 | if not hasattr(dist, '_backend'):
57 | self.warn_on_half = True
58 | else:
59 | self.warn_on_half = True if dist._backend == dist.dist_backend.GLOO else False
60 |
61 | self.module = module
62 |
63 | for p in self.module.state_dict().values():
64 | if not torch.is_tensor(p):
65 | continue
66 | dist.broadcast(p, 0)
67 |
68 | def allreduce_params():
69 | if(self.needs_reduction):
70 | self.needs_reduction = False
71 | buckets = {}
72 | for param in self.module.parameters():
73 | if param.requires_grad and param.grad is not None:
74 | tp = type(param.data)
75 | if tp not in buckets:
76 | buckets[tp] = []
77 | buckets[tp].append(param)
78 | if self.warn_on_half:
79 | if torch.cuda.HalfTensor in buckets:
80 | print("WARNING: gloo dist backend for half parameters may be extremely slow." +
81 | " It is recommended to use the NCCL backend in this case. This currently requires" +
82 | "PyTorch built from top of tree master.")
83 | self.warn_on_half = False
84 |
85 | for tp in buckets:
86 | bucket = buckets[tp]
87 | grads = [param.grad.data for param in bucket]
88 | coalesced = _flatten_dense_tensors(grads)
89 | dist.all_reduce(coalesced)
90 | coalesced /= dist.get_world_size()
91 | for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)):
92 | buf.copy_(synced)
93 |
94 | for param in list(self.module.parameters()):
95 | def allreduce_hook(*unused):
96 | param._execution_engine.queue_callback(allreduce_params)
97 | if param.requires_grad:
98 | param.register_hook(allreduce_hook)
99 |
100 | def forward(self, *inputs, **kwargs):
101 | self.needs_reduction = True
102 | return self.module(*inputs, **kwargs)
103 |
104 | '''
105 | def _sync_buffers(self):
106 | buffers = list(self.module._all_buffers())
107 | if len(buffers) > 0:
108 | # cross-node buffer sync
109 | flat_buffers = _flatten_dense_tensors(buffers)
110 | dist.broadcast(flat_buffers, 0)
111 | for buf, synced in zip(buffers, _unflatten_dense_tensors(flat_buffers, buffers)):
112 | buf.copy_(synced)
113 | def train(self, mode=True):
114 | # Clear NCCL communicator and CUDA event cache of the default group ID,
115 | # These cache will be recreated at the later call. This is currently a
116 | # work-around for a potential NCCL deadlock.
117 | if dist._backend == dist.dist_backend.NCCL:
118 | dist._clear_group_cache()
119 | super(DistributedDataParallel, self).train(mode)
120 | self.module.train(mode)
121 | '''
122 | '''
123 | Modifies existing model to do gradient allreduce, but doesn't change class
124 | so you don't need "module"
125 | '''
126 | def apply_gradient_allreduce(module):
127 | if not hasattr(dist, '_backend'):
128 | module.warn_on_half = True
129 | else:
130 | module.warn_on_half = True if dist._backend == dist.dist_backend.GLOO else False
131 |
132 | for p in module.state_dict().values():
133 | if not torch.is_tensor(p):
134 | continue
135 | dist.broadcast(p, 0)
136 |
137 | def allreduce_params():
138 | if(module.needs_reduction):
139 | module.needs_reduction = False
140 | buckets = {}
141 | for param in module.parameters():
142 | if param.requires_grad and param.grad is not None:
143 | tp = param.data.dtype
144 | if tp not in buckets:
145 | buckets[tp] = []
146 | buckets[tp].append(param)
147 | if module.warn_on_half:
148 | if torch.cuda.HalfTensor in buckets:
149 | print("WARNING: gloo dist backend for half parameters may be extremely slow." +
150 | " It is recommended to use the NCCL backend in this case. This currently requires" +
151 | "PyTorch built from top of tree master.")
152 | module.warn_on_half = False
153 |
154 | for tp in buckets:
155 | bucket = buckets[tp]
156 | grads = [param.grad.data for param in bucket]
157 | coalesced = _flatten_dense_tensors(grads)
158 | dist.all_reduce(coalesced)
159 | coalesced /= dist.get_world_size()
160 | for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)):
161 | buf.copy_(synced)
162 |
163 | for param in list(module.parameters()):
164 | def allreduce_hook(*unused):
165 | Variable._execution_engine.queue_callback(allreduce_params)
166 | if param.requires_grad:
167 | param.register_hook(allreduce_hook)
168 |
169 | def set_needs_reduction(self, input, output):
170 | self.needs_reduction = True
171 |
172 | module.register_forward_hook(set_needs_reduction)
173 | return module
174 |
--------------------------------------------------------------------------------
/filelists/ljs_audio_text_val_filelist.txt:
--------------------------------------------------------------------------------
1 | DUMMY/LJ022-0023.wav|The overwhelming majority of people in this country know how to sift the wheat from the chaff in what they hear and what they read.
2 | DUMMY/LJ043-0030.wav|If somebody did that to me, a lousy trick like that, to take my wife away, and all the furniture, I would be mad as hell, too.
3 | DUMMY/LJ005-0201.wav|as is shown by the report of the Commissioners to inquire into the state of the municipal corporations in eighteen thirty-five.
4 | DUMMY/LJ001-0110.wav|Even the Caslon type when enlarged shows great shortcomings in this respect:
5 | DUMMY/LJ003-0345.wav|All the committee could do in this respect was to throw the responsibility on others.
6 | DUMMY/LJ007-0154.wav|These pungent and well-grounded strictures applied with still greater force to the unconvicted prisoner, the man who came to the prison innocent, and still uncontaminated,
7 | DUMMY/LJ018-0098.wav|and recognized as one of the frequenters of the bogus law-stationers. His arrest led to that of others.
8 | DUMMY/LJ047-0044.wav|Oswald was, however, willing to discuss his contacts with Soviet authorities. He denied having any involvement with Soviet intelligence agencies
9 | DUMMY/LJ031-0038.wav|The first physician to see the President at Parkland Hospital was Dr. Charles J. Carrico, a resident in general surgery.
10 | DUMMY/LJ048-0194.wav|during the morning of November twenty-two prior to the motorcade.
11 | DUMMY/LJ049-0026.wav|On occasion the Secret Service has been permitted to have an agent riding in the passenger compartment with the President.
12 | DUMMY/LJ004-0152.wav|although at Mr. Buxton's visit a new jail was in process of erection, the first step towards reform since Howard's visitation in seventeen seventy-four.
13 | DUMMY/LJ008-0278.wav|or theirs might be one of many, and it might be considered necessary to "make an example."
14 | DUMMY/LJ043-0002.wav|The Warren Commission Report. By The President's Commission on the Assassination of President Kennedy. Chapter seven. Lee Harvey Oswald:
15 | DUMMY/LJ009-0114.wav|Mr. Wakefield winds up his graphic but somewhat sensational account by describing another religious service, which may appropriately be inserted here.
16 | DUMMY/LJ028-0506.wav|A modern artist would have difficulty in doing such accurate work.
17 | DUMMY/LJ050-0168.wav|with the particular purposes of the agency involved. The Commission recognizes that this is a controversial area
18 | DUMMY/LJ039-0223.wav|Oswald's Marine training in marksmanship, his other rifle experience and his established familiarity with this particular weapon
19 | DUMMY/LJ029-0032.wav|According to O'Donnell, quote, we had a motorcade wherever we went, end quote.
20 | DUMMY/LJ031-0070.wav|Dr. Clark, who most closely observed the head wound,
21 | DUMMY/LJ034-0198.wav|Euins, who was on the southwest corner of Elm and Houston Streets testified that he could not describe the man he saw in the window.
22 | DUMMY/LJ026-0068.wav|Energy enters the plant, to a small extent,
23 | DUMMY/LJ039-0075.wav|once you know that you must put the crosshairs on the target and that is all that is necessary.
24 | DUMMY/LJ004-0096.wav|the fatal consequences whereof might be prevented if the justices of the peace were duly authorized
25 | DUMMY/LJ005-0014.wav|Speaking on a debate on prison matters, he declared that
26 | DUMMY/LJ012-0161.wav|he was reported to have fallen away to a shadow.
27 | DUMMY/LJ018-0239.wav|His disappearance gave color and substance to evil reports already in circulation that the will and conveyance above referred to
28 | DUMMY/LJ019-0257.wav|Here the tread-wheel was in use, there cellular cranks, or hard-labor machines.
29 | DUMMY/LJ028-0008.wav|you tap gently with your heel upon the shoulder of the dromedary to urge her on.
30 | DUMMY/LJ024-0083.wav|This plan of mine is no attack on the Court;
31 | DUMMY/LJ042-0129.wav|No night clubs or bowling alleys, no places of recreation except the trade union dances. I have had enough.
32 | DUMMY/LJ036-0103.wav|The police asked him whether he could pick out his passenger from the lineup.
33 | DUMMY/LJ046-0058.wav|During his Presidency, Franklin D. Roosevelt made almost four hundred journeys and traveled more than three hundred fifty thousand miles.
34 | DUMMY/LJ014-0076.wav|He was seen afterwards smoking and talking with his hosts in their back parlor, and never seen again alive.
35 | DUMMY/LJ002-0043.wav|long narrow rooms -- one thirty-six feet, six twenty-three feet, and the eighth eighteen,
36 | DUMMY/LJ009-0076.wav|We come to the sermon.
37 | DUMMY/LJ017-0131.wav|even when the high sheriff had told him there was no possibility of a reprieve, and within a few hours of execution.
38 | DUMMY/LJ046-0184.wav|but there is a system for the immediate notification of the Secret Service by the confining institution when a subject is released or escapes.
39 | DUMMY/LJ014-0263.wav|When other pleasures palled he took a theatre, and posed as a munificent patron of the dramatic art.
40 | DUMMY/LJ042-0096.wav|(old exchange rate) in addition to his factory salary of approximately equal amount
41 | DUMMY/LJ049-0050.wav|Hill had both feet on the car and was climbing aboard to assist President and Mrs. Kennedy.
42 | DUMMY/LJ019-0186.wav|seeing that since the establishment of the Central Criminal Court, Newgate received prisoners for trial from several counties,
43 | DUMMY/LJ028-0307.wav|then let twenty days pass, and at the end of that time station near the Chaldasan gates a body of four thousand.
44 | DUMMY/LJ012-0235.wav|While they were in a state of insensibility the murder was committed.
45 | DUMMY/LJ034-0053.wav|reached the same conclusion as Latona that the prints found on the cartons were those of Lee Harvey Oswald.
46 | DUMMY/LJ014-0030.wav|These were damnatory facts which well supported the prosecution.
47 | DUMMY/LJ015-0203.wav|but were the precautions too minute, the vigilance too close to be eluded or overcome?
48 | DUMMY/LJ028-0093.wav|but his scribe wrote it in the manner customary for the scribes of those days to write of their royal masters.
49 | DUMMY/LJ002-0018.wav|The inadequacy of the jail was noticed and reported upon again and again by the grand juries of the city of London,
50 | DUMMY/LJ028-0275.wav|At last, in the twentieth month,
51 | DUMMY/LJ012-0042.wav|which he kept concealed in a hiding-place with a trap-door just under his bed.
52 | DUMMY/LJ011-0096.wav|He married a lady also belonging to the Society of Friends, who brought him a large fortune, which, and his own money, he put into a city firm,
53 | DUMMY/LJ036-0077.wav|Roger D. Craig, a deputy sheriff of Dallas County,
54 | DUMMY/LJ016-0318.wav|Other officials, great lawyers, governors of prisons, and chaplains supported this view.
55 | DUMMY/LJ013-0164.wav|who came from his room ready dressed, a suspicious circumstance, as he was always late in the morning.
56 | DUMMY/LJ027-0141.wav|is closely reproduced in the life-history of existing deer. Or, in other words,
57 | DUMMY/LJ028-0335.wav|accordingly they committed to him the command of their whole army, and put the keys of their city into his hands.
58 | DUMMY/LJ031-0202.wav|Mrs. Kennedy chose the hospital in Bethesda for the autopsy because the President had served in the Navy.
59 | DUMMY/LJ021-0145.wav|From those willing to join in establishing this hoped-for period of peace,
60 | DUMMY/LJ016-0288.wav|"Müller, Müller, He's the man," till a diversion was created by the appearance of the gallows, which was received with continuous yells.
61 | DUMMY/LJ028-0081.wav|Years later, when the archaeologists could readily distinguish the false from the true,
62 | DUMMY/LJ018-0081.wav|his defense being that he had intended to commit suicide, but that, on the appearance of this officer who had wronged him,
63 | DUMMY/LJ021-0066.wav|together with a great increase in the payrolls, there has come a substantial rise in the total of industrial profits
64 | DUMMY/LJ009-0238.wav|After this the sheriffs sent for another rope, but the spectators interfered, and the man was carried back to jail.
65 | DUMMY/LJ005-0079.wav|and improve the morals of the prisoners, and shall insure the proper measure of punishment to convicted offenders.
66 | DUMMY/LJ035-0019.wav|drove to the northwest corner of Elm and Houston, and parked approximately ten feet from the traffic signal.
67 | DUMMY/LJ036-0174.wav|This is the approximate time he entered the roominghouse, according to Earlene Roberts, the housekeeper there.
68 | DUMMY/LJ046-0146.wav|The criteria in effect prior to November twenty-two, nineteen sixty-three, for determining whether to accept material for the PRS general files
69 | DUMMY/LJ017-0044.wav|and the deepest anxiety was felt that the crime, if crime there had been, should be brought home to its perpetrator.
70 | DUMMY/LJ017-0070.wav|but his sporting operations did not prosper, and he became a needy man, always driven to desperate straits for cash.
71 | DUMMY/LJ014-0020.wav|He was soon afterwards arrested on suspicion, and a search of his lodgings brought to light several garments saturated with blood;
72 | DUMMY/LJ016-0020.wav|He never reached the cistern, but fell back into the yard, injuring his legs severely.
73 | DUMMY/LJ045-0230.wav|when he was finally apprehended in the Texas Theatre. Although it is not fully corroborated by others who were present,
74 | DUMMY/LJ035-0129.wav|and she must have run down the stairs ahead of Oswald and would probably have seen or heard him.
75 | DUMMY/LJ008-0307.wav|afterwards express a wish to murder the Recorder for having kept them so long in suspense.
76 | DUMMY/LJ008-0294.wav|nearly indefinitely deferred.
77 | DUMMY/LJ047-0148.wav|On October twenty-five,
78 | DUMMY/LJ008-0111.wav|They entered a "stone cold room," and were presently joined by the prisoner.
79 | DUMMY/LJ034-0042.wav|that he could only testify with certainty that the print was less than three days old.
80 | DUMMY/LJ037-0234.wav|Mrs. Mary Brock, the wife of a mechanic who worked at the station, was there at the time and she saw a white male,
81 | DUMMY/LJ040-0002.wav|Chapter seven. Lee Harvey Oswald: Background and Possible Motives, Part one.
82 | DUMMY/LJ045-0140.wav|The arguments he used to justify his use of the alias suggest that Oswald may have come to think that the whole world was becoming involved
83 | DUMMY/LJ012-0035.wav|the number and names on watches, were carefully removed or obliterated after the goods passed out of his hands.
84 | DUMMY/LJ012-0250.wav|On the seventh July, eighteen thirty-seven,
85 | DUMMY/LJ016-0179.wav|contracted with sheriffs and conveners to work by the job.
86 | DUMMY/LJ016-0138.wav|at a distance from the prison.
87 | DUMMY/LJ027-0052.wav|These principles of homology are essential to a correct interpretation of the facts of morphology.
88 | DUMMY/LJ031-0134.wav|On one occasion Mrs. Johnson, accompanied by two Secret Service agents, left the room to see Mrs. Kennedy and Mrs. Connally.
89 | DUMMY/LJ019-0273.wav|which Sir Joshua Jebb told the committee he considered the proper elements of penal discipline.
90 | DUMMY/LJ014-0110.wav|At the first the boxes were impounded, opened, and found to contain many of O'Connor's effects.
91 | DUMMY/LJ034-0160.wav|on Brennan's subsequent certain identification of Lee Harvey Oswald as the man he saw fire the rifle.
92 | DUMMY/LJ038-0199.wav|eleven. If I am alive and taken prisoner,
93 | DUMMY/LJ014-0010.wav|yet he could not overcome the strange fascination it had for him, and remained by the side of the corpse till the stretcher came.
94 | DUMMY/LJ033-0047.wav|I noticed when I went out that the light was on, end quote,
95 | DUMMY/LJ040-0027.wav|He was never satisfied with anything.
96 | DUMMY/LJ048-0228.wav|and others who were present say that no agent was inebriated or acted improperly.
97 | DUMMY/LJ003-0111.wav|He was in consequence put out of the protection of their internal law, end quote. Their code was a subject of some curiosity.
98 | DUMMY/LJ008-0258.wav|Let me retrace my steps, and speak more in detail of the treatment of the condemned in those bloodthirsty and brutally indifferent days,
99 | DUMMY/LJ029-0022.wav|The original plan called for the President to spend only one day in the State, making whirlwind visits to Dallas, Fort Worth, San Antonio, and Houston.
100 | DUMMY/LJ004-0045.wav|Mr. Sturges Bourne, Sir James Mackintosh, Sir James Scarlett, and William Wilberforce.
101 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | import argparse
4 | import math
5 | from numpy import finfo
6 |
7 | import torch
8 | from distributed import apply_gradient_allreduce
9 | import torch.distributed as dist
10 | from torch.utils.data.distributed import DistributedSampler
11 | from torch.utils.data import DataLoader
12 |
13 | from model import Tacotron2
14 | from data_utils import TextMelLoader, TextMelCollate
15 | from loss_function import Tacotron2Loss
16 | from logger import Tacotron2Logger
17 | from hparams import create_hparams
18 |
19 |
20 | def reduce_tensor(tensor, n_gpus):
21 | rt = tensor.clone()
22 | dist.all_reduce(rt, op=dist.reduce_op.SUM)
23 | rt /= n_gpus
24 | return rt
25 |
26 |
27 | def init_distributed(hparams, n_gpus, rank, group_name):
28 | assert torch.cuda.is_available(), "Distributed mode requires CUDA."
29 | print("Initializing Distributed")
30 |
31 | # Set cuda device so everything is done on the right GPU.
32 | torch.cuda.set_device(rank % torch.cuda.device_count())
33 |
34 | # Initialize distributed communication
35 | dist.init_process_group(
36 | backend=hparams.dist_backend, init_method=hparams.dist_url,
37 | world_size=n_gpus, rank=rank, group_name=group_name)
38 |
39 | print("Done initializing distributed")
40 |
41 |
42 | def prepare_dataloaders(hparams):
43 | # Get data, data loaders and collate function ready
44 | trainset = TextMelLoader(hparams.training_files, hparams)
45 | valset = TextMelLoader(hparams.validation_files, hparams)
46 | collate_fn = TextMelCollate(hparams.n_frames_per_step)
47 |
48 | if hparams.distributed_run:
49 | train_sampler = DistributedSampler(trainset)
50 | shuffle = False
51 | else:
52 | train_sampler = None
53 | shuffle = True
54 |
55 | train_loader = DataLoader(trainset, num_workers=1, shuffle=shuffle,
56 | sampler=train_sampler,
57 | batch_size=hparams.batch_size, pin_memory=False,
58 | drop_last=True, collate_fn=collate_fn)
59 | return train_loader, valset, collate_fn
60 |
61 |
62 | def prepare_directories_and_logger(output_directory, log_directory, rank):
63 | if rank == 0:
64 | if not os.path.isdir(output_directory):
65 | os.makedirs(output_directory)
66 | os.chmod(output_directory, 0o775)
67 | logger = Tacotron2Logger(os.path.join(output_directory, log_directory))
68 | else:
69 | logger = None
70 | return logger
71 |
72 |
73 | def load_model(hparams,use_cuda=True):
74 | device = torch.device('cuda' if use_cuda else 'cpu')
75 | model = Tacotron2(hparams).to(device)
76 | if hparams.fp16_run:
77 | model.decoder.attention_layer.score_mask_value = finfo('float16').min
78 |
79 | if hparams.distributed_run:
80 | model = apply_gradient_allreduce(model)
81 |
82 | return model
83 |
84 |
85 | def warm_start_model(checkpoint_path, model, ignore_layers):
86 | assert os.path.isfile(checkpoint_path)
87 | print("Warm starting model from checkpoint '{}'".format(checkpoint_path))
88 | checkpoint_dict = torch.load(checkpoint_path, map_location='cpu')
89 | model_dict = checkpoint_dict['state_dict']
90 | if len(ignore_layers) > 0:
91 | model_dict = {k: v for k, v in model_dict.items()
92 | if k not in ignore_layers}
93 | dummy_dict = model.state_dict()
94 | dummy_dict.update(model_dict)
95 | model_dict = dummy_dict
96 | model.load_state_dict(model_dict)
97 | return model
98 |
99 |
100 | def load_checkpoint(checkpoint_path, model, optimizer):
101 | assert os.path.isfile(checkpoint_path)
102 | print("Loading checkpoint '{}'".format(checkpoint_path))
103 | checkpoint_dict = torch.load(checkpoint_path, map_location='cpu')
104 | model.load_state_dict(checkpoint_dict['state_dict'])
105 | optimizer.load_state_dict(checkpoint_dict['optimizer'])
106 | learning_rate = checkpoint_dict['learning_rate']
107 | iteration = checkpoint_dict['iteration']
108 | print("Loaded checkpoint '{}' from iteration {}" .format(
109 | checkpoint_path, iteration))
110 | return model, optimizer, learning_rate, iteration
111 |
112 |
113 | def save_checkpoint(model, optimizer, learning_rate, iteration, filepath):
114 | print("Saving model and optimizer state at iteration {} to {}".format(
115 | iteration, filepath))
116 | torch.save({'iteration': iteration,
117 | 'state_dict': model.state_dict(),
118 | 'optimizer': optimizer.state_dict(),
119 | 'learning_rate': learning_rate}, filepath)
120 |
121 |
122 | def validate(model, criterion, valset, iteration, batch_size, n_gpus,
123 | collate_fn, logger, distributed_run, rank):
124 | """Handles all the validation scoring and printing"""
125 | model.eval()
126 | with torch.no_grad():
127 | val_sampler = DistributedSampler(valset) if distributed_run else None
128 | val_loader = DataLoader(valset, sampler=val_sampler, num_workers=1,
129 | shuffle=False, batch_size=batch_size,
130 | pin_memory=False, collate_fn=collate_fn)
131 |
132 | val_loss = 0.0
133 | for i, batch in enumerate(val_loader):
134 | x, y = model.parse_batch(batch)
135 | y_pred = model(x)
136 | loss = criterion(y_pred, y)
137 | if distributed_run:
138 | reduced_val_loss = reduce_tensor(loss.data, n_gpus).item()
139 | else:
140 | reduced_val_loss = loss.item()
141 | val_loss += reduced_val_loss
142 | val_loss = val_loss / (i + 1)
143 |
144 | model.train()
145 | if rank == 0:
146 | print("Validation loss {}: {:9f} ".format(iteration, val_loss))
147 | logger.log_validation(val_loss, model, y, y_pred, iteration)
148 |
149 |
150 | def train(output_directory, log_directory, checkpoint_path, warm_start, n_gpus,
151 | rank, group_name, hparams):
152 | """Training and validation logging results to tensorboard and stdout
153 |
154 | Params
155 | ------
156 | output_directory (string): directory to save checkpoints
157 | log_directory (string) directory to save tensorboard logs
158 | checkpoint_path(string): checkpoint path
159 | n_gpus (int): number of gpus
160 | rank (int): rank of current gpu
161 | hparams (object): comma separated list of "name=value" pairs.
162 | """
163 | if hparams.distributed_run:
164 | init_distributed(hparams, n_gpus, rank, group_name)
165 |
166 | torch.manual_seed(hparams.seed)
167 | torch.cuda.manual_seed(hparams.seed)
168 |
169 | model = load_model(hparams)
170 | learning_rate = hparams.learning_rate
171 | optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate,
172 | weight_decay=hparams.weight_decay)
173 |
174 | if hparams.fp16_run:
175 | from apex import amp
176 | model, optimizer = amp.initialize(
177 | model, optimizer, opt_level='O2')
178 |
179 | if hparams.distributed_run:
180 | model = apply_gradient_allreduce(model)
181 |
182 | criterion = Tacotron2Loss()
183 |
184 | logger = prepare_directories_and_logger(
185 | output_directory, log_directory, rank)
186 |
187 | train_loader, valset, collate_fn = prepare_dataloaders(hparams)
188 |
189 | # Load checkpoint if one exists
190 | iteration = 0
191 | epoch_offset = 0
192 | if checkpoint_path is not None:
193 | if warm_start:
194 | model = warm_start_model(
195 | checkpoint_path, model, hparams.ignore_layers)
196 | else:
197 | model, optimizer, _learning_rate, iteration = load_checkpoint(
198 | checkpoint_path, model, optimizer)
199 | if hparams.use_saved_learning_rate:
200 | learning_rate = _learning_rate
201 | iteration += 1 # next iteration is iteration + 1
202 | epoch_offset = max(0, int(iteration / len(train_loader)))
203 |
204 | model.train()
205 | is_overflow = False
206 | # ================ MAIN TRAINNIG LOOP! ===================
207 | for epoch in range(epoch_offset, hparams.epochs):
208 | print("Epoch: {}".format(epoch))
209 | for i, batch in enumerate(train_loader):
210 | start = time.perf_counter()
211 | for param_group in optimizer.param_groups:
212 | param_group['lr'] = learning_rate
213 |
214 | model.zero_grad()
215 | x, y = model.parse_batch(batch)
216 | y_pred = model(x)
217 |
218 | loss = criterion(y_pred, y)
219 | if hparams.distributed_run:
220 | reduced_loss = reduce_tensor(loss.data, n_gpus).item()
221 | else:
222 | reduced_loss = loss.item()
223 | if hparams.fp16_run:
224 | with amp.scale_loss(loss, optimizer) as scaled_loss:
225 | scaled_loss.backward()
226 | else:
227 | loss.backward()
228 |
229 | if hparams.fp16_run:
230 | grad_norm = torch.nn.utils.clip_grad_norm_(
231 | amp.master_params(optimizer), hparams.grad_clip_thresh)
232 | is_overflow = math.isnan(grad_norm)
233 | else:
234 | grad_norm = torch.nn.utils.clip_grad_norm_(
235 | model.parameters(), hparams.grad_clip_thresh)
236 |
237 | optimizer.step()
238 |
239 | if not is_overflow and rank == 0:
240 | duration = time.perf_counter() - start
241 | print("Train loss {} {:.6f} Grad Norm {:.6f} {:.2f}s/it".format(
242 | iteration, reduced_loss, grad_norm, duration))
243 | logger.log_training(
244 | reduced_loss, grad_norm, learning_rate, duration, iteration)
245 |
246 | if not is_overflow and (iteration % hparams.iters_per_checkpoint == 0):
247 | validate(model, criterion, valset, iteration,
248 | hparams.batch_size, n_gpus, collate_fn, logger,
249 | hparams.distributed_run, rank)
250 | if rank == 0:
251 | checkpoint_path = os.path.join(
252 | output_directory, "checkpoint_{}".format(iteration))
253 | save_checkpoint(model, optimizer, learning_rate, iteration,
254 | checkpoint_path)
255 |
256 | iteration += 1
257 |
258 |
259 | if __name__ == '__main__':
260 | parser = argparse.ArgumentParser()
261 | parser.add_argument('-o', '--output_directory', type=str,
262 | help='directory to save checkpoints')
263 | parser.add_argument('-l', '--log_directory', type=str,
264 | help='directory to save tensorboard logs')
265 | parser.add_argument('-c', '--checkpoint_path', type=str, default=None,
266 | required=False, help='checkpoint path')
267 | parser.add_argument('--warm_start', action='store_true',
268 | help='load model weights only, ignore specified layers')
269 | parser.add_argument('--n_gpus', type=int, default=1,
270 | required=False, help='number of gpus')
271 | parser.add_argument('--rank', type=int, default=0,
272 | required=False, help='rank of current gpu')
273 | parser.add_argument('--group_name', type=str, default='group_name',
274 | required=False, help='Distributed group name')
275 | parser.add_argument('--hparams', type=str,
276 | required=False, help='comma separated name=value pairs')
277 |
278 | args = parser.parse_args()
279 | hparams = create_hparams(args.hparams)
280 |
281 | torch.backends.cudnn.enabled = hparams.cudnn_enabled
282 | torch.backends.cudnn.benchmark = hparams.cudnn_benchmark
283 |
284 | print("FP16 Run:", hparams.fp16_run)
285 | print("Dynamic Loss Scaling:", hparams.dynamic_loss_scaling)
286 | print("Distributed Run:", hparams.distributed_run)
287 | print("cuDNN Enabled:", hparams.cudnn_enabled)
288 | print("cuDNN Benchmark:", hparams.cudnn_benchmark)
289 |
290 | train(args.output_directory, args.log_directory, args.checkpoint_path,
291 | args.warm_start, args.n_gpus, args.rank, args.group_name, hparams)
292 |
--------------------------------------------------------------------------------
/nvidia_tacotron_TTS_Layout.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | # Form implementation generated from reading ui file '.\nvidia_tacotron_TTS_Layout.ui'
4 | #
5 | # Created by: PyQt5 UI code generator 5.15.0
6 | #
7 | # WARNING: Any manual changes made to this file will be lost when pyuic5 is
8 | # run again. Do not edit this file unless you know what you are doing.
9 |
10 |
11 | from PyQt5 import QtCore, QtGui, QtWidgets
12 |
13 |
14 | class Ui_MainWindow(object):
15 | def setupUi(self, MainWindow):
16 | MainWindow.setObjectName("MainWindow")
17 | MainWindow.resize(518, 534)
18 | self.centralwidget = QtWidgets.QWidget(MainWindow)
19 | self.centralwidget.setObjectName("centralwidget")
20 | self.verticalLayout_3 = QtWidgets.QVBoxLayout(self.centralwidget)
21 | self.verticalLayout_3.setObjectName("verticalLayout_3")
22 | self.tabWidget = QtWidgets.QTabWidget(self.centralwidget)
23 | sizePolicy = QtWidgets.QSizePolicy(QtWidgets.QSizePolicy.MinimumExpanding, QtWidgets.QSizePolicy.MinimumExpanding)
24 | sizePolicy.setHorizontalStretch(0)
25 | sizePolicy.setVerticalStretch(0)
26 | sizePolicy.setHeightForWidth(self.tabWidget.sizePolicy().hasHeightForWidth())
27 | self.tabWidget.setSizePolicy(sizePolicy)
28 | self.tabWidget.setObjectName("tabWidget")
29 | self.tab = QtWidgets.QWidget()
30 | self.tab.setObjectName("tab")
31 | self.verticalLayout = QtWidgets.QVBoxLayout(self.tab)
32 | self.verticalLayout.setObjectName("verticalLayout")
33 | self.gridLayout_5 = QtWidgets.QGridLayout()
34 | self.gridLayout_5.setObjectName("gridLayout_5")
35 | self.label_9 = QtWidgets.QLabel(self.tab)
36 | self.label_9.setObjectName("label_9")
37 | self.gridLayout_5.addWidget(self.label_9, 1, 0, 1, 1)
38 | self.WGModelCombo = QtWidgets.QComboBox(self.tab)
39 | sizePolicy = QtWidgets.QSizePolicy(QtWidgets.QSizePolicy.MinimumExpanding, QtWidgets.QSizePolicy.Fixed)
40 | sizePolicy.setHorizontalStretch(0)
41 | sizePolicy.setVerticalStretch(0)
42 | sizePolicy.setHeightForWidth(self.WGModelCombo.sizePolicy().hasHeightForWidth())
43 | self.WGModelCombo.setSizePolicy(sizePolicy)
44 | self.WGModelCombo.setObjectName("WGModelCombo")
45 | self.gridLayout_5.addWidget(self.WGModelCombo, 1, 2, 1, 1)
46 | self.label_7 = QtWidgets.QLabel(self.tab)
47 | self.label_7.setObjectName("label_7")
48 | self.gridLayout_5.addWidget(self.label_7, 0, 0, 1, 1)
49 | self.TTModelCombo = QtWidgets.QComboBox(self.tab)
50 | sizePolicy = QtWidgets.QSizePolicy(QtWidgets.QSizePolicy.MinimumExpanding, QtWidgets.QSizePolicy.Fixed)
51 | sizePolicy.setHorizontalStretch(0)
52 | sizePolicy.setVerticalStretch(0)
53 | sizePolicy.setHeightForWidth(self.TTModelCombo.sizePolicy().hasHeightForWidth())
54 | self.TTModelCombo.setSizePolicy(sizePolicy)
55 | self.TTModelCombo.setObjectName("TTModelCombo")
56 | self.gridLayout_5.addWidget(self.TTModelCombo, 0, 2, 1, 1)
57 | self.LoadTTButton = QtWidgets.QToolButton(self.tab)
58 | sizePolicy = QtWidgets.QSizePolicy(QtWidgets.QSizePolicy.Fixed, QtWidgets.QSizePolicy.Fixed)
59 | sizePolicy.setHorizontalStretch(0)
60 | sizePolicy.setVerticalStretch(0)
61 | sizePolicy.setHeightForWidth(self.LoadTTButton.sizePolicy().hasHeightForWidth())
62 | self.LoadTTButton.setSizePolicy(sizePolicy)
63 | self.LoadTTButton.setLayoutDirection(QtCore.Qt.LeftToRight)
64 | self.LoadTTButton.setObjectName("LoadTTButton")
65 | self.gridLayout_5.addWidget(self.LoadTTButton, 0, 3, 1, 1)
66 | self.LoadWGButton = QtWidgets.QToolButton(self.tab)
67 | self.LoadWGButton.setObjectName("LoadWGButton")
68 | self.gridLayout_5.addWidget(self.LoadWGButton, 1, 3, 1, 1)
69 | self.verticalLayout.addLayout(self.gridLayout_5)
70 | self.horizontalLayout = QtWidgets.QHBoxLayout()
71 | self.horizontalLayout.setObjectName("horizontalLayout")
72 | spacerItem = QtWidgets.QSpacerItem(40, 20, QtWidgets.QSizePolicy.Expanding, QtWidgets.QSizePolicy.Minimum)
73 | self.horizontalLayout.addItem(spacerItem)
74 | self.label = QtWidgets.QLabel(self.tab)
75 | self.label.setObjectName("label")
76 | self.horizontalLayout.addWidget(self.label)
77 | self.verticalLayout.addLayout(self.horizontalLayout)
78 | spacerItem1 = QtWidgets.QSpacerItem(20, 20, QtWidgets.QSizePolicy.Minimum, QtWidgets.QSizePolicy.Fixed)
79 | self.verticalLayout.addItem(spacerItem1)
80 | self.label_2 = QtWidgets.QLabel(self.tab)
81 | self.label_2.setObjectName("label_2")
82 | self.verticalLayout.addWidget(self.label_2)
83 | self.TTSTextEdit = QtWidgets.QPlainTextEdit(self.tab)
84 | self.TTSTextEdit.setObjectName("TTSTextEdit")
85 | self.verticalLayout.addWidget(self.TTSTextEdit)
86 | self.gridLayout_2 = QtWidgets.QGridLayout()
87 | self.gridLayout_2.setObjectName("gridLayout_2")
88 | self.progressBar = QtWidgets.QProgressBar(self.tab)
89 | self.progressBar.setProperty("value", 0)
90 | self.progressBar.setObjectName("progressBar")
91 | self.gridLayout_2.addWidget(self.progressBar, 1, 0, 1, 1)
92 | self.progressBarLabel = QtWidgets.QLabel(self.tab)
93 | self.progressBarLabel.setObjectName("progressBarLabel")
94 | self.gridLayout_2.addWidget(self.progressBarLabel, 1, 0, 1, 1, QtCore.Qt.AlignHCenter)
95 | self.verticalLayout.addLayout(self.gridLayout_2)
96 | self.gridLayout_4 = QtWidgets.QGridLayout()
97 | self.gridLayout_4.setObjectName("gridLayout_4")
98 | spacerItem2 = QtWidgets.QSpacerItem(40, 20, QtWidgets.QSizePolicy.Expanding, QtWidgets.QSizePolicy.Minimum)
99 | self.gridLayout_4.addItem(spacerItem2, 0, 0, 1, 1)
100 | self.TTSDialogButton = QtWidgets.QPushButton(self.tab)
101 | self.TTSDialogButton.setObjectName("TTSDialogButton")
102 | self.gridLayout_4.addWidget(self.TTSDialogButton, 0, 2, 1, 1, QtCore.Qt.AlignRight)
103 | self.TTSStopButton = QtWidgets.QPushButton(self.tab)
104 | self.TTSStopButton.setObjectName("TTSStopButton")
105 | self.gridLayout_4.addWidget(self.TTSStopButton, 0, 3, 1, 1)
106 | self.verticalLayout.addLayout(self.gridLayout_4)
107 | self.log_window1 = QtWidgets.QLabel(self.tab)
108 | sizePolicy = QtWidgets.QSizePolicy(QtWidgets.QSizePolicy.MinimumExpanding, QtWidgets.QSizePolicy.Fixed)
109 | sizePolicy.setHorizontalStretch(0)
110 | sizePolicy.setVerticalStretch(0)
111 | sizePolicy.setHeightForWidth(self.log_window1.sizePolicy().hasHeightForWidth())
112 | self.log_window1.setSizePolicy(sizePolicy)
113 | self.log_window1.setMinimumSize(QtCore.QSize(0, 39))
114 | self.log_window1.setAlignment(QtCore.Qt.AlignBottom|QtCore.Qt.AlignLeading|QtCore.Qt.AlignLeft)
115 | self.log_window1.setObjectName("log_window1")
116 | self.verticalLayout.addWidget(self.log_window1)
117 | self.tabWidget.addTab(self.tab, "")
118 | self.tab_2 = QtWidgets.QWidget()
119 | self.tab_2.setObjectName("tab_2")
120 | self.verticalLayout_2 = QtWidgets.QVBoxLayout(self.tab_2)
121 | self.verticalLayout_2.setObjectName("verticalLayout_2")
122 | self.gridLayout = QtWidgets.QGridLayout()
123 | self.gridLayout.setSizeConstraint(QtWidgets.QLayout.SetFixedSize)
124 | self.gridLayout.setObjectName("gridLayout")
125 | self.label_5 = QtWidgets.QLabel(self.tab_2)
126 | self.label_5.setObjectName("label_5")
127 | self.gridLayout.addWidget(self.label_5, 0, 0, 1, 1)
128 | self.ChannelName = QtWidgets.QLineEdit(self.tab_2)
129 | self.ChannelName.setObjectName("ChannelName")
130 | self.gridLayout.addWidget(self.ChannelName, 0, 1, 1, 1)
131 | self.label_3 = QtWidgets.QLabel(self.tab_2)
132 | self.label_3.setObjectName("label_3")
133 | self.gridLayout.addWidget(self.label_3, 1, 0, 1, 1)
134 | self.APIKeyLine = QtWidgets.QLineEdit(self.tab_2)
135 | self.APIKeyLine.setEchoMode(QtWidgets.QLineEdit.Password)
136 | self.APIKeyLine.setObjectName("APIKeyLine")
137 | self.gridLayout.addWidget(self.APIKeyLine, 1, 1, 1, 1)
138 | self.verticalLayout_2.addLayout(self.gridLayout)
139 | self.horizontalLayout_2 = QtWidgets.QHBoxLayout()
140 | self.horizontalLayout_2.setObjectName("horizontalLayout_2")
141 | spacerItem3 = QtWidgets.QSpacerItem(40, 20, QtWidgets.QSizePolicy.Expanding, QtWidgets.QSizePolicy.Minimum)
142 | self.horizontalLayout_2.addItem(spacerItem3)
143 | self.ClientStartBtn = QtWidgets.QPushButton(self.tab_2)
144 | self.ClientStartBtn.setObjectName("ClientStartBtn")
145 | self.horizontalLayout_2.addWidget(self.ClientStartBtn)
146 | self.ClientStopBtn = QtWidgets.QPushButton(self.tab_2)
147 | self.ClientStopBtn.setObjectName("ClientStopBtn")
148 | self.horizontalLayout_2.addWidget(self.ClientStopBtn, 0, QtCore.Qt.AlignRight)
149 | self.ClientSkipBtn = QtWidgets.QPushButton(self.tab_2)
150 | self.ClientSkipBtn.setObjectName("ClientSkipBtn")
151 | self.horizontalLayout_2.addWidget(self.ClientSkipBtn)
152 | self.verticalLayout_2.addLayout(self.horizontalLayout_2)
153 | spacerItem4 = QtWidgets.QSpacerItem(20, 20, QtWidgets.QSizePolicy.Minimum, QtWidgets.QSizePolicy.Fixed)
154 | self.verticalLayout_2.addItem(spacerItem4)
155 | self.widget = QtWidgets.QWidget(self.tab_2)
156 | self.widget.setObjectName("widget")
157 | self.verticalLayout_2.addWidget(self.widget)
158 | self.label_8 = QtWidgets.QLabel(self.tab_2)
159 | self.label_8.setObjectName("label_8")
160 | self.verticalLayout_2.addWidget(self.label_8)
161 | self.ClientAmountLine = QtWidgets.QDoubleSpinBox(self.tab_2)
162 | self.ClientAmountLine.setObjectName("ClientAmountLine")
163 | self.verticalLayout_2.addWidget(self.ClientAmountLine)
164 | self.label_4 = QtWidgets.QLabel(self.tab_2)
165 | self.label_4.setObjectName("label_4")
166 | self.verticalLayout_2.addWidget(self.label_4)
167 | self.log_window2 = QtWidgets.QPlainTextEdit(self.tab_2)
168 | sizePolicy = QtWidgets.QSizePolicy(QtWidgets.QSizePolicy.MinimumExpanding, QtWidgets.QSizePolicy.MinimumExpanding)
169 | sizePolicy.setHorizontalStretch(0)
170 | sizePolicy.setVerticalStretch(0)
171 | sizePolicy.setHeightForWidth(self.log_window2.sizePolicy().hasHeightForWidth())
172 | self.log_window2.setSizePolicy(sizePolicy)
173 | self.log_window2.setAutoFillBackground(False)
174 | self.log_window2.setObjectName("log_window2")
175 | self.verticalLayout_2.addWidget(self.log_window2)
176 | self.gridLayout_6 = QtWidgets.QGridLayout()
177 | self.gridLayout_6.setObjectName("gridLayout_6")
178 | self.progressBar2 = QtWidgets.QProgressBar(self.tab_2)
179 | self.progressBar2.setProperty("value", 0)
180 | self.progressBar2.setTextVisible(True)
181 | self.progressBar2.setOrientation(QtCore.Qt.Horizontal)
182 | self.progressBar2.setObjectName("progressBar2")
183 | self.gridLayout_6.addWidget(self.progressBar2, 0, 0, 1, 1)
184 | self.progressBar2Label = QtWidgets.QLabel(self.tab_2)
185 | self.progressBar2Label.setObjectName("progressBar2Label")
186 | self.gridLayout_6.addWidget(self.progressBar2Label, 0, 0, 1, 1, QtCore.Qt.AlignHCenter)
187 | self.verticalLayout_2.addLayout(self.gridLayout_6)
188 | self.tabWidget.addTab(self.tab_2, "")
189 | self.tab_3 = QtWidgets.QWidget()
190 | self.tab_3.setObjectName("tab_3")
191 | self.verticalLayout_4 = QtWidgets.QVBoxLayout(self.tab_3)
192 | self.verticalLayout_4.setObjectName("verticalLayout_4")
193 | self.groupBox = QtWidgets.QGroupBox(self.tab_3)
194 | sizePolicy = QtWidgets.QSizePolicy(QtWidgets.QSizePolicy.Preferred, QtWidgets.QSizePolicy.Fixed)
195 | sizePolicy.setHorizontalStretch(0)
196 | sizePolicy.setVerticalStretch(0)
197 | sizePolicy.setHeightForWidth(self.groupBox.sizePolicy().hasHeightForWidth())
198 | self.groupBox.setSizePolicy(sizePolicy)
199 | self.groupBox.setMaximumSize(QtCore.QSize(16777215, 150))
200 | self.groupBox.setObjectName("groupBox")
201 | self.gridLayout_7 = QtWidgets.QGridLayout(self.groupBox)
202 | self.gridLayout_7.setObjectName("gridLayout_7")
203 | self.OptLimitCpuBtn = QtWidgets.QCheckBox(self.groupBox)
204 | self.OptLimitCpuBtn.setMaximumSize(QtCore.QSize(67, 17))
205 | self.OptLimitCpuBtn.setObjectName("OptLimitCpuBtn")
206 | self.gridLayout_7.addWidget(self.OptLimitCpuBtn, 0, 0, 1, 1)
207 | spacerItem5 = QtWidgets.QSpacerItem(20, 40, QtWidgets.QSizePolicy.Minimum, QtWidgets.QSizePolicy.Expanding)
208 | self.gridLayout_7.addItem(spacerItem5, 1, 0, 1, 1)
209 | spacerItem6 = QtWidgets.QSpacerItem(40, 20, QtWidgets.QSizePolicy.Expanding, QtWidgets.QSizePolicy.Minimum)
210 | self.gridLayout_7.addItem(spacerItem6, 0, 4, 1, 1)
211 | self.label_10 = QtWidgets.QLabel(self.groupBox)
212 | self.label_10.setMaximumSize(QtCore.QSize(43, 16777215))
213 | self.label_10.setObjectName("label_10")
214 | self.gridLayout_7.addWidget(self.label_10, 0, 2, 1, 1)
215 | self.OptLimitCpuCombo = QtWidgets.QComboBox(self.groupBox)
216 | self.OptLimitCpuCombo.setEnabled(True)
217 | sizePolicy = QtWidgets.QSizePolicy(QtWidgets.QSizePolicy.Fixed, QtWidgets.QSizePolicy.Fixed)
218 | sizePolicy.setHorizontalStretch(0)
219 | sizePolicy.setVerticalStretch(0)
220 | sizePolicy.setHeightForWidth(self.OptLimitCpuCombo.sizePolicy().hasHeightForWidth())
221 | self.OptLimitCpuCombo.setSizePolicy(sizePolicy)
222 | self.OptLimitCpuCombo.setMaximumSize(QtCore.QSize(30, 20))
223 | self.OptLimitCpuCombo.setObjectName("OptLimitCpuCombo")
224 | self.gridLayout_7.addWidget(self.OptLimitCpuCombo, 0, 3, 1, 1)
225 | spacerItem7 = QtWidgets.QSpacerItem(40, 20, QtWidgets.QSizePolicy.Fixed, QtWidgets.QSizePolicy.Minimum)
226 | self.gridLayout_7.addItem(spacerItem7, 0, 1, 1, 1)
227 | self.verticalLayout_4.addWidget(self.groupBox)
228 | self.groupBox_2 = QtWidgets.QGroupBox(self.tab_3)
229 | self.groupBox_2.setObjectName("groupBox_2")
230 | self.verticalLayout_5 = QtWidgets.QVBoxLayout(self.groupBox_2)
231 | self.verticalLayout_5.setObjectName("verticalLayout_5")
232 | self.OptDonoNameAmountBtn = QtWidgets.QCheckBox(self.groupBox_2)
233 | self.OptDonoNameAmountBtn.setChecked(True)
234 | self.OptDonoNameAmountBtn.setObjectName("OptDonoNameAmountBtn")
235 | self.verticalLayout_5.addWidget(self.OptDonoNameAmountBtn)
236 | self.OptApproveDonoBtn = QtWidgets.QCheckBox(self.groupBox_2)
237 | sizePolicy = QtWidgets.QSizePolicy(QtWidgets.QSizePolicy.MinimumExpanding, QtWidgets.QSizePolicy.Fixed)
238 | sizePolicy.setHorizontalStretch(0)
239 | sizePolicy.setVerticalStretch(0)
240 | sizePolicy.setHeightForWidth(self.OptApproveDonoBtn.sizePolicy().hasHeightForWidth())
241 | self.OptApproveDonoBtn.setSizePolicy(sizePolicy)
242 | self.OptApproveDonoBtn.setChecked(True)
243 | self.OptApproveDonoBtn.setObjectName("OptApproveDonoBtn")
244 | self.verticalLayout_5.addWidget(self.OptApproveDonoBtn)
245 | self.OptBlockNumberBtn = QtWidgets.QCheckBox(self.groupBox_2)
246 | self.OptBlockNumberBtn.setObjectName("OptBlockNumberBtn")
247 | self.verticalLayout_5.addWidget(self.OptBlockNumberBtn)
248 | spacerItem8 = QtWidgets.QSpacerItem(20, 40, QtWidgets.QSizePolicy.Minimum, QtWidgets.QSizePolicy.Expanding)
249 | self.verticalLayout_5.addItem(spacerItem8)
250 | self.verticalLayout_4.addWidget(self.groupBox_2)
251 | self.tabWidget.addTab(self.tab_3, "")
252 | self.verticalLayout_3.addWidget(self.tabWidget)
253 | self.gridLayout_3 = QtWidgets.QGridLayout()
254 | self.gridLayout_3.setObjectName("gridLayout_3")
255 | self.statusbar = QtWidgets.QLabel(self.centralwidget)
256 | self.statusbar.setObjectName("statusbar")
257 | self.gridLayout_3.addWidget(self.statusbar, 0, 0, 1, 1, QtCore.Qt.AlignLeft)
258 | spacerItem9 = QtWidgets.QSpacerItem(40, 20, QtWidgets.QSizePolicy.Expanding, QtWidgets.QSizePolicy.Minimum)
259 | self.gridLayout_3.addItem(spacerItem9, 0, 1, 1, 1)
260 | self.label_6 = QtWidgets.QLabel(self.centralwidget)
261 | self.label_6.setOpenExternalLinks(True)
262 | self.label_6.setObjectName("label_6")
263 | self.gridLayout_3.addWidget(self.label_6, 0, 2, 1, 1, QtCore.Qt.AlignRight)
264 | self.verticalLayout_3.addLayout(self.gridLayout_3)
265 | MainWindow.setCentralWidget(self.centralwidget)
266 |
267 | self.retranslateUi(MainWindow)
268 | self.tabWidget.setCurrentIndex(0)
269 | self.setWindowTitle("Tacotron2 + Waveglow GUI v0.3")
270 |
271 | def retranslateUi(self, MainWindow):
272 | _translate = QtCore.QCoreApplication.translate
273 | MainWindow.setWindowTitle(_translate("MainWindow", "MainWindow"))
274 | self.label_9.setText(_translate("MainWindow", "Waveglow Model:"))
275 | self.label_7.setText(_translate("MainWindow", "Tacotron 2 Model:"))
276 | self.LoadTTButton.setText(_translate("MainWindow", "Browse..."))
277 | self.LoadWGButton.setText(_translate("MainWindow", "Browse..."))
278 | self.label.setText(_translate("MainWindow", "GPU Mode (Requires CUDA)"))
279 | self.label_2.setText(_translate("MainWindow", "Enter Text to speech:"))
280 | self.progressBarLabel.setText(_translate("MainWindow", "TextLabel"))
281 | self.TTSDialogButton.setText(_translate("MainWindow", "Start"))
282 | self.TTSStopButton.setText(_translate("MainWindow", "Stop"))
283 | self.log_window1.setText(_translate("MainWindow", "Begin by loading a voice model"))
284 | self.tabWidget.setTabText(self.tabWidget.indexOf(self.tab), _translate("MainWindow", "Text to Speech"))
285 | self.label_5.setText(_translate("MainWindow", "Channel name"))
286 | self.label_3.setText(_translate("MainWindow", "StreamElements JWT Token"))
287 | self.ClientStartBtn.setText(_translate("MainWindow", "Start"))
288 | self.ClientStopBtn.setText(_translate("MainWindow", "Stop"))
289 | self.ClientSkipBtn.setText(_translate("MainWindow", "Skip"))
290 | self.label_8.setText(_translate("MainWindow", "Minimum amount for TTS:"))
291 | self.label_4.setText(_translate("MainWindow", "Status:"))
292 | self.progressBar2Label.setText(_translate("MainWindow", "TextLabel"))
293 | self.tabWidget.setTabText(self.tabWidget.indexOf(self.tab_2), _translate("MainWindow", "StreamElements"))
294 | self.groupBox.setTitle(_translate("MainWindow", "Pytorch"))
295 | self.OptLimitCpuBtn.setText(_translate("MainWindow", "Limit CPU"))
296 | self.label_10.setText(_translate("MainWindow", "Threads:"))
297 | self.groupBox_2.setTitle(_translate("MainWindow", "StreamElement"))
298 | self.OptDonoNameAmountBtn.setText(_translate("MainWindow", "Read donor\'s name and amount"))
299 | self.OptApproveDonoBtn.setText(_translate("MainWindow", "Approved donations only"))
300 | self.OptBlockNumberBtn.setText(_translate("MainWindow", "Block large numbers (>8 digits)"))
301 | self.tabWidget.setTabText(self.tabWidget.indexOf(self.tab_3), _translate("MainWindow", "Options"))
302 | self.statusbar.setText(_translate("MainWindow", "Ready"))
303 | self.label_6.setText(_translate("MainWindow", "v0.3"))
304 |
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | from math import sqrt
2 | import torch
3 | from torch.autograd import Variable
4 | from torch import nn
5 | from torch.nn import functional as F
6 | from layers import ConvNorm, LinearNorm
7 | from utils import to_gpu, get_mask_from_lengths
8 |
9 |
10 | class LocationLayer(nn.Module):
11 | def __init__(self, attention_n_filters, attention_kernel_size,
12 | attention_dim):
13 | super(LocationLayer, self).__init__()
14 | padding = int((attention_kernel_size - 1) / 2)
15 | self.location_conv = ConvNorm(2, attention_n_filters,
16 | kernel_size=attention_kernel_size,
17 | padding=padding, bias=False, stride=1,
18 | dilation=1)
19 | self.location_dense = LinearNorm(attention_n_filters, attention_dim,
20 | bias=False, w_init_gain='tanh')
21 |
22 | def forward(self, attention_weights_cat):
23 | processed_attention = self.location_conv(attention_weights_cat)
24 | processed_attention = processed_attention.transpose(1, 2)
25 | processed_attention = self.location_dense(processed_attention)
26 | return processed_attention
27 |
28 |
29 | class Attention(nn.Module):
30 | def __init__(self, attention_rnn_dim, embedding_dim, attention_dim,
31 | attention_location_n_filters, attention_location_kernel_size):
32 | super(Attention, self).__init__()
33 | self.query_layer = LinearNorm(attention_rnn_dim, attention_dim,
34 | bias=False, w_init_gain='tanh')
35 | self.memory_layer = LinearNorm(embedding_dim, attention_dim, bias=False,
36 | w_init_gain='tanh')
37 | self.v = LinearNorm(attention_dim, 1, bias=False)
38 | self.location_layer = LocationLayer(attention_location_n_filters,
39 | attention_location_kernel_size,
40 | attention_dim)
41 | self.score_mask_value = -float("inf")
42 |
43 | def get_alignment_energies(self, query, processed_memory,
44 | attention_weights_cat):
45 | """
46 | PARAMS
47 | ------
48 | query: decoder output (batch, n_mel_channels * n_frames_per_step)
49 | processed_memory: processed encoder outputs (B, T_in, attention_dim)
50 | attention_weights_cat: cumulative and prev. att weights (B, 2, max_time)
51 |
52 | RETURNS
53 | -------
54 | alignment (batch, max_time)
55 | """
56 |
57 | processed_query = self.query_layer(query.unsqueeze(1))
58 | processed_attention_weights = self.location_layer(attention_weights_cat)
59 | energies = self.v(torch.tanh(
60 | processed_query + processed_attention_weights + processed_memory))
61 |
62 | energies = energies.squeeze(-1)
63 | return energies
64 |
65 | def forward(self, attention_hidden_state, memory, processed_memory,
66 | attention_weights_cat, mask):
67 | """
68 | PARAMS
69 | ------
70 | attention_hidden_state: attention rnn last output
71 | memory: encoder outputs
72 | processed_memory: processed encoder outputs
73 | attention_weights_cat: previous and cummulative attention weights
74 | mask: binary mask for padded data
75 | """
76 | alignment = self.get_alignment_energies(
77 | attention_hidden_state, processed_memory, attention_weights_cat)
78 |
79 | if mask is not None:
80 | alignment.data.masked_fill_(mask, self.score_mask_value)
81 |
82 | attention_weights = F.softmax(alignment, dim=1)
83 | attention_context = torch.bmm(attention_weights.unsqueeze(1), memory)
84 | attention_context = attention_context.squeeze(1)
85 |
86 | return attention_context, attention_weights
87 |
88 |
89 | class Prenet(nn.Module):
90 | def __init__(self, in_dim, sizes):
91 | super(Prenet, self).__init__()
92 | in_sizes = [in_dim] + sizes[:-1]
93 | self.layers = nn.ModuleList(
94 | [LinearNorm(in_size, out_size, bias=False)
95 | for (in_size, out_size) in zip(in_sizes, sizes)])
96 |
97 | def forward(self, x):
98 | for linear in self.layers:
99 | x = F.dropout(F.relu(linear(x)), p=0.5, training=True)
100 | return x
101 |
102 |
103 | class Postnet(nn.Module):
104 | """Postnet
105 | - Five 1-d convolution with 512 channels and kernel size 5
106 | """
107 |
108 | def __init__(self, hparams):
109 | super(Postnet, self).__init__()
110 | self.convolutions = nn.ModuleList()
111 |
112 | self.convolutions.append(
113 | nn.Sequential(
114 | ConvNorm(hparams.n_mel_channels, hparams.postnet_embedding_dim,
115 | kernel_size=hparams.postnet_kernel_size, stride=1,
116 | padding=int((hparams.postnet_kernel_size - 1) / 2),
117 | dilation=1, w_init_gain='tanh'),
118 | nn.BatchNorm1d(hparams.postnet_embedding_dim))
119 | )
120 |
121 | for i in range(1, hparams.postnet_n_convolutions - 1):
122 | self.convolutions.append(
123 | nn.Sequential(
124 | ConvNorm(hparams.postnet_embedding_dim,
125 | hparams.postnet_embedding_dim,
126 | kernel_size=hparams.postnet_kernel_size, stride=1,
127 | padding=int((hparams.postnet_kernel_size - 1) / 2),
128 | dilation=1, w_init_gain='tanh'),
129 | nn.BatchNorm1d(hparams.postnet_embedding_dim))
130 | )
131 |
132 | self.convolutions.append(
133 | nn.Sequential(
134 | ConvNorm(hparams.postnet_embedding_dim, hparams.n_mel_channels,
135 | kernel_size=hparams.postnet_kernel_size, stride=1,
136 | padding=int((hparams.postnet_kernel_size - 1) / 2),
137 | dilation=1, w_init_gain='linear'),
138 | nn.BatchNorm1d(hparams.n_mel_channels))
139 | )
140 |
141 | def forward(self, x):
142 | for i in range(len(self.convolutions) - 1):
143 | x = F.dropout(torch.tanh(self.convolutions[i](x)), 0.5, self.training)
144 | x = F.dropout(self.convolutions[-1](x), 0.5, self.training)
145 |
146 | return x
147 |
148 |
149 | class Encoder(nn.Module):
150 | """Encoder module:
151 | - Three 1-d convolution banks
152 | - Bidirectional LSTM
153 | """
154 | def __init__(self, hparams):
155 | super(Encoder, self).__init__()
156 |
157 | convolutions = []
158 | for _ in range(hparams.encoder_n_convolutions):
159 | conv_layer = nn.Sequential(
160 | ConvNorm(hparams.encoder_embedding_dim,
161 | hparams.encoder_embedding_dim,
162 | kernel_size=hparams.encoder_kernel_size, stride=1,
163 | padding=int((hparams.encoder_kernel_size - 1) / 2),
164 | dilation=1, w_init_gain='relu'),
165 | nn.BatchNorm1d(hparams.encoder_embedding_dim))
166 | convolutions.append(conv_layer)
167 | self.convolutions = nn.ModuleList(convolutions)
168 |
169 | self.lstm = nn.LSTM(hparams.encoder_embedding_dim,
170 | int(hparams.encoder_embedding_dim / 2), 1,
171 | batch_first=True, bidirectional=True)
172 |
173 | def forward(self, x, input_lengths):
174 | for conv in self.convolutions:
175 | x = F.dropout(F.relu(conv(x)), 0.5, self.training)
176 |
177 | x = x.transpose(1, 2)
178 |
179 | # pytorch tensor are not reversible, hence the conversion
180 | input_lengths = input_lengths.cpu().numpy()
181 | x = nn.utils.rnn.pack_padded_sequence(
182 | x, input_lengths, batch_first=True)
183 |
184 | self.lstm.flatten_parameters()
185 | outputs, _ = self.lstm(x)
186 |
187 | outputs, _ = nn.utils.rnn.pad_packed_sequence(
188 | outputs, batch_first=True)
189 |
190 | return outputs
191 |
192 | def inference(self, x):
193 | for conv in self.convolutions:
194 | x = F.dropout(F.relu(conv(x)), 0.5, self.training)
195 |
196 | x = x.transpose(1, 2)
197 |
198 | self.lstm.flatten_parameters()
199 | outputs, _ = self.lstm(x)
200 |
201 | return outputs
202 |
203 |
204 | class Decoder(nn.Module):
205 | def __init__(self, hparams):
206 | super(Decoder, self).__init__()
207 | self.n_mel_channels = hparams.n_mel_channels
208 | self.n_frames_per_step = hparams.n_frames_per_step
209 | self.encoder_embedding_dim = hparams.encoder_embedding_dim
210 | self.attention_rnn_dim = hparams.attention_rnn_dim
211 | self.decoder_rnn_dim = hparams.decoder_rnn_dim
212 | self.prenet_dim = hparams.prenet_dim
213 | self.max_decoder_steps = hparams.max_decoder_steps
214 | self.gate_threshold = hparams.gate_threshold
215 | self.p_attention_dropout = hparams.p_attention_dropout
216 | self.p_decoder_dropout = hparams.p_decoder_dropout
217 |
218 | self.prenet = Prenet(
219 | hparams.n_mel_channels * hparams.n_frames_per_step,
220 | [hparams.prenet_dim, hparams.prenet_dim])
221 |
222 | self.attention_rnn = nn.LSTMCell(
223 | hparams.prenet_dim + hparams.encoder_embedding_dim,
224 | hparams.attention_rnn_dim)
225 |
226 | self.attention_layer = Attention(
227 | hparams.attention_rnn_dim, hparams.encoder_embedding_dim,
228 | hparams.attention_dim, hparams.attention_location_n_filters,
229 | hparams.attention_location_kernel_size)
230 |
231 | self.decoder_rnn = nn.LSTMCell(
232 | hparams.attention_rnn_dim + hparams.encoder_embedding_dim,
233 | hparams.decoder_rnn_dim, 1)
234 |
235 | self.linear_projection = LinearNorm(
236 | hparams.decoder_rnn_dim + hparams.encoder_embedding_dim,
237 | hparams.n_mel_channels * hparams.n_frames_per_step)
238 |
239 | self.gate_layer = LinearNorm(
240 | hparams.decoder_rnn_dim + hparams.encoder_embedding_dim, 1,
241 | bias=True, w_init_gain='sigmoid')
242 |
243 | def get_go_frame(self, memory):
244 | """ Gets all zeros frames to use as first decoder input
245 | PARAMS
246 | ------
247 | memory: decoder outputs
248 |
249 | RETURNS
250 | -------
251 | decoder_input: all zeros frames
252 | """
253 | B = memory.size(0)
254 | decoder_input = Variable(memory.data.new(
255 | B, self.n_mel_channels * self.n_frames_per_step).zero_())
256 | return decoder_input
257 |
258 | def initialize_decoder_states(self, memory, mask):
259 | """ Initializes attention rnn states, decoder rnn states, attention
260 | weights, attention cumulative weights, attention context, stores memory
261 | and stores processed memory
262 | PARAMS
263 | ------
264 | memory: Encoder outputs
265 | mask: Mask for padded data if training, expects None for inference
266 | """
267 | B = memory.size(0)
268 | MAX_TIME = memory.size(1)
269 |
270 | self.attention_hidden = Variable(memory.data.new(
271 | B, self.attention_rnn_dim).zero_())
272 | self.attention_cell = Variable(memory.data.new(
273 | B, self.attention_rnn_dim).zero_())
274 |
275 | self.decoder_hidden = Variable(memory.data.new(
276 | B, self.decoder_rnn_dim).zero_())
277 | self.decoder_cell = Variable(memory.data.new(
278 | B, self.decoder_rnn_dim).zero_())
279 |
280 | self.attention_weights = Variable(memory.data.new(
281 | B, MAX_TIME).zero_())
282 | self.attention_weights_cum = Variable(memory.data.new(
283 | B, MAX_TIME).zero_())
284 | self.attention_context = Variable(memory.data.new(
285 | B, self.encoder_embedding_dim).zero_())
286 |
287 | self.memory = memory
288 | self.processed_memory = self.attention_layer.memory_layer(memory)
289 | self.mask = mask
290 |
291 | def parse_decoder_inputs(self, decoder_inputs):
292 | """ Prepares decoder inputs, i.e. mel outputs
293 | PARAMS
294 | ------
295 | decoder_inputs: inputs used for teacher-forced training, i.e. mel-specs
296 |
297 | RETURNS
298 | -------
299 | inputs: processed decoder inputs
300 |
301 | """
302 | # (B, n_mel_channels, T_out) -> (B, T_out, n_mel_channels)
303 | decoder_inputs = decoder_inputs.transpose(1, 2)
304 | decoder_inputs = decoder_inputs.view(
305 | decoder_inputs.size(0),
306 | int(decoder_inputs.size(1)/self.n_frames_per_step), -1)
307 | # (B, T_out, n_mel_channels) -> (T_out, B, n_mel_channels)
308 | decoder_inputs = decoder_inputs.transpose(0, 1)
309 | return decoder_inputs
310 |
311 | def parse_decoder_outputs(self, mel_outputs, gate_outputs, alignments):
312 | """ Prepares decoder outputs for output
313 | PARAMS
314 | ------
315 | mel_outputs:
316 | gate_outputs: gate output energies
317 | alignments:
318 |
319 | RETURNS
320 | -------
321 | mel_outputs:
322 | gate_outpust: gate output energies
323 | alignments:
324 | """
325 | # (T_out, B) -> (B, T_out)
326 | alignments = torch.stack(alignments).transpose(0, 1)
327 | # (T_out, B) -> (B, T_out)
328 | gate_outputs = torch.stack(gate_outputs).transpose(0, 1)
329 | gate_outputs = gate_outputs.contiguous()
330 | # (T_out, B, n_mel_channels) -> (B, T_out, n_mel_channels)
331 | mel_outputs = torch.stack(mel_outputs).transpose(0, 1).contiguous()
332 | # decouple frames per step
333 | mel_outputs = mel_outputs.view(
334 | mel_outputs.size(0), -1, self.n_mel_channels)
335 | # (B, T_out, n_mel_channels) -> (B, n_mel_channels, T_out)
336 | mel_outputs = mel_outputs.transpose(1, 2)
337 |
338 | return mel_outputs, gate_outputs, alignments
339 |
340 | def decode(self, decoder_input):
341 | """ Decoder step using stored states, attention and memory
342 | PARAMS
343 | ------
344 | decoder_input: previous mel output
345 |
346 | RETURNS
347 | -------
348 | mel_output:
349 | gate_output: gate output energies
350 | attention_weights:
351 | """
352 | cell_input = torch.cat((decoder_input, self.attention_context), -1)
353 | self.attention_hidden, self.attention_cell = self.attention_rnn(
354 | cell_input, (self.attention_hidden, self.attention_cell))
355 | self.attention_hidden = F.dropout(
356 | self.attention_hidden, self.p_attention_dropout, self.training)
357 |
358 | attention_weights_cat = torch.cat(
359 | (self.attention_weights.unsqueeze(1),
360 | self.attention_weights_cum.unsqueeze(1)), dim=1)
361 | self.attention_context, self.attention_weights = self.attention_layer(
362 | self.attention_hidden, self.memory, self.processed_memory,
363 | attention_weights_cat, self.mask)
364 |
365 | self.attention_weights_cum += self.attention_weights
366 | decoder_input = torch.cat(
367 | (self.attention_hidden, self.attention_context), -1)
368 | self.decoder_hidden, self.decoder_cell = self.decoder_rnn(
369 | decoder_input, (self.decoder_hidden, self.decoder_cell))
370 | self.decoder_hidden = F.dropout(
371 | self.decoder_hidden, self.p_decoder_dropout, self.training)
372 |
373 | decoder_hidden_attention_context = torch.cat(
374 | (self.decoder_hidden, self.attention_context), dim=1)
375 | decoder_output = self.linear_projection(
376 | decoder_hidden_attention_context)
377 |
378 | gate_prediction = self.gate_layer(decoder_hidden_attention_context)
379 | return decoder_output, gate_prediction, self.attention_weights
380 |
381 | def forward(self, memory, decoder_inputs, memory_lengths):
382 | """ Decoder forward pass for training
383 | PARAMS
384 | ------
385 | memory: Encoder outputs
386 | decoder_inputs: Decoder inputs for teacher forcing. i.e. mel-specs
387 | memory_lengths: Encoder output lengths for attention masking.
388 |
389 | RETURNS
390 | -------
391 | mel_outputs: mel outputs from the decoder
392 | gate_outputs: gate outputs from the decoder
393 | alignments: sequence of attention weights from the decoder
394 | """
395 |
396 | decoder_input = self.get_go_frame(memory).unsqueeze(0)
397 | decoder_inputs = self.parse_decoder_inputs(decoder_inputs)
398 | decoder_inputs = torch.cat((decoder_input, decoder_inputs), dim=0)
399 | decoder_inputs = self.prenet(decoder_inputs)
400 |
401 | self.initialize_decoder_states(
402 | memory, mask=~get_mask_from_lengths(memory_lengths))
403 |
404 | mel_outputs, gate_outputs, alignments = [], [], []
405 | while len(mel_outputs) < decoder_inputs.size(0) - 1:
406 | decoder_input = decoder_inputs[len(mel_outputs)]
407 | mel_output, gate_output, attention_weights = self.decode(
408 | decoder_input)
409 | mel_outputs += [mel_output.squeeze(1)]
410 | gate_outputs += [gate_output.squeeze(1)]
411 | alignments += [attention_weights]
412 |
413 | mel_outputs, gate_outputs, alignments = self.parse_decoder_outputs(
414 | mel_outputs, gate_outputs, alignments)
415 |
416 | return mel_outputs, gate_outputs, alignments
417 |
418 | def inference(self, memory):
419 | """ Decoder inference
420 | PARAMS
421 | ------
422 | memory: Encoder outputs
423 |
424 | RETURNS
425 | -------
426 | mel_outputs: mel outputs from the decoder
427 | gate_outputs: gate outputs from the decoder
428 | alignments: sequence of attention weights from the decoder
429 | """
430 | decoder_input = self.get_go_frame(memory)
431 |
432 | self.initialize_decoder_states(memory, mask=None)
433 |
434 | mel_outputs, gate_outputs, alignments = [], [], []
435 | while True:
436 | decoder_input = self.prenet(decoder_input)
437 | mel_output, gate_output, alignment = self.decode(decoder_input)
438 |
439 | mel_outputs += [mel_output.squeeze(1)]
440 | gate_outputs += [gate_output]
441 | alignments += [alignment]
442 |
443 | if torch.sigmoid(gate_output.data) > self.gate_threshold:
444 | break
445 | elif len(mel_outputs) == self.max_decoder_steps:
446 | print("Warning! Reached max decoder steps")
447 | break
448 |
449 | decoder_input = mel_output
450 |
451 | mel_outputs, gate_outputs, alignments = self.parse_decoder_outputs(
452 | mel_outputs, gate_outputs, alignments)
453 |
454 | return mel_outputs, gate_outputs, alignments
455 |
456 |
457 | class Tacotron2(nn.Module):
458 | def __init__(self, hparams):
459 | super(Tacotron2, self).__init__()
460 | self.mask_padding = hparams.mask_padding
461 | self.fp16_run = hparams.fp16_run
462 | self.n_mel_channels = hparams.n_mel_channels
463 | self.n_frames_per_step = hparams.n_frames_per_step
464 | self.embedding = nn.Embedding(
465 | hparams.n_symbols, hparams.symbols_embedding_dim)
466 | std = sqrt(2.0 / (hparams.n_symbols + hparams.symbols_embedding_dim))
467 | val = sqrt(3.0) * std # uniform bounds for std
468 | self.embedding.weight.data.uniform_(-val, val)
469 | self.encoder = Encoder(hparams)
470 | self.decoder = Decoder(hparams)
471 | self.postnet = Postnet(hparams)
472 |
473 | def parse_batch(self, batch):
474 | text_padded, input_lengths, mel_padded, gate_padded, \
475 | output_lengths = batch
476 | text_padded = to_gpu(text_padded).long()
477 | input_lengths = to_gpu(input_lengths).long()
478 | max_len = torch.max(input_lengths.data).item()
479 | mel_padded = to_gpu(mel_padded).float()
480 | gate_padded = to_gpu(gate_padded).float()
481 | output_lengths = to_gpu(output_lengths).long()
482 |
483 | return (
484 | (text_padded, input_lengths, mel_padded, max_len, output_lengths),
485 | (mel_padded, gate_padded))
486 |
487 | def parse_output(self, outputs, output_lengths=None):
488 | if self.mask_padding and output_lengths is not None:
489 | mask = ~get_mask_from_lengths(output_lengths)
490 | mask = mask.expand(self.n_mel_channels, mask.size(0), mask.size(1))
491 | mask = mask.permute(1, 0, 2)
492 |
493 | outputs[0].data.masked_fill_(mask, 0.0)
494 | outputs[1].data.masked_fill_(mask, 0.0)
495 | outputs[2].data.masked_fill_(mask[:, 0, :], 1e3) # gate energies
496 |
497 | return outputs
498 |
499 | def forward(self, inputs):
500 | text_inputs, text_lengths, mels, max_len, output_lengths = inputs
501 | text_lengths, output_lengths = text_lengths.data, output_lengths.data
502 |
503 | embedded_inputs = self.embedding(text_inputs).transpose(1, 2)
504 |
505 | encoder_outputs = self.encoder(embedded_inputs, text_lengths)
506 |
507 | mel_outputs, gate_outputs, alignments = self.decoder(
508 | encoder_outputs, mels, memory_lengths=text_lengths)
509 |
510 | mel_outputs_postnet = self.postnet(mel_outputs)
511 | mel_outputs_postnet = mel_outputs + mel_outputs_postnet
512 |
513 | return self.parse_output(
514 | [mel_outputs, mel_outputs_postnet, gate_outputs, alignments],
515 | output_lengths)
516 |
517 | def inference(self, inputs):
518 | embedded_inputs = self.embedding(inputs).transpose(1, 2)
519 | encoder_outputs = self.encoder.inference(embedded_inputs)
520 | mel_outputs, gate_outputs, alignments = self.decoder.inference(
521 | encoder_outputs)
522 |
523 | mel_outputs_postnet = self.postnet(mel_outputs)
524 | mel_outputs_postnet = mel_outputs + mel_outputs_postnet
525 |
526 | outputs = self.parse_output(
527 | [mel_outputs, mel_outputs_postnet, gate_outputs, alignments])
528 |
529 | return outputs
530 |
--------------------------------------------------------------------------------
/utils_hparam.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2020 The Tensor2Tensor Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | # Forked with minor changes from https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/training/python/training/hparam.py pylint: disable=line-too-long
17 | """Hyperparameter values."""
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | import json
23 | import numbers
24 | import re
25 | import six
26 |
27 | # Define the regular expression for parsing a single clause of the input
28 | # (delimited by commas). A legal clause looks like:
29 | #