├── .gitignore ├── 200420608-Multi_task_speech_classification.pdf ├── LICENSE.md ├── README.md ├── checkpoints └── resnet18.pt ├── figures ├── accent.png └── gender.png ├── main.py ├── models ├── __init__.py ├── attention.py ├── dataset.py ├── inference.py ├── loss_func.py ├── lstm.py ├── resnet.py ├── simple_cnn.py ├── test.py ├── tests │ └── test_models.py └── train.py ├── requirements.txt └── utils ├── __init__.py ├── config.py └── preprocess.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | .idea/ -------------------------------------------------------------------------------- /200420608-Multi_task_speech_classification.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/karthikbhamidipati/multi-task-speech-classification/c81edc0093b6bbee76e82d21cd2c804b46d31dc0/200420608-Multi_task_speech_classification.pdf -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | Copyright 2021 [Karthik Bhamidipati](https://github.com/karthikbhamidipati) 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # multi-task-speech-classification 2 | Multi-Task Speech classification of accent and gender of an english speaker on [Mozilla's common voice dataset](https://www.kaggle.com/mozillaorg/common-voice). 3 | Paper can be found [here](200420608-Multi_task_speech_classification.pdf) 4 | 5 | # Run instructions 6 | 1. To `preprocess` the audio data, run 7 | ```shell 8 | python main.py preprocess -r 9 | ``` 10 | 11 | 2. To `train` the model using the preprocessed audio data, run 12 | ```shell 13 | python main.py train -r -m 14 | ``` 15 | **Models Implemented:** simple_cnn, resnet18, resnet34, resnet50, simple_lstm, bi_lstm, lstm_attention, bi_lstm_attention 16 | 17 | 3. To `test` the model on the test data, run 18 | ```shell 19 | python main.py test -r -m -c 20 | ``` 21 | 22 | 4. To perform `inference` on the audio files directly, run 23 | ```shell 24 | python main.py inference -r -m 25 | ``` 26 | -------------------------------------------------------------------------------- /checkpoints/resnet18.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/karthikbhamidipati/multi-task-speech-classification/c81edc0093b6bbee76e82d21cd2c804b46d31dc0/checkpoints/resnet18.pt -------------------------------------------------------------------------------- /figures/accent.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/karthikbhamidipati/multi-task-speech-classification/c81edc0093b6bbee76e82d21cd2c804b46d31dc0/figures/accent.png -------------------------------------------------------------------------------- /figures/gender.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/karthikbhamidipati/multi-task-speech-classification/c81edc0093b6bbee76e82d21cd2c804b46d31dc0/figures/gender.png -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | 3 | from models.inference import inference 4 | from models.test import test_model 5 | from models.train import train_model 6 | from utils.preprocess import preprocess, extract_features 7 | 8 | 9 | def is_sub_arg(arg): 10 | key, value = arg 11 | return value is not None and key != 'action' 12 | 13 | 14 | def clean_args(args): 15 | action = args.action 16 | cleaned_args = dict(filter(is_sub_arg, args._get_kwargs())) 17 | return action, cleaned_args 18 | 19 | 20 | def main(): 21 | parser = ArgumentParser() 22 | action_parser = parser.add_subparsers(title="actions", dest="action", required=True, 23 | help="select action to execute") 24 | 25 | # args for preprocessing 26 | preprocess_parser = action_parser.add_parser("preprocess", help="preprocess data") 27 | preprocess_parser.add_argument("-r", "--root-dir", dest="root_dir", required=True, 28 | help="root directory of the common voice dataset") 29 | 30 | # args for feature extraction 31 | feature_extractor_parser = action_parser.add_parser("feature_extractor", help="feature extractor") 32 | feature_extractor_parser.add_argument("-r", "--root-dir", dest="root_dir", required=True, 33 | help="root directory of the common voice dataset") 34 | 35 | # args for training 36 | training_parser = action_parser.add_parser("train", help="Train the model") 37 | training_parser.add_argument("-r", "--root-dir", dest="root_dir", required=True, 38 | help="root directory of the common voice dataset") 39 | training_parser.add_argument("-m", "--model-name", dest="model_key", required=True, 40 | help="key to determine the model to be trained") 41 | 42 | # args for testing 43 | test_parser = action_parser.add_parser("test", help="Test the model") 44 | test_parser.add_argument("-r", "--root-dir", dest="root_dir", required=True, 45 | help="root directory of the common voice dataset") 46 | test_parser.add_argument("-m", "--model-name", dest="model_key", required=True, 47 | help="key to determine the model to be tested") 48 | test_parser.add_argument("-c", "--checkpoint-dir", dest="checkpoint_path", required=True, 49 | help="root directory of the saved models") 50 | 51 | # args for inference 52 | inference_parser = action_parser.add_parser("inference", help="Run inference on the model") 53 | inference_parser.add_argument("-r", "--root-dir", dest="root_dir", required=True, 54 | help="root directory of the audio files") 55 | inference_parser.add_argument("-m", "--model-path", dest="model_path", required=True, 56 | help="path of the model") 57 | 58 | action, args = clean_args(parser.parse_args()) 59 | 60 | if action == 'preprocess': 61 | preprocess(**args) 62 | elif action == 'feature_extractor': 63 | extract_features(**args) 64 | elif action == 'train': 65 | train_model(**args) 66 | elif action == 'test': 67 | test_model(**args) 68 | elif action == 'inference': 69 | inference(**args) 70 | 71 | 72 | if __name__ == '__main__': 73 | main() 74 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from torch import device, cuda 3 | 4 | np.set_printoptions(precision=2) 5 | 6 | run_device = device("cuda" if cuda.is_available() else "cpu") 7 | -------------------------------------------------------------------------------- /models/attention.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn.functional as F 4 | from torch import nn 5 | 6 | from models import run_device 7 | 8 | 9 | class Attention(nn.Module): 10 | def __init__(self, hidden_size, lengths, batch_first=False): 11 | super(Attention, self).__init__() 12 | self.lengths = lengths 13 | self.hidden_size = hidden_size 14 | self.batch_first = batch_first 15 | 16 | self.att_weights = nn.Parameter(torch.Tensor(1, hidden_size), requires_grad=True) 17 | 18 | stdv = 1.0 / np.sqrt(self.hidden_size) 19 | for weight in self.att_weights: 20 | nn.init.uniform_(weight, -stdv, stdv) 21 | 22 | def forward(self, inputs): 23 | if self.batch_first: 24 | batch_size, max_len = inputs.size()[:2] 25 | else: 26 | max_len, batch_size = inputs.size()[:2] 27 | 28 | # apply attention layer 29 | weights = torch.bmm(inputs, 30 | self.att_weights # (1, hidden_size) 31 | .permute(1, 0) # (hidden_size, 1) 32 | .unsqueeze(0) # (1, hidden_size, 1) 33 | .repeat(batch_size, 1, 1) # (batch_size, hidden_size, 1) 34 | ).to(run_device) 35 | 36 | attentions = torch.softmax(F.relu(weights.squeeze()), dim=-1) 37 | 38 | # create mask based on the sentence lengths 39 | mask = torch.ones(attentions.size(), requires_grad=True).to(run_device) 40 | for i, l in enumerate(self.lengths): # skip the first sentence 41 | if l < max_len: 42 | mask[i, l:] = 0 43 | 44 | # apply mask and re-normalize attention scores (weights) 45 | masked = attentions * mask 46 | _sums = masked.sum(-1).unsqueeze(-1) # sums per row 47 | 48 | attentions = masked.div(_sums) 49 | 50 | # apply attention weights 51 | weighted = torch.mul(inputs, attentions.unsqueeze(-1).expand_as(inputs)) 52 | 53 | # get the final fixed vector representations of the sentences 54 | representations = weighted.sum(1) 55 | 56 | return representations, attentions 57 | -------------------------------------------------------------------------------- /models/dataset.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset, DataLoader 2 | 3 | 4 | class CommonVoiceDataset(Dataset): 5 | def __init__(self, dataset, dataset_type, feature_type): 6 | self.accent_mappings = dataset['mappings']['accent'] 7 | self.gender_mappings = dataset['mappings']['gender'] 8 | self.age_mappings = dataset['mappings']['age'] 9 | self.data = dataset['processed_data'][dataset_type] 10 | self.feature_type = feature_type 11 | 12 | def __len__(self): 13 | return len(self.data) 14 | 15 | def __getitem__(self, idx): 16 | features_labels_dict = self.data[idx] 17 | features = features_labels_dict[self.feature_type] 18 | gender_label = self.gender_mappings['gender2idx'][features_labels_dict['gender']] 19 | accent_label = self.accent_mappings['accent2idx'][features_labels_dict['accent']] 20 | return features, (gender_label, accent_label) 21 | 22 | 23 | def get_data_loaders(data, config): 24 | # Load train data 25 | print('Reading Train data', flush=True) 26 | train_dataset = CommonVoiceDataset(data, 'train_set', config.feature_type) 27 | train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True) 28 | 29 | print('Reading Validation data', flush=True) 30 | val_dataset = CommonVoiceDataset(data, 'val_set', config.feature_type) 31 | val_loader = DataLoader(val_dataset, batch_size=config.batch_size, shuffle=False) 32 | 33 | # Load test data 34 | print('Reading Test data', flush=True) 35 | test_dataset = CommonVoiceDataset(data, 'test_set', config.feature_type) 36 | test_loader = DataLoader(test_dataset, batch_size=config.batch_size, shuffle=False) 37 | 38 | return { 39 | 'train': train_loader, 40 | 'val': val_loader, 41 | 'test': test_loader 42 | } 43 | -------------------------------------------------------------------------------- /models/inference.py: -------------------------------------------------------------------------------- 1 | import os 2 | from os.path import join 3 | 4 | import torch 5 | 6 | from models import run_device 7 | from utils.config import GENDER_CLASSES, ACCENT_CLASSES, RAW_DATA_DIR 8 | from utils.preprocess import extract_audio_features 9 | 10 | 11 | class Row: 12 | def __init__(self, filename): 13 | self.filename = filename 14 | 15 | def to_dict(self): 16 | return { 17 | "filename": self.filename 18 | } 19 | 20 | 21 | def inference(root_dir, model_path): 22 | audio_features = [extract_audio_features(root_dir, Row(filename))['mfcc'] for filename in 23 | os.listdir(join(root_dir, RAW_DATA_DIR))] 24 | audio_features = torch.Tensor(audio_features).to(run_device) 25 | model = torch.load(model_path).to(run_device) 26 | 27 | model.eval() 28 | with torch.no_grad(): 29 | predictions_gender, predictions_accent = model(audio_features) 30 | predictions_gender = predictions_gender.argmax(dim=1).cpu().numpy() 31 | predictions_accent = predictions_accent.argmax(dim=1).cpu().numpy() 32 | 33 | for filename, gender, accent in zip(os.listdir(join(root_dir, RAW_DATA_DIR)), predictions_gender, 34 | predictions_accent): 35 | print("Filename: {}, Gender: {}, Accent: {}".format(filename, GENDER_CLASSES[gender], ACCENT_CLASSES[accent])) 36 | -------------------------------------------------------------------------------- /models/loss_func.py: -------------------------------------------------------------------------------- 1 | from torch import from_numpy, zeros, exp, FloatTensor 2 | from torch.nn import Module, CrossEntropyLoss, Parameter 3 | 4 | from models import run_device 5 | 6 | 7 | class MultiTaskLossWrapper(Module): 8 | def __init__(self, num_tasks, weights=None, log_vars=None): 9 | super(MultiTaskLossWrapper, self).__init__() 10 | self.task_num = num_tasks 11 | self.loss_func = [CrossEntropyLoss(weight=weights[i] if weights else None) for i in range(num_tasks)] 12 | self.log_vars = Parameter(from_numpy(log_vars) if log_vars else zeros(num_tasks)) 13 | 14 | def forward(self, predictions, targets): 15 | running_loss = 0.0 16 | for idx in range(self.task_num): 17 | loss = self.loss_func[idx](predictions[idx], targets[idx]) 18 | running_loss += (exp(-self.log_vars[idx]) * loss) + self.log_vars[idx] 19 | return running_loss 20 | 21 | 22 | def calculate_class_weights(config, normalized_counts): 23 | # Loss weighting proposed in https://openaccess.thecvf.com/content_CVPR_2019/html/Cui_Class-Balanced_Loss_Based_on_Effective_Number_of_Samples_CVPR_2019_paper.html 24 | class_weights = (1 - config.beta) / (1 - (config.beta ** normalized_counts)) 25 | return FloatTensor(class_weights).to(run_device) 26 | 27 | 28 | def get_loss_function(config, features): 29 | if config.use_class_weights: 30 | gender_class_weights = calculate_class_weights(config, features['mappings']['gender']['weights']) 31 | accent_class_weights = calculate_class_weights(config, features['mappings']['accent']['weights']) 32 | loss_func = MultiTaskLossWrapper(2, (gender_class_weights, accent_class_weights)) 33 | else: 34 | loss_func = MultiTaskLossWrapper(2) 35 | 36 | return loss_func.to(run_device) -------------------------------------------------------------------------------- /models/lstm.py: -------------------------------------------------------------------------------- 1 | from torch.nn import Module, Sequential, LSTM, BatchNorm2d, Linear, Dropout, init, AdaptiveAvgPool2d 2 | 3 | from models.attention import Attention 4 | 5 | 6 | class _LSTMModel(Module): 7 | def __init__(self, config, hidden_size=512, n_layers=8, bidirectional=False, attention=False): 8 | super(_LSTMModel, self).__init__() 9 | self.attention = attention 10 | 11 | # lstm layers 12 | self.lstm = LSTM(64, hidden_size, n_layers, dropout=config.lstm_dropout, bidirectional=bidirectional) 13 | 14 | n_layers *= 2 if bidirectional else 1 15 | hidden_size *= 2 if bidirectional else 1 16 | 17 | if attention: 18 | self.att_layer = Attention(hidden_size, (256, hidden_size), batch_first=True) 19 | 20 | self.avg_pooling = AdaptiveAvgPool2d((1, hidden_size)) 21 | 22 | # fully connected output layers 23 | self.gender_out = Sequential( 24 | Dropout(config.fc_dropout), 25 | Linear(hidden_size, 3) 26 | ) 27 | 28 | self.accent_out = Sequential( 29 | Dropout(config.fc_dropout), 30 | Linear(hidden_size, 16) 31 | ) 32 | 33 | # initialise the network's weights 34 | self.init_weights() 35 | 36 | def forward(self, x): 37 | # pass through lstm layers 38 | x, _ = self.lstm(x) 39 | 40 | if self.attention: 41 | x, _ = self.att_layer(x) 42 | else: 43 | # reshape the input into 4D format 44 | x = x.unsqueeze(1) 45 | x = self.avg_pooling(x) 46 | 47 | # flatten the features 48 | x = x.view(x.shape[0], -1) 49 | 50 | # pass through the multiple output layers 51 | gender = self.gender_out(x) 52 | accent = self.accent_out(x) 53 | 54 | return [gender, accent] 55 | 56 | def init_weights(self, layer=None): 57 | # Method to recursively initialize network weights 58 | # linear and conv layers, weights are initialized using xavier normal initialization 59 | # batch norm will have it's weights initialized with 1 and bias with 0 60 | if not layer or type(layer) == Sequential: 61 | children = layer.children() if layer else self.children() 62 | for module in children: 63 | self.init_weights(module) 64 | elif type(layer) == Linear: 65 | init.xavier_normal_(layer.weight) 66 | if layer.bias is not None: 67 | init.constant_(layer.bias, 0.001) 68 | elif type(layer) == BatchNorm2d: 69 | init.constant_(layer.weight, 1) 70 | init.constant_(layer.bias, 0) 71 | 72 | 73 | def simple_lstm(config): 74 | return _LSTMModel(config=config) 75 | 76 | 77 | def bi_lstm(config): 78 | return _LSTMModel(config=config, bidirectional=True) 79 | 80 | 81 | def lstm_attention(config): 82 | return _LSTMModel(config=config, bidirectional=False, attention=True) 83 | 84 | 85 | def bi_lstm_attention(config): 86 | return _LSTMModel(config=config, bidirectional=True, attention=True) 87 | -------------------------------------------------------------------------------- /models/resnet.py: -------------------------------------------------------------------------------- 1 | from torch.nn import Module, Sequential, Conv2d, BatchNorm2d, ReLU, MaxPool2d, Linear, Dropout, init, AdaptiveAvgPool2d, \ 2 | GroupNorm, DataParallel 3 | from torchvision.models.resnet import BasicBlock, Bottleneck, conv1x1 4 | 5 | 6 | class _ResNet(Module): 7 | def __init__(self, config, layers, block, inplanes=64, groups=1, norm_layer=BatchNorm2d): 8 | super(_ResNet, self).__init__() 9 | 10 | # type checks 11 | if type(layers) == list and len(layers) != 4 and all(map(lambda x: isinstance(x, int), layers)): 12 | raise ValueError("layers should be a list of ints with size 4") 13 | elif block not in (BasicBlock, Bottleneck): 14 | raise ValueError("invalid block, possible values: ") 15 | 16 | # constants 17 | self.inplanes = inplanes 18 | self.base_width = inplanes 19 | self.groups = groups 20 | self.norm_layer = norm_layer 21 | self.linear_units = 512 if block == BasicBlock else 2048 22 | 23 | # Initial convolutional layer 24 | self.conv = Sequential( 25 | Conv2d(1, self.inplanes, kernel_size=(7, 5), stride=(2, 1), padding=(3, 2), bias=False), 26 | self.norm_layer(self.inplanes), 27 | ReLU(inplace=True), 28 | MaxPool2d(kernel_size=(5, 3), stride=(2, 1), padding=(2, 1)) 29 | ) 30 | 31 | # Residual blocks 32 | self.res1 = self._make_layer(block, 64, layers[0]) 33 | self.res2 = self._make_layer(block, 128, layers[1], stride=2) 34 | self.res3 = self._make_layer(block, 256, layers[2], stride=2) 35 | self.res4 = self._make_layer(block, 512, layers[3], stride=2) 36 | 37 | self.avg_pool = AdaptiveAvgPool2d((1, 1)) 38 | 39 | # fully connected output layers 40 | self.gender_out = Sequential( 41 | Dropout(config.fc_dropout), 42 | Linear(self.linear_units, 3) 43 | ) 44 | 45 | self.accent_out = Sequential( 46 | Dropout(config.fc_dropout), 47 | Linear(self.linear_units, 16) 48 | ) 49 | 50 | # initialise the network's weights 51 | self._init_weights() 52 | 53 | def forward(self, x): 54 | # reshape the input into 4D format by adding an empty dimension at axis 1 for channels 55 | # (batch_size, channels, time_len, frequency_len) 56 | x = x.unsqueeze(1) 57 | 58 | # pass through convolutional layer 59 | x = self.conv(x) 60 | 61 | # pass through residual blocks, and pooling layers 62 | x = self.res1(x) 63 | x = self.res2(x) 64 | x = self.res3(x) 65 | x = self.res4(x) 66 | x = self.avg_pool(x) 67 | 68 | # flatten the features 69 | x = x.view(x.shape[0], -1) 70 | 71 | # pass through the multiple output layers 72 | gender = self.gender_out(x) 73 | accent = self.accent_out(x) 74 | 75 | return [gender, accent] 76 | 77 | def _init_weights(self): 78 | # Method to iteratively initialize network weights 79 | # conv layers, weights are initialized using kaiming normal initialization 80 | # linear layers, weights are initialized using xavier normal initialization 81 | # batch norm will have it's weights initialized with 1 and bias with 0 82 | for m in self.modules(): 83 | if isinstance(m, Conv2d): 84 | init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 85 | elif isinstance(m, Linear): 86 | init.xavier_normal_(m.weight) 87 | elif isinstance(m, (BatchNorm2d, GroupNorm)): 88 | init.constant_(m.weight, 1) 89 | init.constant_(m.bias, 0) 90 | elif isinstance(m, Bottleneck): 91 | init.constant_(m.bn3.weight, 0) 92 | elif isinstance(m, BasicBlock): 93 | init.constant_(m.bn2.weight, 0) 94 | 95 | def _make_layer(self, block, planes, blocks, stride=1): 96 | norm_layer = self.norm_layer 97 | downsample = None 98 | 99 | if stride != 1 or self.inplanes != planes * block.expansion: 100 | downsample = Sequential( 101 | conv1x1(self.inplanes, planes * block.expansion, stride), 102 | norm_layer(planes * block.expansion), 103 | ) 104 | 105 | layers = [block(self.inplanes, planes, stride, downsample, self.groups, self.base_width, 1, norm_layer)] 106 | 107 | self.inplanes = planes * block.expansion 108 | 109 | for _ in range(1, blocks): 110 | layers.append( 111 | block(self.inplanes, planes, groups=self.groups, base_width=self.base_width, norm_layer=norm_layer)) 112 | 113 | return Sequential(*layers) 114 | 115 | 116 | def resnet18(config): 117 | return DataParallel(_ResNet(config, [2, 2, 2, 2], block=BasicBlock)) 118 | 119 | 120 | def resnet34(config): 121 | return DataParallel(_ResNet(config, [3, 4, 6, 3], block=BasicBlock)) 122 | 123 | 124 | def resnet50(config): 125 | return DataParallel(_ResNet(config, [3, 4, 6, 3], block=Bottleneck)) 126 | 127 | 128 | def resnet101(config): 129 | return DataParallel(_ResNet(config, [3, 4, 23, 3], block=Bottleneck)) -------------------------------------------------------------------------------- /models/simple_cnn.py: -------------------------------------------------------------------------------- 1 | from torch.nn import Module, Sequential, Conv2d, BatchNorm2d, ReLU, MaxPool2d, Linear, Dropout, init, DataParallel 2 | 3 | 4 | class _SimpleCNNModel(Module): 5 | def __init__(self, config): 6 | super(_SimpleCNNModel, self).__init__() 7 | 8 | # convolutional layers 9 | self.conv1 = Sequential( 10 | Conv2d(1, 64, kernel_size=(5, 5), padding=(2, 2)), 11 | BatchNorm2d(64), 12 | ReLU(), 13 | MaxPool2d(kernel_size=(8, 4)) 14 | ) 15 | 16 | self.conv2 = Sequential( 17 | Conv2d(64, 128, kernel_size=(5, 5), padding=(2, 2)), 18 | BatchNorm2d(128), 19 | ReLU(), 20 | MaxPool2d(kernel_size=(4, 2)) 21 | ) 22 | 23 | self.conv3 = Sequential( 24 | Conv2d(128, 128, kernel_size=(3, 3), padding=(1, 1)), 25 | BatchNorm2d(128), 26 | ReLU(), 27 | MaxPool2d(kernel_size=(2, 2)) 28 | ) 29 | 30 | # fully connected output layers 31 | self.gender_out = Sequential( 32 | Linear(2048, 256), 33 | ReLU(), 34 | Dropout(config.fc_dropout), 35 | Linear(256, 3) 36 | ) 37 | 38 | self.accent_out = Sequential( 39 | Linear(2048, 1024), 40 | ReLU(), 41 | Linear(1024, 512), 42 | ReLU(), 43 | Linear(512, 512), 44 | ReLU(), 45 | Linear(512, 256), 46 | ReLU(), 47 | Linear(256, 16) 48 | ) 49 | 50 | # initialise the network's weights 51 | self.init_weights() 52 | 53 | def forward(self, x): 54 | # reshape the input into 4D format by adding an empty dimension at axis 1 for channels 55 | # (batch_size, channels, time_len, frequency_len) 56 | x = x.unsqueeze(1) 57 | 58 | # pass through convolutional layers 59 | x = self.conv1(x) 60 | x = self.conv2(x) 61 | x = self.conv3(x) 62 | 63 | # flatten the features 64 | x = x.view(x.shape[0], -1) 65 | 66 | # pass through the multiple output layers 67 | gender = self.gender_out(x) 68 | accent = self.accent_out(x) 69 | 70 | return [gender, accent] 71 | 72 | def init_weights(self, layer=None): 73 | # Method to recursively initialize network weights 74 | # linear and conv layers, weights are initialized using xavier normal initialization 75 | # batch norm will have it's weights initialized with 1 and bias with 0 76 | if not layer or type(layer) == Sequential: 77 | children = layer.children() if layer else self.children() 78 | for module in children: 79 | self.init_weights(module) 80 | elif type(layer) in (Conv2d, Linear): 81 | init.xavier_normal_(layer.weight) 82 | if layer.bias is not None: 83 | init.constant_(layer.bias, 0.001) 84 | elif type(layer) == BatchNorm2d: 85 | init.constant_(layer.weight, 1) 86 | init.constant_(layer.bias, 0) 87 | 88 | 89 | def simple_cnn(config): 90 | return DataParallel(_SimpleCNNModel(config=config)) 91 | -------------------------------------------------------------------------------- /models/test.py: -------------------------------------------------------------------------------- 1 | from _pickle import load 2 | from math import ceil 3 | from os.path import join 4 | 5 | import matplotlib.pyplot as plt 6 | import numpy as np 7 | import torch 8 | from sklearn.metrics import ConfusionMatrixDisplay, confusion_matrix 9 | from tqdm import tqdm 10 | 11 | from models import run_device 12 | from models.dataset import get_data_loaders 13 | from models.loss_func import get_loss_function 14 | from models.tests.test_models import Config 15 | from models.train import compute_metrics 16 | from utils.config import FEATURES_PATH, MODELS, HYPER_PARAMETERS 17 | 18 | 19 | def test(checkpoint_path, model_key, test_loader, criterion): 20 | test_loss, test_metrics = 0.0, np.zeros(2) 21 | 22 | dataset_size = len(test_loader.dataset) 23 | num_test_iters = ceil(dataset_size / test_loader.batch_size) 24 | ground_truth_accent, predictions_accent = [], [] 25 | ground_truth_gender, predictions_gender = [], [] 26 | 27 | model = torch.load(join(checkpoint_path, model_key + ".pt")) 28 | model.eval() 29 | 30 | for data, labels in tqdm(test_loader): 31 | # move data, labels to run_device 32 | data = data.to(run_device) 33 | 34 | ground_truth_gender.extend(labels[0].tolist()) 35 | ground_truth_accent.extend(labels[1].tolist()) 36 | 37 | labels = [label.to(run_device) for label in labels] 38 | 39 | # forward pass without grad to calculate the validation loss 40 | with torch.no_grad(): 41 | outputs = model(data) 42 | 43 | predictions_gender.extend(outputs[0].argmax(dim=1).cpu().tolist()) 44 | predictions_accent.extend(outputs[1].argmax(dim=1).cpu().tolist()) 45 | 46 | loss = criterion(outputs, labels) 47 | 48 | # calculate validation loss 49 | test_loss += loss.item() 50 | test_metrics += compute_metrics(outputs, labels) 51 | 52 | print("Test Loss: {}, Test Metrics: {}" 53 | .format(test_loss / num_test_iters, test_metrics / num_test_iters)) 54 | 55 | return np.array(ground_truth_gender), np.array(predictions_gender), np.array(ground_truth_accent), np.array(predictions_accent) 56 | 57 | 58 | def test_model(root_dir, model_key, checkpoint_path): 59 | features = load(open(join("A:/Data/Audio/common-voice", "pre_processed_data", "common_voice_features.pkl"), 'rb')) 60 | _, model_func = MODELS[model_key] 61 | config = Config(HYPER_PARAMETERS) 62 | 63 | data_loaders = get_data_loaders(features, config) 64 | loss_func = get_loss_function(config, features) 65 | 66 | ground_truth_gender, predictions_gender, ground_truth_accent, predictions_accent = test(checkpoint_path, model_key, 67 | data_loaders['test'], 68 | loss_func) 69 | plot_confusion_matrix(ground_truth_gender, predictions_gender, 70 | list(features['mappings']['gender']['idx2gender'].values()), 71 | "gender") 72 | plot_confusion_matrix(ground_truth_accent, predictions_accent, 73 | list(features['mappings']['accent']['idx2accent'].values()), 74 | "accent") 75 | 76 | 77 | def plot_confusion_matrix(ground_truth, predictions, display_labels, fig_name): 78 | cm = confusion_matrix(ground_truth, predictions) 79 | fig, ax = plt.subplots(figsize=(15, 15)) 80 | 81 | display = ConfusionMatrixDisplay(confusion_matrix=cm, 82 | display_labels=display_labels) 83 | display.plot(ax=ax, xticks_rotation='vertical') 84 | plt.savefig("figures/" + fig_name + ".png") 85 | -------------------------------------------------------------------------------- /models/tests/test_models.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import torch 4 | from parameterized import parameterized 5 | 6 | from models import run_device 7 | from utils.config import HYPER_PARAMETERS, MODELS 8 | 9 | 10 | class Config(object): 11 | def __init__(self, d): 12 | self.__dict__ = d 13 | 14 | 15 | def predict(model, sample_data): 16 | model.eval() 17 | with torch.no_grad(): 18 | return model(sample_data) 19 | 20 | 21 | class TestModels(unittest.TestCase): 22 | 23 | @classmethod 24 | def setUpClass(cls): 25 | cls.sample_data = torch.randn(1, 256, 64).to(run_device) 26 | cls.config = Config(HYPER_PARAMETERS) 27 | 28 | @parameterized.expand(MODELS.keys()) 29 | def test(self, name): 30 | model = MODELS[name][1](self.config).to(run_device) 31 | gender, accent = predict(model, self.sample_data) 32 | self.assertEqual(gender.shape, torch.Size([1, 3]), "Gender output shape not equal") 33 | self.assertEqual(accent.shape, torch.Size([1, 16]), "Accent output shape not equal") 34 | del model 35 | 36 | 37 | if __name__ == '__main__': 38 | unittest.main() 39 | -------------------------------------------------------------------------------- /models/train.py: -------------------------------------------------------------------------------- 1 | from math import ceil 2 | from os.path import join 3 | from pickle import load 4 | 5 | import numpy as np 6 | import wandb 7 | from sklearn.metrics import f1_score 8 | 9 | import torch 10 | from torch.nn import DataParallel 11 | from torch.optim import Adam 12 | 13 | from models import run_device 14 | from models.dataset import get_data_loaders 15 | from models.loss_func import get_loss_function 16 | from utils.config import FEATURES_PATH, PROJECT_NAME, HYPER_PARAMETERS, MODELS 17 | 18 | 19 | def start_wandb_session(run_name): 20 | wandb.init(project=PROJECT_NAME, config=HYPER_PARAMETERS) 21 | wandb.run.name = run_name 22 | return wandb.config 23 | 24 | 25 | def train_model(root_dir, model_key): 26 | # get features and init wandb session 27 | features = load(open(join(root_dir, FEATURES_PATH), 'rb')) 28 | run_name, model_func = MODELS[model_key] 29 | config = start_wandb_session(run_name) 30 | 31 | # get data_loaders, loss_func, model, optimizer 32 | data_loaders = get_data_loaders(features, config) 33 | loss_func = get_loss_function(config, features) 34 | model = model_func(config).to(run_device) 35 | optimizer = Adam(model.parameters(), lr=config.learning_rate) 36 | 37 | # set wandb to watch model parameters 38 | wandb.watch(model, log='all') 39 | model_save_path = join(wandb.run.dir, "model-" + wandb.run.id + ".pt") 40 | 41 | # begin training 42 | _train(config.num_epochs, data_loaders, model, optimizer, loss_func, model_save_path) 43 | 44 | # finish wandb session 45 | wandb.finish() 46 | 47 | 48 | # helper for computing metrics 49 | def compute_metrics(predictions, targets): 50 | metrics = [] 51 | for prediction, target in zip(predictions, targets): 52 | prediction = prediction.argmax(dim=1).cpu().numpy() 53 | target = target.cpu().numpy() 54 | metrics.append(f1_score(target, prediction, average='macro')) 55 | return metrics 56 | 57 | 58 | def _train(num_epochs, loaders, model, optimizer, criterion, save_path, min_loss=np.Inf): 59 | """returns trained model""" 60 | # initialize tracker for minimum validation loss 61 | val_loss_min = min_loss 62 | num_train_iters = ceil(len(loaders['train'].dataset) / loaders['train'].batch_size) 63 | num_val_iters = ceil(len(loaders['val'].dataset) / loaders['val'].batch_size) 64 | 65 | for epoch in range(1, num_epochs + 1): 66 | # initialize variables to monitor training and validation loss 67 | train_loss, val_loss = 0.0, 0.0 68 | train_metrics, val_metrics = np.zeros(2), np.zeros(2) 69 | 70 | # training the model 71 | model.train() 72 | for data, labels in loaders['train']: 73 | # move data, labels to run_device 74 | data = data.to(run_device) 75 | labels = [label.to(run_device) for label in labels] 76 | 77 | # forward pass, backward pass and update weights 78 | optimizer.zero_grad() 79 | outputs = model(data) 80 | loss = criterion(outputs, labels) 81 | loss.backward() 82 | optimizer.step() 83 | 84 | # calculate training loss and metrics 85 | train_loss += loss.item() 86 | train_metrics += compute_metrics(outputs, labels) 87 | 88 | # evaluating the model 89 | model.eval() 90 | for data, labels in loaders['val']: 91 | # move data, labels to run_device 92 | data = data.to(run_device) 93 | labels = [label.to(run_device) for label in labels] 94 | 95 | # forward pass without grad to calculate the validation loss 96 | with torch.no_grad(): 97 | outputs = model(data) 98 | loss = criterion(outputs, labels) 99 | 100 | # calculate validation loss 101 | val_loss += loss.item() 102 | val_metrics += compute_metrics(outputs, labels) 103 | 104 | # compute average loss and accuracy 105 | train_loss /= num_train_iters 106 | val_loss /= num_val_iters 107 | train_metrics *= 100 / num_train_iters 108 | val_metrics *= 100 / num_val_iters 109 | 110 | # logging metrics to wandb 111 | wandb.log({'epoch': epoch, 'train_loss': train_loss, 'train_fscore_gender': train_metrics[0], 112 | 'train_fscore_accent': train_metrics[1], 'val_loss': val_loss, 'val_fscore_gender': val_metrics[0], 113 | 'val_fscore_accent': val_metrics[1]}) 114 | 115 | # print training & validation statistics 116 | print( 117 | "Epoch: {}\tTraining Loss: {:.6f}\tTraining F-score: {}\tValidation Loss: {:.6f}\tValidation F-score: {}".format( 118 | epoch, train_loss, train_metrics, val_loss, val_metrics)) 119 | 120 | # saving the model when validation loss decreases 121 | if val_loss <= val_loss_min: 122 | print("Validation loss decreased ({:.6f} --> {:.6f}). Saving model ...".format(val_loss_min, val_loss)) 123 | torch.save(model.module if isinstance(model, DataParallel) else model.state_dict(), save_path) 124 | val_loss_min = val_loss 125 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | pandas~=1.2.3 2 | librosa~=0.8.0 3 | joblib>=1.2.0 4 | tqdm~=4.56.0 5 | scikit-learn~=0.24.1 6 | numpy>=1.22.0 7 | wandb~=0.10.26 8 | parameterized~=0.8.1 9 | --find-links https://download.pytorch.org/whl/torch_stable.html 10 | torch==2.2.0 11 | torchvision==0.8.2+cu101 12 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | # filtering warnings while preprocessing 4 | warnings.filterwarnings('ignore') 5 | -------------------------------------------------------------------------------- /utils/config.py: -------------------------------------------------------------------------------- 1 | import multiprocessing 2 | from os.path import join 3 | 4 | from models.lstm import simple_lstm, bi_lstm, lstm_attention, bi_lstm_attention 5 | from models.resnet import resnet18, resnet34, resnet50, resnet101 6 | from models.simple_cnn import simple_cnn 7 | 8 | # Config for reading csv files 9 | RAW_DATA_DIR = "raw_data/" 10 | PRE_PROCESSED_DATA_DIR = "pre_processed_data/" 11 | ANNOTATED_FILE_NAME = join(PRE_PROCESSED_DATA_DIR, "common_voice_annotated.csv") 12 | FILTERED_FILE_NAME = join(PRE_PROCESSED_DATA_DIR, "common_voice_annotated_filtered.csv") 13 | 14 | # Config for multiprocessing 15 | NUM_CORES = int(max((multiprocessing.cpu_count() / 2) - 1, 1)) 16 | 17 | # Audio Extraction configs 18 | SAMPLING_RATE = 22050 19 | MAX_TIME_LEN = 256 20 | 21 | FEATURE_ARGS = { 22 | "sr": SAMPLING_RATE, 23 | "n_fft": 2048, 24 | "win_length": 1024, 25 | "hop_length": 512, 26 | } 27 | EXTRACTOR_ARGS = { 28 | "n_mels": 64, 29 | "n_mfcc": 64 30 | } 31 | 32 | FEATURES_PATH = join(PRE_PROCESSED_DATA_DIR, "common_voice_features.pkl") 33 | 34 | # wandb args 35 | PROJECT_NAME = "multitask-audio-classification" 36 | HYPER_PARAMETERS = { 37 | 'beta': 0.001, # loss weighting parameter 38 | 'num_epochs': 50, # number of epochs 39 | 'batch_size': 256, # batch size 40 | 'learning_rate': 0.001, # learning rate 41 | 'feature_type': 'mfcc', # feature_type (mfcc or mel-spectrogram) 42 | 'use_class_weights': False, # use class weights for loss 43 | 'fc_dropout': 0.1, # dropout for fully connected layers 44 | 'lstm_dropout': 0.2 # dropout for lstm layers 45 | } 46 | 47 | # Model args 48 | MODELS = { 49 | "simple_cnn": ("CNN Baseline", simple_cnn), 50 | "resnet18": ("CNN ResNet18", resnet18), 51 | "resnet34": ("CNN ResNet34", resnet34), 52 | "resnet50": ("CNN ResNet50", resnet50), 53 | "resnet101": ("CNN ResNet101", resnet101), 54 | "simple_lstm": ("LSTM Baseline", simple_lstm), 55 | "bi_lstm": ("Bidirectional LSTM", bi_lstm), 56 | "lstm_attention": ("LSTM with attention", lstm_attention), 57 | "bi_lstm_attention": ("Bidirectional LSTM with Attention", bi_lstm_attention) 58 | } 59 | 60 | GENDER_CLASSES = ['female', 'male', 'other'] 61 | ACCENT_CLASSES = ['african', 'australia', 'bermuda', 'canada', 'england', 'hongkong', 'indian', 'ireland', 'malaysia', 62 | 'newzealand', 'philippines', 'scotland', 'singapore', 'southatlandtic', 'us', 'wales'] 63 | -------------------------------------------------------------------------------- /utils/preprocess.py: -------------------------------------------------------------------------------- 1 | from glob import glob 2 | from os.path import join 3 | from pickle import dump 4 | 5 | import numpy as np 6 | import pandas as pd 7 | from joblib import Parallel, delayed 8 | from librosa import get_duration, load 9 | from librosa.feature import melspectrogram, mfcc 10 | from sklearn.model_selection import train_test_split 11 | from tqdm import tqdm 12 | 13 | from utils.config import RAW_DATA_DIR, NUM_CORES, PRE_PROCESSED_DATA_DIR, FILTERED_FILE_NAME, ANNOTATED_FILE_NAME, \ 14 | MAX_TIME_LEN, FEATURE_ARGS, EXTRACTOR_ARGS, FEATURES_PATH 15 | 16 | 17 | def _compute_audio_duration(parent_dir, filename): 18 | filepath = parent_dir + filename 19 | return get_duration(filename=filepath) 20 | 21 | 22 | def _get_mappings(df, col_name): 23 | idx2col = dict(enumerate(df[col_name].cat.categories)) 24 | col2idx = dict((v, k) for k, v in idx2col.items()) 25 | return col2idx, idx2col 26 | 27 | 28 | def _get_class_weights(df, col_name, lookup): 29 | weights = np.zeros(len(lookup)) 30 | value_counts = df[col_name].value_counts(normalize=True).to_dict() 31 | for key, value in lookup.items(): 32 | weights[value] = value_counts[key] 33 | return weights 34 | 35 | 36 | def _clean_features(features): 37 | features = features.transpose(1, 0) 38 | if len(features) < MAX_TIME_LEN: 39 | features = np.pad(features, ((MAX_TIME_LEN - len(features), 0), (0, 0))) 40 | return features[:MAX_TIME_LEN] 41 | 42 | 43 | def extract_audio_features(root_dir, row): 44 | raw_data_dir = join(root_dir, RAW_DATA_DIR) 45 | row_dict = row.to_dict() 46 | waveform, _ = load(raw_data_dir + row_dict['filename'], sr=FEATURE_ARGS['sr']) 47 | row_dict['melspec'] = _clean_features(melspectrogram(waveform, n_mels=EXTRACTOR_ARGS['n_mels'], **FEATURE_ARGS)) 48 | row_dict['mfcc'] = _clean_features(mfcc(waveform, n_mfcc=EXTRACTOR_ARGS['n_mfcc'], **FEATURE_ARGS)) 49 | return row_dict 50 | 51 | 52 | def _get_features(root_dir, df): 53 | job = Parallel(n_jobs=NUM_CORES) 54 | return job(delayed(extract_audio_features)(root_dir, row) 55 | for index, row in tqdm(df.iterrows(), total=df.shape[0])) 56 | 57 | 58 | def preprocess_csv(root_dir): 59 | # getting the full path of raw features dir 60 | raw_data_dir = join(root_dir, RAW_DATA_DIR) 61 | 62 | # reading the csv files 63 | common_voice_df = pd.concat(map(pd.read_csv, glob(raw_data_dir + "*.csv"))) 64 | print('Total number of records: {}'.format(common_voice_df.shape[0]), flush=True) 65 | 66 | # filtering rows and columns that are not required 67 | common_voice_df = common_voice_df[common_voice_df.accent.notna() & 68 | common_voice_df.gender.notna() & 69 | common_voice_df.age.notna()].reset_index(drop=True) 70 | 71 | common_voice_df = common_voice_df[["filename", "age", "gender", "accent"]] 72 | print('After removing empty/null rows: {}'.format(common_voice_df.shape[0]), flush=True) 73 | 74 | # calculating and adding duration to the dataset 75 | job = Parallel(n_jobs=NUM_CORES) 76 | durations = job(delayed(_compute_audio_duration)(raw_data_dir, filename) 77 | for filename in tqdm(common_voice_df.filename)) 78 | 79 | common_voice_df['duration'] = durations 80 | 81 | # saving the dataframe 82 | common_voice_df.to_csv(join(root_dir, ANNOTATED_FILE_NAME), index=False) 83 | 84 | # filtering the dataframe based on audio duration 85 | common_voice_df = common_voice_df[(common_voice_df.duration >= 2.0) & 86 | (common_voice_df.duration <= 5.0)].reset_index(drop=True) 87 | print('After filtering using duration length: {}'.format(common_voice_df.shape[0]), flush=True) 88 | 89 | # saving the filtered dataframe 90 | common_voice_df.to_csv(join(root_dir, FILTERED_FILE_NAME), index=False) 91 | 92 | 93 | def extract_features(root_dir): 94 | # reading the filtered dataframe 95 | common_voice_df = pd.read_csv(join(root_dir, FILTERED_FILE_NAME)) 96 | common_voice_df = common_voice_df.astype( 97 | {'filename': 'string', 'age': 'category', 'gender': 'category', 'accent': 'category'}) 98 | 99 | # Train test split 100 | common_voice_train_df, common_voice_test_df = train_test_split(common_voice_df, test_size=0.2, random_state=0, 101 | stratify=common_voice_df[['gender', 'accent']]) 102 | 103 | # Train validation split 104 | common_voice_train_df, common_voice_val_df = train_test_split(common_voice_train_df, test_size=0.2, random_state=0, 105 | stratify=common_voice_train_df[['gender', 'accent']]) 106 | 107 | print("Train Shape: {}, Validation Shape: {}, Test Shape: {}".format(common_voice_train_df.shape, 108 | common_voice_val_df.shape, 109 | common_voice_test_df.shape)) 110 | 111 | age2idx, idx2age = _get_mappings(common_voice_df, 'age') 112 | gender2idx, idx2gender = _get_mappings(common_voice_df, 'gender') 113 | accent2idx, idx2accent = _get_mappings(common_voice_df, 'accent') 114 | 115 | age_weights = _get_class_weights(common_voice_train_df, 'age', age2idx) 116 | gender_weights = _get_class_weights(common_voice_train_df, 'gender', gender2idx) 117 | accent_weights = _get_class_weights(common_voice_train_df, 'accent', accent2idx) 118 | 119 | print("Extracting train features", flush=True) 120 | train_features = _get_features(root_dir, common_voice_train_df) 121 | 122 | print("Extracting val features", flush=True) 123 | val_features = _get_features(root_dir, common_voice_val_df) 124 | 125 | print("Extracting test features", flush=True) 126 | test_features = _get_features(root_dir, common_voice_test_df) 127 | 128 | features = { 129 | "mappings": { 130 | "accent": { 131 | "accent2idx": accent2idx, 132 | "idx2accent": idx2accent, 133 | "weights": accent_weights 134 | }, 135 | "age": { 136 | "age2idx": age2idx, 137 | "idx2age": idx2age, 138 | "weights": age_weights 139 | }, 140 | "gender": { 141 | "gender2idx": gender2idx, 142 | "idx2gender": idx2gender, 143 | "weights": gender_weights 144 | } 145 | }, 146 | "processed_data": { 147 | "train_set": train_features, 148 | "val_set": val_features, 149 | "test_set": test_features 150 | } 151 | } 152 | 153 | print("Saving the features...", flush=True) 154 | dump(features, open(join(root_dir, FEATURES_PATH), 'wb')) 155 | 156 | 157 | def preprocess(root_dir): 158 | preprocess_csv(root_dir) 159 | extract_features(root_dir) 160 | --------------------------------------------------------------------------------