2 |
Audio Dataset VPR Classifier
3 |
A voiceprint recognition classifier for audio dataset
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 | [简体中文](https://github.com/2DIPW/audio_dataset_vpr/blob/master/README.md) | English
13 |
14 | ## 🚩 Introduction
15 | A dataset screening tool that automatically classifies audio dataset by speaker based on voiceprint recognition model. To use it, you just need to prepare several representative voice clips for each speaker. It can be used to assist the make of dataset for speech models such as VITS/SoVITS/Diff-SVC/RVC/DDSP-SVC.
16 |
17 | This project is modified by [yeyupiaoling/VoiceprintRecognition-Pytorch](https://github.com/yeyupiaoling/VoiceprintRecognition-Pytorch), which added batch processing feature, optimized the speaker judgment rule, and can save the recognition results as a [2DIPW/audio_dataset_screener]( https://github.com/2DIPW/audio_dataset_screener) JSON project file for further manual screening. Compared with the original project, all codes related to model training have been deleted, so if you need to train your own model, you should use the original project.
18 |
19 | This project is experimental and does not guarantee the effect. It is only for learning and communication, not for production environment.
20 |
21 | ## 📥 Deploy
22 | ### Clone
23 | ```shell
24 | git clone https://github.com/2DIPW/audio_dataset_vpr.git
25 | cd audio_dataset_vpr
26 | ```
27 | ### Create a virtual environment (optional, take Anaconda as an example)
28 | ```sheel
29 | conda create -n ad-vpr python=3.8
30 | conda activate ad-vpr
31 | ```
32 | ### Install PyTorch
33 | - Install PyTorch according to your needs, see [official website](https://pytorch.org/get-started/locally) for details, the following is an example of using pip to install PyTorch-CUDA. Skip if already installed.
34 | ```shell
35 | pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
36 | ```
37 | ### Install other requirements
38 | ```shell
39 | pip install -r requirements.txt
40 | ```
41 | ### Configure the voiceprint recognition model
42 | Under the default parameters, you need to put the model files `model.pt`, `model.state`, `optimizer.pt` and the configuration file `config.yml` in the `model` directory. You can also use model and configuration files in other paths by specifying `-m` and `-c` parameters.
43 |
44 | - You can download the EcapaTdnn model I trained on [zhvoice](https://aistudio.baidu.com/aistudio/datasetdetail/133922) dataset from Hugging Face: [2DIPW/VPR_zhvoice_EcapaTdnn](https://huggingface.co/2DIPW/VPR_zhvoice_EcapaTdnn/tree/main)
45 | - Or download [more models trained by the original project author yeyupiaoling](https://github.com/yeyupiaoling/VoiceprintRecognition-Pytorch#%E6%A8%A1%E5%9E%8B%E4%B8%8B%E8%BD%BD)
46 | - Or [train your own model using the original project](https://github.com/yeyupiaoling/VoiceprintRecognition-Pytorch#%E5%88%9B%E5%BB%BA%E6%95%B0%E6%8D%AE)
47 |
48 | The quality of voiceprint recognition is directly related to the quality of the model, you can try to find the best model by yourself.
49 | ## 🗝 How to use
50 | ### Prepare audio feature library
51 | - For each speaker, select several most representative speech segments and put them into the `labels` directory according to the following structure. Create one subdirectory for each speaker, the directory name is the speaker name, and the file name is arbitrary.
52 | - Since the code determines the speaker based on the number of feature segments whose similarity is greater than the given threshold, please ensure that the number of feature segments **equal** for each speaker.
53 | - If you want to use the voiceprint recognition results for further manual screening by Audio Dataset Screener, the number of speakers should not exceed 5, otherwise speakers with serial numbers greater than 5 will be automatically ignored by Audio Dataset Screener.
54 |
55 | ```
56 | labels
57 | ├───speaker1
58 | │ ├───xxx1-xxx1.wav
59 | │ ├───...
60 | │ └───xxx1-xxx4.wav
61 | └───speaker2
62 | ├───xxx2-xxx1.wav
63 | ├───...
64 | └───xxx2-xxx4.wav
65 | ```
66 | ### Prepare audio files for classification
67 | - With the default parameters, you need to put all the audio files (wav format) into the `input` directory. You can also load audio files in other paths by specifying the `-i` parameter.
68 | ### Run the recognition
69 | - Using `infer.py`
70 | ```shell
71 | python infer.py
72 | ```
73 | Parameters that can be specified:
74 | - `-m` | `--model_path`: Path to model. Default: `model/`
75 | - `-c` | `--configs`: Path to model config file. Default: `model/config.yml`
76 | - `-d` | `--device`: Device to use, gpu or cpu. Default: `gpu`
77 | - `-l` | `--label_path`: Path to Voice feature library. Default: `labels/`
78 | - `-t` | `--threshold`: Threshold for judging compliance. Default: `0.6`
79 | - `-i` | `--input_path`: Path to input files. Default: `input/`
80 | - `-o` | `--output_path`: Path to output files. Default: `output/`
81 | - `-k` | `--keep_unrecognized`: Do not move unrecognized files. Default: Disabled
82 |
83 | - After the process is complete, the input audio files will be moved to the directory named in the `VPR_Result_YYYYMMDD_HHMMSS` format in the `output` directory, and the audio files recognized as different speakers will be moved to the directories named after the speakers, unrecognized audio files will be moved to the `Unrecognized` folder.
84 | - The recognition results will also be saved as a `result.json` file, which can be imported using Audio Dataset Screener for further manual screening.
85 |
86 | ## ⚖ License
87 | The original project is licensed under [Apache License 2.0](https://github.com/yeyupiaoling/VoiceprintRecognition-Pytorch/blob/develop/LICENSE) . According to the license, my project contains the MODIFICATION_STATEMENT.
88 |
89 | This project is licensed under [GNU General Public License v3.0](https://github.com/2DIPW/audio_dataset_vpr/blob/master/LICENSE) .
90 |
91 | *Open source leads the world to a brighter future.*
92 | ## 📃 References
93 | ```
94 | @inproceedings{desplanques2020ecapa,
95 | title={{ECAPA-TDNN: Emphasized Channel Attention, propagation and aggregation in TDNN based speaker verification}},
96 | author={Desplanques, Brecht and Thienpondt, Jenthe and Demuynck, Kris},
97 | booktitle={Interspeech 2020},
98 | pages={3830--3834},
99 | year={2020}
100 | }
101 | ```
--------------------------------------------------------------------------------
/infer.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import json
4 | import shutil
5 | from datetime import datetime
6 | from mvector.predict import MVectorPredictor
7 |
8 |
9 | def if_not_mkdir(path):
10 | if not os.path.exists(path):
11 | os.mkdir(path)
12 |
13 |
14 | def about():
15 | print(r"""
16 | ___ ___ ____ __ __ _ ______ ____
17 | / | __ ______/ (_)___ / __ \____ _/ /_____ _________ / /_ | | / / __ \/ __ \
18 | / /| |/ / / / __ / / __ \ / / / / __ `/ __/ __ `/ ___/ _ \/ __/ | | / / /_/ / /_/ /
19 | / ___ / /_/ / /_/ / / /_/ / / /_/ / /_/ / /_/ /_/ (__ ) __/ /_ | |/ / ____/ _, _/
20 | /_/ |_\__,_/\__,_/_/\____/ /_____/\__,_/\__/\__,_/____/\___/\__/ |___/_/ /_/ |_|
21 |
22 | Audio Dataset Voiceprint Recognition Classifier by 2DIPW based on yeyupiaoling/VoiceprintRecognition-Pytorch
23 | Licensed under GNU General Public License v3. Open source leads the world to a brighter future!
24 |
25 | """)
26 |
27 |
28 | if __name__ == "__main__":
29 | about()
30 | parser = argparse.ArgumentParser()
31 | parser.add_argument('-m', '--model_path', type=str, default="model/", help="Path to model.")
32 | parser.add_argument('-c', '--configs', type=str, default="model/config.yml", help="Path to model config file.")
33 | parser.add_argument('-d', '--device', type=str, default="gpu", help='Device to use, gpu or cpu.')
34 | parser.add_argument('-l', '--label_path', type=str, default="labels/", help="Path to Voice feature library.")
35 | parser.add_argument('-t', '--threshold', type=float, default=0.6, help="Threshold for judging compliance.")
36 | parser.add_argument('-i', '--input_path', type=str, default="input/", help="Path to input files.")
37 | parser.add_argument('-o', '--output_path', type=str, default="output/", help="Path to output files.")
38 | parser.add_argument('-k', '--keep_unrecognized', action='store_true', default=False,
39 | help='Do not move unrecognized files.')
40 |
41 | args = parser.parse_args()
42 |
43 | predictor = MVectorPredictor(configs=args.configs,
44 | threshold=args.threshold,
45 | label_path=args.label_path,
46 | model_path=args.model_path,
47 | use_gpu=True if args.device == "gpu" else False)
48 |
49 | if not os.path.exists(args.input_path):
50 | raise Exception("Input path not existed.")
51 |
52 | # Get labels dict from label_path
53 | labels_list = [f for f in os.listdir(args.label_path) if os.path.isdir(os.path.join(args.label_path, f))]
54 | labels_dict = {}
55 | for i, label in enumerate(labels_list):
56 | labels_dict[label] = i + 1
57 | print(f"Feature labels:{labels_dict}")
58 |
59 | # Get input files list from input_path
60 | input_files = []
61 | for root, dirs, files in os.walk(args.input_path):
62 | input_files += [os.path.abspath(os.path.join(root, f)) for f in files if f.split('.')[-1].upper() in ["WAV"]]
63 |
64 | input_files_amount = len(input_files)
65 |
66 | result_dicts_list = []
67 |
68 | for i, file in enumerate(input_files):
69 | try:
70 | label, similarity = predictor.recognition(audio_data=file)
71 | if label:
72 | print(
73 | f"\033[32m[{i + 1}/{input_files_amount}]\033[0m \033[33m{os.path.basename(file)}\033[0m is recognized as speaker \033[31m{label}\033[0m, the max similarity is \033[34m{similarity}\033[0m.")
74 | result_dicts_list.append(
75 | {"Filepath": file, "Label": labels_dict[label], "Similarity": float(similarity)})
76 | else:
77 | print(
78 | f"\033[32m[{i + 1}/{input_files_amount}]\033[0m \033[33m{os.path.basename(file)}\033[0m could not be recognized as any speaker.")
79 | result_dicts_list.append(
80 | {"Filepath": file, "Label": 0, "Similarity": 0})
81 | except Exception as e:
82 | print(
83 | f"\033[32m[{i + 1}/{input_files_amount}]\033[0m An error occurred while processing \033[33m{os.path.basename(file)}\033[0m : {e}")
84 |
85 | output_path_for_this_run = os.path.join(args.output_path, datetime.now().strftime("VPR_Result_%Y%m%d_%H%M%S"))
86 | if_not_mkdir(output_path_for_this_run)
87 | json_path = os.path.abspath(os.path.join(output_path_for_this_run, "result.json"))
88 |
89 | # Move input files to category folders
90 | print("Moving input files to category folders...")
91 | folder_list_without_unrecognized = [os.path.abspath(os.path.join(output_path_for_this_run, label)) for label in
92 | labels_list]
93 | folder_list = [os.path.abspath(
94 | os.path.join(output_path_for_this_run, "Unrecognized"))] + folder_list_without_unrecognized
95 |
96 | for folder in folder_list:
97 | if_not_mkdir(folder)
98 | for result in result_dicts_list:
99 | if args.keep_unrecognized and result["Label"] == 0:
100 | continue
101 | destination_folder = folder_list[result["Label"]]
102 | try:
103 | shutil.move(result["Filepath"], destination_folder)
104 | result["Filepath"] = os.path.abspath(os.path.join(destination_folder, os.path.basename(result["Filepath"])))
105 | except Exception as e:
106 | print(e)
107 |
108 | # Write result json file to output_path
109 | with open(json_path, "w") as f:
110 | json.dump({"Labels": {str(i + 1): folder for i, folder in enumerate(folder_list_without_unrecognized)},
111 | "Files": result_dicts_list}, f, indent=4)
112 | print(f"Result json is saved as {json_path}")
113 |
--------------------------------------------------------------------------------
/input/put wav files here:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/2DIPW/audio_dataset_vpr/1b88767aa5792762626989d64910e599fc4a6bd9/input/put wav files here
--------------------------------------------------------------------------------
/labels/put speaker feature audio file here, one folder per speaker:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/2DIPW/audio_dataset_vpr/1b88767aa5792762626989d64910e599fc4a6bd9/labels/put speaker feature audio file here, one folder per speaker
--------------------------------------------------------------------------------
/model/put infer model here:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/2DIPW/audio_dataset_vpr/1b88767aa5792762626989d64910e599fc4a6bd9/model/put infer model here
--------------------------------------------------------------------------------
/mvector/MODIFICATION_STATEMENT:
--------------------------------------------------------------------------------
1 | ORI THIS
2 | --------------------------------
3 | data_utils data_utils
4 | +augmentor |
5 | +__init__.py = +__init__.py
6 | +audio.py = +audio.py
7 | +collate_fn.py >> |
8 | +featurizer.py = +featurizer.py
9 | +reader.py >> |
10 | \utils.py = \utils.py
11 | metric
12 | +__init__.py >>
13 | \metrics.py >>
14 | models models
15 | +__init__.py = +__init__.py
16 | +ecapa_tdnn.py = +ecapa_tdnn.py
17 | +fc.py = +fc.py
18 | +loss.py = +loss.py
19 | +pooling.py = +pooling.py
20 | +res2net.py = +res2net.py
21 | +resnet_se.py = +resnet_se.py
22 | \tdnn.py = \tdnn.py
23 | utils utils
24 | +__init__.py = +__init__.py
25 | +logger.py = +logger.py
26 | +record.py >> |
27 | \utils.py <> \utils.py
28 | __init__.py = __init__.py
29 | predict.py <> predict.py
30 | trainer.py >>
31 | --------------------------------
32 |
33 | File: utils\utils.py
34 | 6 from mvector.utils.logger import setup_logger +-
35 | 7
36 | 8 logger = setup_logger(__name__)
37 | 9
38 | 10
39 | 11 def print_arguments(args=None, configs=None):
40 | 12 if args:
41 | 13 logger.info("----------- 额外配置参数 -----------")
42 | 14 for arg, value in sorted(vars(args).items()):
43 | 15 logger.info("%s: %s" % (arg, value))
44 | 16 logger.info("------------------------------------------------")
45 | 17 if configs:
46 | 18 logger.info("----------- 配置文件参数 -----------")
47 | 19 for arg, value in sorted(configs.items()):
48 | 20 if isinstance(value, dict):
49 | 21 logger.info(f"{arg}:")
50 | 22 for a, v in sorted(value.items()):
51 | 23 if isinstance(v, dict):
52 | 24 logger.info(f"\t{a}:")
53 | 25 for a1, v1 in sorted(v.items()):
54 | 26 logger.info("\t\t%s: %s" % (a1, v1))
55 | 27 else:
56 | 28 logger.info("\t%s: %s" % (a, v))
57 | 29 else:
58 | 30 logger.info("%s: %s" % (arg, value))
59 | 31 logger.info("------------------------------------------------")
60 | 32
61 | 33
62 | 34 def add_arguments(argname, type, default, help, argparser, **kwargs):
63 | 35 type = distutils.util.strtobool if type == bool else type
64 | 36 argparser.add_argument("--" + argname,
65 | 37 default=default,
66 | 38 type=type,
67 | 39 help=help + ' 默认: %(default)s.',
68 | 40 **kwargs)
69 | 41
70 | -------------------------------------------------------------------------------
71 |
72 | File: predict.py
73 | 2 import pickle +-
74 | 3 import shutil
75 | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
76 | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
77 | 21 from mvector.utils.utils import dict_to_object, print_arguments <> 19 from mvector.utils.utils import dict_to_object
78 | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
79 | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
80 | 30 audio_db_path=None, <> 28 label_path=None,
81 | 31 model_path='models/ecapa_tdnn_spectrogram/best_model/', 29 model_path='./model',
82 | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
83 | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
84 | 37 :param audio_db_path: 声纹库路径 <> 35 :param label_path: 声纹库路径
85 | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
86 | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
87 | 42 assert (torch.cuda.is_available()), 'GPU不可用' <> 40 assert (torch.cuda.is_available()), 'GPU not available.'
88 | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
89 | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
90 | 54 print_arguments(configs=configs) +-
91 | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
92 | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
93 | 57 '【警告】,您貌似使用了旧的配置文件,如果你同时使用了旧的模型,这是错误的,请重新下载或者重新训练,否则只能回滚代码。' <> 54 'You are using an old version of model which is no longer supported.'
94 | 58 assert self.configs.use_model in SUPPORT_MODEL, f'没有该模型:{self.configs.use_model}' 55 assert self.configs.use_model in SUPPORT_MODEL, f'Model not existed:{self.configs.use_model}'
95 | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
96 | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
97 | 71 raise Exception(f'{self.configs.use_model} 模型不存在!') <> 68 raise Exception(f'{self.configs.use_model} model not existed!')
98 | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
99 | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
100 | 77 assert os.path.exists(model_path), f"{model_path} 模型不存在!" <> 74 assert os.path.exists(model_path), f"{model_path} model not existed!"
101 | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
102 | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
103 | 83 print(f"成功加载模型参数:{model_path}") <> 80 print(f"Model loaded successfully:{model_path}")
104 | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
105 | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
106 | 94 self.audio_db_path = audio_db_path <> 91 self.audio_db_path = label_path
107 | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
108 | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
109 | 96 self.audio_indexes_path = os.path.join(audio_db_path, "audio_indexes.bin") +-
110 | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
111 | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
112 | 100 # 加载声纹特征索引 +-
113 | 101 def __load_face_indexes(self):
114 | 102 # 如果存在声纹特征索引文件就加载
115 | 103 if not os.path.exists(self.audio_indexes_path): return
116 | 104 with open(self.audio_indexes_path, "rb") as f:
117 | 105 indexes = pickle.load(f)
118 | 106 self.users_name = indexes["users_name"]
119 | 107 self.audio_feature = indexes["faces_feature"]
120 | 108 self.users_audio_path = indexes["users_image_path"]
121 | 109
122 | 110 # 保存声纹特征索引
123 | 111 def __write_index(self):
124 | 112 with open(self.audio_indexes_path, "wb") as f:
125 | 113 pickle.dump({"users_name": self.users_name,
126 | 114 "faces_feature": self.audio_feature,
127 | 115 "users_image_path": self.users_audio_path}, f)
128 | 116
129 | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
130 | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
131 | 120 self.__load_face_indexes() +-
132 | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
133 | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
134 | 125 if not os.path.isdir(audio_dir):continue <> 103 if not os.path.isdir(audio_dir): continue
135 | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
136 | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
137 | 130 logger.info('正在加载声纹库数据...') <> 108 print("Loading voice feature library...")
138 | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
139 | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
140 | 157 assert len(self.audio_feature) == len(self.users_name) == len(self.users_audio_path), '加载的数量对不上!' <> 135 assert len(self.audio_feature) == len(self.users_name) == len(self.users_audio_path), 'Labels count conflict.'
141 | 158 # 将声纹特征保存到索引文件中 136 print("Voice feature library loaded successfully.")
142 | 159 self.__write_index()
143 | 160 logger.info('声纹库数据加载完成!')
144 | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
145 | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
146 | 164 labels = [] <> 140 results = []
147 | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
148 | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
149 | 179 max_label = None <> 155 results.append({"label": None, "similarity": None})
150 | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
151 | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
152 | 182 labels.append(max_label) <> 158 idx_for_max_label = [i for i, x in enumerate(self.users_name) if x == max_label]
153 | 159 similarity_for_max_label = max(abs_similarity[idx_for_max_label])
154 | 160 results.append({"label": max_label, "similarity": similarity_for_max_label})
155 | 183 return labels 161 return results
156 | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
157 | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
158 | 203 f'音频太短,最小应该为{self.configs.dataset_conf.min_duration}s,当前音频为{audio_segment.duration}s' <> 181 f'Audio segment too short,minimum is {self.configs.dataset_conf.min_duration}s,current is{audio_segment.duration}s'
159 | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
160 | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
161 | 254 input_lens_ratio.append(seq_length/max_audio_length) <> 232 input_lens_ratio.append(seq_length / max_audio_length)
162 | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
163 | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
164 | 262 # 声纹对比 +-
165 | 263 def contrast(self, audio_data1, audio_data2):
166 | 264 feature1 = self.predict(audio_data1)
167 | 265 feature2 = self.predict(audio_data2)
168 | 266 # 对角余弦值
169 | 267 dist = np.dot(feature1, feature2) / (np.linalg.norm(feature1) * np.linalg.norm(feature2))
170 | 268 return dist
171 | 269
172 | 270 def register(self,
173 | 271 user_name,
174 | 272 audio_data,
175 | 273 sample_rate=16000):
176 | 274 """声纹注册
177 | 275 :param user_name: 注册用户的名字
178 | 276 :param audio_data: 需要识别的数据,支持文件路径,文件对象,字节,numpy。如果是字节的话,必须是完整的字节文件
179 | 277 :param sample_rate: 如果传入的事numpy数据,需要指定采样率
180 | 278 :return: 识别的文本结果和解码的得分数
181 | 279 """
182 | 280 # 加载音频文件
183 | 281 if isinstance(audio_data, str):
184 | 282 audio_segment = AudioSegment.from_file(audio_data)
185 | 283 elif isinstance(audio_data, BufferedReader):
186 | 284 audio_segment = AudioSegment.from_file(audio_data)
187 | 285 elif isinstance(audio_data, np.ndarray):
188 | 286 audio_segment = AudioSegment.from_ndarray(audio_data, sample_rate)
189 | 287 elif isinstance(audio_data, bytes):
190 | 288 audio_segment = AudioSegment.from_bytes(audio_data)
191 | 289 else:
192 | 290 raise Exception(f'不支持该数据类型,当前数据类型为:{type(audio_data)}')
193 | 291 feature = self.predict(audio_data=audio_segment.samples, sample_rate=audio_segment.sample_rate)
194 | 292 if self.audio_feature is None:
195 | 293 self.audio_feature = feature
196 | 294 else:
197 | 295 self.audio_feature = np.vstack((self.audio_feature, feature))
198 | 296 # 保存
199 | 297 if not os.path.exists(os.path.join(self.audio_db_path, user_name)):
200 | 298 audio_path = os.path.join(self.audio_db_path, user_name, '0.wav')
201 | 299 else:
202 | 300 audio_path = os.path.join(self.audio_db_path, user_name,
203 | 301 f'{len(os.listdir(os.path.join(self.audio_db_path, user_name)))}.wav')
204 | 302 os.makedirs(os.path.dirname(audio_path), exist_ok=True)
205 | 303 audio_segment.to_wav_file(audio_path)
206 | 304 self.users_audio_path.append(audio_path.replace('\\', '/'))
207 | 305 self.users_name.append(user_name)
208 | 306 self.__write_index()
209 | 307 return True, "注册成功"
210 | 308
211 | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
212 | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
213 | 319 name = self.__retrieval(np_feature=[feature])[0] <> 250 result = self.__retrieval(np_feature=[feature])[0]
214 | 320 return name
215 | 321
216 | 322 def remove_user(self, user_name):
217 | 323 """删除用户
218 | 324
219 | 325 :param user_name: 用户名
220 | 326 :return:
221 | 327 """
222 | 328 if user_name in self.users_name: 251 return result["label"], result["similarity"]
223 | 329 indexes = [i for i in range(len(self.users_name)) if self.users_name[i] == user_name]
224 | 330 for index in sorted(indexes, reverse=True):
225 | 331 del self.users_name[index]
226 | 332 del self.users_audio_path[index]
227 | 333 self.audio_feature = np.delete(self.audio_feature, index, axis=0)
228 | 334 self.__write_index()
229 | 335 shutil.rmtree(os.path.join(self.audio_db_path, user_name))
230 | 336 return True
231 | 337 else:
232 | 338 return False
233 | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
234 |
235 |
--------------------------------------------------------------------------------
/mvector/__init__.py:
--------------------------------------------------------------------------------
1 | __version__ = "0.3.9"
2 | # 项目支持的模型
3 | SUPPORT_MODEL = ['ecapa_tdnn', 'EcapaTdnn', 'Res2Net', 'ResNetSE', 'TDNN']
4 |
--------------------------------------------------------------------------------
/mvector/data_utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/2DIPW/audio_dataset_vpr/1b88767aa5792762626989d64910e599fc4a6bd9/mvector/data_utils/__init__.py
--------------------------------------------------------------------------------
/mvector/data_utils/audio.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import io
3 | import os
4 | import random
5 |
6 | import numpy as np
7 | import resampy
8 | import soundfile
9 |
10 | from mvector.data_utils.utils import buf_to_float, vad, decode_audio
11 |
12 |
13 | class AudioSegment(object):
14 | """Monaural audio segment abstraction.
15 |
16 | :param samples: Audio samples [num_samples x num_channels].
17 | :type samples: ndarray.float32
18 | :param sample_rate: Audio sample rate.
19 | :type sample_rate: int
20 | :raises TypeError: If the sample data type is not float or int.
21 | """
22 |
23 | def __init__(self, samples, sample_rate):
24 | """Create audio segment from samples.
25 |
26 | Samples are convert float32 internally, with int scaled to [-1, 1].
27 | """
28 | self._samples = self._convert_samples_to_float32(samples)
29 | self._sample_rate = sample_rate
30 | if self._samples.ndim >= 2:
31 | self._samples = np.mean(self._samples, 1)
32 |
33 | def __eq__(self, other):
34 | """返回两个对象是否相等"""
35 | if type(other) is not type(self):
36 | return False
37 | if self._sample_rate != other._sample_rate:
38 | return False
39 | if self._samples.shape != other._samples.shape:
40 | return False
41 | if np.any(self.samples != other._samples):
42 | return False
43 | return True
44 |
45 | def __ne__(self, other):
46 | """返回两个对象是否不相等"""
47 | return not self.__eq__(other)
48 |
49 | def __str__(self):
50 | """返回该音频的信息"""
51 | return ("%s: num_samples=%d, sample_rate=%d, duration=%.2fsec, "
52 | "rms=%.2fdB" % (type(self), self.num_samples, self.sample_rate, self.duration, self.rms_db))
53 |
54 | @classmethod
55 | def from_file(cls, file):
56 | """从音频文件创建音频段
57 |
58 | :param file: 文件路径,或者文件对象
59 | :type file: str, BufferedReader
60 | :return: 音频片段实例
61 | :rtype: AudioSegment
62 | """
63 | assert os.path.exists(file), f'文件不存在,请检查路径:{file}'
64 | try:
65 | samples, sample_rate = soundfile.read(file, dtype='float32')
66 | except:
67 | # 支持更多格式数据
68 | sample_rate = 16000
69 | samples = decode_audio(file=file, sample_rate=sample_rate)
70 | return cls(samples, sample_rate)
71 |
72 | @classmethod
73 | def slice_from_file(cls, file, start=None, end=None):
74 | """只加载一小段音频,而不需要将整个文件加载到内存中,这是非常浪费的。
75 |
76 | :param file: 输入音频文件路径或文件对象
77 | :type file: str|file
78 | :param start: 开始时间,单位为秒。如果start是负的,则它从末尾开始计算。如果没有提供,这个函数将从最开始读取。
79 | :type start: float
80 | :param end: 结束时间,单位为秒。如果end是负的,则它从末尾开始计算。如果没有提供,默认的行为是读取到文件的末尾。
81 | :type end: float
82 | :return: AudioSegment输入音频文件的指定片的实例。
83 | :rtype: AudioSegment
84 | :raise ValueError: 如开始或结束的设定不正确,例如时间不允许。
85 | """
86 | sndfile = soundfile.SoundFile(file)
87 | sample_rate = sndfile.samplerate
88 | duration = round(float(len(sndfile)) / sample_rate, 3)
89 | start = 0. if start is None else round(start, 3)
90 | end = duration if end is None else round(end, 3)
91 | # 从末尾开始计
92 | if start < 0.0: start += duration
93 | if end < 0.0: end += duration
94 | # 保证数据不越界
95 | if start < 0.0: start = 0.0
96 | if end > duration: end = duration
97 | if end < 0.0:
98 | raise ValueError("切片结束位置(%f s)越界" % end)
99 | if start > end:
100 | raise ValueError("切片开始位置(%f s)晚于切片结束位置(%f s)" % (start, end))
101 | start_frame = int(start * sample_rate)
102 | end_frame = int(end * sample_rate)
103 | sndfile.seek(start_frame)
104 | data = sndfile.read(frames=end_frame - start_frame, dtype='float32')
105 | return cls(data, sample_rate)
106 |
107 | @classmethod
108 | def from_bytes(cls, data):
109 | """从包含音频样本的字节创建音频段
110 |
111 | :param data: 包含音频样本的字节
112 | :type data: bytes
113 | :return: 音频部分实例
114 | :rtype: AudioSegment
115 | """
116 | samples, sample_rate = soundfile.read(io.BytesIO(data), dtype='float32')
117 | return cls(samples, sample_rate)
118 |
119 | @classmethod
120 | def from_pcm_bytes(cls, data, channels=1, samp_width=2, sample_rate=16000):
121 | """从包含无格式PCM音频的字节创建音频
122 |
123 | :param data: 包含音频样本的字节
124 | :type data: bytes
125 | :param channels: 音频的通道数
126 | :type channels: int
127 | :param samp_width: 音频采样的宽度,如np.int16为2
128 | :type samp_width: int
129 | :param sample_rate: 音频样本采样率
130 | :type sample_rate: int
131 | :return: 音频部分实例
132 | :rtype: AudioSegment
133 | """
134 | samples = buf_to_float(data, n_bytes=samp_width)
135 | if channels > 1:
136 | samples = samples.reshape(-1, channels)
137 | return cls(samples, sample_rate)
138 |
139 | @classmethod
140 | def from_ndarray(cls, data, sample_rate=16000):
141 | """从numpy.ndarray创建音频段
142 |
143 | :param data: numpy.ndarray类型的音频数据
144 | :type data: ndarray
145 | :param sample_rate: 音频样本采样率
146 | :type sample_rate: int
147 | :return: 音频部分实例
148 | :rtype: AudioSegment
149 | """
150 | return cls(data, sample_rate)
151 |
152 | @classmethod
153 | def concatenate(cls, *segments):
154 | """将任意数量的音频片段连接在一起
155 |
156 | :param *segments: 输入音频片段被连接
157 | :type *segments: tuple of AudioSegment
158 | :return: Audio segment instance as concatenating results.
159 | :rtype: AudioSegment
160 | :raises ValueError: If the number of segments is zero, or if the
161 | sample_rate of any segments does not match.
162 | :raises TypeError: If any segment is not AudioSegment instance.
163 | """
164 | # Perform basic sanity-checks.
165 | if len(segments) == 0:
166 | raise ValueError("没有音频片段被给予连接")
167 | sample_rate = segments[0]._sample_rate
168 | for seg in segments:
169 | if sample_rate != seg._sample_rate:
170 | raise ValueError("能用不同的采样率连接片段")
171 | if type(seg) is not cls:
172 | raise TypeError("只有相同类型的音频片段可以连接")
173 | samples = np.concatenate([seg.samples for seg in segments])
174 | return cls(samples, sample_rate)
175 |
176 | @classmethod
177 | def make_silence(cls, duration, sample_rate):
178 | """创建给定持续时间和采样率的静音音频段
179 |
180 | :param duration: 静音的时间,以秒为单位
181 | :type duration: float
182 | :param sample_rate: 音频采样率
183 | :type sample_rate: float
184 | :return: 给定持续时间的静音AudioSegment实例
185 | :rtype: AudioSegment
186 | """
187 | samples = np.zeros(int(duration * sample_rate))
188 | return cls(samples, sample_rate)
189 |
190 | def to_wav_file(self, filepath, dtype='float32'):
191 | """保存音频段到磁盘为wav文件
192 |
193 | :param filepath: WAV文件路径或文件对象,以保存音频段
194 | :type filepath: str|file
195 | :param dtype: Subtype for audio file. Options: 'int16', 'int32',
196 | 'float32', 'float64'. Default is 'float32'.
197 | :type dtype: str
198 | :raises TypeError: If dtype is not supported.
199 | """
200 | samples = self._convert_samples_from_float32(self._samples, dtype)
201 | subtype_map = {
202 | 'int16': 'PCM_16',
203 | 'int32': 'PCM_32',
204 | 'float32': 'FLOAT',
205 | 'float64': 'DOUBLE'
206 | }
207 | soundfile.write(
208 | filepath,
209 | samples,
210 | self._sample_rate,
211 | format='WAV',
212 | subtype=subtype_map[dtype])
213 |
214 | def superimpose(self, other):
215 | """将另一个段的样本添加到这个段的样本中(以样本方式添加,而不是段连接)。
216 |
217 | :param other: 包含样品的片段被添加进去
218 | :type other: AudioSegments
219 | :raise TypeError: 如果两个片段的类型不匹配
220 | :raise ValueError: 不能添加不同类型的段
221 | """
222 | if not isinstance(other, type(self)):
223 | raise TypeError("不能添加不同类型的段: %s 和 %s" % (type(self), type(other)))
224 | if self._sample_rate != other._sample_rate:
225 | raise ValueError("采样率必须匹配才能添加片段")
226 | if len(self._samples) != len(other._samples):
227 | raise ValueError("段长度必须匹配才能添加段")
228 | self._samples += other._samples
229 |
230 | def to_bytes(self, dtype='float32'):
231 | """创建包含音频内容的字节字符串
232 |
233 | :param dtype: Data type for export samples. Options: 'int16', 'int32',
234 | 'float32', 'float64'. Default is 'float32'.
235 | :type dtype: str
236 | :return: Byte string containing audio content.
237 | :rtype: str
238 | """
239 | samples = self._convert_samples_from_float32(self._samples, dtype)
240 | return samples.tostring()
241 |
242 | def to(self, dtype='int16'):
243 | """类型转换
244 |
245 | :param dtype: Data type for export samples. Options: 'int16', 'int32',
246 | 'float32', 'float64'. Default is 'float32'.
247 | :type dtype: str
248 | :return: np.ndarray containing `dtype` audio content.
249 | :rtype: str
250 | """
251 | samples = self._convert_samples_from_float32(self._samples, dtype)
252 | return samples
253 |
254 | def gain_db(self, gain):
255 | """对音频施加分贝增益。
256 |
257 | Note that this is an in-place transformation.
258 |
259 | :param gain: Gain in decibels to apply to samples.
260 | :type gain: float|1darray
261 | """
262 | self._samples *= 10.**(gain / 20.)
263 |
264 | def change_speed(self, speed_rate):
265 | """通过线性插值改变音频速度
266 |
267 | :param speed_rate: Rate of speed change:
268 | speed_rate > 1.0, speed up the audio;
269 | speed_rate = 1.0, unchanged;
270 | speed_rate < 1.0, slow down the audio;
271 | speed_rate <= 0.0, not allowed, raise ValueError.
272 | :type speed_rate: float
273 | :raises ValueError: If speed_rate <= 0.0.
274 | """
275 | if speed_rate == 1.0:
276 | return
277 | if speed_rate <= 0:
278 | raise ValueError("速度速率应大于零")
279 | old_length = self._samples.shape[0]
280 | new_length = int(old_length / speed_rate)
281 | old_indices = np.arange(old_length)
282 | new_indices = np.linspace(start=0, stop=old_length, num=new_length)
283 | self._samples = np.interp(new_indices, old_indices, self._samples).astype(np.float32)
284 |
285 | def normalize(self, target_db=-20, max_gain_db=300.0):
286 | """将音频归一化,使其具有所需的有效值(以分贝为单位)
287 |
288 | :param target_db: Target RMS value in decibels. This value should be
289 | less than 0.0 as 0.0 is full-scale audio.
290 | :type target_db: float
291 | :param max_gain_db: Max amount of gain in dB that can be applied for
292 | normalization. This is to prevent nans when
293 | attempting to normalize a signal consisting of
294 | all zeros.
295 | :type max_gain_db: float
296 | :raises ValueError: If the required gain to normalize the segment to
297 | the target_db value exceeds max_gain_db.
298 | """
299 | if -np.inf == self.rms_db: return
300 | gain = target_db - self.rms_db
301 | if gain > max_gain_db:
302 | raise ValueError(
303 | "无法将段规范化到 %f dB,因为可能的增益已经超过max_gain_db (%f dB)" % (target_db, max_gain_db))
304 | self.gain_db(min(max_gain_db, target_db - self.rms_db))
305 |
306 | def resample(self, target_sample_rate, filter='kaiser_best'):
307 | """按目标采样率重新采样音频
308 |
309 | Note that this is an in-place transformation.
310 |
311 | :param target_sample_rate: Target sample rate.
312 | :type target_sample_rate: int
313 | :param filter: The resampling filter to use one of {'kaiser_best', 'kaiser_fast'}.
314 | :type filter: str
315 | """
316 | self._samples = resampy.resample(self.samples, self.sample_rate, target_sample_rate, filter=filter)
317 | self._sample_rate = target_sample_rate
318 |
319 | def pad_silence(self, duration, sides='both'):
320 | """在这个音频样本上加一段静音
321 |
322 | Note that this is an in-place transformation.
323 |
324 | :param duration: Length of silence in seconds to pad.
325 | :type duration: float
326 | :param sides: Position for padding:
327 | 'beginning' - adds silence in the beginning;
328 | 'end' - adds silence in the end;
329 | 'both' - adds silence in both the beginning and the end.
330 | :type sides: str
331 | :raises ValueError: If sides is not supported.
332 | """
333 | if duration == 0.0:
334 | return self
335 | cls = type(self)
336 | silence = self.make_silence(duration, self._sample_rate)
337 | if sides == "beginning":
338 | padded = cls.concatenate(silence, self)
339 | elif sides == "end":
340 | padded = cls.concatenate(self, silence)
341 | elif sides == "both":
342 | padded = cls.concatenate(silence, self, silence)
343 | else:
344 | raise ValueError("Unknown value for the sides %s" % sides)
345 | self._samples = padded._samples
346 |
347 | def shift(self, shift_ms):
348 | """音频偏移。如果shift_ms为正,则随时间提前移位;如果为负,则随时间延迟移位。填补静音以保持持续时间不变。
349 |
350 | Note that this is an in-place transformation.
351 |
352 | :param shift_ms: Shift time in millseconds. If positive, shift with
353 | time advance; if negative; shift with time delay.
354 | :type shift_ms: float
355 | :raises ValueError: If shift_ms is longer than audio duration.
356 | """
357 | if abs(shift_ms) / 1000.0 > self.duration:
358 | raise ValueError("shift_ms的绝对值应该小于音频持续时间")
359 | shift_samples = int(shift_ms * self._sample_rate / 1000)
360 | if shift_samples > 0:
361 | # time advance
362 | self._samples[:-shift_samples] = self._samples[shift_samples:]
363 | self._samples[-shift_samples:] = 0
364 | elif shift_samples < 0:
365 | # time delay
366 | self._samples[-shift_samples:] = self._samples[:shift_samples]
367 | self._samples[:-shift_samples] = 0
368 |
369 | def subsegment(self, start_sec=None, end_sec=None):
370 | """在给定的边界之间切割音频片段
371 |
372 | Note that this is an in-place transformation.
373 |
374 | :param start_sec: Beginning of subsegment in seconds.
375 | :type start_sec: float
376 | :param end_sec: End of subsegment in seconds.
377 | :type end_sec: float
378 | :raise ValueError: If start_sec or end_sec is incorrectly set, e.g. out
379 | of bounds in time.
380 | """
381 | start_sec = 0.0 if start_sec is None else start_sec
382 | end_sec = self.duration if end_sec is None else end_sec
383 | if start_sec < 0.0:
384 | start_sec = self.duration + start_sec
385 | if end_sec < 0.0:
386 | end_sec = self.duration + end_sec
387 | if start_sec < 0.0:
388 | raise ValueError("切片起始位置(%f s)越界" % start_sec)
389 | if end_sec < 0.0:
390 | raise ValueError("切片结束位置(%f s)越界" % end_sec)
391 | if start_sec > end_sec:
392 | raise ValueError("切片的起始位置(%f s)晚于结束位置(%f s)" % (start_sec, end_sec))
393 | if end_sec > self.duration:
394 | raise ValueError("切片结束位置(%f s)越界(> %f s)" % (end_sec, self.duration))
395 | start_sample = int(round(start_sec * self._sample_rate))
396 | end_sample = int(round(end_sec * self._sample_rate))
397 | self._samples = self._samples[start_sample:end_sample]
398 |
399 | def random_subsegment(self, subsegment_length):
400 | """随机剪切指定长度的音频片段
401 |
402 | Note that this is an in-place transformation.
403 |
404 | :param subsegment_length: Subsegment length in seconds.
405 | :type subsegment_length: float
406 | :raises ValueError: If the length of subsegment is greater than
407 | the origineal segemnt.
408 | """
409 | if subsegment_length > self.duration:
410 | raise ValueError("Length of subsegment must not be greater "
411 | "than original segment.")
412 | start_time = random.uniform(0.0, self.duration - subsegment_length)
413 | self.subsegment(start_time, start_time + subsegment_length)
414 |
415 | def add_noise(self,
416 | noise,
417 | snr_dB,
418 | max_gain_db=300.0):
419 | """以特定的信噪比添加给定的噪声段。如果噪声段比该噪声段长,则从该噪声段中采样匹配长度的随机子段。
420 |
421 | Note that this is an in-place transformation.
422 |
423 | :param noise: Noise signal to add.
424 | :type noise: AudioSegment
425 | :param snr_dB: Signal-to-Noise Ratio, in decibels.
426 | :type snr_dB: float
427 | :param max_gain_db: Maximum amount of gain to apply to noise signal
428 | before adding it in. This is to prevent attempting
429 | to apply infinite gain to a zero signal.
430 | :type max_gain_db: float
431 | :raises ValueError: If the sample rate does not match between the two
432 | audio segments, or if the duration of noise segments
433 | is shorter than original audio segments.
434 | """
435 | if noise.sample_rate != self.sample_rate:
436 | raise ValueError("噪声采样率(%d Hz)不等于基信号采样率(%d Hz)" % (noise.sample_rate, self.sample_rate))
437 | if noise.duration < self.duration:
438 | raise ValueError("噪声信号(%f秒)必须至少与基信号(%f秒)一样长" % (noise.duration, self.duration))
439 | noise_gain_db = min(self.rms_db - noise.rms_db - snr_dB, max_gain_db)
440 | noise_new = copy.deepcopy(noise)
441 | noise_new.random_subsegment(self.duration)
442 | noise_new.gain_db(noise_gain_db)
443 | self.superimpose(noise_new)
444 |
445 | def vad(self, top_db=20, overlap=200):
446 | self._samples = vad(wav=self._samples, top_db=top_db, overlap=overlap)
447 |
448 | def crop(self, duration, mode='eval'):
449 | if self.duration > duration:
450 | if mode == 'train':
451 | self.random_subsegment(duration)
452 | else:
453 | self.subsegment(end_sec=duration)
454 |
455 | @property
456 | def samples(self):
457 | """返回音频样本
458 |
459 | :return: Audio samples.
460 | :rtype: ndarray
461 | """
462 | return self._samples.copy()
463 |
464 | @property
465 | def sample_rate(self):
466 | """返回音频采样率
467 |
468 | :return: Audio sample rate.
469 | :rtype: int
470 | """
471 | return self._sample_rate
472 |
473 | @property
474 | def num_samples(self):
475 | """返回样品数量
476 |
477 | :return: Number of samples.
478 | :rtype: int
479 | """
480 | return self._samples.shape[0]
481 |
482 | @property
483 | def duration(self):
484 | """返回音频持续时间
485 |
486 | :return: Audio duration in seconds.
487 | :rtype: float
488 | """
489 | return self._samples.shape[0] / float(self._sample_rate)
490 |
491 | @property
492 | def rms_db(self):
493 | """返回以分贝为单位的音频均方根能量
494 |
495 | :return: Root mean square energy in decibels.
496 | :rtype: float
497 | """
498 | # square root => multiply by 10 instead of 20 for dBs
499 | mean_square = np.mean(self._samples ** 2)
500 | return 10 * np.log10(mean_square)
501 |
502 | def _convert_samples_to_float32(self, samples):
503 | """Convert sample type to float32.
504 |
505 | Audio sample type is usually integer or float-point.
506 | Integers will be scaled to [-1, 1] in float32.
507 | """
508 | float32_samples = samples.astype('float32')
509 | if samples.dtype in np.sctypes['int']:
510 | bits = np.iinfo(samples.dtype).bits
511 | float32_samples *= (1. / 2 ** (bits - 1))
512 | elif samples.dtype in np.sctypes['float']:
513 | pass
514 | else:
515 | raise TypeError("Unsupported sample type: %s." % samples.dtype)
516 | return float32_samples
517 |
518 | def _convert_samples_from_float32(self, samples, dtype):
519 | """Convert sample type from float32 to dtype.
520 |
521 | Audio sample type is usually integer or float-point. For integer
522 | type, float32 will be rescaled from [-1, 1] to the maximum range
523 | supported by the integer type.
524 |
525 | This is for writing a audio file.
526 | """
527 | dtype = np.dtype(dtype)
528 | output_samples = samples.copy()
529 | if dtype in np.sctypes['int']:
530 | bits = np.iinfo(dtype).bits
531 | output_samples *= (2 ** (bits - 1) / 1.)
532 | min_val = np.iinfo(dtype).min
533 | max_val = np.iinfo(dtype).max
534 | output_samples[output_samples > max_val] = max_val
535 | output_samples[output_samples < min_val] = min_val
536 | elif samples.dtype in np.sctypes['float']:
537 | min_val = np.finfo(dtype).min
538 | max_val = np.finfo(dtype).max
539 | output_samples[output_samples > max_val] = max_val
540 | output_samples[output_samples < min_val] = min_val
541 | else:
542 | raise TypeError("Unsupported sample type: %s." % samples.dtype)
543 | return output_samples.astype(dtype)
544 |
--------------------------------------------------------------------------------
/mvector/data_utils/featurizer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torchaudio.transforms import MelSpectrogram, Spectrogram, MFCC
4 |
5 |
6 | class AudioFeaturizer(nn.Module):
7 | """音频特征器
8 |
9 | :param feature_method: 所使用的预处理方法
10 | :type feature_method: str
11 | :param feature_conf: 预处理方法的参数
12 | :type feature_conf: dict
13 | """
14 |
15 | def __init__(self, feature_method='MelSpectrogram', feature_conf={}):
16 | super().__init__()
17 | self._feature_conf = feature_conf
18 | self._feature_method = feature_method
19 | if feature_method == 'MelSpectrogram':
20 | self.feat_fun = MelSpectrogram(**feature_conf)
21 | elif feature_method == 'Spectrogram':
22 | self.feat_fun = Spectrogram(**feature_conf)
23 | elif feature_method == 'MFCC':
24 | melkwargs = feature_conf.copy()
25 | del melkwargs['sample_rate']
26 | del melkwargs['n_mfcc']
27 | self.feat_fun = MFCC(sample_rate=self._feature_conf.sample_rate,
28 | n_mfcc=self._feature_conf.n_mfcc,
29 | melkwargs=melkwargs)
30 | else:
31 | raise Exception(f'预处理方法 {self._feature_method} 不存在!')
32 |
33 | def forward(self, waveforms, input_lens_ratio):
34 | """从AudioSegment中提取音频特征
35 |
36 | :param waveforms: Audio segment to extract features from.
37 | :type waveforms: AudioSegment
38 | :param input_lens_ratio: input length ratio
39 | :type input_lens_ratio: tensor
40 | :return: Spectrogram audio feature in 2darray.
41 | :rtype: ndarray
42 | """
43 | feature = self.feat_fun(waveforms)
44 | feature = feature.transpose(2, 1)
45 | # 归一化
46 | mean = torch.mean(feature, 1, keepdim=True)
47 | std = torch.std(feature, 1, keepdim=True)
48 | feature = (feature - mean) / (std + 1e-5)
49 | # 对掩码比例进行扩展
50 | input_lens = (input_lens_ratio * feature.shape[1])
51 | mask_lens = torch.round(input_lens).long()
52 | mask_lens = mask_lens.unsqueeze(1)
53 | input_lens = input_lens.int()
54 | # 生成掩码张量
55 | idxs = torch.arange(feature.shape[1], device=feature.device).repeat(feature.shape[0], 1)
56 | mask = idxs < mask_lens
57 | mask = mask.unsqueeze(-1)
58 | # 对特征进行掩码操作
59 | feature_masked = torch.where(mask, feature, torch.zeros_like(feature))
60 | return feature_masked, input_lens
61 |
62 | @property
63 | def feature_dim(self):
64 | """返回特征大小
65 |
66 | :return: 特征大小
67 | :rtype: int
68 | """
69 | if self._feature_method == 'LogMelSpectrogram':
70 | return self._feature_conf.n_mels
71 | elif self._feature_method == 'MelSpectrogram':
72 | return self._feature_conf.n_mels
73 | elif self._feature_method == 'Spectrogram':
74 | return self._feature_conf.n_fft // 2 + 1
75 | elif self._feature_method == 'MFCC':
76 | return self._feature_conf.n_mfcc
77 | else:
78 | raise Exception('没有{}预处理方法'.format(self._feature_method))
79 |
--------------------------------------------------------------------------------
/mvector/data_utils/utils.py:
--------------------------------------------------------------------------------
1 | import io
2 | import itertools
3 |
4 | import av
5 | import librosa
6 | import numpy as np
7 | import torch
8 |
9 |
10 | def vad(wav, top_db=20, overlap=200):
11 | # Split an audio signal into non-silent intervals
12 | intervals = librosa.effects.split(wav, top_db=top_db)
13 | if len(intervals) == 0:
14 | return wav
15 | wav_output = [np.array([])]
16 | for sliced in intervals:
17 | seg = wav[sliced[0]:sliced[1]]
18 | if len(seg) < 2 * overlap:
19 | wav_output[-1] = np.concatenate((wav_output[-1], seg))
20 | else:
21 | wav_output.append(seg)
22 | wav_output = [x for x in wav_output if len(x) > 0]
23 |
24 | if len(wav_output) == 1:
25 | wav_output = wav_output[0]
26 | else:
27 | wav_output = concatenate(wav_output)
28 | return wav_output
29 |
30 |
31 | def concatenate(wave, overlap=200):
32 | total_len = sum([len(x) for x in wave])
33 | unfolded = np.zeros(total_len)
34 |
35 | # Equal power crossfade
36 | window = np.hanning(2 * overlap)
37 | fade_in = window[:overlap]
38 | fade_out = window[-overlap:]
39 |
40 | end = total_len
41 | for i in range(1, len(wave)):
42 | prev = wave[i - 1]
43 | curr = wave[i]
44 |
45 | if i == 1:
46 | end = len(prev)
47 | unfolded[:end] += prev
48 |
49 | max_idx = 0
50 | max_corr = 0
51 | pattern = prev[-overlap:]
52 | # slide the curr batch to match with the pattern of previous one
53 | for j in range(overlap):
54 | match = curr[j:j + overlap]
55 | corr = np.sum(pattern * match) / [(np.sqrt(np.sum(pattern ** 2)) * np.sqrt(np.sum(match ** 2))) + 1e-8]
56 | if corr > max_corr:
57 | max_idx = j
58 | max_corr = corr
59 |
60 | # Apply the gain to the overlap samples
61 | start = end - overlap
62 | unfolded[start:end] *= fade_out
63 | end = start + (len(curr) - max_idx)
64 | curr[max_idx:max_idx + overlap] *= fade_in
65 | unfolded[start:end] += curr[max_idx:]
66 | return unfolded[:end]
67 |
68 |
69 | def decode_audio(file, sample_rate: int = 16000):
70 | """读取音频,主要用于兜底读取,支持各种数据格式
71 |
72 | Args:
73 | file: Path to the input file or a file-like object.
74 | sample_rate: Resample the audio to this sample rate.
75 |
76 | Returns:
77 | A float32 Numpy array.
78 | """
79 | resampler = av.audio.resampler.AudioResampler(format="s16", layout="mono", rate=sample_rate)
80 |
81 | raw_buffer = io.BytesIO()
82 | dtype = None
83 |
84 | with av.open(file, metadata_errors="ignore") as container:
85 | frames = container.decode(audio=0)
86 | frames = _ignore_invalid_frames(frames)
87 | frames = _group_frames(frames, 500000)
88 | frames = _resample_frames(frames, resampler)
89 |
90 | for frame in frames:
91 | array = frame.to_ndarray()
92 | dtype = array.dtype
93 | raw_buffer.write(array)
94 |
95 | audio = np.frombuffer(raw_buffer.getbuffer(), dtype=dtype)
96 |
97 | # Convert s16 back to f32.
98 | return audio.astype(np.float32) / 32768.0
99 |
100 |
101 | def _ignore_invalid_frames(frames):
102 | iterator = iter(frames)
103 |
104 | while True:
105 | try:
106 | yield next(iterator)
107 | except StopIteration:
108 | break
109 | except av.error.InvalidDataError:
110 | continue
111 |
112 |
113 | def _group_frames(frames, num_samples=None):
114 | fifo = av.audio.fifo.AudioFifo()
115 |
116 | for frame in frames:
117 | frame.pts = None # Ignore timestamp check.
118 | fifo.write(frame)
119 |
120 | if num_samples is not None and fifo.samples >= num_samples:
121 | yield fifo.read()
122 |
123 | if fifo.samples > 0:
124 | yield fifo.read()
125 |
126 |
127 | def _resample_frames(frames, resampler):
128 | # Add None to flush the resampler.
129 | for frame in itertools.chain(frames, [None]):
130 | yield from resampler.resample(frame)
131 |
132 |
133 | # 将音频流转换为numpy
134 | def buf_to_float(x, n_bytes=2, dtype=np.float32):
135 | """Convert an integer buffer to floating point values.
136 | This is primarily useful when loading integer-valued wav data
137 | into numpy arrays.
138 |
139 | Parameters
140 | ----------
141 | x : np.ndarray [dtype=int]
142 | The integer-valued data buffer
143 |
144 | n_bytes : int [1, 2, 4]
145 | The number of bytes per sample in ``x``
146 |
147 | dtype : numeric type
148 | The target output type (default: 32-bit float)
149 |
150 | Returns
151 | -------
152 | x_float : np.ndarray [dtype=float]
153 | The input data buffer cast to floating point
154 | """
155 |
156 | # Invert the scale of the data
157 | scale = 1.0 / float(1 << ((8 * n_bytes) - 1))
158 |
159 | # Construct the format string
160 | fmt = "