├── .gitignore
├── LICENSE
├── README.md
├── audio_nets
├── ds_cnn.py
├── kws.py
├── res.py
└── tc_resnet.py
├── common
├── __init__.py
├── model_loader.py
├── tf_utils.py
└── utils.py
├── const.py
├── datasets
├── __init__.py
├── audio_data_wrapper.py
├── augmentation_factory.py
├── data_wrapper_base.py
├── preprocessor_factory.py
└── preprocessors.py
├── evaluate_audio.py
├── execute_script.sh
├── factory
├── __init__.py
├── audio_nets.py
└── base.py
├── figure
└── main_figure.png
├── freeze.py
├── helper
├── __init__.py
├── base.py
├── evaluator.py
└── trainer.py
├── metrics
├── __init__.py
├── base.py
├── funcs.py
├── manager.py
├── ops
│ ├── __init__.py
│ ├── base_ops.py
│ ├── misc_ops.py
│ ├── non_tensor_ops.py
│ └── tensor_ops.py
├── parser.py
└── summaries.py
├── requirements
├── py36-common.txt
├── py36-cpu.txt
└── py36-gpu.txt
├── scripts
├── commands
│ ├── DSCNNLModel-0_mfcc_10_4020_0.0000_adam_l3.sh
│ ├── DSCNNMModel-0_mfcc_10_4020_0.0000_adam_l3.sh
│ ├── DSCNNSModel-0_mfcc_10_4020_0.0000_adam_l3.sh
│ ├── KWSfpool3-0_mfcc_40_4020_0.0000_adam_l3.sh
│ ├── KWSfstride4-0_mfcc_40_4020_0.0000_adam_l2.sh
│ ├── Res15Model-0_mfcc_40_3010_0.00001_adam_s1.sh
│ ├── Res15NarrowModel-0_mfcc_40_3010_0.00001_adam_s1.sh
│ ├── Res8Model-0_mfcc_40_3010_0.00001_adam_s1.sh
│ ├── Res8NarrowModel-0_mfcc_40_3010_0.00001_adam_s1.sh
│ ├── TCResNet14Model-1.0_mfcc_40_3010_0.001_mom_l1.sh
│ ├── TCResNet14Model-1.5_mfcc_40_3010_0.001_mom_l1.sh
│ ├── TCResNet2D8Model-1.0_mfcc_40_3010_0.001_mom_l1.sh
│ ├── TCResNet2D8PoolModel-1.0_mfcc_40_3010_0.001_mom_l1.sh
│ ├── TCResNet8Model-1.0_mfcc_40_3010_0.001_mom_l1.sh
│ └── TCResNet8Model-1.5_mfcc_40_3010_0.001_mom_l1.sh
└── google_speech_commmands_dataset_to_our_format.py
├── speech_commands_dataset
├── README.md
├── download_and_split.sh
├── google_speech_commmands_dataset_to_our_format_with_split.py
├── test.txt
├── train.txt
└── valid.txt
├── tflite_tools
├── benchmark_model_r1.13_official
└── run_benchmark.sh
└── train_audio.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Created by https://www.gitignore.io/api/git,python,jetbrains
2 | # Edit at https://www.gitignore.io/?templates=git,python,jetbrains
3 |
4 | ### Git ###
5 | # Created by git for backups. To disable backups in Git:
6 | # $ git config --global mergetool.keepBackup false
7 | *.orig
8 |
9 | # Created by git when using merge tools for conflicts
10 | *.BACKUP.*
11 | *.BASE.*
12 | *.LOCAL.*
13 | *.REMOTE.*
14 | *_BACKUP_*.txt
15 | *_BASE_*.txt
16 | *_LOCAL_*.txt
17 | *_REMOTE_*.txt
18 |
19 | ### JetBrains ###
20 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and WebStorm
21 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839
22 |
23 | # User-specific stuff
24 | .idea/**/workspace.xml
25 | .idea/**/tasks.xml
26 | .idea/**/usage.statistics.xml
27 | .idea/**/dictionaries
28 | .idea/**/shelf
29 |
30 | # Generated files
31 | .idea/**/contentModel.xml
32 |
33 | # Sensitive or high-churn files
34 | .idea/**/dataSources/
35 | .idea/**/dataSources.ids
36 | .idea/**/dataSources.local.xml
37 | .idea/**/sqlDataSources.xml
38 | .idea/**/dynamic.xml
39 | .idea/**/uiDesigner.xml
40 | .idea/**/dbnavigator.xml
41 |
42 | # Gradle
43 | .idea/**/gradle.xml
44 | .idea/**/libraries
45 |
46 | # Gradle and Maven with auto-import
47 | # When using Gradle or Maven with auto-import, you should exclude module files,
48 | # since they will be recreated, and may cause churn. Uncomment if using
49 | # auto-import.
50 | # .idea/modules.xml
51 | # .idea/*.iml
52 | # .idea/modules
53 |
54 | # CMake
55 | cmake-build-*/
56 |
57 | # Mongo Explorer plugin
58 | .idea/**/mongoSettings.xml
59 |
60 | # File-based project format
61 | *.iws
62 |
63 | # IntelliJ
64 | out/
65 |
66 | # mpeltonen/sbt-idea plugin
67 | .idea_modules/
68 |
69 | # JIRA plugin
70 | atlassian-ide-plugin.xml
71 |
72 | # Cursive Clojure plugin
73 | .idea/replstate.xml
74 |
75 | # Crashlytics plugin (for Android Studio and IntelliJ)
76 | com_crashlytics_export_strings.xml
77 | crashlytics.properties
78 | crashlytics-build.properties
79 | fabric.properties
80 |
81 | # Editor-based Rest Client
82 | .idea/httpRequests
83 |
84 | # Android studio 3.1+ serialized cache file
85 | .idea/caches/build_file_checksums.ser
86 |
87 | ### JetBrains Patch ###
88 | # Comment Reason: https://github.com/joeblau/gitignore.io/issues/186#issuecomment-215987721
89 |
90 | # *.iml
91 | # modules.xml
92 | # .idea/misc.xml
93 | # *.ipr
94 |
95 | # Sonarlint plugin
96 | .idea/sonarlint
97 |
98 | ### Python ###
99 | # Byte-compiled / optimized / DLL files
100 | __pycache__/
101 | *.py[cod]
102 | *$py.class
103 |
104 | # C extensions
105 | *.so
106 |
107 | # Distribution / packaging
108 | .Python
109 | build/
110 | develop-eggs/
111 | dist/
112 | downloads/
113 | eggs/
114 | .eggs/
115 | lib/
116 | lib64/
117 | parts/
118 | sdist/
119 | var/
120 | wheels/
121 | pip-wheel-metadata/
122 | share/python-wheels/
123 | *.egg-info/
124 | .installed.cfg
125 | *.egg
126 | MANIFEST
127 |
128 | # PyInstaller
129 | # Usually these files are written by a python script from a template
130 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
131 | *.manifest
132 | *.spec
133 |
134 | # Installer logs
135 | pip-log.txt
136 | pip-delete-this-directory.txt
137 |
138 | # Unit test / coverage reports
139 | htmlcov/
140 | .tox/
141 | .nox/
142 | .coverage
143 | .coverage.*
144 | .cache
145 | nosetests.xml
146 | coverage.xml
147 | *.cover
148 | .hypothesis/
149 | .pytest_cache/
150 |
151 | # Translations
152 | *.mo
153 | *.pot
154 |
155 | # Django stuff:
156 | *.log
157 | local_settings.py
158 | db.sqlite3
159 |
160 | # Flask stuff:
161 | instance/
162 | .webassets-cache
163 |
164 | # Scrapy stuff:
165 | .scrapy
166 |
167 | # Sphinx documentation
168 | docs/_build/
169 |
170 | # PyBuilder
171 | target/
172 |
173 | # Jupyter Notebook
174 | .ipynb_checkpoints
175 |
176 | # IPython
177 | profile_default/
178 | ipython_config.py
179 |
180 | # pyenv
181 | .python-version
182 |
183 | # pipenv
184 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
185 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
186 | # having no cross-platform support, pipenv may install dependencies that don’t work, or not
187 | # install all needed dependencies.
188 | #Pipfile.lock
189 |
190 | # celery beat schedule file
191 | celerybeat-schedule
192 |
193 | # SageMath parsed files
194 | *.sage.py
195 |
196 | # Environments
197 | .env
198 | .venv
199 | env/
200 | venv/
201 | ENV/
202 | env.bak/
203 | venv.bak/
204 |
205 | # Spyder project settings
206 | .spyderproject
207 | .spyproject
208 |
209 | # Rope project settings
210 | .ropeproject
211 |
212 | # mkdocs documentation
213 | /site
214 |
215 | # mypy
216 | .mypy_cache/
217 | .dmypy.json
218 | dmypy.json
219 |
220 | # Pyre type checker
221 | .pyre/
222 |
223 | # End of https://www.gitignore.io/api/git,python,jetbrains
224 |
225 | google_speech_commands/
226 | work/
227 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Temporal Convolution for Real-time Keyword Spotting on Mobile Devices
2 |
3 |
4 |
5 |
6 |
7 | ## Abstract
8 | Keyword spotting (KWS) plays a critical role in enabling speech-based user interactions on smart devices.
9 | Recent developments in the field of deep learning have led to wide adoption of convolutional neural networks (CNNs) in KWS systems due to their exceptional accuracy and robustness.
10 | The main challenge faced by KWS systems is the trade-off between high accuracy and low latency.
11 | Unfortunately, there has been little quantitative analysis of the actual latency of KWS models on mobile devices.
12 | This is especially concerning since conventional convolution-based KWS approaches are known to require a large number of operations to attain an adequate level of performance.
13 |
14 | In this paper, we propose a temporal convolution for real-time KWS on mobile devices.
15 | Unlike most of the 2D convolution-based KWS approaches that require a deep architecture to fully capture both low- and high-frequency domains, we exploit temporal convolutions with a compact ResNet architecture.
16 | In Google Speech Command Dataset, we achieve more than **385x** speedup on Google Pixel 1 and surpass the accuracy compared to the state-of-the-art model.
17 | In addition, we release the implementation of the proposed and the baseline models including an end-to-end pipeline for training models and evaluating them on mobile devices.
18 |
19 |
20 | ## Requirements
21 |
22 | * Python 3.6+
23 | * Tensorflow 1.13.1
24 |
25 | ## Installation
26 |
27 | ```bash
28 | git clone https://github.com/hyperconnect/TC-ResNet.git
29 | pip3 install -r requirements/py36-[gpu|cpu].txt
30 | ```
31 |
32 | ## Dataset
33 |
34 | For evaluating the proposed and the baseline models we use [Google Speech Commands Dataset](https://ai.googleblog.com/2017/08/launching-speech-commands-dataset.html).
35 |
36 | ### Google Speech Commands Dataset
37 |
38 | Follow instructions in [speech_commands_dataset/](https://github.com/hyperconnect/TC-ResNet/tree/master/speech_commands_dataset)
39 |
40 | ## How to run
41 |
42 | Scripts to reproduce the training and evaluation procedures discussed in the paper are located on scripts/commands. After training a model, you can generate .tflite file by following the instruction below.
43 |
44 | To train TCResNet8Model-1.0 model, run:
45 |
46 | ```
47 | ./scripts/commands/TCResNet8Model-1.0_mfcc_40_3010_0.001_mom_l1.sh
48 | ```
49 |
50 | To freeze the trained model checkpoint into `.pb` file, run:
51 |
52 | ```
53 | python freeze.py --checkpoint_path work/v1/TCResNet8Model-1.0/mfcc_40_3010_0.001_mom_l1/TCResNet8Model-XXX --output_name output/softmax --output_type softmax --preprocess_method no_preprocessing --height 49 --width 40 --channels 1 --num_classes 12 TCResNet8Model --width_multiplier 1.0
54 | ```
55 |
56 | To convert the `.pb` file into `.tflite` file, run:
57 |
58 | ```
59 | tflite_convert --graph_def_file=work/v1/TCResNet8Model-1.0/mfcc_40_3010_0.001_mom_l1/TCResNet8Model-XXX.pb --input_format=TENSORFLOW_GRAPHDEF --output_format=TFLITE --output_file=work/v1/TCResNet8Model-1.0/mfcc_40_3010_0.001_mom_l1/TCResNet8Model-XXX.tflite --inference_type=FLOAT --inference_input_type=FLOAT --input_arrays=input --output_arrays=output/softmax --allow_custom_ops
60 | ```
61 |
62 | As shown in above commands, you need to properly set `height`, `width`, `model`, model specific arguments(e.g. `width_multiplier`).
63 | For more information, please refer to `scripts/commands/`
64 |
65 | ## Benchmark tool
66 |
67 | [Android Debug Bridge](https://developer.android.com/studio/command-line/adb.html) (`adb`) is required to run the Android benchmark tool (`model/tflite_tools/run_benchmark.sh`).
68 | `adb` is part of [The Android SDK Platform Tools](https://developer.android.com/studio/releases/platform-tools) and you can download it [here](https://developer.android.com/studio/releases/platform-tools.html) and follow the installation instructions.
69 |
70 | ### 1. Connect Android device to your computer
71 |
72 | ### 2. Check if connection is established
73 |
74 | Run following command.
75 |
76 | ```bash
77 | adb devices
78 | ```
79 |
80 | You should see similar output to the one below.
81 | The ID of a device will, of course, differ.
82 |
83 | ```
84 | List of devices attached
85 | FA77M0304573 device
86 | ```
87 |
88 | ### 3. Run benchmark
89 |
90 | Go to `model/tflite_tools` and place the TF Lite model you want to benchmark (e.g. `mobilenet_v1_1.0_224.tflite`) and execute the following command.
91 | You can pass the optional parameter, `cpu_mask`, to set the CPU affinity [CPU affinity](https://github.com/tensorflow/tensorflow/tree/r1.13/tensorflow/lite/tools/benchmark#reducing-variance-between-runs-on-android)
92 |
93 |
94 | ```bash
95 | ./run_benchmark.sh TCResNet_14Model-1.5.tflite [cpu_mask]
96 | ```
97 |
98 |
99 | If everything goes well you should see an output similar to the one below.
100 | The important measurement of this benchmark is `avg=5701.96` part.
101 | The number represents the average latency of the inference measured in microseconds.
102 |
103 | ```
104 | ./run_benchmark.sh TCResNet_14Model-1.5_mfcc_40_3010_0.001_mom_l1.tflite 3
105 | benchmark_model_r1.13_official: 1 file pushed. 22.1 MB/s (1265528 bytes in 0.055s)
106 | TCResNet_14Model-1.5_mfcc_40_3010_0.001_mom_l1.tflite: 1 file pushed. 25.0 MB/s (1217136 bytes in 0.046s)
107 | >>> run_benchmark_summary TCResNet_14Model-1.5_mfcc_40_3010_0.001_mom_l1.tflite 3
108 | TCResNet_14Model-1.5_mfcc_40_3010_0.001_mom_l1.tflite > count=50 first=5734 curr=5801 min=4847 max=6516 avg=5701.96 std=210
109 | ```
110 |
111 | ## License
112 |
113 | [Apache License 2.0](LICENSE)
114 |
--------------------------------------------------------------------------------
/audio_nets/ds_cnn.py:
--------------------------------------------------------------------------------
1 | from collections import namedtuple
2 | from functools import partial
3 |
4 | import tensorflow as tf
5 |
6 | slim = tf.contrib.slim
7 |
8 |
9 | _DEFAULT = {
10 | "type": None, # ["conv", "separable"]
11 | "kernel": [3, 3],
12 | "stride": [1, 1],
13 | "depth": None,
14 | "scope": None,
15 | }
16 |
17 | _Block = namedtuple("Block", _DEFAULT.keys())
18 | Block = partial(_Block, **_DEFAULT)
19 |
20 | S_NET_DEF = [
21 | Block(type="conv", depth=64, kernel=[10, 4], stride=[2, 2], scope="conv_1"),
22 | Block(type="separable", depth=64, kernel=[3, 3], stride=[1, 1], scope="conv_ds_1"),
23 | Block(type="separable", depth=64, kernel=[3, 3], stride=[1, 1], scope="conv_ds_2"),
24 | Block(type="separable", depth=64, kernel=[3, 3], stride=[1, 1], scope="conv_ds_3"),
25 | Block(type="separable", depth=64, kernel=[3, 3], stride=[1, 1], scope="conv_ds_4"),
26 | ]
27 |
28 | M_NET_DEF = [
29 | Block(type="conv", depth=172, kernel=[10, 4], stride=[2, 1], scope="conv_1"),
30 | Block(type="separable", depth=172, kernel=[3, 3], stride=[2, 2], scope="conv_ds_1"),
31 | Block(type="separable", depth=172, kernel=[3, 3], stride=[1, 1], scope="conv_ds_2"),
32 | Block(type="separable", depth=172, kernel=[3, 3], stride=[1, 1], scope="conv_ds_3"),
33 | Block(type="separable", depth=172, kernel=[3, 3], stride=[1, 1], scope="conv_ds_4"),
34 | ]
35 |
36 | L_NET_DEF = [
37 | Block(type="conv", depth=276, kernel=[10, 4], stride=[2, 1], scope="conv_1"),
38 | Block(type="separable", depth=276, kernel=[3, 3], stride=[2, 2], scope="conv_ds_1"),
39 | Block(type="separable", depth=276, kernel=[3, 3], stride=[1, 1], scope="conv_ds_2"),
40 | Block(type="separable", depth=276, kernel=[3, 3], stride=[1, 1], scope="conv_ds_3"),
41 | Block(type="separable", depth=276, kernel=[3, 3], stride=[1, 1], scope="conv_ds_4"),
42 | Block(type="separable", depth=276, kernel=[3, 3], stride=[1, 1], scope="conv_ds_5"),
43 | ]
44 |
45 |
46 | def _depthwise_separable_conv(inputs, num_pwc_filters, kernel_size, stride):
47 | """ Helper function to build the depth-wise separable convolution layer."""
48 | # skip pointwise by setting num_outputs=None
49 | depthwise_conv = slim.separable_convolution2d(inputs,
50 | num_outputs=None,
51 | stride=stride,
52 | depth_multiplier=1,
53 | kernel_size=kernel_size,
54 | scope="depthwise_conv")
55 |
56 | bn = slim.batch_norm(depthwise_conv, scope="dw_batch_norm")
57 | pointwise_conv = slim.conv2d(bn,
58 | num_pwc_filters,
59 | kernel_size=[1, 1],
60 | scope="pointwise_conv")
61 | bn = slim.batch_norm(pointwise_conv, scope="pw_batch_norm")
62 | return bn
63 |
64 |
65 | def parse_block(input_net, block):
66 | if block.type == "conv":
67 | net = slim.conv2d(
68 | input_net,
69 | num_outputs=block.depth,
70 | kernel_size=block.kernel,
71 | stride=block.stride,
72 | scope=block.scope
73 | )
74 | net = slim.batch_norm(net, scope=f"{block.scope}/batch_norm")
75 | elif block.type == "separable":
76 | with tf.variable_scope(block.scope):
77 | net = _depthwise_separable_conv(
78 | input_net,
79 | block.depth,
80 | kernel_size=block.kernel,
81 | stride=block.stride
82 | )
83 | else:
84 | raise ValueError(f"Block type {block.type} is not supported!")
85 |
86 | return net
87 |
88 |
89 | def DSCNN(inputs, num_classes, net_def, scope="DSCNN"):
90 | endpoints = dict()
91 |
92 | with tf.variable_scope(scope):
93 | net = inputs
94 | for block in net_def:
95 | net = parse_block(net, block)
96 |
97 | net = slim.avg_pool2d(net, kernel_size=net.shape[1:3], stride=1, scope="avg_pool")
98 | net = tf.squeeze(net, [1, 2], name="SpatialSqueeze")
99 | logits = slim.fully_connected(net, num_classes, activation_fn=None, scope="fc1")
100 |
101 | return logits, endpoints
102 |
103 |
104 | def DSCNN_arg_scope(is_training):
105 | batch_norm_params = {
106 | "is_training": is_training,
107 | "decay": 0.96,
108 | "activation_fn": tf.nn.relu,
109 | }
110 |
111 | with slim.arg_scope([slim.conv2d, slim.separable_convolution2d],
112 | activation_fn=None,
113 | weights_initializer=slim.initializers.xavier_initializer(),
114 | biases_initializer=slim.init_ops.zeros_initializer()):
115 | with slim.arg_scope([slim.batch_norm], **batch_norm_params):
116 | with slim.arg_scope([slim.dropout],
117 | is_training=is_training) as scope:
118 | return scope
119 |
--------------------------------------------------------------------------------
/audio_nets/res.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 |
3 | slim = tf.contrib.slim
4 |
5 |
6 | def conv_relu_bn(inputs, num_outputs, kernel_size, stride, idx, use_dilation, bn=False):
7 | scope = f"conv{idx}"
8 | with tf.variable_scope(scope, values=[inputs]):
9 | if use_dilation:
10 | assert stride == 1
11 | rate = int(2**(idx // 3))
12 | net = slim.conv2d(inputs,
13 | num_outputs=num_outputs,
14 | kernel_size=kernel_size,
15 | stride=stride,
16 | rate=rate)
17 | else:
18 | net = slim.conv2d(inputs,
19 | num_outputs=num_outputs,
20 | kernel_size=kernel_size,
21 | stride=stride)
22 | # conv + relu are done
23 | if bn:
24 | net = slim.batch_norm(net, scope=f"{scope}_bn")
25 |
26 | return net
27 |
28 |
29 | def resnet(inputs, num_classes, num_layers, num_channels, pool_size, use_dilation, scope="Res"):
30 | """Re-implement https://github.com/castorini/honk/blob/master/utils/model.py"""
31 | endpoints = dict()
32 |
33 | with tf.variable_scope(scope):
34 | net = slim.conv2d(inputs, num_channels, kernel_size=3, stride=1, scope="f_conv")
35 |
36 | if pool_size:
37 | net = slim.avg_pool2d(net, kernel_size=pool_size, stride=1, scope="avg_pool0")
38 |
39 | # block
40 | num_blocks = num_layers // 2
41 | idx = 0
42 | for i in range(num_blocks):
43 | layer_in = net
44 |
45 | net = conv_relu_bn(net, num_outputs=num_channels, kernel_size=3, stride=1, idx=idx,
46 | use_dilation=use_dilation, bn=True)
47 | idx += 1
48 |
49 | net = conv_relu_bn(net, num_outputs=num_channels, kernel_size=3, stride=1, idx=(2 * i + 1),
50 | use_dilation=use_dilation, bn=False)
51 | idx += 1
52 |
53 | net += layer_in
54 | net = slim.batch_norm(net, scope=f"conv{2 * i + 1}_bn")
55 |
56 | if num_layers % 2 != 0:
57 | net = conv_relu_bn(net, num_outputs=num_channels, kernel_size=3, stride=1, idx=idx,
58 | use_dilation=use_dilation, bn=True)
59 |
60 | # last
61 | net = slim.avg_pool2d(net, kernel_size=net.shape[1:3], stride=1, scope="avg_pool1")
62 |
63 | logits = slim.conv2d(net, num_classes, 1, activation_fn=None, scope="fc")
64 | logits = tf.reshape(logits, shape=(-1, logits.shape[3]), name="squeeze_logit")
65 |
66 | return logits, endpoints
67 |
68 |
69 | def Res8(inputs, num_classes):
70 | return resnet(inputs,
71 | num_classes,
72 | num_layers=6,
73 | num_channels=45,
74 | pool_size=[4, 3],
75 | use_dilation=False)
76 |
77 |
78 | def Res8Narrow(inputs, num_classes):
79 | return resnet(inputs,
80 | num_classes,
81 | num_layers=6,
82 | num_channels=19,
83 | pool_size=[4, 3],
84 | use_dilation=False)
85 |
86 |
87 | def Res15(inputs, num_classes):
88 | return resnet(inputs,
89 | num_classes,
90 | num_layers=13,
91 | num_channels=45,
92 | pool_size=None,
93 | use_dilation=True)
94 |
95 |
96 | def Res15Narrow(inputs, num_classes):
97 | return resnet(inputs,
98 | num_classes,
99 | num_layers=13,
100 | num_channels=19,
101 | pool_size=None,
102 | use_dilation=True)
103 |
104 |
105 | def Res_arg_scope(is_training, weight_decay=0.00001):
106 | batch_norm_params = {
107 | "is_training": is_training,
108 | "center": False,
109 | "scale": False,
110 | "decay": 0.997,
111 | "fused": True,
112 | }
113 |
114 | with slim.arg_scope([slim.conv2d, slim.fully_connected],
115 | weights_initializer=slim.initializers.xavier_initializer(),
116 | weights_regularizer=slim.l2_regularizer(weight_decay),
117 | activation_fn=tf.nn.relu,
118 | biases_initializer=None,
119 | normalizer_fn=None,
120 | padding="SAME",
121 | ):
122 | with slim.arg_scope([slim.batch_norm], **batch_norm_params) as scope:
123 | return scope
124 |
--------------------------------------------------------------------------------
/audio_nets/tc_resnet.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 |
3 | slim = tf.contrib.slim
4 |
5 |
6 | def tc_resnet(inputs, num_classes, n_blocks, n_channels, scope, debug_2d=False, pool=None):
7 | endpoints = dict()
8 | L = inputs.shape[1]
9 | C = inputs.shape[2]
10 |
11 | assert n_blocks == len(n_channels) - 1
12 |
13 | with tf.variable_scope(scope):
14 | if debug_2d:
15 | conv_kernel = first_conv_kernel = [3, 3]
16 | else:
17 | inputs = tf.reshape(inputs, [-1, L, 1, C]) # [N, L, 1, C]
18 | first_conv_kernel = [3, 1]
19 | conv_kernel = [9, 1]
20 |
21 | net = slim.conv2d(inputs, num_outputs=n_channels[0], kernel_size=first_conv_kernel, stride=1, scope="conv0")
22 |
23 | if pool is not None:
24 | net = slim.avg_pool2d(net, kernel_size=pool[0], stride=pool[1], scope="avg_pool_0")
25 |
26 | n_channels = n_channels[1:]
27 |
28 | for i, n in zip(range(n_blocks), n_channels):
29 | with tf.variable_scope(f"block{i}"):
30 | if n != net.shape[-1]:
31 | stride = 2
32 | layer_in = slim.conv2d(net, num_outputs=n, kernel_size=1, stride=stride, scope=f"down")
33 | else:
34 | layer_in = net
35 | stride = 1
36 |
37 | net = slim.conv2d(net, num_outputs=n, kernel_size=conv_kernel, stride=stride, scope=f"conv{i}_0")
38 | net = slim.conv2d(net, num_outputs=n, kernel_size=conv_kernel, stride=1, scope=f"conv{i}_1",
39 | activation_fn=None)
40 | net += layer_in
41 | net = tf.nn.relu(net)
42 |
43 | net = slim.avg_pool2d(net, kernel_size=net.shape[1:3], stride=1, scope="avg_pool")
44 |
45 | net = slim.dropout(net)
46 |
47 | logits = slim.conv2d(net, num_classes, 1, activation_fn=None, normalizer_fn=None, scope="fc")
48 | logits = tf.reshape(logits, shape=(-1, logits.shape[3]), name="squeeze_logit")
49 |
50 | ranges = slim.conv2d(net, 2, 1, activation_fn=None, normalizer_fn=None, scope="fc2")
51 | ranges = tf.reshape(ranges, shape=(-1, ranges.shape[3]), name="squeeze_logit2")
52 | endpoints["ranges"] = tf.sigmoid(ranges)
53 |
54 | return logits, endpoints
55 |
56 |
57 | def TCResNet8(inputs, num_classes, width_multiplier=1.0, scope="TCResNet8"):
58 | n_blocks = 3
59 | n_channels = [16, 24, 32, 48]
60 | n_channels = [int(x * width_multiplier) for x in n_channels]
61 |
62 | return tc_resnet(inputs, num_classes, n_blocks, n_channels, scope=scope)
63 |
64 |
65 | def TCResNet14(inputs, num_classes, width_multiplier=1.0, scope="TCResNet14"):
66 | n_blocks = 6
67 | n_channels = [16, 24, 24, 32, 32, 48, 48]
68 | n_channels = [int(x * width_multiplier) for x in n_channels]
69 |
70 | return tc_resnet(inputs, num_classes, n_blocks, n_channels, scope=scope)
71 |
72 |
73 | def ResNet2D8(inputs, num_classes, width_multiplier=1.0, scope="ResNet2D8"):
74 | n_blocks = 3
75 | n_channels = [16, 24, 32, 48]
76 | n_channels = [int(x * width_multiplier) for x in n_channels]
77 |
78 | # inputs: [N, L, C, 1]
79 | f = inputs.get_shape().as_list()[2]
80 | c1, c2 = n_channels[0:2]
81 | first_c = int((3 * f * c1 + 10 * c1 * c2) / (9 + 10 * c2))
82 | n_channels[0] = first_c
83 |
84 | return tc_resnet(inputs, num_classes, n_blocks, n_channels, scope=scope, debug_2d=True)
85 |
86 |
87 |
88 | def ResNet2D8Pool(inputs, num_classes, width_multiplier=1.0, scope="ResNet2D8Pool"):
89 | n_blocks = 3
90 | n_channels = [16, 24, 32, 48]
91 | n_channels = [int(x * width_multiplier) for x in n_channels]
92 |
93 | # inputs: [N, L, C, 1]
94 | f = inputs.get_shape().as_list()[2]
95 | c1, c2 = n_channels[0:2]
96 | first_c = int((3 * f * c1 + 10 * c1 * c2) / (9 + 10 * c2))
97 | n_channels[0] = first_c
98 |
99 | return tc_resnet(inputs, num_classes, n_blocks, n_channels, debug_2d=True, pool=([4, 4], 4), scope=scope)
100 |
101 |
102 | def TCResNet_arg_scope(is_training, weight_decay=0.001, keep_prob=0.5):
103 | batch_norm_params = {
104 | "is_training": is_training,
105 | "center": True,
106 | "scale": True,
107 | "decay": 0.997,
108 | "fused": True,
109 | }
110 |
111 | with slim.arg_scope([slim.conv2d, slim.fully_connected],
112 | weights_initializer=slim.initializers.xavier_initializer(),
113 | weights_regularizer=slim.l2_regularizer(weight_decay),
114 | activation_fn=tf.nn.relu,
115 | biases_initializer=None,
116 | normalizer_fn=slim.batch_norm,
117 | padding="SAME",
118 | ):
119 | with slim.arg_scope([slim.batch_norm], **batch_norm_params):
120 | with slim.arg_scope([slim.dropout],
121 | keep_prob=keep_prob,
122 | is_training=is_training) as scope:
123 | return scope
124 |
--------------------------------------------------------------------------------
/common/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hyperconnect/TC-ResNet/8ccbff3a45590247d8c54cc82129acb90eecf5c8/common/__init__.py
--------------------------------------------------------------------------------
/common/model_loader.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 |
3 | import tensorflow as tf
4 | from tensorflow.python import pywrap_tensorflow
5 | from tensorflow.python.ops import control_flow_ops
6 | from tensorflow.python.platform import gfile
7 |
8 | from common.utils import format_text, get_logger
9 |
10 |
11 | class Ckpt():
12 | def __init__(
13 | self,
14 | session: tf.Session,
15 | variables_to_restore=None,
16 | include_scopes: str="",
17 | exclude_scopes: str="",
18 | ignore_missing_vars: bool=False,
19 | use_ema: bool=False,
20 | ema_decay: float=None,
21 | logger=None,
22 | ):
23 | self.session = session
24 | self.variables_to_restore = self._get_variables_to_restore(
25 | variables_to_restore,
26 | include_scopes,
27 | exclude_scopes,
28 | use_ema,
29 | ema_decay,
30 | )
31 | self.ignore_missing_vars = ignore_missing_vars
32 | self.logger = logger
33 | if logger is None:
34 | self.logger = get_logger("Ckpt Loader")
35 |
36 | # variables to save reusable info from previous load
37 | self.has_previous_info = False
38 | self.grouped_vars = {}
39 | self.placeholders = {}
40 | self.assign_op = None
41 |
42 | def _get_variables_to_restore(
43 | self,
44 | variables_to_restore=None,
45 | include_scopes: str="",
46 | exclude_scopes: str="",
47 | use_ema: bool=False,
48 | ema_decay: float=None,
49 | ):
50 | # variables_to_restore might be List or Dictionary.
51 |
52 | def split_strip(scopes: str):
53 | return list(filter(lambda x: len(x) > 0, [s.strip() for s in scopes.split(",")]))
54 |
55 | def starts_with(var, scopes: List) -> bool:
56 | return any([var.op.name.startswith(prefix) for prefix in scopes])
57 |
58 | exclusions = split_strip(exclude_scopes)
59 | inclusions = split_strip(include_scopes)
60 |
61 | if variables_to_restore is None:
62 | if use_ema:
63 | if ema_decay is None:
64 | raise ValueError("ema_decay undefined")
65 | else:
66 | ema = tf.train.ExponentialMovingAverage(decay=ema_decay)
67 | variables_to_restore = ema.variables_to_restore() # dictionary
68 | else:
69 | variables_to_restore = tf.contrib.framework.get_variables_to_restore()
70 |
71 | filtered_variables_key = variables_to_restore
72 | if len(inclusions) > 0:
73 | filtered_variables_key = filter(lambda var: starts_with(var, inclusions), filtered_variables_key)
74 | filtered_variables_key = filter(lambda var: not starts_with(var, exclusions), filtered_variables_key)
75 |
76 | if isinstance(variables_to_restore, dict):
77 | variables_to_restore = {
78 | key: variables_to_restore[key] for key in filtered_variables_key
79 | }
80 | elif isinstance(variables_to_restore, list):
81 | variables_to_restore = list(filtered_variables_key)
82 |
83 | return variables_to_restore
84 |
85 | # Copied and revised code not to create duplicated 'assign' operations everytime it gets called.
86 | # https://github.com/tensorflow/tensorflow/blob/r1.8/tensorflow/contrib/framework/python/ops/variables.py#L558
87 | def load(self, checkpoint_stempath):
88 | def get_variable_full_name(var):
89 | if var._save_slice_info:
90 | return var._save_slice_info.full_name
91 | else:
92 | return var.op.name
93 |
94 | if not self.has_previous_info:
95 | if isinstance(self.variables_to_restore, (tuple, list)):
96 | for var in self.variables_to_restore:
97 | ckpt_name = get_variable_full_name(var)
98 | if ckpt_name not in self.grouped_vars:
99 | self.grouped_vars[ckpt_name] = []
100 | self.grouped_vars[ckpt_name].append(var)
101 |
102 | else:
103 | for ckpt_name, value in self.variables_to_restore.items():
104 | if isinstance(value, (tuple, list)):
105 | self.grouped_vars[ckpt_name] = value
106 | else:
107 | self.grouped_vars[ckpt_name] = [value]
108 |
109 | # Read each checkpoint entry. Create a placeholder variable and
110 | # add the (possibly sliced) data from the checkpoint to the feed_dict.
111 | reader = pywrap_tensorflow.NewCheckpointReader(str(checkpoint_stempath))
112 | feed_dict = {}
113 | assign_ops = []
114 | for ckpt_name in self.grouped_vars:
115 | if not reader.has_tensor(ckpt_name):
116 | log_str = f"Checkpoint is missing variable [{ckpt_name}]"
117 | if self.ignore_missing_vars:
118 | self.logger.warning(log_str)
119 | continue
120 | else:
121 | raise ValueError(log_str)
122 | ckpt_value = reader.get_tensor(ckpt_name)
123 |
124 | for var in self.grouped_vars[ckpt_name]:
125 | placeholder_name = f"placeholder/{var.op.name}"
126 | if self.has_previous_info:
127 | placeholder_tensor = self.placeholders[placeholder_name]
128 | else:
129 | placeholder_tensor = tf.placeholder(
130 | dtype=var.dtype.base_dtype,
131 | shape=var.get_shape(),
132 | name=placeholder_name)
133 | assign_ops.append(var.assign(placeholder_tensor))
134 | self.placeholders[placeholder_name] = placeholder_tensor
135 |
136 | if not var._save_slice_info:
137 | if var.get_shape() != ckpt_value.shape:
138 | raise ValueError(
139 | f"Total size of new array must be unchanged for {ckpt_name} "
140 | f"lh_shape: [{str(ckpt_value.shape)}], rh_shape: [{str(var.get_shape())}]")
141 |
142 | feed_dict[placeholder_tensor] = ckpt_value.reshape(ckpt_value.shape)
143 | else:
144 | slice_dims = zip(var._save_slice_info.var_offset,
145 | var._save_slice_info.var_shape)
146 | slice_dims = [(start, start + size) for (start, size) in slice_dims]
147 | slice_dims = [slice(*x) for x in slice_dims]
148 | slice_value = ckpt_value[slice_dims]
149 | slice_value = slice_value.reshape(var._save_slice_info.var_shape)
150 | feed_dict[placeholder_tensor] = slice_value
151 |
152 | if not self.has_previous_info:
153 | self.assign_op = control_flow_ops.group(*assign_ops)
154 |
155 | self.session.run(self.assign_op, feed_dict)
156 |
157 | if len(feed_dict) > 0:
158 | for key in feed_dict.keys():
159 | self.logger.info(f"init from checkpoint > {key}")
160 | else:
161 | self.logger.info(f"No init from checkpoint")
162 |
163 | with format_text("cyan", attrs=["bold", "underline"]) as fmt:
164 | self.logger.info(fmt(f"Restore from {checkpoint_stempath}"))
165 | self.has_previous_info = True
166 |
--------------------------------------------------------------------------------
/common/tf_utils.py:
--------------------------------------------------------------------------------
1 | import subprocess
2 | import shutil
3 | from typing import Dict
4 | from functools import reduce
5 | from pathlib import Path
6 | from operator import mul
7 |
8 | import tensorflow as tf
9 | import pandas as pd
10 | import numpy as np
11 | from termcolor import colored
12 | from tensorflow.contrib.training import checkpoints_iterator
13 | from common.utils import get_logger
14 | from common.utils import wait
15 |
16 | import const
17 |
18 |
19 | def get_variables_to_train(trainable_scopes, logger):
20 | """Returns a list of variables to train.
21 | Returns:
22 | A list of variables to train by the optimizer.
23 | """
24 | if trainable_scopes is None or trainable_scopes == "":
25 | return tf.trainable_variables()
26 | else:
27 | scopes = [scope.strip() for scope in trainable_scopes.split(",")]
28 |
29 | variables_to_train = []
30 | for scope in scopes:
31 | variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope)
32 | variables_to_train.extend(variables)
33 |
34 | for var in variables_to_train:
35 | logger.info("vars to train > {}".format(var.name))
36 |
37 | return variables_to_train
38 |
39 |
40 | def show_models(logger):
41 | trainable_variables = set(tf.contrib.framework.get_variables(collection=tf.GraphKeys.TRAINABLE_VARIABLES))
42 | all_variables = tf.contrib.framework.get_variables()
43 | trainable_vars = tf.trainable_variables()
44 | total_params = 0
45 | total_trainable_params = 0
46 | logger.info(colored(f">> Start of showing all variables", "cyan", attrs=["bold"]))
47 | for v in all_variables:
48 | is_trainable = v in trainable_variables
49 | count_params = reduce(mul, v.get_shape().as_list(), 1)
50 | total_params += count_params
51 | total_trainable_params += (count_params if is_trainable else 0)
52 | color = "cyan" if is_trainable else "green"
53 | logger.info(colored((
54 | f">> {v.name} {v.dtype} : {v.get_shape().as_list()}, {count_params} ... {total_params} "
55 | f"(is_trainable: {is_trainable})"
56 | ), color))
57 | logger.info(colored(
58 | f">> End of showing all variables // Number of variables: {len(all_variables)}, "
59 | f"Number of trainable variables : {len(trainable_vars)}, "
60 | f"Total prod + sum of shape: {total_params} ({total_trainable_params} trainable)",
61 | "cyan", attrs=["bold"]))
62 | return total_params
63 |
64 |
65 | def ckpt_iterator(checkpoint_dir, min_interval_secs=0, timeout=None, timeout_fn=None, logger=None):
66 | for ckpt_path in checkpoints_iterator(checkpoint_dir, min_interval_secs, timeout, timeout_fn):
67 | yield ckpt_path
68 |
69 |
70 | class BestKeeper(object):
71 | def __init__(
72 | self,
73 | metric_with_modes,
74 | dataset_name,
75 | directory,
76 | logger=None,
77 | epsilon=0.00005,
78 | score_file="scores.tsv",
79 | metric_best: Dict={},
80 | ):
81 | """Keep best model's checkpoint by each datasets & metrics
82 |
83 | Args:
84 | metric_with_modes: Dict, metric_name: mode
85 | if mode is 'min', then it means that minimum value is best, for example loss(MSE, MAE)
86 | if mode is 'max', then it means that maximum value is best, for example Accuracy, Precision, Recall
87 | dataset_name: str, dataset name on which metric be will be calculated
88 | directory: directory path for saving best model
89 | epsilon: float, threshold for measuring the new optimum, to only focus on significant changes.
90 | Because sometimes early-stopping gives better generalization results
91 | """
92 | if logger is not None:
93 | self.log = logger
94 | else:
95 | self.log = get_logger("BestKeeper")
96 |
97 | self.score_file = score_file
98 | self.metric_best = metric_best
99 |
100 | self.log.info(colored(f"Initialize BestKeeper: Monitor {dataset_name} & Save to {directory}",
101 | "yellow", attrs=["underline"]))
102 | self.log.info(f"{metric_with_modes}")
103 |
104 | self.x_better_than_y = {}
105 | self.directory = Path(directory)
106 | self.output_temp_dir = self.directory / f"{dataset_name}_best_keeper_temp"
107 |
108 | for metric_name, mode in metric_with_modes.items():
109 | if mode == "min":
110 | self.metric_best[metric_name] = self.load_metric_from_scores_tsv(
111 | directory / dataset_name / metric_name / score_file,
112 | metric_name,
113 | np.inf,
114 | )
115 | self.x_better_than_y[metric_name] = lambda x, y: np.less(x, y - epsilon)
116 | elif mode == "max":
117 | self.metric_best[metric_name] = self.load_metric_from_scores_tsv(
118 | directory / dataset_name / metric_name / score_file,
119 | metric_name,
120 | -np.inf,
121 | )
122 | self.x_better_than_y[metric_name] = lambda x, y: np.greater(x, y + epsilon)
123 | else:
124 | raise ValueError(f"Unsupported mode : {mode}")
125 |
126 | def load_metric_from_scores_tsv(
127 | self,
128 | full_path: Path,
129 | metric_name: str,
130 | default_value: float,
131 | ) -> float:
132 | def parse_scores(s: str):
133 | if len(s) > 0:
134 | return float(s)
135 | else:
136 | return default_value
137 |
138 | if full_path.exists():
139 | with open(full_path, "r") as f:
140 | header = f.readline().strip().split("\t")
141 | values = list(map(parse_scores, f.readline().strip().split("\t")))
142 | metric_index = header.index(metric_name)
143 |
144 | return values[metric_index]
145 | else:
146 | return default_value
147 |
148 | def monitor(self, dataset_name, eval_scores):
149 | metrics_keep = {}
150 | is_keep = False
151 | for metric_name, score in self.metric_best.items():
152 | score = eval_scores[metric_name]
153 | if self.x_better_than_y[metric_name](score, self.metric_best[metric_name]):
154 | old_score = self.metric_best[metric_name]
155 | self.metric_best[metric_name] = score
156 | metrics_keep[metric_name] = True
157 | is_keep = True
158 | self.log.info(colored("[KeepBest] {} {:.6f} -> {:.6f}, so keep it!".format(
159 | metric_name, old_score, score), "blue", attrs=["underline"]))
160 | else:
161 | metrics_keep[metric_name] = False
162 | return is_keep, metrics_keep
163 |
164 | def save_best(self, dataset_name, metrics_keep, ckpt_glob):
165 | for metric_name, is_keep in metrics_keep.items():
166 | if is_keep:
167 | keep_path = self.directory / Path(dataset_name) / Path(metric_name)
168 | self.keep_checkpoint(keep_path, ckpt_glob)
169 | self.keep_converted_files(keep_path)
170 |
171 | def save_scores(self, dataset_name, metrics_keep, eval_scores, meta_info=None):
172 | eval_scores_with_meta = eval_scores.copy()
173 | if meta_info is not None:
174 | eval_scores_with_meta.update(meta_info)
175 |
176 | for metric_name, is_keep in metrics_keep.items():
177 | if is_keep:
178 | keep_path = self.directory / Path(dataset_name) / Path(metric_name)
179 | if not keep_path.exists():
180 | keep_path.mkdir(parents=True)
181 | df = pd.DataFrame(pd.Series(eval_scores_with_meta)).sort_index().transpose()
182 | df.to_csv(keep_path / self.score_file, sep="\t", index=False, float_format="%.5f")
183 |
184 | def remove_old_best(self, dataset_name, metrics_keep):
185 | for metric_name, is_keep in metrics_keep.items():
186 | if is_keep:
187 | keep_path = self.directory / Path(dataset_name) / Path(metric_name)
188 | # Remove old directory to save space
189 | if keep_path.exists():
190 | shutil.rmtree(str(keep_path))
191 | keep_path.mkdir(parents=True)
192 |
193 | def keep_checkpoint(self, keep_dir, ckpt_glob):
194 | if not isinstance(keep_dir, Path):
195 | keep_dir = Path(keep_dir)
196 |
197 | # .data-00000-of-00001, .meta, .index
198 | for ckpt_path in ckpt_glob.parent.glob(ckpt_glob.name):
199 | shutil.copy(str(ckpt_path), str(keep_dir))
200 |
201 | with open(keep_dir / "checkpoint", "w") as f:
202 | f.write(f'model_checkpoint_path: "{Path(ckpt_path.name).stem}"') # noqa
203 |
204 | def keep_converted_files(self, keep_path):
205 | if not isinstance(keep_path, Path):
206 | keep_path = Path(keep_path)
207 |
208 | for path in self.output_temp_dir.glob("*"):
209 | if path.is_dir():
210 | shutil.copytree(str(path), str(keep_path / path.name))
211 | else:
212 | shutil.copy(str(path), str(keep_path / path.name))
213 |
214 | def remove_temp_dir(self):
215 | if self.output_temp_dir.exists():
216 | shutil.rmtree(str(self.output_temp_dir))
217 |
218 |
219 | def resolve_checkpoint_path(checkpoint_path, log, is_training):
220 | if checkpoint_path is not None and Path(checkpoint_path).is_dir():
221 | old_ckpt_path = checkpoint_path
222 | checkpoint_path = tf.train.latest_checkpoint(old_ckpt_path)
223 | if not is_training:
224 | def stop_checker():
225 | return (tf.train.latest_checkpoint(old_ckpt_path) is not None)
226 | wait("There are no checkpoint file yet", stop_checker) # wait until checkpoint occurs
227 | checkpoint_path = tf.train.latest_checkpoint(old_ckpt_path)
228 | log.info(colored(
229 | "self.args.checkpoint_path updated: {} -> {}".format(old_ckpt_path, checkpoint_path),
230 | "yellow", attrs=["bold"]))
231 | else:
232 | log.info(colored("checkpoint_path is {}".format(checkpoint_path), "yellow", attrs=["bold"]))
233 |
234 | return checkpoint_path
235 |
236 |
237 | def get_global_step_from_checkpoint(checkpoint_path):
238 | """It is assumed that `checkpoint_path` is path to checkpoint file, not path to directory
239 | with checkpoint files.
240 | In case checkpoint path is not defined, 0 is returned."""
241 | if checkpoint_path is None or checkpoint_path == "":
242 | return 0
243 | else:
244 | if "-" in Path(checkpoint_path).stem:
245 | return int(Path(checkpoint_path).stem.split("-")[-1])
246 | else:
247 | return 0
248 |
--------------------------------------------------------------------------------
/common/utils.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import humanfriendly as hf
3 | import contextlib
4 | import argparse
5 | import logging
6 | import getpass
7 | import shutil
8 | import json
9 | import time
10 | from pathlib import Path
11 | from types import SimpleNamespace
12 | from datetime import datetime
13 |
14 | from tensorflow.python.platform import tf_logging
15 | import click
16 | from termcolor import colored
17 |
18 |
19 | LOG_FORMAT = "[%(levelname)s|%(filename)s:%(lineno)s] %(asctime)s > %(message)s"
20 |
21 |
22 | def update_train_dir(args):
23 | def replace_func(base_string, a, b):
24 | replaced_string = base_string.replace(a, b)
25 | print(colored("[update_train_dir] replace {} : {} -> {}".format(a, base_string, replaced_string),
26 | "yellow"))
27 | return replaced_string
28 |
29 | def make_placeholder(s: str, circumfix: str="%"):
30 | return circumfix + s.upper() + circumfix
31 |
32 | placeholder_mapping = {
33 | make_placeholder("DATE"): datetime.now().strftime("%y%m%d%H%M%S"),
34 | make_placeholder("USER"): getpass.getuser(),
35 | }
36 |
37 | for key, value in placeholder_mapping.items():
38 | args.train_dir = replace_func(args.train_dir, key, value)
39 |
40 | unknown = "UNKNOWN"
41 | for key, value in vars(args).items():
42 | key_placeholder = make_placeholder(key)
43 | if key_placeholder in args.train_dir:
44 | replace_value = value
45 | if isinstance(replace_value, str):
46 | if "/" in replace_value:
47 | replace_value = unknown
48 | elif isinstance(replace_value, list):
49 | replace_value = ",".join(map(str, replace_value))
50 | elif isinstance(replace_value, float) or isinstance(replace_value, int):
51 | replace_value = str(replace_value)
52 | elif isinstance(replace_value, bool):
53 | replace_value = str(replace_value)
54 | else:
55 | replace_value = unknown
56 | args.train_dir = replace_func(args.train_dir, key_placeholder, replace_value)
57 |
58 | print(colored("[update_train_dir] final train_dir {}".format(args.train_dir),
59 | "yellow", attrs=["bold", "underline"]))
60 |
61 |
62 | def positive_int(value):
63 | ivalue = int(value)
64 | if ivalue <= 0:
65 | raise argparse.ArgumentTypeError("%s is an invalid positive int value" % value)
66 | return ivalue
67 |
68 |
69 | def get_logger(logger_name=None, log_file: Path=None, level=logging.DEBUG):
70 | # "log/data-pipe-{}.log".format(datetime.now().strftime("%Y-%m-%d-%H-%M-%S"))
71 | if logger_name is None:
72 | logger = tf_logging._get_logger()
73 | else:
74 | logger = logging.getLogger(logger_name)
75 |
76 | if not logger.hasHandlers():
77 | formatter = logging.Formatter(LOG_FORMAT)
78 |
79 | logger.setLevel(level)
80 |
81 | if log_file is not None:
82 | log_file.parent.mkdir(parents=True, exist_ok=True)
83 | fileHandler = logging.FileHandler(log_file, mode="w")
84 | fileHandler.setFormatter(formatter)
85 | logger.addHandler(fileHandler)
86 |
87 | streamHandler = logging.StreamHandler()
88 | streamHandler.setFormatter(formatter)
89 | logger.addHandler(streamHandler)
90 |
91 | return logger
92 |
93 |
94 | def format_timespan(duration):
95 | if duration < 10:
96 | readable_duration = "{:.1f} (ms)".format(duration * 1000)
97 | else:
98 | readable_duration = hf.format_timespan(duration)
99 | return readable_duration
100 |
101 |
102 | @contextlib.contextmanager
103 | def timer(name):
104 | st = time.time()
105 | yield
106 | print(" {} : {}".format(name, format_timespan(time.time() - st)))
107 |
108 |
109 | def timeit(method):
110 | def timed(*args, **kw):
111 | hf_timer = hf.Timer()
112 | result = method(*args, **kw)
113 | print(" {!r} ({!r}, {!r}) {}".format(method.__name__, args, kw, hf_timer.rounded))
114 | return result
115 | return timed
116 |
117 |
118 | class Timer(object):
119 | def __init__(self, log):
120 | self.log = log
121 |
122 | @contextlib.contextmanager
123 | def __call__(self, name, log_func=None):
124 | """
125 | Example.
126 | timer = Timer(log)
127 | with timer("Some Routines"):
128 | routine1()
129 | routine2()
130 | """
131 | if log_func is None:
132 | log_func = self.log.info
133 |
134 | start = time.clock()
135 | yield
136 | end = time.clock()
137 | duration = end - start
138 | readable_duration = format_timespan(duration)
139 | log_func(f"{name} :: {readable_duration}")
140 |
141 |
142 | class TextFormatter(object):
143 | def __init__(self, color, attrs):
144 | self.color = color
145 | self.attrs = attrs
146 |
147 | def __call__(self, string):
148 | return colored(string, self.color, attrs=self.attrs)
149 |
150 |
151 | class LogFormatter(object):
152 | def __init__(self, log, color, attrs):
153 | self.log = log
154 | self.color = color
155 | self.attrs = attrs
156 |
157 | def __call__(self, string):
158 | return self.log(colored(string, self.color, attrs=self.attrs))
159 |
160 |
161 | @contextlib.contextmanager
162 | def format_text(color, attrs=None):
163 | yield TextFormatter(color, attrs)
164 |
165 |
166 | def format_log(log, color, attrs=None):
167 | return LogFormatter(log, color, attrs)
168 |
169 | def wait(message, stop_checker_closure):
170 | assert callable(stop_checker_closure)
171 | st = time.time()
172 | while True:
173 | try:
174 | time_pass = hf.format_timespan(int(time.time() - st))
175 | sys.stdout.write(colored((
176 | f"{message}. Do you wanna wait? If not, then ctrl+c! :: waiting time: {time_pass}\r"
177 | ), "yellow", attrs=["bold"]))
178 | sys.stdout.flush()
179 | time.sleep(1)
180 | if stop_checker_closure():
181 | break
182 | except KeyboardInterrupt:
183 | break
184 |
185 |
186 | class MLNamespace(SimpleNamespace):
187 | def __init__(self, *args, **kwargs):
188 | for kwarg in kwargs.keys():
189 | assert kwarg not in dir(self)
190 | super().__init__(*args, **kwargs)
191 |
192 | def unordered_values(self):
193 | return list(vars(self).values())
194 |
195 | def __setitem__(self, key, value):
196 | setattr(self, key, value)
197 |
--------------------------------------------------------------------------------
/const.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 |
3 |
4 | TF_SESSION_CONFIG = tf.ConfigProto(
5 | gpu_options=tf.GPUOptions(allow_growth=True),
6 | log_device_placement=False,
7 | device_count={"GPU": 1})
8 | NULL_CLASS_LABEL = "__null__"
9 | BACKGROUND_NOISE_DIR_NAME = "_background_noise_"
10 |
--------------------------------------------------------------------------------
/datasets/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hyperconnect/TC-ResNet/8ccbff3a45590247d8c54cc82129acb90eecf5c8/datasets/__init__.py
--------------------------------------------------------------------------------
/datasets/audio_data_wrapper.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 |
3 | import tensorflow as tf
4 |
5 | import const
6 | from datasets.data_wrapper_base import DataWrapperBase
7 | from datasets.augmentation_factory import get_audio_augmentation_fn
8 |
9 |
10 | class AudioDataWrapper(DataWrapperBase):
11 | def __init__(
12 | self, args, session, dataset_split_name, is_training, name: str="AudioDataWrapper"
13 | ):
14 | super().__init__(args, dataset_split_name, is_training, name)
15 | self.setup()
16 |
17 | self.setup_dataset(self.placeholders)
18 | self.setup_iterator(
19 | session,
20 | self.placeholders,
21 | self.data,
22 | )
23 |
24 | def set_fileformat(self, filenames):
25 | file_formats = set(Path(fn).suffix for fn in filenames)
26 | assert len(file_formats) == 1
27 | self.file_format = file_formats.pop()[1:] # decode_audio receives without .(dot)
28 |
29 | @property
30 | def num_samples(self):
31 | return self._num_samples
32 |
33 | def augment_audio(self, filename, desired_samples, file_format, sample_rate, **kwargs):
34 | aug_fn = get_audio_augmentation_fn(self.args.augmentation_method)
35 | return aug_fn(filename, desired_samples, file_format, sample_rate, **kwargs)
36 |
37 | def _parse_function(self, filename, label):
38 | """
39 | `filename` tensor holds full path of input
40 | `label` tensor is holding index of class to which it belongs to
41 | """
42 | desired_samples = int(self.args.sample_rate * self.args.clip_duration_ms / 1000)
43 |
44 | # augment
45 | augmented_audio = self.augment_audio(
46 | filename,
47 | desired_samples,
48 | self.file_format,
49 | self.args.sample_rate,
50 | background_data=self.background_data,
51 | is_training=self.is_training,
52 | background_frequency=self.background_frequency,
53 | background_max_volume=self.background_max_volume,
54 | )
55 |
56 | label_parsed = self.parse_label(label)
57 |
58 | return augmented_audio, label_parsed
59 |
60 | @staticmethod
61 | def add_arguments(parser):
62 | g = parser.add_argument_group("(AudioDataWrapper) Arguments for Audio DataWrapper")
63 | g.add_argument(
64 | "--sample_rate",
65 | type=int,
66 | default=16000,
67 | help="Expected sample rate of the wavs",)
68 | g.add_argument(
69 | "--clip_duration_ms",
70 | type=int,
71 | default=1000,
72 | help=("Expected duration in milliseconds of the wavs"
73 | "the audio will be cropped or padded with zeroes based on this value"),)
74 | g.add_argument(
75 | "--window_size_ms",
76 | type=float,
77 | default=30.0,
78 | help="How long each spectrogram timeslice is.",)
79 | g.add_argument(
80 | "--window_stride_ms",
81 | type=float,
82 | default=10.0,
83 | help="How far to move in time between spectogram timeslices.",)
84 |
85 | # {{ -- Arguments for log-mel spectrograms
86 | # Default values are coming from tensorflow official tutorial
87 | g.add_argument("--lower_edge_hertz", type=float, default=80.0)
88 | g.add_argument("--upper_edge_hertz", type=float, default=7600.0)
89 | g.add_argument("--num_mel_bins", type=int, default=64)
90 | # Arguments for log-mel spectrograms -- }}
91 |
92 | # {{ -- Arguments for mfcc
93 | # Google speech_commands sample uses num_mfccs=40 as a default value
94 | # Official signal processing tutorial uses num_mfccs=13 as a default value
95 | g.add_argument("--num_mfccs", type=int, default=40)
96 | # Arguments for mfcc -- }}
97 |
98 | g.add_argument("--input_file", default=None, type=str)
99 | g.add_argument("--description_file", default=None, type=str)
100 | g.add_argument("--num_partitions", default=2, type=int,
101 | help=("Number of partition to which is input csv file split"
102 | "and parallely processed"))
103 |
104 | # background noise
105 | g.add_argument("--background_max_volume", default=0.1, type=float,
106 | help="How loud the background noise should be, between 0 and 1.")
107 | g.add_argument("--background_frequency", default=0.8, type=float,
108 | help="How many of the training samples have background noise mixed in.")
109 | g.add_argument("--num_silent", default=-1, type=int,
110 | help="How many silent data should be added. -1 means automatically calculated.")
111 |
112 |
113 | class SingleLabelAudioDataWrapper(AudioDataWrapper):
114 | def parse_label(self, label):
115 | return tf.sparse_to_dense(sparse_indices=tf.cast(label, tf.int32),
116 | sparse_values=tf.ones([1], tf.float32),
117 | output_shape=[self.num_labels],
118 | validate_indices=False)
119 |
120 | def setup(self):
121 | dataset_paths = self.get_all_dataset_paths()
122 | self.label_names, self.num_labels = self.get_label_names(dataset_paths)
123 | assert const.NULL_CLASS_LABEL in self.label_names
124 | assert self.args.num_classes == self.num_labels
125 |
126 | self.filenames, self.labels = self.get_filenames_labels(dataset_paths)
127 | self.set_fileformat(self.filenames)
128 |
129 | # add dummy data for silent class
130 | self.background_max_volume = tf.constant(self.args.background_max_volume)
131 | self.background_frequency = tf.constant(self.args.background_frequency)
132 | self.background_data = self.prepare_silent_data(dataset_paths)
133 | self.add_silent_data()
134 | self._num_samples = self.count_samples(self.filenames)
135 |
136 | self.data = (self.filenames, self.labels)
137 |
138 | self.filenames_placeholder = tf.placeholder(tf.string, self._num_samples)
139 | self.labels_placeholder = tf.placeholder(tf.int32, self._num_samples)
140 | self.placeholders = (self.filenames_placeholder, self.labels_placeholder)
141 |
142 | # shuffle
143 | if self.shuffle:
144 | self.data = self.do_shuffle(*self.data)
145 |
146 | def prepare_silent_data(self, dataset_paths):
147 | def _gen(filename):
148 | filename = tf.constant(filename, dtype=tf.string)
149 | read_fn = get_audio_augmentation_fn("no_augmentation_audio")
150 | desired_samples = -1 # read all
151 | wav_data = read_fn(filename, desired_samples, self.file_format, self.args.sample_rate)
152 | return wav_data
153 |
154 | background_data = list()
155 | for dataset_path in dataset_paths:
156 | for label_path in dataset_path.iterdir():
157 | if label_path.name == const.BACKGROUND_NOISE_DIR_NAME:
158 | for wav_fullpath in label_path.glob("*.wav"):
159 | background_data.append(_gen(str(wav_fullpath)))
160 |
161 | self.log.info(f"{len(background_data)} background files are loaded.")
162 | return background_data
163 |
164 | def add_silent_data(self):
165 | num_silent = self.args.num_silent
166 | if self.args.num_silent < 0:
167 | num_samples = self.count_samples(self.filenames)
168 | num_silent = num_samples // self.num_labels
169 |
170 | label_idx = self.label_names.index(const.NULL_CLASS_LABEL)
171 | for _ in range(num_silent):
172 | self.filenames.append("")
173 | self.labels.append(label_idx)
174 | self.log.info(f"{num_silent} silent samples will be added.")
175 |
--------------------------------------------------------------------------------
/datasets/augmentation_factory.py:
--------------------------------------------------------------------------------
1 | from tensorflow.contrib.framework.python.ops import audio_ops as contrib_audio
2 | import tensorflow as tf
3 |
4 |
5 | _available_audio_augmentation_methods = [
6 | "anchored_slice_or_pad",
7 | "anchored_slice_or_pad_with_shift",
8 | "no_augmentation_audio",
9 | ]
10 |
11 |
12 | _available_augmentation_methods = (
13 | _available_audio_augmentation_methods +
14 | ["no_augmentation"]
15 | )
16 |
17 |
18 | def no_augmentation(x):
19 | return x
20 |
21 |
22 | def _gen_random_from_zero(maxval, dtype=tf.float32):
23 | return tf.random.uniform([], maxval=maxval, dtype=dtype)
24 |
25 |
26 | def _gen_empty_audio(desired_samples):
27 | return tf.zeros([desired_samples, 1], dtype=tf.float32)
28 |
29 |
30 | def _mix_background(
31 | audio,
32 | desired_samples,
33 | background_data,
34 | is_silent,
35 | is_training,
36 | background_frequency,
37 | background_max_volume,
38 | naive_version=True,
39 | **kwargs
40 | ):
41 | """
42 | Args:
43 | audio: Tensor of audio.
44 | desired_samples: int value of desired length.
45 | background_data: List of background audios.
46 | is_silent: Tensor[Bool].
47 | is_training: Tensor[Bool].
48 | background_frequency: probability of mixing background. [0.0, 1.0]
49 | background_max_volume: scaling factor of mixing background. [0.0, 1.0]
50 | """
51 | foreground_wav = tf.cond(
52 | is_silent,
53 | true_fn=lambda: _gen_empty_audio(desired_samples),
54 | false_fn=lambda: tf.identity(audio)
55 | )
56 |
57 | # sampling background
58 | random_background_data_idx = _gen_random_from_zero(
59 | len(background_data),
60 | dtype=tf.int32
61 | )
62 | background_wav = tf.case({
63 | tf.equal(background_data_idx, random_background_data_idx):
64 | lambda tensor=wav: tensor
65 | for background_data_idx, wav in enumerate(background_data)
66 | }, exclusive=True)
67 | background_wav = tf.random_crop(background_wav, [desired_samples, 1])
68 |
69 | if naive_version:
70 | # Version 1
71 | # https://github.com/ARM-software/ML-KWS-for-MCU/blob/master/input_data.py#L461
72 | if is_training:
73 | background_volume = tf.cond(
74 | tf.less(_gen_random_from_zero(1.0), background_frequency),
75 | true_fn=lambda: _gen_random_from_zero(background_max_volume),
76 | false_fn=lambda: 0.0,
77 | )
78 | else:
79 | background_volume = 0.0
80 | else:
81 | # Version 2
82 | # https://github.com/tensorflow/tensorflow/blob/master/tensorflow/examples/speech_commands/input_data.py#L570
83 | background_volume = tf.cond(
84 | tf.logical_or(is_training, is_silent),
85 | true_fn=lambda: tf.cond(
86 | is_silent,
87 | true_fn=lambda: _gen_random_from_zero(1.0),
88 | false_fn=lambda: tf.cond(
89 | tf.less(_gen_random_from_zero(1.0), background_frequency),
90 | true_fn=lambda: _gen_random_from_zero(background_max_volume),
91 | false_fn=lambda: 0.0,
92 | ),
93 | ),
94 | false_fn=lambda: 0.0,
95 | )
96 |
97 | background_wav = tf.multiply(background_wav, background_volume)
98 | background_added = tf.add(background_wav, foreground_wav)
99 | augmented_audio = tf.clip_by_value(background_added, -1.0, 1.0)
100 |
101 | return augmented_audio
102 |
103 |
104 | def _shift_audio(audio, desired_samples, shift_ratio=0.1):
105 | time_shift = int(desired_samples * shift_ratio)
106 | time_shift_amount = tf.random.uniform(
107 | [],
108 | minval=-time_shift,
109 | maxval=time_shift,
110 | dtype=tf.int32
111 | )
112 |
113 | time_shift_abs = tf.abs(time_shift_amount)
114 |
115 | def _pos_padding():
116 | return [[time_shift_amount, 0], [0, 0]]
117 |
118 | def _pos_offset():
119 | return [0, 0]
120 |
121 | def _neg_padding():
122 | return [[0, time_shift_abs], [0, 0]]
123 |
124 | def _neg_offset():
125 | return [time_shift_abs, 0]
126 |
127 | padded_audio = tf.pad(
128 | audio,
129 | tf.cond(tf.greater_equal(time_shift_amount, 0),
130 | true_fn=_pos_padding,
131 | false_fn=_neg_padding),
132 | mode="CONSTANT",
133 | )
134 |
135 | sliced_audio = tf.slice(
136 | padded_audio,
137 | tf.cond(tf.greater_equal(time_shift_amount, 0),
138 | true_fn=_pos_offset,
139 | false_fn=_neg_offset),
140 | [desired_samples, 1],
141 | )
142 |
143 | return sliced_audio
144 |
145 |
146 | def _load_wav_file(filename, desired_samples, file_format):
147 | if file_format == "wav":
148 | wav_decoder = contrib_audio.decode_wav(
149 | tf.read_file(filename),
150 | desired_channels=1,
151 | # If desired_samples is set, then the audio will be
152 | # cropped or padded with zeroes to the requested length.
153 | desired_samples=desired_samples,
154 | )
155 | else:
156 | raise ValueError(f"Unsupported file format: {file_format}")
157 |
158 | return wav_decoder.audio
159 |
160 |
161 | def no_augmentation_audio(
162 | filename,
163 | desired_samples,
164 | file_format,
165 | sample_rate,
166 | **kwargs
167 | ):
168 | return _load_wav_file(filename, desired_samples, file_format)
169 |
170 |
171 | def anchored_slice_or_pad(
172 | filename,
173 | desired_samples,
174 | file_format,
175 | sample_rate,
176 | **kwargs,
177 | ):
178 | is_silent = tf.equal(tf.strings.length(filename), 0)
179 |
180 | audio = tf.cond(
181 | is_silent,
182 | true_fn=lambda: _gen_empty_audio(desired_samples),
183 | false_fn=lambda: _load_wav_file(filename, desired_samples, file_format)
184 | )
185 |
186 | if "background_data" in kwargs:
187 | audio = _mix_background(audio, desired_samples, is_silent=is_silent, **kwargs)
188 |
189 | return audio
190 |
191 |
192 | def anchored_slice_or_pad_with_shift(
193 | filename,
194 | desired_samples,
195 | file_format,
196 | sample_rate,
197 | **kwargs
198 | ):
199 | is_silent = tf.equal(tf.strings.length(filename), 0)
200 |
201 | audio = tf.cond(
202 | is_silent,
203 | true_fn=lambda: _gen_empty_audio(desired_samples),
204 | false_fn=lambda: _load_wav_file(filename, desired_samples, file_format)
205 | )
206 | audio = _shift_audio(audio, desired_samples, shift_ratio=0.1)
207 |
208 | if "background_data" in kwargs:
209 | audio = _mix_background(audio, desired_samples, is_silent=is_silent, **kwargs)
210 |
211 | return audio
212 |
213 |
214 | def get_audio_augmentation_fn(name):
215 | if name not in _available_audio_augmentation_methods:
216 | raise ValueError(f"Augmentation name [{name}] was not recognized")
217 | return eval(name)
218 |
--------------------------------------------------------------------------------
/datasets/data_wrapper_base.py:
--------------------------------------------------------------------------------
1 | from abc import ABC
2 | from abc import abstractmethod
3 | from pathlib import Path
4 | from typing import Tuple
5 | from typing import List
6 | from collections import defaultdict
7 | import random
8 |
9 | import tensorflow as tf
10 | import pandas as pd
11 | from termcolor import colored
12 |
13 | import common.utils as utils
14 | import const
15 | from datasets.augmentation_factory import _available_augmentation_methods
16 |
17 |
18 | class DataWrapperBase(ABC):
19 | def __init__(
20 | self,
21 | args,
22 | dataset_split_name: str,
23 | is_training: bool,
24 | name: str,
25 | ):
26 | self.name = name
27 | self.args = args
28 | self.dataset_split_name = dataset_split_name
29 | self.is_training = is_training
30 |
31 | self.shuffle = self.args.shuffle
32 |
33 | self.log = utils.get_logger(self.name)
34 | self.timer = utils.Timer(self.log)
35 | self.dataset_path = Path(self.args.dataset_path)
36 | self.dataset_path_with_split_name = self.dataset_path / self.dataset_split_name
37 |
38 | with utils.format_text("yellow", ["underline"]) as fmt:
39 | self.log.info(self.name)
40 | self.log.info(fmt(f"dataset_path_with_split_name: {self.dataset_path_with_split_name}"))
41 | self.log.info(fmt(f"dataset_split_name: {self.dataset_split_name}"))
42 |
43 | @property
44 | @abstractmethod
45 | def num_samples(self):
46 | pass
47 |
48 | @property
49 | def batch_size(self):
50 | try:
51 | return self._batch_size
52 | except AttributeError:
53 | self._batch_size = 0
54 |
55 | @batch_size.setter
56 | def batch_size(self, val):
57 | self._batch_size = val
58 |
59 | def setup_dataset(
60 | self,
61 | placeholders: Tuple[tf.placeholder, tf.placeholder],
62 | batch_size: int=None,
63 | ):
64 | self.batch_size = self.args.batch_size if batch_size is None else batch_size
65 |
66 | # single-GPU: prefetch before batch-shuffle-repeat
67 | dataset = tf.data.Dataset.from_tensor_slices(placeholders)
68 | if self.shuffle:
69 | dataset = dataset.shuffle(self.num_samples) # Follow tf.data.Dataset.list_files
70 | dataset = dataset.map(self._parse_function, num_parallel_calls=self.args.num_threads)
71 |
72 | if hasattr(tf.contrib.data, "AUTOTUNE"):
73 | dataset = dataset.prefetch(
74 | buffer_size=tf.contrib.data.AUTOTUNE
75 | )
76 | else:
77 | dataset = dataset.prefetch(
78 | buffer_size=self.args.prefetch_factor * self.batch_size
79 | )
80 |
81 | dataset = dataset.batch(self.batch_size)
82 | if self.is_training and self.shuffle:
83 | dataset = dataset.shuffle(buffer_size=self.args.buffer_size, reshuffle_each_iteration=True).repeat(-1)
84 | elif self.is_training and not self.shuffle:
85 | dataset = dataset.repeat(-1)
86 |
87 | self.dataset = dataset
88 | self.iterator = self.dataset.make_initializable_iterator()
89 | self.next_elem = self.iterator.get_next()
90 |
91 | def setup_iterator(
92 | self,
93 | session: tf.Session,
94 | placeholders: Tuple[tf.placeholder, ...],
95 | variables: Tuple[tf.placeholder, ...],
96 | ):
97 | assert len(placeholders) == len(variables), "Length of placeholders and variables differ!"
98 | with self.timer(colored("Initialize data iterator.", "yellow")):
99 | session.run(self.iterator.initializer,
100 | feed_dict={placeholder: variable for placeholder, variable in zip(placeholders, variables)})
101 |
102 | def get_input_and_output_op(self):
103 | return self.next_elem
104 |
105 | def __str__(self):
106 | return f"path: {self.args.dataset_path}, split: {self.args.dataset_split_name} data size: {self._num_samples}"
107 |
108 | def get_all_dataset_paths(self) -> List[str]:
109 | if self.args.has_sub_dataset:
110 | return sorted([p for p in self.dataset_path_with_split_name.glob("*/") if p.is_dir()])
111 | else:
112 | return [self.dataset_path_with_split_name]
113 |
114 | def get_label_names(
115 | self,
116 | dataset_paths: List[str],
117 | ):
118 | """Get all label names (either from one or all subdirectories if subdatasets are defined)
119 | and check consistency of names.
120 |
121 | Args:
122 | dataset_paths: List of paths to datasets.
123 |
124 | Returns:
125 | name_labels: Names of labels.
126 | num_labels: Number of all labels.
127 | """
128 | tmp_label_names = []
129 | for dataset_path in dataset_paths:
130 | dataset_label_names = []
131 |
132 | if self.args.add_null_class:
133 | dataset_label_names.append(const.NULL_CLASS_LABEL)
134 |
135 | for name in sorted([c.name for c in dataset_path.glob("*")]):
136 | if name[0] != "_":
137 | dataset_label_names.append(name)
138 | tmp_label_names.append(dataset_label_names)
139 |
140 | assert len(set(map(tuple, tmp_label_names))) == 1, "Different labels for each sub-dataset directory"
141 |
142 | name_labels = tmp_label_names[0]
143 | num_labels = len(name_labels)
144 | assert num_labels > 0, f"There're no label directories in {dataset_paths}"
145 | return name_labels, num_labels
146 |
147 | def get_filenames_labels(
148 | self,
149 | dataset_paths: List[str],
150 | ) -> [List[str], List[str]]:
151 | """Get paths to all inputs and their labels.
152 |
153 | Args:
154 | dataset_paths: List of paths to datasets.
155 |
156 | Returns:
157 | filenames: List of paths to all inputs.
158 | labels: List of label indexes with corresponding to filenames.
159 | """
160 | if self.args.cache_dataset and self.args.cache_dataset_path is None:
161 | cache_directory = self.dataset_path / "_metainfo"
162 | cache_directory.mkdir(parents=True, exist_ok=True)
163 | cache_dataset_path = cache_directory / f"{self.dataset_split_name}.csv"
164 | else:
165 | cache_dataset_path = self.args.cache_dataset_path
166 |
167 | if self.args.cache_dataset and cache_dataset_path.exists():
168 | dataset_df = pd.read_csv(cache_dataset_path)
169 |
170 | filenames = list(dataset_df["filenames"])
171 | labels = list(dataset_df["labels"])
172 | else:
173 | filenames = []
174 | labels = []
175 | for label_idx, class_name in enumerate(self.label_names):
176 | for dataset_path in dataset_paths:
177 | for class_filename in dataset_path.joinpath(class_name).glob("*"):
178 | filenames.append(str(class_filename))
179 | labels.append(label_idx)
180 |
181 | if self.args.cache_dataset:
182 | pd.DataFrame({
183 | "filenames": filenames,
184 | "labels": labels,
185 | }).to_csv(cache_dataset_path, index=False)
186 |
187 | assert len(filenames) > 0
188 | if self.shuffle:
189 | filenames, labels = self.do_shuffle(filenames, labels)
190 |
191 | return filenames, labels
192 |
193 | def do_shuffle(self, *args):
194 | shuffled_data = list(zip(*args))
195 | random.shuffle(shuffled_data)
196 | result = tuple(map(lambda l: list(l), zip(*shuffled_data)))
197 |
198 | self.log.info(colored("Data shuffled!", "red"))
199 | return result
200 |
201 | def count_samples(
202 | self,
203 | samples: List,
204 | ) -> int:
205 | """Count number of samples in dataset.
206 |
207 | Args:
208 | samples: List of samples (e.g. filenames, labels).
209 |
210 | Returns:
211 | Number of samples.
212 | """
213 | num_samples = len(samples)
214 | with utils.format_text("yellow", ["underline"]) as fmt:
215 | self.log.info(fmt(f"number of data: {num_samples}"))
216 |
217 | return num_samples
218 |
219 | def oversampling(self, data, labels):
220 | """Doing oversampling based on labels.
221 | data: list of data.
222 | labels: list of labels.
223 | """
224 | assert self.args.oversampling_ratio is not None, (
225 | "When `--do_oversampling` is set, it also needs a proper value for `--oversampling_ratio`.")
226 |
227 | samples_of_label = defaultdict(list)
228 | for sample, label in zip(data, labels):
229 | samples_of_label[label].append(sample)
230 |
231 | num_samples_of_label = {label: len(lst) for label, lst in samples_of_label.items()}
232 | max_num_samples = max(num_samples_of_label.values())
233 | min_num_samples = int(max_num_samples * self.args.oversampling_ratio)
234 |
235 | self.log.info(f"Log for oversampling!")
236 | for label, num_samples in sorted(num_samples_of_label.items()):
237 | # for approximation issue, let's put them at least `n` times
238 | n = 5
239 | # ratio = int(max(min_num_samples / num_samples, 1.0) * n / n + 0.5)
240 | ratio = int(max(min_num_samples / num_samples, 1.0) * n + 0.5)
241 |
242 | self.log.info(f"{label}: {num_samples} x {ratio} => {num_samples * ratio}")
243 |
244 | for i in range(ratio - 1):
245 | data.extend(samples_of_label[label])
246 | labels.extend(label for _ in range(num_samples))
247 |
248 | return data, labels
249 |
250 | @staticmethod
251 | def add_arguments(parser):
252 | g_common = parser.add_argument_group("(DataWrapperBase) Common Arguments for all data wrapper.")
253 | g_common.add_argument("--dataset_path", required=True, type=str, help="The name of the dataset to load.")
254 | g_common.add_argument("--dataset_split_name", required=True, type=str, nargs="*",
255 | help="The name of the train/test split. Support multiple splits")
256 | g_common.add_argument("--no-has_sub_dataset", dest="has_sub_dataset", action="store_false")
257 | g_common.add_argument("--has_sub_dataset", dest="has_sub_dataset", action="store_true")
258 | g_common.set_defaults(has_sub_dataset=False)
259 | g_common.add_argument("--no-add_null_class", dest="add_null_class", action="store_false",
260 | help="Support null class for idx 0")
261 | g_common.add_argument("--add_null_class", dest="add_null_class", action="store_true")
262 | g_common.set_defaults(add_null_class=True)
263 |
264 | g_common.add_argument("--batch_size", default=32, type=utils.positive_int,
265 | help="The number of examples in batch.")
266 | g_common.add_argument("--no-shuffle", dest="shuffle", action="store_false")
267 | g_common.add_argument("--shuffle", dest="shuffle", action="store_true")
268 | g_common.set_defaults(shuffle=True)
269 |
270 | g_common.add_argument("--cache_dataset", dest="cache_dataset", action="store_true",
271 | help=("If True generates/loads csv file with paths to all inputs. "
272 | "It accelerates loading of large datasets."))
273 | g_common.add_argument("--no-cache_dataset", dest="cache_dataset", action="store_false")
274 | g_common.set_defaults(cache_dataset=False)
275 | g_common.add_argument("--cache_dataset_path", default=None, type=lambda p: Path(p),
276 | help=("Path to cached csv files containing paths to all inputs. "
277 | "If not given, csv file will be generated in the "
278 | "root data directory. This argument is used only if"
279 | "--cache_dataset is used."))
280 |
281 | g_common.add_argument("--width", type=int, default=-1)
282 | g_common.add_argument("--height", type=int, default=-1)
283 | g_common.add_argument("--augmentation_method", type=str, required=True,
284 | choices=_available_augmentation_methods)
285 | g_common.add_argument("--num_threads", default=8, type=int,
286 | help="We recommend using the number of available CPU cores for its value.")
287 | g_common.add_argument("--buffer_size", default=1000, type=int)
288 | g_common.add_argument("--prefetch_factor", default=100, type=int)
289 |
--------------------------------------------------------------------------------
/datasets/preprocessor_factory.py:
--------------------------------------------------------------------------------
1 | from datasets.preprocessors import NoOpPreprocessor
2 | from datasets.preprocessors import LogMelSpectrogramPreprocessor
3 | from datasets.preprocessors import MFCCPreprocessor
4 |
5 |
6 | _available_preprocessors = {
7 | # Audio
8 | "log_mel_spectrogram": LogMelSpectrogramPreprocessor,
9 | "mfcc": MFCCPreprocessor,
10 | # Noop
11 | "no_preprocessing": NoOpPreprocessor,
12 | }
13 |
14 |
15 | def factory(preprocess_method, scope, preprocessed_node_name):
16 | if preprocess_method in _available_preprocessors.keys():
17 | return _available_preprocessors[preprocess_method](scope, preprocessed_node_name)
18 | else:
19 | raise NotImplementedError(f"{preprocess_method}")
20 |
--------------------------------------------------------------------------------
/datasets/preprocessors.py:
--------------------------------------------------------------------------------
1 | from abc import ABC
2 | from abc import abstractmethod
3 |
4 | from tensorflow.contrib.framework.python.ops import audio_ops as contrib_audio
5 | import tensorflow as tf
6 |
7 | import const
8 |
9 |
10 | class PreprocessorBase(ABC):
11 | def __init__(self, scope: str, preprocessed_node_name: str):
12 | self._scope = scope
13 | self._input_node = None
14 | self._preprocessed_node = None
15 | self._preprocessed_node_name = preprocessed_node_name
16 |
17 | @abstractmethod
18 | def preprocess(self, inputs, reuse):
19 | raise NotImplementedError
20 |
21 | @staticmethod
22 | def _uint8_to_float32(inputs):
23 | if inputs.dtype == tf.uint8:
24 | inputs = tf.cast(inputs, tf.float32)
25 | return inputs
26 |
27 | def _assign_input_node(self, inputs):
28 | # We include preprocessing part in TFLite float model so we need an input tensor before preprocessing.
29 | self._input_node = inputs
30 |
31 | def _make_node_after_preprocessing(self, inputs):
32 | node = tf.identity(inputs, name=self._preprocessed_node_name)
33 | self._preprocessed_node = node
34 | return node
35 |
36 | @property
37 | def input_node(self):
38 | return self._input_node
39 |
40 | @property
41 | def preprocessed_node(self):
42 | return self._preprocessed_node
43 |
44 |
45 | class NoOpPreprocessor(PreprocessorBase):
46 | def preprocess(self, inputs, reuse=False):
47 | self._assign_input_node(inputs)
48 | inputs = self._make_node_after_preprocessing(inputs)
49 | return inputs
50 |
51 |
52 | # For Audio
53 | class AudioPreprocessorBase(PreprocessorBase):
54 | def preprocess(self, inputs, window_size_samples, window_stride_samples, for_deploy, **kwargs):
55 | self._assign_input_node(inputs)
56 | with tf.variable_scope(self._scope):
57 | if for_deploy:
58 | inputs = self._preprocess_for_deploy(inputs, window_size_samples, window_stride_samples, **kwargs)
59 | else:
60 | inputs = self._preprocess(inputs, window_size_samples, window_stride_samples, **kwargs)
61 | inputs = self._make_node_after_preprocessing(inputs)
62 | return inputs
63 |
64 | def _log_mel_spectrogram(self, audio, window_size_samples, window_stride_samples,
65 | magnitude_squared, **kwargs):
66 | # only accept single channels
67 | audio = tf.squeeze(audio, -1)
68 | stfts = tf.contrib.signal.stft(audio,
69 | frame_length=window_size_samples,
70 | frame_step=window_stride_samples)
71 |
72 | # If magnitude_squared = True(power_spectrograms)#, tf.real(stfts * tf.conj(stfts))
73 | # If magnitude_squared = False(magnitude_spectrograms), tf.abs(stfts)
74 | if magnitude_squared:
75 | spectrograms = tf.real(stfts * tf.conj(stfts))
76 | else:
77 | spectrograms = tf.abs(stfts)
78 |
79 | num_spectrogram_bins = spectrograms.shape[-1].value
80 | linear_to_mel_weight_matrix = tf.contrib.signal.linear_to_mel_weight_matrix(
81 | kwargs["num_mel_bins"],
82 | num_spectrogram_bins,
83 | kwargs["sample_rate"],
84 | kwargs["lower_edge_hertz"],
85 | kwargs["upper_edge_hertz"],
86 | )
87 |
88 | mel_spectrograms = tf.tensordot(spectrograms, linear_to_mel_weight_matrix, 1)
89 | mel_spectrograms.set_shape(
90 | spectrograms.shape[:-1].concatenate(linear_to_mel_weight_matrix.shape[-1:])
91 | )
92 |
93 | log_offset = 1e-6
94 | log_mel_spectrograms = tf.log(mel_spectrograms + log_offset)
95 |
96 | return log_mel_spectrograms
97 |
98 | def _single_spectrogram(self, audio, window_size_samples, window_stride_samples, magnitude_squared):
99 | # only accept single batch
100 | audio = tf.squeeze(audio, 0)
101 |
102 | spectrogram = contrib_audio.audio_spectrogram(
103 | audio,
104 | window_size=window_size_samples,
105 | stride=window_stride_samples,
106 | magnitude_squared=magnitude_squared
107 | )
108 |
109 | return spectrogram
110 |
111 | def _single_mfcc(self, audio, window_size_samples, window_stride_samples, magnitude_squared,
112 | **kwargs):
113 | spectrogram = self._single_spectrogram(audio, window_size_samples, window_stride_samples, magnitude_squared)
114 |
115 | mfcc = contrib_audio.mfcc(
116 | spectrogram,
117 | kwargs["sample_rate_const"],
118 | upper_frequency_limit=kwargs["upper_edge_hertz"],
119 | lower_frequency_limit=kwargs["lower_edge_hertz"],
120 | filterbank_channel_count=kwargs["num_mel_bins"],
121 | dct_coefficient_count=kwargs["num_mfccs"],
122 | )
123 |
124 | return mfcc
125 |
126 | def _get_mel_matrix(self, num_mel_bins, num_spectrogram_bins, sample_rate,
127 | lower_edge_hertz, upper_edge_hertz):
128 | if num_mel_bins == 64 and num_spectrogram_bins == 257 and sample_rate == 16000 \
129 | and lower_edge_hertz == 80.0 and upper_edge_hertz == 7600.0:
130 | return tf.constant(const.MEL_WEIGHT_64_257_16000_80_7600, dtype=tf.float32, name="mel_weight_matrix")
131 | elif num_mel_bins == 64 and num_spectrogram_bins == 513 and sample_rate == 16000 \
132 | and lower_edge_hertz == 80.0 and upper_edge_hertz == 7600.0:
133 | return tf.constant(const.MEL_WEIGHT_64_513_16000_80_7600, dtype=tf.float32, name="mel_weight_matrix")
134 | else:
135 | setting = (num_mel_bins, num_spectrogram_bins, sample_rate, lower_edge_hertz, upper_edge_hertz)
136 | raise ValueError(f"Target setting is not defined: {setting}")
137 |
138 | def _single_log_mel_spectrogram(self, audio, window_size_samples, window_stride_samples,
139 | magnitude_squared, **kwargs):
140 | spectrogram = self._single_spectrogram(audio, window_size_samples, window_stride_samples, magnitude_squared)
141 | spectrogram = tf.squeeze(spectrogram, 0)
142 |
143 | num_spectrogram_bins = spectrogram.shape[-1].value
144 | linear_to_mel_weight_matrix = self._get_mel_matrix(
145 | kwargs["num_mel_bins"],
146 | num_spectrogram_bins,
147 | kwargs["sample_rate"],
148 | kwargs["lower_edge_hertz"],
149 | kwargs["upper_edge_hertz"],
150 | )
151 |
152 | mel_spectrogram = tf.matmul(spectrogram, linear_to_mel_weight_matrix)
153 |
154 | log_offset = 1e-6
155 | log_mel_spectrograms = tf.log(mel_spectrogram + log_offset)
156 | log_mel_spectrograms = tf.expand_dims(log_mel_spectrograms, 0)
157 |
158 | return log_mel_spectrograms
159 |
160 |
161 | class LogMelSpectrogramPreprocessor(AudioPreprocessorBase):
162 | def _preprocess(self, audio, window_size_samples, window_stride_samples, **kwargs):
163 | # When calculate log mel spectogram, set magnitude_squared False
164 | log_mel_spectrograms = self._log_mel_spectrogram(audio,
165 | window_size_samples,
166 | window_stride_samples,
167 | False,
168 | **kwargs)
169 | log_mel_spectrograms = tf.expand_dims(log_mel_spectrograms, axis=-1)
170 | return log_mel_spectrograms
171 |
172 | def _preprocess_for_deploy(self, audio, window_size_samples, window_stride_samples, **kwargs):
173 | log_mel_spectrogram = self._single_log_mel_spectrogram(audio,
174 | window_size_samples,
175 | window_stride_samples,
176 | False,
177 | **kwargs)
178 | log_mel_spectrogram = tf.expand_dims(log_mel_spectrogram, axis=-1)
179 | return log_mel_spectrogram
180 |
181 |
182 | class MFCCPreprocessor(AudioPreprocessorBase):
183 | def _preprocess(self, audio, window_size_samples, window_stride_samples, **kwargs):
184 | # When calculate log mel spectogram, set magnitude_squared True
185 | log_mel_spectrograms = self._log_mel_spectrogram(audio,
186 | window_size_samples,
187 | window_stride_samples,
188 | True,
189 | **kwargs)
190 |
191 | mfccs = tf.contrib.signal.mfccs_from_log_mel_spectrograms(log_mel_spectrograms)
192 | mfccs = mfccs[..., :kwargs["num_mfccs"]]
193 | mfccs = tf.expand_dims(mfccs, axis=-1)
194 | return mfccs
195 |
196 | def _preprocess_for_deploy(self, audio, window_size_samples, window_stride_samples, **kwargs):
197 | mfcc = self._single_mfcc(audio,
198 | window_size_samples,
199 | window_stride_samples,
200 | True,
201 | **kwargs)
202 | mfcc = tf.expand_dims(mfcc, axis=-1)
203 | return mfcc
204 |
--------------------------------------------------------------------------------
/evaluate_audio.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import tensorflow as tf
4 |
5 | from factory.base import TFModel
6 | import factory.audio_nets as audio_nets
7 | from helper.base import Base
8 | from helper.evaluator import Evaluator
9 | from helper.evaluator import SingleLabelAudioEvaluator
10 | from datasets.audio_data_wrapper import AudioDataWrapper
11 | from datasets.audio_data_wrapper import SingleLabelAudioDataWrapper
12 | from datasets.audio_data_wrapper import DataWrapperBase
13 | from metrics.base import MetricManagerBase
14 | from common.tf_utils import ckpt_iterator
15 | import common.utils as utils
16 | import const
17 |
18 |
19 | def main(args):
20 | is_training = False
21 | dataset_name = args.dataset_split_name[0]
22 | session = tf.Session(config=const.TF_SESSION_CONFIG)
23 |
24 | dataset = SingleLabelAudioDataWrapper(
25 | args,
26 | session,
27 | dataset_name,
28 | is_training,
29 | )
30 | wavs, labels = dataset.get_input_and_output_op()
31 |
32 | model = eval(f"audio_nets.{args.model}")(args, dataset)
33 | model.build(wavs=wavs, labels=labels, is_training=is_training)
34 |
35 | dataset_name = args.dataset_split_name[0]
36 | evaluator = SingleLabelAudioEvaluator(
37 | model,
38 | session,
39 | args,
40 | dataset,
41 | dataset_name,
42 | )
43 | log = utils.get_logger("EvaluateAudio")
44 |
45 | if args.valid_type == "once":
46 | evaluator.evaluate_once(args.checkpoint_path)
47 | elif args.valid_type == "loop":
48 | log.info(f"Start Loop: watching {evaluator.watch_path}")
49 |
50 | kwargs = {
51 | "min_interval_secs": 0,
52 | "timeout": None,
53 | "timeout_fn": None,
54 | "logger": log,
55 | }
56 |
57 | for ckpt_path in ckpt_iterator(evaluator.watch_path, **kwargs):
58 | log.info(f"[watch] {ckpt_path}")
59 |
60 | evaluator.evaluate_once(ckpt_path)
61 | else:
62 | raise ValueError(f"Undefined valid_type: {args.valid_type}")
63 |
64 |
65 | if __name__ == "__main__":
66 | parser = argparse.ArgumentParser(description=__doc__)
67 | subparsers = parser.add_subparsers(title="Model", description="")
68 |
69 | Base.add_arguments(parser)
70 | Evaluator.add_arguments(parser)
71 | DataWrapperBase.add_arguments(parser)
72 | AudioDataWrapper.add_arguments(parser)
73 | TFModel.add_arguments(parser)
74 | audio_nets.AudioNetModel.add_arguments(parser)
75 | MetricManagerBase.add_arguments(parser)
76 |
77 | for class_name in audio_nets._available_nets:
78 | subparser = subparsers.add_parser(class_name)
79 | subparser.add_argument("--model", default=class_name, type=str, help="DO NOT FIX ME")
80 | add_audio_arguments = eval("audio_nets.{}.add_arguments".format(class_name))
81 | add_audio_arguments(subparser)
82 |
83 | args = parser.parse_args()
84 |
85 | log = utils.get_logger("AudioNetEvaluate")
86 | log.info(args)
87 | main(args)
88 |
--------------------------------------------------------------------------------
/execute_script.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 | set -eux
3 |
4 | # setup your train and data paths
5 | export ROOT_TRAIN_DIR=
6 | export DATA_SPEECH_COMMANDS_V1_DIR=/data/google_audio/data_speech_commands_v0.01_newsplit
7 | export DATA_SPEECH_COMMANDS_V2_DIR=/data/google_audio/data_speech_commands_v0.02_newsplit
8 |
9 | ## DATASET_SPLIT_NAME
10 | # train
11 | # valid
12 | # test
13 | dataset_split_name=
14 |
15 | ## AVAILABLE MODELS
16 | # KWSModel
17 | # Res8Model
18 | # Res8NarrowModel
19 | # Res15Model
20 | # Res15NarrowModel
21 | # DSCNNSModel
22 | # DSCNNMModel
23 | # DSCNNLModel
24 | # TCResNet_8Model
25 | # TCResNet_14Model
26 | # TCResNet_2D8Model
27 | model=
28 |
29 | ## AVAILABLE AUDIO PREPROCESS METHODS
30 | # log_mel_spectrogram
31 | # mfcc*
32 | audio_preprocess_setting=
33 |
34 | ## AVAILABLE WINDOW SETTINGS
35 | # 3010
36 | # 4020
37 | window_setting=
38 |
39 | ## WEIGHT DECAY
40 | # 0.0
41 | # 0.001
42 | weight_decay=
43 |
44 | ## AVAILABLE OPTIMIZER SETTINGS
45 | # adam
46 | # mom
47 | opt_setting=
48 |
49 | ## AVAILABLE LEARNING RATE SETTINGS
50 | # s1
51 | # l1
52 | # l2
53 | # l3
54 | lr_setting=
55 |
56 | ## DEPTH MULTIPLIER
57 | width_multiplier=
58 |
59 | ## AVAILABLE DATASETS
60 | # v1
61 | # v2
62 | dataset_settings=
63 |
64 |
65 | ./scripts/script_google_audio.sh \
66 | ${dataset_split_name} \
67 | ${model} \
68 | ${audio_preprocess_setting} \
69 | ${window_setting} \
70 | ${weight_decay} \
71 | ${opt_setting} \
72 | ${lr_setting} \
73 | ${width_multiplier} \
74 | ${dataset_settings}
75 |
--------------------------------------------------------------------------------
/factory/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hyperconnect/TC-ResNet/8ccbff3a45590247d8c54cc82129acb90eecf5c8/factory/__init__.py
--------------------------------------------------------------------------------
/factory/base.py:
--------------------------------------------------------------------------------
1 | from abc import ABC
2 | from abc import abstractmethod
3 |
4 | import tensorflow as tf
5 | import tensorflow.contrib.slim as slim
6 |
7 | import common.tf_utils as tf_utils
8 | from datasets.preprocessor_factory import _available_preprocessors
9 |
10 |
11 | class TFModel(ABC):
12 | @staticmethod
13 | def add_arguments(parser):
14 | g_cnn = parser.add_argument_group("(CNNModel) Arguments")
15 | g_cnn.add_argument("--num_classes", type=int, default=None)
16 | g_cnn.add_argument("--checkpoint_path", default="", type=str)
17 |
18 | g_cnn.add_argument("--input_batch_size", type=int, default=1)
19 | g_cnn.add_argument("--output_name", type=str, required=True)
20 |
21 | g_cnn.add_argument("--preprocess_method", required=True, type=str,
22 | choices=list(_available_preprocessors.keys()))
23 |
24 | g_cnn.add_argument("--no-ignore_missing_vars", dest="ignore_missing_vars", action="store_false")
25 | g_cnn.add_argument("--ignore_missing_vars", dest="ignore_missing_vars", action="store_true")
26 | g_cnn.set_defaults(ignore_missing_vars=False)
27 |
28 | g_cnn.add_argument("--checkpoint_exclude_scopes", default="", type=str,
29 | help=("Prefix scopes that shoule be EXLUDED for restoring variables "
30 | "(comma separated)"))
31 |
32 | g_cnn.add_argument("--checkpoint_include_scopes", default="", type=str,
33 | help=("Prefix scopes that should be INCLUDED for restoring variables "
34 | "(comma separated)"))
35 | g_cnn.add_argument("--weight_decay", default=1e-4, type=float)
36 |
37 | @abstractmethod
38 | def build_deployable_model(self, *args, **kwargs):
39 | pass
40 |
41 | @abstractmethod
42 | def preprocess_input(self):
43 | pass
44 |
45 | @abstractmethod
46 | def build_output(self):
47 | pass
48 |
49 | @property
50 | @abstractmethod
51 | def audio(self):
52 | pass
53 |
54 | @property
55 | @abstractmethod
56 | def audio_original(self):
57 | pass
58 |
59 | @property
60 | @abstractmethod
61 | def total_loss(self):
62 | pass
63 |
64 | @property
65 | @abstractmethod
66 | def model_loss(self):
67 | pass
68 |
--------------------------------------------------------------------------------
/figure/main_figure.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hyperconnect/TC-ResNet/8ccbff3a45590247d8c54cc82129acb90eecf5c8/figure/main_figure.png
--------------------------------------------------------------------------------
/freeze.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | import argparse
3 |
4 | import tensorflow as tf
5 | from tensorflow.python.framework import graph_util
6 |
7 | import common.model_loader as model_loader
8 | from factory.base import TFModel
9 | import factory.audio_nets as audio_nets
10 | from factory.audio_nets import AudioNetModel
11 | from factory.audio_nets import *
12 | from helper.base import Base
13 | import const
14 |
15 |
16 | def freeze(args):
17 | graph = tf.Graph()
18 | with graph.as_default():
19 | session = tf.Session(config=const.TF_SESSION_CONFIG)
20 |
21 | model = eval(args.model)(args)
22 | input_tensors, output_tensor = model.build_deployable_model(include_preprocess=False)
23 |
24 | ckpt_loader = model_loader.Ckpt(
25 | session=session,
26 | include_scopes=args.checkpoint_include_scopes,
27 | exclude_scopes=args.checkpoint_exclude_scopes,
28 | ignore_missing_vars=args.ignore_missing_vars,
29 | use_ema=args.use_ema,
30 | ema_decay=args.ema_decay,
31 | )
32 | session.run(tf.global_variables_initializer())
33 | session.run(tf.local_variables_initializer())
34 |
35 | ckpt_loader.load(args.checkpoint_path)
36 |
37 | frozen_graph_def = graph_util.convert_variables_to_constants(
38 | session,
39 | session.graph_def,
40 | [output_tensor.op.name],
41 | )
42 |
43 | checkpoint_path = Path(args.checkpoint_path)
44 | output_raw_pb_path = checkpoint_path.parent / f"{checkpoint_path.name}.pb"
45 | tf.train.write_graph(frozen_graph_def,
46 | str(output_raw_pb_path.parent),
47 | output_raw_pb_path.name,
48 | as_text=False)
49 | print(f"Save freezed pb : {output_raw_pb_path}")
50 |
51 |
52 | if __name__ == "__main__":
53 | parser = argparse.ArgumentParser(description=__doc__)
54 | subparsers = parser.add_subparsers(title="Model", description="")
55 |
56 | TFModel.add_arguments(parser)
57 | AudioNetModel.add_arguments(parser)
58 |
59 | for class_name in audio_nets._available_nets:
60 | subparser = subparsers.add_parser(class_name)
61 | subparser.add_argument("--model", default=class_name, type=str, help="DO NOT FIX ME")
62 | add_audio_net_arguments = eval(f"audio_nets.{class_name}.add_arguments")
63 | add_audio_net_arguments(subparser)
64 |
65 | Base.add_arguments(parser)
66 |
67 | parser.add_argument("--width", required=True, type=int)
68 | parser.add_argument("--height", required=True, type=int)
69 | parser.add_argument("--channels", required=True, type=int)
70 |
71 | parser.add_argument("--sample_rate", type=int, default=16000, help="Expected sample rate of the wavs",)
72 | parser.add_argument("--clip_duration_ms", type=int)
73 | parser.add_argument("--window_size_ms", type=float, default=30.0, help="How long each spectrogram timeslice is.",)
74 | parser.add_argument("--window_stride_ms", type=float, default=30.0,
75 | help="How far to move in time between spectogram timeslices.",)
76 | parser.add_argument("--num_mel_bins", type=int, default=64)
77 | parser.add_argument("--num_mfccs", type=int, default=64)
78 | parser.add_argument("--lower_edge_hertz", type=float, default=80.0)
79 | parser.add_argument("--upper_edge_hertz", type=float, default=7600.0)
80 |
81 | args = parser.parse_args()
82 |
83 | freeze(args)
84 |
--------------------------------------------------------------------------------
/helper/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hyperconnect/TC-ResNet/8ccbff3a45590247d8c54cc82129acb90eecf5c8/helper/__init__.py
--------------------------------------------------------------------------------
/helper/base.py:
--------------------------------------------------------------------------------
1 | import time
2 |
3 | import tensorflow as tf
4 | import numpy as np
5 | from abc import ABC
6 | from abc import abstractmethod
7 | from termcolor import colored
8 |
9 | from common.utils import Timer
10 | from common.utils import get_logger
11 | from common.utils import format_log
12 | from common.utils import format_text
13 | from metrics.summaries import BaseSummaries
14 |
15 |
16 | class Base(ABC):
17 | def __init__(self):
18 | self.log = get_logger("Base")
19 | self.timer = Timer(self.log)
20 |
21 | def get_feed_dict(self, is_training: bool=False):
22 | feed_dict = dict()
23 | return feed_dict
24 |
25 | def check_batch_size(
26 | self,
27 | batch_size: int,
28 | terminate: bool=False,
29 | ):
30 | if batch_size != self.args.batch_size:
31 | self.log.info(colored(f"Batch size: required {self.args.batch_size}, obtained {batch_size}", "red"))
32 | if terminate:
33 | raise tf.errors.OutOfRangeError(None, None, "Finished looping dataset.")
34 |
35 | def build_iters_from_batch_size(self, num_samples, batch_size):
36 | iters = self.dataset.num_samples // self.args.batch_size
37 | num_ignored_samples = self.dataset.num_samples % self.args.batch_size
38 | if num_ignored_samples > 0:
39 | with format_text("red", attrs=["bold"]) as fmt:
40 | msg = (
41 | f"Number of samples cannot be divided by batch_size, "
42 | f"so it ignores some data examples in evaluation: "
43 | f"{self.dataset.num_samples} % {self.args.batch_size} = {num_ignored_samples}"
44 | )
45 | self.log.warning(fmt(msg))
46 | return iters
47 |
48 | @abstractmethod
49 | def build_evaluation_fetch_ops(self, do_eval):
50 | raise NotImplementedError
51 |
52 | def run_inference(
53 | self,
54 | global_step: int,
55 | iters: int=None,
56 | is_training: bool=False,
57 | do_eval: bool=True,
58 | ):
59 | """
60 | Return: Dict[metric_key] -> np.array
61 | array is stacked values for all batches
62 | """
63 | feed_dict = self.get_feed_dict(is_training=is_training)
64 |
65 | is_first_batch = True
66 |
67 | if iters is None:
68 | iters = self.build_iters_from_batch_size(self.dataset.num_samples, self.args.batch_size)
69 |
70 | # Get summary ops which should be evaluated by session.run
71 | # For example, segmentation task has several loss(GRAD/MAD/MSE) metrics
72 | # And these losses are now calculated by TensorFlow(not numpy)
73 | merged_tensor_type_summaries = self.metric_manager.summary.get_merged_summaries(
74 | collection_key_suffixes=[BaseSummaries.KEY_TYPES.DEFAULT],
75 | is_tensor_summary=True
76 | )
77 |
78 | fetch_ops = self.build_evaluation_fetch_ops(do_eval)
79 |
80 | aggregator = {key: list() for key in fetch_ops}
81 | aggregator.update({
82 | "batch_infer_time": list(),
83 | "unit_infer_time": list(),
84 | })
85 |
86 | for i in range(iters):
87 | try:
88 | st = time.time()
89 |
90 | is_running_summary = do_eval and is_first_batch and merged_tensor_type_summaries is not None
91 | if is_running_summary:
92 | fetch_ops_with_summary = {"summary": merged_tensor_type_summaries}
93 | fetch_ops_with_summary.update(fetch_ops)
94 |
95 | fetch_vals = self.session.run(fetch_ops_with_summary, feed_dict=feed_dict)
96 |
97 | # To avoid duplicated code of session.run, we evaluate merged_sum
98 | # Because we run multiple batches within single global_step,
99 | # merged_summaries can have duplicated values.
100 | # So we write only when the session.run is first
101 | self.metric_manager.write_tensor_summaries(global_step, fetch_vals["summary"])
102 | is_first_batch = False
103 | else:
104 | fetch_vals = self.session.run(fetch_ops, feed_dict=feed_dict)
105 |
106 | batch_infer_time = (time.time() - st) * 1000 # use milliseconds
107 |
108 | # aggregate
109 | for key, fetch_val in fetch_vals.items():
110 | if key in aggregator:
111 | aggregator[key].append(fetch_val)
112 |
113 | # add inference time
114 | aggregator["batch_infer_time"].append(batch_infer_time)
115 | aggregator["unit_infer_time"].append(batch_infer_time / self.args.batch_size)
116 |
117 | except tf.errors.OutOfRangeError:
118 | format_log(self.log.info, "yellow")(f"Reach end of the dataset.")
119 | break
120 | except tf.errors.InvalidArgumentError as e:
121 | format_log(self.log.error, "red")(f"Invalid instance is detected: {e}")
122 | continue
123 |
124 | aggregator = {k: np.vstack(v) for k, v in aggregator.items()}
125 | return aggregator
126 |
127 | def run_evaluation(
128 | self,
129 | global_step: int,
130 | iters: int=None,
131 | is_training: bool=False,
132 | ):
133 | eval_dict = self.run_inference(global_step, iters, is_training, do_eval=True)
134 |
135 | non_tensor_data = self.build_non_tensor_data_from_eval_dict(eval_dict, step=global_step)
136 |
137 | self.metric_manager.evaluate_and_aggregate_metrics(step=global_step,
138 | non_tensor_data=non_tensor_data,
139 | eval_dict=eval_dict)
140 |
141 | eval_metric_dict = self.metric_manager.get_evaluation_result(step=global_step)
142 |
143 | return eval_metric_dict
144 |
145 | @staticmethod
146 | def add_arguments(parser):
147 | g_base = parser.add_argument_group("Base")
148 | g_base.add_argument("--no-use_ema", dest="use_ema", action="store_false")
149 | g_base.add_argument("--use_ema", dest="use_ema", action="store_true",
150 | help="Exponential Moving Average. It may take more memory.")
151 | g_base.set_defaults(use_ema=False)
152 | g_base.add_argument("--ema_decay", default=0.999, type=float,
153 | help=("Exponential Moving Average decay.\n"
154 | "Reasonable values for decay are close to 1.0, typically "
155 | "in the multiple-nines range: 0.999, 0.9999"))
156 | g_base.add_argument("--evaluation_iterations", type=int, default=None)
157 |
158 |
159 | class AudioBase(Base):
160 | def build_evaluation_fetch_ops(self, do_eval):
161 | if do_eval:
162 | fetch_ops = {
163 | "labels_onehot": self.model.labels,
164 | "predictions_onehot": self.model.outputs,
165 | "total_loss": self.model.total_loss,
166 | }
167 | fetch_ops.update(self.metric_tf_op)
168 | else:
169 | fetch_ops = {
170 | "predictions_onehot": self.model.outputs,
171 | }
172 |
173 | return fetch_ops
174 |
175 | def build_basic_loss_ops(self):
176 | losses = {
177 | "total_loss": self.model.total_loss,
178 | "model_loss": self.model.model_loss,
179 | }
180 | losses.update(self.model.endpoints_loss)
181 |
182 | return losses
183 |
--------------------------------------------------------------------------------
/helper/evaluator.py:
--------------------------------------------------------------------------------
1 | import csv
2 | import sys
3 | from pathlib import Path
4 | from abc import abstractmethod
5 |
6 | import numpy as np
7 | import tensorflow as tf
8 | from tqdm import tqdm
9 |
10 | import common.tf_utils as tf_utils
11 | import metrics.manager as metric_manager
12 | from common.model_loader import Ckpt
13 | from common.utils import format_text
14 | from common.utils import get_logger
15 | from helper.base import AudioBase
16 | from metrics.summaries import BaseSummaries
17 | from metrics.summaries import Summaries
18 |
19 |
20 | class Evaluator(object):
21 | def __init__(self, model, session, args, dataset, dataset_name, name):
22 | self.log = get_logger(name)
23 |
24 | self.model = model
25 | self.session = session
26 | self.args = args
27 | self.dataset = dataset
28 | self.dataset_name = dataset_name
29 |
30 | if Path(self.args.checkpoint_path).is_dir():
31 | latest_checkpoint = tf.train.latest_checkpoint(self.args.checkpoint_path)
32 | if latest_checkpoint is not None:
33 | self.args.checkpoint_path = latest_checkpoint
34 | self.log.info(f"Get latest checkpoint and update to it: {self.args.checkpoint_path}")
35 |
36 | self.watch_path = self._build_watch_path()
37 |
38 | self.session.run(tf.global_variables_initializer())
39 | self.session.run(tf.local_variables_initializer())
40 |
41 | self.ckpt_loader = Ckpt(
42 | session=session,
43 | include_scopes=args.checkpoint_include_scopes,
44 | exclude_scopes=args.checkpoint_exclude_scopes,
45 | ignore_missing_vars=args.ignore_missing_vars,
46 | use_ema=self.args.use_ema,
47 | ema_decay=self.args.ema_decay,
48 | )
49 |
50 | @abstractmethod
51 | def setup_metric_manager(self):
52 | raise NotImplementedError
53 |
54 | @abstractmethod
55 | def setup_metric_ops(self):
56 | raise NotImplementedError
57 |
58 | @abstractmethod
59 | def build_non_tensor_data_from_eval_dict(self, eval_dict, **kwargs):
60 | raise NotImplementedError
61 |
62 | @abstractmethod
63 | def setup_dataset_iterator(self):
64 | raise NotImplementedError
65 |
66 | def _build_watch_path(self):
67 | if Path(self.args.checkpoint_path).is_dir():
68 | return Path(self.args.checkpoint_path)
69 | else:
70 | return Path(self.args.checkpoint_path).parent
71 |
72 | def build_evaluation_step(self, checkpoint_path):
73 | if "-" in checkpoint_path and checkpoint_path.split("-")[-1].isdigit():
74 | return int(checkpoint_path.split("-")[-1])
75 | else:
76 | return 0
77 |
78 | def build_checkpoint_paths(self, checkpoint_path):
79 | checkpoint_glob = Path(checkpoint_path + "*")
80 | checkpoint_path = Path(checkpoint_path)
81 |
82 | return checkpoint_glob, checkpoint_path
83 |
84 | def build_miscellaneous_path(self, name):
85 | target_dir = self.watch_path / "miscellaneous" / self.dataset_name / name
86 |
87 | if not target_dir.exists():
88 | target_dir.mkdir(parents=True)
89 |
90 | return target_dir
91 |
92 | def setup_best_keeper(self):
93 | metric_with_modes = self.metric_manager.get_best_keep_metric_with_modes()
94 | self.log.debug(metric_with_modes)
95 | self.best_keeper = tf_utils.BestKeeper(
96 | metric_with_modes,
97 | self.dataset_name,
98 | self.watch_path,
99 | self.log,
100 | )
101 |
102 | def evaluate_once(self, checkpoint_path):
103 | self.log.info("Evaluation started")
104 | self.setup_dataset_iterator()
105 | self.ckpt_loader.load(checkpoint_path)
106 |
107 | step = self.build_evaluation_step(checkpoint_path)
108 | checkpoint_glob, checkpoint_path = self.build_checkpoint_paths(checkpoint_path)
109 | self.session.run(tf.local_variables_initializer())
110 |
111 | eval_metric_dict = self.run_evaluation(step, is_training=False)
112 | best_keep_metric_dict = self.metric_manager.filter_best_keep_metric(eval_metric_dict)
113 | is_keep, metrics_keep = self.best_keeper.monitor(self.dataset_name, best_keep_metric_dict)
114 |
115 | if self.args.save_best_keeper:
116 | meta_info = {
117 | "step": step,
118 | "model_size": self.model.total_params,
119 | }
120 | self.best_keeper.remove_old_best(self.dataset_name, metrics_keep)
121 | self.best_keeper.save_best(self.dataset_name, metrics_keep, checkpoint_glob)
122 | self.best_keeper.remove_temp_dir()
123 | self.best_keeper.save_scores(self.dataset_name, metrics_keep, best_keep_metric_dict, meta_info)
124 |
125 | self.metric_manager.write_evaluation_summaries(step=step,
126 | collection_keys=[BaseSummaries.KEY_TYPES.DEFAULT])
127 | self.metric_manager.log_metrics(step=step)
128 |
129 | self.log.info("Evaluation finished")
130 |
131 | if step >= self.args.max_step_from_restore:
132 | self.log.info("Evaluation stopped")
133 | sys.exit()
134 |
135 | def build_train_directory(self):
136 | if Path(self.args.checkpoint_path).is_dir():
137 | return str(self.args.checkpoint_path)
138 | else:
139 | return str(Path(self.args.checkpoint_path).parent)
140 |
141 | @staticmethod
142 | def add_arguments(parser):
143 | g = parser.add_argument_group("(Evaluator) arguments")
144 |
145 | g.add_argument("--valid_type", default="loop", type=str, choices=["loop", "once"])
146 | g.add_argument("--max_outputs", default=5, type=int)
147 |
148 | g.add_argument("--maximum_num_labels_for_metric", default=10, type=int,
149 | help="Maximum number of labels for using class-specific metrics(e.g. precision/recall/f1score)")
150 |
151 | g.add_argument("--no-save_best_keeper", dest="save_best_keeper", action="store_false")
152 | g.add_argument("--save_best_keeper", dest="save_best_keeper", action="store_true")
153 | g.set_defaults(save_best_keeper=True)
154 |
155 | g.add_argument("--no-flatten_output", dest="flatten_output", action="store_false")
156 | g.add_argument("--flatten_output", dest="flatten_output", action="store_true")
157 | g.set_defaults(flatten_output=False)
158 |
159 | g.add_argument("--max_step_from_restore", default=1e20, type=int)
160 |
161 |
162 | class SingleLabelAudioEvaluator(Evaluator, AudioBase):
163 |
164 | def __init__(self, model, session, args, dataset, dataset_name):
165 | super().__init__(model, session, args, dataset, dataset_name, "SingleLabelAudioEvaluator")
166 | self.setup_dataset_related_attr()
167 | self.setup_metric_manager()
168 | self.setup_metric_ops()
169 | self.setup_best_keeper()
170 |
171 | def setup_dataset_related_attr(self):
172 | assert len(self.dataset.label_names) == self.args.num_classes
173 | self.use_class_metrics = len(self.dataset.label_names) < self.args.maximum_num_labels_for_metric
174 |
175 | def setup_metric_manager(self):
176 | self.metric_manager = metric_manager.AudioMetricManager(
177 | is_training=False,
178 | use_class_metrics=self.use_class_metrics,
179 | exclude_metric_names=self.args.exclude_metric_names,
180 | summary=Summaries(
181 | session=self.session,
182 | train_dir=self.build_train_directory(),
183 | is_training=False,
184 | base_name=self.dataset.dataset_split_name,
185 | max_summary_outputs=self.args.max_summary_outputs,
186 | ),
187 | )
188 |
189 | def setup_metric_ops(self):
190 | losses = self.build_basic_loss_ops()
191 | self.metric_tf_op = self.metric_manager.build_metric_ops({
192 | "dataset_split_name": self.dataset_name,
193 | "label_names": self.dataset.label_names,
194 | "losses": losses,
195 | "learning_rate": None,
196 | "wavs": self.model.audio_original,
197 | })
198 |
199 | def build_non_tensor_data_from_eval_dict(self, eval_dict, **kwargs):
200 | return {
201 | "dataset_split_name": self.dataset.dataset_split_name,
202 | "label_names": self.dataset.label_names,
203 | "predictions_onehot": eval_dict["predictions_onehot"],
204 | "labels_onehot": eval_dict["labels_onehot"],
205 | }
206 |
207 | def setup_dataset_iterator(self):
208 | self.dataset.setup_iterator(
209 | self.session,
210 | self.dataset.placeholders,
211 | self.dataset.data,
212 | )
213 |
--------------------------------------------------------------------------------
/metrics/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hyperconnect/TC-ResNet/8ccbff3a45590247d8c54cc82129acb90eecf5c8/metrics/__init__.py
--------------------------------------------------------------------------------
/metrics/base.py:
--------------------------------------------------------------------------------
1 | from abc import ABC
2 | from collections import defaultdict
3 |
4 | from common.utils import get_logger
5 | from common.utils import format_text
6 | from common.utils import timer
7 |
8 |
9 | class DataStructure(ABC):
10 | """
11 | Define inner data structure
12 | Should define `_keys`
13 | """
14 | _keys = None
15 |
16 | def __init__(self, data):
17 | keys = self.__class__.get_keys()
18 | data_keys = data.keys()
19 |
20 | if set(keys) != set(data_keys):
21 | raise ValueError(f"Keys defined in `_keys ({list(keys)})`"
22 | f" should be appeared at "
23 | f"`data ({list(data_keys)})`")
24 | for k in keys:
25 | setattr(self, k, data[k])
26 |
27 | def __str__(self):
28 | return f""
29 |
30 | def __repr__(self):
31 | return str(self)
32 |
33 | def to_dict(self):
34 | return {k: getattr(self, k) for k in self._keys}
35 |
36 | @classmethod
37 | def get_keys(cls):
38 | return cls._keys
39 |
40 |
41 | class MetricAggregator:
42 | def __init__(self):
43 | self.step = None
44 | self.metrics_with_result = None
45 |
46 | self.init(-1)
47 |
48 | def init(self, step):
49 | self.step = step
50 | self.metrics_with_result = dict()
51 |
52 | def aggregate(self, metric, metric_result):
53 | assert metric not in self.metrics_with_result
54 | self.metrics_with_result[metric] = metric_result
55 |
56 | def iterate_metrics(self):
57 | for metric, metric_result in self.metrics_with_result.items():
58 | yield metric, metric_result
59 |
60 | def iterate_all(self):
61 | for metric, metric_result in self.iterate_metrics():
62 | for metric_key, value in metric_result.items():
63 | yield metric, metric_key, value
64 |
65 | def get_collection_summary_dict(self):
66 | # for_summary: Dict[str, Dict[MetricOp, List[Tuple(metric_key, tensor_op)]]]
67 | # collection_key -> metric -> List(summary_key, value)
68 | for_summary = defaultdict(lambda: defaultdict(list))
69 | for metric, metric_result in self.metrics_with_result.items():
70 | if metric.is_for_summary:
71 | for metric_key, value in metric_result.items():
72 | for_summary[metric.summary_collection_key][metric].append((metric_key, value))
73 |
74 | return for_summary
75 |
76 | def get_tensor_metrics(self):
77 | """
78 | Get metric that would be fetched for session run.
79 | """
80 | tensor_metrics = dict()
81 | for metric, metric_result in self.metrics_with_result.items():
82 | if metric.is_tensor_metric:
83 | for metric_key, value in metric_result.items():
84 | tensor_metrics[metric_key] = value
85 |
86 | return tensor_metrics
87 |
88 | def get_logs(self):
89 | logs = dict()
90 | for metric, metric_result in self.metrics_with_result.items():
91 | if metric.is_for_log:
92 | for metric_key, value in metric_result.items():
93 | if isinstance(value, str):
94 | msg = f"> {metric_key}\n{value}"
95 | else:
96 | msg = f"> {metric_key} : {value}"
97 | logs[metric_key] = msg
98 |
99 | return logs
100 |
101 |
102 | class MetricManagerBase(ABC):
103 | _metric_input_data_parser = None
104 |
105 | def __init__(self, exclude_metric_names, summary):
106 | self.log = get_logger("Metrics")
107 | self.build_op_aggregator = MetricAggregator()
108 | self.eval_metric_aggregator = MetricAggregator()
109 |
110 | self.summary = summary
111 | self.exclude_metric_names = exclude_metric_names
112 |
113 | self.metric_ops = []
114 |
115 | def register_metric(self, metric):
116 | # if metric is in exclude_metric_names ?
117 | if metric.__class__.__name__ in self.exclude_metric_names:
118 | self.log.info(f"{metric.__class__.__name__} is excluded by user setting.")
119 | return
120 |
121 | # assertion for this metric would be processable
122 | assert str(self._metric_input_data_parser) in map(lambda c: str(c), metric.valid_input_data_parsers), \
123 | f"{metric.__class__.__name__} cannot be parsed by {self._metric_input_data_parser}"
124 |
125 | # add one
126 | self.metric_ops.append(metric)
127 | self.log.info(f"{metric.__class__.__name__} is added.")
128 |
129 | def register_metrics(self, metrics: list):
130 | for metric in metrics:
131 | self.register_metric(metric)
132 |
133 | def build_metric_ops(self, data):
134 | """
135 | Define tensor metric operations
136 | 1. call `build_op` of metrics, i.e. add operations to graph
137 | 2. register summaries
138 |
139 | Return: Dict[str, Tensor]
140 | metric_key -> metric_op
141 | """
142 | output_build_data = self._metric_input_data_parser.parse_build_data(data)
143 |
144 | # get metric tf ops
145 | for metric in self.metric_ops:
146 | try:
147 | metric_build_ops = metric.build_op(output_build_data)
148 | except TypeError as e:
149 | raise TypeError(f"[{metric}]: {e}")
150 | self.build_op_aggregator.aggregate(metric, metric_build_ops)
151 |
152 | # if value is not None, it means it is defined with tensor
153 | metric_tf_ops = self.build_op_aggregator.get_tensor_metrics()
154 |
155 | # register summary
156 | collection_summary_dict = self.build_op_aggregator.get_collection_summary_dict()
157 | self.summary.register_summaries(collection_summary_dict)
158 | self.summary.setup_merged_summaries()
159 |
160 | return metric_tf_ops
161 |
162 | # def evaluate_non_tensor_metric(self, data, step):
163 | def evaluate_and_aggregate_metrics(self, non_tensor_data, eval_dict, step):
164 | """
165 | Run evaluation of non-tensor metrics
166 | Args:
167 | data: data passed from trainer / evaluator/ ...
168 | """
169 | non_tensor_data = self._metric_input_data_parser.parse_non_tensor_data(non_tensor_data)
170 |
171 | # aggregate metrics
172 | self.eval_metric_aggregator.init(step)
173 |
174 | # evaluate all metrics
175 | for metric, metric_key_op_dict in self.build_op_aggregator.iterate_metrics():
176 | if metric.is_tensor_metric:
177 | with timer(f"{metric}.expectation_of"):
178 | # already aggregated - tensor ops
179 | metric_result = dict()
180 | for metric_key in metric_key_op_dict:
181 | if metric_key in eval_dict:
182 | exp_value = metric.expectation_of(eval_dict[metric_key])
183 | metric_result[metric_key] = exp_value
184 | else:
185 | with timer(f"{metric}.evaluate"):
186 | # need calculation - non tensor ops
187 | metric_result = metric.evaluate(non_tensor_data)
188 |
189 | self.eval_metric_aggregator.aggregate(metric, metric_result)
190 |
191 | def write_tensor_summaries(self, step, summary_value):
192 | self.summary.write(summary_value, step)
193 |
194 | def write_evaluation_summaries(self, step, collection_keys):
195 | assert step == self.eval_metric_aggregator.step, \
196 | (f"step: {step} is different from aggregator's step: {self.eval_metric_aggregator.step}"
197 | f"`evaluate` function should be called before calling this function")
198 |
199 | collection_summary_dict = self.eval_metric_aggregator.get_collection_summary_dict()
200 | self.summary.write_evaluation_summaries(step=step,
201 | collection_keys=collection_keys,
202 | collection_summary_dict=collection_summary_dict)
203 |
204 | def log_metrics(self, step):
205 | """
206 | Logging metrics that are evaluated.
207 | """
208 | assert step == self.eval_metric_aggregator.step, \
209 | (f"step: {step} is different from aggregator's step: {self.eval_metric_aggregator.step}"
210 | f"`evaluate` function should be called before calling this function")
211 |
212 | log_dicts = dict()
213 | log_dicts.update(self.eval_metric_aggregator.get_logs())
214 |
215 | with format_text("green", ["bold"]) as fmt:
216 | for metric_key, log_str in log_dicts.items():
217 | self.log.info(fmt(log_str))
218 |
219 | def get_evaluation_result(self, step):
220 | """
221 | Retrun evaluation result regardless of metric type.
222 | """
223 | assert step == self.eval_metric_aggregator.step, \
224 | (f"step: {step} is different from aggregator's step: {self.eval_metric_aggregator.step}"
225 | f"`evaluate` function should be called before calling this function")
226 |
227 | eval_dict = dict()
228 | for metric, metric_key, value in self.eval_metric_aggregator.iterate_all():
229 | eval_dict[metric_key] = value
230 |
231 | return eval_dict
232 |
233 | def get_best_keep_metric_with_modes(self):
234 | metric_min_max_dict = dict()
235 | for metric, metric_key, _ in self.build_op_aggregator.iterate_all():
236 | if metric.is_for_best_keep:
237 | metric_min_max_dict[metric_key] = metric.min_max_mode
238 |
239 | return metric_min_max_dict
240 |
241 | def filter_best_keep_metric(self, eval_metric_dict):
242 | best_keep_metric_dict = dict()
243 | for metric, metric_key, _ in self.build_op_aggregator.iterate_all():
244 | if metric_key in eval_metric_dict and metric.is_for_best_keep:
245 | best_keep_metric_dict[metric_key] = eval_metric_dict[metric_key]
246 |
247 | return best_keep_metric_dict
248 |
249 | @staticmethod
250 | def add_arguments(parser):
251 | subparser = parser.add_argument_group(f"Metric Manager Arguments")
252 | subparser.add_argument("--exclude_metric_names",
253 | nargs="*",
254 | default=[],
255 | type=str,
256 | help="Name of metrics to be excluded")
257 | subparser.add_argument("--max_summary_outputs",
258 | default=3,
259 | type=int,
260 | help="Number of maximum summary outputs for multimedia (ex: audio wav)")
261 |
--------------------------------------------------------------------------------
/metrics/funcs.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | def topN_accuracy(y_true: np.array,
5 | y_pred_onehot: np.array,
6 | N: int):
7 | """ Top N accuracy """
8 | assert len(y_true.shape) == 1
9 | assert len(y_pred_onehot.shape) == 2
10 | assert y_true.shape[0] == y_pred_onehot.shape[0]
11 | assert y_pred_onehot.shape[1] >= N
12 |
13 | true_positive = 0
14 | for label, top_n_pred in zip(y_true, np.argsort(-y_pred_onehot, axis=-1)[:, :N]):
15 | if label in top_n_pred:
16 | true_positive += 1
17 |
18 | accuracy = true_positive / len(y_true)
19 |
20 | return accuracy
21 |
--------------------------------------------------------------------------------
/metrics/manager.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 |
3 | import metrics.ops as mops
4 | import metrics.parser as parser
5 | from metrics.base import MetricManagerBase
6 | from metrics.summaries import Summaries
7 |
8 |
9 | class AudioMetricManager(MetricManagerBase):
10 | _metric_input_data_parser = parser.AudioDataParser
11 |
12 | def __init__(
13 | self,
14 | is_training: bool,
15 | use_class_metrics: bool,
16 | exclude_metric_names: List,
17 | summary: Summaries,
18 | ):
19 | super().__init__(exclude_metric_names, summary)
20 | self.register_metrics([
21 | # map
22 | mops.MAPMetricOp(),
23 | # accuracy
24 | mops.AccuracyMetricOp(),
25 | mops.Top5AccuracyMetricOp(),
26 | # misc
27 | mops.ClassificationReportMetricOp(),
28 |
29 | # tensor ops
30 | mops.LossesMetricOp(),
31 | ])
32 |
33 | if is_training:
34 | self.register_metrics([
35 | mops.WavSummaryOp(),
36 | mops.LearningRateSummaryOp()
37 | ])
38 |
39 | if use_class_metrics:
40 | # per-class
41 | self.register_metrics([
42 | mops.PrecisionMetricOp(),
43 | mops.RecallMetricOp(),
44 | mops.F1ScoreMetricOp(),
45 | mops.APMetricOp(),
46 | ])
47 |
--------------------------------------------------------------------------------
/metrics/ops/__init__.py:
--------------------------------------------------------------------------------
1 | from .base_ops import MetricOpBase
2 | from .base_ops import TensorMetricOpBase
3 | from .base_ops import NonTensorMetricOpBase
4 | from .non_tensor_ops import *
5 | from .tensor_ops import *
6 | from .misc_ops import *
7 |
--------------------------------------------------------------------------------
/metrics/ops/base_ops.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 |
3 | from common.utils import get_logger
4 | from metrics.summaries import BaseSummaries
5 |
6 |
7 | class MetricOpBase(ABC):
8 | MIN_MAX_CHOICES = ["min", "max", None]
9 |
10 | _meta_properties = [
11 | "is_for_summary",
12 | "is_for_best_keep",
13 | "is_for_log",
14 | "valid_input_data_parsers",
15 | "summary_collection_key",
16 | "summary_value_type",
17 | "min_max_mode",
18 | ]
19 | _properties = dict()
20 |
21 | def __init__(self, **kwargs):
22 | self.log = get_logger("MetricOp")
23 |
24 | # init by _properties
25 | # custom values can be added as kwargs
26 | for attr in self._meta_properties:
27 | if attr in kwargs:
28 | setattr(self, attr, kwargs[attr])
29 | else:
30 | setattr(self, attr, self._properties[attr])
31 |
32 | # assertion
33 | assert self.min_max_mode in self.MIN_MAX_CHOICES
34 |
35 | if self.is_for_best_keep:
36 | assert self.min_max_mode is not None
37 |
38 | if self.is_for_summary:
39 | assert self.summary_collection_key in vars(BaseSummaries.KEY_TYPES).values()
40 |
41 | def __hash__(self):
42 | return hash(str(self))
43 |
44 | def __eq__(self, other):
45 | return str(self) == str(other)
46 |
47 | @property
48 | def is_placeholder_summary(self):
49 | assert self.is_for_summary, f"DO NOT call `is_placeholder_summary` method if it is not summary metric"
50 | return self.summary_value_type == BaseSummaries.VALUE_TYPES.PLACEHOLDER
51 |
52 | @property
53 | @abstractmethod
54 | def is_tensor_metric(self):
55 | raise NotImplementedError
56 |
57 | @abstractmethod
58 | def __str__(self):
59 | raise NotImplementedError
60 |
61 | @abstractmethod
62 | def build_op(self, data):
63 | """ This class should be overloaded for
64 | all cases of `valid_input_data_parser`
65 | """
66 | raise NotImplementedError
67 |
68 |
69 | class NonTensorMetricOpBase(MetricOpBase):
70 | @property
71 | def is_tensor_metric(self):
72 | return False
73 |
74 | @abstractmethod
75 | def evaluate(self, data):
76 | """ This class should be overloaded for
77 | all cases of `valid_input_data_parser`
78 | """
79 | raise NotImplementedError
80 |
81 |
82 | class TensorMetricOpBase(MetricOpBase):
83 | @property
84 | def is_tensor_metric(self):
85 | return True
86 |
87 | @abstractmethod
88 | def expectation_of(self, data):
89 | """ If evaluate is done at tensor metric, it has to re-caculate the expectation of
90 | aggregated metric values.
91 | This function assumes that data is aggregated for all batches
92 | and retruns proper expectation value.
93 | """
94 | raise NotImplementedError
95 |
--------------------------------------------------------------------------------
/metrics/ops/misc_ops.py:
--------------------------------------------------------------------------------
1 | from metrics.ops.base_ops import TensorMetricOpBase
2 | from metrics.summaries import BaseSummaries
3 | import metrics.parser as parser
4 |
5 |
6 | class LearningRateSummaryOp(TensorMetricOpBase):
7 | """
8 | Learning rate summary.
9 | """
10 | _properties = {
11 | "is_for_summary": True,
12 | "is_for_best_keep": False,
13 | "is_for_log": True,
14 | "valid_input_data_parsers": [
15 | parser.AudioDataParser,
16 | ],
17 | "summary_collection_key": BaseSummaries.KEY_TYPES.DEFAULT,
18 | "summary_value_type": BaseSummaries.VALUE_TYPES.SCALAR,
19 | "min_max_mode": None,
20 | }
21 |
22 | def __str__(self):
23 | return "summary_lr"
24 |
25 | def build_op(self,
26 | data):
27 | res = dict()
28 | if data.learning_rate is None:
29 | pass
30 | else:
31 | res[f"learning_rate/{data.dataset_split_name}"] = data.learning_rate
32 |
33 | return res
34 |
35 | def expectation_of(self, data):
36 | pass
37 |
--------------------------------------------------------------------------------
/metrics/ops/non_tensor_ops.py:
--------------------------------------------------------------------------------
1 | from sklearn.metrics import accuracy_score
2 | from sklearn.metrics import average_precision_score
3 | from sklearn.metrics import precision_score
4 | from sklearn.metrics import recall_score
5 | from sklearn.metrics import f1_score
6 | from sklearn.metrics import classification_report
7 | from overload import overload
8 |
9 | import metrics.parser as parser
10 | from metrics.funcs import topN_accuracy
11 | from metrics.ops.base_ops import NonTensorMetricOpBase
12 | from metrics.summaries import BaseSummaries
13 |
14 |
15 | class MAPMetricOp(NonTensorMetricOpBase):
16 | """
17 | Micro Mean Average Precision Metric.
18 | """
19 | _properties = {
20 | "is_for_summary": True,
21 | "is_for_best_keep": True,
22 | "is_for_log": True,
23 | "valid_input_data_parsers": [
24 | parser.AudioDataParser,
25 | ],
26 | "summary_collection_key": BaseSummaries.KEY_TYPES.DEFAULT,
27 | "summary_value_type": BaseSummaries.VALUE_TYPES.PLACEHOLDER,
28 | "min_max_mode": "max",
29 | }
30 |
31 | _average_fns = {
32 | "macro": lambda t, p: average_precision_score(t, p, average="macro"),
33 | "micro": lambda t, p: average_precision_score(t, p, average="micro"),
34 | "weighted": lambda t, p: average_precision_score(t, p, average="weighted"),
35 | "samples": lambda t, p: average_precision_score(t, p, average="samples"),
36 | }
37 |
38 | def __str__(self):
39 | return "mAP_metric"
40 |
41 | @overload
42 | def build_op(self,
43 | data: parser.AudioDataParser.OutputBuildData):
44 | result = dict()
45 |
46 | for avg_name in self._average_fns:
47 | key = f"mAP/{data.dataset_split_name}/{avg_name}"
48 | result[key] = None
49 |
50 | return result
51 |
52 | @overload
53 | def evaluate(self,
54 | data: parser.AudioDataParser.OutputNonTensorData):
55 | result = dict()
56 |
57 | for avg_name, avg_fn in self._average_fns.items():
58 | key = f"mAP/{data.dataset_split_name}/{avg_name}"
59 | result[key] = avg_fn(data.labels_onehot, data.predictions_onehot)
60 |
61 | return result
62 |
63 |
64 | class AccuracyMetricOp(NonTensorMetricOpBase):
65 | """
66 | Accuracy Metric.
67 | """
68 | _properties = {
69 | "is_for_summary": True,
70 | "is_for_best_keep": True,
71 | "is_for_log": True,
72 | "valid_input_data_parsers": [
73 | parser.AudioDataParser,
74 | ],
75 | "summary_collection_key": BaseSummaries.KEY_TYPES.DEFAULT,
76 | "summary_value_type": BaseSummaries.VALUE_TYPES.PLACEHOLDER,
77 | "min_max_mode": "max",
78 | }
79 |
80 | def __str__(self):
81 | return "accuracy_metric"
82 |
83 | @overload
84 | def build_op(self,
85 | data: parser.AudioDataParser.OutputBuildData):
86 | key = f"accuracy/{data.dataset_split_name}"
87 |
88 | return {
89 | key: None
90 | }
91 |
92 | @overload
93 | def evaluate(self,
94 | data: parser.AudioDataParser.OutputNonTensorData):
95 | key = f"accuracy/{data.dataset_split_name}"
96 |
97 | metric = accuracy_score(data.labels, data.predictions)
98 |
99 | return {
100 | key: metric
101 | }
102 |
103 |
104 | class Top5AccuracyMetricOp(NonTensorMetricOpBase):
105 | """
106 | Top 5 Accuracy Metric.
107 | """
108 | _properties = {
109 | "is_for_summary": True,
110 | "is_for_best_keep": True,
111 | "is_for_log": True,
112 | "valid_input_data_parsers": [
113 | parser.AudioDataParser,
114 | ],
115 | "summary_collection_key": BaseSummaries.KEY_TYPES.DEFAULT,
116 | "summary_value_type": BaseSummaries.VALUE_TYPES.PLACEHOLDER,
117 | "min_max_mode": "max",
118 | }
119 |
120 | def __str__(self):
121 | return "top5_accuracy_metric"
122 |
123 | @overload
124 | def build_op(self,
125 | data: parser.AudioDataParser.OutputBuildData):
126 | key = f"top5_accuracy/{data.dataset_split_name}"
127 |
128 | return {
129 | key: None
130 | }
131 |
132 | @overload
133 | def evaluate(self,
134 | data: parser.AudioDataParser.OutputNonTensorData):
135 | key = f"top5_accuracy/{data.dataset_split_name}"
136 |
137 | metric = topN_accuracy(y_true=data.labels,
138 | y_pred_onehot=data.predictions_onehot,
139 | N=5)
140 |
141 | return {
142 | key: metric
143 | }
144 |
145 |
146 | class PrecisionMetricOp(NonTensorMetricOpBase):
147 | """
148 | Precision Metric.
149 | """
150 | _properties = {
151 | "is_for_summary": True,
152 | "is_for_best_keep": True,
153 | "is_for_log": True,
154 | "valid_input_data_parsers": [
155 | parser.AudioDataParser,
156 | ],
157 | "summary_collection_key": BaseSummaries.KEY_TYPES.DEFAULT,
158 | "summary_value_type": BaseSummaries.VALUE_TYPES.PLACEHOLDER,
159 | "min_max_mode": "max",
160 | }
161 |
162 | def __str__(self):
163 | return "precision_metric"
164 |
165 | @overload
166 | def build_op(self,
167 | data: parser.AudioDataParser.OutputBuildData):
168 | result = dict()
169 |
170 | label_idxes = list(range(len(data.label_names)))
171 |
172 | for label_idx in label_idxes:
173 | label_name = data.label_names[label_idx]
174 | key = f"precision/{data.dataset_split_name}/{label_name}"
175 | result[key] = None
176 |
177 | return result
178 |
179 | @overload
180 | def evaluate(self,
181 | data: parser.AudioDataParser.OutputNonTensorData):
182 | result = dict()
183 |
184 | label_idxes = list(range(len(data.label_names)))
185 | precisions = precision_score(data.labels, data.predictions, average=None, labels=label_idxes)
186 |
187 | for label_idx in label_idxes:
188 | label_name = data.label_names[label_idx]
189 | key = f"precision/{data.dataset_split_name}/{label_name}"
190 | metric = precisions[label_idx]
191 | result[key] = metric
192 |
193 | return result
194 |
195 |
196 | class RecallMetricOp(NonTensorMetricOpBase):
197 | """
198 | Recall Metric.
199 | """
200 | _properties = {
201 | "is_for_summary": True,
202 | "is_for_best_keep": True,
203 | "is_for_log": True,
204 | "valid_input_data_parsers": [
205 | parser.AudioDataParser,
206 | ],
207 | "summary_collection_key": BaseSummaries.KEY_TYPES.DEFAULT,
208 | "summary_value_type": BaseSummaries.VALUE_TYPES.PLACEHOLDER,
209 | "min_max_mode": "max",
210 | }
211 |
212 | def __str__(self):
213 | return "recall_metric"
214 |
215 | @overload
216 | def build_op(self,
217 | data: parser.AudioDataParser.OutputBuildData):
218 | result = dict()
219 |
220 | label_idxes = list(range(len(data.label_names)))
221 |
222 | for label_idx in label_idxes:
223 | label_name = data.label_names[label_idx]
224 | key = f"recall/{data.dataset_split_name}/{label_name}"
225 | result[key] = None
226 |
227 | return result
228 |
229 | @overload
230 | def evaluate(self,
231 | data: parser.AudioDataParser.OutputNonTensorData):
232 | result = dict()
233 |
234 | label_idxes = list(range(len(data.label_names)))
235 | recalls = recall_score(data.labels, data.predictions, average=None, labels=label_idxes)
236 |
237 | for label_idx in label_idxes:
238 | label_name = data.label_names[label_idx]
239 | key = f"recall/{data.dataset_split_name}/{label_name}"
240 | metric = recalls[label_idx]
241 | result[key] = metric
242 |
243 | return result
244 |
245 |
246 | class F1ScoreMetricOp(NonTensorMetricOpBase):
247 | """
248 | Per class F1-Score Metric.
249 | """
250 | _properties = {
251 | "is_for_summary": True,
252 | "is_for_best_keep": True,
253 | "is_for_log": True,
254 | "valid_input_data_parsers": [
255 | parser.AudioDataParser,
256 | ],
257 | "summary_collection_key": BaseSummaries.KEY_TYPES.DEFAULT,
258 | "summary_value_type": BaseSummaries.VALUE_TYPES.PLACEHOLDER,
259 | "min_max_mode": "max",
260 | }
261 |
262 | def __str__(self):
263 | return "f1_score_metric"
264 |
265 | @overload
266 | def build_op(self,
267 | data: parser.AudioDataParser.OutputBuildData):
268 | result = dict()
269 |
270 | label_idxes = list(range(len(data.label_names)))
271 |
272 | for label_idx in label_idxes:
273 | label_name = data.label_names[label_idx]
274 | key = f"f1score/{data.dataset_split_name}/{label_name}"
275 | result[key] = None
276 |
277 | return result
278 |
279 | @overload
280 | def evaluate(self,
281 | data: parser.AudioDataParser.OutputNonTensorData):
282 | result = dict()
283 |
284 | label_idxes = list(range(len(data.label_names)))
285 | f1_scores = f1_score(data.labels, data.predictions, average=None, labels=label_idxes)
286 |
287 | for label_idx in label_idxes:
288 | label_name = data.label_names[label_idx]
289 | key = f"f1score/{data.dataset_split_name}/{label_name}"
290 | metric = f1_scores[label_idx]
291 | result[key] = metric
292 |
293 | return result
294 |
295 |
296 | class APMetricOp(NonTensorMetricOpBase):
297 | """
298 | Per class Average Precision Metric.
299 | """
300 | _properties = {
301 | "is_for_summary": True,
302 | "is_for_best_keep": True,
303 | "is_for_log": True,
304 | "valid_input_data_parsers": [
305 | parser.AudioDataParser,
306 | ],
307 | "summary_collection_key": BaseSummaries.KEY_TYPES.DEFAULT,
308 | "summary_value_type": BaseSummaries.VALUE_TYPES.PLACEHOLDER,
309 | "min_max_mode": "max",
310 | }
311 |
312 | def __str__(self):
313 | return "ap_score_metric"
314 |
315 | @overload
316 | def build_op(self,
317 | data: parser.AudioDataParser.OutputBuildData):
318 | result = dict()
319 |
320 | label_idxes = list(range(len(data.label_names)))
321 |
322 | for label_idx in label_idxes:
323 | label_name = data.label_names[label_idx]
324 | key = f"ap/{data.dataset_split_name}/{label_name}"
325 | result[key] = None
326 |
327 | return result
328 |
329 | @overload
330 | def evaluate(self,
331 | data: parser.AudioDataParser.OutputNonTensorData):
332 | result = dict()
333 |
334 | label_idxes = list(range(len(data.label_names)))
335 | aps = average_precision_score(data.labels_onehot, data.predictions_onehot, average=None)
336 |
337 | for label_idx in label_idxes:
338 | label_name = data.label_names[label_idx]
339 | key = f"ap/{data.dataset_split_name}/{label_name}"
340 | metric = aps[label_idx]
341 | result[key] = metric
342 |
343 | return result
344 |
345 |
346 | class ClassificationReportMetricOp(NonTensorMetricOpBase):
347 | """
348 | Accuracy Metric.
349 | """
350 | _properties = {
351 | "is_for_summary": False,
352 | "is_for_best_keep": False,
353 | "is_for_log": True,
354 | "valid_input_data_parsers": [
355 | parser.AudioDataParser,
356 | ],
357 | "summary_collection_key": None,
358 | "summary_value_type": None,
359 | "min_max_mode": None,
360 | }
361 |
362 | def __str__(self):
363 | return "classification_report_metric"
364 |
365 | @overload
366 | def build_op(self,
367 | data: parser.AudioDataParser.OutputBuildData):
368 | key = f"classification_report/{data.dataset_split_name}"
369 |
370 | return {
371 | key: None
372 | }
373 |
374 | @overload
375 | def evaluate(self,
376 | data: parser.AudioDataParser.OutputNonTensorData):
377 | key = f"classification_report/{data.dataset_split_name}"
378 |
379 | label_idxes = list(range(len(data.label_names)))
380 | metric = classification_report(data.labels,
381 | data.predictions,
382 | labels=label_idxes,
383 | target_names=data.label_names)
384 | metric = f"[ClassificationReport]\n{metric}"
385 |
386 | return {
387 | key: metric
388 | }
389 |
--------------------------------------------------------------------------------
/metrics/ops/tensor_ops.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | import metrics.parser as parser
4 | from metrics.ops.base_ops import TensorMetricOpBase
5 | from metrics.summaries import BaseSummaries
6 |
7 |
8 | class LossesMetricOp(TensorMetricOpBase):
9 | """ Loss Metric.
10 | """
11 | _properties = {
12 | "is_for_summary": True,
13 | "is_for_best_keep": True,
14 | "is_for_log": True,
15 | "valid_input_data_parsers": [
16 | parser.AudioDataParser,
17 | ],
18 | "summary_collection_key": BaseSummaries.KEY_TYPES.DEFAULT,
19 | "summary_value_type": BaseSummaries.VALUE_TYPES.PLACEHOLDER,
20 | "min_max_mode": "min",
21 | }
22 |
23 | def __str__(self):
24 | return "losses"
25 |
26 | def build_op(self, data):
27 | result = dict()
28 |
29 | for loss_name, loss_op in data.losses.items():
30 | key = f"metric_loss/{data.dataset_split_name}/{loss_name}"
31 | result[key] = loss_op
32 |
33 | return result
34 |
35 | def expectation_of(self, data: np.array):
36 | assert len(data.shape) == 2
37 | return np.mean(data)
38 |
39 |
40 | class WavSummaryOp(TensorMetricOpBase):
41 | _properties = {
42 | "is_for_summary": True,
43 | "is_for_best_keep": False,
44 | "is_for_log": False,
45 | "valid_input_data_parsers": [
46 | parser.AudioDataParser,
47 | ],
48 | "summary_collection_key": BaseSummaries.KEY_TYPES.DEFAULT,
49 | "summary_value_type": BaseSummaries.VALUE_TYPES.AUDIO,
50 | "min_max_mode": None,
51 | }
52 |
53 | def __str__(self):
54 | return "summary_wav"
55 |
56 | def build_op(self, data: parser.AudioDataParser.OutputBuildData):
57 | return {
58 | f"wav/{data.dataset_split_name}": data.wavs
59 | }
60 |
61 | def expectation_of(self, data):
62 | pass
63 |
--------------------------------------------------------------------------------
/metrics/parser.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, ABCMeta
2 |
3 | from metrics.base import DataStructure
4 |
5 |
6 | class MetricDataParserBase(ABC):
7 | @classmethod
8 | def parse_build_data(cls, data):
9 | """
10 | Args:
11 | data: dictionary which will be passed to InputBuildData
12 | """
13 | data = cls._validate_build_data(data)
14 | data = cls._process_build_data(data)
15 | return data
16 |
17 | @classmethod
18 | def parse_non_tensor_data(cls, data):
19 | """
20 | Args:
21 | data: dictionary which will be passed to InputDataStructure
22 | """
23 | input_data = cls._validate_non_tensor_data(data)
24 | output_data = cls._process_non_tensor_data(input_data)
25 | return output_data
26 |
27 | @classmethod
28 | def _validate_build_data(cls, data):
29 | """
30 | Specify assertions that tensor data should contains
31 |
32 | Args:
33 | data: dictionary
34 | Return:
35 | InputDataStructure
36 | """
37 | return cls.InputBuildData(data)
38 |
39 | @classmethod
40 | def _validate_non_tensor_data(cls, data):
41 | """
42 | Specify assertions that non-tensor data should contains
43 |
44 | Args:
45 | data: dictionary
46 | Return:
47 | InputDataStructure
48 | """
49 | return cls.InputNonTensorData(data)
50 |
51 | """
52 | Override these two functions if needed.
53 | """
54 | @classmethod
55 | def _process_build_data(cls, data):
56 | """
57 | Process data in order to following metrics can use it
58 |
59 | Args:
60 | data: InputBuildData
61 |
62 | Return:
63 | OutputBuildData
64 | """
65 | # default function is just passing data
66 | return cls.OutputBuildData(data.to_dict())
67 |
68 | @classmethod
69 | def _process_non_tensor_data(cls, data):
70 | """
71 | Process data in order to following metrics can use it
72 |
73 | Args:
74 | data: InputNonTensorData
75 |
76 | Return:
77 | OutputNonTensorData
78 | """
79 | # default function is just passing data
80 | return cls.OutputNonTensorData(data.to_dict())
81 |
82 | """
83 | Belows should be implemented when inherit.
84 | """
85 | class InputBuildData(DataStructure, metaclass=ABCMeta):
86 | pass
87 |
88 | class OutputBuildData(DataStructure, metaclass=ABCMeta):
89 | pass
90 |
91 | class InputNonTensorData(DataStructure, metaclass=ABCMeta):
92 | pass
93 |
94 | class OutputNonTensorData(DataStructure, metaclass=ABCMeta):
95 | pass
96 |
97 |
98 | class AudioDataParser(MetricDataParserBase):
99 | class InputBuildData(DataStructure):
100 | _keys = [
101 | "dataset_split_name",
102 | "label_names",
103 | "losses", # Dict | loss_key -> Tensor
104 | "learning_rate",
105 | "wavs",
106 | ]
107 |
108 | class OutputBuildData(DataStructure):
109 | _keys = [
110 | "dataset_split_name",
111 | "label_names",
112 | "losses",
113 | "learning_rate",
114 | "wavs",
115 | ]
116 |
117 | class InputNonTensorData(DataStructure):
118 | _keys = [
119 | "dataset_split_name",
120 | "label_names",
121 | "predictions_onehot",
122 | "labels_onehot",
123 | ]
124 |
125 | class OutputNonTensorData(DataStructure):
126 | _keys = [
127 | "dataset_split_name",
128 | "label_names",
129 | "predictions_onehot",
130 | "labels_onehot",
131 | "predictions",
132 | "labels",
133 | ]
134 |
135 | @classmethod
136 | def _process_non_tensor_data(cls, data):
137 | predictions = data.predictions_onehot.argmax(axis=-1)
138 | labels = data.labels_onehot.argmax(axis=-1)
139 |
140 | return cls.OutputNonTensorData({
141 | "dataset_split_name": data.dataset_split_name,
142 | "label_names": data.label_names,
143 | "predictions_onehot": data.predictions_onehot,
144 | "labels_onehot": data.labels_onehot,
145 | "predictions": predictions,
146 | "labels": labels,
147 | })
148 |
--------------------------------------------------------------------------------
/metrics/summaries.py:
--------------------------------------------------------------------------------
1 | from abc import ABC
2 | from types import SimpleNamespace
3 | from pathlib import Path
4 |
5 | import tensorflow as tf
6 | from overload import overload
7 |
8 | from common.utils import get_logger
9 |
10 |
11 | class BaseSummaries(ABC):
12 | KEY_TYPES = SimpleNamespace(
13 | DEFAULT="SUMMARY_DEFAULT",
14 | VERBOSE="SUMMARY_VERBOSE",
15 | FIRST_N="SUMMARY_FIRST_N",
16 | )
17 |
18 | VALUE_TYPES = SimpleNamespace(
19 | SCALAR="SCALAR",
20 | PLACEHOLDER="PLACEHOLDER",
21 | AUDIO="AUDIO",
22 | NONE="NONE", # None is not used for summary
23 | )
24 |
25 | def __init__(
26 | self,
27 | session,
28 | train_dir,
29 | is_training,
30 | base_name="eval",
31 | max_summary_outputs=None,
32 | ):
33 | self.log = get_logger("Summary")
34 |
35 | self.session = session
36 | self.train_dir = train_dir
37 | self.max_summary_outputs = max_summary_outputs
38 | self.base_name = base_name
39 | self.merged_summaries = dict()
40 |
41 | self.summary_writer = None
42 | self._setup_summary_writer(is_training)
43 |
44 | def write(self, summary, global_step=0):
45 | self.summary_writer.add_summary(summary, global_step)
46 |
47 | def setup_experiment(self, config):
48 | """
49 | Args:
50 | config: Namespace
51 | """
52 | config = vars(config)
53 |
54 | sorted_config = [(k, str(v)) for k, v in sorted(config.items(), key=lambda x: x[0])]
55 | config = tf.summary.text("config", tf.convert_to_tensor(sorted_config))
56 |
57 | config_val = self.session.run(config)
58 | self.write(config_val)
59 | self.summary_writer.add_graph(tf.get_default_graph())
60 |
61 | def register_summaries(self, collection_summary_dict):
62 | """
63 | Args:
64 | collection_summary_dict: Dict[str, Dict[MetricOp, List[Tuple(metric_key, tensor_op)]]]
65 | collection_key -> metric -> List(summary_key, value)
66 | """
67 | for collection_key_suffix, metric_dict in collection_summary_dict.items():
68 | for metric, key_value_list in metric_dict.items():
69 | for summary_key, value in key_value_list:
70 | self._routine_add_summary_op(summary_value_type=metric.summary_value_type,
71 | summary_key=summary_key,
72 | value=value,
73 | collection_key_suffix=collection_key_suffix)
74 |
75 | def setup_merged_summaries(self):
76 | for collection_key_suffix in vars(self.KEY_TYPES).values():
77 | for collection_key in self._iterate_collection_keys(collection_key_suffix):
78 | merged_summary = tf.summary.merge_all(key=collection_key)
79 | self.merged_summaries[collection_key] = merged_summary
80 |
81 | def get_merged_summaries(self, collection_key_suffixes: list, is_tensor_summary: bool):
82 | summaries = []
83 |
84 | for collection_key_suffix in collection_key_suffixes:
85 | collection_key = self._build_collection_key(collection_key_suffix, is_tensor_summary)
86 | summary = self.merged_summaries[collection_key]
87 |
88 | if summary is not None:
89 | summaries.append(summary)
90 |
91 | if len(summaries) == 0:
92 | return None
93 | elif len(summaries) == 1:
94 | return summaries[0]
95 | else:
96 | return tf.summary.merge(summaries)
97 |
98 | def write_evaluation_summaries(self, step, collection_keys, collection_summary_dict):
99 | """
100 | Args:
101 | collection_summary_dict: Dict[str, Dict[MetricOp, List[Tuple(metric_key, tensor_op)]]]
102 | collection_key -> metric -> List(summary_key, value)
103 | collection_keys: List
104 | """
105 | for collection_key_suffix, metric_dict in collection_summary_dict.items():
106 | if collection_key_suffix in collection_keys:
107 | merged_summary_op = self.get_merged_summaries(collection_key_suffixes=[collection_key_suffix],
108 | is_tensor_summary=False)
109 | feed_dict = dict()
110 |
111 | for metric, key_value_list in metric_dict.items():
112 | if metric.is_placeholder_summary:
113 | for summary_key, value in key_value_list:
114 | # https://github.com/tensorflow/tensorflow/issues/3378
115 | placeholder_name = self._build_placeholder_name(summary_key) + ":0"
116 | feed_dict[placeholder_name] = value
117 |
118 | summary_value = self.session.run(merged_summary_op, feed_dict=feed_dict)
119 | self.write(summary_value, step)
120 |
121 | def _setup_summary_writer(self, is_training):
122 | summary_directory = self._build_summary_directory(is_training)
123 | self.log.info(f"Write summaries into : {summary_directory}")
124 |
125 | if is_training:
126 | self.summary_writer = tf.summary.FileWriter(summary_directory, self.session.graph)
127 | else:
128 | self.summary_writer = tf.summary.FileWriter(summary_directory)
129 |
130 | def _build_summary_directory(self, is_training):
131 | if is_training:
132 | return self.train_dir
133 | else:
134 | if Path(self.train_dir).is_dir():
135 | summary_directory = (Path(self.train_dir) / Path(self.base_name)).as_posix()
136 | else:
137 | summary_directory = (Path(self.train_dir).parent / Path(self.base_name)).as_posix()
138 |
139 | if not Path(summary_directory).exists():
140 | Path(summary_directory).mkdir(parents=True)
141 |
142 | return summary_directory
143 |
144 | def _routine_add_summary_op(self, summary_value_type, summary_key, value, collection_key_suffix):
145 | collection_key = self._build_collection_key(collection_key_suffix, summary_value_type)
146 |
147 | if summary_value_type == self.VALUE_TYPES.SCALAR:
148 | def register_fn(k, v):
149 | return tf.summary.scalar(k, v, collections=[collection_key])
150 |
151 | elif summary_value_type == self.VALUE_TYPES.AUDIO:
152 | def register_fn(k, v):
153 | return tf.summary.audio(k, v,
154 | sample_rate=16000,
155 | max_outputs=self.max_summary_outputs,
156 | collections=[collection_key])
157 |
158 | elif summary_value_type == self.VALUE_TYPES.PLACEHOLDER:
159 | def register_fn(k, v):
160 | return tf.summary.scalar(k, v, collections=[collection_key])
161 | value = self._build_placeholder(summary_key)
162 |
163 | else:
164 | raise NotImplementedError
165 |
166 | register_fn(summary_key, value)
167 |
168 | @classmethod
169 | def _build_placeholder(cls, summary_key):
170 | name = cls._build_placeholder_name(summary_key)
171 | return tf.placeholder(tf.float32, [], name=name)
172 |
173 | @staticmethod
174 | def _build_placeholder_name(summary_key):
175 | return f"non_tensor_summary_placeholder/{summary_key}"
176 |
177 | # Below two functions should be class method but defined as instance method
178 | # since it has bug in @overload
179 | # @classmethod
180 | @overload
181 | def _build_collection_key(self, collection_key_suffix, summary_value_type: str):
182 | if summary_value_type == self.VALUE_TYPES.PLACEHOLDER:
183 | prefix = "NON_TENSOR"
184 | else:
185 | prefix = "TENSOR"
186 |
187 | return f"{prefix}_{collection_key_suffix}"
188 |
189 | # @classmethod
190 | @_build_collection_key.add
191 | def _build_collection_key(self, collection_key_suffix, is_tensor_summary: bool):
192 | if not is_tensor_summary:
193 | prefix = "NON_TENSOR"
194 | else:
195 | prefix = "TENSOR"
196 |
197 | return f"{prefix}_{collection_key_suffix}"
198 |
199 | @classmethod
200 | def _iterate_collection_keys(cls, collection_key_suffix):
201 | for prefix in ["NON_TENSOR", "TENSOR"]:
202 | yield f"{prefix}_{collection_key_suffix}"
203 |
204 |
205 | class Summaries(BaseSummaries):
206 | pass
207 |
--------------------------------------------------------------------------------
/requirements/py36-common.txt:
--------------------------------------------------------------------------------
1 | # ML
2 | numpy>=1.16.0
3 | pandas
4 | tqdm
5 | scikit-learn==0.19.1
6 | scipy
7 |
8 | # data
9 | termcolor
10 | click
11 | dask[dataframe]
12 |
13 | # misc
14 | humanfriendly
15 | overload
16 |
--------------------------------------------------------------------------------
/requirements/py36-cpu.txt:
--------------------------------------------------------------------------------
1 | # Deep learning
2 | tensorflow==1.13.1
3 | -r py36-common.txt
--------------------------------------------------------------------------------
/requirements/py36-gpu.txt:
--------------------------------------------------------------------------------
1 | # Deep learning
2 | tensorflow-gpu==1.13.1
3 | -r py36-common.txt
4 |
--------------------------------------------------------------------------------
/scripts/commands/DSCNNLModel-0_mfcc_10_4020_0.0000_adam_l3.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | trap 'pkill -P $$' SIGINT SIGTERM EXIT
3 | python train_audio.py --dataset_path google_speech_commands/splitted_data --dataset_split_name train --output_name output/softmax --num_classes 12 --train_dir work/v1/DSCNNLModel-0/mfcc_10_4020_0.0000_adam_l3 --num_silent 1854 --augmentation_method anchored_slice_or_pad_with_shift --preprocess_method mfcc --num_mfccs 10 --clip_duration_ms 1000 --window_size_ms 40 --window_stride_ms 20 --batch_size 100 --boundaries 10000 --max_step_from_restore 20000 --lr_list 0.0005 0.0001 --absolute_schedule --no-boundaries_epoch --max_to_keep 20 --step_save_checkpoint 500 --step_evaluation 500 --optimizer adam DSCNNLModel &
4 | sleep 5
5 | python evaluate_audio.py --dataset_path google_speech_commands/splitted_data --dataset_split_name valid --output_name output/softmax --num_classes 12 --checkpoint_path work/v1/DSCNNLModel-0/mfcc_10_4020_0.0000_adam_l3 --num_silent 258 --augmentation_method anchored_slice_or_pad --preprocess_method mfcc --num_mfccs 10 --clip_duration_ms 1000 --window_size_ms 40 --window_stride_ms 20 --background_frequency 0.0 --background_max_volume 0.0 --max_step_from_restore 20000 --batch_size 3 --no-shuffle --valid_type loop DSCNNLModel &
6 | wait
7 | python evaluate_audio.py --dataset_path google_speech_commands/splitted_data --dataset_split_name test --output_name output/softmax --num_classes 12 --checkpoint_path work/v1/DSCNNLModel-0/mfcc_10_4020_0.0000_adam_l3/valid/accuracy/valid --num_silent 257 --augmentation_method anchored_slice_or_pad --preprocess_method mfcc --num_mfccs 10 --clip_duration_ms 1000 --window_size_ms 40 --window_stride_ms 20 --background_frequency 0.0 --background_max_volume 0.0 --max_step_from_restore 20000 --batch_size 39 --no-shuffle --valid_type once DSCNNLModel
8 |
--------------------------------------------------------------------------------
/scripts/commands/DSCNNMModel-0_mfcc_10_4020_0.0000_adam_l3.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | trap 'pkill -P $$' SIGINT SIGTERM EXIT
3 | python train_audio.py --dataset_path google_speech_commands/splitted_data --dataset_split_name train --output_name output/softmax --num_classes 12 --train_dir work/v1/DSCNNMModel-0/mfcc_10_4020_0.0000_adam_l3 --num_silent 1854 --augmentation_method anchored_slice_or_pad_with_shift --preprocess_method mfcc --num_mfccs 10 --clip_duration_ms 1000 --window_size_ms 40 --window_stride_ms 20 --batch_size 100 --boundaries 10000 --max_step_from_restore 20000 --lr_list 0.0005 0.0001 --absolute_schedule --no-boundaries_epoch --max_to_keep 20 --step_save_checkpoint 500 --step_evaluation 500 --optimizer adam DSCNNMModel &
4 | sleep 5
5 | python evaluate_audio.py --dataset_path google_speech_commands/splitted_data --dataset_split_name valid --output_name output/softmax --num_classes 12 --checkpoint_path work/v1/DSCNNMModel-0/mfcc_10_4020_0.0000_adam_l3 --num_silent 258 --augmentation_method anchored_slice_or_pad --preprocess_method mfcc --num_mfccs 10 --clip_duration_ms 1000 --window_size_ms 40 --window_stride_ms 20 --background_frequency 0.0 --background_max_volume 0.0 --max_step_from_restore 20000 --batch_size 3 --no-shuffle --valid_type loop DSCNNMModel &
6 | wait
7 | python evaluate_audio.py --dataset_path google_speech_commands/splitted_data --dataset_split_name test --output_name output/softmax --num_classes 12 --checkpoint_path work/v1/DSCNNMModel-0/mfcc_10_4020_0.0000_adam_l3/valid/accuracy/valid --num_silent 257 --augmentation_method anchored_slice_or_pad --preprocess_method mfcc --num_mfccs 10 --clip_duration_ms 1000 --window_size_ms 40 --window_stride_ms 20 --background_frequency 0.0 --background_max_volume 0.0 --max_step_from_restore 20000 --batch_size 39 --no-shuffle --valid_type once DSCNNMModel
8 |
--------------------------------------------------------------------------------
/scripts/commands/DSCNNSModel-0_mfcc_10_4020_0.0000_adam_l3.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | trap 'pkill -P $$' SIGINT SIGTERM EXIT
3 | python train_audio.py --dataset_path google_speech_commands/splitted_data --dataset_split_name train --output_name output/softmax --num_classes 12 --train_dir work/v1/DSCNNSModel-0/mfcc_10_4020_0.0000_adam_l3 --num_silent 1854 --augmentation_method anchored_slice_or_pad_with_shift --preprocess_method mfcc --num_mfccs 10 --clip_duration_ms 1000 --window_size_ms 40 --window_stride_ms 20 --batch_size 100 --boundaries 10000 --max_step_from_restore 20000 --lr_list 0.0005 0.0001 --absolute_schedule --no-boundaries_epoch --max_to_keep 20 --step_save_checkpoint 500 --step_evaluation 500 --optimizer adam DSCNNSModel &
4 | sleep 5
5 | python evaluate_audio.py --dataset_path google_speech_commands/splitted_data --dataset_split_name valid --output_name output/softmax --num_classes 12 --checkpoint_path work/v1/DSCNNSModel-0/mfcc_10_4020_0.0000_adam_l3 --num_silent 258 --augmentation_method anchored_slice_or_pad --preprocess_method mfcc --num_mfccs 10 --clip_duration_ms 1000 --window_size_ms 40 --window_stride_ms 20 --background_frequency 0.0 --background_max_volume 0.0 --max_step_from_restore 20000 --batch_size 3 --no-shuffle --valid_type loop DSCNNSModel &
6 | wait
7 | python evaluate_audio.py --dataset_path google_speech_commands/splitted_data --dataset_split_name test --output_name output/softmax --num_classes 12 --checkpoint_path work/v1/DSCNNSModel-0/mfcc_10_4020_0.0000_adam_l3/valid/accuracy/valid --num_silent 257 --augmentation_method anchored_slice_or_pad --preprocess_method mfcc --num_mfccs 10 --clip_duration_ms 1000 --window_size_ms 40 --window_stride_ms 20 --background_frequency 0.0 --background_max_volume 0.0 --max_step_from_restore 20000 --batch_size 39 --no-shuffle --valid_type once DSCNNSModel
8 |
--------------------------------------------------------------------------------
/scripts/commands/KWSfpool3-0_mfcc_40_4020_0.0000_adam_l3.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | trap 'pkill -P $$' SIGINT SIGTERM EXIT
3 | python train_audio.py --dataset_path google_speech_commands/splitted_data --dataset_split_name train --output_name output/softmax --num_classes 12 --train_dir work/v1/KWSfpool3-0/mfcc_40_4020_0.0000_adam_l3 --num_silent 1854 --augmentation_method anchored_slice_or_pad_with_shift --preprocess_method mfcc --num_mfccs 40 --clip_duration_ms 1000 --window_size_ms 40 --window_stride_ms 20 --batch_size 100 --boundaries 10000 --max_step_from_restore 20000 --lr_list 0.0005 0.0001 --absolute_schedule --no-boundaries_epoch --max_to_keep 20 --step_save_checkpoint 500 --step_evaluation 500 --optimizer adam KWSModel --architecture trad_fpool3 &
4 | sleep 5
5 | python evaluate_audio.py --dataset_path google_speech_commands/splitted_data --dataset_split_name valid --output_name output/softmax --num_classes 12 --checkpoint_path work/v1/KWSfpool3-0/mfcc_40_4020_0.0000_adam_l3 --num_silent 258 --augmentation_method anchored_slice_or_pad --preprocess_method mfcc --num_mfccs 40 --clip_duration_ms 1000 --window_size_ms 40 --window_stride_ms 20 --background_frequency 0.0 --background_max_volume 0.0 --max_step_from_restore 20000 --batch_size 3 --no-shuffle --valid_type loop KWSModel --architecture trad_fpool3 &
6 | wait
7 | python evaluate_audio.py --dataset_path google_speech_commands/splitted_data --dataset_split_name test --output_name output/softmax --num_classes 12 --checkpoint_path work/v1/KWSfpool3-0/mfcc_40_4020_0.0000_adam_l3/valid/accuracy/valid --num_silent 257 --augmentation_method anchored_slice_or_pad --preprocess_method mfcc --num_mfccs 40 --clip_duration_ms 1000 --window_size_ms 40 --window_stride_ms 20 --background_frequency 0.0 --background_max_volume 0.0 --max_step_from_restore 20000 --batch_size 39 --no-shuffle --valid_type once KWSModel --architecture trad_fpool3
8 |
--------------------------------------------------------------------------------
/scripts/commands/KWSfstride4-0_mfcc_40_4020_0.0000_adam_l2.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | trap 'pkill -P $$' SIGINT SIGTERM EXIT
3 | python train_audio.py --dataset_path google_speech_commands/splitted_data --dataset_split_name train --output_name output/softmax --num_classes 12 --train_dir work/v1/KWSfstride4-0/mfcc_40_4020_0.0000_adam_l2 --num_silent 1854 --augmentation_method anchored_slice_or_pad_with_shift --preprocess_method mfcc --num_mfccs 40 --clip_duration_ms 1000 --window_size_ms 40 --window_stride_ms 20 --batch_size 100 --boundaries 10000 20000 --max_step_from_restore 30000 --lr_list 0.0005 0.0001 0.00002 --absolute_schedule --no-boundaries_epoch --max_to_keep 20 --step_save_checkpoint 500 --step_evaluation 500 --optimizer adam KWSModel --architecture one_fstride4 &
4 | sleep 5
5 | python evaluate_audio.py --dataset_path google_speech_commands/splitted_data --dataset_split_name valid --output_name output/softmax --num_classes 12 --checkpoint_path work/v1/KWSfstride4-0/mfcc_40_4020_0.0000_adam_l2 --num_silent 258 --augmentation_method anchored_slice_or_pad --preprocess_method mfcc --num_mfccs 40 --clip_duration_ms 1000 --window_size_ms 40 --window_stride_ms 20 --background_frequency 0.0 --background_max_volume 0.0 --max_step_from_restore 30000 --batch_size 3 --no-shuffle --valid_type loop KWSModel --architecture one_fstride4 &
6 | wait
7 | python evaluate_audio.py --dataset_path google_speech_commands/splitted_data --dataset_split_name test --output_name output/softmax --num_classes 12 --checkpoint_path work/v1/KWSfstride4-0/mfcc_40_4020_0.0000_adam_l2/valid/accuracy/valid --num_silent 257 --augmentation_method anchored_slice_or_pad --preprocess_method mfcc --num_mfccs 40 --clip_duration_ms 1000 --window_size_ms 40 --window_stride_ms 20 --background_frequency 0.0 --background_max_volume 0.0 --max_step_from_restore 30000 --batch_size 39 --no-shuffle --valid_type once KWSModel --architecture one_fstride4
8 |
--------------------------------------------------------------------------------
/scripts/commands/Res15Model-0_mfcc_40_3010_0.00001_adam_s1.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | trap 'pkill -P $$' SIGINT SIGTERM EXIT
3 | python train_audio.py --dataset_path google_speech_commands/splitted_data --dataset_split_name train --output_name output/softmax --num_classes 12 --train_dir work/v1/Res15Model-0/mfcc_40_3010_0.00001_adam_s1 --num_silent 1854 --augmentation_method anchored_slice_or_pad_with_shift --preprocess_method mfcc --num_mfccs 40 --clip_duration_ms 1000 --window_size_ms 30 --window_stride_ms 10 --batch_size 64 --boundaries 3000 6000 --max_step_from_restore 9000 --lr_list 0.1 0.01 0.001 --absolute_schedule --no-boundaries_epoch --max_to_keep 20 --step_save_checkpoint 500 --step_evaluation 500 --optimizer adam Res15Model --weight_decay 0.00001 &
4 | sleep 5
5 | python evaluate_audio.py --dataset_path google_speech_commands/splitted_data --dataset_split_name valid --output_name output/softmax --num_classes 12 --checkpoint_path work/v1/Res15Model-0/mfcc_40_3010_0.00001_adam_s1 --num_silent 258 --augmentation_method anchored_slice_or_pad --preprocess_method mfcc --num_mfccs 40 --clip_duration_ms 1000 --window_size_ms 30 --window_stride_ms 10 --background_frequency 0.0 --background_max_volume 0.0 --max_step_from_restore 9000 --batch_size 3 --no-shuffle --valid_type loop Res15Model --weight_decay 0.00001 &
6 | wait
7 | python evaluate_audio.py --dataset_path google_speech_commands/splitted_data --dataset_split_name test --output_name output/softmax --num_classes 12 --checkpoint_path work/v1/Res15Model-0/mfcc_40_3010_0.00001_adam_s1/valid/accuracy/valid --num_silent 257 --augmentation_method anchored_slice_or_pad --preprocess_method mfcc --num_mfccs 40 --clip_duration_ms 1000 --window_size_ms 30 --window_stride_ms 10 --background_frequency 0.0 --background_max_volume 0.0 --max_step_from_restore 9000 --batch_size 39 --no-shuffle --valid_type once Res15Model --weight_decay 0.00001
8 |
--------------------------------------------------------------------------------
/scripts/commands/Res15NarrowModel-0_mfcc_40_3010_0.00001_adam_s1.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | trap 'pkill -P $$' SIGINT SIGTERM EXIT
3 | python train_audio.py --dataset_path google_speech_commands/splitted_data --dataset_split_name train --output_name output/softmax --num_classes 12 --train_dir work/v1/Res15NarrowModel-0/mfcc_40_3010_0.00001_adam_s1 --num_silent 1854 --augmentation_method anchored_slice_or_pad_with_shift --preprocess_method mfcc --num_mfccs 40 --clip_duration_ms 1000 --window_size_ms 30 --window_stride_ms 10 --batch_size 64 --boundaries 3000 6000 --max_step_from_restore 9000 --lr_list 0.1 0.01 0.001 --absolute_schedule --no-boundaries_epoch --max_to_keep 20 --step_save_checkpoint 500 --step_evaluation 500 --optimizer adam Res15NarrowModel --weight_decay 0.00001 &
4 | sleep 5
5 | python evaluate_audio.py --dataset_path google_speech_commands/splitted_data --dataset_split_name valid --output_name output/softmax --num_classes 12 --checkpoint_path work/v1/Res15NarrowModel-0/mfcc_40_3010_0.00001_adam_s1 --num_silent 258 --augmentation_method anchored_slice_or_pad --preprocess_method mfcc --num_mfccs 40 --clip_duration_ms 1000 --window_size_ms 30 --window_stride_ms 10 --background_frequency 0.0 --background_max_volume 0.0 --max_step_from_restore 9000 --batch_size 3 --no-shuffle --valid_type loop Res15NarrowModel --weight_decay 0.00001 &
6 | wait
7 | python evaluate_audio.py --dataset_path google_speech_commands/splitted_data --dataset_split_name test --output_name output/softmax --num_classes 12 --checkpoint_path work/v1/Res15NarrowModel-0/mfcc_40_3010_0.00001_adam_s1/valid/accuracy/valid --num_silent 257 --augmentation_method anchored_slice_or_pad --preprocess_method mfcc --num_mfccs 40 --clip_duration_ms 1000 --window_size_ms 30 --window_stride_ms 10 --background_frequency 0.0 --background_max_volume 0.0 --max_step_from_restore 9000 --batch_size 39 --no-shuffle --valid_type once Res15NarrowModel --weight_decay 0.00001
8 |
--------------------------------------------------------------------------------
/scripts/commands/Res8Model-0_mfcc_40_3010_0.00001_adam_s1.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | trap 'pkill -P $$' SIGINT SIGTERM EXIT
3 | python train_audio.py --dataset_path google_speech_commands/splitted_data --dataset_split_name train --output_name output/softmax --num_classes 12 --train_dir work/v1/Res8Model-0/mfcc_40_3010_0.00001_adam_s1 --num_silent 1854 --augmentation_method anchored_slice_or_pad_with_shift --preprocess_method mfcc --num_mfccs 40 --clip_duration_ms 1000 --window_size_ms 30 --window_stride_ms 10 --batch_size 64 --boundaries 3000 6000 --max_step_from_restore 9000 --lr_list 0.1 0.01 0.001 --absolute_schedule --no-boundaries_epoch --max_to_keep 20 --step_save_checkpoint 500 --step_evaluation 500 --optimizer adam Res8Model --weight_decay 0.00001 &
4 | sleep 5
5 | python evaluate_audio.py --dataset_path google_speech_commands/splitted_data --dataset_split_name valid --output_name output/softmax --num_classes 12 --checkpoint_path work/v1/Res8Model-0/mfcc_40_3010_0.00001_adam_s1 --num_silent 258 --augmentation_method anchored_slice_or_pad --preprocess_method mfcc --num_mfccs 40 --clip_duration_ms 1000 --window_size_ms 30 --window_stride_ms 10 --background_frequency 0.0 --background_max_volume 0.0 --max_step_from_restore 9000 --batch_size 3 --no-shuffle --valid_type loop Res8Model --weight_decay 0.00001 &
6 | wait
7 | python evaluate_audio.py --dataset_path google_speech_commands/splitted_data --dataset_split_name test --output_name output/softmax --num_classes 12 --checkpoint_path work/v1/Res8Model-0/mfcc_40_3010_0.00001_adam_s1/valid/accuracy/valid --num_silent 257 --augmentation_method anchored_slice_or_pad --preprocess_method mfcc --num_mfccs 40 --clip_duration_ms 1000 --window_size_ms 30 --window_stride_ms 10 --background_frequency 0.0 --background_max_volume 0.0 --max_step_from_restore 9000 --batch_size 39 --no-shuffle --valid_type once Res8Model --weight_decay 0.00001
8 |
--------------------------------------------------------------------------------
/scripts/commands/Res8NarrowModel-0_mfcc_40_3010_0.00001_adam_s1.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | trap 'pkill -P $$' SIGINT SIGTERM EXIT
3 | python train_audio.py --dataset_path google_speech_commands/splitted_data --dataset_split_name train --output_name output/softmax --num_classes 12 --train_dir work/v1/Res8NarrowModel-0/mfcc_40_3010_0.00001_adam_s1 --num_silent 1854 --augmentation_method anchored_slice_or_pad_with_shift --preprocess_method mfcc --num_mfccs 40 --clip_duration_ms 1000 --window_size_ms 30 --window_stride_ms 10 --batch_size 64 --boundaries 3000 6000 --max_step_from_restore 9000 --lr_list 0.1 0.01 0.001 --absolute_schedule --no-boundaries_epoch --max_to_keep 20 --step_save_checkpoint 500 --step_evaluation 500 --optimizer adam Res8NarrowModel --weight_decay 0.00001 &
4 | sleep 5
5 | python evaluate_audio.py --dataset_path google_speech_commands/splitted_data --dataset_split_name valid --output_name output/softmax --num_classes 12 --checkpoint_path work/v1/Res8NarrowModel-0/mfcc_40_3010_0.00001_adam_s1 --num_silent 258 --augmentation_method anchored_slice_or_pad --preprocess_method mfcc --num_mfccs 40 --clip_duration_ms 1000 --window_size_ms 30 --window_stride_ms 10 --background_frequency 0.0 --background_max_volume 0.0 --max_step_from_restore 9000 --batch_size 3 --no-shuffle --valid_type loop Res8NarrowModel --weight_decay 0.00001 &
6 | wait
7 | python evaluate_audio.py --dataset_path google_speech_commands/splitted_data --dataset_split_name test --output_name output/softmax --num_classes 12 --checkpoint_path work/v1/Res8NarrowModel-0/mfcc_40_3010_0.00001_adam_s1/valid/accuracy/valid --num_silent 257 --augmentation_method anchored_slice_or_pad --preprocess_method mfcc --num_mfccs 40 --clip_duration_ms 1000 --window_size_ms 30 --window_stride_ms 10 --background_frequency 0.0 --background_max_volume 0.0 --max_step_from_restore 9000 --batch_size 39 --no-shuffle --valid_type once Res8NarrowModel --weight_decay 0.00001
8 |
--------------------------------------------------------------------------------
/scripts/commands/TCResNet14Model-1.0_mfcc_40_3010_0.001_mom_l1.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | trap 'pkill -P $$' SIGINT SIGTERM EXIT
3 | python train_audio.py --dataset_path google_speech_commands/splitted_data --dataset_split_name train --output_name output/softmax --num_classes 12 --train_dir work/v1/TCResNet14Model-1.0/mfcc_40_3010_0.001_mom_l1 --num_silent 1854 --augmentation_method anchored_slice_or_pad_with_shift --preprocess_method mfcc --num_mfccs 40 --clip_duration_ms 1000 --window_size_ms 30 --window_stride_ms 10 --batch_size 100 --boundaries 10000 20000 --max_step_from_restore 30000 --lr_list 0.1 0.01 0.001 --absolute_schedule --no-boundaries_epoch --max_to_keep 20 --step_save_checkpoint 500 --step_evaluation 500 --optimizer mom --momentum 0.9 TCResNet14Model --weight_decay 0.001 --width_multiplier 1.0 &
4 | sleep 5
5 | python evaluate_audio.py --dataset_path google_speech_commands/splitted_data --dataset_split_name valid --output_name output/softmax --num_classes 12 --checkpoint_path work/v1/TCResNet14Model-1.0/mfcc_40_3010_0.001_mom_l1 --num_silent 258 --augmentation_method anchored_slice_or_pad --preprocess_method mfcc --num_mfccs 40 --clip_duration_ms 1000 --window_size_ms 30 --window_stride_ms 10 --background_frequency 0.0 --background_max_volume 0.0 --max_step_from_restore 30000 --batch_size 3 --no-shuffle --valid_type loop TCResNet14Model --weight_decay 0.001 --width_multiplier 1.0 &
6 | wait
7 | python evaluate_audio.py --dataset_path google_speech_commands/splitted_data --dataset_split_name test --output_name output/softmax --num_classes 12 --checkpoint_path work/v1/TCResNet14Model-1.0/mfcc_40_3010_0.001_mom_l1/valid/accuracy/valid --num_silent 257 --augmentation_method anchored_slice_or_pad --preprocess_method mfcc --num_mfccs 40 --clip_duration_ms 1000 --window_size_ms 30 --window_stride_ms 10 --background_frequency 0.0 --background_max_volume 0.0 --max_step_from_restore 30000 --batch_size 39 --no-shuffle --valid_type once TCResNet14Model --weight_decay 0.001 --width_multiplier 1.0
8 |
--------------------------------------------------------------------------------
/scripts/commands/TCResNet14Model-1.5_mfcc_40_3010_0.001_mom_l1.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | trap 'pkill -P $$' SIGINT SIGTERM EXIT
3 | python train_audio.py --dataset_path google_speech_commands/splitted_data --dataset_split_name train --output_name output/softmax --num_classes 12 --train_dir work/v1/TCResNet14Model-1.5/mfcc_40_3010_0.001_mom_l1 --num_silent 1854 --augmentation_method anchored_slice_or_pad_with_shift --preprocess_method mfcc --num_mfccs 40 --clip_duration_ms 1000 --window_size_ms 30 --window_stride_ms 10 --batch_size 100 --boundaries 10000 20000 --max_step_from_restore 30000 --lr_list 0.1 0.01 0.001 --absolute_schedule --no-boundaries_epoch --max_to_keep 20 --step_save_checkpoint 500 --step_evaluation 500 --optimizer mom --momentum 0.9 TCResNet14Model --weight_decay 0.001 --width_multiplier 1.5 &
4 | sleep 5
5 | python evaluate_audio.py --dataset_path google_speech_commands/splitted_data --dataset_split_name valid --output_name output/softmax --num_classes 12 --checkpoint_path work/v1/TCResNet14Model-1.5/mfcc_40_3010_0.001_mom_l1 --num_silent 258 --augmentation_method anchored_slice_or_pad --preprocess_method mfcc --num_mfccs 40 --clip_duration_ms 1000 --window_size_ms 30 --window_stride_ms 10 --background_frequency 0.0 --background_max_volume 0.0 --max_step_from_restore 30000 --batch_size 3 --no-shuffle --valid_type loop TCResNet14Model --weight_decay 0.001 --width_multiplier 1.5 &
6 | wait
7 | python evaluate_audio.py --dataset_path google_speech_commands/splitted_data --dataset_split_name test --output_name output/softmax --num_classes 12 --checkpoint_path work/v1/TCResNet14Model-1.5/mfcc_40_3010_0.001_mom_l1/valid/accuracy/valid --num_silent 257 --augmentation_method anchored_slice_or_pad --preprocess_method mfcc --num_mfccs 40 --clip_duration_ms 1000 --window_size_ms 30 --window_stride_ms 10 --background_frequency 0.0 --background_max_volume 0.0 --max_step_from_restore 30000 --batch_size 39 --no-shuffle --valid_type once TCResNet14Model --weight_decay 0.001 --width_multiplier 1.5
8 |
--------------------------------------------------------------------------------
/scripts/commands/TCResNet2D8Model-1.0_mfcc_40_3010_0.001_mom_l1.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | trap 'pkill -P $$' SIGINT SIGTERM EXIT
3 | python train_audio.py --dataset_path google_speech_commands/splitted_data --dataset_split_name train --output_name output/softmax --num_classes 12 --train_dir work/v1/ResNet2D8Model-1.0/mfcc_40_3010_0.001_mom_l1 --num_silent 1854 --augmentation_method anchored_slice_or_pad_with_shift --preprocess_method mfcc --num_mfccs 40 --clip_duration_ms 1000 --window_size_ms 30 --window_stride_ms 10 --batch_size 100 --boundaries 10000 20000 --max_step_from_restore 30000 --lr_list 0.1 0.01 0.001 --absolute_schedule --no-boundaries_epoch --max_to_keep 20 --step_save_checkpoint 500 --step_evaluation 500 --optimizer mom --momentum 0.9 ResNet2D8Model --weight_decay 0.001 --width_multiplier 1.0 &
4 | sleep 5
5 | python evaluate_audio.py --dataset_path google_speech_commands/splitted_data --dataset_split_name valid --output_name output/softmax --num_classes 12 --checkpoint_path work/v1/ResNet2D8Model-1.0/mfcc_40_3010_0.001_mom_l1 --num_silent 258 --augmentation_method anchored_slice_or_pad --preprocess_method mfcc --num_mfccs 40 --clip_duration_ms 1000 --window_size_ms 30 --window_stride_ms 10 --background_frequency 0.0 --background_max_volume 0.0 --max_step_from_restore 30000 --batch_size 3 --no-shuffle --valid_type loop ResNet2D8Model --weight_decay 0.001 --width_multiplier 1.0 &
6 | wait
7 | python evaluate_audio.py --dataset_path google_speech_commands/splitted_data --dataset_split_name test --output_name output/softmax --num_classes 12 --checkpoint_path work/v1/ResNet2D8Model-1.0/mfcc_40_3010_0.001_mom_l1/valid/accuracy/valid --num_silent 257 --augmentation_method anchored_slice_or_pad --preprocess_method mfcc --num_mfccs 40 --clip_duration_ms 1000 --window_size_ms 30 --window_stride_ms 10 --background_frequency 0.0 --background_max_volume 0.0 --max_step_from_restore 30000 --batch_size 39 --no-shuffle --valid_type once ResNet2D8Model --weight_decay 0.001 --width_multiplier 1.0
8 |
--------------------------------------------------------------------------------
/scripts/commands/TCResNet2D8PoolModel-1.0_mfcc_40_3010_0.001_mom_l1.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | trap 'pkill -P $$' SIGINT SIGTERM EXIT
3 | python train_audio.py --dataset_path google_speech_commands/splitted_data --dataset_split_name train --output_name output/softmax --num_classes 12 --train_dir work/v1/ResNet2D8PoolModel-1.0/mfcc_40_3010_0.001_mom_l1 --num_silent 1854 --augmentation_method anchored_slice_or_pad_with_shift --preprocess_method mfcc --num_mfccs 40 --clip_duration_ms 1000 --window_size_ms 30 --window_stride_ms 10 --batch_size 100 --boundaries 10000 20000 --max_step_from_restore 30000 --lr_list 0.1 0.01 0.001 --absolute_schedule --no-boundaries_epoch --max_to_keep 20 --step_save_checkpoint 500 --step_evaluation 500 --optimizer mom --momentum 0.9 ResNet2D8PoolModel --weight_decay 0.001 --width_multiplier 1.0 &
4 | sleep 5
5 | python evaluate_audio.py --dataset_path google_speech_commands/splitted_data --dataset_split_name valid --output_name output/softmax --num_classes 12 --checkpoint_path work/v1/ResNet2D8PoolModel-1.0/mfcc_40_3010_0.001_mom_l1 --num_silent 258 --augmentation_method anchored_slice_or_pad --preprocess_method mfcc --num_mfccs 40 --clip_duration_ms 1000 --window_size_ms 30 --window_stride_ms 10 --background_frequency 0.0 --background_max_volume 0.0 --max_step_from_restore 30000 --batch_size 3 --no-shuffle --valid_type loop ResNet2D8PoolModel --weight_decay 0.001 --width_multiplier 1.0 &
6 | wait
7 | python evaluate_audio.py --dataset_path google_speech_commands/splitted_data --dataset_split_name test --output_name output/softmax --num_classes 12 --checkpoint_path work/v1/ResNet2D8PoolModel-1.0/mfcc_40_3010_0.001_mom_l1/valid/accuracy/valid --num_silent 257 --augmentation_method anchored_slice_or_pad --preprocess_method mfcc --num_mfccs 40 --clip_duration_ms 1000 --window_size_ms 30 --window_stride_ms 10 --background_frequency 0.0 --background_max_volume 0.0 --max_step_from_restore 30000 --batch_size 39 --no-shuffle --valid_type once ResNet2D8PoolModel --weight_decay 0.001 --width_multiplier 1.0
8 |
--------------------------------------------------------------------------------
/scripts/commands/TCResNet8Model-1.0_mfcc_40_3010_0.001_mom_l1.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | trap 'pkill -P $$' SIGINT SIGTERM EXIT
3 | python train_audio.py --dataset_path google_speech_commands/splitted_data --dataset_split_name train --output_name output/softmax --num_classes 12 --train_dir work/v1/TCResNet8Model-1.0/mfcc_40_3010_0.001_mom_l1 --num_silent 1854 --augmentation_method anchored_slice_or_pad_with_shift --preprocess_method mfcc --num_mfccs 40 --clip_duration_ms 1000 --window_size_ms 30 --window_stride_ms 10 --batch_size 100 --boundaries 10000 20000 --max_step_from_restore 30000 --lr_list 0.1 0.01 0.001 --absolute_schedule --no-boundaries_epoch --max_to_keep 20 --step_save_checkpoint 500 --step_evaluation 500 --optimizer mom --momentum 0.9 TCResNet8Model --weight_decay 0.001 --width_multiplier 1.0 &
4 | sleep 5
5 | python evaluate_audio.py --dataset_path google_speech_commands/splitted_data --dataset_split_name valid --output_name output/softmax --num_classes 12 --checkpoint_path work/v1/TCResNet8Model-1.0/mfcc_40_3010_0.001_mom_l1 --num_silent 258 --augmentation_method anchored_slice_or_pad --preprocess_method mfcc --num_mfccs 40 --clip_duration_ms 1000 --window_size_ms 30 --window_stride_ms 10 --background_frequency 0.0 --background_max_volume 0.0 --max_step_from_restore 30000 --batch_size 3 --no-shuffle --valid_type loop TCResNet8Model --weight_decay 0.001 --width_multiplier 1.0 &
6 | wait
7 | python evaluate_audio.py --dataset_path google_speech_commands/splitted_data --dataset_split_name test --output_name output/softmax --num_classes 12 --checkpoint_path work/v1/TCResNet8Model-1.0/mfcc_40_3010_0.001_mom_l1/valid/accuracy/valid --num_silent 257 --augmentation_method anchored_slice_or_pad --preprocess_method mfcc --num_mfccs 40 --clip_duration_ms 1000 --window_size_ms 30 --window_stride_ms 10 --background_frequency 0.0 --background_max_volume 0.0 --max_step_from_restore 30000 --batch_size 39 --no-shuffle --valid_type once TCResNet8Model --weight_decay 0.001 --width_multiplier 1.0
8 |
--------------------------------------------------------------------------------
/scripts/commands/TCResNet8Model-1.5_mfcc_40_3010_0.001_mom_l1.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | trap 'pkill -P $$' SIGINT SIGTERM EXIT
3 | python train_audio.py --dataset_path google_speech_commands/splitted_data --dataset_split_name train --output_name output/softmax --num_classes 12 --train_dir work/v1/TCResNet8Model-1.5/mfcc_40_3010_0.001_mom_l1 --num_silent 1854 --augmentation_method anchored_slice_or_pad_with_shift --preprocess_method mfcc --num_mfccs 40 --clip_duration_ms 1000 --window_size_ms 30 --window_stride_ms 10 --batch_size 100 --boundaries 10000 20000 --max_step_from_restore 30000 --lr_list 0.1 0.01 0.001 --absolute_schedule --no-boundaries_epoch --max_to_keep 20 --step_save_checkpoint 500 --step_evaluation 500 --optimizer mom --momentum 0.9 TCResNet8Model --weight_decay 0.001 --width_multiplier 1.5 &
4 | sleep 5
5 | python evaluate_audio.py --dataset_path google_speech_commands/splitted_data --dataset_split_name valid --output_name output/softmax --num_classes 12 --checkpoint_path work/v1/TCResNet8Model-1.5/mfcc_40_3010_0.001_mom_l1 --num_silent 258 --augmentation_method anchored_slice_or_pad --preprocess_method mfcc --num_mfccs 40 --clip_duration_ms 1000 --window_size_ms 30 --window_stride_ms 10 --background_frequency 0.0 --background_max_volume 0.0 --max_step_from_restore 30000 --batch_size 3 --no-shuffle --valid_type loop TCResNet8Model --weight_decay 0.001 --width_multiplier 1.5 &
6 | wait
7 | python evaluate_audio.py --dataset_path google_speech_commands/splitted_data --dataset_split_name test --output_name output/softmax --num_classes 12 --checkpoint_path work/v1/TCResNet8Model-1.5/mfcc_40_3010_0.001_mom_l1/valid/accuracy/valid --num_silent 257 --augmentation_method anchored_slice_or_pad --preprocess_method mfcc --num_mfccs 40 --clip_duration_ms 1000 --window_size_ms 30 --window_stride_ms 10 --background_frequency 0.0 --background_max_volume 0.0 --max_step_from_restore 30000 --batch_size 39 --no-shuffle --valid_type once TCResNet8Model --weight_decay 0.001 --width_multiplier 1.5
8 |
--------------------------------------------------------------------------------
/scripts/google_speech_commmands_dataset_to_our_format.py:
--------------------------------------------------------------------------------
1 | import random
2 | import argparse
3 |
4 | from pathlib import Path
5 | from collections import defaultdict
6 |
7 |
8 | UNKNOWN_WORD_LABEL = 'unknown'
9 | BACKGROUND_NOISE_DIR_NAME = '_background_noise_'
10 | RANDOM_SEED = 59185
11 | random.seed(RANDOM_SEED)
12 |
13 |
14 | def check_path_existence(path, name):
15 | assert path.exists(), (f"{name} ({path}) does not exist!")
16 |
17 |
18 | def parse_arguments():
19 | parser = argparse.ArgumentParser(description=__doc__)
20 |
21 | parser.add_argument("--input_dir", type=lambda p: Path(p), required=True,
22 | help="Directory as the result of `tar -zxvf `.")
23 | parser.add_argument("--background_noise_dir", type=lambda p: Path(p), required=True,
24 | help="Directory containing noise wav files")
25 | parser.add_argument("--test_list_fullpath", type=lambda p: Path(p), required=True,
26 | help="Textfile which contains name of test wave files.")
27 | parser.add_argument("--valid_list_fullpath", type=lambda p: Path(p), required=True,
28 | help="Textfile which contains name of validation wave files.")
29 | parser.add_argument("--output_dir", type=lambda p: Path(p), required=True,
30 | help="Directory which will contain the result dataset.")
31 | parser.add_argument("--wanted_words", type=str, default="",
32 | help="Comma seperated words to be categorized as foreground. Default '' means take all.")
33 |
34 | args = parser.parse_args()
35 |
36 | # validation check
37 | check_path_existence(args.input_dir, "Input directory")
38 | check_path_existence(args.background_noise_dir, "Background noise directory")
39 | check_path_existence(args.test_list_fullpath, "`test_list_fullpath`")
40 | check_path_existence(args.valid_list_fullpath, "`valid_list_fullpath`")
41 | assert not args.output_dir.exists() or len([p for p in args.output_dir.iterdir()]) == 0, (
42 | f"Output directory ({args.output_dir}) should be empty!")
43 |
44 | return args
45 |
46 |
47 | def get_label_and_filename(p):
48 | parts = p.parts
49 | label, filename = parts[-2], parts[-1]
50 | return label, filename
51 |
52 |
53 | def is_valid_label(label, valid_labels):
54 | if valid_labels:
55 | is_valid = label in valid_labels
56 | else:
57 | is_valid = True
58 |
59 | return is_valid
60 |
61 |
62 | def is_noise_label(label):
63 | return label == BACKGROUND_NOISE_DIR_NAME
64 |
65 |
66 | def split_files(input_dir, valid_list_fullpath, test_list_fullpath, wanted_words):
67 | # load split list
68 | with test_list_fullpath.open("r") as fr:
69 | test_names = {row.strip(): True for row in fr.readlines()}
70 |
71 | with valid_list_fullpath.open("r") as fr:
72 | valid_names = {row.strip(): True for row in fr.readlines()}
73 |
74 | # set labels
75 | if len(wanted_words) > 0:
76 | valid_labels = set(wanted_words.split(","))
77 | else:
78 | valid_labels = None
79 |
80 | labels = list()
81 | for p in input_dir.iterdir():
82 | if p.is_dir() and is_valid_label(p.name, valid_labels) and not is_noise_label(p.name):
83 | labels.append(p.name)
84 | assert len(set(labels)) == len(labels), f"{len(set(labels))} == {len(labels)}"
85 |
86 | # update valid_labels
87 | if len(wanted_words) > 0:
88 | len(valid_labels) == len(labels)
89 | valid_labels = set(labels)
90 |
91 | # iter input directory to get all wav files
92 | samples = {
93 | 'train': defaultdict(list),
94 | 'valid': defaultdict(list),
95 | 'test': defaultdict(list),
96 | }
97 |
98 | for p in input_dir.rglob("*.wav"):
99 | label, filename = get_label_and_filename(p)
100 | if not is_noise_label(label):
101 | name = f"{label}/{filename}"
102 |
103 | if not is_valid_label(label, valid_labels):
104 | label = UNKNOWN_WORD_LABEL
105 |
106 | if test_names.get(name, False):
107 | samples["test"][label].append(p)
108 | elif valid_names.get(name, False):
109 | samples["valid"][label].append(p)
110 | else:
111 | samples["train"][label].append(p)
112 |
113 | has_unknown = all([UNKNOWN_WORD_LABEL in label_samples for split, label_samples in samples.items()])
114 | for split, label_samples in samples.items():
115 | if has_unknown:
116 | assert len(label_samples) == len(valid_labels) + 1, f"{set(label_samples)} == {valid_labels}"
117 | else:
118 | assert len(label_samples) == len(valid_labels), f"{set(label_samples)} == {valid_labels}"
119 |
120 | # number of samples
121 | num_train = sum(map(lambda kv: len(kv[1]), samples["train"].items()))
122 | num_valid = sum(map(lambda kv: len(kv[1]), samples["valid"].items()))
123 | num_test = sum(map(lambda kv: len(kv[1]), samples["test"].items()))
124 |
125 | num_samples = num_train + num_valid + num_test
126 |
127 | print(f"Num samples with train / valid / test split: {num_train} / {num_valid} / {num_test}")
128 | print(f"Total {num_samples} samples, {len(valid_labels)} labels")
129 | assert num_train > num_test
130 | assert num_train > num_valid
131 |
132 | # filtering unknown samples
133 | # the number of unknown samples -> mean of all other samples
134 | if has_unknown:
135 | mean_num_samples_per_label = dict()
136 | for split, label_samples in samples.items():
137 | s = 0
138 | c = 0
139 | for label, sample_list in label_samples.items():
140 | if label != UNKNOWN_WORD_LABEL:
141 | s += len(sample_list)
142 | c += 1
143 |
144 | m = int(s / c)
145 |
146 | unknown_samples = label_samples[UNKNOWN_WORD_LABEL]
147 | if len(unknown_samples) > m:
148 | random.shuffle(unknown_samples)
149 | label_samples[UNKNOWN_WORD_LABEL] = unknown_samples[:m]
150 |
151 | # number of samples
152 | print("After Filtered:")
153 | num_train = sum(map(lambda kv: len(kv[1]), samples["train"].items()))
154 | num_valid = sum(map(lambda kv: len(kv[1]), samples["valid"].items()))
155 | num_test = sum(map(lambda kv: len(kv[1]), samples["test"].items()))
156 |
157 | num_samples = num_train + num_valid + num_test
158 |
159 | print(f"Num samples with train / valid / test split: {num_train} / {num_valid} / {num_test}")
160 | print(f"Total {num_samples} samples, {len(valid_labels)} labels")
161 |
162 | return samples
163 |
164 |
165 | def generate_dataset(split_samples, background_noise_dir, output_dir):
166 | # make output_dir
167 | if not output_dir.exists():
168 | output_dir.mkdir(parents=True)
169 |
170 | # link splitted samples
171 | for split, label_samples in split_samples.items():
172 | base_dir = output_dir / split
173 | base_dir.mkdir()
174 |
175 | # make labels
176 | for label in label_samples:
177 | label_dir = base_dir / label
178 | label_dir.mkdir()
179 |
180 | # link samples
181 | for label, samples in label_samples.items():
182 | for sample_path in samples:
183 | # do not use label from get_label_and_filename.
184 | # we already replace some of them as UNKNOWN_WORD_LABEL
185 | old_label, filename = get_label_and_filename(sample_path)
186 |
187 | if label == UNKNOWN_WORD_LABEL:
188 | filename = f"{old_label}_{filename}"
189 |
190 | target_path = base_dir / label / filename
191 | target_path.symlink_to(sample_path)
192 |
193 | # link background_noise
194 | if split == "train":
195 | noise_dir = base_dir / BACKGROUND_NOISE_DIR_NAME
196 | noise_dir.symlink_to(background_noise_dir, target_is_directory=True)
197 |
198 | print(f"Make {base_dir} done.")
199 |
200 |
201 | if __name__ == "__main__":
202 | args = parse_arguments()
203 |
204 | samples = split_files(args.input_dir, args.valid_list_fullpath, args.test_list_fullpath, args.wanted_words)
205 | generate_dataset(samples, args.background_noise_dir, args.output_dir)
206 |
--------------------------------------------------------------------------------
/speech_commands_dataset/README.md:
--------------------------------------------------------------------------------
1 | # Speech Commands Data Set
2 |
3 | ## About Data Set
4 | [Speech Commands Data Set](https://ai.googleblog.com/2017/08/launching-speech-commands-dataset.html) is a set of one-second .wav audio files, each containing a single spoken English word.
5 | These words are from a small set of commands, and are spoken by a variety of different speakers.
6 | The audio files are organized into folders based on the word they contain, and this data set is designed to help train simple machine learning models.
7 |
8 | ## Data Preparation
9 |
10 | In order to be able to train and evaluate models on Speech Commands Data Set using our code base, we have to download Speech Commands Data Set v0.01 and organize data to three exclusive splits:
11 |
12 | * `train`
13 | * `valid`
14 | * `test`
15 |
16 | We offer automated script (`download_and_split.sh`) which will download and process into predefined format.
17 | ```bash
18 | bash download_and_split.sh [/path/for/dataset]
19 | ```
20 |
21 | ## Advanced Usage
22 | Below is a detail description of how `download_and_split.sh` works.
23 |
24 | ### 1. Download Speech Commands Data Set
25 |
26 | ```bash
27 | work_dir=${1:-$(pwd)/google_speech_commands}
28 | mkdir -p ${work_dir}
29 | pushd ${work_dir}
30 | wget http://download.tensorflow.org/data/speech_commands_v0.01.tar.gz
31 | tar xzf speech_commands_v0.01.tar.gz
32 | popd
33 | ```
34 |
35 | ### 2. Modify Data Set File Structure
36 |
37 | `speech_commands_dataset` contains three files `train.txt`, `valid.txt` and `test.txt` composed of utterance file paths associated with data set splits.
38 |
39 | Our codebase requires file structure for training, validation and test data such as follows:
40 |
41 | ```
42 | ─ v0.01_split
43 | ├── train
44 | │ ├── _background_noise_
45 | │ ├── down
46 | │ ├── go
47 | │ ├── left
48 | │ ├── no
49 | │ ├── off
50 | │ ├── on
51 | │ ├── right
52 | │ ├── stop
53 | │ ├── unknown
54 | │ ├── up
55 | │ └── yes
56 | ├── valid
57 | │ ├── _background_noise_
58 | │ ├── down
59 | │ ├── go
60 | │ ├── left
61 | │ ├── no
62 | │ ├── off
63 | │ ├── on
64 | │ ├── right
65 | │ ├── stop
66 | │ ├── unknown
67 | │ ├── up
68 | │ └── yes
69 | └── test
70 | ├── _background_noise_
71 | ├── down
72 | ├── go
73 | ├── left
74 | ├── no
75 | ├── off
76 | ├── on
77 | ├── right
78 | ├── stop
79 | ├── unknown
80 | ├── up
81 | └── yes
82 | ```
83 |
84 | To convert Speech Commands Data Set to structure shown above run the command below.
85 | This script does not copy any utterance files, only creates symlinks to the original dataset.
86 |
87 | ```bash
88 | output_dir=${work_dir}/splitted_data
89 | python google_speech_commmands_dataset_to_our_format_with_split.py \
90 | --input_dir `realpath ${work_dir}` \
91 | --train_list_fullpath train.txt \
92 | --valid_list_fullpath valid.txt \
93 | --test_list_fullpath test.txt \
94 | --wanted_words yes,no,up,down,left,right,on,off,stop,go \
95 | --output_dir `realpath ${output_dir}`
96 | ```
97 |
98 | `output_dir` will be used for further train and evaluation.
99 |
100 | ## Why don't we use script provided by Tensorflow for data split?
101 |
102 | TL;DR Script does not create deterministic split even with fixed random seed.
103 |
104 | Tensorflow provides [preprocessing scripts](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/speech_commands) for Speech Commands Data Set, however while [creating `unknown_index` dictionary](https://github.com/tensorflow/tensorflow/blob/d75adcad0d74e15893a93eb02bed447ec7971664/tensorflow/examples/speech_commands/input_data.py#L280-L294) nondeterministic ordering of utterance files in `unknown_index` comes into play.
105 | [`prepare_data_index` method is using only limited number of words](https://github.com/tensorflow/tensorflow/blob/d75adcad0d74e15893a93eb02bed447ec7971664/tensorflow/examples/speech_commands/input_data.py#L315-L316) from `unknown_index` and assumes that the order of directories and files sought with `gfile.Glob(search_path)` is deterministic on any platform.
106 | Unfortunately, the ordering of directories and files returned by `gfile.Glob(search_path)` is given by `directory order` (can be shown with [`ls -U`](https://unix.stackexchange.com/questions/13451/what-is-the-directory-order-of-files-in-a-directory-used-by-ls-u) command) and does not ensure the same ordering for everybody.
107 |
108 | Because we want to ensure that anybody can reproduce our results we provide train, valid, and test split in `train.txt`, `valid.txt`, and `test.txt`, which were generated by Tensorflow's preprocessing scripts in our environment.
109 |
--------------------------------------------------------------------------------
/speech_commands_dataset/download_and_split.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | set -eux
3 |
4 | work_dir=${1:-$(pwd)/google_speech_commands}
5 |
6 | # download file
7 | mkdir -p ${work_dir}
8 | pushd ${work_dir}
9 | wget http://download.tensorflow.org/data/speech_commands_v0.01.tar.gz
10 | tar xzf speech_commands_v0.01.tar.gz
11 | popd
12 |
13 | # split data
14 | output_dir=${work_dir}/splitted_data
15 | python google_speech_commmands_dataset_to_our_format_with_split.py \
16 | --input_dir `realpath ${work_dir}` \
17 | --train_list_fullpath train.txt \
18 | --valid_list_fullpath valid.txt \
19 | --test_list_fullpath test.txt \
20 | --wanted_words yes,no,up,down,left,right,on,off,stop,go \
21 | --output_dir `realpath ${output_dir}`
22 | echo "Dataset is prepared at ${output_dir}"
23 |
--------------------------------------------------------------------------------
/speech_commands_dataset/google_speech_commmands_dataset_to_our_format_with_split.py:
--------------------------------------------------------------------------------
1 | """
2 | python scripts/google_speech_commmands_dataset_to_our_format_with_split.py \
3 | --input_dir "/data/data/data_speech_commands_v0.01" \
4 | --train_list_fullpath "/data/data/data_speech_commands_v0.01/newsplit/training_list.txt" \
5 | --valid_list_fullpath "/data/data/data_speech_commands_v0.01/newsplit/validation_list.txt" \
6 | --test_list_fullpath "/data/data/data_speech_commands_v0.01/newsplit/testing_list.txt" \
7 | --wanted_words "yes,no,up,down,left,right,on,off,stop,go" \
8 | --output_dir "/data/data/google_audio/data_speech_commands_v0.01"
9 | """
10 | import random
11 | import argparse
12 |
13 | from pathlib import Path
14 | from collections import defaultdict
15 |
16 |
17 | UNKNOWN_WORD_LABEL = 'unknown'
18 | BACKGROUND_NOISE_LABEL = '_silence_'
19 | BACKGROUND_NOISE_DIR_NAME = '_background_noise_'
20 |
21 |
22 | def check_path_existence(path, name):
23 | assert path.exists(), (f"{name} ({path}) does not exist!")
24 |
25 |
26 | def parse_arguments():
27 | parser = argparse.ArgumentParser(description=__doc__)
28 |
29 | parser.add_argument("--input_dir", type=lambda p: Path(p), required=True,
30 | help="Directory as the result of `tar -zxvf `.")
31 | parser.add_argument("--test_list_fullpath", type=lambda p: Path(p), required=True,
32 | help="Textfile which contains name of test wave files.")
33 | parser.add_argument("--valid_list_fullpath", type=lambda p: Path(p), required=True,
34 | help="Textfile which contains name of validation wave files.")
35 | parser.add_argument("--train_list_fullpath", type=lambda p: Path(p), required=True,
36 | help="Textfile which contains name of train wave files.")
37 | parser.add_argument("--output_dir", type=lambda p: Path(p), required=True,
38 | help="Directory which will contain the result dataset.")
39 | parser.add_argument("--wanted_words", type=str, default="",
40 | help="Comma seperated words to be categorized as foreground. Default '' means take all.")
41 |
42 | args = parser.parse_args()
43 |
44 | # validation check
45 | check_path_existence(args.input_dir, "Input directory")
46 | check_path_existence(args.test_list_fullpath, "`test_list_fullpath`")
47 | check_path_existence(args.valid_list_fullpath, "`valid_list_fullpath`")
48 | check_path_existence(args.train_list_fullpath, "`valid_list_fullpath`")
49 | assert not args.output_dir.exists() or len([p for p in args.output_dir.iterdir()]) == 0, (
50 | f"Output directory ({args.output_dir}) should be empty!")
51 |
52 | return args
53 |
54 |
55 | def get_label_and_filename(p):
56 | parts = p.parts
57 | label, filename = parts[-2], parts[-1]
58 | return label, filename
59 |
60 |
61 | def is_valid_label(label, valid_labels):
62 | if valid_labels:
63 | is_valid = label in valid_labels
64 | else:
65 | is_valid = True
66 |
67 | return is_valid
68 |
69 |
70 | def is_noise_label(label):
71 | return label == BACKGROUND_NOISE_DIR_NAME or label == BACKGROUND_NOISE_LABEL
72 |
73 |
74 | def process_files(input_dir, train_list_fullpath, valid_list_fullpath, test_list_fullpath, wanted_words,
75 | output_dir):
76 | # load split list
77 | data = {
78 | "train": [],
79 | "valid": [],
80 | "test": [],
81 | }
82 |
83 | list_fullpath = {
84 | "train": train_list_fullpath,
85 | "valid": valid_list_fullpath,
86 | "test": test_list_fullpath,
87 | }
88 |
89 | for split in data:
90 | with list_fullpath[split].open("r") as fr:
91 | for row in fr.readlines():
92 | label, filename = get_label_and_filename(Path(row.strip()))
93 | data[split].append((label, filename))
94 |
95 | # set labels
96 | if len(wanted_words) > 0:
97 | valid_labels = set(wanted_words.split(","))
98 | else:
99 | valid_labels = None
100 |
101 | labels = list()
102 | for p in input_dir.iterdir():
103 | if p.is_dir() and is_valid_label(p.name, valid_labels) and not is_noise_label(p.name):
104 | labels.append(p.name)
105 | assert len(set(labels)) == len(labels), f"{len(set(labels))} == {len(labels)}"
106 |
107 | # update valid_labels
108 | if len(wanted_words) > 0:
109 | len(valid_labels) == len(labels)
110 | valid_labels = set(labels)
111 | print(f"Valid Labels: {valid_labels}")
112 |
113 | # make dataset!
114 | # make output_dir
115 | if not output_dir.exists():
116 | output_dir.mkdir(parents=True)
117 |
118 | for split, lst in data.items():
119 | # for each split
120 | base_dir = output_dir / split
121 | base_dir.mkdir()
122 |
123 | # make labels for valid + unknown
124 | for label in list(valid_labels) + [UNKNOWN_WORD_LABEL]:
125 | label_dir = base_dir / label
126 | label_dir.mkdir()
127 |
128 | # link files
129 | noise_count = 0
130 | for label, filename in lst:
131 | source_path = input_dir / label / filename
132 |
133 | if is_noise_label(label):
134 | noise_count += 1
135 | else:
136 | if not is_valid_label(label, valid_labels):
137 | filename = f"{label}_{filename}"
138 | label = UNKNOWN_WORD_LABEL
139 |
140 | target_path = base_dir / label / filename
141 | target_path.symlink_to(source_path)
142 |
143 | # report number of noise
144 | print(f"[{split}] Num of silences: {noise_count}")
145 |
146 | # link noise
147 | source_noise_dir = input_dir / BACKGROUND_NOISE_DIR_NAME
148 | target_noise_dir = base_dir / BACKGROUND_NOISE_DIR_NAME
149 | target_noise_dir.symlink_to(source_noise_dir, target_is_directory=True)
150 |
151 |
152 | if __name__ == "__main__":
153 | args = parse_arguments()
154 |
155 | process_files(args.input_dir,
156 | args.train_list_fullpath,
157 | args.valid_list_fullpath,
158 | args.test_list_fullpath,
159 | args.wanted_words,
160 | args.output_dir)
161 |
--------------------------------------------------------------------------------
/tflite_tools/benchmark_model_r1.13_official:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hyperconnect/TC-ResNet/8ccbff3a45590247d8c54cc82129acb90eecf5c8/tflite_tools/benchmark_model_r1.13_official
--------------------------------------------------------------------------------
/tflite_tools/run_benchmark.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 | # ./run_benchmark.sh /path/to/tflite/model.tflite
3 | function run_benchmark() {
4 | model_name=$1
5 | cpu_mask=$2
6 | echo ">>> run_benchmark $1 $2"
7 | if [ $# -eq 1 ]
8 | then
9 | res=$(adb shell /data/local/tmp/benchmark_model_r1.13_official \
10 | --graph=/data/local/tmp/${model_name} \
11 | --num_threads=1 \
12 | --warmup_runs=10 \
13 | --min_secs=0 \
14 | --num_runs=50 2>&1 >/dev/null)
15 | elif [ $# -eq 2 ]
16 | then
17 | res=$(adb shell taskset ${cpu_mask} /data/local/tmp/benchmark_model_r1.13_official \
18 | --graph=/data/local/tmp/${model_name} \
19 | --num_threads=1 \
20 | --warmup_runs=10 \
21 | --min_secs=0 \
22 | --num_runs=50 2>&1 >/dev/null)
23 | fi
24 | echo "${res}"
25 | }
26 |
27 | function run_benchmark_summary() {
28 | model_name=$1
29 | cpu_mask=$2
30 | echo ">>> run_benchmark_summary $1 $2"
31 | res=$(run_benchmark $model_name $cpu_mask | tail -n 3 | head -n 1)
32 | print_highlighted "${model_name} > ${res}"
33 | }
34 |
35 | function print_highlighted() {
36 | message=$1
37 | light_green="\033[92m"
38 | default="\033[0m"
39 | printf "${light_green}${message}${default}\n"
40 | }
41 |
42 | model_path=$1
43 | cpu_mask=$2
44 | adb push benchmark_model_r1.13_official /data/local/tmp/
45 | adb shell 'ls /data/local/tmp/benchmark_model_r1.13_official' | tr -d '\r' | xargs -n1 adb shell chmod +x
46 | adb push ${model_path} /data/local/tmp/
47 |
48 | model_name=`basename $model_path`
49 | run_benchmark_summary $model_name $cpu_mask
50 |
51 |
--------------------------------------------------------------------------------
/train_audio.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from typing import List
3 |
4 | import tensorflow as tf
5 |
6 | import const
7 | import factory.audio_nets as audio_nets
8 | import common.utils as utils
9 | from datasets.data_wrapper_base import DataWrapperBase
10 | from datasets.audio_data_wrapper import AudioDataWrapper
11 | from datasets.audio_data_wrapper import SingleLabelAudioDataWrapper
12 | from helper.base import Base
13 | from helper.trainer import TrainerBase
14 | from helper.trainer import SingleLabelAudioTrainer
15 | from factory.base import TFModel
16 | from metrics.base import MetricManagerBase
17 |
18 |
19 | def train(args):
20 | is_training = True
21 | dataset_name = args.dataset_split_name[0]
22 | session = tf.Session(config=const.TF_SESSION_CONFIG)
23 |
24 | dataset = SingleLabelAudioDataWrapper(
25 | args,
26 | session,
27 | dataset_name,
28 | is_training,
29 | )
30 | wavs, labels = dataset.get_input_and_output_op()
31 |
32 | model = eval(f"audio_nets.{args.model}")(args, dataset)
33 | model.build(wavs=wavs, labels=labels, is_training=is_training)
34 |
35 | trainer = SingleLabelAudioTrainer(
36 | model,
37 | session,
38 | args,
39 | dataset,
40 | dataset_name,
41 | )
42 |
43 | trainer.train()
44 |
45 |
46 | def parse_arguments(arguments: List[str]=None):
47 | parser = argparse.ArgumentParser(description=__doc__)
48 | subparsers = parser.add_subparsers(title="Model", description="")
49 |
50 | TFModel.add_arguments(parser)
51 | audio_nets.AudioNetModel.add_arguments(parser)
52 |
53 | for class_name in audio_nets._available_nets:
54 | subparser = subparsers.add_parser(class_name)
55 | subparser.add_argument("--model", default=class_name, type=str, help="DO NOT FIX ME")
56 | add_audio_net_arguments = eval(f"audio_nets.{class_name}.add_arguments")
57 | add_audio_net_arguments(subparser)
58 |
59 | DataWrapperBase.add_arguments(parser)
60 | AudioDataWrapper.add_arguments(parser)
61 | Base.add_arguments(parser)
62 | TrainerBase.add_arguments(parser)
63 | SingleLabelAudioTrainer.add_arguments(parser)
64 | MetricManagerBase.add_arguments(parser)
65 |
66 | args = parser.parse_args(arguments)
67 | return args
68 |
69 |
70 | if __name__ == "__main__":
71 | args = parse_arguments()
72 | log = utils.get_logger("Trainer")
73 |
74 | utils.update_train_dir(args)
75 |
76 | log.info(args)
77 | train(args)
78 |
--------------------------------------------------------------------------------