├── LICENSE ├── README.md ├── Results └── Loss.png ├── VAD_segments.py ├── config └── config.yaml ├── data_load.py ├── data_preprocess.py ├── dvector_create.py ├── hparam.py ├── speech_embedder_net.py ├── train_speech_embedder.py └── utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2019, HarryVolek 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | * Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | * Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | * Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PyTorch_Speaker_Verification 2 | 3 | PyTorch implementation of speech embedding net and loss described here: https://arxiv.org/pdf/1710.10467.pdf. 4 | 5 | Also contains code to create embeddings compatible as input for the speaker diarization model found at https://github.com/google/uis-rnn 6 | 7 | ![training loss](https://github.com/HarryVolek/PyTorch_Speaker_Verification/blob/master/Results/Loss.png) 8 | 9 | The TIMIT speech corpus was used to train the model, found here: https://catalog.ldc.upenn.edu/LDC93S1, 10 | or here, https://github.com/philipperemy/timit 11 | 12 | # Dependencies 13 | 14 | * PyTorch 0.4.1 15 | * python 3.5+ 16 | * numpy 1.15.4 17 | * librosa 0.6.1 18 | 19 | The python WebRTC VAD found at https://github.com/wiseman/py-webrtcvad is required to create run dvector_create.py, but not to train the neural network. 20 | 21 | # Preprocessing 22 | 23 | Change the following config.yaml key to a regex containing all .WAV files in your downloaded TIMIT dataset. The TIMIT .WAV files must be converted to the standard format (RIFF) for the dvector_create.py script, but not for training the neural network. 24 | ```yaml 25 | unprocessed_data: './TIMIT/*/*/*/*.wav' 26 | ``` 27 | Run the preprocessing script: 28 | ``` 29 | ./data_preprocess.py 30 | ``` 31 | Two folders will be created, train_tisv and test_tisv, containing .npy files containing numpy ndarrays of speaker utterances with a 90%/10% training/testing split. 32 | 33 | # Training 34 | 35 | To train the speaker verification model, run: 36 | ``` 37 | ./train_speech_embedder.py 38 | ``` 39 | with the following config.yaml key set to true: 40 | ```yaml 41 | training: !!bool "true" 42 | ``` 43 | for testing, set the key value to: 44 | ```yaml 45 | training: !!bool "false" 46 | ``` 47 | The log file and checkpoint save locations are controlled by the following values: 48 | ```yaml 49 | log_file: './speech_id_checkpoint/Stats' 50 | checkpoint_dir: './speech_id_checkpoint' 51 | ``` 52 | Only TI-SV is implemented. 53 | 54 | # Performance 55 | 56 | ``` 57 | EER across 10 epochs: 0.0377 58 | ``` 59 | 60 | # D vector embedding creation 61 | 62 | After training and testing the model, run dvector_create.py to create the numpy files train_sequence.npy, train_cluster_ids.npy, test_sequence.npy, and test_cluster_ids.npy. 63 | 64 | These files can be loaded and used to train the uis-rnn model found at https://github.com/google/uis-rnn 65 | -------------------------------------------------------------------------------- /Results/Loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HarryVolek/PyTorch_Speaker_Verification/10e159a8d3255503c0184cde4eb7097968857a31/Results/Loss.png -------------------------------------------------------------------------------- /VAD_segments.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Tue Dec 18 16:22:41 2018 5 | 6 | @author: Harry 7 | Modified from https://github.com/wiseman/py-webrtcvad/blob/master/example.py 8 | """ 9 | 10 | import collections 11 | import contextlib 12 | import numpy as np 13 | import sys 14 | import librosa 15 | import wave 16 | 17 | import webrtcvad 18 | 19 | from hparam import hparam as hp 20 | 21 | def read_wave(path, sr): 22 | """Reads a .wav file. 23 | Takes the path, and returns (PCM audio data, sample rate). 24 | Assumes sample width == 2 25 | """ 26 | with contextlib.closing(wave.open(path, 'rb')) as wf: 27 | num_channels = wf.getnchannels() 28 | assert num_channels == 1 29 | sample_width = wf.getsampwidth() 30 | assert sample_width == 2 31 | sample_rate = wf.getframerate() 32 | assert sample_rate in (8000, 16000, 32000, 48000) 33 | pcm_data = wf.readframes(wf.getnframes()) 34 | data, _ = librosa.load(path, sr) 35 | assert len(data.shape) == 1 36 | assert sr in (8000, 16000, 32000, 48000) 37 | return data, pcm_data 38 | 39 | class Frame(object): 40 | """Represents a "frame" of audio data.""" 41 | def __init__(self, bytes, timestamp, duration): 42 | self.bytes = bytes 43 | self.timestamp = timestamp 44 | self.duration = duration 45 | 46 | 47 | def frame_generator(frame_duration_ms, audio, sample_rate): 48 | """Generates audio frames from PCM audio data. 49 | Takes the desired frame duration in milliseconds, the PCM data, and 50 | the sample rate. 51 | Yields Frames of the requested duration. 52 | """ 53 | n = int(sample_rate * (frame_duration_ms / 1000.0) * 2) 54 | offset = 0 55 | timestamp = 0.0 56 | duration = (float(n) / sample_rate) / 2.0 57 | while offset + n < len(audio): 58 | yield Frame(audio[offset:offset + n], timestamp, duration) 59 | timestamp += duration 60 | offset += n 61 | 62 | 63 | def vad_collector(sample_rate, frame_duration_ms, 64 | padding_duration_ms, vad, frames): 65 | """Filters out non-voiced audio frames. 66 | Given a webrtcvad.Vad and a source of audio frames, yields only 67 | the voiced audio. 68 | Uses a padded, sliding window algorithm over the audio frames. 69 | When more than 90% of the frames in the window are voiced (as 70 | reported by the VAD), the collector triggers and begins yielding 71 | audio frames. Then the collector waits until 90% of the frames in 72 | the window are unvoiced to detrigger. 73 | The window is padded at the front and back to provide a small 74 | amount of silence or the beginnings/endings of speech around the 75 | voiced frames. 76 | Arguments: 77 | sample_rate - The audio sample rate, in Hz. 78 | frame_duration_ms - The frame duration in milliseconds. 79 | padding_duration_ms - The amount to pad the window, in milliseconds. 80 | vad - An instance of webrtcvad.Vad. 81 | frames - a source of audio frames (sequence or generator). 82 | Returns: A generator that yields PCM audio data. 83 | """ 84 | num_padding_frames = int(padding_duration_ms / frame_duration_ms) 85 | # We use a deque for our sliding window/ring buffer. 86 | ring_buffer = collections.deque(maxlen=num_padding_frames) 87 | # We have two states: TRIGGERED and NOTTRIGGERED. We start in the 88 | # NOTTRIGGERED state. 89 | triggered = False 90 | 91 | voiced_frames = [] 92 | for frame in frames: 93 | is_speech = vad.is_speech(frame.bytes, sample_rate) 94 | 95 | if not triggered: 96 | ring_buffer.append((frame, is_speech)) 97 | num_voiced = len([f for f, speech in ring_buffer if speech]) 98 | # If we're NOTTRIGGERED and more than 90% of the frames in 99 | # the ring buffer are voiced frames, then enter the 100 | # TRIGGERED state. 101 | if num_voiced > 0.9 * ring_buffer.maxlen: 102 | triggered = True 103 | start = ring_buffer[0][0].timestamp 104 | # We want to yield all the audio we see from now until 105 | # we are NOTTRIGGERED, but we have to start with the 106 | # audio that's already in the ring buffer. 107 | for f, s in ring_buffer: 108 | voiced_frames.append(f) 109 | ring_buffer.clear() 110 | else: 111 | # We're in the TRIGGERED state, so collect the audio data 112 | # and add it to the ring buffer. 113 | voiced_frames.append(frame) 114 | ring_buffer.append((frame, is_speech)) 115 | num_unvoiced = len([f for f, speech in ring_buffer if not speech]) 116 | # If more than 90% of the frames in the ring buffer are 117 | # unvoiced, then enter NOTTRIGGERED and yield whatever 118 | # audio we've collected. 119 | if num_unvoiced > 0.9 * ring_buffer.maxlen: 120 | triggered = False 121 | yield (start, frame.timestamp + frame.duration) 122 | ring_buffer.clear() 123 | voiced_frames = [] 124 | # If we have any leftover voiced audio when we run out of input, 125 | # yield it. 126 | if voiced_frames: 127 | yield (start, frame.timestamp + frame.duration) 128 | 129 | 130 | def VAD_chunk(aggressiveness, path): 131 | audio, byte_audio = read_wave(path, hp.data.sr) 132 | vad = webrtcvad.Vad(int(aggressiveness)) 133 | frames = frame_generator(20, byte_audio, hp.data.sr) 134 | frames = list(frames) 135 | times = vad_collector(hp.data.sr, 20, 200, vad, frames) 136 | speech_times = [] 137 | speech_segs = [] 138 | for i, time in enumerate(times): 139 | start = np.round(time[0],decimals=2) 140 | end = np.round(time[1],decimals=2) 141 | j = start 142 | while j + .4 < end: 143 | end_j = np.round(j+.4,decimals=2) 144 | speech_times.append((j, end_j)) 145 | speech_segs.append(audio[int(j*hp.data.sr):int(end_j*hp.data.sr)]) 146 | j = end_j 147 | else: 148 | speech_times.append((j, end)) 149 | speech_segs.append(audio[int(j*hp.data.sr):int(end*hp.data.sr)]) 150 | return speech_times, speech_segs 151 | 152 | if __name__ == '__main__': 153 | speech_times, speech_segs = VAD_chunk(sys.argv[1], sys.argv[2]) 154 | -------------------------------------------------------------------------------- /config/config.yaml: -------------------------------------------------------------------------------- 1 | training: !!bool "true" 2 | device: "cuda" 3 | unprocessed_data: './TIMIT/*/*/*/*.wav' 4 | --- 5 | data: 6 | train_path: './train_tisv' 7 | train_path_unprocessed: './TIMIT/TRAIN/*/*/*.wav' 8 | test_path: './test_tisv' 9 | test_path_unprocessed: './TIMIT/TEST/*/*/*.wav' 10 | data_preprocessed: !!bool "true" 11 | sr: 16000 12 | nfft: 512 #For mel spectrogram preprocess 13 | window: 0.025 #(s) 14 | hop: 0.01 #(s) 15 | nmels: 40 #Number of mel energies 16 | tisv_frame: 180 #Max number of time steps in input after preprocess 17 | --- 18 | model: 19 | hidden: 768 #Number of LSTM hidden layer units 20 | num_layer: 3 #Number of LSTM layers 21 | proj: 256 #Embedding size 22 | model_path: './model.model' #Model path for testing, inference, or resuming training 23 | --- 24 | train: 25 | N : 4 #Number of speakers in batch 26 | M : 5 #Number of utterances per speaker 27 | num_workers: 0 #number of workers for dataloader 28 | lr: 0.01 29 | epochs: 950 #Max training speaker epoch 30 | log_interval: 30 #Epochs before printing progress 31 | log_file: './speech_id_checkpoint/Stats' 32 | checkpoint_interval: 120 #Save model after x speaker epochs 33 | checkpoint_dir: './speech_id_checkpoint' 34 | restore: !!bool "false" #Resume training from previous model path 35 | --- 36 | test: 37 | N : 4 #Number of speakers in batch 38 | M : 6 #Number of utterances per speaker 39 | num_workers: 8 #number of workers for data laoder 40 | epochs: 10 #testing speaker epochs 41 | -------------------------------------------------------------------------------- /data_load.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Mon Aug 6 20:55:52 2018 5 | 6 | @author: harry 7 | """ 8 | import glob 9 | import numpy as np 10 | import os 11 | import random 12 | from random import shuffle 13 | import torch 14 | from torch.utils.data import Dataset 15 | 16 | from hparam import hparam as hp 17 | from utils import mfccs_and_spec 18 | 19 | class SpeakerDatasetTIMIT(Dataset): 20 | 21 | def __init__(self): 22 | 23 | if hp.training: 24 | self.path = hp.data.train_path_unprocessed 25 | self.utterance_number = hp.train.M 26 | else: 27 | self.path = hp.data.test_path_unprocessed 28 | self.utterance_number = hp.test.M 29 | self.speakers = glob.glob(os.path.dirname(self.path)) 30 | shuffle(self.speakers) 31 | 32 | def __len__(self): 33 | return len(self.speakers) 34 | 35 | def __getitem__(self, idx): 36 | 37 | speaker = self.speakers[idx] 38 | wav_files = glob.glob(speaker+'/*.WAV') 39 | shuffle(wav_files) 40 | wav_files = wav_files[0:self.utterance_number] 41 | 42 | mel_dbs = [] 43 | for f in wav_files: 44 | _, mel_db, _ = mfccs_and_spec(f, wav_process = True) 45 | mel_dbs.append(mel_db) 46 | return torch.Tensor(mel_dbs) 47 | 48 | class SpeakerDatasetTIMITPreprocessed(Dataset): 49 | 50 | def __init__(self, shuffle=True, utter_start=0): 51 | 52 | # data path 53 | if hp.training: 54 | self.path = hp.data.train_path 55 | self.utter_num = hp.train.M 56 | else: 57 | self.path = hp.data.test_path 58 | self.utter_num = hp.test.M 59 | self.file_list = os.listdir(self.path) 60 | self.shuffle=shuffle 61 | self.utter_start = utter_start 62 | 63 | def __len__(self): 64 | return len(self.file_list) 65 | 66 | def __getitem__(self, idx): 67 | 68 | np_file_list = os.listdir(self.path) 69 | 70 | if self.shuffle: 71 | selected_file = random.sample(np_file_list, 1)[0] # select random speaker 72 | else: 73 | selected_file = np_file_list[idx] 74 | 75 | utters = np.load(os.path.join(self.path, selected_file)) # load utterance spectrogram of selected speaker 76 | if self.shuffle: 77 | utter_index = np.random.randint(0, utters.shape[0], self.utter_num) # select M utterances per speaker 78 | utterance = utters[utter_index] 79 | else: 80 | utterance = utters[self.utter_start: self.utter_start+self.utter_num] # utterances of a speaker [batch(M), n_mels, frames] 81 | 82 | utterance = utterance[:,:,:160] # TODO implement variable length batch size 83 | 84 | utterance = torch.tensor(np.transpose(utterance, axes=(0,2,1))) # transpose [batch, frames, n_mels] 85 | return utterance 86 | -------------------------------------------------------------------------------- /data_preprocess.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | #Modified from https://github.com/JanhHyun/Speaker_Verification 4 | import glob 5 | import os 6 | import librosa 7 | import numpy as np 8 | from hparam import hparam as hp 9 | 10 | # downloaded dataset path 11 | audio_path = glob.glob(os.path.dirname(hp.unprocessed_data)) 12 | 13 | def save_spectrogram_tisv(): 14 | """ Full preprocess of text independent utterance. The log-mel-spectrogram is saved as numpy file. 15 | Each partial utterance is splitted by voice detection using DB 16 | and the first and the last 180 frames from each partial utterance are saved. 17 | Need : utterance data set (VTCK) 18 | """ 19 | print("start text independent utterance feature extraction") 20 | os.makedirs(hp.data.train_path, exist_ok=True) # make folder to save train file 21 | os.makedirs(hp.data.test_path, exist_ok=True) # make folder to save test file 22 | 23 | utter_min_len = (hp.data.tisv_frame * hp.data.hop + hp.data.window) * hp.data.sr # lower bound of utterance length 24 | total_speaker_num = len(audio_path) 25 | train_speaker_num= (total_speaker_num//10)*9 # split total data 90% train and 10% test 26 | print("total speaker number : %d"%total_speaker_num) 27 | print("train : %d, test : %d"%(train_speaker_num, total_speaker_num-train_speaker_num)) 28 | for i, folder in enumerate(audio_path): 29 | print("%dth speaker processing..."%i) 30 | utterances_spec = [] 31 | for utter_name in os.listdir(folder): 32 | if utter_name[-4:] == '.WAV': 33 | utter_path = os.path.join(folder, utter_name) # path of each utterance 34 | utter, sr = librosa.core.load(utter_path, hp.data.sr) # load utterance audio 35 | intervals = librosa.effects.split(utter, top_db=30) # voice activity detection 36 | # this works fine for timit but if you get array of shape 0 for any other audio change value of top_db 37 | # for vctk dataset use top_db=100 38 | for interval in intervals: 39 | if (interval[1]-interval[0]) > utter_min_len: # If partial utterance is sufficient long, 40 | utter_part = utter[interval[0]:interval[1]] # save first and last 180 frames of spectrogram. 41 | S = librosa.core.stft(y=utter_part, n_fft=hp.data.nfft, 42 | win_length=int(hp.data.window * sr), hop_length=int(hp.data.hop * sr)) 43 | S = np.abs(S) ** 2 44 | mel_basis = librosa.filters.mel(sr=hp.data.sr, n_fft=hp.data.nfft, n_mels=hp.data.nmels) 45 | S = np.log10(np.dot(mel_basis, S) + 1e-6) # log mel spectrogram of utterances 46 | utterances_spec.append(S[:, :hp.data.tisv_frame]) # first 180 frames of partial utterance 47 | utterances_spec.append(S[:, -hp.data.tisv_frame:]) # last 180 frames of partial utterance 48 | 49 | utterances_spec = np.array(utterances_spec) 50 | print(utterances_spec.shape) 51 | if i train_speaker_num: 111 | train_sequence = np.concatenate(train_sequence,axis=0) 112 | train_cluster_id = np.asarray(train_cluster_id) 113 | np.save('train_sequence',train_sequence) 114 | np.save('train_cluster_id',train_cluster_id) 115 | train_saved = True 116 | train_sequence = [] 117 | train_cluster_id = [] 118 | 119 | train_sequence = np.concatenate(train_sequence,axis=0) 120 | train_cluster_id = np.asarray(train_cluster_id) 121 | np.save('test_sequence',train_sequence) 122 | np.save('test_cluster_id',train_cluster_id) 123 | -------------------------------------------------------------------------------- /hparam.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | #!/usr/bin/env python 3 | 4 | import yaml 5 | 6 | 7 | def load_hparam(filename): 8 | stream = open(filename, 'r') 9 | docs = yaml.load_all(stream) 10 | hparam_dict = dict() 11 | for doc in docs: 12 | for k, v in doc.items(): 13 | hparam_dict[k] = v 14 | return hparam_dict 15 | 16 | 17 | def merge_dict(user, default): 18 | if isinstance(user, dict) and isinstance(default, dict): 19 | for k, v in default.items(): 20 | if k not in user: 21 | user[k] = v 22 | else: 23 | user[k] = merge_dict(user[k], v) 24 | return user 25 | 26 | 27 | class Dotdict(dict): 28 | """ 29 | a dictionary that supports dot notation 30 | as well as dictionary access notation 31 | usage: d = DotDict() or d = DotDict({'val1':'first'}) 32 | set attributes: d.val2 = 'second' or d['val2'] = 'second' 33 | get attributes: d.val2 or d['val2'] 34 | """ 35 | __getattr__ = dict.__getitem__ 36 | __setattr__ = dict.__setitem__ 37 | __delattr__ = dict.__delitem__ 38 | 39 | def __init__(self, dct=None): 40 | dct = dict() if not dct else dct 41 | for key, value in dct.items(): 42 | if hasattr(value, 'keys'): 43 | value = Dotdict(value) 44 | self[key] = value 45 | 46 | 47 | class Hparam(Dotdict): 48 | 49 | def __init__(self, file='config/config.yaml'): 50 | super(Dotdict, self).__init__() 51 | hp_dict = load_hparam(file) 52 | hp_dotdict = Dotdict(hp_dict) 53 | for k, v in hp_dotdict.items(): 54 | setattr(self, k, v) 55 | 56 | __getattr__ = Dotdict.__getitem__ 57 | __setattr__ = Dotdict.__setitem__ 58 | __delattr__ = Dotdict.__delitem__ 59 | 60 | 61 | hparam = Hparam() 62 | -------------------------------------------------------------------------------- /speech_embedder_net.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Wed Sep 5 20:58:34 2018 5 | 6 | @author: harry 7 | """ 8 | 9 | import torch 10 | import torch.nn as nn 11 | 12 | from hparam import hparam as hp 13 | from utils import get_centroids, get_cossim, calc_loss 14 | 15 | class SpeechEmbedder(nn.Module): 16 | 17 | def __init__(self): 18 | super(SpeechEmbedder, self).__init__() 19 | self.LSTM_stack = nn.LSTM(hp.data.nmels, hp.model.hidden, num_layers=hp.model.num_layer, batch_first=True) 20 | for name, param in self.LSTM_stack.named_parameters(): 21 | if 'bias' in name: 22 | nn.init.constant_(param, 0.0) 23 | elif 'weight' in name: 24 | nn.init.xavier_normal_(param) 25 | self.projection = nn.Linear(hp.model.hidden, hp.model.proj) 26 | 27 | def forward(self, x): 28 | x, _ = self.LSTM_stack(x.float()) #(batch, frames, n_mels) 29 | #only use last frame 30 | x = x[:,x.size(1)-1] 31 | x = self.projection(x.float()) 32 | x = x / torch.norm(x, dim=1).unsqueeze(1) 33 | return x 34 | 35 | class GE2ELoss(nn.Module): 36 | 37 | def __init__(self, device): 38 | super(GE2ELoss, self).__init__() 39 | self.w = nn.Parameter(torch.tensor(10.0).to(device), requires_grad=True) 40 | self.b = nn.Parameter(torch.tensor(-5.0).to(device), requires_grad=True) 41 | self.device = device 42 | 43 | def forward(self, embeddings): 44 | torch.clamp(self.w, 1e-6) 45 | centroids = get_centroids(embeddings) 46 | cossim = get_cossim(embeddings, centroids) 47 | sim_matrix = self.w*cossim.to(self.device) + self.b 48 | loss, _ = calc_loss(sim_matrix) 49 | return loss 50 | -------------------------------------------------------------------------------- /train_speech_embedder.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Wed Sep 5 21:49:16 2018 5 | 6 | @author: harry 7 | """ 8 | 9 | import os 10 | import random 11 | import time 12 | import torch 13 | from torch.utils.data import DataLoader 14 | 15 | from hparam import hparam as hp 16 | from data_load import SpeakerDatasetTIMIT, SpeakerDatasetTIMITPreprocessed 17 | from speech_embedder_net import SpeechEmbedder, GE2ELoss, get_centroids, get_cossim 18 | 19 | def train(model_path): 20 | device = torch.device(hp.device) 21 | 22 | if hp.data.data_preprocessed: 23 | train_dataset = SpeakerDatasetTIMITPreprocessed() 24 | else: 25 | train_dataset = SpeakerDatasetTIMIT() 26 | train_loader = DataLoader(train_dataset, batch_size=hp.train.N, shuffle=True, num_workers=hp.train.num_workers, drop_last=True) 27 | 28 | embedder_net = SpeechEmbedder().to(device) 29 | if hp.train.restore: 30 | embedder_net.load_state_dict(torch.load(model_path)) 31 | ge2e_loss = GE2ELoss(device) 32 | #Both net and loss have trainable parameters 33 | optimizer = torch.optim.SGD([ 34 | {'params': embedder_net.parameters()}, 35 | {'params': ge2e_loss.parameters()} 36 | ], lr=hp.train.lr) 37 | 38 | os.makedirs(hp.train.checkpoint_dir, exist_ok=True) 39 | 40 | embedder_net.train() 41 | iteration = 0 42 | for e in range(hp.train.epochs): 43 | total_loss = 0 44 | for batch_id, mel_db_batch in enumerate(train_loader): 45 | mel_db_batch = mel_db_batch.to(device) 46 | 47 | mel_db_batch = torch.reshape(mel_db_batch, (hp.train.N*hp.train.M, mel_db_batch.size(2), mel_db_batch.size(3))) 48 | perm = random.sample(range(0, hp.train.N*hp.train.M), hp.train.N*hp.train.M) 49 | unperm = list(perm) 50 | for i,j in enumerate(perm): 51 | unperm[j] = i 52 | mel_db_batch = mel_db_batch[perm] 53 | #gradient accumulates 54 | optimizer.zero_grad() 55 | 56 | embeddings = embedder_net(mel_db_batch) 57 | embeddings = embeddings[unperm] 58 | embeddings = torch.reshape(embeddings, (hp.train.N, hp.train.M, embeddings.size(1))) 59 | 60 | #get loss, call backward, step optimizer 61 | loss = ge2e_loss(embeddings) #wants (Speaker, Utterances, embedding) 62 | loss.backward() 63 | torch.nn.utils.clip_grad_norm_(embedder_net.parameters(), 3.0) 64 | torch.nn.utils.clip_grad_norm_(ge2e_loss.parameters(), 1.0) 65 | optimizer.step() 66 | 67 | total_loss = total_loss + loss 68 | iteration += 1 69 | if (batch_id + 1) % hp.train.log_interval == 0: 70 | mesg = "{0}\tEpoch:{1}[{2}/{3}],Iteration:{4}\tLoss:{5:.4f}\tTLoss:{6:.4f}\t\n".format(time.ctime(), e+1, 71 | batch_id+1, len(train_dataset)//hp.train.N, iteration,loss, total_loss / (batch_id + 1)) 72 | print(mesg) 73 | if hp.train.log_file is not None: 74 | with open(hp.train.log_file,'a') as f: 75 | f.write(mesg) 76 | 77 | if hp.train.checkpoint_dir is not None and (e + 1) % hp.train.checkpoint_interval == 0: 78 | embedder_net.eval().cpu() 79 | ckpt_model_filename = "ckpt_epoch_" + str(e+1) + "_batch_id_" + str(batch_id+1) + ".pth" 80 | ckpt_model_path = os.path.join(hp.train.checkpoint_dir, ckpt_model_filename) 81 | torch.save(embedder_net.state_dict(), ckpt_model_path) 82 | embedder_net.to(device).train() 83 | 84 | #save model 85 | embedder_net.eval().cpu() 86 | save_model_filename = "final_epoch_" + str(e + 1) + "_batch_id_" + str(batch_id + 1) + ".model" 87 | save_model_path = os.path.join(hp.train.checkpoint_dir, save_model_filename) 88 | torch.save(embedder_net.state_dict(), save_model_path) 89 | 90 | print("\nDone, trained model saved at", save_model_path) 91 | 92 | def test(model_path): 93 | 94 | if hp.data.data_preprocessed: 95 | test_dataset = SpeakerDatasetTIMITPreprocessed() 96 | else: 97 | test_dataset = SpeakerDatasetTIMIT() 98 | test_loader = DataLoader(test_dataset, batch_size=hp.test.N, shuffle=True, num_workers=hp.test.num_workers, drop_last=True) 99 | 100 | embedder_net = SpeechEmbedder() 101 | embedder_net.load_state_dict(torch.load(model_path)) 102 | embedder_net.eval() 103 | 104 | avg_EER = 0 105 | for e in range(hp.test.epochs): 106 | batch_avg_EER = 0 107 | for batch_id, mel_db_batch in enumerate(test_loader): 108 | assert hp.test.M % 2 == 0 109 | enrollment_batch, verification_batch = torch.split(mel_db_batch, int(mel_db_batch.size(1)/2), dim=1) 110 | 111 | enrollment_batch = torch.reshape(enrollment_batch, (hp.test.N*hp.test.M//2, enrollment_batch.size(2), enrollment_batch.size(3))) 112 | verification_batch = torch.reshape(verification_batch, (hp.test.N*hp.test.M//2, verification_batch.size(2), verification_batch.size(3))) 113 | 114 | perm = random.sample(range(0,verification_batch.size(0)), verification_batch.size(0)) 115 | unperm = list(perm) 116 | for i,j in enumerate(perm): 117 | unperm[j] = i 118 | 119 | verification_batch = verification_batch[perm] 120 | enrollment_embeddings = embedder_net(enrollment_batch) 121 | verification_embeddings = embedder_net(verification_batch) 122 | verification_embeddings = verification_embeddings[unperm] 123 | 124 | enrollment_embeddings = torch.reshape(enrollment_embeddings, (hp.test.N, hp.test.M//2, enrollment_embeddings.size(1))) 125 | verification_embeddings = torch.reshape(verification_embeddings, (hp.test.N, hp.test.M//2, verification_embeddings.size(1))) 126 | 127 | enrollment_centroids = get_centroids(enrollment_embeddings) 128 | 129 | sim_matrix = get_cossim(verification_embeddings, enrollment_centroids) 130 | 131 | # calculating EER 132 | diff = 1; EER=0; EER_thresh = 0; EER_FAR=0; EER_FRR=0 133 | 134 | for thres in [0.01*i+0.5 for i in range(50)]: 135 | sim_matrix_thresh = sim_matrix>thres 136 | 137 | FAR = (sum([sim_matrix_thresh[i].float().sum()-sim_matrix_thresh[i,:,i].float().sum() for i in range(int(hp.test.N))]) 138 | /(hp.test.N-1.0)/(float(hp.test.M/2))/hp.test.N) 139 | 140 | FRR = (sum([hp.test.M/2-sim_matrix_thresh[i,:,i].float().sum() for i in range(int(hp.test.N))]) 141 | /(float(hp.test.M/2))/hp.test.N) 142 | 143 | # Save threshold when FAR = FRR (=EER) 144 | if diff> abs(FAR-FRR): 145 | diff = abs(FAR-FRR) 146 | EER = (FAR+FRR)/2 147 | EER_thresh = thres 148 | EER_FAR = FAR 149 | EER_FRR = FRR 150 | batch_avg_EER += EER 151 | print("\nEER : %0.2f (thres:%0.2f, FAR:%0.2f, FRR:%0.2f)"%(EER,EER_thresh,EER_FAR,EER_FRR)) 152 | avg_EER += batch_avg_EER/(batch_id+1) 153 | avg_EER = avg_EER / hp.test.epochs 154 | print("\n EER across {0} epochs: {1:.4f}".format(hp.test.epochs, avg_EER)) 155 | 156 | if __name__=="__main__": 157 | if hp.training: 158 | train(hp.model.model_path) 159 | else: 160 | test(hp.model.model_path) 161 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Thu Sep 20 16:56:19 2018 5 | 6 | @author: harry 7 | """ 8 | import librosa 9 | import numpy as np 10 | import torch 11 | import torch.autograd as grad 12 | import torch.nn.functional as F 13 | 14 | from hparam import hparam as hp 15 | 16 | def get_centroids_prior(embeddings): 17 | centroids = [] 18 | for speaker in embeddings: 19 | centroid = 0 20 | for utterance in speaker: 21 | centroid = centroid + utterance 22 | centroid = centroid/len(speaker) 23 | centroids.append(centroid) 24 | centroids = torch.stack(centroids) 25 | return centroids 26 | 27 | def get_centroids(embeddings): 28 | centroids = embeddings.mean(dim=1) 29 | return centroids 30 | 31 | def get_centroid(embeddings, speaker_num, utterance_num): 32 | centroid = 0 33 | for utterance_id, utterance in enumerate(embeddings[speaker_num]): 34 | if utterance_id == utterance_num: 35 | continue 36 | centroid = centroid + utterance 37 | centroid = centroid/(len(embeddings[speaker_num])-1) 38 | return centroid 39 | 40 | def get_utterance_centroids(embeddings): 41 | """ 42 | Returns the centroids for each utterance of a speaker, where 43 | the utterance centroid is the speaker centroid without considering 44 | this utterance 45 | 46 | Shape of embeddings should be: 47 | (speaker_ct, utterance_per_speaker_ct, embedding_size) 48 | """ 49 | sum_centroids = embeddings.sum(dim=1) 50 | # we want to subtract out each utterance, prior to calculating the 51 | # the utterance centroid 52 | sum_centroids = sum_centroids.reshape( 53 | sum_centroids.shape[0], 1, sum_centroids.shape[-1] 54 | ) 55 | # we want the mean but not including the utterance itself, so -1 56 | num_utterances = embeddings.shape[1] - 1 57 | centroids = (sum_centroids - embeddings) / num_utterances 58 | return centroids 59 | 60 | def get_cossim_prior(embeddings, centroids): 61 | # Calculates cosine similarity matrix. Requires (N, M, feature) input 62 | cossim = torch.zeros(embeddings.size(0),embeddings.size(1),centroids.size(0)) 63 | for speaker_num, speaker in enumerate(embeddings): 64 | for utterance_num, utterance in enumerate(speaker): 65 | for centroid_num, centroid in enumerate(centroids): 66 | if speaker_num == centroid_num: 67 | centroid = get_centroid(embeddings, speaker_num, utterance_num) 68 | output = F.cosine_similarity(utterance,centroid,dim=0)+1e-6 69 | cossim[speaker_num][utterance_num][centroid_num] = output 70 | return cossim 71 | 72 | def get_cossim(embeddings, centroids): 73 | # number of utterances per speaker 74 | num_utterances = embeddings.shape[1] 75 | utterance_centroids = get_utterance_centroids(embeddings) 76 | 77 | # flatten the embeddings and utterance centroids to just utterance, 78 | # so we can do cosine similarity 79 | utterance_centroids_flat = utterance_centroids.view( 80 | utterance_centroids.shape[0] * utterance_centroids.shape[1], 81 | -1 82 | ) 83 | embeddings_flat = embeddings.view( 84 | embeddings.shape[0] * num_utterances, 85 | -1 86 | ) 87 | # the cosine distance between utterance and the associated centroids 88 | # for that utterance 89 | # this is each speaker's utterances against his own centroid, but each 90 | # comparison centroid has the current utterance removed 91 | cos_same = F.cosine_similarity(embeddings_flat, utterance_centroids_flat) 92 | 93 | # now we get the cosine distance between each utterance and the other speakers' 94 | # centroids 95 | # to do so requires comparing each utterance to each centroid. To keep the 96 | # operation fast, we vectorize by using matrices L (embeddings) and 97 | # R (centroids) where L has each utterance repeated sequentially for all 98 | # comparisons and R has the entire centroids frame repeated for each utterance 99 | centroids_expand = centroids.repeat((num_utterances * embeddings.shape[0], 1)) 100 | embeddings_expand = embeddings_flat.unsqueeze(1).repeat(1, embeddings.shape[0], 1) 101 | embeddings_expand = embeddings_expand.view( 102 | embeddings_expand.shape[0] * embeddings_expand.shape[1], 103 | embeddings_expand.shape[-1] 104 | ) 105 | cos_diff = F.cosine_similarity(embeddings_expand, centroids_expand) 106 | cos_diff = cos_diff.view( 107 | embeddings.size(0), 108 | num_utterances, 109 | centroids.size(0) 110 | ) 111 | # assign the cosine distance for same speakers to the proper idx 112 | same_idx = list(range(embeddings.size(0))) 113 | cos_diff[same_idx, :, same_idx] = cos_same.view(embeddings.shape[0], num_utterances) 114 | cos_diff = cos_diff + 1e-6 115 | return cos_diff 116 | 117 | def calc_loss_prior(sim_matrix): 118 | # Calculates loss from (N, M, K) similarity matrix 119 | per_embedding_loss = torch.zeros(sim_matrix.size(0), sim_matrix.size(1)) 120 | for j in range(len(sim_matrix)): 121 | for i in range(sim_matrix.size(1)): 122 | per_embedding_loss[j][i] = -(sim_matrix[j][i][j] - ((torch.exp(sim_matrix[j][i]).sum()+1e-6).log_())) 123 | loss = per_embedding_loss.sum() 124 | return loss, per_embedding_loss 125 | 126 | def calc_loss(sim_matrix): 127 | same_idx = list(range(sim_matrix.size(0))) 128 | pos = sim_matrix[same_idx, :, same_idx] 129 | neg = (torch.exp(sim_matrix).sum(dim=2) + 1e-6).log_() 130 | per_embedding_loss = -1 * (pos - neg) 131 | loss = per_embedding_loss.sum() 132 | return loss, per_embedding_loss 133 | 134 | def normalize_0_1(values, max_value, min_value): 135 | normalized = np.clip((values - min_value) / (max_value - min_value), 0, 1) 136 | return normalized 137 | 138 | def mfccs_and_spec(wav_file, wav_process = False, calc_mfccs=False, calc_mag_db=False): 139 | sound_file, _ = librosa.core.load(wav_file, sr=hp.data.sr) 140 | window_length = int(hp.data.window*hp.data.sr) 141 | hop_length = int(hp.data.hop*hp.data.sr) 142 | duration = hp.data.tisv_frame * hp.data.hop + hp.data.window 143 | 144 | # Cut silence and fix length 145 | if wav_process == True: 146 | sound_file, index = librosa.effects.trim(sound_file, frame_length=window_length, hop_length=hop_length) 147 | length = int(hp.data.sr * duration) 148 | sound_file = librosa.util.fix_length(sound_file, length) 149 | 150 | spec = librosa.stft(sound_file, n_fft=hp.data.nfft, hop_length=hop_length, win_length=window_length) 151 | mag_spec = np.abs(spec) 152 | 153 | mel_basis = librosa.filters.mel(hp.data.sr, hp.data.nfft, n_mels=hp.data.nmels) 154 | mel_spec = np.dot(mel_basis, mag_spec) 155 | 156 | mag_db = librosa.amplitude_to_db(mag_spec) 157 | #db mel spectrogram 158 | mel_db = librosa.amplitude_to_db(mel_spec).T 159 | 160 | mfccs = None 161 | if calc_mfccs: 162 | mfccs = np.dot(librosa.filters.dct(40, mel_db.shape[0]), mel_db).T 163 | 164 | return mfccs, mel_db, mag_db 165 | 166 | if __name__ == "__main__": 167 | w = grad.Variable(torch.tensor(1.0)) 168 | b = grad.Variable(torch.tensor(0.0)) 169 | embeddings = torch.tensor([[0,1,0],[0,0,1], [0,1,0], [0,1,0], [1,0,0], [1,0,0]]).to(torch.float).reshape(3,2,3) 170 | centroids = get_centroids(embeddings) 171 | cossim = get_cossim(embeddings, centroids) 172 | sim_matrix = w*cossim + b 173 | loss, per_embedding_loss = calc_loss(sim_matrix) 174 | --------------------------------------------------------------------------------