├── README.md ├── data ├── __init__.py ├── ge2e_dataset.py ├── infinite_dataloader.py └── wav2mel.py ├── equal_error_rate.py ├── images └── tsne.png ├── modules ├── __init__.py ├── dvector.py └── ge2e.py ├── preprocess.py ├── train.py └── visualize.py /README.md: -------------------------------------------------------------------------------- 1 | # D-vector 2 | 3 | This is a PyTorch implementation of speaker embedding trained with GE2E loss. 4 | The original paper about GE2E loss could be found here: [Generalized End-to-End Loss for Speaker Verification](https://arxiv.org/abs/1710.10467) 5 | 6 | ## Usage 7 | 8 | ```python 9 | import torch 10 | import torchaudio 11 | 12 | wav2mel = torch.jit.load("wav2mel.pt") 13 | dvector = torch.jit.load("dvector.pt").eval() 14 | 15 | wav_tensor, sample_rate = torchaudio.load("example.wav") 16 | mel_tensor = wav2mel(wav_tensor, sample_rate) # shape: (frames, mel_dim) 17 | emb_tensor = dvector.embed_utterance(mel_tensor) # shape: (emb_dim) 18 | ``` 19 | 20 | You can also embed multiple utterances of a speaker at once: 21 | 22 | ```python 23 | emb_tensor = dvector.embed_utterances([mel_tensor_1, mel_tensor_2]) # shape: (emb_dim) 24 | ``` 25 | 26 | There are 2 modules in this example: 27 | - `wav2mel.pt` is the preprocessing module which is composed of 2 modules: 28 | - `sox_effects.pt` is used to normalize volume, remove silence, resample audio to 16 KHz, 16 bits, and remix all channels to single channel 29 | - `log_melspectrogram.pt` is used to transform waveforms to log mel spectrograms 30 | - `dvector.pt` is the speaker encoder 31 | 32 | Since all the modules are compiled with [TorchScript](https://pytorch.org/docs/stable/jit.html), you can simply load them and use anywhere **without any dependencies**. 33 | 34 | ### Pretrianed models & preprocessing modules 35 | 36 | You can download them from the page of [*Releases*](https://github.com/yistLin/dvector/releases). 37 | 38 | ## Evaluate model performance 39 | 40 | You can evaluate the performance of the model with equal error rate. 41 | For example, download the official test splits (`veri_test.txt` and `veri_test2.txt`) from [The VoxCeleb1 Dataset](https://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1.html) and run the following command: 42 | ```bash 43 | python equal_error_rate.py VoxCeleb1/test VoxCeleb1/test/veri_test.txt -w wav2mel.pt -c dvector.pt 44 | ``` 45 | 46 | So far, the released checkpoint was only trained on VoxCeleb1 without any data augmentation. 47 | Its performance on the official test splits of VoxCeleb1 are as following: 48 | | Test Split | Equal Error Rate | Threshold | 49 | | :-: |:-: |:-: | 50 | | veri_test.txt | 12.0% | 0.222 | 51 | | veri_test2.txt | 11.9% | 0.223 | 52 | 53 | ## Train from scratch 54 | 55 | ### Preprocess training data 56 | 57 | To use the script provided here, you have to organize your raw data in this way: 58 | 59 | - all utterances from a speaker should be put under a directory (**speaker directory**) 60 | - all speaker directories should be put under a directory (**root directory**) 61 | - **speaker directory** can have subdirectories and utterances can be placed under subdirectories 62 | 63 | And you can extract utterances from multiple **root directories**, e.g. 64 | 65 | ```bash 66 | python preprocess.py VoxCeleb1/dev LibriSpeech/train-clean-360 -o preprocessed 67 | ``` 68 | 69 | If you need to modify some audio preprocessing hyperparameters, directly modify `data/wav2mel.py`. 70 | After preprocessing, 3 preprocessing modules will be saved in the output directory: 71 | 1. `wav2mel.pt` 72 | 2. `sox_effects.pt` 73 | 3. `log_melspectrogram.pt` 74 | 75 | > The first module `wav2mel.pt` is composed of the second and the third modules. 76 | > These modules were compiled with TorchScript and can be used anywhere to preprocess audio data. 77 | 78 | ### Train a model 79 | 80 | You have to specify where to store checkpoints and logs, e.g. 81 | 82 | ```bash 83 | python train.py preprocessed 84 | ``` 85 | 86 | During training, logs will be put under `/logs` and checkpoints will be placed under `/checkpoints`. 87 | For more details, check the usage with `python train.py -h`. 88 | 89 | ### Use different speaker encoders 90 | 91 | By default I'm using 3-layerd LSTM with attentive pooling as the speaker encoder, but you can use speaker encoders of different architecture. 92 | For more information, please take a look at `modules/dvector.py`. 93 | 94 | ## Visualize speaker embeddings 95 | 96 | You can visualize speaker embeddings using a trained d-vector. 97 | Note that you have to structure speakers' directories in the same way as for preprocessing. 98 | e.g. 99 | 100 | ```bash 101 | python visualize.py LibriSpeech/dev-clean -w wav2mel.pt -c dvector.pt -o tsne.jpg 102 | ``` 103 | 104 | The following plot is the dimension reduction result (using t-SNE) of some utterances from LibriSpeech. 105 | 106 | ![TSNE result](images/tsne.png) 107 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | from .ge2e_dataset import GE2EDataset, collate_batch 2 | from .wav2mel import Wav2Mel 3 | from .infinite_dataloader import InfiniteDataLoader, infinite_iterator 4 | -------------------------------------------------------------------------------- /data/ge2e_dataset.py: -------------------------------------------------------------------------------- 1 | """Dataset for speaker embedding.""" 2 | 3 | import random 4 | from pathlib import Path 5 | from typing import Union 6 | 7 | import torch 8 | from torch.utils.data import Dataset 9 | from torch.nn.utils.rnn import pad_sequence 10 | 11 | 12 | class GE2EDataset(Dataset): 13 | """Sample utterances from speakers.""" 14 | 15 | def __init__( 16 | self, 17 | data_dir: Union[str, Path], 18 | speaker_infos: dict, 19 | n_utterances: int, 20 | seg_len: int, 21 | ): 22 | """ 23 | Args: 24 | data_dir (string): path to the directory of pickle files. 25 | n_utterances (int): # of utterances per speaker to be sampled. 26 | seg_len (int): the minimum length of segments of utterances. 27 | """ 28 | 29 | self.data_dir = data_dir 30 | self.n_utterances = n_utterances 31 | self.seg_len = seg_len 32 | self.infos = [] 33 | 34 | for uttr_infos in speaker_infos.values(): 35 | feature_paths = [ 36 | uttr_info["feature_path"] 37 | for uttr_info in uttr_infos 38 | if uttr_info["mel_len"] > seg_len 39 | ] 40 | if len(feature_paths) > n_utterances: 41 | self.infos.append(feature_paths) 42 | 43 | def __len__(self): 44 | return len(self.infos) 45 | 46 | def __getitem__(self, index): 47 | feature_paths = random.sample(self.infos[index], self.n_utterances) 48 | uttrs = [ 49 | torch.load(Path(self.data_dir, feature_path)) 50 | for feature_path in feature_paths 51 | ] 52 | lefts = [random.randint(0, len(uttr) - self.seg_len) for uttr in uttrs] 53 | segments = [ 54 | uttr[left : left + self.seg_len, :] for uttr, left in zip(uttrs, lefts) 55 | ] 56 | return segments 57 | 58 | 59 | def collate_batch(batch): 60 | """Collate a whole batch of utterances.""" 61 | flatten = [u for s in batch for u in s] 62 | return pad_sequence(flatten, batch_first=True, padding_value=0) 63 | -------------------------------------------------------------------------------- /data/infinite_dataloader.py: -------------------------------------------------------------------------------- 1 | """Reference: https://discuss.pytorch.org/t/enumerate-dataloader-slow/87778/4""" 2 | 3 | import torch 4 | 5 | 6 | class InfiniteDataLoader(torch.utils.data.DataLoader): 7 | def __init__(self, *args, **kwargs): 8 | super().__init__(*args, **kwargs) 9 | self._DataLoader__initialized = False 10 | self.batch_sampler = _RepeatSampler(self.batch_sampler) 11 | self._DataLoader__initialized = True 12 | self.iterator = super().__iter__() 13 | 14 | def __len__(self): 15 | return len(self.batch_sampler.sampler) 16 | 17 | def __iter__(self): 18 | for _ in range(len(self)): 19 | yield next(self.iterator) 20 | 21 | 22 | class _RepeatSampler(object): 23 | """Sampler that repeats forever. 24 | Args: 25 | sampler (Sampler) 26 | """ 27 | 28 | def __init__(self, sampler): 29 | self.sampler = sampler 30 | 31 | def __iter__(self): 32 | while True: 33 | yield from iter(self.sampler) 34 | 35 | 36 | def infinite_iterator(dataloader): 37 | """Infinitely yield a batch of data.""" 38 | while True: 39 | for batch in iter(dataloader): 40 | yield batch 41 | -------------------------------------------------------------------------------- /data/wav2mel.py: -------------------------------------------------------------------------------- 1 | """Wav2Mel for processing audio data.""" 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torchaudio.sox_effects import apply_effects_tensor 6 | from torchaudio.transforms import MelSpectrogram 7 | 8 | 9 | class Wav2Mel(nn.Module): 10 | """Transform audio file into mel spectrogram tensors.""" 11 | 12 | def __init__( 13 | self, 14 | sample_rate: int = 16000, 15 | norm_db: float = -3.0, 16 | sil_threshold: float = 1.0, 17 | sil_duration: float = 0.1, 18 | fft_window_ms: float = 25.0, 19 | fft_hop_ms: float = 10.0, 20 | f_min: float = 50.0, 21 | n_mels: int = 40, 22 | ): 23 | super().__init__() 24 | 25 | self.sample_rate = sample_rate 26 | self.norm_db = norm_db 27 | self.sil_threshold = sil_threshold 28 | self.sil_duration = sil_duration 29 | self.fft_window_ms = fft_window_ms 30 | self.fft_hop_ms = fft_hop_ms 31 | self.f_min = f_min 32 | self.n_mels = n_mels 33 | 34 | self.sox_effects = SoxEffects(sample_rate, norm_db, sil_threshold, sil_duration) 35 | self.log_melspectrogram = LogMelspectrogram( 36 | sample_rate, fft_window_ms, fft_hop_ms, f_min, n_mels 37 | ) 38 | 39 | def forward(self, wav_tensor: torch.Tensor, sample_rate: int) -> torch.Tensor: 40 | wav_tensor = self.sox_effects(wav_tensor, sample_rate) 41 | mel_tensor = self.log_melspectrogram(wav_tensor) 42 | return mel_tensor 43 | 44 | 45 | class SoxEffects(nn.Module): 46 | """Transform waveform tensors.""" 47 | 48 | def __init__( 49 | self, 50 | sample_rate: int, 51 | norm_db: float, 52 | sil_threshold: float, 53 | sil_duration: float, 54 | ): 55 | super().__init__() 56 | self.effects = [ 57 | ["channels", "1"], # convert to mono 58 | ["rate", f"{sample_rate}"], # resample 59 | ["norm", f"{norm_db}"], # normalize to -3 dB 60 | [ 61 | "silence", 62 | "1", 63 | f"{sil_duration}", 64 | f"{sil_threshold}%", 65 | "-1", 66 | f"{sil_duration}", 67 | f"{sil_threshold}%", 68 | ], # remove silence throughout the file 69 | ] 70 | 71 | def forward(self, wav_tensor: torch.Tensor, sample_rate: int) -> torch.Tensor: 72 | wav_tensor, _ = apply_effects_tensor(wav_tensor, sample_rate, self.effects) 73 | return wav_tensor 74 | 75 | 76 | class LogMelspectrogram(nn.Module): 77 | """Transform waveform tensors into log mel spectrogram tensors.""" 78 | 79 | def __init__( 80 | self, 81 | sample_rate: int, 82 | fft_window_ms: float, 83 | fft_hop_ms: float, 84 | f_min: float, 85 | n_mels: int, 86 | ): 87 | super().__init__() 88 | self.melspectrogram = MelSpectrogram( 89 | sample_rate=sample_rate, 90 | hop_length=int(sample_rate * fft_hop_ms / 1000), 91 | n_fft=int(sample_rate * fft_window_ms / 1000), 92 | f_min=f_min, 93 | n_mels=n_mels, 94 | ) 95 | 96 | def forward(self, wav_tensor: torch.Tensor) -> torch.Tensor: 97 | mel_tensor = self.melspectrogram(wav_tensor).squeeze(0).T # (time, n_mels) 98 | return torch.log(torch.clamp(mel_tensor, min=1e-9)) 99 | -------------------------------------------------------------------------------- /equal_error_rate.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | """Visualize speaker embeddings.""" 4 | 5 | from argparse import ArgumentParser 6 | from pathlib import Path 7 | from warnings import filterwarnings 8 | 9 | import numpy as np 10 | import torch 11 | import torch.nn.functional as F 12 | import torchaudio 13 | from scipy.interpolate import interp1d 14 | from scipy.optimize import brentq 15 | from sklearn.metrics import roc_curve 16 | from torch.utils.data import DataLoader, Dataset 17 | from tqdm import tqdm 18 | 19 | 20 | def equal_error_rate(test_dir, test_txt, wav2mel_path, checkpoint_path): 21 | """Compute equal error rate on test set.""" 22 | 23 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 24 | test_dir_path = Path(test_dir) 25 | test_txt_path = Path(test_txt) 26 | 27 | wav2mel = torch.jit.load(wav2mel_path) 28 | dvector = torch.jit.load(checkpoint_path).eval().to(device) 29 | 30 | pairs = [] 31 | with test_txt_path.open() as file: 32 | for line in file: 33 | label, audio_path1, audio_path2 = line.strip().split() 34 | pairs.append((label, audio_path1, audio_path2)) 35 | 36 | class MyDataset(Dataset): 37 | def __init__(self): 38 | self.pairs = pairs 39 | 40 | def __len__(self): 41 | return len(self.pairs) 42 | 43 | def __getitem__(self, index): 44 | label, path1, path2 = self.pairs[index] 45 | audio_path1 = test_dir_path / path1 46 | audio_path2 = test_dir_path / path2 47 | wav_tensor1, sample_rate = torchaudio.load(audio_path1) 48 | wav_tensor2, sample_rate = torchaudio.load(audio_path2) 49 | mel_tensor1 = wav2mel(wav_tensor1, sample_rate) 50 | mel_tensor2 = wav2mel(wav_tensor2, sample_rate) 51 | return int(label), mel_tensor1, mel_tensor2 52 | 53 | dataloader = DataLoader( 54 | MyDataset(), 55 | batch_size=1, 56 | shuffle=False, 57 | drop_last=False, 58 | num_workers=8, 59 | prefetch_factor=4, 60 | ) 61 | 62 | labels, scores = [], [] 63 | for label, mel1, mel2 in tqdm(dataloader, ncols=0, desc="Calculate Similarity"): 64 | mel1, mel2 = mel1.to(device), mel2.to(device) 65 | with torch.no_grad(): 66 | emb1 = dvector.embed_utterance(mel1) 67 | emb2 = dvector.embed_utterance(mel2) 68 | score = F.cosine_similarity(emb1.unsqueeze(0), emb2.unsqueeze(0)) 69 | labels.append(label[0]) 70 | scores.append(score.item()) 71 | 72 | labels = np.array(labels) 73 | scores = np.array(scores) 74 | 75 | fpr, tpr, thresholds = roc_curve(labels, scores) 76 | eer = brentq(lambda x: 1.0 - x - interp1d(fpr, tpr)(x), 0.0, 1.0) 77 | thresh = interp1d(fpr, thresholds)(eer) 78 | 79 | print("eer =", eer) 80 | print("thresh =", thresh) 81 | 82 | 83 | if __name__ == "__main__": 84 | filterwarnings("ignore") 85 | PARSER = ArgumentParser() 86 | PARSER.add_argument("test_dir") 87 | PARSER.add_argument("test_txt") 88 | PARSER.add_argument("-w", "--wav2mel_path", required=True) 89 | PARSER.add_argument("-c", "--checkpoint_path", required=True) 90 | equal_error_rate(**vars(PARSER.parse_args())) 91 | -------------------------------------------------------------------------------- /images/tsne.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yistLin/dvector/3964747aff05ba51645634faea41a1fec5a68b99/images/tsne.png -------------------------------------------------------------------------------- /modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .dvector import * 2 | from .ge2e import GE2ELoss 3 | -------------------------------------------------------------------------------- /modules/dvector.py: -------------------------------------------------------------------------------- 1 | """Build a model for d-vector speaker embedding.""" 2 | 3 | import abc 4 | import math 5 | from typing import List 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | from torch import Tensor 11 | 12 | 13 | class DvectorInterface(nn.Module, metaclass=abc.ABCMeta): 14 | """d-vector interface.""" 15 | 16 | @classmethod 17 | def __subclasshook__(cls, subclass): 18 | return ( 19 | hasattr(subclass, "forward") 20 | and callable(subclass.forward) 21 | and hasattr(subclass, "seg_len") 22 | or NotImplemented 23 | ) 24 | 25 | @abc.abstractmethod 26 | def forward(self, inputs: Tensor) -> Tensor: 27 | """Forward a batch through network. 28 | 29 | Args: 30 | inputs: (batch, seg_len, mel_dim) 31 | 32 | Returns: 33 | embeds: (batch, emb_dim) 34 | """ 35 | raise NotImplementedError 36 | 37 | @torch.jit.export 38 | def embed_utterance(self, utterance: Tensor) -> Tensor: 39 | """Embed an utterance by segmentation and averaging 40 | 41 | Args: 42 | utterance: (uttr_len, mel_dim) or (1, uttr_len, mel_dim) 43 | 44 | Returns: 45 | embed: (emb_dim) 46 | """ 47 | assert utterance.ndim == 2 or (utterance.ndim == 3 and utterance.size(0) == 1) 48 | 49 | if utterance.ndim == 3: 50 | utterance = utterance.squeeze(0) 51 | 52 | if utterance.size(0) <= self.seg_len: 53 | embed = self.forward(utterance.unsqueeze(0)).squeeze(0) 54 | else: 55 | # Pad to multiple of hop length 56 | hop_len = self.seg_len // 2 57 | tgt_len = math.ceil(utterance.size(0) / hop_len) * hop_len 58 | zero_padding = torch.zeros(tgt_len - utterance.size(0), utterance.size(1)) 59 | padded = torch.cat([utterance, zero_padding.to(utterance.device)]) 60 | 61 | segments = padded.unfold(0, self.seg_len, self.seg_len // 2) 62 | segments = segments.transpose(1, 2) # (batch, seg_len, mel_dim) 63 | embeds = self.forward(segments) 64 | embed = embeds.mean(dim=0) 65 | embed = embed.div(embed.norm(p=2, dim=-1, keepdim=True)) 66 | 67 | return embed 68 | 69 | @torch.jit.export 70 | def embed_utterances(self, utterances: List[Tensor]) -> Tensor: 71 | """Embed utterances by averaging the embeddings of utterances 72 | 73 | Args: 74 | utterances: [(uttr_len, mel_dim), ...] 75 | 76 | Returns: 77 | embed: (emb_dim) 78 | """ 79 | embeds = torch.stack([self.embed_utterance(uttr) for uttr in utterances]) 80 | embed = embeds.mean(dim=0) 81 | return embed.div(embed.norm(p=2, dim=-1, keepdim=True)) 82 | 83 | 84 | class LSTMDvector(DvectorInterface): 85 | """LSTM-based d-vector.""" 86 | 87 | def __init__( 88 | self, 89 | num_layers=3, 90 | dim_input=40, 91 | dim_cell=256, 92 | dim_emb=256, 93 | seg_len=160, 94 | ): 95 | super().__init__() 96 | self.lstm = nn.LSTM(dim_input, dim_cell, num_layers, batch_first=True) 97 | self.embedding = nn.Linear(dim_cell, dim_emb) 98 | self.seg_len = seg_len 99 | 100 | def forward(self, inputs: Tensor) -> Tensor: 101 | """Forward a batch through network.""" 102 | lstm_outs, _ = self.lstm(inputs) # (batch, seg_len, dim_cell) 103 | embeds = self.embedding(lstm_outs[:, -1, :]) # (batch, dim_emb) 104 | return embeds.div(embeds.norm(p=2, dim=-1, keepdim=True)) # (batch, dim_emb) 105 | 106 | 107 | class AttentivePooledLSTMDvector(DvectorInterface): 108 | """LSTM-based d-vector with attentive pooling.""" 109 | 110 | def __init__( 111 | self, 112 | num_layers=3, 113 | dim_input=40, 114 | dim_cell=256, 115 | dim_emb=256, 116 | seg_len=160, 117 | ): 118 | super().__init__() 119 | self.lstm = nn.LSTM(dim_input, dim_cell, num_layers, batch_first=True) 120 | self.embedding = nn.Linear(dim_cell, dim_emb) 121 | self.linear = nn.Linear(dim_emb, 1) 122 | self.seg_len = seg_len 123 | 124 | def forward(self, inputs: Tensor) -> Tensor: 125 | """Forward a batch through network.""" 126 | lstm_outs, _ = self.lstm(inputs) # (batch, seg_len, dim_cell) 127 | embeds = torch.tanh(self.embedding(lstm_outs)) # (batch, seg_len, dim_emb) 128 | attn_weights = F.softmax(self.linear(embeds), dim=1) 129 | embeds = torch.sum(embeds * attn_weights, dim=1) 130 | return embeds.div(embeds.norm(p=2, dim=-1, keepdim=True)) 131 | -------------------------------------------------------------------------------- /modules/ge2e.py: -------------------------------------------------------------------------------- 1 | """PyTorch implementation of GE2E loss""" 2 | from functools import lru_cache 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | 9 | class GE2ELoss(nn.Module): 10 | """Implementation of the GE2E loss in https://arxiv.org/abs/1710.10467 11 | 12 | Accepts an input of size (N, M, D) 13 | 14 | where N is the number of speakers in the batch, 15 | M is the number of utterances per speaker, 16 | and D is the dimensionality of the embedding vector (e.g. d-vector) 17 | 18 | Args: 19 | - init_w (float): the initial value of w in Equation (5) 20 | - init_b (float): the initial value of b in Equation (5) 21 | """ 22 | 23 | def __init__(self, init_w=10.0, init_b=-5.0, loss_method="softmax"): 24 | super(GE2ELoss, self).__init__() 25 | self.w = nn.Parameter(torch.FloatTensor([init_w])) 26 | self.b = nn.Parameter(torch.FloatTensor([init_b])) 27 | self.loss_method = loss_method 28 | 29 | assert self.loss_method in ["softmax", "contrast"] 30 | 31 | if self.loss_method == "softmax": 32 | self.embed_loss = self.embed_loss_softmax 33 | if self.loss_method == "contrast": 34 | self.embed_loss = self.embed_loss_contrast 35 | 36 | def cosine_similarity(self, dvecs): 37 | """Calculate cosine similarity matrix of shape (N, M, N).""" 38 | n_spkr, n_uttr, d_embd = dvecs.size() 39 | 40 | dvec_expns = dvecs.unsqueeze(-1).expand(n_spkr, n_uttr, d_embd, n_spkr) 41 | dvec_expns = dvec_expns.transpose(2, 3) 42 | 43 | ctrds = dvecs.mean(dim=1).to(dvecs.device) 44 | ctrd_expns = ctrds.unsqueeze(0).expand(n_spkr * n_uttr, n_spkr, d_embd) 45 | ctrd_expns = ctrd_expns.reshape(-1, d_embd) 46 | 47 | dvec_rolls = torch.cat([dvecs[:, 1:, :], dvecs[:, :-1, :]], dim=1) 48 | dvec_excls = dvec_rolls.unfold(1, n_uttr - 1, 1) 49 | mean_excls = dvec_excls.mean(dim=-1).reshape(-1, d_embd) 50 | 51 | indices = _indices_to_replace(n_spkr, n_uttr).to(dvecs.device) 52 | ctrd_excls = ctrd_expns.index_copy(0, indices, mean_excls) 53 | ctrd_excls = ctrd_excls.view_as(dvec_expns) 54 | 55 | return F.cosine_similarity(dvec_expns, ctrd_excls, 3, 1e-6) 56 | 57 | def embed_loss_softmax(self, dvecs, cos_sim_matrix): 58 | """Calculate the loss on each embedding by taking softmax.""" 59 | n_spkr, n_uttr, _ = dvecs.size() 60 | indices = _indices_to_replace(n_spkr, n_uttr).to(dvecs.device) 61 | losses = -F.log_softmax(cos_sim_matrix, 2) 62 | return losses.flatten().index_select(0, indices).view(n_spkr, n_uttr) 63 | 64 | def embed_loss_contrast(self, dvecs, cos_sim_matrix): 65 | """Calculate the loss on each embedding by contrast loss.""" 66 | N, M, _ = dvecs.shape 67 | L = [] 68 | for j in range(N): 69 | L_row = [] 70 | for i in range(M): 71 | centroids_sigmoids = torch.sigmoid(cos_sim_matrix[j, i]) 72 | excl_centroids_sigmoids = torch.cat( 73 | (centroids_sigmoids[:j], centroids_sigmoids[j + 1 :]) 74 | ) 75 | L_row.append( 76 | 1.0 77 | - torch.sigmoid(cos_sim_matrix[j, i, j]) 78 | + torch.max(excl_centroids_sigmoids) 79 | ) 80 | L_row = torch.stack(L_row) 81 | L.append(L_row) 82 | return torch.stack(L) 83 | 84 | def forward(self, dvecs): 85 | """Calculate the GE2E loss for an input of dimensions (N, M, D).""" 86 | cos_sim_matrix = self.cosine_similarity(dvecs) 87 | torch.clamp(self.w, 1e-6) 88 | cos_sim_matrix = cos_sim_matrix * self.w + self.b 89 | L = self.embed_loss(dvecs, cos_sim_matrix) 90 | return L.sum() 91 | 92 | 93 | @lru_cache(maxsize=5) 94 | def _indices_to_replace(n_spkr, n_uttr): 95 | indices = [ 96 | (s * n_uttr + u) * n_spkr + s for s in range(n_spkr) for u in range(n_uttr) 97 | ] 98 | return torch.LongTensor(indices) 99 | -------------------------------------------------------------------------------- /preprocess.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """Preprocess script""" 3 | 4 | import json 5 | from argparse import ArgumentParser 6 | from multiprocessing import cpu_count 7 | from pathlib import Path 8 | from typing import List 9 | from uuid import uuid4 10 | from warnings import filterwarnings 11 | 12 | import torch 13 | import torchaudio 14 | from librosa.util import find_files 15 | from torch.utils.data import DataLoader 16 | from tqdm import tqdm 17 | 18 | from data import Wav2Mel 19 | 20 | 21 | class PreprocessDataset(torch.utils.data.Dataset): 22 | """Preprocess dataset.""" 23 | 24 | def __init__(self, data_dirs: List[str], wav2mel): 25 | self.wav2mel = wav2mel 26 | self.speakers = set() 27 | self.infos = [] 28 | 29 | for data_dir in data_dirs: 30 | speaker_dir_paths = [x for x in Path(data_dir).iterdir() if x.is_dir()] 31 | for speaker_dir_path in speaker_dir_paths: 32 | audio_paths = find_files(speaker_dir_path) 33 | speaker_name = speaker_dir_path.name 34 | self.speakers.add(speaker_name) 35 | for audio_path in audio_paths: 36 | self.infos.append((speaker_name, audio_path)) 37 | 38 | def __len__(self): 39 | return len(self.infos) 40 | 41 | def __getitem__(self, index): 42 | speaker_name, audio_path = self.infos[index] 43 | wav_tensor, sample_rate = torchaudio.load(audio_path) 44 | mel_tensor = self.wav2mel(wav_tensor, sample_rate) 45 | return speaker_name, mel_tensor 46 | 47 | 48 | def preprocess(data_dirs, output_dir): 49 | """Preprocess audio files into features for training.""" 50 | 51 | output_dir_path = Path(output_dir) 52 | output_dir_path.mkdir(parents=True, exist_ok=True) 53 | 54 | wav2mel = Wav2Mel() 55 | wav2mel_jit = torch.jit.script(wav2mel) 56 | sox_effects_jit = torch.jit.script(wav2mel.sox_effects) 57 | log_melspectrogram_jit = torch.jit.script(wav2mel.log_melspectrogram) 58 | 59 | wav2mel_jit.save(str(output_dir_path / "wav2mel.pt")) 60 | sox_effects_jit.save(str(output_dir_path / "sox_effects.pt")) 61 | log_melspectrogram_jit.save(str(output_dir_path / "log_melspectrogram.pt")) 62 | 63 | dataset = PreprocessDataset(data_dirs, wav2mel_jit) 64 | dataloader = DataLoader(dataset, batch_size=1, num_workers=cpu_count()) 65 | 66 | infos = { 67 | "n_mels": wav2mel.n_mels, 68 | "speakers": {speaker_name: [] for speaker_name in dataset.speakers}, 69 | } 70 | 71 | for speaker_name, mel_tensor in tqdm(dataloader, ncols=0, desc="Preprocess"): 72 | speaker_name = speaker_name[0] 73 | mel_tensor = mel_tensor.squeeze(0) 74 | random_file_path = output_dir_path / f"uttr-{uuid4().hex}.pt" 75 | torch.save(mel_tensor, random_file_path) 76 | infos["speakers"][speaker_name].append( 77 | { 78 | "feature_path": random_file_path.name, 79 | "mel_len": len(mel_tensor), 80 | } 81 | ) 82 | 83 | with open(output_dir_path / "metadata.json", "w") as f: 84 | json.dump(infos, f, indent=2) 85 | 86 | 87 | if __name__ == "__main__": 88 | filterwarnings("ignore") 89 | PARSER = ArgumentParser() 90 | PARSER.add_argument("data_dirs", type=str, nargs="+") 91 | PARSER.add_argument("-o", "--output_dir", type=str, required=True) 92 | preprocess(**vars(PARSER.parse_args())) 93 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | """Train d-vector.""" 4 | 5 | import json 6 | from argparse import ArgumentParser 7 | from collections import deque 8 | from datetime import datetime 9 | from itertools import count 10 | from multiprocessing import cpu_count 11 | from pathlib import Path 12 | 13 | import torch 14 | from torch.optim import SGD 15 | from torch.optim.lr_scheduler import StepLR 16 | from torch.utils.data import random_split 17 | from torch.utils.tensorboard import SummaryWriter 18 | from tqdm import tqdm 19 | 20 | from data import GE2EDataset, InfiniteDataLoader, collate_batch, infinite_iterator 21 | from modules import AttentivePooledLSTMDvector, GE2ELoss 22 | 23 | 24 | def train( 25 | data_dir, 26 | model_dir, 27 | n_speakers, 28 | n_utterances, 29 | seg_len, 30 | save_every, 31 | valid_every, 32 | decay_every, 33 | batch_per_valid, 34 | n_workers, 35 | comment, 36 | ): 37 | """Train a d-vector network.""" 38 | 39 | # setup job name 40 | start_time = datetime.now().strftime("%Y-%m-%d_%H:%M:%S") 41 | job_name = f"{start_time}_{comment}" if comment is not None else start_time 42 | 43 | # setup checkpoint and log dirs 44 | checkpoints_path = Path(model_dir) / "checkpoints" / job_name 45 | checkpoints_path.mkdir(parents=True, exist_ok=True) 46 | writer = SummaryWriter(Path(model_dir) / "logs" / job_name) 47 | 48 | # create data loader, iterator 49 | with open(Path(data_dir, "metadata.json"), "r") as f: 50 | metadata = json.load(f) 51 | dataset = GE2EDataset(data_dir, metadata["speakers"], n_utterances, seg_len) 52 | trainset, validset = random_split(dataset, [len(dataset) - n_speakers, n_speakers]) 53 | train_loader = InfiniteDataLoader( 54 | trainset, 55 | batch_size=n_speakers, 56 | num_workers=n_workers, 57 | collate_fn=collate_batch, 58 | drop_last=True, 59 | ) 60 | valid_loader = InfiniteDataLoader( 61 | validset, 62 | batch_size=n_speakers, 63 | num_workers=n_workers, 64 | collate_fn=collate_batch, 65 | drop_last=True, 66 | ) 67 | train_iter = infinite_iterator(train_loader) 68 | valid_iter = infinite_iterator(valid_loader) 69 | 70 | # display training infos 71 | assert len(trainset) >= n_speakers 72 | assert len(validset) >= n_speakers 73 | print(f"[INFO] Use {len(trainset)} speakers for training.") 74 | print(f"[INFO] Use {len(validset)} speakers for validation.") 75 | 76 | # build network and training tools 77 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 78 | dvector = AttentivePooledLSTMDvector( 79 | dim_input=metadata["n_mels"], 80 | seg_len=seg_len, 81 | ).to(device) 82 | dvector = torch.jit.script(dvector) 83 | criterion = GE2ELoss().to(device) 84 | optimizer = SGD(list(dvector.parameters()) + list(criterion.parameters()), lr=0.01) 85 | scheduler = StepLR(optimizer, step_size=decay_every, gamma=0.5) 86 | 87 | # record training infos 88 | pbar = tqdm(total=valid_every, ncols=0, desc="Train") 89 | running_train_loss, running_grad_norm = deque(maxlen=100), deque(maxlen=100) 90 | running_valid_loss = deque(maxlen=batch_per_valid) 91 | 92 | # start training 93 | for step in count(start=1): 94 | 95 | batch = next(train_iter).to(device) 96 | embds = dvector(batch).view(n_speakers, n_utterances, -1) 97 | loss = criterion(embds) 98 | 99 | optimizer.zero_grad() 100 | loss.backward() 101 | 102 | grad_norm = torch.nn.utils.clip_grad_norm_( 103 | list(dvector.parameters()) + list(criterion.parameters()), 104 | max_norm=3, 105 | norm_type=2.0, 106 | ) 107 | dvector.embedding.weight.grad *= 0.5 108 | dvector.embedding.bias.grad *= 0.5 109 | criterion.w.grad *= 0.01 110 | criterion.b.grad *= 0.01 111 | 112 | optimizer.step() 113 | scheduler.step() 114 | 115 | running_train_loss.append(loss.item()) 116 | running_grad_norm.append(grad_norm.item()) 117 | avg_train_loss = sum(running_train_loss) / len(running_train_loss) 118 | avg_grad_norm = sum(running_grad_norm) / len(running_grad_norm) 119 | 120 | pbar.update(1) 121 | pbar.set_postfix(loss=avg_train_loss, grad_norm=avg_grad_norm) 122 | 123 | if step % valid_every == 0: 124 | pbar.reset() 125 | 126 | for _ in range(batch_per_valid): 127 | batch = next(valid_iter).to(device) 128 | with torch.no_grad(): 129 | embd = dvector(batch).view(n_speakers, n_utterances, -1) 130 | loss = criterion(embd) 131 | running_valid_loss.append(loss.item()) 132 | 133 | avg_valid_loss = sum(running_valid_loss) / len(running_valid_loss) 134 | 135 | tqdm.write(f"Valid: step={step}, loss={avg_valid_loss:.1f}") 136 | writer.add_scalar("Loss/train", avg_train_loss, step) 137 | writer.add_scalar("Loss/valid", avg_valid_loss, step) 138 | 139 | if step % save_every == 0: 140 | ckpt_path = checkpoints_path / f"dvector-step{step}.pt" 141 | dvector.cpu() 142 | dvector.save(str(ckpt_path)) 143 | dvector.to(device) 144 | 145 | 146 | if __name__ == "__main__": 147 | PARSER = ArgumentParser() 148 | PARSER.add_argument("data_dir", type=str) 149 | PARSER.add_argument("model_dir", type=str) 150 | PARSER.add_argument("-n", "--n_speakers", type=int, default=64) 151 | PARSER.add_argument("-m", "--n_utterances", type=int, default=10) 152 | PARSER.add_argument("--seg_len", type=int, default=160) 153 | PARSER.add_argument("--save_every", type=int, default=10000) 154 | PARSER.add_argument("--valid_every", type=int, default=1000) 155 | PARSER.add_argument("--decay_every", type=int, default=100000) 156 | PARSER.add_argument("--batch_per_valid", type=int, default=100) 157 | PARSER.add_argument("--n_workers", type=int, default=cpu_count()) 158 | PARSER.add_argument("--comment", type=str) 159 | train(**vars(PARSER.parse_args())) 160 | -------------------------------------------------------------------------------- /visualize.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | """Visualize speaker embeddings.""" 4 | 5 | from argparse import ArgumentParser 6 | from pathlib import Path 7 | from warnings import filterwarnings 8 | 9 | import matplotlib.pyplot as plt 10 | import seaborn as sns 11 | import torch 12 | import torchaudio 13 | from librosa.util import find_files 14 | from sklearn.manifold import TSNE 15 | from tqdm import tqdm 16 | 17 | 18 | def visualize(data_dirs, wav2mel_path, checkpoint_path, output_path): 19 | """Visualize high-dimensional embeddings using t-SNE.""" 20 | 21 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 22 | 23 | wav2mel = torch.jit.load(wav2mel_path) 24 | dvector = torch.jit.load(checkpoint_path).eval().to(device) 25 | 26 | print("[INFO] model loaded.") 27 | 28 | n_spkrs = 0 29 | paths, spkr_names, mels = [], [], [] 30 | 31 | for data_dir in data_dirs: 32 | data_dir_path = Path(data_dir) 33 | for spkr_dir in [x for x in data_dir_path.iterdir() if x.is_dir()]: 34 | n_spkrs += 1 35 | audio_paths = find_files(spkr_dir) 36 | spkr_name = spkr_dir.name 37 | for audio_path in audio_paths: 38 | paths.append(audio_path) 39 | spkr_names.append(spkr_name) 40 | 41 | for audio_path in tqdm(paths, ncols=0, desc="Preprocess"): 42 | wav_tensor, sample_rate = torchaudio.load(audio_path) 43 | with torch.no_grad(): 44 | mel_tensor = wav2mel(wav_tensor, sample_rate) 45 | mels.append(mel_tensor) 46 | 47 | embs = [] 48 | 49 | for mel in tqdm(mels, ncols=0, desc="Embed"): 50 | with torch.no_grad(): 51 | emb = dvector.embed_utterance(mel.to(device)) 52 | emb = emb.detach().cpu().numpy() 53 | embs.append(emb) 54 | 55 | embs = np.array(emb) 56 | tsne = TSNE(n_components=2, verbose=1, perplexity=40, n_iter=300) 57 | transformed = tsne.fit_transform(embs) 58 | 59 | print("[INFO] embeddings transformed.") 60 | 61 | data = { 62 | "dim-1": transformed[:, 0], 63 | "dim-2": transformed[:, 1], 64 | "label": spkr_names, 65 | } 66 | 67 | plt.figure() 68 | sns.scatterplot( 69 | x="dim-1", 70 | y="dim-2", 71 | hue="label", 72 | palette=sns.color_palette(n_colors=n_spkrs), 73 | data=data, 74 | legend="full", 75 | ) 76 | plt.legend(loc="center left", bbox_to_anchor=(1, 0.5)) 77 | plt.tight_layout() 78 | plt.savefig(output_path) 79 | 80 | 81 | if __name__ == "__main__": 82 | filterwarnings("ignore") 83 | PARSER = ArgumentParser() 84 | PARSER.add_argument("data_dirs", type=str, nargs="+") 85 | PARSER.add_argument("-w", "--wav2mel_path", required=True) 86 | PARSER.add_argument("-c", "--checkpoint_path", required=True) 87 | PARSER.add_argument("-o", "--output_path", required=True) 88 | visualize(**vars(PARSER.parse_args())) 89 | --------------------------------------------------------------------------------