├── src ├── encoder │ ├── __init__.py │ ├── saved_models │ │ └── .gitkeep │ ├── params_model.py │ ├── params_data.py │ ├── audio.py │ ├── model.py │ └── inference.py ├── utils │ ├── __init__.py │ ├── modelutils.py │ ├── argutils.py │ ├── profiler.py │ └── logmmse.py ├── synthesizer │ ├── __init__.py │ ├── saved_models │ │ └── pretrained │ │ │ └── .gitkeep │ ├── utils │ │ ├── symbols.py │ │ ├── __init__.py │ │ ├── _cmudict.py │ │ ├── numbers.py │ │ ├── text.py │ │ ├── plot.py │ │ └── cleaners.py │ ├── hparams.py │ ├── inference.py │ ├── audio.py │ └── models │ │ └── tacotron.py ├── vocoder │ ├── saved_models │ │ └── pretrained │ │ │ └── .gitkeep │ ├── hparams.py │ ├── inference.py │ ├── audio.py │ ├── display.py │ ├── distribution.py │ └── models │ │ └── fatchord_version.py ├── audio │ ├── source.wav │ ├── target.wav │ └── audio_out.wav ├── README.md └── main.py ├── docs ├── audio │ ├── class_1_output.wav │ ├── class_1_source.wav │ ├── class_1_target.wav │ ├── class_2_output.wav │ ├── class_2_source.wav │ ├── class_2_target.wav │ ├── class_3_output.wav │ ├── class_3_source.wav │ ├── class_3_target.wav │ ├── class_4_output.wav │ ├── class_4_source.wav │ ├── class_4_target.wav │ ├── class_5_output.wav │ ├── class_5_source.wav │ └── class_5_target.wav ├── _config.yml ├── _layouts │ └── default.html └── index.md ├── .gitignore ├── README.md └── requirements.txt /src/encoder/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/synthesizer/__init__.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------------------------------- /src/encoder/saved_models/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/synthesizer/saved_models/pretrained/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/vocoder/saved_models/pretrained/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/audio/source.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fmiotello/fastVC/HEAD/src/audio/source.wav -------------------------------------------------------------------------------- /src/audio/target.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fmiotello/fastVC/HEAD/src/audio/target.wav -------------------------------------------------------------------------------- /src/audio/audio_out.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fmiotello/fastVC/HEAD/src/audio/audio_out.wav -------------------------------------------------------------------------------- /docs/audio/class_1_output.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fmiotello/fastVC/HEAD/docs/audio/class_1_output.wav -------------------------------------------------------------------------------- /docs/audio/class_1_source.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fmiotello/fastVC/HEAD/docs/audio/class_1_source.wav -------------------------------------------------------------------------------- /docs/audio/class_1_target.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fmiotello/fastVC/HEAD/docs/audio/class_1_target.wav -------------------------------------------------------------------------------- /docs/audio/class_2_output.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fmiotello/fastVC/HEAD/docs/audio/class_2_output.wav -------------------------------------------------------------------------------- /docs/audio/class_2_source.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fmiotello/fastVC/HEAD/docs/audio/class_2_source.wav -------------------------------------------------------------------------------- /docs/audio/class_2_target.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fmiotello/fastVC/HEAD/docs/audio/class_2_target.wav -------------------------------------------------------------------------------- /docs/audio/class_3_output.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fmiotello/fastVC/HEAD/docs/audio/class_3_output.wav -------------------------------------------------------------------------------- /docs/audio/class_3_source.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fmiotello/fastVC/HEAD/docs/audio/class_3_source.wav -------------------------------------------------------------------------------- /docs/audio/class_3_target.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fmiotello/fastVC/HEAD/docs/audio/class_3_target.wav -------------------------------------------------------------------------------- /docs/audio/class_4_output.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fmiotello/fastVC/HEAD/docs/audio/class_4_output.wav -------------------------------------------------------------------------------- /docs/audio/class_4_source.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fmiotello/fastVC/HEAD/docs/audio/class_4_source.wav -------------------------------------------------------------------------------- /docs/audio/class_4_target.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fmiotello/fastVC/HEAD/docs/audio/class_4_target.wav -------------------------------------------------------------------------------- /docs/audio/class_5_output.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fmiotello/fastVC/HEAD/docs/audio/class_5_output.wav -------------------------------------------------------------------------------- /docs/audio/class_5_source.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fmiotello/fastVC/HEAD/docs/audio/class_5_source.wav -------------------------------------------------------------------------------- /docs/audio/class_5_target.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fmiotello/fastVC/HEAD/docs/audio/class_5_target.wav -------------------------------------------------------------------------------- /docs/_config.yml: -------------------------------------------------------------------------------- 1 | theme: jekyll-theme-cayman 2 | title: Fast VC 3 | description: A fast and efficient voice conversion tool 4 | -------------------------------------------------------------------------------- /src/encoder/params_model.py: -------------------------------------------------------------------------------- 1 | 2 | ## Model parameters 3 | model_hidden_size = 256 4 | model_embedding_size = 256 5 | model_num_layers = 3 6 | 7 | 8 | ## Training parameters 9 | learning_rate_init = 1e-4 10 | speakers_per_batch = 64 11 | utterances_per_speaker = 10 12 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | env/ 3 | LICENSE.txt 4 | *__pycache__ 5 | *.wav 6 | !src/audio/*.wav 7 | !docs/audio/*.wav 8 | src/encoder/saved_models/pretrained.pt 9 | src/vocoder/saved_models/pretrained/pretrained.pt 10 | src/synthesizer/saved_models/pretrained/pretrained.pt 11 | -------------------------------------------------------------------------------- /src/synthesizer/utils/symbols.py: -------------------------------------------------------------------------------- 1 | """ 2 | Defines the set of symbols used in text input to the model. 3 | 4 | The default is a set of ASCII characters that works well for English or text that has been run 5 | through Unidecode. For other data, you can modify _characters. See TRAINING_DATA.md for details. 6 | """ 7 | # from . import cmudict 8 | 9 | _pad = "_" 10 | _eos = "~" 11 | _characters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz!\'\"(),-.:;? " 12 | 13 | # Prepend "@" to ARPAbet symbols to ensure uniqueness (some are the same as uppercase letters): 14 | #_arpabet = ["@' + s for s in cmudict.valid_symbols] 15 | 16 | # Export all symbols: 17 | symbols = [_pad, _eos] + list(_characters) #+ _arpabet 18 | -------------------------------------------------------------------------------- /src/utils/modelutils.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | def check_model_paths(encoder_path: Path, synthesizer_path: Path, vocoder_path: Path): 4 | # This function tests the model paths and makes sure at least one is valid. 5 | if encoder_path.is_file() or encoder_path.is_dir(): 6 | return 7 | if synthesizer_path.is_file() or synthesizer_path.is_dir(): 8 | return 9 | if vocoder_path.is_file() or vocoder_path.is_dir(): 10 | return 11 | 12 | # If none of the paths exist, remind the user to download models if needed 13 | print("********************************************************************************") 14 | print("Error: Model files not found. Follow these instructions to get and install the models:") 15 | print("https://github.com/CorentinJ/Real-Time-Voice-Cloning/wiki/Pretrained-models") 16 | print("********************************************************************************\n") 17 | quit(-1) 18 | -------------------------------------------------------------------------------- /src/encoder/params_data.py: -------------------------------------------------------------------------------- 1 | 2 | ## Mel-filterbank 3 | mel_window_length = 25 # In milliseconds 4 | mel_window_step = 10 # In milliseconds 5 | mel_n_channels = 40 6 | 7 | 8 | ## Audio 9 | sampling_rate = 16000 10 | # Number of spectrogram frames in a partial utterance 11 | partials_n_frames = 160 # 1600 ms 12 | # Number of spectrogram frames at inference 13 | inference_n_frames = 80 # 800 ms 14 | 15 | 16 | ## Voice Activation Detection 17 | # Window size of the VAD. Must be either 10, 20 or 30 milliseconds. 18 | # This sets the granularity of the VAD. Should not need to be changed. 19 | vad_window_length = 30 # In milliseconds 20 | # Number of frames to average together when performing the moving average smoothing. 21 | # The larger this value, the larger the VAD variations must be to not get smoothed out. 22 | vad_moving_average_width = 8 23 | # Maximum number of consecutive silent frames a segment can have. 24 | vad_max_silence_length = 6 25 | 26 | 27 | ## Audio volume normalization 28 | audio_norm_target_dBFS = -30 29 | 30 | -------------------------------------------------------------------------------- /src/utils/argutils.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import numpy as np 3 | import argparse 4 | 5 | _type_priorities = [ # In decreasing order 6 | Path, 7 | str, 8 | int, 9 | float, 10 | bool, 11 | ] 12 | 13 | def _priority(o): 14 | p = next((i for i, t in enumerate(_type_priorities) if type(o) is t), None) 15 | if p is not None: 16 | return p 17 | p = next((i for i, t in enumerate(_type_priorities) if isinstance(o, t)), None) 18 | if p is not None: 19 | return p 20 | return len(_type_priorities) 21 | 22 | def print_args(args: argparse.Namespace, parser=None): 23 | args = vars(args) 24 | if parser is None: 25 | priorities = list(map(_priority, args.values())) 26 | else: 27 | all_params = [a.dest for g in parser._action_groups for a in g._group_actions ] 28 | priority = lambda p: all_params.index(p) if p in all_params else len(all_params) 29 | priorities = list(map(priority, args.keys())) 30 | 31 | pad = max(map(len, args.keys())) + 3 32 | indices = np.lexsort((list(args.keys()), priorities)) 33 | items = list(args.items()) 34 | 35 | print("Arguments:") 36 | for i in indices: 37 | param, value = items[i] 38 | print(" {0}:{1}{2}".format(param, ' ' * (pad - len(param)), value)) 39 | print("") 40 | -------------------------------------------------------------------------------- /src/synthesizer/utils/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | _output_ref = None 5 | _replicas_ref = None 6 | 7 | def data_parallel_workaround(model, *input): 8 | global _output_ref 9 | global _replicas_ref 10 | device_ids = list(range(torch.cuda.device_count())) 11 | output_device = device_ids[0] 12 | replicas = torch.nn.parallel.replicate(model, device_ids) 13 | # input.shape = (num_args, batch, ...) 14 | inputs = torch.nn.parallel.scatter(input, device_ids) 15 | # inputs.shape = (num_gpus, num_args, batch/num_gpus, ...) 16 | replicas = replicas[:len(inputs)] 17 | outputs = torch.nn.parallel.parallel_apply(replicas, inputs) 18 | y_hat = torch.nn.parallel.gather(outputs, output_device) 19 | _output_ref = outputs 20 | _replicas_ref = replicas 21 | return y_hat 22 | 23 | 24 | class ValueWindow(): 25 | def __init__(self, window_size=100): 26 | self._window_size = window_size 27 | self._values = [] 28 | 29 | def append(self, x): 30 | self._values = self._values[-(self._window_size - 1):] + [x] 31 | 32 | @property 33 | def sum(self): 34 | return sum(self._values) 35 | 36 | @property 37 | def count(self): 38 | return len(self._values) 39 | 40 | @property 41 | def average(self): 42 | return self.sum / max(1, self.count) 43 | 44 | def reset(self): 45 | self._values = [] 46 | -------------------------------------------------------------------------------- /src/utils/profiler.py: -------------------------------------------------------------------------------- 1 | from time import perf_counter as timer 2 | from collections import OrderedDict 3 | import numpy as np 4 | 5 | 6 | class Profiler: 7 | def __init__(self, summarize_every=5, disabled=False): 8 | self.last_tick = timer() 9 | self.logs = OrderedDict() 10 | self.summarize_every = summarize_every 11 | self.disabled = disabled 12 | 13 | def tick(self, name): 14 | if self.disabled: 15 | return 16 | 17 | # Log the time needed to execute that function 18 | if not name in self.logs: 19 | self.logs[name] = [] 20 | if len(self.logs[name]) >= self.summarize_every: 21 | self.summarize() 22 | self.purge_logs() 23 | self.logs[name].append(timer() - self.last_tick) 24 | 25 | self.reset_timer() 26 | 27 | def purge_logs(self): 28 | for name in self.logs: 29 | self.logs[name].clear() 30 | 31 | def reset_timer(self): 32 | self.last_tick = timer() 33 | 34 | def summarize(self): 35 | n = max(map(len, self.logs.values())) 36 | assert n == self.summarize_every 37 | print("\nAverage execution time over %d steps:" % n) 38 | 39 | name_msgs = ["%s (%d/%d):" % (name, len(deltas), n) for name, deltas in self.logs.items()] 40 | pad = max(map(len, name_msgs)) 41 | for name_msg, deltas in zip(name_msgs, self.logs.values()): 42 | print(" %s mean: %4.0fms std: %4.0fms" % 43 | (name_msg.ljust(pad), np.mean(deltas) * 1000, np.std(deltas) * 1000)) 44 | print("", flush=True) 45 | -------------------------------------------------------------------------------- /src/README.md: -------------------------------------------------------------------------------- 1 | The implementation of this tool is based on [this project](https://github.com/CorentinJ/Real-Time-Voice-Cloning), that implements the three-stage deep learning TTS framework proposed in [SV2TTS](https://arxiv.org/pdf/1806.04558.pdf). We extended it adding Facebook AI [Wav2vec 2.0](https://arxiv.org/pdf/2006.11477.pdf) ASR model. 2 | 3 | For the computation of relevant metrics we used some classic NLP libraries: [NLTK](https://www.nltk.org), [JiWER](https://github.com/jitsi/jiwer), [speechmetrics](https://github.com/aliutkus/speechmetrics) and [asrtoolkit](https://github.com/finos/greenkey-asrtoolkit). 4 | 5 | ## Source code organization 6 | `main.py` is the main script. This can be executed as a command line tool using `python main.py` (or `src/main.py` from outside this directory) with some options: 7 | - `--source`: specify source speaker path (otherwise the default source speaker `./audio/source.wav` will be used) 8 | - `--target`: specify target speaker path (otherwise the default target speaker `./audio/target.wav` will be used) 9 | - `--string`: choose the string to give as input to the tts 10 | - `--seed`: random number seed to make the output deterministic 11 | - `--metrics`: print relevant metrics 12 | - `--enhance`: trim output audio silences 13 | - `--help`: help page 14 | 15 | `./audio` contains the default source and target speaker files and the output file `audio_out.wav` 16 | 17 | `./encoder` contains the speaker encoder which creates a numerical representation of a voice from a few seconds of audio 18 | 19 | `./synthesizer` contains a modified text-to-speech synthesizer that generates an audio spectrogram in the target voice 20 | 21 | `./vocoder` contains the vocoder that transforms the spectrograms into waveform audio 22 | 23 | `./utils` contains some general utilities 24 | -------------------------------------------------------------------------------- /src/vocoder/hparams.py: -------------------------------------------------------------------------------- 1 | from synthesizer.hparams import hparams as _syn_hp 2 | 3 | 4 | # Audio settings------------------------------------------------------------------------ 5 | # Match the values of the synthesizer 6 | sample_rate = _syn_hp.sample_rate 7 | n_fft = _syn_hp.n_fft 8 | num_mels = _syn_hp.num_mels 9 | hop_length = _syn_hp.hop_size 10 | win_length = _syn_hp.win_size 11 | fmin = _syn_hp.fmin 12 | min_level_db = _syn_hp.min_level_db 13 | ref_level_db = _syn_hp.ref_level_db 14 | mel_max_abs_value = _syn_hp.max_abs_value 15 | preemphasis = _syn_hp.preemphasis 16 | apply_preemphasis = _syn_hp.preemphasize 17 | 18 | bits = 9 # bit depth of signal 19 | mu_law = True # Recommended to suppress noise if using raw bits in hp.voc_mode 20 | # below 21 | 22 | 23 | # WAVERNN / VOCODER -------------------------------------------------------------------------------- 24 | voc_mode = 'RAW' # either 'RAW' (softmax on raw bits) or 'MOL' (sample from 25 | # mixture of logistics) 26 | voc_upsample_factors = (5, 5, 8) # NB - this needs to correctly factorise hop_length 27 | voc_rnn_dims = 512 28 | voc_fc_dims = 512 29 | voc_compute_dims = 128 30 | voc_res_out_dims = 128 31 | voc_res_blocks = 10 32 | 33 | # Training 34 | voc_batch_size = 100 35 | voc_lr = 1e-4 36 | voc_gen_at_checkpoint = 5 # number of samples to generate at each checkpoint 37 | voc_pad = 2 # this will pad the input so that the resnet can 'see' wider 38 | # than input length 39 | voc_seq_len = hop_length * 5 # must be a multiple of hop_length 40 | 41 | # Generating / Synthesizing 42 | voc_gen_batched = True # very fast (realtime+) single utterance batched generation 43 | voc_target = 8000 # target number of samples to be generated in each batch entry 44 | voc_overlap = 400 # number of samples for crossfading between batches 45 | -------------------------------------------------------------------------------- /src/vocoder/inference.py: -------------------------------------------------------------------------------- 1 | from vocoder.models.fatchord_version import WaveRNN 2 | from vocoder import hparams as hp 3 | import torch 4 | 5 | 6 | _model = None # type: WaveRNN 7 | 8 | def load_model(weights_fpath, verbose=True): 9 | global _model, _device 10 | 11 | if verbose: 12 | print("Building Wave-RNN") 13 | _model = WaveRNN( 14 | rnn_dims=hp.voc_rnn_dims, 15 | fc_dims=hp.voc_fc_dims, 16 | bits=hp.bits, 17 | pad=hp.voc_pad, 18 | upsample_factors=hp.voc_upsample_factors, 19 | feat_dims=hp.num_mels, 20 | compute_dims=hp.voc_compute_dims, 21 | res_out_dims=hp.voc_res_out_dims, 22 | res_blocks=hp.voc_res_blocks, 23 | hop_length=hp.hop_length, 24 | sample_rate=hp.sample_rate, 25 | mode=hp.voc_mode 26 | ) 27 | 28 | if torch.cuda.is_available(): 29 | _model = _model.cuda() 30 | _device = torch.device('cuda') 31 | else: 32 | _device = torch.device('cpu') 33 | 34 | if verbose: 35 | print("Loading model weights at %s" % weights_fpath) 36 | checkpoint = torch.load(weights_fpath, _device) 37 | _model.load_state_dict(checkpoint['model_state']) 38 | _model.eval() 39 | 40 | 41 | def is_loaded(): 42 | return _model is not None 43 | 44 | 45 | def infer_waveform(mel, normalize=True, batched=True, target=8000, overlap=800, 46 | progress_callback=None): 47 | """ 48 | Infers the waveform of a mel spectrogram output by the synthesizer (the format must match 49 | that of the synthesizer!) 50 | 51 | :param normalize: 52 | :param batched: 53 | :param target: 54 | :param overlap: 55 | :return: 56 | """ 57 | if _model is None: 58 | raise Exception("Please load Wave-RNN in memory before using it") 59 | 60 | if normalize: 61 | mel = mel / hp.mel_max_abs_value 62 | mel = torch.from_numpy(mel[None, ...]) 63 | wav = _model.generate(mel, batched, target, overlap, hp.mu_law, progress_callback) 64 | return wav 65 | -------------------------------------------------------------------------------- /src/synthesizer/utils/_cmudict.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | valid_symbols = [ 4 | "AA", "AA0", "AA1", "AA2", "AE", "AE0", "AE1", "AE2", "AH", "AH0", "AH1", "AH2", 5 | "AO", "AO0", "AO1", "AO2", "AW", "AW0", "AW1", "AW2", "AY", "AY0", "AY1", "AY2", 6 | "B", "CH", "D", "DH", "EH", "EH0", "EH1", "EH2", "ER", "ER0", "ER1", "ER2", "EY", 7 | "EY0", "EY1", "EY2", "F", "G", "HH", "IH", "IH0", "IH1", "IH2", "IY", "IY0", "IY1", 8 | "IY2", "JH", "K", "L", "M", "N", "NG", "OW", "OW0", "OW1", "OW2", "OY", "OY0", 9 | "OY1", "OY2", "P", "R", "S", "SH", "T", "TH", "UH", "UH0", "UH1", "UH2", "UW", 10 | "UW0", "UW1", "UW2", "V", "W", "Y", "Z", "ZH" 11 | ] 12 | 13 | _valid_symbol_set = set(valid_symbols) 14 | 15 | 16 | class CMUDict: 17 | """Thin wrapper around CMUDict data. http://www.speech.cs.cmu.edu/cgi-bin/cmudict""" 18 | def __init__(self, file_or_path, keep_ambiguous=True): 19 | if isinstance(file_or_path, str): 20 | with open(file_or_path, encoding="latin-1") as f: 21 | entries = _parse_cmudict(f) 22 | else: 23 | entries = _parse_cmudict(file_or_path) 24 | if not keep_ambiguous: 25 | entries = {word: pron for word, pron in entries.items() if len(pron) == 1} 26 | self._entries = entries 27 | 28 | 29 | def __len__(self): 30 | return len(self._entries) 31 | 32 | 33 | def lookup(self, word): 34 | """Returns list of ARPAbet pronunciations of the given word.""" 35 | return self._entries.get(word.upper()) 36 | 37 | 38 | 39 | _alt_re = re.compile(r"\([0-9]+\)") 40 | 41 | 42 | def _parse_cmudict(file): 43 | cmudict = {} 44 | for line in file: 45 | if len(line) and (line[0] >= "A" and line[0] <= "Z" or line[0] == "'"): 46 | parts = line.split(" ") 47 | word = re.sub(_alt_re, "", parts[0]) 48 | pronunciation = _get_pronunciation(parts[1]) 49 | if pronunciation: 50 | if word in cmudict: 51 | cmudict[word].append(pronunciation) 52 | else: 53 | cmudict[word] = [pronunciation] 54 | return cmudict 55 | 56 | 57 | def _get_pronunciation(s): 58 | parts = s.strip().split(" ") 59 | for part in parts: 60 | if part not in _valid_symbol_set: 61 | return None 62 | return " ".join(parts) 63 | -------------------------------------------------------------------------------- /src/synthesizer/utils/numbers.py: -------------------------------------------------------------------------------- 1 | import re 2 | import inflect 3 | 4 | _inflect = inflect.engine() 5 | _comma_number_re = re.compile(r"([0-9][0-9\,]+[0-9])") 6 | _decimal_number_re = re.compile(r"([0-9]+\.[0-9]+)") 7 | _pounds_re = re.compile(r"£([0-9\,]*[0-9]+)") 8 | _dollars_re = re.compile(r"\$([0-9\.\,]*[0-9]+)") 9 | _ordinal_re = re.compile(r"[0-9]+(st|nd|rd|th)") 10 | _number_re = re.compile(r"[0-9]+") 11 | 12 | 13 | def _remove_commas(m): 14 | return m.group(1).replace(",", "") 15 | 16 | 17 | def _expand_decimal_point(m): 18 | return m.group(1).replace(".", " point ") 19 | 20 | 21 | def _expand_dollars(m): 22 | match = m.group(1) 23 | parts = match.split(".") 24 | if len(parts) > 2: 25 | return match + " dollars" # Unexpected format 26 | dollars = int(parts[0]) if parts[0] else 0 27 | cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0 28 | if dollars and cents: 29 | dollar_unit = "dollar" if dollars == 1 else "dollars" 30 | cent_unit = "cent" if cents == 1 else "cents" 31 | return "%s %s, %s %s" % (dollars, dollar_unit, cents, cent_unit) 32 | elif dollars: 33 | dollar_unit = "dollar" if dollars == 1 else "dollars" 34 | return "%s %s" % (dollars, dollar_unit) 35 | elif cents: 36 | cent_unit = "cent" if cents == 1 else "cents" 37 | return "%s %s" % (cents, cent_unit) 38 | else: 39 | return "zero dollars" 40 | 41 | 42 | def _expand_ordinal(m): 43 | return _inflect.number_to_words(m.group(0)) 44 | 45 | 46 | def _expand_number(m): 47 | num = int(m.group(0)) 48 | if num > 1000 and num < 3000: 49 | if num == 2000: 50 | return "two thousand" 51 | elif num > 2000 and num < 2010: 52 | return "two thousand " + _inflect.number_to_words(num % 100) 53 | elif num % 100 == 0: 54 | return _inflect.number_to_words(num // 100) + " hundred" 55 | else: 56 | return _inflect.number_to_words(num, andword="", zero="oh", group=2).replace(", ", " ") 57 | else: 58 | return _inflect.number_to_words(num, andword="") 59 | 60 | 61 | def normalize_numbers(text): 62 | text = re.sub(_comma_number_re, _remove_commas, text) 63 | text = re.sub(_pounds_re, r"\1 pounds", text) 64 | text = re.sub(_dollars_re, _expand_dollars, text) 65 | text = re.sub(_decimal_number_re, _expand_decimal_point, text) 66 | text = re.sub(_ordinal_re, _expand_ordinal, text) 67 | text = re.sub(_number_re, _expand_number, text) 68 | return text 69 | -------------------------------------------------------------------------------- /src/synthesizer/utils/text.py: -------------------------------------------------------------------------------- 1 | from .symbols import symbols 2 | from . import cleaners 3 | import re 4 | 5 | # Mappings from symbol to numeric ID and vice versa: 6 | _symbol_to_id = {s: i for i, s in enumerate(symbols)} 7 | _id_to_symbol = {i: s for i, s in enumerate(symbols)} 8 | 9 | # Regular expression matching text enclosed in curly braces: 10 | _curly_re = re.compile(r"(.*?)\{(.+?)\}(.*)") 11 | 12 | 13 | def text_to_sequence(text, cleaner_names): 14 | """Converts a string of text to a sequence of IDs corresponding to the symbols in the text. 15 | 16 | The text can optionally have ARPAbet sequences enclosed in curly braces embedded 17 | in it. For example, "Turn left on {HH AW1 S S T AH0 N} Street." 18 | 19 | Args: 20 | text: string to convert to a sequence 21 | cleaner_names: names of the cleaner functions to run the text through 22 | 23 | Returns: 24 | List of integers corresponding to the symbols in the text 25 | """ 26 | sequence = [] 27 | 28 | # Check for curly braces and treat their contents as ARPAbet: 29 | while len(text): 30 | m = _curly_re.match(text) 31 | if not m: 32 | sequence += _symbols_to_sequence(_clean_text(text, cleaner_names)) 33 | break 34 | sequence += _symbols_to_sequence(_clean_text(m.group(1), cleaner_names)) 35 | sequence += _arpabet_to_sequence(m.group(2)) 36 | text = m.group(3) 37 | 38 | # Append EOS token 39 | sequence.append(_symbol_to_id["~"]) 40 | return sequence 41 | 42 | 43 | def sequence_to_text(sequence): 44 | """Converts a sequence of IDs back to a string""" 45 | result = "" 46 | for symbol_id in sequence: 47 | if symbol_id in _id_to_symbol: 48 | s = _id_to_symbol[symbol_id] 49 | # Enclose ARPAbet back in curly braces: 50 | if len(s) > 1 and s[0] == "@": 51 | s = "{%s}" % s[1:] 52 | result += s 53 | return result.replace("}{", " ") 54 | 55 | 56 | def _clean_text(text, cleaner_names): 57 | for name in cleaner_names: 58 | cleaner = getattr(cleaners, name) 59 | if not cleaner: 60 | raise Exception("Unknown cleaner: %s" % name) 61 | text = cleaner(text) 62 | return text 63 | 64 | 65 | def _symbols_to_sequence(symbols): 66 | return [_symbol_to_id[s] for s in symbols if _should_keep_symbol(s)] 67 | 68 | 69 | def _arpabet_to_sequence(text): 70 | return _symbols_to_sequence(["@" + s for s in text.split()]) 71 | 72 | 73 | def _should_keep_symbol(s): 74 | return s in _symbol_to_id and s not in ("_", "~") 75 | -------------------------------------------------------------------------------- /src/synthesizer/utils/plot.py: -------------------------------------------------------------------------------- 1 | import matplotlib 2 | matplotlib.use("Agg") 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | 6 | 7 | def split_title_line(title_text, max_words=5): 8 | """ 9 | A function that splits any string based on specific character 10 | (returning it with the string), with maximum number of words on it 11 | """ 12 | seq = title_text.split() 13 | return "\n".join([" ".join(seq[i:i + max_words]) for i in range(0, len(seq), max_words)]) 14 | 15 | def plot_alignment(alignment, path, title=None, split_title=False, max_len=None): 16 | if max_len is not None: 17 | alignment = alignment[:, :max_len] 18 | 19 | fig = plt.figure(figsize=(8, 6)) 20 | ax = fig.add_subplot(111) 21 | 22 | im = ax.imshow( 23 | alignment, 24 | aspect="auto", 25 | origin="lower", 26 | interpolation="none") 27 | fig.colorbar(im, ax=ax) 28 | xlabel = "Decoder timestep" 29 | 30 | if split_title: 31 | title = split_title_line(title) 32 | 33 | plt.xlabel(xlabel) 34 | plt.title(title) 35 | plt.ylabel("Encoder timestep") 36 | plt.tight_layout() 37 | plt.savefig(path, format="png") 38 | plt.close() 39 | 40 | 41 | def plot_spectrogram(pred_spectrogram, path, title=None, split_title=False, target_spectrogram=None, max_len=None, auto_aspect=False): 42 | if max_len is not None: 43 | target_spectrogram = target_spectrogram[:max_len] 44 | pred_spectrogram = pred_spectrogram[:max_len] 45 | 46 | if split_title: 47 | title = split_title_line(title) 48 | 49 | fig = plt.figure(figsize=(10, 8)) 50 | # Set common labels 51 | fig.text(0.5, 0.18, title, horizontalalignment="center", fontsize=16) 52 | 53 | #target spectrogram subplot 54 | if target_spectrogram is not None: 55 | ax1 = fig.add_subplot(311) 56 | ax2 = fig.add_subplot(312) 57 | 58 | if auto_aspect: 59 | im = ax1.imshow(np.rot90(target_spectrogram), aspect="auto", interpolation="none") 60 | else: 61 | im = ax1.imshow(np.rot90(target_spectrogram), interpolation="none") 62 | ax1.set_title("Target Mel-Spectrogram") 63 | fig.colorbar(mappable=im, shrink=0.65, orientation="horizontal", ax=ax1) 64 | ax2.set_title("Predicted Mel-Spectrogram") 65 | else: 66 | ax2 = fig.add_subplot(211) 67 | 68 | if auto_aspect: 69 | im = ax2.imshow(np.rot90(pred_spectrogram), aspect="auto", interpolation="none") 70 | else: 71 | im = ax2.imshow(np.rot90(pred_spectrogram), interpolation="none") 72 | fig.colorbar(mappable=im, shrink=0.65, orientation="horizontal", ax=ax2) 73 | 74 | plt.tight_layout() 75 | plt.savefig(path, format="png") 76 | plt.close() 77 | -------------------------------------------------------------------------------- /docs/_layouts/default.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | {% if site.google_analytics %} 6 | 7 | 13 | {% endif %} 14 | 15 | 16 | {% seo %} 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 29 | 30 | 31 | 32 | 33 | 43 | 44 |
45 | {{ content }} 46 | 47 | 53 |
54 | 55 | 56 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 | FastVC 3 |

4 | 5 | ## Overview 6 | 7 | FastVC is a fast and efficient, non-parallel and any-to-any *voice conversion (VC)* tool. VC involves the modification of the voice of a source speaker to make it sound like that of a target speaker, without changing the linguistic content of the sentence. Our tool exploits the task by cascading an Automatic Speech Recognition (ASR) model and a Text To Speech (TTS) model. 8 | 9 |

10 | 11 |

12 | 13 | The ASR is based on [Wav2vec 2.0](https://arxiv.org/pdf/2006.11477.pdf) and is used to transcribe the speech from a source speaker. The TTS is based on [SV2TTS](https://arxiv.org/pdf/1806.04558.pdf) and is used to generate the output speech from a target speaker embedding. 14 | 15 | For a more detailed explanation check out [the paper of our project](https://github.com/fmiotello/fastVC/files/6849958/L11_Report.pdf). A demo page is available [here](https://fmiotello.github.io/fastVC). 16 | 17 | ## Installation & usage 18 | 19 | The software was implemented using `python 3.9.4` 20 | 21 | 1. Clone the repository (`git clone https://github.com/fmiotello/fastVC.git`) and enter the directory (`cd fastVC`) 22 | 2. (*optional*) Create virtual env and activate it: `python -m venv env` and `source env/bin/activate` (if using macOS/Linux) or `.\env\Scripts\activate` (if using Windows) 23 | 3. Upgrade pip: `python -m pip install --upgrade pip` 24 | 4. Install dependencies: `python -m pip install -r requirements.txt` 25 | 5. Download the pretrained models ([encoder](https://drive.google.com/file/d/1q8mEGwCkFy23KZsinbuvdKAQLqNKbYf1/view?usp=sharing), [synthesizer](https://drive.google.com/file/d/1EqFMIbvxffxtjiVrtykroF6_mUh-5Z3s/view?usp=sharing), [vocoder](https://drive.google.com/file/d/1cf2NO6FtI0jDuy8AV3Xgn6leO6dHjIgu/view?usp=sharing)) and put them in the correct directories: 26 | ```` 27 | ./src/encoder/saved_models/pretrained.pt 28 | ./src/synthesizer/saved_models/pretrained/pretrained.pt 29 | ./src/vocoder/saved_models/pretrained/pretrained.pt 30 | ```` 31 | 6. Run the main script: `python src/main.py` (use `--help` for displaying available options). The output audio will be `./src/audio/audio_out.wav`. 32 | 33 | More instructions can be found [here](https://github.com/fmiotello/fastVC/tree/main/src). 34 | 35 | ## Notes 36 | 37 | This application was developed as a project at [Politecnico di Milano](https://www.polimi.it/en/) (MSc in Music and Acoustic Engineering). 38 | 39 | *[Luigi Attorresi](https://github.com/LuigiAttorresi)*
40 | *[Federico Miotello](https://github.com/fmiotello)*
41 | *[Eugenio Poliuti](https://github.com/Poliuti)*
42 | -------------------------------------------------------------------------------- /src/vocoder/audio.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import librosa 4 | import vocoder.hparams as hp 5 | from scipy.signal import lfilter 6 | import soundfile as sf 7 | 8 | 9 | def label_2_float(x, bits) : 10 | return 2 * x / (2**bits - 1.) - 1. 11 | 12 | 13 | def float_2_label(x, bits) : 14 | assert abs(x).max() <= 1.0 15 | x = (x + 1.) * (2**bits - 1) / 2 16 | return x.clip(0, 2**bits - 1) 17 | 18 | 19 | def load_wav(path) : 20 | return librosa.load(str(path), sr=hp.sample_rate)[0] 21 | 22 | 23 | def save_wav(x, path) : 24 | sf.write(path, x.astype(np.float32), hp.sample_rate) 25 | 26 | 27 | def split_signal(x) : 28 | unsigned = x + 2**15 29 | coarse = unsigned // 256 30 | fine = unsigned % 256 31 | return coarse, fine 32 | 33 | 34 | def combine_signal(coarse, fine) : 35 | return coarse * 256 + fine - 2**15 36 | 37 | 38 | def encode_16bits(x) : 39 | return np.clip(x * 2**15, -2**15, 2**15 - 1).astype(np.int16) 40 | 41 | 42 | mel_basis = None 43 | 44 | 45 | def linear_to_mel(spectrogram): 46 | global mel_basis 47 | if mel_basis is None: 48 | mel_basis = build_mel_basis() 49 | return np.dot(mel_basis, spectrogram) 50 | 51 | 52 | def build_mel_basis(): 53 | return librosa.filters.mel(hp.sample_rate, hp.n_fft, n_mels=hp.num_mels, fmin=hp.fmin) 54 | 55 | 56 | def normalize(S): 57 | return np.clip((S - hp.min_level_db) / -hp.min_level_db, 0, 1) 58 | 59 | 60 | def denormalize(S): 61 | return (np.clip(S, 0, 1) * -hp.min_level_db) + hp.min_level_db 62 | 63 | 64 | def amp_to_db(x): 65 | return 20 * np.log10(np.maximum(1e-5, x)) 66 | 67 | 68 | def db_to_amp(x): 69 | return np.power(10.0, x * 0.05) 70 | 71 | 72 | def spectrogram(y): 73 | D = stft(y) 74 | S = amp_to_db(np.abs(D)) - hp.ref_level_db 75 | return normalize(S) 76 | 77 | 78 | def melspectrogram(y): 79 | D = stft(y) 80 | S = amp_to_db(linear_to_mel(np.abs(D))) 81 | return normalize(S) 82 | 83 | 84 | def stft(y): 85 | return librosa.stft(y=y, n_fft=hp.n_fft, hop_length=hp.hop_length, win_length=hp.win_length) 86 | 87 | 88 | def pre_emphasis(x): 89 | return lfilter([1, -hp.preemphasis], [1], x) 90 | 91 | 92 | def de_emphasis(x): 93 | return lfilter([1], [1, -hp.preemphasis], x) 94 | 95 | 96 | def encode_mu_law(x, mu) : 97 | mu = mu - 1 98 | fx = np.sign(x) * np.log(1 + mu * np.abs(x)) / np.log(1 + mu) 99 | return np.floor((fx + 1) / 2 * mu + 0.5) 100 | 101 | 102 | def decode_mu_law(y, mu, from_labels=True) : 103 | if from_labels: 104 | y = label_2_float(y, math.log2(mu)) 105 | mu = mu - 1 106 | x = np.sign(y) / mu * ((1 + mu) ** np.abs(y) - 1) 107 | return x 108 | -------------------------------------------------------------------------------- /src/synthesizer/utils/cleaners.py: -------------------------------------------------------------------------------- 1 | """ 2 | Cleaners are transformations that run over the input text at both training and eval time. 3 | 4 | Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners" 5 | hyperparameter. Some cleaners are English-specific. You"ll typically want to use: 6 | 1. "english_cleaners" for English text 7 | 2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using 8 | the Unidecode library (https://pypi.python.org/pypi/Unidecode) 9 | 3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update 10 | the symbols in symbols.py to match your data). 11 | """ 12 | 13 | import re 14 | from unidecode import unidecode 15 | from .numbers import normalize_numbers 16 | 17 | # Regular expression matching whitespace: 18 | _whitespace_re = re.compile(r"\s+") 19 | 20 | # List of (regular expression, replacement) pairs for abbreviations: 21 | _abbreviations = [(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1]) for x in [ 22 | ("mrs", "misess"), 23 | ("mr", "mister"), 24 | ("dr", "doctor"), 25 | ("st", "saint"), 26 | ("co", "company"), 27 | ("jr", "junior"), 28 | ("maj", "major"), 29 | ("gen", "general"), 30 | ("drs", "doctors"), 31 | ("rev", "reverend"), 32 | ("lt", "lieutenant"), 33 | ("hon", "honorable"), 34 | ("sgt", "sergeant"), 35 | ("capt", "captain"), 36 | ("esq", "esquire"), 37 | ("ltd", "limited"), 38 | ("col", "colonel"), 39 | ("ft", "fort"), 40 | ]] 41 | 42 | 43 | def expand_abbreviations(text): 44 | for regex, replacement in _abbreviations: 45 | text = re.sub(regex, replacement, text) 46 | return text 47 | 48 | 49 | def expand_numbers(text): 50 | return normalize_numbers(text) 51 | 52 | 53 | def lowercase(text): 54 | """lowercase input tokens.""" 55 | return text.lower() 56 | 57 | 58 | def collapse_whitespace(text): 59 | return re.sub(_whitespace_re, " ", text) 60 | 61 | 62 | def convert_to_ascii(text): 63 | return unidecode(text) 64 | 65 | 66 | def basic_cleaners(text): 67 | """Basic pipeline that lowercases and collapses whitespace without transliteration.""" 68 | text = lowercase(text) 69 | text = collapse_whitespace(text) 70 | return text 71 | 72 | 73 | def transliteration_cleaners(text): 74 | """Pipeline for non-English text that transliterates to ASCII.""" 75 | text = convert_to_ascii(text) 76 | text = lowercase(text) 77 | text = collapse_whitespace(text) 78 | return text 79 | 80 | 81 | def english_cleaners(text): 82 | """Pipeline for English text, including number and abbreviation expansion.""" 83 | text = convert_to_ascii(text) 84 | text = lowercase(text) 85 | text = expand_numbers(text) 86 | text = expand_abbreviations(text) 87 | text = collapse_whitespace(text) 88 | return text 89 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==0.12.0 2 | appdirs==1.4.4 3 | asrtoolkit @ https://github.com/finos/greenkey-asrtoolkit/archive/refs/tags/0.2.4.tar.gz 4 | astunparse==1.6.3 5 | attrs==21.2.0 6 | audioread==2.1.9 7 | beautifulsoup4==4.9.3 8 | cachetools==4.2.2 9 | certifi==2020.12.5 10 | cffi==1.14.5 11 | chardet==4.0.0 12 | click==7.1.2 13 | cycler==0.10.0 14 | datasets==1.6.2 15 | decorator==5.0.7 16 | dill==0.3.3 17 | docopt==0.6.2 18 | editdistance==0.5.3 19 | ffmpeg==1.4 20 | ffmpeg-python==0.2.0 21 | filelock==3.0.12 22 | fire==0.4.0 23 | flatbuffers==1.12 24 | fsspec==2021.4.0 25 | future==0.18.2 26 | gast==0.4.0 27 | google-auth==1.30.1 28 | google-auth-oauthlib==0.4.4 29 | google-pasta==0.2.0 30 | grpcio==1.34.1 31 | h5py==3.1.0 32 | huggingface-hub==0.0.8 33 | idna==2.10 34 | inflect==5.3.0 35 | iniconfig==1.1.1 36 | jiwer==2.2.0 37 | joblib==1.0.1 38 | jsonpatch==1.32 39 | jsonpointer==2.1 40 | jsonschema==3.2.0 41 | keras-nightly==2.5.0.dev2021032900 42 | Keras-Preprocessing==1.1.2 43 | kiwisolver==1.3.1 44 | librosa==0.8.0 45 | llvmlite==0.36.0 46 | Markdown==3.3.4 47 | matplotlib==3.4.1 48 | mock==4.0.3 49 | multiprocess==0.70.11.1 50 | musdb==0.4.0 51 | museval==0.4.0 52 | nltk==3.6.2 53 | nose==1.3.7 54 | num2words==0.5.10 55 | numba==0.53.1 56 | numpy==1.19.3; platform_system == "Windows" 57 | numpy==1.19.4; platform_system != "Windows" 58 | oauthlib==3.1.0 59 | opt-einsum==3.3.0 60 | packaging==20.9 61 | pandas==1.2.4 62 | pesq==0.0.3 63 | Pillow==8.2.0 64 | pluggy==0.13.1 65 | pooch==1.3.0 66 | protobuf==3.17.1 67 | py==1.10.0 68 | pyaml==20.4.0 69 | pyarrow==4.0.0 70 | pyasn1==0.4.8 71 | pyasn1-modules==0.2.8 72 | pycparser==2.20 73 | pynndescent==0.5.2 74 | pyparsing==2.4.7 75 | pypesq==1.2.4 76 | PyQt5==5.15.4 77 | PyQt5-Qt5==5.15.2 78 | PyQt5-sip==12.8.1 79 | pyrsistent==0.17.3 80 | pystoi==0.3.3 81 | pytest==6.2.4 82 | python-dateutil==2.8.1 83 | python-Levenshtein==0.12.2 84 | pytz==2021.1 85 | PyYAML==5.4.1 86 | pyzmq==22.0.3 87 | regex==2021.4.4 88 | requests==2.25.1 89 | requests-oauthlib==1.3.0 90 | resampy==0.2.2 91 | rsa==4.7.2 92 | sacremoses==0.0.45 93 | scikit-learn==0.24.2 94 | scipy==1.6.3 95 | simplejson==3.17.2 96 | six==1.15.0 97 | sounddevice==0.4.1 98 | SoundFile==0.10.3.post1 99 | soupsieve==2.2.1 100 | git+https://github.com/aliutkus/speechmetrics#egg=speechmetrics[cpu] 101 | stempeg==0.2.3 102 | tensorboard==2.5.0 103 | tensorboard-data-server==0.6.1 104 | tensorboard-plugin-wit==1.8.0 105 | tensorflow==2.5.0 106 | tensorflow-estimator==2.5.0 107 | termcolor==1.1.0 108 | threadpoolctl==2.1.0 109 | tokenizers==0.10.2 110 | toml==0.10.2 111 | torch==1.8.1 112 | torchaudio==0.8.1 113 | torchfile==0.1.0 114 | torchvision==0.9.1 115 | tornado==6.1 116 | tqdm==4.49.0 117 | transformers==4.5.1 118 | typing-extensions==3.7.4.3 119 | umap-learn==0.5.1 120 | Unidecode==1.2.0 121 | urllib3==1.26.5 122 | visdom==0.1.8.9 123 | webrtcvad==2.0.10; platform_system != "Windows" 124 | websocket-client==0.59.0 125 | webvtt-py==0.4.6 126 | Werkzeug==2.0.1 127 | wrapt==1.12.1 128 | xlrd==1.2.0 129 | xxhash==2.0.2 130 | -------------------------------------------------------------------------------- /src/vocoder/display.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import time 3 | import numpy as np 4 | import sys 5 | 6 | 7 | def progbar(i, n, size=16): 8 | done = (i * size) // n 9 | bar = '' 10 | for i in range(size): 11 | bar += '█' if i <= done else '░' 12 | return bar 13 | 14 | 15 | def stream(message) : 16 | try: 17 | sys.stdout.write("\r{%s}" % message) 18 | except: 19 | #Remove non-ASCII characters from message 20 | message = ''.join(i for i in message if ord(i)<128) 21 | sys.stdout.write("\r{%s}" % message) 22 | 23 | 24 | def simple_table(item_tuples) : 25 | 26 | border_pattern = '+---------------------------------------' 27 | whitespace = ' ' 28 | 29 | headings, cells, = [], [] 30 | 31 | for item in item_tuples : 32 | 33 | heading, cell = str(item[0]), str(item[1]) 34 | 35 | pad_head = True if len(heading) < len(cell) else False 36 | 37 | pad = abs(len(heading) - len(cell)) 38 | pad = whitespace[:pad] 39 | 40 | pad_left = pad[:len(pad)//2] 41 | pad_right = pad[len(pad)//2:] 42 | 43 | if pad_head : 44 | heading = pad_left + heading + pad_right 45 | else : 46 | cell = pad_left + cell + pad_right 47 | 48 | headings += [heading] 49 | cells += [cell] 50 | 51 | border, head, body = '', '', '' 52 | 53 | for i in range(len(item_tuples)) : 54 | 55 | temp_head = f'| {headings[i]} ' 56 | temp_body = f'| {cells[i]} ' 57 | 58 | border += border_pattern[:len(temp_head)] 59 | head += temp_head 60 | body += temp_body 61 | 62 | if i == len(item_tuples) - 1 : 63 | head += '|' 64 | body += '|' 65 | border += '+' 66 | 67 | print(border) 68 | print(head) 69 | print(border) 70 | print(body) 71 | print(border) 72 | print(' ') 73 | 74 | 75 | def time_since(started) : 76 | elapsed = time.time() - started 77 | m = int(elapsed // 60) 78 | s = int(elapsed % 60) 79 | if m >= 60 : 80 | h = int(m // 60) 81 | m = m % 60 82 | return f'{h}h {m}m {s}s' 83 | else : 84 | return f'{m}m {s}s' 85 | 86 | 87 | def save_attention(attn, path) : 88 | fig = plt.figure(figsize=(12, 6)) 89 | plt.imshow(attn.T, interpolation='nearest', aspect='auto') 90 | fig.savefig(f'{path}.png', bbox_inches='tight') 91 | plt.close(fig) 92 | 93 | 94 | def save_spectrogram(M, path, length=None) : 95 | M = np.flip(M, axis=0) 96 | if length : M = M[:, :length] 97 | fig = plt.figure(figsize=(12, 6)) 98 | plt.imshow(M, interpolation='nearest', aspect='auto') 99 | fig.savefig(f'{path}.png', bbox_inches='tight') 100 | plt.close(fig) 101 | 102 | 103 | def plot(array) : 104 | fig = plt.figure(figsize=(30, 5)) 105 | ax = fig.add_subplot(111) 106 | ax.xaxis.label.set_color('grey') 107 | ax.yaxis.label.set_color('grey') 108 | ax.xaxis.label.set_fontsize(23) 109 | ax.yaxis.label.set_fontsize(23) 110 | ax.tick_params(axis='x', colors='grey', labelsize=23) 111 | ax.tick_params(axis='y', colors='grey', labelsize=23) 112 | plt.plot(array) 113 | 114 | 115 | def plot_spec(M) : 116 | M = np.flip(M, axis=0) 117 | plt.figure(figsize=(18,4)) 118 | plt.imshow(M, interpolation='nearest', aspect='auto') 119 | plt.show() 120 | 121 | -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 | --- 2 | layout: default 3 | title: Fast VC 4 | --- 5 | 6 | ## Abstract 7 | 8 | Current state-of-the-art *voice conversion (VC)* tools rely on neural models trained on massive corpora of data for hundreds of hours. This approach surely leads to astonishing results, but lacks in speed, simplicity and accessibility. In this paper we introduce a simple and fast *any-to-any non-parallel* voice conversion tool that is able to perform its task provided only with a small audio excerpt of target speaker. We consider a modular approach to VC, cascading an *automatic-speech-recognition (ASR)* model, used to transcribe the source speech, and a *text-to-speech (TTS)* model, to generate the target speech. This approach presents a straightforward pipeline, allows to use already available models and opens doors to many expansions. We prove our output to be intelligible and distinguishable between different speakers. 9 | 10 | ## Audio examples 11 | 12 | ### Class 1 13 | 14 |
15 |
16 |
Source speaker
17 | 18 |
19 |
20 |
Target speaker
21 | 22 |
23 |
24 |
Output
25 | 26 |
27 |
28 | 29 | ### Class 2 30 | 31 |
32 |
33 |
Source speaker
34 | 35 |
36 |
37 |
Target speaker
38 | 39 |
40 |
41 |
Output
42 | 43 |
44 |
45 | 46 | ### Class 3 47 | 48 |
49 |
50 |
Source speaker
51 | 52 |
53 |
54 |
Target speaker
55 | 56 |
57 |
58 |
Output
59 | 60 |
61 |
62 | 63 | ### Class 4 64 | 65 |
66 |
67 |
Source speaker
68 | 69 |
70 |
71 |
Target speaker
72 | 73 |
74 |
75 |
Output
76 | 77 |
78 |
79 | 80 | ### Class 5 81 | 82 |
83 |
84 |
Source speaker
85 | 86 |
87 |
88 |
Target speaker
89 | 90 |
91 |
92 |
Output
93 | 94 |
95 |
96 | -------------------------------------------------------------------------------- /src/synthesizer/hparams.py: -------------------------------------------------------------------------------- 1 | import ast 2 | import pprint 3 | 4 | class HParams(object): 5 | def __init__(self, **kwargs): self.__dict__.update(kwargs) 6 | def __setitem__(self, key, value): setattr(self, key, value) 7 | def __getitem__(self, key): return getattr(self, key) 8 | def __repr__(self): return pprint.pformat(self.__dict__) 9 | 10 | def parse(self, string): 11 | # Overrides hparams from a comma-separated string of name=value pairs 12 | if len(string) > 0: 13 | overrides = [s.split("=") for s in string.split(",")] 14 | keys, values = zip(*overrides) 15 | keys = list(map(str.strip, keys)) 16 | values = list(map(str.strip, values)) 17 | for k in keys: 18 | self.__dict__[k] = ast.literal_eval(values[keys.index(k)]) 19 | return self 20 | 21 | hparams = HParams( 22 | ### Signal Processing (used in both synthesizer and vocoder) 23 | sample_rate = 16000, 24 | n_fft = 800, 25 | num_mels = 80, 26 | hop_size = 200, # Tacotron uses 12.5 ms frame shift (set to sample_rate * 0.0125) 27 | win_size = 800, # Tacotron uses 50 ms frame length (set to sample_rate * 0.050) 28 | fmin = 55, 29 | min_level_db = -100, 30 | ref_level_db = 20, 31 | max_abs_value = 4., # Gradient explodes if too big, premature convergence if too small. 32 | preemphasis = 0.97, # Filter coefficient to use if preemphasize is True 33 | preemphasize = True, 34 | 35 | ### Tacotron Text-to-Speech (TTS) 36 | tts_embed_dims = 512, # Embedding dimension for the graphemes/phoneme inputs 37 | tts_encoder_dims = 256, 38 | tts_decoder_dims = 128, 39 | tts_postnet_dims = 512, 40 | tts_encoder_K = 5, 41 | tts_lstm_dims = 1024, 42 | tts_postnet_K = 5, 43 | tts_num_highways = 4, 44 | tts_dropout = 0.5, 45 | tts_cleaner_names = ["english_cleaners"], 46 | tts_stop_threshold = -3.4, # Value below which audio generation ends. 47 | # For example, for a range of [-4, 4], this 48 | # will terminate the sequence at the first 49 | # frame that has all values < -3.4 50 | 51 | ### Tacotron Training 52 | tts_schedule = [(2, 1e-3, 20_000, 12), # Progressive training schedule 53 | (2, 5e-4, 40_000, 12), # (r, lr, step, batch_size) 54 | (2, 2e-4, 80_000, 12), # 55 | (2, 1e-4, 160_000, 12), # r = reduction factor (# of mel frames 56 | (2, 3e-5, 320_000, 12), # synthesized for each decoder iteration) 57 | (2, 1e-5, 640_000, 12)], # lr = learning rate 58 | 59 | tts_clip_grad_norm = 1.0, # clips the gradient norm to prevent explosion - set to None if not needed 60 | tts_eval_interval = 500, # Number of steps between model evaluation (sample generation) 61 | # Set to -1 to generate after completing epoch, or 0 to disable 62 | 63 | tts_eval_num_samples = 1, # Makes this number of samples 64 | 65 | ### Data Preprocessing 66 | max_mel_frames = 900, 67 | rescale = True, 68 | rescaling_max = 0.9, 69 | synthesis_batch_size = 16, # For vocoder preprocessing and inference. 70 | 71 | ### Mel Visualization and Griffin-Lim 72 | signal_normalization = True, 73 | power = 1.5, 74 | griffin_lim_iters = 60, 75 | 76 | ### Audio processing options 77 | fmax = 7600, # Should not exceed (sample_rate // 2) 78 | allow_clipping_in_normalization = True, # Used when signal_normalization = True 79 | clip_mels_length = True, # If true, discards samples exceeding max_mel_frames 80 | use_lws = False, # "Fast spectrogram phase recovery using local weighted sums" 81 | symmetric_mels = True, # Sets mel range to [-max_abs_value, max_abs_value] if True, 82 | # and [0, max_abs_value] if False 83 | trim_silence = True, # Use with sample_rate of 16000 for best results 84 | 85 | ### SV2TTS 86 | speaker_embedding_size = 256, # Dimension for the speaker embedding 87 | silence_min_duration_split = 0.4, # Duration in seconds of a silence for an utterance to be split 88 | utterance_min_duration = 1.6, # Duration in seconds below which utterances are discarded 89 | ) 90 | 91 | def hparams_debug_string(): 92 | return str(hparams) 93 | -------------------------------------------------------------------------------- /src/encoder/audio.py: -------------------------------------------------------------------------------- 1 | from scipy.ndimage.morphology import binary_dilation 2 | from encoder.params_data import * 3 | from pathlib import Path 4 | from typing import Optional, Union 5 | from warnings import warn 6 | import numpy as np 7 | import librosa 8 | import struct 9 | 10 | try: 11 | import webrtcvad 12 | except: 13 | warn("Unable to import 'webrtcvad'. This package enables noise removal and is recommended.") 14 | webrtcvad=None 15 | 16 | int16_max = (2 ** 15) - 1 17 | 18 | 19 | def preprocess_wav(fpath_or_wav: Union[str, Path, np.ndarray], 20 | source_sr: Optional[int] = None, 21 | normalize: Optional[bool] = True, 22 | trim_silence: Optional[bool] = True): 23 | """ 24 | Applies the preprocessing operations used in training the Speaker Encoder to a waveform 25 | either on disk or in memory. The waveform will be resampled to match the data hyperparameters. 26 | 27 | :param fpath_or_wav: either a filepath to an audio file (many extensions are supported, not 28 | just .wav), either the waveform as a numpy array of floats. 29 | :param source_sr: if passing an audio waveform, the sampling rate of the waveform before 30 | preprocessing. After preprocessing, the waveform's sampling rate will match the data 31 | hyperparameters. If passing a filepath, the sampling rate will be automatically detected and 32 | this argument will be ignored. 33 | """ 34 | # Load the wav from disk if needed 35 | if isinstance(fpath_or_wav, str) or isinstance(fpath_or_wav, Path): 36 | wav, source_sr = librosa.load(str(fpath_or_wav), sr=None) 37 | else: 38 | wav = fpath_or_wav 39 | 40 | # Resample the wav if needed 41 | if source_sr is not None and source_sr != sampling_rate: 42 | wav = librosa.resample(wav, source_sr, sampling_rate) 43 | 44 | # Apply the preprocessing: normalize volume and shorten long silences 45 | if normalize: 46 | wav = normalize_volume(wav, audio_norm_target_dBFS, increase_only=True) 47 | if webrtcvad and trim_silence: 48 | wav = trim_long_silences(wav) 49 | 50 | return wav 51 | 52 | 53 | def wav_to_mel_spectrogram(wav): 54 | """ 55 | Derives a mel spectrogram ready to be used by the encoder from a preprocessed audio waveform. 56 | Note: this not a log-mel spectrogram. 57 | """ 58 | frames = librosa.feature.melspectrogram( 59 | wav, 60 | sampling_rate, 61 | n_fft=int(sampling_rate * mel_window_length / 1000), 62 | hop_length=int(sampling_rate * mel_window_step / 1000), 63 | n_mels=mel_n_channels 64 | ) 65 | return frames.astype(np.float32).T 66 | 67 | 68 | def trim_long_silences(wav): 69 | """ 70 | Ensures that segments without voice in the waveform remain no longer than a 71 | threshold determined by the VAD parameters in params.py. 72 | 73 | :param wav: the raw waveform as a numpy array of floats 74 | :return: the same waveform with silences trimmed away (length <= original wav length) 75 | """ 76 | # Compute the voice detection window size 77 | samples_per_window = (vad_window_length * sampling_rate) // 1000 78 | 79 | # Trim the end of the audio to have a multiple of the window size 80 | wav = wav[:len(wav) - (len(wav) % samples_per_window)] 81 | 82 | # Convert the float waveform to 16-bit mono PCM 83 | pcm_wave = struct.pack("%dh" % len(wav), *(np.round(wav * int16_max)).astype(np.int16)) 84 | 85 | # Perform voice activation detection 86 | voice_flags = [] 87 | vad = webrtcvad.Vad(mode=3) 88 | for window_start in range(0, len(wav), samples_per_window): 89 | window_end = window_start + samples_per_window 90 | voice_flags.append(vad.is_speech(pcm_wave[window_start * 2:window_end * 2], 91 | sample_rate=sampling_rate)) 92 | voice_flags = np.array(voice_flags) 93 | 94 | # Smooth the voice detection with a moving average 95 | def moving_average(array, width): 96 | array_padded = np.concatenate((np.zeros((width - 1) // 2), array, np.zeros(width // 2))) 97 | ret = np.cumsum(array_padded, dtype=float) 98 | ret[width:] = ret[width:] - ret[:-width] 99 | return ret[width - 1:] / width 100 | 101 | audio_mask = moving_average(voice_flags, vad_moving_average_width) 102 | audio_mask = np.round(audio_mask).astype(np.bool) 103 | 104 | # Dilate the voiced regions 105 | audio_mask = binary_dilation(audio_mask, np.ones(vad_max_silence_length + 1)) 106 | audio_mask = np.repeat(audio_mask, samples_per_window) 107 | 108 | return wav[audio_mask == True] 109 | 110 | 111 | def normalize_volume(wav, target_dBFS, increase_only=False, decrease_only=False): 112 | if increase_only and decrease_only: 113 | raise ValueError("Both increase only and decrease only are set") 114 | dBFS_change = target_dBFS - 10 * np.log10(np.mean(wav ** 2)) 115 | if (dBFS_change < 0 and increase_only) or (dBFS_change > 0 and decrease_only): 116 | return wav 117 | return wav * (10 ** (dBFS_change / 20)) 118 | -------------------------------------------------------------------------------- /src/vocoder/distribution.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | 6 | def log_sum_exp(x): 7 | """ numerically stable log_sum_exp implementation that prevents overflow """ 8 | # TF ordering 9 | axis = len(x.size()) - 1 10 | m, _ = torch.max(x, dim=axis) 11 | m2, _ = torch.max(x, dim=axis, keepdim=True) 12 | return m + torch.log(torch.sum(torch.exp(x - m2), dim=axis)) 13 | 14 | 15 | # It is adapted from https://github.com/r9y9/wavenet_vocoder/blob/master/wavenet_vocoder/mixture.py 16 | def discretized_mix_logistic_loss(y_hat, y, num_classes=65536, 17 | log_scale_min=None, reduce=True): 18 | if log_scale_min is None: 19 | log_scale_min = float(np.log(1e-14)) 20 | y_hat = y_hat.permute(0,2,1) 21 | assert y_hat.dim() == 3 22 | assert y_hat.size(1) % 3 == 0 23 | nr_mix = y_hat.size(1) // 3 24 | 25 | # (B x T x C) 26 | y_hat = y_hat.transpose(1, 2) 27 | 28 | # unpack parameters. (B, T, num_mixtures) x 3 29 | logit_probs = y_hat[:, :, :nr_mix] 30 | means = y_hat[:, :, nr_mix:2 * nr_mix] 31 | log_scales = torch.clamp(y_hat[:, :, 2 * nr_mix:3 * nr_mix], min=log_scale_min) 32 | 33 | # B x T x 1 -> B x T x num_mixtures 34 | y = y.expand_as(means) 35 | 36 | centered_y = y - means 37 | inv_stdv = torch.exp(-log_scales) 38 | plus_in = inv_stdv * (centered_y + 1. / (num_classes - 1)) 39 | cdf_plus = torch.sigmoid(plus_in) 40 | min_in = inv_stdv * (centered_y - 1. / (num_classes - 1)) 41 | cdf_min = torch.sigmoid(min_in) 42 | 43 | # log probability for edge case of 0 (before scaling) 44 | # equivalent: torch.log(F.sigmoid(plus_in)) 45 | log_cdf_plus = plus_in - F.softplus(plus_in) 46 | 47 | # log probability for edge case of 255 (before scaling) 48 | # equivalent: (1 - F.sigmoid(min_in)).log() 49 | log_one_minus_cdf_min = -F.softplus(min_in) 50 | 51 | # probability for all other cases 52 | cdf_delta = cdf_plus - cdf_min 53 | 54 | mid_in = inv_stdv * centered_y 55 | # log probability in the center of the bin, to be used in extreme cases 56 | # (not actually used in our code) 57 | log_pdf_mid = mid_in - log_scales - 2. * F.softplus(mid_in) 58 | 59 | # tf equivalent 60 | """ 61 | log_probs = tf.where(x < -0.999, log_cdf_plus, 62 | tf.where(x > 0.999, log_one_minus_cdf_min, 63 | tf.where(cdf_delta > 1e-5, 64 | tf.log(tf.maximum(cdf_delta, 1e-12)), 65 | log_pdf_mid - np.log(127.5)))) 66 | """ 67 | # TODO: cdf_delta <= 1e-5 actually can happen. How can we choose the value 68 | # for num_classes=65536 case? 1e-7? not sure.. 69 | inner_inner_cond = (cdf_delta > 1e-5).float() 70 | 71 | inner_inner_out = inner_inner_cond * \ 72 | torch.log(torch.clamp(cdf_delta, min=1e-12)) + \ 73 | (1. - inner_inner_cond) * (log_pdf_mid - np.log((num_classes - 1) / 2)) 74 | inner_cond = (y > 0.999).float() 75 | inner_out = inner_cond * log_one_minus_cdf_min + (1. - inner_cond) * inner_inner_out 76 | cond = (y < -0.999).float() 77 | log_probs = cond * log_cdf_plus + (1. - cond) * inner_out 78 | 79 | log_probs = log_probs + F.log_softmax(logit_probs, -1) 80 | 81 | if reduce: 82 | return -torch.mean(log_sum_exp(log_probs)) 83 | else: 84 | return -log_sum_exp(log_probs).unsqueeze(-1) 85 | 86 | 87 | def sample_from_discretized_mix_logistic(y, log_scale_min=None): 88 | """ 89 | Sample from discretized mixture of logistic distributions 90 | Args: 91 | y (Tensor): B x C x T 92 | log_scale_min (float): Log scale minimum value 93 | Returns: 94 | Tensor: sample in range of [-1, 1]. 95 | """ 96 | if log_scale_min is None: 97 | log_scale_min = float(np.log(1e-14)) 98 | assert y.size(1) % 3 == 0 99 | nr_mix = y.size(1) // 3 100 | 101 | # B x T x C 102 | y = y.transpose(1, 2) 103 | logit_probs = y[:, :, :nr_mix] 104 | 105 | # sample mixture indicator from softmax 106 | temp = logit_probs.data.new(logit_probs.size()).uniform_(1e-5, 1.0 - 1e-5) 107 | temp = logit_probs.data - torch.log(- torch.log(temp)) 108 | _, argmax = temp.max(dim=-1) 109 | 110 | # (B, T) -> (B, T, nr_mix) 111 | one_hot = to_one_hot(argmax, nr_mix) 112 | # select logistic parameters 113 | means = torch.sum(y[:, :, nr_mix:2 * nr_mix] * one_hot, dim=-1) 114 | log_scales = torch.clamp(torch.sum( 115 | y[:, :, 2 * nr_mix:3 * nr_mix] * one_hot, dim=-1), min=log_scale_min) 116 | # sample from logistic & clip to interval 117 | # we don't actually round to the nearest 8bit value when sampling 118 | u = means.data.new(means.size()).uniform_(1e-5, 1.0 - 1e-5) 119 | x = means + torch.exp(log_scales) * (torch.log(u) - torch.log(1. - u)) 120 | 121 | x = torch.clamp(torch.clamp(x, min=-1.), max=1.) 122 | 123 | return x 124 | 125 | 126 | def to_one_hot(tensor, n, fill_with=1.): 127 | # we perform one hot encore with respect to the last axis 128 | one_hot = torch.FloatTensor(tensor.size() + (n,)).zero_() 129 | if tensor.is_cuda: 130 | one_hot = one_hot.cuda() 131 | one_hot.scatter_(len(tensor.size()), tensor.unsqueeze(-1), fill_with) 132 | return one_hot 133 | -------------------------------------------------------------------------------- /src/encoder/model.py: -------------------------------------------------------------------------------- 1 | from encoder.params_model import * 2 | from encoder.params_data import * 3 | from scipy.interpolate import interp1d 4 | from sklearn.metrics import roc_curve 5 | from torch.nn.utils import clip_grad_norm_ 6 | from scipy.optimize import brentq 7 | from torch import nn 8 | import numpy as np 9 | import torch 10 | 11 | 12 | class SpeakerEncoder(nn.Module): 13 | def __init__(self, device, loss_device): 14 | super().__init__() 15 | self.loss_device = loss_device 16 | 17 | # Network defition 18 | self.lstm = nn.LSTM(input_size=mel_n_channels, 19 | hidden_size=model_hidden_size, 20 | num_layers=model_num_layers, 21 | batch_first=True).to(device) 22 | self.linear = nn.Linear(in_features=model_hidden_size, 23 | out_features=model_embedding_size).to(device) 24 | self.relu = torch.nn.ReLU().to(device) 25 | 26 | # Cosine similarity scaling (with fixed initial parameter values) 27 | self.similarity_weight = nn.Parameter(torch.tensor([10.])).to(loss_device) 28 | self.similarity_bias = nn.Parameter(torch.tensor([-5.])).to(loss_device) 29 | 30 | # Loss 31 | self.loss_fn = nn.CrossEntropyLoss().to(loss_device) 32 | 33 | def do_gradient_ops(self): 34 | # Gradient scale 35 | self.similarity_weight.grad *= 0.01 36 | self.similarity_bias.grad *= 0.01 37 | 38 | # Gradient clipping 39 | clip_grad_norm_(self.parameters(), 3, norm_type=2) 40 | 41 | def forward(self, utterances, hidden_init=None): 42 | """ 43 | Computes the embeddings of a batch of utterance spectrograms. 44 | 45 | :param utterances: batch of mel-scale filterbanks of same duration as a tensor of shape 46 | (batch_size, n_frames, n_channels) 47 | :param hidden_init: initial hidden state of the LSTM as a tensor of shape (num_layers, 48 | batch_size, hidden_size). Will default to a tensor of zeros if None. 49 | :return: the embeddings as a tensor of shape (batch_size, embedding_size) 50 | """ 51 | # Pass the input through the LSTM layers and retrieve all outputs, the final hidden state 52 | # and the final cell state. 53 | out, (hidden, cell) = self.lstm(utterances, hidden_init) 54 | 55 | # We take only the hidden state of the last layer 56 | embeds_raw = self.relu(self.linear(hidden[-1])) 57 | 58 | # L2-normalize it 59 | embeds = embeds_raw / (torch.norm(embeds_raw, dim=1, keepdim=True) + 1e-5) 60 | 61 | return embeds 62 | 63 | def similarity_matrix(self, embeds): 64 | """ 65 | Computes the similarity matrix according the section 2.1 of GE2E. 66 | 67 | :param embeds: the embeddings as a tensor of shape (speakers_per_batch, 68 | utterances_per_speaker, embedding_size) 69 | :return: the similarity matrix as a tensor of shape (speakers_per_batch, 70 | utterances_per_speaker, speakers_per_batch) 71 | """ 72 | speakers_per_batch, utterances_per_speaker = embeds.shape[:2] 73 | 74 | # Inclusive centroids (1 per speaker). Cloning is needed for reverse differentiation 75 | centroids_incl = torch.mean(embeds, dim=1, keepdim=True) 76 | centroids_incl = centroids_incl.clone() / (torch.norm(centroids_incl, dim=2, keepdim=True) + 1e-5) 77 | 78 | # Exclusive centroids (1 per utterance) 79 | centroids_excl = (torch.sum(embeds, dim=1, keepdim=True) - embeds) 80 | centroids_excl /= (utterances_per_speaker - 1) 81 | centroids_excl = centroids_excl.clone() / (torch.norm(centroids_excl, dim=2, keepdim=True) + 1e-5) 82 | 83 | # Similarity matrix. The cosine similarity of already 2-normed vectors is simply the dot 84 | # product of these vectors (which is just an element-wise multiplication reduced by a sum). 85 | # We vectorize the computation for efficiency. 86 | sim_matrix = torch.zeros(speakers_per_batch, utterances_per_speaker, 87 | speakers_per_batch).to(self.loss_device) 88 | mask_matrix = 1 - np.eye(speakers_per_batch, dtype=np.int) 89 | for j in range(speakers_per_batch): 90 | mask = np.where(mask_matrix[j])[0] 91 | sim_matrix[mask, :, j] = (embeds[mask] * centroids_incl[j]).sum(dim=2) 92 | sim_matrix[j, :, j] = (embeds[j] * centroids_excl[j]).sum(dim=1) 93 | 94 | ## Even more vectorized version (slower maybe because of transpose) 95 | # sim_matrix2 = torch.zeros(speakers_per_batch, speakers_per_batch, utterances_per_speaker 96 | # ).to(self.loss_device) 97 | # eye = np.eye(speakers_per_batch, dtype=np.int) 98 | # mask = np.where(1 - eye) 99 | # sim_matrix2[mask] = (embeds[mask[0]] * centroids_incl[mask[1]]).sum(dim=2) 100 | # mask = np.where(eye) 101 | # sim_matrix2[mask] = (embeds * centroids_excl).sum(dim=2) 102 | # sim_matrix2 = sim_matrix2.transpose(1, 2) 103 | 104 | sim_matrix = sim_matrix * self.similarity_weight + self.similarity_bias 105 | return sim_matrix 106 | 107 | def loss(self, embeds): 108 | """ 109 | Computes the softmax loss according the section 2.1 of GE2E. 110 | 111 | :param embeds: the embeddings as a tensor of shape (speakers_per_batch, 112 | utterances_per_speaker, embedding_size) 113 | :return: the loss and the EER for this batch of embeddings. 114 | """ 115 | speakers_per_batch, utterances_per_speaker = embeds.shape[:2] 116 | 117 | # Loss 118 | sim_matrix = self.similarity_matrix(embeds) 119 | sim_matrix = sim_matrix.reshape((speakers_per_batch * utterances_per_speaker, 120 | speakers_per_batch)) 121 | ground_truth = np.repeat(np.arange(speakers_per_batch), utterances_per_speaker) 122 | target = torch.from_numpy(ground_truth).long().to(self.loss_device) 123 | loss = self.loss_fn(sim_matrix, target) 124 | 125 | # EER (not backpropagated) 126 | with torch.no_grad(): 127 | inv_argmax = lambda i: np.eye(1, speakers_per_batch, i, dtype=np.int)[0] 128 | labels = np.array([inv_argmax(i) for i in ground_truth]) 129 | preds = sim_matrix.detach().cpu().numpy() 130 | 131 | # Snippet from https://yangcha.github.io/EER-ROC/ 132 | fpr, tpr, thresholds = roc_curve(labels.flatten(), preds.flatten()) 133 | eer = brentq(lambda x: 1. - x - interp1d(fpr, tpr)(x), 0., 1.) 134 | 135 | return loss, eer 136 | -------------------------------------------------------------------------------- /src/synthesizer/inference.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from synthesizer import audio 3 | from synthesizer.hparams import hparams 4 | from synthesizer.models.tacotron import Tacotron 5 | from synthesizer.utils.symbols import symbols 6 | from synthesizer.utils.text import text_to_sequence 7 | from vocoder.display import simple_table 8 | from pathlib import Path 9 | from typing import Union, List 10 | import numpy as np 11 | import librosa 12 | 13 | 14 | class Synthesizer: 15 | sample_rate = hparams.sample_rate 16 | hparams = hparams 17 | 18 | def __init__(self, model_fpath: Path, verbose=True): 19 | """ 20 | The model isn't instantiated and loaded in memory until needed or until load() is called. 21 | 22 | :param model_fpath: path to the trained model file 23 | :param verbose: if False, prints less information when using the model 24 | """ 25 | self.model_fpath = model_fpath 26 | self.verbose = verbose 27 | 28 | # Check for GPU 29 | if torch.cuda.is_available(): 30 | self.device = torch.device("cuda") 31 | else: 32 | self.device = torch.device("cpu") 33 | if self.verbose: 34 | print("Synthesizer using device:", self.device) 35 | 36 | # Tacotron model will be instantiated later on first use. 37 | self._model = None 38 | 39 | def is_loaded(self): 40 | """ 41 | Whether the model is loaded in memory. 42 | """ 43 | return self._model is not None 44 | 45 | def load(self): 46 | """ 47 | Instantiates and loads the model given the weights file that was passed in the constructor. 48 | """ 49 | self._model = Tacotron(embed_dims=hparams.tts_embed_dims, 50 | num_chars=len(symbols), 51 | encoder_dims=hparams.tts_encoder_dims, 52 | decoder_dims=hparams.tts_decoder_dims, 53 | n_mels=hparams.num_mels, 54 | fft_bins=hparams.num_mels, 55 | postnet_dims=hparams.tts_postnet_dims, 56 | encoder_K=hparams.tts_encoder_K, 57 | lstm_dims=hparams.tts_lstm_dims, 58 | postnet_K=hparams.tts_postnet_K, 59 | num_highways=hparams.tts_num_highways, 60 | dropout=hparams.tts_dropout, 61 | stop_threshold=hparams.tts_stop_threshold, 62 | speaker_embedding_size=hparams.speaker_embedding_size).to(self.device) 63 | 64 | self._model.load(self.model_fpath) 65 | self._model.eval() 66 | 67 | #if self.verbose: 68 | # print("Loaded synthesizer \"%s\" trained to step %d" % (self.model_fpath.name, self._model.state_dict()["step"])) 69 | 70 | def synthesize_spectrograms(self, texts: List[str], 71 | embeddings: Union[np.ndarray, List[np.ndarray]], 72 | return_alignments=False): 73 | """ 74 | Synthesizes mel spectrograms from texts and speaker embeddings. 75 | 76 | :param texts: a list of N text prompts to be synthesized 77 | :param embeddings: a numpy array or list of speaker embeddings of shape (N, 256) 78 | :param return_alignments: if True, a matrix representing the alignments between the 79 | characters 80 | and each decoder output step will be returned for each spectrogram 81 | :return: a list of N melspectrograms as numpy arrays of shape (80, Mi), where Mi is the 82 | sequence length of spectrogram i, and possibly the alignments. 83 | """ 84 | # Load the model on the first request. 85 | if not self.is_loaded(): 86 | self.load() 87 | 88 | # Print some info about the model when it is loaded 89 | #tts_k = self._model.get_step() // 1000 90 | 91 | #simple_table([("Tacotron", str(tts_k) + "k"), ("r", self._model.r)]) 92 | 93 | # Preprocess text inputs 94 | inputs = [text_to_sequence(text.strip(), hparams.tts_cleaner_names) for text in texts] 95 | if not isinstance(embeddings, list): 96 | embeddings = [embeddings] 97 | 98 | # Batch inputs 99 | batched_inputs = [inputs[i:i+hparams.synthesis_batch_size] 100 | for i in range(0, len(inputs), hparams.synthesis_batch_size)] 101 | batched_embeds = [embeddings[i:i+hparams.synthesis_batch_size] 102 | for i in range(0, len(embeddings), hparams.synthesis_batch_size)] 103 | 104 | specs = [] 105 | for i, batch in enumerate(batched_inputs, 1): 106 | if self.verbose: 107 | print(f"\n| Generating {i}/{len(batched_inputs)}") 108 | 109 | # Pad texts so they are all the same length 110 | text_lens = [len(text) for text in batch] 111 | max_text_len = max(text_lens) 112 | chars = [pad1d(text, max_text_len) for text in batch] 113 | chars = np.stack(chars) 114 | 115 | # Stack speaker embeddings into 2D array for batch processing 116 | speaker_embeds = np.stack(batched_embeds[i-1]) 117 | 118 | # Convert to tensor 119 | chars = torch.tensor(chars).long().to(self.device) 120 | speaker_embeddings = torch.tensor(speaker_embeds).float().to(self.device) 121 | 122 | # Inference 123 | _, mels, alignments = self._model.generate(chars, speaker_embeddings) 124 | mels = mels.detach().cpu().numpy() 125 | for m in mels: 126 | # Trim silence from end of each spectrogram 127 | while np.max(m[:, -1]) < hparams.tts_stop_threshold: 128 | m = m[:, :-1] 129 | specs.append(m) 130 | 131 | #if self.verbose: 132 | # print("\n\nDone.\n") 133 | return (specs, alignments) if return_alignments else specs 134 | 135 | @staticmethod 136 | def load_preprocess_wav(fpath): 137 | """ 138 | Loads and preprocesses an audio file under the same conditions the audio files were used to 139 | train the synthesizer. 140 | """ 141 | wav = librosa.load(str(fpath), hparams.sample_rate)[0] 142 | if hparams.rescale: 143 | wav = wav / np.abs(wav).max() * hparams.rescaling_max 144 | return wav 145 | 146 | @staticmethod 147 | def make_spectrogram(fpath_or_wav: Union[str, Path, np.ndarray]): 148 | """ 149 | Creates a mel spectrogram from an audio file in the same manner as the mel spectrograms that 150 | were fed to the synthesizer when training. 151 | """ 152 | if isinstance(fpath_or_wav, str) or isinstance(fpath_or_wav, Path): 153 | wav = Synthesizer.load_preprocess_wav(fpath_or_wav) 154 | else: 155 | wav = fpath_or_wav 156 | 157 | mel_spectrogram = audio.melspectrogram(wav, hparams).astype(np.float32) 158 | return mel_spectrogram 159 | 160 | @staticmethod 161 | def griffin_lim(mel): 162 | """ 163 | Inverts a mel spectrogram using Griffin-Lim. The mel spectrogram is expected to have been built 164 | with the same parameters present in hparams.py. 165 | """ 166 | return audio.inv_mel_spectrogram(mel, hparams) 167 | 168 | 169 | def pad1d(x, max_len, pad_value=0): 170 | return np.pad(x, (0, max_len - len(x)), mode="constant", constant_values=pad_value) 171 | -------------------------------------------------------------------------------- /src/main.py: -------------------------------------------------------------------------------- 1 | ### IMPORTS ### 2 | 3 | print('[Importing libraries...]') 4 | 5 | import soundfile as sf 6 | import torch 7 | import librosa 8 | import numpy as np 9 | import sys 10 | import os 11 | from os.path import exists, join, basename, splitext 12 | from transformers import Wav2Vec2Tokenizer, Wav2Vec2ForCTC, logging 13 | from datasets import load_dataset 14 | 15 | from synthesizer.inference import Synthesizer 16 | from encoder import inference as encoder 17 | from encoder.audio import preprocess_wav 18 | from vocoder import inference as vocoder 19 | from pathlib import Path 20 | import argparse 21 | from utils.argutils import print_args 22 | 23 | import jiwer 24 | import speechmetrics 25 | from asrtoolkit import cer 26 | import nltk 27 | 28 | ### STD OUT SUPPRESSION UTILITY ### 29 | 30 | class suppress_output: 31 | def __init__(self, suppress_stdout=True, suppress_stderr=True): 32 | self.suppress_stdout = suppress_stdout 33 | self.suppress_stderr = suppress_stderr 34 | self._stdout = None 35 | self._stderr = None 36 | 37 | def __enter__(self): 38 | devnull = open(os.devnull, "w") 39 | if self.suppress_stdout: 40 | self._stdout = sys.stdout 41 | sys.stdout = devnull 42 | 43 | if self.suppress_stderr: 44 | self._stderr = sys.stderr 45 | sys.stderr = devnull 46 | 47 | def __exit__(self, *args): 48 | if self.suppress_stdout: 49 | sys.stdout = self._stdout 50 | if self.suppress_stderr: 51 | sys.stderr = self._stderr 52 | 53 | 54 | ### MODELS DOWNLOAD ### 55 | print('[Loading models...]') 56 | 57 | dir = os.getcwd() 58 | if os.path.basename(os.path.normpath(dir)) != "src": 59 | dir += "/src" 60 | 61 | 62 | logging.set_verbosity_error() 63 | 64 | with suppress_output(suppress_stdout=True, suppress_stderr=True): 65 | tokenizer = Wav2Vec2Tokenizer.from_pretrained("facebook/wav2vec2-large-960h-lv60-self") 66 | model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-large-960h-lv60-self") 67 | 68 | encoder.load_model(Path(dir + "/encoder/saved_models/pretrained.pt")) 69 | synthesizer = Synthesizer(Path(dir + "/synthesizer/saved_models/pretrained/pretrained.pt")) 70 | vocoder.load_model(Path(dir + "/vocoder/saved_models/pretrained/pretrained.pt")) 71 | 72 | ### FUNCTIONS and GLOBAL VARIABLES ### 73 | 74 | SAMPLE_RATE = 16000 75 | 76 | # def listToString(s): 77 | # str1 = "" 78 | # for ele in s: 79 | # str1 += ele 80 | # return str1 81 | 82 | def synthesize(embed, text): 83 | print('[Synthesizing new audio...]') 84 | print('Text: ' + text + '\n') 85 | #with io.capture_output() as captured: 86 | #with suppress_output(suppress_stdout=True, suppress_stderr=True): 87 | specs = synthesizer.synthesize_spectrograms([text], [embed]) 88 | generated_wav = vocoder.infer_waveform(specs[0]) 89 | generated_wav = np.pad(generated_wav, (0, synthesizer.sample_rate), mode="constant") 90 | # print('\n[...synthesis done]') 91 | return generated_wav 92 | 93 | ### MAIN ### 94 | 95 | if __name__ == '__main__': 96 | 97 | # args parsing 98 | parser = argparse.ArgumentParser( 99 | formatter_class=argparse.ArgumentDefaultsHelpFormatter 100 | ) 101 | parser.add_argument("--source", type=Path, default=None, help=\ 102 | "Source speaker for voice conversion.") 103 | parser.add_argument("--target", type=Path, default=None, help=\ 104 | "Target speaker for voice conversion.") 105 | parser.add_argument("--string", type=str, default=None, help=\ 106 | "Sentence used by the synthesized voice.") 107 | parser.add_argument("--seed", type=int, default=None, help=\ 108 | "Optional random number seed value to make toolbox deterministic.") 109 | parser.add_argument("-m", "--metrics", action="store_true", help=\ 110 | "Print some metrics.") 111 | parser.add_argument("-e", "--enhance", action="store_true", help=\ 112 | "Trims output audio silences.") 113 | 114 | args = parser.parse_args() 115 | # print_args(args, parser) 116 | 117 | if args.seed is not None: 118 | torch.manual_seed(args.seed) 119 | #synthesizer = Synthesizer(args.syn_model_fpath, verbose=False) 120 | 121 | if args.source is not None: 122 | source_audio, _ = librosa.load(args.source, sr=SAMPLE_RATE) 123 | else: 124 | source_audio, _ = librosa.load(dir + "/audio/source.wav", sr=SAMPLE_RATE) 125 | 126 | if args.target is not None: 127 | target_audio, _ = librosa.load(args.target, sr=SAMPLE_RATE) 128 | else: 129 | target_audio, _ = librosa.load(dir + "/audio/target.wav", sr=SAMPLE_RATE) 130 | 131 | if args.string is not None: 132 | if args.source is not None: 133 | raise Exception("[ERROR] Can't specify both source and string args.") 134 | transcription = args.string 135 | # text = listToString(transcription) 136 | else: 137 | input_values = tokenizer(np.asarray(source_audio), return_tensors="pt").input_values 138 | logits = model(input_values).logits 139 | predicted_ids = torch.argmax(logits, dim=-1) 140 | transcription = tokenizer.batch_decode(predicted_ids)[0] 141 | # text = listToString(transcription) 142 | 143 | embedding = encoder.embed_utterance(encoder.preprocess_wav(target_audio, SAMPLE_RATE)) 144 | 145 | out_audio = synthesize(embedding, transcription) 146 | if args.enhance: 147 | out_audio = preprocess_wav(out_audio) 148 | 149 | sf.write(dir + "/audio/audio_out.wav", out_audio, 16000) 150 | 151 | if args.metrics: 152 | input_values = tokenizer(np.asarray(out_audio), return_tensors="pt").input_values 153 | logits = model(input_values).logits 154 | predicted_ids = torch.argmax(logits, dim=-1) 155 | transcription_out = tokenizer.batch_decode(predicted_ids)[0] 156 | # text_out = listToString(transcription_out) 157 | 158 | ground_truth = transcription 159 | hypothesis = transcription_out 160 | 161 | wer_before_lemma = jiwer.wer(ground_truth, hypothesis) #word error rate 162 | 163 | print('\n') 164 | 165 | nltk.download('wordnet') 166 | wnl = nltk.stem.WordNetLemmatizer() 167 | stm = nltk.stem.snowball.EnglishStemmer() 168 | 169 | transcription_lemma = "" 170 | transcription_out_lemma = "" 171 | 172 | for s in transcription.split(" "): 173 | transcription_lemma += (stm.stem(wnl.lemmatize(s)) + " ").upper() 174 | 175 | for s in transcription_out.split(" "): 176 | transcription_out_lemma += (stm.stem(wnl.lemmatize(s)) + " ").upper() 177 | 178 | wer_after_lemma = jiwer.wer(transcription_lemma, transcription_out_lemma) 179 | cer = cer(ground_truth, hypothesis) #character error rate 180 | mer = jiwer.mer(ground_truth, hypothesis) #match error rate 181 | wil = jiwer.wil(ground_truth, hypothesis) #word information lost 182 | 183 | #mosnet 184 | 185 | window_length = None 186 | with suppress_output(suppress_stdout=True, suppress_stderr=True): 187 | metrics = speechmetrics.load('absolute.mosnet',window_length) 188 | results = metrics(dir + "/audio/audio_out.wav") 189 | 190 | print('\n\n[+++METRICS+++]\n') 191 | 192 | print('Detected text: ' + transcription_out) 193 | 194 | print('Original text lemmatized: ' + transcription_lemma) 195 | print('Synthesized text lemmatized: ' + transcription_out_lemma) 196 | 197 | print('\nWord Error Rate before lemmatization is: ', wer_before_lemma) 198 | print('Word Error Rate after lemmatization is: ', wer_after_lemma) 199 | print('Character Error Rate is: ', cer/100) 200 | #print('Match Error Rate is: ', mer) 201 | #print('Word Information Lost is: ', wil) 202 | print('MOSNet is: ', results['mosnet'][0][0]) 203 | print('\n[Done]\n') 204 | else: 205 | print('\n[Done]\n') 206 | -------------------------------------------------------------------------------- /src/encoder/inference.py: -------------------------------------------------------------------------------- 1 | from encoder.params_data import * 2 | from encoder.model import SpeakerEncoder 3 | from encoder.audio import preprocess_wav # We want to expose this function from here 4 | from matplotlib import cm 5 | from encoder import audio 6 | from pathlib import Path 7 | import matplotlib.pyplot as plt 8 | import numpy as np 9 | import torch 10 | 11 | _model = None # type: SpeakerEncoder 12 | _device = None # type: torch.device 13 | 14 | 15 | def load_model(weights_fpath: Path, device=None): 16 | """ 17 | Loads the model in memory. If this function is not explicitely called, it will be run on the 18 | first call to embed_frames() with the default weights file. 19 | 20 | :param weights_fpath: the path to saved model weights. 21 | :param device: either a torch device or the name of a torch device (e.g. "cpu", "cuda"). The 22 | model will be loaded and will run on this device. Outputs will however always be on the cpu. 23 | If None, will default to your GPU if it"s available, otherwise your CPU. 24 | """ 25 | # TODO: I think the slow loading of the encoder might have something to do with the device it 26 | # was saved on. Worth investigating. 27 | global _model, _device 28 | if device is None: 29 | _device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 30 | elif isinstance(device, str): 31 | _device = torch.device(device) 32 | _model = SpeakerEncoder(_device, torch.device("cpu")) 33 | checkpoint = torch.load(weights_fpath, _device) 34 | _model.load_state_dict(checkpoint["model_state"]) 35 | _model.eval() 36 | print("Loaded encoder \"%s\" trained to step %d" % (weights_fpath.name, checkpoint["step"])) 37 | 38 | 39 | def is_loaded(): 40 | return _model is not None 41 | 42 | 43 | def embed_frames_batch(frames_batch): 44 | """ 45 | Computes embeddings for a batch of mel spectrogram. 46 | 47 | :param frames_batch: a batch mel of spectrogram as a numpy array of float32 of shape 48 | (batch_size, n_frames, n_channels) 49 | :return: the embeddings as a numpy array of float32 of shape (batch_size, model_embedding_size) 50 | """ 51 | if _model is None: 52 | raise Exception("Model was not loaded. Call load_model() before inference.") 53 | 54 | frames = torch.from_numpy(frames_batch).to(_device) 55 | embed = _model.forward(frames).detach().cpu().numpy() 56 | return embed 57 | 58 | 59 | def compute_partial_slices(n_samples, partial_utterance_n_frames=partials_n_frames, 60 | min_pad_coverage=0.75, overlap=0.5): 61 | """ 62 | Computes where to split an utterance waveform and its corresponding mel spectrogram to obtain 63 | partial utterances of each. Both the waveform and the mel 64 | spectrogram slices are returned, so as to make each partial utterance waveform correspond to 65 | its spectrogram. This function assumes that the mel spectrogram parameters used are those 66 | defined in params_data.py. 67 | 68 | The returned ranges may be indexing further than the length of the waveform. It is 69 | recommended that you pad the waveform with zeros up to wave_slices[-1].stop. 70 | 71 | :param n_samples: the number of samples in the waveform 72 | :param partial_utterance_n_frames: the number of mel spectrogram frames in each partial 73 | utterance 74 | :param min_pad_coverage: when reaching the last partial utterance, it may or may not have 75 | enough frames. If at least of are present, 76 | then the last partial utterance will be considered, as if we padded the audio. Otherwise, 77 | it will be discarded, as if we trimmed the audio. If there aren't enough frames for 1 partial 78 | utterance, this parameter is ignored so that the function always returns at least 1 slice. 79 | :param overlap: by how much the partial utterance should overlap. If set to 0, the partial 80 | utterances are entirely disjoint. 81 | :return: the waveform slices and mel spectrogram slices as lists of array slices. Index 82 | respectively the waveform and the mel spectrogram with these slices to obtain the partial 83 | utterances. 84 | """ 85 | assert 0 <= overlap < 1 86 | assert 0 < min_pad_coverage <= 1 87 | 88 | samples_per_frame = int((sampling_rate * mel_window_step / 1000)) 89 | n_frames = int(np.ceil((n_samples + 1) / samples_per_frame)) 90 | frame_step = max(int(np.round(partial_utterance_n_frames * (1 - overlap))), 1) 91 | 92 | # Compute the slices 93 | wav_slices, mel_slices = [], [] 94 | steps = max(1, n_frames - partial_utterance_n_frames + frame_step + 1) 95 | for i in range(0, steps, frame_step): 96 | mel_range = np.array([i, i + partial_utterance_n_frames]) 97 | wav_range = mel_range * samples_per_frame 98 | mel_slices.append(slice(*mel_range)) 99 | wav_slices.append(slice(*wav_range)) 100 | 101 | # Evaluate whether extra padding is warranted or not 102 | last_wav_range = wav_slices[-1] 103 | coverage = (n_samples - last_wav_range.start) / (last_wav_range.stop - last_wav_range.start) 104 | if coverage < min_pad_coverage and len(mel_slices) > 1: 105 | mel_slices = mel_slices[:-1] 106 | wav_slices = wav_slices[:-1] 107 | 108 | return wav_slices, mel_slices 109 | 110 | 111 | def embed_utterance(wav, using_partials=True, return_partials=False, **kwargs): 112 | """ 113 | Computes an embedding for a single utterance. 114 | 115 | # TODO: handle multiple wavs to benefit from batching on GPU 116 | :param wav: a preprocessed (see audio.py) utterance waveform as a numpy array of float32 117 | :param using_partials: if True, then the utterance is split in partial utterances of 118 | frames and the utterance embedding is computed from their 119 | normalized average. If False, the utterance is instead computed from feeding the entire 120 | spectogram to the network. 121 | :param return_partials: if True, the partial embeddings will also be returned along with the 122 | wav slices that correspond to the partial embeddings. 123 | :param kwargs: additional arguments to compute_partial_splits() 124 | :return: the embedding as a numpy array of float32 of shape (model_embedding_size,). If 125 | is True, the partial utterances as a numpy array of float32 of shape 126 | (n_partials, model_embedding_size) and the wav partials as a list of slices will also be 127 | returned. If is simultaneously set to False, both these values will be None 128 | instead. 129 | """ 130 | # Process the entire utterance if not using partials 131 | if not using_partials: 132 | frames = audio.wav_to_mel_spectrogram(wav) 133 | embed = embed_frames_batch(frames[None, ...])[0] 134 | if return_partials: 135 | return embed, None, None 136 | return embed 137 | 138 | # Compute where to split the utterance into partials and pad if necessary 139 | wave_slices, mel_slices = compute_partial_slices(len(wav), **kwargs) 140 | max_wave_length = wave_slices[-1].stop 141 | if max_wave_length >= len(wav): 142 | wav = np.pad(wav, (0, max_wave_length - len(wav)), "constant") 143 | 144 | # Split the utterance into partials 145 | frames = audio.wav_to_mel_spectrogram(wav) 146 | frames_batch = np.array([frames[s] for s in mel_slices]) 147 | partial_embeds = embed_frames_batch(frames_batch) 148 | 149 | # Compute the utterance embedding from the partial embeddings 150 | raw_embed = np.mean(partial_embeds, axis=0) 151 | embed = raw_embed / np.linalg.norm(raw_embed, 2) 152 | 153 | if return_partials: 154 | return embed, partial_embeds, wave_slices 155 | return embed 156 | 157 | 158 | def embed_speaker(wavs, **kwargs): 159 | raise NotImplemented() 160 | 161 | 162 | def plot_embedding_as_heatmap(embed, ax=None, title="", shape=None, color_range=(0, 0.30)): 163 | if ax is None: 164 | ax = plt.gca() 165 | 166 | if shape is None: 167 | height = int(np.sqrt(len(embed))) 168 | shape = (height, -1) 169 | embed = embed.reshape(shape) 170 | 171 | cmap = cm.get_cmap() 172 | mappable = ax.imshow(embed, cmap=cmap) 173 | cbar = plt.colorbar(mappable, ax=ax, fraction=0.046, pad=0.04) 174 | sm = cm.ScalarMappable(cmap=cmap) 175 | sm.set_clim(*color_range) 176 | 177 | ax.set_xticks([]), ax.set_yticks([]) 178 | ax.set_title(title) 179 | -------------------------------------------------------------------------------- /src/synthesizer/audio.py: -------------------------------------------------------------------------------- 1 | import librosa 2 | import librosa.filters 3 | import numpy as np 4 | from scipy import signal 5 | from scipy.io import wavfile 6 | import soundfile as sf 7 | 8 | 9 | def load_wav(path, sr): 10 | return librosa.core.load(path, sr=sr)[0] 11 | 12 | def save_wav(wav, path, sr): 13 | wav *= 32767 / max(0.01, np.max(np.abs(wav))) 14 | #proposed by @dsmiller 15 | wavfile.write(path, sr, wav.astype(np.int16)) 16 | 17 | def save_wavenet_wav(wav, path, sr): 18 | sf.write(path, wav.astype(np.float32), sr) 19 | 20 | def preemphasis(wav, k, preemphasize=True): 21 | if preemphasize: 22 | return signal.lfilter([1, -k], [1], wav) 23 | return wav 24 | 25 | def inv_preemphasis(wav, k, inv_preemphasize=True): 26 | if inv_preemphasize: 27 | return signal.lfilter([1], [1, -k], wav) 28 | return wav 29 | 30 | #From https://github.com/r9y9/wavenet_vocoder/blob/master/audio.py 31 | def start_and_end_indices(quantized, silence_threshold=2): 32 | for start in range(quantized.size): 33 | if abs(quantized[start] - 127) > silence_threshold: 34 | break 35 | for end in range(quantized.size - 1, 1, -1): 36 | if abs(quantized[end] - 127) > silence_threshold: 37 | break 38 | 39 | assert abs(quantized[start] - 127) > silence_threshold 40 | assert abs(quantized[end] - 127) > silence_threshold 41 | 42 | return start, end 43 | 44 | def get_hop_size(hparams): 45 | hop_size = hparams.hop_size 46 | if hop_size is None: 47 | assert hparams.frame_shift_ms is not None 48 | hop_size = int(hparams.frame_shift_ms / 1000 * hparams.sample_rate) 49 | return hop_size 50 | 51 | def linearspectrogram(wav, hparams): 52 | D = _stft(preemphasis(wav, hparams.preemphasis, hparams.preemphasize), hparams) 53 | S = _amp_to_db(np.abs(D), hparams) - hparams.ref_level_db 54 | 55 | if hparams.signal_normalization: 56 | return _normalize(S, hparams) 57 | return S 58 | 59 | def melspectrogram(wav, hparams): 60 | D = _stft(preemphasis(wav, hparams.preemphasis, hparams.preemphasize), hparams) 61 | S = _amp_to_db(_linear_to_mel(np.abs(D), hparams), hparams) - hparams.ref_level_db 62 | 63 | if hparams.signal_normalization: 64 | return _normalize(S, hparams) 65 | return S 66 | 67 | def inv_linear_spectrogram(linear_spectrogram, hparams): 68 | """Converts linear spectrogram to waveform using librosa""" 69 | if hparams.signal_normalization: 70 | D = _denormalize(linear_spectrogram, hparams) 71 | else: 72 | D = linear_spectrogram 73 | 74 | S = _db_to_amp(D + hparams.ref_level_db) #Convert back to linear 75 | 76 | if hparams.use_lws: 77 | processor = _lws_processor(hparams) 78 | D = processor.run_lws(S.astype(np.float64).T ** hparams.power) 79 | y = processor.istft(D).astype(np.float32) 80 | return inv_preemphasis(y, hparams.preemphasis, hparams.preemphasize) 81 | else: 82 | return inv_preemphasis(_griffin_lim(S ** hparams.power, hparams), hparams.preemphasis, hparams.preemphasize) 83 | 84 | def inv_mel_spectrogram(mel_spectrogram, hparams): 85 | """Converts mel spectrogram to waveform using librosa""" 86 | if hparams.signal_normalization: 87 | D = _denormalize(mel_spectrogram, hparams) 88 | else: 89 | D = mel_spectrogram 90 | 91 | S = _mel_to_linear(_db_to_amp(D + hparams.ref_level_db), hparams) # Convert back to linear 92 | 93 | if hparams.use_lws: 94 | processor = _lws_processor(hparams) 95 | D = processor.run_lws(S.astype(np.float64).T ** hparams.power) 96 | y = processor.istft(D).astype(np.float32) 97 | return inv_preemphasis(y, hparams.preemphasis, hparams.preemphasize) 98 | else: 99 | return inv_preemphasis(_griffin_lim(S ** hparams.power, hparams), hparams.preemphasis, hparams.preemphasize) 100 | 101 | def _lws_processor(hparams): 102 | import lws 103 | return lws.lws(hparams.n_fft, get_hop_size(hparams), fftsize=hparams.win_size, mode="speech") 104 | 105 | def _griffin_lim(S, hparams): 106 | """librosa implementation of Griffin-Lim 107 | Based on https://github.com/librosa/librosa/issues/434 108 | """ 109 | angles = np.exp(2j * np.pi * np.random.rand(*S.shape)) 110 | S_complex = np.abs(S).astype(np.complex) 111 | y = _istft(S_complex * angles, hparams) 112 | for i in range(hparams.griffin_lim_iters): 113 | angles = np.exp(1j * np.angle(_stft(y, hparams))) 114 | y = _istft(S_complex * angles, hparams) 115 | return y 116 | 117 | def _stft(y, hparams): 118 | if hparams.use_lws: 119 | return _lws_processor(hparams).stft(y).T 120 | else: 121 | return librosa.stft(y=y, n_fft=hparams.n_fft, hop_length=get_hop_size(hparams), win_length=hparams.win_size) 122 | 123 | def _istft(y, hparams): 124 | return librosa.istft(y, hop_length=get_hop_size(hparams), win_length=hparams.win_size) 125 | 126 | ########################################################## 127 | #Those are only correct when using lws!!! (This was messing with Wavenet quality for a long time!) 128 | def num_frames(length, fsize, fshift): 129 | """Compute number of time frames of spectrogram 130 | """ 131 | pad = (fsize - fshift) 132 | if length % fshift == 0: 133 | M = (length + pad * 2 - fsize) // fshift + 1 134 | else: 135 | M = (length + pad * 2 - fsize) // fshift + 2 136 | return M 137 | 138 | 139 | def pad_lr(x, fsize, fshift): 140 | """Compute left and right padding 141 | """ 142 | M = num_frames(len(x), fsize, fshift) 143 | pad = (fsize - fshift) 144 | T = len(x) + 2 * pad 145 | r = (M - 1) * fshift + fsize - T 146 | return pad, pad + r 147 | ########################################################## 148 | #Librosa correct padding 149 | def librosa_pad_lr(x, fsize, fshift): 150 | return 0, (x.shape[0] // fshift + 1) * fshift - x.shape[0] 151 | 152 | # Conversions 153 | _mel_basis = None 154 | _inv_mel_basis = None 155 | 156 | def _linear_to_mel(spectogram, hparams): 157 | global _mel_basis 158 | if _mel_basis is None: 159 | _mel_basis = _build_mel_basis(hparams) 160 | return np.dot(_mel_basis, spectogram) 161 | 162 | def _mel_to_linear(mel_spectrogram, hparams): 163 | global _inv_mel_basis 164 | if _inv_mel_basis is None: 165 | _inv_mel_basis = np.linalg.pinv(_build_mel_basis(hparams)) 166 | return np.maximum(1e-10, np.dot(_inv_mel_basis, mel_spectrogram)) 167 | 168 | def _build_mel_basis(hparams): 169 | assert hparams.fmax <= hparams.sample_rate // 2 170 | return librosa.filters.mel(hparams.sample_rate, hparams.n_fft, n_mels=hparams.num_mels, 171 | fmin=hparams.fmin, fmax=hparams.fmax) 172 | 173 | def _amp_to_db(x, hparams): 174 | min_level = np.exp(hparams.min_level_db / 20 * np.log(10)) 175 | return 20 * np.log10(np.maximum(min_level, x)) 176 | 177 | def _db_to_amp(x): 178 | return np.power(10.0, (x) * 0.05) 179 | 180 | def _normalize(S, hparams): 181 | if hparams.allow_clipping_in_normalization: 182 | if hparams.symmetric_mels: 183 | return np.clip((2 * hparams.max_abs_value) * ((S - hparams.min_level_db) / (-hparams.min_level_db)) - hparams.max_abs_value, 184 | -hparams.max_abs_value, hparams.max_abs_value) 185 | else: 186 | return np.clip(hparams.max_abs_value * ((S - hparams.min_level_db) / (-hparams.min_level_db)), 0, hparams.max_abs_value) 187 | 188 | assert S.max() <= 0 and S.min() - hparams.min_level_db >= 0 189 | if hparams.symmetric_mels: 190 | return (2 * hparams.max_abs_value) * ((S - hparams.min_level_db) / (-hparams.min_level_db)) - hparams.max_abs_value 191 | else: 192 | return hparams.max_abs_value * ((S - hparams.min_level_db) / (-hparams.min_level_db)) 193 | 194 | def _denormalize(D, hparams): 195 | if hparams.allow_clipping_in_normalization: 196 | if hparams.symmetric_mels: 197 | return (((np.clip(D, -hparams.max_abs_value, 198 | hparams.max_abs_value) + hparams.max_abs_value) * -hparams.min_level_db / (2 * hparams.max_abs_value)) 199 | + hparams.min_level_db) 200 | else: 201 | return ((np.clip(D, 0, hparams.max_abs_value) * -hparams.min_level_db / hparams.max_abs_value) + hparams.min_level_db) 202 | 203 | if hparams.symmetric_mels: 204 | return (((D + hparams.max_abs_value) * -hparams.min_level_db / (2 * hparams.max_abs_value)) + hparams.min_level_db) 205 | else: 206 | return ((D * -hparams.min_level_db / hparams.max_abs_value) + hparams.min_level_db) 207 | -------------------------------------------------------------------------------- /src/utils/logmmse.py: -------------------------------------------------------------------------------- 1 | # The MIT License (MIT) 2 | # 3 | # Copyright (c) 2015 braindead 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | # 23 | # 24 | # This code was extracted from the logmmse package (https://pypi.org/project/logmmse/) and I 25 | # simply modified the interface to meet my needs. 26 | 27 | 28 | import numpy as np 29 | import math 30 | from scipy.special import expn 31 | from collections import namedtuple 32 | 33 | NoiseProfile = namedtuple("NoiseProfile", "sampling_rate window_size len1 len2 win n_fft noise_mu2") 34 | 35 | 36 | def profile_noise(noise, sampling_rate, window_size=0): 37 | """ 38 | Creates a profile of the noise in a given waveform. 39 | 40 | :param noise: a waveform containing noise ONLY, as a numpy array of floats or ints. 41 | :param sampling_rate: the sampling rate of the audio 42 | :param window_size: the size of the window the logmmse algorithm operates on. A default value 43 | will be picked if left as 0. 44 | :return: a NoiseProfile object 45 | """ 46 | noise, dtype = to_float(noise) 47 | noise += np.finfo(np.float64).eps 48 | 49 | if window_size == 0: 50 | window_size = int(math.floor(0.02 * sampling_rate)) 51 | 52 | if window_size % 2 == 1: 53 | window_size = window_size + 1 54 | 55 | perc = 50 56 | len1 = int(math.floor(window_size * perc / 100)) 57 | len2 = int(window_size - len1) 58 | 59 | win = np.hanning(window_size) 60 | win = win * len2 / np.sum(win) 61 | n_fft = 2 * window_size 62 | 63 | noise_mean = np.zeros(n_fft) 64 | n_frames = len(noise) // window_size 65 | for j in range(0, window_size * n_frames, window_size): 66 | noise_mean += np.absolute(np.fft.fft(win * noise[j:j + window_size], n_fft, axis=0)) 67 | noise_mu2 = (noise_mean / n_frames) ** 2 68 | 69 | return NoiseProfile(sampling_rate, window_size, len1, len2, win, n_fft, noise_mu2) 70 | 71 | 72 | def denoise(wav, noise_profile: NoiseProfile, eta=0.15): 73 | """ 74 | Cleans the noise from a speech waveform given a noise profile. The waveform must have the 75 | same sampling rate as the one used to create the noise profile. 76 | 77 | :param wav: a speech waveform as a numpy array of floats or ints. 78 | :param noise_profile: a NoiseProfile object that was created from a similar (or a segment of 79 | the same) waveform. 80 | :param eta: voice threshold for noise update. While the voice activation detection value is 81 | below this threshold, the noise profile will be continuously updated throughout the audio. 82 | Set to 0 to disable updating the noise profile. 83 | :return: the clean wav as a numpy array of floats or ints of the same length. 84 | """ 85 | wav, dtype = to_float(wav) 86 | wav += np.finfo(np.float64).eps 87 | p = noise_profile 88 | 89 | nframes = int(math.floor(len(wav) / p.len2) - math.floor(p.window_size / p.len2)) 90 | x_final = np.zeros(nframes * p.len2) 91 | 92 | aa = 0.98 93 | mu = 0.98 94 | ksi_min = 10 ** (-25 / 10) 95 | 96 | x_old = np.zeros(p.len1) 97 | xk_prev = np.zeros(p.len1) 98 | noise_mu2 = p.noise_mu2 99 | for k in range(0, nframes * p.len2, p.len2): 100 | insign = p.win * wav[k:k + p.window_size] 101 | 102 | spec = np.fft.fft(insign, p.n_fft, axis=0) 103 | sig = np.absolute(spec) 104 | sig2 = sig ** 2 105 | 106 | gammak = np.minimum(sig2 / noise_mu2, 40) 107 | 108 | if xk_prev.all() == 0: 109 | ksi = aa + (1 - aa) * np.maximum(gammak - 1, 0) 110 | else: 111 | ksi = aa * xk_prev / noise_mu2 + (1 - aa) * np.maximum(gammak - 1, 0) 112 | ksi = np.maximum(ksi_min, ksi) 113 | 114 | log_sigma_k = gammak * ksi/(1 + ksi) - np.log(1 + ksi) 115 | vad_decision = np.sum(log_sigma_k) / p.window_size 116 | if vad_decision < eta: 117 | noise_mu2 = mu * noise_mu2 + (1 - mu) * sig2 118 | 119 | a = ksi / (1 + ksi) 120 | vk = a * gammak 121 | ei_vk = 0.5 * expn(1, np.maximum(vk, 1e-8)) 122 | hw = a * np.exp(ei_vk) 123 | sig = sig * hw 124 | xk_prev = sig ** 2 125 | xi_w = np.fft.ifft(hw * spec, p.n_fft, axis=0) 126 | xi_w = np.real(xi_w) 127 | 128 | x_final[k:k + p.len2] = x_old + xi_w[0:p.len1] 129 | x_old = xi_w[p.len1:p.window_size] 130 | 131 | output = from_float(x_final, dtype) 132 | output = np.pad(output, (0, len(wav) - len(output)), mode="constant") 133 | return output 134 | 135 | 136 | ## Alternative VAD algorithm to webrctvad. It has the advantage of not requiring to install that 137 | ## darn package and it also works for any sampling rate. Maybe I'll eventually use it instead of 138 | ## webrctvad 139 | # def vad(wav, sampling_rate, eta=0.15, window_size=0): 140 | # """ 141 | # TODO: fix doc 142 | # Creates a profile of the noise in a given waveform. 143 | # 144 | # :param wav: a waveform containing noise ONLY, as a numpy array of floats or ints. 145 | # :param sampling_rate: the sampling rate of the audio 146 | # :param window_size: the size of the window the logmmse algorithm operates on. A default value 147 | # will be picked if left as 0. 148 | # :param eta: voice threshold for noise update. While the voice activation detection value is 149 | # below this threshold, the noise profile will be continuously updated throughout the audio. 150 | # Set to 0 to disable updating the noise profile. 151 | # """ 152 | # wav, dtype = to_float(wav) 153 | # wav += np.finfo(np.float64).eps 154 | # 155 | # if window_size == 0: 156 | # window_size = int(math.floor(0.02 * sampling_rate)) 157 | # 158 | # if window_size % 2 == 1: 159 | # window_size = window_size + 1 160 | # 161 | # perc = 50 162 | # len1 = int(math.floor(window_size * perc / 100)) 163 | # len2 = int(window_size - len1) 164 | # 165 | # win = np.hanning(window_size) 166 | # win = win * len2 / np.sum(win) 167 | # n_fft = 2 * window_size 168 | # 169 | # wav_mean = np.zeros(n_fft) 170 | # n_frames = len(wav) // window_size 171 | # for j in range(0, window_size * n_frames, window_size): 172 | # wav_mean += np.absolute(np.fft.fft(win * wav[j:j + window_size], n_fft, axis=0)) 173 | # noise_mu2 = (wav_mean / n_frames) ** 2 174 | # 175 | # wav, dtype = to_float(wav) 176 | # wav += np.finfo(np.float64).eps 177 | # 178 | # nframes = int(math.floor(len(wav) / len2) - math.floor(window_size / len2)) 179 | # vad = np.zeros(nframes * len2, dtype=np.bool) 180 | # 181 | # aa = 0.98 182 | # mu = 0.98 183 | # ksi_min = 10 ** (-25 / 10) 184 | # 185 | # xk_prev = np.zeros(len1) 186 | # noise_mu2 = noise_mu2 187 | # for k in range(0, nframes * len2, len2): 188 | # insign = win * wav[k:k + window_size] 189 | # 190 | # spec = np.fft.fft(insign, n_fft, axis=0) 191 | # sig = np.absolute(spec) 192 | # sig2 = sig ** 2 193 | # 194 | # gammak = np.minimum(sig2 / noise_mu2, 40) 195 | # 196 | # if xk_prev.all() == 0: 197 | # ksi = aa + (1 - aa) * np.maximum(gammak - 1, 0) 198 | # else: 199 | # ksi = aa * xk_prev / noise_mu2 + (1 - aa) * np.maximum(gammak - 1, 0) 200 | # ksi = np.maximum(ksi_min, ksi) 201 | # 202 | # log_sigma_k = gammak * ksi / (1 + ksi) - np.log(1 + ksi) 203 | # vad_decision = np.sum(log_sigma_k) / window_size 204 | # if vad_decision < eta: 205 | # noise_mu2 = mu * noise_mu2 + (1 - mu) * sig2 206 | # print(vad_decision) 207 | # 208 | # a = ksi / (1 + ksi) 209 | # vk = a * gammak 210 | # ei_vk = 0.5 * expn(1, np.maximum(vk, 1e-8)) 211 | # hw = a * np.exp(ei_vk) 212 | # sig = sig * hw 213 | # xk_prev = sig ** 2 214 | # 215 | # vad[k:k + len2] = vad_decision >= eta 216 | # 217 | # vad = np.pad(vad, (0, len(wav) - len(vad)), mode="constant") 218 | # return vad 219 | 220 | 221 | def to_float(_input): 222 | if _input.dtype == np.float64: 223 | return _input, _input.dtype 224 | elif _input.dtype == np.float32: 225 | return _input.astype(np.float64), _input.dtype 226 | elif _input.dtype == np.uint8: 227 | return (_input - 128) / 128., _input.dtype 228 | elif _input.dtype == np.int16: 229 | return _input / 32768., _input.dtype 230 | elif _input.dtype == np.int32: 231 | return _input / 2147483648., _input.dtype 232 | raise ValueError('Unsupported wave file format') 233 | 234 | 235 | def from_float(_input, dtype): 236 | if dtype == np.float64: 237 | return _input, np.float64 238 | elif dtype == np.float32: 239 | return _input.astype(np.float32) 240 | elif dtype == np.uint8: 241 | return ((_input * 128) + 128).astype(np.uint8) 242 | elif dtype == np.int16: 243 | return (_input * 32768).astype(np.int16) 244 | elif dtype == np.int32: 245 | print(_input) 246 | return (_input * 2147483648).astype(np.int32) 247 | raise ValueError('Unsupported wave file format') 248 | -------------------------------------------------------------------------------- /src/vocoder/models/fatchord_version.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from vocoder.distribution import sample_from_discretized_mix_logistic 5 | from vocoder.display import * 6 | from vocoder.audio import * 7 | 8 | 9 | class ResBlock(nn.Module): 10 | def __init__(self, dims): 11 | super().__init__() 12 | self.conv1 = nn.Conv1d(dims, dims, kernel_size=1, bias=False) 13 | self.conv2 = nn.Conv1d(dims, dims, kernel_size=1, bias=False) 14 | self.batch_norm1 = nn.BatchNorm1d(dims) 15 | self.batch_norm2 = nn.BatchNorm1d(dims) 16 | 17 | def forward(self, x): 18 | residual = x 19 | x = self.conv1(x) 20 | x = self.batch_norm1(x) 21 | x = F.relu(x) 22 | x = self.conv2(x) 23 | x = self.batch_norm2(x) 24 | return x + residual 25 | 26 | 27 | class MelResNet(nn.Module): 28 | def __init__(self, res_blocks, in_dims, compute_dims, res_out_dims, pad): 29 | super().__init__() 30 | k_size = pad * 2 + 1 31 | self.conv_in = nn.Conv1d(in_dims, compute_dims, kernel_size=k_size, bias=False) 32 | self.batch_norm = nn.BatchNorm1d(compute_dims) 33 | self.layers = nn.ModuleList() 34 | for i in range(res_blocks): 35 | self.layers.append(ResBlock(compute_dims)) 36 | self.conv_out = nn.Conv1d(compute_dims, res_out_dims, kernel_size=1) 37 | 38 | def forward(self, x): 39 | x = self.conv_in(x) 40 | x = self.batch_norm(x) 41 | x = F.relu(x) 42 | for f in self.layers: x = f(x) 43 | x = self.conv_out(x) 44 | return x 45 | 46 | 47 | class Stretch2d(nn.Module): 48 | def __init__(self, x_scale, y_scale): 49 | super().__init__() 50 | self.x_scale = x_scale 51 | self.y_scale = y_scale 52 | 53 | def forward(self, x): 54 | b, c, h, w = x.size() 55 | x = x.unsqueeze(-1).unsqueeze(3) 56 | x = x.repeat(1, 1, 1, self.y_scale, 1, self.x_scale) 57 | return x.view(b, c, h * self.y_scale, w * self.x_scale) 58 | 59 | 60 | class UpsampleNetwork(nn.Module): 61 | def __init__(self, feat_dims, upsample_scales, compute_dims, 62 | res_blocks, res_out_dims, pad): 63 | super().__init__() 64 | total_scale = np.cumproduct(upsample_scales)[-1] 65 | self.indent = pad * total_scale 66 | self.resnet = MelResNet(res_blocks, feat_dims, compute_dims, res_out_dims, pad) 67 | self.resnet_stretch = Stretch2d(total_scale, 1) 68 | self.up_layers = nn.ModuleList() 69 | for scale in upsample_scales: 70 | k_size = (1, scale * 2 + 1) 71 | padding = (0, scale) 72 | stretch = Stretch2d(scale, 1) 73 | conv = nn.Conv2d(1, 1, kernel_size=k_size, padding=padding, bias=False) 74 | conv.weight.data.fill_(1. / k_size[1]) 75 | self.up_layers.append(stretch) 76 | self.up_layers.append(conv) 77 | 78 | def forward(self, m): 79 | aux = self.resnet(m).unsqueeze(1) 80 | aux = self.resnet_stretch(aux) 81 | aux = aux.squeeze(1) 82 | m = m.unsqueeze(1) 83 | for f in self.up_layers: m = f(m) 84 | m = m.squeeze(1)[:, :, self.indent:-self.indent] 85 | return m.transpose(1, 2), aux.transpose(1, 2) 86 | 87 | 88 | class WaveRNN(nn.Module): 89 | def __init__(self, rnn_dims, fc_dims, bits, pad, upsample_factors, 90 | feat_dims, compute_dims, res_out_dims, res_blocks, 91 | hop_length, sample_rate, mode='RAW'): 92 | super().__init__() 93 | self.mode = mode 94 | self.pad = pad 95 | if self.mode == 'RAW' : 96 | self.n_classes = 2 ** bits 97 | elif self.mode == 'MOL' : 98 | self.n_classes = 30 99 | else : 100 | RuntimeError("Unknown model mode value - ", self.mode) 101 | 102 | self.rnn_dims = rnn_dims 103 | self.aux_dims = res_out_dims // 4 104 | self.hop_length = hop_length 105 | self.sample_rate = sample_rate 106 | 107 | self.upsample = UpsampleNetwork(feat_dims, upsample_factors, compute_dims, res_blocks, res_out_dims, pad) 108 | self.I = nn.Linear(feat_dims + self.aux_dims + 1, rnn_dims) 109 | self.rnn1 = nn.GRU(rnn_dims, rnn_dims, batch_first=True) 110 | self.rnn2 = nn.GRU(rnn_dims + self.aux_dims, rnn_dims, batch_first=True) 111 | self.fc1 = nn.Linear(rnn_dims + self.aux_dims, fc_dims) 112 | self.fc2 = nn.Linear(fc_dims + self.aux_dims, fc_dims) 113 | self.fc3 = nn.Linear(fc_dims, self.n_classes) 114 | 115 | self.step = nn.Parameter(torch.zeros(1).long(), requires_grad=False) 116 | self.num_params() 117 | 118 | def forward(self, x, mels): 119 | self.step += 1 120 | bsize = x.size(0) 121 | if torch.cuda.is_available(): 122 | h1 = torch.zeros(1, bsize, self.rnn_dims).cuda() 123 | h2 = torch.zeros(1, bsize, self.rnn_dims).cuda() 124 | else: 125 | h1 = torch.zeros(1, bsize, self.rnn_dims).cpu() 126 | h2 = torch.zeros(1, bsize, self.rnn_dims).cpu() 127 | mels, aux = self.upsample(mels) 128 | 129 | aux_idx = [self.aux_dims * i for i in range(5)] 130 | a1 = aux[:, :, aux_idx[0]:aux_idx[1]] 131 | a2 = aux[:, :, aux_idx[1]:aux_idx[2]] 132 | a3 = aux[:, :, aux_idx[2]:aux_idx[3]] 133 | a4 = aux[:, :, aux_idx[3]:aux_idx[4]] 134 | 135 | x = torch.cat([x.unsqueeze(-1), mels, a1], dim=2) 136 | x = self.I(x) 137 | res = x 138 | x, _ = self.rnn1(x, h1) 139 | 140 | x = x + res 141 | res = x 142 | x = torch.cat([x, a2], dim=2) 143 | x, _ = self.rnn2(x, h2) 144 | 145 | x = x + res 146 | x = torch.cat([x, a3], dim=2) 147 | x = F.relu(self.fc1(x)) 148 | 149 | x = torch.cat([x, a4], dim=2) 150 | x = F.relu(self.fc2(x)) 151 | return self.fc3(x) 152 | 153 | def generate(self, mels, batched, target, overlap, mu_law, progress_callback=None): 154 | mu_law = mu_law if self.mode == 'RAW' else False 155 | progress_callback = progress_callback or self.gen_display 156 | 157 | self.eval() 158 | output = [] 159 | start = time.time() 160 | rnn1 = self.get_gru_cell(self.rnn1) 161 | rnn2 = self.get_gru_cell(self.rnn2) 162 | 163 | with torch.no_grad(): 164 | if torch.cuda.is_available(): 165 | mels = mels.cuda() 166 | else: 167 | mels = mels.cpu() 168 | wave_len = (mels.size(-1) - 1) * self.hop_length 169 | mels = self.pad_tensor(mels.transpose(1, 2), pad=self.pad, side='both') 170 | mels, aux = self.upsample(mels.transpose(1, 2)) 171 | 172 | if batched: 173 | mels = self.fold_with_overlap(mels, target, overlap) 174 | aux = self.fold_with_overlap(aux, target, overlap) 175 | 176 | b_size, seq_len, _ = mels.size() 177 | 178 | if torch.cuda.is_available(): 179 | h1 = torch.zeros(b_size, self.rnn_dims).cuda() 180 | h2 = torch.zeros(b_size, self.rnn_dims).cuda() 181 | x = torch.zeros(b_size, 1).cuda() 182 | else: 183 | h1 = torch.zeros(b_size, self.rnn_dims).cpu() 184 | h2 = torch.zeros(b_size, self.rnn_dims).cpu() 185 | x = torch.zeros(b_size, 1).cpu() 186 | 187 | d = self.aux_dims 188 | aux_split = [aux[:, :, d * i:d * (i + 1)] for i in range(4)] 189 | 190 | for i in range(seq_len): 191 | 192 | m_t = mels[:, i, :] 193 | 194 | a1_t, a2_t, a3_t, a4_t = (a[:, i, :] for a in aux_split) 195 | 196 | x = torch.cat([x, m_t, a1_t], dim=1) 197 | x = self.I(x) 198 | h1 = rnn1(x, h1) 199 | 200 | x = x + h1 201 | inp = torch.cat([x, a2_t], dim=1) 202 | h2 = rnn2(inp, h2) 203 | 204 | x = x + h2 205 | x = torch.cat([x, a3_t], dim=1) 206 | x = F.relu(self.fc1(x)) 207 | 208 | x = torch.cat([x, a4_t], dim=1) 209 | x = F.relu(self.fc2(x)) 210 | 211 | logits = self.fc3(x) 212 | 213 | if self.mode == 'MOL': 214 | sample = sample_from_discretized_mix_logistic(logits.unsqueeze(0).transpose(1, 2)) 215 | output.append(sample.view(-1)) 216 | if torch.cuda.is_available(): 217 | # x = torch.FloatTensor([[sample]]).cuda() 218 | x = sample.transpose(0, 1).cuda() 219 | else: 220 | x = sample.transpose(0, 1) 221 | 222 | elif self.mode == 'RAW' : 223 | posterior = F.softmax(logits, dim=1) 224 | distrib = torch.distributions.Categorical(posterior) 225 | 226 | sample = 2 * distrib.sample().float() / (self.n_classes - 1.) - 1. 227 | output.append(sample) 228 | x = sample.unsqueeze(-1) 229 | else: 230 | raise RuntimeError("Unknown model mode value - ", self.mode) 231 | 232 | if i % 100 == 0: 233 | gen_rate = (i + 1) / (time.time() - start) * b_size / 1000 234 | progress_callback(i, seq_len, b_size, gen_rate) 235 | 236 | output = torch.stack(output).transpose(0, 1) 237 | output = output.cpu().numpy() 238 | output = output.astype(np.float64) 239 | 240 | if batched: 241 | output = self.xfade_and_unfold(output, target, overlap) 242 | else: 243 | output = output[0] 244 | 245 | if mu_law: 246 | output = decode_mu_law(output, self.n_classes, False) 247 | if hp.apply_preemphasis: 248 | output = de_emphasis(output) 249 | 250 | # Fade-out at the end to avoid signal cutting out suddenly 251 | fade_out = np.linspace(1, 0, 20 * self.hop_length) 252 | output = output[:wave_len] 253 | output[-20 * self.hop_length:] *= fade_out 254 | 255 | self.train() 256 | 257 | return output 258 | 259 | 260 | def gen_display(self, i, seq_len, b_size, gen_rate): 261 | pbar = progbar(i, seq_len) 262 | msg = f'| {pbar} {i*b_size}/{seq_len*b_size} | Batch Size: {b_size} | Gen Rate: {gen_rate:.1f}kHz | ' 263 | stream(msg) 264 | 265 | def get_gru_cell(self, gru): 266 | gru_cell = nn.GRUCell(gru.input_size, gru.hidden_size) 267 | gru_cell.weight_hh.data = gru.weight_hh_l0.data 268 | gru_cell.weight_ih.data = gru.weight_ih_l0.data 269 | gru_cell.bias_hh.data = gru.bias_hh_l0.data 270 | gru_cell.bias_ih.data = gru.bias_ih_l0.data 271 | return gru_cell 272 | 273 | def pad_tensor(self, x, pad, side='both'): 274 | # NB - this is just a quick method i need right now 275 | # i.e., it won't generalise to other shapes/dims 276 | b, t, c = x.size() 277 | total = t + 2 * pad if side == 'both' else t + pad 278 | if torch.cuda.is_available(): 279 | padded = torch.zeros(b, total, c).cuda() 280 | else: 281 | padded = torch.zeros(b, total, c).cpu() 282 | if side == 'before' or side == 'both': 283 | padded[:, pad:pad + t, :] = x 284 | elif side == 'after': 285 | padded[:, :t, :] = x 286 | return padded 287 | 288 | def fold_with_overlap(self, x, target, overlap): 289 | 290 | ''' Fold the tensor with overlap for quick batched inference. 291 | Overlap will be used for crossfading in xfade_and_unfold() 292 | 293 | Args: 294 | x (tensor) : Upsampled conditioning features. 295 | shape=(1, timesteps, features) 296 | target (int) : Target timesteps for each index of batch 297 | overlap (int) : Timesteps for both xfade and rnn warmup 298 | 299 | Return: 300 | (tensor) : shape=(num_folds, target + 2 * overlap, features) 301 | 302 | Details: 303 | x = [[h1, h2, ... hn]] 304 | 305 | Where each h is a vector of conditioning features 306 | 307 | Eg: target=2, overlap=1 with x.size(1)=10 308 | 309 | folded = [[h1, h2, h3, h4], 310 | [h4, h5, h6, h7], 311 | [h7, h8, h9, h10]] 312 | ''' 313 | 314 | _, total_len, features = x.size() 315 | 316 | # Calculate variables needed 317 | num_folds = (total_len - overlap) // (target + overlap) 318 | extended_len = num_folds * (overlap + target) + overlap 319 | remaining = total_len - extended_len 320 | 321 | # Pad if some time steps poking out 322 | if remaining != 0: 323 | num_folds += 1 324 | padding = target + 2 * overlap - remaining 325 | x = self.pad_tensor(x, padding, side='after') 326 | 327 | if torch.cuda.is_available(): 328 | folded = torch.zeros(num_folds, target + 2 * overlap, features).cuda() 329 | else: 330 | folded = torch.zeros(num_folds, target + 2 * overlap, features).cpu() 331 | 332 | # Get the values for the folded tensor 333 | for i in range(num_folds): 334 | start = i * (target + overlap) 335 | end = start + target + 2 * overlap 336 | folded[i] = x[:, start:end, :] 337 | 338 | return folded 339 | 340 | def xfade_and_unfold(self, y, target, overlap): 341 | 342 | ''' Applies a crossfade and unfolds into a 1d array. 343 | 344 | Args: 345 | y (ndarry) : Batched sequences of audio samples 346 | shape=(num_folds, target + 2 * overlap) 347 | dtype=np.float64 348 | overlap (int) : Timesteps for both xfade and rnn warmup 349 | 350 | Return: 351 | (ndarry) : audio samples in a 1d array 352 | shape=(total_len) 353 | dtype=np.float64 354 | 355 | Details: 356 | y = [[seq1], 357 | [seq2], 358 | [seq3]] 359 | 360 | Apply a gain envelope at both ends of the sequences 361 | 362 | y = [[seq1_in, seq1_target, seq1_out], 363 | [seq2_in, seq2_target, seq2_out], 364 | [seq3_in, seq3_target, seq3_out]] 365 | 366 | Stagger and add up the groups of samples: 367 | 368 | [seq1_in, seq1_target, (seq1_out + seq2_in), seq2_target, ...] 369 | 370 | ''' 371 | 372 | num_folds, length = y.shape 373 | target = length - 2 * overlap 374 | total_len = num_folds * (target + overlap) + overlap 375 | 376 | # Need some silence for the rnn warmup 377 | silence_len = overlap // 2 378 | fade_len = overlap - silence_len 379 | silence = np.zeros((silence_len), dtype=np.float64) 380 | 381 | # Equal power crossfade 382 | t = np.linspace(-1, 1, fade_len, dtype=np.float64) 383 | fade_in = np.sqrt(0.5 * (1 + t)) 384 | fade_out = np.sqrt(0.5 * (1 - t)) 385 | 386 | # Concat the silence to the fades 387 | fade_in = np.concatenate([silence, fade_in]) 388 | fade_out = np.concatenate([fade_out, silence]) 389 | 390 | # Apply the gain to the overlap samples 391 | y[:, :overlap] *= fade_in 392 | y[:, -overlap:] *= fade_out 393 | 394 | unfolded = np.zeros((total_len), dtype=np.float64) 395 | 396 | # Loop to add up all the samples 397 | for i in range(num_folds): 398 | start = i * (target + overlap) 399 | end = start + target + 2 * overlap 400 | unfolded[start:end] += y[i] 401 | 402 | return unfolded 403 | 404 | def get_step(self) : 405 | return self.step.data.item() 406 | 407 | def checkpoint(self, model_dir, optimizer) : 408 | k_steps = self.get_step() // 1000 409 | self.save(model_dir.joinpath("checkpoint_%dk_steps.pt" % k_steps), optimizer) 410 | 411 | def log(self, path, msg) : 412 | with open(path, 'a') as f: 413 | print(msg, file=f) 414 | 415 | def load(self, path, optimizer) : 416 | checkpoint = torch.load(path) 417 | if "optimizer_state" in checkpoint: 418 | self.load_state_dict(checkpoint["model_state"]) 419 | optimizer.load_state_dict(checkpoint["optimizer_state"]) 420 | else: 421 | # Backwards compatibility 422 | self.load_state_dict(checkpoint) 423 | 424 | def save(self, path, optimizer) : 425 | torch.save({ 426 | "model_state": self.state_dict(), 427 | "optimizer_state": optimizer.state_dict(), 428 | }, path) 429 | 430 | def num_params(self, print_out=True): 431 | parameters = filter(lambda p: p.requires_grad, self.parameters()) 432 | parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000 433 | if print_out : 434 | print('Trainable Parameters: %.3fM' % parameters) 435 | -------------------------------------------------------------------------------- /src/synthesizer/models/tacotron.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from pathlib import Path 7 | from typing import Union 8 | 9 | 10 | class HighwayNetwork(nn.Module): 11 | def __init__(self, size): 12 | super().__init__() 13 | self.W1 = nn.Linear(size, size) 14 | self.W2 = nn.Linear(size, size) 15 | self.W1.bias.data.fill_(0.) 16 | 17 | def forward(self, x): 18 | x1 = self.W1(x) 19 | x2 = self.W2(x) 20 | g = torch.sigmoid(x2) 21 | y = g * F.relu(x1) + (1. - g) * x 22 | return y 23 | 24 | 25 | class Encoder(nn.Module): 26 | def __init__(self, embed_dims, num_chars, encoder_dims, K, num_highways, dropout): 27 | super().__init__() 28 | prenet_dims = (encoder_dims, encoder_dims) 29 | cbhg_channels = encoder_dims 30 | self.embedding = nn.Embedding(num_chars, embed_dims) 31 | self.pre_net = PreNet(embed_dims, fc1_dims=prenet_dims[0], fc2_dims=prenet_dims[1], 32 | dropout=dropout) 33 | self.cbhg = CBHG(K=K, in_channels=cbhg_channels, channels=cbhg_channels, 34 | proj_channels=[cbhg_channels, cbhg_channels], 35 | num_highways=num_highways) 36 | 37 | def forward(self, x, speaker_embedding=None): 38 | x = self.embedding(x) 39 | x = self.pre_net(x) 40 | x.transpose_(1, 2) 41 | x = self.cbhg(x) 42 | if speaker_embedding is not None: 43 | x = self.add_speaker_embedding(x, speaker_embedding) 44 | return x 45 | 46 | def add_speaker_embedding(self, x, speaker_embedding): 47 | # SV2TTS 48 | # The input x is the encoder output and is a 3D tensor with size (batch_size, num_chars, tts_embed_dims) 49 | # When training, speaker_embedding is also a 2D tensor with size (batch_size, speaker_embedding_size) 50 | # (for inference, speaker_embedding is a 1D tensor with size (speaker_embedding_size)) 51 | # This concats the speaker embedding for each char in the encoder output 52 | 53 | # Save the dimensions as human-readable names 54 | batch_size = x.size()[0] 55 | num_chars = x.size()[1] 56 | 57 | if speaker_embedding.dim() == 1: 58 | idx = 0 59 | else: 60 | idx = 1 61 | 62 | # Start by making a copy of each speaker embedding to match the input text length 63 | # The output of this has size (batch_size, num_chars * tts_embed_dims) 64 | speaker_embedding_size = speaker_embedding.size()[idx] 65 | e = speaker_embedding.repeat_interleave(num_chars, dim=idx) 66 | 67 | # Reshape it and transpose 68 | e = e.reshape(batch_size, speaker_embedding_size, num_chars) 69 | e = e.transpose(1, 2) 70 | 71 | # Concatenate the tiled speaker embedding with the encoder output 72 | x = torch.cat((x, e), 2) 73 | return x 74 | 75 | 76 | class BatchNormConv(nn.Module): 77 | def __init__(self, in_channels, out_channels, kernel, relu=True): 78 | super().__init__() 79 | self.conv = nn.Conv1d(in_channels, out_channels, kernel, stride=1, padding=kernel // 2, bias=False) 80 | self.bnorm = nn.BatchNorm1d(out_channels) 81 | self.relu = relu 82 | 83 | def forward(self, x): 84 | x = self.conv(x) 85 | x = F.relu(x) if self.relu is True else x 86 | return self.bnorm(x) 87 | 88 | 89 | class CBHG(nn.Module): 90 | def __init__(self, K, in_channels, channels, proj_channels, num_highways): 91 | super().__init__() 92 | 93 | # List of all rnns to call `flatten_parameters()` on 94 | self._to_flatten = [] 95 | 96 | self.bank_kernels = [i for i in range(1, K + 1)] 97 | self.conv1d_bank = nn.ModuleList() 98 | for k in self.bank_kernels: 99 | conv = BatchNormConv(in_channels, channels, k) 100 | self.conv1d_bank.append(conv) 101 | 102 | self.maxpool = nn.MaxPool1d(kernel_size=2, stride=1, padding=1) 103 | 104 | self.conv_project1 = BatchNormConv(len(self.bank_kernels) * channels, proj_channels[0], 3) 105 | self.conv_project2 = BatchNormConv(proj_channels[0], proj_channels[1], 3, relu=False) 106 | 107 | # Fix the highway input if necessary 108 | if proj_channels[-1] != channels: 109 | self.highway_mismatch = True 110 | self.pre_highway = nn.Linear(proj_channels[-1], channels, bias=False) 111 | else: 112 | self.highway_mismatch = False 113 | 114 | self.highways = nn.ModuleList() 115 | for i in range(num_highways): 116 | hn = HighwayNetwork(channels) 117 | self.highways.append(hn) 118 | 119 | self.rnn = nn.GRU(channels, channels // 2, batch_first=True, bidirectional=True) 120 | self._to_flatten.append(self.rnn) 121 | 122 | # Avoid fragmentation of RNN parameters and associated warning 123 | self._flatten_parameters() 124 | 125 | def forward(self, x): 126 | # Although we `_flatten_parameters()` on init, when using DataParallel 127 | # the model gets replicated, making it no longer guaranteed that the 128 | # weights are contiguous in GPU memory. Hence, we must call it again 129 | self._flatten_parameters() 130 | 131 | # Save these for later 132 | residual = x 133 | seq_len = x.size(-1) 134 | conv_bank = [] 135 | 136 | # Convolution Bank 137 | for conv in self.conv1d_bank: 138 | c = conv(x) # Convolution 139 | conv_bank.append(c[:, :, :seq_len]) 140 | 141 | # Stack along the channel axis 142 | conv_bank = torch.cat(conv_bank, dim=1) 143 | 144 | # dump the last padding to fit residual 145 | x = self.maxpool(conv_bank)[:, :, :seq_len] 146 | 147 | # Conv1d projections 148 | x = self.conv_project1(x) 149 | x = self.conv_project2(x) 150 | 151 | # Residual Connect 152 | x = x + residual 153 | 154 | # Through the highways 155 | x = x.transpose(1, 2) 156 | if self.highway_mismatch is True: 157 | x = self.pre_highway(x) 158 | for h in self.highways: x = h(x) 159 | 160 | # And then the RNN 161 | x, _ = self.rnn(x) 162 | return x 163 | 164 | def _flatten_parameters(self): 165 | """Calls `flatten_parameters` on all the rnns used by the WaveRNN. Used 166 | to improve efficiency and avoid PyTorch yelling at us.""" 167 | [m.flatten_parameters() for m in self._to_flatten] 168 | 169 | class PreNet(nn.Module): 170 | def __init__(self, in_dims, fc1_dims=256, fc2_dims=128, dropout=0.5): 171 | super().__init__() 172 | self.fc1 = nn.Linear(in_dims, fc1_dims) 173 | self.fc2 = nn.Linear(fc1_dims, fc2_dims) 174 | self.p = dropout 175 | 176 | def forward(self, x): 177 | x = self.fc1(x) 178 | x = F.relu(x) 179 | x = F.dropout(x, self.p, training=True) 180 | x = self.fc2(x) 181 | x = F.relu(x) 182 | x = F.dropout(x, self.p, training=True) 183 | return x 184 | 185 | 186 | class Attention(nn.Module): 187 | def __init__(self, attn_dims): 188 | super().__init__() 189 | self.W = nn.Linear(attn_dims, attn_dims, bias=False) 190 | self.v = nn.Linear(attn_dims, 1, bias=False) 191 | 192 | def forward(self, encoder_seq_proj, query, t): 193 | 194 | # print(encoder_seq_proj.shape) 195 | # Transform the query vector 196 | query_proj = self.W(query).unsqueeze(1) 197 | 198 | # Compute the scores 199 | u = self.v(torch.tanh(encoder_seq_proj + query_proj)) 200 | scores = F.softmax(u, dim=1) 201 | 202 | return scores.transpose(1, 2) 203 | 204 | 205 | class LSA(nn.Module): 206 | def __init__(self, attn_dim, kernel_size=31, filters=32): 207 | super().__init__() 208 | self.conv = nn.Conv1d(1, filters, padding=(kernel_size - 1) // 2, kernel_size=kernel_size, bias=True) 209 | self.L = nn.Linear(filters, attn_dim, bias=False) 210 | self.W = nn.Linear(attn_dim, attn_dim, bias=True) # Include the attention bias in this term 211 | self.v = nn.Linear(attn_dim, 1, bias=False) 212 | self.cumulative = None 213 | self.attention = None 214 | 215 | def init_attention(self, encoder_seq_proj): 216 | device = next(self.parameters()).device # use same device as parameters 217 | b, t, c = encoder_seq_proj.size() 218 | self.cumulative = torch.zeros(b, t, device=device) 219 | self.attention = torch.zeros(b, t, device=device) 220 | 221 | def forward(self, encoder_seq_proj, query, t, chars): 222 | 223 | if t == 0: self.init_attention(encoder_seq_proj) 224 | 225 | processed_query = self.W(query).unsqueeze(1) 226 | 227 | location = self.cumulative.unsqueeze(1) 228 | processed_loc = self.L(self.conv(location).transpose(1, 2)) 229 | 230 | u = self.v(torch.tanh(processed_query + encoder_seq_proj + processed_loc)) 231 | u = u.squeeze(-1) 232 | 233 | # Mask zero padding chars 234 | u = u * (chars != 0).float() 235 | 236 | # Smooth Attention 237 | # scores = torch.sigmoid(u) / torch.sigmoid(u).sum(dim=1, keepdim=True) 238 | scores = F.softmax(u, dim=1) 239 | self.attention = scores 240 | self.cumulative = self.cumulative + self.attention 241 | 242 | return scores.unsqueeze(-1).transpose(1, 2) 243 | 244 | 245 | class Decoder(nn.Module): 246 | # Class variable because its value doesn't change between classes 247 | # yet ought to be scoped by class because its a property of a Decoder 248 | max_r = 20 249 | def __init__(self, n_mels, encoder_dims, decoder_dims, lstm_dims, 250 | dropout, speaker_embedding_size): 251 | super().__init__() 252 | self.register_buffer("r", torch.tensor(1, dtype=torch.int)) 253 | self.n_mels = n_mels 254 | prenet_dims = (decoder_dims * 2, decoder_dims * 2) 255 | self.prenet = PreNet(n_mels, fc1_dims=prenet_dims[0], fc2_dims=prenet_dims[1], 256 | dropout=dropout) 257 | self.attn_net = LSA(decoder_dims) 258 | self.attn_rnn = nn.GRUCell(encoder_dims + prenet_dims[1] + speaker_embedding_size, decoder_dims) 259 | self.rnn_input = nn.Linear(encoder_dims + decoder_dims + speaker_embedding_size, lstm_dims) 260 | self.res_rnn1 = nn.LSTMCell(lstm_dims, lstm_dims) 261 | self.res_rnn2 = nn.LSTMCell(lstm_dims, lstm_dims) 262 | self.mel_proj = nn.Linear(lstm_dims, n_mels * self.max_r, bias=False) 263 | self.stop_proj = nn.Linear(encoder_dims + speaker_embedding_size + lstm_dims, 1) 264 | 265 | def zoneout(self, prev, current, p=0.1): 266 | device = next(self.parameters()).device # Use same device as parameters 267 | mask = torch.zeros(prev.size(), device=device).bernoulli_(p) 268 | return prev * mask + current * (1 - mask) 269 | 270 | def forward(self, encoder_seq, encoder_seq_proj, prenet_in, 271 | hidden_states, cell_states, context_vec, t, chars): 272 | 273 | # Need this for reshaping mels 274 | batch_size = encoder_seq.size(0) 275 | 276 | # Unpack the hidden and cell states 277 | attn_hidden, rnn1_hidden, rnn2_hidden = hidden_states 278 | rnn1_cell, rnn2_cell = cell_states 279 | 280 | # PreNet for the Attention RNN 281 | prenet_out = self.prenet(prenet_in) 282 | 283 | # Compute the Attention RNN hidden state 284 | attn_rnn_in = torch.cat([context_vec, prenet_out], dim=-1) 285 | attn_hidden = self.attn_rnn(attn_rnn_in.squeeze(1), attn_hidden) 286 | 287 | # Compute the attention scores 288 | scores = self.attn_net(encoder_seq_proj, attn_hidden, t, chars) 289 | 290 | # Dot product to create the context vector 291 | context_vec = scores @ encoder_seq 292 | context_vec = context_vec.squeeze(1) 293 | 294 | # Concat Attention RNN output w. Context Vector & project 295 | x = torch.cat([context_vec, attn_hidden], dim=1) 296 | x = self.rnn_input(x) 297 | 298 | # Compute first Residual RNN 299 | rnn1_hidden_next, rnn1_cell = self.res_rnn1(x, (rnn1_hidden, rnn1_cell)) 300 | if self.training: 301 | rnn1_hidden = self.zoneout(rnn1_hidden, rnn1_hidden_next) 302 | else: 303 | rnn1_hidden = rnn1_hidden_next 304 | x = x + rnn1_hidden 305 | 306 | # Compute second Residual RNN 307 | rnn2_hidden_next, rnn2_cell = self.res_rnn2(x, (rnn2_hidden, rnn2_cell)) 308 | if self.training: 309 | rnn2_hidden = self.zoneout(rnn2_hidden, rnn2_hidden_next) 310 | else: 311 | rnn2_hidden = rnn2_hidden_next 312 | x = x + rnn2_hidden 313 | 314 | # Project Mels 315 | mels = self.mel_proj(x) 316 | mels = mels.view(batch_size, self.n_mels, self.max_r)[:, :, :self.r] 317 | hidden_states = (attn_hidden, rnn1_hidden, rnn2_hidden) 318 | cell_states = (rnn1_cell, rnn2_cell) 319 | 320 | # Stop token prediction 321 | s = torch.cat((x, context_vec), dim=1) 322 | s = self.stop_proj(s) 323 | stop_tokens = torch.sigmoid(s) 324 | 325 | return mels, scores, hidden_states, cell_states, context_vec, stop_tokens 326 | 327 | 328 | class Tacotron(nn.Module): 329 | def __init__(self, embed_dims, num_chars, encoder_dims, decoder_dims, n_mels, 330 | fft_bins, postnet_dims, encoder_K, lstm_dims, postnet_K, num_highways, 331 | dropout, stop_threshold, speaker_embedding_size): 332 | super().__init__() 333 | self.n_mels = n_mels 334 | self.lstm_dims = lstm_dims 335 | self.encoder_dims = encoder_dims 336 | self.decoder_dims = decoder_dims 337 | self.speaker_embedding_size = speaker_embedding_size 338 | self.encoder = Encoder(embed_dims, num_chars, encoder_dims, 339 | encoder_K, num_highways, dropout) 340 | self.encoder_proj = nn.Linear(encoder_dims + speaker_embedding_size, decoder_dims, bias=False) 341 | self.decoder = Decoder(n_mels, encoder_dims, decoder_dims, lstm_dims, 342 | dropout, speaker_embedding_size) 343 | self.postnet = CBHG(postnet_K, n_mels, postnet_dims, 344 | [postnet_dims, fft_bins], num_highways) 345 | self.post_proj = nn.Linear(postnet_dims, fft_bins, bias=False) 346 | 347 | self.init_model() 348 | self.num_params(print_out=False) 349 | 350 | self.register_buffer("step", torch.zeros(1, dtype=torch.long)) 351 | self.register_buffer("stop_threshold", torch.tensor(stop_threshold, dtype=torch.float32)) 352 | 353 | @property 354 | def r(self): 355 | return self.decoder.r.item() 356 | 357 | @r.setter 358 | def r(self, value): 359 | self.decoder.r = self.decoder.r.new_tensor(value, requires_grad=False) 360 | 361 | def forward(self, x, m, speaker_embedding): 362 | device = next(self.parameters()).device # use same device as parameters 363 | 364 | self.step += 1 365 | batch_size, _, steps = m.size() 366 | 367 | # Initialise all hidden states and pack into tuple 368 | attn_hidden = torch.zeros(batch_size, self.decoder_dims, device=device) 369 | rnn1_hidden = torch.zeros(batch_size, self.lstm_dims, device=device) 370 | rnn2_hidden = torch.zeros(batch_size, self.lstm_dims, device=device) 371 | hidden_states = (attn_hidden, rnn1_hidden, rnn2_hidden) 372 | 373 | # Initialise all lstm cell states and pack into tuple 374 | rnn1_cell = torch.zeros(batch_size, self.lstm_dims, device=device) 375 | rnn2_cell = torch.zeros(batch_size, self.lstm_dims, device=device) 376 | cell_states = (rnn1_cell, rnn2_cell) 377 | 378 | # Frame for start of decoder loop 379 | go_frame = torch.zeros(batch_size, self.n_mels, device=device) 380 | 381 | # Need an initial context vector 382 | context_vec = torch.zeros(batch_size, self.encoder_dims + self.speaker_embedding_size, device=device) 383 | 384 | # SV2TTS: Run the encoder with the speaker embedding 385 | # The projection avoids unnecessary matmuls in the decoder loop 386 | encoder_seq = self.encoder(x, speaker_embedding) 387 | encoder_seq_proj = self.encoder_proj(encoder_seq) 388 | 389 | # Need a couple of lists for outputs 390 | mel_outputs, attn_scores, stop_outputs = [], [], [] 391 | 392 | # Run the decoder loop 393 | for t in range(0, steps, self.r): 394 | prenet_in = m[:, :, t - 1] if t > 0 else go_frame 395 | mel_frames, scores, hidden_states, cell_states, context_vec, stop_tokens = \ 396 | self.decoder(encoder_seq, encoder_seq_proj, prenet_in, 397 | hidden_states, cell_states, context_vec, t, x) 398 | mel_outputs.append(mel_frames) 399 | attn_scores.append(scores) 400 | stop_outputs.extend([stop_tokens] * self.r) 401 | 402 | # Concat the mel outputs into sequence 403 | mel_outputs = torch.cat(mel_outputs, dim=2) 404 | 405 | # Post-Process for Linear Spectrograms 406 | postnet_out = self.postnet(mel_outputs) 407 | linear = self.post_proj(postnet_out) 408 | linear = linear.transpose(1, 2) 409 | 410 | # For easy visualisation 411 | attn_scores = torch.cat(attn_scores, 1) 412 | # attn_scores = attn_scores.cpu().data.numpy() 413 | stop_outputs = torch.cat(stop_outputs, 1) 414 | 415 | return mel_outputs, linear, attn_scores, stop_outputs 416 | 417 | def generate(self, x, speaker_embedding=None, steps=2000): 418 | self.eval() 419 | device = next(self.parameters()).device # use same device as parameters 420 | 421 | batch_size, _ = x.size() 422 | 423 | # Need to initialise all hidden states and pack into tuple for tidyness 424 | attn_hidden = torch.zeros(batch_size, self.decoder_dims, device=device) 425 | rnn1_hidden = torch.zeros(batch_size, self.lstm_dims, device=device) 426 | rnn2_hidden = torch.zeros(batch_size, self.lstm_dims, device=device) 427 | hidden_states = (attn_hidden, rnn1_hidden, rnn2_hidden) 428 | 429 | # Need to initialise all lstm cell states and pack into tuple for tidyness 430 | rnn1_cell = torch.zeros(batch_size, self.lstm_dims, device=device) 431 | rnn2_cell = torch.zeros(batch_size, self.lstm_dims, device=device) 432 | cell_states = (rnn1_cell, rnn2_cell) 433 | 434 | # Need a Frame for start of decoder loop 435 | go_frame = torch.zeros(batch_size, self.n_mels, device=device) 436 | 437 | # Need an initial context vector 438 | context_vec = torch.zeros(batch_size, self.encoder_dims + self.speaker_embedding_size, device=device) 439 | 440 | # SV2TTS: Run the encoder with the speaker embedding 441 | # The projection avoids unnecessary matmuls in the decoder loop 442 | encoder_seq = self.encoder(x, speaker_embedding) 443 | encoder_seq_proj = self.encoder_proj(encoder_seq) 444 | 445 | # Need a couple of lists for outputs 446 | mel_outputs, attn_scores, stop_outputs = [], [], [] 447 | 448 | # Run the decoder loop 449 | for t in range(0, steps, self.r): 450 | prenet_in = mel_outputs[-1][:, :, -1] if t > 0 else go_frame 451 | mel_frames, scores, hidden_states, cell_states, context_vec, stop_tokens = \ 452 | self.decoder(encoder_seq, encoder_seq_proj, prenet_in, 453 | hidden_states, cell_states, context_vec, t, x) 454 | mel_outputs.append(mel_frames) 455 | attn_scores.append(scores) 456 | stop_outputs.extend([stop_tokens] * self.r) 457 | # Stop the loop when all stop tokens in batch exceed threshold 458 | if (stop_tokens > 0.5).all() and t > 10: break 459 | 460 | # Concat the mel outputs into sequence 461 | mel_outputs = torch.cat(mel_outputs, dim=2) 462 | 463 | # Post-Process for Linear Spectrograms 464 | postnet_out = self.postnet(mel_outputs) 465 | linear = self.post_proj(postnet_out) 466 | 467 | 468 | linear = linear.transpose(1, 2) 469 | 470 | # For easy visualisation 471 | attn_scores = torch.cat(attn_scores, 1) 472 | stop_outputs = torch.cat(stop_outputs, 1) 473 | 474 | self.train() 475 | 476 | return mel_outputs, linear, attn_scores 477 | 478 | def init_model(self): 479 | for p in self.parameters(): 480 | if p.dim() > 1: nn.init.xavier_uniform_(p) 481 | 482 | def get_step(self): 483 | return self.step.data.item() 484 | 485 | def reset_step(self): 486 | # assignment to parameters or buffers is overloaded, updates internal dict entry 487 | self.step = self.step.data.new_tensor(1) 488 | 489 | def log(self, path, msg): 490 | with open(path, "a") as f: 491 | print(msg, file=f) 492 | 493 | def load(self, path, optimizer=None): 494 | # Use device of model params as location for loaded state 495 | device = next(self.parameters()).device 496 | checkpoint = torch.load(str(path), map_location=device) 497 | self.load_state_dict(checkpoint["model_state"]) 498 | 499 | if "optimizer_state" in checkpoint and optimizer is not None: 500 | optimizer.load_state_dict(checkpoint["optimizer_state"]) 501 | 502 | def save(self, path, optimizer=None): 503 | if optimizer is not None: 504 | torch.save({ 505 | "model_state": self.state_dict(), 506 | "optimizer_state": optimizer.state_dict(), 507 | }, str(path)) 508 | else: 509 | torch.save({ 510 | "model_state": self.state_dict(), 511 | }, str(path)) 512 | 513 | 514 | def num_params(self, print_out=True): 515 | parameters = filter(lambda p: p.requires_grad, self.parameters()) 516 | parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000 517 | if print_out: 518 | print("Trainable Parameters: %.3fM" % parameters) 519 | return parameters 520 | --------------------------------------------------------------------------------