├── .idea
├── inspectionProfiles
│ └── Project_Default.xml
├── misc.xml
├── modules.xml
├── speakerRecognition.iml
└── vcs.xml
├── LICENSE
├── README.md
├── code
├── 0-input
│ ├── __init__.py
│ ├── get_data_list.py
│ ├── process_data.py
│ ├── unzip_tar.py
│ └── vad.py
├── 1-development
│ ├── __init__.py
│ ├── data_feeder.py
│ ├── model.py
│ └── train.py
├── 2-enrollment
│ ├── __init__.py
│ └── enrollment.py
├── 3-evalution
│ ├── __init__.py
│ ├── confusion_matrix.py
│ └── evalution.py
├── 4-roc_curve
│ ├── __init__.py
│ ├── caculate_score.py
│ └── plot_roc.py
├── __init__.py
└── utils
│ ├── __init__.py
│ ├── calculate_cds.py
│ ├── csv_util.py
│ ├── process_wav.py
│ ├── sep_model_weight.py
│ ├── split_data.py
│ └── vad_util.py
├── docs
└── DEEP NEURAL NETWORKS FOR SMALL FOOTPRINT TEXT-DEPENDENT.pdf
├── model_test.sh
├── requirements.txt
└── results
├── plot_model_struct.py
└── plots
└── checkpoint-00484-0.99.jpg
/.idea/inspectionProfiles/Project_Default.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
13 |
14 |
15 |
16 |
17 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/speakerRecognition.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Speaker Recognition
2 |
3 | [](https://github.com/zacharyclam/speaker_recognition/pulls)
4 | [](https://github.com/ellerbrock/open-source-badges)
5 | [](https://www.python.org/)
6 | [](https://github.com/zacharyclam/speaker_recognition/blob/master/LICENSE)
7 |
8 | 使用数据集:AISHELL-ASR0009-OS1 [下载](https://pan.baidu.com/s/1dFKRLwl#list/path=%2F)
9 |
10 | 模型结构参考论文 “DEEP NEURAL NETWORKS FOR SMALL FOOTPRINT TEXT-DEPENDENT“,在实现时将论文中所提出的4层DNN结构的前两层替换为两层一维卷积,训练时通过Softmax分类器进行训练,注册及验证时将Softmax层去掉,DNN的输出作为d-vector,通过计算 *cosine-distance* 来判别说话人是否在注册集内。
11 |
12 | * #### 项目结构
13 |
14 | ```
15 | - code
16 | -- 0-input # 数据预处理
17 | -- 1-development # 模型定义及训练
18 | -- 2-enrollment # 注册
19 | -- 3-evalution # 陌生人验证评估
20 | -- 4-roc_curve # 绘制ROC曲线图,并计算EER及阈值
21 | -- utils
22 | - data # 数据存放
23 | - docs # 参考论文
24 | - logs # tensorboard 日志文件
25 | - model # 模型存储文件
26 | - results
27 | -- features # 根据模型计算出注册人及陌生人的d-vector
28 | -- plots # 绘制完成的ROC曲线图
29 | -- scores # 绘制ROC曲线所需的score
30 | ```
31 |
32 |
33 |
34 | * #### 训练
35 |
36 | * 首先对下载好的数据集进行VAD处理,处理代码位于 code/0-input/vad.py
37 |
38 | *usage:*
39 |
40 | ```shell
41 | python vad.py --save_dir="../../data/vad_data" --data_dir="解压之后的数据集路径" \
42 | --category="要处理的数据类别,eg:train,test,dev"
43 | ```
44 |
45 | * 将vad处理后的数据提取 *log fbank* 特征,该过程使用 python_speech_features 库完成
46 |
47 | usage:
48 |
49 | ```shell
50 | python process_data.py --data_dir="../../data/vad_data" --save_dir="提取log fbank后 bin文件保存路径" \
51 | --category="要处理的数据类别" \
52 | --validata_scale="若处理训练集数据,该参数可设置为验证集所占比例,eg:0.05, 若处理其他类别数据将其设置为0即可"
53 | ```
54 |
55 | * 将训练集和验证集数据文件路径写入txt中,方便训练时打乱数据送入模型
56 |
57 | usage:
58 |
59 | ```shell
60 | python get_data_list.py --save_dir="../../data/bin/" --category="validate" # 验证集list
61 | python get_data_list.py --save_dir="../../data/bin/" --category="train" # 训练集list
62 | ```
63 |
64 | * 通过执行train.py即可开始训练
65 |
66 | usage:
67 |
68 | ```shell
69 | python train.py --batch_size=128 --num_epochs=1000 --learn_rate=0.0001
70 | ```
71 |
72 | * #### 评估模型
73 |
74 | 直接运行model_test.sh 脚本即可绘制ROC曲线图并计算EER,该脚本需要模型的路径参数,结果文件会保存至results目录下
75 |
76 | usage:
77 |
78 | ```
79 | model_test.sh "model/checkpoint-00484-0.99.h5"
80 | ```
81 |
82 | * #### 模型参数统计
83 |
84 | Total params: 5,781,524
85 |
86 | Trainable params: 5,781,428
87 |
88 | Non-trainable params: 96
89 |
90 | 模型结构图可在 results 目录下查看
91 |
92 | * #### 实验结果:
93 |
94 | 使用340人的语音数据进行训练,训练完成后,使用dev数据集共40人进行注册,将数据分为注册和验证两部分,每人选取15s音频进行注册,然后用100条长度为1s的音频进行验证,统计TP和FP的个数。使用test数据集共20人进行陌生人验证,每人选取100条长度为1s的音频,统计每条音频与注册集内得分最高的cds值。通过上述测试数据绘制ROC曲线图,计算出EER为12.2%,阈值为0.7824.
95 |
96 |
97 |
98 | - #### 模型下载:
99 |
100 | 百度网盘地址:https://pan.baidu.com/s/1rrVCKEIiqzZ3fTr4sKzr1Q 密码:3gri
101 |
102 | *其中 checkpoint-00484-0.99.h5 文件包含Softmax层 ,checkpoint-00484-0.99_notop.h5 已去掉Softmax层*
103 |
104 | 模型训练时的超参数:batch_size=128 ,learn_rate=0.0001,实验时使用一块 Titan 进行训练,大约5小时训练完成
105 |
106 |
--------------------------------------------------------------------------------
/code/0-input/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/env/python python3
2 | # -*- coding: utf-8 -*-
3 | # @File : __init__.py.py
4 | # @Time : 2018/8/29 13:38
5 | # @Software : PyCharm
6 | """
7 | code is far away from bugs with the god animal protecting
8 | I love animals. They taste delicious.
9 | ┏┓ ┏┓
10 | ┏┛┻━━━┛┻┓
11 | ┃ ☃ ┃
12 | ┃ ┳┛ ┗┳ ┃
13 | ┃ ┻ ┃
14 | ┗━┓ ┏━┛
15 | ┃ ┗━━━┓
16 | ┃ 神兽保佑 ┣┓
17 | ┃ 永无BUG! ┏┛
18 | ┗┓┓┏━┳┓┏┛
19 | ┃┫┫ ┃┫┫
20 | ┗┻┛ ┗┻┛
21 | """
--------------------------------------------------------------------------------
/code/0-input/get_data_list.py:
--------------------------------------------------------------------------------
1 | #!/usr/env/python python3
2 | # -*- coding: utf-8 -*-
3 | # @File : tvlistProduce.py
4 | # @Time : 2018/8/12 12:09
5 | # @Software : PyCharm
6 | import os
7 | from absl import flags, app
8 |
9 |
10 | def get_list(save_dir, category):
11 | """
12 | 将bin文件路径及标签写入txt
13 |
14 | :param save_dir: 保存文件路径
15 | :param usage: 数据集类别
16 | :return:
17 | """
18 | tname = os.path.join(save_dir, category + "_list.txt")
19 | data_dir = os.path.join(save_dir, category)
20 | # 获取子文件夹下的文件列表
21 | sub_dir = os.listdir(data_dir)
22 |
23 | with open(tname, "w") as f:
24 | for i, subname in enumerate(sub_dir):
25 | subpath = os.path.join(data_dir, subname)
26 | for filename in os.listdir(subpath):
27 | # 文件路径 标签
28 | line = os.path.join(subpath, filename) + " " + str(i) + "\n"
29 | f.write(line)
30 |
31 |
32 | root_dir = os.path.abspath(os.path.join(os.getcwd(), "../.."))
33 | FLAGS = flags.FLAGS
34 |
35 | flags.DEFINE_string("save_dir", os.path.join(root_dir, "data/bin/"), "save list to dir")
36 | flags.DEFINE_string("category", "train", "the category of data")
37 |
38 |
39 | def main(argv):
40 | get_list(os.path.abspath(FLAGS.save_dir), FLAGS.category)
41 |
42 |
43 | if __name__ == "__main__":
44 | app.run(main)
45 | # usage
46 | # python3 get_data_list.py --save_dir="../../data/bin/" --category="validate"
47 | # python3 get_data_list.py --save_dir="../../data/bin/" --category="train"
48 |
--------------------------------------------------------------------------------
/code/0-input/process_data.py:
--------------------------------------------------------------------------------
1 | #!/usr/env/python python3
2 | # -*- coding: utf-8 -*-
3 | # @File : process_data.py
4 | # @Time : 2018/8/11 18:02
5 | # @Software : PyCharm
6 | import os
7 | from absl import app, flags
8 |
9 | try:
10 | from code.utils.process_wav import wav2fb
11 | from code.utils.split_data import split_data
12 | except ImportError:
13 | # ubuntu 下运行会出现 ImportError
14 | import sys
15 | sys.path.append(os.path.abspath(os.path.join(os.getcwd(), "..")))
16 | from utils.process_wav import wav2fb
17 | from utils.split_data import split_data
18 |
19 | root_dir = os.path.abspath(os.path.join(os.getcwd(), "../.."))
20 |
21 | FLAGS = flags.FLAGS
22 |
23 | flags.DEFINE_string("data_dir", os.path.join(root_dir, "data/vad_data"), help="the vad data dir")
24 | flags.DEFINE_string("save_dir", os.path.join(root_dir, "data/bin"), help="save fbank vector dir")
25 | flags.DEFINE_string("category", "test", help="the category of data")
26 | flags.DEFINE_float("validata_scale", 0.05, help="the scale of validate data")
27 |
28 |
29 | def main(argv):
30 | # 分割数据集
31 | train_list, validate_list = split_data(FLAGS.data_dir, FLAGS.category, FLAGS.validata_scale)
32 | # 将特征向量存入 bin 文件中
33 | # 处理训练数据 分为训练集和验证集
34 | if FLAGS.validata_scale > 0:
35 | wav2fb(train_list, os.path.abspath(FLAGS.save_dir), "train")
36 | wav2fb(validate_list, os.path.abspath(FLAGS.save_dir), "validate")
37 | # 处理注册和验证时的数据 validata_scale=0
38 | else:
39 | wav2fb(train_list, os.path.abspath(FLAGS.save_dir), FLAGS.category)
40 |
41 |
42 | if __name__ == "__main__":
43 | app.run(main)
44 | # usage
45 | # train on linux
46 | # nohup python3 -u process_data.py --data_dir="../../data/vad_data" --save_dir="../../data/bin" --category="train" > logs.out 2>&1 &
47 | # dev
48 | # python process_data.py --data_dir="../../data/vad_data" --save_dir="../../data/enrollment_evalution" \
49 | # --category="dev" --validata_scale=0.0
50 | # test
51 | # python process_data.py --data_dir="../../data/vad_data" --save_dir="../../data/enrollment_evalution" \
52 | # --category="test" --validata_scale=0.0
53 |
--------------------------------------------------------------------------------
/code/0-input/unzip_tar.py:
--------------------------------------------------------------------------------
1 | #!/usr/env/python python3
2 | # -*- coding: utf-8 -*-
3 | # @File : unzip_tar.py
4 | # @Time : 2018/8/17 9:01
5 | # @Software : PyCharm
6 | import tarfile
7 | import os
8 | from absl import flags, app
9 |
10 | FLAGS = flags.FLAGS
11 |
12 | flags.DEFINE_string(
13 | "data_dir", "../../data/data_aishell/wav/",
14 | "the original tar data dir")
15 |
16 | flags.DEFINE_string(
17 | "save_dir", "../../untar_data/",
18 | "the save untar data dir")
19 |
20 |
21 | # 解压tar.gz文件到文件夹
22 | def untar_dir(srcname, data_path):
23 | tar_handle = tarfile.open(srcname, "r:gz")
24 | tar_handle.extractall(data_path)
25 | tar_handle.close()
26 |
27 |
28 | def main(argv):
29 | # 将原始音频文件解压
30 | for root, dir, file in os.walk(FLAGS.data_dir):
31 | for filename in file:
32 | untar_dir(os.path.join(root, filename), FLAGS.save_dir)
33 | print("finished")
34 |
35 |
36 | if __name__ == "__main__":
37 | app.run(main)
38 |
--------------------------------------------------------------------------------
/code/0-input/vad.py:
--------------------------------------------------------------------------------
1 | #!/usr/env/python python3
2 | # -*- coding: utf-8 -*-
3 | # @File : vad_util.py
4 | # @Time : 2018/8/29 13:37
5 | # @Software : PyCharm
6 | import os
7 | import re
8 | from tqdm import tqdm
9 | from absl import flags, app
10 | import librosa
11 |
12 | try:
13 | from code.utils.vad_util import remove_silence
14 | except ImportError:
15 | # ubuntu 下运行会出现 ImportError
16 | import sys
17 | sys.path.append(os.path.abspath(os.path.join(os.getcwd(), "..")))
18 | from utils.vad_util import remove_silence
19 |
20 |
21 | def get_datalist(data_dir, category):
22 | data_path = os.path.join(data_dir, category)
23 |
24 | # 获取文件列表
25 | data_list = []
26 | for idx, dir in enumerate(os.listdir(data_path)):
27 | file_list = os.listdir(os.path.join(data_path, dir))
28 | data_list.append(([os.path.join(data_path, dir, file) for file in file_list], str(idx).zfill(4)))
29 | return data_list
30 |
31 |
32 | def vad_wav(wav_path, save_dir, sr=16000):
33 | wav_name = re.search(r"B\S+.$", wav_path).group(0)
34 | if not os.path.exists(save_dir):
35 | os.makedirs(save_dir)
36 |
37 | wav_data, rate = librosa.load(wav_path, sr)
38 |
39 | y = remove_silence(wav_data, wav_data, 139, 300)
40 | # 写入文件
41 | librosa.output.write_wav(os.path.join(save_dir, wav_name), y, rate)
42 |
43 |
44 | root_dir = os.path.abspath(os.path.join(os.getcwd(), "../.."))
45 | FLAGS = flags.FLAGS
46 |
47 | flags.DEFINE_string("data_dir", os.path.join(root_dir, "data"), help="the original audio data dir")
48 | flags.DEFINE_string("save_dir", os.path.join(root_dir, "data/vad_data"), help="save vad data dir")
49 | flags.DEFINE_string("category", "dev", help="the category of data")
50 |
51 |
52 | def main(args):
53 | data_list = get_datalist(FLAGS.data_dir, FLAGS.category)
54 | save_path = os.path.join(FLAGS.save_dir, FLAGS.category)
55 |
56 | for file_list, label in tqdm(data_list):
57 | for wav_path in file_list:
58 | vad_wav(wav_path, os.path.join(save_path, label))
59 |
60 |
61 | if __name__ == '__main__':
62 | app.run(main)
63 | # usage
64 | # nohup python3 -u vad.py --save_dir="../../data/vad_data" --data_dir="../../../../untar_data" --category="train" > logs.out 2>&1 &
65 |
--------------------------------------------------------------------------------
/code/1-development/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/env/python python3
2 | # -*- coding: utf-8 -*-
3 | # @File : __init__.py
4 | # @Time : 2018/8/24 17:51
5 | # @Software : PyCharm
--------------------------------------------------------------------------------
/code/1-development/data_feeder.py:
--------------------------------------------------------------------------------
1 | #!/usr/env/python python3
2 | # -*- coding: utf-8 -*-
3 | # @File : data_feeder.py
4 | # @Time : 2018/8/24 17:45
5 | # @Software : PyCharm
6 | # 生成数据
7 | import numpy as np
8 | from keras.utils import to_categorical
9 |
10 |
11 | def generate_fit(path, batch_size, classe_nums):
12 | """
13 |
14 | :param path: list.txt路径
15 | :param batch_size:
16 | :param classe_nums: 类别数
17 | :return:
18 | """
19 | with open(path) as f:
20 | data_list = [line.strip().split(' ') for line in f]
21 | index = np.arange(len(data_list))
22 | while True:
23 | cnt = 0
24 | X = []
25 | Y = []
26 | # shffle data
27 | np.random.shuffle(index)
28 | data_list = np.array(data_list)[index, :]
29 | for bin_path, label in data_list:
30 | x = np.fromfile(bin_path, dtype=np.float).reshape((99, 40))
31 | # x = np.fromfile(bin_path, dtype=np.float)
32 | X.append(x)
33 | Y.append(label)
34 |
35 | cnt += 1
36 |
37 | if cnt % batch_size == 0:
38 | yield [np.array(X), np.array(to_categorical(Y, classe_nums))]
39 | X = []
40 | Y = []
41 | cnt = 0
42 |
--------------------------------------------------------------------------------
/code/1-development/model.py:
--------------------------------------------------------------------------------
1 | #!/usr/env/python python3
2 | # -*- coding: utf-8 -*-
3 | # @File : model.py
4 | # @Time : 2018/8/24 14:49
5 | # @Software : PyCharm
6 | from keras.regularizers import l2
7 | from keras.models import Sequential
8 | from keras.layers import Dense, Dropout, Input, MaxoutDense, BatchNormalization, Conv1D, Flatten, MaxPool1D
9 | from keras.models import Model
10 | from keras.utils import plot_model
11 |
12 |
13 | def construct_model(classe_nums):
14 | model = Sequential()
15 |
16 | model.add(
17 | Conv1D(filters=256, kernel_size=3, strides=1, activation='relu', input_shape=(99, 40), name='block1_conv1'))
18 | model.add(MaxPool1D(pool_size=2, name='block1_pool1'))
19 | model.add(BatchNormalization(momentum=0.9, epsilon=1e-5, axis=1))
20 |
21 | model.add(Conv1D(filters=256, kernel_size=3, strides=1, activation='relu', name='block1_conv2'))
22 | model.add(MaxPool1D(pool_size=2, name='block1_pool2'))
23 |
24 | model.add(Flatten(name='block1_flat1'))
25 | model.add(Dropout(0.5, name='block1_drop1'))
26 |
27 | model.add(Dense(512, activation='relu', name='block2_dense2'))
28 | model.add(MaxoutDense(512, nb_feature=4, name="block2_maxout2"))
29 | model.add(Dropout(0.5, name='block2_drop2'))
30 |
31 | model.add(Dense(512, activation='relu', name='block2_dense3', kernel_regularizer=l2(1e-4)))
32 | model.add(MaxoutDense(512, nb_feature=4, name="block2_maxout3"))
33 |
34 | model.summary()
35 |
36 | model_input = Input(shape=(99, 40))
37 | features = model(model_input)
38 | extract_feature_model = Model(inputs=model_input, outputs=features)
39 |
40 | category_predict = Dense(classe_nums, activation='softmax', name="predict")(features)
41 |
42 | sr_model = Model(inputs=model_input, outputs=category_predict)
43 |
44 | plot_model(sr_model, to_file='model.png', show_shapes=True, show_layer_names=False)
45 | return extract_feature_model, sr_model
46 |
47 |
48 | if __name__ == "__main__":
49 | extract_feature_model, sr_model = construct_model(classe_nums=340)
50 |
--------------------------------------------------------------------------------
/code/1-development/train.py:
--------------------------------------------------------------------------------
1 | #!/usr/env/python python3
2 | # -*- coding: utf-8 -*-
3 | # @File : train.py
4 | # @Time : 2018/8/12 12:49
5 | # @Software : PyCharm
6 |
7 | import os
8 | import tensorflow as tf
9 | from keras.optimizers import Adam
10 | from keras.callbacks import ModelCheckpoint, TensorBoard, ReduceLROnPlateau
11 | import keras.backend as K
12 | import numpy as np
13 |
14 | from model import construct_model
15 | from data_feeder import generate_fit
16 |
17 |
18 | root_dir = os.path.abspath(os.path.join(os.getcwd(), "../.."))
19 |
20 | tf.flags.DEFINE_integer(
21 | "batch_size", default=128,
22 | help="Batch size (default: 128)")
23 |
24 | tf.flags.DEFINE_integer(
25 | "num_epochs", default=100,
26 | help="Number of training epochs (defaule:100)")
27 |
28 | tf.flags.DEFINE_integer(
29 | "num_classes", default=340,
30 | help="Number of training data classes (default:340)")
31 |
32 | tf.flags.DEFINE_float(
33 | "learn_rate", default=0.0001,
34 | help="learn rate (default: 0.0001)")
35 |
36 | tf.flags.DEFINE_string(
37 | "category", default="train",
38 | help="the category of data")
39 |
40 | tf.flags.DEFINE_string(
41 | "model_dir", default=os.path.join(root_dir, "model"),
42 | help="the model file dir")
43 |
44 | tf.flags.DEFINE_string(
45 | "tensorboard_dir", default=os.path.join(root_dir, "logs"),
46 | help="the tensorboard file dir")
47 |
48 | tf.flags.DEFINE_string(
49 | "datalist_dir", default=os.path.join(root_dir, "data/bin"),
50 | help="the data list file dir")
51 |
52 | # FLAGS 是一个对象,保存了解析后的命令行参数
53 | FLAGS = tf.flags.FLAGS
54 | # 进行解析
55 | FLAGS.flag_values_dict()
56 |
57 | if not os.path.exists(FLAGS.model_dir):
58 | os.makedirs(FLAGS.model_dir)
59 |
60 | if not os.path.exists(FLAGS.tensorboard_dir):
61 | os.makedirs(FLAGS.tensorboard_dir)
62 |
63 | # the paths
64 | train_path = os.path.join(FLAGS.datalist_dir, "train_list.txt")
65 | test_path = os.path.join(FLAGS.datalist_dir, "validate_list.txt")
66 |
67 | # count the number of samples
68 | f = open(train_path)
69 | train_nums = len(f.readlines()) # number of train samples
70 | f.close()
71 |
72 | f = open(test_path)
73 | test_nums = len(f.readlines()) # number of train samples
74 | f.close()
75 |
76 | if __name__ == '__main__':
77 | # 指定使用显卡
78 | os.environ["CUDA_VISIBLE_DEVICES"] = "1"
79 |
80 | config = tf.ConfigProto()
81 | config.gpu_options.per_process_gpu_memory_fraction = 0.90 # 占用GPU90%的显存
82 | K.set_session(tf.Session(config=config))
83 |
84 | # 创建模型
85 | extract_feature_model, sr_model = construct_model(FLAGS.num_classes)
86 |
87 | # 创建优化器
88 | opt = Adam(lr=FLAGS.learn_rate, beta_1=0.9, beta_2=0.999, epsilon=1e-08)
89 | sr_model.compile(loss='categorical_crossentropy', optimizer=opt, metrics=['accuracy'])
90 |
91 | # 学习率衰减
92 | reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=10,
93 | min_lr=1e-8, mode="min", cooldown=10, verbose=1)
94 |
95 | tbCallBack = TensorBoard(log_dir=FLAGS.tensorboard_dir,
96 | histogram_freq=0,
97 | write_graph=True,
98 | write_images=True)
99 |
100 | checkpoint = ModelCheckpoint(filepath=os.path.join(FLAGS.model_dir, "checkpoint-{epoch:05d}-{val_acc:.2f}.h5"),
101 | monitor='val_acc', verbose=2, save_best_only=True, mode='max')
102 |
103 | # 开始训练
104 | sr_model.fit_generator(generate_fit(train_path, FLAGS.batch_size, FLAGS.num_classes),
105 | steps_per_epoch=np.ceil(train_nums / FLAGS.batch_size),
106 | shuffle=True,
107 | validation_data=generate_fit(test_path, FLAGS.batch_size, FLAGS.num_classes),
108 | validation_steps=np.ceil(test_nums / FLAGS.batch_size),
109 | epochs=FLAGS.num_epochs,
110 | verbose=2,
111 | callbacks=[reduce_lr, checkpoint, tbCallBack]
112 | )
113 |
114 | sr_model.save("spk.h5")
115 |
116 | # usage
117 | # nohup python3 -u train.py --batch_size=128 --num_epochs=1000 --learn_rate=0.0001 > logs.out 2>&1 &
118 |
--------------------------------------------------------------------------------
/code/2-enrollment/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/env/python python3
2 | # -*- coding: utf-8 -*-
3 | # @File : __init__.py.py
4 | # @Time : 2018/8/24 21:26
5 | # @Software : PyCharm
6 | """
7 | code is far away from bugs with the god animal protecting
8 | I love animals. They taste delicious.
9 | ┏┓ ┏┓
10 | ┏┛┻━━━┛┻┓
11 | ┃ ☃ ┃
12 | ┃ ┳┛ ┗┳ ┃
13 | ┃ ┻ ┃
14 | ┗━┓ ┏━┛
15 | ┃ ┗━━━┓
16 | ┃ 神兽保佑 ┣┓
17 | ┃ 永无BUG! ┏┛
18 | ┗┓┓┏━┳┓┏┛
19 | ┃┫┫ ┃┫┫
20 | ┗┻┛ ┗┻┛
21 | """
--------------------------------------------------------------------------------
/code/2-enrollment/enrollment.py:
--------------------------------------------------------------------------------
1 | #!/usr/env/python python3
2 | # -*- coding: utf-8 -*-
3 | # @File : enrollment.py
4 | # @Time : 2018/8/22 14:27
5 | # @Software : PyCharm
6 | import os
7 | from keras.models import load_model
8 | from absl import flags, app
9 | try:
10 | import sys
11 | # 防止通过脚本运行时由于路径问题出现 ModuleNotFoundError
12 | sys.path.append(os.path.join(os.getcwd(), "code"))
13 | from utils.csv_util import features2csv
14 | except ModuleNotFoundError:
15 | from code.utils.csv_util import features2csv
16 |
17 | root_dir = os.path.abspath(os.path.join(os.getcwd(), "../.."))
18 |
19 | FLAGS = flags.FLAGS
20 |
21 | flags.DEFINE_string(
22 | "data_dir", os.path.join(root_dir, "data/enrollment_evalution"),
23 | "the enrolled data dir")
24 |
25 | flags.DEFINE_string(
26 | "save_dir", os.path.join(root_dir, "results/features"),
27 | "the save d-vector features dir")
28 |
29 | flags.DEFINE_string(
30 | "weight_path", os.path.join(root_dir, "model/spk-01000-1.00.h5"),
31 | "the model dir")
32 |
33 | flags.DEFINE_string(
34 | "category", "dev", "the category of data")
35 |
36 | flags.DEFINE_integer(
37 | "enroll_sentence_nums", 15,
38 | "the enroll sentence nums")
39 |
40 | flags.DEFINE_integer(
41 | "val_sentence_nums", 100,
42 | "the validate sentence nums")
43 |
44 |
45 | def split_sentences(data_dir, save_dir, usage, enroll_sentence_nums=20, val_sentence_nums=100):
46 | """
47 |
48 | :paramdataDir: 文件存储路径
49 | :param usage: 数据集类别
50 | :param enroll_num: 注册使用语句条数
51 | :param val_num: 验证使用语句条数
52 | :return:
53 | """
54 | data_path = os.path.join(data_dir, usage)
55 | # 获取文件列表
56 | data_list = []
57 | for dir in os.listdir(data_path):
58 | file_list = os.listdir(os.path.join(data_path, dir))
59 | data_list.append([os.path.join(data_path, dir, file) for file in file_list])
60 |
61 | # (list , label)
62 | enroll_list = []
63 | val_list = []
64 |
65 | for i, file_list in enumerate(data_list):
66 | enroll_list.append((file_list[:enroll_sentence_nums], i))
67 | val_list.append((file_list[enroll_sentence_nums:enroll_sentence_nums + val_sentence_nums], i))
68 |
69 | with open(os.path.join(save_dir, "enroll_list.txt"), "w") as f:
70 | for (file_list, label) in enroll_list:
71 | for file in file_list:
72 | line = file + " " + str(label).zfill(4) + "\n"
73 | f.write(line)
74 |
75 | with open(os.path.join(save_dir, "validate_list.txt"), "w") as f:
76 | for (file_list, label) in val_list:
77 | for file in file_list:
78 | line = file + " " + str(label).zfill(4) + "\n"
79 | f.write(line)
80 |
81 |
82 | def main(argv):
83 | if not os.path.exists(FLAGS.save_dir):
84 | os.makedirs(FLAGS.save_dir)
85 | # 导入预测模型权重
86 | model = load_model(FLAGS.weight_path)
87 | # 分割 注册人 数据集 并写入txt
88 | split_sentences(FLAGS.data_dir, FLAGS.save_dir, FLAGS.category, enroll_sentence_nums=FLAGS.enroll_sentence_nums,
89 | val_sentence_nums=FLAGS.val_sentence_nums)
90 |
91 | # 将注册人的注册语句特征写入csv文件
92 | features2csv(FLAGS.save_dir, category="enroll", model=model, mean=True, sentence_nums=FLAGS.enroll_sentence_nums)
93 |
94 | # 将注册人的验证语句特征写入csv文件
95 | features2csv(FLAGS.save_dir, category="validate", model=model, mean=False, sentence_nums=FLAGS.val_sentence_nums)
96 |
97 |
98 | if __name__ == "__main__":
99 | app.run(main)
100 |
--------------------------------------------------------------------------------
/code/3-evalution/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/env/python python3
2 | # -*- coding: utf-8 -*-
3 | # @File : __init__.py.py
4 | # @Time : 2018/8/24 21:26
5 | # @Software : PyCharm
6 | """
7 | code is far away from bugs with the god animal protecting
8 | I love animals. They taste delicious.
9 | ┏┓ ┏┓
10 | ┏┛┻━━━┛┻┓
11 | ┃ ☃ ┃
12 | ┃ ┳┛ ┗┳ ┃
13 | ┃ ┻ ┃
14 | ┗━┓ ┏━┛
15 | ┃ ┗━━━┓
16 | ┃ 神兽保佑 ┣┓
17 | ┃ 永无BUG! ┏┛
18 | ┗┓┓┏━┳┓┏┛
19 | ┃┫┫ ┃┫┫
20 | ┗┻┛ ┗┻┛
21 | """
--------------------------------------------------------------------------------
/code/3-evalution/confusion_matrix.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 |
4 | from tqdm import tqdm
5 | from absl import flags, app
6 |
7 | from code.utils.calculate_cds import get_cds
8 | from code.utils.csv_util import read_features
9 |
10 |
11 | def confusion_matrix_test(features_dir, matrix_dir):
12 | validate_dict = read_features(features_dir, "validate")
13 | # 混淆矩阵
14 | confusion_matrix = np.zeros((40, 40), dtype=np.int16)
15 | # 统计
16 | for val_label, val_feat in tqdm(validate_dict):
17 | enroll_dict = read_features(features_dir, "enroll")
18 | distance = [get_cds(val_feat, enroll_feat) for _, enroll_feat in enroll_dict]
19 | predict_label = np.argmax(distance, axis=0)
20 | confusion_matrix[val_label][predict_label] += 1
21 | np.savetxt(os.path.join(matrix_dir, "confusion_matrix.csv"), confusion_matrix, fmt='%d', delimiter=",")
22 |
23 |
24 | root_dir = os.path.abspath(os.path.join(os.getcwd(), "../.."))
25 | FLAGS = flags.FLAGS
26 |
27 | flags.DEFINE_string(
28 | "features_dir", os.path.join(root_dir, "results/features"),
29 | "the enrolled data dir")
30 |
31 | flags.DEFINE_string(
32 | "matrix_dir", default=os.path.join(root_dir, "results"),
33 | help="the dir of saving confusion matrix")
34 |
35 |
36 | def main(argv):
37 | confusion_matrix_test(FLAGS.features_dir, FLAGS.matrix_dir)
38 |
39 |
40 | if __name__ == "__main__":
41 | app.run(main)
42 |
--------------------------------------------------------------------------------
/code/3-evalution/evalution.py:
--------------------------------------------------------------------------------
1 | #!/usr/env/python python3
2 | # -*- coding: utf-8 -*-
3 | # @File : evalution.py
4 | # @Time : 2018/8/22 17:33
5 | # @Software : PyCharm
6 | import os
7 | from keras.models import load_model
8 | from absl import flags, app
9 | try:
10 | import sys
11 | # 防止通过脚本运行时由于路径问题出现 ModuleNotFoundError
12 | sys.path.append(os.path.join(os.getcwd(), "code"))
13 | from utils.csv_util import features2csv
14 | except ModuleNotFoundError:
15 | from code.utils.csv_util import features2csv
16 |
17 |
18 | def split_data(data_dir, save_dir, usage, sentence_nums=20):
19 | """
20 |
21 | :paramdataDir: 文件存储路径
22 | :param usage: 数据集类别
23 | :param enroll_num: 注册使用语句条数
24 | :param val_num: 验证使用语句条数
25 | :return:
26 | """
27 | data_path = os.path.join(data_dir, usage)
28 | # 获取文件列表
29 | data_list = []
30 | for dir in os.listdir(data_path):
31 | file_list = os.listdir(os.path.join(data_path, dir))
32 | data_list.append([os.path.join(data_path, dir, file) for file in file_list])
33 |
34 | # (list , label)
35 | stranger_list = []
36 |
37 | for i, file_list in enumerate(data_list):
38 | stranger_list.append((file_list[:sentence_nums], i))
39 |
40 | with open(os.path.join(save_dir, "stranger_list.txt"), "w") as f:
41 | for (file_list, label) in stranger_list:
42 | for file in file_list:
43 | line = file + " " + str(label).zfill(4) + "\n"
44 | f.write(line)
45 |
46 |
47 | FLAGS = flags.FLAGS
48 |
49 | root_dir = os.path.abspath(os.path.join(os.getcwd(), "../.."))
50 |
51 | flags.DEFINE_string(
52 | "data_dir", os.path.join(root_dir, "data/enrollment_evalution"),
53 | "the enrolled data dir")
54 |
55 | flags.DEFINE_string(
56 | "weight_path", os.path.join(root_dir, "model/spk-01000-1.00.h5"),
57 | "the model dir")
58 |
59 | flags.DEFINE_string(
60 | "category", "test", "the category of data")
61 |
62 | flags.DEFINE_string(
63 | "save_dir", os.path.join(root_dir, "results/features"),
64 | "the strangers' features save dir")
65 |
66 | flags.DEFINE_integer(
67 | "stranger_sentence_nums", 100,
68 | "the stranger sentence nums")
69 |
70 |
71 | def main(argv):
72 | if not os.path.exists(FLAGS.save_dir):
73 | os.makedirs(FLAGS.save_dir)
74 | # 导入预测模型权重
75 | model = load_model(FLAGS.weight_path)
76 | # 分割 陌生人 数据集 并写入txt
77 | split_data(FLAGS.data_dir, FLAGS.save_dir, FLAGS.category, sentence_nums=FLAGS.stranger_sentence_nums)
78 |
79 | # 将陌生人的注册语句特征写入csv文件
80 | features2csv(FLAGS.save_dir, "stranger", model, mean=False,
81 | sentence_nums=FLAGS.stranger_sentence_nums)
82 |
83 |
84 | if __name__ == "__main__":
85 | app.run(main)
86 |
--------------------------------------------------------------------------------
/code/4-roc_curve/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/env/python python3
2 | # -*- coding: utf-8 -*-
3 | # @File : __init__.py.py
4 | # @Time : 2018/8/24 21:26
5 | # @Software : PyCharm
6 | """
7 | code is far away from bugs with the god animal protecting
8 | I love animals. They taste delicious.
9 | ┏┓ ┏┓
10 | ┏┛┻━━━┛┻┓
11 | ┃ ☃ ┃
12 | ┃ ┳┛ ┗┳ ┃
13 | ┃ ┻ ┃
14 | ┗━┓ ┏━┛
15 | ┃ ┗━━━┓
16 | ┃ 神兽保佑 ┣┓
17 | ┃ 永无BUG! ┏┛
18 | ┗┓┓┏━┳┓┏┛
19 | ┃┫┫ ┃┫┫
20 | ┗┻┛ ┗┻┛
21 | """
--------------------------------------------------------------------------------
/code/4-roc_curve/caculate_score.py:
--------------------------------------------------------------------------------
1 | #!/usr/env/python python3
2 | # -*- coding: utf-8 -*-
3 | # @File : caculate_score.py
4 | # @Time : 2018/8/22 17:50
5 | # @Software : PyCharm
6 | import numpy as np
7 | import os
8 | from absl import flags, app
9 | from tqdm import tqdm
10 | try:
11 | import sys
12 | # 防止通过脚本运行时由于路径问题出现 ModuleNotFoundError
13 | sys.path.append(os.path.join(os.getcwd(), "code"))
14 | from utils.csv_util import read_features
15 | from utils.calculate_cds import get_cds
16 | except ModuleNotFoundError:
17 | from code.utils.csv_util import read_features
18 | from code.utils.calculate_cds import get_cds
19 |
20 |
21 | root_dir = os.path.abspath(os.path.join(os.getcwd(), "../.."))
22 |
23 | FLAGS = flags.FLAGS
24 |
25 | flags.DEFINE_string(
26 | "features_dir", default=os.path.join(root_dir, "results/features"),
27 | help="the dir of enrolllment")
28 |
29 | flags.DEFINE_string(
30 | "score_dir", default=os.path.join(root_dir, "results/scores"),
31 | help="the dir of saving score")
32 |
33 |
34 | def main(argv):
35 | if not os.path.exists(FLAGS.features_dir):
36 | os.makedirs(FLAGS.features_dir)
37 | if not os.path.exists(FLAGS.score_dir):
38 | os.makedirs(FLAGS.score_dir)
39 | # 读取特征向量
40 | validate_dict = read_features(FLAGS.features_dir, "validate")
41 | strange_dict = read_features(FLAGS.features_dir, "stranger")
42 |
43 | # 计算score
44 | # score n/tp/fp
45 | with open(os.path.join(FLAGS.score_dir, "score.txt"), "w") as f:
46 | for val_label, val_feat in tqdm(validate_dict):
47 | enroll_dict = read_features(FLAGS.features_dir, "enroll")
48 | distance = [get_cds(val_feat, enroll_feat) for _, enroll_feat in enroll_dict]
49 | predict_label = np.argmax(distance, axis=0)
50 | line = str(distance[int(predict_label)]) + " "
51 | if predict_label == val_label:
52 | line += "tp\n"
53 | else:
54 | line += "fp\n"
55 | f.write(line)
56 |
57 | for stranger_label, stranger_feat in tqdm(strange_dict):
58 | enroll_dict = read_features(FLAGS.features_dir, "enroll")
59 | distance = [get_cds(stranger_feat, enroll_feat) for _, enroll_feat in enroll_dict]
60 | predict_label = np.argmax(distance, axis=0)
61 | line = str(distance[int(predict_label)]) + " n\n"
62 | f.write(line)
63 |
64 |
65 | if __name__ == "__main__":
66 | app.run(main)
67 |
--------------------------------------------------------------------------------
/code/4-roc_curve/plot_roc.py:
--------------------------------------------------------------------------------
1 | #!/usr/env/python python3
2 | # -*- coding: utf-8 -*-
3 | # @File : plot_roc.py
4 | # @Time : 2018/8/22 20:58
5 | # @Software : PyCharm
6 | import os
7 | import numpy as np
8 | import matplotlib.pylab as plt
9 | from absl import app, flags
10 |
11 |
12 | def cal_rate(score_dict, thres):
13 | all_number = len(score_dict)
14 | # print all_number
15 | TP = 0
16 | FP = 0
17 | FN = 0
18 | TN = 0
19 | for score, label in score_dict:
20 | if score >= thres:
21 | score = 1
22 | if score == 1:
23 | if label == "tp":
24 | TP += 1
25 | else:
26 | FP += 1
27 | else:
28 | if label == "n":
29 | TN += 1
30 | else:
31 | FN += 1
32 | # print TP+FP+TN+FN
33 | accracy = float(TP + FP) / float(all_number)
34 | if TP + FP == 0:
35 | precision = 0
36 | else:
37 | precision = float(TP) / float(TP + FP)
38 | TPR = float(TP) / float(TP + FN)
39 | TNR = float(TN) / float(FP + TN)
40 | FNR = float(FN) / float(TP + FN)
41 | FPR = float(FP) / float(FP + TN)
42 |
43 | return accracy, precision, TPR, TNR, FNR, FPR
44 |
45 |
46 | def plot_roc(score_list, save_dir, plot_name):
47 |
48 | save_path = os.path.join(save_dir, plot_name + ".jpg")
49 | # 按照 score 排序
50 | threshold_value = sorted([score for score, _ in score_list])
51 |
52 | threshold_num = len(threshold_value)
53 | accracy_array = np.zeros(threshold_num)
54 | precision_array = np.zeros(threshold_num)
55 | TPR_array = np.zeros(threshold_num)
56 | TNR_array = np.zeros(threshold_num)
57 | FNR_array = np.zeros(threshold_num)
58 | FPR_array = np.zeros(threshold_num)
59 |
60 | # calculate all the rates
61 | for thres in range(threshold_num):
62 | accracy, precision, TPR, TNR, FNR, FPR = cal_rate(score_list, threshold_value[thres])
63 | accracy_array[thres] = accracy
64 | precision_array[thres] = precision
65 | TPR_array[thres] = TPR
66 | TNR_array[thres] = TNR
67 | FNR_array[thres] = FNR
68 | FPR_array[thres] = FPR
69 |
70 | AUC = np.trapz(TPR_array, FPR_array)
71 | threshold = np.argmin(abs(FNR_array - FPR_array))
72 | EER = (FNR_array[threshold] + FPR_array[threshold]) / 2
73 | # print('EER : %f AUC : %f' % (EER, -AUC))
74 | plt.plot(FPR_array, TPR_array)
75 |
76 | plt.title('ROC')
77 | plt.xlabel('FPR')
78 | plt.ylabel('TPR')
79 | plt.text(0.2, 0, s="EER :{} AUC :{} Threshold:{}".format(round(EER, 4), round(-AUC, 4),
80 | round(threshold_value[threshold], 4)), fontsize=10)
81 | plt.legend()
82 | plt.savefig(save_path)
83 | plt.show()
84 |
85 |
86 | FLAGS = flags.FLAGS
87 | root_dir = os.path.abspath(os.path.join(os.getcwd(), "../.."))
88 |
89 | flags.DEFINE_string(
90 | "save_plot_dir", os.path.join(root_dir, "results/plots"),
91 | "the generate plots image dir")
92 |
93 | flags.DEFINE_string(
94 | "plot_name", "plt_roc_spk-01000-0.99",
95 | "the roc image's name")
96 |
97 | flags.DEFINE_string(
98 | "score_dir", os.path.join(root_dir, "results/scores"),
99 | "the score txt dir")
100 |
101 |
102 | def main(argv):
103 |
104 | if not os.path.exists(FLAGS.save_plot_dir):
105 | os.makedirs(FLAGS.save_plot_dir)
106 | if not os.path.exists(FLAGS.score_dir):
107 | os.makedirs(FLAGS.score_dir)
108 | score_list = []
109 | # 读取 score 数据文件
110 | with open(os.path.join(FLAGS.score_dir, "score.txt"), "r") as f:
111 | for line in f:
112 | score, label = line.split(" ")
113 | score_list.append([float(score), label.rstrip("\n")])
114 | # 绘制 ROC 曲线
115 | plot_roc(score_list, FLAGS.save_plot_dir, FLAGS.plot_name)
116 |
117 |
118 | if __name__ == "__main__":
119 | app.run(main)
120 |
--------------------------------------------------------------------------------
/code/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/env/python python3
2 | # -*- coding: utf-8 -*-
3 | # @File : __init__.py.py
4 | # @Time : 2018/8/24 21:26
5 | # @Software : PyCharm
--------------------------------------------------------------------------------
/code/utils/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/env/python python3
2 | # -*- coding: utf-8 -*-
3 | # @File : __init__.py.py
4 | # @Time : 2018/8/24 14:10
5 | # @Software : PyCharm
--------------------------------------------------------------------------------
/code/utils/calculate_cds.py:
--------------------------------------------------------------------------------
1 | #!/usr/env/python python3
2 | # -*- coding: utf-8 -*-
3 | # @File : calculate_cds.py
4 | # @Time : 2018/8/24 14:32
5 | # @Software : PyCharm
6 | import numpy as np
7 |
8 |
9 | def get_cds(a, b):
10 | """
11 | 返回归一化后的余弦距离,得分CDS越接近1越好
12 | :param a: shape[n,-1]
13 | :param b: shape[n, -1]
14 | :return:
15 | """
16 |
17 | num = float(a.dot(b.T))
18 | denom = np.linalg.norm(a) * np.linalg.norm(b)
19 | cds = num / denom # 余弦值
20 | cds = 0.5 + 0.5 * cds # 归一化
21 | return cds
22 |
--------------------------------------------------------------------------------
/code/utils/csv_util.py:
--------------------------------------------------------------------------------
1 | #!/usr/env/python python3
2 | # -*- coding: utf-8 -*-
3 | # @File : csv_util.py
4 | # @Time : 2018/8/24 14:20
5 | # @Software : PyCharm
6 | import os
7 | import numpy as np
8 | from tqdm import tqdm
9 | import pandas as pd
10 |
11 |
12 | def features2csv(save_dir, category, model, mean=True, sentence_nums=20):
13 | def caculate_features(fb_input):
14 | """
15 |
16 | :param fb_input: fbank特征向量
17 | :return: d-vector
18 | """
19 | features = model.predict(fb_input)
20 | features = np.array(features)
21 | if mean:
22 | # (1,256)
23 | return np.mean(features, axis=0)
24 | else:
25 | # (N,256)
26 | return features
27 |
28 | data_path = os.path.join(save_dir, category + "_list.txt")
29 |
30 | # (label, features)
31 | people_list = []
32 | with open(data_path) as f:
33 | fbank_list = []
34 | cnt = 0
35 | for line in tqdm(f):
36 | bin_path, label = line.split(" ")
37 | fbank = np.fromfile(bin_path, dtype=np.float).reshape((99, 40))
38 | # fbank = np.fromfile(bin_path, dtype=np.float)
39 | fbank_list.append(fbank)
40 |
41 | cnt += 1
42 | if cnt % sentence_nums == 0:
43 | features = caculate_features(np.array(fbank_list))
44 | cnt = 0
45 | fbank_list = []
46 | if mean is True:
47 | people_list.append((label.rstrip("\n"), ",".join(str(feat) for feat in features)))
48 | else:
49 | # 不对特征向量求均值
50 | for feature in features:
51 | people_list.append((label.rstrip("\n"), ",".join(str(feat) for feat in feature)))
52 |
53 | # 将特征写入 csv 文件
54 | features_df = pd.DataFrame(people_list, columns=["label", "features_str"])
55 | df_save_path = os.path.join(save_dir, category + "_features.csv")
56 | features_df.to_csv(df_save_path, index=False, encoding="utf-8")
57 |
58 |
59 | def read_features(csv_dir, category):
60 | """
61 |
62 | :param csv_dir:
63 | :param category:
64 | :return: label, features array
65 | """
66 | csv_path = os.path.join(csv_dir, category + "_features.csv")
67 | # 读取文件
68 | data = pd.read_csv(csv_path, encoding="utf-8")
69 | for label, features in data.values:
70 | yield label, np.array(list(map(float, features.split(","))))
71 |
--------------------------------------------------------------------------------
/code/utils/process_wav.py:
--------------------------------------------------------------------------------
1 | #!/usr/env/python python3
2 | # -*- coding: utf-8 -*-
3 | # @File : logfbank.py
4 | # @Time : 2018/8/13 10:43
5 | # @Software : PyCharm
6 | import librosa
7 | from python_speech_features import logfbank
8 | import numpy as np
9 | import os
10 | from tqdm import tqdm
11 | import re
12 |
13 |
14 | def get_log_fbank(wavname, winlen=0.025, winstep=0.01, nfilt=40):
15 | sig, rate = librosa.load(wavname, sr=16000)
16 | # 归一化 (-1,1)
17 | try:
18 | sig = sig.tolist() / max(max(sig), -min(sig))
19 | except ValueError:
20 | # 读取文件为空
21 | return None
22 | sig = np.array(sig)
23 | section_nums = len(sig) // rate
24 | # 将音频切分为1秒1段
25 | audio_list = [sig[partition * rate:(partition + 1)*rate]for partition in range(section_nums)]
26 |
27 | try:
28 | feat = [logfbank(audio, rate, winlen=winlen, winstep=winstep, nfilt=nfilt) for audio in audio_list]
29 | except IndexError:
30 | return None
31 | # (N,40)
32 | return feat
33 |
34 |
35 | def wav2fb(data_list, save_dir, usage):
36 | """
37 | 将wav音频文件提取 logfbank 特征后写入.bin二进制文件
38 | :param data_list: wav文件路径
39 | :param save_dir: 保存.bin文件路径
40 | :param usage: 数据集类别
41 | :return:
42 | """
43 | save_path = save_dir
44 |
45 | if not os.path.exists(save_path):
46 | os.makedirs(save_path)
47 |
48 | for fileList, label in tqdm(data_list):
49 | # 创建子文件夹
50 | sub_dir = os.path.join(save_path, usage, label)
51 | if not os.path.exists(sub_dir):
52 | os.makedirs(sub_dir)
53 | for wpath in fileList:
54 | # 计算 fbank 特征
55 | fbank_feats = get_log_fbank(wpath)
56 | if fbank_feats is not None:
57 | file_name = re.search(r"B\S+", wpath).group(0)[:-4]
58 | for idx, fbank in enumerate(fbank_feats):
59 | fbank.tofile(os.path.join(sub_dir, file_name + "_{}.bin".format(idx)))
60 |
61 |
62 | if __name__ == '__main__':
63 | wavname = "BAC009S0002W0122.wav"
64 |
65 | feat = get_log_fbank(wavname)
66 | print(feat)
67 | for f in feat:
68 | print(f.shape)
69 |
--------------------------------------------------------------------------------
/code/utils/sep_model_weight.py:
--------------------------------------------------------------------------------
1 | #!/usr/env/python python3
2 | # -*- coding: utf-8 -*-
3 | # @File : calculate_score.py
4 | # @Time : 2018/8/17 14:11
5 | # @Software : PyCharm
6 | from keras import Model
7 | from keras.models import load_model
8 | from absl import flags, app
9 | import os
10 | import re
11 |
12 | FLAGS = flags.FLAGS
13 |
14 | root_dir = os.path.abspath(os.path.join(os.getcwd(), "../.."))
15 |
16 | flags.DEFINE_string(
17 | "checkpoint_path", os.path.join(root_dir, "model/checkpoint-00484-0.99.h5"),
18 | "the checkpoint model path")
19 | flags.DEFINE_string(
20 | "model_save_dir", os.path.join(root_dir, "model"),
21 | "the save model dir")
22 |
23 |
24 | # 去掉模型Softmax层
25 | def sep_model_weight(checkpoint_path, save_dir):
26 | model_name = re.search(r"check\S+", checkpoint_path).group(0)[:-3]
27 | model = load_model(checkpoint_path)
28 | seq = model.get_layer("sequential_1")
29 | target_model = Model(inputs=model.input, outputs=seq.get_output_at(1))
30 | target_model.save(os.path.join(save_dir, model_name + "_notop.h5"))
31 |
32 |
33 | def main(argv):
34 | sep_model_weight(FLAGS.checkpoint_path, FLAGS.model_save_dir)
35 |
36 |
37 | if __name__ == "__main__":
38 | app.run(main)
39 |
--------------------------------------------------------------------------------
/code/utils/split_data.py:
--------------------------------------------------------------------------------
1 | #!/usr/env/python python3
2 | # -*- coding: utf-8 -*-
3 | # @File : split_data.py
4 | # @Time : 2018/8/24 20:43
5 | # @Software : PyCharm
6 | import os
7 |
8 |
9 | def split_data(data_dir, category, split_scale=0.05):
10 |
11 | data_path = os.path.join(data_dir, category)
12 |
13 | # 获取文件列表
14 | data_list = []
15 | for dir in os.listdir(data_path):
16 | file_list = os.listdir(os.path.join(data_path, dir))
17 | data_list.append([os.path.join(data_path, dir, file) for file in file_list])
18 |
19 | # (list , label)
20 | train_list = []
21 | validate_list = []
22 |
23 | for i, file_list in enumerate(data_list):
24 | utterence_nums = len(file_list)
25 | test_nums = utterence_nums * split_scale
26 | partition = int(utterence_nums - test_nums)
27 | train_list.append((file_list[:partition], str(i).zfill(4)))
28 | if split_scale > 0.0:
29 | validate_list.append((file_list[partition:], str(i).zfill(4)))
30 |
31 | return train_list, validate_list
32 |
33 |
34 | if __name__ == "__main__":
35 | root_dir = os.path.abspath(os.path.join(os.getcwd(), "../.."))
36 |
37 | data_dir = os.path.join(root_dir, "data")
38 | category = "vad_data"
39 | train_list, validate_list = split_data(data_dir, category, split_scale=0.05)
40 | for file_list, label in train_list:
41 | print(len(file_list), label)
42 |
--------------------------------------------------------------------------------
/code/utils/vad_util.py:
--------------------------------------------------------------------------------
1 | #!/usr/env/python python3
2 | # -*- coding: utf-8 -*-
3 | # @File : vad_util.py
4 | # @Time : 2018/8/29 13:37
5 | # @Software : PyCharm
6 | import numpy as np
7 | from math import log
8 | import librosa
9 |
10 |
11 | def mse(data):
12 | return ((data ** 2).mean()) ** 0.5
13 |
14 |
15 | def dBFS(data):
16 | mse_data = mse(data)
17 | if mse_data == 0.0:
18 | return 0
19 | max_possible_val = 2 ** 16 / 2
20 | return 20 * log(mse_data / max_possible_val, 10)
21 |
22 |
23 | def cut_wav(data, per_f=150):
24 | num_f = int(len(data) / per_f)
25 | data = data[:num_f * per_f]
26 | data = data.reshape((num_f, per_f))
27 | return data
28 |
29 |
30 | def remove_silence(source_sound, common_sound, silence_threshold=140, chunk_size=148):
31 | source_sounds = cut_wav(source_sound, chunk_size)
32 | common_sounds = cut_wav(common_sound, chunk_size)
33 | y = []
34 | for i in range(common_sounds.shape[0]):
35 | db = -dBFS(common_sounds[i, ...])
36 |
37 | if db < silence_threshold:
38 | y.append(source_sounds[i])
39 | # print("db", i, db)
40 | y = np.array(y)
41 | y = y.flatten()
42 | return y
43 |
44 |
45 | def comman(sound):
46 | abs_sound = np.abs(sound)
47 | return sound / np.max(abs_sound)
48 |
49 |
50 | if __name__ == '__main__':
51 |
52 | wav_data, rate = librosa.load("BAC009S0908W0161.wav", sr=16000)
53 |
54 | y = remove_silence(wav_data, wav_data, 139, 300)
55 | librosa.output.write_wav("c.wav", y, sr=16000)
56 |
--------------------------------------------------------------------------------
/docs/DEEP NEURAL NETWORKS FOR SMALL FOOTPRINT TEXT-DEPENDENT.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zacharyclam/speaker_recognition/ce35d3794413b85e8a470b222f899a737a5fa111/docs/DEEP NEURAL NETWORKS FOR SMALL FOOTPRINT TEXT-DEPENDENT.pdf
--------------------------------------------------------------------------------
/model_test.sh:
--------------------------------------------------------------------------------
1 | # the enrolled data dir
2 | data_dir="./data/enrollment_evalution"
3 |
4 | # save result dir
5 | save_dir="./results/features"
6 |
7 | # model output features dir
8 | features_dir="./results/features"
9 |
10 | # the checkpoint model path
11 | checkpoint_path=$1
12 | model_postfix=${checkpoint_path#m*/}
13 | model_name=${model_postfix%.h5}
14 |
15 | # model weight dir
16 | weight_path="./model/"
17 |
18 | # save caculate score txt path
19 | score_dir="./results/scores"
20 |
21 | # the polt image save dir
22 | save_plot_dir="./results/plots"
23 |
24 | # the enroll sentence nums
25 | enroll_sentence_nums=15
26 |
27 | # the validate sentence nums
28 | val_sentence_nums=100
29 |
30 | # the stranger sentence nums
31 | stranger_sentence_nums=100
32 |
33 | python -u ./code/utils/sep_model_weight.py --checkpoint_path=$checkpoint_path --model_save_dir=$weight_path
34 | echo "模型权重导出完成"
35 |
36 | model_save_path=$weight_path$model_name"_notop.h5"
37 |
38 | python -u ./code/2-enrollment/enrollment.py --data_dir=$data_dir --save_dir=$save_dir -weight_path=$model_save_path \
39 | --category="dev" --enroll_sentence_nums=$enroll_sentence_nums \
40 | --val_sentence_nums=$val_sentence_nums
41 | echo "注册人提取特征完成"
42 |
43 | python -u ./code/3-evalution/evalution.py --data_dir=$data_dir --save_dir=$save_dir -weight_path=$model_save_path \
44 | --category="test" --stranger_sentence_nums=$stranger_sentence_nums \
45 | echo "陌生人提取特征完成"
46 |
47 | python -u ./code/4-roc_curve/caculate_score.py --features_dir=$features_dir --score_dir=$score_dir
48 | echo "score 计算完成"
49 |
50 | python -u ./code/4-roc_curve/plot_roc.py --save_plot_dir=$save_plot_dir --score_dir=$score_dir --plot_name=$model_name
51 |
52 | echo "roc 图绘制完成"
53 |
54 | read -p "按回车键退出"
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | Keras==2.1.4
2 | tensorboard>=1.7.0
3 | tensorflow>=1.7.0
4 | tensorflow-tensorboard>=0.1.8
5 | tensorlayer==1.9.0
6 | tqdm==4.23.4
7 | python_speech_features
8 | numpy
9 | pandas
10 | librosa
11 | matplotlib
--------------------------------------------------------------------------------
/results/plot_model_struct.py:
--------------------------------------------------------------------------------
1 | #!/usr/env/python python3
2 | # -*- coding: utf-8 -*-
3 | # @File : tes.py
4 | # @Time : 2018/8/31 9:32
5 | # @Software : PyCharm
6 | from keras.regularizers import l2
7 | from keras.models import Sequential
8 | from keras.layers import Dense, Dropout, MaxoutDense, BatchNormalization, Conv1D, Flatten, MaxPool1D
9 | from keras.utils import plot_model
10 |
11 |
12 | def construct_model(classe_nums):
13 | model = Sequential()
14 |
15 | model.add(
16 | Conv1D(filters=256, kernel_size=3, strides=1, activation='relu', input_shape=(99, 40), name='block1_conv1'))
17 | model.add(MaxPool1D(pool_size=2, name='block1_pool1'))
18 | model.add(BatchNormalization(momentum=0.9, epsilon=1e-5, axis=1))
19 |
20 | model.add(Conv1D(filters=256, kernel_size=3, strides=1, activation='relu', name='block1_conv2'))
21 | model.add(MaxPool1D(pool_size=2, name='block1_pool2'))
22 |
23 | model.add(Flatten(name='block1_flat1'))
24 | model.add(Dropout(0.5, name='block1_drop1'))
25 |
26 | model.add(Dense(512, activation='relu', name='block2_dense2'))
27 | model.add(MaxoutDense(512, nb_feature=4, name="block2_maxout2"))
28 | model.add(Dropout(0.5, name='block2_drop2'))
29 |
30 | model.add(Dense(512, activation='relu', name='block2_dense3', kernel_regularizer=l2(1e-4)))
31 | model.add(MaxoutDense(512, nb_feature=4, name="block2_maxout3"))
32 | model.add(Dense(classe_nums, activation='softmax', name="predict"))
33 |
34 | # plot_model(model, to_file='model_struct.png', show_shapes=True, show_layer_names=False)
35 |
36 | model.summary()
37 |
38 |
39 | if __name__ == "__main__":
40 | construct_model(classe_nums=340)
41 |
--------------------------------------------------------------------------------
/results/plots/checkpoint-00484-0.99.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zacharyclam/speaker_recognition/ce35d3794413b85e8a470b222f899a737a5fa111/results/plots/checkpoint-00484-0.99.jpg
--------------------------------------------------------------------------------