├── .gitignore ├── LICENSE ├── README.md ├── audio_nets ├── ds_cnn.py ├── kws.py ├── res.py └── tc_resnet.py ├── common ├── __init__.py ├── model_loader.py ├── tf_utils.py └── utils.py ├── const.py ├── datasets ├── __init__.py ├── audio_data_wrapper.py ├── augmentation_factory.py ├── data_wrapper_base.py ├── preprocessor_factory.py └── preprocessors.py ├── evaluate_audio.py ├── execute_script.sh ├── factory ├── __init__.py ├── audio_nets.py └── base.py ├── figure └── main_figure.png ├── freeze.py ├── helper ├── __init__.py ├── base.py ├── evaluator.py └── trainer.py ├── metrics ├── __init__.py ├── base.py ├── funcs.py ├── manager.py ├── ops │ ├── __init__.py │ ├── base_ops.py │ ├── misc_ops.py │ ├── non_tensor_ops.py │ └── tensor_ops.py ├── parser.py └── summaries.py ├── requirements ├── py36-common.txt ├── py36-cpu.txt └── py36-gpu.txt ├── scripts ├── commands │ ├── DSCNNLModel-0_mfcc_10_4020_0.0000_adam_l3.sh │ ├── DSCNNMModel-0_mfcc_10_4020_0.0000_adam_l3.sh │ ├── DSCNNSModel-0_mfcc_10_4020_0.0000_adam_l3.sh │ ├── KWSfpool3-0_mfcc_40_4020_0.0000_adam_l3.sh │ ├── KWSfstride4-0_mfcc_40_4020_0.0000_adam_l2.sh │ ├── Res15Model-0_mfcc_40_3010_0.00001_adam_s1.sh │ ├── Res15NarrowModel-0_mfcc_40_3010_0.00001_adam_s1.sh │ ├── Res8Model-0_mfcc_40_3010_0.00001_adam_s1.sh │ ├── Res8NarrowModel-0_mfcc_40_3010_0.00001_adam_s1.sh │ ├── TCResNet14Model-1.0_mfcc_40_3010_0.001_mom_l1.sh │ ├── TCResNet14Model-1.5_mfcc_40_3010_0.001_mom_l1.sh │ ├── TCResNet2D8Model-1.0_mfcc_40_3010_0.001_mom_l1.sh │ ├── TCResNet2D8PoolModel-1.0_mfcc_40_3010_0.001_mom_l1.sh │ ├── TCResNet8Model-1.0_mfcc_40_3010_0.001_mom_l1.sh │ └── TCResNet8Model-1.5_mfcc_40_3010_0.001_mom_l1.sh └── google_speech_commmands_dataset_to_our_format.py ├── speech_commands_dataset ├── README.md ├── download_and_split.sh ├── google_speech_commmands_dataset_to_our_format_with_split.py ├── test.txt ├── train.txt └── valid.txt ├── tflite_tools ├── benchmark_model_r1.13_official └── run_benchmark.sh └── train_audio.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Created by https://www.gitignore.io/api/git,python,jetbrains 2 | # Edit at https://www.gitignore.io/?templates=git,python,jetbrains 3 | 4 | ### Git ### 5 | # Created by git for backups. To disable backups in Git: 6 | # $ git config --global mergetool.keepBackup false 7 | *.orig 8 | 9 | # Created by git when using merge tools for conflicts 10 | *.BACKUP.* 11 | *.BASE.* 12 | *.LOCAL.* 13 | *.REMOTE.* 14 | *_BACKUP_*.txt 15 | *_BASE_*.txt 16 | *_LOCAL_*.txt 17 | *_REMOTE_*.txt 18 | 19 | ### JetBrains ### 20 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and WebStorm 21 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 22 | 23 | # User-specific stuff 24 | .idea/**/workspace.xml 25 | .idea/**/tasks.xml 26 | .idea/**/usage.statistics.xml 27 | .idea/**/dictionaries 28 | .idea/**/shelf 29 | 30 | # Generated files 31 | .idea/**/contentModel.xml 32 | 33 | # Sensitive or high-churn files 34 | .idea/**/dataSources/ 35 | .idea/**/dataSources.ids 36 | .idea/**/dataSources.local.xml 37 | .idea/**/sqlDataSources.xml 38 | .idea/**/dynamic.xml 39 | .idea/**/uiDesigner.xml 40 | .idea/**/dbnavigator.xml 41 | 42 | # Gradle 43 | .idea/**/gradle.xml 44 | .idea/**/libraries 45 | 46 | # Gradle and Maven with auto-import 47 | # When using Gradle or Maven with auto-import, you should exclude module files, 48 | # since they will be recreated, and may cause churn. Uncomment if using 49 | # auto-import. 50 | # .idea/modules.xml 51 | # .idea/*.iml 52 | # .idea/modules 53 | 54 | # CMake 55 | cmake-build-*/ 56 | 57 | # Mongo Explorer plugin 58 | .idea/**/mongoSettings.xml 59 | 60 | # File-based project format 61 | *.iws 62 | 63 | # IntelliJ 64 | out/ 65 | 66 | # mpeltonen/sbt-idea plugin 67 | .idea_modules/ 68 | 69 | # JIRA plugin 70 | atlassian-ide-plugin.xml 71 | 72 | # Cursive Clojure plugin 73 | .idea/replstate.xml 74 | 75 | # Crashlytics plugin (for Android Studio and IntelliJ) 76 | com_crashlytics_export_strings.xml 77 | crashlytics.properties 78 | crashlytics-build.properties 79 | fabric.properties 80 | 81 | # Editor-based Rest Client 82 | .idea/httpRequests 83 | 84 | # Android studio 3.1+ serialized cache file 85 | .idea/caches/build_file_checksums.ser 86 | 87 | ### JetBrains Patch ### 88 | # Comment Reason: https://github.com/joeblau/gitignore.io/issues/186#issuecomment-215987721 89 | 90 | # *.iml 91 | # modules.xml 92 | # .idea/misc.xml 93 | # *.ipr 94 | 95 | # Sonarlint plugin 96 | .idea/sonarlint 97 | 98 | ### Python ### 99 | # Byte-compiled / optimized / DLL files 100 | __pycache__/ 101 | *.py[cod] 102 | *$py.class 103 | 104 | # C extensions 105 | *.so 106 | 107 | # Distribution / packaging 108 | .Python 109 | build/ 110 | develop-eggs/ 111 | dist/ 112 | downloads/ 113 | eggs/ 114 | .eggs/ 115 | lib/ 116 | lib64/ 117 | parts/ 118 | sdist/ 119 | var/ 120 | wheels/ 121 | pip-wheel-metadata/ 122 | share/python-wheels/ 123 | *.egg-info/ 124 | .installed.cfg 125 | *.egg 126 | MANIFEST 127 | 128 | # PyInstaller 129 | # Usually these files are written by a python script from a template 130 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 131 | *.manifest 132 | *.spec 133 | 134 | # Installer logs 135 | pip-log.txt 136 | pip-delete-this-directory.txt 137 | 138 | # Unit test / coverage reports 139 | htmlcov/ 140 | .tox/ 141 | .nox/ 142 | .coverage 143 | .coverage.* 144 | .cache 145 | nosetests.xml 146 | coverage.xml 147 | *.cover 148 | .hypothesis/ 149 | .pytest_cache/ 150 | 151 | # Translations 152 | *.mo 153 | *.pot 154 | 155 | # Django stuff: 156 | *.log 157 | local_settings.py 158 | db.sqlite3 159 | 160 | # Flask stuff: 161 | instance/ 162 | .webassets-cache 163 | 164 | # Scrapy stuff: 165 | .scrapy 166 | 167 | # Sphinx documentation 168 | docs/_build/ 169 | 170 | # PyBuilder 171 | target/ 172 | 173 | # Jupyter Notebook 174 | .ipynb_checkpoints 175 | 176 | # IPython 177 | profile_default/ 178 | ipython_config.py 179 | 180 | # pyenv 181 | .python-version 182 | 183 | # pipenv 184 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 185 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 186 | # having no cross-platform support, pipenv may install dependencies that don’t work, or not 187 | # install all needed dependencies. 188 | #Pipfile.lock 189 | 190 | # celery beat schedule file 191 | celerybeat-schedule 192 | 193 | # SageMath parsed files 194 | *.sage.py 195 | 196 | # Environments 197 | .env 198 | .venv 199 | env/ 200 | venv/ 201 | ENV/ 202 | env.bak/ 203 | venv.bak/ 204 | 205 | # Spyder project settings 206 | .spyderproject 207 | .spyproject 208 | 209 | # Rope project settings 210 | .ropeproject 211 | 212 | # mkdocs documentation 213 | /site 214 | 215 | # mypy 216 | .mypy_cache/ 217 | .dmypy.json 218 | dmypy.json 219 | 220 | # Pyre type checker 221 | .pyre/ 222 | 223 | # End of https://www.gitignore.io/api/git,python,jetbrains 224 | 225 | google_speech_commands/ 226 | work/ 227 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Temporal Convolution for Real-time Keyword Spotting on Mobile Devices 2 | 3 |

4 | tc-resnet-temporal-convolution 5 |

6 | 7 | ## Abstract 8 | Keyword spotting (KWS) plays a critical role in enabling speech-based user interactions on smart devices. 9 | Recent developments in the field of deep learning have led to wide adoption of convolutional neural networks (CNNs) in KWS systems due to their exceptional accuracy and robustness. 10 | The main challenge faced by KWS systems is the trade-off between high accuracy and low latency. 11 | Unfortunately, there has been little quantitative analysis of the actual latency of KWS models on mobile devices. 12 | This is especially concerning since conventional convolution-based KWS approaches are known to require a large number of operations to attain an adequate level of performance. 13 | 14 | In this paper, we propose a temporal convolution for real-time KWS on mobile devices. 15 | Unlike most of the 2D convolution-based KWS approaches that require a deep architecture to fully capture both low- and high-frequency domains, we exploit temporal convolutions with a compact ResNet architecture. 16 | In Google Speech Command Dataset, we achieve more than **385x** speedup on Google Pixel 1 and surpass the accuracy compared to the state-of-the-art model. 17 | In addition, we release the implementation of the proposed and the baseline models including an end-to-end pipeline for training models and evaluating them on mobile devices. 18 | 19 | 20 | ## Requirements 21 | 22 | * Python 3.6+ 23 | * Tensorflow 1.13.1 24 | 25 | ## Installation 26 | 27 | ```bash 28 | git clone https://github.com/hyperconnect/TC-ResNet.git 29 | pip3 install -r requirements/py36-[gpu|cpu].txt 30 | ``` 31 | 32 | ## Dataset 33 | 34 | For evaluating the proposed and the baseline models we use [Google Speech Commands Dataset](https://ai.googleblog.com/2017/08/launching-speech-commands-dataset.html). 35 | 36 | ### Google Speech Commands Dataset 37 | 38 | Follow instructions in [speech_commands_dataset/](https://github.com/hyperconnect/TC-ResNet/tree/master/speech_commands_dataset) 39 | 40 | ## How to run 41 | 42 | Scripts to reproduce the training and evaluation procedures discussed in the paper are located on scripts/commands. After training a model, you can generate .tflite file by following the instruction below. 43 | 44 | To train TCResNet8Model-1.0 model, run: 45 | 46 | ``` 47 | ./scripts/commands/TCResNet8Model-1.0_mfcc_40_3010_0.001_mom_l1.sh 48 | ``` 49 | 50 | To freeze the trained model checkpoint into `.pb` file, run: 51 | 52 | ``` 53 | python freeze.py --checkpoint_path work/v1/TCResNet8Model-1.0/mfcc_40_3010_0.001_mom_l1/TCResNet8Model-XXX --output_name output/softmax --output_type softmax --preprocess_method no_preprocessing --height 49 --width 40 --channels 1 --num_classes 12 TCResNet8Model --width_multiplier 1.0 54 | ``` 55 | 56 | To convert the `.pb` file into `.tflite` file, run: 57 | 58 | ``` 59 | tflite_convert --graph_def_file=work/v1/TCResNet8Model-1.0/mfcc_40_3010_0.001_mom_l1/TCResNet8Model-XXX.pb --input_format=TENSORFLOW_GRAPHDEF --output_format=TFLITE --output_file=work/v1/TCResNet8Model-1.0/mfcc_40_3010_0.001_mom_l1/TCResNet8Model-XXX.tflite --inference_type=FLOAT --inference_input_type=FLOAT --input_arrays=input --output_arrays=output/softmax --allow_custom_ops 60 | ``` 61 | 62 | As shown in above commands, you need to properly set `height`, `width`, `model`, model specific arguments(e.g. `width_multiplier`). 63 | For more information, please refer to `scripts/commands/` 64 | 65 | ## Benchmark tool 66 | 67 | [Android Debug Bridge](https://developer.android.com/studio/command-line/adb.html) (`adb`) is required to run the Android benchmark tool (`model/tflite_tools/run_benchmark.sh`). 68 | `adb` is part of [The Android SDK Platform Tools](https://developer.android.com/studio/releases/platform-tools) and you can download it [here](https://developer.android.com/studio/releases/platform-tools.html) and follow the installation instructions. 69 | 70 | ### 1. Connect Android device to your computer 71 | 72 | ### 2. Check if connection is established 73 | 74 | Run following command. 75 | 76 | ```bash 77 | adb devices 78 | ``` 79 | 80 | You should see similar output to the one below. 81 | The ID of a device will, of course, differ. 82 | 83 | ``` 84 | List of devices attached 85 | FA77M0304573 device 86 | ``` 87 | 88 | ### 3. Run benchmark 89 | 90 | Go to `model/tflite_tools` and place the TF Lite model you want to benchmark (e.g. `mobilenet_v1_1.0_224.tflite`) and execute the following command. 91 | You can pass the optional parameter, `cpu_mask`, to set the CPU affinity [CPU affinity](https://github.com/tensorflow/tensorflow/tree/r1.13/tensorflow/lite/tools/benchmark#reducing-variance-between-runs-on-android) 92 | 93 | 94 | ```bash 95 | ./run_benchmark.sh TCResNet_14Model-1.5.tflite [cpu_mask] 96 | ``` 97 | 98 | 99 | If everything goes well you should see an output similar to the one below. 100 | The important measurement of this benchmark is `avg=5701.96` part. 101 | The number represents the average latency of the inference measured in microseconds. 102 | 103 | ``` 104 | ./run_benchmark.sh TCResNet_14Model-1.5_mfcc_40_3010_0.001_mom_l1.tflite 3 105 | benchmark_model_r1.13_official: 1 file pushed. 22.1 MB/s (1265528 bytes in 0.055s) 106 | TCResNet_14Model-1.5_mfcc_40_3010_0.001_mom_l1.tflite: 1 file pushed. 25.0 MB/s (1217136 bytes in 0.046s) 107 | >>> run_benchmark_summary TCResNet_14Model-1.5_mfcc_40_3010_0.001_mom_l1.tflite 3 108 | TCResNet_14Model-1.5_mfcc_40_3010_0.001_mom_l1.tflite > count=50 first=5734 curr=5801 min=4847 max=6516 avg=5701.96 std=210 109 | ``` 110 | 111 | ## License 112 | 113 | [Apache License 2.0](LICENSE) 114 | -------------------------------------------------------------------------------- /audio_nets/ds_cnn.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | from functools import partial 3 | 4 | import tensorflow as tf 5 | 6 | slim = tf.contrib.slim 7 | 8 | 9 | _DEFAULT = { 10 | "type": None, # ["conv", "separable"] 11 | "kernel": [3, 3], 12 | "stride": [1, 1], 13 | "depth": None, 14 | "scope": None, 15 | } 16 | 17 | _Block = namedtuple("Block", _DEFAULT.keys()) 18 | Block = partial(_Block, **_DEFAULT) 19 | 20 | S_NET_DEF = [ 21 | Block(type="conv", depth=64, kernel=[10, 4], stride=[2, 2], scope="conv_1"), 22 | Block(type="separable", depth=64, kernel=[3, 3], stride=[1, 1], scope="conv_ds_1"), 23 | Block(type="separable", depth=64, kernel=[3, 3], stride=[1, 1], scope="conv_ds_2"), 24 | Block(type="separable", depth=64, kernel=[3, 3], stride=[1, 1], scope="conv_ds_3"), 25 | Block(type="separable", depth=64, kernel=[3, 3], stride=[1, 1], scope="conv_ds_4"), 26 | ] 27 | 28 | M_NET_DEF = [ 29 | Block(type="conv", depth=172, kernel=[10, 4], stride=[2, 1], scope="conv_1"), 30 | Block(type="separable", depth=172, kernel=[3, 3], stride=[2, 2], scope="conv_ds_1"), 31 | Block(type="separable", depth=172, kernel=[3, 3], stride=[1, 1], scope="conv_ds_2"), 32 | Block(type="separable", depth=172, kernel=[3, 3], stride=[1, 1], scope="conv_ds_3"), 33 | Block(type="separable", depth=172, kernel=[3, 3], stride=[1, 1], scope="conv_ds_4"), 34 | ] 35 | 36 | L_NET_DEF = [ 37 | Block(type="conv", depth=276, kernel=[10, 4], stride=[2, 1], scope="conv_1"), 38 | Block(type="separable", depth=276, kernel=[3, 3], stride=[2, 2], scope="conv_ds_1"), 39 | Block(type="separable", depth=276, kernel=[3, 3], stride=[1, 1], scope="conv_ds_2"), 40 | Block(type="separable", depth=276, kernel=[3, 3], stride=[1, 1], scope="conv_ds_3"), 41 | Block(type="separable", depth=276, kernel=[3, 3], stride=[1, 1], scope="conv_ds_4"), 42 | Block(type="separable", depth=276, kernel=[3, 3], stride=[1, 1], scope="conv_ds_5"), 43 | ] 44 | 45 | 46 | def _depthwise_separable_conv(inputs, num_pwc_filters, kernel_size, stride): 47 | """ Helper function to build the depth-wise separable convolution layer.""" 48 | # skip pointwise by setting num_outputs=None 49 | depthwise_conv = slim.separable_convolution2d(inputs, 50 | num_outputs=None, 51 | stride=stride, 52 | depth_multiplier=1, 53 | kernel_size=kernel_size, 54 | scope="depthwise_conv") 55 | 56 | bn = slim.batch_norm(depthwise_conv, scope="dw_batch_norm") 57 | pointwise_conv = slim.conv2d(bn, 58 | num_pwc_filters, 59 | kernel_size=[1, 1], 60 | scope="pointwise_conv") 61 | bn = slim.batch_norm(pointwise_conv, scope="pw_batch_norm") 62 | return bn 63 | 64 | 65 | def parse_block(input_net, block): 66 | if block.type == "conv": 67 | net = slim.conv2d( 68 | input_net, 69 | num_outputs=block.depth, 70 | kernel_size=block.kernel, 71 | stride=block.stride, 72 | scope=block.scope 73 | ) 74 | net = slim.batch_norm(net, scope=f"{block.scope}/batch_norm") 75 | elif block.type == "separable": 76 | with tf.variable_scope(block.scope): 77 | net = _depthwise_separable_conv( 78 | input_net, 79 | block.depth, 80 | kernel_size=block.kernel, 81 | stride=block.stride 82 | ) 83 | else: 84 | raise ValueError(f"Block type {block.type} is not supported!") 85 | 86 | return net 87 | 88 | 89 | def DSCNN(inputs, num_classes, net_def, scope="DSCNN"): 90 | endpoints = dict() 91 | 92 | with tf.variable_scope(scope): 93 | net = inputs 94 | for block in net_def: 95 | net = parse_block(net, block) 96 | 97 | net = slim.avg_pool2d(net, kernel_size=net.shape[1:3], stride=1, scope="avg_pool") 98 | net = tf.squeeze(net, [1, 2], name="SpatialSqueeze") 99 | logits = slim.fully_connected(net, num_classes, activation_fn=None, scope="fc1") 100 | 101 | return logits, endpoints 102 | 103 | 104 | def DSCNN_arg_scope(is_training): 105 | batch_norm_params = { 106 | "is_training": is_training, 107 | "decay": 0.96, 108 | "activation_fn": tf.nn.relu, 109 | } 110 | 111 | with slim.arg_scope([slim.conv2d, slim.separable_convolution2d], 112 | activation_fn=None, 113 | weights_initializer=slim.initializers.xavier_initializer(), 114 | biases_initializer=slim.init_ops.zeros_initializer()): 115 | with slim.arg_scope([slim.batch_norm], **batch_norm_params): 116 | with slim.arg_scope([slim.dropout], 117 | is_training=is_training) as scope: 118 | return scope 119 | -------------------------------------------------------------------------------- /audio_nets/res.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | slim = tf.contrib.slim 4 | 5 | 6 | def conv_relu_bn(inputs, num_outputs, kernel_size, stride, idx, use_dilation, bn=False): 7 | scope = f"conv{idx}" 8 | with tf.variable_scope(scope, values=[inputs]): 9 | if use_dilation: 10 | assert stride == 1 11 | rate = int(2**(idx // 3)) 12 | net = slim.conv2d(inputs, 13 | num_outputs=num_outputs, 14 | kernel_size=kernel_size, 15 | stride=stride, 16 | rate=rate) 17 | else: 18 | net = slim.conv2d(inputs, 19 | num_outputs=num_outputs, 20 | kernel_size=kernel_size, 21 | stride=stride) 22 | # conv + relu are done 23 | if bn: 24 | net = slim.batch_norm(net, scope=f"{scope}_bn") 25 | 26 | return net 27 | 28 | 29 | def resnet(inputs, num_classes, num_layers, num_channels, pool_size, use_dilation, scope="Res"): 30 | """Re-implement https://github.com/castorini/honk/blob/master/utils/model.py""" 31 | endpoints = dict() 32 | 33 | with tf.variable_scope(scope): 34 | net = slim.conv2d(inputs, num_channels, kernel_size=3, stride=1, scope="f_conv") 35 | 36 | if pool_size: 37 | net = slim.avg_pool2d(net, kernel_size=pool_size, stride=1, scope="avg_pool0") 38 | 39 | # block 40 | num_blocks = num_layers // 2 41 | idx = 0 42 | for i in range(num_blocks): 43 | layer_in = net 44 | 45 | net = conv_relu_bn(net, num_outputs=num_channels, kernel_size=3, stride=1, idx=idx, 46 | use_dilation=use_dilation, bn=True) 47 | idx += 1 48 | 49 | net = conv_relu_bn(net, num_outputs=num_channels, kernel_size=3, stride=1, idx=(2 * i + 1), 50 | use_dilation=use_dilation, bn=False) 51 | idx += 1 52 | 53 | net += layer_in 54 | net = slim.batch_norm(net, scope=f"conv{2 * i + 1}_bn") 55 | 56 | if num_layers % 2 != 0: 57 | net = conv_relu_bn(net, num_outputs=num_channels, kernel_size=3, stride=1, idx=idx, 58 | use_dilation=use_dilation, bn=True) 59 | 60 | # last 61 | net = slim.avg_pool2d(net, kernel_size=net.shape[1:3], stride=1, scope="avg_pool1") 62 | 63 | logits = slim.conv2d(net, num_classes, 1, activation_fn=None, scope="fc") 64 | logits = tf.reshape(logits, shape=(-1, logits.shape[3]), name="squeeze_logit") 65 | 66 | return logits, endpoints 67 | 68 | 69 | def Res8(inputs, num_classes): 70 | return resnet(inputs, 71 | num_classes, 72 | num_layers=6, 73 | num_channels=45, 74 | pool_size=[4, 3], 75 | use_dilation=False) 76 | 77 | 78 | def Res8Narrow(inputs, num_classes): 79 | return resnet(inputs, 80 | num_classes, 81 | num_layers=6, 82 | num_channels=19, 83 | pool_size=[4, 3], 84 | use_dilation=False) 85 | 86 | 87 | def Res15(inputs, num_classes): 88 | return resnet(inputs, 89 | num_classes, 90 | num_layers=13, 91 | num_channels=45, 92 | pool_size=None, 93 | use_dilation=True) 94 | 95 | 96 | def Res15Narrow(inputs, num_classes): 97 | return resnet(inputs, 98 | num_classes, 99 | num_layers=13, 100 | num_channels=19, 101 | pool_size=None, 102 | use_dilation=True) 103 | 104 | 105 | def Res_arg_scope(is_training, weight_decay=0.00001): 106 | batch_norm_params = { 107 | "is_training": is_training, 108 | "center": False, 109 | "scale": False, 110 | "decay": 0.997, 111 | "fused": True, 112 | } 113 | 114 | with slim.arg_scope([slim.conv2d, slim.fully_connected], 115 | weights_initializer=slim.initializers.xavier_initializer(), 116 | weights_regularizer=slim.l2_regularizer(weight_decay), 117 | activation_fn=tf.nn.relu, 118 | biases_initializer=None, 119 | normalizer_fn=None, 120 | padding="SAME", 121 | ): 122 | with slim.arg_scope([slim.batch_norm], **batch_norm_params) as scope: 123 | return scope 124 | -------------------------------------------------------------------------------- /audio_nets/tc_resnet.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | slim = tf.contrib.slim 4 | 5 | 6 | def tc_resnet(inputs, num_classes, n_blocks, n_channels, scope, debug_2d=False, pool=None): 7 | endpoints = dict() 8 | L = inputs.shape[1] 9 | C = inputs.shape[2] 10 | 11 | assert n_blocks == len(n_channels) - 1 12 | 13 | with tf.variable_scope(scope): 14 | if debug_2d: 15 | conv_kernel = first_conv_kernel = [3, 3] 16 | else: 17 | inputs = tf.reshape(inputs, [-1, L, 1, C]) # [N, L, 1, C] 18 | first_conv_kernel = [3, 1] 19 | conv_kernel = [9, 1] 20 | 21 | net = slim.conv2d(inputs, num_outputs=n_channels[0], kernel_size=first_conv_kernel, stride=1, scope="conv0") 22 | 23 | if pool is not None: 24 | net = slim.avg_pool2d(net, kernel_size=pool[0], stride=pool[1], scope="avg_pool_0") 25 | 26 | n_channels = n_channels[1:] 27 | 28 | for i, n in zip(range(n_blocks), n_channels): 29 | with tf.variable_scope(f"block{i}"): 30 | if n != net.shape[-1]: 31 | stride = 2 32 | layer_in = slim.conv2d(net, num_outputs=n, kernel_size=1, stride=stride, scope=f"down") 33 | else: 34 | layer_in = net 35 | stride = 1 36 | 37 | net = slim.conv2d(net, num_outputs=n, kernel_size=conv_kernel, stride=stride, scope=f"conv{i}_0") 38 | net = slim.conv2d(net, num_outputs=n, kernel_size=conv_kernel, stride=1, scope=f"conv{i}_1", 39 | activation_fn=None) 40 | net += layer_in 41 | net = tf.nn.relu(net) 42 | 43 | net = slim.avg_pool2d(net, kernel_size=net.shape[1:3], stride=1, scope="avg_pool") 44 | 45 | net = slim.dropout(net) 46 | 47 | logits = slim.conv2d(net, num_classes, 1, activation_fn=None, normalizer_fn=None, scope="fc") 48 | logits = tf.reshape(logits, shape=(-1, logits.shape[3]), name="squeeze_logit") 49 | 50 | ranges = slim.conv2d(net, 2, 1, activation_fn=None, normalizer_fn=None, scope="fc2") 51 | ranges = tf.reshape(ranges, shape=(-1, ranges.shape[3]), name="squeeze_logit2") 52 | endpoints["ranges"] = tf.sigmoid(ranges) 53 | 54 | return logits, endpoints 55 | 56 | 57 | def TCResNet8(inputs, num_classes, width_multiplier=1.0, scope="TCResNet8"): 58 | n_blocks = 3 59 | n_channels = [16, 24, 32, 48] 60 | n_channels = [int(x * width_multiplier) for x in n_channels] 61 | 62 | return tc_resnet(inputs, num_classes, n_blocks, n_channels, scope=scope) 63 | 64 | 65 | def TCResNet14(inputs, num_classes, width_multiplier=1.0, scope="TCResNet14"): 66 | n_blocks = 6 67 | n_channels = [16, 24, 24, 32, 32, 48, 48] 68 | n_channels = [int(x * width_multiplier) for x in n_channels] 69 | 70 | return tc_resnet(inputs, num_classes, n_blocks, n_channels, scope=scope) 71 | 72 | 73 | def ResNet2D8(inputs, num_classes, width_multiplier=1.0, scope="ResNet2D8"): 74 | n_blocks = 3 75 | n_channels = [16, 24, 32, 48] 76 | n_channels = [int(x * width_multiplier) for x in n_channels] 77 | 78 | # inputs: [N, L, C, 1] 79 | f = inputs.get_shape().as_list()[2] 80 | c1, c2 = n_channels[0:2] 81 | first_c = int((3 * f * c1 + 10 * c1 * c2) / (9 + 10 * c2)) 82 | n_channels[0] = first_c 83 | 84 | return tc_resnet(inputs, num_classes, n_blocks, n_channels, scope=scope, debug_2d=True) 85 | 86 | 87 | 88 | def ResNet2D8Pool(inputs, num_classes, width_multiplier=1.0, scope="ResNet2D8Pool"): 89 | n_blocks = 3 90 | n_channels = [16, 24, 32, 48] 91 | n_channels = [int(x * width_multiplier) for x in n_channels] 92 | 93 | # inputs: [N, L, C, 1] 94 | f = inputs.get_shape().as_list()[2] 95 | c1, c2 = n_channels[0:2] 96 | first_c = int((3 * f * c1 + 10 * c1 * c2) / (9 + 10 * c2)) 97 | n_channels[0] = first_c 98 | 99 | return tc_resnet(inputs, num_classes, n_blocks, n_channels, debug_2d=True, pool=([4, 4], 4), scope=scope) 100 | 101 | 102 | def TCResNet_arg_scope(is_training, weight_decay=0.001, keep_prob=0.5): 103 | batch_norm_params = { 104 | "is_training": is_training, 105 | "center": True, 106 | "scale": True, 107 | "decay": 0.997, 108 | "fused": True, 109 | } 110 | 111 | with slim.arg_scope([slim.conv2d, slim.fully_connected], 112 | weights_initializer=slim.initializers.xavier_initializer(), 113 | weights_regularizer=slim.l2_regularizer(weight_decay), 114 | activation_fn=tf.nn.relu, 115 | biases_initializer=None, 116 | normalizer_fn=slim.batch_norm, 117 | padding="SAME", 118 | ): 119 | with slim.arg_scope([slim.batch_norm], **batch_norm_params): 120 | with slim.arg_scope([slim.dropout], 121 | keep_prob=keep_prob, 122 | is_training=is_training) as scope: 123 | return scope 124 | -------------------------------------------------------------------------------- /common/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hyperconnect/TC-ResNet/8ccbff3a45590247d8c54cc82129acb90eecf5c8/common/__init__.py -------------------------------------------------------------------------------- /common/model_loader.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import tensorflow as tf 4 | from tensorflow.python import pywrap_tensorflow 5 | from tensorflow.python.ops import control_flow_ops 6 | from tensorflow.python.platform import gfile 7 | 8 | from common.utils import format_text, get_logger 9 | 10 | 11 | class Ckpt(): 12 | def __init__( 13 | self, 14 | session: tf.Session, 15 | variables_to_restore=None, 16 | include_scopes: str="", 17 | exclude_scopes: str="", 18 | ignore_missing_vars: bool=False, 19 | use_ema: bool=False, 20 | ema_decay: float=None, 21 | logger=None, 22 | ): 23 | self.session = session 24 | self.variables_to_restore = self._get_variables_to_restore( 25 | variables_to_restore, 26 | include_scopes, 27 | exclude_scopes, 28 | use_ema, 29 | ema_decay, 30 | ) 31 | self.ignore_missing_vars = ignore_missing_vars 32 | self.logger = logger 33 | if logger is None: 34 | self.logger = get_logger("Ckpt Loader") 35 | 36 | # variables to save reusable info from previous load 37 | self.has_previous_info = False 38 | self.grouped_vars = {} 39 | self.placeholders = {} 40 | self.assign_op = None 41 | 42 | def _get_variables_to_restore( 43 | self, 44 | variables_to_restore=None, 45 | include_scopes: str="", 46 | exclude_scopes: str="", 47 | use_ema: bool=False, 48 | ema_decay: float=None, 49 | ): 50 | # variables_to_restore might be List or Dictionary. 51 | 52 | def split_strip(scopes: str): 53 | return list(filter(lambda x: len(x) > 0, [s.strip() for s in scopes.split(",")])) 54 | 55 | def starts_with(var, scopes: List) -> bool: 56 | return any([var.op.name.startswith(prefix) for prefix in scopes]) 57 | 58 | exclusions = split_strip(exclude_scopes) 59 | inclusions = split_strip(include_scopes) 60 | 61 | if variables_to_restore is None: 62 | if use_ema: 63 | if ema_decay is None: 64 | raise ValueError("ema_decay undefined") 65 | else: 66 | ema = tf.train.ExponentialMovingAverage(decay=ema_decay) 67 | variables_to_restore = ema.variables_to_restore() # dictionary 68 | else: 69 | variables_to_restore = tf.contrib.framework.get_variables_to_restore() 70 | 71 | filtered_variables_key = variables_to_restore 72 | if len(inclusions) > 0: 73 | filtered_variables_key = filter(lambda var: starts_with(var, inclusions), filtered_variables_key) 74 | filtered_variables_key = filter(lambda var: not starts_with(var, exclusions), filtered_variables_key) 75 | 76 | if isinstance(variables_to_restore, dict): 77 | variables_to_restore = { 78 | key: variables_to_restore[key] for key in filtered_variables_key 79 | } 80 | elif isinstance(variables_to_restore, list): 81 | variables_to_restore = list(filtered_variables_key) 82 | 83 | return variables_to_restore 84 | 85 | # Copied and revised code not to create duplicated 'assign' operations everytime it gets called. 86 | # https://github.com/tensorflow/tensorflow/blob/r1.8/tensorflow/contrib/framework/python/ops/variables.py#L558 87 | def load(self, checkpoint_stempath): 88 | def get_variable_full_name(var): 89 | if var._save_slice_info: 90 | return var._save_slice_info.full_name 91 | else: 92 | return var.op.name 93 | 94 | if not self.has_previous_info: 95 | if isinstance(self.variables_to_restore, (tuple, list)): 96 | for var in self.variables_to_restore: 97 | ckpt_name = get_variable_full_name(var) 98 | if ckpt_name not in self.grouped_vars: 99 | self.grouped_vars[ckpt_name] = [] 100 | self.grouped_vars[ckpt_name].append(var) 101 | 102 | else: 103 | for ckpt_name, value in self.variables_to_restore.items(): 104 | if isinstance(value, (tuple, list)): 105 | self.grouped_vars[ckpt_name] = value 106 | else: 107 | self.grouped_vars[ckpt_name] = [value] 108 | 109 | # Read each checkpoint entry. Create a placeholder variable and 110 | # add the (possibly sliced) data from the checkpoint to the feed_dict. 111 | reader = pywrap_tensorflow.NewCheckpointReader(str(checkpoint_stempath)) 112 | feed_dict = {} 113 | assign_ops = [] 114 | for ckpt_name in self.grouped_vars: 115 | if not reader.has_tensor(ckpt_name): 116 | log_str = f"Checkpoint is missing variable [{ckpt_name}]" 117 | if self.ignore_missing_vars: 118 | self.logger.warning(log_str) 119 | continue 120 | else: 121 | raise ValueError(log_str) 122 | ckpt_value = reader.get_tensor(ckpt_name) 123 | 124 | for var in self.grouped_vars[ckpt_name]: 125 | placeholder_name = f"placeholder/{var.op.name}" 126 | if self.has_previous_info: 127 | placeholder_tensor = self.placeholders[placeholder_name] 128 | else: 129 | placeholder_tensor = tf.placeholder( 130 | dtype=var.dtype.base_dtype, 131 | shape=var.get_shape(), 132 | name=placeholder_name) 133 | assign_ops.append(var.assign(placeholder_tensor)) 134 | self.placeholders[placeholder_name] = placeholder_tensor 135 | 136 | if not var._save_slice_info: 137 | if var.get_shape() != ckpt_value.shape: 138 | raise ValueError( 139 | f"Total size of new array must be unchanged for {ckpt_name} " 140 | f"lh_shape: [{str(ckpt_value.shape)}], rh_shape: [{str(var.get_shape())}]") 141 | 142 | feed_dict[placeholder_tensor] = ckpt_value.reshape(ckpt_value.shape) 143 | else: 144 | slice_dims = zip(var._save_slice_info.var_offset, 145 | var._save_slice_info.var_shape) 146 | slice_dims = [(start, start + size) for (start, size) in slice_dims] 147 | slice_dims = [slice(*x) for x in slice_dims] 148 | slice_value = ckpt_value[slice_dims] 149 | slice_value = slice_value.reshape(var._save_slice_info.var_shape) 150 | feed_dict[placeholder_tensor] = slice_value 151 | 152 | if not self.has_previous_info: 153 | self.assign_op = control_flow_ops.group(*assign_ops) 154 | 155 | self.session.run(self.assign_op, feed_dict) 156 | 157 | if len(feed_dict) > 0: 158 | for key in feed_dict.keys(): 159 | self.logger.info(f"init from checkpoint > {key}") 160 | else: 161 | self.logger.info(f"No init from checkpoint") 162 | 163 | with format_text("cyan", attrs=["bold", "underline"]) as fmt: 164 | self.logger.info(fmt(f"Restore from {checkpoint_stempath}")) 165 | self.has_previous_info = True 166 | -------------------------------------------------------------------------------- /common/tf_utils.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | import shutil 3 | from typing import Dict 4 | from functools import reduce 5 | from pathlib import Path 6 | from operator import mul 7 | 8 | import tensorflow as tf 9 | import pandas as pd 10 | import numpy as np 11 | from termcolor import colored 12 | from tensorflow.contrib.training import checkpoints_iterator 13 | from common.utils import get_logger 14 | from common.utils import wait 15 | 16 | import const 17 | 18 | 19 | def get_variables_to_train(trainable_scopes, logger): 20 | """Returns a list of variables to train. 21 | Returns: 22 | A list of variables to train by the optimizer. 23 | """ 24 | if trainable_scopes is None or trainable_scopes == "": 25 | return tf.trainable_variables() 26 | else: 27 | scopes = [scope.strip() for scope in trainable_scopes.split(",")] 28 | 29 | variables_to_train = [] 30 | for scope in scopes: 31 | variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope) 32 | variables_to_train.extend(variables) 33 | 34 | for var in variables_to_train: 35 | logger.info("vars to train > {}".format(var.name)) 36 | 37 | return variables_to_train 38 | 39 | 40 | def show_models(logger): 41 | trainable_variables = set(tf.contrib.framework.get_variables(collection=tf.GraphKeys.TRAINABLE_VARIABLES)) 42 | all_variables = tf.contrib.framework.get_variables() 43 | trainable_vars = tf.trainable_variables() 44 | total_params = 0 45 | total_trainable_params = 0 46 | logger.info(colored(f">> Start of showing all variables", "cyan", attrs=["bold"])) 47 | for v in all_variables: 48 | is_trainable = v in trainable_variables 49 | count_params = reduce(mul, v.get_shape().as_list(), 1) 50 | total_params += count_params 51 | total_trainable_params += (count_params if is_trainable else 0) 52 | color = "cyan" if is_trainable else "green" 53 | logger.info(colored(( 54 | f">> {v.name} {v.dtype} : {v.get_shape().as_list()}, {count_params} ... {total_params} " 55 | f"(is_trainable: {is_trainable})" 56 | ), color)) 57 | logger.info(colored( 58 | f">> End of showing all variables // Number of variables: {len(all_variables)}, " 59 | f"Number of trainable variables : {len(trainable_vars)}, " 60 | f"Total prod + sum of shape: {total_params} ({total_trainable_params} trainable)", 61 | "cyan", attrs=["bold"])) 62 | return total_params 63 | 64 | 65 | def ckpt_iterator(checkpoint_dir, min_interval_secs=0, timeout=None, timeout_fn=None, logger=None): 66 | for ckpt_path in checkpoints_iterator(checkpoint_dir, min_interval_secs, timeout, timeout_fn): 67 | yield ckpt_path 68 | 69 | 70 | class BestKeeper(object): 71 | def __init__( 72 | self, 73 | metric_with_modes, 74 | dataset_name, 75 | directory, 76 | logger=None, 77 | epsilon=0.00005, 78 | score_file="scores.tsv", 79 | metric_best: Dict={}, 80 | ): 81 | """Keep best model's checkpoint by each datasets & metrics 82 | 83 | Args: 84 | metric_with_modes: Dict, metric_name: mode 85 | if mode is 'min', then it means that minimum value is best, for example loss(MSE, MAE) 86 | if mode is 'max', then it means that maximum value is best, for example Accuracy, Precision, Recall 87 | dataset_name: str, dataset name on which metric be will be calculated 88 | directory: directory path for saving best model 89 | epsilon: float, threshold for measuring the new optimum, to only focus on significant changes. 90 | Because sometimes early-stopping gives better generalization results 91 | """ 92 | if logger is not None: 93 | self.log = logger 94 | else: 95 | self.log = get_logger("BestKeeper") 96 | 97 | self.score_file = score_file 98 | self.metric_best = metric_best 99 | 100 | self.log.info(colored(f"Initialize BestKeeper: Monitor {dataset_name} & Save to {directory}", 101 | "yellow", attrs=["underline"])) 102 | self.log.info(f"{metric_with_modes}") 103 | 104 | self.x_better_than_y = {} 105 | self.directory = Path(directory) 106 | self.output_temp_dir = self.directory / f"{dataset_name}_best_keeper_temp" 107 | 108 | for metric_name, mode in metric_with_modes.items(): 109 | if mode == "min": 110 | self.metric_best[metric_name] = self.load_metric_from_scores_tsv( 111 | directory / dataset_name / metric_name / score_file, 112 | metric_name, 113 | np.inf, 114 | ) 115 | self.x_better_than_y[metric_name] = lambda x, y: np.less(x, y - epsilon) 116 | elif mode == "max": 117 | self.metric_best[metric_name] = self.load_metric_from_scores_tsv( 118 | directory / dataset_name / metric_name / score_file, 119 | metric_name, 120 | -np.inf, 121 | ) 122 | self.x_better_than_y[metric_name] = lambda x, y: np.greater(x, y + epsilon) 123 | else: 124 | raise ValueError(f"Unsupported mode : {mode}") 125 | 126 | def load_metric_from_scores_tsv( 127 | self, 128 | full_path: Path, 129 | metric_name: str, 130 | default_value: float, 131 | ) -> float: 132 | def parse_scores(s: str): 133 | if len(s) > 0: 134 | return float(s) 135 | else: 136 | return default_value 137 | 138 | if full_path.exists(): 139 | with open(full_path, "r") as f: 140 | header = f.readline().strip().split("\t") 141 | values = list(map(parse_scores, f.readline().strip().split("\t"))) 142 | metric_index = header.index(metric_name) 143 | 144 | return values[metric_index] 145 | else: 146 | return default_value 147 | 148 | def monitor(self, dataset_name, eval_scores): 149 | metrics_keep = {} 150 | is_keep = False 151 | for metric_name, score in self.metric_best.items(): 152 | score = eval_scores[metric_name] 153 | if self.x_better_than_y[metric_name](score, self.metric_best[metric_name]): 154 | old_score = self.metric_best[metric_name] 155 | self.metric_best[metric_name] = score 156 | metrics_keep[metric_name] = True 157 | is_keep = True 158 | self.log.info(colored("[KeepBest] {} {:.6f} -> {:.6f}, so keep it!".format( 159 | metric_name, old_score, score), "blue", attrs=["underline"])) 160 | else: 161 | metrics_keep[metric_name] = False 162 | return is_keep, metrics_keep 163 | 164 | def save_best(self, dataset_name, metrics_keep, ckpt_glob): 165 | for metric_name, is_keep in metrics_keep.items(): 166 | if is_keep: 167 | keep_path = self.directory / Path(dataset_name) / Path(metric_name) 168 | self.keep_checkpoint(keep_path, ckpt_glob) 169 | self.keep_converted_files(keep_path) 170 | 171 | def save_scores(self, dataset_name, metrics_keep, eval_scores, meta_info=None): 172 | eval_scores_with_meta = eval_scores.copy() 173 | if meta_info is not None: 174 | eval_scores_with_meta.update(meta_info) 175 | 176 | for metric_name, is_keep in metrics_keep.items(): 177 | if is_keep: 178 | keep_path = self.directory / Path(dataset_name) / Path(metric_name) 179 | if not keep_path.exists(): 180 | keep_path.mkdir(parents=True) 181 | df = pd.DataFrame(pd.Series(eval_scores_with_meta)).sort_index().transpose() 182 | df.to_csv(keep_path / self.score_file, sep="\t", index=False, float_format="%.5f") 183 | 184 | def remove_old_best(self, dataset_name, metrics_keep): 185 | for metric_name, is_keep in metrics_keep.items(): 186 | if is_keep: 187 | keep_path = self.directory / Path(dataset_name) / Path(metric_name) 188 | # Remove old directory to save space 189 | if keep_path.exists(): 190 | shutil.rmtree(str(keep_path)) 191 | keep_path.mkdir(parents=True) 192 | 193 | def keep_checkpoint(self, keep_dir, ckpt_glob): 194 | if not isinstance(keep_dir, Path): 195 | keep_dir = Path(keep_dir) 196 | 197 | # .data-00000-of-00001, .meta, .index 198 | for ckpt_path in ckpt_glob.parent.glob(ckpt_glob.name): 199 | shutil.copy(str(ckpt_path), str(keep_dir)) 200 | 201 | with open(keep_dir / "checkpoint", "w") as f: 202 | f.write(f'model_checkpoint_path: "{Path(ckpt_path.name).stem}"') # noqa 203 | 204 | def keep_converted_files(self, keep_path): 205 | if not isinstance(keep_path, Path): 206 | keep_path = Path(keep_path) 207 | 208 | for path in self.output_temp_dir.glob("*"): 209 | if path.is_dir(): 210 | shutil.copytree(str(path), str(keep_path / path.name)) 211 | else: 212 | shutil.copy(str(path), str(keep_path / path.name)) 213 | 214 | def remove_temp_dir(self): 215 | if self.output_temp_dir.exists(): 216 | shutil.rmtree(str(self.output_temp_dir)) 217 | 218 | 219 | def resolve_checkpoint_path(checkpoint_path, log, is_training): 220 | if checkpoint_path is not None and Path(checkpoint_path).is_dir(): 221 | old_ckpt_path = checkpoint_path 222 | checkpoint_path = tf.train.latest_checkpoint(old_ckpt_path) 223 | if not is_training: 224 | def stop_checker(): 225 | return (tf.train.latest_checkpoint(old_ckpt_path) is not None) 226 | wait("There are no checkpoint file yet", stop_checker) # wait until checkpoint occurs 227 | checkpoint_path = tf.train.latest_checkpoint(old_ckpt_path) 228 | log.info(colored( 229 | "self.args.checkpoint_path updated: {} -> {}".format(old_ckpt_path, checkpoint_path), 230 | "yellow", attrs=["bold"])) 231 | else: 232 | log.info(colored("checkpoint_path is {}".format(checkpoint_path), "yellow", attrs=["bold"])) 233 | 234 | return checkpoint_path 235 | 236 | 237 | def get_global_step_from_checkpoint(checkpoint_path): 238 | """It is assumed that `checkpoint_path` is path to checkpoint file, not path to directory 239 | with checkpoint files. 240 | In case checkpoint path is not defined, 0 is returned.""" 241 | if checkpoint_path is None or checkpoint_path == "": 242 | return 0 243 | else: 244 | if "-" in Path(checkpoint_path).stem: 245 | return int(Path(checkpoint_path).stem.split("-")[-1]) 246 | else: 247 | return 0 248 | -------------------------------------------------------------------------------- /common/utils.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import humanfriendly as hf 3 | import contextlib 4 | import argparse 5 | import logging 6 | import getpass 7 | import shutil 8 | import json 9 | import time 10 | from pathlib import Path 11 | from types import SimpleNamespace 12 | from datetime import datetime 13 | 14 | from tensorflow.python.platform import tf_logging 15 | import click 16 | from termcolor import colored 17 | 18 | 19 | LOG_FORMAT = "[%(levelname)s|%(filename)s:%(lineno)s] %(asctime)s > %(message)s" 20 | 21 | 22 | def update_train_dir(args): 23 | def replace_func(base_string, a, b): 24 | replaced_string = base_string.replace(a, b) 25 | print(colored("[update_train_dir] replace {} : {} -> {}".format(a, base_string, replaced_string), 26 | "yellow")) 27 | return replaced_string 28 | 29 | def make_placeholder(s: str, circumfix: str="%"): 30 | return circumfix + s.upper() + circumfix 31 | 32 | placeholder_mapping = { 33 | make_placeholder("DATE"): datetime.now().strftime("%y%m%d%H%M%S"), 34 | make_placeholder("USER"): getpass.getuser(), 35 | } 36 | 37 | for key, value in placeholder_mapping.items(): 38 | args.train_dir = replace_func(args.train_dir, key, value) 39 | 40 | unknown = "UNKNOWN" 41 | for key, value in vars(args).items(): 42 | key_placeholder = make_placeholder(key) 43 | if key_placeholder in args.train_dir: 44 | replace_value = value 45 | if isinstance(replace_value, str): 46 | if "/" in replace_value: 47 | replace_value = unknown 48 | elif isinstance(replace_value, list): 49 | replace_value = ",".join(map(str, replace_value)) 50 | elif isinstance(replace_value, float) or isinstance(replace_value, int): 51 | replace_value = str(replace_value) 52 | elif isinstance(replace_value, bool): 53 | replace_value = str(replace_value) 54 | else: 55 | replace_value = unknown 56 | args.train_dir = replace_func(args.train_dir, key_placeholder, replace_value) 57 | 58 | print(colored("[update_train_dir] final train_dir {}".format(args.train_dir), 59 | "yellow", attrs=["bold", "underline"])) 60 | 61 | 62 | def positive_int(value): 63 | ivalue = int(value) 64 | if ivalue <= 0: 65 | raise argparse.ArgumentTypeError("%s is an invalid positive int value" % value) 66 | return ivalue 67 | 68 | 69 | def get_logger(logger_name=None, log_file: Path=None, level=logging.DEBUG): 70 | # "log/data-pipe-{}.log".format(datetime.now().strftime("%Y-%m-%d-%H-%M-%S")) 71 | if logger_name is None: 72 | logger = tf_logging._get_logger() 73 | else: 74 | logger = logging.getLogger(logger_name) 75 | 76 | if not logger.hasHandlers(): 77 | formatter = logging.Formatter(LOG_FORMAT) 78 | 79 | logger.setLevel(level) 80 | 81 | if log_file is not None: 82 | log_file.parent.mkdir(parents=True, exist_ok=True) 83 | fileHandler = logging.FileHandler(log_file, mode="w") 84 | fileHandler.setFormatter(formatter) 85 | logger.addHandler(fileHandler) 86 | 87 | streamHandler = logging.StreamHandler() 88 | streamHandler.setFormatter(formatter) 89 | logger.addHandler(streamHandler) 90 | 91 | return logger 92 | 93 | 94 | def format_timespan(duration): 95 | if duration < 10: 96 | readable_duration = "{:.1f} (ms)".format(duration * 1000) 97 | else: 98 | readable_duration = hf.format_timespan(duration) 99 | return readable_duration 100 | 101 | 102 | @contextlib.contextmanager 103 | def timer(name): 104 | st = time.time() 105 | yield 106 | print(" {} : {}".format(name, format_timespan(time.time() - st))) 107 | 108 | 109 | def timeit(method): 110 | def timed(*args, **kw): 111 | hf_timer = hf.Timer() 112 | result = method(*args, **kw) 113 | print(" {!r} ({!r}, {!r}) {}".format(method.__name__, args, kw, hf_timer.rounded)) 114 | return result 115 | return timed 116 | 117 | 118 | class Timer(object): 119 | def __init__(self, log): 120 | self.log = log 121 | 122 | @contextlib.contextmanager 123 | def __call__(self, name, log_func=None): 124 | """ 125 | Example. 126 | timer = Timer(log) 127 | with timer("Some Routines"): 128 | routine1() 129 | routine2() 130 | """ 131 | if log_func is None: 132 | log_func = self.log.info 133 | 134 | start = time.clock() 135 | yield 136 | end = time.clock() 137 | duration = end - start 138 | readable_duration = format_timespan(duration) 139 | log_func(f"{name} :: {readable_duration}") 140 | 141 | 142 | class TextFormatter(object): 143 | def __init__(self, color, attrs): 144 | self.color = color 145 | self.attrs = attrs 146 | 147 | def __call__(self, string): 148 | return colored(string, self.color, attrs=self.attrs) 149 | 150 | 151 | class LogFormatter(object): 152 | def __init__(self, log, color, attrs): 153 | self.log = log 154 | self.color = color 155 | self.attrs = attrs 156 | 157 | def __call__(self, string): 158 | return self.log(colored(string, self.color, attrs=self.attrs)) 159 | 160 | 161 | @contextlib.contextmanager 162 | def format_text(color, attrs=None): 163 | yield TextFormatter(color, attrs) 164 | 165 | 166 | def format_log(log, color, attrs=None): 167 | return LogFormatter(log, color, attrs) 168 | 169 | def wait(message, stop_checker_closure): 170 | assert callable(stop_checker_closure) 171 | st = time.time() 172 | while True: 173 | try: 174 | time_pass = hf.format_timespan(int(time.time() - st)) 175 | sys.stdout.write(colored(( 176 | f"{message}. Do you wanna wait? If not, then ctrl+c! :: waiting time: {time_pass}\r" 177 | ), "yellow", attrs=["bold"])) 178 | sys.stdout.flush() 179 | time.sleep(1) 180 | if stop_checker_closure(): 181 | break 182 | except KeyboardInterrupt: 183 | break 184 | 185 | 186 | class MLNamespace(SimpleNamespace): 187 | def __init__(self, *args, **kwargs): 188 | for kwarg in kwargs.keys(): 189 | assert kwarg not in dir(self) 190 | super().__init__(*args, **kwargs) 191 | 192 | def unordered_values(self): 193 | return list(vars(self).values()) 194 | 195 | def __setitem__(self, key, value): 196 | setattr(self, key, value) 197 | -------------------------------------------------------------------------------- /const.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | TF_SESSION_CONFIG = tf.ConfigProto( 5 | gpu_options=tf.GPUOptions(allow_growth=True), 6 | log_device_placement=False, 7 | device_count={"GPU": 1}) 8 | NULL_CLASS_LABEL = "__null__" 9 | BACKGROUND_NOISE_DIR_NAME = "_background_noise_" 10 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hyperconnect/TC-ResNet/8ccbff3a45590247d8c54cc82129acb90eecf5c8/datasets/__init__.py -------------------------------------------------------------------------------- /datasets/audio_data_wrapper.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import tensorflow as tf 4 | 5 | import const 6 | from datasets.data_wrapper_base import DataWrapperBase 7 | from datasets.augmentation_factory import get_audio_augmentation_fn 8 | 9 | 10 | class AudioDataWrapper(DataWrapperBase): 11 | def __init__( 12 | self, args, session, dataset_split_name, is_training, name: str="AudioDataWrapper" 13 | ): 14 | super().__init__(args, dataset_split_name, is_training, name) 15 | self.setup() 16 | 17 | self.setup_dataset(self.placeholders) 18 | self.setup_iterator( 19 | session, 20 | self.placeholders, 21 | self.data, 22 | ) 23 | 24 | def set_fileformat(self, filenames): 25 | file_formats = set(Path(fn).suffix for fn in filenames) 26 | assert len(file_formats) == 1 27 | self.file_format = file_formats.pop()[1:] # decode_audio receives without .(dot) 28 | 29 | @property 30 | def num_samples(self): 31 | return self._num_samples 32 | 33 | def augment_audio(self, filename, desired_samples, file_format, sample_rate, **kwargs): 34 | aug_fn = get_audio_augmentation_fn(self.args.augmentation_method) 35 | return aug_fn(filename, desired_samples, file_format, sample_rate, **kwargs) 36 | 37 | def _parse_function(self, filename, label): 38 | """ 39 | `filename` tensor holds full path of input 40 | `label` tensor is holding index of class to which it belongs to 41 | """ 42 | desired_samples = int(self.args.sample_rate * self.args.clip_duration_ms / 1000) 43 | 44 | # augment 45 | augmented_audio = self.augment_audio( 46 | filename, 47 | desired_samples, 48 | self.file_format, 49 | self.args.sample_rate, 50 | background_data=self.background_data, 51 | is_training=self.is_training, 52 | background_frequency=self.background_frequency, 53 | background_max_volume=self.background_max_volume, 54 | ) 55 | 56 | label_parsed = self.parse_label(label) 57 | 58 | return augmented_audio, label_parsed 59 | 60 | @staticmethod 61 | def add_arguments(parser): 62 | g = parser.add_argument_group("(AudioDataWrapper) Arguments for Audio DataWrapper") 63 | g.add_argument( 64 | "--sample_rate", 65 | type=int, 66 | default=16000, 67 | help="Expected sample rate of the wavs",) 68 | g.add_argument( 69 | "--clip_duration_ms", 70 | type=int, 71 | default=1000, 72 | help=("Expected duration in milliseconds of the wavs" 73 | "the audio will be cropped or padded with zeroes based on this value"),) 74 | g.add_argument( 75 | "--window_size_ms", 76 | type=float, 77 | default=30.0, 78 | help="How long each spectrogram timeslice is.",) 79 | g.add_argument( 80 | "--window_stride_ms", 81 | type=float, 82 | default=10.0, 83 | help="How far to move in time between spectogram timeslices.",) 84 | 85 | # {{ -- Arguments for log-mel spectrograms 86 | # Default values are coming from tensorflow official tutorial 87 | g.add_argument("--lower_edge_hertz", type=float, default=80.0) 88 | g.add_argument("--upper_edge_hertz", type=float, default=7600.0) 89 | g.add_argument("--num_mel_bins", type=int, default=64) 90 | # Arguments for log-mel spectrograms -- }} 91 | 92 | # {{ -- Arguments for mfcc 93 | # Google speech_commands sample uses num_mfccs=40 as a default value 94 | # Official signal processing tutorial uses num_mfccs=13 as a default value 95 | g.add_argument("--num_mfccs", type=int, default=40) 96 | # Arguments for mfcc -- }} 97 | 98 | g.add_argument("--input_file", default=None, type=str) 99 | g.add_argument("--description_file", default=None, type=str) 100 | g.add_argument("--num_partitions", default=2, type=int, 101 | help=("Number of partition to which is input csv file split" 102 | "and parallely processed")) 103 | 104 | # background noise 105 | g.add_argument("--background_max_volume", default=0.1, type=float, 106 | help="How loud the background noise should be, between 0 and 1.") 107 | g.add_argument("--background_frequency", default=0.8, type=float, 108 | help="How many of the training samples have background noise mixed in.") 109 | g.add_argument("--num_silent", default=-1, type=int, 110 | help="How many silent data should be added. -1 means automatically calculated.") 111 | 112 | 113 | class SingleLabelAudioDataWrapper(AudioDataWrapper): 114 | def parse_label(self, label): 115 | return tf.sparse_to_dense(sparse_indices=tf.cast(label, tf.int32), 116 | sparse_values=tf.ones([1], tf.float32), 117 | output_shape=[self.num_labels], 118 | validate_indices=False) 119 | 120 | def setup(self): 121 | dataset_paths = self.get_all_dataset_paths() 122 | self.label_names, self.num_labels = self.get_label_names(dataset_paths) 123 | assert const.NULL_CLASS_LABEL in self.label_names 124 | assert self.args.num_classes == self.num_labels 125 | 126 | self.filenames, self.labels = self.get_filenames_labels(dataset_paths) 127 | self.set_fileformat(self.filenames) 128 | 129 | # add dummy data for silent class 130 | self.background_max_volume = tf.constant(self.args.background_max_volume) 131 | self.background_frequency = tf.constant(self.args.background_frequency) 132 | self.background_data = self.prepare_silent_data(dataset_paths) 133 | self.add_silent_data() 134 | self._num_samples = self.count_samples(self.filenames) 135 | 136 | self.data = (self.filenames, self.labels) 137 | 138 | self.filenames_placeholder = tf.placeholder(tf.string, self._num_samples) 139 | self.labels_placeholder = tf.placeholder(tf.int32, self._num_samples) 140 | self.placeholders = (self.filenames_placeholder, self.labels_placeholder) 141 | 142 | # shuffle 143 | if self.shuffle: 144 | self.data = self.do_shuffle(*self.data) 145 | 146 | def prepare_silent_data(self, dataset_paths): 147 | def _gen(filename): 148 | filename = tf.constant(filename, dtype=tf.string) 149 | read_fn = get_audio_augmentation_fn("no_augmentation_audio") 150 | desired_samples = -1 # read all 151 | wav_data = read_fn(filename, desired_samples, self.file_format, self.args.sample_rate) 152 | return wav_data 153 | 154 | background_data = list() 155 | for dataset_path in dataset_paths: 156 | for label_path in dataset_path.iterdir(): 157 | if label_path.name == const.BACKGROUND_NOISE_DIR_NAME: 158 | for wav_fullpath in label_path.glob("*.wav"): 159 | background_data.append(_gen(str(wav_fullpath))) 160 | 161 | self.log.info(f"{len(background_data)} background files are loaded.") 162 | return background_data 163 | 164 | def add_silent_data(self): 165 | num_silent = self.args.num_silent 166 | if self.args.num_silent < 0: 167 | num_samples = self.count_samples(self.filenames) 168 | num_silent = num_samples // self.num_labels 169 | 170 | label_idx = self.label_names.index(const.NULL_CLASS_LABEL) 171 | for _ in range(num_silent): 172 | self.filenames.append("") 173 | self.labels.append(label_idx) 174 | self.log.info(f"{num_silent} silent samples will be added.") 175 | -------------------------------------------------------------------------------- /datasets/augmentation_factory.py: -------------------------------------------------------------------------------- 1 | from tensorflow.contrib.framework.python.ops import audio_ops as contrib_audio 2 | import tensorflow as tf 3 | 4 | 5 | _available_audio_augmentation_methods = [ 6 | "anchored_slice_or_pad", 7 | "anchored_slice_or_pad_with_shift", 8 | "no_augmentation_audio", 9 | ] 10 | 11 | 12 | _available_augmentation_methods = ( 13 | _available_audio_augmentation_methods + 14 | ["no_augmentation"] 15 | ) 16 | 17 | 18 | def no_augmentation(x): 19 | return x 20 | 21 | 22 | def _gen_random_from_zero(maxval, dtype=tf.float32): 23 | return tf.random.uniform([], maxval=maxval, dtype=dtype) 24 | 25 | 26 | def _gen_empty_audio(desired_samples): 27 | return tf.zeros([desired_samples, 1], dtype=tf.float32) 28 | 29 | 30 | def _mix_background( 31 | audio, 32 | desired_samples, 33 | background_data, 34 | is_silent, 35 | is_training, 36 | background_frequency, 37 | background_max_volume, 38 | naive_version=True, 39 | **kwargs 40 | ): 41 | """ 42 | Args: 43 | audio: Tensor of audio. 44 | desired_samples: int value of desired length. 45 | background_data: List of background audios. 46 | is_silent: Tensor[Bool]. 47 | is_training: Tensor[Bool]. 48 | background_frequency: probability of mixing background. [0.0, 1.0] 49 | background_max_volume: scaling factor of mixing background. [0.0, 1.0] 50 | """ 51 | foreground_wav = tf.cond( 52 | is_silent, 53 | true_fn=lambda: _gen_empty_audio(desired_samples), 54 | false_fn=lambda: tf.identity(audio) 55 | ) 56 | 57 | # sampling background 58 | random_background_data_idx = _gen_random_from_zero( 59 | len(background_data), 60 | dtype=tf.int32 61 | ) 62 | background_wav = tf.case({ 63 | tf.equal(background_data_idx, random_background_data_idx): 64 | lambda tensor=wav: tensor 65 | for background_data_idx, wav in enumerate(background_data) 66 | }, exclusive=True) 67 | background_wav = tf.random_crop(background_wav, [desired_samples, 1]) 68 | 69 | if naive_version: 70 | # Version 1 71 | # https://github.com/ARM-software/ML-KWS-for-MCU/blob/master/input_data.py#L461 72 | if is_training: 73 | background_volume = tf.cond( 74 | tf.less(_gen_random_from_zero(1.0), background_frequency), 75 | true_fn=lambda: _gen_random_from_zero(background_max_volume), 76 | false_fn=lambda: 0.0, 77 | ) 78 | else: 79 | background_volume = 0.0 80 | else: 81 | # Version 2 82 | # https://github.com/tensorflow/tensorflow/blob/master/tensorflow/examples/speech_commands/input_data.py#L570 83 | background_volume = tf.cond( 84 | tf.logical_or(is_training, is_silent), 85 | true_fn=lambda: tf.cond( 86 | is_silent, 87 | true_fn=lambda: _gen_random_from_zero(1.0), 88 | false_fn=lambda: tf.cond( 89 | tf.less(_gen_random_from_zero(1.0), background_frequency), 90 | true_fn=lambda: _gen_random_from_zero(background_max_volume), 91 | false_fn=lambda: 0.0, 92 | ), 93 | ), 94 | false_fn=lambda: 0.0, 95 | ) 96 | 97 | background_wav = tf.multiply(background_wav, background_volume) 98 | background_added = tf.add(background_wav, foreground_wav) 99 | augmented_audio = tf.clip_by_value(background_added, -1.0, 1.0) 100 | 101 | return augmented_audio 102 | 103 | 104 | def _shift_audio(audio, desired_samples, shift_ratio=0.1): 105 | time_shift = int(desired_samples * shift_ratio) 106 | time_shift_amount = tf.random.uniform( 107 | [], 108 | minval=-time_shift, 109 | maxval=time_shift, 110 | dtype=tf.int32 111 | ) 112 | 113 | time_shift_abs = tf.abs(time_shift_amount) 114 | 115 | def _pos_padding(): 116 | return [[time_shift_amount, 0], [0, 0]] 117 | 118 | def _pos_offset(): 119 | return [0, 0] 120 | 121 | def _neg_padding(): 122 | return [[0, time_shift_abs], [0, 0]] 123 | 124 | def _neg_offset(): 125 | return [time_shift_abs, 0] 126 | 127 | padded_audio = tf.pad( 128 | audio, 129 | tf.cond(tf.greater_equal(time_shift_amount, 0), 130 | true_fn=_pos_padding, 131 | false_fn=_neg_padding), 132 | mode="CONSTANT", 133 | ) 134 | 135 | sliced_audio = tf.slice( 136 | padded_audio, 137 | tf.cond(tf.greater_equal(time_shift_amount, 0), 138 | true_fn=_pos_offset, 139 | false_fn=_neg_offset), 140 | [desired_samples, 1], 141 | ) 142 | 143 | return sliced_audio 144 | 145 | 146 | def _load_wav_file(filename, desired_samples, file_format): 147 | if file_format == "wav": 148 | wav_decoder = contrib_audio.decode_wav( 149 | tf.read_file(filename), 150 | desired_channels=1, 151 | # If desired_samples is set, then the audio will be 152 | # cropped or padded with zeroes to the requested length. 153 | desired_samples=desired_samples, 154 | ) 155 | else: 156 | raise ValueError(f"Unsupported file format: {file_format}") 157 | 158 | return wav_decoder.audio 159 | 160 | 161 | def no_augmentation_audio( 162 | filename, 163 | desired_samples, 164 | file_format, 165 | sample_rate, 166 | **kwargs 167 | ): 168 | return _load_wav_file(filename, desired_samples, file_format) 169 | 170 | 171 | def anchored_slice_or_pad( 172 | filename, 173 | desired_samples, 174 | file_format, 175 | sample_rate, 176 | **kwargs, 177 | ): 178 | is_silent = tf.equal(tf.strings.length(filename), 0) 179 | 180 | audio = tf.cond( 181 | is_silent, 182 | true_fn=lambda: _gen_empty_audio(desired_samples), 183 | false_fn=lambda: _load_wav_file(filename, desired_samples, file_format) 184 | ) 185 | 186 | if "background_data" in kwargs: 187 | audio = _mix_background(audio, desired_samples, is_silent=is_silent, **kwargs) 188 | 189 | return audio 190 | 191 | 192 | def anchored_slice_or_pad_with_shift( 193 | filename, 194 | desired_samples, 195 | file_format, 196 | sample_rate, 197 | **kwargs 198 | ): 199 | is_silent = tf.equal(tf.strings.length(filename), 0) 200 | 201 | audio = tf.cond( 202 | is_silent, 203 | true_fn=lambda: _gen_empty_audio(desired_samples), 204 | false_fn=lambda: _load_wav_file(filename, desired_samples, file_format) 205 | ) 206 | audio = _shift_audio(audio, desired_samples, shift_ratio=0.1) 207 | 208 | if "background_data" in kwargs: 209 | audio = _mix_background(audio, desired_samples, is_silent=is_silent, **kwargs) 210 | 211 | return audio 212 | 213 | 214 | def get_audio_augmentation_fn(name): 215 | if name not in _available_audio_augmentation_methods: 216 | raise ValueError(f"Augmentation name [{name}] was not recognized") 217 | return eval(name) 218 | -------------------------------------------------------------------------------- /datasets/data_wrapper_base.py: -------------------------------------------------------------------------------- 1 | from abc import ABC 2 | from abc import abstractmethod 3 | from pathlib import Path 4 | from typing import Tuple 5 | from typing import List 6 | from collections import defaultdict 7 | import random 8 | 9 | import tensorflow as tf 10 | import pandas as pd 11 | from termcolor import colored 12 | 13 | import common.utils as utils 14 | import const 15 | from datasets.augmentation_factory import _available_augmentation_methods 16 | 17 | 18 | class DataWrapperBase(ABC): 19 | def __init__( 20 | self, 21 | args, 22 | dataset_split_name: str, 23 | is_training: bool, 24 | name: str, 25 | ): 26 | self.name = name 27 | self.args = args 28 | self.dataset_split_name = dataset_split_name 29 | self.is_training = is_training 30 | 31 | self.shuffle = self.args.shuffle 32 | 33 | self.log = utils.get_logger(self.name) 34 | self.timer = utils.Timer(self.log) 35 | self.dataset_path = Path(self.args.dataset_path) 36 | self.dataset_path_with_split_name = self.dataset_path / self.dataset_split_name 37 | 38 | with utils.format_text("yellow", ["underline"]) as fmt: 39 | self.log.info(self.name) 40 | self.log.info(fmt(f"dataset_path_with_split_name: {self.dataset_path_with_split_name}")) 41 | self.log.info(fmt(f"dataset_split_name: {self.dataset_split_name}")) 42 | 43 | @property 44 | @abstractmethod 45 | def num_samples(self): 46 | pass 47 | 48 | @property 49 | def batch_size(self): 50 | try: 51 | return self._batch_size 52 | except AttributeError: 53 | self._batch_size = 0 54 | 55 | @batch_size.setter 56 | def batch_size(self, val): 57 | self._batch_size = val 58 | 59 | def setup_dataset( 60 | self, 61 | placeholders: Tuple[tf.placeholder, tf.placeholder], 62 | batch_size: int=None, 63 | ): 64 | self.batch_size = self.args.batch_size if batch_size is None else batch_size 65 | 66 | # single-GPU: prefetch before batch-shuffle-repeat 67 | dataset = tf.data.Dataset.from_tensor_slices(placeholders) 68 | if self.shuffle: 69 | dataset = dataset.shuffle(self.num_samples) # Follow tf.data.Dataset.list_files 70 | dataset = dataset.map(self._parse_function, num_parallel_calls=self.args.num_threads) 71 | 72 | if hasattr(tf.contrib.data, "AUTOTUNE"): 73 | dataset = dataset.prefetch( 74 | buffer_size=tf.contrib.data.AUTOTUNE 75 | ) 76 | else: 77 | dataset = dataset.prefetch( 78 | buffer_size=self.args.prefetch_factor * self.batch_size 79 | ) 80 | 81 | dataset = dataset.batch(self.batch_size) 82 | if self.is_training and self.shuffle: 83 | dataset = dataset.shuffle(buffer_size=self.args.buffer_size, reshuffle_each_iteration=True).repeat(-1) 84 | elif self.is_training and not self.shuffle: 85 | dataset = dataset.repeat(-1) 86 | 87 | self.dataset = dataset 88 | self.iterator = self.dataset.make_initializable_iterator() 89 | self.next_elem = self.iterator.get_next() 90 | 91 | def setup_iterator( 92 | self, 93 | session: tf.Session, 94 | placeholders: Tuple[tf.placeholder, ...], 95 | variables: Tuple[tf.placeholder, ...], 96 | ): 97 | assert len(placeholders) == len(variables), "Length of placeholders and variables differ!" 98 | with self.timer(colored("Initialize data iterator.", "yellow")): 99 | session.run(self.iterator.initializer, 100 | feed_dict={placeholder: variable for placeholder, variable in zip(placeholders, variables)}) 101 | 102 | def get_input_and_output_op(self): 103 | return self.next_elem 104 | 105 | def __str__(self): 106 | return f"path: {self.args.dataset_path}, split: {self.args.dataset_split_name} data size: {self._num_samples}" 107 | 108 | def get_all_dataset_paths(self) -> List[str]: 109 | if self.args.has_sub_dataset: 110 | return sorted([p for p in self.dataset_path_with_split_name.glob("*/") if p.is_dir()]) 111 | else: 112 | return [self.dataset_path_with_split_name] 113 | 114 | def get_label_names( 115 | self, 116 | dataset_paths: List[str], 117 | ): 118 | """Get all label names (either from one or all subdirectories if subdatasets are defined) 119 | and check consistency of names. 120 | 121 | Args: 122 | dataset_paths: List of paths to datasets. 123 | 124 | Returns: 125 | name_labels: Names of labels. 126 | num_labels: Number of all labels. 127 | """ 128 | tmp_label_names = [] 129 | for dataset_path in dataset_paths: 130 | dataset_label_names = [] 131 | 132 | if self.args.add_null_class: 133 | dataset_label_names.append(const.NULL_CLASS_LABEL) 134 | 135 | for name in sorted([c.name for c in dataset_path.glob("*")]): 136 | if name[0] != "_": 137 | dataset_label_names.append(name) 138 | tmp_label_names.append(dataset_label_names) 139 | 140 | assert len(set(map(tuple, tmp_label_names))) == 1, "Different labels for each sub-dataset directory" 141 | 142 | name_labels = tmp_label_names[0] 143 | num_labels = len(name_labels) 144 | assert num_labels > 0, f"There're no label directories in {dataset_paths}" 145 | return name_labels, num_labels 146 | 147 | def get_filenames_labels( 148 | self, 149 | dataset_paths: List[str], 150 | ) -> [List[str], List[str]]: 151 | """Get paths to all inputs and their labels. 152 | 153 | Args: 154 | dataset_paths: List of paths to datasets. 155 | 156 | Returns: 157 | filenames: List of paths to all inputs. 158 | labels: List of label indexes with corresponding to filenames. 159 | """ 160 | if self.args.cache_dataset and self.args.cache_dataset_path is None: 161 | cache_directory = self.dataset_path / "_metainfo" 162 | cache_directory.mkdir(parents=True, exist_ok=True) 163 | cache_dataset_path = cache_directory / f"{self.dataset_split_name}.csv" 164 | else: 165 | cache_dataset_path = self.args.cache_dataset_path 166 | 167 | if self.args.cache_dataset and cache_dataset_path.exists(): 168 | dataset_df = pd.read_csv(cache_dataset_path) 169 | 170 | filenames = list(dataset_df["filenames"]) 171 | labels = list(dataset_df["labels"]) 172 | else: 173 | filenames = [] 174 | labels = [] 175 | for label_idx, class_name in enumerate(self.label_names): 176 | for dataset_path in dataset_paths: 177 | for class_filename in dataset_path.joinpath(class_name).glob("*"): 178 | filenames.append(str(class_filename)) 179 | labels.append(label_idx) 180 | 181 | if self.args.cache_dataset: 182 | pd.DataFrame({ 183 | "filenames": filenames, 184 | "labels": labels, 185 | }).to_csv(cache_dataset_path, index=False) 186 | 187 | assert len(filenames) > 0 188 | if self.shuffle: 189 | filenames, labels = self.do_shuffle(filenames, labels) 190 | 191 | return filenames, labels 192 | 193 | def do_shuffle(self, *args): 194 | shuffled_data = list(zip(*args)) 195 | random.shuffle(shuffled_data) 196 | result = tuple(map(lambda l: list(l), zip(*shuffled_data))) 197 | 198 | self.log.info(colored("Data shuffled!", "red")) 199 | return result 200 | 201 | def count_samples( 202 | self, 203 | samples: List, 204 | ) -> int: 205 | """Count number of samples in dataset. 206 | 207 | Args: 208 | samples: List of samples (e.g. filenames, labels). 209 | 210 | Returns: 211 | Number of samples. 212 | """ 213 | num_samples = len(samples) 214 | with utils.format_text("yellow", ["underline"]) as fmt: 215 | self.log.info(fmt(f"number of data: {num_samples}")) 216 | 217 | return num_samples 218 | 219 | def oversampling(self, data, labels): 220 | """Doing oversampling based on labels. 221 | data: list of data. 222 | labels: list of labels. 223 | """ 224 | assert self.args.oversampling_ratio is not None, ( 225 | "When `--do_oversampling` is set, it also needs a proper value for `--oversampling_ratio`.") 226 | 227 | samples_of_label = defaultdict(list) 228 | for sample, label in zip(data, labels): 229 | samples_of_label[label].append(sample) 230 | 231 | num_samples_of_label = {label: len(lst) for label, lst in samples_of_label.items()} 232 | max_num_samples = max(num_samples_of_label.values()) 233 | min_num_samples = int(max_num_samples * self.args.oversampling_ratio) 234 | 235 | self.log.info(f"Log for oversampling!") 236 | for label, num_samples in sorted(num_samples_of_label.items()): 237 | # for approximation issue, let's put them at least `n` times 238 | n = 5 239 | # ratio = int(max(min_num_samples / num_samples, 1.0) * n / n + 0.5) 240 | ratio = int(max(min_num_samples / num_samples, 1.0) * n + 0.5) 241 | 242 | self.log.info(f"{label}: {num_samples} x {ratio} => {num_samples * ratio}") 243 | 244 | for i in range(ratio - 1): 245 | data.extend(samples_of_label[label]) 246 | labels.extend(label for _ in range(num_samples)) 247 | 248 | return data, labels 249 | 250 | @staticmethod 251 | def add_arguments(parser): 252 | g_common = parser.add_argument_group("(DataWrapperBase) Common Arguments for all data wrapper.") 253 | g_common.add_argument("--dataset_path", required=True, type=str, help="The name of the dataset to load.") 254 | g_common.add_argument("--dataset_split_name", required=True, type=str, nargs="*", 255 | help="The name of the train/test split. Support multiple splits") 256 | g_common.add_argument("--no-has_sub_dataset", dest="has_sub_dataset", action="store_false") 257 | g_common.add_argument("--has_sub_dataset", dest="has_sub_dataset", action="store_true") 258 | g_common.set_defaults(has_sub_dataset=False) 259 | g_common.add_argument("--no-add_null_class", dest="add_null_class", action="store_false", 260 | help="Support null class for idx 0") 261 | g_common.add_argument("--add_null_class", dest="add_null_class", action="store_true") 262 | g_common.set_defaults(add_null_class=True) 263 | 264 | g_common.add_argument("--batch_size", default=32, type=utils.positive_int, 265 | help="The number of examples in batch.") 266 | g_common.add_argument("--no-shuffle", dest="shuffle", action="store_false") 267 | g_common.add_argument("--shuffle", dest="shuffle", action="store_true") 268 | g_common.set_defaults(shuffle=True) 269 | 270 | g_common.add_argument("--cache_dataset", dest="cache_dataset", action="store_true", 271 | help=("If True generates/loads csv file with paths to all inputs. " 272 | "It accelerates loading of large datasets.")) 273 | g_common.add_argument("--no-cache_dataset", dest="cache_dataset", action="store_false") 274 | g_common.set_defaults(cache_dataset=False) 275 | g_common.add_argument("--cache_dataset_path", default=None, type=lambda p: Path(p), 276 | help=("Path to cached csv files containing paths to all inputs. " 277 | "If not given, csv file will be generated in the " 278 | "root data directory. This argument is used only if" 279 | "--cache_dataset is used.")) 280 | 281 | g_common.add_argument("--width", type=int, default=-1) 282 | g_common.add_argument("--height", type=int, default=-1) 283 | g_common.add_argument("--augmentation_method", type=str, required=True, 284 | choices=_available_augmentation_methods) 285 | g_common.add_argument("--num_threads", default=8, type=int, 286 | help="We recommend using the number of available CPU cores for its value.") 287 | g_common.add_argument("--buffer_size", default=1000, type=int) 288 | g_common.add_argument("--prefetch_factor", default=100, type=int) 289 | -------------------------------------------------------------------------------- /datasets/preprocessor_factory.py: -------------------------------------------------------------------------------- 1 | from datasets.preprocessors import NoOpPreprocessor 2 | from datasets.preprocessors import LogMelSpectrogramPreprocessor 3 | from datasets.preprocessors import MFCCPreprocessor 4 | 5 | 6 | _available_preprocessors = { 7 | # Audio 8 | "log_mel_spectrogram": LogMelSpectrogramPreprocessor, 9 | "mfcc": MFCCPreprocessor, 10 | # Noop 11 | "no_preprocessing": NoOpPreprocessor, 12 | } 13 | 14 | 15 | def factory(preprocess_method, scope, preprocessed_node_name): 16 | if preprocess_method in _available_preprocessors.keys(): 17 | return _available_preprocessors[preprocess_method](scope, preprocessed_node_name) 18 | else: 19 | raise NotImplementedError(f"{preprocess_method}") 20 | -------------------------------------------------------------------------------- /datasets/preprocessors.py: -------------------------------------------------------------------------------- 1 | from abc import ABC 2 | from abc import abstractmethod 3 | 4 | from tensorflow.contrib.framework.python.ops import audio_ops as contrib_audio 5 | import tensorflow as tf 6 | 7 | import const 8 | 9 | 10 | class PreprocessorBase(ABC): 11 | def __init__(self, scope: str, preprocessed_node_name: str): 12 | self._scope = scope 13 | self._input_node = None 14 | self._preprocessed_node = None 15 | self._preprocessed_node_name = preprocessed_node_name 16 | 17 | @abstractmethod 18 | def preprocess(self, inputs, reuse): 19 | raise NotImplementedError 20 | 21 | @staticmethod 22 | def _uint8_to_float32(inputs): 23 | if inputs.dtype == tf.uint8: 24 | inputs = tf.cast(inputs, tf.float32) 25 | return inputs 26 | 27 | def _assign_input_node(self, inputs): 28 | # We include preprocessing part in TFLite float model so we need an input tensor before preprocessing. 29 | self._input_node = inputs 30 | 31 | def _make_node_after_preprocessing(self, inputs): 32 | node = tf.identity(inputs, name=self._preprocessed_node_name) 33 | self._preprocessed_node = node 34 | return node 35 | 36 | @property 37 | def input_node(self): 38 | return self._input_node 39 | 40 | @property 41 | def preprocessed_node(self): 42 | return self._preprocessed_node 43 | 44 | 45 | class NoOpPreprocessor(PreprocessorBase): 46 | def preprocess(self, inputs, reuse=False): 47 | self._assign_input_node(inputs) 48 | inputs = self._make_node_after_preprocessing(inputs) 49 | return inputs 50 | 51 | 52 | # For Audio 53 | class AudioPreprocessorBase(PreprocessorBase): 54 | def preprocess(self, inputs, window_size_samples, window_stride_samples, for_deploy, **kwargs): 55 | self._assign_input_node(inputs) 56 | with tf.variable_scope(self._scope): 57 | if for_deploy: 58 | inputs = self._preprocess_for_deploy(inputs, window_size_samples, window_stride_samples, **kwargs) 59 | else: 60 | inputs = self._preprocess(inputs, window_size_samples, window_stride_samples, **kwargs) 61 | inputs = self._make_node_after_preprocessing(inputs) 62 | return inputs 63 | 64 | def _log_mel_spectrogram(self, audio, window_size_samples, window_stride_samples, 65 | magnitude_squared, **kwargs): 66 | # only accept single channels 67 | audio = tf.squeeze(audio, -1) 68 | stfts = tf.contrib.signal.stft(audio, 69 | frame_length=window_size_samples, 70 | frame_step=window_stride_samples) 71 | 72 | # If magnitude_squared = True(power_spectrograms)#, tf.real(stfts * tf.conj(stfts)) 73 | # If magnitude_squared = False(magnitude_spectrograms), tf.abs(stfts) 74 | if magnitude_squared: 75 | spectrograms = tf.real(stfts * tf.conj(stfts)) 76 | else: 77 | spectrograms = tf.abs(stfts) 78 | 79 | num_spectrogram_bins = spectrograms.shape[-1].value 80 | linear_to_mel_weight_matrix = tf.contrib.signal.linear_to_mel_weight_matrix( 81 | kwargs["num_mel_bins"], 82 | num_spectrogram_bins, 83 | kwargs["sample_rate"], 84 | kwargs["lower_edge_hertz"], 85 | kwargs["upper_edge_hertz"], 86 | ) 87 | 88 | mel_spectrograms = tf.tensordot(spectrograms, linear_to_mel_weight_matrix, 1) 89 | mel_spectrograms.set_shape( 90 | spectrograms.shape[:-1].concatenate(linear_to_mel_weight_matrix.shape[-1:]) 91 | ) 92 | 93 | log_offset = 1e-6 94 | log_mel_spectrograms = tf.log(mel_spectrograms + log_offset) 95 | 96 | return log_mel_spectrograms 97 | 98 | def _single_spectrogram(self, audio, window_size_samples, window_stride_samples, magnitude_squared): 99 | # only accept single batch 100 | audio = tf.squeeze(audio, 0) 101 | 102 | spectrogram = contrib_audio.audio_spectrogram( 103 | audio, 104 | window_size=window_size_samples, 105 | stride=window_stride_samples, 106 | magnitude_squared=magnitude_squared 107 | ) 108 | 109 | return spectrogram 110 | 111 | def _single_mfcc(self, audio, window_size_samples, window_stride_samples, magnitude_squared, 112 | **kwargs): 113 | spectrogram = self._single_spectrogram(audio, window_size_samples, window_stride_samples, magnitude_squared) 114 | 115 | mfcc = contrib_audio.mfcc( 116 | spectrogram, 117 | kwargs["sample_rate_const"], 118 | upper_frequency_limit=kwargs["upper_edge_hertz"], 119 | lower_frequency_limit=kwargs["lower_edge_hertz"], 120 | filterbank_channel_count=kwargs["num_mel_bins"], 121 | dct_coefficient_count=kwargs["num_mfccs"], 122 | ) 123 | 124 | return mfcc 125 | 126 | def _get_mel_matrix(self, num_mel_bins, num_spectrogram_bins, sample_rate, 127 | lower_edge_hertz, upper_edge_hertz): 128 | if num_mel_bins == 64 and num_spectrogram_bins == 257 and sample_rate == 16000 \ 129 | and lower_edge_hertz == 80.0 and upper_edge_hertz == 7600.0: 130 | return tf.constant(const.MEL_WEIGHT_64_257_16000_80_7600, dtype=tf.float32, name="mel_weight_matrix") 131 | elif num_mel_bins == 64 and num_spectrogram_bins == 513 and sample_rate == 16000 \ 132 | and lower_edge_hertz == 80.0 and upper_edge_hertz == 7600.0: 133 | return tf.constant(const.MEL_WEIGHT_64_513_16000_80_7600, dtype=tf.float32, name="mel_weight_matrix") 134 | else: 135 | setting = (num_mel_bins, num_spectrogram_bins, sample_rate, lower_edge_hertz, upper_edge_hertz) 136 | raise ValueError(f"Target setting is not defined: {setting}") 137 | 138 | def _single_log_mel_spectrogram(self, audio, window_size_samples, window_stride_samples, 139 | magnitude_squared, **kwargs): 140 | spectrogram = self._single_spectrogram(audio, window_size_samples, window_stride_samples, magnitude_squared) 141 | spectrogram = tf.squeeze(spectrogram, 0) 142 | 143 | num_spectrogram_bins = spectrogram.shape[-1].value 144 | linear_to_mel_weight_matrix = self._get_mel_matrix( 145 | kwargs["num_mel_bins"], 146 | num_spectrogram_bins, 147 | kwargs["sample_rate"], 148 | kwargs["lower_edge_hertz"], 149 | kwargs["upper_edge_hertz"], 150 | ) 151 | 152 | mel_spectrogram = tf.matmul(spectrogram, linear_to_mel_weight_matrix) 153 | 154 | log_offset = 1e-6 155 | log_mel_spectrograms = tf.log(mel_spectrogram + log_offset) 156 | log_mel_spectrograms = tf.expand_dims(log_mel_spectrograms, 0) 157 | 158 | return log_mel_spectrograms 159 | 160 | 161 | class LogMelSpectrogramPreprocessor(AudioPreprocessorBase): 162 | def _preprocess(self, audio, window_size_samples, window_stride_samples, **kwargs): 163 | # When calculate log mel spectogram, set magnitude_squared False 164 | log_mel_spectrograms = self._log_mel_spectrogram(audio, 165 | window_size_samples, 166 | window_stride_samples, 167 | False, 168 | **kwargs) 169 | log_mel_spectrograms = tf.expand_dims(log_mel_spectrograms, axis=-1) 170 | return log_mel_spectrograms 171 | 172 | def _preprocess_for_deploy(self, audio, window_size_samples, window_stride_samples, **kwargs): 173 | log_mel_spectrogram = self._single_log_mel_spectrogram(audio, 174 | window_size_samples, 175 | window_stride_samples, 176 | False, 177 | **kwargs) 178 | log_mel_spectrogram = tf.expand_dims(log_mel_spectrogram, axis=-1) 179 | return log_mel_spectrogram 180 | 181 | 182 | class MFCCPreprocessor(AudioPreprocessorBase): 183 | def _preprocess(self, audio, window_size_samples, window_stride_samples, **kwargs): 184 | # When calculate log mel spectogram, set magnitude_squared True 185 | log_mel_spectrograms = self._log_mel_spectrogram(audio, 186 | window_size_samples, 187 | window_stride_samples, 188 | True, 189 | **kwargs) 190 | 191 | mfccs = tf.contrib.signal.mfccs_from_log_mel_spectrograms(log_mel_spectrograms) 192 | mfccs = mfccs[..., :kwargs["num_mfccs"]] 193 | mfccs = tf.expand_dims(mfccs, axis=-1) 194 | return mfccs 195 | 196 | def _preprocess_for_deploy(self, audio, window_size_samples, window_stride_samples, **kwargs): 197 | mfcc = self._single_mfcc(audio, 198 | window_size_samples, 199 | window_stride_samples, 200 | True, 201 | **kwargs) 202 | mfcc = tf.expand_dims(mfcc, axis=-1) 203 | return mfcc 204 | -------------------------------------------------------------------------------- /evaluate_audio.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import tensorflow as tf 4 | 5 | from factory.base import TFModel 6 | import factory.audio_nets as audio_nets 7 | from helper.base import Base 8 | from helper.evaluator import Evaluator 9 | from helper.evaluator import SingleLabelAudioEvaluator 10 | from datasets.audio_data_wrapper import AudioDataWrapper 11 | from datasets.audio_data_wrapper import SingleLabelAudioDataWrapper 12 | from datasets.audio_data_wrapper import DataWrapperBase 13 | from metrics.base import MetricManagerBase 14 | from common.tf_utils import ckpt_iterator 15 | import common.utils as utils 16 | import const 17 | 18 | 19 | def main(args): 20 | is_training = False 21 | dataset_name = args.dataset_split_name[0] 22 | session = tf.Session(config=const.TF_SESSION_CONFIG) 23 | 24 | dataset = SingleLabelAudioDataWrapper( 25 | args, 26 | session, 27 | dataset_name, 28 | is_training, 29 | ) 30 | wavs, labels = dataset.get_input_and_output_op() 31 | 32 | model = eval(f"audio_nets.{args.model}")(args, dataset) 33 | model.build(wavs=wavs, labels=labels, is_training=is_training) 34 | 35 | dataset_name = args.dataset_split_name[0] 36 | evaluator = SingleLabelAudioEvaluator( 37 | model, 38 | session, 39 | args, 40 | dataset, 41 | dataset_name, 42 | ) 43 | log = utils.get_logger("EvaluateAudio") 44 | 45 | if args.valid_type == "once": 46 | evaluator.evaluate_once(args.checkpoint_path) 47 | elif args.valid_type == "loop": 48 | log.info(f"Start Loop: watching {evaluator.watch_path}") 49 | 50 | kwargs = { 51 | "min_interval_secs": 0, 52 | "timeout": None, 53 | "timeout_fn": None, 54 | "logger": log, 55 | } 56 | 57 | for ckpt_path in ckpt_iterator(evaluator.watch_path, **kwargs): 58 | log.info(f"[watch] {ckpt_path}") 59 | 60 | evaluator.evaluate_once(ckpt_path) 61 | else: 62 | raise ValueError(f"Undefined valid_type: {args.valid_type}") 63 | 64 | 65 | if __name__ == "__main__": 66 | parser = argparse.ArgumentParser(description=__doc__) 67 | subparsers = parser.add_subparsers(title="Model", description="") 68 | 69 | Base.add_arguments(parser) 70 | Evaluator.add_arguments(parser) 71 | DataWrapperBase.add_arguments(parser) 72 | AudioDataWrapper.add_arguments(parser) 73 | TFModel.add_arguments(parser) 74 | audio_nets.AudioNetModel.add_arguments(parser) 75 | MetricManagerBase.add_arguments(parser) 76 | 77 | for class_name in audio_nets._available_nets: 78 | subparser = subparsers.add_parser(class_name) 79 | subparser.add_argument("--model", default=class_name, type=str, help="DO NOT FIX ME") 80 | add_audio_arguments = eval("audio_nets.{}.add_arguments".format(class_name)) 81 | add_audio_arguments(subparser) 82 | 83 | args = parser.parse_args() 84 | 85 | log = utils.get_logger("AudioNetEvaluate") 86 | log.info(args) 87 | main(args) 88 | -------------------------------------------------------------------------------- /execute_script.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | set -eux 3 | 4 | # setup your train and data paths 5 | export ROOT_TRAIN_DIR= 6 | export DATA_SPEECH_COMMANDS_V1_DIR=/data/google_audio/data_speech_commands_v0.01_newsplit 7 | export DATA_SPEECH_COMMANDS_V2_DIR=/data/google_audio/data_speech_commands_v0.02_newsplit 8 | 9 | ## DATASET_SPLIT_NAME 10 | # train 11 | # valid 12 | # test 13 | dataset_split_name= 14 | 15 | ## AVAILABLE MODELS 16 | # KWSModel 17 | # Res8Model 18 | # Res8NarrowModel 19 | # Res15Model 20 | # Res15NarrowModel 21 | # DSCNNSModel 22 | # DSCNNMModel 23 | # DSCNNLModel 24 | # TCResNet_8Model 25 | # TCResNet_14Model 26 | # TCResNet_2D8Model 27 | model= 28 | 29 | ## AVAILABLE AUDIO PREPROCESS METHODS 30 | # log_mel_spectrogram 31 | # mfcc* 32 | audio_preprocess_setting= 33 | 34 | ## AVAILABLE WINDOW SETTINGS 35 | # 3010 36 | # 4020 37 | window_setting= 38 | 39 | ## WEIGHT DECAY 40 | # 0.0 41 | # 0.001 42 | weight_decay= 43 | 44 | ## AVAILABLE OPTIMIZER SETTINGS 45 | # adam 46 | # mom 47 | opt_setting= 48 | 49 | ## AVAILABLE LEARNING RATE SETTINGS 50 | # s1 51 | # l1 52 | # l2 53 | # l3 54 | lr_setting= 55 | 56 | ## DEPTH MULTIPLIER 57 | width_multiplier= 58 | 59 | ## AVAILABLE DATASETS 60 | # v1 61 | # v2 62 | dataset_settings= 63 | 64 | 65 | ./scripts/script_google_audio.sh \ 66 | ${dataset_split_name} \ 67 | ${model} \ 68 | ${audio_preprocess_setting} \ 69 | ${window_setting} \ 70 | ${weight_decay} \ 71 | ${opt_setting} \ 72 | ${lr_setting} \ 73 | ${width_multiplier} \ 74 | ${dataset_settings} 75 | -------------------------------------------------------------------------------- /factory/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hyperconnect/TC-ResNet/8ccbff3a45590247d8c54cc82129acb90eecf5c8/factory/__init__.py -------------------------------------------------------------------------------- /factory/base.py: -------------------------------------------------------------------------------- 1 | from abc import ABC 2 | from abc import abstractmethod 3 | 4 | import tensorflow as tf 5 | import tensorflow.contrib.slim as slim 6 | 7 | import common.tf_utils as tf_utils 8 | from datasets.preprocessor_factory import _available_preprocessors 9 | 10 | 11 | class TFModel(ABC): 12 | @staticmethod 13 | def add_arguments(parser): 14 | g_cnn = parser.add_argument_group("(CNNModel) Arguments") 15 | g_cnn.add_argument("--num_classes", type=int, default=None) 16 | g_cnn.add_argument("--checkpoint_path", default="", type=str) 17 | 18 | g_cnn.add_argument("--input_batch_size", type=int, default=1) 19 | g_cnn.add_argument("--output_name", type=str, required=True) 20 | 21 | g_cnn.add_argument("--preprocess_method", required=True, type=str, 22 | choices=list(_available_preprocessors.keys())) 23 | 24 | g_cnn.add_argument("--no-ignore_missing_vars", dest="ignore_missing_vars", action="store_false") 25 | g_cnn.add_argument("--ignore_missing_vars", dest="ignore_missing_vars", action="store_true") 26 | g_cnn.set_defaults(ignore_missing_vars=False) 27 | 28 | g_cnn.add_argument("--checkpoint_exclude_scopes", default="", type=str, 29 | help=("Prefix scopes that shoule be EXLUDED for restoring variables " 30 | "(comma separated)")) 31 | 32 | g_cnn.add_argument("--checkpoint_include_scopes", default="", type=str, 33 | help=("Prefix scopes that should be INCLUDED for restoring variables " 34 | "(comma separated)")) 35 | g_cnn.add_argument("--weight_decay", default=1e-4, type=float) 36 | 37 | @abstractmethod 38 | def build_deployable_model(self, *args, **kwargs): 39 | pass 40 | 41 | @abstractmethod 42 | def preprocess_input(self): 43 | pass 44 | 45 | @abstractmethod 46 | def build_output(self): 47 | pass 48 | 49 | @property 50 | @abstractmethod 51 | def audio(self): 52 | pass 53 | 54 | @property 55 | @abstractmethod 56 | def audio_original(self): 57 | pass 58 | 59 | @property 60 | @abstractmethod 61 | def total_loss(self): 62 | pass 63 | 64 | @property 65 | @abstractmethod 66 | def model_loss(self): 67 | pass 68 | -------------------------------------------------------------------------------- /figure/main_figure.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hyperconnect/TC-ResNet/8ccbff3a45590247d8c54cc82129acb90eecf5c8/figure/main_figure.png -------------------------------------------------------------------------------- /freeze.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import argparse 3 | 4 | import tensorflow as tf 5 | from tensorflow.python.framework import graph_util 6 | 7 | import common.model_loader as model_loader 8 | from factory.base import TFModel 9 | import factory.audio_nets as audio_nets 10 | from factory.audio_nets import AudioNetModel 11 | from factory.audio_nets import * 12 | from helper.base import Base 13 | import const 14 | 15 | 16 | def freeze(args): 17 | graph = tf.Graph() 18 | with graph.as_default(): 19 | session = tf.Session(config=const.TF_SESSION_CONFIG) 20 | 21 | model = eval(args.model)(args) 22 | input_tensors, output_tensor = model.build_deployable_model(include_preprocess=False) 23 | 24 | ckpt_loader = model_loader.Ckpt( 25 | session=session, 26 | include_scopes=args.checkpoint_include_scopes, 27 | exclude_scopes=args.checkpoint_exclude_scopes, 28 | ignore_missing_vars=args.ignore_missing_vars, 29 | use_ema=args.use_ema, 30 | ema_decay=args.ema_decay, 31 | ) 32 | session.run(tf.global_variables_initializer()) 33 | session.run(tf.local_variables_initializer()) 34 | 35 | ckpt_loader.load(args.checkpoint_path) 36 | 37 | frozen_graph_def = graph_util.convert_variables_to_constants( 38 | session, 39 | session.graph_def, 40 | [output_tensor.op.name], 41 | ) 42 | 43 | checkpoint_path = Path(args.checkpoint_path) 44 | output_raw_pb_path = checkpoint_path.parent / f"{checkpoint_path.name}.pb" 45 | tf.train.write_graph(frozen_graph_def, 46 | str(output_raw_pb_path.parent), 47 | output_raw_pb_path.name, 48 | as_text=False) 49 | print(f"Save freezed pb : {output_raw_pb_path}") 50 | 51 | 52 | if __name__ == "__main__": 53 | parser = argparse.ArgumentParser(description=__doc__) 54 | subparsers = parser.add_subparsers(title="Model", description="") 55 | 56 | TFModel.add_arguments(parser) 57 | AudioNetModel.add_arguments(parser) 58 | 59 | for class_name in audio_nets._available_nets: 60 | subparser = subparsers.add_parser(class_name) 61 | subparser.add_argument("--model", default=class_name, type=str, help="DO NOT FIX ME") 62 | add_audio_net_arguments = eval(f"audio_nets.{class_name}.add_arguments") 63 | add_audio_net_arguments(subparser) 64 | 65 | Base.add_arguments(parser) 66 | 67 | parser.add_argument("--width", required=True, type=int) 68 | parser.add_argument("--height", required=True, type=int) 69 | parser.add_argument("--channels", required=True, type=int) 70 | 71 | parser.add_argument("--sample_rate", type=int, default=16000, help="Expected sample rate of the wavs",) 72 | parser.add_argument("--clip_duration_ms", type=int) 73 | parser.add_argument("--window_size_ms", type=float, default=30.0, help="How long each spectrogram timeslice is.",) 74 | parser.add_argument("--window_stride_ms", type=float, default=30.0, 75 | help="How far to move in time between spectogram timeslices.",) 76 | parser.add_argument("--num_mel_bins", type=int, default=64) 77 | parser.add_argument("--num_mfccs", type=int, default=64) 78 | parser.add_argument("--lower_edge_hertz", type=float, default=80.0) 79 | parser.add_argument("--upper_edge_hertz", type=float, default=7600.0) 80 | 81 | args = parser.parse_args() 82 | 83 | freeze(args) 84 | -------------------------------------------------------------------------------- /helper/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hyperconnect/TC-ResNet/8ccbff3a45590247d8c54cc82129acb90eecf5c8/helper/__init__.py -------------------------------------------------------------------------------- /helper/base.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import tensorflow as tf 4 | import numpy as np 5 | from abc import ABC 6 | from abc import abstractmethod 7 | from termcolor import colored 8 | 9 | from common.utils import Timer 10 | from common.utils import get_logger 11 | from common.utils import format_log 12 | from common.utils import format_text 13 | from metrics.summaries import BaseSummaries 14 | 15 | 16 | class Base(ABC): 17 | def __init__(self): 18 | self.log = get_logger("Base") 19 | self.timer = Timer(self.log) 20 | 21 | def get_feed_dict(self, is_training: bool=False): 22 | feed_dict = dict() 23 | return feed_dict 24 | 25 | def check_batch_size( 26 | self, 27 | batch_size: int, 28 | terminate: bool=False, 29 | ): 30 | if batch_size != self.args.batch_size: 31 | self.log.info(colored(f"Batch size: required {self.args.batch_size}, obtained {batch_size}", "red")) 32 | if terminate: 33 | raise tf.errors.OutOfRangeError(None, None, "Finished looping dataset.") 34 | 35 | def build_iters_from_batch_size(self, num_samples, batch_size): 36 | iters = self.dataset.num_samples // self.args.batch_size 37 | num_ignored_samples = self.dataset.num_samples % self.args.batch_size 38 | if num_ignored_samples > 0: 39 | with format_text("red", attrs=["bold"]) as fmt: 40 | msg = ( 41 | f"Number of samples cannot be divided by batch_size, " 42 | f"so it ignores some data examples in evaluation: " 43 | f"{self.dataset.num_samples} % {self.args.batch_size} = {num_ignored_samples}" 44 | ) 45 | self.log.warning(fmt(msg)) 46 | return iters 47 | 48 | @abstractmethod 49 | def build_evaluation_fetch_ops(self, do_eval): 50 | raise NotImplementedError 51 | 52 | def run_inference( 53 | self, 54 | global_step: int, 55 | iters: int=None, 56 | is_training: bool=False, 57 | do_eval: bool=True, 58 | ): 59 | """ 60 | Return: Dict[metric_key] -> np.array 61 | array is stacked values for all batches 62 | """ 63 | feed_dict = self.get_feed_dict(is_training=is_training) 64 | 65 | is_first_batch = True 66 | 67 | if iters is None: 68 | iters = self.build_iters_from_batch_size(self.dataset.num_samples, self.args.batch_size) 69 | 70 | # Get summary ops which should be evaluated by session.run 71 | # For example, segmentation task has several loss(GRAD/MAD/MSE) metrics 72 | # And these losses are now calculated by TensorFlow(not numpy) 73 | merged_tensor_type_summaries = self.metric_manager.summary.get_merged_summaries( 74 | collection_key_suffixes=[BaseSummaries.KEY_TYPES.DEFAULT], 75 | is_tensor_summary=True 76 | ) 77 | 78 | fetch_ops = self.build_evaluation_fetch_ops(do_eval) 79 | 80 | aggregator = {key: list() for key in fetch_ops} 81 | aggregator.update({ 82 | "batch_infer_time": list(), 83 | "unit_infer_time": list(), 84 | }) 85 | 86 | for i in range(iters): 87 | try: 88 | st = time.time() 89 | 90 | is_running_summary = do_eval and is_first_batch and merged_tensor_type_summaries is not None 91 | if is_running_summary: 92 | fetch_ops_with_summary = {"summary": merged_tensor_type_summaries} 93 | fetch_ops_with_summary.update(fetch_ops) 94 | 95 | fetch_vals = self.session.run(fetch_ops_with_summary, feed_dict=feed_dict) 96 | 97 | # To avoid duplicated code of session.run, we evaluate merged_sum 98 | # Because we run multiple batches within single global_step, 99 | # merged_summaries can have duplicated values. 100 | # So we write only when the session.run is first 101 | self.metric_manager.write_tensor_summaries(global_step, fetch_vals["summary"]) 102 | is_first_batch = False 103 | else: 104 | fetch_vals = self.session.run(fetch_ops, feed_dict=feed_dict) 105 | 106 | batch_infer_time = (time.time() - st) * 1000 # use milliseconds 107 | 108 | # aggregate 109 | for key, fetch_val in fetch_vals.items(): 110 | if key in aggregator: 111 | aggregator[key].append(fetch_val) 112 | 113 | # add inference time 114 | aggregator["batch_infer_time"].append(batch_infer_time) 115 | aggregator["unit_infer_time"].append(batch_infer_time / self.args.batch_size) 116 | 117 | except tf.errors.OutOfRangeError: 118 | format_log(self.log.info, "yellow")(f"Reach end of the dataset.") 119 | break 120 | except tf.errors.InvalidArgumentError as e: 121 | format_log(self.log.error, "red")(f"Invalid instance is detected: {e}") 122 | continue 123 | 124 | aggregator = {k: np.vstack(v) for k, v in aggregator.items()} 125 | return aggregator 126 | 127 | def run_evaluation( 128 | self, 129 | global_step: int, 130 | iters: int=None, 131 | is_training: bool=False, 132 | ): 133 | eval_dict = self.run_inference(global_step, iters, is_training, do_eval=True) 134 | 135 | non_tensor_data = self.build_non_tensor_data_from_eval_dict(eval_dict, step=global_step) 136 | 137 | self.metric_manager.evaluate_and_aggregate_metrics(step=global_step, 138 | non_tensor_data=non_tensor_data, 139 | eval_dict=eval_dict) 140 | 141 | eval_metric_dict = self.metric_manager.get_evaluation_result(step=global_step) 142 | 143 | return eval_metric_dict 144 | 145 | @staticmethod 146 | def add_arguments(parser): 147 | g_base = parser.add_argument_group("Base") 148 | g_base.add_argument("--no-use_ema", dest="use_ema", action="store_false") 149 | g_base.add_argument("--use_ema", dest="use_ema", action="store_true", 150 | help="Exponential Moving Average. It may take more memory.") 151 | g_base.set_defaults(use_ema=False) 152 | g_base.add_argument("--ema_decay", default=0.999, type=float, 153 | help=("Exponential Moving Average decay.\n" 154 | "Reasonable values for decay are close to 1.0, typically " 155 | "in the multiple-nines range: 0.999, 0.9999")) 156 | g_base.add_argument("--evaluation_iterations", type=int, default=None) 157 | 158 | 159 | class AudioBase(Base): 160 | def build_evaluation_fetch_ops(self, do_eval): 161 | if do_eval: 162 | fetch_ops = { 163 | "labels_onehot": self.model.labels, 164 | "predictions_onehot": self.model.outputs, 165 | "total_loss": self.model.total_loss, 166 | } 167 | fetch_ops.update(self.metric_tf_op) 168 | else: 169 | fetch_ops = { 170 | "predictions_onehot": self.model.outputs, 171 | } 172 | 173 | return fetch_ops 174 | 175 | def build_basic_loss_ops(self): 176 | losses = { 177 | "total_loss": self.model.total_loss, 178 | "model_loss": self.model.model_loss, 179 | } 180 | losses.update(self.model.endpoints_loss) 181 | 182 | return losses 183 | -------------------------------------------------------------------------------- /helper/evaluator.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import sys 3 | from pathlib import Path 4 | from abc import abstractmethod 5 | 6 | import numpy as np 7 | import tensorflow as tf 8 | from tqdm import tqdm 9 | 10 | import common.tf_utils as tf_utils 11 | import metrics.manager as metric_manager 12 | from common.model_loader import Ckpt 13 | from common.utils import format_text 14 | from common.utils import get_logger 15 | from helper.base import AudioBase 16 | from metrics.summaries import BaseSummaries 17 | from metrics.summaries import Summaries 18 | 19 | 20 | class Evaluator(object): 21 | def __init__(self, model, session, args, dataset, dataset_name, name): 22 | self.log = get_logger(name) 23 | 24 | self.model = model 25 | self.session = session 26 | self.args = args 27 | self.dataset = dataset 28 | self.dataset_name = dataset_name 29 | 30 | if Path(self.args.checkpoint_path).is_dir(): 31 | latest_checkpoint = tf.train.latest_checkpoint(self.args.checkpoint_path) 32 | if latest_checkpoint is not None: 33 | self.args.checkpoint_path = latest_checkpoint 34 | self.log.info(f"Get latest checkpoint and update to it: {self.args.checkpoint_path}") 35 | 36 | self.watch_path = self._build_watch_path() 37 | 38 | self.session.run(tf.global_variables_initializer()) 39 | self.session.run(tf.local_variables_initializer()) 40 | 41 | self.ckpt_loader = Ckpt( 42 | session=session, 43 | include_scopes=args.checkpoint_include_scopes, 44 | exclude_scopes=args.checkpoint_exclude_scopes, 45 | ignore_missing_vars=args.ignore_missing_vars, 46 | use_ema=self.args.use_ema, 47 | ema_decay=self.args.ema_decay, 48 | ) 49 | 50 | @abstractmethod 51 | def setup_metric_manager(self): 52 | raise NotImplementedError 53 | 54 | @abstractmethod 55 | def setup_metric_ops(self): 56 | raise NotImplementedError 57 | 58 | @abstractmethod 59 | def build_non_tensor_data_from_eval_dict(self, eval_dict, **kwargs): 60 | raise NotImplementedError 61 | 62 | @abstractmethod 63 | def setup_dataset_iterator(self): 64 | raise NotImplementedError 65 | 66 | def _build_watch_path(self): 67 | if Path(self.args.checkpoint_path).is_dir(): 68 | return Path(self.args.checkpoint_path) 69 | else: 70 | return Path(self.args.checkpoint_path).parent 71 | 72 | def build_evaluation_step(self, checkpoint_path): 73 | if "-" in checkpoint_path and checkpoint_path.split("-")[-1].isdigit(): 74 | return int(checkpoint_path.split("-")[-1]) 75 | else: 76 | return 0 77 | 78 | def build_checkpoint_paths(self, checkpoint_path): 79 | checkpoint_glob = Path(checkpoint_path + "*") 80 | checkpoint_path = Path(checkpoint_path) 81 | 82 | return checkpoint_glob, checkpoint_path 83 | 84 | def build_miscellaneous_path(self, name): 85 | target_dir = self.watch_path / "miscellaneous" / self.dataset_name / name 86 | 87 | if not target_dir.exists(): 88 | target_dir.mkdir(parents=True) 89 | 90 | return target_dir 91 | 92 | def setup_best_keeper(self): 93 | metric_with_modes = self.metric_manager.get_best_keep_metric_with_modes() 94 | self.log.debug(metric_with_modes) 95 | self.best_keeper = tf_utils.BestKeeper( 96 | metric_with_modes, 97 | self.dataset_name, 98 | self.watch_path, 99 | self.log, 100 | ) 101 | 102 | def evaluate_once(self, checkpoint_path): 103 | self.log.info("Evaluation started") 104 | self.setup_dataset_iterator() 105 | self.ckpt_loader.load(checkpoint_path) 106 | 107 | step = self.build_evaluation_step(checkpoint_path) 108 | checkpoint_glob, checkpoint_path = self.build_checkpoint_paths(checkpoint_path) 109 | self.session.run(tf.local_variables_initializer()) 110 | 111 | eval_metric_dict = self.run_evaluation(step, is_training=False) 112 | best_keep_metric_dict = self.metric_manager.filter_best_keep_metric(eval_metric_dict) 113 | is_keep, metrics_keep = self.best_keeper.monitor(self.dataset_name, best_keep_metric_dict) 114 | 115 | if self.args.save_best_keeper: 116 | meta_info = { 117 | "step": step, 118 | "model_size": self.model.total_params, 119 | } 120 | self.best_keeper.remove_old_best(self.dataset_name, metrics_keep) 121 | self.best_keeper.save_best(self.dataset_name, metrics_keep, checkpoint_glob) 122 | self.best_keeper.remove_temp_dir() 123 | self.best_keeper.save_scores(self.dataset_name, metrics_keep, best_keep_metric_dict, meta_info) 124 | 125 | self.metric_manager.write_evaluation_summaries(step=step, 126 | collection_keys=[BaseSummaries.KEY_TYPES.DEFAULT]) 127 | self.metric_manager.log_metrics(step=step) 128 | 129 | self.log.info("Evaluation finished") 130 | 131 | if step >= self.args.max_step_from_restore: 132 | self.log.info("Evaluation stopped") 133 | sys.exit() 134 | 135 | def build_train_directory(self): 136 | if Path(self.args.checkpoint_path).is_dir(): 137 | return str(self.args.checkpoint_path) 138 | else: 139 | return str(Path(self.args.checkpoint_path).parent) 140 | 141 | @staticmethod 142 | def add_arguments(parser): 143 | g = parser.add_argument_group("(Evaluator) arguments") 144 | 145 | g.add_argument("--valid_type", default="loop", type=str, choices=["loop", "once"]) 146 | g.add_argument("--max_outputs", default=5, type=int) 147 | 148 | g.add_argument("--maximum_num_labels_for_metric", default=10, type=int, 149 | help="Maximum number of labels for using class-specific metrics(e.g. precision/recall/f1score)") 150 | 151 | g.add_argument("--no-save_best_keeper", dest="save_best_keeper", action="store_false") 152 | g.add_argument("--save_best_keeper", dest="save_best_keeper", action="store_true") 153 | g.set_defaults(save_best_keeper=True) 154 | 155 | g.add_argument("--no-flatten_output", dest="flatten_output", action="store_false") 156 | g.add_argument("--flatten_output", dest="flatten_output", action="store_true") 157 | g.set_defaults(flatten_output=False) 158 | 159 | g.add_argument("--max_step_from_restore", default=1e20, type=int) 160 | 161 | 162 | class SingleLabelAudioEvaluator(Evaluator, AudioBase): 163 | 164 | def __init__(self, model, session, args, dataset, dataset_name): 165 | super().__init__(model, session, args, dataset, dataset_name, "SingleLabelAudioEvaluator") 166 | self.setup_dataset_related_attr() 167 | self.setup_metric_manager() 168 | self.setup_metric_ops() 169 | self.setup_best_keeper() 170 | 171 | def setup_dataset_related_attr(self): 172 | assert len(self.dataset.label_names) == self.args.num_classes 173 | self.use_class_metrics = len(self.dataset.label_names) < self.args.maximum_num_labels_for_metric 174 | 175 | def setup_metric_manager(self): 176 | self.metric_manager = metric_manager.AudioMetricManager( 177 | is_training=False, 178 | use_class_metrics=self.use_class_metrics, 179 | exclude_metric_names=self.args.exclude_metric_names, 180 | summary=Summaries( 181 | session=self.session, 182 | train_dir=self.build_train_directory(), 183 | is_training=False, 184 | base_name=self.dataset.dataset_split_name, 185 | max_summary_outputs=self.args.max_summary_outputs, 186 | ), 187 | ) 188 | 189 | def setup_metric_ops(self): 190 | losses = self.build_basic_loss_ops() 191 | self.metric_tf_op = self.metric_manager.build_metric_ops({ 192 | "dataset_split_name": self.dataset_name, 193 | "label_names": self.dataset.label_names, 194 | "losses": losses, 195 | "learning_rate": None, 196 | "wavs": self.model.audio_original, 197 | }) 198 | 199 | def build_non_tensor_data_from_eval_dict(self, eval_dict, **kwargs): 200 | return { 201 | "dataset_split_name": self.dataset.dataset_split_name, 202 | "label_names": self.dataset.label_names, 203 | "predictions_onehot": eval_dict["predictions_onehot"], 204 | "labels_onehot": eval_dict["labels_onehot"], 205 | } 206 | 207 | def setup_dataset_iterator(self): 208 | self.dataset.setup_iterator( 209 | self.session, 210 | self.dataset.placeholders, 211 | self.dataset.data, 212 | ) 213 | -------------------------------------------------------------------------------- /metrics/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hyperconnect/TC-ResNet/8ccbff3a45590247d8c54cc82129acb90eecf5c8/metrics/__init__.py -------------------------------------------------------------------------------- /metrics/base.py: -------------------------------------------------------------------------------- 1 | from abc import ABC 2 | from collections import defaultdict 3 | 4 | from common.utils import get_logger 5 | from common.utils import format_text 6 | from common.utils import timer 7 | 8 | 9 | class DataStructure(ABC): 10 | """ 11 | Define inner data structure 12 | Should define `_keys` 13 | """ 14 | _keys = None 15 | 16 | def __init__(self, data): 17 | keys = self.__class__.get_keys() 18 | data_keys = data.keys() 19 | 20 | if set(keys) != set(data_keys): 21 | raise ValueError(f"Keys defined in `_keys ({list(keys)})`" 22 | f" should be appeared at " 23 | f"`data ({list(data_keys)})`") 24 | for k in keys: 25 | setattr(self, k, data[k]) 26 | 27 | def __str__(self): 28 | return f"" 29 | 30 | def __repr__(self): 31 | return str(self) 32 | 33 | def to_dict(self): 34 | return {k: getattr(self, k) for k in self._keys} 35 | 36 | @classmethod 37 | def get_keys(cls): 38 | return cls._keys 39 | 40 | 41 | class MetricAggregator: 42 | def __init__(self): 43 | self.step = None 44 | self.metrics_with_result = None 45 | 46 | self.init(-1) 47 | 48 | def init(self, step): 49 | self.step = step 50 | self.metrics_with_result = dict() 51 | 52 | def aggregate(self, metric, metric_result): 53 | assert metric not in self.metrics_with_result 54 | self.metrics_with_result[metric] = metric_result 55 | 56 | def iterate_metrics(self): 57 | for metric, metric_result in self.metrics_with_result.items(): 58 | yield metric, metric_result 59 | 60 | def iterate_all(self): 61 | for metric, metric_result in self.iterate_metrics(): 62 | for metric_key, value in metric_result.items(): 63 | yield metric, metric_key, value 64 | 65 | def get_collection_summary_dict(self): 66 | # for_summary: Dict[str, Dict[MetricOp, List[Tuple(metric_key, tensor_op)]]] 67 | # collection_key -> metric -> List(summary_key, value) 68 | for_summary = defaultdict(lambda: defaultdict(list)) 69 | for metric, metric_result in self.metrics_with_result.items(): 70 | if metric.is_for_summary: 71 | for metric_key, value in metric_result.items(): 72 | for_summary[metric.summary_collection_key][metric].append((metric_key, value)) 73 | 74 | return for_summary 75 | 76 | def get_tensor_metrics(self): 77 | """ 78 | Get metric that would be fetched for session run. 79 | """ 80 | tensor_metrics = dict() 81 | for metric, metric_result in self.metrics_with_result.items(): 82 | if metric.is_tensor_metric: 83 | for metric_key, value in metric_result.items(): 84 | tensor_metrics[metric_key] = value 85 | 86 | return tensor_metrics 87 | 88 | def get_logs(self): 89 | logs = dict() 90 | for metric, metric_result in self.metrics_with_result.items(): 91 | if metric.is_for_log: 92 | for metric_key, value in metric_result.items(): 93 | if isinstance(value, str): 94 | msg = f"> {metric_key}\n{value}" 95 | else: 96 | msg = f"> {metric_key} : {value}" 97 | logs[metric_key] = msg 98 | 99 | return logs 100 | 101 | 102 | class MetricManagerBase(ABC): 103 | _metric_input_data_parser = None 104 | 105 | def __init__(self, exclude_metric_names, summary): 106 | self.log = get_logger("Metrics") 107 | self.build_op_aggregator = MetricAggregator() 108 | self.eval_metric_aggregator = MetricAggregator() 109 | 110 | self.summary = summary 111 | self.exclude_metric_names = exclude_metric_names 112 | 113 | self.metric_ops = [] 114 | 115 | def register_metric(self, metric): 116 | # if metric is in exclude_metric_names ? 117 | if metric.__class__.__name__ in self.exclude_metric_names: 118 | self.log.info(f"{metric.__class__.__name__} is excluded by user setting.") 119 | return 120 | 121 | # assertion for this metric would be processable 122 | assert str(self._metric_input_data_parser) in map(lambda c: str(c), metric.valid_input_data_parsers), \ 123 | f"{metric.__class__.__name__} cannot be parsed by {self._metric_input_data_parser}" 124 | 125 | # add one 126 | self.metric_ops.append(metric) 127 | self.log.info(f"{metric.__class__.__name__} is added.") 128 | 129 | def register_metrics(self, metrics: list): 130 | for metric in metrics: 131 | self.register_metric(metric) 132 | 133 | def build_metric_ops(self, data): 134 | """ 135 | Define tensor metric operations 136 | 1. call `build_op` of metrics, i.e. add operations to graph 137 | 2. register summaries 138 | 139 | Return: Dict[str, Tensor] 140 | metric_key -> metric_op 141 | """ 142 | output_build_data = self._metric_input_data_parser.parse_build_data(data) 143 | 144 | # get metric tf ops 145 | for metric in self.metric_ops: 146 | try: 147 | metric_build_ops = metric.build_op(output_build_data) 148 | except TypeError as e: 149 | raise TypeError(f"[{metric}]: {e}") 150 | self.build_op_aggregator.aggregate(metric, metric_build_ops) 151 | 152 | # if value is not None, it means it is defined with tensor 153 | metric_tf_ops = self.build_op_aggregator.get_tensor_metrics() 154 | 155 | # register summary 156 | collection_summary_dict = self.build_op_aggregator.get_collection_summary_dict() 157 | self.summary.register_summaries(collection_summary_dict) 158 | self.summary.setup_merged_summaries() 159 | 160 | return metric_tf_ops 161 | 162 | # def evaluate_non_tensor_metric(self, data, step): 163 | def evaluate_and_aggregate_metrics(self, non_tensor_data, eval_dict, step): 164 | """ 165 | Run evaluation of non-tensor metrics 166 | Args: 167 | data: data passed from trainer / evaluator/ ... 168 | """ 169 | non_tensor_data = self._metric_input_data_parser.parse_non_tensor_data(non_tensor_data) 170 | 171 | # aggregate metrics 172 | self.eval_metric_aggregator.init(step) 173 | 174 | # evaluate all metrics 175 | for metric, metric_key_op_dict in self.build_op_aggregator.iterate_metrics(): 176 | if metric.is_tensor_metric: 177 | with timer(f"{metric}.expectation_of"): 178 | # already aggregated - tensor ops 179 | metric_result = dict() 180 | for metric_key in metric_key_op_dict: 181 | if metric_key in eval_dict: 182 | exp_value = metric.expectation_of(eval_dict[metric_key]) 183 | metric_result[metric_key] = exp_value 184 | else: 185 | with timer(f"{metric}.evaluate"): 186 | # need calculation - non tensor ops 187 | metric_result = metric.evaluate(non_tensor_data) 188 | 189 | self.eval_metric_aggregator.aggregate(metric, metric_result) 190 | 191 | def write_tensor_summaries(self, step, summary_value): 192 | self.summary.write(summary_value, step) 193 | 194 | def write_evaluation_summaries(self, step, collection_keys): 195 | assert step == self.eval_metric_aggregator.step, \ 196 | (f"step: {step} is different from aggregator's step: {self.eval_metric_aggregator.step}" 197 | f"`evaluate` function should be called before calling this function") 198 | 199 | collection_summary_dict = self.eval_metric_aggregator.get_collection_summary_dict() 200 | self.summary.write_evaluation_summaries(step=step, 201 | collection_keys=collection_keys, 202 | collection_summary_dict=collection_summary_dict) 203 | 204 | def log_metrics(self, step): 205 | """ 206 | Logging metrics that are evaluated. 207 | """ 208 | assert step == self.eval_metric_aggregator.step, \ 209 | (f"step: {step} is different from aggregator's step: {self.eval_metric_aggregator.step}" 210 | f"`evaluate` function should be called before calling this function") 211 | 212 | log_dicts = dict() 213 | log_dicts.update(self.eval_metric_aggregator.get_logs()) 214 | 215 | with format_text("green", ["bold"]) as fmt: 216 | for metric_key, log_str in log_dicts.items(): 217 | self.log.info(fmt(log_str)) 218 | 219 | def get_evaluation_result(self, step): 220 | """ 221 | Retrun evaluation result regardless of metric type. 222 | """ 223 | assert step == self.eval_metric_aggregator.step, \ 224 | (f"step: {step} is different from aggregator's step: {self.eval_metric_aggregator.step}" 225 | f"`evaluate` function should be called before calling this function") 226 | 227 | eval_dict = dict() 228 | for metric, metric_key, value in self.eval_metric_aggregator.iterate_all(): 229 | eval_dict[metric_key] = value 230 | 231 | return eval_dict 232 | 233 | def get_best_keep_metric_with_modes(self): 234 | metric_min_max_dict = dict() 235 | for metric, metric_key, _ in self.build_op_aggregator.iterate_all(): 236 | if metric.is_for_best_keep: 237 | metric_min_max_dict[metric_key] = metric.min_max_mode 238 | 239 | return metric_min_max_dict 240 | 241 | def filter_best_keep_metric(self, eval_metric_dict): 242 | best_keep_metric_dict = dict() 243 | for metric, metric_key, _ in self.build_op_aggregator.iterate_all(): 244 | if metric_key in eval_metric_dict and metric.is_for_best_keep: 245 | best_keep_metric_dict[metric_key] = eval_metric_dict[metric_key] 246 | 247 | return best_keep_metric_dict 248 | 249 | @staticmethod 250 | def add_arguments(parser): 251 | subparser = parser.add_argument_group(f"Metric Manager Arguments") 252 | subparser.add_argument("--exclude_metric_names", 253 | nargs="*", 254 | default=[], 255 | type=str, 256 | help="Name of metrics to be excluded") 257 | subparser.add_argument("--max_summary_outputs", 258 | default=3, 259 | type=int, 260 | help="Number of maximum summary outputs for multimedia (ex: audio wav)") 261 | -------------------------------------------------------------------------------- /metrics/funcs.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def topN_accuracy(y_true: np.array, 5 | y_pred_onehot: np.array, 6 | N: int): 7 | """ Top N accuracy """ 8 | assert len(y_true.shape) == 1 9 | assert len(y_pred_onehot.shape) == 2 10 | assert y_true.shape[0] == y_pred_onehot.shape[0] 11 | assert y_pred_onehot.shape[1] >= N 12 | 13 | true_positive = 0 14 | for label, top_n_pred in zip(y_true, np.argsort(-y_pred_onehot, axis=-1)[:, :N]): 15 | if label in top_n_pred: 16 | true_positive += 1 17 | 18 | accuracy = true_positive / len(y_true) 19 | 20 | return accuracy 21 | -------------------------------------------------------------------------------- /metrics/manager.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import metrics.ops as mops 4 | import metrics.parser as parser 5 | from metrics.base import MetricManagerBase 6 | from metrics.summaries import Summaries 7 | 8 | 9 | class AudioMetricManager(MetricManagerBase): 10 | _metric_input_data_parser = parser.AudioDataParser 11 | 12 | def __init__( 13 | self, 14 | is_training: bool, 15 | use_class_metrics: bool, 16 | exclude_metric_names: List, 17 | summary: Summaries, 18 | ): 19 | super().__init__(exclude_metric_names, summary) 20 | self.register_metrics([ 21 | # map 22 | mops.MAPMetricOp(), 23 | # accuracy 24 | mops.AccuracyMetricOp(), 25 | mops.Top5AccuracyMetricOp(), 26 | # misc 27 | mops.ClassificationReportMetricOp(), 28 | 29 | # tensor ops 30 | mops.LossesMetricOp(), 31 | ]) 32 | 33 | if is_training: 34 | self.register_metrics([ 35 | mops.WavSummaryOp(), 36 | mops.LearningRateSummaryOp() 37 | ]) 38 | 39 | if use_class_metrics: 40 | # per-class 41 | self.register_metrics([ 42 | mops.PrecisionMetricOp(), 43 | mops.RecallMetricOp(), 44 | mops.F1ScoreMetricOp(), 45 | mops.APMetricOp(), 46 | ]) 47 | -------------------------------------------------------------------------------- /metrics/ops/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_ops import MetricOpBase 2 | from .base_ops import TensorMetricOpBase 3 | from .base_ops import NonTensorMetricOpBase 4 | from .non_tensor_ops import * 5 | from .tensor_ops import * 6 | from .misc_ops import * 7 | -------------------------------------------------------------------------------- /metrics/ops/base_ops.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | from common.utils import get_logger 4 | from metrics.summaries import BaseSummaries 5 | 6 | 7 | class MetricOpBase(ABC): 8 | MIN_MAX_CHOICES = ["min", "max", None] 9 | 10 | _meta_properties = [ 11 | "is_for_summary", 12 | "is_for_best_keep", 13 | "is_for_log", 14 | "valid_input_data_parsers", 15 | "summary_collection_key", 16 | "summary_value_type", 17 | "min_max_mode", 18 | ] 19 | _properties = dict() 20 | 21 | def __init__(self, **kwargs): 22 | self.log = get_logger("MetricOp") 23 | 24 | # init by _properties 25 | # custom values can be added as kwargs 26 | for attr in self._meta_properties: 27 | if attr in kwargs: 28 | setattr(self, attr, kwargs[attr]) 29 | else: 30 | setattr(self, attr, self._properties[attr]) 31 | 32 | # assertion 33 | assert self.min_max_mode in self.MIN_MAX_CHOICES 34 | 35 | if self.is_for_best_keep: 36 | assert self.min_max_mode is not None 37 | 38 | if self.is_for_summary: 39 | assert self.summary_collection_key in vars(BaseSummaries.KEY_TYPES).values() 40 | 41 | def __hash__(self): 42 | return hash(str(self)) 43 | 44 | def __eq__(self, other): 45 | return str(self) == str(other) 46 | 47 | @property 48 | def is_placeholder_summary(self): 49 | assert self.is_for_summary, f"DO NOT call `is_placeholder_summary` method if it is not summary metric" 50 | return self.summary_value_type == BaseSummaries.VALUE_TYPES.PLACEHOLDER 51 | 52 | @property 53 | @abstractmethod 54 | def is_tensor_metric(self): 55 | raise NotImplementedError 56 | 57 | @abstractmethod 58 | def __str__(self): 59 | raise NotImplementedError 60 | 61 | @abstractmethod 62 | def build_op(self, data): 63 | """ This class should be overloaded for 64 | all cases of `valid_input_data_parser` 65 | """ 66 | raise NotImplementedError 67 | 68 | 69 | class NonTensorMetricOpBase(MetricOpBase): 70 | @property 71 | def is_tensor_metric(self): 72 | return False 73 | 74 | @abstractmethod 75 | def evaluate(self, data): 76 | """ This class should be overloaded for 77 | all cases of `valid_input_data_parser` 78 | """ 79 | raise NotImplementedError 80 | 81 | 82 | class TensorMetricOpBase(MetricOpBase): 83 | @property 84 | def is_tensor_metric(self): 85 | return True 86 | 87 | @abstractmethod 88 | def expectation_of(self, data): 89 | """ If evaluate is done at tensor metric, it has to re-caculate the expectation of 90 | aggregated metric values. 91 | This function assumes that data is aggregated for all batches 92 | and retruns proper expectation value. 93 | """ 94 | raise NotImplementedError 95 | -------------------------------------------------------------------------------- /metrics/ops/misc_ops.py: -------------------------------------------------------------------------------- 1 | from metrics.ops.base_ops import TensorMetricOpBase 2 | from metrics.summaries import BaseSummaries 3 | import metrics.parser as parser 4 | 5 | 6 | class LearningRateSummaryOp(TensorMetricOpBase): 7 | """ 8 | Learning rate summary. 9 | """ 10 | _properties = { 11 | "is_for_summary": True, 12 | "is_for_best_keep": False, 13 | "is_for_log": True, 14 | "valid_input_data_parsers": [ 15 | parser.AudioDataParser, 16 | ], 17 | "summary_collection_key": BaseSummaries.KEY_TYPES.DEFAULT, 18 | "summary_value_type": BaseSummaries.VALUE_TYPES.SCALAR, 19 | "min_max_mode": None, 20 | } 21 | 22 | def __str__(self): 23 | return "summary_lr" 24 | 25 | def build_op(self, 26 | data): 27 | res = dict() 28 | if data.learning_rate is None: 29 | pass 30 | else: 31 | res[f"learning_rate/{data.dataset_split_name}"] = data.learning_rate 32 | 33 | return res 34 | 35 | def expectation_of(self, data): 36 | pass 37 | -------------------------------------------------------------------------------- /metrics/ops/non_tensor_ops.py: -------------------------------------------------------------------------------- 1 | from sklearn.metrics import accuracy_score 2 | from sklearn.metrics import average_precision_score 3 | from sklearn.metrics import precision_score 4 | from sklearn.metrics import recall_score 5 | from sklearn.metrics import f1_score 6 | from sklearn.metrics import classification_report 7 | from overload import overload 8 | 9 | import metrics.parser as parser 10 | from metrics.funcs import topN_accuracy 11 | from metrics.ops.base_ops import NonTensorMetricOpBase 12 | from metrics.summaries import BaseSummaries 13 | 14 | 15 | class MAPMetricOp(NonTensorMetricOpBase): 16 | """ 17 | Micro Mean Average Precision Metric. 18 | """ 19 | _properties = { 20 | "is_for_summary": True, 21 | "is_for_best_keep": True, 22 | "is_for_log": True, 23 | "valid_input_data_parsers": [ 24 | parser.AudioDataParser, 25 | ], 26 | "summary_collection_key": BaseSummaries.KEY_TYPES.DEFAULT, 27 | "summary_value_type": BaseSummaries.VALUE_TYPES.PLACEHOLDER, 28 | "min_max_mode": "max", 29 | } 30 | 31 | _average_fns = { 32 | "macro": lambda t, p: average_precision_score(t, p, average="macro"), 33 | "micro": lambda t, p: average_precision_score(t, p, average="micro"), 34 | "weighted": lambda t, p: average_precision_score(t, p, average="weighted"), 35 | "samples": lambda t, p: average_precision_score(t, p, average="samples"), 36 | } 37 | 38 | def __str__(self): 39 | return "mAP_metric" 40 | 41 | @overload 42 | def build_op(self, 43 | data: parser.AudioDataParser.OutputBuildData): 44 | result = dict() 45 | 46 | for avg_name in self._average_fns: 47 | key = f"mAP/{data.dataset_split_name}/{avg_name}" 48 | result[key] = None 49 | 50 | return result 51 | 52 | @overload 53 | def evaluate(self, 54 | data: parser.AudioDataParser.OutputNonTensorData): 55 | result = dict() 56 | 57 | for avg_name, avg_fn in self._average_fns.items(): 58 | key = f"mAP/{data.dataset_split_name}/{avg_name}" 59 | result[key] = avg_fn(data.labels_onehot, data.predictions_onehot) 60 | 61 | return result 62 | 63 | 64 | class AccuracyMetricOp(NonTensorMetricOpBase): 65 | """ 66 | Accuracy Metric. 67 | """ 68 | _properties = { 69 | "is_for_summary": True, 70 | "is_for_best_keep": True, 71 | "is_for_log": True, 72 | "valid_input_data_parsers": [ 73 | parser.AudioDataParser, 74 | ], 75 | "summary_collection_key": BaseSummaries.KEY_TYPES.DEFAULT, 76 | "summary_value_type": BaseSummaries.VALUE_TYPES.PLACEHOLDER, 77 | "min_max_mode": "max", 78 | } 79 | 80 | def __str__(self): 81 | return "accuracy_metric" 82 | 83 | @overload 84 | def build_op(self, 85 | data: parser.AudioDataParser.OutputBuildData): 86 | key = f"accuracy/{data.dataset_split_name}" 87 | 88 | return { 89 | key: None 90 | } 91 | 92 | @overload 93 | def evaluate(self, 94 | data: parser.AudioDataParser.OutputNonTensorData): 95 | key = f"accuracy/{data.dataset_split_name}" 96 | 97 | metric = accuracy_score(data.labels, data.predictions) 98 | 99 | return { 100 | key: metric 101 | } 102 | 103 | 104 | class Top5AccuracyMetricOp(NonTensorMetricOpBase): 105 | """ 106 | Top 5 Accuracy Metric. 107 | """ 108 | _properties = { 109 | "is_for_summary": True, 110 | "is_for_best_keep": True, 111 | "is_for_log": True, 112 | "valid_input_data_parsers": [ 113 | parser.AudioDataParser, 114 | ], 115 | "summary_collection_key": BaseSummaries.KEY_TYPES.DEFAULT, 116 | "summary_value_type": BaseSummaries.VALUE_TYPES.PLACEHOLDER, 117 | "min_max_mode": "max", 118 | } 119 | 120 | def __str__(self): 121 | return "top5_accuracy_metric" 122 | 123 | @overload 124 | def build_op(self, 125 | data: parser.AudioDataParser.OutputBuildData): 126 | key = f"top5_accuracy/{data.dataset_split_name}" 127 | 128 | return { 129 | key: None 130 | } 131 | 132 | @overload 133 | def evaluate(self, 134 | data: parser.AudioDataParser.OutputNonTensorData): 135 | key = f"top5_accuracy/{data.dataset_split_name}" 136 | 137 | metric = topN_accuracy(y_true=data.labels, 138 | y_pred_onehot=data.predictions_onehot, 139 | N=5) 140 | 141 | return { 142 | key: metric 143 | } 144 | 145 | 146 | class PrecisionMetricOp(NonTensorMetricOpBase): 147 | """ 148 | Precision Metric. 149 | """ 150 | _properties = { 151 | "is_for_summary": True, 152 | "is_for_best_keep": True, 153 | "is_for_log": True, 154 | "valid_input_data_parsers": [ 155 | parser.AudioDataParser, 156 | ], 157 | "summary_collection_key": BaseSummaries.KEY_TYPES.DEFAULT, 158 | "summary_value_type": BaseSummaries.VALUE_TYPES.PLACEHOLDER, 159 | "min_max_mode": "max", 160 | } 161 | 162 | def __str__(self): 163 | return "precision_metric" 164 | 165 | @overload 166 | def build_op(self, 167 | data: parser.AudioDataParser.OutputBuildData): 168 | result = dict() 169 | 170 | label_idxes = list(range(len(data.label_names))) 171 | 172 | for label_idx in label_idxes: 173 | label_name = data.label_names[label_idx] 174 | key = f"precision/{data.dataset_split_name}/{label_name}" 175 | result[key] = None 176 | 177 | return result 178 | 179 | @overload 180 | def evaluate(self, 181 | data: parser.AudioDataParser.OutputNonTensorData): 182 | result = dict() 183 | 184 | label_idxes = list(range(len(data.label_names))) 185 | precisions = precision_score(data.labels, data.predictions, average=None, labels=label_idxes) 186 | 187 | for label_idx in label_idxes: 188 | label_name = data.label_names[label_idx] 189 | key = f"precision/{data.dataset_split_name}/{label_name}" 190 | metric = precisions[label_idx] 191 | result[key] = metric 192 | 193 | return result 194 | 195 | 196 | class RecallMetricOp(NonTensorMetricOpBase): 197 | """ 198 | Recall Metric. 199 | """ 200 | _properties = { 201 | "is_for_summary": True, 202 | "is_for_best_keep": True, 203 | "is_for_log": True, 204 | "valid_input_data_parsers": [ 205 | parser.AudioDataParser, 206 | ], 207 | "summary_collection_key": BaseSummaries.KEY_TYPES.DEFAULT, 208 | "summary_value_type": BaseSummaries.VALUE_TYPES.PLACEHOLDER, 209 | "min_max_mode": "max", 210 | } 211 | 212 | def __str__(self): 213 | return "recall_metric" 214 | 215 | @overload 216 | def build_op(self, 217 | data: parser.AudioDataParser.OutputBuildData): 218 | result = dict() 219 | 220 | label_idxes = list(range(len(data.label_names))) 221 | 222 | for label_idx in label_idxes: 223 | label_name = data.label_names[label_idx] 224 | key = f"recall/{data.dataset_split_name}/{label_name}" 225 | result[key] = None 226 | 227 | return result 228 | 229 | @overload 230 | def evaluate(self, 231 | data: parser.AudioDataParser.OutputNonTensorData): 232 | result = dict() 233 | 234 | label_idxes = list(range(len(data.label_names))) 235 | recalls = recall_score(data.labels, data.predictions, average=None, labels=label_idxes) 236 | 237 | for label_idx in label_idxes: 238 | label_name = data.label_names[label_idx] 239 | key = f"recall/{data.dataset_split_name}/{label_name}" 240 | metric = recalls[label_idx] 241 | result[key] = metric 242 | 243 | return result 244 | 245 | 246 | class F1ScoreMetricOp(NonTensorMetricOpBase): 247 | """ 248 | Per class F1-Score Metric. 249 | """ 250 | _properties = { 251 | "is_for_summary": True, 252 | "is_for_best_keep": True, 253 | "is_for_log": True, 254 | "valid_input_data_parsers": [ 255 | parser.AudioDataParser, 256 | ], 257 | "summary_collection_key": BaseSummaries.KEY_TYPES.DEFAULT, 258 | "summary_value_type": BaseSummaries.VALUE_TYPES.PLACEHOLDER, 259 | "min_max_mode": "max", 260 | } 261 | 262 | def __str__(self): 263 | return "f1_score_metric" 264 | 265 | @overload 266 | def build_op(self, 267 | data: parser.AudioDataParser.OutputBuildData): 268 | result = dict() 269 | 270 | label_idxes = list(range(len(data.label_names))) 271 | 272 | for label_idx in label_idxes: 273 | label_name = data.label_names[label_idx] 274 | key = f"f1score/{data.dataset_split_name}/{label_name}" 275 | result[key] = None 276 | 277 | return result 278 | 279 | @overload 280 | def evaluate(self, 281 | data: parser.AudioDataParser.OutputNonTensorData): 282 | result = dict() 283 | 284 | label_idxes = list(range(len(data.label_names))) 285 | f1_scores = f1_score(data.labels, data.predictions, average=None, labels=label_idxes) 286 | 287 | for label_idx in label_idxes: 288 | label_name = data.label_names[label_idx] 289 | key = f"f1score/{data.dataset_split_name}/{label_name}" 290 | metric = f1_scores[label_idx] 291 | result[key] = metric 292 | 293 | return result 294 | 295 | 296 | class APMetricOp(NonTensorMetricOpBase): 297 | """ 298 | Per class Average Precision Metric. 299 | """ 300 | _properties = { 301 | "is_for_summary": True, 302 | "is_for_best_keep": True, 303 | "is_for_log": True, 304 | "valid_input_data_parsers": [ 305 | parser.AudioDataParser, 306 | ], 307 | "summary_collection_key": BaseSummaries.KEY_TYPES.DEFAULT, 308 | "summary_value_type": BaseSummaries.VALUE_TYPES.PLACEHOLDER, 309 | "min_max_mode": "max", 310 | } 311 | 312 | def __str__(self): 313 | return "ap_score_metric" 314 | 315 | @overload 316 | def build_op(self, 317 | data: parser.AudioDataParser.OutputBuildData): 318 | result = dict() 319 | 320 | label_idxes = list(range(len(data.label_names))) 321 | 322 | for label_idx in label_idxes: 323 | label_name = data.label_names[label_idx] 324 | key = f"ap/{data.dataset_split_name}/{label_name}" 325 | result[key] = None 326 | 327 | return result 328 | 329 | @overload 330 | def evaluate(self, 331 | data: parser.AudioDataParser.OutputNonTensorData): 332 | result = dict() 333 | 334 | label_idxes = list(range(len(data.label_names))) 335 | aps = average_precision_score(data.labels_onehot, data.predictions_onehot, average=None) 336 | 337 | for label_idx in label_idxes: 338 | label_name = data.label_names[label_idx] 339 | key = f"ap/{data.dataset_split_name}/{label_name}" 340 | metric = aps[label_idx] 341 | result[key] = metric 342 | 343 | return result 344 | 345 | 346 | class ClassificationReportMetricOp(NonTensorMetricOpBase): 347 | """ 348 | Accuracy Metric. 349 | """ 350 | _properties = { 351 | "is_for_summary": False, 352 | "is_for_best_keep": False, 353 | "is_for_log": True, 354 | "valid_input_data_parsers": [ 355 | parser.AudioDataParser, 356 | ], 357 | "summary_collection_key": None, 358 | "summary_value_type": None, 359 | "min_max_mode": None, 360 | } 361 | 362 | def __str__(self): 363 | return "classification_report_metric" 364 | 365 | @overload 366 | def build_op(self, 367 | data: parser.AudioDataParser.OutputBuildData): 368 | key = f"classification_report/{data.dataset_split_name}" 369 | 370 | return { 371 | key: None 372 | } 373 | 374 | @overload 375 | def evaluate(self, 376 | data: parser.AudioDataParser.OutputNonTensorData): 377 | key = f"classification_report/{data.dataset_split_name}" 378 | 379 | label_idxes = list(range(len(data.label_names))) 380 | metric = classification_report(data.labels, 381 | data.predictions, 382 | labels=label_idxes, 383 | target_names=data.label_names) 384 | metric = f"[ClassificationReport]\n{metric}" 385 | 386 | return { 387 | key: metric 388 | } 389 | -------------------------------------------------------------------------------- /metrics/ops/tensor_ops.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import metrics.parser as parser 4 | from metrics.ops.base_ops import TensorMetricOpBase 5 | from metrics.summaries import BaseSummaries 6 | 7 | 8 | class LossesMetricOp(TensorMetricOpBase): 9 | """ Loss Metric. 10 | """ 11 | _properties = { 12 | "is_for_summary": True, 13 | "is_for_best_keep": True, 14 | "is_for_log": True, 15 | "valid_input_data_parsers": [ 16 | parser.AudioDataParser, 17 | ], 18 | "summary_collection_key": BaseSummaries.KEY_TYPES.DEFAULT, 19 | "summary_value_type": BaseSummaries.VALUE_TYPES.PLACEHOLDER, 20 | "min_max_mode": "min", 21 | } 22 | 23 | def __str__(self): 24 | return "losses" 25 | 26 | def build_op(self, data): 27 | result = dict() 28 | 29 | for loss_name, loss_op in data.losses.items(): 30 | key = f"metric_loss/{data.dataset_split_name}/{loss_name}" 31 | result[key] = loss_op 32 | 33 | return result 34 | 35 | def expectation_of(self, data: np.array): 36 | assert len(data.shape) == 2 37 | return np.mean(data) 38 | 39 | 40 | class WavSummaryOp(TensorMetricOpBase): 41 | _properties = { 42 | "is_for_summary": True, 43 | "is_for_best_keep": False, 44 | "is_for_log": False, 45 | "valid_input_data_parsers": [ 46 | parser.AudioDataParser, 47 | ], 48 | "summary_collection_key": BaseSummaries.KEY_TYPES.DEFAULT, 49 | "summary_value_type": BaseSummaries.VALUE_TYPES.AUDIO, 50 | "min_max_mode": None, 51 | } 52 | 53 | def __str__(self): 54 | return "summary_wav" 55 | 56 | def build_op(self, data: parser.AudioDataParser.OutputBuildData): 57 | return { 58 | f"wav/{data.dataset_split_name}": data.wavs 59 | } 60 | 61 | def expectation_of(self, data): 62 | pass 63 | -------------------------------------------------------------------------------- /metrics/parser.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, ABCMeta 2 | 3 | from metrics.base import DataStructure 4 | 5 | 6 | class MetricDataParserBase(ABC): 7 | @classmethod 8 | def parse_build_data(cls, data): 9 | """ 10 | Args: 11 | data: dictionary which will be passed to InputBuildData 12 | """ 13 | data = cls._validate_build_data(data) 14 | data = cls._process_build_data(data) 15 | return data 16 | 17 | @classmethod 18 | def parse_non_tensor_data(cls, data): 19 | """ 20 | Args: 21 | data: dictionary which will be passed to InputDataStructure 22 | """ 23 | input_data = cls._validate_non_tensor_data(data) 24 | output_data = cls._process_non_tensor_data(input_data) 25 | return output_data 26 | 27 | @classmethod 28 | def _validate_build_data(cls, data): 29 | """ 30 | Specify assertions that tensor data should contains 31 | 32 | Args: 33 | data: dictionary 34 | Return: 35 | InputDataStructure 36 | """ 37 | return cls.InputBuildData(data) 38 | 39 | @classmethod 40 | def _validate_non_tensor_data(cls, data): 41 | """ 42 | Specify assertions that non-tensor data should contains 43 | 44 | Args: 45 | data: dictionary 46 | Return: 47 | InputDataStructure 48 | """ 49 | return cls.InputNonTensorData(data) 50 | 51 | """ 52 | Override these two functions if needed. 53 | """ 54 | @classmethod 55 | def _process_build_data(cls, data): 56 | """ 57 | Process data in order to following metrics can use it 58 | 59 | Args: 60 | data: InputBuildData 61 | 62 | Return: 63 | OutputBuildData 64 | """ 65 | # default function is just passing data 66 | return cls.OutputBuildData(data.to_dict()) 67 | 68 | @classmethod 69 | def _process_non_tensor_data(cls, data): 70 | """ 71 | Process data in order to following metrics can use it 72 | 73 | Args: 74 | data: InputNonTensorData 75 | 76 | Return: 77 | OutputNonTensorData 78 | """ 79 | # default function is just passing data 80 | return cls.OutputNonTensorData(data.to_dict()) 81 | 82 | """ 83 | Belows should be implemented when inherit. 84 | """ 85 | class InputBuildData(DataStructure, metaclass=ABCMeta): 86 | pass 87 | 88 | class OutputBuildData(DataStructure, metaclass=ABCMeta): 89 | pass 90 | 91 | class InputNonTensorData(DataStructure, metaclass=ABCMeta): 92 | pass 93 | 94 | class OutputNonTensorData(DataStructure, metaclass=ABCMeta): 95 | pass 96 | 97 | 98 | class AudioDataParser(MetricDataParserBase): 99 | class InputBuildData(DataStructure): 100 | _keys = [ 101 | "dataset_split_name", 102 | "label_names", 103 | "losses", # Dict | loss_key -> Tensor 104 | "learning_rate", 105 | "wavs", 106 | ] 107 | 108 | class OutputBuildData(DataStructure): 109 | _keys = [ 110 | "dataset_split_name", 111 | "label_names", 112 | "losses", 113 | "learning_rate", 114 | "wavs", 115 | ] 116 | 117 | class InputNonTensorData(DataStructure): 118 | _keys = [ 119 | "dataset_split_name", 120 | "label_names", 121 | "predictions_onehot", 122 | "labels_onehot", 123 | ] 124 | 125 | class OutputNonTensorData(DataStructure): 126 | _keys = [ 127 | "dataset_split_name", 128 | "label_names", 129 | "predictions_onehot", 130 | "labels_onehot", 131 | "predictions", 132 | "labels", 133 | ] 134 | 135 | @classmethod 136 | def _process_non_tensor_data(cls, data): 137 | predictions = data.predictions_onehot.argmax(axis=-1) 138 | labels = data.labels_onehot.argmax(axis=-1) 139 | 140 | return cls.OutputNonTensorData({ 141 | "dataset_split_name": data.dataset_split_name, 142 | "label_names": data.label_names, 143 | "predictions_onehot": data.predictions_onehot, 144 | "labels_onehot": data.labels_onehot, 145 | "predictions": predictions, 146 | "labels": labels, 147 | }) 148 | -------------------------------------------------------------------------------- /metrics/summaries.py: -------------------------------------------------------------------------------- 1 | from abc import ABC 2 | from types import SimpleNamespace 3 | from pathlib import Path 4 | 5 | import tensorflow as tf 6 | from overload import overload 7 | 8 | from common.utils import get_logger 9 | 10 | 11 | class BaseSummaries(ABC): 12 | KEY_TYPES = SimpleNamespace( 13 | DEFAULT="SUMMARY_DEFAULT", 14 | VERBOSE="SUMMARY_VERBOSE", 15 | FIRST_N="SUMMARY_FIRST_N", 16 | ) 17 | 18 | VALUE_TYPES = SimpleNamespace( 19 | SCALAR="SCALAR", 20 | PLACEHOLDER="PLACEHOLDER", 21 | AUDIO="AUDIO", 22 | NONE="NONE", # None is not used for summary 23 | ) 24 | 25 | def __init__( 26 | self, 27 | session, 28 | train_dir, 29 | is_training, 30 | base_name="eval", 31 | max_summary_outputs=None, 32 | ): 33 | self.log = get_logger("Summary") 34 | 35 | self.session = session 36 | self.train_dir = train_dir 37 | self.max_summary_outputs = max_summary_outputs 38 | self.base_name = base_name 39 | self.merged_summaries = dict() 40 | 41 | self.summary_writer = None 42 | self._setup_summary_writer(is_training) 43 | 44 | def write(self, summary, global_step=0): 45 | self.summary_writer.add_summary(summary, global_step) 46 | 47 | def setup_experiment(self, config): 48 | """ 49 | Args: 50 | config: Namespace 51 | """ 52 | config = vars(config) 53 | 54 | sorted_config = [(k, str(v)) for k, v in sorted(config.items(), key=lambda x: x[0])] 55 | config = tf.summary.text("config", tf.convert_to_tensor(sorted_config)) 56 | 57 | config_val = self.session.run(config) 58 | self.write(config_val) 59 | self.summary_writer.add_graph(tf.get_default_graph()) 60 | 61 | def register_summaries(self, collection_summary_dict): 62 | """ 63 | Args: 64 | collection_summary_dict: Dict[str, Dict[MetricOp, List[Tuple(metric_key, tensor_op)]]] 65 | collection_key -> metric -> List(summary_key, value) 66 | """ 67 | for collection_key_suffix, metric_dict in collection_summary_dict.items(): 68 | for metric, key_value_list in metric_dict.items(): 69 | for summary_key, value in key_value_list: 70 | self._routine_add_summary_op(summary_value_type=metric.summary_value_type, 71 | summary_key=summary_key, 72 | value=value, 73 | collection_key_suffix=collection_key_suffix) 74 | 75 | def setup_merged_summaries(self): 76 | for collection_key_suffix in vars(self.KEY_TYPES).values(): 77 | for collection_key in self._iterate_collection_keys(collection_key_suffix): 78 | merged_summary = tf.summary.merge_all(key=collection_key) 79 | self.merged_summaries[collection_key] = merged_summary 80 | 81 | def get_merged_summaries(self, collection_key_suffixes: list, is_tensor_summary: bool): 82 | summaries = [] 83 | 84 | for collection_key_suffix in collection_key_suffixes: 85 | collection_key = self._build_collection_key(collection_key_suffix, is_tensor_summary) 86 | summary = self.merged_summaries[collection_key] 87 | 88 | if summary is not None: 89 | summaries.append(summary) 90 | 91 | if len(summaries) == 0: 92 | return None 93 | elif len(summaries) == 1: 94 | return summaries[0] 95 | else: 96 | return tf.summary.merge(summaries) 97 | 98 | def write_evaluation_summaries(self, step, collection_keys, collection_summary_dict): 99 | """ 100 | Args: 101 | collection_summary_dict: Dict[str, Dict[MetricOp, List[Tuple(metric_key, tensor_op)]]] 102 | collection_key -> metric -> List(summary_key, value) 103 | collection_keys: List 104 | """ 105 | for collection_key_suffix, metric_dict in collection_summary_dict.items(): 106 | if collection_key_suffix in collection_keys: 107 | merged_summary_op = self.get_merged_summaries(collection_key_suffixes=[collection_key_suffix], 108 | is_tensor_summary=False) 109 | feed_dict = dict() 110 | 111 | for metric, key_value_list in metric_dict.items(): 112 | if metric.is_placeholder_summary: 113 | for summary_key, value in key_value_list: 114 | # https://github.com/tensorflow/tensorflow/issues/3378 115 | placeholder_name = self._build_placeholder_name(summary_key) + ":0" 116 | feed_dict[placeholder_name] = value 117 | 118 | summary_value = self.session.run(merged_summary_op, feed_dict=feed_dict) 119 | self.write(summary_value, step) 120 | 121 | def _setup_summary_writer(self, is_training): 122 | summary_directory = self._build_summary_directory(is_training) 123 | self.log.info(f"Write summaries into : {summary_directory}") 124 | 125 | if is_training: 126 | self.summary_writer = tf.summary.FileWriter(summary_directory, self.session.graph) 127 | else: 128 | self.summary_writer = tf.summary.FileWriter(summary_directory) 129 | 130 | def _build_summary_directory(self, is_training): 131 | if is_training: 132 | return self.train_dir 133 | else: 134 | if Path(self.train_dir).is_dir(): 135 | summary_directory = (Path(self.train_dir) / Path(self.base_name)).as_posix() 136 | else: 137 | summary_directory = (Path(self.train_dir).parent / Path(self.base_name)).as_posix() 138 | 139 | if not Path(summary_directory).exists(): 140 | Path(summary_directory).mkdir(parents=True) 141 | 142 | return summary_directory 143 | 144 | def _routine_add_summary_op(self, summary_value_type, summary_key, value, collection_key_suffix): 145 | collection_key = self._build_collection_key(collection_key_suffix, summary_value_type) 146 | 147 | if summary_value_type == self.VALUE_TYPES.SCALAR: 148 | def register_fn(k, v): 149 | return tf.summary.scalar(k, v, collections=[collection_key]) 150 | 151 | elif summary_value_type == self.VALUE_TYPES.AUDIO: 152 | def register_fn(k, v): 153 | return tf.summary.audio(k, v, 154 | sample_rate=16000, 155 | max_outputs=self.max_summary_outputs, 156 | collections=[collection_key]) 157 | 158 | elif summary_value_type == self.VALUE_TYPES.PLACEHOLDER: 159 | def register_fn(k, v): 160 | return tf.summary.scalar(k, v, collections=[collection_key]) 161 | value = self._build_placeholder(summary_key) 162 | 163 | else: 164 | raise NotImplementedError 165 | 166 | register_fn(summary_key, value) 167 | 168 | @classmethod 169 | def _build_placeholder(cls, summary_key): 170 | name = cls._build_placeholder_name(summary_key) 171 | return tf.placeholder(tf.float32, [], name=name) 172 | 173 | @staticmethod 174 | def _build_placeholder_name(summary_key): 175 | return f"non_tensor_summary_placeholder/{summary_key}" 176 | 177 | # Below two functions should be class method but defined as instance method 178 | # since it has bug in @overload 179 | # @classmethod 180 | @overload 181 | def _build_collection_key(self, collection_key_suffix, summary_value_type: str): 182 | if summary_value_type == self.VALUE_TYPES.PLACEHOLDER: 183 | prefix = "NON_TENSOR" 184 | else: 185 | prefix = "TENSOR" 186 | 187 | return f"{prefix}_{collection_key_suffix}" 188 | 189 | # @classmethod 190 | @_build_collection_key.add 191 | def _build_collection_key(self, collection_key_suffix, is_tensor_summary: bool): 192 | if not is_tensor_summary: 193 | prefix = "NON_TENSOR" 194 | else: 195 | prefix = "TENSOR" 196 | 197 | return f"{prefix}_{collection_key_suffix}" 198 | 199 | @classmethod 200 | def _iterate_collection_keys(cls, collection_key_suffix): 201 | for prefix in ["NON_TENSOR", "TENSOR"]: 202 | yield f"{prefix}_{collection_key_suffix}" 203 | 204 | 205 | class Summaries(BaseSummaries): 206 | pass 207 | -------------------------------------------------------------------------------- /requirements/py36-common.txt: -------------------------------------------------------------------------------- 1 | # ML 2 | numpy>=1.16.0 3 | pandas 4 | tqdm 5 | scikit-learn==0.19.1 6 | scipy 7 | 8 | # data 9 | termcolor 10 | click 11 | dask[dataframe] 12 | 13 | # misc 14 | humanfriendly 15 | overload 16 | -------------------------------------------------------------------------------- /requirements/py36-cpu.txt: -------------------------------------------------------------------------------- 1 | # Deep learning 2 | tensorflow==1.13.1 3 | -r py36-common.txt -------------------------------------------------------------------------------- /requirements/py36-gpu.txt: -------------------------------------------------------------------------------- 1 | # Deep learning 2 | tensorflow-gpu==1.13.1 3 | -r py36-common.txt 4 | -------------------------------------------------------------------------------- /scripts/commands/DSCNNLModel-0_mfcc_10_4020_0.0000_adam_l3.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | trap 'pkill -P $$' SIGINT SIGTERM EXIT 3 | python train_audio.py --dataset_path google_speech_commands/splitted_data --dataset_split_name train --output_name output/softmax --num_classes 12 --train_dir work/v1/DSCNNLModel-0/mfcc_10_4020_0.0000_adam_l3 --num_silent 1854 --augmentation_method anchored_slice_or_pad_with_shift --preprocess_method mfcc --num_mfccs 10 --clip_duration_ms 1000 --window_size_ms 40 --window_stride_ms 20 --batch_size 100 --boundaries 10000 --max_step_from_restore 20000 --lr_list 0.0005 0.0001 --absolute_schedule --no-boundaries_epoch --max_to_keep 20 --step_save_checkpoint 500 --step_evaluation 500 --optimizer adam DSCNNLModel & 4 | sleep 5 5 | python evaluate_audio.py --dataset_path google_speech_commands/splitted_data --dataset_split_name valid --output_name output/softmax --num_classes 12 --checkpoint_path work/v1/DSCNNLModel-0/mfcc_10_4020_0.0000_adam_l3 --num_silent 258 --augmentation_method anchored_slice_or_pad --preprocess_method mfcc --num_mfccs 10 --clip_duration_ms 1000 --window_size_ms 40 --window_stride_ms 20 --background_frequency 0.0 --background_max_volume 0.0 --max_step_from_restore 20000 --batch_size 3 --no-shuffle --valid_type loop DSCNNLModel & 6 | wait 7 | python evaluate_audio.py --dataset_path google_speech_commands/splitted_data --dataset_split_name test --output_name output/softmax --num_classes 12 --checkpoint_path work/v1/DSCNNLModel-0/mfcc_10_4020_0.0000_adam_l3/valid/accuracy/valid --num_silent 257 --augmentation_method anchored_slice_or_pad --preprocess_method mfcc --num_mfccs 10 --clip_duration_ms 1000 --window_size_ms 40 --window_stride_ms 20 --background_frequency 0.0 --background_max_volume 0.0 --max_step_from_restore 20000 --batch_size 39 --no-shuffle --valid_type once DSCNNLModel 8 | -------------------------------------------------------------------------------- /scripts/commands/DSCNNMModel-0_mfcc_10_4020_0.0000_adam_l3.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | trap 'pkill -P $$' SIGINT SIGTERM EXIT 3 | python train_audio.py --dataset_path google_speech_commands/splitted_data --dataset_split_name train --output_name output/softmax --num_classes 12 --train_dir work/v1/DSCNNMModel-0/mfcc_10_4020_0.0000_adam_l3 --num_silent 1854 --augmentation_method anchored_slice_or_pad_with_shift --preprocess_method mfcc --num_mfccs 10 --clip_duration_ms 1000 --window_size_ms 40 --window_stride_ms 20 --batch_size 100 --boundaries 10000 --max_step_from_restore 20000 --lr_list 0.0005 0.0001 --absolute_schedule --no-boundaries_epoch --max_to_keep 20 --step_save_checkpoint 500 --step_evaluation 500 --optimizer adam DSCNNMModel & 4 | sleep 5 5 | python evaluate_audio.py --dataset_path google_speech_commands/splitted_data --dataset_split_name valid --output_name output/softmax --num_classes 12 --checkpoint_path work/v1/DSCNNMModel-0/mfcc_10_4020_0.0000_adam_l3 --num_silent 258 --augmentation_method anchored_slice_or_pad --preprocess_method mfcc --num_mfccs 10 --clip_duration_ms 1000 --window_size_ms 40 --window_stride_ms 20 --background_frequency 0.0 --background_max_volume 0.0 --max_step_from_restore 20000 --batch_size 3 --no-shuffle --valid_type loop DSCNNMModel & 6 | wait 7 | python evaluate_audio.py --dataset_path google_speech_commands/splitted_data --dataset_split_name test --output_name output/softmax --num_classes 12 --checkpoint_path work/v1/DSCNNMModel-0/mfcc_10_4020_0.0000_adam_l3/valid/accuracy/valid --num_silent 257 --augmentation_method anchored_slice_or_pad --preprocess_method mfcc --num_mfccs 10 --clip_duration_ms 1000 --window_size_ms 40 --window_stride_ms 20 --background_frequency 0.0 --background_max_volume 0.0 --max_step_from_restore 20000 --batch_size 39 --no-shuffle --valid_type once DSCNNMModel 8 | -------------------------------------------------------------------------------- /scripts/commands/DSCNNSModel-0_mfcc_10_4020_0.0000_adam_l3.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | trap 'pkill -P $$' SIGINT SIGTERM EXIT 3 | python train_audio.py --dataset_path google_speech_commands/splitted_data --dataset_split_name train --output_name output/softmax --num_classes 12 --train_dir work/v1/DSCNNSModel-0/mfcc_10_4020_0.0000_adam_l3 --num_silent 1854 --augmentation_method anchored_slice_or_pad_with_shift --preprocess_method mfcc --num_mfccs 10 --clip_duration_ms 1000 --window_size_ms 40 --window_stride_ms 20 --batch_size 100 --boundaries 10000 --max_step_from_restore 20000 --lr_list 0.0005 0.0001 --absolute_schedule --no-boundaries_epoch --max_to_keep 20 --step_save_checkpoint 500 --step_evaluation 500 --optimizer adam DSCNNSModel & 4 | sleep 5 5 | python evaluate_audio.py --dataset_path google_speech_commands/splitted_data --dataset_split_name valid --output_name output/softmax --num_classes 12 --checkpoint_path work/v1/DSCNNSModel-0/mfcc_10_4020_0.0000_adam_l3 --num_silent 258 --augmentation_method anchored_slice_or_pad --preprocess_method mfcc --num_mfccs 10 --clip_duration_ms 1000 --window_size_ms 40 --window_stride_ms 20 --background_frequency 0.0 --background_max_volume 0.0 --max_step_from_restore 20000 --batch_size 3 --no-shuffle --valid_type loop DSCNNSModel & 6 | wait 7 | python evaluate_audio.py --dataset_path google_speech_commands/splitted_data --dataset_split_name test --output_name output/softmax --num_classes 12 --checkpoint_path work/v1/DSCNNSModel-0/mfcc_10_4020_0.0000_adam_l3/valid/accuracy/valid --num_silent 257 --augmentation_method anchored_slice_or_pad --preprocess_method mfcc --num_mfccs 10 --clip_duration_ms 1000 --window_size_ms 40 --window_stride_ms 20 --background_frequency 0.0 --background_max_volume 0.0 --max_step_from_restore 20000 --batch_size 39 --no-shuffle --valid_type once DSCNNSModel 8 | -------------------------------------------------------------------------------- /scripts/commands/KWSfpool3-0_mfcc_40_4020_0.0000_adam_l3.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | trap 'pkill -P $$' SIGINT SIGTERM EXIT 3 | python train_audio.py --dataset_path google_speech_commands/splitted_data --dataset_split_name train --output_name output/softmax --num_classes 12 --train_dir work/v1/KWSfpool3-0/mfcc_40_4020_0.0000_adam_l3 --num_silent 1854 --augmentation_method anchored_slice_or_pad_with_shift --preprocess_method mfcc --num_mfccs 40 --clip_duration_ms 1000 --window_size_ms 40 --window_stride_ms 20 --batch_size 100 --boundaries 10000 --max_step_from_restore 20000 --lr_list 0.0005 0.0001 --absolute_schedule --no-boundaries_epoch --max_to_keep 20 --step_save_checkpoint 500 --step_evaluation 500 --optimizer adam KWSModel --architecture trad_fpool3 & 4 | sleep 5 5 | python evaluate_audio.py --dataset_path google_speech_commands/splitted_data --dataset_split_name valid --output_name output/softmax --num_classes 12 --checkpoint_path work/v1/KWSfpool3-0/mfcc_40_4020_0.0000_adam_l3 --num_silent 258 --augmentation_method anchored_slice_or_pad --preprocess_method mfcc --num_mfccs 40 --clip_duration_ms 1000 --window_size_ms 40 --window_stride_ms 20 --background_frequency 0.0 --background_max_volume 0.0 --max_step_from_restore 20000 --batch_size 3 --no-shuffle --valid_type loop KWSModel --architecture trad_fpool3 & 6 | wait 7 | python evaluate_audio.py --dataset_path google_speech_commands/splitted_data --dataset_split_name test --output_name output/softmax --num_classes 12 --checkpoint_path work/v1/KWSfpool3-0/mfcc_40_4020_0.0000_adam_l3/valid/accuracy/valid --num_silent 257 --augmentation_method anchored_slice_or_pad --preprocess_method mfcc --num_mfccs 40 --clip_duration_ms 1000 --window_size_ms 40 --window_stride_ms 20 --background_frequency 0.0 --background_max_volume 0.0 --max_step_from_restore 20000 --batch_size 39 --no-shuffle --valid_type once KWSModel --architecture trad_fpool3 8 | -------------------------------------------------------------------------------- /scripts/commands/KWSfstride4-0_mfcc_40_4020_0.0000_adam_l2.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | trap 'pkill -P $$' SIGINT SIGTERM EXIT 3 | python train_audio.py --dataset_path google_speech_commands/splitted_data --dataset_split_name train --output_name output/softmax --num_classes 12 --train_dir work/v1/KWSfstride4-0/mfcc_40_4020_0.0000_adam_l2 --num_silent 1854 --augmentation_method anchored_slice_or_pad_with_shift --preprocess_method mfcc --num_mfccs 40 --clip_duration_ms 1000 --window_size_ms 40 --window_stride_ms 20 --batch_size 100 --boundaries 10000 20000 --max_step_from_restore 30000 --lr_list 0.0005 0.0001 0.00002 --absolute_schedule --no-boundaries_epoch --max_to_keep 20 --step_save_checkpoint 500 --step_evaluation 500 --optimizer adam KWSModel --architecture one_fstride4 & 4 | sleep 5 5 | python evaluate_audio.py --dataset_path google_speech_commands/splitted_data --dataset_split_name valid --output_name output/softmax --num_classes 12 --checkpoint_path work/v1/KWSfstride4-0/mfcc_40_4020_0.0000_adam_l2 --num_silent 258 --augmentation_method anchored_slice_or_pad --preprocess_method mfcc --num_mfccs 40 --clip_duration_ms 1000 --window_size_ms 40 --window_stride_ms 20 --background_frequency 0.0 --background_max_volume 0.0 --max_step_from_restore 30000 --batch_size 3 --no-shuffle --valid_type loop KWSModel --architecture one_fstride4 & 6 | wait 7 | python evaluate_audio.py --dataset_path google_speech_commands/splitted_data --dataset_split_name test --output_name output/softmax --num_classes 12 --checkpoint_path work/v1/KWSfstride4-0/mfcc_40_4020_0.0000_adam_l2/valid/accuracy/valid --num_silent 257 --augmentation_method anchored_slice_or_pad --preprocess_method mfcc --num_mfccs 40 --clip_duration_ms 1000 --window_size_ms 40 --window_stride_ms 20 --background_frequency 0.0 --background_max_volume 0.0 --max_step_from_restore 30000 --batch_size 39 --no-shuffle --valid_type once KWSModel --architecture one_fstride4 8 | -------------------------------------------------------------------------------- /scripts/commands/Res15Model-0_mfcc_40_3010_0.00001_adam_s1.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | trap 'pkill -P $$' SIGINT SIGTERM EXIT 3 | python train_audio.py --dataset_path google_speech_commands/splitted_data --dataset_split_name train --output_name output/softmax --num_classes 12 --train_dir work/v1/Res15Model-0/mfcc_40_3010_0.00001_adam_s1 --num_silent 1854 --augmentation_method anchored_slice_or_pad_with_shift --preprocess_method mfcc --num_mfccs 40 --clip_duration_ms 1000 --window_size_ms 30 --window_stride_ms 10 --batch_size 64 --boundaries 3000 6000 --max_step_from_restore 9000 --lr_list 0.1 0.01 0.001 --absolute_schedule --no-boundaries_epoch --max_to_keep 20 --step_save_checkpoint 500 --step_evaluation 500 --optimizer adam Res15Model --weight_decay 0.00001 & 4 | sleep 5 5 | python evaluate_audio.py --dataset_path google_speech_commands/splitted_data --dataset_split_name valid --output_name output/softmax --num_classes 12 --checkpoint_path work/v1/Res15Model-0/mfcc_40_3010_0.00001_adam_s1 --num_silent 258 --augmentation_method anchored_slice_or_pad --preprocess_method mfcc --num_mfccs 40 --clip_duration_ms 1000 --window_size_ms 30 --window_stride_ms 10 --background_frequency 0.0 --background_max_volume 0.0 --max_step_from_restore 9000 --batch_size 3 --no-shuffle --valid_type loop Res15Model --weight_decay 0.00001 & 6 | wait 7 | python evaluate_audio.py --dataset_path google_speech_commands/splitted_data --dataset_split_name test --output_name output/softmax --num_classes 12 --checkpoint_path work/v1/Res15Model-0/mfcc_40_3010_0.00001_adam_s1/valid/accuracy/valid --num_silent 257 --augmentation_method anchored_slice_or_pad --preprocess_method mfcc --num_mfccs 40 --clip_duration_ms 1000 --window_size_ms 30 --window_stride_ms 10 --background_frequency 0.0 --background_max_volume 0.0 --max_step_from_restore 9000 --batch_size 39 --no-shuffle --valid_type once Res15Model --weight_decay 0.00001 8 | -------------------------------------------------------------------------------- /scripts/commands/Res15NarrowModel-0_mfcc_40_3010_0.00001_adam_s1.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | trap 'pkill -P $$' SIGINT SIGTERM EXIT 3 | python train_audio.py --dataset_path google_speech_commands/splitted_data --dataset_split_name train --output_name output/softmax --num_classes 12 --train_dir work/v1/Res15NarrowModel-0/mfcc_40_3010_0.00001_adam_s1 --num_silent 1854 --augmentation_method anchored_slice_or_pad_with_shift --preprocess_method mfcc --num_mfccs 40 --clip_duration_ms 1000 --window_size_ms 30 --window_stride_ms 10 --batch_size 64 --boundaries 3000 6000 --max_step_from_restore 9000 --lr_list 0.1 0.01 0.001 --absolute_schedule --no-boundaries_epoch --max_to_keep 20 --step_save_checkpoint 500 --step_evaluation 500 --optimizer adam Res15NarrowModel --weight_decay 0.00001 & 4 | sleep 5 5 | python evaluate_audio.py --dataset_path google_speech_commands/splitted_data --dataset_split_name valid --output_name output/softmax --num_classes 12 --checkpoint_path work/v1/Res15NarrowModel-0/mfcc_40_3010_0.00001_adam_s1 --num_silent 258 --augmentation_method anchored_slice_or_pad --preprocess_method mfcc --num_mfccs 40 --clip_duration_ms 1000 --window_size_ms 30 --window_stride_ms 10 --background_frequency 0.0 --background_max_volume 0.0 --max_step_from_restore 9000 --batch_size 3 --no-shuffle --valid_type loop Res15NarrowModel --weight_decay 0.00001 & 6 | wait 7 | python evaluate_audio.py --dataset_path google_speech_commands/splitted_data --dataset_split_name test --output_name output/softmax --num_classes 12 --checkpoint_path work/v1/Res15NarrowModel-0/mfcc_40_3010_0.00001_adam_s1/valid/accuracy/valid --num_silent 257 --augmentation_method anchored_slice_or_pad --preprocess_method mfcc --num_mfccs 40 --clip_duration_ms 1000 --window_size_ms 30 --window_stride_ms 10 --background_frequency 0.0 --background_max_volume 0.0 --max_step_from_restore 9000 --batch_size 39 --no-shuffle --valid_type once Res15NarrowModel --weight_decay 0.00001 8 | -------------------------------------------------------------------------------- /scripts/commands/Res8Model-0_mfcc_40_3010_0.00001_adam_s1.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | trap 'pkill -P $$' SIGINT SIGTERM EXIT 3 | python train_audio.py --dataset_path google_speech_commands/splitted_data --dataset_split_name train --output_name output/softmax --num_classes 12 --train_dir work/v1/Res8Model-0/mfcc_40_3010_0.00001_adam_s1 --num_silent 1854 --augmentation_method anchored_slice_or_pad_with_shift --preprocess_method mfcc --num_mfccs 40 --clip_duration_ms 1000 --window_size_ms 30 --window_stride_ms 10 --batch_size 64 --boundaries 3000 6000 --max_step_from_restore 9000 --lr_list 0.1 0.01 0.001 --absolute_schedule --no-boundaries_epoch --max_to_keep 20 --step_save_checkpoint 500 --step_evaluation 500 --optimizer adam Res8Model --weight_decay 0.00001 & 4 | sleep 5 5 | python evaluate_audio.py --dataset_path google_speech_commands/splitted_data --dataset_split_name valid --output_name output/softmax --num_classes 12 --checkpoint_path work/v1/Res8Model-0/mfcc_40_3010_0.00001_adam_s1 --num_silent 258 --augmentation_method anchored_slice_or_pad --preprocess_method mfcc --num_mfccs 40 --clip_duration_ms 1000 --window_size_ms 30 --window_stride_ms 10 --background_frequency 0.0 --background_max_volume 0.0 --max_step_from_restore 9000 --batch_size 3 --no-shuffle --valid_type loop Res8Model --weight_decay 0.00001 & 6 | wait 7 | python evaluate_audio.py --dataset_path google_speech_commands/splitted_data --dataset_split_name test --output_name output/softmax --num_classes 12 --checkpoint_path work/v1/Res8Model-0/mfcc_40_3010_0.00001_adam_s1/valid/accuracy/valid --num_silent 257 --augmentation_method anchored_slice_or_pad --preprocess_method mfcc --num_mfccs 40 --clip_duration_ms 1000 --window_size_ms 30 --window_stride_ms 10 --background_frequency 0.0 --background_max_volume 0.0 --max_step_from_restore 9000 --batch_size 39 --no-shuffle --valid_type once Res8Model --weight_decay 0.00001 8 | -------------------------------------------------------------------------------- /scripts/commands/Res8NarrowModel-0_mfcc_40_3010_0.00001_adam_s1.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | trap 'pkill -P $$' SIGINT SIGTERM EXIT 3 | python train_audio.py --dataset_path google_speech_commands/splitted_data --dataset_split_name train --output_name output/softmax --num_classes 12 --train_dir work/v1/Res8NarrowModel-0/mfcc_40_3010_0.00001_adam_s1 --num_silent 1854 --augmentation_method anchored_slice_or_pad_with_shift --preprocess_method mfcc --num_mfccs 40 --clip_duration_ms 1000 --window_size_ms 30 --window_stride_ms 10 --batch_size 64 --boundaries 3000 6000 --max_step_from_restore 9000 --lr_list 0.1 0.01 0.001 --absolute_schedule --no-boundaries_epoch --max_to_keep 20 --step_save_checkpoint 500 --step_evaluation 500 --optimizer adam Res8NarrowModel --weight_decay 0.00001 & 4 | sleep 5 5 | python evaluate_audio.py --dataset_path google_speech_commands/splitted_data --dataset_split_name valid --output_name output/softmax --num_classes 12 --checkpoint_path work/v1/Res8NarrowModel-0/mfcc_40_3010_0.00001_adam_s1 --num_silent 258 --augmentation_method anchored_slice_or_pad --preprocess_method mfcc --num_mfccs 40 --clip_duration_ms 1000 --window_size_ms 30 --window_stride_ms 10 --background_frequency 0.0 --background_max_volume 0.0 --max_step_from_restore 9000 --batch_size 3 --no-shuffle --valid_type loop Res8NarrowModel --weight_decay 0.00001 & 6 | wait 7 | python evaluate_audio.py --dataset_path google_speech_commands/splitted_data --dataset_split_name test --output_name output/softmax --num_classes 12 --checkpoint_path work/v1/Res8NarrowModel-0/mfcc_40_3010_0.00001_adam_s1/valid/accuracy/valid --num_silent 257 --augmentation_method anchored_slice_or_pad --preprocess_method mfcc --num_mfccs 40 --clip_duration_ms 1000 --window_size_ms 30 --window_stride_ms 10 --background_frequency 0.0 --background_max_volume 0.0 --max_step_from_restore 9000 --batch_size 39 --no-shuffle --valid_type once Res8NarrowModel --weight_decay 0.00001 8 | -------------------------------------------------------------------------------- /scripts/commands/TCResNet14Model-1.0_mfcc_40_3010_0.001_mom_l1.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | trap 'pkill -P $$' SIGINT SIGTERM EXIT 3 | python train_audio.py --dataset_path google_speech_commands/splitted_data --dataset_split_name train --output_name output/softmax --num_classes 12 --train_dir work/v1/TCResNet14Model-1.0/mfcc_40_3010_0.001_mom_l1 --num_silent 1854 --augmentation_method anchored_slice_or_pad_with_shift --preprocess_method mfcc --num_mfccs 40 --clip_duration_ms 1000 --window_size_ms 30 --window_stride_ms 10 --batch_size 100 --boundaries 10000 20000 --max_step_from_restore 30000 --lr_list 0.1 0.01 0.001 --absolute_schedule --no-boundaries_epoch --max_to_keep 20 --step_save_checkpoint 500 --step_evaluation 500 --optimizer mom --momentum 0.9 TCResNet14Model --weight_decay 0.001 --width_multiplier 1.0 & 4 | sleep 5 5 | python evaluate_audio.py --dataset_path google_speech_commands/splitted_data --dataset_split_name valid --output_name output/softmax --num_classes 12 --checkpoint_path work/v1/TCResNet14Model-1.0/mfcc_40_3010_0.001_mom_l1 --num_silent 258 --augmentation_method anchored_slice_or_pad --preprocess_method mfcc --num_mfccs 40 --clip_duration_ms 1000 --window_size_ms 30 --window_stride_ms 10 --background_frequency 0.0 --background_max_volume 0.0 --max_step_from_restore 30000 --batch_size 3 --no-shuffle --valid_type loop TCResNet14Model --weight_decay 0.001 --width_multiplier 1.0 & 6 | wait 7 | python evaluate_audio.py --dataset_path google_speech_commands/splitted_data --dataset_split_name test --output_name output/softmax --num_classes 12 --checkpoint_path work/v1/TCResNet14Model-1.0/mfcc_40_3010_0.001_mom_l1/valid/accuracy/valid --num_silent 257 --augmentation_method anchored_slice_or_pad --preprocess_method mfcc --num_mfccs 40 --clip_duration_ms 1000 --window_size_ms 30 --window_stride_ms 10 --background_frequency 0.0 --background_max_volume 0.0 --max_step_from_restore 30000 --batch_size 39 --no-shuffle --valid_type once TCResNet14Model --weight_decay 0.001 --width_multiplier 1.0 8 | -------------------------------------------------------------------------------- /scripts/commands/TCResNet14Model-1.5_mfcc_40_3010_0.001_mom_l1.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | trap 'pkill -P $$' SIGINT SIGTERM EXIT 3 | python train_audio.py --dataset_path google_speech_commands/splitted_data --dataset_split_name train --output_name output/softmax --num_classes 12 --train_dir work/v1/TCResNet14Model-1.5/mfcc_40_3010_0.001_mom_l1 --num_silent 1854 --augmentation_method anchored_slice_or_pad_with_shift --preprocess_method mfcc --num_mfccs 40 --clip_duration_ms 1000 --window_size_ms 30 --window_stride_ms 10 --batch_size 100 --boundaries 10000 20000 --max_step_from_restore 30000 --lr_list 0.1 0.01 0.001 --absolute_schedule --no-boundaries_epoch --max_to_keep 20 --step_save_checkpoint 500 --step_evaluation 500 --optimizer mom --momentum 0.9 TCResNet14Model --weight_decay 0.001 --width_multiplier 1.5 & 4 | sleep 5 5 | python evaluate_audio.py --dataset_path google_speech_commands/splitted_data --dataset_split_name valid --output_name output/softmax --num_classes 12 --checkpoint_path work/v1/TCResNet14Model-1.5/mfcc_40_3010_0.001_mom_l1 --num_silent 258 --augmentation_method anchored_slice_or_pad --preprocess_method mfcc --num_mfccs 40 --clip_duration_ms 1000 --window_size_ms 30 --window_stride_ms 10 --background_frequency 0.0 --background_max_volume 0.0 --max_step_from_restore 30000 --batch_size 3 --no-shuffle --valid_type loop TCResNet14Model --weight_decay 0.001 --width_multiplier 1.5 & 6 | wait 7 | python evaluate_audio.py --dataset_path google_speech_commands/splitted_data --dataset_split_name test --output_name output/softmax --num_classes 12 --checkpoint_path work/v1/TCResNet14Model-1.5/mfcc_40_3010_0.001_mom_l1/valid/accuracy/valid --num_silent 257 --augmentation_method anchored_slice_or_pad --preprocess_method mfcc --num_mfccs 40 --clip_duration_ms 1000 --window_size_ms 30 --window_stride_ms 10 --background_frequency 0.0 --background_max_volume 0.0 --max_step_from_restore 30000 --batch_size 39 --no-shuffle --valid_type once TCResNet14Model --weight_decay 0.001 --width_multiplier 1.5 8 | -------------------------------------------------------------------------------- /scripts/commands/TCResNet2D8Model-1.0_mfcc_40_3010_0.001_mom_l1.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | trap 'pkill -P $$' SIGINT SIGTERM EXIT 3 | python train_audio.py --dataset_path google_speech_commands/splitted_data --dataset_split_name train --output_name output/softmax --num_classes 12 --train_dir work/v1/ResNet2D8Model-1.0/mfcc_40_3010_0.001_mom_l1 --num_silent 1854 --augmentation_method anchored_slice_or_pad_with_shift --preprocess_method mfcc --num_mfccs 40 --clip_duration_ms 1000 --window_size_ms 30 --window_stride_ms 10 --batch_size 100 --boundaries 10000 20000 --max_step_from_restore 30000 --lr_list 0.1 0.01 0.001 --absolute_schedule --no-boundaries_epoch --max_to_keep 20 --step_save_checkpoint 500 --step_evaluation 500 --optimizer mom --momentum 0.9 ResNet2D8Model --weight_decay 0.001 --width_multiplier 1.0 & 4 | sleep 5 5 | python evaluate_audio.py --dataset_path google_speech_commands/splitted_data --dataset_split_name valid --output_name output/softmax --num_classes 12 --checkpoint_path work/v1/ResNet2D8Model-1.0/mfcc_40_3010_0.001_mom_l1 --num_silent 258 --augmentation_method anchored_slice_or_pad --preprocess_method mfcc --num_mfccs 40 --clip_duration_ms 1000 --window_size_ms 30 --window_stride_ms 10 --background_frequency 0.0 --background_max_volume 0.0 --max_step_from_restore 30000 --batch_size 3 --no-shuffle --valid_type loop ResNet2D8Model --weight_decay 0.001 --width_multiplier 1.0 & 6 | wait 7 | python evaluate_audio.py --dataset_path google_speech_commands/splitted_data --dataset_split_name test --output_name output/softmax --num_classes 12 --checkpoint_path work/v1/ResNet2D8Model-1.0/mfcc_40_3010_0.001_mom_l1/valid/accuracy/valid --num_silent 257 --augmentation_method anchored_slice_or_pad --preprocess_method mfcc --num_mfccs 40 --clip_duration_ms 1000 --window_size_ms 30 --window_stride_ms 10 --background_frequency 0.0 --background_max_volume 0.0 --max_step_from_restore 30000 --batch_size 39 --no-shuffle --valid_type once ResNet2D8Model --weight_decay 0.001 --width_multiplier 1.0 8 | -------------------------------------------------------------------------------- /scripts/commands/TCResNet2D8PoolModel-1.0_mfcc_40_3010_0.001_mom_l1.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | trap 'pkill -P $$' SIGINT SIGTERM EXIT 3 | python train_audio.py --dataset_path google_speech_commands/splitted_data --dataset_split_name train --output_name output/softmax --num_classes 12 --train_dir work/v1/ResNet2D8PoolModel-1.0/mfcc_40_3010_0.001_mom_l1 --num_silent 1854 --augmentation_method anchored_slice_or_pad_with_shift --preprocess_method mfcc --num_mfccs 40 --clip_duration_ms 1000 --window_size_ms 30 --window_stride_ms 10 --batch_size 100 --boundaries 10000 20000 --max_step_from_restore 30000 --lr_list 0.1 0.01 0.001 --absolute_schedule --no-boundaries_epoch --max_to_keep 20 --step_save_checkpoint 500 --step_evaluation 500 --optimizer mom --momentum 0.9 ResNet2D8PoolModel --weight_decay 0.001 --width_multiplier 1.0 & 4 | sleep 5 5 | python evaluate_audio.py --dataset_path google_speech_commands/splitted_data --dataset_split_name valid --output_name output/softmax --num_classes 12 --checkpoint_path work/v1/ResNet2D8PoolModel-1.0/mfcc_40_3010_0.001_mom_l1 --num_silent 258 --augmentation_method anchored_slice_or_pad --preprocess_method mfcc --num_mfccs 40 --clip_duration_ms 1000 --window_size_ms 30 --window_stride_ms 10 --background_frequency 0.0 --background_max_volume 0.0 --max_step_from_restore 30000 --batch_size 3 --no-shuffle --valid_type loop ResNet2D8PoolModel --weight_decay 0.001 --width_multiplier 1.0 & 6 | wait 7 | python evaluate_audio.py --dataset_path google_speech_commands/splitted_data --dataset_split_name test --output_name output/softmax --num_classes 12 --checkpoint_path work/v1/ResNet2D8PoolModel-1.0/mfcc_40_3010_0.001_mom_l1/valid/accuracy/valid --num_silent 257 --augmentation_method anchored_slice_or_pad --preprocess_method mfcc --num_mfccs 40 --clip_duration_ms 1000 --window_size_ms 30 --window_stride_ms 10 --background_frequency 0.0 --background_max_volume 0.0 --max_step_from_restore 30000 --batch_size 39 --no-shuffle --valid_type once ResNet2D8PoolModel --weight_decay 0.001 --width_multiplier 1.0 8 | -------------------------------------------------------------------------------- /scripts/commands/TCResNet8Model-1.0_mfcc_40_3010_0.001_mom_l1.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | trap 'pkill -P $$' SIGINT SIGTERM EXIT 3 | python train_audio.py --dataset_path google_speech_commands/splitted_data --dataset_split_name train --output_name output/softmax --num_classes 12 --train_dir work/v1/TCResNet8Model-1.0/mfcc_40_3010_0.001_mom_l1 --num_silent 1854 --augmentation_method anchored_slice_or_pad_with_shift --preprocess_method mfcc --num_mfccs 40 --clip_duration_ms 1000 --window_size_ms 30 --window_stride_ms 10 --batch_size 100 --boundaries 10000 20000 --max_step_from_restore 30000 --lr_list 0.1 0.01 0.001 --absolute_schedule --no-boundaries_epoch --max_to_keep 20 --step_save_checkpoint 500 --step_evaluation 500 --optimizer mom --momentum 0.9 TCResNet8Model --weight_decay 0.001 --width_multiplier 1.0 & 4 | sleep 5 5 | python evaluate_audio.py --dataset_path google_speech_commands/splitted_data --dataset_split_name valid --output_name output/softmax --num_classes 12 --checkpoint_path work/v1/TCResNet8Model-1.0/mfcc_40_3010_0.001_mom_l1 --num_silent 258 --augmentation_method anchored_slice_or_pad --preprocess_method mfcc --num_mfccs 40 --clip_duration_ms 1000 --window_size_ms 30 --window_stride_ms 10 --background_frequency 0.0 --background_max_volume 0.0 --max_step_from_restore 30000 --batch_size 3 --no-shuffle --valid_type loop TCResNet8Model --weight_decay 0.001 --width_multiplier 1.0 & 6 | wait 7 | python evaluate_audio.py --dataset_path google_speech_commands/splitted_data --dataset_split_name test --output_name output/softmax --num_classes 12 --checkpoint_path work/v1/TCResNet8Model-1.0/mfcc_40_3010_0.001_mom_l1/valid/accuracy/valid --num_silent 257 --augmentation_method anchored_slice_or_pad --preprocess_method mfcc --num_mfccs 40 --clip_duration_ms 1000 --window_size_ms 30 --window_stride_ms 10 --background_frequency 0.0 --background_max_volume 0.0 --max_step_from_restore 30000 --batch_size 39 --no-shuffle --valid_type once TCResNet8Model --weight_decay 0.001 --width_multiplier 1.0 8 | -------------------------------------------------------------------------------- /scripts/commands/TCResNet8Model-1.5_mfcc_40_3010_0.001_mom_l1.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | trap 'pkill -P $$' SIGINT SIGTERM EXIT 3 | python train_audio.py --dataset_path google_speech_commands/splitted_data --dataset_split_name train --output_name output/softmax --num_classes 12 --train_dir work/v1/TCResNet8Model-1.5/mfcc_40_3010_0.001_mom_l1 --num_silent 1854 --augmentation_method anchored_slice_or_pad_with_shift --preprocess_method mfcc --num_mfccs 40 --clip_duration_ms 1000 --window_size_ms 30 --window_stride_ms 10 --batch_size 100 --boundaries 10000 20000 --max_step_from_restore 30000 --lr_list 0.1 0.01 0.001 --absolute_schedule --no-boundaries_epoch --max_to_keep 20 --step_save_checkpoint 500 --step_evaluation 500 --optimizer mom --momentum 0.9 TCResNet8Model --weight_decay 0.001 --width_multiplier 1.5 & 4 | sleep 5 5 | python evaluate_audio.py --dataset_path google_speech_commands/splitted_data --dataset_split_name valid --output_name output/softmax --num_classes 12 --checkpoint_path work/v1/TCResNet8Model-1.5/mfcc_40_3010_0.001_mom_l1 --num_silent 258 --augmentation_method anchored_slice_or_pad --preprocess_method mfcc --num_mfccs 40 --clip_duration_ms 1000 --window_size_ms 30 --window_stride_ms 10 --background_frequency 0.0 --background_max_volume 0.0 --max_step_from_restore 30000 --batch_size 3 --no-shuffle --valid_type loop TCResNet8Model --weight_decay 0.001 --width_multiplier 1.5 & 6 | wait 7 | python evaluate_audio.py --dataset_path google_speech_commands/splitted_data --dataset_split_name test --output_name output/softmax --num_classes 12 --checkpoint_path work/v1/TCResNet8Model-1.5/mfcc_40_3010_0.001_mom_l1/valid/accuracy/valid --num_silent 257 --augmentation_method anchored_slice_or_pad --preprocess_method mfcc --num_mfccs 40 --clip_duration_ms 1000 --window_size_ms 30 --window_stride_ms 10 --background_frequency 0.0 --background_max_volume 0.0 --max_step_from_restore 30000 --batch_size 39 --no-shuffle --valid_type once TCResNet8Model --weight_decay 0.001 --width_multiplier 1.5 8 | -------------------------------------------------------------------------------- /scripts/google_speech_commmands_dataset_to_our_format.py: -------------------------------------------------------------------------------- 1 | import random 2 | import argparse 3 | 4 | from pathlib import Path 5 | from collections import defaultdict 6 | 7 | 8 | UNKNOWN_WORD_LABEL = 'unknown' 9 | BACKGROUND_NOISE_DIR_NAME = '_background_noise_' 10 | RANDOM_SEED = 59185 11 | random.seed(RANDOM_SEED) 12 | 13 | 14 | def check_path_existence(path, name): 15 | assert path.exists(), (f"{name} ({path}) does not exist!") 16 | 17 | 18 | def parse_arguments(): 19 | parser = argparse.ArgumentParser(description=__doc__) 20 | 21 | parser.add_argument("--input_dir", type=lambda p: Path(p), required=True, 22 | help="Directory as the result of `tar -zxvf `.") 23 | parser.add_argument("--background_noise_dir", type=lambda p: Path(p), required=True, 24 | help="Directory containing noise wav files") 25 | parser.add_argument("--test_list_fullpath", type=lambda p: Path(p), required=True, 26 | help="Textfile which contains name of test wave files.") 27 | parser.add_argument("--valid_list_fullpath", type=lambda p: Path(p), required=True, 28 | help="Textfile which contains name of validation wave files.") 29 | parser.add_argument("--output_dir", type=lambda p: Path(p), required=True, 30 | help="Directory which will contain the result dataset.") 31 | parser.add_argument("--wanted_words", type=str, default="", 32 | help="Comma seperated words to be categorized as foreground. Default '' means take all.") 33 | 34 | args = parser.parse_args() 35 | 36 | # validation check 37 | check_path_existence(args.input_dir, "Input directory") 38 | check_path_existence(args.background_noise_dir, "Background noise directory") 39 | check_path_existence(args.test_list_fullpath, "`test_list_fullpath`") 40 | check_path_existence(args.valid_list_fullpath, "`valid_list_fullpath`") 41 | assert not args.output_dir.exists() or len([p for p in args.output_dir.iterdir()]) == 0, ( 42 | f"Output directory ({args.output_dir}) should be empty!") 43 | 44 | return args 45 | 46 | 47 | def get_label_and_filename(p): 48 | parts = p.parts 49 | label, filename = parts[-2], parts[-1] 50 | return label, filename 51 | 52 | 53 | def is_valid_label(label, valid_labels): 54 | if valid_labels: 55 | is_valid = label in valid_labels 56 | else: 57 | is_valid = True 58 | 59 | return is_valid 60 | 61 | 62 | def is_noise_label(label): 63 | return label == BACKGROUND_NOISE_DIR_NAME 64 | 65 | 66 | def split_files(input_dir, valid_list_fullpath, test_list_fullpath, wanted_words): 67 | # load split list 68 | with test_list_fullpath.open("r") as fr: 69 | test_names = {row.strip(): True for row in fr.readlines()} 70 | 71 | with valid_list_fullpath.open("r") as fr: 72 | valid_names = {row.strip(): True for row in fr.readlines()} 73 | 74 | # set labels 75 | if len(wanted_words) > 0: 76 | valid_labels = set(wanted_words.split(",")) 77 | else: 78 | valid_labels = None 79 | 80 | labels = list() 81 | for p in input_dir.iterdir(): 82 | if p.is_dir() and is_valid_label(p.name, valid_labels) and not is_noise_label(p.name): 83 | labels.append(p.name) 84 | assert len(set(labels)) == len(labels), f"{len(set(labels))} == {len(labels)}" 85 | 86 | # update valid_labels 87 | if len(wanted_words) > 0: 88 | len(valid_labels) == len(labels) 89 | valid_labels = set(labels) 90 | 91 | # iter input directory to get all wav files 92 | samples = { 93 | 'train': defaultdict(list), 94 | 'valid': defaultdict(list), 95 | 'test': defaultdict(list), 96 | } 97 | 98 | for p in input_dir.rglob("*.wav"): 99 | label, filename = get_label_and_filename(p) 100 | if not is_noise_label(label): 101 | name = f"{label}/{filename}" 102 | 103 | if not is_valid_label(label, valid_labels): 104 | label = UNKNOWN_WORD_LABEL 105 | 106 | if test_names.get(name, False): 107 | samples["test"][label].append(p) 108 | elif valid_names.get(name, False): 109 | samples["valid"][label].append(p) 110 | else: 111 | samples["train"][label].append(p) 112 | 113 | has_unknown = all([UNKNOWN_WORD_LABEL in label_samples for split, label_samples in samples.items()]) 114 | for split, label_samples in samples.items(): 115 | if has_unknown: 116 | assert len(label_samples) == len(valid_labels) + 1, f"{set(label_samples)} == {valid_labels}" 117 | else: 118 | assert len(label_samples) == len(valid_labels), f"{set(label_samples)} == {valid_labels}" 119 | 120 | # number of samples 121 | num_train = sum(map(lambda kv: len(kv[1]), samples["train"].items())) 122 | num_valid = sum(map(lambda kv: len(kv[1]), samples["valid"].items())) 123 | num_test = sum(map(lambda kv: len(kv[1]), samples["test"].items())) 124 | 125 | num_samples = num_train + num_valid + num_test 126 | 127 | print(f"Num samples with train / valid / test split: {num_train} / {num_valid} / {num_test}") 128 | print(f"Total {num_samples} samples, {len(valid_labels)} labels") 129 | assert num_train > num_test 130 | assert num_train > num_valid 131 | 132 | # filtering unknown samples 133 | # the number of unknown samples -> mean of all other samples 134 | if has_unknown: 135 | mean_num_samples_per_label = dict() 136 | for split, label_samples in samples.items(): 137 | s = 0 138 | c = 0 139 | for label, sample_list in label_samples.items(): 140 | if label != UNKNOWN_WORD_LABEL: 141 | s += len(sample_list) 142 | c += 1 143 | 144 | m = int(s / c) 145 | 146 | unknown_samples = label_samples[UNKNOWN_WORD_LABEL] 147 | if len(unknown_samples) > m: 148 | random.shuffle(unknown_samples) 149 | label_samples[UNKNOWN_WORD_LABEL] = unknown_samples[:m] 150 | 151 | # number of samples 152 | print("After Filtered:") 153 | num_train = sum(map(lambda kv: len(kv[1]), samples["train"].items())) 154 | num_valid = sum(map(lambda kv: len(kv[1]), samples["valid"].items())) 155 | num_test = sum(map(lambda kv: len(kv[1]), samples["test"].items())) 156 | 157 | num_samples = num_train + num_valid + num_test 158 | 159 | print(f"Num samples with train / valid / test split: {num_train} / {num_valid} / {num_test}") 160 | print(f"Total {num_samples} samples, {len(valid_labels)} labels") 161 | 162 | return samples 163 | 164 | 165 | def generate_dataset(split_samples, background_noise_dir, output_dir): 166 | # make output_dir 167 | if not output_dir.exists(): 168 | output_dir.mkdir(parents=True) 169 | 170 | # link splitted samples 171 | for split, label_samples in split_samples.items(): 172 | base_dir = output_dir / split 173 | base_dir.mkdir() 174 | 175 | # make labels 176 | for label in label_samples: 177 | label_dir = base_dir / label 178 | label_dir.mkdir() 179 | 180 | # link samples 181 | for label, samples in label_samples.items(): 182 | for sample_path in samples: 183 | # do not use label from get_label_and_filename. 184 | # we already replace some of them as UNKNOWN_WORD_LABEL 185 | old_label, filename = get_label_and_filename(sample_path) 186 | 187 | if label == UNKNOWN_WORD_LABEL: 188 | filename = f"{old_label}_{filename}" 189 | 190 | target_path = base_dir / label / filename 191 | target_path.symlink_to(sample_path) 192 | 193 | # link background_noise 194 | if split == "train": 195 | noise_dir = base_dir / BACKGROUND_NOISE_DIR_NAME 196 | noise_dir.symlink_to(background_noise_dir, target_is_directory=True) 197 | 198 | print(f"Make {base_dir} done.") 199 | 200 | 201 | if __name__ == "__main__": 202 | args = parse_arguments() 203 | 204 | samples = split_files(args.input_dir, args.valid_list_fullpath, args.test_list_fullpath, args.wanted_words) 205 | generate_dataset(samples, args.background_noise_dir, args.output_dir) 206 | -------------------------------------------------------------------------------- /speech_commands_dataset/README.md: -------------------------------------------------------------------------------- 1 | # Speech Commands Data Set 2 | 3 | ## About Data Set 4 | [Speech Commands Data Set](https://ai.googleblog.com/2017/08/launching-speech-commands-dataset.html) is a set of one-second .wav audio files, each containing a single spoken English word. 5 | These words are from a small set of commands, and are spoken by a variety of different speakers. 6 | The audio files are organized into folders based on the word they contain, and this data set is designed to help train simple machine learning models. 7 | 8 | ## Data Preparation 9 | 10 | In order to be able to train and evaluate models on Speech Commands Data Set using our code base, we have to download Speech Commands Data Set v0.01 and organize data to three exclusive splits: 11 | 12 | * `train` 13 | * `valid` 14 | * `test` 15 | 16 | We offer automated script (`download_and_split.sh`) which will download and process into predefined format. 17 | ```bash 18 | bash download_and_split.sh [/path/for/dataset] 19 | ``` 20 | 21 | ## Advanced Usage 22 | Below is a detail description of how `download_and_split.sh` works. 23 | 24 | ### 1. Download Speech Commands Data Set 25 | 26 | ```bash 27 | work_dir=${1:-$(pwd)/google_speech_commands} 28 | mkdir -p ${work_dir} 29 | pushd ${work_dir} 30 | wget http://download.tensorflow.org/data/speech_commands_v0.01.tar.gz 31 | tar xzf speech_commands_v0.01.tar.gz 32 | popd 33 | ``` 34 | 35 | ### 2. Modify Data Set File Structure 36 | 37 | `speech_commands_dataset` contains three files `train.txt`, `valid.txt` and `test.txt` composed of utterance file paths associated with data set splits. 38 | 39 | Our codebase requires file structure for training, validation and test data such as follows: 40 | 41 | ``` 42 | ─ v0.01_split 43 | ├── train 44 | │ ├── _background_noise_ 45 | │ ├── down 46 | │ ├── go 47 | │ ├── left 48 | │ ├── no 49 | │ ├── off 50 | │ ├── on 51 | │ ├── right 52 | │ ├── stop 53 | │ ├── unknown 54 | │ ├── up 55 | │ └── yes 56 | ├── valid 57 | │ ├── _background_noise_ 58 | │ ├── down 59 | │ ├── go 60 | │ ├── left 61 | │ ├── no 62 | │ ├── off 63 | │ ├── on 64 | │ ├── right 65 | │ ├── stop 66 | │ ├── unknown 67 | │ ├── up 68 | │ └── yes 69 | └── test 70 | ├── _background_noise_ 71 | ├── down 72 | ├── go 73 | ├── left 74 | ├── no 75 | ├── off 76 | ├── on 77 | ├── right 78 | ├── stop 79 | ├── unknown 80 | ├── up 81 | └── yes 82 | ``` 83 | 84 | To convert Speech Commands Data Set to structure shown above run the command below. 85 | This script does not copy any utterance files, only creates symlinks to the original dataset. 86 | 87 | ```bash 88 | output_dir=${work_dir}/splitted_data 89 | python google_speech_commmands_dataset_to_our_format_with_split.py \ 90 | --input_dir `realpath ${work_dir}` \ 91 | --train_list_fullpath train.txt \ 92 | --valid_list_fullpath valid.txt \ 93 | --test_list_fullpath test.txt \ 94 | --wanted_words yes,no,up,down,left,right,on,off,stop,go \ 95 | --output_dir `realpath ${output_dir}` 96 | ``` 97 | 98 | `output_dir` will be used for further train and evaluation. 99 | 100 | ## Why don't we use script provided by Tensorflow for data split? 101 | 102 | TL;DR Script does not create deterministic split even with fixed random seed. 103 | 104 | Tensorflow provides [preprocessing scripts](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/speech_commands) for Speech Commands Data Set, however while [creating `unknown_index` dictionary](https://github.com/tensorflow/tensorflow/blob/d75adcad0d74e15893a93eb02bed447ec7971664/tensorflow/examples/speech_commands/input_data.py#L280-L294) nondeterministic ordering of utterance files in `unknown_index` comes into play. 105 | [`prepare_data_index` method is using only limited number of words](https://github.com/tensorflow/tensorflow/blob/d75adcad0d74e15893a93eb02bed447ec7971664/tensorflow/examples/speech_commands/input_data.py#L315-L316) from `unknown_index` and assumes that the order of directories and files sought with `gfile.Glob(search_path)` is deterministic on any platform. 106 | Unfortunately, the ordering of directories and files returned by `gfile.Glob(search_path)` is given by `directory order` (can be shown with [`ls -U`](https://unix.stackexchange.com/questions/13451/what-is-the-directory-order-of-files-in-a-directory-used-by-ls-u) command) and does not ensure the same ordering for everybody. 107 | 108 | Because we want to ensure that anybody can reproduce our results we provide train, valid, and test split in `train.txt`, `valid.txt`, and `test.txt`, which were generated by Tensorflow's preprocessing scripts in our environment. 109 | -------------------------------------------------------------------------------- /speech_commands_dataset/download_and_split.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -eux 3 | 4 | work_dir=${1:-$(pwd)/google_speech_commands} 5 | 6 | # download file 7 | mkdir -p ${work_dir} 8 | pushd ${work_dir} 9 | wget http://download.tensorflow.org/data/speech_commands_v0.01.tar.gz 10 | tar xzf speech_commands_v0.01.tar.gz 11 | popd 12 | 13 | # split data 14 | output_dir=${work_dir}/splitted_data 15 | python google_speech_commmands_dataset_to_our_format_with_split.py \ 16 | --input_dir `realpath ${work_dir}` \ 17 | --train_list_fullpath train.txt \ 18 | --valid_list_fullpath valid.txt \ 19 | --test_list_fullpath test.txt \ 20 | --wanted_words yes,no,up,down,left,right,on,off,stop,go \ 21 | --output_dir `realpath ${output_dir}` 22 | echo "Dataset is prepared at ${output_dir}" 23 | -------------------------------------------------------------------------------- /speech_commands_dataset/google_speech_commmands_dataset_to_our_format_with_split.py: -------------------------------------------------------------------------------- 1 | """ 2 | python scripts/google_speech_commmands_dataset_to_our_format_with_split.py \ 3 | --input_dir "/data/data/data_speech_commands_v0.01" \ 4 | --train_list_fullpath "/data/data/data_speech_commands_v0.01/newsplit/training_list.txt" \ 5 | --valid_list_fullpath "/data/data/data_speech_commands_v0.01/newsplit/validation_list.txt" \ 6 | --test_list_fullpath "/data/data/data_speech_commands_v0.01/newsplit/testing_list.txt" \ 7 | --wanted_words "yes,no,up,down,left,right,on,off,stop,go" \ 8 | --output_dir "/data/data/google_audio/data_speech_commands_v0.01" 9 | """ 10 | import random 11 | import argparse 12 | 13 | from pathlib import Path 14 | from collections import defaultdict 15 | 16 | 17 | UNKNOWN_WORD_LABEL = 'unknown' 18 | BACKGROUND_NOISE_LABEL = '_silence_' 19 | BACKGROUND_NOISE_DIR_NAME = '_background_noise_' 20 | 21 | 22 | def check_path_existence(path, name): 23 | assert path.exists(), (f"{name} ({path}) does not exist!") 24 | 25 | 26 | def parse_arguments(): 27 | parser = argparse.ArgumentParser(description=__doc__) 28 | 29 | parser.add_argument("--input_dir", type=lambda p: Path(p), required=True, 30 | help="Directory as the result of `tar -zxvf `.") 31 | parser.add_argument("--test_list_fullpath", type=lambda p: Path(p), required=True, 32 | help="Textfile which contains name of test wave files.") 33 | parser.add_argument("--valid_list_fullpath", type=lambda p: Path(p), required=True, 34 | help="Textfile which contains name of validation wave files.") 35 | parser.add_argument("--train_list_fullpath", type=lambda p: Path(p), required=True, 36 | help="Textfile which contains name of train wave files.") 37 | parser.add_argument("--output_dir", type=lambda p: Path(p), required=True, 38 | help="Directory which will contain the result dataset.") 39 | parser.add_argument("--wanted_words", type=str, default="", 40 | help="Comma seperated words to be categorized as foreground. Default '' means take all.") 41 | 42 | args = parser.parse_args() 43 | 44 | # validation check 45 | check_path_existence(args.input_dir, "Input directory") 46 | check_path_existence(args.test_list_fullpath, "`test_list_fullpath`") 47 | check_path_existence(args.valid_list_fullpath, "`valid_list_fullpath`") 48 | check_path_existence(args.train_list_fullpath, "`valid_list_fullpath`") 49 | assert not args.output_dir.exists() or len([p for p in args.output_dir.iterdir()]) == 0, ( 50 | f"Output directory ({args.output_dir}) should be empty!") 51 | 52 | return args 53 | 54 | 55 | def get_label_and_filename(p): 56 | parts = p.parts 57 | label, filename = parts[-2], parts[-1] 58 | return label, filename 59 | 60 | 61 | def is_valid_label(label, valid_labels): 62 | if valid_labels: 63 | is_valid = label in valid_labels 64 | else: 65 | is_valid = True 66 | 67 | return is_valid 68 | 69 | 70 | def is_noise_label(label): 71 | return label == BACKGROUND_NOISE_DIR_NAME or label == BACKGROUND_NOISE_LABEL 72 | 73 | 74 | def process_files(input_dir, train_list_fullpath, valid_list_fullpath, test_list_fullpath, wanted_words, 75 | output_dir): 76 | # load split list 77 | data = { 78 | "train": [], 79 | "valid": [], 80 | "test": [], 81 | } 82 | 83 | list_fullpath = { 84 | "train": train_list_fullpath, 85 | "valid": valid_list_fullpath, 86 | "test": test_list_fullpath, 87 | } 88 | 89 | for split in data: 90 | with list_fullpath[split].open("r") as fr: 91 | for row in fr.readlines(): 92 | label, filename = get_label_and_filename(Path(row.strip())) 93 | data[split].append((label, filename)) 94 | 95 | # set labels 96 | if len(wanted_words) > 0: 97 | valid_labels = set(wanted_words.split(",")) 98 | else: 99 | valid_labels = None 100 | 101 | labels = list() 102 | for p in input_dir.iterdir(): 103 | if p.is_dir() and is_valid_label(p.name, valid_labels) and not is_noise_label(p.name): 104 | labels.append(p.name) 105 | assert len(set(labels)) == len(labels), f"{len(set(labels))} == {len(labels)}" 106 | 107 | # update valid_labels 108 | if len(wanted_words) > 0: 109 | len(valid_labels) == len(labels) 110 | valid_labels = set(labels) 111 | print(f"Valid Labels: {valid_labels}") 112 | 113 | # make dataset! 114 | # make output_dir 115 | if not output_dir.exists(): 116 | output_dir.mkdir(parents=True) 117 | 118 | for split, lst in data.items(): 119 | # for each split 120 | base_dir = output_dir / split 121 | base_dir.mkdir() 122 | 123 | # make labels for valid + unknown 124 | for label in list(valid_labels) + [UNKNOWN_WORD_LABEL]: 125 | label_dir = base_dir / label 126 | label_dir.mkdir() 127 | 128 | # link files 129 | noise_count = 0 130 | for label, filename in lst: 131 | source_path = input_dir / label / filename 132 | 133 | if is_noise_label(label): 134 | noise_count += 1 135 | else: 136 | if not is_valid_label(label, valid_labels): 137 | filename = f"{label}_{filename}" 138 | label = UNKNOWN_WORD_LABEL 139 | 140 | target_path = base_dir / label / filename 141 | target_path.symlink_to(source_path) 142 | 143 | # report number of noise 144 | print(f"[{split}] Num of silences: {noise_count}") 145 | 146 | # link noise 147 | source_noise_dir = input_dir / BACKGROUND_NOISE_DIR_NAME 148 | target_noise_dir = base_dir / BACKGROUND_NOISE_DIR_NAME 149 | target_noise_dir.symlink_to(source_noise_dir, target_is_directory=True) 150 | 151 | 152 | if __name__ == "__main__": 153 | args = parse_arguments() 154 | 155 | process_files(args.input_dir, 156 | args.train_list_fullpath, 157 | args.valid_list_fullpath, 158 | args.test_list_fullpath, 159 | args.wanted_words, 160 | args.output_dir) 161 | -------------------------------------------------------------------------------- /tflite_tools/benchmark_model_r1.13_official: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hyperconnect/TC-ResNet/8ccbff3a45590247d8c54cc82129acb90eecf5c8/tflite_tools/benchmark_model_r1.13_official -------------------------------------------------------------------------------- /tflite_tools/run_benchmark.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | # ./run_benchmark.sh /path/to/tflite/model.tflite 3 | function run_benchmark() { 4 | model_name=$1 5 | cpu_mask=$2 6 | echo ">>> run_benchmark $1 $2" 7 | if [ $# -eq 1 ] 8 | then 9 | res=$(adb shell /data/local/tmp/benchmark_model_r1.13_official \ 10 | --graph=/data/local/tmp/${model_name} \ 11 | --num_threads=1 \ 12 | --warmup_runs=10 \ 13 | --min_secs=0 \ 14 | --num_runs=50 2>&1 >/dev/null) 15 | elif [ $# -eq 2 ] 16 | then 17 | res=$(adb shell taskset ${cpu_mask} /data/local/tmp/benchmark_model_r1.13_official \ 18 | --graph=/data/local/tmp/${model_name} \ 19 | --num_threads=1 \ 20 | --warmup_runs=10 \ 21 | --min_secs=0 \ 22 | --num_runs=50 2>&1 >/dev/null) 23 | fi 24 | echo "${res}" 25 | } 26 | 27 | function run_benchmark_summary() { 28 | model_name=$1 29 | cpu_mask=$2 30 | echo ">>> run_benchmark_summary $1 $2" 31 | res=$(run_benchmark $model_name $cpu_mask | tail -n 3 | head -n 1) 32 | print_highlighted "${model_name} > ${res}" 33 | } 34 | 35 | function print_highlighted() { 36 | message=$1 37 | light_green="\033[92m" 38 | default="\033[0m" 39 | printf "${light_green}${message}${default}\n" 40 | } 41 | 42 | model_path=$1 43 | cpu_mask=$2 44 | adb push benchmark_model_r1.13_official /data/local/tmp/ 45 | adb shell 'ls /data/local/tmp/benchmark_model_r1.13_official' | tr -d '\r' | xargs -n1 adb shell chmod +x 46 | adb push ${model_path} /data/local/tmp/ 47 | 48 | model_name=`basename $model_path` 49 | run_benchmark_summary $model_name $cpu_mask 50 | 51 | -------------------------------------------------------------------------------- /train_audio.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from typing import List 3 | 4 | import tensorflow as tf 5 | 6 | import const 7 | import factory.audio_nets as audio_nets 8 | import common.utils as utils 9 | from datasets.data_wrapper_base import DataWrapperBase 10 | from datasets.audio_data_wrapper import AudioDataWrapper 11 | from datasets.audio_data_wrapper import SingleLabelAudioDataWrapper 12 | from helper.base import Base 13 | from helper.trainer import TrainerBase 14 | from helper.trainer import SingleLabelAudioTrainer 15 | from factory.base import TFModel 16 | from metrics.base import MetricManagerBase 17 | 18 | 19 | def train(args): 20 | is_training = True 21 | dataset_name = args.dataset_split_name[0] 22 | session = tf.Session(config=const.TF_SESSION_CONFIG) 23 | 24 | dataset = SingleLabelAudioDataWrapper( 25 | args, 26 | session, 27 | dataset_name, 28 | is_training, 29 | ) 30 | wavs, labels = dataset.get_input_and_output_op() 31 | 32 | model = eval(f"audio_nets.{args.model}")(args, dataset) 33 | model.build(wavs=wavs, labels=labels, is_training=is_training) 34 | 35 | trainer = SingleLabelAudioTrainer( 36 | model, 37 | session, 38 | args, 39 | dataset, 40 | dataset_name, 41 | ) 42 | 43 | trainer.train() 44 | 45 | 46 | def parse_arguments(arguments: List[str]=None): 47 | parser = argparse.ArgumentParser(description=__doc__) 48 | subparsers = parser.add_subparsers(title="Model", description="") 49 | 50 | TFModel.add_arguments(parser) 51 | audio_nets.AudioNetModel.add_arguments(parser) 52 | 53 | for class_name in audio_nets._available_nets: 54 | subparser = subparsers.add_parser(class_name) 55 | subparser.add_argument("--model", default=class_name, type=str, help="DO NOT FIX ME") 56 | add_audio_net_arguments = eval(f"audio_nets.{class_name}.add_arguments") 57 | add_audio_net_arguments(subparser) 58 | 59 | DataWrapperBase.add_arguments(parser) 60 | AudioDataWrapper.add_arguments(parser) 61 | Base.add_arguments(parser) 62 | TrainerBase.add_arguments(parser) 63 | SingleLabelAudioTrainer.add_arguments(parser) 64 | MetricManagerBase.add_arguments(parser) 65 | 66 | args = parser.parse_args(arguments) 67 | return args 68 | 69 | 70 | if __name__ == "__main__": 71 | args = parse_arguments() 72 | log = utils.get_logger("Trainer") 73 | 74 | utils.update_train_dir(args) 75 | 76 | log.info(args) 77 | train(args) 78 | --------------------------------------------------------------------------------