├── .gitignore ├── images ├── mem_comp.png ├── model_macs.png ├── model_params.png ├── model_macs_TASLP.png ├── model_params_TASLP.png ├── hear_task_categories.png └── hear_individual_task_results.png ├── wandb └── README.md ├── resources ├── README.md └── metro_station-paris.wav ├── requirements.txt ├── models ├── dymn │ ├── utils.py │ ├── model.py │ └── dy_block.py ├── ensemble.py ├── mn │ ├── utils.py │ ├── attention_pooling.py │ └── block_types.py └── preprocess.py ├── helpers ├── receptive_field.py ├── init.py ├── utils.py ├── peak_memory.py └── flop_count.py ├── LICENSE ├── receptive_field_cnn.py ├── datasets ├── helpers │ └── audiodatasets.py ├── esc50.py ├── dcase20.py ├── openmic.py ├── fsd50k.py └── audioset.py ├── inference.py ├── complexity.py ├── windowed_inference.py ├── ex_esc50.py ├── ex_dcase20.py ├── ex_openmic.py ├── ex_fsd50k.py ├── metadata └── class_labels_indices.csv ├── ex_pl_audioset.py └── ex_audioset.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | resources/*.pt 3 | resources/*.npy 4 | .idea -------------------------------------------------------------------------------- /images/mem_comp.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wetdog/EfficientAT/main/images/mem_comp.png -------------------------------------------------------------------------------- /images/model_macs.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wetdog/EfficientAT/main/images/model_macs.png -------------------------------------------------------------------------------- /wandb/README.md: -------------------------------------------------------------------------------- 1 | WANDB uses this directory as default to log the experiments and store the models. -------------------------------------------------------------------------------- /images/model_params.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wetdog/EfficientAT/main/images/model_params.png -------------------------------------------------------------------------------- /images/model_macs_TASLP.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wetdog/EfficientAT/main/images/model_macs_TASLP.png -------------------------------------------------------------------------------- /images/model_params_TASLP.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wetdog/EfficientAT/main/images/model_params_TASLP.png -------------------------------------------------------------------------------- /images/hear_task_categories.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wetdog/EfficientAT/main/images/hear_task_categories.png -------------------------------------------------------------------------------- /resources/README.md: -------------------------------------------------------------------------------- 1 | Download the latest version from this repo's Github Releases and place them inside this folder. -------------------------------------------------------------------------------- /resources/metro_station-paris.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wetdog/EfficientAT/main/resources/metro_station-paris.wav -------------------------------------------------------------------------------- /images/hear_individual_task_results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wetdog/EfficientAT/main/images/hear_individual_task_results.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | av==10.0.0 2 | h5py==3.7.0 3 | librosa==0.9.2 4 | numpy==1.23.3 5 | scikit_learn==1.1.3 6 | torch==1.13.0 7 | torchaudio==0.13.0 8 | torchvision==0.14.0 9 | tqdm==4.64.1 10 | wandb==0.13.5 11 | pandas==1.5.2 -------------------------------------------------------------------------------- /models/dymn/utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Optional 3 | 4 | 5 | def make_divisible(v: float, divisor: int, min_value: Optional[int] = None) -> int: 6 | """ 7 | This function is taken from the original tf repo. 8 | It ensures that all layers have a channel number that is divisible by 8 9 | It can be seen here: 10 | https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py 11 | """ 12 | if min_value is None: 13 | min_value = divisor 14 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) 15 | # Make sure that round down does not go down by more than 10%. 16 | if new_v < 0.9 * v: 17 | new_v += divisor 18 | return new_v 19 | 20 | 21 | def cnn_out_size(in_size, padding, dilation, kernel, stride): 22 | s = in_size + 2 * padding - dilation * (kernel - 1) - 1 23 | return math.floor(s / stride + 1) 24 | -------------------------------------------------------------------------------- /helpers/receptive_field.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | def receptive_field_cnn(model, spec_size): 6 | kernel_sizes = [] 7 | strides = [] 8 | 9 | def conv2d_hook(self, input, output): 10 | kernel_sizes.append(self.kernel_size[0]) 11 | strides.append(self.stride[0]) 12 | 13 | def foo(net): 14 | if net.__class__.__name__ == 'Conv2d': 15 | net.register_forward_hook(conv2d_hook) 16 | childrens = list(net.children()) 17 | if isinstance(net, nn.Conv2d): 18 | net.register_forward_hook(conv2d_hook) 19 | for c in childrens: 20 | foo(c) 21 | 22 | # Register hook 23 | foo(model) 24 | 25 | device = next(model.parameters()).device 26 | input = torch.rand(spec_size).to(device) 27 | with torch.no_grad(): 28 | model(input) 29 | 30 | r = 1 31 | for k, s in zip(kernel_sizes[::-1], strides[::-1]): 32 | r = s * r + (k - s) 33 | 34 | return r 35 | -------------------------------------------------------------------------------- /helpers/init.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import random 4 | 5 | 6 | def worker_init_fn(wid): 7 | seed_sequence = np.random.SeedSequence( 8 | [torch.initial_seed(), wid] 9 | ) 10 | 11 | to_seed = spawn_get(seed_sequence, 2, dtype=int) 12 | torch.random.manual_seed(to_seed) 13 | 14 | np_seed = spawn_get(seed_sequence, 2, dtype=np.ndarray) 15 | np.random.seed(np_seed) 16 | 17 | py_seed = spawn_get(seed_sequence, 2, dtype=int) 18 | random.seed(py_seed) 19 | 20 | 21 | def spawn_get(seedseq, n_entropy, dtype): 22 | child = seedseq.spawn(1)[0] 23 | state = child.generate_state(n_entropy, dtype=np.uint32) 24 | 25 | if dtype == np.ndarray: 26 | return state 27 | elif dtype == int: 28 | state_as_int = 0 29 | for shift, s in enumerate(state): 30 | state_as_int = state_as_int + int((2 ** (32 * shift) * s)) 31 | return state_as_int 32 | else: 33 | raise ValueError(f'not a valid dtype "{dtype}"') 34 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Florian Schmid 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /models/ensemble.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from models.mn.model import get_model as get_mobilenet 4 | from models.dymn.model import get_model as get_dymn 5 | from helpers.utils import NAME_TO_WIDTH 6 | 7 | 8 | class EnsemblerModel(nn.Module): 9 | def __init__(self, models): 10 | super(EnsemblerModel, self).__init__() 11 | self.models = nn.ModuleList(models) 12 | 13 | def forward(self, x): 14 | all_out = None 15 | for m in self.models: 16 | out, _ = m(x) 17 | if all_out is None: 18 | all_out = out 19 | else: 20 | all_out = out + all_out 21 | all_out = all_out / len(self.models) 22 | return all_out, all_out 23 | 24 | 25 | def get_ensemble_model(model_names): 26 | models = [] 27 | for model_name in model_names: 28 | if model_name.startswith("dymn"): 29 | model = get_dymn(width_mult=NAME_TO_WIDTH(model_name), pretrained_name=model_name) 30 | else: 31 | model = get_mobilenet(width_mult=NAME_TO_WIDTH(model_name), pretrained_name=model_name) 32 | models.append(model) 33 | return EnsemblerModel(models) 34 | -------------------------------------------------------------------------------- /receptive_field_cnn.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from models.mn.model import get_model 3 | from helpers.utils import NAME_TO_WIDTH 4 | from helpers.receptive_field import receptive_field_cnn 5 | 6 | 7 | def calc_receptive_field(args): 8 | # model 9 | if args.model_width: 10 | # manually specified settings 11 | width = args.model_width 12 | model_name = "mn{}".format(str(width).replace(".", "")) 13 | else: 14 | # model width specified via model name 15 | model_name = args.model_name 16 | width = NAME_TO_WIDTH(model_name) 17 | model = get_model(width_mult=width, se_dims=args.se_dims, head_type=args.head_type, strides=args.strides) 18 | model.eval() 19 | 20 | r = receptive_field_cnn(model, (1, 1, 128, 1000)) 21 | print(f"Receptive field size of {model_name} with strides {args.strides}: ", r) 22 | 23 | 24 | if __name__ == '__main__': 25 | parser = argparse.ArgumentParser(description='Example of parser. ') 26 | 27 | # model name decides, which pre-trained model is evaluated in terms of complexity 28 | parser.add_argument('--model_name', type=str, default='mn10_as') 29 | # alternatively, specify model configurations manually 30 | parser.add_argument('--model_width', type=float, default=None) 31 | parser.add_argument('--head_type', type=str, default='mlp') 32 | parser.add_argument('--strides', nargs=4, default=[2, 2, 2, 2], type=int) 33 | parser.add_argument('--se_dims', type=str, default='c') 34 | 35 | args = parser.parse_args() 36 | calc_receptive_field(args) 37 | -------------------------------------------------------------------------------- /datasets/helpers/audiodatasets.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | import torch 3 | import numpy as np 4 | from functools import partial 5 | 6 | 7 | class PreprocessDataset(Dataset): 8 | """A base preprocessing dataset representing a preprocessing step of a Dataset preprocessed on the fly. 9 | supporting integer indexing in range from 0 to len(self) exclusive. 10 | """ 11 | 12 | def __init__(self, dataset, preprocessor): 13 | self.dataset = dataset 14 | if not callable(preprocessor): 15 | print("preprocessor: ", preprocessor) 16 | raise ValueError('preprocessor should be callable') 17 | self.preprocessor = preprocessor 18 | 19 | def __getitem__(self, index): 20 | return self.preprocessor(self.dataset[index]) 21 | 22 | def __len__(self): 23 | return len(self.dataset) 24 | 25 | 26 | def get_roll_func(axis=1, shift=None, shift_range=4000): 27 | return partial(roll_func, axis=axis, shift=shift, shift_range=shift_range) 28 | 29 | 30 | # roll waveform (over time) 31 | def roll_func(b, axis=1, shift=None, shift_range=4000): 32 | x = b[0] 33 | others = b[1:] 34 | x = torch.as_tensor(x) 35 | sf = shift 36 | if shift is None: 37 | sf = int(np.random.random_integers(-shift_range, shift_range)) 38 | return (x.roll(sf, axis), *others) 39 | 40 | 41 | def get_gain_augment_func(gain_augment): 42 | return partial(gain_augment_func, gain_augment=gain_augment) 43 | 44 | 45 | def gain_augment_func(b, gain_augment=12): 46 | x = b[0] 47 | others = b[1:] 48 | gain = torch.randint(gain_augment * 2, (1,)).item() - gain_augment 49 | amp = 10 ** (gain / 20) 50 | x = x * amp 51 | return (x, *others) 52 | -------------------------------------------------------------------------------- /models/mn/utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Optional, Callable 3 | import torch 4 | import torch.nn as nn 5 | from torch import Tensor 6 | 7 | 8 | def make_divisible(v: float, divisor: int, min_value: Optional[int] = None) -> int: 9 | """ 10 | This function is taken from the original tf repo. 11 | It ensures that all layers have a channel number that is divisible by 8 12 | It can be seen here: 13 | https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py 14 | """ 15 | if min_value is None: 16 | min_value = divisor 17 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) 18 | # Make sure that round down does not go down by more than 10%. 19 | if new_v < 0.9 * v: 20 | new_v += divisor 21 | return new_v 22 | 23 | 24 | def cnn_out_size(in_size, padding, dilation, kernel, stride): 25 | s = in_size + 2 * padding - dilation * (kernel - 1) - 1 26 | return math.floor(s / stride + 1) 27 | 28 | 29 | def collapse_dim(x: Tensor, dim: int, mode: str = "pool", pool_fn: Callable[[Tensor, int], Tensor] = torch.mean, 30 | combine_dim: int = None): 31 | """ 32 | Collapses dimension of multi-dimensional tensor by pooling or combining dimensions 33 | :param x: input Tensor 34 | :param dim: dimension to collapse 35 | :param mode: 'pool' or 'combine' 36 | :param pool_fn: function to be applied in case of pooling 37 | :param combine_dim: dimension to join 'dim' to 38 | :return: collapsed tensor 39 | """ 40 | if mode == "pool": 41 | return pool_fn(x, dim) 42 | elif mode == "combine": 43 | s = list(x.size()) 44 | s[combine_dim] *= dim 45 | s[dim] //= dim 46 | return x.view(s) 47 | 48 | 49 | class CollapseDim(nn.Module): 50 | def __init__(self, dim: int, mode: str = "pool", pool_fn: Callable[[Tensor, int], Tensor] = torch.mean, 51 | combine_dim: int = None): 52 | super(CollapseDim, self).__init__() 53 | self.dim = dim 54 | self.mode = mode 55 | self.pool_fn = pool_fn 56 | self.combine_dim = combine_dim 57 | 58 | def forward(self, x): 59 | return collapse_dim(x, dim=self.dim, mode=self.mode, pool_fn=self.pool_fn, combine_dim=self.combine_dim) 60 | -------------------------------------------------------------------------------- /models/mn/attention_pooling.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch import Tensor 5 | 6 | from models.mn.utils import collapse_dim 7 | 8 | 9 | class MultiHeadAttentionPooling(nn.Module): 10 | """Multi-Head Attention as used in PSLA paper (https://arxiv.org/pdf/2102.01243.pdf) 11 | """ 12 | def __init__(self, in_dim, out_dim, att_activation: str = 'sigmoid', 13 | clf_activation: str = 'ident', num_heads: int = 4, epsilon: float = 1e-7): 14 | super(MultiHeadAttentionPooling, self).__init__() 15 | 16 | self.in_dim = in_dim 17 | self.out_dim = out_dim 18 | self.num_heads = num_heads 19 | self.epsilon = epsilon 20 | 21 | self.att_activation = att_activation 22 | self.clf_activation = clf_activation 23 | 24 | # out size: out dim x 2 (att and clf paths) x num_heads 25 | self.subspace_proj = nn.Linear(self.in_dim, self.out_dim * 2 * self.num_heads) 26 | self.head_weight = nn.Parameter(torch.tensor([1.0 / self.num_heads] * self.num_heads).view(1, -1, 1)) 27 | 28 | def activate(self, x, activation): 29 | if activation == 'linear': 30 | return x 31 | elif activation == 'relu': 32 | return F.relu(x) 33 | elif activation == 'sigmoid': 34 | return torch.sigmoid(x) 35 | elif activation == 'softmax': 36 | return F.softmax(x, dim=1) 37 | elif activation == 'ident': 38 | return x 39 | 40 | def forward(self, x) -> Tensor: 41 | """x: Tensor of size (batch_size, channels, frequency bands, sequence length) 42 | """ 43 | x = collapse_dim(x, dim=2) # results in tensor of size (batch_size, channels, sequence_length) 44 | x = x.transpose(1, 2) # results in tensor of size (batch_size, sequence_length, channels) 45 | b, n, c = x.shape 46 | 47 | x = self.subspace_proj(x).reshape(b, n, 2, self.num_heads, self.out_dim).permute(2, 0, 3, 1, 4) 48 | att, val = x[0], x[1] 49 | val = self.activate(val, self.clf_activation) 50 | att = self.activate(att, self.att_activation) 51 | att = torch.clamp(att, self.epsilon, 1. - self.epsilon) 52 | att = att / torch.sum(att, dim=2, keepdim=True) 53 | 54 | out = torch.sum(att * val, dim=2) * self.head_weight 55 | out = torch.sum(out, dim=1) 56 | return out 57 | -------------------------------------------------------------------------------- /models/preprocess.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torchaudio 3 | import torch 4 | 5 | 6 | class AugmentMelSTFT(nn.Module): 7 | def __init__(self, n_mels=128, sr=32000, win_length=800, hopsize=320, n_fft=1024, freqm=48, timem=192, 8 | fmin=0.0, fmax=None, fmin_aug_range=10, fmax_aug_range=2000): 9 | torch.nn.Module.__init__(self) 10 | # adapted from: https://github.com/CPJKU/kagglebirds2020/commit/70f8308b39011b09d41eb0f4ace5aa7d2b0e806e 11 | 12 | self.win_length = win_length 13 | self.n_mels = n_mels 14 | self.n_fft = n_fft 15 | self.sr = sr 16 | self.fmin = fmin 17 | if fmax is None: 18 | fmax = sr // 2 - fmax_aug_range // 2 19 | print(f"Warning: FMAX is None setting to {fmax} ") 20 | self.fmax = fmax 21 | self.hopsize = hopsize 22 | self.register_buffer('window', 23 | torch.hann_window(win_length, periodic=False), 24 | persistent=False) 25 | assert fmin_aug_range >= 1, f"fmin_aug_range={fmin_aug_range} should be >=1; 1 means no augmentation" 26 | assert fmax_aug_range >= 1, f"fmax_aug_range={fmax_aug_range} should be >=1; 1 means no augmentation" 27 | self.fmin_aug_range = fmin_aug_range 28 | self.fmax_aug_range = fmax_aug_range 29 | 30 | self.register_buffer("preemphasis_coefficient", torch.as_tensor([[[-.97, 1]]]), persistent=False) 31 | if freqm == 0: 32 | self.freqm = torch.nn.Identity() 33 | else: 34 | self.freqm = torchaudio.transforms.FrequencyMasking(freqm, iid_masks=True) 35 | if timem == 0: 36 | self.timem = torch.nn.Identity() 37 | else: 38 | self.timem = torchaudio.transforms.TimeMasking(timem, iid_masks=True) 39 | 40 | def forward(self, x): 41 | x = nn.functional.conv1d(x.unsqueeze(1), self.preemphasis_coefficient).squeeze(1) 42 | x = torch.stft(x, self.n_fft, hop_length=self.hopsize, win_length=self.win_length, 43 | center=True, normalized=False, window=self.window, return_complex=False) 44 | x = (x ** 2).sum(dim=-1) # power mag 45 | fmin = self.fmin + torch.randint(self.fmin_aug_range, (1,)).item() 46 | fmax = self.fmax + self.fmax_aug_range // 2 - torch.randint(self.fmax_aug_range, (1,)).item() 47 | # don't augment eval data 48 | if not self.training: 49 | fmin = self.fmin 50 | fmax = self.fmax 51 | 52 | mel_basis, _ = torchaudio.compliance.kaldi.get_mel_banks(self.n_mels, self.n_fft, self.sr, 53 | fmin, fmax, vtln_low=100.0, vtln_high=-500., vtln_warp_factor=1.0) 54 | mel_basis = torch.as_tensor(torch.nn.functional.pad(mel_basis, (0, 1), mode='constant', value=0), 55 | device=x.device) 56 | with torch.cuda.amp.autocast(enabled=False): 57 | melspec = torch.matmul(mel_basis, x) 58 | 59 | melspec = (melspec + 0.00001).log() 60 | 61 | if self.training: 62 | melspec = self.freqm(melspec) 63 | melspec = self.timem(melspec) 64 | 65 | melspec = (melspec + 4.5) / 5. # fast normalization 66 | 67 | return melspec 68 | -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import librosa 4 | import numpy as np 5 | from torch import autocast 6 | from contextlib import nullcontext 7 | 8 | from models.mn.model import get_model as get_mobilenet 9 | from models.dymn.model import get_model as get_dymn 10 | from models.ensemble import get_ensemble_model 11 | from models.preprocess import AugmentMelSTFT 12 | from helpers.utils import NAME_TO_WIDTH, labels 13 | 14 | 15 | def audio_tagging(args): 16 | """ 17 | Running Inference on an audio clip. 18 | """ 19 | model_name = args.model_name 20 | device = torch.device('cuda') if args.cuda and torch.cuda.is_available() else torch.device('cpu') 21 | audio_path = args.audio_path 22 | sample_rate = args.sample_rate 23 | window_size = args.window_size 24 | hop_size = args.hop_size 25 | n_mels = args.n_mels 26 | 27 | # load pre-trained model 28 | if len(args.ensemble) > 0: 29 | model = get_ensemble_model(args.ensemble) 30 | else: 31 | if model_name.startswith("dymn"): 32 | model = get_dymn(width_mult=NAME_TO_WIDTH(model_name), pretrained_name=model_name, 33 | strides=args.strides) 34 | else: 35 | model = get_mobilenet(width_mult=NAME_TO_WIDTH(model_name), pretrained_name=model_name, 36 | strides=args.strides, head_type=args.head_type) 37 | model.to(device) 38 | model.eval() 39 | 40 | # model to preprocess waveform into mel spectrograms 41 | mel = AugmentMelSTFT(n_mels=n_mels, sr=sample_rate, win_length=window_size, hopsize=hop_size) 42 | mel.to(device) 43 | mel.eval() 44 | 45 | (waveform, _) = librosa.core.load(audio_path, sr=sample_rate, mono=True) 46 | waveform = torch.from_numpy(waveform[None, :]).to(device) 47 | 48 | # our models are trained in half precision mode (torch.float16) 49 | # run on cuda with torch.float16 to get the best performance 50 | # running on cpu with torch.float32 gives similar performance, using torch.bfloat16 is worse 51 | with torch.no_grad(), autocast(device_type=device.type) if args.cuda else nullcontext(): 52 | spec = mel(waveform) 53 | preds, features = model(spec.unsqueeze(0)) 54 | preds = torch.sigmoid(preds.float()).squeeze().cpu().numpy() 55 | 56 | sorted_indexes = np.argsort(preds)[::-1] 57 | 58 | # Print audio tagging top probabilities 59 | print("************* Acoustic Event Detected: *****************") 60 | for k in range(10): 61 | print('{}: {:.3f}'.format(labels[sorted_indexes[k]], 62 | preds[sorted_indexes[k]])) 63 | print("********************************************************") 64 | 65 | 66 | if __name__ == '__main__': 67 | parser = argparse.ArgumentParser(description='Example of parser. ') 68 | # model name decides, which pre-trained model is loaded 69 | parser.add_argument('--model_name', type=str, default='mn10_as') 70 | parser.add_argument('--strides', nargs=4, default=[2, 2, 2, 2], type=int) 71 | parser.add_argument('--head_type', type=str, default="mlp") 72 | parser.add_argument('--cuda', action='store_true', default=False) 73 | parser.add_argument('--audio_path', type=str, required=True) 74 | 75 | # preprocessing 76 | parser.add_argument('--sample_rate', type=int, default=32000) 77 | parser.add_argument('--window_size', type=int, default=800) 78 | parser.add_argument('--hop_size', type=int, default=320) 79 | parser.add_argument('--n_mels', type=int, default=128) 80 | 81 | # overwrite 'model_name' by 'ensemble_model' to evaluate an ensemble 82 | parser.add_argument('--ensemble', nargs='+', default=[]) 83 | 84 | args = parser.parse_args() 85 | 86 | audio_tagging(args) 87 | -------------------------------------------------------------------------------- /complexity.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | 4 | from helpers.flop_count import count_macs, count_macs_transformer 5 | from helpers.peak_memory import peak_memory_mnv3, peak_memory_cnn 6 | from models.mn.model import get_model 7 | from helpers.utils import NAME_TO_WIDTH 8 | from models.preprocess import AugmentMelSTFT 9 | 10 | 11 | def calc_complexity(args): 12 | # mel 13 | mel = AugmentMelSTFT(n_mels=args.n_mels, 14 | sr=args.resample_rate, 15 | win_length=args.window_size, 16 | hopsize=args.hop_size, 17 | n_fft=args.n_fft 18 | ) 19 | 20 | # model 21 | if args.model_width: 22 | # manually specified settings 23 | width = args.model_width 24 | model_name = "mn{}".format(str(width).replace(".", "")) 25 | else: 26 | # model width specified via model name 27 | model_name = args.model_name 28 | width = NAME_TO_WIDTH(model_name) 29 | model = get_model(width_mult=width, se_dims=args.se_dims, head_type=args.head_type) 30 | model.eval() 31 | 32 | # waveform 33 | waveform = torch.zeros((1, args.resample_rate * 10)) # 10 seconds waveform 34 | spectrogram = mel(waveform) 35 | # squeeze in channel dimension 36 | spectrogram = spectrogram.unsqueeze(1) 37 | if args.complexity_type == "computation": 38 | # use size of spectrogram to calculate multiply-accumulate operations 39 | total_macs = count_macs(model, spectrogram.size()) 40 | total_params = sum(p.numel() for p in model.parameters()) 41 | print("Model '{}' has {:.2f} million parameters and inference of a single 10-seconds audio clip requires " 42 | "{:.2f} billion multiply-accumulate operations.".format(model_name, total_params/10**6, total_macs/10**9)) 43 | elif args.complexity_type == "memory": 44 | if args.memory_efficient_inference: 45 | peak_mem = peak_memory_mnv3(model, spectrogram.size(), args.bits_per_elem) 46 | print("Model '{}' inference (memory efficient) of a single 10-seconds audio clip " 47 | "has a peak memory of {:.2f} kB." 48 | .format(model_name, peak_mem)) 49 | else: 50 | peak_mem = peak_memory_cnn(model, spectrogram.size(), args.bits_per_elem) 51 | print("Model '{}' inference of a single 10-seconds audio clip has a peak memory of {:.2f} kB." 52 | .format(model_name, peak_mem)) 53 | else: 54 | raise NotImplementedError(f"Unknown complexity type: {args.complexity_type}") 55 | 56 | 57 | if __name__ == '__main__': 58 | parser = argparse.ArgumentParser(description='Example of parser. ') 59 | # either computation or memory complexity 60 | parser.add_argument('--complexity_type', type=str, default='computation') 61 | # for memory complexity 62 | parser.add_argument('--memory_efficient_inference', action='store_true', default=False) 63 | 64 | # model name decides, which pre-trained model is evaluated in terms of complexity 65 | parser.add_argument('--model_name', type=str, default='mn10_as') 66 | # alternatively, specify model configurations manually 67 | parser.add_argument('--model_width', type=float, default=None) 68 | parser.add_argument('--se_dims', type=str, default='c') 69 | parser.add_argument('--head_type', type=str, default='mlp') 70 | 71 | # preprocessing 72 | parser.add_argument('--resample_rate', type=int, default=32000) 73 | parser.add_argument('--window_size', type=int, default=800) 74 | parser.add_argument('--hop_size', type=int, default=320) 75 | parser.add_argument('--n_fft', type=int, default=1024) 76 | parser.add_argument('--n_mels', type=int, default=128) 77 | 78 | # memory 79 | parser.add_argument('--bits_per_elem', type=int, default=16) 80 | 81 | args = parser.parse_args() 82 | calc_complexity(args) 83 | -------------------------------------------------------------------------------- /helpers/utils.py: -------------------------------------------------------------------------------- 1 | def NAME_TO_WIDTH(name): 2 | mn_map = { 3 | 'mn01': 0.1, 4 | 'mn02': 0.2, 5 | 'mn04': 0.4, 6 | 'mn05': 0.5, 7 | 'mn06': 0.6, 8 | 'mn08': 0.8, 9 | 'mn10': 1.0, 10 | 'mn12': 1.2, 11 | 'mn14': 1.4, 12 | 'mn16': 1.6, 13 | 'mn20': 2.0, 14 | 'mn30': 3.0, 15 | 'mn40': 4.0, 16 | } 17 | 18 | dymn_map = { 19 | 'dymn04': 0.4, 20 | 'dymn10': 1.0, 21 | 'dymn20': 2.0 22 | } 23 | 24 | try: 25 | if name.startswith('dymn'): 26 | w = dymn_map[name[:6]] 27 | else: 28 | w = mn_map[name[:4]] 29 | except: 30 | w = 1.0 31 | 32 | return w 33 | 34 | 35 | import csv 36 | 37 | # Load label 38 | with open('metadata/class_labels_indices.csv', 'r') as f: 39 | reader = csv.reader(f, delimiter=',') 40 | lines = list(reader) 41 | 42 | labels = [] 43 | ids = [] # Each label has a unique id such as "/m/068hy" 44 | for i1 in range(1, len(lines)): 45 | id = lines[i1][1] 46 | label = lines[i1][2] 47 | ids.append(id) 48 | labels.append(label) 49 | 50 | classes_num = len(labels) 51 | 52 | 53 | import numpy as np 54 | 55 | 56 | def exp_warmup_linear_down(warmup, rampdown_length, start_rampdown, last_value): 57 | rampup = exp_rampup(warmup) 58 | rampdown = linear_rampdown(rampdown_length, start_rampdown, last_value) 59 | def wrapper(epoch): 60 | return rampup(epoch) * rampdown(epoch) 61 | return wrapper 62 | 63 | 64 | def exp_rampup(rampup_length): 65 | """Exponential rampup from https://arxiv.org/abs/1610.02242""" 66 | def wrapper(epoch): 67 | if epoch < rampup_length: 68 | epoch = np.clip(epoch, 0.5, rampup_length) 69 | phase = 1.0 - epoch / rampup_length 70 | return float(np.exp(-5.0 * phase * phase)) 71 | else: 72 | return 1.0 73 | return wrapper 74 | 75 | 76 | def linear_rampdown(rampdown_length, start=0, last_value=0): 77 | def wrapper(epoch): 78 | if epoch <= start: 79 | return 1. 80 | elif epoch - start < rampdown_length: 81 | return last_value + (1. - last_value) * (rampdown_length - epoch + start) / rampdown_length 82 | else: 83 | return last_value 84 | return wrapper 85 | 86 | 87 | import torch 88 | 89 | 90 | def mixup(size, alpha): 91 | rn_indices = torch.randperm(size) 92 | lambd = np.random.beta(alpha, alpha, size).astype(np.float32) 93 | lambd = np.concatenate([lambd[:, None], 1 - lambd[:, None]], 1).max(1) 94 | lam = torch.FloatTensor(lambd) 95 | return rn_indices, lam 96 | 97 | 98 | from torch.distributions.beta import Beta 99 | 100 | 101 | def mixstyle(x, p=0.4, alpha=0.4, eps=1e-6, mix_labels=False): 102 | if np.random.rand() > p: 103 | return x 104 | batch_size = x.size(0) 105 | 106 | # changed from dim=[2,3] to dim=[1,3] - from channel-wise statistics to frequency-wise statistics 107 | f_mu = x.mean(dim=[1, 3], keepdim=True) 108 | f_var = x.var(dim=[1, 3], keepdim=True) 109 | 110 | f_sig = (f_var + eps).sqrt() # compute instance standard deviation 111 | f_mu, f_sig = f_mu.detach(), f_sig.detach() # block gradients 112 | x_normed = (x - f_mu) / f_sig # normalize input 113 | lmda = Beta(alpha, alpha).sample((batch_size, 1, 1, 1)).to(x.device) # sample instance-wise convex weights 114 | perm = torch.randperm(batch_size).to(x.device) # generate shuffling indices 115 | f_mu_perm, f_sig_perm = f_mu[perm], f_sig[perm] # shuffling 116 | mu_mix = f_mu * lmda + f_mu_perm * (1 - lmda) # generate mixed mean 117 | sig_mix = f_sig * lmda + f_sig_perm * (1 - lmda) # generate mixed standard deviation 118 | x = x_normed * sig_mix + mu_mix # denormalize input using the mixed statistics 119 | if mix_labels: 120 | return x, perm, lmda 121 | return x 122 | -------------------------------------------------------------------------------- /datasets/esc50.py: -------------------------------------------------------------------------------- 1 | import os 2 | from torch.utils.data import Dataset as TorchDataset 3 | import torch 4 | import numpy as np 5 | import pandas as pd 6 | import librosa 7 | 8 | from datasets.helpers.audiodatasets import PreprocessDataset, get_roll_func 9 | 10 | # specify ESC50 location in 'dataset_dir' 11 | # 3 files have to be located there: 12 | # - FSD50K.eval_mp3.hdf 13 | # - FSD50K.val_mp3.hdf 14 | # - FSD50K.train_mp3.hdf 15 | # follow the instructions here to get these 3 files: 16 | # https://github.com/kkoutini/PaSST/tree/main/esc50 17 | 18 | dataset_dir = None 19 | 20 | assert dataset_dir is not None, "Specify ESC50 dataset location in variable 'dataset_dir'. " \ 21 | "Check out the Readme file for further instructions. " \ 22 | "https://github.com/fschmid56/EfficientAT/blob/main/README.md" 23 | 24 | dataset_config = { 25 | 'meta_csv': os.path.join(dataset_dir, "meta/esc50.csv"), 26 | 'audio_path': os.path.join(dataset_dir, "audio_32k/"), 27 | 'num_of_classes': 50 28 | } 29 | 30 | 31 | def pad_or_truncate(x, audio_length): 32 | """Pad all audio to specific length.""" 33 | if len(x) <= audio_length: 34 | return np.concatenate((x, np.zeros(audio_length - len(x), dtype=np.float32)), axis=0) 35 | else: 36 | return x[0: audio_length] 37 | 38 | 39 | def pydub_augment(waveform, gain_augment=0): 40 | if gain_augment: 41 | gain = torch.randint(gain_augment * 2, (1,)).item() - gain_augment 42 | amp = 10 ** (gain / 20) 43 | waveform = waveform * amp 44 | return waveform 45 | 46 | 47 | class MixupDataset(TorchDataset): 48 | """ Mixing Up wave forms 49 | """ 50 | 51 | def __init__(self, dataset, beta=2, rate=0.5): 52 | self.beta = beta 53 | self.rate = rate 54 | self.dataset = dataset 55 | print(f"Mixing up waveforms from dataset of len {len(dataset)}") 56 | 57 | def __getitem__(self, index): 58 | if torch.rand(1) < self.rate: 59 | x1, f1, y1 = self.dataset[index] 60 | idx2 = torch.randint(len(self.dataset), (1,)).item() 61 | x2, f2, y2 = self.dataset[idx2] 62 | l = np.random.beta(self.beta, self.beta) 63 | l = max(l, 1. - l) 64 | x1 = x1 - x1.mean() 65 | x2 = x2 - x2.mean() 66 | x = (x1 * l + x2 * (1. - l)) 67 | x = x - x.mean() 68 | return x, f1, (y1 * l + y2 * (1. - l)) 69 | return self.dataset[index] 70 | 71 | def __len__(self): 72 | return len(self.dataset) 73 | 74 | 75 | class AudioSetDataset(TorchDataset): 76 | def __init__(self, meta_csv, audiopath, fold, train=False, resample_rate=32000, classes_num=50, 77 | clip_length=5, gain_augment=0): 78 | """ 79 | Reads the mp3 bytes from HDF file decodes using av and returns a fixed length audio wav 80 | """ 81 | self.resample_rate = resample_rate 82 | self.meta_csv = meta_csv 83 | self.df = pd.read_csv(meta_csv) 84 | if train: # training all except this 85 | print(f"Dataset training fold {fold} selection out of {len(self.df)}") 86 | self.df = self.df[self.df.fold != fold] 87 | print(f" for training remains {len(self.df)}") 88 | else: 89 | print(f"Dataset testing fold {fold} selection out of {len(self.df)}") 90 | self.df = self.df[self.df.fold == fold] 91 | print(f" for testing remains {len(self.df)}") 92 | 93 | self.clip_length = clip_length * resample_rate 94 | self.classes_num = classes_num 95 | self.gain_augment = gain_augment 96 | self.audiopath = audiopath 97 | 98 | def __len__(self): 99 | return len(self.df) 100 | 101 | def __getitem__(self, index): 102 | """Load waveform and target of an audio clip. 103 | Args: 104 | meta: { 105 | 'hdf5_path': str, 106 | 'index_in_hdf5': int} 107 | Returns: 108 | data_dict: { 109 | 'audio_name': str, 110 | 'waveform': (clip_samples,), 111 | 'target': (classes_num,)} 112 | """ 113 | row = self.df.iloc[index] 114 | 115 | waveform, _ = librosa.load(self.audiopath + row.filename, sr=self.resample_rate, mono=True) 116 | if self.gain_augment: 117 | waveform = pydub_augment(waveform, self.gain_augment) 118 | waveform = pad_or_truncate(waveform, self.clip_length) 119 | target = np.zeros(self.classes_num) 120 | target[row.target] = 1 121 | return waveform.reshape(1, -1), row.filename, target 122 | 123 | 124 | def get_base_training_set(resample_rate=32000, gain_augment=0, fold=1): 125 | meta_csv = dataset_config['meta_csv'] 126 | audiopath = dataset_config['audio_path'] 127 | ds = AudioSetDataset(meta_csv, audiopath, fold, train=True, 128 | resample_rate=resample_rate, gain_augment=gain_augment) 129 | return ds 130 | 131 | 132 | def get_base_test_set(resample_rate=32000, fold=1): 133 | meta_csv = dataset_config['meta_csv'] 134 | audiopath = dataset_config['audio_path'] 135 | ds = AudioSetDataset(meta_csv, audiopath, fold, train=False, resample_rate=resample_rate) 136 | return ds 137 | 138 | 139 | def get_training_set(resample_rate=32000, roll=False, wavmix=False, gain_augment=0, fold=1): 140 | ds = get_base_training_set(resample_rate=resample_rate, gain_augment=gain_augment, fold=fold) 141 | if roll: 142 | ds = PreprocessDataset(ds, get_roll_func()) 143 | if wavmix: 144 | ds = MixupDataset(ds) 145 | return ds 146 | 147 | 148 | def get_test_set(resample_rate=32000, fold=1): 149 | ds = get_base_test_set(resample_rate, fold=fold) 150 | return ds 151 | -------------------------------------------------------------------------------- /windowed_inference.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import librosa 4 | import numpy as np 5 | from torch import autocast 6 | from contextlib import nullcontext 7 | 8 | from models.mn.model import get_model as get_mobilenet, get_ensemble_model 9 | from models.preprocess import AugmentMelSTFT 10 | from helpers.utils import NAME_TO_WIDTH, labels 11 | 12 | class EATagger: 13 | """ 14 | EATagger: A class for tagging audio files with acoustic event tags. 15 | 16 | Parameters: 17 | 18 | model_name (str, optional): name of the pre-trained model to use. 19 | ensemble (str, optional): name of the ensemble of models to use. 20 | device (str, optional): device to run the model on, either 'cuda' or 'cpu'. 21 | sample_rate (int, optional): sample rate of the audio. 22 | window_size (int, optional): window size for audio analysis in samples. 23 | hop_size (int, optional): hop size for audio analysis in samples. 24 | n_mels (int, optional): number of mel bands to use for audio analysis. 25 | 26 | Methods: 27 | 28 | tag_audio_window(audio_path, window_size=20.0, hop_length=10.0): tags an audio file with an acoustic event. 29 | audio_path (str): path to the audio file 30 | window_size (float, optional): size of the window in seconds 31 | hop_length (float, optional): hop length in seconds 32 | 33 | Returns: list of dictionaries with the following keys: 34 | 'start': start time of the window in seconds 35 | 'end': end time of the window in seconds 36 | 'tags': list of tags for the window in dictionary format 37 | 'tag': name of the tag 38 | 'probability': confidence of the tag 39 | """ 40 | def __init__(self, 41 | model_name=None, 42 | ensemble=None, 43 | device='cuda', 44 | sample_rate=32000, 45 | window_size=800, 46 | hop_size=320, 47 | n_mels=128): 48 | 49 | self.device = torch.device('cuda') if device == 'cuda' and torch.cuda.is_available() else torch.device('cpu') 50 | self.sample_rate = sample_rate 51 | self.window_size = window_size 52 | self.hop_size = hop_size 53 | self.n_mels = n_mels 54 | 55 | # load pre-trained model 56 | if ensemble is not None: 57 | self.model = get_ensemble_model(ensemble) 58 | elif model_name is not None: 59 | self.model = get_mobilenet(width_mult=NAME_TO_WIDTH(model_name), pretrained_name=model_name) 60 | else: 61 | raise ValueError('Please provide a model name or an ensemble of models') 62 | 63 | self.model.to(self.device) 64 | self.model.eval() 65 | 66 | # model to preprocess waveform into mel spectrograms 67 | self.mel = AugmentMelSTFT(n_mels=self.n_mels, sr=self.sample_rate, win_length=self.window_size, hopsize=self.hop_size) 68 | self.mel.to(self.device) 69 | self.mel.eval() 70 | 71 | def tag_audio_window(self, audio_path, window_size=20.0, hop_length=10.0): 72 | """ 73 | Tags an audio file with an acoustic event. 74 | Args: 75 | audio_path (str): path to the audio file 76 | window_size (float): size of the window in seconds 77 | hop_length (float): hop length in seconds 78 | Returns: 79 | List of dictionaries with the following keys: 80 | - 'start': start time of the window in seconds 81 | - 'end': end time of the window in seconds 82 | - 'tags': list of tags for the window in dictionary format 83 | - 'tag': name of the tag 84 | - 'probability': confidence of the tag 85 | 86 | """ 87 | 88 | # load audio file 89 | (waveform, _) = librosa.core.load(audio_path, sr=self.sample_rate, mono=True) 90 | waveform = torch.from_numpy(waveform[None, :]).to(self.device) 91 | 92 | # analyze the audio file in windows, pad the last window if needed 93 | window_size = int(window_size * self.sample_rate) 94 | hop_length = int(hop_length * self.sample_rate) 95 | n_windows = int(np.ceil((waveform.shape[1] - window_size) / hop_length)) + 1 96 | waveform = torch.nn.functional.pad(waveform, (0, n_windows * hop_length + window_size - waveform.shape[1])) 97 | 98 | 99 | with torch.no_grad(), autocast(device_type=self.device.type) if self.device.type == 'cuda' else nullcontext(): 100 | tags = [] 101 | for i in range(n_windows): 102 | start = i * hop_length 103 | end = start + window_size 104 | spec = self.mel(waveform[:, start:end]) 105 | preds, features = self.model(spec.unsqueeze(0)) 106 | preds = torch.sigmoid(preds.float()).squeeze().cpu().numpy() 107 | sorted_indexes = np.argsort(preds)[::-1] 108 | 109 | # Print audio tagging top probabilities 110 | tags.append({ 111 | 'start': start / self.sample_rate, 112 | 'end': end / self.sample_rate, 113 | 'tags': [{ 114 | 'tag': labels[sorted_indexes[k]], 115 | 'probability': preds[sorted_indexes[k]] 116 | } for k in range(10)] 117 | }) 118 | 119 | # progress bar 120 | print(f'\rProgress: {i+1}/{n_windows}', end='') 121 | print() 122 | 123 | 124 | return tags 125 | 126 | 127 | 128 | if __name__ == '__main__': 129 | parser = argparse.ArgumentParser() 130 | parser.add_argument('--model', type=str, default='mn10_as', help='model name') 131 | parser.add_argument('--cuda', action='store_true', default=False) 132 | parser.add_argument('--audio_path', type=str, help='path to the audio file', required=True) 133 | parser.add_argument('--window_size', type=float, default=10.0, help='window size in seconds') 134 | parser.add_argument('--hop_length', type=float, default=2.5, help='hop length in seconds') 135 | args = parser.parse_args() 136 | 137 | # load the model 138 | model = EATagger(model_name=args.model, device='cuda' if args.cuda else 'cpu') 139 | 140 | # tag the audio file 141 | tags = model.tag_audio_window(args.audio_path, window_size=args.window_size, hop_length=args.hop_length) 142 | 143 | # for each window, print the top 5 tags and their probabilities 144 | for window in tags: 145 | print(f'Window: {window["start"]:.2f} - {window["end"]:.2f}') 146 | for tag in window['tags'][:5]: 147 | print(f'\t{tag["tag"]}: {tag["probability"]:.2f}') 148 | print() 149 | -------------------------------------------------------------------------------- /models/mn/block_types.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Callable, List 2 | import torch 3 | import torch.nn as nn 4 | from torch import Tensor 5 | from torchvision.ops.misc import ConvNormActivation 6 | 7 | from models.mn.utils import make_divisible, cnn_out_size 8 | 9 | 10 | class ConcurrentSEBlock(torch.nn.Module): 11 | def __init__( 12 | self, 13 | c_dim: int, 14 | f_dim: int, 15 | t_dim: int, 16 | se_cnf: Dict 17 | ) -> None: 18 | super().__init__() 19 | dims = [c_dim, f_dim, t_dim] 20 | self.conc_se_layers = nn.ModuleList() 21 | for d in se_cnf['se_dims']: 22 | input_dim = dims[d-1] 23 | squeeze_dim = make_divisible(input_dim // se_cnf['se_r'], 8) 24 | self.conc_se_layers.append(SqueezeExcitation(input_dim, squeeze_dim, d)) 25 | if se_cnf['se_agg'] == "max": 26 | self.agg_op = lambda x: torch.max(x, dim=0)[0] 27 | elif se_cnf['se_agg'] == "avg": 28 | self.agg_op = lambda x: torch.mean(x, dim=0) 29 | elif se_cnf['se_agg'] == "add": 30 | self.agg_op = lambda x: torch.sum(x, dim=0) 31 | elif se_cnf['se_agg'] == "min": 32 | self.agg_op = lambda x: torch.min(x, dim=0)[0] 33 | else: 34 | raise NotImplementedError(f"SE aggregation operation '{self.agg_op}' not implemented") 35 | 36 | def forward(self, input: Tensor) -> Tensor: 37 | # apply all concurrent se layers 38 | se_outs = [] 39 | for se_layer in self.conc_se_layers: 40 | se_outs.append(se_layer(input)) 41 | out = self.agg_op(torch.stack(se_outs, dim=0)) 42 | return out 43 | 44 | 45 | class SqueezeExcitation(torch.nn.Module): 46 | """ 47 | This block implements the Squeeze-and-Excitation block from https://arxiv.org/abs/1709.01507. 48 | Args: 49 | input_dim (int): Input dimension 50 | squeeze_dim (int): Size of Bottleneck 51 | activation (Callable): activation applied to bottleneck 52 | scale_activation (Callable): activation applied to the output 53 | """ 54 | 55 | def __init__( 56 | self, 57 | input_dim: int, 58 | squeeze_dim: int, 59 | se_dim: int, 60 | activation: Callable[..., torch.nn.Module] = torch.nn.ReLU, 61 | scale_activation: Callable[..., torch.nn.Module] = torch.nn.Sigmoid, 62 | ) -> None: 63 | super().__init__() 64 | self.fc1 = torch.nn.Linear(input_dim, squeeze_dim) 65 | self.fc2 = torch.nn.Linear(squeeze_dim, input_dim) 66 | assert se_dim in [1, 2, 3] 67 | self.se_dim = [1, 2, 3] 68 | self.se_dim.remove(se_dim) 69 | self.activation = activation() 70 | self.scale_activation = scale_activation() 71 | 72 | def _scale(self, input: Tensor) -> Tensor: 73 | scale = torch.mean(input, self.se_dim, keepdim=True) 74 | shape = scale.size() 75 | scale = self.fc1(scale.squeeze(2).squeeze(2)) 76 | scale = self.activation(scale) 77 | scale = self.fc2(scale) 78 | scale = scale 79 | return self.scale_activation(scale).view(shape) 80 | 81 | def forward(self, input: Tensor) -> Tensor: 82 | scale = self._scale(input) 83 | return scale * input 84 | 85 | 86 | class InvertedResidualConfig: 87 | # Stores information listed at Tables 1 and 2 of the MobileNetV3 paper 88 | def __init__( 89 | self, 90 | input_channels: int, 91 | kernel: int, 92 | expanded_channels: int, 93 | out_channels: int, 94 | use_se: bool, 95 | activation: str, 96 | stride: int, 97 | dilation: int, 98 | width_mult: float, 99 | ): 100 | self.input_channels = self.adjust_channels(input_channels, width_mult) 101 | self.kernel = kernel 102 | self.expanded_channels = self.adjust_channels(expanded_channels, width_mult) 103 | self.out_channels = self.adjust_channels(out_channels, width_mult) 104 | self.use_se = use_se 105 | self.use_hs = activation == "HS" 106 | self.stride = stride 107 | self.dilation = dilation 108 | self.f_dim = None 109 | self.t_dim = None 110 | 111 | @staticmethod 112 | def adjust_channels(channels: int, width_mult: float): 113 | return make_divisible(channels * width_mult, 8) 114 | 115 | def out_size(self, in_size): 116 | padding = (self.kernel - 1) // 2 * self.dilation 117 | return cnn_out_size(in_size, padding, self.dilation, self.kernel, self.stride) 118 | 119 | 120 | class InvertedResidual(nn.Module): 121 | def __init__( 122 | self, 123 | cnf: InvertedResidualConfig, 124 | se_cnf: Dict, 125 | norm_layer: Callable[..., nn.Module], 126 | depthwise_norm_layer: Callable[..., nn.Module] 127 | ): 128 | super().__init__() 129 | if not (1 <= cnf.stride <= 2): 130 | raise ValueError("illegal stride value") 131 | 132 | self.use_res_connect = cnf.stride == 1 and cnf.input_channels == cnf.out_channels 133 | 134 | layers: List[nn.Module] = [] 135 | activation_layer = nn.Hardswish if cnf.use_hs else nn.ReLU 136 | 137 | # expand 138 | if cnf.expanded_channels != cnf.input_channels: 139 | layers.append( 140 | ConvNormActivation( 141 | cnf.input_channels, 142 | cnf.expanded_channels, 143 | kernel_size=1, 144 | norm_layer=norm_layer, 145 | activation_layer=activation_layer, 146 | ) 147 | ) 148 | 149 | # depthwise 150 | stride = 1 if cnf.dilation > 1 else cnf.stride 151 | layers.append( 152 | ConvNormActivation( 153 | cnf.expanded_channels, 154 | cnf.expanded_channels, 155 | kernel_size=cnf.kernel, 156 | stride=stride, 157 | dilation=cnf.dilation, 158 | groups=cnf.expanded_channels, 159 | norm_layer=depthwise_norm_layer, 160 | activation_layer=activation_layer, 161 | ) 162 | ) 163 | if cnf.use_se and se_cnf['se_dims'] is not None: 164 | layers.append(ConcurrentSEBlock(cnf.expanded_channels, cnf.f_dim, cnf.t_dim, se_cnf)) 165 | 166 | # project 167 | layers.append( 168 | ConvNormActivation( 169 | cnf.expanded_channels, cnf.out_channels, kernel_size=1, norm_layer=norm_layer, activation_layer=None 170 | ) 171 | ) 172 | 173 | self.block = nn.Sequential(*layers) 174 | self.out_channels = cnf.out_channels 175 | self._is_cn = cnf.stride > 1 176 | 177 | def forward(self, inp: Tensor) -> Tensor: 178 | result = self.block(inp) 179 | if self.use_res_connect: 180 | result += inp 181 | return result 182 | -------------------------------------------------------------------------------- /datasets/dcase20.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import os 3 | from sklearn import preprocessing 4 | from torch.utils.data import Dataset as TorchDataset 5 | import torch 6 | import numpy as np 7 | import librosa 8 | 9 | from datasets.helpers.audiodatasets import PreprocessDataset, get_roll_func, get_gain_augment_func 10 | 11 | dataset_dir = None 12 | assert dataset_dir is not None, "Specify 'TAU Urban Acoustic Scenes 2020 Mobile dataset' location in variable " \ 13 | "'dataset_dir'. Check out the Readme file for further instructions. " \ 14 | "https://github.com/fschmid56/EfficientAT/blob/main/README.md" 15 | 16 | dataset_config = { 17 | "dataset_name": "tau_urban_acoustic_scene20", 18 | "meta_csv": os.path.join(dataset_dir, "meta.csv"), 19 | "train_files_csv": os.path.join(dataset_dir, "evaluation_setup", "fold1_train.csv"), 20 | "test_files_csv": os.path.join(dataset_dir, "evaluation_setup", "fold1_evaluate.csv") 21 | } 22 | 23 | 24 | class BasicDCASE20Dataset(TorchDataset): 25 | """ 26 | Basic DCASE20 Dataset 27 | """ 28 | 29 | def __init__(self, meta_csv, sr=32000, cache_path=None): 30 | """ 31 | @param meta_csv: meta csv file for the dataset 32 | @param sr: specify sampling rate 33 | @param sr: specify cache path to store resampled waveforms 34 | return: waveform, name of the file, label, device and cities 35 | """ 36 | df = pd.read_csv(meta_csv, sep="\t") 37 | le = preprocessing.LabelEncoder() 38 | self.labels = torch.from_numpy(le.fit_transform(df[['scene_label']].values.reshape(-1))) 39 | self.devices = le.fit_transform(df[['source_label']].values.reshape(-1)) 40 | self.cities = le.fit_transform(df['identifier'].apply(lambda loc: loc.split("-")[0]).values.reshape(-1)) 41 | self.files = df[['filename']].values.reshape(-1) 42 | self.sr = sr 43 | if cache_path is not None: 44 | self.cache_path = os.path.join(cache_path, dataset_config["dataset_name"] + f"_r{self.sr}", "files_cache") 45 | os.makedirs(self.cache_path, exist_ok=True) 46 | else: 47 | self.cache_path = None 48 | 49 | def __getitem__(self, index): 50 | if self.cache_path: 51 | cpath = os.path.join(self.cache_path, str(index) + ".pt") 52 | try: 53 | sig = torch.load(cpath) 54 | except FileNotFoundError: 55 | sig, _ = librosa.load(os.path.join(dataset_dir, self.files[index]), sr=self.sr, mono=True) 56 | sig = torch.from_numpy(sig[np.newaxis]) 57 | torch.save(sig, cpath) 58 | else: 59 | sig, _ = librosa.load(os.path.join(dataset_dir, self.files[index]), sr=self.sr, mono=True) 60 | sig = torch.from_numpy(sig[np.newaxis]) 61 | return sig, self.files[index], self.labels[index], self.devices[index], self.cities[index] 62 | 63 | def __len__(self): 64 | return len(self.files) 65 | 66 | 67 | class SimpleSelectionDataset(TorchDataset): 68 | """A dataset that selects a subsample from a dataset based on a set of sample ids. 69 | Supporting integer indexing in range from 0 to len(self) exclusive. 70 | """ 71 | 72 | def __init__(self, dataset, available_indices): 73 | """ 74 | @param dataset: dataset to load data from 75 | @param available_indices: available indices of samples for 'training', 'testing' 76 | return: x, label, device, city, index 77 | """ 78 | self.available_indices = available_indices 79 | self.dataset = dataset 80 | 81 | def __getitem__(self, index): 82 | x, file, label, device, city = self.dataset[self.available_indices[index]] 83 | return x, file, label, device, city, self.available_indices[index] 84 | 85 | def __len__(self): 86 | return len(self.available_indices) 87 | 88 | 89 | class MixupDataset(TorchDataset): 90 | """ Mixing Up wave forms 91 | """ 92 | 93 | def __init__(self, dataset, beta=2, rate=0.5, num_classes=10): 94 | self.beta = beta 95 | self.rate = rate 96 | self.dataset = dataset 97 | self.num_classes = num_classes 98 | print(f"Mixing up waveforms from dataset of len {len(dataset)}") 99 | 100 | def __getitem__(self, index): 101 | x1, f1, y1, d1, c1, i1 = self.dataset[index] 102 | y = np.zeros(self.num_classes, dtype="float32") 103 | y[y1] = 1. 104 | y1 = y 105 | if torch.rand(1) < self.rate: 106 | idx2 = torch.randint(len(self.dataset), (1,)).item() 107 | x2, _, y2, _, _, _ = self.dataset[idx2] 108 | y = np.zeros(self.num_classes, dtype="float32") 109 | y[y2] = 1. 110 | y2 = y 111 | l = np.random.beta(self.beta, self.beta) 112 | l = max(l, 1. - l) 113 | x1 = x1 - x1.mean() 114 | x2 = x2 - x2.mean() 115 | x = (x1 * l + x2 * (1. - l)) 116 | x = x - x.mean() 117 | return x, f1, (y1 * l + y2 * (1. - l)), d1, c1, i1 118 | return x1, f1, y1, d1, c1, i1 119 | 120 | def __len__(self): 121 | return len(self.dataset) 122 | 123 | 124 | # commands to create the datasets for training and testing 125 | def get_training_set(cache_path=None, resample_rate=32000, roll=False, gain_augment=False, wavmix=False): 126 | ds = get_base_training_set(dataset_config['meta_csv'], dataset_config['train_files_csv'], cache_path, 127 | resample_rate) 128 | if roll: 129 | ds = PreprocessDataset(ds, get_roll_func()) 130 | 131 | if gain_augment: 132 | ds = PreprocessDataset(ds, get_gain_augment_func(gain_augment)) 133 | 134 | if wavmix: 135 | ds = MixupDataset(ds) 136 | 137 | return ds 138 | 139 | 140 | def get_base_training_set(meta_csv, train_files_csv, cache_path, resample_rate): 141 | train_files = pd.read_csv(train_files_csv, sep='\t')['filename'].values.reshape(-1) 142 | meta = pd.read_csv(meta_csv, sep="\t") 143 | train_indices = list(meta[meta['filename'].isin(train_files)].index) 144 | ds = SimpleSelectionDataset(BasicDCASE20Dataset(meta_csv, sr=resample_rate, cache_path=cache_path), train_indices) 145 | return ds 146 | 147 | 148 | def get_test_set(cache_path=None, resample_rate=32000): 149 | ds = get_base_test_set(dataset_config['meta_csv'], dataset_config['test_files_csv'], cache_path, 150 | resample_rate) 151 | return ds 152 | 153 | 154 | def get_base_test_set(meta_csv, test_files_csv, cache_path, resample_rate): 155 | test_files = pd.read_csv(test_files_csv, sep='\t')['filename'].values.reshape(-1) 156 | meta = pd.read_csv(meta_csv, sep="\t") 157 | test_indices = list(meta[meta['filename'].isin(test_files)].index) 158 | ds = SimpleSelectionDataset(BasicDCASE20Dataset(meta_csv, sr=resample_rate, cache_path=cache_path), test_indices) 159 | return ds 160 | -------------------------------------------------------------------------------- /helpers/peak_memory.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | # analytical memory profiling as in: 6 | # https://proceedings.neurips.cc/paper/2021/file/1371bccec2447b5aa6d96d2a540fb401-Paper.pdf 7 | # "memory required for a layer is the sum of input and output activation 8 | # (since weights can be partially fetched from Flash" 9 | # calculated using optimization described in https://arxiv.org/pdf/1801.04381.pdf (memory efficient inference) 10 | 11 | def peak_memory_mnv3(model, spec_size, bits_per_elem=16): 12 | global_in_elements = [] 13 | def in_conv_hook(self, input, output): 14 | global_in_elements.append(input[0].nelement()) 15 | 16 | inv_residual_elems = [] 17 | def first_inv_residual_block_hook(self, input, output, slice=8): 18 | mem = global_in_elements[-1] + output[0].nelement() 19 | # we need to only partially materialize internal block representation, we assume 8 parallel path per default 20 | block_in_t = input[0].size(3) 21 | block_in_f = input[0].size(2) 22 | ch = input[0].size(1) 23 | mem += block_in_t * block_in_f * ch / slice # repr. before depth-wise 24 | mem += block_in_t * block_in_f * ch / slice # repr. after depth-wise 25 | inv_residual_elems.append(mem) 26 | 27 | res_elements = [] 28 | def res_hook(self, input, output): 29 | res_elements.append(output[0].nelement()) 30 | 31 | def inv_residual_hook(self, input, output): 32 | mem = input[0].nelement() + output[0].nelement() 33 | # add possible memory for residual connection 34 | mem += res_elements[-1] 35 | inv_residual_elems.append(mem) 36 | 37 | def inv_no_residual_hook(self, input, output, slice=8): 38 | mem = input[0].nelement() + output[0].nelement() 39 | # we need to only partially materialize internal block representation, we assume 8 parallel path per default 40 | block_in_t = input[0].size(3) 41 | block_in_f = input[0].size(2) 42 | stride = self.block[1][0].stride[0] 43 | mem += block_in_t * block_in_f * self.block[0].out_channels / slice # repr. before depth-wise 44 | next_in_f = block_in_f // stride 45 | next_in_t = block_in_t // stride 46 | mem += next_in_t * next_in_f * self.block[0].out_channels / slice # repr. after depth-wise 47 | inv_residual_elems.append(mem) 48 | 49 | def foo(net): 50 | children = [] 51 | if hasattr(net, "features"): 52 | # first call to foo with full network 53 | # treat first ConvNormActivation and InvertedResidual - can be calculated memory efficient 54 | net.features[0].register_forward_hook(in_conv_hook) 55 | net.features[1].register_forward_hook(first_inv_residual_block_hook) 56 | children = list(net.features.children())[2:] 57 | elif net.__class__.__name__ == 'InvertedResidual': 58 | # account for residual connection if Squeeze-and-Excitation block 59 | net.block.register_forward_hook(res_hook) 60 | 61 | if len(net.block) > 3: 62 | # contains Squeeze-and-Excitation Layer -> cannot use memory efficient inference 63 | # -> must fully materialize all convs in block 64 | # -> last conv layer has max sum of input and output activation sizes 65 | net.block[3].register_forward_hook(inv_residual_hook) 66 | elif len(net.block) == 3: 67 | # block with no Squeeze-and-Excitation 68 | # can use memory efficient inference, no need to fully materialize expanded channel representation 69 | net.register_forward_hook(inv_no_residual_hook) 70 | else: 71 | raise ValueError("Can treat only MobileNetV3 blocks. Block 1 consists of 2 modules and following" 72 | "blocks of 3 or 4 modules. Block 1 must be treated differently.") 73 | else: 74 | children = list(net.children()) 75 | 76 | for c in children: 77 | foo(c) 78 | 79 | # Register hook 80 | foo(model) 81 | 82 | device = next(model.parameters()).device 83 | input = torch.rand(spec_size).to(device) 84 | with torch.no_grad(): 85 | model(input) 86 | 87 | block_mems = [elem * bits_per_elem / (8 * 1000) for elem in inv_residual_elems] 88 | peak_mem = max(block_mems) 89 | 90 | print("*************Memory Complexity (kB) **************") 91 | for i, block_mem in enumerate(block_mems): 92 | print(f"block {i + 1} memory: {block_mem} kB") 93 | print("**************************************************") 94 | print("Analytical peak memory: ", peak_mem, " kB") 95 | print("**************************************************") 96 | return peak_mem 97 | 98 | 99 | def peak_memory_cnn(model, spec_size, bits_per_elem=16): 100 | first_conv_in_block = [True] 101 | res_elems = [] # initialized with one 0 for input conv 102 | 103 | def res_hook(self, input, output): 104 | first_conv_in_block[0] = True 105 | res_elems.append(output[0].nelement()) 106 | 107 | conv_activation_elems = [] 108 | 109 | def conv2d_res_hook(self, input, output): 110 | mem = input[0].nelement() + output[0].nelement() 111 | # maybe have to add size of parallel residual path 112 | if not first_conv_in_block[0]: 113 | mem += res_elems[-1] 114 | else: 115 | first_conv_in_block[0] = False 116 | conv_activation_elems.append(mem) 117 | 118 | def conv2d_hook(self, input, output): 119 | mem = input[0].nelement() + output[0].nelement() 120 | conv_activation_elems.append(mem) 121 | 122 | def foo(net, residual_block=False): 123 | if hasattr(net, "features"): 124 | net.features[0].register_forward_hook(res_hook) 125 | if net.__class__.__name__ == 'InvertedResidual': 126 | net.register_forward_hook(res_hook) 127 | if net.use_res_connect: 128 | residual_block = True 129 | if isinstance(net, nn.Conv2d): 130 | if residual_block: 131 | net.register_forward_hook(conv2d_res_hook) 132 | else: 133 | net.register_forward_hook(conv2d_hook) 134 | 135 | for c in net.children(): 136 | foo(c, residual_block) 137 | 138 | # Register hook 139 | foo(model) 140 | 141 | device = next(model.parameters()).device 142 | input = torch.rand(spec_size).to(device) 143 | with torch.no_grad(): 144 | model(input) 145 | 146 | conv_act_mems = [elem * bits_per_elem / (8 * 1000) for elem in conv_activation_elems] 147 | peak_mem = max(conv_act_mems) 148 | 149 | print("*************Memory Complexity (kB) **************") 150 | for i, conv_mem in enumerate(conv_act_mems): 151 | print(f"conv {i + 1} memory: {conv_mem} kB") 152 | print("**************************************************") 153 | print("Analytical peak memory: ", peak_mem, " kB") 154 | print("**************************************************") 155 | return peak_mem 156 | -------------------------------------------------------------------------------- /helpers/flop_count.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | # adapted from PANNs (https://github.com/qiuqiangkong/audioset_tagging_cnn) 6 | 7 | def count_macs(model, spec_size): 8 | list_conv2d = [] 9 | 10 | def conv2d_hook(self, input, output): 11 | batch_size, input_channels, input_height, input_width = input[0].size() 12 | assert batch_size == 1 13 | output_channels, output_height, output_width = output[0].size() 14 | 15 | kernel_ops = self.kernel_size[0] * self.kernel_size[1] * (self.in_channels / self.groups) 16 | bias_ops = 1 if self.bias is not None else 0 17 | 18 | params = output_channels * (kernel_ops + bias_ops) 19 | # overall macs count is: 20 | # kernel**2 * in_channels/groups * out_channels * out_width * out_height 21 | macs = batch_size * params * output_height * output_width 22 | 23 | list_conv2d.append(macs) 24 | 25 | list_linear = [] 26 | 27 | def linear_hook(self, input, output): 28 | batch_size = input[0].size(0) if input[0].dim() == 2 else 1 29 | assert batch_size == 1 30 | weight_ops = self.weight.nelement() 31 | bias_ops = self.bias.nelement() 32 | 33 | # overall macs count is equal to the number of parameters in layer 34 | macs = batch_size * (weight_ops + bias_ops) 35 | list_linear.append(macs) 36 | 37 | def foo(net): 38 | if net.__class__.__name__ == 'Conv2dStaticSamePadding': 39 | net.register_forward_hook(conv2d_hook) 40 | childrens = list(net.children()) 41 | if not childrens: 42 | if isinstance(net, nn.Conv2d): 43 | net.register_forward_hook(conv2d_hook) 44 | elif isinstance(net, nn.Linear): 45 | net.register_forward_hook(linear_hook) 46 | else: 47 | print('Warning: flop of module {} is not counted!'.format(net)) 48 | return 49 | for c in childrens: 50 | foo(c) 51 | 52 | # Register hook 53 | foo(model) 54 | 55 | device = next(model.parameters()).device 56 | input = torch.rand(spec_size).to(device) 57 | with torch.no_grad(): 58 | model(input) 59 | 60 | total_macs = sum(list_conv2d) + sum(list_linear) 61 | 62 | print("*************Computational Complexity (multiply-adds) **************") 63 | print("Number of Convolutional Layers: ", len(list_conv2d)) 64 | print("Number of Linear Layers: ", len(list_linear)) 65 | print("Relative Share of Convolutional Layers: {:.2f}".format((sum(list_conv2d) / total_macs))) 66 | print("Relative Share of Linear Layers: {:.2f}".format(sum(list_linear) / total_macs)) 67 | print("Total MACs (multiply-accumulate operations in Billions): {:.2f}".format(total_macs/10**9)) 68 | print("********************************************************************") 69 | return total_macs 70 | 71 | 72 | def count_macs_transformer(model, spec_size): 73 | """Count macs. Code modified from others' implementation. 74 | """ 75 | list_conv2d = [] 76 | 77 | def conv2d_hook(self, input, output): 78 | batch_size, input_channels, input_height, input_width = input[0].size() 79 | assert batch_size == 1 80 | output_channels, output_height, output_width = output[0].size() 81 | 82 | kernel_ops = self.kernel_size[0] * self.kernel_size[1] * (self.in_channels / self.groups) 83 | bias_ops = 1 if self.bias is not None else 0 84 | 85 | params = output_channels * (kernel_ops + bias_ops) 86 | # overall macs count is: 87 | # kernel**2 * in_channels/groups * out_channels * out_width * out_height 88 | macs = batch_size * params * output_height * output_width 89 | 90 | list_conv2d.append(macs) 91 | 92 | list_linear = [] 93 | 94 | def linear_hook(self, input, output): 95 | batch_size = input[0].size(0) if input[0].dim() >= 2 else 1 96 | assert batch_size == 1 97 | if input[0].dim() == 3: 98 | # (batch size, sequence length, embeddings size) 99 | batch_size, seq_len, embed_size = input[0].size() 100 | 101 | weight_ops = self.weight.nelement() 102 | bias_ops = self.bias.nelement() if self.bias is not None else 0 103 | # linear layer applied position-wise, multiply with sequence length 104 | macs = batch_size * (weight_ops + bias_ops) * seq_len 105 | else: 106 | # classification head 107 | # (batch size, embeddings size) 108 | batch_size, embed_size = input[0].size() 109 | weight_ops = self.weight.nelement() 110 | bias_ops = self.bias.nelement() if self.bias is not None else 0 111 | # overall macs count is equal to the number of parameters in layer 112 | macs = batch_size * (weight_ops + bias_ops) 113 | list_linear.append(macs) 114 | 115 | list_att = [] 116 | 117 | def attention_hook(self, input, output): 118 | # here we only calculate the attention macs; linear layers are processed in linear_hook 119 | batch_size, seq_len, embed_size = input[0].size() 120 | 121 | # 2 times embed_size * seq_len**2 122 | # - computing the attention matrix: embed_size * seq_len**2 123 | # - multiply attention matrix with value matrix: embed_size * seq_len**2 124 | macs = batch_size * embed_size * seq_len * seq_len * 2 125 | list_att.append(macs) 126 | 127 | def foo(net): 128 | childrens = list(net.children()) 129 | if net.__class__.__name__ == "MultiHeadAttention": 130 | net.register_forward_hook(attention_hook) 131 | if not childrens: 132 | if isinstance(net, nn.Conv2d): 133 | net.register_forward_hook(conv2d_hook) 134 | elif isinstance(net, nn.Linear): 135 | net.register_forward_hook(linear_hook) 136 | else: 137 | print('Warning: flop of module {} is not counted!'.format(net)) 138 | return 139 | for c in childrens: 140 | foo(c) 141 | 142 | # Register hook 143 | foo(model) 144 | 145 | device = next(model.parameters()).device 146 | input = torch.rand(spec_size).to(device) 147 | 148 | with torch.no_grad(): 149 | model(input) 150 | 151 | total_macs = sum(list_conv2d) + sum(list_linear) + sum(list_att) 152 | 153 | print("*************Computational Complexity (multiply-adds) **************") 154 | print("Number of Convolutional Layers: ", len(list_conv2d)) 155 | print("Number of Linear Layers: ", len(list_linear)) 156 | print("Number of Attention Layers: ", len(list_att)) 157 | print("Relative Share of Convolutional Layers: {:.2f}".format((sum(list_conv2d) / total_macs))) 158 | print("Relative Share of Linear Layers: {:.2f}".format(sum(list_linear) / total_macs)) 159 | print("Relative Share of Attention Layers: {:.2f}".format(sum(list_att) / total_macs)) 160 | print("Total MACs (multiply-accumulate operations in Billions): {:.2f}".format(total_macs/10**9)) 161 | print("********************************************************************") 162 | return total_macs 163 | -------------------------------------------------------------------------------- /datasets/openmic.py: -------------------------------------------------------------------------------- 1 | import io 2 | import os 3 | import av 4 | from torch.utils.data import Dataset as TorchDataset, WeightedRandomSampler 5 | import torch 6 | import numpy as np 7 | import h5py 8 | 9 | from datasets.helpers.audiodatasets import PreprocessDataset, get_roll_func 10 | 11 | # specify OpenMic location in 'dataset_dir' 12 | # 2 files have to be located there: 13 | # - openmic_train.csv_mp3.hdf 14 | # - openmic_test.csv_mp3.hdf 15 | # follow the instructions here to get these 2 files: 16 | # https://github.com/kkoutini/PaSST/tree/main/openmic 17 | 18 | dataset_dir = None 19 | 20 | assert dataset_dir is not None, "Specify OpenMic dataset location in variable 'dataset_dir'. " \ 21 | "Check out the Readme file for further instructions. " \ 22 | "https://github.com/fschmid56/EfficientAT/blob/main/README.md" 23 | 24 | dataset_config = { 25 | 'openmic_train_hdf5': os.path.join(dataset_dir, "openmic_train.csv_mp3.hdf"), 26 | 'openmic_test_hdf5': os.path.join(dataset_dir, "openmic_test.csv_mp3.hdf"), 27 | 'num_of_classes': 20 28 | } 29 | 30 | 31 | def decode_mp3(mp3_arr): 32 | """ 33 | decodes an array if uint8 representing an mp3 file 34 | :rtype: np.array 35 | """ 36 | container = av.open(io.BytesIO(mp3_arr.tobytes())) 37 | stream = next(s for s in container.streams if s.type == 'audio') 38 | a = [] 39 | for i, packet in enumerate(container.demux(stream)): 40 | for frame in packet.decode(): 41 | a.append(frame.to_ndarray().reshape(-1)) 42 | waveform = np.concatenate(a) 43 | if waveform.dtype != 'float32': 44 | raise RuntimeError("Unexpected wave type") 45 | return waveform 46 | 47 | 48 | def pad_or_truncate(x, audio_length): 49 | """Pad all audio to specific length.""" 50 | if len(x) <= audio_length: 51 | return np.concatenate((x, np.zeros(audio_length - len(x), dtype=np.float32)), axis=0) 52 | else: 53 | return x[0: audio_length] 54 | 55 | 56 | def pydub_augment(waveform, gain_augment=0): 57 | if gain_augment: 58 | gain = torch.randint(gain_augment * 2, (1,)).item() - gain_augment 59 | amp = 10 ** (gain / 20) 60 | waveform = waveform * amp 61 | return waveform 62 | 63 | 64 | class MixupDataset(TorchDataset): 65 | """ Mixing Up wave forms 66 | """ 67 | 68 | def __init__(self, dataset, beta=2, rate=0.5): 69 | self.beta = beta 70 | self.rate = rate 71 | self.dataset = dataset 72 | print(f"Mixing up waveforms from dataset of len {len(dataset)}") 73 | 74 | def __getitem__(self, index): 75 | x1, f1, y1 = self.dataset[index] 76 | y1 = torch.as_tensor(y1) 77 | if torch.rand(1) < self.rate: 78 | idx2 = torch.randint(len(self.dataset), (1,)).item() 79 | x2, f2, y2 = self.dataset[idx2] 80 | y2 = torch.as_tensor(y2) 81 | l = np.random.beta(self.beta, self.beta) 82 | l = max(l, 1. - l) 83 | x1 = x1 - x1.mean() 84 | x2 = x2 - x2.mean() 85 | x = (x1 * l + x2 * (1. - l)) 86 | x = x - x.mean() 87 | assert len(y1) == 40, "only for openmic this works" 88 | y_mask1 = (torch.as_tensor(y1[20:]) > 0.5).float() 89 | y_mask2 = (torch.as_tensor(y2[20:]) > 0.5).float() 90 | y1[:20] *= y_mask1 91 | y2[:20] *= y_mask2 92 | yres = (y1 * l + y2 * (1. - l)) 93 | yres[20:] = torch.stack([y_mask1, y_mask2]).max(dim=0).values 94 | return x, f1, yres 95 | return x1, f1, y1 96 | 97 | def __len__(self): 98 | return len(self.dataset) 99 | 100 | 101 | class AudioSetDataset(TorchDataset): 102 | def __init__(self, hdf5_file, resample_rate=32000, classes_num=20, clip_length=10, 103 | in_mem=False, gain_augment=0): 104 | """ 105 | Reads the mp3 bytes from HDF file decodes using av and returns a fixed length audio wav 106 | """ 107 | self.resample_rate = resample_rate 108 | self.hdf5_file = hdf5_file 109 | if in_mem: 110 | print("\nPreloading in memory\n") 111 | with open(hdf5_file, 'rb') as f: 112 | self.hdf5_file = io.BytesIO(f.read()) 113 | with h5py.File(hdf5_file, 'r') as f: 114 | self.length = len(f['audio_name']) 115 | print(f"Dataset from {hdf5_file} with length {self.length}.") 116 | self.dataset_file = None # lazy init 117 | self.clip_length = clip_length 118 | if clip_length is not None: 119 | self.clip_length = clip_length * resample_rate 120 | self.classes_num = classes_num 121 | self.gain_augment = gain_augment 122 | 123 | def open_hdf5(self): 124 | self.dataset_file = h5py.File(self.hdf5_file, 'r') 125 | 126 | def __len__(self): 127 | return self.length 128 | 129 | def __del__(self): 130 | if self.dataset_file is not None: 131 | self.dataset_file.close() 132 | self.dataset_file = None 133 | 134 | def __getitem__(self, index): 135 | """Load waveform and target of an audio clip. 136 | Args: 137 | meta: { 138 | 'hdf5_path': str, 139 | 'index_in_hdf5': int} 140 | Returns: 141 | data_dict: { 142 | 'audio_name': str, 143 | 'waveform': (clip_samples,), 144 | 'target': (classes_num,)} 145 | """ 146 | if self.dataset_file is None: 147 | self.open_hdf5() 148 | 149 | audio_name = self.dataset_file['audio_name'][index].decode() 150 | waveform = decode_mp3(self.dataset_file['mp3'][index]) 151 | waveform = pydub_augment(waveform, self.gain_augment) 152 | waveform = pad_or_truncate(waveform, self.clip_length) 153 | waveform = self.resample(waveform) 154 | target = self.dataset_file['target'][index] 155 | target = target.astype(np.float32) 156 | return waveform.reshape(1, -1), audio_name, target 157 | 158 | def resample(self, waveform): 159 | """Resample. 160 | Args: 161 | waveform: (clip_samples,) 162 | Returns: 163 | (resampled_clip_samples,) 164 | """ 165 | if self.resample_rate == 32000: 166 | return waveform 167 | elif self.resample_rate == 16000: 168 | return waveform[0:: 2] 169 | elif self.resample_rate == 8000: 170 | return waveform[0:: 4] 171 | else: 172 | raise Exception('Incorrect sample rate!') 173 | 174 | 175 | def get_base_training_set(resample_rate=32000, gain_augment=0): 176 | balanced_train_hdf5 = dataset_config['openmic_train_hdf5'] 177 | ds = AudioSetDataset(balanced_train_hdf5, resample_rate=resample_rate, gain_augment=gain_augment) 178 | return ds 179 | 180 | 181 | def get_base_test_set(resample_rate=32000): 182 | test_hdf5 = dataset_config['openmic_test_hdf5'] 183 | ds = AudioSetDataset(test_hdf5, resample_rate=resample_rate) 184 | return ds 185 | 186 | 187 | def get_training_set(roll=False, wavmix=False, gain_augment=0, resample_rate=32000): 188 | ds = get_base_training_set(resample_rate=resample_rate, gain_augment=gain_augment) 189 | if roll: 190 | ds = PreprocessDataset(ds, get_roll_func()) 191 | if wavmix: 192 | ds = MixupDataset(ds) 193 | return ds 194 | 195 | 196 | def get_test_set(resample_rate=32000): 197 | ds = get_base_test_set(resample_rate) 198 | return ds 199 | 200 | -------------------------------------------------------------------------------- /datasets/fsd50k.py: -------------------------------------------------------------------------------- 1 | import io 2 | import os 3 | import av 4 | from torch.utils.data import Dataset as TorchDataset, WeightedRandomSampler 5 | import torch 6 | import numpy as np 7 | import h5py 8 | 9 | from datasets.helpers.audiodatasets import PreprocessDataset, get_roll_func 10 | 11 | # specify FSD50K location in 'dataset_dir' 12 | # 3 files have to be located there: 13 | # - FSD50K.eval_mp3.hdf 14 | # - FSD50K.val_mp3.hdf 15 | # - FSD50K.train_mp3.hdf 16 | # follow the instructions here to get these 3 files: 17 | # https://github.com/kkoutini/PaSST/tree/main/fsd50k 18 | 19 | dataset_dir = None 20 | 21 | assert dataset_dir is not None, "Specify FSD50K dataset location in variable 'dataset_dir'. " \ 22 | "Check out the Readme file for further instructions. " \ 23 | "https://github.com/fschmid56/EfficientAT/blob/main/README.md" 24 | 25 | dataset_config = { 26 | 'balanced_train_hdf5': os.path.join(dataset_dir, "FSD50K.train_mp3.hdf"), 27 | 'valid_hdf5': os.path.join(dataset_dir, "FSD50K.val_mp3.hdf"), 28 | 'eval_hdf5': os.path.join(dataset_dir, "FSD50K.eval_mp3.hdf"), 29 | 'num_of_classes': 200 30 | } 31 | 32 | 33 | def decode_mp3(mp3_arr): 34 | """ 35 | decodes an array if uint8 representing an mp3 file 36 | :rtype: np.array 37 | """ 38 | container = av.open(io.BytesIO(mp3_arr.tobytes())) 39 | stream = next(s for s in container.streams if s.type == 'audio') 40 | a = [] 41 | for i, packet in enumerate(container.demux(stream)): 42 | for frame in packet.decode(): 43 | a.append(frame.to_ndarray().reshape(-1)) 44 | waveform = np.concatenate(a) 45 | if waveform.dtype != 'float32': 46 | raise RuntimeError("Unexpected wave type") 47 | return waveform 48 | 49 | 50 | def pad_or_truncate(x, audio_length): 51 | """Pad all audio to specific length.""" 52 | if audio_length is None: 53 | # audio_length not specified don't do anything. 54 | return x 55 | if len(x) <= audio_length: 56 | return np.concatenate((x, np.zeros(audio_length - len(x), dtype=np.float32)), axis=0) 57 | else: 58 | offset = torch.randint(0, len(x) - audio_length + 1, (1,)).item() 59 | return x[offset:offset + audio_length] 60 | 61 | 62 | def pydub_augment(waveform, gain_augment=0): 63 | if gain_augment: 64 | gain = torch.randint(gain_augment * 2, (1,)).item() - gain_augment 65 | amp = 10 ** (gain / 20) 66 | waveform = waveform * amp 67 | return waveform 68 | 69 | 70 | class MixupDataset(TorchDataset): 71 | """ Mixing Up wave forms 72 | """ 73 | 74 | def __init__(self, dataset, beta=2, rate=0.5): 75 | self.beta = beta 76 | self.rate = rate 77 | self.dataset = dataset 78 | print(f"Mixing up waveforms from dataset of len {len(dataset)}") 79 | 80 | def __getitem__(self, index): 81 | if torch.rand(1) < self.rate: 82 | x1, f1, y1 = self.dataset[index] 83 | idx2 = torch.randint(len(self.dataset), (1,)).item() 84 | x2, f2, y2 = self.dataset[idx2] 85 | l = np.random.beta(self.beta, self.beta) 86 | l = max(l, 1. - l) 87 | x1 = x1 - x1.mean() 88 | x2 = x2 - x2.mean() 89 | x = (x1 * l + x2 * (1. - l)) 90 | x = x - x.mean() 91 | return x, f1, (y1 * l + y2 * (1. - l)) 92 | return self.dataset[index] 93 | 94 | def __len__(self): 95 | return len(self.dataset) 96 | 97 | 98 | class AudioSetDataset(TorchDataset): 99 | def __init__(self, hdf5_file, resample_rate=32000, classes_num=200, clip_length=10, 100 | in_mem=False, gain_augment=0): 101 | """ 102 | Reads the mp3 bytes from HDF file decodes using av and returns a fixed length audio wav 103 | """ 104 | self.resample_rate = resample_rate 105 | self.hdf5_file = hdf5_file 106 | if in_mem: 107 | print("\nPreloading in memory\n") 108 | with open(hdf5_file, 'rb') as f: 109 | self.hdf5_file = io.BytesIO(f.read()) 110 | with h5py.File(hdf5_file, 'r') as f: 111 | self.length = len(f['audio_name']) 112 | print(f"Dataset from {hdf5_file} with length {self.length}.") 113 | self.dataset_file = None # lazy init 114 | self.clip_length = clip_length 115 | if clip_length is not None: 116 | self.clip_length = clip_length * resample_rate 117 | self.classes_num = classes_num 118 | self.gain_augment = gain_augment 119 | 120 | def open_hdf5(self): 121 | self.dataset_file = h5py.File(self.hdf5_file, 'r') 122 | 123 | def __len__(self): 124 | return self.length 125 | 126 | def __del__(self): 127 | if self.dataset_file is not None: 128 | self.dataset_file.close() 129 | self.dataset_file = None 130 | 131 | def __getitem__(self, index): 132 | """Load waveform and target of an audio clip. 133 | Args: 134 | meta: { 135 | 'hdf5_path': str, 136 | 'index_in_hdf5': int} 137 | Returns: 138 | data_dict: { 139 | 'audio_name': str, 140 | 'waveform': (clip_samples,), 141 | 'target': (classes_num,)} 142 | """ 143 | if self.dataset_file is None: 144 | self.open_hdf5() 145 | 146 | audio_name = self.dataset_file['audio_name'][index].decode() 147 | waveform = decode_mp3(self.dataset_file['mp3'][index]) 148 | waveform = pydub_augment(waveform, self.gain_augment) 149 | waveform = pad_or_truncate(waveform, self.clip_length) 150 | waveform = self.resample(waveform) 151 | target = self.dataset_file['target'][index] 152 | target = np.unpackbits(target, axis=-1, 153 | count=self.classes_num).astype(np.float32) 154 | return waveform.reshape(1, -1), audio_name, target 155 | 156 | def resample(self, waveform): 157 | """Resample. 158 | Args: 159 | waveform: (clip_samples,) 160 | Returns: 161 | (resampled_clip_samples,) 162 | """ 163 | if self.resample_rate == 32000: 164 | return waveform 165 | elif self.resample_rate == 16000: 166 | return waveform[0:: 2] 167 | elif self.resample_rate == 8000: 168 | return waveform[0:: 4] 169 | else: 170 | raise Exception('Incorrect sample rate!') 171 | 172 | 173 | def get_base_training_set(resample_rate=32000, gain_augment=0): 174 | balanced_train_hdf5 = dataset_config['balanced_train_hdf5'] 175 | ds = AudioSetDataset(balanced_train_hdf5, resample_rate=resample_rate, gain_augment=gain_augment) 176 | return ds 177 | 178 | 179 | def get_base_eval_set(resample_rate=32000, variable_eval=None): 180 | eval_hdf5 = dataset_config['eval_hdf5'] 181 | if variable_eval: 182 | print("Variable length eval!!") 183 | ds = AudioSetDataset(eval_hdf5, resample_rate=resample_rate, clip_length=None) 184 | else: 185 | ds = AudioSetDataset(eval_hdf5, resample_rate=resample_rate) 186 | return ds 187 | 188 | 189 | def get_base_valid_set(resample_rate=32000, variable_eval=None): 190 | valid_hdf5 = dataset_config['valid_hdf5'] 191 | if variable_eval: 192 | print("Variable length valid_set !!") 193 | ds = AudioSetDataset(valid_hdf5, resample_rate=resample_rate, clip_length=None) 194 | else: 195 | ds = AudioSetDataset(valid_hdf5, resample_rate=resample_rate) 196 | return ds 197 | 198 | 199 | def get_training_set(roll=False, wavmix=False, gain_augment=0, resample_rate=32000): 200 | ds = get_base_training_set(resample_rate=resample_rate, gain_augment=gain_augment) 201 | if roll: 202 | ds = PreprocessDataset(ds, get_roll_func()) 203 | if wavmix: 204 | ds = MixupDataset(ds) 205 | return ds 206 | 207 | 208 | def get_valid_set(resample_rate=32000, variable_eval=None): 209 | ds = get_base_valid_set(resample_rate, variable_eval) 210 | return ds 211 | 212 | 213 | def get_eval_set(resample_rate=32000, variable_eval=None): 214 | ds = get_base_eval_set(resample_rate, variable_eval) 215 | return ds 216 | -------------------------------------------------------------------------------- /ex_esc50.py: -------------------------------------------------------------------------------- 1 | import wandb 2 | import numpy as np 3 | import os 4 | from tqdm import tqdm 5 | import torch 6 | from torch.utils.data import DataLoader 7 | import argparse 8 | from sklearn import metrics 9 | import torch.nn.functional as F 10 | 11 | from datasets.esc50 import get_test_set, get_training_set 12 | from models.mn.model import get_model as get_mobilenet 13 | from models.dymn.model import get_model as get_dymn 14 | from models.preprocess import AugmentMelSTFT 15 | from helpers.init import worker_init_fn 16 | from helpers.utils import NAME_TO_WIDTH, exp_warmup_linear_down, mixup 17 | 18 | 19 | def train(args): 20 | # Train Models for Acoustic Scene Classification 21 | 22 | # logging is done using wandb 23 | wandb.init( 24 | project="ESC50", 25 | notes="Fine-tune Models on ESC50.", 26 | tags=["Environmental Sound Classification", "Fine-Tuning"], 27 | config=args, 28 | name=args.experiment_name 29 | ) 30 | 31 | device = torch.device('cuda') if args.cuda and torch.cuda.is_available() else torch.device('cpu') 32 | 33 | # model to preprocess waveform into mel spectrograms 34 | mel = AugmentMelSTFT(n_mels=args.n_mels, 35 | sr=args.resample_rate, 36 | win_length=args.window_size, 37 | hopsize=args.hop_size, 38 | n_fft=args.n_fft, 39 | freqm=args.freqm, 40 | timem=args.timem, 41 | fmin=args.fmin, 42 | fmax=args.fmax, 43 | fmin_aug_range=args.fmin_aug_range, 44 | fmax_aug_range=args.fmax_aug_range 45 | ) 46 | mel.to(device) 47 | 48 | # load prediction model 49 | model_name = args.model_name 50 | pretrained_name = model_name if args.pretrained else None 51 | width = NAME_TO_WIDTH(model_name) if model_name and args.pretrained else args.model_width 52 | if model_name.startswith("dymn"): 53 | model = get_dymn(width_mult=width, pretrained_name=pretrained_name, 54 | pretrain_final_temp=args.pretrain_final_temp, 55 | num_classes=50) 56 | else: 57 | model = get_mobilenet(width_mult=width, pretrained_name=pretrained_name, 58 | head_type=args.head_type, se_dims=args.se_dims, 59 | num_classes=50) 60 | model.to(device) 61 | 62 | # dataloader 63 | dl = DataLoader(dataset=get_training_set(resample_rate=args.resample_rate, 64 | roll=False if args.no_roll else True, 65 | wavmix=False if args.no_wavmix else True, 66 | gain_augment=args.gain_augment, 67 | fold=args.fold), 68 | worker_init_fn=worker_init_fn, 69 | num_workers=args.num_workers, 70 | batch_size=args.batch_size, 71 | shuffle=True) 72 | 73 | # evaluation loader 74 | eval_dl = DataLoader(dataset=get_test_set(resample_rate=args.resample_rate, fold=args.fold), 75 | worker_init_fn=worker_init_fn, 76 | num_workers=args.num_workers, 77 | batch_size=args.batch_size) 78 | 79 | # optimizer & scheduler 80 | lr = args.lr 81 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) 82 | # phases of lr schedule: exponential increase, constant lr, linear decrease, fine-tune 83 | schedule_lambda = \ 84 | exp_warmup_linear_down(args.warm_up_len, args.ramp_down_len, args.ramp_down_start, args.last_lr_value) 85 | scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, schedule_lambda) 86 | 87 | name = None 88 | accuracy, val_loss = float('NaN'), float('NaN') 89 | 90 | for epoch in range(args.n_epochs): 91 | mel.train() 92 | model.train() 93 | train_stats = dict(train_loss=list()) 94 | pbar = tqdm(dl) 95 | pbar.set_description("Epoch {}/{}: accuracy: {:.4f}, val_loss: {:.4f}" 96 | .format(epoch + 1, args.n_epochs, accuracy, val_loss)) 97 | for batch in pbar: 98 | x, f, y = batch 99 | bs = x.size(0) 100 | x, y = x.to(device), y.to(device) 101 | x = _mel_forward(x, mel) 102 | 103 | if args.mixup_alpha: 104 | rn_indices, lam = mixup(bs, args.mixup_alpha) 105 | lam = lam.to(x.device) 106 | x = x * lam.reshape(bs, 1, 1, 1) + \ 107 | x[rn_indices] * (1. - lam.reshape(bs, 1, 1, 1)) 108 | y_hat, _ = model(x) 109 | samples_loss = (F.cross_entropy(y_hat, y, reduction="none") * lam.reshape(bs) + 110 | F.cross_entropy(y_hat, y[rn_indices], reduction="none") * ( 111 | 1. - lam.reshape(bs))) 112 | 113 | else: 114 | y_hat, _ = model(x) 115 | samples_loss = F.cross_entropy(y_hat, y, reduction="none") 116 | 117 | # loss 118 | loss = samples_loss.mean() 119 | 120 | # append training statistics 121 | train_stats['train_loss'].append(loss.detach().cpu().numpy()) 122 | 123 | # Update Model 124 | loss.backward() 125 | optimizer.step() 126 | optimizer.zero_grad() 127 | # Update learning rate 128 | scheduler.step() 129 | 130 | # evaluate 131 | accuracy, val_loss = _test(model, mel, eval_dl, device) 132 | 133 | # log train and validation statistics 134 | wandb.log({"train_loss": np.mean(train_stats['train_loss']), 135 | "accuracy": accuracy, 136 | "val_loss": val_loss 137 | }) 138 | 139 | # remove previous model (we try to not flood your hard disk) and save latest model 140 | if name is not None: 141 | os.remove(os.path.join(wandb.run.dir, name)) 142 | name = f"mn{str(width).replace('.', '')}_esc50_epoch_{epoch}_acc_{int(round(accuracy*1000))}.pt" 143 | torch.save(model.state_dict(), os.path.join(wandb.run.dir, name)) 144 | 145 | 146 | def _mel_forward(x, mel): 147 | old_shape = x.size() 148 | x = x.reshape(-1, old_shape[2]) 149 | x = mel(x) 150 | x = x.reshape(old_shape[0], old_shape[1], x.shape[1], x.shape[2]) 151 | return x 152 | 153 | 154 | def _test(model, mel, eval_loader, device): 155 | model.eval() 156 | mel.eval() 157 | 158 | targets = [] 159 | outputs = [] 160 | losses = [] 161 | pbar = tqdm(eval_loader) 162 | pbar.set_description("Validating") 163 | for batch in pbar: 164 | x, f, y = batch 165 | x = x.to(device) 166 | y = y.to(device) 167 | with torch.no_grad(): 168 | x = _mel_forward(x, mel) 169 | y_hat, _ = model(x) 170 | targets.append(y.cpu().numpy()) 171 | outputs.append(y_hat.float().cpu().numpy()) 172 | losses.append(F.cross_entropy(y_hat, y).cpu().numpy()) 173 | 174 | targets = np.concatenate(targets) 175 | outputs = np.concatenate(outputs) 176 | losses = np.stack(losses) 177 | accuracy = metrics.accuracy_score(targets.argmax(axis=1), outputs.argmax(axis=1)) 178 | return accuracy, losses.mean() 179 | 180 | 181 | if __name__ == '__main__': 182 | parser = argparse.ArgumentParser(description='Example of parser. ') 183 | 184 | # general 185 | parser.add_argument('--experiment_name', type=str, default="ESC50") 186 | parser.add_argument('--cuda', action='store_true', default=False) 187 | parser.add_argument('--batch_size', type=int, default=128) 188 | parser.add_argument('--num_workers', type=int, default=12) 189 | parser.add_argument('--fold', type=int, default=1) 190 | 191 | # training 192 | parser.add_argument('--pretrained', action='store_true', default=False) 193 | parser.add_argument('--model_name', type=str, default="mn10_as") 194 | parser.add_argument('--pretrain_final_temp', type=float, default=1.0) # for DyMN 195 | parser.add_argument('--model_width', type=float, default=1.0) 196 | parser.add_argument('--head_type', type=str, default="mlp") 197 | parser.add_argument('--se_dims', type=str, default="c") 198 | parser.add_argument('--n_epochs', type=int, default=80) 199 | parser.add_argument('--mixup_alpha', type=float, default=0.3) 200 | parser.add_argument('--no_roll', action='store_true', default=False) 201 | parser.add_argument('--no_wavmix', action='store_true', default=False) 202 | parser.add_argument('--gain_augment', type=int, default=12) 203 | parser.add_argument('--weight_decay', type=int, default=0.0) 204 | 205 | # lr schedule 206 | parser.add_argument('--lr', type=float, default=6e-5) 207 | parser.add_argument('--warm_up_len', type=int, default=10) 208 | parser.add_argument('--ramp_down_start', type=int, default=10) 209 | parser.add_argument('--ramp_down_len', type=int, default=65) 210 | parser.add_argument('--last_lr_value', type=float, default=0.01) 211 | 212 | # preprocessing 213 | parser.add_argument('--resample_rate', type=int, default=32000) 214 | parser.add_argument('--window_size', type=int, default=800) 215 | parser.add_argument('--hop_size', type=int, default=320) 216 | parser.add_argument('--n_fft', type=int, default=1024) 217 | parser.add_argument('--n_mels', type=int, default=128) 218 | parser.add_argument('--freqm', type=int, default=0) 219 | parser.add_argument('--timem', type=int, default=0) 220 | parser.add_argument('--fmin', type=int, default=0) 221 | parser.add_argument('--fmax', type=int, default=None) 222 | parser.add_argument('--fmin_aug_range', type=int, default=10) 223 | parser.add_argument('--fmax_aug_range', type=int, default=2000) 224 | 225 | args = parser.parse_args() 226 | train(args) 227 | -------------------------------------------------------------------------------- /ex_dcase20.py: -------------------------------------------------------------------------------- 1 | import wandb 2 | import numpy as np 3 | import os 4 | from tqdm import tqdm 5 | import torch 6 | from torch.utils.data import DataLoader 7 | import argparse 8 | from sklearn import metrics 9 | import torch.nn.functional as F 10 | 11 | from datasets.dcase20 import get_test_set, get_training_set 12 | from models.mn.model import get_model as get_mobilenet 13 | from models.dymn.model import get_model as get_dymn 14 | from models.preprocess import AugmentMelSTFT 15 | from helpers.init import worker_init_fn 16 | from helpers.utils import NAME_TO_WIDTH, exp_warmup_linear_down, mixup, mixstyle 17 | 18 | 19 | def train(args): 20 | # Train Models for Acoustic Scene Classification 21 | 22 | # logging is done using wandb 23 | wandb.init( 24 | project="DCASE20", 25 | notes="Fine-tune Models for Acoustic Scene Classification.", 26 | tags=["Tau Urban Acoustic Scenes 2020 Mobile", "Acoustic Scene Classification", "Fine-Tuning"], 27 | config=args, 28 | name=args.experiment_name 29 | ) 30 | 31 | device = torch.device('cuda') if args.cuda and torch.cuda.is_available() else torch.device('cpu') 32 | 33 | # model to preprocess waveform into mel spectrograms 34 | mel = AugmentMelSTFT(n_mels=args.n_mels, 35 | sr=args.resample_rate, 36 | win_length=args.window_size, 37 | hopsize=args.hop_size, 38 | n_fft=args.n_fft, 39 | freqm=args.freqm, 40 | timem=args.timem, 41 | fmin=args.fmin, 42 | fmax=args.fmax, 43 | fmin_aug_range=args.fmin_aug_range, 44 | fmax_aug_range=args.fmax_aug_range 45 | ) 46 | mel.to(device) 47 | 48 | # load prediction model 49 | model_name = args.model_name 50 | pretrained_name = model_name if args.pretrained else None 51 | width = NAME_TO_WIDTH(model_name) if model_name and args.pretrained else args.model_width 52 | if model_name.startswith("dymn"): 53 | model = get_dymn(width_mult=width, pretrained_name=pretrained_name, 54 | pretrain_final_temp=args.pretrain_final_temp, 55 | num_classes=10) 56 | else: 57 | model = get_mobilenet(width_mult=width, pretrained_name=pretrained_name, 58 | head_type=args.head_type, se_dims=args.se_dims, 59 | num_classes=10) 60 | model.to(device) 61 | 62 | # dataloader 63 | dl = DataLoader(dataset=get_training_set(args.cache_path, 64 | args.resample_rate, 65 | roll=False if args.no_roll else True, 66 | wavmix=False if args.no_wavmix else True, 67 | gain_augment=args.gain_augment 68 | ), 69 | worker_init_fn=worker_init_fn, 70 | num_workers=args.num_workers, 71 | batch_size=args.batch_size, 72 | shuffle=True) 73 | 74 | # evaluation loader 75 | eval_dl = DataLoader(dataset=get_test_set(args.cache_path, args.resample_rate), 76 | worker_init_fn=worker_init_fn, 77 | num_workers=args.num_workers, 78 | batch_size=args.batch_size) 79 | 80 | # optimizer & scheduler 81 | lr = args.lr 82 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) 83 | # phases of lr schedule: exponential increase, constant lr, linear decrease, fine-tune 84 | schedule_lambda = \ 85 | exp_warmup_linear_down(args.warm_up_len, args.ramp_down_len, args.ramp_down_start, args.last_lr_value) 86 | scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, schedule_lambda) 87 | 88 | name = None 89 | accuracy, val_loss = float('NaN'), float('NaN') 90 | 91 | for epoch in range(args.n_epochs): 92 | mel.train() 93 | model.train() 94 | train_stats = dict(train_loss=list()) 95 | pbar = tqdm(dl) 96 | pbar.set_description("Epoch {}/{}: accuracy: {:.4f}, val_loss: {:.4f}" 97 | .format(epoch + 1, args.n_epochs, accuracy, val_loss)) 98 | for batch in pbar: 99 | x, f, y, dev, city, index = batch 100 | bs = x.size(0) 101 | x, y = x.to(device), y.to(device) 102 | x = _mel_forward(x, mel) 103 | 104 | if args.mixstyle_p > 0: 105 | x = mixstyle(x, args.mixstyle_p, args.mixstyle_alpha) 106 | y_hat, _ = model(x) 107 | samples_loss = F.cross_entropy(y_hat, y, reduction="none") 108 | elif args.mixup_alpha: 109 | rn_indices, lam = mixup(bs, args.mixup_alpha) 110 | lam = lam.to(x.device) 111 | x = x * lam.reshape(bs, 1, 1, 1) + \ 112 | x[rn_indices] * (1. - lam.reshape(bs, 1, 1, 1)) 113 | y_hat, _ = model(x) 114 | samples_loss = (F.cross_entropy(y_hat, y, reduction="none") * lam.reshape(bs) + 115 | F.cross_entropy(y_hat, y[rn_indices], reduction="none") * ( 116 | 1. - lam.reshape(bs))) 117 | 118 | else: 119 | y_hat, _ = model(x) 120 | samples_loss = F.cross_entropy(y_hat, y, reduction="none") 121 | 122 | # loss 123 | loss = samples_loss.mean() 124 | 125 | # append training statistics 126 | train_stats['train_loss'].append(loss.detach().cpu().numpy()) 127 | 128 | # Update Model 129 | loss.backward() 130 | optimizer.step() 131 | optimizer.zero_grad() 132 | # Update learning rate 133 | scheduler.step() 134 | 135 | # evaluate 136 | accuracy, val_loss = _test(model, mel, eval_dl, device) 137 | 138 | # log train and validation statistics 139 | wandb.log({"train_loss": np.mean(train_stats['train_loss']), 140 | "accuracy": accuracy, 141 | "val_loss": val_loss 142 | }) 143 | 144 | # remove previous model (we try to not flood your hard disk) and save latest model 145 | if name is not None: 146 | os.remove(os.path.join(wandb.run.dir, name)) 147 | name = f"mn{str(width).replace('.', '')}_dcase_epoch_{epoch}_acc_{int(round(accuracy*1000))}.pt" 148 | torch.save(model.state_dict(), os.path.join(wandb.run.dir, name)) 149 | 150 | 151 | def _mel_forward(x, mel): 152 | old_shape = x.size() 153 | x = x.reshape(-1, old_shape[2]) 154 | x = mel(x) 155 | x = x.reshape(old_shape[0], old_shape[1], x.shape[1], x.shape[2]) 156 | return x 157 | 158 | 159 | def _test(model, mel, eval_loader, device): 160 | model.eval() 161 | mel.eval() 162 | 163 | targets = [] 164 | outputs = [] 165 | losses = [] 166 | pbar = tqdm(eval_loader) 167 | pbar.set_description("Validating") 168 | for batch in pbar: 169 | x, f, y, dev, city, index = batch 170 | x = x.to(device) 171 | y = y.to(device) 172 | with torch.no_grad(): 173 | x = _mel_forward(x, mel) 174 | y_hat, _ = model(x) 175 | targets.append(y.cpu().numpy()) 176 | outputs.append(y_hat.float().cpu().numpy()) 177 | losses.append(F.cross_entropy(y_hat, y).cpu().numpy()) 178 | 179 | targets = np.concatenate(targets) 180 | outputs = np.concatenate(outputs) 181 | losses = np.stack(losses) 182 | accuracy = metrics.accuracy_score(targets, outputs.argmax(axis=1)) 183 | return accuracy, losses.mean() 184 | 185 | 186 | if __name__ == '__main__': 187 | parser = argparse.ArgumentParser(description='Example of parser. ') 188 | 189 | # general 190 | parser.add_argument('--experiment_name', type=str, default="DCASE20") 191 | parser.add_argument('--cuda', action='store_true', default=False) 192 | parser.add_argument('--batch_size', type=int, default=64) 193 | parser.add_argument('--num_workers', type=int, default=12) 194 | parser.add_argument('--cache_path', type=str, default=None) 195 | 196 | # training 197 | parser.add_argument('--pretrained', action='store_true', default=False) 198 | parser.add_argument('--model_name', type=str, default="mn10_as") 199 | parser.add_argument('--pretrain_final_temp', type=float, default=1.0) # for DyMN 200 | parser.add_argument('--model_width', type=float, default=1.0) 201 | parser.add_argument('--head_type', type=str, default="mlp") 202 | parser.add_argument('--se_dims', type=str, default="c") 203 | parser.add_argument('--n_epochs', type=int, default=80) 204 | parser.add_argument('--mixup_alpha', type=float, default=0.3) 205 | parser.add_argument('--mixstyle_p', type=float, default=0.0) 206 | parser.add_argument('--mixstyle_alpha', type=float, default=0.4) 207 | parser.add_argument('--no_roll', action='store_true', default=False) 208 | parser.add_argument('--no_wavmix', action='store_true', default=False) 209 | parser.add_argument('--gain_augment', type=int, default=12) 210 | parser.add_argument('--weight_decay', type=int, default=0.0) 211 | 212 | # lr schedule 213 | parser.add_argument('--lr', type=float, default=8e-4) 214 | parser.add_argument('--warm_up_len', type=int, default=10) 215 | parser.add_argument('--ramp_down_start', type=int, default=10) 216 | parser.add_argument('--ramp_down_len', type=int, default=65) 217 | parser.add_argument('--last_lr_value', type=float, default=0.01) 218 | 219 | # preprocessing 220 | parser.add_argument('--resample_rate', type=int, default=32000) 221 | parser.add_argument('--window_size', type=int, default=800) 222 | parser.add_argument('--hop_size', type=int, default=320) 223 | parser.add_argument('--n_fft', type=int, default=1024) 224 | parser.add_argument('--n_mels', type=int, default=128) 225 | parser.add_argument('--freqm', type=int, default=0) 226 | parser.add_argument('--timem', type=int, default=0) 227 | parser.add_argument('--fmin', type=int, default=0) 228 | parser.add_argument('--fmax', type=int, default=None) 229 | parser.add_argument('--fmin_aug_range', type=int, default=10) 230 | parser.add_argument('--fmax_aug_range', type=int, default=2000) 231 | 232 | args = parser.parse_args() 233 | train(args) 234 | -------------------------------------------------------------------------------- /datasets/audioset.py: -------------------------------------------------------------------------------- 1 | import io 2 | import av 3 | from torch.utils.data import Dataset as TorchDataset, ConcatDataset, WeightedRandomSampler 4 | import torch 5 | import numpy as np 6 | import h5py 7 | import os 8 | 9 | from datasets.helpers.audiodatasets import PreprocessDataset, get_roll_func 10 | 11 | # specify AudioSet location in 'dataset_dir' 12 | # 3 files have to be located there: 13 | # - balanced_train_segments_mp3.hdf 14 | # - unbalanced_train_segments_mp3.hdf 15 | # - eval_segments_mp3.hdf 16 | # follow the instructions here to get these 3 files: 17 | # https://github.com/kkoutini/PaSST/tree/main/audioset 18 | 19 | dataset_dir = None 20 | assert dataset_dir is not None, "Specify AudioSet location in variable 'dataset_dir'. " \ 21 | "Check out the Readme file for further instructions. " \ 22 | "https://github.com/fschmid56/EfficientAT/blob/main/README.md" 23 | 24 | dataset_config = { 25 | 'balanced_train_hdf5': os.path.join(dataset_dir, "balanced_train_segments_mp3.hdf"), 26 | 'unbalanced_train_hdf5': os.path.join(dataset_dir, "unbalanced_train_segments_mp3.hdf"), 27 | 'eval_hdf5': os.path.join(dataset_dir, "eval_segments_mp3.hdf"), 28 | 'num_of_classes': 527 29 | } 30 | 31 | 32 | def decode_mp3(mp3_arr): 33 | """ 34 | decodes an array if uint8 representing an mp3 file 35 | :rtype: np.array 36 | """ 37 | container = av.open(io.BytesIO(mp3_arr.tobytes())) 38 | stream = next(s for s in container.streams if s.type == 'audio') 39 | # print(stream) 40 | a = [] 41 | for i, packet in enumerate(container.demux(stream)): 42 | for frame in packet.decode(): 43 | a.append(frame.to_ndarray().reshape(-1)) 44 | waveform = np.concatenate(a) 45 | if waveform.dtype != 'float32': 46 | raise RuntimeError("Unexpected wave type") 47 | return waveform 48 | 49 | 50 | def pad_or_truncate(x, audio_length): 51 | """Pad all audio to specific length.""" 52 | if len(x) <= audio_length: 53 | return np.concatenate((x, np.zeros(audio_length - len(x), dtype=np.float32)), axis=0) 54 | else: 55 | return x[0: audio_length] 56 | 57 | 58 | def pydub_augment(waveform, gain_augment=0): 59 | if gain_augment: 60 | gain = torch.randint(gain_augment * 2, (1,)).item() - gain_augment 61 | amp = 10 ** (gain / 20) 62 | waveform = waveform * amp 63 | return waveform 64 | 65 | 66 | class MixupDataset(TorchDataset): 67 | """ Mixing Up wave forms 68 | """ 69 | 70 | def __init__(self, dataset, beta=2, rate=0.5): 71 | self.beta = beta 72 | self.rate = rate 73 | self.dataset = dataset 74 | print(f"Mixing up waveforms from dataset of len {len(dataset)}") 75 | 76 | def __getitem__(self, index): 77 | if torch.rand(1) < self.rate: 78 | x1, f1, y1 = self.dataset[index] 79 | idx2 = torch.randint(len(self.dataset), (1,)).item() 80 | x2, f2, y2 = self.dataset[idx2] 81 | l = np.random.beta(self.beta, self.beta) 82 | l = max(l, 1. - l) 83 | x1 = x1 - x1.mean() 84 | x2 = x2 - x2.mean() 85 | x = (x1 * l + x2 * (1. - l)) 86 | x = x - x.mean() 87 | return x, f1, (y1 * l + y2 * (1. - l)) 88 | return self.dataset[index] 89 | 90 | def __len__(self): 91 | return len(self.dataset) 92 | 93 | 94 | class AddIndexDataset(TorchDataset): 95 | def __init__(self, ds): 96 | self.ds = ds 97 | 98 | def __getitem__(self, index): 99 | x, f, y = self.ds[index] 100 | return x, f, y, index 101 | 102 | def __len__(self): 103 | return len(self.ds) 104 | 105 | 106 | class AudioSetDataset(TorchDataset): 107 | def __init__(self, hdf5_file, sample_rate=32000, resample_rate=32000, classes_num=527, 108 | clip_length=10, in_mem=False, gain_augment=0): 109 | """ 110 | Reads the mp3 bytes from HDF file decodes using av and returns a fixed length audio wav 111 | """ 112 | self.sample_rate = sample_rate 113 | self.resample_rate = resample_rate 114 | self.hdf5_file = hdf5_file 115 | if in_mem: 116 | print("\nPreloading in memory\n") 117 | with open(hdf5_file, 'rb') as f: 118 | self.hdf5_file = io.BytesIO(f.read()) 119 | with h5py.File(hdf5_file, 'r') as f: 120 | self.length = len(f['audio_name']) 121 | print(f"Dataset from {hdf5_file} with length {self.length}.") 122 | self.dataset_file = None # lazy init 123 | self.clip_length = clip_length * sample_rate 124 | self.classes_num = classes_num 125 | self.gain_augment = gain_augment 126 | 127 | def open_hdf5(self): 128 | self.dataset_file = h5py.File(self.hdf5_file, 'r') 129 | 130 | def __len__(self): 131 | return self.length 132 | 133 | def __del__(self): 134 | if self.dataset_file is not None: 135 | self.dataset_file.close() 136 | self.dataset_file = None 137 | 138 | def __getitem__(self, index): 139 | """Load waveform and target of an audio clip. 140 | Args: 141 | 'index': int 142 | Returns: 143 | data_dict: { 144 | 'audio_name': str, 145 | 'waveform': (clip_samples,), 146 | 'target': (classes_num,)} 147 | """ 148 | if self.dataset_file is None: 149 | self.open_hdf5() 150 | 151 | audio_name = self.dataset_file['audio_name'][index].decode() 152 | # convert our modified filenames to official file names 153 | audio_name = audio_name.replace(".mp3", "").split("Y", 1)[1] 154 | waveform = decode_mp3(self.dataset_file['mp3'][index]) 155 | waveform = pydub_augment(waveform, self.gain_augment) 156 | waveform = pad_or_truncate(waveform, self.clip_length) 157 | waveform = self.resample(waveform) 158 | target = self.dataset_file['target'][index] 159 | target = np.unpackbits(target, axis=-1, 160 | count=self.classes_num).astype(np.float32) 161 | return waveform.reshape(1, -1), audio_name, target 162 | 163 | def resample(self, waveform): 164 | """Resample. 165 | Args: 166 | waveform: (clip_samples,) 167 | Returns: 168 | (resampled_clip_samples,) 169 | """ 170 | if self.resample_rate == 32000: 171 | return waveform 172 | elif self.resample_rate == 16000: 173 | return waveform[0:: 2] 174 | elif self.resample_rate == 8000: 175 | return waveform[0:: 4] 176 | else: 177 | raise Exception('Incorrect sample rate!') 178 | 179 | 180 | def get_ft_weighted_sampler(epoch_len=100000, sampler_replace=False): 181 | samples_weights = get_ft_cls_balanced_sample_weights() 182 | return WeightedRandomSampler(samples_weights, num_samples=epoch_len, replacement=sampler_replace) 183 | 184 | 185 | def get_ft_cls_balanced_sample_weights(sample_weight_offset=100, sample_weight_sum=True): 186 | """ 187 | :return: float tensor of shape len(full_training_set) representing the weights of each sample. 188 | """ 189 | # the order of balanced_train_hdf5,unbalanced_train_hdf5 is important. 190 | # should match get_full_training_set 191 | unbalanced_train_hdf5 = dataset_config['unbalanced_train_hdf5'] 192 | balanced_train_hdf5 = dataset_config['balanced_train_hdf5'] 193 | num_of_classes = dataset_config['num_of_classes'] 194 | 195 | all_y = [] 196 | for hdf5_file in [balanced_train_hdf5, unbalanced_train_hdf5]: 197 | with h5py.File(hdf5_file, 'r') as dataset_file: 198 | target = dataset_file['target'] 199 | target = np.unpackbits(target, axis=-1, count=num_of_classes) 200 | all_y.append(target) 201 | all_y = np.concatenate(all_y, axis=0) 202 | all_y = torch.as_tensor(all_y) 203 | per_class = all_y.long().sum(0).float().reshape(1, -1) # frequencies per class 204 | 205 | per_class = sample_weight_offset + per_class # offset low freq classes 206 | if sample_weight_offset > 0: 207 | print(f"Warning: sample_weight_offset={sample_weight_offset} minnow={per_class.min()}") 208 | per_class_weights = 1000. / per_class 209 | all_weight = all_y * per_class_weights 210 | if sample_weight_sum: 211 | all_weight = all_weight.sum(dim=1) 212 | else: 213 | all_weight, _ = all_weight.max(dim=1) 214 | return all_weight 215 | 216 | 217 | def get_base_full_training_set(resample_rate=32000, gain_augment=0): 218 | sets = [get_base_training_set(resample_rate=resample_rate, gain_augment=gain_augment), 219 | get_unbalanced_training_set(resample_rate=resample_rate, gain_augment=gain_augment)] 220 | ds = ConcatDataset(sets) 221 | return ds 222 | 223 | 224 | def get_base_training_set(resample_rate=32000, gain_augment=0): 225 | balanced_train_hdf5 = dataset_config['balanced_train_hdf5'] 226 | ds = AudioSetDataset(balanced_train_hdf5, resample_rate=resample_rate, gain_augment=gain_augment) 227 | return ds 228 | 229 | 230 | def get_unbalanced_training_set(resample_rate=32000, gain_augment=0): 231 | unbalanced_train_hdf5 = dataset_config['unbalanced_train_hdf5'] 232 | ds = AudioSetDataset(unbalanced_train_hdf5, resample_rate=resample_rate, gain_augment=gain_augment) 233 | return ds 234 | 235 | 236 | def get_base_test_set(resample_rate=32000): 237 | eval_hdf5 = dataset_config['eval_hdf5'] 238 | ds = AudioSetDataset(eval_hdf5, resample_rate=resample_rate) 239 | return ds 240 | 241 | 242 | def get_training_set(add_index=True, roll=False, wavmix=False, gain_augment=0, resample_rate=32000): 243 | ds = get_base_training_set(resample_rate=resample_rate, gain_augment=gain_augment) 244 | if roll: 245 | ds = PreprocessDataset(ds, get_roll_func()) 246 | if wavmix: 247 | ds = MixupDataset(ds) 248 | if add_index: 249 | ds = AddIndexDataset(ds) 250 | return ds 251 | 252 | 253 | def get_full_training_set(add_index=True, roll=False, wavmix=False, gain_augment=0, resample_rate=32000): 254 | ds = get_base_full_training_set(resample_rate=resample_rate, gain_augment=gain_augment) 255 | if roll: 256 | ds = PreprocessDataset(ds, get_roll_func()) 257 | if wavmix: 258 | ds = MixupDataset(ds) 259 | if add_index: 260 | ds = AddIndexDataset(ds) 261 | return ds 262 | 263 | 264 | def get_test_set(resample_rate=32000): 265 | ds = get_base_test_set(resample_rate=resample_rate) 266 | return ds 267 | -------------------------------------------------------------------------------- /ex_openmic.py: -------------------------------------------------------------------------------- 1 | import wandb 2 | import numpy as np 3 | import os 4 | from tqdm import tqdm 5 | import torch 6 | from torch.utils.data import DataLoader 7 | import argparse 8 | from sklearn import metrics 9 | import torch.nn.functional as F 10 | 11 | from datasets.openmic import get_test_set, get_training_set 12 | from models.mn.model import get_model as get_mobilenet 13 | from models.dymn.model import get_model as get_dymn 14 | from models.preprocess import AugmentMelSTFT 15 | from helpers.init import worker_init_fn 16 | from helpers.utils import NAME_TO_WIDTH, exp_warmup_linear_down, mixup 17 | 18 | 19 | def train(args): 20 | # Train Models on OpenMic 21 | 22 | # logging is done using wandb 23 | wandb.init( 24 | project="OpenMic", 25 | notes="Fine-tune Models on OpenMic.", 26 | tags=["OpenMic", "Instrument Recognition"], 27 | config=args, 28 | name=args.experiment_name 29 | ) 30 | 31 | device = torch.device('cuda') if args.cuda and torch.cuda.is_available() else torch.device('cpu') 32 | 33 | # model to preprocess waveform into mel spectrograms 34 | mel = AugmentMelSTFT(n_mels=args.n_mels, 35 | sr=args.resample_rate, 36 | win_length=args.window_size, 37 | hopsize=args.hop_size, 38 | n_fft=args.n_fft, 39 | freqm=args.freqm, 40 | timem=args.timem, 41 | fmin=args.fmin, 42 | fmax=args.fmax, 43 | fmin_aug_range=args.fmin_aug_range, 44 | fmax_aug_range=args.fmax_aug_range 45 | ) 46 | mel.to(device) 47 | 48 | # load prediction model 49 | model_name = args.model_name 50 | pretrained_name = model_name if args.pretrained else None 51 | width = NAME_TO_WIDTH(model_name) if model_name and args.pretrained else args.model_width 52 | if model_name.startswith("dymn"): 53 | model = get_dymn(width_mult=width, pretrained_name=pretrained_name, 54 | pretrain_final_temp=args.pretrain_final_temp, 55 | num_classes=20) 56 | else: 57 | model = get_mobilenet(width_mult=width, pretrained_name=pretrained_name, 58 | head_type=args.head_type, se_dims=args.se_dims, 59 | num_classes=20) 60 | model.to(device) 61 | 62 | # dataloader 63 | dl = DataLoader(dataset=get_training_set(resample_rate=args.resample_rate, 64 | roll=False if args.no_roll else True, 65 | wavmix=False if args.no_wavmix else True, 66 | gain_augment=args.gain_augment), 67 | worker_init_fn=worker_init_fn, 68 | num_workers=args.num_workers, 69 | batch_size=args.batch_size, 70 | shuffle=True) 71 | 72 | # evaluation loader 73 | valid_dl = DataLoader(dataset=get_test_set(resample_rate=args.resample_rate), 74 | worker_init_fn=worker_init_fn, 75 | num_workers=args.num_workers, 76 | batch_size=args.batch_size) 77 | 78 | # optimizer & scheduler 79 | lr = args.lr 80 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) 81 | # phases of lr schedule: exponential increase, constant lr, linear decrease, fine-tune 82 | schedule_lambda = \ 83 | exp_warmup_linear_down(args.warm_up_len, args.ramp_down_len, args.ramp_down_start, args.last_lr_value) 84 | scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, schedule_lambda) 85 | 86 | name = None 87 | mAP, ROC, val_loss = float('NaN'), float('NaN'), float('NaN') 88 | 89 | for epoch in range(args.n_epochs): 90 | mel.train() 91 | model.train() 92 | train_stats = dict(train_loss=list()) 93 | pbar = tqdm(dl) 94 | pbar.set_description("Epoch {}/{}: mAP: {:.4f}, val_loss: {:.4f}" 95 | .format(epoch + 1, args.n_epochs, mAP, val_loss)) 96 | for batch in pbar: 97 | x, f, y = batch 98 | bs = x.size(0) 99 | x, y = x.to(device), y.to(device) 100 | x = _mel_forward(x, mel) 101 | 102 | y_mask = y[:, 20:] 103 | y = y[:, :20] > 0.5 104 | y = y.float() 105 | 106 | if args.mixup_alpha: 107 | rn_indices, lam = mixup(bs, args.mixup_alpha) 108 | lam = lam.to(x.device) 109 | x = x * lam.reshape(bs, 1, 1, 1) + \ 110 | x[rn_indices] * (1. - lam.reshape(bs, 1, 1, 1)) 111 | y_hat, _ = model(x) 112 | y_mix = y * lam.reshape(bs, 1) + y[rn_indices] * (1. - lam.reshape(bs, 1)) 113 | samples_loss = F.binary_cross_entropy_with_logits(y_hat, y_mix, reduction="none") 114 | samples_loss = y_mask.float() * samples_loss 115 | else: 116 | y_hat, _ = model(x) 117 | samples_loss = F.binary_cross_entropy_with_logits(y_hat, y, reduction="none") 118 | samples_loss = y_mask.float() * samples_loss 119 | 120 | # loss 121 | loss = samples_loss.mean() 122 | 123 | # append training statistics 124 | train_stats['train_loss'].append(loss.detach().cpu().numpy()) 125 | 126 | # Update Model 127 | loss.backward() 128 | optimizer.step() 129 | optimizer.zero_grad() 130 | 131 | # Update learning rate 132 | scheduler.step() 133 | 134 | # evaluate 135 | mAP, ROC, val_loss = _test(model, mel, valid_dl, device) 136 | 137 | # log train and validation statistics 138 | wandb.log({"train_loss": np.mean(train_stats['train_loss']), 139 | "learning_rate": scheduler.get_last_lr()[0], 140 | "mAP": mAP, 141 | "ROC": ROC, 142 | "val_loss": val_loss 143 | }) 144 | 145 | # remove previous model (we try to not flood your hard disk) and save latest model 146 | if name is not None: 147 | os.remove(os.path.join(wandb.run.dir, name)) 148 | name = f"mn{str(width).replace('.', '')}_openmic_epoch_{epoch}_mAP_{int(round(mAP*1000))}.pt" 149 | torch.save(model.state_dict(), os.path.join(wandb.run.dir, name)) 150 | 151 | 152 | def _mel_forward(x, mel): 153 | old_shape = x.size() 154 | x = x.reshape(-1, old_shape[2]) 155 | x = mel(x) 156 | x = x.reshape(old_shape[0], old_shape[1], x.shape[1], x.shape[2]) 157 | return x 158 | 159 | 160 | def _test(model, mel, eval_loader, device): 161 | model.eval() 162 | mel.eval() 163 | 164 | targets = [] 165 | targets_mask = [] 166 | outputs = [] 167 | losses = [] 168 | pbar = tqdm(eval_loader) 169 | pbar.set_description("Validating") 170 | for batch in pbar: 171 | x, _, y = batch 172 | x = x.to(device) 173 | y = y.to(device) 174 | y_mask = y[:, 20:] 175 | y = y[:, :20] > 0.5 176 | y = y.float() 177 | with torch.no_grad(): 178 | x = _mel_forward(x, mel) 179 | y_hat, _ = model(x) 180 | 181 | samples_loss = F.binary_cross_entropy_with_logits(y_hat, y, reduction="none") 182 | samples_loss = y_mask.float() * samples_loss 183 | losses.append(samples_loss.mean().cpu().numpy()) 184 | 185 | targets.append(y.float().cpu().numpy()) 186 | targets_mask.append(y_mask.float().cpu().numpy()) 187 | outputs.append(torch.sigmoid(y_hat.float()).cpu().numpy()) 188 | 189 | targets = np.concatenate(targets) 190 | targets_mask = np.concatenate(targets_mask) 191 | outputs = np.concatenate(outputs) 192 | losses = np.stack(losses) 193 | 194 | try: 195 | mAP = np.array([metrics.average_precision_score( 196 | targets[:, i], outputs[:, i], sample_weight=targets_mask[:, i]) for i in range(targets.shape[1])]) 197 | except ValueError: 198 | mAP = np.array([np.nan] * targets.shape[1]) 199 | 200 | try: 201 | ROC = np.array([metrics.roc_auc_score( 202 | targets[:, i], outputs[:, i], sample_weight=targets_mask[:, i]) for i in range(targets.shape[1])]) 203 | except ValueError: 204 | ROC = np.array([np.nan] * targets.shape[1]) 205 | 206 | return mAP.mean(), ROC.mean(), losses.mean() 207 | 208 | 209 | if __name__ == '__main__': 210 | parser = argparse.ArgumentParser(description='Example of parser. ') 211 | 212 | # general 213 | parser.add_argument('--experiment_name', type=str, default="OpenMic") 214 | parser.add_argument('--train', action='store_true', default=False) 215 | parser.add_argument('--cuda', action='store_true', default=False) 216 | parser.add_argument('--batch_size', type=int, default=64) 217 | parser.add_argument('--num_workers', type=int, default=12) 218 | 219 | # training 220 | parser.add_argument('--pretrained', action='store_true', default=False) 221 | parser.add_argument('--model_name', type=str, default="mn10_as") 222 | parser.add_argument('--pretrain_final_temp', type=float, default=1.0) # for DyMN 223 | parser.add_argument('--model_width', type=float, default=1.0) 224 | parser.add_argument('--head_type', type=str, default="mlp") 225 | parser.add_argument('--se_dims', type=str, default="c") 226 | parser.add_argument('--n_epochs', type=int, default=80) 227 | parser.add_argument('--mixup_alpha', type=float, default=0.3) 228 | parser.add_argument('--no_roll', action='store_true', default=False) 229 | parser.add_argument('--no_wavmix', action='store_true', default=False) 230 | parser.add_argument('--gain_augment', type=int, default=12) 231 | parser.add_argument('--weight_decay', type=int, default=0.0) 232 | # lr schedule 233 | parser.add_argument('--lr', type=float, default=1e-5) 234 | parser.add_argument('--warm_up_len', type=int, default=10) 235 | parser.add_argument('--ramp_down_start', type=int, default=10) 236 | parser.add_argument('--ramp_down_len', type=int, default=65) 237 | parser.add_argument('--last_lr_value', type=float, default=0.01) 238 | 239 | # preprocessing 240 | parser.add_argument('--resample_rate', type=int, default=32000) 241 | parser.add_argument('--window_size', type=int, default=800) 242 | parser.add_argument('--hop_size', type=int, default=320) 243 | parser.add_argument('--n_fft', type=int, default=1024) 244 | parser.add_argument('--n_mels', type=int, default=128) 245 | parser.add_argument('--freqm', type=int, default=0) 246 | parser.add_argument('--timem', type=int, default=0) 247 | parser.add_argument('--fmin', type=int, default=0) 248 | parser.add_argument('--fmax', type=int, default=None) 249 | parser.add_argument('--fmin_aug_range', type=int, default=10) 250 | parser.add_argument('--fmax_aug_range', type=int, default=2000) 251 | 252 | args = parser.parse_args() 253 | if args.train: 254 | train(args) 255 | else: 256 | evaluate(args) 257 | -------------------------------------------------------------------------------- /ex_fsd50k.py: -------------------------------------------------------------------------------- 1 | import wandb 2 | import numpy as np 3 | import os 4 | from tqdm import tqdm 5 | import torch 6 | from torch.utils.data import DataLoader 7 | import argparse 8 | from sklearn import metrics 9 | import torch.nn.functional as F 10 | 11 | from datasets.fsd50k import get_eval_set, get_valid_set, get_training_set 12 | from models.mn.model import get_model as get_mobilenet 13 | from models.dymn.model import get_model as get_dymn 14 | from models.preprocess import AugmentMelSTFT 15 | from helpers.init import worker_init_fn 16 | from helpers.utils import NAME_TO_WIDTH, exp_warmup_linear_down, mixup 17 | 18 | 19 | def train(args): 20 | # Train Models on FSD50K 21 | 22 | # logging is done using wandb 23 | wandb.init( 24 | project="FSD50K", 25 | notes="Fine-tune Models on FSD50K.", 26 | tags=["FSDK50K", "Audio Tagging"], 27 | config=args, 28 | name=args.experiment_name 29 | ) 30 | 31 | device = torch.device('cuda') if args.cuda and torch.cuda.is_available() else torch.device('cpu') 32 | 33 | # model to preprocess waveform into mel spectrograms 34 | mel = AugmentMelSTFT(n_mels=args.n_mels, 35 | sr=args.resample_rate, 36 | win_length=args.window_size, 37 | hopsize=args.hop_size, 38 | n_fft=args.n_fft, 39 | freqm=args.freqm, 40 | timem=args.timem, 41 | fmin=args.fmin, 42 | fmax=args.fmax, 43 | fmin_aug_range=args.fmin_aug_range, 44 | fmax_aug_range=args.fmax_aug_range 45 | ) 46 | mel.to(device) 47 | 48 | # load prediction model 49 | model_name = args.model_name 50 | pretrained_name = model_name if args.pretrained else None 51 | width = NAME_TO_WIDTH(model_name) if model_name and args.pretrained else args.model_width 52 | if model_name.startswith("dymn"): 53 | model = get_dymn(width_mult=width, pretrained_name=pretrained_name, 54 | pretrain_final_temp=args.pretrain_final_temp, 55 | num_classes=200) 56 | else: 57 | model = get_mobilenet(width_mult=width, pretrained_name=pretrained_name, 58 | head_type=args.head_type, se_dims=args.se_dims, 59 | num_classes=200) 60 | model.to(device) 61 | 62 | # dataloader 63 | dl = DataLoader(dataset=get_training_set(resample_rate=args.resample_rate, 64 | roll=False if args.no_roll else True, 65 | wavmix=False if args.no_wavmix else True, 66 | gain_augment=args.gain_augment), 67 | worker_init_fn=worker_init_fn, 68 | num_workers=args.num_workers, 69 | batch_size=args.batch_size, 70 | shuffle=True) 71 | 72 | # evaluation loader 73 | valid_dl = DataLoader(dataset=get_valid_set(resample_rate=args.resample_rate, 74 | variable_eval=args.variable_eval_length), 75 | worker_init_fn=worker_init_fn, 76 | num_workers=args.num_workers, 77 | batch_size=1 if args.variable_eval_length else args.batch_size) 78 | 79 | # optimizer & scheduler 80 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) 81 | # phases of lr schedule: exponential increase, constant lr, linear decrease, fine-tune 82 | schedule_lambda = \ 83 | exp_warmup_linear_down(args.warm_up_len, args.ramp_down_len, args.ramp_down_start, args.last_lr_value) 84 | scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, schedule_lambda) 85 | 86 | name = None 87 | mAP, ROC, val_loss = float('NaN'), float('NaN'), float('NaN') 88 | 89 | for epoch in range(args.n_epochs): 90 | mel.train() 91 | model.train() 92 | train_stats = dict(train_loss=list()) 93 | pbar = tqdm(dl) 94 | pbar.set_description("Epoch {}/{}: mAP: {:.4f}, val_loss: {:.4f}" 95 | .format(epoch + 1, args.n_epochs, mAP, val_loss)) 96 | for batch in pbar: 97 | x, f, y = batch 98 | bs = x.size(0) 99 | x, y = x.to(device), y.to(device) 100 | x = _mel_forward(x, mel) 101 | 102 | if args.mixup_alpha: 103 | rn_indices, lam = mixup(bs, args.mixup_alpha) 104 | lam = lam.to(x.device) 105 | x = x * lam.reshape(bs, 1, 1, 1) + \ 106 | x[rn_indices] * (1. - lam.reshape(bs, 1, 1, 1)) 107 | y_hat, _ = model(x) 108 | y_mix = y * lam.reshape(bs, 1) + y[rn_indices] * (1. - lam.reshape(bs, 1)) 109 | samples_loss = F.binary_cross_entropy_with_logits(y_hat, y_mix, reduction="none") 110 | else: 111 | y_hat , _= model(x) 112 | samples_loss = F.binary_cross_entropy_with_logits(y_hat, y, reduction="none") 113 | 114 | # loss 115 | loss = samples_loss.mean() 116 | 117 | # append training statistics 118 | train_stats['train_loss'].append(loss.detach().cpu().numpy()) 119 | 120 | # Update Model 121 | loss.backward() 122 | optimizer.step() 123 | optimizer.zero_grad() 124 | # Update learning rate 125 | scheduler.step() 126 | 127 | # evaluate 128 | mAP, ROC, val_loss = _test(model, mel, valid_dl, device) 129 | 130 | # log train and validation statistics 131 | wandb.log({"train_loss": np.mean(train_stats['train_loss']), 132 | "learning_rate": scheduler.get_last_lr()[0], 133 | "mAP": mAP, 134 | "ROC": ROC, 135 | "val_loss": val_loss 136 | }) 137 | 138 | # remove previous model (we try to not flood your hard disk) and save latest model 139 | if name is not None: 140 | os.remove(os.path.join(wandb.run.dir, name)) 141 | name = f"mn{str(width).replace('.', '')}_fsd50k_epoch_{epoch}_mAP_{int(round(mAP*1000))}.pt" 142 | torch.save(model.state_dict(), os.path.join(wandb.run.dir, name)) 143 | 144 | 145 | def _mel_forward(x, mel): 146 | old_shape = x.size() 147 | x = x.reshape(-1, old_shape[2]) 148 | x = mel(x) 149 | x = x.reshape(old_shape[0], old_shape[1], x.shape[1], x.shape[2]) 150 | return x 151 | 152 | 153 | def _test(model, mel, eval_loader, device): 154 | model.eval() 155 | mel.eval() 156 | 157 | targets = [] 158 | outputs = [] 159 | losses = [] 160 | pbar = tqdm(eval_loader) 161 | pbar.set_description("Validating") 162 | for batch in pbar: 163 | x, _, y = batch 164 | x = x.to(device) 165 | y = y.to(device) 166 | with torch.no_grad(): 167 | x = _mel_forward(x, mel) 168 | y_hat, _ = model(x) 169 | targets.append(y.cpu().numpy()) 170 | outputs.append(y_hat.float().cpu().numpy()) 171 | losses.append(F.binary_cross_entropy_with_logits(y_hat, y).cpu().numpy()) 172 | 173 | targets = np.concatenate(targets) 174 | outputs = np.concatenate(outputs) 175 | losses = np.stack(losses) 176 | mAP = metrics.average_precision_score(targets, outputs, average=None) 177 | ROC = metrics.roc_auc_score(targets, outputs, average=None) 178 | return mAP.mean(), ROC.mean(), losses.mean() 179 | 180 | 181 | def evaluate(args): 182 | model_name = args.model_name 183 | device = torch.device('cuda') if args.cuda and torch.cuda.is_available() else torch.device('cpu') 184 | 185 | # load pre-trained model 186 | model_name = args.model_name 187 | width = NAME_TO_WIDTH(model_name) 188 | if model_name.startswith("dymn"): 189 | model = get_dymn(width_mult=width, pretrained_name=model_name, 190 | pretrain_final_temp=args.pretrain_final_temp, 191 | num_classes=200) 192 | else: 193 | model = get_mobilenet(width_mult=width, pretrained_name=model_name, 194 | head_type=args.head_type, se_dims=args.se_dims, 195 | num_classes=200) 196 | model.to(device) 197 | model.eval() 198 | 199 | # model to preprocess waveform into mel spectrograms 200 | mel = AugmentMelSTFT(n_mels=args.n_mels, 201 | sr=args.resample_rate, 202 | win_length=args.window_size, 203 | hopsize=args.hop_size, 204 | n_fft=args.n_fft, 205 | freqm=args.freqm, 206 | timem=args.timem, 207 | fmin=args.fmin, 208 | fmax=args.fmax, 209 | fmin_aug_range=args.fmin_aug_range, 210 | fmax_aug_range=args.fmax_aug_range 211 | ) 212 | mel.to(device) 213 | mel.eval() 214 | 215 | dl = DataLoader(dataset=get_eval_set(resample_rate=args.resample_rate, 216 | variable_eval=args.variable_eval_length), 217 | worker_init_fn=worker_init_fn, 218 | num_workers=args.num_workers, 219 | batch_size=1 if args.variable_eval_length else args.batch_size) 220 | 221 | print(f"Running FSD50K evaluation for model '{model_name}' on device '{device}'") 222 | targets = [] 223 | outputs = [] 224 | for batch in tqdm(dl): 225 | x, _, y = batch 226 | x = x.to(device) 227 | y = y.to(device) 228 | with torch.no_grad(): 229 | x = _mel_forward(x, mel) 230 | y_hat, _ = model(x) 231 | targets.append(y.cpu().numpy()) 232 | outputs.append(y_hat.float().cpu().numpy()) 233 | 234 | targets = np.concatenate(targets) 235 | outputs = np.concatenate(outputs) 236 | mAP = metrics.average_precision_score(targets, outputs, average=None) 237 | ROC = metrics.roc_auc_score(targets, outputs, average=None) 238 | 239 | print(f"Results on FSD50K evaluation split for loaded model: {model_name}") 240 | print(" mAP: {:.3f}".format(mAP.mean())) 241 | print(" ROC: {:.3f}".format(ROC.mean())) 242 | 243 | 244 | if __name__ == '__main__': 245 | parser = argparse.ArgumentParser(description='Example of parser. ') 246 | 247 | # general 248 | parser.add_argument('--experiment_name', type=str, default="FSD50K") 249 | parser.add_argument('--train', action='store_true', default=False) 250 | parser.add_argument('--cuda', action='store_true', default=False) 251 | parser.add_argument('--batch_size', type=int, default=64) 252 | parser.add_argument('--num_workers', type=int, default=12) 253 | 254 | # validation & evaluation 255 | # if true, requires setting validation and evaluation batch size to 1 256 | parser.add_argument('--variable_eval_length', action='store_true', default=False) 257 | 258 | # training 259 | parser.add_argument('--pretrained', action='store_true', default=False) 260 | parser.add_argument('--model_name', type=str, default="mn10_as") 261 | parser.add_argument('--pretrain_final_temp', type=float, default=1.0) # for DyMN 262 | parser.add_argument('--model_width', type=float, default=1.0) 263 | parser.add_argument('--head_type', type=str, default="mlp") 264 | parser.add_argument('--se_dims', type=str, default="c") 265 | parser.add_argument('--n_epochs', type=int, default=80) 266 | parser.add_argument('--mixup_alpha', type=float, default=0.3) 267 | parser.add_argument('--no_roll', action='store_true', default=False) 268 | parser.add_argument('--no_wavmix', action='store_true', default=False) 269 | parser.add_argument('--gain_augment', type=int, default=12) 270 | parser.add_argument('--weight_decay', type=int, default=0.0) 271 | # lr schedule 272 | parser.add_argument('--lr', type=float, default=7e-5) 273 | parser.add_argument('--warm_up_len', type=int, default=10) 274 | parser.add_argument('--ramp_down_start', type=int, default=10) 275 | parser.add_argument('--ramp_down_len', type=int, default=65) 276 | parser.add_argument('--last_lr_value', type=float, default=0.01) 277 | 278 | # preprocessing 279 | parser.add_argument('--resample_rate', type=int, default=32000) 280 | parser.add_argument('--window_size', type=int, default=800) 281 | parser.add_argument('--hop_size', type=int, default=320) 282 | parser.add_argument('--n_fft', type=int, default=1024) 283 | parser.add_argument('--n_mels', type=int, default=128) 284 | parser.add_argument('--freqm', type=int, default=0) 285 | parser.add_argument('--timem', type=int, default=0) 286 | parser.add_argument('--fmin', type=int, default=0) 287 | parser.add_argument('--fmax', type=int, default=None) 288 | parser.add_argument('--fmin_aug_range', type=int, default=10) 289 | parser.add_argument('--fmax_aug_range', type=int, default=2000) 290 | 291 | args = parser.parse_args() 292 | if args.train: 293 | train(args) 294 | else: 295 | evaluate(args) 296 | -------------------------------------------------------------------------------- /models/dymn/model.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | from typing import Any, Callable, List, Optional, Sequence, Tuple 3 | from torch import nn, Tensor 4 | import torch.nn.functional as F 5 | from torchvision.ops.misc import ConvNormActivation 6 | from torch.hub import load_state_dict_from_url 7 | import urllib.parse 8 | 9 | from models.dymn.dy_block import DynamicInvertedResidualConfig, DY_Block, DynamicConv, DyReLUB 10 | from models.mn.block_types import InvertedResidualConfig, InvertedResidual 11 | 12 | # points to github releases 13 | model_url = "https://github.com/fschmid56/EfficientAT/releases/download/v0.0.1/" 14 | # folder to store downloaded models to 15 | model_dir = "resources" 16 | 17 | 18 | pretrained_models = { 19 | # ImageNet pre-trained models 20 | "dymn04_im": urllib.parse.urljoin(model_url, "dymn04_im.pt"), 21 | "dymn10_im": urllib.parse.urljoin(model_url, "dymn10_im.pt"), 22 | "dymn20_im": urllib.parse.urljoin(model_url, "dymn20_im.pt"), 23 | 24 | # Models trained on AudioSet 25 | "dymn04_as": urllib.parse.urljoin(model_url, "dymn04_as.pt"), 26 | "dymn10_as": urllib.parse.urljoin(model_url, "dymn10_as.pt"), 27 | "dymn20_as": urllib.parse.urljoin(model_url, "dymn20_as_mAP_493.pt"), 28 | "dymn20_as(1)": urllib.parse.urljoin(model_url, "dymn20_as.pt"), 29 | "dymn20_as(2)": urllib.parse.urljoin(model_url, "dymn20_as_mAP_489.pt"), 30 | "dymn20_as(3)": urllib.parse.urljoin(model_url, "dymn20_as_mAP_490.pt"), 31 | "dymn04_replace_se_as": urllib.parse.urljoin(model_url, "dymn04_replace_se_as.pt"), 32 | "dymn10_replace_se_as": urllib.parse.urljoin(model_url, " dymn10_replace_se_as.pt"), 33 | } 34 | 35 | 36 | class DyMN(nn.Module): 37 | def __init__( 38 | self, 39 | inverted_residual_setting: List[DynamicInvertedResidualConfig], 40 | last_channel: int, 41 | num_classes: int = 527, 42 | head_type: str = "mlp", 43 | block: Optional[Callable[..., nn.Module]] = None, 44 | norm_layer: Optional[Callable[..., nn.Module]] = None, 45 | dropout: float = 0.2, 46 | in_conv_kernel: int = 3, 47 | in_conv_stride: int = 2, 48 | in_channels: int = 1, 49 | context_ratio: int = 4, 50 | max_context_size: int = 128, 51 | min_context_size: int = 32, 52 | dyrelu_k=2, 53 | dyconv_k=4, 54 | no_dyrelu: bool = False, 55 | no_dyconv: bool = False, 56 | no_ca: bool = False, 57 | temp_schedule: tuple = (30, 1, 1, 0.05), 58 | **kwargs: Any, 59 | ) -> None: 60 | super(DyMN, self).__init__() 61 | 62 | if not inverted_residual_setting: 63 | raise ValueError("The inverted_residual_setting should not be empty") 64 | elif not ( 65 | isinstance(inverted_residual_setting, Sequence) 66 | and all([isinstance(s, DynamicInvertedResidualConfig) for s in inverted_residual_setting]) 67 | ): 68 | raise TypeError("The inverted_residual_setting should be List[DynamicInvertedResidualConfig]") 69 | 70 | if block is None: 71 | block = DY_Block 72 | 73 | norm_layer = \ 74 | norm_layer if norm_layer is not None else partial(nn.BatchNorm2d, eps=0.001, momentum=0.01) 75 | 76 | self.layers = nn.ModuleList() 77 | 78 | # building first layer 79 | firstconv_output_channels = inverted_residual_setting[0].input_channels 80 | self.in_c = ConvNormActivation( 81 | in_channels, 82 | firstconv_output_channels, 83 | kernel_size=in_conv_kernel, 84 | stride=in_conv_stride, 85 | norm_layer=norm_layer, 86 | activation_layer=nn.Hardswish, 87 | ) 88 | 89 | for cnf in inverted_residual_setting: 90 | if cnf.use_dy_block: 91 | b = block(cnf, 92 | context_ratio=context_ratio, 93 | max_context_size=max_context_size, 94 | min_context_size=min_context_size, 95 | dyrelu_k=dyrelu_k, 96 | dyconv_k=dyconv_k, 97 | no_dyrelu=no_dyrelu, 98 | no_dyconv=no_dyconv, 99 | no_ca=no_ca, 100 | temp_schedule=temp_schedule 101 | ) 102 | else: 103 | b = InvertedResidual(cnf, None, norm_layer, partial(nn.BatchNorm2d, eps=0.001, momentum=0.01)) 104 | 105 | self.layers.append(b) 106 | 107 | # building last several layers 108 | lastconv_input_channels = inverted_residual_setting[-1].out_channels 109 | lastconv_output_channels = 6 * lastconv_input_channels 110 | self.out_c = ConvNormActivation( 111 | lastconv_input_channels, 112 | lastconv_output_channels, 113 | kernel_size=1, 114 | norm_layer=norm_layer, 115 | activation_layer=nn.Hardswish, 116 | ) 117 | 118 | self.head_type = head_type 119 | if self.head_type == "fully_convolutional": 120 | self.classifier = nn.Sequential( 121 | nn.Conv2d( 122 | lastconv_output_channels, 123 | num_classes, 124 | kernel_size=(1, 1), 125 | stride=(1, 1), 126 | padding=(0, 0), 127 | bias=False), 128 | nn.BatchNorm2d(num_classes), 129 | nn.AdaptiveAvgPool2d((1, 1)), 130 | ) 131 | elif self.head_type == "mlp": 132 | self.classifier = nn.Sequential( 133 | nn.AdaptiveAvgPool2d(1), 134 | nn.Flatten(start_dim=1), 135 | nn.Linear(lastconv_output_channels, last_channel), 136 | nn.Hardswish(inplace=True), 137 | nn.Dropout(p=dropout, inplace=True), 138 | nn.Linear(last_channel, num_classes), 139 | ) 140 | else: 141 | raise NotImplementedError(f"Head '{self.head_type}' unknown. Must be one of: 'mlp', " 142 | f"'fully_convolutional', 'multihead_attention_pooling'") 143 | 144 | for m in self.modules(): 145 | if isinstance(m, nn.Conv2d): 146 | nn.init.kaiming_normal_(m.weight, mode="fan_out") 147 | if m.bias is not None: 148 | nn.init.zeros_(m.bias) 149 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm, nn.LayerNorm, nn.InstanceNorm2d)): 150 | nn.init.ones_(m.weight) 151 | nn.init.zeros_(m.bias) 152 | elif isinstance(m, nn.Linear): 153 | nn.init.normal_(m.weight, 0, 0.01) 154 | if m.bias is not None: 155 | nn.init.zeros_(m.bias) 156 | 157 | def _feature_forward(self, x: Tensor) -> (Tensor, Tensor): 158 | x = self.in_c(x) 159 | g = None 160 | for layer in self.layers: 161 | x = layer(x) 162 | x = self.out_c(x) 163 | return x 164 | 165 | def _clf_forward(self, x: Tensor): 166 | embed = F.adaptive_avg_pool2d(x, (1, 1)).view(x.size(0), -1) 167 | x = self.classifier(x).squeeze() 168 | if x.dim() == 1: 169 | # squeezed batch dimension 170 | x = x.unsqueeze(0) 171 | return x, embed 172 | 173 | def _forward_impl(self, x: Tensor) -> (Tensor, Tensor): 174 | x = self._feature_forward(x) 175 | x, embed = self._clf_forward(x) 176 | return x, embed 177 | 178 | def forward(self, x: Tensor) -> (Tensor, Tensor): 179 | return self._forward_impl(x) 180 | 181 | def update_params(self, epoch): 182 | for module in self.modules(): 183 | if isinstance(module, DynamicConv): 184 | module.update_params(epoch) 185 | 186 | 187 | def _dymn_conf( 188 | width_mult: float = 1.0, 189 | reduced_tail: bool = False, 190 | dilated: bool = False, 191 | strides: Tuple[int] = (2, 2, 2, 2), 192 | use_dy_blocks: str = "all", 193 | **kwargs: Any 194 | ): 195 | reduce_divider = 2 if reduced_tail else 1 196 | dilation = 2 if dilated else 1 197 | 198 | bneck_conf = partial(DynamicInvertedResidualConfig, width_mult=width_mult) 199 | adjust_channels = partial(DynamicInvertedResidualConfig.adjust_channels, width_mult=width_mult) 200 | 201 | activations = ["RE", "RE", "RE", "RE", "RE", "RE", "HS", "HS", "HS", "HS", "HS", "HS", "HS", "HS", "HS"] 202 | 203 | if use_dy_blocks == "all": 204 | # per default the dynamic blocks replace all conventional IR blocks 205 | use_dy_block = [True] * 15 206 | elif use_dy_blocks == "replace_se": 207 | use_dy_block = [False, False, False, True, True, True, False, False, False, False, True, True, True, True, True] 208 | else: 209 | raise NotImplementedError(f"Config use_dy_blocks={use_dy_blocks} not implemented.") 210 | 211 | inverted_residual_setting = [ 212 | bneck_conf(16, 3, 16, 16, use_dy_block[0], activations[0], 1, 1), 213 | bneck_conf(16, 3, 64, 24, use_dy_block[1], activations[1], strides[0], 1), # C1 214 | bneck_conf(24, 3, 72, 24, use_dy_block[2], activations[2], 1, 1), 215 | bneck_conf(24, 5, 72, 40, use_dy_block[3], activations[3], strides[1], 1), # C2 216 | bneck_conf(40, 5, 120, 40, use_dy_block[4], activations[4], 1, 1), 217 | bneck_conf(40, 5, 120, 40, use_dy_block[5], activations[5], 1, 1), 218 | bneck_conf(40, 3, 240, 80, use_dy_block[6], activations[6], strides[2], 1), # C3 219 | bneck_conf(80, 3, 200, 80, use_dy_block[7], activations[7], 1, 1), 220 | bneck_conf(80, 3, 184, 80, use_dy_block[8], activations[8], 1, 1), 221 | bneck_conf(80, 3, 184, 80, use_dy_block[9], activations[9], 1, 1), 222 | bneck_conf(80, 3, 480, 112, use_dy_block[10], activations[10], 1, 1), 223 | bneck_conf(112, 3, 672, 112, use_dy_block[11], activations[11], 1, 1), 224 | bneck_conf(112, 5, 672, 160 // reduce_divider, use_dy_block[12], activations[12], strides[3], dilation), # C4 225 | bneck_conf(160 // reduce_divider, 5, 960 // reduce_divider, 160 // reduce_divider, use_dy_block[13], 226 | activations[13], 1, dilation), 227 | bneck_conf(160 // reduce_divider, 5, 960 // reduce_divider, 160 // reduce_divider, use_dy_block[14], 228 | activations[14], 1, dilation), 229 | ] 230 | last_channel = adjust_channels(1280 // reduce_divider) 231 | 232 | return inverted_residual_setting, last_channel 233 | 234 | 235 | def _dymn( 236 | inverted_residual_setting: List[DynamicInvertedResidualConfig], 237 | last_channel: int, 238 | pretrained_name: str, 239 | **kwargs: Any, 240 | ): 241 | model = DyMN(inverted_residual_setting, last_channel, **kwargs) 242 | 243 | # load pre-trained model using specified name 244 | if pretrained_name: 245 | # download from GitHub or load cached state_dict from 'resources' folder 246 | model_url = pretrained_models.get(pretrained_name) 247 | state_dict = load_state_dict_from_url(model_url, model_dir=model_dir, map_location="cpu") 248 | cls_in_state_dict = state_dict['classifier.5.weight'].shape[0] 249 | cls_in_current_model = model.classifier[5].out_features 250 | if cls_in_state_dict != cls_in_current_model: 251 | print(f"The number of classes in the loaded state dict (={cls_in_state_dict}) and " 252 | f"the current model (={cls_in_current_model}) is not the same. Dropping final fully-connected layer " 253 | f"and loading weights in non-strict mode!") 254 | del state_dict['classifier.5.weight'] 255 | del state_dict['classifier.5.bias'] 256 | model.load_state_dict(state_dict, strict=False) 257 | else: 258 | model.load_state_dict(state_dict) 259 | return model 260 | 261 | 262 | def dymn(pretrained_name: str = None, **kwargs: Any): 263 | inverted_residual_setting, last_channel = _dymn_conf(**kwargs) 264 | return _dymn(inverted_residual_setting, last_channel, pretrained_name, **kwargs) 265 | 266 | 267 | def get_model(num_classes: int = 527, 268 | pretrained_name: str = None, 269 | width_mult: float = 1.0, 270 | strides: Tuple[int, int, int, int] = (2, 2, 2, 2), 271 | # Context 272 | context_ratio: int = 4, 273 | max_context_size: int = 128, 274 | min_context_size: int = 32, 275 | # Dy-ReLU 276 | dyrelu_k: int = 2, 277 | no_dyrelu: bool = False, 278 | # Dy-Conv 279 | dyconv_k: int = 4, 280 | no_dyconv: bool = False, 281 | T_max: float = 30.0, 282 | T0_slope: float = 1.0, 283 | T1_slope: float = 0.02, 284 | T_min: float = 1, 285 | pretrain_final_temp: float = 1.0, 286 | # Coordinate Attention 287 | no_ca: bool = False, 288 | use_dy_blocks="all"): 289 | """ 290 | Arguments to modify the instantiation of a DyMN 291 | 292 | Args: 293 | num_classes (int): Specifies number of classes to predict 294 | pretrained_name (str): Specifies name of pre-trained model to load 295 | width_mult (float): Scales width of network 296 | strides (Tuple): Strides that are set to '2' in original implementation; 297 | might be changed to modify the size of receptive field and the downsampling factor in 298 | time and frequency dimension 299 | context_ratio (int): fraction of expanded channel representation used as context size 300 | max_context_size (int): maximum size of context 301 | min_context_size (int): minimum size of context 302 | dyrelu_k (int): number of linear mappings 303 | no_dyrelu (bool): not use Dy-ReLU 304 | dyconv_k (int): number of kernels for dynamic convolution 305 | no_dyconv (bool): not use Dy-Conv 306 | T_max, T0_slope, T1_slope, T_min (float): hyperparameters to steer the temperature schedule for Dy-Conv 307 | pretrain_final_temp (float): if model is pre-trained, then final Dy-Conv temperature 308 | of pre-training stage should be used 309 | no_ca (bool): not use Coordinate Attention 310 | use_dy_blocks (str): use dynamic block at all positions per default, other option: "replace_se" 311 | """ 312 | 313 | block = DY_Block 314 | if pretrained_name: 315 | # if model is pre-trained, set Dy-Conv temperature to 'pretrain_final_temp' 316 | # pretrained on ImageNet -> 30 317 | # pretrained on AudioSet -> 1 318 | T_max = pretrain_final_temp 319 | 320 | temp_schedule = (T_max, T_min, T0_slope, T1_slope) 321 | 322 | m = dymn(num_classes=num_classes, 323 | pretrained_name=pretrained_name, 324 | block=block, 325 | width_mult=width_mult, 326 | strides=strides, 327 | context_ratio=context_ratio, 328 | max_context_size=max_context_size, 329 | min_context_size=min_context_size, 330 | dyrelu_k=dyrelu_k, 331 | dyconv_k=dyconv_k, 332 | no_dyrelu=no_dyrelu, 333 | no_dyconv=no_dyconv, 334 | no_ca=no_ca, 335 | temp_schedule=temp_schedule, 336 | use_dy_blocks=use_dy_blocks 337 | ) 338 | print(m) 339 | return m 340 | -------------------------------------------------------------------------------- /metadata/class_labels_indices.csv: -------------------------------------------------------------------------------- 1 | index,mid,display_name 2 | 0,/m/09x0r,"Speech" 3 | 1,/m/05zppz,"Male speech, man speaking" 4 | 2,/m/02zsn,"Female speech, woman speaking" 5 | 3,/m/0ytgt,"Child speech, kid speaking" 6 | 4,/m/01h8n0,"Conversation" 7 | 5,/m/02qldy,"Narration, monologue" 8 | 6,/m/0261r1,"Babbling" 9 | 7,/m/0brhx,"Speech synthesizer" 10 | 8,/m/07p6fty,"Shout" 11 | 9,/m/07q4ntr,"Bellow" 12 | 10,/m/07rwj3x,"Whoop" 13 | 11,/m/07sr1lc,"Yell" 14 | 12,/m/04gy_2,"Battle cry" 15 | 13,/t/dd00135,"Children shouting" 16 | 14,/m/03qc9zr,"Screaming" 17 | 15,/m/02rtxlg,"Whispering" 18 | 16,/m/01j3sz,"Laughter" 19 | 17,/t/dd00001,"Baby laughter" 20 | 18,/m/07r660_,"Giggle" 21 | 19,/m/07s04w4,"Snicker" 22 | 20,/m/07sq110,"Belly laugh" 23 | 21,/m/07rgt08,"Chuckle, chortle" 24 | 22,/m/0463cq4,"Crying, sobbing" 25 | 23,/t/dd00002,"Baby cry, infant cry" 26 | 24,/m/07qz6j3,"Whimper" 27 | 25,/m/07qw_06,"Wail, moan" 28 | 26,/m/07plz5l,"Sigh" 29 | 27,/m/015lz1,"Singing" 30 | 28,/m/0l14jd,"Choir" 31 | 29,/m/01swy6,"Yodeling" 32 | 30,/m/02bk07,"Chant" 33 | 31,/m/01c194,"Mantra" 34 | 32,/t/dd00003,"Male singing" 35 | 33,/t/dd00004,"Female singing" 36 | 34,/t/dd00005,"Child singing" 37 | 35,/t/dd00006,"Synthetic singing" 38 | 36,/m/06bxc,"Rapping" 39 | 37,/m/02fxyj,"Humming" 40 | 38,/m/07s2xch,"Groan" 41 | 39,/m/07r4k75,"Grunt" 42 | 40,/m/01w250,"Whistling" 43 | 41,/m/0lyf6,"Breathing" 44 | 42,/m/07mzm6,"Wheeze" 45 | 43,/m/01d3sd,"Snoring" 46 | 44,/m/07s0dtb,"Gasp" 47 | 45,/m/07pyy8b,"Pant" 48 | 46,/m/07q0yl5,"Snort" 49 | 47,/m/01b_21,"Cough" 50 | 48,/m/0dl9sf8,"Throat clearing" 51 | 49,/m/01hsr_,"Sneeze" 52 | 50,/m/07ppn3j,"Sniff" 53 | 51,/m/06h7j,"Run" 54 | 52,/m/07qv_x_,"Shuffle" 55 | 53,/m/07pbtc8,"Walk, footsteps" 56 | 54,/m/03cczk,"Chewing, mastication" 57 | 55,/m/07pdhp0,"Biting" 58 | 56,/m/0939n_,"Gargling" 59 | 57,/m/01g90h,"Stomach rumble" 60 | 58,/m/03q5_w,"Burping, eructation" 61 | 59,/m/02p3nc,"Hiccup" 62 | 60,/m/02_nn,"Fart" 63 | 61,/m/0k65p,"Hands" 64 | 62,/m/025_jnm,"Finger snapping" 65 | 63,/m/0l15bq,"Clapping" 66 | 64,/m/01jg02,"Heart sounds, heartbeat" 67 | 65,/m/01jg1z,"Heart murmur" 68 | 66,/m/053hz1,"Cheering" 69 | 67,/m/028ght,"Applause" 70 | 68,/m/07rkbfh,"Chatter" 71 | 69,/m/03qtwd,"Crowd" 72 | 70,/m/07qfr4h,"Hubbub, speech noise, speech babble" 73 | 71,/t/dd00013,"Children playing" 74 | 72,/m/0jbk,"Animal" 75 | 73,/m/068hy,"Domestic animals, pets" 76 | 74,/m/0bt9lr,"Dog" 77 | 75,/m/05tny_,"Bark" 78 | 76,/m/07r_k2n,"Yip" 79 | 77,/m/07qf0zm,"Howl" 80 | 78,/m/07rc7d9,"Bow-wow" 81 | 79,/m/0ghcn6,"Growling" 82 | 80,/t/dd00136,"Whimper (dog)" 83 | 81,/m/01yrx,"Cat" 84 | 82,/m/02yds9,"Purr" 85 | 83,/m/07qrkrw,"Meow" 86 | 84,/m/07rjwbb,"Hiss" 87 | 85,/m/07r81j2,"Caterwaul" 88 | 86,/m/0ch8v,"Livestock, farm animals, working animals" 89 | 87,/m/03k3r,"Horse" 90 | 88,/m/07rv9rh,"Clip-clop" 91 | 89,/m/07q5rw0,"Neigh, whinny" 92 | 90,/m/01xq0k1,"Cattle, bovinae" 93 | 91,/m/07rpkh9,"Moo" 94 | 92,/m/0239kh,"Cowbell" 95 | 93,/m/068zj,"Pig" 96 | 94,/t/dd00018,"Oink" 97 | 95,/m/03fwl,"Goat" 98 | 96,/m/07q0h5t,"Bleat" 99 | 97,/m/07bgp,"Sheep" 100 | 98,/m/025rv6n,"Fowl" 101 | 99,/m/09b5t,"Chicken, rooster" 102 | 100,/m/07st89h,"Cluck" 103 | 101,/m/07qn5dc,"Crowing, cock-a-doodle-doo" 104 | 102,/m/01rd7k,"Turkey" 105 | 103,/m/07svc2k,"Gobble" 106 | 104,/m/09ddx,"Duck" 107 | 105,/m/07qdb04,"Quack" 108 | 106,/m/0dbvp,"Goose" 109 | 107,/m/07qwf61,"Honk" 110 | 108,/m/01280g,"Wild animals" 111 | 109,/m/0cdnk,"Roaring cats (lions, tigers)" 112 | 110,/m/04cvmfc,"Roar" 113 | 111,/m/015p6,"Bird" 114 | 112,/m/020bb7,"Bird vocalization, bird call, bird song" 115 | 113,/m/07pggtn,"Chirp, tweet" 116 | 114,/m/07sx8x_,"Squawk" 117 | 115,/m/0h0rv,"Pigeon, dove" 118 | 116,/m/07r_25d,"Coo" 119 | 117,/m/04s8yn,"Crow" 120 | 118,/m/07r5c2p,"Caw" 121 | 119,/m/09d5_,"Owl" 122 | 120,/m/07r_80w,"Hoot" 123 | 121,/m/05_wcq,"Bird flight, flapping wings" 124 | 122,/m/01z5f,"Canidae, dogs, wolves" 125 | 123,/m/06hps,"Rodents, rats, mice" 126 | 124,/m/04rmv,"Mouse" 127 | 125,/m/07r4gkf,"Patter" 128 | 126,/m/03vt0,"Insect" 129 | 127,/m/09xqv,"Cricket" 130 | 128,/m/09f96,"Mosquito" 131 | 129,/m/0h2mp,"Fly, housefly" 132 | 130,/m/07pjwq1,"Buzz" 133 | 131,/m/01h3n,"Bee, wasp, etc." 134 | 132,/m/09ld4,"Frog" 135 | 133,/m/07st88b,"Croak" 136 | 134,/m/078jl,"Snake" 137 | 135,/m/07qn4z3,"Rattle" 138 | 136,/m/032n05,"Whale vocalization" 139 | 137,/m/04rlf,"Music" 140 | 138,/m/04szw,"Musical instrument" 141 | 139,/m/0fx80y,"Plucked string instrument" 142 | 140,/m/0342h,"Guitar" 143 | 141,/m/02sgy,"Electric guitar" 144 | 142,/m/018vs,"Bass guitar" 145 | 143,/m/042v_gx,"Acoustic guitar" 146 | 144,/m/06w87,"Steel guitar, slide guitar" 147 | 145,/m/01glhc,"Tapping (guitar technique)" 148 | 146,/m/07s0s5r,"Strum" 149 | 147,/m/018j2,"Banjo" 150 | 148,/m/0jtg0,"Sitar" 151 | 149,/m/04rzd,"Mandolin" 152 | 150,/m/01bns_,"Zither" 153 | 151,/m/07xzm,"Ukulele" 154 | 152,/m/05148p4,"Keyboard (musical)" 155 | 153,/m/05r5c,"Piano" 156 | 154,/m/01s0ps,"Electric piano" 157 | 155,/m/013y1f,"Organ" 158 | 156,/m/03xq_f,"Electronic organ" 159 | 157,/m/03gvt,"Hammond organ" 160 | 158,/m/0l14qv,"Synthesizer" 161 | 159,/m/01v1d8,"Sampler" 162 | 160,/m/03q5t,"Harpsichord" 163 | 161,/m/0l14md,"Percussion" 164 | 162,/m/02hnl,"Drum kit" 165 | 163,/m/0cfdd,"Drum machine" 166 | 164,/m/026t6,"Drum" 167 | 165,/m/06rvn,"Snare drum" 168 | 166,/m/03t3fj,"Rimshot" 169 | 167,/m/02k_mr,"Drum roll" 170 | 168,/m/0bm02,"Bass drum" 171 | 169,/m/011k_j,"Timpani" 172 | 170,/m/01p970,"Tabla" 173 | 171,/m/01qbl,"Cymbal" 174 | 172,/m/03qtq,"Hi-hat" 175 | 173,/m/01sm1g,"Wood block" 176 | 174,/m/07brj,"Tambourine" 177 | 175,/m/05r5wn,"Rattle (instrument)" 178 | 176,/m/0xzly,"Maraca" 179 | 177,/m/0mbct,"Gong" 180 | 178,/m/016622,"Tubular bells" 181 | 179,/m/0j45pbj,"Mallet percussion" 182 | 180,/m/0dwsp,"Marimba, xylophone" 183 | 181,/m/0dwtp,"Glockenspiel" 184 | 182,/m/0dwt5,"Vibraphone" 185 | 183,/m/0l156b,"Steelpan" 186 | 184,/m/05pd6,"Orchestra" 187 | 185,/m/01kcd,"Brass instrument" 188 | 186,/m/0319l,"French horn" 189 | 187,/m/07gql,"Trumpet" 190 | 188,/m/07c6l,"Trombone" 191 | 189,/m/0l14_3,"Bowed string instrument" 192 | 190,/m/02qmj0d,"String section" 193 | 191,/m/07y_7,"Violin, fiddle" 194 | 192,/m/0d8_n,"Pizzicato" 195 | 193,/m/01xqw,"Cello" 196 | 194,/m/02fsn,"Double bass" 197 | 195,/m/085jw,"Wind instrument, woodwind instrument" 198 | 196,/m/0l14j_,"Flute" 199 | 197,/m/06ncr,"Saxophone" 200 | 198,/m/01wy6,"Clarinet" 201 | 199,/m/03m5k,"Harp" 202 | 200,/m/0395lw,"Bell" 203 | 201,/m/03w41f,"Church bell" 204 | 202,/m/027m70_,"Jingle bell" 205 | 203,/m/0gy1t2s,"Bicycle bell" 206 | 204,/m/07n_g,"Tuning fork" 207 | 205,/m/0f8s22,"Chime" 208 | 206,/m/026fgl,"Wind chime" 209 | 207,/m/0150b9,"Change ringing (campanology)" 210 | 208,/m/03qjg,"Harmonica" 211 | 209,/m/0mkg,"Accordion" 212 | 210,/m/0192l,"Bagpipes" 213 | 211,/m/02bxd,"Didgeridoo" 214 | 212,/m/0l14l2,"Shofar" 215 | 213,/m/07kc_,"Theremin" 216 | 214,/m/0l14t7,"Singing bowl" 217 | 215,/m/01hgjl,"Scratching (performance technique)" 218 | 216,/m/064t9,"Pop music" 219 | 217,/m/0glt670,"Hip hop music" 220 | 218,/m/02cz_7,"Beatboxing" 221 | 219,/m/06by7,"Rock music" 222 | 220,/m/03lty,"Heavy metal" 223 | 221,/m/05r6t,"Punk rock" 224 | 222,/m/0dls3,"Grunge" 225 | 223,/m/0dl5d,"Progressive rock" 226 | 224,/m/07sbbz2,"Rock and roll" 227 | 225,/m/05w3f,"Psychedelic rock" 228 | 226,/m/06j6l,"Rhythm and blues" 229 | 227,/m/0gywn,"Soul music" 230 | 228,/m/06cqb,"Reggae" 231 | 229,/m/01lyv,"Country" 232 | 230,/m/015y_n,"Swing music" 233 | 231,/m/0gg8l,"Bluegrass" 234 | 232,/m/02x8m,"Funk" 235 | 233,/m/02w4v,"Folk music" 236 | 234,/m/06j64v,"Middle Eastern music" 237 | 235,/m/03_d0,"Jazz" 238 | 236,/m/026z9,"Disco" 239 | 237,/m/0ggq0m,"Classical music" 240 | 238,/m/05lls,"Opera" 241 | 239,/m/02lkt,"Electronic music" 242 | 240,/m/03mb9,"House music" 243 | 241,/m/07gxw,"Techno" 244 | 242,/m/07s72n,"Dubstep" 245 | 243,/m/0283d,"Drum and bass" 246 | 244,/m/0m0jc,"Electronica" 247 | 245,/m/08cyft,"Electronic dance music" 248 | 246,/m/0fd3y,"Ambient music" 249 | 247,/m/07lnk,"Trance music" 250 | 248,/m/0g293,"Music of Latin America" 251 | 249,/m/0ln16,"Salsa music" 252 | 250,/m/0326g,"Flamenco" 253 | 251,/m/0155w,"Blues" 254 | 252,/m/05fw6t,"Music for children" 255 | 253,/m/02v2lh,"New-age music" 256 | 254,/m/0y4f8,"Vocal music" 257 | 255,/m/0z9c,"A capella" 258 | 256,/m/0164x2,"Music of Africa" 259 | 257,/m/0145m,"Afrobeat" 260 | 258,/m/02mscn,"Christian music" 261 | 259,/m/016cjb,"Gospel music" 262 | 260,/m/028sqc,"Music of Asia" 263 | 261,/m/015vgc,"Carnatic music" 264 | 262,/m/0dq0md,"Music of Bollywood" 265 | 263,/m/06rqw,"Ska" 266 | 264,/m/02p0sh1,"Traditional music" 267 | 265,/m/05rwpb,"Independent music" 268 | 266,/m/074ft,"Song" 269 | 267,/m/025td0t,"Background music" 270 | 268,/m/02cjck,"Theme music" 271 | 269,/m/03r5q_,"Jingle (music)" 272 | 270,/m/0l14gg,"Soundtrack music" 273 | 271,/m/07pkxdp,"Lullaby" 274 | 272,/m/01z7dr,"Video game music" 275 | 273,/m/0140xf,"Christmas music" 276 | 274,/m/0ggx5q,"Dance music" 277 | 275,/m/04wptg,"Wedding music" 278 | 276,/t/dd00031,"Happy music" 279 | 277,/t/dd00032,"Funny music" 280 | 278,/t/dd00033,"Sad music" 281 | 279,/t/dd00034,"Tender music" 282 | 280,/t/dd00035,"Exciting music" 283 | 281,/t/dd00036,"Angry music" 284 | 282,/t/dd00037,"Scary music" 285 | 283,/m/03m9d0z,"Wind" 286 | 284,/m/09t49,"Rustling leaves" 287 | 285,/t/dd00092,"Wind noise (microphone)" 288 | 286,/m/0jb2l,"Thunderstorm" 289 | 287,/m/0ngt1,"Thunder" 290 | 288,/m/0838f,"Water" 291 | 289,/m/06mb1,"Rain" 292 | 290,/m/07r10fb,"Raindrop" 293 | 291,/t/dd00038,"Rain on surface" 294 | 292,/m/0j6m2,"Stream" 295 | 293,/m/0j2kx,"Waterfall" 296 | 294,/m/05kq4,"Ocean" 297 | 295,/m/034srq,"Waves, surf" 298 | 296,/m/06wzb,"Steam" 299 | 297,/m/07swgks,"Gurgling" 300 | 298,/m/02_41,"Fire" 301 | 299,/m/07pzfmf,"Crackle" 302 | 300,/m/07yv9,"Vehicle" 303 | 301,/m/019jd,"Boat, Water vehicle" 304 | 302,/m/0hsrw,"Sailboat, sailing ship" 305 | 303,/m/056ks2,"Rowboat, canoe, kayak" 306 | 304,/m/02rlv9,"Motorboat, speedboat" 307 | 305,/m/06q74,"Ship" 308 | 306,/m/012f08,"Motor vehicle (road)" 309 | 307,/m/0k4j,"Car" 310 | 308,/m/0912c9,"Vehicle horn, car horn, honking" 311 | 309,/m/07qv_d5,"Toot" 312 | 310,/m/02mfyn,"Car alarm" 313 | 311,/m/04gxbd,"Power windows, electric windows" 314 | 312,/m/07rknqz,"Skidding" 315 | 313,/m/0h9mv,"Tire squeal" 316 | 314,/t/dd00134,"Car passing by" 317 | 315,/m/0ltv,"Race car, auto racing" 318 | 316,/m/07r04,"Truck" 319 | 317,/m/0gvgw0,"Air brake" 320 | 318,/m/05x_td,"Air horn, truck horn" 321 | 319,/m/02rhddq,"Reversing beeps" 322 | 320,/m/03cl9h,"Ice cream truck, ice cream van" 323 | 321,/m/01bjv,"Bus" 324 | 322,/m/03j1ly,"Emergency vehicle" 325 | 323,/m/04qvtq,"Police car (siren)" 326 | 324,/m/012n7d,"Ambulance (siren)" 327 | 325,/m/012ndj,"Fire engine, fire truck (siren)" 328 | 326,/m/04_sv,"Motorcycle" 329 | 327,/m/0btp2,"Traffic noise, roadway noise" 330 | 328,/m/06d_3,"Rail transport" 331 | 329,/m/07jdr,"Train" 332 | 330,/m/04zmvq,"Train whistle" 333 | 331,/m/0284vy3,"Train horn" 334 | 332,/m/01g50p,"Railroad car, train wagon" 335 | 333,/t/dd00048,"Train wheels squealing" 336 | 334,/m/0195fx,"Subway, metro, underground" 337 | 335,/m/0k5j,"Aircraft" 338 | 336,/m/014yck,"Aircraft engine" 339 | 337,/m/04229,"Jet engine" 340 | 338,/m/02l6bg,"Propeller, airscrew" 341 | 339,/m/09ct_,"Helicopter" 342 | 340,/m/0cmf2,"Fixed-wing aircraft, airplane" 343 | 341,/m/0199g,"Bicycle" 344 | 342,/m/06_fw,"Skateboard" 345 | 343,/m/02mk9,"Engine" 346 | 344,/t/dd00065,"Light engine (high frequency)" 347 | 345,/m/08j51y,"Dental drill, dentist's drill" 348 | 346,/m/01yg9g,"Lawn mower" 349 | 347,/m/01j4z9,"Chainsaw" 350 | 348,/t/dd00066,"Medium engine (mid frequency)" 351 | 349,/t/dd00067,"Heavy engine (low frequency)" 352 | 350,/m/01h82_,"Engine knocking" 353 | 351,/t/dd00130,"Engine starting" 354 | 352,/m/07pb8fc,"Idling" 355 | 353,/m/07q2z82,"Accelerating, revving, vroom" 356 | 354,/m/02dgv,"Door" 357 | 355,/m/03wwcy,"Doorbell" 358 | 356,/m/07r67yg,"Ding-dong" 359 | 357,/m/02y_763,"Sliding door" 360 | 358,/m/07rjzl8,"Slam" 361 | 359,/m/07r4wb8,"Knock" 362 | 360,/m/07qcpgn,"Tap" 363 | 361,/m/07q6cd_,"Squeak" 364 | 362,/m/0642b4,"Cupboard open or close" 365 | 363,/m/0fqfqc,"Drawer open or close" 366 | 364,/m/04brg2,"Dishes, pots, and pans" 367 | 365,/m/023pjk,"Cutlery, silverware" 368 | 366,/m/07pn_8q,"Chopping (food)" 369 | 367,/m/0dxrf,"Frying (food)" 370 | 368,/m/0fx9l,"Microwave oven" 371 | 369,/m/02pjr4,"Blender" 372 | 370,/m/02jz0l,"Water tap, faucet" 373 | 371,/m/0130jx,"Sink (filling or washing)" 374 | 372,/m/03dnzn,"Bathtub (filling or washing)" 375 | 373,/m/03wvsk,"Hair dryer" 376 | 374,/m/01jt3m,"Toilet flush" 377 | 375,/m/012xff,"Toothbrush" 378 | 376,/m/04fgwm,"Electric toothbrush" 379 | 377,/m/0d31p,"Vacuum cleaner" 380 | 378,/m/01s0vc,"Zipper (clothing)" 381 | 379,/m/03v3yw,"Keys jangling" 382 | 380,/m/0242l,"Coin (dropping)" 383 | 381,/m/01lsmm,"Scissors" 384 | 382,/m/02g901,"Electric shaver, electric razor" 385 | 383,/m/05rj2,"Shuffling cards" 386 | 384,/m/0316dw,"Typing" 387 | 385,/m/0c2wf,"Typewriter" 388 | 386,/m/01m2v,"Computer keyboard" 389 | 387,/m/081rb,"Writing" 390 | 388,/m/07pp_mv,"Alarm" 391 | 389,/m/07cx4,"Telephone" 392 | 390,/m/07pp8cl,"Telephone bell ringing" 393 | 391,/m/01hnzm,"Ringtone" 394 | 392,/m/02c8p,"Telephone dialing, DTMF" 395 | 393,/m/015jpf,"Dial tone" 396 | 394,/m/01z47d,"Busy signal" 397 | 395,/m/046dlr,"Alarm clock" 398 | 396,/m/03kmc9,"Siren" 399 | 397,/m/0dgbq,"Civil defense siren" 400 | 398,/m/030rvx,"Buzzer" 401 | 399,/m/01y3hg,"Smoke detector, smoke alarm" 402 | 400,/m/0c3f7m,"Fire alarm" 403 | 401,/m/04fq5q,"Foghorn" 404 | 402,/m/0l156k,"Whistle" 405 | 403,/m/06hck5,"Steam whistle" 406 | 404,/t/dd00077,"Mechanisms" 407 | 405,/m/02bm9n,"Ratchet, pawl" 408 | 406,/m/01x3z,"Clock" 409 | 407,/m/07qjznt,"Tick" 410 | 408,/m/07qjznl,"Tick-tock" 411 | 409,/m/0l7xg,"Gears" 412 | 410,/m/05zc1,"Pulleys" 413 | 411,/m/0llzx,"Sewing machine" 414 | 412,/m/02x984l,"Mechanical fan" 415 | 413,/m/025wky1,"Air conditioning" 416 | 414,/m/024dl,"Cash register" 417 | 415,/m/01m4t,"Printer" 418 | 416,/m/0dv5r,"Camera" 419 | 417,/m/07bjf,"Single-lens reflex camera" 420 | 418,/m/07k1x,"Tools" 421 | 419,/m/03l9g,"Hammer" 422 | 420,/m/03p19w,"Jackhammer" 423 | 421,/m/01b82r,"Sawing" 424 | 422,/m/02p01q,"Filing (rasp)" 425 | 423,/m/023vsd,"Sanding" 426 | 424,/m/0_ksk,"Power tool" 427 | 425,/m/01d380,"Drill" 428 | 426,/m/014zdl,"Explosion" 429 | 427,/m/032s66,"Gunshot, gunfire" 430 | 428,/m/04zjc,"Machine gun" 431 | 429,/m/02z32qm,"Fusillade" 432 | 430,/m/0_1c,"Artillery fire" 433 | 431,/m/073cg4,"Cap gun" 434 | 432,/m/0g6b5,"Fireworks" 435 | 433,/g/122z_qxw,"Firecracker" 436 | 434,/m/07qsvvw,"Burst, pop" 437 | 435,/m/07pxg6y,"Eruption" 438 | 436,/m/07qqyl4,"Boom" 439 | 437,/m/083vt,"Wood" 440 | 438,/m/07pczhz,"Chop" 441 | 439,/m/07pl1bw,"Splinter" 442 | 440,/m/07qs1cx,"Crack" 443 | 441,/m/039jq,"Glass" 444 | 442,/m/07q7njn,"Chink, clink" 445 | 443,/m/07rn7sz,"Shatter" 446 | 444,/m/04k94,"Liquid" 447 | 445,/m/07rrlb6,"Splash, splatter" 448 | 446,/m/07p6mqd,"Slosh" 449 | 447,/m/07qlwh6,"Squish" 450 | 448,/m/07r5v4s,"Drip" 451 | 449,/m/07prgkl,"Pour" 452 | 450,/m/07pqc89,"Trickle, dribble" 453 | 451,/t/dd00088,"Gush" 454 | 452,/m/07p7b8y,"Fill (with liquid)" 455 | 453,/m/07qlf79,"Spray" 456 | 454,/m/07ptzwd,"Pump (liquid)" 457 | 455,/m/07ptfmf,"Stir" 458 | 456,/m/0dv3j,"Boiling" 459 | 457,/m/0790c,"Sonar" 460 | 458,/m/0dl83,"Arrow" 461 | 459,/m/07rqsjt,"Whoosh, swoosh, swish" 462 | 460,/m/07qnq_y,"Thump, thud" 463 | 461,/m/07rrh0c,"Thunk" 464 | 462,/m/0b_fwt,"Electronic tuner" 465 | 463,/m/02rr_,"Effects unit" 466 | 464,/m/07m2kt,"Chorus effect" 467 | 465,/m/018w8,"Basketball bounce" 468 | 466,/m/07pws3f,"Bang" 469 | 467,/m/07ryjzk,"Slap, smack" 470 | 468,/m/07rdhzs,"Whack, thwack" 471 | 469,/m/07pjjrj,"Smash, crash" 472 | 470,/m/07pc8lb,"Breaking" 473 | 471,/m/07pqn27,"Bouncing" 474 | 472,/m/07rbp7_,"Whip" 475 | 473,/m/07pyf11,"Flap" 476 | 474,/m/07qb_dv,"Scratch" 477 | 475,/m/07qv4k0,"Scrape" 478 | 476,/m/07pdjhy,"Rub" 479 | 477,/m/07s8j8t,"Roll" 480 | 478,/m/07plct2,"Crushing" 481 | 479,/t/dd00112,"Crumpling, crinkling" 482 | 480,/m/07qcx4z,"Tearing" 483 | 481,/m/02fs_r,"Beep, bleep" 484 | 482,/m/07qwdck,"Ping" 485 | 483,/m/07phxs1,"Ding" 486 | 484,/m/07rv4dm,"Clang" 487 | 485,/m/07s02z0,"Squeal" 488 | 486,/m/07qh7jl,"Creak" 489 | 487,/m/07qwyj0,"Rustle" 490 | 488,/m/07s34ls,"Whir" 491 | 489,/m/07qmpdm,"Clatter" 492 | 490,/m/07p9k1k,"Sizzle" 493 | 491,/m/07qc9xj,"Clicking" 494 | 492,/m/07rwm0c,"Clickety-clack" 495 | 493,/m/07phhsh,"Rumble" 496 | 494,/m/07qyrcz,"Plop" 497 | 495,/m/07qfgpx,"Jingle, tinkle" 498 | 496,/m/07rcgpl,"Hum" 499 | 497,/m/07p78v5,"Zing" 500 | 498,/t/dd00121,"Boing" 501 | 499,/m/07s12q4,"Crunch" 502 | 500,/m/028v0c,"Silence" 503 | 501,/m/01v_m0,"Sine wave" 504 | 502,/m/0b9m1,"Harmonic" 505 | 503,/m/0hdsk,"Chirp tone" 506 | 504,/m/0c1dj,"Sound effect" 507 | 505,/m/07pt_g0,"Pulse" 508 | 506,/t/dd00125,"Inside, small room" 509 | 507,/t/dd00126,"Inside, large room or hall" 510 | 508,/t/dd00127,"Inside, public space" 511 | 509,/t/dd00128,"Outside, urban or manmade" 512 | 510,/t/dd00129,"Outside, rural or natural" 513 | 511,/m/01b9nn,"Reverberation" 514 | 512,/m/01jnbd,"Echo" 515 | 513,/m/096m7z,"Noise" 516 | 514,/m/06_y0by,"Environmental noise" 517 | 515,/m/07rgkc5,"Static" 518 | 516,/m/06xkwv,"Mains hum" 519 | 517,/m/0g12c5,"Distortion" 520 | 518,/m/08p9q4,"Sidetone" 521 | 519,/m/07szfh9,"Cacophony" 522 | 520,/m/0chx_,"White noise" 523 | 521,/m/0cj0r,"Pink noise" 524 | 522,/m/07p_0gm,"Throbbing" 525 | 523,/m/01jwx6,"Vibration" 526 | 524,/m/07c52,"Television" 527 | 525,/m/06bz3,"Radio" 528 | 526,/m/07hvw1,"Field recording" 529 | -------------------------------------------------------------------------------- /models/dymn/dy_block.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | from typing import Any 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from models.dymn.utils import make_divisible, cnn_out_size 9 | 10 | 11 | class DynamicInvertedResidualConfig: 12 | def __init__( 13 | self, 14 | input_channels: int, 15 | kernel: int, 16 | expanded_channels: int, 17 | out_channels: int, 18 | use_dy_block: bool, 19 | activation: str, 20 | stride: int, 21 | dilation: int, 22 | width_mult: float, 23 | ): 24 | self.input_channels = self.adjust_channels(input_channels, width_mult) 25 | self.kernel = kernel 26 | self.expanded_channels = self.adjust_channels(expanded_channels, width_mult) 27 | self.out_channels = self.adjust_channels(out_channels, width_mult) 28 | self.use_dy_block = use_dy_block 29 | self.use_hs = activation == "HS" 30 | self.use_se = False 31 | self.stride = stride 32 | self.dilation = dilation 33 | self.width_mult = width_mult 34 | 35 | @staticmethod 36 | def adjust_channels(channels: int, width_mult: float): 37 | return make_divisible(channels * width_mult, 8) 38 | 39 | def out_size(self, in_size): 40 | padding = (self.kernel - 1) // 2 * self.dilation 41 | return cnn_out_size(in_size, padding, self.dilation, self.kernel, self.stride) 42 | 43 | 44 | class DynamicConv(nn.Module): 45 | def __init__(self, 46 | in_channels, 47 | out_channels, 48 | context_dim, 49 | kernel_size, 50 | stride=1, 51 | dilation=1, 52 | padding=0, 53 | groups=1, 54 | att_groups=1, 55 | bias=False, 56 | k=4, 57 | temp_schedule=(30, 1, 1, 0.05) 58 | ): 59 | super(DynamicConv, self).__init__() 60 | assert in_channels % groups == 0 61 | self.in_channels = in_channels 62 | self.out_channels = out_channels 63 | self.kernel_size = kernel_size 64 | self.stride = stride 65 | self.padding = padding 66 | self.dilation = dilation 67 | self.groups = groups 68 | self.k = k 69 | self.T_max, self.T_min, self.T0_slope, self.T1_slope = temp_schedule 70 | self.temperature = self.T_max 71 | # att_groups splits the channels into 'att_groups' groups and predicts separate attention weights 72 | # for each of the groups; did only give slight improvements in our experiments and not mentioned in paper 73 | self.att_groups = att_groups 74 | 75 | # Equation 6 in paper: obtain coefficients for K attention weights over conv. kernels 76 | self.residuals = nn.Sequential( 77 | nn.Linear(context_dim, k * self.att_groups) 78 | ) 79 | 80 | # k sets of weights for convolution 81 | weight = torch.randn(k, out_channels, in_channels // groups, kernel_size, kernel_size) 82 | 83 | if bias: 84 | self.bias = nn.Parameter(torch.zeros(k, out_channels), requires_grad=True) 85 | else: 86 | self.bias = None 87 | 88 | self._initialize_weights(weight, self.bias) 89 | 90 | weight = weight.view(1, k, att_groups, out_channels, 91 | in_channels // groups, kernel_size, kernel_size) 92 | 93 | weight = weight.transpose(1, 2).view(1, self.att_groups, self.k, -1) 94 | self.weight = nn.Parameter(weight, requires_grad=True) 95 | 96 | def _initialize_weights(self, weight, bias): 97 | init_func = partial(nn.init.kaiming_normal_, mode="fan_out") 98 | for i in range(self.k): 99 | init_func(weight[i]) 100 | if bias is not None: 101 | nn.init.zeros_(bias[i]) 102 | 103 | def forward(self, x, g=None): 104 | b, c, f, t = x.size() 105 | g_c = g[0].view(b, -1) 106 | residuals = self.residuals(g_c).view(b, self.att_groups, 1, -1) 107 | attention = F.softmax(residuals / self.temperature, dim=-1) 108 | 109 | # attention shape: batch_size x 1 x 1 x k 110 | # self.weight shape: 1 x 1 x k x out_channels * (in_channels // groups) * kernel_size ** 2 111 | aggregate_weight = (attention @ self.weight).transpose(1, 2).reshape(b, self.out_channels, 112 | self.in_channels // self.groups, 113 | self.kernel_size, self.kernel_size) 114 | 115 | # aggregate_weight shape: batch_size x out_channels x in_channels // groups x kernel_size x kernel_size 116 | aggregate_weight = aggregate_weight.view(b * self.out_channels, self.in_channels // self.groups, 117 | self.kernel_size, self.kernel_size) 118 | # each sample in the batch has different weights for the convolution - therefore batch and channel dims need to 119 | # be merged together in channel dimension 120 | x = x.view(1, -1, f, t) 121 | if self.bias is not None: 122 | aggregate_bias = torch.mm(attention, self.bias).view(-1) 123 | output = F.conv2d(x, weight=aggregate_weight, bias=aggregate_bias, stride=self.stride, padding=self.padding, 124 | dilation=self.dilation, groups=self.groups * b) 125 | else: 126 | output = F.conv2d(x, weight=aggregate_weight, bias=None, stride=self.stride, padding=self.padding, 127 | dilation=self.dilation, groups=self.groups * b) 128 | 129 | # output shape: 1 x batch_size * channels x f_bands x time_frames 130 | output = output.view(b, self.out_channels, output.size(-2), output.size(-1)) 131 | return output 132 | 133 | def update_params(self, epoch): 134 | # temperature schedule for attention weights 135 | # see Equation 5: tau = temperature 136 | t0 = self.T_max - self.T0_slope * epoch 137 | t1 = 1 + self.T1_slope * (self.T_max - 1) / self.T0_slope - self.T1_slope * epoch 138 | self.temperature = max(t0, t1, self.T_min) 139 | print(f"Setting temperature for attention over kernels to {self.temperature}") 140 | 141 | 142 | class DyReLU(nn.Module): 143 | def __init__(self, channels, context_dim, M=2): 144 | super(DyReLU, self).__init__() 145 | self.channels = channels 146 | self.M = M 147 | 148 | self.coef_net = nn.Sequential( 149 | nn.Linear(context_dim, 2 * M) 150 | ) 151 | 152 | self.sigmoid = nn.Sigmoid() 153 | 154 | self.register_buffer('lambdas', torch.Tensor([1.] * M + [0.5] * M).float()) 155 | self.register_buffer('init_v', torch.Tensor([1.] + [0.] * (2 * M - 1)).float()) 156 | 157 | def get_relu_coefs(self, x): 158 | theta = self.coef_net(x) 159 | theta = 2 * self.sigmoid(theta) - 1 160 | return theta 161 | 162 | def forward(self, x, g): 163 | raise NotImplementedError 164 | 165 | 166 | class DyReLUB(DyReLU): 167 | def __init__(self, channels, context_dim, M=2): 168 | super(DyReLUB, self).__init__(channels, context_dim, M) 169 | # Equation 4 in paper: obtain coefficients for M linear mappings for each of the C channels 170 | self.coef_net[-1] = nn.Linear(context_dim, 2 * M * self.channels) 171 | 172 | def forward(self, x, g): 173 | assert x.shape[1] == self.channels 174 | assert g is not None 175 | b, c, f, t = x.size() 176 | h_c = g[0].view(b, -1) 177 | theta = self.get_relu_coefs(h_c) 178 | 179 | relu_coefs = theta.view(-1, self.channels, 1, 1, 2 * self.M) * self.lambdas + self.init_v 180 | # relu_coefs shape: batch_size x channels x 1 x 1 x 2*M 181 | # x shape: batch_size x channels x f_bands x time_frames 182 | x_mapped = x.unsqueeze(-1) * relu_coefs[:, :, :, :, :self.M] + relu_coefs[:, :, :, :, self.M:] 183 | if self.M == 2: 184 | # torch.maximum turned out to be faster than torch.max for M=2 185 | result = torch.maximum(x_mapped[:, :, :, :, 0], x_mapped[:, :, :, :, 1]) 186 | else: 187 | result = torch.max(x_mapped, dim=-1)[0] 188 | return result 189 | 190 | 191 | class CoordAtt(nn.Module): 192 | def __init__(self): 193 | super(CoordAtt, self).__init__() 194 | 195 | def forward(self, x, g): 196 | g_cf, g_ct = g[1], g[2] 197 | a_f = g_cf.sigmoid() 198 | a_t = g_ct.sigmoid() 199 | # recalibration with channel-frequency and channel-time weights 200 | out = x * a_f * a_t 201 | return out 202 | 203 | 204 | class DynamicWrapper(torch.nn.Module): 205 | # wrap a pytorch module in a dynamic module 206 | def __init__(self, module): 207 | super().__init__() 208 | self.module = module 209 | 210 | def forward(self, x, g): 211 | return self.module(x) 212 | 213 | 214 | class ContextGen(nn.Module): 215 | def __init__(self, context_dim, in_ch, exp_ch, norm_layer, stride: int = 1): 216 | super(ContextGen, self).__init__() 217 | 218 | # shared linear layer implemented as a 2D convolution with 1x1 kernel 219 | self.joint_conv = nn.Conv2d(in_ch, context_dim, kernel_size=(1, 1), stride=(1, 1), padding=0, bias=False) 220 | self.joint_norm = norm_layer(context_dim) 221 | self.joint_act = nn.Hardswish(inplace=True) 222 | 223 | # separate linear layers for Coordinate Attention 224 | self.conv_f = nn.Conv2d(context_dim, exp_ch, kernel_size=(1, 1), stride=(1, 1), padding=0) 225 | self.conv_t = nn.Conv2d(context_dim, exp_ch, kernel_size=(1, 1), stride=(1, 1), padding=0) 226 | 227 | if stride > 1: 228 | # sequence pooling for Coordinate Attention 229 | self.pool_f = nn.AvgPool2d(kernel_size=(3, 1), stride=(stride, 1), padding=(1, 0)) 230 | self.pool_t = nn.AvgPool2d(kernel_size=(1, 3), stride=(1, stride), padding=(0, 1)) 231 | else: 232 | self.pool_f = nn.Sequential() 233 | self.pool_t = nn.Sequential() 234 | 235 | def forward(self, x, g): 236 | cf = F.adaptive_avg_pool2d(x, (None, 1)) 237 | ct = F.adaptive_avg_pool2d(x, (1, None)).permute(0, 1, 3, 2) 238 | f, t = cf.size(2), ct.size(2) 239 | 240 | g_cat = torch.cat([cf, ct], dim=2) 241 | # joint frequency and time sequence transformation (S_F and S_T in the paper) 242 | g_cat = self.joint_norm(self.joint_conv(g_cat)) 243 | g_cat = self.joint_act(g_cat) 244 | 245 | h_cf, h_ct = torch.split(g_cat, [f, t], dim=2) 246 | h_ct = h_ct.permute(0, 1, 3, 2) 247 | # pooling over sequence dimension to get context vector of size H to parameterize Dy-ReLU and Dy-Conv 248 | h_c = torch.mean(g_cat, dim=2, keepdim=True) 249 | g_cf, g_ct = self.conv_f(self.pool_f(h_cf)), self.conv_t(self.pool_t(h_ct)) 250 | 251 | # g[0]: context vector of size H to parameterize Dy-ReLU and Dy-Conv 252 | # g[1], g[2]: frequency and time sequences for Coordinate Attention 253 | g = (h_c, g_cf, g_ct) 254 | return g 255 | 256 | 257 | class DY_Block(nn.Module): 258 | def __init__( 259 | self, 260 | cnf: DynamicInvertedResidualConfig, 261 | context_ratio: int = 4, 262 | max_context_size: int = 128, 263 | min_context_size: int = 32, 264 | temp_schedule: tuple = (30, 1, 1, 0.05), 265 | dyrelu_k: int = 2, 266 | dyconv_k: int = 4, 267 | no_dyrelu: bool = False, 268 | no_dyconv: bool = False, 269 | no_ca: bool = False, 270 | **kwargs: Any 271 | ): 272 | super().__init__() 273 | if not (1 <= cnf.stride <= 2): 274 | raise ValueError("illegal stride value") 275 | 276 | self.use_res_connect = cnf.stride == 1 and cnf.input_channels == cnf.out_channels 277 | # context_dim is denoted as 'H' in the paper 278 | self.context_dim = np.clip(make_divisible(cnf.expanded_channels // context_ratio, 8), 279 | make_divisible(min_context_size * cnf.width_mult, 8), 280 | make_divisible(max_context_size * cnf.width_mult, 8) 281 | ) 282 | 283 | activation_layer = nn.Hardswish if cnf.use_hs else nn.ReLU 284 | norm_layer = partial(nn.BatchNorm2d, eps=0.001, momentum=0.01) 285 | 286 | # expand 287 | if cnf.expanded_channels != cnf.input_channels: 288 | if no_dyconv: 289 | self.exp_conv = DynamicWrapper( 290 | nn.Conv2d( 291 | cnf.input_channels, 292 | cnf.expanded_channels, 293 | kernel_size=(1, 1), 294 | stride=(1, 1), 295 | dilation=(1, 1), 296 | padding=0, 297 | bias=False 298 | ) 299 | ) 300 | else: 301 | self.exp_conv = DynamicConv( 302 | cnf.input_channels, 303 | cnf.expanded_channels, 304 | self.context_dim, 305 | kernel_size=1, 306 | k=dyconv_k, 307 | temp_schedule=temp_schedule, 308 | stride=1, 309 | dilation=1, 310 | padding=0, 311 | bias=False 312 | ) 313 | 314 | self.exp_norm = norm_layer(cnf.expanded_channels) 315 | self.exp_act = DynamicWrapper(activation_layer(inplace=True)) 316 | else: 317 | self.exp_conv = DynamicWrapper(nn.Identity()) 318 | self.exp_norm = nn.Identity() 319 | self.exp_act = DynamicWrapper(nn.Identity()) 320 | 321 | # depthwise 322 | stride = 1 if cnf.dilation > 1 else cnf.stride 323 | padding = (cnf.kernel - 1) // 2 * cnf.dilation 324 | if no_dyconv: 325 | self.depth_conv = DynamicWrapper( 326 | nn.Conv2d( 327 | cnf.expanded_channels, 328 | cnf.expanded_channels, 329 | kernel_size=(cnf.kernel, cnf.kernel), 330 | groups=cnf.expanded_channels, 331 | stride=(stride, stride), 332 | dilation=(cnf.dilation, cnf.dilation), 333 | padding=padding, 334 | bias=False 335 | ) 336 | ) 337 | else: 338 | self.depth_conv = DynamicConv( 339 | cnf.expanded_channels, 340 | cnf.expanded_channels, 341 | self.context_dim, 342 | kernel_size=cnf.kernel, 343 | k=dyconv_k, 344 | temp_schedule=temp_schedule, 345 | groups=cnf.expanded_channels, 346 | stride=stride, 347 | dilation=cnf.dilation, 348 | padding=padding, 349 | bias=False 350 | ) 351 | self.depth_norm = norm_layer(cnf.expanded_channels) 352 | self.depth_act = DynamicWrapper(activation_layer(inplace=True)) if no_dyrelu \ 353 | else DyReLUB(cnf.expanded_channels, self.context_dim, M=dyrelu_k) 354 | 355 | self.ca = DynamicWrapper(nn.Identity()) if no_ca else CoordAtt() 356 | 357 | # project 358 | if no_dyconv: 359 | self.proj_conv = DynamicWrapper( 360 | nn.Conv2d( 361 | cnf.expanded_channels, 362 | cnf.out_channels, 363 | kernel_size=(1, 1), 364 | stride=(1, 1), 365 | dilation=(1, 1), 366 | padding=0, 367 | bias=False 368 | ) 369 | ) 370 | else: 371 | self.proj_conv = DynamicConv( 372 | cnf.expanded_channels, 373 | cnf.out_channels, 374 | self.context_dim, 375 | kernel_size=1, 376 | k=dyconv_k, 377 | temp_schedule=temp_schedule, 378 | stride=1, 379 | dilation=1, 380 | padding=0, 381 | bias=False, 382 | ) 383 | 384 | self.proj_norm = norm_layer(cnf.out_channels) 385 | 386 | context_norm_layer = norm_layer 387 | self.context_gen = ContextGen(self.context_dim, cnf.input_channels, cnf.expanded_channels, 388 | norm_layer=context_norm_layer, stride=stride) 389 | 390 | def forward(self, x, g=None): 391 | # x: CNN feature map (C x F x T) 392 | inp = x 393 | 394 | g = self.context_gen(x, g) 395 | x = self.exp_conv(x, g) 396 | x = self.exp_norm(x) 397 | x = self.exp_act(x, g) 398 | 399 | x = self.depth_conv(x, g) 400 | x = self.depth_norm(x) 401 | x = self.depth_act(x, g) 402 | x = self.ca(x, g) 403 | 404 | x = self.proj_conv(x, g) 405 | x = self.proj_norm(x) 406 | 407 | if self.use_res_connect: 408 | x += inp 409 | return x 410 | -------------------------------------------------------------------------------- /ex_pl_audioset.py: -------------------------------------------------------------------------------- 1 | import wandb 2 | import numpy as np 3 | import os 4 | from torch import autocast 5 | from tqdm import tqdm 6 | import torch 7 | from torch.utils.data import DataLoader 8 | import argparse 9 | from sklearn import metrics 10 | from contextlib import nullcontext 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | from torch.hub import download_url_to_file 14 | import pickle 15 | import pytorch_lightning as pl 16 | from pytorch_lightning.loggers import WandbLogger 17 | from pytorch_lightning.callbacks import LearningRateMonitor 18 | 19 | from datasets.audioset import get_test_set, get_full_training_set, get_ft_weighted_sampler 20 | from models.mn.model import get_model as get_mobilenet 21 | from models.dymn.model import get_model as get_dymn 22 | from models.ensemble import get_ensemble_model 23 | from models.preprocess import AugmentMelSTFT 24 | from helpers.init import worker_init_fn 25 | from helpers.utils import NAME_TO_WIDTH, exp_warmup_linear_down, mixup 26 | 27 | preds_url = \ 28 | "https://github.com/fschmid56/EfficientAT/releases/download/v0.0.1/passt_enemble_logits_mAP_495.npy" 29 | 30 | fname_to_index_url = "https://github.com/fschmid56/EfficientAT/releases/download/v0.0.1/fname_to_index.pkl" 31 | 32 | 33 | class PLModule(pl.LightningModule): 34 | def __init__(self, config): 35 | super().__init__() 36 | self.config = config 37 | # model to preprocess waveform to mel spectrograms 38 | self.mel = AugmentMelSTFT(n_mels=config.n_mels, 39 | sr=config.resample_rate, 40 | win_length=config.window_size, 41 | hopsize=config.hop_size, 42 | n_fft=config.n_fft, 43 | freqm=config.freqm, 44 | timem=config.timem, 45 | fmin=config.fmin, 46 | fmax=config.fmax, 47 | fmin_aug_range=config.fmin_aug_range, 48 | fmax_aug_range=config.fmax_aug_range 49 | ) 50 | 51 | # load prediction model 52 | model_name = config.model_name 53 | pretrained_name = model_name if config.pretrained else None 54 | width = NAME_TO_WIDTH(model_name) if model_name and config.pretrained else config.model_width 55 | if model_name.startswith("dymn"): 56 | model = get_dymn(width_mult=width, pretrained_name=pretrained_name, 57 | strides=config.strides, pretrain_final_temp=config.pretrain_final_temp) 58 | else: 59 | model = get_mobilenet(width_mult=width, pretrained_name=pretrained_name, 60 | strides=config.strides, head_type=config.head_type, se_dims=config.se_dims) 61 | self.model = model 62 | 63 | # prepare ingredients for knowledge distillation 64 | assert 0 <= config.kd_lambda <= 1, "Lambda for Knowledge Distillation must be between 0 and 1." 65 | self.distillation_loss = nn.BCEWithLogitsLoss(reduction="none") 66 | 67 | # load stored teacher predictions 68 | if not os.path.isfile(config.teacher_preds): 69 | # download file 70 | print("Download teacher predictions...") 71 | download_url_to_file(preds_url, config.teacher_preds) 72 | print(f"Load teacher predictions from file {config.teacher_preds}") 73 | teacher_preds = np.load(config.teacher_preds) 74 | teacher_preds = torch.from_numpy(teacher_preds).float() 75 | teacher_preds = torch.sigmoid(teacher_preds / config.temperature) 76 | teacher_preds.requires_grad = False 77 | self.teacher_preds = teacher_preds 78 | 79 | if not os.path.isfile(config.fname_to_index): 80 | print("Download filename to teacher prediction index dictionary...") 81 | download_url_to_file(fname_to_index_url, config.fname_to_index) 82 | with open(config.fname_to_index, 'rb') as f: 83 | fname_to_index = pickle.load(f) 84 | self.fname_to_index = fname_to_index 85 | 86 | self.distributed_mode = config.num_devices > 1 87 | self.training_step_outputs = [] 88 | self.validation_step_outputs = [] 89 | 90 | def mel_forward(self, x): 91 | old_shape = x.size() 92 | x = x.reshape(-1, old_shape[2]) 93 | x = self.mel(x) 94 | x = x.reshape(old_shape[0], old_shape[1], x.shape[1], x.shape[2]) 95 | return x 96 | 97 | def forward(self, x): 98 | """ 99 | :param x: batch of raw audio signals (waveforms) 100 | :return: final model predictions 101 | """ 102 | x = self.mel_forward(x) 103 | x = self.model(x) 104 | return x 105 | 106 | def configure_optimizers(self): 107 | """ 108 | This is the way pytorch lightening requires optimizers and learning rate schedulers to be defined. 109 | The specified items are used automatically in the optimization loop (no need to call optimizer.step() yourself). 110 | :return: dict containing optimizer and learning rate scheduler 111 | """ 112 | if self.config.adamw: 113 | optimizer = torch.optim.AdamW(self.model.parameters(), lr=self.config.max_lr, 114 | weight_decay=self.config.weight_decay) 115 | else: 116 | optimizer = torch.optim.Adam(self.model.parameters(), lr=self.config.max_lr, 117 | weight_decay=self.config.weight_decay) 118 | 119 | # phases of lr schedule: exponential increase, constant lr, linear decrease, fine-tune 120 | schedule_lambda = \ 121 | exp_warmup_linear_down(self.config.warm_up_len, self.config.ramp_down_len, self.config.ramp_down_start, 122 | self.config.last_lr_value) 123 | lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, schedule_lambda) 124 | 125 | return { 126 | 'optimizer': optimizer, 127 | 'lr_scheduler': lr_scheduler 128 | } 129 | 130 | def on_train_epoch_start(self): 131 | # in case of DyMN: update DyConv temperature 132 | if hasattr(self.model, "update_params"): 133 | self.model.update_params(self.current_epoch) 134 | 135 | def training_step(self, train_batch, batch_idx): 136 | """ 137 | :param train_batch: contains one batch from train dataloader 138 | :param batch_idx 139 | :return: a dict containing at least loss that is used to update model parameters, can also contain 140 | other items that can be processed in 'training_epoch_end' to log other metrics than loss 141 | """ 142 | x, f, y, i = train_batch 143 | bs = x.size(0) 144 | x = self.mel_forward(x) 145 | 146 | rn_indices, lam = None, None 147 | if self.config.mixup_alpha: 148 | rn_indices, lam = mixup(bs, self.config.mixup_alpha) 149 | lam = lam.to(x.device) 150 | x = x * lam.reshape(bs, 1, 1, 1) + \ 151 | x[rn_indices] * (1. - lam.reshape(bs, 1, 1, 1)) 152 | y_hat, _ = self.model(x) 153 | y_mix = y * lam.reshape(bs, 1) + y[rn_indices] * (1. - lam.reshape(bs, 1)) 154 | samples_loss = F.binary_cross_entropy_with_logits(y_hat, y_mix, reduction="none") 155 | else: 156 | y_hat, _ = self.model(x) 157 | samples_loss = F.binary_cross_entropy_with_logits(y_hat, y, reduction="none") 158 | 159 | # hard label loss 160 | label_loss = samples_loss.mean() 161 | 162 | # distillation loss 163 | if self.config.kd_lambda > 0: 164 | # fetch the correct index in 'teacher_preds' for given filename 165 | # insert -1 for files not in fname_to_index (proportion of files successfully downloaded from 166 | # YouTube can vary for AudioSet) 167 | indices = torch.tensor( 168 | [self.fname_to_index[fname] if fname in self.fname_to_index else -1 for fname in f], dtype=torch.int64 169 | ) 170 | # get indices of files we could not find the teacher predictions for 171 | unknown_indices = indices == -1 172 | y_soft_teacher = self.teacher_preds[indices] 173 | y_soft_teacher = y_soft_teacher.to(y_hat.device).type_as(y_hat) 174 | 175 | if self.config.mixup_alpha: 176 | soft_targets_loss = \ 177 | self.distillation_loss(y_hat, y_soft_teacher).mean(dim=1) * lam.reshape(bs) + \ 178 | self.distillation_loss(y_hat, y_soft_teacher[rn_indices]).mean(dim=1) \ 179 | * (1. - lam.reshape(bs)) 180 | else: 181 | soft_targets_loss = distillation_loss(y_hat, y_soft_teacher) 182 | 183 | # zero out loss for samples we don't have teacher predictions for 184 | soft_targets_loss[unknown_indices] = soft_targets_loss[unknown_indices] * 0 185 | soft_targets_loss = soft_targets_loss.mean() 186 | 187 | # weighting losses 188 | label_loss = self.config.kd_lambda * label_loss 189 | soft_targets_loss = (1 - self.config.kd_lambda) * soft_targets_loss 190 | else: 191 | soft_targets_loss = torch.tensor(0., device=label_loss.device, dtype=label_loss.dtype) 192 | 193 | # total loss is sum of lambda-weighted label and distillation loss 194 | loss = label_loss + soft_targets_loss 195 | 196 | results = {"loss": loss.detach().cpu(), "label_loss": label_loss.detach().cpu(), 197 | "kd_loss": soft_targets_loss.detach().cpu()} 198 | self.training_step_outputs.append(results) 199 | return loss 200 | 201 | def on_train_epoch_end(self): 202 | """ 203 | :return: a dict containing the metrics you want to log to Weights and Biases 204 | """ 205 | avg_loss = torch.stack([x['loss'] for x in self.training_step_outputs]).mean() 206 | avg_label_loss = torch.stack([x['label_loss'] for x in self.training_step_outputs]).mean() 207 | avg_kd_loss = torch.stack([x['kd_loss'] for x in self.training_step_outputs]).mean() 208 | self.log_dict({'train/loss': torch.as_tensor(avg_loss).cuda(), 209 | 'train/label_loss': torch.as_tensor(avg_label_loss).cuda(), 210 | 'train/kd_loss': torch.as_tensor(avg_kd_loss).cuda() 211 | }, sync_dist=True) 212 | 213 | self.training_step_outputs.clear() 214 | 215 | def validation_step(self, val_batch, batch_idx): 216 | x, _, y = val_batch 217 | x = self.mel_forward(x) 218 | y_hat, _ = self.model(x) 219 | loss = F.binary_cross_entropy_with_logits(y_hat, y) 220 | preds = torch.sigmoid(y_hat) 221 | results = {'val_loss': loss, "preds": preds, "targets": y} 222 | results = {k: v.cpu() for k, v in results.items()} 223 | self.validation_step_outputs.append(results) 224 | 225 | def on_validation_epoch_end(self): 226 | loss = torch.stack([x['val_loss'] for x in self.validation_step_outputs]) 227 | preds = torch.cat([x['preds'] for x in self.validation_step_outputs], dim=0) 228 | targets = torch.cat([x['targets'] for x in self.validation_step_outputs], dim=0) 229 | 230 | all_preds = self.all_gather(preds).reshape(-1, preds.shape[-1]).cpu().float().numpy() 231 | all_targets = self.all_gather(targets).reshape(-1, targets.shape[-1]).cpu().float().numpy() 232 | all_loss = self.all_gather(loss).reshape(-1,) 233 | 234 | try: 235 | average_precision = metrics.average_precision_score( 236 | all_targets, all_preds, average=None) 237 | except ValueError: 238 | average_precision = np.array([np.nan] * 527) 239 | try: 240 | roc = metrics.roc_auc_score(all_targets, all_preds, average=None) 241 | except ValueError: 242 | roc = np.array([np.nan] * 527) 243 | logs = {'val/loss': torch.as_tensor(all_loss).mean().cuda(), 244 | 'val/ap': torch.as_tensor(average_precision).mean().cuda(), 245 | 'val/roc': torch.as_tensor(roc).mean().cuda() 246 | } 247 | self.log_dict(logs, sync_dist=False) 248 | self.validation_step_outputs.clear() 249 | 250 | 251 | def train(config): 252 | # Train Models from scratch or ImageNet pre-trained on AudioSet 253 | # PaSST ensemble (https://github.com/kkoutini/PaSST) stored in 'resources/passt_enemble_logits_mAP_495.npy' 254 | # can be used as a teacher. 255 | 256 | # logging is done using wandb 257 | wandb_logger = WandbLogger( 258 | project="EfficientAudioTagging", 259 | notes="Training efficient audio tagging models on AudioSet using Knowledge Distillation.", 260 | tags=["AudioSet", "Audio Tagging", "Knowledge Disitillation"], 261 | config=config, 262 | name=config.experiment_name 263 | ) 264 | 265 | train_dl = DataLoader(dataset=get_full_training_set(resample_rate=config.resample_rate, 266 | roll=config.roll, 267 | wavmix=config.wavmix, 268 | gain_augment=config.gain_augment), 269 | sampler=get_ft_weighted_sampler(config.epoch_len), # sampler important to balance classes 270 | worker_init_fn=worker_init_fn, 271 | num_workers=config.num_workers, 272 | batch_size=config.batch_size) 273 | 274 | # eval dataloader 275 | eval_dl = DataLoader(dataset=get_test_set(resample_rate=config.resample_rate), 276 | worker_init_fn=worker_init_fn, 277 | num_workers=config.num_workers, 278 | batch_size=config.batch_size) 279 | 280 | # create pytorch lightening module 281 | pl_module = PLModule(config) 282 | 283 | # create monitor to keep track of learning rate - we want to check the behaviour of our learning rate schedule 284 | lr_monitor = LearningRateMonitor(logging_interval='epoch') 285 | # create the pytorch lightening trainer by specifying the number of epochs to train, the logger, 286 | # on which kind of device(s) to train and possible callbacks 287 | trainer = pl.Trainer(max_epochs=config.n_epochs, 288 | logger=wandb_logger, 289 | accelerator='auto', 290 | devices=config.num_devices, 291 | precision=config.precision, 292 | num_sanity_val_steps=0, 293 | callbacks=[lr_monitor]) 294 | 295 | # start training and validation for the specified number of epochs 296 | trainer.fit(pl_module, train_dl, eval_dl) 297 | 298 | 299 | if __name__ == '__main__': 300 | parser = argparse.ArgumentParser(description='Example of parser. ') 301 | 302 | # general 303 | parser.add_argument('--experiment_name', type=str, default="AudioSet") 304 | parser.add_argument('--batch_size', type=int, default=120) 305 | parser.add_argument('--num_workers', type=int, default=12) 306 | parser.add_argument('--num_devices', type=int, default=4) 307 | 308 | # evaluation 309 | # if ensemble is set, 'model_name' is not used 310 | parser.add_argument('--ensemble', nargs='+', default=[]) 311 | parser.add_argument('--model_name', type=str, default="mn10_as") # used also for training 312 | parser.add_argument('--cuda', action='store_true', default=False) 313 | 314 | # training 315 | parser.add_argument('--precision', type=int, default=16) 316 | parser.add_argument('--pretrained', action='store_true', default=False) 317 | parser.add_argument('--pretrain_final_temp', type=float, default=30.0) # for DyMN 318 | parser.add_argument('--model_width', type=float, default=1.0) 319 | parser.add_argument('--strides', nargs=4, default=[2, 2, 2, 2], type=int) 320 | parser.add_argument('--head_type', type=str, default="mlp") 321 | parser.add_argument('--se_dims', type=str, default="c") 322 | parser.add_argument('--n_epochs', type=int, default=200) 323 | parser.add_argument('--mixup_alpha', type=float, default=0.3) 324 | parser.add_argument('--epoch_len', type=int, default=100000) 325 | parser.add_argument('--roll', action='store_true', default=False) 326 | parser.add_argument('--wavmix', action='store_true', default=False) 327 | parser.add_argument('--gain_augment', type=int, default=0) 328 | 329 | # optimizer 330 | parser.add_argument('--adamw', action='store_true', default=False) 331 | parser.add_argument('--weight_decay', type=float, default=0.0001) 332 | # lr schedule 333 | parser.add_argument('--max_lr', type=float, default=0.003) 334 | parser.add_argument('--warm_up_len', type=int, default=8) 335 | parser.add_argument('--ramp_down_start', type=int, default=80) 336 | parser.add_argument('--ramp_down_len', type=int, default=95) 337 | parser.add_argument('--last_lr_value', type=float, default=0.01) 338 | 339 | # knowledge distillation 340 | parser.add_argument('--teacher_preds', type=str, 341 | default=os.path.join("resources", "passt_enemble_logits_mAP_495.npy")) 342 | parser.add_argument('--fname_to_index', type=str, 343 | default=os.path.join("resources", "fname_to_index.pkl")) 344 | parser.add_argument('--temperature', type=float, default=1) 345 | parser.add_argument('--kd_lambda', type=float, default=0.1) 346 | 347 | # preprocessing 348 | parser.add_argument('--resample_rate', type=int, default=32000) 349 | parser.add_argument('--window_size', type=int, default=800) 350 | parser.add_argument('--hop_size', type=int, default=320) 351 | parser.add_argument('--n_fft', type=int, default=1024) 352 | parser.add_argument('--n_mels', type=int, default=128) 353 | parser.add_argument('--freqm', type=int, default=0) 354 | parser.add_argument('--timem', type=int, default=0) 355 | parser.add_argument('--fmin', type=int, default=0) 356 | parser.add_argument('--fmax', type=int, default=None) 357 | parser.add_argument('--fmin_aug_range', type=int, default=10) 358 | parser.add_argument('--fmax_aug_range', type=int, default=2000) 359 | 360 | args = parser.parse_args() 361 | train(args) 362 | -------------------------------------------------------------------------------- /ex_audioset.py: -------------------------------------------------------------------------------- 1 | import wandb 2 | import numpy as np 3 | import os 4 | from torch import autocast 5 | from tqdm import tqdm 6 | import torch 7 | from torch.utils.data import DataLoader 8 | import argparse 9 | from sklearn import metrics 10 | from contextlib import nullcontext 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | from torch.hub import download_url_to_file 14 | import pickle 15 | 16 | from datasets.audioset import get_test_set, get_full_training_set, get_ft_weighted_sampler 17 | from models.mn.model import get_model as get_mobilenet 18 | from models.dymn.model import get_model as get_dymn 19 | from models.ensemble import get_ensemble_model 20 | from models.preprocess import AugmentMelSTFT 21 | from helpers.init import worker_init_fn 22 | from helpers.utils import NAME_TO_WIDTH, exp_warmup_linear_down, mixup 23 | 24 | preds_url = \ 25 | "https://github.com/fschmid56/EfficientAT/releases/download/v0.0.1/passt_enemble_logits_mAP_495.npy" 26 | 27 | fname_to_index_url = "https://github.com/fschmid56/EfficientAT/releases/download/v0.0.1/fname_to_index.pkl" 28 | 29 | 30 | def train(args): 31 | # Train Models from scratch or ImageNet pre-trained on AudioSet 32 | # PaSST ensemble (https://github.com/kkoutini/PaSST) stored in 'resources/passt_enemble_logits_mAP_495.npy' 33 | # can be used as a teacher. 34 | 35 | # logging is done using wandb 36 | wandb.init( 37 | project="EfficientAudioTagging", 38 | notes="Training efficient audio tagging models on AudioSet using Knowledge Distillation.", 39 | tags=["AudioSet", "Audio Tagging", "Knowledge Disitillation"], 40 | config=args, 41 | name=args.experiment_name 42 | ) 43 | 44 | device = torch.device('cuda') if args.cuda and torch.cuda.is_available() else torch.device('cpu') 45 | 46 | # model to preprocess waveform into mel spectrograms 47 | mel = AugmentMelSTFT(n_mels=args.n_mels, 48 | sr=args.resample_rate, 49 | win_length=args.window_size, 50 | hopsize=args.hop_size, 51 | n_fft=args.n_fft, 52 | freqm=args.freqm, 53 | timem=args.timem, 54 | fmin=args.fmin, 55 | fmax=args.fmax, 56 | fmin_aug_range=args.fmin_aug_range, 57 | fmax_aug_range=args.fmax_aug_range 58 | ) 59 | mel.to(device) 60 | # load prediction model 61 | model_name = args.model_name 62 | pretrained_name = model_name if args.pretrained else None 63 | width = NAME_TO_WIDTH(model_name) if model_name and args.pretrained else args.model_width 64 | if model_name.startswith("dymn"): 65 | model = get_dymn(width_mult=width, pretrained_name=pretrained_name, 66 | strides=args.strides, pretrain_final_temp=args.pretrain_final_temp) 67 | else: 68 | model = get_mobilenet(width_mult=width, pretrained_name=pretrained_name, 69 | strides=args.strides, head_type=args.head_type, se_dims=args.se_dims) 70 | model.to(device) 71 | 72 | # dataloader 73 | dl = DataLoader(dataset=get_full_training_set(resample_rate=args.resample_rate, roll=args.roll, wavmix=args.wavmix, 74 | gain_augment=args.gain_augment), 75 | sampler=get_ft_weighted_sampler(args.epoch_len), # sampler important to balance classes 76 | worker_init_fn=worker_init_fn, 77 | num_workers=args.num_workers, 78 | batch_size=args.batch_size) 79 | 80 | # evaluation loader 81 | eval_dl = DataLoader(dataset=get_test_set(resample_rate=args.resample_rate), 82 | worker_init_fn=worker_init_fn, 83 | num_workers=args.num_workers, 84 | batch_size=args.batch_size) 85 | 86 | if args.adamw: 87 | # optimizer & scheduler 88 | optimizer = torch.optim.AdamW(model.parameters(), lr=args.max_lr, weight_decay=args.weight_decay) 89 | else: 90 | # optimizer & scheduler 91 | optimizer = torch.optim.Adam(model.parameters(), lr=args.max_lr, weight_decay=args.weight_decay) 92 | 93 | 94 | # phases of lr schedule: exponential increase, constant lr, linear decrease, fine-tune 95 | schedule_lambda = \ 96 | exp_warmup_linear_down(args.warm_up_len, args.ramp_down_len, args.ramp_down_start, args.last_lr_value) 97 | scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, schedule_lambda) 98 | 99 | # prepare ingredients for knowledge distillation 100 | assert 0 <= args.kd_lambda <= 1, "Lambda for Knowledge Distillation must be between 0 and 1." 101 | distillation_loss = nn.BCEWithLogitsLoss(reduction="none") 102 | # load stored teacher predictions 103 | 104 | if not os.path.isfile(args.teacher_preds): 105 | # download file 106 | print("Download teacher predictions...") 107 | download_url_to_file(preds_url, args.teacher_preds) 108 | print(f"Load teacher predictions from file {args.teacher_preds}") 109 | teacher_preds = np.load(args.teacher_preds) 110 | teacher_preds = torch.from_numpy(teacher_preds).float() 111 | teacher_preds = torch.sigmoid(teacher_preds / args.temperature) 112 | teacher_preds.requires_grad = False 113 | 114 | if not os.path.isfile(args.fname_to_index): 115 | print("Download filename to teacher prediction index dictionary...") 116 | download_url_to_file(fname_to_index_url, args.fname_to_index) 117 | with open(args.fname_to_index, 'rb') as f: 118 | fname_to_index = pickle.load(f) 119 | 120 | name = None 121 | mAP, ROC, val_loss = float('NaN'), float('NaN'), float('NaN') 122 | 123 | for epoch in range(args.n_epochs): 124 | mel.train() 125 | model.train() 126 | train_stats = dict(train_loss=list(), label_loss=list(), distillation_loss=list()) 127 | pbar = tqdm(dl) 128 | pbar.set_description("Epoch {}/{}: mAP: {:.4f}, val_loss: {:.4f}" 129 | .format(epoch + 1, args.n_epochs, mAP, val_loss)) 130 | 131 | # in case of DyMN: update DyConv temperature 132 | if hasattr(model, "update_params"): 133 | model.update_params(epoch) 134 | 135 | for batch in pbar: 136 | x, f, y, i = batch 137 | bs = x.size(0) 138 | x, y = x.to(device), y.to(device) 139 | x = _mel_forward(x, mel) 140 | 141 | rn_indices, lam = None, None 142 | if args.mixup_alpha: 143 | rn_indices, lam = mixup(bs, args.mixup_alpha) 144 | lam = lam.to(x.device) 145 | x = x * lam.reshape(bs, 1, 1, 1) + \ 146 | x[rn_indices] * (1. - lam.reshape(bs, 1, 1, 1)) 147 | y_hat, _ = model(x) 148 | y_mix = y * lam.reshape(bs, 1) + y[rn_indices] * (1. - lam.reshape(bs, 1)) 149 | samples_loss = F.binary_cross_entropy_with_logits(y_hat, y_mix, reduction="none") 150 | else: 151 | y_hat, _ = model(x) 152 | samples_loss = F.binary_cross_entropy_with_logits(y_hat, y, reduction="none") 153 | 154 | # hard label loss 155 | label_loss = samples_loss.mean() 156 | 157 | # distillation loss 158 | if args.kd_lambda > 0: 159 | # fetch the correct index in 'teacher_preds' for given filename 160 | # insert -1 for files not in fname_to_index (proportion of files successfully downloaded from 161 | # YouTube can vary for AudioSet) 162 | indices = torch.tensor( 163 | [fname_to_index[fname] if fname in fname_to_index else -1 for fname in f], dtype=torch.int64 164 | ) 165 | # get indices of files we could not find the teacher predictions for 166 | unknown_indices = indices == -1 167 | y_soft_teacher = teacher_preds[indices] 168 | y_soft_teacher = y_soft_teacher.to(y_hat.device).type_as(y_hat) 169 | 170 | if args.mixup_alpha: 171 | soft_targets_loss = \ 172 | distillation_loss(y_hat, y_soft_teacher).mean(dim=1) * lam.reshape(bs) + \ 173 | distillation_loss(y_hat, y_soft_teacher[rn_indices]).mean(dim=1) \ 174 | * (1. - lam.reshape(bs)) 175 | else: 176 | soft_targets_loss = distillation_loss(y_hat, y_soft_teacher) 177 | 178 | # zero out loss for samples we don't have teacher predictions for 179 | soft_targets_loss[unknown_indices] = soft_targets_loss[unknown_indices] * 0 180 | soft_targets_loss = soft_targets_loss.mean() 181 | 182 | # weighting losses 183 | label_loss = args.kd_lambda * label_loss 184 | soft_targets_loss = (1 - args.kd_lambda) * soft_targets_loss 185 | else: 186 | soft_targets_loss = torch.tensor(0., device=label_loss.device, dtype=label_loss.dtype) 187 | 188 | # total loss is sum of lambda-weighted label and distillation loss 189 | loss = label_loss + soft_targets_loss 190 | 191 | # append training statistics 192 | train_stats['train_loss'].append(loss.detach().cpu().numpy()) 193 | train_stats['label_loss'].append(label_loss.detach().cpu().numpy()) 194 | train_stats['distillation_loss'].append(soft_targets_loss.detach().cpu().numpy()) 195 | 196 | # Update Model 197 | loss.backward() 198 | optimizer.step() 199 | optimizer.zero_grad() 200 | # Update learning rate 201 | scheduler.step() 202 | 203 | # evaluate 204 | mAP, ROC, val_loss = _test(model, mel, eval_dl, device) 205 | 206 | # log train and validation statistics 207 | wandb.log({"train_loss": np.mean(train_stats['train_loss']), 208 | "label_loss": np.mean(train_stats['label_loss']), 209 | "distillation_loss": np.mean(train_stats['distillation_loss']), 210 | "learning_rate": scheduler.get_last_lr()[0], 211 | "mAP": mAP, 212 | "ROC": ROC, 213 | "val_loss": val_loss 214 | }) 215 | 216 | # remove previous model (we try to not flood your hard disk) and save latest model 217 | if name is not None: 218 | os.remove(os.path.join(wandb.run.dir, name)) 219 | name = f"mn{str(width).replace('.', '')}_as_epoch_{epoch}_mAP_{int(round(mAP*1000))}.pt" 220 | torch.save(model.state_dict(), os.path.join(wandb.run.dir, name)) 221 | 222 | 223 | def _mel_forward(x, mel): 224 | old_shape = x.size() 225 | x = x.reshape(-1, old_shape[2]) 226 | x = mel(x) 227 | x = x.reshape(old_shape[0], old_shape[1], x.shape[1], x.shape[2]) 228 | return x 229 | 230 | 231 | def _test(model, mel, eval_loader, device): 232 | model.eval() 233 | mel.eval() 234 | 235 | targets = [] 236 | outputs = [] 237 | losses = [] 238 | pbar = tqdm(eval_loader) 239 | pbar.set_description("Validating") 240 | for batch in pbar: 241 | x, _, y = batch 242 | x = x.to(device) 243 | y = y.to(device) 244 | with torch.no_grad(): 245 | x = _mel_forward(x, mel) 246 | y_hat, _ = model(x) 247 | targets.append(y.cpu().numpy()) 248 | outputs.append(y_hat.float().cpu().numpy()) 249 | losses.append(F.binary_cross_entropy_with_logits(y_hat, y).cpu().numpy()) 250 | 251 | targets = np.concatenate(targets) 252 | outputs = np.concatenate(outputs) 253 | losses = np.stack(losses) 254 | mAP = metrics.average_precision_score(targets, outputs, average=None) 255 | ROC = metrics.roc_auc_score(targets, outputs, average=None) 256 | return mAP.mean(), ROC.mean(), losses.mean() 257 | 258 | 259 | def evaluate(args): 260 | model_name = args.model_name 261 | device = torch.device('cuda') if args.cuda and torch.cuda.is_available() else torch.device('cpu') 262 | 263 | # load pre-trained model 264 | if len(args.ensemble) > 0: 265 | print(f"Running AudioSet evaluation for models '{args.ensemble}' on device '{device}'") 266 | model = get_ensemble_model(args.ensemble) 267 | else: 268 | print(f"Running AudioSet evaluation for model '{model_name}' on device '{device}'") 269 | if model_name.startswith("dymn"): 270 | model = get_dymn(width_mult=NAME_TO_WIDTH(model_name), pretrained_name=model_name, 271 | strides=args.strides) 272 | else: 273 | model = get_mobilenet(width_mult=NAME_TO_WIDTH(model_name), pretrained_name=model_name, 274 | strides=args.strides, head_type=args.head_type) 275 | model.to(device) 276 | model.eval() 277 | 278 | # model to preprocess waveform into mel spectrograms 279 | mel = AugmentMelSTFT(n_mels=args.n_mels, 280 | sr=args.resample_rate, 281 | win_length=args.window_size, 282 | hopsize=args.hop_size, 283 | n_fft=args.n_fft, 284 | fmin=args.fmin, 285 | fmax=args.fmax 286 | ) 287 | mel.to(device) 288 | mel.eval() 289 | 290 | dl = DataLoader(dataset=get_test_set(resample_rate=args.resample_rate), 291 | worker_init_fn=worker_init_fn, 292 | num_workers=args.num_workers, 293 | batch_size=args.batch_size) 294 | 295 | targets = [] 296 | outputs = [] 297 | for batch in tqdm(dl): 298 | x, _, y = batch 299 | x = x.to(device) 300 | y = y.to(device) 301 | # our models are trained in half precision mode (torch.float16) 302 | # run on cuda with torch.float16 to get the best performance 303 | # running on cpu with torch.float32 gives similar performance, using torch.bfloat16 is worse 304 | with autocast(device_type=device.type) if args.cuda else nullcontext(): 305 | with torch.no_grad(): 306 | x = _mel_forward(x, mel) 307 | y_hat, _ = model(x) 308 | targets.append(y.cpu().numpy()) 309 | outputs.append(y_hat.float().cpu().numpy()) 310 | 311 | targets = np.concatenate(targets) 312 | outputs = np.concatenate(outputs) 313 | mAP = metrics.average_precision_score(targets, outputs, average=None) 314 | ROC = metrics.roc_auc_score(targets, outputs, average=None) 315 | 316 | if len(args.ensemble) > 0: 317 | print(f"Results on AudioSet test split for loaded models: {args.ensemble}") 318 | else: 319 | print(f"Results on AudioSet test split for loaded model: {model_name}") 320 | print(" mAP: {:.3f}".format(mAP.mean())) 321 | print(" ROC: {:.3f}".format(ROC.mean())) 322 | 323 | 324 | if __name__ == '__main__': 325 | parser = argparse.ArgumentParser(description='Example of parser. ') 326 | 327 | # general 328 | parser.add_argument('--experiment_name', type=str, default="AudioSet") 329 | parser.add_argument('--train', action='store_true', default=False) 330 | parser.add_argument('--cuda', action='store_true', default=False) 331 | parser.add_argument('--batch_size', type=int, default=120) 332 | parser.add_argument('--num_workers', type=int, default=12) 333 | 334 | # evaluation 335 | # if ensemble is set, 'model_name' is not used 336 | parser.add_argument('--ensemble', nargs='+', default=[]) 337 | parser.add_argument('--model_name', type=str, default="mn10_as") # used also for training 338 | 339 | # training 340 | parser.add_argument('--pretrained', action='store_true', default=False) 341 | parser.add_argument('--pretrain_final_temp', type=float, default=30.0) # for DyMN 342 | parser.add_argument('--model_width', type=float, default=1.0) 343 | parser.add_argument('--strides', nargs=4, default=[2, 2, 2, 2], type=int) 344 | parser.add_argument('--head_type', type=str, default="mlp") 345 | parser.add_argument('--se_dims', type=str, default="c") 346 | parser.add_argument('--n_epochs', type=int, default=200) 347 | parser.add_argument('--mixup_alpha', type=float, default=0.3) 348 | parser.add_argument('--epoch_len', type=int, default=100000) 349 | parser.add_argument('--roll', action='store_true', default=False) 350 | parser.add_argument('--wavmix', action='store_true', default=False) 351 | parser.add_argument('--gain_augment', type=int, default=0) 352 | 353 | # optimizer 354 | parser.add_argument('--adamw', action='store_true', default=False) 355 | parser.add_argument('--weight_decay', type=float, default=0) 356 | # lr schedule 357 | parser.add_argument('--max_lr', type=float, default=0.0008) 358 | parser.add_argument('--warm_up_len', type=int, default=8) 359 | parser.add_argument('--ramp_down_start', type=int, default=80) 360 | parser.add_argument('--ramp_down_len', type=int, default=95) 361 | parser.add_argument('--last_lr_value', type=float, default=0.01) 362 | 363 | # knowledge distillation 364 | parser.add_argument('--teacher_preds', type=str, 365 | default=os.path.join("resources", "passt_enemble_logits_mAP_495.npy")) 366 | parser.add_argument('--fname_to_index', type=str, 367 | default=os.path.join("resources", "fname_to_index.pkl")) 368 | parser.add_argument('--temperature', type=float, default=1) 369 | parser.add_argument('--kd_lambda', type=float, default=0.1) 370 | 371 | # preprocessing 372 | parser.add_argument('--resample_rate', type=int, default=32000) 373 | parser.add_argument('--window_size', type=int, default=800) 374 | parser.add_argument('--hop_size', type=int, default=320) 375 | parser.add_argument('--n_fft', type=int, default=1024) 376 | parser.add_argument('--n_mels', type=int, default=128) 377 | parser.add_argument('--freqm', type=int, default=0) 378 | parser.add_argument('--timem', type=int, default=0) 379 | parser.add_argument('--fmin', type=int, default=0) 380 | parser.add_argument('--fmax', type=int, default=None) 381 | parser.add_argument('--fmin_aug_range', type=int, default=10) 382 | parser.add_argument('--fmax_aug_range', type=int, default=2000) 383 | 384 | args = parser.parse_args() 385 | if args.train: 386 | train(args) 387 | else: 388 | evaluate(args) 389 | --------------------------------------------------------------------------------