├── .idea ├── inspectionProfiles │ └── Project_Default.xml ├── misc.xml ├── modules.xml ├── speakerRecognition.iml └── vcs.xml ├── LICENSE ├── README.md ├── code ├── 0-input │ ├── __init__.py │ ├── get_data_list.py │ ├── process_data.py │ ├── unzip_tar.py │ └── vad.py ├── 1-development │ ├── __init__.py │ ├── data_feeder.py │ ├── model.py │ └── train.py ├── 2-enrollment │ ├── __init__.py │ └── enrollment.py ├── 3-evalution │ ├── __init__.py │ ├── confusion_matrix.py │ └── evalution.py ├── 4-roc_curve │ ├── __init__.py │ ├── caculate_score.py │ └── plot_roc.py ├── __init__.py └── utils │ ├── __init__.py │ ├── calculate_cds.py │ ├── csv_util.py │ ├── process_wav.py │ ├── sep_model_weight.py │ ├── split_data.py │ └── vad_util.py ├── docs └── DEEP NEURAL NETWORKS FOR SMALL FOOTPRINT TEXT-DEPENDENT.pdf ├── model_test.sh ├── requirements.txt └── results ├── plot_model_struct.py └── plots └── checkpoint-00484-0.99.jpg /.idea/inspectionProfiles/Project_Default.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 17 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 6 | 7 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/speakerRecognition.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 13 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Speaker Recognition 2 | 3 | [![avatar](https://img.shields.io/badge/contributions-welcome-brightgreen.svg?style=flat)](https://github.com/zacharyclam/speaker_recognition/pulls) 4 | [![avatar](https://badges.frapsoft.com/os/v2/open-source.png?v=103)](https://github.com/ellerbrock/open-source-badges) 5 | [![avatar](https://img.shields.io/badge/Made%20with-Python-1f425f.svg)](https://www.python.org/) 6 | [![avatar](https://img.shields.io/badge/license-Apache_2-blue.svg)](https://github.com/zacharyclam/speaker_recognition/blob/master/LICENSE) 7 | 8 | ​ 使用数据集:AISHELL-ASR0009-OS1 [下载](https://pan.baidu.com/s/1dFKRLwl#list/path=%2F) 9 | 10 | ​ 模型结构参考论文 “DEEP NEURAL NETWORKS FOR SMALL FOOTPRINT TEXT-DEPENDENT“,在实现时将论文中所提出的4层DNN结构的前两层替换为两层一维卷积,训练时通过Softmax分类器进行训练,注册及验证时将Softmax层去掉,DNN的输出作为d-vector,通过计算 *cosine-distance* 来判别说话人是否在注册集内。 11 | 12 | * #### 项目结构 13 | 14 | ``` 15 | - code 16 | -- 0-input # 数据预处理 17 | -- 1-development # 模型定义及训练 18 | -- 2-enrollment # 注册 19 | -- 3-evalution # 陌生人验证评估 20 | -- 4-roc_curve # 绘制ROC曲线图,并计算EER及阈值 21 | -- utils 22 | - data # 数据存放 23 | - docs # 参考论文 24 | - logs # tensorboard 日志文件 25 | - model # 模型存储文件 26 | - results 27 | -- features # 根据模型计算出注册人及陌生人的d-vector 28 | -- plots # 绘制完成的ROC曲线图 29 | -- scores # 绘制ROC曲线所需的score 30 | ``` 31 | 32 | 33 | 34 | * #### 训练 35 | 36 | * 首先对下载好的数据集进行VAD处理,处理代码位于 code/0-input/vad.py 37 | 38 | *usage:* 39 | 40 | ```shell 41 | python vad.py --save_dir="../../data/vad_data" --data_dir="解压之后的数据集路径" \ 42 | --category="要处理的数据类别,eg:train,test,dev" 43 | ``` 44 | 45 | * 将vad处理后的数据提取 *log fbank* 特征,该过程使用 python_speech_features 库完成 46 | 47 | usage: 48 | 49 | ```shell 50 | python process_data.py --data_dir="../../data/vad_data" --save_dir="提取log fbank后 bin文件保存路径" \ 51 | --category="要处理的数据类别" \ 52 | --validata_scale="若处理训练集数据,该参数可设置为验证集所占比例,eg:0.05, 若处理其他类别数据将其设置为0即可" 53 | ``` 54 | 55 | * 将训练集和验证集数据文件路径写入txt中,方便训练时打乱数据送入模型 56 | 57 | usage: 58 | 59 | ```shell 60 | python get_data_list.py --save_dir="../../data/bin/" --category="validate" # 验证集list 61 | python get_data_list.py --save_dir="../../data/bin/" --category="train" # 训练集list 62 | ``` 63 | 64 | * 通过执行train.py即可开始训练 65 | 66 | usage: 67 | 68 | ```shell 69 | python train.py --batch_size=128 --num_epochs=1000 --learn_rate=0.0001 70 | ``` 71 | 72 | * #### 评估模型 73 | 74 | 直接运行model_test.sh 脚本即可绘制ROC曲线图并计算EER,该脚本需要模型的路径参数,结果文件会保存至results目录下 75 | 76 | usage: 77 | 78 | ``` 79 | model_test.sh "model/checkpoint-00484-0.99.h5" 80 | ``` 81 | 82 | * #### 模型参数统计 83 | 84 | Total params: 5,781,524 85 | 86 | Trainable params: 5,781,428 87 | 88 | Non-trainable params: 96 89 | 90 | 模型结构图可在 results 目录下查看 91 | 92 | * #### 实验结果: 93 | 94 | 使用340人的语音数据进行训练,训练完成后,使用dev数据集共40人进行注册,将数据分为注册和验证两部分,每人选取15s音频进行注册,然后用100条长度为1s的音频进行验证,统计TP和FP的个数。使用test数据集共20人进行陌生人验证,每人选取100条长度为1s的音频,统计每条音频与注册集内得分最高的cds值。通过上述测试数据绘制ROC曲线图,计算出EER为12.2%,阈值为0.7824. 95 | 96 | 97 | 98 | - #### 模型下载: 99 | 100 | 百度网盘地址:https://pan.baidu.com/s/1rrVCKEIiqzZ3fTr4sKzr1Q 密码:3gri 101 | 102 | *其中 checkpoint-00484-0.99.h5 文件包含Softmax层 ,checkpoint-00484-0.99_notop.h5 已去掉Softmax层* 103 | 104 | 模型训练时的超参数:batch_size=128 ,learn_rate=0.0001,实验时使用一块 Titan 进行训练,大约5小时训练完成 105 | 106 | -------------------------------------------------------------------------------- /code/0-input/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/env/python python3 2 | # -*- coding: utf-8 -*- 3 | # @File : __init__.py.py 4 | # @Time : 2018/8/29 13:38 5 | # @Software : PyCharm 6 | """ 7 | code is far away from bugs with the god animal protecting 8 | I love animals. They taste delicious. 9 | ┏┓ ┏┓ 10 | ┏┛┻━━━┛┻┓ 11 | ┃ ☃ ┃ 12 | ┃ ┳┛ ┗┳ ┃ 13 | ┃ ┻ ┃ 14 | ┗━┓ ┏━┛ 15 | ┃ ┗━━━┓ 16 | ┃ 神兽保佑 ┣┓ 17 | ┃ 永无BUG! ┏┛ 18 | ┗┓┓┏━┳┓┏┛ 19 | ┃┫┫ ┃┫┫ 20 | ┗┻┛ ┗┻┛ 21 | """ -------------------------------------------------------------------------------- /code/0-input/get_data_list.py: -------------------------------------------------------------------------------- 1 | #!/usr/env/python python3 2 | # -*- coding: utf-8 -*- 3 | # @File : tvlistProduce.py 4 | # @Time : 2018/8/12 12:09 5 | # @Software : PyCharm 6 | import os 7 | from absl import flags, app 8 | 9 | 10 | def get_list(save_dir, category): 11 | """ 12 | 将bin文件路径及标签写入txt 13 | 14 | :param save_dir: 保存文件路径 15 | :param usage: 数据集类别 16 | :return: 17 | """ 18 | tname = os.path.join(save_dir, category + "_list.txt") 19 | data_dir = os.path.join(save_dir, category) 20 | # 获取子文件夹下的文件列表 21 | sub_dir = os.listdir(data_dir) 22 | 23 | with open(tname, "w") as f: 24 | for i, subname in enumerate(sub_dir): 25 | subpath = os.path.join(data_dir, subname) 26 | for filename in os.listdir(subpath): 27 | # 文件路径 标签 28 | line = os.path.join(subpath, filename) + " " + str(i) + "\n" 29 | f.write(line) 30 | 31 | 32 | root_dir = os.path.abspath(os.path.join(os.getcwd(), "../..")) 33 | FLAGS = flags.FLAGS 34 | 35 | flags.DEFINE_string("save_dir", os.path.join(root_dir, "data/bin/"), "save list to dir") 36 | flags.DEFINE_string("category", "train", "the category of data") 37 | 38 | 39 | def main(argv): 40 | get_list(os.path.abspath(FLAGS.save_dir), FLAGS.category) 41 | 42 | 43 | if __name__ == "__main__": 44 | app.run(main) 45 | # usage 46 | # python3 get_data_list.py --save_dir="../../data/bin/" --category="validate" 47 | # python3 get_data_list.py --save_dir="../../data/bin/" --category="train" 48 | -------------------------------------------------------------------------------- /code/0-input/process_data.py: -------------------------------------------------------------------------------- 1 | #!/usr/env/python python3 2 | # -*- coding: utf-8 -*- 3 | # @File : process_data.py 4 | # @Time : 2018/8/11 18:02 5 | # @Software : PyCharm 6 | import os 7 | from absl import app, flags 8 | 9 | try: 10 | from code.utils.process_wav import wav2fb 11 | from code.utils.split_data import split_data 12 | except ImportError: 13 | # ubuntu 下运行会出现 ImportError 14 | import sys 15 | sys.path.append(os.path.abspath(os.path.join(os.getcwd(), ".."))) 16 | from utils.process_wav import wav2fb 17 | from utils.split_data import split_data 18 | 19 | root_dir = os.path.abspath(os.path.join(os.getcwd(), "../..")) 20 | 21 | FLAGS = flags.FLAGS 22 | 23 | flags.DEFINE_string("data_dir", os.path.join(root_dir, "data/vad_data"), help="the vad data dir") 24 | flags.DEFINE_string("save_dir", os.path.join(root_dir, "data/bin"), help="save fbank vector dir") 25 | flags.DEFINE_string("category", "test", help="the category of data") 26 | flags.DEFINE_float("validata_scale", 0.05, help="the scale of validate data") 27 | 28 | 29 | def main(argv): 30 | # 分割数据集 31 | train_list, validate_list = split_data(FLAGS.data_dir, FLAGS.category, FLAGS.validata_scale) 32 | # 将特征向量存入 bin 文件中 33 | # 处理训练数据 分为训练集和验证集 34 | if FLAGS.validata_scale > 0: 35 | wav2fb(train_list, os.path.abspath(FLAGS.save_dir), "train") 36 | wav2fb(validate_list, os.path.abspath(FLAGS.save_dir), "validate") 37 | # 处理注册和验证时的数据 validata_scale=0 38 | else: 39 | wav2fb(train_list, os.path.abspath(FLAGS.save_dir), FLAGS.category) 40 | 41 | 42 | if __name__ == "__main__": 43 | app.run(main) 44 | # usage 45 | # train on linux 46 | # nohup python3 -u process_data.py --data_dir="../../data/vad_data" --save_dir="../../data/bin" --category="train" > logs.out 2>&1 & 47 | # dev 48 | # python process_data.py --data_dir="../../data/vad_data" --save_dir="../../data/enrollment_evalution" \ 49 | # --category="dev" --validata_scale=0.0 50 | # test 51 | # python process_data.py --data_dir="../../data/vad_data" --save_dir="../../data/enrollment_evalution" \ 52 | # --category="test" --validata_scale=0.0 53 | -------------------------------------------------------------------------------- /code/0-input/unzip_tar.py: -------------------------------------------------------------------------------- 1 | #!/usr/env/python python3 2 | # -*- coding: utf-8 -*- 3 | # @File : unzip_tar.py 4 | # @Time : 2018/8/17 9:01 5 | # @Software : PyCharm 6 | import tarfile 7 | import os 8 | from absl import flags, app 9 | 10 | FLAGS = flags.FLAGS 11 | 12 | flags.DEFINE_string( 13 | "data_dir", "../../data/data_aishell/wav/", 14 | "the original tar data dir") 15 | 16 | flags.DEFINE_string( 17 | "save_dir", "../../untar_data/", 18 | "the save untar data dir") 19 | 20 | 21 | # 解压tar.gz文件到文件夹 22 | def untar_dir(srcname, data_path): 23 | tar_handle = tarfile.open(srcname, "r:gz") 24 | tar_handle.extractall(data_path) 25 | tar_handle.close() 26 | 27 | 28 | def main(argv): 29 | # 将原始音频文件解压 30 | for root, dir, file in os.walk(FLAGS.data_dir): 31 | for filename in file: 32 | untar_dir(os.path.join(root, filename), FLAGS.save_dir) 33 | print("finished") 34 | 35 | 36 | if __name__ == "__main__": 37 | app.run(main) 38 | -------------------------------------------------------------------------------- /code/0-input/vad.py: -------------------------------------------------------------------------------- 1 | #!/usr/env/python python3 2 | # -*- coding: utf-8 -*- 3 | # @File : vad_util.py 4 | # @Time : 2018/8/29 13:37 5 | # @Software : PyCharm 6 | import os 7 | import re 8 | from tqdm import tqdm 9 | from absl import flags, app 10 | import librosa 11 | 12 | try: 13 | from code.utils.vad_util import remove_silence 14 | except ImportError: 15 | # ubuntu 下运行会出现 ImportError 16 | import sys 17 | sys.path.append(os.path.abspath(os.path.join(os.getcwd(), ".."))) 18 | from utils.vad_util import remove_silence 19 | 20 | 21 | def get_datalist(data_dir, category): 22 | data_path = os.path.join(data_dir, category) 23 | 24 | # 获取文件列表 25 | data_list = [] 26 | for idx, dir in enumerate(os.listdir(data_path)): 27 | file_list = os.listdir(os.path.join(data_path, dir)) 28 | data_list.append(([os.path.join(data_path, dir, file) for file in file_list], str(idx).zfill(4))) 29 | return data_list 30 | 31 | 32 | def vad_wav(wav_path, save_dir, sr=16000): 33 | wav_name = re.search(r"B\S+.$", wav_path).group(0) 34 | if not os.path.exists(save_dir): 35 | os.makedirs(save_dir) 36 | 37 | wav_data, rate = librosa.load(wav_path, sr) 38 | 39 | y = remove_silence(wav_data, wav_data, 139, 300) 40 | # 写入文件 41 | librosa.output.write_wav(os.path.join(save_dir, wav_name), y, rate) 42 | 43 | 44 | root_dir = os.path.abspath(os.path.join(os.getcwd(), "../..")) 45 | FLAGS = flags.FLAGS 46 | 47 | flags.DEFINE_string("data_dir", os.path.join(root_dir, "data"), help="the original audio data dir") 48 | flags.DEFINE_string("save_dir", os.path.join(root_dir, "data/vad_data"), help="save vad data dir") 49 | flags.DEFINE_string("category", "dev", help="the category of data") 50 | 51 | 52 | def main(args): 53 | data_list = get_datalist(FLAGS.data_dir, FLAGS.category) 54 | save_path = os.path.join(FLAGS.save_dir, FLAGS.category) 55 | 56 | for file_list, label in tqdm(data_list): 57 | for wav_path in file_list: 58 | vad_wav(wav_path, os.path.join(save_path, label)) 59 | 60 | 61 | if __name__ == '__main__': 62 | app.run(main) 63 | # usage 64 | # nohup python3 -u vad.py --save_dir="../../data/vad_data" --data_dir="../../../../untar_data" --category="train" > logs.out 2>&1 & 65 | -------------------------------------------------------------------------------- /code/1-development/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/env/python python3 2 | # -*- coding: utf-8 -*- 3 | # @File : __init__.py 4 | # @Time : 2018/8/24 17:51 5 | # @Software : PyCharm -------------------------------------------------------------------------------- /code/1-development/data_feeder.py: -------------------------------------------------------------------------------- 1 | #!/usr/env/python python3 2 | # -*- coding: utf-8 -*- 3 | # @File : data_feeder.py 4 | # @Time : 2018/8/24 17:45 5 | # @Software : PyCharm 6 | # 生成数据 7 | import numpy as np 8 | from keras.utils import to_categorical 9 | 10 | 11 | def generate_fit(path, batch_size, classe_nums): 12 | """ 13 | 14 | :param path: list.txt路径 15 | :param batch_size: 16 | :param classe_nums: 类别数 17 | :return: 18 | """ 19 | with open(path) as f: 20 | data_list = [line.strip().split(' ') for line in f] 21 | index = np.arange(len(data_list)) 22 | while True: 23 | cnt = 0 24 | X = [] 25 | Y = [] 26 | # shffle data 27 | np.random.shuffle(index) 28 | data_list = np.array(data_list)[index, :] 29 | for bin_path, label in data_list: 30 | x = np.fromfile(bin_path, dtype=np.float).reshape((99, 40)) 31 | # x = np.fromfile(bin_path, dtype=np.float) 32 | X.append(x) 33 | Y.append(label) 34 | 35 | cnt += 1 36 | 37 | if cnt % batch_size == 0: 38 | yield [np.array(X), np.array(to_categorical(Y, classe_nums))] 39 | X = [] 40 | Y = [] 41 | cnt = 0 42 | -------------------------------------------------------------------------------- /code/1-development/model.py: -------------------------------------------------------------------------------- 1 | #!/usr/env/python python3 2 | # -*- coding: utf-8 -*- 3 | # @File : model.py 4 | # @Time : 2018/8/24 14:49 5 | # @Software : PyCharm 6 | from keras.regularizers import l2 7 | from keras.models import Sequential 8 | from keras.layers import Dense, Dropout, Input, MaxoutDense, BatchNormalization, Conv1D, Flatten, MaxPool1D 9 | from keras.models import Model 10 | from keras.utils import plot_model 11 | 12 | 13 | def construct_model(classe_nums): 14 | model = Sequential() 15 | 16 | model.add( 17 | Conv1D(filters=256, kernel_size=3, strides=1, activation='relu', input_shape=(99, 40), name='block1_conv1')) 18 | model.add(MaxPool1D(pool_size=2, name='block1_pool1')) 19 | model.add(BatchNormalization(momentum=0.9, epsilon=1e-5, axis=1)) 20 | 21 | model.add(Conv1D(filters=256, kernel_size=3, strides=1, activation='relu', name='block1_conv2')) 22 | model.add(MaxPool1D(pool_size=2, name='block1_pool2')) 23 | 24 | model.add(Flatten(name='block1_flat1')) 25 | model.add(Dropout(0.5, name='block1_drop1')) 26 | 27 | model.add(Dense(512, activation='relu', name='block2_dense2')) 28 | model.add(MaxoutDense(512, nb_feature=4, name="block2_maxout2")) 29 | model.add(Dropout(0.5, name='block2_drop2')) 30 | 31 | model.add(Dense(512, activation='relu', name='block2_dense3', kernel_regularizer=l2(1e-4))) 32 | model.add(MaxoutDense(512, nb_feature=4, name="block2_maxout3")) 33 | 34 | model.summary() 35 | 36 | model_input = Input(shape=(99, 40)) 37 | features = model(model_input) 38 | extract_feature_model = Model(inputs=model_input, outputs=features) 39 | 40 | category_predict = Dense(classe_nums, activation='softmax', name="predict")(features) 41 | 42 | sr_model = Model(inputs=model_input, outputs=category_predict) 43 | 44 | plot_model(sr_model, to_file='model.png', show_shapes=True, show_layer_names=False) 45 | return extract_feature_model, sr_model 46 | 47 | 48 | if __name__ == "__main__": 49 | extract_feature_model, sr_model = construct_model(classe_nums=340) 50 | -------------------------------------------------------------------------------- /code/1-development/train.py: -------------------------------------------------------------------------------- 1 | #!/usr/env/python python3 2 | # -*- coding: utf-8 -*- 3 | # @File : train.py 4 | # @Time : 2018/8/12 12:49 5 | # @Software : PyCharm 6 | 7 | import os 8 | import tensorflow as tf 9 | from keras.optimizers import Adam 10 | from keras.callbacks import ModelCheckpoint, TensorBoard, ReduceLROnPlateau 11 | import keras.backend as K 12 | import numpy as np 13 | 14 | from model import construct_model 15 | from data_feeder import generate_fit 16 | 17 | 18 | root_dir = os.path.abspath(os.path.join(os.getcwd(), "../..")) 19 | 20 | tf.flags.DEFINE_integer( 21 | "batch_size", default=128, 22 | help="Batch size (default: 128)") 23 | 24 | tf.flags.DEFINE_integer( 25 | "num_epochs", default=100, 26 | help="Number of training epochs (defaule:100)") 27 | 28 | tf.flags.DEFINE_integer( 29 | "num_classes", default=340, 30 | help="Number of training data classes (default:340)") 31 | 32 | tf.flags.DEFINE_float( 33 | "learn_rate", default=0.0001, 34 | help="learn rate (default: 0.0001)") 35 | 36 | tf.flags.DEFINE_string( 37 | "category", default="train", 38 | help="the category of data") 39 | 40 | tf.flags.DEFINE_string( 41 | "model_dir", default=os.path.join(root_dir, "model"), 42 | help="the model file dir") 43 | 44 | tf.flags.DEFINE_string( 45 | "tensorboard_dir", default=os.path.join(root_dir, "logs"), 46 | help="the tensorboard file dir") 47 | 48 | tf.flags.DEFINE_string( 49 | "datalist_dir", default=os.path.join(root_dir, "data/bin"), 50 | help="the data list file dir") 51 | 52 | # FLAGS 是一个对象,保存了解析后的命令行参数 53 | FLAGS = tf.flags.FLAGS 54 | # 进行解析 55 | FLAGS.flag_values_dict() 56 | 57 | if not os.path.exists(FLAGS.model_dir): 58 | os.makedirs(FLAGS.model_dir) 59 | 60 | if not os.path.exists(FLAGS.tensorboard_dir): 61 | os.makedirs(FLAGS.tensorboard_dir) 62 | 63 | # the paths 64 | train_path = os.path.join(FLAGS.datalist_dir, "train_list.txt") 65 | test_path = os.path.join(FLAGS.datalist_dir, "validate_list.txt") 66 | 67 | # count the number of samples 68 | f = open(train_path) 69 | train_nums = len(f.readlines()) # number of train samples 70 | f.close() 71 | 72 | f = open(test_path) 73 | test_nums = len(f.readlines()) # number of train samples 74 | f.close() 75 | 76 | if __name__ == '__main__': 77 | # 指定使用显卡 78 | os.environ["CUDA_VISIBLE_DEVICES"] = "1" 79 | 80 | config = tf.ConfigProto() 81 | config.gpu_options.per_process_gpu_memory_fraction = 0.90 # 占用GPU90%的显存 82 | K.set_session(tf.Session(config=config)) 83 | 84 | # 创建模型 85 | extract_feature_model, sr_model = construct_model(FLAGS.num_classes) 86 | 87 | # 创建优化器 88 | opt = Adam(lr=FLAGS.learn_rate, beta_1=0.9, beta_2=0.999, epsilon=1e-08) 89 | sr_model.compile(loss='categorical_crossentropy', optimizer=opt, metrics=['accuracy']) 90 | 91 | # 学习率衰减 92 | reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=10, 93 | min_lr=1e-8, mode="min", cooldown=10, verbose=1) 94 | 95 | tbCallBack = TensorBoard(log_dir=FLAGS.tensorboard_dir, 96 | histogram_freq=0, 97 | write_graph=True, 98 | write_images=True) 99 | 100 | checkpoint = ModelCheckpoint(filepath=os.path.join(FLAGS.model_dir, "checkpoint-{epoch:05d}-{val_acc:.2f}.h5"), 101 | monitor='val_acc', verbose=2, save_best_only=True, mode='max') 102 | 103 | # 开始训练 104 | sr_model.fit_generator(generate_fit(train_path, FLAGS.batch_size, FLAGS.num_classes), 105 | steps_per_epoch=np.ceil(train_nums / FLAGS.batch_size), 106 | shuffle=True, 107 | validation_data=generate_fit(test_path, FLAGS.batch_size, FLAGS.num_classes), 108 | validation_steps=np.ceil(test_nums / FLAGS.batch_size), 109 | epochs=FLAGS.num_epochs, 110 | verbose=2, 111 | callbacks=[reduce_lr, checkpoint, tbCallBack] 112 | ) 113 | 114 | sr_model.save("spk.h5") 115 | 116 | # usage 117 | # nohup python3 -u train.py --batch_size=128 --num_epochs=1000 --learn_rate=0.0001 > logs.out 2>&1 & 118 | -------------------------------------------------------------------------------- /code/2-enrollment/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/env/python python3 2 | # -*- coding: utf-8 -*- 3 | # @File : __init__.py.py 4 | # @Time : 2018/8/24 21:26 5 | # @Software : PyCharm 6 | """ 7 | code is far away from bugs with the god animal protecting 8 | I love animals. They taste delicious. 9 | ┏┓ ┏┓ 10 | ┏┛┻━━━┛┻┓ 11 | ┃ ☃ ┃ 12 | ┃ ┳┛ ┗┳ ┃ 13 | ┃ ┻ ┃ 14 | ┗━┓ ┏━┛ 15 | ┃ ┗━━━┓ 16 | ┃ 神兽保佑 ┣┓ 17 | ┃ 永无BUG! ┏┛ 18 | ┗┓┓┏━┳┓┏┛ 19 | ┃┫┫ ┃┫┫ 20 | ┗┻┛ ┗┻┛ 21 | """ -------------------------------------------------------------------------------- /code/2-enrollment/enrollment.py: -------------------------------------------------------------------------------- 1 | #!/usr/env/python python3 2 | # -*- coding: utf-8 -*- 3 | # @File : enrollment.py 4 | # @Time : 2018/8/22 14:27 5 | # @Software : PyCharm 6 | import os 7 | from keras.models import load_model 8 | from absl import flags, app 9 | try: 10 | import sys 11 | # 防止通过脚本运行时由于路径问题出现 ModuleNotFoundError 12 | sys.path.append(os.path.join(os.getcwd(), "code")) 13 | from utils.csv_util import features2csv 14 | except ModuleNotFoundError: 15 | from code.utils.csv_util import features2csv 16 | 17 | root_dir = os.path.abspath(os.path.join(os.getcwd(), "../..")) 18 | 19 | FLAGS = flags.FLAGS 20 | 21 | flags.DEFINE_string( 22 | "data_dir", os.path.join(root_dir, "data/enrollment_evalution"), 23 | "the enrolled data dir") 24 | 25 | flags.DEFINE_string( 26 | "save_dir", os.path.join(root_dir, "results/features"), 27 | "the save d-vector features dir") 28 | 29 | flags.DEFINE_string( 30 | "weight_path", os.path.join(root_dir, "model/spk-01000-1.00.h5"), 31 | "the model dir") 32 | 33 | flags.DEFINE_string( 34 | "category", "dev", "the category of data") 35 | 36 | flags.DEFINE_integer( 37 | "enroll_sentence_nums", 15, 38 | "the enroll sentence nums") 39 | 40 | flags.DEFINE_integer( 41 | "val_sentence_nums", 100, 42 | "the validate sentence nums") 43 | 44 | 45 | def split_sentences(data_dir, save_dir, usage, enroll_sentence_nums=20, val_sentence_nums=100): 46 | """ 47 | 48 | :paramdataDir: 文件存储路径 49 | :param usage: 数据集类别 50 | :param enroll_num: 注册使用语句条数 51 | :param val_num: 验证使用语句条数 52 | :return: 53 | """ 54 | data_path = os.path.join(data_dir, usage) 55 | # 获取文件列表 56 | data_list = [] 57 | for dir in os.listdir(data_path): 58 | file_list = os.listdir(os.path.join(data_path, dir)) 59 | data_list.append([os.path.join(data_path, dir, file) for file in file_list]) 60 | 61 | # (list , label) 62 | enroll_list = [] 63 | val_list = [] 64 | 65 | for i, file_list in enumerate(data_list): 66 | enroll_list.append((file_list[:enroll_sentence_nums], i)) 67 | val_list.append((file_list[enroll_sentence_nums:enroll_sentence_nums + val_sentence_nums], i)) 68 | 69 | with open(os.path.join(save_dir, "enroll_list.txt"), "w") as f: 70 | for (file_list, label) in enroll_list: 71 | for file in file_list: 72 | line = file + " " + str(label).zfill(4) + "\n" 73 | f.write(line) 74 | 75 | with open(os.path.join(save_dir, "validate_list.txt"), "w") as f: 76 | for (file_list, label) in val_list: 77 | for file in file_list: 78 | line = file + " " + str(label).zfill(4) + "\n" 79 | f.write(line) 80 | 81 | 82 | def main(argv): 83 | if not os.path.exists(FLAGS.save_dir): 84 | os.makedirs(FLAGS.save_dir) 85 | # 导入预测模型权重 86 | model = load_model(FLAGS.weight_path) 87 | # 分割 注册人 数据集 并写入txt 88 | split_sentences(FLAGS.data_dir, FLAGS.save_dir, FLAGS.category, enroll_sentence_nums=FLAGS.enroll_sentence_nums, 89 | val_sentence_nums=FLAGS.val_sentence_nums) 90 | 91 | # 将注册人的注册语句特征写入csv文件 92 | features2csv(FLAGS.save_dir, category="enroll", model=model, mean=True, sentence_nums=FLAGS.enroll_sentence_nums) 93 | 94 | # 将注册人的验证语句特征写入csv文件 95 | features2csv(FLAGS.save_dir, category="validate", model=model, mean=False, sentence_nums=FLAGS.val_sentence_nums) 96 | 97 | 98 | if __name__ == "__main__": 99 | app.run(main) 100 | -------------------------------------------------------------------------------- /code/3-evalution/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/env/python python3 2 | # -*- coding: utf-8 -*- 3 | # @File : __init__.py.py 4 | # @Time : 2018/8/24 21:26 5 | # @Software : PyCharm 6 | """ 7 | code is far away from bugs with the god animal protecting 8 | I love animals. They taste delicious. 9 | ┏┓ ┏┓ 10 | ┏┛┻━━━┛┻┓ 11 | ┃ ☃ ┃ 12 | ┃ ┳┛ ┗┳ ┃ 13 | ┃ ┻ ┃ 14 | ┗━┓ ┏━┛ 15 | ┃ ┗━━━┓ 16 | ┃ 神兽保佑 ┣┓ 17 | ┃ 永无BUG! ┏┛ 18 | ┗┓┓┏━┳┓┏┛ 19 | ┃┫┫ ┃┫┫ 20 | ┗┻┛ ┗┻┛ 21 | """ -------------------------------------------------------------------------------- /code/3-evalution/confusion_matrix.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | 4 | from tqdm import tqdm 5 | from absl import flags, app 6 | 7 | from code.utils.calculate_cds import get_cds 8 | from code.utils.csv_util import read_features 9 | 10 | 11 | def confusion_matrix_test(features_dir, matrix_dir): 12 | validate_dict = read_features(features_dir, "validate") 13 | # 混淆矩阵 14 | confusion_matrix = np.zeros((40, 40), dtype=np.int16) 15 | # 统计 16 | for val_label, val_feat in tqdm(validate_dict): 17 | enroll_dict = read_features(features_dir, "enroll") 18 | distance = [get_cds(val_feat, enroll_feat) for _, enroll_feat in enroll_dict] 19 | predict_label = np.argmax(distance, axis=0) 20 | confusion_matrix[val_label][predict_label] += 1 21 | np.savetxt(os.path.join(matrix_dir, "confusion_matrix.csv"), confusion_matrix, fmt='%d', delimiter=",") 22 | 23 | 24 | root_dir = os.path.abspath(os.path.join(os.getcwd(), "../..")) 25 | FLAGS = flags.FLAGS 26 | 27 | flags.DEFINE_string( 28 | "features_dir", os.path.join(root_dir, "results/features"), 29 | "the enrolled data dir") 30 | 31 | flags.DEFINE_string( 32 | "matrix_dir", default=os.path.join(root_dir, "results"), 33 | help="the dir of saving confusion matrix") 34 | 35 | 36 | def main(argv): 37 | confusion_matrix_test(FLAGS.features_dir, FLAGS.matrix_dir) 38 | 39 | 40 | if __name__ == "__main__": 41 | app.run(main) 42 | -------------------------------------------------------------------------------- /code/3-evalution/evalution.py: -------------------------------------------------------------------------------- 1 | #!/usr/env/python python3 2 | # -*- coding: utf-8 -*- 3 | # @File : evalution.py 4 | # @Time : 2018/8/22 17:33 5 | # @Software : PyCharm 6 | import os 7 | from keras.models import load_model 8 | from absl import flags, app 9 | try: 10 | import sys 11 | # 防止通过脚本运行时由于路径问题出现 ModuleNotFoundError 12 | sys.path.append(os.path.join(os.getcwd(), "code")) 13 | from utils.csv_util import features2csv 14 | except ModuleNotFoundError: 15 | from code.utils.csv_util import features2csv 16 | 17 | 18 | def split_data(data_dir, save_dir, usage, sentence_nums=20): 19 | """ 20 | 21 | :paramdataDir: 文件存储路径 22 | :param usage: 数据集类别 23 | :param enroll_num: 注册使用语句条数 24 | :param val_num: 验证使用语句条数 25 | :return: 26 | """ 27 | data_path = os.path.join(data_dir, usage) 28 | # 获取文件列表 29 | data_list = [] 30 | for dir in os.listdir(data_path): 31 | file_list = os.listdir(os.path.join(data_path, dir)) 32 | data_list.append([os.path.join(data_path, dir, file) for file in file_list]) 33 | 34 | # (list , label) 35 | stranger_list = [] 36 | 37 | for i, file_list in enumerate(data_list): 38 | stranger_list.append((file_list[:sentence_nums], i)) 39 | 40 | with open(os.path.join(save_dir, "stranger_list.txt"), "w") as f: 41 | for (file_list, label) in stranger_list: 42 | for file in file_list: 43 | line = file + " " + str(label).zfill(4) + "\n" 44 | f.write(line) 45 | 46 | 47 | FLAGS = flags.FLAGS 48 | 49 | root_dir = os.path.abspath(os.path.join(os.getcwd(), "../..")) 50 | 51 | flags.DEFINE_string( 52 | "data_dir", os.path.join(root_dir, "data/enrollment_evalution"), 53 | "the enrolled data dir") 54 | 55 | flags.DEFINE_string( 56 | "weight_path", os.path.join(root_dir, "model/spk-01000-1.00.h5"), 57 | "the model dir") 58 | 59 | flags.DEFINE_string( 60 | "category", "test", "the category of data") 61 | 62 | flags.DEFINE_string( 63 | "save_dir", os.path.join(root_dir, "results/features"), 64 | "the strangers' features save dir") 65 | 66 | flags.DEFINE_integer( 67 | "stranger_sentence_nums", 100, 68 | "the stranger sentence nums") 69 | 70 | 71 | def main(argv): 72 | if not os.path.exists(FLAGS.save_dir): 73 | os.makedirs(FLAGS.save_dir) 74 | # 导入预测模型权重 75 | model = load_model(FLAGS.weight_path) 76 | # 分割 陌生人 数据集 并写入txt 77 | split_data(FLAGS.data_dir, FLAGS.save_dir, FLAGS.category, sentence_nums=FLAGS.stranger_sentence_nums) 78 | 79 | # 将陌生人的注册语句特征写入csv文件 80 | features2csv(FLAGS.save_dir, "stranger", model, mean=False, 81 | sentence_nums=FLAGS.stranger_sentence_nums) 82 | 83 | 84 | if __name__ == "__main__": 85 | app.run(main) 86 | -------------------------------------------------------------------------------- /code/4-roc_curve/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/env/python python3 2 | # -*- coding: utf-8 -*- 3 | # @File : __init__.py.py 4 | # @Time : 2018/8/24 21:26 5 | # @Software : PyCharm 6 | """ 7 | code is far away from bugs with the god animal protecting 8 | I love animals. They taste delicious. 9 | ┏┓ ┏┓ 10 | ┏┛┻━━━┛┻┓ 11 | ┃ ☃ ┃ 12 | ┃ ┳┛ ┗┳ ┃ 13 | ┃ ┻ ┃ 14 | ┗━┓ ┏━┛ 15 | ┃ ┗━━━┓ 16 | ┃ 神兽保佑 ┣┓ 17 | ┃ 永无BUG! ┏┛ 18 | ┗┓┓┏━┳┓┏┛ 19 | ┃┫┫ ┃┫┫ 20 | ┗┻┛ ┗┻┛ 21 | """ -------------------------------------------------------------------------------- /code/4-roc_curve/caculate_score.py: -------------------------------------------------------------------------------- 1 | #!/usr/env/python python3 2 | # -*- coding: utf-8 -*- 3 | # @File : caculate_score.py 4 | # @Time : 2018/8/22 17:50 5 | # @Software : PyCharm 6 | import numpy as np 7 | import os 8 | from absl import flags, app 9 | from tqdm import tqdm 10 | try: 11 | import sys 12 | # 防止通过脚本运行时由于路径问题出现 ModuleNotFoundError 13 | sys.path.append(os.path.join(os.getcwd(), "code")) 14 | from utils.csv_util import read_features 15 | from utils.calculate_cds import get_cds 16 | except ModuleNotFoundError: 17 | from code.utils.csv_util import read_features 18 | from code.utils.calculate_cds import get_cds 19 | 20 | 21 | root_dir = os.path.abspath(os.path.join(os.getcwd(), "../..")) 22 | 23 | FLAGS = flags.FLAGS 24 | 25 | flags.DEFINE_string( 26 | "features_dir", default=os.path.join(root_dir, "results/features"), 27 | help="the dir of enrolllment") 28 | 29 | flags.DEFINE_string( 30 | "score_dir", default=os.path.join(root_dir, "results/scores"), 31 | help="the dir of saving score") 32 | 33 | 34 | def main(argv): 35 | if not os.path.exists(FLAGS.features_dir): 36 | os.makedirs(FLAGS.features_dir) 37 | if not os.path.exists(FLAGS.score_dir): 38 | os.makedirs(FLAGS.score_dir) 39 | # 读取特征向量 40 | validate_dict = read_features(FLAGS.features_dir, "validate") 41 | strange_dict = read_features(FLAGS.features_dir, "stranger") 42 | 43 | # 计算score 44 | # score n/tp/fp 45 | with open(os.path.join(FLAGS.score_dir, "score.txt"), "w") as f: 46 | for val_label, val_feat in tqdm(validate_dict): 47 | enroll_dict = read_features(FLAGS.features_dir, "enroll") 48 | distance = [get_cds(val_feat, enroll_feat) for _, enroll_feat in enroll_dict] 49 | predict_label = np.argmax(distance, axis=0) 50 | line = str(distance[int(predict_label)]) + " " 51 | if predict_label == val_label: 52 | line += "tp\n" 53 | else: 54 | line += "fp\n" 55 | f.write(line) 56 | 57 | for stranger_label, stranger_feat in tqdm(strange_dict): 58 | enroll_dict = read_features(FLAGS.features_dir, "enroll") 59 | distance = [get_cds(stranger_feat, enroll_feat) for _, enroll_feat in enroll_dict] 60 | predict_label = np.argmax(distance, axis=0) 61 | line = str(distance[int(predict_label)]) + " n\n" 62 | f.write(line) 63 | 64 | 65 | if __name__ == "__main__": 66 | app.run(main) 67 | -------------------------------------------------------------------------------- /code/4-roc_curve/plot_roc.py: -------------------------------------------------------------------------------- 1 | #!/usr/env/python python3 2 | # -*- coding: utf-8 -*- 3 | # @File : plot_roc.py 4 | # @Time : 2018/8/22 20:58 5 | # @Software : PyCharm 6 | import os 7 | import numpy as np 8 | import matplotlib.pylab as plt 9 | from absl import app, flags 10 | 11 | 12 | def cal_rate(score_dict, thres): 13 | all_number = len(score_dict) 14 | # print all_number 15 | TP = 0 16 | FP = 0 17 | FN = 0 18 | TN = 0 19 | for score, label in score_dict: 20 | if score >= thres: 21 | score = 1 22 | if score == 1: 23 | if label == "tp": 24 | TP += 1 25 | else: 26 | FP += 1 27 | else: 28 | if label == "n": 29 | TN += 1 30 | else: 31 | FN += 1 32 | # print TP+FP+TN+FN 33 | accracy = float(TP + FP) / float(all_number) 34 | if TP + FP == 0: 35 | precision = 0 36 | else: 37 | precision = float(TP) / float(TP + FP) 38 | TPR = float(TP) / float(TP + FN) 39 | TNR = float(TN) / float(FP + TN) 40 | FNR = float(FN) / float(TP + FN) 41 | FPR = float(FP) / float(FP + TN) 42 | 43 | return accracy, precision, TPR, TNR, FNR, FPR 44 | 45 | 46 | def plot_roc(score_list, save_dir, plot_name): 47 | 48 | save_path = os.path.join(save_dir, plot_name + ".jpg") 49 | # 按照 score 排序 50 | threshold_value = sorted([score for score, _ in score_list]) 51 | 52 | threshold_num = len(threshold_value) 53 | accracy_array = np.zeros(threshold_num) 54 | precision_array = np.zeros(threshold_num) 55 | TPR_array = np.zeros(threshold_num) 56 | TNR_array = np.zeros(threshold_num) 57 | FNR_array = np.zeros(threshold_num) 58 | FPR_array = np.zeros(threshold_num) 59 | 60 | # calculate all the rates 61 | for thres in range(threshold_num): 62 | accracy, precision, TPR, TNR, FNR, FPR = cal_rate(score_list, threshold_value[thres]) 63 | accracy_array[thres] = accracy 64 | precision_array[thres] = precision 65 | TPR_array[thres] = TPR 66 | TNR_array[thres] = TNR 67 | FNR_array[thres] = FNR 68 | FPR_array[thres] = FPR 69 | 70 | AUC = np.trapz(TPR_array, FPR_array) 71 | threshold = np.argmin(abs(FNR_array - FPR_array)) 72 | EER = (FNR_array[threshold] + FPR_array[threshold]) / 2 73 | # print('EER : %f AUC : %f' % (EER, -AUC)) 74 | plt.plot(FPR_array, TPR_array) 75 | 76 | plt.title('ROC') 77 | plt.xlabel('FPR') 78 | plt.ylabel('TPR') 79 | plt.text(0.2, 0, s="EER :{} AUC :{} Threshold:{}".format(round(EER, 4), round(-AUC, 4), 80 | round(threshold_value[threshold], 4)), fontsize=10) 81 | plt.legend() 82 | plt.savefig(save_path) 83 | plt.show() 84 | 85 | 86 | FLAGS = flags.FLAGS 87 | root_dir = os.path.abspath(os.path.join(os.getcwd(), "../..")) 88 | 89 | flags.DEFINE_string( 90 | "save_plot_dir", os.path.join(root_dir, "results/plots"), 91 | "the generate plots image dir") 92 | 93 | flags.DEFINE_string( 94 | "plot_name", "plt_roc_spk-01000-0.99", 95 | "the roc image's name") 96 | 97 | flags.DEFINE_string( 98 | "score_dir", os.path.join(root_dir, "results/scores"), 99 | "the score txt dir") 100 | 101 | 102 | def main(argv): 103 | 104 | if not os.path.exists(FLAGS.save_plot_dir): 105 | os.makedirs(FLAGS.save_plot_dir) 106 | if not os.path.exists(FLAGS.score_dir): 107 | os.makedirs(FLAGS.score_dir) 108 | score_list = [] 109 | # 读取 score 数据文件 110 | with open(os.path.join(FLAGS.score_dir, "score.txt"), "r") as f: 111 | for line in f: 112 | score, label = line.split(" ") 113 | score_list.append([float(score), label.rstrip("\n")]) 114 | # 绘制 ROC 曲线 115 | plot_roc(score_list, FLAGS.save_plot_dir, FLAGS.plot_name) 116 | 117 | 118 | if __name__ == "__main__": 119 | app.run(main) 120 | -------------------------------------------------------------------------------- /code/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/env/python python3 2 | # -*- coding: utf-8 -*- 3 | # @File : __init__.py.py 4 | # @Time : 2018/8/24 21:26 5 | # @Software : PyCharm -------------------------------------------------------------------------------- /code/utils/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/env/python python3 2 | # -*- coding: utf-8 -*- 3 | # @File : __init__.py.py 4 | # @Time : 2018/8/24 14:10 5 | # @Software : PyCharm -------------------------------------------------------------------------------- /code/utils/calculate_cds.py: -------------------------------------------------------------------------------- 1 | #!/usr/env/python python3 2 | # -*- coding: utf-8 -*- 3 | # @File : calculate_cds.py 4 | # @Time : 2018/8/24 14:32 5 | # @Software : PyCharm 6 | import numpy as np 7 | 8 | 9 | def get_cds(a, b): 10 | """ 11 | 返回归一化后的余弦距离,得分CDS越接近1越好 12 | :param a: shape[n,-1] 13 | :param b: shape[n, -1] 14 | :return: 15 | """ 16 | 17 | num = float(a.dot(b.T)) 18 | denom = np.linalg.norm(a) * np.linalg.norm(b) 19 | cds = num / denom # 余弦值 20 | cds = 0.5 + 0.5 * cds # 归一化 21 | return cds 22 | -------------------------------------------------------------------------------- /code/utils/csv_util.py: -------------------------------------------------------------------------------- 1 | #!/usr/env/python python3 2 | # -*- coding: utf-8 -*- 3 | # @File : csv_util.py 4 | # @Time : 2018/8/24 14:20 5 | # @Software : PyCharm 6 | import os 7 | import numpy as np 8 | from tqdm import tqdm 9 | import pandas as pd 10 | 11 | 12 | def features2csv(save_dir, category, model, mean=True, sentence_nums=20): 13 | def caculate_features(fb_input): 14 | """ 15 | 16 | :param fb_input: fbank特征向量 17 | :return: d-vector 18 | """ 19 | features = model.predict(fb_input) 20 | features = np.array(features) 21 | if mean: 22 | # (1,256) 23 | return np.mean(features, axis=0) 24 | else: 25 | # (N,256) 26 | return features 27 | 28 | data_path = os.path.join(save_dir, category + "_list.txt") 29 | 30 | # (label, features) 31 | people_list = [] 32 | with open(data_path) as f: 33 | fbank_list = [] 34 | cnt = 0 35 | for line in tqdm(f): 36 | bin_path, label = line.split(" ") 37 | fbank = np.fromfile(bin_path, dtype=np.float).reshape((99, 40)) 38 | # fbank = np.fromfile(bin_path, dtype=np.float) 39 | fbank_list.append(fbank) 40 | 41 | cnt += 1 42 | if cnt % sentence_nums == 0: 43 | features = caculate_features(np.array(fbank_list)) 44 | cnt = 0 45 | fbank_list = [] 46 | if mean is True: 47 | people_list.append((label.rstrip("\n"), ",".join(str(feat) for feat in features))) 48 | else: 49 | # 不对特征向量求均值 50 | for feature in features: 51 | people_list.append((label.rstrip("\n"), ",".join(str(feat) for feat in feature))) 52 | 53 | # 将特征写入 csv 文件 54 | features_df = pd.DataFrame(people_list, columns=["label", "features_str"]) 55 | df_save_path = os.path.join(save_dir, category + "_features.csv") 56 | features_df.to_csv(df_save_path, index=False, encoding="utf-8") 57 | 58 | 59 | def read_features(csv_dir, category): 60 | """ 61 | 62 | :param csv_dir: 63 | :param category: 64 | :return: label, features array 65 | """ 66 | csv_path = os.path.join(csv_dir, category + "_features.csv") 67 | # 读取文件 68 | data = pd.read_csv(csv_path, encoding="utf-8") 69 | for label, features in data.values: 70 | yield label, np.array(list(map(float, features.split(",")))) 71 | -------------------------------------------------------------------------------- /code/utils/process_wav.py: -------------------------------------------------------------------------------- 1 | #!/usr/env/python python3 2 | # -*- coding: utf-8 -*- 3 | # @File : logfbank.py 4 | # @Time : 2018/8/13 10:43 5 | # @Software : PyCharm 6 | import librosa 7 | from python_speech_features import logfbank 8 | import numpy as np 9 | import os 10 | from tqdm import tqdm 11 | import re 12 | 13 | 14 | def get_log_fbank(wavname, winlen=0.025, winstep=0.01, nfilt=40): 15 | sig, rate = librosa.load(wavname, sr=16000) 16 | # 归一化 (-1,1) 17 | try: 18 | sig = sig.tolist() / max(max(sig), -min(sig)) 19 | except ValueError: 20 | # 读取文件为空 21 | return None 22 | sig = np.array(sig) 23 | section_nums = len(sig) // rate 24 | # 将音频切分为1秒1段 25 | audio_list = [sig[partition * rate:(partition + 1)*rate]for partition in range(section_nums)] 26 | 27 | try: 28 | feat = [logfbank(audio, rate, winlen=winlen, winstep=winstep, nfilt=nfilt) for audio in audio_list] 29 | except IndexError: 30 | return None 31 | # (N,40) 32 | return feat 33 | 34 | 35 | def wav2fb(data_list, save_dir, usage): 36 | """ 37 | 将wav音频文件提取 logfbank 特征后写入.bin二进制文件 38 | :param data_list: wav文件路径 39 | :param save_dir: 保存.bin文件路径 40 | :param usage: 数据集类别 41 | :return: 42 | """ 43 | save_path = save_dir 44 | 45 | if not os.path.exists(save_path): 46 | os.makedirs(save_path) 47 | 48 | for fileList, label in tqdm(data_list): 49 | # 创建子文件夹 50 | sub_dir = os.path.join(save_path, usage, label) 51 | if not os.path.exists(sub_dir): 52 | os.makedirs(sub_dir) 53 | for wpath in fileList: 54 | # 计算 fbank 特征 55 | fbank_feats = get_log_fbank(wpath) 56 | if fbank_feats is not None: 57 | file_name = re.search(r"B\S+", wpath).group(0)[:-4] 58 | for idx, fbank in enumerate(fbank_feats): 59 | fbank.tofile(os.path.join(sub_dir, file_name + "_{}.bin".format(idx))) 60 | 61 | 62 | if __name__ == '__main__': 63 | wavname = "BAC009S0002W0122.wav" 64 | 65 | feat = get_log_fbank(wavname) 66 | print(feat) 67 | for f in feat: 68 | print(f.shape) 69 | -------------------------------------------------------------------------------- /code/utils/sep_model_weight.py: -------------------------------------------------------------------------------- 1 | #!/usr/env/python python3 2 | # -*- coding: utf-8 -*- 3 | # @File : calculate_score.py 4 | # @Time : 2018/8/17 14:11 5 | # @Software : PyCharm 6 | from keras import Model 7 | from keras.models import load_model 8 | from absl import flags, app 9 | import os 10 | import re 11 | 12 | FLAGS = flags.FLAGS 13 | 14 | root_dir = os.path.abspath(os.path.join(os.getcwd(), "../..")) 15 | 16 | flags.DEFINE_string( 17 | "checkpoint_path", os.path.join(root_dir, "model/checkpoint-00484-0.99.h5"), 18 | "the checkpoint model path") 19 | flags.DEFINE_string( 20 | "model_save_dir", os.path.join(root_dir, "model"), 21 | "the save model dir") 22 | 23 | 24 | # 去掉模型Softmax层 25 | def sep_model_weight(checkpoint_path, save_dir): 26 | model_name = re.search(r"check\S+", checkpoint_path).group(0)[:-3] 27 | model = load_model(checkpoint_path) 28 | seq = model.get_layer("sequential_1") 29 | target_model = Model(inputs=model.input, outputs=seq.get_output_at(1)) 30 | target_model.save(os.path.join(save_dir, model_name + "_notop.h5")) 31 | 32 | 33 | def main(argv): 34 | sep_model_weight(FLAGS.checkpoint_path, FLAGS.model_save_dir) 35 | 36 | 37 | if __name__ == "__main__": 38 | app.run(main) 39 | -------------------------------------------------------------------------------- /code/utils/split_data.py: -------------------------------------------------------------------------------- 1 | #!/usr/env/python python3 2 | # -*- coding: utf-8 -*- 3 | # @File : split_data.py 4 | # @Time : 2018/8/24 20:43 5 | # @Software : PyCharm 6 | import os 7 | 8 | 9 | def split_data(data_dir, category, split_scale=0.05): 10 | 11 | data_path = os.path.join(data_dir, category) 12 | 13 | # 获取文件列表 14 | data_list = [] 15 | for dir in os.listdir(data_path): 16 | file_list = os.listdir(os.path.join(data_path, dir)) 17 | data_list.append([os.path.join(data_path, dir, file) for file in file_list]) 18 | 19 | # (list , label) 20 | train_list = [] 21 | validate_list = [] 22 | 23 | for i, file_list in enumerate(data_list): 24 | utterence_nums = len(file_list) 25 | test_nums = utterence_nums * split_scale 26 | partition = int(utterence_nums - test_nums) 27 | train_list.append((file_list[:partition], str(i).zfill(4))) 28 | if split_scale > 0.0: 29 | validate_list.append((file_list[partition:], str(i).zfill(4))) 30 | 31 | return train_list, validate_list 32 | 33 | 34 | if __name__ == "__main__": 35 | root_dir = os.path.abspath(os.path.join(os.getcwd(), "../..")) 36 | 37 | data_dir = os.path.join(root_dir, "data") 38 | category = "vad_data" 39 | train_list, validate_list = split_data(data_dir, category, split_scale=0.05) 40 | for file_list, label in train_list: 41 | print(len(file_list), label) 42 | -------------------------------------------------------------------------------- /code/utils/vad_util.py: -------------------------------------------------------------------------------- 1 | #!/usr/env/python python3 2 | # -*- coding: utf-8 -*- 3 | # @File : vad_util.py 4 | # @Time : 2018/8/29 13:37 5 | # @Software : PyCharm 6 | import numpy as np 7 | from math import log 8 | import librosa 9 | 10 | 11 | def mse(data): 12 | return ((data ** 2).mean()) ** 0.5 13 | 14 | 15 | def dBFS(data): 16 | mse_data = mse(data) 17 | if mse_data == 0.0: 18 | return 0 19 | max_possible_val = 2 ** 16 / 2 20 | return 20 * log(mse_data / max_possible_val, 10) 21 | 22 | 23 | def cut_wav(data, per_f=150): 24 | num_f = int(len(data) / per_f) 25 | data = data[:num_f * per_f] 26 | data = data.reshape((num_f, per_f)) 27 | return data 28 | 29 | 30 | def remove_silence(source_sound, common_sound, silence_threshold=140, chunk_size=148): 31 | source_sounds = cut_wav(source_sound, chunk_size) 32 | common_sounds = cut_wav(common_sound, chunk_size) 33 | y = [] 34 | for i in range(common_sounds.shape[0]): 35 | db = -dBFS(common_sounds[i, ...]) 36 | 37 | if db < silence_threshold: 38 | y.append(source_sounds[i]) 39 | # print("db", i, db) 40 | y = np.array(y) 41 | y = y.flatten() 42 | return y 43 | 44 | 45 | def comman(sound): 46 | abs_sound = np.abs(sound) 47 | return sound / np.max(abs_sound) 48 | 49 | 50 | if __name__ == '__main__': 51 | 52 | wav_data, rate = librosa.load("BAC009S0908W0161.wav", sr=16000) 53 | 54 | y = remove_silence(wav_data, wav_data, 139, 300) 55 | librosa.output.write_wav("c.wav", y, sr=16000) 56 | -------------------------------------------------------------------------------- /docs/DEEP NEURAL NETWORKS FOR SMALL FOOTPRINT TEXT-DEPENDENT.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zacharyclam/speaker_recognition/ce35d3794413b85e8a470b222f899a737a5fa111/docs/DEEP NEURAL NETWORKS FOR SMALL FOOTPRINT TEXT-DEPENDENT.pdf -------------------------------------------------------------------------------- /model_test.sh: -------------------------------------------------------------------------------- 1 | # the enrolled data dir 2 | data_dir="./data/enrollment_evalution" 3 | 4 | # save result dir 5 | save_dir="./results/features" 6 | 7 | # model output features dir 8 | features_dir="./results/features" 9 | 10 | # the checkpoint model path 11 | checkpoint_path=$1 12 | model_postfix=${checkpoint_path#m*/} 13 | model_name=${model_postfix%.h5} 14 | 15 | # model weight dir 16 | weight_path="./model/" 17 | 18 | # save caculate score txt path 19 | score_dir="./results/scores" 20 | 21 | # the polt image save dir 22 | save_plot_dir="./results/plots" 23 | 24 | # the enroll sentence nums 25 | enroll_sentence_nums=15 26 | 27 | # the validate sentence nums 28 | val_sentence_nums=100 29 | 30 | # the stranger sentence nums 31 | stranger_sentence_nums=100 32 | 33 | python -u ./code/utils/sep_model_weight.py --checkpoint_path=$checkpoint_path --model_save_dir=$weight_path 34 | echo "模型权重导出完成" 35 | 36 | model_save_path=$weight_path$model_name"_notop.h5" 37 | 38 | python -u ./code/2-enrollment/enrollment.py --data_dir=$data_dir --save_dir=$save_dir -weight_path=$model_save_path \ 39 | --category="dev" --enroll_sentence_nums=$enroll_sentence_nums \ 40 | --val_sentence_nums=$val_sentence_nums 41 | echo "注册人提取特征完成" 42 | 43 | python -u ./code/3-evalution/evalution.py --data_dir=$data_dir --save_dir=$save_dir -weight_path=$model_save_path \ 44 | --category="test" --stranger_sentence_nums=$stranger_sentence_nums \ 45 | echo "陌生人提取特征完成" 46 | 47 | python -u ./code/4-roc_curve/caculate_score.py --features_dir=$features_dir --score_dir=$score_dir 48 | echo "score 计算完成" 49 | 50 | python -u ./code/4-roc_curve/plot_roc.py --save_plot_dir=$save_plot_dir --score_dir=$score_dir --plot_name=$model_name 51 | 52 | echo "roc 图绘制完成" 53 | 54 | read -p "按回车键退出" -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | Keras==2.1.4 2 | tensorboard>=1.7.0 3 | tensorflow>=1.7.0 4 | tensorflow-tensorboard>=0.1.8 5 | tensorlayer==1.9.0 6 | tqdm==4.23.4 7 | python_speech_features 8 | numpy 9 | pandas 10 | librosa 11 | matplotlib -------------------------------------------------------------------------------- /results/plot_model_struct.py: -------------------------------------------------------------------------------- 1 | #!/usr/env/python python3 2 | # -*- coding: utf-8 -*- 3 | # @File : tes.py 4 | # @Time : 2018/8/31 9:32 5 | # @Software : PyCharm 6 | from keras.regularizers import l2 7 | from keras.models import Sequential 8 | from keras.layers import Dense, Dropout, MaxoutDense, BatchNormalization, Conv1D, Flatten, MaxPool1D 9 | from keras.utils import plot_model 10 | 11 | 12 | def construct_model(classe_nums): 13 | model = Sequential() 14 | 15 | model.add( 16 | Conv1D(filters=256, kernel_size=3, strides=1, activation='relu', input_shape=(99, 40), name='block1_conv1')) 17 | model.add(MaxPool1D(pool_size=2, name='block1_pool1')) 18 | model.add(BatchNormalization(momentum=0.9, epsilon=1e-5, axis=1)) 19 | 20 | model.add(Conv1D(filters=256, kernel_size=3, strides=1, activation='relu', name='block1_conv2')) 21 | model.add(MaxPool1D(pool_size=2, name='block1_pool2')) 22 | 23 | model.add(Flatten(name='block1_flat1')) 24 | model.add(Dropout(0.5, name='block1_drop1')) 25 | 26 | model.add(Dense(512, activation='relu', name='block2_dense2')) 27 | model.add(MaxoutDense(512, nb_feature=4, name="block2_maxout2")) 28 | model.add(Dropout(0.5, name='block2_drop2')) 29 | 30 | model.add(Dense(512, activation='relu', name='block2_dense3', kernel_regularizer=l2(1e-4))) 31 | model.add(MaxoutDense(512, nb_feature=4, name="block2_maxout3")) 32 | model.add(Dense(classe_nums, activation='softmax', name="predict")) 33 | 34 | # plot_model(model, to_file='model_struct.png', show_shapes=True, show_layer_names=False) 35 | 36 | model.summary() 37 | 38 | 39 | if __name__ == "__main__": 40 | construct_model(classe_nums=340) 41 | -------------------------------------------------------------------------------- /results/plots/checkpoint-00484-0.99.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zacharyclam/speaker_recognition/ce35d3794413b85e8a470b222f899a737a5fa111/results/plots/checkpoint-00484-0.99.jpg --------------------------------------------------------------------------------