├── datasets ├── __init__.py ├── ljspeech.py ├── blizzard.py ├── kss.py ├── bible.py └── datafeeder.py ├── utils ├── NanumBarunGothic.ttf ├── infolog.py ├── plot.py ├── audio.py └── __init__.py ├── DISCLAIMER ├── models ├── __init__.py ├── helpers.py ├── modules.py ├── tacotron.py └── rnn_wrappers.py ├── LICENSE ├── text ├── cmudict.py ├── symbols.py ├── cleaners.py ├── numbers.py ├── __init__.py ├── kor_dic.py └── korean.py ├── audio ├── get_duration.py ├── silence.py ├── __init__.py └── google_speech.py ├── TRAINING_DATA.md ├── requirements.txt ├── download.py ├── hparams.py ├── recognition ├── google.py └── alignment.py ├── LJSpeech-1.1 └── README ├── eval.py ├── README.md ├── train.py └── synthesizer.py /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /utils/NanumBarunGothic.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DSAIL-SKKU/Multispeaker/HEAD/utils/NanumBarunGothic.ttf -------------------------------------------------------------------------------- /DISCLAIMER: -------------------------------------------------------------------------------- 1 | This is not an official [DEVSISTERS](http://devsisters.com/) product and is not responsible for misuse or for any damage that you may cause. You agree that you use this software at your own risk. 2 | 3 | 이것은 [데브시스터즈](http://devsisters.com/)의 공식적인 제품이 아닙니다. [데브시스터즈](http://devsisters.com )는 이 코드를 잘못 사용했을 시 발생한 문제나 이슈에 대한 책임을 지지 않으며 이 소프트웨어의 사용은 사용자 자신에>게 전적으로 책임이 있습니다. 4 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | from glob import glob 3 | from .tacotron import Tacotron 4 | 5 | 6 | def create_model(name, hparams): 7 | if name == 'tacotron': 8 | return Tacotron(hparams) 9 | else: 10 | raise Exception('Unknown model: ' + name) 11 | 12 | 13 | def get_most_recent_checkpoint(checkpoint_dir): 14 | checkpoint_paths = [path for path in glob("{}/*.ckpt-*.data-*".format(checkpoint_dir))] 15 | idxes = [int(os.path.basename(path).split('-')[1].split('.')[0]) for path in checkpoint_paths] 16 | 17 | max_idx = max(idxes) 18 | lastest_checkpoint = os.path.join(checkpoint_dir, "model.ckpt-{}".format(max_idx)) 19 | 20 | #latest_checkpoint=checkpoint_paths[0] 21 | print(" [*] Found lastest checkpoint: {}".format(lastest_checkpoint)) 22 | return lastest_checkpoint 23 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2017 Keith Ito 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in 11 | all copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 19 | THE SOFTWARE. 20 | -------------------------------------------------------------------------------- /utils/infolog.py: -------------------------------------------------------------------------------- 1 | import atexit 2 | from datetime import datetime 3 | import json 4 | from threading import Thread 5 | from urllib.request import Request, urlopen 6 | 7 | 8 | _format = '%Y-%m-%d %H:%M:%S.%f' 9 | _file = None 10 | _run_name = None 11 | _slack_url = None 12 | 13 | 14 | def init(filename, run_name, slack_url=None): 15 | global _file, _run_name, _slack_url 16 | _close_logfile() 17 | _file = open(filename, 'a', encoding="utf-8") 18 | _file.write('\n-----------------------------------------------------------------\n') 19 | _file.write('Starting new training run\n') 20 | _file.write('-----------------------------------------------------------------\n') 21 | _run_name = run_name 22 | _slack_url = slack_url 23 | 24 | 25 | def log(msg, slack=False): 26 | print(msg) 27 | if _file is not None: 28 | _file.write('[%s] %s\n' % (datetime.now().strftime(_format)[:-3], msg)) 29 | if slack and _slack_url is not None: 30 | Thread(target=_send_slack, args=(msg,)).start() 31 | 32 | 33 | def _close_logfile(): 34 | global _file 35 | if _file is not None: 36 | _file.close() 37 | _file = None 38 | 39 | 40 | def _send_slack(msg): 41 | req = Request(_slack_url) 42 | req.add_header('Content-Type', 'application/json') 43 | urlopen(req, json.dumps({ 44 | 'username': 'tacotron', 45 | 'icon_emoji': ':taco:', 46 | 'text': '*%s*: %s' % (_run_name, msg) 47 | }).encode()) 48 | 49 | 50 | atexit.register(_close_logfile) 51 | -------------------------------------------------------------------------------- /utils/plot.py: -------------------------------------------------------------------------------- 1 | import os 2 | import matplotlib 3 | from jamo import h2j, j2hcj 4 | 5 | matplotlib.use('Agg') 6 | matplotlib.rc('font', family="NanumBarunGothic") 7 | import matplotlib.pyplot as plt 8 | 9 | from text import PAD, EOS 10 | from utils import add_postfix 11 | from text.korean import normalize 12 | 13 | def plot(alignment, info, text, isKorean=True): 14 | char_len, audio_len = alignment.shape # 145, 200 15 | 16 | fig, ax = plt.subplots(figsize=(char_len/5, 5)) 17 | im = ax.imshow( 18 | alignment.T, 19 | aspect='auto', 20 | origin='lower', 21 | interpolation='none') 22 | 23 | xlabel = 'Encoder timestep' 24 | ylabel = 'Decoder timestep' 25 | 26 | if info is not None: 27 | xlabel += '\n{}'.format(info) 28 | 29 | plt.xlabel(xlabel) 30 | plt.ylabel(ylabel) 31 | 32 | if text: 33 | if isKorean: 34 | jamo_text = j2hcj(h2j(normalize(text))) 35 | else: 36 | jamo_text=text 37 | pad = [PAD] * (char_len - len(jamo_text) - 1) 38 | 39 | plt.xticks(range(char_len), 40 | [tok for tok in jamo_text] + [EOS] + pad) 41 | 42 | if text is not None: 43 | while True: 44 | if text[-1] in [EOS, PAD]: 45 | text = text[:-1] 46 | else: 47 | break 48 | plt.title(text) 49 | 50 | plt.tight_layout() 51 | 52 | def plot_alignment( 53 | alignment, path, info=None, text=None, isKorean=True): 54 | 55 | if text: 56 | tmp_alignment = alignment[:len(h2j(text)) + 2] 57 | 58 | plot(tmp_alignment, info, text, isKorean) 59 | plt.savefig(path, format='png') 60 | else: 61 | plot(alignment, info, text, isKorean) 62 | plt.savefig(path, format='png') 63 | 64 | print(" [*] Plot saved: {}".format(path)) -------------------------------------------------------------------------------- /text/cmudict.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | 4 | valid_symbols = [ 5 | 'AA', 'AA0', 'AA1', 'AA2', 'AE', 'AE0', 'AE1', 'AE2', 'AH', 'AH0', 'AH1', 'AH2', 6 | 'AO', 'AO0', 'AO1', 'AO2', 'AW', 'AW0', 'AW1', 'AW2', 'AY', 'AY0', 'AY1', 'AY2', 7 | 'B', 'CH', 'D', 'DH', 'EH', 'EH0', 'EH1', 'EH2', 'ER', 'ER0', 'ER1', 'ER2', 'EY', 8 | 'EY0', 'EY1', 'EY2', 'F', 'G', 'HH', 'IH', 'IH0', 'IH1', 'IH2', 'IY', 'IY0', 'IY1', 9 | 'IY2', 'JH', 'K', 'L', 'M', 'N', 'NG', 'OW', 'OW0', 'OW1', 'OW2', 'OY', 'OY0', 10 | 'OY1', 'OY2', 'P', 'R', 'S', 'SH', 'T', 'TH', 'UH', 'UH0', 'UH1', 'UH2', 'UW', 11 | 'UW0', 'UW1', 'UW2', 'V', 'W', 'Y', 'Z', 'ZH' 12 | ] 13 | 14 | _valid_symbol_set = set(valid_symbols) 15 | 16 | 17 | class CMUDict: 18 | '''Thin wrapper around CMUDict data. http://www.speech.cs.cmu.edu/cgi-bin/cmudict''' 19 | def __init__(self, file_or_path, keep_ambiguous=True): 20 | if isinstance(file_or_path, str): 21 | with open(file_or_path, encoding='latin-1') as f: 22 | entries = _parse_cmudict(f) 23 | else: 24 | entries = _parse_cmudict(file_or_path) 25 | if not keep_ambiguous: 26 | entries = {word: pron for word, pron in entries.items() if len(pron) == 1} 27 | self._entries = entries 28 | 29 | 30 | def __len__(self): 31 | return len(self._entries) 32 | 33 | 34 | def lookup(self, word): 35 | '''Returns list of ARPAbet pronunciations of the given word.''' 36 | return self._entries.get(word.upper()) 37 | 38 | 39 | 40 | _alt_re = re.compile(r'\([0-9]+\)') 41 | 42 | 43 | def _parse_cmudict(file): 44 | cmudict = {} 45 | for line in file: 46 | if len(line) and (line[0] >= 'A' and line[0] <= 'Z' or line[0] == "'"): 47 | parts = line.split(' ') 48 | word = re.sub(_alt_re, '', parts[0]) 49 | pronunciation = _get_pronunciation(parts[1]) 50 | if pronunciation: 51 | if word in cmudict: 52 | cmudict[word].append(pronunciation) 53 | else: 54 | cmudict[word] = [pronunciation] 55 | return cmudict 56 | 57 | 58 | def _get_pronunciation(s): 59 | parts = s.strip().split(' ') 60 | for part in parts: 61 | if part not in _valid_symbol_set: 62 | return None 63 | return ' '.join(parts) 64 | -------------------------------------------------------------------------------- /text/symbols.py: -------------------------------------------------------------------------------- 1 | # ''' 2 | # Defines the set of symbols used in text input to the model. 3 | # 4 | # The default is a set of ASCII characters that works well for English or text that has been run 5 | # through Unidecode. For other data, you can modify _characters. See TRAINING_DATA.md for details. 6 | # ''' 7 | # from text import cmudict 8 | # 9 | # _pad = '_' 10 | # _eos = '~' 11 | # _characters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz!\'(),-.:;? ' 12 | # 13 | # # Prepend "@" to ARPAbet symbols to ensure uniqueness (some are the same as uppercase letters): 14 | # _arpabet = ['@' + s for s in cmudict.valid_symbols] 15 | # 16 | # # Export all symbols: 17 | # symbols = [_pad, _eos] + list(_characters) + _arpabet 18 | 19 | # coding: utf-8 20 | ''' 21 | Defines the set of symbols used in text input to the model. 22 | 23 | The default is a set of ASCII characters that works well for English or text that has been run 24 | through Unidecode. For other data, you can modify _characters. See TRAINING_DATA.md for details. 25 | ''' 26 | 27 | from jamo import h2j, j2h 28 | from jamo.jamo import _jamo_char_to_hcj 29 | 30 | from .korean import ALL_SYMBOLS, PAD, EOS 31 | 32 | # For english 33 | en_symbols = PAD + EOS + 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz!\'(),-.:;? ' # <-For deployment(Because korean ALL_SYMBOLS follow this convention) 34 | 35 | symbols = ALL_SYMBOLS # for korean 36 | 37 | """ 38 | 초성과 종성은 같아보이지만, 다른 character이다. 39 | '_~ᄀᄁᄂᄃᄄᄅᄆᄇᄈᄉᄊᄋᄌᄍᄎᄏᄐᄑ하ᅢᅣᅤᅥᅦᅧᅨᅩᅪᅫᅬᅭᅮᅯᅰᅱᅲᅳᅴᅵᆨᆩᆪᆫᆬᆭᆮᆯᆰᆱᆲᆳᆴᆵᆶᆷᆸᆹᆺᆻᆼᆽᆾᆿᇀᇁᇂ!'(),-.:;? ' 40 | '_': 0, '~': 1, 'ᄀ': 2, 'ᄁ': 3, 'ᄂ': 4, 'ᄃ': 5, 'ᄄ': 6, 'ᄅ': 7, 'ᄆ': 8, 'ᄇ': 9, 'ᄈ': 10, 41 | 'ᄉ': 11, 'ᄊ': 12, 'ᄋ': 13, 'ᄌ': 14, 'ᄍ': 15, 'ᄎ': 16, 'ᄏ': 17, 'ᄐ': 18, 'ᄑ': 19, 'ᄒ': 20, 42 | 'ᅡ': 21, 'ᅢ': 22, 'ᅣ': 23, 'ᅤ': 24, 'ᅥ': 25, 'ᅦ': 26, 'ᅧ': 27, 'ᅨ': 28, 'ᅩ': 29, 'ᅪ': 30, 43 | 'ᅫ': 31, 'ᅬ': 32, 'ᅭ': 33, 'ᅮ': 34, 'ᅯ': 35, 'ᅰ': 36, 'ᅱ': 37, 'ᅲ': 38, 'ᅳ': 39, 'ᅴ': 40, 44 | 'ᅵ': 41, 'ᆨ': 42, 'ᆩ': 43, 'ᆪ': 44, 'ᆫ': 45, 'ᆬ': 46, 'ᆭ': 47, 'ᆮ': 48, 'ᆯ': 49, 'ᆰ': 50, 45 | 'ᆱ': 51, 'ᆲ': 52, 'ᆳ': 53, 'ᆴ': 54, 'ᆵ': 55, 'ᆶ': 56, 'ᆷ': 57, 'ᆸ': 58, 'ᆹ': 59, 'ᆺ': 60, 46 | 'ᆻ': 61, 'ᆼ': 62, 'ᆽ': 63, 'ᆾ': 64, 'ᆿ': 65, 'ᇀ': 66, 'ᇁ': 67, 'ᇂ': 68, '!': 69, "'": 70, 47 | '(': 71, ')': 72, ',': 73, '-': 74, '.': 75, ':': 76, ';': 77, '?': 78, ' ': 79 48 | """ -------------------------------------------------------------------------------- /text/cleaners.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | # Code based on https://github.com/keithito/tacotron/blob/master/text/cleaners.py 4 | 5 | import re 6 | from .korean import tokenize as ko_tokenize 7 | 8 | # # Added to support LJ_speech 9 | from unidecode import unidecode 10 | from .en_numbers import normalize_numbers as en_normalize_numbers 11 | 12 | # Regular expression matching whitespace: 13 | _whitespace_re = re.compile(r'\s+') 14 | 15 | 16 | def korean_cleaners(text): 17 | '''Pipeline for Korean text, including number and abbreviation expansion.''' 18 | text = ko_tokenize(text) # '존경하는' --> ['ᄌ', 'ᅩ', 'ᆫ', 'ᄀ', 'ᅧ', 'ᆼ', 'ᄒ', 'ᅡ', 'ᄂ', 'ᅳ', 'ᆫ', '~'] 19 | return text 20 | 21 | 22 | # List of (regular expression, replacement) pairs for abbreviations: 23 | _abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [ 24 | ('mrs', 'misess'), 25 | ('mr', 'mister'), 26 | ('dr', 'doctor'), 27 | ('st', 'saint'), 28 | ('co', 'company'), 29 | ('jr', 'junior'), 30 | ('maj', 'major'), 31 | ('gen', 'general'), 32 | ('drs', 'doctors'), 33 | ('rev', 'reverend'), 34 | ('lt', 'lieutenant'), 35 | ('hon', 'honorable'), 36 | ('sgt', 'sergeant'), 37 | ('capt', 'captain'), 38 | ('esq', 'esquire'), 39 | ('ltd', 'limited'), 40 | ('col', 'colonel'), 41 | ('ft', 'fort'), 42 | ]] 43 | 44 | 45 | def expand_abbreviations(text): 46 | for regex, replacement in _abbreviations: 47 | text = re.sub(regex, replacement, text) 48 | return text 49 | 50 | 51 | def expand_numbers(text): 52 | return en_normalize_numbers(text) 53 | 54 | 55 | def lowercase(text): 56 | return text.lower() 57 | 58 | 59 | def collapse_whitespace(text): 60 | return re.sub(_whitespace_re, ' ', text) 61 | 62 | 63 | def convert_to_ascii(text): 64 | return unidecode(text) 65 | 66 | 67 | def basic_cleaners(text): 68 | text = lowercase(text) 69 | text = collapse_whitespace(text) 70 | return text 71 | 72 | 73 | def transliteration_cleaners(text): 74 | # text = convert_to_ascii(text) 75 | text = lowercase(text) 76 | text = collapse_whitespace(text) 77 | return text 78 | 79 | 80 | def english_cleaners(text): 81 | text = convert_to_ascii(text) 82 | text = lowercase(text) 83 | text = expand_numbers(text) 84 | text = expand_abbreviations(text) 85 | text = collapse_whitespace(text) 86 | return text 87 | -------------------------------------------------------------------------------- /audio/get_duration.py: -------------------------------------------------------------------------------- 1 | import os 2 | import datetime 3 | from glob import glob 4 | from tqdm import tqdm 5 | from tinytag import TinyTag 6 | from collections import defaultdict 7 | from multiprocessing.dummy import Pool 8 | 9 | from utils import load_json 10 | 11 | def second_to_hour(sec): 12 | return str(datetime.timedelta(seconds=int(sec))) 13 | 14 | def get_duration(path): 15 | filename = os.path.basename(path) 16 | candidates = filename.split('.')[0].split('_') 17 | dataset = candidates[0] 18 | 19 | if not os.path.exists(path): 20 | print(" [!] {} not found".format(path)) 21 | return dataset, 0 22 | 23 | if True: # tinytag 24 | tag = TinyTag.get(path) 25 | duration = tag.duration 26 | else: # librosa 27 | y, sr = librosa.load(path) 28 | duration = librosa.get_duration(y=y, sr=sr) 29 | 30 | return dataset, duration 31 | 32 | def get_durations(paths, print_detail=True): 33 | duration_all = 0 34 | duration_book = defaultdict(list) 35 | 36 | pool = Pool() 37 | iterator = pool.imap_unordered(get_duration, paths) 38 | for dataset, duration in tqdm(iterator, total=len(paths)): 39 | duration_all += duration 40 | duration_book[dataset].append(duration) 41 | 42 | total_count = 0 43 | for book, duration in duration_book.items(): 44 | if book: 45 | time = second_to_hour(sum(duration)) 46 | file_count = len(duration) 47 | total_count += file_count 48 | 49 | if print_detail: 50 | print(" [*] Duration of {}: {} (file #: {})". \ 51 | format(book, time, file_count)) 52 | 53 | print(" [*] Total Duration : {} (file #: {})". \ 54 | format(second_to_hour(duration_all), total_count)) 55 | print() 56 | return duration_all 57 | 58 | 59 | if __name__ == '__main__': 60 | import argparse 61 | 62 | parser = argparse.ArgumentParser() 63 | parser.add_argument('--audio-pattern', default=None) # datasets/krbook/audio/*.wav 64 | parser.add_argument('--data-path', default=None) # datasets/jtbc/alignment.json 65 | config, unparsed = parser.parse_known_args() 66 | 67 | if config.audio_pattern is not None: 68 | duration = get_durations(get_paths_by_pattern(config.data_dir)) 69 | elif config.data_path is not None: 70 | paths = load_json(config.data_path).keys() 71 | duration = get_durations(paths) 72 | -------------------------------------------------------------------------------- /text/numbers.py: -------------------------------------------------------------------------------- 1 | import inflect 2 | import re 3 | 4 | _inflect = inflect.engine() 5 | _comma_number_re = re.compile(r'([0-9][0-9\,]+[0-9])') 6 | _decimal_number_re = re.compile(r'([0-9]+\.[0-9]+)') 7 | _pounds_re = re.compile(r'£([0-9\,]*[0-9]+)') 8 | _dollars_re = re.compile(r'\$([0-9\.\,]*[0-9]+)') 9 | _ordinal_re = re.compile(r'[0-9]+(st|nd|rd|th)') 10 | _number_re = re.compile(r'[0-9]+') 11 | 12 | 13 | def _remove_commas(m): 14 | return m.group(1).replace(',', '') 15 | 16 | 17 | def _expand_decimal_point(m): 18 | return m.group(1).replace('.', ' point ') 19 | 20 | 21 | def _expand_dollars(m): 22 | match = m.group(1) 23 | parts = match.split('.') 24 | if len(parts) > 2: 25 | return match + ' dollars' # Unexpected format 26 | dollars = int(parts[0]) if parts[0] else 0 27 | cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0 28 | if dollars and cents: 29 | dollar_unit = 'dollar' if dollars == 1 else 'dollars' 30 | cent_unit = 'cent' if cents == 1 else 'cents' 31 | return '%s %s, %s %s' % (dollars, dollar_unit, cents, cent_unit) 32 | elif dollars: 33 | dollar_unit = 'dollar' if dollars == 1 else 'dollars' 34 | return '%s %s' % (dollars, dollar_unit) 35 | elif cents: 36 | cent_unit = 'cent' if cents == 1 else 'cents' 37 | return '%s %s' % (cents, cent_unit) 38 | else: 39 | return 'zero dollars' 40 | 41 | 42 | def _expand_ordinal(m): 43 | return _inflect.number_to_words(m.group(0)) 44 | 45 | 46 | def _expand_number(m): 47 | num = int(m.group(0)) 48 | if num > 1000 and num < 3000: 49 | if num == 2000: 50 | return 'two thousand' 51 | elif num > 2000 and num < 2010: 52 | return 'two thousand ' + _inflect.number_to_words(num % 100) 53 | elif num % 100 == 0: 54 | return _inflect.number_to_words(num // 100) + ' hundred' 55 | else: 56 | return _inflect.number_to_words(num, andword='', zero='oh', group=2).replace(', ', ' ') 57 | else: 58 | return _inflect.number_to_words(num, andword='') 59 | 60 | 61 | def normalize_numbers(text): 62 | text = re.sub(_comma_number_re, _remove_commas, text) 63 | text = re.sub(_pounds_re, r'\1 pounds', text) 64 | text = re.sub(_dollars_re, _expand_dollars, text) 65 | text = re.sub(_decimal_number_re, _expand_decimal_point, text) 66 | text = re.sub(_ordinal_re, _expand_ordinal, text) 67 | text = re.sub(_number_re, _expand_number, text) 68 | return text 69 | -------------------------------------------------------------------------------- /TRAINING_DATA.md: -------------------------------------------------------------------------------- 1 | # Training Data 2 | 3 | 4 | This repo supports the following speech datasets: 5 | * [LJ Speech](https://keithito.com/LJ-Speech-Dataset/) (Public Domain) 6 | * [Blizzard 2012](http://www.cstr.ed.ac.uk/projects/blizzard/2012/phase_one) (Creative Commons Attribution Share-Alike) 7 | 8 | You can use any other dataset if you write a preprocessor for it. 9 | 10 | 11 | ### Writing a Preprocessor 12 | 13 | Each training example consists of: 14 | 1. The text that was spoken 15 | 2. A mel-scale spectrogram of the audio 16 | 3. A linear-scale spectrogram of the audio 17 | 18 | The preprocessor is responsible for generating these. See [ljspeech.py](datasets/ljspeech.py) for a 19 | commented example. 20 | 21 | For each training example, a preprocessor should: 22 | 23 | 1. Load the audio file: 24 | ```python 25 | wav = audio.load_wav(wav_path) 26 | ``` 27 | 28 | 2. Compute linear-scale and mel-scale spectrograms (float32 numpy arrays): 29 | ```python 30 | spectrogram = audio.spectrogram(wav).astype(np.float32) 31 | mel_spectrogram = audio.melspectrogram(wav).astype(np.float32) 32 | ``` 33 | 34 | 3. Save the spectrograms to disk: 35 | ```python 36 | np.save(os.path.join(out_dir, spectrogram_filename), spectrogram.T, allow_pickle=False) 37 | np.save(os.path.join(out_dir, mel_spectrogram_filename), mel_spectrogram.T, allow_pickle=False) 38 | ``` 39 | Note that the transpose of the matrix returned by `audio.spectrogram` is saved so that it's 40 | in time-major format. 41 | 42 | 4. Generate a tuple `(spectrogram_filename, mel_spectrogram_filename, n_frames, text)` to 43 | write to train.txt. n_frames is just the length of the time axis of the spectrogram. 44 | 45 | 46 | After you've written your preprocessor, you can add it to [preprocess.py](preprocess.py) by 47 | following the example of the other preprocessors in that file. 48 | 49 | 50 | ### Non-English Data 51 | 52 | If your training data is in a language other than English, you will probably want to change the 53 | text cleaners by setting the `cleaners` hyperparameter. 54 | 55 | * If your text is in a Latin script or can be transliterated to ASCII using the 56 | [Unidecode](https://pypi.python.org/pypi/Unidecode) library, you can use the transliteration 57 | cleaners by setting the hyperparameter `cleaners=transliteration_cleaners`. 58 | 59 | * If you don't want to transliterate, you can define a custom character set. 60 | This allows you to train directly on the character set used in your data. 61 | 62 | To do so, edit [symbols.py](text/symbols.py) and change the `_characters` variable to be a 63 | string containing the UTF-8 characters in your data. Then set the hyperparameter `cleaners=basic_cleaners`. 64 | 65 | * If you're not sure which option to use, you can evaluate the transliteration cleaners like this: 66 | 67 | ```python 68 | from text import cleaners 69 | cleaners.transliteration_cleaners('Здравствуйте') # Replace with the text you want to try 70 | ``` 71 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | appnope==0.1.0 2 | audioread==2.1.5 3 | beautifulsoup4==4.6.0 4 | bleach==1.5.0 5 | bs4==0.0.1 6 | cachetools==2.0.1 7 | certifi==2017.7.27.1 8 | chardet==3.0.4 9 | click==6.7 10 | cycler==0.10.0 11 | decorator==4.1.2 12 | dill==0.2.7.1 13 | ffprobe==0.5 14 | Flask==0.12.2 15 | Flask-Cors==3.0.3 16 | future==0.16.0 17 | gapic-google-cloud-datastore-v1==0.15.3 18 | gapic-google-cloud-error-reporting-v1beta1==0.15.3 19 | gapic-google-cloud-logging-v2==0.91.3 20 | gapic-google-cloud-pubsub-v1==0.15.4 21 | gapic-google-cloud-spanner-admin-database-v1==0.15.3 22 | gapic-google-cloud-spanner-admin-instance-v1==0.15.3 23 | gapic-google-cloud-spanner-v1==0.15.3 24 | google-auth==1.1.1 25 | google-cloud==0.27.0 26 | google-cloud-bigquery==0.26.0 27 | google-cloud-bigtable==0.26.0 28 | google-cloud-core==0.26.0 29 | google-cloud-datastore==1.2.0 30 | google-cloud-dns==0.26.0 31 | google-cloud-error-reporting==0.26.0 32 | google-cloud-language==0.27.0 33 | google-cloud-logging==1.2.0 34 | google-cloud-monitoring==0.26.0 35 | google-cloud-pubsub==0.27.0 36 | google-cloud-resource-manager==0.26.0 37 | google-cloud-runtimeconfig==0.26.0 38 | google-cloud-spanner==0.26.0 39 | google-cloud-speech==0.28.0 40 | google-cloud-storage==1.3.2 41 | google-cloud-translate==1.1.0 42 | google-cloud-videointelligence==0.25.0 43 | google-cloud-vision==0.26.0 44 | google-api-core==1.1.2 45 | google-resumable-media==0.3.0 46 | googleapis-common-protos==1.5.3 47 | grpc-google-iam-v1==0.11.4 48 | grpcio==1.6.3 49 | html5lib==0.9999999 50 | httplib2==0.10.3 51 | idna==2.6 52 | ipdb==0.10.3 53 | ipython==6.2.1 54 | ipython-genutils==0.2.0 55 | iso8601==0.1.12 56 | itsdangerous==0.24 57 | jamo==0.4.1 58 | jedi==0.11.0 59 | Jinja2==2.9.6 60 | joblib==0.11 61 | librosa==0.5.1 62 | llvmlite==0.20.0 63 | m3u8==0.3.3 64 | Markdown==2.6.9 65 | MarkupSafe==1.0 66 | matplotlib==2.1.0 67 | monotonic==1.3 68 | nltk==3.2.5 69 | numba==0.35.0 70 | numpy==1.13.3 71 | oauth2client==3.0.0 72 | parso==0.1.0 73 | pexpect==4.2.1 74 | pickleshare==0.7.4 75 | ply==3.8 76 | prompt-toolkit==1.0.15 77 | proto-google-cloud-datastore-v1==0.90.4 78 | proto-google-cloud-error-reporting-v1beta1==0.15.3 79 | proto-google-cloud-logging-v2==0.91.3 80 | proto-google-cloud-pubsub-v1==0.15.4 81 | proto-google-cloud-spanner-admin-database-v1==0.15.3 82 | proto-google-cloud-spanner-admin-instance-v1==0.15.3 83 | proto-google-cloud-spanner-v1==0.15.3 84 | protobuf==3.4.0 85 | ptyprocess==0.5.2 86 | pyasn1==0.3.7 87 | pyasn1-modules==0.1.5 88 | pydub==0.20.0 89 | Pygments==2.2.0 90 | pyparsing==2.2.0 91 | python-dateutil==2.6.1 92 | pytz==2017.2 93 | requests==2.18.4 94 | resampy==0.2.0 95 | rsa==3.4.2 96 | scikit-learn==0.19.0 97 | scipy==0.19.1 98 | simplegeneric==0.8.1 99 | six==1.11.0 100 | tenacity==4.4.0 101 | #tensorflow-gpu==1.3.0 102 | #tensorflow-tensorboard==0.1.8 103 | tinytag==0.18.0 104 | tqdm==4.19.2 105 | traitlets==4.3.2 106 | urllib3==1.22 107 | wcwidth==0.1.7 108 | Werkzeug==0.12.2 109 | youtube-dl==2017.10.15.1 110 | unidecode==1.0.22 111 | inflect==0.2.5 -------------------------------------------------------------------------------- /datasets/ljspeech.py: -------------------------------------------------------------------------------- 1 | from concurrent.futures import ProcessPoolExecutor 2 | from functools import partial 3 | import numpy as np 4 | import os 5 | from util import audio 6 | 7 | 8 | def build_from_path(in_dir, out_dir, num_workers=1, tqdm=lambda x: x): 9 | '''Preprocesses the LJ Speech dataset from a given input path into a given output directory. 10 | 11 | Args: 12 | in_dir: The directory where you have downloaded the LJ Speech dataset 13 | out_dir: The directory to write the output into 14 | num_workers: Optional number of worker processes to parallelize across 15 | tqdm: You can optionally pass tqdm to get a nice progress bar 16 | 17 | Returns: 18 | A list of tuples describing the training examples. This should be written to train.txt 19 | ''' 20 | 21 | # We use ProcessPoolExecutor to parallelize across processes. This is just an optimization and you 22 | # can omit it and just call _process_utterance on each input if you want. 23 | executor = ProcessPoolExecutor(max_workers=num_workers) 24 | futures = [] 25 | index = 1 26 | with open(os.path.join(in_dir, 'metadata.csv'), encoding='utf-8') as f: 27 | for line in f: 28 | parts = line.strip().split('|') 29 | wav_path = os.path.join(in_dir, 'wavs', '%s.wav' % parts[0]) 30 | text = parts[2] 31 | futures.append(executor.submit(partial(_process_utterance, out_dir, index, wav_path, text))) 32 | index += 1 33 | return [future.result() for future in tqdm(futures)] 34 | 35 | 36 | def _process_utterance(out_dir, index, wav_path, text): 37 | '''Preprocesses a single utterance audio/text pair. 38 | 39 | This writes the mel and linear scale spectrograms to disk and returns a tuple to write 40 | to the train.txt file. 41 | 42 | Args: 43 | out_dir: The directory to write the spectrograms into 44 | index: The numeric index to use in the spectrogram filenames. 45 | wav_path: Path to the audio file containing the speech input 46 | text: The text spoken in the input audio file 47 | 48 | Returns: 49 | A (spectrogram_filename, mel_filename, n_frames, text) tuple to write to train.txt 50 | ''' 51 | 52 | # Load the audio to a numpy array: 53 | wav = audio.load_wav(wav_path) 54 | 55 | # Compute the linear-scale spectrogram from the wav: 56 | spectrogram = audio.spectrogram(wav).astype(np.float32) 57 | n_frames = spectrogram.shape[1] 58 | 59 | # Compute a mel-scale spectrogram from the wav: 60 | mel_spectrogram = audio.melspectrogram(wav).astype(np.float32) 61 | 62 | # Write the spectrograms to disk: 63 | spectrogram_filename = 'ljspeech-spec-%05d.npy' % index 64 | mel_filename = 'ljspeech-mel-%05d.npy' % index 65 | np.save(os.path.join(out_dir, spectrogram_filename), spectrogram.T, allow_pickle=False) 66 | np.save(os.path.join(out_dir, mel_filename), mel_spectrogram.T, allow_pickle=False) 67 | 68 | # Return a tuple describing this training example: 69 | return (spectrogram_filename, mel_filename, n_frames, text) 70 | -------------------------------------------------------------------------------- /datasets/blizzard.py: -------------------------------------------------------------------------------- 1 | from concurrent.futures import ProcessPoolExecutor 2 | from functools import partial 3 | import numpy as np 4 | import os 5 | from hparams import hparams 6 | from util import audio 7 | 8 | _max_out_length = 700 9 | _end_buffer = 0.05 10 | _min_confidence = 90 11 | 12 | # Note: "A Tramp Abroad" & "The Man That Corrupted Hadleyburg" are higher quality than the others. 13 | books = [ 14 | 'ATrampAbroad', 15 | 'TheManThatCorruptedHadleyburg', 16 | # 'LifeOnTheMississippi', 17 | # 'TheAdventuresOfTomSawyer', 18 | ] 19 | 20 | 21 | def build_from_path(in_dir, out_dir, num_workers=1, tqdm=lambda x: x): 22 | executor = ProcessPoolExecutor(max_workers=num_workers) 23 | futures = [] 24 | index = 1 25 | for book in books: 26 | with open(os.path.join(in_dir, book, 'sentence_index.txt')) as f: 27 | for line in f: 28 | parts = line.strip().split('\t') 29 | if line[0] is not '#' and len(parts) == 8 and float(parts[3]) > _min_confidence: 30 | wav_path = os.path.join(in_dir, book, 'wav', '%s.wav' % parts[0]) 31 | labels_path = os.path.join(in_dir, book, 'lab', '%s.lab' % parts[0]) 32 | text = parts[5] 33 | task = partial(_process_utterance, out_dir, index, wav_path, labels_path, text) 34 | futures.append(executor.submit(task)) 35 | index += 1 36 | results = [future.result() for future in tqdm(futures)] 37 | return [r for r in results if r is not None] 38 | 39 | 40 | def _process_utterance(out_dir, index, wav_path, labels_path, text): 41 | # Load the wav file and trim silence from the ends: 42 | wav = audio.load_wav(wav_path) 43 | start_offset, end_offset = _parse_labels(labels_path) 44 | start = int(start_offset * hparams.sample_rate) 45 | end = int(end_offset * hparams.sample_rate) if end_offset is not None else -1 46 | wav = wav[start:end] 47 | max_samples = _max_out_length * hparams.frame_shift_ms / 1000 * hparams.sample_rate 48 | if len(wav) > max_samples: 49 | return None 50 | spectrogram = audio.spectrogram(wav).astype(np.float32) 51 | n_frames = spectrogram.shape[1] 52 | mel_spectrogram = audio.melspectrogram(wav).astype(np.float32) 53 | spectrogram_filename = 'blizzard-spec-%05d.npy' % index 54 | mel_filename = 'blizzard-mel-%05d.npy' % index 55 | np.save(os.path.join(out_dir, spectrogram_filename), spectrogram.T, allow_pickle=False) 56 | np.save(os.path.join(out_dir, mel_filename), mel_spectrogram.T, allow_pickle=False) 57 | return (spectrogram_filename, mel_filename, n_frames, text) 58 | 59 | 60 | def _parse_labels(path): 61 | labels = [] 62 | with open(os.path.join(path)) as f: 63 | for line in f: 64 | parts = line.strip().split(' ') 65 | if len(parts) >= 3: 66 | labels.append((float(parts[0]), ' '.join(parts[2:]))) 67 | start = 0 68 | end = None 69 | if labels[0][1] == 'sil': 70 | start = labels[0][0] 71 | if labels[-1][1] == 'sil': 72 | end = labels[-2][0] + _end_buffer 73 | return (start, end) 74 | -------------------------------------------------------------------------------- /datasets/kss.py: -------------------------------------------------------------------------------- 1 | from concurrent.futures import ProcessPoolExecutor 2 | from functools import partial 3 | import numpy as np 4 | import os 5 | from util import audio 6 | 7 | 8 | def build_from_path(in_dir, out_dir, num_workers=1, tqdm=lambda x: x): 9 | '''Preprocesses the LJ Speech dataset from a given input path into a given output directory. 10 | 11 | Args: 12 | in_dir: The directory where you have downloaded the LJ Speech dataset 13 | out_dir: The directory to write the output into 14 | num_workers: Optional number of worker processes to parallelize across 15 | tqdm: You can optionally pass tqdm to get a nice progress bar 16 | 17 | Returns: 18 | A list of tuples describing the training examples. This should be written to train.txt 19 | ''' 20 | 21 | # We use ProcessPoolExecutor to parallelize across processes. This is just an optimization and you 22 | # can omit it and just call _process_utterance on each input if you want. 23 | executor = ProcessPoolExecutor(max_workers=num_workers) 24 | futures = [] 25 | index = 1 26 | with open(os.path.join(in_dir, 'metadata.csv'), encoding='utf-8') as f: 27 | for line in f: 28 | try: 29 | 30 | parts = line.strip().split('|') 31 | wav_path = os.path.join(in_dir, 'wavs', '%s' % parts[0]) 32 | text = parts[1] 33 | futures.append(executor.submit(partial(_process_utterance, out_dir, index, wav_path, text))) 34 | index += 1 35 | 36 | except: 37 | 38 | pass 39 | return [future.result() for future in tqdm(futures)] 40 | 41 | 42 | def _process_utterance(out_dir, index, wav_path, text): 43 | '''Preprocesses a single utterance audio/text pair. 44 | 45 | This writes the mel and linear scale spectrograms to disk and returns a tuple to write 46 | to the train.txt file. 47 | 48 | Args: 49 | out_dir: The directory to write the spectrograms into 50 | index: The numeric index to use in the spectrogram filenames. 51 | wav_path: Path to the audio file containing the speech input 52 | text: The text spoken in the input audio file 53 | 54 | Returns: 55 | A (spectrogram_filename, mel_filename, n_frames, text) tuple to write to train.txt 56 | ''' 57 | 58 | # Load the audio to a numpy array: 59 | wav = audio.load_wav(wav_path) 60 | 61 | # Compute the linear-scale spectrogram from the wav: 62 | spectrogram = audio.spectrogram(wav).astype(np.float32) 63 | n_frames = spectrogram.shape[1] 64 | 65 | # Compute a mel-scale spectrogram from the wav: 66 | mel_spectrogram = audio.melspectrogram(wav).astype(np.float32) 67 | 68 | # Write the spectrograms to disk: 69 | spectrogram_filename = 'kss-spec-%05d.npy' % index 70 | mel_filename = 'kss-mel-%05d.npy' % index 71 | np.save(os.path.join(out_dir, spectrogram_filename), spectrogram.T, allow_pickle=False) 72 | np.save(os.path.join(out_dir, mel_filename), mel_spectrogram.T, allow_pickle=False) 73 | 74 | # Return a tuple describing this training example: 75 | return (spectrogram_filename, mel_filename, n_frames, text) 76 | -------------------------------------------------------------------------------- /datasets/bible.py: -------------------------------------------------------------------------------- 1 | from concurrent.futures import ProcessPoolExecutor 2 | from functools import partial 3 | import numpy as np 4 | import os 5 | from util import audio 6 | 7 | 8 | def build_from_path(in_dir, out_dir, num_workers=1, tqdm=lambda x: x): 9 | '''Preprocesses the LJ Speech dataset from a given input path into a given output directory. 10 | 11 | Args: 12 | in_dir: The directory where you have downloaded the LJ Speech dataset 13 | out_dir: The directory to write the output into 14 | num_workers: Optional number of worker processes to parallelize across 15 | tqdm: You can optionally pass tqdm to get a nice progress bar 16 | 17 | Returns: 18 | A list of tuples describing the training examples. This should be written to train.txt 19 | ''' 20 | 21 | # We use ProcessPoolExecutor to parallelize across processes. This is just an optimization and you 22 | # can omit it and just call _process_utterance on each input if you want. 23 | executor = ProcessPoolExecutor(max_workers=num_workers) 24 | futures = [] 25 | index = 1 26 | with open(os.path.join(in_dir, 'metadata.csv'), encoding='utf-8') as f: 27 | for line in f: 28 | try: 29 | 30 | parts = line.strip().split('|') 31 | wav_path = os.path.join(in_dir, 'wavs', '%s' % parts[0]) 32 | text = parts[1] 33 | futures.append(executor.submit(partial(_process_utterance, out_dir, index, wav_path, text))) 34 | index += 1 35 | 36 | except: 37 | 38 | pass 39 | return [future.result() for future in tqdm(futures)] 40 | 41 | 42 | def _process_utterance(out_dir, index, wav_path, text): 43 | '''Preprocesses a single utterance audio/text pair. 44 | 45 | This writes the mel and linear scale spectrograms to disk and returns a tuple to write 46 | to the train.txt file. 47 | 48 | Args: 49 | out_dir: The directory to write the spectrograms into 50 | index: The numeric index to use in the spectrogram filenames. 51 | wav_path: Path to the audio file containing the speech input 52 | text: The text spoken in the input audio file 53 | 54 | Returns: 55 | A (spectrogram_filename, mel_filename, n_frames, text) tuple to write to train.txt 56 | ''' 57 | 58 | # Load the audio to a numpy array: 59 | wav = audio.load_wav(wav_path) 60 | 61 | # Compute the linear-scale spectrogram from the wav: 62 | spectrogram = audio.spectrogram(wav).astype(np.float32) 63 | n_frames = spectrogram.shape[1] 64 | 65 | # Compute a mel-scale spectrogram from the wav: 66 | mel_spectrogram = audio.melspectrogram(wav).astype(np.float32) 67 | 68 | # Write the spectrograms to disk: 69 | spectrogram_filename = 'bible-spec-%05d.npy' % index 70 | mel_filename = 'bible-mel-%05d.npy' % index 71 | np.save(os.path.join(out_dir, spectrogram_filename), spectrogram.T, allow_pickle=False) 72 | np.save(os.path.join(out_dir, mel_filename), mel_spectrogram.T, allow_pickle=False) 73 | 74 | # Return a tuple describing this training example: 75 | return (spectrogram_filename, mel_filename, n_frames, text) 76 | -------------------------------------------------------------------------------- /models/helpers.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | from tensorflow.contrib.seq2seq import Helper 4 | 5 | 6 | # Adapted from tf.contrib.seq2seq.GreedyEmbeddingHelper 7 | class TacoTestHelper(Helper): 8 | def __init__(self, batch_size, output_dim, r): 9 | with tf.name_scope('TacoTestHelper'): 10 | self._batch_size = batch_size 11 | self._output_dim = output_dim 12 | self._end_token = tf.tile([0.0], [output_dim * r]) 13 | 14 | @property 15 | def batch_size(self): 16 | return self._batch_size 17 | 18 | @property 19 | def sample_ids_shape(self): 20 | return tf.TensorShape([]) 21 | 22 | @property 23 | def sample_ids_dtype(self): 24 | return np.int32 25 | 26 | def initialize(self, name=None): 27 | return (tf.tile([False], [self._batch_size]), _go_frames(self._batch_size, self._output_dim)) 28 | 29 | def sample(self, time, outputs, state, name=None): 30 | return tf.tile([0], [self._batch_size]) # Return all 0; we ignore them 31 | 32 | def next_inputs(self, time, outputs, state, sample_ids, name=None): 33 | '''Stop on EOS. Otherwise, pass the last output as the next input and pass through state.''' 34 | with tf.name_scope('TacoTestHelper'): 35 | finished = tf.reduce_all(tf.equal(outputs, self._end_token), axis=1) 36 | # Feed last output frame as next input. outputs is [N, output_dim * r] 37 | next_inputs = outputs[:, -self._output_dim:] 38 | return (finished, next_inputs, state) 39 | 40 | 41 | class TacoTrainingHelper(Helper): 42 | def __init__(self, inputs, targets, output_dim, r, rnn_decoder_test_mode=False): 43 | # inputs is [N, T_in], targets is [N, T_out, D] 44 | with tf.name_scope('TacoTrainingHelper'): 45 | self._batch_size = tf.shape(inputs)[0] 46 | self._output_dim = output_dim 47 | self._rnn_decoder_test_mode = rnn_decoder_test_mode 48 | 49 | # Feed every r-th target frame as input 50 | self._targets = targets[:, r - 1::r, :] 51 | 52 | # Use full length for every target because we don't want to mask the padding frames 53 | num_steps = tf.shape(self._targets)[1] 54 | self._lengths = tf.tile([num_steps], [self._batch_size]) 55 | 56 | @property 57 | def batch_size(self): 58 | return self._batch_size 59 | 60 | @property 61 | def sample_ids_shape(self): 62 | return tf.TensorShape([]) 63 | 64 | @property 65 | def sample_ids_dtype(self): 66 | return np.int32 67 | 68 | def initialize(self, name=None): 69 | return (tf.tile([False], [self._batch_size]), _go_frames(self._batch_size, self._output_dim)) 70 | 71 | def sample(self, time, outputs, state, name=None): 72 | return tf.tile([0], [self._batch_size]) # Return all 0; we ignore them 73 | 74 | def next_inputs(self, time, outputs, state, sample_ids, name=None): 75 | with tf.name_scope(name or 'TacoTrainingHelper'): 76 | finished = (time + 1 >= self._lengths) 77 | if self._rnn_decoder_test_mode: 78 | next_inputs = outputs[:, -self._output_dim:] 79 | else: 80 | next_inputs = self._targets[:, time, :] 81 | return (finished, next_inputs, state) 82 | 83 | 84 | def _go_frames(batch_size, output_dim): 85 | '''Returns all-zero frames for a given batch size and output dimension''' 86 | return tf.tile([[0.0]], [batch_size, output_dim]) 87 | -------------------------------------------------------------------------------- /download.py: -------------------------------------------------------------------------------- 1 | # Code based on https://github.com/carpedm20/DCGAN-tensorflow/blob/master/download.py 2 | 3 | from __future__ import print_function 4 | import os 5 | import sys 6 | import gzip 7 | import json 8 | import tarfile 9 | import zipfile 10 | import argparse 11 | import requests 12 | from tqdm import tqdm 13 | from six.moves import urllib 14 | 15 | from utils import query_yes_no 16 | 17 | parser = argparse.ArgumentParser(description='Download model checkpoints.') 18 | parser.add_argument('checkpoints', metavar='N', type=str, nargs='+', choices=['son', 'park'], 19 | help='name of checkpoints to download [son, park]') 20 | 21 | def download(url, dirpath): 22 | filename = url.split('/')[-1] 23 | filepath = os.path.join(dirpath, filename) 24 | u = urllib.request.urlopen(url) 25 | f = open(filepath, 'wb') 26 | filesize = int(u.headers["Content-Length"]) 27 | print("Downloading: %s Bytes: %s" % (filename, filesize)) 28 | 29 | downloaded = 0 30 | block_sz = 8192 31 | status_width = 70 32 | while True: 33 | buf = u.read(block_sz) 34 | if not buf: 35 | print('') 36 | break 37 | else: 38 | print('', end='\r') 39 | downloaded += len(buf) 40 | f.write(buf) 41 | status = (("[%-" + str(status_width + 1) + "s] %3.2f%%") % 42 | ('=' * int(float(downloaded) / filesize * status_width) + '>', downloaded * 100. / filesize)) 43 | print(status, end='') 44 | sys.stdout.flush() 45 | f.close() 46 | return filepath 47 | 48 | def download_file_from_google_drive(id, destination): 49 | URL = "https://docs.google.com/uc?export=download" 50 | session = requests.Session() 51 | 52 | response = session.get(URL, params={ 'id': id }, stream=True) 53 | token = get_confirm_token(response) 54 | 55 | if token: 56 | params = { 'id' : id, 'confirm' : token } 57 | response = session.get(URL, params=params, stream=True) 58 | 59 | save_response_content(response, destination) 60 | 61 | def get_confirm_token(response): 62 | for key, value in response.cookies.items(): 63 | if key.startswith('download_warning'): 64 | return value 65 | return None 66 | 67 | def save_response_content(response, destination, chunk_size=32*1024): 68 | total_size = int(response.headers.get('content-length', 0)) 69 | with open(destination, "wb") as f: 70 | for chunk in tqdm(response.iter_content(chunk_size), total=total_size, 71 | unit='B', unit_scale=True, desc=destination): 72 | if chunk: # filter out keep-alive new chunks 73 | f.write(chunk) 74 | 75 | def unzip(filepath): 76 | print("Extracting: " + filepath) 77 | dirpath = os.path.dirname(filepath) 78 | with zipfile.ZipFile(filepath) as zf: 79 | zf.extractall(dirpath) 80 | os.remove(filepath) 81 | 82 | def download_checkpoint(checkpoint): 83 | if checkpoint == "son": 84 | save_path, drive_id = "son-20171015.tar.gz", "0B_7wC-DuR6ORcmpaY1A5V1AzZUU" 85 | elif checkpoint == "park": 86 | save_path, drive_id = "park-20171015.tar.gz", "0B_7wC-DuR6ORYjhlekl5bVlkQ2c" 87 | else: 88 | raise Exception(" [!] Unknown checkpoint: {}".format(checkpoint)) 89 | 90 | if os.path.exists(save_path): 91 | print('[*] {} already exists'.format(save_path)) 92 | else: 93 | download_file_from_google_drive(drive_id, save_path) 94 | 95 | if save_path.endswith(".zip"): 96 | zip_dir = '' 97 | with zipfile.ZipFile(save_path) as zf: 98 | zip_dir = zf.namelist()[0] 99 | zf.extractall(dirpath) 100 | os.remove(save_path) 101 | os.rename(os.path.join(dirpath, zip_dir), os.path.join(dirpath, data_dir)) 102 | elif save_path.endswith("tar.gz"): 103 | tar = tarfile.open(save_path, "r:gz") 104 | tar.extractall() 105 | tar.close() 106 | elif save_path.endswith("tar"): 107 | tar = tarfile.open(save_path, "r:") 108 | tar.extractall() 109 | tar.close() 110 | 111 | if __name__ == '__main__': 112 | args = parser.parse_args() 113 | 114 | print(" [!] The pre-trained models are being made available for research purpose only") 115 | print(" [!] 학습된 모델을 연구 이외의 목적으로 사용하는 것을 금지합니다.") 116 | print() 117 | 118 | if query_yes_no(" [?] Are you agree on this? 이에 동의하십니까?"): 119 | if 'park' in args.checkpoints: 120 | download_checkpoint('park') 121 | if 'son' in args.checkpoints: 122 | download_checkpoint('son') 123 | -------------------------------------------------------------------------------- /hparams.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | SCALE_FACTOR = 1 4 | 5 | def f(num): 6 | return num // SCALE_FACTOR 7 | 8 | basic_params = { 9 | # Comma-separated list of cleaners to run on text prior to training and eval. For non-English 10 | # text, you may want to use "basic_cleaners" or "transliteration_cleaners" See TRAINING_DATA.md. 11 | 'cleaners': 'english_cleaners', #originally korean_cleaners 12 | } 13 | 14 | basic_params.update({ 15 | # Audio 16 | 'num_mels': 80, 17 | 'num_freq': 1025, 18 | 'sample_rate': 24000, # trained as 20000 but need to be 24000 19 | 'frame_length_ms': 50, 20 | 'frame_shift_ms': 12.5, 21 | 'preemphasis': 0.97, 22 | 'min_level_db': -100, 23 | 'ref_level_db': 20, 24 | }) 25 | 26 | if True: 27 | basic_params.update({ 28 | 'sample_rate': 22050, #originally 24000 (krbook), 22050(lj-data), 20000(others) 29 | }) 30 | 31 | basic_params.update({ 32 | # Model 33 | 'model_type': 'single', # [single, simple, deepvoice] 34 | 'speaker_embedding_size': f(16), 35 | 36 | 'embedding_size': f(256), 37 | 'dropout_prob': 0.5, 38 | 39 | # Encoder 40 | 'enc_prenet_sizes': [f(256), f(128)], 41 | 'enc_bank_size': 16, 42 | 'enc_bank_channel_size': f(128), 43 | 'enc_maxpool_width': 2, 44 | 'enc_highway_depth': 4, 45 | 'enc_rnn_size': f(128), 46 | 'enc_proj_sizes': [f(128), f(128)], 47 | 'enc_proj_width': 3, 48 | 49 | # Attention 50 | 'attention_type': 'bah_mon', # ntm2-5 51 | 'attention_size': f(256), 52 | 'attention_state_size': f(256), 53 | 54 | # Decoder recurrent network 55 | 'dec_layer_num': 2, 56 | 'dec_rnn_size': f(256), 57 | 58 | # Decoder 59 | 'dec_prenet_sizes': [f(256), f(128)], 60 | 'post_bank_size': 8, 61 | 'post_bank_channel_size': f(256), 62 | 'post_maxpool_width': 2, 63 | 'post_highway_depth': 4, 64 | 'post_rnn_size': f(128), 65 | 'post_proj_sizes': [f(256), 80], # num_mels=80 66 | 'post_proj_width': 3, 67 | 68 | 'reduction_factor': 4, 69 | }) 70 | 71 | if False: # Deep Voice 2 AudioBook Dataset 72 | basic_params.update({ 73 | 'dropout_prob': 0.8, 74 | 75 | 'attention_size': f(512), 76 | 77 | 'dec_prenet_sizes': [f(256), f(128), f(64)], 78 | 'post_bank_channel_size': f(512), 79 | 'post_rnn_size': f(256), 80 | 81 | 'reduction_factor': 5, # changed from 4 82 | }) 83 | elif False: # Deep Voice 2 VCTK dataset 84 | basic_params.update({ 85 | 'dropout_prob': 0.8, 86 | 87 | #'attention_size': f(512), 88 | 89 | #'dec_prenet_sizes': [f(256), f(128)], 90 | #'post_bank_channel_size': f(512), 91 | 'post_rnn_size': f(256), 92 | 93 | 'reduction_factor': 5, 94 | }) 95 | elif True: # Single Speaker 96 | basic_params.update({ 97 | 'dropout_prob': 0.5, 98 | 99 | 'attention_size': f(128), 100 | 101 | 'post_bank_channel_size': f(128), 102 | #'post_rnn_size': f(128), 103 | 104 | 'reduction_factor': 5, #chhanged from 4 105 | }) 106 | elif False: # Single Speaker with generalization 107 | basic_params.update({ 108 | 'dropout_prob': 0.8, 109 | 110 | 'attention_size': f(256), 111 | 112 | 'dec_prenet_sizes': [f(256), f(128), f(64)], 113 | 'post_bank_channel_size': f(128), 114 | 'post_rnn_size': f(128), 115 | 116 | 'reduction_factor': 4, 117 | }) 118 | 119 | 120 | basic_params.update({ 121 | # Training 122 | 'batch_size': 32, 123 | 'adam_beta1': 0.9, 124 | 'adam_beta2': 0.999, 125 | 'use_fixed_test_inputs': False, 126 | 127 | 'initial_learning_rate': 0.001, 128 | 'decay_learning_rate_mode': 0, # True in deepvoice2 paper 129 | 'initial_data_greedy': True, 130 | 'initial_phase_step': 8000, 131 | 'main_data_greedy_factor': 0, 132 | 'main_data': [''], 133 | 'prioritize_loss': False, 134 | 135 | 'recognition_loss_coeff': 0.2, 136 | 'ignore_recognition_level': 0, # 0: use all, 1: ignore only unmatched_alignment, 2: fully ignore recognition 137 | 138 | # Eval 139 | 'min_tokens': 50,#originally 50, 30 is good for korean, 140 | 'min_iters': 30, 141 | 'max_iters': 200, 142 | 'skip_inadequate': False, 143 | 144 | 'griffin_lim_iters': 60, 145 | 'power': 1.5, # Power to raise magnitudes to prior to Griffin-Lim 146 | }) 147 | 148 | 149 | # Default hyperparameters: 150 | hparams = tf.contrib.training.HParams(**basic_params) 151 | 152 | 153 | def hparams_debug_string(): 154 | values = hparams.values() 155 | hp = [' %s: %s' % (name, values[name]) for name in sorted(values)] 156 | return 'Hyperparameters:\n' + '\n'.join(hp) 157 | -------------------------------------------------------------------------------- /text/__init__.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | import re 3 | import string 4 | import numpy as np 5 | 6 | from text import cleaners 7 | from hparams import hparams 8 | from text.symbols import symbols, en_symbols, PAD, EOS 9 | from text.korean import jamo_to_korean 10 | 11 | # Mappings from symbol to numeric ID and vice versa: 12 | _symbol_to_id = {s: i for i, s in enumerate(symbols)} 13 | _id_to_symbol = {i: s for i, s in enumerate(symbols)} 14 | isEn = False 15 | 16 | # Regular expression matching text enclosed in curly braces: 17 | _curly_re = re.compile(r'(.*?)\{(.+?)\}(.*)') 18 | 19 | puncuation_table = str.maketrans({key: None for key in string.punctuation}) 20 | 21 | 22 | def convert_to_en_symbols(): 23 | '''Converts built-in korean symbols to english, to be used for english training 24 | 25 | ''' 26 | global _symbol_to_id, _id_to_symbol, isEn 27 | if not isEn: 28 | print(" [!] Converting to english mode") 29 | _symbol_to_id = {s: i for i, s in enumerate(en_symbols)} 30 | _id_to_symbol = {i: s for i, s in enumerate(en_symbols)} 31 | isEn = True 32 | 33 | 34 | def remove_puncuations(text): 35 | return text.translate(puncuation_table) 36 | 37 | 38 | # def text_to_sequence(text, as_token=False): 39 | # cleaner_names = [x.strip() for x in hparams.cleaners.split(',')] 40 | # if ('english_cleaners' in cleaner_names) and isEn == False: 41 | # convert_to_en_symbols() 42 | # else: 43 | # 44 | # return _text_to_sequence(text, cleaner_names, as_token) 45 | 46 | 47 | def text_to_sequence(text, as_token=False): 48 | cleaner_names = [x.strip() for x in hparams.cleaners.split(',')] 49 | if ('english_cleaners' in cleaner_names) and isEn==False: 50 | convert_to_en_symbols() 51 | return _text_to_sequence(text, cleaner_names, as_token) 52 | 53 | 54 | def _text_to_sequence(text, cleaner_names, as_token): 55 | '''Converts a string of text to a sequence of IDs corresponding to the symbols in the text. 56 | The text can optionally have ARPAbet sequences enclosed in curly braces embedded 57 | in it. For example, "Turn left on {HH AW1 S S T AH0 N} Street." 58 | Args: 59 | text: string to convert to a sequence 60 | cleaner_names: names of the cleaner functions to run the text through 61 | Returns: 62 | List of integers corresponding to the symbols in the text 63 | ''' 64 | sequence = [] 65 | 66 | # Check for curly braces and treat their contents as ARPAbet: 67 | while len(text): 68 | m = _curly_re.match(text) 69 | if not m: 70 | sequence += _symbols_to_sequence(_clean_text(text, cleaner_names)) 71 | break 72 | sequence += _symbols_to_sequence(_clean_text(m.group(1), cleaner_names)) 73 | sequence += _arpabet_to_sequence(m.group(2)) 74 | text = m.group(3) 75 | 76 | # Append EOS token 77 | sequence.append(_symbol_to_id[EOS]) 78 | 79 | if as_token: 80 | return sequence_to_text(sequence, combine_jamo=True) 81 | else: 82 | return np.array(sequence, dtype=np.int32) 83 | 84 | 85 | 86 | def sequence_to_text(sequence, skip_eos_and_pad=False, combine_jamo=False): 87 | '''Converts a sequence of IDs back to a string''' 88 | cleaner_names = [x.strip() for x in hparams.cleaners.split(',')] 89 | if 'english_cleaners' in cleaner_names and isEn == False: 90 | convert_to_en_symbols() 91 | 92 | result = '' 93 | for symbol_id in sequence: 94 | if symbol_id in _id_to_symbol: 95 | s = _id_to_symbol[symbol_id] 96 | # Enclose ARPAbet back in curly braces: 97 | if len(s) > 1 and s[0] == '@': 98 | s = '{%s}' % s[1:] 99 | 100 | if not skip_eos_and_pad or s not in [EOS, PAD]: 101 | result += s 102 | 103 | result = result.replace('}{', ' ') 104 | 105 | if combine_jamo: 106 | return jamo_to_korean(result) 107 | else: 108 | return result 109 | 110 | 111 | def _clean_text(text, cleaner_names): 112 | for name in cleaner_names: 113 | cleaner = getattr(cleaners, name) 114 | if not cleaner: 115 | raise Exception('Unknown cleaner: %s' % name) 116 | text = cleaner(text) # '존경하는' --> ['ᄌ', 'ᅩ', 'ᆫ', 'ᄀ', 'ᅧ', 'ᆼ', 'ᄒ', 'ᅡ', 'ᄂ', 'ᅳ', 'ᆫ', '~'] 117 | return text 118 | 119 | 120 | def _symbols_to_sequence(symbols): 121 | return [_symbol_to_id[s] for s in symbols if _should_keep_symbol(s)] 122 | 123 | 124 | def _arpabet_to_sequence(text): 125 | return _symbols_to_sequence(['@' + s for s in text.split()]) 126 | 127 | 128 | def _should_keep_symbol(s): 129 | return s in _symbol_to_id and s is not '_' and s is not '~' 130 | -------------------------------------------------------------------------------- /recognition/google.py: -------------------------------------------------------------------------------- 1 | import io 2 | import os 3 | import json 4 | import argparse 5 | import numpy as np 6 | from glob import glob 7 | from functools import partial 8 | 9 | from utils import parallel_run, remove_file, backup_file, write_json 10 | from audio import load_audio, save_audio, resample_audio, get_duration 11 | 12 | def text_recognition(path, config): 13 | root, ext = os.path.splitext(path) 14 | txt_path = root + ".txt" 15 | 16 | if os.path.exists(txt_path): 17 | with open(txt_path) as f: 18 | out = json.loads(open(txt_path).read()) 19 | return out 20 | 21 | from google.cloud import speech 22 | from google.cloud.speech import enums 23 | from google.cloud.speech import types 24 | 25 | out = {} 26 | error_count = 0 27 | 28 | tmp_path = os.path.splitext(path)[0] + ".tmp.wav" 29 | client = speech.SpeechClient() # Fixed 30 | 31 | while True: 32 | try: 33 | # client= speech.SpeechClient() # Causes 10060 max retries exceeded -to OAuth -HK 34 | 35 | content = load_audio( 36 | path, pre_silence_length=config.pre_silence_length, 37 | post_silence_length=config.post_silence_length) 38 | 39 | max_duration = config.max_duration - \ 40 | config.pre_silence_length - config.post_silence_length 41 | audio_duration = get_duration(content) 42 | 43 | if audio_duration >= max_duration: 44 | print(" [!] Skip {} because of duration: {} > {}". \ 45 | format(path, audio_duration, max_duration)) 46 | return {} 47 | 48 | content = resample_audio(content, config.sample_rate) 49 | save_audio(content, tmp_path, config.sample_rate) 50 | 51 | with io.open(tmp_path, 'rb') as f: 52 | audio = types.RecognitionAudio(content=f.read()) 53 | 54 | config = types.RecognitionConfig( 55 | encoding=enums.RecognitionConfig.AudioEncoding.LINEAR16, 56 | sample_rate_hertz=config.sample_rate, 57 | language_code='ko-KR') 58 | 59 | response = client.recognize(config, audio) 60 | if len(response.results) > 0: 61 | alternatives = response.results[0].alternatives 62 | 63 | results = [alternative.transcript for alternative in alternatives] 64 | assert len(results) == 1, "More than 1 results: {}".format(results) 65 | 66 | out = { path: "" if len(results) == 0 else results[0] } 67 | print(path, results[0]) 68 | break 69 | break 70 | except Exception as err: 71 | raise Exception("OS error: {0}".format(err)) 72 | 73 | error_count += 1 74 | print("Skip warning for {} for {} times". \ 75 | format(path, error_count)) 76 | 77 | if error_count > 5: 78 | break 79 | else: 80 | continue 81 | 82 | remove_file(tmp_path) 83 | with open(txt_path, 'w') as f: 84 | json.dump(out, f, indent=2, ensure_ascii=False) 85 | 86 | return out 87 | 88 | def text_recognition_batch(paths, config): 89 | paths.sort() 90 | 91 | results = {} 92 | items = parallel_run( 93 | partial(text_recognition, config=config), paths, 94 | desc="text_recognition_batch", parallel=True) 95 | for item in items: 96 | results.update(item) 97 | return results 98 | 99 | 100 | if __name__ == '__main__': 101 | parser = argparse.ArgumentParser() 102 | parser.add_argument('--audio_pattern', required=True) 103 | parser.add_argument('--recognition_filename', default="recognition.json") 104 | parser.add_argument('--sample_rate', default=16000, type=int) 105 | parser.add_argument('--pre_silence_length', default=1, type=int) 106 | parser.add_argument('--post_silence_length', default=1, type=int) 107 | parser.add_argument('--max_duration', default=60, type=int) 108 | config, unparsed = parser.parse_known_args() 109 | 110 | audio_dir = os.path.dirname(config.audio_pattern) 111 | 112 | for tmp_path in glob(os.path.join(audio_dir, "*.tmp.*")): 113 | remove_file(tmp_path) 114 | 115 | paths = glob(config.audio_pattern) 116 | paths.sort() 117 | results = text_recognition_batch(paths, config) 118 | 119 | base_dir = os.path.dirname(audio_dir) 120 | recognition_path = \ 121 | os.path.join(base_dir, config.recognition_filename) 122 | 123 | if os.path.exists(recognition_path): 124 | backup_file(recognition_path) 125 | 126 | write_json(recognition_path, results) 127 | -------------------------------------------------------------------------------- /text/kor_dic.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | etc_dictionary = { 4 | '2 30대': '이삼십대', 5 | '20~30대': '이삼십대', 6 | '20, 30대': '이십대 삼십대', 7 | '1+1': '원플러스원', 8 | '3에서 6개월인': '3개월에서 육개월인', 9 | } 10 | 11 | english_dictionary = { 12 | 'Devsisters': '데브시스터즈', 13 | 'track': '트랙', 14 | 15 | # krbook 16 | 'LA': '엘에이', 17 | 'LG': '엘지', 18 | 'KOREA': '코리아', 19 | 'JSA': '제이에스에이', 20 | 'PGA': '피지에이', 21 | 'GA': '지에이', 22 | 'idol': '아이돌', 23 | 'KTX': '케이티엑스', 24 | 'AC': '에이씨', 25 | 'DVD': '디비디', 26 | 'US': '유에스', 27 | 'CNN': '씨엔엔', 28 | 'LPGA': '엘피지에이', 29 | 'P': '피', 30 | 'L': '엘', 31 | 'T': '티', 32 | 'B': '비', 33 | 'C': '씨', 34 | 'BIFF': '비아이에프에프', 35 | 'GV': '지비', 36 | 37 | # JTBC 38 | 'IT': '아이티', 39 | 'IQ': '아이큐', 40 | 'JTBC': '제이티비씨', 41 | 'trickle down effect': '트리클 다운 이펙트', 42 | 'trickle up effect': '트리클 업 이펙트', 43 | 'down': '다운', 44 | 'up': '업', 45 | 'FCK': '에프씨케이', 46 | 'AP': '에이피', 47 | 'WHERETHEWILDTHINGSARE': '', 48 | 'Rashomon Effect': '', 49 | 'O': '오', 50 | 'OO': '오오', 51 | 'B': '비', 52 | 'GDP': '지디피', 53 | 'CIPA': '씨아이피에이', 54 | 'YS': '와이에스', 55 | 'Y': '와이', 56 | 'S': '에스', 57 | 'JTBC': '제이티비씨', 58 | 'PC': '피씨', 59 | 'bill': '빌', 60 | 'Halmuny': '하모니', ##### 61 | 'X': '엑스', 62 | 'SNS': '에스엔에스', 63 | 'ability': '어빌리티', 64 | 'shy': '', 65 | 'CCTV': '씨씨티비', 66 | 'IT': '아이티', 67 | 'the tenth man': '더 텐쓰 맨', #### 68 | 'L': '엘', 69 | 'PC': '피씨', 70 | 'YSDJJPMB': '', ######## 71 | 'Content Attitude Timing': '컨텐트 애티튜드 타이밍', 72 | 'CAT': '캣', 73 | 'IS': '아이에스', 74 | 'SNS': '에스엔에스', 75 | 'K': '케이', 76 | 'Y': '와이', 77 | 'KDI': '케이디아이', 78 | 'DOC': '디오씨', 79 | 'CIA': '씨아이에이', 80 | 'PBS': '피비에스', 81 | 'D': '디', 82 | 'PPropertyPositionPowerPrisonP' 83 | 'S': '에스', 84 | 'francisco': '프란시스코', 85 | 'I': '아이', 86 | 'III': '아이아이', ###### 87 | 'No joke': '노 조크', 88 | 'BBK': '비비케이', 89 | 'LA': '엘에이', 90 | 'Don': '', 91 | 't worry be happy': ' 워리 비 해피', 92 | 'NO': '엔오', ##### 93 | 'it was our sky': '잇 워즈 아워 스카이', 94 | 'it is our sky': '잇 이즈 아워 스카이', #### 95 | 'NEIS': '엔이아이에스', ##### 96 | 'IMF': '아이엠에프', 97 | 'apology': '어폴로지', 98 | 'humble': '험블', 99 | 'M': '엠', 100 | 'Nowhere Man': '노웨어 맨', 101 | 'The Tenth Man': '더 텐쓰 맨', 102 | 'PBS': '피비에스', 103 | 'BBC': '비비씨', 104 | 'MRJ': '엠알제이', 105 | 'CCTV': '씨씨티비', 106 | 'Pick me up': '픽 미 업', 107 | 'DNA': '디엔에이', 108 | 'UN': '유엔', 109 | 'STOP': '스탑', ##### 110 | 'PRESS': '프레스', ##### 111 | 'not to be': '낫 투비', 112 | 'Denial': '디나이얼', 113 | 'G': '지', 114 | 'IMF': '아이엠에프', 115 | 'GDP': '지디피', 116 | 'JTBC': '제이티비씨', 117 | 'Time flies like an arrow': '타임 플라이즈 라이크 언 애로우', 118 | 'DDT': '디디티', 119 | 'AI': '에이아이', 120 | 'Z': '제트', 121 | 'OECD': '오이씨디', 122 | 'N': '앤', 123 | 'A': '에이', 124 | 'MB': '엠비', 125 | 'EH': '이에이치', 126 | 'IS': '아이에스', 127 | 'TV': '티비', 128 | 'MIT': '엠아이티', 129 | 'KBO': '케이비오', 130 | 'I love America': '아이 러브 아메리카', 131 | 'SF': '에스에프', 132 | 'Q': '큐', 133 | 'KFX': '케이에프엑스', 134 | 'PM': '피엠', 135 | 'Prime Minister': '프라임 미니스터', 136 | 'Swordline': '스워드라인', 137 | 'TBS': '티비에스', 138 | 'DDT': '디디티', 139 | 'CS': '씨에스', 140 | 'Reflecting Absence': '리플렉팅 앱센스', 141 | 'PBS': '피비에스', 142 | 'Drum being beaten by everyone': '드럼 빙 비튼 바이 에브리원', 143 | 'negative pressure': '네거티브 프레셔', 144 | 'F': '에프', 145 | 'KIA': '기아', 146 | 'FTA': '에프티에이', 147 | 'Que sais-je': '', 148 | 'UFC': '유에프씨', 149 | 'P': '피', 150 | 'DJ': '디제이', 151 | 'Chaebol': '채벌', 152 | 'BBC': '비비씨', 153 | 'OECD': '오이씨디', 154 | 'BC': '삐씨', 155 | 'C': '씨', 156 | 'B': '씨', 157 | 'KY': '케이와이', 158 | 'K': '케이', 159 | 'CEO': '씨이오', 160 | 'YH': '와이에치', 161 | 'IS': '아이에스', 162 | 'who are you': '후 얼 유', 163 | 'Y': '와이', 164 | 'The Devils Advocate': '더 데빌즈 어드보카트', 165 | 'YS': '와이에스', 166 | 'so sorry': '쏘 쏘리', 167 | 'Santa': '산타', 168 | 'Big Endian': '빅 엔디안', 169 | 'Small Endian': '스몰 엔디안', 170 | 'Oh Captain My Captain': '오 캡틴 마이 캡틴', 171 | 'AIB': '에이아이비', 172 | 'K': '케이', 173 | 'PBS': '피비에스', 174 | } -------------------------------------------------------------------------------- /LJSpeech-1.1/README: -------------------------------------------------------------------------------- 1 | ----------------------------------------------------------------------------- 2 | The LJ Speech Dataset 3 | 4 | Version 1.0 5 | July 5, 2017 6 | https://keithito.com/LJ-Speech-Dataset 7 | ----------------------------------------------------------------------------- 8 | 9 | 10 | OVERVIEW 11 | 12 | This is a public domain speech dataset consisting of 13,100 short audio clips 13 | of a single speaker reading passages from 7 non-fiction books. A transcription 14 | is provided for each clip. Clips vary in length from 1 to 10 seconds and have 15 | a total length of approximately 24 hours. 16 | 17 | The texts were published between 1884 and 1964, and are in the public domain. 18 | The audio was recorded in 2016-17 by the LibriVox project and is also in the 19 | public domain. 20 | 21 | 22 | 23 | FILE FORMAT 24 | 25 | Metadata is provided in metadata.csv. This file consists of one record per 26 | line, delimited by the pipe character (0x7c). The fields are: 27 | 28 | 1. ID: this is the name of the corresponding .wav file 29 | 2. Transcription: words spoken by the reader (UTF-8) 30 | 3. Normalized Transcription: transcription with numbers, ordinals, and 31 | monetary units expanded into full words (UTF-8). 32 | 33 | Each audio file is a single-channel 16-bit PCM WAV with a sample rate of 34 | 22050 Hz. 35 | 36 | 37 | 38 | STATISTICS 39 | 40 | Total Clips 13,100 41 | Total Words 225,715 42 | Total Characters 1,308,674 43 | Total Duration 23:55:17 44 | Mean Clip Duration 6.57 sec 45 | Min Clip Duration 1.11 sec 46 | Max Clip Duration 10.10 sec 47 | Mean Words per Clip 17.23 48 | Distinct Words 13,821 49 | 50 | 51 | 52 | MISCELLANEOUS 53 | 54 | The audio clips range in length from approximately 1 second to 10 seconds. 55 | They were segmented automatically based on silences in the recording. Clip 56 | boundaries generally align with sentence or clause boundaries, but not always. 57 | 58 | The text was matched to the audio manually, and a QA pass was done to ensure 59 | that the text accurately matched the words spoken in the audio. 60 | 61 | The original LibriVox recordings were distributed as 128 kbps MP3 files. As a 62 | result, they may contain artifacts introduced by the MP3 encoding. 63 | 64 | The following abbreviations appear in the text. They may be expanded as 65 | follows: 66 | 67 | Abbreviation Expansion 68 | -------------------------- 69 | Mr. Mister 70 | Mrs. Misess (*) 71 | Dr. Doctor 72 | No. Number 73 | St. Saint 74 | Co. Company 75 | Jr. Junior 76 | Maj. Major 77 | Gen. General 78 | Drs. Doctors 79 | Rev. Reverend 80 | Lt. Lieutenant 81 | Hon. Honorable 82 | Sgt. Sergeant 83 | Capt. Captain 84 | Esq. Esquire 85 | Ltd. Limited 86 | Col. Colonel 87 | Ft. Fort 88 | 89 | * there's no standard expansion of "Mrs." 90 | 91 | 92 | 19 of the transcriptions contain non-ASCII characters (for example, LJ016-0257 93 | contains "raison d'être"). 94 | 95 | For more information or to report errors, please email kito@kito.us. 96 | 97 | 98 | 99 | LICENSE 100 | 101 | This dataset is in the public domain in the USA (and likely other countries as 102 | well). There are no restrictions on its use. For more information, please see: 103 | https://librivox.org/pages/public-domain. 104 | 105 | 106 | CHANGELOG 107 | 108 | * 1.0 (July 8, 2017): 109 | Initial release 110 | 111 | * 1.1 (Feb 19, 2018): 112 | Version 1.0 included 30 .wav files with no corresponding annotations in 113 | metadata.csv. These have been removed in version 1.1. Thanks to Rafael Valle 114 | for spotting this. 115 | 116 | 117 | CREDITS 118 | 119 | This dataset consists of excerpts from the following works: 120 | 121 | * Morris, William, et al. Arts and Crafts Essays. 1893. 122 | * Griffiths, Arthur. The Chronicles of Newgate, Vol. 2. 1884. 123 | * Roosevelt, Franklin D. The Fireside Chats of Franklin Delano Roosevelt. 124 | 1933-42. 125 | * Harland, Marion. Marion Harland's Cookery for Beginners. 1893. 126 | * Rolt-Wheeler, Francis. The Science - History of the Universe, Vol. 5: 127 | Biology. 1910. 128 | * Banks, Edgar J. The Seven Wonders of the Ancient World. 1916. 129 | * President's Commission on the Assassination of President Kennedy. Report 130 | of the President's Commission on the Assassination of President Kennedy. 131 | 1964. 132 | 133 | Recordings by Linda Johnson. Alignment and annotation by Keith Ito. All text, 134 | audio, and annotations are in the public domain. 135 | 136 | There's no requirement to cite this work, but if you'd like to do so, you can 137 | link to: https://keithito.com/LJ-Speech-Dataset 138 | 139 | or use the following: 140 | @misc{ljspeech17, 141 | author = {Keith Ito}, 142 | title = {The LJ Speech Dataset}, 143 | howpublished = {\url{https://keithito.com/LJ-Speech-Dataset/}}, 144 | year = 2017 145 | } 146 | -------------------------------------------------------------------------------- /audio/silence.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import sys 4 | import json 5 | import librosa 6 | import argparse 7 | import numpy as np 8 | from tqdm import tqdm 9 | from glob import glob 10 | from pydub import silence 11 | from pydub import AudioSegment 12 | from functools import partial 13 | 14 | from hparams import hparams 15 | from utils import parallel_run, add_postfix 16 | from audio import load_audio, save_audio, get_duration, get_silence 17 | 18 | def abs_mean(x): 19 | return abs(x).mean() 20 | 21 | def remove_breath(audio): 22 | edges = librosa.effects.split( 23 | audio, top_db=40, frame_length=128, hop_length=32) 24 | 25 | for idx in range(len(edges)): 26 | start_idx, end_idx = edges[idx][0], edges[idx][1] 27 | if start_idx < len(audio): 28 | if abs_mean(audio[start_idx:end_idx]) < abs_mean(audio) - 0.05: 29 | audio[start_idx:end_idx] = 0 30 | 31 | return audio 32 | 33 | def split_on_silence_with_librosa( 34 | audio_path, top_db=40, frame_length=1024, hop_length=256, 35 | skip_idx=0, out_ext="wav", 36 | min_segment_length=3, max_segment_length=8, 37 | pre_silence_length=0, post_silence_length=0): 38 | 39 | filename = os.path.basename(audio_path).split('.', 1)[0] 40 | in_ext = audio_path.rsplit(".")[1] 41 | 42 | audio = load_audio(audio_path) 43 | 44 | edges = librosa.effects.split(audio, 45 | top_db=top_db, frame_length=frame_length, hop_length=hop_length) 46 | 47 | new_audio = np.zeros_like(audio) 48 | for idx, (start, end) in enumerate(edges[skip_idx:]): 49 | new_audio[start:end] = remove_breath(audio[start:end]) 50 | 51 | save_audio(new_audio, add_postfix(audio_path, "no_breath")) 52 | audio = new_audio 53 | edges = librosa.effects.split(audio, 54 | top_db=top_db, frame_length=frame_length, hop_length=hop_length) 55 | 56 | audio_paths = [] 57 | for idx, (start, end) in enumerate(edges[skip_idx:]): 58 | segment = audio[start:end] 59 | duration = get_duration(segment) 60 | 61 | if duration <= min_segment_length or duration >= max_segment_length: 62 | continue 63 | 64 | output_path = "{}/{}.{:04d}.{}".format( 65 | os.path.dirname(audio_path), filename, idx, out_ext) 66 | 67 | padded_segment = np.concatenate([ 68 | get_silence(pre_silence_length), 69 | segment, 70 | get_silence(post_silence_length), 71 | ]) 72 | 73 | 74 | 75 | save_audio(padded_segment, output_path) 76 | audio_paths.append(output_path) 77 | 78 | return audio_paths 79 | 80 | def read_audio(audio_path): 81 | return AudioSegment.from_file(audio_path) 82 | 83 | def split_on_silence_with_pydub( 84 | audio_path, skip_idx=0, out_ext="wav", 85 | silence_thresh=-40, min_silence_len=400, 86 | silence_chunk_len=100, keep_silence=100): 87 | 88 | filename = os.path.basename(audio_path).split('.', 1)[0] 89 | in_ext = audio_path.rsplit(".")[1] 90 | 91 | audio = read_audio(audio_path) 92 | not_silence_ranges = silence.detect_nonsilent( 93 | audio, min_silence_len=silence_chunk_len, 94 | silence_thresh=silence_thresh) 95 | 96 | edges = [not_silence_ranges[0]] 97 | 98 | for idx in range(1, len(not_silence_ranges)-1): 99 | cur_start = not_silence_ranges[idx][0] 100 | prev_end = edges[-1][1] 101 | 102 | if cur_start - prev_end < min_silence_len: 103 | edges[-1][1] = not_silence_ranges[idx][1] 104 | else: 105 | edges.append(not_silence_ranges[idx]) 106 | 107 | audio_paths = [] 108 | for idx, (start_idx, end_idx) in enumerate(edges[skip_idx:]): 109 | start_idx = max(0, start_idx - keep_silence) 110 | end_idx += keep_silence 111 | 112 | target_audio_path = "{}/{}.{:04d}.{}".format( 113 | os.path.dirname(audio_path), filename, idx, out_ext) 114 | 115 | segment=audio[start_idx:end_idx] 116 | 117 | segment.export(target_audio_path, out_ext) # for soundsegment 118 | 119 | audio_paths.append(target_audio_path) 120 | 121 | return audio_paths 122 | 123 | def split_on_silence_batch(audio_paths, method, **kargv): 124 | audio_paths.sort() 125 | method = method.lower() 126 | 127 | if method == "librosa": 128 | fn = partial(split_on_silence_with_librosa, **kargv) 129 | elif method == "pydub": 130 | fn = partial(split_on_silence_with_pydub, **kargv) 131 | 132 | parallel_run(fn, audio_paths, 133 | desc="Split on silence", parallel=False) 134 | 135 | if __name__ == "__main__": 136 | parser = argparse.ArgumentParser() 137 | parser.add_argument('--audio_pattern', required=True) 138 | parser.add_argument('--out_ext', default='wav') 139 | parser.add_argument('--method', choices=['librosa', 'pydub'], required=True) 140 | config = parser.parse_args() 141 | 142 | audio_paths = glob(config.audio_pattern) 143 | 144 | split_on_silence_batch( 145 | audio_paths, config.method, 146 | out_ext=config.out_ext, 147 | ) 148 | -------------------------------------------------------------------------------- /utils/audio.py: -------------------------------------------------------------------------------- 1 | import librosa 2 | import librosa.filters 3 | import math 4 | import numpy as np 5 | import tensorflow as tf 6 | import scipy 7 | from hparams import hparams 8 | 9 | 10 | def load_wav(path): 11 | return librosa.core.load(path, sr=hparams.sample_rate)[0] 12 | 13 | 14 | def save_wav(wav, path): 15 | wav *= 32767 / max(0.01, np.max(np.abs(wav))) 16 | scipy.io.wavfile.write(path, hparams.sample_rate, wav.astype(np.int16)) 17 | 18 | 19 | def preemphasis(x): 20 | return scipy.signal.lfilter([1, -hparams.preemphasis], [1], x) 21 | 22 | 23 | def inv_preemphasis(x): 24 | return scipy.signal.lfilter([1], [1, -hparams.preemphasis], x) 25 | 26 | 27 | def spectrogram(y): 28 | D = _stft(preemphasis(y)) 29 | S = _amp_to_db(np.abs(D)) - hparams.ref_level_db 30 | return _normalize(S) 31 | 32 | 33 | def inv_spectrogram(spectrogram): 34 | '''Converts spectrogram to waveform using librosa''' 35 | S = _db_to_amp(_denormalize(spectrogram) + hparams.ref_level_db) # Convert back to linear 36 | return inv_preemphasis(_griffin_lim(S ** hparams.power)) # Reconstruct phase 37 | 38 | 39 | def inv_spectrogram_tensorflow(spectrogram): 40 | '''Builds computational graph to convert spectrogram to waveform using TensorFlow. 41 | 42 | Unlike inv_spectrogram, this does NOT invert the preemphasis. The caller should call 43 | inv_preemphasis on the output after running the graph. 44 | ''' 45 | S = _db_to_amp_tensorflow(_denormalize_tensorflow(spectrogram) + hparams.ref_level_db) 46 | return _griffin_lim_tensorflow(tf.pow(S, hparams.power)) 47 | 48 | 49 | def melspectrogram(y): 50 | D = _stft(preemphasis(y)) 51 | S = _amp_to_db(_linear_to_mel(np.abs(D))) - hparams.ref_level_db 52 | return _normalize(S) 53 | 54 | 55 | def find_endpoint(wav, threshold_db=-40, min_silence_sec=0.8): 56 | window_length = int(hparams.sample_rate * min_silence_sec) 57 | hop_length = int(window_length / 4) 58 | threshold = _db_to_amp(threshold_db) 59 | for x in range(hop_length, len(wav) - window_length, hop_length): 60 | if np.max(wav[x:x + window_length]) < threshold: 61 | return x + hop_length 62 | return len(wav) 63 | 64 | 65 | def _griffin_lim(S): 66 | '''librosa implementation of Griffin-Lim 67 | Based on https://github.com/librosa/librosa/issues/434 68 | ''' 69 | angles = np.exp(2j * np.pi * np.random.rand(*S.shape)) 70 | S_complex = np.abs(S).astype(np.complex) 71 | y = _istft(S_complex * angles) 72 | for i in range(hparams.griffin_lim_iters): 73 | angles = np.exp(1j * np.angle(_stft(y))) 74 | y = _istft(S_complex * angles) 75 | return y 76 | 77 | 78 | def _griffin_lim_tensorflow(S): 79 | '''TensorFlow implementation of Griffin-Lim 80 | Based on https://github.com/Kyubyong/tensorflow-exercises/blob/master/Audio_Processing.ipynb 81 | ''' 82 | with tf.variable_scope('griffinlim'): 83 | # TensorFlow's stft and istft operate on a batch of spectrograms; create batch of size 1 84 | S = tf.expand_dims(S, 0) 85 | S_complex = tf.identity(tf.cast(S, dtype=tf.complex64)) 86 | y = _istft_tensorflow(S_complex) 87 | for i in range(hparams.griffin_lim_iters): 88 | est = _stft_tensorflow(y) 89 | angles = est / tf.cast(tf.maximum(1e-8, tf.abs(est)), tf.complex64) 90 | y = _istft_tensorflow(S_complex * angles) 91 | return tf.squeeze(y, 0) 92 | 93 | 94 | def _stft(y): 95 | n_fft, hop_length, win_length = _stft_parameters() 96 | return librosa.stft(y=y, n_fft=n_fft, hop_length=hop_length, win_length=win_length) 97 | 98 | 99 | def _istft(y): 100 | _, hop_length, win_length = _stft_parameters() 101 | return librosa.istft(y, hop_length=hop_length, win_length=win_length) 102 | 103 | 104 | def _stft_tensorflow(signals): 105 | n_fft, hop_length, win_length = _stft_parameters() 106 | return tf.contrib.signal.stft(signals, win_length, hop_length, n_fft, pad_end=False) 107 | 108 | 109 | def _istft_tensorflow(stfts): 110 | n_fft, hop_length, win_length = _stft_parameters() 111 | return tf.contrib.signal.inverse_stft(stfts, win_length, hop_length, n_fft) 112 | 113 | 114 | def _stft_parameters(): 115 | n_fft = (hparams.num_freq - 1) * 2 116 | hop_length = int(hparams.frame_shift_ms / 1000 * hparams.sample_rate) 117 | win_length = int(hparams.frame_length_ms / 1000 * hparams.sample_rate) 118 | return n_fft, hop_length, win_length 119 | 120 | 121 | # Conversions: 122 | 123 | _mel_basis = None 124 | 125 | 126 | def _linear_to_mel(spectrogram): 127 | global _mel_basis 128 | if _mel_basis is None: 129 | _mel_basis = _build_mel_basis() 130 | return np.dot(_mel_basis, spectrogram) 131 | 132 | 133 | def _build_mel_basis(): 134 | n_fft = (hparams.num_freq - 1) * 2 135 | return librosa.filters.mel(hparams.sample_rate, n_fft, n_mels=hparams.num_mels) 136 | 137 | 138 | def _amp_to_db(x): 139 | return 20 * np.log10(np.maximum(1e-5, x)) 140 | 141 | 142 | def _db_to_amp(x): 143 | return np.power(10.0, x * 0.05) 144 | 145 | 146 | def _db_to_amp_tensorflow(x): 147 | return tf.pow(tf.ones(tf.shape(x)) * 10.0, x * 0.05) 148 | 149 | 150 | def _normalize(S): 151 | return np.clip((S - hparams.min_level_db) / -hparams.min_level_db, 0, 1) 152 | 153 | 154 | def _denormalize(S): 155 | return (np.clip(S, 0, 1) * -hparams.min_level_db) + hparams.min_level_db 156 | 157 | 158 | def _denormalize_tensorflow(S): 159 | return (tf.clip_by_value(S, 0, 1) * -hparams.min_level_db) + hparams.min_level_db 160 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import math 4 | import argparse 5 | from glob import glob 6 | 7 | from synthesizer import Synthesizer 8 | from train import create_batch_inputs_from_texts 9 | from utils import makedirs, str2bool, backup_file 10 | from hparams import hparams, hparams_debug_string 11 | 12 | 13 | texts = [ 14 | '텍스트를 음성으로 읽어주는 "음성합성" 기술은 시각 장애인을 위한 오디오북, 음성 안내 시스템, 대화 인공지능 등 많은 분야에 활용할 수 있습니다.', 15 | "하지만 개인이 원하는 목소리로 음성합성 엔진을 만들기에는 어려움이 많았고 소수의 기업만이 기술을 보유하고 있었습니다.", 16 | "최근 딥러닝 기술의 발전은 음성합성 기술의 진입 장벽을 많이 낮췄고 이제는 누구나 손쉽게 음성합성 엔진을 만들 수 있게 되었습니다.", 17 | 18 | "본 세션에서는 딥러닝을 활용한 음성합성 기술을 소개하고 개발 경험과 그 과정에서 얻었던 팁을 공유하고자 합니다.", 19 | "음성합성 엔진을 구현하는데 사용한 세 가지 연구를 소개하고 각각의 기술이 얼마나 자연스러운 목소리를 만들 수 있는지를 공유합니다.", 20 | 21 | # Harry Potter 22 | "그리고 헤르미온느는 겁에 질려 마룻바닥에 쓰러져 있었다.", 23 | "그러자 론은 요술지팡이를 꺼냈다. 무엇을 할지도 모르면서 그는 머리에 처음으로 떠오른 주문을 외치고 있었다.", 24 | "윙가르디움 레비오우사.... 하지만, 그렇게 소리쳤다.", 25 | "그러자 그 방망이가 갑자기 트롤의 손에서 벗어나, 저 위로 올라가더니 탁하며 그 주인의 머리 위에 떨어졌다.", 26 | "그러자 트롤이 그 자리에서 비틀거리더니 방 전체를 흔들어버릴 것 같은 커다란 소리를 내며 쿵 하고 넘어졌다. ", 27 | "그러자 조그맣게 펑 하는 소리가 나면서 가장 가까이 있는 가로등이 꺼졌다.", 28 | "그리고 그가 다시 찰깍하자 그 다음 가로등이 깜박거리며 나가 버렸다.", 29 | 30 | #"그가 그렇게 가로등 끄기를 열두번 하자, 이제 그 거리에 남아 있는 불빛이라곤, ", 31 | #"바늘로 꼭 질러둔 것처럼 작게 보이는 멀리서 그를 지켜보고 있는 고양이의 두 눈뿐이었다.", 32 | #"프리벳가 4번지에 살고 있는 더즐리 부부는 자신들이 정상적이라는 것을 아주 자랑스럽게 여기는 사람들이었다. ", 33 | #"그들은 기이하거나 신비스런 일과는 전혀 무관해 보였다.", 34 | #"아니, 그런 터무니없는 것은 도저히 참아내지 못했다.", 35 | #"더즐리 씨는 그루닝스라는 드릴제작회사의 중역이었다.", 36 | #"그는 목이 거의 없을 정도로 살이 뒤룩뒤룩 찐 몸집이 큰 사내로, 코밑에는 커다란 콧수염을 기르고 있었다.", 37 | #"더즐리 부인은 마른 체구의 금발이었고, 목이 보통사람보다 두 배는 길어서, 담 너머로 고개를 쭉 배고 이웃 사람들을 몰래 훔쳐보는 그녀의 취미에는 더없이 제격이었다.", 38 | 39 | # From Yoo Inna's Audiobook (http://campaign.happybean.naver.com/yooinna_audiobook): 40 | #'16세기 중엽 어느 가을날 옛 런던 시의 가난한 캔티 집안에 사내아이 하나가 태어났다.', 41 | #'그런데 그 집안에서는 그 사내아이를 별로 반기지 않았다.', 42 | #'바로 같은 날 또 한 명의 사내아이가 영국의 부유한 튜터 가문에서 태어났다.', 43 | #'그런데 그 가문에서는 그 아이를 무척이나 반겼다.', 44 | #'온 영국이 다 함께 그 아이를 반겼다.', 45 | 46 | ## From NAVER's Audiobook (http://campaign.happybean.naver.com/yooinna_audiobook): 47 | #'부랑자 패거리는 이른 새벽에 일찍 출발하여 길을 떠났다.', 48 | #'하늘은 찌푸렸고, 발밑의 땅은 질퍽거렸으며, 겨울의 냉기가 공기 중에 감돌았다.', 49 | #'지난밤의 흥겨움은 온데간데없이 사라졌다.', 50 | #'시무룩하게 말이 없는 사람들도 있었고, 안달복달하며 조바심을 내는 사람들도 있었지만, 기분이 좋은 사람은 하나도 없었다.', 51 | 52 | ## From NAVER's nVoice example (https://www.facebook.com/naverlabs/videos/422780217913446): 53 | #'감사합니다. Devsisters 김태훈 님의 발표였습니다.', 54 | #'이것으로 금일 마련된 track 2의 모든 세션이 종료되었습니다.', 55 | #'장시간 끝까지 참석해주신 개발자 여러분들께 진심으로 감사의 말씀을 드리며,', 56 | #'잠시 후 5시 15분부터 특정 주제에 관심 있는 사람들이 모여 자유롭게 이야기하는 오프미팅이 진행될 예정이므로', 57 | #'참여신청을 해주신 분들은 진행 요원의 안내에 따라 이동해주시기 바랍니다.', 58 | 59 | ## From Kakao's Son Seok hee example (https://www.youtube.com/watch?v=ScfdAH2otrY): 60 | #'소설가 마크 트웨인이 말했습니다.', 61 | #'인생에 가장 중요한 이틀이 있는데, 하나는 세상에 태어난 날이고 다른 하나는 왜 이 세상에 왔는가를 깨닫는 날이다.', 62 | #'그런데 그 첫번째 날은 누구나 다 알지만 두번째 날은 참 어려운 것 같습니다.', 63 | #'누구나 그 두번째 날을 만나기 위해 애쓰는게 삶인지도 모르겠습니다.', 64 | #'뉴스룸도 그런 면에서 똑같습니다.', 65 | #'저희들도 그 두번째의 날을 만나고 기억하기 위해 매일 매일 최선을 다하겠습니다.', 66 | ] 67 | 68 | 69 | def get_output_base_path(load_path, eval_dirname="eval"): 70 | if not os.path.isdir(load_path): 71 | base_dir = os.path.dirname(load_path) 72 | else: 73 | base_dir = load_path 74 | 75 | base_dir = os.path.join(base_dir, eval_dirname) 76 | if os.path.exists(base_dir): 77 | backup_file(base_dir) 78 | makedirs(base_dir) 79 | 80 | m = re.compile(r'.*?\.ckpt\-([0-9]+)').match(load_path) 81 | base_path = os.path.join(base_dir, 82 | 'eval-%d' % int(m.group(1)) if m else 'eval') 83 | return base_path 84 | 85 | 86 | def run_eval(args): 87 | print(hparams_debug_string()) 88 | 89 | load_paths = glob(args.load_path_pattern) 90 | 91 | for load_path in load_paths: 92 | if not os.path.exists(os.path.join(load_path, "checkpoint")): 93 | print(" [!] Skip non model directory: {}".format(load_path)) 94 | continue 95 | 96 | synth = Synthesizer() 97 | synth.load(load_path) 98 | 99 | for speaker_id in range(synth.num_speakers): 100 | base_path = get_output_base_path(load_path, "eval-{}".format(speaker_id)) 101 | 102 | inputs, input_lengths = create_batch_inputs_from_texts(texts) 103 | 104 | for idx in range(math.ceil(len(inputs) / args.batch_size)): 105 | start_idx, end_idx = idx*args.batch_size, (idx+1)*args.batch_size 106 | 107 | cur_texts = texts[start_idx:end_idx] 108 | cur_inputs = inputs[start_idx:end_idx] 109 | 110 | synth.synthesize( 111 | texts=cur_texts, 112 | speaker_ids=[speaker_id] * len(cur_texts), 113 | tokens=cur_inputs, 114 | base_path="{}-{}".format(base_path, idx), 115 | manual_attention_mode=args.manual_attention_mode, 116 | base_alignment_path=args.base_alignment_path, 117 | ) 118 | 119 | synth.close() 120 | 121 | def main(): 122 | parser = argparse.ArgumentParser() 123 | parser.add_argument('--batch_size', default=16) 124 | parser.add_argument('--load_path_pattern', required=True) 125 | parser.add_argument('--base_alignment_path', default=None) 126 | parser.add_argument('--manual_attention_mode', default=0, type=int, 127 | help="0: None, 1: Argmax, 2: Sharpening, 3. Pruning") 128 | parser.add_argument('--hparams', default='', 129 | help='Hyperparameter overrides as a comma-separated list of name=value pairs') 130 | args = parser.parse_args() 131 | 132 | #hparams.max_iters = 100 133 | #hparams.parse(args.hparams) 134 | run_eval(args) 135 | 136 | 137 | if __name__ == '__main__': 138 | main() 139 | -------------------------------------------------------------------------------- /audio/__init__.py: -------------------------------------------------------------------------------- 1 | # Code based on https://github.com/keithito/tacotron/blob/master/util/audio.py 2 | import math 3 | import numpy as np 4 | import tensorflow as tf 5 | from scipy import signal 6 | from hparams import hparams 7 | 8 | import librosa 9 | import librosa.filters 10 | 11 | 12 | def load_audio(path, pre_silence_length=0, post_silence_length=0): 13 | audio = librosa.core.load(path, sr=hparams.sample_rate)[0] 14 | if pre_silence_length > 0 or post_silence_length > 0: 15 | audio = np.concatenate([ 16 | get_silence(pre_silence_length), 17 | audio, 18 | get_silence(post_silence_length), 19 | ]) 20 | return audio 21 | 22 | def save_audio(audio, path, sample_rate=None): 23 | audio *= 32767 / max(0.01, np.max(np.abs(audio))) 24 | librosa.output.write_wav(path, audio.astype(np.int16), 25 | hparams.sample_rate if sample_rate is None else sample_rate) 26 | 27 | print(" [*] Audio saved: {}".format(path)) 28 | 29 | 30 | def resample_audio(audio, target_sample_rate): 31 | return librosa.core.resample( 32 | audio, hparams.sample_rate, target_sample_rate) 33 | 34 | 35 | def get_duration(audio): 36 | return librosa.core.get_duration(audio, sr=hparams.sample_rate) 37 | 38 | 39 | def frames_to_hours(n_frames): 40 | return sum((n_frame for n_frame in n_frames)) * \ 41 | hparams.frame_shift_ms / (3600 * 1000) 42 | 43 | 44 | def get_silence(sec): 45 | return np.zeros(hparams.sample_rate * sec) 46 | 47 | 48 | def spectrogram(y): 49 | D = _stft(_preemphasis(y)) 50 | S = _amp_to_db(np.abs(D)) - hparams.ref_level_db 51 | return _normalize(S) 52 | 53 | 54 | def inv_spectrogram(spectrogram): 55 | S = _db_to_amp(_denormalize(spectrogram) + hparams.ref_level_db) # Convert back to linear 56 | return inv_preemphasis(_griffin_lim(S ** hparams.power)) # Reconstruct phase 57 | 58 | 59 | def inv_spectrogram_tensorflow(spectrogram): 60 | S = _db_to_amp_tensorflow(_denormalize_tensorflow(spectrogram) + hparams.ref_level_db) 61 | return _griffin_lim_tensorflow(tf.pow(S, hparams.power)) 62 | 63 | 64 | def melspectrogram(y): 65 | D = _stft(_preemphasis(y)) 66 | S = _amp_to_db(_linear_to_mel(np.abs(D))) 67 | return _normalize(S) 68 | 69 | 70 | def inv_melspectrogram(melspectrogram): 71 | S = _mel_to_linear(_db_to_amp(_denormalize(melspectrogram))) # Convert back to linear 72 | return inv_preemphasis(_griffin_lim(S ** hparams.power)) # Reconstruct phase 73 | 74 | 75 | # Based on https://github.com/librosa/librosa/issues/434 76 | def _griffin_lim(S): 77 | angles = np.exp(2j * np.pi * np.random.rand(*S.shape)) 78 | S_complex = np.abs(S).astype(np.complex) 79 | 80 | y = _istft(S_complex * angles) 81 | for i in range(hparams.griffin_lim_iters): 82 | angles = np.exp(1j * np.angle(_stft(y))) 83 | y = _istft(S_complex * angles) 84 | return y 85 | 86 | 87 | def _griffin_lim_tensorflow(S): 88 | with tf.variable_scope('griffinlim'): 89 | S = tf.expand_dims(S, 0) 90 | S_complex = tf.identity(tf.cast(S, dtype=tf.complex64)) 91 | y = _istft_tensorflow(S_complex) 92 | for i in range(hparams.griffin_lim_iters): 93 | est = _stft_tensorflow(y) 94 | angles = est / tf.cast(tf.maximum(1e-8, tf.abs(est)), tf.complex64) 95 | y = _istft_tensorflow(S_complex * angles) 96 | return tf.squeeze(y, 0) 97 | 98 | 99 | def _stft(y): 100 | n_fft, hop_length, win_length = _stft_parameters() 101 | return librosa.stft(y=y, n_fft=n_fft, hop_length=hop_length, win_length=win_length) 102 | 103 | 104 | def _istft(y): 105 | _, hop_length, win_length = _stft_parameters() 106 | return librosa.istft(y, hop_length=hop_length, win_length=win_length) 107 | 108 | 109 | def _stft_tensorflow(signals): 110 | n_fft, hop_length, win_length = _stft_parameters() 111 | return tf.contrib.signal.stft(signals, win_length, hop_length, n_fft, pad_end=False) 112 | 113 | 114 | def _istft_tensorflow(stfts): 115 | n_fft, hop_length, win_length = _stft_parameters() 116 | return tf.contrib.signal.inverse_stft(stfts, win_length, hop_length, n_fft) 117 | 118 | def _stft_parameters(): 119 | n_fft = (hparams.num_freq - 1) * 2 120 | hop_length = int(hparams.frame_shift_ms / 1000 * hparams.sample_rate) 121 | win_length = int(hparams.frame_length_ms / 1000 * hparams.sample_rate) 122 | return n_fft, hop_length, win_length 123 | 124 | 125 | # Conversions: 126 | 127 | _mel_basis = None 128 | _inv_mel_basis = None 129 | 130 | def _linear_to_mel(spectrogram): 131 | global _mel_basis 132 | if _mel_basis is None: 133 | _mel_basis = _build_mel_basis() 134 | return np.dot(_mel_basis, spectrogram) 135 | 136 | def _mel_to_linear(mel_spectrogram): 137 | global _inv_mel_basis 138 | if _inv_mel_basis is None: 139 | _inv_mel_basis = np.linalg.pinv(_build_mel_basis()) 140 | return np.maximum(1e-10, np.dot(_inv_mel_basis, mel_spectrogram)) 141 | 142 | def _build_mel_basis(): 143 | n_fft = (hparams.num_freq - 1) * 2 144 | return librosa.filters.mel(hparams.sample_rate, n_fft, n_mels=hparams.num_mels) 145 | 146 | def _amp_to_db(x): 147 | return 20 * np.log10(np.maximum(1e-5, x)) 148 | 149 | def _db_to_amp(x): 150 | return np.power(10.0, x * 0.05) 151 | 152 | def _db_to_amp_tensorflow(x): 153 | return tf.pow(tf.ones(tf.shape(x)) * 10.0, x * 0.05) 154 | 155 | def _preemphasis(x): 156 | return signal.lfilter([1, -hparams.preemphasis], [1], x) 157 | 158 | def inv_preemphasis(x): 159 | return signal.lfilter([1], [1, -hparams.preemphasis], x) 160 | 161 | def _normalize(S): 162 | return np.clip((S - hparams.min_level_db) / -hparams.min_level_db, 0, 1) 163 | 164 | def _denormalize(S): 165 | return (np.clip(S, 0, 1) * -hparams.min_level_db) + hparams.min_level_db 166 | 167 | def _denormalize_tensorflow(S): 168 | return (tf.clip_by_value(S, 0, 1) * -hparams.min_level_db) + hparams.min_level_db 169 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Multi-Speaker Tacotron in TensorFlow 2 | 3 | TensorFlow implementation of: 4 | 5 | - [Deep Voice 2: Multi-Speaker Neural Text-to-Speech](https://arxiv.org/abs/1705.08947) 6 | - [Listening while Speaking: Speech Chain by Deep Learning](https://arxiv.org/abs/1707.04879) 7 | - [Tacotron: Towards End-to-End Speech Synthesis](https://arxiv.org/abs/1703.10135) 8 | 9 | Samples audios (in Korean) can be found [here](http://carpedm20.github.io/tacotron/en.html). 10 | 11 | ![model](./assets/model.png) 12 | 13 | 14 | ## Prerequisites 15 | 16 | - Python 3.6+ 17 | - FFmpeg 18 | - [Tensorflow 1.3](https://www.tensorflow.org/install/) 19 | 20 | 21 | ## Usage 22 | 23 | ### 1. Install prerequisites 24 | 25 | After preparing [Tensorflow](https://www.tensorflow.org/install/), install prerequisites with: 26 | 27 | pip3 install -r requirements.txt 28 | python -c "import nltk; nltk.download('punkt')" 29 | 30 | If you want to synthesize a speech in Korean dicrectly, follow [2-3. Download pre-trained models](#2-3-download-pre-trained-models). 31 | 32 | 33 | ### 2-1. Generate custom datasets 34 | 35 | The `datasets` directory should look like: 36 | 37 | datasets 38 | ├── son 39 | │ ├── alignment.json 40 | │ └── audio 41 | │ ├── 1.mp3 42 | │ ├── 2.mp3 43 | │ ├── 3.mp3 44 | │ └── ... 45 | └── YOUR_DATASET 46 | ├── alignment.json 47 | └── audio 48 | ├── 1.mp3 49 | ├── 2.mp3 50 | ├── 3.mp3 51 | └── ... 52 | 53 | and `YOUR_DATASET/alignment.json` should look like: 54 | 55 | { 56 | "./datasets/YOUR_DATASET/audio/001.mp3": "My name is Taehoon Kim.", 57 | "./datasets/YOUR_DATASET/audio/002.mp3": "The buses aren't the problem.", 58 | "./datasets/YOUR_DATASET/audio/003.mp3": "They have discovered a new particle.", 59 | } 60 | 61 | After you prepare as described, you should genearte preprocessed data with: 62 | 63 | python3 -m datasets.generate_data ./datasets/YOUR_DATASET/alignment.json 64 | 65 | 66 | ### 2-2. Generate Korean datasets 67 | 68 | Follow below commands. (explain with `son` dataset) 69 | 70 | 0. To automate an alignment between sounds and texts, prepare `GOOGLE_APPLICATION_CREDENTIALS` to use [Google Speech Recognition API](https://cloud.google.com/speech/). To get credentials, read [this](https://developers.google.com/identity/protocols/application-default-credentials). 71 | 72 | export GOOGLE_APPLICATION_CREDENTIALS="YOUR-GOOGLE.CREDENTIALS.json" 73 | 74 | 1. Download speech(or video) and text. 75 | 76 | python3 -m datasets.son.download 77 | 78 | 2. Segment all audios on silence. 79 | 80 | python3 -m audio.silence --audio_pattern "./datasets/son/audio/*.wav" --method=pydub 81 | 82 | 3. By using [Google Speech Recognition API](https://cloud.google.com/speech/), we predict sentences for all segmented audios. 83 | 84 | python3 -m recognition.google --audio_pattern "./datasets/son/audio/*.*.wav" 85 | 86 | 4. By comparing original text and recognised text, save `audio<->text` pair information into `./datasets/son/alignment.json`. 87 | 88 | python3 -m recognition.alignment --recognition_path "./datasets/son/recognition.json" --score_threshold=0.5 89 | 90 | 5. Finally, generated numpy files which will be used in training. 91 | 92 | python3 -m datasets.generate_data ./datasets/son/alignment.json 93 | 94 | Because the automatic generation is extremely naive, the dataset is noisy. However, if you have enough datasets (20+ hours with random initialization or 5+ hours with pretrained model initialization), you can expect an acceptable quality of audio synthesis. 95 | 96 | ### 2-3. Generate English datasets 97 | 98 | 1. Download speech dataset at https://keithito.com/LJ-Speech-Dataset/ 99 | 100 | 2. Convert metadata CSV file to json file. (arguments are available for changing preferences) 101 | 102 | python3 -m datasets.LJSpeech_1_0.prepare 103 | 104 | 3. Finally, generate numpy files which will be used in training. 105 | 106 | python3 -m datasets.generate_data ./datasets/LJSpeech_1_0 107 | 108 | 109 | ### 3. Train a model 110 | 111 | The important hyperparameters for a models are defined in `hparams.py`. 112 | 113 | (**Change `cleaners` in `hparams.py` from `korean_cleaners` to `english_cleaners` to train with English dataset**) 114 | 115 | To train a single-speaker model: 116 | 117 | python3 train.py --data_path=datasets/son 118 | python3 train.py --data_path=datasets/son --initialize_path=PATH_TO_CHECKPOINT 119 | 120 | To train a multi-speaker model: 121 | 122 | # after change `model_type` in `hparams.py` to `deepvoice` or `simple` 123 | python3 train.py --data_path=datasets/son1,datasets/son2 124 | 125 | To restart a training from previous experiments such as `logs/son-20171015`: 126 | 127 | python3 train.py --data_path=datasets/son --load_path logs/son-20171015 128 | 129 | If you don't have good and enough (10+ hours) dataset, it would be better to use `--initialize_path` to use a well-trained model as initial parameters. 130 | 131 | 132 | ### 4. Synthesize audio 133 | 134 | You can train your own models with: 135 | 136 | python3 app.py --load_path logs/son-20171015 --num_speakers=1 137 | 138 | or generate audio directly with: 139 | 140 | python3 synthesizer.py --load_path logs/son-20171015 --text "이거 실화냐?" 141 | 142 | ### 4-1. Synthesizing non-korean(english) audio 143 | 144 | For generating non-korean audio, you must set the argument --is_korean False. 145 | 146 | python3 app.py --load_path logs/LJSpeech_1_0-20180108 --num_speakers=1 --is_korean=False 147 | python3 synthesizer.py --load_path logs/LJSpeech_1_0-20180108 --text="Winter is coming." --is_korean=False 148 | 149 | ## Results 150 | 151 | Training attention on single speaker model: 152 | 153 | ![model](./assets/attention_single_speaker.gif) 154 | 155 | Training attention on multi speaker model: 156 | 157 | ![model](./assets/attention_multi_speaker.gif) 158 | 159 | 160 | ## Disclaimer 161 | 162 | This is not an official [DEVSISTERS](http://devsisters.com/) product. This project is not responsible for misuse or for any damage that you may cause. You agree that you use this software at your own risk. 163 | 164 | 165 | ## References 166 | 167 | - [Keith Ito](https://github.com/keithito)'s [tacotron](https://github.com/keithito/tacotron) 168 | - [DEVIEW 2017 presentation](https://www.slideshare.net/carpedm20/deview-2017-80824162) 169 | 170 | 171 | ## Author 172 | 173 | Taehoon Kim / [@carpedm20](http://carpedm20.github.io/) 174 | -------------------------------------------------------------------------------- /recognition/alignment.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import os 3 | import string 4 | import argparse 5 | import operator 6 | from functools import partial 7 | from difflib import SequenceMatcher 8 | 9 | from audio.get_duration import get_durations 10 | from text import remove_puncuations, text_to_sequence 11 | from utils import load_json, write_json, parallel_run, remove_postfix, backup_file 12 | 13 | def plain_text(text): 14 | return "".join(remove_puncuations(text.strip()).split()) 15 | 16 | def add_punctuation(text): 17 | if text.endswith('다'): 18 | return text + "." 19 | else: 20 | return text 21 | 22 | def similarity(text_a, text_b): 23 | text_a = plain_text(text_a) 24 | text_b = plain_text(text_b) 25 | 26 | score = SequenceMatcher(None, text_a, text_b).ratio() 27 | return score 28 | 29 | def first_word_combined_words(text): 30 | words = text.split() 31 | if len(words) > 1: 32 | first_words = [words[0], words[0]+words[1]] 33 | else: 34 | first_words = [words[0]] 35 | return first_words 36 | 37 | def first_word_combined_texts(text): 38 | words = text.split() 39 | if len(words) > 1: 40 | if len(words) > 2: 41 | text2 = " ".join([words[0]+words[1]] + words[2:]) 42 | else: 43 | text2 = words[0]+words[1] 44 | texts = [text, text2] 45 | else: 46 | texts = [text] 47 | return texts 48 | 49 | def search_optimal(found_text, recognition_text): 50 | # 1. found_text is usually more accurate 51 | # 2. recognition_text can have more or less word 52 | 53 | optimal = None 54 | 55 | if plain_text(recognition_text) in plain_text(found_text): 56 | optimal = recognition_text 57 | else: 58 | found = False 59 | 60 | for tmp_text in first_word_combined_texts(found_text): 61 | for recognition_first_word in first_word_combined_words(recognition_text): 62 | if recognition_first_word in tmp_text: 63 | start_idx = tmp_text.find(recognition_first_word) 64 | 65 | if tmp_text != found_text: 66 | found_text = found_text[max(0, start_idx-1):].strip() 67 | else: 68 | found_text = found_text[start_idx:].strip() 69 | found = True 70 | break 71 | 72 | if found: 73 | break 74 | 75 | recognition_last_word = recognition_text.split()[-1] 76 | if recognition_last_word in found_text: 77 | end_idx = found_text.find(recognition_last_word) 78 | 79 | punctuation = "" 80 | if len(found_text) > end_idx + len(recognition_last_word): 81 | punctuation = found_text[end_idx + len(recognition_last_word)] 82 | if punctuation not in string.punctuation: 83 | punctuation = "" 84 | 85 | found_text = found_text[:end_idx] + recognition_last_word + punctuation 86 | found = True 87 | 88 | if found: 89 | optimal = found_text 90 | 91 | return optimal 92 | 93 | 94 | def align_text_fn( 95 | item, score_threshold, debug=False): 96 | 97 | audio_path, recognition_text = item 98 | 99 | audio_dir = os.path.dirname(audio_path) 100 | base_dir = os.path.dirname(audio_dir) 101 | 102 | news_path = remove_postfix(audio_path.replace("audio", "assets")) 103 | news_path = os.path.splitext(news_path)[0] + ".txt" 104 | 105 | strip_fn = lambda line: line.strip().replace('"', '').replace("'", "") 106 | candidates = [strip_fn(line) for line in open(news_path, encoding='cp949').readlines()] 107 | 108 | scores = { candidate: similarity(candidate, recognition_text) \ 109 | for candidate in candidates} 110 | sorted_scores = sorted(scores.items(), key=operator.itemgetter(1))[::-1] 111 | 112 | first, second = sorted_scores[0], sorted_scores[1] 113 | 114 | if first[1] > second[1] and first[1] >= score_threshold: 115 | found_text, score = first 116 | aligned_text = search_optimal(found_text, recognition_text) 117 | 118 | if debug: 119 | print(" ", audio_path) 120 | print(" ", recognition_text) 121 | print("=> ", found_text) 122 | print("==>", aligned_text) 123 | print("="*30) 124 | 125 | if aligned_text is not None: 126 | result = { audio_path: add_punctuation(aligned_text) } 127 | elif abs(len(text_to_sequence(found_text)) - len(text_to_sequence(recognition_text))) > 10: 128 | result = {} 129 | else: 130 | result = { audio_path: [add_punctuation(found_text), recognition_text] } 131 | else: 132 | result = {} 133 | 134 | if len(result) == 0: 135 | result = { audio_path: [recognition_text] } 136 | 137 | return result 138 | 139 | def align_text_batch(config): 140 | align_text = partial(align_text_fn, 141 | score_threshold=config.score_threshold) 142 | 143 | results = {} 144 | data = load_json(config.recognition_path, encoding=config.recognition_encoding) 145 | 146 | items = parallel_run( 147 | align_text, data.items(), 148 | desc="align_text_batch", parallel=True) 149 | 150 | for item in items: 151 | results.update(item) 152 | 153 | found_count = sum([type(value) == str for value in results.values()]) 154 | print(" [*] # found: {:.5f}% ({}/{})".format( 155 | len(results)/len(data), len(results), len(data))) 156 | print(" [*] # exact match: {:.5f}% ({}/{})".format( 157 | found_count/len(items), found_count, len(items))) 158 | 159 | return results 160 | 161 | if __name__ == '__main__': 162 | parser = argparse.ArgumentParser() 163 | parser.add_argument('--recognition_path', required=True) 164 | parser.add_argument('--alignment_filename', default="alignment.json") 165 | parser.add_argument('--score_threshold', default=0.4, type=float) 166 | parser.add_argument('--recognition_encoding', default='949') 167 | config, unparsed = parser.parse_known_args() 168 | 169 | results = align_text_batch(config) 170 | 171 | base_dir = os.path.dirname(config.recognition_path) 172 | alignment_path = \ 173 | os.path.join(base_dir, config.alignment_filename) 174 | 175 | if os.path.exists(alignment_path): 176 | backup_file(alignment_path) 177 | 178 | write_json(alignment_path, results) 179 | duration = get_durations(results.keys(), print_detail=False) 180 | -------------------------------------------------------------------------------- /models/modules.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.contrib.rnn import GRUCell 3 | from tensorflow.python.layers import core 4 | from tensorflow.contrib.seq2seq.python.ops.attention_wrapper \ 5 | import _bahdanau_score, _BaseAttentionMechanism, BahdanauAttention, \ 6 | AttentionWrapper, AttentionWrapperState 7 | 8 | 9 | def get_embed(inputs, num_inputs, embed_size, name): 10 | embed_table = tf.get_variable( 11 | name, [num_inputs, embed_size], dtype=tf.float32, 12 | initializer=tf.truncated_normal_initializer(stddev=0.1)) 13 | return tf.nn.embedding_lookup(embed_table, inputs) 14 | 15 | 16 | def prenet(inputs, is_training, layer_sizes, drop_prob, scope=None): 17 | """ 18 | Args: 19 | inputs: input vector 20 | is_training: dropout option 21 | layer_sizes: iteration number 22 | 23 | Output: 24 | x: prenet 25 | """ 26 | x = inputs 27 | drop_rate = drop_prob if is_training else 0.0 # set dropout rate 0.5 (only training) 28 | with tf.variable_scope(scope or 'prenet'): 29 | for i, size in enumerate(layer_sizes): # iterate layer_sizes 30 | dense = tf.layers.dense(x, units=size, activation=tf.nn.relu, name='dense_%d' % (i + 1)) 31 | x = tf.layers.dropout(dense, rate=drop_rate, training=is_training, name='dropout_%d' % (i + 1)) 32 | return x 33 | 34 | 35 | def encoder_cbhg(inputs, input_lengths, is_training, depth): 36 | """ 37 | Args: 38 | inputs: input tensor 39 | input_lengths: length of input tensor 40 | is_training: Batch Normalization option in Conv1D 41 | depth: dimensionality option of Highway net and Bidirectical GRU's output 42 | 43 | Output: 44 | cbhg function 45 | """ 46 | input_channels = inputs.get_shape()[2] # 3rd element of inputs' shape 47 | return cbhg( 48 | inputs, 49 | input_lengths, 50 | is_training, 51 | scope='encoder_cbhg', 52 | K=16, 53 | projections=[128, input_channels], 54 | depth=depth) 55 | 56 | 57 | def post_cbhg(inputs, input_dim, is_training, depth): 58 | """ 59 | Args: 60 | inputs: input tensor 61 | input_dim: dimension of input tensor 62 | is_training: Batch Normalization option in Conv1D 63 | depth: dimensionality option of Highway net and Bidirectical GRU's output 64 | 65 | Output: 66 | cbhg function 67 | """ 68 | return cbhg( 69 | inputs, 70 | None, 71 | is_training, 72 | scope='post_cbhg', 73 | K=8, 74 | projections=[256, input_dim], 75 | depth=depth) 76 | 77 | 78 | def cbhg(inputs, input_lengths, is_training, bank_size, bank_channel_size, 79 | maxpool_width, highway_depth, rnn_size, proj_sizes, proj_width, scope, 80 | before_highway = None, encoder_rnn_init_state = None): 81 | """ 82 | Args: 83 | inputs: input tensor 84 | input_lengths: length of input tensor 85 | is_training: Batch Normalization option in Conv1D 86 | scope: network or model name 87 | K: kernel size range 88 | projections: projection layers option 89 | depth: dimensionality option of Highway net and Bidirectical GRU's output 90 | The layers in the code are staked in the order in which they came out. 91 | """ 92 | 93 | batch_size = tf.shape(inputs)[0] 94 | with tf.variable_scope(scope): 95 | with tf.variable_scope('conv_bank'): 96 | 97 | conv_outputs = tf.concat( 98 | [conv1d(inputs, k, 128, tf.nn.relu, is_training, 'conv1d_%d' % k) for k in range(1, bank_size + 1)], #1D Convolution layers using multiple types of Convolution Kernel. 99 | axis=-1 #Iterate K with increasing filter size by 1. 100 | )# Convolution bank: concatenate on the last axis to stack channels from all convolutions 101 | 102 | # Maxpooling: 103 | maxpool_output = tf.layers.max_pooling1d( 104 | conv_outputs, 105 | pool_size=maxpool_width, 106 | strides=1, 107 | padding='same') #1D Maxpooling layer(strides=1, width=2) 108 | 109 | # Two projection layers: 110 | proj1_output = conv1d(maxpool_output, proj_width, projections[0], tf.nn.relu, is_training, 'proj_1')#1st Conv1D projections 111 | proj2_output = conv1d(proj1_output, proj_width, projections[1], None, is_training, 'proj_2')#2nd Conv1D projections 112 | 113 | # Residual connection: 114 | if before_highway is not None: 115 | expanded_before_highway = tf.expand_dims(before_highway, [1]) 116 | tiled_before_highway = tf.tile( 117 | expanded_before_highway, [1, tf.shape(proj2_out)[1], 1]) 118 | highway_input = proj2_out + inputs + tiled_before_highway 119 | 120 | else: 121 | highway_input = proj2_out + inputs 122 | 123 | # Handle dimensionality mismatch: 124 | if highway_input.shape[2] != rnn_size: 125 | highway_input = tf.layers.dense(highway_input, rnn_size) 126 | 127 | # 4-layer HighwayNet: 128 | for idx in range(highway_depth): 129 | highway_input = highwaynet(highway_input, 'highway_%d' % (idx+1)) #make 4 Highway net layers 130 | rnn_input = highway_input 131 | 132 | # Bidirectional RNN 133 | if encoder_rnn_init_state is not None: 134 | initial_state_fw, initial_state_bw = tf.split(encoder_rnn_init_state, 2, 1) 135 | else: 136 | initial_state_fw, initial_state_bw = None, None 137 | 138 | outputs, states = tf.nn.bidirectional_dynamic_rnn( #make Bidirectional GRU 139 | GRUCell(rnn_size), 140 | GRUCell(rnn_size), 141 | rnn_input, 142 | sequence_length=input_lengths, 143 | initial_state_fw=initial_state_fw, 144 | initial_state_bw=initial_state_bw, 145 | dtype=tf.float32) 146 | return tf.concat(outputs, axis=2) # Concat forward sequence and backward sequence 147 | 148 | 149 | def batch_tile(tensor, batch_size): 150 | expaneded_tensor = tf.expand_dims(tensor, [0]) 151 | return tf.tile(expaneded_tensor, \ 152 | [batch_size] + [1 for _ in tensor.get_shape()]) 153 | 154 | 155 | def highwaynet(inputs, scope): 156 | highway_dim = int(inputs.get_shape()[-1]) 157 | 158 | with tf.variable_scope(scope): 159 | H = tf.layers.dense( 160 | inputs, 161 | units=highway_dim, 162 | activation=tf.nn.relu, 163 | name='H') 164 | T = tf.layers.dense( 165 | inputs, 166 | units=highway_dim, 167 | activation=tf.nn.sigmoid, 168 | name='T', 169 | bias_initializer=tf.constant_initializer(-1.0)) 170 | return H * T + inputs * (1.0 - T) 171 | 172 | 173 | def conv1d(inputs, kernel_size, channels, activation, is_training, scope): 174 | """ 175 | Args: 176 | inputs: input tensor 177 | kernel_size: length of the 1D convolution window 178 | channels: dimensionality of the output space 179 | activation: Activation function (None means linear activation) 180 | is_training: Batch Normalization option in Conv1D 181 | scope: namespace 182 | 183 | Output: 184 | output tensor 185 | """ 186 | with tf.variable_scope(scope): 187 | conv1d_output = tf.layers.conv1d( # creates a convolution kernel 188 | inputs, 189 | filters=channels, 190 | kernel_size=kernel_size, 191 | activation=activation, 192 | padding='same') # return output tensor 193 | return tf.layers.batch_normalization(conv1d_output, training=is_training) 194 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import sys 4 | import json 5 | import requests 6 | import subprocess 7 | from tqdm import tqdm 8 | from contextlib import closing 9 | from multiprocessing import Pool 10 | from collections import namedtuple 11 | from datetime import datetime, timedelta 12 | from shutil import copyfile as copy_file 13 | 14 | PARAMS_NAME = "params.json" 15 | 16 | class ValueWindow(): 17 | def __init__(self, window_size=100): 18 | self._window_size = window_size 19 | self._values = [] 20 | 21 | def append(self, x): 22 | self._values = self._values[-(self._window_size - 1):] + [x] 23 | 24 | @property 25 | def sum(self): 26 | return sum(self._values) 27 | 28 | @property 29 | def count(self): 30 | return len(self._values) 31 | 32 | @property 33 | def average(self): 34 | return self.sum / max(1, self.count) 35 | 36 | def reset(self): 37 | self._values = [] 38 | 39 | def prepare_dirs(config, hparams): 40 | if hasattr(config, "data_paths"): 41 | config.datasets = [ 42 | os.path.basename(data_path) for data_path in config.data_paths] 43 | dataset_desc = "+".join(config.datasets) 44 | 45 | if config.load_path: 46 | config.model_dir = config.load_path 47 | else: 48 | config.model_name = "{}_{}".format(dataset_desc, get_time()) 49 | config.model_dir = os.path.join(config.log_dir, config.model_name) 50 | 51 | for path in [config.log_dir, config.model_dir]: 52 | if not os.path.exists(path): 53 | os.makedirs(path) 54 | 55 | if config.load_path: 56 | load_hparams(hparams, config.model_dir) 57 | else: 58 | setattr(hparams, "num_speakers", len(config.datasets)) 59 | 60 | save_hparams(config.model_dir, hparams) 61 | copy_file("hparams.py", os.path.join(config.model_dir, "hparams.py")) 62 | 63 | def makedirs(path): 64 | if not os.path.exists(path): 65 | print(" [*] Make directories : {}".format(path)) 66 | os.makedirs(path) 67 | 68 | def remove_file(path): 69 | if os.path.exists(path): 70 | print(" [*] Removed: {}".format(path)) 71 | os.remove(path) 72 | 73 | def backup_file(path): 74 | root, ext = os.path.splitext(path) 75 | new_path = "{}.backup_{}{}".format(root, get_time(), ext) 76 | 77 | os.rename(path, new_path) 78 | print(" [*] {} has backup: {}".format(path, new_path)) 79 | 80 | def get_time(): 81 | return datetime.now().strftime("%Y-%m-%d_%H-%M-%S") 82 | 83 | def write_json(path, data): 84 | with open(path, 'w',encoding='utf-8') as f: 85 | json.dump(data, f, indent=4, sort_keys=True, ensure_ascii=False) 86 | 87 | def load_json(path, as_class=False, encoding='euc-kr'): 88 | with open(path,encoding=encoding) as f: 89 | content = f.read() 90 | content = re.sub(",\s*}", "}", content) 91 | content = re.sub(",\s*]", "]", content) 92 | 93 | if as_class: 94 | data = json.loads(content, object_hook=\ 95 | lambda data: namedtuple('Data', data.keys())(*data.values())) 96 | else: 97 | data = json.loads(content) 98 | #print(data) 99 | return data 100 | 101 | def save_hparams(model_dir, hparams): 102 | param_path = os.path.join(model_dir, PARAMS_NAME) 103 | 104 | info = eval(hparams.to_json(). \ 105 | replace('true', 'True').replace('false', 'False')) 106 | write_json(param_path, info) 107 | 108 | print(" [*] MODEL dir: {}".format(model_dir)) 109 | print(" [*] PARAM path: {}".format(param_path)) 110 | 111 | def load_hparams(hparams, load_path, skip_list=[]): 112 | path = os.path.join(load_path, PARAMS_NAME) 113 | 114 | new_hparams = load_json(path) 115 | hparams_keys = vars(hparams).keys() 116 | 117 | for key, value in new_hparams.items(): 118 | if key in skip_list or key not in hparams_keys: 119 | print("Skip {} because it not exists".format(key)) 120 | continue 121 | 122 | if key not in ['job_name', 'num_workers', 'display', 'is_train', 'load_path'] or \ 123 | key == "pointer_load_path": 124 | original_value = getattr(hparams, key) 125 | if original_value != value: 126 | print("UPDATE {}: {} -> {}".format(key, getattr(hparams, key), value)) 127 | setattr(hparams, key, value) 128 | 129 | def add_prefix(path, prefix): 130 | dir_path, filename = os.path.dirname(path), os.path.basename(path) 131 | return "{}/{}.{}".format(dir_path, prefix, filename) 132 | 133 | def add_postfix(path, postfix): 134 | path_without_ext, ext = path.rsplit('.', 1) 135 | return "{}.{}.{}".format(path_without_ext, postfix, ext) 136 | 137 | def remove_postfix(path): 138 | items = path.rsplit('.', 2) 139 | return items[0] + "." + items[2] 140 | 141 | def parallel_run(fn, items, desc="", parallel=True): 142 | results = [] 143 | 144 | if parallel: 145 | with closing(Pool()) as pool: 146 | for out in tqdm(pool.imap_unordered( 147 | fn, items), total=len(items), desc=desc): 148 | if out is not None: 149 | results.append(out) 150 | else: 151 | for item in tqdm(items, total=len(items), desc=desc): 152 | out = fn(item) 153 | if out is not None: 154 | results.append(out) 155 | 156 | return results 157 | 158 | def which(program): 159 | if os.name == "nt" and not program.endswith(".exe"): 160 | program += ".exe" 161 | 162 | envdir_list = [os.curdir] + os.environ["PATH"].split(os.pathsep) 163 | 164 | for envdir in envdir_list: 165 | program_path = os.path.join(envdir, program) 166 | if os.path.isfile(program_path) and os.access(program_path, os.X_OK): 167 | return program_path 168 | 169 | def get_encoder_name(): 170 | if which("avconv"): 171 | return "avconv" 172 | elif which("ffmpeg"): 173 | return "ffmpeg" 174 | else: 175 | return "ffmpeg" 176 | 177 | def download_with_url(url, dest_path, chunk_size=32*1024): 178 | with open(dest_path, "wb") as f: 179 | response = requests.get(url, stream=True) 180 | total_size = int(response.headers.get('content-length', 0)) 181 | 182 | for chunk in response.iter_content(chunk_size): 183 | if chunk: # filter out keep-alive new chunks 184 | f.write(chunk) 185 | return True 186 | 187 | def str2bool(v): 188 | return v.lower() in ('true', '1') 189 | 190 | def get_git_revision_hash(): 191 | return subprocess.check_output(['git', 'rev-parse', 'HEAD']).decode("utf-8") 192 | 193 | def get_git_diff(): 194 | return subprocess.check_output(['git', 'diff']).decode("utf-8") 195 | 196 | def warning(msg): 197 | print("="*40) 198 | print(" [!] {}".format(msg)) 199 | print("="*40) 200 | print() 201 | 202 | def query_yes_no(question, default=None): 203 | # Code from https://stackoverflow.com/a/3041990 204 | valid = {"yes": True, "y": True, "ye": True, 205 | "no": False, "n": False} 206 | if default is None: 207 | prompt = " [y/n] " 208 | elif default == "yes": 209 | prompt = " [Y/n] " 210 | elif default == "no": 211 | prompt = " [y/N] " 212 | else: 213 | raise ValueError("invalid default answer: '%s'" % default) 214 | 215 | while True: 216 | sys.stdout.write(question + prompt) 217 | choice = input().lower() 218 | if default is not None and choice == '': 219 | return valid[default] 220 | elif choice in valid: 221 | return valid[choice] 222 | else: 223 | sys.stdout.write("Please respond with 'yes' or 'no' " 224 | "(or 'y' or 'n').\n") -------------------------------------------------------------------------------- /text/korean.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # Code based on carpedm20 3 | 4 | import re 5 | import os 6 | import ast 7 | import json 8 | from jamo import hangul_to_jamo, h2j, j2h 9 | 10 | from .kor_dic import english_dictionary, etc_dictionary 11 | 12 | PAD = '_' 13 | EOS = '~' 14 | PUNC = '!\'(),-.:;?' 15 | SPACE = ' ' 16 | 17 | JAMO_LEADS = "".join([chr(_) for _ in range(0x1100, 0x1113)]) 18 | JAMO_VOWELS = "".join([chr(_) for _ in range(0x1161, 0x1176)]) 19 | JAMO_TAILS = "".join([chr(_) for _ in range(0x11A8, 0x11C3)]) 20 | 21 | VALID_CHARS = JAMO_LEADS + JAMO_VOWELS + JAMO_TAILS + PUNC + SPACE 22 | ALL_SYMBOLS = PAD + EOS + VALID_CHARS 23 | 24 | char_to_id = {c: i for i, c in enumerate(ALL_SYMBOLS)} 25 | id_to_char = {i: c for i, c in enumerate(ALL_SYMBOLS)} 26 | 27 | quote_checker = """([`"'"“‘])(.+?)([`"'"”’])""" 28 | 29 | 30 | def is_lead(char): 31 | return char in JAMO_LEADS 32 | 33 | 34 | def is_vowel(char): 35 | return char in JAMO_VOWELS 36 | 37 | 38 | def is_tail(char): 39 | return char in JAMO_TAILS 40 | 41 | 42 | def get_mode(char): 43 | if is_lead(char): 44 | return 0 45 | elif is_vowel(char): 46 | return 1 47 | elif is_tail(char): 48 | return 2 49 | else: 50 | return -1 51 | 52 | 53 | def _get_text_from_candidates(candidates): 54 | if len(candidates) == 0: 55 | return "" 56 | elif len(candidates) == 1: 57 | return _jamo_char_to_hcj(candidates[0]) 58 | else: 59 | return j2h(**dict(zip(["lead", "vowel", "tail"], candidates))) 60 | 61 | 62 | def jamo_to_korean(text): 63 | text = h2j(text) 64 | 65 | idx = 0 66 | new_text = "" 67 | candidates = [] 68 | 69 | while True: 70 | if idx >= len(text): 71 | new_text += _get_text_from_candidates(candidates) 72 | break 73 | 74 | char = text[idx] 75 | mode = get_mode(char) 76 | 77 | if mode == 0: 78 | new_text += _get_text_from_candidates(candidates) 79 | candidates = [char] 80 | elif mode == -1: 81 | new_text += _get_text_from_candidates(candidates) 82 | new_text += char 83 | candidates = [] 84 | else: 85 | candidates.append(char) 86 | 87 | idx += 1 88 | return new_text 89 | 90 | 91 | num_to_kor = { 92 | '0': '영', 93 | '1': '일', 94 | '2': '이', 95 | '3': '삼', 96 | '4': '사', 97 | '5': '오', 98 | '6': '육', 99 | '7': '칠', 100 | '8': '팔', 101 | '9': '구', 102 | } 103 | 104 | unit_to_kor1 = { 105 | '%': '퍼센트', 106 | 'cm': '센치미터', 107 | 'mm': '밀리미터', 108 | 'km': '킬로미터', 109 | 'kg': '킬로그람', 110 | } 111 | unit_to_kor2 = { 112 | 'm': '미터', 113 | } 114 | 115 | upper_to_kor = { 116 | 'A': '에이', 117 | 'B': '비', 118 | 'C': '씨', 119 | 'D': '디', 120 | 'E': '이', 121 | 'F': '에프', 122 | 'G': '지', 123 | 'H': '에이치', 124 | 'I': '아이', 125 | 'J': '제이', 126 | 'K': '케이', 127 | 'L': '엘', 128 | 'M': '엠', 129 | 'N': '엔', 130 | 'O': '오', 131 | 'P': '피', 132 | 'Q': '큐', 133 | 'R': '알', 134 | 'S': '에스', 135 | 'T': '티', 136 | 'U': '유', 137 | 'V': '브이', 138 | 'W': '더블유', 139 | 'X': '엑스', 140 | 'Y': '와이', 141 | 'Z': '지', 142 | } 143 | 144 | 145 | def compare_sentence_with_jamo(text1, text2): 146 | return h2j(text1) != h2j(text2) 147 | 148 | 149 | def tokenize(text, as_id=False): 150 | # jamo package에 있는 hangul_to_jamo를 이용하여 한글 string을 초성/중성/종성으로 나눈다. 151 | text = normalize(text) 152 | tokens = list(hangul_to_jamo(text)) # '존경하는' --> ['ᄌ', 'ᅩ', 'ᆫ', 'ᄀ', 'ᅧ', 'ᆼ', 'ᄒ', 'ᅡ', 'ᄂ', 'ᅳ', 'ᆫ', '~'] 153 | 154 | if as_id: 155 | return [char_to_id[token] for token in tokens] + [char_to_id[EOS]] 156 | else: 157 | return [token for token in tokens] + [EOS] 158 | 159 | 160 | def tokenizer_fn(iterator): 161 | return (token for x in iterator for token in tokenize(x, as_id=False)) 162 | 163 | 164 | def normalize(text): 165 | text = text.strip() 166 | 167 | text = re.sub('\(\d+일\)', '', text) 168 | text = re.sub('\([⺀-⺙⺛-⻳⼀-⿕々〇〡-〩〸-〺〻㐀-䶵一-鿃豈-鶴侮-頻並-龎]+\)', '', text) 169 | 170 | text = normalize_with_dictionary(text, etc_dictionary) 171 | text = normalize_english(text) 172 | text = re.sub('[a-zA-Z]+', normalize_upper, text) 173 | 174 | text = normalize_quote(text) 175 | text = normalize_number(text) 176 | 177 | return text 178 | 179 | 180 | def normalize_with_dictionary(text, dic): 181 | if any(key in text for key in dic.keys()): 182 | pattern = re.compile('|'.join(re.escape(key) for key in dic.keys())) 183 | return pattern.sub(lambda x: dic[x.group()], text) 184 | else: 185 | return text 186 | 187 | 188 | def normalize_english(text): 189 | def fn(m): 190 | word = m.group() 191 | if word in english_dictionary: 192 | return english_dictionary.get(word) 193 | else: 194 | return word 195 | 196 | text = re.sub("([A-Za-z]+)", fn, text) 197 | return text 198 | 199 | 200 | def normalize_upper(text): 201 | text = text.group(0) 202 | 203 | if all([char.isupper() for char in text]): 204 | return "".join(upper_to_kor[char] for char in text) 205 | else: 206 | return text 207 | 208 | 209 | def normalize_quote(text): 210 | def fn(found_text): 211 | from nltk import sent_tokenize # NLTK doesn't along with multiprocessing 212 | 213 | found_text = found_text.group() 214 | unquoted_text = found_text[1:-1] 215 | 216 | sentences = sent_tokenize(unquoted_text) 217 | return " ".join(["'{}'".format(sent) for sent in sentences]) 218 | 219 | return re.sub(quote_checker, fn, text) 220 | 221 | 222 | number_checker = "([+-]?\d[\d,]*)[\.]?\d*" 223 | count_checker = "(시|명|가지|살|마리|포기|송이|수|톨|통|점|개|벌|척|채|다발|그루|자루|줄|켤레|그릇|잔|마디|상자|사람|곡|병|판)" 224 | 225 | 226 | def normalize_number(text): 227 | text = normalize_with_dictionary(text, unit_to_kor1) 228 | text = normalize_with_dictionary(text, unit_to_kor2) 229 | text = re.sub(number_checker + count_checker, 230 | lambda x: number_to_korean(x, True), text) 231 | text = re.sub(number_checker, 232 | lambda x: number_to_korean(x, False), text) 233 | return text 234 | 235 | 236 | num_to_kor1 = [""] + list("일이삼사오육칠팔구") 237 | num_to_kor2 = [""] + list("만억조경해") 238 | num_to_kor3 = [""] + list("십백천") 239 | 240 | # count_to_kor1 = [""] + ["하나","둘","셋","넷","다섯","여섯","일곱","여덟","아홉"] 241 | count_to_kor1 = [""] + ["한", "두", "세", "네", "다섯", "여섯", "일곱", "여덟", "아홉"] 242 | 243 | count_tenth_dict = { 244 | "십": "열", 245 | "두십": "스물", 246 | "세십": "서른", 247 | "네십": "마흔", 248 | "다섯십": "쉰", 249 | "여섯십": "예순", 250 | "일곱십": "일흔", 251 | "여덟십": "여든", 252 | "아홉십": "아흔", 253 | } 254 | 255 | 256 | def number_to_korean(num_str, is_count=False): 257 | if is_count: 258 | num_str, unit_str = num_str.group(1), num_str.group(2) 259 | else: 260 | num_str, unit_str = num_str.group(), "" 261 | 262 | num_str = num_str.replace(',', '') 263 | num = ast.literal_eval(num_str) 264 | 265 | if num == 0: 266 | return "영" 267 | 268 | check_float = num_str.split('.') 269 | if len(check_float) == 2: 270 | digit_str, float_str = check_float 271 | elif len(check_float) >= 3: 272 | raise Exception(" [!] Wrong number format") 273 | else: 274 | digit_str, float_str = check_float[0], None 275 | 276 | if is_count and float_str is not None: 277 | raise Exception(" [!] `is_count` and float number does not fit each other") 278 | 279 | digit = int(digit_str) 280 | 281 | if digit_str.startswith("-"): 282 | digit, digit_str = abs(digit), str(abs(digit)) 283 | 284 | kor = "" 285 | size = len(str(digit)) 286 | tmp = [] 287 | 288 | for i, v in enumerate(digit_str, start=1): 289 | v = int(v) 290 | 291 | if v != 0: 292 | if is_count: 293 | tmp += count_to_kor1[v] 294 | else: 295 | tmp += num_to_kor1[v] 296 | 297 | tmp += num_to_kor3[(size - i) % 4] 298 | 299 | if (size - i) % 4 == 0 and len(tmp) != 0: 300 | kor += "".join(tmp) 301 | tmp = [] 302 | kor += num_to_kor2[int((size - i) / 4)] 303 | 304 | if is_count: 305 | if kor.startswith("한") and len(kor) > 1: 306 | kor = kor[1:] 307 | 308 | if any(word in kor for word in count_tenth_dict): 309 | kor = re.sub( 310 | '|'.join(count_tenth_dict.keys()), 311 | lambda x: count_tenth_dict[x.group()], kor) 312 | 313 | if not is_count and kor.startswith("일") and len(kor) > 1: 314 | kor = kor[1:] 315 | 316 | if float_str is not None: 317 | kor += "쩜 " 318 | kor += re.sub('\d', lambda x: num_to_kor[x.group()], float_str) 319 | 320 | if num_str.startswith("+"): 321 | kor = "플러스 " + kor 322 | elif num_str.startswith("-"): 323 | kor = "마이너스 " + kor 324 | 325 | return kor + unit_str 326 | 327 | 328 | if __name__ == "__main__": 329 | def test_normalize(text): 330 | print(text) 331 | print(normalize(text)) 332 | print("=" * 30) 333 | 334 | 335 | test_normalize("JTBC는 JTBCs를 DY는 A가 Absolute") 336 | test_normalize("오늘(13일) 3,600마리 강아지가") 337 | test_normalize("60.3%") 338 | test_normalize('"저돌"(猪突) 입니다.') 339 | test_normalize('비대위원장이 지난 1월 이런 말을 했습니다. “난 그냥 산돼지처럼 돌파하는 스타일이다”') 340 | test_normalize("지금은 -12.35%였고 종류는 5가지와 19가지, 그리고 55가지였다") 341 | test_normalize("JTBC는 TH와 K 양이 2017년 9월 12일 오후 12시에 24살이 된다") 342 | print(list(hangul_to_jamo(list(hangul_to_jamo('비대위원장이 지난 1월 이런 말을 했습니다? “난 그냥 산돼지처럼 돌파하는 스타일이다”'))))) -------------------------------------------------------------------------------- /datasets/datafeeder.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import random 4 | import tensorflow as tf 5 | import threading 6 | import time 7 | import traceback 8 | from collections import defaultdict 9 | from glob import glob 10 | import pprint 11 | 12 | from text import text_to_sequence 13 | from util.infolog import log 14 | from utils import parallel_run, remove_file 15 | from audio import frames_to_hours 16 | from audio.get_duration import get_durations 17 | 18 | 19 | _batches_per_group = 32 20 | # _p_cmudict = 0.5 21 | _pad = 0 22 | 23 | 24 | def get_frame(path): 25 | data = np.load(path) 26 | n_frame = data["linear"].shape[0] 27 | n_token = len(data["tokens"]) 28 | return (path, n_frame, n_token) 29 | 30 | def get_path_dict( 31 | data_dirs, hparams, config, 32 | data_type, n_test=None, 33 | rng=np.random.RandomState(123)): 34 | 35 | # Load metadata: 36 | path_dict = {} 37 | for data_dir in data_dirs: 38 | paths = glob("{}/*.npz".format(data_dir)) 39 | 40 | if data_type == 'train': 41 | rng.shuffle(paths) 42 | 43 | if not config.skip_path_filter: 44 | items = parallel_run( 45 | get_frame, paths, desc="filter_by_min_max_frame_batch", parallel=True) 46 | 47 | min_n_frame = hparams.reduction_factor * hparams.min_iters 48 | max_n_frame = hparams.reduction_factor * hparams.max_iters - hparams.reduction_factor 49 | 50 | new_items = [(path, n) for path, n, n_tokens in items \ 51 | if min_n_frame <= n <= max_n_frame and n_tokens >= hparams.min_tokens] 52 | 53 | if any(check in data_dir for check in ["son", "yuinna"]): 54 | blacklists = [".0000.", ".0001.", "NB11479580.0001"] 55 | new_items = [item for item in new_items \ 56 | if any(check not in item[0] for check in blacklists)] 57 | 58 | new_paths = [path for path, n in new_items] 59 | new_n_frames = [n for path, n in new_items] 60 | 61 | hours = frames_to_hours(new_n_frames) 62 | 63 | log(' [{}] Loaded metadata for {} examples ({:.2f} hours)'. \ 64 | format(data_dir, len(new_n_frames), hours)) 65 | log(' [{}] Max length: {}'.format(data_dir, max(new_n_frames))) 66 | log(' [{}] Min length: {}'.format(data_dir, min(new_n_frames))) 67 | else: 68 | new_paths = paths 69 | 70 | if data_type == 'train': 71 | new_paths = new_paths[:-n_test] 72 | elif data_type == 'test': 73 | new_paths = new_paths[-n_test:] 74 | else: 75 | raise Exception(" [!] Unkown data_type: {}".format(data_type)) 76 | 77 | path_dict[data_dir] = new_paths 78 | 79 | return path_dict 80 | 81 | 82 | class DataFeeder(threading.Thread): 83 | '''Feeds batches of data into a queue on a background thread.''' 84 | 85 | def __init__(self, coordinator, metadata_filename, hparams, config, batches_per_group, data_type, batch_size): 86 | super(DataFeeder, self).__init__() 87 | self._coord = coordinator 88 | self._hparams = hparams 89 | self._cleaner_names = [x.strip() for x in hparams.cleaners.split(',')] 90 | self._step = 0 91 | self._offset = defaultdict(lambda: 2) 92 | self._batches_per_group = batches_per_group 93 | 94 | self.rng = np.random.RandomState(config.random_seed) 95 | self.data_type = data_type 96 | self.batch_size = batch_size 97 | 98 | self.min_tokens = hparams.min_tokens 99 | self.min_n_frame = hparams.reduction_factor * hparams.min_iters 100 | self.max_n_frame = hparams.reduction_factor * hparams.max_iters - hparams.reduction_factor 101 | self.skip_path_filter = config.skip_path_filter 102 | 103 | # Load metadata: 104 | self._datadir = os.path.dirname(metadata_filename) 105 | with open(metadata_filename, encoding='utf-8') as f: 106 | self._metadata = [line.strip().split('|') for line in f] 107 | hours = sum((int(x[2]) for x in self._metadata)) * hparams.frame_shift_ms / (3600 * 1000) 108 | log('Loaded metadata for %d examples (%.2f hours)' % (len(self._metadata), hours)) 109 | 110 | # Create placeholders for inputs and targets. Don't specify batch size because we want to 111 | # be able to feed different sized batches at eval time. 112 | self._placeholders = [ 113 | tf.placeholder(tf.int32, [None, None], 'inputs'), 114 | tf.placeholder(tf.int32, [None], 'input_lengths'), 115 | tf.placeholder(tf.float32, [None], 'loss_coeff'), 116 | tf.placeholder(tf.float32, [None, None, hparams.num_mels], 'mel_targets'), 117 | tf.placeholder(tf.float32, [None, None, hparams.num_freq], 'linear_targets') 118 | ] 119 | 120 | # Create queue for buffering data: 121 | dtypes = [tf.int32, tf.int32, tf.float32, tf.float32, tf.float32] 122 | 123 | self._placeholders.append( 124 | tf.placeholder(tf.int32, [None], 'inputs'), 125 | ) 126 | dtypes.append(tf.int32) 127 | num_worker = 8 if self.data_type == 'train' else 1 128 | 129 | queue = tf.FIFOQueue(num_worker, dtypes, name='input_queue') 130 | self._enqueue_op = queue.enqueue(self._placeholders) 131 | self.inputs, self.input_lengths, self.loss_coeff, self.mel_targets, self.linear_targets, self.speaker_id = queue.dequeue() 132 | self.inputs.set_shape(self._placeholders[0].shape) 133 | self.input_lengths.set_shape(self._placeholders[1].shape) 134 | self.loss_coeff.set_shape(self._placeholders[2].shape) 135 | self.mel_targets.set_shape(self._placeholders[3].shape) 136 | self.linear_targets.set_shape(self._placeholders[4].shape) 137 | self.speaker_id.set_shape(self._placeholders[5].shape) 138 | self._cmudict = None 139 | 140 | # # Load CMUDict: If enabled, this will randomly substitute some words in the training data with 141 | # # their ARPABet equivalents, which will allow you to also pass ARPABet to the model for 142 | # # synthesis (useful for proper nouns, etc.) 143 | # if hparams.use_cmudict: 144 | # cmudict_path = os.path.join(self._datadir, 'cmudict-0.7b') 145 | # if not os.path.isfile(cmudict_path): 146 | # raise Exception('If use_cmudict=True, you must download ' + 147 | # 'http://svn.code.sf.net/p/cmusphinx/code/trunk/cmudict/cmudict-0.7b to %s' % cmudict_path) 148 | # self._cmudict = cmudict.CMUDict(cmudict_path, keep_ambiguous=False) 149 | # log('Loaded CMUDict with %d unambiguous entries' % len(self._cmudict)) 150 | # else: 151 | # self._cmudict = None 152 | 153 | if self.data_type == 'test': 154 | examples = [] 155 | while True: 156 | for data_dir in self._datadir: 157 | examples.append(self._get_next_example(data_dir)) 158 | #print(data_dir, text.sequence_to_text(examples[-1][0], False, True)) 159 | if len(examples) >= self.batch_size: 160 | break 161 | if len(examples) >= self.batch_size: 162 | break 163 | self.static_batches = [examples for _ in range(self._batches_per_group)] 164 | 165 | else: 166 | self.static_batches = None 167 | 168 | 169 | def start_in_session(self, session, start_step): 170 | self._step = start_step 171 | self._session = session 172 | self.start() 173 | 174 | def run(self): 175 | try: 176 | while not self._coord.should_stop(): 177 | self._enqueue_next_group() 178 | except Exception as e: 179 | traceback.print_exc() 180 | self._coord.request_stop(e) 181 | 182 | def _enqueue_next_group(self): 183 | start = time.time() 184 | 185 | # Read a group of examples: 186 | n = self._hparams.batch_size 187 | r = self._hparams.reduction_factor 188 | 189 | if self.static_batches is not None: 190 | batches = self.static_batches 191 | else: 192 | examples = [] 193 | for data_dir in self._datadir: 194 | if self._hparams.initial_data_greedy: 195 | if self._step < self._hparams.initial_phase_step and \ 196 | any("krbook" in data_dir for data_dir in self._datadir): 197 | data_dir = [data_dir for data_dir in self._datadir if "krbook" in data_dir][0] 198 | 199 | if self._step < self._hparams.initial_phase_step: 200 | example = [self._get_next_example(data_dir) \ 201 | for _ in range(int(n * self._batches_per_group // len(self._datadir)))] 202 | else: 203 | example = [self._get_next_example(data_dir) \ 204 | for _ in range(int(n * self._batches_per_group * self.data_ratio[data_dir]))] 205 | examples.extend(example) 206 | examples.sort(key=lambda x: x[-1]) 207 | 208 | batches = [examples[i:i+n] for i in range(0, len(examples), n)] 209 | self.rng.shuffle(batches) 210 | 211 | log('Generated %d batches of size %d in %.03f sec' % (len(batches), n, time.time() - start)) 212 | for batch in batches: 213 | feed_dict = dict(zip(self._placeholders, _prepare_batch(batch, r, self.rng, self.data_type))) 214 | self._session.run(self._enqueue_op, feed_dict=feed_dict) 215 | self._step += 1 216 | 217 | examples = [self._get_next_example() for i in range(n * _batches_per_group)] 218 | 219 | # Bucket examples based on similar output sequence length for efficiency: 220 | examples.sort(key=lambda x: x[-1]) 221 | batches = [examples[i:i + n] for i in range(0, len(examples), n)] 222 | random.shuffle(batches) 223 | 224 | log('Generated %d batches of size %d in %.03f sec' % (len(batches), n, time.time() - start)) 225 | for batch in batches: 226 | feed_dict = dict(zip(self._placeholders, _prepare_batch(batch, r))) 227 | self._session.run(self._enqueue_op, feed_dict=feed_dict) 228 | self._step += 1 229 | 230 | 231 | def _get_next_example(self, data_dir): 232 | '''Loads a single example (input, mel_target, linear_target, cost) from disk''' 233 | data_paths = self.path_dict[data_dir] 234 | 235 | while True: 236 | if self._offset[data_dir] >= len(data_paths): 237 | self._offset[data_dir] = 0 238 | 239 | if self.data_type == 'train': 240 | self.rng.shuffle(data_paths) 241 | 242 | data_path = data_paths[self._offset[data_dir]] 243 | self._offset[data_dir] += 1 244 | 245 | try: 246 | if os.path.exists(data_path): 247 | data = np.load(data_path) 248 | else: 249 | continue 250 | except: 251 | remove_file(data_path) 252 | continue 253 | 254 | if not self.skip_path_filter: 255 | break 256 | 257 | if self.min_n_frame <= data["linear"].shape[0] <= self.max_n_frame and \ 258 | len(data["tokens"]) > self.min_tokens: 259 | break 260 | 261 | input_data = data['tokens'] 262 | mel_target = data['mel'] 263 | 264 | if 'loss_coeff' in data: 265 | loss_coeff = data['loss_coeff'] 266 | else: 267 | loss_coeff = 1 268 | linear_target = data['linear'] 269 | 270 | return (input_data, loss_coeff, mel_target, linear_target, 271 | self.data_dir_to_id[data_dir], len(linear_target)) 272 | 273 | 274 | 275 | def _maybe_get_arpabet(self, word): 276 | arpabet = self._cmudict.lookup(word) 277 | return '{%s}' % arpabet[0] if arpabet is not None and random.random() < 0.5 else word 278 | 279 | 280 | def _prepare_batch(batch, reduction_factor, rng, data_type=None): 281 | if data_type == 'train': 282 | rng.shuffle(batch) 283 | 284 | inputs = _prepare_inputs([x[0] for x in batch]) 285 | input_lengths = np.asarray([len(x[0]) for x in batch], dtype=np.int32) 286 | loss_coeff = np.asarray([x[1] for x in batch], dtype=np.float32) 287 | mel_targets = _prepare_targets([x[2] for x in batch], reduction_factor) 288 | linear_targets = _prepare_targets([x[3] for x in batch], reduction_factor) 289 | 290 | if len(batch[0]) == 6: 291 | speaker_id = np.asarray([x[4] for x in batch], dtype=np.int32) 292 | return (inputs, input_lengths, loss_coeff, mel_targets, linear_targets, speaker_id) 293 | else: 294 | return (inputs, input_lengths, loss_coeff, mel_targets, linear_targets) 295 | 296 | 297 | 298 | def _prepare_inputs(inputs): 299 | max_len = max((len(x) for x in inputs)) 300 | return np.stack([_pad_input(x, max_len) for x in inputs]) 301 | 302 | 303 | def _prepare_targets(targets, alignment): 304 | max_len = max((len(t) for t in targets)) + 1 305 | return np.stack([_pad_target(t, _round_up(max_len, alignment)) for t in targets]) 306 | 307 | 308 | def _pad_input(x, length): 309 | return np.pad(x, (0, length - x.shape[0]), mode='constant', constant_values=_pad) 310 | 311 | 312 | def _pad_target(t, length): 313 | return np.pad(t, [(0, length - t.shape[0]), (0, 0)], mode='constant', constant_values=_pad) 314 | 315 | 316 | def _round_up(x, multiple): 317 | remainder = x % multiple 318 | return x if remainder == 0 else x + multiple - remainder 319 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import math 4 | import argparse 5 | import traceback 6 | import subprocess 7 | import numpy as np 8 | from jamo import h2j 9 | import tensorflow as tf 10 | from datetime import datetime 11 | from functools import partial 12 | 13 | from hparams import hparams, hparams_debug_string 14 | from models import create_model, get_most_recent_checkpoint 15 | 16 | from utils import ValueWindow, prepare_dirs 17 | from utils import infolog, warning, plot, load_hparams 18 | from utils import get_git_revision_hash, get_git_diff, str2bool, parallel_run 19 | 20 | from audio import save_audio, inv_spectrogram 21 | from text import sequence_to_text, text_to_sequence 22 | from datasets.datafeeder import DataFeeder, _prepare_inputs 23 | 24 | log = infolog.log 25 | 26 | 27 | def create_batch_inputs_from_texts(texts): 28 | sequences = [text_to_sequence(text) for text in texts] 29 | 30 | inputs = _prepare_inputs(sequences) 31 | input_lengths = np.asarray([len(x) for x in inputs], dtype=np.int32) 32 | 33 | for idx, (seq, text) in enumerate(zip(inputs, texts)): 34 | recovered_text = sequence_to_text(seq, skip_eos_and_pad=True) 35 | if recovered_text != h2j(text): 36 | log(" [{}] {}".format(idx, text)) 37 | log(" [{}] {}".format(idx, recovered_text)) 38 | log("="*30) 39 | 40 | return inputs, input_lengths 41 | 42 | 43 | def get_git_commit(): 44 | subprocess.check_output(['git', 'diff-index', '--quiet', 'HEAD']) # Verify client is clean 45 | commit = subprocess.check_output(['git', 'rev-parse', 'HEAD']).decode().strip()[:10] 46 | log('Git commit: %s' % commit) 47 | return commit 48 | 49 | 50 | def add_stats(model, model2=None, scope_name='train'): 51 | with tf.variable_scope(scope_name) as scope: 52 | summaries = [ 53 | tf.summary.scalar('loss_mel', model.mel_loss), 54 | tf.summary.scalar('loss_linear', model.linear_loss), 55 | tf.summary.scalar('loss', model.loss_without_coeff), 56 | ] 57 | 58 | if scope_name == 'train': 59 | gradient_norms = [tf.norm(grad) for grad in model.gradients if grad is not None] 60 | 61 | summaries.extend([ 62 | tf.summary.scalar('learning_rate', model.learning_rate), 63 | tf.summary.scalar('max_gradient_norm', tf.reduce_max(gradient_norms)), 64 | ]) 65 | 66 | if model2 is not None: 67 | with tf.variable_scope('gap_test-train') as scope: 68 | summaries.extend([ 69 | tf.summary.scalar('loss_mel', 70 | model.mel_loss - model2.mel_loss), 71 | tf.summary.scalar('loss_linear', 72 | model.linear_loss - model2.linear_loss), 73 | tf.summary.scalar('loss', 74 | model.loss_without_coeff - model2.loss_without_coeff), 75 | ]) 76 | 77 | return tf.summary.merge(summaries) 78 | 79 | 80 | def save_and_plot_fn(args, log_dir, step, loss, prefix): 81 | idx, (seq, spec, align) = args 82 | 83 | audio_path = os.path.join( 84 | log_dir, '{}-step-{:09d}-audio{:03d}.wav'.format(prefix, step, idx)) 85 | align_path = os.path.join( 86 | log_dir, '{}-step-{:09d}-align{:03d}.png'.format(prefix, step, idx)) 87 | 88 | waveform = inv_spectrogram(spec.T) 89 | save_audio(waveform, audio_path) 90 | 91 | info_text = 'step={:d}, loss={:.5f}'.format(step, loss) 92 | if 'korean_cleaners' in [x.strip() for x in hparams.cleaners.split(',')]: 93 | log('Training korean : Use jamo') 94 | plot.plot_alignment( 95 | align, align_path, info=info_text, 96 | text=sequence_to_text(seq, 97 | skip_eos_and_pad=True, combine_jamo=True), isKorean=True) 98 | else: 99 | log('Training non-korean : X use jamo') 100 | plot.plot_alignment( 101 | align, align_path, info=info_text, 102 | text=sequence_to_text(seq, 103 | skip_eos_and_pad=True, combine_jamo=False), isKorean=False) 104 | 105 | def save_and_plot(sequences, spectrograms, 106 | alignments, log_dir, step, loss, prefix): 107 | 108 | fn = partial(save_and_plot_fn, 109 | log_dir=log_dir, step=step, loss=loss, prefix=prefix) 110 | items = list(enumerate(zip(sequences, spectrograms, alignments))) 111 | 112 | parallel_run(fn, items, parallel=False) 113 | log('Test finished for step {}.'.format(step)) 114 | 115 | 116 | def train(log_dir, config): 117 | config.data_paths = config.data_paths 118 | 119 | data_dirs = [os.path.join(data_path, "data") \ 120 | for data_path in config.data_paths] 121 | num_speakers = len(data_dirs) 122 | config.num_test = config.num_test_per_speaker * num_speakers 123 | 124 | if num_speakers > 1 and hparams.model_type not in ["deepvoice", "simple"]: 125 | raise Exception("[!] Unkown model_type for multi-speaker: {}".format(config.model_type)) 126 | 127 | commit = get_git_commit() if config.git else 'None' 128 | checkpoint_path = os.path.join(log_dir, 'model.ckpt') 129 | 130 | log(' [*] git recv-parse HEAD:\n%s' % get_git_revision_hash()) 131 | log('='*50) 132 | #log(' [*] dit diff:\n%s' % get_git_diff()) 133 | log('='*50) 134 | log(' [*] Checkpoint path: %s' % checkpoint_path) 135 | log(' [*] Loading training data from: %s' % data_dirs) 136 | log(' [*] Using model: %s' % config.model_dir) 137 | log(hparams_debug_string()) 138 | 139 | # Set up DataFeeder: 140 | coord = tf.train.Coordinator() 141 | with tf.variable_scope('datafeeder') as scope: 142 | train_feeder = DataFeeder( 143 | coord, data_dirs, hparams, config, 32, 144 | data_type='train', batch_size=hparams.batch_size) 145 | test_feeder = DataFeeder( 146 | coord, data_dirs, hparams, config, 8, 147 | data_type='test', batch_size=config.num_test) 148 | 149 | # Set up model: 150 | is_randomly_initialized = config.initialize_path is None 151 | global_step = tf.Variable(0, name='global_step', trainable=False) 152 | 153 | with tf.variable_scope('model') as scope: 154 | model = create_model(hparams) 155 | model.initialize( 156 | train_feeder.inputs, train_feeder.input_lengths, 157 | num_speakers, train_feeder.speaker_id, 158 | train_feeder.mel_targets, train_feeder.linear_targets, 159 | train_feeder.loss_coeff, 160 | is_randomly_initialized=is_randomly_initialized) 161 | 162 | model.add_loss() 163 | model.add_optimizer(global_step) 164 | train_stats = add_stats(model, scope_name='stats') # legacy 165 | 166 | with tf.variable_scope('model', reuse=True) as scope: 167 | test_model = create_model(hparams) 168 | test_model.initialize( 169 | test_feeder.inputs, test_feeder.input_lengths, 170 | num_speakers, test_feeder.speaker_id, 171 | test_feeder.mel_targets, test_feeder.linear_targets, 172 | test_feeder.loss_coeff, rnn_decoder_test_mode=True, 173 | is_randomly_initialized=is_randomly_initialized) 174 | test_model.add_loss() 175 | 176 | test_stats = add_stats(test_model, model, scope_name='test') 177 | test_stats = tf.summary.merge([test_stats, train_stats]) 178 | 179 | # Bookkeeping: 180 | step = 0 181 | time_window = ValueWindow(100) 182 | loss_window = ValueWindow(100) 183 | saver = tf.train.Saver(max_to_keep=None, keep_checkpoint_every_n_hours=2) 184 | 185 | sess_config = tf.ConfigProto( 186 | log_device_placement=False, 187 | allow_soft_placement=True) 188 | sess_config.gpu_options.allow_growth=True 189 | 190 | # Train! 191 | #with tf.Session(config=sess_config) as sess: 192 | with tf.Session() as sess: 193 | try: 194 | summary_writer = tf.summary.FileWriter(log_dir, sess.graph) 195 | sess.run(tf.global_variables_initializer()) 196 | 197 | if config.load_path: 198 | # Restore from a checkpoint if the user requested it. 199 | restore_path = get_most_recent_checkpoint(config.model_dir) 200 | saver.restore(sess, restore_path) 201 | log('Resuming from checkpoint: %s at commit: %s' % (restore_path, commit), slack=True) 202 | elif config.initialize_path: 203 | restore_path = get_most_recent_checkpoint(config.initialize_path) 204 | saver.restore(sess, restore_path) 205 | log('Initialized from checkpoint: %s at commit: %s' % (restore_path, commit), slack=True) 206 | 207 | zero_step_assign = tf.assign(global_step, 0) 208 | sess.run(zero_step_assign) 209 | 210 | start_step = sess.run(global_step) 211 | log('='*50) 212 | log(' [*] Global step is reset to {}'. \ 213 | format(start_step)) 214 | log('='*50) 215 | else: 216 | log('Starting new training run at commit: %s' % commit, slack=True) 217 | 218 | start_step = sess.run(global_step) 219 | 220 | train_feeder.start_in_session(sess, start_step) 221 | test_feeder.start_in_session(sess, start_step) 222 | 223 | while not coord.should_stop(): 224 | start_time = time.time() 225 | step, loss, opt = sess.run( 226 | [global_step, model.loss_without_coeff, model.optimize], 227 | feed_dict=model.get_dummy_feed_dict()) 228 | 229 | time_window.append(time.time() - start_time) 230 | loss_window.append(loss) 231 | 232 | message = 'Step %-7d [%.03f sec/step, loss=%.05f, avg_loss=%.05f]' % ( 233 | step, time_window.average, loss, loss_window.average) 234 | log(message, slack=(step % config.checkpoint_interval == 0)) 235 | 236 | if loss > 100 or math.isnan(loss): 237 | log('Loss exploded to %.05f at step %d!' % (loss, step), slack=True) 238 | raise Exception('Loss Exploded') 239 | 240 | if step % config.summary_interval == 0: 241 | log('Writing summary at step: %d' % step) 242 | 243 | feed_dict = { 244 | **model.get_dummy_feed_dict(), 245 | **test_model.get_dummy_feed_dict() 246 | } 247 | summary_writer.add_summary(sess.run( 248 | test_stats, feed_dict=feed_dict), step) 249 | 250 | if step % config.checkpoint_interval == 0: 251 | log('Saving checkpoint to: %s-%d' % (checkpoint_path, step)) 252 | saver.save(sess, checkpoint_path, global_step=step) 253 | 254 | if step % config.test_interval == 0: 255 | log('Saving audio and alignment...') 256 | num_test = config.num_test 257 | 258 | fetches = [ 259 | model.inputs[:num_test], 260 | model.linear_outputs[:num_test], 261 | model.alignments[:num_test], 262 | test_model.inputs[:num_test], 263 | test_model.linear_outputs[:num_test], 264 | test_model.alignments[:num_test], 265 | ] 266 | feed_dict = { 267 | **model.get_dummy_feed_dict(), 268 | **test_model.get_dummy_feed_dict() 269 | } 270 | 271 | sequences, spectrograms, alignments, \ 272 | test_sequences, test_spectrograms, test_alignments = \ 273 | sess.run(fetches, feed_dict=feed_dict) 274 | 275 | save_and_plot(sequences[:1], spectrograms[:1], alignments[:1], 276 | log_dir, step, loss, "train") 277 | save_and_plot(test_sequences, test_spectrograms, test_alignments, 278 | log_dir, step, loss, "test") 279 | 280 | except Exception as e: 281 | log('Exiting due to exception: %s' % e, slack=True) 282 | traceback.print_exc() 283 | coord.request_stop(e) 284 | 285 | 286 | def main(): 287 | parser = argparse.ArgumentParser() 288 | 289 | parser.add_argument('--log_dir', default='logs') 290 | parser.add_argument('--data_paths', default='datasets/kr_example') 291 | parser.add_argument('--load_path', default=None) 292 | parser.add_argument('--initialize_path', default=None) 293 | 294 | parser.add_argument('--num_test_per_speaker', type=int, default=2) 295 | parser.add_argument('--random_seed', type=int, default=123) 296 | parser.add_argument('--summary_interval', type=int, default=100) 297 | parser.add_argument('--test_interval', type=int, default=500) 298 | parser.add_argument('--checkpoint_interval', type=int, default=1000) 299 | parser.add_argument('--skip_path_filter', 300 | type=str2bool, default=False, help='Use only for debugging') 301 | 302 | parser.add_argument('--slack_url', 303 | help='Slack webhook URL to get periodic reports.') 304 | parser.add_argument('--git', action='store_true', 305 | help='If set, verify that the client is clean.') 306 | 307 | config = parser.parse_args() 308 | config.data_paths = config.data_paths.split(",") 309 | setattr(hparams, "num_speakers", len(config.data_paths)) 310 | 311 | prepare_dirs(config, hparams) 312 | 313 | log_path = os.path.join(config.model_dir, 'train.log') 314 | infolog.init(log_path, config.model_dir, config.slack_url) 315 | 316 | tf.set_random_seed(config.random_seed) 317 | print(config.data_paths) 318 | 319 | if any("krbook" not in data_path for data_path in config.data_paths) and \ 320 | hparams.sample_rate != 20000: 321 | warning("Detect non-krbook dataset. May need to set sampling rate from {} to 20000".\ 322 | format(hparams.sample_rate)) 323 | 324 | if any('LJ' in data_path for data_path in config.data_paths) and \ 325 | hparams.sample_rate != 22050: 326 | warning("Detect LJ Speech dataset. Set sampling rate from {} to 22050".\ 327 | format(hparams.sample_rate)) 328 | 329 | if config.load_path is not None and config.initialize_path is not None: 330 | raise Exception(" [!] Only one of load_path and initialize_path should be set") 331 | 332 | train(config.model_dir, config) 333 | 334 | 335 | if __name__ == '__main__': 336 | main() 337 | -------------------------------------------------------------------------------- /synthesizer.py: -------------------------------------------------------------------------------- 1 | import io 2 | import os 3 | import re 4 | import librosa 5 | import argparse 6 | import numpy as np 7 | from glob import glob 8 | from tqdm import tqdm 9 | import tensorflow as tf 10 | from functools import partial 11 | 12 | from hparams import hparams 13 | from models import create_model, get_most_recent_checkpoint 14 | from audio import save_audio, inv_spectrogram, inv_preemphasis, \ 15 | inv_spectrogram_tensorflow 16 | from utils import plot, PARAMS_NAME, load_json, load_hparams, \ 17 | add_prefix, add_postfix, get_time, parallel_run, makedirs, str2bool 18 | 19 | from text.korean import tokenize 20 | from text import text_to_sequence, sequence_to_text 21 | 22 | 23 | class Synthesizer(object): 24 | def close(self): 25 | tf.reset_default_graph() 26 | self.sess.close() 27 | 28 | def load(self, checkpoint_path, num_speakers=2, checkpoint_step=None, model_name='tacotron'): 29 | self.num_speakers = num_speakers 30 | 31 | if os.path.isdir(checkpoint_path): 32 | load_path = checkpoint_path 33 | checkpoint_path = get_most_recent_checkpoint(checkpoint_path, checkpoint_step) 34 | else: 35 | load_path = os.path.dirname(checkpoint_path) 36 | 37 | print('Constructing model: %s' % model_name) 38 | 39 | inputs = tf.placeholder(tf.int32, [None, None], 'inputs') 40 | input_lengths = tf.placeholder(tf.int32, [None], 'input_lengths') 41 | 42 | batch_size = tf.shape(inputs)[0] 43 | speaker_id = tf.placeholder_with_default( 44 | tf.zeros([batch_size], dtype=tf.int32), [None], 'speaker_id') 45 | 46 | load_hparams(hparams, load_path) 47 | with tf.variable_scope('model') as scope: 48 | self.model = create_model(hparams) 49 | 50 | self.model.initialize( 51 | inputs, input_lengths, 52 | self.num_speakers, speaker_id) 53 | self.wav_output = \ 54 | inv_spectrogram_tensorflow(self.model.linear_outputs) 55 | 56 | print('Loading checkpoint: %s' % checkpoint_path) 57 | 58 | sess_config = tf.ConfigProto( 59 | allow_soft_placement=True, 60 | intra_op_parallelism_threads=1, 61 | inter_op_parallelism_threads=2) 62 | sess_config.gpu_options.allow_growth = True 63 | 64 | self.sess = tf.Session(config=sess_config) 65 | self.sess.run(tf.global_variables_initializer()) 66 | saver = tf.train.Saver() 67 | saver.restore(self.sess, checkpoint_path) 68 | 69 | def synthesize(self, 70 | texts=None, tokens=None, 71 | base_path=None, paths=None, speaker_ids=None, 72 | start_of_sentence=None, end_of_sentence=True, 73 | pre_word_num=0, post_word_num=0, 74 | pre_surplus_idx=0, post_surplus_idx=1, 75 | use_short_concat=False, 76 | manual_attention_mode=0, 77 | base_alignment_path=None, 78 | librosa_trim=False, 79 | attention_trim=True, 80 | isKorean=True): 81 | 82 | # Possible inputs: 83 | # 1) text=text 84 | # 2) text=texts 85 | # 3) tokens=tokens, texts=texts # use texts as guide 86 | 87 | if type(texts) == str: 88 | texts = [texts] 89 | 90 | if texts is not None and tokens is None: 91 | sequences = [text_to_sequence(text) for text in texts] 92 | elif tokens is not None: 93 | sequences = tokens 94 | 95 | if paths is None: 96 | paths = [None] * len(sequences) 97 | if texts is None: 98 | texts = [None] * len(sequences) 99 | 100 | time_str = get_time() 101 | def plot_and_save_parallel( 102 | wavs, alignments, use_manual_attention): 103 | 104 | items = list(enumerate(zip( 105 | wavs, alignments, paths, texts, sequences))) 106 | 107 | fn = partial( 108 | plot_graph_and_save_audio, 109 | base_path=base_path, 110 | start_of_sentence=start_of_sentence, end_of_sentence=end_of_sentence, 111 | pre_word_num=pre_word_num, post_word_num=post_word_num, 112 | pre_surplus_idx=pre_surplus_idx, post_surplus_idx=post_surplus_idx, 113 | use_short_concat=use_short_concat, 114 | use_manual_attention=use_manual_attention, 115 | librosa_trim=librosa_trim, 116 | attention_trim=attention_trim, 117 | time_str=time_str, 118 | isKorean=isKorean) 119 | return parallel_run(fn, items, 120 | desc="plot_graph_and_save_audio", parallel=False) 121 | 122 | input_lengths = np.argmax(np.array(sequences) == 1, 1) 123 | 124 | fetches = [ 125 | #self.wav_output, 126 | self.model.linear_outputs, 127 | self.model.alignments, 128 | ] 129 | 130 | feed_dict = { 131 | self.model.inputs: sequences, 132 | self.model.input_lengths: input_lengths, 133 | } 134 | if base_alignment_path is None: 135 | feed_dict.update({ 136 | self.model.manual_alignments: np.zeros([1, 1, 1]), 137 | self.model.is_manual_attention: False, 138 | }) 139 | else: 140 | manual_alignments = [] 141 | alignment_path = os.path.join( 142 | base_alignment_path, 143 | os.path.basename(base_path)) 144 | 145 | for idx in range(len(sequences)): 146 | numpy_path = "{}.{}.npy".format(alignment_path, idx) 147 | manual_alignments.append(np.load(numpy_path)) 148 | 149 | alignments_T = np.transpose(manual_alignments, [0, 2, 1]) 150 | feed_dict.update({ 151 | self.model.manual_alignments: alignments_T, 152 | self.model.is_manual_attention: True, 153 | }) 154 | 155 | if speaker_ids is not None: 156 | if type(speaker_ids) == dict: 157 | speaker_embed_table = sess.run( 158 | self.model.speaker_embed_table) 159 | 160 | speaker_embed = [speaker_ids[speaker_id] * \ 161 | speaker_embed_table[speaker_id] for speaker_id in speaker_ids] 162 | feed_dict.update({ 163 | self.model.speaker_embed_table: np.tile() 164 | }) 165 | else: 166 | feed_dict[self.model.speaker_id] = speaker_ids 167 | 168 | wavs, alignments = \ 169 | self.sess.run(fetches, feed_dict=feed_dict) 170 | results = plot_and_save_parallel( 171 | wavs, alignments, True) 172 | 173 | if manual_attention_mode > 0: 174 | # argmax one hot 175 | if manual_attention_mode == 1: 176 | alignments_T = np.transpose(alignments, [0, 2, 1]) # [N, E, D] 177 | new_alignments = np.zeros_like(alignments_T) 178 | 179 | for idx in range(len(alignments)): 180 | argmax = alignments[idx].argmax(1) 181 | new_alignments[idx][(argmax, range(len(argmax)))] = 1 182 | # sharpening 183 | elif manual_attention_mode == 2: 184 | new_alignments = np.transpose(alignments, [0, 2, 1]) # [N, E, D] 185 | 186 | for idx in range(len(alignments)): 187 | var = np.var(new_alignments[idx], 1) 188 | mean_var = var[:input_lengths[idx]].mean() 189 | 190 | new_alignments = np.pow(new_alignments[idx], 2) 191 | # prunning 192 | elif manual_attention_mode == 3: 193 | new_alignments = np.transpose(alignments, [0, 2, 1]) # [N, E, D] 194 | 195 | for idx in range(len(alignments)): 196 | argmax = alignments[idx].argmax(1) 197 | new_alignments[idx][(argmax, range(len(argmax)))] = 1 198 | 199 | feed_dict.update({ 200 | self.model.manual_alignments: new_alignments, 201 | self.model.is_manual_attention: True, 202 | }) 203 | 204 | new_wavs, new_alignments = \ 205 | self.sess.run(fetches, feed_dict=feed_dict) 206 | results = plot_and_save_parallel( 207 | new_wavs, new_alignments, True) 208 | 209 | return results 210 | 211 | def plot_graph_and_save_audio(args, 212 | base_path=None, 213 | start_of_sentence=None, end_of_sentence=None, 214 | pre_word_num=0, post_word_num=0, 215 | pre_surplus_idx=0, post_surplus_idx=1, 216 | use_short_concat=False, 217 | use_manual_attention=False, save_alignment=False, 218 | librosa_trim=False, attention_trim=False, 219 | time_str=None, isKorean=True): 220 | 221 | idx, (wav, alignment, path, text, sequence) = args 222 | 223 | if base_path: 224 | plot_path = "{}/{}.png".format(base_path, get_time()) 225 | elif path: 226 | plot_path = path.rsplit('.', 1)[0] + ".png" 227 | else: 228 | plot_path = None 229 | 230 | #plot_path = add_prefix(plot_path, time_str) 231 | if use_manual_attention: 232 | plot_path = add_postfix(plot_path, "manual") 233 | 234 | if plot_path: 235 | plot.plot_alignment(alignment, plot_path, text=text, isKorean=isKorean) 236 | 237 | if use_short_concat: 238 | wav = short_concat( 239 | wav, alignment, text, 240 | start_of_sentence, end_of_sentence, 241 | pre_word_num, post_word_num, 242 | pre_surplus_idx, post_surplus_idx) 243 | 244 | if attention_trim and end_of_sentence: 245 | end_idx_counter = 0 246 | attention_argmax = alignment.argmax(0) 247 | end_idx = min(len(sequence) - 1, max(attention_argmax)) 248 | max_counter = min((attention_argmax == end_idx).sum(), 5) 249 | 250 | for jdx, attend_idx in enumerate(attention_argmax): 251 | if len(attention_argmax) > jdx + 1: 252 | if attend_idx == end_idx: 253 | end_idx_counter += 1 254 | 255 | if attend_idx == end_idx and attention_argmax[jdx + 1] > end_idx: 256 | break 257 | 258 | if end_idx_counter >= max_counter: 259 | break 260 | else: 261 | break 262 | 263 | spec_end_idx = hparams.reduction_factor * jdx + 3 264 | wav = wav[:spec_end_idx] 265 | 266 | audio_out = inv_spectrogram(wav.T) 267 | 268 | if librosa_trim and end_of_sentence: 269 | yt, index = librosa.effects.trim(audio_out, 270 | frame_length=5120, hop_length=256, top_db=50) 271 | audio_out = audio_out[:index[-1]] 272 | 273 | if save_alignment: 274 | alignment_path = "{}/{}.npy".format(base_path, idx) 275 | np.save(alignment_path, alignment, allow_pickle=False) 276 | 277 | if path or base_path: 278 | if path: 279 | current_path = add_postfix(path, idx) 280 | elif base_path: 281 | current_path = plot_path.replace(".png", ".wav") 282 | 283 | save_audio(audio_out, current_path) 284 | return True 285 | else: 286 | io_out = io.BytesIO() 287 | save_audio(audio_out, io_out) 288 | result = io_out.getvalue() 289 | return result 290 | 291 | def get_most_recent_checkpoint(checkpoint_dir, checkpoint_step=None): 292 | if checkpoint_step is None: 293 | checkpoint_paths = [path for path in glob("{}/*.ckpt-*.data-*".format(checkpoint_dir))] 294 | idxes = [int(os.path.basename(path).split('-')[1].split('.')[0]) for path in checkpoint_paths] 295 | 296 | max_idx = max(idxes) 297 | else: 298 | max_idx = checkpoint_step 299 | lastest_checkpoint = os.path.join(checkpoint_dir, "model.ckpt-{}".format(max_idx)) 300 | print(" [*] Found lastest checkpoint: {}".format(lastest_checkpoint)) 301 | return lastest_checkpoint 302 | 303 | def short_concat( 304 | wav, alignment, text, 305 | start_of_sentence, end_of_sentence, 306 | pre_word_num, post_word_num, 307 | pre_surplus_idx, post_surplus_idx): 308 | 309 | # np.array(list(decomposed_text))[attention_argmax] 310 | attention_argmax = alignment.argmax(0) 311 | 312 | if not start_of_sentence and pre_word_num > 0: 313 | surplus_decomposed_text = decompose_ko_text("".join(text.split()[0])) 314 | start_idx = len(surplus_decomposed_text) + 1 315 | 316 | for idx, attend_idx in enumerate(attention_argmax): 317 | if attend_idx == start_idx and attention_argmax[idx - 1] < start_idx: 318 | break 319 | 320 | wav_start_idx = hparams.reduction_factor * idx - 1 - pre_surplus_idx 321 | else: 322 | wav_start_idx = 0 323 | 324 | if not end_of_sentence and post_word_num > 0: 325 | surplus_decomposed_text = decompose_ko_text("".join(text.split()[-1])) 326 | end_idx = len(decomposed_text.replace(surplus_decomposed_text, '')) - 1 327 | 328 | for idx, attend_idx in enumerate(attention_argmax): 329 | if attend_idx == end_idx and attention_argmax[idx + 1] > end_idx: 330 | break 331 | 332 | wav_end_idx = hparams.reduction_factor * idx + 1 + post_surplus_idx 333 | else: 334 | if True: # attention based split 335 | if end_of_sentence: 336 | end_idx = min(len(decomposed_text) - 1, max(attention_argmax)) 337 | else: 338 | surplus_decomposed_text = decompose_ko_text("".join(text.split()[-1])) 339 | end_idx = len(decomposed_text.replace(surplus_decomposed_text, '')) - 1 340 | 341 | while True: 342 | if end_idx in attention_argmax: 343 | break 344 | end_idx -= 1 345 | 346 | end_idx_counter = 0 347 | for idx, attend_idx in enumerate(attention_argmax): 348 | if len(attention_argmax) > idx + 1: 349 | if attend_idx == end_idx: 350 | end_idx_counter += 1 351 | 352 | if attend_idx == end_idx and attention_argmax[idx + 1] > end_idx: 353 | break 354 | 355 | if end_idx_counter > 5: 356 | break 357 | else: 358 | break 359 | 360 | wav_end_idx = hparams.reduction_factor * idx + 1 + post_surplus_idx 361 | else: 362 | wav_end_idx = None 363 | 364 | wav = wav[wav_start_idx:wav_end_idx] 365 | 366 | if end_of_sentence: 367 | wav = np.lib.pad(wav, ((0, 20), (0, 0)), 'constant', constant_values=0) 368 | else: 369 | wav = np.lib.pad(wav, ((0, 10), (0, 0)), 'constant', constant_values=0) 370 | 371 | 372 | if __name__ == "__main__": 373 | parser = argparse.ArgumentParser() 374 | parser.add_argument('--load_path', required=True) 375 | parser.add_argument('--sample_path', default="samples") 376 | parser.add_argument('--text', required=True) 377 | parser.add_argument('--num_speakers', default=1, type=int) 378 | parser.add_argument('--speaker_id', default=0, type=int) 379 | parser.add_argument('--checkpoint_step', default=None, type=int) 380 | parser.add_argument('--is_korean', default=True, type=str2bool) 381 | config = parser.parse_args() 382 | 383 | makedirs(config.sample_path) 384 | 385 | synthesizer = Synthesizer() 386 | synthesizer.load(config.load_path, config.num_speakers, config.checkpoint_step) 387 | 388 | audio = synthesizer.synthesize( 389 | texts=[config.text], 390 | base_path=config.sample_path, 391 | speaker_ids=[config.speaker_id], 392 | attention_trim=False, 393 | isKorean=config.is_korean)[0] -------------------------------------------------------------------------------- /models/tacotron.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | from tensorflow.contrib.rnn import GRUCell, MultiRNNCell, OutputProjectionWrapper, ResidualWrapper, LSTMCell 4 | from tensorflow.contrib.seq2seq import BasicDecoder, BahdanauAttention, AttentionWrapper, BahdanauMonotonicAttention 5 | from text.symbols import symbols 6 | from util.infolog import log 7 | from .helpers import TacoTestHelper, TacoTrainingHelper 8 | from .modules import * 9 | from .rnn_wrappers import DecoderPrenetWrapper, ConcatOutputAndAttentionWrapper, AttentionWrapper 10 | 11 | 12 | class Tacotron(): 13 | def __init__(self, hparams): 14 | self._hparams = hparams 15 | 16 | def initialize(self, inputs, input_lengths, num_speakers, speaker_id, 17 | mel_targets=None, linear_targets=None, loss_coeff=None, 18 | rnn_decoder_test_mode=False, is_randomly_initialized=False): 19 | 20 | is_training = linear_targets is not None 21 | self.is_randomly_initialized = is_randomly_initialized 22 | 23 | with tf.variable_scope('inference') as scope: 24 | batch_size = tf.shape(inputs)[0] 25 | hp = self._hparams 26 | 27 | # Embeddings 28 | embedding_table = tf.get_variable( 29 | 'embedding', [len(symbols), hp.embed_depth], dtype=tf.float32, 30 | initializer=tf.truncated_normal_initializer(stddev=0.5)) 31 | embedded_inputs = tf.nn.embedding_lookup(embedding_table, inputs) # [N, T_in, embed_depth=256] 32 | 33 | self.num_speakers = num_speakers 34 | if self.num_speakers > 1: 35 | if hp.speaker_embedding_size != 1: 36 | speaker_embed_table = tf.get_variable( 37 | 'speaker_embedding', 38 | [self.num_speakers, hp.speaker_embedding_size], dtype=tf.float32, 39 | initializer=tf.truncated_normal_initializer(stddev=0.5)) 40 | # [N, T_in, speaker_embedding_size] 41 | speaker_embed = tf.nn.embedding_lookup(speaker_embed_table, speaker_id) 42 | 43 | if hp.model_type == 'deepvoice': 44 | if hp.speaker_embedding_size == 1: 45 | before_highway = get_embed( 46 | speaker_id, self.num_speakers, 47 | hp.enc_prenet_sizes[-1], "before_highway") 48 | encoder_rnn_init_state = get_embed( 49 | speaker_id, self.num_speakers, 50 | hp.enc_rnn_size * 2, "encoder_rnn_init_state") 51 | 52 | attention_rnn_init_state = get_embed( 53 | speaker_id, self.num_speakers, 54 | hp.attention_state_size, "attention_rnn_init_state") 55 | decoder_rnn_init_states = [get_embed( 56 | speaker_id, self.num_speakers, 57 | hp.dec_rnn_size, "decoder_rnn_init_states{}".format(idx + 1)) \ 58 | for idx in range(hp.dec_layer_num)] 59 | else: 60 | deep_dense = lambda x, dim: \ 61 | tf.layers.dense(x, dim, activation=tf.nn.softsign) 62 | 63 | before_highway = deep_dense(speaker_embed, hp.enc_prenet_sizes[-1]) 64 | encoder_rnn_init_state = deep_dense(speaker_embed, hp.enc_rnn_size * 2) 65 | 66 | attention_rnn_init_state = deep_dense(speaker_embed, hp.attention_state_size) 67 | decoder_rnn_init_states = [deep_dense(speaker_embed, hp.dec_rnn_size) for _ in range(hp.dec_layer_num)] 68 | 69 | speaker_embed = None # deepvoice does not use speaker_embed directly 70 | 71 | elif hp.model_type == 'simple': 72 | before_highway = None 73 | encoder_rnn_init_state = None 74 | attention_rnn_init_state = None 75 | decoder_rnn_init_states = None 76 | 77 | else: 78 | raise Exception(" [!] Unknown multi-speaker model type: {}".format(hp.model_type)) 79 | 80 | else: 81 | speaker_embed = None 82 | before_highway = None 83 | encoder_rnn_init_state = None 84 | attention_rnn_init_state = None 85 | decoder_rnn_init_states = None 86 | 87 | 88 | # Encoder 89 | prenet_outputs = prenet(embedded_inputs, is_training, hp.enc_prenet_sizes, hp.dropout_prob, scope = 'prenet') # [N, T_in, prenet_depths[-1]=128] 90 | encoder_outputs = cbhg(prenet_outputs, input_lengths, is_training, # [N, T_in, encoder_depth=256] 91 | hp.enc_bank_size, hp.enc_bank_channel_size, 92 | hp.enc_maxpool_width, hp.enc_highway_depth, hp.enc_rnn_size, 93 | hp.enc_proj_sizes, hp.enc_proj_width, 94 | scope="encoder_cbhg", before_highway=before_highway, 95 | encoder_rnn_init_state=encoder_rnn_init_state) 96 | 97 | # Attention 98 | # For manaul control of attention 99 | self.is_manual_attention = tf.placeholder( 100 | tf.bool, shape=(), name='is_manual_attention', 101 | ) 102 | self.manual_alignments = tf.placeholder( 103 | tf.float32, shape=[None, None, None], name="manual_alignments", 104 | ) 105 | 106 | dec_prenet_outputs = DecoderPrenetWrapper( 107 | GRUCell(hp.attention_state_size), 108 | speaker_embed, 109 | is_training, hp.dec_prenet_sizes, hp.dropout_prob) 110 | 111 | if hp.attention_type == 'bah_mon': 112 | attention_mechanism = BahdanauMonotonicAttention( 113 | hp.attention_size, encoder_outputs) 114 | elif hp.attention_type == 'bah_norm': 115 | attention_mechanism = BahdanauAttention( 116 | hp.attention_size, encoder_outputs, normalize=True) 117 | elif hp.attention_type == 'luong_scaled': 118 | attention_mechanism = LuongAttention( 119 | hp.attention_size, encoder_outputs, scale=True) 120 | elif hp.attention_type == 'luong': 121 | attention_mechanism = LuongAttention( 122 | hp.attention_size, encoder_outputs) 123 | elif hp.attention_type == 'bah': 124 | attention_mechanism = BahdanauAttention( 125 | hp.attention_size, encoder_outputs) 126 | elif hp.attention_type.startswith('ntm2'): 127 | shift_width = int(hp.attention_type.split('-')[-1]) 128 | attention_mechanism = NTMAttention2( 129 | hp.attention_size, encoder_outputs, shift_width=shift_width) 130 | else: 131 | raise Exception(" [!] Unkown attention type: {}".format(hp.attention_type)) 132 | 133 | attention_cell = AttentionWrapper( 134 | dec_prenet_outputs, 135 | attention_mechanism, 136 | self.is_manual_attention, 137 | self.manual_alignments, 138 | initial_cell_state=attention_rnn_init_state, 139 | alignment_history=True, 140 | output_attention=False 141 | ) 142 | 143 | # Concatenate attention context vector and RNN cell output into a 2*attention_depth=512D vector. 144 | # [N, T_in, attention_size+attention_state_size] 145 | concat_cell = ConcatOutputAndAttentionWrapper(attention_cell, embed_to_concat=speaker_embed) 146 | 147 | # Decoder (layers specified bottom to top): 148 | decoder_cell = MultiRNNCell([ 149 | OutputProjectionWrapper(concat_cell, hp.dec_rnn_size), 150 | ResidualWrapper(GRUCell(hp.dec_rnn_size)), 151 | ResidualWrapper(GRUCell(hp.dec_rnn_size)), 152 | ], state_is_tuple=True) # [N, T_in, decoder_depth=256] 153 | 154 | # Project onto r mel spectrograms (predict r outputs at each RNN step): 155 | output_cell = OutputProjectionWrapper(decoder_cell, hp.num_mels * hp.reduction_factor) 156 | decoder_init_state = output_cell.zero_state(batch_size=batch_size, dtype=tf.float32) 157 | 158 | if hp.model_type == "deepvoice": 159 | # decoder_init_state[0] : AttentionWrapperState 160 | # = cell_state + attention + time + alignments + alignment_history 161 | # decoder_init_state[0][0] = attention_rnn_init_state (already applied) 162 | decoder_init_state = list(decoder_init_state) 163 | 164 | for idx, cell in enumerate(decoder_rnn_init_states): 165 | shape1 = decoder_init_state[idx + 1].get_shape().as_list() 166 | shape2 = cell.get_shape().as_list() 167 | if shape1 != shape2: 168 | raise Exception(" [!] Shape {} and {} should be equal". \ 169 | format(shape1, shape2)) 170 | decoder_init_state[idx + 1] = cell 171 | 172 | decoder_init_state = tuple(decoder_init_state) 173 | 174 | 175 | if is_training: 176 | helper = TacoTrainingHelper(inputs, mel_targets, hp.num_mels, hp.reduction_factor, rnn_decoder_test_mode) 177 | else: 178 | helper = TacoTestHelper(batch_size, hp.num_mels, hp.reduction_factor) 179 | 180 | (decoder_outputs, _), final_decoder_state, _ = tf.contrib.seq2seq.dynamic_decode( 181 | BasicDecoder(output_cell, helper, decoder_init_state), 182 | maximum_iterations=hp.max_iters) # [N, T_out/r, M*r] 183 | 184 | # Reshape outputs to be one output per entry 185 | mel_outputs = tf.reshape(decoder_outputs, [batch_size, -1, hp.num_mels]) # [N, T_out, M] 186 | 187 | # Add post-processing CBHG: 188 | # [N, T_out, postnet_depth=256] 189 | post_outputs = cbhg(mel_outputs, None, is_training, 190 | hp.post_bank_size, hp.post_bank_channel_size, hp.post_maxpool_width, 191 | hp.post_highway_depth, hp.post_rnn_size, hp.post_proj_sizes, hp.post_proj_width, 192 | scope='post_cbhg') 193 | 194 | if speaker_embed is not None and hp.model_type == 'simple': 195 | expanded_speaker_emb = tf.expand_dims(speaker_embed, [1]) 196 | tiled_speaker_embedding = tf.tile(expanded_speaker_emb, [1, tf.shape(post_outputs)[1], 1]) 197 | 198 | # [N, T_out, 256 + alpha] 199 | post_outputs = tf.concat([tiled_speaker_embedding, post_outputs], axis=-1) 200 | 201 | linear_outputs = tf.layers.dense(post_outputs, hp.num_freq) # [N, T_out, F] 202 | 203 | # Grab alignments from the final decoder state: 204 | alignments = tf.transpose(final_decoder_state[0].alignment_history.stack(), [1, 2, 0]) 205 | 206 | self.inputs = inputs 207 | self.speaker_id = speaker_id 208 | self.input_lengths = input_lengths 209 | self.loss_coeff = loss_coeff 210 | self.mel_outputs = mel_outputs 211 | self.linear_outputs = linear_outputs 212 | self.alignments = alignments 213 | self.mel_targets = mel_targets 214 | self.linear_targets = linear_targets 215 | self.final_decoder_state = final_decoder_state 216 | 217 | log('='*40) 218 | log(' model_type: %s' % hp.model_type) 219 | log('='*40) 220 | 221 | log('Initialized Tacotron model. Dimensions: ') 222 | log(' embedding: %d' % embedded_inputs.shape[-1]) 223 | if speaker_embed is not None: 224 | log(' speaker embedding: %d' % speaker_embed.shape[-1]) 225 | else: 226 | log(' speaker embedding: None') 227 | log(' prenet out: %d' % prenet_outputs.shape[-1]) 228 | log(' encoder out: %d' % encoder_outputs.shape[-1]) 229 | log(' attention out: %d' % attention_cell.output_size) 230 | log(' concat attn & out: %d' % concat_cell.output_size) 231 | log(' decoder cell out: %d' % decoder_cell.output_size) 232 | log(' decoder out (%d frames): %d' % (hp.outputs_per_step, decoder_outputs.shape[-1])) 233 | log(' decoder out (1 frame): %d' % mel_outputs.shape[-1]) 234 | log(' postnet out: %d' % post_outputs.shape[-1]) 235 | log(' linear out: %d' % linear_outputs.shape[-1]) 236 | 237 | def add_loss(self): 238 | '''Adds loss to the model. Sets "loss" field. initialize must have been called.''' 239 | with tf.variable_scope('loss') as scope: 240 | hp = self._hparams 241 | mel_loss = tf.abs(self.mel_targets - self.mel_outputs) 242 | l1 = tf.abs(self.linear_targets - self.linear_outputs) 243 | expanded_loss_coeff = tf.expand_dims(tf.expand_dims(self.loss_coeff, [-1]), [-1]) 244 | 245 | if hp.prioritize_loss: 246 | # Prioritize loss for frequencies under 3000 Hz. 247 | upper_priority_freq = int(3000 / (hp.sample_rate * 0.5) * hp.num_freq) 248 | lower_priority_freq = int(165 / (hp.sample_rate * 0.5) * hp.num_freq) 249 | 250 | l1_priority= l1[:,:,lower_priority_freq:upper_priority_freq] 251 | 252 | self.loss = tf.reduce_mean(mel_loss * expanded_loss_coeff) + \ 253 | 0.5 * tf.reduce_mean(l1 * expanded_loss_coeff) + \ 254 | 0.5 * tf.reduce_mean(l1_priority * expanded_loss_coeff) 255 | self.linear_loss = tf.reduce_mean(0.5 * (tf.reduce_mean(l1) + tf.reduce_mean(l1_priority))) 256 | 257 | else: 258 | self.loss = tf.reduce_mean(mel_loss * expanded_loss_coeff) + tf.reduce_mean(l1 * expanded_loss_coeff) 259 | self.linear_loss = tf.reduce_mean(l1) 260 | 261 | 262 | self.mel_loss = tf.reduce_mean(mel_loss) 263 | self.loss_without_coeff = self.mel_loss + self.linear_loss 264 | 265 | 266 | def add_optimizer(self, global_step): 267 | '''Adds optimizer. Sets "gradients" and "optimize" fields. add_loss must have been called. 268 | 269 | Args: 270 | global_step: int32 scalar Tensor representing current global step in training 271 | ''' 272 | with tf.variable_scope('optimizer') as scope: 273 | hp = self._hparams 274 | step = tf.cast(global_step + 1, dtype=tf.float32) 275 | 276 | if hp.decay_learning_rate_mode == 0: 277 | if self.is_randomly_initialized: 278 | warmup_steps = 4000.0 279 | else: 280 | warmup_steps = 40000.0 281 | self.learning_rate = hp.initial_learning_rate * warmup_steps**0.5 * \ 282 | tf.minimum(step * warmup_steps**-1.5, step**-0.5) 283 | elif hp.decay_learning_rate_mode == 1: 284 | self.learning_rate = hp.initial_learning_rate * \ 285 | tf.train.exponential_decay(1., step, 3000, 0.95) 286 | 287 | optimizer = tf.train.AdamOptimizer(self.learning_rate, hp.adam_beta1, hp.adam_beta2) 288 | gradients, variables = zip(*optimizer.compute_gradients(self.loss)) 289 | self.gradients = gradients 290 | clipped_gradients, _ = tf.clip_by_global_norm(gradients, 1.0) 291 | 292 | # Add dependency on UPDATE_OPS; otherwise batchnorm won't work correctly. See: 293 | # https://github.com/tensorflow/tensorflow/issues/1122 294 | with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)): 295 | self.optimize = optimizer.apply_gradients(zip(clipped_gradients, variables), 296 | global_step=global_step) 297 | 298 | 299 | def get_dummy_feed_dict(self): 300 | feed_dict = { 301 | self.is_manual_attention: False, 302 | self.manual_alignments: np.zeros([1, 1, 1]), 303 | } 304 | return feed_dict -------------------------------------------------------------------------------- /audio/google_speech.py: -------------------------------------------------------------------------------- 1 | import io 2 | import os 3 | import sys 4 | import json 5 | import string 6 | import argparse 7 | import operator 8 | import numpy as np 9 | from glob import glob 10 | from tqdm import tqdm 11 | from nltk import ngrams 12 | from difflib import SequenceMatcher 13 | from collections import defaultdict 14 | 15 | from google.cloud import speech 16 | from google.cloud.speech import enums 17 | from google.cloud.speech import types 18 | 19 | from utils import parallel_run 20 | from text import text_to_sequence 21 | 22 | #################################################### 23 | # When one or two audio is missed in the middle 24 | #################################################### 25 | 26 | def get_continuous_audio_paths(paths, debug=False): 27 | audio_ids = get_audio_ids_from_paths(paths) 28 | min_id, max_id = min(audio_ids), max(audio_ids) 29 | 30 | if int(max_id) - int(min_id) + 1 != len(audio_ids): 31 | base_path = paths[0].replace(min_id, "{:0" + str(len(max_id)) + "d}") 32 | new_paths = [ 33 | base_path.format(audio_id) \ 34 | for audio_id in range(int(min_id), int(max_id) + 1)] 35 | 36 | if debug: print("Missing audio : {} -> {}".format(paths, new_paths)) 37 | return new_paths 38 | else: 39 | return paths 40 | 41 | def get_argmax_key(info, with_value=False): 42 | max_key = max(info.keys(), key=(lambda k: info[k])) 43 | 44 | if with_value: 45 | return max_key, info[max_key] 46 | else: 47 | return max_key 48 | 49 | def similarity(text_a, text_b): 50 | text_a = "".join(remove_puncuations(text_a.strip()).split()) 51 | text_b = "".join(remove_puncuations(text_b.strip()).split()) 52 | 53 | score = SequenceMatcher(None, text_a, text_b).ratio() 54 | #score = 1 / (distance(decompose_ko_text(text_a), decompose_ko_text(text_b)) + 1e-5) 55 | #score = SequenceMatcher(None, 56 | # decompose_ko_text(text_a), decompose_ko_text(text_b)).ratio() 57 | 58 | if len(text_a) < len(text_b): 59 | return -1 + score 60 | else: 61 | return score 62 | 63 | def get_key_value_sorted(data): 64 | keys = list(data.keys()) 65 | keys.sort() 66 | values = [data[key] for key in keys] 67 | return keys, values 68 | 69 | def replace_pred_with_book( 70 | path, book_path=None, threshold=0.9, max_candidate_num=5, 71 | min_post_char_check=2, max_post_char_check=7, max_n=5, 72 | max_allow_missing_when_matching=4, debug=False): 73 | 74 | ####################################### 75 | # find text book from pred 76 | ####################################### 77 | 78 | if book_path is None: 79 | book_path = path.replace("speech", "text").replace("json", "txt") 80 | 81 | data = json.loads(open(path).read()) 82 | 83 | keys, preds = get_key_value_sorted(data) 84 | 85 | book_words = [word for word in open(book_path).read().split() if word != "=="] 86 | book_texts = [text.replace('\n', '') for text in open(book_path).readlines()] 87 | 88 | loc = 0 89 | prev_key = None 90 | force_stop = False 91 | prev_end_loc = -1 92 | prev_sentence_ended = True 93 | 94 | prev_empty_skip = False 95 | prev_not_found_skip = False 96 | 97 | black_lists = ["160.{:04d}".format(audio_id) for audio_id in range(20, 36)] 98 | 99 | new_preds = {} 100 | for key, pred in zip(keys, preds): 101 | if debug: print(key, pred) 102 | 103 | if pred == "" or key in black_lists: 104 | prev_empty_skip = True 105 | continue 106 | 107 | width, counter = 1, 0 108 | sim_dict, loc_dict = {}, {} 109 | 110 | while True: 111 | words = book_words[loc:loc + width] 112 | 113 | if len(words) == 0: 114 | print("Force stop. Left {}, Del {} {}". \ 115 | format(len(preds) - len(new_preds), new_preds[prev_key], prev_key)) 116 | new_preds.pop(prev_key, None) 117 | force_stop = True 118 | break 119 | 120 | candidate_candidates = {} 121 | 122 | for _pred in list(set([pred, koreanize_numbers(pred)])): 123 | max_skip = 0 if has_number(_pred[0]) or \ 124 | _pred[0] in """"'“”’‘’""" else len(words) 125 | 126 | end_sims = [] 127 | for idx in range(min(max_skip, 10)): 128 | text = " ".join(words[idx:]) 129 | 130 | ################################################ 131 | # Score of trailing sentence is also important 132 | ################################################ 133 | 134 | for jdx in range(min_post_char_check, 135 | max_post_char_check): 136 | sim = similarity( 137 | "".join(_pred.split())[-jdx:], 138 | "".join(text.split())[-jdx:]) 139 | end_sims.append(sim) 140 | 141 | candidate_candidates[text] = similarity(_pred, text) 142 | 143 | candidate, sim = get_argmax_key( 144 | candidate_candidates, with_value=True) 145 | 146 | if sim > threshold or max(end_sims + [-1]) > threshold - 0.2 or \ 147 | len(sim_dict) > 0: 148 | sim_dict[candidate] = sim 149 | loc_dict[candidate] = loc + width 150 | 151 | if len(sim_dict) > 0: 152 | counter += 1 153 | 154 | if counter > max_candidate_num: 155 | break 156 | 157 | width += 1 158 | 159 | if width - len(_pred.split()) > 5: 160 | break 161 | 162 | if force_stop: 163 | break 164 | 165 | if len(sim_dict) != 0: 166 | ############################################################# 167 | # Check missing words between prev pred and current pred 168 | ############################################################# 169 | 170 | if prev_key is not None: 171 | cur_idx = int(key.rsplit('.', 2)[-2]) 172 | prev_idx = int(prev_key.rsplit('.', 2)[-2]) 173 | 174 | if cur_idx - prev_idx > 10: 175 | force_stop = True 176 | break 177 | 178 | # word alinged based on prediction but may contain missing words 179 | # because google speech recognition sometimes skip one or two word 180 | # ex. ('오누이는 서로 자기가 할 일을 정했다.', '서로 자기가 할 일을 정했다.') 181 | original_candidate = new_candidate = get_argmax_key(sim_dict) 182 | 183 | word_to_find = original_candidate.split()[0] 184 | 185 | if not prev_empty_skip: 186 | search_idx = book_words[prev_end_loc:].index(word_to_find) \ 187 | if word_to_find in book_words[prev_end_loc:] else -1 188 | 189 | if 0 < search_idx < 4 and not prev_sentence_ended: 190 | words_to_check = book_words[prev_end_loc:prev_end_loc + search_idx] 191 | 192 | if ends_with_punctuation(words_to_check[0]) == True: 193 | tmp = " ".join([new_preds[prev_key]] + words_to_check[:1]) 194 | if debug: print(prev_key, tmp, new_preds[prev_key]) 195 | new_preds[prev_key] = tmp 196 | 197 | prev_end_loc += 1 198 | prev_sentence_ended = True 199 | 200 | search_idx = book_words[prev_end_loc:].index(word_to_find) \ 201 | if word_to_find in book_words[prev_end_loc:] else -1 202 | 203 | if 0 < search_idx < 4 and prev_sentence_ended: 204 | words_to_check = book_words[prev_end_loc:prev_end_loc + search_idx] 205 | 206 | if not any(ends_with_punctuation(word) for word in words_to_check): 207 | new_candidate = " ".join(words_to_check + [original_candidate]) 208 | if debug: print(key, new_candidate, original_candidate) 209 | 210 | new_preds[key] = new_candidate 211 | prev_sentence_ended = ends_with_punctuation(new_candidate) 212 | 213 | loc = loc_dict[original_candidate] 214 | prev_key = key 215 | prev_not_found_skip = False 216 | else: 217 | loc += len(_pred.split()) - 1 218 | prev_sentence_ended = True 219 | prev_not_found_skip = True 220 | 221 | prev_end_loc = loc 222 | prev_empty_skip = False 223 | 224 | if debug: 225 | print("=", pred) 226 | print("=", new_preds[key], loc) 227 | 228 | if force_stop: 229 | print(" [!] Force stop: {}".format(path)) 230 | 231 | align_diff = loc - len(book_words) 232 | 233 | if abs(align_diff) > 10: 234 | print(" => Align result of {}: {} - {} = {}".format(path, loc, len(book_words), align_diff)) 235 | 236 | ####################################### 237 | # find exact match of n-gram of pred 238 | ####################################### 239 | 240 | finished_ids = [] 241 | 242 | keys, preds = get_key_value_sorted(new_preds) 243 | 244 | if abs(align_diff) > 10: 245 | keys, preds = keys[:-30], preds[:-30] 246 | 247 | unfinished_ids = range(len(keys)) 248 | text_matches = [] 249 | 250 | for n in range(max_n, 1, -1): 251 | ngram_preds = ngrams(preds, n) 252 | 253 | for n_allow_missing in range(0, max_allow_missing_when_matching + 1): 254 | unfinished_ids = list(set(unfinished_ids) - set(finished_ids)) 255 | 256 | existing_ngram_preds = [] 257 | 258 | for ngram in ngram_preds: 259 | for text in book_texts: 260 | candidates = [ 261 | " ".join(text.split()[:-n_allow_missing]), 262 | " ".join(text.split()[n_allow_missing:]), 263 | ] 264 | for tmp_text in candidates: 265 | if " ".join(ngram) == tmp_text: 266 | existing_ngram_preds.append(ngram) 267 | break 268 | 269 | tmp_keys = [] 270 | cur_ngram = [] 271 | 272 | ngram_idx = 0 273 | ngram_found = False 274 | 275 | for id_idx in unfinished_ids: 276 | key, pred = keys[id_idx], preds[id_idx] 277 | 278 | if ngram_idx >= len(existing_ngram_preds): 279 | break 280 | 281 | cur_ngram = existing_ngram_preds[ngram_idx] 282 | 283 | if pred in cur_ngram: 284 | ngram_found = True 285 | 286 | tmp_keys.append(key) 287 | finished_ids.append(id_idx) 288 | 289 | if len(tmp_keys) == len(cur_ngram): 290 | if debug: print(n_allow_missing, tmp_keys, cur_ngram) 291 | 292 | tmp_keys = get_continuous_audio_paths(tmp_keys, debug) 293 | text_matches.append( 294 | [[" ".join(cur_ngram)], tmp_keys] 295 | ) 296 | 297 | ngram_idx += 1 298 | tmp_keys = [] 299 | cur_ngram = [] 300 | else: 301 | if pred == cur_ngram[-1]: 302 | ngram_idx += 1 303 | tmp_keys = [] 304 | cur_ngram = [] 305 | else: 306 | if len(tmp_keys) > 0: 307 | ngram_found = False 308 | 309 | tmp_keys = [] 310 | cur_ngram = [] 311 | 312 | for id_idx in range(len(keys)): 313 | if id_idx not in finished_ids: 314 | key, pred = keys[id_idx], preds[id_idx] 315 | 316 | text_matches.append( 317 | [[pred], [key]] 318 | ) 319 | 320 | ############################################################## 321 | # ngram again for just in case after adding missing words 322 | ############################################################## 323 | 324 | max_keys = [max(get_audio_ids_from_paths(item[1], as_int=True)) for item in text_matches] 325 | sorted_text_matches = \ 326 | [item for _, item in sorted(zip(max_keys, text_matches))] 327 | 328 | preds = [item[0][0] for item in sorted_text_matches] 329 | keys = [item[1] for item in sorted_text_matches] 330 | 331 | def book_sentence_idx_search(query, book_texts): 332 | for idx, text in enumerate(book_texts): 333 | if query in text: 334 | return idx, text 335 | return False, False 336 | 337 | text_matches = [] 338 | idx, book_cursor_idx = 0, 0 339 | 340 | if len(preds) == 0: 341 | return [] 342 | 343 | while True: 344 | tmp_texts = book_texts[book_cursor_idx:] 345 | 346 | jdx = 0 347 | tmp_pred = preds[idx] 348 | idxes_to_merge = [idx] 349 | 350 | prev_sent_idx, prev_sent = book_sentence_idx_search(tmp_pred, tmp_texts) 351 | while idx + jdx + 1 < len(preds): 352 | jdx += 1 353 | 354 | tmp_pred = preds[idx + jdx] 355 | sent_idx, sent = book_sentence_idx_search(tmp_pred, tmp_texts) 356 | 357 | if not sent_idx: 358 | if debug: print(" [!] NOT FOUND: {}".format(tmp_pred)) 359 | break 360 | 361 | if prev_sent_idx == sent_idx: 362 | idxes_to_merge.append(idx + jdx) 363 | else: 364 | break 365 | 366 | new_keys = get_continuous_audio_paths( 367 | sum([keys[jdx] for jdx in idxes_to_merge], [])) 368 | text_matches.append([ [tmp_texts[prev_sent_idx]], new_keys ]) 369 | 370 | if len(new_keys) > 1: 371 | book_cursor_idx += 1 372 | 373 | book_cursor_idx = max(book_cursor_idx, sent_idx) 374 | 375 | if idx == len(preds) - 1: 376 | break 377 | idx = idx + jdx 378 | 379 | # Counter([len(i) for i in text_matches.values()]) 380 | return text_matches 381 | 382 | def get_text_from_audio_batch(paths, multi_process=False): 383 | results = {} 384 | items = parallel_run(get_text_from_audio, paths, 385 | desc="get_text_from_audio_batch") 386 | for item in items: 387 | results.update(item) 388 | return results 389 | 390 | def get_text_from_audio(path): 391 | error_count = 0 392 | 393 | txt_path = path.replace('flac', 'txt') 394 | 395 | if os.path.exists(txt_path): 396 | with open(txt_path) as f: 397 | out = json.loads(open(txt_path).read()) 398 | return out 399 | 400 | out = {} 401 | while True: 402 | try: 403 | client = speech.SpeechClient() 404 | 405 | with io.open(path, 'rb') as audio_file: 406 | content = audio_file.read() 407 | audio = types.RecognitionAudio(content=content) 408 | 409 | config = types.RecognitionConfig( 410 | encoding=enums.RecognitionConfig.AudioEncoding.FLAC, 411 | sample_rate_hertz=16000, 412 | language_code='ko-KR') 413 | 414 | response = client.recognize(config, audio) 415 | if len(response.results) > 0: 416 | alternatives = response.results[0].alternatives 417 | 418 | results = [alternative.transcript for alternative in alternatives] 419 | assert len(results) == 1, "More than 1 results: {}".format(results) 420 | 421 | out = { path: "" if len(results) == 0 else results[0] } 422 | print(results[0]) 423 | break 424 | break 425 | except: 426 | error_count += 1 427 | print("Skip warning for {} for {} times". \ 428 | format(path, error_count)) 429 | 430 | if error_count > 5: 431 | break 432 | else: 433 | continue 434 | 435 | with open(txt_path, 'w') as f: 436 | json.dump(out, f, indent=2, ensure_ascii=False) 437 | 438 | return out 439 | 440 | if __name__ == '__main__': 441 | parser = argparse.ArgumentParser() 442 | parser.add_argument('--asset-dir', type=str, default='assets') 443 | parser.add_argument('--data-dir', type=str, default='audio') 444 | parser.add_argument('--pattern', type=str, default="audio/*.flac") 445 | parser.add_argument('--metadata', type=str, default="metadata.json") 446 | config, unparsed = parser.parse_known_args() 447 | 448 | paths = glob(config.pattern) 449 | paths.sort() 450 | paths = paths 451 | 452 | book_ids = list(set([ 453 | os.path.basename(path).split('.', 1)[0] for path in paths])) 454 | book_ids.sort() 455 | 456 | def get_finished_ids(): 457 | finished_paths = glob(os.path.join( 458 | config.asset_dir, "speech-*.json")) 459 | finished_ids = list(set([ 460 | os.path.basename(path).split('.', 1)[0].replace("speech-", "") for path in finished_paths])) 461 | finished_ids.sort() 462 | return finished_ids 463 | 464 | finished_ids = get_finished_ids() 465 | 466 | print("# Finished : {}/{}".format(len(finished_ids), len(book_ids))) 467 | 468 | book_ids_to_parse = list(set(book_ids) - set(finished_ids)) 469 | book_ids_to_parse.sort() 470 | 471 | assert os.path.exists(config.asset_dir), "assert_dir not found" 472 | 473 | pbar = tqdm(book_ids_to_parse, "[1] google_speech", 474 | initial=len(finished_ids), total=len(book_ids)) 475 | 476 | for book_id in pbar: 477 | current_paths = glob(config.pattern.replace("*", "{}.*".format(book_id))) 478 | pbar.set_description("[1] google_speech : {}".format(book_id)) 479 | 480 | results = get_text_from_audio_batch(current_paths) 481 | 482 | filename = "speech-{}.json".format(book_id) 483 | path = os.path.join(config.asset_dir, filename) 484 | 485 | with open(path, "w") as f: 486 | json.dump(results, f, indent=2, ensure_ascii=False) 487 | 488 | finished_ids = get_finished_ids() 489 | 490 | for book_id in tqdm(finished_ids, "[2] text_match"): 491 | filename = "speech-{}.json".format(book_id) 492 | path = os.path.join(config.asset_dir, filename) 493 | clean_path = path.replace("speech", "clean-speech") 494 | 495 | if os.path.exists(clean_path): 496 | print(" [*] Skip {}".format(clean_path)) 497 | else: 498 | results = replace_pred_with_book(path) 499 | with open(clean_path, "w") as f: 500 | json.dump(results, f, indent=2, ensure_ascii=False) 501 | 502 | # Dummy 503 | 504 | if False: 505 | match_paths = get_paths_by_pattern( 506 | config.asset_dir, 'clean-speech-*.json') 507 | 508 | metadata_path = os.path.join(config.data_dir, config.metadata) 509 | 510 | print(" [3] Merge clean-speech-*.json into {}".format(metadata_path)) 511 | 512 | merged_data = [] 513 | for path in match_paths: 514 | with open(path) as f: 515 | merged_data.extend(json.loads(f.read())) 516 | 517 | import ipdb; ipdb.set_trace() 518 | 519 | with open(metadata_path, 'w') as f: 520 | json.dump(merged_data, f, indent=2, ensure_ascii=False) 521 | -------------------------------------------------------------------------------- /models/rnn_wrappers.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | from tensorflow.contrib.rnn import RNNCell 4 | from tensorflow.python.ops import rnn_cell_impl 5 | from tensorflow.contrib.data.python.util import nest 6 | # from tensorflow.contrib.framework import nest 7 | from tensorflow.contrib.seq2seq.python.ops.attention_wrapper \ 8 | import _bahdanau_score, _BaseAttentionMechanism, \ 9 | BahdanauAttention, AttentionWrapperState, AttentionMechanism 10 | # _BaseMonotonicAttentionMechanism, _maybe_mask_score, _prepare_memory, \ 11 | # _monotonic_probability_fn 12 | # from tensorflow.python.ops import array_ops, math_ops, nn_ops, variable_scope 13 | # from tensorflow.python.layers.core import Dense 14 | from .modules import prenet 15 | import functools 16 | 17 | _zero_state_tensors = rnn_cell_impl._zero_state_tensors 18 | 19 | 20 | class AttentionWrapper(RNNCell): 21 | """Wraps another `RNNCell` with attention. 22 | """ 23 | 24 | def __init__(self, cell, attention_mechanism, is_manual_attention, manual_alignments, 25 | attention_layer_size=None, alignment_history=False, 26 | cell_input_fn=None, output_attention=True, initial_cell_state=None, name=None): 27 | """Construct the `AttentionWrapper`. 28 | Args: 29 | cell: An instance of `RNNCell`. 30 | attention_mechanism: A list of `AttentionMechanism` instances or a single 31 | instance. 32 | attention_layer_size: A list of Python integers or a single Python 33 | integer, the depth of the attention (output) layer(s). If None 34 | (default), use the context as attention at each time step. Otherwise, 35 | feed the context and cell output into the attention layer to generate 36 | attention at each time step. If attention_mechanism is a list, 37 | attention_layer_size must be a list of the same length. 38 | alignment_history: Python boolean, whether to store alignment history 39 | from all time steps in the final output state (currently stored as a 40 | time major `TensorArray` on which you must call `stack()`). 41 | cell_input_fn: (optional) A `callable`. The default is: 42 | `lambda inputs, attention: array_tf.concat([inputs, attention], -1)`. 43 | output_attention: Python bool. If `True` (default), the output at each 44 | time step is the attention value. This is the behavior of Luong-style 45 | attention mechanisms. If `False`, the output at each time step is 46 | the output of `cell`. This is the beahvior of Bhadanau-style 47 | attention mechanisms. In both cases, the `attention` tensor is 48 | propagated to the next time step via the state and is used there. 49 | This flag only controls whether the attention mechanism is propagated 50 | up to the next cell in an RNN stack or to the top RNN output. 51 | initial_cell_state: The initial state value to use for the cell when 52 | the user calls `zero_state()`. Note that if this value is provided 53 | now, and the user uses a `batch_size` argument of `zero_state` which 54 | does not match the batch size of `initial_cell_state`, proper 55 | behavior is not guaranteed. 56 | name: Name to use when creating tf. 57 | Raises: 58 | TypeError: `attention_layer_size` is not None and (`attention_mechanism` 59 | is a list but `attention_layer_size` is not; or vice versa). 60 | ValueError: if `attention_layer_size` is not None, `attention_mechanism` 61 | is a list, and its length does not match that of `attention_layer_size`. 62 | """ 63 | super(AttentionWrapper, self).__init__(name=name) 64 | 65 | self.is_manual_attention = is_manual_attention 66 | self.manual_alignments = manual_alignments 67 | 68 | if isinstance(attention_mechanism, (list, tuple)): 69 | self._is_multi = True 70 | attention_mechanisms = attention_mechanism 71 | for attention_mechanism in attention_mechanisms: 72 | if not isinstance(attention_mechanism, AttentionMechanism): 73 | raise TypeError( 74 | "attention_mechanism must contain only instances of " 75 | "AttentionMechanism, saw type: %s" 76 | % type(attention_mechanism).__name__) 77 | else: 78 | self._is_multi = False 79 | if not isinstance(attention_mechanism, AttentionMechanism): 80 | raise TypeError( 81 | "attention_mechanism must be an AttentionMechanism or list of " 82 | "multiple AttentionMechanism instances, saw type: %s" 83 | % type(attention_mechanism).__name__) 84 | attention_mechanisms = (attention_mechanism,) 85 | 86 | if cell_input_fn is None: 87 | cell_input_fn = ( 88 | lambda inputs, attention: tf.concat([inputs, attention], -1)) 89 | else: 90 | if not callable(cell_input_fn): 91 | raise TypeError( 92 | "cell_input_fn must be callable, saw type: %s" 93 | % type(cell_input_fn).__name__) 94 | 95 | if attention_layer_size is not None: 96 | attention_layer_sizes = tuple( 97 | attention_layer_size 98 | if isinstance(attention_layer_size, (list, tuple)) 99 | else (attention_layer_size,)) 100 | if len(attention_layer_sizes) != len(attention_mechanisms): 101 | raise ValueError( 102 | "If provided, attention_layer_size must contain exactly one " 103 | "integer per attention_mechanism, saw: %d vs %d" 104 | % (len(attention_layer_sizes), len(attention_mechanisms))) 105 | self._attention_layers = tuple( 106 | layers_core.Dense( 107 | attention_layer_size, name="attention_layer", use_bias=False) 108 | for attention_layer_size in attention_layer_sizes) 109 | self._attention_layer_size = sum(attention_layer_sizes) 110 | else: 111 | self._attention_layers = None 112 | self._attention_layer_size = sum( 113 | attention_mechanism.values.get_shape()[-1].value 114 | for attention_mechanism in attention_mechanisms) 115 | 116 | self._cell = cell 117 | self._attention_mechanisms = attention_mechanisms 118 | self._cell_input_fn = cell_input_fn 119 | self._output_attention = output_attention 120 | self._alignment_history = alignment_history 121 | with tf.name_scope(name, "AttentionWrapperInit"): 122 | if initial_cell_state is None: 123 | self._initial_cell_state = None 124 | else: 125 | final_state_tensor = nest.flatten(initial_cell_state)[-1] 126 | state_batch_size = ( 127 | final_state_tensor.shape[0].value 128 | or tf.shape(final_state_tensor)[0]) 129 | error_message = ( 130 | "When constructing AttentionWrapper %s: " % self._base_name + 131 | "Non-matching batch sizes between the memory " 132 | "(encoder output) and initial_cell_state. Are you using " 133 | "the BeamSearchDecoder? You may need to tile your initial state " 134 | "via the tf.contrib.seq2seq.tile_batch function with argument " 135 | "multiple=beam_width.") 136 | with tf.control_dependencies( 137 | self._batch_size_checks(state_batch_size, error_message)): 138 | self._initial_cell_state = nest.map_structure( 139 | lambda s: tf.identity(s, name="check_initial_cell_state"), 140 | initial_cell_state) 141 | 142 | def _batch_size_checks(self, batch_size, error_message): 143 | return [tf.assert_equal(batch_size, 144 | attention_mechanism.batch_size, 145 | message=error_message) 146 | for attention_mechanism in self._attention_mechanisms] 147 | 148 | def _item_or_tuple(self, seq): 149 | """Returns `seq` as tuple or the singular element. 150 | Which is returned is determined by how the AttentionMechanism(s) were passed 151 | to the constructor. 152 | Args: 153 | seq: A non-empty sequence of items or generator. 154 | Returns: 155 | Either the values in the sequence as a tuple if AttentionMechanism(s) 156 | were passed to the constructor as a sequence or the singular element. 157 | """ 158 | t = tuple(seq) 159 | if self._is_multi: 160 | return t 161 | else: 162 | return t[0] 163 | 164 | @property 165 | def output_size(self): 166 | if self._output_attention: 167 | return self._attention_layer_size 168 | else: 169 | return self._cell.output_size 170 | 171 | @property 172 | def state_size(self): 173 | return AttentionWrapperState( 174 | cell_state=self._cell.state_size, 175 | time=tf.TensorShape([]), 176 | attention=self._attention_layer_size, 177 | alignments=self._item_or_tuple( 178 | a.alignments_size for a in self._attention_mechanisms), 179 | alignment_history=self._item_or_tuple( 180 | () for _ in self._attention_mechanisms)) # sometimes a TensorArray 181 | 182 | def zero_state(self, batch_size, dtype): 183 | with tf.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]): 184 | if self._initial_cell_state is not None: 185 | cell_state = self._initial_cell_state 186 | else: 187 | cell_state = self._cell.zero_state(batch_size, dtype) 188 | error_message = ( 189 | "When calling zero_state of AttentionWrapper %s: " % self._base_name + 190 | "Non-matching batch sizes between the memory " 191 | "(encoder output) and the requested batch size. Are you using " 192 | "the BeamSearchDecoder? If so, make sure your encoder output has " 193 | "been tiled to beam_width via tf.contrib.seq2seq.tile_batch, and " 194 | "the batch_size= argument passed to zero_state is " 195 | "batch_size * beam_width.") 196 | with tf.control_dependencies( 197 | self._batch_size_checks(batch_size, error_message)): 198 | cell_state = nest.map_structure( 199 | lambda s: tf.identity(s, name="checked_cell_state"), 200 | cell_state) 201 | 202 | return AttentionWrapperState( 203 | cell_state=cell_state, 204 | time=tf.zeros([], dtype=tf.int32), 205 | attention=_zero_state_tensors(self._attention_layer_size, batch_size, dtype), 206 | alignments=self._item_or_tuple( 207 | attention_mechanism.initial_alignments(batch_size, dtype) 208 | for attention_mechanism in self._attention_mechanisms), 209 | alignment_history=self._item_or_tuple( 210 | tf.TensorArray(dtype=dtype, size=0, dynamic_size=True) 211 | if self._alignment_history else () 212 | for _ in self._attention_mechanisms)) 213 | 214 | def call(self, inputs, state): 215 | """Perform a step of attention-wrapped RNN. 216 | - Step 1: Mix the `inputs` and previous step's `attention` output via 217 | `cell_input_fn`. 218 | - Step 2: Call the wrapped `cell` with this input and its previous state. 219 | - Step 3: Score the cell's output with `attention_mechanism`. 220 | - Step 4: Calculate the alignments by passing the score through the 221 | `normalizer`. 222 | - Step 5: Calculate the context vector as the inner product between the 223 | alignments and the attention_mechanism's values (memory). 224 | - Step 6: Calculate the attention output by concatenating the cell output 225 | and context through the attention layer (a linear layer with 226 | `attention_layer_size` outputs). 227 | Args: 228 | inputs: (Possibly nested tuple of) Tensor, the input at this time step. 229 | state: An instance of `AttentionWrapperState` containing 230 | tensors from the previous time step. 231 | Returns: 232 | A tuple `(attention_or_cell_output, next_state)`, where: 233 | - `attention_or_cell_output` depending on `output_attention`. 234 | - `next_state` is an instance of `AttentionWrapperState` 235 | containing the state calculated at this time step. 236 | Raises: 237 | TypeError: If `state` is not an instance of `AttentionWrapperState`. 238 | """ 239 | if not isinstance(state, AttentionWrapperState): 240 | raise TypeError("Expected state to be instance of AttentionWrapperState. " 241 | "Received type %s instead." % type(state)) 242 | 243 | # Step 1: Calculate the true inputs to the cell based on the 244 | # previous attention value. 245 | cell_inputs = self._cell_input_fn(inputs, state.attention) 246 | cell_state = state.cell_state 247 | cell_output, next_cell_state = self._cell(cell_inputs, cell_state) 248 | 249 | cell_batch_size = ( 250 | cell_output.shape[0].value or tf.shape(cell_output)[0]) 251 | error_message = ( 252 | "When applying AttentionWrapper %s: " % self.name + 253 | "Non-matching batch sizes between the memory " 254 | "(encoder output) and the query (decoder output). Are you using " 255 | "the BeamSearchDecoder? You may need to tile your memory input via " 256 | "the tf.contrib.seq2seq.tile_batch function with argument " 257 | "multiple=beam_width.") 258 | with tf.control_dependencies( 259 | self._batch_size_checks(cell_batch_size, error_message)): 260 | cell_output = tf.identity( 261 | cell_output, name="checked_cell_output") 262 | 263 | if self._is_multi: 264 | previous_alignments = state.alignments 265 | previous_alignment_history = state.alignment_history 266 | else: 267 | previous_alignments = [state.alignments] 268 | previous_alignment_history = [state.alignment_history] 269 | 270 | all_alignments = [] 271 | all_attentions = [] 272 | all_histories = [] 273 | 274 | for i, attention_mechanism in enumerate(self._attention_mechanisms): 275 | attention, alignments = _compute_attention( 276 | attention_mechanism, cell_output, previous_alignments[i], 277 | self._attention_layers[i] if self._attention_layers else None, 278 | self.is_manual_attention, self.manual_alignments, state.time) 279 | 280 | alignment_history = previous_alignment_history[i].write( 281 | state.time, alignments) if self._alignment_history else () 282 | 283 | all_alignments.append(alignments) 284 | all_histories.append(alignment_history) 285 | all_attentions.append(attention) 286 | 287 | attention = tf.concat(all_attentions, 1) 288 | next_state = AttentionWrapperState( 289 | time=state.time + 1, 290 | cell_state=next_cell_state, 291 | attention=attention, 292 | alignments=self._item_or_tuple(all_alignments), 293 | alignment_history=self._item_or_tuple(all_histories)) 294 | 295 | if self._output_attention: 296 | return attention, next_state 297 | else: 298 | return cell_output, next_state 299 | 300 | 301 | def _compute_attention( 302 | attention_mechanism, cell_output, previous_alignments, 303 | attention_layer, is_manual_attention, manual_alignments, time): 304 | 305 | computed_alignments = attention_mechanism( 306 | cell_output, previous_alignments=previous_alignments) 307 | batch_size, max_time = \ 308 | tf.shape(computed_alignments)[0], tf.shape(computed_alignments)[1] 309 | 310 | alignments = tf.cond( 311 | is_manual_attention, 312 | lambda: manual_alignments[:, time, :], 313 | lambda: computed_alignments, 314 | ) 315 | 316 | #alignments = tf.one_hot(tf.zeros((batch_size,), dtype=tf.int32), max_time, dtype=tf.float32) 317 | 318 | # Reshape from [batch_size, memory_time] to [batch_size, 1, memory_time] 319 | expanded_alignments = tf.expand_dims(alignments, 1) 320 | 321 | # Context is the inner product of alignments and values along the 322 | # memory time dimension. 323 | # alignments shape is 324 | # [batch_size, 1, memory_time] 325 | # attention_mechanism.values shape is 326 | # [batch_size, memory_time, memory_size] 327 | # the batched matmul is over memory_time, so the output shape is 328 | # [batch_size, 1, memory_size]. 329 | # we then squeeze out the singleton dim. 330 | context = tf.matmul(expanded_alignments, attention_mechanism.values) 331 | context = tf.squeeze(context, [1]) 332 | 333 | if attention_layer is not None: 334 | attention = attention_layer(tf.concat([cell_output, context], 1)) 335 | else: 336 | attention = context 337 | 338 | return attention, alignments 339 | 340 | 341 | 342 | class DecoderPrenetWrapper(RNNCell): 343 | '''Runs RNN inputs through a prenet before sending them to the cell.''' 344 | 345 | def __init__(self, cell, embed_to_concat, is_training, prenet_sizes, dropout_prob): 346 | super(DecoderPrenetWrapper, self).__init__() 347 | self._cell = cell 348 | self._is_training = is_training 349 | self._embed_to_concat = embed_to_concat 350 | self.prenet_sizes = prenet_sizes 351 | self.dropout_prob = dropout_prob 352 | 353 | @property 354 | def state_size(self): 355 | return self._cell.state_size 356 | 357 | @property 358 | def output_size(self): 359 | return self._cell.output_size 360 | 361 | def call(self, inputs, state): 362 | prenet_out = prenet(inputs, self._is_training, self.prenet_sizes, self.dropout_prob, scope='decoder_prenet') 363 | 364 | if self._embed_to_concat is not None: 365 | concat_out = tf.concat([prenet_out, self._embed_to_concat], axis=-1, name='speaker_concat') 366 | return self._cell(concat_out, state) 367 | else: 368 | return self._cell(prenet_out, state) 369 | 370 | def zero_state(self, batch_size, dtype): 371 | return self._cell.zero_state(batch_size, dtype) 372 | 373 | 374 | class ConcatOutputAndAttentionWrapper(RNNCell): 375 | '''Concatenates RNN cell output with the attention context vector. 376 | 377 | This is expected to wrap a cell wrapped with an AttentionWrapper constructed with 378 | attention_layer_size=None and output_attention=False. Such a cell's state will include an 379 | "attention" field that is the context vector. 380 | ''' 381 | 382 | def __init__(self, cell, embed_to_concat): 383 | super(ConcatOutputAndAttentionWrapper, self).__init__() 384 | self._cell = cell 385 | self._embed_to_concat = embed_to_concat 386 | 387 | 388 | @property 389 | def state_size(self): 390 | return self._cell.state_size 391 | 392 | @property 393 | def output_size(self): 394 | return self._cell.output_size + self._cell.state_size.attention 395 | 396 | def call(self, inputs, state): 397 | output, res_state = self._cell(inputs, state) 398 | 399 | if self._embed_to_concat is not None: 400 | tensors = [ 401 | output, res_state.attention, 402 | self._embed_to_concat, 403 | ] 404 | return tf.concat(tensors, axis=-1), res_state 405 | else: 406 | return tf.concat([output, res_state.attention], axis=-1), res_state 407 | 408 | def zero_state(self, batch_size, dtype): 409 | return self._cell.zero_state(batch_size, dtype) 410 | --------------------------------------------------------------------------------