├── .gitignore ├── LICENSE ├── README.md ├── app.py ├── img ├── draw_embd_sim.py ├── gen_hfg.wav ├── gen_rtg.wav ├── gen_tts.wav ├── gt_hfg.wav ├── gt_rtg.wav ├── gt_tts.wav ├── hifi-gan.png ├── rtg_D.png ├── rtg_G.png ├── rtg_arch.png ├── rtg_cmp.png ├── tacotron-cn.mini.step-48000-embed_cosine.png ├── tacotron-cn.png ├── tacotron-en.png ├── transtacos.mini.png ├── transtacos.mini.step-48000_prds_embd_cosine.png ├── transtacos.mini.step-48000_text_embd_cosine.png ├── tts.png ├── tts_arch.png ├── tts_decoder.png ├── tts_embed.png ├── tts_encoder.png ├── tts_out_align.png ├── tts_out_spec.png ├── tts_posnet.png └── y_tmpl.wav ├── index.html ├── requirements.txt ├── retunegan ├── Makefile ├── audio.py ├── audio_proxy.py ├── data.py ├── hparam.py ├── infer.py ├── models │ ├── __init__.py │ ├── discrminator.py │ ├── generator.py │ └── loss.py ├── server.py ├── tools │ ├── spec2wavset.py │ ├── test_downsample.py │ ├── test_envolope.py │ ├── test_griffinlim.py │ ├── test_istft_iter.py │ ├── test_pesq.py │ ├── test_phase_recover.py │ └── test_strip_mirror.py ├── train.py └── utils.py ├── stats ├── DataBaker-lexicon.txt ├── DataBaker-stats.txt ├── DataBaker.stats ├── DataBaker_gen_stat.py ├── DataBaker_print_pinyins.py ├── DataBaker_print_symbols.py ├── inspect_preproc.py ├── inspect_spec.py ├── thchs30-lexicon.txt ├── thchs30_gen_vbanks.py └── thchs30_print_symbols.py └── transtacos ├── Makefile ├── audio.py ├── data.py ├── datasets ├── __skel__.py ├── databaker.py └── thchs30.py ├── hparam.py ├── models ├── attention.py ├── custom_decoder.py ├── modules.py ├── rnn_wrappers.py └── tacotron.py ├── preprocess.py ├── server.py ├── synth.py ├── text ├── __init__.py ├── g2p.py ├── phonodict_cn.csv ├── phonodict_cn.py ├── phonodict_cn.txt ├── symbols.py └── text.py ├── train.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # project meta 2 | .vscode/ 3 | 4 | # cache 5 | __pycache__/ 6 | .cache/ 7 | *.py[cod] 8 | 9 | # misc 10 | ref/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Armit 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TransTacoS-RetuneGAN 2 | 3 | A lighter-weight (perhaps!) Text-to-Speech for Chinese/Mandarin synthesize, inspired by Tacotron & FastSpeech2 & RefineGAN. 4 | It is also my shitty graduation design project, just a toy, so lower your expectations :) 5 | 6 | ---- 7 | 8 | ## Quick Start 9 | 10 | ### setup 11 | 12 | Since `TransTacoS` is implemented in tensorflow while `RefineGAN` in torch respectively, you could separate them by creating virtual envs, but they are likely not to conflict, thus you could try to put all these together: 13 | 14 | - install `tensorflow-gpu==1.14.0 tensorboard==1.14.0` following `https://tensorflow.google.cn/install/pip` 15 | - install `torch==1.8.0+cu1xx torchaudio==0.8.0` following `https://pytorch.org/`, where `cu1xx` is your cuda version 16 | - run `pip install -r requirements.txt` for the rest dependencies 17 | 18 | ### dataset 19 | 20 | - download and unzip the open source dataset [DataBaker](https://www.data-baker.com/data/index/TNtts) 21 | - other dataset requires user-defined preprocessor, please refer to `transtacos/dataset/__skel__.py` 22 | 23 | ### train 24 | 25 | - check path configs in all `Makefile` 26 | - `cd transtacos & make preprocess` to prepare acoustic features (linear/mel/f0/c0/zcr) 27 | - `cd transtacos & make train` to train TransTacoS 28 | - `cd retunegan & make finetune` to train RetuneGAN using preprocessed linear spectrograms (rather than from raw wave) 29 | 30 | ### deploy 31 | 32 | - check port configs in all `Makefile` 33 | - `cd transtacos & make server` to start TransTacoS headless HTTP server (default at port 5105) 34 | - `cd retunegan & make server` to start RetuneGAN headless HTTP server (default at port 5104) 35 | - `python app.py` to start the WebUI app (default at port 5103) 36 | - point your browser to `http://localhost:5103`, now have a try! 37 | 38 | ## Model Architecture 39 | 40 | ### TransTacoS 41 | 42 | ![align](/img/tts_out_align.png) 43 | 44 | ![spec](/img/tts_out_spec.png) 45 | 46 | #### What I actually did: 47 | 48 | - TransTacoS := 49 | - embed: self-designed G2P solution (syl4), employ prodosy marks as linguistic feature 50 | - for G2P solution, I tried char(seq), phoneme, pinyin (CV), CVC, VCV, CVVC ... 51 | - de facto, they merely influence mel_loss, but effect that how controllable the pronounciation is 52 | - so I design **syl4** to split phoneme against tone (just small improvement) 53 | - I predict prodosy marks by simple CNN rather than RNN, it's not enough reasonable though.. 54 | - encoder: modified from FastSpeech2 55 | - I use multi-head self-DotAttn with GFFW (gated-FFW) for backbone transform 56 | - f0/c0 feature is quantilized to embed, then fused in with cross-DotAttn 57 | - decoder: inherited from Tacotron 58 | - this RNN+LSA decoder is really complicated, I dare not to touch :( 59 | - posnet: self-designed simple MLP and grouped-Linear 60 | - I found simple Linear layer with growing depth leads to lower mel_loss than Conv1d 61 | - thus realy DO NOT understand why Tacotron2 even FastSpeech2 still use a postnet directly strech n_channels from 80 (n_mel) to 1025 (n_freq), this is surely tooooo hard to compensate information loss 62 | 63 | Frankly speaking, TransTacoS didn't improve any thing profoundly from Tacotron, but I just found that shallower network leads to lower mel_loss, so maybe simple embed+decoder is already enough :( 64 | 65 | #### Tips of ideas to try or failed: 66 | 67 | - To predict `f0/sp/ap` features so that we can use WORLD vocoder 68 | - It seems `sp` is OK, because it resembles mel very much 69 | - but `ap` requires to be carefully normalized, and accurate `f0` is even harder to predict 70 | - audio quality of WORLD sounds worse than Griffin-Lim at times :( 71 | - To predict spectrograms' magnitude part together with phase part, so that we can directly vocode using pure `istft` 72 | - It should be hard to optimize losses on phase, especially on conjunction points between mel frames 73 | - but these guys claim that they did it: [iSTFTNet](https://arxiv.org/abs/2203.02395) 74 | - To predict `f0` and `dyn` so that vocoder might benefits 75 | - I don't know how to separate them from mel, because so far I must regularize decoder's output to be mel (for the sake of teacher force) 76 | - When I remove the mel loss, I found that the RNN decoder became lazy to learn, and the align model also not work 77 | - mel is even more quantized that linear, to abstract `f0` and `dyn` from only mel seems not the reasonable 78 | - To extract duration info from the soft-aligned alignment map, so that we can further train a non-autogressive FastSpeech2 to speed up inference 79 | - just like [FastSpeech](https://arxiv.org/abs/1905.09263) and [DeepSinger](https://arxiv.org/abs/2007.04590) did 80 | - but FastSpeech reported that extracted duration is not enough accurate, yet I could not fully understand and reproduce DeepSinger's DP algorithm for duration extraction 81 | 82 | 83 | ### RetuneGAN 84 | 85 | ![gen_wav_cmp](/img/rtg_cmp.png) 86 | 87 | 88 | #### What I actually did: 89 | 90 | - RetuneGAN := 91 | - preprocess 92 | - extract `reference wav` using Griffin-Lim 93 | - extract `u/v mask` by hand-tuned zrc/c0 threshold (for Split-G only) 94 | - generators 95 | - `UNet-G` (encoder-decoder generator): modified from RefineGAN, we use the output of Griffin-Lim as reference wav, rather than an F0/C0-guided hand-crafted *speech template* 96 | - `Split-G` (split u/v generator): self-designed, inspired by Multi-Band MelGAN, but I found the generated quality is holy shit :( 97 | - `ResStack` borrowed from MelGAN 98 | - `ResBlock` modified from HiFiGAN 99 | - discriminators 100 | - `MSD` (multi scale discriminator): borrowed from MelGAN, I think it's good for plosive consonants 101 | - `MPD` (multi period discriminator): borrowed from HiFiGAN, I take it as a multiple MSDs' stack-up 102 | - `MTD` (multi stft discriminator): modified from UnivNet, it has two work modes depending on its input (MPSD seems better indeed ...) 103 | - `MPSD` (multi parameter spectrogram discriminator): like in UnivNet, but we let it judge both phase part and magnitude part 104 | - `PHD` (phase discriminator): self-designed, care more about phase, since `l_mstft` has already regulated magnitude 105 | - input of MPSD is `[(mag_real, phase_real), (mag_fake, phase_fake)]`, thus distinguishes real/fake stft data 106 | - input of PHD is `[(mag_real, phase_real), (mag_real, phase_fake)]`, thus ONLY distinguishes real/fake phase 107 | - losses 108 | - `l_adv` (adversarial loss): modified from HiFiGAN, but relativized 109 | - `l_fm` (feature map loss): borrowed from MelGAN 110 | - `l_mstft` (multi stft loss): modified from Parallel WaveGAN, but we calculate mel_loss rather than linear_loss 111 | - `l_env` (envlope loss): borrowed from RefineGAN 112 | - `l_dyn` (dynamic loss): self-designed, inspired by `l_env` 113 | - `l_sm` (strip mirror loss): self-designed, but might hurts audio quality :( 114 | 115 | Oh my dude, it's really a biggy feng-he monster :( 116 | 117 | #### Tips of ideas to try or failed: 118 | 119 | - To divide *consonant* part against *vowel* part in time domain, then use two generator to generate them separately 120 | - I found this will bring breakups in conjunction points, thus audio sounds noisy, yet mstft loss will be more unstable 121 | - To shallow fuse the reference wav with half-decoded mel, rather than the `encode-merge-decode` UNet architecture 122 | - I found this will make the generator lazy to learn, even overfit to train set 123 | - might because waveform is far from audio semantics but near representation, so an encoder is necessary to extract semantical info 124 | - NOTE: the weight of mstft loss should NOT be too overwhelming in front of adversarial loss (like in HiFiGAN) 125 | - adversarial loss leads to more clear plosives (`b/p/g/k/d/t`), while `mstft loss` contributes little to consonants 126 | - `hop_length` in stft is much larger that `stride` in discriminators, thus mstft loss is usually more coarse than adversarial loss in time domain 127 | 128 | ## Acknowledgements 129 | 130 | Codes referred to: 131 | 132 | - [keithito's Tacotron](https://github.com/keithito/tacotron) 133 | - [jaywalnut310's MelGAN](https://github.com/jaywalnut310/MelGAN-Pytorch) 134 | - [jik876's official HiFiGAN](https://github.com/jik876/hifi-gan) 135 | - [ming024's FastSpeech2](https://github.com/ming024/FastSpeech2/) 136 | 137 | Ideas plagiarized from: 138 | 139 | - [Tacotron: Towards End-to-End Speech Synthesis](https://arxiv.org/abs/1703.10135) 140 | - [FastSpeech: Fast, Robust and Controllable Text to Speech](https://arxiv.org/abs/1905.09263) 141 | - [FastSpeech 2: Fast and High-Quality End-to-End Text to Speech](https://arxiv.org/abs/2006.04558) 142 | - [MelGAN: Generative Adversarial Networks for Conditional Waveform Synthesis](https://arxiv.org/abs/1910.06711) 143 | - [Multi-band MelGAN: Faster Waveform Generation for High-Quality Text-to-Speech](https://arxiv.org/abs/2005.05106) 144 | - [HiFi-GAN: Generative Adversarial Networks for Efficient and High Fidelity Speech Synthesis](https://arxiv.org/abs/2010.05646) 145 | - [UnivNet: A Neural Vocoder with Multi-Resolution Spectrogram Discriminators for High-Fidelity Waveform Generation](https://arxiv.org/abs/2106.07889) 146 | - [RefineGAN: Universally Generating Waveform Better than Ground Truth with Highly Accurate Pitch and Intensity Responses](https://arxiv.org/abs/2111.00962) 147 | 148 | code release kept under the MIT license, greatest thanks all the authors!! :) 149 | 150 | ---- 151 | 152 | by Armit 153 | 2022/02/15 154 | 2022/05/25 -------------------------------------------------------------------------------- /app.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Author: Armit 3 | # Create Time: 2022/04/15 4 | 5 | import os 6 | import re 7 | import pickle 8 | from time import time 9 | from argparse import ArgumentParser 10 | from tempfile import gettempdir 11 | 12 | import numpy as np 13 | from flask import Flask, request, jsonify, send_file 14 | from requests.utils import unquote 15 | from scipy.io import wavfile 16 | from xpinyin import Pinyin 17 | from requests import session 18 | 19 | 20 | BASE_PATH = os.path.dirname(os.path.abspath(__file__)) 21 | HTML_FILE = os.path.join(BASE_PATH, 'index.html') 22 | 23 | TMP_DIR = gettempdir() 24 | WAV_TMP_FILE = os.path.join(TMP_DIR, 'synth.wav') 25 | MP3_TMP_FILE = os.path.join(TMP_DIR, 'synth.mp3') 26 | 27 | REGEX_PUNCT_IGNORE = re.compile('、|:|;|“|”|‘|’') 28 | REGEX_PUNCT_BREAK = re.compile(',|。|!|?') 29 | MAX_CLUASE_LENGTH = 20 30 | CONVERT_MP3 = False 31 | 32 | SAMPLE_RATE = 22050 33 | SYNTH_API = 'http://127.0.0.1:5105/synth_spec' 34 | VOCODER_API = 'http://127.0.0.1:5104/vocode' 35 | 36 | 37 | app = Flask(__name__) 38 | html_page = None 39 | http = session() 40 | kanji2pinyin = Pinyin() 41 | 42 | 43 | def synth_and_save_file(txt): 44 | # Text-Norm 45 | if True: 46 | s = time() 47 | print(f'text/raw: {txt!r}') 48 | 49 | kanji = REGEX_PUNCT_IGNORE.sub('', txt) 50 | kanji = REGEX_PUNCT_BREAK.sub(' ', txt) 51 | segs = [''] # dummy init 52 | for rs in [s.strip() for s in kanji.split(' ') if s.strip()]: 53 | if (not segs[-1]) or (len(rs) + len(segs[-1]) < MAX_CLUASE_LENGTH): 54 | segs[-1] = segs[-1] + rs 55 | else: segs.append(rs) 56 | print(f'text/segs: {segs!r}') 57 | t = time() 58 | print('[TextNorm] Done in %.2fs' % (t - s)) 59 | 60 | # Synth 61 | if True: 62 | s = time() 63 | spec_clips = [] 64 | for seg in segs: 65 | pinyin = ' '.join(kanji2pinyin.get_pinyin(seg, tone_marks='numbers').split('-')) 66 | resp = http.post(SYNTH_API, json={'pinyin': pinyin}) 67 | spec = pickle.loads(resp.content) 68 | spec_clips.append(spec) 69 | spec = np.concatenate(spec_clips) 70 | print('spec.shape:', spec.shape) 71 | t = time() 72 | print('[Synth] Done in %.2fs' % (t - s)) 73 | 74 | # Vocode 75 | if True: 76 | s = time() 77 | resp = http.post(VOCODER_API, data=pickle.dumps(spec)) 78 | wav = pickle.loads(resp.content) 79 | wavfile.write(WAV_TMP_FILE, SAMPLE_RATE, wav) 80 | print('wav.length:', len(wav)) 81 | t = time() 82 | print('[Vocode] Done in %.2fs' % (t - s)) 83 | 84 | # Compress 85 | if CONVERT_MP3: 86 | s = time() 87 | cmd = f'ffmpeg -i "{WAV_TMP_FILE}" -f mp3 -acodec libmp3lame -y "{MP3_TMP_FILE}" -loglevel quiet' 88 | r = os.system(cmd) 89 | t = time() 90 | print('[Compress] Done in %.2fs' % (t - s)) 91 | 92 | 93 | @app.route('/', methods=['GET']) 94 | def root(): 95 | global html_page 96 | if not html_page: 97 | with open(HTML_FILE, encoding='utf-8') as fp: 98 | html_page = fp.read() 99 | return html_page 100 | 101 | 102 | @app.route('/synth', methods=['GET']) 103 | def synth(): 104 | txt = unquote(request.args.get('text')).strip() 105 | if not txt: return jsonify({'error': 'empty request'}) 106 | 107 | try: 108 | synth_and_save_file(txt) 109 | if CONVERT_MP3: 110 | return send_file(MP3_TMP_FILE, mimetype='audio/mp3') 111 | else: 112 | return send_file(WAV_TMP_FILE, mimetype='audio/wav') 113 | except Exception as e: 114 | print('[Error] %r' % e) 115 | return jsonify({'error': e}) 116 | 117 | 118 | if __name__ == '__main__': 119 | parser = ArgumentParser() 120 | parser.add_argument('--host', type=str, default='0.0.0.0') 121 | parser.add_argument('--port', type=int, default=5103) 122 | args = parser.parse_args() 123 | 124 | app.run(host=args.host, port=args.port, debug=False) 125 | -------------------------------------------------------------------------------- /img/draw_embd_sim.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | import numpy as np 5 | import seaborn as sns 6 | import matplotlib.pyplot as plt 7 | 8 | 9 | BASE_PATH = os.path.dirname(os.path.abspath(__file__)) 10 | 11 | 12 | def process(e, name): 13 | def dot_sim(x): 14 | return np.dot(x, x.T) 15 | 16 | def cosine_sim(x): 17 | n = np.linalg.norm(x, axis=-1, keepdims=True) 18 | return dot_sim(x) / (np.dot(n, n.T) + 1e-8) 19 | 20 | #s1 = dot_sim(e) 21 | #np.save(os.path.join(BASE_PATH, f'{name}_dot.npy'), s1) 22 | #sns.heatmap(s1) 23 | #plt.gca().invert_yaxis() 24 | #plt.savefig(os.path.join(BASE_PATH, f'{name}_dot.png')) 25 | #plt.clf() 26 | 27 | s2 = cosine_sim(e) 28 | #np.save(os.path.join(BASE_PATH, f'{name}_cosine.npy'), s2) 29 | sns.heatmap(s2) 30 | plt.gca().invert_yaxis() 31 | plt.savefig(os.path.join(BASE_PATH, f'{name}_cosine.png')) 32 | plt.clf() 33 | 34 | 35 | fp = sys.argv[1] 36 | fn = os.path.basename(fp) 37 | base, ext = os.path.splitext(fn) 38 | 39 | d = np.load(fp, allow_pickle=True) 40 | if isinstance(d, np.ndarray): 41 | process(d, base) 42 | elif isinstance(d, dict): 43 | for k in d.keys(): 44 | if not k.endswith('embd'): continue 45 | process(d[k], f'{base}_{k}') 46 | else: raise 47 | -------------------------------------------------------------------------------- /img/gen_hfg.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kahsolt/TransTacoS-RetuneGAN/1963a1ca86045aac270e68e58d4187665efd91d4/img/gen_hfg.wav -------------------------------------------------------------------------------- /img/gen_rtg.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kahsolt/TransTacoS-RetuneGAN/1963a1ca86045aac270e68e58d4187665efd91d4/img/gen_rtg.wav -------------------------------------------------------------------------------- /img/gen_tts.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kahsolt/TransTacoS-RetuneGAN/1963a1ca86045aac270e68e58d4187665efd91d4/img/gen_tts.wav -------------------------------------------------------------------------------- /img/gt_hfg.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kahsolt/TransTacoS-RetuneGAN/1963a1ca86045aac270e68e58d4187665efd91d4/img/gt_hfg.wav -------------------------------------------------------------------------------- /img/gt_rtg.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kahsolt/TransTacoS-RetuneGAN/1963a1ca86045aac270e68e58d4187665efd91d4/img/gt_rtg.wav -------------------------------------------------------------------------------- /img/gt_tts.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kahsolt/TransTacoS-RetuneGAN/1963a1ca86045aac270e68e58d4187665efd91d4/img/gt_tts.wav -------------------------------------------------------------------------------- /img/hifi-gan.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kahsolt/TransTacoS-RetuneGAN/1963a1ca86045aac270e68e58d4187665efd91d4/img/hifi-gan.png -------------------------------------------------------------------------------- /img/rtg_D.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kahsolt/TransTacoS-RetuneGAN/1963a1ca86045aac270e68e58d4187665efd91d4/img/rtg_D.png -------------------------------------------------------------------------------- /img/rtg_G.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kahsolt/TransTacoS-RetuneGAN/1963a1ca86045aac270e68e58d4187665efd91d4/img/rtg_G.png -------------------------------------------------------------------------------- /img/rtg_arch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kahsolt/TransTacoS-RetuneGAN/1963a1ca86045aac270e68e58d4187665efd91d4/img/rtg_arch.png -------------------------------------------------------------------------------- /img/rtg_cmp.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kahsolt/TransTacoS-RetuneGAN/1963a1ca86045aac270e68e58d4187665efd91d4/img/rtg_cmp.png -------------------------------------------------------------------------------- /img/tacotron-cn.mini.step-48000-embed_cosine.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kahsolt/TransTacoS-RetuneGAN/1963a1ca86045aac270e68e58d4187665efd91d4/img/tacotron-cn.mini.step-48000-embed_cosine.png -------------------------------------------------------------------------------- /img/tacotron-cn.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kahsolt/TransTacoS-RetuneGAN/1963a1ca86045aac270e68e58d4187665efd91d4/img/tacotron-cn.png -------------------------------------------------------------------------------- /img/tacotron-en.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kahsolt/TransTacoS-RetuneGAN/1963a1ca86045aac270e68e58d4187665efd91d4/img/tacotron-en.png -------------------------------------------------------------------------------- /img/transtacos.mini.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kahsolt/TransTacoS-RetuneGAN/1963a1ca86045aac270e68e58d4187665efd91d4/img/transtacos.mini.png -------------------------------------------------------------------------------- /img/transtacos.mini.step-48000_prds_embd_cosine.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kahsolt/TransTacoS-RetuneGAN/1963a1ca86045aac270e68e58d4187665efd91d4/img/transtacos.mini.step-48000_prds_embd_cosine.png -------------------------------------------------------------------------------- /img/transtacos.mini.step-48000_text_embd_cosine.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kahsolt/TransTacoS-RetuneGAN/1963a1ca86045aac270e68e58d4187665efd91d4/img/transtacos.mini.step-48000_text_embd_cosine.png -------------------------------------------------------------------------------- /img/tts.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kahsolt/TransTacoS-RetuneGAN/1963a1ca86045aac270e68e58d4187665efd91d4/img/tts.png -------------------------------------------------------------------------------- /img/tts_arch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kahsolt/TransTacoS-RetuneGAN/1963a1ca86045aac270e68e58d4187665efd91d4/img/tts_arch.png -------------------------------------------------------------------------------- /img/tts_decoder.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kahsolt/TransTacoS-RetuneGAN/1963a1ca86045aac270e68e58d4187665efd91d4/img/tts_decoder.png -------------------------------------------------------------------------------- /img/tts_embed.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kahsolt/TransTacoS-RetuneGAN/1963a1ca86045aac270e68e58d4187665efd91d4/img/tts_embed.png -------------------------------------------------------------------------------- /img/tts_encoder.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kahsolt/TransTacoS-RetuneGAN/1963a1ca86045aac270e68e58d4187665efd91d4/img/tts_encoder.png -------------------------------------------------------------------------------- /img/tts_out_align.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kahsolt/TransTacoS-RetuneGAN/1963a1ca86045aac270e68e58d4187665efd91d4/img/tts_out_align.png -------------------------------------------------------------------------------- /img/tts_out_spec.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kahsolt/TransTacoS-RetuneGAN/1963a1ca86045aac270e68e58d4187665efd91d4/img/tts_out_spec.png -------------------------------------------------------------------------------- /img/tts_posnet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kahsolt/TransTacoS-RetuneGAN/1963a1ca86045aac270e68e58d4187665efd91d4/img/tts_posnet.png -------------------------------------------------------------------------------- /img/y_tmpl.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kahsolt/TransTacoS-RetuneGAN/1963a1ca86045aac270e68e58d4187665efd91d4/img/y_tmpl.wav -------------------------------------------------------------------------------- /index.html: -------------------------------------------------------------------------------- 1 | Demo 2 | 13 | 14 |
15 | 16 | 17 |
18 |

19 | 20 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | librosa==0.8.1 2 | soundfile==0.10.3 3 | matplotlib==3.1.3 4 | numpy==1.20.2 5 | scipy==1.6.2 6 | Flask==2.0.1 7 | requests==2.25.1 8 | xpinyin==0.7.6 9 | tensorboardX==2.4.1 10 | seaborn 11 | #gast==0.2.2 12 | -------------------------------------------------------------------------------- /retunegan/Makefile: -------------------------------------------------------------------------------- 1 | ifeq ($(shell uname -s), Linux) 2 | BASE_PATH=~/Data 3 | else 4 | BASE_PATH=D:/Desktop/Workspace/Data 5 | endif 6 | 7 | DATASET=DataBaker 8 | DATA_PATH=$(BASE_PATH)/$(DATASET).tts_processed 9 | #LOG_PATH=$(BASE_PATH)/rtg-$(DATASET).$(VER) 10 | LOG_PATH=$(BASE_PATH)/rtg-$(DATASET) 11 | 12 | .PHONY: train test server clean stat 13 | 14 | train: 15 | python train.py \ 16 | --data_dp $(DATA_PATH) \ 17 | --log_path $(LOG_PATH) \ 18 | --epochs 3100 19 | 20 | finetune: 21 | python train.py \ 22 | --finetune \ 23 | --data_dp $(DATA_PATH) \ 24 | --log_path $(LOG_PATH) \ 25 | --epochs 3100 26 | 27 | test: 28 | python infer.py \ 29 | --log_path $(LOG_PATH) \ 30 | --input_path test 31 | 32 | server: 33 | python server.py \ 34 | --log_path $(LOG_PATH) 35 | 36 | stat: 37 | tensorboard \ 38 | --logdir $(LOG_PATH) \ 39 | --port 5101 40 | 41 | clean: 42 | rm -rf $(LOG_PATH) 43 | -------------------------------------------------------------------------------- /retunegan/audio.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Author: Armit 3 | # Create Time: 2022/01/07 4 | 5 | import torch 6 | from torch.nn import AvgPool2d 7 | import numpy as np 8 | import numpy.random as R 9 | import librosa as L 10 | from scipy.io import wavfile 11 | 12 | import seaborn as sns 13 | import matplotlib.pyplot as plt 14 | 15 | import hparam as hp 16 | R.seed(hp.randseed) 17 | 18 | 19 | eps = 1e-5 20 | mel_basis = L.filters.mel(hp.sample_rate, hp.n_fft, n_mels=hp.n_mel, fmin=hp.fmin, fmax=hp.fmax) 21 | mag_to_mel = lambda x: np.dot(mel_basis, x) 22 | 23 | avg_pool_2d = AvgPool2d(kernel_size=3, stride=1, padding=1) 24 | 25 | mel_basis_torch = { } # { n_fft: mel_basis } 26 | window_fn_torch = { } # { win_length: window_fn } 27 | 28 | 29 | def load_wav(path): # float values in range (-1,1) 30 | y, _ = L.load(path, sr=hp.sample_rate, mono=True, res_type='kaiser_best') 31 | return y.astype(np.float32) # [T,] 32 | 33 | 34 | def save_wav(wav, path): 35 | wavfile.write(path, hp.sample_rate, wav) 36 | 37 | 38 | def align_wav(wav, r=hp.hop_length): 39 | d = len(wav) % r 40 | if d != 0: 41 | wav = np.pad(wav, (0, (r - d))) 42 | return wav 43 | 44 | 45 | def augment_wav(y, pitch_shift=True, time_stretch=True, dynamic_scale=True): 46 | if pitch_shift: 47 | # 75% unmodified, 25% shifted 48 | if R.random() > 0.75: 49 | # ~10% unmodified, ~30% in (-1,1), ~47% in (-2,+2), ~%74 in (-4,+4) 50 | semitone = max(min(round(R.normal(scale=12/3)), 12), -12) # 3-sigma principle:99.74% in [mu-3*sigma,mu+3*sigma] 51 | if semitone != 0: y = L.effects.pitch_shift(y, hp.sample_rate, semitone, res_type='kaiser_best') 52 | 53 | if time_stretch: 54 | # 90% unmodified, 10% twisted; because `time_stretch`` hurts quality a lot 55 | if R.random() > 0.90: 56 | alpha = 2 ** R.normal(scale=1/5) 57 | if abs(alpha - 1.0) < 0.1: alpha = 1.0 58 | if alpha != 1.0: y = L.effects.time_stretch(y=y, rate=alpha, win_length=hp.win_length, hop_length=hp.hop_length) 59 | 60 | if dynamic_scale: 61 | # 25% unmodified, 75% global shift 62 | r = R.random() 63 | if r > 0.25: 64 | alpha = 2 ** R.normal(scale=1/3) 65 | y = y * alpha 66 | absmax = max(y.max(), -y.min()) 67 | if absmax > 1.0: y /= absmax 68 | 69 | return y.astype(np.float32) # [T,] 70 | 71 | 72 | def augment_spec(S, time_mask=True, freq_mask=True, prob=0.2, rounds=3, freq_width=9, time_width=3): 73 | F, T = S.shape 74 | S = torch.from_numpy(S).unsqueeze(0) 75 | 76 | # local mask 77 | # 10.7% unmodified, 57.0% maskes <=2 times, 86.0% masked <=4 times 78 | for _ in range(rounds): 79 | if freq_mask and R.random() < prob: 80 | s = R.randint(0, F - freq_width) 81 | r = R.randint(1, freq_width) 82 | mask_val = R.uniform(low=S.min(), high=S.mean()) 83 | S[:, s:s+r, :] = torch.ones([1, r, T]) * mask_val 84 | 85 | if time_mask and R.random() < prob: 86 | s = R.randint(0, T - time_width) 87 | r = R.randint(1, time_width) 88 | mask_val = R.uniform(low=S.min(), high=S.mean()) 89 | S[:, :, s:s+r] = torch.ones([1, F, r]) * mask_val 90 | 91 | # global blur 92 | S = avg_pool_2d(S) 93 | 94 | S = S.squeeze(0).numpy() 95 | return S.astype(np.float32) 96 | 97 | 98 | def get_zcr(y): 99 | zcr = L.feature.zero_crossing_rate(y, frame_length=hp.win_length, hop_length=hp.hop_length)[0] 100 | return zcr.astype(np.float32) # [T,] 101 | 102 | 103 | def get_c0(y): 104 | c0 = L.feature.rms(y=y, frame_length=hp.win_length, hop_length=hp.hop_length)[0] 105 | return c0.astype(np.float32) # [T,] 106 | 107 | 108 | def get_uv(zcr, dyn): 109 | uv = np.empty_like(zcr) 110 | for i in range(len(uv)): 111 | # NOTE: these numbers are magic, tune by hand according to your dataset 112 | uv[i] = zcr[i] > 0.18 or dyn[i] < 0.03 113 | return uv 114 | 115 | 116 | def get_mag(y, clamp_low=True): 117 | D = L.stft(y, n_fft=hp.n_fft, hop_length=hp.hop_length, win_length=hp.win_length, window=hp.window_fn) 118 | S = np.abs(D) 119 | mag = np.log(S.clip(min=eps) if clamp_low else S) 120 | return mag.astype(np.float32) # [F, T] 121 | 122 | 123 | def get_mel(y, clamp_low=True): 124 | M = L.feature.melspectrogram(y=y, sr=hp.sample_rate, n_fft=hp.n_fft, hop_length=hp.hop_length, win_length=hp.win_length, 125 | n_mels=hp.n_mel, fmin=hp.fmin, fmax=hp.fmax, 126 | window=hp.window_fn, power=1, htk=hp.mel_scale=='htk') 127 | mel = np.log(M.clip(min=eps) if clamp_low else M) 128 | return mel.astype(np.float32) # [M, T] 129 | 130 | 131 | def _griffinlim(S, wavlen=None): 132 | if hp.gl_power: S = S ** hp.gl_power 133 | y = L.griffinlim(S, n_iter=hp.gl_iters, 134 | hop_length=hp.hop_length, win_length=hp.win_length, window=hp.window_fn, 135 | length=wavlen, momentum=hp.gl_momentum, init='random', random_state=hp.randseed) 136 | return y.astype(np.float32) 137 | 138 | 139 | def inv_mag(mag, wavlen=None): 140 | S = np.exp(mag) # [F/F-1, T], reverse np.log 141 | F, T = mag.shape 142 | if F == hp.n_freq - 1: # NOTE: preprend zero DC component 143 | S = np.concatenate([np.zeros([1, T]), S], axis=0) 144 | #print('S.min():', S.min(), 'S.max(): ', S.max(), 'S.mean(): ', S.mean()) 145 | y = _griffinlim(S, wavlen) 146 | if wavlen: assert len(y) == wavlen 147 | return y 148 | 149 | 150 | def get_stft_torch(y, n_fft, win_length, hop_length): 151 | ''' 该函数得到原始的Mel值,没有数值下截断和取对数过程 ''' 152 | 153 | global mel_basis_torch, window_fn_torch 154 | if win_length not in window_fn_torch: 155 | win_functor = getattr(torch, f'{hp.window_fn}_window') 156 | window_fn_torch[win_length] = win_functor(win_length).to(y.device) # [n_fft] 157 | if n_fft not in mel_basis_torch: 158 | mel_filter = L.filters.mel(hp.sample_rate, n_fft, hp.n_mel, hp.fmin, hp.fmax) 159 | mel_basis_torch[n_fft] = torch.from_numpy(mel_filter).float().to(y.device) # [n_mel, n_fft//2+1] 160 | 161 | D = torch.stft(y, n_fft, return_complex=True, # [n_fft/2+1, n_frames, 2], last dim 2 for real/image parts 162 | hop_length=hop_length, win_length=win_length, window=window_fn_torch[win_length], 163 | center=True, pad_mode='reflect', normalized=False, onesided=True) 164 | 165 | #S = torch.sqrt(D.pow(2).sum(-1) + (1e-9)) # [n_fft/2+1, n_frames], get modulo, aka. magnitude 166 | S = torch.abs(D + 1e-9) 167 | M = torch.matmul(mel_basis_torch[n_fft], S) # [n_mel, n_frames] 168 | P = torch.angle(D) 169 | 170 | return S, M, P 171 | -------------------------------------------------------------------------------- /retunegan/audio_proxy.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Author: Armit 3 | # Create Time: 2022/04/16 4 | 5 | # NOTE: proxy by TransTacoS for finetune 6 | 7 | import os 8 | import sys 9 | from importlib import import_module 10 | BASH_PATH = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 11 | sys.path.append(BASH_PATH) 12 | print('TRANSTACOS_PATH:', os.path.join(BASH_PATH, 'transtacos')) 13 | AP = import_module('transtacos.audio') 14 | -------------------------------------------------------------------------------- /retunegan/data.py: -------------------------------------------------------------------------------- 1 | import os 2 | from random import randint 3 | 4 | import numpy as np 5 | from torch.utils.data import Dataset 6 | import seaborn as sns 7 | import matplotlib.pyplot as plt 8 | 9 | import hparam as hp 10 | import audio as A 11 | 12 | # for fintune tune with TranTacoS 13 | from audio_proxy import AP 14 | 15 | 16 | assert hp.segment_size % hp.hop_length == 0 17 | frames_per_seg = hp.segment_size // hp.hop_length 18 | 19 | 20 | class Dataset(Dataset): 21 | 22 | def __init__(self, name, data_dp, finetune=False, limit=None): 23 | self.is_train = name == 'train' 24 | self.data_dp = data_dp # the preprocessed folder containing index files and mel/mag features (if finetune) 25 | self.finetune = finetune 26 | 27 | with open(os.path.join(data_dp, 'wav_path.txt')) as fh: 28 | wav_path = fh.read().strip() 29 | with open(os.path.join(data_dp, f'{name}.txt'), encoding='utf-8') as fh: 30 | self.wav_fps = [os.path.join(wav_path, line.split('|')[0] + '.wav') for line in fh.readlines() if line] 31 | if limit: self.wav_fps = self.wav_fps[:limit] 32 | 33 | self.data = [None] * len(self.wav_fps) 34 | 35 | def __len__(self): 36 | return len(self.wav_fps) 37 | 38 | def __getitem__(self, index): 39 | # repreprocess & cache 40 | if self.data[index] is None: 41 | ''' prepare GT wav ''' 42 | wav_fp = self.wav_fps[index] 43 | if not self.finetune: 44 | wav = A.load_wav(wav_fp) 45 | if self.is_train: 46 | # aug data once and freeze 47 | wav = A.augment_wav(wav) 48 | wav = A.align_wav(wav) 49 | else: 50 | # keep identical to `preprocessor.make_metadata()` of TransTacoS 51 | wav = AP.load_wav(wav_fp) 52 | wav = AP.trim_silence(wav) 53 | wav = AP.align_wav(wav) 54 | 55 | wavlen = len(wav) 56 | 57 | ''' prepare GT mel ''' 58 | if not self.finetune: 59 | # `[:-1]` to avoid extra tailing frame 60 | mag = A.get_mag(wav[:-1]) # [M, T] 61 | else: 62 | # keep identical to preprocessors of TransTacoS 63 | name = os.path.splitext(os.path.basename(wav_fp))[0] 64 | mag = np.load(os.path.join(self.data_dp, f'mag-{name}.npy')) # [M, T] 65 | mag = AP.spec_to_natural_scale(mag) 66 | 67 | mel = A.mag_to_mel(mag) 68 | 69 | if self.is_train: 70 | # aug data once and freeze 71 | mel_aug = A.augment_spec(mel, rounds=5) 72 | mel = mel / 2 + mel_aug / 2 73 | 74 | ''' prepare ref wav ''' 75 | try: 76 | wav_tmpl = A.inv_mag(mag, wavlen=wavlen-1) # `wavlen-1` to avoid extra tailing frame 77 | wav_tmpl = np.pad(wav_tmpl, (0, 1)) # pad to align 78 | except: 79 | breakpoint() 80 | 81 | # dy: 按理说一阶差分可以大致显示脉冲位置,但是wav_tmpl的相位可能有点差 82 | if hp.ref_wav == 'dy': 83 | wav_tmpl = np.pad(wav_tmpl, (0, 1)) 84 | wav_tmpl = np.asarray([b-a for a, b in zip(wav_tmpl[:-1], wav_tmpl[1:])]) 85 | 86 | ''' prepare u/v mask ''' 87 | if hp.split_cv: # 时域法误差有点大 88 | zcr = A.get_zcr(wav_tmpl[:-1]) 89 | dyn = A.get_c0 (wav_tmpl[:-1]) 90 | uv = A.get_uv (zcr, dyn) 91 | 92 | ''' prepare u/v-splitted mel & ref wav ''' 93 | if hp.split_cv: 94 | uv_ex = np.repeat(uv, hp.hop_length) 95 | wav_tmpl_c = wav_tmpl * uv_ex 96 | wav_tmpl_v = wav_tmpl * (1 - uv_ex) 97 | mel_min = mel.min() 98 | mel_shift = mel - mel_min # assure > 0 for mask product 99 | mel_c = mel_shift * uv + mel_min 100 | mel_v = mel_shift * (1-uv) + mel_min 101 | 102 | if not 'check': 103 | if hp.split_cv: 104 | wav_c = wav * uv_ex 105 | wav_v = wav * (1 - uv_ex) 106 | plt.subplot(411); plt.plot(wav_c, 'r') ; plt.plot(wav_v, 'b') 107 | plt.subplot(412); plt.plot(wav_tmpl_c, 'r') ; plt.plot(wav_tmpl_v, 'b') 108 | plt.subplot(413); sns.heatmap(mel_c, cbar=False) ; plt.gca().invert_yaxis() 109 | plt.subplot(414); sns.heatmap(mel_v, cbar=False) ; plt.gca().invert_yaxis() 110 | plt.show() 111 | else: 112 | plt.subplot(411); plt.plot(wav, 'b') 113 | plt.subplot(412); sns.heatmap(mag, cbar=False) ; plt.gca().invert_yaxis() 114 | plt.subplot(413); plt.plot(wav_tmpl, 'r') 115 | plt.subplot(414); sns.heatmap(mel, cbar=False) ; plt.gca().invert_yaxis() 116 | plt.show() 117 | 118 | ''' check shape aligns ''' 119 | if hp.split_cv: assert len(dyn) == len(zcr) == mel.shape[1] 120 | assert len(wav) == len(wav_tmpl) == mel.shape[1] * hp.hop_length 121 | 122 | ''' done ''' 123 | if hp.split_cv: 124 | self.data[index] = (mel, wav, mel_c, mel_v, wav_tmpl_c, wav_tmpl_v, uv_ex) 125 | else: 126 | self.data[index] = (mel, wav, wav_tmpl) 127 | 128 | # get from cache (full length data) 129 | if hp.split_cv: 130 | mel, wav, mel_c, mel_v, wav_tmpl_c, wav_tmpl_v, uv_ex = self.data[index] 131 | else: 132 | mel, wav, wav_tmpl = self.data[index] 133 | 134 | # make slices during training: wav[S=8192] <=> mel[T=32] 135 | if self.is_train: 136 | wavlen, mellen = len(wav), mel.shape[1] 137 | if wavlen > hp.segment_size: 138 | cp = randint(0, mellen - frames_per_seg - 1) 139 | if hp.split_cv: 140 | mel_c = mel_c [:, cp : cp + frames_per_seg] 141 | mel_v = mel_v [:, cp : cp + frames_per_seg] 142 | wav_tmpl_c = wav_tmpl_c[cp * hp.hop_length : (cp + frames_per_seg) * hp.hop_length] 143 | wav_tmpl_v = wav_tmpl_v[cp * hp.hop_length : (cp + frames_per_seg) * hp.hop_length] 144 | wav = wav [cp * hp.hop_length : (cp + frames_per_seg) * hp.hop_length] 145 | uv_ex = uv_ex [cp * hp.hop_length : (cp + frames_per_seg) * hp.hop_length] 146 | else: 147 | mel = mel [:, cp : cp + frames_per_seg] 148 | wav = wav [cp * hp.hop_length : (cp + frames_per_seg) * hp.hop_length] 149 | wav_tmpl = wav_tmpl [cp * hp.hop_length : (cp + frames_per_seg) * hp.hop_length] 150 | else: 151 | if hp.split_cv: 152 | mel_c = np.pad(mel_c, (0, 0, 0, frames_per_seg - mellen), mel.min()) 153 | mel_v = np.pad(mel_v, (0, 0, 0, frames_per_seg - mellen), mel.min()) 154 | wav_tmpl_c = np.pad(wav_tmpl_c, (0, hp.segment_size - wavlen)) 155 | wav_tmpl_v = np.pad(wav_tmpl_v, (0, hp.segment_size - wavlen)) 156 | wav = np.pad(wav, (0, hp.segment_size - wavlen)) 157 | uv_ex = np.pad(uv_ex, (0, hp.segment_size - wavlen)) 158 | else: 159 | mel = np.pad(mel, (0, 0, 0, frames_per_seg - mellen), mel.min()) 160 | wav = np.pad(wav, (0, hp.segment_size - wavlen)) 161 | wav_tmpl = np.pad(wav_tmpl, (0, hp.segment_size - wavlen)) 162 | 163 | # mel: 由外源mag滤波而来,作为推断时的主要输入 164 | # wav_tmpl: 基于外源mag用传统算法得到的粗糙波形,作为推断时的参考输入 165 | # wav: 真实的目标录音波形,作为训练时的目标输出 166 | # uv_ex: 清音掩码,可作为推断时的参考输入 167 | if hp.split_cv: 168 | ret = mel_c, mel_v, wav_tmpl_c, wav_tmpl_v, wav, uv_ex 169 | else: 170 | ret = mel, wav_tmpl, wav 171 | 172 | return [x.astype(np.float32) for x in ret] 173 | -------------------------------------------------------------------------------- /retunegan/hparam.py: -------------------------------------------------------------------------------- 1 | '''Audio: proxy by trastacos, plz keep sync''' 2 | # Audio 3 | sample_rate = 22050 # sample rate (Hz) of wav file 4 | n_fft = 2048 5 | win_length = 1024 # :=n_fft//2 6 | hop_length = 256 # :=win_length//4, 11.6ms, 平均1个拼音音素对应9帧(min2~max20) 7 | n_mel = 80 # MEL谱段数 (default: 160), 120 should be more reasonable though 8 | n_freq = 1025 # 线性谱段数 :=n_fft//2+1 9 | preemphasis = 0.97 # 增强高频,使EQ均衡 10 | ref_level_db = 20 # 最高考虑的谱幅值(虚拟0dB),理论上安静环境下取94,但实际上录音越嘈杂该值应越小 (default: 20) 11 | min_level_db = -100 # 最低考虑的谱幅值,用于动态范围截断压缩 (default: -100) 12 | max_abs_value = 4 # 将谱幅值正则化到 [-max_abs_value, max_abs_value] 13 | trim_below_peak_db = 35 # trim beginning/ending silence parts (default:60) 14 | fmin = 125 # MEL滤波器组频率上下限 (set 55/3600 for male) 15 | fmax = 7600 16 | rf0min = 'D2' # 基频检测上下限 17 | rf0max = 'D5' 18 | 19 | ## see `Databaker stats` or `stats.txt` in preprocessed folder 20 | c0min = 4.6309418394230306e-05 21 | c0max = 0.3751049339771271 22 | f0min = 73.25581359863281 23 | f0max = 595.9459228515625 24 | n_tone = 5+1 25 | n_prds = 5+1 26 | n_c0_bins = 32 27 | n_f0_bins = None # keep None for auto detect using f0min & f0max 28 | n_f0_min = None # as offset 29 | maxlen_text = 128 # for pos_enc, 27 in train set 30 | maxlen_spec = 1024 # for pos_enc, 524 in train set 31 | 32 | #################################################################################################################### 33 | 34 | '''Audio''' 35 | segment_size = 8192 36 | window_fn = 'hann' # ['bartlett', 'blackman', 'hamming', 'hann', 'kaiser'] 37 | mel_scale = 'slaney' # 'htk' is smooth, 'slaney' breaks at f=1000; see `https://blog.csdn.net/qq_39995815/article/details/116269040` 38 | gl_iters = 4 39 | gl_momentum = 0.7 # 0.99 may deadloop 40 | gl_power = 1.2 # power magnitudes before Griffin-Lim 41 | ref_wav = 'y' # ['y', 'dy'] 42 | 43 | 44 | '''Model''' 45 | # RetuneCNN ( | | mstft= at 30 epoch) 46 | # HiFiGAN_mini ( | | mstft= at 30 epoch) 47 | # HiFiGAN_micro ( | | mstft= at 30 epoch) 48 | # HiFiGAN_mu ( | | mstft= at 30 epoch) 49 | # RefineGAN ( | | mstft= at 30 epoch) 50 | # RefineGAN_small (2748371 | | mstft= at 30 epoch) 51 | # MelGAN (4524290 | 2.36 s/b | mstft=10.084 at 30 epoch) 52 | # MelGANRetune (1409427 | 2.42 s/b | mstft= 7.000 at 30 epoch) 53 | # MelGANSplit 54 | # HiFiGAN (1421314 | 2.30 s/b | mstft=10.346 at 30 epoch) 55 | # HiFiGANRetune (1716627 | 2.45 s/b | mstft= 7.041 at 30 epoch) # 高频撕裂 56 | # HiFiGANSplit (2849890 | 2.49 s/b | mstft=11.320 at 30 epoch) 57 | 58 | # generator 59 | generator_ver = 'RefineGAN_small' 60 | split_cv = generator_ver.endswith('Split') 61 | upsample_rates = [8, 8, 4] 62 | upsample_kernel_sizes = [15, 15, 7] 63 | upsample_initial_channel = 256 64 | resblock_kernel_sizes = [3, 5, 7] 65 | resblock_dilation_sizes = [[1, 2], [2, 6], [3, 12]] 66 | #resblock_kernel_sizes = [3, 7, 11] 67 | #resblock_dilation_sizes = [[2, 3], [3, 5], [5, 11]] 68 | 69 | # discriminator 70 | msd_layers = 3 71 | mpd_periods = [3, 5, 7, 11] 72 | multi_stft_params = [ 73 | # (n_fft, win_length, hop_length) 74 | # (2048, 1024, 256), 75 | # (1024, 512, 128), 76 | # ( 512, 256, 64), 77 | # by UnivNet 78 | (2048, 1024, 240), 79 | (1024, 512, 120), 80 | ( 512, 256, 60), 81 | ] 82 | phd_layers = len(multi_stft_params) 83 | phd_input = 'stft' # ['phase', 'stft'], phase似乎不太行 84 | 85 | # loss 86 | relative_gan_loss = False 87 | strip_mirror_loss = False 88 | dynamic_loss = True 89 | envelope_loss = False 90 | envelope_pool_k = 160 # use `tools.test_envolope.py` to pickle proper value 91 | downsample_pool_k = 4 92 | 93 | 94 | '''Misc''' 95 | from sys import platform 96 | debug = platform == 'win32' # debug locally on Windows, actually train on Linux 97 | randseed = 114514 98 | 99 | 100 | '''Training''' 101 | num_workers = 1 if debug else 4 102 | batch_size = 4 if debug else 16 # 16 = 567.6 steps per epoch 103 | learning_rate_d = 2e-4 104 | learning_rate_g = 1.8e-4 105 | d_train_times = 2 # 更新一次G的同时更新多少次D 106 | adam_b1 = 0.8 107 | adam_b2 = 0.99 108 | lr_decay = 0.999 109 | 110 | w_loss_fm = 2 111 | w_loss_mstft = 8 112 | w_loss_env = 4 113 | w_loss_dyn = 4 114 | w_loss_sm = 0.01 115 | 116 | 117 | '''Eval''' 118 | valid_limit = batch_size * 4 119 | -------------------------------------------------------------------------------- /retunegan/infer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import importlib 4 | from argparse import ArgumentParser 5 | from retunegan.audio import get_mag, inv_mag 6 | 7 | import torch 8 | import numpy as np 9 | 10 | from models import Generator 11 | from audio import load_wav, save_wav, get_mel 12 | from utils import load_checkpoint, scan_checkpoint 13 | 14 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 15 | generator = None 16 | 17 | 18 | def load_generator(a): 19 | global generator 20 | if not generator: 21 | Generator = globals().get(f'Generator_{h.generator_ver}') 22 | generator = Generator().to(device) 23 | state_dict_g = state_dict_g = load_checkpoint(scan_checkpoint(args.log_path, 'g_'), device) 24 | generator.load_state_dict(state_dict_g['generator']) 25 | generator.eval() 26 | generator.remove_weight_norm() 27 | return generator 28 | 29 | 30 | def inference(a, x, wav_ref): 31 | generator = load_generator(a) 32 | with torch.no_grad(): 33 | y_g_hat = generator(x, wav_ref) 34 | wav = y_g_hat.squeeze() 35 | wav = wav.cpu().numpy().astype(np.float32) 36 | return wav 37 | 38 | 39 | def inference_from_mag(a, fp): 40 | x = np.load(fp) 41 | x = torch.from_numpy(x).to(device) 42 | if x.size(1) == h.n_freq: x = x.T 43 | 44 | if len(x.shape) < 3: x = x.unsqueeze(0) # set batch_size=1 45 | y = inv_mag(x) 46 | wav = inference(a, x, y) 47 | 48 | wav_fp = os.path.join(a.output_dir, os.path.splitext(os.path.basename(fp))[0] + '_gen_from_mag.wav') 49 | save_wav(wav, wav_fp) 50 | print(f' Done {wav_fp!r}') 51 | 52 | 53 | def inference_from_wav(a, fp): 54 | wav = load_wav(fp) 55 | wav = torch.from_numpy(wav).to(device) 56 | x = get_mag(wav) 57 | 58 | if len(x.shape) < 3: x = x.unsqueeze(0) # set batch_size=1 59 | y = inv_mag(x) 60 | wav = inference(a, x, y) 61 | 62 | wav_fp = os.path.join(a.input_path, os.path.splitext(os.path.basename(fp))[0] + '_gen_from_wav.wav') 63 | save_wav(wav, wav_fp) 64 | print(f' Done {wav_fp!r}') 65 | 66 | 67 | if __name__ == '__main__': 68 | parser = ArgumentParser() 69 | parser.add_argument('--input_path', default='test') 70 | parser.add_argument('--log_path', required=True) 71 | a = parser.parse_args() 72 | 73 | # load frozen hparam 74 | sys.path.insert(0, a.log_path) 75 | h = importlib.import_module('hparam') 76 | torch.manual_seed(h.randseed) 77 | torch.cuda.manual_seed(h.randseed) 78 | 79 | print('Initializing Reference Process..') 80 | fps = [os.path.join(a.input_path, fn) for fn in os.listdir(a.input_path)] 81 | for fp in [fp for fp in fps if fp.lower().endswith('.npy')]: 82 | inference_from_mag(a, fp) 83 | for fp in [fp for fp in fps if fp.lower().endswith('.wav')]: 84 | inference_from_wav(a, fp) 85 | -------------------------------------------------------------------------------- /retunegan/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .generator import * 2 | from .discrminator import * 3 | from .loss import * 4 | -------------------------------------------------------------------------------- /retunegan/models/discrminator.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Author: Armit 3 | # Create Time: 2022/04/16 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from torch.nn import Conv1d, Conv2d 9 | from torch.nn.utils import weight_norm, spectral_norm 10 | 11 | import hparam as hp 12 | from utils import * 13 | 14 | PI = 3.14159265358979 15 | 16 | 17 | class DiscriminatorS(nn.Module): 18 | 19 | def __init__(self, use_sn=False): 20 | super().__init__() 21 | 22 | #norm_f = spectral_norm if use_sn else weight_norm 23 | 24 | # [8192]; 降采样 4*4*4*4=256 倍,对于各个降采样版本即降采样[256,512,1024]倍 25 | sel = 'MelGAN_small' 26 | if sel == 'MelGAN': 27 | self.convs = nn.ModuleList([ 28 | weight_norm(Conv1d( 1, 16, 15, 1, padding=7)), 29 | weight_norm(Conv1d( 16, 64, 41, 4, padding=20, groups=4 )), 30 | weight_norm(Conv1d( 64, 256, 41, 4, padding=20, groups=16)), 31 | weight_norm(Conv1d( 256, 1024, 41, 4, padding=20, groups=64)), 32 | weight_norm(Conv1d(1024, 1024, 41, 4, padding=20, groups=256)), 33 | weight_norm(Conv1d(1024, 1024, 5, 1, padding=2)), 34 | ]) 35 | self.conv_post = weight_norm(Conv1d(1024, 1, 3, 1, padding=1)) 36 | if sel == 'MelGAN_small': 37 | self.convs = nn.ModuleList([ 38 | weight_norm(Conv1d( 1, 32, 15, 1, padding=7)), 39 | weight_norm(Conv1d( 32, 64, 41, 2, padding=20, groups=4 )), 40 | weight_norm(Conv1d( 64, 128, 41, 2, padding=20, groups=8 )), 41 | weight_norm(Conv1d(128, 512, 41, 4, padding=20, groups=32)), 42 | weight_norm(Conv1d(512, 512, 41, 4, padding=20, groups=64)), 43 | weight_norm(Conv1d(512, 512, 5, 1, padding=2)), 44 | ]) 45 | self.conv_post = weight_norm(Conv1d(512, 1, 3, 1, padding=1)) 46 | elif sel == 'HiFiGAN': 47 | self.convs = nn.ModuleList([ 48 | weight_norm(Conv1d( 1, 128, 15, 1, padding=7)), 49 | weight_norm(Conv1d( 128, 128, 41, 2, padding=20, groups=4 )), 50 | weight_norm(Conv1d( 128, 256, 41, 2, padding=20, groups=16)), 51 | weight_norm(Conv1d( 256, 512, 41, 4, padding=20, groups=16)), 52 | weight_norm(Conv1d( 512, 1024, 41, 4, padding=20, groups=16)), 53 | weight_norm(Conv1d(1024, 1024, 41, 1, padding=20, groups=16)), 54 | weight_norm(Conv1d(1024, 1024, 5, 1, padding=2)), 55 | ]) 56 | self.conv_post = weight_norm(Conv1d(1024, 1, 3, 1, padding=1)) 57 | 58 | def forward(self, x): 59 | # torch.Size([16, 1, 8192]) 60 | # torch.Size([16, 32, 8192]) 61 | # torch.Size([16, 64, 4096]) 62 | # torch.Size([16, 128, 2048]) 63 | # torch.Size([16, 512, 512]) 64 | # torch.Size([16, 512, 128]) 65 | # torch.Size([16, 512, 128]) 66 | # torch.Size([16, 1, 128]) 67 | # 68 | # torch.Size([16, 1, 4096]) 69 | # torch.Size([16, 32, 4096]) 70 | # torch.Size([16, 64, 2048]) 71 | # torch.Size([16, 128, 1024]) 72 | # torch.Size([16, 512, 256]) 73 | # torch.Size([16, 512, 64]) 74 | # torch.Size([16, 512, 64]) 75 | # torch.Size([16, 1, 64]) 76 | # 77 | # torch.Size([16, 1, 2048]) 78 | # torch.Size([16, 32, 2048]) 79 | # torch.Size([16, 64, 1024]) 80 | # torch.Size([16, 128, 512]) 81 | # torch.Size([16, 512, 128]) 82 | # torch.Size([16, 512, 32]) 83 | # torch.Size([16, 512, 32]) 84 | # torch.Size([16, 1, 32]) 85 | 86 | DEBUG = 0 87 | if DEBUG: print('[DiscriminatorS]') 88 | 89 | fmap = [] 90 | if DEBUG: print(x.shape) 91 | 92 | for i, l in enumerate(self.convs): 93 | x = l(x) 94 | if DEBUG: print(x.shape) 95 | fmap.append(x) 96 | x = F.leaky_relu(x, LRELU_SLOPE) 97 | x = self.conv_post(x) 98 | if DEBUG: print(x.shape) 99 | x = torch.flatten(x, 1, -1) 100 | 101 | return x, fmap 102 | 103 | 104 | class MultiScaleDiscriminator(nn.Module): 105 | 106 | def __init__(self): 107 | super().__init__() 108 | 109 | self.discriminators = nn.ModuleList([ 110 | DiscriminatorS(use_sn=i==0) for i in range(hp.msd_layers) 111 | ]) 112 | # 不能使用音频处理的Resample,而要用AvgPool逐步抹去高频细节 113 | self.avgpool = nn.AvgPool1d(kernel_size=hp.downsample_pool_k, stride=2, padding=1) 114 | 115 | def forward(self, y, y_hat): 116 | y_d_rs, y_d_gs = [], [] 117 | fmap_rs, fmap_gs = [], [] 118 | 119 | for i, d in enumerate(self.discriminators): 120 | y_d_r, fmap_r = d(y) 121 | y_d_g, fmap_g = d(y_hat) 122 | y_d_rs.append(y_d_r) ; fmap_rs.append(fmap_r) 123 | y_d_gs.append(y_d_g) ; fmap_gs.append(fmap_g) 124 | 125 | if i != len(self.discriminators) - 1: 126 | y = self.avgpool(y) 127 | y_hat = self.avgpool(y_hat) 128 | 129 | return y_d_rs, y_d_gs, fmap_rs, fmap_gs 130 | 131 | 132 | class DiscriminatorP(nn.Module): 133 | 134 | def __init__(self, period): 135 | super().__init__() 136 | 137 | self.period = period 138 | 139 | # [8192]; 只在T方向降采样 3*3*3*3=81 倍 140 | # = [4096,2] 141 | # = [2731,3] (*) 142 | # = [1639,5] (*) 143 | # = [1171,7] (*) 144 | # = [745,11] (*) 145 | sel = 'HiFiGAN_small' 146 | if sel == 'HiFiGAN': 147 | self.convs = nn.ModuleList([ 148 | weight_norm(Conv2d( 1, 32, (5, 1), (3, 1), padding=(2, 0))), 149 | weight_norm(Conv2d( 32, 128, (5, 1), (3, 1), padding=(2, 0))), 150 | weight_norm(Conv2d( 128, 512, (5, 1), (3, 1), padding=(2, 0))), 151 | weight_norm(Conv2d( 512, 1024, (5, 1), (3, 1), padding=(2, 0))), 152 | weight_norm(Conv2d(1024, 1024, (5, 1), 1, padding=(2, 0))), 153 | ]) 154 | self.conv_post = weight_norm(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0))) 155 | elif sel == 'HiFiGAN_small': 156 | self.convs = nn.ModuleList([ 157 | weight_norm(Conv2d( 1, 32, (5, 1), (3, 1), padding=(2, 0))), 158 | weight_norm(Conv2d( 32, 128, (5, 1), (3, 1), padding=(2, 0))), 159 | weight_norm(Conv2d(128, 256, (5, 1), (3, 1), padding=(2, 0))), 160 | weight_norm(Conv2d(256, 512, (5, 1), (3, 1), padding=(2, 0))), 161 | weight_norm(Conv2d(512, 512, (5, 1), 1, padding=(2, 0))), 162 | ]) 163 | self.conv_post = weight_norm(Conv2d(512, 1, (3, 1), 1, padding=(1, 0))) 164 | 165 | def forward(self, x): 166 | # torch.Size([16, 1, 2731, 3]) 167 | # torch.Size([16, 32, 911, 3]) 168 | # torch.Size([16, 128, 304, 3]) 169 | # torch.Size([16, 256, 102, 3]) 170 | # torch.Size([16, 512, 34, 3]) 171 | # torch.Size([16, 512, 34, 3]) 172 | # torch.Size([16, 1, 34, 3]) 173 | # 174 | # torch.Size([16, 1, 1639, 5]) 175 | # torch.Size([16, 32, 547, 5]) 176 | # torch.Size([16, 128, 183, 5]) 177 | # torch.Size([16, 256, 61, 5]) 178 | # torch.Size([16, 512, 21, 5]) 179 | # torch.Size([16, 512, 21, 5]) 180 | # torch.Size([16, 1, 21, 5]) 181 | # 182 | # torch.Size([16, 1, 1171, 7]) 183 | # torch.Size([16, 32, 391, 7]) 184 | # torch.Size([16, 128, 131, 7]) 185 | # torch.Size([16, 256, 44, 7]) 186 | # torch.Size([16, 512, 15, 7]) 187 | # torch.Size([16, 512, 15, 7]) 188 | # torch.Size([16, 1, 15, 7]) 189 | # 190 | # torch.Size([16, 1, 745, 11]) 191 | # torch.Size([16, 32, 249, 11]) 192 | # torch.Size([16, 128, 83, 11]) 193 | # torch.Size([16, 256, 28, 11]) 194 | # torch.Size([16, 512, 10, 11]) 195 | # torch.Size([16, 512, 10, 11]) 196 | # torch.Size([16, 1, 10, 11]) 197 | 198 | DEBUG = 0 199 | if DEBUG: print('[DiscriminatorP]') 200 | 201 | fmap = [] 202 | 203 | # 1d to 2d 204 | b, c, t = x.shape 205 | if t % self.period != 0: # pad tail 206 | n_pad = self.period - (t % self.period) 207 | x = F.pad(x, (0, n_pad), "reflect") 208 | t = t + n_pad 209 | # [B, C, T', P] 210 | x = x.view(b, c, t // self.period, self.period) 211 | 212 | if DEBUG: print(x.shape) 213 | for i, l in enumerate(self.convs): 214 | x = l(x) 215 | if DEBUG: print(x.shape) 216 | fmap.append(x) 217 | x = F.leaky_relu(x, LRELU_SLOPE) 218 | x = self.conv_post(x) 219 | if DEBUG: print(x.shape) 220 | x = torch.flatten(x, 1, -1) 221 | 222 | return x, fmap 223 | 224 | 225 | class MultiPeriodDiscriminator(nn.Module): 226 | 227 | def __init__(self): 228 | super().__init__() 229 | 230 | self.discriminators = nn.ModuleList([ 231 | DiscriminatorP(hp.mpd_periods[i]) for i in range(len(hp.mpd_periods)) 232 | ]) 233 | 234 | def forward(self, y, y_hat): 235 | y_d_rs, y_d_gs = [], [] 236 | fmap_rs, fmap_gs = [], [] 237 | 238 | for d in self.discriminators: 239 | y_d_r, fmap_r = d(y) 240 | y_d_g, fmap_g = d(y_hat) 241 | y_d_rs.append(y_d_r) ; fmap_rs.append(fmap_r) 242 | y_d_gs.append(y_d_g) ; fmap_gs.append(fmap_g) 243 | 244 | return y_d_rs, y_d_gs, fmap_rs, fmap_gs 245 | 246 | 247 | class StftDiscriminator(nn.Module): 248 | 249 | def __init__(self, i, ch=2): 250 | super().__init__() 251 | 252 | # [1025, 35] 253 | # [513, 69] 254 | # [257, 137] 255 | self.convs = nn.ModuleList([ 256 | weight_norm(Conv2d( ch, 32, (3, 3), (2, 1), padding=(1, 1))), 257 | weight_norm(Conv2d( 32, 64, (3, 3), (2, 2), padding=(1, 1))), 258 | weight_norm(Conv2d( 64, 256, (5, 3), (3, 2), padding=(2, 1))), 259 | weight_norm(Conv2d(256, 512, (5, 3), (3, 2), padding=(2, 1))), 260 | weight_norm(Conv2d(512, 512, 3, 1, padding=1)), 261 | ]) 262 | self.conv_post = weight_norm(Conv2d(512, 1, 3, 1, padding=1)) 263 | 264 | self.convs .apply(init_weights) 265 | self.conv_post.apply(init_weights) 266 | 267 | def forward(self, x): 268 | # torch.Size([16, 2, 1025, 35]) 269 | # torch.Size([16, 32, 513, 35]) 270 | # torch.Size([16, 64, 257, 18]) 271 | # torch.Size([16, 256, 86, 9]) 272 | # torch.Size([16, 512, 29, 5]) 273 | # torch.Size([16, 512, 29, 5]) 274 | # torch.Size([16, 1, 29, 5]) 275 | # 276 | # torch.Size([16, 2, 513, 69]) 277 | # torch.Size([16, 32, 257, 69]) 278 | # torch.Size([16, 64, 129, 35]) 279 | # torch.Size([16, 256, 43, 18]) 280 | # torch.Size([16, 512, 15, 9]) 281 | # torch.Size([16, 512, 15, 9]) 282 | # torch.Size([16, 1, 15, 9]) 283 | # 284 | # torch.Size([16, 2, 257, 137]) 285 | # torch.Size([16, 32, 129, 137]) 286 | # torch.Size([16, 64, 65, 69]) 287 | # torch.Size([16, 256, 22, 35]) 288 | # torch.Size([16, 512, 8, 18]) 289 | # torch.Size([16, 512, 8, 18]) 290 | # torch.Size([16, 1, 8, 18]) 291 | 292 | DEBUG = 0 293 | if DEBUG: print('[StftDiscriminator]') 294 | 295 | fmap = [] 296 | if DEBUG: print(x.shape) 297 | 298 | # x.shape = [B, 2, C, T] 299 | for i, l in enumerate(self.convs): 300 | x = l(x) 301 | if DEBUG: print(x.shape) 302 | fmap.append(x) 303 | x = F.leaky_relu(x, LRELU_SLOPE) 304 | x = self.conv_post(x) 305 | if DEBUG: print(x.shape) 306 | x = torch.flatten(x, 1, -1) 307 | 308 | return x, fmap 309 | 310 | 311 | class MultiStftDiscriminator(nn.Module): 312 | 313 | def __init__(self): 314 | super().__init__() 315 | 316 | self.discriminators = nn.ModuleList([ 317 | StftDiscriminator(i) for i in range(len(hp.multi_stft_params)) 318 | ]) 319 | 320 | def forward(self, phs, ph_hats): 321 | ph_d_rs, ph_d_gs = [], [] 322 | fmap_rs, fmap_gs = [], [] 323 | 324 | for d, ph, ph_hat in zip(self.discriminators, phs, ph_hats): 325 | ph_d_r, fmap_r = d(ph) 326 | ph_d_g, fmap_g = d(ph_hat) 327 | ph_d_rs.append(ph_d_r) ; fmap_rs.append(fmap_r) 328 | ph_d_gs.append(ph_d_g) ; fmap_gs.append(fmap_g) 329 | 330 | return ph_d_rs, ph_d_gs, fmap_rs, fmap_gs 331 | -------------------------------------------------------------------------------- /retunegan/models/loss.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Author: Armit 3 | # Create Time: 2022/04/16 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | import seaborn as sns 10 | import matplotlib.pyplot as plt 11 | 12 | import hparam as hp 13 | from audio import get_stft_torch 14 | from utils import PI 15 | 16 | 17 | # globals 18 | MaxPool = nn.MaxPool1d(hp.envelope_pool_k) 19 | 20 | 21 | # multi-stft loss 22 | def multi_stft_loss(y, y_g, ret_loss=False, ret_specs=False): 23 | loss = 0 24 | if ret_specs: stft_r, stft_g = [], [] 25 | 26 | # [B, 1, T] => [B, T] 27 | if len(y.shape) == 3: 28 | y, y_g = y.squeeze(1), y_g.squeeze(1) 29 | 30 | for n_fft, win_length, hop_length in hp.multi_stft_params: 31 | # 得到原始的Mel值 32 | y_mag, y_mel, y_phase = get_stft_torch(y, n_fft, win_length, hop_length) # TODO: 这项可以拆出去缓存 33 | y_g_mag, y_g_mel, y_g_phase = get_stft_torch(y_g, n_fft, win_length, hop_length) 34 | 35 | # 得到对数放缩的谱值 36 | log_y_mel, log_y_g_mel = torch.log(y_mel), torch.log(y_g_mel) 37 | log_y_mag, log_y_g_mag = torch.log(y_mag), torch.log(y_g_mag) 38 | norm_y_phase, norm_y_g_phase = y_phase / PI, y_g_phase / PI 39 | 40 | if ret_specs: 41 | # 判别器在线性谱上作判定 42 | if hp.phd_input == 'stft': 43 | stft_r.append(torch.stack([log_y_mag, norm_y_phase], dim=1)) 44 | stft_g.append(torch.stack([log_y_g_mag, norm_y_g_phase], dim=1)) 45 | elif hp.phd_input == 'phase': 46 | stft_r.append(torch.stack([log_y_mag, norm_y_phase], dim=1)) 47 | stft_g.append(torch.stack([log_y_mag, norm_y_g_phase], dim=1)) 48 | else: raise 49 | 50 | # 谱损失在Mel的对数值和原始值上作考察 51 | loss += F.l1_loss( y_mel, y_g_mel) 52 | loss += F.l1_loss(log_y_mel, log_y_g_mel) 53 | 54 | loss /= len(hp.multi_stft_params) 55 | 56 | if ret_loss and ret_specs: 57 | return loss, (stft_r, stft_g) 58 | elif ret_loss: 59 | return loss 60 | elif ret_specs: 61 | return (stft_r, stft_g) 62 | else: raise 63 | 64 | 65 | # envelope loss for waveform 66 | def envelope_loss(y, y_g): 67 | # 绝对动态包络 68 | loss = 0 69 | loss += torch.mean(torch.abs(MaxPool( y) - MaxPool( y_g))) 70 | loss += torch.mean(torch.abs(MaxPool(-y) - MaxPool(-y_g))) 71 | 72 | return loss 73 | 74 | 75 | # dynamic loss for waveform 76 | def dynamic_loss(y, y_g): 77 | # 相对动态大小 78 | dyn_y = torch.abs(MaxPool(y) + MaxPool(-y)) 79 | dyn_y_g = torch.abs(MaxPool(y_g) + MaxPool(-y_g)) 80 | loss = torch.mean(torch.abs(dyn_y - dyn_y_g)) 81 | 82 | return loss 83 | 84 | 85 | # strip mirror loss for waveform 86 | def strip_mirror_loss(y): 87 | # 可能没啥用甚至有副作用的正则损失 88 | 89 | # assure length is even 90 | if y.shape[-1] % 2 != 0: y = y[:,:,:-1] 91 | # strip split & de-mean 92 | even, odd = y[:,:,::2], y[:,:,1::2] 93 | even = even - even.mean() 94 | odd = odd - odd .mean() 95 | # maximize |e-o| 96 | loss = torch.mean(-torch.log(torch.clamp_max((torch.abs(even - odd) + 1e-9), max=1.0))) 97 | 98 | return loss 99 | 100 | 101 | # adversarial loss for discriminator 102 | def discriminator_loss(disc_r, disc_g): 103 | loss = 0 104 | 105 | for dr, dg in zip(disc_r, disc_g): # List[[B, T], ...] 106 | if hp.relative_gan_loss: 107 | # maxmize gap betwwen dg & dr 108 | #r_loss = torch.mean((1 - (dr - dg.detach().mean())) ** 2) 109 | #r_loss = torch.mean((1 - (dr - dg.detach().mean(axis=-1))) ** 2) 110 | #r_loss = torch.mean(1 - torch.tanh(dr - dg.detach())) 111 | #r_loss = torch.mean((1 - (dr - dg.detach())) ** 2) 112 | #g_loss = torch.mean((0 - dg) ** 2) 113 | r_loss = torch.mean(torch.mean((1 - (dr - dg.detach())) ** 2, axis=-1)) 114 | g_loss = torch.mean(torch.mean((0 - dg) ** 2, axis=-1)) 115 | #r_loss = torch.mean(-torch.log(dr)) 116 | #g_loss = torch.mean(dg) 117 | else: 118 | # let dr -> 1, dg -> 0 119 | #r_loss = torch.mean((1 - dr) ** 2) 120 | #g_loss = torch.mean((0 - dg) ** 2) 121 | r_loss = torch.mean(torch.mean((1 - dr) ** 2, axis=-1)) 122 | g_loss = torch.mean(torch.mean((0 - dg) ** 2, axis=-1)) 123 | loss += (r_loss + g_loss) 124 | 125 | return loss 126 | 127 | 128 | # adversarial loss for generator 129 | def generator_loss(disc_g, disc_r): 130 | loss = 0 131 | 132 | for dg, dr in zip(disc_g, disc_r): 133 | if hp.relative_gan_loss: 134 | # let dg ~= dr 135 | #g_loss = torch.mean((dr.detach().mean(axis=-1) - dg) ** 2) 136 | #g_loss = torch.mean(1 - torch.tanh(dg - dr.detach())) 137 | #g_loss = torch.mean((dr.detach() - dg) ** 2) 138 | g_loss = torch.mean(torch.mean((dg - dr.detach()) ** 2, axis=-1)) 139 | else: 140 | # let dg -> 1 141 | #g_loss = torch.mean((1 - dg) ** 2) 142 | g_loss = torch.mean(torch.mean((1 - dg) ** 2, axis=-1)) 143 | loss += g_loss 144 | 145 | return loss 146 | 147 | 148 | # feature map loss for generator 149 | def feature_loss(fmap_r, fmap_g): 150 | loss = 0 151 | 152 | for dr, dg in zip(fmap_r, fmap_g): 153 | for r, g in zip(dr, dg): 154 | loss += F.l1_loss(r, g) 155 | 156 | return loss 157 | -------------------------------------------------------------------------------- /retunegan/server.py: -------------------------------------------------------------------------------- 1 | import io 2 | import pickle 3 | from time import time 4 | from argparse import ArgumentParser 5 | 6 | import torch 7 | import numpy as np 8 | from flask import Flask, request, jsonify, send_file 9 | 10 | import hparam as h 11 | from models.generator import * 12 | from audio import mag_to_mel, inv_mag 13 | from utils import scan_checkpoint, load_checkpoint 14 | 15 | 16 | os.environ['LIBROSA_CACHE_LEVEL'] = '50' 17 | 18 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 19 | torch.backends.cudnn.enable = True 20 | torch.backends.cudnn.benchmark = True 21 | torch. manual_seed(h.randseed) 22 | torch.cuda.manual_seed(h.randseed) 23 | 24 | torch.autograd.set_detect_anomaly(hp.debug) 25 | 26 | 27 | # globals 28 | app = Flask(__name__) 29 | generator = None 30 | 31 | 32 | # vocode 33 | @app.route('/vocode', methods=['POST']) 34 | def vocode(): 35 | try: 36 | # chk mag 37 | mag = pickle.loads(request.data) 38 | print(f'mag.shape: {mag.shape}, dyn_range: [{mag.min()}, {mag.max()}]') 39 | if mag.shape[1] == h.n_freq: mag = mag.T # assure [F, T] 40 | # ref: preprocess in `data.Dataset.__getitem__()` 41 | mel = mag_to_mel(mag) 42 | wavlen = h.hop_length * mag.shape[1] 43 | wav_tmpl = inv_mag(mag, wavlen=wavlen-1) 44 | wav_tmpl = np.pad(wav_tmpl, (0, 1)) 45 | 46 | # mel to wav 47 | s = time() 48 | with torch.no_grad(): 49 | mel = torch.from_numpy(mel) .to(device, non_blocking=True).float().unsqueeze(0) 50 | wav_tmpl = torch.from_numpy(wav_tmpl).to(device, non_blocking=True).float().unsqueeze(0).unsqueeze(1) 51 | y_g_hat = generator(mel, wav_tmpl) 52 | wav = y_g_hat.squeeze() 53 | wav = wav.cpu().numpy().astype(np.float32) 54 | t = time() 55 | print(f'wav.shape: {wav.shape}, dyn_range: [{wav.min()}, {wav.max()}]') 56 | print(f'[Vocode] Done in {t - s:.2f}s') 57 | 58 | # transfer 59 | bio = io.BytesIO() 60 | bio.write(pickle.dumps(wav)) # float32 -> byte 61 | bio.seek(0) # reset fp to beginning for `send_file` to read 62 | return send_file(bio, mimetype='application/octet-stream') 63 | 64 | except Exception as e: 65 | print('[Error] %r' % e) 66 | return jsonify({'error': e}) 67 | 68 | 69 | if __name__ == '__main__': 70 | parser = ArgumentParser() 71 | parser.add_argument('--log_path', required=True) 72 | parser.add_argument('--host', type=str, default='0.0.0.0') 73 | parser.add_argument('--port', type=int, default=5104) 74 | args = parser.parse_args() 75 | 76 | # load ckpt 77 | generator = globals().get(f'Generator_{h.generator_ver}')().to(device) 78 | state_dict_g = load_checkpoint(scan_checkpoint(args.log_path, 'g_'), device) 79 | generator.load_state_dict(state_dict_g['generator']) 80 | generator.eval() 81 | generator.remove_weight_norm() 82 | 83 | app.run(host=args.host, port=args.port, debug=False) 84 | -------------------------------------------------------------------------------- /retunegan/tools/spec2wavset.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Author: Armit 3 | # Create Time: 2022/02/14 4 | 5 | # 借鉴RefineGAN的想法: 6 | # RefineGAN: 根据从上游模型预测生成的f0,c0,voice-flag信息,生成一个wavform template 7 | # 对于unvoiced段填充高斯白噪、对于voice段依据f0的频率来放置脉冲,再依照c0画整体包络 8 | # spec2wavset: stft谱展示了将原信号拆解为一组频率等间距的正弦信号,因此我们可以直接叠加这一组正弦波作为wavform template 9 | # Q: 为什么不用GriffinLim的输出做模板呢,后续网络相当于在频域进行降噪了 10 | # A: 降噪是很困难的,但是逆向思考——正弦波的组合是比较干净的、我们基于它来加噪声! 11 | # *注意: 需要多组stft参数来避免频率丢失、减缓窗函数导致的频率泄露问题 12 | # fft_params: 13 | # n_fft win_length hop_length 14 | # 2048 1024 256 15 | # 1024 512 128 16 | # 512 256 64 17 | 18 | # 从FFT提取频率和振幅: https://blog.csdn.net/taw19960426/article/details/101684663 19 | # 宽带语谱图: 帧长 3ms, hop_length=48(16000Hz)/66(22050Hz) 20 | # 窄带语谱图: 帧长20ms, hop_length=320(16000Hz)/441(22050Hz) 21 | 22 | # NOTE: 统计特征 23 | # magnitude: 单峰(语音信号) + 左侧高原(底噪) 24 | # phase: 近乎[-pi,pi]间的均匀分布 25 | 26 | import os 27 | import numpy as np 28 | import numpy.random as R 29 | import matplotlib.pyplot as plt 30 | import seaborn as sns 31 | import librosa as L 32 | from scipy.io import wavfile 33 | 34 | sample_rate = 16000 35 | n_mel = 80 36 | fmin = 70 37 | fmax = 7600 38 | fft_params = [ 39 | # (n_fft, win_length, hop_length) 40 | (4096, 2048, 512), 41 | (2048, 1024, 256), 42 | (1024, 512, 128), 43 | ( 512, 256, 64), 44 | ( 256, 128, 32), 45 | ] 46 | 47 | 48 | def calc_spec(y, n_fft, hop_length, win_length, clamp_lower=False): 49 | D = L.stft(y, n_fft=n_fft, hop_length=hop_length, win_length=win_length, window='hann') 50 | S = np.abs(D) 51 | print('S.min():', S.min()) 52 | mag = np.log(S.clip(min=1e-5)) if clamp_lower else S 53 | M = L.feature.melspectrogram(S=S, sr=sample_rate, n_fft=n_fft, hop_length=hop_length, win_length=win_length, 54 | n_mels=n_mel, fmin=fmin, fmax=fmax, window='hann', power=1, htk=True) 55 | print('M.min():', M.min()) 56 | mel = np.log(M.clip(min=1e-5)) if clamp_lower else M 57 | return mag, mel 58 | 59 | 60 | def get_specs(fp): 61 | y, _ = L.load(fp, sr=sample_rate, mono=True, res_type='kaiser_best') 62 | return [calc_spec(y, n_fft, hop_length, win_length) for n_fft, win_length, hop_length in fft_params] 63 | 64 | 65 | def display_specs(specs): 66 | n_fig = len(specs) 67 | for i in range(n_fig): 68 | if isinstance(type(specs[i]), (list, tuple)): 69 | plt.subplot(n_fig*100+10+i+1) 70 | ax = sns.heatmap(specs[i][0]) 71 | ax.invert_yaxis() 72 | plt.subplot(n_fig*100+20+i+1) 73 | ax = sns.heatmap(specs[i][1]) 74 | ax.invert_yaxis() 75 | else: 76 | plt.subplot(n_fig*100+10+i+1) 77 | ax = sns.heatmap(specs[i]) 78 | ax.invert_yaxis() 79 | plt.show() 80 | plt.clf() 81 | 82 | def f(k, n_fft): 83 | return k * (sample_rate / n_fft) 84 | 85 | def sin(x, A, freq, phi=0): 86 | w = 2 * np.pi * freq # f = w / (2*pi) 87 | return A * np.sin(w * x + phi) 88 | 89 | 90 | def extract_f_A(fp): 91 | for i, (mag, _) in enumerate(get_specs(fp)): 92 | n_fft, _, _ = fft_params[i] 93 | print(f'[n_fft={n_fft}]') 94 | 95 | fr = sample_rate / n_fft # 频率分辨率单位 96 | print('mag min/max/mean:', mag.min(), mag.max(), mag.mean()) 97 | 98 | f_A = [] 99 | thresh = mag.mean() * 2 100 | for i, energy in enumerate(mag): 101 | print(f'frame {i}: ') 102 | j = 0 103 | while j < len(energy): 104 | # 阈值响应 105 | if energy[j] <= thresh: 106 | j += 1 107 | else: 108 | # 略过上坡 109 | while j + 1 < len(energy) and energy[j+1] >= energy[j]: 110 | j += 1 111 | # 取得峰值 112 | f_A.append((fr * j, energy[j])) 113 | print(f' freq({j}) = {fr * j} Hz, amp = {energy[j]}') 114 | # 略过下坡 115 | while j + 1 < len(energy) and energy[j+1] < energy[j]: 116 | j += 1 117 | 118 | 119 | def demo_extract_f_A(): 120 | st = 1 # 采样时间 1s 121 | x = np.arange(0, 1, st / sample_rate) 122 | print('signal len:', len(x)) 123 | y1 = sin(x, 2, 207) 124 | y2 = sin(x, -1, 843) 125 | y = y1 + y2 126 | 127 | plt.subplot(311) ; plt.plot(x, y1) 128 | plt.subplot(312) ; plt.plot(x, y2) 129 | plt.subplot(313) ; plt.plot(x, y) 130 | plt.show() 131 | 132 | n_fft = 2048 133 | mag = calc_mag(y, n_fft=n_fft, win_length=1024, hop_length=256) 134 | display_specs(mag) 135 | 136 | mag = mag.T 137 | 138 | n_fft = fft_params[1][0] 139 | A = 2 * mag / n_fft 140 | fr = sample_rate / n_fft # 频率分辨率单位 141 | 142 | f_A = [] 143 | thresh = 0.1 144 | for i, energy in enumerate(A): 145 | print(f'frame {i}: ') 146 | j = 0 147 | while j < len(energy): 148 | # 阈值响应 149 | if energy[j] <= thresh: 150 | j += 1 151 | else: 152 | # 略过上坡 153 | while j + 1 < len(energy) and energy[j+1] >= energy[j]: 154 | j += 1 155 | # 取得峰值 156 | f_A.append((fr * j, energy[j])) 157 | print(f' freq({j}) = {fr * j} Hz, amp = {energy[j]}') 158 | # 略过下坡 159 | while j + 1 < len(energy) and energy[j+1] < energy[j]: 160 | j += 1 161 | 162 | return f_A 163 | 164 | 165 | def demo_stft_istft(fp): 166 | n_fft, win_length, hop_length = 2048, 1024, 256 167 | 168 | y, _ = L.load(fp, sr=sample_rate, mono=True, res_type='kaiser_best') 169 | length = len(y) # for align output 170 | 171 | stft = L.stft(y, n_fft=n_fft, hop_length=hop_length, win_length=win_length) 172 | mag, phase = np.abs(stft), np.angle(stft) 173 | y_istft = L.istft(stft, win_length=win_length, hop_length=hop_length, length=length) 174 | D = L.stft(y_istft, n_fft=n_fft, hop_length=hop_length, win_length=win_length) 175 | mag_istft, phase_istft = np.abs(D), np.angle(D) 176 | 177 | print('mag.min():', mag.min(), 'mag.max():', mag.max()) 178 | plt.subplot(311) ; plt.hist(mag.flatten(), bins=150) # [0, 10] 179 | maglog = np.log(mag) 180 | print('maglog.min():', maglog.min(), 'maglog.max():', maglog.max()) 181 | plt.subplot(312) ; plt.hist(maglog.flatten(), bins=150) # [-12, 3] 182 | plt.subplot(313) ; plt.hist(phase.flatten(), bins=150) # [-pi, pi] 183 | plt.show() 184 | 185 | phase_ristft_r = np.random.uniform(low=-np.pi, high=np.pi, size=phase.shape) 186 | rstft = mag * np.exp(1j * phase_ristft_r) 187 | y_ristft = L.istft(rstft, win_length=win_length, hop_length=hop_length, length=length) 188 | D = L.stft(y_ristft, n_fft=n_fft, hop_length=hop_length, win_length=win_length) 189 | mag_ristft, phase_ristft = np.abs(D), np.angle(D) 190 | 191 | y_gl = L.griffinlim(mag, hop_length=hop_length, win_length=win_length, length=length) 192 | D = L.stft(y_gl, n_fft=n_fft, hop_length=hop_length, win_length=win_length) 193 | mag_gl, phase_gl = np.abs(D), np.angle(D) 194 | 195 | print('y_istft error:', np.sum(y - y_istft)) 196 | print('y_ristft error:', np.sum(y - y_ristft)) 197 | print('y_gl error:', np.sum(y - y_gl)) 198 | 199 | print('y_istft mirror error:', np.sum([e - o for e, o in zip(y_istft[::2], y_istft[1::2])])) 200 | print('y_ristft mirror error:', np.sum([e - o for e, o in zip(y_ristft[::2], y_ristft[1::2])])) 201 | print('y_gl mirror error:', np.sum([e - o for e, o in zip(y_gl[::2], y_gl[1::2])])) 202 | 203 | print('mag_istft error:', np.sum(mag - mag_istft)) 204 | print('mag_ristft error:', np.sum(mag - mag_ristft)) 205 | print('mag_gl error:', np.sum(mag - mag_gl)) 206 | 207 | print('phase_istft error:', np.sum(phase - phase_istft)) 208 | print('phase_ristft error:', np.sum(phase - phase_ristft)) 209 | print('phase_ristft_r error:', np.sum(phase_ristft_r - phase_ristft)) 210 | print('phase_gl error:', np.sum(phase - phase_gl)) 211 | 212 | plt.subplot(231) ; sns.heatmap(mag) 213 | plt.subplot(232) ; sns.heatmap(mag_ristft) 214 | plt.subplot(233) ; sns.heatmap(mag_gl) 215 | plt.subplot(234) ; sns.heatmap(phase) 216 | plt.subplot(235) ; sns.heatmap(phase_ristft) 217 | plt.subplot(236) ; sns.heatmap(phase_gl) 218 | plt.show() 219 | 220 | wavfile.write('y.wav', sample_rate, y) 221 | wavfile.write('y_istft.wav', sample_rate, y_istft) 222 | wavfile.write('y_ristft.wav', sample_rate, y_ristft) 223 | wavfile.write('y_gl.wav', sample_rate, y_gl) 224 | 225 | 226 | if __name__ == '__main__': 227 | dp = r'D:\Desktop\Workspace\Data\DataBaker\Wave' 228 | fn = R.choice([fn for fn in os.listdir(dp) if fn.endswith('.wav')]) 229 | fp = os.path.join(dp, fn) 230 | 231 | demo_stft_istft(fp) 232 | #extract_f_A(fp) 233 | -------------------------------------------------------------------------------- /retunegan/tools/test_downsample.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Author: Armit 3 | # Create Time: 2022/03/10 4 | 5 | # use this script to decide `hp.downsample_pool_k` 6 | # higher sample_rate needs larger `k` 7 | 8 | 9 | import os 10 | import random as R 11 | import librosa as L 12 | from scipy.io import wavfile 13 | import matplotlib.pyplot as plt 14 | import torch 15 | import torch.nn as nn 16 | import torchaudio.transforms as T 17 | 18 | import hparam as hp 19 | 20 | resamplers = [ 21 | T.Resample(hp.sample_rate, hp.sample_rate//2, resampling_method='sinc_interpolation'), # kaiser_window 22 | T.Resample(hp.sample_rate, hp.sample_rate//4, resampling_method='sinc_interpolation'), 23 | T.Resample(hp.sample_rate, hp.sample_rate//8, resampling_method='sinc_interpolation'), 24 | ] 25 | avg_pools = [ 26 | nn.AvgPool1d(kernel_size=2, stride=2, padding=1), 27 | nn.AvgPool1d(kernel_size=4, stride=2, padding=2), 28 | nn.AvgPool1d(kernel_size=6, stride=2, padding=3), # for 16000 Hz 29 | nn.AvgPool1d(kernel_size=8, stride=2, padding=4), 30 | ] 31 | 32 | dp = r'C:\Users\Kahsolt\Desktop\Workspace\Data\DataBaker\Wave' 33 | fn = R.choice([fn for fn in os.listdir(dp) if fn.endswith('.wav')]) 34 | fp = os.path.join(dp, fn) 35 | 36 | y = L.load(fp, hp.sample_rate)[0] 37 | y = torch.from_numpy(y).unsqueeze(0).unsqueeze(0) 38 | 39 | plt.subplot(len(resamplers)+1, 1, 1) 40 | plt.plot(y.numpy().squeeze()) 41 | for i, resampler in enumerate(resamplers): 42 | s = resampler(y) 43 | plt.subplot(len(resamplers)+1, 1, i+2) 44 | plt.plot(s.numpy().squeeze()) 45 | plt.show() 46 | 47 | for i, avg_pool in enumerate(avg_pools): 48 | plt.subplot(len(resamplers)+1, 1, 1) 49 | plt.plot(y.numpy().squeeze()) 50 | s = y 51 | for j in range(len(resamplers)): 52 | s = avg_pool(s) 53 | plt.subplot(len(resamplers)+1, 1, j+2) 54 | plt.plot(s.numpy().squeeze()) 55 | plt.show() 56 | 57 | -------------------------------------------------------------------------------- /retunegan/tools/test_envolope.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Author: Armit 3 | # Create Time: 2022/03/10 4 | 5 | # use this script to decide `hp.envelope_pool_k` 6 | # higher sample_rate needs larger `k` 7 | 8 | 9 | import os 10 | import random as R 11 | import librosa as L 12 | import matplotlib.pyplot as plt 13 | import torch 14 | import torch.nn as nn 15 | 16 | import hparam as hp 17 | 18 | 19 | max_pools = [ 20 | nn.MaxPool1d( 64, 1), 21 | nn.MaxPool1d(128, 1), # for 16000 Hz 22 | nn.MaxPool1d(160, 1), # for 22050 Hz 23 | nn.MaxPool1d(256, 1), # for 44100 Hz 24 | nn.MaxPool1d(512, 1), 25 | ] 26 | 27 | dp = r'C:\Users\Kahsolt\Desktop\Workspace\Data\DataBaker\Wave' 28 | fn = R.choice([fn for fn in os.listdir(dp) if fn.endswith('.wav')]) 29 | fp = os.path.join(dp, fn) 30 | 31 | 32 | y = L.load(fp, hp.sample_rate)[0] 33 | y_np = y 34 | y = torch.from_numpy(y).unsqueeze(0).unsqueeze(0) 35 | 36 | plt.subplot(4, 1, 1) 37 | plt.plot(y.numpy().squeeze()) 38 | plt.title('y') 39 | pool = max_pools[2] 40 | u = pool( y) 41 | d = -pool(-y) 42 | plt.subplot(4, 1, 2) 43 | plt.title('y_envolope') 44 | plt.plot(u.numpy().squeeze()) 45 | plt.plot(d.numpy().squeeze()) 46 | plt.subplot(4, 1, 3) 47 | plt.title('y_even') 48 | plt.plot(y_np[::2]) 49 | plt.subplot(4, 1, 4) 50 | plt.title('y_odd') 51 | plt.plot(y_np[1::2]) 52 | plt.show() 53 | 54 | exit(0) 55 | 56 | y = L.load(fp, hp.sample_rate)[0] 57 | y = torch.from_numpy(y).unsqueeze(0).unsqueeze(0) 58 | 59 | plt.subplot(len(max_pools)+1, 1, 1) 60 | plt.plot(y.numpy().squeeze()) 61 | for i, pool in enumerate(max_pools): 62 | u = pool( y) 63 | d = -pool(-y) 64 | plt.subplot(len(max_pools)+1, 1, i+2) 65 | plt.plot(u.numpy().squeeze()) 66 | plt.plot(d.numpy().squeeze()) 67 | plt.show() 68 | -------------------------------------------------------------------------------- /retunegan/tools/test_griffinlim.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Author: Armit 3 | # Create Time: 2022/03/10 4 | 5 | import os 6 | import random as R 7 | import copy 8 | import librosa as L 9 | import seaborn as sns 10 | import matplotlib.pyplot as plt 11 | import numpy as np 12 | from scipy.io import wavfile 13 | 14 | import hparam as hp 15 | 16 | R.seed(114514) 17 | 18 | 19 | def decompose(D): 20 | P, S = np.angle(D), np.abs(D) 21 | logS = np.log(S.clip(1e-5, None)) 22 | return P, S, logS 23 | 24 | 25 | def griffin_lim(S, n_iter=60): 26 | X_best = copy.deepcopy(S) 27 | for _ in range(n_iter): 28 | X_t = L.istft(X_best, hop_length=hp.hop_length, win_length=hp.win_length, window=hp.window_fn) 29 | X_best = L.stft(X_t, n_fft=hp.n_fft, hop_length=hp.hop_length, win_length=hp.win_length, window=hp.window_fn) 30 | phase = X_best / np.maximum(1e-8, np.abs(X_best)) 31 | X_best = S * phase 32 | X_t = L.istft(X_best, hop_length=hp.hop_length, win_length=hp.win_length, window=hp.window_fn) 33 | y = np.real(X_t) 34 | return y 35 | 36 | 37 | def griffinlim(S, n_iter=60): 38 | angles = np.exp(2j * np.pi * np.random.rand(*S.shape)) 39 | 40 | for _ in range(n_iter): 41 | full = np.abs(S).astype(np.complex) * angles 42 | inverse = L.istft(full, hop_length=hp.hop_length, win_length=hp.win_length, window=hp.window_fn) 43 | rebuilt = L.stft(inverse, n_fft=hp.n_fft, hop_length=hp.hop_length, win_length=hp.win_length, window=hp.window_fn) 44 | angles = np.exp(1j * np.angle(rebuilt)) 45 | full = np.abs(S).astype(np.complex) * angles 46 | inverse = L.istft(full, hop_length=hp.hop_length, win_length=hp.win_length, window=hp.window_fn) 47 | return inverse 48 | 49 | 50 | def griffinlim_conj(P, n_iter=60): 51 | #S = np.random.uniform(low=np.exp(-4), high=np.exp(4), size=P.shape) 52 | S = np.random.normal(loc=-4, scale=1, size=P.shape) 53 | #S = np.random.rand(*P.shape) 54 | 55 | for _ in range(n_iter): 56 | D = S * np.exp(1j * P) 57 | y_hat = L.istft(D, hop_length=hp.hop_length, win_length=hp.win_length, window=hp.window_fn) 58 | D_hat = L.stft(y_hat, n_fft=hp.n_fft, hop_length=hp.hop_length, win_length=hp.win_length, window=hp.window_fn) 59 | S = np.abs(D_hat) 60 | 61 | D = S * np.exp(1j * P) 62 | y = L.istft(D, hop_length=hp.hop_length, win_length=hp.win_length, window=hp.window_fn) 63 | return y 64 | 65 | 66 | dp = r'C:\Users\Kahsolt\Desktop\Workspace\Data\DataBaker\Wave' 67 | fn = R.choice([fn for fn in os.listdir(dp) if fn.endswith('.wav')]) 68 | fp = os.path.join(dp, fn) 69 | 70 | y = L.load(fp, hp.sample_rate)[0] 71 | ylen = len(y) 72 | D = L.stft(y, n_fft=hp.n_fft, hop_length=hp.hop_length, win_length=hp.win_length) 73 | P, S, logS = decompose(D) 74 | 75 | y1 = griffin_lim(S) 76 | y2 = griffinlim(S) 77 | y3 = griffinlim_conj(P) 78 | 79 | plt.subplot(4, 1, 1) ; plt.plot(y) 80 | plt.subplot(4, 1, 2) ; plt.plot(y1) 81 | plt.subplot(4, 1, 3) ; plt.plot(y2) 82 | plt.subplot(4, 1, 4) ; plt.plot(y3) 83 | plt.show() 84 | 85 | D1 = L.stft(y1, n_fft=hp.n_fft, hop_length=hp.hop_length, win_length=hp.win_length) 86 | D2 = L.stft(y2, n_fft=hp.n_fft, hop_length=hp.hop_length, win_length=hp.win_length) 87 | D3 = L.stft(y3, n_fft=hp.n_fft, hop_length=hp.hop_length, win_length=hp.win_length) 88 | _, _, logS1 = decompose(D1) 89 | _, _, logS2 = decompose(D2) 90 | _, _, logS3 = decompose(D3) 91 | 92 | plt.subplot(2, 2, 1) ; sns.heatmap(logS, cbar=False) 93 | plt.subplot(2, 2, 2) ; sns.heatmap(logS1, cbar=False) 94 | plt.subplot(2, 2, 3) ; sns.heatmap(logS2, cbar=False) 95 | plt.subplot(2, 2, 4) ; sns.heatmap(logS3, cbar=False) 96 | plt.show() 97 | 98 | wavfile.write('y_griffinlim_conj.wav', hp.sample_rate, y3) 99 | -------------------------------------------------------------------------------- /retunegan/tools/test_istft_iter.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Author: Armit 3 | # Create Time: 2022/03/10 4 | 5 | import os 6 | import random as R 7 | import librosa as L 8 | import seaborn as sns 9 | import matplotlib.pyplot as plt 10 | import numpy as np 11 | from scipy.io import wavfile 12 | 13 | import hparam as hp 14 | 15 | R.seed(114514) 16 | 17 | 18 | def decompose(D): 19 | P, S = np.angle(D), np.abs(D) 20 | return P, S 21 | 22 | 23 | dp = r'C:\Users\Kahsolt\Desktop\Workspace\Data\DataBaker\Wave' 24 | fn = R.choice([fn for fn in os.listdir(dp) if fn.endswith('.wav')]) 25 | fp = os.path.join(dp, fn) 26 | 27 | y = L.load(fp, hp.sample_rate)[0] 28 | ylen = len(y) 29 | D = L.stft(y, n_fft=hp.n_fft, hop_length=hp.hop_length, win_length=hp.win_length) 30 | P, S = decompose(D) 31 | 32 | y_loss, P_loss, S_loss = [], [], [] 33 | D_i, P_i, S_i = D, P, S 34 | for i in range(1000): 35 | #D_i = S_i * np.exp(1j * P_i) 36 | y_i = L.istft(D_i, win_length=hp.win_length, hop_length=hp.hop_length, length=ylen) 37 | D_i = L.stft(y_i, n_fft=hp.n_fft, hop_length=hp.hop_length, win_length=hp.win_length) 38 | P_i, S_i = decompose(D_i) 39 | 40 | y_l = np.mean(np.abs(y - y_i)) ; y_loss.append(y_l) 41 | P_l = np.mean(np.abs(P - P_i)) ; P_loss.append(P_l) 42 | S_l = np.mean(np.abs(S - S_i)) ; S_loss.append(S_l) 43 | 44 | plt.subplot(3, 1, 1) ; plt.plot(y_loss) 45 | plt.subplot(3, 1, 2) ; plt.plot(P_loss) 46 | plt.subplot(3, 1, 3) ; plt.plot(S_loss) 47 | plt.show() 48 | -------------------------------------------------------------------------------- /retunegan/tools/test_pesq.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Author: Armit 3 | # Create Time: 2022/03/10 4 | 5 | # use this script evaluate PESQ 6 | 7 | import sys 8 | import librosa as L 9 | from pesq import pesq 10 | 11 | sr = 16000 12 | #sr = 8000 13 | 14 | BASE_PATH = r'C:\Users\Kahsolt\Desktop\Workspace\Essay\基于韵律优化与波形修复的汉语语音合成方法研究\audio' 15 | 16 | 17 | def test(fp_y, fp_y_hat): 18 | ref, _ = L.load(BASE_PATH + '\\' + fp_y, sr) 19 | deg, _ = L.load(BASE_PATH + '\\' + fp_y_hat, sr) 20 | 21 | print(pesq(sr, ref, ref, 'wb')) 22 | print(pesq(sr, ref, deg, 'wb')) 23 | 24 | print('[gl]') 25 | test('gl_gt.wav', 'gl_64i.wav') 26 | print('[mlg]') 27 | test('mlg-gt.wav', 'mlg-100e.wav') 28 | print('[hfg-40k]') 29 | test('hfg-gt.wav', 'hfg-40k.wav') 30 | print('[hfg-85k]') 31 | test('hfg-gt.wav', 'hfg-85k.wav') 32 | 33 | print('[taco]') 34 | test('taco-gt.wav', 'taco-103k.wav') 35 | -------------------------------------------------------------------------------- /retunegan/tools/test_phase_recover.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Author: Armit 3 | # Create Time: 2022/03/10 4 | 5 | import os 6 | import random as R 7 | import librosa as L 8 | import seaborn as sns 9 | import matplotlib.pyplot as plt 10 | import numpy as np 11 | from scipy.io import wavfile 12 | 13 | import hparam as hp 14 | 15 | R.seed(114514) 16 | 17 | 18 | def decompose(D): 19 | P, S = np.angle(D), np.abs(D) 20 | logS = np.log(S.clip(1e-5, None)) 21 | return P, S, logS 22 | 23 | 24 | dp = r'C:\Users\Kahsolt\Desktop\Workspace\Data\DataBaker\Wave' 25 | fn = R.choice([fn for fn in os.listdir(dp) if fn.endswith('.wav')]) 26 | fp = os.path.join(dp, fn) 27 | 28 | sr, hsr = hp.sample_rate, hp.sample_rate//2 29 | y = L.load(fp, hp.sample_rate)[0] 30 | if len(y) % 2 != 0: y = y[:-1] 31 | ylen = len(y) 32 | 33 | 34 | print('>> effect of noise in time domain') 35 | D = L.stft(y, n_fft=hp.n_fft, hop_length=hp.hop_length, win_length=hp.win_length) 36 | P, S, logS = decompose(D) 37 | for eps in [1e-5, 1e-4, 1e-3, 1e-2, 1e-1, 1]: 38 | y_n = y + np.random.uniform(low=-eps, high=eps, size=y.shape) 39 | D_n = L.stft(y_n, n_fft=hp.n_fft, hop_length=hp.hop_length, win_length=hp.win_length) 40 | P_n, S_n, logS_n = decompose(D_n) 41 | print('eps =', eps) 42 | print('|y - y_n|:', np.mean(np.abs(y - y_n))) 43 | print('|P - P_n|:', np.mean(np.abs(P - P_n))) 44 | print('|S - S_n|:', np.mean(np.abs(S - S_n))) 45 | print('|logS - logS_n|:', np.mean(np.abs(logS - logS_n))) 46 | print() 47 | if False: sns.heatmap(logS_n) ; plt.show() 48 | 49 | 50 | print('>> reverse with original magnitude & original phase') 51 | y_i = L.istft(D, win_length=hp.win_length, hop_length=hp.hop_length, length=ylen) 52 | D_i = L.stft(y_i, n_fft=hp.n_fft, hop_length=hp.hop_length, win_length=hp.win_length) 53 | P_i, S_i, logS_i = decompose(D_i) 54 | 55 | print('|y - y_i|:', np.mean(np.abs(y - y_i))) 56 | print('|P - P_i|:', np.mean(np.abs(P - P_i))) 57 | print('|S - S_i|:', np.mean(np.abs(S - S_i))) 58 | print('|logS - logS_i|:', np.mean(np.abs(logS - logS_i))) 59 | print() 60 | 61 | 62 | print('>> reverse with original magnitude by GriffinLim') # 这个步骤自带一定随机性 63 | y_gl = L.griffinlim(S, hop_length=hp.hop_length, win_length=hp.win_length, length=ylen) 64 | D_gl = L.stft(y_gl, n_fft=hp.n_fft, hop_length=hp.hop_length, win_length=hp.win_length) 65 | P_gl, S_gl, logS_gl = decompose(D_gl) 66 | 67 | print('|y - y_gl|:', np.mean(np.abs(y - y_gl))) 68 | print('|P - P_gl|:', np.mean(np.abs(P - P_gl))) 69 | print('|S - S_gl|:', np.mean(np.abs(S - S_gl))) 70 | print('|logS - logS_gl|:', np.mean(np.abs(logS - logS_gl))) 71 | print() 72 | 73 | 74 | print('>> reverse with original magnitude & random phase') 75 | S_r, logS_r = S, logS 76 | P_r = np.random.uniform(low=-np.pi, high=np.pi, size=P.shape) 77 | D_r = S_r * np.exp(1j * P_r) 78 | y_r = L.istft(D_r, win_length=hp.win_length, hop_length=hp.hop_length, length=ylen) 79 | D_rr = L.stft(y_r, n_fft=hp.n_fft, hop_length=hp.hop_length, win_length=hp.win_length) 80 | P_rr, S_rr, logS_rr = decompose(D_rr) 81 | 82 | print('|y - y_r|:', np.mean(np.abs(y - y_r))) 83 | print('|P - P_r|:', np.mean(np.abs(P - P_r))) 84 | print('|P - P_rr|:', np.mean(np.abs(P - P_rr))) 85 | print('|P_r - P_rr|:', np.mean(np.abs(P_r - P_rr))) 86 | print('|S - S_rr|:', np.mean(np.abs(S - S_rr))) 87 | print('|logS - logS_rr|:', np.mean(np.abs(logS - logS_rr))) 88 | print() 89 | 90 | print('>> reverse with random magnitude & original phase') 91 | # 这些都不行 92 | #S_f = np.exp(np.random.normal(loc=logS.mean(), scale=logS.std(), size=S.shape))/10 93 | # S_f = np.random.normal(loc=S.mean(), scale=S.std(), size=S.shape) 94 | # 这些可以 95 | #S_f = np.random.normal(loc=logS.mean(), scale=logS.std(), size=S.shape) 96 | #S_f = np.random.normal(loc=S.mean(), scale=S.mean(), size=S.shape) 97 | S_f = np.random.uniform(low=S.min(), high=S.max(), size=S.shape) 98 | P_f = P 99 | D_f = S_f * np.exp(1j * P_f) 100 | y_f = L.istft(D_f, win_length=hp.win_length, hop_length=hp.hop_length, length=ylen) 101 | D_fr = L.stft(y_f, n_fft=hp.n_fft, hop_length=hp.hop_length, win_length=hp.win_length) 102 | P_fr, S_fr, logS_fr = decompose(D_fr) 103 | 104 | print('|y - y_f|:', np.mean(np.abs(y - y_f))) 105 | print('|P - P_fr|:', np.mean(np.abs(P - P_fr))) 106 | print('|S - S_f|:', np.mean(np.abs(S - S_f))) 107 | print('|S - S_fr|:', np.mean(np.abs(S - S_fr))) 108 | print() 109 | 110 | #if True: sns.heatmap(S_f) ; plt.show() 111 | #if True: sns.heatmap(P_fr) ; plt.show() 112 | if True: sns.heatmap(S_fr) ; plt.show() 113 | -------------------------------------------------------------------------------- /retunegan/tools/test_strip_mirror.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Author: Armit 3 | # Create Time: 2022/03/10 4 | 5 | import os 6 | import random as R 7 | import librosa as L 8 | import seaborn as sns 9 | import matplotlib.pyplot as plt 10 | import numpy as np 11 | from scipy.io import wavfile 12 | import torch 13 | import torch.nn as nn 14 | 15 | import hparam as hp 16 | 17 | #R.seed(114514) 18 | 19 | 20 | def decompose(D): 21 | P, S = np.angle(D), np.abs(D) 22 | logS = np.log(S.clip(1e-5, None)) 23 | return P, S, logS 24 | 25 | 26 | dp = r'C:\Users\Kahsolt\Desktop\Workspace\Data\DataBaker\Wave' 27 | fn = R.choice([fn for fn in os.listdir(dp) if fn.endswith('.wav')]) 28 | fp = os.path.join(dp, fn) 29 | 30 | sr, hsr = hp.sample_rate, hp.sample_rate//2 31 | y = L.load(fp, hp.sample_rate)[0] 32 | if len(y) % 2 != 0: y = y[:-1] 33 | ylen = len(y) 34 | 35 | avgpool = nn.AvgPool1d(hp.downsample_pool_k, 2) 36 | Ty = torch.from_numpy(y).unsqueeze(0).unsqueeze(0) 37 | for i in range(3): 38 | if Ty.shape[-1] % 2 != 0: Ty = Ty[:,:,:-1] 39 | even, odd = Ty[:,:,::2], Ty[:,:,1::2] 40 | diff = even - odd 41 | print(f'strip_mirror_loss({i}):', torch.mean(torch.abs(diff)).item()) 42 | Ty = avgpool(Ty) 43 | 44 | 45 | even, odd = y[::2], y[1::2] 46 | mean = (even + odd) / 2 47 | diff = even - odd 48 | print('strip_mirror_loss:', np.mean(np.abs(diff))) 49 | 50 | if False: 51 | plt.subplot(4, 1, 1) ; plt.plot(y) 52 | plt.subplot(4, 1, 2) ; plt.plot(even) 53 | plt.subplot(4, 1, 3) ; plt.plot(odd) 54 | plt.subplot(4, 1, 4) ; plt.plot(diff) 55 | plt.show() 56 | 57 | if False: 58 | wavfile.write('y.wav', sr, y) 59 | wavfile.write('even.wav', hsr, even) 60 | wavfile.write('odd.wav', hsr, odd) 61 | wavfile.write('diff.wav', hsr, diff) 62 | 63 | De = L.stft(even, n_fft=hp.n_fft, hop_length=hp.hop_length, win_length=hp.win_length) 64 | Do = L.stft(odd, n_fft=hp.n_fft, hop_length=hp.hop_length, win_length=hp.win_length) 65 | Dd = L.stft(diff, n_fft=hp.n_fft, hop_length=hp.hop_length, win_length=hp.win_length) 66 | Dm = L.stft(mean, n_fft=hp.n_fft, hop_length=hp.hop_length, win_length=hp.win_length) 67 | Pe, Se, logSe = decompose(De) 68 | Po, So, logSo = decompose(Do) 69 | Pd, Sd, logSd = decompose(Dd) 70 | Pm, Sm, logSm = decompose(Dm) 71 | 72 | print('|Pe - Po|:', np.mean(np.abs(Pe - Po))) 73 | print('|Se - So|:', np.mean(np.abs(Se - So))) 74 | print('|logSe - logSo|:', np.mean(np.abs(logSe - logSo))) 75 | print() 76 | 77 | print('|Pd - Pe|:', np.mean(np.abs(Pd - Pe))) 78 | print('|Sd - Se|:', np.mean(np.abs(Sd - Se))) 79 | print('|logSd - logSe|:', np.mean(np.abs(logSd - logSe))) 80 | print('|Pd - Po|:', np.mean(np.abs(Pd - Po))) 81 | print('|Sd - So|:', np.mean(np.abs(Sd - So))) 82 | print('|logSd - logSo|:', np.mean(np.abs(logSd - logSo))) 83 | print() 84 | 85 | 86 | -------------------------------------------------------------------------------- /retunegan/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | from time import time 4 | 5 | import torch 6 | import matplotlib 7 | import hparam as hp 8 | if not hp.debug: matplotlib.use("Agg") 9 | import matplotlib.pylab as plt 10 | 11 | LRELU_SLOPE = 0.15 12 | PI = 3.14159265358979 13 | 14 | 15 | # plot 16 | def plot_spectrogram(spec): 17 | fig, ax = plt.subplots(figsize=(10, 2)) 18 | im = ax.imshow(spec, aspect="auto", origin="lower", interpolation='none') 19 | plt.colorbar(im, ax=ax) 20 | fig.canvas.draw() 21 | plt.close() 22 | return fig 23 | 24 | 25 | # network 26 | def init_weights(m, mean=0.0, std=0.01): 27 | if 'Conv' in m.__class__.__name__: 28 | #m.weight.data.normal_(mean, std) 29 | torch.nn.init.kaiming_normal_(m.weight.data, nonlinearity='leaky_relu', a=LRELU_SLOPE) 30 | 31 | 32 | def get_padding(kernel_size, dilation=1): 33 | return (kernel_size * dilation - dilation) // 2 34 | 35 | 36 | def get_same_padding(kernel_size, dilation=1): 37 | return dilation * (kernel_size // 2) 38 | 39 | 40 | def truncate_align(x, y): 41 | d = x.shape[-1] - y.shape[-1] 42 | if d != 0: 43 | print('[truncate_align] x.shape:', x.shape, 'y.shape:', y.shape) 44 | if d > 0: x = x[:, :, d //2 : -( d - d //2)] 45 | elif d < 0: y = y[:, :, (-d)//2 : -((-d) - (-d)//2)] 46 | return x, y 47 | 48 | 49 | def get_param_cnt(model): 50 | return sum(param.numel() for param in model.parameters()) 51 | 52 | 53 | def stat_grad(model, name): 54 | max_grad, min_grad = 0, 1e5 55 | for p in model.parameters(): 56 | vabs = p.abs() 57 | vmin, vmax = vabs.min(), vabs.max() 58 | if vmin < min_grad: min_grad = vmin 59 | if vmax > max_grad: max_grad = vmax 60 | print(f'grad_{name}: max = {max_grad.item()}, min={min_grad.item()}') 61 | 62 | 63 | # ckpt 64 | def load_checkpoint(fp, device): 65 | assert os.path.isfile(fp) 66 | print(f"Loading '{fp}'") 67 | checkpoint_dict = torch.load(fp, map_location=device) 68 | print("Complete.") 69 | return checkpoint_dict 70 | 71 | 72 | def save_checkpoint(fp, obj): 73 | print(f"Saving checkpoint to {fp}") 74 | torch.save(obj, fp) 75 | print("Complete.") 76 | 77 | 78 | def scan_checkpoint(dp, prefix): 79 | pattern = os.path.join(dp, prefix + '*') 80 | cp_list = glob.glob(pattern) 81 | return len(cp_list) and sorted(cp_list)[-1] or None 82 | 83 | 84 | # decorator 85 | def timer(fn): 86 | def wrapper(*args, **kwargs): 87 | start = time() 88 | r = fn(*args, **kwargs) 89 | end = time() 90 | print(f'[Timer]: {fn.__name__} took {end - start:.2f}') 91 | return r 92 | return wrapper 93 | -------------------------------------------------------------------------------- /stats/DataBaker-stats.txt: -------------------------------------------------------------------------------- 1 | { 'acode_len:avg': 312.1219, 2 | 'acode_len:max': 668, 3 | 'acode_len:min': 66, 4 | 'ap:max': 0.0, 5 | 'ap:min': -35.58398, 6 | 'dyn:max': 0.37510493, 7 | 'dyn:min': 0.0, 8 | 'f0:max': 7577.7227, 9 | 'f0:min': 0.0, 10 | 'mag:max': 2.3220108, 11 | 'mag:min': -8.0, 12 | 'mel:max': 0.3767597, 13 | 'mel:min': -8.0, 14 | 'n_frames': 3121219, 15 | 'n_hours': 40.26364646006551, 16 | 'n_utterances': 10000, 17 | 'num_note:max': 125, 18 | 'num_note:min': 37, 19 | 'pit:max': 11025.0, 20 | 'pit:min': 70.0, 21 | 'sp:max': 6.6273775, 22 | 'sp:min': -36.75616, 23 | 'text_len:avg': 16.2864, 24 | 'text_len:max': 34, 25 | 'text_len:min': 3} 26 | 27 | 28 | 最长 007537.wav 29 | 最短 007245.wav 30 | -------------------------------------------------------------------------------- /stats/DataBaker.stats: -------------------------------------------------------------------------------- 1 | total_examples 9559 2 | total_hours 9.463644041320231 3 | min_len_txt 6 4 | max_len_txt 27 5 | avg_len_txt 16.08818914112355 6 | min_len_wav 25856 7 | max_len_wav 134144 8 | avg_len_wav 78588.14352965791 9 | min_len_spec 101 10 | max_len_spec 524 11 | avg_len_spec 306.98493566272623 12 | max_mel 2.9994443721145663 13 | min_mel -5.6 14 | max_mag 4.912159066528323 15 | min_mag -5.6 16 | max_f0 595.9459228515625 17 | min_f0 73.25581359863281 18 | max_c0 0.3751049339771271 19 | min_c0 4.6309418394230306e-05 20 | -------------------------------------------------------------------------------- /stats/DataBaker_gen_stat.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Author: Armit 3 | # Create Time: 2021/01/07 4 | 5 | import tgt 6 | import numpy as np 7 | import pandas as pd 8 | from pathlib import Path 9 | from collections import defaultdict 10 | from os import chdir as cd, getcwd as pwd, listdir 11 | 12 | BASE_PATH = Path(__file__).parent.absolute() 13 | WORKING_DIR = BASE_PATH / 'DataBaker.preproc' / 'TextGrid' 14 | STAT_OUT_FILE_FMT = BASE_PATH / 'DataBaker.stat-%s.csv' 15 | 16 | def collect_stat(by_name='phones'): 17 | durdict = defaultdict(list) 18 | for fn in listdir(): 19 | tg = tgt.read_textgrid(fn) 20 | for ph in tg.get_tier_by_name(by_name).intervals: 21 | durdict[ph.text].append(ph.duration()) 22 | 23 | stat = {k: (len(v), np.mean(v), np.std(v), np.min(v), np.max(v)) for k, v in durdict.items()} 24 | df = pd.DataFrame(stat, index=['freq', 'mean', 'std', 'min', 'max']).T 25 | df.to_csv(str(STAT_OUT_FILE_FMT) % by_name) 26 | 27 | if __name__ == '__main__': 28 | savedp = pwd() 29 | cd(WORKING_DIR) 30 | for name in ['words', 'phones']: 31 | collect_stat(name) 32 | cd(savedp) 33 | -------------------------------------------------------------------------------- /stats/DataBaker_print_pinyins.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Author: Armit 3 | # Create Time: 2020/12/16 4 | 5 | # 收集DataBaker中出现过的拼音音节,打印其列表 6 | 7 | import re 8 | from math import ceil 9 | from pathlib import Path 10 | from os import chdir as cd, getcwd as pwd 11 | 12 | SPEC_DIR = 'DataBaker.spec' 13 | BASE_PATH = Path(__file__).parent.absolute() / SPEC_DIR 14 | INDEX_FILE = 'train.txt' 15 | 16 | def collect_symbols(): 17 | pinyins = set() 18 | with open(INDEX_FILE) as fh: 19 | samples = fh.read().split('\n') 20 | for s in samples: 21 | if len(s) == 0: continue 22 | txt = s.split('|')[-1] 23 | pinyins = pinyins.union(txt.split(' ')) 24 | return sorted(list(pinyins)) 25 | 26 | def pprint_symbols(symbols): 27 | n_sym = len(symbols) 28 | SYM_PER_LINE = 15 29 | n_line = ceil(n_sym / SYM_PER_LINE) 30 | 31 | print('_pinyin = [', end='') 32 | for idx, sym in enumerate(symbols): 33 | col = idx % SYM_PER_LINE 34 | if col == 0: print('\n ', end='') 35 | print(f"'{sym}', ", end='') 36 | print(']') 37 | 38 | if __name__ == '__main__': 39 | savedp = pwd() 40 | cd(BASE_PATH) 41 | pprint_symbols(collect_symbols()) 42 | cd(savedp) 43 | -------------------------------------------------------------------------------- /stats/DataBaker_print_symbols.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Author: Armit 3 | # Create Time: 2020/12/16 4 | 5 | # 收集DataBaker中出现过的拼音音节,打印其列表 6 | 7 | import re 8 | from math import ceil 9 | from pathlib import Path 10 | from os import chdir as cd, getcwd as pwd 11 | 12 | SPEC_DIR = 'DataBaker.preproc' 13 | BASE_PATH = Path(__file__).parent.absolute() / SPEC_DIR 14 | INDEX_FILES = ['train.txt', 'val.txt'] 15 | 16 | def collect_symbols(): 17 | symbols = set() 18 | for fn in INDEX_FILES: 19 | with open(fn) as fh: 20 | samples = fh.read().split('\n') 21 | for s in samples: 22 | if len(s) == 0: continue 23 | _, txt = s.split('|') 24 | symbols = symbols.union(txt[1:-1].split(' ')) 25 | return sorted(list(symbols)) 26 | 27 | def pprint_symbols(symbols): 28 | n_sym = len(symbols) 29 | SYM_PER_LINE = 15 30 | n_line = ceil(n_sym / SYM_PER_LINE) 31 | 32 | print('_pinyin = [', end='') 33 | for idx, sym in enumerate(symbols): 34 | col = idx % SYM_PER_LINE 35 | if col == 0: print('\n ', end='') 36 | print(f"'{sym}', ", end='') 37 | print(']') 38 | 39 | if __name__ == '__main__': 40 | savedp = pwd() 41 | cd(BASE_PATH) 42 | pprint_symbols(collect_symbols()) 43 | cd(savedp) 44 | -------------------------------------------------------------------------------- /stats/inspect_preproc.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Author: Armit 3 | # Create Time: 2021/03/04 4 | 5 | from sys import argv 6 | import numpy as np 7 | from random import randrange 8 | import matplotlib.pyplot as plt 9 | import seaborn as sns 10 | 11 | ID = (len(argv) >= 2 and argv[1] or str(randrange(10000) + 1)).rjust(6,'0') 12 | 13 | plt.subplot(311) 14 | plt.title('energy') 15 | e = np.load('DataBaker.preproc/energy/DataBaker-energy-%s.npy' % ID) 16 | plt.xlim(0, len(e)) 17 | plt.plot(e) 18 | 19 | plt.subplot(312) 20 | plt.title('f0') 21 | f = np.load('DataBaker.preproc/f0/DataBaker-f0-%s.npy' % ID) 22 | plt.xlim(0, len(f)) 23 | plt.plot(f) 24 | 25 | plt.subplot(313) 26 | plt.title('mel') 27 | m = np.load('DataBaker.preproc/mel/DataBaker-mel-%s.npy' % ID) 28 | sns.heatmap(m.T[::-1], cbar=False) 29 | 30 | plt.show() 31 | -------------------------------------------------------------------------------- /stats/inspect_spec.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Author: Armit 3 | # Create Time: 2021/03/23 4 | 5 | from sys import argv 6 | import numpy as np 7 | from random import randrange 8 | import matplotlib.pyplot as plt 9 | import seaborn as sns 10 | 11 | ID = (len(argv) >= 2 and argv[1] or str(randrange(20) + 1)).rjust(6,'0') 12 | 13 | plt.subplot(211) 14 | plt.title('linear spec') 15 | m = np.load('DataBaker.spec/databaker-spec-%s.npy' % ID) 16 | sns.heatmap(m.T[::-1], cbar=False) 17 | 18 | plt.subplot(212) 19 | plt.title('mel spec') 20 | m = np.load('DataBaker.spec/databaker-mel-%s.npy' % ID) 21 | sns.heatmap(m.T[::-1], cbar=False) 22 | 23 | plt.show() 24 | -------------------------------------------------------------------------------- /stats/thchs30_gen_vbanks.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Author: Armit 3 | # Create Time: 2020/11/14 4 | 5 | # 将thchs30按音色分库,产生几个train.txt 6 | 7 | import re 8 | from pathlib import Path 9 | from os import chdir as cd, getcwd as pwd 10 | from collections import defaultdict 11 | 12 | SPEC_DIR = 'thchs30.spec' 13 | INDEX_FILE = 'train.txt' 14 | BASE_PATH = Path(__file__).parent.absolute() / SPEC_DIR 15 | R = re.compile(r'-([ABCD]\d+)_') 16 | 17 | MALE_LIST = [ 'A8', 'B8', 'C8', 'D8' ] 18 | FEMALE_POWER_LIST = [ 'A2', 'A4', 'A6', 'A14', 'A22', 'A34', 'B4', 'B6', 'B12', 'B22', 'B31', 'C4', 'C6', 'C31', 'D6', 'D31', 'D32' ] 19 | FEMALE_SOFT_LIST = [ 'A7', 'A11', 'A19', 'B7', 'C7', 'C14', 'C17', 'C18', 'C20', 'C32', 'D7', 'D11' ] 20 | CHILD_LIST = [ 'A13', 'B11', 'C12', 'C13', 'C19', 'C21', 'C22', 'D21' ] 21 | 22 | # => {'uid': ['sample_config1', ...]} 23 | def read_index(fn=INDEX_FILE) -> defaultdict: 24 | index_dict = defaultdict(list) 25 | with open(INDEX_FILE) as fh: 26 | samples = fh.read().split('\n') 27 | for s in samples: 28 | if len(s) == 0: continue 29 | spec_fn, mel_fn, _, txt = s.split('|') 30 | id = R.findall(spec_fn)[0] 31 | index_dict[id].append(s) 32 | return index_dict 33 | 34 | def write_index(fn, vbank): 35 | with open(fn, 'w') as fh: 36 | for s in vbank: 37 | fh.write(s) 38 | fh.write('\n') 39 | 40 | # => ['sample_config1', ...] 41 | def gen_index(index, vt) -> list: 42 | uid_list = globals().get(vt.upper() + '_LIST', []) 43 | vbank = [ ] 44 | for uid in uid_list: 45 | vbank.extend(index[uid]) 46 | return vbank 47 | 48 | if __name__ == '__main__': 49 | savedp = pwd() 50 | cd(BASE_PATH) 51 | index = read_index() 52 | for vt in ['male', 'female_power', 'female_soft', 'child']: 53 | vbank = gen_index(index, vt) 54 | write_index('vbank_' + vt + '.txt', vbank) 55 | cd(savedp) 56 | -------------------------------------------------------------------------------- /stats/thchs30_print_symbols.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Author: Armit 3 | # Create Time: 2020/11/23 4 | 5 | # 收集thchs30中出现过的拼音音节,打印其列表 6 | 7 | import re 8 | from math import ceil 9 | from pathlib import Path 10 | from os import chdir as cd, getcwd as pwd 11 | 12 | SPEC_DIR = 'thchs30.spec' 13 | BASE_PATH = Path(__file__).parent.absolute() / SPEC_DIR 14 | INDEX_FILE = 'train.txt' 15 | 16 | def collect_symbols(): 17 | symbols = set() 18 | with open(INDEX_FILE) as fh: 19 | samples = fh.read().split('\n') 20 | for s in samples: 21 | if len(s) == 0: continue 22 | _, _, _, txt = s.split('|') 23 | symbols = symbols.union(txt.split(' ')) 24 | return sorted(list(symbols)) 25 | 26 | def pprint_symbols(symbols): 27 | n_sym = len(symbols) 28 | SYM_PER_LINE = 15 29 | n_line = ceil(n_sym / SYM_PER_LINE) 30 | 31 | print('_pinyin = [', end='') 32 | for idx, sym in enumerate(symbols): 33 | col = idx % SYM_PER_LINE 34 | if col == 0: print('\n ', end='') 35 | print(f"'{sym}', ", end='') 36 | print(']') 37 | 38 | 39 | if __name__ == '__main__': 40 | savedp = pwd() 41 | cd(BASE_PATH) 42 | pprint_symbols(collect_symbols()) 43 | cd(savedp) 44 | -------------------------------------------------------------------------------- /transtacos/Makefile: -------------------------------------------------------------------------------- 1 | ifeq ($(shell uname -s), Linux) 2 | BASE_PATH=~/Data 3 | else 4 | BASE_PATH=D:/Desktop/Workspace/Data 5 | endif 6 | 7 | DATASET=DataBaker 8 | #LOG_NAME=tts-$(DATASET).$(VER) 9 | LOG_NAME=tts-$(DATASET) 10 | LOG_PATH=${BASE_PATH}/$(LOG_NAME) 11 | 12 | 13 | .PHONY: train test server clean stat 14 | 15 | train: 16 | python train.py \ 17 | --base_dir $(BASE_PATH) \ 18 | --input $(DATASET).tts_processed/train.txt \ 19 | --name $(LOG_NAME) \ 20 | --summary_interval 500 \ 21 | --checkpoint_interval 1000 22 | 23 | preprocess: 24 | python preprocess.py \ 25 | --base_dir $(BASE_PATH) \ 26 | --out_dir $(DATASET).tts_processed \ 27 | --dataset $(shell echo $(DATASET) | tr '[A-Z]' '[a-z]') 28 | 29 | server: 30 | python server.py \ 31 | --log_path $(LOG_PATH) 32 | 33 | test_server: 34 | python server.py \ 35 | --log_path $(LOG_PATH) \ 36 | --port 5103 37 | 38 | stat: 39 | tensorboard \ 40 | --logdir $(LOG_PATH) \ 41 | --port 5103 42 | 43 | clean: 44 | rm -rf $(LOG_PATH) 45 | -------------------------------------------------------------------------------- /transtacos/audio.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Author: Armit 3 | # Create Time: 2022/01/07 4 | 5 | import numpy as np 6 | import librosa as L 7 | from scipy import signal 8 | from scipy.io import wavfile 9 | 10 | import hparam as hp 11 | 12 | 13 | eps = 1e-5 14 | firwin = signal.firwin(hp.n_freq, [hp.fmin, hp.fmax], pass_zero=False, fs=hp.sample_rate) 15 | rf0min = L.note_to_hz(hp.rf0min) if isinstance(hp.rf0min, str) else float(hp.rf0min) 16 | rf0max = L.note_to_hz(hp.rf0max) if isinstance(hp.rf0max, str) else float(hp.rf0max) 17 | c0min = hp.c0min 18 | c0max = hp.c0max 19 | qt_f0min = int(np.floor(L.hz_to_midi(hp.f0min))) 20 | qt_f0max = int(np.ceil (L.hz_to_midi(hp.f0max))) 21 | 22 | hp.n_f0_min = qt_f0min 23 | hp.n_f0_bins = qt_f0max - qt_f0min + 1 24 | 25 | print(f'c0: min={c0min} max={c0max} n_bins={hp.n_c0_bins}') 26 | print(f'qt_f0: min={qt_f0min} max={qt_f0max} n_bins={hp.n_f0_bins}') 27 | 28 | 29 | def load_wav(path): # float values in range (-1,1) 30 | y, _ = L.load(path, sr=hp.sample_rate, mono=True, res_type='kaiser_best') 31 | return y.astype(np.float32) # [T,] 32 | 33 | def save_wav(wav, path): 34 | if hp.postprocess: 35 | # rescaling for unified measure for all clips 36 | # NOTE: normalize amplification 37 | wav = wav / np.abs(wav).max() * 0.999 38 | # factor 0.5 in case of overflow for int16 39 | f1 = 0.5 * 32767 / max(0.01, np.max(np.abs(wav))) 40 | # sublinear scaling as Y ~ X ^ k (k < 1) 41 | f2 = np.sign(wav) * np.power(np.abs(wav), 0.667) 42 | wav = f1 * f2 43 | 44 | # bandpass for less noises 45 | wav = signal.convolve(wav, firwin) 46 | 47 | wavfile.write(path, hp.sample_rate, wav.astype(np.int16)) 48 | else: 49 | wavfile.write(path, hp.sample_rate, wav.astype(np.float32)) 50 | 51 | 52 | def align_wav(wav, r=hp.hop_length): 53 | d = len(wav) % r 54 | if d != 0: 55 | wav = np.pad(wav, (0, (r - d))) 56 | return wav 57 | 58 | 59 | def trim_silence(wav, frame_length=512, hop_length=128): 60 | # 人声动态一般高达55dB 61 | return L.effects.trim(wav, top_db=hp.trim_below_peak_db, frame_length=frame_length, hop_length=hop_length)[0] 62 | 63 | 64 | def preemphasis(x): # 增强了高频, 听起来有点远场 65 | # x[i] = x[i] - k * x[i-1], k ie. preemphasis 66 | return signal.lfilter([1, -hp.preemphasis], [1], x) 67 | 68 | 69 | def inv_preemphasis(x): # undo preemphasis, should be put after Griffin-Lim 70 | return signal.lfilter([1], [1, -hp.preemphasis], x) 71 | 72 | 73 | def get_specs(y): 74 | D = np.abs(_stft(preemphasis(y))) 75 | S = _amp_to_db(D) - hp.ref_level_db 76 | M = _amp_to_db(_linear_to_mel(D)) - hp.ref_level_db 77 | return (_normalize(S), _normalize(M)) 78 | 79 | 80 | def spec_to_natural_scale(spec): 81 | '''from inner normalized scale to raw scale of stft output''' 82 | return _db_to_amp(_denormalize(spec) + hp.ref_level_db) 83 | 84 | 85 | def fix_zero_DC(S): 86 | F, T = S.shape 87 | if F == hp.n_freq - 1: # NOTE: preprend zero DC component 88 | S = np.concatenate([np.ones([1, T]) * S.min() * 1e-2, S], axis=0) 89 | #S = np.concatenate([np.zeros([1, T]), S], axis=0) 90 | return S 91 | 92 | 93 | def inv_spec(spec): 94 | S = spec_to_natural_scale(spec) # denorm 95 | S = fix_zero_DC(S) 96 | wav = inv_preemphasis(_griffin_lim(S ** hp.gl_power)) # reconstruct phase 97 | return wav.astype(np.float32) 98 | 99 | 100 | def inv_mel(mel): # This might have no use case 101 | M = spec_to_natural_scale(mel) # denorm 102 | S = _mel_to_linear(M) # back to linear 103 | wav = inv_preemphasis(_griffin_lim(S ** hp.gl_power)) # reconstruct phase 104 | return wav.astype(np.float32) 105 | 106 | 107 | def get_f0(y): 108 | f0 = L.yin(y, fmin=rf0min, fmax=rf0max, frame_length=hp.win_length, hop_length=hp.hop_length) 109 | return f0.astype(np.float32) # [T,] 110 | 111 | 112 | def get_c0(y): 113 | c0 = L.feature.rms(y=y, frame_length=hp.win_length, hop_length=hp.hop_length)[0] 114 | return c0.astype(np.float32) # [T,] 115 | 116 | 117 | def quantilize_f0(f0): 118 | f0 = np.asarray([L.hz_to_midi(f) - hp.n_f0_min for f in f0]) 119 | f0 = f0.clip(0, hp.n_f0_bins - 1) 120 | return f0.astype(np.int32) # [T,] 121 | 122 | 123 | def quantilize_c0(c0): 124 | c0 = (c0 - c0min) / (c0max - c0min) 125 | c0 = c0 * hp.n_c0_bins 126 | c0 = c0.clip(0, hp.n_c0_bins - 1) 127 | return c0.astype(np.int32) # [T,] 128 | 129 | 130 | def _griffin_lim(S): 131 | '''librosa implementation of Griffin-Lim 132 | Based on https://github.com/librosa/librosa/issues/434 133 | ''' 134 | angles = np.exp(2j * np.pi * np.random.rand(*S.shape)) 135 | S_complex = np.abs(S).astype(np.complex) 136 | y = _istft(S_complex * angles) 137 | for i in range(hp.gl_iters): 138 | angles = np.exp(1j * np.angle(_stft(y))) 139 | y = _istft(S_complex * angles) 140 | return y 141 | 142 | 143 | def _stft(y): 144 | return L.stft(y=y, n_fft=hp.n_fft, hop_length=hp.hop_length, win_length=hp.win_length) 145 | 146 | 147 | def _istft(y): 148 | return L.istft(y, hop_length=hp.hop_length, win_length=hp.win_length) 149 | 150 | 151 | _mel_basis = None 152 | _linear_basis = None 153 | 154 | def _linear_to_mel(spec): 155 | return np.dot(_get_mel_basis(), spec) 156 | 157 | def _get_mel_basis(): 158 | global _mel_basis 159 | if _mel_basis is None: 160 | assert hp.fmax < hp.sample_rate // 2 161 | _mel_basis = L.filters.mel(hp.sample_rate, hp.n_fft, n_mels=hp.n_mel, fmin=hp.fmin, fmax=hp.fmax) 162 | return _mel_basis 163 | 164 | def _mel_to_linear(mel): 165 | return np.dot(_get_linear_basis(), mel) 166 | 167 | def _get_linear_basis(): 168 | global _linear_basis 169 | if _linear_basis is None: 170 | m = _get_mel_basis() 171 | m_T = np.transpose(m) 172 | p = np.matmul(m, m_T) 173 | d = [1.0 / x if np.abs(x) > 1.0e-8 else x for x in np.sum(p, axis=0)] 174 | _linear_basis = np.matmul(m_T, np.diag(d)) 175 | return _linear_basis 176 | 177 | def _amp_to_db(x): 178 | # 人耳可听的声压范围为2e-5~20Pa,对应的声压级范围为0~120dB 179 | # 声压级公式 SPL = 20 * log10(p_e/p_ref) 180 | # 其中p_e为声压/振幅,参考声压p_ref为人耳最低可听觉声压、空气中一般取2e-5 181 | # 即有 SPL = 20 * log10(p_e/2e-5) 182 | # = 20 * (log10(p_e) - log10(2e-5)) 183 | # ~= 20 * log10(p_e) - 94 184 | return 20 * np.log10(np.maximum(1e-5, x)) # 下截断、可有可无,只是作heatmap时方便查看 185 | 186 | def _db_to_amp(x): 187 | # =10^(x/20) 188 | return np.power(10.0, x * 0.05) 189 | 190 | def _normalize(S): 191 | # mapping [hp.min_level_db, 0] => [-hp.max_abs_value, hp.max_abs_value] 192 | # typically: [-100, 0] => [-4, 4] 193 | return 2 * hp.max_abs_value * ((S - hp.min_level_db) / -hp.min_level_db) - hp.max_abs_value 194 | 195 | def _denormalize(S): 196 | return ((S + hp.max_abs_value) * -hp.min_level_db) / (2 * hp.max_abs_value) + hp.min_level_db 197 | -------------------------------------------------------------------------------- /transtacos/data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | from threading import Thread 4 | import traceback 5 | 6 | import numpy as np 7 | import tensorflow as tf 8 | 9 | import hparam as hp 10 | from text.text import text_to_phoneme, phoneme_to_sequence 11 | from text.symbols import _eos, _sep, get_vocab_size 12 | from text.phonodict_cn import phonodict 13 | from audio import quantilize_c0, quantilize_f0 14 | 15 | 16 | _batches_per_group = hp.batch_size 17 | _pad = 0 # FIXME: hardcoded for /'_' fortext-token 18 | 19 | 20 | class DataFeeder(Thread): 21 | '''Feeds batches of data into a queue on a background thread.''' 22 | 23 | def __init__(self, coordinator, metadata_fp, hparams): 24 | super(DataFeeder, self).__init__() 25 | self._hparams = hparams 26 | self._session = None 27 | self._coord = coordinator 28 | self._offset = 0 29 | 30 | # Load metadata & make cache: 31 | self._datadir = os.path.dirname(metadata_fp) 32 | with open(metadata_fp, encoding='utf-8') as f: 33 | self._metadata = [line.strip().split('|') for line in f] 34 | self.data = [None] * len(self._metadata) 35 | 36 | # Create placeholders for inputs and targets. Don't specify batch size because we want to 37 | # be able to feed different sized batches at eval time. 38 | if hp.g2p == 'seq': 39 | text_shape = [None, None] 40 | elif hp.g2p == 'syl4': 41 | text_shape = [None, None, 2] 42 | self._placeholders = [ 43 | tf.placeholder(tf.int32, [None], 'text_lengths'), 44 | tf.placeholder(tf.int32, text_shape, 'text'), 45 | tf.placeholder(tf.int32, [None, None], 'prds'), 46 | tf.placeholder(tf.int32, [None], 'spec_lengths'), 47 | tf.placeholder(tf.float32, [None, None, hparams.n_mel], 'mel_targets'), 48 | tf.placeholder(tf.float32, [None, None, hparams.n_freq-1], 'mag_targets'), 49 | tf.placeholder(tf.int32, [None, None], 'f0_targets'), 50 | tf.placeholder(tf.int32, [None, None], 'c0_targets'), 51 | tf.placeholder(tf.float32, [None, None], 'stop_token_targets') 52 | ] 53 | 54 | # Create queue for buffering data: 55 | queue = tf.FIFOQueue(hp.batch_size, [h.dtype for h in self._placeholders], name='input_queue') 56 | self._enqueue_op = queue.enqueue(self._placeholders) 57 | holders = queue.dequeue() 58 | for i, holder in enumerate(holders): holder.set_shape(self._placeholders[i].shape) 59 | (self.text_lengths, self.text, self.prds, self.spec_lengths, self.mel_targets, self.mag_targets, self.f0_targets, self.c0_targets, self.stop_token_targets) = holders 60 | 61 | def start_in_session(self, session): 62 | self._session = session 63 | self.start() 64 | 65 | def run(self): 66 | try: 67 | while not self._coord.should_stop(): 68 | self._enqueue_next_group() 69 | except Exception as e: 70 | traceback.print_exc() 71 | self._coord.request_stop(e) 72 | 73 | def _enqueue_next_group(self): 74 | def _get_next_example(): 75 | '''Loads a single example (input, mel_target, mag_target, stop_token_target, len(spec)) from memory cached''' 76 | 77 | if self._offset >= len(self.data): # infinit loop 78 | self._offset = 0 79 | random.shuffle(self.data) 80 | 81 | if self.data[self._offset] is None: 82 | self.load_data(self._offset) 83 | data = self.data[self._offset] 84 | self._offset += 1 85 | return data 86 | 87 | # Read a group of examples: 88 | n = self._hparams.batch_size 89 | r = self._hparams.outputs_per_step 90 | examples = [_get_next_example() for _ in range(n * _batches_per_group)] # group = batch_size * batches_per_group 91 | 92 | # Bucket examples based on similar output sequence length (spec n_frames) for efficiency 93 | # NOTE: 按照输出的帧长度排序,而不是输入的文本长度! 94 | examples.sort(key=lambda x: len(x[-1])) 95 | batches = [examples[i:i+n] for i in range(0, len(examples), n)] # split to batches 96 | random.shuffle(batches) 97 | 98 | for batch in batches: 99 | feed_dict = dict(zip(self._placeholders, _prepare_batch(batch, r))) 100 | self._session.run(self._enqueue_op, feed_dict=feed_dict) 101 | 102 | def load_data(self, index): 103 | '''Loads all examples [(input, mel_target, mag_target, stop_token_target, len(spec))] from disk''' 104 | 105 | meta = self._metadata[index] # meta: (id, len(spec), text) 106 | id, prds, text = meta 107 | if hp.g2p == 'seq': 108 | # NOTE: pad here, then convert to id_seq 109 | seq = phoneme_to_sequence(text_to_phoneme(text + _eos)) 110 | prds = [int(d) for d in prds] 111 | elif hp.g2p == 'syl4': 112 | C, V, T, Vx = text_to_phoneme(text) # [[str]] 113 | prds = [int(d) for d in prds] 114 | try: 115 | assert len(C) == len(prds) 116 | except: 117 | breakpoint() 118 | 119 | CVVx, Tx, P = [ ], [ ], [ ] 120 | n_syllable = len(C) 121 | for i in range(n_syllable): 122 | if C[i] != phonodict.vacant: 123 | CVVx.append(C[i]) ; Tx.append(T[i]) ; P.append(0) 124 | if V[i] != phonodict.vacant: 125 | CVVx.append(V[i]) ; Tx.append(T[i]) ; P.append(0) 126 | if Vx[i] != phonodict.vacant: 127 | CVVx.append(Vx[i]) ; Tx.append(T[i]) ; P.append(0) 128 | 129 | CVVx.append(_sep) ; Tx.append(0) ; P.append(prds[i]) 130 | 131 | # NOTE: pad here, then convert to id_seq 132 | CVVx = phoneme_to_sequence(CVVx + [_eos]) # see phone table 133 | Tx = [int(t) for t in Tx] + [0] # should be 0 ~ 5 134 | for i in range(len(P) - 2, -1, -1): 135 | if P[i] == 0: 136 | P[i] = P[i + 1] 137 | P = P + [5] # should be 0 ~ 5 138 | 139 | try: 140 | assert len(CVVx) == len(Tx) == len(P) 141 | assert 0 <= min(CVVx) and max(CVVx) < get_vocab_size() 142 | assert 0 <= min(P) and max(P) < hp.n_prds 143 | assert 0 <= min(Tx) and max(Tx) < hp.n_tone 144 | except: 145 | breakpoint() 146 | 147 | seq = np.stack([CVVx, Tx], axis=-1) # [T, 2] 148 | prds = P 149 | else: raise 150 | 151 | text = np.asarray(seq, dtype=np.int32) 152 | prds = np.asarray(prds, dtype=np.int32) 153 | mel_target = np.load(os.path.join(self._datadir, f'mel-{id}.npy')).T # [T, F] 154 | mag_target = np.load(os.path.join(self._datadir, f'mag-{id}.npy')).T # [T, M] 155 | f0_target = np.load(os.path.join(self._datadir, f'f0-{id}.npy')) 156 | c0_target = np.load(os.path.join(self._datadir, f'c0-{id}.npy')) 157 | stop_token_target = np.zeros(mel_target.shape[0]) # NOTE: 在有数据的mel帧上,初始化停止概率为0,另参见`_pad_stop_token_target()` 158 | 159 | mag_target = mag_target[:, 1:] # remove DC 160 | f0_target = quantilize_f0(f0_target) 161 | c0_target = quantilize_c0(c0_target) 162 | try: 163 | assert 0 <= min(f0_target) and max(f0_target) < hp.n_f0_bins 164 | assert 0 <= min(c0_target) and max(c0_target) < hp.n_c0_bins 165 | except: 166 | breakpoint() 167 | 168 | #breakpoint() 169 | 170 | self.data[index] = (text, prds, mel_target, mag_target, f0_target, c0_target, stop_token_target) 171 | 172 | def _prepare_batch(batch, outputs_per_step): # FIXME: what for `outputs_per_step` 173 | # batch of data, one line per sample (input/id_seq, mel, mag, stop_token, len(spec)) 174 | random.shuffle(batch) 175 | # pad within batch 176 | text_lengths = np.asarray([len(x[0]) for x in batch], dtype=np.int32) 177 | if hp.g2p == 'seq': text = _prepare_inputs ([x[0] for x in batch]) 178 | elif hp.g2p == 'syl4': text = _prepare_inputs_2d([x[0] for x in batch]) 179 | else: raise 180 | prds = _prepare_inputs([x[1] for x in batch]) 181 | spec_lengths = np.asarray([len(x[2]) for x in batch], dtype=np.int32) 182 | mel_targets = _prepare_targets([x[2] for x in batch], outputs_per_step) 183 | mag_targets = _prepare_targets([x[3] for x in batch], outputs_per_step) 184 | f0_targets = _prepare_stop_token_targets([x[4] for x in batch], outputs_per_step, 0) 185 | c0_targets = _prepare_stop_token_targets([x[5] for x in batch], outputs_per_step, 0) 186 | stop_token_targets = _prepare_stop_token_targets([x[6] for x in batch], outputs_per_step, 1.0) 187 | return (text_lengths, text, prds, spec_lengths, mel_targets, mag_targets, f0_targets, c0_targets, stop_token_targets) 188 | 189 | 190 | def _prepare_inputs(inputs): 191 | def _pad_input(x:list, length): # pad =0 for text 192 | return np.pad(x, (0, length - x.shape[0]), mode='constant', constant_values=_pad) 193 | 194 | max_len = max((len(x) for x in inputs)) # 填充到该batch中最长的长度,seq在前面已经附加了 195 | return np.stack([_pad_input(x, max_len) for x in inputs]) 196 | 197 | 198 | def _prepare_inputs_2d(inputs): 199 | def _pad_input(x:list, length): # pad =0 for text 200 | return np.pad(x, [(0, length - x.shape[0]), (0, 0)], mode='constant', constant_values=_pad) 201 | 202 | max_len = max((len(x) for x in inputs)) # 填充到该batch中最长的长度,seq在前面已经附加了 203 | return np.stack([_pad_input(x, max_len) for x in inputs]) 204 | 205 | 206 | def _prepare_targets(targets, r): 207 | def _pad_target(x, length): # pad =0.0 for spec 208 | return np.pad(x, [(0, length - x.shape[0]), (0, 0)], mode='constant', constant_values=x.min()) 209 | 210 | max_len = max((len(t) for t in targets)) + 1 # +1 for 211 | max_len = _round_up(max_len, r) # 上取整到r的整数倍 212 | return np.stack([_pad_target(x, max_len) for x in targets]) 213 | 214 | 215 | def _prepare_stop_token_targets(targets, r, pad_val): 216 | def _pad_stop_token_target(x, length): 217 | return np.pad(x, (0, length - x.shape[0]), mode='constant', constant_values=pad_val) # NOTE: 对于填充的尾帧,初始化停止概率为1 218 | 219 | max_len = max((len(t) for t in targets)) + 1 # +1 for 220 | max_len = _round_up(max_len, r) # 上取整到r的整数倍 221 | return np.stack([_pad_stop_token_target(x, max_len) for x in targets]) 222 | 223 | 224 | def _round_up(x, multiple): # 向上捨入到multiple的整數倍 225 | remainder = x % multiple 226 | return x if remainder == 0 else (x + multiple - remainder) 227 | -------------------------------------------------------------------------------- /transtacos/datasets/__skel__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Author: Armit 3 | # Create Time: 2022/01/10 4 | 5 | # write your own dataset preprocesser 6 | # here's the template :) 7 | 8 | import os 9 | from typing import List, Tuple 10 | 11 | 12 | def preprocess(args) -> Tuple[List[Tuple], dict, str]: 13 | # wav_dp is a file path string, pointing to the folder containing *.wav files 14 | wav_dp = os.path.join(args.base_path, 'dataset', 'wavs') 15 | 16 | # metadata is a list containing textual informatin 17 | metadata = [ 18 | # for exmaple, name-text pairs 19 | ('00001', 'this is an exmaple'), 20 | ('00002', 'this is another exmaple'), 21 | ('00003', 'yet another exmaple'), 22 | ] 23 | 24 | # stats is a dictionary about statistcs 25 | stats = { 26 | 'min_len_txt': 18, 27 | 'max_len_txt': 23, 28 | 'avg_len_txt': 20.0, 29 | 'min_len_wav': 100, 30 | 'max_len_wav': 200, 31 | 'avg_len_wav': 150.0, 32 | } 33 | 34 | # return them all 35 | return metadata, stats, wav_dp 36 | -------------------------------------------------------------------------------- /transtacos/datasets/databaker.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Author: Armit 3 | # Create Time: 2022/01/10 4 | 5 | import os 6 | from re import compile as Regex 7 | from collections import defaultdict 8 | from concurrent.futures import ProcessPoolExecutor 9 | from functools import partial 10 | from typing import Dict, List, Tuple 11 | 12 | import numpy as np 13 | from tqdm import tqdm 14 | 15 | import hparam as hp 16 | import audio as A 17 | 18 | DROPOUT_2SIGMA = True 19 | 20 | 21 | # identically collected from `DataBaker` 22 | PUNCT_KANJI_REGEX = Regex(r',|。|、|:|;|?|!|(|)|“|”|…|—') 23 | 24 | 25 | def preprocess(args) -> Tuple[List[Tuple], dict]: 26 | wav_dp = os.path.join(args.base_dir, 'DataBaker', 'Wave') 27 | out_dp = os.path.join(args.base_dir, args.out_dir) 28 | os.makedirs(out_dp, exist_ok=True) 29 | label_dict = parse_label_file(os.path.join(args.base_dir, 'DataBaker', 'ProsodyLabeling', '000001-010000.txt')) 30 | 31 | executor = ProcessPoolExecutor(max_workers=args.num_workers) 32 | futures = [] 33 | for name, feats in label_dict.items(): 34 | wav_fp = os.path.join(wav_dp, f'{name}.wav') 35 | futures.append(executor.submit(partial(make_metadata, name, feats, wav_fp, out_dp))) 36 | # (name, prds, text, len_text, len_wav, len_spec, stats) 37 | metadata = [future.result() for future in tqdm(futures)] 38 | metadata = [mt for mt in metadata if mt is not None] 39 | 40 | # onely use sample within 2-sigma (95.45%) range of gauss distribution 41 | if DROPOUT_2SIGMA: 42 | tlens = np.asarray([mt[-4] for mt in metadata]) # mt[-4] := len_text 43 | tlens_mu, tlens_sigma = tlens.mean(), tlens.std() 44 | tlen_L = tlens_mu - 2 * tlens_sigma 45 | tlen_R = tlens_mu + 2 * tlens_sigma 46 | alens = np.asarray([mt[-2] for mt in metadata]) # mt[-2] := len_spec 47 | alens_mu, alens_sigma = alens.mean(), alens.std() 48 | alen_L = alens_mu - 2 * alens_sigma 49 | alen_R = alens_mu + 2 * alens_sigma 50 | 51 | metadata_filtered = [] 52 | for mt in metadata: 53 | if not tlen_L <= mt[-4] <= tlen_R: continue 54 | if not alen_L <= mt[-2] <= alen_R: continue 55 | metadata_filtered.append(mt) 56 | else: 57 | metadata_filtered = metadata 58 | 59 | len_text = np.asarray([mt[-4] for mt in metadata_filtered]) 60 | len_wav = np.asarray([mt[-3] for mt in metadata_filtered]) 61 | len_spec = np.asarray([mt[-2] for mt in metadata_filtered]) 62 | stat_dicts = np.asarray([mt[-1] for mt in metadata_filtered]) 63 | stats_agg = defaultdict(list) 64 | for stat in stat_dicts: 65 | for k, v in stat.items(): 66 | stats_agg[k].append(v) 67 | stats_agg = { k: np.asarray(v) for k, v in stats_agg.items() } 68 | 69 | stats = { 70 | 'total_examples': len(metadata_filtered), 71 | 'total_hours': len_wav.sum() / hp.sample_rate / (60 * 60), 72 | 'min_len_txt': len_text.min(), # n_pinyins 73 | 'max_len_txt': len_text.max(), 74 | 'avg_len_txt': len_text.mean(), 75 | 'min_len_wav': len_wav.min(), # n_samples 76 | 'max_len_wav': len_wav.max(), 77 | 'avg_len_wav': len_wav.mean(), 78 | 'min_len_spec': len_spec.min(), # n_frames 79 | 'max_len_spec': len_spec.max(), 80 | 'avg_len_spec': len_spec.mean(), 81 | } 82 | for k, v in stats_agg.items(): 83 | try: 84 | agg_fn = k[:k.find('_')] 85 | stats[k] = getattr(v, agg_fn)() 86 | except: 87 | print(f'unknown aggregate method for {k}') 88 | 89 | # (name, prds, text) 90 | metadata = [mt[:3] for mt in metadata_filtered] 91 | return metadata, stats, wav_dp 92 | 93 | 94 | def make_metadata(name, feats, wav_fp, out_dp): 95 | if not os.path.exists(wav_fp): return None 96 | text, prds = feats 97 | len_text = len(text.split(' ')) 98 | if not len_text == len(prds): return None 99 | 100 | y = A.load_wav(wav_fp) 101 | y = A.trim_silence(y) 102 | y = A.align_wav(y) 103 | len_wav = len(y) 104 | 105 | y_cut = y[:-1] 106 | mag, mel = A.get_specs(y_cut) # [M, T], [F, T] 107 | f0 = A.get_f0 (y_cut) # [T,] 108 | c0 = A.get_c0 (y_cut) # [T,] 109 | len_spec = mel.shape[1] 110 | 111 | assert len_wav == len_spec * hp.hop_length 112 | 113 | np.save(os.path.join(out_dp, f'mel-{name}.npy'), mel, allow_pickle=False) 114 | np.save(os.path.join(out_dp, f'mag-{name}.npy'), mag, allow_pickle=False) 115 | np.save(os.path.join(out_dp, f'f0-{name}.npy'), f0, allow_pickle=False) 116 | np.save(os.path.join(out_dp, f'c0-{name}.npy'), c0, allow_pickle=False) 117 | 118 | stats = { 119 | 'max_mel': mel.max(), 'min_mel': mel.min(), 120 | 'max_mag': mag.max(), 'min_mag': mag.min(), 121 | 'max_f0' : f0 .max(), 'min_f0' : f0 .min(), 122 | 'max_c0' : c0 .max(), 'min_c0' : c0 .min(), 123 | } 124 | return (name, prds, text, len_text, len_wav, len_spec, stats) 125 | 126 | 127 | def parse_label_file(fp) -> Dict[str, Tuple[str, str]]: 128 | '''prodosy: 129 | 0: 词内部 130 | 1: 连读的分词末 131 | 2: 长音或停顿的分词末 132 | 3: 分句末 133 | 4: 句末 134 | 5: 末尾标记 135 | 韵律的层级: 136 | 音节串#0 -> 单词串#1 -> 短语串#2 -> 短句串#3 -> 整句#4 137 | ''' 138 | 139 | r = { } 140 | with open(fp, encoding='utf-8') as fh: 141 | while True: 142 | name_kanji = fh.readline().strip() 143 | if not name_kanji: break 144 | 145 | name, kanji = name_kanji.split('\t') # '002333', '这是个#1例子#2' 146 | pinyin = fh.readline().strip().lower() # 'zhe4 shi4 ge4 li4 zi5' 147 | kanji = PUNCT_KANJI_REGEX.sub('', kanji) 148 | 149 | prodosy = [] 150 | for k in kanji: 151 | if k == '#': continue 152 | if k.isdigit(): 153 | if prodosy: prodosy[-1] = k 154 | else: prodosy.append(k) 155 | else: prodosy.append('0') 156 | prodosy = ''.join(prodosy) # '00102' 157 | 158 | r[name] = (pinyin, prodosy) 159 | return r 160 | -------------------------------------------------------------------------------- /transtacos/datasets/thchs30.py: -------------------------------------------------------------------------------- 1 | import os 2 | from functools import partial 3 | from concurrent.futures import ProcessPoolExecutor 4 | 5 | import numpy as np 6 | from tqdm import tqdm 7 | 8 | import audio 9 | 10 | # NOTE: this is broken, do not use without modify! 11 | 12 | 13 | def preprocess(args): 14 | in_dir = os.path.join(args.base_dir, 'thchs30') 15 | if not os.path.exists(in_dir): 16 | in_dir = os.path.join(args.base_dir, 'data_thchs30') 17 | out_dir = os.path.join(args.base_dir, args.out_dir) 18 | os.makedirs(out_dir, exist_ok=True) 19 | 20 | executor = ProcessPoolExecutor(max_workers=args.num_workers) 21 | futures = [] 22 | dp = os.path.join(in_dir, 'data') 23 | for fn in (fn for fn in os.listdir(dp) if fn.endswith('.wav')): 24 | wav_path = os.path.join(dp, fn) 25 | with open(wav_path + '.trn', encoding='utf8') as fh: 26 | fh.readline() # ignore first line (kanji) 27 | text = fh.readline().strip() # use pinyin only 28 | id = os.path.splitext(fn)[0] # '_' 29 | futures.append(executor.submit(partial(_process_utterance, out_dir, id, wav_path, text))) 30 | return [future.result() for future in tqdm(futures)] 31 | 32 | 33 | def _process_utterance(out_dir, id, wav_path, text): 34 | wav = audio.load_wav(wav_path) 35 | wav = audio.trim_silence(wav) 36 | 37 | mag, mel = audio.get_specs(wav) 38 | mag = mag.astype(np.float32) 39 | mel = mel.astype(np.float32) 40 | n_frames = mel.shape[1] 41 | 42 | spec_fn = 'thchs30-spec-%s.npy' % id 43 | np.save(os.path.join(out_dir, spec_fn), mag.T, allow_pickle=False) 44 | mel_fn = 'thchs30-mel-%s.npy' % id 45 | np.save(os.path.join(out_dir, mel_fn), mel.T, allow_pickle=False) 46 | 47 | return (spec_fn, mel_fn, n_frames, text) 48 | -------------------------------------------------------------------------------- /transtacos/hparam.py: -------------------------------------------------------------------------------- 1 | # Text 2 | g2p = 'syl4' # ['seq', 'syl4'] 3 | 4 | # Audio 5 | sample_rate = 22050 # sample rate (Hz) of wav file 6 | n_fft = 2048 7 | win_length = 1024 # :=n_fft//2 8 | hop_length = 256 # :=win_length//4, 11.6ms, 平均1个拼音音素对应9帧(min2~max20) 9 | n_mel = 80 # MEL谱段数 (default: 160), 120 should be more reasonable 10 | n_freq = 1025 # 线性谱段数 :=n_fft//2+1 11 | preemphasis = 0.97 # 增强高频,使EQ均衡 12 | ref_level_db = 20 # 最高考虑的谱幅值(虚拟0dB),理论上安静环境下取94,但实际上录音越嘈杂该值应越小 (default: 20) 13 | min_level_db = -100 # 最低考虑的谱幅值,用于动态范围截断压缩 (default: -100) 14 | max_abs_value = 4 # 将谱幅值正则化到 [-max_abs_value, max_abs_value] 15 | trim_below_peak_db = 35 # trim beginning/ending silence parts (default:60) 16 | fmin = 125 # MEL滤波器组频率上下限 (set 55/3600 for male) 17 | fmax = 7600 18 | rf0min = 'D2' # 基频检测上下限 19 | rf0max = 'D5' 20 | 21 | ## see `Databaker stats` or `stats.txt` in preprocessed folder 22 | c0min = 4.6309418394230306e-05 23 | c0max = 0.3751049339771271 24 | f0min = 73.25581359863281 25 | f0max = 595.9459228515625 26 | n_tone = 5+1 27 | n_prds = 5+1 28 | n_c0_bins = 32 29 | n_f0_bins = None # keep None for auto detect using f0min & f0max 30 | n_f0_min = None # as offset 31 | maxlen_text = 128 # for pos_enc, 27 in train set 32 | maxlen_spec = 1024 # for pos_enc, 524 in train set 33 | 34 | # Model 35 | outputs_per_step = 5 # default: 5 (aka. reduction factor r), 某种意义上的韵律量化,r越小韵律越精细、但rnn可能记不住长序列 36 | # 一般来说r取半音素的平均长度,这对元音连接类似于VCV 37 | hidden_gauss_std = 1e-5 38 | 39 | embed_depth = 256 # text/prds embed depth 40 | var_embed_depth = 64 # f0/c0 embed depth 41 | posenc_depth = 32 # pos_enc embed depth 42 | txt_use_posenc = True # 不需要PE好像才能学出sa的对角线(?) 43 | var_use_posenc = True # 需要PE才能学出ca的对角线 44 | prdsnet_depth = 64 45 | prdsnet_conv_k = 9 46 | embed_dropout = False 47 | 48 | encoder_depth = 256 # aka. inner_repr_depth 49 | encoder_type = 'sa' # ['sa', 'cb'] 50 | if encoder_type == 'sa': # like FastSpeech2 51 | encoder_attn_layers = 2 # NOTE: set 4 will lead to nan in loss (grad vanish?) 52 | encoder_attn_nhead = 2 53 | encoder_dropout = False 54 | encoder_fusenet = True 55 | gffw_conv_k = 9 56 | var_prednet_depth = 64 57 | var_prednet_conv_k = 13 58 | if encoder_type == 'cb': # like Tacotron 59 | encoder_conv_K = 16 60 | highway_layers = 4 61 | 62 | decoder_layers = 2 # single layer is not enough 63 | decoder_depth = 512 # default: 1024 64 | attention_depth = 128 # single LSA 65 | prenet_depths = [256] # prenet for decoder RNN, single layer seems enough 66 | decoder_sew_layer = False 67 | 68 | n_mel_low = 42 69 | posnet_depth = 512 70 | posnet_ngroup = 8 71 | 72 | # Training 73 | max_steps = 320000 # force stop train 74 | max_ckpt = 1 75 | batch_size = 16 76 | adam_beta1 = 0.9 77 | adam_beta2 = 0.999 # 0.98 78 | adam_eps = 1e-7 79 | reg_weight = 1e-6 # 1e-8 80 | sim_weight = 1e-5 81 | initial_learning_rate = 0.001 82 | decay_learning_rate = True # decrease learning rate by step, see `models.tacotron._learning_rate_decay()` 83 | tf_method = 'mix' # ['random', 'mix', 'force'] 84 | tf_init = 1.0 85 | tf_start_decay = 20000 # default: 20000 86 | tf_decay = 200000 # default: 200000 87 | 88 | # Eval 89 | max_iters = 300 # max iter of RNN, 最多产生max_iters*r帧、防止无限生成 90 | gl_iters = 30 # griffin_lim algorithm iters (default: 60) 91 | gl_power = 1.2 # Power to raise magnitudes to prior to Griffin-Lim 92 | postprocess = False # see `audio.save_wav()` 93 | 94 | # MISC 95 | randseed = 114514 96 | debug = False 97 | -------------------------------------------------------------------------------- /transtacos/models/attention.py: -------------------------------------------------------------------------------- 1 | """Attention file for location based attention (compatible with tensorflow attention wrapper)""" 2 | 3 | import tensorflow as tf 4 | from tensorflow.contrib.seq2seq.python.ops.attention_wrapper import BahdanauAttention 5 | from tensorflow.python.ops import array_ops, variable_scope 6 | 7 | 8 | def _location_sensitive_score(W_query, W_fil, W_keys): 9 | """Impelements Bahdanau-style (cumulative) scoring function. 10 | This attention is described in: 11 | J. K. Chorowski, D. Bahdanau, D. Serdyuk, K. Cho, and Y. Ben- 12 | gio, “Attention-based models for speech recognition,” in Ad- 13 | vances in Neural Information Processing Systems, 2015, pp. 14 | 577–585. 15 | 16 | ############################################################################# 17 | hybrid attention (content-based + location-based) 18 | f = F * α_{i-1} 19 | energy = dot(v_a, tanh(W_keys(h_enc) + W_query(h_dec) + W_fil(f) + b_a)) 20 | ############################################################################# 21 | 22 | Args: 23 | W_query: Tensor, shape '[batch_size, 1, attention_dim]' to compare to location features. 24 | W_location: processed previous alignments into location features, shape '[batch_size, max_time, attention_dim]' 25 | W_keys: Tensor, shape '[batch_size, max_time, attention_dim]', typically the encoder outputs. 26 | Returns: 27 | A '[batch_size, max_time]' attention score (energy) 28 | """ 29 | # Get the number of hidden units from the trailing dimension of keys 30 | dtype = W_query.dtype 31 | num_units = W_keys.shape[-1].value or array_ops.shape(W_keys)[-1] 32 | 33 | v_a = tf.get_variable( 34 | 'attention_variable', shape=[num_units], dtype=dtype, 35 | initializer=tf.contrib.layers.xavier_initializer()) 36 | b_a = tf.get_variable( 37 | 'attention_bias', shape=[num_units], dtype=dtype, 38 | initializer=tf.zeros_initializer()) 39 | 40 | return tf.reduce_sum(v_a * tf.tanh(W_keys + W_query + W_fil + b_a), [2]) 41 | 42 | 43 | class LocationSensitiveAttention(BahdanauAttention): 44 | """Impelements Bahdanau-style (cumulative) scoring function. 45 | Usually referred to as "hybrid" attention (content-based + location-based) 46 | Extends the additive attention described in: 47 | "D. Bahdanau, K. Cho, and Y. Bengio, “Neural machine transla- 48 | tion by jointly learning to align and translate,” in Proceedings 49 | of ICLR, 2015." 50 | to use previous alignments as additional location features. 51 | 52 | This attention is described in: 53 | J. K. Chorowski, D. Bahdanau, D. Serdyuk, K. Cho, and Y. Ben- 54 | gio, “Attention-based models for speech recognition,” in Ad- 55 | vances in Neural Information Processing Systems, 2015, pp. 56 | 577–585. 57 | """ 58 | 59 | def __init__(self, 60 | num_units, 61 | memory, 62 | memory_sequence_length=None, 63 | cumulate_weights=True, 64 | name='LocationSensitiveAttention'): 65 | """Construct the Attention mechanism. 66 | Args: 67 | num_units: The depth of the query mechanism. 68 | memory: The memory to query; usually the output of an RNN encoder. This 69 | tensor should be shaped `[batch_size, max_time, ...]`. 70 | memory_sequence_length (optional): Sequence lengths for the batch entries 71 | in memory. If provided, the memory tensor rows are masked with zeros 72 | for values past the respective sequence lengths. Only relevant if mask_encoder = True. 73 | name: Name to use when creating ops. 74 | """ 75 | #Create normalization function 76 | #Setting it to None defaults in using softmax 77 | super(LocationSensitiveAttention, self).__init__( 78 | num_units=num_units, 79 | memory=memory, 80 | memory_sequence_length=memory_sequence_length, 81 | probability_fn=None, 82 | name=name) 83 | 84 | self.location_convolution = tf.layers.Conv1D(filters=32, 85 | kernel_size=(31, ), padding='same', use_bias=True, # original: 31 86 | bias_initializer=tf.zeros_initializer(), name='location_features_convolution') 87 | self.location_layer = tf.layers.Dense(units=num_units, use_bias=False, 88 | dtype=tf.float32, name='location_features_layer') 89 | self._cumulate = cumulate_weights 90 | 91 | def __call__(self, query, state): 92 | """Score the query based on the keys and values. 93 | Args: 94 | query: Tensor of dtype matching `self.values` and shape 95 | `[batch_size, query_depth]`. 96 | state (previous alignments): Tensor of dtype matching `self.values` and shape 97 | `[batch_size, alignments_size]` 98 | (`alignments_size` is memory's `max_time`). 99 | Returns: 100 | alignments: Tensor of dtype matching `self.values` and shape 101 | `[batch_size, alignments_size]` (`alignments_size` is memory's 102 | `max_time`). 103 | """ 104 | previous_alignments = state 105 | with variable_scope.variable_scope(None, "Location_Sensitive_Attention", [query]): 106 | 107 | # processed_query shape [batch_size, query_depth] -> [batch_size, attention_dim] 108 | processed_query = self.query_layer(query) if self.query_layer else query 109 | # -> [batch_size, 1, attention_dim] 110 | processed_query = tf.expand_dims(processed_query, 1) 111 | 112 | # processed_location_features shape [batch_size, max_time, attention dimension] 113 | # [batch_size, max_time] -> [batch_size, max_time, 1] 114 | expanded_alignments = tf.expand_dims(previous_alignments, axis=2) 115 | # location features [batch_size, max_time, filters] 116 | f = self.location_convolution(expanded_alignments) 117 | # Projected location features [batch_size, max_time, attention_dim] 118 | processed_location_features = self.location_layer(f) 119 | 120 | # energy shape [batch_size, max_time] 121 | energy = _location_sensitive_score(processed_query, processed_location_features, self.keys) 122 | 123 | # alignments shape = energy shape = [batch_size, max_time] 124 | alignments = self._probability_fn(energy, previous_alignments) 125 | 126 | # Cumulate alignments 127 | if self._cumulate: 128 | next_state = alignments + previous_alignments 129 | else: 130 | next_state = alignments 131 | 132 | return alignments, next_state 133 | -------------------------------------------------------------------------------- /transtacos/models/custom_decoder.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | 3 | import numpy as np 4 | import tensorflow as tf 5 | from tensorflow.contrib.seq2seq import Helper, Decoder 6 | from tensorflow.python.framework import ops, tensor_shape 7 | from tensorflow.python.layers import base as layers_base 8 | from tensorflow.python.ops import rnn_cell_impl 9 | from tensorflow.python.util import nest 10 | 11 | import hparam as hp 12 | 13 | 14 | # Adapted from tf.contrib.seq2seq.GreedyEmbeddingHelper 15 | class TacoTestHelper(Helper): 16 | def __init__(self, batch_size, output_dim, r): 17 | with tf.name_scope('TacoTestHelper'): 18 | self._batch_size = batch_size 19 | self._output_dim = output_dim # output_dim==n_mels 20 | self._reduction_factor = r 21 | 22 | @property 23 | def batch_size(self): # IGNORED 24 | return self._batch_size 25 | 26 | @property 27 | def token_output_size(self): # IGNORED 28 | return self._reduction_factor # 每次吐r帧 29 | 30 | @property 31 | def sample_ids_shape(self): # IGNORED 32 | return tf.TensorShape([]) 33 | 34 | @property 35 | def sample_ids_dtype(self): # IGNORED 36 | return np.int32 37 | 38 | def sample(self, time, outputs, state, name=None): # IGNORED 39 | return tf.tile([0], [self._batch_size]) # Return all 0; we ignore them 40 | 41 | def initialize(self, name=None): # init 1d vetor of False 42 | return (tf.tile([False], [self._batch_size]), _go_frames(self._batch_size, self._output_dim)) # append 43 | 44 | def next_inputs(self, time, outputs, state, sample_ids, stop_token_preds, name=None): 45 | '''Stop on EOS. Otherwise, pass the last output as the next input and pass through state.''' 46 | with tf.name_scope('TacoTestHelper'): 47 | # A sequence is finished when the stop token probability is > 0.5 48 | # With enough training steps, the model should be able to predict when to stop correctly 49 | # and the use of stop_at_any = True would be recommended. If however the model didn't 50 | # learn to stop correctly yet, (stops too soon) one could choose to use the safer option 51 | # to get a correct synthesis 52 | # 难以通过判断产生了一个静音的mel帧来判定生成结束,所以用了一个与mel帧等长的stop_token向量 53 | # 一旦向量中(理论上应该靠近末端)出现了接近1.0的值即可认为生成结束,另参见`datafeeder._pad_stop_token_target()` 54 | # NOTE: 为了容错,不能只看stop_token_preds[-1] 55 | finished = tf.reduce_any(tf.cast(tf.round(stop_token_preds), tf.bool)) 56 | 57 | # Feed last output frame as next input. outputs is [N, output_dim * r] 58 | next_inputs = outputs[:, -self._output_dim:] # take last frame of a frame group 59 | return (finished, next_inputs, state) 60 | 61 | 62 | class TacoTrainingHelper(Helper): 63 | def __init__(self, batch_size, targets, output_dim, r, global_step): 64 | # inputs is [N, T_in], targets is [N, T_out, D] 65 | with tf.name_scope('TacoTrainingHelper'): 66 | self._batch_size = batch_size 67 | self._output_dim = output_dim # =n_mels 68 | self._reduction_factor = r 69 | self._ratio = None 70 | self.global_step = global_step 71 | 72 | # Feed every r-th target frame as input 73 | self._targets = targets[:, r-1::r, :] # 每个帧组的最后一帧, every r-th frame 74 | 75 | # Use full length for every target because we don't want to mask the padding frames 76 | num_steps = tf.shape(self._targets)[1] # =max_timesetps # FIXME: why cannot stop early? 77 | self._lengths = tf.tile([num_steps], [self._batch_size]) 78 | 79 | @property 80 | def batch_size(self): # IGNORED 81 | return self._batch_size 82 | 83 | @property 84 | def token_output_size(self): # IGNORED 85 | return self._reduction_factor 86 | 87 | @property 88 | def sample_ids_shape(self): # IGNORED 89 | return tf.TensorShape([]) 90 | 91 | @property 92 | def sample_ids_dtype(self): # IGNORED 93 | return np.int32 94 | 95 | def sample(self, time, outputs, state, name=None): # IGNORED 96 | return tf.tile([0], [self._batch_size]) # Return all 0; we ignore them 97 | 98 | def initialize(self, name=None): 99 | self._ratio = _teacher_forcing_ratio_decay(hp.tf_init, self.global_step) 100 | return (tf.tile([False], [self._batch_size]), _go_frames(self._batch_size, self._output_dim)) # append 101 | 102 | def next_inputs(self, time, outputs, state, sample_ids, stop_token_preds, name='TacoTrainingHelper'): 103 | with tf.name_scope(name): 104 | finished = (time + 1 >= self._lengths) # 训练时读到最后一帧mel就算结束,即这个batch的最大帧组长度 105 | 106 | if hp.tf_method == 'force': 107 | next_inputs = self._targets[:, time, :] 108 | elif hp.tf_method == 'random': 109 | next_inputs = tf.cond(tf.less(tf.random_uniform([], minval=0, maxval=1, dtype=tf.float32), self._ratio), 110 | lambda: self._targets[:, time, :], 111 | lambda: outputs[:, -self._output_dim:]) 112 | elif hp.tf_method == 'mix': 113 | next_inputs = self._ratio * self._targets[:, time, :] + (1 - self._ratio) * outputs[:, -self._output_dim:] 114 | else: raise ValueError 115 | 116 | return (finished, next_inputs, state) 117 | 118 | 119 | def _go_frames(batch_size, output_dim): 120 | '''Returns all-zero frames for a given batch size and output dimension''' 121 | return tf.tile([[0.0]], [batch_size, output_dim]) 122 | 123 | 124 | def _teacher_forcing_ratio_decay(init_tfr, global_step): 125 | ################################################################# 126 | # Narrow Cosine Decay: 127 | 128 | # Phase 1: tfr = 1 129 | # We only start learning rate decay after 10k steps 130 | 131 | # Phase 2: tfr in [0, 1] 132 | # decay reach minimal value at step ~280k 133 | 134 | # Phase 3: tfr = 0 135 | # clip by minimal teacher forcing ratio value (step >~ 280k) 136 | ################################################################# 137 | # Compute natural cosine decay 138 | tfr = tf.train.cosine_decay(init_tfr, 139 | global_step=global_step - hp.tf_start_decay, # tfr = 1 at step 10k, (original: 20000) 140 | decay_steps=hp.tf_decay, # tfr = 0 at step ~280k, (original: 200000) 141 | alpha=0., # tfr = 0% of init_tfr as final value 142 | name='tfr_cosine_decay') 143 | 144 | # force teacher forcing ratio to take initial value when global step < start decay step. 145 | # NOTE: narrow_tfr = global_step < 10000 ? init_tfr : tfr 146 | narrow_tfr = tf.cond( 147 | tf.less(global_step, tf.convert_to_tensor(hp.tf_start_decay)), # original: 20000 148 | lambda: tf.convert_to_tensor(init_tfr), 149 | lambda: tfr) 150 | 151 | return narrow_tfr 152 | 153 | 154 | class CustomDecoderOutput( 155 | namedtuple("CustomDecoderOutput", ("rnn_output", "token_output", "sample_id"))): 156 | pass 157 | 158 | 159 | class CustomDecoder(Decoder): 160 | """Custom sampling decoder. 161 | 162 | Allows for stop token prediction at inference time 163 | and returns equivalent loss in training time. 164 | 165 | Note: 166 | Only use this decoder with Tacotron 2 as it only accepts tacotron custom helpers 167 | """ 168 | 169 | def __init__(self, cell, helper, initial_state, output_layer=None): 170 | """Initialize CustomDecoder. 171 | Args: 172 | cell: An `RNNCell` instance. 173 | helper: A `Helper` instance. 174 | initial_state: A (possibly nested tuple of...) tensors and TensorArrays. 175 | The initial state of the RNNCell. 176 | output_layer: (Optional) An instance of `tf.layers.Layer`, i.e., 177 | `tf.layers.Dense`. Optional layer to apply to the RNN output prior 178 | to storing the result or sampling. 179 | Raises: 180 | TypeError: if `cell`, `helper` or `output_layer` have an incorrect type. 181 | """ 182 | rnn_cell_impl.assert_like_rnncell(type(cell), cell) 183 | if not isinstance(helper, Helper): 184 | raise TypeError("helper must be a Helper, received: %s" % type(helper)) 185 | if (output_layer is not None and not isinstance(output_layer, layers_base.Layer)): 186 | raise TypeError("output_layer must be a Layer, received: %s" % type(output_layer)) 187 | self._cell = cell 188 | self._helper = helper 189 | self._initial_state = initial_state 190 | self._output_layer = output_layer 191 | 192 | @property 193 | def batch_size(self): 194 | return self._helper.batch_size 195 | 196 | def _rnn_output_size(self): 197 | size = self._cell.output_size 198 | if self._output_layer is None: 199 | return size 200 | else: 201 | # To use layer's compute_output_shape, we need to convert the 202 | # RNNCell's output_size entries into shapes with an unknown 203 | # batch size. We then pass this through the layer's 204 | # compute_output_shape and read off all but the first (batch) 205 | # dimensions to get the output size of the rnn with the layer 206 | # applied to the top. 207 | output_shape_with_unknown_batch = nest.map_structure( 208 | lambda s: tensor_shape.TensorShape([None]).concatenate(s), 209 | size) 210 | layer_output_shape = self._output_layer._compute_output_shape( # pylint: disable=protected-access 211 | output_shape_with_unknown_batch) 212 | return nest.map_structure(lambda s: s[1:], layer_output_shape) 213 | 214 | @property 215 | def output_size(self): 216 | # Return the cell output and the id 217 | return CustomDecoderOutput( 218 | rnn_output=self._rnn_output_size(), 219 | token_output=self._helper.token_output_size, 220 | sample_id=self._helper.sample_ids_shape) 221 | 222 | @property 223 | def output_dtype(self): 224 | # Assume the dtype of the cell is the output_size structure 225 | # containing the input_state's first component's dtype. 226 | # Return that structure and the sample_ids_dtype from the helper. 227 | dtype = nest.flatten(self._initial_state)[0].dtype 228 | return CustomDecoderOutput( 229 | nest.map_structure(lambda _: dtype, self._rnn_output_size()), 230 | tf.float32, 231 | self._helper.sample_ids_dtype) 232 | 233 | def initialize(self, name=None): 234 | """Initialize the decoder. 235 | Args: 236 | name: Name scope for any created operations. 237 | Returns: 238 | `(finished, first_inputs, initial_state)`. 239 | """ 240 | return self._helper.initialize() + (self._initial_state,) 241 | 242 | def step(self, time, inputs, state, name=None): 243 | """Perform a custom decoding step. 244 | Enables for dyanmic prediction 245 | Args: 246 | time: scalar `int32` tensor. 247 | inputs: A (structure of) input tensors. 248 | state: A (structure of) state tensors and TensorArrays. 249 | name: Name scope for any created operations. 250 | Returns: 251 | `(outputs, next_state, next_inputs, finished)`. 252 | """ 253 | with ops.name_scope(name, "CustomDecoderStep", (time, inputs, state)): 254 | #Call outputprojection wrapper cell 255 | (cell_outputs, stop_token), cell_state = self._cell(inputs, state) 256 | 257 | #apply output_layer (if existant) 258 | if self._output_layer is not None: 259 | cell_outputs = self._output_layer(cell_outputs) 260 | sample_ids = self._helper.sample( 261 | time=time, outputs=cell_outputs, state=cell_state) 262 | 263 | (finished, next_inputs, next_state) = self._helper.next_inputs( 264 | time=time, 265 | outputs=cell_outputs, 266 | state=cell_state, 267 | sample_ids=sample_ids, 268 | stop_token_preds=stop_token) 269 | 270 | outputs = CustomDecoderOutput(cell_outputs, stop_token, sample_ids) 271 | return (outputs, next_state, next_inputs, finished) 272 | -------------------------------------------------------------------------------- /transtacos/models/modules.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | from tensorflow.contrib.rnn import GRUCell 4 | from tensorflow.keras.backend import batch_dot 5 | 6 | import hparam as hp 7 | 8 | REUSE = False 9 | 10 | 11 | ''' below is for tacotron compactible ''' 12 | 13 | def prenet(inputs, layer_sizes, is_training, scope='prenet'): 14 | x = inputs 15 | drop_rate = 0.5 if is_training else 0.0 # NOTE: only dropout on trainning 16 | with tf.variable_scope(scope): 17 | # NOTE: chained i-layers of dense and dropout 18 | for i, size in enumerate(layer_sizes): 19 | dense = tf.layers.dense(x, units=size, activation=tf.nn.relu, name='dense_%d' % (i+1)) 20 | x = tf.layers.dropout(dense, rate=drop_rate, training=is_training, name='dropout_%d' % (i+1)) 21 | return x 22 | 23 | 24 | def conv1d(inputs, k, filters, activation, is_training, scope='conv1d'): 25 | with tf.variable_scope(scope): 26 | conv = tf.layers.conv1d( 27 | inputs, 28 | filters=filters, 29 | kernel_size=k, 30 | activation=None, 31 | padding='same') 32 | bn = tf.layers.batch_normalization(conv, training=is_training) 33 | return activation(bn) 34 | 35 | 36 | def highwaynet(inputs, depth, scope='highwaynet'): 37 | with tf.variable_scope(scope): 38 | H = tf.layers.dense( 39 | inputs, 40 | units=depth, 41 | activation=tf.nn.relu, 42 | name='H') 43 | T = tf.layers.dense( 44 | inputs, 45 | units=depth, 46 | activation=tf.nn.sigmoid, 47 | name='T', 48 | bias_initializer=tf.constant_initializer(-1.0)) 49 | return H * T + inputs * (1.0 - T) 50 | 51 | 52 | def cbhg(inputs, input_lengths, K, proj_dims, depth, is_training, scope='cbhg'): 53 | with tf.variable_scope(scope): 54 | with tf.variable_scope('conv_bank'): 55 | # Convolution bank: concatenate on the last axis to connect channels from all convolutions 56 | conv = tf.concat( 57 | [conv1d(inputs, k+1, depth//2, tf.nn.relu, is_training, 'conv1d_%d' % (k+1)) for k in range(K)], 58 | axis=-1) 59 | 60 | # Maxpooling: 61 | conv = tf.layers.max_pooling1d( 62 | conv, 63 | pool_size=2, 64 | strides=1, 65 | padding='same') 66 | 67 | # Two projection layers: reduce depth 68 | proj = conv1d(conv, 3, proj_dims[0], tf.nn.relu, is_training, 'proj_1') # depth: 896(7*128) -> 128 69 | proj = conv1d(proj, 3, proj_dims[1], lambda _:_, is_training, 'proj_2') # depth: 128 -> 256 70 | 71 | # Residual connection: 72 | # now we marge `acoustics (phoneme)` with `pardosy (text context)` 73 | highway_input = inputs + proj 74 | 75 | # Handle dimensionality mismatch: 76 | if highway_input.shape[-1] != depth: 77 | highway_input = tf.layers.dense(highway_input, depth) 78 | # 4-layer HighwayNet: 79 | for i in range(hp.highway_layers): 80 | highway_input = highwaynet(highway_input, depth, 'highway_%d' % (i+1)) 81 | 82 | # Bidirectional RNN 83 | outputs, states = tf.nn.bidirectional_dynamic_rnn( 84 | GRUCell(depth//2), 85 | GRUCell(depth//2), 86 | highway_input, 87 | sequence_length=input_lengths, 88 | dtype=tf.float32) 89 | 90 | return tf.concat(outputs, axis=-1) # Concat forward and backward 91 | 92 | 93 | ''' below is my stuff ''' 94 | 95 | def gaussian_noise(x, is_training): 96 | if hp.hidden_gauss_std: 97 | x = tf.keras.layers.GaussianNoise(hp.hidden_gauss_std)(x, training=is_training) 98 | return x 99 | 100 | 101 | def conv_stack(x, n_layers, k, d_in, d_out, activation=tf.nn.relu, scope='conv_stack'): 102 | with tf.variable_scope(scope, reuse=REUSE): 103 | for i in range(n_layers-1): 104 | x = tf.layers.conv1d(x, d_in, k, padding='same', name=f'conv{i+1}') 105 | x = activation(x) 106 | x = tf.layers.conv1d(x, d_out, k, padding='same', name=f'conv{n_layers}') 107 | return x 108 | 109 | 110 | def dot_attn(x, y, mask, attn_dim, scope='dot_attn'): 111 | with tf.variable_scope(scope, reuse=REUSE): 112 | # [B, N, A] 113 | q = tf.layers.dense(x, attn_dim, name='q') 114 | # [B, T, A] 115 | k = tf.layers.dense(y, attn_dim, name='k') 116 | v = tf.layers.dense(y, attn_dim, name='v') 117 | 118 | # [B, N, T] 119 | e = tf.matmul(q, k, transpose_b=True) 120 | e = e * mask + (1 - mask) * -1e8 # mask energe to inf 121 | e = e / tf.sqrt(tf.cast(hp.encoder_depth, tf.float32)) 122 | sc = tf.nn.softmax(e, axis=-1) 123 | 124 | # [B, N, A] 125 | r = tf.matmul(sc, v) 126 | 127 | return r, sc 128 | 129 | 130 | def GLU(inputs, depth, k=7, activation=None, scope='GLU'): 131 | with tf.variable_scope(scope): 132 | conv = tf.layers.conv1d( 133 | inputs, 134 | filters=depth*2, 135 | kernel_size=k, 136 | activation=activation, 137 | padding='same', 138 | name='conv') 139 | 140 | x, gate = tf.split(conv, 2, axis=-1) 141 | if activation: x = activation(x) 142 | gate = tf.nn.sigmoid(gate) 143 | 144 | return x * gate 145 | 146 | 147 | def gffw(x, depth, scope='gffw'): 148 | with tf.variable_scope(scope, reuse=REUSE): 149 | o = GLU(x, depth, k=hp.gffw_conv_k, activation=tf.nn.leaky_relu, scope='GLU') 150 | o = tf.layers.conv1d(o, depth, 1, padding='same', name='conv_pointwise') 151 | return o 152 | 153 | 154 | def get_sinusoid_encoding_table(n_position, d_hid, padding_idx=None): 155 | 156 | def cal_angle(position, hid_idx): 157 | return position / np.power(10000, 2 * (hid_idx // 2) / d_hid) 158 | 159 | def get_posi_angle_vec(position): 160 | return [cal_angle(position, hid_j) for hid_j in range(d_hid)] 161 | 162 | sinusoid_table = np.array([get_posi_angle_vec(pos_i) for pos_i in range(n_position)]) 163 | 164 | sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i 165 | sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 166 | 167 | # zero vector for padding dimension 168 | if padding_idx is not None: sinusoid_table[padding_idx] = 0. 169 | 170 | # [1, K, D] 171 | return tf.expand_dims(tf.convert_to_tensor(sinusoid_table, dtype=tf.float32), axis=0) 172 | 173 | 174 | def get_attn_mask(xlen, max_xlen, ylen=None, max_ylen=None): 175 | if ylen is None and max_ylen is None: ylen, max_ylen = xlen, max_xlen 176 | x_unary = tf.expand_dims(tf.sequence_mask(xlen, max_xlen, dtype=tf.float32), 1) 177 | y_unary = tf.expand_dims(tf.sequence_mask(ylen, max_ylen, dtype=tf.float32), 1) 178 | mask = batch_dot(tf.transpose(x_unary, [0, 2, 1]), y_unary) 179 | return mask 180 | 181 | 182 | def encoder_sa(x, x_len, f0, c0, y_len, is_training, scope='encoder'): 183 | depth = hp.encoder_depth 184 | 185 | with tf.variable_scope(scope, reuse=REUSE): 186 | # prenet 187 | if hp.txt_use_posenc: 188 | x = tf.layers.dense(x, depth, None, name='prenet') 189 | if hp.encoder_dropout: 190 | x = tf.layers.dropout(x, rate=0.2, training=is_training, name='dropout') 191 | 192 | '''multi-head self-attn: acoustics => global prosody''' 193 | slf_attns = [] 194 | max_xlen = tf.shape(x)[-2] # dynamic padded length N 195 | slf_mask = get_attn_mask(x_len, max_xlen) 196 | # D: 256 -> 64*4 -> 256 197 | for i in range(hp.encoder_attn_layers): 198 | # multi-head sa 199 | rs, attns = [], [] 200 | #x = tf.keras.layers.LayerNormalization()(x) # pre-norm 201 | for h in range(hp.encoder_attn_nhead): 202 | r, sc = dot_attn(x, x, slf_mask, depth // hp.encoder_attn_nhead, scope=f'sa_{i}_{h}') 203 | rs.append(r) ; attns.append(sc) 204 | slf_attns.append(attns) 205 | 206 | # combine multi-head 207 | sa = tf.layers.dense(tf.concat(rs, axis=-1), depth, name=f'proj_sa_{i}') 208 | if hp.encoder_dropout: 209 | sa = tf.layers.dropout(sa, rate=0.2, training=is_training, name='dropout') 210 | 211 | # transform (gffw) 212 | x = x + gffw(x + sa, depth, scope=f'gffw_sa_{i}') 213 | #x = tf.keras.layers.LayerNormalization()(x) # post-norm 214 | 215 | ''' fusenet ''' 216 | crx_attns = [] 217 | f0_r = c0_r = f0_r_pred = c0_r_pred = 0.0 218 | if hp.encoder_fusenet: 219 | f0_r_pred = conv_stack(x, 2, hp.var_prednet_conv_k, hp.var_prednet_depth, hp.var_prednet_depth, activation=tf.nn.leaky_relu, scope='ca_f0_prednet') 220 | c0_r_pred = conv_stack(x, 2, hp.var_prednet_conv_k, hp.var_prednet_depth, hp.var_prednet_depth, activation=tf.nn.leaky_relu, scope='ca_c0_prednet') 221 | if is_training: 222 | max_ylen = tf.shape(f0)[-2] # dynamic padded length T 223 | crx_mask = get_attn_mask(x_len, max_xlen, y_len, max_ylen) 224 | 225 | # [B, N, 256] cross attn [B, T, 64] 226 | f0_r, sc = dot_attn(x, f0, crx_mask, hp.var_prednet_depth, scope='ca_f0') 227 | crx_attns.append(sc) 228 | c0_r, sc = dot_attn(x, c0, crx_mask, hp.var_prednet_depth, scope='ca_c0') 229 | crx_attns.append(sc) 230 | 231 | # combine f0 & c0 232 | if is_training: f = tf.layers.dense(tf.concat([f0_r, c0_r], axis=-1), depth, name='proj_ca') 233 | else: f = tf.layers.dense(tf.concat([f0_r_pred, c0_r_pred], axis=-1), depth, name='proj_ca') 234 | if hp.encoder_dropout: 235 | f = tf.layers.dropout(f, rate=0.2, training=is_training, name='dropout') 236 | 237 | # combine & transform (gffw) 238 | x = x + gffw(tf.concat([x, f], axis=-1), depth, scope=f'gffw_ca') 239 | 240 | return x, (slf_attns, crx_attns), ((f0_r, f0_r_pred), (c0_r, c0_r_pred)) 241 | -------------------------------------------------------------------------------- /transtacos/models/rnn_wrappers.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | 3 | import numpy as np 4 | import tensorflow as tf 5 | from tensorflow.contrib.rnn import RNNCell 6 | from tensorflow.python.framework import ops, tensor_shape 7 | from tensorflow.python.ops import array_ops, check_ops, rnn_cell_impl, tensor_array_ops 8 | from tensorflow.python.util import nest 9 | 10 | from .modules import prenet 11 | import hparam as hp 12 | 13 | 14 | class FrameProjection: 15 | """Projection layer to r * n_mel dimensions or n_mel dimensions""" 16 | def __init__(self, shape=hp.n_mel, activation=None, scope='linear_projection'): 17 | """ 18 | Args: 19 | shape: integer, dimensionality of output space (r*n_mels for decoder or n_mels for postnet) 20 | activation: callable, activation function 21 | scope: FrameProjection scope. 22 | """ 23 | super(FrameProjection, self).__init__() 24 | 25 | self.shape = shape 26 | self.activation = activation 27 | self.scope = scope 28 | self.dense = tf.layers.Dense(units=shape, activation=activation, name=f'projection_{self.scope}') 29 | 30 | def __call__(self, inputs): 31 | with tf.variable_scope(self.scope): 32 | # If activation==None, this returns a simple Linear projection 33 | # else the projection will be passed through an activation function 34 | return self.dense(inputs) 35 | 36 | 37 | class StopProjection: 38 | """Projection to a scalar and through a sigmoid activation""" 39 | def __init__(self, is_training, shape=1, activation=tf.nn.sigmoid, scope='stop_token_projection'): 40 | """ 41 | Args: 42 | is_training: Boolean, to control the use of sigmoid function as it is useless to use it 43 | during training since it is integrate inside the sigmoid_crossentropy loss 44 | shape: integer, dimensionality of output space. Defaults to 1 (scalar) 45 | activation: callable, activation function. only used during inference 46 | scope: StopProjection scope. 47 | """ 48 | super(StopProjection, self).__init__() 49 | 50 | self.is_training = is_training 51 | self.shape = shape 52 | self.activation = activation 53 | self.scope = scope 54 | 55 | def __call__(self, inputs): 56 | with tf.variable_scope(self.scope): 57 | output = tf.layers.dense(inputs, units=self.shape, activation=None, name=f'projection_{self.scope}') 58 | #During training, don't use activation as it is integrated inside the sigmoid_cross_entropy loss function 59 | return output if self.is_training else self.activation(output) 60 | 61 | 62 | class TacotronDecoderCellState( 63 | namedtuple("TacotronDecoderCellState", 64 | ("cell_state", "attention", "time", "alignments", 65 | "alignment_history"))): 66 | """`namedtuple` storing the state of a `TacotronDecoderCell`. 67 | Contains: 68 | - `cell_state`: The state of the wrapped `RNNCell` at the previous time 69 | step. 70 | - `attention`: The attention emitted at the previous time step. 71 | - `time`: int32 scalar containing the current time step. 72 | - `alignments`: A single or tuple of `Tensor`(s) containing the alignments 73 | emitted at the previous time step for each attention mechanism. 74 | - `alignment_history`: a single or tuple of `TensorArray`(s) 75 | containing alignment matrices from all time steps for each attention 76 | mechanism. Call `stack()` on each to convert to a `Tensor`. 77 | """ 78 | def replace(self, **kwargs): 79 | """Clones the current state while overwriting components provided by kwargs. 80 | """ 81 | return super(TacotronDecoderCellState, self)._replace(**kwargs) 82 | 83 | 84 | class TacotronDecoderWrapper(RNNCell): 85 | """Tactron 2 Decoder Cell 86 | Decodes encoder output and previous mel frames into next r frames 87 | 88 | Decoder Step i: 89 | 1) Prenet to compress last output information 90 | 2) Concat compressed inputs with previous context vector (input feeding) * 91 | 3) Decoder RNN (actual decoding) to predict current state s_{i} * 92 | 4) Compute new context vector c_{i} based on s_{i} and a cumulative sum of previous alignments * 93 | 5) Predict new output y_{i} using s_{i} and c_{i} (concatenated) 94 | 6) Predict output ys_{i} using s_{i} and c_{i} (concatenated) 95 | 96 | * : This is typically taking a vanilla LSTM, wrapping it using tensorflow's attention wrapper, 97 | and wrap that with the prenet before doing an input feeding, and with the prediction layer 98 | that uses RNN states to project on output space. Actions marked with (*) can be replaced with 99 | tensorflow's attention wrapper call if it was using cumulative alignments instead of previous alignments only. 100 | """ 101 | 102 | def __init__(self, rnn_cell, attention_mechanism, frame_projection, stop_projection, is_training): 103 | """Initialize decoder parameters 104 | 105 | Args: 106 | prenet: A tensorflow fully connected layer acting as the decoder pre-net 107 | attention_mechanism: A _BaseAttentionMechanism instance, usefull to 108 | learn encoder-decoder alignments 109 | rnn_cell: Instance of RNNCell, main body of the decoder 110 | frame_projection: tensorflow fully connected layer with r * n_mel output units 111 | stop_projection: tensorflow fully connected layer, expected to project to a scalar 112 | and through a sigmoid activation 113 | mask_finished: Boolean, Whether to mask decoder frames after the 114 | """ 115 | super(TacotronDecoderWrapper, self).__init__() 116 | #Initialize decoder layers 117 | self._training = is_training 118 | self._attention_mechanism = attention_mechanism 119 | self._cell = rnn_cell 120 | self._frame_projection = frame_projection 121 | self._stop_projection = stop_projection 122 | self._attention_layer_size = self._attention_mechanism.values.get_shape()[-1].value 123 | 124 | def _batch_size_checks(self, batch_size, error_message): 125 | return [check_ops.assert_equal(batch_size, 126 | self._attention_mechanism.batch_size, 127 | message=error_message)] 128 | 129 | @property 130 | def output_size(self): 131 | return self._frame_projection.shape 132 | 133 | # @property 134 | def state_size(self): 135 | """The `state_size` property of `TacotronDecoderWrapper`. 136 | 137 | Returns: 138 | An `TacotronDecoderWrapper` tuple containing shapes used by this object. 139 | """ 140 | return TacotronDecoderCellState( 141 | cell_state=self._cell._cell.state_size, 142 | time=tensor_shape.TensorShape([]), 143 | attention=self._attention_layer_size, 144 | alignments=self._attention_mechanism.alignments_size, 145 | alignment_history=()) 146 | 147 | def zero_state(self, batch_size, dtype): 148 | """Return an initial (zero) state tuple for this `AttentionWrapper`. 149 | 150 | Args: 151 | batch_size: `0D` integer tensor: the batch size. 152 | dtype: The internal state data type. 153 | Returns: 154 | An `TacotronDecoderCellState` tuple containing zeroed out tensors and, 155 | possibly, empty `TensorArray` objects. 156 | Raises: 157 | ValueError: (or, possibly at runtime, InvalidArgument), if 158 | `batch_size` does not match the output size of the encoder passed 159 | to the wrapper object at initialization time. 160 | """ 161 | with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]): 162 | cell_state = self._cell.zero_state(batch_size, dtype) 163 | error_message = ( 164 | "When calling zero_state of TacotronDecoderCell %s: " % self._base_name + 165 | "Non-matching batch sizes between the memory " 166 | "(encoder output) and the requested batch size.") 167 | with ops.control_dependencies( 168 | self._batch_size_checks(batch_size, error_message)): 169 | cell_state = nest.map_structure( 170 | lambda s: array_ops.identity(s, name="checked_cell_state"), 171 | cell_state) 172 | return TacotronDecoderCellState( 173 | cell_state=cell_state, 174 | time=array_ops.zeros([], dtype=tf.int32), 175 | attention=rnn_cell_impl._zero_state_tensors(self._attention_layer_size, batch_size, dtype), 176 | alignments=self._attention_mechanism.initial_alignments(batch_size, dtype), 177 | alignment_history=tensor_array_ops.TensorArray(dtype=dtype, size=0, 178 | dynamic_size=True)) 179 | 180 | 181 | def __call__(self, inputs, state): 182 | #Information bottleneck (essential for learning attention) 183 | # just adjust n_mel to encoder_depth (?) 184 | prenet_output = prenet(inputs, hp.prenet_depths, self._training, scope='decoder_prenet') 185 | 186 | #Concat context vector and prenet output to form RNN cells input (input feeding) 187 | rnn_input = tf.concat([prenet_output, state.attention], axis=-1) 188 | 189 | #Unidirectional RNN layers 190 | rnn_output, next_cell_state = self._cell(tf.layers.dense(rnn_input, hp.decoder_depth), state.cell_state) 191 | 192 | #Compute the attention (context) vector and alignments using 193 | #the new decoder cell hidden state as query vector 194 | #and cumulative alignments to extract location features 195 | #The choice of the new cell hidden state (s_{i}) of the last 196 | #decoder RNN Cell is based on Luong et Al. (2015): 197 | #https://arxiv.org/pdf/1508.04025.pdf 198 | previous_alignments = state.alignments 199 | previous_alignment_history = state.alignment_history 200 | 201 | alignments, cumulated_alignments = self._attention_mechanism(rnn_output, state=previous_alignments) 202 | 203 | # Reshape from [batch_size, memory_time] to [batch_size, 1, memory_time] 204 | expanded_alignments = array_ops.expand_dims(alignments, 1) 205 | # Context is the inner product of alignments and values along the 206 | # memory time dimension. 207 | # alignments shape is 208 | # [batch_size, 1, memory_time] 209 | # attention_mechanism.values shape is 210 | # [batch_size, memory_time, memory_size] 211 | # the batched matmul is over memory_time, so the output shape is 212 | # [batch_size, 1, memory_size]. 213 | # we then squeeze out the singleton dim. 214 | context = tf.matmul(expanded_alignments, self._attention_mechanism.values) 215 | context = tf.squeeze(context, [1]) # V 216 | 217 | #Concat RNN outputs and context vector to form projections inputs 218 | projections_input = tf.concat([rnn_output, context], axis=-1) 219 | 220 | #Compute predicted frames and predicted 221 | cell_outputs = self._frame_projection(projections_input) 222 | stop_tokens = self._stop_projection (projections_input) 223 | 224 | #Save alignment history 225 | alignment_history = previous_alignment_history.write(state.time, alignments) 226 | 227 | #Prepare next decoder state 228 | next_state = TacotronDecoderCellState( 229 | time=state.time + 1, 230 | cell_state=next_cell_state, 231 | attention=context, 232 | alignments=cumulated_alignments, 233 | alignment_history=alignment_history) 234 | 235 | return (cell_outputs, stop_tokens), next_state 236 | -------------------------------------------------------------------------------- /transtacos/models/tacotron.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.contrib.rnn import GRUCell, MultiRNNCell, ResidualWrapper 3 | 4 | from .modules import * 5 | from .custom_decoder import * 6 | from .rnn_wrappers import * 7 | from .attention import * 8 | from text.symbols import get_vocab_size 9 | from audio import inv_spec 10 | from utils import log 11 | 12 | 13 | class Tacotron(): 14 | 15 | def __init__(self, hparams): 16 | self._hparams = hparams 17 | self.global_step = tf.Variable(0, name='global_step', trainable=False) 18 | 19 | 20 | def initialize(self, text_lengths, text, prds=None, 21 | spec_lengths=None, mel_targets=None, mag_targets=None, f0_targets=None, c0_targets=None, 22 | stop_token_targets=None): 23 | with tf.variable_scope('inference'): 24 | hp = self._hparams 25 | is_training = mel_targets is not None 26 | B = batch_size = tf.shape(text)[0] 27 | 28 | print('text_lengths.shape:', text_lengths.shape) 29 | print('text.shape:', text.shape) 30 | if is_training: 31 | print('prds.shape:', prds.shape) 32 | print('spec_lengths.shape:', spec_lengths.shape) 33 | print('mel_targets.shape:', mel_targets.shape) 34 | print('mag_targets.shape:', mag_targets.shape) 35 | print('f0_targets.shape:', f0_targets.shape) 36 | print('c0_targets.shape:', c0_targets.shape) 37 | print('stop_token_targets.shape:', stop_token_targets.shape) 38 | log(f'[Tacotron] vocab size {get_vocab_size()}') 39 | 40 | # Embeddings 41 | # 不用零pad好像关系也不大 42 | zero_embedding_pad = tf.constant(0, shape=[1, hp.embed_depth], dtype=tf.float32, name='zero_embedding_pad') 43 | zero_embedding_pad_half = tf.constant(0, shape=[1, hp.encoder_depth//2], dtype=tf.float32, name='zero_embedding_pad_half') 44 | 45 | '''位置编码嵌入''' 46 | PE_table = get_sinusoid_encoding_table(max(hp.maxlen_text, hp.maxlen_spec), hp.posenc_depth) # 姑且统一深度,使用concat方式 47 | 48 | '''语言学特征嵌入''' 49 | if hp.g2p == 'seq': 50 | E_text = tf.get_variable('E_text', [get_vocab_size(), hp.embed_depth], dtype=tf.float32, initializer=tf.truncated_normal_initializer(stddev=0.5)) 51 | 52 | # seq 53 | text_embd = tf.nn.embedding_lookup(E_text, text) 54 | embd_out = text_embd 55 | 56 | elif hp.g2p == 'syl4': 57 | E_text = tf.get_variable('E_text', [get_vocab_size(), hp.embed_depth], dtype=tf.float32, initializer=tf.truncated_normal_initializer(stddev=0.5)) 58 | E_tone = tf.get_variable('E_tone', [hp.n_tone, hp.embed_depth], dtype=tf.float32, initializer=tf.truncated_normal_initializer(stddev=0.5)) 59 | E_prds = tf.get_variable('E_prds', [hp.n_prds, hp.embed_depth], dtype=tf.float32, initializer=tf.truncated_normal_initializer(stddev=0.5)) 60 | 61 | # syl4 62 | CVVx, T = [tf.squeeze(p, axis=-1) for p in tf.split(text, 2, axis=-1)] # [B, T, 2] => 2 * [B, T] 63 | phone_embd = tf.nn.embedding_lookup(E_text, CVVx) 64 | tone_embd = tf.nn.embedding_lookup(E_tone, T) 65 | text_embd = phone_embd + tone_embd 66 | 67 | # prds 68 | prds_prob = conv_stack(text_embd, 3, hp.prdsnet_conv_k, hp.prdsnet_depth, hp.n_prds, activation=tf.nn.relu, scope='prdsnet') 69 | prds_out = tf.argmax(prds_prob, axis=-1) 70 | if is_training: prds_embd = tf.nn.embedding_lookup(E_prds, prds) 71 | else: prds_embd = tf.nn.embedding_lookup(E_prds, prds_out) 72 | 73 | embd_out = text_embd + prds_embd 74 | 75 | if hp.embed_dropout: 76 | embd_out = tf.layers.dropout(embd_out, rate=0.2, training=is_training, name='dropout_N') 77 | if is_training: 78 | embd_out = gaussian_noise(embd_out, is_training) 79 | 80 | if hp.encoder_type == 'sa': 81 | if hp.txt_use_posenc: 82 | N_pos_embd_out = tf.tile(PE_table[:, :tf.shape(embd_out)[1], :], (B, 1, 1)) 83 | embd_out = tf.concat([embd_out, N_pos_embd_out], axis=-1) 84 | 85 | if is_training: 86 | '''声学特征离散嵌入''' 87 | E_f0 = tf.get_variable('E_f0', [hp.n_f0_bins, hp.var_embed_depth], dtype=tf.float32, initializer=tf.truncated_normal_initializer(stddev=0.5)) 88 | E_c0 = tf.get_variable('E_c0', [hp.n_c0_bins, hp.var_embed_depth], dtype=tf.float32, initializer=tf.truncated_normal_initializer(stddev=0.5)) 89 | 90 | f0_embd = tf.nn.embedding_lookup(E_f0, f0_targets) 91 | c0_embd = tf.nn.embedding_lookup(E_c0, c0_targets) 92 | 93 | if hp.embed_dropout: 94 | f0_embd = tf.layers.dropout(f0_embd, rate=0.2, training=is_training, name='dropout_T') 95 | c0_embd = tf.layers.dropout(c0_embd, rate=0.2, training=is_training, name='dropout_T') 96 | if is_training: 97 | f0_embd = gaussian_noise(f0_embd, is_training) 98 | c0_embd = gaussian_noise(c0_embd, is_training) 99 | 100 | if hp.var_use_posenc: 101 | T_pos_embd_out = tf.tile(PE_table[:, :tf.shape(f0_targets)[-1], :], (B, 1, 1)) 102 | f0_embd = tf.concat([f0_embd, T_pos_embd_out], axis=-1) 103 | c0_embd = tf.concat([c0_embd, T_pos_embd_out], axis=-1) 104 | else: 105 | f0_embd = c0_embd = None 106 | 107 | # Encoder 108 | if hp.encoder_type == 'sa': 109 | encoder_out, (slf_attn, crx_attn), ((f0_r, f0_r_pred), (c0_r, c0_r_pred)) = encoder_sa(embd_out, text_lengths, f0_embd, c0_embd, spec_lengths, is_training) 110 | elif hp.encoder_type == 'cb': 111 | encoder_out = cbhg(embd_out, text_lengths, hp.encoder_conv_K, [hp.encoder_depth//2, hp.encoder_depth], hp.encoder_depth, is_training) 112 | else: raise 113 | if is_training: encoder_out = gaussian_noise(encoder_out, is_training) 114 | 115 | # Decoder (layers specified bottom to top): # 将生成序列长度的控制问题转换为RNN迭代次数预测,参见`stop_projection` 116 | multi_rnn_cell = MultiRNNCell([ # 将inner_repr喂给RNN产出rnn_output 117 | ResidualWrapper(GRUCell(hp.decoder_depth)) 118 | for _ in range(hp.decoder_layers) 119 | ], state_is_tuple=True) # [N, T_in, decoder_depth=256] 120 | attention_mechanism = LocationSensitiveAttention(hp.attention_depth, encoder_out, text_lengths) # [N, T_in, attn_depth=128] 121 | frame_projection = FrameProjection(hp.n_mel * hp.outputs_per_step) # [N, T_out/r, M*r], 将concat([rnn_output,attn_context])投影为长度 r*n_mel 的向量、之后会reshape成r帧 122 | stop_projection = StopProjection(is_training, shape=hp.outputs_per_step) # [N, T_out/r, r], 投影为r个标量、只要其中有一个大于0.5就认结束生成 123 | decoder_cell = TacotronDecoderWrapper(multi_rnn_cell, attention_mechanism, frame_projection, stop_projection, is_training) 124 | if is_training: helper = TacoTrainingHelper(batch_size, mel_targets, hp.n_mel, hp.outputs_per_step, self.global_step) 125 | else: helper = TacoTestHelper(batch_size, hp.n_mel, hp.outputs_per_step) 126 | decoder_init_state = decoder_cell.zero_state(batch_size=batch_size, dtype=tf.float32) 127 | (decoder_out, stop_token_out, _), final_decoder_state, _ = tf.contrib.seq2seq.dynamic_decode( 128 | CustomDecoder(decoder_cell, helper, decoder_init_state), # [N, T_out/r, M*r] 129 | impute_finished=True, maximum_iterations=hp.max_iters) 130 | 131 | # Reshape outputs to be one output per entry 132 | mel_out = tf.reshape(decoder_out, [batch_size, -1, hp.n_mel]) # [N, T_out, M], mel用于参与loss计算、并不用于产生最终wav 133 | stop_token_out = tf.reshape(stop_token_out, [batch_size, -1]) # [N, T_out], 这个结果没有用、只是在decode时作参考而已 134 | alignments = tf.transpose(final_decoder_state.alignment_history.stack(), [1, 2, 0]) 135 | 136 | # k = 7 should cover 1.5 mel-groups (reduce_factor) 137 | if hp.decoder_sew_layer: 138 | mel_out += tf.layers.conv1d(mel_out, hp.n_mel, 7, padding='same', name='sew_up_layer') 139 | 140 | # Posnet 141 | x = mel_out[:,:,:hp.n_mel_low] 142 | x = tf.layers.dense(x, hp.posnet_depth//4, name='posnet1') 143 | x = tf.nn.leaky_relu(x) 144 | x = tf.layers.dense(x, hp.posnet_depth//2, name='posnet2') 145 | x = tf.nn.leaky_relu(x) 146 | x = tf.layers.dense(x, hp.posnet_depth, name='posnet3') 147 | x = tf.nn.leaky_relu(x) 148 | mag_out = tf.concat([tf.layers.dense(s, (hp.n_freq-1)//hp.posnet_ngroup, name=f'posnet4_{i}') 149 | for i, s in enumerate(tf.split(x, hp.posnet_ngroup, axis=-1))], axis=-1) 150 | 151 | # data in 152 | self.text_lengths = text_lengths 153 | self.text = text 154 | self.prds = prds 155 | self.spec_lengths = spec_lengths 156 | self.mel_targets = mel_targets 157 | self.mag_targets = mag_targets 158 | self.stop_token_targets = stop_token_targets 159 | # data out 160 | self.prds_prob = prds_prob 161 | self.prds_out = prds_out 162 | self.mel_outputs = mel_out 163 | self.mag_outputs = mag_out 164 | self.stop_token_outputs = stop_token_out 165 | self.alignments = alignments 166 | # misc 167 | if is_training: 168 | # NOTE: must get `._ratio` after `TacoTrainingHelper.initialize()` 169 | self.tfr = helper._ratio 170 | if hp.encoder_type == 'sa': 171 | self.slf_attn = slf_attn 172 | self.crx_attn = crx_attn 173 | self.f0_r = f0_r 174 | self.f0_r_pred = f0_r_pred 175 | self.c0_r = c0_r 176 | self.c0_r_pred = c0_r_pred 177 | 178 | def get_cosine_sim(x): 179 | dot = tf.matmul(x, x, transpose_b=True) 180 | n = tf.norm(x, axis=-1, keepdims=True) 181 | norm = tf.matmul(n, n, transpose_b=True) 182 | sim = dot / (norm + 1e-8) 183 | return sim 184 | 185 | self.E_text = E_text 186 | self.E_text_sim = get_cosine_sim(E_text) 187 | if hp.g2p == 'syl4': 188 | self.E_tone = E_tone 189 | self.E_tone_sim = get_cosine_sim(E_tone) 190 | self.E_prds = E_prds 191 | self.E_prds_sim = get_cosine_sim(E_prds) 192 | 193 | log('Initialized TrasTacoS Model: ') 194 | log(f' embd out: {embd_out.shape}') 195 | if hp.g2p == 'syl4': 196 | log(f' syl4 embd: {text_embd.shape}') 197 | log(f' tone embd: {tone_embd.shape}') 198 | log(f' prds embd: {prds_embd.shape}') 199 | if hp.encoder_type == 'sa' and is_training: 200 | log(f' f0 embd: {f0_embd.shape}') 201 | log(f' c0 embd: {c0_embd.shape}') 202 | log(f' encoder out: {encoder_out.shape}') 203 | log(f' decoder out (r frames): {decoder_out.shape}') 204 | log(f' mel out (1 frame): {mel_out.shape}') 205 | log(f' stoptoken out: {stop_token_out.shape}') 206 | log(f' mag out: {mag_out.shape}') 207 | log(f' E_text: {E_text.shape}') 208 | if hp.g2p == 'syl4': 209 | log(f' E_tone: {E_tone.shape}') 210 | log(f' E_prds: {E_prds.shape}') 211 | 212 | 213 | def add_loss(self): 214 | '''Adds loss to the model. Sets "loss" field. initialize() must have been called.''' 215 | 216 | hp = self._hparams 217 | with tf.variable_scope('loss'): 218 | self.mel_loss = tf.reduce_mean(tf.abs(self.mag_targets - self.mag_outputs)) 219 | self.mag_loss = tf.reduce_mean(tf.abs(self.mel_targets - self.mel_outputs)) 220 | if hp.encoder_type == 'sa' and hp.encoder_fusenet: 221 | self.f0_loss = tf.reduce_mean(tf.square(self.f0_r - self.f0_r_pred)) 222 | self.c0_loss = tf.reduce_mean(tf.square(self.c0_r - self.c0_r_pred)) 223 | else: 224 | self.f0_loss = 0.0 225 | self.c0_loss = 0.0 226 | if hp.g2p == 'syl4': 227 | self.prds_loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(labels=self.prds, logits=self.prds_prob)) 228 | else: 229 | self.prds_loss = 0.0 230 | if hp.g2p == 'seq': 231 | self.sim_loss = tf.reduce_mean(tf.abs((1.0 - tf.eye(get_vocab_size())) * self.E_text_sim)) * hp.sim_weight 232 | else: 233 | self.sim_loss = tf.add_n([tf.reduce_mean(tf.abs((1.0 - tf.eye(get_vocab_size())) * self.E_text_sim)), 234 | tf.reduce_mean(tf.abs((1.0 - tf.eye(hp.n_prds)) * self.E_prds_sim))]) * hp.sim_weight 235 | self.stop_token_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=self.stop_token_targets, logits=self.stop_token_outputs)) 236 | self.reg_loss = tf.add_n([tf.nn.l2_loss(v) for v in tf.trainable_variables()]) * hp.reg_weight 237 | 238 | self.loss = (self.prds_loss + 239 | self.mel_loss + 240 | self.mag_loss + 241 | self.f0_loss + 242 | self.c0_loss + 243 | self.sim_loss + 244 | self.stop_token_loss + 245 | self.reg_loss) 246 | 247 | def add_optimizer(self): 248 | '''Adds optimizer. Sets "gradients" and "optimize" fields. add_loss() must have been called.''' 249 | 250 | with tf.variable_scope('optimizer'): 251 | hp = self._hparams 252 | 253 | if hp.decay_learning_rate: 254 | self.learning_rate = _learning_rate_decay(hp.initial_learning_rate, self.global_step) 255 | else: 256 | self.learning_rate = tf.convert_to_tensor(hp.initial_learning_rate) 257 | optimizer = tf.train.AdamOptimizer(self.learning_rate, hp.adam_beta1, hp.adam_beta2, hp.adam_eps) 258 | gradients, variables = zip(*optimizer.compute_gradients(self.loss)) 259 | self.gradients = tuple([g for g in gradients if g is not None]) # FIXME: do not know why 260 | clipped_gradients, _ = tf.clip_by_global_norm(gradients, 1.0) # 防止梯度爆炸 261 | 262 | # Add dependency on UPDATE_OPS; otherwise batchnorm won't work correctly. See: 263 | # https://github.com/tensorflow/tensorflow/issues/1122 264 | with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)): 265 | self.optimize = optimizer.apply_gradients(zip(clipped_gradients, variables), 266 | global_step=self.global_step) 267 | 268 | 269 | def add_stats(self): 270 | with tf.variable_scope('stats'): 271 | hp = self._hparams 272 | 273 | tf.summary.histogram('mel_outputs', self.mel_outputs) 274 | tf.summary.histogram('mel_targets', self.mel_targets) 275 | tf.summary.histogram('mag_outputs', self.mag_outputs) 276 | tf.summary.histogram('mag_targets', self.mag_targets) 277 | 278 | tf.summary.scalar('learning_rate', self.learning_rate) 279 | tf.summary.scalar('loss', self.loss) 280 | tf.summary.scalar('tfr', self.tfr) 281 | tf.summary.scalar('mel_loss', self.mel_loss) 282 | tf.summary.scalar('mag_loss', self.mag_loss) 283 | if hp.g2p == 'syl4': 284 | tf.summary.scalar('prds_loss', self.prds_loss) 285 | if hp.encoder_type == 'sa': 286 | tf.summary.scalar('f0_loss', self.f0_loss) 287 | tf.summary.scalar('c0_loss', self.c0_loss) 288 | tf.summary.scalar('sim_loss', self.sim_loss) 289 | tf.summary.scalar('stop_token_loss', self.stop_token_loss) 290 | tf.summary.scalar('reg_loss', self.reg_loss) 291 | 292 | gradient_norms = [tf.norm(grad) for grad in self.gradients] 293 | tf.summary.histogram('gradient_norm', gradient_norms) 294 | tf.summary.scalar('max_gradient_norm', tf.reduce_max(gradient_norms)) 295 | 296 | raw = tf.numpy_function(inv_spec, [tf.transpose(self.mag_targets[0])], tf.float32) 297 | gen = tf.numpy_function(inv_spec, [tf.transpose(self.mag_outputs[0])], tf.float32) 298 | tf.summary.audio('raw', tf.expand_dims(raw, 0), hp.sample_rate, 1) 299 | tf.summary.audio('gen', tf.expand_dims(gen, 0), hp.sample_rate, 1) 300 | 301 | expand_dims = lambda x: tf.expand_dims(tf.expand_dims(x, 0), -1) 302 | tf.summary.image('alignments', expand_dims(self.alignments[0])) 303 | tf.summary.image('E_text_sim', expand_dims(self.E_text_sim)) 304 | if hp.g2p != 'seq': 305 | tf.summary.image('E_tone_sim', expand_dims(self.E_tone_sim)) 306 | tf.summary.image('E_prds_sim', expand_dims(self.E_prds_sim)) 307 | if hp.encoder_type == 'sa': 308 | for i in range(hp.encoder_attn_layers): 309 | for j in range(hp.encoder_attn_nhead): 310 | tf.summary.image(f'slf_attn_{i}{j}', expand_dims(self.slf_attn[i][j][0])) 311 | if self.crx_attn: 312 | for i in range(2): 313 | tf.summary.image(f'crx_attn_{i}', expand_dims(self.crx_attn[i][0])) 314 | 315 | #tf.summary.image('mel_out', expand_dims(self.mel_outputs[0])) 316 | 317 | self.stats = tf.summary.merge_all() 318 | 319 | 320 | def _learning_rate_decay(init_lr, global_step): 321 | # Noam scheme from tensor2tensor: 322 | warmup_steps = 4000.0 323 | step = tf.cast(global_step + 1, dtype=tf.float32) 324 | return init_lr * warmup_steps**0.5 * tf.minimum(step * warmup_steps**-1.5, step**-0.5) 325 | -------------------------------------------------------------------------------- /transtacos/preprocess.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Author: Armit 3 | # Create Time: 2022/01/07 4 | 5 | import os 6 | import random 7 | from pprint import pformat 8 | from argparse import ArgumentParser 9 | from importlib import import_module 10 | from typing import List, Tuple 11 | 12 | import hparam as hp 13 | random.seed(hp.randseed) 14 | 15 | 16 | def write_metadata(metadata:Tuple[List, List], stats:dict, wav_path, args): 17 | if args.shuffle: random.shuffle(metadata) 18 | 19 | out_path = os.path.join(args.base_dir, args.out_dir) 20 | os.makedirs(out_path, exist_ok=True) 21 | 22 | cp = int(len(metadata) * args.split_ratio) 23 | mt_test, mt_train = metadata[:cp], metadata[cp:] 24 | 25 | with open(os.path.join(out_path, 'train.txt'), 'w', encoding='utf-8') as fh: 26 | for mt in mt_train: 27 | fh.write('|'.join([str(x) for x in mt])) 28 | fh.write('\n') 29 | 30 | with open(os.path.join(out_path, 'test.txt'), 'w', encoding='utf-8') as fh: 31 | for mt in mt_test: 32 | fh.write('|'.join([str(x) for x in mt])) 33 | fh.write('\n') 34 | 35 | with open(os.path.join(out_path, 'stats.txt'), 'w', encoding='utf-8') as fh: 36 | for k, v in stats.items(): 37 | fh.write(f'{k}\t{v}') 38 | fh.write('\n') 39 | 40 | with open(os.path.join(out_path, 'wav_path.txt'), 'w', encoding='utf-8') as fh: 41 | fh.write(wav_path) 42 | 43 | 44 | if __name__ == '__main__': 45 | def str2bool(s:str) -> bool: 46 | s = s.lower() 47 | if s in ['true', 't', '1']: return True 48 | if s in ['false', 'f', '0']: return False 49 | raise ValueError(f'invalid bool value: {s}') 50 | 51 | base_dir = os.path.dirname(os.path.abspath(__file__)) 52 | DATASETS = [fn[:-3] for fn in os.listdir(os.path.join(base_dir, 'datasets')) if not fn.startswith('__')] 53 | 54 | parser = ArgumentParser() 55 | parser.add_argument('--base_dir', required=True, help='base path containing the dataset folder') 56 | parser.add_argument('--out_dir', default='preprocessed', help='preprocessed output folder') 57 | parser.add_argument('--dataset', required=True, choices=DATASETS) 58 | parser.add_argument('--shuffle', type=str2bool, default=True, help='shuffle metadata') 59 | parser.add_argument('--split_ratio', type=float, default=0.05, help='test/train split') 60 | parser.add_argument('--num_workers', type=int, default=4) 61 | args = parser.parse_args() 62 | 63 | os.environ['LIBROSA_CACHE_LEVEL'] = '50' 64 | 65 | proc = import_module(f'datasets.{args.dataset}') 66 | metadata, stats, wav_path = proc.preprocess(args) 67 | print('wav_path:', wav_path) 68 | print('stats:', pformat(stats)) # why `sort_dicts=False` not work 69 | write_metadata(metadata, stats, wav_path, args) 70 | -------------------------------------------------------------------------------- /transtacos/server.py: -------------------------------------------------------------------------------- 1 | import os 2 | import io 3 | import pickle 4 | from re import compile as Regex 5 | from time import time 6 | from tempfile import gettempdir 7 | from argparse import ArgumentParser 8 | 9 | import tensorflow as tf 10 | import numpy as np 11 | from flask import Flask, request, jsonify, send_file 12 | from requests.utils import unquote 13 | from xpinyin import Pinyin 14 | 15 | import hparam as hp 16 | from synth import Synthesizer 17 | from audio import save_wav 18 | 19 | 20 | for gpu in tf.config.experimental.list_physical_devices('GPU'): 21 | tf.config.experimental.set_memory_growth(gpu, True) 22 | 23 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1' 24 | os.environ['TF_XLA_FLAGS'] = '--tf_xla_cpu_global_jit' 25 | os.environ['XLA_FLAGS'] = '--xla_hlo_profile' 26 | 27 | np.random.seed (hp.randseed) 28 | tf.random.set_random_seed(hp.randseed) 29 | 30 | 31 | # globals 32 | app = Flask(__name__) 33 | kanji2pinyin = Pinyin() 34 | synthesizer = None 35 | html_text = None 36 | 37 | BASE_PATH = os.path.dirname(os.path.abspath(__file__)) 38 | HTML_FILE = os.path.join(BASE_PATH, '..', 'index.html') 39 | 40 | TMP_DIR = gettempdir() 41 | WAV_TMP_FILE = os.path.join(TMP_DIR, 'synth.wav') 42 | 43 | REGEX_PUNCT_IGNORE = Regex('、|:|;|“|”|‘|’') 44 | REGEX_PUNCT_BREAK = Regex(',|。|!|?') 45 | MAX_CLUASE_LENGTH = 20 46 | 47 | 48 | # quick demo index page 49 | @app.route('/', methods=['GET']) 50 | def root(): 51 | global html_text 52 | if not html_text: 53 | with open(HTML_FILE, encoding='utf-8') as fp: 54 | html_text = fp.read() 55 | return html_text 56 | 57 | 58 | # vocode with internal Griffin-Lim 59 | @app.route('/synth', methods=['GET']) 60 | def synth(): 61 | kanji = unquote(request.args.get('text')) 62 | 63 | if kanji: 64 | try: 65 | # Text-Norm 66 | if True: 67 | s = time() 68 | print(f'text/raw: {kanji!r}') 69 | 70 | kanji = REGEX_PUNCT_IGNORE.sub('', kanji) 71 | kanji = REGEX_PUNCT_BREAK.sub(' ', kanji) 72 | segs = [''] # dummy init 73 | for rs in [s.strip() for s in kanji.split(' ') if s.strip()]: 74 | if (not segs[-1]) or (len(rs) + len(segs[-1]) < MAX_CLUASE_LENGTH): 75 | segs[-1] = segs[-1] + rs 76 | else: segs.append(rs) 77 | print(f'text/segs: {segs!r}') 78 | t = time() 79 | print(f'[TextNorm] Done in {t - s:.2f}s') 80 | 81 | # Synth 82 | if True: 83 | s = time() 84 | wav_clips = [] 85 | for seg in segs: 86 | text = ' '.join(kanji2pinyin.get_pinyin(seg, tone_marks='numbers').split('-')) 87 | wav = synthesizer.synthesize(text, 'wav') 88 | wav_clips.append(wav) 89 | wav = np.concatenate(wav_clips) 90 | print('wav.shape:', wav.shape) 91 | t = time() 92 | print(f'[Synth] Done in {t - s:.2f}s') 93 | 94 | # Save file 95 | if True: 96 | s = time() 97 | save_wav(wav, WAV_TMP_FILE) 98 | t = time() 99 | print(f'[SaveFile] Done in {t - s:.2f}s') 100 | 101 | return send_file(WAV_TMP_FILE, mimetype='audio/wav') 102 | except Exception as e: 103 | print('[Error] %r' % e) 104 | error_msg = 'synth failed, see logs' 105 | else: 106 | error_msg = 'bad request params or no text to synth?' 107 | 108 | return jsonify({'error': error_msg}) 109 | 110 | 111 | # return linear spec 112 | @app.route('/synth_spec', methods=['POST']) 113 | def synth_spec(): 114 | try: 115 | # chk txt 116 | pinyin = request.get_json().get('pinyin').strip() 117 | if not pinyin: 118 | return jsonify({'error': 'no text to synth'}) 119 | 120 | # text to mag 121 | s = time() 122 | spec = synthesizer.synthesize(pinyin, 'spec') 123 | print('spec.shape:', spec.shape) 124 | t = time() 125 | print(f'[Synth] Done in {t - s:.2f}s') 126 | 127 | # transfer 128 | bio = io.BytesIO() 129 | bio.write(pickle.dumps(spec)) # float32 -> byte 130 | bio.seek(0) # reset fp to beginning for `send_file` to read 131 | return send_file(bio, mimetype='application/octet-stream') 132 | 133 | except Exception as e: 134 | print('[Error] %r' % e) 135 | return jsonify({'error': e}) 136 | 137 | 138 | if __name__ == '__main__': 139 | parser = ArgumentParser() 140 | parser.add_argument('--log_path', required=True) 141 | parser.add_argument('--host', type=str, default='0.0.0.0') 142 | parser.add_argument('--port', type=int, default=5105) 143 | args = parser.parse_args() 144 | 145 | # load ckpt 146 | synthesizer = Synthesizer() 147 | synthesizer.load(args.log_path) 148 | 149 | app.run(host=args.host, port=args.port, debug=False) 150 | -------------------------------------------------------------------------------- /transtacos/synth.py: -------------------------------------------------------------------------------- 1 | from time import time 2 | import numpy as np 3 | import tensorflow as tf 4 | 5 | import hparam as hp 6 | from models.tacotron import Tacotron 7 | from text.text import text_to_phoneme, phoneme_to_sequence, sequence_to_phoneme 8 | import audio as A 9 | from text.symbols import _eos, _sep, get_vocab_size 10 | from text.phonodict_cn import phonodict 11 | 12 | 13 | class Synthesizer: 14 | 15 | def load(self, log_dir): 16 | print('Constructing tacotron model') 17 | 18 | # init data placeholder 19 | if hp.g2p == 'seq': text_shape = [1, None] 20 | elif hp.g2p == 'syl4': text_shape = [1, None, 2] 21 | text_lengths = tf.placeholder(tf.int32, [1], 'text_lengths') # bs=1 for one sample 22 | text = tf.placeholder(tf.int32, text_shape, 'text') 23 | with tf.variable_scope('model'): 24 | self.model = Tacotron(hp) 25 | self.model.initialize(text_lengths, text) 26 | self.mag_output = self.model.mag_outputs[0] 27 | 28 | # load ckpt 29 | checkpoint_state = tf.train.get_checkpoint_state(log_dir) 30 | print('Resuming from checkpoint: %s' % checkpoint_state.model_checkpoint_path) 31 | self.session = tf.Session() 32 | self.session.run(tf.global_variables_initializer()) 33 | saver = tf.train.Saver() 34 | saver.restore(self.session, checkpoint_state.model_checkpoint_path) 35 | 36 | def synthesize(self, text, out_type='wav') -> bytes: 37 | # ref: `data.DataFeeder.load_data` 38 | if hp.g2p == 'seq': 39 | text = text + _eos 40 | print('text: ', text) 41 | phs = text_to_phoneme(text) 42 | print('phs: ', phs) 43 | seq = phoneme_to_sequence(phs) 44 | print('seq: ', seq) 45 | phs_rev = sequence_to_phoneme(seq) 46 | print('phs_rev: ', phs_rev) 47 | elif hp.g2p == 'syl4': 48 | C, V, T, Vx = text_to_phoneme(text) # [[str]] 49 | 50 | CVVx, Tx = [ ], [ ] 51 | n_syllable = len(C) 52 | for i in range(n_syllable): 53 | if C[i] != phonodict.vacant: 54 | CVVx.append(C[i]) ; Tx.append(T[i]) 55 | if V[i] != phonodict.vacant: 56 | CVVx.append(V[i]) ; Tx.append(T[i]) 57 | if Vx[i] != phonodict.vacant: 58 | CVVx.append(Vx[i]) ; Tx.append(T[i]) 59 | 60 | CVVx.append(_sep) ; Tx.append(0) 61 | 62 | # NOTE: pad here, then convert to id_seq 63 | CVVx = phoneme_to_sequence(CVVx + [_eos]) # see phone table 64 | Tx = [int(t) for t in Tx] + [0] # should be 0 ~ 5 65 | 66 | assert len(CVVx) == len(Tx) 67 | assert 0 <= min(CVVx) and max(CVVx) < get_vocab_size() 68 | assert 0 <= min(Tx) and max(Tx) < 6 69 | 70 | seq = np.stack([CVVx, Tx], axis=-1) # [T, 2] 71 | 72 | seq = np.asarray(seq, dtype=np.int32) 73 | 74 | feed_dict = { 75 | self.model.text_lengths: [len(seq)], # len(id_seq) 76 | self.model.text: [seq], # id_seq 77 | } 78 | mag = self.session.run(self.mag_output, feed_dict=feed_dict) 79 | mag = mag.T # [F-1, T] 80 | if out_type == 'wav': 81 | wav = A.inv_spec(mag) # vocode with internal Griffin-Lim 82 | wav = A.trim_silence(wav) 83 | return wav # only data chunk, no RIFF capsulation 84 | if out_type == 'spec': 85 | S = A.spec_to_natural_scale(mag) # denorm 86 | S = A.fix_zero_DC(S) 87 | return S 88 | -------------------------------------------------------------------------------- /transtacos/text/__init__.py: -------------------------------------------------------------------------------- 1 | from .symbols import * 2 | from .text import * 3 | from .phonodict_cn import * 4 | from .g2p import * 5 | -------------------------------------------------------------------------------- /transtacos/text/g2p.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Author: Armit 3 | # Create Time: 2021/3/29 4 | 5 | from typing import List 6 | 7 | from .symbols import _unk 8 | from .phonodict_cn import phonodict 9 | 10 | 11 | def to_syl4(pinyin:str, sep=' ') -> List[List[str]]: 12 | C, V, T, Vx = [], [], [], [] 13 | 14 | py_ls = pinyin.split(sep) 15 | n_syllable = len(py_ls) 16 | for py in py_ls: 17 | # split tone 18 | t = py[-1] 19 | if t.isdigit(): py = py[:-1] 20 | else: t = '5' 21 | 22 | # deletec R-ending 23 | r_ending = False 24 | if py[-1] == 'r': 25 | r_ending = True 26 | if py != 'er': 27 | py = py[:-1] 28 | 29 | # split CV 30 | try: 31 | c, v, e = phonodict[py] 32 | C.append(c) ; V.append(v) ; T.append(t) 33 | if r_ending: Vx.append('_R') # let R overriding N or NG 34 | else: Vx.append(e) 35 | 36 | except: 37 | C.append(_unk) ; V.append(_unk) ; T.append(_unk) ; Vx.append(_unk) 38 | print('[Syllable] cannot parse %r' % py) 39 | 40 | assert len(C) == len(V) == len(T) == len(Vx) == n_syllable 41 | return [C, V, T, Vx] 42 | 43 | 44 | def from_syl4(syl4:List[List[str]], sep=' ') -> str: 45 | return sep.join([''.join(s) for s in zip(*syl4)]) 46 | 47 | 48 | if __name__ == '__main__': 49 | pinyin = 'zi3 se4 de hua1 er2 wei4 shen2 me zher4 yang4 yuan2' 50 | print('pinyin:', pinyin) 51 | syl4 = to_syl4(pinyin) 52 | print('syl4:', syl4) 53 | syl4_serial = from_syl4(syl4) 54 | print('syl4_serial:', syl4_serial) 55 | -------------------------------------------------------------------------------- /transtacos/text/phonodict_cn.csv: -------------------------------------------------------------------------------- 1 | ,-,b,d,g,p,t,k,j,q,x,z,c,s,zh,ch,sh,m,n,l,f,h,y,w,r 2 | a,a,b a,d a,g a,p a,t a,k a,,,,z a,c a,s a,zh a,ch a,sh a,m a,n a,l a,f a,h a,ia,ua, 3 | o,o,b uo,,,p uo,,,,,,,,,,,,m uo,,l uo,f uo,,io,uo, 4 | e,e,,d e,g e,,t e,k e,,,,z e,c e,s e,zh e,ch e,sh e,m e,n e,l e,,h e,iE,,r e 5 | i,,b i,d i,,p i,t i,,j i,q i,x i,z i0,c i0,s i0,zh iR,ch iR,sh iR,m i,n i,l i,,,i,,r iR 6 | u,,b u,d u,g u,p u,t u,k u,j v,q v,x v,z u,c u,s u,zh u,ch u,sh u,m u,n u,l u,f u,h u,v,u,r u 7 | v,,,,,,,,,,,,,,,,,,n v,l v,,,,, 8 | ai,ai,b ai,d ai,g ai,p ai,t ai,k ai,,,,z ai,c ai,s ai,zh ai,ch ai,sh ai,m ai,n ai,l ai,,h ai,,uai, 9 | ao,ao,b ao,d ao,g ao,p ao,t ao,k ao,,,,z ao,c ao,s ao,zh ao,ch ao,sh ao,m ao,n ao,l ao,,h ao,iao,,r ao 10 | ei,Ei,b Ei,d Ei,g Ei,p Ei,t Ei,k Ei,,,,z Ei,,,zh Ei,,sh Ei,m Ei,n Ei,l Ei,f Ei,h Ei,,uEi, 11 | ou,ou,,d ou,g ou,p ou,t ou,k ou,,,,z ou,c ou,s ou,zh ou,ch ou,sh ou,m ou,n ou,l ou,f ou,h ou,iou,,r ou 12 | uo,,,d uo,g uo,,t uo,k uo,,,,z uo,c uo,s uo,zh uo,ch uo,sh uo,,n uo,l uo,,h uo,,,r uo 13 | an,an,b an,d an,g an,p an,t an,k an,,,,z an,c an,s an,zh an,ch an,sh an,m an,n an,l an,f an,h an,iEn,uan,r an 14 | en,en,b en,d en,g en,p en,,k en,,,,z en,c en,s en,zh en,ch en,sh en,m en,n en,,f en,h en,,un,r en 15 | in,,b in,,,p in,,,j in,q in,x in,,,,,,,m in,n in,l in,,,in,, 16 | un,,,d un,g un,,t un,k un,j vn,q vn,x vn,z un,c un,s un,zh un,ch un,sh un,,n un,l un,,h un,vn,,r un 17 | ang,ang,b ang,d ang,g ang,p ang,t ang,k ang,,,,z ang,c ang,s ang,zh ang,ch ang,sh ang,m ang,n ang,l ang,f ang,h ang,iang,uang,r ang 18 | eng,eng,b eng,d eng,g eng,p eng,t eng,k eng,,,,z eng,c eng,s eng,zh eng,ch eng,sh eng,m eng,n eng,l eng,f eng,h eng,,ueng,r eng 19 | ing,,b ing,d ing,,p ing,t ing,,j ing,q ing,x ing,,,,,,,m ing,n ing,l ing,,,ing,, 20 | ong,,,d ong,g ong,,t ong,k ong,,,,z ong,c ong,s ong,zh ong,ch ong,,,n ong,l ong,,h ong,iong,,r ong 21 | ia,,,d ia,,,,,j ia,q ia,x ia,,,,,,,,,l ia,,,,, 22 | ian,,b iEn,d iEn,,p iEn,t iEn,,j iEn,q iEn,x iEn,,,,,,,m iEn,n iEn,l iEn,,,,, 23 | iang,,b iang,,,,,,j iang,q iang,x iang,,,,,,,,n iang,l iang,,,,, 24 | iong,,,,,,,,j iong,q iong,x iong,,,,,,,,,,,,,, 25 | ie,,b iE,d iE,,p iE,t iE,,j iE,q iE,x iE,,,,,,,m iE,n iE,l iE,,,,, 26 | iu,,,d iou,,,,,j iou,q iou,x iou,,,,,,,m iou,n iou,l iou,,,,, 27 | iao,,b iao,d iao,,p iao,t iao,,j iao,q iao,x iao,,,,,,,m iao,n iao,l iao,f iao,,,, 28 | ua,,,,g ua,,,k ua,,,,,,,zh ua,ch ua,sh ua,,,,,h ua,,,r ua 29 | uan,,,d uan,g uan,,t uan,k uan,j vEn,q vEn,x vEn,z uan,c uan,s uan,zh uan,ch uan,sh uan,,n uan,l uan,,h uan,vEn,,r uan 30 | uang,,,,g uang,,,k uang,,,,,,,zh uang,ch uang,sh uang,,,,,h uang,,, 31 | ue,,,,,,,,j vE,q vE,x vE,,,,,,,,,,,,vE,, 32 | ui,,,d uEi,g uEi,,t uEi,k uEi,,,,z uEi,c uEi,s uEi,zh uEi,ch uEi,sh uEi,,,,,h uEi,,,r uEi 33 | uai,,,,g uai,,,k uai,,,,,,,zh uai,ch uai,sh uai,,,,,h uai,,, 34 | ve,,,,,,,,,,,,,,,,,,n vE,l vE,,,,, 35 | er,R,,,,,,,,,,,,,,,,,,,,,,, 36 | -------------------------------------------------------------------------------- /transtacos/text/phonodict_cn.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | # 无声调音素字典,基本等价于X-SAMPA(Vocaloid)方案 4 | # 声母是单纯辅音,介母向韵母方向黏着 5 | 6 | LEXDICT_FILE = Path(__file__).absolute().parent / 'phonodict_cn.csv' 7 | 8 | # NOT: hard coded accroding to `symbol.py`` 9 | _pad = '_' 10 | 11 | 12 | class Phonodict4: 13 | 14 | def __init__(self, fp=LEXDICT_FILE, vac_sym=_pad): 15 | self.entry = { } # {'hui': 'h uEi _'}, 1 py = 3 syl 16 | self.initials = [] # pinyin := initial + final 17 | self.finals = [] 18 | self.consonants = [] # syl4 := consonant + vowel + (tone) + ending 19 | self.vowels = [] 20 | self.endings = ['_N', '_NG', '_R'] 21 | self.vacant = vac_sym # '_', for zero consonant/ending 22 | 23 | self.dict_fp = fp 24 | self._load_dict() 25 | 26 | def _load_dict(self): 27 | I, F, C, V = set(), set(), set(), set() 28 | 29 | with open(self.dict_fp) as fh: 30 | ilist = list(fh.readline().strip().split(',')[1:]) 31 | ilist[0] = '' # fix '-' to '' 32 | for row in fh.readlines(): 33 | ls = row.strip().split(',') 34 | f, cvlist = ls[0], ls[1:] 35 | F.add(f) 36 | for i in ilist: 37 | I.add(i) 38 | cv = cvlist[ilist.index(i)] 39 | if cv: # found a valid syllble 40 | if cv == 'R': 41 | c = self.vacant 42 | v = 'e' 43 | e = '_R' 44 | else: 45 | if cv.endswith('ng'): 46 | cv = cv[:-2] ; e = '_NG' 47 | elif cv.endswith('n'): 48 | cv = cv[:-1] ; e = '_N' 49 | else: 50 | e = self.vacant 51 | if ' ' in cv: # 'k ua' 52 | c, v = cv.split(' ') 53 | else: 54 | c = self.vacant ; v = cv 55 | C.add(c) ; V.add(v) 56 | self.entry[i + f] = [c, v, e] 57 | 58 | self.initials = sorted(list(I)) 59 | self.finals = sorted(list(F)) 60 | self.consonants = sorted(list(C)) 61 | self.vowels = sorted(list(V)) 62 | 63 | def __getitem__(self, py:str) -> str: 64 | return self.entry.get(py, None) 65 | 66 | def __len__(self) -> int: 67 | return len(self.entry) 68 | 69 | @property 70 | def vacant_symbol(self) -> str: 71 | return self.vacant 72 | 73 | def inspect(self): 74 | print(f'syllable count: {len(self.entry)}') 75 | print(f'initials({len(self.initials)}): {self.initials}') 76 | print(f'finals({len(self.finals)}): {self.finals}') 77 | print(f'consonants({len(self.consonants)}): {self.consonants}') 78 | print(f'vowels({len(self.vowels)}): {self.vowels}') 79 | print(f'endings({len(self.endings)}): {self.endings}') 80 | 81 | 82 | phonodict = Phonodict4() 83 | 84 | 85 | if __name__ == '__main__': 86 | phonodict.inspect() 87 | -------------------------------------------------------------------------------- /transtacos/text/phonodict_cn.txt: -------------------------------------------------------------------------------- 1 | [字典] 2 | 声母(21): 3 | - 卷(4): zh, ch, sh, r 4 | - 尖(3): z, c, s 5 | - 团(3): j, q, x 6 | - 按送气方式(波形更相近) 7 | - 送气(5): p, t, k, f, h 8 | - 不送气(4): b, d, g, l 9 | - 鼻(2): m, n 10 | - 按发音部位 11 | - 唇(4): b, p, m, f 12 | - 舌(4): d, t, n, l 13 | - 喉(3): g, k, h 14 | 15 | 韵母(39): 16 | - 单(7/9): a, o, e, i(i, iR, i0), u, v, R 17 | - 前闭(5): an en in un vn 18 | - 后闭(4): ang eng ing ong 19 | - 双(5): ai ao Ei ou uo 20 | - 介母化 21 | - y化: ia, iao, iE, iEn, io, iou, iang, iong 22 | - w化: ua, uai, uEi, uan, uang, ueng 23 | - v化: vE, vEn 24 | 25 | 注: 26 | 0.若无声母音节的韵头足够强,则听起来声母为b/d 27 | 28 | 29 | [拼音方案] 30 | *若使用音素phoneme作发音单位,带tone的合法音素有21+39*5=216,DataBaker的可见大小为210 31 | *若使用音节syllable作发音单位,标准普通话不含tone的合法音节有414,按下述六分化之后总字典大小为80 32 | - C: 零声母,声母表;共22个 33 | - Cx: 卷、尖、团、送气、不送气、鼻;共6个 34 | - xV: 无、y化、w化、v化;共4个 35 | - V: 韵母表;共39个 36 | - T: 平上去入,轻声;共5个 37 | - Vx: 无、前闭、后闭、儿化;共4个 38 | 39 | [韵律建模] 40 | 韵律标记: 41 | 0: 词内部 42 | 1: 连读的分词末 43 | 2: 长音或停顿的分词末 44 | 3: 分句末 45 | 4: 句末 46 | 5: 末尾标记 47 | 韵律的层级: 48 | 音节串#0 -> 单词串#1 -> 短语串#2 -> 短句串#3 -> 整句#4 49 | 50 | [Tacotron实验表现] 51 | 0. i/i0/iR几乎完全一样,可以相互替换 52 | 1. j/q/x后面似乎天生带个i, ja = jia 53 | -------------------------------------------------------------------------------- /transtacos/text/symbols.py: -------------------------------------------------------------------------------- 1 | # marks 2 | _pad = '_' # , right padding for fixed-length RNN; 3 | # /, short silence / break of speech 4 | _eos = '~' # , end of sentence 5 | _sep = '/' # separtor between syllables 6 | _unk = '?' # 7 | 8 | _markers = [_pad, _eos, _sep, _unk] # NOTE: `_pad` MUST be at index 0 9 | 10 | 11 | ''' G2P = seq ''' 12 | _chars = 'abcdefghijklmnopqrstuvwxyz 12345' 13 | 14 | 15 | ''' G2P = syl4 ''' 16 | # phonetic unit under syllable repr refer to `phonodict_cn.txt` 17 | # syl4 := CxVTx = C + V + T + Vx 18 | # _syl4 == [ 19 | # '-', 'b', 'c', 'ch', 'd', 'f', 'g', 'h', 'j', 'k', 'l', 'm', 'n', 'p', 'q', 'r', 's', 'sh', 't', 'x', 'z', 'zh', 20 | # 'Ei', 'R', 'a', 'ai', 'ao', 'e', 'i', 'i0', 'iE', 'iR', 'ia', 'iao', 'io', 'iou', 'o', 'ou', 'u', 'uEi', 'ua', 'uai', 'ue', 'uo', 'v', 'vE', 21 | # '0', '1', '2', '3', '4', '5', 22 | # '_N', '_NG', '_R', 23 | # ] 24 | # TODO: 是否合并, 'i'/'iR' - 'i0' 25 | from .phonodict_cn import phonodict 26 | #_syl4_T = ['0', '1', '2', '3', '4', '5'] # 6, NOTE: exclude from phone table 27 | #_syl4_P = ['0', '1', '2', '3', '4', '5'] # 6, NOTE: exclude from phone table 28 | _syl4_C = phonodict.consonants # 22 29 | _syl4_V = phonodict.vowels # 24 30 | _syl4_Vx = phonodict.endings # 3 31 | _syl4 = _syl4_C + _syl4_V + _syl4_Vx # 54 32 | 33 | 34 | # phonetic unit list 35 | _g2p_mapping = { 36 | 'seq': _chars, 37 | 'syl4': _syl4, 38 | } 39 | 40 | import hparam as hp 41 | 42 | assert len(set(_g2p_mapping[hp.g2p])) == len(_g2p_mapping[hp.g2p]) # assure no duplicates 43 | _symbols = _markers + sorted(set(_g2p_mapping[hp.g2p]) - set(_markers)) # keep order 44 | print(f'[Symbols] collect {len(_symbols)} symbols in {hp.g2p} repr') 45 | print(f' {_symbols}') 46 | 47 | _symbol_to_id = {s: i for i, s in enumerate(_symbols)} 48 | _id_to_symbol = {i: s for i, s in enumerate(_symbols)} 49 | 50 | 51 | def symbol_to_id(sym:str) -> int: 52 | return _symbol_to_id.get(sym, _symbol_to_id[_unk]) 53 | 54 | 55 | def id_to_symbol(id:int) -> str: 56 | return _id_to_symbol.get(id, _unk) 57 | 58 | 59 | def get_vocab_size(): 60 | return len(_symbols) 61 | 62 | 63 | def get_symbol_id(s:str): 64 | return { 65 | 'pad': symbol_to_id(_pad), 66 | 'eos': symbol_to_id(_eos), 67 | 'sep': symbol_to_id(_sep), 68 | 'unk': symbol_to_id(_unk), 69 | 'vac': symbol_to_id(phonodict.vacant_symbol), # := _pad 70 | }.get(s, symbol_to_id(s)) 71 | -------------------------------------------------------------------------------- /transtacos/text/text.py: -------------------------------------------------------------------------------- 1 | import re 2 | from typing import List, Union 3 | 4 | import hparam as hp 5 | from .symbols import symbol_to_id, id_to_symbol 6 | from .g2p import to_syl4 7 | 8 | 9 | _whitespace_re = re.compile(r'\s+') 10 | 11 | 12 | def text_to_phoneme(text:str) -> str: 13 | # clean up 14 | text = text.strip() 15 | text = text.lower() 16 | text = re.sub(_whitespace_re, ' ', text) 17 | 18 | # g2p 19 | _converter_mapping = { 20 | 'seq': lambda _: _, # => 'str' 21 | 'syl4': to_syl4, # => [C, V, T, Vx] 22 | } 23 | phs = _converter_mapping[hp.g2p](text) 24 | return phs 25 | 26 | 27 | def phoneme_to_sequence(phoneme:Union[str, List]) -> List[int]: 28 | return [symbol_to_id(ph) for ph in phoneme] 29 | 30 | 31 | def sequence_to_phoneme(sequence:List[int]) -> str: 32 | return ''.join([id_to_symbol(id) for id in sequence]) 33 | -------------------------------------------------------------------------------- /transtacos/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | from time import time, sleep 3 | from pprint import pformat 4 | from argparse import ArgumentParser 5 | import traceback 6 | 7 | import tensorflow as tf 8 | import numpy as np 9 | 10 | import hparam as hp 11 | from models.tacotron import Tacotron 12 | from data import DataFeeder 13 | from text.text import sequence_to_phoneme 14 | from audio import save_wav, inv_spec 15 | from utils import * 16 | 17 | 18 | for gpu in tf.config.experimental.list_physical_devices('GPU'): 19 | tf.config.experimental.set_memory_growth(gpu, True) 20 | 21 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1' 22 | os.environ['TF_XLA_FLAGS'] = '--tf_xla_cpu_global_jit' 23 | os.environ['XLA_FLAGS'] = '--xla_hlo_profile' 24 | 25 | np.random.seed (hp.randseed) 26 | tf.random.set_random_seed(hp.randseed) 27 | 28 | 29 | def train(args): 30 | # Logg Folder 31 | log_dir = os.path.join(args.base_dir, args.name) 32 | os.makedirs(log_dir, exist_ok=True) 33 | log_init(os.path.join(log_dir, 'train.log')) 34 | 35 | ckpt_path = os.path.join(log_dir, 'model.ckpt') 36 | log('Checkpoint path: %s' % ckpt_path) 37 | input_path = os.path.join(args.base_dir, args.input) 38 | log('Loading training data from: %s' % input_path) 39 | log('Hyperparams:') 40 | log(pformat({k: getattr(hp, k) for k in dir(hp) if not k.startswith('__')}, indent=2)) 41 | 42 | # DataFeeder 43 | coord = tf.train.Coordinator() 44 | with tf.variable_scope('datafeeder'): 45 | feeder = DataFeeder(coord, input_path, hp) 46 | 47 | # Model 48 | with tf.variable_scope('model'): 49 | model = Tacotron(hp) 50 | model.initialize(feeder.text_lengths, 51 | feeder.text, feeder.prds, 52 | feeder.spec_lengths, 53 | feeder.mel_targets, feeder.mag_targets, feeder.f0_targets, feeder.c0_targets, 54 | feeder.stop_token_targets) 55 | model.add_loss() 56 | model.add_optimizer() 57 | model.add_stats() 58 | param_count = sum([np.prod(v.get_shape()) for v in tf.trainable_variables()]) 59 | log(f'param_cnt = {param_count}') 60 | 61 | # Bookkeeping 62 | step = 0 # local step 63 | time_window = ValueWindow(100) # for perfcount 64 | loss_window = ValueWindow(100) 65 | saver = tf.train.Saver(max_to_keep=hp.max_ckpt) 66 | 67 | # Train! 68 | with tf.Session() as sess: 69 | try: 70 | sw = tf.summary.FileWriter(log_dir, sess.graph) 71 | sess.run(tf.global_variables_initializer()) 72 | 73 | # Restore from a checkpoint if available 74 | ckpt_state = tf.train.get_checkpoint_state(log_dir) 75 | if ckpt_state is not None: 76 | saver.restore(sess, ckpt_state.model_checkpoint_path) 77 | log('Resuming from checkpoint: %s' % ckpt_state.model_checkpoint_path) 78 | else: 79 | log('Starting new training run') 80 | 81 | feeder.start_in_session(sess) 82 | while not coord.should_stop(): 83 | t = time() 84 | step, loss, opt = sess.run([model.global_step, model.loss, model.optimize]) 85 | time_window.append(time() - t) 86 | loss_window.append(loss) 87 | log('Step %-7d [%.03f sec/step, loss=%.05f, avg_loss=%.05f]' % (step, time_window.average, loss, loss_window.average)) 88 | 89 | if loss > 300 or np.isnan(loss): 90 | log('Loss exploded to %.05f at step %d!' % (loss, step)) 91 | raise Exception('Loss Exploded') 92 | 93 | if step % args.summary_interval == 0: 94 | log('Writing summary at step: %d' % step) 95 | sw.add_summary(sess.run(model.stats), step) 96 | 97 | if step % args.checkpoint_interval == 0: 98 | log('Saving checkpoint to: %s-%d' % (ckpt_path, step)) 99 | saver.save(sess, ckpt_path, global_step=step) 100 | log('Saving audio and alignment...') 101 | 102 | if hp.g2p == 'seq': 103 | (text, mel, mag, alignment, spec_len, mel_r, mag_r, mel_loss, mag_loss) = sess.run([ 104 | model.text[0], 105 | model.mel_outputs[0], model.mag_outputs[0], model.alignments[0], model.spec_lengths[0], 106 | model.mel_targets[0], model.mag_targets[0], model.mel_loss, model.mag_loss]) 107 | log('Input:') 108 | log(f' seq: {text}') 109 | log(f' phs: {sequence_to_phoneme(text)}') 110 | elif hp.g2p == 'syl4': 111 | (text, prds_o, prds_r, mel, mag, alignment, spec_len, mel_r, mag_r, mel_loss, mag_loss) = sess.run([ 112 | model.text[0], model.prds_out[0], model.prds[0], 113 | model.mel_outputs[0], model.mag_outputs[0], model.alignments[0], model.spec_lengths[0], 114 | model.mel_targets[0], model.mag_targets[0], model.mel_loss, model.mag_loss]) 115 | CVVx, T = text.T.tolist() 116 | log('Input:') 117 | log(f' text: {sequence_to_phoneme(CVVx)}') 118 | log(f' tone: {"".join([str(t) for t in T])}') 119 | log(f' prds: {"".join([str(p) for p in prds_r])}') 120 | log(f' pred: {"".join([str(p) for p in prds_o])}') 121 | 122 | mel, mag, mel_r, mag_r = [m[:spec_len,:].T for m in [mel, mag, mel_r, mag_r]] 123 | save_wav(inv_spec(mag), os.path.join(log_dir, 'step-%d-audio.wav' % step)) 124 | plot_specs([mel, mag, mel_r, mag_r], os.path.join(log_dir, 'step-%d-specs.png' % step), 125 | info=f'{time_string()}, mel_loss={mel_loss:.5f}, mag_loss={mag_loss:.5f}') 126 | plot_alignment(alignment, os.path.join(log_dir, 'step-%d-align.png' % step), 127 | info='%s, step=%d, loss=%.5f' % (time_string(), step, loss)) 128 | 129 | if step >= hp.max_steps + 10: 130 | print('[Train] Done') 131 | sleep(5) 132 | break 133 | 134 | except Exception as e: 135 | log('Exiting due to exception: %s' % e) 136 | traceback.print_exc() 137 | coord.request_stop(e) 138 | 139 | 140 | if __name__ == '__main__': 141 | parser = ArgumentParser() 142 | parser.add_argument('--base_dir', default=os.path.expanduser('.')) 143 | parser.add_argument('--input', default='preprocessed/train.txt') 144 | parser.add_argument('--name', default='transtacos', help='Name of the run, used for logging.') 145 | parser.add_argument('--summary_interval', type=int, default=1000, help='Steps between running summary ops.') 146 | parser.add_argument('--checkpoint_interval', type=int, default=1500, help='Steps between writing checkpoints.') 147 | args = parser.parse_args() 148 | 149 | train(args) 150 | -------------------------------------------------------------------------------- /transtacos/utils.py: -------------------------------------------------------------------------------- 1 | import atexit 2 | from datetime import datetime 3 | 4 | import matplotlib ; matplotlib.use('Agg') 5 | import matplotlib.pyplot as plt 6 | import seaborn as sns 7 | 8 | import hparam as hp 9 | 10 | _log_fmt = '%Y-%m-%d %H:%M:%S.%f' 11 | _log_fp = None 12 | 13 | 14 | def log_init(fp): 15 | global _log_fp 16 | _close_logfile() 17 | 18 | _log_fp = open(fp, 'a') 19 | _log_fp.write('\n') 20 | _log_fp.write('-----------------------------------------------------------------\n') 21 | _log_fp.write(' Starting new training run\n') 22 | _log_fp.write('-----------------------------------------------------------------\n') 23 | 24 | 25 | def log(msg): 26 | print(msg) 27 | if _log_fp: 28 | _log_fp.write('[%s] %s\n' % (datetime.now().strftime(_log_fmt)[:-3], msg)) 29 | 30 | 31 | def _close_logfile(): 32 | global _log_fp 33 | if _log_fp: 34 | _log_fp.close() 35 | _log_fp = None 36 | 37 | 38 | atexit.register(_close_logfile) 39 | 40 | 41 | def plot_alignment(alignment, path, info=None): 42 | fig, ax = plt.subplots() 43 | im = ax.imshow( 44 | alignment, 45 | aspect='auto', 46 | origin='lower', 47 | interpolation='none') 48 | fig.colorbar(im, ax=ax) 49 | plt.xlabel('Decoder timestep' + (f'\n\n{info}' if info else '')) 50 | plt.ylabel('Encoder timestep') 51 | plt.tight_layout() 52 | plt.savefig(path, format='png') 53 | 54 | 55 | def plot_specs(specs, path, info=None): 56 | # mel_g mel_r 57 | # mag_g mag_r 58 | ax = plt.subplot(221) ; sns.heatmap(specs[0]) ; ax.invert_yaxis() 59 | ax = plt.subplot(222) ; sns.heatmap(specs[2]) ; ax.invert_yaxis() 60 | ax = plt.subplot(223) ; sns.heatmap(specs[1]) ; ax.invert_yaxis() 61 | ax = plt.subplot(224) ; sns.heatmap(specs[3]) ; ax.invert_yaxis() 62 | plt.xlabel(info) 63 | plt.tight_layout() 64 | plt.margins(0, 0) 65 | plt.savefig(path, format='png', dpi=400) 66 | 67 | 68 | def time_string(): 69 | return datetime.now().strftime('%Y-%m-%d %H:%M') 70 | 71 | 72 | class ValueWindow(): # NOTE: 右进左出的定长队列 73 | 74 | def __init__(self, window_size=100): 75 | self._window_size = window_size 76 | self._values = [] 77 | 78 | def append(self, x): 79 | self._values = self._values[-(self._window_size - 1):] + [x] 80 | 81 | @property 82 | def sum(self): 83 | return sum(self._values) 84 | 85 | @property 86 | def count(self): 87 | return len(self._values) 88 | 89 | @property 90 | def average(self): 91 | return self.sum / max(1, self.count) 92 | 93 | def reset(self): 94 | self._values = [] 95 | --------------------------------------------------------------------------------