├── .gitignore ├── requirements.txt ├── utils ├── tools.py ├── dataLoader.py ├── loss.py ├── ECAPAModel.py └── model.py ├── create_data_index.py ├── README.md ├── inference.py ├── train.py └── LICENSE /.gitignore: -------------------------------------------------------------------------------- 1 | *.ipynb 2 | exp_*/ 3 | *_sec* 4 | submit.csv 5 | data/ 6 | exps/ 7 | *.pyc 8 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.9.0 2 | torchaudio==0.9.0 3 | pydub==0.21.0 4 | numba==0.48.0 5 | numpy==1.15.4 6 | pandas==0.23.3 7 | scipy==1.2.1 8 | scikit-learn==0.19.1 9 | tqdm 10 | SoundFile 11 | librosa==0.6.2 12 | -------------------------------------------------------------------------------- /utils/tools.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Some utilized functions 3 | These functions are all copied from voxceleb_trainer: https://github.com/clovaai/voxceleb_trainer/blob/master/tuneThreshold.py 4 | ''' 5 | 6 | import os, numpy, torch 7 | from sklearn import metrics 8 | from operator import itemgetter 9 | import torch.nn.functional as F 10 | 11 | def init_args(args): 12 | args.score_save_path = os.path.join(args.save_path, 'score.txt') 13 | args.model_save_path = os.path.join(args.save_path, 'model') 14 | os.makedirs(args.model_save_path, exist_ok = True) 15 | return args 16 | 17 | 18 | -------------------------------------------------------------------------------- /create_data_index.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pandas as pd 3 | from tqdm import tqdm 4 | from sklearn.model_selection import train_test_split 5 | 6 | SEED = 42 7 | def main(): 8 | label_dic = { 9 | '普通话': 0, 10 | '成都': 1, 11 | '郑州': 2, 12 | '武汉': 3, 13 | '广州': 4, 14 | '上海': 5, 15 | '杭州': 6, 16 | '厦门': 7, 17 | '长沙': 8, 18 | } 19 | ids = [] 20 | paths = [] 21 | labels = [] 22 | for r, _, files in tqdm(os.walk('./data/train')): 23 | for fname in files: 24 | if fname.endswith('.wav'): 25 | label = label_dic[r.split('/')[-1]] 26 | fid = fname.split('.')[0] 27 | fpath = os.path.join(r, fname) 28 | 29 | ids.append(fid) 30 | paths.append(fpath) 31 | labels.append(label) 32 | 33 | train_df = pd.DataFrame({ 34 | 'id': ids, 35 | 'wav_path': paths, 36 | 'label': labels 37 | }) 38 | train_df = train_df.sample(frac=1,random_state=SEED) 39 | new_train_df, valid_df = train_test_split(train_df, test_size=0.2,random_state=SEED) 40 | new_train_df.to_csv('data/train_df', index=False, encoding='utf8') 41 | valid_df.to_csv('data/valid_df', index=False, encoding='utf8') 42 | 43 | if __name__ =='__main__': 44 | main() -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 第八届信也科技杯baseline 2 | 3 | 数智创新,声至未来 4 | Deep in Dialects, for Future Wave. 5 | 6 | 这是第八届信也科技的baseline。 7 | 本届大赛以“智能语音质检,提升用户体验”为背景,探索利用AI技术识别和还原语音数据中的方言信息,特别是不同方言之间的距离特征的问题。这一问题有助于更好地理解汉语语音及其方言、口音特征,以及将相关技术从理论到实际应用的实现,以进一步支持对用户的更好服务。 8 | 9 | 10 | 11 | ## Environments 12 | Implementing environment: 13 | - python=3.6 14 | - torch==1.9.0 15 | - torchaudio==0.9.0 16 | - pydub==0.21.0 17 | - numba==0.48.0 18 | - numpy==1.15.4 19 | - pandas==0.23.3 20 | - scipy==1.2.1 21 | - scikit-learn==0.19.1 22 | - tqdm 23 | - SoundFile==0.12.1 24 | 25 | - GPU: Tesla V100 32G 26 | 27 | 28 | 29 | ## Dataset 30 | ./data 目录下有所需的test_pair数据文件。 31 | 32 | test_pair 包含提交所需的100万个数据对,需要选手提交对应的一百万个方言距离,并严格按照test_pair内的样本顺序 33 | 34 | 音频数据请从共享地址下载: 35 | 36 | 请将下载好的数据文件 train.zip 和 test.zip 置于工程根目录下,执行 37 | ```bash 38 | unzip "*.zip" -d ./data/ 39 | python create_data_index.py 40 | ``` 41 | 解压文件并生成目录索引。文件索引选手可根据个人需求自行生成。 42 | 43 | ## Training 44 | 45 | 46 | ```bash 47 | python train.py --loss aamsoftmax --max_epoch 80 --device cuda:0 --save_path ./exps/ 48 | ``` 49 | 50 | 51 | ```bash 52 | python train.py --loss StandardSimilarityLoss --max_epoch 80 --device cuda:0 --save_path ./exps_sim/ 53 | ``` 54 | 55 | 56 | ```bash 57 | python train.py --loss PairDistanceLoss --max_epoch 80 --device cuda:0 --save_path ./exps_pairdist/ 58 | ``` 59 | 60 | 61 | ## Inference 62 | ```bash 63 | python inference.py --model_path exps/model/model_0001.model --test_path data/test --device cuda:0 64 | ``` 65 | 会在根目录下生成提交所需的submit.csv文件 66 | 67 | ## Acknowledge 68 | - We borrowed a lot of code from [ECAPA-TDNN](https://github.com/TaoRuijie/ECAPA-TDNN) for modeling 69 | 70 | 71 | ## Authors 72 | 73 | 以下作者对本项目有贡献(以姓氏排名): 74 | - Chen, Yifei 75 | - Gao, Feng 76 | - Kou, Kai 77 | - Ni, Boyi 78 | - Wang, Shaoming 79 | - Zhang, Xuan 80 | -------------------------------------------------------------------------------- /utils/dataLoader.py: -------------------------------------------------------------------------------- 1 | ''' 2 | DataLoader for training 3 | ''' 4 | import pandas as pd 5 | import glob, numpy, os, random, soundfile, torch 6 | from scipy import signal 7 | 8 | class train_loader(object): 9 | def __init__(self, train_list, num_frames, **kwargs): 10 | self.num_frames = num_frames 11 | # Load data & labels 12 | self.df_train = pd.read_csv(train_list) 13 | print("train_data_shape:", self.df_train.shape) 14 | self.data_list = self.df_train['wav_path'].tolist() 15 | self.data_label = self.df_train['label'].tolist() 16 | 17 | def __getitem__(self, index): 18 | # Read the utterance and randomly select the segment 19 | audio, sr = soundfile.read(self.data_list[index]) 20 | length = self.num_frames * 80 21 | #data augment 22 | while len(audio)<=length: 23 | df_tmp = self.df_train[self.df_train["label"]==self.data_label[index]] 24 | df_tmp = df_tmp.sample(n=1) 25 | audio_2, sr2 = soundfile.read(df_tmp["wav_path"].tolist()[0]) 26 | audio = numpy.concatenate((audio,audio_2)) 27 | 28 | 29 | if audio.shape[0] <= length: 30 | shortage = length - audio.shape[0] 31 | audio = numpy.pad(audio, (0, shortage), 'wrap') 32 | start_frame = numpy.int64(random.random()*(audio.shape[0]-length)) 33 | audio = audio[start_frame:start_frame + length] 34 | audio = numpy.stack([audio],axis=0) 35 | return torch.FloatTensor(audio[0]), self.data_label[index] 36 | 37 | def __len__(self): 38 | return len(self.data_list) 39 | 40 | class eval_loader(object): 41 | def __init__(self, eval_list, **kwargs): 42 | # Load data & labels 43 | df_eval = pd.read_csv(eval_list) 44 | print("valid_data_shape:", df_eval.shape) 45 | self.data_list = df_eval['wav_path'].tolist() 46 | self.data_label = df_eval['label'].tolist() 47 | 48 | def __getitem__(self, index): 49 | # Read the utterance 50 | audio, sr = soundfile.read(self.data_list[index]) 51 | audio = numpy.stack([audio],axis=0) 52 | return torch.FloatTensor(audio[0]), self.data_label[index] 53 | 54 | def __len__(self): 55 | return len(self.data_list) 56 | 57 | def my_collate_fn(data, max_length=12000): 58 | 59 | lens = [x[0].shape[0] for x in data] 60 | max_len = max(lens) 61 | print(max_len) 62 | max_length = min(max_len, max_length) 63 | 64 | features = torch.zeros(len(data), max_length) 65 | for i, length in enumerate(lens): 66 | features[i,:length] = data[i][0][:max_length] 67 | labels = torch.tensor([x[1] for x in data]) 68 | 69 | return features, labels 70 | 71 | if __name__ == '__main__': 72 | eval_loader = eval_loader('./data_0609/test2','.') 73 | evalLoader = torch.utils.data.DataLoader(eval_loader, batch_size = 256, shuffle = True, num_workers = 4, collate_fn=my_collate_fn) 74 | for data, label in evalLoader: 75 | print(data.shape) 76 | print(label.shape) 77 | -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import soundfile as sf 4 | from tqdm import tqdm 5 | import torch 6 | from utils.ECAPAModel import ECAPAModel, ECAPA_TDNN 7 | import argparse 8 | import os 9 | SEED = 42 10 | def load_model(model_path, C=1024): 11 | model = ECAPA_TDNN(C=C) 12 | 13 | self_state = model.state_dict() 14 | loaded_state = torch.load(model_path) 15 | for name, param in loaded_state.items(): 16 | origname = name 17 | if name not in self_state: 18 | name = name.replace("dialect_encoder.", "") 19 | if name not in self_state: 20 | 21 | continue 22 | if self_state[name].size() != loaded_state[origname].size(): 23 | print("Wrong parameter length: %s, model: %s, loaded: %s"%(origname, self_state[name].size(), loaded_state[origname].size())) 24 | continue 25 | self_state[name].copy_(param) 26 | 27 | model.load_state_dict(self_state) 28 | model.eval() 29 | return model 30 | 31 | 32 | def compute_pair_distance(model, pair_df, device='cuda',res_path='submit.csv'): 33 | model.to(device) 34 | id2path = {} 35 | ids = set(pair_df.id1.tolist()+pair_df.id2.tolist()) 36 | for r,_,files in os.walk('data'): 37 | for fname in files: 38 | if fname.endswith('.wav'): 39 | fid = fname.split('.')[0] 40 | fpath = os.path.join(r,fname) 41 | id2path[fid] = fpath 42 | pair_df['wav_path1'] = pair_df.id1.apply(lambda x:id2path[x]) 43 | pair_df['wav_path2'] = pair_df.id2.apply(lambda x:id2path[x]) 44 | 45 | embeddings = {} 46 | res = [] 47 | for wav_id in tqdm(ids): 48 | wav_path = id2path[wav_id] 49 | audio, sr = sf.read(wav_path) 50 | audio = np.stack([audio],axis=0) 51 | audio = torch.FloatTensor(audio[0]).unsqueeze(0).to(device) 52 | with torch.no_grad(): 53 | embedding = model.forward(audio, aug=False) 54 | embeddings[wav_id] = embedding 55 | 56 | for i, row in tqdm(pair_df.iterrows()): 57 | dist = torch.cdist(embeddings[row['id1']], embeddings[row['id2']], p=2) 58 | res.append(dist.item()) 59 | pair_df['distance']=res 60 | pair_df[['distance']].to_csv(res_path, encoding='utf8',index=False) 61 | return pair_df 62 | 63 | def main(): 64 | parser = argparse.ArgumentParser(description = "ECAPA_trainer") 65 | parser.add_argument('--device', type=str, default='cuda:0', help='Device model inferring on ') 66 | parser.add_argument('--model_path', type=str, default='exps/model/model_0002.model', help='Model checkpoint path') 67 | parser.add_argument('--test_path', type=str, default='data/test_pair', help='Path of test file, strictly same with the original file') 68 | parser.add_argument('--save_path', type=str, default='submit.csv', help='Path of result') 69 | args = parser.parse_args() 70 | print('loading model...') 71 | model = load_model(args.model_path) 72 | pair_df = pd.read_csv(args.test_path) 73 | 74 | print('model inferring...') 75 | compute_pair_distance(model, pair_df, device=args.device, res_path=args.save_path) 76 | 77 | 78 | if __name__ == '__main__': 79 | main() 80 | -------------------------------------------------------------------------------- /utils/loss.py: -------------------------------------------------------------------------------- 1 | ''' 2 | AAMsoftmax loss function copied from voxceleb_trainer: https://github.com/clovaai/voxceleb_trainer/blob/master/loss/aamsoftmax.py 3 | ''' 4 | 5 | import torch, math 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from utils.tools import * 9 | 10 | import numpy as np 11 | import time 12 | DISTANCE_MATRIX = np.array([[ 0. , 32.1, 23.4, 34.4, 68.7, 67.7, 57.1, 79.2, 51.7], 13 | [32.1, 0. , 38.2, 24.5, 65.8, 60.7, 49.9, 76.6, 43.7], 14 | [23.4, 38.2, 0. , 40.1, 68.6, 70.5, 61.8, 79.1, 55.9], 15 | [34.4, 24.5, 40.1, 0. , 66. , 62.6, 54.2, 77.2, 34. ], 16 | [68.7, 65.8, 68.6, 66. , 0. , 67.1, 68.3, 71. , 68.4], 17 | [67.7, 60.7, 70.5, 62.6, 67.1, 0. , 40.7, 78. , 63.9], 18 | [57.1, 49.9, 61.8, 54.2, 68.3, 40.7, 0. , 76.2, 57.3], 19 | [79.2, 76.6, 79.1, 77.2, 71. , 78. , 76.2, 0. , 77. ], 20 | [51.7, 43.7, 55.9, 34. , 68.4, 63.9, 57.3, 77. , 0. ]]) 21 | class AAMsoftmax(nn.Module): 22 | def __init__(self, n_class, m, s): 23 | 24 | super(AAMsoftmax, self).__init__() 25 | self.m = m 26 | self.s = s 27 | self.weight = torch.nn.Parameter(torch.FloatTensor(n_class, 192), requires_grad=True) 28 | self.ce = nn.CrossEntropyLoss() 29 | nn.init.xavier_normal_(self.weight, gain=1) 30 | self.cos_m = math.cos(self.m) 31 | self.sin_m = math.sin(self.m) 32 | self.th = math.cos(math.pi - self.m) 33 | self.mm = math.sin(math.pi - self.m) * self.m 34 | 35 | def forward(self, x, label=None): 36 | 37 | cosine = F.linear(F.normalize(x), F.normalize(self.weight)) 38 | sine = torch.sqrt((1.0 - torch.mul(cosine, cosine)).clamp(0, 1)) 39 | phi = cosine * self.cos_m - sine * self.sin_m 40 | phi = torch.where((cosine - self.th) > 0, phi, cosine - self.mm) 41 | one_hot = torch.zeros_like(cosine) 42 | one_hot.scatter_(1, label.view(-1, 1), 1) 43 | output = (one_hot * phi) + ((1.0 - one_hot) * cosine) 44 | output = output * self.s 45 | 46 | loss = self.ce(output, label) 47 | #prec1 = accuracy(output.detach(), label.detach(), topk=(1,))[0] 48 | 49 | return loss 50 | 51 | class StandardSimilarityLoss(nn.Module): 52 | def __init__(self, n_class, n_model=192, loss_function=nn.MSELoss, distance_matrix=DISTANCE_MATRIX): 53 | 54 | super(StandardSimilarityLoss, self).__init__() 55 | self.linear = nn.Linear(n_model, n_class) 56 | self.loss_function = loss_function() 57 | self.distance_matrix = distance_matrix 58 | def forward(self, x, label): 59 | n = x.shape[0] 60 | y_true = torch.zeros((n, 9)) 61 | for i in range(n): 62 | y_true[i] = torch.from_numpy(100 - self.distance_matrix[label[i]][:9]) 63 | y_true = y_true.to(x.device) 64 | loss = self.loss_function(self.linear(x), y_true) 65 | return loss 66 | 67 | class PairDistanceLoss(nn.Module): 68 | def __init__(self, loss_function=nn.MSELoss, distance_matrix=DISTANCE_MATRIX): 69 | super(PairDistanceLoss, self).__init__() 70 | self.loss_function = loss_function() 71 | self.distance_matrix = DISTANCE_MATRIX 72 | def forward(self, y_pred, labels): 73 | n = y_pred.shape[0] 74 | y_true = torch.zeros((n,n)) 75 | for i in range(n): 76 | for j in range(n): 77 | y_true[i][j] = self.distance_matrix[labels[i]][labels[j]] 78 | y_true = y_true.to(y_pred.device) 79 | dist = torch.cdist(y_pred, y_pred, p=2) 80 | loss = self.loss_function(dist, y_true) 81 | return loss 82 | 83 | -------------------------------------------------------------------------------- /utils/ECAPAModel.py: -------------------------------------------------------------------------------- 1 | ''' 2 | This part is used to train the dialect model and evaluate the performances 3 | ''' 4 | 5 | import torch, sys, os, tqdm, numpy, soundfile, time, pickle 6 | import torch.nn as nn 7 | from utils.tools import * 8 | from utils.loss import AAMsoftmax, StandardSimilarityLoss, PairDistanceLoss 9 | from utils.model import ECAPA_TDNN 10 | import pandas as pd 11 | 12 | class ECAPAModel(nn.Module): 13 | def __init__(self, lr, lr_decay, C , n_class, m, s, test_step, loss, device, **kwargs): 14 | super(ECAPAModel, self).__init__() 15 | ## model 16 | self.device = device 17 | self.dialect_encoder = ECAPA_TDNN(C = C).to(self.device) 18 | 19 | ##loss 20 | if loss == 'aamsoftmax': 21 | self.dialect_loss = AAMsoftmax(n_class = n_class, m = m, s = s).to(self.device) 22 | elif loss == 'StandardSimilarityLoss': 23 | self.dialect_loss = StandardSimilarityLoss(n_class).to(self.device) 24 | elif loss == 'PairDistanceLoss': 25 | self.dialect_loss = PairDistanceLoss().to(self.device) 26 | else: 27 | raise NotImplementedError 28 | self.optim = torch.optim.Adam(self.parameters(), lr = lr, weight_decay = 2e-5) 29 | self.scheduler = torch.optim.lr_scheduler.StepLR(self.optim, step_size = test_step, gamma=lr_decay) 30 | print(time.strftime("%m-%d %H:%M:%S") + " Model para number = %.2f"%(sum(param.numel() for param in self.dialect_encoder.parameters()) / 1024 / 1024)) 31 | 32 | def train_network(self, epoch, loader): 33 | self.train() 34 | ## Update the learning rate based on the current epcoh 35 | self.scheduler.step(epoch - 1) 36 | index, loss = 0, 0 37 | lr = self.optim.param_groups[0]['lr'] 38 | 39 | for num, (data, labels) in enumerate(loader, start = 1): 40 | self.zero_grad() 41 | labels = torch.LongTensor(labels).to(self.device) 42 | dialect_embedding = self.dialect_encoder.forward(data.to(self.device), aug = True) 43 | nloss = self.dialect_loss.forward(dialect_embedding, labels) 44 | nloss.backward() 45 | self.optim.step() 46 | index += len(labels) 47 | loss += nloss.detach().cpu().numpy() 48 | sys.stderr.write(time.strftime("%m-%d %H:%M:%S") + \ 49 | " [%2d] Lr: %5f, Training: %.2f%%, "%(epoch, lr, 100 * (num / loader.__len__())) + \ 50 | " Loss: %.5f\r"%(loss/(num))) 51 | sys.stderr.flush() 52 | sys.stdout.write("\n") 53 | return loss/num, lr 54 | 55 | def eval_network(self, loader): 56 | self.eval() 57 | 58 | total_loss = 0.0 59 | for idx, (data, label) in tqdm.tqdm(enumerate(loader)): 60 | data = data.to(self.device) 61 | label = label.to(self.device) 62 | with torch.no_grad(): 63 | embedding = self.dialect_encoder.forward(data, aug = False) 64 | loss = self.dialect_loss(embedding, label) 65 | total_loss += loss.item() 66 | total_loss = total_loss / (idx + 1) 67 | return total_loss 68 | 69 | def save_parameters(self, path): 70 | torch.save(self.state_dict(), path) 71 | 72 | def load_parameters(self, path): 73 | self_state = self.state_dict() 74 | loaded_state = torch.load(path) 75 | for name, param in loaded_state.items(): 76 | origname = name 77 | if name not in self_state: 78 | name = name.replace("module.", "") 79 | if name not in self_state: 80 | print("%s is not in the model."%origname) 81 | continue 82 | if self_state[name].size() != loaded_state[origname].size(): 83 | print("Wrong parameter length: %s, model: %s, loaded: %s"%(origname, self_state[name].size(), loaded_state[origname].size())) 84 | continue 85 | self_state[name].copy_(param) -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | ''' 2 | This is the main code of the ECAPATDNN project, to define the parameters and build the construction 3 | ''' 4 | 5 | import argparse, glob, os, torch, warnings, time 6 | from utils.tools import * 7 | from utils.dataLoader import train_loader, eval_loader,my_collate_fn 8 | from utils.ECAPAModel import ECAPAModel 9 | 10 | parser = argparse.ArgumentParser(description = "ECAPA_trainer") 11 | parser.add_argument('--device', type=str, default='cuda:0', help='Device training on ') 12 | ## Training Settings 13 | parser.add_argument('--num_frames', type=int, default=300, help='Duration of the input segments, eg: 200 for 2 second') 14 | parser.add_argument('--max_epoch', type=int, default=80, help='Maximum number of epochs') 15 | parser.add_argument('--batch_size', type=int, default=128, help='Batch size') 16 | parser.add_argument('--n_cpu', type=int, default=4, help='Number of loader threads') 17 | parser.add_argument('--test_step', type=int, default=1, help='Test and save every [test_step] epochs') 18 | parser.add_argument('--lr', type=float, default=0.001, help='Learning rate') 19 | parser.add_argument("--lr_decay", type=float, default=0.97, help='Learning rate decay every [test_step] epochs') 20 | 21 | ## Training and evaluation path/lists, save path 22 | parser.add_argument('--train_list', type=str, default="./data/train_df", help='The path of the training list, eg:"/data08/VoxCeleb2/train_list.txt" in my case') 23 | parser.add_argument('--eval_list', type=str, default="./data/valid_df", help='The path of the evaluation list, eg:"/data08/VoxCeleb1/veri_test2.txt" in my case') 24 | parser.add_argument('--eval_max_length', type=int, default=120000, help='the max length of evaluate audio') 25 | parser.add_argument('--save_path', type=str, default="./exps", help='Path to save the score.txt and models') 26 | parser.add_argument('--initial_model', type=str, default="", help='Path of the initial_model') 27 | 28 | ## Model and Loss settings 29 | parser.add_argument('--C', type=int, default=1024, help='Channel size for the dialect encoder') 30 | parser.add_argument('--m', type=float, default=0.2, help='Loss margin in AAM softmax') 31 | parser.add_argument('--s', type=float, default=30, help='Loss scale in AAM softmax') 32 | parser.add_argument('--n_class', type=int, default=9, help='Number of dialects') 33 | parser.add_argument('--loss' , type=str, default="PairDistanceLoss", help='Target and loss function') 34 | ## Command 35 | parser.add_argument('--eval', dest='eval', action='store_true', help='Only do evaluation') 36 | 37 | ## Initialization 38 | warnings.simplefilter("ignore") 39 | torch.multiprocessing.set_sharing_strategy('file_system') 40 | args = parser.parse_args() 41 | args = init_args(args) 42 | 43 | ## Define the data loader 44 | train_data = train_loader(**vars(args)) 45 | trainLoader = torch.utils.data.DataLoader(train_data, batch_size = args.batch_size, shuffle = True, num_workers = args.n_cpu, drop_last = True) 46 | 47 | def my_collate_fn(data, max_length=args.eval_max_length): 48 | 49 | lens = [x[0].shape[0] for x in data] 50 | max_len = max(lens) 51 | max_length = min(max_len, max_length) 52 | 53 | features = torch.zeros(len(data), max_length) 54 | for i, length in enumerate(lens): 55 | features[i,:length] = data[i][0][:max_length] 56 | labels = torch.tensor([x[1] for x in data]) 57 | 58 | return features, labels 59 | 60 | eval_data = eval_loader(**vars(args)) 61 | evalLoader = torch.utils.data.DataLoader(eval_data, batch_size = args.batch_size*2, shuffle = True, num_workers = args.n_cpu, drop_last = True, collate_fn=my_collate_fn) 62 | 63 | ## Search for the exist models 64 | #modelfiles = glob.glob('%s/model_0*.model'%args.model_save_path) 65 | #modelfiles.sort() 66 | modelfiles = [] 67 | 68 | ## Only do evaluation, the initial_model is necessary 69 | if args.eval == True: 70 | s = ECAPAModel(**vars(args)) 71 | print("Model %s loaded from previous state!"%args.initial_model) 72 | s.load_parameters(args.initial_model) 73 | loss = s.eval_network(loader=evalLoader) 74 | print("EvalLoss %2.2f"%(loss)) 75 | quit() 76 | 77 | ## If initial_model is exist, system will train from the initial_model 78 | if args.initial_model != "": 79 | print("Model %s loaded from previous state!"%args.initial_model) 80 | s = ECAPAModel(**vars(args)) 81 | s.load_parameters(args.initial_model) 82 | epoch = 1 83 | 84 | ## Otherwise, system will try to start from the saved model&epoch 85 | elif len(modelfiles) >= 1: 86 | print("Model %s loaded from previous state!"%modelfiles[-1]) 87 | epoch = 1 88 | s = ECAPAModel(**vars(args)) 89 | s.load_parameters(modelfiles[-1]) 90 | ## Otherwise, system will train from scratch 91 | else: 92 | epoch = 1 93 | s = ECAPAModel(**vars(args)) 94 | 95 | eval_losses = [] 96 | score_file = open(args.score_save_path, "a+") 97 | 98 | while(1): 99 | ## Training for one epoch 100 | loss, lr = s.train_network(epoch = epoch, loader = trainLoader) 101 | 102 | ## Evaluation every [test_step] epochs 103 | if epoch % args.test_step == 0: 104 | s.save_parameters(args.model_save_path + "/model_%04d.model"%epoch) 105 | eval_losses.append(s.eval_network(loader=evalLoader)) 106 | print(time.strftime("%Y-%m-%d %H:%M:%S"), "%d epoch, EvalLoss %2.2f, minEvalLoss %2.2f"%(epoch,eval_losses[-1], min(eval_losses))) 107 | score_file.write("%d epoch, LR %f, LOSS %f, EvalLoss %2.2f, minEvalLoss %2.2f\n"%(epoch, lr, loss, eval_losses[-1], min(eval_losses))) 108 | score_file.flush() 109 | 110 | if epoch >= args.max_epoch: 111 | quit() 112 | 113 | epoch += 1 114 | -------------------------------------------------------------------------------- /utils/model.py: -------------------------------------------------------------------------------- 1 | ''' 2 | This is the ECAPA-TDNN model. 3 | This model is modified and combined based on the following three projects: 4 | 1. https://github.com/clovaai/voxceleb_trainer/issues/86 5 | 2. https://github.com/lawlict/ECAPA-TDNN/blob/master/ecapa_tdnn.py 6 | 3. https://github.com/speechbrain/speechbrain/blob/96077e9a1afff89d3f5ff47cab4bca0202770e4f/speechbrain/lobes/models/ECAPA_TDNN.py 7 | 8 | ''' 9 | 10 | import math, torch, torchaudio 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | 14 | 15 | class SEModule(nn.Module): 16 | def __init__(self, channels, bottleneck=128): 17 | super(SEModule, self).__init__() 18 | self.se = nn.Sequential( 19 | nn.AdaptiveAvgPool1d(1), 20 | nn.Conv1d(channels, bottleneck, kernel_size=1, padding=0), 21 | nn.ReLU(), 22 | nn.Conv1d(bottleneck, channels, kernel_size=1, padding=0), 23 | nn.Sigmoid(), 24 | ) 25 | 26 | def forward(self, input): 27 | x = self.se(input) 28 | return input * x 29 | 30 | class Bottle2neck(nn.Module): 31 | 32 | def __init__(self, inplanes, planes, kernel_size=None, dilation=None, scale = 8): 33 | super(Bottle2neck, self).__init__() 34 | width = int(math.floor(planes / scale)) 35 | self.conv1 = nn.Conv1d(inplanes, width*scale, kernel_size=1) 36 | self.bn1 = nn.BatchNorm1d(width*scale) 37 | self.nums = scale -1 38 | convs = [] 39 | bns = [] 40 | num_pad = math.floor(kernel_size/2)*dilation 41 | for i in range(self.nums): 42 | convs.append(nn.Conv1d(width, width, kernel_size=kernel_size, dilation=dilation, padding=num_pad)) 43 | bns.append(nn.BatchNorm1d(width)) 44 | self.convs = nn.ModuleList(convs) 45 | self.bns = nn.ModuleList(bns) 46 | self.conv3 = nn.Conv1d(width*scale, planes, kernel_size=1) 47 | self.bn3 = nn.BatchNorm1d(planes) 48 | self.relu = nn.ReLU() 49 | self.width = width 50 | self.se = SEModule(planes) 51 | 52 | def forward(self, x): 53 | residual = x 54 | out = self.conv1(x) 55 | out = self.relu(out) 56 | out = self.bn1(out) 57 | 58 | spx = torch.split(out, self.width, 1) 59 | for i in range(self.nums): 60 | if i==0: 61 | sp = spx[i] 62 | else: 63 | sp = sp + spx[i] 64 | sp = self.convs[i](sp) 65 | sp = self.relu(sp) 66 | sp = self.bns[i](sp) 67 | if i==0: 68 | out = sp 69 | else: 70 | out = torch.cat((out, sp), 1) 71 | 72 | out = torch.cat((out, spx[self.nums]),1) 73 | 74 | out = self.conv3(out) 75 | out = self.relu(out) 76 | out = self.bn3(out) 77 | out = self.se(out) 78 | out += residual 79 | return out 80 | 81 | class PreEmphasis(torch.nn.Module): 82 | 83 | def __init__(self, coef: float = 0.97): 84 | super().__init__() 85 | self.coef = coef 86 | self.register_buffer( 87 | 'flipped_filter', torch.FloatTensor([-self.coef, 1.]).unsqueeze(0).unsqueeze(0) 88 | ) 89 | 90 | def forward(self, input: torch.tensor) -> torch.tensor: 91 | input = input.unsqueeze(1) 92 | input = F.pad(input, (1, 0), 'reflect') 93 | return F.conv1d(input, self.flipped_filter).squeeze(1) 94 | 95 | class FbankAug(nn.Module): 96 | 97 | def __init__(self, freq_mask_width = (0, 8), time_mask_width = (0, 10)): 98 | self.time_mask_width = time_mask_width 99 | self.freq_mask_width = freq_mask_width 100 | super().__init__() 101 | 102 | def mask_along_axis(self, x, dim): 103 | original_size = x.shape 104 | batch, fea, time = x.shape 105 | if dim == 1: 106 | D = fea 107 | width_range = self.freq_mask_width 108 | else: 109 | D = time 110 | width_range = self.time_mask_width 111 | 112 | mask_len = torch.randint(width_range[0], width_range[1], (batch, 1), device=x.device).unsqueeze(2) 113 | mask_pos = torch.randint(0, max(1, D - mask_len.max()), (batch, 1), device=x.device).unsqueeze(2) 114 | arange = torch.arange(D, device=x.device).view(1, 1, -1) 115 | mask = (mask_pos <= arange) * (arange < (mask_pos + mask_len)) 116 | mask = mask.any(dim=1) 117 | 118 | if dim == 1: 119 | mask = mask.unsqueeze(2) 120 | else: 121 | mask = mask.unsqueeze(1) 122 | 123 | x = x.masked_fill_(mask, 0.0) 124 | return x.view(*original_size) 125 | 126 | def forward(self, x): 127 | x = self.mask_along_axis(x, dim=2) 128 | x = self.mask_along_axis(x, dim=1) 129 | return x 130 | 131 | class ECAPA_TDNN(nn.Module): 132 | 133 | def __init__(self, C): 134 | 135 | super(ECAPA_TDNN, self).__init__() 136 | 137 | self.torchfbank = torch.nn.Sequential( 138 | PreEmphasis(), 139 | torchaudio.transforms.MelSpectrogram(sample_rate=8000, n_fft=512, win_length=400, hop_length=160, \ 140 | #f_min = 20, f_max = 7600, 141 | window_fn=torch.hamming_window, n_mels=80), 142 | ) 143 | 144 | self.specaug = FbankAug() # Spec augmentation 145 | 146 | self.conv1 = nn.Conv1d(80, C, kernel_size=5, stride=1, padding=2) 147 | self.relu = nn.ReLU() 148 | self.bn1 = nn.BatchNorm1d(C) 149 | self.layer1 = Bottle2neck(C, C, kernel_size=3, dilation=2, scale=8) 150 | self.layer2 = Bottle2neck(C, C, kernel_size=3, dilation=3, scale=8) 151 | self.layer3 = Bottle2neck(C, C, kernel_size=3, dilation=4, scale=8) 152 | 153 | self.layer4 = nn.Conv1d(3*C, 1536, kernel_size=1) 154 | self.attention = nn.Sequential( 155 | nn.Conv1d(4608, 256, kernel_size=1), 156 | nn.ReLU(), 157 | nn.BatchNorm1d(256), 158 | nn.Tanh(), # I add this layer 159 | nn.Conv1d(256, 1536, kernel_size=1), 160 | nn.Softmax(dim=2), 161 | ) 162 | self.bn5 = nn.BatchNorm1d(3072) 163 | self.fc6 = nn.Linear(3072, 192) 164 | self.bn6 = nn.BatchNorm1d(192) 165 | 166 | 167 | def forward(self, x, aug): 168 | 169 | with torch.no_grad(): 170 | x = self.torchfbank(x)+1e-6 171 | 172 | x = x.log() 173 | 174 | x = x - torch.mean(x, dim=-1, keepdim=True) 175 | if aug == True: 176 | x = self.specaug(x) 177 | 178 | x = self.conv1(x) 179 | x = self.relu(x) 180 | x = self.bn1(x) 181 | 182 | 183 | x1 = self.layer1(x) 184 | x2 = self.layer2(x+x1) 185 | x3 = self.layer3(x+x1+x2) 186 | 187 | x = self.layer4(torch.cat((x1,x2,x3),dim=1)) 188 | x = self.relu(x) 189 | 190 | t = x.size()[-1] 191 | global_x = torch.cat((x,torch.mean(x,dim=2,keepdim=True).repeat(1,1,t), torch.sqrt(torch.var(x,dim=2,keepdim=True).clamp(min=1e-4)).repeat(1,1,t)), dim=1) 192 | 193 | w = self.attention(global_x) 194 | 195 | mu = torch.sum(x * w, dim=2) 196 | sg = torch.sqrt( ( torch.sum((x**2) * w, dim=2) - mu**2 ).clamp(min=1e-4) ) 197 | 198 | x = torch.cat((mu,sg),1) 199 | x = self.bn5(x) 200 | x = self.fc6(x) 201 | x = self.bn6(x) 202 | return x 203 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | --------------------------------------------------------------------------------