4 |
5 | ## Overview
6 |
7 | FastVC is a fast and efficient, non-parallel and any-to-any *voice conversion (VC)* tool. VC involves the modification of the voice of a source speaker to make it sound like that of a target speaker, without changing the linguistic content of the sentence. Our tool exploits the task by cascading an Automatic Speech Recognition (ASR) model and a Text To Speech (TTS) model.
8 |
9 |
10 |
11 |
12 |
13 | The ASR is based on [Wav2vec 2.0](https://arxiv.org/pdf/2006.11477.pdf) and is used to transcribe the speech from a source speaker. The TTS is based on [SV2TTS](https://arxiv.org/pdf/1806.04558.pdf) and is used to generate the output speech from a target speaker embedding.
14 |
15 | For a more detailed explanation check out [the paper of our project](https://github.com/fmiotello/fastVC/files/6849958/L11_Report.pdf). A demo page is available [here](https://fmiotello.github.io/fastVC).
16 |
17 | ## Installation & usage
18 |
19 | The software was implemented using `python 3.9.4`
20 |
21 | 1. Clone the repository (`git clone https://github.com/fmiotello/fastVC.git`) and enter the directory (`cd fastVC`)
22 | 2. (*optional*) Create virtual env and activate it: `python -m venv env` and `source env/bin/activate` (if using macOS/Linux) or `.\env\Scripts\activate` (if using Windows)
23 | 3. Upgrade pip: `python -m pip install --upgrade pip`
24 | 4. Install dependencies: `python -m pip install -r requirements.txt`
25 | 5. Download the pretrained models ([encoder](https://drive.google.com/file/d/1q8mEGwCkFy23KZsinbuvdKAQLqNKbYf1/view?usp=sharing), [synthesizer](https://drive.google.com/file/d/1EqFMIbvxffxtjiVrtykroF6_mUh-5Z3s/view?usp=sharing), [vocoder](https://drive.google.com/file/d/1cf2NO6FtI0jDuy8AV3Xgn6leO6dHjIgu/view?usp=sharing)) and put them in the correct directories:
26 | ````
27 | ./src/encoder/saved_models/pretrained.pt
28 | ./src/synthesizer/saved_models/pretrained/pretrained.pt
29 | ./src/vocoder/saved_models/pretrained/pretrained.pt
30 | ````
31 | 6. Run the main script: `python src/main.py` (use `--help` for displaying available options). The output audio will be `./src/audio/audio_out.wav`.
32 |
33 | More instructions can be found [here](https://github.com/fmiotello/fastVC/tree/main/src).
34 |
35 | ## Notes
36 |
37 | This application was developed as a project at [Politecnico di Milano](https://www.polimi.it/en/) (MSc in Music and Acoustic Engineering).
38 |
39 | *[Luigi Attorresi](https://github.com/LuigiAttorresi)*
40 | *[Federico Miotello](https://github.com/fmiotello)*
41 | *[Eugenio Poliuti](https://github.com/Poliuti)*
42 |
--------------------------------------------------------------------------------
/src/vocoder/audio.py:
--------------------------------------------------------------------------------
1 | import math
2 | import numpy as np
3 | import librosa
4 | import vocoder.hparams as hp
5 | from scipy.signal import lfilter
6 | import soundfile as sf
7 |
8 |
9 | def label_2_float(x, bits) :
10 | return 2 * x / (2**bits - 1.) - 1.
11 |
12 |
13 | def float_2_label(x, bits) :
14 | assert abs(x).max() <= 1.0
15 | x = (x + 1.) * (2**bits - 1) / 2
16 | return x.clip(0, 2**bits - 1)
17 |
18 |
19 | def load_wav(path) :
20 | return librosa.load(str(path), sr=hp.sample_rate)[0]
21 |
22 |
23 | def save_wav(x, path) :
24 | sf.write(path, x.astype(np.float32), hp.sample_rate)
25 |
26 |
27 | def split_signal(x) :
28 | unsigned = x + 2**15
29 | coarse = unsigned // 256
30 | fine = unsigned % 256
31 | return coarse, fine
32 |
33 |
34 | def combine_signal(coarse, fine) :
35 | return coarse * 256 + fine - 2**15
36 |
37 |
38 | def encode_16bits(x) :
39 | return np.clip(x * 2**15, -2**15, 2**15 - 1).astype(np.int16)
40 |
41 |
42 | mel_basis = None
43 |
44 |
45 | def linear_to_mel(spectrogram):
46 | global mel_basis
47 | if mel_basis is None:
48 | mel_basis = build_mel_basis()
49 | return np.dot(mel_basis, spectrogram)
50 |
51 |
52 | def build_mel_basis():
53 | return librosa.filters.mel(hp.sample_rate, hp.n_fft, n_mels=hp.num_mels, fmin=hp.fmin)
54 |
55 |
56 | def normalize(S):
57 | return np.clip((S - hp.min_level_db) / -hp.min_level_db, 0, 1)
58 |
59 |
60 | def denormalize(S):
61 | return (np.clip(S, 0, 1) * -hp.min_level_db) + hp.min_level_db
62 |
63 |
64 | def amp_to_db(x):
65 | return 20 * np.log10(np.maximum(1e-5, x))
66 |
67 |
68 | def db_to_amp(x):
69 | return np.power(10.0, x * 0.05)
70 |
71 |
72 | def spectrogram(y):
73 | D = stft(y)
74 | S = amp_to_db(np.abs(D)) - hp.ref_level_db
75 | return normalize(S)
76 |
77 |
78 | def melspectrogram(y):
79 | D = stft(y)
80 | S = amp_to_db(linear_to_mel(np.abs(D)))
81 | return normalize(S)
82 |
83 |
84 | def stft(y):
85 | return librosa.stft(y=y, n_fft=hp.n_fft, hop_length=hp.hop_length, win_length=hp.win_length)
86 |
87 |
88 | def pre_emphasis(x):
89 | return lfilter([1, -hp.preemphasis], [1], x)
90 |
91 |
92 | def de_emphasis(x):
93 | return lfilter([1], [1, -hp.preemphasis], x)
94 |
95 |
96 | def encode_mu_law(x, mu) :
97 | mu = mu - 1
98 | fx = np.sign(x) * np.log(1 + mu * np.abs(x)) / np.log(1 + mu)
99 | return np.floor((fx + 1) / 2 * mu + 0.5)
100 |
101 |
102 | def decode_mu_law(y, mu, from_labels=True) :
103 | if from_labels:
104 | y = label_2_float(y, math.log2(mu))
105 | mu = mu - 1
106 | x = np.sign(y) / mu * ((1 + mu) ** np.abs(y) - 1)
107 | return x
108 |
--------------------------------------------------------------------------------
/src/synthesizer/utils/cleaners.py:
--------------------------------------------------------------------------------
1 | """
2 | Cleaners are transformations that run over the input text at both training and eval time.
3 |
4 | Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners"
5 | hyperparameter. Some cleaners are English-specific. You"ll typically want to use:
6 | 1. "english_cleaners" for English text
7 | 2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using
8 | the Unidecode library (https://pypi.python.org/pypi/Unidecode)
9 | 3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update
10 | the symbols in symbols.py to match your data).
11 | """
12 |
13 | import re
14 | from unidecode import unidecode
15 | from .numbers import normalize_numbers
16 |
17 | # Regular expression matching whitespace:
18 | _whitespace_re = re.compile(r"\s+")
19 |
20 | # List of (regular expression, replacement) pairs for abbreviations:
21 | _abbreviations = [(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1]) for x in [
22 | ("mrs", "misess"),
23 | ("mr", "mister"),
24 | ("dr", "doctor"),
25 | ("st", "saint"),
26 | ("co", "company"),
27 | ("jr", "junior"),
28 | ("maj", "major"),
29 | ("gen", "general"),
30 | ("drs", "doctors"),
31 | ("rev", "reverend"),
32 | ("lt", "lieutenant"),
33 | ("hon", "honorable"),
34 | ("sgt", "sergeant"),
35 | ("capt", "captain"),
36 | ("esq", "esquire"),
37 | ("ltd", "limited"),
38 | ("col", "colonel"),
39 | ("ft", "fort"),
40 | ]]
41 |
42 |
43 | def expand_abbreviations(text):
44 | for regex, replacement in _abbreviations:
45 | text = re.sub(regex, replacement, text)
46 | return text
47 |
48 |
49 | def expand_numbers(text):
50 | return normalize_numbers(text)
51 |
52 |
53 | def lowercase(text):
54 | """lowercase input tokens."""
55 | return text.lower()
56 |
57 |
58 | def collapse_whitespace(text):
59 | return re.sub(_whitespace_re, " ", text)
60 |
61 |
62 | def convert_to_ascii(text):
63 | return unidecode(text)
64 |
65 |
66 | def basic_cleaners(text):
67 | """Basic pipeline that lowercases and collapses whitespace without transliteration."""
68 | text = lowercase(text)
69 | text = collapse_whitespace(text)
70 | return text
71 |
72 |
73 | def transliteration_cleaners(text):
74 | """Pipeline for non-English text that transliterates to ASCII."""
75 | text = convert_to_ascii(text)
76 | text = lowercase(text)
77 | text = collapse_whitespace(text)
78 | return text
79 |
80 |
81 | def english_cleaners(text):
82 | """Pipeline for English text, including number and abbreviation expansion."""
83 | text = convert_to_ascii(text)
84 | text = lowercase(text)
85 | text = expand_numbers(text)
86 | text = expand_abbreviations(text)
87 | text = collapse_whitespace(text)
88 | return text
89 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | absl-py==0.12.0
2 | appdirs==1.4.4
3 | asrtoolkit @ https://github.com/finos/greenkey-asrtoolkit/archive/refs/tags/0.2.4.tar.gz
4 | astunparse==1.6.3
5 | attrs==21.2.0
6 | audioread==2.1.9
7 | beautifulsoup4==4.9.3
8 | cachetools==4.2.2
9 | certifi==2020.12.5
10 | cffi==1.14.5
11 | chardet==4.0.0
12 | click==7.1.2
13 | cycler==0.10.0
14 | datasets==1.6.2
15 | decorator==5.0.7
16 | dill==0.3.3
17 | docopt==0.6.2
18 | editdistance==0.5.3
19 | ffmpeg==1.4
20 | ffmpeg-python==0.2.0
21 | filelock==3.0.12
22 | fire==0.4.0
23 | flatbuffers==1.12
24 | fsspec==2021.4.0
25 | future==0.18.2
26 | gast==0.4.0
27 | google-auth==1.30.1
28 | google-auth-oauthlib==0.4.4
29 | google-pasta==0.2.0
30 | grpcio==1.34.1
31 | h5py==3.1.0
32 | huggingface-hub==0.0.8
33 | idna==2.10
34 | inflect==5.3.0
35 | iniconfig==1.1.1
36 | jiwer==2.2.0
37 | joblib==1.0.1
38 | jsonpatch==1.32
39 | jsonpointer==2.1
40 | jsonschema==3.2.0
41 | keras-nightly==2.5.0.dev2021032900
42 | Keras-Preprocessing==1.1.2
43 | kiwisolver==1.3.1
44 | librosa==0.8.0
45 | llvmlite==0.36.0
46 | Markdown==3.3.4
47 | matplotlib==3.4.1
48 | mock==4.0.3
49 | multiprocess==0.70.11.1
50 | musdb==0.4.0
51 | museval==0.4.0
52 | nltk==3.6.2
53 | nose==1.3.7
54 | num2words==0.5.10
55 | numba==0.53.1
56 | numpy==1.19.3; platform_system == "Windows"
57 | numpy==1.19.4; platform_system != "Windows"
58 | oauthlib==3.1.0
59 | opt-einsum==3.3.0
60 | packaging==20.9
61 | pandas==1.2.4
62 | pesq==0.0.3
63 | Pillow==8.2.0
64 | pluggy==0.13.1
65 | pooch==1.3.0
66 | protobuf==3.17.1
67 | py==1.10.0
68 | pyaml==20.4.0
69 | pyarrow==4.0.0
70 | pyasn1==0.4.8
71 | pyasn1-modules==0.2.8
72 | pycparser==2.20
73 | pynndescent==0.5.2
74 | pyparsing==2.4.7
75 | pypesq==1.2.4
76 | PyQt5==5.15.4
77 | PyQt5-Qt5==5.15.2
78 | PyQt5-sip==12.8.1
79 | pyrsistent==0.17.3
80 | pystoi==0.3.3
81 | pytest==6.2.4
82 | python-dateutil==2.8.1
83 | python-Levenshtein==0.12.2
84 | pytz==2021.1
85 | PyYAML==5.4.1
86 | pyzmq==22.0.3
87 | regex==2021.4.4
88 | requests==2.25.1
89 | requests-oauthlib==1.3.0
90 | resampy==0.2.2
91 | rsa==4.7.2
92 | sacremoses==0.0.45
93 | scikit-learn==0.24.2
94 | scipy==1.6.3
95 | simplejson==3.17.2
96 | six==1.15.0
97 | sounddevice==0.4.1
98 | SoundFile==0.10.3.post1
99 | soupsieve==2.2.1
100 | git+https://github.com/aliutkus/speechmetrics#egg=speechmetrics[cpu]
101 | stempeg==0.2.3
102 | tensorboard==2.5.0
103 | tensorboard-data-server==0.6.1
104 | tensorboard-plugin-wit==1.8.0
105 | tensorflow==2.5.0
106 | tensorflow-estimator==2.5.0
107 | termcolor==1.1.0
108 | threadpoolctl==2.1.0
109 | tokenizers==0.10.2
110 | toml==0.10.2
111 | torch==1.8.1
112 | torchaudio==0.8.1
113 | torchfile==0.1.0
114 | torchvision==0.9.1
115 | tornado==6.1
116 | tqdm==4.49.0
117 | transformers==4.5.1
118 | typing-extensions==3.7.4.3
119 | umap-learn==0.5.1
120 | Unidecode==1.2.0
121 | urllib3==1.26.5
122 | visdom==0.1.8.9
123 | webrtcvad==2.0.10; platform_system != "Windows"
124 | websocket-client==0.59.0
125 | webvtt-py==0.4.6
126 | Werkzeug==2.0.1
127 | wrapt==1.12.1
128 | xlrd==1.2.0
129 | xxhash==2.0.2
130 |
--------------------------------------------------------------------------------
/src/vocoder/display.py:
--------------------------------------------------------------------------------
1 | import matplotlib.pyplot as plt
2 | import time
3 | import numpy as np
4 | import sys
5 |
6 |
7 | def progbar(i, n, size=16):
8 | done = (i * size) // n
9 | bar = ''
10 | for i in range(size):
11 | bar += '█' if i <= done else '░'
12 | return bar
13 |
14 |
15 | def stream(message) :
16 | try:
17 | sys.stdout.write("\r{%s}" % message)
18 | except:
19 | #Remove non-ASCII characters from message
20 | message = ''.join(i for i in message if ord(i)<128)
21 | sys.stdout.write("\r{%s}" % message)
22 |
23 |
24 | def simple_table(item_tuples) :
25 |
26 | border_pattern = '+---------------------------------------'
27 | whitespace = ' '
28 |
29 | headings, cells, = [], []
30 |
31 | for item in item_tuples :
32 |
33 | heading, cell = str(item[0]), str(item[1])
34 |
35 | pad_head = True if len(heading) < len(cell) else False
36 |
37 | pad = abs(len(heading) - len(cell))
38 | pad = whitespace[:pad]
39 |
40 | pad_left = pad[:len(pad)//2]
41 | pad_right = pad[len(pad)//2:]
42 |
43 | if pad_head :
44 | heading = pad_left + heading + pad_right
45 | else :
46 | cell = pad_left + cell + pad_right
47 |
48 | headings += [heading]
49 | cells += [cell]
50 |
51 | border, head, body = '', '', ''
52 |
53 | for i in range(len(item_tuples)) :
54 |
55 | temp_head = f'| {headings[i]} '
56 | temp_body = f'| {cells[i]} '
57 |
58 | border += border_pattern[:len(temp_head)]
59 | head += temp_head
60 | body += temp_body
61 |
62 | if i == len(item_tuples) - 1 :
63 | head += '|'
64 | body += '|'
65 | border += '+'
66 |
67 | print(border)
68 | print(head)
69 | print(border)
70 | print(body)
71 | print(border)
72 | print(' ')
73 |
74 |
75 | def time_since(started) :
76 | elapsed = time.time() - started
77 | m = int(elapsed // 60)
78 | s = int(elapsed % 60)
79 | if m >= 60 :
80 | h = int(m // 60)
81 | m = m % 60
82 | return f'{h}h {m}m {s}s'
83 | else :
84 | return f'{m}m {s}s'
85 |
86 |
87 | def save_attention(attn, path) :
88 | fig = plt.figure(figsize=(12, 6))
89 | plt.imshow(attn.T, interpolation='nearest', aspect='auto')
90 | fig.savefig(f'{path}.png', bbox_inches='tight')
91 | plt.close(fig)
92 |
93 |
94 | def save_spectrogram(M, path, length=None) :
95 | M = np.flip(M, axis=0)
96 | if length : M = M[:, :length]
97 | fig = plt.figure(figsize=(12, 6))
98 | plt.imshow(M, interpolation='nearest', aspect='auto')
99 | fig.savefig(f'{path}.png', bbox_inches='tight')
100 | plt.close(fig)
101 |
102 |
103 | def plot(array) :
104 | fig = plt.figure(figsize=(30, 5))
105 | ax = fig.add_subplot(111)
106 | ax.xaxis.label.set_color('grey')
107 | ax.yaxis.label.set_color('grey')
108 | ax.xaxis.label.set_fontsize(23)
109 | ax.yaxis.label.set_fontsize(23)
110 | ax.tick_params(axis='x', colors='grey', labelsize=23)
111 | ax.tick_params(axis='y', colors='grey', labelsize=23)
112 | plt.plot(array)
113 |
114 |
115 | def plot_spec(M) :
116 | M = np.flip(M, axis=0)
117 | plt.figure(figsize=(18,4))
118 | plt.imshow(M, interpolation='nearest', aspect='auto')
119 | plt.show()
120 |
121 |
--------------------------------------------------------------------------------
/docs/index.md:
--------------------------------------------------------------------------------
1 | ---
2 | layout: default
3 | title: Fast VC
4 | ---
5 |
6 | ## Abstract
7 |
8 | Current state-of-the-art *voice conversion (VC)* tools rely on neural models trained on massive corpora of data for hundreds of hours. This approach surely leads to astonishing results, but lacks in speed, simplicity and accessibility. In this paper we introduce a simple and fast *any-to-any non-parallel* voice conversion tool that is able to perform its task provided only with a small audio excerpt of target speaker. We consider a modular approach to VC, cascading an *automatic-speech-recognition (ASR)* model, used to transcribe the source speech, and a *text-to-speech (TTS)* model, to generate the target speech. This approach presents a straightforward pipeline, allows to use already available models and opens doors to many expansions. We prove our output to be intelligible and distinguishable between different speakers.
9 |
10 | ## Audio examples
11 |
12 | ### Class 1
13 |
14 |
15 |
16 |
Source speaker
17 |
18 |
19 |
20 |
Target speaker
21 |
22 |
23 |
24 |
Output
25 |
26 |
27 |
28 |
29 | ### Class 2
30 |
31 |
32 |
33 |
Source speaker
34 |
35 |
36 |
37 |
Target speaker
38 |
39 |
40 |
41 |
Output
42 |
43 |
44 |
45 |
46 | ### Class 3
47 |
48 |
49 |
50 |
Source speaker
51 |
52 |
53 |
54 |
Target speaker
55 |
56 |
57 |
58 |
Output
59 |
60 |
61 |
62 |
63 | ### Class 4
64 |
65 |
66 |
67 |
Source speaker
68 |
69 |
70 |
71 |
Target speaker
72 |
73 |
74 |
75 |
Output
76 |
77 |
78 |
79 |
80 | ### Class 5
81 |
82 |
83 |
84 |
Source speaker
85 |
86 |
87 |
88 |
Target speaker
89 |
90 |
91 |
92 |
Output
93 |
94 |
95 |
96 |
--------------------------------------------------------------------------------
/src/synthesizer/hparams.py:
--------------------------------------------------------------------------------
1 | import ast
2 | import pprint
3 |
4 | class HParams(object):
5 | def __init__(self, **kwargs): self.__dict__.update(kwargs)
6 | def __setitem__(self, key, value): setattr(self, key, value)
7 | def __getitem__(self, key): return getattr(self, key)
8 | def __repr__(self): return pprint.pformat(self.__dict__)
9 |
10 | def parse(self, string):
11 | # Overrides hparams from a comma-separated string of name=value pairs
12 | if len(string) > 0:
13 | overrides = [s.split("=") for s in string.split(",")]
14 | keys, values = zip(*overrides)
15 | keys = list(map(str.strip, keys))
16 | values = list(map(str.strip, values))
17 | for k in keys:
18 | self.__dict__[k] = ast.literal_eval(values[keys.index(k)])
19 | return self
20 |
21 | hparams = HParams(
22 | ### Signal Processing (used in both synthesizer and vocoder)
23 | sample_rate = 16000,
24 | n_fft = 800,
25 | num_mels = 80,
26 | hop_size = 200, # Tacotron uses 12.5 ms frame shift (set to sample_rate * 0.0125)
27 | win_size = 800, # Tacotron uses 50 ms frame length (set to sample_rate * 0.050)
28 | fmin = 55,
29 | min_level_db = -100,
30 | ref_level_db = 20,
31 | max_abs_value = 4., # Gradient explodes if too big, premature convergence if too small.
32 | preemphasis = 0.97, # Filter coefficient to use if preemphasize is True
33 | preemphasize = True,
34 |
35 | ### Tacotron Text-to-Speech (TTS)
36 | tts_embed_dims = 512, # Embedding dimension for the graphemes/phoneme inputs
37 | tts_encoder_dims = 256,
38 | tts_decoder_dims = 128,
39 | tts_postnet_dims = 512,
40 | tts_encoder_K = 5,
41 | tts_lstm_dims = 1024,
42 | tts_postnet_K = 5,
43 | tts_num_highways = 4,
44 | tts_dropout = 0.5,
45 | tts_cleaner_names = ["english_cleaners"],
46 | tts_stop_threshold = -3.4, # Value below which audio generation ends.
47 | # For example, for a range of [-4, 4], this
48 | # will terminate the sequence at the first
49 | # frame that has all values < -3.4
50 |
51 | ### Tacotron Training
52 | tts_schedule = [(2, 1e-3, 20_000, 12), # Progressive training schedule
53 | (2, 5e-4, 40_000, 12), # (r, lr, step, batch_size)
54 | (2, 2e-4, 80_000, 12), #
55 | (2, 1e-4, 160_000, 12), # r = reduction factor (# of mel frames
56 | (2, 3e-5, 320_000, 12), # synthesized for each decoder iteration)
57 | (2, 1e-5, 640_000, 12)], # lr = learning rate
58 |
59 | tts_clip_grad_norm = 1.0, # clips the gradient norm to prevent explosion - set to None if not needed
60 | tts_eval_interval = 500, # Number of steps between model evaluation (sample generation)
61 | # Set to -1 to generate after completing epoch, or 0 to disable
62 |
63 | tts_eval_num_samples = 1, # Makes this number of samples
64 |
65 | ### Data Preprocessing
66 | max_mel_frames = 900,
67 | rescale = True,
68 | rescaling_max = 0.9,
69 | synthesis_batch_size = 16, # For vocoder preprocessing and inference.
70 |
71 | ### Mel Visualization and Griffin-Lim
72 | signal_normalization = True,
73 | power = 1.5,
74 | griffin_lim_iters = 60,
75 |
76 | ### Audio processing options
77 | fmax = 7600, # Should not exceed (sample_rate // 2)
78 | allow_clipping_in_normalization = True, # Used when signal_normalization = True
79 | clip_mels_length = True, # If true, discards samples exceeding max_mel_frames
80 | use_lws = False, # "Fast spectrogram phase recovery using local weighted sums"
81 | symmetric_mels = True, # Sets mel range to [-max_abs_value, max_abs_value] if True,
82 | # and [0, max_abs_value] if False
83 | trim_silence = True, # Use with sample_rate of 16000 for best results
84 |
85 | ### SV2TTS
86 | speaker_embedding_size = 256, # Dimension for the speaker embedding
87 | silence_min_duration_split = 0.4, # Duration in seconds of a silence for an utterance to be split
88 | utterance_min_duration = 1.6, # Duration in seconds below which utterances are discarded
89 | )
90 |
91 | def hparams_debug_string():
92 | return str(hparams)
93 |
--------------------------------------------------------------------------------
/src/encoder/audio.py:
--------------------------------------------------------------------------------
1 | from scipy.ndimage.morphology import binary_dilation
2 | from encoder.params_data import *
3 | from pathlib import Path
4 | from typing import Optional, Union
5 | from warnings import warn
6 | import numpy as np
7 | import librosa
8 | import struct
9 |
10 | try:
11 | import webrtcvad
12 | except:
13 | warn("Unable to import 'webrtcvad'. This package enables noise removal and is recommended.")
14 | webrtcvad=None
15 |
16 | int16_max = (2 ** 15) - 1
17 |
18 |
19 | def preprocess_wav(fpath_or_wav: Union[str, Path, np.ndarray],
20 | source_sr: Optional[int] = None,
21 | normalize: Optional[bool] = True,
22 | trim_silence: Optional[bool] = True):
23 | """
24 | Applies the preprocessing operations used in training the Speaker Encoder to a waveform
25 | either on disk or in memory. The waveform will be resampled to match the data hyperparameters.
26 |
27 | :param fpath_or_wav: either a filepath to an audio file (many extensions are supported, not
28 | just .wav), either the waveform as a numpy array of floats.
29 | :param source_sr: if passing an audio waveform, the sampling rate of the waveform before
30 | preprocessing. After preprocessing, the waveform's sampling rate will match the data
31 | hyperparameters. If passing a filepath, the sampling rate will be automatically detected and
32 | this argument will be ignored.
33 | """
34 | # Load the wav from disk if needed
35 | if isinstance(fpath_or_wav, str) or isinstance(fpath_or_wav, Path):
36 | wav, source_sr = librosa.load(str(fpath_or_wav), sr=None)
37 | else:
38 | wav = fpath_or_wav
39 |
40 | # Resample the wav if needed
41 | if source_sr is not None and source_sr != sampling_rate:
42 | wav = librosa.resample(wav, source_sr, sampling_rate)
43 |
44 | # Apply the preprocessing: normalize volume and shorten long silences
45 | if normalize:
46 | wav = normalize_volume(wav, audio_norm_target_dBFS, increase_only=True)
47 | if webrtcvad and trim_silence:
48 | wav = trim_long_silences(wav)
49 |
50 | return wav
51 |
52 |
53 | def wav_to_mel_spectrogram(wav):
54 | """
55 | Derives a mel spectrogram ready to be used by the encoder from a preprocessed audio waveform.
56 | Note: this not a log-mel spectrogram.
57 | """
58 | frames = librosa.feature.melspectrogram(
59 | wav,
60 | sampling_rate,
61 | n_fft=int(sampling_rate * mel_window_length / 1000),
62 | hop_length=int(sampling_rate * mel_window_step / 1000),
63 | n_mels=mel_n_channels
64 | )
65 | return frames.astype(np.float32).T
66 |
67 |
68 | def trim_long_silences(wav):
69 | """
70 | Ensures that segments without voice in the waveform remain no longer than a
71 | threshold determined by the VAD parameters in params.py.
72 |
73 | :param wav: the raw waveform as a numpy array of floats
74 | :return: the same waveform with silences trimmed away (length <= original wav length)
75 | """
76 | # Compute the voice detection window size
77 | samples_per_window = (vad_window_length * sampling_rate) // 1000
78 |
79 | # Trim the end of the audio to have a multiple of the window size
80 | wav = wav[:len(wav) - (len(wav) % samples_per_window)]
81 |
82 | # Convert the float waveform to 16-bit mono PCM
83 | pcm_wave = struct.pack("%dh" % len(wav), *(np.round(wav * int16_max)).astype(np.int16))
84 |
85 | # Perform voice activation detection
86 | voice_flags = []
87 | vad = webrtcvad.Vad(mode=3)
88 | for window_start in range(0, len(wav), samples_per_window):
89 | window_end = window_start + samples_per_window
90 | voice_flags.append(vad.is_speech(pcm_wave[window_start * 2:window_end * 2],
91 | sample_rate=sampling_rate))
92 | voice_flags = np.array(voice_flags)
93 |
94 | # Smooth the voice detection with a moving average
95 | def moving_average(array, width):
96 | array_padded = np.concatenate((np.zeros((width - 1) // 2), array, np.zeros(width // 2)))
97 | ret = np.cumsum(array_padded, dtype=float)
98 | ret[width:] = ret[width:] - ret[:-width]
99 | return ret[width - 1:] / width
100 |
101 | audio_mask = moving_average(voice_flags, vad_moving_average_width)
102 | audio_mask = np.round(audio_mask).astype(np.bool)
103 |
104 | # Dilate the voiced regions
105 | audio_mask = binary_dilation(audio_mask, np.ones(vad_max_silence_length + 1))
106 | audio_mask = np.repeat(audio_mask, samples_per_window)
107 |
108 | return wav[audio_mask == True]
109 |
110 |
111 | def normalize_volume(wav, target_dBFS, increase_only=False, decrease_only=False):
112 | if increase_only and decrease_only:
113 | raise ValueError("Both increase only and decrease only are set")
114 | dBFS_change = target_dBFS - 10 * np.log10(np.mean(wav ** 2))
115 | if (dBFS_change < 0 and increase_only) or (dBFS_change > 0 and decrease_only):
116 | return wav
117 | return wav * (10 ** (dBFS_change / 20))
118 |
--------------------------------------------------------------------------------
/src/vocoder/distribution.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn.functional as F
4 |
5 |
6 | def log_sum_exp(x):
7 | """ numerically stable log_sum_exp implementation that prevents overflow """
8 | # TF ordering
9 | axis = len(x.size()) - 1
10 | m, _ = torch.max(x, dim=axis)
11 | m2, _ = torch.max(x, dim=axis, keepdim=True)
12 | return m + torch.log(torch.sum(torch.exp(x - m2), dim=axis))
13 |
14 |
15 | # It is adapted from https://github.com/r9y9/wavenet_vocoder/blob/master/wavenet_vocoder/mixture.py
16 | def discretized_mix_logistic_loss(y_hat, y, num_classes=65536,
17 | log_scale_min=None, reduce=True):
18 | if log_scale_min is None:
19 | log_scale_min = float(np.log(1e-14))
20 | y_hat = y_hat.permute(0,2,1)
21 | assert y_hat.dim() == 3
22 | assert y_hat.size(1) % 3 == 0
23 | nr_mix = y_hat.size(1) // 3
24 |
25 | # (B x T x C)
26 | y_hat = y_hat.transpose(1, 2)
27 |
28 | # unpack parameters. (B, T, num_mixtures) x 3
29 | logit_probs = y_hat[:, :, :nr_mix]
30 | means = y_hat[:, :, nr_mix:2 * nr_mix]
31 | log_scales = torch.clamp(y_hat[:, :, 2 * nr_mix:3 * nr_mix], min=log_scale_min)
32 |
33 | # B x T x 1 -> B x T x num_mixtures
34 | y = y.expand_as(means)
35 |
36 | centered_y = y - means
37 | inv_stdv = torch.exp(-log_scales)
38 | plus_in = inv_stdv * (centered_y + 1. / (num_classes - 1))
39 | cdf_plus = torch.sigmoid(plus_in)
40 | min_in = inv_stdv * (centered_y - 1. / (num_classes - 1))
41 | cdf_min = torch.sigmoid(min_in)
42 |
43 | # log probability for edge case of 0 (before scaling)
44 | # equivalent: torch.log(F.sigmoid(plus_in))
45 | log_cdf_plus = plus_in - F.softplus(plus_in)
46 |
47 | # log probability for edge case of 255 (before scaling)
48 | # equivalent: (1 - F.sigmoid(min_in)).log()
49 | log_one_minus_cdf_min = -F.softplus(min_in)
50 |
51 | # probability for all other cases
52 | cdf_delta = cdf_plus - cdf_min
53 |
54 | mid_in = inv_stdv * centered_y
55 | # log probability in the center of the bin, to be used in extreme cases
56 | # (not actually used in our code)
57 | log_pdf_mid = mid_in - log_scales - 2. * F.softplus(mid_in)
58 |
59 | # tf equivalent
60 | """
61 | log_probs = tf.where(x < -0.999, log_cdf_plus,
62 | tf.where(x > 0.999, log_one_minus_cdf_min,
63 | tf.where(cdf_delta > 1e-5,
64 | tf.log(tf.maximum(cdf_delta, 1e-12)),
65 | log_pdf_mid - np.log(127.5))))
66 | """
67 | # TODO: cdf_delta <= 1e-5 actually can happen. How can we choose the value
68 | # for num_classes=65536 case? 1e-7? not sure..
69 | inner_inner_cond = (cdf_delta > 1e-5).float()
70 |
71 | inner_inner_out = inner_inner_cond * \
72 | torch.log(torch.clamp(cdf_delta, min=1e-12)) + \
73 | (1. - inner_inner_cond) * (log_pdf_mid - np.log((num_classes - 1) / 2))
74 | inner_cond = (y > 0.999).float()
75 | inner_out = inner_cond * log_one_minus_cdf_min + (1. - inner_cond) * inner_inner_out
76 | cond = (y < -0.999).float()
77 | log_probs = cond * log_cdf_plus + (1. - cond) * inner_out
78 |
79 | log_probs = log_probs + F.log_softmax(logit_probs, -1)
80 |
81 | if reduce:
82 | return -torch.mean(log_sum_exp(log_probs))
83 | else:
84 | return -log_sum_exp(log_probs).unsqueeze(-1)
85 |
86 |
87 | def sample_from_discretized_mix_logistic(y, log_scale_min=None):
88 | """
89 | Sample from discretized mixture of logistic distributions
90 | Args:
91 | y (Tensor): B x C x T
92 | log_scale_min (float): Log scale minimum value
93 | Returns:
94 | Tensor: sample in range of [-1, 1].
95 | """
96 | if log_scale_min is None:
97 | log_scale_min = float(np.log(1e-14))
98 | assert y.size(1) % 3 == 0
99 | nr_mix = y.size(1) // 3
100 |
101 | # B x T x C
102 | y = y.transpose(1, 2)
103 | logit_probs = y[:, :, :nr_mix]
104 |
105 | # sample mixture indicator from softmax
106 | temp = logit_probs.data.new(logit_probs.size()).uniform_(1e-5, 1.0 - 1e-5)
107 | temp = logit_probs.data - torch.log(- torch.log(temp))
108 | _, argmax = temp.max(dim=-1)
109 |
110 | # (B, T) -> (B, T, nr_mix)
111 | one_hot = to_one_hot(argmax, nr_mix)
112 | # select logistic parameters
113 | means = torch.sum(y[:, :, nr_mix:2 * nr_mix] * one_hot, dim=-1)
114 | log_scales = torch.clamp(torch.sum(
115 | y[:, :, 2 * nr_mix:3 * nr_mix] * one_hot, dim=-1), min=log_scale_min)
116 | # sample from logistic & clip to interval
117 | # we don't actually round to the nearest 8bit value when sampling
118 | u = means.data.new(means.size()).uniform_(1e-5, 1.0 - 1e-5)
119 | x = means + torch.exp(log_scales) * (torch.log(u) - torch.log(1. - u))
120 |
121 | x = torch.clamp(torch.clamp(x, min=-1.), max=1.)
122 |
123 | return x
124 |
125 |
126 | def to_one_hot(tensor, n, fill_with=1.):
127 | # we perform one hot encore with respect to the last axis
128 | one_hot = torch.FloatTensor(tensor.size() + (n,)).zero_()
129 | if tensor.is_cuda:
130 | one_hot = one_hot.cuda()
131 | one_hot.scatter_(len(tensor.size()), tensor.unsqueeze(-1), fill_with)
132 | return one_hot
133 |
--------------------------------------------------------------------------------
/src/encoder/model.py:
--------------------------------------------------------------------------------
1 | from encoder.params_model import *
2 | from encoder.params_data import *
3 | from scipy.interpolate import interp1d
4 | from sklearn.metrics import roc_curve
5 | from torch.nn.utils import clip_grad_norm_
6 | from scipy.optimize import brentq
7 | from torch import nn
8 | import numpy as np
9 | import torch
10 |
11 |
12 | class SpeakerEncoder(nn.Module):
13 | def __init__(self, device, loss_device):
14 | super().__init__()
15 | self.loss_device = loss_device
16 |
17 | # Network defition
18 | self.lstm = nn.LSTM(input_size=mel_n_channels,
19 | hidden_size=model_hidden_size,
20 | num_layers=model_num_layers,
21 | batch_first=True).to(device)
22 | self.linear = nn.Linear(in_features=model_hidden_size,
23 | out_features=model_embedding_size).to(device)
24 | self.relu = torch.nn.ReLU().to(device)
25 |
26 | # Cosine similarity scaling (with fixed initial parameter values)
27 | self.similarity_weight = nn.Parameter(torch.tensor([10.])).to(loss_device)
28 | self.similarity_bias = nn.Parameter(torch.tensor([-5.])).to(loss_device)
29 |
30 | # Loss
31 | self.loss_fn = nn.CrossEntropyLoss().to(loss_device)
32 |
33 | def do_gradient_ops(self):
34 | # Gradient scale
35 | self.similarity_weight.grad *= 0.01
36 | self.similarity_bias.grad *= 0.01
37 |
38 | # Gradient clipping
39 | clip_grad_norm_(self.parameters(), 3, norm_type=2)
40 |
41 | def forward(self, utterances, hidden_init=None):
42 | """
43 | Computes the embeddings of a batch of utterance spectrograms.
44 |
45 | :param utterances: batch of mel-scale filterbanks of same duration as a tensor of shape
46 | (batch_size, n_frames, n_channels)
47 | :param hidden_init: initial hidden state of the LSTM as a tensor of shape (num_layers,
48 | batch_size, hidden_size). Will default to a tensor of zeros if None.
49 | :return: the embeddings as a tensor of shape (batch_size, embedding_size)
50 | """
51 | # Pass the input through the LSTM layers and retrieve all outputs, the final hidden state
52 | # and the final cell state.
53 | out, (hidden, cell) = self.lstm(utterances, hidden_init)
54 |
55 | # We take only the hidden state of the last layer
56 | embeds_raw = self.relu(self.linear(hidden[-1]))
57 |
58 | # L2-normalize it
59 | embeds = embeds_raw / (torch.norm(embeds_raw, dim=1, keepdim=True) + 1e-5)
60 |
61 | return embeds
62 |
63 | def similarity_matrix(self, embeds):
64 | """
65 | Computes the similarity matrix according the section 2.1 of GE2E.
66 |
67 | :param embeds: the embeddings as a tensor of shape (speakers_per_batch,
68 | utterances_per_speaker, embedding_size)
69 | :return: the similarity matrix as a tensor of shape (speakers_per_batch,
70 | utterances_per_speaker, speakers_per_batch)
71 | """
72 | speakers_per_batch, utterances_per_speaker = embeds.shape[:2]
73 |
74 | # Inclusive centroids (1 per speaker). Cloning is needed for reverse differentiation
75 | centroids_incl = torch.mean(embeds, dim=1, keepdim=True)
76 | centroids_incl = centroids_incl.clone() / (torch.norm(centroids_incl, dim=2, keepdim=True) + 1e-5)
77 |
78 | # Exclusive centroids (1 per utterance)
79 | centroids_excl = (torch.sum(embeds, dim=1, keepdim=True) - embeds)
80 | centroids_excl /= (utterances_per_speaker - 1)
81 | centroids_excl = centroids_excl.clone() / (torch.norm(centroids_excl, dim=2, keepdim=True) + 1e-5)
82 |
83 | # Similarity matrix. The cosine similarity of already 2-normed vectors is simply the dot
84 | # product of these vectors (which is just an element-wise multiplication reduced by a sum).
85 | # We vectorize the computation for efficiency.
86 | sim_matrix = torch.zeros(speakers_per_batch, utterances_per_speaker,
87 | speakers_per_batch).to(self.loss_device)
88 | mask_matrix = 1 - np.eye(speakers_per_batch, dtype=np.int)
89 | for j in range(speakers_per_batch):
90 | mask = np.where(mask_matrix[j])[0]
91 | sim_matrix[mask, :, j] = (embeds[mask] * centroids_incl[j]).sum(dim=2)
92 | sim_matrix[j, :, j] = (embeds[j] * centroids_excl[j]).sum(dim=1)
93 |
94 | ## Even more vectorized version (slower maybe because of transpose)
95 | # sim_matrix2 = torch.zeros(speakers_per_batch, speakers_per_batch, utterances_per_speaker
96 | # ).to(self.loss_device)
97 | # eye = np.eye(speakers_per_batch, dtype=np.int)
98 | # mask = np.where(1 - eye)
99 | # sim_matrix2[mask] = (embeds[mask[0]] * centroids_incl[mask[1]]).sum(dim=2)
100 | # mask = np.where(eye)
101 | # sim_matrix2[mask] = (embeds * centroids_excl).sum(dim=2)
102 | # sim_matrix2 = sim_matrix2.transpose(1, 2)
103 |
104 | sim_matrix = sim_matrix * self.similarity_weight + self.similarity_bias
105 | return sim_matrix
106 |
107 | def loss(self, embeds):
108 | """
109 | Computes the softmax loss according the section 2.1 of GE2E.
110 |
111 | :param embeds: the embeddings as a tensor of shape (speakers_per_batch,
112 | utterances_per_speaker, embedding_size)
113 | :return: the loss and the EER for this batch of embeddings.
114 | """
115 | speakers_per_batch, utterances_per_speaker = embeds.shape[:2]
116 |
117 | # Loss
118 | sim_matrix = self.similarity_matrix(embeds)
119 | sim_matrix = sim_matrix.reshape((speakers_per_batch * utterances_per_speaker,
120 | speakers_per_batch))
121 | ground_truth = np.repeat(np.arange(speakers_per_batch), utterances_per_speaker)
122 | target = torch.from_numpy(ground_truth).long().to(self.loss_device)
123 | loss = self.loss_fn(sim_matrix, target)
124 |
125 | # EER (not backpropagated)
126 | with torch.no_grad():
127 | inv_argmax = lambda i: np.eye(1, speakers_per_batch, i, dtype=np.int)[0]
128 | labels = np.array([inv_argmax(i) for i in ground_truth])
129 | preds = sim_matrix.detach().cpu().numpy()
130 |
131 | # Snippet from https://yangcha.github.io/EER-ROC/
132 | fpr, tpr, thresholds = roc_curve(labels.flatten(), preds.flatten())
133 | eer = brentq(lambda x: 1. - x - interp1d(fpr, tpr)(x), 0., 1.)
134 |
135 | return loss, eer
136 |
--------------------------------------------------------------------------------
/src/synthesizer/inference.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from synthesizer import audio
3 | from synthesizer.hparams import hparams
4 | from synthesizer.models.tacotron import Tacotron
5 | from synthesizer.utils.symbols import symbols
6 | from synthesizer.utils.text import text_to_sequence
7 | from vocoder.display import simple_table
8 | from pathlib import Path
9 | from typing import Union, List
10 | import numpy as np
11 | import librosa
12 |
13 |
14 | class Synthesizer:
15 | sample_rate = hparams.sample_rate
16 | hparams = hparams
17 |
18 | def __init__(self, model_fpath: Path, verbose=True):
19 | """
20 | The model isn't instantiated and loaded in memory until needed or until load() is called.
21 |
22 | :param model_fpath: path to the trained model file
23 | :param verbose: if False, prints less information when using the model
24 | """
25 | self.model_fpath = model_fpath
26 | self.verbose = verbose
27 |
28 | # Check for GPU
29 | if torch.cuda.is_available():
30 | self.device = torch.device("cuda")
31 | else:
32 | self.device = torch.device("cpu")
33 | if self.verbose:
34 | print("Synthesizer using device:", self.device)
35 |
36 | # Tacotron model will be instantiated later on first use.
37 | self._model = None
38 |
39 | def is_loaded(self):
40 | """
41 | Whether the model is loaded in memory.
42 | """
43 | return self._model is not None
44 |
45 | def load(self):
46 | """
47 | Instantiates and loads the model given the weights file that was passed in the constructor.
48 | """
49 | self._model = Tacotron(embed_dims=hparams.tts_embed_dims,
50 | num_chars=len(symbols),
51 | encoder_dims=hparams.tts_encoder_dims,
52 | decoder_dims=hparams.tts_decoder_dims,
53 | n_mels=hparams.num_mels,
54 | fft_bins=hparams.num_mels,
55 | postnet_dims=hparams.tts_postnet_dims,
56 | encoder_K=hparams.tts_encoder_K,
57 | lstm_dims=hparams.tts_lstm_dims,
58 | postnet_K=hparams.tts_postnet_K,
59 | num_highways=hparams.tts_num_highways,
60 | dropout=hparams.tts_dropout,
61 | stop_threshold=hparams.tts_stop_threshold,
62 | speaker_embedding_size=hparams.speaker_embedding_size).to(self.device)
63 |
64 | self._model.load(self.model_fpath)
65 | self._model.eval()
66 |
67 | #if self.verbose:
68 | # print("Loaded synthesizer \"%s\" trained to step %d" % (self.model_fpath.name, self._model.state_dict()["step"]))
69 |
70 | def synthesize_spectrograms(self, texts: List[str],
71 | embeddings: Union[np.ndarray, List[np.ndarray]],
72 | return_alignments=False):
73 | """
74 | Synthesizes mel spectrograms from texts and speaker embeddings.
75 |
76 | :param texts: a list of N text prompts to be synthesized
77 | :param embeddings: a numpy array or list of speaker embeddings of shape (N, 256)
78 | :param return_alignments: if True, a matrix representing the alignments between the
79 | characters
80 | and each decoder output step will be returned for each spectrogram
81 | :return: a list of N melspectrograms as numpy arrays of shape (80, Mi), where Mi is the
82 | sequence length of spectrogram i, and possibly the alignments.
83 | """
84 | # Load the model on the first request.
85 | if not self.is_loaded():
86 | self.load()
87 |
88 | # Print some info about the model when it is loaded
89 | #tts_k = self._model.get_step() // 1000
90 |
91 | #simple_table([("Tacotron", str(tts_k) + "k"), ("r", self._model.r)])
92 |
93 | # Preprocess text inputs
94 | inputs = [text_to_sequence(text.strip(), hparams.tts_cleaner_names) for text in texts]
95 | if not isinstance(embeddings, list):
96 | embeddings = [embeddings]
97 |
98 | # Batch inputs
99 | batched_inputs = [inputs[i:i+hparams.synthesis_batch_size]
100 | for i in range(0, len(inputs), hparams.synthesis_batch_size)]
101 | batched_embeds = [embeddings[i:i+hparams.synthesis_batch_size]
102 | for i in range(0, len(embeddings), hparams.synthesis_batch_size)]
103 |
104 | specs = []
105 | for i, batch in enumerate(batched_inputs, 1):
106 | if self.verbose:
107 | print(f"\n| Generating {i}/{len(batched_inputs)}")
108 |
109 | # Pad texts so they are all the same length
110 | text_lens = [len(text) for text in batch]
111 | max_text_len = max(text_lens)
112 | chars = [pad1d(text, max_text_len) for text in batch]
113 | chars = np.stack(chars)
114 |
115 | # Stack speaker embeddings into 2D array for batch processing
116 | speaker_embeds = np.stack(batched_embeds[i-1])
117 |
118 | # Convert to tensor
119 | chars = torch.tensor(chars).long().to(self.device)
120 | speaker_embeddings = torch.tensor(speaker_embeds).float().to(self.device)
121 |
122 | # Inference
123 | _, mels, alignments = self._model.generate(chars, speaker_embeddings)
124 | mels = mels.detach().cpu().numpy()
125 | for m in mels:
126 | # Trim silence from end of each spectrogram
127 | while np.max(m[:, -1]) < hparams.tts_stop_threshold:
128 | m = m[:, :-1]
129 | specs.append(m)
130 |
131 | #if self.verbose:
132 | # print("\n\nDone.\n")
133 | return (specs, alignments) if return_alignments else specs
134 |
135 | @staticmethod
136 | def load_preprocess_wav(fpath):
137 | """
138 | Loads and preprocesses an audio file under the same conditions the audio files were used to
139 | train the synthesizer.
140 | """
141 | wav = librosa.load(str(fpath), hparams.sample_rate)[0]
142 | if hparams.rescale:
143 | wav = wav / np.abs(wav).max() * hparams.rescaling_max
144 | return wav
145 |
146 | @staticmethod
147 | def make_spectrogram(fpath_or_wav: Union[str, Path, np.ndarray]):
148 | """
149 | Creates a mel spectrogram from an audio file in the same manner as the mel spectrograms that
150 | were fed to the synthesizer when training.
151 | """
152 | if isinstance(fpath_or_wav, str) or isinstance(fpath_or_wav, Path):
153 | wav = Synthesizer.load_preprocess_wav(fpath_or_wav)
154 | else:
155 | wav = fpath_or_wav
156 |
157 | mel_spectrogram = audio.melspectrogram(wav, hparams).astype(np.float32)
158 | return mel_spectrogram
159 |
160 | @staticmethod
161 | def griffin_lim(mel):
162 | """
163 | Inverts a mel spectrogram using Griffin-Lim. The mel spectrogram is expected to have been built
164 | with the same parameters present in hparams.py.
165 | """
166 | return audio.inv_mel_spectrogram(mel, hparams)
167 |
168 |
169 | def pad1d(x, max_len, pad_value=0):
170 | return np.pad(x, (0, max_len - len(x)), mode="constant", constant_values=pad_value)
171 |
--------------------------------------------------------------------------------
/src/main.py:
--------------------------------------------------------------------------------
1 | ### IMPORTS ###
2 |
3 | print('[Importing libraries...]')
4 |
5 | import soundfile as sf
6 | import torch
7 | import librosa
8 | import numpy as np
9 | import sys
10 | import os
11 | from os.path import exists, join, basename, splitext
12 | from transformers import Wav2Vec2Tokenizer, Wav2Vec2ForCTC, logging
13 | from datasets import load_dataset
14 |
15 | from synthesizer.inference import Synthesizer
16 | from encoder import inference as encoder
17 | from encoder.audio import preprocess_wav
18 | from vocoder import inference as vocoder
19 | from pathlib import Path
20 | import argparse
21 | from utils.argutils import print_args
22 |
23 | import jiwer
24 | import speechmetrics
25 | from asrtoolkit import cer
26 | import nltk
27 |
28 | ### STD OUT SUPPRESSION UTILITY ###
29 |
30 | class suppress_output:
31 | def __init__(self, suppress_stdout=True, suppress_stderr=True):
32 | self.suppress_stdout = suppress_stdout
33 | self.suppress_stderr = suppress_stderr
34 | self._stdout = None
35 | self._stderr = None
36 |
37 | def __enter__(self):
38 | devnull = open(os.devnull, "w")
39 | if self.suppress_stdout:
40 | self._stdout = sys.stdout
41 | sys.stdout = devnull
42 |
43 | if self.suppress_stderr:
44 | self._stderr = sys.stderr
45 | sys.stderr = devnull
46 |
47 | def __exit__(self, *args):
48 | if self.suppress_stdout:
49 | sys.stdout = self._stdout
50 | if self.suppress_stderr:
51 | sys.stderr = self._stderr
52 |
53 |
54 | ### MODELS DOWNLOAD ###
55 | print('[Loading models...]')
56 |
57 | dir = os.getcwd()
58 | if os.path.basename(os.path.normpath(dir)) != "src":
59 | dir += "/src"
60 |
61 |
62 | logging.set_verbosity_error()
63 |
64 | with suppress_output(suppress_stdout=True, suppress_stderr=True):
65 | tokenizer = Wav2Vec2Tokenizer.from_pretrained("facebook/wav2vec2-large-960h-lv60-self")
66 | model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-large-960h-lv60-self")
67 |
68 | encoder.load_model(Path(dir + "/encoder/saved_models/pretrained.pt"))
69 | synthesizer = Synthesizer(Path(dir + "/synthesizer/saved_models/pretrained/pretrained.pt"))
70 | vocoder.load_model(Path(dir + "/vocoder/saved_models/pretrained/pretrained.pt"))
71 |
72 | ### FUNCTIONS and GLOBAL VARIABLES ###
73 |
74 | SAMPLE_RATE = 16000
75 |
76 | # def listToString(s):
77 | # str1 = ""
78 | # for ele in s:
79 | # str1 += ele
80 | # return str1
81 |
82 | def synthesize(embed, text):
83 | print('[Synthesizing new audio...]')
84 | print('Text: ' + text + '\n')
85 | #with io.capture_output() as captured:
86 | #with suppress_output(suppress_stdout=True, suppress_stderr=True):
87 | specs = synthesizer.synthesize_spectrograms([text], [embed])
88 | generated_wav = vocoder.infer_waveform(specs[0])
89 | generated_wav = np.pad(generated_wav, (0, synthesizer.sample_rate), mode="constant")
90 | # print('\n[...synthesis done]')
91 | return generated_wav
92 |
93 | ### MAIN ###
94 |
95 | if __name__ == '__main__':
96 |
97 | # args parsing
98 | parser = argparse.ArgumentParser(
99 | formatter_class=argparse.ArgumentDefaultsHelpFormatter
100 | )
101 | parser.add_argument("--source", type=Path, default=None, help=\
102 | "Source speaker for voice conversion.")
103 | parser.add_argument("--target", type=Path, default=None, help=\
104 | "Target speaker for voice conversion.")
105 | parser.add_argument("--string", type=str, default=None, help=\
106 | "Sentence used by the synthesized voice.")
107 | parser.add_argument("--seed", type=int, default=None, help=\
108 | "Optional random number seed value to make toolbox deterministic.")
109 | parser.add_argument("-m", "--metrics", action="store_true", help=\
110 | "Print some metrics.")
111 | parser.add_argument("-e", "--enhance", action="store_true", help=\
112 | "Trims output audio silences.")
113 |
114 | args = parser.parse_args()
115 | # print_args(args, parser)
116 |
117 | if args.seed is not None:
118 | torch.manual_seed(args.seed)
119 | #synthesizer = Synthesizer(args.syn_model_fpath, verbose=False)
120 |
121 | if args.source is not None:
122 | source_audio, _ = librosa.load(args.source, sr=SAMPLE_RATE)
123 | else:
124 | source_audio, _ = librosa.load(dir + "/audio/source.wav", sr=SAMPLE_RATE)
125 |
126 | if args.target is not None:
127 | target_audio, _ = librosa.load(args.target, sr=SAMPLE_RATE)
128 | else:
129 | target_audio, _ = librosa.load(dir + "/audio/target.wav", sr=SAMPLE_RATE)
130 |
131 | if args.string is not None:
132 | if args.source is not None:
133 | raise Exception("[ERROR] Can't specify both source and string args.")
134 | transcription = args.string
135 | # text = listToString(transcription)
136 | else:
137 | input_values = tokenizer(np.asarray(source_audio), return_tensors="pt").input_values
138 | logits = model(input_values).logits
139 | predicted_ids = torch.argmax(logits, dim=-1)
140 | transcription = tokenizer.batch_decode(predicted_ids)[0]
141 | # text = listToString(transcription)
142 |
143 | embedding = encoder.embed_utterance(encoder.preprocess_wav(target_audio, SAMPLE_RATE))
144 |
145 | out_audio = synthesize(embedding, transcription)
146 | if args.enhance:
147 | out_audio = preprocess_wav(out_audio)
148 |
149 | sf.write(dir + "/audio/audio_out.wav", out_audio, 16000)
150 |
151 | if args.metrics:
152 | input_values = tokenizer(np.asarray(out_audio), return_tensors="pt").input_values
153 | logits = model(input_values).logits
154 | predicted_ids = torch.argmax(logits, dim=-1)
155 | transcription_out = tokenizer.batch_decode(predicted_ids)[0]
156 | # text_out = listToString(transcription_out)
157 |
158 | ground_truth = transcription
159 | hypothesis = transcription_out
160 |
161 | wer_before_lemma = jiwer.wer(ground_truth, hypothesis) #word error rate
162 |
163 | print('\n')
164 |
165 | nltk.download('wordnet')
166 | wnl = nltk.stem.WordNetLemmatizer()
167 | stm = nltk.stem.snowball.EnglishStemmer()
168 |
169 | transcription_lemma = ""
170 | transcription_out_lemma = ""
171 |
172 | for s in transcription.split(" "):
173 | transcription_lemma += (stm.stem(wnl.lemmatize(s)) + " ").upper()
174 |
175 | for s in transcription_out.split(" "):
176 | transcription_out_lemma += (stm.stem(wnl.lemmatize(s)) + " ").upper()
177 |
178 | wer_after_lemma = jiwer.wer(transcription_lemma, transcription_out_lemma)
179 | cer = cer(ground_truth, hypothesis) #character error rate
180 | mer = jiwer.mer(ground_truth, hypothesis) #match error rate
181 | wil = jiwer.wil(ground_truth, hypothesis) #word information lost
182 |
183 | #mosnet
184 |
185 | window_length = None
186 | with suppress_output(suppress_stdout=True, suppress_stderr=True):
187 | metrics = speechmetrics.load('absolute.mosnet',window_length)
188 | results = metrics(dir + "/audio/audio_out.wav")
189 |
190 | print('\n\n[+++METRICS+++]\n')
191 |
192 | print('Detected text: ' + transcription_out)
193 |
194 | print('Original text lemmatized: ' + transcription_lemma)
195 | print('Synthesized text lemmatized: ' + transcription_out_lemma)
196 |
197 | print('\nWord Error Rate before lemmatization is: ', wer_before_lemma)
198 | print('Word Error Rate after lemmatization is: ', wer_after_lemma)
199 | print('Character Error Rate is: ', cer/100)
200 | #print('Match Error Rate is: ', mer)
201 | #print('Word Information Lost is: ', wil)
202 | print('MOSNet is: ', results['mosnet'][0][0])
203 | print('\n[Done]\n')
204 | else:
205 | print('\n[Done]\n')
206 |
--------------------------------------------------------------------------------
/src/encoder/inference.py:
--------------------------------------------------------------------------------
1 | from encoder.params_data import *
2 | from encoder.model import SpeakerEncoder
3 | from encoder.audio import preprocess_wav # We want to expose this function from here
4 | from matplotlib import cm
5 | from encoder import audio
6 | from pathlib import Path
7 | import matplotlib.pyplot as plt
8 | import numpy as np
9 | import torch
10 |
11 | _model = None # type: SpeakerEncoder
12 | _device = None # type: torch.device
13 |
14 |
15 | def load_model(weights_fpath: Path, device=None):
16 | """
17 | Loads the model in memory. If this function is not explicitely called, it will be run on the
18 | first call to embed_frames() with the default weights file.
19 |
20 | :param weights_fpath: the path to saved model weights.
21 | :param device: either a torch device or the name of a torch device (e.g. "cpu", "cuda"). The
22 | model will be loaded and will run on this device. Outputs will however always be on the cpu.
23 | If None, will default to your GPU if it"s available, otherwise your CPU.
24 | """
25 | # TODO: I think the slow loading of the encoder might have something to do with the device it
26 | # was saved on. Worth investigating.
27 | global _model, _device
28 | if device is None:
29 | _device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
30 | elif isinstance(device, str):
31 | _device = torch.device(device)
32 | _model = SpeakerEncoder(_device, torch.device("cpu"))
33 | checkpoint = torch.load(weights_fpath, _device)
34 | _model.load_state_dict(checkpoint["model_state"])
35 | _model.eval()
36 | print("Loaded encoder \"%s\" trained to step %d" % (weights_fpath.name, checkpoint["step"]))
37 |
38 |
39 | def is_loaded():
40 | return _model is not None
41 |
42 |
43 | def embed_frames_batch(frames_batch):
44 | """
45 | Computes embeddings for a batch of mel spectrogram.
46 |
47 | :param frames_batch: a batch mel of spectrogram as a numpy array of float32 of shape
48 | (batch_size, n_frames, n_channels)
49 | :return: the embeddings as a numpy array of float32 of shape (batch_size, model_embedding_size)
50 | """
51 | if _model is None:
52 | raise Exception("Model was not loaded. Call load_model() before inference.")
53 |
54 | frames = torch.from_numpy(frames_batch).to(_device)
55 | embed = _model.forward(frames).detach().cpu().numpy()
56 | return embed
57 |
58 |
59 | def compute_partial_slices(n_samples, partial_utterance_n_frames=partials_n_frames,
60 | min_pad_coverage=0.75, overlap=0.5):
61 | """
62 | Computes where to split an utterance waveform and its corresponding mel spectrogram to obtain
63 | partial utterances of each. Both the waveform and the mel
64 | spectrogram slices are returned, so as to make each partial utterance waveform correspond to
65 | its spectrogram. This function assumes that the mel spectrogram parameters used are those
66 | defined in params_data.py.
67 |
68 | The returned ranges may be indexing further than the length of the waveform. It is
69 | recommended that you pad the waveform with zeros up to wave_slices[-1].stop.
70 |
71 | :param n_samples: the number of samples in the waveform
72 | :param partial_utterance_n_frames: the number of mel spectrogram frames in each partial
73 | utterance
74 | :param min_pad_coverage: when reaching the last partial utterance, it may or may not have
75 | enough frames. If at least of are present,
76 | then the last partial utterance will be considered, as if we padded the audio. Otherwise,
77 | it will be discarded, as if we trimmed the audio. If there aren't enough frames for 1 partial
78 | utterance, this parameter is ignored so that the function always returns at least 1 slice.
79 | :param overlap: by how much the partial utterance should overlap. If set to 0, the partial
80 | utterances are entirely disjoint.
81 | :return: the waveform slices and mel spectrogram slices as lists of array slices. Index
82 | respectively the waveform and the mel spectrogram with these slices to obtain the partial
83 | utterances.
84 | """
85 | assert 0 <= overlap < 1
86 | assert 0 < min_pad_coverage <= 1
87 |
88 | samples_per_frame = int((sampling_rate * mel_window_step / 1000))
89 | n_frames = int(np.ceil((n_samples + 1) / samples_per_frame))
90 | frame_step = max(int(np.round(partial_utterance_n_frames * (1 - overlap))), 1)
91 |
92 | # Compute the slices
93 | wav_slices, mel_slices = [], []
94 | steps = max(1, n_frames - partial_utterance_n_frames + frame_step + 1)
95 | for i in range(0, steps, frame_step):
96 | mel_range = np.array([i, i + partial_utterance_n_frames])
97 | wav_range = mel_range * samples_per_frame
98 | mel_slices.append(slice(*mel_range))
99 | wav_slices.append(slice(*wav_range))
100 |
101 | # Evaluate whether extra padding is warranted or not
102 | last_wav_range = wav_slices[-1]
103 | coverage = (n_samples - last_wav_range.start) / (last_wav_range.stop - last_wav_range.start)
104 | if coverage < min_pad_coverage and len(mel_slices) > 1:
105 | mel_slices = mel_slices[:-1]
106 | wav_slices = wav_slices[:-1]
107 |
108 | return wav_slices, mel_slices
109 |
110 |
111 | def embed_utterance(wav, using_partials=True, return_partials=False, **kwargs):
112 | """
113 | Computes an embedding for a single utterance.
114 |
115 | # TODO: handle multiple wavs to benefit from batching on GPU
116 | :param wav: a preprocessed (see audio.py) utterance waveform as a numpy array of float32
117 | :param using_partials: if True, then the utterance is split in partial utterances of
118 | frames and the utterance embedding is computed from their
119 | normalized average. If False, the utterance is instead computed from feeding the entire
120 | spectogram to the network.
121 | :param return_partials: if True, the partial embeddings will also be returned along with the
122 | wav slices that correspond to the partial embeddings.
123 | :param kwargs: additional arguments to compute_partial_splits()
124 | :return: the embedding as a numpy array of float32 of shape (model_embedding_size,). If
125 | is True, the partial utterances as a numpy array of float32 of shape
126 | (n_partials, model_embedding_size) and the wav partials as a list of slices will also be
127 | returned. If is simultaneously set to False, both these values will be None
128 | instead.
129 | """
130 | # Process the entire utterance if not using partials
131 | if not using_partials:
132 | frames = audio.wav_to_mel_spectrogram(wav)
133 | embed = embed_frames_batch(frames[None, ...])[0]
134 | if return_partials:
135 | return embed, None, None
136 | return embed
137 |
138 | # Compute where to split the utterance into partials and pad if necessary
139 | wave_slices, mel_slices = compute_partial_slices(len(wav), **kwargs)
140 | max_wave_length = wave_slices[-1].stop
141 | if max_wave_length >= len(wav):
142 | wav = np.pad(wav, (0, max_wave_length - len(wav)), "constant")
143 |
144 | # Split the utterance into partials
145 | frames = audio.wav_to_mel_spectrogram(wav)
146 | frames_batch = np.array([frames[s] for s in mel_slices])
147 | partial_embeds = embed_frames_batch(frames_batch)
148 |
149 | # Compute the utterance embedding from the partial embeddings
150 | raw_embed = np.mean(partial_embeds, axis=0)
151 | embed = raw_embed / np.linalg.norm(raw_embed, 2)
152 |
153 | if return_partials:
154 | return embed, partial_embeds, wave_slices
155 | return embed
156 |
157 |
158 | def embed_speaker(wavs, **kwargs):
159 | raise NotImplemented()
160 |
161 |
162 | def plot_embedding_as_heatmap(embed, ax=None, title="", shape=None, color_range=(0, 0.30)):
163 | if ax is None:
164 | ax = plt.gca()
165 |
166 | if shape is None:
167 | height = int(np.sqrt(len(embed)))
168 | shape = (height, -1)
169 | embed = embed.reshape(shape)
170 |
171 | cmap = cm.get_cmap()
172 | mappable = ax.imshow(embed, cmap=cmap)
173 | cbar = plt.colorbar(mappable, ax=ax, fraction=0.046, pad=0.04)
174 | sm = cm.ScalarMappable(cmap=cmap)
175 | sm.set_clim(*color_range)
176 |
177 | ax.set_xticks([]), ax.set_yticks([])
178 | ax.set_title(title)
179 |
--------------------------------------------------------------------------------
/src/synthesizer/audio.py:
--------------------------------------------------------------------------------
1 | import librosa
2 | import librosa.filters
3 | import numpy as np
4 | from scipy import signal
5 | from scipy.io import wavfile
6 | import soundfile as sf
7 |
8 |
9 | def load_wav(path, sr):
10 | return librosa.core.load(path, sr=sr)[0]
11 |
12 | def save_wav(wav, path, sr):
13 | wav *= 32767 / max(0.01, np.max(np.abs(wav)))
14 | #proposed by @dsmiller
15 | wavfile.write(path, sr, wav.astype(np.int16))
16 |
17 | def save_wavenet_wav(wav, path, sr):
18 | sf.write(path, wav.astype(np.float32), sr)
19 |
20 | def preemphasis(wav, k, preemphasize=True):
21 | if preemphasize:
22 | return signal.lfilter([1, -k], [1], wav)
23 | return wav
24 |
25 | def inv_preemphasis(wav, k, inv_preemphasize=True):
26 | if inv_preemphasize:
27 | return signal.lfilter([1], [1, -k], wav)
28 | return wav
29 |
30 | #From https://github.com/r9y9/wavenet_vocoder/blob/master/audio.py
31 | def start_and_end_indices(quantized, silence_threshold=2):
32 | for start in range(quantized.size):
33 | if abs(quantized[start] - 127) > silence_threshold:
34 | break
35 | for end in range(quantized.size - 1, 1, -1):
36 | if abs(quantized[end] - 127) > silence_threshold:
37 | break
38 |
39 | assert abs(quantized[start] - 127) > silence_threshold
40 | assert abs(quantized[end] - 127) > silence_threshold
41 |
42 | return start, end
43 |
44 | def get_hop_size(hparams):
45 | hop_size = hparams.hop_size
46 | if hop_size is None:
47 | assert hparams.frame_shift_ms is not None
48 | hop_size = int(hparams.frame_shift_ms / 1000 * hparams.sample_rate)
49 | return hop_size
50 |
51 | def linearspectrogram(wav, hparams):
52 | D = _stft(preemphasis(wav, hparams.preemphasis, hparams.preemphasize), hparams)
53 | S = _amp_to_db(np.abs(D), hparams) - hparams.ref_level_db
54 |
55 | if hparams.signal_normalization:
56 | return _normalize(S, hparams)
57 | return S
58 |
59 | def melspectrogram(wav, hparams):
60 | D = _stft(preemphasis(wav, hparams.preemphasis, hparams.preemphasize), hparams)
61 | S = _amp_to_db(_linear_to_mel(np.abs(D), hparams), hparams) - hparams.ref_level_db
62 |
63 | if hparams.signal_normalization:
64 | return _normalize(S, hparams)
65 | return S
66 |
67 | def inv_linear_spectrogram(linear_spectrogram, hparams):
68 | """Converts linear spectrogram to waveform using librosa"""
69 | if hparams.signal_normalization:
70 | D = _denormalize(linear_spectrogram, hparams)
71 | else:
72 | D = linear_spectrogram
73 |
74 | S = _db_to_amp(D + hparams.ref_level_db) #Convert back to linear
75 |
76 | if hparams.use_lws:
77 | processor = _lws_processor(hparams)
78 | D = processor.run_lws(S.astype(np.float64).T ** hparams.power)
79 | y = processor.istft(D).astype(np.float32)
80 | return inv_preemphasis(y, hparams.preemphasis, hparams.preemphasize)
81 | else:
82 | return inv_preemphasis(_griffin_lim(S ** hparams.power, hparams), hparams.preemphasis, hparams.preemphasize)
83 |
84 | def inv_mel_spectrogram(mel_spectrogram, hparams):
85 | """Converts mel spectrogram to waveform using librosa"""
86 | if hparams.signal_normalization:
87 | D = _denormalize(mel_spectrogram, hparams)
88 | else:
89 | D = mel_spectrogram
90 |
91 | S = _mel_to_linear(_db_to_amp(D + hparams.ref_level_db), hparams) # Convert back to linear
92 |
93 | if hparams.use_lws:
94 | processor = _lws_processor(hparams)
95 | D = processor.run_lws(S.astype(np.float64).T ** hparams.power)
96 | y = processor.istft(D).astype(np.float32)
97 | return inv_preemphasis(y, hparams.preemphasis, hparams.preemphasize)
98 | else:
99 | return inv_preemphasis(_griffin_lim(S ** hparams.power, hparams), hparams.preemphasis, hparams.preemphasize)
100 |
101 | def _lws_processor(hparams):
102 | import lws
103 | return lws.lws(hparams.n_fft, get_hop_size(hparams), fftsize=hparams.win_size, mode="speech")
104 |
105 | def _griffin_lim(S, hparams):
106 | """librosa implementation of Griffin-Lim
107 | Based on https://github.com/librosa/librosa/issues/434
108 | """
109 | angles = np.exp(2j * np.pi * np.random.rand(*S.shape))
110 | S_complex = np.abs(S).astype(np.complex)
111 | y = _istft(S_complex * angles, hparams)
112 | for i in range(hparams.griffin_lim_iters):
113 | angles = np.exp(1j * np.angle(_stft(y, hparams)))
114 | y = _istft(S_complex * angles, hparams)
115 | return y
116 |
117 | def _stft(y, hparams):
118 | if hparams.use_lws:
119 | return _lws_processor(hparams).stft(y).T
120 | else:
121 | return librosa.stft(y=y, n_fft=hparams.n_fft, hop_length=get_hop_size(hparams), win_length=hparams.win_size)
122 |
123 | def _istft(y, hparams):
124 | return librosa.istft(y, hop_length=get_hop_size(hparams), win_length=hparams.win_size)
125 |
126 | ##########################################################
127 | #Those are only correct when using lws!!! (This was messing with Wavenet quality for a long time!)
128 | def num_frames(length, fsize, fshift):
129 | """Compute number of time frames of spectrogram
130 | """
131 | pad = (fsize - fshift)
132 | if length % fshift == 0:
133 | M = (length + pad * 2 - fsize) // fshift + 1
134 | else:
135 | M = (length + pad * 2 - fsize) // fshift + 2
136 | return M
137 |
138 |
139 | def pad_lr(x, fsize, fshift):
140 | """Compute left and right padding
141 | """
142 | M = num_frames(len(x), fsize, fshift)
143 | pad = (fsize - fshift)
144 | T = len(x) + 2 * pad
145 | r = (M - 1) * fshift + fsize - T
146 | return pad, pad + r
147 | ##########################################################
148 | #Librosa correct padding
149 | def librosa_pad_lr(x, fsize, fshift):
150 | return 0, (x.shape[0] // fshift + 1) * fshift - x.shape[0]
151 |
152 | # Conversions
153 | _mel_basis = None
154 | _inv_mel_basis = None
155 |
156 | def _linear_to_mel(spectogram, hparams):
157 | global _mel_basis
158 | if _mel_basis is None:
159 | _mel_basis = _build_mel_basis(hparams)
160 | return np.dot(_mel_basis, spectogram)
161 |
162 | def _mel_to_linear(mel_spectrogram, hparams):
163 | global _inv_mel_basis
164 | if _inv_mel_basis is None:
165 | _inv_mel_basis = np.linalg.pinv(_build_mel_basis(hparams))
166 | return np.maximum(1e-10, np.dot(_inv_mel_basis, mel_spectrogram))
167 |
168 | def _build_mel_basis(hparams):
169 | assert hparams.fmax <= hparams.sample_rate // 2
170 | return librosa.filters.mel(hparams.sample_rate, hparams.n_fft, n_mels=hparams.num_mels,
171 | fmin=hparams.fmin, fmax=hparams.fmax)
172 |
173 | def _amp_to_db(x, hparams):
174 | min_level = np.exp(hparams.min_level_db / 20 * np.log(10))
175 | return 20 * np.log10(np.maximum(min_level, x))
176 |
177 | def _db_to_amp(x):
178 | return np.power(10.0, (x) * 0.05)
179 |
180 | def _normalize(S, hparams):
181 | if hparams.allow_clipping_in_normalization:
182 | if hparams.symmetric_mels:
183 | return np.clip((2 * hparams.max_abs_value) * ((S - hparams.min_level_db) / (-hparams.min_level_db)) - hparams.max_abs_value,
184 | -hparams.max_abs_value, hparams.max_abs_value)
185 | else:
186 | return np.clip(hparams.max_abs_value * ((S - hparams.min_level_db) / (-hparams.min_level_db)), 0, hparams.max_abs_value)
187 |
188 | assert S.max() <= 0 and S.min() - hparams.min_level_db >= 0
189 | if hparams.symmetric_mels:
190 | return (2 * hparams.max_abs_value) * ((S - hparams.min_level_db) / (-hparams.min_level_db)) - hparams.max_abs_value
191 | else:
192 | return hparams.max_abs_value * ((S - hparams.min_level_db) / (-hparams.min_level_db))
193 |
194 | def _denormalize(D, hparams):
195 | if hparams.allow_clipping_in_normalization:
196 | if hparams.symmetric_mels:
197 | return (((np.clip(D, -hparams.max_abs_value,
198 | hparams.max_abs_value) + hparams.max_abs_value) * -hparams.min_level_db / (2 * hparams.max_abs_value))
199 | + hparams.min_level_db)
200 | else:
201 | return ((np.clip(D, 0, hparams.max_abs_value) * -hparams.min_level_db / hparams.max_abs_value) + hparams.min_level_db)
202 |
203 | if hparams.symmetric_mels:
204 | return (((D + hparams.max_abs_value) * -hparams.min_level_db / (2 * hparams.max_abs_value)) + hparams.min_level_db)
205 | else:
206 | return ((D * -hparams.min_level_db / hparams.max_abs_value) + hparams.min_level_db)
207 |
--------------------------------------------------------------------------------
/src/utils/logmmse.py:
--------------------------------------------------------------------------------
1 | # The MIT License (MIT)
2 | #
3 | # Copyright (c) 2015 braindead
4 | #
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy
6 | # of this software and associated documentation files (the "Software"), to deal
7 | # in the Software without restriction, including without limitation the rights
8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | # copies of the Software, and to permit persons to whom the Software is
10 | # furnished to do so, subject to the following conditions:
11 | #
12 | # The above copyright notice and this permission notice shall be included in all
13 | # copies or substantial portions of the Software.
14 | #
15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | # SOFTWARE.
22 | #
23 | #
24 | # This code was extracted from the logmmse package (https://pypi.org/project/logmmse/) and I
25 | # simply modified the interface to meet my needs.
26 |
27 |
28 | import numpy as np
29 | import math
30 | from scipy.special import expn
31 | from collections import namedtuple
32 |
33 | NoiseProfile = namedtuple("NoiseProfile", "sampling_rate window_size len1 len2 win n_fft noise_mu2")
34 |
35 |
36 | def profile_noise(noise, sampling_rate, window_size=0):
37 | """
38 | Creates a profile of the noise in a given waveform.
39 |
40 | :param noise: a waveform containing noise ONLY, as a numpy array of floats or ints.
41 | :param sampling_rate: the sampling rate of the audio
42 | :param window_size: the size of the window the logmmse algorithm operates on. A default value
43 | will be picked if left as 0.
44 | :return: a NoiseProfile object
45 | """
46 | noise, dtype = to_float(noise)
47 | noise += np.finfo(np.float64).eps
48 |
49 | if window_size == 0:
50 | window_size = int(math.floor(0.02 * sampling_rate))
51 |
52 | if window_size % 2 == 1:
53 | window_size = window_size + 1
54 |
55 | perc = 50
56 | len1 = int(math.floor(window_size * perc / 100))
57 | len2 = int(window_size - len1)
58 |
59 | win = np.hanning(window_size)
60 | win = win * len2 / np.sum(win)
61 | n_fft = 2 * window_size
62 |
63 | noise_mean = np.zeros(n_fft)
64 | n_frames = len(noise) // window_size
65 | for j in range(0, window_size * n_frames, window_size):
66 | noise_mean += np.absolute(np.fft.fft(win * noise[j:j + window_size], n_fft, axis=0))
67 | noise_mu2 = (noise_mean / n_frames) ** 2
68 |
69 | return NoiseProfile(sampling_rate, window_size, len1, len2, win, n_fft, noise_mu2)
70 |
71 |
72 | def denoise(wav, noise_profile: NoiseProfile, eta=0.15):
73 | """
74 | Cleans the noise from a speech waveform given a noise profile. The waveform must have the
75 | same sampling rate as the one used to create the noise profile.
76 |
77 | :param wav: a speech waveform as a numpy array of floats or ints.
78 | :param noise_profile: a NoiseProfile object that was created from a similar (or a segment of
79 | the same) waveform.
80 | :param eta: voice threshold for noise update. While the voice activation detection value is
81 | below this threshold, the noise profile will be continuously updated throughout the audio.
82 | Set to 0 to disable updating the noise profile.
83 | :return: the clean wav as a numpy array of floats or ints of the same length.
84 | """
85 | wav, dtype = to_float(wav)
86 | wav += np.finfo(np.float64).eps
87 | p = noise_profile
88 |
89 | nframes = int(math.floor(len(wav) / p.len2) - math.floor(p.window_size / p.len2))
90 | x_final = np.zeros(nframes * p.len2)
91 |
92 | aa = 0.98
93 | mu = 0.98
94 | ksi_min = 10 ** (-25 / 10)
95 |
96 | x_old = np.zeros(p.len1)
97 | xk_prev = np.zeros(p.len1)
98 | noise_mu2 = p.noise_mu2
99 | for k in range(0, nframes * p.len2, p.len2):
100 | insign = p.win * wav[k:k + p.window_size]
101 |
102 | spec = np.fft.fft(insign, p.n_fft, axis=0)
103 | sig = np.absolute(spec)
104 | sig2 = sig ** 2
105 |
106 | gammak = np.minimum(sig2 / noise_mu2, 40)
107 |
108 | if xk_prev.all() == 0:
109 | ksi = aa + (1 - aa) * np.maximum(gammak - 1, 0)
110 | else:
111 | ksi = aa * xk_prev / noise_mu2 + (1 - aa) * np.maximum(gammak - 1, 0)
112 | ksi = np.maximum(ksi_min, ksi)
113 |
114 | log_sigma_k = gammak * ksi/(1 + ksi) - np.log(1 + ksi)
115 | vad_decision = np.sum(log_sigma_k) / p.window_size
116 | if vad_decision < eta:
117 | noise_mu2 = mu * noise_mu2 + (1 - mu) * sig2
118 |
119 | a = ksi / (1 + ksi)
120 | vk = a * gammak
121 | ei_vk = 0.5 * expn(1, np.maximum(vk, 1e-8))
122 | hw = a * np.exp(ei_vk)
123 | sig = sig * hw
124 | xk_prev = sig ** 2
125 | xi_w = np.fft.ifft(hw * spec, p.n_fft, axis=0)
126 | xi_w = np.real(xi_w)
127 |
128 | x_final[k:k + p.len2] = x_old + xi_w[0:p.len1]
129 | x_old = xi_w[p.len1:p.window_size]
130 |
131 | output = from_float(x_final, dtype)
132 | output = np.pad(output, (0, len(wav) - len(output)), mode="constant")
133 | return output
134 |
135 |
136 | ## Alternative VAD algorithm to webrctvad. It has the advantage of not requiring to install that
137 | ## darn package and it also works for any sampling rate. Maybe I'll eventually use it instead of
138 | ## webrctvad
139 | # def vad(wav, sampling_rate, eta=0.15, window_size=0):
140 | # """
141 | # TODO: fix doc
142 | # Creates a profile of the noise in a given waveform.
143 | #
144 | # :param wav: a waveform containing noise ONLY, as a numpy array of floats or ints.
145 | # :param sampling_rate: the sampling rate of the audio
146 | # :param window_size: the size of the window the logmmse algorithm operates on. A default value
147 | # will be picked if left as 0.
148 | # :param eta: voice threshold for noise update. While the voice activation detection value is
149 | # below this threshold, the noise profile will be continuously updated throughout the audio.
150 | # Set to 0 to disable updating the noise profile.
151 | # """
152 | # wav, dtype = to_float(wav)
153 | # wav += np.finfo(np.float64).eps
154 | #
155 | # if window_size == 0:
156 | # window_size = int(math.floor(0.02 * sampling_rate))
157 | #
158 | # if window_size % 2 == 1:
159 | # window_size = window_size + 1
160 | #
161 | # perc = 50
162 | # len1 = int(math.floor(window_size * perc / 100))
163 | # len2 = int(window_size - len1)
164 | #
165 | # win = np.hanning(window_size)
166 | # win = win * len2 / np.sum(win)
167 | # n_fft = 2 * window_size
168 | #
169 | # wav_mean = np.zeros(n_fft)
170 | # n_frames = len(wav) // window_size
171 | # for j in range(0, window_size * n_frames, window_size):
172 | # wav_mean += np.absolute(np.fft.fft(win * wav[j:j + window_size], n_fft, axis=0))
173 | # noise_mu2 = (wav_mean / n_frames) ** 2
174 | #
175 | # wav, dtype = to_float(wav)
176 | # wav += np.finfo(np.float64).eps
177 | #
178 | # nframes = int(math.floor(len(wav) / len2) - math.floor(window_size / len2))
179 | # vad = np.zeros(nframes * len2, dtype=np.bool)
180 | #
181 | # aa = 0.98
182 | # mu = 0.98
183 | # ksi_min = 10 ** (-25 / 10)
184 | #
185 | # xk_prev = np.zeros(len1)
186 | # noise_mu2 = noise_mu2
187 | # for k in range(0, nframes * len2, len2):
188 | # insign = win * wav[k:k + window_size]
189 | #
190 | # spec = np.fft.fft(insign, n_fft, axis=0)
191 | # sig = np.absolute(spec)
192 | # sig2 = sig ** 2
193 | #
194 | # gammak = np.minimum(sig2 / noise_mu2, 40)
195 | #
196 | # if xk_prev.all() == 0:
197 | # ksi = aa + (1 - aa) * np.maximum(gammak - 1, 0)
198 | # else:
199 | # ksi = aa * xk_prev / noise_mu2 + (1 - aa) * np.maximum(gammak - 1, 0)
200 | # ksi = np.maximum(ksi_min, ksi)
201 | #
202 | # log_sigma_k = gammak * ksi / (1 + ksi) - np.log(1 + ksi)
203 | # vad_decision = np.sum(log_sigma_k) / window_size
204 | # if vad_decision < eta:
205 | # noise_mu2 = mu * noise_mu2 + (1 - mu) * sig2
206 | # print(vad_decision)
207 | #
208 | # a = ksi / (1 + ksi)
209 | # vk = a * gammak
210 | # ei_vk = 0.5 * expn(1, np.maximum(vk, 1e-8))
211 | # hw = a * np.exp(ei_vk)
212 | # sig = sig * hw
213 | # xk_prev = sig ** 2
214 | #
215 | # vad[k:k + len2] = vad_decision >= eta
216 | #
217 | # vad = np.pad(vad, (0, len(wav) - len(vad)), mode="constant")
218 | # return vad
219 |
220 |
221 | def to_float(_input):
222 | if _input.dtype == np.float64:
223 | return _input, _input.dtype
224 | elif _input.dtype == np.float32:
225 | return _input.astype(np.float64), _input.dtype
226 | elif _input.dtype == np.uint8:
227 | return (_input - 128) / 128., _input.dtype
228 | elif _input.dtype == np.int16:
229 | return _input / 32768., _input.dtype
230 | elif _input.dtype == np.int32:
231 | return _input / 2147483648., _input.dtype
232 | raise ValueError('Unsupported wave file format')
233 |
234 |
235 | def from_float(_input, dtype):
236 | if dtype == np.float64:
237 | return _input, np.float64
238 | elif dtype == np.float32:
239 | return _input.astype(np.float32)
240 | elif dtype == np.uint8:
241 | return ((_input * 128) + 128).astype(np.uint8)
242 | elif dtype == np.int16:
243 | return (_input * 32768).astype(np.int16)
244 | elif dtype == np.int32:
245 | print(_input)
246 | return (_input * 2147483648).astype(np.int32)
247 | raise ValueError('Unsupported wave file format')
248 |
--------------------------------------------------------------------------------
/src/vocoder/models/fatchord_version.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from vocoder.distribution import sample_from_discretized_mix_logistic
5 | from vocoder.display import *
6 | from vocoder.audio import *
7 |
8 |
9 | class ResBlock(nn.Module):
10 | def __init__(self, dims):
11 | super().__init__()
12 | self.conv1 = nn.Conv1d(dims, dims, kernel_size=1, bias=False)
13 | self.conv2 = nn.Conv1d(dims, dims, kernel_size=1, bias=False)
14 | self.batch_norm1 = nn.BatchNorm1d(dims)
15 | self.batch_norm2 = nn.BatchNorm1d(dims)
16 |
17 | def forward(self, x):
18 | residual = x
19 | x = self.conv1(x)
20 | x = self.batch_norm1(x)
21 | x = F.relu(x)
22 | x = self.conv2(x)
23 | x = self.batch_norm2(x)
24 | return x + residual
25 |
26 |
27 | class MelResNet(nn.Module):
28 | def __init__(self, res_blocks, in_dims, compute_dims, res_out_dims, pad):
29 | super().__init__()
30 | k_size = pad * 2 + 1
31 | self.conv_in = nn.Conv1d(in_dims, compute_dims, kernel_size=k_size, bias=False)
32 | self.batch_norm = nn.BatchNorm1d(compute_dims)
33 | self.layers = nn.ModuleList()
34 | for i in range(res_blocks):
35 | self.layers.append(ResBlock(compute_dims))
36 | self.conv_out = nn.Conv1d(compute_dims, res_out_dims, kernel_size=1)
37 |
38 | def forward(self, x):
39 | x = self.conv_in(x)
40 | x = self.batch_norm(x)
41 | x = F.relu(x)
42 | for f in self.layers: x = f(x)
43 | x = self.conv_out(x)
44 | return x
45 |
46 |
47 | class Stretch2d(nn.Module):
48 | def __init__(self, x_scale, y_scale):
49 | super().__init__()
50 | self.x_scale = x_scale
51 | self.y_scale = y_scale
52 |
53 | def forward(self, x):
54 | b, c, h, w = x.size()
55 | x = x.unsqueeze(-1).unsqueeze(3)
56 | x = x.repeat(1, 1, 1, self.y_scale, 1, self.x_scale)
57 | return x.view(b, c, h * self.y_scale, w * self.x_scale)
58 |
59 |
60 | class UpsampleNetwork(nn.Module):
61 | def __init__(self, feat_dims, upsample_scales, compute_dims,
62 | res_blocks, res_out_dims, pad):
63 | super().__init__()
64 | total_scale = np.cumproduct(upsample_scales)[-1]
65 | self.indent = pad * total_scale
66 | self.resnet = MelResNet(res_blocks, feat_dims, compute_dims, res_out_dims, pad)
67 | self.resnet_stretch = Stretch2d(total_scale, 1)
68 | self.up_layers = nn.ModuleList()
69 | for scale in upsample_scales:
70 | k_size = (1, scale * 2 + 1)
71 | padding = (0, scale)
72 | stretch = Stretch2d(scale, 1)
73 | conv = nn.Conv2d(1, 1, kernel_size=k_size, padding=padding, bias=False)
74 | conv.weight.data.fill_(1. / k_size[1])
75 | self.up_layers.append(stretch)
76 | self.up_layers.append(conv)
77 |
78 | def forward(self, m):
79 | aux = self.resnet(m).unsqueeze(1)
80 | aux = self.resnet_stretch(aux)
81 | aux = aux.squeeze(1)
82 | m = m.unsqueeze(1)
83 | for f in self.up_layers: m = f(m)
84 | m = m.squeeze(1)[:, :, self.indent:-self.indent]
85 | return m.transpose(1, 2), aux.transpose(1, 2)
86 |
87 |
88 | class WaveRNN(nn.Module):
89 | def __init__(self, rnn_dims, fc_dims, bits, pad, upsample_factors,
90 | feat_dims, compute_dims, res_out_dims, res_blocks,
91 | hop_length, sample_rate, mode='RAW'):
92 | super().__init__()
93 | self.mode = mode
94 | self.pad = pad
95 | if self.mode == 'RAW' :
96 | self.n_classes = 2 ** bits
97 | elif self.mode == 'MOL' :
98 | self.n_classes = 30
99 | else :
100 | RuntimeError("Unknown model mode value - ", self.mode)
101 |
102 | self.rnn_dims = rnn_dims
103 | self.aux_dims = res_out_dims // 4
104 | self.hop_length = hop_length
105 | self.sample_rate = sample_rate
106 |
107 | self.upsample = UpsampleNetwork(feat_dims, upsample_factors, compute_dims, res_blocks, res_out_dims, pad)
108 | self.I = nn.Linear(feat_dims + self.aux_dims + 1, rnn_dims)
109 | self.rnn1 = nn.GRU(rnn_dims, rnn_dims, batch_first=True)
110 | self.rnn2 = nn.GRU(rnn_dims + self.aux_dims, rnn_dims, batch_first=True)
111 | self.fc1 = nn.Linear(rnn_dims + self.aux_dims, fc_dims)
112 | self.fc2 = nn.Linear(fc_dims + self.aux_dims, fc_dims)
113 | self.fc3 = nn.Linear(fc_dims, self.n_classes)
114 |
115 | self.step = nn.Parameter(torch.zeros(1).long(), requires_grad=False)
116 | self.num_params()
117 |
118 | def forward(self, x, mels):
119 | self.step += 1
120 | bsize = x.size(0)
121 | if torch.cuda.is_available():
122 | h1 = torch.zeros(1, bsize, self.rnn_dims).cuda()
123 | h2 = torch.zeros(1, bsize, self.rnn_dims).cuda()
124 | else:
125 | h1 = torch.zeros(1, bsize, self.rnn_dims).cpu()
126 | h2 = torch.zeros(1, bsize, self.rnn_dims).cpu()
127 | mels, aux = self.upsample(mels)
128 |
129 | aux_idx = [self.aux_dims * i for i in range(5)]
130 | a1 = aux[:, :, aux_idx[0]:aux_idx[1]]
131 | a2 = aux[:, :, aux_idx[1]:aux_idx[2]]
132 | a3 = aux[:, :, aux_idx[2]:aux_idx[3]]
133 | a4 = aux[:, :, aux_idx[3]:aux_idx[4]]
134 |
135 | x = torch.cat([x.unsqueeze(-1), mels, a1], dim=2)
136 | x = self.I(x)
137 | res = x
138 | x, _ = self.rnn1(x, h1)
139 |
140 | x = x + res
141 | res = x
142 | x = torch.cat([x, a2], dim=2)
143 | x, _ = self.rnn2(x, h2)
144 |
145 | x = x + res
146 | x = torch.cat([x, a3], dim=2)
147 | x = F.relu(self.fc1(x))
148 |
149 | x = torch.cat([x, a4], dim=2)
150 | x = F.relu(self.fc2(x))
151 | return self.fc3(x)
152 |
153 | def generate(self, mels, batched, target, overlap, mu_law, progress_callback=None):
154 | mu_law = mu_law if self.mode == 'RAW' else False
155 | progress_callback = progress_callback or self.gen_display
156 |
157 | self.eval()
158 | output = []
159 | start = time.time()
160 | rnn1 = self.get_gru_cell(self.rnn1)
161 | rnn2 = self.get_gru_cell(self.rnn2)
162 |
163 | with torch.no_grad():
164 | if torch.cuda.is_available():
165 | mels = mels.cuda()
166 | else:
167 | mels = mels.cpu()
168 | wave_len = (mels.size(-1) - 1) * self.hop_length
169 | mels = self.pad_tensor(mels.transpose(1, 2), pad=self.pad, side='both')
170 | mels, aux = self.upsample(mels.transpose(1, 2))
171 |
172 | if batched:
173 | mels = self.fold_with_overlap(mels, target, overlap)
174 | aux = self.fold_with_overlap(aux, target, overlap)
175 |
176 | b_size, seq_len, _ = mels.size()
177 |
178 | if torch.cuda.is_available():
179 | h1 = torch.zeros(b_size, self.rnn_dims).cuda()
180 | h2 = torch.zeros(b_size, self.rnn_dims).cuda()
181 | x = torch.zeros(b_size, 1).cuda()
182 | else:
183 | h1 = torch.zeros(b_size, self.rnn_dims).cpu()
184 | h2 = torch.zeros(b_size, self.rnn_dims).cpu()
185 | x = torch.zeros(b_size, 1).cpu()
186 |
187 | d = self.aux_dims
188 | aux_split = [aux[:, :, d * i:d * (i + 1)] for i in range(4)]
189 |
190 | for i in range(seq_len):
191 |
192 | m_t = mels[:, i, :]
193 |
194 | a1_t, a2_t, a3_t, a4_t = (a[:, i, :] for a in aux_split)
195 |
196 | x = torch.cat([x, m_t, a1_t], dim=1)
197 | x = self.I(x)
198 | h1 = rnn1(x, h1)
199 |
200 | x = x + h1
201 | inp = torch.cat([x, a2_t], dim=1)
202 | h2 = rnn2(inp, h2)
203 |
204 | x = x + h2
205 | x = torch.cat([x, a3_t], dim=1)
206 | x = F.relu(self.fc1(x))
207 |
208 | x = torch.cat([x, a4_t], dim=1)
209 | x = F.relu(self.fc2(x))
210 |
211 | logits = self.fc3(x)
212 |
213 | if self.mode == 'MOL':
214 | sample = sample_from_discretized_mix_logistic(logits.unsqueeze(0).transpose(1, 2))
215 | output.append(sample.view(-1))
216 | if torch.cuda.is_available():
217 | # x = torch.FloatTensor([[sample]]).cuda()
218 | x = sample.transpose(0, 1).cuda()
219 | else:
220 | x = sample.transpose(0, 1)
221 |
222 | elif self.mode == 'RAW' :
223 | posterior = F.softmax(logits, dim=1)
224 | distrib = torch.distributions.Categorical(posterior)
225 |
226 | sample = 2 * distrib.sample().float() / (self.n_classes - 1.) - 1.
227 | output.append(sample)
228 | x = sample.unsqueeze(-1)
229 | else:
230 | raise RuntimeError("Unknown model mode value - ", self.mode)
231 |
232 | if i % 100 == 0:
233 | gen_rate = (i + 1) / (time.time() - start) * b_size / 1000
234 | progress_callback(i, seq_len, b_size, gen_rate)
235 |
236 | output = torch.stack(output).transpose(0, 1)
237 | output = output.cpu().numpy()
238 | output = output.astype(np.float64)
239 |
240 | if batched:
241 | output = self.xfade_and_unfold(output, target, overlap)
242 | else:
243 | output = output[0]
244 |
245 | if mu_law:
246 | output = decode_mu_law(output, self.n_classes, False)
247 | if hp.apply_preemphasis:
248 | output = de_emphasis(output)
249 |
250 | # Fade-out at the end to avoid signal cutting out suddenly
251 | fade_out = np.linspace(1, 0, 20 * self.hop_length)
252 | output = output[:wave_len]
253 | output[-20 * self.hop_length:] *= fade_out
254 |
255 | self.train()
256 |
257 | return output
258 |
259 |
260 | def gen_display(self, i, seq_len, b_size, gen_rate):
261 | pbar = progbar(i, seq_len)
262 | msg = f'| {pbar} {i*b_size}/{seq_len*b_size} | Batch Size: {b_size} | Gen Rate: {gen_rate:.1f}kHz | '
263 | stream(msg)
264 |
265 | def get_gru_cell(self, gru):
266 | gru_cell = nn.GRUCell(gru.input_size, gru.hidden_size)
267 | gru_cell.weight_hh.data = gru.weight_hh_l0.data
268 | gru_cell.weight_ih.data = gru.weight_ih_l0.data
269 | gru_cell.bias_hh.data = gru.bias_hh_l0.data
270 | gru_cell.bias_ih.data = gru.bias_ih_l0.data
271 | return gru_cell
272 |
273 | def pad_tensor(self, x, pad, side='both'):
274 | # NB - this is just a quick method i need right now
275 | # i.e., it won't generalise to other shapes/dims
276 | b, t, c = x.size()
277 | total = t + 2 * pad if side == 'both' else t + pad
278 | if torch.cuda.is_available():
279 | padded = torch.zeros(b, total, c).cuda()
280 | else:
281 | padded = torch.zeros(b, total, c).cpu()
282 | if side == 'before' or side == 'both':
283 | padded[:, pad:pad + t, :] = x
284 | elif side == 'after':
285 | padded[:, :t, :] = x
286 | return padded
287 |
288 | def fold_with_overlap(self, x, target, overlap):
289 |
290 | ''' Fold the tensor with overlap for quick batched inference.
291 | Overlap will be used for crossfading in xfade_and_unfold()
292 |
293 | Args:
294 | x (tensor) : Upsampled conditioning features.
295 | shape=(1, timesteps, features)
296 | target (int) : Target timesteps for each index of batch
297 | overlap (int) : Timesteps for both xfade and rnn warmup
298 |
299 | Return:
300 | (tensor) : shape=(num_folds, target + 2 * overlap, features)
301 |
302 | Details:
303 | x = [[h1, h2, ... hn]]
304 |
305 | Where each h is a vector of conditioning features
306 |
307 | Eg: target=2, overlap=1 with x.size(1)=10
308 |
309 | folded = [[h1, h2, h3, h4],
310 | [h4, h5, h6, h7],
311 | [h7, h8, h9, h10]]
312 | '''
313 |
314 | _, total_len, features = x.size()
315 |
316 | # Calculate variables needed
317 | num_folds = (total_len - overlap) // (target + overlap)
318 | extended_len = num_folds * (overlap + target) + overlap
319 | remaining = total_len - extended_len
320 |
321 | # Pad if some time steps poking out
322 | if remaining != 0:
323 | num_folds += 1
324 | padding = target + 2 * overlap - remaining
325 | x = self.pad_tensor(x, padding, side='after')
326 |
327 | if torch.cuda.is_available():
328 | folded = torch.zeros(num_folds, target + 2 * overlap, features).cuda()
329 | else:
330 | folded = torch.zeros(num_folds, target + 2 * overlap, features).cpu()
331 |
332 | # Get the values for the folded tensor
333 | for i in range(num_folds):
334 | start = i * (target + overlap)
335 | end = start + target + 2 * overlap
336 | folded[i] = x[:, start:end, :]
337 |
338 | return folded
339 |
340 | def xfade_and_unfold(self, y, target, overlap):
341 |
342 | ''' Applies a crossfade and unfolds into a 1d array.
343 |
344 | Args:
345 | y (ndarry) : Batched sequences of audio samples
346 | shape=(num_folds, target + 2 * overlap)
347 | dtype=np.float64
348 | overlap (int) : Timesteps for both xfade and rnn warmup
349 |
350 | Return:
351 | (ndarry) : audio samples in a 1d array
352 | shape=(total_len)
353 | dtype=np.float64
354 |
355 | Details:
356 | y = [[seq1],
357 | [seq2],
358 | [seq3]]
359 |
360 | Apply a gain envelope at both ends of the sequences
361 |
362 | y = [[seq1_in, seq1_target, seq1_out],
363 | [seq2_in, seq2_target, seq2_out],
364 | [seq3_in, seq3_target, seq3_out]]
365 |
366 | Stagger and add up the groups of samples:
367 |
368 | [seq1_in, seq1_target, (seq1_out + seq2_in), seq2_target, ...]
369 |
370 | '''
371 |
372 | num_folds, length = y.shape
373 | target = length - 2 * overlap
374 | total_len = num_folds * (target + overlap) + overlap
375 |
376 | # Need some silence for the rnn warmup
377 | silence_len = overlap // 2
378 | fade_len = overlap - silence_len
379 | silence = np.zeros((silence_len), dtype=np.float64)
380 |
381 | # Equal power crossfade
382 | t = np.linspace(-1, 1, fade_len, dtype=np.float64)
383 | fade_in = np.sqrt(0.5 * (1 + t))
384 | fade_out = np.sqrt(0.5 * (1 - t))
385 |
386 | # Concat the silence to the fades
387 | fade_in = np.concatenate([silence, fade_in])
388 | fade_out = np.concatenate([fade_out, silence])
389 |
390 | # Apply the gain to the overlap samples
391 | y[:, :overlap] *= fade_in
392 | y[:, -overlap:] *= fade_out
393 |
394 | unfolded = np.zeros((total_len), dtype=np.float64)
395 |
396 | # Loop to add up all the samples
397 | for i in range(num_folds):
398 | start = i * (target + overlap)
399 | end = start + target + 2 * overlap
400 | unfolded[start:end] += y[i]
401 |
402 | return unfolded
403 |
404 | def get_step(self) :
405 | return self.step.data.item()
406 |
407 | def checkpoint(self, model_dir, optimizer) :
408 | k_steps = self.get_step() // 1000
409 | self.save(model_dir.joinpath("checkpoint_%dk_steps.pt" % k_steps), optimizer)
410 |
411 | def log(self, path, msg) :
412 | with open(path, 'a') as f:
413 | print(msg, file=f)
414 |
415 | def load(self, path, optimizer) :
416 | checkpoint = torch.load(path)
417 | if "optimizer_state" in checkpoint:
418 | self.load_state_dict(checkpoint["model_state"])
419 | optimizer.load_state_dict(checkpoint["optimizer_state"])
420 | else:
421 | # Backwards compatibility
422 | self.load_state_dict(checkpoint)
423 |
424 | def save(self, path, optimizer) :
425 | torch.save({
426 | "model_state": self.state_dict(),
427 | "optimizer_state": optimizer.state_dict(),
428 | }, path)
429 |
430 | def num_params(self, print_out=True):
431 | parameters = filter(lambda p: p.requires_grad, self.parameters())
432 | parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000
433 | if print_out :
434 | print('Trainable Parameters: %.3fM' % parameters)
435 |
--------------------------------------------------------------------------------
/src/synthesizer/models/tacotron.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | from pathlib import Path
7 | from typing import Union
8 |
9 |
10 | class HighwayNetwork(nn.Module):
11 | def __init__(self, size):
12 | super().__init__()
13 | self.W1 = nn.Linear(size, size)
14 | self.W2 = nn.Linear(size, size)
15 | self.W1.bias.data.fill_(0.)
16 |
17 | def forward(self, x):
18 | x1 = self.W1(x)
19 | x2 = self.W2(x)
20 | g = torch.sigmoid(x2)
21 | y = g * F.relu(x1) + (1. - g) * x
22 | return y
23 |
24 |
25 | class Encoder(nn.Module):
26 | def __init__(self, embed_dims, num_chars, encoder_dims, K, num_highways, dropout):
27 | super().__init__()
28 | prenet_dims = (encoder_dims, encoder_dims)
29 | cbhg_channels = encoder_dims
30 | self.embedding = nn.Embedding(num_chars, embed_dims)
31 | self.pre_net = PreNet(embed_dims, fc1_dims=prenet_dims[0], fc2_dims=prenet_dims[1],
32 | dropout=dropout)
33 | self.cbhg = CBHG(K=K, in_channels=cbhg_channels, channels=cbhg_channels,
34 | proj_channels=[cbhg_channels, cbhg_channels],
35 | num_highways=num_highways)
36 |
37 | def forward(self, x, speaker_embedding=None):
38 | x = self.embedding(x)
39 | x = self.pre_net(x)
40 | x.transpose_(1, 2)
41 | x = self.cbhg(x)
42 | if speaker_embedding is not None:
43 | x = self.add_speaker_embedding(x, speaker_embedding)
44 | return x
45 |
46 | def add_speaker_embedding(self, x, speaker_embedding):
47 | # SV2TTS
48 | # The input x is the encoder output and is a 3D tensor with size (batch_size, num_chars, tts_embed_dims)
49 | # When training, speaker_embedding is also a 2D tensor with size (batch_size, speaker_embedding_size)
50 | # (for inference, speaker_embedding is a 1D tensor with size (speaker_embedding_size))
51 | # This concats the speaker embedding for each char in the encoder output
52 |
53 | # Save the dimensions as human-readable names
54 | batch_size = x.size()[0]
55 | num_chars = x.size()[1]
56 |
57 | if speaker_embedding.dim() == 1:
58 | idx = 0
59 | else:
60 | idx = 1
61 |
62 | # Start by making a copy of each speaker embedding to match the input text length
63 | # The output of this has size (batch_size, num_chars * tts_embed_dims)
64 | speaker_embedding_size = speaker_embedding.size()[idx]
65 | e = speaker_embedding.repeat_interleave(num_chars, dim=idx)
66 |
67 | # Reshape it and transpose
68 | e = e.reshape(batch_size, speaker_embedding_size, num_chars)
69 | e = e.transpose(1, 2)
70 |
71 | # Concatenate the tiled speaker embedding with the encoder output
72 | x = torch.cat((x, e), 2)
73 | return x
74 |
75 |
76 | class BatchNormConv(nn.Module):
77 | def __init__(self, in_channels, out_channels, kernel, relu=True):
78 | super().__init__()
79 | self.conv = nn.Conv1d(in_channels, out_channels, kernel, stride=1, padding=kernel // 2, bias=False)
80 | self.bnorm = nn.BatchNorm1d(out_channels)
81 | self.relu = relu
82 |
83 | def forward(self, x):
84 | x = self.conv(x)
85 | x = F.relu(x) if self.relu is True else x
86 | return self.bnorm(x)
87 |
88 |
89 | class CBHG(nn.Module):
90 | def __init__(self, K, in_channels, channels, proj_channels, num_highways):
91 | super().__init__()
92 |
93 | # List of all rnns to call `flatten_parameters()` on
94 | self._to_flatten = []
95 |
96 | self.bank_kernels = [i for i in range(1, K + 1)]
97 | self.conv1d_bank = nn.ModuleList()
98 | for k in self.bank_kernels:
99 | conv = BatchNormConv(in_channels, channels, k)
100 | self.conv1d_bank.append(conv)
101 |
102 | self.maxpool = nn.MaxPool1d(kernel_size=2, stride=1, padding=1)
103 |
104 | self.conv_project1 = BatchNormConv(len(self.bank_kernels) * channels, proj_channels[0], 3)
105 | self.conv_project2 = BatchNormConv(proj_channels[0], proj_channels[1], 3, relu=False)
106 |
107 | # Fix the highway input if necessary
108 | if proj_channels[-1] != channels:
109 | self.highway_mismatch = True
110 | self.pre_highway = nn.Linear(proj_channels[-1], channels, bias=False)
111 | else:
112 | self.highway_mismatch = False
113 |
114 | self.highways = nn.ModuleList()
115 | for i in range(num_highways):
116 | hn = HighwayNetwork(channels)
117 | self.highways.append(hn)
118 |
119 | self.rnn = nn.GRU(channels, channels // 2, batch_first=True, bidirectional=True)
120 | self._to_flatten.append(self.rnn)
121 |
122 | # Avoid fragmentation of RNN parameters and associated warning
123 | self._flatten_parameters()
124 |
125 | def forward(self, x):
126 | # Although we `_flatten_parameters()` on init, when using DataParallel
127 | # the model gets replicated, making it no longer guaranteed that the
128 | # weights are contiguous in GPU memory. Hence, we must call it again
129 | self._flatten_parameters()
130 |
131 | # Save these for later
132 | residual = x
133 | seq_len = x.size(-1)
134 | conv_bank = []
135 |
136 | # Convolution Bank
137 | for conv in self.conv1d_bank:
138 | c = conv(x) # Convolution
139 | conv_bank.append(c[:, :, :seq_len])
140 |
141 | # Stack along the channel axis
142 | conv_bank = torch.cat(conv_bank, dim=1)
143 |
144 | # dump the last padding to fit residual
145 | x = self.maxpool(conv_bank)[:, :, :seq_len]
146 |
147 | # Conv1d projections
148 | x = self.conv_project1(x)
149 | x = self.conv_project2(x)
150 |
151 | # Residual Connect
152 | x = x + residual
153 |
154 | # Through the highways
155 | x = x.transpose(1, 2)
156 | if self.highway_mismatch is True:
157 | x = self.pre_highway(x)
158 | for h in self.highways: x = h(x)
159 |
160 | # And then the RNN
161 | x, _ = self.rnn(x)
162 | return x
163 |
164 | def _flatten_parameters(self):
165 | """Calls `flatten_parameters` on all the rnns used by the WaveRNN. Used
166 | to improve efficiency and avoid PyTorch yelling at us."""
167 | [m.flatten_parameters() for m in self._to_flatten]
168 |
169 | class PreNet(nn.Module):
170 | def __init__(self, in_dims, fc1_dims=256, fc2_dims=128, dropout=0.5):
171 | super().__init__()
172 | self.fc1 = nn.Linear(in_dims, fc1_dims)
173 | self.fc2 = nn.Linear(fc1_dims, fc2_dims)
174 | self.p = dropout
175 |
176 | def forward(self, x):
177 | x = self.fc1(x)
178 | x = F.relu(x)
179 | x = F.dropout(x, self.p, training=True)
180 | x = self.fc2(x)
181 | x = F.relu(x)
182 | x = F.dropout(x, self.p, training=True)
183 | return x
184 |
185 |
186 | class Attention(nn.Module):
187 | def __init__(self, attn_dims):
188 | super().__init__()
189 | self.W = nn.Linear(attn_dims, attn_dims, bias=False)
190 | self.v = nn.Linear(attn_dims, 1, bias=False)
191 |
192 | def forward(self, encoder_seq_proj, query, t):
193 |
194 | # print(encoder_seq_proj.shape)
195 | # Transform the query vector
196 | query_proj = self.W(query).unsqueeze(1)
197 |
198 | # Compute the scores
199 | u = self.v(torch.tanh(encoder_seq_proj + query_proj))
200 | scores = F.softmax(u, dim=1)
201 |
202 | return scores.transpose(1, 2)
203 |
204 |
205 | class LSA(nn.Module):
206 | def __init__(self, attn_dim, kernel_size=31, filters=32):
207 | super().__init__()
208 | self.conv = nn.Conv1d(1, filters, padding=(kernel_size - 1) // 2, kernel_size=kernel_size, bias=True)
209 | self.L = nn.Linear(filters, attn_dim, bias=False)
210 | self.W = nn.Linear(attn_dim, attn_dim, bias=True) # Include the attention bias in this term
211 | self.v = nn.Linear(attn_dim, 1, bias=False)
212 | self.cumulative = None
213 | self.attention = None
214 |
215 | def init_attention(self, encoder_seq_proj):
216 | device = next(self.parameters()).device # use same device as parameters
217 | b, t, c = encoder_seq_proj.size()
218 | self.cumulative = torch.zeros(b, t, device=device)
219 | self.attention = torch.zeros(b, t, device=device)
220 |
221 | def forward(self, encoder_seq_proj, query, t, chars):
222 |
223 | if t == 0: self.init_attention(encoder_seq_proj)
224 |
225 | processed_query = self.W(query).unsqueeze(1)
226 |
227 | location = self.cumulative.unsqueeze(1)
228 | processed_loc = self.L(self.conv(location).transpose(1, 2))
229 |
230 | u = self.v(torch.tanh(processed_query + encoder_seq_proj + processed_loc))
231 | u = u.squeeze(-1)
232 |
233 | # Mask zero padding chars
234 | u = u * (chars != 0).float()
235 |
236 | # Smooth Attention
237 | # scores = torch.sigmoid(u) / torch.sigmoid(u).sum(dim=1, keepdim=True)
238 | scores = F.softmax(u, dim=1)
239 | self.attention = scores
240 | self.cumulative = self.cumulative + self.attention
241 |
242 | return scores.unsqueeze(-1).transpose(1, 2)
243 |
244 |
245 | class Decoder(nn.Module):
246 | # Class variable because its value doesn't change between classes
247 | # yet ought to be scoped by class because its a property of a Decoder
248 | max_r = 20
249 | def __init__(self, n_mels, encoder_dims, decoder_dims, lstm_dims,
250 | dropout, speaker_embedding_size):
251 | super().__init__()
252 | self.register_buffer("r", torch.tensor(1, dtype=torch.int))
253 | self.n_mels = n_mels
254 | prenet_dims = (decoder_dims * 2, decoder_dims * 2)
255 | self.prenet = PreNet(n_mels, fc1_dims=prenet_dims[0], fc2_dims=prenet_dims[1],
256 | dropout=dropout)
257 | self.attn_net = LSA(decoder_dims)
258 | self.attn_rnn = nn.GRUCell(encoder_dims + prenet_dims[1] + speaker_embedding_size, decoder_dims)
259 | self.rnn_input = nn.Linear(encoder_dims + decoder_dims + speaker_embedding_size, lstm_dims)
260 | self.res_rnn1 = nn.LSTMCell(lstm_dims, lstm_dims)
261 | self.res_rnn2 = nn.LSTMCell(lstm_dims, lstm_dims)
262 | self.mel_proj = nn.Linear(lstm_dims, n_mels * self.max_r, bias=False)
263 | self.stop_proj = nn.Linear(encoder_dims + speaker_embedding_size + lstm_dims, 1)
264 |
265 | def zoneout(self, prev, current, p=0.1):
266 | device = next(self.parameters()).device # Use same device as parameters
267 | mask = torch.zeros(prev.size(), device=device).bernoulli_(p)
268 | return prev * mask + current * (1 - mask)
269 |
270 | def forward(self, encoder_seq, encoder_seq_proj, prenet_in,
271 | hidden_states, cell_states, context_vec, t, chars):
272 |
273 | # Need this for reshaping mels
274 | batch_size = encoder_seq.size(0)
275 |
276 | # Unpack the hidden and cell states
277 | attn_hidden, rnn1_hidden, rnn2_hidden = hidden_states
278 | rnn1_cell, rnn2_cell = cell_states
279 |
280 | # PreNet for the Attention RNN
281 | prenet_out = self.prenet(prenet_in)
282 |
283 | # Compute the Attention RNN hidden state
284 | attn_rnn_in = torch.cat([context_vec, prenet_out], dim=-1)
285 | attn_hidden = self.attn_rnn(attn_rnn_in.squeeze(1), attn_hidden)
286 |
287 | # Compute the attention scores
288 | scores = self.attn_net(encoder_seq_proj, attn_hidden, t, chars)
289 |
290 | # Dot product to create the context vector
291 | context_vec = scores @ encoder_seq
292 | context_vec = context_vec.squeeze(1)
293 |
294 | # Concat Attention RNN output w. Context Vector & project
295 | x = torch.cat([context_vec, attn_hidden], dim=1)
296 | x = self.rnn_input(x)
297 |
298 | # Compute first Residual RNN
299 | rnn1_hidden_next, rnn1_cell = self.res_rnn1(x, (rnn1_hidden, rnn1_cell))
300 | if self.training:
301 | rnn1_hidden = self.zoneout(rnn1_hidden, rnn1_hidden_next)
302 | else:
303 | rnn1_hidden = rnn1_hidden_next
304 | x = x + rnn1_hidden
305 |
306 | # Compute second Residual RNN
307 | rnn2_hidden_next, rnn2_cell = self.res_rnn2(x, (rnn2_hidden, rnn2_cell))
308 | if self.training:
309 | rnn2_hidden = self.zoneout(rnn2_hidden, rnn2_hidden_next)
310 | else:
311 | rnn2_hidden = rnn2_hidden_next
312 | x = x + rnn2_hidden
313 |
314 | # Project Mels
315 | mels = self.mel_proj(x)
316 | mels = mels.view(batch_size, self.n_mels, self.max_r)[:, :, :self.r]
317 | hidden_states = (attn_hidden, rnn1_hidden, rnn2_hidden)
318 | cell_states = (rnn1_cell, rnn2_cell)
319 |
320 | # Stop token prediction
321 | s = torch.cat((x, context_vec), dim=1)
322 | s = self.stop_proj(s)
323 | stop_tokens = torch.sigmoid(s)
324 |
325 | return mels, scores, hidden_states, cell_states, context_vec, stop_tokens
326 |
327 |
328 | class Tacotron(nn.Module):
329 | def __init__(self, embed_dims, num_chars, encoder_dims, decoder_dims, n_mels,
330 | fft_bins, postnet_dims, encoder_K, lstm_dims, postnet_K, num_highways,
331 | dropout, stop_threshold, speaker_embedding_size):
332 | super().__init__()
333 | self.n_mels = n_mels
334 | self.lstm_dims = lstm_dims
335 | self.encoder_dims = encoder_dims
336 | self.decoder_dims = decoder_dims
337 | self.speaker_embedding_size = speaker_embedding_size
338 | self.encoder = Encoder(embed_dims, num_chars, encoder_dims,
339 | encoder_K, num_highways, dropout)
340 | self.encoder_proj = nn.Linear(encoder_dims + speaker_embedding_size, decoder_dims, bias=False)
341 | self.decoder = Decoder(n_mels, encoder_dims, decoder_dims, lstm_dims,
342 | dropout, speaker_embedding_size)
343 | self.postnet = CBHG(postnet_K, n_mels, postnet_dims,
344 | [postnet_dims, fft_bins], num_highways)
345 | self.post_proj = nn.Linear(postnet_dims, fft_bins, bias=False)
346 |
347 | self.init_model()
348 | self.num_params(print_out=False)
349 |
350 | self.register_buffer("step", torch.zeros(1, dtype=torch.long))
351 | self.register_buffer("stop_threshold", torch.tensor(stop_threshold, dtype=torch.float32))
352 |
353 | @property
354 | def r(self):
355 | return self.decoder.r.item()
356 |
357 | @r.setter
358 | def r(self, value):
359 | self.decoder.r = self.decoder.r.new_tensor(value, requires_grad=False)
360 |
361 | def forward(self, x, m, speaker_embedding):
362 | device = next(self.parameters()).device # use same device as parameters
363 |
364 | self.step += 1
365 | batch_size, _, steps = m.size()
366 |
367 | # Initialise all hidden states and pack into tuple
368 | attn_hidden = torch.zeros(batch_size, self.decoder_dims, device=device)
369 | rnn1_hidden = torch.zeros(batch_size, self.lstm_dims, device=device)
370 | rnn2_hidden = torch.zeros(batch_size, self.lstm_dims, device=device)
371 | hidden_states = (attn_hidden, rnn1_hidden, rnn2_hidden)
372 |
373 | # Initialise all lstm cell states and pack into tuple
374 | rnn1_cell = torch.zeros(batch_size, self.lstm_dims, device=device)
375 | rnn2_cell = torch.zeros(batch_size, self.lstm_dims, device=device)
376 | cell_states = (rnn1_cell, rnn2_cell)
377 |
378 | # Frame for start of decoder loop
379 | go_frame = torch.zeros(batch_size, self.n_mels, device=device)
380 |
381 | # Need an initial context vector
382 | context_vec = torch.zeros(batch_size, self.encoder_dims + self.speaker_embedding_size, device=device)
383 |
384 | # SV2TTS: Run the encoder with the speaker embedding
385 | # The projection avoids unnecessary matmuls in the decoder loop
386 | encoder_seq = self.encoder(x, speaker_embedding)
387 | encoder_seq_proj = self.encoder_proj(encoder_seq)
388 |
389 | # Need a couple of lists for outputs
390 | mel_outputs, attn_scores, stop_outputs = [], [], []
391 |
392 | # Run the decoder loop
393 | for t in range(0, steps, self.r):
394 | prenet_in = m[:, :, t - 1] if t > 0 else go_frame
395 | mel_frames, scores, hidden_states, cell_states, context_vec, stop_tokens = \
396 | self.decoder(encoder_seq, encoder_seq_proj, prenet_in,
397 | hidden_states, cell_states, context_vec, t, x)
398 | mel_outputs.append(mel_frames)
399 | attn_scores.append(scores)
400 | stop_outputs.extend([stop_tokens] * self.r)
401 |
402 | # Concat the mel outputs into sequence
403 | mel_outputs = torch.cat(mel_outputs, dim=2)
404 |
405 | # Post-Process for Linear Spectrograms
406 | postnet_out = self.postnet(mel_outputs)
407 | linear = self.post_proj(postnet_out)
408 | linear = linear.transpose(1, 2)
409 |
410 | # For easy visualisation
411 | attn_scores = torch.cat(attn_scores, 1)
412 | # attn_scores = attn_scores.cpu().data.numpy()
413 | stop_outputs = torch.cat(stop_outputs, 1)
414 |
415 | return mel_outputs, linear, attn_scores, stop_outputs
416 |
417 | def generate(self, x, speaker_embedding=None, steps=2000):
418 | self.eval()
419 | device = next(self.parameters()).device # use same device as parameters
420 |
421 | batch_size, _ = x.size()
422 |
423 | # Need to initialise all hidden states and pack into tuple for tidyness
424 | attn_hidden = torch.zeros(batch_size, self.decoder_dims, device=device)
425 | rnn1_hidden = torch.zeros(batch_size, self.lstm_dims, device=device)
426 | rnn2_hidden = torch.zeros(batch_size, self.lstm_dims, device=device)
427 | hidden_states = (attn_hidden, rnn1_hidden, rnn2_hidden)
428 |
429 | # Need to initialise all lstm cell states and pack into tuple for tidyness
430 | rnn1_cell = torch.zeros(batch_size, self.lstm_dims, device=device)
431 | rnn2_cell = torch.zeros(batch_size, self.lstm_dims, device=device)
432 | cell_states = (rnn1_cell, rnn2_cell)
433 |
434 | # Need a Frame for start of decoder loop
435 | go_frame = torch.zeros(batch_size, self.n_mels, device=device)
436 |
437 | # Need an initial context vector
438 | context_vec = torch.zeros(batch_size, self.encoder_dims + self.speaker_embedding_size, device=device)
439 |
440 | # SV2TTS: Run the encoder with the speaker embedding
441 | # The projection avoids unnecessary matmuls in the decoder loop
442 | encoder_seq = self.encoder(x, speaker_embedding)
443 | encoder_seq_proj = self.encoder_proj(encoder_seq)
444 |
445 | # Need a couple of lists for outputs
446 | mel_outputs, attn_scores, stop_outputs = [], [], []
447 |
448 | # Run the decoder loop
449 | for t in range(0, steps, self.r):
450 | prenet_in = mel_outputs[-1][:, :, -1] if t > 0 else go_frame
451 | mel_frames, scores, hidden_states, cell_states, context_vec, stop_tokens = \
452 | self.decoder(encoder_seq, encoder_seq_proj, prenet_in,
453 | hidden_states, cell_states, context_vec, t, x)
454 | mel_outputs.append(mel_frames)
455 | attn_scores.append(scores)
456 | stop_outputs.extend([stop_tokens] * self.r)
457 | # Stop the loop when all stop tokens in batch exceed threshold
458 | if (stop_tokens > 0.5).all() and t > 10: break
459 |
460 | # Concat the mel outputs into sequence
461 | mel_outputs = torch.cat(mel_outputs, dim=2)
462 |
463 | # Post-Process for Linear Spectrograms
464 | postnet_out = self.postnet(mel_outputs)
465 | linear = self.post_proj(postnet_out)
466 |
467 |
468 | linear = linear.transpose(1, 2)
469 |
470 | # For easy visualisation
471 | attn_scores = torch.cat(attn_scores, 1)
472 | stop_outputs = torch.cat(stop_outputs, 1)
473 |
474 | self.train()
475 |
476 | return mel_outputs, linear, attn_scores
477 |
478 | def init_model(self):
479 | for p in self.parameters():
480 | if p.dim() > 1: nn.init.xavier_uniform_(p)
481 |
482 | def get_step(self):
483 | return self.step.data.item()
484 |
485 | def reset_step(self):
486 | # assignment to parameters or buffers is overloaded, updates internal dict entry
487 | self.step = self.step.data.new_tensor(1)
488 |
489 | def log(self, path, msg):
490 | with open(path, "a") as f:
491 | print(msg, file=f)
492 |
493 | def load(self, path, optimizer=None):
494 | # Use device of model params as location for loaded state
495 | device = next(self.parameters()).device
496 | checkpoint = torch.load(str(path), map_location=device)
497 | self.load_state_dict(checkpoint["model_state"])
498 |
499 | if "optimizer_state" in checkpoint and optimizer is not None:
500 | optimizer.load_state_dict(checkpoint["optimizer_state"])
501 |
502 | def save(self, path, optimizer=None):
503 | if optimizer is not None:
504 | torch.save({
505 | "model_state": self.state_dict(),
506 | "optimizer_state": optimizer.state_dict(),
507 | }, str(path))
508 | else:
509 | torch.save({
510 | "model_state": self.state_dict(),
511 | }, str(path))
512 |
513 |
514 | def num_params(self, print_out=True):
515 | parameters = filter(lambda p: p.requires_grad, self.parameters())
516 | parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000
517 | if print_out:
518 | print("Trainable Parameters: %.3fM" % parameters)
519 | return parameters
520 |
--------------------------------------------------------------------------------