├── utils ├── __init__.py ├── schedulers.py ├── optimizers.py ├── metrics.py ├── losses.py ├── utils.py └── visualize.py ├── assests ├── test.mp3 ├── test.wav ├── test2.wav ├── sed_result.png └── noises │ └── voices.wav ├── requirements.txt ├── models ├── __init__.py └── cnn14.py ├── LICENSE ├── datasets ├── __init__.py ├── urbansound.py ├── esc50.py ├── fsdkaggle.py ├── speechcommands.py ├── aug_test.ipynb ├── transforms.py └── audioset.py ├── configs ├── audioset.yaml ├── audioset_sed.yaml ├── esc50.yaml ├── fsd2018.yaml └── scv1.yaml ├── tools ├── val.py ├── infer.py ├── sed_infer.py └── train.py ├── .gitignore └── README.md /utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /assests/test.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sithu31296/audio-tagging/HEAD/assests/test.mp3 -------------------------------------------------------------------------------- /assests/test.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sithu31296/audio-tagging/HEAD/assests/test.wav -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sithu31296/audio-tagging/HEAD/requirements.txt -------------------------------------------------------------------------------- /assests/test2.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sithu31296/audio-tagging/HEAD/assests/test2.wav -------------------------------------------------------------------------------- /assests/sed_result.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sithu31296/audio-tagging/HEAD/assests/sed_result.png -------------------------------------------------------------------------------- /assests/noises/voices.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sithu31296/audio-tagging/HEAD/assests/noises/voices.wav -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .cnn14 import CNN14, CNN14DecisionLevelMax 2 | 3 | __all__ = { 4 | "cnn14": CNN14, 5 | "cnn14decisionlevelmax": CNN14DecisionLevelMax 6 | } 7 | 8 | 9 | def get_model(model_name: str, num_classes: int = 527): 10 | assert model_name in __all__.keys(), f"Unavailable model name >> {model_name}.\nList of available model names: {list(__all__.keys())}" 11 | return __all__[model_name](num_classes) -------------------------------------------------------------------------------- /utils/schedulers.py: -------------------------------------------------------------------------------- 1 | from torch.optim.lr_scheduler import StepLR, MultiStepLR 2 | 3 | 4 | __all__ = { 5 | "steplr": StepLR 6 | } 7 | 8 | def get_scheduler(cfg, optimizer): 9 | scheduler_name = cfg['SCHEDULER']['NAME'] 10 | assert scheduler_name in __all__.keys(), f"Unavailable scheduler name >> {scheduler_name}.\nList of available schedulers: {list(__all__.keys())}" 11 | return __all__[scheduler_name](optimizer, *cfg['SCHEDULER']['PARAMS']) 12 | 13 | 14 | if __name__ == '__main__': 15 | import torch 16 | model = torch.nn.Linear(1024, 50) 17 | optimizer = torch.optim.SGD(model.parameters(), 0.01) 18 | scheduler = MultiStepLR(optimizer, [10, 20, 30], gamma=0.5) -------------------------------------------------------------------------------- /utils/optimizers.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from torch.optim import AdamW, SGD 3 | 4 | 5 | def get_optimizer(model: nn.Module, optimizer: str, lr: float, weight_decay: float = 0.01): 6 | wd_params, nwd_params = [], [] 7 | for p in model.parameters(): 8 | if p.dim() == 1: 9 | nwd_params.append(p) 10 | else: 11 | wd_params.append(p) 12 | 13 | params = [ 14 | {"params": wd_params}, 15 | {"params": nwd_params, "weight_decay": 0} 16 | ] 17 | 18 | if optimizer == 'adamw': 19 | return AdamW(params, lr, betas=(0.9, 0.999), eps=1e-8, weight_decay=weight_decay) 20 | else: 21 | return SGD(params, lr, momentum=0.9, weight_decay=weight_decay) -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 sithu3 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from torch import distributed as dist 2 | from torch.utils.data import DistributedSampler, SequentialSampler, RandomSampler 3 | from .esc50 import ESC50 4 | from .audioset import AudioSet 5 | from .fsdkaggle import FSDKaggle2018 6 | from .urbansound import UrbanSound8k 7 | from .speechcommands import SpeechCommandsv1 8 | from .speechcommands import SpeechCommandsv2 9 | 10 | 11 | __all__ = { 12 | "esc50": ESC50, 13 | "audioset": AudioSet, 14 | "fsdkaggle2018": FSDKaggle2018, 15 | "urbansound8k": UrbanSound8k, 16 | "speechcommandsv1": SpeechCommandsv1, 17 | "speechcommandsv2": SpeechCommandsv2 18 | } 19 | 20 | 21 | def get_train_dataset(dataset_cfg, aug_cfg, transform, spec_transform): 22 | dataset_name = dataset_cfg['NAME'] 23 | assert dataset_name in __all__.keys(), f"Unavailable dataset name >> {dataset_name}.\nList of available datasets: {list(__all__.keys())}" 24 | return __all__[dataset_name]('train', dataset_cfg, aug_cfg, transform, spec_transform) 25 | 26 | def get_val_dataset(dataset_cfg): 27 | dataset_name = dataset_cfg['NAME'] 28 | assert dataset_name in __all__.keys(), f"Unavailable dataset name >> {dataset_name}.\nList of available datasets: {list(__all__.keys())}" 29 | return __all__[dataset_name]('val', dataset_cfg) 30 | 31 | 32 | def get_sampler(ddp_enable, train_dataset, val_dataset): 33 | if not ddp_enable: 34 | train_sampler = RandomSampler(train_dataset) 35 | else: 36 | train_sampler = DistributedSampler(train_dataset, dist.get_world_size(), dist.get_rank(), shuffle=True) 37 | val_sampler = SequentialSampler(val_dataset) 38 | return train_sampler, val_sampler -------------------------------------------------------------------------------- /configs/audioset.yaml: -------------------------------------------------------------------------------- 1 | DEVICE: cpu # device used for training 2 | 3 | MODEL: 4 | NAME: cnn14 # name of the model you are using 5 | PRETRAINED: '' 6 | 7 | DATASET: 8 | NAME: audioset # dataset name 9 | ROOT: '' # dataset root path 10 | METRIC: mAP 11 | SAMPLE_RATE: 32000 12 | AUDIO_LENGTH: 5 13 | WIN_LENGTH: 1024 14 | HOP_LENGTH: 320 15 | N_MELS: 64 16 | FMIN: 50 17 | FMAX: 14000 18 | 19 | AUG: 20 | MIXUP: 0.0 21 | MIXUP_ALPHA: 10 22 | SMOOTHING: 0.1 23 | TIME_MASK: 96 24 | FREQ_MASK: 24 25 | 26 | TRAIN: 27 | EPOCHS: 100 # number of epochs to train 28 | EVAL_INTERVAL: 10 # interval to evaluate the model during training 29 | BATCH_SIZE: 16 # batch size used to train 30 | LOSS: bcelogits # loss function name (ce, bce, bcelogits, label_smooth, soft_target) 31 | AMP: true # use Automatic Mixed Precision training or not 32 | DDP: false 33 | SAVE_DIR: 'output' # output folder name used for saving the trained model and logs 34 | 35 | OPTIMIZER: 36 | NAME: adamw 37 | LR: 0.0001 # initial learning rate used in optimizer 38 | WEIGHT_DECAY: 0.001 # decay rate use in optimizer 39 | 40 | SCHEDULER: 41 | NAME: steplr 42 | PARAMS: [30, 0.1] 43 | 44 | 45 | TEST: 46 | MODE: file # inference mode (file, mic) 47 | FILE: 'assests/test.wav' # audio file name (not use if you choose MODE=mic) 48 | MODEL_PATH: 'checkpoints/cnn14.pth' # trained model path 49 | TOPK: 5 -------------------------------------------------------------------------------- /utils/metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | from torch import Tensor 4 | from sklearn import metrics as skmetrics 5 | from scipy import stats 6 | 7 | 8 | class Metrics: 9 | def __init__(self, metric='accuracy') -> None: 10 | assert metric in ['accuracy', 'mAP', 'map'] 11 | self.metric = metric 12 | self.count = 0 13 | self.acc = 0.0 14 | self.preds = [] 15 | self.targets = [] 16 | 17 | def _update_accuracy(self, pred: Tensor, target: Tensor): 18 | self.acc += (pred.argmax(dim=-1) == target.argmax(dim=-1)).sum(dim=0).item() 19 | self.count += target.shape[0] 20 | 21 | def _update_mAP(self, pred: Tensor, target: Tensor): 22 | self.preds.append(pred) 23 | self.targets.append(target) 24 | 25 | def _compute_accuracy(self): 26 | return [(self.acc / self.count) * 100] 27 | 28 | def _compute_mAP(self): 29 | preds = torch.cat(self.preds, dim=0).cpu().numpy() 30 | targets = torch.cat(self.targets, dim=0).cpu().numpy() 31 | ap = skmetrics.average_precision_score(targets, preds, average=None) 32 | auc = skmetrics.roc_auc_score(targets, preds, average=None) 33 | mAP = ap.mean() 34 | mAUC = auc.mean() 35 | d_prime = stats.norm().ppf(mAUC) * math.sqrt(2.0) 36 | return mAP * 100, mAUC * 100, d_prime * 100 37 | 38 | def update(self, pred: Tensor, target: Tensor): 39 | if self.metric == 'accuracy': 40 | self._update_accuracy(pred, target) 41 | else: 42 | self._update_mAP(pred, target) 43 | 44 | def compute(self): 45 | if self.metric == 'accuracy': 46 | return self._compute_accuracy() 47 | else: 48 | return self._compute_mAP() -------------------------------------------------------------------------------- /configs/audioset_sed.yaml: -------------------------------------------------------------------------------- 1 | DEVICE: cpu # device used for training 2 | 3 | MODEL: 4 | NAME: cnn14decisionlevelmax # name of the model you are using 5 | PRETRAINED: '' 6 | 7 | DATASET: 8 | NAME: audioset # dataset name 9 | ROOT: '' # dataset root path 10 | METRIC: mAP 11 | SAMPLE_RATE: 32000 12 | AUDIO_LENGTH: 5 13 | WIN_LENGTH: 1024 14 | HOP_LENGTH: 320 15 | N_MELS: 64 16 | FMIN: 50 17 | FMAX: 14000 18 | 19 | AUG: 20 | MIXUP: 0.0 21 | MIXUP_ALPHA: 10 22 | SMOOTHING: 0.1 23 | TIME_MASK: 96 24 | FREQ_MASK: 24 25 | 26 | TRAIN: 27 | EPOCHS: 100 # number of epochs to train 28 | EVAL_INTERVAL: 10 # interval to evaluate the model during training 29 | BATCH_SIZE: 16 # batch size used to train 30 | LOSS: bcelogits # loss function name (ce, bce, bcelogits, label_smooth, soft_target) 31 | AMP: true # use Automatic Mixed Precision training or not 32 | DDP: false 33 | SAVE_DIR: 'output' # output folder name used for saving the trained model and logs 34 | 35 | OPTIMIZER: 36 | NAME: adamw 37 | LR: 0.0001 # initial learning rate used in optimizer 38 | WEIGHT_DECAY: 0.001 # decay rate use in optimizer 39 | 40 | SCHEDULER: 41 | NAME: steplr 42 | PARAMS: [30, 0.1] 43 | 44 | 45 | TEST: 46 | MODE: file # inference mode (file, mic) 47 | FILE: 'assests/test.wav' # audio file name (not use if you choose MODE=mic) 48 | MODEL_PATH: 'checkpoints/cnn14_decisionlevelmax.pth' # trained model path 49 | THRESHOLD: 0.2 50 | PLOT: false -------------------------------------------------------------------------------- /configs/esc50.yaml: -------------------------------------------------------------------------------- 1 | DEVICE: cpu # device used for training 2 | 3 | MODEL: 4 | NAME: cnn14 # name of the model you are using 5 | PRETRAINED: 'checkpoints/cnn14.pth' 6 | 7 | DATASET: 8 | NAME: esc50 # dataset name 9 | ROOT: 'C:/Users/sithu/Documents/Datasets/ESC50' # dataset root path 10 | METRIC: accuracy 11 | SOURCE_SAMPLE: 44100 12 | SAMPLE_RATE: 32000 13 | AUDIO_LENGTH: 5 14 | WIN_LENGTH: 1024 15 | HOP_LENGTH: 320 16 | N_MELS: 64 17 | FMIN: 50 18 | FMAX: 14000 19 | 20 | AUG: 21 | MIXUP: 0.0 22 | MIXUP_ALPHA: 10 23 | SMOOTHING: 0.1 24 | FREQ_MASK: 24 25 | TIME_MASK: 96 26 | 27 | TRAIN: 28 | EPOCHS: 100 # number of epochs to train 29 | EVAL_INTERVAL: 5 # interval to evaluate the model during training 30 | BATCH_SIZE: 16 # batch size used to train 31 | LOSS: label_smooth # loss function name (ce, bce, bcelogits, label_smooth, soft_target) 32 | AMP: false # use Automatic Mixed Precision training or not 33 | DDP: false 34 | SAVE_DIR: output # output folder name used for saving the trained model and logs 35 | 36 | OPTIMIZER: 37 | NAME: adamw 38 | LR: 0.0001 # initial learning rate used in optimizer 39 | WEIGHT_DECAY: 0.001 # decay rate use in optimizer 40 | 41 | SCHEDULER: 42 | NAME: steplr 43 | PARAMS: [20, 0.5] 44 | 45 | 46 | TEST: 47 | MODE: file # inference mode (file, mic) 48 | FILE: 'assests/test2.wav' # audio file name (not use if you choose MODE=mic) 49 | MODEL_PATH: 'checkpoints/cnn14_esc50.pth' # trained model path 50 | TOPK: 5 -------------------------------------------------------------------------------- /configs/fsd2018.yaml: -------------------------------------------------------------------------------- 1 | DEVICE: cuda # device used for training 2 | 3 | MODEL: 4 | NAME: cnn14 # name of the model you are using 5 | PRETRAINED: 'checkpoints/cnn14.pth' 6 | 7 | DATASET: 8 | NAME: fsdkaggle2018 # dataset name 9 | ROOT: 'C:/Users/sithu/Documents/Datasets/FSDKaggle2018' # dataset root path 10 | METRIC: accuracy 11 | SOURCE_SAMPLE: 44100 12 | SAMPLE_RATE: 32000 13 | AUDIO_LENGTH: 5 14 | WIN_LENGTH: 1024 15 | HOP_LENGTH: 320 16 | N_MELS: 64 17 | FMIN: 50 18 | FMAX: 14000 19 | 20 | AUG: 21 | MIXUP: 0.0 22 | MIXUP_ALPHA: 10 23 | SMOOTHING: 0.1 24 | FREQ_MASK: 24 25 | TIME_MASK: 96 26 | 27 | TRAIN: 28 | EPOCHS: 100 # number of epochs to train 29 | EVAL_INTERVAL: 10 # interval to evaluate the model during training 30 | BATCH_SIZE: 16 # batch size used to train 31 | LOSS: label_smooth # loss function name (ce, bce, bcelogits, label_smooth, soft_target) 32 | AMP: true # use Automatic Mixed Precision training or not 33 | DDP: false 34 | SAVE_DIR: output # output folder name used for saving the trained model and logs 35 | 36 | OPTIMIZER: 37 | NAME: adamw 38 | LR: 0.0001 # initial learning rate used in optimizer 39 | WEIGHT_DECAY: 0.001 # decay rate use in optimizer 40 | 41 | SCHEDULER: 42 | NAME: steplr 43 | PARAMS: [30, 0.1] 44 | 45 | 46 | TEST: 47 | MODE: file # inference mode (file, mic) 48 | FILE: 'assests/test.wav' # audio file name (not use if you choose MODE=mic) 49 | MODEL_PATH: 'checkpoints/cnn14_fsdkaggle2018.pth' # trained model path 50 | TOPK: 5 -------------------------------------------------------------------------------- /configs/scv1.yaml: -------------------------------------------------------------------------------- 1 | DEVICE: cpu # device used for training 2 | 3 | MODEL: 4 | NAME: cnn14 # name of the model you are using 5 | PRETRAINED: 'checkpoints/cnn14.pth' 6 | 7 | DATASET: 8 | NAME: speechcommandsv1 # dataset name 9 | ROOT: 'C:/Users/sithu/Documents/Datasets/SpeechCommandsv1' # dataset root path 10 | METRIC: accuracy 11 | SOURCE_SAMPLE: 16000 12 | SAMPLE_RATE: 32000 13 | AUDIO_LENGTH: 1 14 | WIN_LENGTH: 1024 15 | HOP_LENGTH: 320 16 | N_MELS: 64 17 | FMIN: 50 18 | FMAX: 14000 19 | 20 | AUG: 21 | MIXUP: 0.0 22 | MIXUP_ALPHA: 10 23 | SMOOTHING: 0.1 24 | FREQ_MASK: 24 25 | TIME_MASK: 96 26 | 27 | TRAIN: 28 | EPOCHS: 100 # number of epochs to train 29 | EVAL_INTERVAL: 10 # interval to evaluate the model during training 30 | BATCH_SIZE: 16 # batch size used to train 31 | LOSS: label_smooth # loss function name (ce, bce, bcelogits, label_smooth, soft_target) 32 | AMP: true # use Automatic Mixed Precision training or not 33 | DDP: false 34 | SAVE_DIR: output # output folder name used for saving the trained model and logs 35 | 36 | OPTIMIZER: 37 | NAME: adamw 38 | LR: 0.0001 # initial learning rate used in optimizer 39 | WEIGHT_DECAY: 0.001 # decay rate use in optimizer 40 | 41 | SCHEDULER: 42 | NAME: steplr 43 | PARAMS: [30, 0.1] 44 | 45 | 46 | TEST: 47 | MODE: file # inference mode (file, mic) 48 | FILE: 'assests/test.wav' # audio file name (not use if you choose MODE=mic) 49 | MODEL_PATH: 'checkpoints/cnn14_speechcommandsv1.pth' # trained model path 50 | TOPK: 5 -------------------------------------------------------------------------------- /utils/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, Tensor 3 | from torch.nn import CrossEntropyLoss, BCELoss, BCEWithLogitsLoss 4 | 5 | 6 | class CrossEntropy(nn.Module): 7 | def __init__(self): 8 | super().__init__() 9 | self.ce = CrossEntropyLoss() 10 | 11 | def forward(self, pred: Tensor, target: Tensor) -> Tensor: 12 | target = target.argmax(dim=1) 13 | loss = self.ce(pred, target) 14 | return loss 15 | 16 | 17 | class LabelSmoothCrossEntropy(nn.Module): 18 | def __init__(self, smoothing=0.1): 19 | super().__init__() 20 | assert smoothing < 1.0 21 | self.smoothing = smoothing 22 | self.confidence = 1. - smoothing 23 | self.log_softmax = nn.LogSoftmax(dim=-1) 24 | 25 | def forward(self, pred: Tensor, target: Tensor) -> Tensor: 26 | pred = self.log_softmax(pred) 27 | target = target.argmax(dim=1) 28 | nll_loss = -pred.gather(dim=-1, index=target.unsqueeze(1)).squeeze(1) 29 | smooth_loss = -pred.mean(dim=-1) 30 | loss = self.confidence * nll_loss + self.smoothing * smooth_loss 31 | return loss.mean() 32 | 33 | 34 | class SoftTargetCrossEntropy(nn.Module): 35 | def __init__(self): 36 | super().__init__() 37 | self.log_softmax = nn.LogSoftmax(dim=-1) 38 | 39 | def forward(self, pred: Tensor, target: Tensor) -> Tensor: 40 | pred = self.log_softmax(pred) 41 | loss = (-target * pred).sum(dim=-1) 42 | return loss.mean() 43 | 44 | 45 | __all__ = { 46 | # for audio classification, the following can be used 47 | 'ce': CrossEntropy, 48 | 'label_smooth': LabelSmoothCrossEntropy, 49 | 50 | # for audio tagging and sed, the following can be used 51 | 'bce': BCELoss, 52 | 'bcelogits': BCEWithLogitsLoss, 53 | 'soft_target': SoftTargetCrossEntropy 54 | } 55 | 56 | 57 | def get_loss(loss_fn_name: str): 58 | assert loss_fn_name in __all__.keys(), f"Unavailable loss function name >> {loss_fn_name}.\nList of available loss functions: {list(__all__.keys())}" 59 | return __all__[loss_fn_name]() 60 | 61 | 62 | if __name__ == '__main__': 63 | import torch 64 | from torch.nn import functional as F 65 | torch.manual_seed(123) 66 | B = 2 67 | C = 10 68 | x = torch.rand(B, C) 69 | y = torch.rand(B, C) 70 | loss_fn = CrossEntropy() 71 | loss = loss_fn(x, y) 72 | print(loss) 73 | -------------------------------------------------------------------------------- /tools/val.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import yaml 3 | import torch 4 | import multiprocessing as mp 5 | from pathlib import Path 6 | from pprint import pprint 7 | from tabulate import tabulate 8 | from tqdm import tqdm 9 | from torch.utils.data import DataLoader 10 | 11 | import sys 12 | sys.path.insert(0, '.') 13 | from models import get_model 14 | from datasets import get_val_dataset 15 | from utils.utils import setup_cudnn 16 | from utils.metrics import Metrics 17 | 18 | 19 | @torch.no_grad() 20 | def evaluate(dataloader, model, device, metric='accuracy'): 21 | print('Evaluating...') 22 | model.eval() 23 | ametrics = Metrics(metric) 24 | 25 | for audio, target in tqdm(dataloader, ): 26 | audio = audio.to(device) 27 | target = target.to(device) 28 | pred = model(audio) 29 | ametrics.update(pred, target) 30 | 31 | return ametrics.compute() 32 | 33 | 34 | def main(cfg): 35 | save_dir = Path(cfg['TRAIN']['SAVE_DIR']) 36 | device = torch.device(cfg['DEVICE']) 37 | metric_name = cfg['DATASET']['METRIC'] 38 | num_workers = mp.cpu_count() 39 | 40 | dataset = get_val_dataset(cfg['DATASET']) 41 | dataloader = DataLoader(dataset, batch_size=cfg['TRAIN']['BATCH_SIZE'], num_workers=num_workers, pin_memory=True) 42 | model = get_model(cfg['MODEL']['NAME'], dataset.num_classes) 43 | 44 | try: 45 | model_weights = save_dir / f"{cfg['MODEL']['NAME']}_{cfg['DATASET']['NAME']}.pth" 46 | model.load_state_dict(torch.load(str(model_weights), map_location='cpu')) 47 | print(f"Loading Model and trained weights from {model_weights}") 48 | except: 49 | print(f"Please consider placing your model's weights in {save_dir}") 50 | 51 | model = model.to(device) 52 | 53 | if metric_name == 'accuracy': 54 | acc = evaluate(dataloader, model, device, metric_name)[-1] 55 | table = [['Accuracy', f"{acc:.2f}"]] 56 | else: 57 | mAP, mAUC, d_prime = evaluate(dataloader, model, device, metric_name) 58 | table = [ 59 | ['mAP', f"{mAP:.2f}"], 60 | ['AUC', f"{mAUC:.2f}"], 61 | ['d-prime', f"{d_prime:.2f}"] 62 | ] 63 | print(tabulate(table, numalign='right')) 64 | 65 | 66 | if __name__ == '__main__': 67 | parser = argparse.ArgumentParser() 68 | parser.add_argument('--cfg', type=str, required=True, help='Experiment configuration file name') 69 | args = parser.parse_args() 70 | 71 | with open(args.cfg) as f: 72 | cfg = yaml.load(f, Loader=yaml.SafeLoader) 73 | 74 | pprint(cfg) 75 | setup_cudnn() 76 | main(cfg) -------------------------------------------------------------------------------- /tools/infer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import argparse 3 | import torchaudio 4 | import yaml 5 | import numpy as np 6 | from torch import Tensor 7 | from tabulate import tabulate 8 | from pathlib import Path 9 | from torchaudio import transforms as T 10 | from torchaudio import functional as F 11 | 12 | import sys 13 | sys.path.insert(0, '.') 14 | from models import get_model 15 | from datasets import __all__ 16 | from utils.utils import time_sync 17 | 18 | 19 | class AudioTagging: 20 | def __init__(self, cfg) -> None: 21 | self.device = torch.device(cfg['DEVICE']) 22 | self.labels = np.array(__all__[cfg['DATASET']['NAME']].CLASSES) 23 | self.model = get_model(cfg['MODEL']['NAME'], len(self.labels)) 24 | self.model.load_state_dict(torch.load(cfg['TEST']['MODEL_PATH'], map_location='cpu')) 25 | self.model = self.model.to(self.device) 26 | self.model.eval() 27 | 28 | self.topk = cfg['TEST']['TOPK'] 29 | self.sample_rate = cfg['DATASET']['SAMPLE_RATE'] 30 | self.mel_tf = T.MelSpectrogram(self.sample_rate, cfg['DATASET']['WIN_LENGTH'], cfg['DATASET']['WIN_LENGTH'], cfg['DATASET']['HOP_LENGTH'], cfg['DATASET']['FMIN'], cfg['DATASET']['FMAX'], n_mels=cfg['DATASET']['N_MELS'], norm='slaney') 31 | 32 | def preprocess(self, file: str) -> Tensor: 33 | audio, sr = torchaudio.load(file) 34 | if sr != self.sample_rate: audio = F.resample(audio, sr, self.sample_rate) 35 | audio = self.mel_tf(audio) 36 | audio = 10.0 * audio.clamp_(1e-10).log10() 37 | audio = audio.unsqueeze(0) 38 | audio = audio.to(self.device) 39 | return audio 40 | 41 | def postprocess(self, prob: Tensor) -> str: 42 | probs, indices = torch.topk(prob.sigmoid().squeeze().cpu(), self.topk) 43 | return self.labels[indices], probs 44 | 45 | @torch.no_grad() 46 | def model_forward(self, audio: Tensor) -> Tensor: 47 | start = time_sync() 48 | pred = self.model(audio) 49 | end = time_sync() 50 | print(f"Model Inference Time: {(end-start)*1000:.2f}ms") 51 | return pred 52 | 53 | def predict(self, file: str) -> str: 54 | audio = self.preprocess(file) 55 | pred = self.model_forward(audio) 56 | labels, probs = self.postprocess(pred) 57 | return labels, probs 58 | 59 | 60 | if __name__ == '__main__': 61 | parser = argparse.ArgumentParser() 62 | parser.add_argument('--cfg', type=str, default='configs/audioset.yaml') 63 | args = parser.parse_args() 64 | 65 | with open(args.cfg) as f: 66 | cfg = yaml.load(f, Loader=yaml.SafeLoader) 67 | 68 | file_path = Path(cfg['TEST']['FILE']) 69 | model = AudioTagging(cfg) 70 | 71 | if cfg['TEST']['MODE'] == 'file': 72 | if file_path.is_file(): 73 | labels, probs = model.predict(str(file_path)) 74 | print(tabulate({"Class": labels, "Confidence": probs}, headers='keys')) 75 | else: 76 | raise NotImplementedError 77 | else: 78 | raise NotImplementedError -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import random 4 | import time 5 | import os 6 | from pathlib import Path 7 | from torch.backends import cudnn 8 | from torch import nn, Tensor 9 | from torch.autograd import profiler 10 | from typing import Union 11 | from torch import distributed as dist 12 | 13 | 14 | def fix_seeds(seed: int = 123) -> None: 15 | torch.manual_seed(seed) 16 | torch.cuda.manual_seed(seed) 17 | np.random.seed(seed) 18 | random.seed(seed) 19 | 20 | def setup_cudnn() -> None: 21 | # Speed-reproducibility tradeoff https://pytorch.org/docs/stable/notes/randomness.html 22 | cudnn.benchmark = True 23 | cudnn.deterministic = False 24 | 25 | def time_sync() -> float: 26 | if torch.cuda.is_available(): 27 | torch.cuda.synchronize() 28 | return time.time() 29 | 30 | def get_model_size(model: Union[nn.Module, torch.jit.ScriptModule]): 31 | tmp_model_path = Path('temp.p') 32 | if isinstance(model, torch.jit.ScriptModule): 33 | torch.jit.save(model, tmp_model_path) 34 | else: 35 | torch.save(model.state_dict(), tmp_model_path) 36 | size = tmp_model_path.stat().st_size 37 | os.remove(tmp_model_path) 38 | return size / 1e6 # in MB 39 | 40 | @torch.no_grad() 41 | def test_model_latency(model: nn.Module, inputs: torch.Tensor, use_cuda: bool = False) -> float: 42 | with profiler.profile(use_cuda=use_cuda) as prof: 43 | _ = model(inputs) 44 | return prof.self_cpu_time_total / 1000 # ms 45 | 46 | def count_parameters(model: nn.Module) -> float: 47 | return sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6 # in M 48 | 49 | def setup_ddp() -> int: 50 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 51 | rank = int(os.environ['RANK']) 52 | world_size = int(os.environ['WORLD_SIZE']) 53 | gpu = int(os.environ(['LOCAL_RANK'])) 54 | torch.cuda.set_device(gpu) 55 | dist.init_process_group('nccl', init_method="env://", world_size=world_size, rank=rank) 56 | dist.barrier() 57 | else: 58 | gpu = 0 59 | return gpu 60 | 61 | def cleanup_ddp(): 62 | if dist.is_initialized(): 63 | dist.destroy_process_group() 64 | 65 | def reduce_tensor(tensor: Tensor) -> Tensor: 66 | rt = tensor.clone() 67 | dist.all_reduce(rt, op=dist.ReduceOp.SUM) 68 | rt /= dist.get_world_size() 69 | return rt 70 | 71 | @torch.no_grad() 72 | def throughput(dataloader, model: nn.Module, times: int = 30): 73 | model.eval() 74 | images, _ = next(iter(dataloader)) 75 | images = images.cuda(non_blocking=True) 76 | B = images.shape[0] 77 | print(f"Throughput averaged with {times} times") 78 | start = time_sync() 79 | for _ in range(times): 80 | model(images) 81 | end = time_sync() 82 | 83 | print(f"Batch Size {B} throughput {times * B / (end - start)} images/s") 84 | 85 | 86 | def get_dataset_norm(dataloader): 87 | mean, std = 0.0, 0.0 88 | 89 | for audio, _ in dataloader: 90 | mean += audio.mean().item() * audio.shape[0] 91 | std += audio.std().item() * audio.shape[0] 92 | return mean / len(dataloader.dataset), std / len(dataloader.dataset) -------------------------------------------------------------------------------- /utils/visualize.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchaudio 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | from torch import Tensor 6 | from torchaudio import functional as AF 7 | 8 | 9 | def plot_waveform(waveform: Tensor, sample_rate: int): 10 | waveform = waveform.numpy() 11 | num_channels, num_frames = waveform.shape 12 | time_axis = torch.arange(0, num_frames) / sample_rate 13 | 14 | fig, axes = plt.subplots(num_channels, 1) 15 | if num_channels == 1: axes = [axes] 16 | 17 | for c in range(num_channels): 18 | axes[c].plot(time_axis, waveform[c], linewidth=1) 19 | axes[c].grid(True) 20 | if num_channels > 1: axes[c].set_ylabel(f"Channel {c+1}") 21 | 22 | fig.suptitle("Waveform") 23 | plt.show(block=False) 24 | 25 | def plot_specgram(waveform: Tensor, sample_rate: int): 26 | waveform = waveform.numpy() 27 | num_channels, _ = waveform.shape 28 | 29 | fig, axes = plt.subplots(num_channels, 1) 30 | if num_channels == 1: axes = [axes] 31 | 32 | for c in range(num_channels): 33 | axes[c].specgram(waveform[c], Fs=sample_rate) 34 | if num_channels > 1: axes[c].set_ylabel(f"Channel {c+1}") 35 | 36 | fig.suptitle("Spectrogram") 37 | plt.show(block=False) 38 | 39 | def plot_spectrogram(spec: Tensor): 40 | fig, ax = plt.subplots(1, 1) 41 | ax.set_title("Spectrogram (db)") 42 | ax.set_ylabel('freq_bin') 43 | ax.set_xlabel('frame') 44 | im = ax.imshow(AF.amplitude_to_DB(spec[0], 10, 1e-10, np.log10(max(spec.max(), 1e-10))).numpy(), origin='lower', aspect='auto') 45 | fig.colorbar(im, ax=ax) 46 | plt.show(block=False) 47 | 48 | def plot_mel_fbank(fbank: Tensor): 49 | plt.imshow(fbank.numpy(), aspect='auto') 50 | plt.xlabel('mel_bin') 51 | plt.ylabel('freq_bin') 52 | plt.title('Filter bank') 53 | plt.show(block=False) 54 | 55 | def plot_pitch(waveform: Tensor, sample_rate: int, pitch: Tensor): 56 | num_channels, num_frames = waveform.shape 57 | pitch_channels, pitch_frames = pitch.shape 58 | time_axis = torch.linspace(0, num_frames / sample_rate, num_frames) 59 | pitch_axis = torch.linspace(0, num_frames / sample_rate, pitch_frames) 60 | 61 | plt.plot(time_axis, num_channels, linewidth=1, color='gray', alpha=0.3, label='Waveform') 62 | plt.plot(pitch_axis, pitch_channels, linewidth=2, color='green', label='Pitch') 63 | plt.title("Pitch Feature") 64 | plt.grid(True) 65 | plt.legend(loc=0) 66 | plt.show(block=False) 67 | 68 | 69 | def play_audio(waveform: Tensor, sample_rate: int): 70 | from IPython.display import Audio, display 71 | waveform = waveform.numpy() 72 | num_channels, _ = waveform.shape 73 | 74 | if num_channels == 1: 75 | display(Audio(waveform[0], rate=sample_rate)) 76 | elif num_channels == 2: 77 | display(Audio((waveform[0], waveform[1]), rate=sample_rate)) 78 | else: 79 | raise ValueError("Waveform with more than 2 channels are not supported.") 80 | 81 | def plot_sound_events(results: Tensor, labels: list, fps: int = 100): 82 | fig, ax = plt.subplots(1, 1, figsize=(10, 4)) 83 | ax.matshow(results, origin='upper', aspect='auto', cmap='jet', vmin=0, vmax=1) 84 | ax.set_title('Sound Event Detection') 85 | ax.set_xlabel('Seconds') 86 | ax.set_xticks(np.arange(0, results.shape[1], fps)) 87 | ax.set_xticklabels(np.arange(0, results.shape[1]/fps)) 88 | ax.set_yticks(np.arange(0, results.shape[0])) 89 | ax.set_yticklabels(labels) 90 | ax.yaxis.grid(color='k', linestyle='solid', linewidth=0.3, alpha=0.3) 91 | ax.xaxis.set_ticks_position('bottom') 92 | 93 | plt.tight_layout() 94 | plt.show() 95 | # plt.savefig('result.jpg') -------------------------------------------------------------------------------- /tools/sed_infer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import argparse 3 | import torchaudio 4 | import yaml 5 | import numpy as np 6 | from torch import Tensor 7 | from tabulate import tabulate 8 | from pathlib import Path 9 | from torchaudio import transforms as T 10 | from torchaudio import functional as F 11 | 12 | import sys 13 | sys.path.insert(0, '.') 14 | from models import get_model 15 | from datasets import __all__ 16 | from utils.utils import time_sync 17 | from utils.visualize import plot_sound_events 18 | 19 | 20 | class SED: 21 | def __init__(self, cfg) -> None: 22 | self.device = torch.device(cfg['DEVICE']) 23 | self.labels = np.array(__all__[cfg['DATASET']['NAME']].CLASSES) 24 | self.model = get_model(cfg['MODEL']['NAME'], len(self.labels)) 25 | self.model.load_state_dict(torch.load(cfg['TEST']['MODEL_PATH'], map_location='cpu')) 26 | self.model = self.model.to(self.device) 27 | self.model.eval() 28 | 29 | self.threshold = cfg['TEST']['THRESHOLD'] 30 | self.sample_rate = cfg['DATASET']['SAMPLE_RATE'] 31 | self.mel_tf = T.MelSpectrogram(self.sample_rate, cfg['DATASET']['WIN_LENGTH'], cfg['DATASET']['WIN_LENGTH'], cfg['DATASET']['HOP_LENGTH'], cfg['DATASET']['FMIN'], cfg['DATASET']['FMAX'], n_mels=cfg['DATASET']['N_MELS'], norm='slaney') 32 | 33 | def preprocess(self, file: str) -> Tensor: 34 | audio, sr = torchaudio.load(file) 35 | if sr != self.sample_rate: audio = F.resample(audio, sr, self.sample_rate) 36 | audio = self.mel_tf(audio) 37 | audio = 10.0 * audio.clamp_(1e-10).log10() 38 | audio = audio.unsqueeze(0) 39 | audio = audio.to(self.device) 40 | return audio 41 | 42 | def postprocess(self, prob: Tensor) -> str: 43 | probs, indices = torch.sort(prob.max(dim=0)[0], descending=True) 44 | indices = indices[probs > self.threshold] 45 | # probs, indices = torch.topk(prob.max(dim=0)[0], topk) 46 | top_results = prob[:, indices].t() 47 | 48 | starts = [] 49 | ends = [] 50 | 51 | for result in top_results: 52 | index = torch.where(result > self.threshold)[0] 53 | starts.append(round(index[0].item()/100, 1)) 54 | ends.append(round(index[-1].item()/100, 1)) 55 | 56 | return top_results, self.labels[indices], starts, ends 57 | 58 | 59 | @torch.no_grad() 60 | def model_forward(self, audio: Tensor) -> Tensor: 61 | start = time_sync() 62 | pred = self.model(audio)[0].squeeze().cpu() 63 | end = time_sync() 64 | print(f"Model Inference Time: {(end-start)*1000:.2f}ms") 65 | return pred 66 | 67 | def predict(self, file: str) -> str: 68 | audio = self.preprocess(file) 69 | pred = self.model_forward(audio) 70 | results, labels, starts, ends = self.postprocess(pred) 71 | return results, labels, starts, ends 72 | 73 | 74 | if __name__ == '__main__': 75 | parser = argparse.ArgumentParser() 76 | parser.add_argument('--cfg', type=str, default='configs/audioset_sed.yaml') 77 | args = parser.parse_args() 78 | 79 | with open(args.cfg) as f: 80 | cfg = yaml.load(f, Loader=yaml.SafeLoader) 81 | 82 | file_path = Path(cfg['TEST']['FILE']) 83 | model = SED(cfg) 84 | 85 | if cfg['TEST']['MODE'] == 'file': 86 | if file_path.is_file(): 87 | results, labels, starts, ends = model.predict(str(file_path)) 88 | print(tabulate({"Class": labels, "Start": starts, "End": ends}, headers='keys')) 89 | if cfg['TEST']['PLOT']: 90 | plot_sound_events(results, labels) 91 | else: 92 | raise NotImplementedError 93 | else: 94 | raise NotImplementedError -------------------------------------------------------------------------------- /datasets/urbansound.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random 3 | import torchaudio 4 | from pathlib import Path 5 | from torch import Tensor 6 | from torchaudio import transforms as T 7 | from torchaudio import functional as AF 8 | from torch.nn import functional as F 9 | from torch.utils.data import Dataset, DataLoader 10 | from typing import Tuple 11 | from .transforms import mixup_augment 12 | 13 | 14 | class UrbanSound8k(Dataset): 15 | CLASSES = ['air conditioner', 'car horn', 'children playing', 'dog bark', 'drilling', 'engine idling', 'gun shot', 'jackhammer', 'siren', 'street music'] 16 | 17 | def __init__(self, split, data_cfg, mixup_cfg=None, transform=None, spec_transform=None) -> None: 18 | super().__init__() 19 | assert split in ['train', 'val'] 20 | self.num_classes = len(self.CLASSES) 21 | self.transform = transform 22 | self.spec_transform = spec_transform 23 | self.mixup = mixup_cfg['MIXUP'] if mixup_cfg is not None else 0.0 24 | self.mixup_alpha = mixup_cfg['MIXUP_ALPHA'] if mixup_cfg is not None else 0.0 25 | self.label_smooth = mixup_cfg['SMOOTHING'] if mixup_cfg is not None else 0.0 26 | self.sample_rate = data_cfg['SAMPLE_RATE'] 27 | self.num_frames = self.sample_rate * data_cfg['AUDIO_LENGTH'] 28 | 29 | val_fold = 10 30 | 31 | self.mel_tf = T.MelSpectrogram(self.sample_rate, data_cfg['WIN_LENGTH'], data_cfg['WIN_LENGTH'], data_cfg['HOP_LENGTH'], data_cfg['FMIN'], data_cfg['FMAX'], n_mels=data_cfg['N_MELS'], norm='slaney') # using mel_scale='slaney' is better 32 | self.data, self.targets = self.get_data(data_cfg['ROOT'], split, val_fold) 33 | print(f"Found {len(self.data)} {split} audios in {data_cfg['ROOT']}.") 34 | 35 | 36 | def get_data(self, root: str, split: int, fold: int): 37 | root = Path(root) 38 | files = (root / 'audio').rglob('*.wav') 39 | files = list(filter(lambda x: not str(x.parent).endswith(f"fold{fold}") if split == 'train' else str(x.parent).endswith(f"fold{fold}"), files)) 40 | targets = list(map(lambda x: int(x.stem.split('-', maxsplit=2)[1]), files)) 41 | assert len(files) == len(targets) 42 | return files, targets 43 | 44 | 45 | def __len__(self) -> int: 46 | return len(self.data) 47 | 48 | 49 | def __getitem__(self, index: int) -> Tuple[Tensor, Tensor]: 50 | audio, sr = torchaudio.load(self.data[index]) 51 | if audio.shape[0] != 1: audio = audio[:1] # reduce to mono 52 | audio = AF.resample(audio, sr, self.sample_rate) # resample the audio 53 | if audio.shape[1] < self.num_frames: audio = torch.cat([audio, torch.zeros(1, self.num_frames-audio.shape[1])], dim=-1) 54 | target = torch.tensor(self.targets[index]) 55 | 56 | if self.transform: audio = self.transform(audio) 57 | 58 | if random.random() < self.mixup: 59 | next_index = random.randint(0, len(self.data)-1) 60 | next_audio, sr = torchaudio.load(self.data[next_index]) 61 | if next_audio.shape[0] != 1: next_audio = next_audio[:1] # reduce to mono 62 | next_audio = AF.resample(next_audio, sr, self.sample_rate) 63 | if next_audio.shape[1] < self.num_frames: next_audio = torch.cat([next_audio, torch.zeros(1, self.num_frames-next_audio.shape[1])], dim=-1) 64 | next_target = torch.tensor(self.targets[next_index]) 65 | audio, target = mixup_augment(audio, target, next_audio, next_target, self.mixup_alpha, self.num_classes, self.label_smooth) 66 | else: 67 | target = F.one_hot(target, self.num_classes).float() 68 | 69 | audio = self.mel_tf(audio) # convert to mel spectrogram 70 | audio = 10.0 * audio.clamp_(1e-10).log10() # convert to log mel spectrogram 71 | 72 | if self.spec_transform: audio = self.spec_transform(audio) 73 | 74 | return audio, target 75 | 76 | 77 | if __name__ == '__main__': 78 | data_cfg = { 79 | 'ROOT': 'C:/Users/sithu/Documents/Datasets/Urbansound8K', 80 | 'SAMPLE_RATE': 32000, 81 | 'AUDIO_LENGTH': 4, 82 | 'WIN_LENGTH': 1024, 83 | 'HOP_LENGTH': 320, 84 | 'N_MELS': 64, 85 | 'FMIN': 50, 86 | 'FMAX': 14000 87 | } 88 | aug_cfg = { 89 | 'MIXUP': 0.5, 90 | 'MIXUP_ALPHA': 10, 91 | 'SMOOTHING': 0.1 92 | } 93 | dataset = UrbanSound8k('train', data_cfg, aug_cfg) 94 | dataloader = DataLoader(dataset, 2, True) 95 | for audio, target in dataloader: 96 | print(audio.shape, target.argmax(dim=1)) 97 | print(audio.min(), audio.max()) 98 | break -------------------------------------------------------------------------------- /datasets/esc50.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random 3 | import torchaudio 4 | from pathlib import Path 5 | from torch import Tensor 6 | from torchaudio import transforms as T 7 | from torch.nn import functional as F 8 | from torch.utils.data import Dataset, DataLoader 9 | from typing import Tuple 10 | from .transforms import mixup_augment 11 | 12 | 13 | class ESC50(Dataset): 14 | """ 15 | 50 classes 16 | 40 examples per class 17 | 2000 recordings 18 | 5 major categories: [Animals, Nature sounds, Human non-speech sounds, Interior/domestic sounds, Exterior/urban sounds] 19 | Each of the audio is named like this: 20 | {FOLD}-{CLIP_ID}-{TAKE}-{TARGET}.wav 21 | """ 22 | CLASSES = ['dog', 'rooster', 'pig', 'cow', 'frog', 'cat', 'hen', 'insects', 'sheep', 'crow', 'rain', 'sea_waves', 'crackling_fire', 'crickets', 'chirping_birds', 'water_drops', 'wind', 'pouring_water', 'toilet_flush', 'thunderstorm', 'crying_baby', 'sneezing', 'clapping', 'breathing', 'coughing', 'footsteps', 'laughing', 'brushing_teeth', 23 | 'snoring', 'drinking_sipping', 'door_wood_knock', 'mouse_click', 'keyboard_typing', 'door_wood_creaks', 'can_opening', 'washing_machine', 'vacuum_cleaner', 'clock_alarm', 'clock_tick', 'glass_breaking', 'helicopter', 'chainsaw', 'siren', 'car_horn', 'engine', 'train', 'church_bells', 'airplane', 'fireworks', 'hand_saw'] 24 | 25 | def __init__(self, split, data_cfg, mixup_cfg=None, transform=None, spec_transform=None) -> None: 26 | super().__init__() 27 | assert split in ['train', 'val'] 28 | self.num_classes = len(self.CLASSES) 29 | self.transform = transform 30 | self.spec_transform = spec_transform 31 | self.mixup = mixup_cfg['MIXUP'] if mixup_cfg is not None else 0.0 32 | self.mixup_alpha = mixup_cfg['MIXUP_ALPHA'] if mixup_cfg is not None else 0.0 33 | self.label_smooth = mixup_cfg['SMOOTHING'] if mixup_cfg is not None else 0.0 34 | val_fold = 5 35 | 36 | self.mel_tf = T.MelSpectrogram(data_cfg['SAMPLE_RATE'], data_cfg['WIN_LENGTH'], data_cfg['WIN_LENGTH'], data_cfg['HOP_LENGTH'], data_cfg['FMIN'], data_cfg['FMAX'], n_mels=data_cfg['N_MELS'], norm='slaney') # using mel_scale='slaney' is better 37 | self.resample = T.Resample(data_cfg['SOURCE_SAMPLE'], data_cfg['SAMPLE_RATE']) 38 | 39 | self.data, self.targets = self.get_data(data_cfg['ROOT'], split, val_fold) 40 | print(f"Found {len(self.data)} {split} audios in {data_cfg['ROOT']}.") 41 | 42 | 43 | def get_data(self, root: str, split: str, fold: int): 44 | root = Path(root) 45 | files = (root / 'audio').glob('*.wav') 46 | files = list(filter(lambda x: not x.stem.startswith(f"{fold}") if split == 'train' else x.stem.startswith(f"{fold}"), files)) 47 | targets = list(map(lambda x: int(x.stem.rsplit('-', maxsplit=1)[-1]), files)) 48 | assert len(files) == len(targets) 49 | return files, targets 50 | 51 | 52 | def __len__(self) -> int: 53 | return len(self.data) 54 | 55 | 56 | def __getitem__(self, index: int) -> Tuple[Tensor, Tensor]: 57 | audio, _ = torchaudio.load(self.data[index]) 58 | audio = self.resample(audio) 59 | target = torch.tensor(self.targets[index]) 60 | 61 | if self.transform: audio = self.transform(audio) 62 | 63 | if random.random() < self.mixup: 64 | next_index = random.randint(0, len(self.data)-1) 65 | next_audio, _ = torchaudio.load(self.data[next_index]) 66 | next_audio = self.resample(next_audio) 67 | next_target = torch.tensor(self.targets[next_index]) 68 | audio, target = mixup_augment(audio, target, next_audio, next_target, self.mixup_alpha, self.num_classes, self.label_smooth) 69 | else: 70 | target = F.one_hot(target, self.num_classes).float() 71 | 72 | audio = self.mel_tf(audio) 73 | audio = 10.0 * audio.clamp_(1e-10).log10() 74 | 75 | if self.spec_transform: audio = self.spec_transform(audio) 76 | 77 | return audio, target 78 | 79 | 80 | if __name__ == '__main__': 81 | data_cfg = { 82 | 'ROOT': 'C:/Users/sithu/Documents/Datasets/ESC50', 83 | 'SOURCE_SAMPLE': 44100, 84 | 'SAMPLE_RATE': 32000, 85 | 'AUDIO_LENGTH': 5, 86 | 'WIN_LENGTH': 1024, 87 | 'HOP_LENGTH': 320, 88 | 'N_MELS': 64, 89 | 'FMIN': 50, 90 | 'FMAX': 14000 91 | } 92 | aug_cfg = { 93 | 'MIXUP': 0.5, 94 | 'MIXUP_ALPHA': 10, 95 | 'SMOOTHING': 0.1 96 | } 97 | dataset = ESC50('val', data_cfg, aug_cfg) 98 | dataloader = DataLoader(dataset, 2, True) 99 | for audio, target in dataloader: 100 | print(audio.shape, target.argmax(dim=1)) 101 | print(audio.min(), audio.max()) 102 | break 103 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Repo-specific GitIgnore ---------------------------------------------------------------------------------------------- 2 | *.jpg 3 | *.jpeg 4 | *.png 5 | *.bmp 6 | *.tif 7 | *.tiff 8 | *.heic 9 | *.JPG 10 | *.JPEG 11 | *.PNG 12 | *.BMP 13 | *.TIF 14 | *.TIFF 15 | *.HEIC 16 | *.mp4 17 | *.mov 18 | *.MOV 19 | *.avi 20 | *.data 21 | *.json 22 | 23 | *.cfg 24 | !cfg/yolov3*.cfg 25 | 26 | storage.googleapis.com 27 | test_imgs/ 28 | runs/* 29 | data/* 30 | !data/images/zidane.jpg 31 | !data/images/bus.jpg 32 | !data/coco.names 33 | !data/coco_paper.names 34 | !data/coco.data 35 | !data/coco_*.data 36 | !data/coco_*.txt 37 | !data/trainvalno5k.shapes 38 | !data/*.sh 39 | 40 | pycocotools/* 41 | results*.txt 42 | gcp_test*.sh 43 | 44 | checkpoints/ 45 | output/ 46 | 47 | # Datasets ------------------------------------------------------------------------------------------------------------- 48 | coco/ 49 | coco128/ 50 | VOC/ 51 | 52 | # MATLAB GitIgnore ----------------------------------------------------------------------------------------------------- 53 | *.m~ 54 | *.mat 55 | !targets*.mat 56 | 57 | # Neural Network weights ----------------------------------------------------------------------------------------------- 58 | *.weights 59 | *.pt 60 | *.onnx 61 | *.mlmodel 62 | *.torchscript 63 | darknet53.conv.74 64 | yolov3-tiny.conv.15 65 | 66 | # GitHub Python GitIgnore ---------------------------------------------------------------------------------------------- 67 | # Byte-compiled / optimized / DLL files 68 | __pycache__/ 69 | *.py[cod] 70 | *$py.class 71 | 72 | # C extensions 73 | *.so 74 | 75 | # Distribution / packaging 76 | .Python 77 | env/ 78 | build/ 79 | develop-eggs/ 80 | dist/ 81 | downloads/ 82 | eggs/ 83 | .eggs/ 84 | lib/ 85 | lib64/ 86 | parts/ 87 | sdist/ 88 | var/ 89 | wheels/ 90 | *.egg-info/ 91 | wandb/ 92 | .installed.cfg 93 | *.egg 94 | 95 | 96 | # PyInstaller 97 | # Usually these files are written by a python script from a template 98 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 99 | *.manifest 100 | *.spec 101 | 102 | # Installer logs 103 | pip-log.txt 104 | pip-delete-this-directory.txt 105 | 106 | # Unit test / coverage reports 107 | htmlcov/ 108 | .tox/ 109 | .coverage 110 | .coverage.* 111 | .cache 112 | nosetests.xml 113 | coverage.xml 114 | *.cover 115 | .hypothesis/ 116 | 117 | # Translations 118 | *.mo 119 | *.pot 120 | 121 | # Django stuff: 122 | *.log 123 | local_settings.py 124 | tools/flask_deploy.py 125 | # Flask stuff: 126 | instance/ 127 | .webassets-cache 128 | 129 | # Scrapy stuff: 130 | .scrapy 131 | 132 | # Sphinx documentation 133 | docs/_build/ 134 | 135 | # PyBuilder 136 | target/ 137 | 138 | # Jupyter Notebook 139 | .ipynb_checkpoints 140 | 141 | # pyenv 142 | .python-version 143 | 144 | # celery beat schedule file 145 | celerybeat-schedule 146 | 147 | # SageMath parsed files 148 | *.sage.py 149 | 150 | # dotenv 151 | .env 152 | 153 | # virtualenv 154 | .venv* 155 | venv*/ 156 | ENV*/ 157 | 158 | # Spyder project settings 159 | .spyderproject 160 | .spyproject 161 | 162 | # Rope project settings 163 | .ropeproject 164 | 165 | # mkdocs documentation 166 | /site 167 | 168 | # mypy 169 | .mypy_cache/ 170 | 171 | 172 | # https://github.com/github/gitignore/blob/master/Global/macOS.gitignore ----------------------------------------------- 173 | 174 | # General 175 | .DS_Store 176 | .AppleDouble 177 | .LSOverride 178 | 179 | # Icon must end with two \r 180 | Icon 181 | Icon? 182 | 183 | # Thumbnails 184 | ._* 185 | 186 | # Files that might appear in the root of a volume 187 | .DocumentRevisions-V100 188 | .fseventsd 189 | .Spotlight-V100 190 | .TemporaryItems 191 | .Trashes 192 | .VolumeIcon.icns 193 | .com.apple.timemachine.donotpresent 194 | 195 | # Directories potentially created on remote AFP share 196 | .AppleDB 197 | .AppleDesktop 198 | Network Trash Folder 199 | Temporary Items 200 | .apdisk 201 | 202 | 203 | # https://github.com/github/gitignore/blob/master/Global/JetBrains.gitignore 204 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and WebStorm 205 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 206 | 207 | # User-specific stuff: 208 | .idea/* 209 | .idea/**/workspace.xml 210 | .idea/**/tasks.xml 211 | .idea/dictionaries 212 | .html # Bokeh Plots 213 | .pg # TensorFlow Frozen Graphs 214 | .avi # videos 215 | 216 | # Sensitive or high-churn files: 217 | .idea/**/dataSources/ 218 | .idea/**/dataSources.ids 219 | .idea/**/dataSources.local.xml 220 | .idea/**/sqlDataSources.xml 221 | .idea/**/dynamic.xml 222 | .idea/**/uiDesigner.xml 223 | 224 | # Gradle: 225 | .idea/**/gradle.xml 226 | .idea/**/libraries 227 | 228 | # CMake 229 | cmake-build-debug/ 230 | cmake-build-release/ 231 | 232 | # Mongo Explorer plugin: 233 | .idea/**/mongoSettings.xml 234 | 235 | ## File-based project format: 236 | *.iws 237 | 238 | ## Plugin-specific files: 239 | 240 | # IntelliJ 241 | out/ 242 | 243 | # mpeltonen/sbt-idea plugin 244 | .idea_modules/ 245 | 246 | # JIRA plugin 247 | atlassian-ide-plugin.xml 248 | 249 | # Cursive Clojure plugin 250 | .idea/replstate.xml 251 | 252 | # Crashlytics plugin (for Android Studio and IntelliJ) 253 | com_crashlytics_export_strings.xml 254 | crashlytics.properties 255 | crashlytics-build.properties 256 | fabric.properties 257 | -------------------------------------------------------------------------------- /datasets/fsdkaggle.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random 3 | import torchaudio 4 | from pathlib import Path 5 | from torch import Tensor 6 | from torchaudio import transforms as T 7 | from torch.nn import functional as F 8 | from torch.utils.data import Dataset, DataLoader 9 | from typing import Tuple 10 | from .transforms import mixup_augment 11 | 12 | 13 | class FSDKaggle2018(Dataset): 14 | CLASSES = ["Acoustic_guitar", "Applause", "Bark", "Bass_drum", "Burping_or_eructation", "Bus", "Cello", "Chime", "Clarinet", "Computer_keyboard", "Cough", "Cowbell", "Double_bass", "Drawer_open_or_close", "Electric_piano", "Fart", "Finger_snapping", "Fireworks", "Flute", 15 | "Glockenspiel", "Gong", "Gunshot_or_gunfire", "Harmonica", "Hi-hat", "Keys_jangling", "Knock", "Laughter", "Meow", "Microwave_oven", "Oboe", "Saxophone", "Scissors", "Shatter", "Snare_drum", "Squeak", "Tambourine", "Tearing", "Telephone", "Trumpet", "Violin_or_fiddle", "Writing"] 16 | 17 | def __init__(self, split, data_cfg, mixup_cfg=None, transform=None, spec_transform=None) -> None: 18 | super().__init__() 19 | assert split in ['train', 'val'] 20 | split = 'train' if split == 'train' else 'test' 21 | self.num_classes = len(self.CLASSES) 22 | self.transform = transform 23 | self.spec_transform = spec_transform 24 | self.mixup = mixup_cfg['MIXUP'] if mixup_cfg is not None else 0.0 25 | self.mixup_alpha = mixup_cfg['MIXUP_ALPHA'] if mixup_cfg is not None else 0.0 26 | self.label_smooth = mixup_cfg['SMOOTHING'] if mixup_cfg is not None else 0.0 27 | self.num_frames = data_cfg['SAMPLE_RATE'] * data_cfg['AUDIO_LENGTH'] 28 | 29 | self.mel_tf = T.MelSpectrogram(data_cfg['SAMPLE_RATE'], data_cfg['WIN_LENGTH'], data_cfg['WIN_LENGTH'], data_cfg['HOP_LENGTH'], data_cfg['FMIN'], data_cfg['FMAX'], n_mels=data_cfg['N_MELS'], norm='slaney') # using mel_scale='slaney' is better 30 | self.resample = T.Resample(data_cfg['SOURCE_SAMPLE'], data_cfg['SAMPLE_RATE']) 31 | 32 | self.data, self.targets = self.get_data(data_cfg['ROOT'], split) 33 | print(f"Found {len(self.data)} {split} audios in {data_cfg['ROOT']}.") 34 | 35 | 36 | def get_data(self, root: str, split: str): 37 | root = Path(root) 38 | csv_path = 'train_post_competition.csv' if split == 'train' else 'test_post_competition_scoring_clips.csv' 39 | files, targets = [], [] 40 | 41 | with open(root / 'FSDKaggle2018.meta' / csv_path) as f: 42 | lines = f.read().splitlines()[1:] 43 | for line in lines: 44 | fname, label, *_ = line.split(',') 45 | files.append(str(root / f"audio_{split}" / fname)) 46 | targets.append(int(self.CLASSES.index(label))) 47 | return files, targets 48 | 49 | 50 | def __len__(self) -> int: 51 | return len(self.data) 52 | 53 | 54 | def cut_pad(self, audio: Tensor) -> Tensor: 55 | if audio.shape[1] < self.num_frames: # if less than 5s, pad the audio 56 | audio = torch.cat([audio, torch.zeros(1, self.num_frames-audio.shape[1])], dim=-1) 57 | else: # if not, trim the audio to 5s 58 | audio = audio[:, :self.num_frames] 59 | return audio 60 | 61 | 62 | def __getitem__(self, index: int) -> Tuple[Tensor, Tensor]: 63 | audio, _ = torchaudio.load(self.data[index]) 64 | audio = self.resample(audio) # resample to 32kHz 65 | audio = self.cut_pad(audio) 66 | target = torch.tensor(self.targets[index]) 67 | 68 | if self.transform: audio = self.transform(audio) 69 | 70 | if random.random() < self.mixup: 71 | next_index = random.randint(0, len(self.data)-1) 72 | next_audio, _ = torchaudio.load(self.data[next_index]) 73 | next_audio = self.resample(next_audio) 74 | next_audio = self.cut_pad(next_audio) 75 | next_target = torch.tensor(self.targets[next_index]) 76 | audio, target = mixup_augment(audio, target, next_audio, next_target, self.mixup_alpha, self.num_classes, self.label_smooth) 77 | else: 78 | target = F.one_hot(target, self.num_classes).float() 79 | 80 | audio = self.mel_tf(audio) # convert to mel spectrogram 81 | audio = 10.0 * audio.clamp_(1e-10).log10() # convert to log mel spectrogram 82 | 83 | if self.spec_transform: audio = self.spec_transform(audio) 84 | 85 | return audio, target 86 | 87 | 88 | if __name__ == '__main__': 89 | data_cfg = { 90 | 'ROOT': 'C:/Users/sithu/Documents/Datasets/FSDKaggle2018', 91 | 'SOURCE_SAMPLE': 44100, 92 | 'SAMPLE_RATE': 32000, 93 | 'AUDIO_LENGTH': 5, 94 | 'WIN_LENGTH': 1024, 95 | 'HOP_LENGTH': 320, 96 | 'N_MELS': 64, 97 | 'FMIN': 50, 98 | 'FMAX': 14000 99 | } 100 | aug_cfg = { 101 | 'MIXUP': 0.5, 102 | 'MIXUP_ALPHA': 10, 103 | 'SMOOTHING': 0.1 104 | } 105 | dataset = FSDKaggle2018('val', data_cfg, aug_cfg) 106 | dataloader = DataLoader(dataset, 2, True) 107 | for audio, target in dataloader: 108 | print(audio.shape, target.argmax(dim=1)) 109 | print(audio.min(), audio.max()) 110 | break 111 | 112 | -------------------------------------------------------------------------------- /datasets/speechcommands.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random 3 | import torchaudio 4 | import os 5 | from pathlib import Path 6 | from torch import Tensor 7 | from torchaudio import transforms as T 8 | from torch.nn import functional as F 9 | from torch.utils.data import Dataset, DataLoader 10 | from typing import Tuple 11 | from .transforms import mixup_augment 12 | 13 | 14 | class SpeechCommandsv1(Dataset): 15 | CLASSES = ['bed', 'bird', 'cat', 'dog', 'down', 'eight', 'five', 'four', 'go', 'happy', 'house', 'left', 'marvin', 'nine', 'no', 16 | 'off', 'on', 'one', 'right', 'seven', 'sheila', 'six', 'stop', 'three', 'tree', 'two', 'up', 'wow', 'yes', 'zero'] 17 | 18 | def __init__(self, split, data_cfg, mixup_cfg=None, transform=None, spec_transform=None) -> None: 19 | super().__init__() 20 | assert split in ['train', 'val', 'test'] 21 | self.num_classes = len(self.CLASSES) 22 | self.transform = transform 23 | self.spec_transform = spec_transform 24 | self.mixup = mixup_cfg['MIXUP'] if mixup_cfg is not None else 0.0 25 | self.mixup_alpha = mixup_cfg['MIXUP_ALPHA'] if mixup_cfg is not None else 0.0 26 | self.label_smooth = mixup_cfg['SMOOTHING'] if mixup_cfg is not None else 0.0 27 | self.num_frames = data_cfg['SAMPLE_RATE'] * data_cfg['AUDIO_LENGTH'] 28 | 29 | self.mel_tf = T.MelSpectrogram(data_cfg['SAMPLE_RATE'], data_cfg['WIN_LENGTH'], data_cfg['WIN_LENGTH'], data_cfg['HOP_LENGTH'], data_cfg['FMIN'], data_cfg['FMAX'], n_mels=data_cfg['N_MELS'], norm='slaney') # using mel_scale='slaney' is better 30 | self.resample = T.Resample(data_cfg['SOURCE_SAMPLE'], data_cfg['SAMPLE_RATE']) 31 | 32 | self.data, self.targets = self.get_data(data_cfg['ROOT'], split) 33 | print(f"Found {len(self.data)} {split} audios in {data_cfg['ROOT']}.") 34 | 35 | 36 | def get_data(self, root: str, split: int): 37 | root = Path(root) 38 | if split == 'train': 39 | files = root.rglob('*.wav') 40 | excludes = [] 41 | with open(root / 'testing_list.txt') as f1, open(root / 'validation_list.txt') as f2: 42 | excludes += f1.read().splitlines() 43 | excludes += f2.read().splitlines() 44 | 45 | excludes = list(map(lambda x: str(root / x), excludes)) 46 | files = list(filter(lambda x: "_background_noise_" not in str(x) and str(x) not in excludes, files)) 47 | else: 48 | split = 'testing' if split == 'test' else 'validation' 49 | with open(root / f'{split}_list.txt') as f: 50 | files = f.read().splitlines() 51 | 52 | files = list(map(lambda x: root / x, files)) 53 | 54 | targets = list(map(lambda x: int(self.CLASSES.index(str(x.parent).rsplit(os.path.sep, maxsplit=1)[-1])), files)) 55 | assert len(files) == len(targets) 56 | return files, targets 57 | 58 | 59 | def __len__(self) -> int: 60 | return len(self.data) 61 | 62 | 63 | def __getitem__(self, index: int) -> Tuple[Tensor, Tensor]: 64 | audio, _ = torchaudio.load(self.data[index]) 65 | audio = self.resample(audio) 66 | if audio.shape[1] < self.num_frames: audio = torch.cat([audio, torch.zeros(1, self.num_frames-audio.shape[1])], dim=-1) # if less than 1s, pad the audio 67 | target = torch.tensor(self.targets[index]) 68 | 69 | if self.transform: audio = self.transform(audio) 70 | 71 | if random.random() < self.mixup: 72 | next_index = random.randint(0, len(self.data)-1) 73 | next_audio, _ = torchaudio.load(self.data[next_index]) 74 | next_audio = self.resample(next_audio) 75 | if next_audio.shape[1] < self.num_frames: next_audio = torch.cat([next_audio, torch.zeros(1, self.num_frames-next_audio.shape[1])], dim=-1) # if less than 1s, pad the audio 76 | next_target = torch.tensor(self.targets[next_index]) 77 | audio, target = mixup_augment(audio, target, next_audio, next_target, self.mixup_alpha, self.num_classes, self.label_smooth) 78 | else: 79 | target = F.one_hot(target, self.num_classes).float() 80 | 81 | audio = self.mel_tf(audio) # convert to mel spectrogram 82 | audio = 10.0 * audio.clamp_(1e-10).log10() # convert to log mel spectrogram 83 | 84 | if self.spec_transform: audio = self.spec_transform(audio) 85 | 86 | return audio, target 87 | 88 | 89 | class SpeechCommandsv2(SpeechCommandsv1): 90 | CLASSES = ['backward', 'bed', 'bird', 'cat', 'dog', 'down', 'eight', 'five', 'follow', 'forward', 'four', 'go', 'happy', 'house', 'learn', 'left', 91 | 'marvin', 'nine', 'no', 'off', 'on', 'one', 'right', 'seven', 'sheila', 'six', 'stop', 'three', 'tree', 'two', 'up', 'visual', 'wow', 'yes', 'zero'] 92 | 93 | 94 | if __name__ == '__main__': 95 | data_cfg = { 96 | 'ROOT': 'C:/Users/sithu/Documents/Datasets/SpeechCommands', 97 | 'SOURCE_SAMPLE': 16000, 98 | 'SAMPLE_RATE': 32000, 99 | 'AUDIO_LENGTH': 1, 100 | 'WIN_LENGTH': 1024, 101 | 'HOP_LENGTH': 320, 102 | 'N_MELS': 64, 103 | 'FMIN': 50, 104 | 'FMAX': 14000 105 | } 106 | aug_cfg = { 107 | 'MIXUP': 0.5, 108 | 'MIXUP_ALPHA': 10, 109 | 'SMOOTHING': 0.1 110 | } 111 | dataset = SpeechCommandsv1('train', data_cfg, aug_cfg) 112 | dataloader = DataLoader(dataset, 2, True) 113 | for audio, target in dataloader: 114 | print(audio.shape, target.argmax(dim=1)) 115 | print(audio.min(), audio.max()) 116 | break -------------------------------------------------------------------------------- /tools/train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import yaml 4 | import time 5 | import multiprocessing as mp 6 | from pprint import pprint 7 | from tqdm import tqdm 8 | from tabulate import tabulate 9 | from torch.nn.parallel import DistributedDataParallel as DDP 10 | from pathlib import Path 11 | from torch.utils.data import DataLoader 12 | from torch.cuda.amp import GradScaler, autocast 13 | from torch.utils.tensorboard import SummaryWriter 14 | 15 | import sys 16 | sys.path.insert(0, '.') 17 | from datasets import get_train_dataset, get_val_dataset, get_sampler 18 | from datasets.transforms import get_waveform_transforms, get_spec_transforms 19 | from models import get_model 20 | from utils.utils import fix_seeds, time_sync, setup_cudnn, setup_ddp, cleanup_ddp 21 | from utils.schedulers import get_scheduler 22 | from utils.losses import get_loss 23 | from utils.optimizers import get_optimizer 24 | from val import evaluate 25 | 26 | 27 | def main(cfg, gpu, save_dir): 28 | start = time_sync() 29 | 30 | best_score = 0.0 31 | num_workers = mp.cpu_count() 32 | device = torch.device(cfg['DEVICE']) 33 | train_config = cfg['TRAIN'] 34 | epochs = train_config['EPOCHS'] 35 | metric = cfg['DATASET']['METRIC'] 36 | lr = cfg['OPTIMIZER']['LR'] 37 | 38 | # augmentations 39 | # waveform_transforms = get_waveform_transforms(cfg['DATASET']['SAMPLE_RATE'], cfg['DATASET']['AUDIO_LENGTH'], 0.1, 3, 0.005) 40 | waveform_transforms = None 41 | spec_transforms = get_spec_transforms(cfg['DATASET'], cfg['AUG']) 42 | 43 | # dataset 44 | train_dataset = get_train_dataset(cfg['DATASET'], cfg['AUG'], waveform_transforms, spec_transforms) 45 | val_dataset = get_val_dataset(cfg['DATASET']) 46 | 47 | # dataset sampler 48 | train_sampler, val_sampler = get_sampler(train_config['DDP'], train_dataset, val_dataset) 49 | 50 | # dataloader 51 | train_dataloader = DataLoader(train_dataset, batch_size=train_config['BATCH_SIZE'], num_workers=num_workers, drop_last=True, pin_memory=True, sampler=train_sampler) 52 | val_dataloader = DataLoader(val_dataset, batch_size=train_config['BATCH_SIZE'], num_workers=num_workers, pin_memory=True, sampler=val_sampler) 53 | 54 | # create model 55 | model = get_model(cfg['MODEL']['NAME'], train_dataset.num_classes) 56 | model._init_weights(cfg['MODEL']['PRETRAINED']) 57 | model = model.to(device) 58 | if train_config['DDP']: model = DDP(model, device_ids=[gpu]) 59 | 60 | # loss function, optimizer, scheduler, AMP scaler, tensorboard writer 61 | loss_fn = get_loss(train_config['LOSS']) 62 | optimizer = get_optimizer(model, cfg['OPTIMIZER']['NAME'], lr, cfg['OPTIMIZER']['WEIGHT_DECAY']) 63 | scheduler = get_scheduler(cfg, optimizer) 64 | scaler = GradScaler(enabled=train_config['AMP']) 65 | writer = SummaryWriter(save_dir / 'logs') 66 | iters_per_epoch = len(train_dataset) // train_config['BATCH_SIZE'] 67 | 68 | for epoch in range(epochs): 69 | model.train() 70 | 71 | if train_config['DDP']: train_sampler.set_epoch(epoch) 72 | 73 | train_loss = 0.0 74 | pbar = tqdm(enumerate(train_dataloader), total=iters_per_epoch, desc=f"Epoch: [{epoch+1}/{epochs}] Iter: [{0}/{iters_per_epoch}] LR: {lr:.8f} Loss: {train_loss:.6f}") 75 | 76 | for iter, (audio, target) in pbar: 77 | audio = audio.to(device) 78 | target = target.to(device) 79 | 80 | optimizer.zero_grad() 81 | 82 | with autocast(enabled=train_config['AMP']): 83 | pred = model(audio) 84 | loss = loss_fn(pred, target) 85 | 86 | # Backpropagation 87 | scaler.scale(loss).backward() 88 | scaler.step(optimizer) 89 | scaler.update() 90 | 91 | lr = scheduler.get_last_lr()[0] 92 | train_loss += loss.item() 93 | 94 | pbar.set_description(f"Epoch: [{epoch+1}/{epochs}] Iter: [{iter+1}/{iters_per_epoch}] LR: {lr:.8f} Loss: {train_loss/(iter+1):.6f}") 95 | 96 | train_loss /= iter+1 97 | writer.add_scalar('train/loss', train_loss, epoch) 98 | 99 | scheduler.step() 100 | torch.cuda.empty_cache() 101 | 102 | if (epoch+1) % train_config['EVAL_INTERVAL'] == 0 or (epoch+1) == epochs: 103 | # evaluate the model 104 | score = evaluate(val_dataloader, model, device, metric)[0] 105 | writer.add_scalar(f'val/{metric}', score, epoch) 106 | 107 | if score >= best_score: 108 | best_score = score 109 | torch.save(model.module.state_dict() if train_config['DDP'] else model.state_dict(), save_dir / f"{cfg['MODEL']['NAME']}_{cfg['DATASET']['NAME']}.pth") 110 | print(f"Current {metric}: {score:.2f} Best {metric}: {best_score:.2f}") 111 | 112 | end = time.gmtime(time_sync() - start) 113 | total_time = time.strftime("%H:%M:%S", end) 114 | 115 | # results table 116 | table = [ 117 | [f'Best {metric}', f"{best_score:.2f}"], 118 | ['Total Training Time', total_time] 119 | ] 120 | print(tabulate(table, numalign='right')) 121 | 122 | writer.close() 123 | pbar.close() 124 | 125 | 126 | if __name__ == '__main__': 127 | parser = argparse.ArgumentParser() 128 | parser.add_argument('--cfg', type=str, required=True, help='Experiment configuration file name') 129 | args = parser.parse_args() 130 | 131 | with open(args.cfg) as f: 132 | cfg = yaml.load(f, Loader=yaml.SafeLoader) 133 | 134 | pprint(cfg) 135 | save_dir = Path(cfg['TRAIN']['SAVE_DIR']) 136 | save_dir.mkdir(exist_ok=True) 137 | fix_seeds(123) 138 | setup_cudnn() 139 | gpu = setup_ddp() 140 | main(cfg, gpu, save_dir) 141 | cleanup_ddp() -------------------------------------------------------------------------------- /datasets/aug_test.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import torch\n", 10 | "import torchaudio\n", 11 | "from torchaudio import transforms as T\n", 12 | "from torchaudio import functional as AF\n", 13 | "from transforms import *\n", 14 | "import sys\n", 15 | "sys.path.insert(0, '../')\n", 16 | "from utils.visualize import play_audio, plot_spectrogram" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": null, 22 | "metadata": {}, 23 | "outputs": [], 24 | "source": [ 25 | "audio_path = '../assests/test.wav'\n", 26 | "noise_path = '../assests/noises/voices.wav'\n", 27 | "\n", 28 | "sample_rate = 32000\n", 29 | "audio_length = 5 # 5s\n", 30 | "audio, sr = torchaudio.load(audio_path)\n", 31 | "noise, sr1 = torchaudio.load(noise_path)\n", 32 | "\n", 33 | "if sr != sample_rate: audio = AF.resample(audio, sr, sample_rate)\n", 34 | "if audio.shape[0] != 1: audio = audio[:1]\n", 35 | "if audio.shape[1]!= sample_rate*audio_length: audio = audio[:, :sample_rate*audio_length]\n", 36 | "if sr1 != sample_rate: noise = AF.resample(noise, sr1, sample_rate)" 37 | ] 38 | }, 39 | { 40 | "cell_type": "code", 41 | "execution_count": null, 42 | "metadata": {}, 43 | "outputs": [], 44 | "source": [ 45 | "play_audio(audio, sample_rate)\n", 46 | "play_audio(noise, sample_rate)" 47 | ] 48 | }, 49 | { 50 | "cell_type": "markdown", 51 | "metadata": {}, 52 | "source": [ 53 | "## WaveForm Augmentations" 54 | ] 55 | }, 56 | { 57 | "cell_type": "code", 58 | "execution_count": null, 59 | "metadata": {}, 60 | "outputs": [], 61 | "source": [ 62 | "## Background Noise\n", 63 | "background_noise = BackgroundNoise(3) #20(weak)~3(strong)\n", 64 | "noise_audio = background_noise(audio, noise)\n", 65 | "play_audio(noise_audio, sample_rate)" 66 | ] 67 | }, 68 | { 69 | "cell_type": "code", 70 | "execution_count": null, 71 | "metadata": {}, 72 | "outputs": [], 73 | "source": [ 74 | "# Fade In/Out\n", 75 | "fade = Fade(0.1) # 0.1(10% of audio)~0.5(50% of audio)\n", 76 | "fade_audio = fade(audio)\n", 77 | "play_audio(fade_audio, sample_rate)" 78 | ] 79 | }, 80 | { 81 | "cell_type": "code", 82 | "execution_count": null, 83 | "metadata": {}, 84 | "outputs": [], 85 | "source": [ 86 | "# Add Volume\n", 87 | "volume = Volume(20) # 3(weak)~20(strong)\n", 88 | "vol_audio = volume(audio)\n", 89 | "play_audio(vol_audio, sample_rate)" 90 | ] 91 | }, 92 | { 93 | "cell_type": "code", 94 | "execution_count": null, 95 | "metadata": {}, 96 | "outputs": [], 97 | "source": [ 98 | "# Gaussian Noise\n", 99 | "gnoise = GaussianNoise(0.005) # 0.005(weak)~0.02(strong)\n", 100 | "gnoise_audio = gnoise(audio)\n", 101 | "play_audio(gnoise_audio, sample_rate)" 102 | ] 103 | }, 104 | { 105 | "cell_type": "markdown", 106 | "metadata": {}, 107 | "source": [ 108 | "## Spectrogram Augmentations" 109 | ] 110 | }, 111 | { 112 | "cell_type": "code", 113 | "execution_count": null, 114 | "metadata": {}, 115 | "outputs": [], 116 | "source": [ 117 | "sample_rate = 32000\n", 118 | "win_length = 1024\n", 119 | "hop_length = 320\n", 120 | "n_mels = 64\n", 121 | "fmin = 50\n", 122 | "fmax = 14000\n", 123 | "audio_length = 5\n", 124 | "\n", 125 | "mel_tf = T.MelSpectrogram(sample_rate, win_length, win_length, hop_length, fmin, fmax, n_mels=n_mels, norm='slaney')\n", 126 | "spec = mel_tf(audio)\n", 127 | "plot_spectrogram(spec)\n", 128 | "spec.shape, spec.min(), spec.max()" 129 | ] 130 | }, 131 | { 132 | "cell_type": "code", 133 | "execution_count": null, 134 | "metadata": {}, 135 | "outputs": [], 136 | "source": [ 137 | "## Log-mel Spectrogram\n", 138 | "logspec = 10.0 * spec.log10()\n", 139 | "plot_spectrogram(logspec)" 140 | ] 141 | }, 142 | { 143 | "cell_type": "code", 144 | "execution_count": null, 145 | "metadata": {}, 146 | "outputs": [], 147 | "source": [ 148 | "## Time Masking\n", 149 | "tmasking = TimeMasking(200, audio_length)\n", 150 | "tspec = tmasking(spec)\n", 151 | "plot_spectrogram(tspec)" 152 | ] 153 | }, 154 | { 155 | "cell_type": "code", 156 | "execution_count": null, 157 | "metadata": {}, 158 | "outputs": [], 159 | "source": [ 160 | "## Frequency Masking\n", 161 | "fmasking = FrequencyMasking(24, n_mels)\n", 162 | "fspec = fmasking(spec)\n", 163 | "plot_spectrogram(fspec)" 164 | ] 165 | }, 166 | { 167 | "cell_type": "code", 168 | "execution_count": null, 169 | "metadata": {}, 170 | "outputs": [], 171 | "source": [ 172 | "## Filter Augment\n", 173 | "filteraug = FilterAugment((-20, 20), (5, 10))\n", 174 | "filtspec = filteraug(spec)\n", 175 | "plot_spectrogram(spec)\n", 176 | "plot_spectrogram(filtspec)" 177 | ] 178 | }, 179 | { 180 | "cell_type": "code", 181 | "execution_count": null, 182 | "metadata": {}, 183 | "outputs": [], 184 | "source": [] 185 | } 186 | ], 187 | "metadata": { 188 | "interpreter": { 189 | "hash": "78184fe1b8a3f830767e8814b2b01c36fc7c8ac521e39cb583cd3fce210fee57" 190 | }, 191 | "kernelspec": { 192 | "display_name": "Python 3", 193 | "language": "python", 194 | "name": "python3" 195 | }, 196 | "language_info": { 197 | "codemirror_mode": { 198 | "name": "ipython", 199 | "version": 3 200 | }, 201 | "file_extension": ".py", 202 | "mimetype": "text/x-python", 203 | "name": "python", 204 | "nbconvert_exporter": "python", 205 | "pygments_lexer": "ipython3", 206 | "version": "3.9.5" 207 | } 208 | }, 209 | "nbformat": 4, 210 | "nbformat_minor": 2 211 | } 212 | -------------------------------------------------------------------------------- /models/cnn14.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, Tensor 3 | 4 | 5 | class Conv2Layer(nn.Module): 6 | def __init__(self, c1, c2, pool_size=2): 7 | super().__init__() 8 | self.conv1 = nn.Conv2d(c1, c2, 3, 1, 1, bias=False) 9 | self.conv2 = nn.Conv2d(c2, c2, 3, 1, 1, bias=False) 10 | self.bn1 = nn.BatchNorm2d(c2) 11 | self.bn2 = nn.BatchNorm2d(c2) 12 | self.relu = nn.ReLU() 13 | self.avgpool = nn.AvgPool2d(pool_size) 14 | self.dropout = nn.Dropout(0.2) 15 | 16 | def forward(self, x: Tensor) -> Tensor: 17 | x = self.relu(self.bn1(self.conv1(x))) 18 | x = self.relu(self.bn2(self.conv2(x))) 19 | x = self.avgpool(x) 20 | x = self.dropout(x) 21 | return x 22 | 23 | 24 | 25 | class CNN14(nn.Module): 26 | def __init__(self, num_classes: int = 50): 27 | super().__init__() 28 | self.bn0 = nn.BatchNorm2d(64) 29 | self.conv_block1 = Conv2Layer(1, 64) 30 | self.conv_block2 = Conv2Layer(64, 128) 31 | self.conv_block3 = Conv2Layer(128, 256) 32 | self.conv_block4 = Conv2Layer(256, 512) 33 | self.conv_block5 = Conv2Layer(512, 1024) 34 | self.conv_block6 = Conv2Layer(1024, 2048, 1) 35 | 36 | self.dropout = nn.Dropout(0.5) 37 | self.relu = nn.ReLU() 38 | 39 | self.fc1 = nn.Linear(2048, 2048) 40 | self.fc = nn.Linear(2048, num_classes) 41 | 42 | def _init_weights(self, pretrained: str = None): 43 | if pretrained: 44 | print(f"Loading Pretrained Weights from {pretrained}") 45 | pretrained_dict = torch.load(pretrained, map_location='cpu') 46 | model_dict = self.state_dict() 47 | for k in model_dict.keys(): 48 | if not k.startswith('fc.'): 49 | model_dict[k] = pretrained_dict[k] 50 | self.load_state_dict(model_dict) 51 | else: 52 | for m in self.modules(): 53 | if isinstance(m, nn.Conv2d): 54 | nn.init.xavier_uniform_(m.weight) 55 | elif isinstance(m, nn.Linear): 56 | nn.init.xavier_uniform_(m.weight) 57 | if hasattr(m, 'bias'): 58 | nn.init.constant_(m.bias, 0) 59 | elif isinstance(m, nn.BatchNorm2d): 60 | nn.init.constant_(m.weight, 1.0) 61 | nn.init.constant_(m.bias, 0.0) 62 | 63 | def forward(self, x: Tensor) -> Tensor: 64 | # x [B, 1, mel_bins, time_steps] 65 | x = x.permute(0, 2, 3, 1) 66 | x = self.bn0(x) 67 | x = x.transpose(1, 3) 68 | 69 | x = self.conv_block1(x) 70 | x = self.conv_block2(x) 71 | x = self.conv_block3(x) 72 | x = self.conv_block4(x) 73 | x = self.conv_block5(x) 74 | x = self.conv_block6(x) 75 | 76 | x = x.mean(3) 77 | x = x.max(dim=2)[0] + x.mean(2) 78 | x = self.dropout(x) 79 | x = self.relu(self.fc1(x)) 80 | return self.fc(x) 81 | 82 | 83 | class CNN14DecisionLevelMax(nn.Module): 84 | def __init__(self, num_classes: int = 50): 85 | super().__init__() 86 | self.interpolate_ratio = 32 # downsampled ratio 87 | 88 | self.bn0 = nn.BatchNorm2d(64) 89 | self.conv_block1 = Conv2Layer(1, 64) 90 | self.conv_block2 = Conv2Layer(64, 128) 91 | self.conv_block3 = Conv2Layer(128, 256) 92 | self.conv_block4 = Conv2Layer(256, 512) 93 | self.conv_block5 = Conv2Layer(512, 1024) 94 | self.conv_block6 = Conv2Layer(1024, 2048, 1) 95 | 96 | self.dropout = nn.Dropout(0.5) 97 | self.relu = nn.ReLU() 98 | self.maxpool = nn.MaxPool1d(3, 1, 1) 99 | self.avgpool = nn.AvgPool1d(3, 1, 1) 100 | 101 | self.fc1 = nn.Linear(2048, 2048) 102 | self.fc = nn.Linear(2048, num_classes) 103 | 104 | def _init_weights(self, pretrained: str = None): 105 | if pretrained: 106 | print(f"Loading Pretrained Weights from {pretrained}") 107 | pretrained_dict = torch.load(pretrained, map_location='cpu') 108 | model_dict = self.state_dict() 109 | for k in model_dict.keys(): 110 | if not k.startswith('fc.'): 111 | model_dict[k] = pretrained_dict[k] 112 | self.load_state_dict(model_dict) 113 | else: 114 | for m in self.modules(): 115 | if isinstance(m, nn.Conv2d): 116 | nn.init.xavier_uniform_(m.weight) 117 | elif isinstance(m, nn.Linear): 118 | nn.init.xavier_uniform_(m.weight) 119 | if hasattr(m, 'bias'): 120 | nn.init.constant_(m.bias, 0) 121 | elif isinstance(m, nn.BatchNorm2d): 122 | nn.init.constant_(m.weight, 1.0) 123 | nn.init.constant_(m.bias, 0.0) 124 | 125 | def forward(self, x: Tensor) -> Tensor: 126 | # x [B, 1, mel_bins, time_steps] 127 | num_frames = x.shape[-1] 128 | x = x.permute(0, 2, 3, 1) 129 | x = self.bn0(x) 130 | x = x.transpose(1, 3) 131 | 132 | x = self.conv_block1(x) 133 | x = self.conv_block2(x) 134 | x = self.conv_block3(x) 135 | x = self.conv_block4(x) 136 | x = self.conv_block5(x) 137 | x = self.conv_block6(x) 138 | 139 | x = x.mean(3) 140 | x = self.maxpool(x) + self.avgpool(x) 141 | x = self.dropout(x) 142 | x = x.transpose(1, 2) 143 | x = self.relu(self.fc1(x)) 144 | 145 | segmentwise = self.fc(x).sigmoid() 146 | clipwise = segmentwise.max(dim=1)[0] 147 | 148 | # get framewise output 149 | framewise = interpolate(segmentwise, self.interpolate_ratio) 150 | framewise = pad_framewise(framewise, num_frames) 151 | 152 | return framewise, clipwise 153 | 154 | 155 | 156 | def interpolate(x, ratio): 157 | B, T, C = x.shape 158 | upsampled = x[:, :, None, :].repeat(1, 1, ratio, 1) 159 | upsampled = upsampled.reshape(B, T*ratio, C) 160 | return upsampled 161 | 162 | 163 | def pad_framewise(framewise, num_frames): 164 | pad = framewise[:, -1:, :].repeat(1, num_frames-framewise.shape[1], 1) 165 | return torch.cat([framewise, pad], dim=1) 166 | 167 | 168 | if __name__ == '__main__': 169 | import time 170 | model = CNN14(527) 171 | model.load_state_dict(torch.load('checkpoints/cnn14.pth', map_location='cpu')) 172 | x = torch.randn(3, 1, 64, 701) 173 | start = time.time() 174 | y = model(x) 175 | print(time.time()-start) 176 | print(y.shape) 177 | print(y.min(), y.max()) -------------------------------------------------------------------------------- /datasets/transforms.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random 3 | import math 4 | import numpy as np 5 | from torch import nn, Tensor 6 | from torchaudio import transforms as T 7 | from typing import Tuple 8 | 9 | 10 | ####################################################### 11 | ## Voice Activity Detector 12 | 13 | class VoiceActivityDetector(nn.Module): 14 | def __init__(self): 15 | super().__init__() 16 | self.vad = T.Vad( 17 | sample_rate=16000, 18 | trigger_level=7, 19 | trigger_time=0.25, 20 | search_time=1.0, 21 | allowed_gap=0.25, 22 | pre_trigger_time=0.0, 23 | boot_time=0.35, 24 | noise_up_time=0.1, 25 | noise_down_time=0.01, 26 | noise_reduction_amount=1.35, 27 | measure_freq=20.0, 28 | measure_duration=None, 29 | measure_smooth_time=0.4 30 | ) 31 | 32 | def forward(self, waveform: Tensor) -> Tensor: 33 | waveform = self.vad(waveform) 34 | return waveform 35 | 36 | 37 | ############################################################# 38 | ### WAVEFORM AUGMENTATIONS ### 39 | 40 | class Fade(nn.Module): 41 | def __init__(self, fio: float = 0.1, sample_rate: int = 32000, audio_length: int = 5): 42 | super().__init__() 43 | fiop = int(fio * sample_rate * audio_length) 44 | self.fade = T.Fade(fiop, fiop) 45 | 46 | def forward(self, waveform: Tensor) -> Tensor: 47 | return self.fade(waveform) 48 | 49 | 50 | class Volume(nn.Module): 51 | """ 52 | gain: in decibel (3 (weak) to 20 (strong)) 53 | """ 54 | def __init__(self, gain: int = 3): 55 | super().__init__() 56 | self.volume = T.Vol(gain, gain_type='db') 57 | 58 | def forward(self, waveform: Tensor) -> Tensor: 59 | return self.volume(waveform) 60 | 61 | 62 | class GaussianNoise(nn.Module): 63 | def __init__(self, sigma: float = 0.005): 64 | super().__init__() 65 | self.sigma = sigma # 0.005 (weak) to 0.02 (strong) 66 | 67 | def forward(self, waveform: Tensor) -> Tensor: 68 | gauss = np.random.normal(0, self.sigma, waveform.shape) 69 | waveform += gauss 70 | return waveform 71 | 72 | 73 | class BackgroundNoise(nn.Module): 74 | """To add background noise to audio 75 | For simplicity, you can add audio Tensor with noise Tensor. 76 | A common way to adjust the intensity of noise is to change Signal-to-Noise Ratio (SNR) 77 | 78 | SNR = audio / noise 79 | SNR(db) = 10 * log10(SNR) 80 | """ 81 | def __init__(self, snr_db: int = 20): 82 | super().__init__() 83 | self.snr_db = snr_db # 3 (strong) to 20 (weak) 84 | 85 | def forward(self, waveform: Tensor, noise: Tensor) -> Tensor: 86 | # assume waveform and noise have same sample rates 87 | if noise.shape[1] < waveform.shape[1]: 88 | noise = torch.cat([noise[:1], torch.zeros(1, waveform.shape[1]-noise.shape[1])], dim=-1) 89 | noise = noise[:1, :waveform.shape[1]] 90 | scale = math.exp(self.snr_db / 10) * noise.norm(p=2) / waveform.norm(p=2) 91 | return (scale * waveform + noise) / 2 92 | 93 | 94 | def mixup_augment(audio1: Tensor, target1: Tensor, audio2: Tensor, target2: Tensor, alpha: int = 10, num_classes: int = 50, smoothing: float = 0.1) -> Tuple[Tensor, Tensor]: 95 | ## assume audio1 and audio2 are mono channel audios and have same sampling rates 96 | mix_lambda = np.random.beta(alpha, alpha) 97 | 98 | off_value = smoothing / num_classes 99 | on_value = 1 - smoothing + off_value 100 | target1 = torch.full((1, num_classes), off_value).scatter_(1, target1.long().view(-1, 1), on_value).squeeze() 101 | target2 = torch.full((1, num_classes), off_value).scatter_(1, target2.long().view(-1, 1), on_value).squeeze() 102 | 103 | # target1 = F.one_hot(target1, num_classes) 104 | # target2 = F.one_hot(target2, num_classes) 105 | 106 | audio = audio1 * mix_lambda + audio2 * (1 - mix_lambda) 107 | target = target1 * mix_lambda + target2 * (1 - mix_lambda) 108 | 109 | return audio, target 110 | 111 | 112 | ########################################################## 113 | ## Spectrogram Augmentations 114 | 115 | class TimeMasking(nn.Module): 116 | def __init__(self, mask: int = 96, audio_length: int = 5): 117 | super().__init__() 118 | assert mask < audio_length*100, f"TimeMasking parameter should be less than time frames >> {mask} > {audio_length*100}" 119 | self.masking = T.TimeMasking(mask) 120 | 121 | def forward(self, spec: Tensor) -> Tensor: 122 | return self.masking(spec) 123 | 124 | 125 | class FrequencyMasking(nn.Module): 126 | def __init__(self, mask: int = 24, n_mels: int = 64): 127 | super().__init__() 128 | assert mask < n_mels, f"FrequencyMasking parameter should be less than num mels >> {mask} > {n_mels}" 129 | self.masking = T.FrequencyMasking(mask) 130 | 131 | def forward(self, spec: Tensor) -> Tensor: 132 | return self.masking(spec) 133 | 134 | 135 | class FilterAugment(nn.Module): 136 | """ 137 | https://github.com/frednam93/FilterAugSED 138 | https://arxiv.org/abs/2107.03649 139 | """ 140 | def __init__(self, db_range=(-7.5, 6), n_bands=(2, 5)): 141 | super().__init__() 142 | self.db_range = db_range 143 | self.n_bands = n_bands 144 | 145 | def forward(self, audio: Tensor) -> Tensor: 146 | C, F, _ = audio.shape 147 | n_freq_band = random.randint(self.n_bands[0], self.n_bands[1]) 148 | 149 | if n_freq_band > 1: 150 | band_boundary_freqs = torch.cat([ 151 | torch.tensor([0]), 152 | torch.sort(torch.randint(1, F-1, (F-1,)))[0], 153 | torch.tensor([F]) 154 | ]) 155 | band_factors = torch.rand((C, n_freq_band)) * (self.db_range[1] - self.db_range[0]) + self.db_range[0] 156 | band_factors = 10 ** (band_factors / 20) 157 | freq_filter = torch.ones((C, F, 1)) 158 | 159 | for i in range(n_freq_band): 160 | freq_filter[:, band_boundary_freqs[i]:band_boundary_freqs[i+1], :] = band_factors[:, i].unsqueeze(-1).unsqueeze(-1) 161 | 162 | audio *= freq_filter 163 | return audio 164 | 165 | 166 | def get_waveform_transforms(data_config, aug_config): 167 | return nn.Sequential( 168 | Fade(0.1, data_config['SAMPLE_RATE'], data_config['AUDIO_LENGTH']), 169 | Volume(3), 170 | GaussianNoise(0.005) 171 | ) 172 | 173 | 174 | def get_spec_transforms(data_config, aug_config): 175 | return nn.Sequential( 176 | TimeMasking(aug_config['TIME_MASK'], data_config['AUDIO_LENGTH']), 177 | FrequencyMasking(aug_config['FREQ_MASK'], data_config['N_MELS']) 178 | ) 179 | 180 | 181 | if __name__ == '__main__': 182 | torch.manual_seed(123) 183 | x = torch.randn(4, 1, 64, 701) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | #
Audio Classification, Tagging & Sound Event Detection in PyTorch
2 | 3 | Progress: 4 | 5 | - [x] Fine-tune on audio classification 6 | - [ ] Fine-tune on audio tagging 7 | - [ ] Fine-tune on sound event detection 8 | - [x] Add tagging metrics 9 | - [ ] Add Tutorial 10 | - [x] Add Augmentation Notebook 11 | - [ ] Add more schedulers 12 | - [ ] Add FSDKaggle2019 dataset 13 | - [ ] Add MTT dataset 14 | - [ ] Add DESED 15 | - [ ] Test in real-time 16 | 17 | 18 | ##
Model Zoo
19 | 20 | [cnn14]: https://drive.google.com/file/d/1GhDXnyj9KgDMyOOoMuSBn8pb1iELlEp7/view?usp=sharing 21 | [cnn1416k]: https://drive.google.com/file/d/1BGAfVH_6xt06YZUDPqRLNtyj7KoyoEaF/view?usp=sharing 22 | [cnn14max]: https://drive.google.com/file/d/1K0XKf6JbFIgCoo70WvdunQoWWMMmrqDl/view?usp=sharing 23 | 24 |
25 | AudioSet Pretrained Models 26 | 27 | Model | Task | mAP
(%) | Sample Rate
(kHz) | Window Length | Num Mels | Fmax | Weights 28 | --- | --- | --- | --- | --- | --- | --- | --- 29 | CNN14 | Tagging | 43.1 | 32 | 1024 | 64 | 14k | [download][cnn14] 30 | CNN14_16k | Tagging | 43.8 | 16 | 512 | 64 | 8k | [download][cnn1416k] 31 | || 32 | CNN14_DecisionLevelMax | SED | 38.5 | 32 | 1024 | 64 | 14k | [download][cnn14max] 33 | 34 |
35 | 36 | > Note: These models will be used as a pretrained model in the fine-tuning tasks below. Check out [audioset-tagging-cnn](https://github.com/qiuqiangkong/audioset_tagging_cnn), if you want to train on AudioSet dataset. 37 | 38 | [esc50cnn14]: https://drive.google.com/file/d/1itN-WyEL6Wp_jVBlld6vLaj47UWL2JaP/view?usp=sharing 39 | [fsd2018]: https://drive.google.com/file/d/1KzKd4icIV2xF7BdW9EZpU9BAZyfCatrD/view?usp=sharing 40 | [scv1]: https://drive.google.com/file/d/1Mc4UxHOEvaeJXKcuP4RiTggqZZ0CCmOB/view?usp=sharing 41 | 42 |
43 | Fine-tuned Classification Models 44 | 45 | Model | Dataset | Accuracy
(%) | Sample Rate
(kHz) | Weights 46 | --- | --- | --- | --- | --- 47 | CNN14 | ESC50 (Fold-5)| 95.75 | 32 | [download][esc50cnn14] 48 | CNN14 | FSDKaggle2018 (test) | 93.56 | 32 | [download][fsd2018] 49 | CNN14 | SpeechCommandsv1 (val/test) | 96.60/96.77 | 32 | [download][scv1] 50 | 51 |
52 | 53 |
54 | Fine-tuned Tagging Models 55 | 56 | Model | Dataset | mAP(%) | AUC | d-prime | Sample Rate
(kHz) | Config | Weights 57 | --- | --- | --- | --- | --- | --- | --- | --- 58 | CNN14 | FSDKaggle2019 | - | - | - | 32 | - | - 59 | 60 |
61 | 62 |
63 | Fine-tuned SED Models 64 | 65 | Model | Dataset | F1 | Sample Rate
(kHz) | Config | Weights 66 | --- | --- | --- | --- | --- | --- 67 | CNN14_DecisionLevelMax | DESED | - | 32 | - | - 68 | 69 |
70 | 71 | --- 72 | 73 | ##
Supported Datasets
74 | 75 | [esc50]: https://github.com/karolpiczak/ESC-50 76 | [fsdkaggle2018]: https://zenodo.org/record/2552860 77 | [fsdkaggle2019]: https://zenodo.org/record/3612637 78 | [audioset]: https://research.google.com/audioset/ 79 | [urbansound8k]: https://urbansounddataset.weebly.com/urbansound8k.html 80 | [speechcommandsv1]: https://ai.googleblog.com/2017/08/launching-speech-commands-dataset.html 81 | [speechcommandsv2]: http://download.tensorflow.org/data/speech_commands_v0.02.tar.gz 82 | [mtt]: https://github.com/keunwoochoi/magnatagatune-list 83 | [desed]: https://project.inria.fr/desed/ 84 | 85 | Dataset | Task | Classes | Train | Val | Test | Audio Length | Audio Spec | Size 86 | --- | --- | --- | --- | --- | --- | --- | --- | --- 87 | [ESC-50][esc50] | Classification | 50 | 2,000 | 5 folds | - | 5s | 44.1kHz, mono | 600MB 88 | [UrbanSound8k][urbansound8k] | Classification | 10 | 8,732 | 10 folds | - | <=4s | Vary | 5.6GB 89 | [FSDKaggle2018][fsdkaggle2018] | Classification | 41 | 9,473 | - | 1,600 | 300ms~30s | 44.1kHz, mono | 4.6GB 90 | [SpeechCommandsv1][speechcommandsv1] | Classification | 30 | 51,088 | 6,798 | 6,835 | <=1s | 16kHz, mono | 1.4GB 91 | [SpeechCommandsv2][speechcommandsv2] | Classification | 35 | 84,843 | 9,981 | 11,005 | <=1s | 16kHz, mono | 2.3GB 92 | || 93 | [FSDKaggle2019][fsdkaggle2019]* | Tagging | 80 | 4,970+19,815 | - | 4,481 | 300ms~30s | 44.1kHz, mono | 24GB 94 | [MTT][mtt]* | Tagging | 50 | 19,000 | - | - | - | - | 3GB 95 | || 96 | [DESED][desed]* | SED | 10 | - | - | - | 10 | - | - 97 | 98 | > Notes: `*` datasets are not available yet. Classification dataset are treated as multi-class/single-label classification and tagging and sed datasets are treated as multi-label classification. 99 | 100 |
101 | Dataset Structure (click to expand) 102 | 103 | Download the dataset and prepare it into the following structure. 104 | 105 | ``` 106 | datasets 107 | |__ ESC50 108 | |__ audio 109 | 110 | |__ Urbansound8k 111 | |__ audio 112 | 113 | |__ FSDKaggle2018 114 | |__ audio_train 115 | |__ audio_test 116 | |__ FSDKaggle2018.meta 117 | |__ train_post_competition.csv 118 | |__ test_post_competition_scoring_clips.csv 119 | 120 | |__ SpeechCommandsv1/v2 121 | |__ bed 122 | |__ bird 123 | |__ ... 124 | |__ testing_list.txt 125 | |__ validation_list.txt 126 | 127 | ``` 128 | 129 |
130 | 131 |
132 | Augmentations (click to expand) 133 | 134 | Currently, the following augmentations are supported. More will be added in the future. You can test the effects of augmentations with this [notebook](./datasets/aug_test.ipynb) 135 | 136 | WaveForm Augmentations: 137 | 138 | - [x] MixUp 139 | - [x] Background Noise 140 | - [x] Gaussian Noise 141 | - [x] Fade In/Out 142 | - [x] Volume 143 | - [ ] CutMix 144 | 145 | Spectrogram Augmentations: 146 | 147 | - [x] Time Masking 148 | - [x] Frequency Masking 149 | - [x] Filter Augmentation 150 | 151 |
152 | 153 | --- 154 | 155 | ##
Usage
156 | 157 |
158 | Requirements (click to expand) 159 | 160 | * python >= 3.6 161 | * torch >= 1.8.1 162 | * torchaudio >= 0.8.1 163 | 164 | Other requirements can be installed with `pip install -r requirements.txt`. 165 | 166 |
167 | 168 |
169 |
170 | Configuration (click to expand) 171 | 172 | * Create a configuration file in [configs](./configs/). Sample configuration for ESC50 dataset can be found [here](configs/esc50.yaml). 173 | * Copy the contents of this and then edit the fields you think if it is needed. 174 | * This configuration file is needed for all of training, evaluation and prediction scripts. 175 | 176 |
177 |
178 |
179 | Training (click to expand) 180 | 181 | To train with a single GPU: 182 | 183 | ```bash 184 | $ python tools/train.py --cfg configs/CONFIG_FILE_NAME.yaml 185 | ``` 186 | 187 | To train with multiple gpus, set `DDP` field in config file to `true` and run as follows: 188 | 189 | ```bash 190 | $ python -m torch.distributed.launch --nproc_per_node=2 --use_env tools/train.py --cfg configs/CONFIG_FILE_NAME.yaml 191 | ``` 192 | 193 |
194 | 195 |
196 |
197 | Evaluation (click to expand) 198 | 199 | Make sure to set `MODEL_PATH` of the configuration file to your trained model directory. 200 | 201 | ```bash 202 | $ python tools/val.py --cfg configs/CONFIG_FILE.yaml 203 | ``` 204 | 205 |
206 | 207 |
208 |
209 | Audio Classification/Tagging Inference 210 | 211 | * Set `MODEL_PATH` of the configuration file to your model's trained weights. 212 | * Change the dataset name in `DATASET` >> `NAME` as your trained model's dataset. 213 | * Set the testing audio file path in `TEST` >> `FILE`. 214 | * Run the following command. 215 | 216 | ```bash 217 | $ python tools/infer.py --cfg configs/CONFIG_FILE.yaml 218 | 219 | ## for example 220 | $ python tools/infer.py --cfg configs/audioset.yaml 221 | ``` 222 | You will get an output similar to this: 223 | 224 | ```bash 225 | Class Confidence 226 | ---------------------- ------------ 227 | Speech 0.897762 228 | Telephone bell ringing 0.752206 229 | Telephone 0.219329 230 | Inside, small room 0.20761 231 | Music 0.0770325 232 | ``` 233 | 234 |
235 | 236 |
237 |
238 | Sound Event Detection Inference 239 | 240 | * Set `MODEL_PATH` of the configuration file to your model's trained weights. 241 | * Change the dataset name in `DATASET` >> `NAME` as your trained model's dataset. 242 | * Set the testing audio file path in `TEST` >> `FILE`. 243 | * Run the following command. 244 | 245 | ```bash 246 | $ python tools/sed_infer.py --cfg configs/CONFIG_FILE.yaml 247 | 248 | ## for example 249 | $ python tools/sed_infer.py --cfg configs/audioset_sed.yaml 250 | ``` 251 | 252 | You will get an output similar to this: 253 | 254 | ```bash 255 | Class Start End 256 | ---------------------- ------- ----- 257 | Speech 2.2 7 258 | Telephone bell ringing 0 2.5 259 | ``` 260 | 261 | The following plot will also be shown, if you set `PLOT` to `true`: 262 | 263 | ![sed_result](./assests/sed_result.png) 264 | 265 |
266 | 267 |
268 |
269 | References (click to expand) 270 | 271 | * https://github.com/qiuqiangkong/audioset_tagging_cnn 272 | * https://github.com/YuanGongND/ast 273 | * https://github.com/frednam93/FilterAugSED 274 | * https://github.com/lRomul/argus-freesound 275 | 276 |
277 | 278 |
279 | Citations (click to expand) 280 | 281 | ``` 282 | @misc{kong2020panns, 283 | title={PANNs: Large-Scale Pretrained Audio Neural Networks for Audio Pattern Recognition}, 284 | author={Qiuqiang Kong and Yin Cao and Turab Iqbal and Yuxuan Wang and Wenwu Wang and Mark D. Plumbley}, 285 | year={2020}, 286 | eprint={1912.10211}, 287 | archivePrefix={arXiv}, 288 | primaryClass={cs.SD} 289 | } 290 | 291 | @misc{gong2021ast, 292 | title={AST: Audio Spectrogram Transformer}, 293 | author={Yuan Gong and Yu-An Chung and James Glass}, 294 | year={2021}, 295 | eprint={2104.01778}, 296 | archivePrefix={arXiv}, 297 | primaryClass={cs.SD} 298 | } 299 | 300 | @misc{nam2021heavily, 301 | title={Heavily Augmented Sound Event Detection utilizing Weak Predictions}, 302 | author={Hyeonuk Nam and Byeong-Yun Ko and Gyeong-Tae Lee and Seong-Hu Kim and Won-Ho Jung and Sang-Min Choi and Yong-Hwa Park}, 303 | year={2021}, 304 | eprint={2107.03649}, 305 | archivePrefix={arXiv}, 306 | primaryClass={eess.AS} 307 | } 308 | ``` 309 | 310 |
-------------------------------------------------------------------------------- /datasets/audioset.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | 3 | 4 | 5 | class AudioSet(Dataset): 6 | WINDS = ['/m/09x0r', '/m/05zppz', '/m/02zsn', '/m/0ytgt', '/m/01h8n0', '/m/02qldy', '/m/0261r1', '/m/0brhx', '/m/07p6fty', '/m/07q4ntr', '/m/07rwj3x', '/m/07sr1lc', '/m/04gy_2', '/t/dd00135', '/m/03qc9zr', '/m/02rtxlg', '/m/01j3sz', '/t/dd00001', '/m/07r660_', '/m/07s04w4', '/m/07sq110', '/m/07rgt08', '/m/0463cq4', '/t/dd00002', '/m/07qz6j3', '/m/07qw_06', '/m/07plz5l', '/m/015lz1', '/m/0l14jd', '/m/01swy6', '/m/02bk07', '/m/01c194', '/t/dd00003', '/t/dd00004', '/t/dd00005', '/t/dd00006', '/m/06bxc', '/m/02fxyj', '/m/07s2xch', '/m/07r4k75', '/m/01w250', '/m/0lyf6', '/m/07mzm6', '/m/01d3sd', '/m/07s0dtb', '/m/07pyy8b', '/m/07q0yl5', '/m/01b_21', '/m/0dl9sf8', '/m/01hsr_', '/m/07ppn3j', '/m/06h7j', '/m/07qv_x_', '/m/07pbtc8', '/m/03cczk', '/m/07pdhp0', '/m/0939n_', '/m/01g90h', '/m/03q5_w', '/m/02p3nc', '/m/02_nn', '/m/0k65p', '/m/025_jnm', '/m/0l15bq', '/m/01jg02', '/m/01jg1z', '/m/053hz1', '/m/028ght', '/m/07rkbfh', '/m/03qtwd', '/m/07qfr4h', '/t/dd00013', '/m/0jbk', '/m/068hy', '/m/0bt9lr', '/m/05tny_', '/m/07r_k2n', '/m/07qf0zm', '/m/07rc7d9', '/m/0ghcn6', '/t/dd00136', '/m/01yrx', '/m/02yds9', '/m/07qrkrw', '/m/07rjwbb', '/m/07r81j2', '/m/0ch8v', '/m/03k3r', '/m/07rv9rh', '/m/07q5rw0', '/m/01xq0k1', '/m/07rpkh9', '/m/0239kh', '/m/068zj', '/t/dd00018', '/m/03fwl', '/m/07q0h5t', '/m/07bgp', '/m/025rv6n', '/m/09b5t', '/m/07st89h', '/m/07qn5dc', '/m/01rd7k', '/m/07svc2k', '/m/09ddx', '/m/07qdb04', '/m/0dbvp', '/m/07qwf61', '/m/01280g', '/m/0cdnk', '/m/04cvmfc', '/m/015p6', '/m/020bb7', '/m/07pggtn', '/m/07sx8x_', '/m/0h0rv', '/m/07r_25d', '/m/04s8yn', '/m/07r5c2p', '/m/09d5_', '/m/07r_80w', '/m/05_wcq', '/m/01z5f', '/m/06hps', '/m/04rmv', '/m/07r4gkf', '/m/03vt0', '/m/09xqv', '/m/09f96', '/m/0h2mp', '/m/07pjwq1', '/m/01h3n', '/m/09ld4', '/m/07st88b', '/m/078jl', '/m/07qn4z3', '/m/032n05', '/m/04rlf', '/m/04szw', '/m/0fx80y', '/m/0342h', '/m/02sgy', '/m/018vs', '/m/042v_gx', '/m/06w87', '/m/01glhc', '/m/07s0s5r', '/m/018j2', '/m/0jtg0', '/m/04rzd', '/m/01bns_', '/m/07xzm', '/m/05148p4', '/m/05r5c', '/m/01s0ps', '/m/013y1f', '/m/03xq_f', '/m/03gvt', '/m/0l14qv', '/m/01v1d8', '/m/03q5t', '/m/0l14md', '/m/02hnl', '/m/0cfdd', '/m/026t6', '/m/06rvn', '/m/03t3fj', '/m/02k_mr', '/m/0bm02', '/m/011k_j', '/m/01p970', '/m/01qbl', '/m/03qtq', '/m/01sm1g', '/m/07brj', '/m/05r5wn', '/m/0xzly', '/m/0mbct', '/m/016622', '/m/0j45pbj', '/m/0dwsp', '/m/0dwtp', '/m/0dwt5', '/m/0l156b', '/m/05pd6', '/m/01kcd', '/m/0319l', '/m/07gql', '/m/07c6l', '/m/0l14_3', '/m/02qmj0d', '/m/07y_7', '/m/0d8_n', '/m/01xqw', '/m/02fsn', '/m/085jw', '/m/0l14j_', '/m/06ncr', '/m/01wy6', '/m/03m5k', '/m/0395lw', '/m/03w41f', '/m/027m70_', '/m/0gy1t2s', '/m/07n_g', '/m/0f8s22', '/m/026fgl', '/m/0150b9', '/m/03qjg', '/m/0mkg', '/m/0192l', '/m/02bxd', '/m/0l14l2', '/m/07kc_', '/m/0l14t7', '/m/01hgjl', '/m/064t9', '/m/0glt670', '/m/02cz_7', '/m/06by7', '/m/03lty', '/m/05r6t', '/m/0dls3', '/m/0dl5d', '/m/07sbbz2', '/m/05w3f', '/m/06j6l', '/m/0gywn', '/m/06cqb', '/m/01lyv', '/m/015y_n', '/m/0gg8l', '/m/02x8m', '/m/02w4v', '/m/06j64v', '/m/03_d0', '/m/026z9', '/m/0ggq0m', '/m/05lls', '/m/02lkt', '/m/03mb9', '/m/07gxw', '/m/07s72n', '/m/0283d', '/m/0m0jc', '/m/08cyft', '/m/0fd3y', '/m/07lnk', '/m/0g293', '/m/0ln16', '/m/0326g', '/m/0155w', '/m/05fw6t', '/m/02v2lh', '/m/0y4f8', '/m/0z9c', '/m/0164x2', '/m/0145m', '/m/02mscn', '/m/016cjb', '/m/028sqc', '/m/015vgc', '/m/0dq0md', '/m/06rqw', '/m/02p0sh1', '/m/05rwpb', 7 | '/m/074ft', '/m/025td0t', '/m/02cjck', '/m/03r5q_', '/m/0l14gg', '/m/07pkxdp', '/m/01z7dr', '/m/0140xf', '/m/0ggx5q', '/m/04wptg', '/t/dd00031', '/t/dd00032', '/t/dd00033', '/t/dd00034', '/t/dd00035', '/t/dd00036', '/t/dd00037', '/m/03m9d0z', '/m/09t49', '/t/dd00092', '/m/0jb2l', '/m/0ngt1', '/m/0838f', '/m/06mb1', '/m/07r10fb', '/t/dd00038', '/m/0j6m2', '/m/0j2kx', '/m/05kq4', '/m/034srq', '/m/06wzb', '/m/07swgks', '/m/02_41', '/m/07pzfmf', '/m/07yv9', '/m/019jd', '/m/0hsrw', '/m/056ks2', '/m/02rlv9', '/m/06q74', '/m/012f08', '/m/0k4j', '/m/0912c9', '/m/07qv_d5', '/m/02mfyn', '/m/04gxbd', '/m/07rknqz', '/m/0h9mv', '/t/dd00134', '/m/0ltv', '/m/07r04', '/m/0gvgw0', '/m/05x_td', '/m/02rhddq', '/m/03cl9h', '/m/01bjv', '/m/03j1ly', '/m/04qvtq', '/m/012n7d', '/m/012ndj', '/m/04_sv', '/m/0btp2', '/m/06d_3', '/m/07jdr', '/m/04zmvq', '/m/0284vy3', '/m/01g50p', '/t/dd00048', '/m/0195fx', '/m/0k5j', '/m/014yck', '/m/04229', '/m/02l6bg', '/m/09ct_', '/m/0cmf2', '/m/0199g', '/m/06_fw', '/m/02mk9', '/t/dd00065', '/m/08j51y', '/m/01yg9g', '/m/01j4z9', '/t/dd00066', '/t/dd00067', '/m/01h82_', '/t/dd00130', '/m/07pb8fc', '/m/07q2z82', '/m/02dgv', '/m/03wwcy', '/m/07r67yg', '/m/02y_763', '/m/07rjzl8', '/m/07r4wb8', '/m/07qcpgn', '/m/07q6cd_', '/m/0642b4', '/m/0fqfqc', '/m/04brg2', '/m/023pjk', '/m/07pn_8q', '/m/0dxrf', '/m/0fx9l', '/m/02pjr4', '/m/02jz0l', '/m/0130jx', '/m/03dnzn', '/m/03wvsk', '/m/01jt3m', '/m/012xff', '/m/04fgwm', '/m/0d31p', '/m/01s0vc', '/m/03v3yw', '/m/0242l', '/m/01lsmm', '/m/02g901', '/m/05rj2', '/m/0316dw', '/m/0c2wf', '/m/01m2v', '/m/081rb', '/m/07pp_mv', '/m/07cx4', '/m/07pp8cl', '/m/01hnzm', '/m/02c8p', '/m/015jpf', '/m/01z47d', '/m/046dlr', '/m/03kmc9', '/m/0dgbq', '/m/030rvx', '/m/01y3hg', '/m/0c3f7m', '/m/04fq5q', '/m/0l156k', '/m/06hck5', '/t/dd00077', '/m/02bm9n', '/m/01x3z', '/m/07qjznt', '/m/07qjznl', '/m/0l7xg', '/m/05zc1', '/m/0llzx', '/m/02x984l', '/m/025wky1', '/m/024dl', '/m/01m4t', '/m/0dv5r', '/m/07bjf', '/m/07k1x', '/m/03l9g', '/m/03p19w', '/m/01b82r', '/m/02p01q', '/m/023vsd', '/m/0_ksk', '/m/01d380', '/m/014zdl', '/m/032s66', '/m/04zjc', '/m/02z32qm', '/m/0_1c', '/m/073cg4', '/m/0g6b5', '/g/122z_qxw', '/m/07qsvvw', '/m/07pxg6y', '/m/07qqyl4', '/m/083vt', '/m/07pczhz', '/m/07pl1bw', '/m/07qs1cx', '/m/039jq', '/m/07q7njn', '/m/07rn7sz', '/m/04k94', '/m/07rrlb6', '/m/07p6mqd', '/m/07qlwh6', '/m/07r5v4s', '/m/07prgkl', '/m/07pqc89', '/t/dd00088', '/m/07p7b8y', '/m/07qlf79', '/m/07ptzwd', '/m/07ptfmf', '/m/0dv3j', '/m/0790c', '/m/0dl83', '/m/07rqsjt', '/m/07qnq_y', '/m/07rrh0c', '/m/0b_fwt', '/m/02rr_', '/m/07m2kt', '/m/018w8', '/m/07pws3f', '/m/07ryjzk', '/m/07rdhzs', '/m/07pjjrj', '/m/07pc8lb', '/m/07pqn27', '/m/07rbp7_', '/m/07pyf11', '/m/07qb_dv', '/m/07qv4k0', '/m/07pdjhy', '/m/07s8j8t', '/m/07plct2', '/t/dd00112', '/m/07qcx4z', '/m/02fs_r', '/m/07qwdck', '/m/07phxs1', '/m/07rv4dm', '/m/07s02z0', '/m/07qh7jl', '/m/07qwyj0', '/m/07s34ls', '/m/07qmpdm', '/m/07p9k1k', '/m/07qc9xj', '/m/07rwm0c', '/m/07phhsh', '/m/07qyrcz', '/m/07qfgpx', '/m/07rcgpl', '/m/07p78v5', '/t/dd00121', '/m/07s12q4', '/m/028v0c', '/m/01v_m0', '/m/0b9m1', '/m/0hdsk', '/m/0c1dj', '/m/07pt_g0', '/t/dd00125', '/t/dd00126', '/t/dd00127', '/t/dd00128', '/t/dd00129', '/m/01b9nn', '/m/01jnbd', '/m/096m7z', '/m/06_y0by', '/m/07rgkc5', '/m/06xkwv', '/m/0g12c5', '/m/08p9q4', '/m/07szfh9', '/m/0chx_', '/m/0cj0r', '/m/07p_0gm', '/m/01jwx6', '/m/07c52', '/m/06bz3', '/m/07hvw1'] 8 | 9 | CLASSES = ['Speech', 'Male speech, man speaking', 'Female speech, woman speaking', 'Child speech, kid speaking', 'Conversation', 'Narration, monologue', 'Babbling', 'Speech synthesizer', 'Shout', 'Bellow', 'Whoop', 'Yell', 'Battle cry', 'Children shouting', 'Screaming', 'Whispering', 'Laughter', 'Baby laughter', 'Giggle', 'Snicker', 'Belly laugh', 'Chuckle, chortle', 'Crying, sobbing', 'Baby cry, infant cry', 'Whimper', 'Wail, moan', 'Sigh', 'Singing', 'Choir', 'Yodeling', 'Chant', 'Mantra', 'Male singing', 'Female singing', 'Child singing', 'Synthetic singing', 'Rapping', 'Humming', 'Groan', 'Grunt', 'Whistling', 'Breathing', 'Wheeze', 'Snoring', 'Gasp', 'Pant', 'Snort', 'Cough', 'Throat clearing', 'Sneeze', 'Sniff', 'Run', 'Shuffle', 'Walk, footsteps', 'Chewing, mastication', 'Biting', 'Gargling', 'Stomach rumble', 'Burping, eructation', 'Hiccup', 'Fart', 'Hands', 'Finger snapping', 'Clapping', 'Heart sounds, heartbeat', 'Heart murmur', 'Cheering', 'Applause', 'Chatter', 'Crowd', 'Hubbub, speech noise, speech babble', 'Children playing', 'Animal', 'Domestic animals, pets', 'Dog', 'Bark', 'Yip', 'Howl', 'Bow-wow', 'Growling', 'Whimper (dog)', 'Cat', 'Purr', 'Meow', 'Hiss', 'Caterwaul', 'Livestock, farm animals, working animals', 'Horse', 'Clip-clop', 'Neigh, whinny', 'Cattle, bovinae', 'Moo', 'Cowbell', 'Pig', 'Oink', 'Goat', 'Bleat', 'Sheep', 'Fowl', 'Chicken, rooster', 'Cluck', 'Crowing, cock-a-doodle-doo', 'Turkey', 'Gobble', 'Duck', 'Quack', 'Goose', 'Honk', 'Wild animals', 'Roaring cats (lions, tigers)', 'Roar', 'Bird', 'Bird vocalization, bird call, bird song', 'Chirp, tweet', 'Squawk', 'Pigeon, dove', 'Coo', 'Crow', 'Caw', 'Owl', 'Hoot', 'Bird flight, flapping wings', 'Canidae, dogs, wolves', 'Rodents, rats, mice', 'Mouse', 'Patter', 'Insect', 'Cricket', 'Mosquito', 'Fly, housefly', 'Buzz', 'Bee, wasp, etc.', 'Frog', 'Croak', 'Snake', 'Rattle', 'Whale vocalization', 'Music', 'Musical instrument', 'Plucked string instrument', 'Guitar', 'Electric guitar', 'Bass guitar', 'Acoustic guitar', 'Steel guitar, slide guitar', 'Tapping (guitar technique)', 'Strum', 'Banjo', 'Sitar', 'Mandolin', 'Zither', 'Ukulele', 'Keyboard (musical)', 'Piano', 'Electric piano', 'Organ', 'Electronic organ', 'Hammond organ', 'Synthesizer', 'Sampler', 'Harpsichord', 'Percussion', 'Drum kit', 'Drum machine', 'Drum', 'Snare drum', 'Rimshot', 'Drum roll', 'Bass drum', 'Timpani', 'Tabla', 'Cymbal', 'Hi-hat', 'Wood block', 'Tambourine', 'Rattle (instrument)', 'Maraca', 'Gong', 'Tubular bells', 'Mallet percussion', 'Marimba, xylophone', 'Glockenspiel', 'Vibraphone', 'Steelpan', 'Orchestra', 'Brass instrument', 'French horn', 'Trumpet', 'Trombone', 'Bowed string instrument', 'String section', 'Violin, fiddle', 'Pizzicato', 'Cello', 'Double bass', 'Wind instrument, woodwind instrument', 'Flute', 'Saxophone', 'Clarinet', 'Harp', 'Bell', 'Church bell', 'Jingle bell', 'Bicycle bell', 'Tuning fork', 'Chime', 'Wind chime', 'Change ringing (campanology)', 'Harmonica', 'Accordion', 'Bagpipes', 'Didgeridoo', 'Shofar', 'Theremin', 'Singing bowl', 'Scratching (performance technique)', 'Pop music', 'Hip hop music', 'Beatboxing', 'Rock music', 'Heavy metal', 'Punk rock', 'Grunge', 'Progressive rock', 'Rock and roll', 'Psychedelic rock', 'Rhythm and blues', 'Soul music', 'Reggae', 'Country', 'Swing music', 'Bluegrass', 'Funk', 'Folk music', 'Middle Eastern music', 'Jazz', 'Disco', 'Classical music', 'Opera', 'Electronic music', 'House music', 'Techno', 'Dubstep', 'Drum and bass', 'Electronica', 'Electronic dance music', 'Ambient music', 'Trance music', 'Music of Latin America', 'Salsa music', 'Flamenco', 'Blues', 'Music for children', 'New-age music', 'Vocal music', 'A capella', 'Music of Africa', 'Afrobeat', 'Christian music', 'Gospel music', 'Music of Asia', 'Carnatic music', 'Music of Bollywood', 'Ska', 'Traditional music', 'Independent music', 'Song', 'Background music', 'Theme music', 'Jingle (music)', 'Soundtrack music', 'Lullaby', 10 | 'Video game music', 'Christmas music', 'Dance music', 'Wedding music', 'Happy music', 'Funny music', 'Sad music', 'Tender music', 'Exciting music', 'Angry music', 'Scary music', 'Wind', 'Rustling leaves', 'Wind noise (microphone)', 'Thunderstorm', 'Thunder', 'Water', 'Rain', 'Raindrop', 'Rain on surface', 'Stream', 'Waterfall', 'Ocean', 'Waves, surf', 'Steam', 'Gurgling', 'Fire', 'Crackle', 'Vehicle', 'Boat, Water vehicle', 'Sailboat, sailing ship', 'Rowboat, canoe, kayak', 'Motorboat, speedboat', 'Ship', 'Motor vehicle (road)', 'Car', 'Vehicle horn, car horn, honking', 'Toot', 'Car alarm', 'Power windows, electric windows', 'Skidding', 'Tire squeal', 'Car passing by', 'Race car, auto racing', 'Truck', 'Air brake', 'Air horn, truck horn', 'Reversing beeps', 'Ice cream truck, ice cream van', 'Bus', 'Emergency vehicle', 'Police car (siren)', 'Ambulance (siren)', 'Fire engine, fire truck (siren)', 'Motorcycle', 'Traffic noise, roadway noise', 'Rail transport', 'Train', 'Train whistle', 'Train horn', 'Railroad car, train wagon', 'Train wheels squealing', 'Subway, metro, underground', 'Aircraft', 'Aircraft engine', 'Jet engine', 'Propeller, airscrew', 'Helicopter', 'Fixed-wing aircraft, airplane', 'Bicycle', 'Skateboard', 'Engine', 'Light engine (high frequency)', "Dental drill, dentist's drill", 'Lawn mower', 'Chainsaw', 'Medium engine (mid frequency)', 'Heavy engine (low frequency)', 'Engine knocking', 'Engine starting', 'Idling', 'Accelerating, revving, vroom', 'Door', 'Doorbell', 'Ding-dong', 'Sliding door', 'Slam', 'Knock', 'Tap', 'Squeak', 'Cupboard open or close', 'Drawer open or close', 'Dishes, pots, and pans', 'Cutlery, silverware', 'Chopping (food)', 'Frying (food)', 'Microwave oven', 'Blender', 'Water tap, faucet', 'Sink (filling or washing)', 'Bathtub (filling or washing)', 'Hair dryer', 'Toilet flush', 'Toothbrush', 'Electric toothbrush', 'Vacuum cleaner', 'Zipper (clothing)', 'Keys jangling', 'Coin (dropping)', 'Scissors', 'Electric shaver, electric razor', 'Shuffling cards', 'Typing', 'Typewriter', 'Computer keyboard', 'Writing', 'Alarm', 'Telephone', 'Telephone bell ringing', 'Ringtone', 'Telephone dialing, DTMF', 'Dial tone', 'Busy signal', 'Alarm clock', 'Siren', 'Civil defense siren', 'Buzzer', 'Smoke detector, smoke alarm', 'Fire alarm', 'Foghorn', 'Whistle', 'Steam whistle', 'Mechanisms', 'Ratchet, pawl', 'Clock', 'Tick', 'Tick-tock', 'Gears', 'Pulleys', 'Sewing machine', 'Mechanical fan', 'Air conditioning', 'Cash register', 'Printer', 'Camera', 'Single-lens reflex camera', 'Tools', 'Hammer', 'Jackhammer', 'Sawing', 'Filing (rasp)', 'Sanding', 'Power tool', 'Drill', 'Explosion', 'Gunshot, gunfire', 'Machine gun', 'Fusillade', 'Artillery fire', 'Cap gun', 'Fireworks', 'Firecracker', 'Burst, pop', 'Eruption', 'Boom', 'Wood', 'Chop', 'Splinter', 'Crack', 'Glass', 'Chink, clink', 'Shatter', 'Liquid', 'Splash, splatter', 'Slosh', 'Squish', 'Drip', 'Pour', 'Trickle, dribble', 'Gush', 'Fill (with liquid)', 'Spray', 'Pump (liquid)', 'Stir', 'Boiling', 'Sonar', 'Arrow', 'Whoosh, swoosh, swish', 'Thump, thud', 'Thunk', 'Electronic tuner', 'Effects unit', 'Chorus effect', 'Basketball bounce', 'Bang', 'Slap, smack', 'Whack, thwack', 'Smash, crash', 'Breaking', 'Bouncing', 'Whip', 'Flap', 'Scratch', 'Scrape', 'Rub', 'Roll', 'Crushing', 'Crumpling, crinkling', 'Tearing', 'Beep, bleep', 'Ping', 'Ding', 'Clang', 'Squeal', 'Creak', 'Rustle', 'Whir', 'Clatter', 'Sizzle', 'Clicking', 'Clickety-clack', 'Rumble', 'Plop', 'Jingle, tinkle', 'Hum', 'Zing', 'Boing', 'Crunch', 'Silence', 'Sine wave', 'Harmonic', 'Chirp tone', 'Sound effect', 'Pulse', 'Inside, small room', 'Inside, large room or hall', 'Inside, public space', 'Outside, urban or manmade', 'Outside, rural or natural', 'Reverberation', 'Echo', 'Noise', 'Environmental noise', 'Static', 'Mains hum', 'Distortion', 'Sidetone', 'Cacophony', 'White noise', 'Pink noise', 'Throbbing', 'Vibration', 'Television', 'Radio', 'Field recording'] 11 | 12 | def __init__(self, *args, **kwargs) -> None: 13 | super().__init__() --------------------------------------------------------------------------------