├── .gitignore ├── .idea └── encodings.xml ├── LICENSE ├── README.md ├── configs ├── mfcc.conf └── vbdiar.yml ├── examples ├── diarization.py ├── lists │ ├── AMI_dev-eval.scp │ ├── AMI_trn.scp │ ├── list.scp │ └── list_spk.scp ├── run.sh ├── vad │ └── fisher-english-p1 │ │ ├── fe_03_00001-a.lab.gz │ │ └── fe_03_00001-b.lab.gz └── wav │ └── fisher-english-p1 │ ├── README │ ├── fe_03_00001-a.wav │ ├── fe_03_00001-b.wav │ └── smalltalk0501.wav ├── models ├── LDA.npy ├── final.onnx ├── gplda │ ├── CB.npy │ ├── CW.npy │ └── mu.npy └── mean.npy ├── requirements.txt ├── setup.py └── vbdiar ├── __init__.py ├── clustering ├── __init__.py └── pldakmeans.py ├── embeddings ├── __init__.py └── embedding.py ├── features ├── __init__.py └── segments.py ├── kaldi ├── __init__.py ├── kaldi_xvector_extraction.py ├── mfcc_features_extraction.py ├── onnx_xvector_extraction.py └── utils.py ├── scoring ├── __init__.py ├── diarization.py ├── gplda.py ├── md-eval.pl └── normalization.py ├── utils ├── __init__.py └── utils.py └── vad ├── __init__.py └── vad.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | -------------------------------------------------------------------------------- /.idea/encodings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # vbdiar 2 | 3 | This project is DEPRECATED and moved to https://github.com/BUTSpeechFIT/VBx, which achieves better results. 4 | 5 | Speaker diarization based on x-vectors using pretrained model trained in Kaldi (https://github.com/kaldi-asr/kaldi) 6 | and converted to ONNX format (https://github.com/onnx/onnx) running in ONNXRuntime (https://github.com/Microsoft/onnxruntime). 7 | 8 | X-vector model was trained using VoxCeleb1 and VoxCeleb2 16k data (http://www.robots.ox.ac.uk/~vgg/data/voxceleb/index.html#about). 9 | 10 | If you make use of the code or model, cite this: https://www.vutbr.cz/en/students/final-thesis/detail/122072 11 | 12 | 13 | ## Dependencies 14 | 15 | Dependencies are listed in `requirements.txt`. 16 | 17 | ## Installation 18 | 19 | It is recommended to use anaconda environment https://www.anaconda.com/download/. 20 | Run `python setup.py install` 21 | Also, since we are using Kaldi, path to Kaldi root must be set in `vbdiar/kaldi/__init__.py` 22 | 23 | ## Configs 24 | 25 | Config file declares used models and paths to them. Example configuration file is `configs/vbdiar.yml`. 26 | 27 | ## Models 28 | 29 | Pretrained models are stored in `models/` directory. 30 | 31 | ## Examples 32 | 33 | Example script `examples/diarization.py` is able to run full diarization process. The code is designed in a way, that you have everything in same tree structure with relative paths in list and then you just specify directories - audio, VAD, output, etc. See example configuration. 34 | 35 | ### Required Arguments 36 | 37 | `'-l', '--input-list'` - specifies relative path to files for testing, it is possible to specify number of speakers as the second column. Do not use file suffixes, path is always relative to input directory and suffix. 38 | 39 | `'-c', '--configuration'` - specifies configuration file/ 40 | 41 | `'-m', '--mode'` - specifies running mode, there are two possible modes, classic `diarization` mode which should segment 42 | utterance into speakers and `sre` mode used for speaker recognition, which runs clustering for N iterations and saves all clusters 43 | 44 | ### Non-required Arguments 45 | 46 | `'--audio-dir'` - directory with audio files in `.wav` format - `8000Hz, 16bit-s, 1c`. 47 | 48 | `'--vad-dir'` - directory with lab files - Voice/Speech activity detection - format `speech_start speech_end`. 49 | 50 | `'--in-emb-dir'` - input directory containing embeddings (if they were previously saved). 51 | 52 | `'--out-emb-dir'` - output directory for storing embeddings. 53 | 54 | `'--norm-list'` - input list with files for score normalization. When performing score normalization, it is necessary to use input ground truth `.rttm` files with unique speaker label. Speaker labels should not overlap, only in case, that there is same speaker in more audio files. All normalization utterances will be merged by speaker labels. 55 | 56 | `'--in-rttm-dir'` - input directory with `.rttm` files (used primary for score normalization) 57 | 58 | `'--out-rttm-dir'` - output directory for storing `.rttm` files 59 | 60 | `'--min-window-size'` - minimal size of embedding window in miliseconds. Defines minimal size used for clustering algorithms. 61 | 62 | `'--max-window-size'` - maximal size of embedding window in miliseconds. 63 | 64 | `'--vad-tolerance'` - skip `n` frames of non-speech and merge them as speech. 65 | 66 | `'--max-num-speakers'` - maximal number of speakers. Used in clustering algorithm. 67 | 68 | `'--use-gpu'` - use GPU instead of cpu (onnxruntime-gpu must be installed) 69 | 70 | 71 | ## Results on Datasets 72 | 73 | ### AMI corpus http://groups.inf.ed.ac.uk/ami/corpus/ (development and evaluation set together) 74 | It is important to note that these results are obtained using summed individual head-mounted microphones. 75 | Results are reporting when using oracle number of speakers, collar size 0.25s and without scoring overlapped speech. 76 | Data were upsampled from 8k to 16k and 8k wav data are no longer supported. 77 | 78 | Results can be obtained using similar command 79 | ```bash 80 | python diarization.py -c ../configs/vbdiar.yml -l lists/AMI_dev-eval.scp --audio-dir wav/AMI/IHM_SUM --vad-dir vad/AMI --out-emb-dir emb/AMI/IHM_SUM --in-rttm-dir rttms/AMI 81 | ``` 82 | 83 | | System | DER | 84 | |------------------------------------------------------------------------|-------| 85 | | Oracle number of speakers + x-vectors + mean + LDA + L2 Norm + GPLDA | 6.67 | 86 | | Oracle number of speakers + x-vectors + mean + LDA + L2 Norm | 9.16 | 87 | | x-vectors + mean + LDA + L2 Norm + GPLDA | 15.54 | 88 | -------------------------------------------------------------------------------- /configs/mfcc.conf: -------------------------------------------------------------------------------- 1 | --sample-frequency=16000 2 | --frame-length=25 # the default is 25 3 | --low-freq=20 # the default. 4 | --high-freq=7700 # the default is zero meaning use the Nyquist (4k in this case). 5 | --num-ceps=23 # higher than the default which is 12. 6 | --snip-edges=false 7 | -------------------------------------------------------------------------------- /configs/vbdiar.yml: -------------------------------------------------------------------------------- 1 | # defines MFCC configuration, following Kaldi's conventions 2 | MFCC: 3 | config_path: ../configs/mfcc.conf 4 | apply_cmvn_sliding: True 5 | norm_vars: False 6 | center: True 7 | cmn_window: 300 8 | 9 | # defines properties of x-vector extractor, such as neural net, ... 10 | EmbeddingExtractor: 11 | onnx_path: ../models/final.onnx 12 | 13 | PLDA: 14 | path: ../models/gplda 15 | 16 | # specifies possible transformations to embeddings - mean subtraction, LDA, l2 normalization, ... 17 | Transforms: 18 | mean: ../models/mean.npy 19 | lda: ../models/LDA.npy 20 | use_l2_norm: True 21 | -------------------------------------------------------------------------------- /examples/diarization.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # 4 | # Copyright (C) 2018 Brno University of Technology FIT 5 | # Author: Jan Profant 6 | # All Rights Reserved 7 | 8 | import os 9 | import sys 10 | import ctypes 11 | import logging 12 | import argparse 13 | import multiprocessing 14 | import subprocess 15 | 16 | import numpy as np 17 | 18 | from vbdiar.scoring.gplda import GPLDA 19 | from vbdiar.vad import get_vad 20 | from vbdiar.utils import mkdir_p 21 | from vbdiar.utils.utils import Utils 22 | from vbdiar.embeddings.embedding import extract_embeddings 23 | from vbdiar.scoring.diarization import Diarization 24 | from vbdiar.scoring.normalization import Normalization 25 | from vbdiar.kaldi.onnx_xvector_extraction import ONNXXVectorExtraction 26 | from vbdiar.features.segments import get_segments, get_time_from_frames, get_frames_from_time 27 | from vbdiar.kaldi.mfcc_features_extraction import KaldiMFCCFeatureExtraction 28 | 29 | 30 | logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') 31 | logger = logging.getLogger(__name__) 32 | 33 | CDIR = os.path.dirname(os.path.realpath(__file__)) 34 | 35 | 36 | def _process_files(dargs): 37 | """ 38 | 39 | Args: 40 | dargs: 41 | 42 | Returns: 43 | 44 | """ 45 | fns, kwargs = dargs 46 | ret = [] 47 | for fn in fns: 48 | ret.append(process_file(file_name=fn, **kwargs)) 49 | return ret 50 | 51 | 52 | def process_files(fns, wav_dir, vad_dir, out_dir, features_extractor, embedding_extractor, min_size, 53 | max_size, overlap, tolerance, wav_suffix='.wav', vad_suffix='.lab.gz', n_jobs=1): 54 | """ Process all files from list. 55 | 56 | Args: 57 | fns (list): name of files to process 58 | wav_dir (str): directory with wav files 59 | vad_dir (str): directory with vad files 60 | out_dir (str|None): output directory 61 | features_extractor (Any): intialized object for feature extraction 62 | embedding_extractor (Any): initialized object for embedding extraction 63 | max_size (int): maximal size of window in ms 64 | min_size (int): minimal size of window in ms 65 | overlap (int): size of window overlap in ms 66 | tolerance (int): accept given number of frames as speech even when it is marked as silence 67 | wav_suffix (str): suffix of wav files 68 | vad_suffix (str): suffix of vad files 69 | n_jobs (int): number of jobs to run in parallel 70 | 71 | Returns: 72 | List[EmbeddingSet] 73 | """ 74 | kwargs = dict(wav_dir=wav_dir, vad_dir=vad_dir, out_dir=out_dir, features_extractor=features_extractor, 75 | embedding_extractor=embedding_extractor, tolerance=tolerance, min_size=min_size, 76 | max_size=max_size, overlap=overlap, wav_suffix=wav_suffix, vad_suffix=vad_suffix) 77 | if n_jobs == 1: 78 | ret = _process_files((fns, kwargs)) 79 | else: 80 | pool = multiprocessing.Pool(n_jobs) 81 | ret = pool.map(_process_files, ((part, kwargs) for part in Utils.partition(fns, n_jobs))) 82 | return [item for sublist in ret for item in sublist] 83 | 84 | 85 | def process_file(wav_dir, vad_dir, out_dir, file_name, features_extractor, embedding_extractor, 86 | min_size, max_size, overlap, tolerance, wav_suffix='.wav', vad_suffix='.lab.gz'): 87 | """ Process single audio file. 88 | 89 | Args: 90 | wav_dir (str): directory with wav files 91 | vad_dir (str): directory with vad files 92 | out_dir (str): output directory 93 | file_name (str): name of the file 94 | features_extractor (Any): intialized object for feature extraction 95 | embedding_extractor (Any): initialized object for embedding extraction 96 | max_size (int): maximal size of window in ms 97 | max_size (int): maximal size of window in ms 98 | overlap (int): size of window overlap in ms 99 | tolerance (int): accept given number of frames as speech even when it is marked as silence 100 | wav_suffix (str): suffix of wav files 101 | vad_suffix (str): suffix of vad files 102 | 103 | Returns: 104 | EmbeddingSet 105 | """ 106 | logger.info('Processing file {}.'.format(file_name.split()[0])) 107 | num_speakers = None 108 | if len(file_name.split()) > 1: # number of speakers is defined 109 | file_name, num_speakers = file_name.split()[0], int(file_name.split()[1]) 110 | 111 | wav_dir, vad_dir = os.path.abspath(wav_dir), os.path.abspath(vad_dir) 112 | if out_dir: 113 | out_dir = os.path.abspath(out_dir) 114 | 115 | # extract features 116 | features = features_extractor.audio2features(os.path.join(wav_dir, f'{file_name}{wav_suffix}')) 117 | 118 | # load voice activity detection from file 119 | vad, _, _ = get_vad(f'{os.path.join(vad_dir, file_name)}{vad_suffix}', features.shape[0]) 120 | 121 | # parse segments and split features 122 | features_dict = {} 123 | for seg in get_segments(vad, max_size, tolerance): 124 | seg_start, seg_end = seg 125 | start, end = get_time_from_frames(seg_start), get_time_from_frames(seg_end) 126 | if start >= overlap: 127 | seg_start = get_frames_from_time(start - overlap) 128 | if seg_start > features.shape[0] - 1 or seg_end > features.shape[0] - 1: 129 | logger.warning(f'Frames not aligned, number of frames {features.shape[0]} and got ending segment {seg_end}') 130 | seg_end = features.shape[0] 131 | features_dict[(start, end)] = features[seg_start:seg_end] 132 | 133 | # extract embedding for each segment 134 | embedding_set = extract_embeddings(features_dict, embedding_extractor) 135 | embedding_set.name = file_name 136 | embedding_set.num_speakers = num_speakers 137 | 138 | # save embeddings if required 139 | if out_dir is not None: 140 | mkdir_p(os.path.join(out_dir, os.path.dirname(file_name))) 141 | embedding_set.save(os.path.join(out_dir, '{}.pkl'.format(file_name))) 142 | 143 | return embedding_set 144 | 145 | 146 | def set_mkl(num_cores=1): 147 | """ Set number of cores for mkl library. 148 | 149 | Args: 150 | num_cores (int): number of cores 151 | """ 152 | try: 153 | mkl_rt = ctypes.CDLL('libmkl_rt.so') 154 | mkl_rt.mkl_set_dynamic(ctypes.byref(ctypes.c_int(0))) 155 | mkl_rt.mkl_set_num_threads(ctypes.byref(ctypes.c_int(num_cores))) 156 | except OSError: 157 | logger.warning('Failed to import libmkl_rt.so, it will not be possible to use mkl backend.') 158 | 159 | 160 | def get_gpu(really=True): 161 | try: 162 | if really: 163 | command = 'nvidia-smi --query-gpu=memory.free,memory.total --format=csv |tail -n+2| ' \ 164 | 'awk \'BEGIN{FS=" "}{if ($1/$3 > 0.98) print NR-1}\'' 165 | gpu_idx = subprocess.check_output(command, shell=True).rsplit(b'\n')[0].decode('utf-8') 166 | else: 167 | gpu_idx = '-1' 168 | os.environ["CUDA_VISIBLE_DEVICES"] = gpu_idx 169 | except subprocess.CalledProcessError: 170 | logger.warning('No GPUs seems to be available.') 171 | 172 | 173 | if __name__ == '__main__': 174 | parser = argparse.ArgumentParser('Extract embeddings used for diarization from audio wav files.') 175 | 176 | # required 177 | parser.add_argument('-l', '--input-list', help='list of input files without suffix', 178 | action='store', required=True) 179 | parser.add_argument('-c', '--configuration', help='input configuration of models', 180 | action='store', required=True) 181 | parser.add_argument('-m', '--mode', required=True, choices=['sre', 'diarization'], 182 | help='mode used - there are two possible modes, classic `diarization` mode which should' 183 | 'segment utterance into speakers and `sre` mode used for speaker recognition, which ' 184 | 'runs clustering for N iterations and saves all clusters') 185 | 186 | # not required 187 | parser.add_argument('--audio-dir', 188 | help='directory with audio files in .wav format - 8000Hz, 16bit-s, 1c', required=False) 189 | parser.add_argument('--vad-dir', 190 | help='directory with lab files - Voice/Speech activity detection', required=False) 191 | parser.add_argument('--in-emb-dir', 192 | help='input directory containing embeddings', required=False) 193 | parser.add_argument('--out-emb-dir', 194 | help='output directory for storing embeddings', required=False) 195 | parser.add_argument('--norm-list', 196 | help='list of normalization files without suffix', required=False) 197 | parser.add_argument('--in-rttm-dir', 198 | help='input directory with rttm files', required=False) 199 | parser.add_argument('--out-rttm-dir', 200 | help='output directory for storing rttm files', required=False) 201 | parser.add_argument('--out-clusters-dir', required=False, 202 | help='output directory for storing clusters - only used when mode is `sre`') 203 | parser.add_argument('-wav-suffix', 204 | help='wav file suffix', required=False, default='.wav') 205 | parser.add_argument('-vad-suffix', 206 | help='Voice Activity Detector file suffix', required=False, default='.lab.gz') 207 | parser.add_argument('-rttm-suffix', 208 | help='rttm file suffix', required=False, default='.rttm') 209 | parser.add_argument('--min-window-size', default=1000, 210 | help='minimal window size for embedding clustering in ms', type=int, required=False) 211 | parser.add_argument('--max-window-size', default=2000, 212 | help='maximal window size for extracting embedding in ms', type=int, required=False) 213 | parser.add_argument('--window-overlap', 214 | help='overlap in window in ms', type=int, required=False, default=0) 215 | parser.add_argument('--vad-tolerance', default=0, 216 | help='tolerance critetion for ignoring frames of silence', type=float, required=False) 217 | parser.add_argument('--use-gpu', required=False, default=False, action='store_true', 218 | help='use GPU instead of cpu (onnxruntime-gpu must be installed)') 219 | parser.add_argument('--max-num-speakers', 220 | help='maximal number of speakers', required=False, default=10) 221 | 222 | args = parser.parse_args() 223 | 224 | logger.info(f'Running `{" ".join(sys.argv)}`.') 225 | 226 | set_mkl(1) 227 | get_gpu(args.use_gpu) 228 | 229 | # initialize extractor 230 | config = Utils.read_config(args.configuration) 231 | 232 | config_mfcc = config['MFCC'] 233 | config_path = os.path.abspath(config_mfcc['config_path']) 234 | if not os.path.isfile(config_path): 235 | raise ValueError(f'Path to MFCC configuration `{config_path}` not found.') 236 | features_extractor = KaldiMFCCFeatureExtraction( 237 | config_path=config_path, apply_cmvn_sliding=config_mfcc['apply_cmvn_sliding'], 238 | norm_vars=config_mfcc['norm_vars'], center=config_mfcc['center'], cmn_window=config_mfcc['cmn_window']) 239 | 240 | config_embedding_extractor = config['EmbeddingExtractor'] 241 | embedding_extractor = ONNXXVectorExtraction(onnx_path=os.path.abspath(config_embedding_extractor['onnx_path'])) 242 | 243 | config_transforms = config['Transforms'] 244 | mean = config_transforms.get('mean') 245 | lda = config_transforms.get('lda') 246 | if lda is not None: 247 | lda = np.load(lda) 248 | use_l2_norm = config_transforms.get('use_l2_norm') 249 | plda = config.get('PLDA') 250 | if plda is not None: 251 | plda = GPLDA(plda['path']) 252 | 253 | files = [line.rstrip('\n') for line in open(args.input_list)] 254 | 255 | # extract embeddings 256 | if args.in_emb_dir is None: 257 | if args.out_emb_dir is None: 258 | raise ValueError('At least one of `--in-emb-dir` or `--out-emb-dir` must be specified.') 259 | if args.audio_dir is None: 260 | raise ValueError('At least one of `--in-emb-dir` or `--audio-dir` must be specified.') 261 | if args.vad_dir is None: 262 | raise ValueError('`--audio-dir` was specified, `--vad-dir` must be specified too.') 263 | process_files( 264 | fns=files, wav_dir=args.audio_dir, vad_dir=args.vad_dir, out_dir=args.out_emb_dir, 265 | features_extractor=features_extractor, embedding_extractor=embedding_extractor, 266 | min_size=args.min_window_size, max_size=args.max_window_size, overlap=args.window_overlap, 267 | tolerance=args.vad_tolerance, wav_suffix=args.wav_suffix, vad_suffix=args.vad_suffix, 268 | n_jobs=1) 269 | if args.out_emb_dir: 270 | embeddings = args.out_emb_dir 271 | else: 272 | embeddings = args.in_emb_dir 273 | 274 | # initialize normalization 275 | if args.norm_list is not None: 276 | norm = Normalization(norm_list=args.norm_list, audio_dir=args.audio_dir, 277 | in_rttm_dir=args.in_rttm_dir, in_emb_dir=args.in_emb_dir, 278 | out_emb_dir=args.out_emb_dir, min_length=args.min_window_size, plda=plda, 279 | embedding_extractor=embedding_extractor, features_extractor=features_extractor, 280 | wav_suffix=args.wav_suffix, rttm_suffix=args.rttm_suffix, n_jobs=1) 281 | else: 282 | norm = None 283 | 284 | # load transformations if specified 285 | if not norm: 286 | if mean: 287 | mean = np.load(mean) 288 | else: 289 | mean = norm.mean 290 | 291 | # run diarization 292 | diar = Diarization(args.input_list, embeddings, embeddings_mean=mean, lda=lda, 293 | use_l2_norm=use_l2_norm, plda=plda, norm=norm) 294 | result = diar.score_embeddings(args.min_window_size, args.max_num_speakers, args.mode) 295 | 296 | if args.mode == 'diarization': 297 | if args.in_rttm_dir: 298 | diar.evaluate(scores=result, in_rttm_dir=args.in_rttm_dir, collar_size=0.25, evaluate_overlaps=False) 299 | 300 | if args.out_rttm_dir is not None: 301 | diar.dump_rttm(result, args.out_rttm_dir) 302 | else: 303 | if args.out_clusters_dir: 304 | for name in result: 305 | mkdir_p(os.path.join(args.out_clusters_dir, os.path.dirname(name))) 306 | np.save(os.path.join(args.out_clusters_dir, name), result[name]) 307 | -------------------------------------------------------------------------------- /examples/lists/AMI_dev-eval.scp: -------------------------------------------------------------------------------- 1 | EN2002a 4 2 | EN2002b 4 3 | EN2002c 4 4 | EN2002d 4 5 | ES2004a 4 6 | ES2004b 4 7 | ES2004c 4 8 | ES2004d 4 9 | ES2011a 4 10 | ES2011b 4 11 | ES2011c 4 12 | ES2011d 4 13 | IB4001 4 14 | IB4002 4 15 | IB4003 4 16 | IB4004 4 17 | IB4010 4 18 | IB4011 4 19 | IS1008a 4 20 | IS1008b 4 21 | IS1008c 4 22 | IS1008d 4 23 | IS1009a 4 24 | IS1009b 4 25 | IS1009c 4 26 | IS1009d 4 27 | TS3003a 4 28 | TS3003b 4 29 | TS3003c 4 30 | TS3003d 4 31 | TS3004a 4 32 | TS3004b 4 33 | TS3004c 4 34 | TS3004d 4 35 | -------------------------------------------------------------------------------- /examples/lists/AMI_trn.scp: -------------------------------------------------------------------------------- 1 | EN2001a 2 | EN2001b 3 | EN2001d 4 | EN2001e 5 | EN2003a 6 | EN2004a 7 | EN2005a 8 | EN2006a 9 | EN2006b 10 | EN2009b 11 | EN2009c 12 | EN2009d 13 | ES2002a 14 | ES2002b 15 | ES2002c 16 | ES2002d 17 | ES2003a 18 | ES2003b 19 | ES2003c 20 | ES2003d 21 | ES2005a 22 | ES2005b 23 | ES2005c 24 | ES2005d 25 | ES2006a 26 | ES2006b 27 | ES2006c 28 | ES2006d 29 | ES2007a 30 | ES2007b 31 | ES2007c 32 | ES2007d 33 | ES2008a 34 | ES2008b 35 | ES2008c 36 | ES2008d 37 | ES2009a 38 | ES2009b 39 | ES2009c 40 | ES2009d 41 | ES2010a 42 | ES2010b 43 | ES2010c 44 | ES2010d 45 | ES2012a 46 | ES2012b 47 | ES2012c 48 | ES2012d 49 | ES2013a 50 | ES2013b 51 | ES2013c 52 | ES2013d 53 | ES2014a 54 | ES2014b 55 | ES2014c 56 | ES2014d 57 | ES2015a 58 | ES2015b 59 | ES2015c 60 | ES2015d 61 | ES2016a 62 | ES2016b 63 | ES2016c 64 | ES2016d 65 | IB4005 66 | IN1001 67 | IN1002 68 | IN1005 69 | IN1007 70 | IN1008 71 | IN1009 72 | IN1012 73 | IN1013 74 | IN1014 75 | IN1016 76 | IS1000a 77 | IS1000b 78 | IS1000c 79 | IS1000d 80 | IS1001a 81 | IS1001b 82 | IS1001c 83 | IS1001d 84 | IS1002b 85 | IS1002c 86 | IS1002d 87 | IS1003a 88 | IS1003c 89 | IS1003d 90 | IS1004a 91 | IS1004b 92 | IS1004c 93 | IS1004d 94 | IS1005a 95 | IS1005b 96 | IS1005c 97 | IS1006a 98 | IS1006b 99 | IS1006c 100 | IS1006d 101 | IS1007a 102 | IS1007b 103 | IS1007c 104 | TS3005a 105 | TS3005b 106 | TS3005c 107 | TS3005d 108 | TS3006a 109 | TS3006b 110 | TS3006c 111 | TS3006d 112 | TS3007a 113 | TS3007b 114 | TS3007c 115 | TS3007d 116 | TS3008a 117 | TS3008b 118 | TS3008c 119 | TS3008d 120 | TS3009a 121 | TS3009b 122 | TS3009c 123 | TS3009d 124 | TS3010a 125 | TS3010b 126 | TS3010c 127 | TS3010d 128 | TS3011a 129 | TS3011b 130 | TS3011c 131 | TS3011d 132 | TS3012a 133 | TS3012b 134 | TS3012c 135 | TS3012d 136 | -------------------------------------------------------------------------------- /examples/lists/list.scp: -------------------------------------------------------------------------------- 1 | fisher-english-p1/fe_03_00001-a 2 | fisher-english-p1/fe_03_00001-b 3 | -------------------------------------------------------------------------------- /examples/lists/list_spk.scp: -------------------------------------------------------------------------------- 1 | fe_03_00001-a 1 2 | fe_03_00001-b 1 3 | -------------------------------------------------------------------------------- /examples/run.sh: -------------------------------------------------------------------------------- 1 | python diarization.py -c ../configs/vbdiar.yml \ 2 | -l lists/list_spk.scp \ 3 | --audio-dir wav/fisher-english-p1 \ 4 | --vad-dir vad/fisher-english-p1 \ 5 | --mode diarization \ 6 | --out-emb-dir embeddings 7 | -------------------------------------------------------------------------------- /examples/vad/fisher-english-p1/fe_03_00001-a.lab.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jamiroquai88/VBDiarization/a1ceabab3f58c925b2b16272fc9b3a4cd795068e/examples/vad/fisher-english-p1/fe_03_00001-a.lab.gz -------------------------------------------------------------------------------- /examples/vad/fisher-english-p1/fe_03_00001-b.lab.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jamiroquai88/VBDiarization/a1ceabab3f58c925b2b16272fc9b3a4cd795068e/examples/vad/fisher-english-p1/fe_03_00001-b.lab.gz -------------------------------------------------------------------------------- /examples/wav/fisher-english-p1/README: -------------------------------------------------------------------------------- 1 | These files are taken from the Fisher English database. 2 | 3 | The parameters of the files: 4 | 8kHz, 16bits per sample, 1 channel 5 | 6 | For reference, they were generated using BUT's raw files using: 7 | # for f in *.raw; do nf=$(echo $f|sed 's@raw@wav@'); sox -r 8k -s -b 16 -c 1 $f $nf; done 8 | -------------------------------------------------------------------------------- /examples/wav/fisher-english-p1/fe_03_00001-a.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jamiroquai88/VBDiarization/a1ceabab3f58c925b2b16272fc9b3a4cd795068e/examples/wav/fisher-english-p1/fe_03_00001-a.wav -------------------------------------------------------------------------------- /examples/wav/fisher-english-p1/fe_03_00001-b.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jamiroquai88/VBDiarization/a1ceabab3f58c925b2b16272fc9b3a4cd795068e/examples/wav/fisher-english-p1/fe_03_00001-b.wav -------------------------------------------------------------------------------- /examples/wav/fisher-english-p1/smalltalk0501.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jamiroquai88/VBDiarization/a1ceabab3f58c925b2b16272fc9b3a4cd795068e/examples/wav/fisher-english-p1/smalltalk0501.wav -------------------------------------------------------------------------------- /models/LDA.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jamiroquai88/VBDiarization/a1ceabab3f58c925b2b16272fc9b3a4cd795068e/models/LDA.npy -------------------------------------------------------------------------------- /models/final.onnx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jamiroquai88/VBDiarization/a1ceabab3f58c925b2b16272fc9b3a4cd795068e/models/final.onnx -------------------------------------------------------------------------------- /models/gplda/CB.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jamiroquai88/VBDiarization/a1ceabab3f58c925b2b16272fc9b3a4cd795068e/models/gplda/CB.npy -------------------------------------------------------------------------------- /models/gplda/CW.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jamiroquai88/VBDiarization/a1ceabab3f58c925b2b16272fc9b3a4cd795068e/models/gplda/CW.npy -------------------------------------------------------------------------------- /models/gplda/mu.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jamiroquai88/VBDiarization/a1ceabab3f58c925b2b16272fc9b3a4cd795068e/models/gplda/mu.npy -------------------------------------------------------------------------------- /models/mean.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jamiroquai88/VBDiarization/a1ceabab3f58c925b2b16272fc9b3a4cd795068e/models/mean.npy -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | scipy==1.2.1 2 | scikit-learn==0.20.3 3 | numpy==1.16.2 4 | pyyaml 5 | onnxruntime==0.3.0 6 | spherecluster==0.1.7 7 | pyclustering==0.8.2 8 | kaldiio==2.13.4 9 | -------------------------------------------------------------------------------- /vbdiar/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # 4 | # Copyright (C) 2018 Brno University of Technology FIT 5 | # Author: Jan Profant 6 | # All Rights Reserved 7 | -------------------------------------------------------------------------------- /vbdiar/clustering/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # 4 | # Copyright (C) 2018 Brno University of Technology FIT 5 | # Author: Jan Profant 6 | # All Rights Reserved 7 | -------------------------------------------------------------------------------- /vbdiar/clustering/pldakmeans.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # 4 | # Copyright (C) 2018 Brno University of Technology FIT 5 | # Author: Jan Profant 6 | # All Rights Reserved 7 | 8 | import numpy as np 9 | 10 | 11 | class PLDAKMeans(object): 12 | """ KMeans clustering algorithm using PLDA output as distance metric. 13 | 14 | """ 15 | def __init__(self, centroids, k, plda, max_iter=10): 16 | """ Class constructor. 17 | 18 | :param centroids: initialization centroids 19 | :type centroids: numpy.array 20 | :param k: number of classes 21 | :type k: int 22 | :param plda: PLDA object 23 | :type plda: PLDA 24 | :param max_iter: maximal number of iterations 25 | :type max_iter: int 26 | """ 27 | self.max_iter = max_iter 28 | self.old_labels = [] 29 | self.data = None 30 | self.cluster_centers_ = centroids 31 | self.k = k 32 | self.plda = plda 33 | 34 | def fit(self, data): 35 | """ Fit the input data. 36 | 37 | :param data: input data 38 | :type data: numpy.array 39 | :returns: cluster centers 40 | :rtype: numpy.array 41 | """ 42 | self.data = data 43 | iterations = 0 44 | while True: 45 | if self.stop(iterations): 46 | break 47 | else: 48 | iterations += 1 49 | return self.cluster_centers_ 50 | 51 | def stop(self, iterations): 52 | """ Make the decision if algorithm should stop. 53 | 54 | :param iterations: number of successfull iterations 55 | :type iterations: int 56 | :returns: True if algorithm should stop, False otherwise 57 | :rtype: bool 58 | """ 59 | labels = self.labels() 60 | if iterations > self.max_iter or self.old_labels == labels: 61 | return True 62 | else: 63 | self.old_labels = labels 64 | return False 65 | 66 | def labels(self): 67 | """ Predict labels. 68 | 69 | """ 70 | scores = self.plda.score(self.data, self.cluster_centers_) 71 | centroids = {} 72 | for ii in range(self.k): 73 | centroids[ii] = [] 74 | labels = [] 75 | for ii in range(self.data.shape[0]): 76 | c = np.argmax(scores[ii]) 77 | labels.append(c) 78 | centroids[c].append(self.data[ii]) 79 | for ii in range(self.k): 80 | centroids[ii] = np.array(centroids[ii]) 81 | # clustering has strange behaviour 82 | if centroids[ii].ndim == 1: 83 | return self.old_labels 84 | self.cluster_centers_[ii] = np.mean(centroids[ii], axis=0) 85 | return labels 86 | -------------------------------------------------------------------------------- /vbdiar/embeddings/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # 4 | # Copyright (C) 2018 Brno University of Technology FIT 5 | # Author: Jan Profant 6 | # All Rights Reserved 7 | 8 | from .embedding import Embedding, EmbeddingSet 9 | -------------------------------------------------------------------------------- /vbdiar/embeddings/embedding.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # 4 | # Copyright (C) 2018 Brno University of Technology FIT 5 | # Author: Jan Profant 6 | # All Rights Reserved 7 | 8 | import os 9 | import pickle 10 | import numpy as np 11 | 12 | from vbdiar.features.segments import get_time_from_frames 13 | from vbdiar.utils import mkdir_p 14 | 15 | 16 | def extract_embeddings(features_dict, embedding_extractor): 17 | """ Extract embeddings from multiple segments. 18 | 19 | Args: 20 | features_dict (Dict): dictionary with segment range as key and features as values 21 | embedding_extractor (Any): 22 | 23 | Returns: 24 | EmbeddingSet: extracted embedding in embedding set 25 | """ 26 | embedding_set = EmbeddingSet() 27 | embeddings = embedding_extractor.features2embeddings(features_dict) 28 | for embedding_key in embeddings: 29 | start, end = embedding_key 30 | embedding_set.add(embeddings[embedding_key], window_start=int(float(start)), window_end=int(float(end))) 31 | return embedding_set 32 | 33 | 34 | class Embedding(object): 35 | """ Class for basic i-vector operations. 36 | 37 | """ 38 | 39 | def __init__(self): 40 | """ Class constructor. 41 | 42 | """ 43 | self.data = None 44 | self.features = None 45 | self.window_start = None 46 | self.window_end = None 47 | 48 | 49 | class EmbeddingSet(object): 50 | """ Class for encapsulating ivectors set. 51 | 52 | """ 53 | 54 | def __init__(self): 55 | """ Class constructor. 56 | 57 | """ 58 | self.name = None 59 | self.num_speakers = None 60 | self.embeddings = [] 61 | 62 | def __iter__(self): 63 | current = 0 64 | while current < len(self.embeddings): 65 | yield self.embeddings[current] 66 | current += 1 67 | 68 | def __getitem__(self, key): 69 | return self.embeddings[key] 70 | 71 | def __setitem__(self, key, value): 72 | self.embeddings[key] = value 73 | 74 | def __len__(self): 75 | return len(self.embeddings) 76 | 77 | def get_all_embeddings(self): 78 | """ Get all ivectors. 79 | 80 | """ 81 | a = [] 82 | for i in self.embeddings: 83 | a.append(i.data.flatten()) 84 | return np.array(a) 85 | 86 | def get_longer_embeddings(self, min_length): 87 | """ Get i-vectors extracted from longer segments than minimal length. 88 | 89 | Args: 90 | min_length (int): minimal length of segment in miliseconds 91 | 92 | Returns: 93 | np.array: i-vectors 94 | """ 95 | a = [] 96 | for embedding in self.embeddings: 97 | if embedding.window_end - embedding.window_start >= min_length: 98 | a.append(embedding.data.flatten()) 99 | return np.array(a) 100 | 101 | def add(self, data, window_start, window_end, features=None): 102 | """ Add embedding to set. 103 | 104 | Args: 105 | data (np.array): embeding data 106 | window_start (int): start of the window [ms] 107 | window_end (int): end of the window [ms] 108 | features (np.array): features from which embedding was extracted 109 | """ 110 | i = Embedding() 111 | i.data = data 112 | i.window_start = window_start 113 | i.window_end = window_end 114 | i.features = features 115 | self.__append(i) 116 | 117 | def __append(self, embedding): 118 | """ Append embedding to set of embedding. 119 | 120 | Args: 121 | embedding (Embedding): 122 | """ 123 | ii = 0 124 | for vp in self.embeddings: 125 | if vp.window_start > embedding.window_start: 126 | break 127 | ii += 1 128 | self.embeddings.insert(ii, embedding) 129 | 130 | def save(self, path): 131 | """ Save embedding set as pickled file. 132 | 133 | Args: 134 | path (string_types): output path 135 | """ 136 | mkdir_p(os.path.dirname(path)) 137 | with open(path, 'wb') as f: 138 | pickle.dump(self, f, pickle.HIGHEST_PROTOCOL) 139 | 140 | 141 | if __name__ == "__main__": 142 | import doctest 143 | doctest.testmod() 144 | -------------------------------------------------------------------------------- /vbdiar/features/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # 4 | # Copyright (C) 2018 Brno University of Technology FIT 5 | # Author: Jan Profant 6 | # All Rights Reserved 7 | -------------------------------------------------------------------------------- /vbdiar/features/segments.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # 4 | # Copyright (C) 2018 Brno University of Technology FIT 5 | # Author: Jan Profant 6 | # All Rights Reserved 7 | 8 | import logging 9 | 10 | import numpy as np 11 | 12 | 13 | logger = logging.getLogger(__name__) 14 | 15 | RATE = 16000 16 | SOURCERATE = 1250 17 | TARGETRATE = 100000 18 | 19 | ZMEANSOURCE = True 20 | WINDOWSIZE = 250000.0 21 | USEHAMMING = True 22 | PREEMCOEF = 0.97 23 | NUMCHANS = 24 24 | CEPLIFTER = 22 25 | NUMCEPS = 19 26 | ADDDITHER = 1.0 27 | RAWENERGY = True 28 | ENORMALISE = True 29 | 30 | deltawindow = accwindow = 2 31 | 32 | cmvn_lc = 150 33 | cmvn_rc = 150 34 | 35 | fs = 1e7 / SOURCERATE 36 | 37 | 38 | def get_segments(vad, max_size, tolerance): 39 | """ Return clustered speech segments. 40 | 41 | :param vad: list with labels - voice activity detection 42 | :type vad: list 43 | :param max_size: maximal size of window in ms 44 | :type max_size: int 45 | :param tolerance: accept given number of frames as speech even when it is marked as silence 46 | :type tolerance: int 47 | :returns: clustered segments 48 | :rtype: list 49 | """ 50 | clusters = get_clusters(vad, tolerance) 51 | segments = [] 52 | max_frames = get_frames_from_time(max_size) 53 | for item in clusters.values(): 54 | if item[1] - item[0] > max_frames: 55 | for ss in split_segment(item, max_frames): 56 | segments.append(ss) 57 | else: 58 | segments.append(item) 59 | return segments 60 | 61 | 62 | def split_segment(segment, max_size): 63 | """ Split segment to more with adaptive size. 64 | 65 | :param segment: input segment 66 | :type segment: tuple 67 | :param max_size: maximal size of window in ms 68 | :type max_size: int 69 | :returns: splitted segment 70 | :rtype: list 71 | """ 72 | size = segment[1] - segment[0] 73 | num_segments = int(np.math.ceil(size / max_size)) 74 | size_segment = size / num_segments 75 | for ii in range(num_segments): 76 | yield (int(segment[0] + ii * size_segment), int(segment[0] + (ii + 1) * size_segment)) 77 | 78 | 79 | def get_frames_from_time(n): 80 | """ Get number of frames from ms. 81 | 82 | :param n: number of ms 83 | :type n: int 84 | :returns: number of frames 85 | :rtype: int 86 | 87 | >>> get_frames_from_time(25) 88 | 1 89 | >>> get_frames_from_time(35) 90 | 2 91 | """ 92 | assert n >= 0, 'Time must be at least equal to 0.' 93 | if n < 25: 94 | return 0 95 | return int(1 + (n - WINDOWSIZE / 10000) / (TARGETRATE / 10000)) 96 | 97 | 98 | def get_time_from_frames(n): 99 | """ Get count of ms from number of frames. 100 | 101 | :param n: number of frames 102 | :type n: int 103 | :returns: number of ms 104 | :rtype: int 105 | 106 | >>> get_time_from_frames(1) 107 | 25 108 | >>> get_time_from_frames(2) 109 | 35 110 | 111 | """ 112 | return int(n * (TARGETRATE / 10000) - (TARGETRATE / 10000) + (WINDOWSIZE / 10000)) 113 | 114 | 115 | def get_clusters(vad, tolerance=10): 116 | """ Cluster speech segments. 117 | 118 | :param vad: list with labels - voice activity detection 119 | :type vad: list 120 | :param tolerance: accept given number of frames as speech even when it is marked as silence 121 | :type tolerance: int 122 | :returns: clustered speech segments 123 | :rtype: dict 124 | """ 125 | num_prev = 0 126 | in_tolerance = 0 127 | num_clusters = 0 128 | clusters = {} 129 | for ii, frame in enumerate(vad): 130 | if frame: 131 | num_prev += 1 132 | else: 133 | in_tolerance += 1 134 | if in_tolerance > tolerance: 135 | if num_prev > 0: 136 | clusters[num_clusters] = (ii - num_prev, ii) 137 | num_clusters += 1 138 | num_prev = 0 139 | in_tolerance = 0 140 | if num_prev > 0: 141 | clusters[num_clusters] = (ii - num_prev, ii) 142 | num_clusters += 1 143 | return clusters 144 | 145 | 146 | def split_seq(seq, size): 147 | """ Split up seq in pieces of size. 148 | 149 | Args: 150 | seq: 151 | size: 152 | 153 | Returns: 154 | 155 | """ 156 | return [seq[i:i + size] for i in range(0, len(seq), size)] 157 | -------------------------------------------------------------------------------- /vbdiar/kaldi/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # 4 | # Copyright (C) 2018 Brno University of Technology FIT 5 | # Author: Jan Profant 6 | # All Rights Reserved 7 | 8 | import os 9 | 10 | 11 | KALDI_ROOT_PATH = os.getenv('KALDI_ROOT_PATH', '.') 12 | 13 | bin_path = os.path.join(KALDI_ROOT_PATH, 'src', 'bin') 14 | featbin_path = os.path.join(KALDI_ROOT_PATH, 'src', 'featbin') 15 | nnet3bin_path = os.path.join(KALDI_ROOT_PATH, 'src', 'nnet3bin') 16 | 17 | -------------------------------------------------------------------------------- /vbdiar/kaldi/kaldi_xvector_extraction.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # 4 | # Copyright (C) 2018 Brno University of Technology FIT 5 | # Author: Jan Profant 6 | # All Rights Reserved 7 | 8 | import os 9 | import logging 10 | import tempfile 11 | import subprocess 12 | 13 | from vbdiar.kaldi import nnet3bin_path 14 | from vbdiar.kaldi.utils import write_txt_matrix, read_txt_vectors 15 | 16 | 17 | logger = logging.getLogger(__name__) 18 | 19 | 20 | class KaldiXVectorExtraction(object): 21 | 22 | def __init__(self, nnet, binary_path=nnet3bin_path, use_gpu=False, 23 | min_chunk_size=25, chunk_size=10000, cache_capacity=64): 24 | """ Initialize Kaldi x-vector extractor. 25 | 26 | Args: 27 | nnet (string_types): path to neural net 28 | use_gpu (bool): 29 | min_chunk_size (int): 30 | chunk_size (int): 31 | cache_capacity (int): 32 | """ 33 | self.nnet3_xvector_compute = os.path.join(binary_path, 'nnet3-xvector-compute') 34 | if not os.path.exists(self.nnet3_xvector_compute): 35 | raise ValueError( 36 | 'Path to nnet3-xvector-compute - `{}` does not exists.'.format(self.nnet3_xvector_compute)) 37 | self.nnet3_copy = os.path.join(binary_path, 'nnet3-copy') 38 | if not os.path.exists(self.nnet3_copy): 39 | raise ValueError( 40 | 'Path to nnet3-copy - `{}` does not exists.'.format(self.nnet3_copy)) 41 | if not os.path.isfile(nnet): 42 | raise ValueError('Invalid path to nnet `{}`.'.format(nnet)) 43 | else: 44 | self.nnet = nnet 45 | self.binary_path = binary_path 46 | self.use_gpu = use_gpu 47 | self.min_chunk_size = min_chunk_size 48 | self.chunk_size = chunk_size 49 | self.cache_capacity = cache_capacity 50 | 51 | def features2embeddings(self, data_dict): 52 | """ Extract x-vector embeddings from feature vectors. 53 | 54 | Args: 55 | data_dict (Dict): 56 | 57 | Returns: 58 | 59 | """ 60 | tmp_data_dict = {} 61 | for key in data_dict: 62 | tmp_data_dict[f'{key[0]}_{key[1]}'] = data_dict[key] 63 | with tempfile.NamedTemporaryFile() as xvec_ark, tempfile.NamedTemporaryFile() as mfcc_ark: 64 | write_txt_matrix(path=mfcc_ark.name, data_dict=tmp_data_dict) 65 | 66 | args = [self.nnet3_xvector_compute, 67 | '--use-gpu={}'.format('yes' if self.use_gpu else 'no'), 68 | '--min-chunk-size={}'.format(str(self.min_chunk_size)), 69 | '--chunk-size={}'.format(str(self.chunk_size)), 70 | '--cache-capacity={}'.format(str(self.cache_capacity)), 71 | self.nnet, 'ark,t:{}'.format(mfcc_ark.name), 'ark,t:{}'.format(xvec_ark.name)] 72 | 73 | logger.info('Extracting x-vectors from {} feature vectors to `{}`.'.format(len(tmp_data_dict), xvec_ark.name)) 74 | process = subprocess.Popen( 75 | args, stderr=subprocess.PIPE, stdout=subprocess.PIPE, cwd=self.binary_path, shell=False) 76 | _, stderr = process.communicate() 77 | if process.returncode != 0: 78 | raise ValueError('`{}` binary returned error code {}.{}{}'.format( 79 | self.nnet3_xvector_compute, process.returncode, os.linesep, stderr)) 80 | tmp_xvec_dict = read_txt_vectors(xvec_ark.name) 81 | xvec_dict = {} 82 | for key in tmp_xvec_dict: 83 | new_key = tuple(key.split('_')) 84 | xvec_dict[new_key] = tmp_xvec_dict[key] 85 | return xvec_dict 86 | -------------------------------------------------------------------------------- /vbdiar/kaldi/mfcc_features_extraction.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # 4 | # Copyright (C) 2018 Brno University of Technology FIT 5 | # Author: Jan Profant 6 | # All Rights Reserved 7 | 8 | import os 9 | import logging 10 | import tempfile 11 | import subprocess 12 | 13 | import kaldiio 14 | 15 | from vbdiar.kaldi import featbin_path 16 | 17 | 18 | logger = logging.getLogger(__name__) 19 | 20 | 21 | class KaldiMFCCFeatureExtraction(object): 22 | 23 | def __init__(self, config_path, binary_path=featbin_path, apply_cmvn_sliding=True, 24 | norm_vars=False, center=True, cmn_window=300): 25 | """ Initialize Kaldi MFCC extraction component. Names of the arguments keep original Kaldi convention. 26 | 27 | Args: 28 | config_path (string_types): path to config file 29 | binary_path (string_types): path to directory containing binaries 30 | apply_cmvn_sliding (bool): apply cepstral mean and variance normalization 31 | norm_vars (bool): normalize variances 32 | center (bool): center window 33 | cmn_window (int): window size 34 | """ 35 | self.binary_path = binary_path 36 | self.config_path = config_path 37 | self.apply_cmvn_sliding = apply_cmvn_sliding 38 | self.norm_vars = norm_vars 39 | self.center = center 40 | self.cmn_window = cmn_window 41 | self.compute_mfcc_feats_bin = os.path.join(binary_path, 'compute-mfcc-feats') 42 | if not os.path.exists(self.compute_mfcc_feats_bin): 43 | raise ValueError('Path to compute-mfcc-feats - {} does not exists.'.format(self.compute_mfcc_feats_bin)) 44 | self.copy_feats_bin = os.path.join(binary_path, 'copy-feats') 45 | if not os.path.exists(self.copy_feats_bin): 46 | raise ValueError('Path to copy-feats - {} does not exists.'.format(self.copy_feats_bin)) 47 | self.apply_cmvn_sliding_bin = os.path.join(binary_path, 'apply-cmvn-sliding') 48 | if not os.path.exists(self.apply_cmvn_sliding_bin): 49 | raise ValueError('Path to apply-cmvn-sliding - {} does not exists.'.format(self.apply_cmvn_sliding_bin)) 50 | 51 | def __str__(self): 52 | return ''.format(self.config_path) 53 | 54 | def audio2features(self, input_path): 55 | """ Extract features from list of files into list of numpy.arrays 56 | 57 | Args: 58 | input_path (string_types): audio file path 59 | 60 | Returns: 61 | Tuple[str, np.array]: path to Kaldi ark file containing features and features itself 62 | """ 63 | with tempfile.NamedTemporaryFile(mode='w') as wav_scp, tempfile.NamedTemporaryFile() as mfcc_ark: 64 | # dump list of file to wav.scp file 65 | wav_scp.write('{} {}{}'.format(input_path, input_path, os.linesep)) 66 | wav_scp.flush() 67 | 68 | # run fextract 69 | args = [self.compute_mfcc_feats_bin, f'--config={self.config_path}', f'scp:{wav_scp.name}', 70 | f'ark:{mfcc_ark.name if not self.apply_cmvn_sliding else "-"}'] 71 | logger.info('Extracting MFCC features from `{}`.'.format(input_path)) 72 | compute_mfcc_feats = subprocess.Popen( 73 | args, stderr=subprocess.PIPE, stdout=subprocess.PIPE, cwd=self.binary_path, shell=False) 74 | if not self.apply_cmvn_sliding: 75 | # do not apply cmvn, so just simply compute features 76 | _, stderr = compute_mfcc_feats.communicate() 77 | if compute_mfcc_feats.returncode != 0: 78 | raise ValueError(f'`{self.compute_mfcc_feats_bin}` binary returned error code ' 79 | f'{compute_mfcc_feats.returncode}.{os.linesep}{stderr}') 80 | else: 81 | args2 = [self.apply_cmvn_sliding_bin, f'--norm-vars={str(self.norm_vars).lower()}', 82 | f'--center={str(self.center).lower()}', f'--cmn-window={str(self.cmn_window)}', 83 | 'ark:-', f'ark:{mfcc_ark.name}'] 84 | apply_cmvn_sliding = subprocess.Popen(args2, stdin=compute_mfcc_feats.stdout, 85 | stderr=subprocess.PIPE, stdout=subprocess.PIPE, shell=False) 86 | _, stderr = apply_cmvn_sliding.communicate() 87 | if apply_cmvn_sliding.returncode == 0: 88 | pass 89 | else: 90 | raise ValueError(f'`{self.compute_mfcc_feats_bin}` binary returned error code ' 91 | f'{compute_mfcc_feats.returncode}.{os.linesep}{stderr}') 92 | ark = kaldiio.load_ark(mfcc_ark.name) 93 | for key, numpy_array in ark: 94 | return numpy_array 95 | -------------------------------------------------------------------------------- /vbdiar/kaldi/onnx_xvector_extraction.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # 4 | # Copyright (C) 2018 Brno University of Technology FIT 5 | # Author: Jan Profant 6 | # All Rights Reserved 7 | 8 | import logging 9 | import os 10 | 11 | import numpy as np 12 | import onnxruntime 13 | 14 | 15 | logger = logging.getLogger(__name__) 16 | 17 | MIN_SIGNAL_LEN = 25 18 | 19 | 20 | class ONNXXVectorExtraction(object): 21 | 22 | def __init__(self, onnx_path): 23 | """ Initialize ONNX x-vector extractor. 24 | 25 | Args: 26 | onnx_path (str): path to neural net in ONNX format, see https://github.com/onnx/onnx 27 | """ 28 | if not os.path.isfile(onnx_path): 29 | raise ValueError(f'Invalid path to nnet `{onnx_path}`.') 30 | else: 31 | self.onnx_path = onnx_path 32 | self.sess = onnxruntime.InferenceSession(onnx_path) 33 | self.input_name = self.sess.get_inputs()[0].name 34 | 35 | def features2embeddings(self, data_dict): 36 | """ Extract x-vector embeddings from feature vectors. 37 | 38 | Args: 39 | data_dict (Dict): 40 | 41 | Returns: 42 | 43 | """ 44 | logger.info(f'Extracting x-vectors from {len(data_dict)} segments.') 45 | xvec_dict = {} 46 | for name in data_dict: 47 | signal_len, num_coefs = data_dict[name].shape 48 | # here we need to avoid failing on very short inputs, so we will just concatenate frames in time 49 | if signal_len == 0: 50 | continue 51 | elif signal_len < MIN_SIGNAL_LEN: 52 | for i in range(MIN_SIGNAL_LEN // signal_len): 53 | data_dict[name] = np.concatenate((data_dict[name], data_dict[name]), axis=0) 54 | xvec = self.sess.run(None, {self.input_name: data_dict[name].T[np.newaxis, :, :]})[0] 55 | xvec_dict[name] = xvec.squeeze() 56 | return xvec_dict 57 | -------------------------------------------------------------------------------- /vbdiar/kaldi/utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # 4 | # Copyright (C) 2018 Brno University of Technology FIT 5 | # Author: Jan Profant 6 | # All Rights Reserved 7 | 8 | import os 9 | 10 | import numpy as np 11 | 12 | 13 | def read_txt_matrix(path): 14 | """ Read features file in text format. This code expects correct format of input file. 15 | 16 | Args: 17 | path (string_types): path to txt file 18 | 19 | Returns: 20 | Dict[np.array]: name to array mapping 21 | """ 22 | data_dict, name = {}, None 23 | with open(path) as f: 24 | for line in f: 25 | if '[' in line: 26 | name = line.split()[0] if len(line) > 3 else '' 27 | continue 28 | elif ']' in line: 29 | line = line.replace(' ]', '') 30 | assert name is not None, 'Incorrect format of input file `{}`.'.format(path) 31 | if name not in data_dict: 32 | data_dict[name] = [] 33 | data_dict[name].append(np.fromstring(line, sep=' ', dtype=np.float32)) 34 | for name in data_dict: 35 | data_dict[name] = np.array(data_dict[name]) 36 | return data_dict 37 | 38 | 39 | def write_txt_matrix(path, data_dict): 40 | """ Write features into file in text format. This code expects correct format of input dictionary. 41 | 42 | Args: 43 | path (string_types): path to txt file 44 | data_dict (Dict[np.array]): name to array mapping 45 | """ 46 | with open(path, 'w') as f: 47 | for name in sorted(data_dict.keys()): 48 | f.write('{} ['.format(name, os.linesep)) 49 | for row_idx in range(data_dict[name].shape[0]): 50 | f.write('{} '.format(os.linesep)) 51 | data_dict[name][row_idx].tofile(f, sep=' ', format='%.6f') 52 | f.write(' ]{}'.format(os.linesep)) 53 | 54 | 55 | def read_txt_vectors(path): 56 | """ Read vectors file in text format. This code expects correct format of input file. 57 | 58 | Args: 59 | path (string_types): path to txt file 60 | 61 | Returns: 62 | Dict[np.array]: name to array mapping 63 | """ 64 | data_dict = {} 65 | with open(path) as f: 66 | for line in f: 67 | splitted_line = line.split() 68 | name = splitted_line[0] 69 | end_idx = splitted_line.index(']') 70 | vector_data = np.array([float(single_float) for single_float in splitted_line[2:end_idx]]) 71 | data_dict[name] = vector_data 72 | return data_dict 73 | -------------------------------------------------------------------------------- /vbdiar/scoring/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # 4 | # Copyright (C) 2018 Brno University of Technology FIT 5 | # Author: Jan Profant 6 | # All Rights Reserved 7 | -------------------------------------------------------------------------------- /vbdiar/scoring/diarization.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # 4 | # Copyright (C) 2018 Brno University of Technology FIT 5 | # Author: Jan Profant 6 | # All Rights Reserved 7 | 8 | import os 9 | import re 10 | import pickle 11 | import logging 12 | from shutil import rmtree 13 | from subprocess import check_output 14 | from tempfile import mkdtemp, NamedTemporaryFile 15 | 16 | import numpy as np 17 | from spherecluster import SphericalKMeans 18 | from pyclustering.cluster.xmeans import xmeans 19 | from sklearn.cluster import KMeans as sklearnKMeans 20 | from sklearn.cluster import AgglomerativeClustering, DBSCAN, MeanShift 21 | from sklearn.metrics.pairwise import cosine_similarity, pairwise_distances 22 | 23 | from vbdiar.clustering.pldakmeans import PLDAKMeans 24 | from vbdiar.scoring.normalization import Normalization 25 | from vbdiar.utils import mkdir_p 26 | from vbdiar.utils.utils import Utils 27 | 28 | 29 | CDIR = os.path.dirname(os.path.realpath(__file__)) 30 | MD_EVAL_SCRIPT_PATH = os.path.join(CDIR, 'md-eval.pl') 31 | MAX_SRE_CLUSTERS = 5 32 | 33 | logger = logging.getLogger(__name__) 34 | 35 | 36 | def evaluate2rttms(reference_path, hypothesis_path, collar_size=0.25, evaluate_overlaps=False): 37 | """ Evaluate two rttms. 38 | 39 | Args: 40 | reference_path (string_types): 41 | hypothesis_path (string_types): 42 | collar_size (float): 43 | evaluate_overlaps (bool): 44 | 45 | Returns: 46 | float: diarization error rate 47 | """ 48 | args = [MD_EVAL_SCRIPT_PATH, '{}'.format('' if evaluate_overlaps else '-1'), 49 | '-c', str(collar_size), '-r', reference_path, '-s', hypothesis_path] 50 | stdout = check_output(args) 51 | 52 | for line in stdout.decode('utf-8').split(os.linesep): 53 | if ' OVERALL SPEAKER DIARIZATION ERROR = ' in line: 54 | return float(line.replace( 55 | ' OVERALL SPEAKER DIARIZATION ERROR = ', '').replace( 56 | ' percent of scored speaker time `(ALL)', '')) 57 | raise ValueError(f'Command `{" ".join(args)}` failed.') 58 | 59 | 60 | def evaluate_all(reference_dir, hypothesis_dir, names, collar_size=0.25, evaluate_overlaps=False, rttm_ext='.rttm'): 61 | """ Evaluate all rttms in directories specified by list of names. 62 | 63 | Args: 64 | reference_dir (string_types): directory containing reference rttm files 65 | hypothesis_dir (string_types): directory containing hypothesis rttm files 66 | names (List[string_types]): list containing relative names 67 | collar_size (float): 68 | evaluate_overlaps (bool): 69 | rttm_ext (string_types): extension of rttm files 70 | 71 | Returns: 72 | float: diarization error rate 73 | """ 74 | with NamedTemporaryFile(mode='w') as ref, NamedTemporaryFile(mode='w') as hyp: 75 | for name in names: 76 | with open(f'{os.path.join(reference_dir, name)}{rttm_ext}') as f: 77 | for line in f: 78 | ref.write(line) 79 | ref.write(os.linesep) 80 | with open(f'{os.path.join(hypothesis_dir, name)}{rttm_ext}') as f: 81 | for line in f: 82 | hyp.write(line) 83 | hyp.write(os.linesep) 84 | ref.flush() 85 | hyp.flush() 86 | 87 | return evaluate2rttms(ref.name, hyp.name, collar_size=collar_size, evaluate_overlaps=evaluate_overlaps) 88 | 89 | 90 | class Diarization(object): 91 | """ Diarization class used as main diarization focused implementation. 92 | 93 | """ 94 | 95 | def __init__(self, input_list, embeddings, embeddings_mean=None, lda=None, use_l2_norm=True, norm=None, plda=None): 96 | """ Initialize diarization class. 97 | 98 | Args: 99 | input_list (string_types): path to list of input files 100 | embeddings (string_types|List[EmbeddingSet]): path to directory containing embeddings or list 101 | of EmbeddingSet instances 102 | embeddings_mean (np.ndarray): 103 | lda (np.ndarray): linear discriminant analysis - dimensionality reduction 104 | use_l2_norm (bool): do l2 normalization 105 | norm (Normalization): instance of class Normalization 106 | plda (GPLDA): instance of class GPLDA 107 | """ 108 | self.input_list = input_list 109 | if isinstance(embeddings, str): 110 | self.embeddings_dir = embeddings 111 | self.embeddings = list(self.load_embeddings()) 112 | else: 113 | self.embeddings = embeddings 114 | self.lda = lda 115 | self.use_l2_norm = use_l2_norm 116 | self.norm = norm 117 | self.plda = plda 118 | 119 | for embedding_set in self.embeddings: 120 | for embedding in embedding_set: 121 | if embeddings_mean is not None: 122 | embedding.data = embedding.data - embeddings_mean 123 | if lda is not None: 124 | embedding.data = embedding.data.dot(lda) 125 | if use_l2_norm: 126 | embedding.data = Utils.l2_norm(embedding.data[np.newaxis, :]).flatten() 127 | if self.norm: 128 | assert embeddings_mean is not None, 'Expecting usage of mean from normalization set.' 129 | self.norm.embeddings = self.norm.embeddings - embeddings_mean 130 | if lda is not None: 131 | self.norm.embeddings = self.norm.embeddings.dot(lda) 132 | if use_l2_norm: 133 | self.norm.embeddings = Utils.l2_norm(self.norm.embeddings) 134 | 135 | def get_embedding(self, name): 136 | """ Get embedding set by name. 137 | 138 | Args: 139 | name (string_types): 140 | 141 | Returns: 142 | EmbeddingSet: 143 | """ 144 | for ii in self.embeddings: 145 | if name == ii.name: 146 | return ii 147 | raise ValueError(f'Name of the set not found - `{name}`.') 148 | 149 | def load_embeddings(self): 150 | """ Load embedding from pickled files. 151 | 152 | Returns: 153 | List[EmbeddingSet]: 154 | """ 155 | logger.info(f'Loading pickled evaluation embedding from `{self.embeddings_dir}`.') 156 | with open(self.input_list, 'r') as f: 157 | for line in f: 158 | if len(line) > 0: 159 | logger.info(f'Loading evaluation pickle file `{line.rstrip().split()[0]}`.') 160 | line = line.rstrip() 161 | try: 162 | if len(line.split()) == 1: 163 | with open(os.path.join(self.embeddings_dir, line + '.pkl'), 'rb') as i: 164 | yield pickle.load(i) 165 | elif len(line.split()) == 2: 166 | file_name = line.split()[0] 167 | num_spks = int(line.split()[1]) 168 | with open(os.path.join(self.embeddings_dir, file_name + '.pkl'), 'rb') as i: 169 | ivec_set = pickle.load(i) 170 | ivec_set.num_speakers = num_spks 171 | yield ivec_set 172 | else: 173 | raise ValueError(f'Unexpected number of columns in input list `{self.input_list}`.') 174 | except IOError: 175 | logger.warning(f'No pickle file found for `{line.rstrip().split()[0]}`' 176 | f' in `{self.embeddings_dir}`.') 177 | 178 | def score_embeddings(self, min_length, max_num_speakers, mode): 179 | """ Score embeddings. 180 | 181 | Args: 182 | min_length (int): minimal length of segment used for clustering in miliseconds 183 | max_num_speakers (int): maximal number of speakers 184 | mode (str): running mode, see examples/diarization.py for details 185 | 186 | Returns: 187 | dict: dictionary with scores for each file 188 | """ 189 | result_dict = {} 190 | logger.info('Scoring using `{}`.'.format('PLDA' if self.plda is not None else 'cosine distance')) 191 | for embedding_set in self.embeddings: 192 | name = os.path.normpath(embedding_set.name) 193 | embeddings_all = embedding_set.get_all_embeddings() 194 | embeddings_long = embedding_set.get_longer_embeddings(min_length) 195 | if len(embeddings_long) == 0: 196 | logger.warning( 197 | f'No embeddings found longer than {min_length} for embedding set `{name}`.') 198 | continue 199 | size = len(embedding_set) 200 | if size > 0: 201 | logger.info(f'Clustering `{name}` using {len(embeddings_long)} long embeddings.') 202 | if mode == 'diarization': 203 | if embedding_set.num_speakers is not None: 204 | num_speakers = embedding_set.num_speakers 205 | else: 206 | xm = xmeans(embeddings_long, kmax=max_num_speakers) 207 | xm.process() 208 | num_speakers = len(xm.get_clusters()) 209 | 210 | centroids = self.run_clustering(num_speakers, embeddings_long) 211 | if self.norm is None: 212 | if self.plda is None: 213 | result_dict[name] = cosine_similarity(embeddings_all, centroids).T 214 | else: 215 | result_dict[name] = self.plda.score(embeddings_all, centroids) 216 | else: 217 | result_dict[name] = self.norm.s_norm(embeddings_all, centroids) 218 | else: 219 | clusters = [] 220 | for k in range(1, MAX_SRE_CLUSTERS + 1): 221 | if size >= k: 222 | centroids = self.run_clustering(k, embeddings_long) 223 | clusters.extend(x for x in centroids) 224 | result_dict[name] = np.array(clusters) 225 | else: 226 | logger.warning(f'No embeddings to score in `{embedding_set.name}`.') 227 | return result_dict 228 | 229 | def run_ahc(self, n_clusters, embeddings, scores_matrix): 230 | """ Run agglomerative hierarchical clustering. 231 | 232 | Returns: 233 | np.array: means of clusters 234 | """ 235 | scores_matrix = -((scores_matrix - np.min(scores_matrix)) / (np.max(scores_matrix) - np.min(scores_matrix))) 236 | ahc = AgglomerativeClustering(affinity='precomputed', linkage='complete', n_clusters=n_clusters) 237 | labels = ahc.fit_predict(scores_matrix) 238 | return np.array([np.mean(embeddings[np.where(labels == i)], axis=0) for i in range(n_clusters)]) 239 | 240 | def run_clustering(self, num_speakers, embeddings): 241 | if self.use_l2_norm: 242 | kmeans_clustering = SphericalKMeans( 243 | n_clusters=num_speakers, n_init=100, n_jobs=1).fit(embeddings) 244 | else: 245 | kmeans_clustering = sklearnKMeans( 246 | n_clusters=num_speakers, n_init=100, n_jobs=1).fit(embeddings) 247 | centroids = kmeans_clustering.cluster_centers_ 248 | if self.plda: 249 | centroids = PLDAKMeans(centroids=kmeans_clustering.cluster_centers_, k=num_speakers, 250 | plda=self.plda, max_iter=100).fit(embeddings) 251 | return centroids 252 | 253 | def dump_rttm(self, scores, out_dir): 254 | """ Dump rttm files to output directory. This function requires initialized embeddings. 255 | 256 | Args: 257 | scores (Dict): dictionary containing scores 258 | out_dir (string_types): path to output directory 259 | """ 260 | for embedding_set in self.embeddings: 261 | if len(embedding_set) > 0: 262 | name = embedding_set.name 263 | reg_name = re.sub('/.*', '', embedding_set.name) 264 | mkdir_p(os.path.join(out_dir, os.path.dirname(name))) 265 | with open(os.path.join(out_dir, name + '.rttm'), 'w') as f: 266 | for i, embedding in enumerate(embedding_set.embeddings): 267 | start, end = embedding.window_start, embedding.window_end 268 | idx = np.argmax(scores[name][i]) 269 | f.write(f'SPEAKER {reg_name} 1 {float(start / 1000.0)} {float((end - start) / 1000.0)} ' 270 | f' {reg_name}_spkr_{idx} \n') 271 | else: 272 | logger.warning(f'No embedding to dump in {embedding_set.name}.') 273 | 274 | def evaluate(self, scores, in_rttm_dir, collar_size=0.25, evaluate_overlaps=False, rttm_ext='.rttm'): 275 | """ At first, separately evaluate each file based on ground truth segmentation. Then evaluate all files. 276 | 277 | Args: 278 | scores (dict): dictionary containing scores 279 | in_rttm_dir (string_types): input directory with rttm files 280 | collar_size (float): collar size for scoring 281 | evaluate_overlaps (bool): evaluate or ignore overlapping speech segments 282 | rttm_ext (string_types): extension for rttm files 283 | """ 284 | tmp_dir = mkdtemp(prefix='rttm_') 285 | self.dump_rttm(scores, tmp_dir) 286 | for embedding_set in self.embeddings: 287 | name = embedding_set.name 288 | ground_truth_rttm = os.path.join(in_rttm_dir, '{}{}'.format(name, rttm_ext)) 289 | if not os.path.exists(ground_truth_rttm): 290 | logger.warning(f'Ground truth rttm file not found in `{ground_truth_rttm}`.') 291 | continue 292 | # evaluate single rttm 293 | der = evaluate2rttms(ground_truth_rttm, os.path.join(tmp_dir, '{}{}'.format(name, rttm_ext)), 294 | collar_size=collar_size, evaluate_overlaps=evaluate_overlaps) 295 | logger.info(f'`{name}` DER={der}') 296 | 297 | # evaluate all rttms 298 | der = evaluate_all(reference_dir=in_rttm_dir, hypothesis_dir=tmp_dir, names=scores.keys(), 299 | collar_size=collar_size, evaluate_overlaps=evaluate_overlaps, rttm_ext=rttm_ext) 300 | logger.info(f'`Total` DER={der}') 301 | rmtree(tmp_dir) 302 | -------------------------------------------------------------------------------- /vbdiar/scoring/gplda.py: -------------------------------------------------------------------------------- 1 | import os 2 | import operator 3 | from functools import reduce 4 | 5 | import numpy as np 6 | from numpy.linalg import inv, slogdet 7 | 8 | 9 | class GPLDA(object): 10 | """ Gaussian PLDA model. 11 | 12 | 13 | """ 14 | cw = None 15 | cb = None 16 | mean = None 17 | p = None 18 | q = None 19 | k = None 20 | r = None 21 | s = None 22 | t = None 23 | u = None 24 | ct = None 25 | count = None 26 | initialized = False 27 | 28 | def __init__(self, path): 29 | """ Init all parameters needed for Gaussian PLDA Model score computation. 30 | 31 | Args: 32 | path(str): path with Gaussian PLDA model 33 | """ 34 | self.cw = np.load(os.path.join(path, 'CW.npy')) 35 | self.cb = np.load(os.path.join(path, 'CB.npy')) 36 | self.mean = np.load(os.path.join(path, 'mu.npy')) 37 | 38 | self.initialize() 39 | 40 | def initialize(self): 41 | """ Initialize members for faster scoring. """ 42 | self.ct = self.cw + self.cb 43 | self.p = inv(self.ct * 0.5) - inv(0.5 * self.cw + self.cb) 44 | self.q = inv(2 * self.cw) - inv(2 * self.ct) 45 | k1 = reduce(operator.mul, slogdet(0.5 * self.ct)) 46 | k2 = reduce(operator.mul, slogdet(0.5 * self.cw + self.cb)) 47 | k3 = reduce(operator.mul, slogdet(2 * self.ct)) 48 | k4 = reduce(operator.mul, slogdet(2 * self.cw)) 49 | self.k = 0.5 * (k1 - k2 + k3 - k4) 50 | self.r = 0.5 * (0.25 * self.p - self.q) 51 | self.s = 0.5 * (0.25 * self.p + self.q) 52 | self.t = 0.25 * np.dot(self.p, self.mean.T) 53 | u1 = 2 * np.dot(self.mean, 0.25 * self.p) 54 | self.u = self.k + np.dot(u1, self.mean.T) 55 | self.initialized = True 56 | 57 | def score(self, np_vec_1, np_vec_2): 58 | """ Compare two vectors using plda scoring metric. Function is symmetric. 59 | 60 | Args: 61 | np_vec_1 (np.array): array of vectors (e.g. nx250), depends on model 62 | np_vec_2 (np.array): array of vectors (e.g. nx250), depends on model 63 | 64 | Returns: 65 | 2-dimensional np.array: scores matrix 66 | """ 67 | if not self.initialized: 68 | raise ValueError('Model is not trained nor initialized.') 69 | np_vec_1 = np_vec_1.T.copy() 70 | np_vec_2 = np_vec_2.T.copy() 71 | mat1 = np.dot(self.r, np_vec_1) * np_vec_1 72 | mat2 = np.dot(self.r, np_vec_2) * np_vec_2 73 | vct1 = np.sum(mat1, axis=0, keepdims=True) 74 | vct2 = np.sum(mat2, axis=0, keepdims=True) 75 | vct3 = 2 * np.dot(self.t.T, np_vec_1) 76 | vct4 = 2 * np.dot(self.t.T, np_vec_2) 77 | mat3 = np.dot(np_vec_1.T, self.s) 78 | scores_matrix = 2 * np.dot(mat3, np_vec_2) + vct1.T + vct2 - vct3.T - vct4 + self.u 79 | return scores_matrix 80 | -------------------------------------------------------------------------------- /vbdiar/scoring/normalization.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # 4 | # Copyright (C) 2018 Brno University of Technology FIT 5 | # Author: Jan Profant 6 | # All Rights Reserved 7 | 8 | import os 9 | import logging 10 | import pickle 11 | import multiprocessing 12 | 13 | import numpy as np 14 | from sklearn.metrics.pairwise import cosine_similarity 15 | 16 | from vbdiar.features.segments import get_frames_from_time 17 | from vbdiar.embeddings.embedding import extract_embeddings 18 | from vbdiar.utils import mkdir_p 19 | from vbdiar.utils.utils import Utils 20 | 21 | logger = logging.getLogger(__name__) 22 | 23 | 24 | def process_files(fns, speakers_dict, features_extractor, embedding_extractor, 25 | audio_dir, wav_suffix, in_rttm_dir, rttm_suffix, min_length, n_jobs=1): 26 | """ 27 | 28 | Args: 29 | fns: 30 | speakers_dict: 31 | features_extractor: 32 | embedding_extractor: 33 | audio_dir: 34 | wav_suffix: 35 | in_rttm_dir: 36 | rttm_suffix: 37 | min_length: 38 | n_jobs: 39 | 40 | Returns: 41 | 42 | """ 43 | kwargs = dict(speakers_dict=speakers_dict, features_extractor=features_extractor, 44 | embedding_extractor=embedding_extractor, audio_dir=audio_dir, wav_suffix=wav_suffix, 45 | in_rttm_dir=in_rttm_dir, rttm_suffix=rttm_suffix, min_length=min_length) 46 | if n_jobs == 1: 47 | ret = _process_files((fns, kwargs)) 48 | else: 49 | pool = multiprocessing.Pool(n_jobs) 50 | ret = pool.map(_process_files, ((part, kwargs) for part in Utils.partition(fns, n_jobs))) 51 | return ret 52 | 53 | 54 | def _process_files(dargs): 55 | """ 56 | 57 | Args: 58 | dargs: 59 | 60 | Returns: 61 | 62 | """ 63 | fns, kwargs = dargs 64 | ret = [] 65 | for fn in fns: 66 | ret.append(process_file(file_name=fn, **kwargs)) 67 | return ret 68 | 69 | 70 | def process_file(file_name, speakers_dict, features_extractor, embedding_extractor, 71 | audio_dir, wav_suffix, in_rttm_dir, rttm_suffix, min_length): 72 | """ Extract embeddings for all defined speakers. 73 | 74 | Args: 75 | file_name (string_types): path to input audio file 76 | speakers_dict (dict): dictionary containing all embedding across speakers 77 | features_extractor (Any): 78 | embedding_extractor (Any): 79 | audio_dir (string_types): 80 | wav_suffix (string_types): 81 | in_rttm_dir (string_types): 82 | rttm_suffix (string_types): 83 | min_length (float): 84 | 85 | Returns: 86 | dict: updated dictionary with speakers 87 | """ 88 | logger.info('Processing file `{}`.'.format(file_name.split()[0])) 89 | # extract features from whole audio 90 | features = features_extractor.audio2features(os.path.join(audio_dir, '{}{}'.format(file_name, wav_suffix))) 91 | 92 | # process utterances of the speakers 93 | features_dict = {} 94 | with open(f'{os.path.join(in_rttm_dir, file_name)}{rttm_suffix}') as f: 95 | for line in f: 96 | start_time, dur = int(float(line.split()[3]) * 1000), int(float(line.split()[4]) * 1000) 97 | speaker = line.split()[7] 98 | if dur > min_length: 99 | end_time = start_time + dur 100 | start, end = get_frames_from_time(int(start_time)), get_frames_from_time(int(end_time)) 101 | if speaker not in features_dict: 102 | features_dict[speaker] = {} 103 | 104 | assert 0 <= start < end, \ 105 | f'Incorrect timing for extracting features, start: {start}, size: {features.shape[0]}, end: {end}.' 106 | if end >= features.shape[0]: 107 | end = features.shape[0] - 1 108 | features_dict[speaker][(start_time, end_time)] = features[start:end] 109 | for speaker in features_dict: 110 | embedding_set = extract_embeddings(features_dict[speaker], embedding_extractor) 111 | embeddings_long = embedding_set.get_all_embeddings() 112 | if speaker not in speakers_dict.keys(): 113 | speakers_dict[speaker] = embeddings_long 114 | else: 115 | speakers_dict[speaker] = np.concatenate((speakers_dict[speaker], embeddings_long), axis=0) 116 | return speakers_dict 117 | 118 | 119 | class Normalization(object): 120 | """ Speaker normalization S-Norm. """ 121 | embeddings = None 122 | in_emb_dir = None 123 | 124 | def __init__(self, norm_list, audio_dir=None, in_rttm_dir=None, in_emb_dir=None, 125 | out_emb_dir=None, min_length=None, features_extractor=None, embedding_extractor=None, 126 | plda=None, wav_suffix='.wav', rttm_suffix='.rttm', n_jobs=1): 127 | """ Initialize normalization object. 128 | 129 | Args: 130 | norm_list (string_types): path to normalization list 131 | audio_dir (string_types|None): path to audio directory 132 | in_rttm_dir (string_types|None): path to directory with rttm files 133 | in_emb_dir (str|None): path to directory with i-vectors 134 | out_emb_dir (str|None): path to directory for storing embeddings 135 | min_length (int): minimal length for extracting embeddings 136 | features_extractor (Any): object for feature extraction 137 | embedding_extractor (Any): object for extracting embedding 138 | plda (PLDA|None): plda model object 139 | wav_suffix (string_types): suffix of wav files 140 | rttm_suffix (string_types): suffix of rttm files 141 | """ 142 | if audio_dir: 143 | self.audio_dir = os.path.abspath(audio_dir) 144 | self.norm_list = norm_list 145 | if in_rttm_dir: 146 | self.in_rttm_dir = os.path.abspath(in_rttm_dir) 147 | else: 148 | raise ValueError('It is required to have input rttm files for normalization.') 149 | self.features_extractor = features_extractor 150 | self.embedding_extractor = embedding_extractor 151 | self.plda = plda 152 | self.wav_suffix = wav_suffix 153 | self.rttm_suffix = rttm_suffix 154 | if in_emb_dir: 155 | self.in_emb_dir = os.path.abspath(in_emb_dir) 156 | if out_emb_dir: 157 | self.out_emb_dir = os.path.abspath(out_emb_dir) 158 | self.min_length = min_length 159 | self.n_jobs = n_jobs 160 | if self.in_emb_dir is None: 161 | self.embeddings = self.extract_embeddings() 162 | else: 163 | self.embeddings = self.load_embeddings() 164 | self.mean = np.mean(self.embeddings, axis=0) 165 | 166 | def __iter__(self): 167 | current = 0 168 | while current < len(self.embeddings): 169 | yield self.embeddings[current] 170 | current += 1 171 | 172 | def __getitem__(self, key): 173 | return self.embeddings[key] 174 | 175 | def __setitem__(self, key, value): 176 | self.embeddings[key] = value 177 | 178 | def __len__(self): 179 | return len(self.embeddings) 180 | 181 | def extract_embeddings(self): 182 | """ Extract normalization embeddings using averaging. 183 | 184 | Returns: 185 | Tuple[np.array, np.array]: vectors for individual speakers, global mean over all speakers 186 | """ 187 | speakers_dict, fns = {}, [] 188 | with open(self.norm_list) as f: 189 | for line in f: 190 | if len(line.split()) > 1: # number of speakers is defined 191 | line = line.split()[0] 192 | else: 193 | line = line.replace(os.linesep, '') 194 | fns.append(line) 195 | 196 | speakers_dict = process_files(fns, speakers_dict=speakers_dict, features_extractor=self.features_extractor, 197 | embedding_extractor=self.embedding_extractor, audio_dir=self.audio_dir, 198 | wav_suffix=self.wav_suffix, in_rttm_dir=self.in_rttm_dir, 199 | rttm_suffix=self.rttm_suffix, min_length=self.min_length, n_jobs=self.n_jobs) 200 | assert len(speakers_dict) == len(fns) 201 | # all are the same 202 | merged_speakers_dict = speakers_dict[0] 203 | 204 | if self.out_emb_dir: 205 | for speaker in merged_speakers_dict: 206 | out_path = os.path.join(self.out_emb_dir, f'{speaker}.pkl') 207 | mkdir_p(os.path.dirname(out_path)) 208 | with open(out_path, 'wb') as f: 209 | pickle.dump(merged_speakers_dict[speaker], f, pickle.HIGHEST_PROTOCOL) 210 | 211 | for speaker in merged_speakers_dict: 212 | merged_speakers_dict[speaker] = np.mean(merged_speakers_dict[speaker], axis=0) 213 | 214 | return np.array(list(merged_speakers_dict.values())) 215 | 216 | def load_embeddings(self): 217 | """ Load normalization embeddings from pickle files. 218 | 219 | Returns: 220 | np.array: embeddings per speaker 221 | """ 222 | embeddings, speakers = [], set() 223 | with open(self.norm_list) as f: 224 | for file_name in f: 225 | if len(file_name.split()) > 1: # number of speakers is defined 226 | file_name = file_name.split()[0] 227 | else: 228 | file_name = file_name.replace(os.linesep, '') 229 | with open('{}{}'.format(os.path.join(self.in_rttm_dir, file_name), self.rttm_suffix)) as fp: 230 | for line in fp: 231 | speakers.add(line.split()[7]) 232 | 233 | logger.info('Loading pickled normalization embeddings from `{}`.'.format(self.in_emb_dir)) 234 | for speaker in speakers: 235 | embedding_path = os.path.join(self.in_emb_dir, '{}.pkl'.format(speaker)) 236 | if os.path.isfile(embedding_path): 237 | logger.info('Loading normalization pickle file `{}`.'.format(speaker)) 238 | with open(embedding_path, 'rb') as f: 239 | # append mean from speaker's embeddings 240 | speaker_embeddings = pickle.load(f) 241 | embeddings.append(np.mean(speaker_embeddings, axis=0)) 242 | else: 243 | logger.warning('No pickle file found for `{}` in `{}`.'.format(speaker, self.in_emb_dir)) 244 | return np.array(embeddings) 245 | 246 | def s_norm(self, test, enroll): 247 | """ Run speaker normalization (S-Norm) on cached embeddings. 248 | 249 | Args: 250 | test (np.array): test embedding 251 | enroll (np.array): enroll embedding 252 | 253 | Returns: 254 | float: hypothesis 255 | """ 256 | if self.plda: 257 | a = self.plda.score(test, self.embeddings).T 258 | b = self.plda.score(enroll, self.embeddings).T 259 | c = self.plda.score(enroll, test).T 260 | else: 261 | a = cosine_similarity(test, self.embeddings).T 262 | b = cosine_similarity(enroll, self.embeddings).T 263 | c = cosine_similarity(enroll, test).T 264 | scores = [] 265 | for ii in range(test.shape[0]): 266 | test_scores = [] 267 | for jj in range(enroll.shape[0]): 268 | test_mean, test_std = np.mean(a.T[ii]), np.std(a.T[ii]) 269 | enroll_mean, enroll_std = np.mean(b.T[jj]), np.std(b.T[jj]) 270 | s = c[ii][jj] 271 | test_scores.append((((s - test_mean) / test_std + (s - enroll_mean) / enroll_std) / 2)) 272 | scores.append(test_scores) 273 | return np.array(scores) 274 | -------------------------------------------------------------------------------- /vbdiar/utils/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # 4 | # Copyright (C) 2018 Brno University of Technology FIT 5 | # Author: Jan Profant 6 | # All Rights Reserved 7 | 8 | import os 9 | import errno 10 | 11 | 12 | def mkdir_p(path): 13 | """ Behaviour similar to mkdir -p in shell. 14 | 15 | Args: 16 | path (string_types): path to create 17 | """ 18 | try: 19 | os.makedirs(path) 20 | except OSError as exc: # Python >2.5 21 | if exc.errno == errno.EEXIST and os.path.isdir(path): 22 | pass 23 | else: 24 | raise ValueError('Can not create directory {}.'.format(path)) 25 | 26 | 27 | 28 | -------------------------------------------------------------------------------- /vbdiar/utils/utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # 4 | # Copyright (C) 2018 Brno University of Technology FIT 5 | # Author: Jan Profant 6 | # All Rights Reserved 7 | 8 | import os 9 | import re 10 | import random 11 | from os import listdir 12 | from os.path import isfile, join 13 | import fnmatch 14 | import math 15 | 16 | import numpy as np 17 | import yaml 18 | 19 | 20 | class Utils(object): 21 | """ Class tools handles basic operations with files and directories. 22 | 23 | """ 24 | 25 | def __init__(self): 26 | """ tools class constructor. 27 | 28 | """ 29 | return 30 | 31 | @staticmethod 32 | def list_directory_by_suffix(directory, suffix): 33 | """ Return listed directory of files based on their suffix. 34 | 35 | :param directory: directory to be listed 36 | :type directory: str 37 | :param suffix: suffix of files in directory 38 | :type suffix: str 39 | :returns: list of files specified byt suffix in directory 40 | :rtype: list 41 | 42 | >>> Utils.list_directory_by_suffix('../../tests/tools', '.test') 43 | ['empty1.test', 'empty2.test'] 44 | >>> Utils.list_directory_by_suffix('../../tests/tools_no_ex', '.test') 45 | Traceback (most recent call last): 46 | ... 47 | toolsException: [listDirectoryBySuffix] No directory found! 48 | >>> Utils.list_directory_by_suffix('../../tests/tools', '.py') 49 | [] 50 | """ 51 | abs_dir = os.path.abspath(directory) 52 | try: 53 | ofiles = [f for f in listdir(abs_dir) if isfile(join(abs_dir, f))] 54 | except OSError: 55 | raise ValueError('No directory named {} found!'.format(directory)) 56 | out = [] 57 | for file_in in ofiles: 58 | if file_in.find(suffix) != -1: 59 | out.append(file_in) 60 | out.sort() 61 | return out 62 | 63 | @staticmethod 64 | def list_directory(directory): 65 | """ List directory. 66 | 67 | :param directory: directory to be listed 68 | :type directory: str 69 | :returns: list with files in directory 70 | :rtype: list 71 | 72 | >>> Utils.list_directory('../../tests/tools') 73 | ['empty1.test', 'empty2.test', 'test', 'test.txt'] 74 | >>> Utils.list_directory('../../tests/tools_no_ex') 75 | Traceback (most recent call last): 76 | ... 77 | toolsException: [listDirectory] No directory found! 78 | """ 79 | directory = os.path.abspath(directory) 80 | try: 81 | out = [f for f in listdir(directory)] 82 | except OSError: 83 | raise ValueError('No directory found!') 84 | out.sort() 85 | return out 86 | 87 | @staticmethod 88 | def recursively_list_directory_by_suffix(directory, suffix): 89 | """ Return recursively listed directory of files based on their suffix. 90 | 91 | :param directory: directory to be listed 92 | :type directory: str 93 | :param suffix: suffix of files in directory 94 | :type suffix: str 95 | :returns: list of files specified by suffix in directory 96 | :rtype: list 97 | 98 | >>> Utils.recursively_list_directory_by_suffix( \ 99 | '../../tests/tools', '.test') 100 | ['empty1.test', 'empty2.test', 'test/empty.test'] 101 | >>> Utils.recursively_list_directory_by_suffix( \ 102 | '../../tests/tools_no_ex', '.test') 103 | [] 104 | """ 105 | matches = [] 106 | for root, dirnames, filenames in os.walk(directory): 107 | for filename in fnmatch.filter(filenames, '*' + suffix): 108 | app = os.path.join(root, filename).replace(directory + '/', '') 109 | matches.append(app) 110 | matches.sort() 111 | return matches 112 | 113 | @staticmethod 114 | def sed_in_file(input_file, regex1, regex2): 115 | """ Replace in input file by regex. 116 | 117 | :param input_file: input file 118 | :type input_file: str 119 | :param regex1: regular expression 1 120 | :type regex1: str 121 | :param regex2: regular expression 2 122 | :type regex2: str 123 | """ 124 | with open(input_file, 'r') as sources: 125 | lines = sources.readlines() 126 | with open(input_file, 'w') as sources: 127 | for line in lines: 128 | sources.write(re.sub(regex1, regex2, line)) 129 | 130 | @staticmethod 131 | def remove_lines_in_file_by_indexes(input_file, lines_indexes): 132 | """ Remove specified lines in file. 133 | 134 | :param input_file: input file name 135 | :type input_file: str 136 | :param lines_indexes: list with lines 137 | :type lines_indexes: list 138 | """ 139 | with open(input_file, 'r') as sources: 140 | lines = sources.readlines() 141 | with open(input_file, 'w') as sources: 142 | for i in range(len(lines)): 143 | if i not in lines_indexes: 144 | sources.write(lines[i]) 145 | 146 | @staticmethod 147 | def get_method(instance, method): 148 | """ Get method pointer. 149 | 150 | :param instance: input object 151 | :type instance: object 152 | :param method: name of method 153 | :type method: str 154 | :returns: pointer to method 155 | :rtype: method 156 | """ 157 | try: 158 | attr = getattr(instance, method) 159 | except AttributeError: 160 | raise ValueError('Unknown class method!') 161 | return attr 162 | 163 | @staticmethod 164 | def configure_instance(instance, input_list): 165 | """ Configures instance base on methods list. 166 | 167 | :param instance: reference to class instance 168 | :type instance: object 169 | :param input_list: input list with name of class members 170 | :type input_list: list 171 | :returns: configured instance 172 | :rtype: object 173 | """ 174 | for line in input_list: 175 | variable = line[:line.rfind('=')] 176 | value = line[line.rfind('=') + 1:] 177 | method_callback = Utils.get_method(instance, 'Set' + variable) 178 | method_callback(value) 179 | return instance 180 | 181 | @staticmethod 182 | def sort(scores, col=None): 183 | """ Sort scores list where score is in n-th-1 column. 184 | 185 | :param scores: scores list to be sorted 186 | :type scores: list 187 | :param col: index of column 188 | :type col: int 189 | :returns: sorted scores list 190 | :rtype: list 191 | 192 | >>> Utils.sort([['f1', 'f2', 10.0], \ 193 | ['f3', 'f4', -10.0], \ 194 | ['f5', 'f6', 9.58]], col=2) 195 | [['f3', 'f4', -10.0], ['f5', 'f6', 9.58], ['f1', 'f2', 10.0]] 196 | >>> Utils.sort([4.59, 8.8, 6.9, -10001.478]) 197 | [-10001.478, 4.59, 6.9, 8.8] 198 | """ 199 | if col is None: 200 | return sorted(scores, key=float) 201 | else: 202 | return sorted(scores, key=lambda x: x[col]) 203 | 204 | @staticmethod 205 | def reverse_sort(scores, col=None): 206 | """ Reversively sort scores list where score is in n-th column. 207 | 208 | :param scores: scores list to be sorted 209 | :type scores: list 210 | :param col: number of columns 211 | :type col: int 212 | :returns: reversively sorted scores list 213 | :rtype: list 214 | 215 | >>> Utils.reverse_sort([['f1', 'f2', 10.0], \ 216 | ['f3', 'f4', -10.0], \ 217 | ['f5', 'f6', 9.58]], col=2) 218 | [['f1', 'f2', 10.0], ['f5', 'f6', 9.58], ['f3', 'f4', -10.0]] 219 | >>> Utils.reverse_sort([4.59, 8.8, 6.9, -10001.478]) 220 | [8.8, 6.9, 4.59, -10001.478] 221 | """ 222 | if col is None: 223 | return sorted(scores, key=float, reverse=True) 224 | else: 225 | return sorted(scores, key=lambda x: x[col], reverse=True) 226 | 227 | @staticmethod 228 | def get_nth_col(in_list, col): 229 | """ Extract n-th-1 columns from list. 230 | 231 | :param in_list: input list 232 | :type in_list: list 233 | :param col: column 234 | :type col: int 235 | :returns: list only with one column 236 | :rtype: list 237 | 238 | >>> Utils.get_nth_col([['1', '2'], ['3', '4'], ['5', '6']], col=1) 239 | ['2', '4', '6'] 240 | >>> Utils.get_nth_col([['1', '2'], ['3', '4'], ['5', '6']], col=42) 241 | Traceback (most recent call last): 242 | ... 243 | toolsException: [getNthCol] Column out of range! 244 | """ 245 | try: 246 | out = [row[col] for row in in_list] 247 | except IndexError: 248 | raise ValueError('Column out of range!') 249 | return out 250 | 251 | @staticmethod 252 | def find_in_dictionary(in_dict, value): 253 | """ Find value in directory whose items are lists and return key. 254 | 255 | :param in_dict: dictionary to search in 256 | :type in_dict: dict 257 | :param value: value to find 258 | :type value: any 259 | :returns: dictionary key 260 | :rtype: any 261 | 262 | >>> Utils.find_in_dictionary({ 0 : [42], 1 : [88], 2 : [69]}, 69) 263 | 2 264 | >>> Utils.find_in_dictionary(dict(), 69) 265 | Traceback (most recent call last): 266 | ... 267 | toolsException: [findInDictionary] Value not found! 268 | """ 269 | for key in in_dict: 270 | if value in in_dict[key]: 271 | return key 272 | raise ValueError('Value not found!') 273 | 274 | @staticmethod 275 | def get_scores(scores, key): 276 | """ Get scores from scores list by key. 277 | 278 | :param scores: input scores list 279 | :type scores: list 280 | :param key: key to find 281 | :type key: list 282 | :returns: score if key is present in score, None otherwise 283 | :rtype: float 284 | 285 | >>> Utils.get_scores([['f1', 'f2', 10.1], ['f3', 'f4', 20.1], \ 286 | ['f5', 'f6', 30.1]], ['f6', 'f5']) 287 | 30.1 288 | """ 289 | if len(key) != 2: 290 | raise ValueError('Unexpected key!') 291 | if len(scores[0]) != 3: 292 | raise ValueError('Invalid input list!') 293 | for score in scores: 294 | a = score[0] 295 | b = score[1] 296 | if (key[0] == a and key[1] == b) or (key[0] == b and key[1] == a): 297 | return score[2] 298 | return None 299 | 300 | @staticmethod 301 | def get_line_from_file(line_num, infile): 302 | """ Get specified line from file. 303 | 304 | :param line_num: number of line 305 | :type line_num: int 306 | :param infile: file name 307 | :type infile: str 308 | :returns: specified line, None otherwise 309 | :rtype: str 310 | 311 | >>> Utils.get_line_from_file(3, '../../tests/tools/test.txt') 312 | 'c\\n' 313 | >>> Utils.get_line_from_file(10, '../../tests/tools/test.txt') 314 | Traceback (most recent call last): 315 | ... 316 | toolsException: [getLineFromFile] Line number not found! 317 | """ 318 | with open(infile) as fp: 319 | for i, line in enumerate(fp): 320 | if i == line_num - 1: 321 | return line 322 | raise ValueError('Line number {} not found in file.'.format(line_num, infile)) 323 | 324 | @staticmethod 325 | def list2dict(input_list): 326 | """ Create dictionary from list in format [key1, key2, score]. 327 | 328 | :param input_list: list to process 329 | :type input_list: list 330 | :returns: preprocessed dictionary 331 | :rtype: dict 332 | 333 | >>> Utils.list2dict([['f1', 'f2', 10.1], ['f3', 'f4', 20.1], \ 334 | ['f5', 'f6', 30.1], ['f1', 'f3', 40.1]]) 335 | {'f1 f2': 10.1, 'f5 f6': 30.1, 'f3 f4': 20.1, 'f1 f3': 40.1} 336 | >>> Utils.list2dict([['f1', 'f2', 10.1], ['f3', 'f4']]) 337 | Traceback (most recent call last): 338 | ... 339 | toolsException: [list2Dict] Invalid format of input list! 340 | """ 341 | dictionary = dict() 342 | for item in input_list: 343 | if len(item) != 3: 344 | raise ValueError('Invalid format of input list!') 345 | tmp_list = [item[0], item[1]] 346 | tmp_list.sort() 347 | dictionary[tmp_list[0] + ' ' + tmp_list[1]] = item[2] 348 | return dictionary 349 | 350 | @staticmethod 351 | def merge_dicts(*dict_args): 352 | """ Merge dictionaries into single one. 353 | 354 | :param dict_args: input dictionaries 355 | :type dict_args: dict array 356 | :returns: merged dictionaries into single one 357 | :rtype: dict 358 | 359 | >>> Utils.merge_dicts( \ 360 | {'f1 f2': 10.1, 'f5 f6': 30.1, 'f1 f3': 40.1}, {'f6 f2': 50.1}) 361 | {'f1 f2': 10.1, 'f5 f6': 30.1, 'f6 f2': 50.1, 'f1 f3': 40.1} 362 | """ 363 | result = {} 364 | for dictionary in dict_args: 365 | result.update(dictionary) 366 | return result 367 | 368 | @staticmethod 369 | def save_object(obj, path): 370 | """ Saves object to disk. 371 | 372 | :param obj: reference to object 373 | :type obj: any 374 | :param path: path to file 375 | :type path: str 376 | """ 377 | np.save(path, obj) 378 | 379 | @staticmethod 380 | def load_object(path): 381 | """ Loads object from disk. 382 | 383 | :param path: path to file 384 | :type path: str 385 | """ 386 | np.load(path) 387 | 388 | @staticmethod 389 | def common_prefix(m): 390 | """ Given a list of pathnames, returns the longest prefix." 391 | 392 | :param m: input list 393 | :type m: list 394 | :returns: longest prefix in list 395 | :rtype: str 396 | """ 397 | if not m: 398 | return '' 399 | s1 = min(m) 400 | s2 = max(m) 401 | for i, c in enumerate(s1): 402 | if c != s2[i]: 403 | return s1[:i] 404 | return s1 405 | 406 | @staticmethod 407 | def root_name(d): 408 | """ Return a root directory by name. 409 | 410 | :param d: directory name 411 | :type d: str 412 | :returns: root directory name 413 | :rtype d: str 414 | """ 415 | pass 416 | 417 | @staticmethod 418 | def read_config(config_path): 419 | """ Read config in yaml format. 420 | 421 | Args: 422 | config_path (str): path to config file 423 | 424 | Returns: 425 | 426 | """ 427 | with open(config_path, 'r') as ymlfile: 428 | return yaml.load(ymlfile) 429 | 430 | @staticmethod 431 | def l2_norm(ivecs): 432 | """ Perform L2 normalization. 433 | 434 | Args: 435 | ivecs (np.array): input i-vector 436 | 437 | Returns: 438 | np.array: normalized i-vectors 439 | """ 440 | ret_ivecs = ivecs.copy() 441 | ret_ivecs /= np.sqrt((ret_ivecs ** 2).sum(axis=1)[:, np.newaxis]) 442 | return ret_ivecs 443 | 444 | @staticmethod 445 | def cos_sim(v1, v2): 446 | """ 447 | 448 | Args: 449 | v1 (np.array): first vector 450 | v2 (np.array): second vector 451 | 452 | Returns: 453 | 454 | """ 455 | sumxx, sumxy, sumyy = 0, 0, 0 456 | for i in range(len(v1)): 457 | x = v1[i] 458 | y = v2[i] 459 | sumxx += x * x 460 | sumyy += y * y 461 | sumxy += x * y 462 | return sumxy / math.sqrt(sumxx * sumyy) 463 | 464 | @staticmethod 465 | def partition(large_list, n_sublists, shuffle=False): 466 | """Partition a list ``l`` into ``n`` sublists.""" 467 | return np.array_split(large_list, n_sublists) 468 | 469 | 470 | if __name__ == "__main__": 471 | import doctest 472 | doctest.testmod() 473 | -------------------------------------------------------------------------------- /vbdiar/vad/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # 4 | # Copyright (C) 2018 Brno University of Technology FIT 5 | # Author: Jan Profant 6 | # All Rights Reserved 7 | 8 | from .vad import load_vad_lab_as_bool_vec, get_vad 9 | -------------------------------------------------------------------------------- /vbdiar/vad/vad.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # 4 | # Copyright (C) 2018 Brno University of Technology FIT 5 | # Author: Jan Profant 6 | # All Rights Reserved 7 | 8 | import logging 9 | 10 | import numpy as np 11 | 12 | 13 | logger = logging.getLogger(__name__) 14 | 15 | 16 | def get_vad(file_name, fea_len): 17 | """ Load .lab file as bool vector. 18 | 19 | Args: 20 | file_name (str): path to .lab file 21 | fea_len (int): length of features 22 | 23 | Returns: 24 | np.array: bool vector 25 | """ 26 | 27 | logger.info('Loading VAD from file `{}`.'.format(file_name)) 28 | return load_vad_lab_as_bool_vec(file_name)[:fea_len] 29 | 30 | 31 | def load_vad_lab_as_bool_vec(lab_file): 32 | """ 33 | 34 | Args: 35 | lab_file: 36 | 37 | Returns: 38 | 39 | """ 40 | lab_cont = np.atleast_2d(np.loadtxt(lab_file, dtype=object)) 41 | 42 | if lab_cont.shape[1] == 0: 43 | return np.empty(0), 0, 0 44 | 45 | if lab_cont.shape[1] == 3: 46 | lab_cont = lab_cont[lab_cont[:, 2] == 'sp', :][:, [0, 1]] 47 | 48 | n_regions = lab_cont.shape[0] 49 | ii = 0 50 | while True: 51 | try: 52 | start1, end1 = float(lab_cont[ii][0]), float(lab_cont[ii][1]) 53 | jj = ii + 1 54 | start2, end2 = float(lab_cont[jj][0]), float(lab_cont[jj][1]) 55 | if end1 >= start2: 56 | lab_cont = np.delete(lab_cont, ii, axis=0) 57 | ii -= 1 58 | lab_cont[jj - 1][0] = str(start1) 59 | lab_cont[jj - 1][1] = str(max(end1, end2)) 60 | ii += 1 61 | except IndexError: 62 | break 63 | 64 | vad = np.round(np.atleast_2d(lab_cont).astype(np.float).T * 100).astype(np.int) 65 | vad[1] += 1 # Paja's bug!!! 66 | 67 | if not vad.size: 68 | return np.empty(0, dtype=bool) 69 | 70 | npc1 = np.c_[np.zeros_like(vad[0], dtype=bool), np.ones_like(vad[0], dtype=bool)] 71 | npc2 = np.c_[vad[0] - np.r_[0, vad[1, :-1]], vad[1] - vad[0]] 72 | npc2[npc2 < 0] = 0 73 | 74 | out = np.repeat(npc1, npc2.flat) 75 | 76 | n_frames = sum(out) 77 | 78 | return out, n_regions, n_frames 79 | 80 | 81 | --------------------------------------------------------------------------------