├── .gitignore
├── LICENSE
├── README.md
├── app.py
├── img
├── draw_embd_sim.py
├── gen_hfg.wav
├── gen_rtg.wav
├── gen_tts.wav
├── gt_hfg.wav
├── gt_rtg.wav
├── gt_tts.wav
├── hifi-gan.png
├── rtg_D.png
├── rtg_G.png
├── rtg_arch.png
├── rtg_cmp.png
├── tacotron-cn.mini.step-48000-embed_cosine.png
├── tacotron-cn.png
├── tacotron-en.png
├── transtacos.mini.png
├── transtacos.mini.step-48000_prds_embd_cosine.png
├── transtacos.mini.step-48000_text_embd_cosine.png
├── tts.png
├── tts_arch.png
├── tts_decoder.png
├── tts_embed.png
├── tts_encoder.png
├── tts_out_align.png
├── tts_out_spec.png
├── tts_posnet.png
└── y_tmpl.wav
├── index.html
├── requirements.txt
├── retunegan
├── Makefile
├── audio.py
├── audio_proxy.py
├── data.py
├── hparam.py
├── infer.py
├── models
│ ├── __init__.py
│ ├── discrminator.py
│ ├── generator.py
│ └── loss.py
├── server.py
├── tools
│ ├── spec2wavset.py
│ ├── test_downsample.py
│ ├── test_envolope.py
│ ├── test_griffinlim.py
│ ├── test_istft_iter.py
│ ├── test_pesq.py
│ ├── test_phase_recover.py
│ └── test_strip_mirror.py
├── train.py
└── utils.py
├── stats
├── DataBaker-lexicon.txt
├── DataBaker-stats.txt
├── DataBaker.stats
├── DataBaker_gen_stat.py
├── DataBaker_print_pinyins.py
├── DataBaker_print_symbols.py
├── inspect_preproc.py
├── inspect_spec.py
├── thchs30-lexicon.txt
├── thchs30_gen_vbanks.py
└── thchs30_print_symbols.py
└── transtacos
├── Makefile
├── audio.py
├── data.py
├── datasets
├── __skel__.py
├── databaker.py
└── thchs30.py
├── hparam.py
├── models
├── attention.py
├── custom_decoder.py
├── modules.py
├── rnn_wrappers.py
└── tacotron.py
├── preprocess.py
├── server.py
├── synth.py
├── text
├── __init__.py
├── g2p.py
├── phonodict_cn.csv
├── phonodict_cn.py
├── phonodict_cn.txt
├── symbols.py
└── text.py
├── train.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # project meta
2 | .vscode/
3 |
4 | # cache
5 | __pycache__/
6 | .cache/
7 | *.py[cod]
8 |
9 | # misc
10 | ref/
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Armit
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # TransTacoS-RetuneGAN
2 |
3 | A lighter-weight (perhaps!) Text-to-Speech for Chinese/Mandarin synthesize, inspired by Tacotron & FastSpeech2 & RefineGAN.
4 | It is also my shitty graduation design project, just a toy, so lower your expectations :)
5 |
6 | ----
7 |
8 | ## Quick Start
9 |
10 | ### setup
11 |
12 | Since `TransTacoS` is implemented in tensorflow while `RefineGAN` in torch respectively, you could separate them by creating virtual envs, but they are likely not to conflict, thus you could try to put all these together:
13 |
14 | - install `tensorflow-gpu==1.14.0 tensorboard==1.14.0` following `https://tensorflow.google.cn/install/pip`
15 | - install `torch==1.8.0+cu1xx torchaudio==0.8.0` following `https://pytorch.org/`, where `cu1xx` is your cuda version
16 | - run `pip install -r requirements.txt` for the rest dependencies
17 |
18 | ### dataset
19 |
20 | - download and unzip the open source dataset [DataBaker](https://www.data-baker.com/data/index/TNtts)
21 | - other dataset requires user-defined preprocessor, please refer to `transtacos/dataset/__skel__.py`
22 |
23 | ### train
24 |
25 | - check path configs in all `Makefile`
26 | - `cd transtacos & make preprocess` to prepare acoustic features (linear/mel/f0/c0/zcr)
27 | - `cd transtacos & make train` to train TransTacoS
28 | - `cd retunegan & make finetune` to train RetuneGAN using preprocessed linear spectrograms (rather than from raw wave)
29 |
30 | ### deploy
31 |
32 | - check port configs in all `Makefile`
33 | - `cd transtacos & make server` to start TransTacoS headless HTTP server (default at port 5105)
34 | - `cd retunegan & make server` to start RetuneGAN headless HTTP server (default at port 5104)
35 | - `python app.py` to start the WebUI app (default at port 5103)
36 | - point your browser to `http://localhost:5103`, now have a try!
37 |
38 | ## Model Architecture
39 |
40 | ### TransTacoS
41 |
42 | 
43 |
44 | 
45 |
46 | #### What I actually did:
47 |
48 | - TransTacoS :=
49 | - embed: self-designed G2P solution (syl4), employ prodosy marks as linguistic feature
50 | - for G2P solution, I tried char(seq), phoneme, pinyin (CV), CVC, VCV, CVVC ...
51 | - de facto, they merely influence mel_loss, but effect that how controllable the pronounciation is
52 | - so I design **syl4** to split phoneme against tone (just small improvement)
53 | - I predict prodosy marks by simple CNN rather than RNN, it's not enough reasonable though..
54 | - encoder: modified from FastSpeech2
55 | - I use multi-head self-DotAttn with GFFW (gated-FFW) for backbone transform
56 | - f0/c0 feature is quantilized to embed, then fused in with cross-DotAttn
57 | - decoder: inherited from Tacotron
58 | - this RNN+LSA decoder is really complicated, I dare not to touch :(
59 | - posnet: self-designed simple MLP and grouped-Linear
60 | - I found simple Linear layer with growing depth leads to lower mel_loss than Conv1d
61 | - thus realy DO NOT understand why Tacotron2 even FastSpeech2 still use a postnet directly strech n_channels from 80 (n_mel) to 1025 (n_freq), this is surely tooooo hard to compensate information loss
62 |
63 | Frankly speaking, TransTacoS didn't improve any thing profoundly from Tacotron, but I just found that shallower network leads to lower mel_loss, so maybe simple embed+decoder is already enough :(
64 |
65 | #### Tips of ideas to try or failed:
66 |
67 | - To predict `f0/sp/ap` features so that we can use WORLD vocoder
68 | - It seems `sp` is OK, because it resembles mel very much
69 | - but `ap` requires to be carefully normalized, and accurate `f0` is even harder to predict
70 | - audio quality of WORLD sounds worse than Griffin-Lim at times :(
71 | - To predict spectrograms' magnitude part together with phase part, so that we can directly vocode using pure `istft`
72 | - It should be hard to optimize losses on phase, especially on conjunction points between mel frames
73 | - but these guys claim that they did it: [iSTFTNet](https://arxiv.org/abs/2203.02395)
74 | - To predict `f0` and `dyn` so that vocoder might benefits
75 | - I don't know how to separate them from mel, because so far I must regularize decoder's output to be mel (for the sake of teacher force)
76 | - When I remove the mel loss, I found that the RNN decoder became lazy to learn, and the align model also not work
77 | - mel is even more quantized that linear, to abstract `f0` and `dyn` from only mel seems not the reasonable
78 | - To extract duration info from the soft-aligned alignment map, so that we can further train a non-autogressive FastSpeech2 to speed up inference
79 | - just like [FastSpeech](https://arxiv.org/abs/1905.09263) and [DeepSinger](https://arxiv.org/abs/2007.04590) did
80 | - but FastSpeech reported that extracted duration is not enough accurate, yet I could not fully understand and reproduce DeepSinger's DP algorithm for duration extraction
81 |
82 |
83 | ### RetuneGAN
84 |
85 | 
86 |
87 |
88 | #### What I actually did:
89 |
90 | - RetuneGAN :=
91 | - preprocess
92 | - extract `reference wav` using Griffin-Lim
93 | - extract `u/v mask` by hand-tuned zrc/c0 threshold (for Split-G only)
94 | - generators
95 | - `UNet-G` (encoder-decoder generator): modified from RefineGAN, we use the output of Griffin-Lim as reference wav, rather than an F0/C0-guided hand-crafted *speech template*
96 | - `Split-G` (split u/v generator): self-designed, inspired by Multi-Band MelGAN, but I found the generated quality is holy shit :(
97 | - `ResStack` borrowed from MelGAN
98 | - `ResBlock` modified from HiFiGAN
99 | - discriminators
100 | - `MSD` (multi scale discriminator): borrowed from MelGAN, I think it's good for plosive consonants
101 | - `MPD` (multi period discriminator): borrowed from HiFiGAN, I take it as a multiple MSDs' stack-up
102 | - `MTD` (multi stft discriminator): modified from UnivNet, it has two work modes depending on its input (MPSD seems better indeed ...)
103 | - `MPSD` (multi parameter spectrogram discriminator): like in UnivNet, but we let it judge both phase part and magnitude part
104 | - `PHD` (phase discriminator): self-designed, care more about phase, since `l_mstft` has already regulated magnitude
105 | - input of MPSD is `[(mag_real, phase_real), (mag_fake, phase_fake)]`, thus distinguishes real/fake stft data
106 | - input of PHD is `[(mag_real, phase_real), (mag_real, phase_fake)]`, thus ONLY distinguishes real/fake phase
107 | - losses
108 | - `l_adv` (adversarial loss): modified from HiFiGAN, but relativized
109 | - `l_fm` (feature map loss): borrowed from MelGAN
110 | - `l_mstft` (multi stft loss): modified from Parallel WaveGAN, but we calculate mel_loss rather than linear_loss
111 | - `l_env` (envlope loss): borrowed from RefineGAN
112 | - `l_dyn` (dynamic loss): self-designed, inspired by `l_env`
113 | - `l_sm` (strip mirror loss): self-designed, but might hurts audio quality :(
114 |
115 | Oh my dude, it's really a biggy feng-he monster :(
116 |
117 | #### Tips of ideas to try or failed:
118 |
119 | - To divide *consonant* part against *vowel* part in time domain, then use two generator to generate them separately
120 | - I found this will bring breakups in conjunction points, thus audio sounds noisy, yet mstft loss will be more unstable
121 | - To shallow fuse the reference wav with half-decoded mel, rather than the `encode-merge-decode` UNet architecture
122 | - I found this will make the generator lazy to learn, even overfit to train set
123 | - might because waveform is far from audio semantics but near representation, so an encoder is necessary to extract semantical info
124 | - NOTE: the weight of mstft loss should NOT be too overwhelming in front of adversarial loss (like in HiFiGAN)
125 | - adversarial loss leads to more clear plosives (`b/p/g/k/d/t`), while `mstft loss` contributes little to consonants
126 | - `hop_length` in stft is much larger that `stride` in discriminators, thus mstft loss is usually more coarse than adversarial loss in time domain
127 |
128 | ## Acknowledgements
129 |
130 | Codes referred to:
131 |
132 | - [keithito's Tacotron](https://github.com/keithito/tacotron)
133 | - [jaywalnut310's MelGAN](https://github.com/jaywalnut310/MelGAN-Pytorch)
134 | - [jik876's official HiFiGAN](https://github.com/jik876/hifi-gan)
135 | - [ming024's FastSpeech2](https://github.com/ming024/FastSpeech2/)
136 |
137 | Ideas plagiarized from:
138 |
139 | - [Tacotron: Towards End-to-End Speech Synthesis](https://arxiv.org/abs/1703.10135)
140 | - [FastSpeech: Fast, Robust and Controllable Text to Speech](https://arxiv.org/abs/1905.09263)
141 | - [FastSpeech 2: Fast and High-Quality End-to-End Text to Speech](https://arxiv.org/abs/2006.04558)
142 | - [MelGAN: Generative Adversarial Networks for Conditional Waveform Synthesis](https://arxiv.org/abs/1910.06711)
143 | - [Multi-band MelGAN: Faster Waveform Generation for High-Quality Text-to-Speech](https://arxiv.org/abs/2005.05106)
144 | - [HiFi-GAN: Generative Adversarial Networks for Efficient and High Fidelity Speech Synthesis](https://arxiv.org/abs/2010.05646)
145 | - [UnivNet: A Neural Vocoder with Multi-Resolution Spectrogram Discriminators for High-Fidelity Waveform Generation](https://arxiv.org/abs/2106.07889)
146 | - [RefineGAN: Universally Generating Waveform Better than Ground Truth with Highly Accurate Pitch and Intensity Responses](https://arxiv.org/abs/2111.00962)
147 |
148 | code release kept under the MIT license, greatest thanks all the authors!! :)
149 |
150 | ----
151 |
152 | by Armit
153 | 2022/02/15
154 | 2022/05/25
--------------------------------------------------------------------------------
/app.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Author: Armit
3 | # Create Time: 2022/04/15
4 |
5 | import os
6 | import re
7 | import pickle
8 | from time import time
9 | from argparse import ArgumentParser
10 | from tempfile import gettempdir
11 |
12 | import numpy as np
13 | from flask import Flask, request, jsonify, send_file
14 | from requests.utils import unquote
15 | from scipy.io import wavfile
16 | from xpinyin import Pinyin
17 | from requests import session
18 |
19 |
20 | BASE_PATH = os.path.dirname(os.path.abspath(__file__))
21 | HTML_FILE = os.path.join(BASE_PATH, 'index.html')
22 |
23 | TMP_DIR = gettempdir()
24 | WAV_TMP_FILE = os.path.join(TMP_DIR, 'synth.wav')
25 | MP3_TMP_FILE = os.path.join(TMP_DIR, 'synth.mp3')
26 |
27 | REGEX_PUNCT_IGNORE = re.compile('、|:|;|“|”|‘|’')
28 | REGEX_PUNCT_BREAK = re.compile(',|。|!|?')
29 | MAX_CLUASE_LENGTH = 20
30 | CONVERT_MP3 = False
31 |
32 | SAMPLE_RATE = 22050
33 | SYNTH_API = 'http://127.0.0.1:5105/synth_spec'
34 | VOCODER_API = 'http://127.0.0.1:5104/vocode'
35 |
36 |
37 | app = Flask(__name__)
38 | html_page = None
39 | http = session()
40 | kanji2pinyin = Pinyin()
41 |
42 |
43 | def synth_and_save_file(txt):
44 | # Text-Norm
45 | if True:
46 | s = time()
47 | print(f'text/raw: {txt!r}')
48 |
49 | kanji = REGEX_PUNCT_IGNORE.sub('', txt)
50 | kanji = REGEX_PUNCT_BREAK.sub(' ', txt)
51 | segs = [''] # dummy init
52 | for rs in [s.strip() for s in kanji.split(' ') if s.strip()]:
53 | if (not segs[-1]) or (len(rs) + len(segs[-1]) < MAX_CLUASE_LENGTH):
54 | segs[-1] = segs[-1] + rs
55 | else: segs.append(rs)
56 | print(f'text/segs: {segs!r}')
57 | t = time()
58 | print('[TextNorm] Done in %.2fs' % (t - s))
59 |
60 | # Synth
61 | if True:
62 | s = time()
63 | spec_clips = []
64 | for seg in segs:
65 | pinyin = ' '.join(kanji2pinyin.get_pinyin(seg, tone_marks='numbers').split('-'))
66 | resp = http.post(SYNTH_API, json={'pinyin': pinyin})
67 | spec = pickle.loads(resp.content)
68 | spec_clips.append(spec)
69 | spec = np.concatenate(spec_clips)
70 | print('spec.shape:', spec.shape)
71 | t = time()
72 | print('[Synth] Done in %.2fs' % (t - s))
73 |
74 | # Vocode
75 | if True:
76 | s = time()
77 | resp = http.post(VOCODER_API, data=pickle.dumps(spec))
78 | wav = pickle.loads(resp.content)
79 | wavfile.write(WAV_TMP_FILE, SAMPLE_RATE, wav)
80 | print('wav.length:', len(wav))
81 | t = time()
82 | print('[Vocode] Done in %.2fs' % (t - s))
83 |
84 | # Compress
85 | if CONVERT_MP3:
86 | s = time()
87 | cmd = f'ffmpeg -i "{WAV_TMP_FILE}" -f mp3 -acodec libmp3lame -y "{MP3_TMP_FILE}" -loglevel quiet'
88 | r = os.system(cmd)
89 | t = time()
90 | print('[Compress] Done in %.2fs' % (t - s))
91 |
92 |
93 | @app.route('/', methods=['GET'])
94 | def root():
95 | global html_page
96 | if not html_page:
97 | with open(HTML_FILE, encoding='utf-8') as fp:
98 | html_page = fp.read()
99 | return html_page
100 |
101 |
102 | @app.route('/synth', methods=['GET'])
103 | def synth():
104 | txt = unquote(request.args.get('text')).strip()
105 | if not txt: return jsonify({'error': 'empty request'})
106 |
107 | try:
108 | synth_and_save_file(txt)
109 | if CONVERT_MP3:
110 | return send_file(MP3_TMP_FILE, mimetype='audio/mp3')
111 | else:
112 | return send_file(WAV_TMP_FILE, mimetype='audio/wav')
113 | except Exception as e:
114 | print('[Error] %r' % e)
115 | return jsonify({'error': e})
116 |
117 |
118 | if __name__ == '__main__':
119 | parser = ArgumentParser()
120 | parser.add_argument('--host', type=str, default='0.0.0.0')
121 | parser.add_argument('--port', type=int, default=5103)
122 | args = parser.parse_args()
123 |
124 | app.run(host=args.host, port=args.port, debug=False)
125 |
--------------------------------------------------------------------------------
/img/draw_embd_sim.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 |
4 | import numpy as np
5 | import seaborn as sns
6 | import matplotlib.pyplot as plt
7 |
8 |
9 | BASE_PATH = os.path.dirname(os.path.abspath(__file__))
10 |
11 |
12 | def process(e, name):
13 | def dot_sim(x):
14 | return np.dot(x, x.T)
15 |
16 | def cosine_sim(x):
17 | n = np.linalg.norm(x, axis=-1, keepdims=True)
18 | return dot_sim(x) / (np.dot(n, n.T) + 1e-8)
19 |
20 | #s1 = dot_sim(e)
21 | #np.save(os.path.join(BASE_PATH, f'{name}_dot.npy'), s1)
22 | #sns.heatmap(s1)
23 | #plt.gca().invert_yaxis()
24 | #plt.savefig(os.path.join(BASE_PATH, f'{name}_dot.png'))
25 | #plt.clf()
26 |
27 | s2 = cosine_sim(e)
28 | #np.save(os.path.join(BASE_PATH, f'{name}_cosine.npy'), s2)
29 | sns.heatmap(s2)
30 | plt.gca().invert_yaxis()
31 | plt.savefig(os.path.join(BASE_PATH, f'{name}_cosine.png'))
32 | plt.clf()
33 |
34 |
35 | fp = sys.argv[1]
36 | fn = os.path.basename(fp)
37 | base, ext = os.path.splitext(fn)
38 |
39 | d = np.load(fp, allow_pickle=True)
40 | if isinstance(d, np.ndarray):
41 | process(d, base)
42 | elif isinstance(d, dict):
43 | for k in d.keys():
44 | if not k.endswith('embd'): continue
45 | process(d[k], f'{base}_{k}')
46 | else: raise
47 |
--------------------------------------------------------------------------------
/img/gen_hfg.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kahsolt/TransTacoS-RetuneGAN/1963a1ca86045aac270e68e58d4187665efd91d4/img/gen_hfg.wav
--------------------------------------------------------------------------------
/img/gen_rtg.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kahsolt/TransTacoS-RetuneGAN/1963a1ca86045aac270e68e58d4187665efd91d4/img/gen_rtg.wav
--------------------------------------------------------------------------------
/img/gen_tts.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kahsolt/TransTacoS-RetuneGAN/1963a1ca86045aac270e68e58d4187665efd91d4/img/gen_tts.wav
--------------------------------------------------------------------------------
/img/gt_hfg.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kahsolt/TransTacoS-RetuneGAN/1963a1ca86045aac270e68e58d4187665efd91d4/img/gt_hfg.wav
--------------------------------------------------------------------------------
/img/gt_rtg.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kahsolt/TransTacoS-RetuneGAN/1963a1ca86045aac270e68e58d4187665efd91d4/img/gt_rtg.wav
--------------------------------------------------------------------------------
/img/gt_tts.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kahsolt/TransTacoS-RetuneGAN/1963a1ca86045aac270e68e58d4187665efd91d4/img/gt_tts.wav
--------------------------------------------------------------------------------
/img/hifi-gan.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kahsolt/TransTacoS-RetuneGAN/1963a1ca86045aac270e68e58d4187665efd91d4/img/hifi-gan.png
--------------------------------------------------------------------------------
/img/rtg_D.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kahsolt/TransTacoS-RetuneGAN/1963a1ca86045aac270e68e58d4187665efd91d4/img/rtg_D.png
--------------------------------------------------------------------------------
/img/rtg_G.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kahsolt/TransTacoS-RetuneGAN/1963a1ca86045aac270e68e58d4187665efd91d4/img/rtg_G.png
--------------------------------------------------------------------------------
/img/rtg_arch.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kahsolt/TransTacoS-RetuneGAN/1963a1ca86045aac270e68e58d4187665efd91d4/img/rtg_arch.png
--------------------------------------------------------------------------------
/img/rtg_cmp.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kahsolt/TransTacoS-RetuneGAN/1963a1ca86045aac270e68e58d4187665efd91d4/img/rtg_cmp.png
--------------------------------------------------------------------------------
/img/tacotron-cn.mini.step-48000-embed_cosine.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kahsolt/TransTacoS-RetuneGAN/1963a1ca86045aac270e68e58d4187665efd91d4/img/tacotron-cn.mini.step-48000-embed_cosine.png
--------------------------------------------------------------------------------
/img/tacotron-cn.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kahsolt/TransTacoS-RetuneGAN/1963a1ca86045aac270e68e58d4187665efd91d4/img/tacotron-cn.png
--------------------------------------------------------------------------------
/img/tacotron-en.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kahsolt/TransTacoS-RetuneGAN/1963a1ca86045aac270e68e58d4187665efd91d4/img/tacotron-en.png
--------------------------------------------------------------------------------
/img/transtacos.mini.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kahsolt/TransTacoS-RetuneGAN/1963a1ca86045aac270e68e58d4187665efd91d4/img/transtacos.mini.png
--------------------------------------------------------------------------------
/img/transtacos.mini.step-48000_prds_embd_cosine.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kahsolt/TransTacoS-RetuneGAN/1963a1ca86045aac270e68e58d4187665efd91d4/img/transtacos.mini.step-48000_prds_embd_cosine.png
--------------------------------------------------------------------------------
/img/transtacos.mini.step-48000_text_embd_cosine.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kahsolt/TransTacoS-RetuneGAN/1963a1ca86045aac270e68e58d4187665efd91d4/img/transtacos.mini.step-48000_text_embd_cosine.png
--------------------------------------------------------------------------------
/img/tts.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kahsolt/TransTacoS-RetuneGAN/1963a1ca86045aac270e68e58d4187665efd91d4/img/tts.png
--------------------------------------------------------------------------------
/img/tts_arch.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kahsolt/TransTacoS-RetuneGAN/1963a1ca86045aac270e68e58d4187665efd91d4/img/tts_arch.png
--------------------------------------------------------------------------------
/img/tts_decoder.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kahsolt/TransTacoS-RetuneGAN/1963a1ca86045aac270e68e58d4187665efd91d4/img/tts_decoder.png
--------------------------------------------------------------------------------
/img/tts_embed.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kahsolt/TransTacoS-RetuneGAN/1963a1ca86045aac270e68e58d4187665efd91d4/img/tts_embed.png
--------------------------------------------------------------------------------
/img/tts_encoder.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kahsolt/TransTacoS-RetuneGAN/1963a1ca86045aac270e68e58d4187665efd91d4/img/tts_encoder.png
--------------------------------------------------------------------------------
/img/tts_out_align.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kahsolt/TransTacoS-RetuneGAN/1963a1ca86045aac270e68e58d4187665efd91d4/img/tts_out_align.png
--------------------------------------------------------------------------------
/img/tts_out_spec.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kahsolt/TransTacoS-RetuneGAN/1963a1ca86045aac270e68e58d4187665efd91d4/img/tts_out_spec.png
--------------------------------------------------------------------------------
/img/tts_posnet.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kahsolt/TransTacoS-RetuneGAN/1963a1ca86045aac270e68e58d4187665efd91d4/img/tts_posnet.png
--------------------------------------------------------------------------------
/img/y_tmpl.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kahsolt/TransTacoS-RetuneGAN/1963a1ca86045aac270e68e58d4187665efd91d4/img/y_tmpl.wav
--------------------------------------------------------------------------------
/index.html:
--------------------------------------------------------------------------------
1 |
Demo
2 |
13 |
14 |
18 |
19 |
20 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | librosa==0.8.1
2 | soundfile==0.10.3
3 | matplotlib==3.1.3
4 | numpy==1.20.2
5 | scipy==1.6.2
6 | Flask==2.0.1
7 | requests==2.25.1
8 | xpinyin==0.7.6
9 | tensorboardX==2.4.1
10 | seaborn
11 | #gast==0.2.2
12 |
--------------------------------------------------------------------------------
/retunegan/Makefile:
--------------------------------------------------------------------------------
1 | ifeq ($(shell uname -s), Linux)
2 | BASE_PATH=~/Data
3 | else
4 | BASE_PATH=D:/Desktop/Workspace/Data
5 | endif
6 |
7 | DATASET=DataBaker
8 | DATA_PATH=$(BASE_PATH)/$(DATASET).tts_processed
9 | #LOG_PATH=$(BASE_PATH)/rtg-$(DATASET).$(VER)
10 | LOG_PATH=$(BASE_PATH)/rtg-$(DATASET)
11 |
12 | .PHONY: train test server clean stat
13 |
14 | train:
15 | python train.py \
16 | --data_dp $(DATA_PATH) \
17 | --log_path $(LOG_PATH) \
18 | --epochs 3100
19 |
20 | finetune:
21 | python train.py \
22 | --finetune \
23 | --data_dp $(DATA_PATH) \
24 | --log_path $(LOG_PATH) \
25 | --epochs 3100
26 |
27 | test:
28 | python infer.py \
29 | --log_path $(LOG_PATH) \
30 | --input_path test
31 |
32 | server:
33 | python server.py \
34 | --log_path $(LOG_PATH)
35 |
36 | stat:
37 | tensorboard \
38 | --logdir $(LOG_PATH) \
39 | --port 5101
40 |
41 | clean:
42 | rm -rf $(LOG_PATH)
43 |
--------------------------------------------------------------------------------
/retunegan/audio.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Author: Armit
3 | # Create Time: 2022/01/07
4 |
5 | import torch
6 | from torch.nn import AvgPool2d
7 | import numpy as np
8 | import numpy.random as R
9 | import librosa as L
10 | from scipy.io import wavfile
11 |
12 | import seaborn as sns
13 | import matplotlib.pyplot as plt
14 |
15 | import hparam as hp
16 | R.seed(hp.randseed)
17 |
18 |
19 | eps = 1e-5
20 | mel_basis = L.filters.mel(hp.sample_rate, hp.n_fft, n_mels=hp.n_mel, fmin=hp.fmin, fmax=hp.fmax)
21 | mag_to_mel = lambda x: np.dot(mel_basis, x)
22 |
23 | avg_pool_2d = AvgPool2d(kernel_size=3, stride=1, padding=1)
24 |
25 | mel_basis_torch = { } # { n_fft: mel_basis }
26 | window_fn_torch = { } # { win_length: window_fn }
27 |
28 |
29 | def load_wav(path): # float values in range (-1,1)
30 | y, _ = L.load(path, sr=hp.sample_rate, mono=True, res_type='kaiser_best')
31 | return y.astype(np.float32) # [T,]
32 |
33 |
34 | def save_wav(wav, path):
35 | wavfile.write(path, hp.sample_rate, wav)
36 |
37 |
38 | def align_wav(wav, r=hp.hop_length):
39 | d = len(wav) % r
40 | if d != 0:
41 | wav = np.pad(wav, (0, (r - d)))
42 | return wav
43 |
44 |
45 | def augment_wav(y, pitch_shift=True, time_stretch=True, dynamic_scale=True):
46 | if pitch_shift:
47 | # 75% unmodified, 25% shifted
48 | if R.random() > 0.75:
49 | # ~10% unmodified, ~30% in (-1,1), ~47% in (-2,+2), ~%74 in (-4,+4)
50 | semitone = max(min(round(R.normal(scale=12/3)), 12), -12) # 3-sigma principle:99.74% in [mu-3*sigma,mu+3*sigma]
51 | if semitone != 0: y = L.effects.pitch_shift(y, hp.sample_rate, semitone, res_type='kaiser_best')
52 |
53 | if time_stretch:
54 | # 90% unmodified, 10% twisted; because `time_stretch`` hurts quality a lot
55 | if R.random() > 0.90:
56 | alpha = 2 ** R.normal(scale=1/5)
57 | if abs(alpha - 1.0) < 0.1: alpha = 1.0
58 | if alpha != 1.0: y = L.effects.time_stretch(y=y, rate=alpha, win_length=hp.win_length, hop_length=hp.hop_length)
59 |
60 | if dynamic_scale:
61 | # 25% unmodified, 75% global shift
62 | r = R.random()
63 | if r > 0.25:
64 | alpha = 2 ** R.normal(scale=1/3)
65 | y = y * alpha
66 | absmax = max(y.max(), -y.min())
67 | if absmax > 1.0: y /= absmax
68 |
69 | return y.astype(np.float32) # [T,]
70 |
71 |
72 | def augment_spec(S, time_mask=True, freq_mask=True, prob=0.2, rounds=3, freq_width=9, time_width=3):
73 | F, T = S.shape
74 | S = torch.from_numpy(S).unsqueeze(0)
75 |
76 | # local mask
77 | # 10.7% unmodified, 57.0% maskes <=2 times, 86.0% masked <=4 times
78 | for _ in range(rounds):
79 | if freq_mask and R.random() < prob:
80 | s = R.randint(0, F - freq_width)
81 | r = R.randint(1, freq_width)
82 | mask_val = R.uniform(low=S.min(), high=S.mean())
83 | S[:, s:s+r, :] = torch.ones([1, r, T]) * mask_val
84 |
85 | if time_mask and R.random() < prob:
86 | s = R.randint(0, T - time_width)
87 | r = R.randint(1, time_width)
88 | mask_val = R.uniform(low=S.min(), high=S.mean())
89 | S[:, :, s:s+r] = torch.ones([1, F, r]) * mask_val
90 |
91 | # global blur
92 | S = avg_pool_2d(S)
93 |
94 | S = S.squeeze(0).numpy()
95 | return S.astype(np.float32)
96 |
97 |
98 | def get_zcr(y):
99 | zcr = L.feature.zero_crossing_rate(y, frame_length=hp.win_length, hop_length=hp.hop_length)[0]
100 | return zcr.astype(np.float32) # [T,]
101 |
102 |
103 | def get_c0(y):
104 | c0 = L.feature.rms(y=y, frame_length=hp.win_length, hop_length=hp.hop_length)[0]
105 | return c0.astype(np.float32) # [T,]
106 |
107 |
108 | def get_uv(zcr, dyn):
109 | uv = np.empty_like(zcr)
110 | for i in range(len(uv)):
111 | # NOTE: these numbers are magic, tune by hand according to your dataset
112 | uv[i] = zcr[i] > 0.18 or dyn[i] < 0.03
113 | return uv
114 |
115 |
116 | def get_mag(y, clamp_low=True):
117 | D = L.stft(y, n_fft=hp.n_fft, hop_length=hp.hop_length, win_length=hp.win_length, window=hp.window_fn)
118 | S = np.abs(D)
119 | mag = np.log(S.clip(min=eps) if clamp_low else S)
120 | return mag.astype(np.float32) # [F, T]
121 |
122 |
123 | def get_mel(y, clamp_low=True):
124 | M = L.feature.melspectrogram(y=y, sr=hp.sample_rate, n_fft=hp.n_fft, hop_length=hp.hop_length, win_length=hp.win_length,
125 | n_mels=hp.n_mel, fmin=hp.fmin, fmax=hp.fmax,
126 | window=hp.window_fn, power=1, htk=hp.mel_scale=='htk')
127 | mel = np.log(M.clip(min=eps) if clamp_low else M)
128 | return mel.astype(np.float32) # [M, T]
129 |
130 |
131 | def _griffinlim(S, wavlen=None):
132 | if hp.gl_power: S = S ** hp.gl_power
133 | y = L.griffinlim(S, n_iter=hp.gl_iters,
134 | hop_length=hp.hop_length, win_length=hp.win_length, window=hp.window_fn,
135 | length=wavlen, momentum=hp.gl_momentum, init='random', random_state=hp.randseed)
136 | return y.astype(np.float32)
137 |
138 |
139 | def inv_mag(mag, wavlen=None):
140 | S = np.exp(mag) # [F/F-1, T], reverse np.log
141 | F, T = mag.shape
142 | if F == hp.n_freq - 1: # NOTE: preprend zero DC component
143 | S = np.concatenate([np.zeros([1, T]), S], axis=0)
144 | #print('S.min():', S.min(), 'S.max(): ', S.max(), 'S.mean(): ', S.mean())
145 | y = _griffinlim(S, wavlen)
146 | if wavlen: assert len(y) == wavlen
147 | return y
148 |
149 |
150 | def get_stft_torch(y, n_fft, win_length, hop_length):
151 | ''' 该函数得到原始的Mel值,没有数值下截断和取对数过程 '''
152 |
153 | global mel_basis_torch, window_fn_torch
154 | if win_length not in window_fn_torch:
155 | win_functor = getattr(torch, f'{hp.window_fn}_window')
156 | window_fn_torch[win_length] = win_functor(win_length).to(y.device) # [n_fft]
157 | if n_fft not in mel_basis_torch:
158 | mel_filter = L.filters.mel(hp.sample_rate, n_fft, hp.n_mel, hp.fmin, hp.fmax)
159 | mel_basis_torch[n_fft] = torch.from_numpy(mel_filter).float().to(y.device) # [n_mel, n_fft//2+1]
160 |
161 | D = torch.stft(y, n_fft, return_complex=True, # [n_fft/2+1, n_frames, 2], last dim 2 for real/image parts
162 | hop_length=hop_length, win_length=win_length, window=window_fn_torch[win_length],
163 | center=True, pad_mode='reflect', normalized=False, onesided=True)
164 |
165 | #S = torch.sqrt(D.pow(2).sum(-1) + (1e-9)) # [n_fft/2+1, n_frames], get modulo, aka. magnitude
166 | S = torch.abs(D + 1e-9)
167 | M = torch.matmul(mel_basis_torch[n_fft], S) # [n_mel, n_frames]
168 | P = torch.angle(D)
169 |
170 | return S, M, P
171 |
--------------------------------------------------------------------------------
/retunegan/audio_proxy.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Author: Armit
3 | # Create Time: 2022/04/16
4 |
5 | # NOTE: proxy by TransTacoS for finetune
6 |
7 | import os
8 | import sys
9 | from importlib import import_module
10 | BASH_PATH = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
11 | sys.path.append(BASH_PATH)
12 | print('TRANSTACOS_PATH:', os.path.join(BASH_PATH, 'transtacos'))
13 | AP = import_module('transtacos.audio')
14 |
--------------------------------------------------------------------------------
/retunegan/data.py:
--------------------------------------------------------------------------------
1 | import os
2 | from random import randint
3 |
4 | import numpy as np
5 | from torch.utils.data import Dataset
6 | import seaborn as sns
7 | import matplotlib.pyplot as plt
8 |
9 | import hparam as hp
10 | import audio as A
11 |
12 | # for fintune tune with TranTacoS
13 | from audio_proxy import AP
14 |
15 |
16 | assert hp.segment_size % hp.hop_length == 0
17 | frames_per_seg = hp.segment_size // hp.hop_length
18 |
19 |
20 | class Dataset(Dataset):
21 |
22 | def __init__(self, name, data_dp, finetune=False, limit=None):
23 | self.is_train = name == 'train'
24 | self.data_dp = data_dp # the preprocessed folder containing index files and mel/mag features (if finetune)
25 | self.finetune = finetune
26 |
27 | with open(os.path.join(data_dp, 'wav_path.txt')) as fh:
28 | wav_path = fh.read().strip()
29 | with open(os.path.join(data_dp, f'{name}.txt'), encoding='utf-8') as fh:
30 | self.wav_fps = [os.path.join(wav_path, line.split('|')[0] + '.wav') for line in fh.readlines() if line]
31 | if limit: self.wav_fps = self.wav_fps[:limit]
32 |
33 | self.data = [None] * len(self.wav_fps)
34 |
35 | def __len__(self):
36 | return len(self.wav_fps)
37 |
38 | def __getitem__(self, index):
39 | # repreprocess & cache
40 | if self.data[index] is None:
41 | ''' prepare GT wav '''
42 | wav_fp = self.wav_fps[index]
43 | if not self.finetune:
44 | wav = A.load_wav(wav_fp)
45 | if self.is_train:
46 | # aug data once and freeze
47 | wav = A.augment_wav(wav)
48 | wav = A.align_wav(wav)
49 | else:
50 | # keep identical to `preprocessor.make_metadata()` of TransTacoS
51 | wav = AP.load_wav(wav_fp)
52 | wav = AP.trim_silence(wav)
53 | wav = AP.align_wav(wav)
54 |
55 | wavlen = len(wav)
56 |
57 | ''' prepare GT mel '''
58 | if not self.finetune:
59 | # `[:-1]` to avoid extra tailing frame
60 | mag = A.get_mag(wav[:-1]) # [M, T]
61 | else:
62 | # keep identical to preprocessors of TransTacoS
63 | name = os.path.splitext(os.path.basename(wav_fp))[0]
64 | mag = np.load(os.path.join(self.data_dp, f'mag-{name}.npy')) # [M, T]
65 | mag = AP.spec_to_natural_scale(mag)
66 |
67 | mel = A.mag_to_mel(mag)
68 |
69 | if self.is_train:
70 | # aug data once and freeze
71 | mel_aug = A.augment_spec(mel, rounds=5)
72 | mel = mel / 2 + mel_aug / 2
73 |
74 | ''' prepare ref wav '''
75 | try:
76 | wav_tmpl = A.inv_mag(mag, wavlen=wavlen-1) # `wavlen-1` to avoid extra tailing frame
77 | wav_tmpl = np.pad(wav_tmpl, (0, 1)) # pad to align
78 | except:
79 | breakpoint()
80 |
81 | # dy: 按理说一阶差分可以大致显示脉冲位置,但是wav_tmpl的相位可能有点差
82 | if hp.ref_wav == 'dy':
83 | wav_tmpl = np.pad(wav_tmpl, (0, 1))
84 | wav_tmpl = np.asarray([b-a for a, b in zip(wav_tmpl[:-1], wav_tmpl[1:])])
85 |
86 | ''' prepare u/v mask '''
87 | if hp.split_cv: # 时域法误差有点大
88 | zcr = A.get_zcr(wav_tmpl[:-1])
89 | dyn = A.get_c0 (wav_tmpl[:-1])
90 | uv = A.get_uv (zcr, dyn)
91 |
92 | ''' prepare u/v-splitted mel & ref wav '''
93 | if hp.split_cv:
94 | uv_ex = np.repeat(uv, hp.hop_length)
95 | wav_tmpl_c = wav_tmpl * uv_ex
96 | wav_tmpl_v = wav_tmpl * (1 - uv_ex)
97 | mel_min = mel.min()
98 | mel_shift = mel - mel_min # assure > 0 for mask product
99 | mel_c = mel_shift * uv + mel_min
100 | mel_v = mel_shift * (1-uv) + mel_min
101 |
102 | if not 'check':
103 | if hp.split_cv:
104 | wav_c = wav * uv_ex
105 | wav_v = wav * (1 - uv_ex)
106 | plt.subplot(411); plt.plot(wav_c, 'r') ; plt.plot(wav_v, 'b')
107 | plt.subplot(412); plt.plot(wav_tmpl_c, 'r') ; plt.plot(wav_tmpl_v, 'b')
108 | plt.subplot(413); sns.heatmap(mel_c, cbar=False) ; plt.gca().invert_yaxis()
109 | plt.subplot(414); sns.heatmap(mel_v, cbar=False) ; plt.gca().invert_yaxis()
110 | plt.show()
111 | else:
112 | plt.subplot(411); plt.plot(wav, 'b')
113 | plt.subplot(412); sns.heatmap(mag, cbar=False) ; plt.gca().invert_yaxis()
114 | plt.subplot(413); plt.plot(wav_tmpl, 'r')
115 | plt.subplot(414); sns.heatmap(mel, cbar=False) ; plt.gca().invert_yaxis()
116 | plt.show()
117 |
118 | ''' check shape aligns '''
119 | if hp.split_cv: assert len(dyn) == len(zcr) == mel.shape[1]
120 | assert len(wav) == len(wav_tmpl) == mel.shape[1] * hp.hop_length
121 |
122 | ''' done '''
123 | if hp.split_cv:
124 | self.data[index] = (mel, wav, mel_c, mel_v, wav_tmpl_c, wav_tmpl_v, uv_ex)
125 | else:
126 | self.data[index] = (mel, wav, wav_tmpl)
127 |
128 | # get from cache (full length data)
129 | if hp.split_cv:
130 | mel, wav, mel_c, mel_v, wav_tmpl_c, wav_tmpl_v, uv_ex = self.data[index]
131 | else:
132 | mel, wav, wav_tmpl = self.data[index]
133 |
134 | # make slices during training: wav[S=8192] <=> mel[T=32]
135 | if self.is_train:
136 | wavlen, mellen = len(wav), mel.shape[1]
137 | if wavlen > hp.segment_size:
138 | cp = randint(0, mellen - frames_per_seg - 1)
139 | if hp.split_cv:
140 | mel_c = mel_c [:, cp : cp + frames_per_seg]
141 | mel_v = mel_v [:, cp : cp + frames_per_seg]
142 | wav_tmpl_c = wav_tmpl_c[cp * hp.hop_length : (cp + frames_per_seg) * hp.hop_length]
143 | wav_tmpl_v = wav_tmpl_v[cp * hp.hop_length : (cp + frames_per_seg) * hp.hop_length]
144 | wav = wav [cp * hp.hop_length : (cp + frames_per_seg) * hp.hop_length]
145 | uv_ex = uv_ex [cp * hp.hop_length : (cp + frames_per_seg) * hp.hop_length]
146 | else:
147 | mel = mel [:, cp : cp + frames_per_seg]
148 | wav = wav [cp * hp.hop_length : (cp + frames_per_seg) * hp.hop_length]
149 | wav_tmpl = wav_tmpl [cp * hp.hop_length : (cp + frames_per_seg) * hp.hop_length]
150 | else:
151 | if hp.split_cv:
152 | mel_c = np.pad(mel_c, (0, 0, 0, frames_per_seg - mellen), mel.min())
153 | mel_v = np.pad(mel_v, (0, 0, 0, frames_per_seg - mellen), mel.min())
154 | wav_tmpl_c = np.pad(wav_tmpl_c, (0, hp.segment_size - wavlen))
155 | wav_tmpl_v = np.pad(wav_tmpl_v, (0, hp.segment_size - wavlen))
156 | wav = np.pad(wav, (0, hp.segment_size - wavlen))
157 | uv_ex = np.pad(uv_ex, (0, hp.segment_size - wavlen))
158 | else:
159 | mel = np.pad(mel, (0, 0, 0, frames_per_seg - mellen), mel.min())
160 | wav = np.pad(wav, (0, hp.segment_size - wavlen))
161 | wav_tmpl = np.pad(wav_tmpl, (0, hp.segment_size - wavlen))
162 |
163 | # mel: 由外源mag滤波而来,作为推断时的主要输入
164 | # wav_tmpl: 基于外源mag用传统算法得到的粗糙波形,作为推断时的参考输入
165 | # wav: 真实的目标录音波形,作为训练时的目标输出
166 | # uv_ex: 清音掩码,可作为推断时的参考输入
167 | if hp.split_cv:
168 | ret = mel_c, mel_v, wav_tmpl_c, wav_tmpl_v, wav, uv_ex
169 | else:
170 | ret = mel, wav_tmpl, wav
171 |
172 | return [x.astype(np.float32) for x in ret]
173 |
--------------------------------------------------------------------------------
/retunegan/hparam.py:
--------------------------------------------------------------------------------
1 | '''Audio: proxy by trastacos, plz keep sync'''
2 | # Audio
3 | sample_rate = 22050 # sample rate (Hz) of wav file
4 | n_fft = 2048
5 | win_length = 1024 # :=n_fft//2
6 | hop_length = 256 # :=win_length//4, 11.6ms, 平均1个拼音音素对应9帧(min2~max20)
7 | n_mel = 80 # MEL谱段数 (default: 160), 120 should be more reasonable though
8 | n_freq = 1025 # 线性谱段数 :=n_fft//2+1
9 | preemphasis = 0.97 # 增强高频,使EQ均衡
10 | ref_level_db = 20 # 最高考虑的谱幅值(虚拟0dB),理论上安静环境下取94,但实际上录音越嘈杂该值应越小 (default: 20)
11 | min_level_db = -100 # 最低考虑的谱幅值,用于动态范围截断压缩 (default: -100)
12 | max_abs_value = 4 # 将谱幅值正则化到 [-max_abs_value, max_abs_value]
13 | trim_below_peak_db = 35 # trim beginning/ending silence parts (default:60)
14 | fmin = 125 # MEL滤波器组频率上下限 (set 55/3600 for male)
15 | fmax = 7600
16 | rf0min = 'D2' # 基频检测上下限
17 | rf0max = 'D5'
18 |
19 | ## see `Databaker stats` or `stats.txt` in preprocessed folder
20 | c0min = 4.6309418394230306e-05
21 | c0max = 0.3751049339771271
22 | f0min = 73.25581359863281
23 | f0max = 595.9459228515625
24 | n_tone = 5+1
25 | n_prds = 5+1
26 | n_c0_bins = 32
27 | n_f0_bins = None # keep None for auto detect using f0min & f0max
28 | n_f0_min = None # as offset
29 | maxlen_text = 128 # for pos_enc, 27 in train set
30 | maxlen_spec = 1024 # for pos_enc, 524 in train set
31 |
32 | ####################################################################################################################
33 |
34 | '''Audio'''
35 | segment_size = 8192
36 | window_fn = 'hann' # ['bartlett', 'blackman', 'hamming', 'hann', 'kaiser']
37 | mel_scale = 'slaney' # 'htk' is smooth, 'slaney' breaks at f=1000; see `https://blog.csdn.net/qq_39995815/article/details/116269040`
38 | gl_iters = 4
39 | gl_momentum = 0.7 # 0.99 may deadloop
40 | gl_power = 1.2 # power magnitudes before Griffin-Lim
41 | ref_wav = 'y' # ['y', 'dy']
42 |
43 |
44 | '''Model'''
45 | # RetuneCNN ( | | mstft= at 30 epoch)
46 | # HiFiGAN_mini ( | | mstft= at 30 epoch)
47 | # HiFiGAN_micro ( | | mstft= at 30 epoch)
48 | # HiFiGAN_mu ( | | mstft= at 30 epoch)
49 | # RefineGAN ( | | mstft= at 30 epoch)
50 | # RefineGAN_small (2748371 | | mstft= at 30 epoch)
51 | # MelGAN (4524290 | 2.36 s/b | mstft=10.084 at 30 epoch)
52 | # MelGANRetune (1409427 | 2.42 s/b | mstft= 7.000 at 30 epoch)
53 | # MelGANSplit
54 | # HiFiGAN (1421314 | 2.30 s/b | mstft=10.346 at 30 epoch)
55 | # HiFiGANRetune (1716627 | 2.45 s/b | mstft= 7.041 at 30 epoch) # 高频撕裂
56 | # HiFiGANSplit (2849890 | 2.49 s/b | mstft=11.320 at 30 epoch)
57 |
58 | # generator
59 | generator_ver = 'RefineGAN_small'
60 | split_cv = generator_ver.endswith('Split')
61 | upsample_rates = [8, 8, 4]
62 | upsample_kernel_sizes = [15, 15, 7]
63 | upsample_initial_channel = 256
64 | resblock_kernel_sizes = [3, 5, 7]
65 | resblock_dilation_sizes = [[1, 2], [2, 6], [3, 12]]
66 | #resblock_kernel_sizes = [3, 7, 11]
67 | #resblock_dilation_sizes = [[2, 3], [3, 5], [5, 11]]
68 |
69 | # discriminator
70 | msd_layers = 3
71 | mpd_periods = [3, 5, 7, 11]
72 | multi_stft_params = [
73 | # (n_fft, win_length, hop_length)
74 | # (2048, 1024, 256),
75 | # (1024, 512, 128),
76 | # ( 512, 256, 64),
77 | # by UnivNet
78 | (2048, 1024, 240),
79 | (1024, 512, 120),
80 | ( 512, 256, 60),
81 | ]
82 | phd_layers = len(multi_stft_params)
83 | phd_input = 'stft' # ['phase', 'stft'], phase似乎不太行
84 |
85 | # loss
86 | relative_gan_loss = False
87 | strip_mirror_loss = False
88 | dynamic_loss = True
89 | envelope_loss = False
90 | envelope_pool_k = 160 # use `tools.test_envolope.py` to pickle proper value
91 | downsample_pool_k = 4
92 |
93 |
94 | '''Misc'''
95 | from sys import platform
96 | debug = platform == 'win32' # debug locally on Windows, actually train on Linux
97 | randseed = 114514
98 |
99 |
100 | '''Training'''
101 | num_workers = 1 if debug else 4
102 | batch_size = 4 if debug else 16 # 16 = 567.6 steps per epoch
103 | learning_rate_d = 2e-4
104 | learning_rate_g = 1.8e-4
105 | d_train_times = 2 # 更新一次G的同时更新多少次D
106 | adam_b1 = 0.8
107 | adam_b2 = 0.99
108 | lr_decay = 0.999
109 |
110 | w_loss_fm = 2
111 | w_loss_mstft = 8
112 | w_loss_env = 4
113 | w_loss_dyn = 4
114 | w_loss_sm = 0.01
115 |
116 |
117 | '''Eval'''
118 | valid_limit = batch_size * 4
119 |
--------------------------------------------------------------------------------
/retunegan/infer.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import importlib
4 | from argparse import ArgumentParser
5 | from retunegan.audio import get_mag, inv_mag
6 |
7 | import torch
8 | import numpy as np
9 |
10 | from models import Generator
11 | from audio import load_wav, save_wav, get_mel
12 | from utils import load_checkpoint, scan_checkpoint
13 |
14 | device = 'cuda' if torch.cuda.is_available() else 'cpu'
15 | generator = None
16 |
17 |
18 | def load_generator(a):
19 | global generator
20 | if not generator:
21 | Generator = globals().get(f'Generator_{h.generator_ver}')
22 | generator = Generator().to(device)
23 | state_dict_g = state_dict_g = load_checkpoint(scan_checkpoint(args.log_path, 'g_'), device)
24 | generator.load_state_dict(state_dict_g['generator'])
25 | generator.eval()
26 | generator.remove_weight_norm()
27 | return generator
28 |
29 |
30 | def inference(a, x, wav_ref):
31 | generator = load_generator(a)
32 | with torch.no_grad():
33 | y_g_hat = generator(x, wav_ref)
34 | wav = y_g_hat.squeeze()
35 | wav = wav.cpu().numpy().astype(np.float32)
36 | return wav
37 |
38 |
39 | def inference_from_mag(a, fp):
40 | x = np.load(fp)
41 | x = torch.from_numpy(x).to(device)
42 | if x.size(1) == h.n_freq: x = x.T
43 |
44 | if len(x.shape) < 3: x = x.unsqueeze(0) # set batch_size=1
45 | y = inv_mag(x)
46 | wav = inference(a, x, y)
47 |
48 | wav_fp = os.path.join(a.output_dir, os.path.splitext(os.path.basename(fp))[0] + '_gen_from_mag.wav')
49 | save_wav(wav, wav_fp)
50 | print(f' Done {wav_fp!r}')
51 |
52 |
53 | def inference_from_wav(a, fp):
54 | wav = load_wav(fp)
55 | wav = torch.from_numpy(wav).to(device)
56 | x = get_mag(wav)
57 |
58 | if len(x.shape) < 3: x = x.unsqueeze(0) # set batch_size=1
59 | y = inv_mag(x)
60 | wav = inference(a, x, y)
61 |
62 | wav_fp = os.path.join(a.input_path, os.path.splitext(os.path.basename(fp))[0] + '_gen_from_wav.wav')
63 | save_wav(wav, wav_fp)
64 | print(f' Done {wav_fp!r}')
65 |
66 |
67 | if __name__ == '__main__':
68 | parser = ArgumentParser()
69 | parser.add_argument('--input_path', default='test')
70 | parser.add_argument('--log_path', required=True)
71 | a = parser.parse_args()
72 |
73 | # load frozen hparam
74 | sys.path.insert(0, a.log_path)
75 | h = importlib.import_module('hparam')
76 | torch.manual_seed(h.randseed)
77 | torch.cuda.manual_seed(h.randseed)
78 |
79 | print('Initializing Reference Process..')
80 | fps = [os.path.join(a.input_path, fn) for fn in os.listdir(a.input_path)]
81 | for fp in [fp for fp in fps if fp.lower().endswith('.npy')]:
82 | inference_from_mag(a, fp)
83 | for fp in [fp for fp in fps if fp.lower().endswith('.wav')]:
84 | inference_from_wav(a, fp)
85 |
--------------------------------------------------------------------------------
/retunegan/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .generator import *
2 | from .discrminator import *
3 | from .loss import *
4 |
--------------------------------------------------------------------------------
/retunegan/models/discrminator.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Author: Armit
3 | # Create Time: 2022/04/16
4 |
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 | from torch.nn import Conv1d, Conv2d
9 | from torch.nn.utils import weight_norm, spectral_norm
10 |
11 | import hparam as hp
12 | from utils import *
13 |
14 | PI = 3.14159265358979
15 |
16 |
17 | class DiscriminatorS(nn.Module):
18 |
19 | def __init__(self, use_sn=False):
20 | super().__init__()
21 |
22 | #norm_f = spectral_norm if use_sn else weight_norm
23 |
24 | # [8192]; 降采样 4*4*4*4=256 倍,对于各个降采样版本即降采样[256,512,1024]倍
25 | sel = 'MelGAN_small'
26 | if sel == 'MelGAN':
27 | self.convs = nn.ModuleList([
28 | weight_norm(Conv1d( 1, 16, 15, 1, padding=7)),
29 | weight_norm(Conv1d( 16, 64, 41, 4, padding=20, groups=4 )),
30 | weight_norm(Conv1d( 64, 256, 41, 4, padding=20, groups=16)),
31 | weight_norm(Conv1d( 256, 1024, 41, 4, padding=20, groups=64)),
32 | weight_norm(Conv1d(1024, 1024, 41, 4, padding=20, groups=256)),
33 | weight_norm(Conv1d(1024, 1024, 5, 1, padding=2)),
34 | ])
35 | self.conv_post = weight_norm(Conv1d(1024, 1, 3, 1, padding=1))
36 | if sel == 'MelGAN_small':
37 | self.convs = nn.ModuleList([
38 | weight_norm(Conv1d( 1, 32, 15, 1, padding=7)),
39 | weight_norm(Conv1d( 32, 64, 41, 2, padding=20, groups=4 )),
40 | weight_norm(Conv1d( 64, 128, 41, 2, padding=20, groups=8 )),
41 | weight_norm(Conv1d(128, 512, 41, 4, padding=20, groups=32)),
42 | weight_norm(Conv1d(512, 512, 41, 4, padding=20, groups=64)),
43 | weight_norm(Conv1d(512, 512, 5, 1, padding=2)),
44 | ])
45 | self.conv_post = weight_norm(Conv1d(512, 1, 3, 1, padding=1))
46 | elif sel == 'HiFiGAN':
47 | self.convs = nn.ModuleList([
48 | weight_norm(Conv1d( 1, 128, 15, 1, padding=7)),
49 | weight_norm(Conv1d( 128, 128, 41, 2, padding=20, groups=4 )),
50 | weight_norm(Conv1d( 128, 256, 41, 2, padding=20, groups=16)),
51 | weight_norm(Conv1d( 256, 512, 41, 4, padding=20, groups=16)),
52 | weight_norm(Conv1d( 512, 1024, 41, 4, padding=20, groups=16)),
53 | weight_norm(Conv1d(1024, 1024, 41, 1, padding=20, groups=16)),
54 | weight_norm(Conv1d(1024, 1024, 5, 1, padding=2)),
55 | ])
56 | self.conv_post = weight_norm(Conv1d(1024, 1, 3, 1, padding=1))
57 |
58 | def forward(self, x):
59 | # torch.Size([16, 1, 8192])
60 | # torch.Size([16, 32, 8192])
61 | # torch.Size([16, 64, 4096])
62 | # torch.Size([16, 128, 2048])
63 | # torch.Size([16, 512, 512])
64 | # torch.Size([16, 512, 128])
65 | # torch.Size([16, 512, 128])
66 | # torch.Size([16, 1, 128])
67 | #
68 | # torch.Size([16, 1, 4096])
69 | # torch.Size([16, 32, 4096])
70 | # torch.Size([16, 64, 2048])
71 | # torch.Size([16, 128, 1024])
72 | # torch.Size([16, 512, 256])
73 | # torch.Size([16, 512, 64])
74 | # torch.Size([16, 512, 64])
75 | # torch.Size([16, 1, 64])
76 | #
77 | # torch.Size([16, 1, 2048])
78 | # torch.Size([16, 32, 2048])
79 | # torch.Size([16, 64, 1024])
80 | # torch.Size([16, 128, 512])
81 | # torch.Size([16, 512, 128])
82 | # torch.Size([16, 512, 32])
83 | # torch.Size([16, 512, 32])
84 | # torch.Size([16, 1, 32])
85 |
86 | DEBUG = 0
87 | if DEBUG: print('[DiscriminatorS]')
88 |
89 | fmap = []
90 | if DEBUG: print(x.shape)
91 |
92 | for i, l in enumerate(self.convs):
93 | x = l(x)
94 | if DEBUG: print(x.shape)
95 | fmap.append(x)
96 | x = F.leaky_relu(x, LRELU_SLOPE)
97 | x = self.conv_post(x)
98 | if DEBUG: print(x.shape)
99 | x = torch.flatten(x, 1, -1)
100 |
101 | return x, fmap
102 |
103 |
104 | class MultiScaleDiscriminator(nn.Module):
105 |
106 | def __init__(self):
107 | super().__init__()
108 |
109 | self.discriminators = nn.ModuleList([
110 | DiscriminatorS(use_sn=i==0) for i in range(hp.msd_layers)
111 | ])
112 | # 不能使用音频处理的Resample,而要用AvgPool逐步抹去高频细节
113 | self.avgpool = nn.AvgPool1d(kernel_size=hp.downsample_pool_k, stride=2, padding=1)
114 |
115 | def forward(self, y, y_hat):
116 | y_d_rs, y_d_gs = [], []
117 | fmap_rs, fmap_gs = [], []
118 |
119 | for i, d in enumerate(self.discriminators):
120 | y_d_r, fmap_r = d(y)
121 | y_d_g, fmap_g = d(y_hat)
122 | y_d_rs.append(y_d_r) ; fmap_rs.append(fmap_r)
123 | y_d_gs.append(y_d_g) ; fmap_gs.append(fmap_g)
124 |
125 | if i != len(self.discriminators) - 1:
126 | y = self.avgpool(y)
127 | y_hat = self.avgpool(y_hat)
128 |
129 | return y_d_rs, y_d_gs, fmap_rs, fmap_gs
130 |
131 |
132 | class DiscriminatorP(nn.Module):
133 |
134 | def __init__(self, period):
135 | super().__init__()
136 |
137 | self.period = period
138 |
139 | # [8192]; 只在T方向降采样 3*3*3*3=81 倍
140 | # = [4096,2]
141 | # = [2731,3] (*)
142 | # = [1639,5] (*)
143 | # = [1171,7] (*)
144 | # = [745,11] (*)
145 | sel = 'HiFiGAN_small'
146 | if sel == 'HiFiGAN':
147 | self.convs = nn.ModuleList([
148 | weight_norm(Conv2d( 1, 32, (5, 1), (3, 1), padding=(2, 0))),
149 | weight_norm(Conv2d( 32, 128, (5, 1), (3, 1), padding=(2, 0))),
150 | weight_norm(Conv2d( 128, 512, (5, 1), (3, 1), padding=(2, 0))),
151 | weight_norm(Conv2d( 512, 1024, (5, 1), (3, 1), padding=(2, 0))),
152 | weight_norm(Conv2d(1024, 1024, (5, 1), 1, padding=(2, 0))),
153 | ])
154 | self.conv_post = weight_norm(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
155 | elif sel == 'HiFiGAN_small':
156 | self.convs = nn.ModuleList([
157 | weight_norm(Conv2d( 1, 32, (5, 1), (3, 1), padding=(2, 0))),
158 | weight_norm(Conv2d( 32, 128, (5, 1), (3, 1), padding=(2, 0))),
159 | weight_norm(Conv2d(128, 256, (5, 1), (3, 1), padding=(2, 0))),
160 | weight_norm(Conv2d(256, 512, (5, 1), (3, 1), padding=(2, 0))),
161 | weight_norm(Conv2d(512, 512, (5, 1), 1, padding=(2, 0))),
162 | ])
163 | self.conv_post = weight_norm(Conv2d(512, 1, (3, 1), 1, padding=(1, 0)))
164 |
165 | def forward(self, x):
166 | # torch.Size([16, 1, 2731, 3])
167 | # torch.Size([16, 32, 911, 3])
168 | # torch.Size([16, 128, 304, 3])
169 | # torch.Size([16, 256, 102, 3])
170 | # torch.Size([16, 512, 34, 3])
171 | # torch.Size([16, 512, 34, 3])
172 | # torch.Size([16, 1, 34, 3])
173 | #
174 | # torch.Size([16, 1, 1639, 5])
175 | # torch.Size([16, 32, 547, 5])
176 | # torch.Size([16, 128, 183, 5])
177 | # torch.Size([16, 256, 61, 5])
178 | # torch.Size([16, 512, 21, 5])
179 | # torch.Size([16, 512, 21, 5])
180 | # torch.Size([16, 1, 21, 5])
181 | #
182 | # torch.Size([16, 1, 1171, 7])
183 | # torch.Size([16, 32, 391, 7])
184 | # torch.Size([16, 128, 131, 7])
185 | # torch.Size([16, 256, 44, 7])
186 | # torch.Size([16, 512, 15, 7])
187 | # torch.Size([16, 512, 15, 7])
188 | # torch.Size([16, 1, 15, 7])
189 | #
190 | # torch.Size([16, 1, 745, 11])
191 | # torch.Size([16, 32, 249, 11])
192 | # torch.Size([16, 128, 83, 11])
193 | # torch.Size([16, 256, 28, 11])
194 | # torch.Size([16, 512, 10, 11])
195 | # torch.Size([16, 512, 10, 11])
196 | # torch.Size([16, 1, 10, 11])
197 |
198 | DEBUG = 0
199 | if DEBUG: print('[DiscriminatorP]')
200 |
201 | fmap = []
202 |
203 | # 1d to 2d
204 | b, c, t = x.shape
205 | if t % self.period != 0: # pad tail
206 | n_pad = self.period - (t % self.period)
207 | x = F.pad(x, (0, n_pad), "reflect")
208 | t = t + n_pad
209 | # [B, C, T', P]
210 | x = x.view(b, c, t // self.period, self.period)
211 |
212 | if DEBUG: print(x.shape)
213 | for i, l in enumerate(self.convs):
214 | x = l(x)
215 | if DEBUG: print(x.shape)
216 | fmap.append(x)
217 | x = F.leaky_relu(x, LRELU_SLOPE)
218 | x = self.conv_post(x)
219 | if DEBUG: print(x.shape)
220 | x = torch.flatten(x, 1, -1)
221 |
222 | return x, fmap
223 |
224 |
225 | class MultiPeriodDiscriminator(nn.Module):
226 |
227 | def __init__(self):
228 | super().__init__()
229 |
230 | self.discriminators = nn.ModuleList([
231 | DiscriminatorP(hp.mpd_periods[i]) for i in range(len(hp.mpd_periods))
232 | ])
233 |
234 | def forward(self, y, y_hat):
235 | y_d_rs, y_d_gs = [], []
236 | fmap_rs, fmap_gs = [], []
237 |
238 | for d in self.discriminators:
239 | y_d_r, fmap_r = d(y)
240 | y_d_g, fmap_g = d(y_hat)
241 | y_d_rs.append(y_d_r) ; fmap_rs.append(fmap_r)
242 | y_d_gs.append(y_d_g) ; fmap_gs.append(fmap_g)
243 |
244 | return y_d_rs, y_d_gs, fmap_rs, fmap_gs
245 |
246 |
247 | class StftDiscriminator(nn.Module):
248 |
249 | def __init__(self, i, ch=2):
250 | super().__init__()
251 |
252 | # [1025, 35]
253 | # [513, 69]
254 | # [257, 137]
255 | self.convs = nn.ModuleList([
256 | weight_norm(Conv2d( ch, 32, (3, 3), (2, 1), padding=(1, 1))),
257 | weight_norm(Conv2d( 32, 64, (3, 3), (2, 2), padding=(1, 1))),
258 | weight_norm(Conv2d( 64, 256, (5, 3), (3, 2), padding=(2, 1))),
259 | weight_norm(Conv2d(256, 512, (5, 3), (3, 2), padding=(2, 1))),
260 | weight_norm(Conv2d(512, 512, 3, 1, padding=1)),
261 | ])
262 | self.conv_post = weight_norm(Conv2d(512, 1, 3, 1, padding=1))
263 |
264 | self.convs .apply(init_weights)
265 | self.conv_post.apply(init_weights)
266 |
267 | def forward(self, x):
268 | # torch.Size([16, 2, 1025, 35])
269 | # torch.Size([16, 32, 513, 35])
270 | # torch.Size([16, 64, 257, 18])
271 | # torch.Size([16, 256, 86, 9])
272 | # torch.Size([16, 512, 29, 5])
273 | # torch.Size([16, 512, 29, 5])
274 | # torch.Size([16, 1, 29, 5])
275 | #
276 | # torch.Size([16, 2, 513, 69])
277 | # torch.Size([16, 32, 257, 69])
278 | # torch.Size([16, 64, 129, 35])
279 | # torch.Size([16, 256, 43, 18])
280 | # torch.Size([16, 512, 15, 9])
281 | # torch.Size([16, 512, 15, 9])
282 | # torch.Size([16, 1, 15, 9])
283 | #
284 | # torch.Size([16, 2, 257, 137])
285 | # torch.Size([16, 32, 129, 137])
286 | # torch.Size([16, 64, 65, 69])
287 | # torch.Size([16, 256, 22, 35])
288 | # torch.Size([16, 512, 8, 18])
289 | # torch.Size([16, 512, 8, 18])
290 | # torch.Size([16, 1, 8, 18])
291 |
292 | DEBUG = 0
293 | if DEBUG: print('[StftDiscriminator]')
294 |
295 | fmap = []
296 | if DEBUG: print(x.shape)
297 |
298 | # x.shape = [B, 2, C, T]
299 | for i, l in enumerate(self.convs):
300 | x = l(x)
301 | if DEBUG: print(x.shape)
302 | fmap.append(x)
303 | x = F.leaky_relu(x, LRELU_SLOPE)
304 | x = self.conv_post(x)
305 | if DEBUG: print(x.shape)
306 | x = torch.flatten(x, 1, -1)
307 |
308 | return x, fmap
309 |
310 |
311 | class MultiStftDiscriminator(nn.Module):
312 |
313 | def __init__(self):
314 | super().__init__()
315 |
316 | self.discriminators = nn.ModuleList([
317 | StftDiscriminator(i) for i in range(len(hp.multi_stft_params))
318 | ])
319 |
320 | def forward(self, phs, ph_hats):
321 | ph_d_rs, ph_d_gs = [], []
322 | fmap_rs, fmap_gs = [], []
323 |
324 | for d, ph, ph_hat in zip(self.discriminators, phs, ph_hats):
325 | ph_d_r, fmap_r = d(ph)
326 | ph_d_g, fmap_g = d(ph_hat)
327 | ph_d_rs.append(ph_d_r) ; fmap_rs.append(fmap_r)
328 | ph_d_gs.append(ph_d_g) ; fmap_gs.append(fmap_g)
329 |
330 | return ph_d_rs, ph_d_gs, fmap_rs, fmap_gs
331 |
--------------------------------------------------------------------------------
/retunegan/models/loss.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Author: Armit
3 | # Create Time: 2022/04/16
4 |
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 |
9 | import seaborn as sns
10 | import matplotlib.pyplot as plt
11 |
12 | import hparam as hp
13 | from audio import get_stft_torch
14 | from utils import PI
15 |
16 |
17 | # globals
18 | MaxPool = nn.MaxPool1d(hp.envelope_pool_k)
19 |
20 |
21 | # multi-stft loss
22 | def multi_stft_loss(y, y_g, ret_loss=False, ret_specs=False):
23 | loss = 0
24 | if ret_specs: stft_r, stft_g = [], []
25 |
26 | # [B, 1, T] => [B, T]
27 | if len(y.shape) == 3:
28 | y, y_g = y.squeeze(1), y_g.squeeze(1)
29 |
30 | for n_fft, win_length, hop_length in hp.multi_stft_params:
31 | # 得到原始的Mel值
32 | y_mag, y_mel, y_phase = get_stft_torch(y, n_fft, win_length, hop_length) # TODO: 这项可以拆出去缓存
33 | y_g_mag, y_g_mel, y_g_phase = get_stft_torch(y_g, n_fft, win_length, hop_length)
34 |
35 | # 得到对数放缩的谱值
36 | log_y_mel, log_y_g_mel = torch.log(y_mel), torch.log(y_g_mel)
37 | log_y_mag, log_y_g_mag = torch.log(y_mag), torch.log(y_g_mag)
38 | norm_y_phase, norm_y_g_phase = y_phase / PI, y_g_phase / PI
39 |
40 | if ret_specs:
41 | # 判别器在线性谱上作判定
42 | if hp.phd_input == 'stft':
43 | stft_r.append(torch.stack([log_y_mag, norm_y_phase], dim=1))
44 | stft_g.append(torch.stack([log_y_g_mag, norm_y_g_phase], dim=1))
45 | elif hp.phd_input == 'phase':
46 | stft_r.append(torch.stack([log_y_mag, norm_y_phase], dim=1))
47 | stft_g.append(torch.stack([log_y_mag, norm_y_g_phase], dim=1))
48 | else: raise
49 |
50 | # 谱损失在Mel的对数值和原始值上作考察
51 | loss += F.l1_loss( y_mel, y_g_mel)
52 | loss += F.l1_loss(log_y_mel, log_y_g_mel)
53 |
54 | loss /= len(hp.multi_stft_params)
55 |
56 | if ret_loss and ret_specs:
57 | return loss, (stft_r, stft_g)
58 | elif ret_loss:
59 | return loss
60 | elif ret_specs:
61 | return (stft_r, stft_g)
62 | else: raise
63 |
64 |
65 | # envelope loss for waveform
66 | def envelope_loss(y, y_g):
67 | # 绝对动态包络
68 | loss = 0
69 | loss += torch.mean(torch.abs(MaxPool( y) - MaxPool( y_g)))
70 | loss += torch.mean(torch.abs(MaxPool(-y) - MaxPool(-y_g)))
71 |
72 | return loss
73 |
74 |
75 | # dynamic loss for waveform
76 | def dynamic_loss(y, y_g):
77 | # 相对动态大小
78 | dyn_y = torch.abs(MaxPool(y) + MaxPool(-y))
79 | dyn_y_g = torch.abs(MaxPool(y_g) + MaxPool(-y_g))
80 | loss = torch.mean(torch.abs(dyn_y - dyn_y_g))
81 |
82 | return loss
83 |
84 |
85 | # strip mirror loss for waveform
86 | def strip_mirror_loss(y):
87 | # 可能没啥用甚至有副作用的正则损失
88 |
89 | # assure length is even
90 | if y.shape[-1] % 2 != 0: y = y[:,:,:-1]
91 | # strip split & de-mean
92 | even, odd = y[:,:,::2], y[:,:,1::2]
93 | even = even - even.mean()
94 | odd = odd - odd .mean()
95 | # maximize |e-o|
96 | loss = torch.mean(-torch.log(torch.clamp_max((torch.abs(even - odd) + 1e-9), max=1.0)))
97 |
98 | return loss
99 |
100 |
101 | # adversarial loss for discriminator
102 | def discriminator_loss(disc_r, disc_g):
103 | loss = 0
104 |
105 | for dr, dg in zip(disc_r, disc_g): # List[[B, T], ...]
106 | if hp.relative_gan_loss:
107 | # maxmize gap betwwen dg & dr
108 | #r_loss = torch.mean((1 - (dr - dg.detach().mean())) ** 2)
109 | #r_loss = torch.mean((1 - (dr - dg.detach().mean(axis=-1))) ** 2)
110 | #r_loss = torch.mean(1 - torch.tanh(dr - dg.detach()))
111 | #r_loss = torch.mean((1 - (dr - dg.detach())) ** 2)
112 | #g_loss = torch.mean((0 - dg) ** 2)
113 | r_loss = torch.mean(torch.mean((1 - (dr - dg.detach())) ** 2, axis=-1))
114 | g_loss = torch.mean(torch.mean((0 - dg) ** 2, axis=-1))
115 | #r_loss = torch.mean(-torch.log(dr))
116 | #g_loss = torch.mean(dg)
117 | else:
118 | # let dr -> 1, dg -> 0
119 | #r_loss = torch.mean((1 - dr) ** 2)
120 | #g_loss = torch.mean((0 - dg) ** 2)
121 | r_loss = torch.mean(torch.mean((1 - dr) ** 2, axis=-1))
122 | g_loss = torch.mean(torch.mean((0 - dg) ** 2, axis=-1))
123 | loss += (r_loss + g_loss)
124 |
125 | return loss
126 |
127 |
128 | # adversarial loss for generator
129 | def generator_loss(disc_g, disc_r):
130 | loss = 0
131 |
132 | for dg, dr in zip(disc_g, disc_r):
133 | if hp.relative_gan_loss:
134 | # let dg ~= dr
135 | #g_loss = torch.mean((dr.detach().mean(axis=-1) - dg) ** 2)
136 | #g_loss = torch.mean(1 - torch.tanh(dg - dr.detach()))
137 | #g_loss = torch.mean((dr.detach() - dg) ** 2)
138 | g_loss = torch.mean(torch.mean((dg - dr.detach()) ** 2, axis=-1))
139 | else:
140 | # let dg -> 1
141 | #g_loss = torch.mean((1 - dg) ** 2)
142 | g_loss = torch.mean(torch.mean((1 - dg) ** 2, axis=-1))
143 | loss += g_loss
144 |
145 | return loss
146 |
147 |
148 | # feature map loss for generator
149 | def feature_loss(fmap_r, fmap_g):
150 | loss = 0
151 |
152 | for dr, dg in zip(fmap_r, fmap_g):
153 | for r, g in zip(dr, dg):
154 | loss += F.l1_loss(r, g)
155 |
156 | return loss
157 |
--------------------------------------------------------------------------------
/retunegan/server.py:
--------------------------------------------------------------------------------
1 | import io
2 | import pickle
3 | from time import time
4 | from argparse import ArgumentParser
5 |
6 | import torch
7 | import numpy as np
8 | from flask import Flask, request, jsonify, send_file
9 |
10 | import hparam as h
11 | from models.generator import *
12 | from audio import mag_to_mel, inv_mag
13 | from utils import scan_checkpoint, load_checkpoint
14 |
15 |
16 | os.environ['LIBROSA_CACHE_LEVEL'] = '50'
17 |
18 | device = 'cuda' if torch.cuda.is_available() else 'cpu'
19 | torch.backends.cudnn.enable = True
20 | torch.backends.cudnn.benchmark = True
21 | torch. manual_seed(h.randseed)
22 | torch.cuda.manual_seed(h.randseed)
23 |
24 | torch.autograd.set_detect_anomaly(hp.debug)
25 |
26 |
27 | # globals
28 | app = Flask(__name__)
29 | generator = None
30 |
31 |
32 | # vocode
33 | @app.route('/vocode', methods=['POST'])
34 | def vocode():
35 | try:
36 | # chk mag
37 | mag = pickle.loads(request.data)
38 | print(f'mag.shape: {mag.shape}, dyn_range: [{mag.min()}, {mag.max()}]')
39 | if mag.shape[1] == h.n_freq: mag = mag.T # assure [F, T]
40 | # ref: preprocess in `data.Dataset.__getitem__()`
41 | mel = mag_to_mel(mag)
42 | wavlen = h.hop_length * mag.shape[1]
43 | wav_tmpl = inv_mag(mag, wavlen=wavlen-1)
44 | wav_tmpl = np.pad(wav_tmpl, (0, 1))
45 |
46 | # mel to wav
47 | s = time()
48 | with torch.no_grad():
49 | mel = torch.from_numpy(mel) .to(device, non_blocking=True).float().unsqueeze(0)
50 | wav_tmpl = torch.from_numpy(wav_tmpl).to(device, non_blocking=True).float().unsqueeze(0).unsqueeze(1)
51 | y_g_hat = generator(mel, wav_tmpl)
52 | wav = y_g_hat.squeeze()
53 | wav = wav.cpu().numpy().astype(np.float32)
54 | t = time()
55 | print(f'wav.shape: {wav.shape}, dyn_range: [{wav.min()}, {wav.max()}]')
56 | print(f'[Vocode] Done in {t - s:.2f}s')
57 |
58 | # transfer
59 | bio = io.BytesIO()
60 | bio.write(pickle.dumps(wav)) # float32 -> byte
61 | bio.seek(0) # reset fp to beginning for `send_file` to read
62 | return send_file(bio, mimetype='application/octet-stream')
63 |
64 | except Exception as e:
65 | print('[Error] %r' % e)
66 | return jsonify({'error': e})
67 |
68 |
69 | if __name__ == '__main__':
70 | parser = ArgumentParser()
71 | parser.add_argument('--log_path', required=True)
72 | parser.add_argument('--host', type=str, default='0.0.0.0')
73 | parser.add_argument('--port', type=int, default=5104)
74 | args = parser.parse_args()
75 |
76 | # load ckpt
77 | generator = globals().get(f'Generator_{h.generator_ver}')().to(device)
78 | state_dict_g = load_checkpoint(scan_checkpoint(args.log_path, 'g_'), device)
79 | generator.load_state_dict(state_dict_g['generator'])
80 | generator.eval()
81 | generator.remove_weight_norm()
82 |
83 | app.run(host=args.host, port=args.port, debug=False)
84 |
--------------------------------------------------------------------------------
/retunegan/tools/spec2wavset.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Author: Armit
3 | # Create Time: 2022/02/14
4 |
5 | # 借鉴RefineGAN的想法:
6 | # RefineGAN: 根据从上游模型预测生成的f0,c0,voice-flag信息,生成一个wavform template
7 | # 对于unvoiced段填充高斯白噪、对于voice段依据f0的频率来放置脉冲,再依照c0画整体包络
8 | # spec2wavset: stft谱展示了将原信号拆解为一组频率等间距的正弦信号,因此我们可以直接叠加这一组正弦波作为wavform template
9 | # Q: 为什么不用GriffinLim的输出做模板呢,后续网络相当于在频域进行降噪了
10 | # A: 降噪是很困难的,但是逆向思考——正弦波的组合是比较干净的、我们基于它来加噪声!
11 | # *注意: 需要多组stft参数来避免频率丢失、减缓窗函数导致的频率泄露问题
12 | # fft_params:
13 | # n_fft win_length hop_length
14 | # 2048 1024 256
15 | # 1024 512 128
16 | # 512 256 64
17 |
18 | # 从FFT提取频率和振幅: https://blog.csdn.net/taw19960426/article/details/101684663
19 | # 宽带语谱图: 帧长 3ms, hop_length=48(16000Hz)/66(22050Hz)
20 | # 窄带语谱图: 帧长20ms, hop_length=320(16000Hz)/441(22050Hz)
21 |
22 | # NOTE: 统计特征
23 | # magnitude: 单峰(语音信号) + 左侧高原(底噪)
24 | # phase: 近乎[-pi,pi]间的均匀分布
25 |
26 | import os
27 | import numpy as np
28 | import numpy.random as R
29 | import matplotlib.pyplot as plt
30 | import seaborn as sns
31 | import librosa as L
32 | from scipy.io import wavfile
33 |
34 | sample_rate = 16000
35 | n_mel = 80
36 | fmin = 70
37 | fmax = 7600
38 | fft_params = [
39 | # (n_fft, win_length, hop_length)
40 | (4096, 2048, 512),
41 | (2048, 1024, 256),
42 | (1024, 512, 128),
43 | ( 512, 256, 64),
44 | ( 256, 128, 32),
45 | ]
46 |
47 |
48 | def calc_spec(y, n_fft, hop_length, win_length, clamp_lower=False):
49 | D = L.stft(y, n_fft=n_fft, hop_length=hop_length, win_length=win_length, window='hann')
50 | S = np.abs(D)
51 | print('S.min():', S.min())
52 | mag = np.log(S.clip(min=1e-5)) if clamp_lower else S
53 | M = L.feature.melspectrogram(S=S, sr=sample_rate, n_fft=n_fft, hop_length=hop_length, win_length=win_length,
54 | n_mels=n_mel, fmin=fmin, fmax=fmax, window='hann', power=1, htk=True)
55 | print('M.min():', M.min())
56 | mel = np.log(M.clip(min=1e-5)) if clamp_lower else M
57 | return mag, mel
58 |
59 |
60 | def get_specs(fp):
61 | y, _ = L.load(fp, sr=sample_rate, mono=True, res_type='kaiser_best')
62 | return [calc_spec(y, n_fft, hop_length, win_length) for n_fft, win_length, hop_length in fft_params]
63 |
64 |
65 | def display_specs(specs):
66 | n_fig = len(specs)
67 | for i in range(n_fig):
68 | if isinstance(type(specs[i]), (list, tuple)):
69 | plt.subplot(n_fig*100+10+i+1)
70 | ax = sns.heatmap(specs[i][0])
71 | ax.invert_yaxis()
72 | plt.subplot(n_fig*100+20+i+1)
73 | ax = sns.heatmap(specs[i][1])
74 | ax.invert_yaxis()
75 | else:
76 | plt.subplot(n_fig*100+10+i+1)
77 | ax = sns.heatmap(specs[i])
78 | ax.invert_yaxis()
79 | plt.show()
80 | plt.clf()
81 |
82 | def f(k, n_fft):
83 | return k * (sample_rate / n_fft)
84 |
85 | def sin(x, A, freq, phi=0):
86 | w = 2 * np.pi * freq # f = w / (2*pi)
87 | return A * np.sin(w * x + phi)
88 |
89 |
90 | def extract_f_A(fp):
91 | for i, (mag, _) in enumerate(get_specs(fp)):
92 | n_fft, _, _ = fft_params[i]
93 | print(f'[n_fft={n_fft}]')
94 |
95 | fr = sample_rate / n_fft # 频率分辨率单位
96 | print('mag min/max/mean:', mag.min(), mag.max(), mag.mean())
97 |
98 | f_A = []
99 | thresh = mag.mean() * 2
100 | for i, energy in enumerate(mag):
101 | print(f'frame {i}: ')
102 | j = 0
103 | while j < len(energy):
104 | # 阈值响应
105 | if energy[j] <= thresh:
106 | j += 1
107 | else:
108 | # 略过上坡
109 | while j + 1 < len(energy) and energy[j+1] >= energy[j]:
110 | j += 1
111 | # 取得峰值
112 | f_A.append((fr * j, energy[j]))
113 | print(f' freq({j}) = {fr * j} Hz, amp = {energy[j]}')
114 | # 略过下坡
115 | while j + 1 < len(energy) and energy[j+1] < energy[j]:
116 | j += 1
117 |
118 |
119 | def demo_extract_f_A():
120 | st = 1 # 采样时间 1s
121 | x = np.arange(0, 1, st / sample_rate)
122 | print('signal len:', len(x))
123 | y1 = sin(x, 2, 207)
124 | y2 = sin(x, -1, 843)
125 | y = y1 + y2
126 |
127 | plt.subplot(311) ; plt.plot(x, y1)
128 | plt.subplot(312) ; plt.plot(x, y2)
129 | plt.subplot(313) ; plt.plot(x, y)
130 | plt.show()
131 |
132 | n_fft = 2048
133 | mag = calc_mag(y, n_fft=n_fft, win_length=1024, hop_length=256)
134 | display_specs(mag)
135 |
136 | mag = mag.T
137 |
138 | n_fft = fft_params[1][0]
139 | A = 2 * mag / n_fft
140 | fr = sample_rate / n_fft # 频率分辨率单位
141 |
142 | f_A = []
143 | thresh = 0.1
144 | for i, energy in enumerate(A):
145 | print(f'frame {i}: ')
146 | j = 0
147 | while j < len(energy):
148 | # 阈值响应
149 | if energy[j] <= thresh:
150 | j += 1
151 | else:
152 | # 略过上坡
153 | while j + 1 < len(energy) and energy[j+1] >= energy[j]:
154 | j += 1
155 | # 取得峰值
156 | f_A.append((fr * j, energy[j]))
157 | print(f' freq({j}) = {fr * j} Hz, amp = {energy[j]}')
158 | # 略过下坡
159 | while j + 1 < len(energy) and energy[j+1] < energy[j]:
160 | j += 1
161 |
162 | return f_A
163 |
164 |
165 | def demo_stft_istft(fp):
166 | n_fft, win_length, hop_length = 2048, 1024, 256
167 |
168 | y, _ = L.load(fp, sr=sample_rate, mono=True, res_type='kaiser_best')
169 | length = len(y) # for align output
170 |
171 | stft = L.stft(y, n_fft=n_fft, hop_length=hop_length, win_length=win_length)
172 | mag, phase = np.abs(stft), np.angle(stft)
173 | y_istft = L.istft(stft, win_length=win_length, hop_length=hop_length, length=length)
174 | D = L.stft(y_istft, n_fft=n_fft, hop_length=hop_length, win_length=win_length)
175 | mag_istft, phase_istft = np.abs(D), np.angle(D)
176 |
177 | print('mag.min():', mag.min(), 'mag.max():', mag.max())
178 | plt.subplot(311) ; plt.hist(mag.flatten(), bins=150) # [0, 10]
179 | maglog = np.log(mag)
180 | print('maglog.min():', maglog.min(), 'maglog.max():', maglog.max())
181 | plt.subplot(312) ; plt.hist(maglog.flatten(), bins=150) # [-12, 3]
182 | plt.subplot(313) ; plt.hist(phase.flatten(), bins=150) # [-pi, pi]
183 | plt.show()
184 |
185 | phase_ristft_r = np.random.uniform(low=-np.pi, high=np.pi, size=phase.shape)
186 | rstft = mag * np.exp(1j * phase_ristft_r)
187 | y_ristft = L.istft(rstft, win_length=win_length, hop_length=hop_length, length=length)
188 | D = L.stft(y_ristft, n_fft=n_fft, hop_length=hop_length, win_length=win_length)
189 | mag_ristft, phase_ristft = np.abs(D), np.angle(D)
190 |
191 | y_gl = L.griffinlim(mag, hop_length=hop_length, win_length=win_length, length=length)
192 | D = L.stft(y_gl, n_fft=n_fft, hop_length=hop_length, win_length=win_length)
193 | mag_gl, phase_gl = np.abs(D), np.angle(D)
194 |
195 | print('y_istft error:', np.sum(y - y_istft))
196 | print('y_ristft error:', np.sum(y - y_ristft))
197 | print('y_gl error:', np.sum(y - y_gl))
198 |
199 | print('y_istft mirror error:', np.sum([e - o for e, o in zip(y_istft[::2], y_istft[1::2])]))
200 | print('y_ristft mirror error:', np.sum([e - o for e, o in zip(y_ristft[::2], y_ristft[1::2])]))
201 | print('y_gl mirror error:', np.sum([e - o for e, o in zip(y_gl[::2], y_gl[1::2])]))
202 |
203 | print('mag_istft error:', np.sum(mag - mag_istft))
204 | print('mag_ristft error:', np.sum(mag - mag_ristft))
205 | print('mag_gl error:', np.sum(mag - mag_gl))
206 |
207 | print('phase_istft error:', np.sum(phase - phase_istft))
208 | print('phase_ristft error:', np.sum(phase - phase_ristft))
209 | print('phase_ristft_r error:', np.sum(phase_ristft_r - phase_ristft))
210 | print('phase_gl error:', np.sum(phase - phase_gl))
211 |
212 | plt.subplot(231) ; sns.heatmap(mag)
213 | plt.subplot(232) ; sns.heatmap(mag_ristft)
214 | plt.subplot(233) ; sns.heatmap(mag_gl)
215 | plt.subplot(234) ; sns.heatmap(phase)
216 | plt.subplot(235) ; sns.heatmap(phase_ristft)
217 | plt.subplot(236) ; sns.heatmap(phase_gl)
218 | plt.show()
219 |
220 | wavfile.write('y.wav', sample_rate, y)
221 | wavfile.write('y_istft.wav', sample_rate, y_istft)
222 | wavfile.write('y_ristft.wav', sample_rate, y_ristft)
223 | wavfile.write('y_gl.wav', sample_rate, y_gl)
224 |
225 |
226 | if __name__ == '__main__':
227 | dp = r'D:\Desktop\Workspace\Data\DataBaker\Wave'
228 | fn = R.choice([fn for fn in os.listdir(dp) if fn.endswith('.wav')])
229 | fp = os.path.join(dp, fn)
230 |
231 | demo_stft_istft(fp)
232 | #extract_f_A(fp)
233 |
--------------------------------------------------------------------------------
/retunegan/tools/test_downsample.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Author: Armit
3 | # Create Time: 2022/03/10
4 |
5 | # use this script to decide `hp.downsample_pool_k`
6 | # higher sample_rate needs larger `k`
7 |
8 |
9 | import os
10 | import random as R
11 | import librosa as L
12 | from scipy.io import wavfile
13 | import matplotlib.pyplot as plt
14 | import torch
15 | import torch.nn as nn
16 | import torchaudio.transforms as T
17 |
18 | import hparam as hp
19 |
20 | resamplers = [
21 | T.Resample(hp.sample_rate, hp.sample_rate//2, resampling_method='sinc_interpolation'), # kaiser_window
22 | T.Resample(hp.sample_rate, hp.sample_rate//4, resampling_method='sinc_interpolation'),
23 | T.Resample(hp.sample_rate, hp.sample_rate//8, resampling_method='sinc_interpolation'),
24 | ]
25 | avg_pools = [
26 | nn.AvgPool1d(kernel_size=2, stride=2, padding=1),
27 | nn.AvgPool1d(kernel_size=4, stride=2, padding=2),
28 | nn.AvgPool1d(kernel_size=6, stride=2, padding=3), # for 16000 Hz
29 | nn.AvgPool1d(kernel_size=8, stride=2, padding=4),
30 | ]
31 |
32 | dp = r'C:\Users\Kahsolt\Desktop\Workspace\Data\DataBaker\Wave'
33 | fn = R.choice([fn for fn in os.listdir(dp) if fn.endswith('.wav')])
34 | fp = os.path.join(dp, fn)
35 |
36 | y = L.load(fp, hp.sample_rate)[0]
37 | y = torch.from_numpy(y).unsqueeze(0).unsqueeze(0)
38 |
39 | plt.subplot(len(resamplers)+1, 1, 1)
40 | plt.plot(y.numpy().squeeze())
41 | for i, resampler in enumerate(resamplers):
42 | s = resampler(y)
43 | plt.subplot(len(resamplers)+1, 1, i+2)
44 | plt.plot(s.numpy().squeeze())
45 | plt.show()
46 |
47 | for i, avg_pool in enumerate(avg_pools):
48 | plt.subplot(len(resamplers)+1, 1, 1)
49 | plt.plot(y.numpy().squeeze())
50 | s = y
51 | for j in range(len(resamplers)):
52 | s = avg_pool(s)
53 | plt.subplot(len(resamplers)+1, 1, j+2)
54 | plt.plot(s.numpy().squeeze())
55 | plt.show()
56 |
57 |
--------------------------------------------------------------------------------
/retunegan/tools/test_envolope.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Author: Armit
3 | # Create Time: 2022/03/10
4 |
5 | # use this script to decide `hp.envelope_pool_k`
6 | # higher sample_rate needs larger `k`
7 |
8 |
9 | import os
10 | import random as R
11 | import librosa as L
12 | import matplotlib.pyplot as plt
13 | import torch
14 | import torch.nn as nn
15 |
16 | import hparam as hp
17 |
18 |
19 | max_pools = [
20 | nn.MaxPool1d( 64, 1),
21 | nn.MaxPool1d(128, 1), # for 16000 Hz
22 | nn.MaxPool1d(160, 1), # for 22050 Hz
23 | nn.MaxPool1d(256, 1), # for 44100 Hz
24 | nn.MaxPool1d(512, 1),
25 | ]
26 |
27 | dp = r'C:\Users\Kahsolt\Desktop\Workspace\Data\DataBaker\Wave'
28 | fn = R.choice([fn for fn in os.listdir(dp) if fn.endswith('.wav')])
29 | fp = os.path.join(dp, fn)
30 |
31 |
32 | y = L.load(fp, hp.sample_rate)[0]
33 | y_np = y
34 | y = torch.from_numpy(y).unsqueeze(0).unsqueeze(0)
35 |
36 | plt.subplot(4, 1, 1)
37 | plt.plot(y.numpy().squeeze())
38 | plt.title('y')
39 | pool = max_pools[2]
40 | u = pool( y)
41 | d = -pool(-y)
42 | plt.subplot(4, 1, 2)
43 | plt.title('y_envolope')
44 | plt.plot(u.numpy().squeeze())
45 | plt.plot(d.numpy().squeeze())
46 | plt.subplot(4, 1, 3)
47 | plt.title('y_even')
48 | plt.plot(y_np[::2])
49 | plt.subplot(4, 1, 4)
50 | plt.title('y_odd')
51 | plt.plot(y_np[1::2])
52 | plt.show()
53 |
54 | exit(0)
55 |
56 | y = L.load(fp, hp.sample_rate)[0]
57 | y = torch.from_numpy(y).unsqueeze(0).unsqueeze(0)
58 |
59 | plt.subplot(len(max_pools)+1, 1, 1)
60 | plt.plot(y.numpy().squeeze())
61 | for i, pool in enumerate(max_pools):
62 | u = pool( y)
63 | d = -pool(-y)
64 | plt.subplot(len(max_pools)+1, 1, i+2)
65 | plt.plot(u.numpy().squeeze())
66 | plt.plot(d.numpy().squeeze())
67 | plt.show()
68 |
--------------------------------------------------------------------------------
/retunegan/tools/test_griffinlim.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Author: Armit
3 | # Create Time: 2022/03/10
4 |
5 | import os
6 | import random as R
7 | import copy
8 | import librosa as L
9 | import seaborn as sns
10 | import matplotlib.pyplot as plt
11 | import numpy as np
12 | from scipy.io import wavfile
13 |
14 | import hparam as hp
15 |
16 | R.seed(114514)
17 |
18 |
19 | def decompose(D):
20 | P, S = np.angle(D), np.abs(D)
21 | logS = np.log(S.clip(1e-5, None))
22 | return P, S, logS
23 |
24 |
25 | def griffin_lim(S, n_iter=60):
26 | X_best = copy.deepcopy(S)
27 | for _ in range(n_iter):
28 | X_t = L.istft(X_best, hop_length=hp.hop_length, win_length=hp.win_length, window=hp.window_fn)
29 | X_best = L.stft(X_t, n_fft=hp.n_fft, hop_length=hp.hop_length, win_length=hp.win_length, window=hp.window_fn)
30 | phase = X_best / np.maximum(1e-8, np.abs(X_best))
31 | X_best = S * phase
32 | X_t = L.istft(X_best, hop_length=hp.hop_length, win_length=hp.win_length, window=hp.window_fn)
33 | y = np.real(X_t)
34 | return y
35 |
36 |
37 | def griffinlim(S, n_iter=60):
38 | angles = np.exp(2j * np.pi * np.random.rand(*S.shape))
39 |
40 | for _ in range(n_iter):
41 | full = np.abs(S).astype(np.complex) * angles
42 | inverse = L.istft(full, hop_length=hp.hop_length, win_length=hp.win_length, window=hp.window_fn)
43 | rebuilt = L.stft(inverse, n_fft=hp.n_fft, hop_length=hp.hop_length, win_length=hp.win_length, window=hp.window_fn)
44 | angles = np.exp(1j * np.angle(rebuilt))
45 | full = np.abs(S).astype(np.complex) * angles
46 | inverse = L.istft(full, hop_length=hp.hop_length, win_length=hp.win_length, window=hp.window_fn)
47 | return inverse
48 |
49 |
50 | def griffinlim_conj(P, n_iter=60):
51 | #S = np.random.uniform(low=np.exp(-4), high=np.exp(4), size=P.shape)
52 | S = np.random.normal(loc=-4, scale=1, size=P.shape)
53 | #S = np.random.rand(*P.shape)
54 |
55 | for _ in range(n_iter):
56 | D = S * np.exp(1j * P)
57 | y_hat = L.istft(D, hop_length=hp.hop_length, win_length=hp.win_length, window=hp.window_fn)
58 | D_hat = L.stft(y_hat, n_fft=hp.n_fft, hop_length=hp.hop_length, win_length=hp.win_length, window=hp.window_fn)
59 | S = np.abs(D_hat)
60 |
61 | D = S * np.exp(1j * P)
62 | y = L.istft(D, hop_length=hp.hop_length, win_length=hp.win_length, window=hp.window_fn)
63 | return y
64 |
65 |
66 | dp = r'C:\Users\Kahsolt\Desktop\Workspace\Data\DataBaker\Wave'
67 | fn = R.choice([fn for fn in os.listdir(dp) if fn.endswith('.wav')])
68 | fp = os.path.join(dp, fn)
69 |
70 | y = L.load(fp, hp.sample_rate)[0]
71 | ylen = len(y)
72 | D = L.stft(y, n_fft=hp.n_fft, hop_length=hp.hop_length, win_length=hp.win_length)
73 | P, S, logS = decompose(D)
74 |
75 | y1 = griffin_lim(S)
76 | y2 = griffinlim(S)
77 | y3 = griffinlim_conj(P)
78 |
79 | plt.subplot(4, 1, 1) ; plt.plot(y)
80 | plt.subplot(4, 1, 2) ; plt.plot(y1)
81 | plt.subplot(4, 1, 3) ; plt.plot(y2)
82 | plt.subplot(4, 1, 4) ; plt.plot(y3)
83 | plt.show()
84 |
85 | D1 = L.stft(y1, n_fft=hp.n_fft, hop_length=hp.hop_length, win_length=hp.win_length)
86 | D2 = L.stft(y2, n_fft=hp.n_fft, hop_length=hp.hop_length, win_length=hp.win_length)
87 | D3 = L.stft(y3, n_fft=hp.n_fft, hop_length=hp.hop_length, win_length=hp.win_length)
88 | _, _, logS1 = decompose(D1)
89 | _, _, logS2 = decompose(D2)
90 | _, _, logS3 = decompose(D3)
91 |
92 | plt.subplot(2, 2, 1) ; sns.heatmap(logS, cbar=False)
93 | plt.subplot(2, 2, 2) ; sns.heatmap(logS1, cbar=False)
94 | plt.subplot(2, 2, 3) ; sns.heatmap(logS2, cbar=False)
95 | plt.subplot(2, 2, 4) ; sns.heatmap(logS3, cbar=False)
96 | plt.show()
97 |
98 | wavfile.write('y_griffinlim_conj.wav', hp.sample_rate, y3)
99 |
--------------------------------------------------------------------------------
/retunegan/tools/test_istft_iter.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Author: Armit
3 | # Create Time: 2022/03/10
4 |
5 | import os
6 | import random as R
7 | import librosa as L
8 | import seaborn as sns
9 | import matplotlib.pyplot as plt
10 | import numpy as np
11 | from scipy.io import wavfile
12 |
13 | import hparam as hp
14 |
15 | R.seed(114514)
16 |
17 |
18 | def decompose(D):
19 | P, S = np.angle(D), np.abs(D)
20 | return P, S
21 |
22 |
23 | dp = r'C:\Users\Kahsolt\Desktop\Workspace\Data\DataBaker\Wave'
24 | fn = R.choice([fn for fn in os.listdir(dp) if fn.endswith('.wav')])
25 | fp = os.path.join(dp, fn)
26 |
27 | y = L.load(fp, hp.sample_rate)[0]
28 | ylen = len(y)
29 | D = L.stft(y, n_fft=hp.n_fft, hop_length=hp.hop_length, win_length=hp.win_length)
30 | P, S = decompose(D)
31 |
32 | y_loss, P_loss, S_loss = [], [], []
33 | D_i, P_i, S_i = D, P, S
34 | for i in range(1000):
35 | #D_i = S_i * np.exp(1j * P_i)
36 | y_i = L.istft(D_i, win_length=hp.win_length, hop_length=hp.hop_length, length=ylen)
37 | D_i = L.stft(y_i, n_fft=hp.n_fft, hop_length=hp.hop_length, win_length=hp.win_length)
38 | P_i, S_i = decompose(D_i)
39 |
40 | y_l = np.mean(np.abs(y - y_i)) ; y_loss.append(y_l)
41 | P_l = np.mean(np.abs(P - P_i)) ; P_loss.append(P_l)
42 | S_l = np.mean(np.abs(S - S_i)) ; S_loss.append(S_l)
43 |
44 | plt.subplot(3, 1, 1) ; plt.plot(y_loss)
45 | plt.subplot(3, 1, 2) ; plt.plot(P_loss)
46 | plt.subplot(3, 1, 3) ; plt.plot(S_loss)
47 | plt.show()
48 |
--------------------------------------------------------------------------------
/retunegan/tools/test_pesq.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Author: Armit
3 | # Create Time: 2022/03/10
4 |
5 | # use this script evaluate PESQ
6 |
7 | import sys
8 | import librosa as L
9 | from pesq import pesq
10 |
11 | sr = 16000
12 | #sr = 8000
13 |
14 | BASE_PATH = r'C:\Users\Kahsolt\Desktop\Workspace\Essay\基于韵律优化与波形修复的汉语语音合成方法研究\audio'
15 |
16 |
17 | def test(fp_y, fp_y_hat):
18 | ref, _ = L.load(BASE_PATH + '\\' + fp_y, sr)
19 | deg, _ = L.load(BASE_PATH + '\\' + fp_y_hat, sr)
20 |
21 | print(pesq(sr, ref, ref, 'wb'))
22 | print(pesq(sr, ref, deg, 'wb'))
23 |
24 | print('[gl]')
25 | test('gl_gt.wav', 'gl_64i.wav')
26 | print('[mlg]')
27 | test('mlg-gt.wav', 'mlg-100e.wav')
28 | print('[hfg-40k]')
29 | test('hfg-gt.wav', 'hfg-40k.wav')
30 | print('[hfg-85k]')
31 | test('hfg-gt.wav', 'hfg-85k.wav')
32 |
33 | print('[taco]')
34 | test('taco-gt.wav', 'taco-103k.wav')
35 |
--------------------------------------------------------------------------------
/retunegan/tools/test_phase_recover.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Author: Armit
3 | # Create Time: 2022/03/10
4 |
5 | import os
6 | import random as R
7 | import librosa as L
8 | import seaborn as sns
9 | import matplotlib.pyplot as plt
10 | import numpy as np
11 | from scipy.io import wavfile
12 |
13 | import hparam as hp
14 |
15 | R.seed(114514)
16 |
17 |
18 | def decompose(D):
19 | P, S = np.angle(D), np.abs(D)
20 | logS = np.log(S.clip(1e-5, None))
21 | return P, S, logS
22 |
23 |
24 | dp = r'C:\Users\Kahsolt\Desktop\Workspace\Data\DataBaker\Wave'
25 | fn = R.choice([fn for fn in os.listdir(dp) if fn.endswith('.wav')])
26 | fp = os.path.join(dp, fn)
27 |
28 | sr, hsr = hp.sample_rate, hp.sample_rate//2
29 | y = L.load(fp, hp.sample_rate)[0]
30 | if len(y) % 2 != 0: y = y[:-1]
31 | ylen = len(y)
32 |
33 |
34 | print('>> effect of noise in time domain')
35 | D = L.stft(y, n_fft=hp.n_fft, hop_length=hp.hop_length, win_length=hp.win_length)
36 | P, S, logS = decompose(D)
37 | for eps in [1e-5, 1e-4, 1e-3, 1e-2, 1e-1, 1]:
38 | y_n = y + np.random.uniform(low=-eps, high=eps, size=y.shape)
39 | D_n = L.stft(y_n, n_fft=hp.n_fft, hop_length=hp.hop_length, win_length=hp.win_length)
40 | P_n, S_n, logS_n = decompose(D_n)
41 | print('eps =', eps)
42 | print('|y - y_n|:', np.mean(np.abs(y - y_n)))
43 | print('|P - P_n|:', np.mean(np.abs(P - P_n)))
44 | print('|S - S_n|:', np.mean(np.abs(S - S_n)))
45 | print('|logS - logS_n|:', np.mean(np.abs(logS - logS_n)))
46 | print()
47 | if False: sns.heatmap(logS_n) ; plt.show()
48 |
49 |
50 | print('>> reverse with original magnitude & original phase')
51 | y_i = L.istft(D, win_length=hp.win_length, hop_length=hp.hop_length, length=ylen)
52 | D_i = L.stft(y_i, n_fft=hp.n_fft, hop_length=hp.hop_length, win_length=hp.win_length)
53 | P_i, S_i, logS_i = decompose(D_i)
54 |
55 | print('|y - y_i|:', np.mean(np.abs(y - y_i)))
56 | print('|P - P_i|:', np.mean(np.abs(P - P_i)))
57 | print('|S - S_i|:', np.mean(np.abs(S - S_i)))
58 | print('|logS - logS_i|:', np.mean(np.abs(logS - logS_i)))
59 | print()
60 |
61 |
62 | print('>> reverse with original magnitude by GriffinLim') # 这个步骤自带一定随机性
63 | y_gl = L.griffinlim(S, hop_length=hp.hop_length, win_length=hp.win_length, length=ylen)
64 | D_gl = L.stft(y_gl, n_fft=hp.n_fft, hop_length=hp.hop_length, win_length=hp.win_length)
65 | P_gl, S_gl, logS_gl = decompose(D_gl)
66 |
67 | print('|y - y_gl|:', np.mean(np.abs(y - y_gl)))
68 | print('|P - P_gl|:', np.mean(np.abs(P - P_gl)))
69 | print('|S - S_gl|:', np.mean(np.abs(S - S_gl)))
70 | print('|logS - logS_gl|:', np.mean(np.abs(logS - logS_gl)))
71 | print()
72 |
73 |
74 | print('>> reverse with original magnitude & random phase')
75 | S_r, logS_r = S, logS
76 | P_r = np.random.uniform(low=-np.pi, high=np.pi, size=P.shape)
77 | D_r = S_r * np.exp(1j * P_r)
78 | y_r = L.istft(D_r, win_length=hp.win_length, hop_length=hp.hop_length, length=ylen)
79 | D_rr = L.stft(y_r, n_fft=hp.n_fft, hop_length=hp.hop_length, win_length=hp.win_length)
80 | P_rr, S_rr, logS_rr = decompose(D_rr)
81 |
82 | print('|y - y_r|:', np.mean(np.abs(y - y_r)))
83 | print('|P - P_r|:', np.mean(np.abs(P - P_r)))
84 | print('|P - P_rr|:', np.mean(np.abs(P - P_rr)))
85 | print('|P_r - P_rr|:', np.mean(np.abs(P_r - P_rr)))
86 | print('|S - S_rr|:', np.mean(np.abs(S - S_rr)))
87 | print('|logS - logS_rr|:', np.mean(np.abs(logS - logS_rr)))
88 | print()
89 |
90 | print('>> reverse with random magnitude & original phase')
91 | # 这些都不行
92 | #S_f = np.exp(np.random.normal(loc=logS.mean(), scale=logS.std(), size=S.shape))/10
93 | # S_f = np.random.normal(loc=S.mean(), scale=S.std(), size=S.shape)
94 | # 这些可以
95 | #S_f = np.random.normal(loc=logS.mean(), scale=logS.std(), size=S.shape)
96 | #S_f = np.random.normal(loc=S.mean(), scale=S.mean(), size=S.shape)
97 | S_f = np.random.uniform(low=S.min(), high=S.max(), size=S.shape)
98 | P_f = P
99 | D_f = S_f * np.exp(1j * P_f)
100 | y_f = L.istft(D_f, win_length=hp.win_length, hop_length=hp.hop_length, length=ylen)
101 | D_fr = L.stft(y_f, n_fft=hp.n_fft, hop_length=hp.hop_length, win_length=hp.win_length)
102 | P_fr, S_fr, logS_fr = decompose(D_fr)
103 |
104 | print('|y - y_f|:', np.mean(np.abs(y - y_f)))
105 | print('|P - P_fr|:', np.mean(np.abs(P - P_fr)))
106 | print('|S - S_f|:', np.mean(np.abs(S - S_f)))
107 | print('|S - S_fr|:', np.mean(np.abs(S - S_fr)))
108 | print()
109 |
110 | #if True: sns.heatmap(S_f) ; plt.show()
111 | #if True: sns.heatmap(P_fr) ; plt.show()
112 | if True: sns.heatmap(S_fr) ; plt.show()
113 |
--------------------------------------------------------------------------------
/retunegan/tools/test_strip_mirror.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Author: Armit
3 | # Create Time: 2022/03/10
4 |
5 | import os
6 | import random as R
7 | import librosa as L
8 | import seaborn as sns
9 | import matplotlib.pyplot as plt
10 | import numpy as np
11 | from scipy.io import wavfile
12 | import torch
13 | import torch.nn as nn
14 |
15 | import hparam as hp
16 |
17 | #R.seed(114514)
18 |
19 |
20 | def decompose(D):
21 | P, S = np.angle(D), np.abs(D)
22 | logS = np.log(S.clip(1e-5, None))
23 | return P, S, logS
24 |
25 |
26 | dp = r'C:\Users\Kahsolt\Desktop\Workspace\Data\DataBaker\Wave'
27 | fn = R.choice([fn for fn in os.listdir(dp) if fn.endswith('.wav')])
28 | fp = os.path.join(dp, fn)
29 |
30 | sr, hsr = hp.sample_rate, hp.sample_rate//2
31 | y = L.load(fp, hp.sample_rate)[0]
32 | if len(y) % 2 != 0: y = y[:-1]
33 | ylen = len(y)
34 |
35 | avgpool = nn.AvgPool1d(hp.downsample_pool_k, 2)
36 | Ty = torch.from_numpy(y).unsqueeze(0).unsqueeze(0)
37 | for i in range(3):
38 | if Ty.shape[-1] % 2 != 0: Ty = Ty[:,:,:-1]
39 | even, odd = Ty[:,:,::2], Ty[:,:,1::2]
40 | diff = even - odd
41 | print(f'strip_mirror_loss({i}):', torch.mean(torch.abs(diff)).item())
42 | Ty = avgpool(Ty)
43 |
44 |
45 | even, odd = y[::2], y[1::2]
46 | mean = (even + odd) / 2
47 | diff = even - odd
48 | print('strip_mirror_loss:', np.mean(np.abs(diff)))
49 |
50 | if False:
51 | plt.subplot(4, 1, 1) ; plt.plot(y)
52 | plt.subplot(4, 1, 2) ; plt.plot(even)
53 | plt.subplot(4, 1, 3) ; plt.plot(odd)
54 | plt.subplot(4, 1, 4) ; plt.plot(diff)
55 | plt.show()
56 |
57 | if False:
58 | wavfile.write('y.wav', sr, y)
59 | wavfile.write('even.wav', hsr, even)
60 | wavfile.write('odd.wav', hsr, odd)
61 | wavfile.write('diff.wav', hsr, diff)
62 |
63 | De = L.stft(even, n_fft=hp.n_fft, hop_length=hp.hop_length, win_length=hp.win_length)
64 | Do = L.stft(odd, n_fft=hp.n_fft, hop_length=hp.hop_length, win_length=hp.win_length)
65 | Dd = L.stft(diff, n_fft=hp.n_fft, hop_length=hp.hop_length, win_length=hp.win_length)
66 | Dm = L.stft(mean, n_fft=hp.n_fft, hop_length=hp.hop_length, win_length=hp.win_length)
67 | Pe, Se, logSe = decompose(De)
68 | Po, So, logSo = decompose(Do)
69 | Pd, Sd, logSd = decompose(Dd)
70 | Pm, Sm, logSm = decompose(Dm)
71 |
72 | print('|Pe - Po|:', np.mean(np.abs(Pe - Po)))
73 | print('|Se - So|:', np.mean(np.abs(Se - So)))
74 | print('|logSe - logSo|:', np.mean(np.abs(logSe - logSo)))
75 | print()
76 |
77 | print('|Pd - Pe|:', np.mean(np.abs(Pd - Pe)))
78 | print('|Sd - Se|:', np.mean(np.abs(Sd - Se)))
79 | print('|logSd - logSe|:', np.mean(np.abs(logSd - logSe)))
80 | print('|Pd - Po|:', np.mean(np.abs(Pd - Po)))
81 | print('|Sd - So|:', np.mean(np.abs(Sd - So)))
82 | print('|logSd - logSo|:', np.mean(np.abs(logSd - logSo)))
83 | print()
84 |
85 |
86 |
--------------------------------------------------------------------------------
/retunegan/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import glob
3 | from time import time
4 |
5 | import torch
6 | import matplotlib
7 | import hparam as hp
8 | if not hp.debug: matplotlib.use("Agg")
9 | import matplotlib.pylab as plt
10 |
11 | LRELU_SLOPE = 0.15
12 | PI = 3.14159265358979
13 |
14 |
15 | # plot
16 | def plot_spectrogram(spec):
17 | fig, ax = plt.subplots(figsize=(10, 2))
18 | im = ax.imshow(spec, aspect="auto", origin="lower", interpolation='none')
19 | plt.colorbar(im, ax=ax)
20 | fig.canvas.draw()
21 | plt.close()
22 | return fig
23 |
24 |
25 | # network
26 | def init_weights(m, mean=0.0, std=0.01):
27 | if 'Conv' in m.__class__.__name__:
28 | #m.weight.data.normal_(mean, std)
29 | torch.nn.init.kaiming_normal_(m.weight.data, nonlinearity='leaky_relu', a=LRELU_SLOPE)
30 |
31 |
32 | def get_padding(kernel_size, dilation=1):
33 | return (kernel_size * dilation - dilation) // 2
34 |
35 |
36 | def get_same_padding(kernel_size, dilation=1):
37 | return dilation * (kernel_size // 2)
38 |
39 |
40 | def truncate_align(x, y):
41 | d = x.shape[-1] - y.shape[-1]
42 | if d != 0:
43 | print('[truncate_align] x.shape:', x.shape, 'y.shape:', y.shape)
44 | if d > 0: x = x[:, :, d //2 : -( d - d //2)]
45 | elif d < 0: y = y[:, :, (-d)//2 : -((-d) - (-d)//2)]
46 | return x, y
47 |
48 |
49 | def get_param_cnt(model):
50 | return sum(param.numel() for param in model.parameters())
51 |
52 |
53 | def stat_grad(model, name):
54 | max_grad, min_grad = 0, 1e5
55 | for p in model.parameters():
56 | vabs = p.abs()
57 | vmin, vmax = vabs.min(), vabs.max()
58 | if vmin < min_grad: min_grad = vmin
59 | if vmax > max_grad: max_grad = vmax
60 | print(f'grad_{name}: max = {max_grad.item()}, min={min_grad.item()}')
61 |
62 |
63 | # ckpt
64 | def load_checkpoint(fp, device):
65 | assert os.path.isfile(fp)
66 | print(f"Loading '{fp}'")
67 | checkpoint_dict = torch.load(fp, map_location=device)
68 | print("Complete.")
69 | return checkpoint_dict
70 |
71 |
72 | def save_checkpoint(fp, obj):
73 | print(f"Saving checkpoint to {fp}")
74 | torch.save(obj, fp)
75 | print("Complete.")
76 |
77 |
78 | def scan_checkpoint(dp, prefix):
79 | pattern = os.path.join(dp, prefix + '*')
80 | cp_list = glob.glob(pattern)
81 | return len(cp_list) and sorted(cp_list)[-1] or None
82 |
83 |
84 | # decorator
85 | def timer(fn):
86 | def wrapper(*args, **kwargs):
87 | start = time()
88 | r = fn(*args, **kwargs)
89 | end = time()
90 | print(f'[Timer]: {fn.__name__} took {end - start:.2f}')
91 | return r
92 | return wrapper
93 |
--------------------------------------------------------------------------------
/stats/DataBaker-stats.txt:
--------------------------------------------------------------------------------
1 | { 'acode_len:avg': 312.1219,
2 | 'acode_len:max': 668,
3 | 'acode_len:min': 66,
4 | 'ap:max': 0.0,
5 | 'ap:min': -35.58398,
6 | 'dyn:max': 0.37510493,
7 | 'dyn:min': 0.0,
8 | 'f0:max': 7577.7227,
9 | 'f0:min': 0.0,
10 | 'mag:max': 2.3220108,
11 | 'mag:min': -8.0,
12 | 'mel:max': 0.3767597,
13 | 'mel:min': -8.0,
14 | 'n_frames': 3121219,
15 | 'n_hours': 40.26364646006551,
16 | 'n_utterances': 10000,
17 | 'num_note:max': 125,
18 | 'num_note:min': 37,
19 | 'pit:max': 11025.0,
20 | 'pit:min': 70.0,
21 | 'sp:max': 6.6273775,
22 | 'sp:min': -36.75616,
23 | 'text_len:avg': 16.2864,
24 | 'text_len:max': 34,
25 | 'text_len:min': 3}
26 |
27 |
28 | 最长 007537.wav
29 | 最短 007245.wav
30 |
--------------------------------------------------------------------------------
/stats/DataBaker.stats:
--------------------------------------------------------------------------------
1 | total_examples 9559
2 | total_hours 9.463644041320231
3 | min_len_txt 6
4 | max_len_txt 27
5 | avg_len_txt 16.08818914112355
6 | min_len_wav 25856
7 | max_len_wav 134144
8 | avg_len_wav 78588.14352965791
9 | min_len_spec 101
10 | max_len_spec 524
11 | avg_len_spec 306.98493566272623
12 | max_mel 2.9994443721145663
13 | min_mel -5.6
14 | max_mag 4.912159066528323
15 | min_mag -5.6
16 | max_f0 595.9459228515625
17 | min_f0 73.25581359863281
18 | max_c0 0.3751049339771271
19 | min_c0 4.6309418394230306e-05
20 |
--------------------------------------------------------------------------------
/stats/DataBaker_gen_stat.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Author: Armit
3 | # Create Time: 2021/01/07
4 |
5 | import tgt
6 | import numpy as np
7 | import pandas as pd
8 | from pathlib import Path
9 | from collections import defaultdict
10 | from os import chdir as cd, getcwd as pwd, listdir
11 |
12 | BASE_PATH = Path(__file__).parent.absolute()
13 | WORKING_DIR = BASE_PATH / 'DataBaker.preproc' / 'TextGrid'
14 | STAT_OUT_FILE_FMT = BASE_PATH / 'DataBaker.stat-%s.csv'
15 |
16 | def collect_stat(by_name='phones'):
17 | durdict = defaultdict(list)
18 | for fn in listdir():
19 | tg = tgt.read_textgrid(fn)
20 | for ph in tg.get_tier_by_name(by_name).intervals:
21 | durdict[ph.text].append(ph.duration())
22 |
23 | stat = {k: (len(v), np.mean(v), np.std(v), np.min(v), np.max(v)) for k, v in durdict.items()}
24 | df = pd.DataFrame(stat, index=['freq', 'mean', 'std', 'min', 'max']).T
25 | df.to_csv(str(STAT_OUT_FILE_FMT) % by_name)
26 |
27 | if __name__ == '__main__':
28 | savedp = pwd()
29 | cd(WORKING_DIR)
30 | for name in ['words', 'phones']:
31 | collect_stat(name)
32 | cd(savedp)
33 |
--------------------------------------------------------------------------------
/stats/DataBaker_print_pinyins.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Author: Armit
3 | # Create Time: 2020/12/16
4 |
5 | # 收集DataBaker中出现过的拼音音节,打印其列表
6 |
7 | import re
8 | from math import ceil
9 | from pathlib import Path
10 | from os import chdir as cd, getcwd as pwd
11 |
12 | SPEC_DIR = 'DataBaker.spec'
13 | BASE_PATH = Path(__file__).parent.absolute() / SPEC_DIR
14 | INDEX_FILE = 'train.txt'
15 |
16 | def collect_symbols():
17 | pinyins = set()
18 | with open(INDEX_FILE) as fh:
19 | samples = fh.read().split('\n')
20 | for s in samples:
21 | if len(s) == 0: continue
22 | txt = s.split('|')[-1]
23 | pinyins = pinyins.union(txt.split(' '))
24 | return sorted(list(pinyins))
25 |
26 | def pprint_symbols(symbols):
27 | n_sym = len(symbols)
28 | SYM_PER_LINE = 15
29 | n_line = ceil(n_sym / SYM_PER_LINE)
30 |
31 | print('_pinyin = [', end='')
32 | for idx, sym in enumerate(symbols):
33 | col = idx % SYM_PER_LINE
34 | if col == 0: print('\n ', end='')
35 | print(f"'{sym}', ", end='')
36 | print(']')
37 |
38 | if __name__ == '__main__':
39 | savedp = pwd()
40 | cd(BASE_PATH)
41 | pprint_symbols(collect_symbols())
42 | cd(savedp)
43 |
--------------------------------------------------------------------------------
/stats/DataBaker_print_symbols.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Author: Armit
3 | # Create Time: 2020/12/16
4 |
5 | # 收集DataBaker中出现过的拼音音节,打印其列表
6 |
7 | import re
8 | from math import ceil
9 | from pathlib import Path
10 | from os import chdir as cd, getcwd as pwd
11 |
12 | SPEC_DIR = 'DataBaker.preproc'
13 | BASE_PATH = Path(__file__).parent.absolute() / SPEC_DIR
14 | INDEX_FILES = ['train.txt', 'val.txt']
15 |
16 | def collect_symbols():
17 | symbols = set()
18 | for fn in INDEX_FILES:
19 | with open(fn) as fh:
20 | samples = fh.read().split('\n')
21 | for s in samples:
22 | if len(s) == 0: continue
23 | _, txt = s.split('|')
24 | symbols = symbols.union(txt[1:-1].split(' '))
25 | return sorted(list(symbols))
26 |
27 | def pprint_symbols(symbols):
28 | n_sym = len(symbols)
29 | SYM_PER_LINE = 15
30 | n_line = ceil(n_sym / SYM_PER_LINE)
31 |
32 | print('_pinyin = [', end='')
33 | for idx, sym in enumerate(symbols):
34 | col = idx % SYM_PER_LINE
35 | if col == 0: print('\n ', end='')
36 | print(f"'{sym}', ", end='')
37 | print(']')
38 |
39 | if __name__ == '__main__':
40 | savedp = pwd()
41 | cd(BASE_PATH)
42 | pprint_symbols(collect_symbols())
43 | cd(savedp)
44 |
--------------------------------------------------------------------------------
/stats/inspect_preproc.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Author: Armit
3 | # Create Time: 2021/03/04
4 |
5 | from sys import argv
6 | import numpy as np
7 | from random import randrange
8 | import matplotlib.pyplot as plt
9 | import seaborn as sns
10 |
11 | ID = (len(argv) >= 2 and argv[1] or str(randrange(10000) + 1)).rjust(6,'0')
12 |
13 | plt.subplot(311)
14 | plt.title('energy')
15 | e = np.load('DataBaker.preproc/energy/DataBaker-energy-%s.npy' % ID)
16 | plt.xlim(0, len(e))
17 | plt.plot(e)
18 |
19 | plt.subplot(312)
20 | plt.title('f0')
21 | f = np.load('DataBaker.preproc/f0/DataBaker-f0-%s.npy' % ID)
22 | plt.xlim(0, len(f))
23 | plt.plot(f)
24 |
25 | plt.subplot(313)
26 | plt.title('mel')
27 | m = np.load('DataBaker.preproc/mel/DataBaker-mel-%s.npy' % ID)
28 | sns.heatmap(m.T[::-1], cbar=False)
29 |
30 | plt.show()
31 |
--------------------------------------------------------------------------------
/stats/inspect_spec.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Author: Armit
3 | # Create Time: 2021/03/23
4 |
5 | from sys import argv
6 | import numpy as np
7 | from random import randrange
8 | import matplotlib.pyplot as plt
9 | import seaborn as sns
10 |
11 | ID = (len(argv) >= 2 and argv[1] or str(randrange(20) + 1)).rjust(6,'0')
12 |
13 | plt.subplot(211)
14 | plt.title('linear spec')
15 | m = np.load('DataBaker.spec/databaker-spec-%s.npy' % ID)
16 | sns.heatmap(m.T[::-1], cbar=False)
17 |
18 | plt.subplot(212)
19 | plt.title('mel spec')
20 | m = np.load('DataBaker.spec/databaker-mel-%s.npy' % ID)
21 | sns.heatmap(m.T[::-1], cbar=False)
22 |
23 | plt.show()
24 |
--------------------------------------------------------------------------------
/stats/thchs30_gen_vbanks.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Author: Armit
3 | # Create Time: 2020/11/14
4 |
5 | # 将thchs30按音色分库,产生几个train.txt
6 |
7 | import re
8 | from pathlib import Path
9 | from os import chdir as cd, getcwd as pwd
10 | from collections import defaultdict
11 |
12 | SPEC_DIR = 'thchs30.spec'
13 | INDEX_FILE = 'train.txt'
14 | BASE_PATH = Path(__file__).parent.absolute() / SPEC_DIR
15 | R = re.compile(r'-([ABCD]\d+)_')
16 |
17 | MALE_LIST = [ 'A8', 'B8', 'C8', 'D8' ]
18 | FEMALE_POWER_LIST = [ 'A2', 'A4', 'A6', 'A14', 'A22', 'A34', 'B4', 'B6', 'B12', 'B22', 'B31', 'C4', 'C6', 'C31', 'D6', 'D31', 'D32' ]
19 | FEMALE_SOFT_LIST = [ 'A7', 'A11', 'A19', 'B7', 'C7', 'C14', 'C17', 'C18', 'C20', 'C32', 'D7', 'D11' ]
20 | CHILD_LIST = [ 'A13', 'B11', 'C12', 'C13', 'C19', 'C21', 'C22', 'D21' ]
21 |
22 | # => {'uid': ['sample_config1', ...]}
23 | def read_index(fn=INDEX_FILE) -> defaultdict:
24 | index_dict = defaultdict(list)
25 | with open(INDEX_FILE) as fh:
26 | samples = fh.read().split('\n')
27 | for s in samples:
28 | if len(s) == 0: continue
29 | spec_fn, mel_fn, _, txt = s.split('|')
30 | id = R.findall(spec_fn)[0]
31 | index_dict[id].append(s)
32 | return index_dict
33 |
34 | def write_index(fn, vbank):
35 | with open(fn, 'w') as fh:
36 | for s in vbank:
37 | fh.write(s)
38 | fh.write('\n')
39 |
40 | # => ['sample_config1', ...]
41 | def gen_index(index, vt) -> list:
42 | uid_list = globals().get(vt.upper() + '_LIST', [])
43 | vbank = [ ]
44 | for uid in uid_list:
45 | vbank.extend(index[uid])
46 | return vbank
47 |
48 | if __name__ == '__main__':
49 | savedp = pwd()
50 | cd(BASE_PATH)
51 | index = read_index()
52 | for vt in ['male', 'female_power', 'female_soft', 'child']:
53 | vbank = gen_index(index, vt)
54 | write_index('vbank_' + vt + '.txt', vbank)
55 | cd(savedp)
56 |
--------------------------------------------------------------------------------
/stats/thchs30_print_symbols.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Author: Armit
3 | # Create Time: 2020/11/23
4 |
5 | # 收集thchs30中出现过的拼音音节,打印其列表
6 |
7 | import re
8 | from math import ceil
9 | from pathlib import Path
10 | from os import chdir as cd, getcwd as pwd
11 |
12 | SPEC_DIR = 'thchs30.spec'
13 | BASE_PATH = Path(__file__).parent.absolute() / SPEC_DIR
14 | INDEX_FILE = 'train.txt'
15 |
16 | def collect_symbols():
17 | symbols = set()
18 | with open(INDEX_FILE) as fh:
19 | samples = fh.read().split('\n')
20 | for s in samples:
21 | if len(s) == 0: continue
22 | _, _, _, txt = s.split('|')
23 | symbols = symbols.union(txt.split(' '))
24 | return sorted(list(symbols))
25 |
26 | def pprint_symbols(symbols):
27 | n_sym = len(symbols)
28 | SYM_PER_LINE = 15
29 | n_line = ceil(n_sym / SYM_PER_LINE)
30 |
31 | print('_pinyin = [', end='')
32 | for idx, sym in enumerate(symbols):
33 | col = idx % SYM_PER_LINE
34 | if col == 0: print('\n ', end='')
35 | print(f"'{sym}', ", end='')
36 | print(']')
37 |
38 |
39 | if __name__ == '__main__':
40 | savedp = pwd()
41 | cd(BASE_PATH)
42 | pprint_symbols(collect_symbols())
43 | cd(savedp)
44 |
--------------------------------------------------------------------------------
/transtacos/Makefile:
--------------------------------------------------------------------------------
1 | ifeq ($(shell uname -s), Linux)
2 | BASE_PATH=~/Data
3 | else
4 | BASE_PATH=D:/Desktop/Workspace/Data
5 | endif
6 |
7 | DATASET=DataBaker
8 | #LOG_NAME=tts-$(DATASET).$(VER)
9 | LOG_NAME=tts-$(DATASET)
10 | LOG_PATH=${BASE_PATH}/$(LOG_NAME)
11 |
12 |
13 | .PHONY: train test server clean stat
14 |
15 | train:
16 | python train.py \
17 | --base_dir $(BASE_PATH) \
18 | --input $(DATASET).tts_processed/train.txt \
19 | --name $(LOG_NAME) \
20 | --summary_interval 500 \
21 | --checkpoint_interval 1000
22 |
23 | preprocess:
24 | python preprocess.py \
25 | --base_dir $(BASE_PATH) \
26 | --out_dir $(DATASET).tts_processed \
27 | --dataset $(shell echo $(DATASET) | tr '[A-Z]' '[a-z]')
28 |
29 | server:
30 | python server.py \
31 | --log_path $(LOG_PATH)
32 |
33 | test_server:
34 | python server.py \
35 | --log_path $(LOG_PATH) \
36 | --port 5103
37 |
38 | stat:
39 | tensorboard \
40 | --logdir $(LOG_PATH) \
41 | --port 5103
42 |
43 | clean:
44 | rm -rf $(LOG_PATH)
45 |
--------------------------------------------------------------------------------
/transtacos/audio.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Author: Armit
3 | # Create Time: 2022/01/07
4 |
5 | import numpy as np
6 | import librosa as L
7 | from scipy import signal
8 | from scipy.io import wavfile
9 |
10 | import hparam as hp
11 |
12 |
13 | eps = 1e-5
14 | firwin = signal.firwin(hp.n_freq, [hp.fmin, hp.fmax], pass_zero=False, fs=hp.sample_rate)
15 | rf0min = L.note_to_hz(hp.rf0min) if isinstance(hp.rf0min, str) else float(hp.rf0min)
16 | rf0max = L.note_to_hz(hp.rf0max) if isinstance(hp.rf0max, str) else float(hp.rf0max)
17 | c0min = hp.c0min
18 | c0max = hp.c0max
19 | qt_f0min = int(np.floor(L.hz_to_midi(hp.f0min)))
20 | qt_f0max = int(np.ceil (L.hz_to_midi(hp.f0max)))
21 |
22 | hp.n_f0_min = qt_f0min
23 | hp.n_f0_bins = qt_f0max - qt_f0min + 1
24 |
25 | print(f'c0: min={c0min} max={c0max} n_bins={hp.n_c0_bins}')
26 | print(f'qt_f0: min={qt_f0min} max={qt_f0max} n_bins={hp.n_f0_bins}')
27 |
28 |
29 | def load_wav(path): # float values in range (-1,1)
30 | y, _ = L.load(path, sr=hp.sample_rate, mono=True, res_type='kaiser_best')
31 | return y.astype(np.float32) # [T,]
32 |
33 | def save_wav(wav, path):
34 | if hp.postprocess:
35 | # rescaling for unified measure for all clips
36 | # NOTE: normalize amplification
37 | wav = wav / np.abs(wav).max() * 0.999
38 | # factor 0.5 in case of overflow for int16
39 | f1 = 0.5 * 32767 / max(0.01, np.max(np.abs(wav)))
40 | # sublinear scaling as Y ~ X ^ k (k < 1)
41 | f2 = np.sign(wav) * np.power(np.abs(wav), 0.667)
42 | wav = f1 * f2
43 |
44 | # bandpass for less noises
45 | wav = signal.convolve(wav, firwin)
46 |
47 | wavfile.write(path, hp.sample_rate, wav.astype(np.int16))
48 | else:
49 | wavfile.write(path, hp.sample_rate, wav.astype(np.float32))
50 |
51 |
52 | def align_wav(wav, r=hp.hop_length):
53 | d = len(wav) % r
54 | if d != 0:
55 | wav = np.pad(wav, (0, (r - d)))
56 | return wav
57 |
58 |
59 | def trim_silence(wav, frame_length=512, hop_length=128):
60 | # 人声动态一般高达55dB
61 | return L.effects.trim(wav, top_db=hp.trim_below_peak_db, frame_length=frame_length, hop_length=hop_length)[0]
62 |
63 |
64 | def preemphasis(x): # 增强了高频, 听起来有点远场
65 | # x[i] = x[i] - k * x[i-1], k ie. preemphasis
66 | return signal.lfilter([1, -hp.preemphasis], [1], x)
67 |
68 |
69 | def inv_preemphasis(x): # undo preemphasis, should be put after Griffin-Lim
70 | return signal.lfilter([1], [1, -hp.preemphasis], x)
71 |
72 |
73 | def get_specs(y):
74 | D = np.abs(_stft(preemphasis(y)))
75 | S = _amp_to_db(D) - hp.ref_level_db
76 | M = _amp_to_db(_linear_to_mel(D)) - hp.ref_level_db
77 | return (_normalize(S), _normalize(M))
78 |
79 |
80 | def spec_to_natural_scale(spec):
81 | '''from inner normalized scale to raw scale of stft output'''
82 | return _db_to_amp(_denormalize(spec) + hp.ref_level_db)
83 |
84 |
85 | def fix_zero_DC(S):
86 | F, T = S.shape
87 | if F == hp.n_freq - 1: # NOTE: preprend zero DC component
88 | S = np.concatenate([np.ones([1, T]) * S.min() * 1e-2, S], axis=0)
89 | #S = np.concatenate([np.zeros([1, T]), S], axis=0)
90 | return S
91 |
92 |
93 | def inv_spec(spec):
94 | S = spec_to_natural_scale(spec) # denorm
95 | S = fix_zero_DC(S)
96 | wav = inv_preemphasis(_griffin_lim(S ** hp.gl_power)) # reconstruct phase
97 | return wav.astype(np.float32)
98 |
99 |
100 | def inv_mel(mel): # This might have no use case
101 | M = spec_to_natural_scale(mel) # denorm
102 | S = _mel_to_linear(M) # back to linear
103 | wav = inv_preemphasis(_griffin_lim(S ** hp.gl_power)) # reconstruct phase
104 | return wav.astype(np.float32)
105 |
106 |
107 | def get_f0(y):
108 | f0 = L.yin(y, fmin=rf0min, fmax=rf0max, frame_length=hp.win_length, hop_length=hp.hop_length)
109 | return f0.astype(np.float32) # [T,]
110 |
111 |
112 | def get_c0(y):
113 | c0 = L.feature.rms(y=y, frame_length=hp.win_length, hop_length=hp.hop_length)[0]
114 | return c0.astype(np.float32) # [T,]
115 |
116 |
117 | def quantilize_f0(f0):
118 | f0 = np.asarray([L.hz_to_midi(f) - hp.n_f0_min for f in f0])
119 | f0 = f0.clip(0, hp.n_f0_bins - 1)
120 | return f0.astype(np.int32) # [T,]
121 |
122 |
123 | def quantilize_c0(c0):
124 | c0 = (c0 - c0min) / (c0max - c0min)
125 | c0 = c0 * hp.n_c0_bins
126 | c0 = c0.clip(0, hp.n_c0_bins - 1)
127 | return c0.astype(np.int32) # [T,]
128 |
129 |
130 | def _griffin_lim(S):
131 | '''librosa implementation of Griffin-Lim
132 | Based on https://github.com/librosa/librosa/issues/434
133 | '''
134 | angles = np.exp(2j * np.pi * np.random.rand(*S.shape))
135 | S_complex = np.abs(S).astype(np.complex)
136 | y = _istft(S_complex * angles)
137 | for i in range(hp.gl_iters):
138 | angles = np.exp(1j * np.angle(_stft(y)))
139 | y = _istft(S_complex * angles)
140 | return y
141 |
142 |
143 | def _stft(y):
144 | return L.stft(y=y, n_fft=hp.n_fft, hop_length=hp.hop_length, win_length=hp.win_length)
145 |
146 |
147 | def _istft(y):
148 | return L.istft(y, hop_length=hp.hop_length, win_length=hp.win_length)
149 |
150 |
151 | _mel_basis = None
152 | _linear_basis = None
153 |
154 | def _linear_to_mel(spec):
155 | return np.dot(_get_mel_basis(), spec)
156 |
157 | def _get_mel_basis():
158 | global _mel_basis
159 | if _mel_basis is None:
160 | assert hp.fmax < hp.sample_rate // 2
161 | _mel_basis = L.filters.mel(hp.sample_rate, hp.n_fft, n_mels=hp.n_mel, fmin=hp.fmin, fmax=hp.fmax)
162 | return _mel_basis
163 |
164 | def _mel_to_linear(mel):
165 | return np.dot(_get_linear_basis(), mel)
166 |
167 | def _get_linear_basis():
168 | global _linear_basis
169 | if _linear_basis is None:
170 | m = _get_mel_basis()
171 | m_T = np.transpose(m)
172 | p = np.matmul(m, m_T)
173 | d = [1.0 / x if np.abs(x) > 1.0e-8 else x for x in np.sum(p, axis=0)]
174 | _linear_basis = np.matmul(m_T, np.diag(d))
175 | return _linear_basis
176 |
177 | def _amp_to_db(x):
178 | # 人耳可听的声压范围为2e-5~20Pa,对应的声压级范围为0~120dB
179 | # 声压级公式 SPL = 20 * log10(p_e/p_ref)
180 | # 其中p_e为声压/振幅,参考声压p_ref为人耳最低可听觉声压、空气中一般取2e-5
181 | # 即有 SPL = 20 * log10(p_e/2e-5)
182 | # = 20 * (log10(p_e) - log10(2e-5))
183 | # ~= 20 * log10(p_e) - 94
184 | return 20 * np.log10(np.maximum(1e-5, x)) # 下截断、可有可无,只是作heatmap时方便查看
185 |
186 | def _db_to_amp(x):
187 | # =10^(x/20)
188 | return np.power(10.0, x * 0.05)
189 |
190 | def _normalize(S):
191 | # mapping [hp.min_level_db, 0] => [-hp.max_abs_value, hp.max_abs_value]
192 | # typically: [-100, 0] => [-4, 4]
193 | return 2 * hp.max_abs_value * ((S - hp.min_level_db) / -hp.min_level_db) - hp.max_abs_value
194 |
195 | def _denormalize(S):
196 | return ((S + hp.max_abs_value) * -hp.min_level_db) / (2 * hp.max_abs_value) + hp.min_level_db
197 |
--------------------------------------------------------------------------------
/transtacos/data.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 | from threading import Thread
4 | import traceback
5 |
6 | import numpy as np
7 | import tensorflow as tf
8 |
9 | import hparam as hp
10 | from text.text import text_to_phoneme, phoneme_to_sequence
11 | from text.symbols import _eos, _sep, get_vocab_size
12 | from text.phonodict_cn import phonodict
13 | from audio import quantilize_c0, quantilize_f0
14 |
15 |
16 | _batches_per_group = hp.batch_size
17 | _pad = 0 # FIXME: hardcoded for /'_' fortext-token
18 |
19 |
20 | class DataFeeder(Thread):
21 | '''Feeds batches of data into a queue on a background thread.'''
22 |
23 | def __init__(self, coordinator, metadata_fp, hparams):
24 | super(DataFeeder, self).__init__()
25 | self._hparams = hparams
26 | self._session = None
27 | self._coord = coordinator
28 | self._offset = 0
29 |
30 | # Load metadata & make cache:
31 | self._datadir = os.path.dirname(metadata_fp)
32 | with open(metadata_fp, encoding='utf-8') as f:
33 | self._metadata = [line.strip().split('|') for line in f]
34 | self.data = [None] * len(self._metadata)
35 |
36 | # Create placeholders for inputs and targets. Don't specify batch size because we want to
37 | # be able to feed different sized batches at eval time.
38 | if hp.g2p == 'seq':
39 | text_shape = [None, None]
40 | elif hp.g2p == 'syl4':
41 | text_shape = [None, None, 2]
42 | self._placeholders = [
43 | tf.placeholder(tf.int32, [None], 'text_lengths'),
44 | tf.placeholder(tf.int32, text_shape, 'text'),
45 | tf.placeholder(tf.int32, [None, None], 'prds'),
46 | tf.placeholder(tf.int32, [None], 'spec_lengths'),
47 | tf.placeholder(tf.float32, [None, None, hparams.n_mel], 'mel_targets'),
48 | tf.placeholder(tf.float32, [None, None, hparams.n_freq-1], 'mag_targets'),
49 | tf.placeholder(tf.int32, [None, None], 'f0_targets'),
50 | tf.placeholder(tf.int32, [None, None], 'c0_targets'),
51 | tf.placeholder(tf.float32, [None, None], 'stop_token_targets')
52 | ]
53 |
54 | # Create queue for buffering data:
55 | queue = tf.FIFOQueue(hp.batch_size, [h.dtype for h in self._placeholders], name='input_queue')
56 | self._enqueue_op = queue.enqueue(self._placeholders)
57 | holders = queue.dequeue()
58 | for i, holder in enumerate(holders): holder.set_shape(self._placeholders[i].shape)
59 | (self.text_lengths, self.text, self.prds, self.spec_lengths, self.mel_targets, self.mag_targets, self.f0_targets, self.c0_targets, self.stop_token_targets) = holders
60 |
61 | def start_in_session(self, session):
62 | self._session = session
63 | self.start()
64 |
65 | def run(self):
66 | try:
67 | while not self._coord.should_stop():
68 | self._enqueue_next_group()
69 | except Exception as e:
70 | traceback.print_exc()
71 | self._coord.request_stop(e)
72 |
73 | def _enqueue_next_group(self):
74 | def _get_next_example():
75 | '''Loads a single example (input, mel_target, mag_target, stop_token_target, len(spec)) from memory cached'''
76 |
77 | if self._offset >= len(self.data): # infinit loop
78 | self._offset = 0
79 | random.shuffle(self.data)
80 |
81 | if self.data[self._offset] is None:
82 | self.load_data(self._offset)
83 | data = self.data[self._offset]
84 | self._offset += 1
85 | return data
86 |
87 | # Read a group of examples:
88 | n = self._hparams.batch_size
89 | r = self._hparams.outputs_per_step
90 | examples = [_get_next_example() for _ in range(n * _batches_per_group)] # group = batch_size * batches_per_group
91 |
92 | # Bucket examples based on similar output sequence length (spec n_frames) for efficiency
93 | # NOTE: 按照输出的帧长度排序,而不是输入的文本长度!
94 | examples.sort(key=lambda x: len(x[-1]))
95 | batches = [examples[i:i+n] for i in range(0, len(examples), n)] # split to batches
96 | random.shuffle(batches)
97 |
98 | for batch in batches:
99 | feed_dict = dict(zip(self._placeholders, _prepare_batch(batch, r)))
100 | self._session.run(self._enqueue_op, feed_dict=feed_dict)
101 |
102 | def load_data(self, index):
103 | '''Loads all examples [(input, mel_target, mag_target, stop_token_target, len(spec))] from disk'''
104 |
105 | meta = self._metadata[index] # meta: (id, len(spec), text)
106 | id, prds, text = meta
107 | if hp.g2p == 'seq':
108 | # NOTE: pad here, then convert to id_seq
109 | seq = phoneme_to_sequence(text_to_phoneme(text + _eos))
110 | prds = [int(d) for d in prds]
111 | elif hp.g2p == 'syl4':
112 | C, V, T, Vx = text_to_phoneme(text) # [[str]]
113 | prds = [int(d) for d in prds]
114 | try:
115 | assert len(C) == len(prds)
116 | except:
117 | breakpoint()
118 |
119 | CVVx, Tx, P = [ ], [ ], [ ]
120 | n_syllable = len(C)
121 | for i in range(n_syllable):
122 | if C[i] != phonodict.vacant:
123 | CVVx.append(C[i]) ; Tx.append(T[i]) ; P.append(0)
124 | if V[i] != phonodict.vacant:
125 | CVVx.append(V[i]) ; Tx.append(T[i]) ; P.append(0)
126 | if Vx[i] != phonodict.vacant:
127 | CVVx.append(Vx[i]) ; Tx.append(T[i]) ; P.append(0)
128 |
129 | CVVx.append(_sep) ; Tx.append(0) ; P.append(prds[i])
130 |
131 | # NOTE: pad here, then convert to id_seq
132 | CVVx = phoneme_to_sequence(CVVx + [_eos]) # see phone table
133 | Tx = [int(t) for t in Tx] + [0] # should be 0 ~ 5
134 | for i in range(len(P) - 2, -1, -1):
135 | if P[i] == 0:
136 | P[i] = P[i + 1]
137 | P = P + [5] # should be 0 ~ 5
138 |
139 | try:
140 | assert len(CVVx) == len(Tx) == len(P)
141 | assert 0 <= min(CVVx) and max(CVVx) < get_vocab_size()
142 | assert 0 <= min(P) and max(P) < hp.n_prds
143 | assert 0 <= min(Tx) and max(Tx) < hp.n_tone
144 | except:
145 | breakpoint()
146 |
147 | seq = np.stack([CVVx, Tx], axis=-1) # [T, 2]
148 | prds = P
149 | else: raise
150 |
151 | text = np.asarray(seq, dtype=np.int32)
152 | prds = np.asarray(prds, dtype=np.int32)
153 | mel_target = np.load(os.path.join(self._datadir, f'mel-{id}.npy')).T # [T, F]
154 | mag_target = np.load(os.path.join(self._datadir, f'mag-{id}.npy')).T # [T, M]
155 | f0_target = np.load(os.path.join(self._datadir, f'f0-{id}.npy'))
156 | c0_target = np.load(os.path.join(self._datadir, f'c0-{id}.npy'))
157 | stop_token_target = np.zeros(mel_target.shape[0]) # NOTE: 在有数据的mel帧上,初始化停止概率为0,另参见`_pad_stop_token_target()`
158 |
159 | mag_target = mag_target[:, 1:] # remove DC
160 | f0_target = quantilize_f0(f0_target)
161 | c0_target = quantilize_c0(c0_target)
162 | try:
163 | assert 0 <= min(f0_target) and max(f0_target) < hp.n_f0_bins
164 | assert 0 <= min(c0_target) and max(c0_target) < hp.n_c0_bins
165 | except:
166 | breakpoint()
167 |
168 | #breakpoint()
169 |
170 | self.data[index] = (text, prds, mel_target, mag_target, f0_target, c0_target, stop_token_target)
171 |
172 | def _prepare_batch(batch, outputs_per_step): # FIXME: what for `outputs_per_step`
173 | # batch of data, one line per sample (input/id_seq, mel, mag, stop_token, len(spec))
174 | random.shuffle(batch)
175 | # pad within batch
176 | text_lengths = np.asarray([len(x[0]) for x in batch], dtype=np.int32)
177 | if hp.g2p == 'seq': text = _prepare_inputs ([x[0] for x in batch])
178 | elif hp.g2p == 'syl4': text = _prepare_inputs_2d([x[0] for x in batch])
179 | else: raise
180 | prds = _prepare_inputs([x[1] for x in batch])
181 | spec_lengths = np.asarray([len(x[2]) for x in batch], dtype=np.int32)
182 | mel_targets = _prepare_targets([x[2] for x in batch], outputs_per_step)
183 | mag_targets = _prepare_targets([x[3] for x in batch], outputs_per_step)
184 | f0_targets = _prepare_stop_token_targets([x[4] for x in batch], outputs_per_step, 0)
185 | c0_targets = _prepare_stop_token_targets([x[5] for x in batch], outputs_per_step, 0)
186 | stop_token_targets = _prepare_stop_token_targets([x[6] for x in batch], outputs_per_step, 1.0)
187 | return (text_lengths, text, prds, spec_lengths, mel_targets, mag_targets, f0_targets, c0_targets, stop_token_targets)
188 |
189 |
190 | def _prepare_inputs(inputs):
191 | def _pad_input(x:list, length): # pad =0 for text
192 | return np.pad(x, (0, length - x.shape[0]), mode='constant', constant_values=_pad)
193 |
194 | max_len = max((len(x) for x in inputs)) # 填充到该batch中最长的长度,seq在前面已经附加了
195 | return np.stack([_pad_input(x, max_len) for x in inputs])
196 |
197 |
198 | def _prepare_inputs_2d(inputs):
199 | def _pad_input(x:list, length): # pad =0 for text
200 | return np.pad(x, [(0, length - x.shape[0]), (0, 0)], mode='constant', constant_values=_pad)
201 |
202 | max_len = max((len(x) for x in inputs)) # 填充到该batch中最长的长度,seq在前面已经附加了
203 | return np.stack([_pad_input(x, max_len) for x in inputs])
204 |
205 |
206 | def _prepare_targets(targets, r):
207 | def _pad_target(x, length): # pad =0.0 for spec
208 | return np.pad(x, [(0, length - x.shape[0]), (0, 0)], mode='constant', constant_values=x.min())
209 |
210 | max_len = max((len(t) for t in targets)) + 1 # +1 for
211 | max_len = _round_up(max_len, r) # 上取整到r的整数倍
212 | return np.stack([_pad_target(x, max_len) for x in targets])
213 |
214 |
215 | def _prepare_stop_token_targets(targets, r, pad_val):
216 | def _pad_stop_token_target(x, length):
217 | return np.pad(x, (0, length - x.shape[0]), mode='constant', constant_values=pad_val) # NOTE: 对于填充的尾帧,初始化停止概率为1
218 |
219 | max_len = max((len(t) for t in targets)) + 1 # +1 for
220 | max_len = _round_up(max_len, r) # 上取整到r的整数倍
221 | return np.stack([_pad_stop_token_target(x, max_len) for x in targets])
222 |
223 |
224 | def _round_up(x, multiple): # 向上捨入到multiple的整數倍
225 | remainder = x % multiple
226 | return x if remainder == 0 else (x + multiple - remainder)
227 |
--------------------------------------------------------------------------------
/transtacos/datasets/__skel__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Author: Armit
3 | # Create Time: 2022/01/10
4 |
5 | # write your own dataset preprocesser
6 | # here's the template :)
7 |
8 | import os
9 | from typing import List, Tuple
10 |
11 |
12 | def preprocess(args) -> Tuple[List[Tuple], dict, str]:
13 | # wav_dp is a file path string, pointing to the folder containing *.wav files
14 | wav_dp = os.path.join(args.base_path, 'dataset', 'wavs')
15 |
16 | # metadata is a list containing textual informatin
17 | metadata = [
18 | # for exmaple, name-text pairs
19 | ('00001', 'this is an exmaple'),
20 | ('00002', 'this is another exmaple'),
21 | ('00003', 'yet another exmaple'),
22 | ]
23 |
24 | # stats is a dictionary about statistcs
25 | stats = {
26 | 'min_len_txt': 18,
27 | 'max_len_txt': 23,
28 | 'avg_len_txt': 20.0,
29 | 'min_len_wav': 100,
30 | 'max_len_wav': 200,
31 | 'avg_len_wav': 150.0,
32 | }
33 |
34 | # return them all
35 | return metadata, stats, wav_dp
36 |
--------------------------------------------------------------------------------
/transtacos/datasets/databaker.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Author: Armit
3 | # Create Time: 2022/01/10
4 |
5 | import os
6 | from re import compile as Regex
7 | from collections import defaultdict
8 | from concurrent.futures import ProcessPoolExecutor
9 | from functools import partial
10 | from typing import Dict, List, Tuple
11 |
12 | import numpy as np
13 | from tqdm import tqdm
14 |
15 | import hparam as hp
16 | import audio as A
17 |
18 | DROPOUT_2SIGMA = True
19 |
20 |
21 | # identically collected from `DataBaker`
22 | PUNCT_KANJI_REGEX = Regex(r',|。|、|:|;|?|!|(|)|“|”|…|—')
23 |
24 |
25 | def preprocess(args) -> Tuple[List[Tuple], dict]:
26 | wav_dp = os.path.join(args.base_dir, 'DataBaker', 'Wave')
27 | out_dp = os.path.join(args.base_dir, args.out_dir)
28 | os.makedirs(out_dp, exist_ok=True)
29 | label_dict = parse_label_file(os.path.join(args.base_dir, 'DataBaker', 'ProsodyLabeling', '000001-010000.txt'))
30 |
31 | executor = ProcessPoolExecutor(max_workers=args.num_workers)
32 | futures = []
33 | for name, feats in label_dict.items():
34 | wav_fp = os.path.join(wav_dp, f'{name}.wav')
35 | futures.append(executor.submit(partial(make_metadata, name, feats, wav_fp, out_dp)))
36 | # (name, prds, text, len_text, len_wav, len_spec, stats)
37 | metadata = [future.result() for future in tqdm(futures)]
38 | metadata = [mt for mt in metadata if mt is not None]
39 |
40 | # onely use sample within 2-sigma (95.45%) range of gauss distribution
41 | if DROPOUT_2SIGMA:
42 | tlens = np.asarray([mt[-4] for mt in metadata]) # mt[-4] := len_text
43 | tlens_mu, tlens_sigma = tlens.mean(), tlens.std()
44 | tlen_L = tlens_mu - 2 * tlens_sigma
45 | tlen_R = tlens_mu + 2 * tlens_sigma
46 | alens = np.asarray([mt[-2] for mt in metadata]) # mt[-2] := len_spec
47 | alens_mu, alens_sigma = alens.mean(), alens.std()
48 | alen_L = alens_mu - 2 * alens_sigma
49 | alen_R = alens_mu + 2 * alens_sigma
50 |
51 | metadata_filtered = []
52 | for mt in metadata:
53 | if not tlen_L <= mt[-4] <= tlen_R: continue
54 | if not alen_L <= mt[-2] <= alen_R: continue
55 | metadata_filtered.append(mt)
56 | else:
57 | metadata_filtered = metadata
58 |
59 | len_text = np.asarray([mt[-4] for mt in metadata_filtered])
60 | len_wav = np.asarray([mt[-3] for mt in metadata_filtered])
61 | len_spec = np.asarray([mt[-2] for mt in metadata_filtered])
62 | stat_dicts = np.asarray([mt[-1] for mt in metadata_filtered])
63 | stats_agg = defaultdict(list)
64 | for stat in stat_dicts:
65 | for k, v in stat.items():
66 | stats_agg[k].append(v)
67 | stats_agg = { k: np.asarray(v) for k, v in stats_agg.items() }
68 |
69 | stats = {
70 | 'total_examples': len(metadata_filtered),
71 | 'total_hours': len_wav.sum() / hp.sample_rate / (60 * 60),
72 | 'min_len_txt': len_text.min(), # n_pinyins
73 | 'max_len_txt': len_text.max(),
74 | 'avg_len_txt': len_text.mean(),
75 | 'min_len_wav': len_wav.min(), # n_samples
76 | 'max_len_wav': len_wav.max(),
77 | 'avg_len_wav': len_wav.mean(),
78 | 'min_len_spec': len_spec.min(), # n_frames
79 | 'max_len_spec': len_spec.max(),
80 | 'avg_len_spec': len_spec.mean(),
81 | }
82 | for k, v in stats_agg.items():
83 | try:
84 | agg_fn = k[:k.find('_')]
85 | stats[k] = getattr(v, agg_fn)()
86 | except:
87 | print(f'unknown aggregate method for {k}')
88 |
89 | # (name, prds, text)
90 | metadata = [mt[:3] for mt in metadata_filtered]
91 | return metadata, stats, wav_dp
92 |
93 |
94 | def make_metadata(name, feats, wav_fp, out_dp):
95 | if not os.path.exists(wav_fp): return None
96 | text, prds = feats
97 | len_text = len(text.split(' '))
98 | if not len_text == len(prds): return None
99 |
100 | y = A.load_wav(wav_fp)
101 | y = A.trim_silence(y)
102 | y = A.align_wav(y)
103 | len_wav = len(y)
104 |
105 | y_cut = y[:-1]
106 | mag, mel = A.get_specs(y_cut) # [M, T], [F, T]
107 | f0 = A.get_f0 (y_cut) # [T,]
108 | c0 = A.get_c0 (y_cut) # [T,]
109 | len_spec = mel.shape[1]
110 |
111 | assert len_wav == len_spec * hp.hop_length
112 |
113 | np.save(os.path.join(out_dp, f'mel-{name}.npy'), mel, allow_pickle=False)
114 | np.save(os.path.join(out_dp, f'mag-{name}.npy'), mag, allow_pickle=False)
115 | np.save(os.path.join(out_dp, f'f0-{name}.npy'), f0, allow_pickle=False)
116 | np.save(os.path.join(out_dp, f'c0-{name}.npy'), c0, allow_pickle=False)
117 |
118 | stats = {
119 | 'max_mel': mel.max(), 'min_mel': mel.min(),
120 | 'max_mag': mag.max(), 'min_mag': mag.min(),
121 | 'max_f0' : f0 .max(), 'min_f0' : f0 .min(),
122 | 'max_c0' : c0 .max(), 'min_c0' : c0 .min(),
123 | }
124 | return (name, prds, text, len_text, len_wav, len_spec, stats)
125 |
126 |
127 | def parse_label_file(fp) -> Dict[str, Tuple[str, str]]:
128 | '''prodosy:
129 | 0: 词内部
130 | 1: 连读的分词末
131 | 2: 长音或停顿的分词末
132 | 3: 分句末
133 | 4: 句末
134 | 5: 末尾标记
135 | 韵律的层级:
136 | 音节串#0 -> 单词串#1 -> 短语串#2 -> 短句串#3 -> 整句#4
137 | '''
138 |
139 | r = { }
140 | with open(fp, encoding='utf-8') as fh:
141 | while True:
142 | name_kanji = fh.readline().strip()
143 | if not name_kanji: break
144 |
145 | name, kanji = name_kanji.split('\t') # '002333', '这是个#1例子#2'
146 | pinyin = fh.readline().strip().lower() # 'zhe4 shi4 ge4 li4 zi5'
147 | kanji = PUNCT_KANJI_REGEX.sub('', kanji)
148 |
149 | prodosy = []
150 | for k in kanji:
151 | if k == '#': continue
152 | if k.isdigit():
153 | if prodosy: prodosy[-1] = k
154 | else: prodosy.append(k)
155 | else: prodosy.append('0')
156 | prodosy = ''.join(prodosy) # '00102'
157 |
158 | r[name] = (pinyin, prodosy)
159 | return r
160 |
--------------------------------------------------------------------------------
/transtacos/datasets/thchs30.py:
--------------------------------------------------------------------------------
1 | import os
2 | from functools import partial
3 | from concurrent.futures import ProcessPoolExecutor
4 |
5 | import numpy as np
6 | from tqdm import tqdm
7 |
8 | import audio
9 |
10 | # NOTE: this is broken, do not use without modify!
11 |
12 |
13 | def preprocess(args):
14 | in_dir = os.path.join(args.base_dir, 'thchs30')
15 | if not os.path.exists(in_dir):
16 | in_dir = os.path.join(args.base_dir, 'data_thchs30')
17 | out_dir = os.path.join(args.base_dir, args.out_dir)
18 | os.makedirs(out_dir, exist_ok=True)
19 |
20 | executor = ProcessPoolExecutor(max_workers=args.num_workers)
21 | futures = []
22 | dp = os.path.join(in_dir, 'data')
23 | for fn in (fn for fn in os.listdir(dp) if fn.endswith('.wav')):
24 | wav_path = os.path.join(dp, fn)
25 | with open(wav_path + '.trn', encoding='utf8') as fh:
26 | fh.readline() # ignore first line (kanji)
27 | text = fh.readline().strip() # use pinyin only
28 | id = os.path.splitext(fn)[0] # '_'
29 | futures.append(executor.submit(partial(_process_utterance, out_dir, id, wav_path, text)))
30 | return [future.result() for future in tqdm(futures)]
31 |
32 |
33 | def _process_utterance(out_dir, id, wav_path, text):
34 | wav = audio.load_wav(wav_path)
35 | wav = audio.trim_silence(wav)
36 |
37 | mag, mel = audio.get_specs(wav)
38 | mag = mag.astype(np.float32)
39 | mel = mel.astype(np.float32)
40 | n_frames = mel.shape[1]
41 |
42 | spec_fn = 'thchs30-spec-%s.npy' % id
43 | np.save(os.path.join(out_dir, spec_fn), mag.T, allow_pickle=False)
44 | mel_fn = 'thchs30-mel-%s.npy' % id
45 | np.save(os.path.join(out_dir, mel_fn), mel.T, allow_pickle=False)
46 |
47 | return (spec_fn, mel_fn, n_frames, text)
48 |
--------------------------------------------------------------------------------
/transtacos/hparam.py:
--------------------------------------------------------------------------------
1 | # Text
2 | g2p = 'syl4' # ['seq', 'syl4']
3 |
4 | # Audio
5 | sample_rate = 22050 # sample rate (Hz) of wav file
6 | n_fft = 2048
7 | win_length = 1024 # :=n_fft//2
8 | hop_length = 256 # :=win_length//4, 11.6ms, 平均1个拼音音素对应9帧(min2~max20)
9 | n_mel = 80 # MEL谱段数 (default: 160), 120 should be more reasonable
10 | n_freq = 1025 # 线性谱段数 :=n_fft//2+1
11 | preemphasis = 0.97 # 增强高频,使EQ均衡
12 | ref_level_db = 20 # 最高考虑的谱幅值(虚拟0dB),理论上安静环境下取94,但实际上录音越嘈杂该值应越小 (default: 20)
13 | min_level_db = -100 # 最低考虑的谱幅值,用于动态范围截断压缩 (default: -100)
14 | max_abs_value = 4 # 将谱幅值正则化到 [-max_abs_value, max_abs_value]
15 | trim_below_peak_db = 35 # trim beginning/ending silence parts (default:60)
16 | fmin = 125 # MEL滤波器组频率上下限 (set 55/3600 for male)
17 | fmax = 7600
18 | rf0min = 'D2' # 基频检测上下限
19 | rf0max = 'D5'
20 |
21 | ## see `Databaker stats` or `stats.txt` in preprocessed folder
22 | c0min = 4.6309418394230306e-05
23 | c0max = 0.3751049339771271
24 | f0min = 73.25581359863281
25 | f0max = 595.9459228515625
26 | n_tone = 5+1
27 | n_prds = 5+1
28 | n_c0_bins = 32
29 | n_f0_bins = None # keep None for auto detect using f0min & f0max
30 | n_f0_min = None # as offset
31 | maxlen_text = 128 # for pos_enc, 27 in train set
32 | maxlen_spec = 1024 # for pos_enc, 524 in train set
33 |
34 | # Model
35 | outputs_per_step = 5 # default: 5 (aka. reduction factor r), 某种意义上的韵律量化,r越小韵律越精细、但rnn可能记不住长序列
36 | # 一般来说r取半音素的平均长度,这对元音连接类似于VCV
37 | hidden_gauss_std = 1e-5
38 |
39 | embed_depth = 256 # text/prds embed depth
40 | var_embed_depth = 64 # f0/c0 embed depth
41 | posenc_depth = 32 # pos_enc embed depth
42 | txt_use_posenc = True # 不需要PE好像才能学出sa的对角线(?)
43 | var_use_posenc = True # 需要PE才能学出ca的对角线
44 | prdsnet_depth = 64
45 | prdsnet_conv_k = 9
46 | embed_dropout = False
47 |
48 | encoder_depth = 256 # aka. inner_repr_depth
49 | encoder_type = 'sa' # ['sa', 'cb']
50 | if encoder_type == 'sa': # like FastSpeech2
51 | encoder_attn_layers = 2 # NOTE: set 4 will lead to nan in loss (grad vanish?)
52 | encoder_attn_nhead = 2
53 | encoder_dropout = False
54 | encoder_fusenet = True
55 | gffw_conv_k = 9
56 | var_prednet_depth = 64
57 | var_prednet_conv_k = 13
58 | if encoder_type == 'cb': # like Tacotron
59 | encoder_conv_K = 16
60 | highway_layers = 4
61 |
62 | decoder_layers = 2 # single layer is not enough
63 | decoder_depth = 512 # default: 1024
64 | attention_depth = 128 # single LSA
65 | prenet_depths = [256] # prenet for decoder RNN, single layer seems enough
66 | decoder_sew_layer = False
67 |
68 | n_mel_low = 42
69 | posnet_depth = 512
70 | posnet_ngroup = 8
71 |
72 | # Training
73 | max_steps = 320000 # force stop train
74 | max_ckpt = 1
75 | batch_size = 16
76 | adam_beta1 = 0.9
77 | adam_beta2 = 0.999 # 0.98
78 | adam_eps = 1e-7
79 | reg_weight = 1e-6 # 1e-8
80 | sim_weight = 1e-5
81 | initial_learning_rate = 0.001
82 | decay_learning_rate = True # decrease learning rate by step, see `models.tacotron._learning_rate_decay()`
83 | tf_method = 'mix' # ['random', 'mix', 'force']
84 | tf_init = 1.0
85 | tf_start_decay = 20000 # default: 20000
86 | tf_decay = 200000 # default: 200000
87 |
88 | # Eval
89 | max_iters = 300 # max iter of RNN, 最多产生max_iters*r帧、防止无限生成
90 | gl_iters = 30 # griffin_lim algorithm iters (default: 60)
91 | gl_power = 1.2 # Power to raise magnitudes to prior to Griffin-Lim
92 | postprocess = False # see `audio.save_wav()`
93 |
94 | # MISC
95 | randseed = 114514
96 | debug = False
97 |
--------------------------------------------------------------------------------
/transtacos/models/attention.py:
--------------------------------------------------------------------------------
1 | """Attention file for location based attention (compatible with tensorflow attention wrapper)"""
2 |
3 | import tensorflow as tf
4 | from tensorflow.contrib.seq2seq.python.ops.attention_wrapper import BahdanauAttention
5 | from tensorflow.python.ops import array_ops, variable_scope
6 |
7 |
8 | def _location_sensitive_score(W_query, W_fil, W_keys):
9 | """Impelements Bahdanau-style (cumulative) scoring function.
10 | This attention is described in:
11 | J. K. Chorowski, D. Bahdanau, D. Serdyuk, K. Cho, and Y. Ben-
12 | gio, “Attention-based models for speech recognition,” in Ad-
13 | vances in Neural Information Processing Systems, 2015, pp.
14 | 577–585.
15 |
16 | #############################################################################
17 | hybrid attention (content-based + location-based)
18 | f = F * α_{i-1}
19 | energy = dot(v_a, tanh(W_keys(h_enc) + W_query(h_dec) + W_fil(f) + b_a))
20 | #############################################################################
21 |
22 | Args:
23 | W_query: Tensor, shape '[batch_size, 1, attention_dim]' to compare to location features.
24 | W_location: processed previous alignments into location features, shape '[batch_size, max_time, attention_dim]'
25 | W_keys: Tensor, shape '[batch_size, max_time, attention_dim]', typically the encoder outputs.
26 | Returns:
27 | A '[batch_size, max_time]' attention score (energy)
28 | """
29 | # Get the number of hidden units from the trailing dimension of keys
30 | dtype = W_query.dtype
31 | num_units = W_keys.shape[-1].value or array_ops.shape(W_keys)[-1]
32 |
33 | v_a = tf.get_variable(
34 | 'attention_variable', shape=[num_units], dtype=dtype,
35 | initializer=tf.contrib.layers.xavier_initializer())
36 | b_a = tf.get_variable(
37 | 'attention_bias', shape=[num_units], dtype=dtype,
38 | initializer=tf.zeros_initializer())
39 |
40 | return tf.reduce_sum(v_a * tf.tanh(W_keys + W_query + W_fil + b_a), [2])
41 |
42 |
43 | class LocationSensitiveAttention(BahdanauAttention):
44 | """Impelements Bahdanau-style (cumulative) scoring function.
45 | Usually referred to as "hybrid" attention (content-based + location-based)
46 | Extends the additive attention described in:
47 | "D. Bahdanau, K. Cho, and Y. Bengio, “Neural machine transla-
48 | tion by jointly learning to align and translate,” in Proceedings
49 | of ICLR, 2015."
50 | to use previous alignments as additional location features.
51 |
52 | This attention is described in:
53 | J. K. Chorowski, D. Bahdanau, D. Serdyuk, K. Cho, and Y. Ben-
54 | gio, “Attention-based models for speech recognition,” in Ad-
55 | vances in Neural Information Processing Systems, 2015, pp.
56 | 577–585.
57 | """
58 |
59 | def __init__(self,
60 | num_units,
61 | memory,
62 | memory_sequence_length=None,
63 | cumulate_weights=True,
64 | name='LocationSensitiveAttention'):
65 | """Construct the Attention mechanism.
66 | Args:
67 | num_units: The depth of the query mechanism.
68 | memory: The memory to query; usually the output of an RNN encoder. This
69 | tensor should be shaped `[batch_size, max_time, ...]`.
70 | memory_sequence_length (optional): Sequence lengths for the batch entries
71 | in memory. If provided, the memory tensor rows are masked with zeros
72 | for values past the respective sequence lengths. Only relevant if mask_encoder = True.
73 | name: Name to use when creating ops.
74 | """
75 | #Create normalization function
76 | #Setting it to None defaults in using softmax
77 | super(LocationSensitiveAttention, self).__init__(
78 | num_units=num_units,
79 | memory=memory,
80 | memory_sequence_length=memory_sequence_length,
81 | probability_fn=None,
82 | name=name)
83 |
84 | self.location_convolution = tf.layers.Conv1D(filters=32,
85 | kernel_size=(31, ), padding='same', use_bias=True, # original: 31
86 | bias_initializer=tf.zeros_initializer(), name='location_features_convolution')
87 | self.location_layer = tf.layers.Dense(units=num_units, use_bias=False,
88 | dtype=tf.float32, name='location_features_layer')
89 | self._cumulate = cumulate_weights
90 |
91 | def __call__(self, query, state):
92 | """Score the query based on the keys and values.
93 | Args:
94 | query: Tensor of dtype matching `self.values` and shape
95 | `[batch_size, query_depth]`.
96 | state (previous alignments): Tensor of dtype matching `self.values` and shape
97 | `[batch_size, alignments_size]`
98 | (`alignments_size` is memory's `max_time`).
99 | Returns:
100 | alignments: Tensor of dtype matching `self.values` and shape
101 | `[batch_size, alignments_size]` (`alignments_size` is memory's
102 | `max_time`).
103 | """
104 | previous_alignments = state
105 | with variable_scope.variable_scope(None, "Location_Sensitive_Attention", [query]):
106 |
107 | # processed_query shape [batch_size, query_depth] -> [batch_size, attention_dim]
108 | processed_query = self.query_layer(query) if self.query_layer else query
109 | # -> [batch_size, 1, attention_dim]
110 | processed_query = tf.expand_dims(processed_query, 1)
111 |
112 | # processed_location_features shape [batch_size, max_time, attention dimension]
113 | # [batch_size, max_time] -> [batch_size, max_time, 1]
114 | expanded_alignments = tf.expand_dims(previous_alignments, axis=2)
115 | # location features [batch_size, max_time, filters]
116 | f = self.location_convolution(expanded_alignments)
117 | # Projected location features [batch_size, max_time, attention_dim]
118 | processed_location_features = self.location_layer(f)
119 |
120 | # energy shape [batch_size, max_time]
121 | energy = _location_sensitive_score(processed_query, processed_location_features, self.keys)
122 |
123 | # alignments shape = energy shape = [batch_size, max_time]
124 | alignments = self._probability_fn(energy, previous_alignments)
125 |
126 | # Cumulate alignments
127 | if self._cumulate:
128 | next_state = alignments + previous_alignments
129 | else:
130 | next_state = alignments
131 |
132 | return alignments, next_state
133 |
--------------------------------------------------------------------------------
/transtacos/models/custom_decoder.py:
--------------------------------------------------------------------------------
1 | from collections import namedtuple
2 |
3 | import numpy as np
4 | import tensorflow as tf
5 | from tensorflow.contrib.seq2seq import Helper, Decoder
6 | from tensorflow.python.framework import ops, tensor_shape
7 | from tensorflow.python.layers import base as layers_base
8 | from tensorflow.python.ops import rnn_cell_impl
9 | from tensorflow.python.util import nest
10 |
11 | import hparam as hp
12 |
13 |
14 | # Adapted from tf.contrib.seq2seq.GreedyEmbeddingHelper
15 | class TacoTestHelper(Helper):
16 | def __init__(self, batch_size, output_dim, r):
17 | with tf.name_scope('TacoTestHelper'):
18 | self._batch_size = batch_size
19 | self._output_dim = output_dim # output_dim==n_mels
20 | self._reduction_factor = r
21 |
22 | @property
23 | def batch_size(self): # IGNORED
24 | return self._batch_size
25 |
26 | @property
27 | def token_output_size(self): # IGNORED
28 | return self._reduction_factor # 每次吐r帧
29 |
30 | @property
31 | def sample_ids_shape(self): # IGNORED
32 | return tf.TensorShape([])
33 |
34 | @property
35 | def sample_ids_dtype(self): # IGNORED
36 | return np.int32
37 |
38 | def sample(self, time, outputs, state, name=None): # IGNORED
39 | return tf.tile([0], [self._batch_size]) # Return all 0; we ignore them
40 |
41 | def initialize(self, name=None): # init 1d vetor of False
42 | return (tf.tile([False], [self._batch_size]), _go_frames(self._batch_size, self._output_dim)) # append
43 |
44 | def next_inputs(self, time, outputs, state, sample_ids, stop_token_preds, name=None):
45 | '''Stop on EOS. Otherwise, pass the last output as the next input and pass through state.'''
46 | with tf.name_scope('TacoTestHelper'):
47 | # A sequence is finished when the stop token probability is > 0.5
48 | # With enough training steps, the model should be able to predict when to stop correctly
49 | # and the use of stop_at_any = True would be recommended. If however the model didn't
50 | # learn to stop correctly yet, (stops too soon) one could choose to use the safer option
51 | # to get a correct synthesis
52 | # 难以通过判断产生了一个静音的mel帧来判定生成结束,所以用了一个与mel帧等长的stop_token向量
53 | # 一旦向量中(理论上应该靠近末端)出现了接近1.0的值即可认为生成结束,另参见`datafeeder._pad_stop_token_target()`
54 | # NOTE: 为了容错,不能只看stop_token_preds[-1]
55 | finished = tf.reduce_any(tf.cast(tf.round(stop_token_preds), tf.bool))
56 |
57 | # Feed last output frame as next input. outputs is [N, output_dim * r]
58 | next_inputs = outputs[:, -self._output_dim:] # take last frame of a frame group
59 | return (finished, next_inputs, state)
60 |
61 |
62 | class TacoTrainingHelper(Helper):
63 | def __init__(self, batch_size, targets, output_dim, r, global_step):
64 | # inputs is [N, T_in], targets is [N, T_out, D]
65 | with tf.name_scope('TacoTrainingHelper'):
66 | self._batch_size = batch_size
67 | self._output_dim = output_dim # =n_mels
68 | self._reduction_factor = r
69 | self._ratio = None
70 | self.global_step = global_step
71 |
72 | # Feed every r-th target frame as input
73 | self._targets = targets[:, r-1::r, :] # 每个帧组的最后一帧, every r-th frame
74 |
75 | # Use full length for every target because we don't want to mask the padding frames
76 | num_steps = tf.shape(self._targets)[1] # =max_timesetps # FIXME: why cannot stop early?
77 | self._lengths = tf.tile([num_steps], [self._batch_size])
78 |
79 | @property
80 | def batch_size(self): # IGNORED
81 | return self._batch_size
82 |
83 | @property
84 | def token_output_size(self): # IGNORED
85 | return self._reduction_factor
86 |
87 | @property
88 | def sample_ids_shape(self): # IGNORED
89 | return tf.TensorShape([])
90 |
91 | @property
92 | def sample_ids_dtype(self): # IGNORED
93 | return np.int32
94 |
95 | def sample(self, time, outputs, state, name=None): # IGNORED
96 | return tf.tile([0], [self._batch_size]) # Return all 0; we ignore them
97 |
98 | def initialize(self, name=None):
99 | self._ratio = _teacher_forcing_ratio_decay(hp.tf_init, self.global_step)
100 | return (tf.tile([False], [self._batch_size]), _go_frames(self._batch_size, self._output_dim)) # append
101 |
102 | def next_inputs(self, time, outputs, state, sample_ids, stop_token_preds, name='TacoTrainingHelper'):
103 | with tf.name_scope(name):
104 | finished = (time + 1 >= self._lengths) # 训练时读到最后一帧mel就算结束,即这个batch的最大帧组长度
105 |
106 | if hp.tf_method == 'force':
107 | next_inputs = self._targets[:, time, :]
108 | elif hp.tf_method == 'random':
109 | next_inputs = tf.cond(tf.less(tf.random_uniform([], minval=0, maxval=1, dtype=tf.float32), self._ratio),
110 | lambda: self._targets[:, time, :],
111 | lambda: outputs[:, -self._output_dim:])
112 | elif hp.tf_method == 'mix':
113 | next_inputs = self._ratio * self._targets[:, time, :] + (1 - self._ratio) * outputs[:, -self._output_dim:]
114 | else: raise ValueError
115 |
116 | return (finished, next_inputs, state)
117 |
118 |
119 | def _go_frames(batch_size, output_dim):
120 | '''Returns all-zero frames for a given batch size and output dimension'''
121 | return tf.tile([[0.0]], [batch_size, output_dim])
122 |
123 |
124 | def _teacher_forcing_ratio_decay(init_tfr, global_step):
125 | #################################################################
126 | # Narrow Cosine Decay:
127 |
128 | # Phase 1: tfr = 1
129 | # We only start learning rate decay after 10k steps
130 |
131 | # Phase 2: tfr in [0, 1]
132 | # decay reach minimal value at step ~280k
133 |
134 | # Phase 3: tfr = 0
135 | # clip by minimal teacher forcing ratio value (step >~ 280k)
136 | #################################################################
137 | # Compute natural cosine decay
138 | tfr = tf.train.cosine_decay(init_tfr,
139 | global_step=global_step - hp.tf_start_decay, # tfr = 1 at step 10k, (original: 20000)
140 | decay_steps=hp.tf_decay, # tfr = 0 at step ~280k, (original: 200000)
141 | alpha=0., # tfr = 0% of init_tfr as final value
142 | name='tfr_cosine_decay')
143 |
144 | # force teacher forcing ratio to take initial value when global step < start decay step.
145 | # NOTE: narrow_tfr = global_step < 10000 ? init_tfr : tfr
146 | narrow_tfr = tf.cond(
147 | tf.less(global_step, tf.convert_to_tensor(hp.tf_start_decay)), # original: 20000
148 | lambda: tf.convert_to_tensor(init_tfr),
149 | lambda: tfr)
150 |
151 | return narrow_tfr
152 |
153 |
154 | class CustomDecoderOutput(
155 | namedtuple("CustomDecoderOutput", ("rnn_output", "token_output", "sample_id"))):
156 | pass
157 |
158 |
159 | class CustomDecoder(Decoder):
160 | """Custom sampling decoder.
161 |
162 | Allows for stop token prediction at inference time
163 | and returns equivalent loss in training time.
164 |
165 | Note:
166 | Only use this decoder with Tacotron 2 as it only accepts tacotron custom helpers
167 | """
168 |
169 | def __init__(self, cell, helper, initial_state, output_layer=None):
170 | """Initialize CustomDecoder.
171 | Args:
172 | cell: An `RNNCell` instance.
173 | helper: A `Helper` instance.
174 | initial_state: A (possibly nested tuple of...) tensors and TensorArrays.
175 | The initial state of the RNNCell.
176 | output_layer: (Optional) An instance of `tf.layers.Layer`, i.e.,
177 | `tf.layers.Dense`. Optional layer to apply to the RNN output prior
178 | to storing the result or sampling.
179 | Raises:
180 | TypeError: if `cell`, `helper` or `output_layer` have an incorrect type.
181 | """
182 | rnn_cell_impl.assert_like_rnncell(type(cell), cell)
183 | if not isinstance(helper, Helper):
184 | raise TypeError("helper must be a Helper, received: %s" % type(helper))
185 | if (output_layer is not None and not isinstance(output_layer, layers_base.Layer)):
186 | raise TypeError("output_layer must be a Layer, received: %s" % type(output_layer))
187 | self._cell = cell
188 | self._helper = helper
189 | self._initial_state = initial_state
190 | self._output_layer = output_layer
191 |
192 | @property
193 | def batch_size(self):
194 | return self._helper.batch_size
195 |
196 | def _rnn_output_size(self):
197 | size = self._cell.output_size
198 | if self._output_layer is None:
199 | return size
200 | else:
201 | # To use layer's compute_output_shape, we need to convert the
202 | # RNNCell's output_size entries into shapes with an unknown
203 | # batch size. We then pass this through the layer's
204 | # compute_output_shape and read off all but the first (batch)
205 | # dimensions to get the output size of the rnn with the layer
206 | # applied to the top.
207 | output_shape_with_unknown_batch = nest.map_structure(
208 | lambda s: tensor_shape.TensorShape([None]).concatenate(s),
209 | size)
210 | layer_output_shape = self._output_layer._compute_output_shape( # pylint: disable=protected-access
211 | output_shape_with_unknown_batch)
212 | return nest.map_structure(lambda s: s[1:], layer_output_shape)
213 |
214 | @property
215 | def output_size(self):
216 | # Return the cell output and the id
217 | return CustomDecoderOutput(
218 | rnn_output=self._rnn_output_size(),
219 | token_output=self._helper.token_output_size,
220 | sample_id=self._helper.sample_ids_shape)
221 |
222 | @property
223 | def output_dtype(self):
224 | # Assume the dtype of the cell is the output_size structure
225 | # containing the input_state's first component's dtype.
226 | # Return that structure and the sample_ids_dtype from the helper.
227 | dtype = nest.flatten(self._initial_state)[0].dtype
228 | return CustomDecoderOutput(
229 | nest.map_structure(lambda _: dtype, self._rnn_output_size()),
230 | tf.float32,
231 | self._helper.sample_ids_dtype)
232 |
233 | def initialize(self, name=None):
234 | """Initialize the decoder.
235 | Args:
236 | name: Name scope for any created operations.
237 | Returns:
238 | `(finished, first_inputs, initial_state)`.
239 | """
240 | return self._helper.initialize() + (self._initial_state,)
241 |
242 | def step(self, time, inputs, state, name=None):
243 | """Perform a custom decoding step.
244 | Enables for dyanmic prediction
245 | Args:
246 | time: scalar `int32` tensor.
247 | inputs: A (structure of) input tensors.
248 | state: A (structure of) state tensors and TensorArrays.
249 | name: Name scope for any created operations.
250 | Returns:
251 | `(outputs, next_state, next_inputs, finished)`.
252 | """
253 | with ops.name_scope(name, "CustomDecoderStep", (time, inputs, state)):
254 | #Call outputprojection wrapper cell
255 | (cell_outputs, stop_token), cell_state = self._cell(inputs, state)
256 |
257 | #apply output_layer (if existant)
258 | if self._output_layer is not None:
259 | cell_outputs = self._output_layer(cell_outputs)
260 | sample_ids = self._helper.sample(
261 | time=time, outputs=cell_outputs, state=cell_state)
262 |
263 | (finished, next_inputs, next_state) = self._helper.next_inputs(
264 | time=time,
265 | outputs=cell_outputs,
266 | state=cell_state,
267 | sample_ids=sample_ids,
268 | stop_token_preds=stop_token)
269 |
270 | outputs = CustomDecoderOutput(cell_outputs, stop_token, sample_ids)
271 | return (outputs, next_state, next_inputs, finished)
272 |
--------------------------------------------------------------------------------
/transtacos/models/modules.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import tensorflow as tf
3 | from tensorflow.contrib.rnn import GRUCell
4 | from tensorflow.keras.backend import batch_dot
5 |
6 | import hparam as hp
7 |
8 | REUSE = False
9 |
10 |
11 | ''' below is for tacotron compactible '''
12 |
13 | def prenet(inputs, layer_sizes, is_training, scope='prenet'):
14 | x = inputs
15 | drop_rate = 0.5 if is_training else 0.0 # NOTE: only dropout on trainning
16 | with tf.variable_scope(scope):
17 | # NOTE: chained i-layers of dense and dropout
18 | for i, size in enumerate(layer_sizes):
19 | dense = tf.layers.dense(x, units=size, activation=tf.nn.relu, name='dense_%d' % (i+1))
20 | x = tf.layers.dropout(dense, rate=drop_rate, training=is_training, name='dropout_%d' % (i+1))
21 | return x
22 |
23 |
24 | def conv1d(inputs, k, filters, activation, is_training, scope='conv1d'):
25 | with tf.variable_scope(scope):
26 | conv = tf.layers.conv1d(
27 | inputs,
28 | filters=filters,
29 | kernel_size=k,
30 | activation=None,
31 | padding='same')
32 | bn = tf.layers.batch_normalization(conv, training=is_training)
33 | return activation(bn)
34 |
35 |
36 | def highwaynet(inputs, depth, scope='highwaynet'):
37 | with tf.variable_scope(scope):
38 | H = tf.layers.dense(
39 | inputs,
40 | units=depth,
41 | activation=tf.nn.relu,
42 | name='H')
43 | T = tf.layers.dense(
44 | inputs,
45 | units=depth,
46 | activation=tf.nn.sigmoid,
47 | name='T',
48 | bias_initializer=tf.constant_initializer(-1.0))
49 | return H * T + inputs * (1.0 - T)
50 |
51 |
52 | def cbhg(inputs, input_lengths, K, proj_dims, depth, is_training, scope='cbhg'):
53 | with tf.variable_scope(scope):
54 | with tf.variable_scope('conv_bank'):
55 | # Convolution bank: concatenate on the last axis to connect channels from all convolutions
56 | conv = tf.concat(
57 | [conv1d(inputs, k+1, depth//2, tf.nn.relu, is_training, 'conv1d_%d' % (k+1)) for k in range(K)],
58 | axis=-1)
59 |
60 | # Maxpooling:
61 | conv = tf.layers.max_pooling1d(
62 | conv,
63 | pool_size=2,
64 | strides=1,
65 | padding='same')
66 |
67 | # Two projection layers: reduce depth
68 | proj = conv1d(conv, 3, proj_dims[0], tf.nn.relu, is_training, 'proj_1') # depth: 896(7*128) -> 128
69 | proj = conv1d(proj, 3, proj_dims[1], lambda _:_, is_training, 'proj_2') # depth: 128 -> 256
70 |
71 | # Residual connection:
72 | # now we marge `acoustics (phoneme)` with `pardosy (text context)`
73 | highway_input = inputs + proj
74 |
75 | # Handle dimensionality mismatch:
76 | if highway_input.shape[-1] != depth:
77 | highway_input = tf.layers.dense(highway_input, depth)
78 | # 4-layer HighwayNet:
79 | for i in range(hp.highway_layers):
80 | highway_input = highwaynet(highway_input, depth, 'highway_%d' % (i+1))
81 |
82 | # Bidirectional RNN
83 | outputs, states = tf.nn.bidirectional_dynamic_rnn(
84 | GRUCell(depth//2),
85 | GRUCell(depth//2),
86 | highway_input,
87 | sequence_length=input_lengths,
88 | dtype=tf.float32)
89 |
90 | return tf.concat(outputs, axis=-1) # Concat forward and backward
91 |
92 |
93 | ''' below is my stuff '''
94 |
95 | def gaussian_noise(x, is_training):
96 | if hp.hidden_gauss_std:
97 | x = tf.keras.layers.GaussianNoise(hp.hidden_gauss_std)(x, training=is_training)
98 | return x
99 |
100 |
101 | def conv_stack(x, n_layers, k, d_in, d_out, activation=tf.nn.relu, scope='conv_stack'):
102 | with tf.variable_scope(scope, reuse=REUSE):
103 | for i in range(n_layers-1):
104 | x = tf.layers.conv1d(x, d_in, k, padding='same', name=f'conv{i+1}')
105 | x = activation(x)
106 | x = tf.layers.conv1d(x, d_out, k, padding='same', name=f'conv{n_layers}')
107 | return x
108 |
109 |
110 | def dot_attn(x, y, mask, attn_dim, scope='dot_attn'):
111 | with tf.variable_scope(scope, reuse=REUSE):
112 | # [B, N, A]
113 | q = tf.layers.dense(x, attn_dim, name='q')
114 | # [B, T, A]
115 | k = tf.layers.dense(y, attn_dim, name='k')
116 | v = tf.layers.dense(y, attn_dim, name='v')
117 |
118 | # [B, N, T]
119 | e = tf.matmul(q, k, transpose_b=True)
120 | e = e * mask + (1 - mask) * -1e8 # mask energe to inf
121 | e = e / tf.sqrt(tf.cast(hp.encoder_depth, tf.float32))
122 | sc = tf.nn.softmax(e, axis=-1)
123 |
124 | # [B, N, A]
125 | r = tf.matmul(sc, v)
126 |
127 | return r, sc
128 |
129 |
130 | def GLU(inputs, depth, k=7, activation=None, scope='GLU'):
131 | with tf.variable_scope(scope):
132 | conv = tf.layers.conv1d(
133 | inputs,
134 | filters=depth*2,
135 | kernel_size=k,
136 | activation=activation,
137 | padding='same',
138 | name='conv')
139 |
140 | x, gate = tf.split(conv, 2, axis=-1)
141 | if activation: x = activation(x)
142 | gate = tf.nn.sigmoid(gate)
143 |
144 | return x * gate
145 |
146 |
147 | def gffw(x, depth, scope='gffw'):
148 | with tf.variable_scope(scope, reuse=REUSE):
149 | o = GLU(x, depth, k=hp.gffw_conv_k, activation=tf.nn.leaky_relu, scope='GLU')
150 | o = tf.layers.conv1d(o, depth, 1, padding='same', name='conv_pointwise')
151 | return o
152 |
153 |
154 | def get_sinusoid_encoding_table(n_position, d_hid, padding_idx=None):
155 |
156 | def cal_angle(position, hid_idx):
157 | return position / np.power(10000, 2 * (hid_idx // 2) / d_hid)
158 |
159 | def get_posi_angle_vec(position):
160 | return [cal_angle(position, hid_j) for hid_j in range(d_hid)]
161 |
162 | sinusoid_table = np.array([get_posi_angle_vec(pos_i) for pos_i in range(n_position)])
163 |
164 | sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
165 | sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
166 |
167 | # zero vector for padding dimension
168 | if padding_idx is not None: sinusoid_table[padding_idx] = 0.
169 |
170 | # [1, K, D]
171 | return tf.expand_dims(tf.convert_to_tensor(sinusoid_table, dtype=tf.float32), axis=0)
172 |
173 |
174 | def get_attn_mask(xlen, max_xlen, ylen=None, max_ylen=None):
175 | if ylen is None and max_ylen is None: ylen, max_ylen = xlen, max_xlen
176 | x_unary = tf.expand_dims(tf.sequence_mask(xlen, max_xlen, dtype=tf.float32), 1)
177 | y_unary = tf.expand_dims(tf.sequence_mask(ylen, max_ylen, dtype=tf.float32), 1)
178 | mask = batch_dot(tf.transpose(x_unary, [0, 2, 1]), y_unary)
179 | return mask
180 |
181 |
182 | def encoder_sa(x, x_len, f0, c0, y_len, is_training, scope='encoder'):
183 | depth = hp.encoder_depth
184 |
185 | with tf.variable_scope(scope, reuse=REUSE):
186 | # prenet
187 | if hp.txt_use_posenc:
188 | x = tf.layers.dense(x, depth, None, name='prenet')
189 | if hp.encoder_dropout:
190 | x = tf.layers.dropout(x, rate=0.2, training=is_training, name='dropout')
191 |
192 | '''multi-head self-attn: acoustics => global prosody'''
193 | slf_attns = []
194 | max_xlen = tf.shape(x)[-2] # dynamic padded length N
195 | slf_mask = get_attn_mask(x_len, max_xlen)
196 | # D: 256 -> 64*4 -> 256
197 | for i in range(hp.encoder_attn_layers):
198 | # multi-head sa
199 | rs, attns = [], []
200 | #x = tf.keras.layers.LayerNormalization()(x) # pre-norm
201 | for h in range(hp.encoder_attn_nhead):
202 | r, sc = dot_attn(x, x, slf_mask, depth // hp.encoder_attn_nhead, scope=f'sa_{i}_{h}')
203 | rs.append(r) ; attns.append(sc)
204 | slf_attns.append(attns)
205 |
206 | # combine multi-head
207 | sa = tf.layers.dense(tf.concat(rs, axis=-1), depth, name=f'proj_sa_{i}')
208 | if hp.encoder_dropout:
209 | sa = tf.layers.dropout(sa, rate=0.2, training=is_training, name='dropout')
210 |
211 | # transform (gffw)
212 | x = x + gffw(x + sa, depth, scope=f'gffw_sa_{i}')
213 | #x = tf.keras.layers.LayerNormalization()(x) # post-norm
214 |
215 | ''' fusenet '''
216 | crx_attns = []
217 | f0_r = c0_r = f0_r_pred = c0_r_pred = 0.0
218 | if hp.encoder_fusenet:
219 | f0_r_pred = conv_stack(x, 2, hp.var_prednet_conv_k, hp.var_prednet_depth, hp.var_prednet_depth, activation=tf.nn.leaky_relu, scope='ca_f0_prednet')
220 | c0_r_pred = conv_stack(x, 2, hp.var_prednet_conv_k, hp.var_prednet_depth, hp.var_prednet_depth, activation=tf.nn.leaky_relu, scope='ca_c0_prednet')
221 | if is_training:
222 | max_ylen = tf.shape(f0)[-2] # dynamic padded length T
223 | crx_mask = get_attn_mask(x_len, max_xlen, y_len, max_ylen)
224 |
225 | # [B, N, 256] cross attn [B, T, 64]
226 | f0_r, sc = dot_attn(x, f0, crx_mask, hp.var_prednet_depth, scope='ca_f0')
227 | crx_attns.append(sc)
228 | c0_r, sc = dot_attn(x, c0, crx_mask, hp.var_prednet_depth, scope='ca_c0')
229 | crx_attns.append(sc)
230 |
231 | # combine f0 & c0
232 | if is_training: f = tf.layers.dense(tf.concat([f0_r, c0_r], axis=-1), depth, name='proj_ca')
233 | else: f = tf.layers.dense(tf.concat([f0_r_pred, c0_r_pred], axis=-1), depth, name='proj_ca')
234 | if hp.encoder_dropout:
235 | f = tf.layers.dropout(f, rate=0.2, training=is_training, name='dropout')
236 |
237 | # combine & transform (gffw)
238 | x = x + gffw(tf.concat([x, f], axis=-1), depth, scope=f'gffw_ca')
239 |
240 | return x, (slf_attns, crx_attns), ((f0_r, f0_r_pred), (c0_r, c0_r_pred))
241 |
--------------------------------------------------------------------------------
/transtacos/models/rnn_wrappers.py:
--------------------------------------------------------------------------------
1 | from collections import namedtuple
2 |
3 | import numpy as np
4 | import tensorflow as tf
5 | from tensorflow.contrib.rnn import RNNCell
6 | from tensorflow.python.framework import ops, tensor_shape
7 | from tensorflow.python.ops import array_ops, check_ops, rnn_cell_impl, tensor_array_ops
8 | from tensorflow.python.util import nest
9 |
10 | from .modules import prenet
11 | import hparam as hp
12 |
13 |
14 | class FrameProjection:
15 | """Projection layer to r * n_mel dimensions or n_mel dimensions"""
16 | def __init__(self, shape=hp.n_mel, activation=None, scope='linear_projection'):
17 | """
18 | Args:
19 | shape: integer, dimensionality of output space (r*n_mels for decoder or n_mels for postnet)
20 | activation: callable, activation function
21 | scope: FrameProjection scope.
22 | """
23 | super(FrameProjection, self).__init__()
24 |
25 | self.shape = shape
26 | self.activation = activation
27 | self.scope = scope
28 | self.dense = tf.layers.Dense(units=shape, activation=activation, name=f'projection_{self.scope}')
29 |
30 | def __call__(self, inputs):
31 | with tf.variable_scope(self.scope):
32 | # If activation==None, this returns a simple Linear projection
33 | # else the projection will be passed through an activation function
34 | return self.dense(inputs)
35 |
36 |
37 | class StopProjection:
38 | """Projection to a scalar and through a sigmoid activation"""
39 | def __init__(self, is_training, shape=1, activation=tf.nn.sigmoid, scope='stop_token_projection'):
40 | """
41 | Args:
42 | is_training: Boolean, to control the use of sigmoid function as it is useless to use it
43 | during training since it is integrate inside the sigmoid_crossentropy loss
44 | shape: integer, dimensionality of output space. Defaults to 1 (scalar)
45 | activation: callable, activation function. only used during inference
46 | scope: StopProjection scope.
47 | """
48 | super(StopProjection, self).__init__()
49 |
50 | self.is_training = is_training
51 | self.shape = shape
52 | self.activation = activation
53 | self.scope = scope
54 |
55 | def __call__(self, inputs):
56 | with tf.variable_scope(self.scope):
57 | output = tf.layers.dense(inputs, units=self.shape, activation=None, name=f'projection_{self.scope}')
58 | #During training, don't use activation as it is integrated inside the sigmoid_cross_entropy loss function
59 | return output if self.is_training else self.activation(output)
60 |
61 |
62 | class TacotronDecoderCellState(
63 | namedtuple("TacotronDecoderCellState",
64 | ("cell_state", "attention", "time", "alignments",
65 | "alignment_history"))):
66 | """`namedtuple` storing the state of a `TacotronDecoderCell`.
67 | Contains:
68 | - `cell_state`: The state of the wrapped `RNNCell` at the previous time
69 | step.
70 | - `attention`: The attention emitted at the previous time step.
71 | - `time`: int32 scalar containing the current time step.
72 | - `alignments`: A single or tuple of `Tensor`(s) containing the alignments
73 | emitted at the previous time step for each attention mechanism.
74 | - `alignment_history`: a single or tuple of `TensorArray`(s)
75 | containing alignment matrices from all time steps for each attention
76 | mechanism. Call `stack()` on each to convert to a `Tensor`.
77 | """
78 | def replace(self, **kwargs):
79 | """Clones the current state while overwriting components provided by kwargs.
80 | """
81 | return super(TacotronDecoderCellState, self)._replace(**kwargs)
82 |
83 |
84 | class TacotronDecoderWrapper(RNNCell):
85 | """Tactron 2 Decoder Cell
86 | Decodes encoder output and previous mel frames into next r frames
87 |
88 | Decoder Step i:
89 | 1) Prenet to compress last output information
90 | 2) Concat compressed inputs with previous context vector (input feeding) *
91 | 3) Decoder RNN (actual decoding) to predict current state s_{i} *
92 | 4) Compute new context vector c_{i} based on s_{i} and a cumulative sum of previous alignments *
93 | 5) Predict new output y_{i} using s_{i} and c_{i} (concatenated)
94 | 6) Predict output ys_{i} using s_{i} and c_{i} (concatenated)
95 |
96 | * : This is typically taking a vanilla LSTM, wrapping it using tensorflow's attention wrapper,
97 | and wrap that with the prenet before doing an input feeding, and with the prediction layer
98 | that uses RNN states to project on output space. Actions marked with (*) can be replaced with
99 | tensorflow's attention wrapper call if it was using cumulative alignments instead of previous alignments only.
100 | """
101 |
102 | def __init__(self, rnn_cell, attention_mechanism, frame_projection, stop_projection, is_training):
103 | """Initialize decoder parameters
104 |
105 | Args:
106 | prenet: A tensorflow fully connected layer acting as the decoder pre-net
107 | attention_mechanism: A _BaseAttentionMechanism instance, usefull to
108 | learn encoder-decoder alignments
109 | rnn_cell: Instance of RNNCell, main body of the decoder
110 | frame_projection: tensorflow fully connected layer with r * n_mel output units
111 | stop_projection: tensorflow fully connected layer, expected to project to a scalar
112 | and through a sigmoid activation
113 | mask_finished: Boolean, Whether to mask decoder frames after the
114 | """
115 | super(TacotronDecoderWrapper, self).__init__()
116 | #Initialize decoder layers
117 | self._training = is_training
118 | self._attention_mechanism = attention_mechanism
119 | self._cell = rnn_cell
120 | self._frame_projection = frame_projection
121 | self._stop_projection = stop_projection
122 | self._attention_layer_size = self._attention_mechanism.values.get_shape()[-1].value
123 |
124 | def _batch_size_checks(self, batch_size, error_message):
125 | return [check_ops.assert_equal(batch_size,
126 | self._attention_mechanism.batch_size,
127 | message=error_message)]
128 |
129 | @property
130 | def output_size(self):
131 | return self._frame_projection.shape
132 |
133 | # @property
134 | def state_size(self):
135 | """The `state_size` property of `TacotronDecoderWrapper`.
136 |
137 | Returns:
138 | An `TacotronDecoderWrapper` tuple containing shapes used by this object.
139 | """
140 | return TacotronDecoderCellState(
141 | cell_state=self._cell._cell.state_size,
142 | time=tensor_shape.TensorShape([]),
143 | attention=self._attention_layer_size,
144 | alignments=self._attention_mechanism.alignments_size,
145 | alignment_history=())
146 |
147 | def zero_state(self, batch_size, dtype):
148 | """Return an initial (zero) state tuple for this `AttentionWrapper`.
149 |
150 | Args:
151 | batch_size: `0D` integer tensor: the batch size.
152 | dtype: The internal state data type.
153 | Returns:
154 | An `TacotronDecoderCellState` tuple containing zeroed out tensors and,
155 | possibly, empty `TensorArray` objects.
156 | Raises:
157 | ValueError: (or, possibly at runtime, InvalidArgument), if
158 | `batch_size` does not match the output size of the encoder passed
159 | to the wrapper object at initialization time.
160 | """
161 | with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]):
162 | cell_state = self._cell.zero_state(batch_size, dtype)
163 | error_message = (
164 | "When calling zero_state of TacotronDecoderCell %s: " % self._base_name +
165 | "Non-matching batch sizes between the memory "
166 | "(encoder output) and the requested batch size.")
167 | with ops.control_dependencies(
168 | self._batch_size_checks(batch_size, error_message)):
169 | cell_state = nest.map_structure(
170 | lambda s: array_ops.identity(s, name="checked_cell_state"),
171 | cell_state)
172 | return TacotronDecoderCellState(
173 | cell_state=cell_state,
174 | time=array_ops.zeros([], dtype=tf.int32),
175 | attention=rnn_cell_impl._zero_state_tensors(self._attention_layer_size, batch_size, dtype),
176 | alignments=self._attention_mechanism.initial_alignments(batch_size, dtype),
177 | alignment_history=tensor_array_ops.TensorArray(dtype=dtype, size=0,
178 | dynamic_size=True))
179 |
180 |
181 | def __call__(self, inputs, state):
182 | #Information bottleneck (essential for learning attention)
183 | # just adjust n_mel to encoder_depth (?)
184 | prenet_output = prenet(inputs, hp.prenet_depths, self._training, scope='decoder_prenet')
185 |
186 | #Concat context vector and prenet output to form RNN cells input (input feeding)
187 | rnn_input = tf.concat([prenet_output, state.attention], axis=-1)
188 |
189 | #Unidirectional RNN layers
190 | rnn_output, next_cell_state = self._cell(tf.layers.dense(rnn_input, hp.decoder_depth), state.cell_state)
191 |
192 | #Compute the attention (context) vector and alignments using
193 | #the new decoder cell hidden state as query vector
194 | #and cumulative alignments to extract location features
195 | #The choice of the new cell hidden state (s_{i}) of the last
196 | #decoder RNN Cell is based on Luong et Al. (2015):
197 | #https://arxiv.org/pdf/1508.04025.pdf
198 | previous_alignments = state.alignments
199 | previous_alignment_history = state.alignment_history
200 |
201 | alignments, cumulated_alignments = self._attention_mechanism(rnn_output, state=previous_alignments)
202 |
203 | # Reshape from [batch_size, memory_time] to [batch_size, 1, memory_time]
204 | expanded_alignments = array_ops.expand_dims(alignments, 1)
205 | # Context is the inner product of alignments and values along the
206 | # memory time dimension.
207 | # alignments shape is
208 | # [batch_size, 1, memory_time]
209 | # attention_mechanism.values shape is
210 | # [batch_size, memory_time, memory_size]
211 | # the batched matmul is over memory_time, so the output shape is
212 | # [batch_size, 1, memory_size].
213 | # we then squeeze out the singleton dim.
214 | context = tf.matmul(expanded_alignments, self._attention_mechanism.values)
215 | context = tf.squeeze(context, [1]) # V
216 |
217 | #Concat RNN outputs and context vector to form projections inputs
218 | projections_input = tf.concat([rnn_output, context], axis=-1)
219 |
220 | #Compute predicted frames and predicted
221 | cell_outputs = self._frame_projection(projections_input)
222 | stop_tokens = self._stop_projection (projections_input)
223 |
224 | #Save alignment history
225 | alignment_history = previous_alignment_history.write(state.time, alignments)
226 |
227 | #Prepare next decoder state
228 | next_state = TacotronDecoderCellState(
229 | time=state.time + 1,
230 | cell_state=next_cell_state,
231 | attention=context,
232 | alignments=cumulated_alignments,
233 | alignment_history=alignment_history)
234 |
235 | return (cell_outputs, stop_tokens), next_state
236 |
--------------------------------------------------------------------------------
/transtacos/models/tacotron.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | from tensorflow.contrib.rnn import GRUCell, MultiRNNCell, ResidualWrapper
3 |
4 | from .modules import *
5 | from .custom_decoder import *
6 | from .rnn_wrappers import *
7 | from .attention import *
8 | from text.symbols import get_vocab_size
9 | from audio import inv_spec
10 | from utils import log
11 |
12 |
13 | class Tacotron():
14 |
15 | def __init__(self, hparams):
16 | self._hparams = hparams
17 | self.global_step = tf.Variable(0, name='global_step', trainable=False)
18 |
19 |
20 | def initialize(self, text_lengths, text, prds=None,
21 | spec_lengths=None, mel_targets=None, mag_targets=None, f0_targets=None, c0_targets=None,
22 | stop_token_targets=None):
23 | with tf.variable_scope('inference'):
24 | hp = self._hparams
25 | is_training = mel_targets is not None
26 | B = batch_size = tf.shape(text)[0]
27 |
28 | print('text_lengths.shape:', text_lengths.shape)
29 | print('text.shape:', text.shape)
30 | if is_training:
31 | print('prds.shape:', prds.shape)
32 | print('spec_lengths.shape:', spec_lengths.shape)
33 | print('mel_targets.shape:', mel_targets.shape)
34 | print('mag_targets.shape:', mag_targets.shape)
35 | print('f0_targets.shape:', f0_targets.shape)
36 | print('c0_targets.shape:', c0_targets.shape)
37 | print('stop_token_targets.shape:', stop_token_targets.shape)
38 | log(f'[Tacotron] vocab size {get_vocab_size()}')
39 |
40 | # Embeddings
41 | # 不用零pad好像关系也不大
42 | zero_embedding_pad = tf.constant(0, shape=[1, hp.embed_depth], dtype=tf.float32, name='zero_embedding_pad')
43 | zero_embedding_pad_half = tf.constant(0, shape=[1, hp.encoder_depth//2], dtype=tf.float32, name='zero_embedding_pad_half')
44 |
45 | '''位置编码嵌入'''
46 | PE_table = get_sinusoid_encoding_table(max(hp.maxlen_text, hp.maxlen_spec), hp.posenc_depth) # 姑且统一深度,使用concat方式
47 |
48 | '''语言学特征嵌入'''
49 | if hp.g2p == 'seq':
50 | E_text = tf.get_variable('E_text', [get_vocab_size(), hp.embed_depth], dtype=tf.float32, initializer=tf.truncated_normal_initializer(stddev=0.5))
51 |
52 | # seq
53 | text_embd = tf.nn.embedding_lookup(E_text, text)
54 | embd_out = text_embd
55 |
56 | elif hp.g2p == 'syl4':
57 | E_text = tf.get_variable('E_text', [get_vocab_size(), hp.embed_depth], dtype=tf.float32, initializer=tf.truncated_normal_initializer(stddev=0.5))
58 | E_tone = tf.get_variable('E_tone', [hp.n_tone, hp.embed_depth], dtype=tf.float32, initializer=tf.truncated_normal_initializer(stddev=0.5))
59 | E_prds = tf.get_variable('E_prds', [hp.n_prds, hp.embed_depth], dtype=tf.float32, initializer=tf.truncated_normal_initializer(stddev=0.5))
60 |
61 | # syl4
62 | CVVx, T = [tf.squeeze(p, axis=-1) for p in tf.split(text, 2, axis=-1)] # [B, T, 2] => 2 * [B, T]
63 | phone_embd = tf.nn.embedding_lookup(E_text, CVVx)
64 | tone_embd = tf.nn.embedding_lookup(E_tone, T)
65 | text_embd = phone_embd + tone_embd
66 |
67 | # prds
68 | prds_prob = conv_stack(text_embd, 3, hp.prdsnet_conv_k, hp.prdsnet_depth, hp.n_prds, activation=tf.nn.relu, scope='prdsnet')
69 | prds_out = tf.argmax(prds_prob, axis=-1)
70 | if is_training: prds_embd = tf.nn.embedding_lookup(E_prds, prds)
71 | else: prds_embd = tf.nn.embedding_lookup(E_prds, prds_out)
72 |
73 | embd_out = text_embd + prds_embd
74 |
75 | if hp.embed_dropout:
76 | embd_out = tf.layers.dropout(embd_out, rate=0.2, training=is_training, name='dropout_N')
77 | if is_training:
78 | embd_out = gaussian_noise(embd_out, is_training)
79 |
80 | if hp.encoder_type == 'sa':
81 | if hp.txt_use_posenc:
82 | N_pos_embd_out = tf.tile(PE_table[:, :tf.shape(embd_out)[1], :], (B, 1, 1))
83 | embd_out = tf.concat([embd_out, N_pos_embd_out], axis=-1)
84 |
85 | if is_training:
86 | '''声学特征离散嵌入'''
87 | E_f0 = tf.get_variable('E_f0', [hp.n_f0_bins, hp.var_embed_depth], dtype=tf.float32, initializer=tf.truncated_normal_initializer(stddev=0.5))
88 | E_c0 = tf.get_variable('E_c0', [hp.n_c0_bins, hp.var_embed_depth], dtype=tf.float32, initializer=tf.truncated_normal_initializer(stddev=0.5))
89 |
90 | f0_embd = tf.nn.embedding_lookup(E_f0, f0_targets)
91 | c0_embd = tf.nn.embedding_lookup(E_c0, c0_targets)
92 |
93 | if hp.embed_dropout:
94 | f0_embd = tf.layers.dropout(f0_embd, rate=0.2, training=is_training, name='dropout_T')
95 | c0_embd = tf.layers.dropout(c0_embd, rate=0.2, training=is_training, name='dropout_T')
96 | if is_training:
97 | f0_embd = gaussian_noise(f0_embd, is_training)
98 | c0_embd = gaussian_noise(c0_embd, is_training)
99 |
100 | if hp.var_use_posenc:
101 | T_pos_embd_out = tf.tile(PE_table[:, :tf.shape(f0_targets)[-1], :], (B, 1, 1))
102 | f0_embd = tf.concat([f0_embd, T_pos_embd_out], axis=-1)
103 | c0_embd = tf.concat([c0_embd, T_pos_embd_out], axis=-1)
104 | else:
105 | f0_embd = c0_embd = None
106 |
107 | # Encoder
108 | if hp.encoder_type == 'sa':
109 | encoder_out, (slf_attn, crx_attn), ((f0_r, f0_r_pred), (c0_r, c0_r_pred)) = encoder_sa(embd_out, text_lengths, f0_embd, c0_embd, spec_lengths, is_training)
110 | elif hp.encoder_type == 'cb':
111 | encoder_out = cbhg(embd_out, text_lengths, hp.encoder_conv_K, [hp.encoder_depth//2, hp.encoder_depth], hp.encoder_depth, is_training)
112 | else: raise
113 | if is_training: encoder_out = gaussian_noise(encoder_out, is_training)
114 |
115 | # Decoder (layers specified bottom to top): # 将生成序列长度的控制问题转换为RNN迭代次数预测,参见`stop_projection`
116 | multi_rnn_cell = MultiRNNCell([ # 将inner_repr喂给RNN产出rnn_output
117 | ResidualWrapper(GRUCell(hp.decoder_depth))
118 | for _ in range(hp.decoder_layers)
119 | ], state_is_tuple=True) # [N, T_in, decoder_depth=256]
120 | attention_mechanism = LocationSensitiveAttention(hp.attention_depth, encoder_out, text_lengths) # [N, T_in, attn_depth=128]
121 | frame_projection = FrameProjection(hp.n_mel * hp.outputs_per_step) # [N, T_out/r, M*r], 将concat([rnn_output,attn_context])投影为长度 r*n_mel 的向量、之后会reshape成r帧
122 | stop_projection = StopProjection(is_training, shape=hp.outputs_per_step) # [N, T_out/r, r], 投影为r个标量、只要其中有一个大于0.5就认结束生成
123 | decoder_cell = TacotronDecoderWrapper(multi_rnn_cell, attention_mechanism, frame_projection, stop_projection, is_training)
124 | if is_training: helper = TacoTrainingHelper(batch_size, mel_targets, hp.n_mel, hp.outputs_per_step, self.global_step)
125 | else: helper = TacoTestHelper(batch_size, hp.n_mel, hp.outputs_per_step)
126 | decoder_init_state = decoder_cell.zero_state(batch_size=batch_size, dtype=tf.float32)
127 | (decoder_out, stop_token_out, _), final_decoder_state, _ = tf.contrib.seq2seq.dynamic_decode(
128 | CustomDecoder(decoder_cell, helper, decoder_init_state), # [N, T_out/r, M*r]
129 | impute_finished=True, maximum_iterations=hp.max_iters)
130 |
131 | # Reshape outputs to be one output per entry
132 | mel_out = tf.reshape(decoder_out, [batch_size, -1, hp.n_mel]) # [N, T_out, M], mel用于参与loss计算、并不用于产生最终wav
133 | stop_token_out = tf.reshape(stop_token_out, [batch_size, -1]) # [N, T_out], 这个结果没有用、只是在decode时作参考而已
134 | alignments = tf.transpose(final_decoder_state.alignment_history.stack(), [1, 2, 0])
135 |
136 | # k = 7 should cover 1.5 mel-groups (reduce_factor)
137 | if hp.decoder_sew_layer:
138 | mel_out += tf.layers.conv1d(mel_out, hp.n_mel, 7, padding='same', name='sew_up_layer')
139 |
140 | # Posnet
141 | x = mel_out[:,:,:hp.n_mel_low]
142 | x = tf.layers.dense(x, hp.posnet_depth//4, name='posnet1')
143 | x = tf.nn.leaky_relu(x)
144 | x = tf.layers.dense(x, hp.posnet_depth//2, name='posnet2')
145 | x = tf.nn.leaky_relu(x)
146 | x = tf.layers.dense(x, hp.posnet_depth, name='posnet3')
147 | x = tf.nn.leaky_relu(x)
148 | mag_out = tf.concat([tf.layers.dense(s, (hp.n_freq-1)//hp.posnet_ngroup, name=f'posnet4_{i}')
149 | for i, s in enumerate(tf.split(x, hp.posnet_ngroup, axis=-1))], axis=-1)
150 |
151 | # data in
152 | self.text_lengths = text_lengths
153 | self.text = text
154 | self.prds = prds
155 | self.spec_lengths = spec_lengths
156 | self.mel_targets = mel_targets
157 | self.mag_targets = mag_targets
158 | self.stop_token_targets = stop_token_targets
159 | # data out
160 | self.prds_prob = prds_prob
161 | self.prds_out = prds_out
162 | self.mel_outputs = mel_out
163 | self.mag_outputs = mag_out
164 | self.stop_token_outputs = stop_token_out
165 | self.alignments = alignments
166 | # misc
167 | if is_training:
168 | # NOTE: must get `._ratio` after `TacoTrainingHelper.initialize()`
169 | self.tfr = helper._ratio
170 | if hp.encoder_type == 'sa':
171 | self.slf_attn = slf_attn
172 | self.crx_attn = crx_attn
173 | self.f0_r = f0_r
174 | self.f0_r_pred = f0_r_pred
175 | self.c0_r = c0_r
176 | self.c0_r_pred = c0_r_pred
177 |
178 | def get_cosine_sim(x):
179 | dot = tf.matmul(x, x, transpose_b=True)
180 | n = tf.norm(x, axis=-1, keepdims=True)
181 | norm = tf.matmul(n, n, transpose_b=True)
182 | sim = dot / (norm + 1e-8)
183 | return sim
184 |
185 | self.E_text = E_text
186 | self.E_text_sim = get_cosine_sim(E_text)
187 | if hp.g2p == 'syl4':
188 | self.E_tone = E_tone
189 | self.E_tone_sim = get_cosine_sim(E_tone)
190 | self.E_prds = E_prds
191 | self.E_prds_sim = get_cosine_sim(E_prds)
192 |
193 | log('Initialized TrasTacoS Model: ')
194 | log(f' embd out: {embd_out.shape}')
195 | if hp.g2p == 'syl4':
196 | log(f' syl4 embd: {text_embd.shape}')
197 | log(f' tone embd: {tone_embd.shape}')
198 | log(f' prds embd: {prds_embd.shape}')
199 | if hp.encoder_type == 'sa' and is_training:
200 | log(f' f0 embd: {f0_embd.shape}')
201 | log(f' c0 embd: {c0_embd.shape}')
202 | log(f' encoder out: {encoder_out.shape}')
203 | log(f' decoder out (r frames): {decoder_out.shape}')
204 | log(f' mel out (1 frame): {mel_out.shape}')
205 | log(f' stoptoken out: {stop_token_out.shape}')
206 | log(f' mag out: {mag_out.shape}')
207 | log(f' E_text: {E_text.shape}')
208 | if hp.g2p == 'syl4':
209 | log(f' E_tone: {E_tone.shape}')
210 | log(f' E_prds: {E_prds.shape}')
211 |
212 |
213 | def add_loss(self):
214 | '''Adds loss to the model. Sets "loss" field. initialize() must have been called.'''
215 |
216 | hp = self._hparams
217 | with tf.variable_scope('loss'):
218 | self.mel_loss = tf.reduce_mean(tf.abs(self.mag_targets - self.mag_outputs))
219 | self.mag_loss = tf.reduce_mean(tf.abs(self.mel_targets - self.mel_outputs))
220 | if hp.encoder_type == 'sa' and hp.encoder_fusenet:
221 | self.f0_loss = tf.reduce_mean(tf.square(self.f0_r - self.f0_r_pred))
222 | self.c0_loss = tf.reduce_mean(tf.square(self.c0_r - self.c0_r_pred))
223 | else:
224 | self.f0_loss = 0.0
225 | self.c0_loss = 0.0
226 | if hp.g2p == 'syl4':
227 | self.prds_loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(labels=self.prds, logits=self.prds_prob))
228 | else:
229 | self.prds_loss = 0.0
230 | if hp.g2p == 'seq':
231 | self.sim_loss = tf.reduce_mean(tf.abs((1.0 - tf.eye(get_vocab_size())) * self.E_text_sim)) * hp.sim_weight
232 | else:
233 | self.sim_loss = tf.add_n([tf.reduce_mean(tf.abs((1.0 - tf.eye(get_vocab_size())) * self.E_text_sim)),
234 | tf.reduce_mean(tf.abs((1.0 - tf.eye(hp.n_prds)) * self.E_prds_sim))]) * hp.sim_weight
235 | self.stop_token_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=self.stop_token_targets, logits=self.stop_token_outputs))
236 | self.reg_loss = tf.add_n([tf.nn.l2_loss(v) for v in tf.trainable_variables()]) * hp.reg_weight
237 |
238 | self.loss = (self.prds_loss +
239 | self.mel_loss +
240 | self.mag_loss +
241 | self.f0_loss +
242 | self.c0_loss +
243 | self.sim_loss +
244 | self.stop_token_loss +
245 | self.reg_loss)
246 |
247 | def add_optimizer(self):
248 | '''Adds optimizer. Sets "gradients" and "optimize" fields. add_loss() must have been called.'''
249 |
250 | with tf.variable_scope('optimizer'):
251 | hp = self._hparams
252 |
253 | if hp.decay_learning_rate:
254 | self.learning_rate = _learning_rate_decay(hp.initial_learning_rate, self.global_step)
255 | else:
256 | self.learning_rate = tf.convert_to_tensor(hp.initial_learning_rate)
257 | optimizer = tf.train.AdamOptimizer(self.learning_rate, hp.adam_beta1, hp.adam_beta2, hp.adam_eps)
258 | gradients, variables = zip(*optimizer.compute_gradients(self.loss))
259 | self.gradients = tuple([g for g in gradients if g is not None]) # FIXME: do not know why
260 | clipped_gradients, _ = tf.clip_by_global_norm(gradients, 1.0) # 防止梯度爆炸
261 |
262 | # Add dependency on UPDATE_OPS; otherwise batchnorm won't work correctly. See:
263 | # https://github.com/tensorflow/tensorflow/issues/1122
264 | with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
265 | self.optimize = optimizer.apply_gradients(zip(clipped_gradients, variables),
266 | global_step=self.global_step)
267 |
268 |
269 | def add_stats(self):
270 | with tf.variable_scope('stats'):
271 | hp = self._hparams
272 |
273 | tf.summary.histogram('mel_outputs', self.mel_outputs)
274 | tf.summary.histogram('mel_targets', self.mel_targets)
275 | tf.summary.histogram('mag_outputs', self.mag_outputs)
276 | tf.summary.histogram('mag_targets', self.mag_targets)
277 |
278 | tf.summary.scalar('learning_rate', self.learning_rate)
279 | tf.summary.scalar('loss', self.loss)
280 | tf.summary.scalar('tfr', self.tfr)
281 | tf.summary.scalar('mel_loss', self.mel_loss)
282 | tf.summary.scalar('mag_loss', self.mag_loss)
283 | if hp.g2p == 'syl4':
284 | tf.summary.scalar('prds_loss', self.prds_loss)
285 | if hp.encoder_type == 'sa':
286 | tf.summary.scalar('f0_loss', self.f0_loss)
287 | tf.summary.scalar('c0_loss', self.c0_loss)
288 | tf.summary.scalar('sim_loss', self.sim_loss)
289 | tf.summary.scalar('stop_token_loss', self.stop_token_loss)
290 | tf.summary.scalar('reg_loss', self.reg_loss)
291 |
292 | gradient_norms = [tf.norm(grad) for grad in self.gradients]
293 | tf.summary.histogram('gradient_norm', gradient_norms)
294 | tf.summary.scalar('max_gradient_norm', tf.reduce_max(gradient_norms))
295 |
296 | raw = tf.numpy_function(inv_spec, [tf.transpose(self.mag_targets[0])], tf.float32)
297 | gen = tf.numpy_function(inv_spec, [tf.transpose(self.mag_outputs[0])], tf.float32)
298 | tf.summary.audio('raw', tf.expand_dims(raw, 0), hp.sample_rate, 1)
299 | tf.summary.audio('gen', tf.expand_dims(gen, 0), hp.sample_rate, 1)
300 |
301 | expand_dims = lambda x: tf.expand_dims(tf.expand_dims(x, 0), -1)
302 | tf.summary.image('alignments', expand_dims(self.alignments[0]))
303 | tf.summary.image('E_text_sim', expand_dims(self.E_text_sim))
304 | if hp.g2p != 'seq':
305 | tf.summary.image('E_tone_sim', expand_dims(self.E_tone_sim))
306 | tf.summary.image('E_prds_sim', expand_dims(self.E_prds_sim))
307 | if hp.encoder_type == 'sa':
308 | for i in range(hp.encoder_attn_layers):
309 | for j in range(hp.encoder_attn_nhead):
310 | tf.summary.image(f'slf_attn_{i}{j}', expand_dims(self.slf_attn[i][j][0]))
311 | if self.crx_attn:
312 | for i in range(2):
313 | tf.summary.image(f'crx_attn_{i}', expand_dims(self.crx_attn[i][0]))
314 |
315 | #tf.summary.image('mel_out', expand_dims(self.mel_outputs[0]))
316 |
317 | self.stats = tf.summary.merge_all()
318 |
319 |
320 | def _learning_rate_decay(init_lr, global_step):
321 | # Noam scheme from tensor2tensor:
322 | warmup_steps = 4000.0
323 | step = tf.cast(global_step + 1, dtype=tf.float32)
324 | return init_lr * warmup_steps**0.5 * tf.minimum(step * warmup_steps**-1.5, step**-0.5)
325 |
--------------------------------------------------------------------------------
/transtacos/preprocess.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Author: Armit
3 | # Create Time: 2022/01/07
4 |
5 | import os
6 | import random
7 | from pprint import pformat
8 | from argparse import ArgumentParser
9 | from importlib import import_module
10 | from typing import List, Tuple
11 |
12 | import hparam as hp
13 | random.seed(hp.randseed)
14 |
15 |
16 | def write_metadata(metadata:Tuple[List, List], stats:dict, wav_path, args):
17 | if args.shuffle: random.shuffle(metadata)
18 |
19 | out_path = os.path.join(args.base_dir, args.out_dir)
20 | os.makedirs(out_path, exist_ok=True)
21 |
22 | cp = int(len(metadata) * args.split_ratio)
23 | mt_test, mt_train = metadata[:cp], metadata[cp:]
24 |
25 | with open(os.path.join(out_path, 'train.txt'), 'w', encoding='utf-8') as fh:
26 | for mt in mt_train:
27 | fh.write('|'.join([str(x) for x in mt]))
28 | fh.write('\n')
29 |
30 | with open(os.path.join(out_path, 'test.txt'), 'w', encoding='utf-8') as fh:
31 | for mt in mt_test:
32 | fh.write('|'.join([str(x) for x in mt]))
33 | fh.write('\n')
34 |
35 | with open(os.path.join(out_path, 'stats.txt'), 'w', encoding='utf-8') as fh:
36 | for k, v in stats.items():
37 | fh.write(f'{k}\t{v}')
38 | fh.write('\n')
39 |
40 | with open(os.path.join(out_path, 'wav_path.txt'), 'w', encoding='utf-8') as fh:
41 | fh.write(wav_path)
42 |
43 |
44 | if __name__ == '__main__':
45 | def str2bool(s:str) -> bool:
46 | s = s.lower()
47 | if s in ['true', 't', '1']: return True
48 | if s in ['false', 'f', '0']: return False
49 | raise ValueError(f'invalid bool value: {s}')
50 |
51 | base_dir = os.path.dirname(os.path.abspath(__file__))
52 | DATASETS = [fn[:-3] for fn in os.listdir(os.path.join(base_dir, 'datasets')) if not fn.startswith('__')]
53 |
54 | parser = ArgumentParser()
55 | parser.add_argument('--base_dir', required=True, help='base path containing the dataset folder')
56 | parser.add_argument('--out_dir', default='preprocessed', help='preprocessed output folder')
57 | parser.add_argument('--dataset', required=True, choices=DATASETS)
58 | parser.add_argument('--shuffle', type=str2bool, default=True, help='shuffle metadata')
59 | parser.add_argument('--split_ratio', type=float, default=0.05, help='test/train split')
60 | parser.add_argument('--num_workers', type=int, default=4)
61 | args = parser.parse_args()
62 |
63 | os.environ['LIBROSA_CACHE_LEVEL'] = '50'
64 |
65 | proc = import_module(f'datasets.{args.dataset}')
66 | metadata, stats, wav_path = proc.preprocess(args)
67 | print('wav_path:', wav_path)
68 | print('stats:', pformat(stats)) # why `sort_dicts=False` not work
69 | write_metadata(metadata, stats, wav_path, args)
70 |
--------------------------------------------------------------------------------
/transtacos/server.py:
--------------------------------------------------------------------------------
1 | import os
2 | import io
3 | import pickle
4 | from re import compile as Regex
5 | from time import time
6 | from tempfile import gettempdir
7 | from argparse import ArgumentParser
8 |
9 | import tensorflow as tf
10 | import numpy as np
11 | from flask import Flask, request, jsonify, send_file
12 | from requests.utils import unquote
13 | from xpinyin import Pinyin
14 |
15 | import hparam as hp
16 | from synth import Synthesizer
17 | from audio import save_wav
18 |
19 |
20 | for gpu in tf.config.experimental.list_physical_devices('GPU'):
21 | tf.config.experimental.set_memory_growth(gpu, True)
22 |
23 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1'
24 | os.environ['TF_XLA_FLAGS'] = '--tf_xla_cpu_global_jit'
25 | os.environ['XLA_FLAGS'] = '--xla_hlo_profile'
26 |
27 | np.random.seed (hp.randseed)
28 | tf.random.set_random_seed(hp.randseed)
29 |
30 |
31 | # globals
32 | app = Flask(__name__)
33 | kanji2pinyin = Pinyin()
34 | synthesizer = None
35 | html_text = None
36 |
37 | BASE_PATH = os.path.dirname(os.path.abspath(__file__))
38 | HTML_FILE = os.path.join(BASE_PATH, '..', 'index.html')
39 |
40 | TMP_DIR = gettempdir()
41 | WAV_TMP_FILE = os.path.join(TMP_DIR, 'synth.wav')
42 |
43 | REGEX_PUNCT_IGNORE = Regex('、|:|;|“|”|‘|’')
44 | REGEX_PUNCT_BREAK = Regex(',|。|!|?')
45 | MAX_CLUASE_LENGTH = 20
46 |
47 |
48 | # quick demo index page
49 | @app.route('/', methods=['GET'])
50 | def root():
51 | global html_text
52 | if not html_text:
53 | with open(HTML_FILE, encoding='utf-8') as fp:
54 | html_text = fp.read()
55 | return html_text
56 |
57 |
58 | # vocode with internal Griffin-Lim
59 | @app.route('/synth', methods=['GET'])
60 | def synth():
61 | kanji = unquote(request.args.get('text'))
62 |
63 | if kanji:
64 | try:
65 | # Text-Norm
66 | if True:
67 | s = time()
68 | print(f'text/raw: {kanji!r}')
69 |
70 | kanji = REGEX_PUNCT_IGNORE.sub('', kanji)
71 | kanji = REGEX_PUNCT_BREAK.sub(' ', kanji)
72 | segs = [''] # dummy init
73 | for rs in [s.strip() for s in kanji.split(' ') if s.strip()]:
74 | if (not segs[-1]) or (len(rs) + len(segs[-1]) < MAX_CLUASE_LENGTH):
75 | segs[-1] = segs[-1] + rs
76 | else: segs.append(rs)
77 | print(f'text/segs: {segs!r}')
78 | t = time()
79 | print(f'[TextNorm] Done in {t - s:.2f}s')
80 |
81 | # Synth
82 | if True:
83 | s = time()
84 | wav_clips = []
85 | for seg in segs:
86 | text = ' '.join(kanji2pinyin.get_pinyin(seg, tone_marks='numbers').split('-'))
87 | wav = synthesizer.synthesize(text, 'wav')
88 | wav_clips.append(wav)
89 | wav = np.concatenate(wav_clips)
90 | print('wav.shape:', wav.shape)
91 | t = time()
92 | print(f'[Synth] Done in {t - s:.2f}s')
93 |
94 | # Save file
95 | if True:
96 | s = time()
97 | save_wav(wav, WAV_TMP_FILE)
98 | t = time()
99 | print(f'[SaveFile] Done in {t - s:.2f}s')
100 |
101 | return send_file(WAV_TMP_FILE, mimetype='audio/wav')
102 | except Exception as e:
103 | print('[Error] %r' % e)
104 | error_msg = 'synth failed, see logs'
105 | else:
106 | error_msg = 'bad request params or no text to synth?'
107 |
108 | return jsonify({'error': error_msg})
109 |
110 |
111 | # return linear spec
112 | @app.route('/synth_spec', methods=['POST'])
113 | def synth_spec():
114 | try:
115 | # chk txt
116 | pinyin = request.get_json().get('pinyin').strip()
117 | if not pinyin:
118 | return jsonify({'error': 'no text to synth'})
119 |
120 | # text to mag
121 | s = time()
122 | spec = synthesizer.synthesize(pinyin, 'spec')
123 | print('spec.shape:', spec.shape)
124 | t = time()
125 | print(f'[Synth] Done in {t - s:.2f}s')
126 |
127 | # transfer
128 | bio = io.BytesIO()
129 | bio.write(pickle.dumps(spec)) # float32 -> byte
130 | bio.seek(0) # reset fp to beginning for `send_file` to read
131 | return send_file(bio, mimetype='application/octet-stream')
132 |
133 | except Exception as e:
134 | print('[Error] %r' % e)
135 | return jsonify({'error': e})
136 |
137 |
138 | if __name__ == '__main__':
139 | parser = ArgumentParser()
140 | parser.add_argument('--log_path', required=True)
141 | parser.add_argument('--host', type=str, default='0.0.0.0')
142 | parser.add_argument('--port', type=int, default=5105)
143 | args = parser.parse_args()
144 |
145 | # load ckpt
146 | synthesizer = Synthesizer()
147 | synthesizer.load(args.log_path)
148 |
149 | app.run(host=args.host, port=args.port, debug=False)
150 |
--------------------------------------------------------------------------------
/transtacos/synth.py:
--------------------------------------------------------------------------------
1 | from time import time
2 | import numpy as np
3 | import tensorflow as tf
4 |
5 | import hparam as hp
6 | from models.tacotron import Tacotron
7 | from text.text import text_to_phoneme, phoneme_to_sequence, sequence_to_phoneme
8 | import audio as A
9 | from text.symbols import _eos, _sep, get_vocab_size
10 | from text.phonodict_cn import phonodict
11 |
12 |
13 | class Synthesizer:
14 |
15 | def load(self, log_dir):
16 | print('Constructing tacotron model')
17 |
18 | # init data placeholder
19 | if hp.g2p == 'seq': text_shape = [1, None]
20 | elif hp.g2p == 'syl4': text_shape = [1, None, 2]
21 | text_lengths = tf.placeholder(tf.int32, [1], 'text_lengths') # bs=1 for one sample
22 | text = tf.placeholder(tf.int32, text_shape, 'text')
23 | with tf.variable_scope('model'):
24 | self.model = Tacotron(hp)
25 | self.model.initialize(text_lengths, text)
26 | self.mag_output = self.model.mag_outputs[0]
27 |
28 | # load ckpt
29 | checkpoint_state = tf.train.get_checkpoint_state(log_dir)
30 | print('Resuming from checkpoint: %s' % checkpoint_state.model_checkpoint_path)
31 | self.session = tf.Session()
32 | self.session.run(tf.global_variables_initializer())
33 | saver = tf.train.Saver()
34 | saver.restore(self.session, checkpoint_state.model_checkpoint_path)
35 |
36 | def synthesize(self, text, out_type='wav') -> bytes:
37 | # ref: `data.DataFeeder.load_data`
38 | if hp.g2p == 'seq':
39 | text = text + _eos
40 | print('text: ', text)
41 | phs = text_to_phoneme(text)
42 | print('phs: ', phs)
43 | seq = phoneme_to_sequence(phs)
44 | print('seq: ', seq)
45 | phs_rev = sequence_to_phoneme(seq)
46 | print('phs_rev: ', phs_rev)
47 | elif hp.g2p == 'syl4':
48 | C, V, T, Vx = text_to_phoneme(text) # [[str]]
49 |
50 | CVVx, Tx = [ ], [ ]
51 | n_syllable = len(C)
52 | for i in range(n_syllable):
53 | if C[i] != phonodict.vacant:
54 | CVVx.append(C[i]) ; Tx.append(T[i])
55 | if V[i] != phonodict.vacant:
56 | CVVx.append(V[i]) ; Tx.append(T[i])
57 | if Vx[i] != phonodict.vacant:
58 | CVVx.append(Vx[i]) ; Tx.append(T[i])
59 |
60 | CVVx.append(_sep) ; Tx.append(0)
61 |
62 | # NOTE: pad here, then convert to id_seq
63 | CVVx = phoneme_to_sequence(CVVx + [_eos]) # see phone table
64 | Tx = [int(t) for t in Tx] + [0] # should be 0 ~ 5
65 |
66 | assert len(CVVx) == len(Tx)
67 | assert 0 <= min(CVVx) and max(CVVx) < get_vocab_size()
68 | assert 0 <= min(Tx) and max(Tx) < 6
69 |
70 | seq = np.stack([CVVx, Tx], axis=-1) # [T, 2]
71 |
72 | seq = np.asarray(seq, dtype=np.int32)
73 |
74 | feed_dict = {
75 | self.model.text_lengths: [len(seq)], # len(id_seq)
76 | self.model.text: [seq], # id_seq
77 | }
78 | mag = self.session.run(self.mag_output, feed_dict=feed_dict)
79 | mag = mag.T # [F-1, T]
80 | if out_type == 'wav':
81 | wav = A.inv_spec(mag) # vocode with internal Griffin-Lim
82 | wav = A.trim_silence(wav)
83 | return wav # only data chunk, no RIFF capsulation
84 | if out_type == 'spec':
85 | S = A.spec_to_natural_scale(mag) # denorm
86 | S = A.fix_zero_DC(S)
87 | return S
88 |
--------------------------------------------------------------------------------
/transtacos/text/__init__.py:
--------------------------------------------------------------------------------
1 | from .symbols import *
2 | from .text import *
3 | from .phonodict_cn import *
4 | from .g2p import *
5 |
--------------------------------------------------------------------------------
/transtacos/text/g2p.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Author: Armit
3 | # Create Time: 2021/3/29
4 |
5 | from typing import List
6 |
7 | from .symbols import _unk
8 | from .phonodict_cn import phonodict
9 |
10 |
11 | def to_syl4(pinyin:str, sep=' ') -> List[List[str]]:
12 | C, V, T, Vx = [], [], [], []
13 |
14 | py_ls = pinyin.split(sep)
15 | n_syllable = len(py_ls)
16 | for py in py_ls:
17 | # split tone
18 | t = py[-1]
19 | if t.isdigit(): py = py[:-1]
20 | else: t = '5'
21 |
22 | # deletec R-ending
23 | r_ending = False
24 | if py[-1] == 'r':
25 | r_ending = True
26 | if py != 'er':
27 | py = py[:-1]
28 |
29 | # split CV
30 | try:
31 | c, v, e = phonodict[py]
32 | C.append(c) ; V.append(v) ; T.append(t)
33 | if r_ending: Vx.append('_R') # let R overriding N or NG
34 | else: Vx.append(e)
35 |
36 | except:
37 | C.append(_unk) ; V.append(_unk) ; T.append(_unk) ; Vx.append(_unk)
38 | print('[Syllable] cannot parse %r' % py)
39 |
40 | assert len(C) == len(V) == len(T) == len(Vx) == n_syllable
41 | return [C, V, T, Vx]
42 |
43 |
44 | def from_syl4(syl4:List[List[str]], sep=' ') -> str:
45 | return sep.join([''.join(s) for s in zip(*syl4)])
46 |
47 |
48 | if __name__ == '__main__':
49 | pinyin = 'zi3 se4 de hua1 er2 wei4 shen2 me zher4 yang4 yuan2'
50 | print('pinyin:', pinyin)
51 | syl4 = to_syl4(pinyin)
52 | print('syl4:', syl4)
53 | syl4_serial = from_syl4(syl4)
54 | print('syl4_serial:', syl4_serial)
55 |
--------------------------------------------------------------------------------
/transtacos/text/phonodict_cn.csv:
--------------------------------------------------------------------------------
1 | ,-,b,d,g,p,t,k,j,q,x,z,c,s,zh,ch,sh,m,n,l,f,h,y,w,r
2 | a,a,b a,d a,g a,p a,t a,k a,,,,z a,c a,s a,zh a,ch a,sh a,m a,n a,l a,f a,h a,ia,ua,
3 | o,o,b uo,,,p uo,,,,,,,,,,,,m uo,,l uo,f uo,,io,uo,
4 | e,e,,d e,g e,,t e,k e,,,,z e,c e,s e,zh e,ch e,sh e,m e,n e,l e,,h e,iE,,r e
5 | i,,b i,d i,,p i,t i,,j i,q i,x i,z i0,c i0,s i0,zh iR,ch iR,sh iR,m i,n i,l i,,,i,,r iR
6 | u,,b u,d u,g u,p u,t u,k u,j v,q v,x v,z u,c u,s u,zh u,ch u,sh u,m u,n u,l u,f u,h u,v,u,r u
7 | v,,,,,,,,,,,,,,,,,,n v,l v,,,,,
8 | ai,ai,b ai,d ai,g ai,p ai,t ai,k ai,,,,z ai,c ai,s ai,zh ai,ch ai,sh ai,m ai,n ai,l ai,,h ai,,uai,
9 | ao,ao,b ao,d ao,g ao,p ao,t ao,k ao,,,,z ao,c ao,s ao,zh ao,ch ao,sh ao,m ao,n ao,l ao,,h ao,iao,,r ao
10 | ei,Ei,b Ei,d Ei,g Ei,p Ei,t Ei,k Ei,,,,z Ei,,,zh Ei,,sh Ei,m Ei,n Ei,l Ei,f Ei,h Ei,,uEi,
11 | ou,ou,,d ou,g ou,p ou,t ou,k ou,,,,z ou,c ou,s ou,zh ou,ch ou,sh ou,m ou,n ou,l ou,f ou,h ou,iou,,r ou
12 | uo,,,d uo,g uo,,t uo,k uo,,,,z uo,c uo,s uo,zh uo,ch uo,sh uo,,n uo,l uo,,h uo,,,r uo
13 | an,an,b an,d an,g an,p an,t an,k an,,,,z an,c an,s an,zh an,ch an,sh an,m an,n an,l an,f an,h an,iEn,uan,r an
14 | en,en,b en,d en,g en,p en,,k en,,,,z en,c en,s en,zh en,ch en,sh en,m en,n en,,f en,h en,,un,r en
15 | in,,b in,,,p in,,,j in,q in,x in,,,,,,,m in,n in,l in,,,in,,
16 | un,,,d un,g un,,t un,k un,j vn,q vn,x vn,z un,c un,s un,zh un,ch un,sh un,,n un,l un,,h un,vn,,r un
17 | ang,ang,b ang,d ang,g ang,p ang,t ang,k ang,,,,z ang,c ang,s ang,zh ang,ch ang,sh ang,m ang,n ang,l ang,f ang,h ang,iang,uang,r ang
18 | eng,eng,b eng,d eng,g eng,p eng,t eng,k eng,,,,z eng,c eng,s eng,zh eng,ch eng,sh eng,m eng,n eng,l eng,f eng,h eng,,ueng,r eng
19 | ing,,b ing,d ing,,p ing,t ing,,j ing,q ing,x ing,,,,,,,m ing,n ing,l ing,,,ing,,
20 | ong,,,d ong,g ong,,t ong,k ong,,,,z ong,c ong,s ong,zh ong,ch ong,,,n ong,l ong,,h ong,iong,,r ong
21 | ia,,,d ia,,,,,j ia,q ia,x ia,,,,,,,,,l ia,,,,,
22 | ian,,b iEn,d iEn,,p iEn,t iEn,,j iEn,q iEn,x iEn,,,,,,,m iEn,n iEn,l iEn,,,,,
23 | iang,,b iang,,,,,,j iang,q iang,x iang,,,,,,,,n iang,l iang,,,,,
24 | iong,,,,,,,,j iong,q iong,x iong,,,,,,,,,,,,,,
25 | ie,,b iE,d iE,,p iE,t iE,,j iE,q iE,x iE,,,,,,,m iE,n iE,l iE,,,,,
26 | iu,,,d iou,,,,,j iou,q iou,x iou,,,,,,,m iou,n iou,l iou,,,,,
27 | iao,,b iao,d iao,,p iao,t iao,,j iao,q iao,x iao,,,,,,,m iao,n iao,l iao,f iao,,,,
28 | ua,,,,g ua,,,k ua,,,,,,,zh ua,ch ua,sh ua,,,,,h ua,,,r ua
29 | uan,,,d uan,g uan,,t uan,k uan,j vEn,q vEn,x vEn,z uan,c uan,s uan,zh uan,ch uan,sh uan,,n uan,l uan,,h uan,vEn,,r uan
30 | uang,,,,g uang,,,k uang,,,,,,,zh uang,ch uang,sh uang,,,,,h uang,,,
31 | ue,,,,,,,,j vE,q vE,x vE,,,,,,,,,,,,vE,,
32 | ui,,,d uEi,g uEi,,t uEi,k uEi,,,,z uEi,c uEi,s uEi,zh uEi,ch uEi,sh uEi,,,,,h uEi,,,r uEi
33 | uai,,,,g uai,,,k uai,,,,,,,zh uai,ch uai,sh uai,,,,,h uai,,,
34 | ve,,,,,,,,,,,,,,,,,,n vE,l vE,,,,,
35 | er,R,,,,,,,,,,,,,,,,,,,,,,,
36 |
--------------------------------------------------------------------------------
/transtacos/text/phonodict_cn.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 |
3 | # 无声调音素字典,基本等价于X-SAMPA(Vocaloid)方案
4 | # 声母是单纯辅音,介母向韵母方向黏着
5 |
6 | LEXDICT_FILE = Path(__file__).absolute().parent / 'phonodict_cn.csv'
7 |
8 | # NOT: hard coded accroding to `symbol.py``
9 | _pad = '_'
10 |
11 |
12 | class Phonodict4:
13 |
14 | def __init__(self, fp=LEXDICT_FILE, vac_sym=_pad):
15 | self.entry = { } # {'hui': 'h uEi _'}, 1 py = 3 syl
16 | self.initials = [] # pinyin := initial + final
17 | self.finals = []
18 | self.consonants = [] # syl4 := consonant + vowel + (tone) + ending
19 | self.vowels = []
20 | self.endings = ['_N', '_NG', '_R']
21 | self.vacant = vac_sym # '_', for zero consonant/ending
22 |
23 | self.dict_fp = fp
24 | self._load_dict()
25 |
26 | def _load_dict(self):
27 | I, F, C, V = set(), set(), set(), set()
28 |
29 | with open(self.dict_fp) as fh:
30 | ilist = list(fh.readline().strip().split(',')[1:])
31 | ilist[0] = '' # fix '-' to ''
32 | for row in fh.readlines():
33 | ls = row.strip().split(',')
34 | f, cvlist = ls[0], ls[1:]
35 | F.add(f)
36 | for i in ilist:
37 | I.add(i)
38 | cv = cvlist[ilist.index(i)]
39 | if cv: # found a valid syllble
40 | if cv == 'R':
41 | c = self.vacant
42 | v = 'e'
43 | e = '_R'
44 | else:
45 | if cv.endswith('ng'):
46 | cv = cv[:-2] ; e = '_NG'
47 | elif cv.endswith('n'):
48 | cv = cv[:-1] ; e = '_N'
49 | else:
50 | e = self.vacant
51 | if ' ' in cv: # 'k ua'
52 | c, v = cv.split(' ')
53 | else:
54 | c = self.vacant ; v = cv
55 | C.add(c) ; V.add(v)
56 | self.entry[i + f] = [c, v, e]
57 |
58 | self.initials = sorted(list(I))
59 | self.finals = sorted(list(F))
60 | self.consonants = sorted(list(C))
61 | self.vowels = sorted(list(V))
62 |
63 | def __getitem__(self, py:str) -> str:
64 | return self.entry.get(py, None)
65 |
66 | def __len__(self) -> int:
67 | return len(self.entry)
68 |
69 | @property
70 | def vacant_symbol(self) -> str:
71 | return self.vacant
72 |
73 | def inspect(self):
74 | print(f'syllable count: {len(self.entry)}')
75 | print(f'initials({len(self.initials)}): {self.initials}')
76 | print(f'finals({len(self.finals)}): {self.finals}')
77 | print(f'consonants({len(self.consonants)}): {self.consonants}')
78 | print(f'vowels({len(self.vowels)}): {self.vowels}')
79 | print(f'endings({len(self.endings)}): {self.endings}')
80 |
81 |
82 | phonodict = Phonodict4()
83 |
84 |
85 | if __name__ == '__main__':
86 | phonodict.inspect()
87 |
--------------------------------------------------------------------------------
/transtacos/text/phonodict_cn.txt:
--------------------------------------------------------------------------------
1 | [字典]
2 | 声母(21):
3 | - 卷(4): zh, ch, sh, r
4 | - 尖(3): z, c, s
5 | - 团(3): j, q, x
6 | - 按送气方式(波形更相近)
7 | - 送气(5): p, t, k, f, h
8 | - 不送气(4): b, d, g, l
9 | - 鼻(2): m, n
10 | - 按发音部位
11 | - 唇(4): b, p, m, f
12 | - 舌(4): d, t, n, l
13 | - 喉(3): g, k, h
14 |
15 | 韵母(39):
16 | - 单(7/9): a, o, e, i(i, iR, i0), u, v, R
17 | - 前闭(5): an en in un vn
18 | - 后闭(4): ang eng ing ong
19 | - 双(5): ai ao Ei ou uo
20 | - 介母化
21 | - y化: ia, iao, iE, iEn, io, iou, iang, iong
22 | - w化: ua, uai, uEi, uan, uang, ueng
23 | - v化: vE, vEn
24 |
25 | 注:
26 | 0.若无声母音节的韵头足够强,则听起来声母为b/d
27 |
28 |
29 | [拼音方案]
30 | *若使用音素phoneme作发音单位,带tone的合法音素有21+39*5=216,DataBaker的可见大小为210
31 | *若使用音节syllable作发音单位,标准普通话不含tone的合法音节有414,按下述六分化之后总字典大小为80
32 | - C: 零声母,声母表;共22个
33 | - Cx: 卷、尖、团、送气、不送气、鼻;共6个
34 | - xV: 无、y化、w化、v化;共4个
35 | - V: 韵母表;共39个
36 | - T: 平上去入,轻声;共5个
37 | - Vx: 无、前闭、后闭、儿化;共4个
38 |
39 | [韵律建模]
40 | 韵律标记:
41 | 0: 词内部
42 | 1: 连读的分词末
43 | 2: 长音或停顿的分词末
44 | 3: 分句末
45 | 4: 句末
46 | 5: 末尾标记
47 | 韵律的层级:
48 | 音节串#0 -> 单词串#1 -> 短语串#2 -> 短句串#3 -> 整句#4
49 |
50 | [Tacotron实验表现]
51 | 0. i/i0/iR几乎完全一样,可以相互替换
52 | 1. j/q/x后面似乎天生带个i, ja = jia
53 |
--------------------------------------------------------------------------------
/transtacos/text/symbols.py:
--------------------------------------------------------------------------------
1 | # marks
2 | _pad = '_' # , right padding for fixed-length RNN;
3 | # /, short silence / break of speech
4 | _eos = '~' # , end of sentence
5 | _sep = '/' # separtor between syllables
6 | _unk = '?' #
7 |
8 | _markers = [_pad, _eos, _sep, _unk] # NOTE: `_pad` MUST be at index 0
9 |
10 |
11 | ''' G2P = seq '''
12 | _chars = 'abcdefghijklmnopqrstuvwxyz 12345'
13 |
14 |
15 | ''' G2P = syl4 '''
16 | # phonetic unit under syllable repr refer to `phonodict_cn.txt`
17 | # syl4 := CxVTx = C + V + T + Vx
18 | # _syl4 == [
19 | # '-', 'b', 'c', 'ch', 'd', 'f', 'g', 'h', 'j', 'k', 'l', 'm', 'n', 'p', 'q', 'r', 's', 'sh', 't', 'x', 'z', 'zh',
20 | # 'Ei', 'R', 'a', 'ai', 'ao', 'e', 'i', 'i0', 'iE', 'iR', 'ia', 'iao', 'io', 'iou', 'o', 'ou', 'u', 'uEi', 'ua', 'uai', 'ue', 'uo', 'v', 'vE',
21 | # '0', '1', '2', '3', '4', '5',
22 | # '_N', '_NG', '_R',
23 | # ]
24 | # TODO: 是否合并, 'i'/'iR' - 'i0'
25 | from .phonodict_cn import phonodict
26 | #_syl4_T = ['0', '1', '2', '3', '4', '5'] # 6, NOTE: exclude from phone table
27 | #_syl4_P = ['0', '1', '2', '3', '4', '5'] # 6, NOTE: exclude from phone table
28 | _syl4_C = phonodict.consonants # 22
29 | _syl4_V = phonodict.vowels # 24
30 | _syl4_Vx = phonodict.endings # 3
31 | _syl4 = _syl4_C + _syl4_V + _syl4_Vx # 54
32 |
33 |
34 | # phonetic unit list
35 | _g2p_mapping = {
36 | 'seq': _chars,
37 | 'syl4': _syl4,
38 | }
39 |
40 | import hparam as hp
41 |
42 | assert len(set(_g2p_mapping[hp.g2p])) == len(_g2p_mapping[hp.g2p]) # assure no duplicates
43 | _symbols = _markers + sorted(set(_g2p_mapping[hp.g2p]) - set(_markers)) # keep order
44 | print(f'[Symbols] collect {len(_symbols)} symbols in {hp.g2p} repr')
45 | print(f' {_symbols}')
46 |
47 | _symbol_to_id = {s: i for i, s in enumerate(_symbols)}
48 | _id_to_symbol = {i: s for i, s in enumerate(_symbols)}
49 |
50 |
51 | def symbol_to_id(sym:str) -> int:
52 | return _symbol_to_id.get(sym, _symbol_to_id[_unk])
53 |
54 |
55 | def id_to_symbol(id:int) -> str:
56 | return _id_to_symbol.get(id, _unk)
57 |
58 |
59 | def get_vocab_size():
60 | return len(_symbols)
61 |
62 |
63 | def get_symbol_id(s:str):
64 | return {
65 | 'pad': symbol_to_id(_pad),
66 | 'eos': symbol_to_id(_eos),
67 | 'sep': symbol_to_id(_sep),
68 | 'unk': symbol_to_id(_unk),
69 | 'vac': symbol_to_id(phonodict.vacant_symbol), # := _pad
70 | }.get(s, symbol_to_id(s))
71 |
--------------------------------------------------------------------------------
/transtacos/text/text.py:
--------------------------------------------------------------------------------
1 | import re
2 | from typing import List, Union
3 |
4 | import hparam as hp
5 | from .symbols import symbol_to_id, id_to_symbol
6 | from .g2p import to_syl4
7 |
8 |
9 | _whitespace_re = re.compile(r'\s+')
10 |
11 |
12 | def text_to_phoneme(text:str) -> str:
13 | # clean up
14 | text = text.strip()
15 | text = text.lower()
16 | text = re.sub(_whitespace_re, ' ', text)
17 |
18 | # g2p
19 | _converter_mapping = {
20 | 'seq': lambda _: _, # => 'str'
21 | 'syl4': to_syl4, # => [C, V, T, Vx]
22 | }
23 | phs = _converter_mapping[hp.g2p](text)
24 | return phs
25 |
26 |
27 | def phoneme_to_sequence(phoneme:Union[str, List]) -> List[int]:
28 | return [symbol_to_id(ph) for ph in phoneme]
29 |
30 |
31 | def sequence_to_phoneme(sequence:List[int]) -> str:
32 | return ''.join([id_to_symbol(id) for id in sequence])
33 |
--------------------------------------------------------------------------------
/transtacos/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | from time import time, sleep
3 | from pprint import pformat
4 | from argparse import ArgumentParser
5 | import traceback
6 |
7 | import tensorflow as tf
8 | import numpy as np
9 |
10 | import hparam as hp
11 | from models.tacotron import Tacotron
12 | from data import DataFeeder
13 | from text.text import sequence_to_phoneme
14 | from audio import save_wav, inv_spec
15 | from utils import *
16 |
17 |
18 | for gpu in tf.config.experimental.list_physical_devices('GPU'):
19 | tf.config.experimental.set_memory_growth(gpu, True)
20 |
21 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1'
22 | os.environ['TF_XLA_FLAGS'] = '--tf_xla_cpu_global_jit'
23 | os.environ['XLA_FLAGS'] = '--xla_hlo_profile'
24 |
25 | np.random.seed (hp.randseed)
26 | tf.random.set_random_seed(hp.randseed)
27 |
28 |
29 | def train(args):
30 | # Logg Folder
31 | log_dir = os.path.join(args.base_dir, args.name)
32 | os.makedirs(log_dir, exist_ok=True)
33 | log_init(os.path.join(log_dir, 'train.log'))
34 |
35 | ckpt_path = os.path.join(log_dir, 'model.ckpt')
36 | log('Checkpoint path: %s' % ckpt_path)
37 | input_path = os.path.join(args.base_dir, args.input)
38 | log('Loading training data from: %s' % input_path)
39 | log('Hyperparams:')
40 | log(pformat({k: getattr(hp, k) for k in dir(hp) if not k.startswith('__')}, indent=2))
41 |
42 | # DataFeeder
43 | coord = tf.train.Coordinator()
44 | with tf.variable_scope('datafeeder'):
45 | feeder = DataFeeder(coord, input_path, hp)
46 |
47 | # Model
48 | with tf.variable_scope('model'):
49 | model = Tacotron(hp)
50 | model.initialize(feeder.text_lengths,
51 | feeder.text, feeder.prds,
52 | feeder.spec_lengths,
53 | feeder.mel_targets, feeder.mag_targets, feeder.f0_targets, feeder.c0_targets,
54 | feeder.stop_token_targets)
55 | model.add_loss()
56 | model.add_optimizer()
57 | model.add_stats()
58 | param_count = sum([np.prod(v.get_shape()) for v in tf.trainable_variables()])
59 | log(f'param_cnt = {param_count}')
60 |
61 | # Bookkeeping
62 | step = 0 # local step
63 | time_window = ValueWindow(100) # for perfcount
64 | loss_window = ValueWindow(100)
65 | saver = tf.train.Saver(max_to_keep=hp.max_ckpt)
66 |
67 | # Train!
68 | with tf.Session() as sess:
69 | try:
70 | sw = tf.summary.FileWriter(log_dir, sess.graph)
71 | sess.run(tf.global_variables_initializer())
72 |
73 | # Restore from a checkpoint if available
74 | ckpt_state = tf.train.get_checkpoint_state(log_dir)
75 | if ckpt_state is not None:
76 | saver.restore(sess, ckpt_state.model_checkpoint_path)
77 | log('Resuming from checkpoint: %s' % ckpt_state.model_checkpoint_path)
78 | else:
79 | log('Starting new training run')
80 |
81 | feeder.start_in_session(sess)
82 | while not coord.should_stop():
83 | t = time()
84 | step, loss, opt = sess.run([model.global_step, model.loss, model.optimize])
85 | time_window.append(time() - t)
86 | loss_window.append(loss)
87 | log('Step %-7d [%.03f sec/step, loss=%.05f, avg_loss=%.05f]' % (step, time_window.average, loss, loss_window.average))
88 |
89 | if loss > 300 or np.isnan(loss):
90 | log('Loss exploded to %.05f at step %d!' % (loss, step))
91 | raise Exception('Loss Exploded')
92 |
93 | if step % args.summary_interval == 0:
94 | log('Writing summary at step: %d' % step)
95 | sw.add_summary(sess.run(model.stats), step)
96 |
97 | if step % args.checkpoint_interval == 0:
98 | log('Saving checkpoint to: %s-%d' % (ckpt_path, step))
99 | saver.save(sess, ckpt_path, global_step=step)
100 | log('Saving audio and alignment...')
101 |
102 | if hp.g2p == 'seq':
103 | (text, mel, mag, alignment, spec_len, mel_r, mag_r, mel_loss, mag_loss) = sess.run([
104 | model.text[0],
105 | model.mel_outputs[0], model.mag_outputs[0], model.alignments[0], model.spec_lengths[0],
106 | model.mel_targets[0], model.mag_targets[0], model.mel_loss, model.mag_loss])
107 | log('Input:')
108 | log(f' seq: {text}')
109 | log(f' phs: {sequence_to_phoneme(text)}')
110 | elif hp.g2p == 'syl4':
111 | (text, prds_o, prds_r, mel, mag, alignment, spec_len, mel_r, mag_r, mel_loss, mag_loss) = sess.run([
112 | model.text[0], model.prds_out[0], model.prds[0],
113 | model.mel_outputs[0], model.mag_outputs[0], model.alignments[0], model.spec_lengths[0],
114 | model.mel_targets[0], model.mag_targets[0], model.mel_loss, model.mag_loss])
115 | CVVx, T = text.T.tolist()
116 | log('Input:')
117 | log(f' text: {sequence_to_phoneme(CVVx)}')
118 | log(f' tone: {"".join([str(t) for t in T])}')
119 | log(f' prds: {"".join([str(p) for p in prds_r])}')
120 | log(f' pred: {"".join([str(p) for p in prds_o])}')
121 |
122 | mel, mag, mel_r, mag_r = [m[:spec_len,:].T for m in [mel, mag, mel_r, mag_r]]
123 | save_wav(inv_spec(mag), os.path.join(log_dir, 'step-%d-audio.wav' % step))
124 | plot_specs([mel, mag, mel_r, mag_r], os.path.join(log_dir, 'step-%d-specs.png' % step),
125 | info=f'{time_string()}, mel_loss={mel_loss:.5f}, mag_loss={mag_loss:.5f}')
126 | plot_alignment(alignment, os.path.join(log_dir, 'step-%d-align.png' % step),
127 | info='%s, step=%d, loss=%.5f' % (time_string(), step, loss))
128 |
129 | if step >= hp.max_steps + 10:
130 | print('[Train] Done')
131 | sleep(5)
132 | break
133 |
134 | except Exception as e:
135 | log('Exiting due to exception: %s' % e)
136 | traceback.print_exc()
137 | coord.request_stop(e)
138 |
139 |
140 | if __name__ == '__main__':
141 | parser = ArgumentParser()
142 | parser.add_argument('--base_dir', default=os.path.expanduser('.'))
143 | parser.add_argument('--input', default='preprocessed/train.txt')
144 | parser.add_argument('--name', default='transtacos', help='Name of the run, used for logging.')
145 | parser.add_argument('--summary_interval', type=int, default=1000, help='Steps between running summary ops.')
146 | parser.add_argument('--checkpoint_interval', type=int, default=1500, help='Steps between writing checkpoints.')
147 | args = parser.parse_args()
148 |
149 | train(args)
150 |
--------------------------------------------------------------------------------
/transtacos/utils.py:
--------------------------------------------------------------------------------
1 | import atexit
2 | from datetime import datetime
3 |
4 | import matplotlib ; matplotlib.use('Agg')
5 | import matplotlib.pyplot as plt
6 | import seaborn as sns
7 |
8 | import hparam as hp
9 |
10 | _log_fmt = '%Y-%m-%d %H:%M:%S.%f'
11 | _log_fp = None
12 |
13 |
14 | def log_init(fp):
15 | global _log_fp
16 | _close_logfile()
17 |
18 | _log_fp = open(fp, 'a')
19 | _log_fp.write('\n')
20 | _log_fp.write('-----------------------------------------------------------------\n')
21 | _log_fp.write(' Starting new training run\n')
22 | _log_fp.write('-----------------------------------------------------------------\n')
23 |
24 |
25 | def log(msg):
26 | print(msg)
27 | if _log_fp:
28 | _log_fp.write('[%s] %s\n' % (datetime.now().strftime(_log_fmt)[:-3], msg))
29 |
30 |
31 | def _close_logfile():
32 | global _log_fp
33 | if _log_fp:
34 | _log_fp.close()
35 | _log_fp = None
36 |
37 |
38 | atexit.register(_close_logfile)
39 |
40 |
41 | def plot_alignment(alignment, path, info=None):
42 | fig, ax = plt.subplots()
43 | im = ax.imshow(
44 | alignment,
45 | aspect='auto',
46 | origin='lower',
47 | interpolation='none')
48 | fig.colorbar(im, ax=ax)
49 | plt.xlabel('Decoder timestep' + (f'\n\n{info}' if info else ''))
50 | plt.ylabel('Encoder timestep')
51 | plt.tight_layout()
52 | plt.savefig(path, format='png')
53 |
54 |
55 | def plot_specs(specs, path, info=None):
56 | # mel_g mel_r
57 | # mag_g mag_r
58 | ax = plt.subplot(221) ; sns.heatmap(specs[0]) ; ax.invert_yaxis()
59 | ax = plt.subplot(222) ; sns.heatmap(specs[2]) ; ax.invert_yaxis()
60 | ax = plt.subplot(223) ; sns.heatmap(specs[1]) ; ax.invert_yaxis()
61 | ax = plt.subplot(224) ; sns.heatmap(specs[3]) ; ax.invert_yaxis()
62 | plt.xlabel(info)
63 | plt.tight_layout()
64 | plt.margins(0, 0)
65 | plt.savefig(path, format='png', dpi=400)
66 |
67 |
68 | def time_string():
69 | return datetime.now().strftime('%Y-%m-%d %H:%M')
70 |
71 |
72 | class ValueWindow(): # NOTE: 右进左出的定长队列
73 |
74 | def __init__(self, window_size=100):
75 | self._window_size = window_size
76 | self._values = []
77 |
78 | def append(self, x):
79 | self._values = self._values[-(self._window_size - 1):] + [x]
80 |
81 | @property
82 | def sum(self):
83 | return sum(self._values)
84 |
85 | @property
86 | def count(self):
87 | return len(self._values)
88 |
89 | @property
90 | def average(self):
91 | return self.sum / max(1, self.count)
92 |
93 | def reset(self):
94 | self._values = []
95 |
--------------------------------------------------------------------------------