├── utils ├── __init__.py ├── schedulers.py ├── optimizers.py ├── metrics.py ├── losses.py ├── utils.py └── visualize.py ├── assests ├── test.mp3 ├── test.wav ├── test2.wav ├── sed_result.png └── noises │ └── voices.wav ├── requirements.txt ├── models ├── __init__.py └── cnn14.py ├── LICENSE ├── datasets ├── __init__.py ├── urbansound.py ├── esc50.py ├── fsdkaggle.py ├── speechcommands.py ├── aug_test.ipynb ├── transforms.py └── audioset.py ├── configs ├── audioset.yaml ├── audioset_sed.yaml ├── esc50.yaml ├── fsd2018.yaml └── scv1.yaml ├── tools ├── val.py ├── infer.py ├── sed_infer.py └── train.py ├── .gitignore └── README.md /utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /assests/test.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sithu31296/audio-tagging/HEAD/assests/test.mp3 -------------------------------------------------------------------------------- /assests/test.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sithu31296/audio-tagging/HEAD/assests/test.wav -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sithu31296/audio-tagging/HEAD/requirements.txt -------------------------------------------------------------------------------- /assests/test2.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sithu31296/audio-tagging/HEAD/assests/test2.wav -------------------------------------------------------------------------------- /assests/sed_result.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sithu31296/audio-tagging/HEAD/assests/sed_result.png -------------------------------------------------------------------------------- /assests/noises/voices.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sithu31296/audio-tagging/HEAD/assests/noises/voices.wav -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .cnn14 import CNN14, CNN14DecisionLevelMax 2 | 3 | __all__ = { 4 | "cnn14": CNN14, 5 | "cnn14decisionlevelmax": CNN14DecisionLevelMax 6 | } 7 | 8 | 9 | def get_model(model_name: str, num_classes: int = 527): 10 | assert model_name in __all__.keys(), f"Unavailable model name >> {model_name}.\nList of available model names: {list(__all__.keys())}" 11 | return __all__[model_name](num_classes) -------------------------------------------------------------------------------- /utils/schedulers.py: -------------------------------------------------------------------------------- 1 | from torch.optim.lr_scheduler import StepLR, MultiStepLR 2 | 3 | 4 | __all__ = { 5 | "steplr": StepLR 6 | } 7 | 8 | def get_scheduler(cfg, optimizer): 9 | scheduler_name = cfg['SCHEDULER']['NAME'] 10 | assert scheduler_name in __all__.keys(), f"Unavailable scheduler name >> {scheduler_name}.\nList of available schedulers: {list(__all__.keys())}" 11 | return __all__[scheduler_name](optimizer, *cfg['SCHEDULER']['PARAMS']) 12 | 13 | 14 | if __name__ == '__main__': 15 | import torch 16 | model = torch.nn.Linear(1024, 50) 17 | optimizer = torch.optim.SGD(model.parameters(), 0.01) 18 | scheduler = MultiStepLR(optimizer, [10, 20, 30], gamma=0.5) -------------------------------------------------------------------------------- /utils/optimizers.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from torch.optim import AdamW, SGD 3 | 4 | 5 | def get_optimizer(model: nn.Module, optimizer: str, lr: float, weight_decay: float = 0.01): 6 | wd_params, nwd_params = [], [] 7 | for p in model.parameters(): 8 | if p.dim() == 1: 9 | nwd_params.append(p) 10 | else: 11 | wd_params.append(p) 12 | 13 | params = [ 14 | {"params": wd_params}, 15 | {"params": nwd_params, "weight_decay": 0} 16 | ] 17 | 18 | if optimizer == 'adamw': 19 | return AdamW(params, lr, betas=(0.9, 0.999), eps=1e-8, weight_decay=weight_decay) 20 | else: 21 | return SGD(params, lr, momentum=0.9, weight_decay=weight_decay) -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 sithu3 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from torch import distributed as dist 2 | from torch.utils.data import DistributedSampler, SequentialSampler, RandomSampler 3 | from .esc50 import ESC50 4 | from .audioset import AudioSet 5 | from .fsdkaggle import FSDKaggle2018 6 | from .urbansound import UrbanSound8k 7 | from .speechcommands import SpeechCommandsv1 8 | from .speechcommands import SpeechCommandsv2 9 | 10 | 11 | __all__ = { 12 | "esc50": ESC50, 13 | "audioset": AudioSet, 14 | "fsdkaggle2018": FSDKaggle2018, 15 | "urbansound8k": UrbanSound8k, 16 | "speechcommandsv1": SpeechCommandsv1, 17 | "speechcommandsv2": SpeechCommandsv2 18 | } 19 | 20 | 21 | def get_train_dataset(dataset_cfg, aug_cfg, transform, spec_transform): 22 | dataset_name = dataset_cfg['NAME'] 23 | assert dataset_name in __all__.keys(), f"Unavailable dataset name >> {dataset_name}.\nList of available datasets: {list(__all__.keys())}" 24 | return __all__[dataset_name]('train', dataset_cfg, aug_cfg, transform, spec_transform) 25 | 26 | def get_val_dataset(dataset_cfg): 27 | dataset_name = dataset_cfg['NAME'] 28 | assert dataset_name in __all__.keys(), f"Unavailable dataset name >> {dataset_name}.\nList of available datasets: {list(__all__.keys())}" 29 | return __all__[dataset_name]('val', dataset_cfg) 30 | 31 | 32 | def get_sampler(ddp_enable, train_dataset, val_dataset): 33 | if not ddp_enable: 34 | train_sampler = RandomSampler(train_dataset) 35 | else: 36 | train_sampler = DistributedSampler(train_dataset, dist.get_world_size(), dist.get_rank(), shuffle=True) 37 | val_sampler = SequentialSampler(val_dataset) 38 | return train_sampler, val_sampler -------------------------------------------------------------------------------- /configs/audioset.yaml: -------------------------------------------------------------------------------- 1 | DEVICE: cpu # device used for training 2 | 3 | MODEL: 4 | NAME: cnn14 # name of the model you are using 5 | PRETRAINED: '' 6 | 7 | DATASET: 8 | NAME: audioset # dataset name 9 | ROOT: '' # dataset root path 10 | METRIC: mAP 11 | SAMPLE_RATE: 32000 12 | AUDIO_LENGTH: 5 13 | WIN_LENGTH: 1024 14 | HOP_LENGTH: 320 15 | N_MELS: 64 16 | FMIN: 50 17 | FMAX: 14000 18 | 19 | AUG: 20 | MIXUP: 0.0 21 | MIXUP_ALPHA: 10 22 | SMOOTHING: 0.1 23 | TIME_MASK: 96 24 | FREQ_MASK: 24 25 | 26 | TRAIN: 27 | EPOCHS: 100 # number of epochs to train 28 | EVAL_INTERVAL: 10 # interval to evaluate the model during training 29 | BATCH_SIZE: 16 # batch size used to train 30 | LOSS: bcelogits # loss function name (ce, bce, bcelogits, label_smooth, soft_target) 31 | AMP: true # use Automatic Mixed Precision training or not 32 | DDP: false 33 | SAVE_DIR: 'output' # output folder name used for saving the trained model and logs 34 | 35 | OPTIMIZER: 36 | NAME: adamw 37 | LR: 0.0001 # initial learning rate used in optimizer 38 | WEIGHT_DECAY: 0.001 # decay rate use in optimizer 39 | 40 | SCHEDULER: 41 | NAME: steplr 42 | PARAMS: [30, 0.1] 43 | 44 | 45 | TEST: 46 | MODE: file # inference mode (file, mic) 47 | FILE: 'assests/test.wav' # audio file name (not use if you choose MODE=mic) 48 | MODEL_PATH: 'checkpoints/cnn14.pth' # trained model path 49 | TOPK: 5 -------------------------------------------------------------------------------- /utils/metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | from torch import Tensor 4 | from sklearn import metrics as skmetrics 5 | from scipy import stats 6 | 7 | 8 | class Metrics: 9 | def __init__(self, metric='accuracy') -> None: 10 | assert metric in ['accuracy', 'mAP', 'map'] 11 | self.metric = metric 12 | self.count = 0 13 | self.acc = 0.0 14 | self.preds = [] 15 | self.targets = [] 16 | 17 | def _update_accuracy(self, pred: Tensor, target: Tensor): 18 | self.acc += (pred.argmax(dim=-1) == target.argmax(dim=-1)).sum(dim=0).item() 19 | self.count += target.shape[0] 20 | 21 | def _update_mAP(self, pred: Tensor, target: Tensor): 22 | self.preds.append(pred) 23 | self.targets.append(target) 24 | 25 | def _compute_accuracy(self): 26 | return [(self.acc / self.count) * 100] 27 | 28 | def _compute_mAP(self): 29 | preds = torch.cat(self.preds, dim=0).cpu().numpy() 30 | targets = torch.cat(self.targets, dim=0).cpu().numpy() 31 | ap = skmetrics.average_precision_score(targets, preds, average=None) 32 | auc = skmetrics.roc_auc_score(targets, preds, average=None) 33 | mAP = ap.mean() 34 | mAUC = auc.mean() 35 | d_prime = stats.norm().ppf(mAUC) * math.sqrt(2.0) 36 | return mAP * 100, mAUC * 100, d_prime * 100 37 | 38 | def update(self, pred: Tensor, target: Tensor): 39 | if self.metric == 'accuracy': 40 | self._update_accuracy(pred, target) 41 | else: 42 | self._update_mAP(pred, target) 43 | 44 | def compute(self): 45 | if self.metric == 'accuracy': 46 | return self._compute_accuracy() 47 | else: 48 | return self._compute_mAP() -------------------------------------------------------------------------------- /configs/audioset_sed.yaml: -------------------------------------------------------------------------------- 1 | DEVICE: cpu # device used for training 2 | 3 | MODEL: 4 | NAME: cnn14decisionlevelmax # name of the model you are using 5 | PRETRAINED: '' 6 | 7 | DATASET: 8 | NAME: audioset # dataset name 9 | ROOT: '' # dataset root path 10 | METRIC: mAP 11 | SAMPLE_RATE: 32000 12 | AUDIO_LENGTH: 5 13 | WIN_LENGTH: 1024 14 | HOP_LENGTH: 320 15 | N_MELS: 64 16 | FMIN: 50 17 | FMAX: 14000 18 | 19 | AUG: 20 | MIXUP: 0.0 21 | MIXUP_ALPHA: 10 22 | SMOOTHING: 0.1 23 | TIME_MASK: 96 24 | FREQ_MASK: 24 25 | 26 | TRAIN: 27 | EPOCHS: 100 # number of epochs to train 28 | EVAL_INTERVAL: 10 # interval to evaluate the model during training 29 | BATCH_SIZE: 16 # batch size used to train 30 | LOSS: bcelogits # loss function name (ce, bce, bcelogits, label_smooth, soft_target) 31 | AMP: true # use Automatic Mixed Precision training or not 32 | DDP: false 33 | SAVE_DIR: 'output' # output folder name used for saving the trained model and logs 34 | 35 | OPTIMIZER: 36 | NAME: adamw 37 | LR: 0.0001 # initial learning rate used in optimizer 38 | WEIGHT_DECAY: 0.001 # decay rate use in optimizer 39 | 40 | SCHEDULER: 41 | NAME: steplr 42 | PARAMS: [30, 0.1] 43 | 44 | 45 | TEST: 46 | MODE: file # inference mode (file, mic) 47 | FILE: 'assests/test.wav' # audio file name (not use if you choose MODE=mic) 48 | MODEL_PATH: 'checkpoints/cnn14_decisionlevelmax.pth' # trained model path 49 | THRESHOLD: 0.2 50 | PLOT: false -------------------------------------------------------------------------------- /configs/esc50.yaml: -------------------------------------------------------------------------------- 1 | DEVICE: cpu # device used for training 2 | 3 | MODEL: 4 | NAME: cnn14 # name of the model you are using 5 | PRETRAINED: 'checkpoints/cnn14.pth' 6 | 7 | DATASET: 8 | NAME: esc50 # dataset name 9 | ROOT: 'C:/Users/sithu/Documents/Datasets/ESC50' # dataset root path 10 | METRIC: accuracy 11 | SOURCE_SAMPLE: 44100 12 | SAMPLE_RATE: 32000 13 | AUDIO_LENGTH: 5 14 | WIN_LENGTH: 1024 15 | HOP_LENGTH: 320 16 | N_MELS: 64 17 | FMIN: 50 18 | FMAX: 14000 19 | 20 | AUG: 21 | MIXUP: 0.0 22 | MIXUP_ALPHA: 10 23 | SMOOTHING: 0.1 24 | FREQ_MASK: 24 25 | TIME_MASK: 96 26 | 27 | TRAIN: 28 | EPOCHS: 100 # number of epochs to train 29 | EVAL_INTERVAL: 5 # interval to evaluate the model during training 30 | BATCH_SIZE: 16 # batch size used to train 31 | LOSS: label_smooth # loss function name (ce, bce, bcelogits, label_smooth, soft_target) 32 | AMP: false # use Automatic Mixed Precision training or not 33 | DDP: false 34 | SAVE_DIR: output # output folder name used for saving the trained model and logs 35 | 36 | OPTIMIZER: 37 | NAME: adamw 38 | LR: 0.0001 # initial learning rate used in optimizer 39 | WEIGHT_DECAY: 0.001 # decay rate use in optimizer 40 | 41 | SCHEDULER: 42 | NAME: steplr 43 | PARAMS: [20, 0.5] 44 | 45 | 46 | TEST: 47 | MODE: file # inference mode (file, mic) 48 | FILE: 'assests/test2.wav' # audio file name (not use if you choose MODE=mic) 49 | MODEL_PATH: 'checkpoints/cnn14_esc50.pth' # trained model path 50 | TOPK: 5 -------------------------------------------------------------------------------- /configs/fsd2018.yaml: -------------------------------------------------------------------------------- 1 | DEVICE: cuda # device used for training 2 | 3 | MODEL: 4 | NAME: cnn14 # name of the model you are using 5 | PRETRAINED: 'checkpoints/cnn14.pth' 6 | 7 | DATASET: 8 | NAME: fsdkaggle2018 # dataset name 9 | ROOT: 'C:/Users/sithu/Documents/Datasets/FSDKaggle2018' # dataset root path 10 | METRIC: accuracy 11 | SOURCE_SAMPLE: 44100 12 | SAMPLE_RATE: 32000 13 | AUDIO_LENGTH: 5 14 | WIN_LENGTH: 1024 15 | HOP_LENGTH: 320 16 | N_MELS: 64 17 | FMIN: 50 18 | FMAX: 14000 19 | 20 | AUG: 21 | MIXUP: 0.0 22 | MIXUP_ALPHA: 10 23 | SMOOTHING: 0.1 24 | FREQ_MASK: 24 25 | TIME_MASK: 96 26 | 27 | TRAIN: 28 | EPOCHS: 100 # number of epochs to train 29 | EVAL_INTERVAL: 10 # interval to evaluate the model during training 30 | BATCH_SIZE: 16 # batch size used to train 31 | LOSS: label_smooth # loss function name (ce, bce, bcelogits, label_smooth, soft_target) 32 | AMP: true # use Automatic Mixed Precision training or not 33 | DDP: false 34 | SAVE_DIR: output # output folder name used for saving the trained model and logs 35 | 36 | OPTIMIZER: 37 | NAME: adamw 38 | LR: 0.0001 # initial learning rate used in optimizer 39 | WEIGHT_DECAY: 0.001 # decay rate use in optimizer 40 | 41 | SCHEDULER: 42 | NAME: steplr 43 | PARAMS: [30, 0.1] 44 | 45 | 46 | TEST: 47 | MODE: file # inference mode (file, mic) 48 | FILE: 'assests/test.wav' # audio file name (not use if you choose MODE=mic) 49 | MODEL_PATH: 'checkpoints/cnn14_fsdkaggle2018.pth' # trained model path 50 | TOPK: 5 -------------------------------------------------------------------------------- /configs/scv1.yaml: -------------------------------------------------------------------------------- 1 | DEVICE: cpu # device used for training 2 | 3 | MODEL: 4 | NAME: cnn14 # name of the model you are using 5 | PRETRAINED: 'checkpoints/cnn14.pth' 6 | 7 | DATASET: 8 | NAME: speechcommandsv1 # dataset name 9 | ROOT: 'C:/Users/sithu/Documents/Datasets/SpeechCommandsv1' # dataset root path 10 | METRIC: accuracy 11 | SOURCE_SAMPLE: 16000 12 | SAMPLE_RATE: 32000 13 | AUDIO_LENGTH: 1 14 | WIN_LENGTH: 1024 15 | HOP_LENGTH: 320 16 | N_MELS: 64 17 | FMIN: 50 18 | FMAX: 14000 19 | 20 | AUG: 21 | MIXUP: 0.0 22 | MIXUP_ALPHA: 10 23 | SMOOTHING: 0.1 24 | FREQ_MASK: 24 25 | TIME_MASK: 96 26 | 27 | TRAIN: 28 | EPOCHS: 100 # number of epochs to train 29 | EVAL_INTERVAL: 10 # interval to evaluate the model during training 30 | BATCH_SIZE: 16 # batch size used to train 31 | LOSS: label_smooth # loss function name (ce, bce, bcelogits, label_smooth, soft_target) 32 | AMP: true # use Automatic Mixed Precision training or not 33 | DDP: false 34 | SAVE_DIR: output # output folder name used for saving the trained model and logs 35 | 36 | OPTIMIZER: 37 | NAME: adamw 38 | LR: 0.0001 # initial learning rate used in optimizer 39 | WEIGHT_DECAY: 0.001 # decay rate use in optimizer 40 | 41 | SCHEDULER: 42 | NAME: steplr 43 | PARAMS: [30, 0.1] 44 | 45 | 46 | TEST: 47 | MODE: file # inference mode (file, mic) 48 | FILE: 'assests/test.wav' # audio file name (not use if you choose MODE=mic) 49 | MODEL_PATH: 'checkpoints/cnn14_speechcommandsv1.pth' # trained model path 50 | TOPK: 5 -------------------------------------------------------------------------------- /utils/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, Tensor 3 | from torch.nn import CrossEntropyLoss, BCELoss, BCEWithLogitsLoss 4 | 5 | 6 | class CrossEntropy(nn.Module): 7 | def __init__(self): 8 | super().__init__() 9 | self.ce = CrossEntropyLoss() 10 | 11 | def forward(self, pred: Tensor, target: Tensor) -> Tensor: 12 | target = target.argmax(dim=1) 13 | loss = self.ce(pred, target) 14 | return loss 15 | 16 | 17 | class LabelSmoothCrossEntropy(nn.Module): 18 | def __init__(self, smoothing=0.1): 19 | super().__init__() 20 | assert smoothing < 1.0 21 | self.smoothing = smoothing 22 | self.confidence = 1. - smoothing 23 | self.log_softmax = nn.LogSoftmax(dim=-1) 24 | 25 | def forward(self, pred: Tensor, target: Tensor) -> Tensor: 26 | pred = self.log_softmax(pred) 27 | target = target.argmax(dim=1) 28 | nll_loss = -pred.gather(dim=-1, index=target.unsqueeze(1)).squeeze(1) 29 | smooth_loss = -pred.mean(dim=-1) 30 | loss = self.confidence * nll_loss + self.smoothing * smooth_loss 31 | return loss.mean() 32 | 33 | 34 | class SoftTargetCrossEntropy(nn.Module): 35 | def __init__(self): 36 | super().__init__() 37 | self.log_softmax = nn.LogSoftmax(dim=-1) 38 | 39 | def forward(self, pred: Tensor, target: Tensor) -> Tensor: 40 | pred = self.log_softmax(pred) 41 | loss = (-target * pred).sum(dim=-1) 42 | return loss.mean() 43 | 44 | 45 | __all__ = { 46 | # for audio classification, the following can be used 47 | 'ce': CrossEntropy, 48 | 'label_smooth': LabelSmoothCrossEntropy, 49 | 50 | # for audio tagging and sed, the following can be used 51 | 'bce': BCELoss, 52 | 'bcelogits': BCEWithLogitsLoss, 53 | 'soft_target': SoftTargetCrossEntropy 54 | } 55 | 56 | 57 | def get_loss(loss_fn_name: str): 58 | assert loss_fn_name in __all__.keys(), f"Unavailable loss function name >> {loss_fn_name}.\nList of available loss functions: {list(__all__.keys())}" 59 | return __all__[loss_fn_name]() 60 | 61 | 62 | if __name__ == '__main__': 63 | import torch 64 | from torch.nn import functional as F 65 | torch.manual_seed(123) 66 | B = 2 67 | C = 10 68 | x = torch.rand(B, C) 69 | y = torch.rand(B, C) 70 | loss_fn = CrossEntropy() 71 | loss = loss_fn(x, y) 72 | print(loss) 73 | -------------------------------------------------------------------------------- /tools/val.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import yaml 3 | import torch 4 | import multiprocessing as mp 5 | from pathlib import Path 6 | from pprint import pprint 7 | from tabulate import tabulate 8 | from tqdm import tqdm 9 | from torch.utils.data import DataLoader 10 | 11 | import sys 12 | sys.path.insert(0, '.') 13 | from models import get_model 14 | from datasets import get_val_dataset 15 | from utils.utils import setup_cudnn 16 | from utils.metrics import Metrics 17 | 18 | 19 | @torch.no_grad() 20 | def evaluate(dataloader, model, device, metric='accuracy'): 21 | print('Evaluating...') 22 | model.eval() 23 | ametrics = Metrics(metric) 24 | 25 | for audio, target in tqdm(dataloader, ): 26 | audio = audio.to(device) 27 | target = target.to(device) 28 | pred = model(audio) 29 | ametrics.update(pred, target) 30 | 31 | return ametrics.compute() 32 | 33 | 34 | def main(cfg): 35 | save_dir = Path(cfg['TRAIN']['SAVE_DIR']) 36 | device = torch.device(cfg['DEVICE']) 37 | metric_name = cfg['DATASET']['METRIC'] 38 | num_workers = mp.cpu_count() 39 | 40 | dataset = get_val_dataset(cfg['DATASET']) 41 | dataloader = DataLoader(dataset, batch_size=cfg['TRAIN']['BATCH_SIZE'], num_workers=num_workers, pin_memory=True) 42 | model = get_model(cfg['MODEL']['NAME'], dataset.num_classes) 43 | 44 | try: 45 | model_weights = save_dir / f"{cfg['MODEL']['NAME']}_{cfg['DATASET']['NAME']}.pth" 46 | model.load_state_dict(torch.load(str(model_weights), map_location='cpu')) 47 | print(f"Loading Model and trained weights from {model_weights}") 48 | except: 49 | print(f"Please consider placing your model's weights in {save_dir}") 50 | 51 | model = model.to(device) 52 | 53 | if metric_name == 'accuracy': 54 | acc = evaluate(dataloader, model, device, metric_name)[-1] 55 | table = [['Accuracy', f"{acc:.2f}"]] 56 | else: 57 | mAP, mAUC, d_prime = evaluate(dataloader, model, device, metric_name) 58 | table = [ 59 | ['mAP', f"{mAP:.2f}"], 60 | ['AUC', f"{mAUC:.2f}"], 61 | ['d-prime', f"{d_prime:.2f}"] 62 | ] 63 | print(tabulate(table, numalign='right')) 64 | 65 | 66 | if __name__ == '__main__': 67 | parser = argparse.ArgumentParser() 68 | parser.add_argument('--cfg', type=str, required=True, help='Experiment configuration file name') 69 | args = parser.parse_args() 70 | 71 | with open(args.cfg) as f: 72 | cfg = yaml.load(f, Loader=yaml.SafeLoader) 73 | 74 | pprint(cfg) 75 | setup_cudnn() 76 | main(cfg) -------------------------------------------------------------------------------- /tools/infer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import argparse 3 | import torchaudio 4 | import yaml 5 | import numpy as np 6 | from torch import Tensor 7 | from tabulate import tabulate 8 | from pathlib import Path 9 | from torchaudio import transforms as T 10 | from torchaudio import functional as F 11 | 12 | import sys 13 | sys.path.insert(0, '.') 14 | from models import get_model 15 | from datasets import __all__ 16 | from utils.utils import time_sync 17 | 18 | 19 | class AudioTagging: 20 | def __init__(self, cfg) -> None: 21 | self.device = torch.device(cfg['DEVICE']) 22 | self.labels = np.array(__all__[cfg['DATASET']['NAME']].CLASSES) 23 | self.model = get_model(cfg['MODEL']['NAME'], len(self.labels)) 24 | self.model.load_state_dict(torch.load(cfg['TEST']['MODEL_PATH'], map_location='cpu')) 25 | self.model = self.model.to(self.device) 26 | self.model.eval() 27 | 28 | self.topk = cfg['TEST']['TOPK'] 29 | self.sample_rate = cfg['DATASET']['SAMPLE_RATE'] 30 | self.mel_tf = T.MelSpectrogram(self.sample_rate, cfg['DATASET']['WIN_LENGTH'], cfg['DATASET']['WIN_LENGTH'], cfg['DATASET']['HOP_LENGTH'], cfg['DATASET']['FMIN'], cfg['DATASET']['FMAX'], n_mels=cfg['DATASET']['N_MELS'], norm='slaney') 31 | 32 | def preprocess(self, file: str) -> Tensor: 33 | audio, sr = torchaudio.load(file) 34 | if sr != self.sample_rate: audio = F.resample(audio, sr, self.sample_rate) 35 | audio = self.mel_tf(audio) 36 | audio = 10.0 * audio.clamp_(1e-10).log10() 37 | audio = audio.unsqueeze(0) 38 | audio = audio.to(self.device) 39 | return audio 40 | 41 | def postprocess(self, prob: Tensor) -> str: 42 | probs, indices = torch.topk(prob.sigmoid().squeeze().cpu(), self.topk) 43 | return self.labels[indices], probs 44 | 45 | @torch.no_grad() 46 | def model_forward(self, audio: Tensor) -> Tensor: 47 | start = time_sync() 48 | pred = self.model(audio) 49 | end = time_sync() 50 | print(f"Model Inference Time: {(end-start)*1000:.2f}ms") 51 | return pred 52 | 53 | def predict(self, file: str) -> str: 54 | audio = self.preprocess(file) 55 | pred = self.model_forward(audio) 56 | labels, probs = self.postprocess(pred) 57 | return labels, probs 58 | 59 | 60 | if __name__ == '__main__': 61 | parser = argparse.ArgumentParser() 62 | parser.add_argument('--cfg', type=str, default='configs/audioset.yaml') 63 | args = parser.parse_args() 64 | 65 | with open(args.cfg) as f: 66 | cfg = yaml.load(f, Loader=yaml.SafeLoader) 67 | 68 | file_path = Path(cfg['TEST']['FILE']) 69 | model = AudioTagging(cfg) 70 | 71 | if cfg['TEST']['MODE'] == 'file': 72 | if file_path.is_file(): 73 | labels, probs = model.predict(str(file_path)) 74 | print(tabulate({"Class": labels, "Confidence": probs}, headers='keys')) 75 | else: 76 | raise NotImplementedError 77 | else: 78 | raise NotImplementedError -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import random 4 | import time 5 | import os 6 | from pathlib import Path 7 | from torch.backends import cudnn 8 | from torch import nn, Tensor 9 | from torch.autograd import profiler 10 | from typing import Union 11 | from torch import distributed as dist 12 | 13 | 14 | def fix_seeds(seed: int = 123) -> None: 15 | torch.manual_seed(seed) 16 | torch.cuda.manual_seed(seed) 17 | np.random.seed(seed) 18 | random.seed(seed) 19 | 20 | def setup_cudnn() -> None: 21 | # Speed-reproducibility tradeoff https://pytorch.org/docs/stable/notes/randomness.html 22 | cudnn.benchmark = True 23 | cudnn.deterministic = False 24 | 25 | def time_sync() -> float: 26 | if torch.cuda.is_available(): 27 | torch.cuda.synchronize() 28 | return time.time() 29 | 30 | def get_model_size(model: Union[nn.Module, torch.jit.ScriptModule]): 31 | tmp_model_path = Path('temp.p') 32 | if isinstance(model, torch.jit.ScriptModule): 33 | torch.jit.save(model, tmp_model_path) 34 | else: 35 | torch.save(model.state_dict(), tmp_model_path) 36 | size = tmp_model_path.stat().st_size 37 | os.remove(tmp_model_path) 38 | return size / 1e6 # in MB 39 | 40 | @torch.no_grad() 41 | def test_model_latency(model: nn.Module, inputs: torch.Tensor, use_cuda: bool = False) -> float: 42 | with profiler.profile(use_cuda=use_cuda) as prof: 43 | _ = model(inputs) 44 | return prof.self_cpu_time_total / 1000 # ms 45 | 46 | def count_parameters(model: nn.Module) -> float: 47 | return sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6 # in M 48 | 49 | def setup_ddp() -> int: 50 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 51 | rank = int(os.environ['RANK']) 52 | world_size = int(os.environ['WORLD_SIZE']) 53 | gpu = int(os.environ(['LOCAL_RANK'])) 54 | torch.cuda.set_device(gpu) 55 | dist.init_process_group('nccl', init_method="env://", world_size=world_size, rank=rank) 56 | dist.barrier() 57 | else: 58 | gpu = 0 59 | return gpu 60 | 61 | def cleanup_ddp(): 62 | if dist.is_initialized(): 63 | dist.destroy_process_group() 64 | 65 | def reduce_tensor(tensor: Tensor) -> Tensor: 66 | rt = tensor.clone() 67 | dist.all_reduce(rt, op=dist.ReduceOp.SUM) 68 | rt /= dist.get_world_size() 69 | return rt 70 | 71 | @torch.no_grad() 72 | def throughput(dataloader, model: nn.Module, times: int = 30): 73 | model.eval() 74 | images, _ = next(iter(dataloader)) 75 | images = images.cuda(non_blocking=True) 76 | B = images.shape[0] 77 | print(f"Throughput averaged with {times} times") 78 | start = time_sync() 79 | for _ in range(times): 80 | model(images) 81 | end = time_sync() 82 | 83 | print(f"Batch Size {B} throughput {times * B / (end - start)} images/s") 84 | 85 | 86 | def get_dataset_norm(dataloader): 87 | mean, std = 0.0, 0.0 88 | 89 | for audio, _ in dataloader: 90 | mean += audio.mean().item() * audio.shape[0] 91 | std += audio.std().item() * audio.shape[0] 92 | return mean / len(dataloader.dataset), std / len(dataloader.dataset) -------------------------------------------------------------------------------- /utils/visualize.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchaudio 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | from torch import Tensor 6 | from torchaudio import functional as AF 7 | 8 | 9 | def plot_waveform(waveform: Tensor, sample_rate: int): 10 | waveform = waveform.numpy() 11 | num_channels, num_frames = waveform.shape 12 | time_axis = torch.arange(0, num_frames) / sample_rate 13 | 14 | fig, axes = plt.subplots(num_channels, 1) 15 | if num_channels == 1: axes = [axes] 16 | 17 | for c in range(num_channels): 18 | axes[c].plot(time_axis, waveform[c], linewidth=1) 19 | axes[c].grid(True) 20 | if num_channels > 1: axes[c].set_ylabel(f"Channel {c+1}") 21 | 22 | fig.suptitle("Waveform") 23 | plt.show(block=False) 24 | 25 | def plot_specgram(waveform: Tensor, sample_rate: int): 26 | waveform = waveform.numpy() 27 | num_channels, _ = waveform.shape 28 | 29 | fig, axes = plt.subplots(num_channels, 1) 30 | if num_channels == 1: axes = [axes] 31 | 32 | for c in range(num_channels): 33 | axes[c].specgram(waveform[c], Fs=sample_rate) 34 | if num_channels > 1: axes[c].set_ylabel(f"Channel {c+1}") 35 | 36 | fig.suptitle("Spectrogram") 37 | plt.show(block=False) 38 | 39 | def plot_spectrogram(spec: Tensor): 40 | fig, ax = plt.subplots(1, 1) 41 | ax.set_title("Spectrogram (db)") 42 | ax.set_ylabel('freq_bin') 43 | ax.set_xlabel('frame') 44 | im = ax.imshow(AF.amplitude_to_DB(spec[0], 10, 1e-10, np.log10(max(spec.max(), 1e-10))).numpy(), origin='lower', aspect='auto') 45 | fig.colorbar(im, ax=ax) 46 | plt.show(block=False) 47 | 48 | def plot_mel_fbank(fbank: Tensor): 49 | plt.imshow(fbank.numpy(), aspect='auto') 50 | plt.xlabel('mel_bin') 51 | plt.ylabel('freq_bin') 52 | plt.title('Filter bank') 53 | plt.show(block=False) 54 | 55 | def plot_pitch(waveform: Tensor, sample_rate: int, pitch: Tensor): 56 | num_channels, num_frames = waveform.shape 57 | pitch_channels, pitch_frames = pitch.shape 58 | time_axis = torch.linspace(0, num_frames / sample_rate, num_frames) 59 | pitch_axis = torch.linspace(0, num_frames / sample_rate, pitch_frames) 60 | 61 | plt.plot(time_axis, num_channels, linewidth=1, color='gray', alpha=0.3, label='Waveform') 62 | plt.plot(pitch_axis, pitch_channels, linewidth=2, color='green', label='Pitch') 63 | plt.title("Pitch Feature") 64 | plt.grid(True) 65 | plt.legend(loc=0) 66 | plt.show(block=False) 67 | 68 | 69 | def play_audio(waveform: Tensor, sample_rate: int): 70 | from IPython.display import Audio, display 71 | waveform = waveform.numpy() 72 | num_channels, _ = waveform.shape 73 | 74 | if num_channels == 1: 75 | display(Audio(waveform[0], rate=sample_rate)) 76 | elif num_channels == 2: 77 | display(Audio((waveform[0], waveform[1]), rate=sample_rate)) 78 | else: 79 | raise ValueError("Waveform with more than 2 channels are not supported.") 80 | 81 | def plot_sound_events(results: Tensor, labels: list, fps: int = 100): 82 | fig, ax = plt.subplots(1, 1, figsize=(10, 4)) 83 | ax.matshow(results, origin='upper', aspect='auto', cmap='jet', vmin=0, vmax=1) 84 | ax.set_title('Sound Event Detection') 85 | ax.set_xlabel('Seconds') 86 | ax.set_xticks(np.arange(0, results.shape[1], fps)) 87 | ax.set_xticklabels(np.arange(0, results.shape[1]/fps)) 88 | ax.set_yticks(np.arange(0, results.shape[0])) 89 | ax.set_yticklabels(labels) 90 | ax.yaxis.grid(color='k', linestyle='solid', linewidth=0.3, alpha=0.3) 91 | ax.xaxis.set_ticks_position('bottom') 92 | 93 | plt.tight_layout() 94 | plt.show() 95 | # plt.savefig('result.jpg') -------------------------------------------------------------------------------- /tools/sed_infer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import argparse 3 | import torchaudio 4 | import yaml 5 | import numpy as np 6 | from torch import Tensor 7 | from tabulate import tabulate 8 | from pathlib import Path 9 | from torchaudio import transforms as T 10 | from torchaudio import functional as F 11 | 12 | import sys 13 | sys.path.insert(0, '.') 14 | from models import get_model 15 | from datasets import __all__ 16 | from utils.utils import time_sync 17 | from utils.visualize import plot_sound_events 18 | 19 | 20 | class SED: 21 | def __init__(self, cfg) -> None: 22 | self.device = torch.device(cfg['DEVICE']) 23 | self.labels = np.array(__all__[cfg['DATASET']['NAME']].CLASSES) 24 | self.model = get_model(cfg['MODEL']['NAME'], len(self.labels)) 25 | self.model.load_state_dict(torch.load(cfg['TEST']['MODEL_PATH'], map_location='cpu')) 26 | self.model = self.model.to(self.device) 27 | self.model.eval() 28 | 29 | self.threshold = cfg['TEST']['THRESHOLD'] 30 | self.sample_rate = cfg['DATASET']['SAMPLE_RATE'] 31 | self.mel_tf = T.MelSpectrogram(self.sample_rate, cfg['DATASET']['WIN_LENGTH'], cfg['DATASET']['WIN_LENGTH'], cfg['DATASET']['HOP_LENGTH'], cfg['DATASET']['FMIN'], cfg['DATASET']['FMAX'], n_mels=cfg['DATASET']['N_MELS'], norm='slaney') 32 | 33 | def preprocess(self, file: str) -> Tensor: 34 | audio, sr = torchaudio.load(file) 35 | if sr != self.sample_rate: audio = F.resample(audio, sr, self.sample_rate) 36 | audio = self.mel_tf(audio) 37 | audio = 10.0 * audio.clamp_(1e-10).log10() 38 | audio = audio.unsqueeze(0) 39 | audio = audio.to(self.device) 40 | return audio 41 | 42 | def postprocess(self, prob: Tensor) -> str: 43 | probs, indices = torch.sort(prob.max(dim=0)[0], descending=True) 44 | indices = indices[probs > self.threshold] 45 | # probs, indices = torch.topk(prob.max(dim=0)[0], topk) 46 | top_results = prob[:, indices].t() 47 | 48 | starts = [] 49 | ends = [] 50 | 51 | for result in top_results: 52 | index = torch.where(result > self.threshold)[0] 53 | starts.append(round(index[0].item()/100, 1)) 54 | ends.append(round(index[-1].item()/100, 1)) 55 | 56 | return top_results, self.labels[indices], starts, ends 57 | 58 | 59 | @torch.no_grad() 60 | def model_forward(self, audio: Tensor) -> Tensor: 61 | start = time_sync() 62 | pred = self.model(audio)[0].squeeze().cpu() 63 | end = time_sync() 64 | print(f"Model Inference Time: {(end-start)*1000:.2f}ms") 65 | return pred 66 | 67 | def predict(self, file: str) -> str: 68 | audio = self.preprocess(file) 69 | pred = self.model_forward(audio) 70 | results, labels, starts, ends = self.postprocess(pred) 71 | return results, labels, starts, ends 72 | 73 | 74 | if __name__ == '__main__': 75 | parser = argparse.ArgumentParser() 76 | parser.add_argument('--cfg', type=str, default='configs/audioset_sed.yaml') 77 | args = parser.parse_args() 78 | 79 | with open(args.cfg) as f: 80 | cfg = yaml.load(f, Loader=yaml.SafeLoader) 81 | 82 | file_path = Path(cfg['TEST']['FILE']) 83 | model = SED(cfg) 84 | 85 | if cfg['TEST']['MODE'] == 'file': 86 | if file_path.is_file(): 87 | results, labels, starts, ends = model.predict(str(file_path)) 88 | print(tabulate({"Class": labels, "Start": starts, "End": ends}, headers='keys')) 89 | if cfg['TEST']['PLOT']: 90 | plot_sound_events(results, labels) 91 | else: 92 | raise NotImplementedError 93 | else: 94 | raise NotImplementedError -------------------------------------------------------------------------------- /datasets/urbansound.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random 3 | import torchaudio 4 | from pathlib import Path 5 | from torch import Tensor 6 | from torchaudio import transforms as T 7 | from torchaudio import functional as AF 8 | from torch.nn import functional as F 9 | from torch.utils.data import Dataset, DataLoader 10 | from typing import Tuple 11 | from .transforms import mixup_augment 12 | 13 | 14 | class UrbanSound8k(Dataset): 15 | CLASSES = ['air conditioner', 'car horn', 'children playing', 'dog bark', 'drilling', 'engine idling', 'gun shot', 'jackhammer', 'siren', 'street music'] 16 | 17 | def __init__(self, split, data_cfg, mixup_cfg=None, transform=None, spec_transform=None) -> None: 18 | super().__init__() 19 | assert split in ['train', 'val'] 20 | self.num_classes = len(self.CLASSES) 21 | self.transform = transform 22 | self.spec_transform = spec_transform 23 | self.mixup = mixup_cfg['MIXUP'] if mixup_cfg is not None else 0.0 24 | self.mixup_alpha = mixup_cfg['MIXUP_ALPHA'] if mixup_cfg is not None else 0.0 25 | self.label_smooth = mixup_cfg['SMOOTHING'] if mixup_cfg is not None else 0.0 26 | self.sample_rate = data_cfg['SAMPLE_RATE'] 27 | self.num_frames = self.sample_rate * data_cfg['AUDIO_LENGTH'] 28 | 29 | val_fold = 10 30 | 31 | self.mel_tf = T.MelSpectrogram(self.sample_rate, data_cfg['WIN_LENGTH'], data_cfg['WIN_LENGTH'], data_cfg['HOP_LENGTH'], data_cfg['FMIN'], data_cfg['FMAX'], n_mels=data_cfg['N_MELS'], norm='slaney') # using mel_scale='slaney' is better 32 | self.data, self.targets = self.get_data(data_cfg['ROOT'], split, val_fold) 33 | print(f"Found {len(self.data)} {split} audios in {data_cfg['ROOT']}.") 34 | 35 | 36 | def get_data(self, root: str, split: int, fold: int): 37 | root = Path(root) 38 | files = (root / 'audio').rglob('*.wav') 39 | files = list(filter(lambda x: not str(x.parent).endswith(f"fold{fold}") if split == 'train' else str(x.parent).endswith(f"fold{fold}"), files)) 40 | targets = list(map(lambda x: int(x.stem.split('-', maxsplit=2)[1]), files)) 41 | assert len(files) == len(targets) 42 | return files, targets 43 | 44 | 45 | def __len__(self) -> int: 46 | return len(self.data) 47 | 48 | 49 | def __getitem__(self, index: int) -> Tuple[Tensor, Tensor]: 50 | audio, sr = torchaudio.load(self.data[index]) 51 | if audio.shape[0] != 1: audio = audio[:1] # reduce to mono 52 | audio = AF.resample(audio, sr, self.sample_rate) # resample the audio 53 | if audio.shape[1] < self.num_frames: audio = torch.cat([audio, torch.zeros(1, self.num_frames-audio.shape[1])], dim=-1) 54 | target = torch.tensor(self.targets[index]) 55 | 56 | if self.transform: audio = self.transform(audio) 57 | 58 | if random.random() < self.mixup: 59 | next_index = random.randint(0, len(self.data)-1) 60 | next_audio, sr = torchaudio.load(self.data[next_index]) 61 | if next_audio.shape[0] != 1: next_audio = next_audio[:1] # reduce to mono 62 | next_audio = AF.resample(next_audio, sr, self.sample_rate) 63 | if next_audio.shape[1] < self.num_frames: next_audio = torch.cat([next_audio, torch.zeros(1, self.num_frames-next_audio.shape[1])], dim=-1) 64 | next_target = torch.tensor(self.targets[next_index]) 65 | audio, target = mixup_augment(audio, target, next_audio, next_target, self.mixup_alpha, self.num_classes, self.label_smooth) 66 | else: 67 | target = F.one_hot(target, self.num_classes).float() 68 | 69 | audio = self.mel_tf(audio) # convert to mel spectrogram 70 | audio = 10.0 * audio.clamp_(1e-10).log10() # convert to log mel spectrogram 71 | 72 | if self.spec_transform: audio = self.spec_transform(audio) 73 | 74 | return audio, target 75 | 76 | 77 | if __name__ == '__main__': 78 | data_cfg = { 79 | 'ROOT': 'C:/Users/sithu/Documents/Datasets/Urbansound8K', 80 | 'SAMPLE_RATE': 32000, 81 | 'AUDIO_LENGTH': 4, 82 | 'WIN_LENGTH': 1024, 83 | 'HOP_LENGTH': 320, 84 | 'N_MELS': 64, 85 | 'FMIN': 50, 86 | 'FMAX': 14000 87 | } 88 | aug_cfg = { 89 | 'MIXUP': 0.5, 90 | 'MIXUP_ALPHA': 10, 91 | 'SMOOTHING': 0.1 92 | } 93 | dataset = UrbanSound8k('train', data_cfg, aug_cfg) 94 | dataloader = DataLoader(dataset, 2, True) 95 | for audio, target in dataloader: 96 | print(audio.shape, target.argmax(dim=1)) 97 | print(audio.min(), audio.max()) 98 | break -------------------------------------------------------------------------------- /datasets/esc50.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random 3 | import torchaudio 4 | from pathlib import Path 5 | from torch import Tensor 6 | from torchaudio import transforms as T 7 | from torch.nn import functional as F 8 | from torch.utils.data import Dataset, DataLoader 9 | from typing import Tuple 10 | from .transforms import mixup_augment 11 | 12 | 13 | class ESC50(Dataset): 14 | """ 15 | 50 classes 16 | 40 examples per class 17 | 2000 recordings 18 | 5 major categories: [Animals, Nature sounds, Human non-speech sounds, Interior/domestic sounds, Exterior/urban sounds] 19 | Each of the audio is named like this: 20 | {FOLD}-{CLIP_ID}-{TAKE}-{TARGET}.wav 21 | """ 22 | CLASSES = ['dog', 'rooster', 'pig', 'cow', 'frog', 'cat', 'hen', 'insects', 'sheep', 'crow', 'rain', 'sea_waves', 'crackling_fire', 'crickets', 'chirping_birds', 'water_drops', 'wind', 'pouring_water', 'toilet_flush', 'thunderstorm', 'crying_baby', 'sneezing', 'clapping', 'breathing', 'coughing', 'footsteps', 'laughing', 'brushing_teeth', 23 | 'snoring', 'drinking_sipping', 'door_wood_knock', 'mouse_click', 'keyboard_typing', 'door_wood_creaks', 'can_opening', 'washing_machine', 'vacuum_cleaner', 'clock_alarm', 'clock_tick', 'glass_breaking', 'helicopter', 'chainsaw', 'siren', 'car_horn', 'engine', 'train', 'church_bells', 'airplane', 'fireworks', 'hand_saw'] 24 | 25 | def __init__(self, split, data_cfg, mixup_cfg=None, transform=None, spec_transform=None) -> None: 26 | super().__init__() 27 | assert split in ['train', 'val'] 28 | self.num_classes = len(self.CLASSES) 29 | self.transform = transform 30 | self.spec_transform = spec_transform 31 | self.mixup = mixup_cfg['MIXUP'] if mixup_cfg is not None else 0.0 32 | self.mixup_alpha = mixup_cfg['MIXUP_ALPHA'] if mixup_cfg is not None else 0.0 33 | self.label_smooth = mixup_cfg['SMOOTHING'] if mixup_cfg is not None else 0.0 34 | val_fold = 5 35 | 36 | self.mel_tf = T.MelSpectrogram(data_cfg['SAMPLE_RATE'], data_cfg['WIN_LENGTH'], data_cfg['WIN_LENGTH'], data_cfg['HOP_LENGTH'], data_cfg['FMIN'], data_cfg['FMAX'], n_mels=data_cfg['N_MELS'], norm='slaney') # using mel_scale='slaney' is better 37 | self.resample = T.Resample(data_cfg['SOURCE_SAMPLE'], data_cfg['SAMPLE_RATE']) 38 | 39 | self.data, self.targets = self.get_data(data_cfg['ROOT'], split, val_fold) 40 | print(f"Found {len(self.data)} {split} audios in {data_cfg['ROOT']}.") 41 | 42 | 43 | def get_data(self, root: str, split: str, fold: int): 44 | root = Path(root) 45 | files = (root / 'audio').glob('*.wav') 46 | files = list(filter(lambda x: not x.stem.startswith(f"{fold}") if split == 'train' else x.stem.startswith(f"{fold}"), files)) 47 | targets = list(map(lambda x: int(x.stem.rsplit('-', maxsplit=1)[-1]), files)) 48 | assert len(files) == len(targets) 49 | return files, targets 50 | 51 | 52 | def __len__(self) -> int: 53 | return len(self.data) 54 | 55 | 56 | def __getitem__(self, index: int) -> Tuple[Tensor, Tensor]: 57 | audio, _ = torchaudio.load(self.data[index]) 58 | audio = self.resample(audio) 59 | target = torch.tensor(self.targets[index]) 60 | 61 | if self.transform: audio = self.transform(audio) 62 | 63 | if random.random() < self.mixup: 64 | next_index = random.randint(0, len(self.data)-1) 65 | next_audio, _ = torchaudio.load(self.data[next_index]) 66 | next_audio = self.resample(next_audio) 67 | next_target = torch.tensor(self.targets[next_index]) 68 | audio, target = mixup_augment(audio, target, next_audio, next_target, self.mixup_alpha, self.num_classes, self.label_smooth) 69 | else: 70 | target = F.one_hot(target, self.num_classes).float() 71 | 72 | audio = self.mel_tf(audio) 73 | audio = 10.0 * audio.clamp_(1e-10).log10() 74 | 75 | if self.spec_transform: audio = self.spec_transform(audio) 76 | 77 | return audio, target 78 | 79 | 80 | if __name__ == '__main__': 81 | data_cfg = { 82 | 'ROOT': 'C:/Users/sithu/Documents/Datasets/ESC50', 83 | 'SOURCE_SAMPLE': 44100, 84 | 'SAMPLE_RATE': 32000, 85 | 'AUDIO_LENGTH': 5, 86 | 'WIN_LENGTH': 1024, 87 | 'HOP_LENGTH': 320, 88 | 'N_MELS': 64, 89 | 'FMIN': 50, 90 | 'FMAX': 14000 91 | } 92 | aug_cfg = { 93 | 'MIXUP': 0.5, 94 | 'MIXUP_ALPHA': 10, 95 | 'SMOOTHING': 0.1 96 | } 97 | dataset = ESC50('val', data_cfg, aug_cfg) 98 | dataloader = DataLoader(dataset, 2, True) 99 | for audio, target in dataloader: 100 | print(audio.shape, target.argmax(dim=1)) 101 | print(audio.min(), audio.max()) 102 | break 103 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Repo-specific GitIgnore ---------------------------------------------------------------------------------------------- 2 | *.jpg 3 | *.jpeg 4 | *.png 5 | *.bmp 6 | *.tif 7 | *.tiff 8 | *.heic 9 | *.JPG 10 | *.JPEG 11 | *.PNG 12 | *.BMP 13 | *.TIF 14 | *.TIFF 15 | *.HEIC 16 | *.mp4 17 | *.mov 18 | *.MOV 19 | *.avi 20 | *.data 21 | *.json 22 | 23 | *.cfg 24 | !cfg/yolov3*.cfg 25 | 26 | storage.googleapis.com 27 | test_imgs/ 28 | runs/* 29 | data/* 30 | !data/images/zidane.jpg 31 | !data/images/bus.jpg 32 | !data/coco.names 33 | !data/coco_paper.names 34 | !data/coco.data 35 | !data/coco_*.data 36 | !data/coco_*.txt 37 | !data/trainvalno5k.shapes 38 | !data/*.sh 39 | 40 | pycocotools/* 41 | results*.txt 42 | gcp_test*.sh 43 | 44 | checkpoints/ 45 | output/ 46 | 47 | # Datasets ------------------------------------------------------------------------------------------------------------- 48 | coco/ 49 | coco128/ 50 | VOC/ 51 | 52 | # MATLAB GitIgnore ----------------------------------------------------------------------------------------------------- 53 | *.m~ 54 | *.mat 55 | !targets*.mat 56 | 57 | # Neural Network weights ----------------------------------------------------------------------------------------------- 58 | *.weights 59 | *.pt 60 | *.onnx 61 | *.mlmodel 62 | *.torchscript 63 | darknet53.conv.74 64 | yolov3-tiny.conv.15 65 | 66 | # GitHub Python GitIgnore ---------------------------------------------------------------------------------------------- 67 | # Byte-compiled / optimized / DLL files 68 | __pycache__/ 69 | *.py[cod] 70 | *$py.class 71 | 72 | # C extensions 73 | *.so 74 | 75 | # Distribution / packaging 76 | .Python 77 | env/ 78 | build/ 79 | develop-eggs/ 80 | dist/ 81 | downloads/ 82 | eggs/ 83 | .eggs/ 84 | lib/ 85 | lib64/ 86 | parts/ 87 | sdist/ 88 | var/ 89 | wheels/ 90 | *.egg-info/ 91 | wandb/ 92 | .installed.cfg 93 | *.egg 94 | 95 | 96 | # PyInstaller 97 | # Usually these files are written by a python script from a template 98 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 99 | *.manifest 100 | *.spec 101 | 102 | # Installer logs 103 | pip-log.txt 104 | pip-delete-this-directory.txt 105 | 106 | # Unit test / coverage reports 107 | htmlcov/ 108 | .tox/ 109 | .coverage 110 | .coverage.* 111 | .cache 112 | nosetests.xml 113 | coverage.xml 114 | *.cover 115 | .hypothesis/ 116 | 117 | # Translations 118 | *.mo 119 | *.pot 120 | 121 | # Django stuff: 122 | *.log 123 | local_settings.py 124 | tools/flask_deploy.py 125 | # Flask stuff: 126 | instance/ 127 | .webassets-cache 128 | 129 | # Scrapy stuff: 130 | .scrapy 131 | 132 | # Sphinx documentation 133 | docs/_build/ 134 | 135 | # PyBuilder 136 | target/ 137 | 138 | # Jupyter Notebook 139 | .ipynb_checkpoints 140 | 141 | # pyenv 142 | .python-version 143 | 144 | # celery beat schedule file 145 | celerybeat-schedule 146 | 147 | # SageMath parsed files 148 | *.sage.py 149 | 150 | # dotenv 151 | .env 152 | 153 | # virtualenv 154 | .venv* 155 | venv*/ 156 | ENV*/ 157 | 158 | # Spyder project settings 159 | .spyderproject 160 | .spyproject 161 | 162 | # Rope project settings 163 | .ropeproject 164 | 165 | # mkdocs documentation 166 | /site 167 | 168 | # mypy 169 | .mypy_cache/ 170 | 171 | 172 | # https://github.com/github/gitignore/blob/master/Global/macOS.gitignore ----------------------------------------------- 173 | 174 | # General 175 | .DS_Store 176 | .AppleDouble 177 | .LSOverride 178 | 179 | # Icon must end with two \r 180 | Icon 181 | Icon? 182 | 183 | # Thumbnails 184 | ._* 185 | 186 | # Files that might appear in the root of a volume 187 | .DocumentRevisions-V100 188 | .fseventsd 189 | .Spotlight-V100 190 | .TemporaryItems 191 | .Trashes 192 | .VolumeIcon.icns 193 | .com.apple.timemachine.donotpresent 194 | 195 | # Directories potentially created on remote AFP share 196 | .AppleDB 197 | .AppleDesktop 198 | Network Trash Folder 199 | Temporary Items 200 | .apdisk 201 | 202 | 203 | # https://github.com/github/gitignore/blob/master/Global/JetBrains.gitignore 204 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and WebStorm 205 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 206 | 207 | # User-specific stuff: 208 | .idea/* 209 | .idea/**/workspace.xml 210 | .idea/**/tasks.xml 211 | .idea/dictionaries 212 | .html # Bokeh Plots 213 | .pg # TensorFlow Frozen Graphs 214 | .avi # videos 215 | 216 | # Sensitive or high-churn files: 217 | .idea/**/dataSources/ 218 | .idea/**/dataSources.ids 219 | .idea/**/dataSources.local.xml 220 | .idea/**/sqlDataSources.xml 221 | .idea/**/dynamic.xml 222 | .idea/**/uiDesigner.xml 223 | 224 | # Gradle: 225 | .idea/**/gradle.xml 226 | .idea/**/libraries 227 | 228 | # CMake 229 | cmake-build-debug/ 230 | cmake-build-release/ 231 | 232 | # Mongo Explorer plugin: 233 | .idea/**/mongoSettings.xml 234 | 235 | ## File-based project format: 236 | *.iws 237 | 238 | ## Plugin-specific files: 239 | 240 | # IntelliJ 241 | out/ 242 | 243 | # mpeltonen/sbt-idea plugin 244 | .idea_modules/ 245 | 246 | # JIRA plugin 247 | atlassian-ide-plugin.xml 248 | 249 | # Cursive Clojure plugin 250 | .idea/replstate.xml 251 | 252 | # Crashlytics plugin (for Android Studio and IntelliJ) 253 | com_crashlytics_export_strings.xml 254 | crashlytics.properties 255 | crashlytics-build.properties 256 | fabric.properties 257 | -------------------------------------------------------------------------------- /datasets/fsdkaggle.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random 3 | import torchaudio 4 | from pathlib import Path 5 | from torch import Tensor 6 | from torchaudio import transforms as T 7 | from torch.nn import functional as F 8 | from torch.utils.data import Dataset, DataLoader 9 | from typing import Tuple 10 | from .transforms import mixup_augment 11 | 12 | 13 | class FSDKaggle2018(Dataset): 14 | CLASSES = ["Acoustic_guitar", "Applause", "Bark", "Bass_drum", "Burping_or_eructation", "Bus", "Cello", "Chime", "Clarinet", "Computer_keyboard", "Cough", "Cowbell", "Double_bass", "Drawer_open_or_close", "Electric_piano", "Fart", "Finger_snapping", "Fireworks", "Flute", 15 | "Glockenspiel", "Gong", "Gunshot_or_gunfire", "Harmonica", "Hi-hat", "Keys_jangling", "Knock", "Laughter", "Meow", "Microwave_oven", "Oboe", "Saxophone", "Scissors", "Shatter", "Snare_drum", "Squeak", "Tambourine", "Tearing", "Telephone", "Trumpet", "Violin_or_fiddle", "Writing"] 16 | 17 | def __init__(self, split, data_cfg, mixup_cfg=None, transform=None, spec_transform=None) -> None: 18 | super().__init__() 19 | assert split in ['train', 'val'] 20 | split = 'train' if split == 'train' else 'test' 21 | self.num_classes = len(self.CLASSES) 22 | self.transform = transform 23 | self.spec_transform = spec_transform 24 | self.mixup = mixup_cfg['MIXUP'] if mixup_cfg is not None else 0.0 25 | self.mixup_alpha = mixup_cfg['MIXUP_ALPHA'] if mixup_cfg is not None else 0.0 26 | self.label_smooth = mixup_cfg['SMOOTHING'] if mixup_cfg is not None else 0.0 27 | self.num_frames = data_cfg['SAMPLE_RATE'] * data_cfg['AUDIO_LENGTH'] 28 | 29 | self.mel_tf = T.MelSpectrogram(data_cfg['SAMPLE_RATE'], data_cfg['WIN_LENGTH'], data_cfg['WIN_LENGTH'], data_cfg['HOP_LENGTH'], data_cfg['FMIN'], data_cfg['FMAX'], n_mels=data_cfg['N_MELS'], norm='slaney') # using mel_scale='slaney' is better 30 | self.resample = T.Resample(data_cfg['SOURCE_SAMPLE'], data_cfg['SAMPLE_RATE']) 31 | 32 | self.data, self.targets = self.get_data(data_cfg['ROOT'], split) 33 | print(f"Found {len(self.data)} {split} audios in {data_cfg['ROOT']}.") 34 | 35 | 36 | def get_data(self, root: str, split: str): 37 | root = Path(root) 38 | csv_path = 'train_post_competition.csv' if split == 'train' else 'test_post_competition_scoring_clips.csv' 39 | files, targets = [], [] 40 | 41 | with open(root / 'FSDKaggle2018.meta' / csv_path) as f: 42 | lines = f.read().splitlines()[1:] 43 | for line in lines: 44 | fname, label, *_ = line.split(',') 45 | files.append(str(root / f"audio_{split}" / fname)) 46 | targets.append(int(self.CLASSES.index(label))) 47 | return files, targets 48 | 49 | 50 | def __len__(self) -> int: 51 | return len(self.data) 52 | 53 | 54 | def cut_pad(self, audio: Tensor) -> Tensor: 55 | if audio.shape[1] < self.num_frames: # if less than 5s, pad the audio 56 | audio = torch.cat([audio, torch.zeros(1, self.num_frames-audio.shape[1])], dim=-1) 57 | else: # if not, trim the audio to 5s 58 | audio = audio[:, :self.num_frames] 59 | return audio 60 | 61 | 62 | def __getitem__(self, index: int) -> Tuple[Tensor, Tensor]: 63 | audio, _ = torchaudio.load(self.data[index]) 64 | audio = self.resample(audio) # resample to 32kHz 65 | audio = self.cut_pad(audio) 66 | target = torch.tensor(self.targets[index]) 67 | 68 | if self.transform: audio = self.transform(audio) 69 | 70 | if random.random() < self.mixup: 71 | next_index = random.randint(0, len(self.data)-1) 72 | next_audio, _ = torchaudio.load(self.data[next_index]) 73 | next_audio = self.resample(next_audio) 74 | next_audio = self.cut_pad(next_audio) 75 | next_target = torch.tensor(self.targets[next_index]) 76 | audio, target = mixup_augment(audio, target, next_audio, next_target, self.mixup_alpha, self.num_classes, self.label_smooth) 77 | else: 78 | target = F.one_hot(target, self.num_classes).float() 79 | 80 | audio = self.mel_tf(audio) # convert to mel spectrogram 81 | audio = 10.0 * audio.clamp_(1e-10).log10() # convert to log mel spectrogram 82 | 83 | if self.spec_transform: audio = self.spec_transform(audio) 84 | 85 | return audio, target 86 | 87 | 88 | if __name__ == '__main__': 89 | data_cfg = { 90 | 'ROOT': 'C:/Users/sithu/Documents/Datasets/FSDKaggle2018', 91 | 'SOURCE_SAMPLE': 44100, 92 | 'SAMPLE_RATE': 32000, 93 | 'AUDIO_LENGTH': 5, 94 | 'WIN_LENGTH': 1024, 95 | 'HOP_LENGTH': 320, 96 | 'N_MELS': 64, 97 | 'FMIN': 50, 98 | 'FMAX': 14000 99 | } 100 | aug_cfg = { 101 | 'MIXUP': 0.5, 102 | 'MIXUP_ALPHA': 10, 103 | 'SMOOTHING': 0.1 104 | } 105 | dataset = FSDKaggle2018('val', data_cfg, aug_cfg) 106 | dataloader = DataLoader(dataset, 2, True) 107 | for audio, target in dataloader: 108 | print(audio.shape, target.argmax(dim=1)) 109 | print(audio.min(), audio.max()) 110 | break 111 | 112 | -------------------------------------------------------------------------------- /datasets/speechcommands.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random 3 | import torchaudio 4 | import os 5 | from pathlib import Path 6 | from torch import Tensor 7 | from torchaudio import transforms as T 8 | from torch.nn import functional as F 9 | from torch.utils.data import Dataset, DataLoader 10 | from typing import Tuple 11 | from .transforms import mixup_augment 12 | 13 | 14 | class SpeechCommandsv1(Dataset): 15 | CLASSES = ['bed', 'bird', 'cat', 'dog', 'down', 'eight', 'five', 'four', 'go', 'happy', 'house', 'left', 'marvin', 'nine', 'no', 16 | 'off', 'on', 'one', 'right', 'seven', 'sheila', 'six', 'stop', 'three', 'tree', 'two', 'up', 'wow', 'yes', 'zero'] 17 | 18 | def __init__(self, split, data_cfg, mixup_cfg=None, transform=None, spec_transform=None) -> None: 19 | super().__init__() 20 | assert split in ['train', 'val', 'test'] 21 | self.num_classes = len(self.CLASSES) 22 | self.transform = transform 23 | self.spec_transform = spec_transform 24 | self.mixup = mixup_cfg['MIXUP'] if mixup_cfg is not None else 0.0 25 | self.mixup_alpha = mixup_cfg['MIXUP_ALPHA'] if mixup_cfg is not None else 0.0 26 | self.label_smooth = mixup_cfg['SMOOTHING'] if mixup_cfg is not None else 0.0 27 | self.num_frames = data_cfg['SAMPLE_RATE'] * data_cfg['AUDIO_LENGTH'] 28 | 29 | self.mel_tf = T.MelSpectrogram(data_cfg['SAMPLE_RATE'], data_cfg['WIN_LENGTH'], data_cfg['WIN_LENGTH'], data_cfg['HOP_LENGTH'], data_cfg['FMIN'], data_cfg['FMAX'], n_mels=data_cfg['N_MELS'], norm='slaney') # using mel_scale='slaney' is better 30 | self.resample = T.Resample(data_cfg['SOURCE_SAMPLE'], data_cfg['SAMPLE_RATE']) 31 | 32 | self.data, self.targets = self.get_data(data_cfg['ROOT'], split) 33 | print(f"Found {len(self.data)} {split} audios in {data_cfg['ROOT']}.") 34 | 35 | 36 | def get_data(self, root: str, split: int): 37 | root = Path(root) 38 | if split == 'train': 39 | files = root.rglob('*.wav') 40 | excludes = [] 41 | with open(root / 'testing_list.txt') as f1, open(root / 'validation_list.txt') as f2: 42 | excludes += f1.read().splitlines() 43 | excludes += f2.read().splitlines() 44 | 45 | excludes = list(map(lambda x: str(root / x), excludes)) 46 | files = list(filter(lambda x: "_background_noise_" not in str(x) and str(x) not in excludes, files)) 47 | else: 48 | split = 'testing' if split == 'test' else 'validation' 49 | with open(root / f'{split}_list.txt') as f: 50 | files = f.read().splitlines() 51 | 52 | files = list(map(lambda x: root / x, files)) 53 | 54 | targets = list(map(lambda x: int(self.CLASSES.index(str(x.parent).rsplit(os.path.sep, maxsplit=1)[-1])), files)) 55 | assert len(files) == len(targets) 56 | return files, targets 57 | 58 | 59 | def __len__(self) -> int: 60 | return len(self.data) 61 | 62 | 63 | def __getitem__(self, index: int) -> Tuple[Tensor, Tensor]: 64 | audio, _ = torchaudio.load(self.data[index]) 65 | audio = self.resample(audio) 66 | if audio.shape[1] < self.num_frames: audio = torch.cat([audio, torch.zeros(1, self.num_frames-audio.shape[1])], dim=-1) # if less than 1s, pad the audio 67 | target = torch.tensor(self.targets[index]) 68 | 69 | if self.transform: audio = self.transform(audio) 70 | 71 | if random.random() < self.mixup: 72 | next_index = random.randint(0, len(self.data)-1) 73 | next_audio, _ = torchaudio.load(self.data[next_index]) 74 | next_audio = self.resample(next_audio) 75 | if next_audio.shape[1] < self.num_frames: next_audio = torch.cat([next_audio, torch.zeros(1, self.num_frames-next_audio.shape[1])], dim=-1) # if less than 1s, pad the audio 76 | next_target = torch.tensor(self.targets[next_index]) 77 | audio, target = mixup_augment(audio, target, next_audio, next_target, self.mixup_alpha, self.num_classes, self.label_smooth) 78 | else: 79 | target = F.one_hot(target, self.num_classes).float() 80 | 81 | audio = self.mel_tf(audio) # convert to mel spectrogram 82 | audio = 10.0 * audio.clamp_(1e-10).log10() # convert to log mel spectrogram 83 | 84 | if self.spec_transform: audio = self.spec_transform(audio) 85 | 86 | return audio, target 87 | 88 | 89 | class SpeechCommandsv2(SpeechCommandsv1): 90 | CLASSES = ['backward', 'bed', 'bird', 'cat', 'dog', 'down', 'eight', 'five', 'follow', 'forward', 'four', 'go', 'happy', 'house', 'learn', 'left', 91 | 'marvin', 'nine', 'no', 'off', 'on', 'one', 'right', 'seven', 'sheila', 'six', 'stop', 'three', 'tree', 'two', 'up', 'visual', 'wow', 'yes', 'zero'] 92 | 93 | 94 | if __name__ == '__main__': 95 | data_cfg = { 96 | 'ROOT': 'C:/Users/sithu/Documents/Datasets/SpeechCommands', 97 | 'SOURCE_SAMPLE': 16000, 98 | 'SAMPLE_RATE': 32000, 99 | 'AUDIO_LENGTH': 1, 100 | 'WIN_LENGTH': 1024, 101 | 'HOP_LENGTH': 320, 102 | 'N_MELS': 64, 103 | 'FMIN': 50, 104 | 'FMAX': 14000 105 | } 106 | aug_cfg = { 107 | 'MIXUP': 0.5, 108 | 'MIXUP_ALPHA': 10, 109 | 'SMOOTHING': 0.1 110 | } 111 | dataset = SpeechCommandsv1('train', data_cfg, aug_cfg) 112 | dataloader = DataLoader(dataset, 2, True) 113 | for audio, target in dataloader: 114 | print(audio.shape, target.argmax(dim=1)) 115 | print(audio.min(), audio.max()) 116 | break -------------------------------------------------------------------------------- /tools/train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import yaml 4 | import time 5 | import multiprocessing as mp 6 | from pprint import pprint 7 | from tqdm import tqdm 8 | from tabulate import tabulate 9 | from torch.nn.parallel import DistributedDataParallel as DDP 10 | from pathlib import Path 11 | from torch.utils.data import DataLoader 12 | from torch.cuda.amp import GradScaler, autocast 13 | from torch.utils.tensorboard import SummaryWriter 14 | 15 | import sys 16 | sys.path.insert(0, '.') 17 | from datasets import get_train_dataset, get_val_dataset, get_sampler 18 | from datasets.transforms import get_waveform_transforms, get_spec_transforms 19 | from models import get_model 20 | from utils.utils import fix_seeds, time_sync, setup_cudnn, setup_ddp, cleanup_ddp 21 | from utils.schedulers import get_scheduler 22 | from utils.losses import get_loss 23 | from utils.optimizers import get_optimizer 24 | from val import evaluate 25 | 26 | 27 | def main(cfg, gpu, save_dir): 28 | start = time_sync() 29 | 30 | best_score = 0.0 31 | num_workers = mp.cpu_count() 32 | device = torch.device(cfg['DEVICE']) 33 | train_config = cfg['TRAIN'] 34 | epochs = train_config['EPOCHS'] 35 | metric = cfg['DATASET']['METRIC'] 36 | lr = cfg['OPTIMIZER']['LR'] 37 | 38 | # augmentations 39 | # waveform_transforms = get_waveform_transforms(cfg['DATASET']['SAMPLE_RATE'], cfg['DATASET']['AUDIO_LENGTH'], 0.1, 3, 0.005) 40 | waveform_transforms = None 41 | spec_transforms = get_spec_transforms(cfg['DATASET'], cfg['AUG']) 42 | 43 | # dataset 44 | train_dataset = get_train_dataset(cfg['DATASET'], cfg['AUG'], waveform_transforms, spec_transforms) 45 | val_dataset = get_val_dataset(cfg['DATASET']) 46 | 47 | # dataset sampler 48 | train_sampler, val_sampler = get_sampler(train_config['DDP'], train_dataset, val_dataset) 49 | 50 | # dataloader 51 | train_dataloader = DataLoader(train_dataset, batch_size=train_config['BATCH_SIZE'], num_workers=num_workers, drop_last=True, pin_memory=True, sampler=train_sampler) 52 | val_dataloader = DataLoader(val_dataset, batch_size=train_config['BATCH_SIZE'], num_workers=num_workers, pin_memory=True, sampler=val_sampler) 53 | 54 | # create model 55 | model = get_model(cfg['MODEL']['NAME'], train_dataset.num_classes) 56 | model._init_weights(cfg['MODEL']['PRETRAINED']) 57 | model = model.to(device) 58 | if train_config['DDP']: model = DDP(model, device_ids=[gpu]) 59 | 60 | # loss function, optimizer, scheduler, AMP scaler, tensorboard writer 61 | loss_fn = get_loss(train_config['LOSS']) 62 | optimizer = get_optimizer(model, cfg['OPTIMIZER']['NAME'], lr, cfg['OPTIMIZER']['WEIGHT_DECAY']) 63 | scheduler = get_scheduler(cfg, optimizer) 64 | scaler = GradScaler(enabled=train_config['AMP']) 65 | writer = SummaryWriter(save_dir / 'logs') 66 | iters_per_epoch = len(train_dataset) // train_config['BATCH_SIZE'] 67 | 68 | for epoch in range(epochs): 69 | model.train() 70 | 71 | if train_config['DDP']: train_sampler.set_epoch(epoch) 72 | 73 | train_loss = 0.0 74 | pbar = tqdm(enumerate(train_dataloader), total=iters_per_epoch, desc=f"Epoch: [{epoch+1}/{epochs}] Iter: [{0}/{iters_per_epoch}] LR: {lr:.8f} Loss: {train_loss:.6f}") 75 | 76 | for iter, (audio, target) in pbar: 77 | audio = audio.to(device) 78 | target = target.to(device) 79 | 80 | optimizer.zero_grad() 81 | 82 | with autocast(enabled=train_config['AMP']): 83 | pred = model(audio) 84 | loss = loss_fn(pred, target) 85 | 86 | # Backpropagation 87 | scaler.scale(loss).backward() 88 | scaler.step(optimizer) 89 | scaler.update() 90 | 91 | lr = scheduler.get_last_lr()[0] 92 | train_loss += loss.item() 93 | 94 | pbar.set_description(f"Epoch: [{epoch+1}/{epochs}] Iter: [{iter+1}/{iters_per_epoch}] LR: {lr:.8f} Loss: {train_loss/(iter+1):.6f}") 95 | 96 | train_loss /= iter+1 97 | writer.add_scalar('train/loss', train_loss, epoch) 98 | 99 | scheduler.step() 100 | torch.cuda.empty_cache() 101 | 102 | if (epoch+1) % train_config['EVAL_INTERVAL'] == 0 or (epoch+1) == epochs: 103 | # evaluate the model 104 | score = evaluate(val_dataloader, model, device, metric)[0] 105 | writer.add_scalar(f'val/{metric}', score, epoch) 106 | 107 | if score >= best_score: 108 | best_score = score 109 | torch.save(model.module.state_dict() if train_config['DDP'] else model.state_dict(), save_dir / f"{cfg['MODEL']['NAME']}_{cfg['DATASET']['NAME']}.pth") 110 | print(f"Current {metric}: {score:.2f} Best {metric}: {best_score:.2f}") 111 | 112 | end = time.gmtime(time_sync() - start) 113 | total_time = time.strftime("%H:%M:%S", end) 114 | 115 | # results table 116 | table = [ 117 | [f'Best {metric}', f"{best_score:.2f}"], 118 | ['Total Training Time', total_time] 119 | ] 120 | print(tabulate(table, numalign='right')) 121 | 122 | writer.close() 123 | pbar.close() 124 | 125 | 126 | if __name__ == '__main__': 127 | parser = argparse.ArgumentParser() 128 | parser.add_argument('--cfg', type=str, required=True, help='Experiment configuration file name') 129 | args = parser.parse_args() 130 | 131 | with open(args.cfg) as f: 132 | cfg = yaml.load(f, Loader=yaml.SafeLoader) 133 | 134 | pprint(cfg) 135 | save_dir = Path(cfg['TRAIN']['SAVE_DIR']) 136 | save_dir.mkdir(exist_ok=True) 137 | fix_seeds(123) 138 | setup_cudnn() 139 | gpu = setup_ddp() 140 | main(cfg, gpu, save_dir) 141 | cleanup_ddp() -------------------------------------------------------------------------------- /datasets/aug_test.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import torch\n", 10 | "import torchaudio\n", 11 | "from torchaudio import transforms as T\n", 12 | "from torchaudio import functional as AF\n", 13 | "from transforms import *\n", 14 | "import sys\n", 15 | "sys.path.insert(0, '../')\n", 16 | "from utils.visualize import play_audio, plot_spectrogram" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": null, 22 | "metadata": {}, 23 | "outputs": [], 24 | "source": [ 25 | "audio_path = '../assests/test.wav'\n", 26 | "noise_path = '../assests/noises/voices.wav'\n", 27 | "\n", 28 | "sample_rate = 32000\n", 29 | "audio_length = 5 # 5s\n", 30 | "audio, sr = torchaudio.load(audio_path)\n", 31 | "noise, sr1 = torchaudio.load(noise_path)\n", 32 | "\n", 33 | "if sr != sample_rate: audio = AF.resample(audio, sr, sample_rate)\n", 34 | "if audio.shape[0] != 1: audio = audio[:1]\n", 35 | "if audio.shape[1]!= sample_rate*audio_length: audio = audio[:, :sample_rate*audio_length]\n", 36 | "if sr1 != sample_rate: noise = AF.resample(noise, sr1, sample_rate)" 37 | ] 38 | }, 39 | { 40 | "cell_type": "code", 41 | "execution_count": null, 42 | "metadata": {}, 43 | "outputs": [], 44 | "source": [ 45 | "play_audio(audio, sample_rate)\n", 46 | "play_audio(noise, sample_rate)" 47 | ] 48 | }, 49 | { 50 | "cell_type": "markdown", 51 | "metadata": {}, 52 | "source": [ 53 | "## WaveForm Augmentations" 54 | ] 55 | }, 56 | { 57 | "cell_type": "code", 58 | "execution_count": null, 59 | "metadata": {}, 60 | "outputs": [], 61 | "source": [ 62 | "## Background Noise\n", 63 | "background_noise = BackgroundNoise(3) #20(weak)~3(strong)\n", 64 | "noise_audio = background_noise(audio, noise)\n", 65 | "play_audio(noise_audio, sample_rate)" 66 | ] 67 | }, 68 | { 69 | "cell_type": "code", 70 | "execution_count": null, 71 | "metadata": {}, 72 | "outputs": [], 73 | "source": [ 74 | "# Fade In/Out\n", 75 | "fade = Fade(0.1) # 0.1(10% of audio)~0.5(50% of audio)\n", 76 | "fade_audio = fade(audio)\n", 77 | "play_audio(fade_audio, sample_rate)" 78 | ] 79 | }, 80 | { 81 | "cell_type": "code", 82 | "execution_count": null, 83 | "metadata": {}, 84 | "outputs": [], 85 | "source": [ 86 | "# Add Volume\n", 87 | "volume = Volume(20) # 3(weak)~20(strong)\n", 88 | "vol_audio = volume(audio)\n", 89 | "play_audio(vol_audio, sample_rate)" 90 | ] 91 | }, 92 | { 93 | "cell_type": "code", 94 | "execution_count": null, 95 | "metadata": {}, 96 | "outputs": [], 97 | "source": [ 98 | "# Gaussian Noise\n", 99 | "gnoise = GaussianNoise(0.005) # 0.005(weak)~0.02(strong)\n", 100 | "gnoise_audio = gnoise(audio)\n", 101 | "play_audio(gnoise_audio, sample_rate)" 102 | ] 103 | }, 104 | { 105 | "cell_type": "markdown", 106 | "metadata": {}, 107 | "source": [ 108 | "## Spectrogram Augmentations" 109 | ] 110 | }, 111 | { 112 | "cell_type": "code", 113 | "execution_count": null, 114 | "metadata": {}, 115 | "outputs": [], 116 | "source": [ 117 | "sample_rate = 32000\n", 118 | "win_length = 1024\n", 119 | "hop_length = 320\n", 120 | "n_mels = 64\n", 121 | "fmin = 50\n", 122 | "fmax = 14000\n", 123 | "audio_length = 5\n", 124 | "\n", 125 | "mel_tf = T.MelSpectrogram(sample_rate, win_length, win_length, hop_length, fmin, fmax, n_mels=n_mels, norm='slaney')\n", 126 | "spec = mel_tf(audio)\n", 127 | "plot_spectrogram(spec)\n", 128 | "spec.shape, spec.min(), spec.max()" 129 | ] 130 | }, 131 | { 132 | "cell_type": "code", 133 | "execution_count": null, 134 | "metadata": {}, 135 | "outputs": [], 136 | "source": [ 137 | "## Log-mel Spectrogram\n", 138 | "logspec = 10.0 * spec.log10()\n", 139 | "plot_spectrogram(logspec)" 140 | ] 141 | }, 142 | { 143 | "cell_type": "code", 144 | "execution_count": null, 145 | "metadata": {}, 146 | "outputs": [], 147 | "source": [ 148 | "## Time Masking\n", 149 | "tmasking = TimeMasking(200, audio_length)\n", 150 | "tspec = tmasking(spec)\n", 151 | "plot_spectrogram(tspec)" 152 | ] 153 | }, 154 | { 155 | "cell_type": "code", 156 | "execution_count": null, 157 | "metadata": {}, 158 | "outputs": [], 159 | "source": [ 160 | "## Frequency Masking\n", 161 | "fmasking = FrequencyMasking(24, n_mels)\n", 162 | "fspec = fmasking(spec)\n", 163 | "plot_spectrogram(fspec)" 164 | ] 165 | }, 166 | { 167 | "cell_type": "code", 168 | "execution_count": null, 169 | "metadata": {}, 170 | "outputs": [], 171 | "source": [ 172 | "## Filter Augment\n", 173 | "filteraug = FilterAugment((-20, 20), (5, 10))\n", 174 | "filtspec = filteraug(spec)\n", 175 | "plot_spectrogram(spec)\n", 176 | "plot_spectrogram(filtspec)" 177 | ] 178 | }, 179 | { 180 | "cell_type": "code", 181 | "execution_count": null, 182 | "metadata": {}, 183 | "outputs": [], 184 | "source": [] 185 | } 186 | ], 187 | "metadata": { 188 | "interpreter": { 189 | "hash": "78184fe1b8a3f830767e8814b2b01c36fc7c8ac521e39cb583cd3fce210fee57" 190 | }, 191 | "kernelspec": { 192 | "display_name": "Python 3", 193 | "language": "python", 194 | "name": "python3" 195 | }, 196 | "language_info": { 197 | "codemirror_mode": { 198 | "name": "ipython", 199 | "version": 3 200 | }, 201 | "file_extension": ".py", 202 | "mimetype": "text/x-python", 203 | "name": "python", 204 | "nbconvert_exporter": "python", 205 | "pygments_lexer": "ipython3", 206 | "version": "3.9.5" 207 | } 208 | }, 209 | "nbformat": 4, 210 | "nbformat_minor": 2 211 | } 212 | -------------------------------------------------------------------------------- /models/cnn14.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, Tensor 3 | 4 | 5 | class Conv2Layer(nn.Module): 6 | def __init__(self, c1, c2, pool_size=2): 7 | super().__init__() 8 | self.conv1 = nn.Conv2d(c1, c2, 3, 1, 1, bias=False) 9 | self.conv2 = nn.Conv2d(c2, c2, 3, 1, 1, bias=False) 10 | self.bn1 = nn.BatchNorm2d(c2) 11 | self.bn2 = nn.BatchNorm2d(c2) 12 | self.relu = nn.ReLU() 13 | self.avgpool = nn.AvgPool2d(pool_size) 14 | self.dropout = nn.Dropout(0.2) 15 | 16 | def forward(self, x: Tensor) -> Tensor: 17 | x = self.relu(self.bn1(self.conv1(x))) 18 | x = self.relu(self.bn2(self.conv2(x))) 19 | x = self.avgpool(x) 20 | x = self.dropout(x) 21 | return x 22 | 23 | 24 | 25 | class CNN14(nn.Module): 26 | def __init__(self, num_classes: int = 50): 27 | super().__init__() 28 | self.bn0 = nn.BatchNorm2d(64) 29 | self.conv_block1 = Conv2Layer(1, 64) 30 | self.conv_block2 = Conv2Layer(64, 128) 31 | self.conv_block3 = Conv2Layer(128, 256) 32 | self.conv_block4 = Conv2Layer(256, 512) 33 | self.conv_block5 = Conv2Layer(512, 1024) 34 | self.conv_block6 = Conv2Layer(1024, 2048, 1) 35 | 36 | self.dropout = nn.Dropout(0.5) 37 | self.relu = nn.ReLU() 38 | 39 | self.fc1 = nn.Linear(2048, 2048) 40 | self.fc = nn.Linear(2048, num_classes) 41 | 42 | def _init_weights(self, pretrained: str = None): 43 | if pretrained: 44 | print(f"Loading Pretrained Weights from {pretrained}") 45 | pretrained_dict = torch.load(pretrained, map_location='cpu') 46 | model_dict = self.state_dict() 47 | for k in model_dict.keys(): 48 | if not k.startswith('fc.'): 49 | model_dict[k] = pretrained_dict[k] 50 | self.load_state_dict(model_dict) 51 | else: 52 | for m in self.modules(): 53 | if isinstance(m, nn.Conv2d): 54 | nn.init.xavier_uniform_(m.weight) 55 | elif isinstance(m, nn.Linear): 56 | nn.init.xavier_uniform_(m.weight) 57 | if hasattr(m, 'bias'): 58 | nn.init.constant_(m.bias, 0) 59 | elif isinstance(m, nn.BatchNorm2d): 60 | nn.init.constant_(m.weight, 1.0) 61 | nn.init.constant_(m.bias, 0.0) 62 | 63 | def forward(self, x: Tensor) -> Tensor: 64 | # x [B, 1, mel_bins, time_steps] 65 | x = x.permute(0, 2, 3, 1) 66 | x = self.bn0(x) 67 | x = x.transpose(1, 3) 68 | 69 | x = self.conv_block1(x) 70 | x = self.conv_block2(x) 71 | x = self.conv_block3(x) 72 | x = self.conv_block4(x) 73 | x = self.conv_block5(x) 74 | x = self.conv_block6(x) 75 | 76 | x = x.mean(3) 77 | x = x.max(dim=2)[0] + x.mean(2) 78 | x = self.dropout(x) 79 | x = self.relu(self.fc1(x)) 80 | return self.fc(x) 81 | 82 | 83 | class CNN14DecisionLevelMax(nn.Module): 84 | def __init__(self, num_classes: int = 50): 85 | super().__init__() 86 | self.interpolate_ratio = 32 # downsampled ratio 87 | 88 | self.bn0 = nn.BatchNorm2d(64) 89 | self.conv_block1 = Conv2Layer(1, 64) 90 | self.conv_block2 = Conv2Layer(64, 128) 91 | self.conv_block3 = Conv2Layer(128, 256) 92 | self.conv_block4 = Conv2Layer(256, 512) 93 | self.conv_block5 = Conv2Layer(512, 1024) 94 | self.conv_block6 = Conv2Layer(1024, 2048, 1) 95 | 96 | self.dropout = nn.Dropout(0.5) 97 | self.relu = nn.ReLU() 98 | self.maxpool = nn.MaxPool1d(3, 1, 1) 99 | self.avgpool = nn.AvgPool1d(3, 1, 1) 100 | 101 | self.fc1 = nn.Linear(2048, 2048) 102 | self.fc = nn.Linear(2048, num_classes) 103 | 104 | def _init_weights(self, pretrained: str = None): 105 | if pretrained: 106 | print(f"Loading Pretrained Weights from {pretrained}") 107 | pretrained_dict = torch.load(pretrained, map_location='cpu') 108 | model_dict = self.state_dict() 109 | for k in model_dict.keys(): 110 | if not k.startswith('fc.'): 111 | model_dict[k] = pretrained_dict[k] 112 | self.load_state_dict(model_dict) 113 | else: 114 | for m in self.modules(): 115 | if isinstance(m, nn.Conv2d): 116 | nn.init.xavier_uniform_(m.weight) 117 | elif isinstance(m, nn.Linear): 118 | nn.init.xavier_uniform_(m.weight) 119 | if hasattr(m, 'bias'): 120 | nn.init.constant_(m.bias, 0) 121 | elif isinstance(m, nn.BatchNorm2d): 122 | nn.init.constant_(m.weight, 1.0) 123 | nn.init.constant_(m.bias, 0.0) 124 | 125 | def forward(self, x: Tensor) -> Tensor: 126 | # x [B, 1, mel_bins, time_steps] 127 | num_frames = x.shape[-1] 128 | x = x.permute(0, 2, 3, 1) 129 | x = self.bn0(x) 130 | x = x.transpose(1, 3) 131 | 132 | x = self.conv_block1(x) 133 | x = self.conv_block2(x) 134 | x = self.conv_block3(x) 135 | x = self.conv_block4(x) 136 | x = self.conv_block5(x) 137 | x = self.conv_block6(x) 138 | 139 | x = x.mean(3) 140 | x = self.maxpool(x) + self.avgpool(x) 141 | x = self.dropout(x) 142 | x = x.transpose(1, 2) 143 | x = self.relu(self.fc1(x)) 144 | 145 | segmentwise = self.fc(x).sigmoid() 146 | clipwise = segmentwise.max(dim=1)[0] 147 | 148 | # get framewise output 149 | framewise = interpolate(segmentwise, self.interpolate_ratio) 150 | framewise = pad_framewise(framewise, num_frames) 151 | 152 | return framewise, clipwise 153 | 154 | 155 | 156 | def interpolate(x, ratio): 157 | B, T, C = x.shape 158 | upsampled = x[:, :, None, :].repeat(1, 1, ratio, 1) 159 | upsampled = upsampled.reshape(B, T*ratio, C) 160 | return upsampled 161 | 162 | 163 | def pad_framewise(framewise, num_frames): 164 | pad = framewise[:, -1:, :].repeat(1, num_frames-framewise.shape[1], 1) 165 | return torch.cat([framewise, pad], dim=1) 166 | 167 | 168 | if __name__ == '__main__': 169 | import time 170 | model = CNN14(527) 171 | model.load_state_dict(torch.load('checkpoints/cnn14.pth', map_location='cpu')) 172 | x = torch.randn(3, 1, 64, 701) 173 | start = time.time() 174 | y = model(x) 175 | print(time.time()-start) 176 | print(y.shape) 177 | print(y.min(), y.max()) -------------------------------------------------------------------------------- /datasets/transforms.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random 3 | import math 4 | import numpy as np 5 | from torch import nn, Tensor 6 | from torchaudio import transforms as T 7 | from typing import Tuple 8 | 9 | 10 | ####################################################### 11 | ## Voice Activity Detector 12 | 13 | class VoiceActivityDetector(nn.Module): 14 | def __init__(self): 15 | super().__init__() 16 | self.vad = T.Vad( 17 | sample_rate=16000, 18 | trigger_level=7, 19 | trigger_time=0.25, 20 | search_time=1.0, 21 | allowed_gap=0.25, 22 | pre_trigger_time=0.0, 23 | boot_time=0.35, 24 | noise_up_time=0.1, 25 | noise_down_time=0.01, 26 | noise_reduction_amount=1.35, 27 | measure_freq=20.0, 28 | measure_duration=None, 29 | measure_smooth_time=0.4 30 | ) 31 | 32 | def forward(self, waveform: Tensor) -> Tensor: 33 | waveform = self.vad(waveform) 34 | return waveform 35 | 36 | 37 | ############################################################# 38 | ### WAVEFORM AUGMENTATIONS ### 39 | 40 | class Fade(nn.Module): 41 | def __init__(self, fio: float = 0.1, sample_rate: int = 32000, audio_length: int = 5): 42 | super().__init__() 43 | fiop = int(fio * sample_rate * audio_length) 44 | self.fade = T.Fade(fiop, fiop) 45 | 46 | def forward(self, waveform: Tensor) -> Tensor: 47 | return self.fade(waveform) 48 | 49 | 50 | class Volume(nn.Module): 51 | """ 52 | gain: in decibel (3 (weak) to 20 (strong)) 53 | """ 54 | def __init__(self, gain: int = 3): 55 | super().__init__() 56 | self.volume = T.Vol(gain, gain_type='db') 57 | 58 | def forward(self, waveform: Tensor) -> Tensor: 59 | return self.volume(waveform) 60 | 61 | 62 | class GaussianNoise(nn.Module): 63 | def __init__(self, sigma: float = 0.005): 64 | super().__init__() 65 | self.sigma = sigma # 0.005 (weak) to 0.02 (strong) 66 | 67 | def forward(self, waveform: Tensor) -> Tensor: 68 | gauss = np.random.normal(0, self.sigma, waveform.shape) 69 | waveform += gauss 70 | return waveform 71 | 72 | 73 | class BackgroundNoise(nn.Module): 74 | """To add background noise to audio 75 | For simplicity, you can add audio Tensor with noise Tensor. 76 | A common way to adjust the intensity of noise is to change Signal-to-Noise Ratio (SNR) 77 | 78 | SNR = audio / noise 79 | SNR(db) = 10 * log10(SNR) 80 | """ 81 | def __init__(self, snr_db: int = 20): 82 | super().__init__() 83 | self.snr_db = snr_db # 3 (strong) to 20 (weak) 84 | 85 | def forward(self, waveform: Tensor, noise: Tensor) -> Tensor: 86 | # assume waveform and noise have same sample rates 87 | if noise.shape[1] < waveform.shape[1]: 88 | noise = torch.cat([noise[:1], torch.zeros(1, waveform.shape[1]-noise.shape[1])], dim=-1) 89 | noise = noise[:1, :waveform.shape[1]] 90 | scale = math.exp(self.snr_db / 10) * noise.norm(p=2) / waveform.norm(p=2) 91 | return (scale * waveform + noise) / 2 92 | 93 | 94 | def mixup_augment(audio1: Tensor, target1: Tensor, audio2: Tensor, target2: Tensor, alpha: int = 10, num_classes: int = 50, smoothing: float = 0.1) -> Tuple[Tensor, Tensor]: 95 | ## assume audio1 and audio2 are mono channel audios and have same sampling rates 96 | mix_lambda = np.random.beta(alpha, alpha) 97 | 98 | off_value = smoothing / num_classes 99 | on_value = 1 - smoothing + off_value 100 | target1 = torch.full((1, num_classes), off_value).scatter_(1, target1.long().view(-1, 1), on_value).squeeze() 101 | target2 = torch.full((1, num_classes), off_value).scatter_(1, target2.long().view(-1, 1), on_value).squeeze() 102 | 103 | # target1 = F.one_hot(target1, num_classes) 104 | # target2 = F.one_hot(target2, num_classes) 105 | 106 | audio = audio1 * mix_lambda + audio2 * (1 - mix_lambda) 107 | target = target1 * mix_lambda + target2 * (1 - mix_lambda) 108 | 109 | return audio, target 110 | 111 | 112 | ########################################################## 113 | ## Spectrogram Augmentations 114 | 115 | class TimeMasking(nn.Module): 116 | def __init__(self, mask: int = 96, audio_length: int = 5): 117 | super().__init__() 118 | assert mask < audio_length*100, f"TimeMasking parameter should be less than time frames >> {mask} > {audio_length*100}" 119 | self.masking = T.TimeMasking(mask) 120 | 121 | def forward(self, spec: Tensor) -> Tensor: 122 | return self.masking(spec) 123 | 124 | 125 | class FrequencyMasking(nn.Module): 126 | def __init__(self, mask: int = 24, n_mels: int = 64): 127 | super().__init__() 128 | assert mask < n_mels, f"FrequencyMasking parameter should be less than num mels >> {mask} > {n_mels}" 129 | self.masking = T.FrequencyMasking(mask) 130 | 131 | def forward(self, spec: Tensor) -> Tensor: 132 | return self.masking(spec) 133 | 134 | 135 | class FilterAugment(nn.Module): 136 | """ 137 | https://github.com/frednam93/FilterAugSED 138 | https://arxiv.org/abs/2107.03649 139 | """ 140 | def __init__(self, db_range=(-7.5, 6), n_bands=(2, 5)): 141 | super().__init__() 142 | self.db_range = db_range 143 | self.n_bands = n_bands 144 | 145 | def forward(self, audio: Tensor) -> Tensor: 146 | C, F, _ = audio.shape 147 | n_freq_band = random.randint(self.n_bands[0], self.n_bands[1]) 148 | 149 | if n_freq_band > 1: 150 | band_boundary_freqs = torch.cat([ 151 | torch.tensor([0]), 152 | torch.sort(torch.randint(1, F-1, (F-1,)))[0], 153 | torch.tensor([F]) 154 | ]) 155 | band_factors = torch.rand((C, n_freq_band)) * (self.db_range[1] - self.db_range[0]) + self.db_range[0] 156 | band_factors = 10 ** (band_factors / 20) 157 | freq_filter = torch.ones((C, F, 1)) 158 | 159 | for i in range(n_freq_band): 160 | freq_filter[:, band_boundary_freqs[i]:band_boundary_freqs[i+1], :] = band_factors[:, i].unsqueeze(-1).unsqueeze(-1) 161 | 162 | audio *= freq_filter 163 | return audio 164 | 165 | 166 | def get_waveform_transforms(data_config, aug_config): 167 | return nn.Sequential( 168 | Fade(0.1, data_config['SAMPLE_RATE'], data_config['AUDIO_LENGTH']), 169 | Volume(3), 170 | GaussianNoise(0.005) 171 | ) 172 | 173 | 174 | def get_spec_transforms(data_config, aug_config): 175 | return nn.Sequential( 176 | TimeMasking(aug_config['TIME_MASK'], data_config['AUDIO_LENGTH']), 177 | FrequencyMasking(aug_config['FREQ_MASK'], data_config['N_MELS']) 178 | ) 179 | 180 | 181 | if __name__ == '__main__': 182 | torch.manual_seed(123) 183 | x = torch.randn(4, 1, 64, 701) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | #