├── .gitignore ├── README.md ├── audio_tag └── backbone.py ├── config.py ├── data ├── URBAN-SED_v2.0.0 │ └── metadata │ │ ├── test.tsv │ │ ├── train.tsv │ │ └── validate.tsv └── dcase2019 │ └── metadata │ ├── eval │ └── public.tsv │ ├── train │ ├── dcase2018_task5.tsv │ ├── synthetic_2019 │ │ └── soundscapes.tsv │ ├── unlabel_in_domain.tsv │ └── weak.tsv │ └── validation │ └── validation.tsv ├── data_utils ├── DataLoad.py ├── SedData.py └── collapse_event.py ├── engine.py ├── img ├── sedt.png └── sp-sedt.png ├── sedt ├── __init__.py ├── backbone.py ├── matcher.py ├── position_encoding.py ├── sedt.py ├── spsedt.py └── transformer.py ├── train_at.py ├── train_sedt.py ├── train_spsedt.py ├── train_ss_sedt.py └── utilities ├── BoxEncoder.py ├── BoxTransforms.py ├── FrameEncoder.py ├── FrameTransforms.py ├── Logger.py ├── Scaler.py ├── box_ops.py ├── distribute.py ├── metrics.py ├── mixup.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | /exp/* 2 | /.idea/* 3 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Sound Event Detection Transformer 2 | ![image](./img/sedt.png) 3 | ## Prepare your data 4 | + URBANSED Dataset 5 | Download [URBAN-SED_v2.0.0](https://zenodo.org/record/1324404/files/URBAN-SED_v2.0.0.tar.gz?download=1) dataset, and 6 | change $urbansed_dir in config.py to your own URBAN-SED data path. To generate *.tsv file, run 7 | ```python 8 | python ./data_utils/collapse_event.py 9 | ``` 10 | 11 | 12 | + DCASE2019 Task4 Dataset 13 | Download the dataset from the website of [DCASE](http://dcase.community/), and change $dcase_dir in config.py to your own 14 | DCASE data path. 15 | 16 | ## Train models 17 | + To train model on the dataset of URBANSED, run 18 | ```shell script 19 | python train_sedt.py 20 | --gpus $ngpu 21 | --dataname urbansed 22 | --batch_size 64 23 | --fusion_strategy 1 24 | --dec_at 25 | --weak_loss_coef 1 26 | --epochs 400 # total epochs 27 | --epochs_ls 280 # epochs of learning stage 28 | --lr_drop 160 29 | --num_queries 10 30 | ``` 31 | 32 | ## Evaluate models 33 | + For URBAN-SED dataset, download our [SEDT(E=3, Eb_F1=38.15)](https://drive.google.com/file/d/1X7PEZzPH61W1KCFAyLN6RspvIfabb2H-/view?usp=sharing), put it in ./exp/urbansed/model/ , then run 34 | ```shell script 35 | python train_sedt.py --gpus 0 --dataname urbansed --dec_at --fusion_strategy 1 --num_queries 10 --eval --info URBAN-SED 36 | ``` 37 | 38 | # SP-SEDT: Self-supervised Pretraining for SEDT 39 | ![image](img/sp-sedt.png) 40 | ## Prepare your data 41 | + DCASE2018 Task5 development dataset 42 | Download [the dataset](https://zenodo.org/record/1247102), put the audios in $dcase_dir/audio/train/ and the *.tsv file 43 | in $dcase_dir/metadata/train/ 44 | ## Train models 45 | + To train backbone by audio tagging task, run 46 | ```shell script 47 | python train_at.py --dataname dcase --gpu $ngpu --pooling max 48 | ``` 49 | You can also download our [backbone](https://drive.google.com/file/d/1R-hAnM6cW1Q9TvLBqROrTxOp4T99Ih76/view?usp=sharing), and put it in ./exp/dcase/model/ 50 | + To pretrain SEDT on a single node with $N gpus, run 51 | ```shell script 52 | python -m torch.distributed.launch --nproc_per_node=$N train_spsedt.py 53 | --gpus $ngpu 54 | --dataname dcase 55 | --batch_size 200 56 | --self_sup 57 | --feature_recon 58 | --num_patches 10 59 | --num_queries 20 60 | --enc_layers 6 61 | --epochs 160 62 | --pretrain "backbone" # file name of the backbone model 63 | --checkpoint_epochs 20 64 | ``` 65 | You can also download our models pretrained [with](https://drive.google.com/file/d/1iYykmwu0Imuoypb30IQDRWIf-_3F7mXu/view?usp=sharing) or [without](https://drive.google.com/file/d/1TpR0YhmPxVYyJ0HOm1tn4AnYPqZe442-/view?usp=sharing) dcase2018 task5 data, 66 | and put it in ./exp/dcase/model/ 67 | + To fine-tune SEDT, run 68 | ```shell script 69 | python train_sedt.py 70 | --gpus $ngpu 71 | --dataname dcase 72 | --batch_size 32 73 | --n_weak 16 74 | --num_queries 20 75 | --enc_layers 6 76 | --dec_at 77 | --fusion_strategy 1 78 | --epochs 300 79 | --pretrain "Pretrained_SP_SEDT" 80 | --weak_loss_coef 0.25 81 | ``` 82 | ## Evaluate models 83 | Download our [SP-SEDT(E=6, Eb_F1=39.03)](https://drive.google.com/file/d/1JIhvRpvW6MC7N88PxCVQ8BpckaAYLDDU/view?usp=sharing), put it in ./exp/dcase/model/ , then run 84 | 85 | ```shell script 86 | python train_sedt.py 87 | --gpus $ngpu 88 | --dataname dcase 89 | --num_queries 20 90 | --enc_layers 6 91 | --dec_at 92 | --fusion_strategy 1 93 | --eval 94 | --info SP_SEDT 95 | ``` 96 | # DCASE challenge system 97 | A SP-SEDT-based system is constructed with mixup, frequency mask, frequency shift and time mask as data augmentation methods, and hybrid pseudo-labelling/mean-teacher as semi-supervised learning method where a teacher model is updated from the student model online using exponential moving average strategy and used to generate hard pseudo labels for unlabeled data. 98 | ## Train models 99 | + To prepare a teahcer model with weak and synthetic subsets, download [SP-SEDT(E=6)](https://drive.google.com/file/d/1TpR0YhmPxVYyJ0HOm1tn4AnYPqZe442-/view?usp=sharing) pretrained with only dcase2019 unlabel subset, put it in ./exp/dcase/model/, then run 100 | ```shell script 101 | python train_sedt.py 102 | --gpus $ngpu 103 | --dataname dcase 104 | --batch_size 64 105 | --n_weak 32 106 | --num_queries 20 107 | --enc_layers 6 108 | --dec_at 109 | --fusion_strategy 1 110 | --epochs 400 111 | --pretrain "Pretrained_SP_SEDT_unlabel" 112 | --weak_loss_coef 0.25 113 | --freq_mask 114 | --freq_shift 115 | --mix_up_ratio 0.6 116 | ``` 117 | You can also download our [teacher model (EB_F1=40.88)](https://drive.google.com/file/d/15EGgn6tKnQ9AUHPzAzlEBeCzslMVPvLs/view?usp=sharing), put it in ./exp/dcase/model/ 118 | + To semi-supervised train SEDT with additional unlabel subset, run 119 | ```shell script 120 | python train_ss_sedt.py 121 | --gpus $ngpu 122 | --dataname dcase 123 | --num_queries 20 124 | --enc_layers 6 125 | --dec_at 126 | --fusion_strategy 1 127 | --epochs 400 128 | --teacher_model "teacher_model_40.88" # file name of the teacher model 129 | --weak_loss_coef 0.25 130 | --freq_mask 131 | --freq_shift 132 | --time_mask 133 | --focal_loss 134 | --mix_up_ratio 0.6 135 | ``` 136 | ## Evaluate models 137 | Download our [SP-SEDT(E=6, Eb_F1=51.75)](https://drive.google.com/file/d/1e9x4ZY5WccoYmwhlErr5a1PhAtvdSPL0/view?usp=sharing), put it in ./exp/dcase/model/ , then run 138 | ```shell script 139 | python train_ss_sedt.py 140 | --gpus $ngpu 141 | --dataname dcase 142 | --num_queries 20 143 | --enc_layers 6 144 | --dec_at 145 | --fusion_strategy 1 146 | --eval 147 | --info SP-SEDT-system 148 | ``` 149 | ## Related papers 150 | ``` 151 | @article{2021Sound, 152 | title={Sound Event Detection Transformer: An Event-based End-to-End Model for Sound Event Detection}, 153 | author={ Ye, Zhirong and Wang, Xiangdong and Liu, Hong and Qian, Yueliang and Tao, Rui and Yan, Long and Ouchi, Kazushige }, 154 | year={2021}, 155 | } 156 | @article{2021SP, 157 | title={SP-SEDT: Self-supervised Pre-training for Sound Event Detection Transformer}, 158 | author={ Ye, Z. and Wang, X. and Liu, H. and Qian, Y. and Tao, R. and Yan, L. and Ouchi, K. }, 159 | journal={arXiv e-prints}, 160 | year={2021}, 161 | } 162 | ``` 163 | -------------------------------------------------------------------------------- /audio_tag/backbone.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | """ 3 | Backbone modules. 4 | """ 5 | import torch 6 | import torchvision 7 | from torch import nn 8 | from torchvision.models._utils import IntermediateLayerGetter 9 | from sedt.backbone import FrozenBatchNorm2d 10 | 11 | 12 | 13 | class BackboneBase(nn.Module): 14 | 15 | def __init__(self, backbone: nn.Module, fix_backbone: bool, num_channels: int, pooling): 16 | super().__init__() 17 | for name, parameter in backbone.named_parameters(): 18 | # freeze backbone while training audio tag classifier 19 | if fix_backbone: 20 | parameter.requires_grad_(False) 21 | 22 | return_layers = {'layer4': "0"} 23 | return_layers[f"{pooling}pool_"] = str(int(return_layers["layer4"]) + 1) 24 | self.weak_label = torch.nn.Sequential( 25 | torch.nn.Linear(num_channels, 1000), 26 | torch.nn.ReLU(), 27 | torch.nn.Linear(1000, 10) 28 | ) 29 | self.body = IntermediateLayerGetter(backbone, return_layers=return_layers) 30 | self.num_channels = num_channels 31 | 32 | def forward(self, x): 33 | xs = self.body(x) 34 | x = xs['1'].flatten(1) # cnn_at 35 | at = self.weak_label(x) 36 | at = at.sigmoid() 37 | return at 38 | 39 | 40 | 41 | class Backbone(BackboneBase): 42 | """ResNet backbone with frozen BatchNorm.""" 43 | def __init__(self, name: str, 44 | fix_backbone: bool, 45 | dilation: bool, 46 | pooling: str, 47 | pretrained: bool 48 | ): 49 | backbone = nn.Sequential() 50 | backbone.add_module('conv0', nn.Conv2d(1, 3, 1)) 51 | resnet = getattr(torchvision.models, name)( 52 | replace_stride_with_dilation=[False, False, dilation], 53 | pretrained=pretrained, norm_layer=FrozenBatchNorm2d) 54 | for name, module in resnet.named_children(): 55 | if "avgpool" in name : 56 | if "max" in pooling : 57 | backbone.add_module('maxpool_', nn.AdaptiveMaxPool2d(output_size=(1, 1))) 58 | else: 59 | backbone.add_module('avgpool_', module) 60 | else: 61 | backbone.add_module(name, module) 62 | num_channels = 512 if name in ('resnet18', 'resnet34') else 2048 63 | super().__init__(backbone, fix_backbone, num_channels, pooling) 64 | 65 | 66 | 67 | 68 | def build_backbone(args): 69 | model = Backbone(args.backbone, args.fix_backbone, args.dilation, args.pooling, args.pretrained) 70 | return model 71 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # encoding: utf-8 3 | """ 4 | @author: yzr 5 | @file: config.py 6 | @time: 2020/7/30 12:06 7 | """ 8 | import logging 9 | import math 10 | import os 11 | import pandas as pd 12 | 13 | # save directory 14 | dir_root = "./exp/" 15 | 16 | # data 17 | dcase_dir = "./data/dcase2019/" 18 | # # DESED Paths 19 | weak = dcase_dir + 'metadata/train/weak.tsv' 20 | synthetic = dcase_dir + 'metadata/train/synthetic_2019/soundscapes.tsv' 21 | unlabel = dcase_dir + 'metadata/train/unlabel_in_domain.tsv' 22 | dcase2018_task5 = dcase_dir + "metadata/train/dcase2018_task5.tsv" 23 | validation = dcase_dir + 'metadata/validation/validation.tsv' 24 | eval_desed = dcase_dir + "metadata/eval/public.tsv" 25 | # Useful because does not correspond to the tsv file path (metadata replace by audio), (due to subsets test/eval2018) 26 | audio_validation_dir = dcase_dir + '/audio/validation' 27 | 28 | # urbansound 29 | urbansed_dir = "./data/URBAN-SED_v2.0.0/" 30 | urban_train_tsv = urbansed_dir + "metadata/train.tsv" 31 | urban_valid_tsv = urbansed_dir + "metadata/validate.tsv" 32 | urban_eval_tsv = urbansed_dir + "metadata/test.tsv" 33 | 34 | 35 | max_len_seconds = 10. 36 | noise_snr = 30 37 | 38 | # dcase features 39 | sample_rate = 16000 40 | n_window = 1024 41 | n_fft = 1024 42 | hop_size = 323 43 | n_mels = 64 44 | max_frames = math.ceil(max_len_seconds * sample_rate / hop_size) # 496 45 | 46 | # urbansound feature 47 | usample_rate = 44100 48 | un_fft = 2048 49 | un_window = int(0.04 * usample_rate) 50 | uhop_size = int(0.02 * usample_rate) 51 | un_mels = 64 52 | umax_frames = int(max_len_seconds * usample_rate / uhop_size) 53 | 54 | # Training 55 | checkpoint_epochs = None 56 | save_best = True 57 | early_stopping = True 58 | es_init_wait = 50 # es for early stopping 59 | in_memory = True 60 | 61 | # Classes 62 | file_path = os.path.abspath(os.path.dirname(__file__)) 63 | dcase_classes = pd.read_csv(os.path.join(file_path, validation), sep="\t").event_label.dropna().sort_values().unique() 64 | urban_classes = pd.read_csv(os.path.join(file_path, urban_train_tsv), 65 | sep="\t").event_label.dropna().sort_values().unique() 66 | 67 | # Logger 68 | terminal_level = logging.INFO 69 | 70 | # focal loss related 71 | alpha_fl = 0.5 72 | gamma_fl = float(1) 73 | 74 | -------------------------------------------------------------------------------- /data_utils/DataLoad.py: -------------------------------------------------------------------------------- 1 | import bisect 2 | import numpy as np 3 | import pandas as pd 4 | import torch 5 | import random 6 | import warnings 7 | from PIL import ImageFilter 8 | from torch.utils.data import Dataset 9 | from torch.utils.data.sampler import Sampler 10 | from utilities.utils import to_cuda_if_available 11 | from utilities.Logger import create_logger 12 | import config as cfg 13 | from utilities.BoxTransforms import Compose 14 | 15 | torch.manual_seed(0) 16 | random.seed(0) 17 | 18 | 19 | class DataLoadDf(Dataset): 20 | """ Class derived from pytorch DESED 21 | Prepare the data to be use in a batch mode 22 | 23 | Args: 24 | df: pandas.DataFrame, the dataframe containing the set infromation (feat_filenames, labels), 25 | it should contain these columns : 26 | "feature_filename" 27 | "feature_filename", "event_labels" 28 | "feature_filename", "onset", "offset", "event_label" 29 | encode_function: function(), function which encode labels 30 | transform: function(), (Default value = None), function to be applied to the sample (pytorch transformations) 31 | return_indexes: bool, (Default value = False) whether or not to return indexes when use __getitem__ 32 | 33 | Attributes: 34 | df: pandas.DataFrame, the dataframe containing the set information (feat_filenames, labels, ...) 35 | encode_function: function(), function which encode labels 36 | transform : function(), function to be applied to the sample (pytorch transformations) 37 | return_indexes: bool, whether or not to return indexes when use __getitem__ 38 | """ 39 | 40 | def __init__(self, df, encode_function=None, transform=None, return_indexes=False, in_memory=False, 41 | num_patches=None, sigma=0.26, mu= 0.2, fixed_patch_size=False): 42 | self.df = df 43 | self.encode_function = encode_function 44 | self.transform = transform 45 | self.return_indexes = return_indexes 46 | self.feat_filenames = df.feature_filename.drop_duplicates() 47 | self.filenames = df.filename.drop_duplicates() 48 | self.in_memory = in_memory 49 | self.num_patches = num_patches 50 | self.sigma = sigma 51 | self.mu = mu 52 | self.fixed_patch_size=fixed_patch_size 53 | self.logger = create_logger(__name__, terminal_level=cfg.terminal_level) 54 | if self.in_memory: 55 | self.features = {} 56 | 57 | def get_random_patch(self, feature): 58 | 59 | def get_random_center(i): 60 | return np.random.randint(int(t*i/2)+1, int(t*(1-i/2)))/t 61 | t, f = feature.shape 62 | 63 | if self.fixed_patch_size: 64 | l = np.asarray([128/t] * self.num_patches) 65 | else: 66 | l = self.mu + self.sigma * np.random.randn(5*self.num_patches) 67 | idx = [ i >= 0.05 and i < 0.8 for i in l] 68 | l = l[idx][:self.num_patches] 69 | c= [ get_random_center(i) for i in l] 70 | s, e = (c-l/2)*t, (c+l/2)*t 71 | s = [int(i) for i in s] 72 | if self.fixed_patch_size: 73 | e = [i+128 for i in s] 74 | else: 75 | e = [int(i) for i in e] 76 | boxes = [[(i+j)/(2*t), (j-i)/t] for i, j in zip(s, e)] 77 | return boxes 78 | 79 | 80 | def set_return_indexes(self, val): 81 | """ Set the value of self.return_indexes 82 | Args: 83 | val : bool, whether or not to return indexes when use __getitem__ 84 | """ 85 | self.return_indexes = val 86 | 87 | def get_feature_file_func(self, filename): 88 | """Get a feature file from a filename 89 | Args: 90 | filename: str, name of the file to get the feature 91 | 92 | Returns: 93 | numpy.array 94 | containing the features computed previously 95 | """ 96 | if not self.in_memory: 97 | data = np.load(filename).astype(np.float32) 98 | else: 99 | if self.features.get(filename) is None: 100 | data = np.load(filename).astype(np.float32) 101 | self.features[filename] = data 102 | else: 103 | data = self.features[filename] 104 | return data 105 | 106 | def __len__(self): 107 | """ 108 | Returns: 109 | int 110 | Length of the object 111 | """ 112 | length = len(self.feat_filenames) 113 | return length 114 | 115 | def get_sample(self, index): 116 | """From an index, get the features and the labels to create a sample 117 | 118 | Args: 119 | index: int, Index of the sample desired 120 | 121 | Returns: 122 | tuple 123 | Tuple containing the features and the labels (numpy.array, numpy.array) 124 | 125 | """ 126 | features = self.get_feature_file_func(self.feat_filenames.iloc[index]) 127 | 128 | # event_labels means weak labels, event_label means strong labels 129 | if "event_labels" in self.df.columns or {"onset", "offset", "event_label"}.issubset(self.df.columns): 130 | if "event_labels" in self.df.columns: 131 | label = self.df.iloc[index]["event_labels"] 132 | if pd.isna(label): 133 | label = [] 134 | if type(label) is str: 135 | if label == "": 136 | label = [] 137 | else: 138 | label = label.split(",") 139 | else: 140 | cols = ["onset", "offset", "event_label"] 141 | label = self.df[self.df.filename == self.filenames.iloc[index]][cols] 142 | if label.empty: 143 | label = [] 144 | else: 145 | if self.num_patches : 146 | label = self.get_random_patch(features) 147 | else: 148 | label = "empty" 149 | 150 | if index == 0: 151 | self.logger.debug("label to encode: {}".format(label)) 152 | if self.encode_function is not None: 153 | # labels are a list of string or list of list [[label, onset, offset]] 154 | y = self.encode_function(label) 155 | else: 156 | y = label 157 | sample = features, y 158 | return sample 159 | 160 | def __getitem__(self, index): 161 | """ Get a sample and transform it to be used in a ss_model, use the transformations 162 | 163 | Args: 164 | index : int, index of the sample desired 165 | 166 | Returns: 167 | tuple 168 | Tuple containing the features and the labels (numpy.array, numpy.array) or 169 | Tuple containing the features, the labels and the index (numpy.array, numpy.array, int) 170 | 171 | """ 172 | sample = self.get_sample(index) 173 | 174 | if self.transform: 175 | sample = self.transform(sample) 176 | 177 | if self.return_indexes: 178 | sample = (sample, index) 179 | 180 | return sample 181 | 182 | def set_transform(self, transform): 183 | """Set the transformations used on a sample 184 | 185 | Args: 186 | transform: function(), the new transformations 187 | """ 188 | self.transform = transform 189 | 190 | def add_transform(self, transform): 191 | if type(self.transform) is not Compose: 192 | raise TypeError("To add transform, the transform should already be a compose of transforms") 193 | transforms = self.transform.add_transform(transform) 194 | return DataLoadDf(self.df, self.encode_function, transforms, self.return_indexes, self.in_memory) 195 | 196 | 197 | class ConcatDataset(Dataset): 198 | """ 199 | DESED to concatenate multiple datasets. 200 | Purpose: useful to assemble different existing datasets, possibly 201 | large-scale datasets as the concatenation operation is done in an 202 | on-the-fly manner. 203 | 204 | Args: 205 | datasets : sequence, list of datasets to be concatenated 206 | """ 207 | 208 | @staticmethod 209 | def cumsum(sequence): 210 | r, s = [], 0 211 | for e in sequence: 212 | l = len(e) 213 | r.append(l + s) 214 | s += l 215 | return r 216 | 217 | @property 218 | def cluster_indices(self): 219 | cluster_ind = [] 220 | prec = 0 221 | for size in self.cumulative_sizes: 222 | cluster_ind.append(range(prec, size)) 223 | prec = size 224 | return cluster_ind 225 | 226 | def __init__(self, datasets): 227 | assert len(datasets) > 0, 'datasets should not be an empty iterable' 228 | self.datasets = list(datasets) 229 | self.cumulative_sizes = self.cumsum(self.datasets) 230 | 231 | def __len__(self): 232 | return self.cumulative_sizes[-1] 233 | 234 | def __getitem__(self, idx): 235 | dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) 236 | if dataset_idx == 0: 237 | sample_idx = idx 238 | else: 239 | sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] 240 | return self.datasets[dataset_idx][sample_idx] 241 | 242 | @property 243 | def cummulative_sizes(self): 244 | warnings.warn("cummulative_sizes attribute is renamed to " 245 | "cumulative_sizes", DeprecationWarning, stacklevel=2) 246 | return self.cumulative_sizes 247 | 248 | @property 249 | def df(self): 250 | df = self.datasets[0].df 251 | for dataset in self.datasets[1:]: 252 | df = pd.concat([df, dataset.df], axis=0, ignore_index=True, sort=False) 253 | return df 254 | 255 | 256 | class MultiStreamBatchSampler(Sampler): 257 | """Takes a dataset with cluster_indices property, cuts it into batch-sized chunks 258 | Drops the extra items, not fitting into exact batches 259 | Args: 260 | data_source : DESED, a DESED to sample from. Should have a cluster_indices property 261 | batch_size : int, a batch size that you would like to use later with Dataloader class 262 | shuffle : bool, whether to shuffle the data or not 263 | Attributes: 264 | data_source : DESED, a DESED to sample from. Should have a cluster_indices property 265 | batch_size : int, a batch size that you would like to use later with Dataloader class 266 | shuffle : bool, whether to shuffle the data or not 267 | """ 268 | 269 | def __init__(self, data_source, batch_sizes, shuffle=True): 270 | super(MultiStreamBatchSampler, self).__init__(data_source) 271 | self.data_source = data_source 272 | self.batch_sizes = batch_sizes 273 | l_bs = len(batch_sizes) 274 | nb_dataset = len(self.data_source.cluster_indices) 275 | assert l_bs == nb_dataset, "batch_sizes must be the same length as the number of datasets in " \ 276 | "the source {} != {}".format(l_bs, nb_dataset) 277 | self.shuffle = shuffle 278 | 279 | def __iter__(self): 280 | indices = self.data_source.cluster_indices 281 | if self.shuffle: 282 | for i in range(len(self.batch_sizes)): 283 | indices[i] = np.random.permutation(indices[i]) 284 | iterators = [] 285 | for i in range(len(self.batch_sizes)): 286 | iterators.append(grouper(indices[i], self.batch_sizes[i])) 287 | 288 | return (sum(subbatch_ind, ()) for subbatch_ind in zip(*iterators)) 289 | 290 | def __len__(self): 291 | val = np.inf 292 | for i in range(len(self.batch_sizes)): 293 | val = min(val, len(self.data_source.cluster_indices[i]) // self.batch_sizes[i]) 294 | return val 295 | 296 | 297 | def grouper(iterable, n): 298 | "Collect data into fixed-length chunks or blocks" 299 | # grouper('ABCDEFG', 3) --> ABC DEF" 300 | args = [iter(iterable)] * n 301 | return zip(*args) 302 | 303 | 304 | class data_prefetcher(): 305 | def __init__(self, loader, return_indexes=False): 306 | self.loader = iter(loader) 307 | self.stream = torch.cuda.Stream() 308 | self.return_index = return_indexes 309 | self.preload() 310 | 311 | def preload(self): 312 | try: 313 | if self.return_index: 314 | (self.next_input, self.next_target), self.next_index = next(self.loader) 315 | else: 316 | self.next_input, self.next_target = next(self.loader) 317 | self.next_index = None 318 | except StopIteration: 319 | self.next_input = None 320 | self.next_target = None 321 | self.next_index = None 322 | return 323 | with torch.cuda.stream(self.stream): 324 | self.next_input = to_cuda_if_available(self.next_input) 325 | self.next_target = to_cuda_if_available(self.next_target) 326 | 327 | def next(self): 328 | torch.cuda.current_stream().wait_stream(self.stream) 329 | input = self.next_input 330 | target = self.next_target 331 | index = self.next_index 332 | self.preload() 333 | if not self.return_index: 334 | return input, target 335 | else: 336 | return (input, target), index 337 | -------------------------------------------------------------------------------- /data_utils/collapse_event.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # encoding: utf-8 3 | """ 4 | @author: yzr 5 | @file: collapse_event.py 6 | @time: 2020/9/2 21:43 7 | """ 8 | import os, sys 9 | sys.path.append(os.path.abspath('.')) 10 | import pandas as pd 11 | from tqdm import tqdm 12 | import config as cfg 13 | 14 | def collapse(meta_df): 15 | df_new = pd.DataFrame() 16 | filenames = meta_df.filename.drop_duplicates() 17 | cols = ["onset", "offset", "event_label"] 18 | for f in tqdm(filenames): 19 | label = meta_df[meta_df.filename == f][cols] 20 | events = label.event_label.drop_duplicates() 21 | for e in events: 22 | time = label[label.event_label == e][["onset", "offset"]] 23 | time = time.sort_values(by='onset') 24 | time = time.reset_index(drop=True) 25 | i = 0 26 | while i < len(time): 27 | if i == 0: 28 | i += 1 29 | continue 30 | if time.loc[i, 'onset'] <= time.loc[i-1, 'offset']: 31 | time.loc[i-1, 'offset'] = max(time.loc[i, 'offset'], time.loc[i-1, 'offset']) 32 | time = time.drop(index=i).reset_index(drop=True) 33 | i = i-1 34 | i += 1 35 | time["event_label"] = e.strip() 36 | time["filename"] = f 37 | df_new = df_new.append(time, ignore_index=True) 38 | return df_new 39 | 40 | if __name__=='__main__': 41 | annotation_dir = os.path.join(cfg.urbansed_dir, "annotations") 42 | datasets = ["train", "validate", "test"] 43 | meta_dir = annotation_dir.replace("annotations", "metadata") 44 | os.makedirs(meta_dir, exist_ok=True) 45 | for dataset in datasets: 46 | df = pd.DataFrame(columns=["filename", "event_label", "onset", "offset"]) 47 | df_sub = pd.DataFrame() 48 | f_list = list(filter(lambda x: x.endswith(".txt") and not x.startswith("."), os.listdir(os.path.join(annotation_dir, dataset)))) 49 | for f in tqdm(f_list): 50 | fr = open(os.path.join(annotation_dir, dataset, f),'r') 51 | lines = fr.readlines() 52 | lines = [l.split('\t') for l in lines] 53 | df_sub = pd.DataFrame(lines, columns=["onset", "offset", "event_label"]) 54 | df_sub["filename"] = os.path.splitext(f)[0] + ".wav" 55 | df = df.append(df_sub,ignore_index=True) 56 | df = collapse(df) 57 | df = df[["filename", "event_label", "onset", "offset"]] 58 | df.to_csv(os.path.join(meta_dir, dataset+".tsv"), index=False, sep="\t", float_format="%.3f") 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | -------------------------------------------------------------------------------- /engine.py: -------------------------------------------------------------------------------- 1 | import math 2 | import sys 3 | import time 4 | 5 | import inspect 6 | import torch 7 | import pandas as pd 8 | from data_utils.DataLoad import data_prefetcher 9 | from utilities.metrics import audio_tagging_results, compute_metrics 10 | from utilities.Logger import create_logger 11 | from utilities.distribute import reduce_dict, get_reduced_loss 12 | from utilities.mixup import mixup_data, mixup_label_unlabel 13 | from utilities.utils import MetricLogger, SmoothedValue, AverageMeter, to_cuda_if_available 14 | from collections import Counter 15 | import config as cfg 16 | import numpy as np 17 | 18 | 19 | def train(train_loader, model, criterion, optimizer, c_epoch, accumrating_gradient_steps, 20 | mask_weak=None, mask_strong=None, fine_tune=False, normalize=False, max_norm=0.1, mix_up_ratio=0): 21 | """ One epoch of a Mean Teacher model 22 | Args: 23 | train_loader: torch.utils.data.DataLoader, iterator of training batches for an epoch. 24 | Should return a tuple: (input, labels) 25 | model: torch.Module, model to be trained 26 | criterion: 27 | optimizer: torch.Module, optimizer used to train the model 28 | c_epoch: int, the current epoch of training 29 | mask_weak: slice or list, mask the batch to get only the weak labeled data (used to calculate the loss) 30 | mask_strong: slice or list, mask the batch to get only the strong labeled data (used to calcultate the loss) 31 | adjust_lr: bool, Whether or not to adjust the learning rate during training (params in config) 32 | """ 33 | log = create_logger(__name__ + "/" + inspect.currentframe().f_code.co_name, terminal_level=cfg.terminal_level) 34 | metric_logger: MetricLogger = MetricLogger(delimiter=" ") 35 | metric_logger.add_meter('lr', SmoothedValue(window_size=1, fmt='{value:.6f}')) 36 | metric_logger.add_meter('class_error', SmoothedValue(window_size=1, fmt='{value:.2f}')) 37 | log.debug("Nb batches: {}".format(len(train_loader))) 38 | end = time.time() 39 | data_time = AverageMeter() 40 | batch_time = AverageMeter() 41 | prefetcher = data_prefetcher(train_loader) 42 | batch_input, target = prefetcher.next() 43 | i = -1 44 | while batch_input is not None: 45 | i += 1 46 | # measure data loading time 47 | data_time.update(time.time() - end) 48 | global_step = c_epoch * len(train_loader) + i + 1 49 | 50 | if mix_up_ratio: 51 | batch_input, target, mask_strong_c, mask_weak_c = mixup_data(batch_input, target, mask_strong, mask_weak, mix_up_ratio, alpha=1) 52 | else: 53 | mask_weak_c, mask_strong_c = mask_weak, mask_strong 54 | 55 | # Outputs 56 | if 'patches' in target[0]: 57 | patches = [t['patches'] for t in target] 58 | patches = torch.stack(patches, dim=0) 59 | outputs = model(batch_input.decompose(), patches) 60 | else: 61 | outputs = model(batch_input) 62 | 63 | loss_dict, _ = criterion(outputs, target, mask_weak_c, mask_strong_c, fine_tune, normalize) 64 | weight_dict = criterion.weight_dict 65 | losses = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict) 66 | 67 | # reduce losses over all GPUs for logging purposes 68 | loss_value = get_reduced_loss(loss_dict, weight_dict, metric_logger) 69 | 70 | if not math.isfinite(loss_value): 71 | log.info("Loss is {}, stopping training".format(loss_value)) 72 | log.info(loss_dict) 73 | sys.exit(1) 74 | 75 | losses.backward() 76 | if (i + 1) % accumrating_gradient_steps == 0: 77 | if max_norm > 0: 78 | torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm) 79 | optimizer.step() 80 | optimizer.zero_grad() 81 | 82 | global_step += 1 83 | metric_logger.update(loss=loss_value) 84 | metric_logger.update(class_error=0) 85 | metric_logger.update(lr=optimizer.param_groups[0]["lr"]) 86 | # measure elapsed time 87 | batch_time.update(time.time() - end) 88 | end = time.time() 89 | batch_input, target = prefetcher.next() 90 | # gather the stats from all processes 91 | metric_logger.synchronize_between_processes() 92 | log.info("Epoch:{} data_time:{:.3f}({:.3f}) batch_time:{:.3f}({:.3f})". 93 | format(c_epoch, data_time.val, data_time.avg, batch_time.val, batch_time.avg)) 94 | log.info("Train averaged stats: \n" + str(metric_logger)) 95 | return loss_value 96 | 97 | def semi_train(train_loader, model, ema, criterion, optimizer, c_epoch, accumrating_gradient_steps, 98 | accumlating_ema_steps, postprocessor, 99 | mask_weak=None, mask_strong=None, fine_tune=False, normalize=False, max_norm=0.1, mask_unlabel=None, 100 | mask_label=None, fl=False, mix_up_ratio=0, classwise_threshold=None): 101 | log = create_logger(__name__ + "/" + inspect.currentframe().f_code.co_name, terminal_level=cfg.terminal_level) 102 | metric_logger: MetricLogger = MetricLogger(delimiter=" ") 103 | metric_logger.add_meter('lr', SmoothedValue(window_size=1, fmt='{value:.6f}')) 104 | metric_logger.add_meter('class_error', SmoothedValue(window_size=1, fmt='{value:.2f}')) 105 | metric_logger.add_meter('loss', SmoothedValue(window_size=1, fmt='{value:.6f}')) 106 | log.debug("Nb batches: {}".format(len(train_loader))) 107 | end = time.time() 108 | data_time = AverageMeter() 109 | batch_time = AverageMeter() 110 | i = -1 111 | pseudo_labels_counter = Counter() 112 | for _, data in enumerate(train_loader): 113 | i += 1 114 | ((batch_input_teacher, batch_input_student), target) = to_cuda_if_available(data) 115 | data_time.update(time.time() - end) 116 | global_step = c_epoch * len(train_loader) + i + 1 117 | 118 | # split labeled and unlabeld data 119 | batch_input_labeled = batch_input_teacher[mask_label] 120 | target_labeled = target[mask_label] 121 | 122 | batch_input_unnlabel_teacher = batch_input_teacher[mask_unlabel] 123 | batch_input_unnlabel_student = batch_input_student[mask_unlabel] 124 | target_unlabeled = target[mask_unlabel] 125 | 126 | 127 | # train on labeled data like sedt 128 | if mix_up_ratio > 0: 129 | batch_input_labeled, target_labeled, mask_strong_c, mask_weak_c = mixup_data(batch_input_labeled, 130 | target_labeled, mask_strong, 131 | mask_weak, mix_up_ratio = mix_up_ratio, alpha=1) 132 | else: 133 | mask_weak_c, mask_strong_c = mask_weak, mask_strong 134 | labeled_outputs = model(batch_input_labeled) 135 | sup_loss_dict, _ = criterion(labeled_outputs, target_labeled, mask_weak_c, mask_strong_c, fine_tune, normalize, 136 | fl) 137 | weight_dict = criterion.weight_dict 138 | sup_losses = sum(sup_loss_dict[k] * weight_dict[k] for k in sup_loss_dict.keys() if k in weight_dict) 139 | sup_loss_value = get_reduced_loss(sup_loss_dict, weight_dict, metric_logger, prefix="sup_") 140 | 141 | 142 | # train on unlabeld data 143 | # teacher 144 | ema.apply_shadow() 145 | with torch.no_grad(): 146 | tea_outputs = model(batch_input_unnlabel_teacher) 147 | orig_unlabel_target_sizes = torch.stack([t["orig_size"] for t in target_unlabeled], dim=0) 148 | pseudo_labels = get_pseudo_labels(tea_outputs, postprocessor, orig_unlabel_target_sizes, target_unlabeled, 149 | pseudo_labels_counter, classwise_threshold=classwise_threshold) 150 | if mix_up_ratio > 0: 151 | batch_input_unnlabel_student, pseudo_labels = mixup_label_unlabel(batch_input_labeled, 152 | batch_input_unnlabel_student, target_labeled, 153 | pseudo_labels, alpha=1) 154 | ema.restore() 155 | 156 | # student 157 | st_outputs = model(batch_input_unnlabel_student) 158 | 159 | unsup_loss_dict, _ = criterion(st_outputs, pseudo_labels, None, 160 | slice(batch_input_unnlabel_student.tensors.size(0)), fine_tune, normalize, fl) 161 | 162 | unsup_losses = sum(unsup_loss_dict[k] * weight_dict[k] for k in unsup_loss_dict.keys() if k in weight_dict) 163 | unsup_loss_value = get_reduced_loss(unsup_loss_dict, weight_dict, metric_logger, prefix="unsup_") 164 | 165 | 166 | total_losses = sup_losses + unsup_losses 167 | if not (math.isfinite(total_losses)): 168 | print("Loss is infinite, stopping training") 169 | sys.exit(1) 170 | total_losses.backward() 171 | 172 | 173 | 174 | if (i + 1) % accumrating_gradient_steps == 0: 175 | if max_norm > 0: 176 | torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm) 177 | optimizer.step() 178 | optimizer.zero_grad() 179 | 180 | if (i + 1) % accumlating_ema_steps == 0: 181 | ema.update() 182 | 183 | global_step += 1 184 | metric_logger.update(class_error=0) 185 | metric_logger.update(lr=optimizer.param_groups[0]["lr"]) 186 | metric_logger.update(loss=sup_loss_value + unsup_loss_value) 187 | # measure elapsed time 188 | batch_time.update(time.time() - end) 189 | end = time.time() 190 | 191 | # gather the stats from all processes 192 | log.info("Epoch:{} data_time:{:.3f}({:.3f}) batch_time:{:.3f}({:.3f})". 193 | format(c_epoch, data_time.val, data_time.avg, batch_time.val, batch_time.avg)) 194 | log.info("Train averaged stats: \n" + str(metric_logger)) 195 | log.info("class nums: " + str(pseudo_labels_counter)) 196 | return sup_loss_value + unsup_loss_value, pseudo_labels_counter 197 | 198 | 199 | def evaluate(model, criterion, postprocessors, dataloader, decoder, ref_df, fusion_strategy, at=True, cal_seg=False, cal_clip=False): 200 | logger = create_logger(__name__ + "/" + inspect.currentframe().f_code.co_name, terminal_level=cfg.terminal_level) 201 | audio_tag_dfs, dec_prediction_dfs = get_sedt_predictions(model, criterion, postprocessors, dataloader, decoder, fusion_strategy, at) 202 | 203 | 204 | if not audio_tag_dfs.empty: 205 | clip_metric = audio_tagging_results(ref_df, audio_tag_dfs) 206 | logger.info(f"AT Class-wise clip metrics \n {'=' * 50} \n {clip_metric}") 207 | 208 | metrics = {} 209 | logger.info(f"decoder output \n {'=' * 50}") 210 | for at_m, dec_pred in dec_prediction_dfs.items(): 211 | logger.info(f"Fusion strategy: {at_m}") 212 | event_macro_f1 = compute_metrics(dec_pred, ref_df, cal_seg=cal_seg, cal_clip=cal_clip) 213 | metrics[at_m] = event_macro_f1 214 | return metrics 215 | 216 | 217 | 218 | def get_sedt_predictions(model, criterion, postprocessors, dataloader, decoder, fusion_strategy, at=True): 219 | """ Get the predictions of a trained model on a specific set 220 | Args: 221 | model: torch.Module, a trained pytorch model (you usually want it to be in .eval() mode). 222 | dataloader: torch.utils.data.DataLoader, giving ((input_data, label), indexes) but label is not used here 223 | decoder: function, takes a numpy.array of shape (time_steps, n_labels) as input and return a list of lists 224 | of ("event_label", "onset", "offset") for each label predicted. 225 | 226 | Returns: 227 | dict of the different predictions with associated fusion_strategy 228 | """ 229 | logger = create_logger(__name__ + "/" + inspect.currentframe().f_code.co_name, terminal_level=cfg.terminal_level) 230 | # Init a dataframe per threshold 231 | metric_logger = MetricLogger(delimiter=" ") 232 | # metric_logger.add_meter('class_error', utils.SmoothedValue(window_size=1, fmt='{value:.2f}')) 233 | epoch_time = time.time() 234 | decoding_time = 0. 235 | dec_prediction_dfs = {} 236 | audio_tag_dfs = pd.DataFrame() 237 | for at_m in fusion_strategy: 238 | dec_prediction_dfs[at_m] = pd.DataFrame() 239 | 240 | # Get predictions 241 | prefetcher = data_prefetcher(dataloader, return_indexes=True) 242 | i = -1 243 | (input_data, targets), indexes = prefetcher.next() 244 | while input_data is not None: 245 | i += 1 246 | with torch.no_grad(): 247 | outputs = model(input_data) 248 | # ############## 249 | # compute losses 250 | # ############## 251 | weak_mask = None 252 | strong_mask = slice(len(input_data.tensors)) 253 | loss_dict, indices = criterion(outputs, targets, weak_mask, strong_mask) 254 | weight_dict = criterion.weight_dict 255 | 256 | # reduce losses over all GPUs for logging purposes 257 | loss_value = get_reduced_loss(loss_dict, weight_dict, metric_logger) 258 | 259 | # ################### 260 | # get decoder results 261 | # ################### 262 | orig_target_sizes = torch.stack([t["orig_size"] for t in targets], dim=0) 263 | if at: 264 | assert "at" in outputs 265 | audio_tags = outputs["at"] 266 | audio_tags = (audio_tags > 0.5).long() 267 | for j, audio_tag in enumerate(audio_tags): 268 | audio_tag_res = decoder.decode_weak(audio_tag) 269 | audio_tag_res = pd.DataFrame(audio_tag_res, columns=["event_label"]) 270 | audio_tag_res["filename"] = dataloader.dataset.filenames.iloc[indexes[j]] 271 | audio_tag_res["onset"] = 0 272 | audio_tag_res["offset"] = 0 273 | audio_tag_dfs = audio_tag_dfs.append(audio_tag_res) 274 | else: 275 | audio_tags = None 276 | 277 | decoding_start = time.time() 278 | for at_m in fusion_strategy: 279 | results = postprocessors['bbox'](outputs, orig_target_sizes, audio_tags=audio_tags, at_m=at_m) 280 | 281 | for j, res in enumerate(results): 282 | for item in res: 283 | res[item] = res[item].cpu() 284 | pred = decoder.decode_strong(res, threshold=0.5) 285 | pred = pd.DataFrame(pred, columns=["event_label", "onset", "offset", "score"]) 286 | # Put them in seconds 287 | pred.loc[:, ["onset", "offset"]] = pred[["onset", "offset"]].clip(0, cfg.max_len_seconds) 288 | pred["filename"] = dataloader.dataset.filenames.iloc[indexes[j]] 289 | dec_prediction_dfs[at_m] = dec_prediction_dfs[at_m].append(pred, ignore_index=True) 290 | 291 | decoding_time += time.time() - decoding_start 292 | (input_data, targets), indexes = prefetcher.next() 293 | 294 | logger.info("Val averaged stats:" + metric_logger.__str__()) 295 | epoch_time = time.time() - epoch_time 296 | logger.info(f"val_epoch_time:{epoch_time} decoding_time:{decoding_time}") 297 | return audio_tag_dfs, dec_prediction_dfs 298 | 299 | 300 | def get_pseudo_labels(tea_outputs, postprocessor, orig_unlabel_target_sizes, target_unlabeled, pseudo_labels_counter, 301 | threshold=0.5, del_overlap=True, classwise_threshold=None): 302 | if "at" in tea_outputs: 303 | audio_tags = tea_outputs["at"] 304 | audio_tags = (audio_tags >= classwise_threshold).long() 305 | else: 306 | audio_tags = None 307 | 308 | results = postprocessor['bbox'](tea_outputs, orig_unlabel_target_sizes, audio_tags=audio_tags, at_m=1, is_semi=True, 309 | threshold=None) 310 | 311 | for i, result in enumerate(results): 312 | filter_class = classwise_threshold[result['labels']] 313 | filtered_idx_1 = result['scores'] >= filter_class # confidence score > threshold 314 | filtered_idx_2 = result['boxes'][:, 1] > 0.2 / orig_unlabel_target_sizes[0].item() # duration > 0.02 s 315 | filtered_idx = filtered_idx_1 & filtered_idx_2 316 | 317 | if not del_overlap: 318 | target_unlabeled[i]['labels'] = result['labels'][filtered_idx] 319 | target_unlabeled[i]['boxes'] = result['boxes'][filtered_idx] 320 | else: 321 | # delete overlapped event 322 | tmp_labels, tmp_boxes, tmp_scores = result['labels'][filtered_idx], result['boxes'][filtered_idx], \ 323 | result['scores'][filtered_idx] 324 | tmp_scores, indices = tmp_scores.sort(descending=True) 325 | x = tmp_boxes[:, 0] - tmp_boxes[:, 1] / 2 326 | y = tmp_boxes[:, 0] + tmp_boxes[:, 1] / 2 327 | keep = [] 328 | while indices.numel() > 0: 329 | if indices.numel() == 1: 330 | k = indices.item() 331 | keep.append(k) 332 | break 333 | else: 334 | k = indices[0].item() 335 | keep.append(k) 336 | cur_label = tmp_labels[k] 337 | x_max = x[indices[1:]].clamp(min=x[k]) 338 | y_min = y[indices[1:]].clamp(max=y[k]) 339 | overlap = (y_min - x_max).clamp(min=0) 340 | idx = ((overlap == 0) + (tmp_labels[indices[1:]] != cur_label.item())).nonzero().squeeze() 341 | if idx.numel() == 0: 342 | break 343 | indices = indices[idx + 1] 344 | target_unlabeled[i]['labels'] = tmp_labels[keep] 345 | target_unlabeled[i]['boxes'] = tmp_boxes[keep] 346 | pseudo_labels_counter.update(tmp_labels[keep].cpu().numpy().tolist()) 347 | 348 | return target_unlabeled 349 | 350 | def adjust_threshold(pseudo_labels_counter, origin_threshold): 351 | labels_num_dict = dict(sorted(dict(pseudo_labels_counter).items(), key=lambda x: x[0])) 352 | labels_num = np.array(list(labels_num_dict.values())) 353 | labels_ratio = torch.tensor(labels_num / np.sum(labels_num)) 354 | true_distribution = torch.tensor( 355 | [0.09915014, 0.02266289, 0.08050047, 0.13385269, 0.13456091, 0.01534466, 0.02219075, 0.05594901, 0.41406988, 356 | 0.0217186]) 357 | adjust_ratio = (labels_ratio / true_distribution) ** 0.7 358 | adjust_ratio = to_cuda_if_available(adjust_ratio) 359 | class_threshold = torch.clamp(adjust_ratio * origin_threshold, min=0.45, max=0.7) 360 | return class_threshold -------------------------------------------------------------------------------- /img/sedt.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Anaesthesiaye/sound_event_detection_transformer/03e2a64acfd499a549be3a81ac71c1b6ce87198a/img/sedt.png -------------------------------------------------------------------------------- /img/sp-sedt.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Anaesthesiaye/sound_event_detection_transformer/03e2a64acfd499a549be3a81ac71c1b6ce87198a/img/sp-sedt.png -------------------------------------------------------------------------------- /sedt/__init__.py: -------------------------------------------------------------------------------- 1 | from utilities.utils import to_cuda_if_available 2 | from .spsedt import SPSEDT 3 | from .sedt import SEDT, SetCriterion, PostProcess 4 | from .backbone import build_backbone 5 | from .transformer import build_transformer, TransformerDecoder, TransformerDecoderLayer 6 | from .matcher import build_matcher 7 | 8 | def build_model(args): 9 | if args.self_sup: 10 | num_classes = 1 11 | else: 12 | num_classes = args.num_classes 13 | 14 | backbone = build_backbone(args) 15 | transformer = build_transformer(args) 16 | if args.self_sup: 17 | model = SPSEDT( 18 | backbone, 19 | transformer, 20 | num_classes=num_classes, 21 | num_queries=args.num_queries, 22 | aux_loss=args.aux_loss, 23 | feature_recon=args.feature_recon, 24 | query_shuffle=args.query_shuffle, 25 | num_patches=args.num_patches 26 | ) 27 | else: 28 | model = SEDT( 29 | backbone, 30 | transformer, 31 | num_classes=num_classes, 32 | num_queries=args.num_queries, 33 | aux_loss=args.aux_loss, 34 | dec_at=args.dec_at, 35 | pooling=args.pooling 36 | ) 37 | matcher = build_matcher(args) 38 | weight_dict = {'loss_ce': args.ce_loss_coef, 'loss_bbox': args.bbox_loss_coef, 'loss_giou': args.giou_loss_coef} 39 | losses = ['labels', 'boxes', 'cardinality'] 40 | if not args.self_sup: 41 | if args.dec_at: 42 | weight_dict['loss_weak'] = args.weak_loss_coef 43 | losses += ['weak'] 44 | if args.pooling: 45 | weight_dict['loss_weak_p'] = args.weak_loss_p_coef 46 | else: 47 | if args.feature_recon: 48 | losses += ['feature'] 49 | weight_dict['loss_feature'] = 1 50 | 51 | # TODO this is a hack 52 | if args.aux_loss: 53 | aux_weight_dict = {} 54 | for i in range(args.dec_layers - 1): 55 | aux_weight_dict.update({k + f'_{i}': v for k, v in weight_dict.items()}) 56 | weight_dict.update(aux_weight_dict) 57 | 58 | criterion = SetCriterion(num_classes, matcher=matcher, weight_dict=weight_dict, 59 | eos_coef=args.eos_coef, losses=losses) 60 | criterion = to_cuda_if_available(criterion) 61 | postprocessors = {'bbox': PostProcess()} 62 | 63 | return model, criterion, postprocessors 64 | 65 | 66 | 67 | 68 | -------------------------------------------------------------------------------- /sedt/backbone.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | """ 3 | Backbone modules. 4 | """ 5 | 6 | import torch 7 | import torch.nn.functional as F 8 | import torchvision 9 | from torch import nn 10 | from torchvision.models._utils import IntermediateLayerGetter 11 | 12 | from utilities.utils import NestedTensor 13 | 14 | from .position_encoding import build_position_encoding 15 | 16 | 17 | class FrozenBatchNorm2d(torch.nn.Module): 18 | """ 19 | BatchNorm2d where the batch statistics and the affine parameters are fixed. 20 | 21 | Copy-paste from torchvision.misc.ops with added eps before rqsrt, 22 | without which any other models than torchvision.models.resnet[18,34,50,101] 23 | produce nans. 24 | """ 25 | 26 | def __init__(self, n): 27 | super(FrozenBatchNorm2d, self).__init__() 28 | self.register_buffer("weight", torch.ones(n)) 29 | self.register_buffer("bias", torch.zeros(n)) 30 | self.register_buffer("running_mean", torch.zeros(n)) 31 | self.register_buffer("running_var", torch.ones(n)) 32 | 33 | def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, 34 | missing_keys, unexpected_keys, error_msgs): 35 | num_batches_tracked_key = prefix + 'num_batches_tracked' 36 | if num_batches_tracked_key in state_dict: 37 | del state_dict[num_batches_tracked_key] 38 | 39 | super(FrozenBatchNorm2d, self)._load_from_state_dict( 40 | state_dict, prefix, local_metadata, strict, 41 | missing_keys, unexpected_keys, error_msgs) 42 | 43 | def forward(self, x): 44 | # move reshapes to the beginning 45 | # to make it fuser-friendly 46 | w = self.weight.reshape(1, -1, 1, 1) 47 | b = self.bias.reshape(1, -1, 1, 1) 48 | rv = self.running_var.reshape(1, -1, 1, 1) 49 | rm = self.running_mean.reshape(1, -1, 1, 1) 50 | eps = 1e-5 51 | scale = w * (rv + eps).rsqrt() 52 | bias = b - rm * scale 53 | return x * scale + bias 54 | 55 | 56 | class BackboneBase(nn.Module): 57 | 58 | def __init__(self, backbone: nn.Module, train_backbone: bool, num_channels: int, return_interm_layers: bool): 59 | super().__init__() 60 | for name, parameter in backbone.named_parameters(): 61 | if not train_backbone or 'conv0' not in name and 'layer2' not in name and 'layer3' not in name and 'layer4' not in name: # conv(1,3) 62 | parameter.requires_grad_(False) 63 | if return_interm_layers: 64 | return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"} 65 | else: 66 | return_layers = {'layer4': "0"} 67 | 68 | self.body = IntermediateLayerGetter(backbone, return_layers=return_layers) 69 | self.num_channels = num_channels 70 | 71 | def forward(self, tensor_list: NestedTensor): 72 | if isinstance(tensor_list, NestedTensor): 73 | xs = self.body(tensor_list.tensors) 74 | # out: Dict[str, NestedTensor] = {} 75 | out = {} 76 | for name, x in xs.items(): 77 | if name == '1': # the value of "avgpool" 78 | continue 79 | m = tensor_list.mask 80 | assert m is not None 81 | mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0] 82 | out[name] = NestedTensor(x, mask) 83 | else: 84 | out = self.body(tensor_list) 85 | 86 | return out 87 | 88 | 89 | class Backbone(BackboneBase): 90 | """ResNet backbone with frozen BatchNorm.""" 91 | 92 | def __init__(self, name: str, 93 | train_backbone: bool, 94 | return_interm_layers: bool, 95 | dilation: bool 96 | ): 97 | backbone = nn.Sequential() 98 | resnet = getattr(torchvision.models, name)( 99 | replace_stride_with_dilation=[False, False, dilation], 100 | pretrained=True, norm_layer=FrozenBatchNorm2d) 101 | # strategy 1: add conv0 to change the channel of spectrogram 102 | backbone.add_module('conv0', nn.Conv2d(1, 3, 1)) 103 | # # strategy 2 : alter the kernel of conv1 from (3,64) to (1,64) 104 | # ori_conv1 = resnet.conv1 105 | # resnet.conv1 = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) 106 | # resnet.conv1.weight = nn.Parameter(ori_conv1.weight.mean(dim=1).unsqueeze(dim=1)) 107 | for name, module in resnet.named_children(): 108 | if name == "avgpool": 109 | backbone.add_module('maxpool_', nn.AdaptiveMaxPool2d(output_size=(1, 1))) 110 | else: 111 | backbone.add_module(name, module) 112 | num_channels = 512 if name in ('resnet18', 'resnet34') else 2048 113 | super().__init__(backbone, train_backbone, num_channels, return_interm_layers) 114 | 115 | 116 | class Joiner(nn.Sequential): 117 | def __init__(self, backbone, position_embedding): 118 | super().__init__(backbone, position_embedding) 119 | 120 | def forward(self, tensor_list: NestedTensor): 121 | if isinstance(tensor_list, NestedTensor): 122 | xs = self[0](tensor_list) 123 | # out: List[NestedTensor] = [] 124 | out = [] 125 | pos = [] 126 | for name, x in xs.items(): 127 | out.append(x) 128 | # position encoding 129 | pos.append(self[1](x).to(x.tensors.dtype)) 130 | return out, pos 131 | else: 132 | return list(self[0](tensor_list).values()) 133 | 134 | 135 | def build_backbone(args): 136 | position_embedding = build_position_encoding(args) 137 | train_backbone = args.lr_backbone > 0 138 | backbone = Backbone(args.backbone, train_backbone, False, args.dilation) 139 | model = Joiner(backbone, position_embedding) 140 | model.num_channels = backbone.num_channels 141 | return model 142 | -------------------------------------------------------------------------------- /sedt/matcher.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Modified from DETR (https://github.com/facebookresearch/detr) 3 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 4 | # ------------------------------------------------------------------------ 5 | """ 6 | Modules to compute the matching cost and solve the corresponding LSAP. 7 | """ 8 | 9 | import torch 10 | from scipy.optimize import linear_sum_assignment 11 | from torch import nn 12 | from collections import Counter 13 | from utilities.box_ops import box_cxcywh_to_xyxy, generalized_box_iou 14 | import config as cfg 15 | 16 | 17 | class HungarianMatcher(nn.Module): 18 | """This class computes an assignment between the targets and the predictions of the network 19 | 20 | For efficiency reasons, the targets don't include the no_object. Because of this, in general, 21 | there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions, 22 | while the others are un-matched (and thus treated as non-objects). 23 | """ 24 | 25 | def __init__(self, cost_class: float = 1, cost_bbox: float = 1, cost_giou: float = 1, epsilon=0, alpha=100): 26 | """Creates the matcher 27 | 28 | Params: 29 | cost_class: This is the relative weight of the classification error in the matching cost 30 | cost_bbox: This is the relative weight of the L1 error of the bounding box coordinates in the matching cost 31 | cost_giou: This is the relative weight of the giou loss of the bounding box in the matching cost 32 | """ 33 | super().__init__() 34 | self.cost_class = cost_class 35 | self.cost_bbox = cost_bbox 36 | self.cost_giou = cost_giou 37 | self.epsilon = epsilon 38 | self.alpha = alpha 39 | assert cost_class != 0 or cost_bbox != 0 or cost_giou != 0, "all costs cant be 0" 40 | 41 | @torch.no_grad() 42 | def forward(self, outputs, targets, fine_tune=False, normalize=False, fl = False): 43 | """ Performs the matching 44 | 45 | Params: 46 | outputs: This is a dict that contains at least these entries: 47 | "pred_logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits 48 | "pred_boxes": Tensor of dim [batch_size, num_queries, 4] with the predicted box coordinates 49 | 50 | targets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing: 51 | "labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of ground-truth 52 | objects in the target) containing the class labels 53 | "boxes": Tensor of dim [num_target_boxes, 4] containing the target box coordinates 54 | 55 | Returns: 56 | A list of size batch_size, containing tuples of (index_i, index_j) where: 57 | - index_i is the indices of the selected predictions (in order) 58 | - index_j is the indices of the corresponding selected targets (in order) 59 | For each batch element, it holds: 60 | len(index_i) = len(index_j) = min(num_queries, num_target_boxes) 61 | """ 62 | bs, num_queries = outputs["pred_logits"].shape[:2] 63 | 64 | # We flatten to compute the cost matrices in a batch 65 | out_prob = outputs["pred_logits"].flatten(0, 1).softmax(-1) if not fl else outputs["pred_logits"].flatten(0, 1).sigmoid() # [batch_size * num_queries, num_classes] 66 | out_bbox = outputs["pred_boxes"].flatten(0, 1) # [batch_size * num_queries, 4] 67 | 68 | # Also concat the target labels and boxes 69 | tgt_ids = torch.cat([v["labels"][:len(v["boxes"])] for v in targets]) 70 | tgt_bbox = torch.cat([v["boxes"] for v in targets]) 71 | 72 | # Compute the classification cost. Contrary to the loss, we don't use the NLL, 73 | # but approximate it in 1 - proba[target class]. 74 | # The 1 is a constant that doesn't change the matching, it can be ommitted. 75 | if not fl: 76 | cost_class = -out_prob[:, tgt_ids] 77 | else: 78 | alpha_fl = cfg.alpha_fl 79 | gamma_fl = cfg.gamma_fl 80 | neg_cost_class = (1 - alpha_fl) * (out_prob ** gamma_fl) * (-(1 - out_prob + 1e-8).log()) 81 | pos_cost_class = alpha_fl * ((1 - out_prob) ** gamma_fl) * (-(out_prob + 1e-8).log()) 82 | cost_class = pos_cost_class[:, tgt_ids] - neg_cost_class[:, tgt_ids] 83 | 84 | # Compute the L1 cost between boxes 85 | cost_bbox = torch.cdist(box_cxcywh_to_xyxy(out_bbox), box_cxcywh_to_xyxy(tgt_bbox), p=1) 86 | 87 | # Compute the giou cost betwen boxes 88 | cost_giou = -generalized_box_iou(box_cxcywh_to_xyxy(out_bbox), box_cxcywh_to_xyxy(tgt_bbox)) 89 | 90 | # Final cost matrix 91 | C = self.cost_bbox * cost_bbox + self.cost_class * cost_class + self.cost_giou * cost_giou 92 | C = C.view(bs, num_queries, -1).cpu() 93 | 94 | sizes = [len(v["boxes"]) for v in targets] 95 | indices1 = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(sizes, -1))] 96 | # matching order by TLOSS matcher 97 | idx = [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices1] 98 | Coef = [] 99 | if fine_tune: 100 | # matching order by LLoss 101 | out_bbox = out_bbox.unsqueeze(dim=1).repeat(1, tgt_bbox.shape[0], 1) 102 | tgt_bbox = tgt_bbox.unsqueeze(dim=0).repeat(out_bbox.shape[0], 1, 1) 103 | # C_l = torch.abs(box_cxcywh_to_xyxy(out_bbox) - box_cxcywh_to_xyxy(tgt_bbox)).max(-1)[0] 104 | C_l = self.cost_bbox * cost_bbox + self.cost_giou * cost_giou 105 | C_l = C_l.view(bs, num_queries, -1).cpu() 106 | indices2 = [c[i].min(-1) for i, c in enumerate(C_l.split(sizes, -1))] 107 | idx1 = idx 108 | idx = [] 109 | for i1, i2 in zip(idx1, indices2): 110 | i1 = list(i1) 111 | num_gt = len(i1[1]) 112 | reserved = i2[0] < self.epsilon 113 | idx1_res = reserved[i1[0]] == True 114 | i1[0] = i1[0][idx1_res] 115 | i1[1] = i1[1][idx1_res] 116 | reserved[i1[0]] = False 117 | reserved_index = torch.where(reserved == True)[0] 118 | random_del_index = torch.where(torch.rand(len(reserved_index)) > (self.alpha * num_gt / num_queries))[0] 119 | reserved[reserved_index[random_del_index]] = False 120 | idx += [(torch.cat([i1[0], torch.arange(num_queries)[reserved]], dim=-1), 121 | torch.cat([i1[1], torch.as_tensor(i2[1], dtype=torch.int64)[reserved]], dim=-1))] 122 | 123 | for i, (_, tgt) in enumerate(idx): 124 | if normalize: 125 | cur_list = tgt.tolist() 126 | num = Counter(cur_list) 127 | coef = torch.tensor([1 / num[i] for i in cur_list], dtype=torch.float32).cpu() 128 | elif "ratio" in targets[i]: 129 | coef = targets[i]["ratio"].cpu() 130 | else: 131 | coef = torch.tensor([1] * len(tgt), dtype=torch.float32).cpu() 132 | Coef.append(coef) 133 | return idx, Coef 134 | 135 | 136 | 137 | 138 | def build_matcher(args): 139 | return HungarianMatcher(cost_class=args.set_cost_class, cost_bbox=args.set_cost_bbox, 140 | cost_giou=args.set_cost_giou, epsilon=args.epsilon, alpha=args.alpha) 141 | 142 | -------------------------------------------------------------------------------- /sedt/position_encoding.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | """ 3 | Various positional encodings for the transformer. 4 | """ 5 | import math 6 | import torch 7 | from torch import nn 8 | from utilities.utils import NestedTensor 9 | 10 | 11 | class PositionEmbeddingSine(nn.Module): 12 | """ 13 | This is a more standard version of the position embedding, very similar to the one 14 | used by the Attention is all you need paper, generalized to work on images. 15 | """ 16 | def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None): 17 | super().__init__() 18 | self.num_pos_feats = num_pos_feats 19 | self.temperature = temperature 20 | self.normalize = normalize 21 | if scale is not None and normalize is False: 22 | raise ValueError("normalize should be True if scale is passed") 23 | if scale is None: 24 | scale = 2 * math.pi 25 | self.scale = scale 26 | 27 | def forward(self, tensor_list: NestedTensor): 28 | x = tensor_list.tensors 29 | mask = tensor_list.mask 30 | assert mask is not None 31 | not_mask = ~mask 32 | y_embed = not_mask.cumsum(1, dtype=torch.float32) 33 | # x_embed = not_mask.cumsum(2, dtype=torch.float32) 34 | if self.normalize: 35 | eps = 1e-6 36 | y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale 37 | # x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale 38 | 39 | dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) 40 | dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) 41 | 42 | # pos_x = x_embed[:, :, :, None] / dim_t 43 | pos_y = y_embed[:, :, :, None] / dim_t 44 | # pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) 45 | pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) 46 | # pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) 47 | return pos_y.permute(0, 3, 1, 2) 48 | 49 | 50 | class PositionEmbeddingLearned(nn.Module): 51 | """ 52 | Absolute pos embedding, learned. 53 | """ 54 | def __init__(self, num_pos_feats=256): 55 | super().__init__() 56 | self.row_embed = nn.Embedding(50, num_pos_feats) 57 | self.col_embed = nn.Embedding(50, num_pos_feats) 58 | self.reset_parameters() 59 | 60 | def reset_parameters(self): 61 | nn.init.uniform_(self.row_embed.weight) 62 | nn.init.uniform_(self.col_embed.weight) 63 | 64 | def forward(self, tensor_list: NestedTensor): 65 | x = tensor_list.tensors 66 | h, w = x.shape[-2:] 67 | i = torch.arange(w, device=x.device) 68 | j = torch.arange(h, device=x.device) 69 | x_emb = self.col_embed(i) 70 | y_emb = self.row_embed(j) 71 | pos = torch.cat([ 72 | x_emb.unsqueeze(0).repeat(h, 1, 1), 73 | y_emb.unsqueeze(1).repeat(1, w, 1), 74 | ], dim=-1).permute(2, 0, 1).unsqueeze(0).repeat(x.shape[0], 1, 1, 1) 75 | return pos 76 | 77 | 78 | def build_position_encoding(args): 79 | # N_steps = args.hidden_dim // 2 80 | N_steps = args.hidden_dim 81 | if args.position_embedding in ('v2', 'sine'): 82 | # TODO find a better way of exposing other arguments 83 | position_embedding = PositionEmbeddingSine(N_steps, normalize=True) 84 | elif args.position_embedding in ('v3', 'learned'): 85 | position_embedding = PositionEmbeddingLearned(N_steps) 86 | else: 87 | raise ValueError(f"not supported {args.position_embedding}") 88 | 89 | return position_embedding 90 | -------------------------------------------------------------------------------- /sedt/spsedt.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Modified from UP-DETR (https://github.com/dddzg/up-detr) 3 | # Copyright (c) Tencent, Inc. and its affiliates. All Rights Reserved. 4 | # ------------------------------------------------------------------------ 5 | """ 6 | SP-SEDT model 7 | """ 8 | import torch 9 | from torch import nn 10 | from utilities.utils import NestedTensor 11 | from .sedt import SEDT, MLP 12 | 13 | 14 | class SPSEDT(SEDT): 15 | def __init__(self, backbone, transformer, num_classes, num_queries, aux_loss=False, dec_at=False,feature_recon=True, 16 | query_shuffle=False, mask_ratio=0.1, num_patches=10, pooling=None): 17 | super().__init__(backbone, transformer, num_classes, num_queries, aux_loss, dec_at, pooling) 18 | hidden_dim = transformer.d_model 19 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 20 | self.patch2query = nn.Linear(backbone.num_channels, hidden_dim) 21 | self.num_patches = num_patches 22 | self.mask_ratio = mask_ratio 23 | self.feature_recon = feature_recon 24 | if self.feature_recon: 25 | self.feature_align = MLP(hidden_dim, hidden_dim, backbone.num_channels, 2) 26 | self.query_shuffle = query_shuffle 27 | assert num_queries % num_patches == 0 28 | query_per_patch = num_queries // num_patches 29 | self.attention_mask = torch.ones(self.num_queries, self.num_queries) * float('-inf') 30 | for i in range(num_patches): 31 | self.attention_mask[i * query_per_patch:(i + 1) * query_per_patch, 32 | i * query_per_patch:(i + 1) * query_per_patch] = 0 33 | 34 | def forward(self, samples: list, patches: torch.Tensor): 35 | # def forward(self, samples: NestedTensor, patches: torch.Tensor): 36 | batch_num_patches = patches.shape[1] 37 | samples = [s.cuda() for s in samples] 38 | patches = patches.cuda() 39 | if isinstance(samples, (list, torch.Tensor)): 40 | # samples = nested_tensor_from_tensor_list(samples) 41 | samples=NestedTensor(samples[0],samples[1]) 42 | feature, pos = self.backbone(samples) 43 | 44 | src, mask = feature[-1].decompose() 45 | assert mask is not None 46 | 47 | bs = patches.shape[0] 48 | patches = patches.flatten(0, 1) 49 | patches_feature = self.backbone(patches) 50 | patches_feature_gt = self.avgpool(patches_feature[-1]).flatten(1) 51 | 52 | # [num_queries, bs, hidden_dim] 53 | patches_feature = self.patch2query(patches_feature_gt) \ 54 | .view(bs, batch_num_patches, 1, -1) \ 55 | .repeat(1, 1, self.num_queries // self.num_patches, 1) \ 56 | .flatten(1, 2).permute(1, 0, 2) \ 57 | .contiguous() 58 | 59 | # only shuffle the event queries 60 | idx = torch.randperm(self.num_queries) if self.query_shuffle else torch.arange(self.num_queries) 61 | 62 | start = 1 if self.dec_at else 0 63 | if self.training: 64 | # for training, it uses fixed number of query patches. 65 | mask_query_patch = (torch.rand(self.num_queries, bs, 1, device=patches.device) > self.mask_ratio).float() 66 | decoder_input = self.query_embed.weight[start:, :].unsqueeze(1).repeat(1, bs, 1)[idx] # don't include audio query 67 | decoder_input += patches_feature * mask_query_patch + decoder_input 68 | hs, memory = self.transformer(self.input_proj(src), mask, decoder_input, pos[-1], 69 | decoder_mask=self.attention_mask.to(patches_feature.device)) 70 | else: 71 | # for test, it supports x query patches, where x<=self.num_queries. 72 | num_queries = batch_num_patches * self.num_queries // self.num_patches 73 | decoder_input = patches_feature + self.query_embed.weight[start:num_queries, :].unsqueeze(1).repeat(1, bs, 1) 74 | hs, memory = self.transformer(self.input_proj(src), mask, decoder_input, pos[-1], 75 | decoder_mask=self.attention_mask.to(patches_feature.device)[:num_queries, :num_queries]) 76 | 77 | outputs_class = self.class_embed(hs) 78 | outputs_coord = self.bbox_embed(hs).sigmoid() 79 | if self.feature_recon: 80 | outputs_feature = self.feature_align(hs) 81 | out = {'pred_logits': outputs_class[-1], 'pred_feature': outputs_feature[-1], 82 | 'gt_feature': patches_feature_gt, 83 | 'pred_boxes': outputs_coord[-1]} 84 | if self.aux_loss: 85 | out['aux_outputs'] = self._set_aux_loss(outputs_class, outputs_coord, outputs_feature, patches_feature_gt) 86 | else: 87 | out = {'pred_logits': outputs_class[-1], 'pred_boxes': outputs_coord[-1]} 88 | if self.aux_loss: 89 | out['aux_outputs'] = super()._set_aux_loss(outputs_class, outputs_coord) 90 | 91 | return out 92 | 93 | def _set_aux_loss(self, outputs_class, outputs_coord, outputs_feature, backbone_out): 94 | return [{'pred_logits': a, 'pred_boxes' : b, 'pred_feature': c, 'gt_feature': backbone_out} 95 | for a, b, c in zip(outputs_class[:-1], outputs_coord[:-1], outputs_feature[:-1])] 96 | -------------------------------------------------------------------------------- /train_at.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # encoding: utf-8 3 | """ 4 | @author: yzr 5 | @file: train_at.py 6 | @time: 2020/12/1 14:52 7 | """ 8 | import torch 9 | import torch.nn as nn 10 | import inspect 11 | 12 | from utilities.metrics import audio_tagging_results 13 | from utilities.Logger import create_logger, set_logger 14 | from data_utils.SedData import SedData 15 | from utilities.FrameEncoder import ManyHotEncoder 16 | from utilities.FrameTransforms import get_transforms 17 | from data_utils.DataLoad import DataLoadDf, data_prefetcher, ConcatDataset 18 | from utilities.Scaler import Scaler 19 | from torch.utils.data import DataLoader 20 | from audio_tag.backbone import build_backbone 21 | from utilities.utils import to_cuda_if_available, SaveBest 22 | import datetime 23 | import pandas as pd 24 | import argparse 25 | import shutil 26 | import config as cfg 27 | from pprint import pprint 28 | import os 29 | 30 | 31 | def get_dfs(desed_dataset, dataname): 32 | if "urban" in dataname: 33 | train_df = desed_dataset.initialize_and_get_df(cfg.urban_train_tsv) 34 | valid_df = desed_dataset.initialize_and_get_df(cfg.urban_valid_tsv) 35 | eval_df = desed_dataset.initialize_and_get_df(cfg.urban_eval_tsv) 36 | return {"train": train_df, 37 | "val": valid_df, 38 | "test": eval_df} 39 | else: 40 | synthetic_df = desed_dataset.initialize_and_get_df(cfg.synthetic) 41 | validation_df = desed_dataset.initialize_and_get_df(cfg.validation, audio_dir=cfg.audio_validation_dir) 42 | weak_df = desed_dataset.initialize_and_get_df(cfg.weak) 43 | eval_df = desed_dataset.initialize_and_get_df(cfg.eval_desed) 44 | return {"weak": weak_df, 45 | "synthetic": synthetic_df, 46 | "val": validation_df, 47 | "test": eval_df} 48 | 49 | 50 | def train(model, train_loader, optim, c_epoch, grad_step, max_norm=0.1): 51 | loss_func = nn.BCELoss() 52 | prefetcher = data_prefetcher(train_loader) 53 | input, targets = prefetcher.next() 54 | i = -1 55 | while input is not None: 56 | output = model(input) 57 | loss = loss_func(output, targets) 58 | 59 | loss.backward() 60 | if i % grad_step == 0: 61 | if max_norm > 0: 62 | torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm) 63 | optim.step() 64 | optim.zero_grad() 65 | input, targets = prefetcher.next() 66 | print("Epoch:{} Loss:{} lr:{}".format(c_epoch, loss.item(), optim.param_groups[0]["lr"])) 67 | 68 | 69 | def evaluate(model, data_loader, decoder): 70 | logger.info("validation") 71 | loss_func = nn.BCELoss() 72 | audio_tag_dfs = pd.DataFrame() 73 | prefetcher = data_prefetcher(data_loader, return_indexes=True) 74 | (input, targets), indexes = prefetcher.next() 75 | i = -1 76 | while input is not None: 77 | i += 1 78 | indexes = indexes.numpy() 79 | with torch.no_grad(): 80 | output = model(input) 81 | loss = loss_func(output, targets) 82 | 83 | audio_tags = output 84 | audio_tags = (audio_tags > 0.5).long() 85 | for j, audio_tag in enumerate(audio_tags): 86 | audio_tag_res = decoder(audio_tag) 87 | audio_tag_res = pd.DataFrame(audio_tag_res, columns=["event_label"]) 88 | audio_tag_res["filename"] = data_loader.dataset.filenames.iloc[indexes[j]] 89 | audio_tag_res["onset"] = 0 90 | audio_tag_res["offset"] = 0 91 | audio_tag_dfs = audio_tag_dfs.append(audio_tag_res) 92 | (input, targets), indexes = prefetcher.next() 93 | if "event_labels" in data_loader.dataset.df.columns: 94 | reformat_df = pd.DataFrame() 95 | filenames = data_loader.dataset.filenames 96 | for file in filenames: 97 | labels = audio_tag_dfs[audio_tag_dfs['filename']==file].event_label.drop_duplicates().to_list() 98 | labels = ",".join(labels) 99 | df = pd.DataFrame([[file, labels]], columns=['filename', 'event_labels']) 100 | reformat_df = reformat_df.append(df) 101 | return reformat_df 102 | else: 103 | return audio_tag_dfs 104 | 105 | 106 | if __name__ == "__main__": 107 | torch.manual_seed(2020) 108 | 109 | parser = argparse.ArgumentParser(description="") 110 | # model param 111 | parser.add_argument("--pooling", choices=["max", "avg"], default="avg") 112 | parser.add_argument("--pretrained", action="store_false", default=True) 113 | ### 114 | parser.add_argument('--hidden_dim', default=256, type=int, 115 | help="Size of the embeddings (dimension of the transformer)") 116 | parser.add_argument('--backbone', default='resnet50', type=str, 117 | help="Name of the convolutional backbone to use") 118 | parser.add_argument('--dilation', action='store_false', default=True, 119 | help="If true, we replace stride with dilation in the last convolutional block (DC5)") 120 | # train param 121 | parser.add_argument("--nepochs", type=int, default=100) 122 | parser.add_argument("--batch_size", type=int, default=128) 123 | parser.add_argument("--grad_steps", type=int, default=1) 124 | parser.add_argument("--lr", type=float, default=0.0001) 125 | parser.add_argument("--lr_drop", type=int, default=20) 126 | parser.add_argument("--gpu", type=str, default="-1") 127 | parser.add_argument("--back_up", action="store_true", default=False) 128 | parser.add_argument("--fix_backbone", action="store_true", default=False) 129 | # data param 130 | parser.add_argument('--dataname', default='urbansed', choices=['urbansed', 'dcase']) 131 | 132 | f_args = parser.parse_args() 133 | os.environ["CUDA_VISIBLE_DEVICES"] = f_args.gpu 134 | 135 | store_dir = os.path.join(cfg.dir_root, f_args.dataname) 136 | code_dir = os.path.join(store_dir, "code") 137 | model_dir = os.path.join(store_dir, "model") 138 | os.makedirs(model_dir, exist_ok=True) 139 | model_name = f"backbone_{f_args.backbone}_{f_args.pooling}" 140 | if f_args.pretrained: 141 | model_name += '_pretrained' 142 | model_path = os.path.join(model_dir, model_name) 143 | set_logger(model_name) 144 | logger = create_logger(__name__ + "/" + inspect.currentframe().f_code.co_name, terminal_level=cfg.terminal_level) 145 | logger.info("Audio_Tag_Module") 146 | logger.info(f"starting time :{datetime.datetime.now()}") 147 | pprint(vars(f_args)) 148 | 149 | ################ 150 | # code back-up 151 | ################ 152 | current_time = datetime.datetime.now().strftime('%F_%H%M') 153 | if f_args.back_up: 154 | # code file path 155 | cur_code_dir = os.path.join(code_dir, f'{current_time}_{model_name}') 156 | if os.path.exists(cur_code_dir): 157 | shutil.rmtree(cur_code_dir) 158 | os.makedirs(cur_code_dir) 159 | this_dir = os.path.dirname(os.path.abspath(__file__)) 160 | for filename in os.listdir(this_dir): 161 | if filename in ['data', 'exp', 'log']: 162 | continue 163 | old_path = os.path.join(this_dir, filename) 164 | new_path = os.path.join(cur_code_dir, filename) 165 | if os.path.isdir(old_path): 166 | shutil.copytree(old_path, new_path) 167 | else: 168 | shutil.copyfile(old_path, new_path) 169 | 170 | 171 | # model 172 | model = build_backbone(f_args) 173 | model = to_cuda_if_available(model) 174 | logger.info(model) 175 | param_num = sum(p.numel() for p in model.parameters() if p.requires_grad) 176 | logger.info("number of parameters in the model: {}".format(param_num)) 177 | 178 | # data preparation 179 | dataset = SedData(f_args.dataname, recompute_features=False, compute_log=False) 180 | dfs = get_dfs(dataset, f_args.dataname) 181 | if "urban" in f_args.dataname: 182 | encoder = ManyHotEncoder(cfg.urban_classes, n_frames=cfg.umax_frames) 183 | transformer = get_transforms(cfg.umax_frames, add_axis=0) 184 | else: 185 | encoder = ManyHotEncoder(cfg.dcase_classes, n_frames=cfg.max_frames) 186 | transformer = get_transforms(cfg.max_frames, add_axis=0) 187 | 188 | weak_data = DataLoadDf(dfs["weak"], encoder.encode_weak, transform=transformer) 189 | syn_data = DataLoadDf(dfs["synthetic"], encoder.encode_weak, transform=transformer) 190 | train_data = ConcatDataset([weak_data, syn_data]) 191 | scaler = Scaler() 192 | scaler.calculate_scaler(train_data) 193 | 194 | transformer = get_transforms(cfg.umax_frames if "urbansed" in f_args.dataname else cfg.max_frames, scaler=scaler, add_axis=0) 195 | weak_data = DataLoadDf(dfs["weak"], encoder.encode_weak, transform=transformer, in_memory=cfg.in_memory) 196 | syn_data = DataLoadDf(dfs["synthetic"], encoder.encode_weak, transform=transformer, in_memory=cfg.in_memory) 197 | val_data = DataLoadDf(dfs["val"], encoder.encode_weak, transform=transformer, return_indexes=True) 198 | test_data = DataLoadDf(dfs["test"], encoder.encode_weak, transform=transformer, return_indexes=True) 199 | 200 | train_loader = DataLoader(ConcatDataset([weak_data, syn_data]), batch_size=f_args.batch_size, shuffle=True, pin_memory=True) 201 | val_loader = DataLoader(val_data, batch_size=f_args.batch_size, shuffle=False, drop_last=False) 202 | test_loader = DataLoader(test_data, batch_size=f_args.batch_size, shuffle=False, drop_last=False) 203 | 204 | validation_labels_df = dfs["val"].drop("feature_filename", axis=1) 205 | test_labels_df = dfs["test"].drop("feature_filename", axis=1) 206 | 207 | 208 | optim = torch.optim.Adam(model.parameters(), lr=f_args.lr, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, 209 | amsgrad=True) 210 | lr_scheduler = torch.optim.lr_scheduler.StepLR(optim, f_args.lr_drop) 211 | best_saver = SaveBest("sup") 212 | # train 213 | state = {"model": model.state_dict(), "epoch": 0} 214 | for epoch in range(f_args.nepochs): 215 | model.train() 216 | train(model, train_loader, optim, epoch, f_args.grad_steps) 217 | lr_scheduler.step() 218 | model = model.eval() 219 | 220 | 221 | audio_tag_df = evaluate(model, val_loader, encoder.decode_weak) 222 | clip_metric = audio_tagging_results(validation_labels_df, audio_tag_df) 223 | clip_macro_f1 = clip_metric.loc['avg', 'f'] 224 | print("AT Class-wise clip metrics") 225 | print("=" * 50) 226 | print(clip_metric) 227 | # print("clip_macro_metrics:" + f'{clip_metric.values.mean():.3f}') 228 | state["model"] = model.state_dict() 229 | state["epoch"] = epoch 230 | # save best model 231 | if best_saver.apply(clip_macro_f1): 232 | torch.save(state, model_path) 233 | state = torch.load(model_path, map_location=torch.device("cpu") if not torch.cuda.is_available() else None) 234 | model.load_state_dict(state['model']) 235 | logger.info(f"testing model of epoch {state['epoch']} at {model_path}") 236 | model.eval() 237 | audio_tag_df = evaluate(model, val_loader, encoder.decode_weak) 238 | clip_metric = audio_tagging_results(validation_labels_df, audio_tag_df) 239 | clip_macro_f1 = clip_metric.loc['avg', 'f'] 240 | print("AT Class-wise clip metrics on validation set") 241 | print("=" * 50) 242 | print(clip_metric) 243 | 244 | audio_tag_df = evaluate(model, test_loader, encoder.decode_weak) 245 | clip_metric = audio_tagging_results(test_labels_df, audio_tag_df) 246 | clip_macro_f1 = clip_metric.loc['avg', 'f'] 247 | print("AT Class-wise clip metrics on test set") 248 | print("=" * 50) 249 | print(clip_metric) 250 | 251 | 252 | 253 | 254 | 255 | 256 | 257 | -------------------------------------------------------------------------------- /train_sedt.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # encoding: utf-8 3 | """ 4 | @author: yzr 5 | @file: train_sedt.py 6 | @time: 2020/8/1 10:43 7 | """ 8 | import argparse 9 | import datetime 10 | import inspect 11 | import os 12 | from pprint import pprint 13 | import numpy as np 14 | import torch 15 | 16 | from data_utils.SedData import SedData, get_dfs 17 | from torch.utils.data import DataLoader 18 | from data_utils.DataLoad import DataLoadDf, ConcatDataset, MultiStreamBatchSampler 19 | from engine import train, evaluate 20 | import config as cfg 21 | from utilities.Logger import create_logger, set_logger 22 | from utilities.Scaler import Scaler 23 | from utilities.utils import SaveBest, collate_fn, back_up_code, EarlyStopping 24 | from utilities.BoxEncoder import BoxEncoder 25 | from utilities.BoxTransforms import get_transforms as box_transforms 26 | from sedt import build_model 27 | 28 | def get_parser(): 29 | parser = argparse.ArgumentParser(description="") 30 | # dataset parameters 31 | parser.add_argument('--num_classes', default=10, type=int) 32 | parser.add_argument('--dataname', default='dcase', choices=['urbansed', 'dcase']) 33 | parser.add_argument('--synthetic', dest='synthetic', action='store_true', default=True, 34 | help="using synthetic labels during training") 35 | parser.add_argument('--weak', dest='weak', action='store_false', default=True, 36 | help="Not using weak labels during training") 37 | 38 | # train parameters 39 | parser.add_argument('--lr', default=1e-4, type=float) 40 | parser.add_argument('--lr_backbone', default=1e-4, type=float) 41 | parser.add_argument('--batch_size', default=64, type=int) 42 | parser.add_argument('--n_weak', default=16, type=int) 43 | parser.add_argument('--accumrating_gradient_steps', default=1, type=int) 44 | parser.add_argument('--adjust_lr', action='store_false', default=True) 45 | parser.add_argument('--weight_decay', default=1e-4, type=float) 46 | parser.add_argument('--eval', action="store_true", help='evaluate existing model') 47 | parser.add_argument('--epochs', default=400, type=int) 48 | parser.add_argument('--epochs_ls', default=400, type=int, help='number of epochs for learning stage') 49 | parser.add_argument('--checkpoint_epochs', default=0, type=int, help='save model every checkpoint_epochs') 50 | parser.add_argument('--lr_drop', default=200, type=int) 51 | parser.add_argument('--fine_tune', action="store_true", default=False) 52 | parser.add_argument('--normalize', action="store_true", default=False) 53 | parser.add_argument('--clip_max_norm', default=0.1, type=float, help='gradient clipping max norm') 54 | 55 | # data augmentation parameters 56 | parser.add_argument("--mix_up_ratio", type=float, default=0, 57 | help="the ratio of data to be mixed up during training") 58 | parser.add_argument("--time_mask", action="store_true", default=False, 59 | help="perform time mask during training") 60 | parser.add_argument("--freq_mask", action="store_true", default=False, 61 | help="perform frequency mask during training") 62 | parser.add_argument("--freq_shift", action="store_true", default=False, 63 | help="perform frequency shift during training") 64 | 65 | # model parameters 66 | parser.add_argument('--self_sup', dest='self_sup', action='store_true') 67 | parser.add_argument('--gpus', type=str, default='0') 68 | parser.add_argument('--pretrain', default='', help='initialized from the pre-training model') 69 | parser.add_argument('--resume', default='', help='resume training from specific model') 70 | parser.add_argument("--dec_at", action="store_true", default=False, help="add audio tagging branch") 71 | parser.add_argument("--fusion_strategy", default=[1], nargs='+', type=int) 72 | parser.add_argument("--pooling", type=str, default=None, choices=('max', 'avg', 'attn', 'weighted_sum')) 73 | # * Backbone 74 | parser.add_argument('--backbone', default='resnet50', type=str, 75 | help="Name of the convolutional backbone to use") 76 | parser.add_argument('--dilation', action='store_false', default=True, 77 | help="If true, we replace stride with dilation in the last convolutional block (DC5)") 78 | parser.add_argument('--position_embedding', default='sine', type=str, choices=('sine', 'learned'), 79 | help="Type of positional embedding to use on top of the image features") 80 | 81 | # * Transformer 82 | parser.add_argument('--enc_layers', default=3, type=int, 83 | help="Number of encoding layers in the transformer") 84 | parser.add_argument('--dec_layers', default=3, type=int, 85 | help="Number of decoding layers in the transformer") 86 | parser.add_argument('--idim', default=128, type=int, 87 | help="Size of the transformer input") 88 | parser.add_argument('--dim_feedforward', default=2048, type=int, 89 | help="Intermediate size of the feedforward layers in the transformer blocks") 90 | parser.add_argument('--hidden_dim', default=256, type=int, 91 | help="Size of the embeddings (dimension of the transformer)") 92 | parser.add_argument('--dropout', default=0.1, type=float, 93 | help="Dropout applied in the transformer") 94 | parser.add_argument('--nheads', default=8, type=int, 95 | help="Number of attention heads inside the transformer's attentions") 96 | parser.add_argument('--num_queries', default=20, type=int, 97 | help="Number of query slots") 98 | parser.add_argument('--pre_norm', action='store_false', default=True) 99 | parser.add_argument('--input_layer', default="linear", type=str, 100 | help="input layer type in the transformer") 101 | 102 | # Loss 103 | parser.add_argument('--no_aux_loss', dest='aux_loss', action='store_false', 104 | help="Disables auxiliary decoding losses (loss at each layer)") 105 | # * Matcher 106 | parser.add_argument('--set_cost_class', default=1, type=float, 107 | help="Class coefficient in the matching cost") 108 | parser.add_argument('--set_cost_bbox', default=5, type=float, 109 | help="L1 box coefficient in the matching cost") 110 | parser.add_argument('--set_cost_giou', default=2, type=float, 111 | help="giou box coefficient in the matching cost") 112 | parser.add_argument('--epsilon', default=1, type=float) 113 | parser.add_argument('--alpha', default=1, type=float) 114 | # * Loss coefficients 115 | parser.add_argument('--dice_loss_coef', default=1, type=float) 116 | parser.add_argument('--bbox_loss_coef', default=5, type=float) 117 | parser.add_argument('--giou_loss_coef', default=2, type=float) 118 | parser.add_argument('--eos_coef', default=0.1, type=float, 119 | help="Relative classification weight of the no-object class") 120 | parser.add_argument('--weak_loss_coef', default=1, type=float) 121 | parser.add_argument('--weak_loss_p_coef', default=1, type=float) 122 | parser.add_argument('--ce_loss_coef', default=1, type=float) 123 | 124 | parser.add_argument('--info', default=None, type=str) # experiment information 125 | parser.add_argument('--back_up', action='store_true', default=False, 126 | help="store current code") 127 | parser.add_argument('--log', action='store_false', default=True, 128 | help="generate log file for this experiment") 129 | return parser 130 | 131 | 132 | if __name__ == '__main__': 133 | torch.manual_seed(2020) 134 | np.random.seed(2020) 135 | parser = get_parser() 136 | f_args = parser.parse_args() 137 | if f_args.eval: 138 | f_args.epochs = 0 139 | assert f_args.info, "Don't give the model information to be evaluated" 140 | if f_args.info is None: 141 | f_args.info = f"{f_args.dataname}_atloss_{f_args.weak_loss_coef}_atploss_{f_args.weak_loss_p_coef}_enc_{f_args.enc_layers}_pooling_{f_args.pooling}_{f_args.fusion_strategy}" 142 | if f_args.pretrain: 143 | f_args.info += "_" + f_args.pretrain 144 | if f_args.log: 145 | set_logger(f_args.info) 146 | logger = create_logger(__name__ + "/" + inspect.currentframe().f_code.co_name, terminal_level=cfg.terminal_level) 147 | logger.info("Sound Event Detection Transformer") 148 | logger.info(f"Starting time: {datetime.datetime.now()}") 149 | os.environ["CUDA_VISIBLE_DEVICES"] = f_args.gpus 150 | 151 | if 'dcase' in f_args.dataname: 152 | f_args.num_queries=20 153 | pprint(vars(f_args)) 154 | store_dir = os.path.join(cfg.dir_root, f_args.dataname) 155 | saved_model_dir = os.path.join(store_dir, "model") 156 | os.makedirs(saved_model_dir, exist_ok=True) 157 | if f_args.back_up: 158 | back_up_code(store_dir, f_args.info) 159 | 160 | # ############## 161 | # DATA 162 | # ############## 163 | dataset = SedData(f_args.dataname, recompute_features=False, compute_log=False) 164 | dfs = get_dfs(dataset, f_args.dataname) 165 | 166 | 167 | # Normalisation per audio or on the full dataset 168 | add_axis_conv = 0 169 | scaler = Scaler() 170 | scaler_path = os.path.join(store_dir, f_args.dataname + ".json") 171 | if f_args.dataname == 'urbansed': 172 | label_encoder = BoxEncoder(cfg.urban_classes, seconds=cfg.max_len_seconds) 173 | transforms = box_transforms(cfg.umax_frames, add_axis=add_axis_conv) 174 | encod_func = label_encoder.encode_strong_df 175 | train_data = DataLoadDf(dfs['train'], encod_func, transforms) 176 | train_data = ConcatDataset([train_data]) 177 | else: 178 | label_encoder = BoxEncoder(cfg.dcase_classes, seconds=cfg.max_len_seconds) 179 | transforms = box_transforms(cfg.max_frames, add_axis=add_axis_conv) 180 | encod_func = label_encoder.encode_strong_df 181 | weak_data = DataLoadDf(dfs["weak"], encod_func, transforms) 182 | train_synth_data = DataLoadDf(dfs["synthetic"], encod_func, transforms) 183 | train_data = ConcatDataset([weak_data, train_synth_data]) 184 | if os.path.isfile(scaler_path): 185 | logger.info('loading scaler from {}'.format(scaler_path)) 186 | scaler.load(scaler_path) 187 | else: 188 | scaler.calculate_scaler(train_data) 189 | scaler.save(scaler_path) 190 | logger.debug(f"scaler mean: {scaler.mean_}") 191 | 192 | 193 | if f_args.dataname == 'urbansed': 194 | transforms = box_transforms(cfg.umax_frames, scaler, add_axis_conv, time_mask=f_args.time_mask, 195 | freq_mask=f_args.freq_mask, freq_shift=f_args.freq_shift,) 196 | transforms_valid = box_transforms(cfg.umax_frames, scaler, add_axis_conv) 197 | train_data = DataLoadDf(dfs["train"], encod_func, transform=transforms, in_memory=cfg.in_memory) 198 | eval_data = DataLoadDf(dfs["eval"], encod_func, transform=transforms_valid, return_indexes=True) 199 | validation_data = DataLoadDf(dfs["validation"], encod_func, transform=transforms_valid, return_indexes=True) 200 | 201 | train_dataset = [train_data] 202 | batch_sizes = [f_args.batch_size] 203 | strong_mask = slice(batch_sizes[0]) 204 | weak_mask = None 205 | else: 206 | transforms = box_transforms(cfg.max_frames, scaler, add_axis_conv, time_mask=f_args.time_mask, 207 | freq_mask=f_args.freq_mask, freq_shift=f_args.freq_shift,) 208 | transforms_valid = box_transforms(cfg.max_frames, scaler, add_axis_conv) 209 | weak_data = DataLoadDf(dfs["weak"], encod_func, transforms, in_memory=cfg.in_memory) 210 | train_synth_data = DataLoadDf(dfs["synthetic"], encod_func, transforms, in_memory=cfg.in_memory) 211 | validation_data = DataLoadDf(dfs["validation"], encod_func, transform=transforms_valid, return_indexes=True) 212 | eval_data = DataLoadDf(dfs["eval"], encod_func, transform=transforms_valid, return_indexes=True) 213 | train_dataset = [train_synth_data, weak_data] 214 | batch_sizes = [f_args.batch_size-f_args.n_weak, f_args.n_weak] 215 | weak_mask = slice(batch_sizes[0], f_args.batch_size) 216 | strong_mask = slice(batch_sizes[0]) 217 | 218 | concat_dataset = ConcatDataset(train_dataset) 219 | sampler = MultiStreamBatchSampler(concat_dataset, batch_sizes=batch_sizes) 220 | training_loader = DataLoader(concat_dataset, batch_sampler=sampler, collate_fn=collate_fn, pin_memory=True) 221 | validation_dataloader = DataLoader(validation_data, batch_size=f_args.batch_size, collate_fn=collate_fn) 222 | eval_dataloader = DataLoader(eval_data, batch_size=f_args.batch_size, collate_fn=collate_fn) 223 | validation_labels_df = dfs["validation"].drop("feature_filename", axis=1) 224 | eval_labels_df = dfs["eval"].drop("feature_filename", axis=1) 225 | 226 | 227 | # ############## 228 | # Model 229 | # ############## 230 | model, criterion, postprocessors = build_model(f_args) 231 | pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad) 232 | logger.info(model) 233 | logger.info("number of parameters in the model: {}".format(pytorch_total_params)) 234 | param_dicts = [ 235 | {"params": [p for n, p in model.named_parameters() if "backbone" not in n and p.requires_grad]}, 236 | { 237 | "params": [p for n, p in model.named_parameters() if "backbone" in n and p.requires_grad], 238 | "lr": f_args.lr_backbone, 239 | }, 240 | ] 241 | 242 | 243 | if f_args.pretrain: 244 | logger.info('loading the self-supervised model') 245 | model_fname = os.path.join(saved_model_dir, f_args.pretrain) 246 | state = torch.load(model_fname, map_location=torch.device('cpu')) 247 | model_dict = model.state_dict() 248 | # There are no audio query in self-supervised model 249 | # only load the backbone, encoder, decoder and MLP for box prediction 250 | load_dict = state['model']['state_dict'] 251 | model_dict["query_embed.weight"][1:, :] = load_dict["query_embed.weight"] 252 | load_dict = {k: v for k, v in load_dict.items() if (k in model_dict and "class_embed" not in k and "query_embed" not in k)} 253 | model_dict.update(load_dict) 254 | model.load_state_dict(model_dict) 255 | 256 | start_epoch = 0 257 | if f_args.resume: 258 | model_fname = os.path.join(saved_model_dir, f_args.resume) 259 | if torch.cuda.is_available(): 260 | state = torch.load(model_fname) 261 | else: 262 | state = torch.load(model_fname, map_location=torch.device('cpu')) 263 | load_dict = state['model']['state_dict'] 264 | model.load_state_dict(load_dict) 265 | start_epoch = state['epoch'] 266 | logger.info('Resume training form epoch {}'.format(state['epoch'])) 267 | 268 | model = model.cuda() 269 | optim = torch.optim.AdamW(param_dicts, lr=f_args.lr, 270 | weight_decay=f_args.weight_decay) 271 | lr_scheduler = torch.optim.lr_scheduler.StepLR(optim, f_args.lr_drop) 272 | if f_args.resume: 273 | optim.load_state_dict(state['optimizer']['state_dict']) 274 | 275 | state = { 276 | 'model': {"name": model.__class__.__name__, 277 | 'args': '', 278 | "kwargs": '', 279 | 'state_dict': model.state_dict()}, 280 | 281 | 'optimizer': {"name": optim.__class__.__name__, 282 | 'args': '', 283 | 'state_dict': optim.state_dict()}, 284 | } 285 | 286 | fusion_strategy = f_args.fusion_strategy 287 | best_saver = {} 288 | 289 | for at_m in fusion_strategy: 290 | best_saver[at_m] = SaveBest("sup") 291 | 292 | if cfg.early_stopping is not None: 293 | early_stopping_call = EarlyStopping(patience=cfg.early_stopping, fusion_strategy=f_args.fusion_strategy, 294 | val_comp="sup", init_patience=cfg.es_init_wait) 295 | 296 | for epoch in range(start_epoch, f_args.epochs): 297 | model.train() 298 | if epoch == f_args.epochs_ls: 299 | logger.info("enter the fine-tuning stage") 300 | # load the best model of the learning stage 301 | try: 302 | model_fname = os.path.join(saved_model_dir, f"{f_args.info}_1_best") 303 | state = torch.load(model_fname) 304 | model.load_state_dict(state['model']['state_dict']) 305 | except: 306 | logger.info("No best model exists, fine-tune current model") 307 | # fix the learning rate as 1e-5 308 | f_args.adjust_lr = False 309 | f_args.fine_tune = True 310 | f_args.info += "_ft" 311 | 312 | loss_value = train(training_loader, model, criterion, optim, epoch, f_args.accumrating_gradient_steps, 313 | mask_weak=weak_mask, fine_tune=f_args.fine_tune, normalize=f_args.normalize, 314 | mask_strong=strong_mask, max_norm=0.1, mix_up_ratio=f_args.mix_up_ratio) 315 | if f_args.adjust_lr: 316 | lr_scheduler.step() 317 | 318 | # Update state 319 | state['model']['state_dict'] = model.state_dict() 320 | state['optimizer']['state_dict'] = optim.state_dict() 321 | state['epoch'] = epoch 322 | # Validation 323 | model = model.eval() 324 | logger.info("Metric on validation") 325 | metrics = evaluate(model, criterion, postprocessors, validation_dataloader, label_encoder, validation_labels_df, 326 | at=True, fusion_strategy=fusion_strategy) 327 | 328 | if cfg.save_best: 329 | for at_m, eb in metrics.items(): 330 | state[f'event_based_f1_{at_m}'] = eb 331 | if best_saver[at_m].apply(eb): 332 | model_fname = os.path.join(saved_model_dir, f"{f_args.info}_{at_m}_best") 333 | torch.save(state, model_fname) 334 | 335 | if cfg.early_stopping: 336 | if early_stopping_call.apply(eb): 337 | logger.warn("EARLY STOPPING") 338 | break 339 | 340 | if f_args.checkpoint_epochs > 0 and (epoch + 1) % f_args.checkpoint_epochs == 0: 341 | model_fname = os.path.join(saved_model_dir, f"{f_args.info}_{epoch}") 342 | torch.save(state, model_fname) 343 | 344 | if cfg.save_best or f_args.eval: 345 | for at_m in fusion_strategy: 346 | model_fname = os.path.join(saved_model_dir, f"{f_args.info}_{at_m}_best") 347 | if torch.cuda.is_available(): 348 | state = torch.load(model_fname) 349 | else: 350 | state = torch.load(model_fname, map_location=torch.device('cpu')) 351 | model.load_state_dict(state['model']['state_dict']) 352 | logger.info(f"testing model: {model_fname}, epoch: {state['epoch']}") 353 | 354 | model.eval() 355 | logger.info("Metric on validation") 356 | evaluate(model, criterion, postprocessors, validation_dataloader, label_encoder, validation_labels_df, 357 | at=True, fusion_strategy=[at_m], cal_seg=True, cal_clip=True) 358 | 359 | logger.info("Metric on eval") 360 | evaluate(model, criterion, postprocessors, eval_dataloader, label_encoder, eval_labels_df, 361 | at=True, fusion_strategy=[at_m], cal_seg=True, cal_clip=True) 362 | -------------------------------------------------------------------------------- /train_spsedt.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import inspect 3 | import os 4 | import numpy as np 5 | import torch 6 | 7 | from data_utils.SedData import SedData 8 | from torch.nn.parallel import DistributedDataParallel 9 | from torch.utils.data import DataLoader, DistributedSampler, RandomSampler, BatchSampler 10 | from data_utils.DataLoad import DataLoadDf 11 | from engine import train 12 | from train_sedt import get_parser 13 | from utilities.Logger import create_logger, set_logger 14 | from utilities.Scaler import Scaler 15 | from utilities.distribute import is_main_process, init_distributed_mode 16 | from utilities.utils import collate_fn, back_up_code 17 | from utilities.BoxEncoder import BoxEncoder 18 | from utilities.BoxTransforms import get_transforms as box_transforms 19 | from sedt import build_model 20 | import config as cfg 21 | 22 | 23 | def get_pretrain_data(desed_dataset, extra_data=False): 24 | unlabel_df = desed_dataset.initialize_and_get_df(cfg.unlabel) 25 | if extra_data: 26 | dcase2018_task5 = desed_dataset.initialize_and_get_df(cfg.dcase2018_task5) 27 | unlabel_df = unlabel_df.append(dcase2018_task5,ignore_index=True) 28 | return unlabel_df 29 | 30 | 31 | if __name__ == '__main__': 32 | torch.manual_seed(2020) 33 | np.random.seed(2020) 34 | 35 | parser = get_parser() 36 | # sp-sedt related parameters 37 | parser.add_argument('--num_patches', default=10, type=int, help="number of query patches") 38 | parser.add_argument('--feature_recon', action='store_true', default=False) 39 | parser.add_argument('--query_shuffle', action='store_true', default=False) 40 | parser.add_argument('--fixed_patch_size', default=False, action='store_true', 41 | help="use fixed size for each patch") 42 | parser.add_argument('--extra_data', default=False, action='store_true', 43 | help="use dcase2018 task5 data to pretrain") 44 | # distributed training parameters 45 | parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes') 46 | parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training') 47 | parser.add_argument('--local_rank',default=0, type=int) 48 | f_args = parser.parse_args() 49 | assert f_args.dataname == "dcase", "only support dcase dataset now" 50 | f_args.lr_backbone = 0 51 | init_distributed_mode(f_args) 52 | if f_args.info is None: 53 | f_args.info = f"pretrain_enc_{f_args.enc_layers}" 54 | if f_args.feature_recon: 55 | f_args.info += "_feature_recon" 56 | if f_args.fixed_patch_size: 57 | f_args.info += "_fixed_patch_size" 58 | if f_args.extra_data: 59 | f_args.extra_data += "_extra_data" 60 | if f_args.log: 61 | set_logger(f_args.info) 62 | logger = create_logger(__name__ + "/" + inspect.currentframe().f_code.co_name, terminal_level=cfg.terminal_level) 63 | logger.info("Self-supervised Pre-training for Sound Event Detection Transformer") 64 | logger.info(f"Starting time: {datetime.datetime.now()}") 65 | os.environ["CUDA_VISIBLE_DEVICES"] = f_args.gpus 66 | 67 | logger.info(vars(f_args)) 68 | store_dir = os.path.join(cfg.dir_root, "dcase") 69 | saved_model_dir = os.path.join(store_dir, "model") 70 | os.makedirs(saved_model_dir, exist_ok=True) 71 | if f_args.back_up: 72 | back_up_code(store_dir, f_args.info) 73 | 74 | # ############## 75 | # DATA 76 | # ############## 77 | dataset = SedData("dcase", recompute_features=False, compute_log=False) 78 | unlabel_data = get_pretrain_data(dataset, extra_data=f_args.extra_data) 79 | 80 | # Normalisation per audio or on the full dataset 81 | add_axis_conv = 0 82 | scaler = Scaler() 83 | if f_args.extra_data: 84 | scaler_path = os.path.join(store_dir, "dcase_sp_bd.json") 85 | else: 86 | scaler_path = os.path.join(store_dir, "dcase_sp.json") 87 | num_class = 1 88 | label_encoder = BoxEncoder(num_class, seconds=cfg.max_len_seconds, generate_patch=True) 89 | encod_func = label_encoder.encode_strong_df 90 | 91 | if os.path.isfile(scaler_path): 92 | logger.info('loading scaler from {}'.format(scaler_path)) 93 | scaler.load(scaler_path) 94 | else: 95 | transforms = box_transforms(cfg.max_frames, add_axis=add_axis_conv, crop_patch=f_args.self_sup, 96 | fixed_patch_size=f_args.fixed_patch_size) 97 | train_data = DataLoadDf(unlabel_data, label_encoder.encode_unlabel, transforms, 98 | num_patches=f_args.num_patches, fixed_patch_size=f_args.fixed_patch_size) 99 | scaler.calculate_scaler(train_data) 100 | scaler.save(scaler_path) 101 | 102 | logger.debug(f"scaler mean: {scaler.mean_}") 103 | transforms = box_transforms(cfg.max_frames, scaler, add_axis_conv, crop_patch=True, 104 | fixed_patch_size=f_args.fixed_patch_size) 105 | train_data = DataLoadDf(unlabel_data, label_encoder.encode_unlabel, transforms, 106 | num_patches=f_args.num_patches, fixed_patch_size=f_args.fixed_patch_size) 107 | strong_mask = slice(f_args.batch_size) 108 | weak_mask = slice(f_args.batch_size) 109 | 110 | if torch.cuda.device_count() > 1: 111 | train_sampler = DistributedSampler(train_data) 112 | else: 113 | train_sampler = RandomSampler(train_data) 114 | train_sampler = BatchSampler(train_sampler, f_args.batch_size, drop_last=True) 115 | training_loader = DataLoader(train_data, batch_sampler=train_sampler, collate_fn=collate_fn, pin_memory=True) 116 | 117 | 118 | # ############## 119 | # Model 120 | # ############## 121 | model, criterion, postprocessors = build_model(f_args) 122 | pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad) 123 | logger.info(model) 124 | logger.info("number of parameters in the model: {}".format(pytorch_total_params)) 125 | param_dicts = [ 126 | {"params": [p for n, p in model.named_parameters() if "backbone" not in n and p.requires_grad]}, 127 | { 128 | "params": [p for n, p in model.named_parameters() if "backbone" in n and p.requires_grad], 129 | "lr": f_args.lr_backbone, 130 | }, 131 | ] 132 | 133 | if f_args.pretrain: 134 | logger.info('loading the ptrtrained backbone for self-supervised training') 135 | model_fname = os.path.join(saved_model_dir, f_args.pretrain) 136 | state = torch.load(model_fname, map_location=torch.device('cpu')) 137 | model_dict = model.state_dict() 138 | load_dict = state['model'] 139 | load_dict = {'backbone.0.' + k: v for k, v in load_dict.items() if 140 | ('backbone.0.' + k in model_dict and "class_embed" not in k and "query_embed" not in k)} 141 | model_dict.update(load_dict) 142 | model.load_state_dict(model_dict) 143 | 144 | start_epoch = 0 145 | if f_args.resume: 146 | model_fname = os.path.join(saved_model_dir, f_args.resume) 147 | if torch.cuda.is_available(): 148 | state = torch.load(model_fname) 149 | else: 150 | state = torch.load(model_fname, map_location=torch.device('cpu')) 151 | load_dict = state['model']['state_dict'] 152 | model.load_state_dict(load_dict) 153 | start_epoch = state['epoch'] 154 | logger.info('Resume training form epoch {}'.format(state['epoch'])) 155 | 156 | model = model.cuda() 157 | if torch.cuda.device_count() > 1: 158 | model = DistributedDataParallel(model, device_ids=[f_args.gpu]) 159 | 160 | 161 | optim = torch.optim.AdamW(param_dicts, lr=f_args.lr, 162 | weight_decay=f_args.weight_decay) 163 | lr_scheduler = torch.optim.lr_scheduler.StepLR(optim, f_args.lr_drop) 164 | if f_args.resume: 165 | optim.load_state_dict(state['optimizer']['state_dict']) 166 | 167 | state = { 168 | 'model': {"name": model.__class__.__name__, 169 | 'args': '', 170 | "kwargs": '', 171 | 'state_dict': model.state_dict()}, 172 | 173 | 'optimizer': {"name": optim.__class__.__name__, 174 | 'args': '', 175 | 'state_dict': optim.state_dict()}, 176 | } 177 | 178 | 179 | for epoch in range(start_epoch, f_args.epochs): 180 | model.train() 181 | 182 | loss_value = train(training_loader, model, criterion, optim, epoch, f_args.accumrating_gradient_steps, 183 | mask_weak=weak_mask, normalize=f_args.normalize, mask_strong=strong_mask, max_norm=0.1) 184 | if f_args.adjust_lr: 185 | lr_scheduler.step() 186 | # Validation 187 | model = model.eval() 188 | 189 | # Update state 190 | if is_main_process(): 191 | if torch.cuda.device_count() > 1: 192 | state['model']['state_dict'] = model.module.state_dict() 193 | else: 194 | state['model']['state_dict'] = model.state_dict() 195 | state['optimizer']['state_dict'] = optim.state_dict() 196 | state['epoch'] = epoch 197 | 198 | if f_args.checkpoint_epochs > 0 and (epoch + 1) % f_args.checkpoint_epochs == 0: 199 | model_fname = os.path.join(saved_model_dir, "pretrained_{}_loss_{}".format(f_args.info, epoch)) 200 | torch.save(state, model_fname) -------------------------------------------------------------------------------- /train_ss_sedt.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import inspect 3 | import os 4 | from pprint import pprint 5 | import numpy as np 6 | import torch 7 | 8 | from engine import evaluate, semi_train, adjust_threshold 9 | from data_utils.SedData import SedData, get_dfs 10 | from torch.utils.data import DataLoader 11 | from data_utils.DataLoad import DataLoadDf, ConcatDataset, MultiStreamBatchSampler 12 | import config as cfg 13 | from train_sedt import get_parser 14 | from utilities.Logger import create_logger, set_logger 15 | from utilities.Scaler import Scaler 16 | from utilities.utils import SaveBest, collate_fn, get_cosine_schedule_with_warmup, back_up_code, EarlyStopping 17 | from utilities.utils import to_cuda_if_available 18 | from utilities.BoxEncoder import BoxEncoder 19 | from utilities.BoxTransforms import get_transforms as box_transforms 20 | from sedt import build_model 21 | from utilities.utils import EMA 22 | 23 | 24 | 25 | if __name__ == '__main__': 26 | torch.manual_seed(2020) 27 | np.random.seed(2020) 28 | parser = get_parser() 29 | # semi-train 30 | parser.add_argument('--focal_loss', action="store_true", default=False) 31 | parser.add_argument('--ema_m', type=float, default=0.9996, help='ema momentum for eval_model') 32 | parser.add_argument('--semi_batch_size', default=64, type=int) 33 | parser.add_argument('--accumlating_ema_steps', default=1, type=int) 34 | parser.add_argument('--teacher_model', default=None, help='load teacher from specific model') 35 | parser.add_argument('--teacher_eval', help='load teacher model for evaluation', action="store_false", default=True) 36 | 37 | f_args = parser.parse_args() 38 | assert f_args.dataname == "dcase", "only support dcase dataset now" 39 | if f_args.eval: 40 | f_args.epochs = 0 41 | assert f_args.info, "Don't give the model information to be evaluated" 42 | if f_args.info is None: 43 | f_args.info = f"semi_supervised_{f_args.dataname}_atloss_{f_args.weak_loss_coef}_atploss_{f_args.weak_loss_p_coef}_enc_{f_args.enc_layers}_pooling_{f_args.pooling}_{f_args.fusion_strategy}" 44 | if f_args.log: 45 | set_logger(f_args.info) 46 | logger = create_logger(__name__ + "/" + inspect.currentframe().f_code.co_name, terminal_level=cfg.terminal_level) 47 | logger.info("Semi-supervised Learning for Sound Event Detection Transformer") 48 | logger.info(f"Starting time: {datetime.datetime.now()}") 49 | 50 | os.environ["CUDA_VISIBLE_DEVICES"] = f_args.gpus 51 | 52 | pprint(vars(f_args)) 53 | store_dir = os.path.join(cfg.dir_root, f_args.dataname) 54 | 55 | saved_model_dir = os.path.join(store_dir, "model") 56 | os.makedirs(saved_model_dir, exist_ok=True) 57 | if f_args.back_up: 58 | back_up_code(store_dir, f_args.info) 59 | 60 | # ############## 61 | # DATA 62 | # ############## 63 | dataset = SedData(f_args.dataname, recompute_features=False, compute_log=False) 64 | dfs = get_dfs(dataset, f_args.dataname, unlabel_data=True) 65 | 66 | # Normalisation per audio 67 | add_axis_conv = 0 68 | scaler = Scaler() 69 | scaler_path = os.path.join(store_dir, f_args.dataname + ".json") 70 | num_class = cfg.dcase_classes 71 | label_encoder = BoxEncoder(num_class, seconds=cfg.max_len_seconds) 72 | transforms = box_transforms(cfg.max_frames, add_axis=add_axis_conv) 73 | encod_func = label_encoder.encode_strong_df 74 | weak_data = DataLoadDf(dfs["weak"], encod_func, transforms) 75 | train_synth_data = DataLoadDf(dfs["synthetic"], encod_func, transforms) 76 | train_labeled_data = ConcatDataset([weak_data, train_synth_data]) 77 | 78 | if os.path.isfile(scaler_path): 79 | logger.info('loading scaler from {}'.format(scaler_path)) 80 | scaler.load(scaler_path) 81 | else: 82 | scaler.calculate_scaler(train_labeled_data) 83 | scaler.save(scaler_path) 84 | logger.debug(f"scaler mean: {scaler.mean_}") 85 | 86 | # prepare transforms 87 | transforms_noise = box_transforms(cfg.max_frames, scaler, add_axis_conv, 88 | noise_dict_params={"mean": 0., "snr": cfg.noise_snr}, 89 | freq_mask=f_args.freq_mask, freq_shift=f_args.freq_shift, 90 | time_mask=f_args.time_mask) 91 | transforms_valid = box_transforms(cfg.max_frames, scaler, add_axis_conv) 92 | 93 | 94 | # prepare train dataset 95 | semi_weak_data = DataLoadDf(dfs["weak"], encod_func, transforms_noise, in_memory=cfg.in_memory) 96 | semi_train_synth_data = DataLoadDf(dfs["synthetic"], encod_func, transforms_noise, in_memory=cfg.in_memory) 97 | unlabel_data = DataLoadDf(dfs["unlabel"], encod_func, transforms_noise, in_memory=cfg.in_memory) 98 | 99 | # prepare semi-supervised learning dataset, default: a batch contains 1/4 synthetic data, 1/4 weak data, 1/2 unlabel data 100 | train_semi_dataset = [semi_train_synth_data, semi_weak_data, unlabel_data] 101 | semi_batch_sizes = [f_args.semi_batch_size // 4, f_args.semi_batch_size // 4, 2 * f_args.semi_batch_size // 4] 102 | 103 | # prepare semi dataloader 104 | semi_concat_dataset = ConcatDataset(train_semi_dataset) 105 | semi_sampler = MultiStreamBatchSampler(semi_concat_dataset, batch_sizes=semi_batch_sizes) 106 | semi_training_loader = DataLoader(semi_concat_dataset, batch_sampler=semi_sampler, collate_fn=collate_fn, 107 | pin_memory=True, num_workers=0) 108 | 109 | # prepare data mask, use it to calculate loss and split labeled and unlabeled dataset 110 | semi_weak_mask = slice(semi_batch_sizes[0], semi_batch_sizes[0] + semi_batch_sizes[1]) 111 | semi_strong_mask = slice(semi_batch_sizes[0]) 112 | semi_label_mask = slice(semi_batch_sizes[0] + semi_batch_sizes[1]) 113 | semi_unlabel_mask = slice(semi_batch_sizes[0] + semi_batch_sizes[1], f_args.semi_batch_size) 114 | 115 | # prepare eval dataloader 116 | validation_data = DataLoadDf(dfs["validation"], encod_func, transform=transforms_valid, return_indexes=True) 117 | eval_data = DataLoadDf(dfs["eval"], encod_func, transform=transforms_valid, return_indexes=True) 118 | validation_dataloader = DataLoader(validation_data, batch_size=f_args.batch_size, collate_fn=collate_fn, 119 | num_workers=0) 120 | eval_dataloader = DataLoader(eval_data, batch_size=f_args.batch_size, collate_fn=collate_fn, num_workers=0) 121 | validation_labels_df = dfs["validation"].drop("feature_filename", axis=1) 122 | eval_labels_df = dfs["eval"].drop("feature_filename", axis=1) 123 | 124 | # ############## 125 | # Model 126 | # ############## 127 | model, criterion, postprocessors = build_model(f_args) 128 | pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad) 129 | logger.info(model) 130 | logger.info("number of parameters in the model: {}".format(pytorch_total_params)) 131 | param_dicts = [ 132 | {"params": [p for n, p in model.named_parameters() if "backbone" not in n and p.requires_grad]}, 133 | { 134 | "params": [p for n, p in model.named_parameters() if "backbone" in n and p.requires_grad], 135 | "lr": f_args.lr_backbone, 136 | }, 137 | ] 138 | 139 | # load a well-trained model as teacher 140 | if not f_args.eval: 141 | assert f_args.teacher_model is not None, "please provide the teacher model" 142 | model_fname = os.path.join(saved_model_dir, f_args.teacher_model) 143 | if torch.cuda.is_available(): 144 | state = torch.load(model_fname) 145 | else: 146 | state = torch.load(model_fname, map_location=torch.device('cpu')) 147 | load_dict = state['model']['state_dict'] 148 | model.load_state_dict(load_dict) 149 | logger.info('Using teacher model: ' + model_fname) 150 | 151 | model = model.cuda() 152 | 153 | # ema formula 154 | ema = EMA(model, f_args.ema_m) 155 | ema.register() 156 | 157 | # optimizer and scheduler 158 | optim= torch.optim.AdamW(param_dicts, lr=f_args.lr, weight_decay=f_args.weight_decay) 159 | lr_scheduler = get_cosine_schedule_with_warmup(optim, f_args.epochs, num_warmup_steps=f_args.epochs * 0) 160 | 161 | state = { 162 | 'model': {"name": model.__class__.__name__, 163 | 'args': '', 164 | "kwargs": '', 165 | 'state_dict': model.state_dict()}, 166 | 167 | 'ema_model': {"name": ema.model.__class__.__name__, 168 | 'args': '', 169 | "kwargs": '', 170 | 'state_dict': ema.model.state_dict()}, 171 | 172 | 'optimizer': {"name": optim.__class__.__name__, 173 | 'args': '', 174 | 'state_dict': optim.state_dict()}, 175 | } 176 | 177 | fusion_strategy = f_args.fusion_strategy 178 | best_saver = {} 179 | 180 | for at_m in fusion_strategy: 181 | best_saver[at_m] = SaveBest("sup") 182 | if cfg.early_stopping is not None: 183 | early_stopping_call = EarlyStopping(patience=cfg.early_stopping, fusion_strategy=f_args.fusion_strategy, 184 | val_comp="sup", init_patience=cfg.es_init_wait) 185 | 186 | start_epoch = 0 187 | origin_threshold = torch.tensor([0.5] * f_args.num_classes) 188 | origin_threshold = to_cuda_if_available(origin_threshold) 189 | classwise_threshold = origin_threshold 190 | 191 | for epoch in range(start_epoch, f_args.epochs): 192 | # Train 193 | model.train() 194 | loss_value, pseudo_labels_counter = semi_train(semi_training_loader, model, ema, criterion, optim, epoch, 195 | f_args.accumrating_gradient_steps, f_args.accumlating_ema_steps, 196 | postprocessors, 197 | mask_weak=semi_weak_mask, fine_tune=f_args.fine_tune, 198 | normalize=f_args.normalize, 199 | mask_strong=semi_strong_mask, max_norm=0.1, 200 | mask_unlabel=semi_unlabel_mask, mask_label=semi_label_mask, 201 | fl=f_args.focal_loss, mix_up_ratio=f_args.mix_up_ratio, 202 | classwise_threshold=classwise_threshold) 203 | 204 | classwise_threshold = adjust_threshold(pseudo_labels_counter, origin_threshold) 205 | 206 | if f_args.adjust_lr: 207 | lr_scheduler.step() 208 | 209 | # Validation 210 | model = model.eval() 211 | 212 | # Update state 213 | state['model']['state_dict'] = model.state_dict() # student 214 | ema.apply_shadow() 215 | state['ema_model']['state_dict'] = ema.model.state_dict() # teacher 216 | ema.restore() 217 | 218 | state['optimizer']['state_dict'] = optim.state_dict() 219 | state['epoch'] = epoch 220 | 221 | 222 | # Validation with real data 223 | if f_args.teacher_eval: 224 | logger.info("Using teacher model for validation \n") 225 | ema.apply_shadow() 226 | else: 227 | logger.info("Using student model for validation \n") 228 | 229 | metrics = evaluate(model, criterion, postprocessors, validation_dataloader, label_encoder, validation_labels_df, 230 | at=True, fusion_strategy=fusion_strategy) 231 | 232 | if f_args.teacher_eval: 233 | ema.restore() 234 | 235 | if cfg.save_best: 236 | for at_m, eb in metrics.items(): 237 | state[f'event_based_f1_{at_m}'] = eb 238 | if best_saver[at_m].apply(eb): 239 | model_fname = os.path.join(saved_model_dir, f"{f_args.info}_{at_m}_best") 240 | torch.save(state, model_fname) 241 | 242 | if cfg.early_stopping: 243 | if early_stopping_call.apply(eb): 244 | logger.warn("EARLY STOPPING") 245 | break 246 | 247 | if f_args.checkpoint_epochs > 0 and (epoch + 1) % f_args.checkpoint_epochs == 0: 248 | model_fname = os.path.join(saved_model_dir, "semi_train_{}_loss_{}".format(f_args.info, epoch)) 249 | torch.save(state, model_fname) 250 | 251 | if cfg.save_best or f_args.eval: 252 | for at_m in fusion_strategy: 253 | model_fname = os.path.join(saved_model_dir, f"{f_args.info}_{at_m}_best") 254 | if torch.cuda.is_available(): 255 | state = torch.load(model_fname) 256 | else: 257 | state = torch.load(model_fname, map_location=torch.device('cpu')) 258 | if f_args.teacher_eval: 259 | model.load_state_dict(state['ema_model']['state_dict']) 260 | logger.info(f"using teacher model for test...") 261 | else: 262 | model.load_state_dict(state['model']['state_dict']) 263 | logger.info(f"using student model for test...") 264 | logger.info(f"testing model: {model_fname}, epoch: {state['epoch']}") 265 | 266 | # ############## 267 | # Validation 268 | # ############## 269 | model.eval() 270 | logger.info("Metric on validation") 271 | evaluate(model, criterion, postprocessors, validation_dataloader, label_encoder, validation_labels_df, 272 | at=True, fusion_strategy=[at_m], cal_seg=True, cal_clip=True) 273 | 274 | logger.info("Metric on eval") 275 | evaluate(model, criterion, postprocessors, eval_dataloader, label_encoder, eval_labels_df, 276 | at=True, fusion_strategy=[at_m], cal_seg=True, cal_clip=True) -------------------------------------------------------------------------------- /utilities/BoxEncoder.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import config as cfg 4 | from dcase_util.data import DecisionEncoder 5 | from dcase_util.data import ProbabilityEncoder 6 | 7 | class BoxEncoder: 8 | """" 9 | Adapted after DecisionEncoder.find_contiguous_regions method in 10 | https://github.com/DCASE-REPO/dcase_util/blob/master/dcase_util/data/decisions.py 11 | 12 | Encode labels into numpy arrays where 1 correspond to presence of the class and 0 absence. 13 | Multiple 1 can appear on the same line, it is for multi label problem. 14 | Args: 15 | labels: list, the classes which will be encoded 16 | n_frames: int, (Default value = None) only useful for strong labels. The number of frames of a segment. 17 | Attributes: 18 | labels: list, the classes which will be encoded 19 | n_frames: int, only useful for strong labels. The number of frames of a segment. 20 | """ 21 | 22 | def __init__(self, labels, seconds, generate_patch=False): 23 | if type(labels) in [np.ndarray, np.array]: 24 | labels = labels.tolist() 25 | self.labels = labels 26 | self.seconds = seconds 27 | self.generate_patch = generate_patch 28 | 29 | def encode_unlabel(self, boxes): 30 | """ 31 | Args: 32 | labels: (c_list, l_list) 33 | Returns: 34 | 35 | """ 36 | y = {} 37 | y["labels"] = np.asarray([0]*len(boxes)) 38 | y["boxes"] = np.asarray(boxes) 39 | y["orig_size"] = np.asarray(self.seconds) 40 | y["patches"] = [] 41 | return y 42 | 43 | 44 | def encode_weak(self, labels): 45 | """ Encode a list of weak labels into a numpy array 46 | 47 | Args: 48 | labels: list, list of labels to encode (to a vector of 0 and 1) 49 | 50 | Returns: 51 | numpy.array 52 | A vector containing 1 for each label, and 0 everywhere else 53 | """ 54 | # useful for tensor empty labels 55 | y = {"labels": [], "boxes": [], "orig_size": []} 56 | if type(labels) is str: 57 | if labels == "empty": 58 | return y 59 | else: 60 | labels = labels.split(",") 61 | if type(labels) is pd.DataFrame: 62 | if labels.empty: 63 | labels = [] 64 | elif "event_label" in labels.columns: 65 | labels = labels["event_label"] 66 | if isinstance(self.labels, int): 67 | y[labels] = len(labels) * [0] 68 | else: 69 | for label in labels: 70 | if not pd.isna(label): 71 | i = int(self.labels.index(label)) 72 | y["labels"].append(i) 73 | y["labels"] = np.asarray(y["labels"]) 74 | y["boxes"] = np.asarray(y["boxes"]) 75 | y["orig_size"] = np.asarray(self.seconds) 76 | if self.generate_patch: 77 | y["patches"] = [] 78 | return y 79 | 80 | def encode_strong_df(self, label_df): 81 | """Encode a list (or pandas Dataframe or Serie) of strong labels, they correspond to a given filename 82 | 83 | Args: 84 | label_df: pandas DataFrame or Series, contains filename, onset (in frames) and offset (in frames) 85 | If only filename (no onset offset) is specified, it will return the event on all the frames 86 | onset and offset should be in frames 87 | Returns: 88 | numpy.array 89 | Encoded labels, 1 where the label is present, 0 otherwise 90 | """ 91 | y = {"labels": [], "boxes": [], "orig_size": []} 92 | assert self.seconds is not None, "n_seconds need to be specified when using strong encoder" 93 | if type(label_df) is str: 94 | if label_df == 'empty': 95 | pass 96 | elif type(label_df) is pd.DataFrame: 97 | if {"onset", "offset", "event_label"}.issubset(label_df.columns): 98 | for _, row in label_df.iterrows(): 99 | if not pd.isna(row["event_label"]): 100 | if isinstance(self.labels, int): 101 | i = 0 102 | else: 103 | i = int(self.labels.index(row["event_label"])) 104 | y["labels"].append(i) 105 | onset = float(row["onset"]) / self.seconds 106 | offset = float(row["offset"]) / self.seconds 107 | y["boxes"].append([(onset + offset) / 2, offset - onset]) 108 | 109 | elif type(label_df) in [pd.Series, list, np.ndarray]: # list of list or list of strings 110 | if type(label_df) is pd.Series: 111 | if {"onset", "offset", "event_label"}.issubset(label_df.index): # means only one value 112 | if not pd.isna(label_df["event_label"]): 113 | if isinstance(self.labels, int): 114 | i = 0 115 | else: 116 | i = int(self.labels.index(label_df["event_label"])) 117 | onset = float(label_df["onset"]) / self.seconds 118 | offset = float(label_df["offset"]) / self.seconds 119 | y["labels"].append(i) 120 | y["boxes"].append([(onset + offset) / 2, offset - onset]) 121 | return y 122 | 123 | for event_label in label_df: 124 | # List of string, so weak labels to be encoded in strong 125 | if type(event_label) is str: 126 | if event_label is not "": 127 | if isinstance(self.labels, int): 128 | i = 0 129 | else: 130 | i = int(self.labels.index(event_label)) 131 | y["labels"].append(i) 132 | 133 | # List of list, with [label, onset, offset] 134 | elif len(event_label) == 3: 135 | if event_label[0] is not "": 136 | if isinstance(self.labels, int): 137 | i = 0 138 | else: 139 | i = int(self.labels.index(event_label[0])) 140 | onset = float(event_label[1]) / self.seconds 141 | offset = float(event_label[2]) / self.seconds 142 | y["labels"].append(i) 143 | y["boxes"].append([(onset + offset) / 2, offset - onset]) 144 | 145 | else: 146 | raise NotImplementedError("cannot encode strong, type mismatch: {}".format(type(event_label))) 147 | 148 | else: 149 | raise NotImplementedError("To encode_strong, type is pandas.Dataframe with onset, offset and event_label" 150 | "columns, or it is a list or pandas Series of event labels, " 151 | "type given: {}".format(type(label_df))) 152 | # put events with the same class label together 153 | y["labels"] = np.asarray(y["labels"]) 154 | y["boxes"] = np.asarray(y["boxes"]) 155 | # index = y["labels"].argsort() 156 | # y["labels"] = y["labels"][index] 157 | # y["boxes"] = y["boxes"][index] 158 | y["orig_size"] = np.asarray(self.seconds) 159 | if self.generate_patch: 160 | y["patches"] = [] 161 | return y 162 | 163 | def decode_weak(self, labels): 164 | """ Decode the encoded weak labels 165 | Args: 166 | labels: numpy.array, the encoded labels to be decoded 167 | 168 | Returns: 169 | list 170 | Decoded labels, list of string 171 | 172 | """ 173 | result_labels = [] 174 | for i, value in enumerate(labels): 175 | if value == 1: 176 | result_labels.append(self.labels[i]) 177 | return result_labels 178 | 179 | def decode_strong(self, labels, threshold=0.5, del_overlap = True): 180 | """ Decode the encoded strong labels 181 | Args: 182 | labels: numpy.array, the encoded labels to be decoded 183 | Returns: 184 | list 185 | Decoded labels, list of list: [[label, onset offset], ...] 186 | 187 | """ 188 | result_labels = [] 189 | num_queries = len(labels["scores"]) 190 | event_dict = {} 191 | if not del_overlap: 192 | for i in range(num_queries): 193 | if labels["scores"][i] > threshold : 194 | # ignore result with duration less than 0.2s 195 | onset, offset = labels['boxes'][i] 196 | if offset - onset >= 0.2: 197 | onset, offset = labels['boxes'][i] 198 | result_labels.append([self.labels[labels["labels"][i]], onset, offset, labels["scores"][i]]) 199 | else: 200 | assert not isinstance(self.labels, int), "Don't support del-overlap under self-supervision mode" 201 | for i in range(num_queries): 202 | if labels["scores"][i] >= threshold : 203 | onset, offset = labels['boxes'][i] 204 | # ignore result with duration less than 0.2s 205 | if offset - onset >= 0.2: 206 | class_index = labels["labels"][i] 207 | event_dict.setdefault(self.labels[class_index], []).append( 208 | np.asarray([labels['scores'][i], onset, offset])) 209 | 210 | # del overlap box of same class according to score 211 | for event in event_dict: 212 | event_dict[event] = np.vstack(event_dict[event]) 213 | index = np.argsort(event_dict[event], axis=0)[:, 1] 214 | event_dict[event] = event_dict[event][index] 215 | i = 1 216 | while i < len(event_dict[event]): 217 | if event_dict[event][i][1] < event_dict[event][i - 1][2]: 218 | if event_dict[event][i][0] > event_dict[event][i - 1][0]: 219 | event_dict[event] = np.delete(event_dict[event], i - 1, axis=0) 220 | else: 221 | event_dict[event] = np.delete(event_dict[event], i, axis=0) 222 | continue 223 | i += 1 224 | for i in event_dict[event]: 225 | result_labels.append([event, i[1], i[2], i[0]]) 226 | return result_labels 227 | 228 | def state_dict(self): 229 | return {"labels": self.labels, 230 | "n_frames": self.seconds} 231 | 232 | @classmethod 233 | def load_state_dict(cls, state_dict): 234 | labels = state_dict["labels"] 235 | n_frames = state_dict["n_frames"] 236 | return cls(labels, n_frames) 237 | -------------------------------------------------------------------------------- /utilities/BoxTransforms.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | import librosa 4 | import numpy as np 5 | import torch 6 | from torchvision import transforms 7 | from PIL import ImageFilter 8 | import random 9 | 10 | class Transform: 11 | def transform_data(self, data): 12 | # Mandatory to be defined by subclasses 13 | raise NotImplementedError("Abstract object") 14 | 15 | def transform_label(self, label): 16 | # Do nothing, to be changed in subclasses if needed 17 | return label 18 | 19 | def _apply_transform(self, sample_no_index): 20 | data, label = sample_no_index 21 | if type(data) is tuple: # meaning there is more than one data_input (could be duet, triplet...) 22 | data = list(data) 23 | for k in range(len(data)): 24 | if (type(self).__name__ == "TimeMask"): 25 | if (k == 0): 26 | continue 27 | data[k] = self.transform_data(data[k]) 28 | data = tuple(data) 29 | else: 30 | data = self.transform_data(data) 31 | if self.__class__.__name__ == 'Query': 32 | data, label = self.transform_label(sample_no_index) 33 | else: 34 | label = self.transform_label(label) 35 | return data, label 36 | 37 | def __call__(self, sample): 38 | """ Apply the transformation 39 | Args: 40 | sample: tuple, a sample defined by a DataLoad class 41 | 42 | Returns: 43 | tuple 44 | The transformed tuple 45 | """ 46 | if type(sample[1]) is int: # Means there is an index, may be another way to make it cleaner 47 | sample_data, index = sample 48 | sample_data = self._apply_transform(sample_data) 49 | sample = sample_data, index 50 | else: 51 | sample = self._apply_transform(sample) 52 | return sample 53 | 54 | 55 | class ApplyLog(Transform): 56 | """Convert ndarrays in sample to Tensors.""" 57 | 58 | def transform_data(self, data): 59 | """ Apply the transformation on data 60 | Args: 61 | data: np.array, the data to be modified 62 | 63 | Returns: 64 | np.array 65 | The transformed data 66 | """ 67 | return librosa.amplitude_to_db(data.T).T 68 | 69 | 70 | def pad_trunc_seq(x, max_len): 71 | """Pad or truncate a sequence data to a fixed length. 72 | The sequence should be on axis -2. 73 | 74 | Args: 75 | x: ndarray, input sequence data. 76 | max_len: integer, length of sequence to be padded or truncated. 77 | 78 | Returns: 79 | ndarray, Padded or truncated input sequence data. 80 | """ 81 | shape = x.shape 82 | if shape[-2] <= max_len: 83 | padded = max_len - shape[-2] 84 | padded_shape = ((0, 0),) * len(shape[:-2]) + ((0, padded), (0, 0)) 85 | x = np.pad(x, padded_shape, mode="constant") 86 | else: 87 | x = x[..., :max_len, :] 88 | return x 89 | 90 | 91 | class PadOrTrunc(Transform): 92 | """ Pad or truncate a sequence given a number of frames 93 | Args: 94 | nb_frames: int, the number of frames to match 95 | Attributes: 96 | nb_frames: int, the number of frames to match 97 | """ 98 | 99 | def __init__(self, nb_frames, apply_to_label=False): 100 | self.nb_frames = nb_frames 101 | self.apply_to_label = apply_to_label 102 | 103 | def transform_label(self, label): 104 | if self.apply_to_label: 105 | return pad_trunc_seq(label, self.nb_frames) 106 | else: 107 | return label 108 | 109 | def transform_data(self, data): 110 | """ Apply the transformation on data 111 | Args: 112 | data: np.array, the data to be modified 113 | 114 | Returns: 115 | np.array 116 | The transformed data 117 | """ 118 | return pad_trunc_seq(data, self.nb_frames) 119 | 120 | 121 | class AugmentGaussianNoise(Transform): 122 | """ Pad or truncate a sequence given a number of frames 123 | Args: 124 | mean: float, mean of the Gaussian noise to add 125 | Attributes: 126 | std: float, std of the Gaussian noise to add 127 | """ 128 | 129 | def __init__(self, mean=0., std=None, snr=None, p=0.5): 130 | self.mean = mean 131 | self.std = std 132 | self.snr = snr 133 | self.p = p 134 | 135 | @staticmethod 136 | def gaussian_noise(features, snr): 137 | """Apply gaussian noise on each point of the data 138 | 139 | Args: 140 | features: numpy.array, features to be modified 141 | snr: float, average snr to be used for data augmentation 142 | Returns: 143 | numpy.ndarray 144 | Modified features 145 | """ 146 | # If using source separation, using only the first audio (the mixture) to compute the gaussian noise, 147 | # Otherwise it just removes the first axis if it was an extended one 148 | if len(features.shape) == 3: 149 | feat_used = features[0] 150 | else: 151 | feat_used = features 152 | std = np.sqrt(np.mean((feat_used ** 2) * (10 ** (-snr / 10)), axis=-2)) 153 | try: 154 | noise = np.random.normal(0, std, features.shape) 155 | except Exception as e: 156 | warnings.warn(f"the computed noise did not work std: {std}, using 0.5 for std instead") 157 | noise = np.random.normal(0, 0.5, features.shape) 158 | 159 | return features + noise 160 | 161 | def transform_data(self, data): 162 | """ Apply the transformation on data 163 | Args: 164 | data: np.array, the data to be modified 165 | 166 | Returns: 167 | (np.array, np.array) 168 | (original data, noisy_data (data + noise)) 169 | """ 170 | random_num = np.random.uniform(0, 1) 171 | if random_num < self.p: 172 | if self.std is not None: 173 | noisy_data = data + np.abs(np.random.normal(0, 0.5 ** 2, data.shape)) 174 | elif self.snr is not None: 175 | noisy_data = self.gaussian_noise(data, self.snr) 176 | else: 177 | raise NotImplementedError("Only (mean, std) or snr can be given") 178 | return data, noisy_data 179 | else: 180 | return data, data 181 | 182 | 183 | class ToTensor(Transform): 184 | """Convert ndarrays in sample to Tensors. 185 | Args: 186 | unsqueeze_axis: int, (Default value = None) add an dimension to the axis mentioned. 187 | Useful to add a channel axis to use CNN. 188 | Attributes: 189 | unsqueeze_axis: int, add an dimension to the axis mentioned. 190 | Useful to add a channel axis to use CNN. 191 | """ 192 | 193 | def __init__(self, unsqueeze_axis=None): 194 | self.unsqueeze_axis = unsqueeze_axis 195 | 196 | def transform_data(self, data): 197 | """ Apply the transformation on data 198 | Args: 199 | data: np.array, the data to be modified 200 | 201 | Returns: 202 | np.array 203 | The transformed data 204 | """ 205 | res_data = torch.from_numpy(data).float() 206 | if self.unsqueeze_axis is not None: 207 | res_data = res_data.unsqueeze(self.unsqueeze_axis) 208 | return res_data 209 | 210 | def transform_label(self, label): 211 | label["labels"] = torch.from_numpy(label["labels"]).long() 212 | label["boxes"] = torch.from_numpy(label["boxes"]).float() 213 | label["orig_size"] = torch.from_numpy(label["orig_size"]) 214 | return label # float otherwise error 215 | 216 | 217 | class Normalize(Transform): 218 | """Normalize inputs 219 | Args: 220 | scaler: Scaler object, the scaler to be used to normalize the data 221 | Attributes: 222 | scaler : Scaler object, the scaler to be used to normalize the data 223 | """ 224 | 225 | def __init__(self, scaler): 226 | self.scaler = scaler 227 | 228 | def transform_data(self, data): 229 | """ Apply the transformation on data 230 | Args: 231 | data: np.array, the data to be modified 232 | 233 | Returns: 234 | np.array 235 | The transformed data 236 | """ 237 | return self.scaler.normalize(data) 238 | 239 | 240 | class CombineChannels(Transform): 241 | """ Combine channels when using source separation (to remove the channels with low intensity) 242 | Args: 243 | combine_on: str, in {"max", "min"}, the channel in which to combine the channels with the smallest energy 244 | n_channel_mix: int, the number of lowest energy channel to combine in another one 245 | """ 246 | 247 | def __init__(self, combine_on="max", n_channel_mix=2): 248 | self.combine_on = combine_on 249 | self.n_channel_mix = n_channel_mix 250 | 251 | def transform_data(self, data): 252 | """ Apply the transformation on data 253 | Args: 254 | data: np.array, the data to be modified, assuming the first values are the mixture, 255 | and the other channels the sources 256 | 257 | Returns: 258 | np.array 259 | The transformed data 260 | """ 261 | mix = data[:1] # :1 is just to keep the first axis 262 | sources = data[1:] 263 | channels_en = (sources ** 2).sum(-1).sum(-1) # Get the energy per channel 264 | indexes_sorted = channels_en.argsort() 265 | sources_to_add = sources[indexes_sorted[:2]].sum(0) 266 | if self.combine_on == "min": 267 | sources[indexes_sorted[2]] += sources_to_add 268 | elif self.combine_on == "max": 269 | sources[indexes_sorted[-1]] += sources_to_add 270 | return np.concatenate((mix, sources[indexes_sorted[2:]])) 271 | 272 | 273 | class Compose(object): 274 | """Composes several transforms together. 275 | Args: 276 | transforms: list of ``Transform`` objects, list of transforms to compose. 277 | Example of transform: ToTensor() 278 | """ 279 | 280 | def __init__(self, transforms): 281 | self.transforms = transforms 282 | 283 | def add_transform(self, transform): 284 | t = self.transforms.copy() 285 | t.append(transform) 286 | return Compose(t) 287 | 288 | def __call__(self, audio): 289 | for t in self.transforms: 290 | audio = t(audio) 291 | return audio 292 | 293 | def __repr__(self): 294 | format_string = self.__class__.__name__ + '(' 295 | for t in self.transforms: 296 | format_string += '\n' 297 | format_string += ' {0}'.format(t) 298 | format_string += '\n)' 299 | 300 | return format_string 301 | 302 | 303 | class GaussianBlur(object): 304 | """Gaussian blur augmentation in SimCLR https://arxiv.org/abs/2002.05709""" 305 | 306 | def __init__(self, sigma=[.1, 2.]): 307 | self.sigma = sigma 308 | 309 | def __call__(self, x): 310 | sigma = random.uniform(self.sigma[0], self.sigma[1]) 311 | x = x.filter(ImageFilter.GaussianBlur(radius=sigma)) 312 | return x 313 | 314 | 315 | class Query(Transform): 316 | def __init__(self, fixed_patch_size=False): 317 | self.fixed_patch_size = fixed_patch_size 318 | self.transformer = transforms.Compose([ 319 | transforms.ToPILImage(), 320 | transforms.Resize((128, 64)), 321 | # transforms.RandomApply([ 322 | # transforms.ColorJitter(0.4, 0.4, 0.4, 0.1) # not strengthened 323 | # ], p=0.8), 324 | # transforms.RandomGrayscale(p=0.2), 325 | # transforms.RandomApply([GaussianBlur([.1, 2.])], p=0.5), 326 | transforms.ToTensor() 327 | ]) 328 | 329 | def transform_data(self, data): 330 | return data 331 | 332 | def transform_label(self, sample): 333 | data, label = sample 334 | if "patches" not in label: 335 | return data, label 336 | c, t, f = data.shape 337 | assert "boxes" in label, "there are no 'boxes' in label, please check your data" 338 | patches = [] 339 | for box in label['boxes']: 340 | c, l = box.numpy() 341 | s, e = c - l / 2, c + l / 2 342 | s_idx, e_idx = int(s * t), int(e * t) 343 | if self.fixed_patch_size: 344 | e_idx = min(t, s_idx + 128) 345 | s_idx = e_idx - 128 346 | patch_t = data[:, s_idx:e_idx, :] 347 | else: 348 | # make sure patch is not empty 349 | if s_idx >= e_idx: 350 | s_idx = max(0, s_idx - 1) 351 | e_idx = min(t, e_idx + 1) 352 | patch_ori = data[:, s_idx:e_idx, :] 353 | # map to [0,1] 354 | min_v, max_v = patch_ori.min(), patch_ori.max() 355 | patch_norm = (patch_ori - min_v) / (max_v - min_v) 356 | patch_norm_t = self.transformer(patch_norm) 357 | patch_t = patch_norm_t * (max_v - min_v) + min_v 358 | patches.append(patch_t) 359 | label["patches"] = torch.stack(patches, dim=0) 360 | return data, label 361 | 362 | 363 | class TimeMask(Transform): 364 | def __init__(self, min_band_part=0.0, max_band_part=0.1, fade=False, p=0.2): 365 | """ 366 | :param min_band_part: Minimum length of the silent part as a fraction of the 367 | total sound length. Float. 368 | :param max_band_part: Maximum length of the silent part as a fraction of the 369 | total sound length. Float. 370 | :param fade: Bool, Add linear fade in and fade out of the silent part. 371 | :param p: The probability of applying this transform 372 | """ 373 | self.min_band_part = min_band_part 374 | self.max_band_part = max_band_part 375 | self.fade = fade 376 | self.p = p 377 | self.parameters = {} 378 | 379 | def randomize_parameters(self): 380 | self.parameters["apply"] = np.random.uniform(0, 1) < self.p 381 | self.parameters["t"] = np.random.uniform(self.min_band_part, self.max_band_part) 382 | self.parameters["t0"] = np.random.uniform(0, 1 - self.parameters["t"]) 383 | 384 | def transform_data(self, data): 385 | self.randomize_parameters() 386 | if self.parameters["apply"]: 387 | nframes, nfreq = data.shape 388 | t = int(self.parameters["t"] * nframes) 389 | t0 = int(self.parameters["t0"] * nframes) 390 | mask = np.zeros((t, nfreq)) 391 | if self.fade: 392 | fade_length = int(t * 0.1) 393 | mask[0:fade_length, :] = np.linspace(1, 0, num=fade_length) 394 | mask[-fade_length:, :] = np.linspace(0, 1, num=fade_length) 395 | data[t0:t0 + t, :] *= mask 396 | return data 397 | 398 | 399 | class FreqMask(Transform): 400 | def __init__(self, min_mask_fraction=0.03, max_mask_fraction=0.4, fill_mode="constant", fill_constant=0, p=0.5): 401 | self.min_mask_fraction = min_mask_fraction 402 | self.max_mask_fraction = max_mask_fraction 403 | assert fill_mode in ("mean", "constant") 404 | self.fill_mode = fill_mode 405 | self.constant = fill_constant 406 | self.p = p 407 | self.parameters = {} 408 | 409 | def randomize_parameters(self): 410 | self.parameters["apply"] = np.random.uniform(0, 1) < self.p 411 | self.parameters["f"] = np.random.uniform(self.min_mask_fraction, self.max_mask_fraction) 412 | self.parameters["f0"] = np.random.uniform(0, 1 - self.parameters["f"]) 413 | 414 | def transform_data(self, data): 415 | self.randomize_parameters() 416 | if self.parameters["apply"]: 417 | nframe, nmel = data.shape 418 | f = int(self.parameters["f"] * nmel) 419 | f0 = int(self.parameters["f0"] * nmel) 420 | if self.fill_mode == "mean": 421 | fill_value = np.mean(data[:, f0:f0 + f]) 422 | else: 423 | fill_value = self.constant 424 | data[:, f0:f + f0] = fill_value 425 | return data 426 | 427 | 428 | class FreqShift(Transform): 429 | def __init__(self, p=0.5, max_band=4, mean=0, std=2): 430 | self.p = p 431 | self.max_band = max_band 432 | self.mean = mean 433 | self.std = std 434 | self.parameters = {} 435 | 436 | def randomize_parameters(self): 437 | self.parameters["apply"] = np.random.uniform(0, 1) < self.p 438 | shift_size = int(np.random.normal(self.mean, self.std)) 439 | while abs(shift_size) > self.max_band: 440 | shift_size = int(np.random.normal(self.mean, self.std)) 441 | self.parameters["shift_size"] = shift_size 442 | 443 | def transform_data(self, data): 444 | self.randomize_parameters() 445 | if self.parameters["apply"]: 446 | data = np.roll(data, self.parameters["shift_size"], axis=1) 447 | if self.parameters["shift_size"] >= 0: 448 | data[:, :self.parameters["shift_size"]] = 0 449 | else: 450 | data[:, self.parameters["shift_size"]:] = 0 451 | return data 452 | 453 | 454 | def get_transforms(frames=None, scaler=None, add_axis=0, noise_dict_params=None, combine_channels_args=None, 455 | crop_patch=False, fixed_patch_size=False, freq_mask=False, freq_shift=False, time_mask=False): 456 | transf = [] 457 | unsqueeze_axis = None 458 | if add_axis is not None: 459 | unsqueeze_axis = add_axis 460 | 461 | if combine_channels_args is not None: 462 | transf.append(CombineChannels(*combine_channels_args)) 463 | 464 | if noise_dict_params is not None: 465 | transf.append(AugmentGaussianNoise(**noise_dict_params)) 466 | 467 | transf.append(ApplyLog()) 468 | 469 | if frames is not None: 470 | transf.append(PadOrTrunc(nb_frames=frames)) 471 | 472 | if time_mask: 473 | transf.append(TimeMask()) 474 | 475 | if freq_mask: 476 | transf.append(FreqMask(fill_mode="mean")) 477 | 478 | 479 | if freq_shift: 480 | transf.append(FreqShift()) 481 | 482 | transf.append(ToTensor(unsqueeze_axis=unsqueeze_axis)) 483 | 484 | if scaler is not None: 485 | transf.append(Normalize(scaler=scaler)) 486 | 487 | if crop_patch: 488 | transf.append(Query(fixed_patch_size)) 489 | 490 | return Compose(transf) 491 | -------------------------------------------------------------------------------- /utilities/FrameEncoder.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | from dcase_util.data import DecisionEncoder 4 | 5 | class ManyHotEncoder: 6 | """" 7 | Adapted after DecisionEncoder.find_contiguous_regions method in 8 | https://github.com/DCASE-REPO/dcase_util/blob/master/dcase_util/data/decisions.py 9 | 10 | Encode labels into numpy arrays where 1 correspond to presence of the class and 0 absence. 11 | Multiple 1 can appear on the same line, it is for multi label problem. 12 | Args: 13 | labels: list, the classes which will be encoded 14 | n_frames: int, (Default value = None) only useful for strong labels. The number of frames of a segment. 15 | Attributes: 16 | labels: list, the classes which will be encoded 17 | n_frames: int, only useful for strong labels. The number of frames of a segment. 18 | """ 19 | def __init__(self, labels, n_frames=None): 20 | if type(labels) in [np.ndarray, np.array]: 21 | labels = labels.tolist() 22 | self.labels = labels 23 | self.n_frames = n_frames 24 | 25 | def encode_weak(self, labels): 26 | """ Encode a list of weak labels into a numpy array 27 | 28 | Args: 29 | labels: list, list of labels to encode (to a vector of 0 and 1) 30 | 31 | Returns: 32 | numpy.array 33 | A vector containing 1 for each label, and 0 everywhere else 34 | """ 35 | # useful for tensor empty labels 36 | if type(labels) is str: 37 | if labels == "empty": 38 | y = np.zeros(len(self.labels)) - 1 39 | return y 40 | else: 41 | labels = labels.split(",") 42 | if type(labels) is pd.DataFrame: 43 | if labels.empty: 44 | labels = [] 45 | elif "event_label" in labels.columns: 46 | labels = labels["event_label"] 47 | y = np.zeros(len(self.labels)) 48 | for label in labels: 49 | if not pd.isna(label): 50 | i = self.labels.index(label) 51 | y[i] = 1 52 | return y 53 | 54 | def encode_strong_df(self, label_df): 55 | """Encode a list (or pandas Dataframe or Serie) of strong labels, they correspond to a given filename 56 | 57 | Args: 58 | label_df: pandas DataFrame or Series, contains filename, onset (in frames) and offset (in frames) 59 | If only filename (no onset offset) is specified, it will return the event on all the frames 60 | onset and offset should be in frames 61 | Returns: 62 | numpy.array 63 | Encoded labels, 1 where the label is present, 0 otherwise 64 | """ 65 | 66 | assert self.n_frames is not None, "n_frames need to be specified when using strong encoder" 67 | if type(label_df) is str: 68 | if label_df == 'empty': 69 | y = np.zeros((self.n_frames, len(self.labels))) - 1 70 | return y 71 | y = np.zeros((self.n_frames, len(self.labels))) 72 | if type(label_df) is pd.DataFrame: 73 | if {"onset", "offset", "event_label"}.issubset(label_df.columns): 74 | for _, row in label_df.iterrows(): 75 | if not pd.isna(row["event_label"]): 76 | i = self.labels.index(row["event_label"]) 77 | onset = int(row["onset"]) 78 | offset = int(row["offset"]) 79 | y[onset:offset, i] = 1 # means offset not included (hypothesis of overlapping frames, so ok) 80 | 81 | elif type(label_df) in [pd.Series, list, np.ndarray]: # list of list or list of strings 82 | if type(label_df) is pd.Series: 83 | if {"onset", "offset", "event_label"}.issubset(label_df.index): # means only one value 84 | if not pd.isna(label_df["event_label"]): 85 | i = self.labels.index(label_df["event_label"]) 86 | onset = int(label_df["onset"]) 87 | offset = int(label_df["offset"]) 88 | y[onset:offset, i] = 1 89 | return y 90 | 91 | for event_label in label_df: 92 | # List of string, so weak labels to be encoded in strong 93 | if type(event_label) is str: 94 | if event_label is not "": 95 | i = self.labels.index(event_label) 96 | y[:, i] = 1 97 | 98 | # List of list, with [label, onset, offset] 99 | elif len(event_label) == 3: 100 | if event_label[0] is not "": 101 | i = self.labels.index(event_label[0]) 102 | onset = int(event_label[1]) 103 | offset = int(event_label[2]) 104 | y[onset:offset, i] = 1 105 | 106 | else: 107 | raise NotImplementedError("cannot encode strong, type mismatch: {}".format(type(event_label))) 108 | 109 | else: 110 | raise NotImplementedError("To encode_strong, type is pandas.Dataframe with onset, offset and event_label" 111 | "columns, or it is a list or pandas Series of event labels, " 112 | "type given: {}".format(type(label_df))) 113 | return y 114 | 115 | def decode_weak(self, labels): 116 | """ Decode the encoded weak labels 117 | Args: 118 | labels: numpy.array, the encoded labels to be decoded 119 | 120 | Returns: 121 | list 122 | Decoded labels, list of string 123 | 124 | """ 125 | result_labels = [] 126 | for i, value in enumerate(labels): 127 | if value == 1: 128 | result_labels.append(self.labels[i]) 129 | return result_labels 130 | 131 | def decode_strong(self, labels): 132 | """ Decode the encoded strong labels 133 | Args: 134 | labels: numpy.array, the encoded labels to be decoded 135 | Returns: 136 | list 137 | Decoded labels, list of list: [[label, onset offset], ...] 138 | 139 | """ 140 | result_labels = [] 141 | for i, label_column in enumerate(labels.T): 142 | change_indices = DecisionEncoder().find_contiguous_regions(label_column) 143 | 144 | # append [label, onset, offset] in the result list 145 | for row in change_indices: 146 | result_labels.append([self.labels[i], row[0], row[1]]) 147 | return result_labels 148 | 149 | def state_dict(self): 150 | return {"labels": self.labels, 151 | "n_frames": self.n_frames} 152 | 153 | @classmethod 154 | def load_state_dict(cls, state_dict): 155 | labels = state_dict["labels"] 156 | n_frames = state_dict["n_frames"] 157 | return cls(labels, n_frames) -------------------------------------------------------------------------------- /utilities/FrameTransforms.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | import librosa 4 | import numpy as np 5 | import torch 6 | 7 | 8 | class Transform: 9 | def transform_data(self, data): 10 | # Mandatory to be defined by subclasses 11 | raise NotImplementedError("Abstract object") 12 | 13 | def transform_label(self, label): 14 | # Do nothing, to be changed in subclasses if needed 15 | return label 16 | 17 | def _apply_transform(self, sample_no_index): 18 | data, label = sample_no_index 19 | if self.__class__.__name__ == 'Time_shift': 20 | data = self.transform_data(data) 21 | else: 22 | if type(data) is tuple: # meaning there is more than one data_input (could be duet, triplet...) 23 | data = list(data) 24 | for k in range(len(data)): 25 | data[k] = self.transform_data(data[k]) 26 | data = tuple(data) 27 | else: 28 | data = self.transform_data(data) 29 | label = self.transform_label(label) 30 | return data, label 31 | 32 | def __call__(self, sample): 33 | """ Apply the transformation 34 | Args: 35 | sample: tuple, a sample defined by a DataLoad class 36 | 37 | Returns: 38 | tuple 39 | The transformed tuple 40 | """ 41 | if type(sample[1]) is int: # Means there is an index, may be another way to make it cleaner 42 | sample_data, index = sample 43 | sample_data = self._apply_transform(sample_data) 44 | sample = sample_data, index 45 | else: 46 | sample = self._apply_transform(sample) 47 | return sample 48 | 49 | 50 | class ApplyLog(Transform): 51 | """Convert ndarrays in sample to Tensors.""" 52 | 53 | def transform_data(self, data): 54 | """ Apply the transformation on data 55 | Args: 56 | data: np.array, the data to be modified 57 | 58 | Returns: 59 | np.array 60 | The transformed data 61 | """ 62 | return librosa.amplitude_to_db(data.T).T 63 | 64 | 65 | def pad_trunc_seq(x, max_len): 66 | """Pad or truncate a sequence data to a fixed length. 67 | The sequence should be on axis -2. 68 | 69 | Args: 70 | x: ndarray, input sequence data. 71 | max_len: integer, length of sequence to be padded or truncated. 72 | 73 | Returns: 74 | ndarray, Padded or truncated input sequence data. 75 | """ 76 | shape = x.shape 77 | if shape[-2] <= max_len: 78 | padded = max_len - shape[-2] 79 | padded_shape = ((0, 0),) * len(shape[:-2]) + ((0, padded), (0, 0)) 80 | x = np.pad(x, padded_shape, mode="constant") 81 | else: 82 | x = x[..., :max_len, :] 83 | return x 84 | 85 | 86 | class PadOrTrunc(Transform): 87 | """ Pad or truncate a sequence given a number of frames 88 | Args: 89 | nb_frames: int, the number of frames to match 90 | Attributes: 91 | nb_frames: int, the number of frames to match 92 | """ 93 | 94 | def __init__(self, nb_frames, apply_to_label=False): 95 | self.nb_frames = nb_frames 96 | self.apply_to_label = apply_to_label 97 | 98 | def transform_label(self, label): 99 | if self.apply_to_label: 100 | return pad_trunc_seq(label, self.nb_frames) 101 | else: 102 | return label 103 | 104 | def transform_data(self, data): 105 | """ Apply the transformation on data 106 | Args: 107 | data: np.array, the data to be modified 108 | 109 | Returns: 110 | np.array 111 | The transformed data 112 | """ 113 | return pad_trunc_seq(data, self.nb_frames) 114 | 115 | 116 | class AugmentGaussianNoise(Transform): 117 | """ Pad or truncate a sequence given a number of frames 118 | Args: 119 | mean: float, mean of the Gaussian noise to add 120 | Attributes: 121 | std: float, std of the Gaussian noise to add 122 | """ 123 | 124 | def __init__(self, mean=0., std=None, snr=None): 125 | self.mean = mean 126 | self.std = std 127 | self.snr = snr 128 | 129 | @staticmethod 130 | def gaussian_noise(features, snr): 131 | """Apply gaussian noise on each point of the data 132 | 133 | Args: 134 | features: numpy.array, features to be modified 135 | snr: float, average snr to be used for data augmentation 136 | Returns: 137 | numpy.ndarray 138 | Modified features 139 | """ 140 | # If using source separation, using only the first audio (the mixture) to compute the gaussian noise, 141 | # Otherwise it just removes the first axis if it was an extended one 142 | if len(features.shape) == 3: 143 | feat_used = features[0] 144 | else: 145 | feat_used = features 146 | std = np.sqrt(np.mean((feat_used ** 2) * (10 ** (-snr / 10)), axis=-2)) 147 | try: 148 | noise = np.random.normal(0, std, features.shape) 149 | except Exception as e: 150 | warnings.warn(f"the computed noise did not work std: {std}, using 0.5 for std instead") 151 | noise = np.random.normal(0, 0.5, features.shape) 152 | 153 | return features + noise 154 | 155 | def transform_data(self, data): 156 | """ Apply the transformation on data 157 | Args: 158 | data: np.array, the data to be modified 159 | 160 | Returns: 161 | (np.array, np.array) 162 | (original data, noisy_data (data + noise)) 163 | """ 164 | if self.std is not None: 165 | noisy_data = data + np.abs(np.random.normal(0, 0.5 ** 2, data.shape)) 166 | elif self.snr is not None: 167 | noisy_data = self.gaussian_noise(data, self.snr) 168 | else: 169 | raise NotImplementedError("Only (mean, std) or snr can be given") 170 | return data, noisy_data 171 | 172 | 173 | class ToTensor(Transform): 174 | """Convert ndarrays in sample to Tensors. 175 | Args: 176 | unsqueeze_axis: int, (Default value = None) add an dimension to the axis mentioned. 177 | Useful to add a channel axis to use CNN. 178 | Attributes: 179 | unsqueeze_axis: int, add an dimension to the axis mentioned. 180 | Useful to add a channel axis to use CNN. 181 | """ 182 | 183 | def __init__(self, unsqueeze_axis=None): 184 | self.unsqueeze_axis = unsqueeze_axis 185 | 186 | def transform_data(self, data): 187 | """ Apply the transformation on data 188 | Args: 189 | data: np.array, the data to be modified 190 | 191 | Returns: 192 | np.array 193 | The transformed data 194 | """ 195 | res_data = torch.from_numpy(data).float() 196 | if self.unsqueeze_axis is not None: 197 | res_data = res_data.unsqueeze(self.unsqueeze_axis) 198 | return res_data 199 | 200 | def transform_label(self, label): 201 | return torch.from_numpy(label).float() # float otherwise error 202 | 203 | 204 | class Normalize(Transform): 205 | """Normalize inputs 206 | Args: 207 | scaler: Scaler object, the scaler to be used to normalize the data 208 | Attributes: 209 | scaler : Scaler object, the scaler to be used to normalize the data 210 | """ 211 | 212 | def __init__(self, scaler): 213 | self.scaler = scaler 214 | 215 | def transform_data(self, data): 216 | """ Apply the transformation on data 217 | Args: 218 | data: np.array, the data to be modified 219 | 220 | Returns: 221 | np.array 222 | The transformed data 223 | """ 224 | return self.scaler.normalize(data) 225 | 226 | 227 | class CombineChannels(Transform): 228 | """ Combine channels when using source separation (to remove the channels with low intensity) 229 | Args: 230 | combine_on: str, in {"max", "min"}, the channel in which to combine the channels with the smallest energy 231 | n_channel_mix: int, the number of lowest energy channel to combine in another one 232 | """ 233 | 234 | def __init__(self, combine_on="max", n_channel_mix=2): 235 | self.combine_on = combine_on 236 | self.n_channel_mix = n_channel_mix 237 | 238 | def transform_data(self, data): 239 | """ Apply the transformation on data 240 | Args: 241 | data: np.array, the data to be modified, assuming the first values are the mixture, 242 | and the other channels the sources 243 | 244 | Returns: 245 | np.array 246 | The transformed data 247 | """ 248 | mix = data[:1] # :1 is just to keep the first axis 249 | sources = data[1:] 250 | channels_en = (sources ** 2).sum(-1).sum(-1) # Get the energy per channel 251 | indexes_sorted = channels_en.argsort() 252 | sources_to_add = sources[indexes_sorted[:2]].sum(0) 253 | if self.combine_on == "min": 254 | sources[indexes_sorted[2]] += sources_to_add 255 | elif self.combine_on == "max": 256 | sources[indexes_sorted[-1]] += sources_to_add 257 | return np.concatenate((mix, sources[indexes_sorted[2:]])) 258 | 259 | 260 | class Compose(object): 261 | """Composes several transforms together. 262 | Args: 263 | transforms: list of ``Transform`` objects, list of transforms to compose. 264 | Example of transform: ToTensor() 265 | """ 266 | 267 | def __init__(self, transforms): 268 | self.transforms = transforms 269 | 270 | def add_transform(self, transform): 271 | t = self.transforms.copy() 272 | t.append(transform) 273 | return Compose(t) 274 | 275 | def __call__(self, audio): 276 | for t in self.transforms: 277 | audio = t(audio) 278 | return audio 279 | 280 | def __repr__(self): 281 | format_string = self.__class__.__name__ + '(' 282 | for t in self.transforms: 283 | format_string += '\n' 284 | format_string += ' {0}'.format(t) 285 | format_string += '\n)' 286 | 287 | return format_string 288 | 289 | 290 | class Time_warping(Transform): 291 | def __init__(self, mean=0, std=90): 292 | self.mean = mean 293 | self.std = std 294 | 295 | def transform_data(self, data): 296 | time_warping_para = np.random.normal(self.mean, self.std, (1, ))[0] 297 | warped_data = time_warp(data, W=time_warping_para) 298 | return warped_data 299 | 300 | 301 | class Time_shift(Transform): 302 | def __init__(self, tpr, mean=0, std=90): 303 | self.mean = mean 304 | self.tpr = tpr 305 | self.std = std // tpr 306 | self.label_shift_size = 0 307 | self.shift_size = 0 308 | 309 | def transform_data(self, data): 310 | """Time shifting. 311 | Args: 312 | data: tuple (data,noise_data) size: (channel, Time, Freq) 313 | shift_size: shift size parameter 314 | """ 315 | data = list(data) 316 | t = data[0].shape[1] 317 | self.label_shift_size = int(np.random.normal(self.mean, self.std, (1, ))[0]) 318 | while abs(self.label_shift_size)*self.tpr > t: 319 | self.label_shift_size = int(np.random.normal(self.mean, self.std, (1,))[0]) 320 | self.shift_size = self.tpr * self.label_shift_size 321 | new_data = [] 322 | for d in data: 323 | if self.shift_size > 0: 324 | left = d[:, self.shift_size:, :] 325 | right = d[:, :self.shift_size, :] 326 | else: 327 | right = d[:, :(t + self.shift_size), :] 328 | left = d[:, (t + self.shift_size):, :] 329 | d = torch.cat((left, right), dim=1) 330 | new_data.append(d) 331 | return tuple(new_data) 332 | def transform_label(self, label): 333 | ''' 334 | :param label: [time,fre] 335 | :return: 336 | ''' 337 | t = label.shape[0] 338 | if self.label_shift_size > 0: 339 | left = label[self.label_shift_size:, :] 340 | right = label[:self.label_shift_size, :] 341 | else: 342 | right = label[:(t + self.label_shift_size), :] 343 | left = label[(t + self.label_shift_size):, :] 344 | label = torch.cat((left, right), dim=0) 345 | return label 346 | 347 | 348 | 349 | 350 | 351 | def get_transforms(frames=None, scaler=None, add_axis=0, noise_dict_params=None, combine_channels_args=None, 352 | time_shifting=None): 353 | transf = [] 354 | unsqueeze_axis = None 355 | if add_axis is not None: 356 | unsqueeze_axis = add_axis 357 | 358 | if combine_channels_args is not None: 359 | transf.append(CombineChannels(*combine_channels_args)) 360 | 361 | if noise_dict_params is not None: 362 | transf.append(AugmentGaussianNoise(**noise_dict_params)) 363 | 364 | transf.append(ApplyLog()) 365 | 366 | if frames is not None: 367 | transf.append(PadOrTrunc(nb_frames=frames)) 368 | 369 | transf.append(ToTensor(unsqueeze_axis=unsqueeze_axis)) 370 | 371 | if scaler is not None: 372 | transf.append(Normalize(scaler=scaler)) 373 | 374 | if time_shifting is not None: 375 | transf.append(Time_shift(time_shifting)) 376 | 377 | return Compose(transf) 378 | -------------------------------------------------------------------------------- /utilities/Logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import sys 3 | import logging.config 4 | import os 5 | import time 6 | from utilities.distribute import is_main_process 7 | 8 | 9 | class Logger(object): 10 | def __init__(self, file_name="Default.log", stream=sys.stdout): 11 | self.terminal = stream 12 | self.log = open(file_name, "a") 13 | 14 | def write(self, message): 15 | self.terminal.write(message) 16 | self.log.write(message) 17 | 18 | def flush(self): 19 | pass 20 | 21 | 22 | def create_logger(logger_name, terminal_level=logging.INFO): 23 | """ Create a logger. 24 | Args: 25 | logger_name: str, name of the logger 26 | terminal_level: int, logging level in the terminal 27 | """ 28 | logging.config.dictConfig({ 29 | 'version': 1, 30 | 'disable_existing_loggers': False, 31 | }) 32 | logger = logging.getLogger(logger_name) 33 | tool_formatter = logging.Formatter('%(levelname)s - %(name)s - %(message)s') 34 | 35 | if type(terminal_level) is str: 36 | if terminal_level.lower() == "debug": 37 | res_terminal_level = logging.DEBUG 38 | elif terminal_level.lower() == "info": 39 | res_terminal_level = logging.INFO 40 | elif "warn" in terminal_level.lower(): 41 | res_terminal_level = logging.WARNING 42 | elif terminal_level.lower() == "error": 43 | res_terminal_level = logging.ERROR 44 | elif terminal_level.lower() == "critical": 45 | res_terminal_level = logging.CRITICAL 46 | else: 47 | res_terminal_level = logging.NOTSET 48 | else: 49 | res_terminal_level = terminal_level 50 | 51 | if not is_main_process(): 52 | res_terminal_level = logging.ERROR 53 | 54 | logger.setLevel(res_terminal_level) 55 | # Remove the stdout handler 56 | logger_handlers = logger.handlers[:] 57 | if not len(logger_handlers): 58 | terminal_h = logging.StreamHandler(sys.stdout) 59 | terminal_h.setLevel(res_terminal_level) 60 | terminal_h.set_name('stdout') 61 | terminal_h.setFormatter(tool_formatter) 62 | logger.addHandler(terminal_h) 63 | return logger 64 | 65 | 66 | def set_logger(info): 67 | log_path = './log/' 68 | if not os.path.exists(log_path): 69 | os.makedirs(log_path) 70 | log_file_name = log_path + info + time.strftime("-%Y%m%d-%H%M%S", time.localtime()) + '.log' 71 | sys.stdout = Logger(log_file_name) 72 | sys.stderr = Logger(log_file_name) 73 | -------------------------------------------------------------------------------- /utilities/Scaler.py: -------------------------------------------------------------------------------- 1 | import time 2 | import warnings 3 | 4 | import numpy as np 5 | import torch 6 | import json 7 | from utilities.Logger import create_logger 8 | 9 | 10 | 11 | class Scaler: 12 | """ 13 | operates on one or multiple existing datasets and applies operations 14 | """ 15 | 16 | def __init__(self): 17 | self.mean_ = None 18 | self.mean_of_square_ = None 19 | self.std_ = None 20 | self.logger = create_logger(__name__) 21 | 22 | # compute the mean incrementaly 23 | def mean(self, data, axis=-1): 24 | # -1 means have at the end a mean vector of the last dimension 25 | if axis == -1: 26 | mean = data 27 | while len(mean.shape) != 1: 28 | mean = np.mean(mean, axis=0, dtype=np.float64) 29 | else: 30 | mean = np.mean(data, axis=axis, dtype=np.float64) 31 | return mean 32 | 33 | # compute variance thanks to mean and mean of square 34 | def variance(self, mean, mean_of_square): 35 | return mean_of_square - mean**2 36 | 37 | def means(self, dataset): 38 | """ 39 | Splits a dataset in to train test validation. 40 | :param dataset: dataset, from DataLoad class, each sample is an (X, y) tuple. 41 | """ 42 | self.logger.info('computing mean') 43 | start = time.time() 44 | 45 | shape = None 46 | 47 | counter = 0 48 | for sample in dataset: 49 | if type(sample) in [tuple, list] and len(sample) == 2: 50 | batch_x, _ = sample 51 | else: 52 | batch_x = sample 53 | if type(batch_x) is torch.Tensor: 54 | batch_x_arr = batch_x.numpy() 55 | else: 56 | batch_x_arr = batch_x 57 | data_square = batch_x_arr ** 2 58 | counter += 1 59 | 60 | if shape is None: 61 | shape = batch_x_arr.shape 62 | else: 63 | if not batch_x_arr.shape == shape: 64 | raise NotImplementedError("Not possible to add data with different shape in mean calculation yet") 65 | 66 | # assume first item will have shape info 67 | if self.mean_ is None: 68 | self.mean_ = self.mean(batch_x_arr, axis=-1) 69 | else: 70 | self.mean_ += self.mean(batch_x_arr, axis=-1) 71 | 72 | if self.mean_of_square_ is None: 73 | self.mean_of_square_ = self.mean(data_square, axis=-1) 74 | else: 75 | self.mean_of_square_ += self.mean(data_square, axis=-1) 76 | 77 | self.mean_ /= counter 78 | self.mean_of_square_ /= counter 79 | 80 | # ### To be used if data different shape, but need to stop the iteration before. 81 | # rest = len(dataset) - i 82 | # if rest != 0: 83 | # weight = rest / float(i + rest) 84 | # X, y = dataset[-1] 85 | # data_square = X ** 2 86 | # mean = mean * (1 - weight) + self.mean(X, axis=-1) * weight 87 | # mean_of_square = mean_of_square * (1 - weight) + self.mean(data_square, axis=-1) * weight 88 | 89 | self.logger.debug('time to compute means: ' + str(time.time() - start)) 90 | return self 91 | 92 | def std(self, variance): 93 | return np.sqrt(variance) 94 | 95 | def calculate_scaler(self, dataset): 96 | self.means(dataset) 97 | variance = self.variance(self.mean_, self.mean_of_square_) 98 | self.std_ = self.std(variance) 99 | 100 | return self.mean_, self.std_ 101 | 102 | def normalize(self, batch): 103 | if type(batch) is torch.Tensor: 104 | batch_ = batch.numpy() 105 | batch_ = (batch_ - self.mean_) / self.std_ 106 | return torch.Tensor(batch_) 107 | else: 108 | return (batch - self.mean_) / self.std_ 109 | 110 | def state_dict(self): 111 | if type(self.mean_) is not np.ndarray: 112 | raise NotImplementedError("Save scaler only implemented for numpy array means_") 113 | 114 | dict_save = {"mean_": self.mean_.tolist(), 115 | "mean_of_square_": self.mean_of_square_.tolist()} 116 | return dict_save 117 | 118 | def save(self, path): 119 | dict_save = self.state_dict() 120 | with open(path, "w") as f: 121 | json.dump(dict_save, f) 122 | 123 | def load(self, path): 124 | with open(path, "r") as f: 125 | dict_save = json.load(f) 126 | 127 | self.load_state_dict(dict_save) 128 | 129 | def load_state_dict(self, state_dict): 130 | self.mean_ = np.array(state_dict["mean_"]) 131 | self.mean_of_square_ = np.array(state_dict["mean_of_square_"]) 132 | variance = self.variance(self.mean_, self.mean_of_square_) 133 | self.std_ = self.std(variance) 134 | 135 | 136 | class ScalerPerAudio: 137 | """Normalize inputs one by one 138 | Args: 139 | normalization: str, in {"global", "per_channel"} 140 | type_norm: str, in {"mean", "max"} 141 | """ 142 | 143 | def __init__(self, normalization="global", type_norm="mean"): 144 | self.normalization = normalization 145 | self.type_norm = type_norm 146 | 147 | def normalize(self, spectrogram): 148 | """ Apply the transformation on data 149 | Args: 150 | spectrogram: np.array, the data to be modified, assume to have 3 dimensions 151 | 152 | Returns: 153 | np.array 154 | The transformed data 155 | """ 156 | if type(spectrogram) is torch.Tensor: 157 | tensor = True 158 | spectrogram = spectrogram.numpy() 159 | else: 160 | tensor = False 161 | 162 | if self.normalization == "global": 163 | axis = None 164 | elif self.normalization == "per_band": 165 | axis = 0 166 | else: 167 | raise NotImplementedError("normalization is 'global' or 'per_band'") 168 | 169 | if self.type_norm == "standard": 170 | res_data = (spectrogram - spectrogram[0].mean(axis)) / (spectrogram[0].std(axis) + np.finfo(float).eps) 171 | elif self.type_norm == "max": 172 | res_data = spectrogram[0] / (np.abs(spectrogram[0].max(axis)) + np.finfo(float).eps) 173 | elif self.type_norm == "min-max": 174 | res_data = (spectrogram - spectrogram[0].min(axis)) / (spectrogram[0].max(axis) - spectrogram[0].min(axis) 175 | + np.finfo(float).eps) 176 | else: 177 | raise NotImplementedError("No other type_norm implemented except {'standard', 'max', 'min-max'}") 178 | if np.isnan(res_data).any(): 179 | res_data = np.nan_to_num(res_data, posinf=0, neginf=0) 180 | warnings.warn("Trying to divide by zeros while normalizing spectrogram, replacing nan by 0") 181 | 182 | if tensor: 183 | res_data = torch.Tensor(res_data) 184 | return res_data 185 | 186 | def state_dict(self): 187 | pass 188 | 189 | def save(self, path): 190 | pass 191 | 192 | def load(self, path): 193 | pass 194 | 195 | def load_state_dict(self, state_dict): 196 | pass 197 | -------------------------------------------------------------------------------- /utilities/box_ops.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | """ 3 | Utilities for bounding box manipulation and GIoU. 4 | """ 5 | import torch 6 | from torchvision.ops.boxes import box_area 7 | 8 | 9 | def box_cxcywh_to_xyxy(x): 10 | c, l = x.unbind(-1) 11 | zero = torch.zeros(c.shape).to(c.device) 12 | one = torch.ones(c.shape).to(c.device) 13 | b = [c-l/2, zero, c+l/2, one] 14 | return torch.stack(b, dim=-1) 15 | 16 | def box_cxcywh_to_se(x): 17 | c, l = x.unbind(-1) 18 | b = [c-l/2, c+l/2] 19 | return torch.stack(b, dim=-1) 20 | 21 | 22 | def box_xyxy_to_cxcywh(x): 23 | x0, y0, x1, y1 = x.unbind(-1) 24 | b = [(x0 + x1) / 2, (x1 - x0)] 25 | return torch.stack(b, dim=-1) 26 | 27 | 28 | 29 | def box_iou(boxes1, boxes2): 30 | area1 = box_area(boxes1) 31 | area2 = box_area(boxes2) 32 | 33 | lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2] 34 | rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2] 35 | 36 | wh = (rb - lt).clamp(min=0) # [N,M,2] 37 | inter = wh[:, :, 0] * wh[:, :, 1] # [N,M] 38 | 39 | union = area1[:, None] + area2 - inter 40 | 41 | iou = inter / union 42 | return iou, union 43 | 44 | 45 | def generalized_box_iou(boxes1, boxes2): 46 | assert (boxes1[:, 2:] >= boxes1[:, :2]).all() 47 | assert (boxes2[:, 2:] >= boxes2[:, :2]).all() 48 | iou, union = box_iou(boxes1, boxes2) 49 | 50 | lt = torch.min(boxes1[:, None, :2], boxes2[:, :2]) 51 | rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:]) 52 | 53 | wh = (rb - lt).clamp(min=0) # [N,M,2] 54 | area = wh[:, :, 0] * wh[:, :, 1] 55 | 56 | return iou - (area - union) / area 57 | -------------------------------------------------------------------------------- /utilities/distribute.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.distributed as dist 4 | 5 | 6 | def is_dist_avail_and_initialized(): 7 | if not dist.is_available(): 8 | return False 9 | if not dist.is_initialized(): 10 | return False 11 | return True 12 | 13 | def get_rank(): 14 | if not is_dist_avail_and_initialized(): 15 | return 0 16 | return dist.get_rank() 17 | 18 | 19 | def is_main_process(): 20 | return get_rank() == 0 21 | 22 | 23 | def get_world_size(): 24 | if not is_dist_avail_and_initialized(): 25 | return 1 26 | return dist.get_world_size() 27 | 28 | 29 | def setup_for_distributed(is_master): 30 | """ 31 | This function disables printing when not in master process 32 | """ 33 | import builtins as __builtin__ 34 | builtin_print = __builtin__.print 35 | 36 | def print(*args, **kwargs): 37 | force = kwargs.pop('force', False) 38 | if is_master or force: 39 | builtin_print(*args, **kwargs) 40 | 41 | __builtin__.print = print 42 | 43 | def init_distributed_mode(args): 44 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 45 | args.rank = int(os.environ["RANK"]) 46 | args.world_size = int(os.environ['WORLD_SIZE']) 47 | args.gpu = int(os.environ['LOCAL_RANK']) 48 | elif 'SLURM_PROCID' in os.environ: 49 | args.rank = int(os.environ['SLURM_PROCID']) 50 | args.gpu = args.rank % torch.cuda.device_count() 51 | else: 52 | print('Not using distributed mode') 53 | args.distributed = False 54 | return 55 | 56 | args.distributed = True 57 | 58 | torch.cuda.set_device(args.gpu) 59 | args.dist_backend = 'nccl' 60 | print('| distributed init (rank {}): {}'.format( 61 | args.rank, args.dist_url), flush=True) 62 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 63 | world_size=args.world_size, rank=args.rank) 64 | torch.distributed.barrier() 65 | setup_for_distributed(args.rank == 0) 66 | 67 | def reduce_dict(input_dict, average=True): 68 | """ 69 | Args: 70 | input_dict (dict): all the values will be reduced 71 | average (bool): whether to do average or sum 72 | Reduce the values in the dictionary from all processes so that all processes 73 | have the averaged results. Returns a dict with the same fields as 74 | input_dict, after reduction. 75 | """ 76 | world_size = get_world_size() 77 | if world_size < 2: 78 | return input_dict 79 | with torch.no_grad(): 80 | names = [] 81 | values = [] 82 | # sort the keys so that they are consistent across processes 83 | for k in sorted(input_dict.keys()): 84 | if input_dict[k].grad_fn: 85 | names.append(k) 86 | values.append(input_dict[k]) 87 | values = torch.stack(values, dim=0) 88 | dist.all_reduce(values) 89 | if average: 90 | values /= world_size 91 | reduced_dict = {k: v for k, v in zip(names, values)} 92 | return reduced_dict 93 | 94 | def get_reduced_loss(loss_dict, weight_dict, metric_logger, prefix=''): 95 | 96 | # reduce losses over all GPUs for logging purposes 97 | loss_dict_reduced = reduce_dict(loss_dict) 98 | loss_dict_reduced_unscaled = {prefix + f'{k}_unscaled': v 99 | for k, v in loss_dict_reduced.items()} 100 | loss_dict_reduced_scaled = {prefix + k: v * weight_dict[k] 101 | for k, v in loss_dict_reduced.items() if k in weight_dict} 102 | losses_reduced_scaled = sum(loss_dict_reduced_scaled.values()) 103 | loss_value = losses_reduced_scaled.item() 104 | loss_name = prefix + "loss" 105 | metric_logger.update(**{loss_name: loss_value}, **loss_dict_reduced_scaled, **loss_dict_reduced_unscaled) 106 | 107 | return loss_value -------------------------------------------------------------------------------- /utilities/metrics.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import os 3 | from os import path as osp 4 | import psds_eval 5 | from dcase_util.data import ProbabilityEncoder 6 | import sed_eval 7 | import numpy as np 8 | import pandas as pd 9 | import torch 10 | from psds_eval import plot_psd_roc, PSDSEval 11 | from utilities.FrameEncoder import ManyHotEncoder 12 | from collections.abc import Iterable 13 | 14 | 15 | def flatten(l): 16 | for el in l: 17 | if isinstance(el, Iterable) and not isinstance(el, (str, bytes)): 18 | yield from flatten(el) 19 | else: 20 | yield el 21 | 22 | 23 | 24 | def get_event_list_current_file(df, fname): 25 | """ 26 | Get list of events for a given filename 27 | :param df: pd.DataFrame, the dataframe to search on 28 | :param fname: the filename to extract the value from the dataframe 29 | :return: list of events (dictionaries) for the given filename 30 | """ 31 | event_file = df[df["filename"] == fname] 32 | if len(event_file) == 1: 33 | if pd.isna(event_file["event_label"].iloc[0]): 34 | event_list_for_current_file = [{"filename": fname}] 35 | else: 36 | event_list_for_current_file = event_file.to_dict('records') 37 | else: 38 | event_list_for_current_file = event_file.to_dict('records') 39 | 40 | return event_list_for_current_file 41 | 42 | 43 | def event_based_evaluation_df(reference, estimated, t_collar=0.200, percentage_of_length=0.2): 44 | """ Calculate EventBasedMetric given a reference and estimated dataframe 45 | 46 | Args: 47 | reference: pd.DataFrame containing "filename" "onset" "offset" and "event_label" columns which describe the 48 | reference events 49 | estimated: pd.DataFrame containing "filename" "onset" "offset" and "event_label" columns which describe the 50 | estimated events to be compared with reference 51 | t_collar: float, in seconds, the number of time allowed on onsets and offsets 52 | percentage_of_length: float, between 0 and 1, the percentage of length of the file allowed on the offset 53 | Returns: 54 | sed_eval.sound_event.EventBasedMetrics with the scores 55 | """ 56 | 57 | evaluated_files = reference["filename"].unique() 58 | 59 | classes = [] 60 | classes.extend(reference.event_label.dropna().unique()) 61 | classes.extend(estimated.event_label.dropna().unique()) 62 | classes = list(set(classes)) 63 | 64 | event_based_metric = sed_eval.sound_event.EventBasedMetrics( 65 | event_label_list=classes, 66 | t_collar=t_collar, 67 | percentage_of_length=percentage_of_length, 68 | empty_system_output_handling='zero_score' 69 | ) 70 | 71 | for fname in evaluated_files: 72 | reference_event_list_for_current_file = get_event_list_current_file(reference, fname) 73 | estimated_event_list_for_current_file = get_event_list_current_file(estimated, fname) 74 | 75 | event_based_metric.evaluate( 76 | reference_event_list=reference_event_list_for_current_file, 77 | estimated_event_list=estimated_event_list_for_current_file, 78 | ) 79 | 80 | return event_based_metric 81 | 82 | 83 | def segment_based_evaluation_df(reference, estimated, time_resolution=1.): 84 | """ Calculate SegmentBasedMetrics given a reference and estimated dataframe 85 | 86 | Args: 87 | reference: pd.DataFrame containing "filename" "onset" "offset" and "event_label" columns which describe the 88 | reference events 89 | estimated: pd.DataFrame containing "filename" "onset" "offset" and "event_label" columns which describe the 90 | estimated events to be compared with reference 91 | time_resolution: float, the time resolution of the segment based metric 92 | Returns: 93 | sed_eval.sound_event.SegmentBasedMetrics with the scores 94 | """ 95 | evaluated_files = reference["filename"].unique() 96 | 97 | classes = [] 98 | classes.extend(reference.event_label.dropna().unique()) 99 | classes.extend(estimated.event_label.dropna().unique()) 100 | classes = list(set(classes)) 101 | 102 | segment_based_metric = sed_eval.sound_event.SegmentBasedMetrics( 103 | event_label_list=classes, 104 | time_resolution=time_resolution 105 | ) 106 | 107 | for fname in evaluated_files: 108 | reference_event_list_for_current_file = get_event_list_current_file(reference, fname) 109 | estimated_event_list_for_current_file = get_event_list_current_file(estimated, fname) 110 | 111 | segment_based_metric.evaluate( 112 | reference_event_list=reference_event_list_for_current_file, 113 | estimated_event_list=estimated_event_list_for_current_file 114 | ) 115 | 116 | return segment_based_metric 117 | 118 | 119 | 120 | def psds_score(psds, filename_roc_curves=None): 121 | """ add operating points to PSDSEval object and compute metrics 122 | 123 | Args: 124 | psds: psds.PSDSEval object initialized with the groundtruth corresponding to the predictions 125 | filename_roc_curves: str, the base filename of the roc curve to be computed 126 | """ 127 | try: 128 | psds_score = psds.psds(alpha_ct=0, alpha_st=0, max_efpr=100) 129 | print(f"\nPSD-Score (0, 0, 100): {psds_score.value:.5f}") 130 | psds_ct_score = psds.psds(alpha_ct=1, alpha_st=0, max_efpr=100) 131 | print(f"\nPSD-Score (1, 0, 100): {psds_ct_score.value:.5f}") 132 | psds_macro_score = psds.psds(alpha_ct=0, alpha_st=1, max_efpr=100) 133 | print(f"\nPSD-Score (0, 1, 100): {psds_macro_score.value:.5f}") 134 | if filename_roc_curves is not None: 135 | if osp.dirname(filename_roc_curves) != "": 136 | os.makedirs(osp.dirname(filename_roc_curves), exist_ok=True) 137 | base, ext = osp.splitext(filename_roc_curves) 138 | plot_psd_roc(psds_score, filename=f"{base}_0_0_100{ext}") 139 | plot_psd_roc(psds_ct_score, filename=f"{base}_1_0_100{ext}") 140 | plot_psd_roc(psds_macro_score, filename=f"{base}_0_1_100{ext}") 141 | 142 | except psds_eval.psds.PSDSEvalError as e: 143 | print("psds score did not work ....") 144 | print(e) 145 | 146 | 147 | def compute_sed_eval_metrics(predictions, groundtruth, report=True, cal_seg=False): 148 | metric_event = event_based_evaluation_df(groundtruth, predictions, t_collar=0.200, 149 | percentage_of_length=0.2) 150 | if report: 151 | print(metric_event) 152 | metric_segment = None 153 | if cal_seg: 154 | metric_segment = segment_based_evaluation_df(groundtruth, predictions, time_resolution=1.) 155 | print(metric_segment) 156 | return metric_event, metric_segment 157 | 158 | 159 | def format_df(df, mhe): 160 | """ Make a weak labels dataframe from strongly labeled (join labels) 161 | Args: 162 | df: pd.DataFrame, the dataframe strongly labeled with onset and offset columns (+ event_label) 163 | mhe: ManyHotEncoder object, the many hot encoder object that can encode the weak labels 164 | 165 | Returns: 166 | weakly labeled dataframe 167 | """ 168 | def join_labels(x): 169 | return pd.Series(dict(filename=x['filename'].iloc[0], 170 | event_label=mhe.encode_weak(x["event_label"].drop_duplicates().dropna().tolist()))) 171 | 172 | if "onset" in df.columns or "offset" in df.columns: 173 | df = df.groupby("filename", as_index=False).apply(join_labels) 174 | return df 175 | 176 | 177 | def get_f_measure_by_class(torch_model, nb_tags, dataloader_, thresholds_=None): 178 | """ get f measure for each class given a model and a generator of data (batch_x, y) 179 | 180 | Args: 181 | torch_model : Model, model to get predictions, forward should return weak and strong predictions 182 | nb_tags : int, number of classes which are represented 183 | dataloader_ : generator, data generator used to get f_measure 184 | thresholds_ : int or list, thresholds to apply to each class to binarize probabilities 185 | 186 | Returns: 187 | macro_f_measure : list, f measure for each class 188 | 189 | """ 190 | if torch.cuda.is_available(): 191 | torch_model = torch_model.cuda() 192 | 193 | # Calculate external metrics 194 | tp = np.zeros(nb_tags) 195 | tn = np.zeros(nb_tags) 196 | fp = np.zeros(nb_tags) 197 | fn = np.zeros(nb_tags) 198 | for counter, (batch_x, y) in enumerate(dataloader_): 199 | if torch.cuda.is_available(): 200 | batch_x = batch_x.cuda() 201 | 202 | pred_strong, pred_weak = torch_model(batch_x) 203 | pred_weak = pred_weak.cpu().data.numpy() 204 | labels = y.numpy() 205 | 206 | # Used only with a model predicting only strong outputs 207 | if len(pred_weak.shape) == 3: 208 | # average data to have weak labels 209 | pred_weak = np.max(pred_weak, axis=1) 210 | 211 | if len(labels.shape) == 3: 212 | labels = np.max(labels, axis=1) 213 | labels = ProbabilityEncoder().binarization(labels, 214 | binarization_type="global_threshold", 215 | threshold=0.5) 216 | 217 | if thresholds_ is None: 218 | binarization_type = 'global_threshold' 219 | thresh = 0.5 220 | else: 221 | binarization_type = "class_threshold" 222 | assert type(thresholds_) is list 223 | thresh = thresholds_ 224 | 225 | batch_predictions = ProbabilityEncoder().binarization(pred_weak, 226 | binarization_type=binarization_type, 227 | threshold=thresh, 228 | time_axis=0 229 | ) 230 | 231 | tp_, fp_, fn_, tn_ = intermediate_at_measures(labels, batch_predictions) 232 | tp += tp_ 233 | fp += fp_ 234 | fn += fn_ 235 | tn += tn_ 236 | 237 | macro_f_score = np.zeros(nb_tags) 238 | mask_f_score = 2 * tp + fp + fn != 0 239 | macro_f_score[mask_f_score] = 2 * tp[mask_f_score] / (2 * tp + fp + fn)[mask_f_score] 240 | 241 | return macro_f_score 242 | 243 | 244 | def intermediate_at_measures(encoded_ref, encoded_est): 245 | """ Calculate true/false - positives/negatives. 246 | 247 | Args: 248 | encoded_ref: np.array, the reference array where a 1 means the label is present, 0 otherwise 249 | encoded_est: np.array, the estimated array, where a 1 means the label is present, 0 otherwise 250 | 251 | Returns: 252 | tuple 253 | number of (true positives, false positives, false negatives, true negatives) 254 | 255 | """ 256 | tp = (encoded_est + encoded_ref == 2).sum(axis=0) 257 | fp = (encoded_est - encoded_ref == 1).sum(axis=0) 258 | fn = (encoded_ref - encoded_est == 1).sum(axis=0) 259 | tn = (encoded_est + encoded_ref == 0).sum(axis=0) 260 | return tp, fp, fn, tn 261 | 262 | 263 | def macro_f_measure(tp, fp, fn): 264 | """ From intermediates measures, give the macro F-measure 265 | 266 | Args: 267 | tp: int, number of true positives 268 | fp: int, number of false positives 269 | fn: int, number of true negatives 270 | 271 | Returns: 272 | float 273 | The macro F-measure 274 | """ 275 | macro_f_score = np.zeros(tp.shape[-1]) 276 | mask_f_score = 2 * tp + fp + fn != 0 277 | macro_f_score[mask_f_score] = 2 * tp[mask_f_score] / (2 * tp + fp + fn)[mask_f_score] 278 | return macro_f_score 279 | 280 | 281 | def audio_tagging_results(reference, estimated): 282 | classes = [] 283 | if "event_label" in reference.columns: 284 | classes.extend(reference.event_label.dropna().unique()) 285 | classes.extend(estimated.event_label.dropna().unique()) 286 | classes = list(set(classes)) 287 | mhe = ManyHotEncoder(classes) 288 | reference = format_df(reference, mhe) 289 | estimated = format_df(estimated, mhe) 290 | else: 291 | classes.extend(reference.event_labels.str.split(',', expand=True).unstack().dropna().unique()) 292 | classes.extend(estimated.event_labels.str.split(',', expand=True).unstack().dropna().unique()) 293 | classes = list(set(classes)) 294 | mhe = ManyHotEncoder(classes) 295 | 296 | matching = reference.merge(estimated, how='outer', on="filename", suffixes=["_ref", "_pred"]) 297 | 298 | def na_values(val): 299 | if type(val) is np.ndarray: 300 | return val 301 | if pd.isna(val): 302 | return np.zeros(len(classes)) 303 | return val 304 | 305 | if not estimated.empty: 306 | matching.event_label_pred = matching.event_label_pred.apply(na_values) 307 | matching.event_label_ref = matching.event_label_ref.apply(na_values) 308 | 309 | tp, fp, fn, tn = intermediate_at_measures(np.array(matching.event_label_ref.tolist()), 310 | np.array(matching.event_label_pred.tolist())) 311 | macro_f = macro_f_measure(tp, fp, fn) 312 | macro_p = tp / (tp + fp) 313 | macro_r = tp / (tp + fn) 314 | else: 315 | macro_f = np.zeros(len(classes)) 316 | macro_p = np.zeros(len(classes)) 317 | macro_r = np.zeros(len(classes)) 318 | data = np.asarray([macro_f, macro_p, macro_r]).transpose(1, 0) 319 | results_serie = pd.DataFrame(data, columns=['f', 'p', 'r'], index=mhe.labels) 320 | results_serie = results_serie.append( 321 | pd.DataFrame(data.mean(0).reshape(1, -1), columns=['f', 'p', 'r'], index=['avg'])) 322 | return results_serie 323 | 324 | 325 | def compute_psds_from_operating_points(list_predictions, groundtruth_df, meta_df, dtc_threshold=0.5, gtc_threshold=0.5, 326 | cttc_threshold=0.3): 327 | psds = PSDSEval(dtc_threshold, gtc_threshold, cttc_threshold, ground_truth=groundtruth_df, metadata=meta_df) 328 | for prediction_df in list_predictions: 329 | psds.add_operating_point(prediction_df) 330 | return psds 331 | 332 | 333 | def compute_metrics(predictions, gtruth_df, meta_df=None, cal_seg=True, cal_clip=True): 334 | # report results 335 | if predictions.empty: 336 | return 0 337 | events_metric, segments_metric = compute_sed_eval_metrics(predictions, gtruth_df, report=True, cal_seg=cal_seg) 338 | events_macro = events_metric.results_class_wise_average_metrics() 339 | events_macro_f1 = events_macro['f_measure']['f_measure'] 340 | events_macro_p = events_macro['f_measure']['precision'] 341 | events_macro_r = events_macro['f_measure']['recall'] 342 | clip_macro_f1 = None 343 | if cal_clip: 344 | clip_metric = audio_tagging_results(gtruth_df, predictions) 345 | # clip_macro_f1 = clip_metric.values.mean() 346 | clip_macro_f1 = clip_metric.loc['avg', 'f'] 347 | print("Class-wise clip metrics") 348 | print("=" * 50) 349 | print(clip_metric) 350 | if segments_metric is not None: 351 | seg_macro = segments_metric.results_class_wise_average_metrics() 352 | seg_macro_f1 = seg_macro['f_measure']['f_measure'] 353 | seg_macro_p = seg_macro['f_measure']['precision'] 354 | seg_macro_r = seg_macro['f_measure']['recall'] 355 | metric = pd.DataFrame([['%.2f%%'%(events_macro_f1 * 100), '%.2f%%'%(events_macro_p * 100), 356 | '%.2f%%'%(events_macro_r * 100), '%.2f%%'%(seg_macro_f1 * 100), 357 | '%.2f%%'%(seg_macro_p * 100), '%.2f%%'%(seg_macro_r * 100), 358 | '%.2f%%'%(clip_macro_f1 * 100)]], 359 | columns=['Eb_F1', 'Eb_P', 'Eb_R', 'Sb_F', 'Sb_P', 'Sb_R', 'At_F1']) 360 | print("\nAll Metrics") 361 | print("=" * 55) 362 | print(metric) 363 | print("=" * 55) 364 | # dtc_threshold, gtc_threshold, cttc_threshold = 0.5, 0.5, 0.3 365 | # psds = PSDSEval(dtc_threshold, gtc_threshold, cttc_threshold, ground_truth=gtruth_df, metadata=meta_df) 366 | # psds_macro_f1, psds_f1_classes = psds.compute_macro_f_score(predictions) 367 | # logger.info(f"F1_score (psds_eval) accounting cross triggers: {psds_macro_f1}") 368 | return events_macro_f1 -------------------------------------------------------------------------------- /utilities/mixup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # encoding: utf-8 3 | """ 4 | @author: yzr 5 | @file: mixup.py 6 | @time: 2020/8/5 12:53 7 | """ 8 | import numpy as np 9 | import torch 10 | from utilities.box_ops import box_cxcywh_to_se 11 | 12 | 13 | def mixup_data(x, y, mask_strong, mask_weak, mix_up_ratio=0.5, max_events=20, alpha=3): 14 | """ 15 | mix up data for sedt model, only mix up data with same type of labels 16 | :param x: feature, shape:[batch_size, channel, nframe, nfreq] 17 | :param y: label, [{"labels":ndarray, "boxes":ndarray, "orig_size"},{...},...] 18 | :param mix_up_ratio: only mix up a part of data 19 | :param alpha: 20 | :return: 21 | """ 22 | if alpha > 0.: 23 | lam = np.random.beta(alpha, alpha) 24 | else: 25 | lam = 1.0 26 | bs = x.tensors.shape[0] 27 | mix_num = int(bs*mix_up_ratio) 28 | index = np.asarray(list(range(bs))) 29 | np.random.shuffle(index) 30 | 31 | data_1 = x.tensors[:mix_num, :] 32 | data_2 = x.tensors[index][:mix_num, :] 33 | label_1 = y[:mix_num] 34 | label_2 = [y[i] for i in index[:mix_num]] 35 | 36 | # not collapse overlapped events of the same class 37 | data = lam * data_1 + (1-lam) * data_2 38 | strong_label, strong_data = [], [] 39 | weak_label, weak_data = [], [] 40 | unlabel, unlabel_data = [], [] 41 | for i, (l_1, l_2 )in enumerate(zip(label_1,label_2)): 42 | if len(l_1["boxes"]) == 0 or len(l_2["boxes"]) ==0: 43 | if (len(l_1["boxes"]) > 0): 44 | strong_label.append(label_1[i]) 45 | strong_data.append(data_1[i].unsqueeze(dim=0)) 46 | elif (len(l_2["boxes"]) > 0): 47 | strong_label.append(label_2[i]) 48 | strong_data.append(data_2[i].unsqueeze(dim=0)) 49 | else: 50 | weak_label.append({ 51 | "labels": torch.cat((l_1["labels"], l_2["labels"]), dim=0), 52 | "boxes": torch.tensor([], device=x.tensors.device), 53 | "ratio": torch.tensor([lam] * len(l_1["labels"]) + [1 - lam] * len(l_2["labels"]), device=x.tensors.device), 54 | "orig_size": l_1["orig_size"]}) 55 | weak_data.append(data[i].unsqueeze(dim=0)) 56 | else: 57 | # abandom data with more than num_queries events 58 | if len(l_1["boxes"]) + len(l_2["boxes"]) > max_events: 59 | if len(l_1["boxes"]): 60 | strong_label.append(l_1) 61 | strong_data.append(data_1[i].unsqueeze(dim=0)) 62 | else: 63 | strong_label.append(l_2) 64 | strong_data.append(data_2[i].unsqueeze(dim=0)) 65 | else: 66 | ds = data_1[i] # data with strong label 67 | if len(l_1["boxes"]) == 0: 68 | tmp = l_1 69 | l_1 = l_2 70 | l_2 = tmp 71 | lam = 1-lam 72 | ds = data_2[i] 73 | 74 | strong_label.append({ 75 | "labels": torch.cat((l_1["labels"], l_2["labels"]),dim=0), 76 | "boxes": torch.cat((l_1["boxes"], l_2["boxes"]), dim=0), 77 | "ratio": torch.tensor([lam]*len(l_1["labels"]) + [1-lam]*len(l_2["labels"]), device=x.tensors.device), 78 | "orig_size": l_1["orig_size"] 79 | }) 80 | strong_data.append(data[i].unsqueeze(dim=0)) 81 | 82 | # abandom data which mix up events with same class label 83 | cur_labels = strong_label[-1]["labels"] 84 | cur_boxes = strong_label[-1]["boxes"] 85 | events = set(cur_labels.tolist()) 86 | for e in events: 87 | boxes = cur_boxes[(cur_labels == e)[:len(cur_boxes)]] 88 | boxes = box_cxcywh_to_se(boxes) 89 | boxes = boxes[boxes.argsort(dim=0)[:, 0]] 90 | boxes_e = boxes[:, 1][:-1] 91 | boxes_s = boxes[:, 0][1:] 92 | if not (boxes_e < boxes_s).all().item(): 93 | strong_label[-1] = l_1 94 | strong_data[-1] = ds.unsqueeze(dim=0) 95 | break 96 | data_final = [] 97 | label_final = [] 98 | # integrate with non-mix-up data 99 | if len(x.tensors[mask_strong][mix_num:]): 100 | strong_data.append(x.tensors[mask_strong][mix_num:]) 101 | strong_label.extend(y[mask_strong][mix_num:]) 102 | if len(strong_data): 103 | data_final.extend(strong_data) 104 | label_final.extend(strong_label) 105 | 106 | if mask_weak is not None: 107 | left_weak_index = max(0, mix_num - mask_strong.stop) 108 | if len(x.tensors[mask_weak][left_weak_index:]): 109 | weak_data.append(x.tensors[mask_weak][left_weak_index:]) 110 | weak_label.extend(y[mask_weak][left_weak_index:]) 111 | if len(weak_data): 112 | data_final.extend(weak_data) 113 | label_final.extend(weak_label) 114 | 115 | 116 | left_unlabel_index = max(0, mix_num - mask_weak.stop) 117 | if len(x.tensors[mask_weak.stop:][left_unlabel_index:]): 118 | unlabel_data.append(x.tensors[mask_weak.stop:][left_unlabel_index:]) 119 | unlabel.extend(y[mask_weak.stop:][left_unlabel_index:]) 120 | if len(unlabel_data): 121 | data_final.extend(unlabel_data) 122 | label_final.extend(unlabel) 123 | 124 | x.tensors = torch.cat(data_final,dim=0) 125 | y = label_final 126 | 127 | return x, y, slice(len(strong_label)), slice(len(strong_label), len(strong_label)+len(weak_label)) 128 | 129 | def mixup_label_unlabel(x1, x2, y1, y2, mix_up_ratio=0.5, max_events=20, alpha=3): 130 | """ 131 | mix up data for sedt model 132 | :param x1: label data feature, shape:[batch_size, channel, nframe, nfreq] 133 | :param y1: label data target, [{"labels":ndarray, "boxes":ndarray, "orig_size"},{...},...] 134 | :param x2: unlabel data feature, shape:[batch_size, channel, nframe, nfreq] 135 | :param y2: unlabel data target, [{"labels":ndarray, "boxes":ndarray, "orig_size"},{...},...] 136 | :param mix_up_ratio: only mix up a part of data 137 | :param alpha: 138 | :return: 139 | """ 140 | assert mix_up_ratio <= 0.5 141 | if alpha > 0.: 142 | lam = np.random.beta(alpha, alpha) 143 | else: 144 | lam = 1.0 145 | bs = x1.tensors.shape[0] 146 | mix_num = int(bs * mix_up_ratio) 147 | 148 | data_1 = x1.tensors[:mix_num, :] 149 | data_2 = x2.tensors[:mix_num, :] 150 | label_1 = y1[:mix_num] 151 | label_2 = y2[:mix_num] 152 | 153 | # not collapse overlapped events of the same class 154 | data = lam * data_1 + (1 - lam) * data_2 155 | strong_label, strong_data = [], [] 156 | 157 | for i, (l_1, l_2) in enumerate(zip(label_1, label_2)): 158 | if len(l_1["boxes"]) + len(l_2["boxes"]) > max_events: 159 | if len(l_2["boxes"]): 160 | strong_label.append(l_2) 161 | strong_data.append(data_2[i].unsqueeze(dim=0)) 162 | else: 163 | strong_label.append(l_1) 164 | strong_data.append(data_1[i].unsqueeze(dim=0)) 165 | else: 166 | ds = data_1[i] 167 | strong_label.append({ 168 | "labels": torch.cat((l_1["labels"], l_2["labels"]), dim=0), 169 | "boxes": torch.cat((l_1["boxes"], l_2["boxes"]), dim=0), 170 | "ratio": torch.tensor([lam] * len(l_1["labels"]) + [1 - lam] * len(l_2["labels"]), 171 | device=x1.tensors.device), 172 | "orig_size": l_1["orig_size"] 173 | }) 174 | strong_data.append(data[i].unsqueeze(dim=0)) 175 | 176 | # abandon data which mix up events with same class label 177 | cur_labels = strong_label[-1]["labels"] 178 | cur_boxes = strong_label[-1]["boxes"] 179 | events = set(cur_labels.tolist()) 180 | for e in events: 181 | boxes = cur_boxes[(cur_labels == e)[:len(cur_boxes)]] 182 | boxes = box_cxcywh_to_se(boxes) 183 | boxes = boxes[boxes.argsort(dim=0)[:, 0]] 184 | boxes_e = boxes[:, 1][:-1] 185 | boxes_s = boxes[:, 0][1:] 186 | if not (boxes_e < boxes_s).all().item(): 187 | strong_label[-1] = l_1 188 | strong_data[-1] = ds.unsqueeze(dim=0) 189 | break 190 | 191 | strong_data.append(x2.tensors[mix_num:]) 192 | strong_label.extend(y2[mix_num:]) 193 | 194 | x2.tensors, y2 = torch.cat(strong_data, dim=0), strong_label 195 | 196 | return x2, y2 197 | 198 | --------------------------------------------------------------------------------