├── .gitignore ├── README.md ├── SpeechDataGenerator.py ├── SpeechDataGenerator_precomp_feats.py ├── datasets.py ├── feature_extraction.py ├── meta ├── testing.txt ├── testing_feat.txt ├── training.txt ├── training_feat.txt ├── validation.txt └── validation_feat.txt ├── models ├── tdnn.py ├── x_vector.py └── x_vector_Indian_LID.py ├── requirements.txt ├── training_xvector.py └── utils └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | .DS_store 3 | best_check_point* 4 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # x-vector-pytorch 2 | This repo contains the implementation of the paper "Spoken Language Recognition using X-vectors" in Pytorch 3 | Paper: https://danielpovey.com/files/2018_odyssey_xvector_lid.pdf 4 | Tutorial : https://www.youtube.com/watch?v=8nZjiXEdMH0 5 | 6 | ## Installation 7 | 8 | I suggest you to install Anaconda3 in your system. First download Anancoda3 from https://docs.anaconda.com/anaconda/install/hashes/lin-3-64/ 9 | ```bash 10 | bash Anaconda2-2019.03-Linux-x86_64.sh 11 | ``` 12 | ## Clone the repo 13 | ```bash 14 | https://github.com/KrishnaDN/x-vector-pytorch.git 15 | ``` 16 | Once you install anaconda3 successfully, install required packges using requirements.txt 17 | ```bash 18 | pip iinstall -r requirements.txt 19 | ``` 20 | 21 | ## Create manifest files for training and testing 22 | This step creates training and testing files. 23 | ``` 24 | python datasets.py --processed_data /media/newhd/youtube_lid_data/download_data --meta_store_path meta/ 25 | ``` 26 | 27 | ## Training 28 | This steps starts training the X-vector model for language identification 29 | ``` 30 | python training_xvector.py --training_filepath meta/training.txt --testing_filepath meta/testing.txt --validation_filepath meta/validation.txt 31 | --input_dim 40 --num_classes 8 --batch_size 32 --use_gpu True --num_epochs 100 32 | 33 | ``` 34 | 35 | Pull requests are welcome. For major changes, please open an issue first to discuss what you would like to change. 36 | For any queries contact : krishnadn94@gmail.com 37 | ## License 38 | [MIT](https://choosealicense.com/licenses/mit/) -------------------------------------------------------------------------------- /SpeechDataGenerator.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Sat Jul 20 14:09:31 2019 5 | 6 | @author: Krishna 7 | """ 8 | import numpy as np 9 | import torch 10 | from utils import utils 11 | 12 | class SpeechDataGenerator(): 13 | """Speech dataset.""" 14 | 15 | def __init__(self, manifest, mode): 16 | """ 17 | Read the textfile and get the paths 18 | """ 19 | self.mode=mode 20 | self.audio_links = [line.rstrip('\n').split(' ')[0] for line in open(manifest)] 21 | self.labels = [int(line.rstrip('\n').split(' ')[1]) for line in open(manifest)] 22 | 23 | 24 | def __len__(self): 25 | return len(self.audio_links) 26 | 27 | def __getitem__(self, idx): 28 | audio_link =self.audio_links[idx] 29 | class_id = self.labels[idx] 30 | #lang_label=lang_id[self.audio_links[idx].split('/')[-2]] 31 | spec = utils.load_data(audio_link,mode=self.mode) 32 | sample = {'features': torch.from_numpy(np.ascontiguousarray(spec)), 'labels': torch.from_numpy(np.ascontiguousarray(class_id))} 33 | return sample 34 | 35 | 36 | -------------------------------------------------------------------------------- /SpeechDataGenerator_precomp_feats.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Sat Jul 20 14:09:31 2019 5 | 6 | @author: Krishna 7 | """ 8 | import numpy as np 9 | import torch 10 | from utils import utils 11 | 12 | class SpeechDataGenerator_precomp_features(): 13 | """Speech dataset.""" 14 | 15 | def __init__(self, manifest, mode): 16 | """ 17 | Read the textfile and get the paths 18 | """ 19 | self.mode=mode 20 | self.npy_files = [line.rstrip('\n').split(' ')[0] for line in open(manifest)] 21 | self.labels = [int(line.rstrip('\n').split(' ')[1]) for line in open(manifest)] 22 | 23 | 24 | def __len__(self): 25 | return len(self.npy_files) 26 | 27 | def __getitem__(self, idx): 28 | npy_filepath =self.npy_files[idx] 29 | class_id = self.labels[idx] 30 | #lang_label=lang_id[self.audio_links[idx].split('/')[-2]] 31 | spec = utils.load_npy_data(npy_filepath,mode=self.mode) 32 | sample = {'features': torch.from_numpy(np.ascontiguousarray(spec)), 'labels': torch.from_numpy(np.ascontiguousarray(class_id))} 33 | return sample 34 | 35 | 36 | -------------------------------------------------------------------------------- /datasets.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Sat May 30 19:09:44 2020 5 | 6 | @author: krishna 7 | """ 8 | 9 | import os 10 | import numpy as np 11 | import glob 12 | import argparse 13 | 14 | 15 | 16 | class_ids ={'English':0,'Hindi':1,'Kannada':2,'Tamil':3,'Telugu':4,'Malayalam':5,'Marathi':6,'Gujarathi':7} 17 | def create_meta(files_list,store_loc,mode='train'): 18 | if not os.path.exists(store_loc): 19 | os.makedirs(store_loc) 20 | 21 | if mode=='train': 22 | meta_store = store_loc+'/training.txt' 23 | fid = open(meta_store,'w') 24 | for filepath in files_list: 25 | fid.write(filepath+'\n') 26 | fid.close() 27 | elif mode=='test': 28 | meta_store = store_loc+'/testing.txt' 29 | fid = open(meta_store,'w') 30 | for filepath in files_list: 31 | fid.write(filepath+'\n') 32 | fid.close() 33 | elif mode=='validation': 34 | meta_store = store_loc+'/validation.txt' 35 | fid = open(meta_store,'w') 36 | for filepath in files_list: 37 | fid.write(filepath+'\n') 38 | fid.close() 39 | else: 40 | print('Error in creating meta files') 41 | 42 | def extract_files(folder_path): 43 | all_lang_folders = sorted(glob.glob(folder_path+'/*/')) 44 | train_lists=[] 45 | test_lists = [] 46 | val_lists=[] 47 | 48 | for lang_folderpath in all_lang_folders: 49 | language = lang_folderpath.split('/')[-2] 50 | sub_folders = sorted(glob.glob(lang_folderpath+'/*/')) 51 | train_nums = len(sub_folders)-int(len(sub_folders)*0.1)-int(len(sub_folders)*0.05) 52 | for i in range(train_nums): 53 | sub_folder = sub_folders[i] 54 | all_files = sorted(glob.glob(sub_folder+'/*.wav')) 55 | for audio_filepath in all_files: 56 | to_write = audio_filepath+' '+str(class_ids[language]) 57 | train_lists.append(to_write) 58 | 59 | for i in range(train_nums,train_nums+int(len(sub_folders)*0.05)): 60 | sub_folder = sub_folders[i] 61 | all_files = sorted(glob.glob(sub_folder+'/*.wav')) 62 | for audio_filepath in all_files: 63 | to_write = audio_filepath+' '+str(class_ids[language]) 64 | val_lists.append(to_write) 65 | 66 | for i in range(train_nums+int(len(sub_folders)*0.05),len(sub_folders)): 67 | sub_folder = sub_folders[i] 68 | all_files = sorted(glob.glob(sub_folder+'/*.wav')) 69 | for audio_filepath in all_files: 70 | to_write = audio_filepath+' '+str(class_ids[language]) 71 | test_lists.append(to_write) 72 | return train_lists,test_lists,val_lists 73 | 74 | 75 | if __name__ == '__main__': 76 | parser = argparse.ArgumentParser("Configuration for data preparation") 77 | parser.add_argument("--processed_data", default="/media/newhd/youtube_lid_data/download_data", type=str,help='Dataset path') 78 | parser.add_argument("--meta_store_path", default="meta/", type=str,help='Save directory after processing') 79 | config = parser.parse_args() 80 | train_list, test_list,val_lists = extract_files(config.processed_data) 81 | 82 | create_meta(train_list,config.meta_store_path,mode='train') 83 | create_meta(test_list,config.meta_store_path,mode='test') 84 | create_meta(val_lists,config.meta_store_path,mode='validation') 85 | -------------------------------------------------------------------------------- /feature_extraction.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Sun May 31 11:15:47 2020 5 | 6 | @author: krishna 7 | """ 8 | 9 | import os 10 | import numpy as np 11 | from utils import utils 12 | 13 | def extract_features(audio_filepath): 14 | features = utils.feature_extraction(audio_filepath) 15 | return features 16 | 17 | 18 | 19 | def FE_pipeline(feature_list,store_loc,mode): 20 | create_root = os.path.join(store_loc,mode) 21 | if not os.path.exists(create_root): 22 | os.makedirs(create_root) 23 | if mode=='train': 24 | fid = open('meta/training_feat.txt','w') 25 | elif mode=='test': 26 | fid = open('meta/testing_feat.txt','w') 27 | elif mode=='validation': 28 | fid = open('meta/validation_feat.txt','w') 29 | else: 30 | print('Unknown mode') 31 | 32 | for row in feature_list: 33 | filepath = row.split(' ')[0] 34 | lang_id = row.split(' ')[1] 35 | vid_folder = filepath.split('/')[-2] 36 | lang_folder = filepath.split('/')[-3] 37 | filename = filepath.split('/')[-1] 38 | create_folders = os.path.join(create_root,lang_folder,vid_folder) 39 | if not os.path.exists(create_folders): 40 | os.makedirs(create_folders) 41 | extract_feats = extract_features(filepath) 42 | dest_filepath = create_folders+'/'+filename[:-4]+'.npy' 43 | np.save(dest_filepath,extract_feats) 44 | to_write = dest_filepath+' '+lang_id 45 | fid.write(to_write+'\n') 46 | fid.close() 47 | 48 | 49 | 50 | if __name__ == '__main__': 51 | store_loc = '/media/newhd/youtube_lid_data/Features' 52 | read_train = [line.rstrip('\n') for line in open('meta/training.txt')] 53 | FE_pipeline(read_train,store_loc,mode='train') 54 | 55 | read_test = [line.rstrip('\n') for line in open('meta/testing.txt')] 56 | FE_pipeline(read_test,store_loc,mode='test') 57 | 58 | read_val = [line.rstrip('\n') for line in open('meta/validation.txt')] 59 | FE_pipeline(read_val,store_loc,mode='validation') 60 | 61 | 62 | 63 | 64 | 65 | -------------------------------------------------------------------------------- /models/tdnn.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | @author: cvqluu 5 | repo: https://github.com/cvqluu/TDNN 6 | """ 7 | 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | class TDNN(nn.Module): 12 | 13 | def __init__( 14 | self, 15 | input_dim=23, 16 | output_dim=512, 17 | context_size=5, 18 | stride=1, 19 | dilation=1, 20 | batch_norm=False, 21 | dropout_p=0.2 22 | ): 23 | ''' 24 | TDNN as defined by https://www.danielpovey.com/files/2015_interspeech_multisplice.pdf 25 | 26 | Affine transformation not applied globally to all frames but smaller windows with local context 27 | 28 | batch_norm: True to include batch normalisation after the non linearity 29 | 30 | Context size and dilation determine the frames selected 31 | (although context size is not really defined in the traditional sense) 32 | For example: 33 | context size 5 and dilation 1 is equivalent to [-2,-1,0,1,2] 34 | context size 3 and dilation 2 is equivalent to [-2, 0, 2] 35 | context size 1 and dilation 1 is equivalent to [0] 36 | ''' 37 | super(TDNN, self).__init__() 38 | self.context_size = context_size 39 | self.stride = stride 40 | self.input_dim = input_dim 41 | self.output_dim = output_dim 42 | self.dilation = dilation 43 | self.dropout_p = dropout_p 44 | self.batch_norm = batch_norm 45 | 46 | self.kernel = nn.Linear(input_dim*context_size, output_dim) 47 | self.nonlinearity = nn.ReLU() 48 | if self.batch_norm: 49 | self.bn = nn.BatchNorm1d(output_dim) 50 | if self.dropout_p: 51 | self.drop = nn.Dropout(p=self.dropout_p) 52 | 53 | def forward(self, x): 54 | ''' 55 | input: size (batch, seq_len, input_features) 56 | outpu: size (batch, new_seq_len, output_features) 57 | ''' 58 | 59 | _, _, d = x.shape 60 | assert (d == self.input_dim), 'Input dimension was wrong. Expected ({}), got ({})'.format(self.input_dim, d) 61 | x = x.unsqueeze(1) 62 | 63 | # Unfold input into smaller temporal contexts 64 | x = F.unfold( 65 | x, 66 | (self.context_size, self.input_dim), 67 | stride=(1,self.input_dim), 68 | dilation=(self.dilation,1) 69 | ) 70 | 71 | # N, output_dim*context_size, new_t = x.shape 72 | x = x.transpose(1,2) 73 | x = self.kernel(x.float()) 74 | x = self.nonlinearity(x) 75 | 76 | if self.dropout_p: 77 | x = self.drop(x) 78 | 79 | if self.batch_norm: 80 | x = x.transpose(1,2) 81 | x = self.bn(x) 82 | x = x.transpose(1,2) 83 | 84 | return x -------------------------------------------------------------------------------- /models/x_vector.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Sat May 30 19:59:45 2020 5 | 6 | @author: krishna 7 | 8 | """ 9 | 10 | 11 | import torch.nn as nn 12 | from models.tdnn import TDNN 13 | import torch 14 | import torch.nn.functional as F 15 | 16 | class X_vector(nn.Module): 17 | def __init__(self, input_dim = 40, num_classes=8): 18 | super(X_vector, self).__init__() 19 | self.tdnn1 = TDNN(input_dim=input_dim, output_dim=512, context_size=5, dilation=1,dropout_p=0.5) 20 | self.tdnn2 = TDNN(input_dim=512, output_dim=512, context_size=3, dilation=1,dropout_p=0.5) 21 | self.tdnn3 = TDNN(input_dim=512, output_dim=512, context_size=2, dilation=2,dropout_p=0.5) 22 | self.tdnn4 = TDNN(input_dim=512, output_dim=512, context_size=1, dilation=1,dropout_p=0.5) 23 | self.tdnn5 = TDNN(input_dim=512, output_dim=512, context_size=1, dilation=3,dropout_p=0.5) 24 | #### Frame levelPooling 25 | self.segment6 = nn.Linear(1024, 512) 26 | self.segment7 = nn.Linear(512, 512) 27 | self.output = nn.Linear(512, num_classes) 28 | self.softmax = nn.Softmax(dim=1) 29 | def forward(self, inputs): 30 | tdnn1_out = self.tdnn1(inputs) 31 | return tdnn1_out 32 | tdnn2_out = self.tdnn2(tdnn1_out) 33 | tdnn3_out = self.tdnn3(tdnn2_out) 34 | tdnn4_out = self.tdnn4(tdnn3_out) 35 | tdnn5_out = self.tdnn5(tdnn4_out) 36 | ### Stat Pool 37 | mean = torch.mean(tdnn5_out,1) 38 | std = torch.std(tdnn5_out,1) 39 | stat_pooling = torch.cat((mean,std),1) 40 | segment6_out = self.segment6(stat_pooling) 41 | x_vec = self.segment7(segment6_out) 42 | predictions = self.softmax(self.output(x_vec)) 43 | return predictions,x_vec -------------------------------------------------------------------------------- /models/x_vector_Indian_LID.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Sat May 30 19:59:45 2020 5 | 6 | @author: krishna 7 | 8 | """ 9 | 10 | 11 | import torch.nn as nn 12 | from models.tdnn import TDNN 13 | import torch 14 | 15 | 16 | class X_vector(nn.Module): 17 | def __init__(self, input_dim = 40, num_classes=8): 18 | super(X_vector, self).__init__() 19 | self.tdnn1 = TDNN(input_dim=input_dim, output_dim=512, context_size=5, dilation=1,dropout_p=0.5) 20 | self.tdnn2 = TDNN(input_dim=512, output_dim=512, context_size=3, dilation=1,dropout_p=0.5) 21 | self.tdnn3 = TDNN(input_dim=512, output_dim=512, context_size=2, dilation=2,dropout_p=0.5) 22 | self.tdnn4 = TDNN(input_dim=512, output_dim=512, context_size=1, dilation=1,dropout_p=0.5) 23 | self.tdnn5 = TDNN(input_dim=512, output_dim=512, context_size=1, dilation=3,dropout_p=0.5) 24 | #### Frame levelPooling 25 | self.segment6 = nn.Linear(1024, 512) 26 | self.segment7 = nn.Linear(512, 512) 27 | self.output = nn.Linear(512, num_classes) 28 | self.softmax = nn.Softmax(dim=1) 29 | def forward(self, inputs): 30 | tdnn1_out = self.tdnn1(inputs) 31 | tdnn2_out = self.tdnn2(tdnn1_out) 32 | tdnn3_out = self.tdnn3(tdnn2_out) 33 | tdnn4_out = self.tdnn4(tdnn3_out) 34 | tdnn5_out = self.tdnn5(tdnn4_out) 35 | ### Stat Pool 36 | 37 | mean = torch.mean(tdnn5_out,1) 38 | std = torch.var(tdnn5_out,1) 39 | stat_pooling = torch.cat((mean,std),1) 40 | segment6_out = self.segment6(stat_pooling) 41 | x_vec = self.segment7(segment6_out) 42 | predictions = self.output(x_vec) 43 | return predictions,x_vec -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | pytorch 2 | numpy 3 | glob 4 | sklearn 5 | librosa 6 | -------------------------------------------------------------------------------- /training_xvector.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Sat May 30 20:22:26 2020 5 | 6 | @author: krishna 7 | """ 8 | 9 | 10 | 11 | import torch 12 | import numpy as np 13 | from torch.utils.data import DataLoader 14 | from SpeechDataGenerator import SpeechDataGenerator 15 | import torch.nn as nn 16 | import os 17 | import numpy as np 18 | from torch import optim 19 | import argparse 20 | from models.x_vector_Indian_LID import X_vector 21 | from sklearn.metrics import accuracy_score 22 | from utils.utils import speech_collate 23 | import torch.nn.functional as F 24 | torch.multiprocessing.set_sharing_strategy('file_system') 25 | 26 | 27 | ########## Argument parser 28 | parser = argparse.ArgumentParser(add_help=False) 29 | parser.add_argument('-training_filepath',type=str,default='meta/training_feat.txt') 30 | parser.add_argument('-testing_filepath',type=str, default='meta/testing_feat.txt') 31 | parser.add_argument('-validation_filepath',type=str, default='meta/validation_feat.txt') 32 | 33 | parser.add_argument('-input_dim', action="store_true", default=257) 34 | parser.add_argument('-num_classes', action="store_true", default=8) 35 | parser.add_argument('-lamda_val', action="store_true", default=0.1) 36 | parser.add_argument('-batch_size', action="store_true", default=256) 37 | parser.add_argument('-use_gpu', action="store_true", default=True) 38 | parser.add_argument('-num_epochs', action="store_true", default=100) 39 | args = parser.parse_args() 40 | 41 | ### Data related 42 | dataset_train = SpeechDataGenerator(manifest=args.training_filepath,mode='train') 43 | dataloader_train = DataLoader(dataset_train, batch_size=args.batch_size,shuffle=True,collate_fn=speech_collate) 44 | 45 | dataset_val = SpeechDataGenerator(manifest=args.validation_filepath,mode='train') 46 | dataloader_val = DataLoader(dataset_train, batch_size=args.batch_size,shuffle=True,collate_fn=speech_collate) 47 | 48 | 49 | dataset_test = SpeechDataGenerator(manifest=args.testing_filepath,mode='test') 50 | dataloader_test = DataLoader(dataset_test, batch_size=args.batch_size,shuffle=True,collate_fn=speech_collate) 51 | 52 | ## Model related 53 | use_cuda = torch.cuda.is_available() 54 | device = torch.device("cuda" if use_cuda else "cpu") 55 | model = X_vector(args.input_dim, args.num_classes).to(device) 56 | optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=0.0, betas=(0.9, 0.98), eps=1e-9) 57 | loss_fun = nn.CrossEntropyLoss() 58 | 59 | 60 | 61 | def train(dataloader_train,epoch): 62 | train_loss_list=[] 63 | full_preds=[] 64 | full_gts=[] 65 | model.train() 66 | for i_batch, sample_batched in enumerate(dataloader_train): 67 | 68 | features = torch.from_numpy(np.asarray([torch_tensor.numpy().T for torch_tensor in sample_batched[0]])).float() 69 | labels = torch.from_numpy(np.asarray([torch_tensor[0].numpy() for torch_tensor in sample_batched[1]])) 70 | features, labels = features.to(device),labels.to(device) 71 | features.requires_grad = True 72 | optimizer.zero_grad() 73 | pred_logits,x_vec = model(features) 74 | #### CE loss 75 | loss = loss_fun(pred_logits,labels) 76 | loss.backward() 77 | optimizer.step() 78 | train_loss_list.append(loss.item()) 79 | #train_acc_list.append(accuracy) 80 | #if i_batch%10==0: 81 | # print('Loss {} after {} iteration'.format(np.mean(np.asarray(train_loss_list)),i_batch)) 82 | 83 | predictions = np.argmax(pred_logits.detach().cpu().numpy(),axis=1) 84 | for pred in predictions: 85 | full_preds.append(pred) 86 | for lab in labels.detach().cpu().numpy(): 87 | full_gts.append(lab) 88 | 89 | mean_acc = accuracy_score(full_gts,full_preds) 90 | mean_loss = np.mean(np.asarray(train_loss_list)) 91 | print('Total training loss {} and training Accuracy {} after {} epochs'.format(mean_loss,mean_acc,epoch)) 92 | 93 | 94 | 95 | def validation(dataloader_val,epoch): 96 | model.eval() 97 | with torch.no_grad(): 98 | val_loss_list=[] 99 | full_preds=[] 100 | full_gts=[] 101 | for i_batch, sample_batched in enumerate(dataloader_val): 102 | features = torch.from_numpy(np.asarray([torch_tensor.numpy().T for torch_tensor in sample_batched[0]])).float() 103 | labels = torch.from_numpy(np.asarray([torch_tensor[0].numpy() for torch_tensor in sample_batched[1]])) 104 | features, labels = features.to(device),labels.to(device) 105 | pred_logits,x_vec = model(features) 106 | #### CE loss 107 | loss = loss_fun(pred_logits,labels) 108 | val_loss_list.append(loss.item()) 109 | #train_acc_list.append(accuracy) 110 | predictions = np.argmax(pred_logits.detach().cpu().numpy(),axis=1) 111 | for pred in predictions: 112 | full_preds.append(pred) 113 | for lab in labels.detach().cpu().numpy(): 114 | full_gts.append(lab) 115 | 116 | mean_acc = accuracy_score(full_gts,full_preds) 117 | mean_loss = np.mean(np.asarray(val_loss_list)) 118 | print('Total vlidation loss {} and Validation accuracy {} after {} epochs'.format(mean_loss,mean_acc,epoch)) 119 | 120 | model_save_path = os.path.join('save_model', 'best_check_point_'+str(epoch)+'_'+str(mean_loss)) 121 | state_dict = {'model': model.state_dict(),'optimizer': optimizer.state_dict(),'epoch': epoch} 122 | torch.save(state_dict, model_save_path) 123 | 124 | if __name__ == '__main__': 125 | for epoch in range(args.num_epochs): 126 | train(dataloader_train,epoch) 127 | validation(dataloader_val,epoch) 128 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | # Third Party 2 | import librosa 3 | import numpy as np 4 | import random 5 | # =============================================== 6 | # code from Arsha for loading data. 7 | # This code extract features for a give audio file 8 | # =============================================== 9 | def load_wav(audio_filepath, sr, min_dur_sec=4): 10 | audio_data,fs = librosa.load(audio_filepath,sr=16000) 11 | len_file = len(audio_data) 12 | 13 | if len_file