├── .gitignore
├── .idea
└── encodings.xml
├── LICENSE
├── README.md
├── configs
├── mfcc.conf
└── vbdiar.yml
├── examples
├── diarization.py
├── lists
│ ├── AMI_dev-eval.scp
│ ├── AMI_trn.scp
│ ├── list.scp
│ └── list_spk.scp
├── run.sh
├── vad
│ └── fisher-english-p1
│ │ ├── fe_03_00001-a.lab.gz
│ │ └── fe_03_00001-b.lab.gz
└── wav
│ └── fisher-english-p1
│ ├── README
│ ├── fe_03_00001-a.wav
│ ├── fe_03_00001-b.wav
│ └── smalltalk0501.wav
├── models
├── LDA.npy
├── final.onnx
├── gplda
│ ├── CB.npy
│ ├── CW.npy
│ └── mu.npy
└── mean.npy
├── requirements.txt
├── setup.py
└── vbdiar
├── __init__.py
├── clustering
├── __init__.py
└── pldakmeans.py
├── embeddings
├── __init__.py
└── embedding.py
├── features
├── __init__.py
└── segments.py
├── kaldi
├── __init__.py
├── kaldi_xvector_extraction.py
├── mfcc_features_extraction.py
├── onnx_xvector_extraction.py
└── utils.py
├── scoring
├── __init__.py
├── diarization.py
├── gplda.py
├── md-eval.pl
└── normalization.py
├── utils
├── __init__.py
└── utils.py
└── vad
├── __init__.py
└── vad.py
/.gitignore:
--------------------------------------------------------------------------------
1 | *.pyc
2 |
--------------------------------------------------------------------------------
/.idea/encodings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # vbdiar
2 |
3 | This project is DEPRECATED and moved to https://github.com/BUTSpeechFIT/VBx, which achieves better results.
4 |
5 | Speaker diarization based on x-vectors using pretrained model trained in Kaldi (https://github.com/kaldi-asr/kaldi)
6 | and converted to ONNX format (https://github.com/onnx/onnx) running in ONNXRuntime (https://github.com/Microsoft/onnxruntime).
7 |
8 | X-vector model was trained using VoxCeleb1 and VoxCeleb2 16k data (http://www.robots.ox.ac.uk/~vgg/data/voxceleb/index.html#about).
9 |
10 | If you make use of the code or model, cite this: https://www.vutbr.cz/en/students/final-thesis/detail/122072
11 |
12 |
13 | ## Dependencies
14 |
15 | Dependencies are listed in `requirements.txt`.
16 |
17 | ## Installation
18 |
19 | It is recommended to use anaconda environment https://www.anaconda.com/download/.
20 | Run `python setup.py install`
21 | Also, since we are using Kaldi, path to Kaldi root must be set in `vbdiar/kaldi/__init__.py`
22 |
23 | ## Configs
24 |
25 | Config file declares used models and paths to them. Example configuration file is `configs/vbdiar.yml`.
26 |
27 | ## Models
28 |
29 | Pretrained models are stored in `models/` directory.
30 |
31 | ## Examples
32 |
33 | Example script `examples/diarization.py` is able to run full diarization process. The code is designed in a way, that you have everything in same tree structure with relative paths in list and then you just specify directories - audio, VAD, output, etc. See example configuration.
34 |
35 | ### Required Arguments
36 |
37 | `'-l', '--input-list'` - specifies relative path to files for testing, it is possible to specify number of speakers as the second column. Do not use file suffixes, path is always relative to input directory and suffix.
38 |
39 | `'-c', '--configuration'` - specifies configuration file/
40 |
41 | `'-m', '--mode'` - specifies running mode, there are two possible modes, classic `diarization` mode which should segment
42 | utterance into speakers and `sre` mode used for speaker recognition, which runs clustering for N iterations and saves all clusters
43 |
44 | ### Non-required Arguments
45 |
46 | `'--audio-dir'` - directory with audio files in `.wav` format - `8000Hz, 16bit-s, 1c`.
47 |
48 | `'--vad-dir'` - directory with lab files - Voice/Speech activity detection - format `speech_start speech_end`.
49 |
50 | `'--in-emb-dir'` - input directory containing embeddings (if they were previously saved).
51 |
52 | `'--out-emb-dir'` - output directory for storing embeddings.
53 |
54 | `'--norm-list'` - input list with files for score normalization. When performing score normalization, it is necessary to use input ground truth `.rttm` files with unique speaker label. Speaker labels should not overlap, only in case, that there is same speaker in more audio files. All normalization utterances will be merged by speaker labels.
55 |
56 | `'--in-rttm-dir'` - input directory with `.rttm` files (used primary for score normalization)
57 |
58 | `'--out-rttm-dir'` - output directory for storing `.rttm` files
59 |
60 | `'--min-window-size'` - minimal size of embedding window in miliseconds. Defines minimal size used for clustering algorithms.
61 |
62 | `'--max-window-size'` - maximal size of embedding window in miliseconds.
63 |
64 | `'--vad-tolerance'` - skip `n` frames of non-speech and merge them as speech.
65 |
66 | `'--max-num-speakers'` - maximal number of speakers. Used in clustering algorithm.
67 |
68 | `'--use-gpu'` - use GPU instead of cpu (onnxruntime-gpu must be installed)
69 |
70 |
71 | ## Results on Datasets
72 |
73 | ### AMI corpus http://groups.inf.ed.ac.uk/ami/corpus/ (development and evaluation set together)
74 | It is important to note that these results are obtained using summed individual head-mounted microphones.
75 | Results are reporting when using oracle number of speakers, collar size 0.25s and without scoring overlapped speech.
76 | Data were upsampled from 8k to 16k and 8k wav data are no longer supported.
77 |
78 | Results can be obtained using similar command
79 | ```bash
80 | python diarization.py -c ../configs/vbdiar.yml -l lists/AMI_dev-eval.scp --audio-dir wav/AMI/IHM_SUM --vad-dir vad/AMI --out-emb-dir emb/AMI/IHM_SUM --in-rttm-dir rttms/AMI
81 | ```
82 |
83 | | System | DER |
84 | |------------------------------------------------------------------------|-------|
85 | | Oracle number of speakers + x-vectors + mean + LDA + L2 Norm + GPLDA | 6.67 |
86 | | Oracle number of speakers + x-vectors + mean + LDA + L2 Norm | 9.16 |
87 | | x-vectors + mean + LDA + L2 Norm + GPLDA | 15.54 |
88 |
--------------------------------------------------------------------------------
/configs/mfcc.conf:
--------------------------------------------------------------------------------
1 | --sample-frequency=16000
2 | --frame-length=25 # the default is 25
3 | --low-freq=20 # the default.
4 | --high-freq=7700 # the default is zero meaning use the Nyquist (4k in this case).
5 | --num-ceps=23 # higher than the default which is 12.
6 | --snip-edges=false
7 |
--------------------------------------------------------------------------------
/configs/vbdiar.yml:
--------------------------------------------------------------------------------
1 | # defines MFCC configuration, following Kaldi's conventions
2 | MFCC:
3 | config_path: ../configs/mfcc.conf
4 | apply_cmvn_sliding: True
5 | norm_vars: False
6 | center: True
7 | cmn_window: 300
8 |
9 | # defines properties of x-vector extractor, such as neural net, ...
10 | EmbeddingExtractor:
11 | onnx_path: ../models/final.onnx
12 |
13 | PLDA:
14 | path: ../models/gplda
15 |
16 | # specifies possible transformations to embeddings - mean subtraction, LDA, l2 normalization, ...
17 | Transforms:
18 | mean: ../models/mean.npy
19 | lda: ../models/LDA.npy
20 | use_l2_norm: True
21 |
--------------------------------------------------------------------------------
/examples/diarization.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | #
4 | # Copyright (C) 2018 Brno University of Technology FIT
5 | # Author: Jan Profant
6 | # All Rights Reserved
7 |
8 | import os
9 | import sys
10 | import ctypes
11 | import logging
12 | import argparse
13 | import multiprocessing
14 | import subprocess
15 |
16 | import numpy as np
17 |
18 | from vbdiar.scoring.gplda import GPLDA
19 | from vbdiar.vad import get_vad
20 | from vbdiar.utils import mkdir_p
21 | from vbdiar.utils.utils import Utils
22 | from vbdiar.embeddings.embedding import extract_embeddings
23 | from vbdiar.scoring.diarization import Diarization
24 | from vbdiar.scoring.normalization import Normalization
25 | from vbdiar.kaldi.onnx_xvector_extraction import ONNXXVectorExtraction
26 | from vbdiar.features.segments import get_segments, get_time_from_frames, get_frames_from_time
27 | from vbdiar.kaldi.mfcc_features_extraction import KaldiMFCCFeatureExtraction
28 |
29 |
30 | logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
31 | logger = logging.getLogger(__name__)
32 |
33 | CDIR = os.path.dirname(os.path.realpath(__file__))
34 |
35 |
36 | def _process_files(dargs):
37 | """
38 |
39 | Args:
40 | dargs:
41 |
42 | Returns:
43 |
44 | """
45 | fns, kwargs = dargs
46 | ret = []
47 | for fn in fns:
48 | ret.append(process_file(file_name=fn, **kwargs))
49 | return ret
50 |
51 |
52 | def process_files(fns, wav_dir, vad_dir, out_dir, features_extractor, embedding_extractor, min_size,
53 | max_size, overlap, tolerance, wav_suffix='.wav', vad_suffix='.lab.gz', n_jobs=1):
54 | """ Process all files from list.
55 |
56 | Args:
57 | fns (list): name of files to process
58 | wav_dir (str): directory with wav files
59 | vad_dir (str): directory with vad files
60 | out_dir (str|None): output directory
61 | features_extractor (Any): intialized object for feature extraction
62 | embedding_extractor (Any): initialized object for embedding extraction
63 | max_size (int): maximal size of window in ms
64 | min_size (int): minimal size of window in ms
65 | overlap (int): size of window overlap in ms
66 | tolerance (int): accept given number of frames as speech even when it is marked as silence
67 | wav_suffix (str): suffix of wav files
68 | vad_suffix (str): suffix of vad files
69 | n_jobs (int): number of jobs to run in parallel
70 |
71 | Returns:
72 | List[EmbeddingSet]
73 | """
74 | kwargs = dict(wav_dir=wav_dir, vad_dir=vad_dir, out_dir=out_dir, features_extractor=features_extractor,
75 | embedding_extractor=embedding_extractor, tolerance=tolerance, min_size=min_size,
76 | max_size=max_size, overlap=overlap, wav_suffix=wav_suffix, vad_suffix=vad_suffix)
77 | if n_jobs == 1:
78 | ret = _process_files((fns, kwargs))
79 | else:
80 | pool = multiprocessing.Pool(n_jobs)
81 | ret = pool.map(_process_files, ((part, kwargs) for part in Utils.partition(fns, n_jobs)))
82 | return [item for sublist in ret for item in sublist]
83 |
84 |
85 | def process_file(wav_dir, vad_dir, out_dir, file_name, features_extractor, embedding_extractor,
86 | min_size, max_size, overlap, tolerance, wav_suffix='.wav', vad_suffix='.lab.gz'):
87 | """ Process single audio file.
88 |
89 | Args:
90 | wav_dir (str): directory with wav files
91 | vad_dir (str): directory with vad files
92 | out_dir (str): output directory
93 | file_name (str): name of the file
94 | features_extractor (Any): intialized object for feature extraction
95 | embedding_extractor (Any): initialized object for embedding extraction
96 | max_size (int): maximal size of window in ms
97 | max_size (int): maximal size of window in ms
98 | overlap (int): size of window overlap in ms
99 | tolerance (int): accept given number of frames as speech even when it is marked as silence
100 | wav_suffix (str): suffix of wav files
101 | vad_suffix (str): suffix of vad files
102 |
103 | Returns:
104 | EmbeddingSet
105 | """
106 | logger.info('Processing file {}.'.format(file_name.split()[0]))
107 | num_speakers = None
108 | if len(file_name.split()) > 1: # number of speakers is defined
109 | file_name, num_speakers = file_name.split()[0], int(file_name.split()[1])
110 |
111 | wav_dir, vad_dir = os.path.abspath(wav_dir), os.path.abspath(vad_dir)
112 | if out_dir:
113 | out_dir = os.path.abspath(out_dir)
114 |
115 | # extract features
116 | features = features_extractor.audio2features(os.path.join(wav_dir, f'{file_name}{wav_suffix}'))
117 |
118 | # load voice activity detection from file
119 | vad, _, _ = get_vad(f'{os.path.join(vad_dir, file_name)}{vad_suffix}', features.shape[0])
120 |
121 | # parse segments and split features
122 | features_dict = {}
123 | for seg in get_segments(vad, max_size, tolerance):
124 | seg_start, seg_end = seg
125 | start, end = get_time_from_frames(seg_start), get_time_from_frames(seg_end)
126 | if start >= overlap:
127 | seg_start = get_frames_from_time(start - overlap)
128 | if seg_start > features.shape[0] - 1 or seg_end > features.shape[0] - 1:
129 | logger.warning(f'Frames not aligned, number of frames {features.shape[0]} and got ending segment {seg_end}')
130 | seg_end = features.shape[0]
131 | features_dict[(start, end)] = features[seg_start:seg_end]
132 |
133 | # extract embedding for each segment
134 | embedding_set = extract_embeddings(features_dict, embedding_extractor)
135 | embedding_set.name = file_name
136 | embedding_set.num_speakers = num_speakers
137 |
138 | # save embeddings if required
139 | if out_dir is not None:
140 | mkdir_p(os.path.join(out_dir, os.path.dirname(file_name)))
141 | embedding_set.save(os.path.join(out_dir, '{}.pkl'.format(file_name)))
142 |
143 | return embedding_set
144 |
145 |
146 | def set_mkl(num_cores=1):
147 | """ Set number of cores for mkl library.
148 |
149 | Args:
150 | num_cores (int): number of cores
151 | """
152 | try:
153 | mkl_rt = ctypes.CDLL('libmkl_rt.so')
154 | mkl_rt.mkl_set_dynamic(ctypes.byref(ctypes.c_int(0)))
155 | mkl_rt.mkl_set_num_threads(ctypes.byref(ctypes.c_int(num_cores)))
156 | except OSError:
157 | logger.warning('Failed to import libmkl_rt.so, it will not be possible to use mkl backend.')
158 |
159 |
160 | def get_gpu(really=True):
161 | try:
162 | if really:
163 | command = 'nvidia-smi --query-gpu=memory.free,memory.total --format=csv |tail -n+2| ' \
164 | 'awk \'BEGIN{FS=" "}{if ($1/$3 > 0.98) print NR-1}\''
165 | gpu_idx = subprocess.check_output(command, shell=True).rsplit(b'\n')[0].decode('utf-8')
166 | else:
167 | gpu_idx = '-1'
168 | os.environ["CUDA_VISIBLE_DEVICES"] = gpu_idx
169 | except subprocess.CalledProcessError:
170 | logger.warning('No GPUs seems to be available.')
171 |
172 |
173 | if __name__ == '__main__':
174 | parser = argparse.ArgumentParser('Extract embeddings used for diarization from audio wav files.')
175 |
176 | # required
177 | parser.add_argument('-l', '--input-list', help='list of input files without suffix',
178 | action='store', required=True)
179 | parser.add_argument('-c', '--configuration', help='input configuration of models',
180 | action='store', required=True)
181 | parser.add_argument('-m', '--mode', required=True, choices=['sre', 'diarization'],
182 | help='mode used - there are two possible modes, classic `diarization` mode which should'
183 | 'segment utterance into speakers and `sre` mode used for speaker recognition, which '
184 | 'runs clustering for N iterations and saves all clusters')
185 |
186 | # not required
187 | parser.add_argument('--audio-dir',
188 | help='directory with audio files in .wav format - 8000Hz, 16bit-s, 1c', required=False)
189 | parser.add_argument('--vad-dir',
190 | help='directory with lab files - Voice/Speech activity detection', required=False)
191 | parser.add_argument('--in-emb-dir',
192 | help='input directory containing embeddings', required=False)
193 | parser.add_argument('--out-emb-dir',
194 | help='output directory for storing embeddings', required=False)
195 | parser.add_argument('--norm-list',
196 | help='list of normalization files without suffix', required=False)
197 | parser.add_argument('--in-rttm-dir',
198 | help='input directory with rttm files', required=False)
199 | parser.add_argument('--out-rttm-dir',
200 | help='output directory for storing rttm files', required=False)
201 | parser.add_argument('--out-clusters-dir', required=False,
202 | help='output directory for storing clusters - only used when mode is `sre`')
203 | parser.add_argument('-wav-suffix',
204 | help='wav file suffix', required=False, default='.wav')
205 | parser.add_argument('-vad-suffix',
206 | help='Voice Activity Detector file suffix', required=False, default='.lab.gz')
207 | parser.add_argument('-rttm-suffix',
208 | help='rttm file suffix', required=False, default='.rttm')
209 | parser.add_argument('--min-window-size', default=1000,
210 | help='minimal window size for embedding clustering in ms', type=int, required=False)
211 | parser.add_argument('--max-window-size', default=2000,
212 | help='maximal window size for extracting embedding in ms', type=int, required=False)
213 | parser.add_argument('--window-overlap',
214 | help='overlap in window in ms', type=int, required=False, default=0)
215 | parser.add_argument('--vad-tolerance', default=0,
216 | help='tolerance critetion for ignoring frames of silence', type=float, required=False)
217 | parser.add_argument('--use-gpu', required=False, default=False, action='store_true',
218 | help='use GPU instead of cpu (onnxruntime-gpu must be installed)')
219 | parser.add_argument('--max-num-speakers',
220 | help='maximal number of speakers', required=False, default=10)
221 |
222 | args = parser.parse_args()
223 |
224 | logger.info(f'Running `{" ".join(sys.argv)}`.')
225 |
226 | set_mkl(1)
227 | get_gpu(args.use_gpu)
228 |
229 | # initialize extractor
230 | config = Utils.read_config(args.configuration)
231 |
232 | config_mfcc = config['MFCC']
233 | config_path = os.path.abspath(config_mfcc['config_path'])
234 | if not os.path.isfile(config_path):
235 | raise ValueError(f'Path to MFCC configuration `{config_path}` not found.')
236 | features_extractor = KaldiMFCCFeatureExtraction(
237 | config_path=config_path, apply_cmvn_sliding=config_mfcc['apply_cmvn_sliding'],
238 | norm_vars=config_mfcc['norm_vars'], center=config_mfcc['center'], cmn_window=config_mfcc['cmn_window'])
239 |
240 | config_embedding_extractor = config['EmbeddingExtractor']
241 | embedding_extractor = ONNXXVectorExtraction(onnx_path=os.path.abspath(config_embedding_extractor['onnx_path']))
242 |
243 | config_transforms = config['Transforms']
244 | mean = config_transforms.get('mean')
245 | lda = config_transforms.get('lda')
246 | if lda is not None:
247 | lda = np.load(lda)
248 | use_l2_norm = config_transforms.get('use_l2_norm')
249 | plda = config.get('PLDA')
250 | if plda is not None:
251 | plda = GPLDA(plda['path'])
252 |
253 | files = [line.rstrip('\n') for line in open(args.input_list)]
254 |
255 | # extract embeddings
256 | if args.in_emb_dir is None:
257 | if args.out_emb_dir is None:
258 | raise ValueError('At least one of `--in-emb-dir` or `--out-emb-dir` must be specified.')
259 | if args.audio_dir is None:
260 | raise ValueError('At least one of `--in-emb-dir` or `--audio-dir` must be specified.')
261 | if args.vad_dir is None:
262 | raise ValueError('`--audio-dir` was specified, `--vad-dir` must be specified too.')
263 | process_files(
264 | fns=files, wav_dir=args.audio_dir, vad_dir=args.vad_dir, out_dir=args.out_emb_dir,
265 | features_extractor=features_extractor, embedding_extractor=embedding_extractor,
266 | min_size=args.min_window_size, max_size=args.max_window_size, overlap=args.window_overlap,
267 | tolerance=args.vad_tolerance, wav_suffix=args.wav_suffix, vad_suffix=args.vad_suffix,
268 | n_jobs=1)
269 | if args.out_emb_dir:
270 | embeddings = args.out_emb_dir
271 | else:
272 | embeddings = args.in_emb_dir
273 |
274 | # initialize normalization
275 | if args.norm_list is not None:
276 | norm = Normalization(norm_list=args.norm_list, audio_dir=args.audio_dir,
277 | in_rttm_dir=args.in_rttm_dir, in_emb_dir=args.in_emb_dir,
278 | out_emb_dir=args.out_emb_dir, min_length=args.min_window_size, plda=plda,
279 | embedding_extractor=embedding_extractor, features_extractor=features_extractor,
280 | wav_suffix=args.wav_suffix, rttm_suffix=args.rttm_suffix, n_jobs=1)
281 | else:
282 | norm = None
283 |
284 | # load transformations if specified
285 | if not norm:
286 | if mean:
287 | mean = np.load(mean)
288 | else:
289 | mean = norm.mean
290 |
291 | # run diarization
292 | diar = Diarization(args.input_list, embeddings, embeddings_mean=mean, lda=lda,
293 | use_l2_norm=use_l2_norm, plda=plda, norm=norm)
294 | result = diar.score_embeddings(args.min_window_size, args.max_num_speakers, args.mode)
295 |
296 | if args.mode == 'diarization':
297 | if args.in_rttm_dir:
298 | diar.evaluate(scores=result, in_rttm_dir=args.in_rttm_dir, collar_size=0.25, evaluate_overlaps=False)
299 |
300 | if args.out_rttm_dir is not None:
301 | diar.dump_rttm(result, args.out_rttm_dir)
302 | else:
303 | if args.out_clusters_dir:
304 | for name in result:
305 | mkdir_p(os.path.join(args.out_clusters_dir, os.path.dirname(name)))
306 | np.save(os.path.join(args.out_clusters_dir, name), result[name])
307 |
--------------------------------------------------------------------------------
/examples/lists/AMI_dev-eval.scp:
--------------------------------------------------------------------------------
1 | EN2002a 4
2 | EN2002b 4
3 | EN2002c 4
4 | EN2002d 4
5 | ES2004a 4
6 | ES2004b 4
7 | ES2004c 4
8 | ES2004d 4
9 | ES2011a 4
10 | ES2011b 4
11 | ES2011c 4
12 | ES2011d 4
13 | IB4001 4
14 | IB4002 4
15 | IB4003 4
16 | IB4004 4
17 | IB4010 4
18 | IB4011 4
19 | IS1008a 4
20 | IS1008b 4
21 | IS1008c 4
22 | IS1008d 4
23 | IS1009a 4
24 | IS1009b 4
25 | IS1009c 4
26 | IS1009d 4
27 | TS3003a 4
28 | TS3003b 4
29 | TS3003c 4
30 | TS3003d 4
31 | TS3004a 4
32 | TS3004b 4
33 | TS3004c 4
34 | TS3004d 4
35 |
--------------------------------------------------------------------------------
/examples/lists/AMI_trn.scp:
--------------------------------------------------------------------------------
1 | EN2001a
2 | EN2001b
3 | EN2001d
4 | EN2001e
5 | EN2003a
6 | EN2004a
7 | EN2005a
8 | EN2006a
9 | EN2006b
10 | EN2009b
11 | EN2009c
12 | EN2009d
13 | ES2002a
14 | ES2002b
15 | ES2002c
16 | ES2002d
17 | ES2003a
18 | ES2003b
19 | ES2003c
20 | ES2003d
21 | ES2005a
22 | ES2005b
23 | ES2005c
24 | ES2005d
25 | ES2006a
26 | ES2006b
27 | ES2006c
28 | ES2006d
29 | ES2007a
30 | ES2007b
31 | ES2007c
32 | ES2007d
33 | ES2008a
34 | ES2008b
35 | ES2008c
36 | ES2008d
37 | ES2009a
38 | ES2009b
39 | ES2009c
40 | ES2009d
41 | ES2010a
42 | ES2010b
43 | ES2010c
44 | ES2010d
45 | ES2012a
46 | ES2012b
47 | ES2012c
48 | ES2012d
49 | ES2013a
50 | ES2013b
51 | ES2013c
52 | ES2013d
53 | ES2014a
54 | ES2014b
55 | ES2014c
56 | ES2014d
57 | ES2015a
58 | ES2015b
59 | ES2015c
60 | ES2015d
61 | ES2016a
62 | ES2016b
63 | ES2016c
64 | ES2016d
65 | IB4005
66 | IN1001
67 | IN1002
68 | IN1005
69 | IN1007
70 | IN1008
71 | IN1009
72 | IN1012
73 | IN1013
74 | IN1014
75 | IN1016
76 | IS1000a
77 | IS1000b
78 | IS1000c
79 | IS1000d
80 | IS1001a
81 | IS1001b
82 | IS1001c
83 | IS1001d
84 | IS1002b
85 | IS1002c
86 | IS1002d
87 | IS1003a
88 | IS1003c
89 | IS1003d
90 | IS1004a
91 | IS1004b
92 | IS1004c
93 | IS1004d
94 | IS1005a
95 | IS1005b
96 | IS1005c
97 | IS1006a
98 | IS1006b
99 | IS1006c
100 | IS1006d
101 | IS1007a
102 | IS1007b
103 | IS1007c
104 | TS3005a
105 | TS3005b
106 | TS3005c
107 | TS3005d
108 | TS3006a
109 | TS3006b
110 | TS3006c
111 | TS3006d
112 | TS3007a
113 | TS3007b
114 | TS3007c
115 | TS3007d
116 | TS3008a
117 | TS3008b
118 | TS3008c
119 | TS3008d
120 | TS3009a
121 | TS3009b
122 | TS3009c
123 | TS3009d
124 | TS3010a
125 | TS3010b
126 | TS3010c
127 | TS3010d
128 | TS3011a
129 | TS3011b
130 | TS3011c
131 | TS3011d
132 | TS3012a
133 | TS3012b
134 | TS3012c
135 | TS3012d
136 |
--------------------------------------------------------------------------------
/examples/lists/list.scp:
--------------------------------------------------------------------------------
1 | fisher-english-p1/fe_03_00001-a
2 | fisher-english-p1/fe_03_00001-b
3 |
--------------------------------------------------------------------------------
/examples/lists/list_spk.scp:
--------------------------------------------------------------------------------
1 | fe_03_00001-a 1
2 | fe_03_00001-b 1
3 |
--------------------------------------------------------------------------------
/examples/run.sh:
--------------------------------------------------------------------------------
1 | python diarization.py -c ../configs/vbdiar.yml \
2 | -l lists/list_spk.scp \
3 | --audio-dir wav/fisher-english-p1 \
4 | --vad-dir vad/fisher-english-p1 \
5 | --mode diarization \
6 | --out-emb-dir embeddings
7 |
--------------------------------------------------------------------------------
/examples/vad/fisher-english-p1/fe_03_00001-a.lab.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jamiroquai88/VBDiarization/a1ceabab3f58c925b2b16272fc9b3a4cd795068e/examples/vad/fisher-english-p1/fe_03_00001-a.lab.gz
--------------------------------------------------------------------------------
/examples/vad/fisher-english-p1/fe_03_00001-b.lab.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jamiroquai88/VBDiarization/a1ceabab3f58c925b2b16272fc9b3a4cd795068e/examples/vad/fisher-english-p1/fe_03_00001-b.lab.gz
--------------------------------------------------------------------------------
/examples/wav/fisher-english-p1/README:
--------------------------------------------------------------------------------
1 | These files are taken from the Fisher English database.
2 |
3 | The parameters of the files:
4 | 8kHz, 16bits per sample, 1 channel
5 |
6 | For reference, they were generated using BUT's raw files using:
7 | # for f in *.raw; do nf=$(echo $f|sed 's@raw@wav@'); sox -r 8k -s -b 16 -c 1 $f $nf; done
8 |
--------------------------------------------------------------------------------
/examples/wav/fisher-english-p1/fe_03_00001-a.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jamiroquai88/VBDiarization/a1ceabab3f58c925b2b16272fc9b3a4cd795068e/examples/wav/fisher-english-p1/fe_03_00001-a.wav
--------------------------------------------------------------------------------
/examples/wav/fisher-english-p1/fe_03_00001-b.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jamiroquai88/VBDiarization/a1ceabab3f58c925b2b16272fc9b3a4cd795068e/examples/wav/fisher-english-p1/fe_03_00001-b.wav
--------------------------------------------------------------------------------
/examples/wav/fisher-english-p1/smalltalk0501.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jamiroquai88/VBDiarization/a1ceabab3f58c925b2b16272fc9b3a4cd795068e/examples/wav/fisher-english-p1/smalltalk0501.wav
--------------------------------------------------------------------------------
/models/LDA.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jamiroquai88/VBDiarization/a1ceabab3f58c925b2b16272fc9b3a4cd795068e/models/LDA.npy
--------------------------------------------------------------------------------
/models/final.onnx:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jamiroquai88/VBDiarization/a1ceabab3f58c925b2b16272fc9b3a4cd795068e/models/final.onnx
--------------------------------------------------------------------------------
/models/gplda/CB.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jamiroquai88/VBDiarization/a1ceabab3f58c925b2b16272fc9b3a4cd795068e/models/gplda/CB.npy
--------------------------------------------------------------------------------
/models/gplda/CW.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jamiroquai88/VBDiarization/a1ceabab3f58c925b2b16272fc9b3a4cd795068e/models/gplda/CW.npy
--------------------------------------------------------------------------------
/models/gplda/mu.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jamiroquai88/VBDiarization/a1ceabab3f58c925b2b16272fc9b3a4cd795068e/models/gplda/mu.npy
--------------------------------------------------------------------------------
/models/mean.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jamiroquai88/VBDiarization/a1ceabab3f58c925b2b16272fc9b3a4cd795068e/models/mean.npy
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | scipy==1.2.1
2 | scikit-learn==0.20.3
3 | numpy==1.16.2
4 | pyyaml
5 | onnxruntime==0.3.0
6 | spherecluster==0.1.7
7 | pyclustering==0.8.2
8 | kaldiio==2.13.4
9 |
--------------------------------------------------------------------------------
/vbdiar/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | #
4 | # Copyright (C) 2018 Brno University of Technology FIT
5 | # Author: Jan Profant
6 | # All Rights Reserved
7 |
--------------------------------------------------------------------------------
/vbdiar/clustering/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | #
4 | # Copyright (C) 2018 Brno University of Technology FIT
5 | # Author: Jan Profant
6 | # All Rights Reserved
7 |
--------------------------------------------------------------------------------
/vbdiar/clustering/pldakmeans.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | #
4 | # Copyright (C) 2018 Brno University of Technology FIT
5 | # Author: Jan Profant
6 | # All Rights Reserved
7 |
8 | import numpy as np
9 |
10 |
11 | class PLDAKMeans(object):
12 | """ KMeans clustering algorithm using PLDA output as distance metric.
13 |
14 | """
15 | def __init__(self, centroids, k, plda, max_iter=10):
16 | """ Class constructor.
17 |
18 | :param centroids: initialization centroids
19 | :type centroids: numpy.array
20 | :param k: number of classes
21 | :type k: int
22 | :param plda: PLDA object
23 | :type plda: PLDA
24 | :param max_iter: maximal number of iterations
25 | :type max_iter: int
26 | """
27 | self.max_iter = max_iter
28 | self.old_labels = []
29 | self.data = None
30 | self.cluster_centers_ = centroids
31 | self.k = k
32 | self.plda = plda
33 |
34 | def fit(self, data):
35 | """ Fit the input data.
36 |
37 | :param data: input data
38 | :type data: numpy.array
39 | :returns: cluster centers
40 | :rtype: numpy.array
41 | """
42 | self.data = data
43 | iterations = 0
44 | while True:
45 | if self.stop(iterations):
46 | break
47 | else:
48 | iterations += 1
49 | return self.cluster_centers_
50 |
51 | def stop(self, iterations):
52 | """ Make the decision if algorithm should stop.
53 |
54 | :param iterations: number of successfull iterations
55 | :type iterations: int
56 | :returns: True if algorithm should stop, False otherwise
57 | :rtype: bool
58 | """
59 | labels = self.labels()
60 | if iterations > self.max_iter or self.old_labels == labels:
61 | return True
62 | else:
63 | self.old_labels = labels
64 | return False
65 |
66 | def labels(self):
67 | """ Predict labels.
68 |
69 | """
70 | scores = self.plda.score(self.data, self.cluster_centers_)
71 | centroids = {}
72 | for ii in range(self.k):
73 | centroids[ii] = []
74 | labels = []
75 | for ii in range(self.data.shape[0]):
76 | c = np.argmax(scores[ii])
77 | labels.append(c)
78 | centroids[c].append(self.data[ii])
79 | for ii in range(self.k):
80 | centroids[ii] = np.array(centroids[ii])
81 | # clustering has strange behaviour
82 | if centroids[ii].ndim == 1:
83 | return self.old_labels
84 | self.cluster_centers_[ii] = np.mean(centroids[ii], axis=0)
85 | return labels
86 |
--------------------------------------------------------------------------------
/vbdiar/embeddings/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | #
4 | # Copyright (C) 2018 Brno University of Technology FIT
5 | # Author: Jan Profant
6 | # All Rights Reserved
7 |
8 | from .embedding import Embedding, EmbeddingSet
9 |
--------------------------------------------------------------------------------
/vbdiar/embeddings/embedding.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | #
4 | # Copyright (C) 2018 Brno University of Technology FIT
5 | # Author: Jan Profant
6 | # All Rights Reserved
7 |
8 | import os
9 | import pickle
10 | import numpy as np
11 |
12 | from vbdiar.features.segments import get_time_from_frames
13 | from vbdiar.utils import mkdir_p
14 |
15 |
16 | def extract_embeddings(features_dict, embedding_extractor):
17 | """ Extract embeddings from multiple segments.
18 |
19 | Args:
20 | features_dict (Dict): dictionary with segment range as key and features as values
21 | embedding_extractor (Any):
22 |
23 | Returns:
24 | EmbeddingSet: extracted embedding in embedding set
25 | """
26 | embedding_set = EmbeddingSet()
27 | embeddings = embedding_extractor.features2embeddings(features_dict)
28 | for embedding_key in embeddings:
29 | start, end = embedding_key
30 | embedding_set.add(embeddings[embedding_key], window_start=int(float(start)), window_end=int(float(end)))
31 | return embedding_set
32 |
33 |
34 | class Embedding(object):
35 | """ Class for basic i-vector operations.
36 |
37 | """
38 |
39 | def __init__(self):
40 | """ Class constructor.
41 |
42 | """
43 | self.data = None
44 | self.features = None
45 | self.window_start = None
46 | self.window_end = None
47 |
48 |
49 | class EmbeddingSet(object):
50 | """ Class for encapsulating ivectors set.
51 |
52 | """
53 |
54 | def __init__(self):
55 | """ Class constructor.
56 |
57 | """
58 | self.name = None
59 | self.num_speakers = None
60 | self.embeddings = []
61 |
62 | def __iter__(self):
63 | current = 0
64 | while current < len(self.embeddings):
65 | yield self.embeddings[current]
66 | current += 1
67 |
68 | def __getitem__(self, key):
69 | return self.embeddings[key]
70 |
71 | def __setitem__(self, key, value):
72 | self.embeddings[key] = value
73 |
74 | def __len__(self):
75 | return len(self.embeddings)
76 |
77 | def get_all_embeddings(self):
78 | """ Get all ivectors.
79 |
80 | """
81 | a = []
82 | for i in self.embeddings:
83 | a.append(i.data.flatten())
84 | return np.array(a)
85 |
86 | def get_longer_embeddings(self, min_length):
87 | """ Get i-vectors extracted from longer segments than minimal length.
88 |
89 | Args:
90 | min_length (int): minimal length of segment in miliseconds
91 |
92 | Returns:
93 | np.array: i-vectors
94 | """
95 | a = []
96 | for embedding in self.embeddings:
97 | if embedding.window_end - embedding.window_start >= min_length:
98 | a.append(embedding.data.flatten())
99 | return np.array(a)
100 |
101 | def add(self, data, window_start, window_end, features=None):
102 | """ Add embedding to set.
103 |
104 | Args:
105 | data (np.array): embeding data
106 | window_start (int): start of the window [ms]
107 | window_end (int): end of the window [ms]
108 | features (np.array): features from which embedding was extracted
109 | """
110 | i = Embedding()
111 | i.data = data
112 | i.window_start = window_start
113 | i.window_end = window_end
114 | i.features = features
115 | self.__append(i)
116 |
117 | def __append(self, embedding):
118 | """ Append embedding to set of embedding.
119 |
120 | Args:
121 | embedding (Embedding):
122 | """
123 | ii = 0
124 | for vp in self.embeddings:
125 | if vp.window_start > embedding.window_start:
126 | break
127 | ii += 1
128 | self.embeddings.insert(ii, embedding)
129 |
130 | def save(self, path):
131 | """ Save embedding set as pickled file.
132 |
133 | Args:
134 | path (string_types): output path
135 | """
136 | mkdir_p(os.path.dirname(path))
137 | with open(path, 'wb') as f:
138 | pickle.dump(self, f, pickle.HIGHEST_PROTOCOL)
139 |
140 |
141 | if __name__ == "__main__":
142 | import doctest
143 | doctest.testmod()
144 |
--------------------------------------------------------------------------------
/vbdiar/features/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | #
4 | # Copyright (C) 2018 Brno University of Technology FIT
5 | # Author: Jan Profant
6 | # All Rights Reserved
7 |
--------------------------------------------------------------------------------
/vbdiar/features/segments.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | #
4 | # Copyright (C) 2018 Brno University of Technology FIT
5 | # Author: Jan Profant
6 | # All Rights Reserved
7 |
8 | import logging
9 |
10 | import numpy as np
11 |
12 |
13 | logger = logging.getLogger(__name__)
14 |
15 | RATE = 16000
16 | SOURCERATE = 1250
17 | TARGETRATE = 100000
18 |
19 | ZMEANSOURCE = True
20 | WINDOWSIZE = 250000.0
21 | USEHAMMING = True
22 | PREEMCOEF = 0.97
23 | NUMCHANS = 24
24 | CEPLIFTER = 22
25 | NUMCEPS = 19
26 | ADDDITHER = 1.0
27 | RAWENERGY = True
28 | ENORMALISE = True
29 |
30 | deltawindow = accwindow = 2
31 |
32 | cmvn_lc = 150
33 | cmvn_rc = 150
34 |
35 | fs = 1e7 / SOURCERATE
36 |
37 |
38 | def get_segments(vad, max_size, tolerance):
39 | """ Return clustered speech segments.
40 |
41 | :param vad: list with labels - voice activity detection
42 | :type vad: list
43 | :param max_size: maximal size of window in ms
44 | :type max_size: int
45 | :param tolerance: accept given number of frames as speech even when it is marked as silence
46 | :type tolerance: int
47 | :returns: clustered segments
48 | :rtype: list
49 | """
50 | clusters = get_clusters(vad, tolerance)
51 | segments = []
52 | max_frames = get_frames_from_time(max_size)
53 | for item in clusters.values():
54 | if item[1] - item[0] > max_frames:
55 | for ss in split_segment(item, max_frames):
56 | segments.append(ss)
57 | else:
58 | segments.append(item)
59 | return segments
60 |
61 |
62 | def split_segment(segment, max_size):
63 | """ Split segment to more with adaptive size.
64 |
65 | :param segment: input segment
66 | :type segment: tuple
67 | :param max_size: maximal size of window in ms
68 | :type max_size: int
69 | :returns: splitted segment
70 | :rtype: list
71 | """
72 | size = segment[1] - segment[0]
73 | num_segments = int(np.math.ceil(size / max_size))
74 | size_segment = size / num_segments
75 | for ii in range(num_segments):
76 | yield (int(segment[0] + ii * size_segment), int(segment[0] + (ii + 1) * size_segment))
77 |
78 |
79 | def get_frames_from_time(n):
80 | """ Get number of frames from ms.
81 |
82 | :param n: number of ms
83 | :type n: int
84 | :returns: number of frames
85 | :rtype: int
86 |
87 | >>> get_frames_from_time(25)
88 | 1
89 | >>> get_frames_from_time(35)
90 | 2
91 | """
92 | assert n >= 0, 'Time must be at least equal to 0.'
93 | if n < 25:
94 | return 0
95 | return int(1 + (n - WINDOWSIZE / 10000) / (TARGETRATE / 10000))
96 |
97 |
98 | def get_time_from_frames(n):
99 | """ Get count of ms from number of frames.
100 |
101 | :param n: number of frames
102 | :type n: int
103 | :returns: number of ms
104 | :rtype: int
105 |
106 | >>> get_time_from_frames(1)
107 | 25
108 | >>> get_time_from_frames(2)
109 | 35
110 |
111 | """
112 | return int(n * (TARGETRATE / 10000) - (TARGETRATE / 10000) + (WINDOWSIZE / 10000))
113 |
114 |
115 | def get_clusters(vad, tolerance=10):
116 | """ Cluster speech segments.
117 |
118 | :param vad: list with labels - voice activity detection
119 | :type vad: list
120 | :param tolerance: accept given number of frames as speech even when it is marked as silence
121 | :type tolerance: int
122 | :returns: clustered speech segments
123 | :rtype: dict
124 | """
125 | num_prev = 0
126 | in_tolerance = 0
127 | num_clusters = 0
128 | clusters = {}
129 | for ii, frame in enumerate(vad):
130 | if frame:
131 | num_prev += 1
132 | else:
133 | in_tolerance += 1
134 | if in_tolerance > tolerance:
135 | if num_prev > 0:
136 | clusters[num_clusters] = (ii - num_prev, ii)
137 | num_clusters += 1
138 | num_prev = 0
139 | in_tolerance = 0
140 | if num_prev > 0:
141 | clusters[num_clusters] = (ii - num_prev, ii)
142 | num_clusters += 1
143 | return clusters
144 |
145 |
146 | def split_seq(seq, size):
147 | """ Split up seq in pieces of size.
148 |
149 | Args:
150 | seq:
151 | size:
152 |
153 | Returns:
154 |
155 | """
156 | return [seq[i:i + size] for i in range(0, len(seq), size)]
157 |
--------------------------------------------------------------------------------
/vbdiar/kaldi/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | #
4 | # Copyright (C) 2018 Brno University of Technology FIT
5 | # Author: Jan Profant
6 | # All Rights Reserved
7 |
8 | import os
9 |
10 |
11 | KALDI_ROOT_PATH = os.getenv('KALDI_ROOT_PATH', '.')
12 |
13 | bin_path = os.path.join(KALDI_ROOT_PATH, 'src', 'bin')
14 | featbin_path = os.path.join(KALDI_ROOT_PATH, 'src', 'featbin')
15 | nnet3bin_path = os.path.join(KALDI_ROOT_PATH, 'src', 'nnet3bin')
16 |
17 |
--------------------------------------------------------------------------------
/vbdiar/kaldi/kaldi_xvector_extraction.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | #
4 | # Copyright (C) 2018 Brno University of Technology FIT
5 | # Author: Jan Profant
6 | # All Rights Reserved
7 |
8 | import os
9 | import logging
10 | import tempfile
11 | import subprocess
12 |
13 | from vbdiar.kaldi import nnet3bin_path
14 | from vbdiar.kaldi.utils import write_txt_matrix, read_txt_vectors
15 |
16 |
17 | logger = logging.getLogger(__name__)
18 |
19 |
20 | class KaldiXVectorExtraction(object):
21 |
22 | def __init__(self, nnet, binary_path=nnet3bin_path, use_gpu=False,
23 | min_chunk_size=25, chunk_size=10000, cache_capacity=64):
24 | """ Initialize Kaldi x-vector extractor.
25 |
26 | Args:
27 | nnet (string_types): path to neural net
28 | use_gpu (bool):
29 | min_chunk_size (int):
30 | chunk_size (int):
31 | cache_capacity (int):
32 | """
33 | self.nnet3_xvector_compute = os.path.join(binary_path, 'nnet3-xvector-compute')
34 | if not os.path.exists(self.nnet3_xvector_compute):
35 | raise ValueError(
36 | 'Path to nnet3-xvector-compute - `{}` does not exists.'.format(self.nnet3_xvector_compute))
37 | self.nnet3_copy = os.path.join(binary_path, 'nnet3-copy')
38 | if not os.path.exists(self.nnet3_copy):
39 | raise ValueError(
40 | 'Path to nnet3-copy - `{}` does not exists.'.format(self.nnet3_copy))
41 | if not os.path.isfile(nnet):
42 | raise ValueError('Invalid path to nnet `{}`.'.format(nnet))
43 | else:
44 | self.nnet = nnet
45 | self.binary_path = binary_path
46 | self.use_gpu = use_gpu
47 | self.min_chunk_size = min_chunk_size
48 | self.chunk_size = chunk_size
49 | self.cache_capacity = cache_capacity
50 |
51 | def features2embeddings(self, data_dict):
52 | """ Extract x-vector embeddings from feature vectors.
53 |
54 | Args:
55 | data_dict (Dict):
56 |
57 | Returns:
58 |
59 | """
60 | tmp_data_dict = {}
61 | for key in data_dict:
62 | tmp_data_dict[f'{key[0]}_{key[1]}'] = data_dict[key]
63 | with tempfile.NamedTemporaryFile() as xvec_ark, tempfile.NamedTemporaryFile() as mfcc_ark:
64 | write_txt_matrix(path=mfcc_ark.name, data_dict=tmp_data_dict)
65 |
66 | args = [self.nnet3_xvector_compute,
67 | '--use-gpu={}'.format('yes' if self.use_gpu else 'no'),
68 | '--min-chunk-size={}'.format(str(self.min_chunk_size)),
69 | '--chunk-size={}'.format(str(self.chunk_size)),
70 | '--cache-capacity={}'.format(str(self.cache_capacity)),
71 | self.nnet, 'ark,t:{}'.format(mfcc_ark.name), 'ark,t:{}'.format(xvec_ark.name)]
72 |
73 | logger.info('Extracting x-vectors from {} feature vectors to `{}`.'.format(len(tmp_data_dict), xvec_ark.name))
74 | process = subprocess.Popen(
75 | args, stderr=subprocess.PIPE, stdout=subprocess.PIPE, cwd=self.binary_path, shell=False)
76 | _, stderr = process.communicate()
77 | if process.returncode != 0:
78 | raise ValueError('`{}` binary returned error code {}.{}{}'.format(
79 | self.nnet3_xvector_compute, process.returncode, os.linesep, stderr))
80 | tmp_xvec_dict = read_txt_vectors(xvec_ark.name)
81 | xvec_dict = {}
82 | for key in tmp_xvec_dict:
83 | new_key = tuple(key.split('_'))
84 | xvec_dict[new_key] = tmp_xvec_dict[key]
85 | return xvec_dict
86 |
--------------------------------------------------------------------------------
/vbdiar/kaldi/mfcc_features_extraction.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | #
4 | # Copyright (C) 2018 Brno University of Technology FIT
5 | # Author: Jan Profant
6 | # All Rights Reserved
7 |
8 | import os
9 | import logging
10 | import tempfile
11 | import subprocess
12 |
13 | import kaldiio
14 |
15 | from vbdiar.kaldi import featbin_path
16 |
17 |
18 | logger = logging.getLogger(__name__)
19 |
20 |
21 | class KaldiMFCCFeatureExtraction(object):
22 |
23 | def __init__(self, config_path, binary_path=featbin_path, apply_cmvn_sliding=True,
24 | norm_vars=False, center=True, cmn_window=300):
25 | """ Initialize Kaldi MFCC extraction component. Names of the arguments keep original Kaldi convention.
26 |
27 | Args:
28 | config_path (string_types): path to config file
29 | binary_path (string_types): path to directory containing binaries
30 | apply_cmvn_sliding (bool): apply cepstral mean and variance normalization
31 | norm_vars (bool): normalize variances
32 | center (bool): center window
33 | cmn_window (int): window size
34 | """
35 | self.binary_path = binary_path
36 | self.config_path = config_path
37 | self.apply_cmvn_sliding = apply_cmvn_sliding
38 | self.norm_vars = norm_vars
39 | self.center = center
40 | self.cmn_window = cmn_window
41 | self.compute_mfcc_feats_bin = os.path.join(binary_path, 'compute-mfcc-feats')
42 | if not os.path.exists(self.compute_mfcc_feats_bin):
43 | raise ValueError('Path to compute-mfcc-feats - {} does not exists.'.format(self.compute_mfcc_feats_bin))
44 | self.copy_feats_bin = os.path.join(binary_path, 'copy-feats')
45 | if not os.path.exists(self.copy_feats_bin):
46 | raise ValueError('Path to copy-feats - {} does not exists.'.format(self.copy_feats_bin))
47 | self.apply_cmvn_sliding_bin = os.path.join(binary_path, 'apply-cmvn-sliding')
48 | if not os.path.exists(self.apply_cmvn_sliding_bin):
49 | raise ValueError('Path to apply-cmvn-sliding - {} does not exists.'.format(self.apply_cmvn_sliding_bin))
50 |
51 | def __str__(self):
52 | return ''.format(self.config_path)
53 |
54 | def audio2features(self, input_path):
55 | """ Extract features from list of files into list of numpy.arrays
56 |
57 | Args:
58 | input_path (string_types): audio file path
59 |
60 | Returns:
61 | Tuple[str, np.array]: path to Kaldi ark file containing features and features itself
62 | """
63 | with tempfile.NamedTemporaryFile(mode='w') as wav_scp, tempfile.NamedTemporaryFile() as mfcc_ark:
64 | # dump list of file to wav.scp file
65 | wav_scp.write('{} {}{}'.format(input_path, input_path, os.linesep))
66 | wav_scp.flush()
67 |
68 | # run fextract
69 | args = [self.compute_mfcc_feats_bin, f'--config={self.config_path}', f'scp:{wav_scp.name}',
70 | f'ark:{mfcc_ark.name if not self.apply_cmvn_sliding else "-"}']
71 | logger.info('Extracting MFCC features from `{}`.'.format(input_path))
72 | compute_mfcc_feats = subprocess.Popen(
73 | args, stderr=subprocess.PIPE, stdout=subprocess.PIPE, cwd=self.binary_path, shell=False)
74 | if not self.apply_cmvn_sliding:
75 | # do not apply cmvn, so just simply compute features
76 | _, stderr = compute_mfcc_feats.communicate()
77 | if compute_mfcc_feats.returncode != 0:
78 | raise ValueError(f'`{self.compute_mfcc_feats_bin}` binary returned error code '
79 | f'{compute_mfcc_feats.returncode}.{os.linesep}{stderr}')
80 | else:
81 | args2 = [self.apply_cmvn_sliding_bin, f'--norm-vars={str(self.norm_vars).lower()}',
82 | f'--center={str(self.center).lower()}', f'--cmn-window={str(self.cmn_window)}',
83 | 'ark:-', f'ark:{mfcc_ark.name}']
84 | apply_cmvn_sliding = subprocess.Popen(args2, stdin=compute_mfcc_feats.stdout,
85 | stderr=subprocess.PIPE, stdout=subprocess.PIPE, shell=False)
86 | _, stderr = apply_cmvn_sliding.communicate()
87 | if apply_cmvn_sliding.returncode == 0:
88 | pass
89 | else:
90 | raise ValueError(f'`{self.compute_mfcc_feats_bin}` binary returned error code '
91 | f'{compute_mfcc_feats.returncode}.{os.linesep}{stderr}')
92 | ark = kaldiio.load_ark(mfcc_ark.name)
93 | for key, numpy_array in ark:
94 | return numpy_array
95 |
--------------------------------------------------------------------------------
/vbdiar/kaldi/onnx_xvector_extraction.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | #
4 | # Copyright (C) 2018 Brno University of Technology FIT
5 | # Author: Jan Profant
6 | # All Rights Reserved
7 |
8 | import logging
9 | import os
10 |
11 | import numpy as np
12 | import onnxruntime
13 |
14 |
15 | logger = logging.getLogger(__name__)
16 |
17 | MIN_SIGNAL_LEN = 25
18 |
19 |
20 | class ONNXXVectorExtraction(object):
21 |
22 | def __init__(self, onnx_path):
23 | """ Initialize ONNX x-vector extractor.
24 |
25 | Args:
26 | onnx_path (str): path to neural net in ONNX format, see https://github.com/onnx/onnx
27 | """
28 | if not os.path.isfile(onnx_path):
29 | raise ValueError(f'Invalid path to nnet `{onnx_path}`.')
30 | else:
31 | self.onnx_path = onnx_path
32 | self.sess = onnxruntime.InferenceSession(onnx_path)
33 | self.input_name = self.sess.get_inputs()[0].name
34 |
35 | def features2embeddings(self, data_dict):
36 | """ Extract x-vector embeddings from feature vectors.
37 |
38 | Args:
39 | data_dict (Dict):
40 |
41 | Returns:
42 |
43 | """
44 | logger.info(f'Extracting x-vectors from {len(data_dict)} segments.')
45 | xvec_dict = {}
46 | for name in data_dict:
47 | signal_len, num_coefs = data_dict[name].shape
48 | # here we need to avoid failing on very short inputs, so we will just concatenate frames in time
49 | if signal_len == 0:
50 | continue
51 | elif signal_len < MIN_SIGNAL_LEN:
52 | for i in range(MIN_SIGNAL_LEN // signal_len):
53 | data_dict[name] = np.concatenate((data_dict[name], data_dict[name]), axis=0)
54 | xvec = self.sess.run(None, {self.input_name: data_dict[name].T[np.newaxis, :, :]})[0]
55 | xvec_dict[name] = xvec.squeeze()
56 | return xvec_dict
57 |
--------------------------------------------------------------------------------
/vbdiar/kaldi/utils.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | #
4 | # Copyright (C) 2018 Brno University of Technology FIT
5 | # Author: Jan Profant
6 | # All Rights Reserved
7 |
8 | import os
9 |
10 | import numpy as np
11 |
12 |
13 | def read_txt_matrix(path):
14 | """ Read features file in text format. This code expects correct format of input file.
15 |
16 | Args:
17 | path (string_types): path to txt file
18 |
19 | Returns:
20 | Dict[np.array]: name to array mapping
21 | """
22 | data_dict, name = {}, None
23 | with open(path) as f:
24 | for line in f:
25 | if '[' in line:
26 | name = line.split()[0] if len(line) > 3 else ''
27 | continue
28 | elif ']' in line:
29 | line = line.replace(' ]', '')
30 | assert name is not None, 'Incorrect format of input file `{}`.'.format(path)
31 | if name not in data_dict:
32 | data_dict[name] = []
33 | data_dict[name].append(np.fromstring(line, sep=' ', dtype=np.float32))
34 | for name in data_dict:
35 | data_dict[name] = np.array(data_dict[name])
36 | return data_dict
37 |
38 |
39 | def write_txt_matrix(path, data_dict):
40 | """ Write features into file in text format. This code expects correct format of input dictionary.
41 |
42 | Args:
43 | path (string_types): path to txt file
44 | data_dict (Dict[np.array]): name to array mapping
45 | """
46 | with open(path, 'w') as f:
47 | for name in sorted(data_dict.keys()):
48 | f.write('{} ['.format(name, os.linesep))
49 | for row_idx in range(data_dict[name].shape[0]):
50 | f.write('{} '.format(os.linesep))
51 | data_dict[name][row_idx].tofile(f, sep=' ', format='%.6f')
52 | f.write(' ]{}'.format(os.linesep))
53 |
54 |
55 | def read_txt_vectors(path):
56 | """ Read vectors file in text format. This code expects correct format of input file.
57 |
58 | Args:
59 | path (string_types): path to txt file
60 |
61 | Returns:
62 | Dict[np.array]: name to array mapping
63 | """
64 | data_dict = {}
65 | with open(path) as f:
66 | for line in f:
67 | splitted_line = line.split()
68 | name = splitted_line[0]
69 | end_idx = splitted_line.index(']')
70 | vector_data = np.array([float(single_float) for single_float in splitted_line[2:end_idx]])
71 | data_dict[name] = vector_data
72 | return data_dict
73 |
--------------------------------------------------------------------------------
/vbdiar/scoring/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | #
4 | # Copyright (C) 2018 Brno University of Technology FIT
5 | # Author: Jan Profant
6 | # All Rights Reserved
7 |
--------------------------------------------------------------------------------
/vbdiar/scoring/diarization.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | #
4 | # Copyright (C) 2018 Brno University of Technology FIT
5 | # Author: Jan Profant
6 | # All Rights Reserved
7 |
8 | import os
9 | import re
10 | import pickle
11 | import logging
12 | from shutil import rmtree
13 | from subprocess import check_output
14 | from tempfile import mkdtemp, NamedTemporaryFile
15 |
16 | import numpy as np
17 | from spherecluster import SphericalKMeans
18 | from pyclustering.cluster.xmeans import xmeans
19 | from sklearn.cluster import KMeans as sklearnKMeans
20 | from sklearn.cluster import AgglomerativeClustering, DBSCAN, MeanShift
21 | from sklearn.metrics.pairwise import cosine_similarity, pairwise_distances
22 |
23 | from vbdiar.clustering.pldakmeans import PLDAKMeans
24 | from vbdiar.scoring.normalization import Normalization
25 | from vbdiar.utils import mkdir_p
26 | from vbdiar.utils.utils import Utils
27 |
28 |
29 | CDIR = os.path.dirname(os.path.realpath(__file__))
30 | MD_EVAL_SCRIPT_PATH = os.path.join(CDIR, 'md-eval.pl')
31 | MAX_SRE_CLUSTERS = 5
32 |
33 | logger = logging.getLogger(__name__)
34 |
35 |
36 | def evaluate2rttms(reference_path, hypothesis_path, collar_size=0.25, evaluate_overlaps=False):
37 | """ Evaluate two rttms.
38 |
39 | Args:
40 | reference_path (string_types):
41 | hypothesis_path (string_types):
42 | collar_size (float):
43 | evaluate_overlaps (bool):
44 |
45 | Returns:
46 | float: diarization error rate
47 | """
48 | args = [MD_EVAL_SCRIPT_PATH, '{}'.format('' if evaluate_overlaps else '-1'),
49 | '-c', str(collar_size), '-r', reference_path, '-s', hypothesis_path]
50 | stdout = check_output(args)
51 |
52 | for line in stdout.decode('utf-8').split(os.linesep):
53 | if ' OVERALL SPEAKER DIARIZATION ERROR = ' in line:
54 | return float(line.replace(
55 | ' OVERALL SPEAKER DIARIZATION ERROR = ', '').replace(
56 | ' percent of scored speaker time `(ALL)', ''))
57 | raise ValueError(f'Command `{" ".join(args)}` failed.')
58 |
59 |
60 | def evaluate_all(reference_dir, hypothesis_dir, names, collar_size=0.25, evaluate_overlaps=False, rttm_ext='.rttm'):
61 | """ Evaluate all rttms in directories specified by list of names.
62 |
63 | Args:
64 | reference_dir (string_types): directory containing reference rttm files
65 | hypothesis_dir (string_types): directory containing hypothesis rttm files
66 | names (List[string_types]): list containing relative names
67 | collar_size (float):
68 | evaluate_overlaps (bool):
69 | rttm_ext (string_types): extension of rttm files
70 |
71 | Returns:
72 | float: diarization error rate
73 | """
74 | with NamedTemporaryFile(mode='w') as ref, NamedTemporaryFile(mode='w') as hyp:
75 | for name in names:
76 | with open(f'{os.path.join(reference_dir, name)}{rttm_ext}') as f:
77 | for line in f:
78 | ref.write(line)
79 | ref.write(os.linesep)
80 | with open(f'{os.path.join(hypothesis_dir, name)}{rttm_ext}') as f:
81 | for line in f:
82 | hyp.write(line)
83 | hyp.write(os.linesep)
84 | ref.flush()
85 | hyp.flush()
86 |
87 | return evaluate2rttms(ref.name, hyp.name, collar_size=collar_size, evaluate_overlaps=evaluate_overlaps)
88 |
89 |
90 | class Diarization(object):
91 | """ Diarization class used as main diarization focused implementation.
92 |
93 | """
94 |
95 | def __init__(self, input_list, embeddings, embeddings_mean=None, lda=None, use_l2_norm=True, norm=None, plda=None):
96 | """ Initialize diarization class.
97 |
98 | Args:
99 | input_list (string_types): path to list of input files
100 | embeddings (string_types|List[EmbeddingSet]): path to directory containing embeddings or list
101 | of EmbeddingSet instances
102 | embeddings_mean (np.ndarray):
103 | lda (np.ndarray): linear discriminant analysis - dimensionality reduction
104 | use_l2_norm (bool): do l2 normalization
105 | norm (Normalization): instance of class Normalization
106 | plda (GPLDA): instance of class GPLDA
107 | """
108 | self.input_list = input_list
109 | if isinstance(embeddings, str):
110 | self.embeddings_dir = embeddings
111 | self.embeddings = list(self.load_embeddings())
112 | else:
113 | self.embeddings = embeddings
114 | self.lda = lda
115 | self.use_l2_norm = use_l2_norm
116 | self.norm = norm
117 | self.plda = plda
118 |
119 | for embedding_set in self.embeddings:
120 | for embedding in embedding_set:
121 | if embeddings_mean is not None:
122 | embedding.data = embedding.data - embeddings_mean
123 | if lda is not None:
124 | embedding.data = embedding.data.dot(lda)
125 | if use_l2_norm:
126 | embedding.data = Utils.l2_norm(embedding.data[np.newaxis, :]).flatten()
127 | if self.norm:
128 | assert embeddings_mean is not None, 'Expecting usage of mean from normalization set.'
129 | self.norm.embeddings = self.norm.embeddings - embeddings_mean
130 | if lda is not None:
131 | self.norm.embeddings = self.norm.embeddings.dot(lda)
132 | if use_l2_norm:
133 | self.norm.embeddings = Utils.l2_norm(self.norm.embeddings)
134 |
135 | def get_embedding(self, name):
136 | """ Get embedding set by name.
137 |
138 | Args:
139 | name (string_types):
140 |
141 | Returns:
142 | EmbeddingSet:
143 | """
144 | for ii in self.embeddings:
145 | if name == ii.name:
146 | return ii
147 | raise ValueError(f'Name of the set not found - `{name}`.')
148 |
149 | def load_embeddings(self):
150 | """ Load embedding from pickled files.
151 |
152 | Returns:
153 | List[EmbeddingSet]:
154 | """
155 | logger.info(f'Loading pickled evaluation embedding from `{self.embeddings_dir}`.')
156 | with open(self.input_list, 'r') as f:
157 | for line in f:
158 | if len(line) > 0:
159 | logger.info(f'Loading evaluation pickle file `{line.rstrip().split()[0]}`.')
160 | line = line.rstrip()
161 | try:
162 | if len(line.split()) == 1:
163 | with open(os.path.join(self.embeddings_dir, line + '.pkl'), 'rb') as i:
164 | yield pickle.load(i)
165 | elif len(line.split()) == 2:
166 | file_name = line.split()[0]
167 | num_spks = int(line.split()[1])
168 | with open(os.path.join(self.embeddings_dir, file_name + '.pkl'), 'rb') as i:
169 | ivec_set = pickle.load(i)
170 | ivec_set.num_speakers = num_spks
171 | yield ivec_set
172 | else:
173 | raise ValueError(f'Unexpected number of columns in input list `{self.input_list}`.')
174 | except IOError:
175 | logger.warning(f'No pickle file found for `{line.rstrip().split()[0]}`'
176 | f' in `{self.embeddings_dir}`.')
177 |
178 | def score_embeddings(self, min_length, max_num_speakers, mode):
179 | """ Score embeddings.
180 |
181 | Args:
182 | min_length (int): minimal length of segment used for clustering in miliseconds
183 | max_num_speakers (int): maximal number of speakers
184 | mode (str): running mode, see examples/diarization.py for details
185 |
186 | Returns:
187 | dict: dictionary with scores for each file
188 | """
189 | result_dict = {}
190 | logger.info('Scoring using `{}`.'.format('PLDA' if self.plda is not None else 'cosine distance'))
191 | for embedding_set in self.embeddings:
192 | name = os.path.normpath(embedding_set.name)
193 | embeddings_all = embedding_set.get_all_embeddings()
194 | embeddings_long = embedding_set.get_longer_embeddings(min_length)
195 | if len(embeddings_long) == 0:
196 | logger.warning(
197 | f'No embeddings found longer than {min_length} for embedding set `{name}`.')
198 | continue
199 | size = len(embedding_set)
200 | if size > 0:
201 | logger.info(f'Clustering `{name}` using {len(embeddings_long)} long embeddings.')
202 | if mode == 'diarization':
203 | if embedding_set.num_speakers is not None:
204 | num_speakers = embedding_set.num_speakers
205 | else:
206 | xm = xmeans(embeddings_long, kmax=max_num_speakers)
207 | xm.process()
208 | num_speakers = len(xm.get_clusters())
209 |
210 | centroids = self.run_clustering(num_speakers, embeddings_long)
211 | if self.norm is None:
212 | if self.plda is None:
213 | result_dict[name] = cosine_similarity(embeddings_all, centroids).T
214 | else:
215 | result_dict[name] = self.plda.score(embeddings_all, centroids)
216 | else:
217 | result_dict[name] = self.norm.s_norm(embeddings_all, centroids)
218 | else:
219 | clusters = []
220 | for k in range(1, MAX_SRE_CLUSTERS + 1):
221 | if size >= k:
222 | centroids = self.run_clustering(k, embeddings_long)
223 | clusters.extend(x for x in centroids)
224 | result_dict[name] = np.array(clusters)
225 | else:
226 | logger.warning(f'No embeddings to score in `{embedding_set.name}`.')
227 | return result_dict
228 |
229 | def run_ahc(self, n_clusters, embeddings, scores_matrix):
230 | """ Run agglomerative hierarchical clustering.
231 |
232 | Returns:
233 | np.array: means of clusters
234 | """
235 | scores_matrix = -((scores_matrix - np.min(scores_matrix)) / (np.max(scores_matrix) - np.min(scores_matrix)))
236 | ahc = AgglomerativeClustering(affinity='precomputed', linkage='complete', n_clusters=n_clusters)
237 | labels = ahc.fit_predict(scores_matrix)
238 | return np.array([np.mean(embeddings[np.where(labels == i)], axis=0) for i in range(n_clusters)])
239 |
240 | def run_clustering(self, num_speakers, embeddings):
241 | if self.use_l2_norm:
242 | kmeans_clustering = SphericalKMeans(
243 | n_clusters=num_speakers, n_init=100, n_jobs=1).fit(embeddings)
244 | else:
245 | kmeans_clustering = sklearnKMeans(
246 | n_clusters=num_speakers, n_init=100, n_jobs=1).fit(embeddings)
247 | centroids = kmeans_clustering.cluster_centers_
248 | if self.plda:
249 | centroids = PLDAKMeans(centroids=kmeans_clustering.cluster_centers_, k=num_speakers,
250 | plda=self.plda, max_iter=100).fit(embeddings)
251 | return centroids
252 |
253 | def dump_rttm(self, scores, out_dir):
254 | """ Dump rttm files to output directory. This function requires initialized embeddings.
255 |
256 | Args:
257 | scores (Dict): dictionary containing scores
258 | out_dir (string_types): path to output directory
259 | """
260 | for embedding_set in self.embeddings:
261 | if len(embedding_set) > 0:
262 | name = embedding_set.name
263 | reg_name = re.sub('/.*', '', embedding_set.name)
264 | mkdir_p(os.path.join(out_dir, os.path.dirname(name)))
265 | with open(os.path.join(out_dir, name + '.rttm'), 'w') as f:
266 | for i, embedding in enumerate(embedding_set.embeddings):
267 | start, end = embedding.window_start, embedding.window_end
268 | idx = np.argmax(scores[name][i])
269 | f.write(f'SPEAKER {reg_name} 1 {float(start / 1000.0)} {float((end - start) / 1000.0)} '
270 | f' {reg_name}_spkr_{idx} \n')
271 | else:
272 | logger.warning(f'No embedding to dump in {embedding_set.name}.')
273 |
274 | def evaluate(self, scores, in_rttm_dir, collar_size=0.25, evaluate_overlaps=False, rttm_ext='.rttm'):
275 | """ At first, separately evaluate each file based on ground truth segmentation. Then evaluate all files.
276 |
277 | Args:
278 | scores (dict): dictionary containing scores
279 | in_rttm_dir (string_types): input directory with rttm files
280 | collar_size (float): collar size for scoring
281 | evaluate_overlaps (bool): evaluate or ignore overlapping speech segments
282 | rttm_ext (string_types): extension for rttm files
283 | """
284 | tmp_dir = mkdtemp(prefix='rttm_')
285 | self.dump_rttm(scores, tmp_dir)
286 | for embedding_set in self.embeddings:
287 | name = embedding_set.name
288 | ground_truth_rttm = os.path.join(in_rttm_dir, '{}{}'.format(name, rttm_ext))
289 | if not os.path.exists(ground_truth_rttm):
290 | logger.warning(f'Ground truth rttm file not found in `{ground_truth_rttm}`.')
291 | continue
292 | # evaluate single rttm
293 | der = evaluate2rttms(ground_truth_rttm, os.path.join(tmp_dir, '{}{}'.format(name, rttm_ext)),
294 | collar_size=collar_size, evaluate_overlaps=evaluate_overlaps)
295 | logger.info(f'`{name}` DER={der}')
296 |
297 | # evaluate all rttms
298 | der = evaluate_all(reference_dir=in_rttm_dir, hypothesis_dir=tmp_dir, names=scores.keys(),
299 | collar_size=collar_size, evaluate_overlaps=evaluate_overlaps, rttm_ext=rttm_ext)
300 | logger.info(f'`Total` DER={der}')
301 | rmtree(tmp_dir)
302 |
--------------------------------------------------------------------------------
/vbdiar/scoring/gplda.py:
--------------------------------------------------------------------------------
1 | import os
2 | import operator
3 | from functools import reduce
4 |
5 | import numpy as np
6 | from numpy.linalg import inv, slogdet
7 |
8 |
9 | class GPLDA(object):
10 | """ Gaussian PLDA model.
11 |
12 |
13 | """
14 | cw = None
15 | cb = None
16 | mean = None
17 | p = None
18 | q = None
19 | k = None
20 | r = None
21 | s = None
22 | t = None
23 | u = None
24 | ct = None
25 | count = None
26 | initialized = False
27 |
28 | def __init__(self, path):
29 | """ Init all parameters needed for Gaussian PLDA Model score computation.
30 |
31 | Args:
32 | path(str): path with Gaussian PLDA model
33 | """
34 | self.cw = np.load(os.path.join(path, 'CW.npy'))
35 | self.cb = np.load(os.path.join(path, 'CB.npy'))
36 | self.mean = np.load(os.path.join(path, 'mu.npy'))
37 |
38 | self.initialize()
39 |
40 | def initialize(self):
41 | """ Initialize members for faster scoring. """
42 | self.ct = self.cw + self.cb
43 | self.p = inv(self.ct * 0.5) - inv(0.5 * self.cw + self.cb)
44 | self.q = inv(2 * self.cw) - inv(2 * self.ct)
45 | k1 = reduce(operator.mul, slogdet(0.5 * self.ct))
46 | k2 = reduce(operator.mul, slogdet(0.5 * self.cw + self.cb))
47 | k3 = reduce(operator.mul, slogdet(2 * self.ct))
48 | k4 = reduce(operator.mul, slogdet(2 * self.cw))
49 | self.k = 0.5 * (k1 - k2 + k3 - k4)
50 | self.r = 0.5 * (0.25 * self.p - self.q)
51 | self.s = 0.5 * (0.25 * self.p + self.q)
52 | self.t = 0.25 * np.dot(self.p, self.mean.T)
53 | u1 = 2 * np.dot(self.mean, 0.25 * self.p)
54 | self.u = self.k + np.dot(u1, self.mean.T)
55 | self.initialized = True
56 |
57 | def score(self, np_vec_1, np_vec_2):
58 | """ Compare two vectors using plda scoring metric. Function is symmetric.
59 |
60 | Args:
61 | np_vec_1 (np.array): array of vectors (e.g. nx250), depends on model
62 | np_vec_2 (np.array): array of vectors (e.g. nx250), depends on model
63 |
64 | Returns:
65 | 2-dimensional np.array: scores matrix
66 | """
67 | if not self.initialized:
68 | raise ValueError('Model is not trained nor initialized.')
69 | np_vec_1 = np_vec_1.T.copy()
70 | np_vec_2 = np_vec_2.T.copy()
71 | mat1 = np.dot(self.r, np_vec_1) * np_vec_1
72 | mat2 = np.dot(self.r, np_vec_2) * np_vec_2
73 | vct1 = np.sum(mat1, axis=0, keepdims=True)
74 | vct2 = np.sum(mat2, axis=0, keepdims=True)
75 | vct3 = 2 * np.dot(self.t.T, np_vec_1)
76 | vct4 = 2 * np.dot(self.t.T, np_vec_2)
77 | mat3 = np.dot(np_vec_1.T, self.s)
78 | scores_matrix = 2 * np.dot(mat3, np_vec_2) + vct1.T + vct2 - vct3.T - vct4 + self.u
79 | return scores_matrix
80 |
--------------------------------------------------------------------------------
/vbdiar/scoring/normalization.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | #
4 | # Copyright (C) 2018 Brno University of Technology FIT
5 | # Author: Jan Profant
6 | # All Rights Reserved
7 |
8 | import os
9 | import logging
10 | import pickle
11 | import multiprocessing
12 |
13 | import numpy as np
14 | from sklearn.metrics.pairwise import cosine_similarity
15 |
16 | from vbdiar.features.segments import get_frames_from_time
17 | from vbdiar.embeddings.embedding import extract_embeddings
18 | from vbdiar.utils import mkdir_p
19 | from vbdiar.utils.utils import Utils
20 |
21 | logger = logging.getLogger(__name__)
22 |
23 |
24 | def process_files(fns, speakers_dict, features_extractor, embedding_extractor,
25 | audio_dir, wav_suffix, in_rttm_dir, rttm_suffix, min_length, n_jobs=1):
26 | """
27 |
28 | Args:
29 | fns:
30 | speakers_dict:
31 | features_extractor:
32 | embedding_extractor:
33 | audio_dir:
34 | wav_suffix:
35 | in_rttm_dir:
36 | rttm_suffix:
37 | min_length:
38 | n_jobs:
39 |
40 | Returns:
41 |
42 | """
43 | kwargs = dict(speakers_dict=speakers_dict, features_extractor=features_extractor,
44 | embedding_extractor=embedding_extractor, audio_dir=audio_dir, wav_suffix=wav_suffix,
45 | in_rttm_dir=in_rttm_dir, rttm_suffix=rttm_suffix, min_length=min_length)
46 | if n_jobs == 1:
47 | ret = _process_files((fns, kwargs))
48 | else:
49 | pool = multiprocessing.Pool(n_jobs)
50 | ret = pool.map(_process_files, ((part, kwargs) for part in Utils.partition(fns, n_jobs)))
51 | return ret
52 |
53 |
54 | def _process_files(dargs):
55 | """
56 |
57 | Args:
58 | dargs:
59 |
60 | Returns:
61 |
62 | """
63 | fns, kwargs = dargs
64 | ret = []
65 | for fn in fns:
66 | ret.append(process_file(file_name=fn, **kwargs))
67 | return ret
68 |
69 |
70 | def process_file(file_name, speakers_dict, features_extractor, embedding_extractor,
71 | audio_dir, wav_suffix, in_rttm_dir, rttm_suffix, min_length):
72 | """ Extract embeddings for all defined speakers.
73 |
74 | Args:
75 | file_name (string_types): path to input audio file
76 | speakers_dict (dict): dictionary containing all embedding across speakers
77 | features_extractor (Any):
78 | embedding_extractor (Any):
79 | audio_dir (string_types):
80 | wav_suffix (string_types):
81 | in_rttm_dir (string_types):
82 | rttm_suffix (string_types):
83 | min_length (float):
84 |
85 | Returns:
86 | dict: updated dictionary with speakers
87 | """
88 | logger.info('Processing file `{}`.'.format(file_name.split()[0]))
89 | # extract features from whole audio
90 | features = features_extractor.audio2features(os.path.join(audio_dir, '{}{}'.format(file_name, wav_suffix)))
91 |
92 | # process utterances of the speakers
93 | features_dict = {}
94 | with open(f'{os.path.join(in_rttm_dir, file_name)}{rttm_suffix}') as f:
95 | for line in f:
96 | start_time, dur = int(float(line.split()[3]) * 1000), int(float(line.split()[4]) * 1000)
97 | speaker = line.split()[7]
98 | if dur > min_length:
99 | end_time = start_time + dur
100 | start, end = get_frames_from_time(int(start_time)), get_frames_from_time(int(end_time))
101 | if speaker not in features_dict:
102 | features_dict[speaker] = {}
103 |
104 | assert 0 <= start < end, \
105 | f'Incorrect timing for extracting features, start: {start}, size: {features.shape[0]}, end: {end}.'
106 | if end >= features.shape[0]:
107 | end = features.shape[0] - 1
108 | features_dict[speaker][(start_time, end_time)] = features[start:end]
109 | for speaker in features_dict:
110 | embedding_set = extract_embeddings(features_dict[speaker], embedding_extractor)
111 | embeddings_long = embedding_set.get_all_embeddings()
112 | if speaker not in speakers_dict.keys():
113 | speakers_dict[speaker] = embeddings_long
114 | else:
115 | speakers_dict[speaker] = np.concatenate((speakers_dict[speaker], embeddings_long), axis=0)
116 | return speakers_dict
117 |
118 |
119 | class Normalization(object):
120 | """ Speaker normalization S-Norm. """
121 | embeddings = None
122 | in_emb_dir = None
123 |
124 | def __init__(self, norm_list, audio_dir=None, in_rttm_dir=None, in_emb_dir=None,
125 | out_emb_dir=None, min_length=None, features_extractor=None, embedding_extractor=None,
126 | plda=None, wav_suffix='.wav', rttm_suffix='.rttm', n_jobs=1):
127 | """ Initialize normalization object.
128 |
129 | Args:
130 | norm_list (string_types): path to normalization list
131 | audio_dir (string_types|None): path to audio directory
132 | in_rttm_dir (string_types|None): path to directory with rttm files
133 | in_emb_dir (str|None): path to directory with i-vectors
134 | out_emb_dir (str|None): path to directory for storing embeddings
135 | min_length (int): minimal length for extracting embeddings
136 | features_extractor (Any): object for feature extraction
137 | embedding_extractor (Any): object for extracting embedding
138 | plda (PLDA|None): plda model object
139 | wav_suffix (string_types): suffix of wav files
140 | rttm_suffix (string_types): suffix of rttm files
141 | """
142 | if audio_dir:
143 | self.audio_dir = os.path.abspath(audio_dir)
144 | self.norm_list = norm_list
145 | if in_rttm_dir:
146 | self.in_rttm_dir = os.path.abspath(in_rttm_dir)
147 | else:
148 | raise ValueError('It is required to have input rttm files for normalization.')
149 | self.features_extractor = features_extractor
150 | self.embedding_extractor = embedding_extractor
151 | self.plda = plda
152 | self.wav_suffix = wav_suffix
153 | self.rttm_suffix = rttm_suffix
154 | if in_emb_dir:
155 | self.in_emb_dir = os.path.abspath(in_emb_dir)
156 | if out_emb_dir:
157 | self.out_emb_dir = os.path.abspath(out_emb_dir)
158 | self.min_length = min_length
159 | self.n_jobs = n_jobs
160 | if self.in_emb_dir is None:
161 | self.embeddings = self.extract_embeddings()
162 | else:
163 | self.embeddings = self.load_embeddings()
164 | self.mean = np.mean(self.embeddings, axis=0)
165 |
166 | def __iter__(self):
167 | current = 0
168 | while current < len(self.embeddings):
169 | yield self.embeddings[current]
170 | current += 1
171 |
172 | def __getitem__(self, key):
173 | return self.embeddings[key]
174 |
175 | def __setitem__(self, key, value):
176 | self.embeddings[key] = value
177 |
178 | def __len__(self):
179 | return len(self.embeddings)
180 |
181 | def extract_embeddings(self):
182 | """ Extract normalization embeddings using averaging.
183 |
184 | Returns:
185 | Tuple[np.array, np.array]: vectors for individual speakers, global mean over all speakers
186 | """
187 | speakers_dict, fns = {}, []
188 | with open(self.norm_list) as f:
189 | for line in f:
190 | if len(line.split()) > 1: # number of speakers is defined
191 | line = line.split()[0]
192 | else:
193 | line = line.replace(os.linesep, '')
194 | fns.append(line)
195 |
196 | speakers_dict = process_files(fns, speakers_dict=speakers_dict, features_extractor=self.features_extractor,
197 | embedding_extractor=self.embedding_extractor, audio_dir=self.audio_dir,
198 | wav_suffix=self.wav_suffix, in_rttm_dir=self.in_rttm_dir,
199 | rttm_suffix=self.rttm_suffix, min_length=self.min_length, n_jobs=self.n_jobs)
200 | assert len(speakers_dict) == len(fns)
201 | # all are the same
202 | merged_speakers_dict = speakers_dict[0]
203 |
204 | if self.out_emb_dir:
205 | for speaker in merged_speakers_dict:
206 | out_path = os.path.join(self.out_emb_dir, f'{speaker}.pkl')
207 | mkdir_p(os.path.dirname(out_path))
208 | with open(out_path, 'wb') as f:
209 | pickle.dump(merged_speakers_dict[speaker], f, pickle.HIGHEST_PROTOCOL)
210 |
211 | for speaker in merged_speakers_dict:
212 | merged_speakers_dict[speaker] = np.mean(merged_speakers_dict[speaker], axis=0)
213 |
214 | return np.array(list(merged_speakers_dict.values()))
215 |
216 | def load_embeddings(self):
217 | """ Load normalization embeddings from pickle files.
218 |
219 | Returns:
220 | np.array: embeddings per speaker
221 | """
222 | embeddings, speakers = [], set()
223 | with open(self.norm_list) as f:
224 | for file_name in f:
225 | if len(file_name.split()) > 1: # number of speakers is defined
226 | file_name = file_name.split()[0]
227 | else:
228 | file_name = file_name.replace(os.linesep, '')
229 | with open('{}{}'.format(os.path.join(self.in_rttm_dir, file_name), self.rttm_suffix)) as fp:
230 | for line in fp:
231 | speakers.add(line.split()[7])
232 |
233 | logger.info('Loading pickled normalization embeddings from `{}`.'.format(self.in_emb_dir))
234 | for speaker in speakers:
235 | embedding_path = os.path.join(self.in_emb_dir, '{}.pkl'.format(speaker))
236 | if os.path.isfile(embedding_path):
237 | logger.info('Loading normalization pickle file `{}`.'.format(speaker))
238 | with open(embedding_path, 'rb') as f:
239 | # append mean from speaker's embeddings
240 | speaker_embeddings = pickle.load(f)
241 | embeddings.append(np.mean(speaker_embeddings, axis=0))
242 | else:
243 | logger.warning('No pickle file found for `{}` in `{}`.'.format(speaker, self.in_emb_dir))
244 | return np.array(embeddings)
245 |
246 | def s_norm(self, test, enroll):
247 | """ Run speaker normalization (S-Norm) on cached embeddings.
248 |
249 | Args:
250 | test (np.array): test embedding
251 | enroll (np.array): enroll embedding
252 |
253 | Returns:
254 | float: hypothesis
255 | """
256 | if self.plda:
257 | a = self.plda.score(test, self.embeddings).T
258 | b = self.plda.score(enroll, self.embeddings).T
259 | c = self.plda.score(enroll, test).T
260 | else:
261 | a = cosine_similarity(test, self.embeddings).T
262 | b = cosine_similarity(enroll, self.embeddings).T
263 | c = cosine_similarity(enroll, test).T
264 | scores = []
265 | for ii in range(test.shape[0]):
266 | test_scores = []
267 | for jj in range(enroll.shape[0]):
268 | test_mean, test_std = np.mean(a.T[ii]), np.std(a.T[ii])
269 | enroll_mean, enroll_std = np.mean(b.T[jj]), np.std(b.T[jj])
270 | s = c[ii][jj]
271 | test_scores.append((((s - test_mean) / test_std + (s - enroll_mean) / enroll_std) / 2))
272 | scores.append(test_scores)
273 | return np.array(scores)
274 |
--------------------------------------------------------------------------------
/vbdiar/utils/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | #
4 | # Copyright (C) 2018 Brno University of Technology FIT
5 | # Author: Jan Profant
6 | # All Rights Reserved
7 |
8 | import os
9 | import errno
10 |
11 |
12 | def mkdir_p(path):
13 | """ Behaviour similar to mkdir -p in shell.
14 |
15 | Args:
16 | path (string_types): path to create
17 | """
18 | try:
19 | os.makedirs(path)
20 | except OSError as exc: # Python >2.5
21 | if exc.errno == errno.EEXIST and os.path.isdir(path):
22 | pass
23 | else:
24 | raise ValueError('Can not create directory {}.'.format(path))
25 |
26 |
27 |
28 |
--------------------------------------------------------------------------------
/vbdiar/utils/utils.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | #
4 | # Copyright (C) 2018 Brno University of Technology FIT
5 | # Author: Jan Profant
6 | # All Rights Reserved
7 |
8 | import os
9 | import re
10 | import random
11 | from os import listdir
12 | from os.path import isfile, join
13 | import fnmatch
14 | import math
15 |
16 | import numpy as np
17 | import yaml
18 |
19 |
20 | class Utils(object):
21 | """ Class tools handles basic operations with files and directories.
22 |
23 | """
24 |
25 | def __init__(self):
26 | """ tools class constructor.
27 |
28 | """
29 | return
30 |
31 | @staticmethod
32 | def list_directory_by_suffix(directory, suffix):
33 | """ Return listed directory of files based on their suffix.
34 |
35 | :param directory: directory to be listed
36 | :type directory: str
37 | :param suffix: suffix of files in directory
38 | :type suffix: str
39 | :returns: list of files specified byt suffix in directory
40 | :rtype: list
41 |
42 | >>> Utils.list_directory_by_suffix('../../tests/tools', '.test')
43 | ['empty1.test', 'empty2.test']
44 | >>> Utils.list_directory_by_suffix('../../tests/tools_no_ex', '.test')
45 | Traceback (most recent call last):
46 | ...
47 | toolsException: [listDirectoryBySuffix] No directory found!
48 | >>> Utils.list_directory_by_suffix('../../tests/tools', '.py')
49 | []
50 | """
51 | abs_dir = os.path.abspath(directory)
52 | try:
53 | ofiles = [f for f in listdir(abs_dir) if isfile(join(abs_dir, f))]
54 | except OSError:
55 | raise ValueError('No directory named {} found!'.format(directory))
56 | out = []
57 | for file_in in ofiles:
58 | if file_in.find(suffix) != -1:
59 | out.append(file_in)
60 | out.sort()
61 | return out
62 |
63 | @staticmethod
64 | def list_directory(directory):
65 | """ List directory.
66 |
67 | :param directory: directory to be listed
68 | :type directory: str
69 | :returns: list with files in directory
70 | :rtype: list
71 |
72 | >>> Utils.list_directory('../../tests/tools')
73 | ['empty1.test', 'empty2.test', 'test', 'test.txt']
74 | >>> Utils.list_directory('../../tests/tools_no_ex')
75 | Traceback (most recent call last):
76 | ...
77 | toolsException: [listDirectory] No directory found!
78 | """
79 | directory = os.path.abspath(directory)
80 | try:
81 | out = [f for f in listdir(directory)]
82 | except OSError:
83 | raise ValueError('No directory found!')
84 | out.sort()
85 | return out
86 |
87 | @staticmethod
88 | def recursively_list_directory_by_suffix(directory, suffix):
89 | """ Return recursively listed directory of files based on their suffix.
90 |
91 | :param directory: directory to be listed
92 | :type directory: str
93 | :param suffix: suffix of files in directory
94 | :type suffix: str
95 | :returns: list of files specified by suffix in directory
96 | :rtype: list
97 |
98 | >>> Utils.recursively_list_directory_by_suffix( \
99 | '../../tests/tools', '.test')
100 | ['empty1.test', 'empty2.test', 'test/empty.test']
101 | >>> Utils.recursively_list_directory_by_suffix( \
102 | '../../tests/tools_no_ex', '.test')
103 | []
104 | """
105 | matches = []
106 | for root, dirnames, filenames in os.walk(directory):
107 | for filename in fnmatch.filter(filenames, '*' + suffix):
108 | app = os.path.join(root, filename).replace(directory + '/', '')
109 | matches.append(app)
110 | matches.sort()
111 | return matches
112 |
113 | @staticmethod
114 | def sed_in_file(input_file, regex1, regex2):
115 | """ Replace in input file by regex.
116 |
117 | :param input_file: input file
118 | :type input_file: str
119 | :param regex1: regular expression 1
120 | :type regex1: str
121 | :param regex2: regular expression 2
122 | :type regex2: str
123 | """
124 | with open(input_file, 'r') as sources:
125 | lines = sources.readlines()
126 | with open(input_file, 'w') as sources:
127 | for line in lines:
128 | sources.write(re.sub(regex1, regex2, line))
129 |
130 | @staticmethod
131 | def remove_lines_in_file_by_indexes(input_file, lines_indexes):
132 | """ Remove specified lines in file.
133 |
134 | :param input_file: input file name
135 | :type input_file: str
136 | :param lines_indexes: list with lines
137 | :type lines_indexes: list
138 | """
139 | with open(input_file, 'r') as sources:
140 | lines = sources.readlines()
141 | with open(input_file, 'w') as sources:
142 | for i in range(len(lines)):
143 | if i not in lines_indexes:
144 | sources.write(lines[i])
145 |
146 | @staticmethod
147 | def get_method(instance, method):
148 | """ Get method pointer.
149 |
150 | :param instance: input object
151 | :type instance: object
152 | :param method: name of method
153 | :type method: str
154 | :returns: pointer to method
155 | :rtype: method
156 | """
157 | try:
158 | attr = getattr(instance, method)
159 | except AttributeError:
160 | raise ValueError('Unknown class method!')
161 | return attr
162 |
163 | @staticmethod
164 | def configure_instance(instance, input_list):
165 | """ Configures instance base on methods list.
166 |
167 | :param instance: reference to class instance
168 | :type instance: object
169 | :param input_list: input list with name of class members
170 | :type input_list: list
171 | :returns: configured instance
172 | :rtype: object
173 | """
174 | for line in input_list:
175 | variable = line[:line.rfind('=')]
176 | value = line[line.rfind('=') + 1:]
177 | method_callback = Utils.get_method(instance, 'Set' + variable)
178 | method_callback(value)
179 | return instance
180 |
181 | @staticmethod
182 | def sort(scores, col=None):
183 | """ Sort scores list where score is in n-th-1 column.
184 |
185 | :param scores: scores list to be sorted
186 | :type scores: list
187 | :param col: index of column
188 | :type col: int
189 | :returns: sorted scores list
190 | :rtype: list
191 |
192 | >>> Utils.sort([['f1', 'f2', 10.0], \
193 | ['f3', 'f4', -10.0], \
194 | ['f5', 'f6', 9.58]], col=2)
195 | [['f3', 'f4', -10.0], ['f5', 'f6', 9.58], ['f1', 'f2', 10.0]]
196 | >>> Utils.sort([4.59, 8.8, 6.9, -10001.478])
197 | [-10001.478, 4.59, 6.9, 8.8]
198 | """
199 | if col is None:
200 | return sorted(scores, key=float)
201 | else:
202 | return sorted(scores, key=lambda x: x[col])
203 |
204 | @staticmethod
205 | def reverse_sort(scores, col=None):
206 | """ Reversively sort scores list where score is in n-th column.
207 |
208 | :param scores: scores list to be sorted
209 | :type scores: list
210 | :param col: number of columns
211 | :type col: int
212 | :returns: reversively sorted scores list
213 | :rtype: list
214 |
215 | >>> Utils.reverse_sort([['f1', 'f2', 10.0], \
216 | ['f3', 'f4', -10.0], \
217 | ['f5', 'f6', 9.58]], col=2)
218 | [['f1', 'f2', 10.0], ['f5', 'f6', 9.58], ['f3', 'f4', -10.0]]
219 | >>> Utils.reverse_sort([4.59, 8.8, 6.9, -10001.478])
220 | [8.8, 6.9, 4.59, -10001.478]
221 | """
222 | if col is None:
223 | return sorted(scores, key=float, reverse=True)
224 | else:
225 | return sorted(scores, key=lambda x: x[col], reverse=True)
226 |
227 | @staticmethod
228 | def get_nth_col(in_list, col):
229 | """ Extract n-th-1 columns from list.
230 |
231 | :param in_list: input list
232 | :type in_list: list
233 | :param col: column
234 | :type col: int
235 | :returns: list only with one column
236 | :rtype: list
237 |
238 | >>> Utils.get_nth_col([['1', '2'], ['3', '4'], ['5', '6']], col=1)
239 | ['2', '4', '6']
240 | >>> Utils.get_nth_col([['1', '2'], ['3', '4'], ['5', '6']], col=42)
241 | Traceback (most recent call last):
242 | ...
243 | toolsException: [getNthCol] Column out of range!
244 | """
245 | try:
246 | out = [row[col] for row in in_list]
247 | except IndexError:
248 | raise ValueError('Column out of range!')
249 | return out
250 |
251 | @staticmethod
252 | def find_in_dictionary(in_dict, value):
253 | """ Find value in directory whose items are lists and return key.
254 |
255 | :param in_dict: dictionary to search in
256 | :type in_dict: dict
257 | :param value: value to find
258 | :type value: any
259 | :returns: dictionary key
260 | :rtype: any
261 |
262 | >>> Utils.find_in_dictionary({ 0 : [42], 1 : [88], 2 : [69]}, 69)
263 | 2
264 | >>> Utils.find_in_dictionary(dict(), 69)
265 | Traceback (most recent call last):
266 | ...
267 | toolsException: [findInDictionary] Value not found!
268 | """
269 | for key in in_dict:
270 | if value in in_dict[key]:
271 | return key
272 | raise ValueError('Value not found!')
273 |
274 | @staticmethod
275 | def get_scores(scores, key):
276 | """ Get scores from scores list by key.
277 |
278 | :param scores: input scores list
279 | :type scores: list
280 | :param key: key to find
281 | :type key: list
282 | :returns: score if key is present in score, None otherwise
283 | :rtype: float
284 |
285 | >>> Utils.get_scores([['f1', 'f2', 10.1], ['f3', 'f4', 20.1], \
286 | ['f5', 'f6', 30.1]], ['f6', 'f5'])
287 | 30.1
288 | """
289 | if len(key) != 2:
290 | raise ValueError('Unexpected key!')
291 | if len(scores[0]) != 3:
292 | raise ValueError('Invalid input list!')
293 | for score in scores:
294 | a = score[0]
295 | b = score[1]
296 | if (key[0] == a and key[1] == b) or (key[0] == b and key[1] == a):
297 | return score[2]
298 | return None
299 |
300 | @staticmethod
301 | def get_line_from_file(line_num, infile):
302 | """ Get specified line from file.
303 |
304 | :param line_num: number of line
305 | :type line_num: int
306 | :param infile: file name
307 | :type infile: str
308 | :returns: specified line, None otherwise
309 | :rtype: str
310 |
311 | >>> Utils.get_line_from_file(3, '../../tests/tools/test.txt')
312 | 'c\\n'
313 | >>> Utils.get_line_from_file(10, '../../tests/tools/test.txt')
314 | Traceback (most recent call last):
315 | ...
316 | toolsException: [getLineFromFile] Line number not found!
317 | """
318 | with open(infile) as fp:
319 | for i, line in enumerate(fp):
320 | if i == line_num - 1:
321 | return line
322 | raise ValueError('Line number {} not found in file.'.format(line_num, infile))
323 |
324 | @staticmethod
325 | def list2dict(input_list):
326 | """ Create dictionary from list in format [key1, key2, score].
327 |
328 | :param input_list: list to process
329 | :type input_list: list
330 | :returns: preprocessed dictionary
331 | :rtype: dict
332 |
333 | >>> Utils.list2dict([['f1', 'f2', 10.1], ['f3', 'f4', 20.1], \
334 | ['f5', 'f6', 30.1], ['f1', 'f3', 40.1]])
335 | {'f1 f2': 10.1, 'f5 f6': 30.1, 'f3 f4': 20.1, 'f1 f3': 40.1}
336 | >>> Utils.list2dict([['f1', 'f2', 10.1], ['f3', 'f4']])
337 | Traceback (most recent call last):
338 | ...
339 | toolsException: [list2Dict] Invalid format of input list!
340 | """
341 | dictionary = dict()
342 | for item in input_list:
343 | if len(item) != 3:
344 | raise ValueError('Invalid format of input list!')
345 | tmp_list = [item[0], item[1]]
346 | tmp_list.sort()
347 | dictionary[tmp_list[0] + ' ' + tmp_list[1]] = item[2]
348 | return dictionary
349 |
350 | @staticmethod
351 | def merge_dicts(*dict_args):
352 | """ Merge dictionaries into single one.
353 |
354 | :param dict_args: input dictionaries
355 | :type dict_args: dict array
356 | :returns: merged dictionaries into single one
357 | :rtype: dict
358 |
359 | >>> Utils.merge_dicts( \
360 | {'f1 f2': 10.1, 'f5 f6': 30.1, 'f1 f3': 40.1}, {'f6 f2': 50.1})
361 | {'f1 f2': 10.1, 'f5 f6': 30.1, 'f6 f2': 50.1, 'f1 f3': 40.1}
362 | """
363 | result = {}
364 | for dictionary in dict_args:
365 | result.update(dictionary)
366 | return result
367 |
368 | @staticmethod
369 | def save_object(obj, path):
370 | """ Saves object to disk.
371 |
372 | :param obj: reference to object
373 | :type obj: any
374 | :param path: path to file
375 | :type path: str
376 | """
377 | np.save(path, obj)
378 |
379 | @staticmethod
380 | def load_object(path):
381 | """ Loads object from disk.
382 |
383 | :param path: path to file
384 | :type path: str
385 | """
386 | np.load(path)
387 |
388 | @staticmethod
389 | def common_prefix(m):
390 | """ Given a list of pathnames, returns the longest prefix."
391 |
392 | :param m: input list
393 | :type m: list
394 | :returns: longest prefix in list
395 | :rtype: str
396 | """
397 | if not m:
398 | return ''
399 | s1 = min(m)
400 | s2 = max(m)
401 | for i, c in enumerate(s1):
402 | if c != s2[i]:
403 | return s1[:i]
404 | return s1
405 |
406 | @staticmethod
407 | def root_name(d):
408 | """ Return a root directory by name.
409 |
410 | :param d: directory name
411 | :type d: str
412 | :returns: root directory name
413 | :rtype d: str
414 | """
415 | pass
416 |
417 | @staticmethod
418 | def read_config(config_path):
419 | """ Read config in yaml format.
420 |
421 | Args:
422 | config_path (str): path to config file
423 |
424 | Returns:
425 |
426 | """
427 | with open(config_path, 'r') as ymlfile:
428 | return yaml.load(ymlfile)
429 |
430 | @staticmethod
431 | def l2_norm(ivecs):
432 | """ Perform L2 normalization.
433 |
434 | Args:
435 | ivecs (np.array): input i-vector
436 |
437 | Returns:
438 | np.array: normalized i-vectors
439 | """
440 | ret_ivecs = ivecs.copy()
441 | ret_ivecs /= np.sqrt((ret_ivecs ** 2).sum(axis=1)[:, np.newaxis])
442 | return ret_ivecs
443 |
444 | @staticmethod
445 | def cos_sim(v1, v2):
446 | """
447 |
448 | Args:
449 | v1 (np.array): first vector
450 | v2 (np.array): second vector
451 |
452 | Returns:
453 |
454 | """
455 | sumxx, sumxy, sumyy = 0, 0, 0
456 | for i in range(len(v1)):
457 | x = v1[i]
458 | y = v2[i]
459 | sumxx += x * x
460 | sumyy += y * y
461 | sumxy += x * y
462 | return sumxy / math.sqrt(sumxx * sumyy)
463 |
464 | @staticmethod
465 | def partition(large_list, n_sublists, shuffle=False):
466 | """Partition a list ``l`` into ``n`` sublists."""
467 | return np.array_split(large_list, n_sublists)
468 |
469 |
470 | if __name__ == "__main__":
471 | import doctest
472 | doctest.testmod()
473 |
--------------------------------------------------------------------------------
/vbdiar/vad/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | #
4 | # Copyright (C) 2018 Brno University of Technology FIT
5 | # Author: Jan Profant
6 | # All Rights Reserved
7 |
8 | from .vad import load_vad_lab_as_bool_vec, get_vad
9 |
--------------------------------------------------------------------------------
/vbdiar/vad/vad.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | #
4 | # Copyright (C) 2018 Brno University of Technology FIT
5 | # Author: Jan Profant
6 | # All Rights Reserved
7 |
8 | import logging
9 |
10 | import numpy as np
11 |
12 |
13 | logger = logging.getLogger(__name__)
14 |
15 |
16 | def get_vad(file_name, fea_len):
17 | """ Load .lab file as bool vector.
18 |
19 | Args:
20 | file_name (str): path to .lab file
21 | fea_len (int): length of features
22 |
23 | Returns:
24 | np.array: bool vector
25 | """
26 |
27 | logger.info('Loading VAD from file `{}`.'.format(file_name))
28 | return load_vad_lab_as_bool_vec(file_name)[:fea_len]
29 |
30 |
31 | def load_vad_lab_as_bool_vec(lab_file):
32 | """
33 |
34 | Args:
35 | lab_file:
36 |
37 | Returns:
38 |
39 | """
40 | lab_cont = np.atleast_2d(np.loadtxt(lab_file, dtype=object))
41 |
42 | if lab_cont.shape[1] == 0:
43 | return np.empty(0), 0, 0
44 |
45 | if lab_cont.shape[1] == 3:
46 | lab_cont = lab_cont[lab_cont[:, 2] == 'sp', :][:, [0, 1]]
47 |
48 | n_regions = lab_cont.shape[0]
49 | ii = 0
50 | while True:
51 | try:
52 | start1, end1 = float(lab_cont[ii][0]), float(lab_cont[ii][1])
53 | jj = ii + 1
54 | start2, end2 = float(lab_cont[jj][0]), float(lab_cont[jj][1])
55 | if end1 >= start2:
56 | lab_cont = np.delete(lab_cont, ii, axis=0)
57 | ii -= 1
58 | lab_cont[jj - 1][0] = str(start1)
59 | lab_cont[jj - 1][1] = str(max(end1, end2))
60 | ii += 1
61 | except IndexError:
62 | break
63 |
64 | vad = np.round(np.atleast_2d(lab_cont).astype(np.float).T * 100).astype(np.int)
65 | vad[1] += 1 # Paja's bug!!!
66 |
67 | if not vad.size:
68 | return np.empty(0, dtype=bool)
69 |
70 | npc1 = np.c_[np.zeros_like(vad[0], dtype=bool), np.ones_like(vad[0], dtype=bool)]
71 | npc2 = np.c_[vad[0] - np.r_[0, vad[1, :-1]], vad[1] - vad[0]]
72 | npc2[npc2 < 0] = 0
73 |
74 | out = np.repeat(npc1, npc2.flat)
75 |
76 | n_frames = sum(out)
77 |
78 | return out, n_regions, n_frames
79 |
80 |
81 |
--------------------------------------------------------------------------------