├── LICENSE ├── README.md ├── augment.py ├── configs └── example.yaml ├── data ├── extract_feature.py ├── extract_feature_for_train.py └── prepare_labels.py ├── dataset.py ├── example └── example.wav ├── figs ├── data_driven_framework.png ├── sample_background.png ├── sample_music.png ├── sample_speech.png ├── samples_1.png ├── samples_2.png ├── samples_3.png └── samples_4.png ├── forward.py ├── labelencoders └── vad.pth ├── losses.py ├── metrics.py ├── models.py ├── pprint_results.py ├── pretrained_models ├── audio2_vox2 │ └── model.pth ├── audioset2 │ └── model.pth ├── c1 │ └── model.pth ├── labelencoders │ ├── students.pth │ └── teacher.pth ├── sre │ └── model.pth ├── teacher1 │ └── model.pth ├── teacher2 │ └── model.pth └── vox2 │ └── model.pth ├── requirements.txt ├── run.py └── utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Heinrich Dinkel 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Data driven GPVAD 2 | Repository for the work in TASLP 2021 [Voice activity detection in the wild: A data-driven approach using teacher-student training](https://arxiv.org/abs/2105.04065). 3 | 4 | 5 | ![Framework](figs/data_driven_framework.png) 6 | 7 | 8 | ## Sample predictions against other methods 9 | 10 | ![Samples_1](figs/samples_1.png) 11 | 12 | ![Samples_2](figs/samples_2.png) 13 | 14 | ![Samples_3](figs/samples_3.png) 15 | 16 | ![Samples_4](figs/samples_4.png) 17 | 18 | ## Noise robustness 19 | 20 | ![Speech](figs/sample_speech.png) 21 | 22 | ![Background](figs/sample_background.png) 23 | 24 | ![Speech](figs/sample_speech.png) 25 | 26 | ## Results 27 | 28 | Our best model trained on the SRE (V3) dataset obtains the following results: 29 | 30 | | | Precision | Recall | F1 | AUC | FER | Event-F1 | 31 | |:-------------|------------:|---------:|-------:|------:|------:|-----------:| 32 | | aurora_clean | 96.844 | 95.102 | 95.93 | 98.66 | 3.06 | 74.8 | 33 | | aurora_noisy | 90.435 | 92.871 | 91.544 | 97.63 | 6.68 | 54.45 | 34 | | dcase18 | 89.202 | 88.362 | 88.717 | 95.2 | 10.82 | 57.85 | 35 | 36 | ## Usage 37 | 38 | We provide most of our pretrained models in this repository, including: 39 | 40 | 1. Both teachers (T_1, T_2) 41 | 2. Unbalanced audioset pretrained model 42 | 3. Voxceleb 2 pretrained model 43 | 4. Our best submission (SRE V3 trained) 44 | 45 | To download and run evaluation just do: 46 | 47 | ```bash 48 | git clone https://github.com/RicherMans/Datadriven-VAD 49 | cd Datadriven-VAD 50 | pip3 install -r requirements.txt 51 | python3 forward.py -w example/example.wav 52 | ``` 53 | 54 | Running this will print: 55 | 56 | ``` 57 | | index | event_label | onset | offset | filename | 58 | |--------:|:--------------|--------:|---------:|:--------------------| 59 | | 0 | Speech | 0.28 | 0.94 | example/example.wav | 60 | | 1 | Speech | 1.04 | 2.22 | example/example.wav | 61 | ``` 62 | 63 | ### Predicting voice activity 64 | 65 | We support single file and filelist-batching in our script. 66 | Obtaining VAD predictions is easy: 67 | 68 | ```bash 69 | python3 forward.py -w example/example.wav 70 | ``` 71 | 72 | Or if one prefers to do that batch_wise, first prepare a filelist: 73 | `find . -type f -name *.wav > wavlist.txt'` 74 | And then just run: 75 | ```bash 76 | python3 forward.py -l wavlist 77 | ``` 78 | 79 | 80 | #### Extra parameters 81 | 82 | * `-model` adjusts the pretrained model. Can be one of `t1,t2,v2,a2,a2_v2,sre`. Refer to the paper for each respective model. By default we use `sre`. 83 | * `-soft` instead of predicting human-readable timestamps, the model is now outputting the raw probabilities. 84 | * `-hard` instead of predicting human-readable timestamps, the model is now outputting the post-processed 0-1 flags indicating speech. Please note this is different from the paper, which thresholded the soft probabilities without post-processing. 85 | * `-th` adjusts the threshold. If a single threshold is passed (e.g., `-th 0.5`), we utilize simple binearization. Otherwise use the default double threshold with `-th 0.5 0.1`. 86 | * `-o` outputs the results into a new folder. 87 | 88 | 89 | ## Training from scratch 90 | 91 | If you intend to rerun our work, prepare some data and extract log-Mel spectrogram features. 92 | Say, you have downloaded the [balanced](http://storage.googleapis.com/us_audioset/youtube_corpus/v1/csv/balanced_train_segments.csv) subset of AudioSet and stored all files in a folder `data/balanced/`. Then: 93 | 94 | ```bash 95 | cd data; 96 | mkdir hdf5 csv_labels; 97 | find balanced -type f > wavs.txt; 98 | python3 extract_features.py wavs.txt -o hdf5/balanced.h5 99 | h5ls -r hdf5/balanced.h5 | awk -F[/' '] 'BEGIN{print "filename","hdf5path"}NR>1{print $2,"hdf5/balanced.h5"}'> csv_labels/balanced.csv 100 | ``` 101 | 102 | 103 | The input for our label prediction script is a csv file with exactly two columns, `filename and hdf5path`. 104 | 105 | An example `csv_labels/balanced.csv` would be: 106 | 107 | ``` 108 | filename hdf5path 109 | --PJHxphWEs_30.000.wav hdf5/balanced.h5 110 | --ZhevVpy1s_50.000.wav hdf5/balanced.h5 111 | --aE2O5G5WE_0.000.wav hdf5/balanced.h5 112 | --aO5cdqSAg_30.000.wav hdf5/balanced.h5 113 | ``` 114 | 115 | After feature extraction, proceed to predict labels: 116 | 117 | ```bash 118 | mkdir -p softlabels/{hdf5,csv}; 119 | python3 prepare_labels.py --pre ../pretrained_models/teacher1/model.pth csv_labels/balanced.csv softlabels/hdf5/balanced.h5 softlabels/csv/balanced.csv 120 | ``` 121 | 122 | Lastly, just train: 123 | 124 | ```bash 125 | cd ../; #Go to project root 126 | # Change config accoringly with input data 127 | python3 run.py train configs/example.yaml 128 | ``` 129 | 130 | ## Citation 131 | 132 | If youre using this work, please cite it in your publications. 133 | 134 | ``` 135 | @article{Dinkel2021, 136 | author = {Dinkel, Heinrich and Wang, Shuai and Xu, Xuenan and Wu, Mengyue and Yu, Kai}, 137 | doi = {10.1109/TASLP.2021.3073596}, 138 | issn = {2329-9290}, 139 | journal = {IEEE/ACM Transactions on Audio, Speech, and Language Processing}, 140 | pages = {1542--1555}, 141 | title = {{Voice Activity Detection in the Wild: A Data-Driven Approach Using Teacher-Student Training}}, 142 | url = {https://ieeexplore.ieee.org/document/9405474/}, 143 | volume = {29}, 144 | year = {2021} 145 | } 146 | ``` 147 | and 148 | ``` 149 | @inproceedings{Dinkel2020, 150 | author={Heinrich Dinkel and Yefei Chen and Mengyue Wu and Kai Yu}, 151 | title={{Voice Activity Detection in the Wild via Weakly Supervised Sound Event Detection}}, 152 | year=2020, 153 | booktitle={Proc. Interspeech 2020}, 154 | pages={3665--3669}, 155 | doi={10.21437/Interspeech.2020-0995}, 156 | url={http://dx.doi.org/10.21437/Interspeech.2020-0995} 157 | } 158 | ``` 159 | 160 | -------------------------------------------------------------------------------- /augment.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import logging 3 | import torch.nn as nn 4 | import numpy as np 5 | 6 | 7 | class RandomPad(nn.Module): 8 | """docstring for RandomPad""" 9 | def __init__(self, value=0., padding=0): 10 | super().__init__() 11 | self.value = value 12 | self.padding = padding 13 | 14 | def forward(self, x): 15 | if self.training and self.padding > 0: 16 | left_right = torch.empty(2).random_(self.padding).int().numpy() 17 | topad = (0, 0, *left_right) 18 | x = nn.functional.pad(x, topad, value=self.value) 19 | return x 20 | 21 | 22 | class Roll(nn.Module): 23 | """docstring for Roll""" 24 | def __init__(self, mean, std): 25 | super().__init__() 26 | self.mean = mean 27 | self.std = std 28 | 29 | def forward(self, x): 30 | if self.training: 31 | shift = torch.empty(1).normal_(self.mean, self.std).int().item() 32 | x = torch.roll(x, shift, dims=0) 33 | return x 34 | 35 | 36 | class RandomCrop(nn.Module): 37 | """docstring for RandomPad""" 38 | def __init__(self, size: int = 100): 39 | super().__init__() 40 | self.size = int(size) 41 | 42 | def forward(self, x): 43 | if self.training: 44 | time, freq = x.shape 45 | if time < self.size: 46 | return x 47 | hi = time - self.size 48 | start_ind = torch.empty(1, dtype=torch.long).random_(0, hi).item() 49 | x = x[start_ind:start_ind + self.size, :] 50 | return x 51 | 52 | 53 | class TimeMask(nn.Module): 54 | def __init__(self, n=1, p=50): 55 | super().__init__() 56 | self.p = p 57 | self.n = 1 58 | 59 | def forward(self, x): 60 | time, freq = x.shape 61 | if self.training: 62 | for i in range(self.n): 63 | t = torch.empty(1, dtype=int).random_(self.p).item() 64 | to_sample = max(time - t, 1) 65 | t0 = torch.empty(1, dtype=int).random_(to_sample).item() 66 | x[t0:t0 + t, :] = 0 67 | return x 68 | 69 | 70 | class FreqMask(nn.Module): 71 | def __init__(self, n=1, p=12): 72 | super().__init__() 73 | self.p = p 74 | self.n = 1 75 | 76 | def forward(self, x): 77 | time, freq = x.shape 78 | if self.training: 79 | for i in range(self.n): 80 | f = torch.empty(1, dtype=int).random_(self.p).item() 81 | f0 = torch.empty(1, dtype=int).random_(freq - f).item() 82 | x[:, f0:f0 + f] = 0. 83 | return x 84 | 85 | 86 | class GaussianNoise(nn.Module): 87 | """docstring for Gaussian""" 88 | def __init__(self, snr=30, mean=0): 89 | super().__init__() 90 | self._mean = mean 91 | self._snr = snr 92 | 93 | def forward(self, x): 94 | if self.training: 95 | E_x = (x**2).sum()/x.shape[0] 96 | noise = torch.empty_like(x).normal_(self._mean, std=1) 97 | E_noise = (noise**2).sum()/noise.shape[0] 98 | alpha = np.sqrt(E_x / (E_noise * pow(10, self._snr / 10))) 99 | x = x + alpha * noise 100 | return x 101 | 102 | 103 | class Shift(nn.Module): 104 | """ 105 | Randomly shift audio in time by up to `shift` samples. 106 | """ 107 | def __init__(self, shift=4000): 108 | super().__init__() 109 | self.shift = shift 110 | 111 | def forward(self, wav): 112 | time, channels = wav.size() 113 | length = time - self.shift 114 | if self.shift > 0: 115 | if not self.training: 116 | wav = wav[..., :length] 117 | else: 118 | offset = torch.randint(self.shift, [channels, 1], 119 | device=wav.device) 120 | indexes = torch.arange(length, device=wav.device) 121 | offset = indexes + offset 122 | wav = wav.gather(0, offset.transpose(0, 1)) 123 | return wav 124 | 125 | 126 | class FlipSign(nn.Module): 127 | """ 128 | Random sign flip. 129 | """ 130 | def forward(self, wav): 131 | time, channels = wav.size() 132 | if self.training: 133 | signs = torch.randint(2, (1, channels), 134 | device=wav.device, 135 | dtype=torch.float32) 136 | wav = wav * (2 * signs - 1) 137 | return wav 138 | 139 | 140 | if __name__ == "__main__": 141 | x = torch.randn(1, 10) 142 | y = GaussianNoise(10)(x) 143 | print(x) 144 | print(y) 145 | -------------------------------------------------------------------------------- /configs/example.yaml: -------------------------------------------------------------------------------- 1 | data: data/csv_labels/balanced.csv 2 | label: data/softlabels/csv/balanced.csv 3 | batch_size: 64 4 | data_args: 5 | mode: Null 6 | num_workers: 8 7 | optimizer: AdamW 8 | optimizer_args: 9 | lr: 0.001 10 | scheduler_args: 11 | patience: 10 12 | factor: 0.1 13 | early_stop: 15 14 | epochs: 15 15 | itercv: 10000 16 | save: best 17 | model: CRNN 18 | model_args: {} 19 | outputpath: experiments/ 20 | transforms: [timemask, freqmask] 21 | loss: FrameBCELoss 22 | -------------------------------------------------------------------------------- /data/extract_feature.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import argparse 3 | import librosa 4 | from tqdm import tqdm 5 | import io 6 | import logging 7 | from pathlib import Path 8 | import pandas as pd 9 | import numpy as np 10 | import soundfile as sf 11 | from pypeln import process as pr 12 | import gzip 13 | import h5py 14 | 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument('input_csv') 17 | parser.add_argument('-o', '--output', type=str, required=True) 18 | parser.add_argument('-c', type=int, default=4) 19 | parser.add_argument('-sr', type=int, default=22050) 20 | parser.add_argument('-col', 21 | default='filename', 22 | type=str, 23 | help='Column to search for audio files') 24 | parser.add_argument('-cmn', default=False, action='store_true') 25 | parser.add_argument('-cvn', default=False, action='store_true') 26 | parser.add_argument('-winlen', 27 | default=40, 28 | type=float, 29 | help='FFT duration in ms') 30 | parser.add_argument('-hoplen', 31 | default=20, 32 | type=float, 33 | help='hop duration in ms') 34 | 35 | parser.add_argument('-n_mels', default=64, type=int) 36 | ARGS = parser.parse_args() 37 | 38 | DF = pd.read_csv(ARGS.input_csv, sep='\t', 39 | usecols=[0]) # only read first cols, allows to have messy csv 40 | 41 | MEL_ARGS = { 42 | 'n_mels': ARGS.n_mels, 43 | 'n_fft': 2048, 44 | 'hop_length': int(ARGS.sr * ARGS.hoplen / 1000), 45 | 'win_length': int(ARGS.sr * ARGS.winlen / 1000) 46 | } 47 | 48 | EPS = np.spacing(1) 49 | 50 | 51 | def extract_feature(fname): 52 | """extract_feature 53 | Extracts a log mel spectrogram feature from a filename, currently supports two filetypes: 54 | 55 | 1. Wave 56 | 2. Gzipped wave 57 | 58 | :param fname: filepath to the file to extract 59 | """ 60 | ext = Path(fname).suffix 61 | try: 62 | if ext == '.gz': 63 | with gzip.open(fname, 'rb') as gzipped_wav: 64 | y, sr = sf.read(io.BytesIO(gzipped_wav.read()), 65 | dtype='float32') 66 | # Multiple channels, reduce 67 | if y.ndim == 2: 68 | y = y.mean(1) 69 | y = librosa.resample(y, sr, ARGS.sr) 70 | elif ext in ('.wav', '.flac'): 71 | y, sr = sf.read(fname, dtype='float32') 72 | if y.ndim > 1: 73 | y = y.mean(1) 74 | y = librosa.resample(y, sr, ARGS.sr) 75 | except Exception as e: 76 | # Exception usually happens because some data has 6 channels , which librosa cant handle 77 | logging.error(e) 78 | logging.error(fname) 79 | raise 80 | lms_feature = np.log(librosa.feature.melspectrogram(y, **MEL_ARGS) + EPS).T 81 | return fname, lms_feature 82 | 83 | 84 | with h5py.File(ARGS.output, 'w') as store: 85 | for fname, feat in tqdm(pr.map(extract_feature, 86 | DF[ARGS.col].unique(), 87 | workers=ARGS.c, 88 | maxsize=4), 89 | total=len(DF[ARGS.col].unique())): 90 | basename = Path(fname).name 91 | store[basename] = feat 92 | -------------------------------------------------------------------------------- /data/extract_feature_for_train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import argparse 3 | import librosa 4 | from tqdm import tqdm 5 | import io 6 | from pathlib import Path 7 | from loguru import logger 8 | import pandas as pd 9 | import numpy as np 10 | import soundfile as sf 11 | from pypeln import process as pr 12 | import h5py 13 | import gzip 14 | 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument('input_csv') 17 | parser.add_argument('-o', '--output', type=str, required=True) 18 | parser.add_argument('-l', 19 | '--label', 20 | type=str, 21 | required=True, 22 | help="Output label(chunked)") 23 | parser.add_argument('-s', 24 | '--size', 25 | type=float, 26 | default=10, 27 | help="Length of each segment") 28 | parser.add_argument('-t', 29 | '--threshold', 30 | type=float, 31 | default=1, 32 | help='Do not save files less than -t seconds', 33 | metavar='s') 34 | parser.add_argument('-c', type=int, default=4) 35 | parser.add_argument('-sr', type=int, default=22050) 36 | parser.add_argument('-col', 37 | default='filename', 38 | type=str, 39 | help='Column to search for audio files') 40 | parser.add_argument('-cmn', default=False, action='store_true') 41 | parser.add_argument('-cvn', default=False, action='store_true') 42 | parser.add_argument('-winlen', 43 | default=40, 44 | type=float, 45 | help='FFT duration in ms') 46 | parser.add_argument('-hoplen', 47 | default=20, 48 | type=float, 49 | help='hop duration in ms') 50 | 51 | parser.add_argument('-n_mels', default=64, type=int) 52 | ARGS = parser.parse_args() 53 | 54 | DF = pd.read_csv(ARGS.input_csv, usecols=[0], sep=' ') 55 | 56 | MEL_ARGS = { 57 | 'n_mels': ARGS.n_mels, 58 | 'n_fft': 2048, 59 | 'hop_length': int(ARGS.sr * ARGS.hoplen / 1000), 60 | 'win_length': int(ARGS.sr * ARGS.winlen / 1000) 61 | } 62 | 63 | EPS = np.spacing(1) 64 | DURATION_CHUNK = ARGS.size / (ARGS.hoplen / 1000) 65 | THRESHOLD = ARGS.threshold / (ARGS.hoplen / 1000) 66 | 67 | 68 | def extract_feature(fname): 69 | # def extract_feature(fname, segfname, start, end, nseg): 70 | """extract_feature 71 | Extracts a log mel spectrogram feature from a filename, currently supports two filetypes: 72 | 73 | 1. Wave 74 | 2. Gzipped wave 75 | 76 | :param fname: filepath to the file to extract 77 | """ 78 | pospath = Path(fname) 79 | ext = pospath.suffix 80 | try: 81 | if ext == '.gz': 82 | with gzip.open(fname, 'rb') as gzipped_wav: 83 | y, sr = sf.read(io.BytesIO(gzipped_wav.read()), 84 | dtype='float32') 85 | # Multiple channels, reduce 86 | if y.ndim == 2: 87 | y = y.mean(1) 88 | y = librosa.resample(y, sr, ARGS.sr) 89 | elif ext in ('.wav', '.flac'): 90 | y, sr = sf.read(fname, dtype='float32') 91 | if y.ndim > 1: 92 | y = y.mean(1) 93 | y = librosa.resample(y, sr, ARGS.sr) 94 | except Exception as e: 95 | # Exception usually happens because some data has 6 channels , which librosa cant handle 96 | logger.error(e) 97 | logger.error(fname) 98 | raise 99 | fname = pospath.name 100 | feat = np.log(librosa.feature.melspectrogram(y, **MEL_ARGS) + EPS).T 101 | start_range = np.arange(0, feat.shape[0], DURATION_CHUNK, dtype=int) 102 | end_range = (start_range + DURATION_CHUNK).astype(int) 103 | end_range[-1] = feat.shape[0] 104 | for nseg, (start_time, end_time) in enumerate(zip(start_range, end_range)): 105 | seg = feat[start_time:end_time] 106 | if end_time - start_time < THRESHOLD: 107 | # Dont save 108 | continue 109 | yield fname, seg, nseg 110 | 111 | 112 | with h5py.File(ARGS.output, 'w') as store, tqdm() as pbar, open(ARGS.label,'w') as output_csv: 113 | output_csv.write(f"filename hdf5path\n") #write header 114 | hdf5_path = Path(ARGS.output).absolute() 115 | for fname, feat, nseg in pr.flat_map(extract_feature, 116 | DF['filename'].unique(), 117 | workers=ARGS.c, 118 | maxsize=ARGS.c * 2): 119 | new_fname = f"{Path(fname).stem}_{nseg:05d}{Path(fname).suffix}" 120 | store[new_fname] = feat 121 | output_csv.write(f"{new_fname} {hdf5_path}\n") 122 | pbar.set_postfix(stored=new_fname, shape=feat.shape) 123 | pbar.update() 124 | -------------------------------------------------------------------------------- /data/prepare_labels.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pandas as pd 3 | import numpy as np 4 | import argparse 5 | from h5py import File 6 | from pathlib import Path 7 | from loguru import logger 8 | import torch.utils.data as tdata 9 | from tqdm import tqdm 10 | from models import crnn, cnn10 11 | import sys 12 | import csv 13 | 14 | 15 | class HDF5Dataset(tdata.Dataset): 16 | """ 17 | HDF5 dataset indexed by a labels dataframe. 18 | Indexing is done via the dataframe since we want to preserve some storage 19 | in cases where oversampling is needed ( pretty likely ) 20 | """ 21 | def __init__(self, h5file: File, transform=None): 22 | super(HDF5Dataset, self).__init__() 23 | self._h5file = h5file 24 | self.dataset = None 25 | # IF none is passed still use no transform at all 26 | self._transform = transform 27 | with File(self._h5file, 'r') as store: 28 | self._len = len(store) 29 | self._labels = list(store.keys()) 30 | self.datadim = store[self._labels[0]].shape[-1] 31 | 32 | def __len__(self): 33 | return self._len 34 | 35 | def __getitem__(self, index): 36 | if self.dataset is None: 37 | self.dataset = File(self._h5file, 'r') 38 | fname = self._labels[index] 39 | data = self.dataset[fname][()] 40 | data = torch.as_tensor(data).float() 41 | if self._transform: 42 | data = self._transform(data) 43 | return data, fname 44 | 45 | 46 | MODELS = { 47 | 'crnn': { 48 | 'model': crnn, 49 | 'encoder': torch.load('encoders/balanced.pth'), 50 | 'outputdim': 527, 51 | }, 52 | 'gpvb': { 53 | 'model': crnn, 54 | 'encoder': torch.load('encoders/balanced_binary.pth'), 55 | 'outputdim': 2, 56 | } 57 | } 58 | 59 | POOLING = { 60 | 'max': lambda x: np.max(x, axis=-1), 61 | 'mean': lambda x: np.mean(x, axis=-1) 62 | } 63 | 64 | DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' 65 | DEVICE = torch.device(DEVICE) 66 | 67 | 68 | def main(): 69 | parser = argparse.ArgumentParser() 70 | parser.add_argument('data', type=Path) 71 | parser.add_argument('-m', '--model', default='crnn', type=str) 72 | parser.add_argument('-po', 73 | '--pool', 74 | default='max', 75 | choices=POOLING.keys(), 76 | type=str) 77 | parser.add_argument('--pre', '-p', default='pretrained/gpv_f.pth') 78 | parser.add_argument('hdf5output', type=Path) 79 | parser.add_argument('csvoutput', type=Path) 80 | args = parser.parse_args() 81 | 82 | log_format = "[{time:YYYY-MM-DD HH:mm:ss}] {message}" 83 | logger.configure(handlers=[{"sink": sys.stderr, "format": log_format}]) 84 | 85 | for k, v in vars(args).items(): 86 | logger.info(f"{k} : {v}") 87 | 88 | model_dict = MODELS[args.model] 89 | model = model_dict['model'](outputdim=model_dict['outputdim'], 90 | pretrained_from=args.pre).to(DEVICE).eval() 91 | encoder = model_dict['encoder'] 92 | logger.info(model) 93 | pooling_fun = POOLING[args.pool] 94 | if Path(args.data).suffix == '.csv': 95 | data = pd.read_csv(args.data, sep='\s+') 96 | data = data['hdf5path'].unique() 97 | assert len(data) == 1, "Only single hdf5 supported yet" 98 | data = data[0] 99 | else: #h5 file directly 100 | data = args.data 101 | 102 | logger.info(f"Reading from input file {data}") 103 | dataloader = tdata.DataLoader(HDF5Dataset(data), 104 | num_workers=4, 105 | batch_size=1) 106 | speech_class_idx = np.where(encoder.classes_ == 'Speech')[0] 107 | non_speech_idx = np.arange(len(encoder.classes_)) 108 | non_speech_idx = np.delete(non_speech_idx, speech_class_idx) 109 | with torch.no_grad(), File(args.hdf5output, 'w') as store, tqdm( 110 | total=len(dataloader)) as pbar, open(args.csvoutput, 111 | 'w') as csvfile: 112 | abs_output_hdf5 = Path(args.hdf5output).absolute() 113 | csvwr = csv.writer(csvfile, delimiter='\t') 114 | csvwr.writerow(['filename', 'hdf5path']) 115 | for batch in dataloader: 116 | x, fname = batch 117 | fname = fname[0] 118 | x = x.to(DEVICE) 119 | if x.shape[1] < 8: 120 | continue 121 | clip_pred, time_pred = model(x) 122 | clip_pred = clip_pred.squeeze(0).to('cpu').numpy() 123 | time_pred = time_pred.squeeze(0).to('cpu').numpy() 124 | speech_time_pred = time_pred[..., speech_class_idx].squeeze(-1) 125 | speech_clip_pred = clip_pred[..., speech_class_idx].squeeze(-1) 126 | non_speech_clip_pred = clip_pred[..., non_speech_idx] 127 | non_speech_time_pred = time_pred[..., non_speech_idx] 128 | non_speech_time_pred = pooling_fun(non_speech_time_pred) 129 | store[f'{fname}/speech'] = speech_time_pred 130 | store[f'{fname}/noise'] = non_speech_time_pred 131 | store[f'{fname}/clipspeech'] = speech_clip_pred 132 | store[f'{fname}/clipnoise'] = non_speech_clip_pred 133 | csvwr.writerow([fname, abs_output_hdf5]) 134 | pbar.set_postfix(fname=fname, speechsize=speech_time_pred.shape) 135 | pbar.update() 136 | 137 | 138 | if __name__ == "__main__": 139 | main() 140 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import time 2 | import torch 3 | import numpy as np 4 | import pandas as pd 5 | import scipy 6 | from h5py import File 7 | import itertools, random 8 | from tqdm import tqdm 9 | from loguru import logger 10 | import torch.utils.data as tdata 11 | from typing import List, Dict 12 | 13 | 14 | class TrainHDF5Dataset(tdata.Dataset): 15 | """ 16 | HDF5 dataset indexed by a labels dataframe. 17 | Indexing is done via the dataframe since we want to preserve some storage 18 | in cases where oversampling is needed ( pretty likely ) 19 | """ 20 | def __init__(self, 21 | h5filedict: Dict, 22 | h5labeldict: Dict, 23 | label_type='soft', 24 | transform=None): 25 | super(TrainHDF5Dataset, self).__init__() 26 | self._h5filedict = h5filedict 27 | self._h5labeldict = h5labeldict 28 | self._datasetcache = {} 29 | self._labelcache = {} 30 | self._len = len(self._h5labeldict) 31 | # IF none is passed still use no transform at all 32 | self._transform = transform 33 | assert label_type in ('soft', 'hard', 'softhard', 'hardnoise') 34 | self._label_type = label_type 35 | 36 | self.idx_to_item = { 37 | idx: item 38 | for idx, item in enumerate(self._h5labeldict.keys()) 39 | } 40 | first_item = next(iter(self._h5filedict.keys())) 41 | with File(self._h5filedict[first_item], 'r') as store: 42 | self.datadim = store[first_item].shape[-1] 43 | 44 | def __len__(self): 45 | return self._len 46 | 47 | def __del__(self): 48 | for k, cache in self._datasetcache.items(): 49 | cache.close() 50 | for k, cache in self._labelcache.items(): 51 | cache.close() 52 | 53 | def __getitem__(self, index: int): 54 | fname: str = self.idx_to_item[index] 55 | h5file: str = self._h5filedict[fname] 56 | labelh5file: str = self._h5labeldict[fname] 57 | if not h5file in self._datasetcache: 58 | self._datasetcache[h5file] = File(h5file, 'r') 59 | if not labelh5file in self._labelcache: 60 | self._labelcache[labelh5file] = File(labelh5file, 'r') 61 | 62 | data = self._datasetcache[h5file][f"{fname}"][()] 63 | speech_target = self._labelcache[labelh5file][f"{fname}/speech"][()] 64 | noise_target = self._labelcache[labelh5file][f"{fname}/noise"][()] 65 | speech_clip_target = self._labelcache[labelh5file][ 66 | f"{fname}/clipspeech"][()] 67 | noise_clip_target = self._labelcache[labelh5file][ 68 | f"{fname}/clipnoise"][()] 69 | 70 | noise_clip_target = np.max(noise_clip_target) # take max around axis 71 | if self._label_type == 'hard': 72 | noise_clip_target = noise_clip_target.round() 73 | speech_target = speech_target.round() 74 | noise_target = noise_target.round() 75 | speech_clip_target = speech_clip_target.round() 76 | elif self._label_type == 'hardnoise': # only noise yay 77 | noise_clip_target = noise_clip_target.round() 78 | noise_target = noise_target.round() 79 | elif self._label_type == 'softhard': 80 | r = np.random.permutation(noise_target.shape[0] // 4) 81 | speech_target[r] = speech_target[r].round() 82 | target_clip = torch.tensor((noise_clip_target, speech_clip_target)) 83 | data = torch.as_tensor(data).float() 84 | target_time = torch.as_tensor( 85 | np.stack((noise_target, speech_target), axis=-1)).float() 86 | if self._transform: 87 | data = self._transform(data) 88 | return data, target_time, target_clip, fname 89 | 90 | 91 | class HDF5Dataset(tdata.Dataset): 92 | """ 93 | HDF5 dataset indexed by a labels dataframe. 94 | Indexing is done via the dataframe since we want to preserve some storage 95 | in cases where oversampling is needed ( pretty likely ) 96 | """ 97 | def __init__(self, h5file: File, h5label: File, fnames, transform=None): 98 | super(HDF5Dataset, self).__init__() 99 | self._h5file = h5file 100 | self._h5label = h5label 101 | self.fnames = fnames 102 | self.dataset = None 103 | self.label_dataset = None 104 | self._len = len(fnames) 105 | # IF none is passed still use no transform at all 106 | self._transform = transform 107 | with File(self._h5file, 'r') as store, File(self._h5label, 108 | 'r') as labelstore: 109 | self.datadim = store[self.fnames[0]].shape[-1] 110 | 111 | def __len__(self): 112 | return self._len 113 | 114 | def __getitem__(self, index): 115 | if self.dataset is None: 116 | self.dataset = File(self._h5file, 'r') 117 | self.label_dataset = File(self._h5label, 'r') 118 | fname = self.fnames[index] 119 | data = self.dataset[fname][()] 120 | speech_target = self.label_dataset[f"{fname}/speech"][()] 121 | noise_target = self.label_dataset[f"{fname}/noise"][()] 122 | speech_clip_target = self.label_dataset[f"{fname}/clipspeech"][()] 123 | noise_clip_target = self.label_dataset[f"{fname}/clipnoise"][()] 124 | noise_clip_target = np.max(noise_clip_target) # take max around axis 125 | target_clip = torch.tensor((noise_clip_target, speech_clip_target)) 126 | data = torch.as_tensor(data).float() 127 | target_time = torch.as_tensor( 128 | np.stack((noise_target, speech_target), axis=-1)).float() 129 | if self._transform: 130 | data = self._transform(data) 131 | return data, target_time, target_clip, fname 132 | 133 | 134 | class EvalH5Dataset(tdata.Dataset): 135 | """ 136 | HDF5 dataset indexed by a labels dataframe. 137 | Indexing is done via the dataframe since we want to preserve some storage 138 | in cases where oversampling is needed ( pretty likely ) 139 | """ 140 | def __init__(self, h5file: File, fnames=None): 141 | super(EvalH5Dataset, self).__init__() 142 | self._h5file = h5file 143 | self._dataset = None 144 | # IF none is passed still use no transform at all 145 | with File(self._h5file, 'r') as store: 146 | if fnames is None: 147 | self.fnames = list(store.keys()) 148 | else: 149 | self.fnames = fnames 150 | self.datadim = store[self.fnames[0]].shape[-1] 151 | self._len = len(store) 152 | 153 | def __len__(self): 154 | return self._len 155 | 156 | def __getitem__(self, index): 157 | if self._dataset is None: 158 | self._dataset = File(self._h5file, 'r') 159 | fname = self.fnames[index] 160 | data = self._dataset[fname][()] 161 | data = torch.as_tensor(data).float() 162 | return data, fname 163 | 164 | 165 | class MinimumOccupancySampler(tdata.Sampler): 166 | """ 167 | docstring for MinimumOccupancySampler 168 | samples at least one instance from each class sequentially 169 | """ 170 | def __init__(self, labels, sampling_mode='same', random_state=None): 171 | self.labels = labels 172 | data_samples, n_labels = labels.shape 173 | label_to_idx_list, label_to_length = [], [] 174 | self.random_state = np.random.RandomState(seed=random_state) 175 | for lb_idx in range(n_labels): 176 | label_selection = labels[:, lb_idx] 177 | if scipy.sparse.issparse(label_selection): 178 | label_selection = label_selection.toarray() 179 | label_indexes = np.where(label_selection == 1)[0] 180 | self.random_state.shuffle(label_indexes) 181 | label_to_length.append(len(label_indexes)) 182 | label_to_idx_list.append(label_indexes) 183 | 184 | self.longest_seq = max(label_to_length) 185 | self.data_source = np.empty((self.longest_seq, len(label_to_length)), 186 | dtype=np.uint32) 187 | # Each column represents one "single instance per class" data piece 188 | for ix, leng in enumerate(label_to_length): 189 | # Fill first only "real" samples 190 | self.data_source[:leng, ix] = label_to_idx_list[ix] 191 | 192 | self.label_to_idx_list = label_to_idx_list 193 | self.label_to_length = label_to_length 194 | 195 | if sampling_mode == 'same': 196 | self.data_length = data_samples 197 | elif sampling_mode == 'over': # Sample all items 198 | self.data_length = np.prod(self.data_source.shape) 199 | 200 | def _reshuffle(self): 201 | # Reshuffle 202 | for ix, leng in enumerate(self.label_to_length): 203 | leftover = self.longest_seq - leng 204 | random_idxs = np.random.randint(leng, size=leftover) 205 | self.data_source[leng:, 206 | ix] = self.label_to_idx_list[ix][random_idxs] 207 | 208 | def __iter__(self): 209 | # Before each epoch, reshuffle random indicies 210 | self._reshuffle() 211 | n_samples = len(self.data_source) 212 | random_indices = self.random_state.permutation(n_samples) 213 | data = np.concatenate( 214 | self.data_source[random_indices])[:self.data_length] 215 | return iter(data) 216 | 217 | def __len__(self): 218 | return self.data_length 219 | 220 | 221 | class MultiBalancedSampler(tdata.sampler.Sampler): 222 | """docstring for BalancedSampler 223 | Samples for Multi-label training 224 | Sampling is not totally equal, but aims to be roughtly equal 225 | """ 226 | def __init__(self, Y, replacement=False, num_samples=None): 227 | assert Y.ndim == 2, "Y needs to be one hot encoded" 228 | if scipy.sparse.issparse(Y): 229 | raise ValueError("Not supporting sparse amtrices yet") 230 | class_counts = np.sum(Y, axis=0) 231 | class_weights = 1. / class_counts 232 | class_weights = class_weights / class_weights.sum() 233 | classes = np.arange(Y[0].shape[0]) 234 | # Revert from many_hot to one 235 | class_ids = [tuple(classes.compress(idx)) for idx in Y] 236 | 237 | sample_weights = [] 238 | for i in range(len(Y)): 239 | # Multiple classes were chosen, calculate average probability 240 | weight = class_weights[np.array(class_ids[i])] 241 | # Take the mean of the multiple classes and set as weight 242 | weight = np.mean(weight) 243 | sample_weights.append(weight) 244 | self._weights = torch.as_tensor(sample_weights, dtype=torch.float) 245 | self._len = num_samples if num_samples else len(Y) 246 | self._replacement = replacement 247 | 248 | def __len__(self): 249 | return self._len 250 | 251 | def __iter__(self): 252 | return iter( 253 | torch.multinomial(self._weights, self._len, 254 | self._replacement).tolist()) 255 | 256 | 257 | def gettraindataloader(h5files, 258 | h5labels, 259 | label_type=False, 260 | transform=None, 261 | **dataloader_kwargs): 262 | dset = TrainHDF5Dataset(h5files, 263 | h5labels, 264 | label_type=label_type, 265 | transform=transform) 266 | return tdata.DataLoader(dset, 267 | collate_fn=sequential_collate, 268 | **dataloader_kwargs) 269 | 270 | 271 | def getdataloader(h5file, h5label, fnames, transform=None, 272 | **dataloader_kwargs): 273 | dset = HDF5Dataset(h5file, h5label, fnames, transform=transform) 274 | return tdata.DataLoader(dset, 275 | collate_fn=sequential_collate, 276 | **dataloader_kwargs) 277 | 278 | 279 | def pad(tensorlist, padding_value=0.): 280 | lengths = [len(f) for f in tensorlist] 281 | max_len = np.max(lengths) 282 | # max_len = 2000 283 | batch_dim = len(lengths) 284 | data_dim = tensorlist[0].shape[-1] 285 | out_tensor = torch.full((batch_dim, max_len, data_dim), 286 | fill_value=padding_value, 287 | dtype=torch.float32) 288 | for i, tensor in enumerate(tensorlist): 289 | length = tensor.shape[0] 290 | out_tensor[i, :length, ...] = tensor[:length, ...] 291 | return out_tensor, torch.tensor(lengths) 292 | 293 | 294 | def sequential_collate(batches): 295 | # sort length wise 296 | data, targets_time, targets_clip, fnames = zip(*batches) 297 | data, lengths_data = pad(data) 298 | targets_time, lengths_tar = pad(targets_time, padding_value=0) 299 | targets_clip = torch.stack(targets_clip) 300 | assert lengths_data.shape == lengths_tar.shape 301 | return data, targets_time, targets_clip, fnames, lengths_tar 302 | 303 | 304 | if __name__ == '__main__': 305 | import utils 306 | label_df = pd.read_csv( 307 | 'data/csv_labels/unbalanced_from_unbalanced/unbalanced.csv', sep='\s+') 308 | data_df = pd.read_csv("data/data_csv/unbalanced.csv", sep='\s+') 309 | 310 | merged = data_df.merge(label_df, on='filename') 311 | common_idxs = merged['filename'] 312 | data_df = data_df[data_df['filename'].isin(common_idxs)] 313 | label_df = label_df[label_df['filename'].isin(common_idxs)] 314 | 315 | label = utils.df_to_dict(label_df) 316 | data = utils.df_to_dict(data_df) 317 | 318 | trainloader = gettraindataloader( 319 | h5files=data, 320 | h5labels=label, 321 | transform=None, 322 | label_type='soft', 323 | batch_size=64, 324 | num_workers=3, 325 | shuffle=False, 326 | ) 327 | 328 | with tqdm(total=len(trainloader)) as pbar: 329 | for batch in trainloader: 330 | inputs, targets_time, targets_clip, filenames, lengths = batch 331 | pbar.set_postfix(inp=inputs.shape) 332 | pbar.update() 333 | -------------------------------------------------------------------------------- /example/example.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RicherMans/Datadriven-GPVAD/6c94570bee753271722f64826a0ed00c030b089c/example/example.wav -------------------------------------------------------------------------------- /figs/data_driven_framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RicherMans/Datadriven-GPVAD/6c94570bee753271722f64826a0ed00c030b089c/figs/data_driven_framework.png -------------------------------------------------------------------------------- /figs/sample_background.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RicherMans/Datadriven-GPVAD/6c94570bee753271722f64826a0ed00c030b089c/figs/sample_background.png -------------------------------------------------------------------------------- /figs/sample_music.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RicherMans/Datadriven-GPVAD/6c94570bee753271722f64826a0ed00c030b089c/figs/sample_music.png -------------------------------------------------------------------------------- /figs/sample_speech.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RicherMans/Datadriven-GPVAD/6c94570bee753271722f64826a0ed00c030b089c/figs/sample_speech.png -------------------------------------------------------------------------------- /figs/samples_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RicherMans/Datadriven-GPVAD/6c94570bee753271722f64826a0ed00c030b089c/figs/samples_1.png -------------------------------------------------------------------------------- /figs/samples_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RicherMans/Datadriven-GPVAD/6c94570bee753271722f64826a0ed00c030b089c/figs/samples_2.png -------------------------------------------------------------------------------- /figs/samples_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RicherMans/Datadriven-GPVAD/6c94570bee753271722f64826a0ed00c030b089c/figs/samples_3.png -------------------------------------------------------------------------------- /figs/samples_4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RicherMans/Datadriven-GPVAD/6c94570bee753271722f64826a0ed00c030b089c/figs/samples_4.png -------------------------------------------------------------------------------- /forward.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import sys 3 | from loguru import logger 4 | from pathlib import Path 5 | from tqdm import tqdm 6 | import utils 7 | import pandas as pd 8 | import numpy as np 9 | import librosa 10 | import soundfile as sf 11 | import uuid 12 | import argparse 13 | from models import crnn 14 | import os 15 | 16 | SAMPLE_RATE = 22050 17 | EPS = np.spacing(1) 18 | LMS_ARGS = { 19 | 'n_fft': 2048, 20 | 'n_mels': 64, 21 | 'hop_length': int(SAMPLE_RATE * 0.02), 22 | 'win_length': int(SAMPLE_RATE * 0.04) 23 | } 24 | DEVICE = 'cpu' 25 | if torch.cuda.is_available(): 26 | DEVICE = 'cuda' 27 | DEVICE = torch.device(DEVICE) 28 | 29 | 30 | def extract_feature(wavefilepath, **kwargs): 31 | _, file_extension = os.path.splitext(wavefilepath) 32 | if file_extension == '.wav': 33 | wav, sr = sf.read(wavefilepath, dtype='float32') 34 | if file_extension == '.mp3': 35 | wav, sr = librosa.load(wavefilepath) 36 | elif file_extension not in ['.mp3', '.wav']: 37 | raise NotImplementedError('Audio extension not supported... yet ;)') 38 | if wav.ndim > 1: 39 | wav = wav.mean(-1) 40 | wav = librosa.resample(wav, sr, target_sr=SAMPLE_RATE) 41 | return np.log( 42 | librosa.feature.melspectrogram(wav.astype(np.float32), SAMPLE_RATE, ** 43 | kwargs) + EPS).T 44 | 45 | 46 | class OnlineLogMelDataset(torch.utils.data.Dataset): 47 | def __init__(self, data_list, **kwargs): 48 | super().__init__() 49 | self.dlist = data_list 50 | self.kwargs = kwargs 51 | 52 | def __getitem__(self, idx): 53 | return extract_feature(wavefilepath=self.dlist[idx], 54 | **self.kwargs), self.dlist[idx] 55 | 56 | def __len__(self): 57 | return len(self.dlist) 58 | 59 | 60 | MODELS = { 61 | 't1': { 62 | 'model': crnn, 63 | 'outputdim': 527, 64 | 'encoder': 'labelencoders/teacher.pth', 65 | 'pretrained': 'teacher1/model.pth', 66 | 'resolution': 0.02 67 | }, 68 | 't2': { 69 | 'model': crnn, 70 | 'outputdim': 527, 71 | 'encoder': 'labelencoders/teacher.pth', 72 | 'pretrained': 'teacher2/model.pth', 73 | 'resolution': 0.02 74 | }, 75 | 'sre': { 76 | 'model': crnn, 77 | 'outputdim': 2, 78 | 'encoder': 'labelencoders/students.pth', 79 | 'pretrained': 'sre/model.pth', 80 | 'resolution': 0.02 81 | }, 82 | 'v2': { 83 | 'model': crnn, 84 | 'outputdim': 2, 85 | 'encoder': 'labelencoders/students.pth', 86 | 'pretrained': 'vox2/model.pth', 87 | 'resolution': 0.02 88 | }, 89 | 'a2': { 90 | 'model': crnn, 91 | 'outputdim': 2, 92 | 'encoder': 'labelencoders/students.pth', 93 | 'pretrained': 'audioset2/model.pth', 94 | 'resolution': 0.02 95 | }, 96 | 'a2_v2': { 97 | 'model': crnn, 98 | 'outputdim': 2, 99 | 'encoder': 'labelencoders/students.pth', 100 | 'pretrained': 'audio2_vox2/model.pth', 101 | 'resolution': 0.02 102 | }, 103 | 'c1': { 104 | 'model': crnn, 105 | 'outputdim': 2, 106 | 'encoder': 'labelencoders/students.pth', 107 | 'pretrained': 'c1/model.pth', 108 | 'resolution': 0.02 109 | }, 110 | } 111 | 112 | 113 | def main(): 114 | parser = argparse.ArgumentParser() 115 | group = parser.add_mutually_exclusive_group(required=True) 116 | group.add_argument( 117 | '-w', 118 | '--wav', 119 | help= 120 | 'A single wave/mp3/flac or any other compatible audio file with soundfile.read' 121 | ) 122 | group.add_argument( 123 | '-l', 124 | '--wavlist', 125 | help= 126 | 'A list of wave or any other compatible audio files. E.g., output of find . -type f -name *.wav > wavlist.txt' 127 | ) 128 | parser.add_argument('-model', choices=list(MODELS.keys()), default='sre') 129 | parser.add_argument( 130 | '--pretrained_dir', 131 | default='pretrained_models', 132 | help= 133 | 'Path to downloaded pretrained models directory, (default %(default)s)' 134 | ) 135 | parser.add_argument('-o', 136 | '--output_path', 137 | default=None, 138 | help='Output folder to save predictions if necessary') 139 | parser.add_argument('-soft', 140 | default=False, 141 | action='store_true', 142 | help='Outputs soft probabilities.') 143 | parser.add_argument('-hard', 144 | default=False, 145 | action='store_true', 146 | help='Outputs hard labels as zero-one array.') 147 | parser.add_argument('-th', 148 | '--threshold', 149 | default=(0.5, 0.1), 150 | type=float, 151 | nargs="+") 152 | args = parser.parse_args() 153 | pretrained_dir = Path(args.pretrained_dir) 154 | if not (pretrained_dir.exists() and pretrained_dir.is_dir()): 155 | logger.error(f"""Pretrained directory {args.pretrained_dir} not found. 156 | Please download the pretrained models from and try again or set --pretrained_dir to your directory.""" 157 | ) 158 | return 159 | logger.info("Passed args") 160 | for k, v in vars(args).items(): 161 | logger.info(f"{k} : {str(v):<10}") 162 | if args.wavlist: 163 | wavlist = pd.read_csv(args.wavlist, 164 | usecols=[0], 165 | header=None, 166 | names=['filename']) 167 | wavlist = wavlist['filename'].values.tolist() 168 | elif args.wav: 169 | wavlist = [args.wav] 170 | dset = OnlineLogMelDataset(wavlist, **LMS_ARGS) 171 | dloader = torch.utils.data.DataLoader(dset, 172 | batch_size=1, 173 | num_workers=3, 174 | shuffle=False) 175 | 176 | model_kwargs_pack = MODELS[args.model] 177 | model_resolution = model_kwargs_pack['resolution'] 178 | # Load model from relative path 179 | model = model_kwargs_pack['model']( 180 | outputdim=model_kwargs_pack['outputdim'], 181 | pretrained_from=pretrained_dir / 182 | model_kwargs_pack['pretrained']).to(DEVICE).eval() 183 | encoder = torch.load(pretrained_dir / model_kwargs_pack['encoder']) 184 | logger.trace(model) 185 | 186 | output_dfs = [] 187 | frame_outputs = {} 188 | threshold = tuple(args.threshold) 189 | 190 | speech_label_idx = np.where('Speech' == encoder.classes_)[0].squeeze() 191 | # Using only binary thresholding without filter 192 | if len(threshold) == 1: 193 | postprocessing_method = utils.binarize 194 | else: 195 | postprocessing_method = utils.double_threshold 196 | with torch.no_grad(), tqdm(total=len(dloader), leave=False, 197 | unit='clip') as pbar: 198 | for feature, filename in dloader: 199 | feature = torch.as_tensor(feature).to(DEVICE) 200 | prediction_tag, prediction_time = model(feature) 201 | prediction_tag = prediction_tag.to('cpu') 202 | prediction_time = prediction_time.to('cpu') 203 | 204 | if prediction_time is not None: # Some models do not predict timestamps 205 | 206 | cur_filename = filename[0] #Remove batchsize 207 | thresholded_prediction = postprocessing_method( 208 | prediction_time, *threshold) 209 | speech_soft_pred = prediction_time[..., speech_label_idx] 210 | if args.soft: 211 | speech_soft_pred = prediction_time[ 212 | ..., speech_label_idx].numpy() 213 | frame_outputs[cur_filename] = speech_soft_pred[ 214 | 0] # 1 batch 215 | 216 | if args.hard: 217 | speech_hard_pred = thresholded_prediction[..., 218 | speech_label_idx] 219 | frame_outputs[cur_filename] = speech_hard_pred[ 220 | 0] # 1 batch 221 | # frame_outputs_hard.append(thresholded_prediction) 222 | 223 | labelled_predictions = utils.decode_with_timestamps( 224 | encoder, thresholded_prediction) 225 | pred_label_df = pd.DataFrame( 226 | labelled_predictions[0], 227 | columns=['event_label', 'onset', 'offset']) 228 | if not pred_label_df.empty: 229 | pred_label_df['filename'] = cur_filename 230 | pred_label_df['onset'] *= model_resolution 231 | pred_label_df['offset'] *= model_resolution 232 | pbar.set_postfix(labels=','.join( 233 | np.unique(pred_label_df['event_label'].values))) 234 | pbar.update() 235 | output_dfs.append(pred_label_df) 236 | 237 | full_prediction_df = pd.concat(output_dfs).sort_values(by='onset',ascending=True).reset_index() 238 | prediction_df = full_prediction_df[full_prediction_df['event_label'] == 239 | 'Speech'] 240 | 241 | if args.output_path: 242 | args.output_path = Path(args.output_path) 243 | args.output_path.mkdir(parents=True, exist_ok=True) 244 | prediction_df.to_csv(args.output_path / 'speech_predictions.tsv', 245 | sep='\t', 246 | index=False) 247 | full_prediction_df.to_csv(args.output_path / 'all_predictions.tsv', 248 | sep='\t', 249 | index=False) 250 | 251 | if args.soft or args.hard: 252 | prefix = 'soft' if args.soft else 'hard' 253 | with open(args.output_path / f'{prefix}_predictions.txt', 254 | 'w') as wp: 255 | np.set_printoptions(suppress=True, 256 | precision=2, 257 | linewidth=np.inf) 258 | for fname, output in frame_outputs.items(): 259 | print(f"{fname} {output}", file=wp) 260 | logger.info(f"Putting results also to dir {args.output_path}") 261 | if args.soft or args.hard: 262 | np.set_printoptions(suppress=True, precision=2, linewidth=np.inf) 263 | for fname, output in frame_outputs.items(): 264 | print(f"{fname} {output}") 265 | else: 266 | print(prediction_df.to_markdown(showindex=False)) 267 | 268 | 269 | if __name__ == "__main__": 270 | main() 271 | -------------------------------------------------------------------------------- /labelencoders/vad.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RicherMans/Datadriven-GPVAD/6c94570bee753271722f64826a0ed00c030b089c/labelencoders/vad.pth -------------------------------------------------------------------------------- /losses.py: -------------------------------------------------------------------------------- 1 | import ignite.metrics as metrics 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | class FrameBCELoss(nn.Module): 7 | """docstring for BCELoss""" 8 | def __init__(self): 9 | super().__init__() 10 | 11 | def forward(self, clip_prob, frame_prob, tar_time, tar_clip, length): 12 | batchsize, timesteps, ndim = tar_time.shape 13 | idxs = torch.arange(timesteps, device='cpu').repeat(batchsize).view( 14 | batchsize, timesteps) 15 | mask = (idxs < length.view(-1, 1)).to(frame_prob.device) 16 | masked_bce = nn.functional.binary_cross_entropy( 17 | input=frame_prob, target=tar_time, 18 | reduction='none') * mask.unsqueeze(-1) 19 | return masked_bce.sum() / mask.sum() 20 | 21 | 22 | class ClipFrameBCELoss(nn.Module): 23 | """docstring for BCELoss""" 24 | def __init__(self): 25 | super().__init__() 26 | self.frameloss = FrameBCELoss() 27 | self.cliploss = nn.BCELoss() 28 | 29 | def forward(self, clip_prob, frame_prob, tar_time, tar_clip, length): 30 | return self.frameloss( 31 | clip_prob, frame_prob, tar_time, tar_clip, length) + self.cliploss( 32 | clip_prob, tar_clip) 33 | 34 | 35 | class BCELossWithLabelSmoothing(nn.Module): 36 | """docstring for BCELoss""" 37 | def __init__(self, label_smoothing=0.1): 38 | super().__init__() 39 | self.label_smoothing = label_smoothing 40 | 41 | def forward(self, clip_prob, frame_prob, tar): 42 | n_classes = clip_prob.shape[-1] 43 | with torch.no_grad(): 44 | tar = tar * (1 - self.label_smoothing) + ( 45 | 1 - tar) * self.label_smoothing / (n_classes - 1) 46 | return nn.functional.binary_cross_entropy(clip_prob, tar) 47 | 48 | 49 | # Reimplement Loss, because ignite loss only takes 2 args, not 3 and nees to parse kwargs around ... just *output does the trick 50 | class Loss(metrics.Loss): 51 | def __init__(self, 52 | loss_fn, 53 | output_transform=lambda x: x, 54 | batch_size=lambda x: len(x), 55 | device=None): 56 | super(Loss, self).__init__(loss_fn=loss_fn, 57 | output_transform=output_transform, 58 | batch_size=batch_size) 59 | 60 | def update(self, output): 61 | average_loss = self._loss_fn(*output) 62 | 63 | if len(average_loss.shape) != 0: 64 | raise ValueError('loss_fn did not return the average loss.') 65 | 66 | N = self._batch_size(output[0]) 67 | self._sum += average_loss.item() * N 68 | self._num_examples += N 69 | 70 | 71 | if __name__ == "__main__": 72 | batch, time, dim = 4, 500, 10 73 | frame = torch.sigmoid(torch.randn(batch, time, dim)) 74 | clip = torch.sigmoid(torch.randn(batch, dim)) 75 | tar = torch.empty(batch, dim).random_(2) 76 | -------------------------------------------------------------------------------- /metrics.py: -------------------------------------------------------------------------------- 1 | import sed_eval 2 | import utils 3 | import pandas as pd 4 | from sklearn.preprocessing import binarize, MultiLabelBinarizer 5 | import sklearn.metrics as skmetrics 6 | import numpy as np 7 | 8 | 9 | def get_audio_tagging_df(df): 10 | return df.groupby('filename')['event_label'].unique().reset_index() 11 | 12 | 13 | def audio_tagging_results(reference, estimated): 14 | """audio_tagging_results. Returns clip-level F1 Scores 15 | 16 | :param reference: The ground truth dataframe as pd.DataFrame 17 | :param estimated: Predicted labels by the model ( thresholded ) 18 | """ 19 | if "event_label" in reference.columns: 20 | classes = reference.event_label.dropna().unique().tolist( 21 | ) + estimated.event_label.dropna().unique().tolist() 22 | encoder = MultiLabelBinarizer().fit([classes]) 23 | reference = get_audio_tagging_df(reference) 24 | estimated = get_audio_tagging_df(estimated) 25 | ref_labels, _ = utils.encode_labels(reference['event_label'], 26 | encoder=encoder) 27 | reference['event_label'] = ref_labels.tolist() 28 | est_labels, _ = utils.encode_labels(estimated['event_label'], 29 | encoder=encoder) 30 | estimated['event_label'] = est_labels.tolist() 31 | 32 | matching = reference.merge(estimated, 33 | how='outer', 34 | on="filename", 35 | suffixes=["_ref", "_pred"]) 36 | 37 | def na_values(val): 38 | if type(val) is np.ndarray: 39 | return val 40 | elif isinstance(val, list): 41 | return np.array(val) 42 | if pd.isna(val): 43 | return np.zeros(len(encoder.classes_)) 44 | return val 45 | 46 | ret_df = pd.DataFrame(columns=['label', 'f1', 'precision', 'recall']) 47 | if not estimated.empty: 48 | matching['event_label_pred'] = matching.event_label_pred.apply( 49 | na_values) 50 | matching['event_label_ref'] = matching.event_label_ref.apply(na_values) 51 | 52 | y_true = np.vstack(matching['event_label_ref'].values) 53 | y_pred = np.vstack(matching['event_label_pred'].values) 54 | ret_df.loc[:, 'label'] = encoder.classes_ 55 | for avg in [None, 'macro', 'micro']: 56 | avg_f1 = skmetrics.f1_score(y_true, y_pred, average=avg) 57 | avg_pre = skmetrics.precision_score(y_true, y_pred, average=avg) 58 | avg_rec = skmetrics.recall_score(y_true, y_pred, average=avg) 59 | # avg_auc = skmetrics.roc_auc_score(y_true, y_pred, average=avg) 60 | 61 | if avg == None: 62 | # Add for each label non pooled stats 63 | ret_df.loc[:, 'precision'] = avg_pre 64 | ret_df.loc[:, 'recall'] = avg_rec 65 | ret_df.loc[:, 'f1'] = avg_f1 66 | # ret_df.loc[:, 'AUC'] = avg_auc 67 | else: 68 | # Append macro and micro results in last 2 rows 69 | ret_df = ret_df.append( 70 | { 71 | 'label': avg, 72 | 'precision': avg_pre, 73 | 'recall': avg_rec, 74 | 'f1': avg_f1, 75 | # 'AUC': avg_auc 76 | }, 77 | ignore_index=True) 78 | return ret_df 79 | 80 | 81 | def get_event_list_current_file(df, fname): 82 | """ 83 | Get list of events for a given filename 84 | :param df: pd.DataFrame, the dataframe to search on 85 | :param fname: the filename to extract the value from the dataframe 86 | :return: list of events (dictionaries) for the given filename 87 | """ 88 | event_file = df[df["filename"] == fname] 89 | if len(event_file) == 1: 90 | if pd.isna(event_file["event_label"].iloc[0]): 91 | event_list_for_current_file = [{"filename": fname}] 92 | else: 93 | event_list_for_current_file = event_file.to_dict('records') 94 | else: 95 | event_list_for_current_file = event_file.to_dict('records') 96 | 97 | return event_list_for_current_file 98 | 99 | 100 | def event_based_evaluation_df(reference, 101 | estimated, 102 | t_collar=0.200, 103 | percentage_of_length=0.2): 104 | """ 105 | Calculate EventBasedMetric given a reference and estimated dataframe 106 | :param reference: pd.DataFrame containing "filename" "onset" "offset" and "event_label" columns which describe the 107 | reference events 108 | :param estimated: pd.DataFrame containing "filename" "onset" "offset" and "event_label" columns which describe the 109 | estimated events to be compared with reference 110 | :return: sed_eval.sound_event.EventBasedMetrics with the scores 111 | """ 112 | 113 | evaluated_files = reference["filename"].unique() 114 | 115 | classes = [] 116 | classes.extend(reference.event_label.dropna().unique()) 117 | classes.extend(estimated.event_label.dropna().unique()) 118 | classes = list(set(classes)) 119 | 120 | event_based_metric = sed_eval.sound_event.EventBasedMetrics( 121 | event_label_list=classes, 122 | t_collar=t_collar, 123 | percentage_of_length=percentage_of_length, 124 | empty_system_output_handling='zero_score') 125 | 126 | for fname in evaluated_files: 127 | reference_event_list_for_current_file = get_event_list_current_file( 128 | reference, fname) 129 | estimated_event_list_for_current_file = get_event_list_current_file( 130 | estimated, fname) 131 | 132 | event_based_metric.evaluate( 133 | reference_event_list=reference_event_list_for_current_file, 134 | estimated_event_list=estimated_event_list_for_current_file, 135 | ) 136 | 137 | return event_based_metric 138 | 139 | 140 | def segment_based_evaluation_df(reference, estimated, time_resolution=1.): 141 | evaluated_files = reference["filename"].unique() 142 | 143 | classes = [] 144 | classes.extend(reference.event_label.dropna().unique()) 145 | classes.extend(estimated.event_label.dropna().unique()) 146 | classes = list(set(classes)) 147 | 148 | segment_based_metric = sed_eval.sound_event.SegmentBasedMetrics( 149 | event_label_list=classes, time_resolution=time_resolution) 150 | 151 | for fname in evaluated_files: 152 | reference_event_list_for_current_file = get_event_list_current_file( 153 | reference, fname) 154 | estimated_event_list_for_current_file = get_event_list_current_file( 155 | estimated, fname) 156 | 157 | segment_based_metric.evaluate( 158 | reference_event_list=reference_event_list_for_current_file, 159 | estimated_event_list=estimated_event_list_for_current_file) 160 | 161 | return segment_based_metric 162 | 163 | 164 | def compute_metrics(valid_df, pred_df, time_resolution=1.): 165 | 166 | metric_event = event_based_evaluation_df(valid_df, 167 | pred_df, 168 | t_collar=0.200, 169 | percentage_of_length=0.2) 170 | metric_segment = segment_based_evaluation_df( 171 | valid_df, pred_df, time_resolution=time_resolution) 172 | return metric_event, metric_segment 173 | 174 | 175 | def roc(y_true, y_pred, average=None): 176 | return skmetrics.roc_auc_score(y_true, y_pred, average=average) 177 | 178 | 179 | def mAP(y_true, y_pred, average=None): 180 | return skmetrics.average_precision_score(y_true, y_pred, average=average) 181 | 182 | 183 | def precision_recall_fscore_support(y_true, y_pred, average=None): 184 | return skmetrics.precision_recall_fscore_support(y_true, 185 | y_pred, 186 | average=average) 187 | 188 | 189 | def tpr_fpr(y_true, y_pred): 190 | fpr, tpr, thresholds = skmetrics.roc_curve(y_true, y_pred) 191 | return fpr, tpr, thresholds 192 | 193 | 194 | def obtain_error_rates_alt(y_true, y_pred, threshold=0.5): 195 | speech_frame_predictions = binarize(y_pred.reshape(-1, 1), 196 | threshold=threshold) 197 | tn, fp, fn, tp = skmetrics.confusion_matrix( 198 | y_true, speech_frame_predictions).ravel() 199 | 200 | p_miss = 100 * (fn / (fn + tp)) 201 | p_fa = 100 * (fp / (fp + tn)) 202 | return p_fa, p_miss 203 | 204 | 205 | def confusion_matrix(y_true, y_pred): 206 | return skmetrics.confusion_matrix(y_true, y_pred) 207 | 208 | 209 | def obtain_error_rates(y_true, y_pred, threshold=0.5): 210 | negatives = y_pred[np.where(y_true == 0)] 211 | positives = y_pred[np.where(y_true == 1)] 212 | Pfa = np.sum(negatives >= threshold) / negatives.size 213 | Pmiss = np.sum(positives < threshold) / positives.size 214 | return Pfa, Pmiss 215 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from pathlib import Path 4 | import torch.nn as nn 5 | 6 | 7 | def crnn(inputdim=64, outputdim=527, pretrained_from='balanced.pth'): 8 | model = CRNN(inputdim, outputdim) 9 | if pretrained_from: 10 | state = torch.load(pretrained_from, 11 | map_location='cpu') 12 | model.load_state_dict(state, strict=False) 13 | return model 14 | 15 | 16 | def cnn10(inputdim=64, outputdim=527, pretrained_from='balanced.pth'): 17 | model = CNN10(inputdim, outputdim) 18 | if pretrained_from: 19 | state = torch.load(pretrained_from, 20 | map_location='cpu') 21 | model.load_state_dict(state, strict=False) 22 | return model 23 | 24 | 25 | def init_weights(m): 26 | if isinstance(m, (nn.Conv2d, nn.Conv1d)): 27 | nn.init.kaiming_normal_(m.weight) 28 | if m.bias is not None: 29 | nn.init.constant_(m.bias, 0) 30 | elif isinstance(m, nn.BatchNorm2d): 31 | nn.init.constant_(m.weight, 1) 32 | if m.bias is not None: 33 | nn.init.constant_(m.bias, 0) 34 | if isinstance(m, nn.Linear): 35 | nn.init.kaiming_uniform_(m.weight) 36 | if m.bias is not None: 37 | nn.init.constant_(m.bias, 0) 38 | 39 | 40 | class LinearSoftPool(nn.Module): 41 | """LinearSoftPool 42 | 43 | Linear softmax, takes logits and returns a probability, near to the actual maximum value. 44 | Taken from the paper: 45 | 46 | A Comparison of Five Multiple Instance Learning Pooling Functions for Sound Event Detection with Weak Labeling 47 | https://arxiv.org/abs/1810.09050 48 | 49 | """ 50 | def __init__(self, pooldim=1): 51 | super().__init__() 52 | self.pooldim = pooldim 53 | 54 | def forward(self, logits, time_decision): 55 | return (time_decision**2).sum(self.pooldim) / time_decision.sum( 56 | self.pooldim) 57 | 58 | 59 | class MeanPool(nn.Module): 60 | def __init__(self, pooldim=1): 61 | super().__init__() 62 | self.pooldim = pooldim 63 | 64 | def forward(self, logits, decision): 65 | return torch.mean(decision, dim=self.pooldim) 66 | 67 | 68 | def parse_poolingfunction(poolingfunction_name='mean', **kwargs): 69 | """parse_poolingfunction 70 | A heler function to parse any temporal pooling 71 | Pooling is done on dimension 1 72 | 73 | :param poolingfunction_name: 74 | :param **kwargs: 75 | """ 76 | poolingfunction_name = poolingfunction_name.lower() 77 | if poolingfunction_name == 'mean': 78 | return MeanPool(pooldim=1) 79 | elif poolingfunction_name == 'linear': 80 | return LinearSoftPool(pooldim=1) 81 | elif poolingfunction_name == 'attention': 82 | return AttentionPool(inputdim=kwargs['inputdim'], 83 | outputdim=kwargs['outputdim']) 84 | 85 | 86 | class AttentionPool(nn.Module): 87 | """docstring for AttentionPool""" 88 | def __init__(self, inputdim, outputdim=10, pooldim=1, **kwargs): 89 | super().__init__() 90 | self.inputdim = inputdim 91 | self.outputdim = outputdim 92 | self.pooldim = pooldim 93 | self.transform = nn.Linear(inputdim, outputdim) 94 | self.activ = nn.Softmax(dim=self.pooldim) 95 | self.eps = 1e-7 96 | 97 | def forward(self, logits, decision): 98 | # Input is (B, T, D) 99 | # B, T , D 100 | w = self.activ(self.transform(logits)) 101 | detect = (decision * w).sum( 102 | self.pooldim) / (w.sum(self.pooldim) + self.eps) 103 | # B, T, D 104 | return detect 105 | 106 | 107 | class Block2D(nn.Module): 108 | def __init__(self, cin, cout, kernel_size=3, padding=1): 109 | super().__init__() 110 | self.block = nn.Sequential( 111 | nn.BatchNorm2d(cin), 112 | nn.Conv2d(cin, 113 | cout, 114 | kernel_size=kernel_size, 115 | padding=padding, 116 | bias=False), 117 | nn.LeakyReLU(inplace=True, negative_slope=0.1)) 118 | 119 | def forward(self, x): 120 | return self.block(x) 121 | 122 | 123 | class CRNN(nn.Module): 124 | def __init__(self, inputdim, outputdim, **kwargs): 125 | super().__init__() 126 | self.features = nn.Sequential( 127 | Block2D(1, 32), 128 | nn.LPPool2d(4, (2, 4)), 129 | Block2D(32, 128), 130 | Block2D(128, 128), 131 | nn.LPPool2d(4, (2, 4)), 132 | Block2D(128, 128), 133 | Block2D(128, 128), 134 | nn.LPPool2d(4, (1, 4)), 135 | nn.Dropout(0.3), 136 | ) 137 | with torch.no_grad(): 138 | rnn_input_dim = self.features(torch.randn(1, 1, 500, 139 | inputdim)).shape 140 | rnn_input_dim = rnn_input_dim[1] * rnn_input_dim[-1] 141 | 142 | self.gru = nn.GRU(rnn_input_dim, 143 | 128, 144 | bidirectional=True, 145 | batch_first=True) 146 | self.temp_pool = parse_poolingfunction(kwargs.get( 147 | 'temppool', 'linear'), 148 | inputdim=256, 149 | outputdim=outputdim) 150 | self.outputlayer = nn.Linear(256, outputdim) 151 | self.features.apply(init_weights) 152 | self.outputlayer.apply(init_weights) 153 | 154 | def forward(self, x, upsample=True): 155 | batch, time, dim = x.shape 156 | x = x.unsqueeze(1) 157 | x = self.features(x) 158 | x = x.transpose(1, 2).contiguous().flatten(-2) 159 | x, _ = self.gru(x) 160 | decision_time = torch.sigmoid(self.outputlayer(x)).clamp(1e-7, 1.) 161 | if upsample: 162 | decision_time = torch.nn.functional.interpolate( 163 | decision_time.transpose(1, 2), 164 | time, 165 | mode='linear', 166 | align_corners=False).transpose(1, 2) 167 | decision = self.temp_pool(x, decision_time).clamp(1e-7, 1.).squeeze(1) 168 | return decision, decision_time 169 | 170 | 171 | class CNN10(nn.Module): 172 | def __init__(self, inputdim, outputdim, **kwargs): 173 | super().__init__() 174 | self.features = nn.Sequential( 175 | Block2D(1, 64), 176 | Block2D(64, 64), 177 | nn.LPPool2d(4, (2, 4)), 178 | Block2D(64, 128), 179 | Block2D(128, 128), 180 | nn.LPPool2d(4, (2, 2)), 181 | Block2D(128, 256), 182 | Block2D(256, 256), 183 | nn.LPPool2d(4, (1, 2)), 184 | Block2D(256, 512), 185 | Block2D(512, 512), 186 | nn.LPPool2d(4, (1, 2)), 187 | nn.Dropout(0.3), 188 | nn.AdaptiveAvgPool2d((None, 1)), 189 | ) 190 | 191 | self.temp_pool = parse_poolingfunction(kwargs.get( 192 | 'temppool', 'attention'), 193 | inputdim=512, 194 | outputdim=outputdim) 195 | self.outputlayer = nn.Linear(512, outputdim) 196 | self.features.apply(init_weights) 197 | self.outputlayer.apply(init_weights) 198 | 199 | def forward(self, x, upsample=True): 200 | batch, time, dim = x.shape 201 | x = x.unsqueeze(1) 202 | x = self.features(x) 203 | x = x.transpose(1, 2).contiguous().flatten(-2) 204 | decision_time = torch.sigmoid(self.outputlayer(x)).clamp(1e-7, 1.) 205 | decision = self.temp_pool(x, decision_time).clamp(1e-7, 1.).squeeze(1) 206 | if upsample: 207 | decision_time = torch.nn.functional.interpolate( 208 | decision_time.transpose(1, 2), 209 | time, 210 | mode='linear', 211 | align_corners=False).transpose(1, 2) 212 | return decision, decision_time 213 | 214 | 215 | class CRNN10(nn.Module): 216 | def __init__(self, inputdim, outputdim, **kwargs): 217 | super().__init__() 218 | self._hiddim = kwargs.get('hiddim', 256) 219 | self.features = nn.Sequential( 220 | Block2D(1, 64), 221 | Block2D(64, 64), 222 | nn.LPPool2d(4, (2, 4)), 223 | Block2D(64, 128), 224 | Block2D(128, 128), 225 | nn.LPPool2d(4, (2, 2)), 226 | Block2D(128, 256), 227 | Block2D(256, 256), 228 | nn.LPPool2d(4, (1, 2)), 229 | Block2D(256, 512), 230 | Block2D(512, 512), 231 | nn.LPPool2d(4, (1, 2)), 232 | nn.Dropout(0.3), 233 | nn.AdaptiveAvgPool2d((None, 1)), 234 | ) 235 | with torch.no_grad(): 236 | rnn_input_dim = self.features(torch.randn(1, 1, 500, 237 | inputdim)).shape 238 | rnn_input_dim = rnn_input_dim[1] * rnn_input_dim[-1] 239 | self.gru = nn.GRU(rnn_input_dim, 240 | self._hiddim, 241 | bidirectional=True, 242 | batch_first=True) 243 | self.temp_pool = parse_poolingfunction(kwargs.get( 244 | 'temppool', 'linear'), 245 | inputdim=self._hiddim*2, 246 | outputdim=outputdim) 247 | 248 | self.outputlayer = nn.Linear(self._hiddim*2, outputdim) 249 | self.features.apply(init_weights) 250 | self.outputlayer.apply(init_weights) 251 | 252 | def forward(self, x, upsample=True): 253 | batch, time, dim = x.shape 254 | x = x.unsqueeze(1) 255 | x = self.features(x) 256 | x = x.transpose(1, 2).contiguous().flatten(-2) 257 | decision_time = torch.sigmoid(self.outputlayer(x)).clamp(1e-7, 1.) 258 | decision = self.temp_pool(x, decision_time).clamp(1e-7, 1.).squeeze(1) 259 | if upsample: 260 | decision_time = torch.nn.functional.interpolate( 261 | decision_time.transpose(1, 2), 262 | time, 263 | mode='linear', 264 | align_corners=False).transpose(1, 2) 265 | return decision, decision_time 266 | -------------------------------------------------------------------------------- /pprint_results.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import pandas as pd 3 | import re 4 | from pathlib import Path 5 | import torch 6 | 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument('dir', type=str) 9 | parser.add_argument('fmt', default=None,nargs='?') 10 | args = parser.parse_args() 11 | 12 | res = {} 13 | root_dir = Path(args.dir) 14 | train_log = root_dir / 'train.log' 15 | 16 | config = torch.load(root_dir / 'run_config.pth') 17 | pretrained = config.get('pretrained', None) 18 | # logs 19 | augment = config.get('transforms', []) 20 | label_type = config.get('label_type', 'soft') 21 | model = config.get('model','CRNN') 22 | 23 | 24 | def get_seg_metrics(line, pointer, seg_type='Segment'): 25 | res = {} 26 | while not 'macro-average' in line: 27 | line = next(pointer).strip() 28 | while not 'F-measure (F1)' in line: 29 | line = next(pointer).strip() 30 | res[f'F1'] = float(line.split()[-2]) 31 | while not 'Precision' in line: 32 | line = next(pointer).strip() 33 | res[f'Precision'] = float(line.split()[-2]) 34 | while not 'Recall' in line: 35 | line = next(pointer).strip() 36 | res[f'Recall'] = float(line.split()[-2]) 37 | return res 38 | 39 | 40 | def parse_eval_file(eval_file): 41 | res = {} 42 | frame_results = {} 43 | with open(eval_file, 'r') as rp: 44 | for line in rp: 45 | line = line.strip() 46 | if 'AUC' in line: 47 | auc = line.split()[-1] 48 | frame_results['AUC'] = float(auc) 49 | if 'FER' in line: 50 | fer = line.split()[-1] 51 | frame_results['FER'] = float(fer) 52 | if 'VAD macro' in line: 53 | f1, pre, rec = re.findall(r"[-+]?\d*\.\d+|\d+", 54 | line)[1:] # First hit is F1 55 | frame_results['F1'] = float(f1) 56 | frame_results['Precision'] = float(pre) 57 | frame_results['Recall'] = float(rec) 58 | if "Segment based metrics" in line: 59 | res['Segment'] = get_seg_metrics(line, rp) 60 | if 'Event based metrics' in line: 61 | res['Event'] = get_seg_metrics(line, rp, 'Event') 62 | res['Frame'] = frame_results 63 | return res 64 | 65 | 66 | all_results = [] 67 | for f in root_dir.glob('*.txt'): 68 | eval_dataset = str(f.stem)[11:] 69 | res = parse_eval_file(f) 70 | df = pd.DataFrame(res).fillna('') 71 | df['data'] = eval_dataset 72 | df['augment'] = ",".join(augment) 73 | df['pretrained'] = pretrained 74 | df['label_type'] = label_type 75 | df['model'] = model 76 | all_results.append(df) 77 | df = pd.concat(all_results) 78 | if args.fmt == 'csv': 79 | print(df.to_csv()) 80 | else: 81 | print(df) 82 | -------------------------------------------------------------------------------- /pretrained_models/audio2_vox2/model.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RicherMans/Datadriven-GPVAD/6c94570bee753271722f64826a0ed00c030b089c/pretrained_models/audio2_vox2/model.pth -------------------------------------------------------------------------------- /pretrained_models/audioset2/model.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RicherMans/Datadriven-GPVAD/6c94570bee753271722f64826a0ed00c030b089c/pretrained_models/audioset2/model.pth -------------------------------------------------------------------------------- /pretrained_models/c1/model.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RicherMans/Datadriven-GPVAD/6c94570bee753271722f64826a0ed00c030b089c/pretrained_models/c1/model.pth -------------------------------------------------------------------------------- /pretrained_models/labelencoders/students.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RicherMans/Datadriven-GPVAD/6c94570bee753271722f64826a0ed00c030b089c/pretrained_models/labelencoders/students.pth -------------------------------------------------------------------------------- /pretrained_models/labelencoders/teacher.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RicherMans/Datadriven-GPVAD/6c94570bee753271722f64826a0ed00c030b089c/pretrained_models/labelencoders/teacher.pth -------------------------------------------------------------------------------- /pretrained_models/sre/model.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RicherMans/Datadriven-GPVAD/6c94570bee753271722f64826a0ed00c030b089c/pretrained_models/sre/model.pth -------------------------------------------------------------------------------- /pretrained_models/teacher1/model.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RicherMans/Datadriven-GPVAD/6c94570bee753271722f64826a0ed00c030b089c/pretrained_models/teacher1/model.pth -------------------------------------------------------------------------------- /pretrained_models/teacher2/model.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RicherMans/Datadriven-GPVAD/6c94570bee753271722f64826a0ed00c030b089c/pretrained_models/teacher2/model.pth -------------------------------------------------------------------------------- /pretrained_models/vox2/model.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RicherMans/Datadriven-GPVAD/6c94570bee753271722f64826a0ed00c030b089c/pretrained_models/vox2/model.pth -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | pandas==1.1.0 2 | SoundFile==0.10.2 3 | numpy==1.16.4 4 | loguru==0.4.0 5 | h5py==2.9.0 6 | scipy==1.3.0 7 | torch==1.2.0 8 | sed_eval==0.2.1 9 | pytorch_ignite==0.2.0 10 | tqdm==4.32.2 11 | tabulate==0.8.3 12 | six==1.12.0 13 | fire==0.1.3 14 | librosa==0.7.0 15 | ignite==1.1.0 16 | scikit_learn==0.23.1 17 | typing==3.7.4.1 18 | PyYAML==5.4 19 | numba==0.48 20 | pypeln 21 | -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | import os 5 | import datetime 6 | 7 | import uuid 8 | import fire 9 | from pathlib import Path 10 | 11 | import pandas as pd 12 | import torch 13 | import numpy as np 14 | from tqdm import tqdm 15 | from ignite.contrib.handlers import ProgressBar, param_scheduler 16 | from ignite.engine import (Engine, Events) 17 | from ignite.handlers import EarlyStopping, ModelCheckpoint 18 | from ignite.metrics import Accuracy, RunningAverage, Precision, Recall 19 | from ignite.utils import convert_tensor 20 | from tabulate import tabulate 21 | from h5py import File 22 | 23 | import dataset 24 | import models 25 | import utils 26 | import metrics 27 | import losses 28 | 29 | DEVICE = 'cpu' 30 | if torch.cuda.is_available( 31 | ) and 'SLURM_JOB_PARTITION' in os.environ and 'gpu' in os.environ[ 32 | 'SLURM_JOB_PARTITION']: 33 | DEVICE = 'cuda' 34 | # Without results are slightly inconsistent 35 | torch.backends.cudnn.deterministic = True 36 | DEVICE = torch.device(DEVICE) 37 | 38 | 39 | class Runner(object): 40 | """Main class to run experiments with e.g., train and evaluate""" 41 | def __init__(self, seed=42): 42 | """__init__ 43 | 44 | :param config: YAML config file 45 | :param **kwargs: Overwrite of yaml config 46 | """ 47 | super().__init__() 48 | torch.manual_seed(seed) 49 | np.random.seed(seed) 50 | 51 | @staticmethod 52 | def _forward(model, batch): 53 | inputs, targets_time, targets_clip, filenames, lengths = batch 54 | inputs = convert_tensor(inputs, device=DEVICE, non_blocking=True) 55 | targets_time = convert_tensor(targets_time, 56 | device=DEVICE, 57 | non_blocking=True) 58 | targets_clip = convert_tensor(targets_clip, 59 | device=DEVICE, 60 | non_blocking=True) 61 | clip_level_output, frame_level_output = model(inputs) 62 | return clip_level_output, frame_level_output, targets_time, targets_clip, lengths 63 | 64 | @staticmethod 65 | def _negative_loss(engine): 66 | return -engine.state.metrics['Loss'] 67 | 68 | def train(self, config, **kwargs): 69 | """Trains a given model specified in the config file or passed as the --model parameter. 70 | All options in the config file can be overwritten as needed by passing --PARAM 71 | Options with variable lengths ( e.g., kwargs can be passed by --PARAM '{"PARAM1":VAR1, "PARAM2":VAR2}' 72 | 73 | :param config: yaml config file 74 | :param **kwargs: parameters to overwrite yaml config 75 | """ 76 | 77 | config_parameters = utils.parse_config_or_kwargs(config, **kwargs) 78 | outputdir = os.path.join( 79 | config_parameters['outputpath'], config_parameters['model'], 80 | "{}_{}".format( 81 | datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%m'), 82 | uuid.uuid1().hex)) 83 | # Early init because of creating dir 84 | checkpoint_handler = ModelCheckpoint( 85 | outputdir, 86 | 'run', 87 | n_saved=3, 88 | require_empty=False, 89 | create_dir=True, 90 | score_function=self._negative_loss, 91 | score_name='loss') 92 | logger = utils.getfile_outlogger(os.path.join(outputdir, 'train.log')) 93 | logger.info("Storing files in {}".format(outputdir)) 94 | # utils.pprint_dict 95 | utils.pprint_dict(config_parameters, logger.info) 96 | logger.info("Running on device {}".format(DEVICE)) 97 | label_df = pd.read_csv(config_parameters['label'], sep='\s+') 98 | data_df = pd.read_csv(config_parameters['data'], sep='\s+') 99 | # In case that both are not matching 100 | merged = data_df.merge(label_df, on='filename') 101 | common_idxs = merged['filename'] 102 | data_df = data_df[data_df['filename'].isin(common_idxs)] 103 | label_df = label_df[label_df['filename'].isin(common_idxs)] 104 | 105 | train_df, cv_df = utils.split_train_cv( 106 | label_df, **config_parameters['data_args']) 107 | train_label = utils.df_to_dict(train_df) 108 | cv_label = utils.df_to_dict(cv_df) 109 | data = utils.df_to_dict(data_df) 110 | 111 | transform = utils.parse_transforms(config_parameters['transforms']) 112 | torch.save(config_parameters, os.path.join(outputdir, 113 | 'run_config.pth')) 114 | logger.info("Transforms:") 115 | utils.pprint_dict(transform, logger.info, formatter='pretty') 116 | assert len(cv_df) > 0, "Fraction a bit too large?" 117 | 118 | trainloader = dataset.gettraindataloader( 119 | h5files=data, 120 | h5labels=train_label, 121 | transform=transform, 122 | label_type=config_parameters['label_type'], 123 | batch_size=config_parameters['batch_size'], 124 | num_workers=config_parameters['num_workers'], 125 | shuffle=True, 126 | ) 127 | 128 | cvdataloader = dataset.gettraindataloader( 129 | h5files=data, 130 | h5labels=cv_label, 131 | label_type=config_parameters['label_type'], 132 | transform=None, 133 | shuffle=False, 134 | batch_size=config_parameters['batch_size'], 135 | num_workers=config_parameters['num_workers'], 136 | ) 137 | model = getattr(models, config_parameters['model'], 138 | 'CRNN')(inputdim=trainloader.dataset.datadim, 139 | outputdim=2, 140 | **config_parameters['model_args']) 141 | if 'pretrained' in config_parameters and config_parameters[ 142 | 'pretrained'] is not None: 143 | model_dump = torch.load(config_parameters['pretrained'], 144 | map_location='cpu') 145 | model_state = model.state_dict() 146 | pretrained_state = { 147 | k: v 148 | for k, v in model_dump.items() 149 | if k in model_state and v.size() == model_state[k].size() 150 | } 151 | model_state.update(pretrained_state) 152 | model.load_state_dict(model_state) 153 | logger.info("Loading pretrained model {}".format( 154 | config_parameters['pretrained'])) 155 | 156 | model = model.to(DEVICE) 157 | optimizer = getattr( 158 | torch.optim, 159 | config_parameters['optimizer'], 160 | )(model.parameters(), **config_parameters['optimizer_args']) 161 | 162 | utils.pprint_dict(optimizer, logger.info, formatter='pretty') 163 | utils.pprint_dict(model, logger.info, formatter='pretty') 164 | if DEVICE.type != 'cpu' and torch.cuda.device_count() > 1: 165 | logger.info("Using {} GPUs!".format(torch.cuda.device_count())) 166 | model = torch.nn.DataParallel(model) 167 | criterion = getattr(losses, config_parameters['loss'])().to(DEVICE) 168 | 169 | def _train_batch(_, batch): 170 | model.train() 171 | with torch.enable_grad(): 172 | optimizer.zero_grad() 173 | output = self._forward( 174 | model, batch) # output is tuple (clip, frame, target) 175 | loss = criterion(*output) 176 | loss.backward() 177 | # Single loss 178 | optimizer.step() 179 | return loss.item() 180 | 181 | def _inference(_, batch): 182 | model.eval() 183 | with torch.no_grad(): 184 | return self._forward(model, batch) 185 | 186 | def thresholded_output_transform(output): 187 | # Output is (clip, frame, target, lengths) 188 | _, y_pred, y, y_clip, length = output 189 | batchsize, timesteps, ndim = y.shape 190 | idxs = torch.arange(timesteps, 191 | device='cpu').repeat(batchsize).view( 192 | batchsize, timesteps) 193 | mask = (idxs < length.view(-1, 1)).to(y.device) 194 | y = y * mask.unsqueeze(-1) 195 | y_pred = torch.round(y_pred) 196 | y = torch.round(y) 197 | return y_pred, y 198 | 199 | metrics = { 200 | 'Loss': losses.Loss( 201 | criterion), #reimplementation of Loss, supports 3 way loss 202 | 'Precision': Precision(thresholded_output_transform), 203 | 'Recall': Recall(thresholded_output_transform), 204 | 'Accuracy': Accuracy(thresholded_output_transform), 205 | } 206 | train_engine = Engine(_train_batch) 207 | inference_engine = Engine(_inference) 208 | for name, metric in metrics.items(): 209 | metric.attach(inference_engine, name) 210 | 211 | def compute_metrics(engine): 212 | inference_engine.run(cvdataloader) 213 | results = inference_engine.state.metrics 214 | output_str_list = [ 215 | "Validation Results - Epoch : {:<5}".format(engine.state.epoch) 216 | ] 217 | for metric in metrics: 218 | output_str_list.append("{} {:<5.2f}".format( 219 | metric, results[metric])) 220 | logger.info(" ".join(output_str_list)) 221 | pbar.n = pbar.last_print_n = 0 222 | 223 | pbar = ProgressBar(persist=False) 224 | pbar.attach(train_engine) 225 | 226 | train_engine.add_event_handler(Events.ITERATION_COMPLETED(every=5000), 227 | compute_metrics) 228 | train_engine.add_event_handler(Events.EPOCH_COMPLETED, compute_metrics) 229 | 230 | early_stop_handler = EarlyStopping( 231 | patience=config_parameters['early_stop'], 232 | score_function=self._negative_loss, 233 | trainer=train_engine) 234 | inference_engine.add_event_handler(Events.EPOCH_COMPLETED, 235 | early_stop_handler) 236 | inference_engine.add_event_handler(Events.EPOCH_COMPLETED, 237 | checkpoint_handler, { 238 | 'model': model, 239 | }) 240 | 241 | train_engine.run(trainloader, max_epochs=config_parameters['epochs']) 242 | return outputdir 243 | 244 | def train_evaluate(self, 245 | config, 246 | tasks=['aurora_clean', 'aurora_noisy', 'dcase18'], 247 | **kwargs): 248 | experiment_path = self.train(config, **kwargs) 249 | for task in tasks: 250 | self.evaluate(experiment_path, task=task) 251 | 252 | def predict_time( 253 | self, 254 | experiment_path, 255 | output_h5, 256 | rfac=2, # Resultuion upscale fator 257 | **kwargs): # overwrite --data 258 | 259 | experiment_path = Path(experiment_path) 260 | if experiment_path.is_file(): # Model is given 261 | model_path = experiment_path 262 | experiment_path = experiment_path.parent 263 | else: 264 | model_path = next(Path(experiment_path).glob("run_model*")) 265 | config = torch.load(next(Path(experiment_path).glob("run_config*")), 266 | map_location=lambda storage, loc: storage) 267 | logger = utils.getfile_outlogger(None) 268 | # Use previous config, but update data such as kwargs 269 | config_parameters = dict(config, **kwargs) 270 | # Default columns to search for in data 271 | encoder = torch.load('labelencoders/vad.pth') 272 | data = config_parameters['data'] 273 | dset = dataset.EvalH5Dataset(data) 274 | dataloader = torch.utils.data.DataLoader(dset, 275 | batch_size=1, 276 | num_workers=4, 277 | shuffle=False) 278 | 279 | model = getattr(models, config_parameters['model'])( 280 | inputdim=dataloader.dataset.datadim, 281 | outputdim=len(encoder.classes_), 282 | **config_parameters['model_args']) 283 | 284 | model_parameters = torch.load( 285 | model_path, map_location=lambda storage, loc: storage) 286 | model.load_state_dict(model_parameters) 287 | model = model.to(DEVICE).eval() 288 | 289 | ## VAD preprocessing data 290 | logger.trace(model) 291 | 292 | output_dfs = [] 293 | 294 | speech_label_idx = np.where('Speech' == encoder.classes_)[0].squeeze() 295 | non_speech_idx = np.arange(len(encoder.classes_)) 296 | non_speech_idx = np.delete(non_speech_idx, speech_label_idx) 297 | speech_frame_predictions, speech_frame_prob_predictions = [], [] 298 | with torch.no_grad(), tqdm(total=len(dataloader), 299 | leave=False, 300 | unit='clip') as pbar, File(output_h5, 301 | 'w') as store: 302 | for feature, filename in dataloader: 303 | feature = torch.as_tensor(feature).to(DEVICE) 304 | filename = Path(filename[0]).stem 305 | batch, time, dim = feature.shape 306 | # PANNS output a dict instead of 2 values 307 | prediction_tag, prediction_time = model(feature, 308 | upsample=False) 309 | prediction_tag = prediction_tag.to('cpu') 310 | prediction_time = torch.nn.functional.interpolate( 311 | prediction_time.transpose(1, 2), 312 | int(time * rfac), 313 | mode='linear', 314 | align_corners=False).transpose(1, 2) 315 | prediction_time = prediction_time.to('cpu').squeeze(0) 316 | speech_label_pred = prediction_time[ 317 | ..., speech_label_idx].squeeze(-1) 318 | noise_label_pred = prediction_time[..., 319 | non_speech_idx].squeeze(-1) 320 | store[f'{filename}/speech'] = speech_label_pred 321 | store[f'{filename}/noise'] = noise_label_pred 322 | pbar.set_postfix(time=time, 323 | fname=filename, 324 | speech=speech_label_pred.shape, 325 | noise=noise_label_pred.shape) 326 | pbar.update() 327 | 328 | def predict_clip(self, 329 | experiment_path, 330 | output_csv, 331 | thres=0.5, 332 | **kwargs): # overwrite --data 333 | import h5py 334 | from sklearn.preprocessing import binarize 335 | from tqdm import tqdm 336 | config = torch.load(list(Path(experiment_path).glob("run_config*"))[0], 337 | map_location=lambda storage, loc: storage) 338 | config_parameters = dict(config, **kwargs) 339 | model_parameters = torch.load( 340 | list(Path(experiment_path).glob("run_model*"))[0], 341 | map_location=lambda storage, loc: storage) 342 | encoder = torch.load('labelencoders/vad.pth') 343 | 344 | predictions = [] 345 | with h5py.File(config_parameters['data'], 346 | 'r') as input_store, torch.no_grad(), tqdm( 347 | total=len(input_store)) as pbar: 348 | inputdim = next(iter(input_store.values())).shape[-1] 349 | model = getattr(models, config_parameters['model'])( 350 | inputdim=inputdim, 351 | outputdim=len(encoder.classes_), 352 | **config_parameters['model_args']) 353 | model.load_state_dict(model_parameters) 354 | model = model.to(DEVICE).eval() 355 | for fname, sample in input_store.items(): 356 | if sample.ndim > 1: # Global mean and Global_var might also be there 357 | sample = torch.as_tensor(sample[()]).unsqueeze(0).to( 358 | DEVICE) # batch + channel 359 | decision, _ = model(sample) 360 | decision = binarize(decision.to('cpu'), threshold=thres) 361 | pred_labels = encoder.inverse_transform(decision)[0] 362 | pbar.set_postfix(labels=pred_labels, file=fname) 363 | if len(pred_labels) > 0: 364 | predictions.append({ 365 | 'filename': 366 | fname, 367 | 'event_labels': 368 | ",".join(pred_labels) 369 | }) 370 | pbar.update() 371 | 372 | df = pd.DataFrame(predictions) 373 | df.to_csv(output_csv, sep='\t', index=False) 374 | 375 | def evaluate(self, 376 | experiment_path: Path, 377 | task: str = 'aurora_clean', 378 | model_resolution=0.02, 379 | time_resolution=0.02, 380 | threshold=(0.5, 0.1), 381 | **kwargs): 382 | EVALUATION_DATA = { 383 | 'aurora_clean': { 384 | 'data': 'data/evaluation/hdf5/aurora_clean.h5', 385 | 'label': 'data/evaluation/labels/aurora_clean_labels.tsv', 386 | }, 387 | 'aurora_noisy': { 388 | 'data': 'data/evaluation/hdf5/aurora_noisy.h5', 389 | 'label': 'data/evaluation/labels/aurora_noisy_labels.tsv' 390 | }, 391 | 'dihard_dev': { 392 | 'data': 'data/evaluation/hdf5/dihard_dev.h5', 393 | 'label': 'data/evaluation/labels/dihard_dev.csv' 394 | }, 395 | 'dihard_eval': { 396 | 'data': 'data/evaluation/hdf5/dihard_eval.h5', 397 | 'label': 'data/evaluation/labels/dihard_eval.csv' 398 | }, 399 | 'aurora_snr_20': { 400 | 'data': 401 | 'data/evaluation/hdf5/aurora_noisy_musan_snr_20.0.hdf5', 402 | 'label': 'data/evaluation/labels/musan_labels.tsv' 403 | }, 404 | 'aurora_snr_15': { 405 | 'data': 406 | 'data/evaluation/hdf5/aurora_noisy_musan_snr_15.0.hdf5', 407 | 'label': 'data/evaluation/labels/musan_labels.tsv' 408 | }, 409 | 'aurora_snr_10': { 410 | 'data': 411 | 'data/evaluation/hdf5/aurora_noisy_musan_snr_10.0.hdf5', 412 | 'label': 'data/evaluation/labels/musan_labels.tsv' 413 | }, 414 | 'aurora_snr_5': { 415 | 'data': 'data/evaluation/hdf5/aurora_noisy_musan_snr_5.0.hdf5', 416 | 'label': 'data/evaluation/labels/musan_labels.tsv' 417 | }, 418 | 'aurora_snr_0': { 419 | 'data': 'data/evaluation/hdf5/aurora_noisy_musan_snr_0.0.hdf5', 420 | 'label': 'data/evaluation/labels/musan_labels.tsv' 421 | }, 422 | 'aurora_snr_-5': { 423 | 'data': 424 | 'data/evaluation/hdf5/aurora_noisy_musan_snr_-5.0.hdf5', 425 | 'label': 'data/evaluation/labels/musan_labels.tsv' 426 | }, 427 | 'dcase18': { 428 | 'data': 'data/evaluation/hdf5/dcase18.h5', 429 | 'label': 'data/evaluation/labels/dcase18.tsv', 430 | }, 431 | } 432 | assert task in EVALUATION_DATA, f"--task {'|'.join(list(EVALUATION_DATA.keys()))}" 433 | experiment_path = Path(experiment_path) 434 | if experiment_path.is_file(): # Model is given 435 | model_path = experiment_path 436 | experiment_path = experiment_path.parent 437 | else: 438 | model_path = next(Path(experiment_path).glob("run_model*")) 439 | config = torch.load(next(Path(experiment_path).glob("run_config*")), 440 | map_location='cpu') 441 | logger = utils.getfile_outlogger(None) 442 | # Use previous config, but update data such as kwargs 443 | config_parameters = dict(config, **kwargs) 444 | # Default columns to search for in data 445 | model_parameters = torch.load( 446 | model_path, map_location=lambda storage, loc: storage) 447 | encoder = torch.load('labelencoders/vad.pth') 448 | data = EVALUATION_DATA[task]['data'] 449 | label_df = pd.read_csv(EVALUATION_DATA[task]['label'], sep='\s+') 450 | label_df['filename'] = label_df['filename'].apply( 451 | lambda x: Path(x).name) 452 | logger.info(f"Label_df shape is {label_df.shape}") 453 | 454 | dset = dataset.EvalH5Dataset(data, 455 | fnames=np.unique( 456 | label_df['filename'].values)) 457 | 458 | dataloader = torch.utils.data.DataLoader(dset, 459 | batch_size=1, 460 | num_workers=4, 461 | shuffle=False) 462 | 463 | model = getattr(models, config_parameters['model'])( 464 | inputdim=dataloader.dataset.datadim, 465 | outputdim=len(encoder.classes_), 466 | **config_parameters['model_args']) 467 | 468 | model.load_state_dict(model_parameters) 469 | model = model.to(DEVICE).eval() 470 | 471 | ## VAD preprocessing data 472 | vad_label_helper_df = label_df.copy() 473 | vad_label_helper_df['onset'] = np.ceil(vad_label_helper_df['onset'] / 474 | model_resolution).astype(int) 475 | vad_label_helper_df['offset'] = np.ceil(vad_label_helper_df['offset'] / 476 | model_resolution).astype(int) 477 | 478 | vad_label_helper_df = vad_label_helper_df.groupby(['filename']).agg({ 479 | 'onset': 480 | tuple, 481 | 'offset': 482 | tuple, 483 | 'event_label': 484 | tuple 485 | }).reset_index() 486 | logger.trace(model) 487 | 488 | output_dfs = [] 489 | 490 | speech_label_idx = np.where('Speech' == encoder.classes_)[0].squeeze() 491 | speech_frame_predictions, speech_frame_ground_truth, speech_frame_prob_predictions = [], [],[] 492 | # Using only binary thresholding without filter 493 | if len(threshold) == 1: 494 | postprocessing_method = utils.binarize 495 | else: 496 | postprocessing_method = utils.double_threshold 497 | with torch.no_grad(), tqdm(total=len(dataloader), 498 | leave=False, 499 | unit='clip') as pbar: 500 | for feature, filename in dataloader: 501 | feature = torch.as_tensor(feature).to(DEVICE) 502 | # PANNS output a dict instead of 2 values 503 | prediction_tag, prediction_time = model(feature) 504 | prediction_tag = prediction_tag.to('cpu') 505 | prediction_time = prediction_time.to('cpu') 506 | 507 | if prediction_time is not None: # Some models do not predict timestamps 508 | 509 | cur_filename = filename[0] 510 | 511 | thresholded_prediction = postprocessing_method( 512 | prediction_time, *threshold) 513 | 514 | ## VAD predictions 515 | speech_frame_prob_predictions.append( 516 | prediction_time[..., speech_label_idx].squeeze()) 517 | ### Thresholded speech predictions 518 | speech_prediction = thresholded_prediction[ 519 | ..., speech_label_idx].squeeze() 520 | speech_frame_predictions.append(speech_prediction) 521 | targets = vad_label_helper_df[ 522 | vad_label_helper_df['filename'] == cur_filename][[ 523 | 'onset', 'offset' 524 | ]].values[0] 525 | target_arr = np.zeros_like(speech_prediction) 526 | for start, end in zip(*targets): 527 | target_arr[start:end] = 1 528 | speech_frame_ground_truth.append(target_arr) 529 | 530 | #### SED predictions 531 | 532 | labelled_predictions = utils.decode_with_timestamps( 533 | encoder, thresholded_prediction) 534 | pred_label_df = pd.DataFrame( 535 | labelled_predictions[0], 536 | columns=['event_label', 'onset', 'offset']) 537 | if not pred_label_df.empty: 538 | pred_label_df['filename'] = cur_filename 539 | pred_label_df['onset'] *= model_resolution 540 | pred_label_df['offset'] *= model_resolution 541 | pbar.set_postfix(labels=','.join( 542 | np.unique(pred_label_df['event_label'].values))) 543 | pbar.update() 544 | output_dfs.append(pred_label_df) 545 | 546 | full_prediction_df = pd.concat(output_dfs) 547 | prediction_df = full_prediction_df[full_prediction_df['event_label'] == 548 | 'Speech'] 549 | assert set(['onset', 'offset', 'filename', 'event_label' 550 | ]).issubset(prediction_df.columns), "Format is wrong" 551 | assert set(['onset', 'offset', 'filename', 'event_label' 552 | ]).issubset(label_df.columns), "Format is wrong" 553 | logger.info("Calculating VAD measures ... ") 554 | speech_frame_ground_truth = np.concatenate(speech_frame_ground_truth, 555 | axis=0) 556 | speech_frame_predictions = np.concatenate(speech_frame_predictions, 557 | axis=0) 558 | speech_frame_prob_predictions = np.concatenate( 559 | speech_frame_prob_predictions, axis=0) 560 | 561 | vad_results = [] 562 | tn, fp, fn, tp = metrics.confusion_matrix( 563 | speech_frame_ground_truth, speech_frame_predictions).ravel() 564 | fer = 100 * ((fp + fn) / len(speech_frame_ground_truth)) 565 | acc = 100 * ((tp + tn) / (len(speech_frame_ground_truth))) 566 | 567 | p_miss = 100 * (fn / (fn + tp)) 568 | p_fa = 100 * (fp / (fp + tn)) 569 | for i in [0.01, 0.02, 0.05, 0.1, 0.2, 0.5, 0.7,0.9]: 570 | mp_fa, mp_miss = metrics.obtain_error_rates( 571 | speech_frame_ground_truth, speech_frame_prob_predictions, i) 572 | tn, fp, fn, tp = metrics.confusion_matrix( 573 | speech_frame_ground_truth, 574 | speech_frame_prob_predictions > i).ravel() 575 | sub_fer = 100 * ((fp + fn) / len(speech_frame_ground_truth)) 576 | logger.info( 577 | f"PFa {100*mp_fa:.2f} Pmiss {100*mp_miss:.2f} FER {sub_fer:.2f} t: {i:.2f}" 578 | ) 579 | 580 | auc = metrics.roc(speech_frame_ground_truth, 581 | speech_frame_prob_predictions) * 100 582 | for avgtype in ('micro', 'macro', 'binary'): 583 | precision, recall, f1, _ = metrics.precision_recall_fscore_support( 584 | speech_frame_ground_truth, 585 | speech_frame_predictions, 586 | average=avgtype) 587 | vad_results.append( 588 | (avgtype, 100 * precision, 100 * recall, 100 * f1)) 589 | 590 | logger.info("Calculating segment based metric .. ") 591 | # Change order just for better printing in file 592 | prediction_df = prediction_df[[ 593 | 'filename', 'onset', 'offset', 'event_label' 594 | ]] 595 | metric = metrics.segment_based_evaluation_df( 596 | label_df, prediction_df, time_resolution=time_resolution) 597 | logger.info("Calculating event based metric .. ") 598 | event_metric = metrics.event_based_evaluation_df( 599 | label_df, prediction_df) 600 | 601 | prediction_df.to_csv(experiment_path / 602 | f'speech_predictions_{task}.tsv', 603 | sep='\t', 604 | index=False) 605 | full_prediction_df.to_csv(experiment_path / f'predictions_{task}.tsv', 606 | sep='\t', 607 | index=False) 608 | with open(experiment_path / f'evaluation_{task}.txt', 'w') as fp: 609 | for k, v in config_parameters.items(): 610 | print(f"{k}:{v}", file=fp) 611 | print(metric, file=fp) 612 | print(event_metric, file=fp) 613 | for avgtype, precision, recall, f1 in vad_results: 614 | print( 615 | f"VAD {avgtype} F1: {f1:<10.3f} {precision:<10.3f} Recall: {recall:<10.3f}", 616 | file=fp) 617 | print(f"FER: {fer:.2f}", file=fp) 618 | print(f"AUC: {auc:.2f}", file=fp) 619 | print(f"Pfa: {p_fa:.2f}", file=fp) 620 | print(f"Pmiss: {p_miss:.2f}", file=fp) 621 | print(f"ACC: {acc:.2f}", file=fp) 622 | logger.info(f"Results are at {experiment_path}") 623 | for avgtype, precision, recall, f1 in vad_results: 624 | print( 625 | f"VAD {avgtype:<10} F1: {f1:<10.3f} Pre: {precision:<10.3f} Recall: {recall:<10.3f}" 626 | ) 627 | print(f"FER: {fer:.2f}") 628 | print(f"AUC: {auc:.2f}") 629 | print(f"Pfa: {p_fa:.2f}") 630 | print(f"Pmiss: {p_miss:.2f}") 631 | print(f"ACC: {acc:.2f}") 632 | print(event_metric) 633 | print(metric) 634 | 635 | 636 | if __name__ == "__main__": 637 | fire.Fire(Runner) 638 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | import collections 5 | import sys 6 | from loguru import logger 7 | from pprint import pformat 8 | from typing import List 9 | 10 | import numpy as np 11 | import pandas as pd 12 | import scipy 13 | import six 14 | import sklearn.preprocessing as pre 15 | import torch 16 | import tqdm 17 | import yaml 18 | 19 | import augment 20 | import dataset 21 | 22 | # Some defaults for non-specified arguments in yaml 23 | DEFAULT_ARGS = { 24 | 'outputpath': 'experiments', 25 | 'loss': 'BCELoss', 26 | 'batch_size': 64, 27 | 'num_workers': 4, 28 | 'epochs': 100, 29 | 'transforms': [], 30 | 'label_type':'soft', 31 | 'scheduler_args': { 32 | 'patience': 3, 33 | 'factor': 0.1, 34 | }, 35 | 'early_stop': 7, 36 | 'optimizer': 'Adam', 37 | 'optimizer_args': { 38 | 'lr': 0.001, 39 | }, 40 | 'threshold': None, #Default threshold for postprocessing function 41 | 'postprocessing': 'double', 42 | } 43 | 44 | 45 | def parse_config_or_kwargs(config_file, **kwargs): 46 | """parse_config_or_kwargs 47 | 48 | :param config_file: Config file that has parameters, yaml format 49 | :param **kwargs: Other alternative parameters or overwrites for config 50 | """ 51 | with open(config_file) as con_read: 52 | yaml_config = yaml.load(con_read, Loader=yaml.FullLoader) 53 | # values from config file are all possible params 54 | arguments = dict(yaml_config, **kwargs) 55 | # In case some arguments were not passed, replace with default ones 56 | for key, value in DEFAULT_ARGS.items(): 57 | arguments.setdefault(key, value) 58 | return arguments 59 | 60 | 61 | def find_contiguous_regions(activity_array): 62 | """Find contiguous regions from bool valued numpy.array. 63 | Copy of https://dcase-repo.github.io/dcase_util/_modules/dcase_util/data/decisions.html#DecisionEncoder 64 | 65 | Reason is: 66 | 1. This does not belong to a class necessarily 67 | 2. Import DecisionEncoder requires sndfile over some other imports..which causes some problems on clusters 68 | 69 | """ 70 | 71 | # Find the changes in the activity_array 72 | change_indices = np.logical_xor(activity_array[1:], 73 | activity_array[:-1]).nonzero()[0] 74 | 75 | # Shift change_index with one, focus on frame after the change. 76 | change_indices += 1 77 | 78 | if activity_array[0]: 79 | # If the first element of activity_array is True add 0 at the beginning 80 | change_indices = np.r_[0, change_indices] 81 | 82 | if activity_array[-1]: 83 | # If the last element of activity_array is True, add the length of the array 84 | change_indices = np.r_[change_indices, activity_array.size] 85 | 86 | # Reshape the result into two columns 87 | return change_indices.reshape((-1, 2)) 88 | 89 | 90 | def split_train_cv(input_data, frac: float = 0.9, **kwargs): 91 | """split_train_cv 92 | 93 | :param data_frame: 94 | :param frac: 95 | :type frac: float 96 | """ 97 | if isinstance(input_data, list): 98 | N = len(input_data) 99 | indicies = np.random.permutation(N) 100 | train_size = round(N * frac) 101 | cv_size = N - train_size 102 | train_idxs, cv_idxs = indicies[:train_size], indicies[cv_size:] 103 | input_data = np.array(input_data) 104 | return input_data[train_idxs].tolist(), input_data[cv_idxs].tolist() 105 | elif isinstance(input_data, pd.DataFrame): 106 | train_df = input_data.sample(frac=frac) 107 | cv_df = input_data[~input_data.index.isin(train_df.index)] 108 | return train_df, cv_df 109 | 110 | 111 | def parse_transforms(transform_list): 112 | """parse_transforms 113 | parses the config files transformation strings to coresponding methods 114 | 115 | :param transform_list: String list 116 | """ 117 | transforms = [] 118 | for trans in transform_list: 119 | if trans == 'noise': 120 | transforms.append(augment.GaussianNoise(snr=25)) 121 | elif trans == 'roll': 122 | transforms.append(augment.Roll(0, 10)) 123 | elif trans == 'freqmask': 124 | transforms.append(augment.FreqMask(2, 8)) 125 | elif trans == 'timemask': 126 | transforms.append(augment.TimeMask(2, 60)) 127 | elif trans == 'crop': 128 | transforms.append(augment.RandomCrop(200)) 129 | elif trans == 'randompad': 130 | transforms.append(augment.RandomPad(value=0., padding=25)) 131 | elif trans == 'flipsign': 132 | transforms.append(augment.FlipSign()) 133 | elif trans == 'shift': 134 | transforms.append(augment.Shift()) 135 | return torch.nn.Sequential(*transforms) 136 | 137 | 138 | def pprint_dict(in_dict, outputfun=sys.stdout.write, formatter='yaml'): 139 | """pprint_dict 140 | 141 | :param outputfun: function to use, defaults to sys.stdout 142 | :param in_dict: dict to print 143 | """ 144 | if formatter == 'yaml': 145 | format_fun = yaml.dump 146 | elif formatter == 'pretty': 147 | format_fun = pformat 148 | for line in format_fun(in_dict).split('\n'): 149 | outputfun(line) 150 | 151 | 152 | def getfile_outlogger(outputfile): 153 | log_format = "[{time:YYYY-MM-DD HH:mm:ss}] {message}" 154 | logger.configure(handlers=[{"sink": sys.stderr, "format": log_format}]) 155 | if outputfile: 156 | logger.add(outputfile, enqueue=True, format=log_format) 157 | return logger 158 | 159 | 160 | def train_labelencoder(labels: pd.Series, sparse=True): 161 | """encode_labels 162 | 163 | Encodes labels 164 | 165 | :param labels: pd.Series representing the raw labels e.g., Speech, Water 166 | :param encoder (optional): Encoder already fitted 167 | returns encoded labels (many hot) and the encoder 168 | """ 169 | assert isinstance(labels, pd.Series), "Labels need to be series" 170 | if isinstance(labels[0], six.string_types): 171 | # In case of using non processed strings, e.g., Vaccum, Speech 172 | label_array = labels.str.split(',').values.tolist() 173 | elif isinstance(labels[0], np.ndarray): 174 | # Encoder does not like to see numpy array 175 | label_array = [lab.tolist() for lab in labels] 176 | elif isinstance(labels[0], collections.Iterable): 177 | label_array = labels 178 | encoder = pre.MultiLabelBinarizer(sparse_output=sparse) 179 | encoder.fit(label_array) 180 | return encoder 181 | 182 | 183 | def encode_labels(labels: pd.Series, encoder=None, sparse=True): 184 | """encode_labels 185 | 186 | Encodes labels 187 | 188 | :param labels: pd.Series representing the raw labels e.g., Speech, Water 189 | :param encoder (optional): Encoder already fitted 190 | returns encoded labels (many hot) and the encoder 191 | """ 192 | assert isinstance(labels, pd.Series), "Labels need to be series" 193 | instance = labels.iloc[0] 194 | if isinstance(instance, six.string_types): 195 | # In case of using non processed strings, e.g., Vaccum, Speech 196 | label_array = labels.str.split(',').values.tolist() 197 | elif isinstance(instance, np.ndarray): 198 | # Encoder does not like to see numpy array 199 | label_array = [lab.tolist() for lab in labels] 200 | elif isinstance(instance, collections.Iterable): 201 | label_array = labels 202 | if not encoder: 203 | encoder = pre.MultiLabelBinarizer(sparse_output=sparse) 204 | encoder.fit(label_array) 205 | labels_encoded = encoder.transform(label_array) 206 | return labels_encoded, encoder 207 | 208 | # return pd.arrays.SparseArray( 209 | # [row.toarray().ravel() for row in labels_encoded]), encoder 210 | 211 | 212 | def decode_with_timestamps(encoder: pre.MultiLabelBinarizer, labels: np.array): 213 | """decode_with_timestamps 214 | Decodes the predicted label array (2d) into a list of 215 | [(Labelname, onset, offset), ...] 216 | 217 | :param encoder: Encoder during training 218 | :type encoder: pre.MultiLabelBinarizer 219 | :param labels: n-dim array 220 | :type labels: np.array 221 | """ 222 | if labels.ndim == 3: 223 | return [_decode_with_timestamps(encoder, lab) for lab in labels] 224 | else: 225 | return _decode_with_timestamps(encoder, labels) 226 | 227 | 228 | def sma_filter(x, window_size, axis=1): 229 | """sma_filter 230 | 231 | :param x: Input numpy array, 232 | :param window_size: filter size 233 | :param axis: over which axis ( usually time ) to apply 234 | """ 235 | # 1 is time axis 236 | kernel = np.ones((window_size, )) / window_size 237 | 238 | def moving_average(arr): 239 | return np.convolve(arr, kernel, 'same') 240 | 241 | return np.apply_along_axis(moving_average, axis, x) 242 | 243 | 244 | def median_filter(x, window_size, threshold=0.5): 245 | """median_filter 246 | 247 | :param x: input prediction array of shape (B, T, C) or (B, T). 248 | Input is a sequence of probabilities 0 <= x <= 1 249 | :param window_size: An integer to use 250 | :param threshold: Binary thresholding threshold 251 | """ 252 | x = binarize(x, threshold=threshold) 253 | if x.ndim == 3: 254 | size = (1, window_size, 1) 255 | elif x.ndim == 2 and x.shape[0] == 1: 256 | # Assume input is class-specific median filtering 257 | # E.g, Batch x Time [1, 501] 258 | size = (1, window_size) 259 | elif x.ndim == 2 and x.shape[0] > 1: 260 | # Assume input is standard median pooling, class-independent 261 | # E.g., Time x Class [501, 10] 262 | size = (window_size, 1) 263 | return scipy.ndimage.median_filter(x, size=size) 264 | 265 | 266 | def _decode_with_timestamps(encoder, labels): 267 | result_labels = [] 268 | for i, label_column in enumerate(labels.T): 269 | change_indices = find_contiguous_regions(label_column) 270 | # append [onset, offset] in the result list 271 | for row in change_indices: 272 | result_labels.append((encoder.classes_[i], row[0], row[1])) 273 | return result_labels 274 | 275 | 276 | def inverse_transform_labels(encoder, pred): 277 | if pred.ndim == 3: 278 | return [encoder.inverse_transform(x) for x in pred] 279 | else: 280 | return encoder.inverse_transform(pred) 281 | 282 | 283 | def binarize(pred, threshold=0.5): 284 | # Batch_wise 285 | if pred.ndim == 3: 286 | return np.array( 287 | [pre.binarize(sub, threshold=threshold) for sub in pred]) 288 | else: 289 | return pre.binarize(pred, threshold=threshold) 290 | 291 | 292 | def double_threshold(x, high_thres, low_thres, n_connect=1): 293 | """double_threshold 294 | Helper function to calculate double threshold for n-dim arrays 295 | 296 | :param x: input array 297 | :param high_thres: high threshold value 298 | :param low_thres: Low threshold value 299 | :param n_connect: Distance of <= n clusters will be merged 300 | """ 301 | assert x.ndim <= 3, "Whoops something went wrong with the input ({}), check if its <= 3 dims".format( 302 | x.shape) 303 | if x.ndim == 3: 304 | apply_dim = 1 305 | elif x.ndim < 3: 306 | apply_dim = 0 307 | # x is assumed to be 3d: (batch, time, dim) 308 | # Assumed to be 2d : (time, dim) 309 | # Assumed to be 1d : (time) 310 | # time axis is therefore at 1 for 3d and 0 for 2d ( 311 | return np.apply_along_axis(lambda x: _double_threshold( 312 | x, high_thres, low_thres, n_connect=n_connect), 313 | axis=apply_dim, 314 | arr=x) 315 | 316 | 317 | def _double_threshold(x, high_thres, low_thres, n_connect=1, return_arr=True): 318 | """_double_threshold 319 | Computes a double threshold over the input array 320 | 321 | :param x: input array, needs to be 1d 322 | :param high_thres: High threshold over the array 323 | :param low_thres: Low threshold over the array 324 | :param n_connect: Postprocessing, maximal distance between clusters to connect 325 | :param return_arr: By default this function returns the filtered indiced, but if return_arr = True it returns an array of tsame size as x filled with ones and zeros. 326 | """ 327 | assert x.ndim == 1, "Input needs to be 1d" 328 | high_locations = np.where(x > high_thres)[0] 329 | locations = x > low_thres 330 | encoded_pairs = find_contiguous_regions(locations) 331 | 332 | filtered_list = list( 333 | filter( 334 | lambda pair: 335 | ((pair[0] <= high_locations) & (high_locations <= pair[1])).any(), 336 | encoded_pairs)) 337 | 338 | filtered_list = connect_(filtered_list, n_connect) 339 | if return_arr: 340 | zero_one_arr = np.zeros_like(x, dtype=int) 341 | for sl in filtered_list: 342 | zero_one_arr[sl[0]:sl[1]] = 1 343 | return zero_one_arr 344 | return filtered_list 345 | 346 | 347 | def connect_clusters(x, n=1): 348 | if x.ndim == 1: 349 | return connect_clusters_(x, n) 350 | if x.ndim >= 2: 351 | return np.apply_along_axis(lambda a: connect_clusters_(a, n=n), -2, x) 352 | 353 | 354 | def connect_clusters_(x, n=1): 355 | """connect_clusters_ 356 | Connects clustered predictions (0,1) in x with range n 357 | 358 | :param x: Input array. zero-one format 359 | :param n: Number of frames to skip until connection can be made 360 | """ 361 | assert x.ndim == 1, "input needs to be 1d" 362 | reg = find_contiguous_regions(x) 363 | start_end = connect_(reg, n=n) 364 | zero_one_arr = np.zeros_like(x, dtype=int) 365 | for sl in start_end: 366 | zero_one_arr[sl[0]:sl[1]] = 1 367 | return zero_one_arr 368 | 369 | 370 | def connect_(pairs, n=1): 371 | """connect_ 372 | Connects two adjacent clusters if their distance is <= n 373 | 374 | :param pairs: Clusters of iterateables e.g., [(1,5),(7,10)] 375 | :param n: distance between two clusters 376 | """ 377 | if len(pairs) == 0: 378 | return [] 379 | start_, end_ = pairs[0] 380 | new_pairs = [] 381 | for i, (next_item, cur_item) in enumerate(zip(pairs[1:], pairs[0:])): 382 | end_ = next_item[1] 383 | if next_item[0] - cur_item[1] <= n: 384 | pass 385 | else: 386 | new_pairs.append((start_, cur_item[1])) 387 | start_ = next_item[0] 388 | new_pairs.append((start_, end_)) 389 | return new_pairs 390 | 391 | 392 | def predictions_to_time(df, ratio): 393 | df.onset = df.onset * ratio 394 | df.offset = df.offset * ratio 395 | return df 396 | 397 | 398 | def estimate_scaler(dataloader, **scaler_args): 399 | 400 | scaler = pre.StandardScaler(**scaler_args) 401 | with tqdm.tqdm(total=len(dataloader), 402 | unit='batch', 403 | leave=False, 404 | desc='Estimating Scaler') as pbar: 405 | for batch in dataloader: 406 | feature = batch[0] 407 | # Flatten time and batch dim to one 408 | feature = feature.reshape(-1, feature.shape[-1]) 409 | pbar.set_postfix(feature=feature.shape) 410 | pbar.update() 411 | scaler.partial_fit(feature) 412 | return scaler 413 | 414 | 415 | def rescale_0_1(x): 416 | if x.ndim == 2: 417 | return pre.minmax_scale(x, axis=0) 418 | else: 419 | 420 | def min_max_scale(a): 421 | return pre.minmax_scale(a, axis=0) 422 | 423 | def df_to_dict(df, index='filename', value='hdf5path'): 424 | return dict(zip(df[index],df[value])) 425 | --------------------------------------------------------------------------------