├── config ├── __init__.py └── shared_configs.py ├── utils ├── __init__.py ├── io.py └── audio.py ├── speaker_encoder ├── __init__.py ├── utils │ ├── __init__.py │ ├── visual.py │ ├── io.py │ ├── generic_utils.py │ └── prepare_voxceleb.py ├── requirements.txt ├── umap.png ├── README.md ├── speaker_encoder_config.py ├── models │ ├── lstm.py │ └── resnet.py ├── configs │ └── config.json ├── losses.py └── dataset.py ├── vi_speaker_center.py ├── README.md ├── vi_speaker_batch.py ├── vi_speaker_single.py ├── saved_models └── config.json └── LICENSE /config/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /speaker_encoder/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /speaker_encoder/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /speaker_encoder/requirements.txt: -------------------------------------------------------------------------------- 1 | umap-learn 2 | numpy>=1.17.0 3 | -------------------------------------------------------------------------------- /speaker_encoder/umap.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PlayVoice/VI-Speaker/HEAD/speaker_encoder/umap.png -------------------------------------------------------------------------------- /vi_speaker_center.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | 4 | single_id_path = "speaker_embedding" 5 | center_id_path = "speaker_embedding_center" 6 | 7 | os.makedirs(f"./{center_id_path}") 8 | 9 | for speaker in os.listdir(single_id_path): 10 | if os.path.isdir(f"./{single_id_path}/{speaker}"): 11 | print(f"---->{speaker}<----") 12 | subfile_num = 0 13 | speaker_cen = 0 14 | for file in os.listdir(f"./{single_id_path}/{speaker}"): 15 | if file.endswith(".npy"): 16 | source_embed = np.load(f"./{single_id_path}/{speaker}/{file}") 17 | source_embed = source_embed.astype(np.float32) 18 | speaker_cen = speaker_cen + source_embed 19 | subfile_num = subfile_num + 1 20 | speaker_cen = speaker_cen / subfile_num 21 | np.save(f"./{center_id_path}/{speaker}.npy", speaker_cen, allow_pickle=False) 22 | -------------------------------------------------------------------------------- /speaker_encoder/utils/visual.py: -------------------------------------------------------------------------------- 1 | import matplotlib 2 | import matplotlib.pyplot as plt 3 | import numpy as np 4 | import umap 5 | 6 | matplotlib.use("Agg") 7 | 8 | 9 | colormap = ( 10 | np.array( 11 | [ 12 | [76, 255, 0], 13 | [0, 127, 70], 14 | [255, 0, 0], 15 | [255, 217, 38], 16 | [0, 135, 255], 17 | [165, 0, 165], 18 | [255, 167, 255], 19 | [0, 255, 255], 20 | [255, 96, 38], 21 | [142, 76, 0], 22 | [33, 0, 127], 23 | [0, 0, 0], 24 | [183, 183, 183], 25 | ], 26 | dtype=np.float, 27 | ) 28 | / 255 29 | ) 30 | 31 | 32 | def plot_embeddings(embeddings, num_utter_per_speaker): 33 | embeddings = embeddings[: 10 * num_utter_per_speaker] 34 | model = umap.UMAP() 35 | projection = model.fit_transform(embeddings) 36 | num_speakers = embeddings.shape[0] // num_utter_per_speaker 37 | ground_truth = np.repeat(np.arange(num_speakers), num_utter_per_speaker) 38 | colors = [colormap[i] for i in ground_truth] 39 | 40 | fig, ax = plt.subplots(figsize=(16, 10)) 41 | _ = ax.scatter(projection[:, 0], projection[:, 1], c=colors) 42 | plt.gca().set_aspect("equal", "datalim") 43 | plt.title("UMAP projection") 44 | plt.tight_layout() 45 | plt.savefig("umap") 46 | return fig 47 | -------------------------------------------------------------------------------- /speaker_encoder/README.md: -------------------------------------------------------------------------------- 1 | ### Speaker Encoder 2 | 3 | This is an implementation of https://arxiv.org/abs/1710.10467. This model can be used for voice and speaker embedding. 4 | 5 | With the code here you can generate d-vectors for both multi-speaker and single-speaker TTS datasets, then visualise and explore them along with the associated audio files in an interactive chart. 6 | 7 | Below is an example showing embedding results of various speakers. You can generate the same plot with the provided notebook as demonstrated in [this video](https://youtu.be/KW3oO7JVa7Q). 8 | 9 | ![](umap.png) 10 | 11 | Download a pretrained model from [Released Models](https://github.com/mozilla/TTS/wiki/Released-Models) page. 12 | 13 | To run the code, you need to follow the same flow as in TTS. 14 | 15 | - Define 'config.json' for your needs. Note that, audio parameters should match your TTS model. 16 | - Example training call ```python speaker_encoder/train.py --config_path speaker_encoder/config.json --data_path ~/Data/Libri-TTS/train-clean-360``` 17 | - Generate embedding vectors ```python speaker_encoder/compute_embeddings.py --use_cuda true /model/path/best_model.pth.tar model/config/path/config.json dataset/path/ output_path``` . This code parses all .wav files at the given dataset path and generates the same folder structure under the output path with the generated embedding files. 18 | - Watch training on Tensorboard as in TTS 19 | -------------------------------------------------------------------------------- /speaker_encoder/utils/io.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import os 3 | 4 | from TTS.utils.io import save_fsspec 5 | 6 | 7 | def save_checkpoint(model, optimizer, model_loss, out_path, current_step): 8 | checkpoint_path = "checkpoint_{}.pth.tar".format(current_step) 9 | checkpoint_path = os.path.join(out_path, checkpoint_path) 10 | print(" | | > Checkpoint saving : {}".format(checkpoint_path)) 11 | 12 | new_state_dict = model.state_dict() 13 | state = { 14 | "model": new_state_dict, 15 | "optimizer": optimizer.state_dict() if optimizer is not None else None, 16 | "step": current_step, 17 | "loss": model_loss, 18 | "date": datetime.date.today().strftime("%B %d, %Y"), 19 | } 20 | save_fsspec(state, checkpoint_path) 21 | 22 | 23 | def save_best_model(model, optimizer, model_loss, best_loss, out_path, current_step): 24 | if model_loss < best_loss: 25 | new_state_dict = model.state_dict() 26 | state = { 27 | "model": new_state_dict, 28 | "optimizer": optimizer.state_dict(), 29 | "step": current_step, 30 | "loss": model_loss, 31 | "date": datetime.date.today().strftime("%B %d, %Y"), 32 | } 33 | best_loss = model_loss 34 | bestmodel_path = "best_model.pth.tar" 35 | bestmodel_path = os.path.join(out_path, bestmodel_path) 36 | print("\n > BEST MODEL ({0:.5f}) : {1:}".format(model_loss, bestmodel_path)) 37 | save_fsspec(state, bestmodel_path) 38 | return best_loss 39 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # VI-Speaker 2 | Speaker embedding for VI-SVC and VI-SVS, alse for VITS; Use this to replace the ID to implement voice clone. 3 | 4 | # code from mozill_tts and Coqpit/TTS 5 | https://github.com/mozilla/TTS/tree/master/TTS/speaker_encoder 6 | 7 | https://github.com/coqui-ai/TTS 8 | 9 | pip install coqpit 10 | 11 | # download model, 12 | https://github.com/mozilla/TTS/wiki/Released-Models 13 | 14 | Speaker-Encoder by @mueller91 LibriTTS + VCTK + VoxCeleb + CommonVoice 15 | 16 | https://drive.google.com/drive/folders/15oeBYf6Qn1edONkVLXe82MzdIi3O_9m3 17 | 18 | Or get it at release **saved_models.zip** 19 | 20 | # use 21 | python vi_speaker_single.py ./saved_models/best_model.pth.tar ./saved_models/config.json -s TEST.wav -t TEST.npy 22 | 23 | # batch use 24 | python vi_speaker_batch.py ./saved_models/best_model.pth.tar ./saved_models/config.json ./data/waves ./speaker_embedding 25 | 26 | data/ 27 | └── waves 28 | ├── spk1 29 | │   ├── 000002.wav 30 | │   ├── 000006.wav 31 | │   └── 000038.wav 32 | └── spk2 33 | ├── 000040.wav 34 | ├── 000044.wav 35 | └── 000077.wav 36 | 37 | speaker_embedding/ 38 | ├── spk1 39 | │   ├── 000002.npy 40 | │   ├── 000006.npy 41 | │   └── 000038.npy 42 | └── spk2 43 | ├── 000040.npy 44 | ├── 000044.npy 45 | └── 000077.npy 46 | 47 | # compute speaker center 48 | input path = speaker_embedding, output path = speaker_embedding_center 49 | 50 | python vi_speaker_center.py 51 | 52 | speaker_embedding_center/ 53 | ├── spk1.npy 54 | └── spk2.npy 55 | 56 | 57 | # for VI-SVC 58 | mv speaker_embedding_center data/spkid 59 | 60 | data/ 61 | ├── waves 62 | │   ├── 10001 63 | │   ├── 20400 64 | │   │   ├── 20400_001.wav 65 | │   │   ├── 20456_019.wav 66 | │   │   67 | ├── phone 68 | │   ├── 10001 69 | │   ├── 20400 70 | │   │   ├── 20400_001.npy 71 | │   │   ├── 20456_019.npy 72 | │   │   73 | ├── lable 74 | │   ├── 10001 75 | │   ├── 20400 76 | │   │   ├── 20400_001.npy 77 | │   │   ├── 20456_019.npy 78 | │   │   79 | ├── spkid 80 | │   ├── 10001.npy 81 | │   ├── 20400.npy 82 | │   │   83 | 84 | 85 | 86 | -------------------------------------------------------------------------------- /speaker_encoder/speaker_encoder_config.py: -------------------------------------------------------------------------------- 1 | from dataclasses import asdict, dataclass, field 2 | from typing import Dict, List 3 | 4 | from coqpit import MISSING 5 | 6 | from config.shared_configs import BaseAudioConfig, BaseDatasetConfig, BaseTrainingConfig 7 | 8 | 9 | @dataclass 10 | class SpeakerEncoderConfig(BaseTrainingConfig): 11 | """Defines parameters for Speaker Encoder model.""" 12 | 13 | model: str = "speaker_encoder" 14 | audio: BaseAudioConfig = field(default_factory=BaseAudioConfig) 15 | datasets: List[BaseDatasetConfig] = field(default_factory=lambda: [BaseDatasetConfig()]) 16 | # model params 17 | model_params: Dict = field( 18 | default_factory=lambda: { 19 | "model_name": "lstm", 20 | "input_dim": 80, 21 | "proj_dim": 256, 22 | "lstm_dim": 768, 23 | "num_lstm_layers": 3, 24 | "use_lstm_with_projection": True, 25 | } 26 | ) 27 | 28 | audio_augmentation: Dict = field(default_factory=lambda: {}) 29 | 30 | storage: Dict = field( 31 | default_factory=lambda: { 32 | "sample_from_storage_p": 0.66, # the probability with which we'll sample from the DataSet in-memory storage 33 | "storage_size": 15, # the size of the in-memory storage with respect to a single batch 34 | } 35 | ) 36 | 37 | # training params 38 | max_train_step: int = 1000000 # end training when number of training steps reaches this value. 39 | loss: str = "angleproto" 40 | grad_clip: float = 3.0 41 | lr: float = 0.0001 42 | lr_decay: bool = False 43 | warmup_steps: int = 4000 44 | wd: float = 1e-6 45 | 46 | # logging params 47 | tb_model_param_stats: bool = False 48 | steps_plot_stats: int = 10 49 | checkpoint: bool = True 50 | save_step: int = 1000 51 | print_step: int = 20 52 | 53 | # data loader 54 | num_speakers_in_batch: int = MISSING 55 | num_utters_per_speaker: int = MISSING 56 | num_loader_workers: int = MISSING 57 | skip_speakers: bool = False 58 | voice_len: float = 1.6 59 | 60 | def check_values(self): 61 | super().check_values() 62 | c = asdict(self) 63 | assert ( 64 | c["model_params"]["input_dim"] == self.audio.num_mels 65 | ), " [!] model input dimendion must be equal to melspectrogram dimension." 66 | -------------------------------------------------------------------------------- /vi_speaker_batch.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import json 4 | import fsspec 5 | import torch 6 | import numpy as np 7 | import argparse 8 | 9 | from tqdm import tqdm 10 | from argparse import RawTextHelpFormatter 11 | from speaker_encoder.models.lstm import LSTMSpeakerEncoder 12 | from speaker_encoder.speaker_encoder_config import SpeakerEncoderConfig 13 | 14 | from utils.audio import AudioProcessor 15 | from vi_speaker_single import read_json 16 | 17 | 18 | def get_spk_wavs(dataset_path, output_path): 19 | wav_files = [] 20 | os.makedirs(f"./{output_path}") 21 | for spks in os.listdir(dataset_path): 22 | if os.path.isdir(f"./{dataset_path}/{spks}"): 23 | os.makedirs(f"./{output_path}/{spks}") 24 | for file in os.listdir(f"./{dataset_path}/{spks}"): 25 | if file.endswith(".wav"): 26 | wav_files.append(f"./{dataset_path}/{spks}/{file}") 27 | elif spks.endswith(".wav"): 28 | wav_files.append(f"./{dataset_path}/{spks}") 29 | return wav_files 30 | 31 | 32 | if __name__ == "__main__": 33 | 34 | parser = argparse.ArgumentParser( 35 | description="""Compute embedding vectors for each wav file in a dataset.""", 36 | formatter_class=RawTextHelpFormatter, 37 | ) 38 | parser.add_argument("model_path", type=str, help="Path to model checkpoint file.") 39 | parser.add_argument("config_path", type=str, help="Path to model config file.") 40 | parser.add_argument("dataset_path", type=str, help="Path to dataset waves.") 41 | parser.add_argument( 42 | "output_path", type=str, help="path for output speaker/speaker_wavs.npy." 43 | ) 44 | parser.add_argument("--use_cuda", type=bool, help="flag to set cuda.", default=True) 45 | parser.add_argument("--eval", type=bool, help="compute eval.", default=True) 46 | args = parser.parse_args() 47 | dataset_path = args.dataset_path 48 | output_path = args.output_path 49 | 50 | # config 51 | config_dict = read_json(args.config_path) 52 | 53 | # model 54 | config = SpeakerEncoderConfig(config_dict) 55 | config.from_dict(config_dict) 56 | 57 | speaker_encoder = LSTMSpeakerEncoder( 58 | config.model_params["input_dim"], 59 | config.model_params["proj_dim"], 60 | config.model_params["lstm_dim"], 61 | config.model_params["num_lstm_layers"], 62 | ) 63 | 64 | speaker_encoder.load_checkpoint(args.model_path, eval=True, use_cuda=args.use_cuda) 65 | 66 | # preprocess 67 | speaker_encoder_ap = AudioProcessor(**config.audio) 68 | # normalize the input audio level and trim silences 69 | speaker_encoder_ap.do_sound_norm = True 70 | speaker_encoder_ap.do_trim_silence = True 71 | 72 | wav_files = get_spk_wavs(dataset_path, output_path) 73 | 74 | # compute speaker embeddings 75 | for idx, wav_file in enumerate(tqdm(wav_files)): 76 | waveform = speaker_encoder_ap.load_wav( 77 | wav_file, sr=speaker_encoder_ap.sample_rate 78 | ) 79 | spec = speaker_encoder_ap.melspectrogram(waveform) 80 | spec = torch.from_numpy(spec.T) 81 | if args.use_cuda: 82 | spec = spec.cuda() 83 | spec = spec.unsqueeze(0) 84 | embed = speaker_encoder.compute_embedding(spec).detach().cpu().numpy() 85 | embed = embed.squeeze() 86 | embed_path = wav_file.replace(dataset_path, output_path) 87 | embed_path = embed_path.replace(".wav", ".npy") 88 | np.save(embed_path, embed, allow_pickle=False) 89 | -------------------------------------------------------------------------------- /vi_speaker_single.py: -------------------------------------------------------------------------------- 1 | import re 2 | import json 3 | import fsspec 4 | import torch 5 | import numpy as np 6 | import argparse 7 | 8 | from argparse import RawTextHelpFormatter 9 | from speaker_encoder.models.lstm import LSTMSpeakerEncoder 10 | from speaker_encoder.speaker_encoder_config import SpeakerEncoderConfig 11 | 12 | from utils.audio import AudioProcessor 13 | 14 | 15 | def read_json(json_path): 16 | config_dict = {} 17 | try: 18 | with fsspec.open(json_path, "r", encoding="utf-8") as f: 19 | data = json.load(f) 20 | except json.decoder.JSONDecodeError: 21 | # backwards compat. 22 | data = read_json_with_comments(json_path) 23 | config_dict.update(data) 24 | return config_dict 25 | 26 | 27 | def read_json_with_comments(json_path): 28 | """for backward compat.""" 29 | # fallback to json 30 | with fsspec.open(json_path, "r", encoding="utf-8") as f: 31 | input_str = f.read() 32 | # handle comments 33 | input_str = re.sub(r"\\\n", "", input_str) 34 | input_str = re.sub(r"//.*\n", "\n", input_str) 35 | data = json.loads(input_str) 36 | return data 37 | 38 | 39 | if __name__ == "__main__": 40 | 41 | parser = argparse.ArgumentParser( 42 | description="""Compute embedding vectors for each wav file in a dataset.""", 43 | formatter_class=RawTextHelpFormatter, 44 | ) 45 | parser.add_argument("model_path", type=str, help="Path to model checkpoint file.") 46 | parser.add_argument( 47 | "config_path", 48 | type=str, 49 | help="Path to model config file.", 50 | ) 51 | 52 | parser.add_argument("-s", "--source", help="input wave", dest="source") 53 | parser.add_argument( 54 | "-t", "--target", help="output 256d speaker embeddimg", dest="target" 55 | ) 56 | 57 | parser.add_argument("--use_cuda", type=bool, help="flag to set cuda.", default=True) 58 | parser.add_argument("--eval", type=bool, help="compute eval.", default=True) 59 | 60 | args = parser.parse_args() 61 | source_file = args.source 62 | target_file = args.target 63 | 64 | # config 65 | config_dict = read_json(args.config_path) 66 | # print(config_dict) 67 | 68 | # model 69 | config = SpeakerEncoderConfig(config_dict) 70 | config.from_dict(config_dict) 71 | 72 | speaker_encoder = LSTMSpeakerEncoder( 73 | config.model_params["input_dim"], 74 | config.model_params["proj_dim"], 75 | config.model_params["lstm_dim"], 76 | config.model_params["num_lstm_layers"], 77 | ) 78 | 79 | speaker_encoder.load_checkpoint(args.model_path, eval=True, use_cuda=args.use_cuda) 80 | 81 | # preprocess 82 | speaker_encoder_ap = AudioProcessor(**config.audio) 83 | # normalize the input audio level and trim silences 84 | speaker_encoder_ap.do_sound_norm = True 85 | speaker_encoder_ap.do_trim_silence = True 86 | 87 | # compute speaker embeddings 88 | 89 | # extract the embedding 90 | waveform = speaker_encoder_ap.load_wav( 91 | source_file, sr=speaker_encoder_ap.sample_rate 92 | ) 93 | spec = speaker_encoder_ap.melspectrogram(waveform) 94 | spec = torch.from_numpy(spec.T) 95 | if args.use_cuda: 96 | spec = spec.cuda() 97 | spec = spec.unsqueeze(0) 98 | embed = speaker_encoder.compute_embedding(spec).detach().cpu().numpy() 99 | embed = embed.squeeze() 100 | # print(embed) 101 | # print(embed.size) 102 | np.save(target_file, embed, allow_pickle=False) 103 | 104 | 105 | if hasattr(speaker_encoder, 'module'): 106 | state_dict = speaker_encoder.module.state_dict() 107 | else: 108 | state_dict = speaker_encoder.state_dict() 109 | torch.save({'model': state_dict}, "model_small.pth") 110 | -------------------------------------------------------------------------------- /speaker_encoder/models/lstm.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import nn 4 | 5 | from utils.io import load_fsspec 6 | 7 | 8 | class LSTMWithProjection(nn.Module): 9 | def __init__(self, input_size, hidden_size, proj_size): 10 | super().__init__() 11 | self.input_size = input_size 12 | self.hidden_size = hidden_size 13 | self.proj_size = proj_size 14 | self.lstm = nn.LSTM(input_size, hidden_size, batch_first=True) 15 | self.linear = nn.Linear(hidden_size, proj_size, bias=False) 16 | 17 | def forward(self, x): 18 | self.lstm.flatten_parameters() 19 | o, (_, _) = self.lstm(x) 20 | return self.linear(o) 21 | 22 | 23 | class LSTMWithoutProjection(nn.Module): 24 | def __init__(self, input_dim, lstm_dim, proj_dim, num_lstm_layers): 25 | super().__init__() 26 | self.lstm = nn.LSTM(input_size=input_dim, hidden_size=lstm_dim, num_layers=num_lstm_layers, batch_first=True) 27 | self.linear = nn.Linear(lstm_dim, proj_dim, bias=True) 28 | self.relu = nn.ReLU() 29 | 30 | def forward(self, x): 31 | _, (hidden, _) = self.lstm(x) 32 | return self.relu(self.linear(hidden[-1])) 33 | 34 | 35 | class LSTMSpeakerEncoder(nn.Module): 36 | def __init__(self, input_dim, proj_dim=256, lstm_dim=768, num_lstm_layers=3, use_lstm_with_projection=True): 37 | super().__init__() 38 | self.use_lstm_with_projection = use_lstm_with_projection 39 | layers = [] 40 | # choise LSTM layer 41 | if use_lstm_with_projection: 42 | layers.append(LSTMWithProjection(input_dim, lstm_dim, proj_dim)) 43 | for _ in range(num_lstm_layers - 1): 44 | layers.append(LSTMWithProjection(proj_dim, lstm_dim, proj_dim)) 45 | self.layers = nn.Sequential(*layers) 46 | else: 47 | self.layers = LSTMWithoutProjection(input_dim, lstm_dim, proj_dim, num_lstm_layers) 48 | 49 | self._init_layers() 50 | 51 | def _init_layers(self): 52 | for name, param in self.layers.named_parameters(): 53 | if "bias" in name: 54 | nn.init.constant_(param, 0.0) 55 | elif "weight" in name: 56 | nn.init.xavier_normal_(param) 57 | 58 | def forward(self, x): 59 | # TODO: implement state passing for lstms 60 | d = self.layers(x) 61 | if self.use_lstm_with_projection: 62 | d = torch.nn.functional.normalize(d[:, -1], p=2, dim=1) 63 | else: 64 | d = torch.nn.functional.normalize(d, p=2, dim=1) 65 | return d 66 | 67 | @torch.no_grad() 68 | def inference(self, x): 69 | d = self.layers.forward(x) 70 | if self.use_lstm_with_projection: 71 | d = torch.nn.functional.normalize(d[:, -1], p=2, dim=1) 72 | else: 73 | d = torch.nn.functional.normalize(d, p=2, dim=1) 74 | return d 75 | 76 | def compute_embedding(self, x, num_frames=250, num_eval=10, return_mean=True): 77 | """ 78 | Generate embeddings for a batch of utterances 79 | x: 1xTxD 80 | """ 81 | max_len = x.shape[1] 82 | 83 | if max_len < num_frames: 84 | num_frames = max_len 85 | 86 | offsets = np.linspace(0, max_len - num_frames, num=num_eval) 87 | 88 | frames_batch = [] 89 | for offset in offsets: 90 | offset = int(offset) 91 | end_offset = int(offset + num_frames) 92 | frames = x[:, offset:end_offset] 93 | frames_batch.append(frames) 94 | 95 | frames_batch = torch.cat(frames_batch, dim=0) 96 | embeddings = self.inference(frames_batch) 97 | 98 | if return_mean: 99 | embeddings = torch.mean(embeddings, dim=0, keepdim=True) 100 | 101 | return embeddings 102 | 103 | def batch_compute_embedding(self, x, seq_lens, num_frames=160, overlap=0.5): 104 | """ 105 | Generate embeddings for a batch of utterances 106 | x: BxTxD 107 | """ 108 | num_overlap = num_frames * overlap 109 | max_len = x.shape[1] 110 | embed = None 111 | num_iters = seq_lens / (num_frames - num_overlap) 112 | cur_iter = 0 113 | for offset in range(0, max_len, num_frames - num_overlap): 114 | cur_iter += 1 115 | end_offset = min(x.shape[1], offset + num_frames) 116 | frames = x[:, offset:end_offset] 117 | if embed is None: 118 | embed = self.inference(frames) 119 | else: 120 | embed[cur_iter <= num_iters, :] += self.inference(frames[cur_iter <= num_iters, :, :]) 121 | return embed / num_iters 122 | 123 | # pylint: disable=unused-argument, redefined-builtin 124 | def load_checkpoint(self, checkpoint_path: str, eval: bool = False, use_cuda: bool = False): 125 | state = load_fsspec(checkpoint_path, map_location=torch.device("cpu")) 126 | self.load_state_dict(state["model"]) 127 | if use_cuda: 128 | self.cuda() 129 | if eval: 130 | self.eval() 131 | assert not self.training 132 | -------------------------------------------------------------------------------- /saved_models/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "lstm", 3 | "run_name": "mueller91", 4 | "run_description": "train speaker encoder with voxceleb1, voxceleb2 and libriSpeech ", 5 | "audio":{ 6 | // Audio processing parameters 7 | "num_mels": 80, // size of the mel spec frame. 8 | "fft_size": 1024, // number of stft frequency levels. Size of the linear spectogram frame. 9 | "sample_rate": 16000, // DATASET-RELATED: wav sample-rate. If different than the original data, it is resampled. 10 | "win_length": 1024, // stft window length in ms. 11 | "hop_length": 256, // stft window hop-lengh in ms. 12 | "frame_length_ms": null, // stft window length in ms.If null, 'win_length' is used. 13 | "frame_shift_ms": null, // stft window hop-lengh in ms. If null, 'hop_length' is used. 14 | "preemphasis": 0.98, // pre-emphasis to reduce spec noise and make it more structured. If 0.0, no -pre-emphasis. 15 | "min_level_db": -100, // normalization range 16 | "ref_level_db": 20, // reference level db, theoretically 20db is the sound of air. 17 | "power": 1.5, // value to sharpen wav signals after GL algorithm. 18 | "griffin_lim_iters": 60,// #griffin-lim iterations. 30-60 is a good range. Larger the value, slower the generation. 19 | // Normalization parameters 20 | "signal_norm": true, // normalize the spec values in range [0, 1] 21 | "symmetric_norm": true, // move normalization to range [-1, 1] 22 | "max_norm": 4.0, // scale normalization to range [-max_norm, max_norm] or [0, max_norm] 23 | "clip_norm": true, // clip normalized values into the range. 24 | "mel_fmin": 0.0, // minimum freq level for mel-spec. ~50 for male and ~95 for female voices. Tune for dataset!! 25 | "mel_fmax": 8000.0, // maximum freq level for mel-spec. Tune for dataset!! 26 | "do_trim_silence": true, // enable trimming of slience of audio as you load it. LJspeech (false), TWEB (false), Nancy (true) 27 | "trim_db": 60 // threshold for timming silence. Set this according to your dataset. 28 | }, 29 | "reinit_layers": [], 30 | "loss": "angleproto", // "ge2e" to use Generalized End-to-End loss and "angleproto" to use Angular Prototypical loss (new SOTA) 31 | "grad_clip": 3.0, // upper limit for gradients for clipping. 32 | "epochs": 1000, // total number of epochs to train. 33 | "lr": 0.0001, // Initial learning rate. If Noam decay is active, maximum learning rate. 34 | "lr_decay": false, // if true, Noam learning rate decaying is applied through training. 35 | "warmup_steps": 4000, // Noam decay steps to increase the learning rate from 0 to "lr" 36 | "tb_model_param_stats": false, // true, plots param stats per layer on tensorboard. Might be memory consuming, but good for debugging. 37 | "steps_plot_stats": 10, // number of steps to plot embeddings. 38 | "num_speakers_in_batch": 64, // Batch size for training. Lower values than 32 might cause hard to learn attention. It is overwritten by 'gradual_training'. 39 | "voice_len": 2.0, // size of the voice 40 | "num_utters_per_speaker": 10, // 41 | "num_loader_workers": 8, // number of training data loader processes. Don't set it too big. 4-8 are good values. 42 | "wd": 0.000001, // Weight decay weight. 43 | "checkpoint": true, // If true, it saves checkpoints per "save_step" 44 | "save_step": 1000, // Number of training steps expected to save traning stats and checkpoints. 45 | "print_step": 20, // Number of steps to log traning on console. 46 | "output_path": "../../OutputsMozilla/checkpoints/speaker_encoder/", // DATASET-RELATED: output path for all training outputs. 47 | "model": { 48 | "input_dim": 80, 49 | "proj_dim": 256, 50 | "lstm_dim": 768, 51 | "num_lstm_layers": 3, 52 | "use_lstm_with_projection": true 53 | }, 54 | "storage": { 55 | "sample_from_storage_p": 0.9, // the probability with which we'll sample from the DataSet in-memory storage 56 | "storage_size": 25, // the size of the in-memory storage with respect to a single batch 57 | "additive_noise": 1e-5 // add very small gaussian noise to the data in order to increase robustness 58 | }, 59 | "datasets": 60 | [ 61 | { 62 | "name": "vctk_slim", 63 | "path": "../../../audio-datasets/en/VCTK-Corpus/", 64 | "meta_file_train": null, 65 | "meta_file_val": null 66 | }, 67 | { 68 | "name": "libri_tts", 69 | "path": "../../../audio-datasets/en/LibriTTS/train-clean-100", 70 | "meta_file_train": null, 71 | "meta_file_val": null 72 | }, 73 | { 74 | "name": "libri_tts", 75 | "path": "../../../audio-datasets/en/LibriTTS/train-clean-360", 76 | "meta_file_train": null, 77 | "meta_file_val": null 78 | }, 79 | { 80 | "name": "libri_tts", 81 | "path": "../../../audio-datasets/en/LibriTTS/train-other-500", 82 | "meta_file_train": null, 83 | "meta_file_val": null 84 | }, 85 | { 86 | "name": "voxceleb1", 87 | "path": "../../../audio-datasets/en/voxceleb1/", 88 | "meta_file_train": null, 89 | "meta_file_val": null 90 | }, 91 | { 92 | "name": "voxceleb2", 93 | "path": "../../../audio-datasets/en/voxceleb2/", 94 | "meta_file_train": null, 95 | "meta_file_val": null 96 | }, 97 | { 98 | "name": "common_voice", 99 | "path": "../../../audio-datasets/en/MozillaCommonVoice", 100 | "meta_file_train": "train.tsv", 101 | "meta_file_val": "test.tsv" 102 | } 103 | ] 104 | } -------------------------------------------------------------------------------- /speaker_encoder/configs/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name": "lstm", 3 | "run_name": "mueller91", 4 | "run_description": "train speaker encoder with voxceleb1, voxceleb2 and libriSpeech ", 5 | "audio":{ 6 | // Audio processing parameters 7 | "num_mels": 80, // size of the mel spec frame. 8 | "fft_size": 1024, // number of stft frequency levels. Size of the linear spectogram frame. 9 | "sample_rate": 16000, // DATASET-RELATED: wav sample-rate. If different than the original data, it is resampled. 10 | "win_length": 1024, // stft window length in ms. 11 | "hop_length": 256, // stft window hop-lengh in ms. 12 | "frame_length_ms": null, // stft window length in ms.If null, 'win_length' is used. 13 | "frame_shift_ms": null, // stft window hop-lengh in ms. If null, 'hop_length' is used. 14 | "preemphasis": 0.98, // pre-emphasis to reduce spec noise and make it more structured. If 0.0, no -pre-emphasis. 15 | "min_level_db": -100, // normalization range 16 | "ref_level_db": 20, // reference level db, theoretically 20db is the sound of air. 17 | "power": 1.5, // value to sharpen wav signals after GL algorithm. 18 | "griffin_lim_iters": 60,// #griffin-lim iterations. 30-60 is a good range. Larger the value, slower the generation. 19 | // Normalization parameters 20 | "signal_norm": true, // normalize the spec values in range [0, 1] 21 | "symmetric_norm": true, // move normalization to range [-1, 1] 22 | "max_norm": 4.0, // scale normalization to range [-max_norm, max_norm] or [0, max_norm] 23 | "clip_norm": true, // clip normalized values into the range. 24 | "mel_fmin": 0.0, // minimum freq level for mel-spec. ~50 for male and ~95 for female voices. Tune for dataset!! 25 | "mel_fmax": 8000.0, // maximum freq level for mel-spec. Tune for dataset!! 26 | "do_trim_silence": true, // enable trimming of slience of audio as you load it. LJspeech (false), TWEB (false), Nancy (true) 27 | "trim_db": 60 // threshold for timming silence. Set this according to your dataset. 28 | }, 29 | "reinit_layers": [], 30 | "loss": "angleproto", // "ge2e" to use Generalized End-to-End loss and "angleproto" to use Angular Prototypical loss (new SOTA) 31 | "grad_clip": 3.0, // upper limit for gradients for clipping. 32 | "epochs": 1000, // total number of epochs to train. 33 | "lr": 0.0001, // Initial learning rate. If Noam decay is active, maximum learning rate. 34 | "lr_decay": false, // if true, Noam learning rate decaying is applied through training. 35 | "warmup_steps": 4000, // Noam decay steps to increase the learning rate from 0 to "lr" 36 | "tb_model_param_stats": false, // true, plots param stats per layer on tensorboard. Might be memory consuming, but good for debugging. 37 | "steps_plot_stats": 10, // number of steps to plot embeddings. 38 | "num_speakers_in_batch": 64, // Batch size for training. Lower values than 32 might cause hard to learn attention. It is overwritten by 'gradual_training'. 39 | "voice_len": 2.0, // size of the voice 40 | "num_utters_per_speaker": 10, // 41 | "num_loader_workers": 8, // number of training data loader processes. Don't set it too big. 4-8 are good values. 42 | "wd": 0.000001, // Weight decay weight. 43 | "checkpoint": true, // If true, it saves checkpoints per "save_step" 44 | "save_step": 1000, // Number of training steps expected to save traning stats and checkpoints. 45 | "print_step": 20, // Number of steps to log traning on console. 46 | "output_path": "../../OutputsMozilla/checkpoints/speaker_encoder/", // DATASET-RELATED: output path for all training outputs. 47 | "model": { 48 | "input_dim": 80, 49 | "proj_dim": 256, 50 | "lstm_dim": 768, 51 | "num_lstm_layers": 3, 52 | "use_lstm_with_projection": true 53 | }, 54 | "storage": { 55 | "sample_from_storage_p": 0.9, // the probability with which we'll sample from the DataSet in-memory storage 56 | "storage_size": 25, // the size of the in-memory storage with respect to a single batch 57 | "additive_noise": 1e-5 // add very small gaussian noise to the data in order to increase robustness 58 | }, 59 | "datasets": 60 | [ 61 | { 62 | "name": "vctk_slim", 63 | "path": "../../../audio-datasets/en/VCTK-Corpus/", 64 | "meta_file_train": null, 65 | "meta_file_val": null 66 | }, 67 | { 68 | "name": "libri_tts", 69 | "path": "../../../audio-datasets/en/LibriTTS/train-clean-100", 70 | "meta_file_train": null, 71 | "meta_file_val": null 72 | }, 73 | { 74 | "name": "libri_tts", 75 | "path": "../../../audio-datasets/en/LibriTTS/train-clean-360", 76 | "meta_file_train": null, 77 | "meta_file_val": null 78 | }, 79 | { 80 | "name": "libri_tts", 81 | "path": "../../../audio-datasets/en/LibriTTS/train-other-500", 82 | "meta_file_train": null, 83 | "meta_file_val": null 84 | }, 85 | { 86 | "name": "voxceleb1", 87 | "path": "../../../audio-datasets/en/voxceleb1/", 88 | "meta_file_train": null, 89 | "meta_file_val": null 90 | }, 91 | { 92 | "name": "voxceleb2", 93 | "path": "../../../audio-datasets/en/voxceleb2/", 94 | "meta_file_train": null, 95 | "meta_file_val": null 96 | }, 97 | { 98 | "name": "common_voice", 99 | "path": "../../../audio-datasets/en/MozillaCommonVoice", 100 | "meta_file_train": "train.tsv", 101 | "meta_file_val": "test.tsv" 102 | } 103 | ] 104 | } -------------------------------------------------------------------------------- /utils/io.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import json 3 | import os 4 | import pickle as pickle_tts 5 | import shutil 6 | from typing import Any, Callable, Dict, Union 7 | 8 | import fsspec 9 | import torch 10 | from coqpit import Coqpit 11 | 12 | 13 | class RenamingUnpickler(pickle_tts.Unpickler): 14 | """Overload default pickler to solve module renaming problem""" 15 | 16 | def find_class(self, module, name): 17 | return super().find_class(module.replace("mozilla_voice_tts", "TTS"), name) 18 | 19 | 20 | class AttrDict(dict): 21 | """A custom dict which converts dict keys 22 | to class attributes""" 23 | 24 | def __init__(self, *args, **kwargs): 25 | super().__init__(*args, **kwargs) 26 | self.__dict__ = self 27 | 28 | 29 | def copy_model_files(config: Coqpit, out_path, new_fields): 30 | """Copy config.json and other model files to training folder and add 31 | new fields. 32 | 33 | Args: 34 | config (Coqpit): Coqpit config defining the training run. 35 | out_path (str): output path to copy the file. 36 | new_fields (dict): new fileds to be added or edited 37 | in the config file. 38 | """ 39 | copy_config_path = os.path.join(out_path, "config.json") 40 | # add extra information fields 41 | config.update(new_fields, allow_new=True) 42 | # TODO: Revert to config.save_json() once Coqpit supports arbitrary paths. 43 | with fsspec.open(copy_config_path, "w", encoding="utf8") as f: 44 | json.dump(config.to_dict(), f, indent=4) 45 | 46 | # copy model stats file if available 47 | if config.audio.stats_path is not None: 48 | copy_stats_path = os.path.join(out_path, "scale_stats.npy") 49 | filesystem = fsspec.get_mapper(copy_stats_path).fs 50 | if not filesystem.exists(copy_stats_path): 51 | with fsspec.open(config.audio.stats_path, "rb") as source_file: 52 | with fsspec.open(copy_stats_path, "wb") as target_file: 53 | shutil.copyfileobj(source_file, target_file) 54 | 55 | 56 | def load_fsspec( 57 | path: str, 58 | map_location: Union[str, Callable, torch.device, Dict[Union[str, torch.device], Union[str, torch.device]]] = None, 59 | **kwargs, 60 | ) -> Any: 61 | """Like torch.load but can load from other locations (e.g. s3:// , gs://). 62 | 63 | Args: 64 | path: Any path or url supported by fsspec. 65 | map_location: torch.device or str. 66 | **kwargs: Keyword arguments forwarded to torch.load. 67 | 68 | Returns: 69 | Object stored in path. 70 | """ 71 | with fsspec.open(path, "rb") as f: 72 | return torch.load(f, map_location=map_location, **kwargs) 73 | 74 | 75 | def load_checkpoint(model, checkpoint_path, use_cuda=False, eval=False): # pylint: disable=redefined-builtin 76 | try: 77 | state = load_fsspec(checkpoint_path, map_location=torch.device("cpu")) 78 | except ModuleNotFoundError: 79 | pickle_tts.Unpickler = RenamingUnpickler 80 | state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), pickle_module=pickle_tts) 81 | model.load_state_dict(state["model"]) 82 | if use_cuda: 83 | model.cuda() 84 | if eval: 85 | model.eval() 86 | return model, state 87 | 88 | 89 | def save_fsspec(state: Any, path: str, **kwargs): 90 | """Like torch.save but can save to other locations (e.g. s3:// , gs://). 91 | 92 | Args: 93 | state: State object to save 94 | path: Any path or url supported by fsspec. 95 | **kwargs: Keyword arguments forwarded to torch.save. 96 | """ 97 | with fsspec.open(path, "wb") as f: 98 | torch.save(state, f, **kwargs) 99 | 100 | 101 | def save_model(config, model, optimizer, scaler, current_step, epoch, output_path, **kwargs): 102 | if hasattr(model, "module"): 103 | model_state = model.module.state_dict() 104 | else: 105 | model_state = model.state_dict() 106 | if isinstance(optimizer, list): 107 | optimizer_state = [optim.state_dict() for optim in optimizer] 108 | else: 109 | optimizer_state = optimizer.state_dict() if optimizer is not None else None 110 | 111 | if isinstance(scaler, list): 112 | scaler_state = [s.state_dict() for s in scaler] 113 | else: 114 | scaler_state = scaler.state_dict() if scaler is not None else None 115 | 116 | if isinstance(config, Coqpit): 117 | config = config.to_dict() 118 | 119 | state = { 120 | "config": config, 121 | "model": model_state, 122 | "optimizer": optimizer_state, 123 | "scaler": scaler_state, 124 | "step": current_step, 125 | "epoch": epoch, 126 | "date": datetime.date.today().strftime("%B %d, %Y"), 127 | } 128 | state.update(kwargs) 129 | save_fsspec(state, output_path) 130 | 131 | 132 | def save_checkpoint( 133 | config, 134 | model, 135 | optimizer, 136 | scaler, 137 | current_step, 138 | epoch, 139 | output_folder, 140 | **kwargs, 141 | ): 142 | file_name = "checkpoint_{}.pth.tar".format(current_step) 143 | checkpoint_path = os.path.join(output_folder, file_name) 144 | print("\n > CHECKPOINT : {}".format(checkpoint_path)) 145 | save_model( 146 | config, 147 | model, 148 | optimizer, 149 | scaler, 150 | current_step, 151 | epoch, 152 | checkpoint_path, 153 | **kwargs, 154 | ) 155 | 156 | 157 | def save_best_model( 158 | current_loss, 159 | best_loss, 160 | config, 161 | model, 162 | optimizer, 163 | scaler, 164 | current_step, 165 | epoch, 166 | out_path, 167 | keep_all_best=False, 168 | keep_after=10000, 169 | **kwargs, 170 | ): 171 | if current_loss < best_loss: 172 | best_model_name = f"best_model_{current_step}.pth.tar" 173 | checkpoint_path = os.path.join(out_path, best_model_name) 174 | print(" > BEST MODEL : {}".format(checkpoint_path)) 175 | save_model( 176 | config, 177 | model, 178 | optimizer, 179 | scaler, 180 | current_step, 181 | epoch, 182 | checkpoint_path, 183 | model_loss=current_loss, 184 | **kwargs, 185 | ) 186 | fs = fsspec.get_mapper(out_path).fs 187 | # only delete previous if current is saved successfully 188 | if not keep_all_best or (current_step < keep_after): 189 | model_names = fs.glob(os.path.join(out_path, "best_model*.pth.tar")) 190 | for model_name in model_names: 191 | if os.path.basename(model_name) != best_model_name: 192 | fs.rm(model_name) 193 | # create a shortcut which always points to the currently best model 194 | shortcut_name = "best_model.pth.tar" 195 | shortcut_path = os.path.join(out_path, shortcut_name) 196 | fs.copy(checkpoint_path, shortcut_path) 197 | best_loss = current_loss 198 | return best_loss 199 | -------------------------------------------------------------------------------- /speaker_encoder/models/resnet.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import nn 4 | 5 | from TTS.utils.io import load_fsspec 6 | 7 | 8 | class SELayer(nn.Module): 9 | def __init__(self, channel, reduction=8): 10 | super(SELayer, self).__init__() 11 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 12 | self.fc = nn.Sequential( 13 | nn.Linear(channel, channel // reduction), 14 | nn.ReLU(inplace=True), 15 | nn.Linear(channel // reduction, channel), 16 | nn.Sigmoid(), 17 | ) 18 | 19 | def forward(self, x): 20 | b, c, _, _ = x.size() 21 | y = self.avg_pool(x).view(b, c) 22 | y = self.fc(y).view(b, c, 1, 1) 23 | return x * y 24 | 25 | 26 | class SEBasicBlock(nn.Module): 27 | expansion = 1 28 | 29 | def __init__(self, inplanes, planes, stride=1, downsample=None, reduction=8): 30 | super(SEBasicBlock, self).__init__() 31 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 32 | self.bn1 = nn.BatchNorm2d(planes) 33 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1, bias=False) 34 | self.bn2 = nn.BatchNorm2d(planes) 35 | self.relu = nn.ReLU(inplace=True) 36 | self.se = SELayer(planes, reduction) 37 | self.downsample = downsample 38 | self.stride = stride 39 | 40 | def forward(self, x): 41 | residual = x 42 | 43 | out = self.conv1(x) 44 | out = self.relu(out) 45 | out = self.bn1(out) 46 | 47 | out = self.conv2(out) 48 | out = self.bn2(out) 49 | out = self.se(out) 50 | 51 | if self.downsample is not None: 52 | residual = self.downsample(x) 53 | 54 | out += residual 55 | out = self.relu(out) 56 | return out 57 | 58 | 59 | class ResNetSpeakerEncoder(nn.Module): 60 | """Implementation of the model H/ASP without batch normalization in speaker embedding. This model was proposed in: https://arxiv.org/abs/2009.14153 61 | Adapted from: https://github.com/clovaai/voxceleb_trainer 62 | """ 63 | 64 | # pylint: disable=W0102 65 | def __init__( 66 | self, 67 | input_dim=64, 68 | proj_dim=512, 69 | layers=[3, 4, 6, 3], 70 | num_filters=[32, 64, 128, 256], 71 | encoder_type="ASP", 72 | log_input=False, 73 | ): 74 | super(ResNetSpeakerEncoder, self).__init__() 75 | 76 | self.encoder_type = encoder_type 77 | self.input_dim = input_dim 78 | self.log_input = log_input 79 | self.conv1 = nn.Conv2d(1, num_filters[0], kernel_size=3, stride=1, padding=1) 80 | self.relu = nn.ReLU(inplace=True) 81 | self.bn1 = nn.BatchNorm2d(num_filters[0]) 82 | 83 | self.inplanes = num_filters[0] 84 | self.layer1 = self.create_layer(SEBasicBlock, num_filters[0], layers[0]) 85 | self.layer2 = self.create_layer(SEBasicBlock, num_filters[1], layers[1], stride=(2, 2)) 86 | self.layer3 = self.create_layer(SEBasicBlock, num_filters[2], layers[2], stride=(2, 2)) 87 | self.layer4 = self.create_layer(SEBasicBlock, num_filters[3], layers[3], stride=(2, 2)) 88 | 89 | self.instancenorm = nn.InstanceNorm1d(input_dim) 90 | 91 | outmap_size = int(self.input_dim / 8) 92 | 93 | self.attention = nn.Sequential( 94 | nn.Conv1d(num_filters[3] * outmap_size, 128, kernel_size=1), 95 | nn.ReLU(), 96 | nn.BatchNorm1d(128), 97 | nn.Conv1d(128, num_filters[3] * outmap_size, kernel_size=1), 98 | nn.Softmax(dim=2), 99 | ) 100 | 101 | if self.encoder_type == "SAP": 102 | out_dim = num_filters[3] * outmap_size 103 | elif self.encoder_type == "ASP": 104 | out_dim = num_filters[3] * outmap_size * 2 105 | else: 106 | raise ValueError("Undefined encoder") 107 | 108 | self.fc = nn.Linear(out_dim, proj_dim) 109 | 110 | self._init_layers() 111 | 112 | def _init_layers(self): 113 | for m in self.modules(): 114 | if isinstance(m, nn.Conv2d): 115 | nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") 116 | elif isinstance(m, nn.BatchNorm2d): 117 | nn.init.constant_(m.weight, 1) 118 | nn.init.constant_(m.bias, 0) 119 | 120 | def create_layer(self, block, planes, blocks, stride=1): 121 | downsample = None 122 | if stride != 1 or self.inplanes != planes * block.expansion: 123 | downsample = nn.Sequential( 124 | nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False), 125 | nn.BatchNorm2d(planes * block.expansion), 126 | ) 127 | 128 | layers = [] 129 | layers.append(block(self.inplanes, planes, stride, downsample)) 130 | self.inplanes = planes * block.expansion 131 | for _ in range(1, blocks): 132 | layers.append(block(self.inplanes, planes)) 133 | 134 | return nn.Sequential(*layers) 135 | 136 | # pylint: disable=R0201 137 | def new_parameter(self, *size): 138 | out = nn.Parameter(torch.FloatTensor(*size)) 139 | nn.init.xavier_normal_(out) 140 | return out 141 | 142 | def forward(self, x, l2_norm=False): 143 | x = x.transpose(1, 2) 144 | with torch.no_grad(): 145 | with torch.cuda.amp.autocast(enabled=False): 146 | if self.log_input: 147 | x = (x + 1e-6).log() 148 | x = self.instancenorm(x).unsqueeze(1) 149 | 150 | x = self.conv1(x) 151 | x = self.relu(x) 152 | x = self.bn1(x) 153 | 154 | x = self.layer1(x) 155 | x = self.layer2(x) 156 | x = self.layer3(x) 157 | x = self.layer4(x) 158 | 159 | x = x.reshape(x.size()[0], -1, x.size()[-1]) 160 | 161 | w = self.attention(x) 162 | 163 | if self.encoder_type == "SAP": 164 | x = torch.sum(x * w, dim=2) 165 | elif self.encoder_type == "ASP": 166 | mu = torch.sum(x * w, dim=2) 167 | sg = torch.sqrt((torch.sum((x ** 2) * w, dim=2) - mu ** 2).clamp(min=1e-5)) 168 | x = torch.cat((mu, sg), 1) 169 | 170 | x = x.view(x.size()[0], -1) 171 | x = self.fc(x) 172 | 173 | if l2_norm: 174 | x = torch.nn.functional.normalize(x, p=2, dim=1) 175 | return x 176 | 177 | @torch.no_grad() 178 | def compute_embedding(self, x, num_frames=250, num_eval=10, return_mean=True): 179 | """ 180 | Generate embeddings for a batch of utterances 181 | x: 1xTxD 182 | """ 183 | max_len = x.shape[1] 184 | 185 | if max_len < num_frames: 186 | num_frames = max_len 187 | 188 | offsets = np.linspace(0, max_len - num_frames, num=num_eval) 189 | 190 | frames_batch = [] 191 | for offset in offsets: 192 | offset = int(offset) 193 | end_offset = int(offset + num_frames) 194 | frames = x[:, offset:end_offset] 195 | frames_batch.append(frames) 196 | 197 | frames_batch = torch.cat(frames_batch, dim=0) 198 | embeddings = self.forward(frames_batch, l2_norm=True) 199 | 200 | if return_mean: 201 | embeddings = torch.mean(embeddings, dim=0, keepdim=True) 202 | 203 | return embeddings 204 | 205 | def load_checkpoint(self, config: dict, checkpoint_path: str, eval: bool = False, use_cuda: bool = False): 206 | state = load_fsspec(checkpoint_path, map_location=torch.device("cpu")) 207 | self.load_state_dict(state["model"]) 208 | if use_cuda: 209 | self.cuda() 210 | if eval: 211 | self.eval() 212 | assert not self.training 213 | -------------------------------------------------------------------------------- /speaker_encoder/utils/generic_utils.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import glob 3 | import os 4 | import random 5 | import re 6 | from multiprocessing import Manager 7 | 8 | import numpy as np 9 | from scipy import signal 10 | 11 | from TTS.speaker_encoder.models.lstm import LSTMSpeakerEncoder 12 | from TTS.speaker_encoder.models.resnet import ResNetSpeakerEncoder 13 | from TTS.utils.io import save_fsspec 14 | 15 | 16 | class Storage(object): 17 | def __init__(self, maxsize, storage_batchs, num_speakers_in_batch, num_threads=8): 18 | # use multiprocessing for threading safe 19 | self.storage = Manager().list() 20 | self.maxsize = maxsize 21 | self.num_speakers_in_batch = num_speakers_in_batch 22 | self.num_threads = num_threads 23 | self.ignore_last_batch = False 24 | 25 | if storage_batchs >= 3: 26 | self.ignore_last_batch = True 27 | 28 | # used for fast random sample 29 | self.safe_storage_size = self.maxsize - self.num_threads 30 | if self.ignore_last_batch: 31 | self.safe_storage_size -= self.num_speakers_in_batch 32 | 33 | def __len__(self): 34 | return len(self.storage) 35 | 36 | def full(self): 37 | return len(self.storage) >= self.maxsize 38 | 39 | def append(self, item): 40 | # if storage is full, remove an item 41 | if self.full(): 42 | self.storage.pop(0) 43 | 44 | self.storage.append(item) 45 | 46 | def get_random_sample(self): 47 | # safe storage size considering all threads remove one item from storage in same time 48 | storage_size = len(self.storage) - self.num_threads 49 | 50 | if self.ignore_last_batch: 51 | storage_size -= self.num_speakers_in_batch 52 | 53 | return self.storage[random.randint(0, storage_size)] 54 | 55 | def get_random_sample_fast(self): 56 | """Call this method only when storage is full""" 57 | return self.storage[random.randint(0, self.safe_storage_size)] 58 | 59 | 60 | class AugmentWAV(object): 61 | def __init__(self, ap, augmentation_config): 62 | 63 | self.ap = ap 64 | self.use_additive_noise = False 65 | 66 | if "additive" in augmentation_config.keys(): 67 | self.additive_noise_config = augmentation_config["additive"] 68 | additive_path = self.additive_noise_config["sounds_path"] 69 | if additive_path: 70 | self.use_additive_noise = True 71 | # get noise types 72 | self.additive_noise_types = [] 73 | for key in self.additive_noise_config.keys(): 74 | if isinstance(self.additive_noise_config[key], dict): 75 | self.additive_noise_types.append(key) 76 | 77 | additive_files = glob.glob(os.path.join(additive_path, "**/*.wav"), recursive=True) 78 | 79 | self.noise_list = {} 80 | 81 | for wav_file in additive_files: 82 | noise_dir = wav_file.replace(additive_path, "").split(os.sep)[0] 83 | # ignore not listed directories 84 | if noise_dir not in self.additive_noise_types: 85 | continue 86 | if not noise_dir in self.noise_list: 87 | self.noise_list[noise_dir] = [] 88 | self.noise_list[noise_dir].append(wav_file) 89 | 90 | print( 91 | f" | > Using Additive Noise Augmentation: with {len(additive_files)} audios instances from {self.additive_noise_types}" 92 | ) 93 | 94 | self.use_rir = False 95 | 96 | if "rir" in augmentation_config.keys(): 97 | self.rir_config = augmentation_config["rir"] 98 | if self.rir_config["rir_path"]: 99 | self.rir_files = glob.glob(os.path.join(self.rir_config["rir_path"], "**/*.wav"), recursive=True) 100 | self.use_rir = True 101 | 102 | print(f" | > Using RIR Noise Augmentation: with {len(self.rir_files)} audios instances") 103 | 104 | self.create_augmentation_global_list() 105 | 106 | def create_augmentation_global_list(self): 107 | if self.use_additive_noise: 108 | self.global_noise_list = self.additive_noise_types 109 | else: 110 | self.global_noise_list = [] 111 | if self.use_rir: 112 | self.global_noise_list.append("RIR_AUG") 113 | 114 | def additive_noise(self, noise_type, audio): 115 | 116 | clean_db = 10 * np.log10(np.mean(audio ** 2) + 1e-4) 117 | 118 | noise_list = random.sample( 119 | self.noise_list[noise_type], 120 | random.randint( 121 | self.additive_noise_config[noise_type]["min_num_noises"], 122 | self.additive_noise_config[noise_type]["max_num_noises"], 123 | ), 124 | ) 125 | 126 | audio_len = audio.shape[0] 127 | noises_wav = None 128 | for noise in noise_list: 129 | noiseaudio = self.ap.load_wav(noise, sr=self.ap.sample_rate)[:audio_len] 130 | 131 | if noiseaudio.shape[0] < audio_len: 132 | continue 133 | 134 | noise_snr = random.uniform( 135 | self.additive_noise_config[noise_type]["min_snr_in_db"], 136 | self.additive_noise_config[noise_type]["max_num_noises"], 137 | ) 138 | noise_db = 10 * np.log10(np.mean(noiseaudio ** 2) + 1e-4) 139 | noise_wav = np.sqrt(10 ** ((clean_db - noise_db - noise_snr) / 10)) * noiseaudio 140 | 141 | if noises_wav is None: 142 | noises_wav = noise_wav 143 | else: 144 | noises_wav += noise_wav 145 | 146 | # if all possible files is less than audio, choose other files 147 | if noises_wav is None: 148 | return self.additive_noise(noise_type, audio) 149 | 150 | return audio + noises_wav 151 | 152 | def reverberate(self, audio): 153 | audio_len = audio.shape[0] 154 | 155 | rir_file = random.choice(self.rir_files) 156 | rir = self.ap.load_wav(rir_file, sr=self.ap.sample_rate) 157 | rir = rir / np.sqrt(np.sum(rir ** 2)) 158 | return signal.convolve(audio, rir, mode=self.rir_config["conv_mode"])[:audio_len] 159 | 160 | def apply_one(self, audio): 161 | noise_type = random.choice(self.global_noise_list) 162 | if noise_type == "RIR_AUG": 163 | return self.reverberate(audio) 164 | 165 | return self.additive_noise(noise_type, audio) 166 | 167 | 168 | def to_camel(text): 169 | text = text.capitalize() 170 | return re.sub(r"(?!^)_([a-zA-Z])", lambda m: m.group(1).upper(), text) 171 | 172 | 173 | def setup_model(c): 174 | if c.model_params["model_name"].lower() == "lstm": 175 | model = LSTMSpeakerEncoder( 176 | c.model_params["input_dim"], 177 | c.model_params["proj_dim"], 178 | c.model_params["lstm_dim"], 179 | c.model_params["num_lstm_layers"], 180 | ) 181 | elif c.model_params["model_name"].lower() == "resnet": 182 | model = ResNetSpeakerEncoder(input_dim=c.model_params["input_dim"], proj_dim=c.model_params["proj_dim"]) 183 | return model 184 | 185 | 186 | def save_checkpoint(model, optimizer, criterion, model_loss, out_path, current_step, epoch): 187 | checkpoint_path = "checkpoint_{}.pth.tar".format(current_step) 188 | checkpoint_path = os.path.join(out_path, checkpoint_path) 189 | print(" | | > Checkpoint saving : {}".format(checkpoint_path)) 190 | 191 | new_state_dict = model.state_dict() 192 | state = { 193 | "model": new_state_dict, 194 | "optimizer": optimizer.state_dict() if optimizer is not None else None, 195 | "criterion": criterion.state_dict(), 196 | "step": current_step, 197 | "epoch": epoch, 198 | "loss": model_loss, 199 | "date": datetime.date.today().strftime("%B %d, %Y"), 200 | } 201 | save_fsspec(state, checkpoint_path) 202 | 203 | 204 | def save_best_model(model, optimizer, criterion, model_loss, best_loss, out_path, current_step): 205 | if model_loss < best_loss: 206 | new_state_dict = model.state_dict() 207 | state = { 208 | "model": new_state_dict, 209 | "optimizer": optimizer.state_dict(), 210 | "criterion": criterion.state_dict(), 211 | "step": current_step, 212 | "loss": model_loss, 213 | "date": datetime.date.today().strftime("%B %d, %Y"), 214 | } 215 | best_loss = model_loss 216 | bestmodel_path = "best_model.pth.tar" 217 | bestmodel_path = os.path.join(out_path, bestmodel_path) 218 | print("\n > BEST MODEL ({0:.5f}) : {1:}".format(model_loss, bestmodel_path)) 219 | save_fsspec(state, bestmodel_path) 220 | return best_loss 221 | -------------------------------------------------------------------------------- /speaker_encoder/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import nn 4 | 5 | 6 | # adapted from https://github.com/cvqluu/GE2E-Loss 7 | class GE2ELoss(nn.Module): 8 | def __init__(self, init_w=10.0, init_b=-5.0, loss_method="softmax"): 9 | """ 10 | Implementation of the Generalized End-to-End loss defined in https://arxiv.org/abs/1710.10467 [1] 11 | Accepts an input of size (N, M, D) 12 | where N is the number of speakers in the batch, 13 | M is the number of utterances per speaker, 14 | and D is the dimensionality of the embedding vector (e.g. d-vector) 15 | Args: 16 | - init_w (float): defines the initial value of w in Equation (5) of [1] 17 | - init_b (float): definies the initial value of b in Equation (5) of [1] 18 | """ 19 | super().__init__() 20 | # pylint: disable=E1102 21 | self.w = nn.Parameter(torch.tensor(init_w)) 22 | # pylint: disable=E1102 23 | self.b = nn.Parameter(torch.tensor(init_b)) 24 | self.loss_method = loss_method 25 | 26 | print(" > Initialized Generalized End-to-End loss") 27 | 28 | assert self.loss_method in ["softmax", "contrast"] 29 | 30 | if self.loss_method == "softmax": 31 | self.embed_loss = self.embed_loss_softmax 32 | if self.loss_method == "contrast": 33 | self.embed_loss = self.embed_loss_contrast 34 | 35 | # pylint: disable=R0201 36 | def calc_new_centroids(self, dvecs, centroids, spkr, utt): 37 | """ 38 | Calculates the new centroids excluding the reference utterance 39 | """ 40 | excl = torch.cat((dvecs[spkr, :utt], dvecs[spkr, utt + 1 :])) 41 | excl = torch.mean(excl, 0) 42 | new_centroids = [] 43 | for i, centroid in enumerate(centroids): 44 | if i == spkr: 45 | new_centroids.append(excl) 46 | else: 47 | new_centroids.append(centroid) 48 | return torch.stack(new_centroids) 49 | 50 | def calc_cosine_sim(self, dvecs, centroids): 51 | """ 52 | Make the cosine similarity matrix with dims (N,M,N) 53 | """ 54 | cos_sim_matrix = [] 55 | for spkr_idx, speaker in enumerate(dvecs): 56 | cs_row = [] 57 | for utt_idx, utterance in enumerate(speaker): 58 | new_centroids = self.calc_new_centroids(dvecs, centroids, spkr_idx, utt_idx) 59 | # vector based cosine similarity for speed 60 | cs_row.append( 61 | torch.clamp( 62 | torch.mm( 63 | utterance.unsqueeze(1).transpose(0, 1), 64 | new_centroids.transpose(0, 1), 65 | ) 66 | / (torch.norm(utterance) * torch.norm(new_centroids, dim=1)), 67 | 1e-6, 68 | ) 69 | ) 70 | cs_row = torch.cat(cs_row, dim=0) 71 | cos_sim_matrix.append(cs_row) 72 | return torch.stack(cos_sim_matrix) 73 | 74 | # pylint: disable=R0201 75 | def embed_loss_softmax(self, dvecs, cos_sim_matrix): 76 | """ 77 | Calculates the loss on each embedding $L(e_{ji})$ by taking softmax 78 | """ 79 | N, M, _ = dvecs.shape 80 | L = [] 81 | for j in range(N): 82 | L_row = [] 83 | for i in range(M): 84 | L_row.append(-F.log_softmax(cos_sim_matrix[j, i], 0)[j]) 85 | L_row = torch.stack(L_row) 86 | L.append(L_row) 87 | return torch.stack(L) 88 | 89 | # pylint: disable=R0201 90 | def embed_loss_contrast(self, dvecs, cos_sim_matrix): 91 | """ 92 | Calculates the loss on each embedding $L(e_{ji})$ by contrast loss with closest centroid 93 | """ 94 | N, M, _ = dvecs.shape 95 | L = [] 96 | for j in range(N): 97 | L_row = [] 98 | for i in range(M): 99 | centroids_sigmoids = torch.sigmoid(cos_sim_matrix[j, i]) 100 | excl_centroids_sigmoids = torch.cat((centroids_sigmoids[:j], centroids_sigmoids[j + 1 :])) 101 | L_row.append(1.0 - torch.sigmoid(cos_sim_matrix[j, i, j]) + torch.max(excl_centroids_sigmoids)) 102 | L_row = torch.stack(L_row) 103 | L.append(L_row) 104 | return torch.stack(L) 105 | 106 | def forward(self, x, _label=None): 107 | """ 108 | Calculates the GE2E loss for an input of dimensions (num_speakers, num_utts_per_speaker, dvec_feats) 109 | """ 110 | 111 | assert x.size()[1] >= 2 112 | 113 | centroids = torch.mean(x, 1) 114 | cos_sim_matrix = self.calc_cosine_sim(x, centroids) 115 | torch.clamp(self.w, 1e-6) 116 | cos_sim_matrix = self.w * cos_sim_matrix + self.b 117 | L = self.embed_loss(x, cos_sim_matrix) 118 | return L.mean() 119 | 120 | 121 | # adapted from https://github.com/clovaai/voxceleb_trainer/blob/master/loss/angleproto.py 122 | class AngleProtoLoss(nn.Module): 123 | """ 124 | Implementation of the Angular Prototypical loss defined in https://arxiv.org/abs/2003.11982 125 | Accepts an input of size (N, M, D) 126 | where N is the number of speakers in the batch, 127 | M is the number of utterances per speaker, 128 | and D is the dimensionality of the embedding vector 129 | Args: 130 | - init_w (float): defines the initial value of w 131 | - init_b (float): definies the initial value of b 132 | """ 133 | 134 | def __init__(self, init_w=10.0, init_b=-5.0): 135 | super().__init__() 136 | # pylint: disable=E1102 137 | self.w = nn.Parameter(torch.tensor(init_w)) 138 | # pylint: disable=E1102 139 | self.b = nn.Parameter(torch.tensor(init_b)) 140 | self.criterion = torch.nn.CrossEntropyLoss() 141 | 142 | print(" > Initialized Angular Prototypical loss") 143 | 144 | def forward(self, x, _label=None): 145 | """ 146 | Calculates the AngleProto loss for an input of dimensions (num_speakers, num_utts_per_speaker, dvec_feats) 147 | """ 148 | 149 | assert x.size()[1] >= 2 150 | 151 | out_anchor = torch.mean(x[:, 1:, :], 1) 152 | out_positive = x[:, 0, :] 153 | num_speakers = out_anchor.size()[0] 154 | 155 | cos_sim_matrix = F.cosine_similarity( 156 | out_positive.unsqueeze(-1).expand(-1, -1, num_speakers), 157 | out_anchor.unsqueeze(-1).expand(-1, -1, num_speakers).transpose(0, 2), 158 | ) 159 | torch.clamp(self.w, 1e-6) 160 | cos_sim_matrix = cos_sim_matrix * self.w + self.b 161 | label = torch.arange(num_speakers).to(cos_sim_matrix.device) 162 | L = self.criterion(cos_sim_matrix, label) 163 | return L 164 | 165 | 166 | class SoftmaxLoss(nn.Module): 167 | """ 168 | Implementation of the Softmax loss as defined in https://arxiv.org/abs/2003.11982 169 | Args: 170 | - embedding_dim (float): speaker embedding dim 171 | - n_speakers (float): number of speakers 172 | """ 173 | 174 | def __init__(self, embedding_dim, n_speakers): 175 | super().__init__() 176 | 177 | self.criterion = torch.nn.CrossEntropyLoss() 178 | self.fc = nn.Linear(embedding_dim, n_speakers) 179 | 180 | print("Initialised Softmax Loss") 181 | 182 | def forward(self, x, label=None): 183 | # reshape for compatibility 184 | x = x.reshape(-1, x.size()[-1]) 185 | label = label.reshape(-1) 186 | 187 | x = self.fc(x) 188 | L = self.criterion(x, label) 189 | 190 | return L 191 | 192 | 193 | class SoftmaxAngleProtoLoss(nn.Module): 194 | """ 195 | Implementation of the Softmax AnglePrototypical loss as defined in https://arxiv.org/abs/2009.14153 196 | Args: 197 | - embedding_dim (float): speaker embedding dim 198 | - n_speakers (float): number of speakers 199 | - init_w (float): defines the initial value of w 200 | - init_b (float): definies the initial value of b 201 | """ 202 | 203 | def __init__(self, embedding_dim, n_speakers, init_w=10.0, init_b=-5.0): 204 | super().__init__() 205 | 206 | self.softmax = SoftmaxLoss(embedding_dim, n_speakers) 207 | self.angleproto = AngleProtoLoss(init_w, init_b) 208 | 209 | print("Initialised SoftmaxAnglePrototypical Loss") 210 | 211 | def forward(self, x, label=None): 212 | """ 213 | Calculates the SoftmaxAnglePrototypical loss for an input of dimensions (num_speakers, num_utts_per_speaker, dvec_feats) 214 | """ 215 | 216 | Lp = self.angleproto(x) 217 | 218 | Ls = self.softmax(x, label) 219 | 220 | return Ls + Lp 221 | -------------------------------------------------------------------------------- /speaker_encoder/utils/prepare_voxceleb.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (C) 2020 ATHENA AUTHORS; Yiping Peng; Ne Luo 3 | # All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # ============================================================================== 17 | # Only support eager mode and TF>=2.0.0 18 | # pylint: disable=no-member, invalid-name, relative-beyond-top-level 19 | # pylint: disable=too-many-locals, too-many-statements, too-many-arguments, too-many-instance-attributes 20 | """ voxceleb 1 & 2 """ 21 | 22 | import hashlib 23 | import os 24 | import subprocess 25 | import sys 26 | import zipfile 27 | 28 | import pandas 29 | import soundfile as sf 30 | from absl import logging 31 | 32 | SUBSETS = { 33 | "vox1_dev_wav": [ 34 | "https://thor.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox1_dev_wav_partaa", 35 | "https://thor.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox1_dev_wav_partab", 36 | "https://thor.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox1_dev_wav_partac", 37 | "https://thor.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox1_dev_wav_partad", 38 | ], 39 | "vox1_test_wav": ["https://thor.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox1_test_wav.zip"], 40 | "vox2_dev_aac": [ 41 | "https://thor.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partaa", 42 | "https://thor.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partab", 43 | "https://thor.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partac", 44 | "https://thor.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partad", 45 | "https://thor.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partae", 46 | "https://thor.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partaf", 47 | "https://thor.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partag", 48 | "https://thor.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partah", 49 | ], 50 | "vox2_test_aac": ["https://thor.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_test_aac.zip"], 51 | } 52 | 53 | MD5SUM = { 54 | "vox1_dev_wav": "ae63e55b951748cc486645f532ba230b", 55 | "vox2_dev_aac": "bbc063c46078a602ca71605645c2a402", 56 | "vox1_test_wav": "185fdc63c3c739954633d50379a3d102", 57 | "vox2_test_aac": "0d2b3ea430a821c33263b5ea37ede312", 58 | } 59 | 60 | USER = {"user": "", "password": ""} 61 | 62 | speaker_id_dict = {} 63 | 64 | 65 | def download_and_extract(directory, subset, urls): 66 | """Download and extract the given split of dataset. 67 | 68 | Args: 69 | directory: the directory where to put the downloaded data. 70 | subset: subset name of the corpus. 71 | urls: the list of urls to download the data file. 72 | """ 73 | os.makedirs(directory, exist_ok=True) 74 | 75 | try: 76 | for url in urls: 77 | zip_filepath = os.path.join(directory, url.split("/")[-1]) 78 | if os.path.exists(zip_filepath): 79 | continue 80 | logging.info("Downloading %s to %s" % (url, zip_filepath)) 81 | subprocess.call( 82 | "wget %s --user %s --password %s -O %s" % (url, USER["user"], USER["password"], zip_filepath), 83 | shell=True, 84 | ) 85 | 86 | statinfo = os.stat(zip_filepath) 87 | logging.info("Successfully downloaded %s, size(bytes): %d" % (url, statinfo.st_size)) 88 | 89 | # concatenate all parts into zip files 90 | if ".zip" not in zip_filepath: 91 | zip_filepath = "_".join(zip_filepath.split("_")[:-1]) 92 | subprocess.call("cat %s* > %s.zip" % (zip_filepath, zip_filepath), shell=True) 93 | zip_filepath += ".zip" 94 | extract_path = zip_filepath.strip(".zip") 95 | 96 | # check zip file md5sum 97 | with open(zip_filepath, "rb") as f_zip: 98 | md5 = hashlib.md5(f_zip.read()).hexdigest() 99 | if md5 != MD5SUM[subset]: 100 | raise ValueError("md5sum of %s mismatch" % zip_filepath) 101 | 102 | with zipfile.ZipFile(zip_filepath, "r") as zfile: 103 | zfile.extractall(directory) 104 | extract_path_ori = os.path.join(directory, zfile.infolist()[0].filename) 105 | subprocess.call("mv %s %s" % (extract_path_ori, extract_path), shell=True) 106 | finally: 107 | # os.remove(zip_filepath) 108 | pass 109 | 110 | 111 | def exec_cmd(cmd): 112 | """Run a command in a subprocess. 113 | Args: 114 | cmd: command line to be executed. 115 | Return: 116 | int, the return code. 117 | """ 118 | try: 119 | retcode = subprocess.call(cmd, shell=True) 120 | if retcode < 0: 121 | logging.info(f"Child was terminated by signal {retcode}") 122 | except OSError as e: 123 | logging.info(f"Execution failed: {e}") 124 | retcode = -999 125 | return retcode 126 | 127 | 128 | def decode_aac_with_ffmpeg(aac_file, wav_file): 129 | """Decode a given AAC file into WAV using ffmpeg. 130 | Args: 131 | aac_file: file path to input AAC file. 132 | wav_file: file path to output WAV file. 133 | Return: 134 | bool, True if success. 135 | """ 136 | cmd = f"ffmpeg -i {aac_file} {wav_file}" 137 | logging.info(f"Decoding aac file using command line: {cmd}") 138 | ret = exec_cmd(cmd) 139 | if ret != 0: 140 | logging.error(f"Failed to decode aac file with retcode {ret}") 141 | logging.error("Please check your ffmpeg installation.") 142 | return False 143 | return True 144 | 145 | 146 | def convert_audio_and_make_label(input_dir, subset, output_dir, output_file): 147 | """Optionally convert AAC to WAV and make speaker labels. 148 | Args: 149 | input_dir: the directory which holds the input dataset. 150 | subset: the name of the specified subset. e.g. vox1_dev_wav 151 | output_dir: the directory to place the newly generated csv files. 152 | output_file: the name of the newly generated csv file. e.g. vox1_dev_wav.csv 153 | """ 154 | 155 | logging.info("Preprocessing audio and label for subset %s" % subset) 156 | source_dir = os.path.join(input_dir, subset) 157 | 158 | files = [] 159 | # Convert all AAC file into WAV format. At the same time, generate the csv 160 | for root, _, filenames in os.walk(source_dir): 161 | for filename in filenames: 162 | name, ext = os.path.splitext(filename) 163 | if ext.lower() == ".wav": 164 | _, ext2 = os.path.splitext(name) 165 | if ext2: 166 | continue 167 | wav_file = os.path.join(root, filename) 168 | elif ext.lower() == ".m4a": 169 | # Convert AAC to WAV. 170 | aac_file = os.path.join(root, filename) 171 | wav_file = aac_file + ".wav" 172 | if not os.path.exists(wav_file): 173 | if not decode_aac_with_ffmpeg(aac_file, wav_file): 174 | raise RuntimeError("Audio decoding failed.") 175 | else: 176 | continue 177 | speaker_name = root.split(os.path.sep)[-2] 178 | if speaker_name not in speaker_id_dict: 179 | num = len(speaker_id_dict) 180 | speaker_id_dict[speaker_name] = num 181 | # wav_filesize = os.path.getsize(wav_file) 182 | wav_length = len(sf.read(wav_file)[0]) 183 | files.append((os.path.abspath(wav_file), wav_length, speaker_id_dict[speaker_name], speaker_name)) 184 | 185 | # Write to CSV file which contains four columns: 186 | # "wav_filename", "wav_length_ms", "speaker_id", "speaker_name". 187 | csv_file_path = os.path.join(output_dir, output_file) 188 | df = pandas.DataFrame(data=files, columns=["wav_filename", "wav_length_ms", "speaker_id", "speaker_name"]) 189 | df.to_csv(csv_file_path, index=False, sep="\t") 190 | logging.info("Successfully generated csv file {}".format(csv_file_path)) 191 | 192 | 193 | def processor(directory, subset, force_process): 194 | """download and process""" 195 | urls = SUBSETS 196 | if subset not in urls: 197 | raise ValueError(subset, "is not in voxceleb") 198 | 199 | subset_csv = os.path.join(directory, subset + ".csv") 200 | if not force_process and os.path.exists(subset_csv): 201 | return subset_csv 202 | 203 | logging.info("Downloading and process the voxceleb in %s", directory) 204 | logging.info("Preparing subset %s", subset) 205 | download_and_extract(directory, subset, urls[subset]) 206 | convert_audio_and_make_label(directory, subset, directory, subset + ".csv") 207 | logging.info("Finished downloading and processing") 208 | return subset_csv 209 | 210 | 211 | if __name__ == "__main__": 212 | logging.set_verbosity(logging.INFO) 213 | if len(sys.argv) != 4: 214 | print("Usage: python prepare_data.py save_directory user password") 215 | sys.exit() 216 | 217 | DIR, USER["user"], USER["password"] = sys.argv[1], sys.argv[2], sys.argv[3] 218 | for SUBSET in SUBSETS: 219 | processor(DIR, SUBSET, False) 220 | -------------------------------------------------------------------------------- /speaker_encoder/dataset.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import numpy as np 4 | import torch 5 | from torch.utils.data import Dataset 6 | 7 | from TTS.speaker_encoder.utils.generic_utils import AugmentWAV, Storage 8 | 9 | 10 | class SpeakerEncoderDataset(Dataset): 11 | def __init__( 12 | self, 13 | ap, 14 | meta_data, 15 | voice_len=1.6, 16 | num_speakers_in_batch=64, 17 | storage_size=1, 18 | sample_from_storage_p=0.5, 19 | num_utter_per_speaker=10, 20 | skip_speakers=False, 21 | verbose=False, 22 | augmentation_config=None, 23 | ): 24 | """ 25 | Args: 26 | ap (TTS.tts.utils.AudioProcessor): audio processor object. 27 | meta_data (list): list of dataset instances. 28 | seq_len (int): voice segment length in seconds. 29 | verbose (bool): print diagnostic information. 30 | """ 31 | super().__init__() 32 | self.items = meta_data 33 | self.sample_rate = ap.sample_rate 34 | self.seq_len = int(voice_len * self.sample_rate) 35 | self.num_speakers_in_batch = num_speakers_in_batch 36 | self.num_utter_per_speaker = num_utter_per_speaker 37 | self.skip_speakers = skip_speakers 38 | self.ap = ap 39 | self.verbose = verbose 40 | self.__parse_items() 41 | storage_max_size = storage_size * num_speakers_in_batch 42 | self.storage = Storage( 43 | maxsize=storage_max_size, storage_batchs=storage_size, num_speakers_in_batch=num_speakers_in_batch 44 | ) 45 | self.sample_from_storage_p = float(sample_from_storage_p) 46 | 47 | speakers_aux = list(self.speakers) 48 | speakers_aux.sort() 49 | self.speakerid_to_classid = {key: i for i, key in enumerate(speakers_aux)} 50 | 51 | # Augmentation 52 | self.augmentator = None 53 | self.gaussian_augmentation_config = None 54 | if augmentation_config: 55 | self.data_augmentation_p = augmentation_config["p"] 56 | if self.data_augmentation_p and ("additive" in augmentation_config or "rir" in augmentation_config): 57 | self.augmentator = AugmentWAV(ap, augmentation_config) 58 | 59 | if "gaussian" in augmentation_config.keys(): 60 | self.gaussian_augmentation_config = augmentation_config["gaussian"] 61 | 62 | if self.verbose: 63 | print("\n > DataLoader initialization") 64 | print(f" | > Speakers per Batch: {num_speakers_in_batch}") 65 | print(f" | > Storage Size: {storage_max_size} instances, each with {num_utter_per_speaker} utters") 66 | print(f" | > Sample_from_storage_p : {self.sample_from_storage_p}") 67 | print(f" | > Number of instances : {len(self.items)}") 68 | print(f" | > Sequence length: {self.seq_len}") 69 | print(f" | > Num speakers: {len(self.speakers)}") 70 | 71 | def load_wav(self, filename): 72 | audio = self.ap.load_wav(filename, sr=self.ap.sample_rate) 73 | return audio 74 | 75 | def load_data(self, idx): 76 | text, wav_file, speaker_name = self.items[idx] 77 | wav = np.asarray(self.load_wav(wav_file), dtype=np.float32) 78 | mel = self.ap.melspectrogram(wav).astype("float32") 79 | # sample seq_len 80 | 81 | assert text.size > 0, self.items[idx][1] 82 | assert wav.size > 0, self.items[idx][1] 83 | 84 | sample = { 85 | "mel": mel, 86 | "item_idx": self.items[idx][1], 87 | "speaker_name": speaker_name, 88 | } 89 | return sample 90 | 91 | def __parse_items(self): 92 | self.speaker_to_utters = {} 93 | for i in self.items: 94 | path_ = i[1] 95 | speaker_ = i[2] 96 | if speaker_ in self.speaker_to_utters.keys(): 97 | self.speaker_to_utters[speaker_].append(path_) 98 | else: 99 | self.speaker_to_utters[speaker_] = [ 100 | path_, 101 | ] 102 | 103 | if self.skip_speakers: 104 | self.speaker_to_utters = { 105 | k: v for (k, v) in self.speaker_to_utters.items() if len(v) >= self.num_utter_per_speaker 106 | } 107 | 108 | self.speakers = [k for (k, v) in self.speaker_to_utters.items()] 109 | 110 | def __len__(self): 111 | return int(1e10) 112 | 113 | def get_num_speakers(self): 114 | return len(self.speakers) 115 | 116 | def __sample_speaker(self, ignore_speakers=None): 117 | speaker = random.sample(self.speakers, 1)[0] 118 | # if list of speakers_id is provide make sure that it's will be ignored 119 | if ignore_speakers and self.speakerid_to_classid[speaker] in ignore_speakers: 120 | while True: 121 | speaker = random.sample(self.speakers, 1)[0] 122 | if self.speakerid_to_classid[speaker] not in ignore_speakers: 123 | break 124 | 125 | if self.num_utter_per_speaker > len(self.speaker_to_utters[speaker]): 126 | utters = random.choices(self.speaker_to_utters[speaker], k=self.num_utter_per_speaker) 127 | else: 128 | utters = random.sample(self.speaker_to_utters[speaker], self.num_utter_per_speaker) 129 | return speaker, utters 130 | 131 | def __sample_speaker_utterances(self, speaker): 132 | """ 133 | Sample all M utterances for the given speaker. 134 | """ 135 | wavs = [] 136 | labels = [] 137 | for _ in range(self.num_utter_per_speaker): 138 | # TODO:dummy but works 139 | while True: 140 | # remove speakers that have num_utter less than 2 141 | if len(self.speaker_to_utters[speaker]) > 1: 142 | utter = random.sample(self.speaker_to_utters[speaker], 1)[0] 143 | else: 144 | if speaker in self.speakers: 145 | self.speakers.remove(speaker) 146 | 147 | speaker, _ = self.__sample_speaker() 148 | continue 149 | 150 | wav = self.load_wav(utter) 151 | if wav.shape[0] - self.seq_len > 0: 152 | break 153 | 154 | if utter in self.speaker_to_utters[speaker]: 155 | self.speaker_to_utters[speaker].remove(utter) 156 | 157 | if self.augmentator is not None and self.data_augmentation_p: 158 | if random.random() < self.data_augmentation_p: 159 | wav = self.augmentator.apply_one(wav) 160 | 161 | wavs.append(wav) 162 | labels.append(self.speakerid_to_classid[speaker]) 163 | return wavs, labels 164 | 165 | def __getitem__(self, idx): 166 | speaker, _ = self.__sample_speaker() 167 | speaker_id = self.speakerid_to_classid[speaker] 168 | return speaker, speaker_id 169 | 170 | def __load_from_disk_and_storage(self, speaker): 171 | # don't sample from storage, but from HDD 172 | wavs_, labels_ = self.__sample_speaker_utterances(speaker) 173 | # put the newly loaded item into storage 174 | self.storage.append((wavs_, labels_)) 175 | return wavs_, labels_ 176 | 177 | def collate_fn(self, batch): 178 | # get the batch speaker_ids 179 | batch = np.array(batch) 180 | speakers_id_in_batch = set(batch[:, 1].astype(np.int32)) 181 | 182 | labels = [] 183 | feats = [] 184 | speakers = set() 185 | 186 | for speaker, speaker_id in batch: 187 | speaker_id = int(speaker_id) 188 | 189 | # ensure that an speaker appears only once in the batch 190 | if speaker_id in speakers: 191 | 192 | # remove current speaker 193 | if speaker_id in speakers_id_in_batch: 194 | speakers_id_in_batch.remove(speaker_id) 195 | 196 | speaker, _ = self.__sample_speaker(ignore_speakers=speakers_id_in_batch) 197 | speaker_id = self.speakerid_to_classid[speaker] 198 | speakers_id_in_batch.add(speaker_id) 199 | 200 | if random.random() < self.sample_from_storage_p and self.storage.full(): 201 | # sample from storage (if full) 202 | wavs_, labels_ = self.storage.get_random_sample_fast() 203 | 204 | # force choose the current speaker or other not in batch 205 | # It's necessary for ideal training with AngleProto and GE2E losses 206 | if labels_[0] in speakers_id_in_batch and labels_[0] != speaker_id: 207 | attempts = 0 208 | while True: 209 | wavs_, labels_ = self.storage.get_random_sample_fast() 210 | if labels_[0] == speaker_id or labels_[0] not in speakers_id_in_batch: 211 | break 212 | 213 | attempts += 1 214 | # Try 5 times after that load from disk 215 | if attempts >= 5: 216 | wavs_, labels_ = self.__load_from_disk_and_storage(speaker) 217 | break 218 | else: 219 | # don't sample from storage, but from HDD 220 | wavs_, labels_ = self.__load_from_disk_and_storage(speaker) 221 | 222 | # append speaker for control 223 | speakers.add(labels_[0]) 224 | 225 | # remove current speaker and append other 226 | if speaker_id in speakers_id_in_batch: 227 | speakers_id_in_batch.remove(speaker_id) 228 | 229 | speakers_id_in_batch.add(labels_[0]) 230 | 231 | # get a random subset of each of the wavs and extract mel spectrograms. 232 | feats_ = [] 233 | for wav in wavs_: 234 | offset = random.randint(0, wav.shape[0] - self.seq_len) 235 | wav = wav[offset : offset + self.seq_len] 236 | # add random gaussian noise 237 | if self.gaussian_augmentation_config and self.gaussian_augmentation_config["p"]: 238 | if random.random() < self.gaussian_augmentation_config["p"]: 239 | wav += np.random.normal( 240 | self.gaussian_augmentation_config["min_amplitude"], 241 | self.gaussian_augmentation_config["max_amplitude"], 242 | size=len(wav), 243 | ) 244 | mel = self.ap.melspectrogram(wav) 245 | feats_.append(torch.FloatTensor(mel)) 246 | 247 | labels.append(torch.LongTensor(labels_)) 248 | feats.extend(feats_) 249 | 250 | feats = torch.stack(feats) 251 | labels = torch.stack(labels) 252 | 253 | return feats.transpose(1, 2), labels 254 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /config/shared_configs.py: -------------------------------------------------------------------------------- 1 | from dataclasses import asdict, dataclass 2 | from typing import List 3 | 4 | from coqpit import Coqpit, check_argument 5 | 6 | 7 | @dataclass 8 | class BaseAudioConfig(Coqpit): 9 | """Base config to definge audio processing parameters. It is used to initialize 10 | ```TTS.utils.audio.AudioProcessor.``` 11 | 12 | Args: 13 | fft_size (int): 14 | Number of STFT frequency levels aka.size of the linear spectogram frame. Defaults to 1024. 15 | 16 | win_length (int): 17 | Each frame of audio is windowed by window of length ```win_length``` and then padded with zeros to match 18 | ```fft_size```. Defaults to 1024. 19 | 20 | hop_length (int): 21 | Number of audio samples between adjacent STFT columns. Defaults to 1024. 22 | 23 | frame_shift_ms (int): 24 | Set ```hop_length``` based on milliseconds and sampling rate. 25 | 26 | frame_length_ms (int): 27 | Set ```win_length``` based on milliseconds and sampling rate. 28 | 29 | stft_pad_mode (str): 30 | Padding method used in STFT. 'reflect' or 'center'. Defaults to 'reflect'. 31 | 32 | sample_rate (int): 33 | Audio sampling rate. Defaults to 22050. 34 | 35 | resample (bool): 36 | Enable / Disable resampling audio to ```sample_rate```. Defaults to ```False```. 37 | 38 | preemphasis (float): 39 | Preemphasis coefficient. Defaults to 0.0. 40 | 41 | ref_level_db (int): 20 42 | Reference Db level to rebase the audio signal and ignore the level below. 20Db is assumed the sound of air. 43 | Defaults to 20. 44 | 45 | do_sound_norm (bool): 46 | Enable / Disable sound normalization to reconcile the volume differences among samples. Defaults to False. 47 | 48 | log_func (str): 49 | Numpy log function used for amplitude to DB conversion. Defaults to 'np.log10'. 50 | 51 | do_trim_silence (bool): 52 | Enable / Disable trimming silences at the beginning and the end of the audio clip. Defaults to ```True```. 53 | 54 | do_amp_to_db_linear (bool, optional): 55 | enable/disable amplitude to dB conversion of linear spectrograms. Defaults to True. 56 | 57 | do_amp_to_db_mel (bool, optional): 58 | enable/disable amplitude to dB conversion of mel spectrograms. Defaults to True. 59 | 60 | trim_db (int): 61 | Silence threshold used for silence trimming. Defaults to 45. 62 | 63 | power (float): 64 | Exponent used for expanding spectrogra levels before running Griffin Lim. It helps to reduce the 65 | artifacts in the synthesized voice. Defaults to 1.5. 66 | 67 | griffin_lim_iters (int): 68 | Number of Griffing Lim iterations. Defaults to 60. 69 | 70 | num_mels (int): 71 | Number of mel-basis frames that defines the frame lengths of each mel-spectrogram frame. Defaults to 80. 72 | 73 | mel_fmin (float): Min frequency level used for the mel-basis filters. ~50 for male and ~95 for female voices. 74 | It needs to be adjusted for a dataset. Defaults to 0. 75 | 76 | mel_fmax (float): 77 | Max frequency level used for the mel-basis filters. It needs to be adjusted for a dataset. 78 | 79 | spec_gain (int): 80 | Gain applied when converting amplitude to DB. Defaults to 20. 81 | 82 | signal_norm (bool): 83 | enable/disable signal normalization. Defaults to True. 84 | 85 | min_level_db (int): 86 | minimum db threshold for the computed melspectrograms. Defaults to -100. 87 | 88 | symmetric_norm (bool): 89 | enable/disable symmetric normalization. If set True normalization is performed in the range [-k, k] else 90 | [0, k], Defaults to True. 91 | 92 | max_norm (float): 93 | ```k``` defining the normalization range. Defaults to 4.0. 94 | 95 | clip_norm (bool): 96 | enable/disable clipping the our of range values in the normalized audio signal. Defaults to True. 97 | 98 | stats_path (str): 99 | Path to the computed stats file. Defaults to None. 100 | """ 101 | 102 | # stft parameters 103 | fft_size: int = 1024 104 | win_length: int = 1024 105 | hop_length: int = 256 106 | frame_shift_ms: int = None 107 | frame_length_ms: int = None 108 | stft_pad_mode: str = "reflect" 109 | # audio processing parameters 110 | sample_rate: int = 22050 111 | resample: bool = False 112 | preemphasis: float = 0.0 113 | ref_level_db: int = 20 114 | do_sound_norm: bool = False 115 | log_func: str = "np.log10" 116 | # silence trimming 117 | do_trim_silence: bool = True 118 | trim_db: int = 45 119 | # griffin-lim params 120 | power: float = 1.5 121 | griffin_lim_iters: int = 60 122 | # mel-spec params 123 | num_mels: int = 80 124 | mel_fmin: float = 0.0 125 | mel_fmax: float = None 126 | spec_gain: int = 20 127 | do_amp_to_db_linear: bool = True 128 | do_amp_to_db_mel: bool = True 129 | # normalization params 130 | signal_norm: bool = True 131 | min_level_db: int = -100 132 | symmetric_norm: bool = True 133 | max_norm: float = 4.0 134 | clip_norm: bool = True 135 | stats_path: str = None 136 | 137 | def check_values( 138 | self, 139 | ): 140 | """Check config fields""" 141 | c = asdict(self) 142 | check_argument("num_mels", c, restricted=True, min_val=10, max_val=2056) 143 | check_argument("fft_size", c, restricted=True, min_val=128, max_val=4058) 144 | check_argument("sample_rate", c, restricted=True, min_val=512, max_val=100000) 145 | check_argument( 146 | "frame_length_ms", 147 | c, 148 | restricted=True, 149 | min_val=10, 150 | max_val=1000, 151 | alternative="win_length", 152 | ) 153 | check_argument("frame_shift_ms", c, restricted=True, min_val=1, max_val=1000, alternative="hop_length") 154 | check_argument("preemphasis", c, restricted=True, min_val=0, max_val=1) 155 | check_argument("min_level_db", c, restricted=True, min_val=-1000, max_val=10) 156 | check_argument("ref_level_db", c, restricted=True, min_val=0, max_val=1000) 157 | check_argument("power", c, restricted=True, min_val=1, max_val=5) 158 | check_argument("griffin_lim_iters", c, restricted=True, min_val=10, max_val=1000) 159 | 160 | # normalization parameters 161 | check_argument("signal_norm", c, restricted=True) 162 | check_argument("symmetric_norm", c, restricted=True) 163 | check_argument("max_norm", c, restricted=True, min_val=0.1, max_val=1000) 164 | check_argument("clip_norm", c, restricted=True) 165 | check_argument("mel_fmin", c, restricted=True, min_val=0.0, max_val=1000) 166 | check_argument("mel_fmax", c, restricted=True, min_val=500.0, allow_none=True) 167 | check_argument("spec_gain", c, restricted=True, min_val=1, max_val=100) 168 | check_argument("do_trim_silence", c, restricted=True) 169 | check_argument("trim_db", c, restricted=True) 170 | 171 | 172 | @dataclass 173 | class BaseDatasetConfig(Coqpit): 174 | """Base config for TTS datasets. 175 | 176 | Args: 177 | name (str): 178 | Dataset name that defines the preprocessor in use. Defaults to None. 179 | 180 | path (str): 181 | Root path to the dataset files. Defaults to None. 182 | 183 | meta_file_train (str): 184 | Name of the dataset meta file. Or a list of speakers to be ignored at training for multi-speaker datasets. 185 | Defaults to None. 186 | 187 | unused_speakers (List): 188 | List of speakers IDs that are not used at the training. Default None. 189 | 190 | meta_file_val (str): 191 | Name of the dataset meta file that defines the instances used at validation. 192 | 193 | meta_file_attn_mask (str): 194 | Path to the file that lists the attention mask files used with models that require attention masks to 195 | train the duration predictor. 196 | """ 197 | 198 | name: str = "" 199 | path: str = "" 200 | meta_file_train: str = "" 201 | ununsed_speakers: List[str] = None 202 | meta_file_val: str = "" 203 | meta_file_attn_mask: str = "" 204 | 205 | def check_values( 206 | self, 207 | ): 208 | """Check config fields""" 209 | c = asdict(self) 210 | check_argument("name", c, restricted=True) 211 | check_argument("path", c, restricted=True) 212 | check_argument("meta_file_train", c, restricted=True) 213 | check_argument("meta_file_val", c, restricted=False) 214 | check_argument("meta_file_attn_mask", c, restricted=False) 215 | 216 | 217 | @dataclass 218 | class BaseTrainingConfig(Coqpit): 219 | """Base config to define the basic training parameters that are shared 220 | among all the models. 221 | 222 | Args: 223 | model (str): 224 | Name of the model that is used in the training. 225 | 226 | run_name (str): 227 | Name of the experiment. This prefixes the output folder name. Defaults to `coqui_tts`. 228 | 229 | run_description (str): 230 | Short description of the experiment. 231 | 232 | epochs (int): 233 | Number training epochs. Defaults to 10000. 234 | 235 | batch_size (int): 236 | Training batch size. 237 | 238 | eval_batch_size (int): 239 | Validation batch size. 240 | 241 | mixed_precision (bool): 242 | Enable / Disable mixed precision training. It reduces the VRAM use and allows larger batch sizes, however 243 | it may also cause numerical unstability in some cases. 244 | 245 | scheduler_after_epoch (bool): 246 | If true, run the scheduler step after each epoch else run it after each model step. 247 | 248 | run_eval (bool): 249 | Enable / Disable evaluation (validation) run. Defaults to True. 250 | 251 | test_delay_epochs (int): 252 | Number of epochs before starting to use evaluation runs. Initially, models do not generate meaningful 253 | results, hence waiting for a couple of epochs might save some time. 254 | 255 | print_eval (bool): 256 | Enable / Disable console logging for evalutaion steps. If disabled then it only shows the final values at 257 | the end of the evaluation. Default to ```False```. 258 | 259 | print_step (int): 260 | Number of steps required to print the next training log. 261 | 262 | log_dashboard (str): "tensorboard" or "wandb" 263 | Set the experiment tracking tool 264 | 265 | plot_step (int): 266 | Number of steps required to log training on Tensorboard. 267 | 268 | model_param_stats (bool): 269 | Enable / Disable logging internal model stats for model diagnostic. It might be useful for model debugging. 270 | Defaults to ```False```. 271 | 272 | project_name (str): 273 | Name of the project. Defaults to config.model 274 | 275 | wandb_entity (str): 276 | Name of W&B entity/team. Enables collaboration across a team or org. 277 | 278 | log_model_step (int): 279 | Number of steps required to log a checkpoint as W&B artifact 280 | 281 | save_step (int):ipt 282 | Number of steps required to save the next checkpoint. 283 | 284 | checkpoint (bool): 285 | Enable / Disable checkpointing. 286 | 287 | keep_all_best (bool): 288 | Enable / Disable keeping all the saved best models instead of overwriting the previous one. Defaults 289 | to ```False```. 290 | 291 | keep_after (int): 292 | Number of steps to wait before saving all the best models. In use if ```keep_all_best == True```. Defaults 293 | to 10000. 294 | 295 | num_loader_workers (int): 296 | Number of workers for training time dataloader. 297 | 298 | num_eval_loader_workers (int): 299 | Number of workers for evaluation time dataloader. 300 | 301 | output_path (str): 302 | Path for training output folder, either a local file path or other 303 | URLs supported by both fsspec and tensorboardX, e.g. GCS (gs://) or 304 | S3 (s3://) paths. The nonexist part of the given path is created 305 | automatically. All training artefacts are saved there. 306 | """ 307 | 308 | model: str = None 309 | run_name: str = "coqui_tts" 310 | run_description: str = "" 311 | # training params 312 | epochs: int = 10000 313 | batch_size: int = None 314 | eval_batch_size: int = None 315 | mixed_precision: bool = False 316 | scheduler_after_epoch: bool = False 317 | # eval params 318 | run_eval: bool = True 319 | test_delay_epochs: int = 0 320 | print_eval: bool = False 321 | # logging 322 | dashboard_logger: str = "tensorboard" 323 | print_step: int = 25 324 | plot_step: int = 100 325 | model_param_stats: bool = False 326 | project_name: str = None 327 | log_model_step: int = None 328 | wandb_entity: str = None 329 | # checkpointing 330 | save_step: int = 10000 331 | checkpoint: bool = True 332 | keep_all_best: bool = False 333 | keep_after: int = 10000 334 | # dataloading 335 | num_loader_workers: int = 0 336 | num_eval_loader_workers: int = 0 337 | use_noise_augment: bool = False 338 | # paths 339 | output_path: str = None 340 | # distributed 341 | distributed_backend: str = "nccl" 342 | distributed_url: str = "tcp://localhost:54321" 343 | -------------------------------------------------------------------------------- /utils/audio.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Tuple 2 | 3 | import librosa 4 | import numpy as np 5 | import pyworld as pw 6 | import scipy.io.wavfile 7 | import scipy.signal 8 | import soundfile as sf 9 | import torch 10 | from torch import nn 11 | 12 | class StandardScaler: 13 | """StandardScaler for mean-scale normalization with the given mean and scale values.""" 14 | 15 | def __init__(self, mean: np.ndarray = None, scale: np.ndarray = None) -> None: 16 | self.mean_ = mean 17 | self.scale_ = scale 18 | 19 | def set_stats(self, mean, scale): 20 | self.mean_ = mean 21 | self.scale_ = scale 22 | 23 | def reset_stats(self): 24 | delattr(self, "mean_") 25 | delattr(self, "scale_") 26 | 27 | def transform(self, X): 28 | X = np.asarray(X) 29 | X -= self.mean_ 30 | X /= self.scale_ 31 | return X 32 | 33 | def inverse_transform(self, X): 34 | X = np.asarray(X) 35 | X *= self.scale_ 36 | X += self.mean_ 37 | return X 38 | 39 | class TorchSTFT(nn.Module): # pylint: disable=abstract-method 40 | """Some of the audio processing funtions using Torch for faster batch processing. 41 | 42 | TODO: Merge this with audio.py 43 | """ 44 | 45 | def __init__( 46 | self, 47 | n_fft, 48 | hop_length, 49 | win_length, 50 | pad_wav=False, 51 | window="hann_window", 52 | sample_rate=None, 53 | mel_fmin=0, 54 | mel_fmax=None, 55 | n_mels=80, 56 | use_mel=False, 57 | do_amp_to_db=False, 58 | spec_gain=1.0, 59 | ): 60 | super().__init__() 61 | self.n_fft = n_fft 62 | self.hop_length = hop_length 63 | self.win_length = win_length 64 | self.pad_wav = pad_wav 65 | self.sample_rate = sample_rate 66 | self.mel_fmin = mel_fmin 67 | self.mel_fmax = mel_fmax 68 | self.n_mels = n_mels 69 | self.use_mel = use_mel 70 | self.do_amp_to_db = do_amp_to_db 71 | self.spec_gain = spec_gain 72 | self.window = nn.Parameter(getattr(torch, window)(win_length), requires_grad=False) 73 | self.mel_basis = None 74 | if use_mel: 75 | self._build_mel_basis() 76 | 77 | def __call__(self, x): 78 | """Compute spectrogram frames by torch based stft. 79 | 80 | Args: 81 | x (Tensor): input waveform 82 | 83 | Returns: 84 | Tensor: spectrogram frames. 85 | 86 | Shapes: 87 | x: [B x T] or [:math:`[B, 1, T]`] 88 | """ 89 | if x.ndim == 2: 90 | x = x.unsqueeze(1) 91 | if self.pad_wav: 92 | padding = int((self.n_fft - self.hop_length) / 2) 93 | x = torch.nn.functional.pad(x, (padding, padding), mode="reflect") 94 | # B x D x T x 2 95 | o = torch.stft( 96 | x.squeeze(1), 97 | self.n_fft, 98 | self.hop_length, 99 | self.win_length, 100 | self.window, 101 | center=True, 102 | pad_mode="reflect", # compatible with audio.py 103 | normalized=False, 104 | onesided=True, 105 | return_complex=False, 106 | ) 107 | M = o[:, :, :, 0] 108 | P = o[:, :, :, 1] 109 | S = torch.sqrt(torch.clamp(M ** 2 + P ** 2, min=1e-8)) 110 | if self.use_mel: 111 | S = torch.matmul(self.mel_basis.to(x), S) 112 | if self.do_amp_to_db: 113 | S = self._amp_to_db(S, spec_gain=self.spec_gain) 114 | return S 115 | 116 | def _build_mel_basis(self): 117 | mel_basis = librosa.filters.mel( 118 | self.sample_rate, self.n_fft, n_mels=self.n_mels, fmin=self.mel_fmin, fmax=self.mel_fmax 119 | ) 120 | self.mel_basis = torch.from_numpy(mel_basis).float() 121 | 122 | @staticmethod 123 | def _amp_to_db(x, spec_gain=1.0): 124 | return torch.log(torch.clamp(x, min=1e-5) * spec_gain) 125 | 126 | @staticmethod 127 | def _db_to_amp(x, spec_gain=1.0): 128 | return torch.exp(x) / spec_gain 129 | 130 | 131 | # pylint: disable=too-many-public-methods 132 | class AudioProcessor(object): 133 | """Audio Processor for TTS used by all the data pipelines. 134 | 135 | Note: 136 | All the class arguments are set to default values to enable a flexible initialization 137 | of the class with the model config. They are not meaningful for all the arguments. 138 | 139 | Args: 140 | sample_rate (int, optional): 141 | target audio sampling rate. Defaults to None. 142 | 143 | resample (bool, optional): 144 | enable/disable resampling of the audio clips when the target sampling rate does not match the original sampling rate. Defaults to False. 145 | 146 | num_mels (int, optional): 147 | number of melspectrogram dimensions. Defaults to None. 148 | 149 | log_func (int, optional): 150 | log exponent used for converting spectrogram aplitude to DB. 151 | 152 | min_level_db (int, optional): 153 | minimum db threshold for the computed melspectrograms. Defaults to None. 154 | 155 | frame_shift_ms (int, optional): 156 | milliseconds of frames between STFT columns. Defaults to None. 157 | 158 | frame_length_ms (int, optional): 159 | milliseconds of STFT window length. Defaults to None. 160 | 161 | hop_length (int, optional): 162 | number of frames between STFT columns. Used if ```frame_shift_ms``` is None. Defaults to None. 163 | 164 | win_length (int, optional): 165 | STFT window length. Used if ```frame_length_ms``` is None. Defaults to None. 166 | 167 | ref_level_db (int, optional): 168 | reference DB level to avoid background noise. In general <20DB corresponds to the air noise. Defaults to None. 169 | 170 | fft_size (int, optional): 171 | FFT window size for STFT. Defaults to 1024. 172 | 173 | power (int, optional): 174 | Exponent value applied to the spectrogram before GriffinLim. Defaults to None. 175 | 176 | preemphasis (float, optional): 177 | Preemphasis coefficient. Preemphasis is disabled if == 0.0. Defaults to 0.0. 178 | 179 | signal_norm (bool, optional): 180 | enable/disable signal normalization. Defaults to None. 181 | 182 | symmetric_norm (bool, optional): 183 | enable/disable symmetric normalization. If set True normalization is performed in the range [-k, k] else [0, k], Defaults to None. 184 | 185 | max_norm (float, optional): 186 | ```k``` defining the normalization range. Defaults to None. 187 | 188 | mel_fmin (int, optional): 189 | minimum filter frequency for computing melspectrograms. Defaults to None. 190 | 191 | mel_fmax (int, optional): 192 | maximum filter frequency for computing melspectrograms.. Defaults to None. 193 | 194 | spec_gain (int, optional): 195 | gain applied when converting amplitude to DB. Defaults to 20. 196 | 197 | stft_pad_mode (str, optional): 198 | Padding mode for STFT. Defaults to 'reflect'. 199 | 200 | clip_norm (bool, optional): 201 | enable/disable clipping the our of range values in the normalized audio signal. Defaults to True. 202 | 203 | griffin_lim_iters (int, optional): 204 | Number of GriffinLim iterations. Defaults to None. 205 | 206 | do_trim_silence (bool, optional): 207 | enable/disable silence trimming when loading the audio signal. Defaults to False. 208 | 209 | trim_db (int, optional): 210 | DB threshold used for silence trimming. Defaults to 60. 211 | 212 | do_sound_norm (bool, optional): 213 | enable/disable signal normalization. Defaults to False. 214 | 215 | do_amp_to_db_linear (bool, optional): 216 | enable/disable amplitude to dB conversion of linear spectrograms. Defaults to True. 217 | 218 | do_amp_to_db_mel (bool, optional): 219 | enable/disable amplitude to dB conversion of mel spectrograms. Defaults to True. 220 | 221 | stats_path (str, optional): 222 | Path to the computed stats file. Defaults to None. 223 | 224 | verbose (bool, optional): 225 | enable/disable logging. Defaults to True. 226 | 227 | """ 228 | 229 | def __init__( 230 | self, 231 | sample_rate=None, 232 | resample=False, 233 | num_mels=None, 234 | log_func="np.log10", 235 | min_level_db=None, 236 | frame_shift_ms=None, 237 | frame_length_ms=None, 238 | hop_length=None, 239 | win_length=None, 240 | ref_level_db=None, 241 | fft_size=1024, 242 | power=None, 243 | preemphasis=0.0, 244 | signal_norm=None, 245 | symmetric_norm=None, 246 | max_norm=None, 247 | mel_fmin=None, 248 | mel_fmax=None, 249 | spec_gain=20, 250 | stft_pad_mode="reflect", 251 | clip_norm=True, 252 | griffin_lim_iters=None, 253 | do_trim_silence=False, 254 | trim_db=60, 255 | do_sound_norm=False, 256 | do_amp_to_db_linear=True, 257 | do_amp_to_db_mel=True, 258 | stats_path=None, 259 | verbose=True, 260 | **_, 261 | ): 262 | 263 | # setup class attributed 264 | self.sample_rate = sample_rate 265 | self.resample = resample 266 | self.num_mels = num_mels 267 | self.log_func = log_func 268 | self.min_level_db = min_level_db or 0 269 | self.frame_shift_ms = frame_shift_ms 270 | self.frame_length_ms = frame_length_ms 271 | self.ref_level_db = ref_level_db 272 | self.fft_size = fft_size 273 | self.power = power 274 | self.preemphasis = preemphasis 275 | self.griffin_lim_iters = griffin_lim_iters 276 | self.signal_norm = signal_norm 277 | self.symmetric_norm = symmetric_norm 278 | self.mel_fmin = mel_fmin or 0 279 | self.mel_fmax = mel_fmax 280 | self.spec_gain = float(spec_gain) 281 | self.stft_pad_mode = stft_pad_mode 282 | self.max_norm = 1.0 if max_norm is None else float(max_norm) 283 | self.clip_norm = clip_norm 284 | self.do_trim_silence = do_trim_silence 285 | self.trim_db = trim_db 286 | self.do_sound_norm = do_sound_norm 287 | self.do_amp_to_db_linear = do_amp_to_db_linear 288 | self.do_amp_to_db_mel = do_amp_to_db_mel 289 | self.stats_path = stats_path 290 | # setup exp_func for db to amp conversion 291 | if log_func == "np.log": 292 | self.base = np.e 293 | elif log_func == "np.log10": 294 | self.base = 10 295 | else: 296 | raise ValueError(" [!] unknown `log_func` value.") 297 | # setup stft parameters 298 | if hop_length is None: 299 | # compute stft parameters from given time values 300 | self.hop_length, self.win_length = self._stft_parameters() 301 | else: 302 | # use stft parameters from config file 303 | self.hop_length = hop_length 304 | self.win_length = win_length 305 | assert min_level_db != 0.0, " [!] min_level_db is 0" 306 | assert self.win_length <= self.fft_size, " [!] win_length cannot be larger than fft_size" 307 | members = vars(self) 308 | if verbose: 309 | print(" > Setting up Audio Processor...") 310 | for key, value in members.items(): 311 | print(" | > {}:{}".format(key, value)) 312 | # create spectrogram utils 313 | self.mel_basis = self._build_mel_basis() 314 | self.inv_mel_basis = np.linalg.pinv(self._build_mel_basis()) 315 | # setup scaler 316 | if stats_path and signal_norm: 317 | mel_mean, mel_std, linear_mean, linear_std, _ = self.load_stats(stats_path) 318 | self.setup_scaler(mel_mean, mel_std, linear_mean, linear_std) 319 | self.signal_norm = True 320 | self.max_norm = None 321 | self.clip_norm = None 322 | self.symmetric_norm = None 323 | 324 | ### setting up the parameters ### 325 | def _build_mel_basis( 326 | self, 327 | ) -> np.ndarray: 328 | """Build melspectrogram basis. 329 | 330 | Returns: 331 | np.ndarray: melspectrogram basis. 332 | """ 333 | if self.mel_fmax is not None: 334 | assert self.mel_fmax <= self.sample_rate // 2 335 | return librosa.filters.mel( 336 | self.sample_rate, self.fft_size, n_mels=self.num_mels, fmin=self.mel_fmin, fmax=self.mel_fmax 337 | ) 338 | 339 | def _stft_parameters( 340 | self, 341 | ) -> Tuple[int, int]: 342 | """Compute the real STFT parameters from the time values. 343 | 344 | Returns: 345 | Tuple[int, int]: hop length and window length for STFT. 346 | """ 347 | factor = self.frame_length_ms / self.frame_shift_ms 348 | assert (factor).is_integer(), " [!] frame_shift_ms should divide frame_length_ms" 349 | hop_length = int(self.frame_shift_ms / 1000.0 * self.sample_rate) 350 | win_length = int(hop_length * factor) 351 | return hop_length, win_length 352 | 353 | ### normalization ### 354 | def normalize(self, S: np.ndarray) -> np.ndarray: 355 | """Normalize values into `[0, self.max_norm]` or `[-self.max_norm, self.max_norm]` 356 | 357 | Args: 358 | S (np.ndarray): Spectrogram to normalize. 359 | 360 | Raises: 361 | RuntimeError: Mean and variance is computed from incompatible parameters. 362 | 363 | Returns: 364 | np.ndarray: Normalized spectrogram. 365 | """ 366 | # pylint: disable=no-else-return 367 | S = S.copy() 368 | if self.signal_norm: 369 | # mean-var scaling 370 | if hasattr(self, "mel_scaler"): 371 | if S.shape[0] == self.num_mels: 372 | return self.mel_scaler.transform(S.T).T 373 | elif S.shape[0] == self.fft_size / 2: 374 | return self.linear_scaler.transform(S.T).T 375 | else: 376 | raise RuntimeError(" [!] Mean-Var stats does not match the given feature dimensions.") 377 | # range normalization 378 | S -= self.ref_level_db # discard certain range of DB assuming it is air noise 379 | S_norm = (S - self.min_level_db) / (-self.min_level_db) 380 | if self.symmetric_norm: 381 | S_norm = ((2 * self.max_norm) * S_norm) - self.max_norm 382 | if self.clip_norm: 383 | S_norm = np.clip( 384 | S_norm, -self.max_norm, self.max_norm # pylint: disable=invalid-unary-operand-type 385 | ) 386 | return S_norm 387 | else: 388 | S_norm = self.max_norm * S_norm 389 | if self.clip_norm: 390 | S_norm = np.clip(S_norm, 0, self.max_norm) 391 | return S_norm 392 | else: 393 | return S 394 | 395 | def denormalize(self, S: np.ndarray) -> np.ndarray: 396 | """Denormalize spectrogram values. 397 | 398 | Args: 399 | S (np.ndarray): Spectrogram to denormalize. 400 | 401 | Raises: 402 | RuntimeError: Mean and variance are incompatible. 403 | 404 | Returns: 405 | np.ndarray: Denormalized spectrogram. 406 | """ 407 | # pylint: disable=no-else-return 408 | S_denorm = S.copy() 409 | if self.signal_norm: 410 | # mean-var scaling 411 | if hasattr(self, "mel_scaler"): 412 | if S_denorm.shape[0] == self.num_mels: 413 | return self.mel_scaler.inverse_transform(S_denorm.T).T 414 | elif S_denorm.shape[0] == self.fft_size / 2: 415 | return self.linear_scaler.inverse_transform(S_denorm.T).T 416 | else: 417 | raise RuntimeError(" [!] Mean-Var stats does not match the given feature dimensions.") 418 | if self.symmetric_norm: 419 | if self.clip_norm: 420 | S_denorm = np.clip( 421 | S_denorm, -self.max_norm, self.max_norm # pylint: disable=invalid-unary-operand-type 422 | ) 423 | S_denorm = ((S_denorm + self.max_norm) * -self.min_level_db / (2 * self.max_norm)) + self.min_level_db 424 | return S_denorm + self.ref_level_db 425 | else: 426 | if self.clip_norm: 427 | S_denorm = np.clip(S_denorm, 0, self.max_norm) 428 | S_denorm = (S_denorm * -self.min_level_db / self.max_norm) + self.min_level_db 429 | return S_denorm + self.ref_level_db 430 | else: 431 | return S_denorm 432 | 433 | ### Mean-STD scaling ### 434 | def load_stats(self, stats_path: str) -> Tuple[np.array, np.array, np.array, np.array, Dict]: 435 | """Loading mean and variance statistics from a `npy` file. 436 | 437 | Args: 438 | stats_path (str): Path to the `npy` file containing 439 | 440 | Returns: 441 | Tuple[np.array, np.array, np.array, np.array, Dict]: loaded statistics and the config used to 442 | compute them. 443 | """ 444 | stats = np.load(stats_path, allow_pickle=True).item() # pylint: disable=unexpected-keyword-arg 445 | mel_mean = stats["mel_mean"] 446 | mel_std = stats["mel_std"] 447 | linear_mean = stats["linear_mean"] 448 | linear_std = stats["linear_std"] 449 | stats_config = stats["audio_config"] 450 | # check all audio parameters used for computing stats 451 | skip_parameters = ["griffin_lim_iters", "stats_path", "do_trim_silence", "ref_level_db", "power"] 452 | for key in stats_config.keys(): 453 | if key in skip_parameters: 454 | continue 455 | if key not in ["sample_rate", "trim_db"]: 456 | assert ( 457 | stats_config[key] == self.__dict__[key] 458 | ), f" [!] Audio param {key} does not match the value used for computing mean-var stats. {stats_config[key]} vs {self.__dict__[key]}" 459 | return mel_mean, mel_std, linear_mean, linear_std, stats_config 460 | 461 | # pylint: disable=attribute-defined-outside-init 462 | def setup_scaler( 463 | self, mel_mean: np.ndarray, mel_std: np.ndarray, linear_mean: np.ndarray, linear_std: np.ndarray 464 | ) -> None: 465 | """Initialize scaler objects used in mean-std normalization. 466 | 467 | Args: 468 | mel_mean (np.ndarray): Mean for melspectrograms. 469 | mel_std (np.ndarray): STD for melspectrograms. 470 | linear_mean (np.ndarray): Mean for full scale spectrograms. 471 | linear_std (np.ndarray): STD for full scale spectrograms. 472 | """ 473 | self.mel_scaler = StandardScaler() 474 | self.mel_scaler.set_stats(mel_mean, mel_std) 475 | self.linear_scaler = StandardScaler() 476 | self.linear_scaler.set_stats(linear_mean, linear_std) 477 | 478 | ### DB and AMP conversion ### 479 | # pylint: disable=no-self-use 480 | def _amp_to_db(self, x: np.ndarray) -> np.ndarray: 481 | """Convert amplitude values to decibels. 482 | 483 | Args: 484 | x (np.ndarray): Amplitude spectrogram. 485 | 486 | Returns: 487 | np.ndarray: Decibels spectrogram. 488 | """ 489 | return self.spec_gain * _log(np.maximum(1e-5, x), self.base) 490 | 491 | # pylint: disable=no-self-use 492 | def _db_to_amp(self, x: np.ndarray) -> np.ndarray: 493 | """Convert decibels spectrogram to amplitude spectrogram. 494 | 495 | Args: 496 | x (np.ndarray): Decibels spectrogram. 497 | 498 | Returns: 499 | np.ndarray: Amplitude spectrogram. 500 | """ 501 | return _exp(x / self.spec_gain, self.base) 502 | 503 | ### Preemphasis ### 504 | def apply_preemphasis(self, x: np.ndarray) -> np.ndarray: 505 | """Apply pre-emphasis to the audio signal. Useful to reduce the correlation between neighbouring signal values. 506 | 507 | Args: 508 | x (np.ndarray): Audio signal. 509 | 510 | Raises: 511 | RuntimeError: Preemphasis coeff is set to 0. 512 | 513 | Returns: 514 | np.ndarray: Decorrelated audio signal. 515 | """ 516 | if self.preemphasis == 0: 517 | raise RuntimeError(" [!] Preemphasis is set 0.0.") 518 | return scipy.signal.lfilter([1, -self.preemphasis], [1], x) 519 | 520 | def apply_inv_preemphasis(self, x: np.ndarray) -> np.ndarray: 521 | """Reverse pre-emphasis.""" 522 | if self.preemphasis == 0: 523 | raise RuntimeError(" [!] Preemphasis is set 0.0.") 524 | return scipy.signal.lfilter([1], [1, -self.preemphasis], x) 525 | 526 | ### SPECTROGRAMs ### 527 | def _linear_to_mel(self, spectrogram: np.ndarray) -> np.ndarray: 528 | """Project a full scale spectrogram to a melspectrogram. 529 | 530 | Args: 531 | spectrogram (np.ndarray): Full scale spectrogram. 532 | 533 | Returns: 534 | np.ndarray: Melspectrogram 535 | """ 536 | return np.dot(self.mel_basis, spectrogram) 537 | 538 | def _mel_to_linear(self, mel_spec: np.ndarray) -> np.ndarray: 539 | """Convert a melspectrogram to full scale spectrogram.""" 540 | return np.maximum(1e-10, np.dot(self.inv_mel_basis, mel_spec)) 541 | 542 | def spectrogram(self, y: np.ndarray) -> np.ndarray: 543 | """Compute a spectrogram from a waveform. 544 | 545 | Args: 546 | y (np.ndarray): Waveform. 547 | 548 | Returns: 549 | np.ndarray: Spectrogram. 550 | """ 551 | if self.preemphasis != 0: 552 | D = self._stft(self.apply_preemphasis(y)) 553 | else: 554 | D = self._stft(y) 555 | if self.do_amp_to_db_linear: 556 | S = self._amp_to_db(np.abs(D)) 557 | else: 558 | S = np.abs(D) 559 | return self.normalize(S).astype(np.float32) 560 | 561 | def melspectrogram(self, y: np.ndarray) -> np.ndarray: 562 | """Compute a melspectrogram from a waveform.""" 563 | if self.preemphasis != 0: 564 | D = self._stft(self.apply_preemphasis(y)) 565 | else: 566 | D = self._stft(y) 567 | if self.do_amp_to_db_mel: 568 | S = self._amp_to_db(self._linear_to_mel(np.abs(D))) 569 | else: 570 | S = self._linear_to_mel(np.abs(D)) 571 | return self.normalize(S).astype(np.float32) 572 | 573 | def inv_spectrogram(self, spectrogram: np.ndarray) -> np.ndarray: 574 | """Convert a spectrogram to a waveform using Griffi-Lim vocoder.""" 575 | S = self.denormalize(spectrogram) 576 | S = self._db_to_amp(S) 577 | # Reconstruct phase 578 | if self.preemphasis != 0: 579 | return self.apply_inv_preemphasis(self._griffin_lim(S ** self.power)) 580 | return self._griffin_lim(S ** self.power) 581 | 582 | def inv_melspectrogram(self, mel_spectrogram: np.ndarray) -> np.ndarray: 583 | """Convert a melspectrogram to a waveform using Griffi-Lim vocoder.""" 584 | D = self.denormalize(mel_spectrogram) 585 | S = self._db_to_amp(D) 586 | S = self._mel_to_linear(S) # Convert back to linear 587 | if self.preemphasis != 0: 588 | return self.apply_inv_preemphasis(self._griffin_lim(S ** self.power)) 589 | return self._griffin_lim(S ** self.power) 590 | 591 | def out_linear_to_mel(self, linear_spec: np.ndarray) -> np.ndarray: 592 | """Convert a full scale linear spectrogram output of a network to a melspectrogram. 593 | 594 | Args: 595 | linear_spec (np.ndarray): Normalized full scale linear spectrogram. 596 | 597 | Returns: 598 | np.ndarray: Normalized melspectrogram. 599 | """ 600 | S = self.denormalize(linear_spec) 601 | S = self._db_to_amp(S) 602 | S = self._linear_to_mel(np.abs(S)) 603 | S = self._amp_to_db(S) 604 | mel = self.normalize(S) 605 | return mel 606 | 607 | ### STFT and ISTFT ### 608 | def _stft(self, y: np.ndarray) -> np.ndarray: 609 | """Librosa STFT wrapper. 610 | 611 | Args: 612 | y (np.ndarray): Audio signal. 613 | 614 | Returns: 615 | np.ndarray: Complex number array. 616 | """ 617 | return librosa.stft( 618 | y=y, 619 | n_fft=self.fft_size, 620 | hop_length=self.hop_length, 621 | win_length=self.win_length, 622 | pad_mode=self.stft_pad_mode, 623 | window="hann", 624 | center=True, 625 | ) 626 | 627 | def _istft(self, y: np.ndarray) -> np.ndarray: 628 | """Librosa iSTFT wrapper.""" 629 | return librosa.istft(y, hop_length=self.hop_length, win_length=self.win_length) 630 | 631 | def _griffin_lim(self, S): 632 | angles = np.exp(2j * np.pi * np.random.rand(*S.shape)) 633 | S_complex = np.abs(S).astype(np.complex) 634 | y = self._istft(S_complex * angles) 635 | if not np.isfinite(y).all(): 636 | print(" [!] Waveform is not finite everywhere. Skipping the GL.") 637 | return np.array([0.0]) 638 | for _ in range(self.griffin_lim_iters): 639 | angles = np.exp(1j * np.angle(self._stft(y))) 640 | y = self._istft(S_complex * angles) 641 | return y 642 | 643 | def compute_stft_paddings(self, x, pad_sides=1): 644 | """Compute paddings used by Librosa's STFT. Compute right padding (final frame) or both sides padding 645 | (first and final frames)""" 646 | assert pad_sides in (1, 2) 647 | pad = (x.shape[0] // self.hop_length + 1) * self.hop_length - x.shape[0] 648 | if pad_sides == 1: 649 | return 0, pad 650 | return pad // 2, pad // 2 + pad % 2 651 | 652 | def compute_f0(self, x: np.ndarray) -> np.ndarray: 653 | """Compute pitch (f0) of a waveform using the same parameters used for computing melspectrogram. 654 | 655 | Args: 656 | x (np.ndarray): Waveform. 657 | 658 | Returns: 659 | np.ndarray: Pitch. 660 | 661 | Examples: 662 | >>> WAV_FILE = filename = librosa.util.example_audio_file() 663 | >>> from TTS.config import BaseAudioConfig 664 | >>> from TTS.utils.audio import AudioProcessor 665 | >>> conf = BaseAudioConfig(mel_fmax=8000) 666 | >>> ap = AudioProcessor(**conf) 667 | >>> wav = ap.load_wav(WAV_FILE, sr=22050)[:5 * 22050] 668 | >>> pitch = ap.compute_f0(wav) 669 | """ 670 | f0, t = pw.dio( 671 | x.astype(np.double), 672 | fs=self.sample_rate, 673 | f0_ceil=self.mel_fmax, 674 | frame_period=1000 * self.hop_length / self.sample_rate, 675 | ) 676 | f0 = pw.stonemask(x.astype(np.double), f0, t, self.sample_rate) 677 | # pad = int((self.win_length / self.hop_length) / 2) 678 | # f0 = [0.0] * pad + f0 + [0.0] * pad 679 | # f0 = np.pad(f0, (pad, pad), mode="constant", constant_values=0) 680 | # f0 = np.array(f0, dtype=np.float32) 681 | 682 | # f01, _, _ = librosa.pyin( 683 | # x, 684 | # fmin=65 if self.mel_fmin == 0 else self.mel_fmin, 685 | # fmax=self.mel_fmax, 686 | # frame_length=self.win_length, 687 | # sr=self.sample_rate, 688 | # fill_na=0.0, 689 | # ) 690 | 691 | # spec = self.melspectrogram(x) 692 | return f0 693 | 694 | ### Audio Processing ### 695 | def find_endpoint(self, wav: np.ndarray, threshold_db=-40, min_silence_sec=0.8) -> int: 696 | """Find the last point without silence at the end of a audio signal. 697 | 698 | Args: 699 | wav (np.ndarray): Audio signal. 700 | threshold_db (int, optional): Silence threshold in decibels. Defaults to -40. 701 | min_silence_sec (float, optional): Ignore silences that are shorter then this in secs. Defaults to 0.8. 702 | 703 | Returns: 704 | int: Last point without silence. 705 | """ 706 | window_length = int(self.sample_rate * min_silence_sec) 707 | hop_length = int(window_length / 4) 708 | threshold = self._db_to_amp(threshold_db) 709 | for x in range(hop_length, len(wav) - window_length, hop_length): 710 | if np.max(wav[x : x + window_length]) < threshold: 711 | return x + hop_length 712 | return len(wav) 713 | 714 | def trim_silence(self, wav): 715 | """Trim silent parts with a threshold and 0.01 sec margin""" 716 | margin = int(self.sample_rate * 0.01) 717 | wav = wav[margin:-margin] 718 | return librosa.effects.trim(wav, top_db=self.trim_db, frame_length=self.win_length, hop_length=self.hop_length)[ 719 | 0 720 | ] 721 | 722 | @staticmethod 723 | def sound_norm(x: np.ndarray) -> np.ndarray: 724 | """Normalize the volume of an audio signal. 725 | 726 | Args: 727 | x (np.ndarray): Raw waveform. 728 | 729 | Returns: 730 | np.ndarray: Volume normalized waveform. 731 | """ 732 | return x / abs(x).max() * 0.95 733 | 734 | ### save and load ### 735 | def load_wav(self, filename: str, sr: int = None) -> np.ndarray: 736 | """Read a wav file using Librosa and optionally resample, silence trim, volume normalize. 737 | 738 | Args: 739 | filename (str): Path to the wav file. 740 | sr (int, optional): Sampling rate for resampling. Defaults to None. 741 | 742 | Returns: 743 | np.ndarray: Loaded waveform. 744 | """ 745 | if self.resample: 746 | x, sr = librosa.load(filename, sr=self.sample_rate) 747 | elif sr is None: 748 | x, sr = sf.read(filename) 749 | assert self.sample_rate == sr, "%s vs %s" % (self.sample_rate, sr) 750 | else: 751 | x, sr = librosa.load(filename, sr=sr) 752 | if self.do_trim_silence: 753 | try: 754 | x = self.trim_silence(x) 755 | except ValueError: 756 | print(f" [!] File cannot be trimmed for silence - {filename}") 757 | if self.do_sound_norm: 758 | x = self.sound_norm(x) 759 | return x 760 | 761 | def save_wav(self, wav: np.ndarray, path: str, sr: int = None) -> None: 762 | """Save a waveform to a file using Scipy. 763 | 764 | Args: 765 | wav (np.ndarray): Waveform to save. 766 | path (str): Path to a output file. 767 | sr (int, optional): Sampling rate used for saving to the file. Defaults to None. 768 | """ 769 | wav_norm = wav * (32767 / max(0.01, np.max(np.abs(wav)))) 770 | scipy.io.wavfile.write(path, sr if sr else self.sample_rate, wav_norm.astype(np.int16)) 771 | 772 | @staticmethod 773 | def mulaw_encode(wav: np.ndarray, qc: int) -> np.ndarray: 774 | mu = 2 ** qc - 1 775 | # wav_abs = np.minimum(np.abs(wav), 1.0) 776 | signal = np.sign(wav) * np.log(1 + mu * np.abs(wav)) / np.log(1.0 + mu) 777 | # Quantize signal to the specified number of levels. 778 | signal = (signal + 1) / 2 * mu + 0.5 779 | return np.floor( 780 | signal, 781 | ) 782 | 783 | @staticmethod 784 | def mulaw_decode(wav, qc): 785 | """Recovers waveform from quantized values.""" 786 | mu = 2 ** qc - 1 787 | x = np.sign(wav) / mu * ((1 + mu) ** np.abs(wav) - 1) 788 | return x 789 | 790 | @staticmethod 791 | def encode_16bits(x): 792 | return np.clip(x * 2 ** 15, -(2 ** 15), 2 ** 15 - 1).astype(np.int16) 793 | 794 | @staticmethod 795 | def quantize(x: np.ndarray, bits: int) -> np.ndarray: 796 | """Quantize a waveform to a given number of bits. 797 | 798 | Args: 799 | x (np.ndarray): Waveform to quantize. Must be normalized into the range `[-1, 1]`. 800 | bits (int): Number of quantization bits. 801 | 802 | Returns: 803 | np.ndarray: Quantized waveform. 804 | """ 805 | return (x + 1.0) * (2 ** bits - 1) / 2 806 | 807 | @staticmethod 808 | def dequantize(x, bits): 809 | """Dequantize a waveform from the given number of bits.""" 810 | return 2 * x / (2 ** bits - 1) - 1 811 | 812 | 813 | def _log(x, base): 814 | if base == 10: 815 | return np.log10(x) 816 | return np.log(x) 817 | 818 | 819 | def _exp(x, base): 820 | if base == 10: 821 | return np.power(10, x) 822 | return np.exp(x) 823 | --------------------------------------------------------------------------------