├── utils ├── __init__.py ├── argutils.py ├── profiler.py └── logmmse.py ├── encoder ├── __init__.py ├── data_objects │ ├── __init__.py │ ├── speaker_batch.py │ ├── utterance.py │ ├── random_cycler.py │ ├── speaker.py │ └── speaker_verification_dataset.py ├── params_model.py ├── params_data.py ├── config.py ├── plot_umap.py ├── audio.py ├── train.py ├── model.py ├── visualizations.py ├── preprocess.py └── inference.py ├── .gitattributes ├── requirements.txt ├── toolbox ├── utterance.py ├── __init__.py └── ui.py ├── .gitignore ├── README.md ├── LICENSE.txt ├── encoder_train.py ├── encoder_preprocess.py └── generate_embeddings.py /utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /encoder/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | *.ipynb linguist-vendored 2 | -------------------------------------------------------------------------------- /encoder/data_objects/__init__.py: -------------------------------------------------------------------------------- 1 | from encoder.data_objects.speaker_verification_dataset import SpeakerVerificationDataset 2 | from encoder.data_objects.speaker_verification_dataset import SpeakerVerificationDataLoader 3 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=1.0.1 2 | umap-learn 3 | visdom 4 | webrtcvad 5 | librosa>=0.5.1 6 | matplotlib>=2.0.2 7 | numpy>=1.14.0 8 | scipy>=1.0.0 9 | tqdm 10 | sounddevice 11 | Unidecode 12 | inflect 13 | PyQt5 14 | multiprocess 15 | numba 16 | -------------------------------------------------------------------------------- /toolbox/utterance.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | 3 | Utterance = namedtuple("Utterance", "name speaker_name wav spec embed partial_embeds synth") 4 | Utterance.__eq__ = lambda x, y: x.name == y.name 5 | Utterance.__hash__ = lambda x: hash(x.name) 6 | -------------------------------------------------------------------------------- /encoder/params_model.py: -------------------------------------------------------------------------------- 1 | 2 | ## Model parameters 3 | model_hidden_size = 256 4 | model_embedding_size = 256 5 | model_num_layers = 3 6 | 7 | 8 | ## Training parameters 9 | learning_rate_init = 1e-4 10 | speakers_per_batch = 64 11 | utterances_per_speaker = 10 12 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *.aux 3 | *.log 4 | *.out 5 | *.synctex.gz 6 | *.suo 7 | *__pycache__ 8 | *.idea 9 | *.ipynb_checkpoints 10 | *.pickle 11 | *.npy 12 | *.blg 13 | *.bbl 14 | *.bcf 15 | *.toc 16 | *.wav 17 | *.sh 18 | encoder/saved_models/* 19 | synthesizer/saved_models/* 20 | vocoder/saved_models/* 21 | -------------------------------------------------------------------------------- /encoder/data_objects/speaker_batch.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from typing import List 3 | from encoder.data_objects.speaker import Speaker 4 | 5 | class SpeakerBatch: 6 | def __init__(self, speakers: List[Speaker], utterances_per_speaker: int, n_frames: int): 7 | self.speakers = speakers 8 | self.partials = {s: s.random_partial(utterances_per_speaker, n_frames) for s in speakers} 9 | 10 | # Array of shape (n_speakers * n_utterances, n_frames, mel_n), e.g. for 3 speakers with 11 | # 4 utterances each of 160 frames of 40 mel coefficients: (12, 160, 40) 12 | self.data = np.array([frames for s in speakers for _, frames, _ in self.partials[s]]) 13 | -------------------------------------------------------------------------------- /encoder/data_objects/utterance.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class Utterance: 5 | def __init__(self, frames_fpath, wave_fpath): 6 | self.frames_fpath = frames_fpath 7 | self.wave_fpath = wave_fpath 8 | 9 | def get_frames(self): 10 | return np.load(self.frames_fpath) 11 | 12 | def random_partial(self, n_frames): 13 | """ 14 | Crops the frames into a partial utterance of n_frames 15 | 16 | :param n_frames: The number of frames of the partial utterance 17 | :return: the partial utterance frames and a tuple indicating the start and end of the 18 | partial utterance in the complete utterance. 19 | """ 20 | frames = self.get_frames() 21 | if frames.shape[0] == n_frames: 22 | start = 0 23 | else: 24 | start = np.random.randint(0, frames.shape[0] - n_frames) 25 | end = start + n_frames 26 | return frames[start:end], (start, end) -------------------------------------------------------------------------------- /encoder/params_data.py: -------------------------------------------------------------------------------- 1 | 2 | ## Mel-filterbank 3 | mel_window_length = 25 # In milliseconds 4 | mel_window_step = 10 # In milliseconds 5 | mel_n_channels = 40 6 | 7 | 8 | ## Audio 9 | sampling_rate = 16000 10 | # Number of spectrogram frames in a partial utterance 11 | partials_n_frames = 160 # 1600 ms 12 | # Number of spectrogram frames at inference 13 | inference_n_frames = 80 # 800 ms 14 | 15 | 16 | ## Voice Activation Detection 17 | # Window size of the VAD. Must be either 10, 20 or 30 milliseconds. 18 | # This sets the granularity of the VAD. Should not need to be changed. 19 | vad_window_length = 30 # In milliseconds 20 | # Number of frames to average together when performing the moving average smoothing. 21 | # The larger this value, the larger the VAD variations must be to not get smoothed out. 22 | vad_moving_average_width = 8 23 | # Maximum number of consecutive silent frames a segment can have. 24 | vad_max_silence_length = 6 25 | 26 | 27 | ## Audio volume normalization 28 | audio_norm_target_dBFS = -30 29 | 30 | -------------------------------------------------------------------------------- /utils/argutils.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import numpy as np 3 | import argparse 4 | 5 | _type_priorities = [ # In decreasing order 6 | Path, 7 | str, 8 | int, 9 | float, 10 | bool, 11 | ] 12 | 13 | def _priority(o): 14 | p = next((i for i, t in enumerate(_type_priorities) if type(o) is t), None) 15 | if p is not None: 16 | return p 17 | p = next((i for i, t in enumerate(_type_priorities) if isinstance(o, t)), None) 18 | if p is not None: 19 | return p 20 | return len(_type_priorities) 21 | 22 | def print_args(args: argparse.Namespace, parser=None): 23 | args = vars(args) 24 | if parser is None: 25 | priorities = list(map(_priority, args.values())) 26 | else: 27 | all_params = [a.dest for g in parser._action_groups for a in g._group_actions ] 28 | priority = lambda p: all_params.index(p) if p in all_params else len(all_params) 29 | priorities = list(map(priority, args.keys())) 30 | 31 | pad = max(map(len, args.keys())) + 3 32 | indices = np.lexsort((list(args.keys()), priorities)) 33 | items = list(args.items()) 34 | 35 | print("Arguments:") 36 | for i in indices: 37 | param, value = items[i] 38 | print(" {0}:{1}{2}".format(param, ' ' * (pad - len(param)), value)) 39 | print("") 40 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # GE2E Speaker Embeddings 2 | This repository is an implementation of Generalized End-to-End Loss for Speaker Verification 3 | paper: [GE2E](https://arxiv.org/abs/1710.10467) 4 | 5 | ### Requirements 6 | 7 | **Python 3.6 +**. 8 | 9 | Run `pip install -r requirements.txt` to install the necessary packages. 10 | 11 | A GPU is mandatory, but you don't necessarily need a high tier GPU if you only want to use the toolbox. 12 | 13 | ### Datasets 14 | 15 | Ideally, all your datasets are kept under a same directory i.e., . All prepreprocessing scripts will, by default, output the clean data to a new directory SV2TTS created in your datasets root directory. Inside this directory will be created a directory for the encoder. 16 | 17 | For the encoder: 18 | 19 | LibriSpeech: train-other-500 (extract as LibriSpeech/train-other-500) 20 | VoxCeleb1: Dev A - D as well as the metadata file (extract as VoxCeleb1/wav and VoxCeleb1/vox1_meta.csv) 21 | VoxCeleb2: Dev A - H (extract as VoxCeleb2/dev) 22 | 23 | ### Preprocessing and training 24 | 25 | python encoder_preprocess.py 26 | python encoder_train.py my_run /SV2TTS/encoder 27 | 28 | ### Generate speaker embeddings 29 | 30 | python generate_embeddings.py 31 | 32 | ## For more details 33 | (https://github.com/CorentinJ/Real-Time-Voice-Cloning) 34 | 35 | -------------------------------------------------------------------------------- /encoder/config.py: -------------------------------------------------------------------------------- 1 | librispeech_datasets = { 2 | "train": { 3 | "clean": ["LibriSpeech/train-clean-100", "LibriSpeech/train-clean-360"], 4 | "other": ["LibriSpeech/train-other-500"] 5 | }, 6 | "test": { 7 | "clean": ["LibriSpeech/test-clean"], 8 | "other": ["LibriSpeech/test-other"] 9 | }, 10 | "dev": { 11 | "clean": ["LibriSpeech/dev-clean"], 12 | "other": ["LibriSpeech/dev-other"] 13 | }, 14 | } 15 | libritts_datasets = { 16 | "train": { 17 | "clean": ["LibriTTS/train-clean-100", "LibriTTS/train-clean-360"], 18 | "other": ["LibriTTS/train-other-500"] 19 | }, 20 | "test": { 21 | "clean": ["LibriTTS/test-clean"], 22 | "other": ["LibriTTS/test-other"] 23 | }, 24 | "dev": { 25 | "clean": ["LibriTTS/dev-clean"], 26 | "other": ["LibriTTS/dev-other"] 27 | }, 28 | } 29 | voxceleb_datasets = { 30 | "voxceleb1" : { 31 | "train": ["VoxCeleb1/wav"], 32 | "test": ["VoxCeleb1/test_wav"] 33 | }, 34 | "voxceleb2" : { 35 | "train": ["VoxCeleb2/dev/aac"], 36 | "test": ["VoxCeleb2/test_wav"] 37 | } 38 | } 39 | 40 | other_datasets = [ 41 | "LJSpeech-1.1", 42 | "VCTK-Corpus/wav48", 43 | ] 44 | 45 | anglophone_nationalites = ["australia", "canada", "ireland", "uk", "usa"] 46 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Modified & original work Copyright (c) 2019 Corentin Jemine (https://github.com/CorentinJ) 4 | Original work Copyright (c) 2018 Rayhane Mama (https://github.com/Rayhane-mamah) 5 | Original work Copyright (c) 2019 fatchord (https://github.com/fatchord) 6 | Original work Copyright (c) 2015 braindead (https://github.com/braindead) 7 | 8 | Permission is hereby granted, free of charge, to any person obtaining a copy 9 | of this software and associated documentation files (the "Software"), to deal 10 | in the Software without restriction, including without limitation the rights 11 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 12 | copies of the Software, and to permit persons to whom the Software is 13 | furnished to do so, subject to the following conditions: 14 | 15 | The above copyright notice and this permission notice shall be included in all 16 | copies or substantial portions of the Software. 17 | 18 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 19 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 20 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 21 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 22 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 23 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 24 | SOFTWARE. 25 | -------------------------------------------------------------------------------- /encoder/data_objects/random_cycler.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | class RandomCycler: 4 | """ 5 | Creates an internal copy of a sequence and allows access to its items in a constrained random 6 | order. For a source sequence of n items and one or several consecutive queries of a total 7 | of m items, the following guarantees hold (one implies the other): 8 | - Each item will be returned between m // n and ((m - 1) // n) + 1 times. 9 | - Between two appearances of the same item, there may be at most 2 * (n - 1) other items. 10 | """ 11 | 12 | def __init__(self, source): 13 | if len(source) == 0: 14 | raise Exception("Can't create RandomCycler from an empty collection") 15 | self.all_items = list(source) 16 | self.next_items = [] 17 | 18 | def sample(self, count: int): 19 | shuffle = lambda l: random.sample(l, len(l)) 20 | 21 | out = [] 22 | while count > 0: 23 | if count >= len(self.all_items): 24 | out.extend(shuffle(list(self.all_items))) 25 | count -= len(self.all_items) 26 | continue 27 | n = min(count, len(self.next_items)) 28 | out.extend(self.next_items[:n]) 29 | count -= n 30 | self.next_items = self.next_items[n:] 31 | if len(self.next_items) == 0: 32 | self.next_items = shuffle(list(self.all_items)) 33 | return out 34 | 35 | def __next__(self): 36 | return self.sample(1)[0] 37 | 38 | -------------------------------------------------------------------------------- /utils/profiler.py: -------------------------------------------------------------------------------- 1 | from time import perf_counter as timer 2 | from collections import OrderedDict 3 | import numpy as np 4 | 5 | 6 | class Profiler: 7 | def __init__(self, summarize_every=5, disabled=False): 8 | self.last_tick = timer() 9 | self.logs = OrderedDict() 10 | self.summarize_every = summarize_every 11 | self.disabled = disabled 12 | 13 | def tick(self, name): 14 | if self.disabled: 15 | return 16 | 17 | # Log the time needed to execute that function 18 | if not name in self.logs: 19 | self.logs[name] = [] 20 | if len(self.logs[name]) >= self.summarize_every: 21 | self.summarize() 22 | self.purge_logs() 23 | self.logs[name].append(timer() - self.last_tick) 24 | 25 | self.reset_timer() 26 | 27 | def purge_logs(self): 28 | for name in self.logs: 29 | self.logs[name].clear() 30 | 31 | def reset_timer(self): 32 | self.last_tick = timer() 33 | 34 | def summarize(self): 35 | n = max(map(len, self.logs.values())) 36 | assert n == self.summarize_every 37 | print("\nAverage execution time over %d steps:" % n) 38 | 39 | name_msgs = ["%s (%d/%d):" % (name, len(deltas), n) for name, deltas in self.logs.items()] 40 | pad = max(map(len, name_msgs)) 41 | for name_msg, deltas in zip(name_msgs, self.logs.values()): 42 | print(" %s mean: %4.0fms std: %4.0fms" % 43 | (name_msg.ljust(pad), np.mean(deltas) * 1000, np.std(deltas) * 1000)) 44 | print("", flush=True) 45 | -------------------------------------------------------------------------------- /encoder/data_objects/speaker.py: -------------------------------------------------------------------------------- 1 | from encoder.data_objects.random_cycler import RandomCycler 2 | from encoder.data_objects.utterance import Utterance 3 | from pathlib import Path 4 | 5 | # Contains the set of utterances of a single speaker 6 | class Speaker: 7 | def __init__(self, root: Path): 8 | self.root = root 9 | self.name = root.name 10 | self.utterances = None 11 | self.utterance_cycler = None 12 | 13 | def _load_utterances(self): 14 | with self.root.joinpath("_sources.txt").open("r") as sources_file: 15 | sources = [l.split(",") for l in sources_file] 16 | sources = {frames_fname: wave_fpath for frames_fname, wave_fpath in sources} 17 | self.utterances = [Utterance(self.root.joinpath(f), w) for f, w in sources.items()] 18 | self.utterance_cycler = RandomCycler(self.utterances) 19 | 20 | def random_partial(self, count, n_frames): 21 | """ 22 | Samples a batch of unique partial utterances from the disk in a way that all 23 | utterances come up at least once every two cycles and in a random order every time. 24 | 25 | :param count: The number of partial utterances to sample from the set of utterances from 26 | that speaker. Utterances are guaranteed not to be repeated if is not larger than 27 | the number of utterances available. 28 | :param n_frames: The number of frames in the partial utterance. 29 | :return: A list of tuples (utterance, frames, range) where utterance is an Utterance, 30 | frames are the frames of the partial utterances and range is the range of the partial 31 | utterance with regard to the complete utterance. 32 | """ 33 | if self.utterances is None: 34 | self._load_utterances() 35 | 36 | utterances = self.utterance_cycler.sample(count) 37 | 38 | a = [(u,) + u.random_partial(n_frames) for u in utterances] 39 | 40 | return a 41 | -------------------------------------------------------------------------------- /encoder/data_objects/speaker_verification_dataset.py: -------------------------------------------------------------------------------- 1 | from encoder.data_objects.random_cycler import RandomCycler 2 | from encoder.data_objects.speaker_batch import SpeakerBatch 3 | from encoder.data_objects.speaker import Speaker 4 | from encoder.params_data import partials_n_frames 5 | from torch.utils.data import Dataset, DataLoader 6 | from pathlib import Path 7 | 8 | # TODO: improve with a pool of speakers for data efficiency 9 | 10 | class SpeakerVerificationDataset(Dataset): 11 | def __init__(self, datasets_root: Path): 12 | self.root = datasets_root 13 | speaker_dirs = [f for f in self.root.glob("*") if f.is_dir()] 14 | if len(speaker_dirs) == 0: 15 | raise Exception("No speakers found. Make sure you are pointing to the directory " 16 | "containing all preprocessed speaker directories.") 17 | self.speakers = [Speaker(speaker_dir) for speaker_dir in speaker_dirs] 18 | self.speaker_cycler = RandomCycler(self.speakers) 19 | 20 | def __len__(self): 21 | return int(1e10) 22 | 23 | def __getitem__(self, index): 24 | return next(self.speaker_cycler) 25 | 26 | def get_logs(self): 27 | log_string = "" 28 | for log_fpath in self.root.glob("*.txt"): 29 | with log_fpath.open("r") as log_file: 30 | log_string += "".join(log_file.readlines()) 31 | return log_string 32 | 33 | 34 | class SpeakerVerificationDataLoader(DataLoader): 35 | def __init__(self, dataset, speakers_per_batch, utterances_per_speaker, sampler=None, 36 | batch_sampler=None, num_workers=0, pin_memory=False, timeout=0, 37 | worker_init_fn=None): 38 | self.utterances_per_speaker = utterances_per_speaker 39 | 40 | super().__init__( 41 | dataset=dataset, 42 | batch_size=speakers_per_batch, 43 | shuffle=False, 44 | sampler=sampler, 45 | batch_sampler=batch_sampler, 46 | num_workers=num_workers, 47 | collate_fn=self.collate, 48 | pin_memory=pin_memory, 49 | drop_last=False, 50 | timeout=timeout, 51 | worker_init_fn=worker_init_fn 52 | ) 53 | 54 | def collate(self, speakers): 55 | return SpeakerBatch(speakers, self.utterances_per_speaker, partials_n_frames) 56 | -------------------------------------------------------------------------------- /encoder_train.py: -------------------------------------------------------------------------------- 1 | from utils.argutils import print_args 2 | from encoder.train import train 3 | from pathlib import Path 4 | import argparse 5 | 6 | 7 | if __name__ == "__main__": 8 | parser = argparse.ArgumentParser( 9 | description="Trains the speaker encoder. You must have run encoder_preprocess.py first.", 10 | formatter_class=argparse.ArgumentDefaultsHelpFormatter 11 | ) 12 | 13 | parser.add_argument("run_id", type=str, help= \ 14 | "Name for this model instance. If a model state from the same run ID was previously " 15 | "saved, the training will restart from there. Pass -f to overwrite saved states and " 16 | "restart from scratch.") 17 | parser.add_argument("clean_data_root", type=Path, help= \ 18 | "Path to the output directory of encoder_preprocess.py. If you left the default " 19 | "output directory when preprocessing, it should be /SV2TTS/encoder/.") 20 | parser.add_argument("-m", "--models_dir", type=Path, default="encoder/saved_models/", help=\ 21 | "Path to the output directory that will contain the saved model weights, as well as " 22 | "backups of those weights and plots generated during training.") 23 | parser.add_argument("-v", "--vis_every", type=int, default=10, help= \ 24 | "Number of steps between updates of the loss and the plots.") 25 | parser.add_argument("-u", "--umap_every", type=int, default=100, help= \ 26 | "Number of steps between updates of the umap projection. Set to 0 to never update the " 27 | "projections.") 28 | parser.add_argument("-s", "--save_every", type=int, default=500, help= \ 29 | "Number of steps between updates of the model on the disk. Set to 0 to never save the " 30 | "model.") 31 | parser.add_argument("-b", "--backup_every", type=int, default=7500, help= \ 32 | "Number of steps between backups of the model. Set to 0 to never make backups of the " 33 | "model.") 34 | parser.add_argument("-f", "--force_restart", action="store_true", help= \ 35 | "Do not load any saved model.") 36 | parser.add_argument("--visdom_server", type=str, default="http://localhost") 37 | parser.add_argument("--no_visdom", action="store_true", help= \ 38 | "Disable visdom.") 39 | args = parser.parse_args() 40 | 41 | # Process the arguments 42 | args.models_dir.mkdir(exist_ok=True) 43 | 44 | # Run the training 45 | print_args(args, parser) 46 | train(**vars(args)) 47 | -------------------------------------------------------------------------------- /encoder_preprocess.py: -------------------------------------------------------------------------------- 1 | from encoder.preprocess import preprocess_librispeech, preprocess_voxceleb1, preprocess_voxceleb2 2 | from utils.argutils import print_args 3 | from pathlib import Path 4 | import argparse 5 | 6 | 7 | if __name__ == "__main__": 8 | class MyFormatter(argparse.ArgumentDefaultsHelpFormatter, argparse.RawDescriptionHelpFormatter): 9 | pass 10 | 11 | parser = argparse.ArgumentParser( 12 | description="Preprocesses audio files from datasets, encodes them as mel spectrograms and " 13 | "writes them to the disk. This will allow you to train the encoder. The " 14 | "datasets required are at least one of VoxCeleb1, VoxCeleb2 and LibriSpeech. " 15 | "Ideally, you should have all three. You should extract them as they are " 16 | "after having downloaded them and put them in a same directory, e.g.:\n" 17 | "-[datasets_root]\n" 18 | " -LibriSpeech\n" 19 | " -train-other-500\n" 20 | " -VoxCeleb1\n" 21 | " -wav\n" 22 | " -vox1_meta.csv\n" 23 | " -VoxCeleb2\n" 24 | " -dev", 25 | formatter_class=MyFormatter 26 | ) 27 | parser.add_argument("datasets_root", type=Path, help=\ 28 | "Path to the directory containing your LibriSpeech/TTS and VoxCeleb datasets.") 29 | parser.add_argument("-o", "--out_dir", type=Path, default=argparse.SUPPRESS, help=\ 30 | "Path to the output directory that will contain the mel spectrograms. If left out, " 31 | "defaults to /SV2TTS/encoder/") 32 | parser.add_argument("-d", "--datasets", type=str, 33 | default="librispeech_other,voxceleb1,voxceleb2", help=\ 34 | "Comma-separated list of the name of the datasets you want to preprocess. Only the train " 35 | "set of these datasets will be used. Possible names: librispeech_other, voxceleb1, " 36 | "voxceleb2.") 37 | parser.add_argument("-s", "--skip_existing", action="store_true", help=\ 38 | "Whether to skip existing output files with the same name. Useful if this script was " 39 | "interrupted.") 40 | args = parser.parse_args() 41 | 42 | # Process the arguments 43 | args.datasets = args.datasets.split(",") 44 | if not hasattr(args, "out_dir"): 45 | args.out_dir = args.datasets_root.joinpath("SV2TTS", "encoder") 46 | assert args.datasets_root.exists() 47 | args.out_dir.mkdir(exist_ok=True, parents=True) 48 | 49 | # Preprocess the datasets 50 | print_args(args, parser) 51 | preprocess_func = { 52 | "librispeech_other": preprocess_librispeech, 53 | "voxceleb1": preprocess_voxceleb1, 54 | "voxceleb2": preprocess_voxceleb2, 55 | } 56 | args = vars(args) 57 | for dataset in args.pop("datasets"): 58 | print("Preprocessing %s" % dataset) 59 | preprocess_func[dataset](**args) 60 | -------------------------------------------------------------------------------- /generate_embeddings.py: -------------------------------------------------------------------------------- 1 | from encoder.params_model import model_embedding_size as speaker_embedding_size 2 | from encoder import inference as encoder 3 | from pathlib import Path 4 | import numpy as np 5 | import librosa 6 | import argparse 7 | import torch 8 | import sys 9 | import os 10 | import glob 11 | 12 | 13 | if __name__ == '__main__': 14 | # Info & args 15 | parser = argparse.ArgumentParser( 16 | formatter_class=argparse.ArgumentDefaultsHelpFormatter 17 | ) 18 | parser.add_argument("-e", "--enc_model_fpath", type=Path, 19 | default="encoder/saved_models/pretrained.pt", 20 | help="Path to a saved encoder") 21 | parser.add_argument("-s", "--syn_model_dir", type=Path, 22 | default="synthesizer/saved_models/logs-pretrained/", 23 | help="Directory containing the synthesizer model") 24 | parser.add_argument("-v", "--voc_model_fpath", type=Path, 25 | default="vocoder/saved_models/pretrained/pretrained.pt", 26 | help="Path to a saved vocoder") 27 | parser.add_argument("--low_mem", action="store_true", help="If True, the memory used by the synthesizer will be freed after each use. Adds large " 28 | "overhead but allows to save some GPU memory for lower-end GPUs.") 29 | parser.add_argument("--no_sound", action="store_true", 30 | help="If True, audio won't be played.") 31 | parser.add_argument("-a", "--audio_fpath", type=Path, 32 | default="audios/", 33 | help="Path to wave files") 34 | parser.add_argument("-m", "--embed_fpath", type=Path, 35 | default="audios/embeds/", 36 | help="Path to save embeddings") 37 | 38 | args = parser.parse_args() 39 | if not args.no_sound: 40 | import sounddevice as sd 41 | 42 | # File path where wav files are stored 43 | #path_wav = '/home/dipjyoti/speaker_embeddings_GE2E/audios/' 44 | 45 | # File path where generated speaker embeddings are stored 46 | #path_embed = '/home/dipjyoti/speaker_embeddings_GE2E/audios/embeds/' 47 | 48 | 49 | 50 | # Load the models one by one. 51 | print("Preparing the encoder...") 52 | encoder.load_model(args.enc_model_fpath) 53 | print("Insert the wav file name...") 54 | try: 55 | # Get the reference audio filepath 56 | 57 | for filename in glob.glob(os.path.join(args.audio_fpath, '*.wav')): 58 | print(filename) 59 | # Computing the embedding 60 | # First, we load the wav using the function that the speaker encoder provides. This is 61 | # important: there is preprocessing that must be applied. 62 | 63 | # The following two methods are equivalent: 64 | # - Directly load from the filepath: 65 | preprocessed_wav = encoder.preprocess_wav(filename) 66 | 67 | # Then we derive the embedding. There are many functions and parameters that the 68 | # speaker encoder interfaces. These are mostly for in-depth research. You will typically 69 | # only use this function (with its default parameters): 70 | embed = encoder.embed_utterance( 71 | preprocessed_wav) 72 | embed_path = args.embed_fpath / \ 73 | filename.split('/')[-1].replace('.wav', '.npy') 74 | np.save(embed_path, embed) 75 | print("Created the embeddings") 76 | 77 | except Exception as e: 78 | print("Caught exception: %s" % repr(e)) 79 | print("Restarting\n") 80 | -------------------------------------------------------------------------------- /encoder/plot_umap.py: -------------------------------------------------------------------------------- 1 | from encoder.visualizations import Visualizations 2 | from encoder.data_objects import SpeakerVerificationDataLoader, SpeakerVerificationDataset 3 | from encoder.params_model import * 4 | from encoder.model import SpeakerEncoder 5 | from utils.profiler import Profiler 6 | from pathlib import Path 7 | import torch 8 | 9 | def sync(device: torch.device): 10 | # FIXME 11 | return 12 | # For correct profiling (cuda operations are async) 13 | if device.type == "cuda": 14 | torch.cuda.synchronize(device) 15 | 16 | def train(run_id: str, clean_data_root: Path, models_dir: Path, umap_every: int, save_every: int, 17 | backup_every: int, vis_every: int, force_restart: bool, visdom_server: str, 18 | no_visdom: bool): 19 | # Create a dataset and a dataloader 20 | dataset = SpeakerVerificationDataset(clean_data_root) 21 | loader = SpeakerVerificationDataLoader( 22 | dataset, 23 | speakers_per_batch, 24 | utterances_per_speaker, 25 | num_workers=8, 26 | ) 27 | 28 | # Setup the device on which to run the forward pass and the loss. These can be different, 29 | # because the forward pass is faster on the GPU whereas the loss is often (depending on your 30 | # hyperparameters) faster on the CPU. 31 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 32 | # FIXME: currently, the gradient is None if loss_device is cuda 33 | loss_device = torch.device("cpu") 34 | 35 | # Create the model and the optimizer 36 | model = SpeakerEncoder(device, loss_device) 37 | optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate_init) 38 | init_step = 1 39 | 40 | # Configure file path for the model 41 | state_fpath = models_dir.joinpath(run_id + ".pt") 42 | backup_dir = models_dir.joinpath(run_id + "_backups") 43 | 44 | # Load any existing model 45 | if not force_restart: 46 | if state_fpath.exists(): 47 | print("Found existing model \"%s\", loading it and resuming training." % run_id) 48 | checkpoint = torch.load(state_fpath) 49 | init_step = checkpoint["step"] 50 | model.load_state_dict(checkpoint["model_state"]) 51 | optimizer.load_state_dict(checkpoint["optimizer_state"]) 52 | optimizer.param_groups[0]["lr"] = learning_rate_init 53 | else: 54 | print("No model \"%s\" found, starting training from scratch." % run_id) 55 | else: 56 | print("Starting the training from scratch.") 57 | model.train() 58 | 59 | # Initialize the visualization environment 60 | vis = Visualizations(run_id, vis_every, server=visdom_server, disabled=no_visdom) 61 | vis.log_dataset(dataset) 62 | vis.log_params() 63 | device_name = str(torch.cuda.get_device_name(0) if torch.cuda.is_available() else "CPU") 64 | vis.log_implementation({"Device": device_name}) 65 | 66 | # Training loop 67 | profiler = Profiler(summarize_every=10, disabled=False) 68 | for step, speaker_batch in enumerate(loader, init_step): 69 | profiler.tick("Blocking, waiting for batch (threaded)") 70 | 71 | # Forward pass 72 | inputs = torch.from_numpy(speaker_batch.data).to(device) 73 | sync(device) 74 | profiler.tick("Data to %s" % device) 75 | embeds = model(inputs) 76 | sync(device) 77 | profiler.tick("Forward pass") 78 | embeds_loss = embeds.view((speakers_per_batch, utterances_per_speaker, -1)).to(loss_device) 79 | loss, eer = model.loss(embeds_loss) 80 | sync(loss_device) 81 | profiler.tick("Loss") 82 | 83 | # Backward pass 84 | model.zero_grad() 85 | loss.backward() 86 | profiler.tick("Backward pass") 87 | model.do_gradient_ops() 88 | optimizer.step() 89 | profiler.tick("Parameter update") 90 | 91 | # Update visualizations 92 | # learning_rate = optimizer.param_groups[0]["lr"] 93 | vis.update(loss.item(), eer, step) 94 | 95 | # Draw projections and save them to the backup folder 96 | if umap_every != 0 and step % umap_every == 0: 97 | print("Drawing and saving projections (step %d)" % step) 98 | backup_dir.mkdir(exist_ok=True) 99 | projection_fpath = backup_dir.joinpath("%s_umap_%06d.png" % (run_id, step)) 100 | embeds = embeds.detach().cpu().numpy() 101 | vis.draw_projections(embeds, utterances_per_speaker, step, projection_fpath) 102 | vis.save() 103 | 104 | 105 | -------------------------------------------------------------------------------- /encoder/audio.py: -------------------------------------------------------------------------------- 1 | from scipy.ndimage.morphology import binary_dilation 2 | from encoder.params_data import * 3 | from pathlib import Path 4 | from typing import Optional, Union 5 | import numpy as np 6 | import webrtcvad 7 | import librosa 8 | import struct 9 | 10 | int16_max = (2 ** 15) - 1 11 | 12 | 13 | def preprocess_wav(fpath_or_wav: Union[str, Path, np.ndarray], 14 | source_sr: Optional[int] = None): 15 | """ 16 | Applies the preprocessing operations used in training the Speaker Encoder to a waveform 17 | either on disk or in memory. The waveform will be resampled to match the data hyperparameters. 18 | 19 | :param fpath_or_wav: either a filepath to an audio file (many extensions are supported, not 20 | just .wav), either the waveform as a numpy array of floats. 21 | :param source_sr: if passing an audio waveform, the sampling rate of the waveform before 22 | preprocessing. After preprocessing, the waveform's sampling rate will match the data 23 | hyperparameters. If passing a filepath, the sampling rate will be automatically detected and 24 | this argument will be ignored. 25 | """ 26 | # Load the wav from disk if needed 27 | if isinstance(fpath_or_wav, str) or isinstance(fpath_or_wav, Path): 28 | wav, source_sr = librosa.load(fpath_or_wav, sr=None) 29 | else: 30 | wav = fpath_or_wav 31 | 32 | # Resample the wav if needed 33 | if source_sr is not None and source_sr != sampling_rate: 34 | wav = librosa.resample(wav, source_sr, sampling_rate) 35 | 36 | # Apply the preprocessing: normalize volume and shorten long silences 37 | wav = normalize_volume(wav, audio_norm_target_dBFS, increase_only=True) 38 | wav = trim_long_silences(wav) 39 | 40 | return wav 41 | 42 | 43 | def wav_to_mel_spectrogram(wav): 44 | """ 45 | Derives a mel spectrogram ready to be used by the encoder from a preprocessed audio waveform. 46 | Note: this not a log-mel spectrogram. 47 | """ 48 | frames = librosa.feature.melspectrogram( 49 | wav, 50 | sampling_rate, 51 | n_fft=int(sampling_rate * mel_window_length / 1000), 52 | hop_length=int(sampling_rate * mel_window_step / 1000), 53 | n_mels=mel_n_channels 54 | ) 55 | return frames.astype(np.float32).T 56 | 57 | 58 | def trim_long_silences(wav): 59 | """ 60 | Ensures that segments without voice in the waveform remain no longer than a 61 | threshold determined by the VAD parameters in params.py. 62 | 63 | :param wav: the raw waveform as a numpy array of floats 64 | :return: the same waveform with silences trimmed away (length <= original wav length) 65 | """ 66 | # Compute the voice detection window size 67 | samples_per_window = (vad_window_length * sampling_rate) // 1000 68 | 69 | # Trim the end of the audio to have a multiple of the window size 70 | wav = wav[:len(wav) - (len(wav) % samples_per_window)] 71 | 72 | # Convert the float waveform to 16-bit mono PCM 73 | pcm_wave = struct.pack("%dh" % len(wav), *(np.round(wav * int16_max)).astype(np.int16)) 74 | 75 | # Perform voice activation detection 76 | voice_flags = [] 77 | vad = webrtcvad.Vad(mode=3) 78 | for window_start in range(0, len(wav), samples_per_window): 79 | window_end = window_start + samples_per_window 80 | voice_flags.append(vad.is_speech(pcm_wave[window_start * 2:window_end * 2], 81 | sample_rate=sampling_rate)) 82 | voice_flags = np.array(voice_flags) 83 | 84 | # Smooth the voice detection with a moving average 85 | def moving_average(array, width): 86 | array_padded = np.concatenate((np.zeros((width - 1) // 2), array, np.zeros(width // 2))) 87 | ret = np.cumsum(array_padded, dtype=float) 88 | ret[width:] = ret[width:] - ret[:-width] 89 | return ret[width - 1:] / width 90 | 91 | audio_mask = moving_average(voice_flags, vad_moving_average_width) 92 | audio_mask = np.round(audio_mask).astype(np.bool) 93 | 94 | # Dilate the voiced regions 95 | audio_mask = binary_dilation(audio_mask, np.ones(vad_max_silence_length + 1)) 96 | audio_mask = np.repeat(audio_mask, samples_per_window) 97 | 98 | return wav[audio_mask == True] 99 | 100 | 101 | def normalize_volume(wav, target_dBFS, increase_only=False, decrease_only=False): 102 | if increase_only and decrease_only: 103 | raise ValueError("Both increase only and decrease only are set") 104 | dBFS_change = target_dBFS - 10 * np.log10(np.mean(wav ** 2)) 105 | if (dBFS_change < 0 and increase_only) or (dBFS_change > 0 and decrease_only): 106 | return wav 107 | return wav * (10 ** (dBFS_change / 20)) 108 | -------------------------------------------------------------------------------- /encoder/train.py: -------------------------------------------------------------------------------- 1 | from encoder.visualizations import Visualizations 2 | from encoder.data_objects import SpeakerVerificationDataLoader, SpeakerVerificationDataset 3 | from encoder.params_model import * 4 | from encoder.model import SpeakerEncoder 5 | from utils.profiler import Profiler 6 | from pathlib import Path 7 | import torch 8 | 9 | def sync(device: torch.device): 10 | # FIXME 11 | return 12 | # For correct profiling (cuda operations are async) 13 | if device.type == "cuda": 14 | torch.cuda.synchronize(device) 15 | 16 | def train(run_id: str, clean_data_root: Path, models_dir: Path, umap_every: int, save_every: int, 17 | backup_every: int, vis_every: int, force_restart: bool, visdom_server: str, 18 | no_visdom: bool): 19 | # Create a dataset and a dataloader 20 | dataset = SpeakerVerificationDataset(clean_data_root) 21 | loader = SpeakerVerificationDataLoader( 22 | dataset, 23 | speakers_per_batch, 24 | utterances_per_speaker, 25 | num_workers=8, 26 | ) 27 | 28 | # Setup the device on which to run the forward pass and the loss. These can be different, 29 | # because the forward pass is faster on the GPU whereas the loss is often (depending on your 30 | # hyperparameters) faster on the CPU. 31 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 32 | # FIXME: currently, the gradient is None if loss_device is cuda 33 | loss_device = torch.device("cpu") 34 | 35 | # Create the model and the optimizer 36 | model = SpeakerEncoder(device, loss_device) 37 | optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate_init) 38 | init_step = 1 39 | 40 | # Configure file path for the model 41 | state_fpath = models_dir.joinpath(run_id + ".pt") 42 | backup_dir = models_dir.joinpath(run_id + "_backups") 43 | 44 | # Load any existing model 45 | if not force_restart: 46 | if state_fpath.exists(): 47 | print("Found existing model \"%s\", loading it and resuming training." % run_id) 48 | checkpoint = torch.load(state_fpath) 49 | init_step = checkpoint["step"] 50 | model.load_state_dict(checkpoint["model_state"]) 51 | optimizer.load_state_dict(checkpoint["optimizer_state"]) 52 | optimizer.param_groups[0]["lr"] = learning_rate_init 53 | else: 54 | print("No model \"%s\" found, starting training from scratch." % run_id) 55 | else: 56 | print("Starting the training from scratch.") 57 | model.train() 58 | 59 | # Initialize the visualization environment 60 | vis = Visualizations(run_id, vis_every, server=visdom_server, disabled=no_visdom) 61 | vis.log_dataset(dataset) 62 | vis.log_params() 63 | device_name = str(torch.cuda.get_device_name(0) if torch.cuda.is_available() else "CPU") 64 | vis.log_implementation({"Device": device_name}) 65 | 66 | # Training loop 67 | profiler = Profiler(summarize_every=10, disabled=False) 68 | for step, speaker_batch in enumerate(loader, init_step): 69 | profiler.tick("Blocking, waiting for batch (threaded)") 70 | 71 | # Forward pass 72 | inputs = torch.from_numpy(speaker_batch.data).to(device) 73 | sync(device) 74 | profiler.tick("Data to %s" % device) 75 | embeds = model(inputs) 76 | sync(device) 77 | profiler.tick("Forward pass") 78 | embeds_loss = embeds.view((speakers_per_batch, utterances_per_speaker, -1)).to(loss_device) 79 | loss, eer = model.loss(embeds_loss) 80 | sync(loss_device) 81 | profiler.tick("Loss") 82 | 83 | # Backward pass 84 | model.zero_grad() 85 | loss.backward() 86 | profiler.tick("Backward pass") 87 | model.do_gradient_ops() 88 | optimizer.step() 89 | profiler.tick("Parameter update") 90 | 91 | # Update visualizations 92 | # learning_rate = optimizer.param_groups[0]["lr"] 93 | vis.update(loss.item(), eer, step) 94 | 95 | # Draw projections and save them to the backup folder 96 | if umap_every != 0 and step % umap_every == 0: 97 | print("Drawing and saving projections (step %d)" % step) 98 | backup_dir.mkdir(exist_ok=True) 99 | projection_fpath = backup_dir.joinpath("%s_umap_%06d.png" % (run_id, step)) 100 | embeds = embeds.detach().cpu().numpy() 101 | vis.draw_projections(embeds, utterances_per_speaker, step, projection_fpath) 102 | vis.save() 103 | 104 | # Overwrite the latest version of the model 105 | if save_every != 0 and step % save_every == 0: 106 | print("Saving the model (step %d)" % step) 107 | torch.save({ 108 | "step": step + 1, 109 | "model_state": model.state_dict(), 110 | "optimizer_state": optimizer.state_dict(), 111 | }, state_fpath) 112 | 113 | # Make a backup 114 | if backup_every != 0 and step % backup_every == 0: 115 | print("Making a backup (step %d)" % step) 116 | backup_dir.mkdir(exist_ok=True) 117 | backup_fpath = backup_dir.joinpath("%s_bak_%06d.pt" % (run_id, step)) 118 | torch.save({ 119 | "step": step + 1, 120 | "model_state": model.state_dict(), 121 | "optimizer_state": optimizer.state_dict(), 122 | }, backup_fpath) 123 | 124 | profiler.tick("Extras (visualizations, saving)") 125 | -------------------------------------------------------------------------------- /encoder/model.py: -------------------------------------------------------------------------------- 1 | from encoder.params_model import * 2 | from encoder.params_data import * 3 | from scipy.interpolate import interp1d 4 | from sklearn.metrics import roc_curve 5 | from torch.nn.utils import clip_grad_norm_ 6 | from scipy.optimize import brentq 7 | from torch import nn 8 | import numpy as np 9 | import torch 10 | 11 | 12 | class SpeakerEncoder(nn.Module): 13 | def __init__(self, device, loss_device): 14 | super().__init__() 15 | self.loss_device = loss_device 16 | 17 | # Network defition 18 | self.lstm = nn.LSTM(input_size=mel_n_channels, 19 | hidden_size=model_hidden_size, 20 | num_layers=model_num_layers, 21 | batch_first=True).to(device) 22 | self.linear = nn.Linear(in_features=model_hidden_size, 23 | out_features=model_embedding_size).to(device) 24 | self.relu = torch.nn.ReLU().to(device) 25 | 26 | # Cosine similarity scaling (with fixed initial parameter values) 27 | self.similarity_weight = nn.Parameter(torch.tensor([10.])).to(loss_device) 28 | self.similarity_bias = nn.Parameter(torch.tensor([-5.])).to(loss_device) 29 | 30 | # Loss 31 | self.loss_fn = nn.CrossEntropyLoss().to(loss_device) 32 | 33 | def do_gradient_ops(self): 34 | # Gradient scale 35 | self.similarity_weight.grad *= 0.01 36 | self.similarity_bias.grad *= 0.01 37 | 38 | # Gradient clipping 39 | clip_grad_norm_(self.parameters(), 3, norm_type=2) 40 | 41 | def forward(self, utterances, hidden_init=None): 42 | """ 43 | Computes the embeddings of a batch of utterance spectrograms. 44 | 45 | :param utterances: batch of mel-scale filterbanks of same duration as a tensor of shape 46 | (batch_size, n_frames, n_channels) 47 | :param hidden_init: initial hidden state of the LSTM as a tensor of shape (num_layers, 48 | batch_size, hidden_size). Will default to a tensor of zeros if None. 49 | :return: the embeddings as a tensor of shape (batch_size, embedding_size) 50 | """ 51 | # Pass the input through the LSTM layers and retrieve all outputs, the final hidden state 52 | # and the final cell state. 53 | out, (hidden, cell) = self.lstm(utterances, hidden_init) 54 | 55 | # We take only the hidden state of the last layer 56 | embeds_raw = self.relu(self.linear(hidden[-1])) 57 | 58 | # L2-normalize it 59 | embeds = embeds_raw / torch.norm(embeds_raw, dim=1, keepdim=True) 60 | 61 | return embeds 62 | 63 | def similarity_matrix(self, embeds): 64 | """ 65 | Computes the similarity matrix according the section 2.1 of GE2E. 66 | 67 | :param embeds: the embeddings as a tensor of shape (speakers_per_batch, 68 | utterances_per_speaker, embedding_size) 69 | :return: the similarity matrix as a tensor of shape (speakers_per_batch, 70 | utterances_per_speaker, speakers_per_batch) 71 | """ 72 | speakers_per_batch, utterances_per_speaker = embeds.shape[:2] 73 | 74 | # Inclusive centroids (1 per speaker). Cloning is needed for reverse differentiation 75 | centroids_incl = torch.mean(embeds, dim=1, keepdim=True) 76 | centroids_incl = centroids_incl.clone() / torch.norm(centroids_incl, dim=2, keepdim=True) 77 | 78 | # Exclusive centroids (1 per utterance) 79 | centroids_excl = (torch.sum(embeds, dim=1, keepdim=True) - embeds) 80 | centroids_excl /= (utterances_per_speaker - 1) 81 | centroids_excl = centroids_excl.clone() / torch.norm(centroids_excl, dim=2, keepdim=True) 82 | 83 | # Similarity matrix. The cosine similarity of already 2-normed vectors is simply the dot 84 | # product of these vectors (which is just an element-wise multiplication reduced by a sum). 85 | # We vectorize the computation for efficiency. 86 | sim_matrix = torch.zeros(speakers_per_batch, utterances_per_speaker, 87 | speakers_per_batch).to(self.loss_device) 88 | mask_matrix = 1 - np.eye(speakers_per_batch, dtype=np.int) 89 | for j in range(speakers_per_batch): 90 | mask = np.where(mask_matrix[j])[0] 91 | sim_matrix[mask, :, j] = (embeds[mask] * centroids_incl[j]).sum(dim=2) 92 | sim_matrix[j, :, j] = (embeds[j] * centroids_excl[j]).sum(dim=1) 93 | 94 | ## Even more vectorized version (slower maybe because of transpose) 95 | # sim_matrix2 = torch.zeros(speakers_per_batch, speakers_per_batch, utterances_per_speaker 96 | # ).to(self.loss_device) 97 | # eye = np.eye(speakers_per_batch, dtype=np.int) 98 | # mask = np.where(1 - eye) 99 | # sim_matrix2[mask] = (embeds[mask[0]] * centroids_incl[mask[1]]).sum(dim=2) 100 | # mask = np.where(eye) 101 | # sim_matrix2[mask] = (embeds * centroids_excl).sum(dim=2) 102 | # sim_matrix2 = sim_matrix2.transpose(1, 2) 103 | 104 | sim_matrix = sim_matrix * self.similarity_weight + self.similarity_bias 105 | return sim_matrix 106 | 107 | def loss(self, embeds): 108 | """ 109 | Computes the softmax loss according the section 2.1 of GE2E. 110 | 111 | :param embeds: the embeddings as a tensor of shape (speakers_per_batch, 112 | utterances_per_speaker, embedding_size) 113 | :return: the loss and the EER for this batch of embeddings. 114 | """ 115 | speakers_per_batch, utterances_per_speaker = embeds.shape[:2] 116 | 117 | # Loss 118 | sim_matrix = self.similarity_matrix(embeds) 119 | sim_matrix = sim_matrix.reshape((speakers_per_batch * utterances_per_speaker, 120 | speakers_per_batch)) 121 | ground_truth = np.repeat(np.arange(speakers_per_batch), utterances_per_speaker) 122 | target = torch.from_numpy(ground_truth).long().to(self.loss_device) 123 | loss = self.loss_fn(sim_matrix, target) 124 | 125 | # EER (not backpropagated) 126 | with torch.no_grad(): 127 | inv_argmax = lambda i: np.eye(1, speakers_per_batch, i, dtype=np.int)[0] 128 | labels = np.array([inv_argmax(i) for i in ground_truth]) 129 | preds = sim_matrix.detach().cpu().numpy() 130 | 131 | # Snippet from https://yangcha.github.io/EER-ROC/ 132 | fpr, tpr, thresholds = roc_curve(labels.flatten(), preds.flatten()) 133 | eer = brentq(lambda x: 1. - x - interp1d(fpr, tpr)(x), 0., 1.) 134 | 135 | return loss, eer -------------------------------------------------------------------------------- /encoder/visualizations.py: -------------------------------------------------------------------------------- 1 | from encoder.data_objects.speaker_verification_dataset import SpeakerVerificationDataset 2 | from datetime import datetime 3 | from time import perf_counter as timer 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | # import webbrowser 7 | import visdom 8 | import umap 9 | 10 | colormap = np.array([ 11 | [76, 255, 0], 12 | [0, 127, 70], 13 | [255, 0, 0], 14 | [255, 217, 38], 15 | [0, 135, 255], 16 | [165, 0, 165], 17 | [255, 167, 255], 18 | [0, 255, 255], 19 | [255, 96, 38], 20 | [142, 76, 0], 21 | [33, 0, 127], 22 | [0, 0, 0], 23 | [183, 183, 183], 24 | ], dtype=np.float) / 255 25 | 26 | 27 | class Visualizations: 28 | def __init__(self, env_name=None, update_every=10, server="http://localhost", disabled=False): 29 | # Tracking data 30 | self.last_update_timestamp = timer() 31 | self.update_every = update_every 32 | self.step_times = [] 33 | self.losses = [] 34 | self.eers = [] 35 | print("Updating the visualizations every %d steps." % update_every) 36 | 37 | # If visdom is disabled TODO: use a better paradigm for that 38 | self.disabled = disabled 39 | if self.disabled: 40 | return 41 | 42 | # Set the environment name 43 | now = str(datetime.now().strftime("%d-%m %Hh%M")) 44 | if env_name is None: 45 | self.env_name = now 46 | else: 47 | self.env_name = "%s (%s)" % (env_name, now) 48 | 49 | # Connect to visdom and open the corresponding window in the browser 50 | try: 51 | self.vis = visdom.Visdom(server, env=self.env_name, raise_exceptions=True) 52 | except ConnectionError: 53 | raise Exception("No visdom server detected. Run the command \"visdom\" in your CLI to " 54 | "start it.") 55 | # webbrowser.open("http://localhost:8097/env/" + self.env_name) 56 | 57 | # Create the windows 58 | self.loss_win = None 59 | self.eer_win = None 60 | # self.lr_win = None 61 | self.implementation_win = None 62 | self.projection_win = None 63 | self.implementation_string = "" 64 | 65 | def log_params(self): 66 | if self.disabled: 67 | return 68 | from encoder import params_data 69 | from encoder import params_model 70 | param_string = "Model parameters:
" 71 | for param_name in (p for p in dir(params_model) if not p.startswith("__")): 72 | value = getattr(params_model, param_name) 73 | param_string += "\t%s: %s
" % (param_name, value) 74 | param_string += "Data parameters:
" 75 | for param_name in (p for p in dir(params_data) if not p.startswith("__")): 76 | value = getattr(params_data, param_name) 77 | param_string += "\t%s: %s
" % (param_name, value) 78 | self.vis.text(param_string, opts={"title": "Parameters"}) 79 | 80 | def log_dataset(self, dataset: SpeakerVerificationDataset): 81 | if self.disabled: 82 | return 83 | dataset_string = "" 84 | dataset_string += "Speakers: %s\n" % len(dataset.speakers) 85 | dataset_string += "\n" + dataset.get_logs() 86 | dataset_string = dataset_string.replace("\n", "
") 87 | self.vis.text(dataset_string, opts={"title": "Dataset"}) 88 | 89 | def log_implementation(self, params): 90 | if self.disabled: 91 | return 92 | implementation_string = "" 93 | for param, value in params.items(): 94 | implementation_string += "%s: %s\n" % (param, value) 95 | implementation_string = implementation_string.replace("\n", "
") 96 | self.implementation_string = implementation_string 97 | self.implementation_win = self.vis.text( 98 | implementation_string, 99 | opts={"title": "Training implementation"} 100 | ) 101 | 102 | def update(self, loss, eer, step): 103 | # Update the tracking data 104 | now = timer() 105 | self.step_times.append(1000 * (now - self.last_update_timestamp)) 106 | self.last_update_timestamp = now 107 | self.losses.append(loss) 108 | self.eers.append(eer) 109 | print(".", end="") 110 | 111 | # Update the plots every steps 112 | if step % self.update_every != 0: 113 | return 114 | time_string = "Step time: mean: %5dms std: %5dms" % \ 115 | (int(np.mean(self.step_times)), int(np.std(self.step_times))) 116 | print("\nStep %6d Loss: %.4f EER: %.4f %s" % 117 | (step, np.mean(self.losses), np.mean(self.eers), time_string)) 118 | if not self.disabled: 119 | self.loss_win = self.vis.line( 120 | [np.mean(self.losses)], 121 | [step], 122 | win=self.loss_win, 123 | update="append" if self.loss_win else None, 124 | opts=dict( 125 | legend=["Avg. loss"], 126 | xlabel="Step", 127 | ylabel="Loss", 128 | title="Loss", 129 | ) 130 | ) 131 | self.eer_win = self.vis.line( 132 | [np.mean(self.eers)], 133 | [step], 134 | win=self.eer_win, 135 | update="append" if self.eer_win else None, 136 | opts=dict( 137 | legend=["Avg. EER"], 138 | xlabel="Step", 139 | ylabel="EER", 140 | title="Equal error rate" 141 | ) 142 | ) 143 | if self.implementation_win is not None: 144 | self.vis.text( 145 | self.implementation_string + ("%s" % time_string), 146 | win=self.implementation_win, 147 | opts={"title": "Training implementation"}, 148 | ) 149 | 150 | # Reset the tracking 151 | self.losses.clear() 152 | self.eers.clear() 153 | self.step_times.clear() 154 | 155 | def draw_projections(self, embeds, utterances_per_speaker, step, out_fpath=None, 156 | max_speakers=10): 157 | max_speakers = min(max_speakers, len(colormap)) 158 | embeds = embeds[:max_speakers * utterances_per_speaker] 159 | 160 | n_speakers = len(embeds) // utterances_per_speaker 161 | ground_truth = np.repeat(np.arange(n_speakers), utterances_per_speaker) 162 | colors = [colormap[i] for i in ground_truth] 163 | 164 | reducer = umap.UMAP() 165 | projected = reducer.fit_transform(embeds) 166 | plt.scatter(projected[:, 0], projected[:, 1], c=colors) 167 | plt.gca().set_aspect("equal", "datalim") 168 | plt.title("UMAP projection (step %d)" % step) 169 | if not self.disabled: 170 | self.projection_win = self.vis.matplot(plt, win=self.projection_win) 171 | if out_fpath is not None: 172 | plt.savefig(out_fpath) 173 | plt.clf() 174 | 175 | def save(self): 176 | if not self.disabled: 177 | self.vis.save([self.env_name]) 178 | -------------------------------------------------------------------------------- /encoder/preprocess.py: -------------------------------------------------------------------------------- 1 | from multiprocess.pool import ThreadPool 2 | from encoder.params_data import * 3 | from encoder.config import librispeech_datasets, anglophone_nationalites 4 | from datetime import datetime 5 | from encoder import audio 6 | from pathlib import Path 7 | from tqdm import tqdm 8 | import numpy as np 9 | 10 | 11 | class DatasetLog: 12 | """ 13 | Registers metadata about the dataset in a text file. 14 | """ 15 | def __init__(self, root, name): 16 | self.text_file = open(Path(root, "Log_%s.txt" % name.replace("/", "_")), "w") 17 | self.sample_data = dict() 18 | 19 | start_time = str(datetime.now().strftime("%A %d %B %Y at %H:%M")) 20 | self.write_line("Creating dataset %s on %s" % (name, start_time)) 21 | self.write_line("-----") 22 | self._log_params() 23 | 24 | def _log_params(self): 25 | from encoder import params_data 26 | self.write_line("Parameter values:") 27 | for param_name in (p for p in dir(params_data) if not p.startswith("__")): 28 | value = getattr(params_data, param_name) 29 | self.write_line("\t%s: %s" % (param_name, value)) 30 | self.write_line("-----") 31 | 32 | def write_line(self, line): 33 | self.text_file.write("%s\n" % line) 34 | 35 | def add_sample(self, **kwargs): 36 | for param_name, value in kwargs.items(): 37 | if not param_name in self.sample_data: 38 | self.sample_data[param_name] = [] 39 | self.sample_data[param_name].append(value) 40 | 41 | def finalize(self): 42 | self.write_line("Statistics:") 43 | for param_name, values in self.sample_data.items(): 44 | self.write_line("\t%s:" % param_name) 45 | self.write_line("\t\tmin %.3f, max %.3f" % (np.min(values), np.max(values))) 46 | self.write_line("\t\tmean %.3f, median %.3f" % (np.mean(values), np.median(values))) 47 | self.write_line("-----") 48 | end_time = str(datetime.now().strftime("%A %d %B %Y at %H:%M")) 49 | self.write_line("Finished on %s" % end_time) 50 | self.text_file.close() 51 | 52 | 53 | def _init_preprocess_dataset(dataset_name, datasets_root, out_dir) -> (Path, DatasetLog): 54 | dataset_root = datasets_root.joinpath(dataset_name) 55 | if not dataset_root.exists(): 56 | print("Couldn\'t find %s, skipping this dataset." % dataset_root) 57 | return None, None 58 | return dataset_root, DatasetLog(out_dir, dataset_name) 59 | 60 | 61 | def _preprocess_speaker_dirs(speaker_dirs, dataset_name, datasets_root, out_dir, extension, 62 | skip_existing, logger): 63 | print("%s: Preprocessing data for %d speakers." % (dataset_name, len(speaker_dirs))) 64 | 65 | # Function to preprocess utterances for one speaker 66 | def preprocess_speaker(speaker_dir: Path): 67 | # Give a name to the speaker that includes its dataset 68 | speaker_name = "_".join(speaker_dir.relative_to(datasets_root).parts) 69 | 70 | # Create an output directory with that name, as well as a txt file containing a 71 | # reference to each source file. 72 | speaker_out_dir = out_dir.joinpath(speaker_name) 73 | speaker_out_dir.mkdir(exist_ok=True) 74 | sources_fpath = speaker_out_dir.joinpath("_sources.txt") 75 | 76 | # There's a possibility that the preprocessing was interrupted earlier, check if 77 | # there already is a sources file. 78 | if sources_fpath.exists(): 79 | try: 80 | with sources_fpath.open("r") as sources_file: 81 | existing_fnames = {line.split(",")[0] for line in sources_file} 82 | except: 83 | existing_fnames = {} 84 | else: 85 | existing_fnames = {} 86 | 87 | # Gather all audio files for that speaker recursively 88 | sources_file = sources_fpath.open("a" if skip_existing else "w") 89 | for in_fpath in speaker_dir.glob("**/*.%s" % extension): 90 | # Check if the target output file already exists 91 | out_fname = "_".join(in_fpath.relative_to(speaker_dir).parts) 92 | out_fname = out_fname.replace(".%s" % extension, ".npy") 93 | if skip_existing and out_fname in existing_fnames: 94 | continue 95 | 96 | # Load and preprocess the waveform 97 | wav = audio.preprocess_wav(in_fpath) 98 | if len(wav) == 0: 99 | continue 100 | 101 | # Create the mel spectrogram, discard those that are too short 102 | frames = audio.wav_to_mel_spectrogram(wav) 103 | if len(frames) < partials_n_frames: 104 | continue 105 | 106 | out_fpath = speaker_out_dir.joinpath(out_fname) 107 | np.save(out_fpath, frames) 108 | logger.add_sample(duration=len(wav) / sampling_rate) 109 | sources_file.write("%s,%s\n" % (out_fname, in_fpath)) 110 | 111 | sources_file.close() 112 | 113 | # Process the utterances for each speaker 114 | with ThreadPool(8) as pool: 115 | list(tqdm(pool.imap(preprocess_speaker, speaker_dirs), dataset_name, len(speaker_dirs), 116 | unit="speakers")) 117 | logger.finalize() 118 | print("Done preprocessing %s.\n" % dataset_name) 119 | 120 | 121 | def preprocess_librispeech(datasets_root: Path, out_dir: Path, skip_existing=False): 122 | for dataset_name in librispeech_datasets["train"]["other"]: 123 | # Initialize the preprocessing 124 | dataset_root, logger = _init_preprocess_dataset(dataset_name, datasets_root, out_dir) 125 | if not dataset_root: 126 | return 127 | 128 | # Preprocess all speakers 129 | speaker_dirs = list(dataset_root.glob("*")) 130 | _preprocess_speaker_dirs(speaker_dirs, dataset_name, datasets_root, out_dir, "flac", 131 | skip_existing, logger) 132 | 133 | 134 | def preprocess_voxceleb1(datasets_root: Path, out_dir: Path, skip_existing=False): 135 | # Initialize the preprocessing 136 | dataset_name = "VoxCeleb1" 137 | dataset_root, logger = _init_preprocess_dataset(dataset_name, datasets_root, out_dir) 138 | if not dataset_root: 139 | return 140 | 141 | # Get the contents of the meta file 142 | with dataset_root.joinpath("vox1_meta.csv").open("r") as metafile: 143 | metadata = [line.split("\t") for line in metafile][1:] 144 | 145 | # Select the ID and the nationality, filter out non-anglophone speakers 146 | nationalities = {line[0]: line[3] for line in metadata} 147 | keep_speaker_ids = [speaker_id for speaker_id, nationality in nationalities.items() if 148 | nationality.lower() in anglophone_nationalites] 149 | print("VoxCeleb1: using samples from %d (presumed anglophone) speakers out of %d." % 150 | (len(keep_speaker_ids), len(nationalities))) 151 | 152 | # Get the speaker directories for anglophone speakers only 153 | speaker_dirs = dataset_root.joinpath("wav").glob("*") 154 | speaker_dirs = [speaker_dir for speaker_dir in speaker_dirs if 155 | speaker_dir.name in keep_speaker_ids] 156 | print("VoxCeleb1: found %d anglophone speakers on the disk, %d missing (this is normal)." % 157 | (len(speaker_dirs), len(keep_speaker_ids) - len(speaker_dirs))) 158 | 159 | # Preprocess all speakers 160 | _preprocess_speaker_dirs(speaker_dirs, dataset_name, datasets_root, out_dir, "wav", 161 | skip_existing, logger) 162 | 163 | 164 | def preprocess_voxceleb2(datasets_root: Path, out_dir: Path, skip_existing=False): 165 | # Initialize the preprocessing 166 | dataset_name = "VoxCeleb2" 167 | dataset_root, logger = _init_preprocess_dataset(dataset_name, datasets_root, out_dir) 168 | if not dataset_root: 169 | return 170 | 171 | # Get the speaker directories 172 | # Preprocess all speakers 173 | speaker_dirs = list(dataset_root.joinpath("dev", "aac").glob("*")) 174 | _preprocess_speaker_dirs(speaker_dirs, dataset_name, datasets_root, out_dir, "m4a", 175 | skip_existing, logger) 176 | -------------------------------------------------------------------------------- /encoder/inference.py: -------------------------------------------------------------------------------- 1 | from encoder.params_data import * 2 | from encoder.model import SpeakerEncoder 3 | from encoder.audio import preprocess_wav # We want to expose this function from here 4 | from matplotlib import cm 5 | from encoder import audio 6 | from pathlib import Path 7 | import matplotlib.pyplot as plt 8 | import numpy as np 9 | import torch 10 | 11 | _model = None # type: SpeakerEncoder 12 | _device = None # type: torch.device 13 | 14 | 15 | def load_model(weights_fpath: Path, device=None): 16 | """ 17 | Loads the model in memory. If this function is not explicitely called, it will be run on the 18 | first call to embed_frames() with the default weights file. 19 | 20 | :param weights_fpath: the path to saved model weights. 21 | :param device: either a torch device or the name of a torch device (e.g. "cpu", "cuda"). The 22 | model will be loaded and will run on this device. Outputs will however always be on the cpu. 23 | If None, will default to your GPU if it"s available, otherwise your CPU. 24 | """ 25 | # TODO: I think the slow loading of the encoder might have something to do with the device it 26 | # was saved on. Worth investigating. 27 | global _model, _device 28 | if device is None: 29 | _device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 30 | elif isinstance(device, str): 31 | _device = torch.device(device) 32 | _model = SpeakerEncoder(_device, torch.device("cpu")) 33 | checkpoint = torch.load(weights_fpath) 34 | _model.load_state_dict(checkpoint["model_state"]) 35 | _model.eval() 36 | print("Loaded encoder \"%s\" trained to step %d" % (weights_fpath.name, checkpoint["step"])) 37 | 38 | 39 | def is_loaded(): 40 | return _model is not None 41 | 42 | 43 | def embed_frames_batch(frames_batch): 44 | """ 45 | Computes embeddings for a batch of mel spectrogram. 46 | 47 | :param frames_batch: a batch mel of spectrogram as a numpy array of float32 of shape 48 | (batch_size, n_frames, n_channels) 49 | :return: the embeddings as a numpy array of float32 of shape (batch_size, model_embedding_size) 50 | """ 51 | if _model is None: 52 | raise Exception("Model was not loaded. Call load_model() before inference.") 53 | 54 | frames = torch.from_numpy(frames_batch).to(_device) 55 | embed = _model.forward(frames).detach().cpu().numpy() 56 | return embed 57 | 58 | 59 | def compute_partial_slices(n_samples, partial_utterance_n_frames=partials_n_frames, 60 | min_pad_coverage=0.75, overlap=0.5): 61 | """ 62 | Computes where to split an utterance waveform and its corresponding mel spectrogram to obtain 63 | partial utterances of each. Both the waveform and the mel 64 | spectrogram slices are returned, so as to make each partial utterance waveform correspond to 65 | its spectrogram. This function assumes that the mel spectrogram parameters used are those 66 | defined in params_data.py. 67 | 68 | The returned ranges may be indexing further than the length of the waveform. It is 69 | recommended that you pad the waveform with zeros up to wave_slices[-1].stop. 70 | 71 | :param n_samples: the number of samples in the waveform 72 | :param partial_utterance_n_frames: the number of mel spectrogram frames in each partial 73 | utterance 74 | :param min_pad_coverage: when reaching the last partial utterance, it may or may not have 75 | enough frames. If at least of are present, 76 | then the last partial utterance will be considered, as if we padded the audio. Otherwise, 77 | it will be discarded, as if we trimmed the audio. If there aren't enough frames for 1 partial 78 | utterance, this parameter is ignored so that the function always returns at least 1 slice. 79 | :param overlap: by how much the partial utterance should overlap. If set to 0, the partial 80 | utterances are entirely disjoint. 81 | :return: the waveform slices and mel spectrogram slices as lists of array slices. Index 82 | respectively the waveform and the mel spectrogram with these slices to obtain the partial 83 | utterances. 84 | """ 85 | assert 0 <= overlap < 1 86 | assert 0 < min_pad_coverage <= 1 87 | 88 | samples_per_frame = int((sampling_rate * mel_window_step / 1000)) 89 | n_frames = int(np.ceil((n_samples + 1) / samples_per_frame)) 90 | frame_step = max(int(np.round(partial_utterance_n_frames * (1 - overlap))), 1) 91 | 92 | # Compute the slices 93 | wav_slices, mel_slices = [], [] 94 | steps = max(1, n_frames - partial_utterance_n_frames + frame_step + 1) 95 | for i in range(0, steps, frame_step): 96 | mel_range = np.array([i, i + partial_utterance_n_frames]) 97 | wav_range = mel_range * samples_per_frame 98 | mel_slices.append(slice(*mel_range)) 99 | wav_slices.append(slice(*wav_range)) 100 | 101 | # Evaluate whether extra padding is warranted or not 102 | last_wav_range = wav_slices[-1] 103 | coverage = (n_samples - last_wav_range.start) / (last_wav_range.stop - last_wav_range.start) 104 | if coverage < min_pad_coverage and len(mel_slices) > 1: 105 | mel_slices = mel_slices[:-1] 106 | wav_slices = wav_slices[:-1] 107 | 108 | return wav_slices, mel_slices 109 | 110 | 111 | def embed_utterance(wav, using_partials=True, return_partials=False, **kwargs): 112 | """ 113 | Computes an embedding for a single utterance. 114 | 115 | # TODO: handle multiple wavs to benefit from batching on GPU 116 | :param wav: a preprocessed (see audio.py) utterance waveform as a numpy array of float32 117 | :param using_partials: if True, then the utterance is split in partial utterances of 118 | frames and the utterance embedding is computed from their 119 | normalized average. If False, the utterance is instead computed from feeding the entire 120 | spectogram to the network. 121 | :param return_partials: if True, the partial embeddings will also be returned along with the 122 | wav slices that correspond to the partial embeddings. 123 | :param kwargs: additional arguments to compute_partial_splits() 124 | :return: the embedding as a numpy array of float32 of shape (model_embedding_size,). If 125 | is True, the partial utterances as a numpy array of float32 of shape 126 | (n_partials, model_embedding_size) and the wav partials as a list of slices will also be 127 | returned. If is simultaneously set to False, both these values will be None 128 | instead. 129 | """ 130 | # Process the entire utterance if not using partials 131 | if not using_partials: 132 | frames = audio.wav_to_mel_spectrogram(wav) 133 | embed = embed_frames_batch(frames[None, ...])[0] 134 | if return_partials: 135 | return embed, None, None 136 | return embed 137 | 138 | # Compute where to split the utterance into partials and pad if necessary 139 | wave_slices, mel_slices = compute_partial_slices(len(wav), **kwargs) 140 | max_wave_length = wave_slices[-1].stop 141 | if max_wave_length >= len(wav): 142 | wav = np.pad(wav, (0, max_wave_length - len(wav)), "constant") 143 | 144 | # Split the utterance into partials 145 | frames = audio.wav_to_mel_spectrogram(wav) 146 | frames_batch = np.array([frames[s] for s in mel_slices]) 147 | partial_embeds = embed_frames_batch(frames_batch) 148 | 149 | # Compute the utterance embedding from the partial embeddings 150 | raw_embed = np.mean(partial_embeds, axis=0) 151 | embed = raw_embed / np.linalg.norm(raw_embed, 2) 152 | 153 | if return_partials: 154 | return embed, partial_embeds, wave_slices 155 | return embed 156 | 157 | 158 | def embed_speaker(wavs, **kwargs): 159 | raise NotImplemented() 160 | 161 | 162 | def plot_embedding_as_heatmap(embed, ax=None, title="", shape=None, color_range=(0, 0.30)): 163 | if ax is None: 164 | ax = plt.gca() 165 | 166 | if shape is None: 167 | height = int(np.sqrt(len(embed))) 168 | shape = (height, -1) 169 | embed = embed.reshape(shape) 170 | 171 | cmap = cm.get_cmap() 172 | mappable = ax.imshow(embed, cmap=cmap) 173 | cbar = plt.colorbar(mappable, ax=ax, fraction=0.046, pad=0.04) 174 | cbar.set_clim(*color_range) 175 | 176 | ax.set_xticks([]), ax.set_yticks([]) 177 | ax.set_title(title) 178 | -------------------------------------------------------------------------------- /utils/logmmse.py: -------------------------------------------------------------------------------- 1 | # The MIT License (MIT) 2 | # 3 | # Copyright (c) 2015 braindead 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | # 23 | # 24 | # This code was extracted from the logmmse package (https://pypi.org/project/logmmse/) and I 25 | # simply modified the interface to meet my needs. 26 | 27 | 28 | import numpy as np 29 | import math 30 | from scipy.special import expn 31 | from collections import namedtuple 32 | 33 | NoiseProfile = namedtuple("NoiseProfile", "sampling_rate window_size len1 len2 win n_fft noise_mu2") 34 | 35 | 36 | def profile_noise(noise, sampling_rate, window_size=0): 37 | """ 38 | Creates a profile of the noise in a given waveform. 39 | 40 | :param noise: a waveform containing noise ONLY, as a numpy array of floats or ints. 41 | :param sampling_rate: the sampling rate of the audio 42 | :param window_size: the size of the window the logmmse algorithm operates on. A default value 43 | will be picked if left as 0. 44 | :return: a NoiseProfile object 45 | """ 46 | noise, dtype = to_float(noise) 47 | noise += np.finfo(np.float64).eps 48 | 49 | if window_size == 0: 50 | window_size = int(math.floor(0.02 * sampling_rate)) 51 | 52 | if window_size % 2 == 1: 53 | window_size = window_size + 1 54 | 55 | perc = 50 56 | len1 = int(math.floor(window_size * perc / 100)) 57 | len2 = int(window_size - len1) 58 | 59 | win = np.hanning(window_size) 60 | win = win * len2 / np.sum(win) 61 | n_fft = 2 * window_size 62 | 63 | noise_mean = np.zeros(n_fft) 64 | n_frames = len(noise) // window_size 65 | for j in range(0, window_size * n_frames, window_size): 66 | noise_mean += np.absolute(np.fft.fft(win * noise[j:j + window_size], n_fft, axis=0)) 67 | noise_mu2 = (noise_mean / n_frames) ** 2 68 | 69 | return NoiseProfile(sampling_rate, window_size, len1, len2, win, n_fft, noise_mu2) 70 | 71 | 72 | def denoise(wav, noise_profile: NoiseProfile, eta=0.15): 73 | """ 74 | Cleans the noise from a speech waveform given a noise profile. The waveform must have the 75 | same sampling rate as the one used to create the noise profile. 76 | 77 | :param wav: a speech waveform as a numpy array of floats or ints. 78 | :param noise_profile: a NoiseProfile object that was created from a similar (or a segment of 79 | the same) waveform. 80 | :param eta: voice threshold for noise update. While the voice activation detection value is 81 | below this threshold, the noise profile will be continuously updated throughout the audio. 82 | Set to 0 to disable updating the noise profile. 83 | :return: the clean wav as a numpy array of floats or ints of the same length. 84 | """ 85 | wav, dtype = to_float(wav) 86 | wav += np.finfo(np.float64).eps 87 | p = noise_profile 88 | 89 | nframes = int(math.floor(len(wav) / p.len2) - math.floor(p.window_size / p.len2)) 90 | x_final = np.zeros(nframes * p.len2) 91 | 92 | aa = 0.98 93 | mu = 0.98 94 | ksi_min = 10 ** (-25 / 10) 95 | 96 | x_old = np.zeros(p.len1) 97 | xk_prev = np.zeros(p.len1) 98 | noise_mu2 = p.noise_mu2 99 | for k in range(0, nframes * p.len2, p.len2): 100 | insign = p.win * wav[k:k + p.window_size] 101 | 102 | spec = np.fft.fft(insign, p.n_fft, axis=0) 103 | sig = np.absolute(spec) 104 | sig2 = sig ** 2 105 | 106 | gammak = np.minimum(sig2 / noise_mu2, 40) 107 | 108 | if xk_prev.all() == 0: 109 | ksi = aa + (1 - aa) * np.maximum(gammak - 1, 0) 110 | else: 111 | ksi = aa * xk_prev / noise_mu2 + (1 - aa) * np.maximum(gammak - 1, 0) 112 | ksi = np.maximum(ksi_min, ksi) 113 | 114 | log_sigma_k = gammak * ksi/(1 + ksi) - np.log(1 + ksi) 115 | vad_decision = np.sum(log_sigma_k) / p.window_size 116 | if vad_decision < eta: 117 | noise_mu2 = mu * noise_mu2 + (1 - mu) * sig2 118 | 119 | a = ksi / (1 + ksi) 120 | vk = a * gammak 121 | ei_vk = 0.5 * expn(1, np.maximum(vk, 1e-8)) 122 | hw = a * np.exp(ei_vk) 123 | sig = sig * hw 124 | xk_prev = sig ** 2 125 | xi_w = np.fft.ifft(hw * spec, p.n_fft, axis=0) 126 | xi_w = np.real(xi_w) 127 | 128 | x_final[k:k + p.len2] = x_old + xi_w[0:p.len1] 129 | x_old = xi_w[p.len1:p.window_size] 130 | 131 | output = from_float(x_final, dtype) 132 | output = np.pad(output, (0, len(wav) - len(output)), mode="constant") 133 | return output 134 | 135 | 136 | ## Alternative VAD algorithm to webrctvad. It has the advantage of not requiring to install that 137 | ## darn package and it also works for any sampling rate. Maybe I'll eventually use it instead of 138 | ## webrctvad 139 | # def vad(wav, sampling_rate, eta=0.15, window_size=0): 140 | # """ 141 | # TODO: fix doc 142 | # Creates a profile of the noise in a given waveform. 143 | # 144 | # :param wav: a waveform containing noise ONLY, as a numpy array of floats or ints. 145 | # :param sampling_rate: the sampling rate of the audio 146 | # :param window_size: the size of the window the logmmse algorithm operates on. A default value 147 | # will be picked if left as 0. 148 | # :param eta: voice threshold for noise update. While the voice activation detection value is 149 | # below this threshold, the noise profile will be continuously updated throughout the audio. 150 | # Set to 0 to disable updating the noise profile. 151 | # """ 152 | # wav, dtype = to_float(wav) 153 | # wav += np.finfo(np.float64).eps 154 | # 155 | # if window_size == 0: 156 | # window_size = int(math.floor(0.02 * sampling_rate)) 157 | # 158 | # if window_size % 2 == 1: 159 | # window_size = window_size + 1 160 | # 161 | # perc = 50 162 | # len1 = int(math.floor(window_size * perc / 100)) 163 | # len2 = int(window_size - len1) 164 | # 165 | # win = np.hanning(window_size) 166 | # win = win * len2 / np.sum(win) 167 | # n_fft = 2 * window_size 168 | # 169 | # wav_mean = np.zeros(n_fft) 170 | # n_frames = len(wav) // window_size 171 | # for j in range(0, window_size * n_frames, window_size): 172 | # wav_mean += np.absolute(np.fft.fft(win * wav[j:j + window_size], n_fft, axis=0)) 173 | # noise_mu2 = (wav_mean / n_frames) ** 2 174 | # 175 | # wav, dtype = to_float(wav) 176 | # wav += np.finfo(np.float64).eps 177 | # 178 | # nframes = int(math.floor(len(wav) / len2) - math.floor(window_size / len2)) 179 | # vad = np.zeros(nframes * len2, dtype=np.bool) 180 | # 181 | # aa = 0.98 182 | # mu = 0.98 183 | # ksi_min = 10 ** (-25 / 10) 184 | # 185 | # xk_prev = np.zeros(len1) 186 | # noise_mu2 = noise_mu2 187 | # for k in range(0, nframes * len2, len2): 188 | # insign = win * wav[k:k + window_size] 189 | # 190 | # spec = np.fft.fft(insign, n_fft, axis=0) 191 | # sig = np.absolute(spec) 192 | # sig2 = sig ** 2 193 | # 194 | # gammak = np.minimum(sig2 / noise_mu2, 40) 195 | # 196 | # if xk_prev.all() == 0: 197 | # ksi = aa + (1 - aa) * np.maximum(gammak - 1, 0) 198 | # else: 199 | # ksi = aa * xk_prev / noise_mu2 + (1 - aa) * np.maximum(gammak - 1, 0) 200 | # ksi = np.maximum(ksi_min, ksi) 201 | # 202 | # log_sigma_k = gammak * ksi / (1 + ksi) - np.log(1 + ksi) 203 | # vad_decision = np.sum(log_sigma_k) / window_size 204 | # if vad_decision < eta: 205 | # noise_mu2 = mu * noise_mu2 + (1 - mu) * sig2 206 | # print(vad_decision) 207 | # 208 | # a = ksi / (1 + ksi) 209 | # vk = a * gammak 210 | # ei_vk = 0.5 * expn(1, np.maximum(vk, 1e-8)) 211 | # hw = a * np.exp(ei_vk) 212 | # sig = sig * hw 213 | # xk_prev = sig ** 2 214 | # 215 | # vad[k:k + len2] = vad_decision >= eta 216 | # 217 | # vad = np.pad(vad, (0, len(wav) - len(vad)), mode="constant") 218 | # return vad 219 | 220 | 221 | def to_float(_input): 222 | if _input.dtype == np.float64: 223 | return _input, _input.dtype 224 | elif _input.dtype == np.float32: 225 | return _input.astype(np.float64), _input.dtype 226 | elif _input.dtype == np.uint8: 227 | return (_input - 128) / 128., _input.dtype 228 | elif _input.dtype == np.int16: 229 | return _input / 32768., _input.dtype 230 | elif _input.dtype == np.int32: 231 | return _input / 2147483648., _input.dtype 232 | raise ValueError('Unsupported wave file format') 233 | 234 | 235 | def from_float(_input, dtype): 236 | if dtype == np.float64: 237 | return _input, np.float64 238 | elif dtype == np.float32: 239 | return _input.astype(np.float32) 240 | elif dtype == np.uint8: 241 | return ((_input * 128) + 128).astype(np.uint8) 242 | elif dtype == np.int16: 243 | return (_input * 32768).astype(np.int16) 244 | elif dtype == np.int32: 245 | print(_input) 246 | return (_input * 2147483648).astype(np.int32) 247 | raise ValueError('Unsupported wave file format') 248 | -------------------------------------------------------------------------------- /toolbox/__init__.py: -------------------------------------------------------------------------------- 1 | from toolbox.ui import UI 2 | from encoder import inference as encoder 3 | from synthesizer.inference import Synthesizer 4 | from vocoder import inference as vocoder 5 | from pathlib import Path 6 | from time import perf_counter as timer 7 | from toolbox.utterance import Utterance 8 | import numpy as np 9 | import traceback 10 | import sys 11 | 12 | 13 | # Use this directory structure for your datasets, or modify it to fit your needs 14 | recognized_datasets = [ 15 | "LibriSpeech/dev-clean", 16 | "LibriSpeech/dev-other", 17 | "LibriSpeech/test-clean", 18 | "LibriSpeech/test-other", 19 | "LibriSpeech/train-clean-100", 20 | "LibriSpeech/train-clean-360", 21 | "LibriSpeech/train-other-500", 22 | "LibriTTS/dev-clean", 23 | "LibriTTS/dev-other", 24 | "LibriTTS/test-clean", 25 | "LibriTTS/test-other", 26 | "LibriTTS/train-clean-100", 27 | "LibriTTS/train-clean-360", 28 | "LibriTTS/train-other-500", 29 | "LJSpeech-1.1", 30 | "VoxCeleb1/wav", 31 | "VoxCeleb1/test_wav", 32 | "VoxCeleb2/dev/aac", 33 | "VoxCeleb2/test/aac", 34 | "VCTK-Corpus/wav48", 35 | ] 36 | 37 | class Toolbox: 38 | def __init__(self, datasets_root, enc_models_dir, syn_models_dir, voc_models_dir, low_mem): 39 | sys.excepthook = self.excepthook 40 | self.datasets_root = datasets_root 41 | self.low_mem = low_mem 42 | self.utterances = set() 43 | self.current_generated = (None, None, None, None) # speaker_name, spec, breaks, wav 44 | 45 | self.synthesizer = None # type: Synthesizer 46 | 47 | # Initialize the events and the interface 48 | self.ui = UI() 49 | self.reset_ui(enc_models_dir, syn_models_dir, voc_models_dir) 50 | self.setup_events() 51 | self.ui.start() 52 | 53 | def excepthook(self, exc_type, exc_value, exc_tb): 54 | traceback.print_exception(exc_type, exc_value, exc_tb) 55 | self.ui.log("Exception: %s" % exc_value) 56 | 57 | def setup_events(self): 58 | # Dataset, speaker and utterance selection 59 | self.ui.browser_load_button.clicked.connect(lambda: self.load_from_browser()) 60 | random_func = lambda level: lambda: self.ui.populate_browser(self.datasets_root, 61 | recognized_datasets, 62 | level) 63 | self.ui.random_dataset_button.clicked.connect(random_func(0)) 64 | self.ui.random_speaker_button.clicked.connect(random_func(1)) 65 | self.ui.random_utterance_button.clicked.connect(random_func(2)) 66 | self.ui.dataset_box.currentIndexChanged.connect(random_func(1)) 67 | self.ui.speaker_box.currentIndexChanged.connect(random_func(2)) 68 | 69 | # Model selection 70 | self.ui.encoder_box.currentIndexChanged.connect(self.init_encoder) 71 | def func(): 72 | self.synthesizer = None 73 | self.ui.synthesizer_box.currentIndexChanged.connect(func) 74 | self.ui.vocoder_box.currentIndexChanged.connect(self.init_vocoder) 75 | 76 | # Utterance selection 77 | func = lambda: self.load_from_browser(self.ui.browse_file()) 78 | self.ui.browser_browse_button.clicked.connect(func) 79 | func = lambda: self.ui.draw_utterance(self.ui.selected_utterance, "current") 80 | self.ui.utterance_history.currentIndexChanged.connect(func) 81 | func = lambda: self.ui.play(self.ui.selected_utterance.wav, Synthesizer.sample_rate) 82 | self.ui.play_button.clicked.connect(func) 83 | self.ui.stop_button.clicked.connect(self.ui.stop) 84 | self.ui.record_button.clicked.connect(self.record) 85 | 86 | # Generation 87 | func = lambda: self.synthesize() or self.vocode() 88 | self.ui.generate_button.clicked.connect(func) 89 | self.ui.synthesize_button.clicked.connect(self.synthesize) 90 | self.ui.vocode_button.clicked.connect(self.vocode) 91 | 92 | # UMAP legend 93 | self.ui.clear_button.clicked.connect(self.clear_utterances) 94 | 95 | def reset_ui(self, encoder_models_dir, synthesizer_models_dir, vocoder_models_dir): 96 | self.ui.populate_browser(self.datasets_root, recognized_datasets, 0, True) 97 | self.ui.populate_models(encoder_models_dir, synthesizer_models_dir, vocoder_models_dir) 98 | 99 | def load_from_browser(self, fpath=None): 100 | if fpath is None: 101 | fpath = Path(self.datasets_root, 102 | self.ui.current_dataset_name, 103 | self.ui.current_speaker_name, 104 | self.ui.current_utterance_name) 105 | name = str(fpath.relative_to(self.datasets_root)) 106 | speaker_name = self.ui.current_dataset_name + '_' + self.ui.current_speaker_name 107 | 108 | # Select the next utterance 109 | if self.ui.auto_next_checkbox.isChecked(): 110 | self.ui.browser_select_next() 111 | elif fpath == "": 112 | return 113 | else: 114 | name = fpath.name 115 | speaker_name = fpath.parent.name 116 | 117 | # Get the wav from the disk. We take the wav with the vocoder/synthesizer format for 118 | # playback, so as to have a fair comparison with the generated audio 119 | wav = Synthesizer.load_preprocess_wav(fpath) 120 | self.ui.log("Loaded %s" % name) 121 | 122 | self.add_real_utterance(wav, name, speaker_name) 123 | 124 | def record(self): 125 | wav = self.ui.record_one(encoder.sampling_rate, 5) 126 | if wav is None: 127 | return 128 | self.ui.play(wav, encoder.sampling_rate) 129 | 130 | speaker_name = "user01" 131 | name = speaker_name + "_rec_%05d" % np.random.randint(100000) 132 | self.add_real_utterance(wav, name, speaker_name) 133 | 134 | def add_real_utterance(self, wav, name, speaker_name): 135 | # Compute the mel spectrogram 136 | spec = Synthesizer.make_spectrogram(wav) 137 | self.ui.draw_spec(spec, "current") 138 | 139 | # Compute the embedding 140 | if not encoder.is_loaded(): 141 | self.init_encoder() 142 | encoder_wav = encoder.preprocess_wav(wav) 143 | embed, partial_embeds, _ = encoder.embed_utterance(encoder_wav, return_partials=True) 144 | 145 | # Add the utterance 146 | utterance = Utterance(name, speaker_name, wav, spec, embed, partial_embeds, False) 147 | self.utterances.add(utterance) 148 | self.ui.register_utterance(utterance) 149 | 150 | # Plot it 151 | self.ui.draw_embed(embed, name, "current") 152 | self.ui.draw_umap_projections(self.utterances) 153 | 154 | def clear_utterances(self): 155 | self.utterances.clear() 156 | self.ui.draw_umap_projections(self.utterances) 157 | 158 | def synthesize(self): 159 | self.ui.log("Generating the mel spectrogram...") 160 | self.ui.set_loading(1) 161 | 162 | # Synthesize the spectrogram 163 | if self.synthesizer is None: 164 | model_dir = self.ui.current_synthesizer_model_dir 165 | checkpoints_dir = model_dir.joinpath("taco_pretrained") 166 | self.synthesizer = Synthesizer(checkpoints_dir, low_mem=self.low_mem) 167 | if not self.synthesizer.is_loaded(): 168 | self.ui.log("Loading the synthesizer %s" % self.synthesizer.checkpoint_fpath) 169 | 170 | texts = self.ui.text_prompt.toPlainText().split("\n") 171 | embed = self.ui.selected_utterance.embed 172 | embeds = np.stack([embed] * len(texts)) 173 | specs = self.synthesizer.synthesize_spectrograms(texts, embeds) 174 | breaks = [spec.shape[1] for spec in specs] 175 | spec = np.concatenate(specs, axis=1) 176 | 177 | self.ui.draw_spec(spec, "generated") 178 | self.current_generated = (self.ui.selected_utterance.speaker_name, spec, breaks, None) 179 | self.ui.set_loading(0) 180 | 181 | def vocode(self): 182 | speaker_name, spec, breaks, _ = self.current_generated 183 | assert spec is not None 184 | 185 | # Synthesize the waveform 186 | if not vocoder.is_loaded(): 187 | self.init_vocoder() 188 | def vocoder_progress(i, seq_len, b_size, gen_rate): 189 | real_time_factor = (gen_rate / Synthesizer.sample_rate) * 1000 190 | line = "Waveform generation: %d/%d (batch size: %d, rate: %.1fkHz - %.2fx real time)" \ 191 | % (i * b_size, seq_len * b_size, b_size, gen_rate, real_time_factor) 192 | self.ui.log(line, "overwrite") 193 | self.ui.set_loading(i, seq_len) 194 | if self.ui.current_vocoder_fpath is not None: 195 | self.ui.log("") 196 | wav = vocoder.infer_waveform(spec, progress_callback=vocoder_progress) 197 | else: 198 | self.ui.log("Waveform generation with Griffin-Lim... ") 199 | wav = Synthesizer.griffin_lim(spec) 200 | self.ui.set_loading(0) 201 | self.ui.log(" Done!", "append") 202 | 203 | # Add breaks 204 | b_ends = np.cumsum(np.array(breaks) * Synthesizer.hparams.hop_size) 205 | b_starts = np.concatenate(([0], b_ends[:-1])) 206 | wavs = [wav[start:end] for start, end, in zip(b_starts, b_ends)] 207 | breaks = [np.zeros(int(0.15 * Synthesizer.sample_rate))] * len(breaks) 208 | wav = np.concatenate([i for w, b in zip(wavs, breaks) for i in (w, b)]) 209 | 210 | # Play it 211 | wav = wav / np.abs(wav).max() * 0.97 212 | self.ui.play(wav, Synthesizer.sample_rate) 213 | 214 | # Compute the embedding 215 | # TODO: this is problematic with different sampling rates, gotta fix it 216 | if not encoder.is_loaded(): 217 | self.init_encoder() 218 | encoder_wav = encoder.preprocess_wav(wav) 219 | embed, partial_embeds, _ = encoder.embed_utterance(encoder_wav, return_partials=True) 220 | 221 | # Add the utterance 222 | name = speaker_name + "_gen_%05d" % np.random.randint(100000) 223 | utterance = Utterance(name, speaker_name, wav, spec, embed, partial_embeds, True) 224 | self.utterances.add(utterance) 225 | 226 | # Plot it 227 | self.ui.draw_embed(embed, name, "generated") 228 | self.ui.draw_umap_projections(self.utterances) 229 | 230 | def init_encoder(self): 231 | model_fpath = self.ui.current_encoder_fpath 232 | 233 | self.ui.log("Loading the encoder %s... " % model_fpath) 234 | self.ui.set_loading(1) 235 | start = timer() 236 | encoder.load_model(model_fpath) 237 | self.ui.log("Done (%dms)." % int(1000 * (timer() - start)), "append") 238 | self.ui.set_loading(0) 239 | 240 | def init_vocoder(self): 241 | model_fpath = self.ui.current_vocoder_fpath 242 | # Case of Griffin-lim 243 | if model_fpath is None: 244 | return 245 | 246 | self.ui.log("Loading the vocoder %s... " % model_fpath) 247 | self.ui.set_loading(1) 248 | start = timer() 249 | vocoder.load_model(model_fpath) 250 | self.ui.log("Done (%dms)." % int(1000 * (timer() - start)), "append") 251 | self.ui.set_loading(0) 252 | -------------------------------------------------------------------------------- /toolbox/ui.py: -------------------------------------------------------------------------------- 1 | from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas 2 | from matplotlib.figure import Figure 3 | from PyQt5.QtCore import Qt 4 | from PyQt5.QtWidgets import * 5 | from encoder.inference import plot_embedding_as_heatmap 6 | from toolbox.utterance import Utterance 7 | from pathlib import Path 8 | from typing import List, Set 9 | import sounddevice as sd 10 | import matplotlib.pyplot as plt 11 | import numpy as np 12 | # from sklearn.manifold import TSNE # You can try with TSNE if you like, I prefer UMAP 13 | from time import sleep 14 | import umap 15 | import sys 16 | from warnings import filterwarnings 17 | filterwarnings("ignore") 18 | 19 | 20 | colormap = np.array([ 21 | [0, 127, 70], 22 | [255, 0, 0], 23 | [255, 217, 38], 24 | [0, 135, 255], 25 | [165, 0, 165], 26 | [255, 167, 255], 27 | [97, 142, 151], 28 | [0, 255, 255], 29 | [255, 96, 38], 30 | [142, 76, 0], 31 | [33, 0, 127], 32 | [0, 0, 0], 33 | [183, 183, 183], 34 | [76, 255, 0], 35 | ], dtype=np.float) / 255 36 | 37 | default_text = \ 38 | "Welcome to the toolbox! To begin, load an utterance from your datasets or record one " \ 39 | "yourself.\nOnce its embedding has been created, you can synthesize any text written here.\n" \ 40 | "With the current synthesizer model, punctuation and special characters will be ignored.\n" \ 41 | "The synthesizer expects to generate " \ 42 | "outputs that are somewhere between 5 and 12 seconds.\nTo mark breaks, write a new line. " \ 43 | "Each line will be treated separately.\nThen, they are joined together to make the final " \ 44 | "spectrogram. Use the vocoder to generate audio.\nThe vocoder generates almost in constant " \ 45 | "time, so it will be more time efficient for longer inputs like this one.\nOn the left you " \ 46 | "have the embedding projections. Load or record more utterances to see them.\nIf you have " \ 47 | "at least 2 or 3 utterances from a same speaker, a cluster should form.\nSynthesized " \ 48 | "utterances are of the same color as the speaker whose voice was used, but they're " \ 49 | "represented with a cross." 50 | 51 | 52 | class UI(QDialog): 53 | min_umap_points = 4 54 | max_log_lines = 5 55 | max_saved_utterances = 20 56 | 57 | def draw_utterance(self, utterance: Utterance, which): 58 | self.draw_spec(utterance.spec, which) 59 | self.draw_embed(utterance.embed, utterance.name, which) 60 | 61 | def draw_embed(self, embed, name, which): 62 | embed_ax, _ = self.current_ax if which == "current" else self.gen_ax 63 | embed_ax.figure.suptitle("" if embed is None else name) 64 | 65 | ## Embedding 66 | # Clear the plot 67 | if len(embed_ax.images) > 0: 68 | embed_ax.images[0].colorbar.remove() 69 | embed_ax.clear() 70 | 71 | # Draw the embed 72 | if embed is not None: 73 | plot_embedding_as_heatmap(embed, embed_ax) 74 | embed_ax.set_title("embedding") 75 | embed_ax.set_aspect("equal", "datalim") 76 | embed_ax.set_xticks([]) 77 | embed_ax.set_yticks([]) 78 | embed_ax.figure.canvas.draw() 79 | 80 | def draw_spec(self, spec, which): 81 | _, spec_ax = self.current_ax if which == "current" else self.gen_ax 82 | 83 | ## Spectrogram 84 | # Draw the spectrogram 85 | spec_ax.clear() 86 | if spec is not None: 87 | im = spec_ax.imshow(spec, aspect="auto", interpolation="none") 88 | # spec_ax.figure.colorbar(mappable=im, shrink=0.65, orientation="horizontal", 89 | # spec_ax=spec_ax) 90 | spec_ax.set_title("mel spectrogram") 91 | 92 | spec_ax.set_xticks([]) 93 | spec_ax.set_yticks([]) 94 | spec_ax.figure.canvas.draw() 95 | if which != "current": 96 | self.vocode_button.setDisabled(spec is None) 97 | 98 | def draw_umap_projections(self, utterances: Set[Utterance]): 99 | self.umap_ax.clear() 100 | 101 | speakers = np.unique([u.speaker_name for u in utterances]) 102 | colors = {speaker_name: colormap[i] for i, speaker_name in enumerate(speakers)} 103 | embeds = [u.embed for u in utterances] 104 | 105 | # Display a message if there aren't enough points 106 | if len(utterances) < self.min_umap_points: 107 | self.umap_ax.text(.5, .5, "Add %d more points to\ngenerate the projections" % 108 | (self.min_umap_points - len(utterances)), 109 | horizontalalignment='center', fontsize=15) 110 | self.umap_ax.set_title("") 111 | 112 | # Compute the projections 113 | else: 114 | if not self.umap_hot: 115 | self.log( 116 | "Drawing UMAP projections for the first time, this will take a few seconds.") 117 | self.umap_hot = True 118 | 119 | reducer = umap.UMAP(int(np.ceil(np.sqrt(len(embeds)))), metric="cosine") 120 | # reducer = TSNE() 121 | projections = reducer.fit_transform(embeds) 122 | 123 | speakers_done = set() 124 | for projection, utterance in zip(projections, utterances): 125 | color = colors[utterance.speaker_name] 126 | mark = "x" if "_gen_" in utterance.name else "o" 127 | label = None if utterance.speaker_name in speakers_done else utterance.speaker_name 128 | speakers_done.add(utterance.speaker_name) 129 | self.umap_ax.scatter(projection[0], projection[1], c=[color], marker=mark, 130 | label=label) 131 | # self.umap_ax.set_title("UMAP projections") 132 | self.umap_ax.legend(prop={'size': 10}) 133 | 134 | # Draw the plot 135 | self.umap_ax.set_aspect("equal", "datalim") 136 | self.umap_ax.set_xticks([]) 137 | self.umap_ax.set_yticks([]) 138 | self.umap_ax.figure.canvas.draw() 139 | 140 | def play(self, wav, sample_rate): 141 | sd.stop() 142 | sd.play(wav, sample_rate) 143 | 144 | def stop(self): 145 | sd.stop() 146 | 147 | def record_one(self, sample_rate, duration): 148 | self.record_button.setText("Recording...") 149 | self.record_button.setDisabled(True) 150 | 151 | self.log("Recording %d seconds of audio" % duration) 152 | sd.stop() 153 | try: 154 | wav = sd.rec(duration * sample_rate, sample_rate, 1) 155 | except Exception as e: 156 | print(e) 157 | self.log("Could not record anything. Is your recording device enabled?") 158 | self.log("Your device must be connected before you start the toolbox.") 159 | return None 160 | 161 | for i in np.arange(0, duration, 0.1): 162 | self.set_loading(i, duration) 163 | sleep(0.1) 164 | self.set_loading(duration, duration) 165 | sd.wait() 166 | 167 | self.log("Done recording.") 168 | self.record_button.setText("Record one") 169 | self.record_button.setDisabled(False) 170 | 171 | return wav.squeeze() 172 | 173 | @property 174 | def current_dataset_name(self): 175 | return self.dataset_box.currentText() 176 | 177 | @property 178 | def current_speaker_name(self): 179 | return self.speaker_box.currentText() 180 | 181 | @property 182 | def current_utterance_name(self): 183 | return self.utterance_box.currentText() 184 | 185 | def browse_file(self): 186 | fpath = QFileDialog().getOpenFileName( 187 | parent=self, 188 | caption="Select an audio file", 189 | filter="Audio Files (*.mp3 *.flac *.wav *.m4a)" 190 | ) 191 | return Path(fpath[0]) if fpath[0] != "" else "" 192 | 193 | @staticmethod 194 | def repopulate_box(box, items, random=False): 195 | """ 196 | Resets a box and adds a list of items. Pass a list of (item, data) pairs instead to join 197 | data to the items 198 | """ 199 | box.blockSignals(True) 200 | box.clear() 201 | for item in items: 202 | item = list(item) if isinstance(item, tuple) else [item] 203 | box.addItem(str(item[0]), *item[1:]) 204 | if len(items) > 0: 205 | box.setCurrentIndex(np.random.randint(len(items)) if random else 0) 206 | box.setDisabled(len(items) == 0) 207 | box.blockSignals(False) 208 | 209 | def populate_browser(self, datasets_root: Path, recognized_datasets: List, level: int, 210 | random=True): 211 | # Select a random dataset 212 | if level <= 0: 213 | if datasets_root is not None: 214 | datasets = [datasets_root.joinpath(d) for d in recognized_datasets] 215 | datasets = [d.relative_to(datasets_root) for d in datasets if d.exists()] 216 | self.browser_load_button.setDisabled(len(datasets) == 0) 217 | if datasets_root is None or len(datasets) == 0: 218 | msg = "Warning: you d" + ("id not pass a root directory for datasets as argument" \ 219 | if datasets_root is None else "o not have any of the recognized datasets" \ 220 | " in %s" % datasets_root) 221 | self.log(msg) 222 | msg += ".\nThe recognized datasets are:\n\t%s\nFeel free to add your own. You " \ 223 | "can still use the toolbox by recording samples yourself." % \ 224 | ("\n\t".join(recognized_datasets)) 225 | print(msg, file=sys.stderr) 226 | 227 | self.random_utterance_button.setDisabled(True) 228 | self.random_speaker_button.setDisabled(True) 229 | self.random_dataset_button.setDisabled(True) 230 | self.utterance_box.setDisabled(True) 231 | self.speaker_box.setDisabled(True) 232 | self.dataset_box.setDisabled(True) 233 | return 234 | self.repopulate_box(self.dataset_box, datasets, random) 235 | 236 | # Select a random speaker 237 | if level <= 1: 238 | speakers_root = datasets_root.joinpath(self.current_dataset_name) 239 | speaker_names = [d.stem for d in speakers_root.glob("*") if d.is_dir()] 240 | self.repopulate_box(self.speaker_box, speaker_names, random) 241 | 242 | # Select a random utterance 243 | if level <= 2: 244 | utterances_root = datasets_root.joinpath( 245 | self.current_dataset_name, 246 | self.current_speaker_name 247 | ) 248 | utterances = [] 249 | for extension in ['mp3', 'flac', 'wav', 'm4a']: 250 | utterances.extend(Path(utterances_root).glob("**/*.%s" % extension)) 251 | utterances = [fpath.relative_to(utterances_root) for fpath in utterances] 252 | self.repopulate_box(self.utterance_box, utterances, random) 253 | 254 | def browser_select_next(self): 255 | index = (self.utterance_box.currentIndex() + 1) % len(self.utterance_box) 256 | self.utterance_box.setCurrentIndex(index) 257 | 258 | @property 259 | def current_encoder_fpath(self): 260 | return self.encoder_box.itemData(self.encoder_box.currentIndex()) 261 | 262 | @property 263 | def current_synthesizer_model_dir(self): 264 | return self.synthesizer_box.itemData(self.synthesizer_box.currentIndex()) 265 | 266 | @property 267 | def current_vocoder_fpath(self): 268 | return self.vocoder_box.itemData(self.vocoder_box.currentIndex()) 269 | 270 | def populate_models(self, encoder_models_dir: Path, synthesizer_models_dir: Path, 271 | vocoder_models_dir: Path): 272 | # Encoder 273 | encoder_fpaths = list(encoder_models_dir.glob("*.pt")) 274 | if len(encoder_fpaths) == 0: 275 | raise Exception("No encoder models found in %s" % encoder_models_dir) 276 | self.repopulate_box(self.encoder_box, [(f.stem, f) for f in encoder_fpaths]) 277 | 278 | # Synthesizer 279 | synthesizer_model_dirs = list(synthesizer_models_dir.glob("*")) 280 | synthesizer_items = [(f.name.replace("logs-", ""), f) for f in synthesizer_model_dirs] 281 | if len(synthesizer_model_dirs) == 0: 282 | raise Exception("No synthesizer models found in %s. For the synthesizer, the expected " 283 | "structure is /logs-/taco_pretrained/" 284 | "checkpoint" % synthesizer_models_dir) 285 | self.repopulate_box(self.synthesizer_box, synthesizer_items) 286 | 287 | # Vocoder 288 | vocoder_fpaths = list(vocoder_models_dir.glob("**/*.pt")) 289 | vocoder_items = [(f.stem, f) for f in vocoder_fpaths] + [("Griffin-Lim", None)] 290 | self.repopulate_box(self.vocoder_box, vocoder_items) 291 | 292 | @property 293 | def selected_utterance(self): 294 | return self.utterance_history.itemData(self.utterance_history.currentIndex()) 295 | 296 | def register_utterance(self, utterance: Utterance): 297 | self.utterance_history.blockSignals(True) 298 | self.utterance_history.insertItem(0, utterance.name, utterance) 299 | self.utterance_history.setCurrentIndex(0) 300 | self.utterance_history.blockSignals(False) 301 | 302 | if len(self.utterance_history) > self.max_saved_utterances: 303 | self.utterance_history.removeItem(self.max_saved_utterances) 304 | 305 | self.play_button.setDisabled(False) 306 | self.generate_button.setDisabled(False) 307 | self.synthesize_button.setDisabled(False) 308 | 309 | def log(self, line, mode="newline"): 310 | if mode == "newline": 311 | self.logs.append(line) 312 | if len(self.logs) > self.max_log_lines: 313 | del self.logs[0] 314 | elif mode == "append": 315 | self.logs[-1] += line 316 | elif mode == "overwrite": 317 | self.logs[-1] = line 318 | log_text = '\n'.join(self.logs) 319 | 320 | self.log_window.setText(log_text) 321 | self.app.processEvents() 322 | 323 | def set_loading(self, value, maximum=1): 324 | self.loading_bar.setValue(value * 100) 325 | self.loading_bar.setMaximum(maximum * 100) 326 | self.loading_bar.setTextVisible(value != 0) 327 | self.app.processEvents() 328 | 329 | def reset_interface(self): 330 | self.draw_embed(None, None, "current") 331 | self.draw_embed(None, None, "generated") 332 | self.draw_spec(None, "current") 333 | self.draw_spec(None, "generated") 334 | self.draw_umap_projections(set()) 335 | self.set_loading(0) 336 | self.play_button.setDisabled(True) 337 | self.generate_button.setDisabled(True) 338 | self.synthesize_button.setDisabled(True) 339 | self.vocode_button.setDisabled(True) 340 | [self.log("") for _ in range(self.max_log_lines)] 341 | 342 | def __init__(self): 343 | ## Initialize the application 344 | self.app = QApplication(sys.argv) 345 | super().__init__(None) 346 | self.setWindowTitle("SV2TTS toolbox") 347 | 348 | 349 | ## Main layouts 350 | # Root 351 | root_layout = QGridLayout() 352 | self.setLayout(root_layout) 353 | 354 | # Browser 355 | browser_layout = QGridLayout() 356 | root_layout.addLayout(browser_layout, 0, 1) 357 | 358 | # Visualizations 359 | vis_layout = QVBoxLayout() 360 | root_layout.addLayout(vis_layout, 1, 1, 2, 3) 361 | 362 | # Generation 363 | gen_layout = QVBoxLayout() 364 | root_layout.addLayout(gen_layout, 0, 2) 365 | 366 | # Projections 367 | self.projections_layout = QVBoxLayout() 368 | root_layout.addLayout(self.projections_layout, 1, 0) 369 | 370 | 371 | ## Projections 372 | # UMap 373 | fig, self.umap_ax = plt.subplots(figsize=(4, 4), facecolor="#F0F0F0") 374 | fig.subplots_adjust(left=0.02, bottom=0.02, right=0.98, top=0.98) 375 | self.projections_layout.addWidget(FigureCanvas(fig)) 376 | self.umap_hot = False 377 | self.clear_button = QPushButton("Clear") 378 | self.projections_layout.addWidget(self.clear_button) 379 | 380 | 381 | ## Browser 382 | # Dataset, speaker and utterance selection 383 | i = 0 384 | self.dataset_box = QComboBox() 385 | browser_layout.addWidget(QLabel("Dataset"), i, 0) 386 | browser_layout.addWidget(self.dataset_box, i + 1, 0) 387 | self.speaker_box = QComboBox() 388 | browser_layout.addWidget(QLabel("Speaker"), i, 1) 389 | browser_layout.addWidget(self.speaker_box, i + 1, 1) 390 | self.utterance_box = QComboBox() 391 | browser_layout.addWidget(QLabel("Utterance"), i, 2) 392 | browser_layout.addWidget(self.utterance_box, i + 1, 2) 393 | self.browser_browse_button = QPushButton("Browse") 394 | browser_layout.addWidget(self.browser_browse_button, i, 3) 395 | self.browser_load_button = QPushButton("Load") 396 | browser_layout.addWidget(self.browser_load_button, i + 1, 3) 397 | i += 2 398 | 399 | # Random buttons 400 | self.random_dataset_button = QPushButton("Random") 401 | browser_layout.addWidget(self.random_dataset_button, i, 0) 402 | self.random_speaker_button = QPushButton("Random") 403 | browser_layout.addWidget(self.random_speaker_button, i, 1) 404 | self.random_utterance_button = QPushButton("Random") 405 | browser_layout.addWidget(self.random_utterance_button, i, 2) 406 | self.auto_next_checkbox = QCheckBox("Auto select next") 407 | self.auto_next_checkbox.setChecked(True) 408 | browser_layout.addWidget(self.auto_next_checkbox, i, 3) 409 | i += 1 410 | 411 | # Utterance box 412 | browser_layout.addWidget(QLabel("Use embedding from:"), i, 0) 413 | i += 1 414 | 415 | # Random & next utterance buttons 416 | self.utterance_history = QComboBox() 417 | browser_layout.addWidget(self.utterance_history, i, 0, 1, 3) 418 | i += 1 419 | 420 | # Random & next utterance buttons 421 | self.take_generated_button = QPushButton("Take generated") 422 | browser_layout.addWidget(self.take_generated_button, i, 0) 423 | self.record_button = QPushButton("Record") 424 | browser_layout.addWidget(self.record_button, i, 1) 425 | self.play_button = QPushButton("Play") 426 | browser_layout.addWidget(self.play_button, i, 2) 427 | self.stop_button = QPushButton("Stop") 428 | browser_layout.addWidget(self.stop_button, i, 3) 429 | i += 2 430 | 431 | # Model selection 432 | self.encoder_box = QComboBox() 433 | browser_layout.addWidget(QLabel("Encoder"), i, 0) 434 | browser_layout.addWidget(self.encoder_box, i + 1, 0) 435 | self.synthesizer_box = QComboBox() 436 | browser_layout.addWidget(QLabel("Synthesizer"), i, 1) 437 | browser_layout.addWidget(self.synthesizer_box, i + 1, 1) 438 | self.vocoder_box = QComboBox() 439 | browser_layout.addWidget(QLabel("Vocoder"), i, 2) 440 | browser_layout.addWidget(self.vocoder_box, i + 1, 2) 441 | i += 2 442 | 443 | 444 | ## Embed & spectrograms 445 | vis_layout.addStretch() 446 | 447 | gridspec_kw = {"width_ratios": [1, 4]} 448 | fig, self.current_ax = plt.subplots(1, 2, figsize=(10, 2.25), facecolor="#F0F0F0", 449 | gridspec_kw=gridspec_kw) 450 | fig.subplots_adjust(left=0, bottom=0.1, right=1, top=0.8) 451 | vis_layout.addWidget(FigureCanvas(fig)) 452 | 453 | fig, self.gen_ax = plt.subplots(1, 2, figsize=(10, 2.25), facecolor="#F0F0F0", 454 | gridspec_kw=gridspec_kw) 455 | fig.subplots_adjust(left=0, bottom=0.1, right=1, top=0.8) 456 | vis_layout.addWidget(FigureCanvas(fig)) 457 | 458 | for ax in self.current_ax.tolist() + self.gen_ax.tolist(): 459 | ax.set_facecolor("#F0F0F0") 460 | for side in ["top", "right", "bottom", "left"]: 461 | ax.spines[side].set_visible(False) 462 | 463 | 464 | ## Generation 465 | self.text_prompt = QPlainTextEdit(default_text) 466 | gen_layout.addWidget(self.text_prompt, stretch=1) 467 | 468 | self.generate_button = QPushButton("Synthesize and vocode") 469 | gen_layout.addWidget(self.generate_button) 470 | 471 | layout = QHBoxLayout() 472 | self.synthesize_button = QPushButton("Synthesize only") 473 | layout.addWidget(self.synthesize_button) 474 | self.vocode_button = QPushButton("Vocode only") 475 | layout.addWidget(self.vocode_button) 476 | gen_layout.addLayout(layout) 477 | 478 | self.loading_bar = QProgressBar() 479 | gen_layout.addWidget(self.loading_bar) 480 | 481 | self.log_window = QLabel() 482 | self.log_window.setAlignment(Qt.AlignBottom | Qt.AlignLeft) 483 | gen_layout.addWidget(self.log_window) 484 | self.logs = [] 485 | gen_layout.addStretch() 486 | 487 | 488 | ## Set the size of the window and of the elements 489 | max_size = QDesktopWidget().availableGeometry(self).size() * 0.8 490 | self.resize(max_size) 491 | 492 | ## Finalize the display 493 | self.reset_interface() 494 | self.show() 495 | 496 | def start(self): 497 | self.app.exec_() 498 | --------------------------------------------------------------------------------