├── src ├── local │ ├── ast │ │ ├── __init__.py │ │ └── ast_models.py │ ├── panns │ │ ├── __init__.py │ │ ├── pytorch_utils.py │ │ └── models.py │ ├── __pycache__ │ │ ├── utils.cpython-38.pyc │ │ ├── con_loss.cpython-38.pyc │ │ ├── sed_trainer.cpython-38.pyc │ │ ├── classes_dict.cpython-38.pyc │ │ └── resample_folder.cpython-38.pyc │ ├── classes_dict.py │ ├── con_loss.py │ ├── resample_folder.py │ └── utils.py ├── desed_task │ ├── nnet │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── CNN.cpython-38.pyc │ │ │ ├── CRNN.cpython-38.pyc │ │ │ ├── RNN.cpython-38.pyc │ │ │ └── __init__.cpython-38.pyc │ │ ├── RNN.py │ │ ├── .ipynb_checkpoints │ │ │ ├── RNN-checkpoint.py │ │ │ ├── CNN-checkpoint.py │ │ │ └── CRNN-checkpoint.py │ │ ├── CNN.py │ │ └── CRNN.py │ ├── evaluation │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-38.pyc │ │ │ └── evaluation_measures.cpython-38.pyc │ │ ├── evaluation_measures.py │ │ └── .ipynb_checkpoints │ │ │ └── evaluation_measures-checkpoint.py │ ├── utils │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-38.pyc │ │ │ ├── encoder.cpython-38.pyc │ │ │ ├── scaler.cpython-38.pyc │ │ │ └── schedulers.cpython-38.pyc │ │ ├── torch_utils.py │ │ ├── download.py │ │ ├── schedulers.py │ │ ├── scaler.py │ │ ├── encoder.py │ │ └── .ipynb_checkpoints │ │ │ └── encoder-checkpoint.py │ ├── __pycache__ │ │ └── data_augm.cpython-38.pyc │ ├── dataio │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── sampler.cpython-38.pyc │ │ │ ├── __init__.cpython-38.pyc │ │ │ └── datasets.cpython-38.pyc │ │ ├── sampler.py │ │ ├── datasets.py │ │ └── .ipynb_checkpoints │ │ │ └── datasets-checkpoint.py │ └── data_augm.py ├── conda_create_environment.sh ├── confs │ └── default.yaml ├── generate_dcase_task4_2022.py ├── extract_embeddings.py ├── train_sed.py └── train_pretrained.py ├── asset └── lgc.png ├── LICENSE └── README.md /src/local/ast/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/local/panns/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/desed_task/nnet/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/desed_task/evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /asset/lgc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ming-er/LGC-SED/HEAD/asset/lgc.png -------------------------------------------------------------------------------- /src/desed_task/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .encoder import ManyHotEncoder 2 | from .schedulers import ExponentialWarmup 3 | -------------------------------------------------------------------------------- /src/local/__pycache__/utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ming-er/LGC-SED/HEAD/src/local/__pycache__/utils.cpython-38.pyc -------------------------------------------------------------------------------- /src/local/__pycache__/con_loss.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ming-er/LGC-SED/HEAD/src/local/__pycache__/con_loss.cpython-38.pyc -------------------------------------------------------------------------------- /src/local/__pycache__/sed_trainer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ming-er/LGC-SED/HEAD/src/local/__pycache__/sed_trainer.cpython-38.pyc -------------------------------------------------------------------------------- /src/desed_task/__pycache__/data_augm.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ming-er/LGC-SED/HEAD/src/desed_task/__pycache__/data_augm.cpython-38.pyc -------------------------------------------------------------------------------- /src/desed_task/nnet/__pycache__/CNN.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ming-er/LGC-SED/HEAD/src/desed_task/nnet/__pycache__/CNN.cpython-38.pyc -------------------------------------------------------------------------------- /src/desed_task/nnet/__pycache__/CRNN.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ming-er/LGC-SED/HEAD/src/desed_task/nnet/__pycache__/CRNN.cpython-38.pyc -------------------------------------------------------------------------------- /src/desed_task/nnet/__pycache__/RNN.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ming-er/LGC-SED/HEAD/src/desed_task/nnet/__pycache__/RNN.cpython-38.pyc -------------------------------------------------------------------------------- /src/local/__pycache__/classes_dict.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ming-er/LGC-SED/HEAD/src/local/__pycache__/classes_dict.cpython-38.pyc -------------------------------------------------------------------------------- /src/desed_task/dataio/__init__.py: -------------------------------------------------------------------------------- 1 | from .datasets import WeakSet, UnlabeledSet, StronglyAnnotatedSet 2 | from .sampler import ConcatDatasetBatchSampler 3 | -------------------------------------------------------------------------------- /src/local/__pycache__/resample_folder.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ming-er/LGC-SED/HEAD/src/local/__pycache__/resample_folder.cpython-38.pyc -------------------------------------------------------------------------------- /src/desed_task/dataio/__pycache__/sampler.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ming-er/LGC-SED/HEAD/src/desed_task/dataio/__pycache__/sampler.cpython-38.pyc -------------------------------------------------------------------------------- /src/desed_task/nnet/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ming-er/LGC-SED/HEAD/src/desed_task/nnet/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /src/desed_task/utils/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ming-er/LGC-SED/HEAD/src/desed_task/utils/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /src/desed_task/utils/__pycache__/encoder.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ming-er/LGC-SED/HEAD/src/desed_task/utils/__pycache__/encoder.cpython-38.pyc -------------------------------------------------------------------------------- /src/desed_task/utils/__pycache__/scaler.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ming-er/LGC-SED/HEAD/src/desed_task/utils/__pycache__/scaler.cpython-38.pyc -------------------------------------------------------------------------------- /src/desed_task/dataio/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ming-er/LGC-SED/HEAD/src/desed_task/dataio/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /src/desed_task/dataio/__pycache__/datasets.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ming-er/LGC-SED/HEAD/src/desed_task/dataio/__pycache__/datasets.cpython-38.pyc -------------------------------------------------------------------------------- /src/desed_task/utils/__pycache__/schedulers.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ming-er/LGC-SED/HEAD/src/desed_task/utils/__pycache__/schedulers.cpython-38.pyc -------------------------------------------------------------------------------- /src/desed_task/evaluation/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ming-er/LGC-SED/HEAD/src/desed_task/evaluation/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /src/desed_task/evaluation/__pycache__/evaluation_measures.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ming-er/LGC-SED/HEAD/src/desed_task/evaluation/__pycache__/evaluation_measures.cpython-38.pyc -------------------------------------------------------------------------------- /src/desed_task/utils/torch_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | def nantensor(*args, **kwargs): 6 | return torch.ones(*args, **kwargs) * np.nan 7 | 8 | 9 | def nanmean(v, *args, inplace=False, **kwargs): 10 | if not inplace: 11 | v = v.clone() 12 | is_nan = torch.isnan(v) 13 | v[is_nan] = 0 14 | return v.sum(*args, **kwargs) / (~is_nan).float().sum(*args, **kwargs) 15 | -------------------------------------------------------------------------------- /src/local/classes_dict.py: -------------------------------------------------------------------------------- 1 | """ 2 | we store here a dict where we define the encodings for all classes in DESED task. 3 | """ 4 | 5 | from collections import OrderedDict 6 | 7 | 8 | classes_labels = OrderedDict( 9 | { 10 | "Alarm_bell_ringing": 0, 11 | "Blender": 1, 12 | "Cat": 2, 13 | "Dishes": 3, 14 | "Dog": 4, 15 | "Electric_shaver_toothbrush": 5, 16 | "Frying": 6, 17 | "Running_water": 7, 18 | "Speech": 8, 19 | "Vacuum_cleaner": 9, 20 | } 21 | ) 22 | -------------------------------------------------------------------------------- /src/conda_create_environment.sh: -------------------------------------------------------------------------------- 1 | conda create -y -n dcase2022 python==3.8.5 2 | source activate dcase2022 3 | 4 | conda install -y numba 5 | conda install -y librosa -c conda-forge 6 | conda install -y ffmpeg -c conda-forge 7 | conda install -y sox -c conda-forge 8 | conda install -y pandas h5py scipy 9 | conda install -y pytorch torchvision torchaudio cudatoolkit=11.3 -c pytorch # for gpu install (or cpu in MAC) 10 | # conda install pytorch-cpu torchvision-cpu -c pytorch (cpu linux) 11 | conda install -y youtube-dl tqdm -c conda-forge 12 | 13 | pip install codecarbon==1.2.0 14 | pip install -r ../../requirements.txt 15 | pip install torchmetrics==0.7.3 16 | pip install -e ../../. -------------------------------------------------------------------------------- /src/desed_task/utils/download.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | import os 3 | from tqdm import tqdm 4 | import requests 5 | from pathlib import Path 6 | 7 | def download_from_url(url, destination): 8 | if os.path.exists(destination): 9 | print("Skipping download as file in {} exists already".format(destination)) 10 | return 11 | response = requests.get(url, stream=True) 12 | total_size_in_bytes = int(response.headers.get('content-length', 0)) 13 | block_size = 1024 # 1 Kibibyte 14 | progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True) 15 | os.makedirs(Path(destination).parent, exist_ok=True) 16 | with open(destination, 'wb') as file: 17 | for data in response.iter_content(block_size): 18 | progress_bar.update(len(data)) 19 | file.write(data) 20 | progress_bar.close() 21 | if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes: 22 | print("ERROR, something went wrong") -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Yiming Li 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /src/desed_task/utils/schedulers.py: -------------------------------------------------------------------------------- 1 | from asteroid.engine.schedulers import * 2 | import numpy as np 3 | 4 | 5 | class ExponentialWarmup(BaseScheduler): 6 | """ Scheduler to apply ramp-up during training to the learning rate. 7 | Args: 8 | optimizer: torch.optimizer.Optimizer, the optimizer from which to rampup the value from 9 | max_lr: float, the maximum learning to use at the end of ramp-up. 10 | rampup_length: int, the length of the rampup (number of steps). 11 | exponent: float, the exponent to be used. 12 | """ 13 | 14 | def __init__(self, optimizer, max_lr, rampup_length, exponent=-5.0): 15 | super().__init__(optimizer) 16 | self.rampup_len = rampup_length 17 | self.max_lr = max_lr 18 | self.step_num = 1 19 | self.exponent = exponent 20 | 21 | def _get_scaling_factor(self): 22 | 23 | if self.rampup_len == 0: 24 | return 1.0 25 | else: 26 | 27 | current = np.clip(self.step_num, 0.0, self.rampup_len) 28 | phase = 1.0 - current / self.rampup_len 29 | return float(np.exp(self.exponent * phase * phase)) 30 | 31 | def _get_lr(self): 32 | return self.max_lr * self._get_scaling_factor() 33 | -------------------------------------------------------------------------------- /src/desed_task/nnet/RNN.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | import torch 4 | from torch import nn as nn 5 | 6 | 7 | class BidirectionalGRU(nn.Module): 8 | def __init__(self, n_in, n_hidden, dropout=0, num_layers=1): 9 | 10 | """ 11 | Initialization of BidirectionalGRU instance 12 | Args: 13 | n_in: int, number of input 14 | n_hidden: int, number of hidden layers 15 | dropout: flat, dropout 16 | num_layers: int, number of layers 17 | """ 18 | 19 | super(BidirectionalGRU, self).__init__() 20 | self.rnn = nn.GRU( 21 | n_in, 22 | n_hidden, 23 | bidirectional=True, 24 | dropout=dropout, 25 | batch_first=True, 26 | num_layers=num_layers, 27 | ) 28 | 29 | def forward(self, input_feat): 30 | recurrent, _ = self.rnn(input_feat) 31 | return recurrent 32 | 33 | 34 | class BidirectionalLSTM(nn.Module): 35 | def __init__(self, nIn, nHidden, nOut, dropout=0, num_layers=1): 36 | super(BidirectionalLSTM, self).__init__() 37 | self.rnn = nn.LSTM( 38 | nIn, 39 | nHidden // 2, 40 | bidirectional=True, 41 | batch_first=True, 42 | dropout=dropout, 43 | num_layers=num_layers, 44 | ) 45 | self.embedding = nn.Linear(nHidden * 2, nOut) 46 | 47 | def forward(self, input_feat): 48 | recurrent, _ = self.rnn(input_feat) 49 | b, T, h = recurrent.size() 50 | t_rec = recurrent.contiguous().view(b * T, h) 51 | 52 | output = self.embedding(t_rec) # [T * b, nOut] 53 | output = output.view(b, T, -1) 54 | return output 55 | -------------------------------------------------------------------------------- /src/desed_task/nnet/.ipynb_checkpoints/RNN-checkpoint.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | import torch 4 | from torch import nn as nn 5 | 6 | 7 | class BidirectionalGRU(nn.Module): 8 | def __init__(self, n_in, n_hidden, dropout=0, num_layers=1): 9 | 10 | """ 11 | Initialization of BidirectionalGRU instance 12 | Args: 13 | n_in: int, number of input 14 | n_hidden: int, number of hidden layers 15 | dropout: flat, dropout 16 | num_layers: int, number of layers 17 | """ 18 | 19 | super(BidirectionalGRU, self).__init__() 20 | self.rnn = nn.GRU( 21 | n_in, 22 | n_hidden, 23 | bidirectional=True, 24 | dropout=dropout, 25 | batch_first=True, 26 | num_layers=num_layers, 27 | ) 28 | 29 | def forward(self, input_feat): 30 | recurrent, _ = self.rnn(input_feat) 31 | return recurrent 32 | 33 | 34 | class BidirectionalLSTM(nn.Module): 35 | def __init__(self, nIn, nHidden, nOut, dropout=0, num_layers=1): 36 | super(BidirectionalLSTM, self).__init__() 37 | self.rnn = nn.LSTM( 38 | nIn, 39 | nHidden // 2, 40 | bidirectional=True, 41 | batch_first=True, 42 | dropout=dropout, 43 | num_layers=num_layers, 44 | ) 45 | self.embedding = nn.Linear(nHidden * 2, nOut) 46 | 47 | def forward(self, input_feat): 48 | recurrent, _ = self.rnn(input_feat) 49 | b, T, h = recurrent.size() 50 | t_rec = recurrent.contiguous().view(b * T, h) 51 | 52 | output = self.embedding(t_rec) # [T * b, nOut] 53 | output = output.view(b, T, -1) 54 | return output 55 | -------------------------------------------------------------------------------- /src/local/con_loss.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import torch 3 | import torch.nn as nn 4 | 5 | class SupConLoss(nn.Module): 6 | def __init__(self, temperature=0.10, pos_thresh=0.90): 7 | super(SupConLoss, self).__init__() 8 | self.temperature = temperature 9 | self.pos_thresh = pos_thresh 10 | 11 | def forward(self, features_stu, pseudo_lb_stu, features_proto, lb_proto): 12 | """ 13 | Args: 14 | features_stu: features from student model. 15 | features_proto: features from prototypes. 16 | pseudo_lb_stu: frame-level prob vector for student preds. 17 | lb_proto: labels for prototypes. 18 | """ 19 | bg_prob_stu = ((1.0 - pseudo_lb_stu.max(1)[0])).float().unsqueeze(1) # background class 20 | pseudo_lb_stu = torch.cat([pseudo_lb_stu, bg_prob_stu], dim=1) 21 | 22 | # get (hard) mask 23 | pseudo_lb_stu[pseudo_lb_stu > self.pos_thresh] = 1.0 24 | pseudo_lb_stu[pseudo_lb_stu < self.pos_thresh] = 0.0 25 | 26 | mask = pseudo_lb_stu 27 | 28 | # compute feature similarity logits 29 | # features_stu (n, d), features_proto (c, m, d) -> (c, m, n) -> (n, c, m) -> (n, c) 30 | feat_sim_mat = torch.max(torch.matmul(features_proto, features_stu.T).permute(2, 0, 1), dim=2)[0] 31 | feat_sim_mat = torch.div(feat_sim_mat, self.temperature) 32 | 33 | # for numerical stability 34 | logits_max, _ = torch.max(feat_sim_mat, dim=1, keepdim=True) 35 | feat_sim_mat = feat_sim_mat - logits_max.detach() 36 | 37 | # compute log_prob 38 | exp_logits = torch.exp(feat_sim_mat) 39 | log_prob = feat_sim_mat - torch.log(1e-7 + exp_logits.sum(1, keepdim=True)) 40 | 41 | # compute mean of log-likelihood over positive 42 | mean_log_prob_pos = (mask * log_prob).sum(1) / (mask.sum(1) + 1e-7) 43 | 44 | # loss 45 | loss = - mean_log_prob_pos 46 | return loss.mean() 47 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # LGC-SED 2 | 3 | The official implementations of "Semi-supervised Sound Event Detection with Local and Global Consistency Regularization" (accepted by ICASSP 2024). 4 | 5 | [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT) 6 | 7 | 8 | 9 | ## Introduction 10 | 11 | LGC pursues the local consistency regularization and global consistency regularization simultaneously. The former adopts audio CutMix to change the boundaries of a sound event and helps to learn robust patterns under varying contexts, while the later leverages a contrastive loss to encourage frame features to be appealed to according class prototypes and repelled to prototypes of other classes. Therefore, the intra-class variance is decreased while the intra-class variance is increased for learning a decision boundary at low-density regions. 12 | 13 | Here is an overview of our method, 14 | 15 | prompt_tuning 16 | 17 | 18 | 19 | ## Get started 20 | 21 | We provide the source code for vanilla LGC which is not combined with audio augmentations or other consistency regularization methods (e.g., SCT, RCT). Besides, LGC is lightweight and can be reproduced on a single RTX 3080 with 10 GB RAM. Here are the instructions to run it, 22 | 23 | 1. Download the whole DESED dataset (please notice that we do not use the strong real subset which contains 3470 audio clips). 24 | 2. Build up the environment. The environment of this code is the same as the DCASE 2022 baseline (so the PSDS score we used for evaluation is the threshold dependent one). 25 | 3. Clone the codes by: 26 | 27 | ``` 28 | git clone https://github.com/Ming-er/LGC-SED.git 29 | ``` 30 | 31 | 4. Change all required paths in `src/confs/default.yaml` to your own paths. 32 | 33 | 5. Then you can start training by, 34 | 35 | ``` 36 | python train_sed.py 37 | ``` 38 | 39 | 40 | 41 | ## Citation 42 | 43 | If you want to cite this paper: 44 | 45 | ``` 46 | @article{li2023semi, 47 | title={Semi-supervised Sound Event Detection with Local and Global Consistency Regularization}, 48 | author={Li, Yiming and Wang, Xiangdong and Liu, Hong and Tao, Rui and Yan, Long and Ouchi, Kazushige}, 49 | journal={arXiv preprint arXiv:2309.08355}, 50 | year={2023} 51 | } 52 | ``` -------------------------------------------------------------------------------- /src/local/resample_folder.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import glob 3 | import os 4 | from pathlib import Path 5 | 6 | import librosa 7 | import torch 8 | import torchaudio 9 | import tqdm 10 | 11 | parser = argparse.ArgumentParser("Resample a folder recursively") 12 | parser.add_argument( 13 | "--in_dir", 14 | type=str, 15 | default="/media/sam/bx500/DCASE_DATA/dataset/audio/validation/", 16 | ) 17 | parser.add_argument("--out_dir", type=str, default="/tmp/val16k") 18 | parser.add_argument("--target_fs", default=16000) 19 | parser.add_argument("--regex", type=str, default="*.wav") 20 | 21 | 22 | def resample(audio, orig_fs, target_fs): 23 | """ 24 | Resamples the audio given as input at the target_fs sample rate, if the target sample rate and the 25 | original sample rate are different. 26 | 27 | Args: 28 | audio (Tensor): audio to resample 29 | orig_fs (int): original sample rate 30 | target_fs (int): target sample rate 31 | 32 | Returns: 33 | Tensor: audio resampled 34 | """ 35 | out = [] 36 | for c in range(audio.shape[0]): 37 | tmp = audio[c].detach().cpu().numpy() 38 | if target_fs != orig_fs: 39 | tmp = librosa.resample(tmp, orig_fs, target_fs) 40 | out.append(torch.from_numpy(tmp)) 41 | out = torch.stack(out) 42 | return out 43 | 44 | 45 | def resample_folder(in_dir, out_dir, target_fs=16000, regex="*.wav"): 46 | """ 47 | Resamples the audio files contained in the in_dir folder and saves them in out_dir folder 48 | 49 | Args: 50 | in_dir (str): path to audio directory (audio to be resampled) 51 | out_dir (str): path to audio resampled directory 52 | target_fs (int, optional): target sample rate. Defaults to 16000. 53 | regex (str, optional): regular expression for extension of file. Defaults to "*.wav". 54 | """ 55 | compute = True 56 | files = glob.glob(os.path.join(in_dir, regex)) 57 | if os.path.exists(out_dir): 58 | out_files = glob.glob(os.path.join(out_dir, regex)) 59 | if len(files) == len(out_files): 60 | compute = False 61 | 62 | if compute: 63 | for f in tqdm.tqdm(files): 64 | audio, orig_fs = torchaudio.load(f) 65 | audio = resample(audio, orig_fs, target_fs) 66 | 67 | os.makedirs( 68 | Path(os.path.join(out_dir, Path(f).relative_to(Path(in_dir)))).parent, 69 | exist_ok=True, 70 | ) 71 | torchaudio.save( 72 | os.path.join(out_dir, Path(f).relative_to(Path(in_dir))), 73 | audio, 74 | target_fs, 75 | ) 76 | return compute 77 | 78 | 79 | if __name__ == "__main__": 80 | args = parser.parse_args() 81 | resample_folder(args.in_dir, args.out_dir, int(args.target_fs), args.regex) 82 | -------------------------------------------------------------------------------- /src/confs/default.yaml: -------------------------------------------------------------------------------- 1 | training: 2 | # batch size: [synth, weak, unlabel] 3 | batch_size: [12, 12, 24] 4 | batch_size_val: 24 5 | const_max: 2 # max weight used for self supervised loss 6 | n_epochs_warmup: 50 # num epochs used for exponential warmup 7 | num_workers: 8 # change according to your cpu 8 | n_epochs: 200 # max num epochs 9 | early_stop_patience: 200 # Same as number of epochs by default, so no early stopping used 10 | accumulate_batches: 1 11 | gradient_clip: 0. # 0 no gradient clipping 12 | median_window: [3, 28, 7, 4, 7, 22, 48, 19, 10, 50] # length of median filter used to smooth prediction in inference (nb of output frames) 13 | val_thresholds: [0.5] # thresholds used to compute f1 intersection in validation. 14 | n_test_thresholds: 50 # number of thresholds used to compute psds in test 15 | ema_factor: 0.999 # ema factor for mean teacher 16 | self_sup_loss: mse # bce or mse for self supervised mean teacher loss 17 | backend: dp # pytorch lightning backend, ddp, dp or None 18 | validation_interval: 1 # perform validation every X epoch, 1 default 19 | weak_split: 0.9 20 | seed: 42 21 | precision: 32 22 | mixup: soft # Soft mixup gives the ratio of the mix to the labels, hard mixup gives a 1 to every label present. 23 | obj_metric_synth_type: teacher_intersection 24 | precision: 32 25 | 26 | LGC: 27 | proto_nums: 3 28 | start_contrast_epochs: 100 29 | num_class: 10 30 | prototype_ema: 0.996 31 | feat_dim: 128 32 | pos_thresh: 0.90 33 | neg_thresh: 0.50 34 | 35 | scaler: 36 | statistic: instance # instance or dataset-wide statistic 37 | normtype: minmax # minmax or standard or mean normalization 38 | dims: [1, 2] # dimensions over which normalization is applied 39 | savepath: ./scaler.ckpt # path to scaler checkpoint 40 | 41 | data: # change with your paths if different. 42 | # NOTE: if you have data in 44kHz only then synth_folder will be the path where 43 | # resampled data will be placed. 44 | synth_folder: "/PATH/TO/YOUR/DATA" 45 | synth_folder_44k: "/PATH/TO/YOUR/DATA" 46 | synth_tsv: "/PATH/TO/YOUR/DATA" 47 | strong_folder: "/PATH/TO/YOUR/DATA" 48 | strong_folder_44k: "/PATH/TO/YOUR/DATA" 49 | strong_tsv: "/PATH/TO/YOUR/DATA" 50 | weak_folder: "/PATH/TO/YOUR/DATA" 51 | weak_folder_44k: "/PATH/TO/YOUR/DATA" 52 | weak_tsv: "/PATH/TO/YOUR/DATA" 53 | unlabeled_folder: "/PATH/TO/YOUR/DATA" 54 | unlabeled_folder_44k: "/PATH/TO/YOUR/DATA" 55 | val_folder: "/PATH/TO/YOUR/DATA" 56 | val_folder_44k: "/PATH/TO/YOUR/DATA" 57 | val_tsv: "/PATH/TO/YOUR/DATA" 58 | val_dur: "/PATH/TO/YOUR/DATA" 59 | test_folder: "/PATH/TO/YOUR/DATA" 60 | test_folder_44k: "/PATH/TO/YOUR/DATA" 61 | test_tsv: "/PATH/TO/YOUR/DATA" 62 | test_dur: "/PATH/TO/YOUR/DATA" 63 | eval_folder: "/PATH/TO/YOUR/DATA" 64 | eval_folder_44k: "/PATH/TO/YOUR/DATA" 65 | audio_max_len: 10 66 | fs: 16000 67 | net_subsample: 4 68 | opt: 69 | lr: 0.001 70 | feats: 71 | n_mels: 128 72 | n_filters: 2048 73 | hop_length: 256 74 | n_window: 2048 75 | sample_rate: 16000 76 | f_min: 0 77 | f_max: 8000 78 | net: 79 | dropout: 0.5 80 | rnn_layers: 2 81 | n_in_channel: 1 82 | nclass: 10 83 | attention: True 84 | n_RNN_cell: 128 85 | activation: glu 86 | rnn_type: BGRU 87 | kernel_size: [3, 3, 3, 3, 3, 3, 3] 88 | padding: [1, 1, 1, 1, 1, 1, 1] 89 | stride: [1, 1, 1, 1, 1, 1, 1] 90 | nb_filters: [ 16, 32, 64, 128, 128, 128, 128] 91 | pooling: [ [ 2, 2 ], [ 2, 2 ], [ 1, 2 ], [ 1, 2 ], [ 1, 2 ], [ 1, 2 ], [ 1, 2 ] ] 92 | dropout_recurrent: 0 93 | -------------------------------------------------------------------------------- /src/desed_task/dataio/sampler.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Sampler 2 | import numpy as np 3 | 4 | 5 | class ConcatDatasetBatchSampler(Sampler): 6 | """This sampler is built to work with a standard Pytorch ConcatDataset. 7 | From SpeechBrain dataio see https://github.com/speechbrain/ 8 | 9 | It is used to retrieve elements from the different concatenated datasets placing them in the same batch 10 | with proportion specified by batch_sizes, e.g 8, 16 means each batch will 11 | be of 24 elements with the first 8 belonging to the first dataset in ConcatDataset 12 | object and the last 16 to the second. 13 | More than two datasets are supported, in that case you need to provide 3 batch 14 | sizes. 15 | 16 | Note 17 | ---- 18 | Batched are drawn from the datasets till the one with smallest length is exhausted. 19 | Thus number of examples in your training epoch is dictated by the dataset 20 | whose length is the smallest. 21 | 22 | 23 | Arguments 24 | --------- 25 | samplers : int 26 | The base seed to use for the random number generator. It is recommended 27 | to use a value which has a good mix of 0 and 1 bits. 28 | batch_sizes: list 29 | Batch sizes. 30 | epoch : int 31 | The epoch to start at. 32 | """ 33 | 34 | def __init__(self, samplers, batch_sizes: (tuple, list), epoch=0) -> None: 35 | 36 | if not isinstance(samplers, (list, tuple)): 37 | raise ValueError( 38 | "samplers should be a list or tuple of Pytorch Samplers, " 39 | "but got samplers={}".format(batch_sizes) 40 | ) 41 | 42 | if not isinstance(batch_sizes, (list, tuple)): 43 | raise ValueError( 44 | "batch_sizes should be a list or tuple of integers, " 45 | "but got batch_sizes={}".format(batch_sizes) 46 | ) 47 | 48 | if not len(batch_sizes) == len(samplers): 49 | raise ValueError("batch_sizes and samplers should be have same length") 50 | 51 | self.batch_sizes = batch_sizes 52 | self.samplers = samplers 53 | self.offsets = [0] + np.cumsum([len(x) for x in self.samplers]).tolist()[:-1] 54 | 55 | self.epoch = epoch 56 | self.set_epoch(self.epoch) 57 | 58 | def _iter_one_dataset(self, c_batch_size, c_sampler, c_offset): 59 | batch = [] 60 | for idx in c_sampler: 61 | batch.append(c_offset + idx) 62 | if len(batch) == c_batch_size: 63 | yield batch 64 | 65 | def set_epoch(self, epoch): 66 | if hasattr(self.samplers[0], "epoch"): 67 | for s in self.samplers: 68 | s.set_epoch(epoch) 69 | 70 | def __iter__(self): 71 | 72 | iterators = [iter(i) for i in self.samplers] 73 | tot_batch = [] 74 | 75 | for b_num in range(len(self)): 76 | for samp_idx in range(len(self.samplers)): 77 | c_batch = [] 78 | while len(c_batch) < self.batch_sizes[samp_idx]: 79 | c_batch.append(self.offsets[samp_idx] + next(iterators[samp_idx])) 80 | tot_batch.extend(c_batch) 81 | yield tot_batch 82 | tot_batch = [] 83 | 84 | def __len__(self): 85 | 86 | min_len = float("inf") 87 | for idx, sampler in enumerate(self.samplers): 88 | c_len = (len(sampler)) // self.batch_sizes[idx] 89 | 90 | min_len = min(c_len, min_len) 91 | return min_len 92 | -------------------------------------------------------------------------------- /src/desed_task/nnet/CNN.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | 4 | 5 | class GLU(nn.Module): 6 | def __init__(self, input_num): 7 | super(GLU, self).__init__() 8 | self.sigmoid = nn.Sigmoid() 9 | self.linear = nn.Linear(input_num, input_num) 10 | 11 | def forward(self, x): 12 | lin = self.linear(x.permute(0, 2, 3, 1)) 13 | lin = lin.permute(0, 3, 1, 2) 14 | sig = self.sigmoid(x) 15 | res = lin * sig 16 | return res 17 | 18 | 19 | class ContextGating(nn.Module): 20 | def __init__(self, input_num): 21 | super(ContextGating, self).__init__() 22 | self.sigmoid = nn.Sigmoid() 23 | self.linear = nn.Linear(input_num, input_num) 24 | 25 | def forward(self, x): 26 | lin = self.linear(x.permute(0, 2, 3, 1)) 27 | lin = lin.permute(0, 3, 1, 2) 28 | sig = self.sigmoid(lin) 29 | res = x * sig 30 | return res 31 | 32 | 33 | class CNN(nn.Module): 34 | def __init__( 35 | self, 36 | n_in_channel, 37 | activation="Relu", 38 | conv_dropout=0, 39 | kernel_size=[3, 3, 3], 40 | padding=[1, 1, 1], 41 | stride=[1, 1, 1], 42 | nb_filters=[64, 64, 64], 43 | pooling=[(1, 4), (1, 4), (1, 4)], 44 | normalization="batch", 45 | **transformer_kwargs 46 | ): 47 | """ 48 | Initialization of CNN network s 49 | 50 | Args: 51 | n_in_channel: int, number of input channel 52 | activation: str, activation function 53 | conv_dropout: float, dropout 54 | kernel_size: kernel size 55 | padding: padding 56 | stride: list, stride 57 | nb_filters: number of filters 58 | pooling: list of tuples, time and frequency pooling 59 | normalization: choose between "batch" for BatchNormalization and "layer" for LayerNormalization. 60 | """ 61 | super(CNN, self).__init__() 62 | 63 | self.nb_filters = nb_filters 64 | cnn = nn.Sequential() 65 | 66 | def conv(i, normalization="batch", dropout=None, activ="relu"): 67 | nIn = n_in_channel if i == 0 else nb_filters[i - 1] 68 | nOut = nb_filters[i] 69 | cnn.add_module( 70 | "conv{0}".format(i), 71 | nn.Conv2d(nIn, nOut, kernel_size[i], stride[i], padding[i]), 72 | ) 73 | if normalization == "batch": 74 | cnn.add_module( 75 | "batchnorm{0}".format(i), 76 | nn.BatchNorm2d(nOut, eps=0.001, momentum=0.99), 77 | ) 78 | elif normalization == "layer": 79 | cnn.add_module("layernorm{0}".format(i), nn.GroupNorm(1, nOut)) 80 | 81 | if activ.lower() == "leakyrelu": 82 | cnn.add_module("relu{0}".format(i), nn.LeakyReLU(0.2)) 83 | elif activ.lower() == "relu": 84 | cnn.add_module("relu{0}".format(i), nn.ReLU()) 85 | elif activ.lower() == "glu": 86 | cnn.add_module("glu{0}".format(i), GLU(nOut)) 87 | elif activ.lower() == "cg": 88 | cnn.add_module("cg{0}".format(i), ContextGating(nOut)) 89 | 90 | if dropout is not None: 91 | cnn.add_module("dropout{0}".format(i), nn.Dropout(dropout)) 92 | 93 | # 128x862x64 94 | for i in range(len(nb_filters)): 95 | conv(i, normalization=normalization, dropout=conv_dropout, activ=activation) 96 | cnn.add_module( 97 | "pooling{0}".format(i), nn.AvgPool2d(pooling[i]) 98 | ) # bs x tframe x mels 99 | 100 | self.cnn = cnn 101 | 102 | def forward(self, x): 103 | """ 104 | Forward step of the CNN module 105 | 106 | Args: 107 | x (Tensor): input batch of size (batch_size, n_channels, n_frames, n_freq) 108 | 109 | Returns: 110 | Tensor: batch embedded 111 | """ 112 | # conv features 113 | x = self.cnn(x) 114 | return x 115 | -------------------------------------------------------------------------------- /src/desed_task/nnet/.ipynb_checkpoints/CNN-checkpoint.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | 4 | 5 | class GLU(nn.Module): 6 | def __init__(self, input_num): 7 | super(GLU, self).__init__() 8 | self.sigmoid = nn.Sigmoid() 9 | self.linear = nn.Linear(input_num, input_num) 10 | 11 | def forward(self, x): 12 | lin = self.linear(x.permute(0, 2, 3, 1)) 13 | lin = lin.permute(0, 3, 1, 2) 14 | sig = self.sigmoid(x) 15 | res = lin * sig 16 | return res 17 | 18 | 19 | class ContextGating(nn.Module): 20 | def __init__(self, input_num): 21 | super(ContextGating, self).__init__() 22 | self.sigmoid = nn.Sigmoid() 23 | self.linear = nn.Linear(input_num, input_num) 24 | 25 | def forward(self, x): 26 | lin = self.linear(x.permute(0, 2, 3, 1)) 27 | lin = lin.permute(0, 3, 1, 2) 28 | sig = self.sigmoid(lin) 29 | res = x * sig 30 | return res 31 | 32 | 33 | class CNN(nn.Module): 34 | def __init__( 35 | self, 36 | n_in_channel, 37 | activation="Relu", 38 | conv_dropout=0, 39 | kernel_size=[3, 3, 3], 40 | padding=[1, 1, 1], 41 | stride=[1, 1, 1], 42 | nb_filters=[64, 64, 64], 43 | pooling=[(1, 4), (1, 4), (1, 4)], 44 | normalization="batch", 45 | **transformer_kwargs 46 | ): 47 | """ 48 | Initialization of CNN network s 49 | 50 | Args: 51 | n_in_channel: int, number of input channel 52 | activation: str, activation function 53 | conv_dropout: float, dropout 54 | kernel_size: kernel size 55 | padding: padding 56 | stride: list, stride 57 | nb_filters: number of filters 58 | pooling: list of tuples, time and frequency pooling 59 | normalization: choose between "batch" for BatchNormalization and "layer" for LayerNormalization. 60 | """ 61 | super(CNN, self).__init__() 62 | 63 | self.nb_filters = nb_filters 64 | cnn = nn.Sequential() 65 | 66 | def conv(i, normalization="batch", dropout=None, activ="relu"): 67 | nIn = n_in_channel if i == 0 else nb_filters[i - 1] 68 | nOut = nb_filters[i] 69 | cnn.add_module( 70 | "conv{0}".format(i), 71 | nn.Conv2d(nIn, nOut, kernel_size[i], stride[i], padding[i]), 72 | ) 73 | if normalization == "batch": 74 | cnn.add_module( 75 | "batchnorm{0}".format(i), 76 | nn.BatchNorm2d(nOut, eps=0.001, momentum=0.99), 77 | ) 78 | elif normalization == "layer": 79 | cnn.add_module("layernorm{0}".format(i), nn.GroupNorm(1, nOut)) 80 | 81 | if activ.lower() == "leakyrelu": 82 | cnn.add_module("relu{0}".format(i), nn.LeakyReLU(0.2)) 83 | elif activ.lower() == "relu": 84 | cnn.add_module("relu{0}".format(i), nn.ReLU()) 85 | elif activ.lower() == "glu": 86 | cnn.add_module("glu{0}".format(i), GLU(nOut)) 87 | elif activ.lower() == "cg": 88 | cnn.add_module("cg{0}".format(i), ContextGating(nOut)) 89 | 90 | if dropout is not None: 91 | cnn.add_module("dropout{0}".format(i), nn.Dropout(dropout)) 92 | 93 | # 128x862x64 94 | for i in range(len(nb_filters)): 95 | conv(i, normalization=normalization, dropout=conv_dropout, activ=activation) 96 | cnn.add_module( 97 | "pooling{0}".format(i), nn.AvgPool2d(pooling[i]) 98 | ) # bs x tframe x mels 99 | 100 | self.cnn = cnn 101 | 102 | def forward(self, x): 103 | """ 104 | Forward step of the CNN module 105 | 106 | Args: 107 | x (Tensor): input batch of size (batch_size, n_channels, n_frames, n_freq) 108 | 109 | Returns: 110 | Tensor: batch embedded 111 | """ 112 | # conv features 113 | x = self.cnn(x) 114 | return x 115 | -------------------------------------------------------------------------------- /src/desed_task/utils/scaler.py: -------------------------------------------------------------------------------- 1 | import tqdm 2 | import torch 3 | 4 | 5 | class TorchScaler(torch.nn.Module): 6 | """ 7 | This torch module implements scaling for input tensors, both instance based 8 | and dataset-wide statistic based. 9 | 10 | Args: 11 | statistic: str, (default='dataset'), represent how to compute the statistic for normalisation. 12 | Choice in {'dataset', 'instance'}. 13 | 'dataset' needs to be 'fit()' with a dataloader of the dataset. 14 | 'instance' apply the normalisation at an instance-level, so compute the statitics on the instance 15 | specified, it can be a clip or a batch. 16 | normtype: str, (default='standard') the type of normalisation to use. 17 | Choice in {'standard', 'mean', 'minmax'}. 'standard' applies a classic normalisation with mean and standard 18 | deviation. 'mean' substract the mean to the data. 'minmax' substract the minimum of the data and divide by 19 | the difference between max and min. 20 | """ 21 | 22 | def __init__(self, statistic="dataset", normtype="standard", dims=(1, 2), eps=1e-8): 23 | super(TorchScaler, self).__init__() 24 | assert statistic in ["dataset", "instance", None] 25 | assert normtype in ["standard", "mean", "minmax", None] 26 | if statistic == "dataset" and normtype == "minmax": 27 | raise NotImplementedError( 28 | "statistic==dataset and normtype==minmax is not currently implemented." 29 | ) 30 | self.statistic = statistic 31 | self.normtype = normtype 32 | self.dims = dims 33 | self.eps = eps 34 | 35 | def load_state_dict(self, state_dict, strict=True): 36 | if self.statistic == "dataset": 37 | super(TorchScaler, self).load_state_dict(state_dict, strict) 38 | 39 | def _load_from_state_dict( 40 | self, 41 | state_dict, 42 | prefix, 43 | local_metadata, 44 | strict, 45 | missing_keys, 46 | unexpected_keys, 47 | error_msgs, 48 | ): 49 | if self.statistic == "dataset": 50 | super(TorchScaler, self)._load_from_state_dict( 51 | state_dict, 52 | prefix, 53 | local_metadata, 54 | strict, 55 | missing_keys, 56 | unexpected_keys, 57 | error_msgs, 58 | ) 59 | 60 | def fit(self, dataloader, transform_func=lambda x: x[0]): 61 | """ 62 | Scaler fitting 63 | 64 | Args: 65 | dataloader (DataLoader): training data DataLoader 66 | transform_func (lambda function, optional): Transforms applied to the data. 67 | Defaults to lambdax:x[0]. 68 | """ 69 | indx = 0 70 | for batch in tqdm.tqdm(dataloader): 71 | 72 | feats = transform_func(batch) 73 | if indx == 0: 74 | mean = torch.mean(feats, self.dims, keepdim=True).mean(0).unsqueeze(0) 75 | mean_squared = ( 76 | torch.mean(feats ** 2, self.dims, keepdim=True).mean(0).unsqueeze(0) 77 | ) 78 | else: 79 | mean += torch.mean(feats, self.dims, keepdim=True).mean(0).unsqueeze(0) 80 | mean_squared += ( 81 | torch.mean(feats ** 2, self.dims, keepdim=True).mean(0).unsqueeze(0) 82 | ) 83 | indx += 1 84 | 85 | mean /= indx 86 | mean_squared /= indx 87 | 88 | self.register_buffer("mean", mean) 89 | self.register_buffer("mean_squared", mean_squared) 90 | 91 | def forward(self, tensor): 92 | 93 | if self.statistic is None or self.normtype is None: 94 | return tensor 95 | 96 | if self.statistic == "dataset": 97 | assert hasattr(self, "mean") and hasattr( 98 | self, "mean_squared" 99 | ), "TorchScaler should be fit before used if statistics=dataset" 100 | assert tensor.ndim == self.mean.ndim, "Pre-computed statistics " 101 | if self.normtype == "mean": 102 | return tensor - self.mean 103 | elif self.normtype == "standard": 104 | std = torch.sqrt(self.mean_squared - self.mean ** 2) 105 | return (tensor - self.mean) / (std + self.eps) 106 | else: 107 | raise NotImplementedError 108 | 109 | else: 110 | if self.normtype == "mean": 111 | return tensor - torch.mean(tensor, self.dims, keepdim=True) 112 | elif self.normtype == "standard": 113 | return (tensor - torch.mean(tensor, self.dims, keepdim=True)) / ( 114 | torch.std(tensor, self.dims, keepdim=True) + self.eps 115 | ) 116 | elif self.normtype == "minmax": 117 | return (tensor - torch.amin(tensor, dim=self.dims, keepdim=True)) / ( 118 | torch.amax(tensor, dim=self.dims, keepdim=True) 119 | - torch.amin(tensor, dim=self.dims, keepdim=True) 120 | + self.eps 121 | ) 122 | -------------------------------------------------------------------------------- /src/desed_task/nnet/CRNN.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | import torch.nn as nn 4 | import torch 5 | from .RNN import BidirectionalGRU 6 | from .CNN import CNN 7 | import torch.nn.functional as F 8 | 9 | class CRNN(nn.Module): 10 | def __init__( 11 | self, 12 | n_in_channel=1, 13 | nclass=10, 14 | attention=True, 15 | activation="glu", 16 | dropout=0.5, 17 | train_cnn=True, 18 | rnn_type="BGRU", 19 | n_RNN_cell=128, 20 | n_layers_RNN=2, 21 | dropout_recurrent=0, 22 | cnn_integration=False, 23 | freeze_bn=False, 24 | use_embeddings=False, 25 | embedding_size=527, 26 | embedding_type="global", 27 | frame_emb_enc_dim=512, 28 | **kwargs, 29 | ): 30 | """ 31 | Initialization of CRNN model 32 | 33 | Args: 34 | n_in_channel: int, number of input channel 35 | n_class: int, number of classes 36 | attention: bool, adding attention layer or not 37 | activation: str, activation function 38 | dropout: float, dropout 39 | train_cnn: bool, training cnn layers 40 | rnn_type: str, rnn type 41 | n_RNN_cell: int, RNN nodes 42 | n_layer_RNN: int, number of RNN layers 43 | dropout_recurrent: float, recurrent layers dropout 44 | cnn_integration: bool, integration of cnn 45 | freeze_bn: 46 | **kwargs: keywords arguments for CNN. 47 | """ 48 | super(CRNN, self).__init__() 49 | self.n_in_channel = n_in_channel 50 | self.attention = attention 51 | self.cnn_integration = cnn_integration 52 | self.freeze_bn = freeze_bn 53 | self.use_embeddings = use_embeddings 54 | self.embedding_type = embedding_type 55 | 56 | n_in_cnn = n_in_channel 57 | 58 | if cnn_integration: 59 | n_in_cnn = 1 60 | 61 | self.cnn = CNN( 62 | n_in_channel=n_in_cnn, activation=activation, conv_dropout=dropout, **kwargs 63 | ) 64 | 65 | 66 | self.train_cnn = train_cnn 67 | 68 | if not train_cnn: 69 | for param in self.cnn.parameters(): 70 | param.requires_grad = False 71 | 72 | if rnn_type == "BGRU": 73 | nb_in = self.cnn.nb_filters[-1] 74 | if self.cnn_integration: 75 | nb_in = nb_in * n_in_channel 76 | self.rnn = BidirectionalGRU( 77 | n_in=nb_in, 78 | n_hidden=n_RNN_cell, 79 | dropout=dropout_recurrent, 80 | num_layers=n_layers_RNN, 81 | ) 82 | else: 83 | NotImplementedError("Only BGRU supported for CRNN for now") 84 | 85 | self.cnn_proj_fc = nn.Linear(n_RNN_cell * 2, n_RNN_cell) 86 | self.cnn_proj_bn = nn.BatchNorm1d(nb_in, eps=0.001, momentum=0.99) 87 | self.cnn_relu = nn.LeakyReLU(inplace=False, negative_slope=0.2) 88 | 89 | self.dropout = nn.Dropout(dropout) 90 | self.dense = nn.Linear(n_RNN_cell * 2, nclass) 91 | self.sigmoid = nn.Sigmoid() 92 | 93 | if self.attention: 94 | self.dense_softmax = nn.Linear(n_RNN_cell * 2, nclass) 95 | self.softmax = nn.Softmax(dim=-1) 96 | 97 | 98 | if self.use_embeddings: 99 | if self.embedding_type == "frame": 100 | self.frame_embs_encoder = nn.GRU(batch_first=True, input_size=embedding_size, 101 | hidden_size=512, 102 | bidirectional=True) 103 | self.shrink_emb = torch.nn.Sequential(torch.nn.Linear(2 * frame_emb_enc_dim, nb_in), 104 | torch.nn.LayerNorm(nb_in)) 105 | else: 106 | self.shrink_emb = torch.nn.Sequential(torch.nn.Linear(embedding_size, nb_in), 107 | torch.nn.LayerNorm(nb_in)) 108 | self.cat_tf = torch.nn.Linear(2*nb_in, nb_in) 109 | 110 | def forward(self, x, pad_mask=None, embeddings=None, proj=False, refine=False, prototype_vec=None): 111 | 112 | x = x.transpose(1, 2).unsqueeze(1) 113 | 114 | # input size : (batch_size, n_channels, n_frames, n_freq) 115 | if self.cnn_integration: 116 | bs_in, nc_in = x.size(0), x.size(1) 117 | x = x.view(bs_in * nc_in, 1, *x.shape[2:]) 118 | 119 | # conv features 120 | x = self.cnn(x) 121 | bs, chan, frames, freq = x.size() 122 | if self.cnn_integration: 123 | x = x.reshape(bs_in, chan * nc_in, frames, freq) 124 | 125 | if freq != 1: 126 | warnings.warn( 127 | f"Output shape is: {(bs, frames, chan * freq)}, from {freq} staying freq" 128 | ) 129 | x = x.permute(0, 2, 1, 3) 130 | x = x.contiguous().view(bs, frames, chan * freq) 131 | else: 132 | x = x.squeeze(-1) 133 | x = x.permute(0, 2, 1) # [bs, frames, chan] 134 | 135 | # rnn features 136 | if self.use_embeddings: 137 | if self.embedding_type == "global": 138 | x = self.cat_tf(torch.cat((x, self.shrink_emb(embeddings).unsqueeze(1).repeat(1, x.shape[1], 1)), -1)) 139 | else: 140 | # there can be some mismatch between seq length of cnn of crnn and the pretrained embeddings, we use an rnn 141 | # as an encoder and we use the last state 142 | last, _ = self.frame_embs_encoder(embeddings.transpose(1, 2)) 143 | embeddings = last[:, -1] 144 | x = self.cat_tf(torch.cat((x, self.shrink_emb(embeddings).unsqueeze(1).repeat(1, x.shape[1], 1)), -1)) 145 | 146 | x = self.rnn(x) 147 | x = self.dropout(x) 148 | 149 | # projector: conv + relu + norm 150 | if proj: 151 | feat = self.cnn_proj_fc(x) 152 | feat = self.cnn_relu(feat) 153 | embed_dim = feat.size(-1) 154 | feat = feat.reshape(-1, embed_dim) 155 | feat = F.normalize(feat) 156 | 157 | strong = self.dense(x) # [bs, frames, nclass] 158 | strong = self.sigmoid(strong) 159 | 160 | 161 | if self.attention: 162 | sof = self.dense_softmax(x) # [bs, frames, nclass] 163 | if not pad_mask is None: 164 | sof = sof.masked_fill(pad_mask.transpose(1, 2), -1e30) # mask attention 165 | sof = self.softmax(sof) 166 | sof = torch.clamp(sof, min=1e-7, max=1) 167 | weak = (strong * sof).sum(1) / sof.sum(1) # [bs, nclass] 168 | else: 169 | weak = strong.mean(1) 170 | 171 | if proj: 172 | return strong.transpose(1, 2), weak, feat 173 | else: 174 | return strong.transpose(1, 2), weak 175 | 176 | def train(self, mode=True): 177 | """ 178 | Override the default train() to freeze the BN parameters 179 | """ 180 | super(CRNN, self).train(mode) 181 | if self.freeze_bn: 182 | print("Freezing Mean/Var of BatchNorm2D.") 183 | if self.freeze_bn: 184 | print("Freezing Weight/Bias of BatchNorm2D.") 185 | if self.freeze_bn: 186 | for m in self.modules(): 187 | if isinstance(m, nn.BatchNorm2d): 188 | m.eval() 189 | if self.freeze_bn: 190 | m.weight.requires_grad = False 191 | m.bias.requires_grad = False 192 | -------------------------------------------------------------------------------- /src/desed_task/nnet/.ipynb_checkpoints/CRNN-checkpoint.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | import torch.nn as nn 4 | import torch 5 | from .RNN import BidirectionalGRU 6 | from .CNN import CNN 7 | import torch.nn.functional as F 8 | 9 | class CRNN(nn.Module): 10 | def __init__( 11 | self, 12 | n_in_channel=1, 13 | nclass=10, 14 | attention=True, 15 | activation="glu", 16 | dropout=0.5, 17 | train_cnn=True, 18 | rnn_type="BGRU", 19 | n_RNN_cell=128, 20 | n_layers_RNN=2, 21 | dropout_recurrent=0, 22 | cnn_integration=False, 23 | freeze_bn=False, 24 | use_embeddings=False, 25 | embedding_size=527, 26 | embedding_type="global", 27 | frame_emb_enc_dim=512, 28 | **kwargs, 29 | ): 30 | """ 31 | Initialization of CRNN model 32 | 33 | Args: 34 | n_in_channel: int, number of input channel 35 | n_class: int, number of classes 36 | attention: bool, adding attention layer or not 37 | activation: str, activation function 38 | dropout: float, dropout 39 | train_cnn: bool, training cnn layers 40 | rnn_type: str, rnn type 41 | n_RNN_cell: int, RNN nodes 42 | n_layer_RNN: int, number of RNN layers 43 | dropout_recurrent: float, recurrent layers dropout 44 | cnn_integration: bool, integration of cnn 45 | freeze_bn: 46 | **kwargs: keywords arguments for CNN. 47 | """ 48 | super(CRNN, self).__init__() 49 | self.n_in_channel = n_in_channel 50 | self.attention = attention 51 | self.cnn_integration = cnn_integration 52 | self.freeze_bn = freeze_bn 53 | self.use_embeddings = use_embeddings 54 | self.embedding_type = embedding_type 55 | 56 | n_in_cnn = n_in_channel 57 | 58 | if cnn_integration: 59 | n_in_cnn = 1 60 | 61 | self.cnn = CNN( 62 | n_in_channel=n_in_cnn, activation=activation, conv_dropout=dropout, **kwargs 63 | ) 64 | 65 | 66 | self.train_cnn = train_cnn 67 | 68 | if not train_cnn: 69 | for param in self.cnn.parameters(): 70 | param.requires_grad = False 71 | 72 | if rnn_type == "BGRU": 73 | nb_in = self.cnn.nb_filters[-1] 74 | if self.cnn_integration: 75 | nb_in = nb_in * n_in_channel 76 | self.rnn = BidirectionalGRU( 77 | n_in=nb_in, 78 | n_hidden=n_RNN_cell, 79 | dropout=dropout_recurrent, 80 | num_layers=n_layers_RNN, 81 | ) 82 | else: 83 | NotImplementedError("Only BGRU supported for CRNN for now") 84 | 85 | self.cnn_proj_fc = nn.Linear(n_RNN_cell * 2, n_RNN_cell) 86 | self.cnn_proj_bn = nn.BatchNorm1d(nb_in, eps=0.001, momentum=0.99) 87 | self.cnn_relu = nn.LeakyReLU(inplace=False, negative_slope=0.2) 88 | 89 | self.dropout = nn.Dropout(dropout) 90 | self.dense = nn.Linear(n_RNN_cell * 2, nclass) 91 | self.sigmoid = nn.Sigmoid() 92 | 93 | if self.attention: 94 | self.dense_softmax = nn.Linear(n_RNN_cell * 2, nclass) 95 | self.softmax = nn.Softmax(dim=-1) 96 | 97 | 98 | if self.use_embeddings: 99 | if self.embedding_type == "frame": 100 | self.frame_embs_encoder = nn.GRU(batch_first=True, input_size=embedding_size, 101 | hidden_size=512, 102 | bidirectional=True) 103 | self.shrink_emb = torch.nn.Sequential(torch.nn.Linear(2 * frame_emb_enc_dim, nb_in), 104 | torch.nn.LayerNorm(nb_in)) 105 | else: 106 | self.shrink_emb = torch.nn.Sequential(torch.nn.Linear(embedding_size, nb_in), 107 | torch.nn.LayerNorm(nb_in)) 108 | self.cat_tf = torch.nn.Linear(2*nb_in, nb_in) 109 | 110 | def forward(self, x, pad_mask=None, embeddings=None, proj=False, refine=False, prototype_vec=None): 111 | 112 | x = x.transpose(1, 2).unsqueeze(1) 113 | 114 | # input size : (batch_size, n_channels, n_frames, n_freq) 115 | if self.cnn_integration: 116 | bs_in, nc_in = x.size(0), x.size(1) 117 | x = x.view(bs_in * nc_in, 1, *x.shape[2:]) 118 | 119 | # conv features 120 | x = self.cnn(x) 121 | bs, chan, frames, freq = x.size() 122 | if self.cnn_integration: 123 | x = x.reshape(bs_in, chan * nc_in, frames, freq) 124 | 125 | if freq != 1: 126 | warnings.warn( 127 | f"Output shape is: {(bs, frames, chan * freq)}, from {freq} staying freq" 128 | ) 129 | x = x.permute(0, 2, 1, 3) 130 | x = x.contiguous().view(bs, frames, chan * freq) 131 | else: 132 | x = x.squeeze(-1) 133 | x = x.permute(0, 2, 1) # [bs, frames, chan] 134 | 135 | # rnn features 136 | if self.use_embeddings: 137 | if self.embedding_type == "global": 138 | x = self.cat_tf(torch.cat((x, self.shrink_emb(embeddings).unsqueeze(1).repeat(1, x.shape[1], 1)), -1)) 139 | else: 140 | # there can be some mismatch between seq length of cnn of crnn and the pretrained embeddings, we use an rnn 141 | # as an encoder and we use the last state 142 | last, _ = self.frame_embs_encoder(embeddings.transpose(1, 2)) 143 | embeddings = last[:, -1] 144 | x = self.cat_tf(torch.cat((x, self.shrink_emb(embeddings).unsqueeze(1).repeat(1, x.shape[1], 1)), -1)) 145 | 146 | x = self.rnn(x) 147 | x = self.dropout(x) 148 | 149 | # projector: conv + relu + norm 150 | if proj: 151 | feat = self.cnn_proj_fc(x) 152 | feat = self.cnn_relu(feat) 153 | embed_dim = feat.size(-1) 154 | feat = feat.reshape(-1, embed_dim) 155 | feat = F.normalize(feat) 156 | 157 | strong = self.dense(x) # [bs, frames, nclass] 158 | strong = self.sigmoid(strong) 159 | 160 | 161 | if self.attention: 162 | sof = self.dense_softmax(x) # [bs, frames, nclass] 163 | if not pad_mask is None: 164 | sof = sof.masked_fill(pad_mask.transpose(1, 2), -1e30) # mask attention 165 | sof = self.softmax(sof) 166 | sof = torch.clamp(sof, min=1e-7, max=1) 167 | weak = (strong * sof).sum(1) / sof.sum(1) # [bs, nclass] 168 | else: 169 | weak = strong.mean(1) 170 | 171 | if proj: 172 | return strong.transpose(1, 2), weak, feat 173 | else: 174 | return strong.transpose(1, 2), weak 175 | 176 | def train(self, mode=True): 177 | """ 178 | Override the default train() to freeze the BN parameters 179 | """ 180 | super(CRNN, self).train(mode) 181 | if self.freeze_bn: 182 | print("Freezing Mean/Var of BatchNorm2D.") 183 | if self.freeze_bn: 184 | print("Freezing Weight/Bias of BatchNorm2D.") 185 | if self.freeze_bn: 186 | for m in self.modules(): 187 | if isinstance(m, nn.BatchNorm2d): 188 | m.eval() 189 | if self.freeze_bn: 190 | m.weight.requires_grad = False 191 | m.bias.requires_grad = False 192 | -------------------------------------------------------------------------------- /src/local/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | import numpy as np 4 | import pandas as pd 5 | import scipy 6 | 7 | from desed_task.evaluation.evaluation_measures import compute_sed_eval_metrics 8 | import json 9 | 10 | import soundfile 11 | import glob 12 | 13 | 14 | def batched_decode_preds( 15 | strong_preds, filenames, encoder, thresholds=[0.5], median_filter=[3,28,7, 4,7,22,48,19,10, 50], pad_indx=None, 16 | ): 17 | """ Decode a batch of predictions to dataframes. Each threshold gives a different dataframe and stored in a 18 | dictionary 19 | 20 | Args: 21 | strong_preds: torch.Tensor, batch of strong predictions. 22 | filenames: list, the list of filenames of the current batch. 23 | encoder: ManyHotEncoder object, object used to decode predictions. 24 | thresholds: list, the list of thresholds to be used for predictions. 25 | median_filter: int, the number of frames for which to apply median window (smoothing). 26 | pad_indx: list, the list of indexes which have been used for padding. 27 | 28 | Returns: 29 | dict of predictions, each keys is a threshold and the value is the DataFrame of predictions. 30 | """ 31 | # Init a dataframe per threshold 32 | prediction_dfs = {} 33 | for threshold in thresholds: 34 | prediction_dfs[threshold] = pd.DataFrame() 35 | 36 | for j in range(strong_preds.shape[0]): # over batches 37 | for c_th in thresholds: 38 | c_preds = strong_preds[j] 39 | if pad_indx is not None: 40 | true_len = int(c_preds.shape[-1] * pad_indx[j].item()) 41 | c_preds = c_preds[:true_len] 42 | pred = c_preds.transpose(0, 1).detach().cpu().numpy() 43 | pred = pred > c_th 44 | pred_lst = [] 45 | for lb in range(pred.shape[1]): 46 | cls_pred = pred[:, lb] 47 | cls_pred = scipy.ndimage.filters.median_filter(cls_pred, median_filter[lb]) 48 | pred_lst.append(cls_pred.reshape(-1, 1)) 49 | pred = np.concatenate(pred_lst, axis=1) 50 | pred = encoder.decode_strong(pred) 51 | pred = pd.DataFrame(pred, columns=["event_label", "onset", "offset"]) 52 | pred["filename"] = Path(filenames[j]).stem + ".wav" 53 | prediction_dfs[c_th] = prediction_dfs[c_th].append(pred, ignore_index=True) 54 | 55 | return prediction_dfs 56 | 57 | 58 | def convert_to_event_based(weak_dataframe): 59 | """ Convert a weakly labeled DataFrame ('filename', 'event_labels') to a DataFrame strongly labeled 60 | ('filename', 'onset', 'offset', 'event_label'). 61 | 62 | Args: 63 | weak_dataframe: pd.DataFrame, the dataframe to be converted. 64 | 65 | Returns: 66 | pd.DataFrame, the dataframe strongly labeled. 67 | """ 68 | 69 | new = [] 70 | for i, r in weak_dataframe.iterrows(): 71 | 72 | events = r["event_labels"].split(",") 73 | for e in events: 74 | new.append( 75 | {"filename": r["filename"], "event_label": e, "onset": 0, "offset": 1} 76 | ) 77 | return pd.DataFrame(new) 78 | 79 | 80 | def log_sedeval_metrics(predictions, ground_truth, save_dir=None): 81 | """ Return the set of metrics from sed_eval 82 | Args: 83 | predictions: pd.DataFrame, the dataframe of predictions. 84 | ground_truth: pd.DataFrame, the dataframe of groundtruth. 85 | save_dir: str, path to the folder where to save the event and segment based metrics outputs. 86 | 87 | Returns: 88 | tuple, event-based macro-F1 and micro-F1, segment-based macro-F1 and micro-F1 89 | """ 90 | if predictions.empty: 91 | return 0.0, 0.0, 0.0, 0.0 92 | 93 | gt = pd.read_csv(ground_truth, sep="\t") 94 | 95 | event_res, segment_res = compute_sed_eval_metrics(predictions, gt) 96 | 97 | if save_dir is not None: 98 | os.makedirs(save_dir, exist_ok=True) 99 | with open(os.path.join(save_dir, "event_f1.txt"), "w") as f: 100 | f.write(str(event_res)) 101 | 102 | with open(os.path.join(save_dir, "segment_f1.txt"), "w") as f: 103 | f.write(str(segment_res)) 104 | 105 | return ( 106 | event_res.results()["class_wise_average"]["f_measure"]["f_measure"], 107 | event_res.results()["overall"]["f_measure"]["f_measure"], 108 | segment_res.results()["class_wise_average"]["f_measure"]["f_measure"], 109 | segment_res.results()["overall"]["f_measure"]["f_measure"], 110 | ) # return also segment measures 111 | 112 | 113 | def parse_jams(jams_list, encoder, out_json): 114 | 115 | if len(jams_list) == 0: 116 | raise IndexError("jams list is empty ! Wrong path ?") 117 | 118 | backgrounds = [] 119 | sources = [] 120 | for jamfile in jams_list: 121 | 122 | with open(jamfile, "r") as f: 123 | jdata = json.load(f) 124 | 125 | # check if we have annotations for each source in scaper 126 | assert len(jdata["annotations"][0]["data"]) == len( 127 | jdata["annotations"][-1]["sandbox"]["scaper"]["isolated_events_audio_path"] 128 | ) 129 | 130 | for indx, sound in enumerate(jdata["annotations"][0]["data"]): 131 | source_name = Path( 132 | jdata["annotations"][-1]["sandbox"]["scaper"][ 133 | "isolated_events_audio_path" 134 | ][indx] 135 | ).stem 136 | source_file = os.path.join( 137 | Path(jamfile).parent, 138 | Path(jamfile).stem + "_events", 139 | source_name + ".wav", 140 | ) 141 | 142 | if sound["value"]["role"] == "background": 143 | backgrounds.append(source_file) 144 | else: # it is an event 145 | if ( 146 | sound["value"]["label"] not in encoder.labels 147 | ): # correct different labels 148 | if sound["value"]["label"].startswith("Frying"): 149 | sound["value"]["label"] = "Frying" 150 | elif sound["value"]["label"].startswith("Vacuum_cleaner"): 151 | sound["value"]["label"] = "Vacuum_cleaner" 152 | else: 153 | raise NotImplementedError 154 | 155 | sources.append( 156 | { 157 | "filename": source_file, 158 | "onset": sound["value"]["event_time"], 159 | "offset": sound["value"]["event_time"] 160 | + sound["value"]["event_duration"], 161 | "event_label": sound["value"]["label"], 162 | } 163 | ) 164 | 165 | os.makedirs(Path(out_json).parent, exist_ok=True) 166 | with open(out_json, "w") as f: 167 | json.dump({"backgrounds": backgrounds, "sources": sources}, f, indent=4) 168 | 169 | 170 | def generate_tsv_wav_durations(audio_dir, out_tsv): 171 | """ 172 | Generate a dataframe with filename and duration of the file 173 | 174 | Args: 175 | audio_dir: str, the path of the folder where audio files are (used by glob.glob) 176 | out_tsv: str, the path of the output tsv file 177 | 178 | Returns: 179 | pd.DataFrame: the dataframe containing filenames and durations 180 | """ 181 | meta_list = [] 182 | for file in glob.glob(os.path.join(audio_dir, "*.wav")): 183 | d = soundfile.info(file).duration 184 | meta_list.append([os.path.basename(file), d]) 185 | meta_df = pd.DataFrame(meta_list, columns=["filename", "duration"]) 186 | if out_tsv is not None: 187 | meta_df.to_csv(out_tsv, sep="\t", index=False, float_format="%.1f") 188 | 189 | return meta_df 190 | -------------------------------------------------------------------------------- /src/desed_task/data_augm.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import random 4 | 5 | def time_mask(features, labels=None, net_pooling=4, mask_ratios=(5, 20)): 6 | _, _, n_frame = labels.shape 7 | t_width = torch.randint(low=int(n_frame/mask_ratios[1]), high=int(n_frame/mask_ratios[0]), size=(1,)) # [low, high) 8 | t_low = torch.randint(low=0, high=n_frame-t_width[0], size=(1,)) 9 | features[:, :, t_low * net_pooling:(t_low+t_width)*net_pooling] = 0 10 | labels[:, :, t_low:t_low+t_width] = 0 11 | return features, labels 12 | 13 | def filt_aug(features, db_range=[-6, 6], n_band=[3, 6], min_bw=6, filter_type="linear"): 14 | if not isinstance(filter_type, str): 15 | if torch.rand(1).item() < filter_type: 16 | filter_type = "step" 17 | n_band = [2, 5] 18 | min_bw = 4 19 | else: 20 | filter_type = "linear" 21 | n_band = [3, 6] 22 | min_bw = 6 23 | 24 | batch_size, n_freq_bin, _ = features.shape 25 | n_freq_band = torch.randint(low=n_band[0], high=n_band[1], size=(1,)).item() # [low, high) 26 | if n_freq_band > 1: 27 | while n_freq_bin - n_freq_band * min_bw + 1 < 0: 28 | min_bw -= 1 29 | band_bndry_freqs = torch.sort(torch.randint(0, n_freq_bin - n_freq_band * min_bw + 1, 30 | (n_freq_band - 1,)))[0] + \ 31 | torch.arange(1, n_freq_band) * min_bw 32 | band_bndry_freqs = torch.cat((torch.tensor([0]), band_bndry_freqs, torch.tensor([n_freq_bin]))) 33 | 34 | if filter_type == "step": 35 | band_factors = torch.rand((batch_size, n_freq_band)).to(features) * (db_range[1] - db_range[0]) + db_range[0] 36 | band_factors = 10 ** (band_factors / 20) 37 | 38 | freq_filt = torch.ones((batch_size, n_freq_bin, 1)).to(features) 39 | for i in range(n_freq_band): 40 | freq_filt[:, band_bndry_freqs[i]:band_bndry_freqs[i + 1], :] = band_factors[:, i].unsqueeze(-1).unsqueeze(-1) 41 | 42 | elif filter_type == "linear": 43 | band_factors = torch.rand((batch_size, n_freq_band + 1)).to(features) * (db_range[1] - db_range[0]) + db_range[0] 44 | freq_filt = torch.ones((batch_size, n_freq_bin, 1)).to(features) 45 | for i in range(n_freq_band): 46 | for j in range(batch_size): 47 | freq_filt[j, band_bndry_freqs[i]:band_bndry_freqs[i+1], :] = \ 48 | torch.linspace(band_factors[j, i], band_factors[j, i+1], 49 | band_bndry_freqs[i+1] - band_bndry_freqs[i]).unsqueeze(-1) 50 | freq_filt = 10 ** (freq_filt / 20) 51 | return features * freq_filt 52 | 53 | else: 54 | return features 55 | 56 | def cut_mix(data, target, indx_synth=12, low_r=0.20, high_r=0.50): 57 | """ 58 | Args: 59 | features_stu: features from student model. 60 | features_proto: features from prototypes. 61 | pseudo_lb_stu: frame-level prob vector for student preds. 62 | lb_proto: labels for prototypes. 63 | """ 64 | with torch.no_grad(): 65 | batch_size, feat_dims, feat_len = data.size() 66 | _, _, n_frame = target.size() 67 | net_pooling = feat_len // feat_dims 68 | 69 | mask_ratio = np.random.uniform(low_r, high_r, batch_size) 70 | 71 | # mix length 72 | lb_mask_len = (mask_ratio * n_frame).astype(int) # for target 73 | feat_mask_len = lb_mask_len * net_pooling # for data 74 | 75 | # start point 76 | lb_st = np.random.randint(0, n_frame - lb_mask_len, batch_size) 77 | feat_st = lb_st * net_pooling 78 | 79 | lb_mask = torch.ones((batch_size, n_frame)) 80 | feat_mask = torch.ones((batch_size, feat_len)) 81 | ulb_frame = torch.ones((batch_size, n_frame)) 82 | # trace the unlabeled sample 83 | ulb_frame[:indx_synth] = 0 84 | 85 | for i in range(batch_size): 86 | lb_mask[i, lb_st[i] : lb_st[i] + lb_mask_len[i]] = 0 87 | feat_mask[i, feat_st[i] : feat_st[i] + feat_mask_len[i]] = 0 88 | 89 | lb_mask = lb_mask.unsqueeze(-1).permute(0, 2, 1).cuda() 90 | feat_mask = feat_mask.unsqueeze(-1).permute(0, 2, 1).cuda() 91 | ulb_frame = ulb_frame.unsqueeze(-1).permute(0, 2, 1).cuda() 92 | 93 | # cutmix 94 | perm = torch.randperm(batch_size) 95 | mixed_data = feat_mask * data + (1 - feat_mask) * data[perm, :] 96 | mixed_target = lb_mask * target + (1 - lb_mask) * target[perm, :] 97 | mixed_ulb_frame = lb_mask * ulb_frame + (1 - lb_mask) * ulb_frame[perm, :] 98 | 99 | return mixed_data, mixed_target, mixed_ulb_frame 100 | 101 | 102 | def frame_shift(mels, labels, net_pooling=4): 103 | bsz, n_bands, frames = mels.shape 104 | shifted = [] 105 | new_labels = [] 106 | for bindx in range(bsz): 107 | shift = int(random.gauss(0, 90)) 108 | shifted.append(torch.roll(mels[bindx], shift, dims=-1)) 109 | shift = -abs(shift) // net_pooling if shift < 0 else shift // net_pooling 110 | new_labels.append(torch.roll(labels[bindx], shift, dims=-1)) 111 | return torch.stack(shifted), torch.stack(new_labels) 112 | 113 | 114 | def mixup(data, target=None, alpha=0.2, beta=0.2, mixup_label_type="soft"): 115 | """Mixup data augmentation by permuting the data 116 | 117 | Args: 118 | data: input tensor, must be a batch so data can be permuted and mixed. 119 | target: tensor of the target to be mixed, if None, do not return targets. 120 | alpha: float, the parameter to the np.random.beta distribution 121 | beta: float, the parameter to the np.random.beta distribution 122 | mixup_label_type: str, the type of mixup to be used choice between {'soft', 'hard'}. 123 | Returns: 124 | torch.Tensor of mixed data and labels if given 125 | """ 126 | with torch.no_grad(): 127 | batch_size = data.size(0) 128 | c = np.random.beta(alpha, beta) 129 | 130 | perm = torch.randperm(batch_size) 131 | 132 | mixed_data = c * data + (1 - c) * data[perm, :] 133 | if target is not None: 134 | if mixup_label_type == "soft": 135 | mixed_target = torch.clamp( 136 | c * target + (1 - c) * target[perm, :], min=0, max=1 137 | ) 138 | elif mixup_label_type == "hard": 139 | mixed_target = torch.clamp(target + target[perm, :], min=0, max=1) 140 | else: 141 | raise NotImplementedError( 142 | f"mixup_label_type: {mixup_label_type} not implemented. choice in " 143 | f"{'soft', 'hard'}" 144 | ) 145 | 146 | return mixed_data, mixed_target 147 | else: 148 | return mixed_data 149 | 150 | 151 | def add_noise(mels, snrs=(6, 30), dims=(1, 2)): 152 | """ Add white noise to mels spectrograms 153 | Args: 154 | mels: torch.tensor, mels spectrograms to apply the white noise to. 155 | snrs: int or tuple, the range of snrs to choose from if tuple (uniform) 156 | dims: tuple, the dimensions for which to compute the standard deviation (default to (1,2) because assume 157 | an input of a batch of mel spectrograms. 158 | Returns: 159 | torch.Tensor of mels with noise applied 160 | """ 161 | if isinstance(snrs, (list, tuple)): 162 | snr = (snrs[0] - snrs[1]) * torch.rand( 163 | (mels.shape[0],), device=mels.device 164 | ).reshape(-1, 1, 1) + snrs[1] 165 | else: 166 | snr = snrs 167 | 168 | snr = 10 ** (snr / 20) # linear domain 169 | sigma = torch.std(mels, dim=dims, keepdim=True) / snr 170 | mels = mels + torch.randn(mels.shape, device=mels.device) * sigma 171 | 172 | return mels 173 | -------------------------------------------------------------------------------- /src/generate_dcase_task4_2022.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import glob 3 | import time 4 | import warnings 5 | from pprint import pformat 6 | from pathlib import Path 7 | 8 | 9 | import os 10 | import shutil 11 | 12 | import desed 13 | 14 | def create_folder(folder, exist_ok=True, delete_if_exists=False): 15 | """ Create folder (and parent folders) if not exists. 16 | 17 | Args: 18 | folder: str, path of folder(s) to create. 19 | delete_if_exists: bool, True if you want to delete the folder when exists 20 | 21 | Returns: 22 | None 23 | """ 24 | if not folder == "": 25 | if delete_if_exists: 26 | if os.path.exists(folder): 27 | shutil.rmtree(folder) 28 | os.mkdir(folder) 29 | 30 | os.makedirs(folder, exist_ok=exist_ok) 31 | 32 | 33 | def _create_symlink(src, dest, **kwargs): 34 | if os.path.exists(dest): 35 | warnings.warn(f"Symlink already exists : {dest}, skipping.\n") 36 | else: 37 | os.makedirs(os.path.dirname(dest), exist_ok=True) 38 | os.symlink(os.path.abspath(src), dest, **kwargs) 39 | 40 | 41 | def create_synth_dcase(synth_path, destination_folder): 42 | """Create symbolic links for synethtic part of the dataset 43 | 44 | Args: 45 | synth_path (str): synthetic folder path 46 | destination_folder (str): destination folder path 47 | """ 48 | print("Creating symlinks for synthetic data") 49 | split_sets = ["train", "validation"] 50 | if os.path.exists(os.path.join(synth_path, "audio", "eval")): 51 | split_sets.append("eval") 52 | 53 | for split_set in split_sets: 54 | # AUDIO 55 | split_audio_folder = os.path.join(synth_path, "audio", split_set) 56 | audio_subfolders = [ 57 | d 58 | for d in os.listdir(split_audio_folder) 59 | if os.path.isdir(os.path.join(split_audio_folder, d)) 60 | ] 61 | # Manage the validation case which changed from 2020 62 | if split_set == "validation" and not len(audio_subfolders): 63 | split_audio_folder = os.path.join(synth_path, "audio") 64 | audio_subfolders = ["validation"] 65 | 66 | for subfolder in audio_subfolders: 67 | abs_src_folder = os.path.abspath( 68 | os.path.join(split_audio_folder, subfolder) 69 | ) 70 | dest_folder = os.path.join( 71 | destination_folder, "audio", split_set, subfolder 72 | ) 73 | _create_symlink(abs_src_folder, dest_folder) 74 | 75 | # META 76 | split_meta_folder = os.path.join(synth_path, "metadata", split_set, f"synthetic21_{split_set}") 77 | meta_files = glob.glob(os.path.join(split_meta_folder, "*.tsv")) 78 | for meta_file in meta_files: 79 | 80 | create_folder(destination_folder) 81 | dest_file = os.path.join( 82 | destination_folder, "metadata", split_set, f"synthetic21_{split_set}", os.path.basename(meta_file) 83 | ) 84 | _create_symlink(meta_file, dest_file) 85 | 86 | 87 | if __name__ == "__main__": 88 | t = time.time() 89 | parser = argparse.ArgumentParser() 90 | parser.add_argument( 91 | "--basedir", 92 | type=str, 93 | default="../../data", 94 | help="The base data folder in which we'll create the different datasets." 95 | "Useful when you don't have any dataset, provide this one and the output folder", 96 | ) 97 | parser.add_argument( 98 | "--out_dir", 99 | type=str, 100 | default=None, 101 | help="Output basefolder in which to put the created 2021 dataset (with real and soundscapes)", 102 | ) 103 | parser.add_argument( 104 | "--only_real", 105 | action="store_true", 106 | help="True if only the real part of the dataset need to be downloaded" 107 | ) 108 | 109 | parser.add_argument( 110 | "--only_synth", 111 | action="store_true", 112 | help="True if only the synthetic part of the dataset need to be downloaded" 113 | ) 114 | 115 | parser.add_argument( 116 | "--only_strong", 117 | action="store_true", 118 | help="True if only the strongly annotated part of the Audioset dataset need to be downloaded" 119 | ) 120 | 121 | args = parser.parse_args() 122 | pformat(vars(args)) 123 | 124 | # ######### 125 | # Paths 126 | # ######### 127 | bdir = args.basedir 128 | dcase_dataset_folder = args.out_dir 129 | only_real = args.only_real 130 | only_synth = args.only_synth 131 | only_strong = args.only_strong 132 | missing_files = None 133 | 134 | download_all = (only_real and only_synth and only_strong) or (not only_real and not only_synth and not only_strong) 135 | print(f"Download all: {download_all}") 136 | 137 | 138 | # Default paths if not defined (using basedir) 139 | if dcase_dataset_folder is None: 140 | dcase_dataset_folder = os.path.join(bdir, "dcase", "dataset") 141 | 142 | # ######### 143 | # Download the different datasets if they do not exist 144 | # ######### 145 | 146 | # download real dataset 147 | if only_real or download_all: 148 | print('Downloading audioset dataset') 149 | missing_files = desed.download_audioset_data(dcase_dataset_folder, n_jobs=3, chunk_size=10) 150 | 151 | # download strong-label Audioset dataset 152 | if only_strong or download_all: 153 | url_strong = ( 154 | "https://zenodo.org/record/6444477/files/audioset_strong.tsv?download=1" 155 | ) 156 | basedir_missing_files = "missing_files" 157 | desed.utils.create_folder(basedir_missing_files) 158 | 159 | strong_label_metadata_path = os.path.join(dcase_dataset_folder, "metadata", "train", "audioset_strong.tsv") 160 | sl_path = Path(strong_label_metadata_path) 161 | if not sl_path.is_file(): 162 | desed.utils.download_file_from_url(url_strong, strong_label_metadata_path) 163 | print(f"File saved in {strong_label_metadata_path}") 164 | 165 | print("Downloading strong-label Audioset dataset...") 166 | path_missing_files_audioset = os.path.join( 167 | basedir_missing_files, "missing_files_" + "strong_label_real" + ".tsv" 168 | ) 169 | desed.download.download_audioset_files_from_csv( 170 | strong_label_metadata_path, 171 | os.path.join(dcase_dataset_folder, "audio", "train", "strong_label_real"), 172 | missing_files_tsv=path_missing_files_audioset, 173 | ) 174 | 175 | else: 176 | print(f"The file {sl_path} already exists.") 177 | 178 | 179 | # download synthetic dataset 180 | if only_synth or download_all: 181 | print(f"Downloading synthetic part of the dataset") 182 | url_synth = ( 183 | "https://zenodo.org/record/6026841/files/dcase_synth.zip?download=1" 184 | ) 185 | synth_folder = str(os.path.basename(url_synth)).split('.')[0] 186 | desed.download.download_and_unpack_archive(url_synth, dcase_dataset_folder, archive_format="zip") 187 | synth_folder = os.path.join(bdir, "dcase", "dataset", synth_folder) 188 | create_synth_dcase(synth_folder, dcase_dataset_folder) 189 | 190 | print(f"Time of the program: {time.time() - t} s") 191 | print(f"The dcase dataset has been saved in the following path: {dcase_dataset_folder}") 192 | if missing_files is not None: 193 | warnings.warn( 194 | f"You have missing files.\n\n" 195 | f"Please try to redownload desed_real again: \n" 196 | f"import desed\n" 197 | f"desed.download_audioset_data('{dcase_dataset_folder}', n_jobs=3, chunk_size=10)\n\n" 198 | f"Please, send your missing_files_xx.tsv to the task organisers to get your missing files.\n" 199 | ) 200 | 201 | 202 | -------------------------------------------------------------------------------- /src/extract_embeddings.py: -------------------------------------------------------------------------------- 1 | import h5py 2 | import numpy as np 3 | import glob 4 | import os 5 | import argparse 6 | import torch 7 | from desed_task.dataio.datasets import read_audio 8 | import yaml 9 | import pandas as pd 10 | from desed_task.utils.download import download_from_url 11 | from tqdm import tqdm 12 | from pathlib import Path 13 | import torchaudio 14 | 15 | parser = argparse.ArgumentParser("Extract Embeddings with Audioset Pretrained Models") 16 | 17 | 18 | class WavDataset(torch.utils.data.Dataset): 19 | def __init__( 20 | self, 21 | folder, 22 | pad_to=10, 23 | fs=16000, 24 | feats_pipeline=None 25 | ): 26 | self.fs = fs 27 | self.pad_to = pad_to * fs if pad_to is not None else None 28 | self.examples = glob.glob(os.path.join(folder, "*.wav")) 29 | self.feats_pipeline = feats_pipeline 30 | 31 | def __len__(self): 32 | return len(self.examples) 33 | 34 | def __getitem__(self, item): 35 | c_ex = self.examples[item] 36 | 37 | mixture, _, _, padded_indx = read_audio( 38 | c_ex, False, False, self.pad_to 39 | ) 40 | 41 | if self.feats_pipeline is not None: 42 | mixture = self.feats_pipeline(mixture) 43 | return mixture, Path(c_ex).stem 44 | 45 | 46 | def extract(batch_size, folder, dset_name, torch_dset, embedding_model, use_gpu=True): 47 | 48 | Path(folder).mkdir(parents=True, exist_ok=True) 49 | f = h5py.File(os.path.join(folder, '{}.hdf5'.format(dset_name)), "w-") 50 | dt = h5py.vlen_dtype(np.dtype('float32')) 51 | global_embeddings = f.create_dataset('global_embeddings', (len(torch_dset), ), dtype=dt) 52 | if type(embedding_model).__name__ == "Cnn14_16k": 53 | emb_size = int(256*8) 54 | else: 55 | emb_size = 768 56 | frame_embeddings = f.create_dataset('frame_embeddings', (len(torch_dset), emb_size, ), dtype=dt) 57 | 58 | dloader = torch.utils.data.DataLoader(torch_dset, 59 | batch_size=batch_size, 60 | drop_last=False) 61 | global_indx = 0 62 | for i, batch in enumerate(tqdm(dloader)): 63 | feats, filenames = batch 64 | if use_gpu: 65 | feats = feats.cuda() 66 | 67 | with torch.inference_mode(): 68 | emb = embedding_model(feats) 69 | c_glob_emb = emb["global"] 70 | c_frame_emb = emb["frame"] 71 | # enumerate, convert to numpy and write to h5py 72 | bsz = feats.shape[0] 73 | for b_indx in range(bsz): 74 | global_embeddings[global_indx] = c_glob_emb[b_indx].detach().cpu().numpy() 75 | global_embeddings.attrs[filenames[b_indx]] = global_indx 76 | frame_embeddings[global_indx] = c_frame_emb[b_indx].detach().cpu().numpy() 77 | frame_embeddings.attrs[filenames[b_indx]] = global_indx 78 | global_indx += 1 79 | 80 | 81 | if __name__ == "__main__": 82 | parser.add_argument( 83 | "--output_dir", default="./embeddings", 84 | help="Output directory") 85 | parser.add_argument( 86 | "--conf_file", 87 | default="./confs/default.yaml", 88 | help="The configuration file with all the experiment parameters.") 89 | parser.add_argument("--pretrained_model", default="panns", help="The pretrained model to use," 90 | "choose between panns and ast") 91 | parser.add_argument( 92 | "--use_gpu", 93 | default="1", 94 | help="0 does not use GPU, 1 use GPU") 95 | parser.add_argument( 96 | "--batch_size", 97 | default="8", 98 | help="Batch size for model inference, used to speed up the embedding extraction.") 99 | 100 | args = parser.parse_args() 101 | assert args.pretrained_model in ["panns", "ast"], "pretrained model must be either panns or ast" 102 | 103 | with open(args.conf_file, "r") as f: 104 | config = yaml.safe_load(f) 105 | 106 | output_dir = os.path.join(args.output_dir, args.pretrained_model) 107 | # loading model 108 | if args.pretrained_model == "ast": 109 | # need feature extraction with torchaudio compliance feats 110 | class ASTFeatsExtraction: 111 | # need feature extraction in dataloader because kaldi compliant torchaudio fbank are used (no gpu support) 112 | def __init__(self, audioset_mean=-4.2677393, audioset_std=4.5689974, 113 | target_length=1024): 114 | super(ASTFeatsExtraction, self).__init__() 115 | self.audioset_mean = audioset_mean 116 | self.audioset_std = audioset_std 117 | self.target_length = target_length 118 | def __call__(self, waveform): 119 | waveform = waveform - torch.mean(waveform, -1) 120 | 121 | fbank = torchaudio.compliance.kaldi.fbank(waveform.unsqueeze(0), htk_compat=True, sample_frequency=16000, use_energy=False, 122 | window_type='hanning', num_mel_bins=128, 123 | dither=0.0, frame_shift=10) 124 | fbank = torch.nn.functional.pad(fbank, (0, 0, 0, self.target_length-fbank.shape[0]), mode="constant") 125 | 126 | fbank = (fbank - self.audioset_mean) / (self.audioset_std * 2) 127 | return fbank 128 | 129 | 130 | feature_extraction = ASTFeatsExtraction() 131 | from local.ast.ast_models import ASTModel 132 | pretrained = ASTModel(label_dim=527, 133 | fstride=10, tstride=10, 134 | input_fdim=128, input_tdim=1024, 135 | imagenet_pretrain=True, audioset_pretrain=True, 136 | model_size='base384') 137 | 138 | 139 | 140 | elif args.pretrained_model == "panns": 141 | feature_extraction = None # integrated in the model 142 | download_from_url("https://zenodo.org/record/3987831/files/Cnn14_16k_mAP%3D0.438.pth?download=1", "./pretrained_models/Cnn14_16k_mAP%3D0.438.pth") 143 | # use pannss as additional feature 144 | from local.panns.models import Cnn14_16k 145 | pretrained = Cnn14_16k(freeze_bn=True, 146 | use_specaugm=True) 147 | 148 | pretrained.load_state_dict(torch.load("./pretrained_models/Cnn14_16k_mAP%3D0.438.pth")["model"], strict=False) 149 | else: 150 | raise NotImplementedError 151 | 152 | use_gpu = int(args.use_gpu) 153 | if use_gpu: 154 | pretrained = pretrained.cuda() 155 | 156 | pretrained.eval() 157 | synth_df = pd.read_csv(config["data"]["synth_tsv"], sep="\t") 158 | synth_set = WavDataset( 159 | config["data"]["synth_folder"], 160 | feats_pipeline=feature_extraction) 161 | 162 | synth_set[0] 163 | 164 | weak_df = pd.read_csv(config["data"]["weak_tsv"], sep="\t") 165 | train_weak_df = weak_df.sample( 166 | frac=config["training"]["weak_split"], 167 | random_state=config["training"]["seed"]) 168 | 169 | valid_weak_df = weak_df.drop(train_weak_df.index).reset_index(drop=True) 170 | train_weak_df = train_weak_df.reset_index(drop=True) 171 | weak_set = WavDataset( 172 | config["data"]["weak_folder"], 173 | feats_pipeline=feature_extraction) 174 | 175 | unlabeled_set = WavDataset( 176 | config["data"]["unlabeled_folder"], 177 | feats_pipeline=feature_extraction) 178 | 179 | synth_df_val = pd.read_csv(config["data"]["synth_val_tsv"], 180 | sep="\t") 181 | synth_val = WavDataset( 182 | config["data"]["synth_val_folder"], 183 | feats_pipeline=feature_extraction 184 | ) 185 | 186 | weak_val = WavDataset( 187 | config["data"]["weak_folder"], 188 | feats_pipeline=feature_extraction 189 | ) 190 | 191 | devtest_dataset = WavDataset( 192 | config["data"]["test_folder"], feats_pipeline=feature_extraction) 193 | """ 194 | for k, elem in {"synth_train": synth_set, "weak_train": weak_set, 195 | "unlabeled_train" : unlabeled_set, 196 | "synth_val" : synth_val, 197 | "weak_val" : weak_val, 198 | "devtest": devtest_dataset}.items(): 199 | """ 200 | for k, elem in {"devtest": devtest_dataset}.items(): 201 | extract(int(args.batch_size), output_dir, k, elem, pretrained, use_gpu) -------------------------------------------------------------------------------- /src/desed_task/utils/encoder.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | from dcase_util.data import DecisionEncoder 4 | 5 | 6 | class ManyHotEncoder: 7 | """" 8 | Adapted after DecisionEncoder.find_contiguous_regions method in 9 | https://github.com/DCASE-REPO/dcase_util/blob/master/dcase_util/data/decisions.py 10 | 11 | Encode labels into numpy arrays where 1 correspond to presence of the class and 0 absence. 12 | Multiple 1 can appear on the same line, it is for multi label problem. 13 | Args: 14 | labels: list, the classes which will be encoded 15 | n_frames: int, (Default value = None) only useful for strong labels. The number of frames of a segment. 16 | Attributes: 17 | labels: list, the classes which will be encoded 18 | n_frames: int, only useful for strong labels. The number of frames of a segment. 19 | """ 20 | 21 | def __init__( 22 | self, labels, audio_len, frame_len, frame_hop, net_pooling=1, fs=16000 23 | ): 24 | if type(labels) in [np.ndarray, np.array]: 25 | labels = labels.tolist() 26 | self.labels = labels 27 | self.audio_len = audio_len 28 | self.frame_len = frame_len 29 | self.frame_hop = frame_hop 30 | self.fs = fs 31 | self.net_pooling = net_pooling 32 | n_frames = self.audio_len * self.fs 33 | # self.n_frames = int( 34 | # int(((n_frames - self.frame_len) / self.frame_hop)) / self.net_pooling 35 | # ) 36 | self.n_frames = int(int((n_frames / self.frame_hop)) / self.net_pooling) 37 | 38 | def encode_weak(self, labels): 39 | """ Encode a list of weak labels into a numpy array 40 | 41 | Args: 42 | labels: list, list of labels to encode (to a vector of 0 and 1) 43 | 44 | Returns: 45 | numpy.array 46 | A vector containing 1 for each label, and 0 everywhere else 47 | """ 48 | # useful for tensor empty labels 49 | if type(labels) is str: 50 | if labels == "empty": 51 | y = np.zeros(len(self.labels)) - 1 52 | return y 53 | else: 54 | labels = labels.split(",") 55 | if type(labels) is pd.DataFrame: 56 | if labels.empty: 57 | labels = [] 58 | elif "event_label" in labels.columns: 59 | labels = labels["event_label"] 60 | y = np.zeros(len(self.labels)) 61 | for label in labels: 62 | if not pd.isna(label): 63 | i = self.labels.index(label) 64 | y[i] = 1 65 | return y 66 | 67 | def _time_to_frame(self, time): 68 | samples = time * self.fs 69 | frame = (samples) / self.frame_hop 70 | return np.clip(frame / self.net_pooling, a_min=0, a_max=self.n_frames) 71 | 72 | def _frame_to_time(self, frame): 73 | frame = frame * self.net_pooling / (self.fs / self.frame_hop) 74 | return np.clip(frame, a_min=0, a_max=self.audio_len) 75 | 76 | def encode_strong_df(self, label_df): 77 | """Encode a list (or pandas Dataframe or Serie) of strong labels, they correspond to a given filename 78 | 79 | Args: 80 | label_df: pandas DataFrame or Series, contains filename, onset (in frames) and offset (in frames) 81 | If only filename (no onset offset) is specified, it will return the event on all the frames 82 | onset and offset should be in frames 83 | Returns: 84 | numpy.array 85 | Encoded labels, 1 where the label is present, 0 otherwise 86 | """ 87 | 88 | assert any( 89 | [x is not None for x in [self.audio_len, self.frame_len, self.frame_hop]] 90 | ) 91 | 92 | samples_len = self.n_frames 93 | if type(label_df) is str: 94 | if label_df == "empty": 95 | y = np.zeros((samples_len, len(self.labels))) - 1 96 | return y 97 | y = np.zeros((samples_len, len(self.labels))) 98 | if type(label_df) is pd.DataFrame: 99 | if {"onset", "offset", "event_label"}.issubset(label_df.columns): 100 | for _, row in label_df.iterrows(): 101 | if not pd.isna(row["event_label"]): 102 | i = self.labels.index(row["event_label"]) 103 | onset = int(self._time_to_frame(row["onset"])) 104 | offset = int(np.ceil(self._time_to_frame(row["offset"]))) 105 | y[ 106 | onset:offset, i 107 | ] = 1 # means offset not included (hypothesis of overlapping frames, so ok) 108 | 109 | elif type(label_df) in [ 110 | pd.Series, 111 | list, 112 | np.ndarray, 113 | ]: # list of list or list of strings 114 | if type(label_df) is pd.Series: 115 | if {"onset", "offset", "event_label"}.issubset( 116 | label_df.index 117 | ): # means only one value 118 | if not pd.isna(label_df["event_label"]): 119 | i = self.labels.index(label_df["event_label"]) 120 | onset = int(self._time_to_frame(label_df["onset"])) 121 | offset = int(np.ceil(self._time_to_frame(label_df["offset"]))) 122 | y[onset:offset, i] = 1 123 | return y 124 | 125 | for event_label in label_df: 126 | # List of string, so weak labels to be encoded in strong 127 | if type(event_label) is str: 128 | if event_label != "": 129 | i = self.labels.index(event_label) 130 | y[:, i] = 1 131 | 132 | # List of list, with [label, onset, offset] 133 | elif len(event_label) == 3: 134 | if event_label[0] != "": 135 | i = self.labels.index(event_label[0]) 136 | onset = int(self._time_to_frame(event_label[1])) 137 | offset = int(np.ceil(self._time_to_frame(event_label[2]))) 138 | y[onset:offset, i] = 1 139 | 140 | else: 141 | raise NotImplementedError( 142 | "cannot encode strong, type mismatch: {}".format( 143 | type(event_label) 144 | ) 145 | ) 146 | 147 | else: 148 | raise NotImplementedError( 149 | "To encode_strong, type is pandas.Dataframe with onset, offset and event_label" 150 | "columns, or it is a list or pandas Series of event labels, " 151 | "type given: {}".format(type(label_df)) 152 | ) 153 | return y 154 | 155 | def decode_weak(self, labels): 156 | """ Decode the encoded weak labels 157 | Args: 158 | labels: numpy.array, the encoded labels to be decoded 159 | 160 | Returns: 161 | list 162 | Decoded labels, list of string 163 | 164 | """ 165 | result_labels = [] 166 | for i, value in enumerate(labels): 167 | if value == 1: 168 | result_labels.append(self.labels[i]) 169 | return result_labels 170 | 171 | def decode_strong(self, labels): 172 | """ Decode the encoded strong labels 173 | Args: 174 | labels: numpy.array, the encoded labels to be decoded 175 | Returns: 176 | list 177 | Decoded labels, list of list: [[label, onset offset], ...] 178 | 179 | """ 180 | result_labels = [] 181 | for i, label_column in enumerate(labels.T): 182 | change_indices = DecisionEncoder().find_contiguous_regions(label_column) 183 | 184 | # append [label, onset, offset] in the result list 185 | for row in change_indices: 186 | result_labels.append( 187 | [ 188 | self.labels[i], 189 | self._frame_to_time(row[0]), 190 | self._frame_to_time(row[1]), 191 | ] 192 | ) 193 | return result_labels 194 | 195 | def state_dict(self): 196 | return { 197 | "labels": self.labels, 198 | "audio_len": self.audio_len, 199 | "frame_len": self.frame_len, 200 | "frame_hop": self.frame_hop, 201 | "net_pooling": self.net_pooling, 202 | "fs": self.fs, 203 | } 204 | 205 | @classmethod 206 | def load_state_dict(cls, state_dict): 207 | labels = state_dict["labels"] 208 | audio_len = state_dict["audio_len"] 209 | frame_len = state_dict["frame_len"] 210 | frame_hop = state_dict["frame_hop"] 211 | net_pooling = state_dict["net_pooling"] 212 | fs = state_dict["fs"] 213 | return cls(labels, audio_len, frame_len, frame_hop, net_pooling, fs) 214 | -------------------------------------------------------------------------------- /src/desed_task/utils/.ipynb_checkpoints/encoder-checkpoint.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | from dcase_util.data import DecisionEncoder 4 | 5 | 6 | class ManyHotEncoder: 7 | """" 8 | Adapted after DecisionEncoder.find_contiguous_regions method in 9 | https://github.com/DCASE-REPO/dcase_util/blob/master/dcase_util/data/decisions.py 10 | 11 | Encode labels into numpy arrays where 1 correspond to presence of the class and 0 absence. 12 | Multiple 1 can appear on the same line, it is for multi label problem. 13 | Args: 14 | labels: list, the classes which will be encoded 15 | n_frames: int, (Default value = None) only useful for strong labels. The number of frames of a segment. 16 | Attributes: 17 | labels: list, the classes which will be encoded 18 | n_frames: int, only useful for strong labels. The number of frames of a segment. 19 | """ 20 | 21 | def __init__( 22 | self, labels, audio_len, frame_len, frame_hop, net_pooling=1, fs=16000 23 | ): 24 | if type(labels) in [np.ndarray, np.array]: 25 | labels = labels.tolist() 26 | self.labels = labels 27 | self.audio_len = audio_len 28 | self.frame_len = frame_len 29 | self.frame_hop = frame_hop 30 | self.fs = fs 31 | self.net_pooling = net_pooling 32 | n_frames = self.audio_len * self.fs 33 | # self.n_frames = int( 34 | # int(((n_frames - self.frame_len) / self.frame_hop)) / self.net_pooling 35 | # ) 36 | self.n_frames = int(int((n_frames / self.frame_hop)) / self.net_pooling) 37 | 38 | def encode_weak(self, labels): 39 | """ Encode a list of weak labels into a numpy array 40 | 41 | Args: 42 | labels: list, list of labels to encode (to a vector of 0 and 1) 43 | 44 | Returns: 45 | numpy.array 46 | A vector containing 1 for each label, and 0 everywhere else 47 | """ 48 | # useful for tensor empty labels 49 | if type(labels) is str: 50 | if labels == "empty": 51 | y = np.zeros(len(self.labels)) - 1 52 | return y 53 | else: 54 | labels = labels.split(",") 55 | if type(labels) is pd.DataFrame: 56 | if labels.empty: 57 | labels = [] 58 | elif "event_label" in labels.columns: 59 | labels = labels["event_label"] 60 | y = np.zeros(len(self.labels)) 61 | for label in labels: 62 | if not pd.isna(label): 63 | i = self.labels.index(label) 64 | y[i] = 1 65 | return y 66 | 67 | def _time_to_frame(self, time): 68 | samples = time * self.fs 69 | frame = (samples) / self.frame_hop 70 | return np.clip(frame / self.net_pooling, a_min=0, a_max=self.n_frames) 71 | 72 | def _frame_to_time(self, frame): 73 | frame = frame * self.net_pooling / (self.fs / self.frame_hop) 74 | return np.clip(frame, a_min=0, a_max=self.audio_len) 75 | 76 | def encode_strong_df(self, label_df): 77 | """Encode a list (or pandas Dataframe or Serie) of strong labels, they correspond to a given filename 78 | 79 | Args: 80 | label_df: pandas DataFrame or Series, contains filename, onset (in frames) and offset (in frames) 81 | If only filename (no onset offset) is specified, it will return the event on all the frames 82 | onset and offset should be in frames 83 | Returns: 84 | numpy.array 85 | Encoded labels, 1 where the label is present, 0 otherwise 86 | """ 87 | 88 | assert any( 89 | [x is not None for x in [self.audio_len, self.frame_len, self.frame_hop]] 90 | ) 91 | 92 | samples_len = self.n_frames 93 | if type(label_df) is str: 94 | if label_df == "empty": 95 | y = np.zeros((samples_len, len(self.labels))) - 1 96 | return y 97 | y = np.zeros((samples_len, len(self.labels))) 98 | if type(label_df) is pd.DataFrame: 99 | if {"onset", "offset", "event_label"}.issubset(label_df.columns): 100 | for _, row in label_df.iterrows(): 101 | if not pd.isna(row["event_label"]): 102 | i = self.labels.index(row["event_label"]) 103 | onset = int(self._time_to_frame(row["onset"])) 104 | offset = int(np.ceil(self._time_to_frame(row["offset"]))) 105 | y[ 106 | onset:offset, i 107 | ] = 1 # means offset not included (hypothesis of overlapping frames, so ok) 108 | 109 | elif type(label_df) in [ 110 | pd.Series, 111 | list, 112 | np.ndarray, 113 | ]: # list of list or list of strings 114 | if type(label_df) is pd.Series: 115 | if {"onset", "offset", "event_label"}.issubset( 116 | label_df.index 117 | ): # means only one value 118 | if not pd.isna(label_df["event_label"]): 119 | i = self.labels.index(label_df["event_label"]) 120 | onset = int(self._time_to_frame(label_df["onset"])) 121 | offset = int(np.ceil(self._time_to_frame(label_df["offset"]))) 122 | y[onset:offset, i] = 1 123 | return y 124 | 125 | for event_label in label_df: 126 | # List of string, so weak labels to be encoded in strong 127 | if type(event_label) is str: 128 | if event_label != "": 129 | i = self.labels.index(event_label) 130 | y[:, i] = 1 131 | 132 | # List of list, with [label, onset, offset] 133 | elif len(event_label) == 3: 134 | if event_label[0] != "": 135 | i = self.labels.index(event_label[0]) 136 | onset = int(self._time_to_frame(event_label[1])) 137 | offset = int(np.ceil(self._time_to_frame(event_label[2]))) 138 | y[onset:offset, i] = 1 139 | 140 | else: 141 | raise NotImplementedError( 142 | "cannot encode strong, type mismatch: {}".format( 143 | type(event_label) 144 | ) 145 | ) 146 | 147 | else: 148 | raise NotImplementedError( 149 | "To encode_strong, type is pandas.Dataframe with onset, offset and event_label" 150 | "columns, or it is a list or pandas Series of event labels, " 151 | "type given: {}".format(type(label_df)) 152 | ) 153 | return y 154 | 155 | def decode_weak(self, labels): 156 | """ Decode the encoded weak labels 157 | Args: 158 | labels: numpy.array, the encoded labels to be decoded 159 | 160 | Returns: 161 | list 162 | Decoded labels, list of string 163 | 164 | """ 165 | result_labels = [] 166 | for i, value in enumerate(labels): 167 | if value == 1: 168 | result_labels.append(self.labels[i]) 169 | return result_labels 170 | 171 | def decode_strong(self, labels): 172 | """ Decode the encoded strong labels 173 | Args: 174 | labels: numpy.array, the encoded labels to be decoded 175 | Returns: 176 | list 177 | Decoded labels, list of list: [[label, onset offset], ...] 178 | 179 | """ 180 | result_labels = [] 181 | for i, label_column in enumerate(labels.T): 182 | change_indices = DecisionEncoder().find_contiguous_regions(label_column) 183 | 184 | # append [label, onset, offset] in the result list 185 | for row in change_indices: 186 | result_labels.append( 187 | [ 188 | self.labels[i], 189 | self._frame_to_time(row[0]), 190 | self._frame_to_time(row[1]), 191 | ] 192 | ) 193 | return result_labels 194 | 195 | def state_dict(self): 196 | return { 197 | "labels": self.labels, 198 | "audio_len": self.audio_len, 199 | "frame_len": self.frame_len, 200 | "frame_hop": self.frame_hop, 201 | "net_pooling": self.net_pooling, 202 | "fs": self.fs, 203 | } 204 | 205 | @classmethod 206 | def load_state_dict(cls, state_dict): 207 | labels = state_dict["labels"] 208 | audio_len = state_dict["audio_len"] 209 | frame_len = state_dict["frame_len"] 210 | frame_hop = state_dict["frame_hop"] 211 | net_pooling = state_dict["net_pooling"] 212 | fs = state_dict["fs"] 213 | return cls(labels, audio_len, frame_len, frame_hop, net_pooling, fs) 214 | -------------------------------------------------------------------------------- /src/local/panns/pytorch_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import time 3 | import torch 4 | import torch.nn as nn 5 | 6 | 7 | def move_data_to_device(x, device): 8 | if 'float' in str(x.dtype): 9 | x = torch.Tensor(x) 10 | elif 'int' in str(x.dtype): 11 | x = torch.LongTensor(x) 12 | else: 13 | return x 14 | 15 | return x.to(device) 16 | 17 | 18 | def do_mixup(x, mixup_lambda): 19 | """Mixup x of even indexes (0, 2, 4, ...) with x of odd indexes 20 | (1, 3, 5, ...). 21 | 22 | Args: 23 | x: (batch_size * 2, ...) 24 | mixup_lambda: (batch_size * 2,) 25 | 26 | Returns: 27 | out: (batch_size, ...) 28 | """ 29 | out = (x[0 :: 2].transpose(0, -1) * mixup_lambda[0 :: 2] + \ 30 | x[1 :: 2].transpose(0, -1) * mixup_lambda[1 :: 2]).transpose(0, -1) 31 | return out 32 | 33 | 34 | def append_to_dict(dict, key, value): 35 | if key in dict.keys(): 36 | dict[key].append(value) 37 | else: 38 | dict[key] = [value] 39 | 40 | 41 | def forward(model, generator, return_input=False, 42 | return_target=False): 43 | """Forward data to a model. 44 | 45 | Args: 46 | model: object 47 | generator: object 48 | return_input: bool 49 | return_target: bool 50 | 51 | Returns: 52 | audio_name: (audios_num,) 53 | clipwise_output: (audios_num, classes_num) 54 | (ifexist) segmentwise_output: (audios_num, segments_num, classes_num) 55 | (ifexist) framewise_output: (audios_num, frames_num, classes_num) 56 | (optional) return_input: (audios_num, segment_samples) 57 | (optional) return_target: (audios_num, classes_num) 58 | """ 59 | output_dict = {} 60 | device = next(model.parameters()).device 61 | time1 = time.time() 62 | 63 | # Forward data to a model in mini-batches 64 | for n, batch_data_dict in enumerate(generator): 65 | print(n) 66 | batch_waveform = move_data_to_device(batch_data_dict['waveform'], device) 67 | 68 | with torch.no_grad(): 69 | model.eval() 70 | batch_output = model(batch_waveform) 71 | 72 | append_to_dict(output_dict, 'audio_name', batch_data_dict['audio_name']) 73 | 74 | append_to_dict(output_dict, 'clipwise_output', 75 | batch_output['clipwise_output'].data.cpu().numpy()) 76 | 77 | if 'segmentwise_output' in batch_output.keys(): 78 | append_to_dict(output_dict, 'segmentwise_output', 79 | batch_output['segmentwise_output'].data.cpu().numpy()) 80 | 81 | if 'framewise_output' in batch_output.keys(): 82 | append_to_dict(output_dict, 'framewise_output', 83 | batch_output['framewise_output'].data.cpu().numpy()) 84 | 85 | if return_input: 86 | append_to_dict(output_dict, 'waveform', batch_data_dict['waveform']) 87 | 88 | if return_target: 89 | if 'target' in batch_data_dict.keys(): 90 | append_to_dict(output_dict, 'target', batch_data_dict['target']) 91 | 92 | if n % 10 == 0: 93 | print(' --- Inference time: {:.3f} s / 10 iterations ---'.format( 94 | time.time() - time1)) 95 | time1 = time.time() 96 | 97 | for key in output_dict.keys(): 98 | output_dict[key] = np.concatenate(output_dict[key], axis=0) 99 | 100 | return output_dict 101 | 102 | 103 | def interpolate(x, ratio): 104 | """Interpolate data in time domain. This is used to compensate the 105 | resolution reduction in downsampling of a CNN. 106 | 107 | Args: 108 | x: (batch_size, time_steps, classes_num) 109 | ratio: int, ratio to interpolate 110 | 111 | Returns: 112 | upsampled: (batch_size, time_steps * ratio, classes_num) 113 | """ 114 | (batch_size, time_steps, classes_num) = x.shape 115 | upsampled = x[:, :, None, :].repeat(1, 1, ratio, 1) 116 | upsampled = upsampled.reshape(batch_size, time_steps * ratio, classes_num) 117 | return upsampled 118 | 119 | 120 | def pad_framewise_output(framewise_output, frames_num): 121 | """Pad framewise_output to the same length as input frames. The pad value 122 | is the same as the value of the last frame. 123 | 124 | Args: 125 | framewise_output: (batch_size, frames_num, classes_num) 126 | frames_num: int, number of frames to pad 127 | 128 | Outputs: 129 | output: (batch_size, frames_num, classes_num) 130 | """ 131 | pad = framewise_output[:, -1 :, :].repeat(1, frames_num - framewise_output.shape[1], 1) 132 | """tensor for padding""" 133 | 134 | output = torch.cat((framewise_output, pad), dim=1) 135 | """(batch_size, frames_num, classes_num)""" 136 | 137 | return output 138 | 139 | 140 | def count_parameters(model): 141 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 142 | 143 | 144 | def count_flops(model, audio_length): 145 | """Count flops. Code modified from others' implementation. 146 | """ 147 | multiply_adds = True 148 | list_conv2d=[] 149 | def conv2d_hook(self, input, output): 150 | batch_size, input_channels, input_height, input_width = input[0].size() 151 | output_channels, output_height, output_width = output[0].size() 152 | 153 | kernel_ops = self.kernel_size[0] * self.kernel_size[1] * (self.in_channels / self.groups) * (2 if multiply_adds else 1) 154 | bias_ops = 1 if self.bias is not None else 0 155 | 156 | params = output_channels * (kernel_ops + bias_ops) 157 | flops = batch_size * params * output_height * output_width 158 | 159 | list_conv2d.append(flops) 160 | 161 | list_conv1d=[] 162 | def conv1d_hook(self, input, output): 163 | batch_size, input_channels, input_length = input[0].size() 164 | output_channels, output_length = output[0].size() 165 | 166 | kernel_ops = self.kernel_size[0] * (self.in_channels / self.groups) * (2 if multiply_adds else 1) 167 | bias_ops = 1 if self.bias is not None else 0 168 | 169 | params = output_channels * (kernel_ops + bias_ops) 170 | flops = batch_size * params * output_length 171 | 172 | list_conv1d.append(flops) 173 | 174 | list_linear=[] 175 | def linear_hook(self, input, output): 176 | batch_size = input[0].size(0) if input[0].dim() == 2 else 1 177 | 178 | weight_ops = self.weight.nelement() * (2 if multiply_adds else 1) 179 | bias_ops = self.bias.nelement() 180 | 181 | flops = batch_size * (weight_ops + bias_ops) 182 | list_linear.append(flops) 183 | 184 | list_bn=[] 185 | def bn_hook(self, input, output): 186 | list_bn.append(input[0].nelement() * 2) 187 | 188 | list_relu=[] 189 | def relu_hook(self, input, output): 190 | list_relu.append(input[0].nelement() * 2) 191 | 192 | list_pooling2d=[] 193 | def pooling2d_hook(self, input, output): 194 | batch_size, input_channels, input_height, input_width = input[0].size() 195 | output_channels, output_height, output_width = output[0].size() 196 | 197 | kernel_ops = self.kernel_size * self.kernel_size 198 | bias_ops = 0 199 | params = output_channels * (kernel_ops + bias_ops) 200 | flops = batch_size * params * output_height * output_width 201 | 202 | list_pooling2d.append(flops) 203 | 204 | list_pooling1d=[] 205 | def pooling1d_hook(self, input, output): 206 | batch_size, input_channels, input_length = input[0].size() 207 | output_channels, output_length = output[0].size() 208 | 209 | kernel_ops = self.kernel_size[0] 210 | bias_ops = 0 211 | 212 | params = output_channels * (kernel_ops + bias_ops) 213 | flops = batch_size * params * output_length 214 | 215 | list_pooling2d.append(flops) 216 | 217 | def foo(net): 218 | childrens = list(net.children()) 219 | if not childrens: 220 | if isinstance(net, nn.Conv2d): 221 | net.register_forward_hook(conv2d_hook) 222 | elif isinstance(net, nn.Conv1d): 223 | net.register_forward_hook(conv1d_hook) 224 | elif isinstance(net, nn.Linear): 225 | net.register_forward_hook(linear_hook) 226 | elif isinstance(net, nn.BatchNorm2d) or isinstance(net, nn.BatchNorm1d): 227 | net.register_forward_hook(bn_hook) 228 | elif isinstance(net, nn.ReLU): 229 | net.register_forward_hook(relu_hook) 230 | elif isinstance(net, nn.AvgPool2d) or isinstance(net, nn.MaxPool2d): 231 | net.register_forward_hook(pooling2d_hook) 232 | elif isinstance(net, nn.AvgPool1d) or isinstance(net, nn.MaxPool1d): 233 | net.register_forward_hook(pooling1d_hook) 234 | else: 235 | print('Warning: flop of module {} is not counted!'.format(net)) 236 | return 237 | for c in childrens: 238 | foo(c) 239 | 240 | # Register hook 241 | foo(model) 242 | 243 | device = device = next(model.parameters()).device 244 | input = torch.rand(1, audio_length).to(device) 245 | 246 | out = model(input) 247 | 248 | total_flops = sum(list_conv2d) + sum(list_conv1d) + sum(list_linear) + \ 249 | sum(list_bn) + sum(list_relu) + sum(list_pooling2d) + sum(list_pooling1d) 250 | 251 | return total_flops -------------------------------------------------------------------------------- /src/local/panns/models.py: -------------------------------------------------------------------------------- 1 | ## script taken from https://github.com/qiuqiangkong/audioset_tagging_cnn 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torchlibrosa.stft import Spectrogram, LogmelFilterBank 7 | from torchlibrosa.augmentation import SpecAugmentation 8 | from local.panns.pytorch_utils import do_mixup, interpolate, pad_framewise_output 9 | 10 | 11 | def init_layer(layer): 12 | """Initialize a Linear or Convolutional layer. """ 13 | nn.init.xavier_uniform_(layer.weight) 14 | 15 | if hasattr(layer, 'bias'): 16 | if layer.bias is not None: 17 | layer.bias.data.fill_(0.) 18 | 19 | 20 | def init_bn(bn): 21 | """Initialize a Batchnorm layer. """ 22 | bn.bias.data.fill_(0.) 23 | bn.weight.data.fill_(1.) 24 | 25 | 26 | class ConvBlock(nn.Module): 27 | def __init__(self, in_channels, out_channels): 28 | 29 | super(ConvBlock, self).__init__() 30 | 31 | self.conv1 = nn.Conv2d(in_channels=in_channels, 32 | out_channels=out_channels, 33 | kernel_size=(3, 3), stride=(1, 1), 34 | padding=(1, 1), bias=False) 35 | 36 | self.conv2 = nn.Conv2d(in_channels=out_channels, 37 | out_channels=out_channels, 38 | kernel_size=(3, 3), stride=(1, 1), 39 | padding=(1, 1), bias=False) 40 | 41 | self.bn1 = nn.BatchNorm2d(out_channels) 42 | self.bn2 = nn.BatchNorm2d(out_channels) 43 | 44 | self.init_weight() 45 | 46 | def init_weight(self): 47 | init_layer(self.conv1) 48 | init_layer(self.conv2) 49 | init_bn(self.bn1) 50 | init_bn(self.bn2) 51 | 52 | 53 | def forward(self, input, pool_size=(2, 2), pool_type='avg'): 54 | 55 | x = input 56 | x = F.relu_(self.bn1(self.conv1(x))) 57 | x = F.relu_(self.bn2(self.conv2(x))) 58 | if pool_type == 'max': 59 | x = F.max_pool2d(x, kernel_size=pool_size) 60 | elif pool_type == 'avg': 61 | x = F.avg_pool2d(x, kernel_size=pool_size) 62 | elif pool_type == 'avg+max': 63 | x1 = F.avg_pool2d(x, kernel_size=pool_size) 64 | x2 = F.max_pool2d(x, kernel_size=pool_size) 65 | x = x1 + x2 66 | else: 67 | raise Exception('Incorrect argument!') 68 | 69 | return x 70 | 71 | 72 | class ConvBlock5x5(nn.Module): 73 | def __init__(self, in_channels, out_channels): 74 | 75 | super(ConvBlock5x5, self).__init__() 76 | 77 | self.conv1 = nn.Conv2d(in_channels=in_channels, 78 | out_channels=out_channels, 79 | kernel_size=(5, 5), stride=(1, 1), 80 | padding=(2, 2), bias=False) 81 | 82 | self.bn1 = nn.BatchNorm2d(out_channels) 83 | 84 | self.init_weight() 85 | 86 | def init_weight(self): 87 | init_layer(self.conv1) 88 | init_bn(self.bn1) 89 | 90 | 91 | def forward(self, input, pool_size=(2, 2), pool_type='avg'): 92 | 93 | x = input 94 | x = F.relu_(self.bn1(self.conv1(x))) 95 | if pool_type == 'max': 96 | x = F.max_pool2d(x, kernel_size=pool_size) 97 | elif pool_type == 'avg': 98 | x = F.avg_pool2d(x, kernel_size=pool_size) 99 | elif pool_type == 'avg+max': 100 | x1 = F.avg_pool2d(x, kernel_size=pool_size) 101 | x2 = F.max_pool2d(x, kernel_size=pool_size) 102 | x = x1 + x2 103 | else: 104 | raise Exception('Incorrect argument!') 105 | 106 | return x 107 | 108 | 109 | class AttBlock(nn.Module): 110 | def __init__(self, n_in, n_out, activation='linear', temperature=1.): 111 | super(AttBlock, self).__init__() 112 | 113 | self.activation = activation 114 | self.temperature = temperature 115 | self.att = nn.Conv1d(in_channels=n_in, out_channels=n_out, kernel_size=1, stride=1, padding=0, bias=True) 116 | self.cla = nn.Conv1d(in_channels=n_in, out_channels=n_out, kernel_size=1, stride=1, padding=0, bias=True) 117 | 118 | self.bn_att = nn.BatchNorm1d(n_out) 119 | self.init_weights() 120 | 121 | def init_weights(self): 122 | init_layer(self.att) 123 | init_layer(self.cla) 124 | init_bn(self.bn_att) 125 | 126 | def forward(self, x): 127 | # x: (n_samples, n_in, n_time) 128 | norm_att = torch.softmax(torch.clamp(self.att(x), -10, 10), dim=-1) 129 | cla = self.nonlinear_transform(self.cla(x)) 130 | x = torch.sum(norm_att * cla, dim=2) 131 | return x, norm_att, cla 132 | 133 | def nonlinear_transform(self, x): 134 | if self.activation == 'linear': 135 | return x 136 | elif self.activation == 'sigmoid': 137 | return torch.sigmoid(x) 138 | 139 | class Cnn14_16k(nn.Module): 140 | def __init__(self, hop_size=160, freeze_bn=True, use_specaugm=False): 141 | 142 | super(Cnn14_16k, self).__init__() 143 | 144 | self.freeze_bn = freeze_bn 145 | self.use_specaugm = use_specaugm 146 | sample_rate = 16000 147 | window_size = 512 148 | assert hop_size == 160 149 | mel_bins = 64 150 | fmin = 50 151 | fmax = 8000 152 | classes_num = 527 153 | 154 | window = 'hann' 155 | center = True 156 | pad_mode = 'reflect' 157 | ref = 1.0 158 | amin = 1e-10 159 | top_db = None 160 | 161 | # Spectrogram extractor 162 | self.spectrogram_extractor = Spectrogram(n_fft=window_size, hop_length=hop_size, 163 | win_length=window_size, window=window, center=center, pad_mode=pad_mode, 164 | freeze_parameters=True) 165 | 166 | # Logmel feature extractor 167 | self.logmel_extractor = LogmelFilterBank(sr=sample_rate, n_fft=window_size, 168 | n_mels=mel_bins, fmin=fmin, fmax=fmax, ref=ref, amin=amin, top_db=top_db, 169 | freeze_parameters=True) 170 | 171 | # Spec augmenter 172 | self.spec_augmenter = SpecAugmentation(time_drop_width=64, time_stripes_num=2, 173 | freq_drop_width=8, freq_stripes_num=2) 174 | 175 | self.bn0 = nn.BatchNorm2d(64) 176 | 177 | self.conv_block1 = ConvBlock(in_channels=1, out_channels=64) 178 | self.conv_block2 = ConvBlock(in_channels=64, out_channels=128) 179 | self.conv_block3 = ConvBlock(in_channels=128, out_channels=256) 180 | 181 | self.conv_block4 = ConvBlock(in_channels=256, out_channels=512) 182 | self.conv_block5 = ConvBlock(in_channels=512, out_channels=1024) 183 | self.conv_block6 = ConvBlock(in_channels=1024, out_channels=2048) 184 | 185 | self.fc1 = nn.Linear(2048, 2048, bias=True) 186 | # this is not used here, we only use the pretrained model as an embedding extractor 187 | #self.fc_audioset = nn.Linear(2048, classes_num, bias=True) 188 | self.init_weight() 189 | 190 | def init_weight(self): 191 | init_bn(self.bn0) 192 | init_layer(self.fc1) 193 | #init_layer(self.fc_audioset) 194 | 195 | def forward(self, input, mixup_lambda=None): 196 | """ 197 | Input: (batch_size, data_length)""" 198 | 199 | x = self.spectrogram_extractor(input) # (batch_size, 1, time_steps, freq_bins) 200 | x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins) 201 | 202 | x = x.transpose(1, 3) 203 | x = self.bn0(x) 204 | x = x.transpose(1, 3) 205 | 206 | if self.training and self.use_specaugm: 207 | x = self.spec_augmenter(x) 208 | 209 | # Mixup on spectrogram 210 | if self.training and mixup_lambda is not None: 211 | x = do_mixup(x, mixup_lambda) 212 | 213 | x = self.conv_block1(x, pool_size=(2, 2), pool_type='avg') 214 | x = F.dropout(x, p=0.2, training=self.training) 215 | x = self.conv_block2(x, pool_size=(2, 2), pool_type='avg') 216 | x = F.dropout(x, p=0.2, training=self.training) 217 | x = self.conv_block3(x, pool_size=(2, 2), pool_type='avg') 218 | x = F.dropout(x, p=0.2, training=self.training) 219 | frame_embedding = x 220 | # we take frame-level embedding from this layer, participants are free to use other layers 221 | x = self.conv_block4(x, pool_size=(2, 2), pool_type='avg') 222 | x = F.dropout(x, p=0.2, training=self.training) 223 | x = self.conv_block5(x, pool_size=(2, 2), pool_type='avg') 224 | x = F.dropout(x, p=0.2, training=self.training) 225 | x = self.conv_block6(x, pool_size=(1, 1), pool_type='avg') 226 | x = F.dropout(x, p=0.2, training=self.training) 227 | 228 | x = torch.mean(x, dim=3) 229 | (x1, _) = torch.max(x, dim=2) 230 | x2 = torch.mean(x, dim=2) 231 | x = x1 + x2 232 | global_emb = F.dropout(x, p=0.5, training=self.training) 233 | bsz, chans, time, freq = frame_embedding.shape 234 | return {"global": global_emb, "frame": frame_embedding.transpose(2, -1).reshape(bsz, chans*freq, time)} 235 | 236 | def train(self, mode=True): 237 | """ 238 | Override the default train() to freeze the BN parameters 239 | """ 240 | super(Cnn14_16k, self).train(mode) 241 | if self.freeze_bn: 242 | print("Freezing Mean/Var of BatchNorm2D.") 243 | if self.freeze_bn: 244 | print("Freezing Weight/Bias of BatchNorm2D.") 245 | if self.freeze_bn: 246 | for m in self.modules(): 247 | if isinstance(m, nn.BatchNorm2d): 248 | m.eval() 249 | if self.freeze_bn: 250 | m.weight.requires_grad = False 251 | m.bias.requires_grad = False -------------------------------------------------------------------------------- /src/desed_task/evaluation/evaluation_measures.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import pandas as pd 5 | import psds_eval 6 | import sed_eval 7 | from psds_eval import PSDSEval, plot_psd_roc 8 | 9 | 10 | def get_event_list_current_file(df, fname): 11 | """ 12 | Get list of events for a given filename 13 | Args: 14 | df: pd.DataFrame, the dataframe to search on 15 | fname: the filename to extract the value from the dataframe 16 | Returns: 17 | list of events (dictionaries) for the given filename 18 | """ 19 | event_file = df[df["filename"] == fname] 20 | if len(event_file) == 1: 21 | if pd.isna(event_file["event_label"].iloc[0]): 22 | event_list_for_current_file = [{"filename": fname}] 23 | else: 24 | event_list_for_current_file = event_file.to_dict("records") 25 | else: 26 | event_list_for_current_file = event_file.to_dict("records") 27 | 28 | return event_list_for_current_file 29 | 30 | 31 | def psds_results(psds_obj): 32 | """ Compute psds scores 33 | Args: 34 | psds_obj: psds_eval.PSDSEval object with operating points. 35 | Returns: 36 | """ 37 | try: 38 | psds_score = psds_obj.psds(alpha_ct=0, alpha_st=0, max_efpr=100) 39 | print(f"\nPSD-Score (0, 0, 100): {psds_score.value:.5f}") 40 | psds_score = psds_obj.psds(alpha_ct=1, alpha_st=0, max_efpr=100) 41 | print(f"\nPSD-Score (1, 0, 100): {psds_score.value:.5f}") 42 | psds_score = psds_obj.psds(alpha_ct=0, alpha_st=1, max_efpr=100) 43 | print(f"\nPSD-Score (0, 1, 100): {psds_score.value:.5f}") 44 | except psds_eval.psds.PSDSEvalError as e: 45 | print("psds did not work ....") 46 | raise EnvironmentError 47 | 48 | 49 | def event_based_evaluation_df( 50 | reference, estimated, t_collar=0.200, percentage_of_length=0.2 51 | ): 52 | """ Calculate EventBasedMetric given a reference and estimated dataframe 53 | 54 | Args: 55 | reference: pd.DataFrame containing "filename" "onset" "offset" and "event_label" columns which describe the 56 | reference events 57 | estimated: pd.DataFrame containing "filename" "onset" "offset" and "event_label" columns which describe the 58 | estimated events to be compared with reference 59 | t_collar: float, in seconds, the number of time allowed on onsets and offsets 60 | percentage_of_length: float, between 0 and 1, the percentage of length of the file allowed on the offset 61 | Returns: 62 | sed_eval.sound_event.EventBasedMetrics with the scores 63 | """ 64 | 65 | evaluated_files = reference["filename"].unique() 66 | 67 | classes = [] 68 | classes.extend(reference.event_label.dropna().unique()) 69 | classes.extend(estimated.event_label.dropna().unique()) 70 | classes = list(set(classes)) 71 | 72 | event_based_metric = sed_eval.sound_event.EventBasedMetrics( 73 | event_label_list=classes, 74 | t_collar=t_collar, 75 | percentage_of_length=percentage_of_length, 76 | empty_system_output_handling="zero_score", 77 | ) 78 | 79 | for fname in evaluated_files: 80 | reference_event_list_for_current_file = get_event_list_current_file( 81 | reference, fname 82 | ) 83 | estimated_event_list_for_current_file = get_event_list_current_file( 84 | estimated, fname 85 | ) 86 | 87 | event_based_metric.evaluate( 88 | reference_event_list=reference_event_list_for_current_file, 89 | estimated_event_list=estimated_event_list_for_current_file, 90 | ) 91 | 92 | return event_based_metric 93 | 94 | 95 | def segment_based_evaluation_df(reference, estimated, time_resolution=1.0): 96 | """ Calculate SegmentBasedMetrics given a reference and estimated dataframe 97 | 98 | Args: 99 | reference: pd.DataFrame containing "filename" "onset" "offset" and "event_label" columns which describe the 100 | reference events 101 | estimated: pd.DataFrame containing "filename" "onset" "offset" and "event_label" columns which describe the 102 | estimated events to be compared with reference 103 | time_resolution: float, the time resolution of the segment based metric 104 | Returns: 105 | sed_eval.sound_event.SegmentBasedMetrics with the scores 106 | """ 107 | evaluated_files = reference["filename"].unique() 108 | 109 | classes = [] 110 | classes.extend(reference.event_label.dropna().unique()) 111 | classes.extend(estimated.event_label.dropna().unique()) 112 | classes = list(set(classes)) 113 | 114 | segment_based_metric = sed_eval.sound_event.SegmentBasedMetrics( 115 | event_label_list=classes, time_resolution=time_resolution 116 | ) 117 | 118 | for fname in evaluated_files: 119 | reference_event_list_for_current_file = get_event_list_current_file( 120 | reference, fname 121 | ) 122 | estimated_event_list_for_current_file = get_event_list_current_file( 123 | estimated, fname 124 | ) 125 | 126 | segment_based_metric.evaluate( 127 | reference_event_list=reference_event_list_for_current_file, 128 | estimated_event_list=estimated_event_list_for_current_file, 129 | ) 130 | 131 | return segment_based_metric 132 | 133 | 134 | def compute_sed_eval_metrics(predictions, groundtruth): 135 | """ Compute sed_eval metrics event based and segment based with default parameters used in the task. 136 | Args: 137 | predictions: pd.DataFrame, predictions dataframe 138 | groundtruth: pd.DataFrame, groundtruth dataframe 139 | Returns: 140 | tuple, (sed_eval.sound_event.EventBasedMetrics, sed_eval.sound_event.SegmentBasedMetrics) 141 | """ 142 | metric_event = event_based_evaluation_df( 143 | groundtruth, predictions, t_collar=0.200, percentage_of_length=0.2 144 | ) 145 | metric_segment = segment_based_evaluation_df( 146 | groundtruth, predictions, time_resolution=1.0 147 | ) 148 | 149 | return metric_event, metric_segment 150 | 151 | 152 | def compute_per_intersection_macro_f1( 153 | prediction_dfs, 154 | ground_truth_file, 155 | durations_file, 156 | dtc_threshold=0.5, 157 | gtc_threshold=0.5, 158 | cttc_threshold=0.3, 159 | ): 160 | """ Compute F1-score per intersection, using the defautl 161 | Args: 162 | prediction_dfs: dict, a dictionary with thresholds keys and predictions dataframe 163 | ground_truth_file: pd.DataFrame, the groundtruth dataframe 164 | durations_file: pd.DataFrame, the duration dataframe 165 | dtc_threshold: float, the parameter used in PSDSEval, percentage of tolerance for groundtruth intersection 166 | with predictions 167 | gtc_threshold: float, the parameter used in PSDSEval percentage of tolerance for predictions intersection 168 | with groundtruth 169 | gtc_threshold: float, the parameter used in PSDSEval to know the percentage needed to count FP as cross-trigger 170 | 171 | Returns: 172 | 173 | """ 174 | gt = pd.read_csv(ground_truth_file, sep="\t") 175 | durations = pd.read_csv(durations_file, sep="\t") 176 | 177 | psds = PSDSEval( 178 | ground_truth=gt, 179 | metadata=durations, 180 | dtc_threshold=dtc_threshold, 181 | gtc_threshold=gtc_threshold, 182 | cttc_threshold=cttc_threshold, 183 | ) 184 | psds_macro_f1 = [] 185 | for threshold in prediction_dfs.keys(): 186 | if not prediction_dfs[threshold].empty: 187 | threshold_f1, _ = psds.compute_macro_f_score(prediction_dfs[threshold]) 188 | else: 189 | threshold_f1 = 0 190 | if np.isnan(threshold_f1): 191 | threshold_f1 = 0.0 192 | psds_macro_f1.append(threshold_f1) 193 | psds_macro_f1 = np.mean(psds_macro_f1) 194 | return psds_macro_f1 195 | 196 | 197 | def compute_psds_from_operating_points( 198 | prediction_dfs, 199 | ground_truth_file, 200 | durations_file, 201 | dtc_threshold=0.5, 202 | gtc_threshold=0.5, 203 | cttc_threshold=0.3, 204 | alpha_ct=0, 205 | alpha_st=0, 206 | max_efpr=100, 207 | save_dir=None, 208 | ): 209 | 210 | gt = pd.read_csv(ground_truth_file, sep="\t") 211 | durations = pd.read_csv(durations_file, sep="\t") 212 | psds_eval = PSDSEval( 213 | ground_truth=gt, 214 | metadata=durations, 215 | dtc_threshold=dtc_threshold, 216 | gtc_threshold=gtc_threshold, 217 | cttc_threshold=cttc_threshold, 218 | ) 219 | 220 | for i, k in enumerate(prediction_dfs.keys()): 221 | det = prediction_dfs[k] 222 | # see issue https://github.com/audioanalytic/psds_eval/issues/3 223 | det["index"] = range(1, len(det) + 1) 224 | det = det.set_index("index") 225 | psds_eval.add_operating_point( 226 | det, info={"name": f"Op {i + 1:02d}", "threshold": k} 227 | ) 228 | 229 | psds_score = psds_eval.psds(alpha_ct=alpha_ct, alpha_st=alpha_st, max_efpr=max_efpr) 230 | 231 | if save_dir is not None: 232 | os.makedirs(save_dir, exist_ok=True) 233 | 234 | pred_dir = os.path.join( 235 | save_dir, 236 | f"predictions_dtc{dtc_threshold}_gtc{gtc_threshold}_cttc{cttc_threshold}", 237 | ) 238 | os.makedirs(pred_dir, exist_ok=True) 239 | for k in prediction_dfs.keys(): 240 | prediction_dfs[k].to_csv( 241 | os.path.join(pred_dir, f"predictions_th_{k:.2f}.tsv"), 242 | sep="\t", 243 | index=False, 244 | ) 245 | 246 | plot_psd_roc( 247 | psds_score, 248 | filename=os.path.join(save_dir, f"PSDS_ct{alpha_ct}_st{alpha_st}_100.png"), 249 | ) 250 | 251 | return psds_score.value 252 | -------------------------------------------------------------------------------- /src/desed_task/evaluation/.ipynb_checkpoints/evaluation_measures-checkpoint.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import pandas as pd 5 | import psds_eval 6 | import sed_eval 7 | from psds_eval import PSDSEval, plot_psd_roc 8 | 9 | 10 | def get_event_list_current_file(df, fname): 11 | """ 12 | Get list of events for a given filename 13 | Args: 14 | df: pd.DataFrame, the dataframe to search on 15 | fname: the filename to extract the value from the dataframe 16 | Returns: 17 | list of events (dictionaries) for the given filename 18 | """ 19 | event_file = df[df["filename"] == fname] 20 | if len(event_file) == 1: 21 | if pd.isna(event_file["event_label"].iloc[0]): 22 | event_list_for_current_file = [{"filename": fname}] 23 | else: 24 | event_list_for_current_file = event_file.to_dict("records") 25 | else: 26 | event_list_for_current_file = event_file.to_dict("records") 27 | 28 | return event_list_for_current_file 29 | 30 | 31 | def psds_results(psds_obj): 32 | """ Compute psds scores 33 | Args: 34 | psds_obj: psds_eval.PSDSEval object with operating points. 35 | Returns: 36 | """ 37 | try: 38 | psds_score = psds_obj.psds(alpha_ct=0, alpha_st=0, max_efpr=100) 39 | print(f"\nPSD-Score (0, 0, 100): {psds_score.value:.5f}") 40 | psds_score = psds_obj.psds(alpha_ct=1, alpha_st=0, max_efpr=100) 41 | print(f"\nPSD-Score (1, 0, 100): {psds_score.value:.5f}") 42 | psds_score = psds_obj.psds(alpha_ct=0, alpha_st=1, max_efpr=100) 43 | print(f"\nPSD-Score (0, 1, 100): {psds_score.value:.5f}") 44 | except psds_eval.psds.PSDSEvalError as e: 45 | print("psds did not work ....") 46 | raise EnvironmentError 47 | 48 | 49 | def event_based_evaluation_df( 50 | reference, estimated, t_collar=0.200, percentage_of_length=0.2 51 | ): 52 | """ Calculate EventBasedMetric given a reference and estimated dataframe 53 | 54 | Args: 55 | reference: pd.DataFrame containing "filename" "onset" "offset" and "event_label" columns which describe the 56 | reference events 57 | estimated: pd.DataFrame containing "filename" "onset" "offset" and "event_label" columns which describe the 58 | estimated events to be compared with reference 59 | t_collar: float, in seconds, the number of time allowed on onsets and offsets 60 | percentage_of_length: float, between 0 and 1, the percentage of length of the file allowed on the offset 61 | Returns: 62 | sed_eval.sound_event.EventBasedMetrics with the scores 63 | """ 64 | 65 | evaluated_files = reference["filename"].unique() 66 | 67 | classes = [] 68 | classes.extend(reference.event_label.dropna().unique()) 69 | classes.extend(estimated.event_label.dropna().unique()) 70 | classes = list(set(classes)) 71 | 72 | event_based_metric = sed_eval.sound_event.EventBasedMetrics( 73 | event_label_list=classes, 74 | t_collar=t_collar, 75 | percentage_of_length=percentage_of_length, 76 | empty_system_output_handling="zero_score", 77 | ) 78 | 79 | for fname in evaluated_files: 80 | reference_event_list_for_current_file = get_event_list_current_file( 81 | reference, fname 82 | ) 83 | estimated_event_list_for_current_file = get_event_list_current_file( 84 | estimated, fname 85 | ) 86 | 87 | event_based_metric.evaluate( 88 | reference_event_list=reference_event_list_for_current_file, 89 | estimated_event_list=estimated_event_list_for_current_file, 90 | ) 91 | 92 | return event_based_metric 93 | 94 | 95 | def segment_based_evaluation_df(reference, estimated, time_resolution=1.0): 96 | """ Calculate SegmentBasedMetrics given a reference and estimated dataframe 97 | 98 | Args: 99 | reference: pd.DataFrame containing "filename" "onset" "offset" and "event_label" columns which describe the 100 | reference events 101 | estimated: pd.DataFrame containing "filename" "onset" "offset" and "event_label" columns which describe the 102 | estimated events to be compared with reference 103 | time_resolution: float, the time resolution of the segment based metric 104 | Returns: 105 | sed_eval.sound_event.SegmentBasedMetrics with the scores 106 | """ 107 | evaluated_files = reference["filename"].unique() 108 | 109 | classes = [] 110 | classes.extend(reference.event_label.dropna().unique()) 111 | classes.extend(estimated.event_label.dropna().unique()) 112 | classes = list(set(classes)) 113 | 114 | segment_based_metric = sed_eval.sound_event.SegmentBasedMetrics( 115 | event_label_list=classes, time_resolution=time_resolution 116 | ) 117 | 118 | for fname in evaluated_files: 119 | reference_event_list_for_current_file = get_event_list_current_file( 120 | reference, fname 121 | ) 122 | estimated_event_list_for_current_file = get_event_list_current_file( 123 | estimated, fname 124 | ) 125 | 126 | segment_based_metric.evaluate( 127 | reference_event_list=reference_event_list_for_current_file, 128 | estimated_event_list=estimated_event_list_for_current_file, 129 | ) 130 | 131 | return segment_based_metric 132 | 133 | 134 | def compute_sed_eval_metrics(predictions, groundtruth): 135 | """ Compute sed_eval metrics event based and segment based with default parameters used in the task. 136 | Args: 137 | predictions: pd.DataFrame, predictions dataframe 138 | groundtruth: pd.DataFrame, groundtruth dataframe 139 | Returns: 140 | tuple, (sed_eval.sound_event.EventBasedMetrics, sed_eval.sound_event.SegmentBasedMetrics) 141 | """ 142 | metric_event = event_based_evaluation_df( 143 | groundtruth, predictions, t_collar=0.200, percentage_of_length=0.2 144 | ) 145 | metric_segment = segment_based_evaluation_df( 146 | groundtruth, predictions, time_resolution=1.0 147 | ) 148 | 149 | return metric_event, metric_segment 150 | 151 | 152 | def compute_per_intersection_macro_f1( 153 | prediction_dfs, 154 | ground_truth_file, 155 | durations_file, 156 | dtc_threshold=0.5, 157 | gtc_threshold=0.5, 158 | cttc_threshold=0.3, 159 | ): 160 | """ Compute F1-score per intersection, using the defautl 161 | Args: 162 | prediction_dfs: dict, a dictionary with thresholds keys and predictions dataframe 163 | ground_truth_file: pd.DataFrame, the groundtruth dataframe 164 | durations_file: pd.DataFrame, the duration dataframe 165 | dtc_threshold: float, the parameter used in PSDSEval, percentage of tolerance for groundtruth intersection 166 | with predictions 167 | gtc_threshold: float, the parameter used in PSDSEval percentage of tolerance for predictions intersection 168 | with groundtruth 169 | gtc_threshold: float, the parameter used in PSDSEval to know the percentage needed to count FP as cross-trigger 170 | 171 | Returns: 172 | 173 | """ 174 | gt = pd.read_csv(ground_truth_file, sep="\t") 175 | durations = pd.read_csv(durations_file, sep="\t") 176 | 177 | psds = PSDSEval( 178 | ground_truth=gt, 179 | metadata=durations, 180 | dtc_threshold=dtc_threshold, 181 | gtc_threshold=gtc_threshold, 182 | cttc_threshold=cttc_threshold, 183 | ) 184 | psds_macro_f1 = [] 185 | for threshold in prediction_dfs.keys(): 186 | if not prediction_dfs[threshold].empty: 187 | threshold_f1, _ = psds.compute_macro_f_score(prediction_dfs[threshold]) 188 | else: 189 | threshold_f1 = 0 190 | if np.isnan(threshold_f1): 191 | threshold_f1 = 0.0 192 | psds_macro_f1.append(threshold_f1) 193 | psds_macro_f1 = np.mean(psds_macro_f1) 194 | return psds_macro_f1 195 | 196 | 197 | def compute_psds_from_operating_points( 198 | prediction_dfs, 199 | ground_truth_file, 200 | durations_file, 201 | dtc_threshold=0.5, 202 | gtc_threshold=0.5, 203 | cttc_threshold=0.3, 204 | alpha_ct=0, 205 | alpha_st=0, 206 | max_efpr=100, 207 | save_dir=None, 208 | ): 209 | 210 | gt = pd.read_csv(ground_truth_file, sep="\t") 211 | durations = pd.read_csv(durations_file, sep="\t") 212 | psds_eval = PSDSEval( 213 | ground_truth=gt, 214 | metadata=durations, 215 | dtc_threshold=dtc_threshold, 216 | gtc_threshold=gtc_threshold, 217 | cttc_threshold=cttc_threshold, 218 | ) 219 | 220 | for i, k in enumerate(prediction_dfs.keys()): 221 | det = prediction_dfs[k] 222 | # see issue https://github.com/audioanalytic/psds_eval/issues/3 223 | det["index"] = range(1, len(det) + 1) 224 | det = det.set_index("index") 225 | psds_eval.add_operating_point( 226 | det, info={"name": f"Op {i + 1:02d}", "threshold": k} 227 | ) 228 | 229 | psds_score = psds_eval.psds(alpha_ct=alpha_ct, alpha_st=alpha_st, max_efpr=max_efpr) 230 | 231 | if save_dir is not None: 232 | os.makedirs(save_dir, exist_ok=True) 233 | 234 | pred_dir = os.path.join( 235 | save_dir, 236 | f"predictions_dtc{dtc_threshold}_gtc{gtc_threshold}_cttc{cttc_threshold}", 237 | ) 238 | os.makedirs(pred_dir, exist_ok=True) 239 | for k in prediction_dfs.keys(): 240 | prediction_dfs[k].to_csv( 241 | os.path.join(pred_dir, f"predictions_th_{k:.2f}.tsv"), 242 | sep="\t", 243 | index=False, 244 | ) 245 | 246 | plot_psd_roc( 247 | psds_score, 248 | filename=os.path.join(save_dir, f"PSDS_ct{alpha_ct}_st{alpha_st}_100.png"), 249 | ) 250 | 251 | return psds_score.value 252 | -------------------------------------------------------------------------------- /src/local/ast/ast_models.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 6/10/21 5:04 PM 3 | # @Author : Yuan Gong 4 | # @Affiliation : Massachusetts Institute of Technology 5 | # @Email : yuangong@mit.edu 6 | # @File : ast_models.py 7 | 8 | import torch 9 | import torch.nn as nn 10 | from torch.cuda.amp import autocast 11 | import os 12 | import wget 13 | import timm 14 | from timm.models.layers import to_2tuple,trunc_normal_ 15 | from pathlib import Path 16 | 17 | # override the timm package to relax the input shape constraint. 18 | class PatchEmbed(nn.Module): 19 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): 20 | super().__init__() 21 | 22 | img_size = to_2tuple(img_size) 23 | patch_size = to_2tuple(patch_size) 24 | num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) 25 | self.img_size = img_size 26 | self.patch_size = patch_size 27 | self.num_patches = num_patches 28 | 29 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) 30 | 31 | def forward(self, x): 32 | x = self.proj(x).flatten(2).transpose(1, 2) 33 | return x 34 | 35 | class ASTModel(nn.Module): 36 | """ 37 | The AST model. 38 | :param label_dim: the label dimension, i.e., the number of total classes, it is 527 for AudioSet, 50 for ESC-50, and 35 for speechcommands v2-35 39 | :param fstride: the stride of patch spliting on the frequency dimension, for 16*16 patchs, fstride=16 means no overlap, fstride=10 means overlap of 6 40 | :param tstride: the stride of patch spliting on the time dimension, for 16*16 patchs, tstride=16 means no overlap, tstride=10 means overlap of 6 41 | :param input_fdim: the number of frequency bins of the input spectrogram 42 | :param input_tdim: the number of time frames of the input spectrogram 43 | :param imagenet_pretrain: if use ImageNet pretrained model 44 | :param audioset_pretrain: if use full AudioSet and ImageNet pretrained model 45 | :param model_size: the model size of AST, should be in [tiny224, small224, base224, base384], base224 and base 384 are same model, but are trained differently during ImageNet pretraining. 46 | """ 47 | def __init__(self, label_dim=527, fstride=10, tstride=10, input_fdim=128, input_tdim=1024, imagenet_pretrain=True, audioset_pretrain=False, model_size='base384', verbose=True): 48 | 49 | super(ASTModel, self).__init__() 50 | assert timm.__version__ == '0.4.5', 'Please use timm == 0.4.5, the code might not be compatible with newer versions.' 51 | 52 | if verbose == True: 53 | print('---------------AST Model Summary---------------') 54 | print('ImageNet pretraining: {:s}, AudioSet pretraining: {:s}'.format(str(imagenet_pretrain),str(audioset_pretrain))) 55 | # override timm input shape restriction 56 | timm.models.vision_transformer.PatchEmbed = PatchEmbed 57 | 58 | # if AudioSet pretraining is not used (but ImageNet pretraining may still apply) 59 | if audioset_pretrain == False: 60 | if model_size == 'tiny224': 61 | self.v = timm.create_model('vit_deit_tiny_distilled_patch16_224', pretrained=imagenet_pretrain) 62 | elif model_size == 'small224': 63 | self.v = timm.create_model('vit_deit_small_distilled_patch16_224', pretrained=imagenet_pretrain) 64 | elif model_size == 'base224': 65 | self.v = timm.create_model('vit_deit_base_distilled_patch16_224', pretrained=imagenet_pretrain) 66 | elif model_size == 'base384': 67 | self.v = timm.create_model('vit_deit_base_distilled_patch16_384', pretrained=imagenet_pretrain) 68 | else: 69 | raise Exception('Model size must be one of tiny224, small224, base224, base384.') 70 | self.original_num_patches = self.v.patch_embed.num_patches 71 | self.oringal_hw = int(self.original_num_patches ** 0.5) 72 | self.original_embedding_dim = self.v.pos_embed.shape[2] 73 | self.mlp_head = nn.Sequential(nn.LayerNorm(self.original_embedding_dim), nn.Linear(self.original_embedding_dim, label_dim)) 74 | 75 | # automatcially get the intermediate shape 76 | f_dim, t_dim = self.get_shape(fstride, tstride, input_fdim, input_tdim) 77 | num_patches = f_dim * t_dim 78 | self.v.patch_embed.num_patches = num_patches 79 | if verbose == True: 80 | print('frequncey stride={:d}, time stride={:d}'.format(fstride, tstride)) 81 | print('number of patches={:d}'.format(num_patches)) 82 | 83 | # the linear projection layer 84 | new_proj = torch.nn.Conv2d(1, self.original_embedding_dim, kernel_size=(16, 16), stride=(fstride, tstride)) 85 | if imagenet_pretrain == True: 86 | new_proj.weight = torch.nn.Parameter(torch.sum(self.v.patch_embed.proj.weight, dim=1).unsqueeze(1)) 87 | new_proj.bias = self.v.patch_embed.proj.bias 88 | self.v.patch_embed.proj = new_proj 89 | 90 | # the positional embedding 91 | if imagenet_pretrain == True: 92 | # get the positional embedding from deit model, skip the first two tokens (cls token and distillation token), reshape it to original 2D shape (24*24). 93 | new_pos_embed = self.v.pos_embed[:, 2:, :].detach().reshape(1, self.original_num_patches, self.original_embedding_dim).transpose(1, 2).reshape(1, self.original_embedding_dim, self.oringal_hw, self.oringal_hw) 94 | # cut (from middle) or interpolate the second dimension of the positional embedding 95 | if t_dim <= self.oringal_hw: 96 | new_pos_embed = new_pos_embed[:, :, :, int(self.oringal_hw / 2) - int(t_dim / 2): int(self.oringal_hw / 2) - int(t_dim / 2) + t_dim] 97 | else: 98 | new_pos_embed = torch.nn.functional.interpolate(new_pos_embed, size=(self.oringal_hw, t_dim), mode='bilinear') 99 | # cut (from middle) or interpolate the first dimension of the positional embedding 100 | if f_dim <= self.oringal_hw: 101 | new_pos_embed = new_pos_embed[:, :, int(self.oringal_hw / 2) - int(f_dim / 2): int(self.oringal_hw / 2) - int(f_dim / 2) + f_dim, :] 102 | else: 103 | new_pos_embed = torch.nn.functional.interpolate(new_pos_embed, size=(f_dim, t_dim), mode='bilinear') 104 | # flatten the positional embedding 105 | new_pos_embed = new_pos_embed.reshape(1, self.original_embedding_dim, num_patches).transpose(1,2) 106 | # concatenate the above positional embedding with the cls token and distillation token of the deit model. 107 | self.v.pos_embed = nn.Parameter(torch.cat([self.v.pos_embed[:, :2, :].detach(), new_pos_embed], dim=1)) 108 | else: 109 | # if not use imagenet pretrained model, just randomly initialize a learnable positional embedding 110 | # TODO can use sinusoidal positional embedding instead 111 | new_pos_embed = nn.Parameter(torch.zeros(1, self.v.patch_embed.num_patches + 2, self.original_embedding_dim)) 112 | self.v.pos_embed = new_pos_embed 113 | trunc_normal_(self.v.pos_embed, std=.02) 114 | 115 | # now load a model that is pretrained on both ImageNet and AudioSet 116 | elif audioset_pretrain == True: 117 | if audioset_pretrain == True and imagenet_pretrain == False: 118 | raise ValueError('currently model pretrained on only audioset is not supported, please set imagenet_pretrain = True to use audioset pretrained model.') 119 | if model_size != 'base384': 120 | raise ValueError('currently only has base384 AudioSet pretrained model.') 121 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 122 | if os.path.exists('./pretrained_models/audioset_10_10_0.4593.pth') == False: 123 | # this model performs 0.4593 mAP on the audioset eval set 124 | os.makedirs(Path('./pretrained_models/audioset_10_10_0.4593.pth').parent, exist_ok=True) 125 | audioset_mdl_url = 'https://www.dropbox.com/s/cv4knew8mvbrnvq/audioset_0.4593.pth?dl=1' 126 | wget.download(audioset_mdl_url, out='./pretrained_models/audioset_10_10_0.4593.pth') 127 | sd = torch.load('./pretrained_models/audioset_10_10_0.4593.pth', map_location=device) 128 | audio_model = ASTModel(label_dim=527, fstride=10, tstride=10, input_fdim=128, input_tdim=1024, imagenet_pretrain=False, audioset_pretrain=False, model_size='base384', verbose=False) 129 | audio_model = torch.nn.DataParallel(audio_model) 130 | audio_model.load_state_dict(sd, strict=False) 131 | self.v = audio_model.module.v 132 | self.original_embedding_dim = self.v.pos_embed.shape[2] 133 | self.mlp_head = nn.Sequential(nn.LayerNorm(self.original_embedding_dim), nn.Linear(self.original_embedding_dim, label_dim)) 134 | 135 | f_dim, t_dim = self.get_shape(fstride, tstride, input_fdim, input_tdim) 136 | num_patches = f_dim * t_dim 137 | self.v.patch_embed.num_patches = num_patches 138 | if verbose == True: 139 | print('frequncey stride={:d}, time stride={:d}'.format(fstride, tstride)) 140 | print('number of patches={:d}'.format(num_patches)) 141 | 142 | new_pos_embed = self.v.pos_embed[:, 2:, :].detach().reshape(1, 1212, 768).transpose(1, 2).reshape(1, 768, 12, 101) 143 | # if the input sequence length is larger than the original audioset (10s), then cut the positional embedding 144 | if t_dim < 101: 145 | new_pos_embed = new_pos_embed[:, :, :, 50 - int(t_dim/2): 50 - int(t_dim/2) + t_dim] 146 | # otherwise interpolate 147 | else: 148 | new_pos_embed = torch.nn.functional.interpolate(new_pos_embed, size=(12, t_dim), mode='bilinear') 149 | new_pos_embed = new_pos_embed.reshape(1, 768, num_patches).transpose(1, 2) 150 | self.v.pos_embed = nn.Parameter(torch.cat([self.v.pos_embed[:, :2, :].detach(), new_pos_embed], dim=1)) 151 | 152 | def get_shape(self, fstride, tstride, input_fdim=128, input_tdim=1024): 153 | test_input = torch.randn(1, 1, input_fdim, input_tdim) 154 | test_proj = nn.Conv2d(1, self.original_embedding_dim, kernel_size=(16, 16), stride=(fstride, tstride)) 155 | test_out = test_proj(test_input) 156 | f_dim = test_out.shape[2] 157 | t_dim = test_out.shape[3] 158 | return f_dim, t_dim 159 | 160 | @autocast() 161 | def forward(self, x): 162 | """ 163 | :param x: the input spectrogram, expected shape: (batch_size, time_frame_num, frequency_bins), e.g., (12, 1024, 128) 164 | :return: prediction 165 | """ 166 | # expect input x = (batch_size, time_frame_num, frequency_bins), e.g., (12, 1024, 128) 167 | x = x.unsqueeze(1) 168 | x = x.transpose(2, 3) 169 | 170 | B = x.shape[0] 171 | x = self.v.patch_embed(x) 172 | cls_tokens = self.v.cls_token.expand(B, -1, -1) 173 | dist_token = self.v.dist_token.expand(B, -1, -1) 174 | x = torch.cat((cls_tokens, dist_token, x), dim=1) 175 | x = x + self.v.pos_embed 176 | x = self.v.pos_drop(x) 177 | for blk in self.v.blocks: 178 | x = blk(x) 179 | x = self.v.norm(x) 180 | frame_embeds = x 181 | x = (x[:, 0] + x[:, 1]) / 2 182 | 183 | x = self.mlp_head(x) 184 | return {"global": x.float(), "frame": frame_embeds.transpose(1, 2).float()} 185 | 186 | if __name__ == '__main__': 187 | input_tdim = 100 188 | ast_mdl = ASTModel(input_tdim=input_tdim) 189 | # input a batch of 10 spectrogram, each with 100 time frames and 128 frequency bins 190 | test_input = torch.rand([10, input_tdim, 128]) 191 | test_output = ast_mdl(test_input) 192 | # output should be in shape [10, 527], i.e., 10 samples, each with prediction of 527 classes. 193 | print(test_output.shape) 194 | 195 | input_tdim = 256 196 | ast_mdl = ASTModel(input_tdim=input_tdim, label_dim=50, audioset_pretrain=True) 197 | # input a batch of 10 spectrogram, each with 512 time frames and 128 frequency bins 198 | test_input = torch.rand([10, input_tdim, 128]) 199 | test_output = ast_mdl(test_input) 200 | # output should be in shape [10, 50], i.e., 10 samples, each with prediction of 50 classes. 201 | print(test_output.shape) -------------------------------------------------------------------------------- /src/train_sed.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import warnings 3 | 4 | import numpy as np 5 | import os 6 | import pandas as pd 7 | import random 8 | import torch 9 | import yaml 10 | 11 | import pytorch_lightning as pl 12 | from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint 13 | from pytorch_lightning.loggers import TensorBoardLogger 14 | 15 | from desed_task.dataio import ConcatDatasetBatchSampler 16 | from desed_task.dataio.datasets import StronglyAnnotatedSet, UnlabeledSet, WeakSet 17 | from desed_task.nnet.CRNN import CRNN 18 | from desed_task.utils.encoder import ManyHotEncoder 19 | from desed_task.utils.schedulers import ExponentialWarmup 20 | 21 | from local.classes_dict import classes_labels 22 | from local.sed_trainer import SEDTask4 23 | from local.resample_folder import resample_folder 24 | from local.utils import generate_tsv_wav_durations 25 | 26 | 27 | def resample_data_generate_durations(config_data, test_only=False, evaluation=False): 28 | if not test_only: 29 | dsets = [ 30 | "synth_folder", 31 | "val_folder", 32 | "strong_folder", 33 | "weak_folder", 34 | "unlabeled_folder", 35 | "test_folder", 36 | ] 37 | elif test_only: 38 | dsets = ["test_folder"] 39 | else: 40 | dsets = ["eval_folder"] 41 | 42 | for dset in dsets: 43 | computed = resample_folder( 44 | config_data[dset + "_44k"], config_data[dset], target_fs=config_data["fs"] 45 | ) 46 | 47 | if not evaluation: 48 | for base_set in ["val", "test"]: 49 | if not os.path.exists(config_data[base_set + "_dur"]) or computed: 50 | generate_tsv_wav_durations( 51 | config_data[base_set + "_folder"], config_data[base_set + "_dur"] 52 | ) 53 | 54 | def single_run( 55 | config, 56 | log_dir, 57 | gpus, 58 | strong_real=False, 59 | checkpoint_resume=None, 60 | test_state_dict=None, 61 | fast_dev_run=False, 62 | evaluation=False 63 | ): 64 | """ 65 | Running sound event detection baselin 66 | 67 | Args: 68 | config (dict): the dictionary of configuration params 69 | log_dir (str): path to log directory 70 | gpus (int): number of gpus to use 71 | checkpoint_resume (str, optional): path to checkpoint to resume from. Defaults to "". 72 | test_state_dict (dict, optional): if not None, no training is involved. This dictionary is the state_dict 73 | to be loaded to test the model. 74 | fast_dev_run (bool, optional): whether to use a run with only one batch at train and validation, useful 75 | for development purposes. 76 | """ 77 | config.update({"log_dir": log_dir}) 78 | 79 | ##### data prep test ########## 80 | encoder = ManyHotEncoder( 81 | list(classes_labels.keys()), 82 | audio_len=config["data"]["audio_max_len"], 83 | frame_len=config["feats"]["n_filters"], 84 | frame_hop=config["feats"]["hop_length"], 85 | net_pooling=config["data"]["net_subsample"], 86 | fs=config["data"]["fs"], 87 | ) 88 | 89 | if not evaluation: 90 | devtest_df = pd.read_csv(config["data"]["test_tsv"], sep="\t") 91 | devtest_dataset = StronglyAnnotatedSet( 92 | config["data"]["test_folder"], 93 | devtest_df, 94 | encoder, 95 | return_filename=True, 96 | pad_to=config["data"]["audio_max_len"] 97 | ) 98 | else: 99 | devtest_dataset = UnlabeledSet( 100 | config["data"]["eval_folder"], 101 | encoder, 102 | pad_to=None, 103 | return_filename=True 104 | ) 105 | 106 | test_dataset = devtest_dataset 107 | 108 | ##### model definition ############ 109 | sed_student = CRNN(**config["net"]) 110 | 111 | if test_state_dict is None: 112 | ##### data prep train valid ########## 113 | synth_df = pd.read_csv(config["data"]["synth_tsv"], sep="\t") 114 | synth_set = StronglyAnnotatedSet( 115 | config["data"]["synth_folder"], 116 | synth_df, 117 | encoder, 118 | pad_to=config["data"]["audio_max_len"], 119 | ) 120 | 121 | if strong_real: 122 | strong_df = pd.read_csv(config["data"]["strong_tsv"], sep="\t") 123 | strong_set = StronglyAnnotatedSet( 124 | config["data"]["strong_folder"], 125 | strong_df, 126 | encoder, 127 | pad_to=config["data"]["audio_max_len"], 128 | ) 129 | 130 | 131 | weak_df = pd.read_csv(config["data"]["weak_tsv"], sep="\t") 132 | train_weak_df = weak_df.sample( 133 | frac=config["training"]["weak_split"], 134 | random_state=config["training"]["seed"], 135 | ) 136 | valid_weak_df = weak_df.drop(train_weak_df.index).reset_index(drop=True) 137 | train_weak_df = train_weak_df.reset_index(drop=True) 138 | weak_set = WeakSet( 139 | config["data"]["weak_folder"], 140 | train_weak_df, 141 | encoder, 142 | pad_to=config["data"]["audio_max_len"], 143 | ) 144 | 145 | unlabeled_set = UnlabeledSet( 146 | config["data"]["unlabeled_folder"], 147 | encoder, 148 | pad_to=config["data"]["audio_max_len"], 149 | ) 150 | 151 | strong_val_df = pd.read_csv(config["data"]["val_tsv"], sep="\t") 152 | strong_val = StronglyAnnotatedSet( 153 | config["data"]["val_folder"], 154 | strong_val_df, 155 | encoder, 156 | return_filename=True, 157 | pad_to=config["data"]["audio_max_len"], 158 | ) 159 | 160 | weak_val = WeakSet( 161 | config["data"]["weak_folder"], 162 | valid_weak_df, 163 | encoder, 164 | pad_to=config["data"]["audio_max_len"], 165 | return_filename=True, 166 | ) 167 | 168 | if strong_real: 169 | strong_full_set = torch.utils.data.ConcatDataset([strong_set, synth_set]) 170 | tot_train_data = [strong_full_set, weak_set, unlabeled_set] 171 | else: 172 | tot_train_data = [synth_set, weak_set, unlabeled_set] 173 | train_dataset = torch.utils.data.ConcatDataset(tot_train_data) 174 | 175 | batch_sizes = config["training"]["batch_size"] 176 | samplers = [torch.utils.data.RandomSampler(x) for x in tot_train_data] 177 | batch_sampler = ConcatDatasetBatchSampler(samplers, batch_sizes) 178 | 179 | valid_dataset = torch.utils.data.ConcatDataset([strong_val, weak_val]) 180 | 181 | ##### training params and optimizers ############ 182 | epoch_len = min( 183 | [ 184 | len(tot_train_data[indx]) 185 | // ( 186 | config["training"]["batch_size"][indx] 187 | * config["training"]["accumulate_batches"] 188 | ) 189 | for indx in range(len(tot_train_data)) 190 | ] 191 | ) 192 | 193 | opt = torch.optim.Adam(sed_student.parameters(), 1e-3, betas=(0.9, 0.999)) 194 | exp_steps = config["training"]["n_epochs_warmup"] * epoch_len 195 | exp_scheduler = { 196 | "scheduler": ExponentialWarmup(opt, config["opt"]["lr"], exp_steps), 197 | "interval": "step", 198 | } 199 | logger = TensorBoardLogger( 200 | os.path.dirname(config["log_dir"]), config["log_dir"].split("/")[-1], 201 | ) 202 | print(f"experiment dir: {logger.log_dir}") 203 | 204 | callbacks = [ 205 | EarlyStopping( 206 | monitor="val/obj_metric", 207 | patience=config["training"]["early_stop_patience"], 208 | verbose=True, 209 | mode="max", 210 | ), 211 | ModelCheckpoint( 212 | logger.log_dir, 213 | monitor="val/obj_metric", 214 | save_top_k=1, 215 | mode="max", 216 | save_last=True, 217 | ), 218 | ] 219 | else: 220 | train_dataset = None 221 | valid_dataset = None 222 | batch_sampler = None 223 | opt = None 224 | exp_scheduler = None 225 | logger = True 226 | callbacks = None 227 | 228 | desed_training = SEDTask4( 229 | config, 230 | encoder=encoder, 231 | sed_student=sed_student, 232 | opt=opt, 233 | train_data=train_dataset, 234 | valid_data=valid_dataset, 235 | test_data=test_dataset, 236 | train_sampler=batch_sampler, 237 | scheduler=exp_scheduler, 238 | fast_dev_run=fast_dev_run, 239 | evaluation=evaluation, 240 | ) 241 | 242 | # Not using the fast_dev_run of Trainer because creates a DummyLogger so cannot check problems with the Logger 243 | if fast_dev_run: 244 | flush_logs_every_n_steps = 1 245 | log_every_n_steps = 1 246 | limit_train_batches = 2 247 | limit_val_batches = 2 248 | limit_test_batches = 2 249 | n_epochs = 3 250 | else: 251 | flush_logs_every_n_steps = 100 252 | log_every_n_steps = 40 253 | limit_train_batches = 1.0 254 | limit_val_batches = 1.0 255 | limit_test_batches = 1.0 256 | n_epochs = config["training"]["n_epochs"] 257 | 258 | 259 | if len(gpus.split(",")) > 1: 260 | raise NotImplementedError("Multiple GPUs are currently not supported") 261 | 262 | trainer = pl.Trainer( 263 | precision=config["training"]["precision"], 264 | max_epochs=n_epochs, 265 | callbacks=callbacks, 266 | gpus=gpus, 267 | strategy=config["training"].get("backend"), 268 | accumulate_grad_batches=config["training"]["accumulate_batches"], 269 | logger=logger, 270 | resume_from_checkpoint=checkpoint_resume, 271 | gradient_clip_val=config["training"]["gradient_clip"], 272 | check_val_every_n_epoch=config["training"]["validation_interval"], 273 | num_sanity_val_steps=0, 274 | log_every_n_steps=log_every_n_steps, 275 | flush_logs_every_n_steps=flush_logs_every_n_steps, 276 | limit_train_batches=limit_train_batches, 277 | limit_val_batches=limit_val_batches, 278 | limit_test_batches=limit_test_batches, 279 | ) 280 | 281 | if test_state_dict is None: 282 | 283 | # start tracking energy consumption 284 | trainer.fit(desed_training) 285 | best_path = trainer.checkpoint_callback.best_model_path 286 | print(f"best model: {best_path}") 287 | test_state_dict = torch.load(best_path)["state_dict"] 288 | 289 | desed_training.load_state_dict(test_state_dict) 290 | trainer.test(desed_training) 291 | 292 | 293 | if __name__ == "__main__": 294 | parser = argparse.ArgumentParser("Training a SED system for DESED Task") 295 | parser.add_argument( 296 | "--conf_file", 297 | default="./confs/default.yaml", 298 | help="The configuration file with all the experiment parameters.", 299 | ) 300 | parser.add_argument( 301 | "--log_dir", 302 | default="./exp/2022_baseline", 303 | help="Directory where to save tensorboard logs, saved models, etc.", 304 | ) 305 | 306 | parser.add_argument( 307 | "--strong_real", 308 | action="store_true", 309 | default=False, 310 | help="The strong annotations coming from Audioset will be included in the training phase.", 311 | ) 312 | parser.add_argument( 313 | "--resume_from_checkpoint", 314 | default=None, 315 | help="Allow the training to be resumed, take as input a previously saved model (.ckpt).", 316 | ) 317 | parser.add_argument( 318 | "--test_from_checkpoint", default=None, help="Test the model specified" 319 | ) 320 | parser.add_argument( 321 | "--gpus", 322 | default="1", 323 | help="The number of GPUs to train on, or the gpu to use, default='0', " 324 | "so uses one GPU", 325 | ) 326 | parser.add_argument( 327 | "--fast_dev_run", 328 | action="store_true", 329 | default=False, 330 | help="Use this option to make a 'fake' run which is useful for development and debugging. " 331 | "It uses very few batches and epochs so it won't give any meaningful result.", 332 | ) 333 | 334 | parser.add_argument( 335 | "--eval_from_checkpoint", 336 | default=None, 337 | help="Evaluate the model specified" 338 | ) 339 | 340 | args = parser.parse_args() 341 | 342 | with open(args.conf_file, "r") as f: 343 | configs = yaml.safe_load(f) 344 | 345 | evaluation = False 346 | test_from_checkpoint = args.test_from_checkpoint 347 | 348 | if args.eval_from_checkpoint is not None: 349 | test_from_checkpoint = args.eval_from_checkpoint 350 | evaluation = True 351 | 352 | test_model_state_dict = None 353 | if test_from_checkpoint is not None: 354 | checkpoint = torch.load(test_from_checkpoint) 355 | configs_ckpt = checkpoint["hyper_parameters"] 356 | configs_ckpt["data"] = configs["data"] 357 | print( 358 | f"loaded model: {test_from_checkpoint} \n" 359 | f"at epoch: {checkpoint['epoch']}" 360 | ) 361 | test_model_state_dict = checkpoint["state_dict"] 362 | 363 | if evaluation: 364 | configs["training"]["batch_size_val"] = 1 365 | 366 | seed = configs["training"]["seed"] 367 | if seed: 368 | torch.random.manual_seed(seed) 369 | np.random.seed(seed) 370 | random.seed(seed) 371 | pl.seed_everything(seed) 372 | 373 | test_only = test_from_checkpoint is not None 374 | resample_data_generate_durations(configs["data"], test_only, evaluation) 375 | single_run( 376 | configs, 377 | args.log_dir, 378 | args.gpus, 379 | args.strong_real, 380 | args.resume_from_checkpoint, 381 | test_model_state_dict, 382 | args.fast_dev_run, 383 | evaluation 384 | ) 385 | -------------------------------------------------------------------------------- /src/desed_task/dataio/datasets.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | import pandas as pd 3 | import os 4 | import numpy as np 5 | import torchaudio 6 | import random 7 | import torch 8 | import glob 9 | import h5py 10 | from pathlib import Path 11 | 12 | def to_mono(mixture, random_ch=False): 13 | 14 | if mixture.ndim > 1: # multi channel 15 | if not random_ch: 16 | mixture = torch.mean(mixture, 0) 17 | else: # randomly select one channel 18 | indx = np.random.randint(0, mixture.shape[0] - 1) 19 | mixture = mixture[indx] 20 | return mixture 21 | 22 | 23 | def pad_audio(audio, target_len, fs): 24 | 25 | if audio.shape[-1] < target_len: 26 | audio = torch.nn.functional.pad( 27 | audio, (0, target_len - audio.shape[-1]), mode="constant" 28 | ) 29 | 30 | padded_indx = [target_len / len(audio)] 31 | onset_s = 0.000 32 | 33 | elif len(audio) > target_len: 34 | 35 | rand_onset = random.randint(0, len(audio) - target_len) 36 | audio = audio[rand_onset:rand_onset + target_len] 37 | onset_s = round(rand_onset / fs, 3) 38 | 39 | padded_indx = [target_len / len(audio)] 40 | else: 41 | 42 | onset_s = 0.000 43 | padded_indx = [1.0] 44 | 45 | offset_s = round(onset_s + (target_len / fs), 3) 46 | return audio, onset_s, offset_s, padded_indx 47 | 48 | def process_labels(df, onset, offset): 49 | 50 | 51 | df["onset"] = df["onset"] - onset 52 | df["offset"] = df["offset"] - onset 53 | 54 | df["onset"] = df.apply(lambda x: max(0, x["onset"]), axis=1) 55 | df["offset"] = df.apply(lambda x: min(10, x["offset"]), axis=1) 56 | 57 | df_new = df[(df.onset < df.offset)] 58 | 59 | return df_new.drop_duplicates() 60 | 61 | 62 | def read_audio(file, multisrc, random_channel, pad_to): 63 | 64 | mixture, fs = torchaudio.load(file) 65 | 66 | if not multisrc: 67 | mixture = to_mono(mixture, random_channel) 68 | 69 | if pad_to is not None: 70 | mixture, onset_s, offset_s, padded_indx = pad_audio(mixture, pad_to, fs) 71 | else: 72 | padded_indx = [1.0] 73 | onset_s = None 74 | offset_s = None 75 | 76 | mixture = mixture.float() 77 | return mixture, onset_s, offset_s, padded_indx 78 | 79 | 80 | class StronglyAnnotatedSet(Dataset): 81 | def __init__( 82 | self, 83 | audio_folder, 84 | tsv_entries, 85 | encoder, 86 | pad_to=10, 87 | fs=16000, 88 | return_filename=False, 89 | random_channel=False, 90 | multisrc=False, 91 | feats_pipeline=None, 92 | embeddings_hdf5_file=None, 93 | embedding_type=None 94 | 95 | ): 96 | 97 | self.encoder = encoder 98 | self.fs = fs 99 | self.pad_to = pad_to * fs 100 | self.return_filename = return_filename 101 | self.random_channel = random_channel 102 | self.multisrc = multisrc 103 | self.feats_pipeline = feats_pipeline 104 | self.embeddings_hdf5_file = embeddings_hdf5_file 105 | self.embedding_type = embedding_type 106 | assert embedding_type in ["global", "frame", None], "embedding type are either frame or global or None, got {}".format(embedding_type) 107 | 108 | tsv_entries = tsv_entries.dropna() 109 | 110 | examples = {} 111 | for i, r in tsv_entries.iterrows(): 112 | if r["filename"] not in examples.keys(): 113 | examples[r["filename"]] = { 114 | "mixture": os.path.join(audio_folder, r["filename"]), 115 | "events": [], 116 | } 117 | if not np.isnan(r["onset"]): 118 | examples[r["filename"]]["events"].append( 119 | { 120 | "event_label": r["event_label"], 121 | "onset": r["onset"], 122 | "offset": r["offset"], 123 | } 124 | ) 125 | else: 126 | if not np.isnan(r["onset"]): 127 | examples[r["filename"]]["events"].append( 128 | { 129 | "event_label": r["event_label"], 130 | "onset": r["onset"], 131 | "offset": r["offset"], 132 | } 133 | ) 134 | 135 | # we construct a dictionary for each example 136 | self.examples = examples 137 | self.examples_list = list(examples.keys()) 138 | 139 | if self.embeddings_hdf5_file is not None: 140 | assert self.embedding_type is not None, "If you use embeddings you need to specify also the type (global or frame)" 141 | # fetch dict of positions for each example 142 | self.ex2emb_idx = {} 143 | f = h5py.File(self.embeddings_hdf5_file, "r") 144 | for k, v in f["frame_embeddings"].attrs.items(): 145 | self.ex2emb_idx[k] = v 146 | self._opened_hdf5 = None 147 | 148 | def __len__(self): 149 | return len(self.examples_list) 150 | 151 | @property 152 | def hdf5_file(self): 153 | if self._opened_hdf5 is None: 154 | self._opened_hdf5 = h5py.File(self.embeddings_hdf5_file, "r") 155 | return self._opened_hdf5 156 | 157 | def __getitem__(self, item): 158 | 159 | c_ex = self.examples[self.examples_list[item]] 160 | mixture, onset_s, offset_s, padded_indx = read_audio( 161 | c_ex["mixture"], self.multisrc, self.random_channel, self.pad_to 162 | ) 163 | 164 | # labels 165 | labels = c_ex["events"] 166 | 167 | # to steps 168 | labels_df = pd.DataFrame(labels) 169 | labels_df = process_labels(labels_df, onset_s, offset_s) 170 | 171 | # check if labels exists: 172 | if not len(labels_df): 173 | max_len_targets = self.encoder.n_frames 174 | strong = torch.zeros(max_len_targets, len(self.encoder.labels)).float() 175 | else: 176 | strong = self.encoder.encode_strong_df(labels_df) 177 | strong = torch.from_numpy(strong).float() 178 | 179 | out_args = [mixture, strong.transpose(0, 1), padded_indx] 180 | 181 | if self.feats_pipeline is not None: 182 | # use this function to extract features in the dataloader and apply possibly some data augm 183 | feats = self.feats_pipeline(mixture) 184 | out_args.append(feats) 185 | if self.return_filename: 186 | out_args.append(c_ex["mixture"]) 187 | 188 | if self.embeddings_hdf5_file is not None: 189 | name = Path(c_ex["mixture"]).stem 190 | index = self.ex2emb_idx[name] 191 | 192 | global_embeddings = torch.from_numpy(self.hdf5_file["global_embeddings"][index]).float() 193 | frame_embeddings = torch.from_numpy(np.stack(self.hdf5_file["frame_embeddings"][index])).float() 194 | if self.embedding_type == "global": 195 | embeddings = global_embeddings 196 | elif self.embedding_type == "frame": 197 | embeddings = frame_embeddings 198 | else: 199 | raise NotImplementedError 200 | 201 | out_args.append(embeddings) 202 | 203 | return out_args 204 | 205 | 206 | class WeakSet(Dataset): 207 | 208 | def __init__( 209 | self, 210 | audio_folder, 211 | tsv_entries, 212 | encoder, 213 | pad_to=10, 214 | fs=16000, 215 | return_filename=False, 216 | random_channel=False, 217 | multisrc=False, 218 | feats_pipeline=None, 219 | embeddings_hdf5_file=None, 220 | embedding_type=None, 221 | 222 | ): 223 | 224 | self.encoder = encoder 225 | self.fs = fs 226 | self.pad_to = pad_to * fs 227 | self.return_filename = return_filename 228 | self.random_channel = random_channel 229 | self.multisrc = multisrc 230 | self.feats_pipeline = feats_pipeline 231 | self.embeddings_hdf5_file = embeddings_hdf5_file 232 | self.embedding_type = embedding_type 233 | assert embedding_type in ["global", "frame", 234 | None], "embedding type are either frame or global or None, got {}".format( 235 | embedding_type) 236 | 237 | examples = {} 238 | for i, r in tsv_entries.iterrows(): 239 | 240 | if r["filename"] not in examples.keys(): 241 | examples[r["filename"]] = { 242 | "mixture": os.path.join(audio_folder, r["filename"]), 243 | "events": r["event_labels"].split(","), 244 | } 245 | 246 | self.examples = examples 247 | self.examples_list = list(examples.keys()) 248 | 249 | if self.embeddings_hdf5_file is not None: 250 | assert self.embedding_type is not None, "If you use embeddings you need to specify also the type (global or frame)" 251 | # fetch dict of positions for each example 252 | self.ex2emb_idx = {} 253 | f = h5py.File(self.embeddings_hdf5_file, "r") 254 | for k, v in f["frame_embeddings"].attrs.items(): 255 | self.ex2emb_idx[k] = v 256 | self._opened_hdf5 = None 257 | 258 | def __len__(self): 259 | return len(self.examples_list) 260 | 261 | @property 262 | def hdf5_file(self): 263 | if self._opened_hdf5 is None: 264 | self._opened_hdf5 = h5py.File(self.embeddings_hdf5_file, "r") 265 | return self._opened_hdf5 266 | 267 | def __getitem__(self, item): 268 | file = self.examples_list[item] 269 | c_ex = self.examples[file] 270 | 271 | mixture, _, _, padded_indx = read_audio( 272 | c_ex["mixture"], self.multisrc, self.random_channel, self.pad_to 273 | ) 274 | 275 | # labels 276 | labels = c_ex["events"] 277 | # check if labels exists: 278 | max_len_targets = self.encoder.n_frames 279 | weak = torch.zeros(max_len_targets, len(self.encoder.labels)) 280 | if len(labels): 281 | weak_labels = self.encoder.encode_weak(labels) 282 | weak[0, :] = torch.from_numpy(weak_labels).float() 283 | 284 | out_args = [mixture, weak.transpose(0, 1), padded_indx] 285 | if self.feats_pipeline is not None: 286 | feats = self.feats_pipeline(mixture) 287 | out_args.append(feats) 288 | 289 | if self.return_filename: 290 | out_args.append(c_ex["mixture"]) 291 | 292 | if self.embeddings_hdf5_file is not None: 293 | name = Path(c_ex["mixture"]).stem 294 | index = self.ex2emb_idx[name] 295 | 296 | global_embeddings = torch.from_numpy(self.hdf5_file["global_embeddings"][index]).float() 297 | frame_embeddings = torch.from_numpy(np.stack(self.hdf5_file["frame_embeddings"][index])).float() 298 | if self.embedding_type == "global": 299 | embeddings = global_embeddings 300 | elif self.embedding_type == "frame": 301 | embeddings = frame_embeddings 302 | else: 303 | raise NotImplementedError 304 | 305 | out_args.append(embeddings) 306 | 307 | 308 | return out_args 309 | 310 | 311 | class UnlabeledSet(Dataset): 312 | def __init__( 313 | self, 314 | unlabeled_folder, 315 | encoder, 316 | pad_to=10, 317 | fs=16000, 318 | return_filename=False, 319 | random_channel=False, 320 | multisrc=False, 321 | feats_pipeline=None, 322 | embeddings_hdf5_file=None, 323 | embedding_type=None, 324 | ): 325 | 326 | self.encoder = encoder 327 | self.fs = fs 328 | self.pad_to = pad_to * fs if pad_to is not None else None 329 | self.examples = glob.glob(os.path.join(unlabeled_folder, "*.wav")) 330 | self.return_filename = return_filename 331 | self.random_channel = random_channel 332 | self.multisrc = multisrc 333 | self.feats_pipeline = feats_pipeline 334 | self.embeddings_hdf5_file = embeddings_hdf5_file 335 | self.embedding_type = embedding_type 336 | assert embedding_type in ["global", "frame", 337 | None], "embedding type are either frame or global or None, got {}".format( 338 | embedding_type) 339 | 340 | if self.embeddings_hdf5_file is not None: 341 | assert self.embedding_type is not None, "If you use embeddings you need to specify also the type (global or frame)" 342 | # fetch dict of positions for each example 343 | self.ex2emb_idx = {} 344 | f = h5py.File(self.embeddings_hdf5_file, "r") 345 | for k, v in f["frame_embeddings"].attrs.items(): 346 | self.ex2emb_idx[k] = v 347 | self._opened_hdf5 = None 348 | 349 | def __len__(self): 350 | return len(self.examples) 351 | 352 | @property 353 | def hdf5_file(self): 354 | if self._opened_hdf5 is None: 355 | self._opened_hdf5 = h5py.File(self.embeddings_hdf5_file, "r") 356 | return self._opened_hdf5 357 | 358 | def __getitem__(self, item): 359 | c_ex = self.examples[item] 360 | 361 | mixture, _, _, padded_indx = read_audio( 362 | c_ex, self.multisrc, self.random_channel, self.pad_to 363 | ) 364 | 365 | max_len_targets = self.encoder.n_frames 366 | strong = torch.zeros(max_len_targets, len(self.encoder.labels)).float() 367 | out_args = [mixture, strong.transpose(0, 1), padded_indx] 368 | if self.feats_pipeline is not None: 369 | feats = self.feats_pipeline(mixture) 370 | out_args.append(feats) 371 | 372 | if self.return_filename: 373 | out_args.append(c_ex) 374 | 375 | if self.embeddings_hdf5_file is not None: 376 | name = Path(c_ex).stem 377 | index = self.ex2emb_idx[name] 378 | 379 | global_embeddings = torch.from_numpy(self.hdf5_file["global_embeddings"][index]).float() 380 | frame_embeddings = torch.from_numpy(np.stack(self.hdf5_file["frame_embeddings"][index])).float() 381 | if self.embedding_type == "global": 382 | embeddings = global_embeddings 383 | elif self.embedding_type == "frame": 384 | embeddings = frame_embeddings 385 | else: 386 | raise NotImplementedError 387 | 388 | out_args.append(embeddings) 389 | 390 | return out_args 391 | -------------------------------------------------------------------------------- /src/desed_task/dataio/.ipynb_checkpoints/datasets-checkpoint.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | import pandas as pd 3 | import os 4 | import numpy as np 5 | import torchaudio 6 | import random 7 | import torch 8 | import glob 9 | import h5py 10 | from pathlib import Path 11 | 12 | def to_mono(mixture, random_ch=False): 13 | 14 | if mixture.ndim > 1: # multi channel 15 | if not random_ch: 16 | mixture = torch.mean(mixture, 0) 17 | else: # randomly select one channel 18 | indx = np.random.randint(0, mixture.shape[0] - 1) 19 | mixture = mixture[indx] 20 | return mixture 21 | 22 | 23 | def pad_audio(audio, target_len, fs): 24 | 25 | if audio.shape[-1] < target_len: 26 | audio = torch.nn.functional.pad( 27 | audio, (0, target_len - audio.shape[-1]), mode="constant" 28 | ) 29 | 30 | padded_indx = [target_len / len(audio)] 31 | onset_s = 0.000 32 | 33 | elif len(audio) > target_len: 34 | 35 | rand_onset = random.randint(0, len(audio) - target_len) 36 | audio = audio[rand_onset:rand_onset + target_len] 37 | onset_s = round(rand_onset / fs, 3) 38 | 39 | padded_indx = [target_len / len(audio)] 40 | else: 41 | 42 | onset_s = 0.000 43 | padded_indx = [1.0] 44 | 45 | offset_s = round(onset_s + (target_len / fs), 3) 46 | return audio, onset_s, offset_s, padded_indx 47 | 48 | def process_labels(df, onset, offset): 49 | 50 | 51 | df["onset"] = df["onset"] - onset 52 | df["offset"] = df["offset"] - onset 53 | 54 | df["onset"] = df.apply(lambda x: max(0, x["onset"]), axis=1) 55 | df["offset"] = df.apply(lambda x: min(10, x["offset"]), axis=1) 56 | 57 | df_new = df[(df.onset < df.offset)] 58 | 59 | return df_new.drop_duplicates() 60 | 61 | 62 | def read_audio(file, multisrc, random_channel, pad_to): 63 | 64 | mixture, fs = torchaudio.load(file) 65 | 66 | if not multisrc: 67 | mixture = to_mono(mixture, random_channel) 68 | 69 | if pad_to is not None: 70 | mixture, onset_s, offset_s, padded_indx = pad_audio(mixture, pad_to, fs) 71 | else: 72 | padded_indx = [1.0] 73 | onset_s = None 74 | offset_s = None 75 | 76 | mixture = mixture.float() 77 | return mixture, onset_s, offset_s, padded_indx 78 | 79 | 80 | class StronglyAnnotatedSet(Dataset): 81 | def __init__( 82 | self, 83 | audio_folder, 84 | tsv_entries, 85 | encoder, 86 | pad_to=10, 87 | fs=16000, 88 | return_filename=False, 89 | random_channel=False, 90 | multisrc=False, 91 | feats_pipeline=None, 92 | embeddings_hdf5_file=None, 93 | embedding_type=None 94 | 95 | ): 96 | 97 | self.encoder = encoder 98 | self.fs = fs 99 | self.pad_to = pad_to * fs 100 | self.return_filename = return_filename 101 | self.random_channel = random_channel 102 | self.multisrc = multisrc 103 | self.feats_pipeline = feats_pipeline 104 | self.embeddings_hdf5_file = embeddings_hdf5_file 105 | self.embedding_type = embedding_type 106 | assert embedding_type in ["global", "frame", None], "embedding type are either frame or global or None, got {}".format(embedding_type) 107 | 108 | tsv_entries = tsv_entries.dropna() 109 | 110 | examples = {} 111 | for i, r in tsv_entries.iterrows(): 112 | if r["filename"] not in examples.keys(): 113 | examples[r["filename"]] = { 114 | "mixture": os.path.join(audio_folder, r["filename"]), 115 | "events": [], 116 | } 117 | if not np.isnan(r["onset"]): 118 | examples[r["filename"]]["events"].append( 119 | { 120 | "event_label": r["event_label"], 121 | "onset": r["onset"], 122 | "offset": r["offset"], 123 | } 124 | ) 125 | else: 126 | if not np.isnan(r["onset"]): 127 | examples[r["filename"]]["events"].append( 128 | { 129 | "event_label": r["event_label"], 130 | "onset": r["onset"], 131 | "offset": r["offset"], 132 | } 133 | ) 134 | 135 | # we construct a dictionary for each example 136 | self.examples = examples 137 | self.examples_list = list(examples.keys()) 138 | 139 | if self.embeddings_hdf5_file is not None: 140 | assert self.embedding_type is not None, "If you use embeddings you need to specify also the type (global or frame)" 141 | # fetch dict of positions for each example 142 | self.ex2emb_idx = {} 143 | f = h5py.File(self.embeddings_hdf5_file, "r") 144 | for k, v in f["frame_embeddings"].attrs.items(): 145 | self.ex2emb_idx[k] = v 146 | self._opened_hdf5 = None 147 | 148 | def __len__(self): 149 | return len(self.examples_list) 150 | 151 | @property 152 | def hdf5_file(self): 153 | if self._opened_hdf5 is None: 154 | self._opened_hdf5 = h5py.File(self.embeddings_hdf5_file, "r") 155 | return self._opened_hdf5 156 | 157 | def __getitem__(self, item): 158 | 159 | c_ex = self.examples[self.examples_list[item]] 160 | mixture, onset_s, offset_s, padded_indx = read_audio( 161 | c_ex["mixture"], self.multisrc, self.random_channel, self.pad_to 162 | ) 163 | 164 | # labels 165 | labels = c_ex["events"] 166 | 167 | # to steps 168 | labels_df = pd.DataFrame(labels) 169 | labels_df = process_labels(labels_df, onset_s, offset_s) 170 | 171 | # check if labels exists: 172 | if not len(labels_df): 173 | max_len_targets = self.encoder.n_frames 174 | strong = torch.zeros(max_len_targets, len(self.encoder.labels)).float() 175 | else: 176 | strong = self.encoder.encode_strong_df(labels_df) 177 | strong = torch.from_numpy(strong).float() 178 | 179 | out_args = [mixture, strong.transpose(0, 1), padded_indx] 180 | 181 | if self.feats_pipeline is not None: 182 | # use this function to extract features in the dataloader and apply possibly some data augm 183 | feats = self.feats_pipeline(mixture) 184 | out_args.append(feats) 185 | if self.return_filename: 186 | out_args.append(c_ex["mixture"]) 187 | 188 | if self.embeddings_hdf5_file is not None: 189 | name = Path(c_ex["mixture"]).stem 190 | index = self.ex2emb_idx[name] 191 | 192 | global_embeddings = torch.from_numpy(self.hdf5_file["global_embeddings"][index]).float() 193 | frame_embeddings = torch.from_numpy(np.stack(self.hdf5_file["frame_embeddings"][index])).float() 194 | if self.embedding_type == "global": 195 | embeddings = global_embeddings 196 | elif self.embedding_type == "frame": 197 | embeddings = frame_embeddings 198 | else: 199 | raise NotImplementedError 200 | 201 | out_args.append(embeddings) 202 | 203 | return out_args 204 | 205 | 206 | class WeakSet(Dataset): 207 | 208 | def __init__( 209 | self, 210 | audio_folder, 211 | tsv_entries, 212 | encoder, 213 | pad_to=10, 214 | fs=16000, 215 | return_filename=False, 216 | random_channel=False, 217 | multisrc=False, 218 | feats_pipeline=None, 219 | embeddings_hdf5_file=None, 220 | embedding_type=None, 221 | 222 | ): 223 | 224 | self.encoder = encoder 225 | self.fs = fs 226 | self.pad_to = pad_to * fs 227 | self.return_filename = return_filename 228 | self.random_channel = random_channel 229 | self.multisrc = multisrc 230 | self.feats_pipeline = feats_pipeline 231 | self.embeddings_hdf5_file = embeddings_hdf5_file 232 | self.embedding_type = embedding_type 233 | assert embedding_type in ["global", "frame", 234 | None], "embedding type are either frame or global or None, got {}".format( 235 | embedding_type) 236 | 237 | examples = {} 238 | for i, r in tsv_entries.iterrows(): 239 | 240 | if r["filename"] not in examples.keys(): 241 | examples[r["filename"]] = { 242 | "mixture": os.path.join(audio_folder, r["filename"]), 243 | "events": r["event_labels"].split(","), 244 | } 245 | 246 | self.examples = examples 247 | self.examples_list = list(examples.keys()) 248 | 249 | if self.embeddings_hdf5_file is not None: 250 | assert self.embedding_type is not None, "If you use embeddings you need to specify also the type (global or frame)" 251 | # fetch dict of positions for each example 252 | self.ex2emb_idx = {} 253 | f = h5py.File(self.embeddings_hdf5_file, "r") 254 | for k, v in f["frame_embeddings"].attrs.items(): 255 | self.ex2emb_idx[k] = v 256 | self._opened_hdf5 = None 257 | 258 | def __len__(self): 259 | return len(self.examples_list) 260 | 261 | @property 262 | def hdf5_file(self): 263 | if self._opened_hdf5 is None: 264 | self._opened_hdf5 = h5py.File(self.embeddings_hdf5_file, "r") 265 | return self._opened_hdf5 266 | 267 | def __getitem__(self, item): 268 | file = self.examples_list[item] 269 | c_ex = self.examples[file] 270 | 271 | mixture, _, _, padded_indx = read_audio( 272 | c_ex["mixture"], self.multisrc, self.random_channel, self.pad_to 273 | ) 274 | 275 | # labels 276 | labels = c_ex["events"] 277 | # check if labels exists: 278 | max_len_targets = self.encoder.n_frames 279 | weak = torch.zeros(max_len_targets, len(self.encoder.labels)) 280 | if len(labels): 281 | weak_labels = self.encoder.encode_weak(labels) 282 | weak[0, :] = torch.from_numpy(weak_labels).float() 283 | 284 | out_args = [mixture, weak.transpose(0, 1), padded_indx] 285 | if self.feats_pipeline is not None: 286 | feats = self.feats_pipeline(mixture) 287 | out_args.append(feats) 288 | 289 | if self.return_filename: 290 | out_args.append(c_ex["mixture"]) 291 | 292 | if self.embeddings_hdf5_file is not None: 293 | name = Path(c_ex["mixture"]).stem 294 | index = self.ex2emb_idx[name] 295 | 296 | global_embeddings = torch.from_numpy(self.hdf5_file["global_embeddings"][index]).float() 297 | frame_embeddings = torch.from_numpy(np.stack(self.hdf5_file["frame_embeddings"][index])).float() 298 | if self.embedding_type == "global": 299 | embeddings = global_embeddings 300 | elif self.embedding_type == "frame": 301 | embeddings = frame_embeddings 302 | else: 303 | raise NotImplementedError 304 | 305 | out_args.append(embeddings) 306 | 307 | 308 | return out_args 309 | 310 | 311 | class UnlabeledSet(Dataset): 312 | def __init__( 313 | self, 314 | unlabeled_folder, 315 | encoder, 316 | pad_to=10, 317 | fs=16000, 318 | return_filename=False, 319 | random_channel=False, 320 | multisrc=False, 321 | feats_pipeline=None, 322 | embeddings_hdf5_file=None, 323 | embedding_type=None, 324 | ): 325 | 326 | self.encoder = encoder 327 | self.fs = fs 328 | self.pad_to = pad_to * fs if pad_to is not None else None 329 | self.examples = glob.glob(os.path.join(unlabeled_folder, "*.wav")) 330 | self.return_filename = return_filename 331 | self.random_channel = random_channel 332 | self.multisrc = multisrc 333 | self.feats_pipeline = feats_pipeline 334 | self.embeddings_hdf5_file = embeddings_hdf5_file 335 | self.embedding_type = embedding_type 336 | assert embedding_type in ["global", "frame", 337 | None], "embedding type are either frame or global or None, got {}".format( 338 | embedding_type) 339 | 340 | if self.embeddings_hdf5_file is not None: 341 | assert self.embedding_type is not None, "If you use embeddings you need to specify also the type (global or frame)" 342 | # fetch dict of positions for each example 343 | self.ex2emb_idx = {} 344 | f = h5py.File(self.embeddings_hdf5_file, "r") 345 | for k, v in f["frame_embeddings"].attrs.items(): 346 | self.ex2emb_idx[k] = v 347 | self._opened_hdf5 = None 348 | 349 | def __len__(self): 350 | return len(self.examples) 351 | 352 | @property 353 | def hdf5_file(self): 354 | if self._opened_hdf5 is None: 355 | self._opened_hdf5 = h5py.File(self.embeddings_hdf5_file, "r") 356 | return self._opened_hdf5 357 | 358 | def __getitem__(self, item): 359 | c_ex = self.examples[item] 360 | 361 | mixture, _, _, padded_indx = read_audio( 362 | c_ex, self.multisrc, self.random_channel, self.pad_to 363 | ) 364 | 365 | max_len_targets = self.encoder.n_frames 366 | strong = torch.zeros(max_len_targets, len(self.encoder.labels)).float() 367 | out_args = [mixture, strong.transpose(0, 1), padded_indx] 368 | if self.feats_pipeline is not None: 369 | feats = self.feats_pipeline(mixture) 370 | out_args.append(feats) 371 | 372 | if self.return_filename: 373 | out_args.append(c_ex) 374 | 375 | if self.embeddings_hdf5_file is not None: 376 | name = Path(c_ex).stem 377 | index = self.ex2emb_idx[name] 378 | 379 | global_embeddings = torch.from_numpy(self.hdf5_file["global_embeddings"][index]).float() 380 | frame_embeddings = torch.from_numpy(np.stack(self.hdf5_file["frame_embeddings"][index])).float() 381 | if self.embedding_type == "global": 382 | embeddings = global_embeddings 383 | elif self.embedding_type == "frame": 384 | embeddings = frame_embeddings 385 | else: 386 | raise NotImplementedError 387 | 388 | out_args.append(embeddings) 389 | 390 | return out_args 391 | -------------------------------------------------------------------------------- /src/train_pretrained.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from copy import deepcopy 3 | import numpy as np 4 | import os 5 | import pandas as pd 6 | import random 7 | import torch 8 | import yaml 9 | import torchaudio 10 | 11 | import pytorch_lightning as pl 12 | from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint 13 | from pytorch_lightning.loggers import TensorBoardLogger 14 | 15 | from desed_task.dataio import ConcatDatasetBatchSampler 16 | from desed_task.dataio.datasets import StronglyAnnotatedSet, UnlabeledSet, WeakSet 17 | from desed_task.nnet.CRNN import CRNN 18 | from desed_task.utils.encoder import ManyHotEncoder 19 | from desed_task.utils.schedulers import ExponentialWarmup 20 | 21 | from local.classes_dict import classes_labels 22 | from local.sed_trainer_pretrained import SEDTask4 23 | from local.resample_folder import resample_folder 24 | from local.utils import generate_tsv_wav_durations 25 | from desed_task.utils.download import download_from_url 26 | 27 | 28 | def resample_data_generate_durations(config_data, test_only=False, evaluation=False): 29 | if not test_only: 30 | dsets = [ 31 | "synth_folder", 32 | "synth_val_folder", 33 | "weak_folder", 34 | "unlabeled_folder", 35 | "test_folder", 36 | ] 37 | elif test_only: 38 | dsets = ["test_folder"] 39 | else: 40 | dsets = ["eval_folder"] 41 | 42 | for dset in dsets: 43 | computed = resample_folder( 44 | config_data[dset + "_44k"], config_data[dset], target_fs=config_data["fs"] 45 | ) 46 | 47 | if not evaluation: 48 | for base_set in ["synth_val", "test"]: 49 | if not os.path.exists(config_data[base_set + "_dur"]) or computed: 50 | generate_tsv_wav_durations( 51 | config_data[base_set + "_folder"], config_data[base_set + "_dur"] 52 | ) 53 | 54 | 55 | def single_run( 56 | config, 57 | log_dir, 58 | gpus, 59 | checkpoint_resume=None, 60 | test_state_dict=None, 61 | fast_dev_run=False, 62 | evaluation=False 63 | ): 64 | """ 65 | Running sound event detection baselin 66 | 67 | Args: 68 | config (dict): the dictionary of configuration params 69 | log_dir (str): path to log directory 70 | gpus (int): number of gpus to use 71 | checkpoint_resume (str, optional): path to checkpoint to resume from. Defaults to "". 72 | test_state_dict (dict, optional): if not None, no training is involved. This dictionary is the state_dict 73 | to be loaded to test the model. 74 | fast_dev_run (bool, optional): whether to use a run with only one batch at train and validation, useful 75 | for development purposes. 76 | """ 77 | config.update({"log_dir": log_dir}) 78 | 79 | ##### data prep test ########## 80 | encoder = ManyHotEncoder( 81 | list(classes_labels.keys()), 82 | audio_len=config["data"]["audio_max_len"], 83 | frame_len=config["feats"]["n_filters"], 84 | frame_hop=config["feats"]["hop_length"], 85 | net_pooling=config["data"]["net_subsample"], 86 | fs=config["data"]["fs"], 87 | ) 88 | 89 | if not config["pretrained"]["freezed"]: 90 | assert config["pretrained"]["e2e"], "If freezed is false, you have to train end2end ! " \ 91 | "You cannot use precomputed embeddings if you want to update the pretrained model." 92 | #FIXME 93 | if not config["pretrained"]["e2e"]: 94 | assert config["pretrained"]["extracted_embeddings_dir"] is not None, \ 95 | "If e2e is false, you have to download pretrained embeddings from {}" \ 96 | "and set in the config yaml file the path to the downloaded directory".format("REPLACE ME") 97 | 98 | if config["pretrained"]["model"] == "ast" and config["pretrained"]["e2e"]: 99 | # feature extraction pipeline for SSAST 100 | class ASTFeatsExtraction: 101 | # need feature extraction in dataloader because kaldi compliant torchaudio fbank are used (no gpu support) 102 | def __init__(self, audioset_mean=-4.2677393, audioset_std=4.5689974, 103 | target_length=1024): 104 | super(ASTFeatsExtraction, self).__init__() 105 | self.audioset_mean = audioset_mean 106 | self.audioset_std = audioset_std 107 | self.target_length = target_length 108 | def __call__(self, waveform): 109 | waveform = waveform - torch.mean(waveform, -1) 110 | 111 | fbank = torchaudio.compliance.kaldi.fbank(waveform.unsqueeze(0), htk_compat=True, sample_frequency=16000, use_energy=False, 112 | window_type='hanning', num_mel_bins=128, 113 | dither=0.0, frame_shift=10) 114 | fbank = torch.nn.functional.pad(fbank, (0, 0, 0, self.target_length-fbank.shape[0]), mode="constant") 115 | 116 | fbank = (fbank - self.audioset_mean) / (self.audioset_std * 2) 117 | return fbank 118 | 119 | assert config["data"]["fs"] == 16000, "this pretrained model is trained on 16k" 120 | feature_extraction = ASTFeatsExtraction() 121 | from local.ast.ast_models import ASTModel 122 | pretrained = ASTModel(label_dim=527, 123 | fstride=10, tstride=10, 124 | input_fdim=128, input_tdim=1024, 125 | imagenet_pretrain=True, audioset_pretrain=True, 126 | model_size='base384') 127 | 128 | elif config["pretrained"]["model"] == "panns" and config["pretrained"]["e2e"]: 129 | assert config["data"]["fs"] == 16000, "this pretrained model is trained on 16k" 130 | feature_extraction = None # integrated in the model 131 | download_from_url(config["pretrained"]["url"], config["pretrained"]["dest"]) 132 | # use PANNs as additional feature 133 | from local.panns.models import Cnn14_16k 134 | pretrained = Cnn14_16k() 135 | pretrained.load_state_dict(torch.load(config["pretrained"]["dest"])["model"], strict=False) 136 | else: 137 | pretrained = None 138 | feature_extraction = None 139 | 140 | crnn = CRNN(**config["net"]) 141 | 142 | if not evaluation: 143 | devtest_df = pd.read_csv(config["data"]["test_tsv"], sep="\t") 144 | devtest_embeddings = None if config["pretrained"]["e2e"] else os.path.join(config["pretrained"]["extracted_embeddings_dir"], 145 | config["pretrained"]["model"], "devtest.hdf5") 146 | devtest_dataset = StronglyAnnotatedSet( 147 | config["data"]["test_folder"], 148 | devtest_df, 149 | encoder, 150 | return_filename=True, 151 | pad_to=config["data"]["audio_max_len"], feats_pipeline=feature_extraction, 152 | embeddings_hdf5_file=devtest_embeddings, 153 | embedding_type=config["net"]["embedding_type"] 154 | ) 155 | else: 156 | devtest_dataset = UnlabeledSet( 157 | config["data"]["eval_folder"], 158 | encoder, 159 | pad_to=None, 160 | return_filename=True, feats_pipeline=feature_extraction 161 | ) 162 | 163 | test_dataset = devtest_dataset 164 | 165 | ##### model definition ############ 166 | if test_state_dict is None: 167 | ##### data prep train valid ########## 168 | synth_df = pd.read_csv(config["data"]["synth_tsv"], sep="\t") 169 | synth_set_embeddings = None if config["pretrained"]["e2e"] else os.path.join(config["pretrained"]["extracted_embeddings_dir"], 170 | config["pretrained"]["model"], 171 | "synth_train.hdf5") 172 | synth_set = StronglyAnnotatedSet( 173 | config["data"]["synth_folder"], 174 | synth_df, 175 | encoder, 176 | pad_to=config["data"]["audio_max_len"], 177 | feats_pipeline=feature_extraction, 178 | embeddings_hdf5_file=synth_set_embeddings, 179 | embedding_type=config["net"]["embedding_type"] 180 | ) 181 | synth_set[0] 182 | 183 | weak_df = pd.read_csv(config["data"]["weak_tsv"], sep="\t") 184 | train_weak_df = weak_df.sample( 185 | frac=config["training"]["weak_split"], 186 | random_state=config["training"]["seed"], 187 | ) 188 | valid_weak_df = weak_df.drop(train_weak_df.index).reset_index(drop=True) 189 | train_weak_df = train_weak_df.reset_index(drop=True) 190 | weak_set_embeddings = None if config["pretrained"]["e2e"] else os.path.join(config["pretrained"]["extracted_embeddings_dir"], 191 | config["pretrained"]["model"], 192 | "weak_train.hdf5") 193 | weak_set = WeakSet( 194 | config["data"]["weak_folder"], 195 | train_weak_df, 196 | encoder, 197 | pad_to=config["data"]["audio_max_len"], feats_pipeline=feature_extraction, 198 | embeddings_hdf5_file=weak_set_embeddings, 199 | embedding_type=config["net"]["embedding_type"] 200 | 201 | ) 202 | unlabeled_set_embeddings = None if config["pretrained"]["e2e"] else os.path.join(config["pretrained"]["extracted_embeddings_dir"], 203 | config["pretrained"]["model"], 204 | "unlabeled_train.hdf5") 205 | unlabeled_set = UnlabeledSet( 206 | config["data"]["unlabeled_folder"], 207 | encoder, 208 | pad_to=config["data"]["audio_max_len"], feats_pipeline=feature_extraction, 209 | embeddings_hdf5_file=unlabeled_set_embeddings, 210 | embedding_type=config["net"]["embedding_type"] 211 | ) 212 | 213 | synth_df_val = pd.read_csv(config["data"]["synth_val_tsv"], sep="\t") 214 | synth_val_embeddings = None if config["pretrained"]["e2e"] else os.path.join(config["pretrained"]["extracted_embeddings_dir"], 215 | config["pretrained"]["model"], 216 | "synth_val.hdf5") 217 | synth_val = StronglyAnnotatedSet( 218 | config["data"]["synth_val_folder"], 219 | synth_df_val, 220 | encoder, 221 | return_filename=True, 222 | pad_to=config["data"]["audio_max_len"], feats_pipeline=feature_extraction, 223 | embeddings_hdf5_file=synth_val_embeddings, 224 | embedding_type=config["net"]["embedding_type"] 225 | ) 226 | 227 | weak_val_embeddings = None if config["pretrained"]["e2e"] else os.path.join(config["pretrained"]["extracted_embeddings_dir"], 228 | config["pretrained"]["model"], 229 | "weak_val.hdf5") 230 | weak_val = WeakSet( 231 | config["data"]["weak_folder"], 232 | valid_weak_df, 233 | encoder, 234 | pad_to=config["data"]["audio_max_len"], 235 | return_filename=True, feats_pipeline=feature_extraction, 236 | embeddings_hdf5_file=weak_val_embeddings, 237 | embedding_type=config["net"]["embedding_type"] 238 | ) 239 | 240 | tot_train_data = [synth_set, weak_set, unlabeled_set] 241 | train_dataset = torch.utils.data.ConcatDataset(tot_train_data) 242 | 243 | batch_sizes = config["training"]["batch_size"] 244 | samplers = [torch.utils.data.RandomSampler(x) for x in tot_train_data] 245 | batch_sampler = ConcatDatasetBatchSampler(samplers, batch_sizes) 246 | 247 | valid_dataset = torch.utils.data.ConcatDataset([synth_val, weak_val]) 248 | 249 | ##### training params and optimizers ############ 250 | epoch_len = min( 251 | [ 252 | len(tot_train_data[indx]) 253 | // ( 254 | config["training"]["batch_size"][indx] 255 | * config["training"]["accumulate_batches"] 256 | ) 257 | for indx in range(len(tot_train_data)) 258 | ] 259 | ) 260 | 261 | if config["pretrained"]["freezed"] or not config["pretrained"]["e2e"]: 262 | parameters = list(crnn.parameters()) 263 | else: 264 | parameters = list(crnn.parameters()) + list(pretrained.parameters()) 265 | opt = torch.optim.Adam(parameters, config["opt"]["lr"], betas=(0.9, 0.999)) 266 | 267 | exp_steps = config["training"]["n_epochs_warmup"] * epoch_len 268 | exp_scheduler = { 269 | "scheduler": ExponentialWarmup(opt, config["opt"]["lr"], exp_steps), 270 | "interval": "step", 271 | } 272 | logger = TensorBoardLogger( 273 | os.path.dirname(config["log_dir"]), config["log_dir"].split("/")[-1], 274 | ) 275 | print(f"experiment dir: {logger.log_dir}") 276 | 277 | callbacks = [ 278 | EarlyStopping( 279 | monitor="val/obj_metric", 280 | patience=config["training"]["early_stop_patience"], 281 | verbose=True, 282 | mode="max", 283 | ), 284 | ModelCheckpoint( 285 | logger.log_dir, 286 | monitor="val/obj_metric", 287 | save_top_k=1, 288 | mode="max", 289 | save_last=True, 290 | ), 291 | ] 292 | else: 293 | train_dataset = None 294 | valid_dataset = None 295 | batch_sampler = None 296 | opt = None 297 | exp_scheduler = None 298 | logger = True 299 | callbacks = None 300 | 301 | desed_training = SEDTask4( 302 | config, 303 | encoder=encoder, 304 | sed_student=crnn, 305 | pretrained_model=pretrained, 306 | opt=opt, 307 | train_data=train_dataset, 308 | valid_data=valid_dataset, 309 | test_data=test_dataset, 310 | train_sampler=batch_sampler, 311 | scheduler=exp_scheduler, 312 | fast_dev_run=fast_dev_run, 313 | evaluation=evaluation 314 | ) 315 | # Not using the fast_dev_run of Trainer because creates a DummyLogger so cannot check problems with the Logger 316 | if fast_dev_run: 317 | flush_logs_every_n_steps = 1 318 | log_every_n_steps = 1 319 | limit_train_batches = 2 320 | limit_val_batches = 2 321 | limit_test_batches = 2 322 | n_epochs = 3 323 | else: 324 | flush_logs_every_n_steps = 100 325 | log_every_n_steps = 40 326 | limit_train_batches = 1.0 327 | limit_val_batches = 1.0 328 | limit_test_batches = 1.0 329 | n_epochs = config["training"]["n_epochs"] 330 | 331 | trainer = pl.Trainer( 332 | precision=config["training"]["precision"], 333 | max_epochs=n_epochs, 334 | callbacks=callbacks, 335 | gpus=gpus, 336 | strategy=config["training"].get("backend"), 337 | accumulate_grad_batches=config["training"]["accumulate_batches"], 338 | logger=logger, 339 | resume_from_checkpoint=checkpoint_resume, 340 | gradient_clip_val=config["training"]["gradient_clip"], 341 | check_val_every_n_epoch=config["training"]["validation_interval"], 342 | num_sanity_val_steps=0, 343 | log_every_n_steps=log_every_n_steps, 344 | flush_logs_every_n_steps=flush_logs_every_n_steps, 345 | limit_train_batches=limit_train_batches, 346 | limit_val_batches=limit_val_batches, 347 | limit_test_batches=limit_test_batches, 348 | ) 349 | 350 | if test_state_dict is None: 351 | # start tracking energy consumption 352 | trainer.fit(desed_training) 353 | best_path = trainer.checkpoint_callback.best_model_path 354 | print(f"best model: {best_path}") 355 | test_state_dict = torch.load(best_path)["state_dict"] 356 | 357 | desed_training.load_state_dict(test_state_dict) 358 | trainer.test(desed_training) 359 | 360 | 361 | if __name__ == "__main__": 362 | parser = argparse.ArgumentParser("Training a SED system for DESED Task") 363 | parser.add_argument( 364 | "--conf_file", 365 | default="./confs/pretrained.yaml", 366 | help="The configuration file with all the experiment parameters.", 367 | ) 368 | parser.add_argument( 369 | "--log_dir", 370 | default="./exp/2022_baseline_pretask", 371 | help="Directory where to save tensorboard logs, saved models, etc.", 372 | ) 373 | parser.add_argument( 374 | "--resume_from_checkpoint", 375 | default=None, 376 | help="Allow the training to be resumed, take as input a previously saved model (.ckpt).", 377 | ) 378 | parser.add_argument( 379 | "--test_from_checkpoint", default=None, help="Test the model specified" 380 | ) 381 | parser.add_argument( 382 | "--gpus", 383 | default="1", 384 | help="The number of GPUs to train on, or the gpu to use, default='0', " 385 | "so uses one GPU", 386 | ) 387 | parser.add_argument( 388 | "--fast_dev_run", 389 | action="store_true", 390 | default=False, 391 | help="Use this option to make a 'fake' run which is useful for development and debugging. " 392 | "It uses very few batches and epochs so it won't give any meaningful result.", 393 | ) 394 | 395 | parser.add_argument( 396 | "--eval_from_checkpoint", 397 | default=None, 398 | help="Evaluate the model specified" 399 | ) 400 | 401 | args = parser.parse_args() 402 | 403 | with open(args.conf_file, "r") as f: 404 | configs = yaml.safe_load(f) 405 | 406 | evaluation = False 407 | test_from_checkpoint = args.test_from_checkpoint 408 | 409 | if args.eval_from_checkpoint is not None: 410 | test_from_checkpoint = args.eval_from_checkpoint 411 | evaluation = True 412 | 413 | test_model_state_dict = None 414 | if test_from_checkpoint is not None: 415 | checkpoint = torch.load(test_from_checkpoint) 416 | configs_ckpt = checkpoint["hyper_parameters"] 417 | configs_ckpt["data"] = configs["data"] 418 | print( 419 | f"loaded model: {test_from_checkpoint} \n" 420 | f"at epoch: {checkpoint['epoch']}" 421 | ) 422 | test_model_state_dict = checkpoint["state_dict"] 423 | 424 | if evaluation: 425 | configs["training"]["batch_size_val"] = 1 426 | 427 | seed = configs["training"]["seed"] 428 | if seed: 429 | torch.random.manual_seed(seed) 430 | np.random.seed(seed) 431 | random.seed(seed) 432 | pl.seed_everything(seed) 433 | 434 | test_only = test_from_checkpoint is not None 435 | resample_data_generate_durations(configs["data"], test_only, evaluation) 436 | single_run( 437 | configs, 438 | args.log_dir, 439 | args.gpus, 440 | args.resume_from_checkpoint, 441 | test_model_state_dict, 442 | args.fast_dev_run, 443 | evaluation 444 | ) 445 | --------------------------------------------------------------------------------