├── LICENSE
├── README.md
├── augment.py
├── configs
└── example.yaml
├── data
├── extract_feature.py
├── extract_feature_for_train.py
└── prepare_labels.py
├── dataset.py
├── example
└── example.wav
├── figs
├── data_driven_framework.png
├── sample_background.png
├── sample_music.png
├── sample_speech.png
├── samples_1.png
├── samples_2.png
├── samples_3.png
└── samples_4.png
├── forward.py
├── labelencoders
└── vad.pth
├── losses.py
├── metrics.py
├── models.py
├── pprint_results.py
├── pretrained_models
├── audio2_vox2
│ └── model.pth
├── audioset2
│ └── model.pth
├── c1
│ └── model.pth
├── labelencoders
│ ├── students.pth
│ └── teacher.pth
├── sre
│ └── model.pth
├── teacher1
│ └── model.pth
├── teacher2
│ └── model.pth
└── vox2
│ └── model.pth
├── requirements.txt
├── run.py
└── utils.py
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 Heinrich Dinkel
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Data driven GPVAD
2 | Repository for the work in TASLP 2021 [Voice activity detection in the wild: A data-driven approach using teacher-student training](https://arxiv.org/abs/2105.04065).
3 |
4 |
5 | 
6 |
7 |
8 | ## Sample predictions against other methods
9 |
10 | 
11 |
12 | 
13 |
14 | 
15 |
16 | 
17 |
18 | ## Noise robustness
19 |
20 | 
21 |
22 | 
23 |
24 | 
25 |
26 | ## Results
27 |
28 | Our best model trained on the SRE (V3) dataset obtains the following results:
29 |
30 | | | Precision | Recall | F1 | AUC | FER | Event-F1 |
31 | |:-------------|------------:|---------:|-------:|------:|------:|-----------:|
32 | | aurora_clean | 96.844 | 95.102 | 95.93 | 98.66 | 3.06 | 74.8 |
33 | | aurora_noisy | 90.435 | 92.871 | 91.544 | 97.63 | 6.68 | 54.45 |
34 | | dcase18 | 89.202 | 88.362 | 88.717 | 95.2 | 10.82 | 57.85 |
35 |
36 | ## Usage
37 |
38 | We provide most of our pretrained models in this repository, including:
39 |
40 | 1. Both teachers (T_1, T_2)
41 | 2. Unbalanced audioset pretrained model
42 | 3. Voxceleb 2 pretrained model
43 | 4. Our best submission (SRE V3 trained)
44 |
45 | To download and run evaluation just do:
46 |
47 | ```bash
48 | git clone https://github.com/RicherMans/Datadriven-VAD
49 | cd Datadriven-VAD
50 | pip3 install -r requirements.txt
51 | python3 forward.py -w example/example.wav
52 | ```
53 |
54 | Running this will print:
55 |
56 | ```
57 | | index | event_label | onset | offset | filename |
58 | |--------:|:--------------|--------:|---------:|:--------------------|
59 | | 0 | Speech | 0.28 | 0.94 | example/example.wav |
60 | | 1 | Speech | 1.04 | 2.22 | example/example.wav |
61 | ```
62 |
63 | ### Predicting voice activity
64 |
65 | We support single file and filelist-batching in our script.
66 | Obtaining VAD predictions is easy:
67 |
68 | ```bash
69 | python3 forward.py -w example/example.wav
70 | ```
71 |
72 | Or if one prefers to do that batch_wise, first prepare a filelist:
73 | `find . -type f -name *.wav > wavlist.txt'`
74 | And then just run:
75 | ```bash
76 | python3 forward.py -l wavlist
77 | ```
78 |
79 |
80 | #### Extra parameters
81 |
82 | * `-model` adjusts the pretrained model. Can be one of `t1,t2,v2,a2,a2_v2,sre`. Refer to the paper for each respective model. By default we use `sre`.
83 | * `-soft` instead of predicting human-readable timestamps, the model is now outputting the raw probabilities.
84 | * `-hard` instead of predicting human-readable timestamps, the model is now outputting the post-processed 0-1 flags indicating speech. Please note this is different from the paper, which thresholded the soft probabilities without post-processing.
85 | * `-th` adjusts the threshold. If a single threshold is passed (e.g., `-th 0.5`), we utilize simple binearization. Otherwise use the default double threshold with `-th 0.5 0.1`.
86 | * `-o` outputs the results into a new folder.
87 |
88 |
89 | ## Training from scratch
90 |
91 | If you intend to rerun our work, prepare some data and extract log-Mel spectrogram features.
92 | Say, you have downloaded the [balanced](http://storage.googleapis.com/us_audioset/youtube_corpus/v1/csv/balanced_train_segments.csv) subset of AudioSet and stored all files in a folder `data/balanced/`. Then:
93 |
94 | ```bash
95 | cd data;
96 | mkdir hdf5 csv_labels;
97 | find balanced -type f > wavs.txt;
98 | python3 extract_features.py wavs.txt -o hdf5/balanced.h5
99 | h5ls -r hdf5/balanced.h5 | awk -F[/' '] 'BEGIN{print "filename","hdf5path"}NR>1{print $2,"hdf5/balanced.h5"}'> csv_labels/balanced.csv
100 | ```
101 |
102 |
103 | The input for our label prediction script is a csv file with exactly two columns, `filename and hdf5path`.
104 |
105 | An example `csv_labels/balanced.csv` would be:
106 |
107 | ```
108 | filename hdf5path
109 | --PJHxphWEs_30.000.wav hdf5/balanced.h5
110 | --ZhevVpy1s_50.000.wav hdf5/balanced.h5
111 | --aE2O5G5WE_0.000.wav hdf5/balanced.h5
112 | --aO5cdqSAg_30.000.wav hdf5/balanced.h5
113 | ```
114 |
115 | After feature extraction, proceed to predict labels:
116 |
117 | ```bash
118 | mkdir -p softlabels/{hdf5,csv};
119 | python3 prepare_labels.py --pre ../pretrained_models/teacher1/model.pth csv_labels/balanced.csv softlabels/hdf5/balanced.h5 softlabels/csv/balanced.csv
120 | ```
121 |
122 | Lastly, just train:
123 |
124 | ```bash
125 | cd ../; #Go to project root
126 | # Change config accoringly with input data
127 | python3 run.py train configs/example.yaml
128 | ```
129 |
130 | ## Citation
131 |
132 | If youre using this work, please cite it in your publications.
133 |
134 | ```
135 | @article{Dinkel2021,
136 | author = {Dinkel, Heinrich and Wang, Shuai and Xu, Xuenan and Wu, Mengyue and Yu, Kai},
137 | doi = {10.1109/TASLP.2021.3073596},
138 | issn = {2329-9290},
139 | journal = {IEEE/ACM Transactions on Audio, Speech, and Language Processing},
140 | pages = {1542--1555},
141 | title = {{Voice Activity Detection in the Wild: A Data-Driven Approach Using Teacher-Student Training}},
142 | url = {https://ieeexplore.ieee.org/document/9405474/},
143 | volume = {29},
144 | year = {2021}
145 | }
146 | ```
147 | and
148 | ```
149 | @inproceedings{Dinkel2020,
150 | author={Heinrich Dinkel and Yefei Chen and Mengyue Wu and Kai Yu},
151 | title={{Voice Activity Detection in the Wild via Weakly Supervised Sound Event Detection}},
152 | year=2020,
153 | booktitle={Proc. Interspeech 2020},
154 | pages={3665--3669},
155 | doi={10.21437/Interspeech.2020-0995},
156 | url={http://dx.doi.org/10.21437/Interspeech.2020-0995}
157 | }
158 | ```
159 |
160 |
--------------------------------------------------------------------------------
/augment.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import logging
3 | import torch.nn as nn
4 | import numpy as np
5 |
6 |
7 | class RandomPad(nn.Module):
8 | """docstring for RandomPad"""
9 | def __init__(self, value=0., padding=0):
10 | super().__init__()
11 | self.value = value
12 | self.padding = padding
13 |
14 | def forward(self, x):
15 | if self.training and self.padding > 0:
16 | left_right = torch.empty(2).random_(self.padding).int().numpy()
17 | topad = (0, 0, *left_right)
18 | x = nn.functional.pad(x, topad, value=self.value)
19 | return x
20 |
21 |
22 | class Roll(nn.Module):
23 | """docstring for Roll"""
24 | def __init__(self, mean, std):
25 | super().__init__()
26 | self.mean = mean
27 | self.std = std
28 |
29 | def forward(self, x):
30 | if self.training:
31 | shift = torch.empty(1).normal_(self.mean, self.std).int().item()
32 | x = torch.roll(x, shift, dims=0)
33 | return x
34 |
35 |
36 | class RandomCrop(nn.Module):
37 | """docstring for RandomPad"""
38 | def __init__(self, size: int = 100):
39 | super().__init__()
40 | self.size = int(size)
41 |
42 | def forward(self, x):
43 | if self.training:
44 | time, freq = x.shape
45 | if time < self.size:
46 | return x
47 | hi = time - self.size
48 | start_ind = torch.empty(1, dtype=torch.long).random_(0, hi).item()
49 | x = x[start_ind:start_ind + self.size, :]
50 | return x
51 |
52 |
53 | class TimeMask(nn.Module):
54 | def __init__(self, n=1, p=50):
55 | super().__init__()
56 | self.p = p
57 | self.n = 1
58 |
59 | def forward(self, x):
60 | time, freq = x.shape
61 | if self.training:
62 | for i in range(self.n):
63 | t = torch.empty(1, dtype=int).random_(self.p).item()
64 | to_sample = max(time - t, 1)
65 | t0 = torch.empty(1, dtype=int).random_(to_sample).item()
66 | x[t0:t0 + t, :] = 0
67 | return x
68 |
69 |
70 | class FreqMask(nn.Module):
71 | def __init__(self, n=1, p=12):
72 | super().__init__()
73 | self.p = p
74 | self.n = 1
75 |
76 | def forward(self, x):
77 | time, freq = x.shape
78 | if self.training:
79 | for i in range(self.n):
80 | f = torch.empty(1, dtype=int).random_(self.p).item()
81 | f0 = torch.empty(1, dtype=int).random_(freq - f).item()
82 | x[:, f0:f0 + f] = 0.
83 | return x
84 |
85 |
86 | class GaussianNoise(nn.Module):
87 | """docstring for Gaussian"""
88 | def __init__(self, snr=30, mean=0):
89 | super().__init__()
90 | self._mean = mean
91 | self._snr = snr
92 |
93 | def forward(self, x):
94 | if self.training:
95 | E_x = (x**2).sum()/x.shape[0]
96 | noise = torch.empty_like(x).normal_(self._mean, std=1)
97 | E_noise = (noise**2).sum()/noise.shape[0]
98 | alpha = np.sqrt(E_x / (E_noise * pow(10, self._snr / 10)))
99 | x = x + alpha * noise
100 | return x
101 |
102 |
103 | class Shift(nn.Module):
104 | """
105 | Randomly shift audio in time by up to `shift` samples.
106 | """
107 | def __init__(self, shift=4000):
108 | super().__init__()
109 | self.shift = shift
110 |
111 | def forward(self, wav):
112 | time, channels = wav.size()
113 | length = time - self.shift
114 | if self.shift > 0:
115 | if not self.training:
116 | wav = wav[..., :length]
117 | else:
118 | offset = torch.randint(self.shift, [channels, 1],
119 | device=wav.device)
120 | indexes = torch.arange(length, device=wav.device)
121 | offset = indexes + offset
122 | wav = wav.gather(0, offset.transpose(0, 1))
123 | return wav
124 |
125 |
126 | class FlipSign(nn.Module):
127 | """
128 | Random sign flip.
129 | """
130 | def forward(self, wav):
131 | time, channels = wav.size()
132 | if self.training:
133 | signs = torch.randint(2, (1, channels),
134 | device=wav.device,
135 | dtype=torch.float32)
136 | wav = wav * (2 * signs - 1)
137 | return wav
138 |
139 |
140 | if __name__ == "__main__":
141 | x = torch.randn(1, 10)
142 | y = GaussianNoise(10)(x)
143 | print(x)
144 | print(y)
145 |
--------------------------------------------------------------------------------
/configs/example.yaml:
--------------------------------------------------------------------------------
1 | data: data/csv_labels/balanced.csv
2 | label: data/softlabels/csv/balanced.csv
3 | batch_size: 64
4 | data_args:
5 | mode: Null
6 | num_workers: 8
7 | optimizer: AdamW
8 | optimizer_args:
9 | lr: 0.001
10 | scheduler_args:
11 | patience: 10
12 | factor: 0.1
13 | early_stop: 15
14 | epochs: 15
15 | itercv: 10000
16 | save: best
17 | model: CRNN
18 | model_args: {}
19 | outputpath: experiments/
20 | transforms: [timemask, freqmask]
21 | loss: FrameBCELoss
22 |
--------------------------------------------------------------------------------
/data/extract_feature.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | import argparse
3 | import librosa
4 | from tqdm import tqdm
5 | import io
6 | import logging
7 | from pathlib import Path
8 | import pandas as pd
9 | import numpy as np
10 | import soundfile as sf
11 | from pypeln import process as pr
12 | import gzip
13 | import h5py
14 |
15 | parser = argparse.ArgumentParser()
16 | parser.add_argument('input_csv')
17 | parser.add_argument('-o', '--output', type=str, required=True)
18 | parser.add_argument('-c', type=int, default=4)
19 | parser.add_argument('-sr', type=int, default=22050)
20 | parser.add_argument('-col',
21 | default='filename',
22 | type=str,
23 | help='Column to search for audio files')
24 | parser.add_argument('-cmn', default=False, action='store_true')
25 | parser.add_argument('-cvn', default=False, action='store_true')
26 | parser.add_argument('-winlen',
27 | default=40,
28 | type=float,
29 | help='FFT duration in ms')
30 | parser.add_argument('-hoplen',
31 | default=20,
32 | type=float,
33 | help='hop duration in ms')
34 |
35 | parser.add_argument('-n_mels', default=64, type=int)
36 | ARGS = parser.parse_args()
37 |
38 | DF = pd.read_csv(ARGS.input_csv, sep='\t',
39 | usecols=[0]) # only read first cols, allows to have messy csv
40 |
41 | MEL_ARGS = {
42 | 'n_mels': ARGS.n_mels,
43 | 'n_fft': 2048,
44 | 'hop_length': int(ARGS.sr * ARGS.hoplen / 1000),
45 | 'win_length': int(ARGS.sr * ARGS.winlen / 1000)
46 | }
47 |
48 | EPS = np.spacing(1)
49 |
50 |
51 | def extract_feature(fname):
52 | """extract_feature
53 | Extracts a log mel spectrogram feature from a filename, currently supports two filetypes:
54 |
55 | 1. Wave
56 | 2. Gzipped wave
57 |
58 | :param fname: filepath to the file to extract
59 | """
60 | ext = Path(fname).suffix
61 | try:
62 | if ext == '.gz':
63 | with gzip.open(fname, 'rb') as gzipped_wav:
64 | y, sr = sf.read(io.BytesIO(gzipped_wav.read()),
65 | dtype='float32')
66 | # Multiple channels, reduce
67 | if y.ndim == 2:
68 | y = y.mean(1)
69 | y = librosa.resample(y, sr, ARGS.sr)
70 | elif ext in ('.wav', '.flac'):
71 | y, sr = sf.read(fname, dtype='float32')
72 | if y.ndim > 1:
73 | y = y.mean(1)
74 | y = librosa.resample(y, sr, ARGS.sr)
75 | except Exception as e:
76 | # Exception usually happens because some data has 6 channels , which librosa cant handle
77 | logging.error(e)
78 | logging.error(fname)
79 | raise
80 | lms_feature = np.log(librosa.feature.melspectrogram(y, **MEL_ARGS) + EPS).T
81 | return fname, lms_feature
82 |
83 |
84 | with h5py.File(ARGS.output, 'w') as store:
85 | for fname, feat in tqdm(pr.map(extract_feature,
86 | DF[ARGS.col].unique(),
87 | workers=ARGS.c,
88 | maxsize=4),
89 | total=len(DF[ARGS.col].unique())):
90 | basename = Path(fname).name
91 | store[basename] = feat
92 |
--------------------------------------------------------------------------------
/data/extract_feature_for_train.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | import argparse
3 | import librosa
4 | from tqdm import tqdm
5 | import io
6 | from pathlib import Path
7 | from loguru import logger
8 | import pandas as pd
9 | import numpy as np
10 | import soundfile as sf
11 | from pypeln import process as pr
12 | import h5py
13 | import gzip
14 |
15 | parser = argparse.ArgumentParser()
16 | parser.add_argument('input_csv')
17 | parser.add_argument('-o', '--output', type=str, required=True)
18 | parser.add_argument('-l',
19 | '--label',
20 | type=str,
21 | required=True,
22 | help="Output label(chunked)")
23 | parser.add_argument('-s',
24 | '--size',
25 | type=float,
26 | default=10,
27 | help="Length of each segment")
28 | parser.add_argument('-t',
29 | '--threshold',
30 | type=float,
31 | default=1,
32 | help='Do not save files less than -t seconds',
33 | metavar='s')
34 | parser.add_argument('-c', type=int, default=4)
35 | parser.add_argument('-sr', type=int, default=22050)
36 | parser.add_argument('-col',
37 | default='filename',
38 | type=str,
39 | help='Column to search for audio files')
40 | parser.add_argument('-cmn', default=False, action='store_true')
41 | parser.add_argument('-cvn', default=False, action='store_true')
42 | parser.add_argument('-winlen',
43 | default=40,
44 | type=float,
45 | help='FFT duration in ms')
46 | parser.add_argument('-hoplen',
47 | default=20,
48 | type=float,
49 | help='hop duration in ms')
50 |
51 | parser.add_argument('-n_mels', default=64, type=int)
52 | ARGS = parser.parse_args()
53 |
54 | DF = pd.read_csv(ARGS.input_csv, usecols=[0], sep=' ')
55 |
56 | MEL_ARGS = {
57 | 'n_mels': ARGS.n_mels,
58 | 'n_fft': 2048,
59 | 'hop_length': int(ARGS.sr * ARGS.hoplen / 1000),
60 | 'win_length': int(ARGS.sr * ARGS.winlen / 1000)
61 | }
62 |
63 | EPS = np.spacing(1)
64 | DURATION_CHUNK = ARGS.size / (ARGS.hoplen / 1000)
65 | THRESHOLD = ARGS.threshold / (ARGS.hoplen / 1000)
66 |
67 |
68 | def extract_feature(fname):
69 | # def extract_feature(fname, segfname, start, end, nseg):
70 | """extract_feature
71 | Extracts a log mel spectrogram feature from a filename, currently supports two filetypes:
72 |
73 | 1. Wave
74 | 2. Gzipped wave
75 |
76 | :param fname: filepath to the file to extract
77 | """
78 | pospath = Path(fname)
79 | ext = pospath.suffix
80 | try:
81 | if ext == '.gz':
82 | with gzip.open(fname, 'rb') as gzipped_wav:
83 | y, sr = sf.read(io.BytesIO(gzipped_wav.read()),
84 | dtype='float32')
85 | # Multiple channels, reduce
86 | if y.ndim == 2:
87 | y = y.mean(1)
88 | y = librosa.resample(y, sr, ARGS.sr)
89 | elif ext in ('.wav', '.flac'):
90 | y, sr = sf.read(fname, dtype='float32')
91 | if y.ndim > 1:
92 | y = y.mean(1)
93 | y = librosa.resample(y, sr, ARGS.sr)
94 | except Exception as e:
95 | # Exception usually happens because some data has 6 channels , which librosa cant handle
96 | logger.error(e)
97 | logger.error(fname)
98 | raise
99 | fname = pospath.name
100 | feat = np.log(librosa.feature.melspectrogram(y, **MEL_ARGS) + EPS).T
101 | start_range = np.arange(0, feat.shape[0], DURATION_CHUNK, dtype=int)
102 | end_range = (start_range + DURATION_CHUNK).astype(int)
103 | end_range[-1] = feat.shape[0]
104 | for nseg, (start_time, end_time) in enumerate(zip(start_range, end_range)):
105 | seg = feat[start_time:end_time]
106 | if end_time - start_time < THRESHOLD:
107 | # Dont save
108 | continue
109 | yield fname, seg, nseg
110 |
111 |
112 | with h5py.File(ARGS.output, 'w') as store, tqdm() as pbar, open(ARGS.label,'w') as output_csv:
113 | output_csv.write(f"filename hdf5path\n") #write header
114 | hdf5_path = Path(ARGS.output).absolute()
115 | for fname, feat, nseg in pr.flat_map(extract_feature,
116 | DF['filename'].unique(),
117 | workers=ARGS.c,
118 | maxsize=ARGS.c * 2):
119 | new_fname = f"{Path(fname).stem}_{nseg:05d}{Path(fname).suffix}"
120 | store[new_fname] = feat
121 | output_csv.write(f"{new_fname} {hdf5_path}\n")
122 | pbar.set_postfix(stored=new_fname, shape=feat.shape)
123 | pbar.update()
124 |
--------------------------------------------------------------------------------
/data/prepare_labels.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import pandas as pd
3 | import numpy as np
4 | import argparse
5 | from h5py import File
6 | from pathlib import Path
7 | from loguru import logger
8 | import torch.utils.data as tdata
9 | from tqdm import tqdm
10 | from models import crnn, cnn10
11 | import sys
12 | import csv
13 |
14 |
15 | class HDF5Dataset(tdata.Dataset):
16 | """
17 | HDF5 dataset indexed by a labels dataframe.
18 | Indexing is done via the dataframe since we want to preserve some storage
19 | in cases where oversampling is needed ( pretty likely )
20 | """
21 | def __init__(self, h5file: File, transform=None):
22 | super(HDF5Dataset, self).__init__()
23 | self._h5file = h5file
24 | self.dataset = None
25 | # IF none is passed still use no transform at all
26 | self._transform = transform
27 | with File(self._h5file, 'r') as store:
28 | self._len = len(store)
29 | self._labels = list(store.keys())
30 | self.datadim = store[self._labels[0]].shape[-1]
31 |
32 | def __len__(self):
33 | return self._len
34 |
35 | def __getitem__(self, index):
36 | if self.dataset is None:
37 | self.dataset = File(self._h5file, 'r')
38 | fname = self._labels[index]
39 | data = self.dataset[fname][()]
40 | data = torch.as_tensor(data).float()
41 | if self._transform:
42 | data = self._transform(data)
43 | return data, fname
44 |
45 |
46 | MODELS = {
47 | 'crnn': {
48 | 'model': crnn,
49 | 'encoder': torch.load('encoders/balanced.pth'),
50 | 'outputdim': 527,
51 | },
52 | 'gpvb': {
53 | 'model': crnn,
54 | 'encoder': torch.load('encoders/balanced_binary.pth'),
55 | 'outputdim': 2,
56 | }
57 | }
58 |
59 | POOLING = {
60 | 'max': lambda x: np.max(x, axis=-1),
61 | 'mean': lambda x: np.mean(x, axis=-1)
62 | }
63 |
64 | DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
65 | DEVICE = torch.device(DEVICE)
66 |
67 |
68 | def main():
69 | parser = argparse.ArgumentParser()
70 | parser.add_argument('data', type=Path)
71 | parser.add_argument('-m', '--model', default='crnn', type=str)
72 | parser.add_argument('-po',
73 | '--pool',
74 | default='max',
75 | choices=POOLING.keys(),
76 | type=str)
77 | parser.add_argument('--pre', '-p', default='pretrained/gpv_f.pth')
78 | parser.add_argument('hdf5output', type=Path)
79 | parser.add_argument('csvoutput', type=Path)
80 | args = parser.parse_args()
81 |
82 | log_format = "[{time:YYYY-MM-DD HH:mm:ss}] {message}"
83 | logger.configure(handlers=[{"sink": sys.stderr, "format": log_format}])
84 |
85 | for k, v in vars(args).items():
86 | logger.info(f"{k} : {v}")
87 |
88 | model_dict = MODELS[args.model]
89 | model = model_dict['model'](outputdim=model_dict['outputdim'],
90 | pretrained_from=args.pre).to(DEVICE).eval()
91 | encoder = model_dict['encoder']
92 | logger.info(model)
93 | pooling_fun = POOLING[args.pool]
94 | if Path(args.data).suffix == '.csv':
95 | data = pd.read_csv(args.data, sep='\s+')
96 | data = data['hdf5path'].unique()
97 | assert len(data) == 1, "Only single hdf5 supported yet"
98 | data = data[0]
99 | else: #h5 file directly
100 | data = args.data
101 |
102 | logger.info(f"Reading from input file {data}")
103 | dataloader = tdata.DataLoader(HDF5Dataset(data),
104 | num_workers=4,
105 | batch_size=1)
106 | speech_class_idx = np.where(encoder.classes_ == 'Speech')[0]
107 | non_speech_idx = np.arange(len(encoder.classes_))
108 | non_speech_idx = np.delete(non_speech_idx, speech_class_idx)
109 | with torch.no_grad(), File(args.hdf5output, 'w') as store, tqdm(
110 | total=len(dataloader)) as pbar, open(args.csvoutput,
111 | 'w') as csvfile:
112 | abs_output_hdf5 = Path(args.hdf5output).absolute()
113 | csvwr = csv.writer(csvfile, delimiter='\t')
114 | csvwr.writerow(['filename', 'hdf5path'])
115 | for batch in dataloader:
116 | x, fname = batch
117 | fname = fname[0]
118 | x = x.to(DEVICE)
119 | if x.shape[1] < 8:
120 | continue
121 | clip_pred, time_pred = model(x)
122 | clip_pred = clip_pred.squeeze(0).to('cpu').numpy()
123 | time_pred = time_pred.squeeze(0).to('cpu').numpy()
124 | speech_time_pred = time_pred[..., speech_class_idx].squeeze(-1)
125 | speech_clip_pred = clip_pred[..., speech_class_idx].squeeze(-1)
126 | non_speech_clip_pred = clip_pred[..., non_speech_idx]
127 | non_speech_time_pred = time_pred[..., non_speech_idx]
128 | non_speech_time_pred = pooling_fun(non_speech_time_pred)
129 | store[f'{fname}/speech'] = speech_time_pred
130 | store[f'{fname}/noise'] = non_speech_time_pred
131 | store[f'{fname}/clipspeech'] = speech_clip_pred
132 | store[f'{fname}/clipnoise'] = non_speech_clip_pred
133 | csvwr.writerow([fname, abs_output_hdf5])
134 | pbar.set_postfix(fname=fname, speechsize=speech_time_pred.shape)
135 | pbar.update()
136 |
137 |
138 | if __name__ == "__main__":
139 | main()
140 |
--------------------------------------------------------------------------------
/dataset.py:
--------------------------------------------------------------------------------
1 | import time
2 | import torch
3 | import numpy as np
4 | import pandas as pd
5 | import scipy
6 | from h5py import File
7 | import itertools, random
8 | from tqdm import tqdm
9 | from loguru import logger
10 | import torch.utils.data as tdata
11 | from typing import List, Dict
12 |
13 |
14 | class TrainHDF5Dataset(tdata.Dataset):
15 | """
16 | HDF5 dataset indexed by a labels dataframe.
17 | Indexing is done via the dataframe since we want to preserve some storage
18 | in cases where oversampling is needed ( pretty likely )
19 | """
20 | def __init__(self,
21 | h5filedict: Dict,
22 | h5labeldict: Dict,
23 | label_type='soft',
24 | transform=None):
25 | super(TrainHDF5Dataset, self).__init__()
26 | self._h5filedict = h5filedict
27 | self._h5labeldict = h5labeldict
28 | self._datasetcache = {}
29 | self._labelcache = {}
30 | self._len = len(self._h5labeldict)
31 | # IF none is passed still use no transform at all
32 | self._transform = transform
33 | assert label_type in ('soft', 'hard', 'softhard', 'hardnoise')
34 | self._label_type = label_type
35 |
36 | self.idx_to_item = {
37 | idx: item
38 | for idx, item in enumerate(self._h5labeldict.keys())
39 | }
40 | first_item = next(iter(self._h5filedict.keys()))
41 | with File(self._h5filedict[first_item], 'r') as store:
42 | self.datadim = store[first_item].shape[-1]
43 |
44 | def __len__(self):
45 | return self._len
46 |
47 | def __del__(self):
48 | for k, cache in self._datasetcache.items():
49 | cache.close()
50 | for k, cache in self._labelcache.items():
51 | cache.close()
52 |
53 | def __getitem__(self, index: int):
54 | fname: str = self.idx_to_item[index]
55 | h5file: str = self._h5filedict[fname]
56 | labelh5file: str = self._h5labeldict[fname]
57 | if not h5file in self._datasetcache:
58 | self._datasetcache[h5file] = File(h5file, 'r')
59 | if not labelh5file in self._labelcache:
60 | self._labelcache[labelh5file] = File(labelh5file, 'r')
61 |
62 | data = self._datasetcache[h5file][f"{fname}"][()]
63 | speech_target = self._labelcache[labelh5file][f"{fname}/speech"][()]
64 | noise_target = self._labelcache[labelh5file][f"{fname}/noise"][()]
65 | speech_clip_target = self._labelcache[labelh5file][
66 | f"{fname}/clipspeech"][()]
67 | noise_clip_target = self._labelcache[labelh5file][
68 | f"{fname}/clipnoise"][()]
69 |
70 | noise_clip_target = np.max(noise_clip_target) # take max around axis
71 | if self._label_type == 'hard':
72 | noise_clip_target = noise_clip_target.round()
73 | speech_target = speech_target.round()
74 | noise_target = noise_target.round()
75 | speech_clip_target = speech_clip_target.round()
76 | elif self._label_type == 'hardnoise': # only noise yay
77 | noise_clip_target = noise_clip_target.round()
78 | noise_target = noise_target.round()
79 | elif self._label_type == 'softhard':
80 | r = np.random.permutation(noise_target.shape[0] // 4)
81 | speech_target[r] = speech_target[r].round()
82 | target_clip = torch.tensor((noise_clip_target, speech_clip_target))
83 | data = torch.as_tensor(data).float()
84 | target_time = torch.as_tensor(
85 | np.stack((noise_target, speech_target), axis=-1)).float()
86 | if self._transform:
87 | data = self._transform(data)
88 | return data, target_time, target_clip, fname
89 |
90 |
91 | class HDF5Dataset(tdata.Dataset):
92 | """
93 | HDF5 dataset indexed by a labels dataframe.
94 | Indexing is done via the dataframe since we want to preserve some storage
95 | in cases where oversampling is needed ( pretty likely )
96 | """
97 | def __init__(self, h5file: File, h5label: File, fnames, transform=None):
98 | super(HDF5Dataset, self).__init__()
99 | self._h5file = h5file
100 | self._h5label = h5label
101 | self.fnames = fnames
102 | self.dataset = None
103 | self.label_dataset = None
104 | self._len = len(fnames)
105 | # IF none is passed still use no transform at all
106 | self._transform = transform
107 | with File(self._h5file, 'r') as store, File(self._h5label,
108 | 'r') as labelstore:
109 | self.datadim = store[self.fnames[0]].shape[-1]
110 |
111 | def __len__(self):
112 | return self._len
113 |
114 | def __getitem__(self, index):
115 | if self.dataset is None:
116 | self.dataset = File(self._h5file, 'r')
117 | self.label_dataset = File(self._h5label, 'r')
118 | fname = self.fnames[index]
119 | data = self.dataset[fname][()]
120 | speech_target = self.label_dataset[f"{fname}/speech"][()]
121 | noise_target = self.label_dataset[f"{fname}/noise"][()]
122 | speech_clip_target = self.label_dataset[f"{fname}/clipspeech"][()]
123 | noise_clip_target = self.label_dataset[f"{fname}/clipnoise"][()]
124 | noise_clip_target = np.max(noise_clip_target) # take max around axis
125 | target_clip = torch.tensor((noise_clip_target, speech_clip_target))
126 | data = torch.as_tensor(data).float()
127 | target_time = torch.as_tensor(
128 | np.stack((noise_target, speech_target), axis=-1)).float()
129 | if self._transform:
130 | data = self._transform(data)
131 | return data, target_time, target_clip, fname
132 |
133 |
134 | class EvalH5Dataset(tdata.Dataset):
135 | """
136 | HDF5 dataset indexed by a labels dataframe.
137 | Indexing is done via the dataframe since we want to preserve some storage
138 | in cases where oversampling is needed ( pretty likely )
139 | """
140 | def __init__(self, h5file: File, fnames=None):
141 | super(EvalH5Dataset, self).__init__()
142 | self._h5file = h5file
143 | self._dataset = None
144 | # IF none is passed still use no transform at all
145 | with File(self._h5file, 'r') as store:
146 | if fnames is None:
147 | self.fnames = list(store.keys())
148 | else:
149 | self.fnames = fnames
150 | self.datadim = store[self.fnames[0]].shape[-1]
151 | self._len = len(store)
152 |
153 | def __len__(self):
154 | return self._len
155 |
156 | def __getitem__(self, index):
157 | if self._dataset is None:
158 | self._dataset = File(self._h5file, 'r')
159 | fname = self.fnames[index]
160 | data = self._dataset[fname][()]
161 | data = torch.as_tensor(data).float()
162 | return data, fname
163 |
164 |
165 | class MinimumOccupancySampler(tdata.Sampler):
166 | """
167 | docstring for MinimumOccupancySampler
168 | samples at least one instance from each class sequentially
169 | """
170 | def __init__(self, labels, sampling_mode='same', random_state=None):
171 | self.labels = labels
172 | data_samples, n_labels = labels.shape
173 | label_to_idx_list, label_to_length = [], []
174 | self.random_state = np.random.RandomState(seed=random_state)
175 | for lb_idx in range(n_labels):
176 | label_selection = labels[:, lb_idx]
177 | if scipy.sparse.issparse(label_selection):
178 | label_selection = label_selection.toarray()
179 | label_indexes = np.where(label_selection == 1)[0]
180 | self.random_state.shuffle(label_indexes)
181 | label_to_length.append(len(label_indexes))
182 | label_to_idx_list.append(label_indexes)
183 |
184 | self.longest_seq = max(label_to_length)
185 | self.data_source = np.empty((self.longest_seq, len(label_to_length)),
186 | dtype=np.uint32)
187 | # Each column represents one "single instance per class" data piece
188 | for ix, leng in enumerate(label_to_length):
189 | # Fill first only "real" samples
190 | self.data_source[:leng, ix] = label_to_idx_list[ix]
191 |
192 | self.label_to_idx_list = label_to_idx_list
193 | self.label_to_length = label_to_length
194 |
195 | if sampling_mode == 'same':
196 | self.data_length = data_samples
197 | elif sampling_mode == 'over': # Sample all items
198 | self.data_length = np.prod(self.data_source.shape)
199 |
200 | def _reshuffle(self):
201 | # Reshuffle
202 | for ix, leng in enumerate(self.label_to_length):
203 | leftover = self.longest_seq - leng
204 | random_idxs = np.random.randint(leng, size=leftover)
205 | self.data_source[leng:,
206 | ix] = self.label_to_idx_list[ix][random_idxs]
207 |
208 | def __iter__(self):
209 | # Before each epoch, reshuffle random indicies
210 | self._reshuffle()
211 | n_samples = len(self.data_source)
212 | random_indices = self.random_state.permutation(n_samples)
213 | data = np.concatenate(
214 | self.data_source[random_indices])[:self.data_length]
215 | return iter(data)
216 |
217 | def __len__(self):
218 | return self.data_length
219 |
220 |
221 | class MultiBalancedSampler(tdata.sampler.Sampler):
222 | """docstring for BalancedSampler
223 | Samples for Multi-label training
224 | Sampling is not totally equal, but aims to be roughtly equal
225 | """
226 | def __init__(self, Y, replacement=False, num_samples=None):
227 | assert Y.ndim == 2, "Y needs to be one hot encoded"
228 | if scipy.sparse.issparse(Y):
229 | raise ValueError("Not supporting sparse amtrices yet")
230 | class_counts = np.sum(Y, axis=0)
231 | class_weights = 1. / class_counts
232 | class_weights = class_weights / class_weights.sum()
233 | classes = np.arange(Y[0].shape[0])
234 | # Revert from many_hot to one
235 | class_ids = [tuple(classes.compress(idx)) for idx in Y]
236 |
237 | sample_weights = []
238 | for i in range(len(Y)):
239 | # Multiple classes were chosen, calculate average probability
240 | weight = class_weights[np.array(class_ids[i])]
241 | # Take the mean of the multiple classes and set as weight
242 | weight = np.mean(weight)
243 | sample_weights.append(weight)
244 | self._weights = torch.as_tensor(sample_weights, dtype=torch.float)
245 | self._len = num_samples if num_samples else len(Y)
246 | self._replacement = replacement
247 |
248 | def __len__(self):
249 | return self._len
250 |
251 | def __iter__(self):
252 | return iter(
253 | torch.multinomial(self._weights, self._len,
254 | self._replacement).tolist())
255 |
256 |
257 | def gettraindataloader(h5files,
258 | h5labels,
259 | label_type=False,
260 | transform=None,
261 | **dataloader_kwargs):
262 | dset = TrainHDF5Dataset(h5files,
263 | h5labels,
264 | label_type=label_type,
265 | transform=transform)
266 | return tdata.DataLoader(dset,
267 | collate_fn=sequential_collate,
268 | **dataloader_kwargs)
269 |
270 |
271 | def getdataloader(h5file, h5label, fnames, transform=None,
272 | **dataloader_kwargs):
273 | dset = HDF5Dataset(h5file, h5label, fnames, transform=transform)
274 | return tdata.DataLoader(dset,
275 | collate_fn=sequential_collate,
276 | **dataloader_kwargs)
277 |
278 |
279 | def pad(tensorlist, padding_value=0.):
280 | lengths = [len(f) for f in tensorlist]
281 | max_len = np.max(lengths)
282 | # max_len = 2000
283 | batch_dim = len(lengths)
284 | data_dim = tensorlist[0].shape[-1]
285 | out_tensor = torch.full((batch_dim, max_len, data_dim),
286 | fill_value=padding_value,
287 | dtype=torch.float32)
288 | for i, tensor in enumerate(tensorlist):
289 | length = tensor.shape[0]
290 | out_tensor[i, :length, ...] = tensor[:length, ...]
291 | return out_tensor, torch.tensor(lengths)
292 |
293 |
294 | def sequential_collate(batches):
295 | # sort length wise
296 | data, targets_time, targets_clip, fnames = zip(*batches)
297 | data, lengths_data = pad(data)
298 | targets_time, lengths_tar = pad(targets_time, padding_value=0)
299 | targets_clip = torch.stack(targets_clip)
300 | assert lengths_data.shape == lengths_tar.shape
301 | return data, targets_time, targets_clip, fnames, lengths_tar
302 |
303 |
304 | if __name__ == '__main__':
305 | import utils
306 | label_df = pd.read_csv(
307 | 'data/csv_labels/unbalanced_from_unbalanced/unbalanced.csv', sep='\s+')
308 | data_df = pd.read_csv("data/data_csv/unbalanced.csv", sep='\s+')
309 |
310 | merged = data_df.merge(label_df, on='filename')
311 | common_idxs = merged['filename']
312 | data_df = data_df[data_df['filename'].isin(common_idxs)]
313 | label_df = label_df[label_df['filename'].isin(common_idxs)]
314 |
315 | label = utils.df_to_dict(label_df)
316 | data = utils.df_to_dict(data_df)
317 |
318 | trainloader = gettraindataloader(
319 | h5files=data,
320 | h5labels=label,
321 | transform=None,
322 | label_type='soft',
323 | batch_size=64,
324 | num_workers=3,
325 | shuffle=False,
326 | )
327 |
328 | with tqdm(total=len(trainloader)) as pbar:
329 | for batch in trainloader:
330 | inputs, targets_time, targets_clip, filenames, lengths = batch
331 | pbar.set_postfix(inp=inputs.shape)
332 | pbar.update()
333 |
--------------------------------------------------------------------------------
/example/example.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RicherMans/Datadriven-GPVAD/6c94570bee753271722f64826a0ed00c030b089c/example/example.wav
--------------------------------------------------------------------------------
/figs/data_driven_framework.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RicherMans/Datadriven-GPVAD/6c94570bee753271722f64826a0ed00c030b089c/figs/data_driven_framework.png
--------------------------------------------------------------------------------
/figs/sample_background.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RicherMans/Datadriven-GPVAD/6c94570bee753271722f64826a0ed00c030b089c/figs/sample_background.png
--------------------------------------------------------------------------------
/figs/sample_music.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RicherMans/Datadriven-GPVAD/6c94570bee753271722f64826a0ed00c030b089c/figs/sample_music.png
--------------------------------------------------------------------------------
/figs/sample_speech.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RicherMans/Datadriven-GPVAD/6c94570bee753271722f64826a0ed00c030b089c/figs/sample_speech.png
--------------------------------------------------------------------------------
/figs/samples_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RicherMans/Datadriven-GPVAD/6c94570bee753271722f64826a0ed00c030b089c/figs/samples_1.png
--------------------------------------------------------------------------------
/figs/samples_2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RicherMans/Datadriven-GPVAD/6c94570bee753271722f64826a0ed00c030b089c/figs/samples_2.png
--------------------------------------------------------------------------------
/figs/samples_3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RicherMans/Datadriven-GPVAD/6c94570bee753271722f64826a0ed00c030b089c/figs/samples_3.png
--------------------------------------------------------------------------------
/figs/samples_4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RicherMans/Datadriven-GPVAD/6c94570bee753271722f64826a0ed00c030b089c/figs/samples_4.png
--------------------------------------------------------------------------------
/forward.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import sys
3 | from loguru import logger
4 | from pathlib import Path
5 | from tqdm import tqdm
6 | import utils
7 | import pandas as pd
8 | import numpy as np
9 | import librosa
10 | import soundfile as sf
11 | import uuid
12 | import argparse
13 | from models import crnn
14 | import os
15 |
16 | SAMPLE_RATE = 22050
17 | EPS = np.spacing(1)
18 | LMS_ARGS = {
19 | 'n_fft': 2048,
20 | 'n_mels': 64,
21 | 'hop_length': int(SAMPLE_RATE * 0.02),
22 | 'win_length': int(SAMPLE_RATE * 0.04)
23 | }
24 | DEVICE = 'cpu'
25 | if torch.cuda.is_available():
26 | DEVICE = 'cuda'
27 | DEVICE = torch.device(DEVICE)
28 |
29 |
30 | def extract_feature(wavefilepath, **kwargs):
31 | _, file_extension = os.path.splitext(wavefilepath)
32 | if file_extension == '.wav':
33 | wav, sr = sf.read(wavefilepath, dtype='float32')
34 | if file_extension == '.mp3':
35 | wav, sr = librosa.load(wavefilepath)
36 | elif file_extension not in ['.mp3', '.wav']:
37 | raise NotImplementedError('Audio extension not supported... yet ;)')
38 | if wav.ndim > 1:
39 | wav = wav.mean(-1)
40 | wav = librosa.resample(wav, sr, target_sr=SAMPLE_RATE)
41 | return np.log(
42 | librosa.feature.melspectrogram(wav.astype(np.float32), SAMPLE_RATE, **
43 | kwargs) + EPS).T
44 |
45 |
46 | class OnlineLogMelDataset(torch.utils.data.Dataset):
47 | def __init__(self, data_list, **kwargs):
48 | super().__init__()
49 | self.dlist = data_list
50 | self.kwargs = kwargs
51 |
52 | def __getitem__(self, idx):
53 | return extract_feature(wavefilepath=self.dlist[idx],
54 | **self.kwargs), self.dlist[idx]
55 |
56 | def __len__(self):
57 | return len(self.dlist)
58 |
59 |
60 | MODELS = {
61 | 't1': {
62 | 'model': crnn,
63 | 'outputdim': 527,
64 | 'encoder': 'labelencoders/teacher.pth',
65 | 'pretrained': 'teacher1/model.pth',
66 | 'resolution': 0.02
67 | },
68 | 't2': {
69 | 'model': crnn,
70 | 'outputdim': 527,
71 | 'encoder': 'labelencoders/teacher.pth',
72 | 'pretrained': 'teacher2/model.pth',
73 | 'resolution': 0.02
74 | },
75 | 'sre': {
76 | 'model': crnn,
77 | 'outputdim': 2,
78 | 'encoder': 'labelencoders/students.pth',
79 | 'pretrained': 'sre/model.pth',
80 | 'resolution': 0.02
81 | },
82 | 'v2': {
83 | 'model': crnn,
84 | 'outputdim': 2,
85 | 'encoder': 'labelencoders/students.pth',
86 | 'pretrained': 'vox2/model.pth',
87 | 'resolution': 0.02
88 | },
89 | 'a2': {
90 | 'model': crnn,
91 | 'outputdim': 2,
92 | 'encoder': 'labelencoders/students.pth',
93 | 'pretrained': 'audioset2/model.pth',
94 | 'resolution': 0.02
95 | },
96 | 'a2_v2': {
97 | 'model': crnn,
98 | 'outputdim': 2,
99 | 'encoder': 'labelencoders/students.pth',
100 | 'pretrained': 'audio2_vox2/model.pth',
101 | 'resolution': 0.02
102 | },
103 | 'c1': {
104 | 'model': crnn,
105 | 'outputdim': 2,
106 | 'encoder': 'labelencoders/students.pth',
107 | 'pretrained': 'c1/model.pth',
108 | 'resolution': 0.02
109 | },
110 | }
111 |
112 |
113 | def main():
114 | parser = argparse.ArgumentParser()
115 | group = parser.add_mutually_exclusive_group(required=True)
116 | group.add_argument(
117 | '-w',
118 | '--wav',
119 | help=
120 | 'A single wave/mp3/flac or any other compatible audio file with soundfile.read'
121 | )
122 | group.add_argument(
123 | '-l',
124 | '--wavlist',
125 | help=
126 | 'A list of wave or any other compatible audio files. E.g., output of find . -type f -name *.wav > wavlist.txt'
127 | )
128 | parser.add_argument('-model', choices=list(MODELS.keys()), default='sre')
129 | parser.add_argument(
130 | '--pretrained_dir',
131 | default='pretrained_models',
132 | help=
133 | 'Path to downloaded pretrained models directory, (default %(default)s)'
134 | )
135 | parser.add_argument('-o',
136 | '--output_path',
137 | default=None,
138 | help='Output folder to save predictions if necessary')
139 | parser.add_argument('-soft',
140 | default=False,
141 | action='store_true',
142 | help='Outputs soft probabilities.')
143 | parser.add_argument('-hard',
144 | default=False,
145 | action='store_true',
146 | help='Outputs hard labels as zero-one array.')
147 | parser.add_argument('-th',
148 | '--threshold',
149 | default=(0.5, 0.1),
150 | type=float,
151 | nargs="+")
152 | args = parser.parse_args()
153 | pretrained_dir = Path(args.pretrained_dir)
154 | if not (pretrained_dir.exists() and pretrained_dir.is_dir()):
155 | logger.error(f"""Pretrained directory {args.pretrained_dir} not found.
156 | Please download the pretrained models from and try again or set --pretrained_dir to your directory."""
157 | )
158 | return
159 | logger.info("Passed args")
160 | for k, v in vars(args).items():
161 | logger.info(f"{k} : {str(v):<10}")
162 | if args.wavlist:
163 | wavlist = pd.read_csv(args.wavlist,
164 | usecols=[0],
165 | header=None,
166 | names=['filename'])
167 | wavlist = wavlist['filename'].values.tolist()
168 | elif args.wav:
169 | wavlist = [args.wav]
170 | dset = OnlineLogMelDataset(wavlist, **LMS_ARGS)
171 | dloader = torch.utils.data.DataLoader(dset,
172 | batch_size=1,
173 | num_workers=3,
174 | shuffle=False)
175 |
176 | model_kwargs_pack = MODELS[args.model]
177 | model_resolution = model_kwargs_pack['resolution']
178 | # Load model from relative path
179 | model = model_kwargs_pack['model'](
180 | outputdim=model_kwargs_pack['outputdim'],
181 | pretrained_from=pretrained_dir /
182 | model_kwargs_pack['pretrained']).to(DEVICE).eval()
183 | encoder = torch.load(pretrained_dir / model_kwargs_pack['encoder'])
184 | logger.trace(model)
185 |
186 | output_dfs = []
187 | frame_outputs = {}
188 | threshold = tuple(args.threshold)
189 |
190 | speech_label_idx = np.where('Speech' == encoder.classes_)[0].squeeze()
191 | # Using only binary thresholding without filter
192 | if len(threshold) == 1:
193 | postprocessing_method = utils.binarize
194 | else:
195 | postprocessing_method = utils.double_threshold
196 | with torch.no_grad(), tqdm(total=len(dloader), leave=False,
197 | unit='clip') as pbar:
198 | for feature, filename in dloader:
199 | feature = torch.as_tensor(feature).to(DEVICE)
200 | prediction_tag, prediction_time = model(feature)
201 | prediction_tag = prediction_tag.to('cpu')
202 | prediction_time = prediction_time.to('cpu')
203 |
204 | if prediction_time is not None: # Some models do not predict timestamps
205 |
206 | cur_filename = filename[0] #Remove batchsize
207 | thresholded_prediction = postprocessing_method(
208 | prediction_time, *threshold)
209 | speech_soft_pred = prediction_time[..., speech_label_idx]
210 | if args.soft:
211 | speech_soft_pred = prediction_time[
212 | ..., speech_label_idx].numpy()
213 | frame_outputs[cur_filename] = speech_soft_pred[
214 | 0] # 1 batch
215 |
216 | if args.hard:
217 | speech_hard_pred = thresholded_prediction[...,
218 | speech_label_idx]
219 | frame_outputs[cur_filename] = speech_hard_pred[
220 | 0] # 1 batch
221 | # frame_outputs_hard.append(thresholded_prediction)
222 |
223 | labelled_predictions = utils.decode_with_timestamps(
224 | encoder, thresholded_prediction)
225 | pred_label_df = pd.DataFrame(
226 | labelled_predictions[0],
227 | columns=['event_label', 'onset', 'offset'])
228 | if not pred_label_df.empty:
229 | pred_label_df['filename'] = cur_filename
230 | pred_label_df['onset'] *= model_resolution
231 | pred_label_df['offset'] *= model_resolution
232 | pbar.set_postfix(labels=','.join(
233 | np.unique(pred_label_df['event_label'].values)))
234 | pbar.update()
235 | output_dfs.append(pred_label_df)
236 |
237 | full_prediction_df = pd.concat(output_dfs).sort_values(by='onset',ascending=True).reset_index()
238 | prediction_df = full_prediction_df[full_prediction_df['event_label'] ==
239 | 'Speech']
240 |
241 | if args.output_path:
242 | args.output_path = Path(args.output_path)
243 | args.output_path.mkdir(parents=True, exist_ok=True)
244 | prediction_df.to_csv(args.output_path / 'speech_predictions.tsv',
245 | sep='\t',
246 | index=False)
247 | full_prediction_df.to_csv(args.output_path / 'all_predictions.tsv',
248 | sep='\t',
249 | index=False)
250 |
251 | if args.soft or args.hard:
252 | prefix = 'soft' if args.soft else 'hard'
253 | with open(args.output_path / f'{prefix}_predictions.txt',
254 | 'w') as wp:
255 | np.set_printoptions(suppress=True,
256 | precision=2,
257 | linewidth=np.inf)
258 | for fname, output in frame_outputs.items():
259 | print(f"{fname} {output}", file=wp)
260 | logger.info(f"Putting results also to dir {args.output_path}")
261 | if args.soft or args.hard:
262 | np.set_printoptions(suppress=True, precision=2, linewidth=np.inf)
263 | for fname, output in frame_outputs.items():
264 | print(f"{fname} {output}")
265 | else:
266 | print(prediction_df.to_markdown(showindex=False))
267 |
268 |
269 | if __name__ == "__main__":
270 | main()
271 |
--------------------------------------------------------------------------------
/labelencoders/vad.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RicherMans/Datadriven-GPVAD/6c94570bee753271722f64826a0ed00c030b089c/labelencoders/vad.pth
--------------------------------------------------------------------------------
/losses.py:
--------------------------------------------------------------------------------
1 | import ignite.metrics as metrics
2 | import torch
3 | import torch.nn as nn
4 |
5 |
6 | class FrameBCELoss(nn.Module):
7 | """docstring for BCELoss"""
8 | def __init__(self):
9 | super().__init__()
10 |
11 | def forward(self, clip_prob, frame_prob, tar_time, tar_clip, length):
12 | batchsize, timesteps, ndim = tar_time.shape
13 | idxs = torch.arange(timesteps, device='cpu').repeat(batchsize).view(
14 | batchsize, timesteps)
15 | mask = (idxs < length.view(-1, 1)).to(frame_prob.device)
16 | masked_bce = nn.functional.binary_cross_entropy(
17 | input=frame_prob, target=tar_time,
18 | reduction='none') * mask.unsqueeze(-1)
19 | return masked_bce.sum() / mask.sum()
20 |
21 |
22 | class ClipFrameBCELoss(nn.Module):
23 | """docstring for BCELoss"""
24 | def __init__(self):
25 | super().__init__()
26 | self.frameloss = FrameBCELoss()
27 | self.cliploss = nn.BCELoss()
28 |
29 | def forward(self, clip_prob, frame_prob, tar_time, tar_clip, length):
30 | return self.frameloss(
31 | clip_prob, frame_prob, tar_time, tar_clip, length) + self.cliploss(
32 | clip_prob, tar_clip)
33 |
34 |
35 | class BCELossWithLabelSmoothing(nn.Module):
36 | """docstring for BCELoss"""
37 | def __init__(self, label_smoothing=0.1):
38 | super().__init__()
39 | self.label_smoothing = label_smoothing
40 |
41 | def forward(self, clip_prob, frame_prob, tar):
42 | n_classes = clip_prob.shape[-1]
43 | with torch.no_grad():
44 | tar = tar * (1 - self.label_smoothing) + (
45 | 1 - tar) * self.label_smoothing / (n_classes - 1)
46 | return nn.functional.binary_cross_entropy(clip_prob, tar)
47 |
48 |
49 | # Reimplement Loss, because ignite loss only takes 2 args, not 3 and nees to parse kwargs around ... just *output does the trick
50 | class Loss(metrics.Loss):
51 | def __init__(self,
52 | loss_fn,
53 | output_transform=lambda x: x,
54 | batch_size=lambda x: len(x),
55 | device=None):
56 | super(Loss, self).__init__(loss_fn=loss_fn,
57 | output_transform=output_transform,
58 | batch_size=batch_size)
59 |
60 | def update(self, output):
61 | average_loss = self._loss_fn(*output)
62 |
63 | if len(average_loss.shape) != 0:
64 | raise ValueError('loss_fn did not return the average loss.')
65 |
66 | N = self._batch_size(output[0])
67 | self._sum += average_loss.item() * N
68 | self._num_examples += N
69 |
70 |
71 | if __name__ == "__main__":
72 | batch, time, dim = 4, 500, 10
73 | frame = torch.sigmoid(torch.randn(batch, time, dim))
74 | clip = torch.sigmoid(torch.randn(batch, dim))
75 | tar = torch.empty(batch, dim).random_(2)
76 |
--------------------------------------------------------------------------------
/metrics.py:
--------------------------------------------------------------------------------
1 | import sed_eval
2 | import utils
3 | import pandas as pd
4 | from sklearn.preprocessing import binarize, MultiLabelBinarizer
5 | import sklearn.metrics as skmetrics
6 | import numpy as np
7 |
8 |
9 | def get_audio_tagging_df(df):
10 | return df.groupby('filename')['event_label'].unique().reset_index()
11 |
12 |
13 | def audio_tagging_results(reference, estimated):
14 | """audio_tagging_results. Returns clip-level F1 Scores
15 |
16 | :param reference: The ground truth dataframe as pd.DataFrame
17 | :param estimated: Predicted labels by the model ( thresholded )
18 | """
19 | if "event_label" in reference.columns:
20 | classes = reference.event_label.dropna().unique().tolist(
21 | ) + estimated.event_label.dropna().unique().tolist()
22 | encoder = MultiLabelBinarizer().fit([classes])
23 | reference = get_audio_tagging_df(reference)
24 | estimated = get_audio_tagging_df(estimated)
25 | ref_labels, _ = utils.encode_labels(reference['event_label'],
26 | encoder=encoder)
27 | reference['event_label'] = ref_labels.tolist()
28 | est_labels, _ = utils.encode_labels(estimated['event_label'],
29 | encoder=encoder)
30 | estimated['event_label'] = est_labels.tolist()
31 |
32 | matching = reference.merge(estimated,
33 | how='outer',
34 | on="filename",
35 | suffixes=["_ref", "_pred"])
36 |
37 | def na_values(val):
38 | if type(val) is np.ndarray:
39 | return val
40 | elif isinstance(val, list):
41 | return np.array(val)
42 | if pd.isna(val):
43 | return np.zeros(len(encoder.classes_))
44 | return val
45 |
46 | ret_df = pd.DataFrame(columns=['label', 'f1', 'precision', 'recall'])
47 | if not estimated.empty:
48 | matching['event_label_pred'] = matching.event_label_pred.apply(
49 | na_values)
50 | matching['event_label_ref'] = matching.event_label_ref.apply(na_values)
51 |
52 | y_true = np.vstack(matching['event_label_ref'].values)
53 | y_pred = np.vstack(matching['event_label_pred'].values)
54 | ret_df.loc[:, 'label'] = encoder.classes_
55 | for avg in [None, 'macro', 'micro']:
56 | avg_f1 = skmetrics.f1_score(y_true, y_pred, average=avg)
57 | avg_pre = skmetrics.precision_score(y_true, y_pred, average=avg)
58 | avg_rec = skmetrics.recall_score(y_true, y_pred, average=avg)
59 | # avg_auc = skmetrics.roc_auc_score(y_true, y_pred, average=avg)
60 |
61 | if avg == None:
62 | # Add for each label non pooled stats
63 | ret_df.loc[:, 'precision'] = avg_pre
64 | ret_df.loc[:, 'recall'] = avg_rec
65 | ret_df.loc[:, 'f1'] = avg_f1
66 | # ret_df.loc[:, 'AUC'] = avg_auc
67 | else:
68 | # Append macro and micro results in last 2 rows
69 | ret_df = ret_df.append(
70 | {
71 | 'label': avg,
72 | 'precision': avg_pre,
73 | 'recall': avg_rec,
74 | 'f1': avg_f1,
75 | # 'AUC': avg_auc
76 | },
77 | ignore_index=True)
78 | return ret_df
79 |
80 |
81 | def get_event_list_current_file(df, fname):
82 | """
83 | Get list of events for a given filename
84 | :param df: pd.DataFrame, the dataframe to search on
85 | :param fname: the filename to extract the value from the dataframe
86 | :return: list of events (dictionaries) for the given filename
87 | """
88 | event_file = df[df["filename"] == fname]
89 | if len(event_file) == 1:
90 | if pd.isna(event_file["event_label"].iloc[0]):
91 | event_list_for_current_file = [{"filename": fname}]
92 | else:
93 | event_list_for_current_file = event_file.to_dict('records')
94 | else:
95 | event_list_for_current_file = event_file.to_dict('records')
96 |
97 | return event_list_for_current_file
98 |
99 |
100 | def event_based_evaluation_df(reference,
101 | estimated,
102 | t_collar=0.200,
103 | percentage_of_length=0.2):
104 | """
105 | Calculate EventBasedMetric given a reference and estimated dataframe
106 | :param reference: pd.DataFrame containing "filename" "onset" "offset" and "event_label" columns which describe the
107 | reference events
108 | :param estimated: pd.DataFrame containing "filename" "onset" "offset" and "event_label" columns which describe the
109 | estimated events to be compared with reference
110 | :return: sed_eval.sound_event.EventBasedMetrics with the scores
111 | """
112 |
113 | evaluated_files = reference["filename"].unique()
114 |
115 | classes = []
116 | classes.extend(reference.event_label.dropna().unique())
117 | classes.extend(estimated.event_label.dropna().unique())
118 | classes = list(set(classes))
119 |
120 | event_based_metric = sed_eval.sound_event.EventBasedMetrics(
121 | event_label_list=classes,
122 | t_collar=t_collar,
123 | percentage_of_length=percentage_of_length,
124 | empty_system_output_handling='zero_score')
125 |
126 | for fname in evaluated_files:
127 | reference_event_list_for_current_file = get_event_list_current_file(
128 | reference, fname)
129 | estimated_event_list_for_current_file = get_event_list_current_file(
130 | estimated, fname)
131 |
132 | event_based_metric.evaluate(
133 | reference_event_list=reference_event_list_for_current_file,
134 | estimated_event_list=estimated_event_list_for_current_file,
135 | )
136 |
137 | return event_based_metric
138 |
139 |
140 | def segment_based_evaluation_df(reference, estimated, time_resolution=1.):
141 | evaluated_files = reference["filename"].unique()
142 |
143 | classes = []
144 | classes.extend(reference.event_label.dropna().unique())
145 | classes.extend(estimated.event_label.dropna().unique())
146 | classes = list(set(classes))
147 |
148 | segment_based_metric = sed_eval.sound_event.SegmentBasedMetrics(
149 | event_label_list=classes, time_resolution=time_resolution)
150 |
151 | for fname in evaluated_files:
152 | reference_event_list_for_current_file = get_event_list_current_file(
153 | reference, fname)
154 | estimated_event_list_for_current_file = get_event_list_current_file(
155 | estimated, fname)
156 |
157 | segment_based_metric.evaluate(
158 | reference_event_list=reference_event_list_for_current_file,
159 | estimated_event_list=estimated_event_list_for_current_file)
160 |
161 | return segment_based_metric
162 |
163 |
164 | def compute_metrics(valid_df, pred_df, time_resolution=1.):
165 |
166 | metric_event = event_based_evaluation_df(valid_df,
167 | pred_df,
168 | t_collar=0.200,
169 | percentage_of_length=0.2)
170 | metric_segment = segment_based_evaluation_df(
171 | valid_df, pred_df, time_resolution=time_resolution)
172 | return metric_event, metric_segment
173 |
174 |
175 | def roc(y_true, y_pred, average=None):
176 | return skmetrics.roc_auc_score(y_true, y_pred, average=average)
177 |
178 |
179 | def mAP(y_true, y_pred, average=None):
180 | return skmetrics.average_precision_score(y_true, y_pred, average=average)
181 |
182 |
183 | def precision_recall_fscore_support(y_true, y_pred, average=None):
184 | return skmetrics.precision_recall_fscore_support(y_true,
185 | y_pred,
186 | average=average)
187 |
188 |
189 | def tpr_fpr(y_true, y_pred):
190 | fpr, tpr, thresholds = skmetrics.roc_curve(y_true, y_pred)
191 | return fpr, tpr, thresholds
192 |
193 |
194 | def obtain_error_rates_alt(y_true, y_pred, threshold=0.5):
195 | speech_frame_predictions = binarize(y_pred.reshape(-1, 1),
196 | threshold=threshold)
197 | tn, fp, fn, tp = skmetrics.confusion_matrix(
198 | y_true, speech_frame_predictions).ravel()
199 |
200 | p_miss = 100 * (fn / (fn + tp))
201 | p_fa = 100 * (fp / (fp + tn))
202 | return p_fa, p_miss
203 |
204 |
205 | def confusion_matrix(y_true, y_pred):
206 | return skmetrics.confusion_matrix(y_true, y_pred)
207 |
208 |
209 | def obtain_error_rates(y_true, y_pred, threshold=0.5):
210 | negatives = y_pred[np.where(y_true == 0)]
211 | positives = y_pred[np.where(y_true == 1)]
212 | Pfa = np.sum(negatives >= threshold) / negatives.size
213 | Pmiss = np.sum(positives < threshold) / positives.size
214 | return Pfa, Pmiss
215 |
--------------------------------------------------------------------------------
/models.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from pathlib import Path
4 | import torch.nn as nn
5 |
6 |
7 | def crnn(inputdim=64, outputdim=527, pretrained_from='balanced.pth'):
8 | model = CRNN(inputdim, outputdim)
9 | if pretrained_from:
10 | state = torch.load(pretrained_from,
11 | map_location='cpu')
12 | model.load_state_dict(state, strict=False)
13 | return model
14 |
15 |
16 | def cnn10(inputdim=64, outputdim=527, pretrained_from='balanced.pth'):
17 | model = CNN10(inputdim, outputdim)
18 | if pretrained_from:
19 | state = torch.load(pretrained_from,
20 | map_location='cpu')
21 | model.load_state_dict(state, strict=False)
22 | return model
23 |
24 |
25 | def init_weights(m):
26 | if isinstance(m, (nn.Conv2d, nn.Conv1d)):
27 | nn.init.kaiming_normal_(m.weight)
28 | if m.bias is not None:
29 | nn.init.constant_(m.bias, 0)
30 | elif isinstance(m, nn.BatchNorm2d):
31 | nn.init.constant_(m.weight, 1)
32 | if m.bias is not None:
33 | nn.init.constant_(m.bias, 0)
34 | if isinstance(m, nn.Linear):
35 | nn.init.kaiming_uniform_(m.weight)
36 | if m.bias is not None:
37 | nn.init.constant_(m.bias, 0)
38 |
39 |
40 | class LinearSoftPool(nn.Module):
41 | """LinearSoftPool
42 |
43 | Linear softmax, takes logits and returns a probability, near to the actual maximum value.
44 | Taken from the paper:
45 |
46 | A Comparison of Five Multiple Instance Learning Pooling Functions for Sound Event Detection with Weak Labeling
47 | https://arxiv.org/abs/1810.09050
48 |
49 | """
50 | def __init__(self, pooldim=1):
51 | super().__init__()
52 | self.pooldim = pooldim
53 |
54 | def forward(self, logits, time_decision):
55 | return (time_decision**2).sum(self.pooldim) / time_decision.sum(
56 | self.pooldim)
57 |
58 |
59 | class MeanPool(nn.Module):
60 | def __init__(self, pooldim=1):
61 | super().__init__()
62 | self.pooldim = pooldim
63 |
64 | def forward(self, logits, decision):
65 | return torch.mean(decision, dim=self.pooldim)
66 |
67 |
68 | def parse_poolingfunction(poolingfunction_name='mean', **kwargs):
69 | """parse_poolingfunction
70 | A heler function to parse any temporal pooling
71 | Pooling is done on dimension 1
72 |
73 | :param poolingfunction_name:
74 | :param **kwargs:
75 | """
76 | poolingfunction_name = poolingfunction_name.lower()
77 | if poolingfunction_name == 'mean':
78 | return MeanPool(pooldim=1)
79 | elif poolingfunction_name == 'linear':
80 | return LinearSoftPool(pooldim=1)
81 | elif poolingfunction_name == 'attention':
82 | return AttentionPool(inputdim=kwargs['inputdim'],
83 | outputdim=kwargs['outputdim'])
84 |
85 |
86 | class AttentionPool(nn.Module):
87 | """docstring for AttentionPool"""
88 | def __init__(self, inputdim, outputdim=10, pooldim=1, **kwargs):
89 | super().__init__()
90 | self.inputdim = inputdim
91 | self.outputdim = outputdim
92 | self.pooldim = pooldim
93 | self.transform = nn.Linear(inputdim, outputdim)
94 | self.activ = nn.Softmax(dim=self.pooldim)
95 | self.eps = 1e-7
96 |
97 | def forward(self, logits, decision):
98 | # Input is (B, T, D)
99 | # B, T , D
100 | w = self.activ(self.transform(logits))
101 | detect = (decision * w).sum(
102 | self.pooldim) / (w.sum(self.pooldim) + self.eps)
103 | # B, T, D
104 | return detect
105 |
106 |
107 | class Block2D(nn.Module):
108 | def __init__(self, cin, cout, kernel_size=3, padding=1):
109 | super().__init__()
110 | self.block = nn.Sequential(
111 | nn.BatchNorm2d(cin),
112 | nn.Conv2d(cin,
113 | cout,
114 | kernel_size=kernel_size,
115 | padding=padding,
116 | bias=False),
117 | nn.LeakyReLU(inplace=True, negative_slope=0.1))
118 |
119 | def forward(self, x):
120 | return self.block(x)
121 |
122 |
123 | class CRNN(nn.Module):
124 | def __init__(self, inputdim, outputdim, **kwargs):
125 | super().__init__()
126 | self.features = nn.Sequential(
127 | Block2D(1, 32),
128 | nn.LPPool2d(4, (2, 4)),
129 | Block2D(32, 128),
130 | Block2D(128, 128),
131 | nn.LPPool2d(4, (2, 4)),
132 | Block2D(128, 128),
133 | Block2D(128, 128),
134 | nn.LPPool2d(4, (1, 4)),
135 | nn.Dropout(0.3),
136 | )
137 | with torch.no_grad():
138 | rnn_input_dim = self.features(torch.randn(1, 1, 500,
139 | inputdim)).shape
140 | rnn_input_dim = rnn_input_dim[1] * rnn_input_dim[-1]
141 |
142 | self.gru = nn.GRU(rnn_input_dim,
143 | 128,
144 | bidirectional=True,
145 | batch_first=True)
146 | self.temp_pool = parse_poolingfunction(kwargs.get(
147 | 'temppool', 'linear'),
148 | inputdim=256,
149 | outputdim=outputdim)
150 | self.outputlayer = nn.Linear(256, outputdim)
151 | self.features.apply(init_weights)
152 | self.outputlayer.apply(init_weights)
153 |
154 | def forward(self, x, upsample=True):
155 | batch, time, dim = x.shape
156 | x = x.unsqueeze(1)
157 | x = self.features(x)
158 | x = x.transpose(1, 2).contiguous().flatten(-2)
159 | x, _ = self.gru(x)
160 | decision_time = torch.sigmoid(self.outputlayer(x)).clamp(1e-7, 1.)
161 | if upsample:
162 | decision_time = torch.nn.functional.interpolate(
163 | decision_time.transpose(1, 2),
164 | time,
165 | mode='linear',
166 | align_corners=False).transpose(1, 2)
167 | decision = self.temp_pool(x, decision_time).clamp(1e-7, 1.).squeeze(1)
168 | return decision, decision_time
169 |
170 |
171 | class CNN10(nn.Module):
172 | def __init__(self, inputdim, outputdim, **kwargs):
173 | super().__init__()
174 | self.features = nn.Sequential(
175 | Block2D(1, 64),
176 | Block2D(64, 64),
177 | nn.LPPool2d(4, (2, 4)),
178 | Block2D(64, 128),
179 | Block2D(128, 128),
180 | nn.LPPool2d(4, (2, 2)),
181 | Block2D(128, 256),
182 | Block2D(256, 256),
183 | nn.LPPool2d(4, (1, 2)),
184 | Block2D(256, 512),
185 | Block2D(512, 512),
186 | nn.LPPool2d(4, (1, 2)),
187 | nn.Dropout(0.3),
188 | nn.AdaptiveAvgPool2d((None, 1)),
189 | )
190 |
191 | self.temp_pool = parse_poolingfunction(kwargs.get(
192 | 'temppool', 'attention'),
193 | inputdim=512,
194 | outputdim=outputdim)
195 | self.outputlayer = nn.Linear(512, outputdim)
196 | self.features.apply(init_weights)
197 | self.outputlayer.apply(init_weights)
198 |
199 | def forward(self, x, upsample=True):
200 | batch, time, dim = x.shape
201 | x = x.unsqueeze(1)
202 | x = self.features(x)
203 | x = x.transpose(1, 2).contiguous().flatten(-2)
204 | decision_time = torch.sigmoid(self.outputlayer(x)).clamp(1e-7, 1.)
205 | decision = self.temp_pool(x, decision_time).clamp(1e-7, 1.).squeeze(1)
206 | if upsample:
207 | decision_time = torch.nn.functional.interpolate(
208 | decision_time.transpose(1, 2),
209 | time,
210 | mode='linear',
211 | align_corners=False).transpose(1, 2)
212 | return decision, decision_time
213 |
214 |
215 | class CRNN10(nn.Module):
216 | def __init__(self, inputdim, outputdim, **kwargs):
217 | super().__init__()
218 | self._hiddim = kwargs.get('hiddim', 256)
219 | self.features = nn.Sequential(
220 | Block2D(1, 64),
221 | Block2D(64, 64),
222 | nn.LPPool2d(4, (2, 4)),
223 | Block2D(64, 128),
224 | Block2D(128, 128),
225 | nn.LPPool2d(4, (2, 2)),
226 | Block2D(128, 256),
227 | Block2D(256, 256),
228 | nn.LPPool2d(4, (1, 2)),
229 | Block2D(256, 512),
230 | Block2D(512, 512),
231 | nn.LPPool2d(4, (1, 2)),
232 | nn.Dropout(0.3),
233 | nn.AdaptiveAvgPool2d((None, 1)),
234 | )
235 | with torch.no_grad():
236 | rnn_input_dim = self.features(torch.randn(1, 1, 500,
237 | inputdim)).shape
238 | rnn_input_dim = rnn_input_dim[1] * rnn_input_dim[-1]
239 | self.gru = nn.GRU(rnn_input_dim,
240 | self._hiddim,
241 | bidirectional=True,
242 | batch_first=True)
243 | self.temp_pool = parse_poolingfunction(kwargs.get(
244 | 'temppool', 'linear'),
245 | inputdim=self._hiddim*2,
246 | outputdim=outputdim)
247 |
248 | self.outputlayer = nn.Linear(self._hiddim*2, outputdim)
249 | self.features.apply(init_weights)
250 | self.outputlayer.apply(init_weights)
251 |
252 | def forward(self, x, upsample=True):
253 | batch, time, dim = x.shape
254 | x = x.unsqueeze(1)
255 | x = self.features(x)
256 | x = x.transpose(1, 2).contiguous().flatten(-2)
257 | decision_time = torch.sigmoid(self.outputlayer(x)).clamp(1e-7, 1.)
258 | decision = self.temp_pool(x, decision_time).clamp(1e-7, 1.).squeeze(1)
259 | if upsample:
260 | decision_time = torch.nn.functional.interpolate(
261 | decision_time.transpose(1, 2),
262 | time,
263 | mode='linear',
264 | align_corners=False).transpose(1, 2)
265 | return decision, decision_time
266 |
--------------------------------------------------------------------------------
/pprint_results.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import pandas as pd
3 | import re
4 | from pathlib import Path
5 | import torch
6 |
7 | parser = argparse.ArgumentParser()
8 | parser.add_argument('dir', type=str)
9 | parser.add_argument('fmt', default=None,nargs='?')
10 | args = parser.parse_args()
11 |
12 | res = {}
13 | root_dir = Path(args.dir)
14 | train_log = root_dir / 'train.log'
15 |
16 | config = torch.load(root_dir / 'run_config.pth')
17 | pretrained = config.get('pretrained', None)
18 | # logs
19 | augment = config.get('transforms', [])
20 | label_type = config.get('label_type', 'soft')
21 | model = config.get('model','CRNN')
22 |
23 |
24 | def get_seg_metrics(line, pointer, seg_type='Segment'):
25 | res = {}
26 | while not 'macro-average' in line:
27 | line = next(pointer).strip()
28 | while not 'F-measure (F1)' in line:
29 | line = next(pointer).strip()
30 | res[f'F1'] = float(line.split()[-2])
31 | while not 'Precision' in line:
32 | line = next(pointer).strip()
33 | res[f'Precision'] = float(line.split()[-2])
34 | while not 'Recall' in line:
35 | line = next(pointer).strip()
36 | res[f'Recall'] = float(line.split()[-2])
37 | return res
38 |
39 |
40 | def parse_eval_file(eval_file):
41 | res = {}
42 | frame_results = {}
43 | with open(eval_file, 'r') as rp:
44 | for line in rp:
45 | line = line.strip()
46 | if 'AUC' in line:
47 | auc = line.split()[-1]
48 | frame_results['AUC'] = float(auc)
49 | if 'FER' in line:
50 | fer = line.split()[-1]
51 | frame_results['FER'] = float(fer)
52 | if 'VAD macro' in line:
53 | f1, pre, rec = re.findall(r"[-+]?\d*\.\d+|\d+",
54 | line)[1:] # First hit is F1
55 | frame_results['F1'] = float(f1)
56 | frame_results['Precision'] = float(pre)
57 | frame_results['Recall'] = float(rec)
58 | if "Segment based metrics" in line:
59 | res['Segment'] = get_seg_metrics(line, rp)
60 | if 'Event based metrics' in line:
61 | res['Event'] = get_seg_metrics(line, rp, 'Event')
62 | res['Frame'] = frame_results
63 | return res
64 |
65 |
66 | all_results = []
67 | for f in root_dir.glob('*.txt'):
68 | eval_dataset = str(f.stem)[11:]
69 | res = parse_eval_file(f)
70 | df = pd.DataFrame(res).fillna('')
71 | df['data'] = eval_dataset
72 | df['augment'] = ",".join(augment)
73 | df['pretrained'] = pretrained
74 | df['label_type'] = label_type
75 | df['model'] = model
76 | all_results.append(df)
77 | df = pd.concat(all_results)
78 | if args.fmt == 'csv':
79 | print(df.to_csv())
80 | else:
81 | print(df)
82 |
--------------------------------------------------------------------------------
/pretrained_models/audio2_vox2/model.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RicherMans/Datadriven-GPVAD/6c94570bee753271722f64826a0ed00c030b089c/pretrained_models/audio2_vox2/model.pth
--------------------------------------------------------------------------------
/pretrained_models/audioset2/model.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RicherMans/Datadriven-GPVAD/6c94570bee753271722f64826a0ed00c030b089c/pretrained_models/audioset2/model.pth
--------------------------------------------------------------------------------
/pretrained_models/c1/model.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RicherMans/Datadriven-GPVAD/6c94570bee753271722f64826a0ed00c030b089c/pretrained_models/c1/model.pth
--------------------------------------------------------------------------------
/pretrained_models/labelencoders/students.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RicherMans/Datadriven-GPVAD/6c94570bee753271722f64826a0ed00c030b089c/pretrained_models/labelencoders/students.pth
--------------------------------------------------------------------------------
/pretrained_models/labelencoders/teacher.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RicherMans/Datadriven-GPVAD/6c94570bee753271722f64826a0ed00c030b089c/pretrained_models/labelencoders/teacher.pth
--------------------------------------------------------------------------------
/pretrained_models/sre/model.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RicherMans/Datadriven-GPVAD/6c94570bee753271722f64826a0ed00c030b089c/pretrained_models/sre/model.pth
--------------------------------------------------------------------------------
/pretrained_models/teacher1/model.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RicherMans/Datadriven-GPVAD/6c94570bee753271722f64826a0ed00c030b089c/pretrained_models/teacher1/model.pth
--------------------------------------------------------------------------------
/pretrained_models/teacher2/model.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RicherMans/Datadriven-GPVAD/6c94570bee753271722f64826a0ed00c030b089c/pretrained_models/teacher2/model.pth
--------------------------------------------------------------------------------
/pretrained_models/vox2/model.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RicherMans/Datadriven-GPVAD/6c94570bee753271722f64826a0ed00c030b089c/pretrained_models/vox2/model.pth
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | pandas==1.1.0
2 | SoundFile==0.10.2
3 | numpy==1.16.4
4 | loguru==0.4.0
5 | h5py==2.9.0
6 | scipy==1.3.0
7 | torch==1.2.0
8 | sed_eval==0.2.1
9 | pytorch_ignite==0.2.0
10 | tqdm==4.32.2
11 | tabulate==0.8.3
12 | six==1.12.0
13 | fire==0.1.3
14 | librosa==0.7.0
15 | ignite==1.1.0
16 | scikit_learn==0.23.1
17 | typing==3.7.4.1
18 | PyYAML==5.4
19 | numba==0.48
20 | pypeln
21 |
--------------------------------------------------------------------------------
/run.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 |
4 | import os
5 | import datetime
6 |
7 | import uuid
8 | import fire
9 | from pathlib import Path
10 |
11 | import pandas as pd
12 | import torch
13 | import numpy as np
14 | from tqdm import tqdm
15 | from ignite.contrib.handlers import ProgressBar, param_scheduler
16 | from ignite.engine import (Engine, Events)
17 | from ignite.handlers import EarlyStopping, ModelCheckpoint
18 | from ignite.metrics import Accuracy, RunningAverage, Precision, Recall
19 | from ignite.utils import convert_tensor
20 | from tabulate import tabulate
21 | from h5py import File
22 |
23 | import dataset
24 | import models
25 | import utils
26 | import metrics
27 | import losses
28 |
29 | DEVICE = 'cpu'
30 | if torch.cuda.is_available(
31 | ) and 'SLURM_JOB_PARTITION' in os.environ and 'gpu' in os.environ[
32 | 'SLURM_JOB_PARTITION']:
33 | DEVICE = 'cuda'
34 | # Without results are slightly inconsistent
35 | torch.backends.cudnn.deterministic = True
36 | DEVICE = torch.device(DEVICE)
37 |
38 |
39 | class Runner(object):
40 | """Main class to run experiments with e.g., train and evaluate"""
41 | def __init__(self, seed=42):
42 | """__init__
43 |
44 | :param config: YAML config file
45 | :param **kwargs: Overwrite of yaml config
46 | """
47 | super().__init__()
48 | torch.manual_seed(seed)
49 | np.random.seed(seed)
50 |
51 | @staticmethod
52 | def _forward(model, batch):
53 | inputs, targets_time, targets_clip, filenames, lengths = batch
54 | inputs = convert_tensor(inputs, device=DEVICE, non_blocking=True)
55 | targets_time = convert_tensor(targets_time,
56 | device=DEVICE,
57 | non_blocking=True)
58 | targets_clip = convert_tensor(targets_clip,
59 | device=DEVICE,
60 | non_blocking=True)
61 | clip_level_output, frame_level_output = model(inputs)
62 | return clip_level_output, frame_level_output, targets_time, targets_clip, lengths
63 |
64 | @staticmethod
65 | def _negative_loss(engine):
66 | return -engine.state.metrics['Loss']
67 |
68 | def train(self, config, **kwargs):
69 | """Trains a given model specified in the config file or passed as the --model parameter.
70 | All options in the config file can be overwritten as needed by passing --PARAM
71 | Options with variable lengths ( e.g., kwargs can be passed by --PARAM '{"PARAM1":VAR1, "PARAM2":VAR2}'
72 |
73 | :param config: yaml config file
74 | :param **kwargs: parameters to overwrite yaml config
75 | """
76 |
77 | config_parameters = utils.parse_config_or_kwargs(config, **kwargs)
78 | outputdir = os.path.join(
79 | config_parameters['outputpath'], config_parameters['model'],
80 | "{}_{}".format(
81 | datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%m'),
82 | uuid.uuid1().hex))
83 | # Early init because of creating dir
84 | checkpoint_handler = ModelCheckpoint(
85 | outputdir,
86 | 'run',
87 | n_saved=3,
88 | require_empty=False,
89 | create_dir=True,
90 | score_function=self._negative_loss,
91 | score_name='loss')
92 | logger = utils.getfile_outlogger(os.path.join(outputdir, 'train.log'))
93 | logger.info("Storing files in {}".format(outputdir))
94 | # utils.pprint_dict
95 | utils.pprint_dict(config_parameters, logger.info)
96 | logger.info("Running on device {}".format(DEVICE))
97 | label_df = pd.read_csv(config_parameters['label'], sep='\s+')
98 | data_df = pd.read_csv(config_parameters['data'], sep='\s+')
99 | # In case that both are not matching
100 | merged = data_df.merge(label_df, on='filename')
101 | common_idxs = merged['filename']
102 | data_df = data_df[data_df['filename'].isin(common_idxs)]
103 | label_df = label_df[label_df['filename'].isin(common_idxs)]
104 |
105 | train_df, cv_df = utils.split_train_cv(
106 | label_df, **config_parameters['data_args'])
107 | train_label = utils.df_to_dict(train_df)
108 | cv_label = utils.df_to_dict(cv_df)
109 | data = utils.df_to_dict(data_df)
110 |
111 | transform = utils.parse_transforms(config_parameters['transforms'])
112 | torch.save(config_parameters, os.path.join(outputdir,
113 | 'run_config.pth'))
114 | logger.info("Transforms:")
115 | utils.pprint_dict(transform, logger.info, formatter='pretty')
116 | assert len(cv_df) > 0, "Fraction a bit too large?"
117 |
118 | trainloader = dataset.gettraindataloader(
119 | h5files=data,
120 | h5labels=train_label,
121 | transform=transform,
122 | label_type=config_parameters['label_type'],
123 | batch_size=config_parameters['batch_size'],
124 | num_workers=config_parameters['num_workers'],
125 | shuffle=True,
126 | )
127 |
128 | cvdataloader = dataset.gettraindataloader(
129 | h5files=data,
130 | h5labels=cv_label,
131 | label_type=config_parameters['label_type'],
132 | transform=None,
133 | shuffle=False,
134 | batch_size=config_parameters['batch_size'],
135 | num_workers=config_parameters['num_workers'],
136 | )
137 | model = getattr(models, config_parameters['model'],
138 | 'CRNN')(inputdim=trainloader.dataset.datadim,
139 | outputdim=2,
140 | **config_parameters['model_args'])
141 | if 'pretrained' in config_parameters and config_parameters[
142 | 'pretrained'] is not None:
143 | model_dump = torch.load(config_parameters['pretrained'],
144 | map_location='cpu')
145 | model_state = model.state_dict()
146 | pretrained_state = {
147 | k: v
148 | for k, v in model_dump.items()
149 | if k in model_state and v.size() == model_state[k].size()
150 | }
151 | model_state.update(pretrained_state)
152 | model.load_state_dict(model_state)
153 | logger.info("Loading pretrained model {}".format(
154 | config_parameters['pretrained']))
155 |
156 | model = model.to(DEVICE)
157 | optimizer = getattr(
158 | torch.optim,
159 | config_parameters['optimizer'],
160 | )(model.parameters(), **config_parameters['optimizer_args'])
161 |
162 | utils.pprint_dict(optimizer, logger.info, formatter='pretty')
163 | utils.pprint_dict(model, logger.info, formatter='pretty')
164 | if DEVICE.type != 'cpu' and torch.cuda.device_count() > 1:
165 | logger.info("Using {} GPUs!".format(torch.cuda.device_count()))
166 | model = torch.nn.DataParallel(model)
167 | criterion = getattr(losses, config_parameters['loss'])().to(DEVICE)
168 |
169 | def _train_batch(_, batch):
170 | model.train()
171 | with torch.enable_grad():
172 | optimizer.zero_grad()
173 | output = self._forward(
174 | model, batch) # output is tuple (clip, frame, target)
175 | loss = criterion(*output)
176 | loss.backward()
177 | # Single loss
178 | optimizer.step()
179 | return loss.item()
180 |
181 | def _inference(_, batch):
182 | model.eval()
183 | with torch.no_grad():
184 | return self._forward(model, batch)
185 |
186 | def thresholded_output_transform(output):
187 | # Output is (clip, frame, target, lengths)
188 | _, y_pred, y, y_clip, length = output
189 | batchsize, timesteps, ndim = y.shape
190 | idxs = torch.arange(timesteps,
191 | device='cpu').repeat(batchsize).view(
192 | batchsize, timesteps)
193 | mask = (idxs < length.view(-1, 1)).to(y.device)
194 | y = y * mask.unsqueeze(-1)
195 | y_pred = torch.round(y_pred)
196 | y = torch.round(y)
197 | return y_pred, y
198 |
199 | metrics = {
200 | 'Loss': losses.Loss(
201 | criterion), #reimplementation of Loss, supports 3 way loss
202 | 'Precision': Precision(thresholded_output_transform),
203 | 'Recall': Recall(thresholded_output_transform),
204 | 'Accuracy': Accuracy(thresholded_output_transform),
205 | }
206 | train_engine = Engine(_train_batch)
207 | inference_engine = Engine(_inference)
208 | for name, metric in metrics.items():
209 | metric.attach(inference_engine, name)
210 |
211 | def compute_metrics(engine):
212 | inference_engine.run(cvdataloader)
213 | results = inference_engine.state.metrics
214 | output_str_list = [
215 | "Validation Results - Epoch : {:<5}".format(engine.state.epoch)
216 | ]
217 | for metric in metrics:
218 | output_str_list.append("{} {:<5.2f}".format(
219 | metric, results[metric]))
220 | logger.info(" ".join(output_str_list))
221 | pbar.n = pbar.last_print_n = 0
222 |
223 | pbar = ProgressBar(persist=False)
224 | pbar.attach(train_engine)
225 |
226 | train_engine.add_event_handler(Events.ITERATION_COMPLETED(every=5000),
227 | compute_metrics)
228 | train_engine.add_event_handler(Events.EPOCH_COMPLETED, compute_metrics)
229 |
230 | early_stop_handler = EarlyStopping(
231 | patience=config_parameters['early_stop'],
232 | score_function=self._negative_loss,
233 | trainer=train_engine)
234 | inference_engine.add_event_handler(Events.EPOCH_COMPLETED,
235 | early_stop_handler)
236 | inference_engine.add_event_handler(Events.EPOCH_COMPLETED,
237 | checkpoint_handler, {
238 | 'model': model,
239 | })
240 |
241 | train_engine.run(trainloader, max_epochs=config_parameters['epochs'])
242 | return outputdir
243 |
244 | def train_evaluate(self,
245 | config,
246 | tasks=['aurora_clean', 'aurora_noisy', 'dcase18'],
247 | **kwargs):
248 | experiment_path = self.train(config, **kwargs)
249 | for task in tasks:
250 | self.evaluate(experiment_path, task=task)
251 |
252 | def predict_time(
253 | self,
254 | experiment_path,
255 | output_h5,
256 | rfac=2, # Resultuion upscale fator
257 | **kwargs): # overwrite --data
258 |
259 | experiment_path = Path(experiment_path)
260 | if experiment_path.is_file(): # Model is given
261 | model_path = experiment_path
262 | experiment_path = experiment_path.parent
263 | else:
264 | model_path = next(Path(experiment_path).glob("run_model*"))
265 | config = torch.load(next(Path(experiment_path).glob("run_config*")),
266 | map_location=lambda storage, loc: storage)
267 | logger = utils.getfile_outlogger(None)
268 | # Use previous config, but update data such as kwargs
269 | config_parameters = dict(config, **kwargs)
270 | # Default columns to search for in data
271 | encoder = torch.load('labelencoders/vad.pth')
272 | data = config_parameters['data']
273 | dset = dataset.EvalH5Dataset(data)
274 | dataloader = torch.utils.data.DataLoader(dset,
275 | batch_size=1,
276 | num_workers=4,
277 | shuffle=False)
278 |
279 | model = getattr(models, config_parameters['model'])(
280 | inputdim=dataloader.dataset.datadim,
281 | outputdim=len(encoder.classes_),
282 | **config_parameters['model_args'])
283 |
284 | model_parameters = torch.load(
285 | model_path, map_location=lambda storage, loc: storage)
286 | model.load_state_dict(model_parameters)
287 | model = model.to(DEVICE).eval()
288 |
289 | ## VAD preprocessing data
290 | logger.trace(model)
291 |
292 | output_dfs = []
293 |
294 | speech_label_idx = np.where('Speech' == encoder.classes_)[0].squeeze()
295 | non_speech_idx = np.arange(len(encoder.classes_))
296 | non_speech_idx = np.delete(non_speech_idx, speech_label_idx)
297 | speech_frame_predictions, speech_frame_prob_predictions = [], []
298 | with torch.no_grad(), tqdm(total=len(dataloader),
299 | leave=False,
300 | unit='clip') as pbar, File(output_h5,
301 | 'w') as store:
302 | for feature, filename in dataloader:
303 | feature = torch.as_tensor(feature).to(DEVICE)
304 | filename = Path(filename[0]).stem
305 | batch, time, dim = feature.shape
306 | # PANNS output a dict instead of 2 values
307 | prediction_tag, prediction_time = model(feature,
308 | upsample=False)
309 | prediction_tag = prediction_tag.to('cpu')
310 | prediction_time = torch.nn.functional.interpolate(
311 | prediction_time.transpose(1, 2),
312 | int(time * rfac),
313 | mode='linear',
314 | align_corners=False).transpose(1, 2)
315 | prediction_time = prediction_time.to('cpu').squeeze(0)
316 | speech_label_pred = prediction_time[
317 | ..., speech_label_idx].squeeze(-1)
318 | noise_label_pred = prediction_time[...,
319 | non_speech_idx].squeeze(-1)
320 | store[f'{filename}/speech'] = speech_label_pred
321 | store[f'{filename}/noise'] = noise_label_pred
322 | pbar.set_postfix(time=time,
323 | fname=filename,
324 | speech=speech_label_pred.shape,
325 | noise=noise_label_pred.shape)
326 | pbar.update()
327 |
328 | def predict_clip(self,
329 | experiment_path,
330 | output_csv,
331 | thres=0.5,
332 | **kwargs): # overwrite --data
333 | import h5py
334 | from sklearn.preprocessing import binarize
335 | from tqdm import tqdm
336 | config = torch.load(list(Path(experiment_path).glob("run_config*"))[0],
337 | map_location=lambda storage, loc: storage)
338 | config_parameters = dict(config, **kwargs)
339 | model_parameters = torch.load(
340 | list(Path(experiment_path).glob("run_model*"))[0],
341 | map_location=lambda storage, loc: storage)
342 | encoder = torch.load('labelencoders/vad.pth')
343 |
344 | predictions = []
345 | with h5py.File(config_parameters['data'],
346 | 'r') as input_store, torch.no_grad(), tqdm(
347 | total=len(input_store)) as pbar:
348 | inputdim = next(iter(input_store.values())).shape[-1]
349 | model = getattr(models, config_parameters['model'])(
350 | inputdim=inputdim,
351 | outputdim=len(encoder.classes_),
352 | **config_parameters['model_args'])
353 | model.load_state_dict(model_parameters)
354 | model = model.to(DEVICE).eval()
355 | for fname, sample in input_store.items():
356 | if sample.ndim > 1: # Global mean and Global_var might also be there
357 | sample = torch.as_tensor(sample[()]).unsqueeze(0).to(
358 | DEVICE) # batch + channel
359 | decision, _ = model(sample)
360 | decision = binarize(decision.to('cpu'), threshold=thres)
361 | pred_labels = encoder.inverse_transform(decision)[0]
362 | pbar.set_postfix(labels=pred_labels, file=fname)
363 | if len(pred_labels) > 0:
364 | predictions.append({
365 | 'filename':
366 | fname,
367 | 'event_labels':
368 | ",".join(pred_labels)
369 | })
370 | pbar.update()
371 |
372 | df = pd.DataFrame(predictions)
373 | df.to_csv(output_csv, sep='\t', index=False)
374 |
375 | def evaluate(self,
376 | experiment_path: Path,
377 | task: str = 'aurora_clean',
378 | model_resolution=0.02,
379 | time_resolution=0.02,
380 | threshold=(0.5, 0.1),
381 | **kwargs):
382 | EVALUATION_DATA = {
383 | 'aurora_clean': {
384 | 'data': 'data/evaluation/hdf5/aurora_clean.h5',
385 | 'label': 'data/evaluation/labels/aurora_clean_labels.tsv',
386 | },
387 | 'aurora_noisy': {
388 | 'data': 'data/evaluation/hdf5/aurora_noisy.h5',
389 | 'label': 'data/evaluation/labels/aurora_noisy_labels.tsv'
390 | },
391 | 'dihard_dev': {
392 | 'data': 'data/evaluation/hdf5/dihard_dev.h5',
393 | 'label': 'data/evaluation/labels/dihard_dev.csv'
394 | },
395 | 'dihard_eval': {
396 | 'data': 'data/evaluation/hdf5/dihard_eval.h5',
397 | 'label': 'data/evaluation/labels/dihard_eval.csv'
398 | },
399 | 'aurora_snr_20': {
400 | 'data':
401 | 'data/evaluation/hdf5/aurora_noisy_musan_snr_20.0.hdf5',
402 | 'label': 'data/evaluation/labels/musan_labels.tsv'
403 | },
404 | 'aurora_snr_15': {
405 | 'data':
406 | 'data/evaluation/hdf5/aurora_noisy_musan_snr_15.0.hdf5',
407 | 'label': 'data/evaluation/labels/musan_labels.tsv'
408 | },
409 | 'aurora_snr_10': {
410 | 'data':
411 | 'data/evaluation/hdf5/aurora_noisy_musan_snr_10.0.hdf5',
412 | 'label': 'data/evaluation/labels/musan_labels.tsv'
413 | },
414 | 'aurora_snr_5': {
415 | 'data': 'data/evaluation/hdf5/aurora_noisy_musan_snr_5.0.hdf5',
416 | 'label': 'data/evaluation/labels/musan_labels.tsv'
417 | },
418 | 'aurora_snr_0': {
419 | 'data': 'data/evaluation/hdf5/aurora_noisy_musan_snr_0.0.hdf5',
420 | 'label': 'data/evaluation/labels/musan_labels.tsv'
421 | },
422 | 'aurora_snr_-5': {
423 | 'data':
424 | 'data/evaluation/hdf5/aurora_noisy_musan_snr_-5.0.hdf5',
425 | 'label': 'data/evaluation/labels/musan_labels.tsv'
426 | },
427 | 'dcase18': {
428 | 'data': 'data/evaluation/hdf5/dcase18.h5',
429 | 'label': 'data/evaluation/labels/dcase18.tsv',
430 | },
431 | }
432 | assert task in EVALUATION_DATA, f"--task {'|'.join(list(EVALUATION_DATA.keys()))}"
433 | experiment_path = Path(experiment_path)
434 | if experiment_path.is_file(): # Model is given
435 | model_path = experiment_path
436 | experiment_path = experiment_path.parent
437 | else:
438 | model_path = next(Path(experiment_path).glob("run_model*"))
439 | config = torch.load(next(Path(experiment_path).glob("run_config*")),
440 | map_location='cpu')
441 | logger = utils.getfile_outlogger(None)
442 | # Use previous config, but update data such as kwargs
443 | config_parameters = dict(config, **kwargs)
444 | # Default columns to search for in data
445 | model_parameters = torch.load(
446 | model_path, map_location=lambda storage, loc: storage)
447 | encoder = torch.load('labelencoders/vad.pth')
448 | data = EVALUATION_DATA[task]['data']
449 | label_df = pd.read_csv(EVALUATION_DATA[task]['label'], sep='\s+')
450 | label_df['filename'] = label_df['filename'].apply(
451 | lambda x: Path(x).name)
452 | logger.info(f"Label_df shape is {label_df.shape}")
453 |
454 | dset = dataset.EvalH5Dataset(data,
455 | fnames=np.unique(
456 | label_df['filename'].values))
457 |
458 | dataloader = torch.utils.data.DataLoader(dset,
459 | batch_size=1,
460 | num_workers=4,
461 | shuffle=False)
462 |
463 | model = getattr(models, config_parameters['model'])(
464 | inputdim=dataloader.dataset.datadim,
465 | outputdim=len(encoder.classes_),
466 | **config_parameters['model_args'])
467 |
468 | model.load_state_dict(model_parameters)
469 | model = model.to(DEVICE).eval()
470 |
471 | ## VAD preprocessing data
472 | vad_label_helper_df = label_df.copy()
473 | vad_label_helper_df['onset'] = np.ceil(vad_label_helper_df['onset'] /
474 | model_resolution).astype(int)
475 | vad_label_helper_df['offset'] = np.ceil(vad_label_helper_df['offset'] /
476 | model_resolution).astype(int)
477 |
478 | vad_label_helper_df = vad_label_helper_df.groupby(['filename']).agg({
479 | 'onset':
480 | tuple,
481 | 'offset':
482 | tuple,
483 | 'event_label':
484 | tuple
485 | }).reset_index()
486 | logger.trace(model)
487 |
488 | output_dfs = []
489 |
490 | speech_label_idx = np.where('Speech' == encoder.classes_)[0].squeeze()
491 | speech_frame_predictions, speech_frame_ground_truth, speech_frame_prob_predictions = [], [],[]
492 | # Using only binary thresholding without filter
493 | if len(threshold) == 1:
494 | postprocessing_method = utils.binarize
495 | else:
496 | postprocessing_method = utils.double_threshold
497 | with torch.no_grad(), tqdm(total=len(dataloader),
498 | leave=False,
499 | unit='clip') as pbar:
500 | for feature, filename in dataloader:
501 | feature = torch.as_tensor(feature).to(DEVICE)
502 | # PANNS output a dict instead of 2 values
503 | prediction_tag, prediction_time = model(feature)
504 | prediction_tag = prediction_tag.to('cpu')
505 | prediction_time = prediction_time.to('cpu')
506 |
507 | if prediction_time is not None: # Some models do not predict timestamps
508 |
509 | cur_filename = filename[0]
510 |
511 | thresholded_prediction = postprocessing_method(
512 | prediction_time, *threshold)
513 |
514 | ## VAD predictions
515 | speech_frame_prob_predictions.append(
516 | prediction_time[..., speech_label_idx].squeeze())
517 | ### Thresholded speech predictions
518 | speech_prediction = thresholded_prediction[
519 | ..., speech_label_idx].squeeze()
520 | speech_frame_predictions.append(speech_prediction)
521 | targets = vad_label_helper_df[
522 | vad_label_helper_df['filename'] == cur_filename][[
523 | 'onset', 'offset'
524 | ]].values[0]
525 | target_arr = np.zeros_like(speech_prediction)
526 | for start, end in zip(*targets):
527 | target_arr[start:end] = 1
528 | speech_frame_ground_truth.append(target_arr)
529 |
530 | #### SED predictions
531 |
532 | labelled_predictions = utils.decode_with_timestamps(
533 | encoder, thresholded_prediction)
534 | pred_label_df = pd.DataFrame(
535 | labelled_predictions[0],
536 | columns=['event_label', 'onset', 'offset'])
537 | if not pred_label_df.empty:
538 | pred_label_df['filename'] = cur_filename
539 | pred_label_df['onset'] *= model_resolution
540 | pred_label_df['offset'] *= model_resolution
541 | pbar.set_postfix(labels=','.join(
542 | np.unique(pred_label_df['event_label'].values)))
543 | pbar.update()
544 | output_dfs.append(pred_label_df)
545 |
546 | full_prediction_df = pd.concat(output_dfs)
547 | prediction_df = full_prediction_df[full_prediction_df['event_label'] ==
548 | 'Speech']
549 | assert set(['onset', 'offset', 'filename', 'event_label'
550 | ]).issubset(prediction_df.columns), "Format is wrong"
551 | assert set(['onset', 'offset', 'filename', 'event_label'
552 | ]).issubset(label_df.columns), "Format is wrong"
553 | logger.info("Calculating VAD measures ... ")
554 | speech_frame_ground_truth = np.concatenate(speech_frame_ground_truth,
555 | axis=0)
556 | speech_frame_predictions = np.concatenate(speech_frame_predictions,
557 | axis=0)
558 | speech_frame_prob_predictions = np.concatenate(
559 | speech_frame_prob_predictions, axis=0)
560 |
561 | vad_results = []
562 | tn, fp, fn, tp = metrics.confusion_matrix(
563 | speech_frame_ground_truth, speech_frame_predictions).ravel()
564 | fer = 100 * ((fp + fn) / len(speech_frame_ground_truth))
565 | acc = 100 * ((tp + tn) / (len(speech_frame_ground_truth)))
566 |
567 | p_miss = 100 * (fn / (fn + tp))
568 | p_fa = 100 * (fp / (fp + tn))
569 | for i in [0.01, 0.02, 0.05, 0.1, 0.2, 0.5, 0.7,0.9]:
570 | mp_fa, mp_miss = metrics.obtain_error_rates(
571 | speech_frame_ground_truth, speech_frame_prob_predictions, i)
572 | tn, fp, fn, tp = metrics.confusion_matrix(
573 | speech_frame_ground_truth,
574 | speech_frame_prob_predictions > i).ravel()
575 | sub_fer = 100 * ((fp + fn) / len(speech_frame_ground_truth))
576 | logger.info(
577 | f"PFa {100*mp_fa:.2f} Pmiss {100*mp_miss:.2f} FER {sub_fer:.2f} t: {i:.2f}"
578 | )
579 |
580 | auc = metrics.roc(speech_frame_ground_truth,
581 | speech_frame_prob_predictions) * 100
582 | for avgtype in ('micro', 'macro', 'binary'):
583 | precision, recall, f1, _ = metrics.precision_recall_fscore_support(
584 | speech_frame_ground_truth,
585 | speech_frame_predictions,
586 | average=avgtype)
587 | vad_results.append(
588 | (avgtype, 100 * precision, 100 * recall, 100 * f1))
589 |
590 | logger.info("Calculating segment based metric .. ")
591 | # Change order just for better printing in file
592 | prediction_df = prediction_df[[
593 | 'filename', 'onset', 'offset', 'event_label'
594 | ]]
595 | metric = metrics.segment_based_evaluation_df(
596 | label_df, prediction_df, time_resolution=time_resolution)
597 | logger.info("Calculating event based metric .. ")
598 | event_metric = metrics.event_based_evaluation_df(
599 | label_df, prediction_df)
600 |
601 | prediction_df.to_csv(experiment_path /
602 | f'speech_predictions_{task}.tsv',
603 | sep='\t',
604 | index=False)
605 | full_prediction_df.to_csv(experiment_path / f'predictions_{task}.tsv',
606 | sep='\t',
607 | index=False)
608 | with open(experiment_path / f'evaluation_{task}.txt', 'w') as fp:
609 | for k, v in config_parameters.items():
610 | print(f"{k}:{v}", file=fp)
611 | print(metric, file=fp)
612 | print(event_metric, file=fp)
613 | for avgtype, precision, recall, f1 in vad_results:
614 | print(
615 | f"VAD {avgtype} F1: {f1:<10.3f} {precision:<10.3f} Recall: {recall:<10.3f}",
616 | file=fp)
617 | print(f"FER: {fer:.2f}", file=fp)
618 | print(f"AUC: {auc:.2f}", file=fp)
619 | print(f"Pfa: {p_fa:.2f}", file=fp)
620 | print(f"Pmiss: {p_miss:.2f}", file=fp)
621 | print(f"ACC: {acc:.2f}", file=fp)
622 | logger.info(f"Results are at {experiment_path}")
623 | for avgtype, precision, recall, f1 in vad_results:
624 | print(
625 | f"VAD {avgtype:<10} F1: {f1:<10.3f} Pre: {precision:<10.3f} Recall: {recall:<10.3f}"
626 | )
627 | print(f"FER: {fer:.2f}")
628 | print(f"AUC: {auc:.2f}")
629 | print(f"Pfa: {p_fa:.2f}")
630 | print(f"Pmiss: {p_miss:.2f}")
631 | print(f"ACC: {acc:.2f}")
632 | print(event_metric)
633 | print(metric)
634 |
635 |
636 | if __name__ == "__main__":
637 | fire.Fire(Runner)
638 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 |
4 | import collections
5 | import sys
6 | from loguru import logger
7 | from pprint import pformat
8 | from typing import List
9 |
10 | import numpy as np
11 | import pandas as pd
12 | import scipy
13 | import six
14 | import sklearn.preprocessing as pre
15 | import torch
16 | import tqdm
17 | import yaml
18 |
19 | import augment
20 | import dataset
21 |
22 | # Some defaults for non-specified arguments in yaml
23 | DEFAULT_ARGS = {
24 | 'outputpath': 'experiments',
25 | 'loss': 'BCELoss',
26 | 'batch_size': 64,
27 | 'num_workers': 4,
28 | 'epochs': 100,
29 | 'transforms': [],
30 | 'label_type':'soft',
31 | 'scheduler_args': {
32 | 'patience': 3,
33 | 'factor': 0.1,
34 | },
35 | 'early_stop': 7,
36 | 'optimizer': 'Adam',
37 | 'optimizer_args': {
38 | 'lr': 0.001,
39 | },
40 | 'threshold': None, #Default threshold for postprocessing function
41 | 'postprocessing': 'double',
42 | }
43 |
44 |
45 | def parse_config_or_kwargs(config_file, **kwargs):
46 | """parse_config_or_kwargs
47 |
48 | :param config_file: Config file that has parameters, yaml format
49 | :param **kwargs: Other alternative parameters or overwrites for config
50 | """
51 | with open(config_file) as con_read:
52 | yaml_config = yaml.load(con_read, Loader=yaml.FullLoader)
53 | # values from config file are all possible params
54 | arguments = dict(yaml_config, **kwargs)
55 | # In case some arguments were not passed, replace with default ones
56 | for key, value in DEFAULT_ARGS.items():
57 | arguments.setdefault(key, value)
58 | return arguments
59 |
60 |
61 | def find_contiguous_regions(activity_array):
62 | """Find contiguous regions from bool valued numpy.array.
63 | Copy of https://dcase-repo.github.io/dcase_util/_modules/dcase_util/data/decisions.html#DecisionEncoder
64 |
65 | Reason is:
66 | 1. This does not belong to a class necessarily
67 | 2. Import DecisionEncoder requires sndfile over some other imports..which causes some problems on clusters
68 |
69 | """
70 |
71 | # Find the changes in the activity_array
72 | change_indices = np.logical_xor(activity_array[1:],
73 | activity_array[:-1]).nonzero()[0]
74 |
75 | # Shift change_index with one, focus on frame after the change.
76 | change_indices += 1
77 |
78 | if activity_array[0]:
79 | # If the first element of activity_array is True add 0 at the beginning
80 | change_indices = np.r_[0, change_indices]
81 |
82 | if activity_array[-1]:
83 | # If the last element of activity_array is True, add the length of the array
84 | change_indices = np.r_[change_indices, activity_array.size]
85 |
86 | # Reshape the result into two columns
87 | return change_indices.reshape((-1, 2))
88 |
89 |
90 | def split_train_cv(input_data, frac: float = 0.9, **kwargs):
91 | """split_train_cv
92 |
93 | :param data_frame:
94 | :param frac:
95 | :type frac: float
96 | """
97 | if isinstance(input_data, list):
98 | N = len(input_data)
99 | indicies = np.random.permutation(N)
100 | train_size = round(N * frac)
101 | cv_size = N - train_size
102 | train_idxs, cv_idxs = indicies[:train_size], indicies[cv_size:]
103 | input_data = np.array(input_data)
104 | return input_data[train_idxs].tolist(), input_data[cv_idxs].tolist()
105 | elif isinstance(input_data, pd.DataFrame):
106 | train_df = input_data.sample(frac=frac)
107 | cv_df = input_data[~input_data.index.isin(train_df.index)]
108 | return train_df, cv_df
109 |
110 |
111 | def parse_transforms(transform_list):
112 | """parse_transforms
113 | parses the config files transformation strings to coresponding methods
114 |
115 | :param transform_list: String list
116 | """
117 | transforms = []
118 | for trans in transform_list:
119 | if trans == 'noise':
120 | transforms.append(augment.GaussianNoise(snr=25))
121 | elif trans == 'roll':
122 | transforms.append(augment.Roll(0, 10))
123 | elif trans == 'freqmask':
124 | transforms.append(augment.FreqMask(2, 8))
125 | elif trans == 'timemask':
126 | transforms.append(augment.TimeMask(2, 60))
127 | elif trans == 'crop':
128 | transforms.append(augment.RandomCrop(200))
129 | elif trans == 'randompad':
130 | transforms.append(augment.RandomPad(value=0., padding=25))
131 | elif trans == 'flipsign':
132 | transforms.append(augment.FlipSign())
133 | elif trans == 'shift':
134 | transforms.append(augment.Shift())
135 | return torch.nn.Sequential(*transforms)
136 |
137 |
138 | def pprint_dict(in_dict, outputfun=sys.stdout.write, formatter='yaml'):
139 | """pprint_dict
140 |
141 | :param outputfun: function to use, defaults to sys.stdout
142 | :param in_dict: dict to print
143 | """
144 | if formatter == 'yaml':
145 | format_fun = yaml.dump
146 | elif formatter == 'pretty':
147 | format_fun = pformat
148 | for line in format_fun(in_dict).split('\n'):
149 | outputfun(line)
150 |
151 |
152 | def getfile_outlogger(outputfile):
153 | log_format = "[{time:YYYY-MM-DD HH:mm:ss}] {message}"
154 | logger.configure(handlers=[{"sink": sys.stderr, "format": log_format}])
155 | if outputfile:
156 | logger.add(outputfile, enqueue=True, format=log_format)
157 | return logger
158 |
159 |
160 | def train_labelencoder(labels: pd.Series, sparse=True):
161 | """encode_labels
162 |
163 | Encodes labels
164 |
165 | :param labels: pd.Series representing the raw labels e.g., Speech, Water
166 | :param encoder (optional): Encoder already fitted
167 | returns encoded labels (many hot) and the encoder
168 | """
169 | assert isinstance(labels, pd.Series), "Labels need to be series"
170 | if isinstance(labels[0], six.string_types):
171 | # In case of using non processed strings, e.g., Vaccum, Speech
172 | label_array = labels.str.split(',').values.tolist()
173 | elif isinstance(labels[0], np.ndarray):
174 | # Encoder does not like to see numpy array
175 | label_array = [lab.tolist() for lab in labels]
176 | elif isinstance(labels[0], collections.Iterable):
177 | label_array = labels
178 | encoder = pre.MultiLabelBinarizer(sparse_output=sparse)
179 | encoder.fit(label_array)
180 | return encoder
181 |
182 |
183 | def encode_labels(labels: pd.Series, encoder=None, sparse=True):
184 | """encode_labels
185 |
186 | Encodes labels
187 |
188 | :param labels: pd.Series representing the raw labels e.g., Speech, Water
189 | :param encoder (optional): Encoder already fitted
190 | returns encoded labels (many hot) and the encoder
191 | """
192 | assert isinstance(labels, pd.Series), "Labels need to be series"
193 | instance = labels.iloc[0]
194 | if isinstance(instance, six.string_types):
195 | # In case of using non processed strings, e.g., Vaccum, Speech
196 | label_array = labels.str.split(',').values.tolist()
197 | elif isinstance(instance, np.ndarray):
198 | # Encoder does not like to see numpy array
199 | label_array = [lab.tolist() for lab in labels]
200 | elif isinstance(instance, collections.Iterable):
201 | label_array = labels
202 | if not encoder:
203 | encoder = pre.MultiLabelBinarizer(sparse_output=sparse)
204 | encoder.fit(label_array)
205 | labels_encoded = encoder.transform(label_array)
206 | return labels_encoded, encoder
207 |
208 | # return pd.arrays.SparseArray(
209 | # [row.toarray().ravel() for row in labels_encoded]), encoder
210 |
211 |
212 | def decode_with_timestamps(encoder: pre.MultiLabelBinarizer, labels: np.array):
213 | """decode_with_timestamps
214 | Decodes the predicted label array (2d) into a list of
215 | [(Labelname, onset, offset), ...]
216 |
217 | :param encoder: Encoder during training
218 | :type encoder: pre.MultiLabelBinarizer
219 | :param labels: n-dim array
220 | :type labels: np.array
221 | """
222 | if labels.ndim == 3:
223 | return [_decode_with_timestamps(encoder, lab) for lab in labels]
224 | else:
225 | return _decode_with_timestamps(encoder, labels)
226 |
227 |
228 | def sma_filter(x, window_size, axis=1):
229 | """sma_filter
230 |
231 | :param x: Input numpy array,
232 | :param window_size: filter size
233 | :param axis: over which axis ( usually time ) to apply
234 | """
235 | # 1 is time axis
236 | kernel = np.ones((window_size, )) / window_size
237 |
238 | def moving_average(arr):
239 | return np.convolve(arr, kernel, 'same')
240 |
241 | return np.apply_along_axis(moving_average, axis, x)
242 |
243 |
244 | def median_filter(x, window_size, threshold=0.5):
245 | """median_filter
246 |
247 | :param x: input prediction array of shape (B, T, C) or (B, T).
248 | Input is a sequence of probabilities 0 <= x <= 1
249 | :param window_size: An integer to use
250 | :param threshold: Binary thresholding threshold
251 | """
252 | x = binarize(x, threshold=threshold)
253 | if x.ndim == 3:
254 | size = (1, window_size, 1)
255 | elif x.ndim == 2 and x.shape[0] == 1:
256 | # Assume input is class-specific median filtering
257 | # E.g, Batch x Time [1, 501]
258 | size = (1, window_size)
259 | elif x.ndim == 2 and x.shape[0] > 1:
260 | # Assume input is standard median pooling, class-independent
261 | # E.g., Time x Class [501, 10]
262 | size = (window_size, 1)
263 | return scipy.ndimage.median_filter(x, size=size)
264 |
265 |
266 | def _decode_with_timestamps(encoder, labels):
267 | result_labels = []
268 | for i, label_column in enumerate(labels.T):
269 | change_indices = find_contiguous_regions(label_column)
270 | # append [onset, offset] in the result list
271 | for row in change_indices:
272 | result_labels.append((encoder.classes_[i], row[0], row[1]))
273 | return result_labels
274 |
275 |
276 | def inverse_transform_labels(encoder, pred):
277 | if pred.ndim == 3:
278 | return [encoder.inverse_transform(x) for x in pred]
279 | else:
280 | return encoder.inverse_transform(pred)
281 |
282 |
283 | def binarize(pred, threshold=0.5):
284 | # Batch_wise
285 | if pred.ndim == 3:
286 | return np.array(
287 | [pre.binarize(sub, threshold=threshold) for sub in pred])
288 | else:
289 | return pre.binarize(pred, threshold=threshold)
290 |
291 |
292 | def double_threshold(x, high_thres, low_thres, n_connect=1):
293 | """double_threshold
294 | Helper function to calculate double threshold for n-dim arrays
295 |
296 | :param x: input array
297 | :param high_thres: high threshold value
298 | :param low_thres: Low threshold value
299 | :param n_connect: Distance of <= n clusters will be merged
300 | """
301 | assert x.ndim <= 3, "Whoops something went wrong with the input ({}), check if its <= 3 dims".format(
302 | x.shape)
303 | if x.ndim == 3:
304 | apply_dim = 1
305 | elif x.ndim < 3:
306 | apply_dim = 0
307 | # x is assumed to be 3d: (batch, time, dim)
308 | # Assumed to be 2d : (time, dim)
309 | # Assumed to be 1d : (time)
310 | # time axis is therefore at 1 for 3d and 0 for 2d (
311 | return np.apply_along_axis(lambda x: _double_threshold(
312 | x, high_thres, low_thres, n_connect=n_connect),
313 | axis=apply_dim,
314 | arr=x)
315 |
316 |
317 | def _double_threshold(x, high_thres, low_thres, n_connect=1, return_arr=True):
318 | """_double_threshold
319 | Computes a double threshold over the input array
320 |
321 | :param x: input array, needs to be 1d
322 | :param high_thres: High threshold over the array
323 | :param low_thres: Low threshold over the array
324 | :param n_connect: Postprocessing, maximal distance between clusters to connect
325 | :param return_arr: By default this function returns the filtered indiced, but if return_arr = True it returns an array of tsame size as x filled with ones and zeros.
326 | """
327 | assert x.ndim == 1, "Input needs to be 1d"
328 | high_locations = np.where(x > high_thres)[0]
329 | locations = x > low_thres
330 | encoded_pairs = find_contiguous_regions(locations)
331 |
332 | filtered_list = list(
333 | filter(
334 | lambda pair:
335 | ((pair[0] <= high_locations) & (high_locations <= pair[1])).any(),
336 | encoded_pairs))
337 |
338 | filtered_list = connect_(filtered_list, n_connect)
339 | if return_arr:
340 | zero_one_arr = np.zeros_like(x, dtype=int)
341 | for sl in filtered_list:
342 | zero_one_arr[sl[0]:sl[1]] = 1
343 | return zero_one_arr
344 | return filtered_list
345 |
346 |
347 | def connect_clusters(x, n=1):
348 | if x.ndim == 1:
349 | return connect_clusters_(x, n)
350 | if x.ndim >= 2:
351 | return np.apply_along_axis(lambda a: connect_clusters_(a, n=n), -2, x)
352 |
353 |
354 | def connect_clusters_(x, n=1):
355 | """connect_clusters_
356 | Connects clustered predictions (0,1) in x with range n
357 |
358 | :param x: Input array. zero-one format
359 | :param n: Number of frames to skip until connection can be made
360 | """
361 | assert x.ndim == 1, "input needs to be 1d"
362 | reg = find_contiguous_regions(x)
363 | start_end = connect_(reg, n=n)
364 | zero_one_arr = np.zeros_like(x, dtype=int)
365 | for sl in start_end:
366 | zero_one_arr[sl[0]:sl[1]] = 1
367 | return zero_one_arr
368 |
369 |
370 | def connect_(pairs, n=1):
371 | """connect_
372 | Connects two adjacent clusters if their distance is <= n
373 |
374 | :param pairs: Clusters of iterateables e.g., [(1,5),(7,10)]
375 | :param n: distance between two clusters
376 | """
377 | if len(pairs) == 0:
378 | return []
379 | start_, end_ = pairs[0]
380 | new_pairs = []
381 | for i, (next_item, cur_item) in enumerate(zip(pairs[1:], pairs[0:])):
382 | end_ = next_item[1]
383 | if next_item[0] - cur_item[1] <= n:
384 | pass
385 | else:
386 | new_pairs.append((start_, cur_item[1]))
387 | start_ = next_item[0]
388 | new_pairs.append((start_, end_))
389 | return new_pairs
390 |
391 |
392 | def predictions_to_time(df, ratio):
393 | df.onset = df.onset * ratio
394 | df.offset = df.offset * ratio
395 | return df
396 |
397 |
398 | def estimate_scaler(dataloader, **scaler_args):
399 |
400 | scaler = pre.StandardScaler(**scaler_args)
401 | with tqdm.tqdm(total=len(dataloader),
402 | unit='batch',
403 | leave=False,
404 | desc='Estimating Scaler') as pbar:
405 | for batch in dataloader:
406 | feature = batch[0]
407 | # Flatten time and batch dim to one
408 | feature = feature.reshape(-1, feature.shape[-1])
409 | pbar.set_postfix(feature=feature.shape)
410 | pbar.update()
411 | scaler.partial_fit(feature)
412 | return scaler
413 |
414 |
415 | def rescale_0_1(x):
416 | if x.ndim == 2:
417 | return pre.minmax_scale(x, axis=0)
418 | else:
419 |
420 | def min_max_scale(a):
421 | return pre.minmax_scale(a, axis=0)
422 |
423 | def df_to_dict(df, index='filename', value='hdf5path'):
424 | return dict(zip(df[index],df[value]))
425 |
--------------------------------------------------------------------------------