├── SpeakerNetTrainer ├── 00-train_model.sh ├── models │ ├── baseline_lite_ap.model │ ├── ResNetBlocks.py │ └── ResNetSE34L.py ├── setup.py ├── 01-get_model_list.sh ├── 02-download_model.sh ├── accuracy.py ├── tuneThreshold.py ├── LICENSE.md ├── loss │ └── angleproto.py ├── trainSpeakerNet.py ├── DatasetLoader.py └── SpeakerNet.py ├── SpeakerDiarization ├── diarization.png ├── configs │ └── speaker_diarization.conf ├── main.py └── nssd │ ├── endpoint_detector.py │ ├── clustering.py │ ├── epd │ └── webrtc.py │ └── emb_extractor.py └── README.md /SpeakerNetTrainer/00-train_model.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python trainSpeakerNet.py "$@" 4 | 5 | -------------------------------------------------------------------------------- /SpeakerDiarization/diarization.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Namsik-Yoon/Speaker-Diarization/main/SpeakerDiarization/diarization.png -------------------------------------------------------------------------------- /SpeakerNetTrainer/models/baseline_lite_ap.model: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Namsik-Yoon/Speaker-Diarization/main/SpeakerNetTrainer/models/baseline_lite_ap.model -------------------------------------------------------------------------------- /SpeakerNetTrainer/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | setup( 4 | name='speaker-diarization', 5 | version='0.1', 6 | install_requires=['torchaudio==0.3.1'] 7 | ) 8 | -------------------------------------------------------------------------------- /SpeakerNetTrainer/01-get_model_list.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | ID=$1 4 | DATASET=vox 5 | SESSION_NUM=$2 6 | 7 | echo "This script previously listed models on NSML." \ 8 | "Download models manually as needed." 9 | 10 | -------------------------------------------------------------------------------- /SpeakerNetTrainer/02-download_model.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | ID=$1 4 | DATASET=vox 5 | SESSION_NUM=$2 6 | CHECKPOINT=$3 7 | 8 | echo "Download the pretrained model manually and place it in the appropriate directory." 9 | 10 | -------------------------------------------------------------------------------- /SpeakerNetTrainer/accuracy.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/python 2 | # -*- encoding: utf-8 -*- 3 | 4 | def accuracy(output, target, topk=(1,)): 5 | """Computes the precision@k for the specified values of k""" 6 | maxk = max(topk) 7 | batch_size = target.size(0) 8 | 9 | _, pred = output.topk(maxk, 1, True, True) 10 | pred = pred.t() 11 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 12 | 13 | res = [] 14 | for k in topk: 15 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 16 | res.append(correct_k.mul_(100.0 / batch_size)) 17 | return res -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Speaker Diarization 2 | 3 | This project contains code for training speaker embedding models and running a simple diarization pipeline. The original implementation relied on the NSML platform, but the scripts have been adapted for normal Python execution. 4 | 5 | ## Running diarization 6 | 7 | ``` 8 | python SpeakerDiarization/main.py path/to/audio.wav --config SpeakerDiarization/configs/speaker_diarization.conf 9 | ``` 10 | 11 | ## Training SpeakerNet 12 | 13 | ``` 14 | cd SpeakerNetTrainer 15 | python trainSpeakerNet.py --train_list --test_list --save_path exp 16 | ``` 17 | 18 | The training script will create `exp/model` and `exp/result` directories to store checkpoints and evaluation results. 19 | -------------------------------------------------------------------------------- /SpeakerDiarization/configs/speaker_diarization.conf: -------------------------------------------------------------------------------- 1 | { 2 | "inference_config": { 3 | "model_type": "ResNetSEL_16k_150", 4 | "model_path": "third_party/SpeakerNet/models/weights/16k/heavy_256.pt", 5 | "device" : "cuda", 6 | "batch_size": 512 7 | }, 8 | "diarization_config": { 9 | "max_seg_ms": 1380, 10 | "shift_ms": 500, 11 | "method": "ahc", 12 | "num_cluster": "None", 13 | "normalize": true, 14 | "clustering_parameters": { 15 | "ahc_metric" : "cosine", 16 | "ahc_method": "complete", 17 | "ahc_criterion": "distance", 18 | "threshold": 0.94 19 | } 20 | }, 21 | "epd_config": { 22 | "epd_mode": "webrtc", 23 | "resolution": 30, 24 | "voice_criteria": 0.7 25 | } 26 | } 27 | -------------------------------------------------------------------------------- /SpeakerNetTrainer/tuneThreshold.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | #-*- coding: utf-8 -*- 3 | 4 | import os 5 | import glob 6 | import sys 7 | import time 8 | from sklearn import metrics 9 | import numpy 10 | import pdb 11 | 12 | def tuneThresholdfromScore(scores, labels, target_fa, target_fr = None): 13 | 14 | fpr, tpr, thresholds = metrics.roc_curve(labels, scores, pos_label=1) 15 | fnr = 1 - tpr 16 | 17 | fnr = fnr*100 18 | fpr = fpr*100 19 | 20 | tunedThreshold = []; 21 | if target_fr: 22 | for tfr in target_fr: 23 | idx = numpy.nanargmin(numpy.absolute((tfr - fnr))) 24 | tunedThreshold.append([thresholds[idx], fpr[idx], fnr[idx]]); 25 | 26 | for tfa in target_fa: 27 | idx = numpy.nanargmin(numpy.absolute((tfa - fpr))) # numpy.where(fpr<=tfa)[0][-1] 28 | tunedThreshold.append([thresholds[idx], fpr[idx], fnr[idx]]); 29 | 30 | idxE = numpy.nanargmin(numpy.absolute((fnr - fpr))) 31 | eer = max(fpr[idxE],fnr[idxE]) 32 | 33 | return (tunedThreshold, eer, fpr, fnr); 34 | -------------------------------------------------------------------------------- /SpeakerNetTrainer/LICENSE.md: -------------------------------------------------------------------------------- 1 | Copyright (c) 2020-present NAVER Corp. 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in 11 | all copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 19 | THE SOFTWARE. 20 | -------------------------------------------------------------------------------- /SpeakerDiarization/main.py: -------------------------------------------------------------------------------- 1 | """Simple CLI entry point for running speaker diarization locally.""" 2 | import argparse 3 | import json 4 | from pprint import pprint 5 | 6 | from nssd.speaker_diarization import SpeakerDiarization 7 | 8 | 9 | def _build_config(args): 10 | # load configurations 11 | config = json.loads(open(args.config, 'r').read()) 12 | return config 13 | 14 | 15 | def main(args): 16 | """Run diarization on a WAV file.""" 17 | config = _build_config(args) 18 | pprint(config) 19 | speaker_diarization = SpeakerDiarization(**config) 20 | 21 | run_result = speaker_diarization.run([args.input_pcm]) 22 | _, sel_tuples = run_result.export_nssd_data() 23 | pprint(sel_tuples[0]) 24 | 25 | 26 | if __name__ == "__main__": 27 | parser = argparse.ArgumentParser( 28 | description="Run speaker diarization on a WAV file.") 29 | parser.add_argument('--config', 30 | dest='config', 31 | type=str, 32 | help='configuration file for speaker diarization.', 33 | default='configs/speaker_diarization.conf') 34 | parser.add_argument('input_pcm', 35 | type=str, 36 | help='Path to input WAV file') 37 | main(parser.parse_args()) 38 | -------------------------------------------------------------------------------- /SpeakerNetTrainer/loss/angleproto.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/python 2 | # -*- encoding: utf-8 -*- 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import time, pdb, numpy 8 | from accuracy import accuracy 9 | 10 | class AngleProtoLoss(nn.Module): 11 | 12 | def __init__(self, init_w=10.0, init_b=-5.0): 13 | super(AngleProtoLoss, self).__init__() 14 | self.w = nn.Parameter(torch.tensor(init_w)) 15 | self.b = nn.Parameter(torch.tensor(init_b)) 16 | self.criterion = torch.nn.CrossEntropyLoss() 17 | 18 | print('Initialised AngleProto') 19 | 20 | def forward(self, x, label=None): 21 | 22 | out_anchor = torch.mean(x[:,1:,:],1) 23 | out_positive = x[:,0,:] 24 | stepsize = out_anchor.size()[0] 25 | 26 | cos_sim_matrix = F.cosine_similarity(out_positive.unsqueeze(-1).expand(-1,-1,stepsize),out_anchor.unsqueeze(-1).expand(-1,-1,stepsize).transpose(0,2)) 27 | torch.clamp(self.w, 1e-6) 28 | cos_sim_matrix = cos_sim_matrix * self.w + self.b 29 | 30 | label = torch.from_numpy(numpy.asarray(range(0,stepsize))).cuda() 31 | nloss = self.criterion(cos_sim_matrix, label) 32 | prec1, _ = accuracy(cos_sim_matrix.detach().cpu(), label.detach().cpu(), topk=(1, 5)) 33 | 34 | return nloss, prec1 -------------------------------------------------------------------------------- /SpeakerDiarization/nssd/endpoint_detector.py: -------------------------------------------------------------------------------- 1 | """End point detector.""" 2 | import contextlib 3 | import wave 4 | 5 | import nssd.epd.webrtc as webrtc 6 | 7 | def read_wave(path): 8 | """Read WAV file. 9 | 10 | Args: 11 | path: string, path of WAV format file. 12 | 13 | Returns: 14 | pcm_data: bytes, raw wave data. 15 | sample_rate: integer, sampling frequency. 16 | """ 17 | with contextlib.closing(wave.open(path, 'rb')) as wave_file: 18 | num_channels = wave_file.getnchannels() 19 | assert num_channels == 1 20 | sample_width = wave_file.getsampwidth() 21 | assert sample_width == 2 22 | sample_rate = wave_file.getframerate() 23 | assert sample_rate in (8000, 16000, 32000, 48000) 24 | pcm_data = wave_file.readframes(wave_file.getnframes()) 25 | 26 | return pcm_data, sample_rate 27 | 28 | 29 | class EndPointDetector(): 30 | def __init__(self, config, sampling_rate=8000): 31 | self.mode = config.get('epd_mode', 'webrtc') 32 | self.resolution = config.get('resolution', 30) 33 | self.voice_criteria = config.get('voice_criteria', 0.7) 34 | self.sampling_rate = sampling_rate 35 | 36 | def get_epd_result(self, input_pcm): 37 | """Get positions where a speech segment contains in given audio. 38 | 39 | Args: 40 | input_pcm: string, bytearray or bytes. 41 | string, path of audio file. 42 | bytearray or bytes, data from WAV file. 43 | Its frame rate should be 8k or 16k. 44 | Returns: 45 | list of 2-tuple of (start_time, end_time) 46 | """ 47 | vad_segments = [] 48 | 49 | voice_criteria = self.voice_criteria 50 | resolution = self.resolution 51 | mode = self.mode 52 | sampling_rate = self.sampling_rate 53 | 54 | if isinstance(input_pcm, str): 55 | audio, sampling_rate = read_wave(input_pcm) 56 | elif isinstance(input_pcm, bytearray): 57 | audio = input_pcm 58 | elif isinstance(input_pcm, bytes): 59 | audio = bytearray(input_pcm) 60 | else: 61 | raise TypeError("Unsupported input pcm: %s" % type(input_pcm)) 62 | 63 | if mode == 'webrtc': 64 | vad_segments = webrtc.epd(audio, resolution, sampling_rate, voice_criteria) 65 | else: 66 | raise ValueError("Unsupported EPD mode: %s" % mode) 67 | 68 | return vad_segments 69 | -------------------------------------------------------------------------------- /SpeakerDiarization/nssd/clustering.py: -------------------------------------------------------------------------------- 1 | """Clustering modules.""" 2 | from functools import partial 3 | 4 | from scipy.cluster.hierarchy import linkage, fcluster 5 | 6 | def _setup_clusterer(config): 7 | """Factory method to setup clusterer. 8 | 9 | Args: 10 | config: dictionary, hparams for initializing clusterer. 11 | 12 | Returns: 13 | Function that requests clustering. 14 | """ 15 | def _ahc(embeddings, config): 16 | # initialize hparams related with ahc. 17 | num_cluster = config['num_cluster'] 18 | threshold = config['threshold'] 19 | ahc_metric = config['ahc_metric'] 20 | ahc_method = config['ahc_method'] 21 | ahc_criterion = config['ahc_criterion'] 22 | 23 | if ahc_metric == "euclidean": # default method = 'single' 24 | _linkage = linkage(embeddings, metric=ahc_metric, method=ahc_method) 25 | else: # default method = 'complete' 26 | _linkage = linkage(embeddings, metric=ahc_metric, method=ahc_method) 27 | 28 | if num_cluster != "None": # default criterion = 'maxclust' 29 | cluster_labels = fcluster(_linkage, float(num_cluster), 30 | criterion=ahc_criterion) 31 | else: # default criterion = 'distance' 32 | cluster_labels = fcluster(_linkage, threshold, 33 | criterion=ahc_criterion) 34 | return cluster_labels 35 | 36 | method = config['method'] 37 | 38 | if method == "ahc": 39 | return partial(_ahc, config=config) 40 | 41 | raise ValueError("Unsupported clustering method : %s" % method) 42 | 43 | class Clusterer(): 44 | """Wrapper class for clustering module.""" 45 | def __init__(self, config): 46 | self.config = config 47 | self.clusterer = _setup_clusterer(config) 48 | 49 | def predict(self, embeddings): 50 | """Predict speaker label. 51 | 52 | Args: 53 | embeddings: list of embeddings will be clustered. 54 | 55 | Returns: 56 | labels: ndarray of shape (n_samples,). 57 | Index of cluster each sample belong to. 58 | """ 59 | return self.clusterer(embeddings) 60 | 61 | def update_num_cluster(self, n_clusters): 62 | """Update num_cluster. 63 | 64 | If given n_clusters is different from one that clusterer has, 65 | update clusterer. 66 | 67 | Args: 68 | n_clusters: integer, the number of clusters. 69 | """ 70 | if self.config['num_cluster'] != n_clusters: 71 | self.config['num_cluster'] = n_clusters 72 | self.clusterer = _setup_clusterer(self.config) 73 | -------------------------------------------------------------------------------- /SpeakerNetTrainer/models/ResNetBlocks.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/python 2 | # -*- encoding: utf-8 -*- 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | class SEBasicBlock(nn.Module): 8 | expansion = 1 9 | 10 | def __init__(self, inplanes, planes, stride=1, downsample=None, reduction=8): 11 | super(SEBasicBlock, self).__init__() 12 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 13 | self.bn1 = nn.BatchNorm2d(planes) 14 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1, bias=False) 15 | self.bn2 = nn.BatchNorm2d(planes) 16 | self.relu = nn.ReLU(inplace=True) 17 | self.se = SELayer(planes, reduction) 18 | self.downsample = downsample 19 | self.stride = stride 20 | 21 | def forward(self, x): 22 | residual = x 23 | 24 | out = self.conv1(x) 25 | out = self.relu(out) 26 | out = self.bn1(out) 27 | 28 | out = self.conv2(out) 29 | out = self.bn2(out) 30 | out = self.se(out) 31 | 32 | if self.downsample is not None: 33 | residual = self.downsample(x) 34 | 35 | out += residual 36 | out = self.relu(out) 37 | return out 38 | 39 | 40 | class SEBottleneck(nn.Module): 41 | expansion = 4 42 | 43 | def __init__(self, inplanes, planes, stride=1, downsample=None, reduction=8): 44 | super(SEBottleneck, self).__init__() 45 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 46 | self.bn1 = nn.BatchNorm2d(planes) 47 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 48 | padding=1, bias=False) 49 | self.bn2 = nn.BatchNorm2d(planes) 50 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 51 | self.bn3 = nn.BatchNorm2d(planes * 4) 52 | self.relu = nn.ReLU(inplace=True) 53 | self.se = SELayer(planes * 4, reduction) 54 | self.downsample = downsample 55 | self.stride = stride 56 | 57 | def forward(self, x): 58 | residual = x 59 | 60 | out = self.conv1(x) 61 | out = self.bn1(out) 62 | out = self.relu(out) 63 | 64 | out = self.conv2(out) 65 | out = self.bn2(out) 66 | out = self.relu(out) 67 | 68 | out = self.conv3(out) 69 | out = self.bn3(out) 70 | out = self.se(out) 71 | 72 | if self.downsample is not None: 73 | residual = self.downsample(x) 74 | 75 | out += residual 76 | out = self.relu(out) 77 | 78 | return out 79 | 80 | 81 | class SELayer(nn.Module): 82 | def __init__(self, channel, reduction=8): 83 | super(SELayer, self).__init__() 84 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 85 | self.fc = nn.Sequential( 86 | nn.Linear(channel, channel // reduction), 87 | nn.ReLU(inplace=True), 88 | nn.Linear(channel // reduction, channel), 89 | nn.Sigmoid() 90 | ) 91 | 92 | def forward(self, x): 93 | b, c, _, _ = x.size() 94 | y = self.avg_pool(x).view(b, c) 95 | y = self.fc(y).view(b, c, 1, 1) 96 | return x * y -------------------------------------------------------------------------------- /SpeakerDiarization/nssd/epd/webrtc.py: -------------------------------------------------------------------------------- 1 | """webrtc epd.""" 2 | import collections 3 | 4 | import webrtcvad 5 | 6 | class Frame(): 7 | def __init__(self, _bytes, timestamp, duration): 8 | self.bytes = _bytes 9 | self.timestamp = timestamp 10 | self.duration = duration 11 | 12 | def frame_generator(frame_duration_ms, audio, sr): 13 | """Split given audio into frames. 14 | 15 | Args: 16 | frame_duration_ms: integer, frame duration. 17 | audio: bytearray, data from WAV file. 18 | sr: integer, sampling rate. 19 | 20 | Returns: 21 | list of Frame. 22 | """ 23 | 24 | n = int(sr * (frame_duration_ms / 1000.0) * 2) 25 | offset = 0 26 | timestamp = 0.0 27 | duration = (float(n) / sr) / 2.0 28 | frames = [] 29 | 30 | while offset + n < len(audio): 31 | frames.append(Frame(audio[offset:offset + n], timestamp, duration)) 32 | timestamp += duration 33 | offset += n 34 | 35 | return frames 36 | 37 | 38 | def vad_collector(sample_rate, frame_duration_ms, padding_duration_ms, vad, 39 | frames, voice_criteria): 40 | """ 41 | voice_criteria: criteria for TRIGGERED and NONTRIGGERED state 42 | """ 43 | num_padding_frames = int(padding_duration_ms / frame_duration_ms) 44 | voice_criteria = voice_criteria * num_padding_frames 45 | 46 | ring_buffer = collections.deque(maxlen=num_padding_frames) 47 | triggered = False 48 | 49 | s_start, s_end = 0.0, 0.0 50 | 51 | vad_segments = [] 52 | for frame in frames: 53 | is_speech = vad.is_speech(frame.bytes, sample_rate) 54 | 55 | if not triggered: 56 | ring_buffer.append((frame, is_speech)) 57 | num_voiced = len([f for f, speech in ring_buffer if speech]) 58 | 59 | # If more than voice_criteria of the frames in the ring buffer are voiced, 60 | # We enter the TRIGGERED state. 61 | if num_voiced > voice_criteria: 62 | triggered = True 63 | s_start = ring_buffer[0][0].timestamp 64 | ring_buffer.clear() 65 | else: 66 | # TRIGGERED state: append frames to ring_buffer 67 | ring_buffer.append((frame, is_speech)) 68 | num_unvoiced = len([f for f, s in ring_buffer if not s]) 69 | 70 | # If more than voice_criteria of the frames in the ring buffer are unvoiced, 71 | # We enter the NONTRIGGERED state. 72 | if num_unvoiced > voice_criteria: 73 | 74 | # For smaller margin 75 | s_end = ring_buffer[int(num_padding_frames / 2)][0].timestamp 76 | vad_segments.append((s_start, s_end)) 77 | triggered = False 78 | ring_buffer.clear() 79 | 80 | if triggered: 81 | if vad_segments and vad_segments[-1][1] == s_start: 82 | vad_segments[-1] = (vad_segments[-1][0], 83 | frame.timestamp + frame.duration) 84 | else: 85 | vad_segments.append((s_start, frame.timestamp + frame.duration)) 86 | 87 | ## Rounding 88 | for i in range(len(vad_segments)): 89 | x, y = vad_segments[i] 90 | vad_segments[i] = (round(x, 2), round(y, 2)) 91 | 92 | return vad_segments 93 | 94 | def epd(audio, resolution, sample_rate, voice_criteria): 95 | """Get positions where a speech segment contains in given audio. 96 | 97 | Args: 98 | audio: bytearray, data from WAV file. 99 | resolution: integer, frame duration (ms). 100 | sample_rate: interger, sampling rate. 101 | voice_criteria: float, criteria for TRIGGERED and NONTRIGGERED state. 102 | Returns: 103 | list of 2-tuple of (start_time, end_time) 104 | 105 | """ 106 | vad = webrtcvad.Vad(3) 107 | frames = frame_generator(resolution, audio, sample_rate) 108 | vad_segments = vad_collector(sample_rate, resolution, 300, vad, frames, 109 | voice_criteria) 110 | 111 | return vad_segments 112 | -------------------------------------------------------------------------------- /SpeakerDiarization/nssd/emb_extractor.py: -------------------------------------------------------------------------------- 1 | """Embedding extractor.""" 2 | import numpy as np 3 | import torch 4 | 5 | import nssd.utils as utils 6 | from SpeakerNet import SpeakerNet 7 | 8 | def split_segment(segment, batch_size): 9 | """Split the speech segment into small segments to be consumable 10 | by SpeakerNet. 11 | 12 | Args: 13 | segment: np.ndarray with shape (n_windows, window_size). 14 | batch_size: integer, batch size. 15 | 16 | Returns: 17 | np.ndarray with shape (batch_size, window_size). 18 | """ 19 | n_windows = len(segment) 20 | for cur_idx in range(0, n_windows, batch_size): 21 | if cur_idx + batch_size >= n_windows: 22 | yield segment[cur_idx:] 23 | else: 24 | yield segment[cur_idx:cur_idx+batch_size] 25 | 26 | class EmbeddingExtractor(): 27 | """Extract feature vector of speech segment. 28 | 29 | Args: 30 | model_path: string, weight file path of SpeakerNet. 31 | model_type: string, SpeakerNet model type. 32 | batch_size: integer, batch size. 33 | max_frames: integer, window size. 34 | sampling_rate: integer, sampling rate of audio. 35 | device: string, 'cuda' for running with GPU, 36 | 'cpu' for running with CPU. 37 | 38 | Example of using EmbeddingExtractor: 39 | # Instantiation. 40 | embedding_extractor = EmbeddingExtractor( 41 | model_path='/ResNetSE_16k_150.model', 42 | model_type='ResNetSE_16k_150', 43 | batch_size=512, 44 | max_frames=150, 45 | sampling_rate=16000, 46 | device='cuda') 47 | 48 | input_pcm = '/audio.wav' 49 | vad_segment = [(0.14, 2.33), (3.12, 5.77), (10.01, 15.22)] 50 | embeddings = embedding_extractor(input_pcm, vad_segment) 51 | """ 52 | def __init__(self, 53 | model_path, 54 | model_type, 55 | batch_size=512, 56 | max_frames=100, 57 | sampling_rate=8000, 58 | device='cuda'): 59 | assert model_path != '' 60 | self.max_frames = max_frames 61 | self.device = torch.device(device) 62 | self.sampling_rate = sampling_rate 63 | self.batch_size = batch_size 64 | 65 | self.speakernet = SpeakerNet(model_path=model_path, 66 | model_type=model_type, 67 | device=self.device) 68 | 69 | def get_embeddings(self, input_pcm, vad_segments): 70 | """Get embeddings of audio segments. 71 | 72 | Args: 73 | input_pcm: It can be one of string, bytearray or bytes. 74 | string, path of audio file. 75 | bytearray or bytes, data from wavfile. 76 | vad_segments: list of tuple, which is (start_time, end_time). 77 | 78 | Returns: 79 | embeddings: np.ndarray with shape (the number of segments, embedding_size) 80 | """ 81 | embeddings = [] 82 | max_frames = self.max_frames 83 | batch_size = self.batch_size 84 | 85 | # Chop the pcm in to segments from input_pcm 86 | pcm_segs = utils.extract_pcms(input_pcm, 87 | vad_segments, 88 | max_frames, 89 | sr=self.sampling_rate) 90 | pcm_segs = [np.concatenate(pcm_segs, axis=0)] 91 | # input segment corresponding to each vad segment 92 | for seg in pcm_segs: 93 | seg = seg.astype(np.short) 94 | cur_segment = [] 95 | 96 | batches = split_segment(seg, batch_size) 97 | for batch in batches: 98 | with torch.no_grad(): 99 | output = self.speakernet.get_embedding(batch) 100 | cur_segment.append(output) 101 | 102 | cur_segment = np.concatenate(cur_segment, 0) 103 | embeddings.append(cur_segment) 104 | 105 | return np.concatenate(embeddings, 0) 106 | -------------------------------------------------------------------------------- /SpeakerNetTrainer/models/ResNetSE34L.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/python 2 | # -*- encoding: utf-8 -*- 3 | 4 | import torch 5 | import torchaudio 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from torch.nn import Parameter 9 | from models.ResNetBlocks import * 10 | 11 | class ResNetSE(nn.Module): 12 | def __init__(self, block, layers, num_filters, nOut, encoder_type='SAP', **kwargs): 13 | 14 | print('Embedding size is %d, encoder %s.'%(nOut, encoder_type)) 15 | 16 | self.inplanes = num_filters[0] 17 | self.encoder_type = encoder_type 18 | super(ResNetSE, self).__init__() 19 | 20 | self.conv1 = nn.Conv2d(1, num_filters[0] , kernel_size=7, stride=(2, 1), padding=3, 21 | bias=False) 22 | self.bn1 = nn.BatchNorm2d(num_filters[0]) 23 | self.relu = nn.ReLU(inplace=True) 24 | 25 | self.layer1 = self._make_layer(block, num_filters[0], layers[0]) 26 | self.layer2 = self._make_layer(block, num_filters[1], layers[1], stride=(2, 2)) 27 | self.layer3 = self._make_layer(block, num_filters[2], layers[2], stride=(2, 2)) 28 | self.layer4 = self._make_layer(block, num_filters[3], layers[3], stride=(1, 1)) 29 | 30 | self.avgpool = nn.AvgPool2d((5, 1), stride=1) 31 | 32 | self.instancenorm = nn.InstanceNorm1d(40) 33 | self.torchfb = torchaudio.transforms.MelSpectrogram(sample_rate=16000, n_fft=512, win_length=400, hop_length=160, f_min=0.0, f_max=8000, pad=0, n_mels=40) 34 | 35 | if self.encoder_type == "SAP": 36 | self.sap_linear = nn.Linear(num_filters[3] * block.expansion, num_filters[3] * block.expansion) 37 | self.attention = self.new_parameter(num_filters[3] * block.expansion, 1) 38 | out_dim = num_filters[3] * block.expansion 39 | else: 40 | raise ValueError('Undefined encoder') 41 | 42 | self.fc = nn.Linear(out_dim, nOut) 43 | 44 | for m in self.modules(): 45 | if isinstance(m, nn.Conv2d): 46 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 47 | elif isinstance(m, nn.BatchNorm2d): 48 | nn.init.constant_(m.weight, 1) 49 | nn.init.constant_(m.bias, 0) 50 | 51 | def _make_layer(self, block, planes, blocks, stride=1): 52 | downsample = None 53 | if stride != 1 or self.inplanes != planes * block.expansion: 54 | downsample = nn.Sequential( 55 | nn.Conv2d(self.inplanes, planes * block.expansion, 56 | kernel_size=1, stride=stride, bias=False), 57 | nn.BatchNorm2d(planes * block.expansion), 58 | ) 59 | 60 | layers = [] 61 | layers.append(block(self.inplanes, planes, stride, downsample)) 62 | self.inplanes = planes * block.expansion 63 | for i in range(1, blocks): 64 | layers.append(block(self.inplanes, planes)) 65 | 66 | return nn.Sequential(*layers) 67 | 68 | def new_parameter(self, *size): 69 | out = nn.Parameter(torch.FloatTensor(*size)) 70 | nn.init.xavier_normal_(out) 71 | return out 72 | 73 | def forward(self, x): 74 | 75 | x = self.torchfb(x)+1e-6 76 | x = self.instancenorm(x.log()).unsqueeze(1).detach() 77 | 78 | x = self.conv1(x) 79 | x = self.bn1(x) 80 | x = self.relu(x) 81 | 82 | x = self.layer1(x) 83 | x = self.layer2(x) 84 | x = self.layer3(x) 85 | x = self.layer4(x) 86 | x = self.avgpool(x) 87 | 88 | if self.encoder_type == "SAP": 89 | x = x.permute(0, 2, 1, 3) 90 | x = x.squeeze(dim=1).permute(0, 2, 1) # batch * L * D 91 | h = torch.tanh(self.sap_linear(x)) 92 | w = torch.matmul(h, self.attention).squeeze(dim=2) 93 | w = F.softmax(w, dim=1).view(x.size(0), x.size(1), 1) 94 | x = torch.sum(x * w, dim=1) 95 | 96 | x = x.view(x.size()[0], -1) 97 | x = self.fc(x) 98 | 99 | return x 100 | 101 | 102 | def ResNetSE34L(nOut=256, **kwargs): 103 | # Number of filters 104 | num_filters = [32, 64, 128, 128] 105 | model = ResNetSE(SEBasicBlock, [3, 4, 6, 3], num_filters, nOut, **kwargs) 106 | return model 107 | -------------------------------------------------------------------------------- /SpeakerNetTrainer/trainSpeakerNet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | #-*- coding: utf-8 -*- 3 | 4 | import sys, time, os, argparse, socket 5 | import numpy 6 | import pdb 7 | import torch 8 | import glob 9 | from tuneThreshold import tuneThresholdfromScore 10 | from SpeakerNet import SpeakerNet 11 | from DatasetLoader import get_data_loader 12 | 13 | # NSML support has been removed. Training now runs on a local environment. 14 | 15 | parser = argparse.ArgumentParser(description = "SpeakerNet"); 16 | 17 | ## Data loader 18 | parser.add_argument('--max_frames', type=int, default=200, help='Input length to the network'); 19 | parser.add_argument('--batch_size', type=int, default=200, help='Batch size'); 20 | parser.add_argument('--nDataLoaderThread', type=int, default=10, help='Number of loader threads'); 21 | parser.add_argument('--save_path', type=str, default='exp', help='Directory to save models and results'); 22 | 23 | ## Training details 24 | parser.add_argument('--test_interval', type=int, default=1, help='Test and save every [test_interval] epochs'); 25 | parser.add_argument('--max_epoch', type=int, default=100, help='Maximum number of epochs'); 26 | parser.add_argument('--trainfunc', type=str, default="angleproto", help='Loss function'); 27 | parser.add_argument('--optimizer', type=str, default="sgd", help='sgd or adam'); 28 | 29 | ## Learning rates 30 | parser.add_argument('--lr', type=float, default=0.1, help='Learning rate'); 31 | 32 | 33 | ## Training and test data 34 | parser.add_argument('--train_list', type=str, default="", help='Train list'); 35 | parser.add_argument('--test_list', type=str, default="", help='Evaluation list'); 36 | parser.add_argument('--train_path', type=str, default="voxceleb2", help='Absolute path to the train set'); 37 | parser.add_argument('--test_path', type=str, default="voxceleb1", help='Absolute path to the test set'); 38 | 39 | ## For test only 40 | parser.add_argument('--eval', dest='eval', action='store_true', help='Eval only') 41 | 42 | ## Model definition 43 | parser.add_argument('--model', type=str, default="ResNetSE34L", help='Name of model definition'); 44 | parser.add_argument('--initial_model', type=str, default="./models/baseline_lite_ap.model", help='Initial model weights'); 45 | parser.add_argument('--encoder_type', type=str, default="SAP", help='Type of encoder'); 46 | parser.add_argument('--nOut', type=int, default=512, help='Embedding size in the last FC layer'); 47 | 48 | args = parser.parse_args(); 49 | 50 | # ==================== INITIAL SETUP ==================== 51 | 52 | model_save_path = os.path.join(args.save_path, "model") 53 | result_save_path = os.path.join(args.save_path, "result") 54 | feat_save_path = "" 55 | 56 | # ==================== MAKE DIRECTORIES ==================== 57 | 58 | if not(os.path.exists(model_save_path)): 59 | os.makedirs(model_save_path) 60 | 61 | if not(os.path.exists(result_save_path)): 62 | os.makedirs(result_save_path) 63 | else: 64 | print("Folder already exists. Press Enter to continue...") 65 | 66 | # ==================== LOAD MODEL ==================== 67 | 68 | s = SpeakerNet(**vars(args)) 69 | if(args.initial_model != ""): 70 | s.loadParameters(args.initial_model); 71 | print("Model %s loaded!"%args.initial_model); 72 | 73 | 74 | it = 1; 75 | prevloss = float("inf"); 76 | sumloss = 0; 77 | min_eer = []; 78 | 79 | trainLoader = get_data_loader(args.train_list, **vars(args)) 80 | s.make_scheduler(trainLoader, args.max_epoch, args.lr) 81 | 82 | while(1): 83 | 84 | trainLoader.dataset.make_spk_list() 85 | loss, traineer = s.train_network(loader=trainLoader); 86 | 87 | # ==================== EVALUATE LIST ==================== 88 | 89 | if it % args.test_interval == 0: 90 | 91 | print(time.strftime("%Y-%m-%d %H:%M:%S"), it, "Evaluating..."); 92 | 93 | sc, lab = s.evaluateFromListSave(args.test_list, print_interval=100, feat_dir=feat_save_path, test_path=args.test_path) 94 | result = tuneThresholdfromScore(sc, lab, [1, 0.1]); 95 | 96 | print(time.strftime("%Y-%m-%d %H:%M:%S"), "TEER %2.2f, TLOSS %f, EER %2.4f"%( traineer, loss, result[1])); 97 | 98 | min_eer.append(result[1]) 99 | s.current_EER = result[1] 100 | 101 | else: 102 | print(time.strftime("%Y-%m-%d %H:%M:%S"), "TEER %2.2f, TLOSS %f"%( traineer, loss)); 103 | 104 | 105 | 106 | 107 | # ==================== SAVE MODEL ==================== 108 | 109 | if it >= args.max_epoch: 110 | quit(); 111 | 112 | it+=1; 113 | print(""); 114 | 115 | 116 | 117 | 118 | 119 | 120 | -------------------------------------------------------------------------------- /SpeakerNetTrainer/DatasetLoader.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/python 2 | # -*- encoding: utf-8 -*- 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import numpy 8 | import random 9 | import pdb 10 | import os 11 | import threading 12 | import time 13 | import math 14 | import glob 15 | from scipy.io import wavfile 16 | from queue import Queue 17 | from torch.utils.data import Dataset, DataLoader 18 | from torchaudio import transforms 19 | from scipy import signal 20 | 21 | def worker_init_fn(worker_id): 22 | numpy.random.seed(numpy.random.get_state()[1][0] + worker_id) 23 | 24 | class wav_split(Dataset): 25 | def __init__(self, dataset_file_name, max_frames, train_path, batch_size, noise_aug = False): 26 | self.dataset_file_name = dataset_file_name 27 | self.max_frames = max_frames 28 | self.noise_aug = noise_aug 29 | # self.instancenorm = nn.InstanceNorm1d(40) 30 | self.data_dict = {} 31 | self.data_list = [] 32 | self.nFiles = 0 33 | self.batch_size = batch_size 34 | 35 | self.noisetypes = ['noise','speech','music'] 36 | 37 | self.noisesnr = {'noise':[0,15],'speech':[13,20],'music':[5,15]} 38 | self.noiselist = {} 39 | 40 | augment_files = glob.glob('/home1/irteam/db/musan/*/*/*/*.wav'); 41 | 42 | for file in augment_files: 43 | if not file.split('/')[-4] in self.noiselist: 44 | self.noiselist[file.split('/')[-4]] = [] 45 | self.noiselist[file.split('/')[-4]].append(file) 46 | 47 | ### Read Training Files... 48 | self.spk_dic = {} 49 | with open(dataset_file_name) as dataset_file: 50 | while True: 51 | line = dataset_file.readline(); 52 | if not line: 53 | break; 54 | 55 | data = line.split(); 56 | speaker_name = data[0]; 57 | filename = os.path.join(train_path,data[1]); 58 | if speaker_name not in self.spk_dic: 59 | self.spk_dic[speaker_name] = [] 60 | 61 | self.data_list.append(filename) 62 | self.spk_dic[speaker_name].append(filename) 63 | 64 | 65 | 66 | def make_spk_list(self): 67 | self.spk_list = [] 68 | 69 | while len(self.spk_list) < self.__len__(): 70 | spk_list_tmp = list(self.spk_dic.keys()) 71 | numpy.random.shuffle(spk_list_tmp) 72 | self.spk_list.extend(spk_list_tmp) 73 | 74 | 75 | def __getitem__(self, index): 76 | fns = numpy.random.choice(self.spk_dic[self.spk_list[index]],size = 2, replace = False) 77 | 78 | audio = [] 79 | audio.append(loadWAV(fns[0], self.max_frames, evalmode=False).astype(numpy.float)[0]) 80 | audio.append(loadWAV(fns[1], self.max_frames, evalmode=False).astype(numpy.float)[0]) 81 | 82 | 83 | if self.noise_aug: 84 | augment_profiles = [] 85 | audio_aug = [] 86 | for ii in range(len(audio)): 87 | ## additive noise profile 88 | noisecat = random.choice(self.noisetypes) 89 | noisefile = random.choice(self.noiselist[noisecat].copy()) 90 | snr = [random.uniform(self.noisesnr[noisecat][0],self.noisesnr[noisecat][1])] 91 | augment_profiles.append({'add_noise': noisefile, 'add_snr': snr}) 92 | 93 | audio_aug.append(self.augment_wav(audio[0],augment_profiles[0])) 94 | audio_aug.append(self.augment_wav(audio[1],augment_profiles[1])) 95 | 96 | audio = numpy.concatenate(audio_aug,axis=0) 97 | else: 98 | audio = numpy.stack(audio,axis=0) 99 | 100 | audio = torch.FloatTensor(audio) 101 | 102 | return audio 103 | 104 | def __len__(self): 105 | return len(self.data_list) 106 | 107 | 108 | def augment_wav(self,audio,augment): 109 | 110 | noiseaudio = loadWAV(augment['add_noise'], self.max_frames, evalmode=False).astype(numpy.float) 111 | 112 | noise_db = 10 * numpy.log10(numpy.mean(noiseaudio[0] ** 2)+1e-4) 113 | clean_db = 10 * numpy.log10(numpy.mean(audio ** 2)+1e-4) 114 | 115 | noise = numpy.sqrt(10 ** ((clean_db - noise_db - augment['add_snr']) / 10)) * noiseaudio 116 | audio = audio + noise 117 | 118 | return audio 119 | 120 | 121 | def round_down(num, divisor): 122 | return num - (num%divisor) 123 | 124 | def loadWAV(filename, max_frames, evalmode=True, num_eval=10): 125 | 126 | # Maximum audio length 127 | max_audio = max_frames * 160 + 240 128 | 129 | # Read wav file and convert to torch tensor 130 | sample_rate, audio = wavfile.read(filename) 131 | 132 | audiosize = audio.shape[0] 133 | 134 | if audiosize <= max_audio: 135 | shortage = math.floor( ( max_audio - audiosize + 1 ) / 2 ) 136 | audio = numpy.pad(audio, (shortage, shortage), 'constant', constant_values=0) 137 | audiosize = audio.shape[0] 138 | 139 | if evalmode: 140 | startframe = numpy.linspace(0,audiosize-max_audio,num=num_eval) 141 | else: 142 | startframe = numpy.array([numpy.int64(random.random()*(audiosize-max_audio))]) 143 | 144 | feats = [] 145 | if evalmode and max_frames == 0: 146 | feats.append(audio) 147 | else: 148 | for asf in startframe: 149 | feats.append(audio[int(asf):int(asf)+max_audio]) 150 | 151 | feat = numpy.stack(feats,axis=0) 152 | 153 | #feat = torch.FloatTensor(feat) 154 | 155 | return feat; 156 | 157 | 158 | def get_data_loader(dataset_file_name, batch_size, max_frames, nDataLoaderThread, train_path, **kwargs): 159 | 160 | train_dataset = wav_split(dataset_file_name, max_frames, train_path, batch_size) 161 | 162 | train_loader = torch.utils.data.DataLoader( 163 | train_dataset, 164 | batch_size=batch_size, 165 | shuffle=False, 166 | num_workers=nDataLoaderThread, 167 | pin_memory=True, 168 | drop_last=True, 169 | worker_init_fn=worker_init_fn, 170 | ) 171 | 172 | return train_loader 173 | -------------------------------------------------------------------------------- /SpeakerNetTrainer/SpeakerNet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | #-*- coding: utf-8 -*- 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import numpy, math, pdb, sys, random 8 | import time, os, itertools, shutil, importlib 9 | from tuneThreshold import tuneThresholdfromScore 10 | from DatasetLoader import loadWAV 11 | from loss.angleproto import AngleProtoLoss 12 | 13 | def cosine_annealing(step, total_steps, lr_max, lr_min): 14 | return lr_min + (lr_max - lr_min) * 0.5 * ( 15 | 1 + numpy.cos(step / total_steps * numpy.pi)) 16 | 17 | class SpeakerNet(nn.Module): 18 | 19 | def __init__(self, max_frames, lr = 0.0001, model="ResNetSE34", nOut = 512, nSpeakers = 1000, optimizer = 'adam', encoder_type = 'SAP', normalize = True, trainfunc='angleproto', **kwargs): 20 | super(SpeakerNet, self).__init__(); 21 | 22 | argsdict = {'nOut': nOut, 'encoder_type':encoder_type} 23 | print(lr) 24 | 25 | SpeakerNetModel = importlib.import_module('models.%s'%(model)).__getattribute__(model) 26 | self.__S__ = SpeakerNetModel(**argsdict).cuda(); 27 | 28 | if trainfunc == 'angleproto': 29 | self.__L__ = AngleProtoLoss().cuda() 30 | self.__train_normalize__ = True 31 | self.__test_normalize__ = True 32 | else: 33 | raise ValueError('Undefined loss.') 34 | 35 | if optimizer == 'adam': 36 | self.__optimizer__ = torch.optim.Adam(self.parameters(), lr = lr); 37 | elif optimizer == 'sgd': 38 | self.__optimizer__ = torch.optim.SGD(self.parameters(), lr = lr, momentum = 0.9, weight_decay=5e-5); 39 | else: 40 | raise ValueError('Undefined optimizer.') 41 | 42 | self.__max_frames__ = max_frames; 43 | self.current_EER = 50.0 44 | 45 | ## ===== ===== ===== ===== ===== ===== ===== ===== 46 | ## Train contrastive 47 | ## ===== ===== ===== ===== ===== ===== ===== ===== 48 | 49 | def train_network(self, loader, print_interval=100): 50 | clr = [] 51 | for param_group in self.__optimizer__.param_groups: 52 | clr.append(param_group['lr']) 53 | 54 | print(time.strftime("%Y-%m-%d %H:%M:%S"), "Training with LR %.5f..."%(max(clr))) 55 | self.train(); 56 | 57 | # ==================== INITIAL PARAMETERS ==================== 58 | 59 | stepsize = loader.batch_size; 60 | 61 | counter = 0; 62 | index = 0; 63 | loss = 0; 64 | top1 = 0 65 | 66 | for data in loader: 67 | 68 | tstart = time.time() 69 | 70 | self.zero_grad(); 71 | 72 | org_size = data.size() 73 | data = data.reshape(org_size[0] * org_size[1], -1) 74 | outp = self.__S__.forward(data.cuda()) 75 | if self.__train_normalize__: 76 | outp = F.normalize(outp, p=2, dim=1) 77 | feat = outp.reshape(org_size[0], org_size[1], -1) 78 | 79 | nloss, prec1 = self.__L__.forward(feat) 80 | 81 | loss += nloss.detach().cpu(); 82 | top1 += prec1 83 | counter += 1; 84 | index += stepsize; 85 | 86 | nloss.backward(); 87 | self.__optimizer__.step(); 88 | self.scheduler.step() 89 | 90 | telapsed = time.time() - tstart 91 | 92 | if counter % print_interval == 1: 93 | print("Processing (%d) Loss %f ACC/T1 %2.3f%% - %.2f Hz "%(index, loss/counter, top1/counter, stepsize/telapsed)) 94 | 95 | return (loss/counter, top1/counter); 96 | 97 | ## ===== ===== ===== ===== ===== ===== ===== ===== 98 | ## Read data from list 99 | ## ===== ===== ===== ===== ===== ===== ===== ===== 100 | 101 | def readDataFromList(self, listfilename): 102 | 103 | data_list = {}; 104 | 105 | with open(listfilename) as listfile: 106 | while True: 107 | line = listfile.readline(); 108 | if not line: 109 | break; 110 | 111 | data = line.split(); 112 | filename = data[1]; 113 | speaker_name = data[0] 114 | 115 | if not (speaker_name in data_list): 116 | data_list[speaker_name] = []; 117 | data_list[speaker_name].append(filename); 118 | 119 | return data_list 120 | 121 | 122 | ## ===== ===== ===== ===== ===== ===== ===== ===== 123 | ## Evaluate from list 124 | ## ===== ===== ===== ===== ===== ===== ===== ===== 125 | 126 | def evaluateFromListSave(self, listfilename, print_interval=5000, feat_dir='', test_path='', num_eval=10): 127 | 128 | self.eval(); 129 | 130 | lines = [] 131 | files = [] 132 | filedict = {} 133 | feats = {} 134 | tstart = time.time() 135 | 136 | if feat_dir != '': 137 | print('Saving temporary files to %s'%feat_dir) 138 | if not(os.path.exists(feat_dir)): 139 | os.makedirs(feat_dir) 140 | 141 | ## Read all lines 142 | with open(listfilename) as listfile: 143 | while True: 144 | line = listfile.readline(); 145 | if (not line): # or (len(all_scores)==1000) 146 | break; 147 | 148 | data = line.split(); 149 | 150 | files.append(data[1]) 151 | files.append(data[2]) 152 | lines.append(line) 153 | 154 | setfiles = list(set(files)) 155 | setfiles.sort() 156 | 157 | ## Save all features to file 158 | for idx, file in enumerate(setfiles): 159 | 160 | inp1 = torch.FloatTensor(loadWAV(os.path.join(test_path,file), self.__max_frames__, evalmode=True, num_eval=num_eval)).cuda() 161 | 162 | with torch.no_grad(): 163 | ref_feat = self.__S__.forward(inp1).detach().cpu() 164 | 165 | filename = '%06d.wav'%idx 166 | 167 | if feat_dir == '': 168 | feats[file] = ref_feat 169 | else: 170 | filedict[file] = filename 171 | torch.save(ref_feat,os.path.join(feat_dir,filename)) 172 | 173 | telapsed = time.time() - tstart 174 | 175 | if idx % print_interval == 0: 176 | print("Reading %d: %.2f Hz, embed size %d"%(idx,idx/telapsed,ref_feat.size()[1])) 177 | 178 | print('') 179 | all_scores = []; 180 | all_labels = []; 181 | tstart = time.time() 182 | 183 | ## Read files and compute all scores 184 | for idx, line in enumerate(lines): 185 | 186 | data = line.split(); 187 | 188 | if feat_dir == '': 189 | ref_feat = feats[data[1]].cuda() 190 | com_feat = feats[data[2]].cuda() 191 | else: 192 | ref_feat = torch.load(os.path.join(feat_dir,filedict[data[1]])).cuda() 193 | com_feat = torch.load(os.path.join(feat_dir,filedict[data[2]])).cuda() 194 | 195 | if self.__test_normalize__: 196 | ref_feat = F.normalize(ref_feat, p=2, dim=1) 197 | com_feat = F.normalize(com_feat, p=2, dim=1) 198 | 199 | dist = F.pairwise_distance(ref_feat.unsqueeze(-1).expand(-1,-1,num_eval), com_feat.unsqueeze(-1).expand(-1,-1,num_eval).transpose(0,2)).detach().cpu().numpy(); 200 | 201 | score = -1 * numpy.mean(dist); 202 | 203 | all_scores.append(score); 204 | all_labels.append(int(data[0])); 205 | 206 | if idx % print_interval == 0: 207 | telapsed = time.time() - tstart 208 | print("Computing %d: %.2f Hz"%(idx,idx/telapsed)) 209 | 210 | if feat_dir != '': 211 | print(' Deleting temporary files.') 212 | shutil.rmtree(feat_dir) 213 | 214 | print('\n') 215 | 216 | return (all_scores, all_labels); 217 | 218 | 219 | ## ===== ===== ===== ===== ===== ===== ===== ===== 220 | ## Update learning rate 221 | ## ===== ===== ===== ===== ===== ===== ===== ===== 222 | 223 | def make_scheduler(self, loader, max_epoch, lr, min_lr = 0.000005): 224 | total_steps = max_epoch * len(loader) 225 | 226 | self.scheduler = torch.optim.lr_scheduler.LambdaLR( 227 | self.__optimizer__, 228 | lr_lambda=lambda step: cosine_annealing( 229 | step, 230 | total_steps, 231 | 1, 232 | min_lr / lr)) 233 | 234 | 235 | ## ===== ===== ===== ===== ===== ===== ===== ===== 236 | ## Save parameters 237 | ## ===== ===== ===== ===== ===== ===== ===== ===== 238 | 239 | def saveParameters(self, path): 240 | 241 | torch.save(self.state_dict(), path + '/pretrained.pt'); 242 | 243 | 244 | ## ===== ===== ===== ===== ===== ===== ===== ===== 245 | ## Load parameters 246 | ## ===== ===== ===== ===== ===== ===== ===== ===== 247 | 248 | def loadParameters(self, path): 249 | 250 | self_state = self.state_dict(); 251 | loaded_state = torch.load(path); 252 | for name, param in loaded_state.items(): 253 | origname = name; 254 | if name not in self_state: 255 | name = name.replace("module.", ""); 256 | 257 | if name not in self_state: 258 | print("%s is not in the model."%origname); 259 | continue; 260 | 261 | if self_state[name].size() != loaded_state[origname].size(): 262 | print("Wrong parameter length: %s, model: %s, loaded: %s"%(origname, self_state[name].size(), loaded_state[origname].size())); 263 | continue; 264 | 265 | self_state[name].copy_(param); 266 | 267 | 268 | ## ===== ===== ===== ===== ===== ===== ===== ===== 269 | ## Load parameters 270 | ## ===== ===== ===== ===== ===== ===== ===== ===== 271 | 272 | def loadParametersModelEnsemble(self, paths): 273 | 274 | self_state = self.state_dict(); 275 | 276 | states = [] 277 | for path in paths: 278 | states.append(torch.load(path)); 279 | 280 | loaded_state = states[0] 281 | 282 | for state in states[1:]: 283 | for name, param in state.items(): 284 | loaded_state[name] = loaded_state[name] + param 285 | 286 | for name, param in state.items(): 287 | loaded_state[name] = loaded_state[name] / len(paths) 288 | 289 | for name, param in loaded_state.items(): 290 | origname = name; 291 | if name not in self_state: 292 | name = name.replace("module.", ""); 293 | 294 | if name not in self_state: 295 | print("%s is not in the model."%origname); 296 | continue; 297 | 298 | if self_state[name].size() != loaded_state[origname].size(): 299 | print("Wrong parameter length: %s, model: %s, loaded: %s"%(origname, self_state[name].size(), loaded_state[origname].size())); 300 | continue; 301 | 302 | self_state[name].copy_(param); --------------------------------------------------------------------------------