├── requirements.txt ├── .gitattributes ├── .gitignore ├── train.sh ├── ge2e ├── conf.py ├── utils.py ├── dataset.py ├── profile_loss.py ├── nnet.py ├── train_ge2e.py ├── compute_dvector.py └── trainer.py └── README.md /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.0.0 2 | numpy==1.14.3 -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | * text=auto 2 | 3 | *.py text eol=lf 4 | *.sh text eol=lf -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode/ 2 | __pycache__/ 3 | ge2e/libs/ 4 | .*.py 5 | *.yaml 6 | egs.py -------------------------------------------------------------------------------- /train.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # wujian@2018 3 | 4 | set -eu 5 | 6 | train_steps=2500 7 | dev_steps=800 8 | chunk_size="140,180" 9 | cpt_dir=exp/ge2e 10 | epochs=50 11 | 12 | echo "$0 $@" 13 | 14 | [ $# -ne 2 ] && echo "Script format error: $0 " && exit 1 15 | 16 | ./nnet/train_ge2e.py \ 17 | --M 10 \ 18 | --N 64 \ 19 | --gpu $1 \ 20 | --epochs $epochs \ 21 | --train-steps $train_steps \ 22 | --dev-steps $dev_steps \ 23 | --chunk-size $chunk_size \ 24 | --checkpoint $cpt_dir/$2 \ 25 | > $2.train.log 2>&1 -------------------------------------------------------------------------------- /ge2e/conf.py: -------------------------------------------------------------------------------- 1 | # nnet opts 2 | 3 | lstm_conf = {"num_layers": 3, "hidden_size": 738, "dropout": 0.2} 4 | 5 | nnet_conf = {"feature_dim": 40, "embedding_dim": 256, "lstm_conf": lstm_conf} 6 | 7 | # trainer opts 8 | opt_kwargs = {"lr": 1e-2, "weight_decay": 1e-5, "momentum": 0.8} 9 | 10 | trainer_conf = { 11 | "optimizer": "sgd", 12 | "optimizer_kwargs": opt_kwargs, 13 | "clip_norm": 10, 14 | "min_lr": 1e-8, 15 | "patience": 2, 16 | "factor": 0.5, 17 | "no_impr": 6, 18 | "logging_period": 200 # steps 19 | } 20 | 21 | train_dir = "data/train" 22 | dev_dir = "data/dev" -------------------------------------------------------------------------------- /ge2e/utils.py: -------------------------------------------------------------------------------- 1 | # wujian@2018 2 | 3 | import os 4 | import json 5 | 6 | import os.path as op 7 | 8 | 9 | def dump_json(obj, fdir, name): 10 | """ 11 | Dump python object in json 12 | """ 13 | if fdir and not op.exists(fdir): 14 | os.makedirs(fdir) 15 | with open(op.join(fdir, name), "w") as f: 16 | json.dump(obj, f, indent=4, sort_keys=False) 17 | 18 | 19 | def load_json(fdir, name): 20 | """ 21 | Load json as python object 22 | """ 23 | path = op.join(fdir, name) 24 | if not op.exists(path): 25 | raise FileNotFoundError("Could not find json file: {}".format(path)) 26 | with open(path, "r") as f: 27 | obj = json.load(f) 28 | return obj -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Speaker Verification with GE2E Loss 2 | 3 | Pytorch implement of "Generalized End-to-End Loss for Speaker Verification" 4 | 5 | ### Data Processing 6 | 7 | 1. Vad (recommend [py-webrtcvad](https://github.com/wiseman/py-webrtcvad)) 8 | 2. Log mel-spectrogram features (recommend [librosa](https://github.com/librosa/librosa)) 9 | 3. Prepare data as `data/{train,dev}/{feats.scp,spk2utt}` 10 | 11 | ### Usage 12 | 13 | see [train.sh](train.sh) and [compute_dvector.py](ge2e/compute_dvector.py) 14 | 15 | ### Reference 16 | 17 | Wan L, Wang Q, Papir A, et al. Generalized end-to-end loss for speaker verification[C]//2018 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP). IEEE, 2018: 4879-4883. -------------------------------------------------------------------------------- /ge2e/dataset.py: -------------------------------------------------------------------------------- 1 | # wujian@2018 2 | 3 | import random 4 | import os.path as op 5 | 6 | import numpy as np 7 | import torch as th 8 | 9 | from kaldi_python_io import Reader, ScriptReader 10 | 11 | 12 | class SpeakerSampler(object): 13 | """ 14 | Remember to filter speakers which utterance number lower than M 15 | """ 16 | 17 | def __init__(self, data_dir): 18 | depends = [op.join(data_dir, x) for x in ["feats.scp", "spk2utt"]] 19 | for depend in depends: 20 | if not op.exists(depend): 21 | raise RuntimeError("Missing {}!".format(depend)) 22 | self.reader = ScriptReader(depends[0]) 23 | self.spk2utt = Reader(depends[1], num_tokens=-1) 24 | 25 | def sample(self, N=64, M=10, chunk_size=(140, 180)): 26 | """ 27 | N: number of spks 28 | M: number of utts 29 | """ 30 | spks = random.sample(self.spk2utt.index_keys, N) 31 | chunks = [] 32 | eg = dict() 33 | eg["N"] = N 34 | eg["M"] = M 35 | C = random.randint(*chunk_size) 36 | for spk in spks: 37 | utt_sets = self.spk2utt[spk] 38 | if len(utt_sets) < M: 39 | raise RuntimeError( 40 | "Speaker {} can not got enough utterance with M = {:d}". 41 | format(spk, M)) 42 | samp_utts = random.sample(utt_sets, M) 43 | for uttid in samp_utts: 44 | utt = self.reader[uttid] 45 | pad = C - utt.shape[0] 46 | if pad < 0: 47 | start = random.randint(0, -pad) 48 | chunks.append(utt[start:start + C]) 49 | else: 50 | chunk = np.pad(utt, ((pad, 0), (0, 0)), "edge") 51 | chunks.append(chunk) 52 | eg["feats"] = th.from_numpy(np.stack(chunks)) 53 | return eg 54 | 55 | 56 | class SpeakerLoader(object): 57 | def __init__(self, 58 | data_dir, 59 | N=64, 60 | M=10, 61 | chunk_size=(140, 180), 62 | num_steps=10000): 63 | self.sampler = SpeakerSampler(data_dir) 64 | self.N, self.M, self.C = N, M, chunk_size 65 | self.num_steps = num_steps 66 | 67 | def _sample(self): 68 | return self.sampler.sample(self.N, self.M, self.C) 69 | 70 | def __iter__(self): 71 | for _ in range(self.num_steps): 72 | yield self._sample() 73 | -------------------------------------------------------------------------------- /ge2e/profile_loss.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # wujian@2018 3 | 4 | import time 5 | import torch as th 6 | import torch.nn.functional as F 7 | 8 | 9 | def ge2e_v1(e, N, M): 10 | """ 11 | e: N x M x D, after L2 norm 12 | N: number of spks 13 | M: number of utts 14 | """ 15 | # N x D 16 | c = th.mean(e, dim=1) 17 | s = th.sum(e, dim=1) 18 | # build similarity matrix 19 | dst = [] 20 | # jth speaker 21 | for j in range(N): 22 | # ith utterance 23 | for i in range(M): 24 | # kth ref speaker 25 | for k in range(N): 26 | if k == j: 27 | # fix centroid 28 | cj = (s[j] - e[j][i]) / (M - 1) 29 | dst.append(th.dot(e[j][i], cj)) 30 | else: 31 | dst.append(th.dot(e[j][i], c[k])) 32 | # N*M*N 33 | sim = th.stack(dst) 34 | # N*M x N 35 | sim = sim.view(-1, N) 36 | # build label N*M 37 | ref = th.zeros(N * M, dtype=th.int64, device=e.device) 38 | for r, s in enumerate(range(0, N * M, M)): 39 | ref[s:s + M] = r 40 | # ce loss 41 | loss = F.cross_entropy(sim, ref) 42 | return loss 43 | 44 | 45 | def ge2e_v2(e, N, M): 46 | """ 47 | e: N x M x D, after L2 norm 48 | N: number of spks 49 | M: number of utts 50 | """ 51 | # N x D 52 | c = th.mean(e, dim=1) 53 | s = th.sum(e, dim=1) 54 | # build similarity matrix 55 | # NM * D 56 | e = e.view(N * M, -1) 57 | # NM * N 58 | sim = th.mm(e, th.transpose(c, 0, 1)) 59 | # fix similarity matrix 60 | for j in range(N): 61 | for i in range(M): 62 | cj = (s[j] - e[j*M + i]) / (M - 1) 63 | sim[j*M + i][j] = th.dot(cj, e[j*M + i]) 64 | # build label N*M 65 | ref = th.zeros(N * M, dtype=th.int64, device=e.device) 66 | for r, s in enumerate(range(0, N * M, M)): 67 | ref[s:s + M] = r 68 | # ce loss 69 | loss = F.cross_entropy(sim, ref) 70 | return loss 71 | 72 | 73 | def foo(): 74 | N, M, D = 64, 20, 64 75 | e = th.rand(N, M, D) 76 | e = e / th.norm(e, dim=-1, keepdim=True) 77 | s = time.time() 78 | loss = ge2e_v1(e, N, M) 79 | t = time.time() 80 | print(loss.data) 81 | print("cost: {:.2f}".format(t - s)) 82 | s = time.time() 83 | loss = ge2e_v2(e, N, M) 84 | t = time.time() 85 | print(loss.data) 86 | print("cost: {:.2f}".format(t - s)) 87 | 88 | if __name__ == "__main__": 89 | foo() -------------------------------------------------------------------------------- /ge2e/nnet.py: -------------------------------------------------------------------------------- 1 | # wujian@2018 2 | 3 | import math 4 | import torch as th 5 | import torch.nn as nn 6 | 7 | import torch.nn.functional as F 8 | from torch.nn.utils.rnn import PackedSequence, pad_packed_sequence 9 | 10 | 11 | class TorchRNN(nn.Module): 12 | def __init__(self, 13 | feature_dim, 14 | rnn="lstm", 15 | num_layers=2, 16 | hidden_size=512, 17 | dropout=0.0, 18 | bidirectional=False): 19 | super(TorchRNN, self).__init__() 20 | RNN = rnn.upper() 21 | supported_rnn = {"LSTM": nn.LSTM, "RNN": nn.RNN, "GRU": nn.GRU} 22 | if RNN not in supported_rnn: 23 | raise RuntimeError("unknown RNN type: {}".format(RNN)) 24 | self.rnn = supported_rnn[RNN]( 25 | feature_dim, 26 | hidden_size, 27 | num_layers, 28 | batch_first=True, 29 | dropout=dropout, 30 | bidirectional=bidirectional) 31 | self.output_dim = hidden_size if not bidirectional else hidden_size * 2 32 | 33 | def forward(self, x, squeeze=False, total_length=None): 34 | """ 35 | Accept tensor([N]xTxF) or PackedSequence Object 36 | """ 37 | is_packed = isinstance(x, PackedSequence) 38 | # extend dim when inference 39 | if not is_packed: 40 | if x.dim() not in [2, 3]: 41 | raise RuntimeError( 42 | "RNN expect input dim as 2 or 3, got {:d}".format(x.dim())) 43 | if x.dim() != 3: 44 | x = th.unsqueeze(x, 0) 45 | x, _ = self.rnn(x) 46 | # using unpacked sequence 47 | # x: NxTxD 48 | if is_packed: 49 | x, _ = pad_packed_sequence( 50 | x, batch_first=True, total_length=total_length) 51 | if squeeze: 52 | x = th.squeeze(x) 53 | return x 54 | 55 | 56 | class Nnet(nn.Module): 57 | def __init__(self, feature_dim=40, embedding_dim=256, lstm_conf=None): 58 | super(Nnet, self).__init__() 59 | self.encoder = TorchRNN(feature_dim, **lstm_conf) 60 | self.linear = nn.Linear(self.encoder.output_dim, embedding_dim) 61 | 62 | def forward(self, x): 63 | x = self.encoder(x) 64 | if x.dim() == 3: 65 | x = self.linear(x[:, -1, :]) 66 | else: 67 | x = self.linear(x[-1, :]) 68 | return x / th.norm(x, dim=-1, keepdim=True) 69 | 70 | 71 | def foo_lstm(): 72 | lstm_conf = {"num_layers": 3, "hidden_size": 738, "dropout": 0.5} 73 | nnet_conf = { 74 | "feature_dim": 40, 75 | "embedding_dim": 256, 76 | "lstm_conf": lstm_conf 77 | } 78 | nnet = Nnet(**nnet_conf) 79 | x = th.rand(100, 40) 80 | x = nnet(x) 81 | print(x.shape) 82 | 83 | 84 | if __name__ == "__main__": 85 | foo_lstm() -------------------------------------------------------------------------------- /ge2e/train_ge2e.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # wujian@2018 4 | 5 | import os 6 | import pprint 7 | import argparse 8 | import random 9 | 10 | from nnet import Nnet 11 | from trainer import GE2ETrainer, get_logger 12 | from dataset import SpeakerLoader 13 | from utils import dump_json 14 | from conf import nnet_conf, trainer_conf, train_dir, dev_dir 15 | 16 | logger = get_logger(__name__) 17 | 18 | 19 | def run(args): 20 | parse_str = lambda s: tuple(map(int, s.split(","))) 21 | nnet = Nnet(**nnet_conf) 22 | 23 | trainer = GE2ETrainer( 24 | nnet, 25 | gpuid=parse_str(args.gpu), 26 | checkpoint=args.checkpoint, 27 | resume=args.resume, 28 | **trainer_conf) 29 | 30 | loader_conf = { 31 | "M": args.M, 32 | "N": args.N, 33 | "chunk_size": parse_str(args.chunk_size) 34 | } 35 | for conf, fname in zip([nnet_conf, trainer_conf, loader_conf], 36 | ["mdl.json", "trainer.json", "loader.json"]): 37 | dump_json(conf, args.checkpoint, fname) 38 | 39 | train_loader = SpeakerLoader( 40 | train_dir, **loader_conf, num_steps=args.train_steps) 41 | dev_loader = SpeakerLoader( 42 | dev_dir, **loader_conf, num_steps=args.dev_steps) 43 | 44 | trainer.run(train_loader, dev_loader, num_epochs=args.epochs) 45 | 46 | 47 | if __name__ == "__main__": 48 | parser = argparse.ArgumentParser( 49 | description="Command to train speaker embedding model using GE2E loss, " 50 | "auto configured from conf.py", 51 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 52 | parser.add_argument( 53 | "--gpu", type=str, default=0, help="Training on which GPUs") 54 | parser.add_argument( 55 | "--epochs", type=int, default=50, help="Number of training epochs") 56 | parser.add_argument( 57 | "--checkpoint", 58 | type=str, 59 | required=True, 60 | help="Directory to dump models") 61 | parser.add_argument( 62 | "--N", type=int, default=64, help="Number of speakers in each batch") 63 | parser.add_argument( 64 | "--M", 65 | type=int, 66 | default=10, 67 | help="Number of utterances for each speaker") 68 | parser.add_argument( 69 | "--train-steps", 70 | type=int, 71 | default=5000, 72 | help="Number of training steps in one epoch") 73 | parser.add_argument( 74 | "--dev-steps", 75 | type=int, 76 | default=800, 77 | help="Number of validation steps in one epoch") 78 | parser.add_argument( 79 | "--chunk-size", 80 | type=str, 81 | default="140,180", 82 | help="Range of chunk size, eg: 140,180") 83 | parser.add_argument( 84 | "--resume", 85 | type=str, 86 | default="", 87 | help="Checkpoint to resume training process") 88 | args = parser.parse_args() 89 | logger.info("Arguments in command:\n{}".format(pprint.pformat(vars(args)))) 90 | run(args) 91 | -------------------------------------------------------------------------------- /ge2e/compute_dvector.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # wujian@2018 4 | 5 | import os 6 | import argparse 7 | 8 | import torch as th 9 | import numpy as np 10 | 11 | from utils import load_json 12 | from nnet import Nnet 13 | from trainer import get_logger 14 | 15 | from kaldi_python_io import Reader, ScriptReader 16 | 17 | logger = get_logger(__name__) 18 | 19 | 20 | class NnetComputer(object): 21 | """ 22 | Compute output of networks 23 | """ 24 | 25 | def __init__(self, cpt_dir, gpuid): 26 | # chunk size when inference 27 | loader_conf = load_json(cpt_dir, "loader.json") 28 | self.chunk_size = sum(loader_conf["chunk_size"]) // 2 29 | logger.info("Using chunk size {:d}".format(self.chunk_size)) 30 | # GPU or CPU 31 | self.device = "cuda:{}".format(gpuid) if gpuid >= 0 else "cpu" 32 | # load nnet 33 | nnet = self._load_nnet(cpt_dir) 34 | self.nnet = nnet.to(self.device) 35 | 36 | def _load_nnet(self, cpt_dir): 37 | # nnet config 38 | nnet_conf = load_json(cpt_dir, "mdl.json") 39 | nnet = Nnet(**nnet_conf) 40 | cpt_fname = os.path.join(cpt_dir, "best.pt.tar") 41 | cpt = th.load(cpt_fname, map_location="cpu") 42 | nnet.load_state_dict(cpt["model_state_dict"]) 43 | logger.info("Load checkpoint from {}, epoch {:d}".format( 44 | cpt_fname, cpt["epoch"])) 45 | nnet.eval() 46 | return nnet 47 | 48 | def _make_chunk(self, feats): 49 | T, F = feats.shape 50 | # step: half chunk 51 | S = self.chunk_size // 2 52 | N = (T - self.chunk_size) // S + 1 53 | if N <= 0: 54 | return feats 55 | elif N == 1: 56 | return feats[:self.chunk_size] 57 | else: 58 | chunks = th.zeros([N, self.chunk_size, F], 59 | device=feats.device, 60 | dtype=feats.dtype) 61 | for n in range(N): 62 | chunks[n] = feats[n * S:n * S + self.chunk_size] 63 | return chunks 64 | 65 | def compute(self, feats): 66 | feats = th.tensor(feats, device=self.device) 67 | with th.no_grad(): 68 | chunks = self._make_chunk(feats) # N x C x F 69 | dvector = self.nnet(chunks) # N x D 70 | dvector = th.mean(dvector, dim=0).detach() 71 | return dvector.cpu().numpy() 72 | 73 | 74 | def run(args): 75 | feats_reader = ScriptReader(args.feats) 76 | computer = NnetComputer(args.checkpoint, args.gpu) 77 | if not os.path.exists(args.dump_dir): 78 | os.makedirs(args.dump_dir) 79 | for key, feats in feats_reader: 80 | logger.info("Compute dvector on utterance {}...".format(key)) 81 | dvector = computer.compute(feats) 82 | np.save(os.path.join(args.dump_dir, key), dvector) 83 | logger.info("Compute over {:d} utterances".format(len(feats_reader))) 84 | 85 | 86 | if __name__ == "__main__": 87 | parser = argparse.ArgumentParser( 88 | description="Command to compute dvector from SpeakerNet", 89 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 90 | parser.add_argument("checkpoint", type=str, help="Directory of checkpoint") 91 | parser.add_argument( 92 | "--feats", 93 | type=str, 94 | required=True, 95 | help="Rspecifier for input features") 96 | parser.add_argument( 97 | "--gpu", 98 | type=int, 99 | default=-1, 100 | help="GPU-id to offload model to, -1 means running on CPU") 101 | parser.add_argument( 102 | "--dump-dir", 103 | type=str, 104 | default="dvector", 105 | help="Directory to dump dvector out") 106 | args = parser.parse_args() 107 | run(args) -------------------------------------------------------------------------------- /ge2e/trainer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # wujian@2018 4 | 5 | import os 6 | import sys 7 | import time 8 | import logging 9 | 10 | from collections import defaultdict 11 | 12 | import torch as th 13 | import torch.nn as nn 14 | import torch.nn.functional as F 15 | 16 | from torch.optim.lr_scheduler import ReduceLROnPlateau 17 | from torch.nn.utils import clip_grad_norm_ 18 | 19 | 20 | def get_logger( 21 | name, 22 | format_str="%(asctime)s [%(pathname)s:%(lineno)s - %(levelname)s ] %(message)s", 23 | date_format="%Y-%m-%d %H:%M:%S", 24 | file=False): 25 | """ 26 | Get python logger instance 27 | """ 28 | logger = logging.getLogger(name) 29 | logger.setLevel(logging.INFO) 30 | # file or console 31 | handler = logging.StreamHandler() if not file else logging.FileHandler( 32 | name) 33 | handler.setLevel(logging.INFO) 34 | formatter = logging.Formatter(fmt=format_str, datefmt=date_format) 35 | handler.setFormatter(formatter) 36 | logger.addHandler(handler) 37 | return logger 38 | 39 | 40 | def load_obj(obj, device): 41 | """ 42 | Offload tensor object in obj to cuda device 43 | """ 44 | 45 | def cuda(obj): 46 | return obj.to(device) if isinstance(obj, th.Tensor) else obj 47 | 48 | if isinstance(obj, dict): 49 | return {key: load_obj(obj[key], device) for key in obj} 50 | elif isinstance(obj, list): 51 | return [load_obj(val, device) for val in obj] 52 | else: 53 | return cuda(obj) 54 | 55 | 56 | class SimpleTimer(object): 57 | """ 58 | A simple timer 59 | """ 60 | 61 | def __init__(self): 62 | self.reset() 63 | 64 | def reset(self): 65 | self.start = time.time() 66 | 67 | def elapsed(self): 68 | return (time.time() - self.start) / 60 69 | 70 | 71 | class Reporter(object): 72 | """ 73 | A simple progress reporter 74 | """ 75 | 76 | def __init__(self, logger, period=100): 77 | self.period = period 78 | self.logger = logger 79 | self.loss = [] 80 | self.timer = SimpleTimer() 81 | 82 | def add(self, loss): 83 | self.loss.append(loss) 84 | N = len(self.loss) 85 | if not N % self.period: 86 | avg = sum(self.loss[-self.period:]) / self.period 87 | self.logger.info("Processed {:d} batches" 88 | "(loss = {:+.2f})...".format(N, avg)) 89 | 90 | def report(self, details=False): 91 | N = len(self.loss) 92 | if details: 93 | sstr = ",".join(map(lambda f: "{:.2f}".format(f), self.loss)) 94 | self.logger.info("Loss on {:d} batches: {}".format(N, sstr)) 95 | return { 96 | "loss": sum(self.loss) / N, 97 | "batches": N, 98 | "cost": self.timer.elapsed() 99 | } 100 | 101 | 102 | class GE2ELoss(nn.Module): 103 | def __init__(self): 104 | super(GE2ELoss, self).__init__() 105 | self.w = nn.Parameter(th.tensor(10.0)) 106 | self.b = nn.Parameter(th.tensor(-5.0)) 107 | 108 | def forward(self, e, N, M): 109 | """ 110 | e: N x M x D, after L2 norm 111 | N: number of spks 112 | M: number of utts 113 | """ 114 | # N x D 115 | c = th.mean(e, dim=1) 116 | s = th.sum(e, dim=1) 117 | # NM * D 118 | e = e.view(N * M, -1) 119 | # compute similarity matrix: NM * N 120 | sim = th.mm(e, th.transpose(c, 0, 1)) 121 | # fix similarity matrix: eq (8), (9) 122 | for j in range(N): 123 | for i in range(M): 124 | cj = (s[j] - e[j * M + i]) / (M - 1) 125 | sim[j * M + i][j] = th.dot(cj, e[j * M + i]) 126 | # eq (5) 127 | sim = self.w * sim + self.b 128 | # build label N*M 129 | ref = th.zeros(N * M, dtype=th.int64, device=e.device) 130 | for r, s in enumerate(range(0, N * M, M)): 131 | ref[s:s + M] = r 132 | # ce loss 133 | loss = F.cross_entropy(sim, ref) 134 | return loss 135 | 136 | 137 | class GE2ETrainer(object): 138 | """ 139 | Train speaker embedding model using GE2E loss 140 | """ 141 | 142 | def __init__(self, 143 | nnet, 144 | checkpoint="checkpoint", 145 | optimizer="sgd", 146 | gpuid=None, 147 | optimizer_kwargs=None, 148 | clip_norm=None, 149 | min_lr=0, 150 | patience=0, 151 | factor=0.5, 152 | logging_period=1000, 153 | resume=None, 154 | no_impr=6): 155 | if not th.cuda.is_available(): 156 | raise RuntimeError("CUDA device unavailable...exist") 157 | if not isinstance(gpuid, tuple): 158 | gpuid = (gpuid, ) 159 | self.device = th.device("cuda:{}".format(gpuid[0])) 160 | self.gpuid = gpuid 161 | if checkpoint and not os.path.exists(checkpoint): 162 | os.makedirs(checkpoint) 163 | self.checkpoint = checkpoint 164 | self.logger = get_logger( 165 | os.path.join(checkpoint, "trainer.log"), file=True) 166 | 167 | self.clip_norm = clip_norm 168 | self.logging_period = logging_period 169 | self.cur_epoch = 0 # zero based 170 | self.no_impr = no_impr 171 | 172 | if resume: 173 | if not os.path.exists(resume): 174 | raise FileNotFoundError( 175 | "Could not find resume checkpoint: {}".format(resume)) 176 | cpt = th.load(resume, map_location="cpu") 177 | self.cur_epoch = cpt["epoch"] 178 | self.logger.info("Resume from checkpoint {}: epoch {:d}".format( 179 | resume, self.cur_epoch)) 180 | # load nnet 181 | nnet.load_state_dict(cpt["model_state_dict"]) 182 | self.nnet = nnet.to(self.device) 183 | # load ge2e 184 | ge2e_loss = GE2ELoss() 185 | ge2e_loss.load_state_dict(cpt["ge2e_state_dict"]) 186 | self.ge2e = ge2e_loss.to(self.device) 187 | self.optimizer = self.create_optimizer( 188 | optimizer, optimizer_kwargs, state=cpt["optim_state_dict"]) 189 | else: 190 | self.nnet = nnet.to(self.device) 191 | ge2e_loss = GE2ELoss() 192 | self.ge2e = ge2e_loss.to(self.device) 193 | self.optimizer = self.create_optimizer(optimizer, optimizer_kwargs) 194 | self.scheduler = ReduceLROnPlateau( 195 | self.optimizer, 196 | mode="min", 197 | factor=factor, 198 | patience=patience, 199 | min_lr=min_lr, 200 | verbose=True) 201 | self.num_params = sum( 202 | [param.nelement() for param in nnet.parameters()]) / 10.0**6 203 | 204 | # logging 205 | self.logger.info("Model summary:\n{}".format(nnet)) 206 | self.logger.info("Loading model to GPUs:{}, #param: {:.2f}M".format( 207 | gpuid, self.num_params)) 208 | if clip_norm: 209 | self.logger.info( 210 | "Gradient clipping by {}, default L2".format(clip_norm)) 211 | 212 | def save_checkpoint(self, best=True): 213 | cpt = { 214 | "epoch": self.cur_epoch, 215 | "model_state_dict": self.nnet.state_dict(), 216 | "optim_state_dict": self.optimizer.state_dict(), 217 | "ge2e_state_dict": self.ge2e.state_dict() 218 | } 219 | th.save( 220 | cpt, 221 | os.path.join(self.checkpoint, 222 | "{0}.pt.tar".format("best" if best else "last"))) 223 | 224 | def create_optimizer(self, optimizer, kwargs, state=None): 225 | supported_optimizer = { 226 | "sgd": th.optim.SGD, # momentum, weight_decay, lr 227 | "rmsprop": th.optim.RMSprop, # momentum, weight_decay, lr 228 | "adam": th.optim.Adam, # weight_decay, lr 229 | "adadelta": th.optim.Adadelta, # weight_decay, lr 230 | "adagrad": th.optim.Adagrad, # lr, lr_decay, weight_decay 231 | "adamax": th.optim.Adamax # lr, weight_decay 232 | # ... 233 | } 234 | if optimizer not in supported_optimizer: 235 | raise ValueError("Now only support optimizer {}".format(optimizer)) 236 | params = [{ 237 | "params": self.nnet.parameters() 238 | }, { 239 | "params": self.ge2e.parameters() 240 | }] 241 | opt = supported_optimizer[optimizer](params, **kwargs) 242 | self.logger.info("Create optimizer {0}: {1}".format(optimizer, kwargs)) 243 | if state is not None: 244 | opt.load_state_dict(state) 245 | self.logger.info("Load optimizer state dict from checkpoint") 246 | return opt 247 | 248 | def compute_loss(self, egs): 249 | """ 250 | Compute ge2e loss 251 | """ 252 | N, M = egs["N"], egs["M"] 253 | # NM x D 254 | embed = th.nn.parallel.data_parallel( 255 | self.nnet, egs["feats"], device_ids=self.gpuid) 256 | if embed.size(0) != N * M: 257 | raise RuntimeError( 258 | "Seems something wrong with egs, dimention check failed({:d} vs {:d})" 259 | .format(embed.size(0), M * N)) 260 | embed = embed.view(N, M, -1) 261 | loss = self.ge2e(embed, N, M) 262 | return loss 263 | 264 | def train(self, data_loader): 265 | self.logger.info("Set train mode...") 266 | self.nnet.train() 267 | reporter = Reporter(self.logger, period=self.logging_period) 268 | 269 | for egs in data_loader: 270 | # load to gpu 271 | egs = load_obj(egs, self.device) 272 | 273 | self.optimizer.zero_grad() 274 | loss = self.compute_loss(egs) 275 | loss.backward() 276 | if self.clip_norm: 277 | clip_grad_norm_(self.nnet.parameters(), self.clip_norm) 278 | self.optimizer.step() 279 | 280 | reporter.add(loss.item()) 281 | return reporter.report() 282 | 283 | def eval(self, data_loader): 284 | self.logger.info("Set eval mode...") 285 | self.nnet.eval() 286 | reporter = Reporter(self.logger, period=self.logging_period) 287 | 288 | with th.no_grad(): 289 | for egs in data_loader: 290 | egs = load_obj(egs, self.device) 291 | loss = self.compute_loss(egs) 292 | reporter.add(loss.item()) 293 | return reporter.report(details=True) 294 | 295 | def run(self, train_loader, dev_loader, num_epochs=50): 296 | # avoid alloc memory from gpu0 297 | with th.cuda.device(self.gpuid[0]): 298 | stats = dict() 299 | # check if save is OK 300 | self.save_checkpoint(best=False) 301 | cv = self.eval(dev_loader) 302 | best_loss = cv["loss"] 303 | self.logger.info("START FROM EPOCH {:d}, LOSS = {:.4f}".format( 304 | self.cur_epoch, best_loss)) 305 | no_impr = 0 306 | # make sure not inf 307 | self.scheduler.best = best_loss 308 | while self.cur_epoch < num_epochs: 309 | self.cur_epoch += 1 310 | cur_lr = self.optimizer.param_groups[0]["lr"] 311 | stats[ 312 | "title"] = "Loss(time/N, lr={:.3e}) - Epoch {:2d}:".format( 313 | cur_lr, self.cur_epoch) 314 | tr = self.train(train_loader) 315 | stats["tr"] = "train = {:+.4f}({:.2f}m/{:d})".format( 316 | tr["loss"], tr["cost"], tr["batches"]) 317 | cv = self.eval(dev_loader) 318 | stats["cv"] = "dev = {:+.4f}({:.2f}m/{:d})".format( 319 | cv["loss"], cv["cost"], cv["batches"]) 320 | stats["scheduler"] = "" 321 | if cv["loss"] > best_loss: 322 | no_impr += 1 323 | stats["scheduler"] = "| no impr, best = {:.4f}".format( 324 | self.scheduler.best) 325 | else: 326 | best_loss = cv["loss"] 327 | no_impr = 0 328 | self.save_checkpoint(best=True) 329 | self.logger.info( 330 | "{title} {tr} | {cv} {scheduler}".format(**stats)) 331 | # schedule here 332 | self.scheduler.step(cv["loss"]) 333 | # flush scheduler info 334 | sys.stdout.flush() 335 | # save last checkpoint 336 | self.save_checkpoint(best=False) 337 | if no_impr == self.no_impr: 338 | self.logger.info( 339 | "Stop training cause no impr for {:d} epochs".format( 340 | no_impr)) 341 | break 342 | self.logger.info("Training for {:d}/{:d} epoches done!".format( 343 | self.cur_epoch, num_epochs)) 344 | --------------------------------------------------------------------------------