├── .flake8 ├── .gitignore ├── LICENSE ├── README.md ├── benchmarks ├── alexa_roc_curve.png ├── hey_jarvis_roc_curve.png └── okay_nabu_roc_curve.png ├── documentation └── data_sources.md ├── etc ├── logo.png └── logo.svg ├── microwakeword ├── __init__.py ├── audio │ ├── audio_utils.py │ ├── augmentation.py │ ├── clips.py │ └── spectrograms.py ├── data.py ├── inception.py ├── inference.py ├── layers │ ├── average_pooling2d.py │ ├── delay.py │ ├── modes.py │ ├── stream.py │ ├── strided_drop.py │ └── sub_spectral_normalization.py ├── mixednet.py ├── model_train_eval.py ├── test.py ├── train.py └── utils.py ├── notebooks └── basic_training_notebook.ipynb ├── pyproject.toml └── setup.py /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | max-line-length = 120 3 | # Following 4 for black compatibility 4 | # E501: line too long 5 | # W503: Line break occurred before a binary operator 6 | # E203: Whitespace before ':' 7 | # D202 No blank lines allowed after function docstring 8 | 9 | # TODO fix flake8 10 | # D100 Missing docstring in public module 11 | # D101 Missing docstring in public class 12 | # D102 Missing docstring in public method 13 | # D103 Missing docstring in public function 14 | # D104 Missing docstring in public package 15 | # D105 Missing docstring in magic method 16 | # D107 Missing docstring in __init__ 17 | # D200 One-line docstring should fit on one line with quotes 18 | # D205 1 blank line required between summary line and description 19 | # D209 Multi-line docstring closing quotes should be on a separate line 20 | # D400 First line should end with a period 21 | # D401 First line should be in imperative mood 22 | 23 | ignore = 24 | E501, 25 | W503, 26 | E203, 27 | D202, 28 | 29 | D100, 30 | D101, 31 | D102, 32 | D103, 33 | D104, 34 | D105, 35 | D107, 36 | D200, 37 | D205, 38 | D209, 39 | D400, 40 | D401, 41 | 42 | exclude = api_pb2.py 43 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Hide sublime text stuff 10 | *.sublime-project 11 | *.sublime-workspace 12 | 13 | # Intellij Idea 14 | .idea 15 | 16 | # Eclipse 17 | .project 18 | .cproject 19 | .pydevproject 20 | .settings/ 21 | 22 | # Vim 23 | *.swp 24 | 25 | # Hide some OS X stuff 26 | .DS_Store 27 | .AppleDouble 28 | .LSOverride 29 | Icon 30 | 31 | # Thumbnails 32 | ._* 33 | 34 | # Distribution / packaging 35 | .Python 36 | build/ 37 | develop-eggs/ 38 | dist/ 39 | downloads/ 40 | eggs/ 41 | .eggs/ 42 | lib/ 43 | lib64/ 44 | parts/ 45 | sdist/ 46 | var/ 47 | wheels/ 48 | *.egg-info/ 49 | .installed.cfg 50 | *.egg 51 | MANIFEST 52 | 53 | # Installer logs 54 | pip-log.txt 55 | pip-delete-this-directory.txt 56 | 57 | # Unit test / coverage reports 58 | htmlcov/ 59 | .tox/ 60 | .coverage 61 | .coverage.* 62 | .cache 63 | .esphome 64 | nosetests.xml 65 | coverage.xml 66 | cov.xml 67 | *.cover 68 | .hypothesis/ 69 | .pytest_cache/ 70 | 71 | # Translations 72 | *.mo 73 | *.pot 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # Environments 79 | .env 80 | .venv 81 | env/ 82 | venv/ 83 | ENV/ 84 | env.bak/ 85 | venv.bak/ 86 | venv-*/ 87 | 88 | # mypy 89 | .mypy_cache/ 90 | 91 | .pioenvs 92 | .piolibdeps 93 | .pio 94 | .vscode/ 95 | !.vscode/tasks.json 96 | CMakeListsPrivate.txt 97 | CMakeLists.txt 98 | 99 | # User-specific stuff: 100 | .idea/**/workspace.xml 101 | .idea/**/tasks.xml 102 | .idea/dictionaries 103 | 104 | # Sensitive or high-churn files: 105 | .idea/**/dataSources/ 106 | .idea/**/dataSources.ids 107 | .idea/**/dataSources.xml 108 | .idea/**/dataSources.local.xml 109 | .idea/**/dynamic.xml 110 | 111 | # CMake 112 | cmake-build-*/ 113 | 114 | CMakeCache.txt 115 | CMakeFiles 116 | CMakeScripts 117 | Testing 118 | Makefile 119 | cmake_install.cmake 120 | install_manifest.txt 121 | compile_commands.json 122 | CTestTestfile.cmake 123 | /*.cbp 124 | 125 | .clang_complete 126 | .gcc-flags.json 127 | 128 | config/ 129 | tests/build/ 130 | tests/.esphome/ 131 | /.temp-clang-tidy.cpp 132 | /.temp/ 133 | .pio/ 134 | 135 | sdkconfig.* 136 | !sdkconfig.defaults 137 | 138 | .tests/ 139 | 140 | /components 141 | 142 | data_training/ 143 | 144 | trained_models/ 145 | notebooks/*/ 146 | 147 | training_parameters.yaml 148 | 149 | *.npy 150 | *.pb 151 | *.ninja 152 | *.tar 153 | *.zip 154 | 155 | *.flac 156 | *.wav 157 | *.mp3 -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ![microWakeWord logo](etc/logo.png) 2 | 3 | microWakeWord is an open-source wakeword library for detecting custom wake words on low power devices. It produces models that are suitable for using [TensorFlow Lite for Microcontrollers](https://www.tensorflow.org/lite/microcontrollers). The models are suitable for real-world usage with low false accept and false reject rates. 4 | 5 | **microWakeword is currently available as an early release. Training new models is intended for advanced users. Training a model that works well is still very difficult, as it typically requires experimentation with hyperparameters and sample generation settings. Please share any insights you find for training a good model!** 6 | 7 | ## Detection Process 8 | 9 | We detect the wake word in two stages. Raw audio data is processed into 40 spectrogram features every 10 ms. The streaming inference model uses the newest slice of feature data as input and returns a probability that the wake word is said. If the model consistently predicts the wake word over multiple windows, then we predict that the wake word has been said. 10 | 11 | The first stage processes the raw monochannel audio data at a sample rate of 16 kHz via the [micro_speech preprocessor](https://github.com/tensorflow/tflite-micro/tree/main/tensorflow/lite/micro/examples/micro_speech). The preprocessor generates 40 features over 30 ms (the window duration) of audio data. The preprocessor generates these features every 10 ms (the stride duration), so the first 20 ms of audio data is part of the previous window. This process is similar to calculating a Mel spectrogram for the audio data, but it includes noise supression and automatic gain control. This makes it suitable for devices with limited processing power. See the linked TFLite Micro example for full details on how the audio is processed. 12 | 13 | The streaming model performs inferences every 30 ms, where the initial convolution layer strides over three 10 ms slices of audio. The model is a neural network using [MixConv](https://arxiv.org/abs/1907.09595) mixed depthwise convolutions suitable for streaming. Streaming and training the model uses heavily modified open-sourced code from [Google Research](https://github.com/google-research/google-research/tree/master/kws_streaming) found in the paper [Streaming Keyword Spotting on Mobile Devices](https://arxiv.org/pdf/2005.06720.pdf) by Rykabov, Kononenko, Subrahmanya, Visontai, and Laurenzo. 14 | 15 | ### Training Process 16 | - We augment the spectrograms in several possible ways during training: 17 | - [SpecAugment](https://arxiv.org/pdf/1904.08779.pdf) masks time and frequency features 18 | - The best weights are chosen as a two-step process: 19 | 1. The top priority is minimizing a specific metric like the false accepts per hour on ambient background noise first. 20 | 2. If the specified minimization target metric is met, then we maximize a different specified metric like accuracy. 21 | - Validation and test sets are split into two portions: 22 | 1. The ``validation`` and ``testing`` sets include the positive and negative generated samples. 23 | 2. The ``validation_ambient`` and ``testing_ambient`` sets are all negative samples representing real-world background sounds; e.g., music, random household noises, and general speech/conversations. 24 | - Generated spectrograms are stored as [Ragged Mmap](https://github.com/hristo-vrigazov/mmap.ninja/tree/master) folders for quick loading from the disk while training. 25 | - Each feature set is configured with a ``sampling_weight`` and ``penalty_weight``. The ``sampling_weight`` parameter controls oversampling and ``penalty_weight`` controls the weight of incorrect predictions. 26 | - Class weights are also adjustable with the ``positive_class_weight`` and ``negative_class_weight`` parameters. It is useful to increase the ``negative_class_weight`` to reduce the amount of false acceptances. 27 | - We train the model in a non-streaming mode; i.e., it trains on the entire spectrogram. When finished, this is converted to a streaming model that updates on only the newest spectrogram features. 28 | - Not padding the convolutions ensures the non-streaming and streaming models have nearly identical prediction behaviors. 29 | - We estimate the false accepts per hour metric during training by splitting long-duration ambient clips into appropriate-sized spectrograms with a 100 ms stride to simulate the streaming model. This is not a perfect estimate of the streaming model's real-world false accepts per hour, but it is sufficient for determining the best weights. 30 | - We should generate spectrogram features over a longer time period than needed for training the model. The preprocessor model applies PCAN and noise reduction, and generating features over a longer time period results in models that are better to generalize. 31 | - We quantize the streaming models to increase performance on low-power devices. This has a small performance penalty that varies from model to model, but there is typically no reduction in accuracy. 32 | 33 | 34 | ## Model Training Process 35 | 36 | We generate samples using [Piper sample generator](https://github.com/rhasspy/piper-sample-generator). 37 | 38 | The generated samples are augmented before or during training to increase variability. There are pre-generated spectrogram features for various negative datasets available on [Hugging Face](https://huggingface.co/datasets/kahrendt/microwakeword). 39 | 40 | Please see the ``basic_training_notebook.ipynb`` notebook to see how a model is trained. This notebook will produce a model, but it will most likely not be usable! Training a usable model requires a lot of experimentation, and that notebook is meant to serve only as a starting point for advanced users. 41 | 42 | ## Models 43 | 44 | See https://github.com/esphome/micro-wake-word-models to download the currently available models. 45 | 46 | ## Acknowledgements 47 | 48 | I am very thankful for many people's support to help improve this! Thank you, in particular, to the following individuals and organizations for providing feedback, collaboration, and developmental support: 49 | 50 | - [balloob](https://github.com/balloob) 51 | - [dscripka](https://github.com/dscripka) 52 | - [jesserockz](https://github.com/jesserockz) 53 | - [kbx81](https://github.com/kbx81) 54 | - [synesthesiam](https://github.com/synesthesiam) 55 | - [ESPHome](https://github.com/esphome) 56 | - [Nabu Casa](https://github.com/NabuCasa) 57 | - [Open Home Foundation](https://www.openhomefoundation.org/) -------------------------------------------------------------------------------- /benchmarks/alexa_roc_curve.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kahrendt/microWakeWord/ef8befc49fbe0aac2dd046a3b0b1264cfc73eff1/benchmarks/alexa_roc_curve.png -------------------------------------------------------------------------------- /benchmarks/hey_jarvis_roc_curve.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kahrendt/microWakeWord/ef8befc49fbe0aac2dd046a3b0b1264cfc73eff1/benchmarks/hey_jarvis_roc_curve.png -------------------------------------------------------------------------------- /benchmarks/okay_nabu_roc_curve.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kahrendt/microWakeWord/ef8befc49fbe0aac2dd046a3b0b1264cfc73eff1/benchmarks/okay_nabu_roc_curve.png -------------------------------------------------------------------------------- /documentation/data_sources.md: -------------------------------------------------------------------------------- 1 | # Data Sources for Training Wake Words 2 | 3 | ## Generated Samples 4 | 5 | [Piper sample generator](https://github.com/rhasspy/piper-sample-generator) uses text-to-speech to generate many wake word samples. We also generate adversarial phrase samples created using [openWakeWord](https://github.com/dscripka/openWakeWord). 6 | 7 | ## Augmentation Sources 8 | 9 | We apply several augments to the generated samples. We use the following sources for background audio samples: 10 | 11 | - [FSD50K: An Open Dataset of Human-Labeled Sound Events](https://arxiv.org/abs/2010.00475) - (Various Creative Commons Licenses.) 12 | - [FMA: A Dataset For Music Analysis](https://arxiv.org/abs/1612.01840) - (Creative Commons Attribution 4.0 International License.) 13 | - [WHAM!: Extending Speech Separation to Noisy Environments](https://arxiv.org/abs/1907.01160) - (Creative Commons Attribution-NonCommercial 4.0 International License.) 14 | 15 | We reverberate the samples with room impulse responses from [BIRD: Big Impulse Response Dataset](https://arxiv.org/abs/2010.09930). 16 | 17 | ## Ambient Noises for Negative Samples 18 | 19 | We use a variety of sources of ambient background noises as negative samples during training. 20 | 21 | ### Ambient Speech 22 | 23 | - [Voices Obscured in Complex Environmental Settings (VOICES) corpus](https://arxiv.org/abs/1804.05053) - (Creative Commons Attribution 4.0 License.) 24 | - [Common Voice: A Massively-Multilingual Speech Corpus](https://arxiv.org/abs/1912.06670) - (Creative Commons License.) 25 | 26 | ### Ambient Background 27 | 28 | - [FSD50K: An Open Dataset of Human-Labeled Sound Events](https://arxiv.org/abs/2010.00475) 29 | - [FMA: A Dataset For Music Analysis](https://arxiv.org/abs/1612.01840) - reverberated with room impulse responses 30 | - [WHAM!: Extending Speech Separation to Noisy Environments](https://arxiv.org/abs/1907.01160) 31 | 32 | ## Validation and Test Sets 33 | 34 | We generate positive and negative samples solely for validation and testing. We augment these samples in the same way as the training data. We split the FSDK50K, FMA, and WHAM! datasets 90-10 into training and testing sets (they are not in the validation set). We estimate the false accepts per hour during training with the VOiCES validation set and [DiPCo - Dinner Party Corpus](https://www.amazon.science/publications/dipco-dinner-party-corpus) (Community Data License Agreement – Permissive Version 1.0 License.) We test the false accepts per hour in streaming mode after training with the DiPCo set. 35 | -------------------------------------------------------------------------------- /etc/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kahrendt/microWakeWord/ef8befc49fbe0aac2dd046a3b0b1264cfc73eff1/etc/logo.png -------------------------------------------------------------------------------- /etc/logo.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 48 | -------------------------------------------------------------------------------- /microwakeword/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kahrendt/microWakeWord/ef8befc49fbe0aac2dd046a3b0b1264cfc73eff1/microwakeword/__init__.py -------------------------------------------------------------------------------- /microwakeword/audio/audio_utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 Kevin Ahrendt. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import numpy as np 17 | import tensorflow as tf 18 | import webrtcvad 19 | 20 | from tensorflow.lite.experimental.microfrontend.python.ops import ( 21 | audio_microfrontend_op as frontend_op, 22 | ) 23 | from scipy.io import wavfile 24 | 25 | from pymicro_features import MicroFrontend 26 | 27 | 28 | def generate_features_for_clip( 29 | audio_samples: np.ndarray, step_ms: int = 20, use_c: bool = True 30 | ): 31 | """Generates spectrogram features for the given audio data. 32 | 33 | Args: 34 | audio_samples (numpy.ndarray): The clip's audio samples. 35 | step_ms (int, optional): The window step size in ms. Defaults to 20. 36 | use_c (bool, optional): Whether to use the C implementation of the microfrontend via pymicro-features. Defaults to True. 37 | 38 | Raises: 39 | ValueError: If the provided audio data is not a 16-bit integer array. 40 | 41 | 42 | Returns: 43 | numpy.ndarray: The spectrogram features for the provided audio clip. 44 | """ 45 | 46 | # Convert any float formatted audio data to an int16 array 47 | if audio_samples.dtype in (np.float32, np.float64): 48 | audio_samples = np.clip((audio_samples * 32768), -32768, 32767).astype(np.int16) 49 | 50 | if use_c: 51 | audio_samples = audio_samples.tobytes() 52 | micro_frontend = MicroFrontend() 53 | features = [] 54 | audio_idx = 0 55 | num_audio_bytes = len(audio_samples) 56 | while audio_idx + 160 * 2 < num_audio_bytes: 57 | frontend_result = micro_frontend.ProcessSamples( 58 | audio_samples[audio_idx : audio_idx + 160 * 2] 59 | ) 60 | audio_idx += frontend_result.samples_read * 2 61 | if frontend_result.features: 62 | features.append(frontend_result.features) 63 | 64 | return np.array(features).astype(np.float32) 65 | 66 | with tf.device("/cpu:0"): 67 | # The default settings match the TFLM preprocessor settings. 68 | # Preproccesor model is available from the tflite-micro repository, accessed December 2023. 69 | micro_frontend = frontend_op.audio_microfrontend( 70 | tf.convert_to_tensor(audio_samples), 71 | sample_rate=16000, 72 | window_size=30, 73 | window_step=step_ms, 74 | num_channels=40, 75 | upper_band_limit=7500, 76 | lower_band_limit=125, 77 | enable_pcan=True, 78 | min_signal_remaining=0.05, 79 | out_scale=1, 80 | out_type=tf.uint16, 81 | ) 82 | 83 | spectrogram = micro_frontend.numpy() 84 | return spectrogram 85 | 86 | 87 | def save_clip(audio_samples: np.ndarray, output_file: str) -> None: 88 | """Saves an audio clip's sample as a wave file. 89 | 90 | Args: 91 | audio_samples (numpy.ndarray): The clip's audio samples. 92 | output_file (str): Path to the desired output file. 93 | """ 94 | if audio_samples.dtype in (np.float32, np.float64): 95 | audio_samples = (audio_samples * 32767).astype(np.int16) 96 | wavfile.write(output_file, 16000, audio_samples) 97 | 98 | 99 | def remove_silence_webrtc( 100 | audio_data: np.ndarray, 101 | frame_duration: float = 0.030, 102 | sample_rate: int = 16000, 103 | min_start: int = 2000, 104 | ) -> np.ndarray: 105 | """Uses webrtc voice activity detection to remove silence from the clips 106 | 107 | Args: 108 | audio_data (numpy.ndarray): The input clip's audio samples. 109 | frame_duration (float): The frame_duration for webrtcvad. Defaults to 0.03. 110 | sample_rate (int): The audio's sample rate. Defaults to 16000. 111 | min_start: (int): The number of audio samples from the start of the clip to always include. Defaults to 2000. 112 | 113 | Returns: 114 | numpy.ndarray: Array with the trimmed audio clip's samples. 115 | """ 116 | vad = webrtcvad.Vad(0) 117 | 118 | # webrtcvad expects int16 arrays as input, so convert if audio_data is a float 119 | float_type = audio_data.dtype in (np.float32, np.float64) 120 | if float_type: 121 | audio_data = (audio_data * 32767).astype(np.int16) 122 | 123 | filtered_audio = audio_data[0:min_start].tolist() 124 | 125 | step_size = int(sample_rate * frame_duration) 126 | 127 | for i in range(min_start, audio_data.shape[0] - step_size, step_size): 128 | vad_detected = vad.is_speech( 129 | audio_data[i : i + step_size].tobytes(), sample_rate 130 | ) 131 | if vad_detected: 132 | # If voice activity is detected, add it to filtered_audio 133 | filtered_audio.extend(audio_data[i : i + step_size].tolist()) 134 | 135 | # If the original audio data was a float array, convert back 136 | if float_type: 137 | trimmed_audio = np.array(filtered_audio) 138 | return np.array(trimmed_audio / 32767).astype(np.float32) 139 | 140 | return np.array(filtered_audio).astype(np.int16) 141 | -------------------------------------------------------------------------------- /microwakeword/audio/augmentation.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 Kevin Ahrendt. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import audiomentations 17 | import warnings 18 | 19 | import numpy as np 20 | 21 | from typing import List 22 | 23 | 24 | class Augmentation: 25 | """A class that handles applying augmentations to audio clips. 26 | 27 | Args: 28 | augmentation_duration_s (float): The duration of the augmented clip in seconds. 29 | augmentation_probabilities (dict, optional): Dictionary that specifies each augmentation's probability of being applied. Defaults to { "SevenBandParametricEQ": 0.0, "TanhDistortion": 0.0, "PitchShift": 0.0, "BandStopFilter": 0.0, "AddColorNoise": 0.25, "AddBackgroundNoise": 0.75, "Gain": 1.0, "GainTransition": 0.25, "RIR": 0.5, }. 30 | impulse_paths (List[str], optional): List of directory paths that contain room impulse responses that the audio clip is reverberated with. If the list is empty, then reverberation is not applied. Defaults to []. 31 | background_paths (List[str], optional): List of directory paths that contain audio clips to be mixed into the audio clip. If the list is empty, then the background augmentation is not applied. Defaults to []. 32 | background_min_snr_db (int, optional): The minimum signal to noise ratio for mixing in background audio. Defaults to -10. 33 | background_max_snr_db (int, optional): The maximum signal to noise ratio for mixing in background audio. Defaults to 10. 34 | min_gain_db (float, optional): The minimum gain for the gain augmentation. Defaults to -45.0. 35 | max_gain_db (float, optional): The mmaximum gain for the gain augmentation. Defaults to 0.0. 36 | min_gain_transition_db (float, optional): The minimum gain for the gain transition augmentation. Defaults to -10.0. 37 | max_gain_transition_db (float, optional): The mmaximum gain for the gain transition augmentation. Defaults to 10.0. 38 | min_jitter_s (float, optional): The minimum duration in seconds that the original clip is positioned before the end of the augmented audio. Defaults to 0.0. 39 | max_jitter_s (float, optional): The maximum duration in seconds that the original clip is positioned before the end of the augmented audio. Defaults to 0.0. 40 | truncate_randomly: (bool, option): If true, the clip is truncated to the specified duration randomly. Otherwise, the start of the clip is truncated. 41 | """ 42 | 43 | def __init__( 44 | self, 45 | augmentation_duration_s: float | None = None, 46 | augmentation_probabilities: dict = { 47 | "SevenBandParametricEQ": 0.0, 48 | "TanhDistortion": 0.0, 49 | "PitchShift": 0.0, 50 | "BandStopFilter": 0.0, 51 | "AddColorNoise": 0.25, 52 | "AddBackgroundNoise": 0.75, 53 | "Gain": 1.0, 54 | "GainTransition": 0.25, 55 | "RIR": 0.5, 56 | }, 57 | impulse_paths: List[str] = [], 58 | background_paths: List[str] = [], 59 | background_min_snr_db: int = -10, 60 | background_max_snr_db: int = 10, 61 | color_min_snr_db: int = 10, 62 | color_max_snr_db: int = 30, 63 | min_gain_db: float = -45, 64 | max_gain_db: float = 0, 65 | min_gain_transition_db: float = -10, 66 | max_gain_transition_db: float = 10, 67 | min_jitter_s: float = 0.0, 68 | max_jitter_s: float = 0.0, 69 | truncate_randomly: bool = False, 70 | ): 71 | self.truncate_randomly = truncate_randomly 72 | ############################################ 73 | # Configure audio duration and positioning # 74 | ############################################ 75 | 76 | self.min_jitter_samples = int(min_jitter_s * 16000) 77 | self.max_jitter_samples = int(max_jitter_s * 16000) 78 | 79 | if augmentation_duration_s is not None: 80 | self.augmented_samples = int(augmentation_duration_s * 16000) 81 | else: 82 | self.augmented_samples = None 83 | 84 | assert ( 85 | self.min_jitter_samples <= self.max_jitter_samples 86 | ), "Minimum jitter must be less than or equal to maximum jitter." 87 | 88 | ####################### 89 | # Setup augmentations # 90 | ####################### 91 | 92 | # If either the background_paths or impulse_paths are not specified, use an identity transform instead 93 | def identity_transform(samples, sample_rate): 94 | return samples 95 | 96 | background_noise_augment = audiomentations.Lambda( 97 | transform=identity_transform, p=0.0 98 | ) 99 | reverb_augment = audiomentations.Lambda(transform=identity_transform, p=0.0) 100 | 101 | if len(background_paths): 102 | background_noise_augment = audiomentations.AddBackgroundNoise( 103 | p=augmentation_probabilities.get("AddBackgroundNoise", 0.0), 104 | sounds_path=background_paths, 105 | min_snr_db=background_min_snr_db, 106 | max_snr_db=background_max_snr_db, 107 | ) 108 | 109 | if len(impulse_paths) > 0: 110 | reverb_augment = audiomentations.ApplyImpulseResponse( 111 | p=augmentation_probabilities.get("RIR", 0.0), 112 | ir_path=impulse_paths, 113 | ) 114 | 115 | # Based on openWakeWord's augmentations, accessed on February 23, 2024. 116 | self.augment = audiomentations.Compose( 117 | transforms=[ 118 | audiomentations.SevenBandParametricEQ( 119 | p=augmentation_probabilities.get("SevenBandParametricEQ", 0.0), 120 | min_gain_db=-6, 121 | max_gain_db=6, 122 | ), 123 | audiomentations.TanhDistortion( 124 | p=augmentation_probabilities.get("TanhDistortion", 0.0), 125 | min_distortion=0.0001, 126 | max_distortion=0.10, 127 | ), 128 | audiomentations.PitchShift( 129 | p=augmentation_probabilities.get("PitchShift", 0.0), 130 | min_semitones=-3, 131 | max_semitones=3, 132 | ), 133 | audiomentations.BandStopFilter( 134 | p=augmentation_probabilities.get("BandStopFilter", 0.0), 135 | ), 136 | audiomentations.AddColorNoise( 137 | p=augmentation_probabilities.get("AddColorNoise", 0.0), 138 | min_snr_db=color_min_snr_db, 139 | max_snr_db=color_max_snr_db, 140 | ), 141 | background_noise_augment, 142 | audiomentations.Gain( 143 | p=augmentation_probabilities.get("Gain", 0.0), 144 | min_gain_db=min_gain_db, 145 | max_gain_db=max_gain_db, 146 | ), 147 | audiomentations.GainTransition( 148 | p=augmentation_probabilities.get("GainTransition", 0.0), 149 | min_gain_db=min_gain_transition_db, 150 | max_gain_db=max_gain_transition_db, 151 | ), 152 | reverb_augment, 153 | audiomentations.Compose( 154 | transforms=[ 155 | audiomentations.Normalize( 156 | apply_to="only_too_loud_sounds", p=1.0 157 | ) 158 | ] 159 | ), # If the audio is clipped, normalize 160 | ], 161 | shuffle=False, 162 | ) 163 | 164 | def add_jitter(self, input_audio: np.ndarray): 165 | """Pads the clip on the right by a random duration between the class's min_jitter_s and max_jitter_s paramters. 166 | 167 | Args: 168 | input_audio (numpy.ndarray): Array containing the audio clip's samples. 169 | 170 | Returns: 171 | numpy.ndarray: Array of audio samples with silence added to the end. 172 | """ 173 | if self.min_jitter_samples < self.max_jitter_samples: 174 | jitter_samples = np.random.randint( 175 | self.min_jitter_samples, self.max_jitter_samples 176 | ) 177 | else: 178 | jitter_samples = self.min_jitter_samples 179 | 180 | # Pad audio on the right by jitter samples 181 | return np.pad(input_audio, (0, jitter_samples)) 182 | 183 | def create_fixed_size_clip(self, input_audio: np.ndarray): 184 | """Ensures the input audio clip has a fixced length. If the duration is too long, the start of the clip is removed. If it is too short, the start of the clip is padded with silence. 185 | 186 | Args: 187 | input_audio (numpy.ndarray): Array containing the audio clip's samples. 188 | 189 | Returns: 190 | numpy.ndarray: Array of audio samples with `augmented_duration_s` length. 191 | """ 192 | if self.augmented_samples is None: 193 | return input_audio 194 | 195 | if self.augmented_samples < input_audio.shape[0]: 196 | # Truncate the too long audio by removing the start of the clip 197 | if self.truncate_randomly: 198 | random_start = np.random.randint( 199 | 0, input_audio.shape[0] - self.augmented_samples 200 | ) 201 | input_audio = input_audio[ 202 | random_start : random_start + self.augmented_samples 203 | ] 204 | else: 205 | input_audio = input_audio[-self.augmented_samples :] 206 | else: 207 | # Pad with zeros at start of too short audio clip 208 | left_padding_samples = self.augmented_samples - input_audio.shape[0] 209 | 210 | input_audio = np.pad(input_audio, (left_padding_samples, 0)) 211 | 212 | return input_audio 213 | 214 | def augment_clip(self, input_audio: np.ndarray): 215 | """Augments the input audio after adding jitter and creating a fixed size clip. 216 | 217 | Args: 218 | input_audio (numpy.ndarray): Array containing the audio clip's samples. 219 | 220 | Returns: 221 | numpy.ndarray: The augmented audio of fixed duration. 222 | """ 223 | input_audio = self.add_jitter(input_audio) 224 | input_audio = self.create_fixed_size_clip(input_audio) 225 | 226 | with warnings.catch_warnings(): 227 | warnings.simplefilter( 228 | "ignore" 229 | ) # Suppresses warning about background clip being too quiet... TODO: find better approach! 230 | output_audio = self.augment(input_audio, sample_rate=16000) 231 | 232 | return output_audio 233 | 234 | def augment_generator(self, audio_generator): 235 | """A Python generator that augments clips retrived from the input audio generator. 236 | 237 | Args: 238 | audio_generator (generator): A Python generator that yields audio clips. 239 | 240 | Yields: 241 | numpy.ndarray: The augmented audio clip's samples. 242 | """ 243 | for audio in audio_generator: 244 | yield self.augment_clip(audio) 245 | -------------------------------------------------------------------------------- /microwakeword/audio/clips.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 Kevin Ahrendt. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import audio_metadata 17 | import datasets 18 | import math 19 | import os 20 | import random 21 | import wave 22 | 23 | import numpy as np 24 | 25 | from pathlib import Path 26 | 27 | from microwakeword.audio.audio_utils import remove_silence_webrtc 28 | 29 | 30 | class Clips: 31 | """Class for loading audio clips from the specified directory. The clips can first be filtered by their duration using the `min_clip_duration_s` and `max_clip_duration_s` parameters. Clips are retrieved as numpy float arrays via the `get_random_clip` method or via the `audio_generator` or `random_audio_generator` generators. Before retrieval, the audio clip can trim non-voice activiity. Before retrieval, the audio clip can be repeated until it is longer than a specified minimum duration. 32 | 33 | Args: 34 | input_directory (str): Path to audio clip files. 35 | file_pattern (str): File glob pattern for selecting audio clip files. 36 | min_clip_duration_s (float | None, optional): The minimum clip duration (in seconds). Set to None to disable filtering by minimum clip duration. Defaults to None. 37 | max_clip_duration_s (float | None, optional): The maximum clip duration (in seconds). Set to None to disable filtering by maximum clip duration. Defaults to None. 38 | repeat_clip_min_duration_s (float | None, optional): If a clip is shorter than this duration, then it is repeated until it is longer than this duration. Set to None to disable repeating the clip. Defaults to None. 39 | remove_silence (bool, optional): Use webrtcvad to trim non-voice activity in the clip. Defaults to False. 40 | random_split_seed (int | None, optional): The random seed used to split the clips into different sets. Set to None to disable splitting the clips. Defaults to None. 41 | split_count (int | float, optional): The percentage/count of clips to be included in the testing and validation sets. Defaults to 0.1. 42 | trimmed_clip_duration_s: (float | None, optional): The duration of the clips to trim the end of long clips. Set to None to disable trimming. Defaults to None. 43 | trim_zerios: (bool, optional): If true, any leading and trailling zeros are removed. Defaults to false. 44 | """ 45 | 46 | def __init__( 47 | self, 48 | input_directory: str, 49 | file_pattern: str, 50 | min_clip_duration_s: float | None = None, 51 | max_clip_duration_s: float | None = None, 52 | repeat_clip_min_duration_s: float | None = None, 53 | remove_silence: bool = False, 54 | random_split_seed: int | None = None, 55 | split_count: int | float = 0.1, 56 | trimmed_clip_duration_s: float | None = None, 57 | trim_zeros: bool = False, 58 | ): 59 | self.trim_zeros = trim_zeros 60 | self.trimmed_clip_duration_s = trimmed_clip_duration_s 61 | 62 | if min_clip_duration_s is not None: 63 | self.min_clip_duration_s = min_clip_duration_s 64 | else: 65 | self.min_clip_duration_s = 0.0 66 | 67 | if max_clip_duration_s is not None: 68 | self.max_clip_duration_s = max_clip_duration_s 69 | else: 70 | self.max_clip_duration_s = math.inf 71 | 72 | if repeat_clip_min_duration_s is not None: 73 | self.repeat_clip_min_duration_s = repeat_clip_min_duration_s 74 | else: 75 | self.repeat_clip_min_duration_s = 0.0 76 | 77 | self.remove_silence = remove_silence 78 | 79 | self.remove_silence_function = remove_silence_webrtc 80 | 81 | paths_to_clips = [str(i) for i in Path(input_directory).glob(file_pattern)] 82 | 83 | if (self.min_clip_duration_s == 0) and (math.isinf(self.max_clip_duration_s)): 84 | # No durations specified, so do not filter by length 85 | filtered_paths = paths_to_clips 86 | else: 87 | # Filter audio clips by length 88 | if file_pattern.endswith("wav"): 89 | # If it is a wave file, assume all wave files have the same parameters and filter by file size. 90 | # Based on openWakeWord's estimate_clip_duration and filter_audio_paths in data.py, accessed March 2, 2024. 91 | with wave.open(paths_to_clips[0], "rb") as input_wav: 92 | channels = input_wav.getnchannels() 93 | sample_width = input_wav.getsampwidth() 94 | sample_rate = input_wav.getframerate() 95 | frames = input_wav.getnframes() 96 | 97 | sizes = [] 98 | sizes.extend([os.path.getsize(i) for i in paths_to_clips]) 99 | 100 | # Correct for the wav file header bytes. Assumes all files in the directory have same parameters. 101 | header_correction = ( 102 | os.path.getsize(paths_to_clips[0]) 103 | - frames * sample_width * channels 104 | ) 105 | 106 | durations = [] 107 | for size in sizes: 108 | durations.append( 109 | (size - header_correction) 110 | / (sample_rate * sample_width * channels) 111 | ) 112 | 113 | filtered_paths = [ 114 | path_to_clip 115 | for path_to_clip, duration in zip(paths_to_clips, durations) 116 | if (self.min_clip_duration_s < duration) 117 | and (duration < self.max_clip_duration_s) 118 | ] 119 | else: 120 | # If not a wave file, use the audio_metadata package to analyze audio file headers for the duration. 121 | # This is slower! 122 | filtered_paths = [] 123 | 124 | if (self.min_clip_duration_s > 0) or ( 125 | not math.isinf(self.max_clip_duration_s) 126 | ): 127 | for audio_file in paths_to_clips: 128 | metadata = audio_metadata.load(audio_file) 129 | duration = metadata["streaminfo"]["duration"] 130 | if (self.min_clip_duration_s < duration) and ( 131 | duration < self.max_clip_duration_s 132 | ): 133 | filtered_paths.append(audio_file) 134 | 135 | # Load all filtered clips 136 | audio_dataset = datasets.Dataset.from_dict( 137 | {"audio": [str(i) for i in filtered_paths]} 138 | ).cast_column("audio", datasets.Audio()) 139 | 140 | # Convert all clips to 16 kHz sampling rate when accessed 141 | audio_dataset = audio_dataset.cast_column( 142 | "audio", datasets.Audio(sampling_rate=16000) 143 | ) 144 | 145 | if random_split_seed is not None: 146 | train_testvalid = audio_dataset.train_test_split( 147 | test_size=2 * split_count, seed=random_split_seed 148 | ) 149 | test_valid = train_testvalid["test"].train_test_split(test_size=0.5) 150 | split_dataset = datasets.DatasetDict( 151 | { 152 | "train": train_testvalid["train"], 153 | "test": test_valid["test"], 154 | "validation": test_valid["train"], 155 | } 156 | ) 157 | self.split_clips = split_dataset 158 | 159 | self.clips = audio_dataset 160 | 161 | def audio_generator(self, split: str | None = None, repeat: int = 1): 162 | """A Python generator that retrieves all loaded audio clips. 163 | 164 | Args: 165 | split (str | None, optional): Specifies which set the clips are retrieved from. If None, all clips are retrieved. Otherwise, it can be set to `train`, `test`, or `validation`. Defaults to None. 166 | repeat (int, optional): The number of times each audio clip will be yielded. Defaults to 1. 167 | 168 | Yields: 169 | numpy.ndarray: Array with the audio clip's samples. 170 | """ 171 | if split is None: 172 | clip_list = self.clips 173 | else: 174 | clip_list = self.split_clips[split] 175 | for _ in range(repeat): 176 | for clip in clip_list: 177 | clip_audio = clip["audio"]["array"] 178 | 179 | if self.remove_silence: 180 | clip_audio = self.remove_silence_function(clip_audio) 181 | 182 | if self.trim_zeros: 183 | clip_audio = np.trim_zeros(clip_audio) 184 | 185 | if self.trimmed_clip_duration_s: 186 | total_samples = int(self.trimmed_clip_duration_s * 16000) 187 | clip_audio = clip_audio[:total_samples] 188 | 189 | clip_audio = self.repeat_clip(clip_audio) 190 | yield clip_audio 191 | 192 | def get_random_clip(self): 193 | """Retrieves a random audio clip. 194 | 195 | Returns: 196 | numpy.ndarray: Array with the audio clip's samples. 197 | """ 198 | rand_audio_entry = random.choice(self.clips) 199 | clip_audio = rand_audio_entry["audio"]["array"] 200 | 201 | if self.remove_silence: 202 | clip_audio = self.remove_silence_function(clip_audio) 203 | 204 | if self.trim_zeros: 205 | clip_audio = np.trim_zeros(clip_audio) 206 | 207 | if self.trimmed_clip_duration_s: 208 | total_samples = int(self.trimmed_clip_duration_s * 16000) 209 | clip_audio = clip_audio[:total_samples] 210 | 211 | clip_audio = self.repeat_clip(clip_audio) 212 | return clip_audio 213 | 214 | def random_audio_generator(self, max_clips: int = math.inf): 215 | """A Python generator that retrieves random audio clips. 216 | 217 | Args: 218 | max_clips (int, optional): The total number of clips the generator will yield before the StopIteration. Defaults to math.inf. 219 | 220 | Yields: 221 | numpy.ndarray: Array with the random audio clip's samples. 222 | """ 223 | while max_clips > 0: 224 | max_clips -= 1 225 | 226 | yield self.get_random_clip() 227 | 228 | def repeat_clip(self, audio_samples: np.array): 229 | """Repeats the audio clip until its duration exceeds the minimum specified in the class. 230 | 231 | Args: 232 | audio_samples numpy.ndarray: Original audio clip's samples. 233 | 234 | Returns: 235 | numpy.ndarray: Array with duration exceeding self.repeat_clip_min_duration_s. 236 | """ 237 | original_clip = audio_samples 238 | desired_samples = int(self.repeat_clip_min_duration_s * 16000) 239 | while audio_samples.shape[0] < desired_samples: 240 | audio_samples = np.append(audio_samples, original_clip) 241 | return audio_samples 242 | -------------------------------------------------------------------------------- /microwakeword/audio/spectrograms.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 Kevin Ahrendt. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import numpy as np 17 | 18 | from microwakeword.audio.audio_utils import generate_features_for_clip 19 | from microwakeword.audio.augmentation import Augmentation 20 | from microwakeword.audio.clips import Clips 21 | 22 | 23 | class SpectrogramGeneration: 24 | """A class that handles generating spectrogram features for audio clips. Spectrograms can optionally be split into nonoverlapping segments for faster file loading or they can optionally be strided by dropping the last feature windows to simulate a streaming model's sequential inputs. 25 | 26 | Args: 27 | clips (Clips): Object that retrieves audio clips. 28 | augmenter (Augmentation | None, optional): Object that augments audio clips. If None, no augmentations are applied. Defaults to None. 29 | step_ms (int, optional): The window step size in ms for the spectrogram features. Defaults to 20. 30 | split_spectrogram_duration_s (float | None, optional): Splits generated spectrograms to yield nonoverlapping spectrograms with this duration. If None, the entire spectrogram is yielded. Defaults to None. 31 | slide_frames (int | None, optional): Strides the generated spectrograms to yield `slide_frames` overlapping spectrogram by removing features at the end of the spectrogram. If None, the entire spectrogram is yielded. Defaults to None. 32 | """ 33 | 34 | def __init__( 35 | self, 36 | clips: Clips, 37 | augmenter: Augmentation | None = None, 38 | step_ms: int = 20, 39 | split_spectrogram_duration_s: float | None = None, 40 | slide_frames: int | None = None, 41 | ): 42 | 43 | self.clips = clips 44 | self.augmenter = augmenter 45 | self.step_ms = step_ms 46 | self.split_spectrogram_duration_s = split_spectrogram_duration_s 47 | self.slide_frames = slide_frames 48 | 49 | def get_random_spectrogram(self): 50 | """Retrieves a random audio clip's spectrogram that is optionally augmented. 51 | 52 | Returns: 53 | numpy.ndarry: 2D spectrogram array for the random (augmented) audio clip. 54 | """ 55 | clip = self.clips.get_random_clip() 56 | if self.augmenter is not None: 57 | clip = self.augmenter.augment_clip(clip) 58 | 59 | return generate_features_for_clip(clip, self.step_ms) 60 | 61 | def spectrogram_generator(self, random=False, max_clips=None, **kwargs): 62 | """A Python generator that retrieves (augmented) spectrograms. 63 | 64 | Args: 65 | random (bool, optional): Specifies if the source audio clips should be chosen randomly. Defaults to False. 66 | kwargs: Parameters to pass to the clips audio generator. 67 | 68 | Yields: 69 | numpy.ndarry: 2D spectrogram array for the random (augmented) audio clip. 70 | """ 71 | if random: 72 | if max_clips is not None: 73 | clip_generator = self.clips.random_audio_generator(max_clips=max_clips) 74 | else: 75 | clip_generator = self.clips.random_audio_generator() 76 | else: 77 | clip_generator = self.clips.audio_generator(**kwargs) 78 | 79 | if self.augmenter is not None: 80 | augmented_generator = self.augmenter.augment_generator(clip_generator) 81 | else: 82 | augmented_generator = clip_generator 83 | 84 | for augmented_clip in augmented_generator: 85 | spectrogram = generate_features_for_clip(augmented_clip, self.step_ms) 86 | 87 | if self.split_spectrogram_duration_s is not None: 88 | # Splits the resulting spectrogram into non-overlapping spectrograms. The features from the first 20 feature windows are dropped. 89 | desired_spectrogram_length = int( 90 | self.split_spectrogram_duration_s / (self.step_ms / 1000) 91 | ) 92 | 93 | if spectrogram.shape[0] > desired_spectrogram_length + 20: 94 | slided_spectrograms = np.lib.stride_tricks.sliding_window_view( 95 | spectrogram, 96 | window_shape=(desired_spectrogram_length, spectrogram.shape[1]), 97 | )[20::desired_spectrogram_length, ...] 98 | 99 | for i in range(slided_spectrograms.shape[0]): 100 | yield np.squeeze(slided_spectrograms[i]) 101 | else: 102 | yield spectrogram 103 | elif self.slide_frames is not None: 104 | # Generates self.slide_frames spectrograms by shifting over the already generated spectrogram 105 | spectrogram_length = spectrogram.shape[0] - self.slide_frames + 1 106 | 107 | slided_spectrograms = np.lib.stride_tricks.sliding_window_view( 108 | spectrogram, window_shape=(spectrogram_length, spectrogram.shape[1]) 109 | ) 110 | for i in range(self.slide_frames): 111 | yield np.squeeze(slided_spectrograms[i]) 112 | else: 113 | yield spectrogram 114 | -------------------------------------------------------------------------------- /microwakeword/inception.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 The Google Research Authors. 3 | # Modifications copyright 2024 Kevin Ahrendt. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """Inception - reduced version of keras/applications/inception_v3.py .""" 18 | import ast 19 | import tensorflow as tf 20 | 21 | 22 | from microwakeword.layers import delay 23 | from microwakeword.layers import stream 24 | from microwakeword.layers import strided_drop 25 | from microwakeword.layers import sub_spectral_normalization 26 | 27 | 28 | def parse(text): 29 | """Parse model parameters. 30 | 31 | Args: 32 | text: string with layer parameters: '128,128' or "'relu','relu'" 33 | 34 | Returns: 35 | list of parsed parameters 36 | """ 37 | if not text: 38 | return [] 39 | res = ast.literal_eval(text) 40 | if isinstance(res, tuple): 41 | return res 42 | else: 43 | return [res] 44 | 45 | 46 | def conv2d_bn( 47 | x, 48 | filters, 49 | kernel_size, 50 | dilation=(1, 1), 51 | padding="same", 52 | strides=(1, 1), 53 | activation="relu", 54 | use_bias=False, 55 | subgroups=1, 56 | ): 57 | """Utility function to apply conv + BN. 58 | 59 | Arguments: 60 | x: input tensor 61 | filters: filters in `Conv2D` 62 | kernel_size: size of convolution kernel 63 | dilation: dilation rate 64 | padding: padding mode in `Conv2D` 65 | strides: strides in `Conv2D` 66 | activation: activation function applied in the end 67 | use_bias: use bias for convolution 68 | subgroups: the number of subgroups used for sub-spectral normaliation 69 | 70 | Returns: 71 | output tensor after applying `Conv2D` and `SubSpectralNormalization` 72 | """ 73 | 74 | x = tf.keras.layers.Conv2D( 75 | filters, 76 | kernel_size, 77 | dilation_rate=dilation, 78 | strides=strides, 79 | padding=padding, 80 | use_bias=use_bias, 81 | )(x) 82 | 83 | sub_spectral_normalization_layer = ( 84 | sub_spectral_normalization.SubSpectralNormalization(subgroups) 85 | ) 86 | x = sub_spectral_normalization_layer(x) 87 | x = tf.keras.layers.Activation(activation)(x) 88 | return x 89 | 90 | 91 | def conv2d_bn_delay( 92 | x, 93 | filters, 94 | kernel_size, 95 | dilation, 96 | padding="same", 97 | strides=(1, 1), 98 | activation="relu", 99 | use_bias=False, 100 | delay_val=1, 101 | subgroups=1, 102 | ): 103 | """Utility function to apply conv + BN. 104 | 105 | Arguments: 106 | x: input tensor 107 | filters: filters in `Conv2D` 108 | kernel_size: size of convolution kernel 109 | dilation: dilation rate 110 | padding: padding mode in `Conv2D` 111 | strides: strides in `Conv2D` 112 | activation: activation function applied in the end 113 | use_bias: use bias for convolution 114 | delay_val: number of features for delay layer when using `same` padding 115 | subgroups: the number of subgroups used for sub-spectral normaliation 116 | 117 | Returns: 118 | output tensor after applying `Conv2D` and `SubSpectralNormalization`. 119 | """ 120 | 121 | if padding == "same": 122 | x = delay.Delay(delay=delay_val)(x) 123 | 124 | x = stream.Stream( 125 | cell=tf.keras.layers.Conv2D( 126 | filters, 127 | kernel_size, 128 | dilation_rate=dilation, 129 | strides=strides, 130 | padding="valid", 131 | use_bias=use_bias, 132 | ), 133 | use_one_step=False, 134 | pad_time_dim=padding, 135 | pad_freq_dim="same", 136 | )(x) 137 | sub_spectral_normalization_layer = ( 138 | sub_spectral_normalization.SubSpectralNormalization(subgroups) 139 | ) 140 | x = sub_spectral_normalization_layer(x) 141 | 142 | x = tf.keras.layers.Activation(activation)(x) 143 | return x 144 | 145 | 146 | def model_parameters(parser_nn): 147 | """Inception model parameters. 148 | 149 | Args: 150 | parser_nn: global command line args parser 151 | Returns: 152 | parser with updated arguments 153 | """ 154 | parser_nn.add_argument( 155 | "--cnn1_filters", 156 | type=str, 157 | default="24", 158 | help="Number of filters in the first conv blocks", 159 | ) 160 | parser_nn.add_argument( 161 | "--cnn1_kernel_sizes", 162 | type=str, 163 | default="5", 164 | help="Kernel size in time dim of conv blocks", 165 | ) 166 | parser_nn.add_argument( 167 | "--cnn1_subspectral_groups", 168 | type=str, 169 | default="4", 170 | help="The number of subspectral groups for normalization", 171 | ) 172 | parser_nn.add_argument( 173 | "--cnn2_filters1", 174 | type=str, 175 | default="10,10,16", 176 | help="Number of filters inside of inception block " 177 | "will be multipled by 4 because of concatenation of 4 branches", 178 | ) 179 | parser_nn.add_argument( 180 | "--cnn2_filters2", 181 | type=str, 182 | default="10,10,16", 183 | help="Number of filters inside of inception block " 184 | "it is used to reduce the dim of cnn2_filters1*4", 185 | ) 186 | parser_nn.add_argument( 187 | "--cnn2_kernel_sizes", 188 | type=str, 189 | default="5,5,5", 190 | help="Kernel sizes of conv layers in the inception block", 191 | ) 192 | parser_nn.add_argument( 193 | "--cnn2_subspectral_groups", 194 | type=str, 195 | default="1,1,1", 196 | help="The number of subspectral groups for normalization", 197 | ) 198 | parser_nn.add_argument( 199 | "--cnn2_dilation", 200 | type=str, 201 | default="1,1,1", 202 | help="Dilation rate", 203 | ) 204 | parser_nn.add_argument( 205 | "--dropout", 206 | type=float, 207 | default=0.2, 208 | help="Percentage of data dropped", 209 | ) 210 | 211 | 212 | def spectrogram_slices_dropped(flags): 213 | """Computes the number of spectrogram slices dropped due to valid padding. 214 | 215 | Args: 216 | flags: data/model parameters 217 | 218 | Returns: 219 | int: number of spectrogram slices dropped 220 | """ 221 | spectrogram_slices_dropped = 0 222 | 223 | for kernel_size in parse(flags.cnn1_kernel_sizes): 224 | spectrogram_slices_dropped += kernel_size - 1 225 | for kernel_size, dilation in zip( 226 | parse(flags.cnn2_kernel_sizes), parse(flags.cnn2_dilation) 227 | ): 228 | spectrogram_slices_dropped += 2 * dilation * (kernel_size - 1) 229 | 230 | return spectrogram_slices_dropped 231 | 232 | 233 | def model(flags, shape, batch_size): 234 | """Inception model. 235 | 236 | It is based on paper: 237 | Rethinking the Inception Architecture for Computer Vision 238 | http://arxiv.org/abs/1512.00567 239 | Args: 240 | flags: data/model parameters 241 | config: dictionary containing microWakeWord training configuration 242 | 243 | Returns: 244 | Keras model for training 245 | """ 246 | input_audio = tf.keras.layers.Input( 247 | shape=shape, 248 | batch_size=batch_size, 249 | ) 250 | net = input_audio 251 | 252 | # [batch, time, feature] 253 | net = tf.keras.ops.expand_dims(net, axis=2) 254 | # [batch, time, 1, feature] 255 | 256 | for filters, kernel_size, subgroups in zip( 257 | parse(flags.cnn1_filters), 258 | parse(flags.cnn1_kernel_sizes), 259 | parse(flags.cnn1_subspectral_groups), 260 | ): 261 | # Streaming Conv2D with 'valid' padding 262 | net = stream.Stream( 263 | cell=tf.keras.layers.Conv2D( 264 | filters, (kernel_size, 1), padding="valid", use_bias=False 265 | ), 266 | use_one_step=True, 267 | pad_time_dim=None, 268 | pad_freq_dim="same", 269 | )(net) 270 | sub_spectral_normalization_layer = ( 271 | sub_spectral_normalization.SubSpectralNormalization(subgroups) 272 | ) 273 | net = sub_spectral_normalization_layer(net) 274 | net = tf.keras.layers.Activation("relu")(net) 275 | 276 | for filters1, filters2, kernel_size, subgroups, dilation in zip( 277 | parse(flags.cnn2_filters1), 278 | parse(flags.cnn2_filters2), 279 | parse(flags.cnn2_kernel_sizes), 280 | parse(flags.cnn2_subspectral_groups), 281 | parse(flags.cnn2_dilation), 282 | ): 283 | time_buffer_size = dilation * (kernel_size - 1) 284 | 285 | branch1 = conv2d_bn(net, filters1, (1, 1), dilation=(1, 1), subgroups=subgroups) 286 | 287 | branch2 = conv2d_bn(net, filters1, (1, 1), subgroups=subgroups) 288 | branch2 = conv2d_bn_delay( 289 | branch2, 290 | filters1, 291 | (kernel_size, 1), 292 | (dilation, 1), 293 | padding="None", 294 | delay_val=time_buffer_size // 2, 295 | subgroups=subgroups, 296 | ) 297 | 298 | branch3 = conv2d_bn(net, filters1, (1, 1), subgroups=subgroups) 299 | branch3 = conv2d_bn_delay( 300 | branch3, 301 | filters1, 302 | (kernel_size, 1), 303 | (dilation, 1), 304 | padding="None", 305 | delay_val=time_buffer_size // 2, 306 | subgroups=subgroups, 307 | ) 308 | branch3 = conv2d_bn_delay( 309 | branch3, 310 | filters1, 311 | (kernel_size, 1), 312 | (dilation, 1), 313 | padding="None", 314 | delay_val=time_buffer_size // 2, 315 | subgroups=subgroups, 316 | ) 317 | 318 | branch1_drop_layer = strided_drop.StridedDrop( 319 | branch1.shape[1] - branch3.shape[1] 320 | ) 321 | branch1 = branch1_drop_layer(branch1) 322 | 323 | branch2_drop_layer = strided_drop.StridedDrop( 324 | branch2.shape[1] - branch3.shape[1] 325 | ) 326 | branch2 = branch2_drop_layer(branch2) 327 | 328 | net = tf.keras.layers.concatenate([branch1, branch2, branch3]) 329 | # [batch, time, 1, filters*4] 330 | net = conv2d_bn(net, filters2, (1, 1)) 331 | # [batch, time, 1, filters2] 332 | 333 | net = stream.Stream(cell=tf.keras.layers.Flatten())(net) 334 | # [batch, filters*4] 335 | net = tf.keras.layers.Dropout(flags.dropout)(net) 336 | net = tf.keras.layers.Dense(1, activation="sigmoid")(net) 337 | 338 | return tf.keras.Model(input_audio, net) 339 | -------------------------------------------------------------------------------- /microwakeword/inference.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 The Google Research Authors. 3 | # Modifications copyright 2024 Kevin Ahrendt. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """Functions and classes for using microwakeword models with audio files/data""" 18 | 19 | # imports 20 | import numpy as np 21 | import tensorflow as tf 22 | from microwakeword.audio.audio_utils import generate_features_for_clip 23 | 24 | 25 | class Model: 26 | """ 27 | Class for loading and running tflite microwakeword models 28 | 29 | Args: 30 | tflite_model_path (str): Path to tflite model file. 31 | stride (int | None, optional): Time dimension's stride. If None, then the stride is the input tensor's time dimension. Defaults to None. 32 | """ 33 | 34 | def __init__(self, tflite_model_path: str, stride: int | None = None): 35 | # Load tflite model 36 | interpreter = tf.lite.Interpreter( 37 | model_path=tflite_model_path, 38 | ) 39 | interpreter.allocate_tensors() 40 | 41 | self.input_details = interpreter.get_input_details() 42 | self.output_details = interpreter.get_output_details() 43 | 44 | self.is_quantized_model = self.input_details[0]["dtype"] == np.int8 45 | self.input_feature_slices = self.input_details[0]["shape"][1] 46 | 47 | if stride is None: 48 | self.stride = self.input_feature_slices 49 | else: 50 | self.stride = stride 51 | 52 | for s in range(len(self.input_details)): 53 | if self.is_quantized_model: 54 | interpreter.set_tensor( 55 | self.input_details[s]["index"], 56 | np.zeros(self.input_details[s]["shape"], dtype=np.int8), 57 | ) 58 | else: 59 | interpreter.set_tensor( 60 | self.input_details[s]["index"], 61 | np.zeros(self.input_details[s]["shape"], dtype=np.float32), 62 | ) 63 | 64 | self.model = interpreter 65 | 66 | def predict_clip(self, data: np.ndarray, step_ms: int = 20): 67 | """Run the model on a single clip of audio data 68 | 69 | Args: 70 | data (numpy.ndarray): input data for the model (16 khz, 16-bit PCM audio data) 71 | step_ms (int): The window step sized used for generating the spectrogram in ms. Defaults to 20. 72 | 73 | Returns: 74 | list: model predictions for the input audio data 75 | """ 76 | 77 | # Get the spectrogram 78 | spectrogram = generate_features_for_clip(data, step_ms=step_ms) 79 | 80 | return self.predict_spectrogram(spectrogram) 81 | 82 | def predict_spectrogram(self, spectrogram: np.ndarray): 83 | """Run the model on a single spectrogram 84 | 85 | Args: 86 | spectrogram (numpy.ndarray): Input spectrogram. 87 | 88 | Returns: 89 | list: model predictions for the input audio data 90 | """ 91 | 92 | # Spectrograms with type np.uint16 haven't been scaled 93 | if np.issubdtype(spectrogram.dtype, np.uint16): 94 | spectrogram = spectrogram.astype(np.float32) * 0.0390625 95 | elif np.issubdtype(spectrogram.dtype, np.float64): 96 | spectrogram = spectrogram.astype(np.float32) 97 | 98 | # Slice the input data into the required number of chunks 99 | chunks = [] 100 | for last_index in range( 101 | self.input_feature_slices, len(spectrogram) + 1, self.stride 102 | ): 103 | chunk = spectrogram[last_index - self.input_feature_slices : last_index] 104 | if len(chunk) == self.input_feature_slices: 105 | chunks.append(chunk) 106 | 107 | # Get the prediction for each chunk 108 | predictions = [] 109 | for chunk in chunks: 110 | if self.is_quantized_model and spectrogram.dtype != np.int8: 111 | chunk = self.quantize_input_data(chunk, self.input_details[0]) 112 | 113 | self.model.set_tensor( 114 | self.input_details[0]["index"], 115 | np.reshape(chunk, self.input_details[0]["shape"]), 116 | ) 117 | self.model.invoke() 118 | 119 | output = self.model.get_tensor(self.output_details[0]["index"])[0][0] 120 | if self.is_quantized_model: 121 | output = self.dequantize_output_data(output, self.output_details[0]) 122 | 123 | predictions.append(output) 124 | 125 | return predictions 126 | 127 | def quantize_input_data(self, data: np.ndarray, input_details: dict) -> np.ndarray: 128 | """quantize the input data using scale and zero point 129 | 130 | Args: 131 | data (numpy.array in float): input data for the interpreter 132 | input_details (dict): output of get_input_details from the tflm interpreter. 133 | 134 | Returns: 135 | numpy.ndarray: quantized data as int8 dtype 136 | """ 137 | # Get input quantization parameters 138 | data_type = input_details["dtype"] 139 | 140 | input_quantization_parameters = input_details["quantization_parameters"] 141 | input_scale, input_zero_point = ( 142 | input_quantization_parameters["scales"][0], 143 | input_quantization_parameters["zero_points"][0], 144 | ) 145 | # quantize the input data 146 | data = data / input_scale + input_zero_point 147 | return data.astype(data_type) 148 | 149 | def dequantize_output_data( 150 | self, data: np.ndarray, output_details: dict 151 | ) -> np.ndarray: 152 | """Dequantize the model output 153 | 154 | Args: 155 | data (numpy.ndarray): integer data to be dequantized 156 | output_details (dict): TFLM interpreter model output details 157 | 158 | Returns: 159 | numpy.ndarray: dequantized data as float32 dtype 160 | """ 161 | output_quantization_parameters = output_details["quantization_parameters"] 162 | output_scale = 255.0 # assume (u)int8 quantization 163 | output_zero_point = output_quantization_parameters["zero_points"][0] 164 | # Caveat: tflm_output_quant need to be converted to float to avoid integer 165 | # overflow during dequantization 166 | # e.g., (tflm_output_quant -output_zero_point) and 167 | # (tflm_output_quant + (-output_zero_point)) 168 | # can produce different results (int8 calculation) 169 | # return output_scale * (data.astype(np.float32) - output_zero_point) 170 | return 1 / output_scale * (data.astype(np.float32) - output_zero_point) 171 | -------------------------------------------------------------------------------- /microwakeword/layers/average_pooling2d.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Convolutional AveragePooling2D.""" 17 | import numpy as np 18 | import tensorflow as tf 19 | 20 | 21 | class AveragePooling2D(tf.keras.layers.Layer): 22 | """AveragePooling2D layer. 23 | 24 | It is convolutional AveragePooling2D based on depthwise_conv2d. 25 | It can be useful for cases where AveragePooling2D has to run in streaming mode 26 | 27 | The input data with shape [batch_size, time1, feature1, feature2] 28 | are processed by depthwise conv with fixed weights, all weights values 29 | are equal to 1.0/(size_in_time_1*size_in_feature1). 30 | Averaging is done in 'time1' and 'feature1' dims. 31 | Conv filter has size [size_in_time_1, size_in_feature1, feature2], 32 | where first two dims are specified by user and 33 | feature2 is defiend by the last dim of input data. 34 | 35 | So if kernel_size = [time1, feature1] 36 | output will be [batch_size, time1, 1, feature2] 37 | 38 | Attributes: 39 | kernel_size: 2D kernel size - defines the dims 40 | which will be eliminated/averaged. 41 | strides: stride for each dim, with size 4 42 | padding: defiens how to pad 43 | dilation_rate: dilation rate in which we sample input values 44 | across the height and width 45 | **kwargs: additional layer arguments 46 | """ 47 | 48 | def __init__( 49 | self, kernel_size, strides=None, padding="valid", dilation_rate=None, **kwargs 50 | ): 51 | super(AveragePooling2D, self).__init__(**kwargs) 52 | self.kernel_size = kernel_size 53 | self.strides = strides 54 | self.padding = padding 55 | self.dilation_rate = dilation_rate 56 | if not self.strides: 57 | self.strides = [1, 1, 1, 1] 58 | 59 | if not self.dilation_rate: 60 | self.dilation_rate = [1, 1] 61 | 62 | def build(self, input_shape): 63 | super(AveragePooling2D, self).build(input_shape) 64 | # expand filters shape with the last dimension 65 | filter_shape = self.kernel_size + (input_shape[-1],) 66 | self.filters = self.add_weight("kernel", shape=filter_shape) 67 | 68 | init_weight = np.ones(filter_shape) / np.prod(self.kernel_size) 69 | self.set_weights([init_weight]) 70 | 71 | def call(self, inputs): 72 | # inputs [batch_size, time1, feature1, feature2] 73 | time_kernel_exp = tf.expand_dims(self.filters, -1) 74 | # it can be replaced by AveragePooling2D with temporal padding 75 | # and optimized for streaming mode 76 | # output will be [batch_size, time1, feature1, feature2] 77 | return tf.nn.depthwise_conv2d( 78 | inputs, 79 | time_kernel_exp, 80 | strides=self.strides, 81 | padding=self.padding.upper(), 82 | dilations=self.dilation_rate, 83 | name=self.name + "_averPool2D", 84 | ) 85 | 86 | def get_config(self): 87 | config = super(AveragePooling2D, self).get_config() 88 | config.update( 89 | { 90 | "kernel_size": self.kernel_size, 91 | "strides": self.strides, 92 | "padding": self.padding, 93 | "dilation_rate": self.dilation_rate, 94 | } 95 | ) 96 | return config 97 | -------------------------------------------------------------------------------- /microwakeword/layers/delay.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Dealy layer.""" 17 | 18 | from microwakeword.layers import modes 19 | import tensorflow as tf 20 | 21 | 22 | class Delay(tf.keras.layers.Layer): 23 | """Delay layer. 24 | 25 | It is useful for introducing delay in streaming mode for non causal filters. 26 | For example in residual connections with multiple conv layers 27 | 28 | Attributes: 29 | mode: Training or inference modes: non streaming, streaming. 30 | delay: delay value 31 | inference_batch_size: batch size in inference mode 32 | also_in_non_streaming: Apply delay also in training and non-streaming 33 | inference mode. 34 | **kwargs: additional layer arguments 35 | """ 36 | 37 | def __init__( 38 | self, 39 | mode=modes.Modes.TRAINING, 40 | delay=0, 41 | inference_batch_size=1, 42 | also_in_non_streaming=False, 43 | **kwargs, 44 | ): 45 | super(Delay, self).__init__(**kwargs) 46 | self.mode = mode 47 | self.delay = delay 48 | self.inference_batch_size = inference_batch_size 49 | self.also_in_non_streaming = also_in_non_streaming 50 | 51 | if delay < 0: 52 | raise ValueError("delay (%d) must be non-negative" % delay) 53 | 54 | def build(self, input_shape): 55 | super(Delay, self).build(input_shape) 56 | 57 | if self.delay > 0: 58 | self.state_shape = [ 59 | self.inference_batch_size, 60 | self.delay, 61 | ] + input_shape.as_list()[2:] 62 | if self.mode == modes.Modes.STREAM_INTERNAL_STATE_INFERENCE: 63 | self.states = self.add_weight( 64 | name="states", 65 | shape=self.state_shape, 66 | trainable=False, 67 | initializer=tf.zeros_initializer, 68 | ) 69 | 70 | elif self.mode == modes.Modes.STREAM_EXTERNAL_STATE_INFERENCE: 71 | # For streaming inference with extrnal states, 72 | # the states are passed in as input. 73 | self.input_state = tf.keras.layers.Input( 74 | shape=self.state_shape[1:], 75 | batch_size=self.inference_batch_size, 76 | name=self.name + "/input_state_delay", 77 | ) 78 | self.output_state = None 79 | 80 | def call(self, inputs): 81 | if self.delay == 0: 82 | return inputs 83 | 84 | if self.mode == modes.Modes.STREAM_INTERNAL_STATE_INFERENCE: 85 | return self._streaming_internal_state(inputs) 86 | 87 | elif self.mode == modes.Modes.STREAM_EXTERNAL_STATE_INFERENCE: 88 | # in streaming inference mode with external state 89 | # in addition to the output we return the output state. 90 | output, self.output_state = self._streaming_external_state( 91 | inputs, self.input_state 92 | ) 93 | return output 94 | 95 | elif self.mode in (modes.Modes.TRAINING, modes.Modes.NON_STREAM_INFERENCE): 96 | # run non streamable training or non streamable inference 97 | return self._non_streaming(inputs) 98 | 99 | else: 100 | raise ValueError(f"Encountered unexpected mode `{self.mode}`.") 101 | 102 | def get_config(self): 103 | config = super(Delay, self).get_config() 104 | config.update( 105 | { 106 | "mode": self.mode, 107 | "delay": self.delay, 108 | "inference_batch_size": self.inference_batch_size, 109 | "also_in_non_streaming": self.also_in_non_streaming, 110 | } 111 | ) 112 | return config 113 | 114 | def _streaming_internal_state(self, inputs): 115 | memory = tf.keras.layers.concatenate([self.states, inputs], 1) 116 | outputs = memory[:, : inputs.shape.as_list()[1]] 117 | new_memory = memory[:, -self.delay :] 118 | assign_states = self.states.assign(new_memory) 119 | 120 | with tf.control_dependencies([assign_states]): 121 | return tf.identity(outputs) 122 | 123 | def _streaming_external_state(self, inputs, states): 124 | memory = tf.keras.layers.concatenate([states, inputs], 1) 125 | outputs = memory[:, : inputs.shape.as_list()[1]] 126 | new_memory = memory[:, -self.delay :] 127 | return outputs, new_memory 128 | 129 | def _non_streaming(self, inputs): 130 | if self.also_in_non_streaming: 131 | return tf.pad( 132 | inputs, ((0, 0), (self.delay, 0)) + ((0, 0),) * (inputs.shape.rank - 2) 133 | )[:, : -self.delay] 134 | else: 135 | return inputs 136 | 137 | def get_input_state(self): 138 | # input state will be used only for STREAM_EXTERNAL_STATE_INFERENCE mode 139 | if self.mode == modes.Modes.STREAM_EXTERNAL_STATE_INFERENCE: 140 | return [self.input_state] 141 | else: 142 | raise ValueError( 143 | "Expected the layer to be in external streaming mode, " 144 | f"not `{self.mode}`." 145 | ) 146 | 147 | def get_output_state(self): 148 | # output state will be used only for STREAM_EXTERNAL_STATE_INFERENCE mode 149 | if self.mode == modes.Modes.STREAM_EXTERNAL_STATE_INFERENCE: 150 | return [self.output_state] 151 | else: 152 | raise ValueError( 153 | "Expected the layer to be in external streaming mode, " 154 | f"not `{self.mode}`." 155 | ) 156 | -------------------------------------------------------------------------------- /microwakeword/layers/modes.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Modes the model can be in and its input data shape.""" 17 | 18 | 19 | class Modes(object): 20 | """Definition of the mode the model is functioning in.""" 21 | 22 | # Model is in a training state. No streaming is done. 23 | TRAINING = "TRAINING" 24 | 25 | # Below are three options for inference: 26 | 27 | # Model is in inference mode and has state for efficient 28 | # computation/streaming, where state is kept inside of the model 29 | STREAM_INTERNAL_STATE_INFERENCE = "STREAM_INTERNAL_STATE_INFERENCE" 30 | 31 | # Model is in inference mode and has state for efficient 32 | # computation/streaming, where state is received from outside of the model 33 | STREAM_EXTERNAL_STATE_INFERENCE = "STREAM_EXTERNAL_STATE_INFERENCE" 34 | 35 | # Model its in inference mode and it's topology is the same with training 36 | # mode (with removed droputs etc) 37 | NON_STREAM_INFERENCE = "NON_STREAM_INFERENCE" 38 | 39 | 40 | def get_input_data_shape(config, mode): 41 | """Gets data shape for a neural net input layer. 42 | 43 | Args: 44 | config: dictionary containing training parameters 45 | mode: inference mode described above at Modes 46 | 47 | Returns: 48 | data_shape for input layer 49 | """ 50 | 51 | if mode not in ( 52 | Modes.TRAINING, 53 | Modes.NON_STREAM_INFERENCE, 54 | Modes.STREAM_INTERNAL_STATE_INFERENCE, 55 | Modes.STREAM_EXTERNAL_STATE_INFERENCE, 56 | ): 57 | raise ValueError('Unknown mode "%s" ' % config["mode"]) 58 | 59 | if mode in (Modes.TRAINING, Modes.NON_STREAM_INFERENCE): 60 | data_shape = (config["spectrogram_length"], 40) 61 | else: 62 | stride = config['stride'] 63 | data_shape = (stride, 40) 64 | return data_shape 65 | -------------------------------------------------------------------------------- /microwakeword/layers/strided_drop.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 Kevin Ahrendt. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import tensorflow as tf 17 | 18 | from microwakeword.layers import modes 19 | 20 | 21 | class StridedDrop(tf.keras.layers.Layer): 22 | """StridedDrop 23 | 24 | Drops the specified audio feature slices in nonstreaming mode only. 25 | Used for matching the dimensions of convolutions with valid padding. 26 | 27 | Attributes: 28 | time_sclices_to_drop: number of audio feature slices to drop 29 | mode: inference mode; e.g., non-streaming, internal streaming 30 | """ 31 | 32 | def __init__( 33 | self, time_slices_to_drop, mode=modes.Modes.NON_STREAM_INFERENCE, **kwargs 34 | ): 35 | super(StridedDrop, self).__init__(**kwargs) 36 | self.time_slices_to_drop = time_slices_to_drop 37 | self.mode = mode 38 | self.state_shape = [] 39 | 40 | def call(self, inputs): 41 | if self.mode == modes.Modes.NON_STREAM_INFERENCE: 42 | return inputs[:, self.time_slices_to_drop :, :, :] 43 | 44 | return inputs 45 | 46 | def get_config(self): 47 | config = { 48 | "time_slices_to_drop": self.time_slices_to_drop, 49 | "mode": self.mode, 50 | } 51 | base_config = super(StridedDrop, self).get_config() 52 | return dict(list(base_config.items()) + list(config.items())) 53 | 54 | def get_input_state(self): 55 | return [] 56 | 57 | def get_output_state(self): 58 | return [] 59 | 60 | 61 | class StridedKeep(tf.keras.layers.Layer): 62 | """StridedKeep 63 | 64 | Keeps the specified audio feature slices in streaming mode only. 65 | Used for splitting a single streaming ring buffer into multiple branches with minimal overhead. 66 | 67 | Attributes: 68 | time_sclices_to_keep: number of audio feature slices to keep 69 | mode: inference mode; e.g., non-streaming, internal streaming 70 | """ 71 | 72 | def __init__( 73 | self, time_slices_to_keep, mode=modes.Modes.NON_STREAM_INFERENCE, **kwargs 74 | ): 75 | super(StridedKeep, self).__init__(**kwargs) 76 | self.time_slices_to_keep = max(time_slices_to_keep, 1) 77 | self.mode = mode 78 | self.state_shape = [] 79 | 80 | def call(self, inputs): 81 | if self.mode != modes.Modes.NON_STREAM_INFERENCE: 82 | return inputs[:, -self.time_slices_to_keep :, :, :] 83 | 84 | return inputs 85 | 86 | def get_config(self): 87 | config = { 88 | "time_slices_to_keep": self.time_slices_to_keep, 89 | "mode": self.mode, 90 | } 91 | base_config = super(StridedKeep, self).get_config() 92 | return dict(list(base_config.items()) + list(config.items())) 93 | 94 | def get_input_state(self): 95 | return [] 96 | 97 | def get_output_state(self): 98 | return [] 99 | -------------------------------------------------------------------------------- /microwakeword/layers/sub_spectral_normalization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 The Google Research Authors. 3 | # Modifications copyright 2024 Kevin Ahrendt. 4 | # 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | 18 | """Sub spectral normalization layer.""" 19 | from typing import Any, Dict 20 | 21 | import tensorflow as tf 22 | 23 | 24 | class SubSpectralNormalization(tf.keras.layers.Layer): 25 | """Sub spectral normalization layer. 26 | 27 | It is based on paper: 28 | "SUBSPECTRAL NORMALIZATION FOR NEURAL AUDIO DATA PROCESSING" 29 | https://arxiv.org/pdf/2103.13620.pdf 30 | """ 31 | 32 | def __init__(self, sub_groups, **kwargs): 33 | super(SubSpectralNormalization, self).__init__(**kwargs) 34 | self.sub_groups = sub_groups 35 | 36 | self.batch_norm = tf.keras.layers.BatchNormalization() 37 | 38 | def call(self, inputs): 39 | # expected input: [N, Time, Frequency, Channels] 40 | if inputs.shape.rank != 4: 41 | raise ValueError("input_shape.rank:%d must be 4" % inputs.shape.rank) 42 | 43 | input_shape = inputs.shape.as_list() 44 | if input_shape[3] % self.sub_groups: 45 | raise ValueError( 46 | "input_shape[3]: %d must be divisible by " 47 | "self.sub_groups %d " % (input_shape[3], self.sub_groups) 48 | ) 49 | 50 | net = inputs 51 | if self.sub_groups == 1: 52 | net = self.batch_norm(net) 53 | else: 54 | target_shape = [ 55 | input_shape[1], 56 | input_shape[3] // self.sub_groups, 57 | input_shape[2] * self.sub_groups, 58 | ] 59 | net = tf.keras.layers.Reshape(target_shape)(net) 60 | net = self.batch_norm(net) 61 | net = tf.keras.layers.Reshape(input_shape[1:])(net) 62 | return net 63 | 64 | def get_config(self): 65 | config = {"sub_groups": self.sub_groups} 66 | base_config = super(SubSpectralNormalization, self).get_config() 67 | return dict(list(base_config.items()) + list(config.items())) 68 | -------------------------------------------------------------------------------- /microwakeword/mixednet.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 Kevin Ahrendt. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Model based on 1D depthwise MixedConvs and 1x1 convolutions in time + residual.""" 17 | 18 | from microwakeword.layers import stream 19 | from microwakeword.layers import strided_drop 20 | 21 | import ast 22 | import tensorflow as tf 23 | 24 | 25 | def parse(text): 26 | """Parse model parameters. 27 | 28 | Args: 29 | text: string with layer parameters: '128,128' or "'relu','relu'". 30 | 31 | Returns: 32 | list of parsed parameters 33 | """ 34 | if not text: 35 | return [] 36 | res = ast.literal_eval(text) 37 | if isinstance(res, tuple): 38 | return res 39 | else: 40 | return [res] 41 | 42 | 43 | def model_parameters(parser_nn): 44 | """MixedNet model parameters.""" 45 | 46 | parser_nn.add_argument( 47 | "--pointwise_filters", 48 | type=str, 49 | default="48, 48, 48, 48", 50 | help="Number of filters in every MixConv block's pointwise convolution", 51 | ) 52 | parser_nn.add_argument( 53 | "--residual_connection", 54 | type=str, 55 | default="0,0,0,0,0", 56 | help="Use a residual connection in each MixConv block", 57 | ) 58 | parser_nn.add_argument( 59 | "--repeat_in_block", 60 | type=str, 61 | default="1,1,1,1", 62 | help="Number of repeating conv blocks inside of residual block", 63 | ) 64 | parser_nn.add_argument( 65 | "--mixconv_kernel_sizes", 66 | type=str, 67 | default="[5], [9], [13], [21]", 68 | help="Kernel size lists for DepthwiseConv1D in time dim for every MixConv block", 69 | ) 70 | parser_nn.add_argument( 71 | "--max_pool", 72 | type=int, 73 | default=0, 74 | help="apply max pool instead of average pool before final convolution and sigmoid activation", 75 | ) 76 | parser_nn.add_argument( 77 | "--first_conv_filters", 78 | type=int, 79 | default=32, 80 | help="Number of filters on initial convolution layer. Set to 0 to disable.", 81 | ) 82 | parser_nn.add_argument( 83 | "--first_conv_kernel_size", 84 | type=int, 85 | default="3", 86 | help="Temporal kernel size for the initial convolution layer.", 87 | ) 88 | parser_nn.add_argument( 89 | "--spatial_attention", 90 | type=int, 91 | default=0, 92 | help="Add a spatial attention layer before the final pooling layer", 93 | ) 94 | parser_nn.add_argument( 95 | "--pooled", 96 | type=int, 97 | default=0, 98 | help="Pool the temporal dimension before the final fully connected layer. Uses average pooling or max pooling depending on the max_pool argument", 99 | ) 100 | parser_nn.add_argument( 101 | "--stride", 102 | type=int, 103 | default=1, 104 | help="Striding in the time dimension of the initial convolution layer", 105 | ) 106 | 107 | 108 | def spectrogram_slices_dropped(flags): 109 | """Computes the number of spectrogram slices dropped due to valid padding. 110 | 111 | Args: 112 | flags: data/model parameters 113 | 114 | Returns: 115 | int: number of spectrogram slices dropped 116 | """ 117 | spectrogram_slices_dropped = 0 118 | 119 | if flags.first_conv_filters > 0: 120 | spectrogram_slices_dropped += flags.first_conv_kernel_size - 1 121 | 122 | for repeat, ksize in zip( 123 | parse(flags.repeat_in_block), 124 | parse(flags.mixconv_kernel_sizes), 125 | ): 126 | spectrogram_slices_dropped += (repeat * (max(ksize) - 1)) * flags.stride 127 | 128 | # spectrogram_slices_dropped *= flags.stride 129 | return spectrogram_slices_dropped 130 | 131 | 132 | def _split_channels(total_filters, num_groups): 133 | """Helper for MixConv""" 134 | split = [total_filters // num_groups for _ in range(num_groups)] 135 | split[0] += total_filters - sum(split) 136 | return split 137 | 138 | 139 | def _get_shape_value(maybe_v2_shape): 140 | """Helper for MixConv""" 141 | if maybe_v2_shape is None: 142 | return None 143 | elif isinstance(maybe_v2_shape, int): 144 | return maybe_v2_shape 145 | else: 146 | return maybe_v2_shape.value 147 | 148 | 149 | class ChannelSplit(tf.keras.layers.Layer): 150 | def __init__(self, splits, axis=-1, **kwargs): 151 | super().__init__(**kwargs) 152 | self.splits = splits 153 | self.axis = axis 154 | 155 | def call(self, inputs): 156 | return tf.split(inputs, self.splits, axis=self.axis) 157 | 158 | def compute_output_shape(self, input_shape): 159 | output_shapes = [] 160 | for split in self.splits: 161 | new_shape = list(input_shape) 162 | new_shape[self.axis] = split 163 | output_shapes.append(tuple(new_shape)) 164 | return output_shapes 165 | 166 | 167 | 168 | class MixConv: 169 | """MixConv with mixed depthwise convolutional kernels. 170 | 171 | MDConv is an improved depthwise convolution that mixes multiple kernels (e.g. 172 | 3x1, 5x1, etc). Right now, we use an naive implementation that split channels 173 | into multiple groups and perform different kernels for each group. 174 | 175 | See Mixnet paper for more details. 176 | """ 177 | 178 | def __init__(self, kernel_size, **kwargs): 179 | """Initialize the layer. 180 | 181 | Most of args are the same as tf.keras.layers.DepthwiseConv2D. 182 | 183 | Args: 184 | kernel_size: An integer or a list. If it is a single integer, then it is 185 | same as the original tf.keras.layers.DepthwiseConv2D. If it is a list, 186 | then we split the channels and perform different kernel for each group. 187 | strides: An integer or tuple/list of 2 integers, specifying the strides of 188 | the convolution along the height and width. 189 | **kwargs: other parameters passed to the original depthwise_conv layer. 190 | """ 191 | self._channel_axis = -1 192 | 193 | self.ring_buffer_length = max(kernel_size) - 1 194 | 195 | self.kernel_sizes = kernel_size 196 | 197 | def __call__(self, inputs): 198 | # We manually handle the streaming ring buffer for each layer 199 | # - There is some latency overhead on the esp devices for loading each ring buffer's data 200 | # - This avoids variable's holding redundant information 201 | # - Reduces the necessary size of the tensor arena 202 | net = stream.Stream( 203 | cell=tf.keras.layers.Identity(), 204 | ring_buffer_size_in_time_dim=self.ring_buffer_length, 205 | use_one_step=False, 206 | )(inputs) 207 | 208 | if len(self.kernel_sizes) == 1: 209 | return tf.keras.layers.DepthwiseConv2D( 210 | (self.kernel_sizes[0], 1), strides=1, padding="valid" 211 | )(net) 212 | 213 | filters = _get_shape_value(net.shape[self._channel_axis]) 214 | splits = _split_channels(filters, len(self.kernel_sizes)) 215 | x_splits = ChannelSplit(splits, axis=self._channel_axis)(net) 216 | 217 | x_outputs = [] 218 | for x, ks in zip(x_splits, self.kernel_sizes): 219 | fit = strided_drop.StridedKeep(ks)(x) 220 | x_outputs.append( 221 | tf.keras.layers.DepthwiseConv2D((ks, 1), strides=1, padding="valid")( 222 | fit 223 | ) 224 | ) 225 | 226 | for i, output in enumerate(x_outputs): 227 | features_drop = output.shape[1] - x_outputs[-1].shape[1] 228 | x_outputs[i] = strided_drop.StridedDrop(features_drop)(output) 229 | 230 | x = tf.keras.layers.concatenate(x_outputs, axis=self._channel_axis) 231 | return x 232 | 233 | 234 | class SpatialAttention: 235 | """Spatial Attention Layer based on CBAM: Convolutional Block Attention Module 236 | https://arxiv.org/pdf/1807.06521v2 237 | 238 | Args: 239 | object (_type_): _description_ 240 | """ 241 | 242 | def __init__(self, kernel_size, ring_buffer_size): 243 | self.kernel_size = kernel_size 244 | self.ring_buffer_size = ring_buffer_size 245 | 246 | def __call__(self, inputs): 247 | tranposed = tf.keras.ops.transpose(inputs, axes=[0, 1, 3, 2]) 248 | channel_avg = tf.keras.layers.AveragePooling2D( 249 | pool_size=(1, tranposed.shape[2]), strides=(1, tranposed.shape[2]) 250 | )(tranposed) 251 | channel_max = tf.keras.layers.MaxPooling2D( 252 | pool_size=(1, tranposed.shape[2]), strides=(1, tranposed.shape[2]) 253 | )(tranposed) 254 | pooled = tf.keras.layers.Concatenate(axis=-1)([channel_avg, channel_max]) 255 | 256 | attention = stream.Stream( 257 | cell=tf.keras.layers.Conv2D( 258 | 1, 259 | (self.kernel_size, 1), 260 | strides=(1, 1), 261 | padding="valid", 262 | use_bias=False, 263 | activation="sigmoid", 264 | ), 265 | use_one_step=False, 266 | )(pooled) 267 | 268 | net = stream.Stream( 269 | cell=tf.keras.layers.Identity(), 270 | ring_buffer_size_in_time_dim=self.ring_buffer_size, 271 | use_one_step=False, 272 | )(inputs) 273 | net = net[:, -attention.shape[1] :, :, :] 274 | 275 | return net * attention 276 | 277 | 278 | def model(flags, shape, batch_size): 279 | """MixedNet model. 280 | 281 | It is based on the paper 282 | MixConv: Mixed Depthwise Convolutional Kernels 283 | https://arxiv.org/abs/1907.09595 284 | Args: 285 | flags: data/model parameters 286 | shape: shape of the input vector 287 | config: dictionary containing microWakeWord training configuration 288 | 289 | Returns: 290 | Keras model for training 291 | """ 292 | 293 | pointwise_filters = parse(flags.pointwise_filters) 294 | repeat_in_block = parse(flags.repeat_in_block) 295 | mixconv_kernel_sizes = parse(flags.mixconv_kernel_sizes) 296 | residual_connections = parse(flags.residual_connection) 297 | 298 | for list in ( 299 | pointwise_filters, 300 | repeat_in_block, 301 | mixconv_kernel_sizes, 302 | residual_connections, 303 | ): 304 | if len(pointwise_filters) != len(list): 305 | raise ValueError("all input lists have to be the same length") 306 | 307 | input_audio = tf.keras.layers.Input( 308 | shape=shape, 309 | batch_size=batch_size, 310 | ) 311 | net = input_audio 312 | 313 | # make it [batch, time, 1, feature] 314 | net = tf.keras.ops.expand_dims(net, axis=2) 315 | 316 | # Streaming Conv2D with 'valid' padding 317 | if flags.first_conv_filters > 0: 318 | net = stream.Stream( 319 | cell=tf.keras.layers.Conv2D( 320 | flags.first_conv_filters, 321 | (flags.first_conv_kernel_size, 1), 322 | strides=(flags.stride, 1), 323 | padding="valid", 324 | use_bias=False, 325 | ), 326 | use_one_step=False, 327 | pad_time_dim=None, 328 | pad_freq_dim="valid", 329 | )(net) 330 | 331 | net = tf.keras.layers.Activation("relu")(net) 332 | 333 | # encoder 334 | for filters, repeat, ksize, res in zip( 335 | pointwise_filters, 336 | repeat_in_block, 337 | mixconv_kernel_sizes, 338 | residual_connections, 339 | ): 340 | if res: 341 | residual = tf.keras.layers.Conv2D( 342 | filters=filters, kernel_size=1, use_bias=False, padding="same" 343 | )(net) 344 | residual = tf.keras.layers.BatchNormalization()(residual) 345 | 346 | for _ in range(repeat): 347 | if max(ksize) > 1: 348 | net = MixConv(kernel_size=ksize)(net) 349 | net = tf.keras.layers.Conv2D( 350 | filters=filters, kernel_size=1, use_bias=False, padding="same" 351 | )(net) 352 | net = tf.keras.layers.BatchNormalization()(net) 353 | 354 | if res: 355 | residual = strided_drop.StridedDrop(residual.shape[1] - net.shape[1])( 356 | residual 357 | ) 358 | net = net + residual 359 | 360 | net = tf.keras.layers.Activation("relu")(net) 361 | 362 | if net.shape[1] > 1: 363 | if flags.spatial_attention: 364 | net = SpatialAttention( 365 | kernel_size=4, 366 | ring_buffer_size=net.shape[1] - 1, 367 | )(net) 368 | else: 369 | net = stream.Stream( 370 | cell=tf.keras.layers.Identity(), 371 | ring_buffer_size_in_time_dim=net.shape[1] - 1, 372 | use_one_step=False, 373 | )(net) 374 | 375 | if flags.pooled: 376 | # We want to use either Global Max Pooling or Global Average Pooling, but the esp-nn operator optimizations only benefit regular pooling operations 377 | 378 | if flags.max_pool: 379 | net = tf.keras.layers.MaxPooling2D(pool_size=(net.shape[1], 1))(net) 380 | else: 381 | net = tf.keras.layers.AveragePooling2D(pool_size=(net.shape[1], 1))(net) 382 | 383 | net = tf.keras.layers.Flatten()(net) 384 | net = tf.keras.layers.Dense(1, activation="sigmoid")(net) 385 | 386 | return tf.keras.Model(input_audio, net) 387 | -------------------------------------------------------------------------------- /microwakeword/model_train_eval.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 The Google Research Authors. 3 | # Modifications copyright 2024 Kevin Ahrendt. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | import argparse 18 | import os 19 | import sys 20 | import yaml 21 | import platform 22 | from absl import logging 23 | 24 | import tensorflow as tf 25 | 26 | # Disable GPU by default on ARM Macs, it's slower than just using the CPU 27 | if os.environ.get("CUDA_VISIBLE_DEVICES") == "-1" or ( 28 | sys.platform == "darwin" 29 | and platform.processor() == "arm" 30 | and "CUDA_VISIBLE_DEVICES" not in os.environ 31 | ): 32 | tf.config.set_visible_devices([], "GPU") 33 | 34 | import microwakeword.data as input_data 35 | import microwakeword.train as train 36 | import microwakeword.test as test 37 | import microwakeword.utils as utils 38 | 39 | import microwakeword.inception as inception 40 | import microwakeword.mixednet as mixednet 41 | 42 | from microwakeword.layers import modes 43 | 44 | 45 | def load_config(flags, model_module): 46 | """Loads the training configuration from the specified yaml file. 47 | 48 | Args: 49 | flags (argparse.Namespace): command line flags 50 | model_module (module): python module for loading the model 51 | 52 | Returns: 53 | dict: dictionary containing training configuration 54 | """ 55 | config_filename = flags.training_config 56 | config = yaml.load(open(config_filename, "r").read(), yaml.Loader) 57 | 58 | config["summaries_dir"] = os.path.join(config["train_dir"], "logs/") 59 | 60 | config["stride"] = flags.__dict__.get("stride", 1) 61 | config["window_step_ms"] = config.get("window_step_ms", 20) 62 | 63 | # Default preprocessor settings 64 | preprocessor_sample_rate = 16000 # Hz 65 | preprocessor_window_size = 30 # ms 66 | preprocessor_window_step = config["window_step_ms"] # ms 67 | 68 | desired_samples = int(preprocessor_sample_rate * config["clip_duration_ms"] / 1000) 69 | 70 | window_size_samples = int( 71 | preprocessor_sample_rate * preprocessor_window_size / 1000 72 | ) 73 | window_step_samples = int( 74 | config["stride"] * preprocessor_sample_rate * preprocessor_window_step / 1000 75 | ) 76 | 77 | length_minus_window = desired_samples - window_size_samples 78 | 79 | if length_minus_window < 0: 80 | config["spectrogram_length_final_layer"] = 0 81 | else: 82 | config["spectrogram_length_final_layer"] = 1 + int( 83 | length_minus_window / window_step_samples 84 | ) 85 | 86 | config["spectrogram_length"] = config[ 87 | "spectrogram_length_final_layer" 88 | ] + model_module.spectrogram_slices_dropped(flags) 89 | 90 | config["flags"] = flags.__dict__ 91 | 92 | config["training_input_shape"] = modes.get_input_data_shape( 93 | config, modes.Modes.TRAINING 94 | ) 95 | 96 | return config 97 | 98 | 99 | def train_model(config, model, data_processor, restore_checkpoint): 100 | """Trains a model. 101 | 102 | Args: 103 | config (dict): dictionary containing training configuration 104 | model (Keras model): model architecture to train 105 | data_processor (FeatureHandler): feature handler that loads spectrogram data 106 | restore_checkpoint (bool): Whether to restore from checkpoint if model exists 107 | 108 | Raises: 109 | ValueError: If the model exists but the training flag isn't set 110 | """ 111 | try: 112 | os.makedirs(config["train_dir"]) 113 | os.mkdir(config["summaries_dir"]) 114 | except OSError: 115 | if restore_checkpoint: 116 | pass 117 | else: 118 | raise ValueError( 119 | "model already exists in folder %s" % config["train_dir"] 120 | ) from None 121 | config_fname = os.path.join(config["train_dir"], "training_config.yaml") 122 | 123 | with open(config_fname, "w") as outfile: 124 | yaml.dump(config, outfile, default_flow_style=False) 125 | 126 | utils.save_model_summary(model, config["train_dir"]) 127 | 128 | train.train(model, config, data_processor) 129 | 130 | 131 | def evaluate_model( 132 | config, 133 | model, 134 | data_processor, 135 | test_tf_nonstreaming, 136 | test_tflite_nonstreaming, 137 | test_tflite_nonstreaming_quantized, 138 | test_tflite_streaming, 139 | test_tflite_streaming_quantized, 140 | ): 141 | """Evaluates a model on test data. 142 | 143 | Saves the nonstreaming model or streaming model in SavedModel format, 144 | then converts it to TFLite as specified. 145 | 146 | Args: 147 | config (dict): dictionary containing training configuration 148 | model (Keras model): model (with loaded weights) to test 149 | data_processor (FeatureHandler): feature handler that loads spectrogram data 150 | test_tf_nonstreaming (bool): Evaluate the nonstreaming SavedModel 151 | test_tflite_nonstreaming_quantized (bool): Convert and evaluate quantized nonstreaming TFLite model 152 | test_tflite_nonstreaming (bool): Convert and evaluate nonstreaming TFLite model 153 | test_tflite_streaming (bool): Convert and evaluate streaming TFLite model 154 | test_tflite_streaming_quantized (bool): Convert and evaluate quantized streaming TFLite model 155 | """ 156 | 157 | if ( 158 | test_tf_nonstreaming 159 | or test_tflite_nonstreaming 160 | or test_tflite_nonstreaming_quantized 161 | ): 162 | # Save the nonstreaming model to disk 163 | logging.info("Saving nonstreaming model") 164 | 165 | utils.convert_model_saved( 166 | model, 167 | config, 168 | folder="non_stream", 169 | mode=modes.Modes.NON_STREAM_INFERENCE, 170 | ) 171 | 172 | if test_tflite_streaming or test_tflite_streaming_quantized: 173 | # Save the internal streaming model to disk 174 | logging.info("Saving streaming model") 175 | 176 | utils.convert_model_saved( 177 | model, 178 | config, 179 | folder="stream_state_internal", 180 | mode=modes.Modes.STREAM_INTERNAL_STATE_INFERENCE, 181 | ) 182 | 183 | if test_tf_nonstreaming: 184 | logging.info("Testing nonstreaming model") 185 | 186 | folder_name = "non_stream" 187 | test.tf_model_accuracy( 188 | config, 189 | folder_name, 190 | data_processor, 191 | data_set="testing", 192 | accuracy_name="testing_set_metrics.txt", 193 | ) 194 | 195 | tflite_configs = [] 196 | 197 | if test_tflite_nonstreaming: 198 | tflite_configs.append( 199 | { 200 | "log_string": "nonstreaming model", 201 | "source_folder": "non_stream", 202 | "output_folder": "tflite_non_stream", 203 | "filename": "non_stream.tflite", 204 | "testing_dataset": "testing", 205 | "testing_ambient_dataset": "testing_ambient", 206 | "quantize": False, 207 | } 208 | ) 209 | 210 | if test_tflite_nonstreaming_quantized: 211 | tflite_configs.append( 212 | { 213 | "log_string": "quantized nonstreaming model", 214 | "source_folder": "non_stream", 215 | "output_folder": "tflite_non_stream_quant", 216 | "filename": "non_stream_quant.tflite", 217 | "testing_dataset": "testing", 218 | "testing_ambient_dataset": "testing_ambient", 219 | "quantize": True, 220 | } 221 | ) 222 | 223 | if test_tflite_streaming: 224 | tflite_configs.append( 225 | { 226 | "log_string": "streaming model", 227 | "source_folder": "stream_state_internal", 228 | "output_folder": "tflite_stream_state_internal", 229 | "filename": "stream_state_internal.tflite", 230 | "testing_dataset": "testing", 231 | "testing_ambient_dataset": "testing_ambient", 232 | "quantize": False, 233 | } 234 | ) 235 | 236 | if test_tflite_streaming_quantized: 237 | tflite_configs.append( 238 | { 239 | "log_string": "quantized streaming model", 240 | "source_folder": "stream_state_internal", 241 | "output_folder": "tflite_stream_state_internal_quant", 242 | "filename": "stream_state_internal_quant.tflite", 243 | "testing_dataset": "testing", 244 | "testing_ambient_dataset": "testing_ambient", 245 | "quantize": True, 246 | } 247 | ) 248 | 249 | for tflite_config in tflite_configs: 250 | logging.info("Converting %s to TFLite", tflite_config["log_string"]) 251 | 252 | utils.convert_saved_model_to_tflite( 253 | config, 254 | audio_processor=data_processor, 255 | path_to_model=os.path.join(config["train_dir"], tflite_config["source_folder"]), 256 | folder=os.path.join(config["train_dir"], tflite_config["output_folder"]), 257 | fname=tflite_config["filename"], 258 | quantize=tflite_config["quantize"], 259 | ) 260 | 261 | logging.info( 262 | "Testing the TFLite %s false accept per hour and false rejection rates at various cutoffs.", 263 | tflite_config["log_string"], 264 | ) 265 | 266 | test.tflite_streaming_model_roc( 267 | config, 268 | tflite_config["output_folder"], 269 | data_processor, 270 | data_set=tflite_config["testing_dataset"], 271 | ambient_set=tflite_config["testing_ambient_dataset"], 272 | tflite_model_name=tflite_config["filename"], 273 | accuracy_name="tflite_streaming_roc.txt", 274 | ) 275 | 276 | 277 | if __name__ == "__main__": 278 | parser = argparse.ArgumentParser() 279 | parser.add_argument( 280 | "--training_config", 281 | type=str, 282 | default="trained_models/model/training_parameters.yaml", 283 | help="""\ 284 | Path to the training parameters yaml configuration.action= 285 | """, 286 | ) 287 | parser.add_argument( 288 | "--train", 289 | type=int, 290 | default=1, 291 | help="If 1 run train and test, else run only test", 292 | ) 293 | parser.add_argument( 294 | "--test_tf_nonstreaming", 295 | type=int, 296 | default=0, 297 | help="Save the nonstreaming model and test on the test datasets", 298 | ) 299 | parser.add_argument( 300 | "--test_tflite_nonstreaming", 301 | type=int, 302 | default=0, 303 | help="Save the TFLite nonstreaming model and test on the test datasets", 304 | ) 305 | parser.add_argument( 306 | "--test_tflite_nonstreaming_quantized", 307 | type=int, 308 | default=0, 309 | help="Save the TFLite quantized nonstreaming model and test on the test datasets", 310 | ) 311 | parser.add_argument( 312 | "--test_tflite_streaming", 313 | type=int, 314 | default=0, 315 | help="Save the (non-quantized) streaming model and test on the test datasets", 316 | ) 317 | parser.add_argument( 318 | "--test_tflite_streaming_quantized", 319 | type=int, 320 | default=1, 321 | help="Save the quantized streaming model and test on the test datasets", 322 | ) 323 | parser.add_argument( 324 | "--restore_checkpoint", 325 | type=int, 326 | default=0, 327 | help="If 1 it will restore a checkpoint and resume the training " 328 | "by initializing model weights and optimizer with checkpoint values. " 329 | "It will use learning rate and number of training iterations from " 330 | "--learning_rate and --how_many_training_steps accordinlgy. " 331 | "This option is useful in cases when training was interrupted. " 332 | "With it you should adjust learning_rate and how_many_training_steps.", 333 | ) 334 | parser.add_argument( 335 | "--use_weights", 336 | type=str, 337 | default="best_weights", 338 | help="Which set of weights to use when creating the model" 339 | "One of `best_weights`` or `last_weights`.", 340 | ) 341 | 342 | # Function used to parse --verbosity argument 343 | def verbosity_arg(value): 344 | """Parses verbosity argument. 345 | 346 | Args: 347 | value: A member of tf.logging. 348 | 349 | Returns: 350 | TF logging mode 351 | 352 | Raises: 353 | ArgumentTypeError: Not an expected value. 354 | """ 355 | value = value.upper() 356 | if value == "INFO": 357 | return logging.INFO 358 | elif value == "DEBUG": 359 | return logging.DEBUG 360 | elif value == "ERROR": 361 | return logging.ERROR 362 | elif value == "FATAL": 363 | return logging.FATAL 364 | elif value == "WARN": 365 | return logging.WARN 366 | else: 367 | raise argparse.ArgumentTypeError("Not an expected value") 368 | 369 | parser.add_argument( 370 | "--verbosity", 371 | type=verbosity_arg, 372 | default=logging.INFO, 373 | help='Log verbosity. Can be "INFO", "DEBUG", "ERROR", "FATAL", or "WARN"', 374 | ) 375 | 376 | # sub parser for model settings 377 | subparsers = parser.add_subparsers(dest="model_name", help="NN model name") 378 | 379 | # inception model settings 380 | parser_inception = subparsers.add_parser("inception") 381 | inception.model_parameters(parser_inception) 382 | 383 | # mixednet model settings 384 | parser_mixednet = subparsers.add_parser("mixednet") 385 | mixednet.model_parameters(parser_mixednet) 386 | 387 | flags, unparsed = parser.parse_known_args() 388 | if unparsed: 389 | raise ValueError("Unknown argument: {}".format(unparsed)) 390 | 391 | if flags.model_name == "inception": 392 | model_module = inception 393 | elif flags.model_name == "mixednet": 394 | model_module = mixednet 395 | else: 396 | raise ValueError("Unknown model type: {}".format(flags.model_name)) 397 | 398 | logging.set_verbosity(flags.verbosity) 399 | 400 | config = load_config(flags, model_module) 401 | 402 | data_processor = input_data.FeatureHandler(config) 403 | 404 | if flags.train: 405 | model = model_module.model( 406 | flags, config["training_input_shape"], config["batch_size"] 407 | ) 408 | logging.info(model.summary()) 409 | train_model(config, model, data_processor, flags.restore_checkpoint) 410 | else: 411 | if not os.path.isdir(config["train_dir"]): 412 | raise ValueError('model is not trained set "--train 1" and retrain it') 413 | 414 | if ( 415 | flags.test_tf_nonstreaming 416 | or flags.test_tflite_nonstreaming 417 | or flags.test_tflite_streaming 418 | or flags.test_tflite_streaming_quantized 419 | ): 420 | model = model_module.model( 421 | flags, shape=config["training_input_shape"], batch_size=1 422 | ) 423 | 424 | model.load_weights( 425 | os.path.join(config["train_dir"], flags.use_weights) + ".weights.h5" 426 | ) 427 | 428 | logging.info(model.summary()) 429 | 430 | evaluate_model( 431 | config, 432 | model, 433 | data_processor, 434 | flags.test_tf_nonstreaming, 435 | flags.test_tflite_nonstreaming, 436 | flags.test_tflite_nonstreaming_quantized, 437 | flags.test_tflite_streaming, 438 | flags.test_tflite_streaming_quantized, 439 | ) 440 | -------------------------------------------------------------------------------- /microwakeword/test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 The Google Research Authors. 3 | # Modifications copyright 2024 Kevin Ahrendt. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """Test utility functions for accuracy evaluation.""" 18 | 19 | import os 20 | 21 | import numpy as np 22 | import tensorflow as tf 23 | 24 | from absl import logging 25 | from typing import List 26 | from microwakeword.inference import Model 27 | from numpy.lib.stride_tricks import sliding_window_view 28 | 29 | 30 | def compute_metrics(true_positives, true_negatives, false_positives, false_negatives): 31 | """Utility function to compute various metrics. 32 | 33 | Arguments: 34 | true_positives: Count of samples correctly predicted as positive 35 | true_negatives: Count of samples correctly predicted as negative 36 | false_positives: Count of samples incorrectly predicted as positive 37 | false_negatives: Count of samples incorrectly predicted as negative 38 | 39 | Returns: 40 | metric dictionary with keys for `accuracy`, `recall`, `precision`, `false_positive_rate`, `false_negative_rate`, and `count` 41 | """ 42 | 43 | accuracy = float("nan") 44 | false_positive_rate = float("nan") 45 | false_negative_rate = float("nan") 46 | recall = float("nan") 47 | precision = float("nan") 48 | 49 | count = true_positives + true_negatives + false_positives + false_negatives 50 | 51 | if true_positives + true_negatives + false_positives + false_negatives > 0: 52 | accuracy = (true_positives + true_negatives) / count 53 | 54 | if false_positives + true_negatives > 0: 55 | false_positive_rate = false_positives / (false_positives + true_negatives) 56 | 57 | if true_positives + false_negatives > 0: 58 | false_negative_rate = false_negatives / (true_positives + false_negatives) 59 | recall = true_positives / (true_positives + false_negatives) 60 | 61 | if (true_positives + false_positives) > 0: 62 | precision = true_positives / (true_positives + false_positives) 63 | 64 | return { 65 | "accuracy": accuracy, 66 | "recall": recall, 67 | "precision": precision, 68 | "false_positive_rate": false_positive_rate, 69 | "false_negative_rate": false_negative_rate, 70 | "count": count, 71 | } 72 | 73 | 74 | def metrics_to_string(metrics): 75 | """Utility function to return a string that describes various metrics. 76 | 77 | Arguments: 78 | metrics: metric dictionary with keys for `accuracy`, `recall`, `precision`, `false_positive_rate`, `false_negative_rate`, and `count` 79 | 80 | Returns: 81 | string describing the given metrics 82 | """ 83 | 84 | return "accuracy = {accuracy:.4%}; recall = {recall:.4%}; precision = {precision:.4%}; fpr = {fpr:.4%}; fnr = {fnr:.4%}; (N={count})".format( 85 | accuracy=metrics["accuracy"], 86 | recall=metrics["recall"], 87 | precision=metrics["precision"], 88 | fpr=metrics["false_positive_rate"], 89 | fnr=metrics["false_negative_rate"], 90 | count=metrics["count"], 91 | ) 92 | 93 | 94 | def compute_false_accepts_per_hour( 95 | streaming_probabilities_list: List[np.ndarray], 96 | cutoffs: np.array, 97 | ignore_slices_after_accept: int = 75, 98 | stride: int = 1, 99 | step_s: float = 0.02, 100 | ): 101 | """Computes the false accept per hour rates at various cutoffs given a list of streaming probabilities. 102 | 103 | Args: 104 | streaming_probabilities_list (List[numpy.ndarray]): A list containing streaming probabilities from negative audio clips 105 | cutoffs (numpy.array): An array of cutoffs/thresholds to test the false accpet rate at. 106 | ignore_slices_after_accept (int, optional): The number of probabililities slices to ignore after a false accept. Defaults to 75. 107 | stride (int, optional): The stride of the input layer. Defaults to 1. 108 | step_s (float, optional): The duration between each probabilitiy in seconds. Defaults to 0.02. 109 | 110 | Returns: 111 | numpy.ndarray: The false accepts per hour corresponding to thresholds in `cutoffs`. 112 | """ 113 | cutoffs_count = cutoffs.shape[0] 114 | 115 | false_accepts_at_cutoffs = np.zeros(cutoffs_count) 116 | probabilities_duration_h = 0 117 | 118 | for track_probabilities in streaming_probabilities_list: 119 | probabilities_duration_h += len(track_probabilities) * stride * step_s / 3600.0 120 | 121 | cooldown_at_cutoffs = np.ones(cutoffs_count) * ignore_slices_after_accept 122 | 123 | for wakeword_probability in track_probabilities: 124 | # Decrease the cooldown cutoff by 1 with a minimum value of 0 125 | cooldown_at_cutoffs = np.maximum( 126 | cooldown_at_cutoffs - 1, np.zeros(cutoffs_count) 127 | ) 128 | detection_boolean = ( 129 | wakeword_probability > cutoffs 130 | ) # a list of detection states at each cutoff 131 | 132 | for index in range(cutoffs_count): 133 | if cooldown_at_cutoffs[index] == 0 and detection_boolean[index]: 134 | false_accepts_at_cutoffs[index] += 1 135 | cooldown_at_cutoffs[index] = ignore_slices_after_accept 136 | 137 | return false_accepts_at_cutoffs / probabilities_duration_h 138 | 139 | 140 | def generate_roc_curve( 141 | false_accepts_per_hour: np.ndarray, 142 | false_rejections: np.ndarray, 143 | # positive_samples_probabilities: np.ndarray, 144 | cutoffs: np.ndarray, 145 | max_faph: float = 2.0, 146 | ): 147 | """Generates the coordinates for an ROC curve plotting false accepts per hour vs false rejections. Computes the false rejection rate at the specifiied cutoffs. 148 | 149 | Args: 150 | false_accepts_per_hour (numpy.ndarray): False accepts per hour rates for each threshold in `cutoffs`. 151 | false_rejections (numpy.ndarray): False rejection rates for each threshold in `cutoffs`. 152 | cutoffs (numpy.ndarray): Thresholds used for `false_ccepts_per_hour` 153 | max_faph (float, optional): The maximum false accept per hour rate to include in curve's coordinates. Defaults to 2.0. 154 | 155 | Returns: 156 | (numpy.ndarray, numpy.ndarray, numpy.ndarray): (false accept per hour coordinates, false rejection rate coordinates, cutoffs for each coordinate) 157 | """ 158 | 159 | if false_accepts_per_hour[0] > max_faph: 160 | # Use linear interpolation to estimate false negative rate at max_faph 161 | 162 | # Increase the index until we find a faph less than max_faph 163 | index_of_first_viable = 1 164 | while false_accepts_per_hour[index_of_first_viable] > max_faph: 165 | index_of_first_viable += 1 166 | 167 | x0 = false_accepts_per_hour[index_of_first_viable - 1] 168 | y0 = false_rejections[index_of_first_viable - 1] 169 | x1 = false_accepts_per_hour[index_of_first_viable] 170 | y1 = false_rejections[index_of_first_viable - 1] 171 | 172 | fnr_at_max_faph = (y0 * (x1 - 2.0) + y1 * (2.0 - x0)) / (x1 - x0) 173 | cutoff_at_max_faph = ( 174 | cutoffs[index_of_first_viable] + cutoffs[index_of_first_viable - 1] 175 | ) / 2.0 176 | else: 177 | # Smallest faph is less than max_faph, so assume the false negative rate is constant 178 | index_of_first_viable = 0 179 | fnr_at_max_faph = false_rejections[index_of_first_viable] 180 | cutoff_at_max_faph = cutoffs[index_of_first_viable] 181 | 182 | horizontal_coordinates = [max_faph] 183 | vertical_coordinates = [fnr_at_max_faph] 184 | cutoffs_at_coordinate = [cutoff_at_max_faph] 185 | 186 | for index in range(index_of_first_viable, len(false_rejections)): 187 | if false_accepts_per_hour[index] != horizontal_coordinates[-1]: 188 | # Only add a point if it is a new faph 189 | # This ensures if a faph rate is repeated, we use the small false negative rate 190 | horizontal_coordinates.append(false_accepts_per_hour[index]) 191 | vertical_coordinates.append(false_rejections[index]) 192 | cutoffs_at_coordinate.append(cutoffs[index]) 193 | 194 | if horizontal_coordinates[-1] > 0: 195 | # If there isn't a cutoff with 0 faph, then add a coordinate at (0,1) 196 | horizontal_coordinates.append(0.0) 197 | vertical_coordinates.append(1.0) 198 | cutoffs_at_coordinate.append(0.0) 199 | 200 | # The points on the curve are listed in descending order, flip them before returning 201 | horizontal_coordinates = np.flip(horizontal_coordinates) 202 | vertical_coordinates = np.flip(vertical_coordinates) 203 | cutoffs_at_coordinate = np.flip(cutoffs_at_coordinate) 204 | return horizontal_coordinates, vertical_coordinates, cutoffs_at_coordinate 205 | 206 | 207 | def tf_model_accuracy( 208 | config, 209 | folder, 210 | audio_processor, 211 | data_set="testing", 212 | accuracy_name="tf_model_accuracy.txt", 213 | ): 214 | """Function to test a TF model on a specified data set. 215 | 216 | NOTE: This assumes the wakeword is at the end of the spectrogram. The ``tflite_streaming_model_roc`` method does not make this assumption, and you may get vastly different results depending on how word is positioned in the spectrogram in the data set. 217 | 218 | Arguments: 219 | config: dictionary containing microWakeWord training configuration 220 | folder: folder containing the TF model 221 | audio_processor: microWakeWord FeatureHandler object for retrieving spectrograms 222 | data_set: data set to test the model on 223 | accuracy_name: filename to save metrics to 224 | 225 | Returns: 226 | metric dictionary with keys for `accuracy`, `recall`, `precision`, `false_positive_rate`, `false_negative_rate`, and `count` 227 | """ 228 | 229 | test_fingerprints, test_ground_truth, _ = audio_processor.get_data( 230 | data_set, 231 | batch_size=config["batch_size"], 232 | features_length=config["spectrogram_length"], 233 | truncation_strategy="truncate_start", 234 | ) 235 | 236 | with tf.device("/cpu:0"): 237 | model = tf.saved_model.load(os.path.join(config["train_dir"], folder)) 238 | inference_batch_size = 1 239 | 240 | true_positives = 0 241 | true_negatives = 0 242 | false_positives = 0 243 | false_negatives = 0 244 | 245 | for i in range(0, len(test_fingerprints), inference_batch_size): 246 | spectrogram_features = test_fingerprints[i : i + inference_batch_size] 247 | sample_ground_truth = test_ground_truth[i] 248 | 249 | result = model(tf.convert_to_tensor(spectrogram_features, dtype=tf.float32)) 250 | 251 | prediction = result.numpy()[0][0] > 0.5 252 | if sample_ground_truth == prediction: 253 | if sample_ground_truth: 254 | true_positives += 1 255 | else: 256 | true_negatives += 1 257 | else: 258 | if sample_ground_truth: 259 | false_negatives += 1 260 | else: 261 | false_positives += 1 262 | 263 | metrics = compute_metrics( 264 | true_positives, true_negatives, false_positives, false_negatives 265 | ) 266 | 267 | if i % 1000 == 0 and i: 268 | logging.info( 269 | "TensorFlow model on the {dataset} set: accuracy = {accuracy:.6}; recall = {recall:.6}; precision = {precision:.6}; fpr = {fpr:.6}; fnr = {fnr:.6} ({iteration} out of {length})".format( 270 | dataset=data_set, 271 | accuracy=metrics["accuracy"], 272 | recall=metrics["recall"], 273 | precision=metrics["precision"], 274 | fpr=metrics["false_positive_rate"], 275 | fnr=metrics["false_negative_rate"], 276 | iteration=i, 277 | length=len(test_fingerprints), 278 | ) 279 | ) 280 | 281 | metrics_string = metrics_to_string(metrics) 282 | 283 | logging.info( 284 | "Final TensorFlow model on the " + data_set + " set: " + metrics_string 285 | ) 286 | 287 | path = os.path.join(config["train_dir"], folder) 288 | with open(os.path.join(path, accuracy_name), "wt") as fd: 289 | fd.write(metrics_string) 290 | return metrics 291 | 292 | 293 | def tflite_streaming_model_roc( 294 | config, 295 | folder, 296 | audio_processor, 297 | data_set="testing", 298 | ambient_set="testing_ambient", 299 | tflite_model_name="stream_state_internal.tflite", 300 | accuracy_name="tflite_streaming_roc.txt", 301 | sliding_window_length=5, 302 | ignore_slices_after_accept=25, 303 | ): 304 | """Function to test a tflite model false accepts per hour and false rejection rates. 305 | 306 | Model can be streaming or nonstreaming. Nonstreaming models are strided by 1 spectrogram feature in the time dimension. 307 | 308 | Args: 309 | config (dict): dictionary containing microWakeWord training configuration 310 | folder (str): folder containing the TFLite model 311 | audio_processor (FeatureHandler): microWakeWord FeatureHandler object for retrieving spectrograms 312 | data_set (str, optional): Dataset for testing recall. Defaults to "testing". 313 | ambient_set (str, optional): Dataset for testing false accepts per hour. Defaults to "testing_ambient". 314 | tflite_model_name (str, optional): filename of the TFLite model. Defaults to "stream_state_internal.tflite". 315 | accuracy_name (str, optional): filename to save metrics at various cutoffs. Defaults to "tflite_streaming_roc.txt". 316 | sliding_window_length (int, optional): the length of the sliding window for computing average probabilities. Defaults to 1. 317 | 318 | Returns: 319 | float: The Area under the false accept per hour vs. false rejection curve. 320 | """ 321 | stride = config["stride"] 322 | model = Model( 323 | os.path.join(config["train_dir"], folder, tflite_model_name), stride=stride 324 | ) 325 | 326 | test_ambient_fingerprints, _, _ = audio_processor.get_data( 327 | ambient_set, 328 | batch_size=config["batch_size"], 329 | features_length=config["spectrogram_length"], 330 | truncation_strategy="none", 331 | ) 332 | 333 | logging.info("Testing the " + ambient_set + " set.") 334 | ambient_streaming_probabilities = [] 335 | for spectrogram_track in test_ambient_fingerprints: 336 | streaming_probabilities = model.predict_spectrogram(spectrogram_track) 337 | sliding_window_probabilities = sliding_window_view( 338 | streaming_probabilities, sliding_window_length 339 | ) 340 | moving_average = sliding_window_probabilities.mean(axis=-1) 341 | ambient_streaming_probabilities.append(moving_average) 342 | 343 | cutoffs = np.arange(0, 1.01, 0.01) 344 | # ignore_slices_after_accept = 25 345 | 346 | faph = compute_false_accepts_per_hour( 347 | ambient_streaming_probabilities, 348 | cutoffs, 349 | ignore_slices_after_accept, 350 | stride=config["stride"], 351 | step_s=config["window_step_ms"] / 1000, 352 | ) 353 | 354 | test_fingerprints, test_ground_truth, _ = audio_processor.get_data( 355 | data_set, 356 | batch_size=config["batch_size"], 357 | features_length=config["spectrogram_length"], 358 | truncation_strategy="none", 359 | ) 360 | 361 | logging.info("Testing the " + data_set + " set.") 362 | 363 | positive_sample_streaming_probabilities = [] 364 | for i in range(len(test_fingerprints)): 365 | if test_ground_truth[i]: 366 | # Only test positive samples 367 | streaming_probabilities = model.predict_spectrogram(test_fingerprints[i]) 368 | sliding_window_probabilities = sliding_window_view( 369 | streaming_probabilities[ignore_slices_after_accept:], 370 | sliding_window_length, 371 | ) 372 | moving_average = sliding_window_probabilities.mean(axis=-1) 373 | positive_sample_streaming_probabilities.append(np.max(moving_average)) 374 | 375 | # Compute the false negative rates at each cutoff 376 | false_negative_rate_at_cutoffs = [] 377 | for cutoff in cutoffs: 378 | true_accepts = sum(i > cutoff for i in positive_sample_streaming_probabilities) 379 | false_negative_rate_at_cutoffs.append( 380 | 1 - true_accepts / len(positive_sample_streaming_probabilities) 381 | ) 382 | 383 | x_coordinates, y_coordinates, cutoffs_at_points = generate_roc_curve( 384 | false_accepts_per_hour=faph, 385 | false_rejections=false_negative_rate_at_cutoffs, 386 | cutoffs=cutoffs, 387 | ) 388 | 389 | path = os.path.join(config["train_dir"], folder) 390 | with open(os.path.join(path, accuracy_name), "wt") as fd: 391 | auc = np.trapz(y_coordinates, x_coordinates) 392 | auc_string = "AUC {:.5f}".format(auc) 393 | logging.info(auc_string) 394 | fd.write(auc_string + "\n") 395 | 396 | for i in range(0, x_coordinates.shape[0]): 397 | cutoff_string = "Cutoff {:.2f}: frr={:.4f}; faph={:.3f}".format( 398 | cutoffs_at_points[i], y_coordinates[i], x_coordinates[i] 399 | ) 400 | logging.info(cutoff_string) 401 | fd.write(cutoff_string + "\n") 402 | 403 | return auc 404 | 405 | 406 | def tflite_model_accuracy( 407 | config, 408 | folder, 409 | audio_processor, 410 | data_set="testing", 411 | tflite_model_name="stream_state_internal.tflite", 412 | accuracy_name="tflite_model_accuracy.txt", 413 | ): 414 | """Function to test a TFLite model on a specified data set. 415 | 416 | NOTE: This assumes the wakeword is at the end of the spectrogram. The ``tflite_streaming_model_roc`` method does not make this assumption, and you may get vastly different results depending on how word is positioned in the spectrogram in the data set. 417 | 418 | Model can be streaming or nonstreaming. If tested on an "_ambient" set, 419 | it detects a false accept if the previous probability was less than 0.5 420 | and the current probability is greater than 0.5. 421 | 422 | Arguments: 423 | config: dictionary containing microWakeWord training configuration 424 | folder: folder containing the TFLite model 425 | audio_processor: microWakeWord FeatureHandler object for retrieving spectrograms 426 | data_set: data set to test the model on 427 | tflite_model_name: filename of the TFLite model 428 | accuracy_name: filename to save metrics to 429 | 430 | Returns: 431 | Metric dictionary with keys for `accuracy`, `recall`, `precision`, `false_positive_rate`, `false_negative_rate`, and `count` 432 | """ 433 | 434 | model = Model(os.path.join(config["train_dir"], folder, tflite_model_name)) 435 | 436 | truncation_strategy = "truncate_start" 437 | if data_set.endswith("ambient"): 438 | truncation_strategy = "none" 439 | 440 | test_fingerprints, test_ground_truth, _ = audio_processor.get_data( 441 | data_set, 442 | batch_size=config["batch_size"], 443 | features_length=config["spectrogram_length"], 444 | truncation_strategy=truncation_strategy, 445 | ) 446 | 447 | logging.info("Testing TFLite model on the {data_set} set".format(data_set=data_set)) 448 | 449 | true_positives = 0.0 450 | true_negatives = 0.0 451 | false_positives = 0.0 452 | false_negatives = 0.0 453 | 454 | previous_probability = 0.0 455 | 456 | for i in range(0, len(test_fingerprints)): 457 | sample_fingerprint = test_fingerprints[i].astype(np.float32) 458 | sample_ground_truth = test_ground_truth[i] 459 | 460 | probabilities = model.predict_spectrogram(sample_fingerprint) 461 | 462 | if truncation_strategy != "none": 463 | prediction = probabilities[-1] > 0.5 464 | if sample_ground_truth == prediction: 465 | if sample_ground_truth: 466 | true_positives += 1 467 | else: 468 | true_negatives += 1 469 | else: 470 | if sample_ground_truth: 471 | false_negatives += 1 472 | else: 473 | false_positives += 1 474 | else: 475 | previous_probability = 0 476 | last_positive_index = 0 477 | for index, prob in enumerate(probabilities): 478 | if (previous_probability <= 0.5 and prob > 0.5) and ( 479 | index - last_positive_index 480 | > config["spectrogram_length_final_layer"] 481 | ): 482 | false_positives += 1 483 | last_positive_index = index 484 | previous_probability = prob 485 | 486 | metrics = compute_metrics( 487 | true_positives, true_negatives, false_positives, false_negatives 488 | ) 489 | 490 | if i % 1000 == 0 and i: 491 | logging.info( 492 | "TFLite model on the {dataset} set: accuracy = {accuracy:.6}; recall = {recall:.6}; precision = {precision:.6}; fpr = {fpr:.6}; fnr = {fnr:.6} ({iteration} out of {length})".format( 493 | dataset=data_set, 494 | accuracy=metrics["accuracy"], 495 | recall=metrics["recall"], 496 | precision=metrics["precision"], 497 | fpr=metrics["false_positive_rate"], 498 | fnr=metrics["false_negative_rate"], 499 | iteration=i, 500 | length=len(test_fingerprints), 501 | ) 502 | ) 503 | 504 | if truncation_strategy != "none": 505 | metrics_string = metrics_to_string(metrics) 506 | else: 507 | metrics_string = "false accepts = {false_positives}; false accepts per hour = {faph:.4}".format( 508 | false_positives=false_positives, 509 | faph=false_positives 510 | / (audio_processor.get_mode_duration(data_set) / 3600.0), 511 | ) 512 | 513 | logging.info("Final TFLite model on the " + data_set + " set: " + metrics_string) 514 | path = os.path.join(config["train_dir"], folder) 515 | with open(os.path.join(path, accuracy_name), "wt") as fd: 516 | fd.write(metrics_string) 517 | return metrics 518 | -------------------------------------------------------------------------------- /microwakeword/train.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 The Google Research Authors. 3 | # Modifications copyright 2024 Kevin Ahrendt. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | import os 18 | import platform 19 | import contextlib 20 | 21 | from absl import logging 22 | 23 | import numpy as np 24 | import tensorflow as tf 25 | 26 | from tensorflow.python.util import tf_decorator 27 | 28 | 29 | @contextlib.contextmanager 30 | def swap_attribute(obj, attr, temp_value): 31 | """Temporarily swap an attribute of an object.""" 32 | original_value = getattr(obj, attr) 33 | setattr(obj, attr, temp_value) 34 | 35 | try: 36 | yield 37 | finally: 38 | setattr(obj, attr, original_value) 39 | 40 | 41 | def validate_nonstreaming(config, data_processor, model, test_set): 42 | testing_fingerprints, testing_ground_truth, _ = data_processor.get_data( 43 | test_set, 44 | batch_size=config["batch_size"], 45 | features_length=config["spectrogram_length"], 46 | truncation_strategy="truncate_start", 47 | ) 48 | testing_ground_truth = testing_ground_truth.reshape(-1, 1) 49 | 50 | model.reset_metrics() 51 | 52 | result = model.evaluate( 53 | testing_fingerprints, 54 | testing_ground_truth, 55 | batch_size=1024, 56 | return_dict=True, 57 | verbose=0, 58 | ) 59 | 60 | metrics = {} 61 | metrics["accuracy"] = result["accuracy"] 62 | metrics["recall"] = result["recall"] 63 | metrics["precision"] = result["precision"] 64 | 65 | metrics["auc"] = result["auc"] 66 | metrics["loss"] = result["loss"] 67 | metrics["recall_at_no_faph"] = 0 68 | metrics["cutoff_for_no_faph"] = 0 69 | metrics["ambient_false_positives"] = 0 70 | metrics["ambient_false_positives_per_hour"] = 0 71 | metrics["average_viable_recall"] = 0 72 | 73 | test_set_fp = result["fp"].numpy() 74 | 75 | if data_processor.get_mode_size("validation_ambient") > 0: 76 | ( 77 | ambient_testing_fingerprints, 78 | ambient_testing_ground_truth, 79 | _, 80 | ) = data_processor.get_data( 81 | test_set + "_ambient", 82 | batch_size=config["batch_size"], 83 | features_length=config["spectrogram_length"], 84 | truncation_strategy="split", 85 | ) 86 | ambient_testing_ground_truth = ambient_testing_ground_truth.reshape(-1, 1) 87 | 88 | # XXX: tf no longer provides a way to evaluate a model without updating metrics 89 | with swap_attribute(model, "reset_metrics", lambda: None): 90 | ambient_predictions = model.evaluate( 91 | ambient_testing_fingerprints, 92 | ambient_testing_ground_truth, 93 | batch_size=1024, 94 | return_dict=True, 95 | verbose=0, 96 | ) 97 | 98 | duration_of_ambient_set = ( 99 | data_processor.get_mode_duration("validation_ambient") / 3600.0 100 | ) 101 | 102 | # Other than the false positive rate, all other metrics are accumulated across 103 | # both test sets 104 | all_true_positives = ambient_predictions["tp"].numpy() 105 | ambient_false_positives = ambient_predictions["fp"].numpy() - test_set_fp 106 | all_false_negatives = ambient_predictions["fn"].numpy() 107 | 108 | metrics["auc"] = ambient_predictions["auc"] 109 | metrics["loss"] = ambient_predictions["loss"] 110 | 111 | recall_at_cutoffs = ( 112 | all_true_positives / (all_true_positives + all_false_negatives) 113 | ) 114 | faph_at_cutoffs = ambient_false_positives / duration_of_ambient_set 115 | 116 | target_faph_cutoff_probability = 1.0 117 | for index, cutoff in enumerate(np.linspace(0.0, 1.0, 101)): 118 | if faph_at_cutoffs[index] == 0: 119 | target_faph_cutoff_probability = cutoff 120 | recall_at_no_faph = recall_at_cutoffs[index] 121 | break 122 | 123 | if faph_at_cutoffs[0] > 2: 124 | # Use linear interpolation to estimate recall at 2 faph 125 | 126 | # Increase index until we find a faph less than 2 127 | index_of_first_viable = 1 128 | while faph_at_cutoffs[index_of_first_viable] > 2: 129 | index_of_first_viable += 1 130 | 131 | x0 = faph_at_cutoffs[index_of_first_viable - 1] 132 | y0 = recall_at_cutoffs[index_of_first_viable - 1] 133 | x1 = faph_at_cutoffs[index_of_first_viable] 134 | y1 = recall_at_cutoffs[index_of_first_viable] 135 | 136 | recall_at_2faph = (y0 * (x1 - 2.0) + y1 * (2.0 - x0)) / (x1 - x0) 137 | else: 138 | # Lowest faph is already under 2, assume the recall is constant before this 139 | index_of_first_viable = 0 140 | recall_at_2faph = recall_at_cutoffs[0] 141 | 142 | x_coordinates = [2.0] 143 | y_coordinates = [recall_at_2faph] 144 | 145 | for index in range(index_of_first_viable, len(recall_at_cutoffs)): 146 | if faph_at_cutoffs[index] != x_coordinates[-1]: 147 | # Only add a point if it is a new faph 148 | # This ensures if a faph rate is repeated, we use the highest recall 149 | x_coordinates.append(faph_at_cutoffs[index]) 150 | y_coordinates.append(recall_at_cutoffs[index]) 151 | 152 | # Use trapezoid rule to estimate the area under the curve, then divide by 2.0 to get the average recall 153 | average_viable_recall = ( 154 | np.trapz(np.flip(y_coordinates), np.flip(x_coordinates)) / 2.0 155 | ) 156 | 157 | metrics["recall_at_no_faph"] = recall_at_no_faph 158 | metrics["cutoff_for_no_faph"] = target_faph_cutoff_probability 159 | metrics["ambient_false_positives"] = ambient_false_positives[50] 160 | metrics["ambient_false_positives_per_hour"] = faph_at_cutoffs[50] 161 | metrics["average_viable_recall"] = average_viable_recall 162 | 163 | return metrics 164 | 165 | 166 | def train(model, config, data_processor): 167 | # Assign default training settings if not set in the configuration yaml 168 | if not (training_steps_list := config.get("training_steps")): 169 | training_steps_list = [20000] 170 | if not (learning_rates_list := config.get("learning_rates")): 171 | learning_rates_list = [0.001] 172 | if not (mix_up_prob_list := config.get("mix_up_augmentation_prob")): 173 | mix_up_prob_list = [0.0] 174 | if not (freq_mix_prob_list := config.get("freq_mix_augmentation_prob")): 175 | freq_mix_prob_list = [0.0] 176 | if not (time_mask_max_size_list := config.get("time_mask_max_size")): 177 | time_mask_max_size_list = [5] 178 | if not (time_mask_count_list := config.get("time_mask_count")): 179 | time_mask_count_list = [2] 180 | if not (freq_mask_max_size_list := config.get("freq_mask_max_size")): 181 | freq_mask_max_size_list = [5] 182 | if not (freq_mask_count_list := config.get("freq_mask_count")): 183 | freq_mask_count_list = [2] 184 | if not (positive_class_weight_list := config.get("positive_class_weight")): 185 | positive_class_weight_list = [1.0] 186 | if not (negative_class_weight_list := config.get("negative_class_weight")): 187 | negative_class_weight_list = [1.0] 188 | 189 | # Ensure all training setting lists are as long as the training step iterations 190 | def pad_list_with_last_entry(list_to_pad, desired_length): 191 | while len(list_to_pad) < desired_length: 192 | last_entry = list_to_pad[-1] 193 | list_to_pad.append(last_entry) 194 | 195 | training_step_iterations = len(training_steps_list) 196 | pad_list_with_last_entry(learning_rates_list, training_step_iterations) 197 | pad_list_with_last_entry(mix_up_prob_list, training_step_iterations) 198 | pad_list_with_last_entry(freq_mix_prob_list, training_step_iterations) 199 | pad_list_with_last_entry(time_mask_max_size_list, training_step_iterations) 200 | pad_list_with_last_entry(time_mask_count_list, training_step_iterations) 201 | pad_list_with_last_entry(freq_mask_max_size_list, training_step_iterations) 202 | pad_list_with_last_entry(freq_mask_count_list, training_step_iterations) 203 | pad_list_with_last_entry(positive_class_weight_list, training_step_iterations) 204 | pad_list_with_last_entry(negative_class_weight_list, training_step_iterations) 205 | 206 | loss = tf.keras.losses.BinaryCrossentropy(from_logits=False) 207 | optimizer = tf.keras.optimizers.Adam() 208 | 209 | cutoffs = np.linspace(0.0, 1.0, 101).tolist() 210 | 211 | metrics = [ 212 | tf.keras.metrics.BinaryAccuracy(name="accuracy"), 213 | tf.keras.metrics.Recall(name="recall"), 214 | tf.keras.metrics.Precision(name="precision"), 215 | tf.keras.metrics.TruePositives(name="tp", thresholds=cutoffs), 216 | tf.keras.metrics.FalsePositives(name="fp", thresholds=cutoffs), 217 | tf.keras.metrics.TrueNegatives(name="tn", thresholds=cutoffs), 218 | tf.keras.metrics.FalseNegatives(name="fn", thresholds=cutoffs), 219 | tf.keras.metrics.AUC(name="auc"), 220 | tf.keras.metrics.BinaryCrossentropy(name="loss"), 221 | ] 222 | 223 | model.compile(optimizer=optimizer, loss=loss, metrics=metrics) 224 | 225 | # We un-decorate the `tf.function`, it's very slow to manually run training batches 226 | model.make_train_function() 227 | _, model.train_function = tf_decorator.unwrap(model.train_function) 228 | 229 | # Configure checkpointer and restore if available 230 | checkpoint_directory = os.path.join(config["train_dir"], "restore/") 231 | checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") 232 | checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model) 233 | checkpoint.restore(tf.train.latest_checkpoint(checkpoint_directory)) 234 | 235 | # Configure TensorBoard summaries 236 | train_writer = tf.summary.create_file_writer( 237 | os.path.join(config["summaries_dir"], "train") 238 | ) 239 | validation_writer = tf.summary.create_file_writer( 240 | os.path.join(config["summaries_dir"], "validation") 241 | ) 242 | 243 | training_steps_max = np.sum(training_steps_list) 244 | 245 | best_minimization_quantity = 10000 246 | best_maximization_quantity = 0.0 247 | best_no_faph_cutoff = 1.0 248 | 249 | for training_step in range(1, training_steps_max + 1): 250 | training_steps_sum = 0 251 | for i in range(len(training_steps_list)): 252 | training_steps_sum += training_steps_list[i] 253 | if training_step <= training_steps_sum: 254 | learning_rate = learning_rates_list[i] 255 | mix_up_prob = mix_up_prob_list[i] 256 | freq_mix_prob = freq_mix_prob_list[i] 257 | time_mask_max_size = time_mask_max_size_list[i] 258 | time_mask_count = time_mask_count_list[i] 259 | freq_mask_max_size = freq_mask_max_size_list[i] 260 | freq_mask_count = freq_mask_count_list[i] 261 | positive_class_weight = positive_class_weight_list[i] 262 | negative_class_weight = negative_class_weight_list[i] 263 | break 264 | 265 | model.optimizer.learning_rate.assign(learning_rate) 266 | 267 | augmentation_policy = { 268 | "mix_up_prob": mix_up_prob, 269 | "freq_mix_prob": freq_mix_prob, 270 | "time_mask_max_size": time_mask_max_size, 271 | "time_mask_count": time_mask_count, 272 | "freq_mask_max_size": freq_mask_max_size, 273 | "freq_mask_count": freq_mask_count, 274 | } 275 | 276 | ( 277 | train_fingerprints, 278 | train_ground_truth, 279 | train_sample_weights, 280 | ) = data_processor.get_data( 281 | "training", 282 | batch_size=config["batch_size"], 283 | features_length=config["spectrogram_length"], 284 | truncation_strategy="default", 285 | augmentation_policy=augmentation_policy, 286 | ) 287 | 288 | train_ground_truth = train_ground_truth.reshape(-1, 1) 289 | 290 | class_weights = {0: negative_class_weight, 1: positive_class_weight} 291 | combined_weights = train_sample_weights * np.vectorize(class_weights.get)( 292 | train_ground_truth 293 | ) 294 | 295 | result = model.train_on_batch( 296 | train_fingerprints, 297 | train_ground_truth, 298 | sample_weight=combined_weights, 299 | ) 300 | 301 | # Print the running statistics in the current validation epoch 302 | print( 303 | "Validation Batch #{:d}: Accuracy = {:.3f}; Recall = {:.3f}; Precision = {:.3f}; Loss = {:.4f}; Mini-Batch #{:d}".format( 304 | (training_step // config["eval_step_interval"] + 1), 305 | result[1], 306 | result[2], 307 | result[3], 308 | result[9], 309 | (training_step % config["eval_step_interval"]), 310 | ), 311 | end="\r", 312 | ) 313 | 314 | is_last_step = training_step == training_steps_max 315 | if (training_step % config["eval_step_interval"]) == 0 or is_last_step: 316 | logging.info( 317 | "Step #%d: rate %f, accuracy %.2f%%, recall %.2f%%, precision %.2f%%, cross entropy %f", 318 | *( 319 | training_step, 320 | learning_rate, 321 | result[1] * 100, 322 | result[2] * 100, 323 | result[3] * 100, 324 | result[9], 325 | ), 326 | ) 327 | 328 | with train_writer.as_default(): 329 | tf.summary.scalar("loss", result[9], step=training_step) 330 | tf.summary.scalar("accuracy", result[1], step=training_step) 331 | tf.summary.scalar("recall", result[2], step=training_step) 332 | tf.summary.scalar("precision", result[3], step=training_step) 333 | tf.summary.scalar("auc", result[8], step=training_step) 334 | train_writer.flush() 335 | 336 | model.save_weights( 337 | os.path.join(config["train_dir"], "last_weights.weights.h5") 338 | ) 339 | 340 | nonstreaming_metrics = validate_nonstreaming( 341 | config, data_processor, model, "validation" 342 | ) 343 | model.reset_metrics() # reset metrics for next validation epoch of training 344 | logging.info( 345 | "Step %d (nonstreaming): Validation: recall at no faph = %.3f with cutoff %.2f, accuracy = %.2f%%, recall = %.2f%%, precision = %.2f%%, ambient false positives = %d, estimated false positives per hour = %.5f, loss = %.5f, auc = %.5f, average viable recall = %.9f", 346 | *( 347 | training_step, 348 | nonstreaming_metrics["recall_at_no_faph"] * 100, 349 | nonstreaming_metrics["cutoff_for_no_faph"], 350 | nonstreaming_metrics["accuracy"] * 100, 351 | nonstreaming_metrics["recall"] * 100, 352 | nonstreaming_metrics["precision"] * 100, 353 | nonstreaming_metrics["ambient_false_positives"], 354 | nonstreaming_metrics["ambient_false_positives_per_hour"], 355 | nonstreaming_metrics["loss"], 356 | nonstreaming_metrics["auc"], 357 | nonstreaming_metrics["average_viable_recall"], 358 | ), 359 | ) 360 | 361 | with validation_writer.as_default(): 362 | tf.summary.scalar( 363 | "loss", nonstreaming_metrics["loss"], step=training_step 364 | ) 365 | tf.summary.scalar( 366 | "accuracy", nonstreaming_metrics["accuracy"], step=training_step 367 | ) 368 | tf.summary.scalar( 369 | "recall", nonstreaming_metrics["recall"], step=training_step 370 | ) 371 | tf.summary.scalar( 372 | "precision", nonstreaming_metrics["precision"], step=training_step 373 | ) 374 | tf.summary.scalar( 375 | "recall_at_no_faph", 376 | nonstreaming_metrics["recall_at_no_faph"], 377 | step=training_step, 378 | ) 379 | tf.summary.scalar( 380 | "auc", 381 | nonstreaming_metrics["auc"], 382 | step=training_step, 383 | ) 384 | tf.summary.scalar( 385 | "average_viable_recall", 386 | nonstreaming_metrics["average_viable_recall"], 387 | step=training_step, 388 | ) 389 | validation_writer.flush() 390 | 391 | os.makedirs(os.path.join(config["train_dir"], "train"), exist_ok=True) 392 | 393 | model.save_weights( 394 | os.path.join( 395 | config["train_dir"], 396 | "train", 397 | f"{int(best_minimization_quantity * 10000)}_weights_{training_step}.weights.h5", 398 | ) 399 | ) 400 | 401 | current_minimization_quantity = 0.0 402 | if config["minimization_metric"] is not None: 403 | current_minimization_quantity = nonstreaming_metrics[ 404 | config["minimization_metric"] 405 | ] 406 | current_maximization_quantity = nonstreaming_metrics[ 407 | config["maximization_metric"] 408 | ] 409 | current_no_faph_cutoff = nonstreaming_metrics["cutoff_for_no_faph"] 410 | 411 | # Save model weights if this is a new best model 412 | if ( 413 | ( 414 | ( 415 | current_minimization_quantity <= config["target_minimization"] 416 | ) # achieved target false positive rate 417 | and ( 418 | ( 419 | current_maximization_quantity > best_maximization_quantity 420 | ) # either accuracy improved 421 | or ( 422 | best_minimization_quantity > config["target_minimization"] 423 | ) # or this is the first time we met the target 424 | ) 425 | ) 426 | or ( 427 | ( 428 | current_minimization_quantity > config["target_minimization"] 429 | ) # we haven't achieved our target 430 | and ( 431 | current_minimization_quantity < best_minimization_quantity 432 | ) # but we have decreased since the previous best 433 | ) 434 | or ( 435 | ( 436 | current_minimization_quantity == best_minimization_quantity 437 | ) # we tied a previous best 438 | and ( 439 | current_maximization_quantity > best_maximization_quantity 440 | ) # and we increased our accuracy 441 | ) 442 | ): 443 | best_minimization_quantity = current_minimization_quantity 444 | best_maximization_quantity = current_maximization_quantity 445 | best_no_faph_cutoff = current_no_faph_cutoff 446 | 447 | # overwrite the best model weights 448 | model.save_weights( 449 | os.path.join(config["train_dir"], "best_weights.weights.h5") 450 | ) 451 | checkpoint.save(file_prefix=checkpoint_prefix) 452 | 453 | logging.info( 454 | "So far the best minimization quantity is %.3f with best maximization quantity of %.5f%%; no faph cutoff is %.2f", 455 | best_minimization_quantity, 456 | (best_maximization_quantity * 100), 457 | best_no_faph_cutoff, 458 | ) 459 | 460 | # Save checkpoint after training 461 | checkpoint.save(file_prefix=checkpoint_prefix) 462 | model.save_weights(os.path.join(config["train_dir"], "last_weights.weights.h5")) 463 | -------------------------------------------------------------------------------- /microwakeword/utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 The Google Research Authors. 3 | # Modifications copyright 2024 Kevin Ahrendt. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """Utility functions for operations on Model.""" 18 | import os.path 19 | import numpy as np 20 | import tensorflow as tf 21 | 22 | from absl import logging 23 | 24 | from microwakeword.layers import modes, stream, strided_drop 25 | 26 | 27 | def _set_mode(model, mode): 28 | """Set model's inference type and disable training.""" 29 | 30 | def _recursive_set_layer_mode(layer, mode): 31 | if isinstance(layer, tf.keras.layers.Wrapper): 32 | _recursive_set_layer_mode(layer.layer, mode) 33 | 34 | config = layer.get_config() 35 | # for every layer set mode, if it has it 36 | if "mode" in config: 37 | assert isinstance( 38 | layer, 39 | (stream.Stream, strided_drop.StridedDrop, strided_drop.StridedKeep), 40 | ) 41 | layer.mode = mode 42 | # with any mode of inference - training is False 43 | if "training" in config: 44 | layer.training = False 45 | if mode == modes.Modes.NON_STREAM_INFERENCE: 46 | if "unroll" in config: 47 | layer.unroll = True 48 | 49 | for layer in model.layers: 50 | _recursive_set_layer_mode(layer, mode) 51 | return model 52 | 53 | 54 | def _copy_weights(new_model, model): 55 | """Copy weights of trained model to an inference one.""" 56 | 57 | def _same_weights(weight, new_weight): 58 | # Check that weights are the same 59 | # Note that states should be marked as non trainable 60 | return ( 61 | weight.trainable == new_weight.trainable 62 | and weight.shape == new_weight.shape 63 | and weight.name[weight.name.rfind("/") : None] 64 | == new_weight.name[new_weight.name.rfind("/") : None] 65 | ) 66 | 67 | if len(new_model.layers) != len(model.layers): 68 | raise ValueError( 69 | "number of layers in new_model: %d != to layers number in model: %d " 70 | % (len(new_model.layers), len(model.layers)) 71 | ) 72 | 73 | for i in range(len(model.layers)): 74 | layer = model.layers[i] 75 | new_layer = new_model.layers[i] 76 | 77 | # if number of weights in the layers are the same 78 | # then we can set weights directly 79 | if len(layer.get_weights()) == len(new_layer.get_weights()): 80 | new_layer.set_weights(layer.get_weights()) 81 | elif layer.weights: 82 | k = 0 # index pointing to weights in the copied model 83 | new_weights = [] 84 | # iterate over weights in the new_model 85 | # and prepare a new_weights list which will 86 | # contain weights from model and weight states from new model 87 | for k_new in range(len(new_layer.get_weights())): 88 | new_weight = new_layer.weights[k_new] 89 | new_weight_values = new_layer.get_weights()[k_new] 90 | same_weights = True 91 | 92 | # if there are weights which are not copied yet 93 | if k < len(layer.get_weights()): 94 | weight = layer.weights[k] 95 | weight_values = layer.get_weights()[k] 96 | if ( 97 | weight.shape != weight_values.shape 98 | or new_weight.shape != new_weight_values.shape 99 | ): 100 | raise ValueError("weights are not listed in order") 101 | 102 | # if there are weights available for copying and they are the same 103 | if _same_weights(weight, new_weight): 104 | new_weights.append(weight_values) 105 | k = k + 1 # go to next weight in model 106 | else: 107 | same_weights = False # weights are different 108 | else: 109 | same_weights = ( 110 | False # all weights are copied, remaining is different 111 | ) 112 | 113 | if not same_weights: 114 | # weight with index k_new is missing in model, 115 | # so we will keep iterating over k_new until find similar weights 116 | new_weights.append(new_weight_values) 117 | 118 | # check that all weights from model are copied to a new_model 119 | if k != len(layer.get_weights()): 120 | raise ValueError( 121 | "trained model has: %d weights, but only %d were copied" 122 | % (len(layer.get_weights()), k) 123 | ) 124 | 125 | # now they should have the same number of weights with matched sizes 126 | # so we can set weights directly 127 | new_layer.set_weights(new_weights) 128 | return new_model 129 | 130 | 131 | def save_model_summary(model, path, file_name="model_summary.txt"): 132 | """Saves model topology/summary in text format. 133 | 134 | Args: 135 | model: Keras model 136 | path: path where to store model summary 137 | file_name: model summary file name 138 | """ 139 | with tf.io.gfile.GFile(os.path.join(path, file_name), "w") as fd: 140 | stringlist = [] 141 | model.summary( 142 | print_fn=lambda x: stringlist.append(x) 143 | ) # pylint: disable=unnecessary-lambda 144 | model_summary = "\n".join(stringlist) 145 | fd.write(model_summary) 146 | 147 | 148 | def convert_to_inference_model(model, input_tensors, mode): 149 | """Convert tf._keras_internal.engine.functional `Model` instance to a streaming inference. 150 | 151 | It will create a new model with new inputs: input_tensors. 152 | All weights will be copied. Internal states for streaming mode will be created 153 | Only tf._keras_internal.engine.functional Keras model is supported! 154 | 155 | Args: 156 | model: Instance of `Model`. 157 | input_tensors: list of input tensors to build the model upon. 158 | mode: is defined by modes.Modes 159 | 160 | Returns: 161 | An instance of streaming inference `Model` reproducing the behavior 162 | of the original model, on top of new inputs tensors, 163 | using copied weights. 164 | 165 | Raises: 166 | ValueError: in case of invalid `model` argument value or input_tensors 167 | """ 168 | 169 | # scope is introduced for simplifiyng access to weights by names 170 | scope_name = "streaming" 171 | 172 | with tf.name_scope(scope_name): 173 | if not isinstance(model, tf.keras.Model): 174 | raise ValueError( 175 | "Expected `model` argument to be a `Model` instance, got ", model 176 | ) 177 | if isinstance(model, tf.keras.Sequential): 178 | raise ValueError( 179 | "Expected `model` argument " 180 | "to be a functional `Model` instance, " 181 | "got a `Sequential` instance instead:", 182 | model, 183 | ) 184 | model = _set_mode(model, mode) 185 | new_model = tf.keras.models.clone_model(model, input_tensors) 186 | 187 | if mode == modes.Modes.STREAM_INTERNAL_STATE_INFERENCE: 188 | return _copy_weights(new_model, model) 189 | elif mode == modes.Modes.NON_STREAM_INFERENCE: 190 | new_model.set_weights(model.get_weights()) 191 | return new_model 192 | else: 193 | raise ValueError("non supported mode ", mode) 194 | 195 | 196 | def to_streaming_inference(model_non_stream, config, mode): 197 | """Convert non streaming trained model to inference modes. 198 | 199 | Args: 200 | model_non_stream: trained Keras model non streamable 201 | config: dictionary containing microWakeWord training configuration 202 | mode: it supports Non streaming inference or Streaming inference with internal 203 | states 204 | 205 | Returns: 206 | Keras inference model of inference_type 207 | """ 208 | 209 | input_data_shape = modes.get_input_data_shape(config, mode) 210 | 211 | # get input data type and use it for input streaming type 212 | if isinstance(model_non_stream.input, (tuple, list)): 213 | dtype = model_non_stream.input[0].dtype 214 | else: 215 | dtype = model_non_stream.input.dtype 216 | 217 | # For streaming, set the batch size to 1 218 | input_tensors = [ 219 | tf.keras.layers.Input( 220 | shape=input_data_shape, batch_size=1, dtype=dtype, name="input_audio" 221 | ) 222 | ] 223 | 224 | if ( 225 | isinstance(model_non_stream.input, (tuple, list)) 226 | and len(model_non_stream.input) > 1 227 | ): 228 | if len(model_non_stream.input) > 2: 229 | raise ValueError( 230 | "Maximum number of inputs supported is 2 (input_audio and " 231 | "cond_features), but got %d inputs" % len(model_non_stream.input) 232 | ) 233 | 234 | input_tensors.append( 235 | tf.keras.layers.Input( 236 | shape=config["cond_shape"], 237 | batch_size=1, 238 | dtype=model_non_stream.input[1].dtype, 239 | name="cond_features", 240 | ) 241 | ) 242 | 243 | # Input tensors must have the same shape as the original 244 | if isinstance(model_non_stream.input, (tuple, list)): 245 | model_inference = convert_to_inference_model( 246 | model_non_stream, input_tensors, mode 247 | ) 248 | else: 249 | model_inference = convert_to_inference_model( 250 | model_non_stream, input_tensors[0], mode 251 | ) 252 | 253 | return model_inference 254 | 255 | 256 | def model_to_saved( 257 | model_non_stream, 258 | config, 259 | mode=modes.Modes.STREAM_INTERNAL_STATE_INFERENCE, 260 | ): 261 | """Convert Keras model to SavedModel. 262 | 263 | Depending on mode: 264 | 1 Converted inference graph and model will be streaming statefull. 265 | 2 Converted inference graph and model will be non streaming stateless. 266 | 267 | Args: 268 | model_non_stream: Keras non streamable model 269 | config: dictionary containing microWakeWord training configuration 270 | mode: inference mode it can be streaming with internal state or non 271 | streaming 272 | """ 273 | 274 | if mode not in ( 275 | modes.Modes.STREAM_INTERNAL_STATE_INFERENCE, 276 | modes.Modes.NON_STREAM_INFERENCE, 277 | ): 278 | raise ValueError("mode %s is not supported " % mode) 279 | 280 | if mode == modes.Modes.NON_STREAM_INFERENCE: 281 | model = model_non_stream 282 | else: 283 | # convert non streaming Keras model to Keras streaming model, internal state 284 | model = to_streaming_inference(model_non_stream, config, mode) 285 | 286 | return model 287 | 288 | 289 | def convert_saved_model_to_tflite( 290 | config, audio_processor, path_to_model, folder, fname, quantize=False 291 | ): 292 | """Convert SavedModel to TFLite and optionally quantize it. 293 | 294 | Args: 295 | config: dictionary containing microWakeWord training configuration 296 | audio_processor: microWakeWord FeatureHandler object for retrieving spectrograms 297 | path_to_model: path to SavedModel 298 | folder: folder where converted model will be saved 299 | fname: output filename for TFLite file 300 | quantize: boolean selecting whether to quantize the model 301 | """ 302 | 303 | def representative_dataset_gen(): 304 | sample_fingerprints, _, _ = audio_processor.get_data( 305 | "training", 500, features_length=config["spectrogram_length"] 306 | ) 307 | 308 | sample_fingerprints[0][ 309 | 0, 0 310 | ] = 0.0 # guarantee one pixel is the preprocessor min 311 | sample_fingerprints[0][ 312 | 0, 1 313 | ] = 26.0 # guarantee one pixel is the preprocessor max 314 | 315 | # for spectrogram in sample_fingerprints: 316 | # yield spectrogram 317 | 318 | stride = config["stride"] 319 | 320 | for spectrogram in sample_fingerprints: 321 | assert spectrogram.shape[0] % stride == 0 322 | 323 | for i in range(0, spectrogram.shape[0] - stride, stride): 324 | sample = spectrogram[i : i + stride, :].astype(np.float32) 325 | yield [sample] 326 | 327 | converter = tf.lite.TFLiteConverter.from_saved_model(path_to_model) 328 | converter.optimizations = {tf.lite.Optimize.DEFAULT} 329 | 330 | # Without this flag, the Streaming layer `state` variables are left as float32, 331 | # resulting in Quantize and Dequantize operations before and after every `ReadVariable` 332 | # and `AssignVariable` operation. 333 | converter._experimental_variable_quantization = True 334 | 335 | if quantize: 336 | converter.target_spec.supported_ops = {tf.lite.OpsSet.TFLITE_BUILTINS_INT8} 337 | converter.inference_input_type = tf.int8 338 | converter.inference_output_type = tf.uint8 339 | converter.representative_dataset = tf.lite.RepresentativeDataset( 340 | representative_dataset_gen 341 | ) 342 | 343 | if not os.path.exists(folder): 344 | os.makedirs(folder) 345 | 346 | with open(os.path.join(folder, fname), "wb") as f: 347 | tflite_model = converter.convert() 348 | f.write(tflite_model) 349 | 350 | 351 | def convert_model_saved(model, config, folder, mode): 352 | """Convert model to streaming and non streaming SavedModel. 353 | 354 | Args: 355 | model: model settings 356 | config: dictionary containing microWakeWord training configuration 357 | folder: folder where converted model will be saved 358 | mode: inference mode 359 | """ 360 | 361 | path_model = os.path.join(config["train_dir"], folder) 362 | if not os.path.exists(path_model): 363 | os.makedirs(path_model) 364 | 365 | # Convert trained model to SavedModel 366 | converted_model = model_to_saved(model, config, mode) 367 | converted_model.summary() 368 | 369 | assert converted_model.input.shape[0] is not None 370 | 371 | # XXX: Using `converted_model.export(path_model)` results in obscure errors during 372 | # quantization, we create an export archive directly instead. 373 | export_archive = tf.keras.export.ExportArchive() 374 | export_archive.track(converted_model) 375 | export_archive.add_endpoint( 376 | name="serve", 377 | fn=converted_model.call, 378 | input_signature=[tf.TensorSpec(shape=converted_model.input.shape, dtype=tf.float32)], 379 | ) 380 | export_archive.write_out(path_model) 381 | 382 | save_model_summary(converted_model, path_model) 383 | 384 | return converted_model 385 | -------------------------------------------------------------------------------- /notebooks/basic_training_notebook.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "id": "r11cNiLqvWC6" 7 | }, 8 | "source": [ 9 | "# Training a microWakeWord Model\n", 10 | "\n", 11 | "This notebook steps you through training a basic microWakeWord model. It is intended as a **starting point** for advanced users. You should use Python 3.10.\n", 12 | "\n", 13 | "**The model generated will most likely not be usable for everyday use; it may be difficult to trigger or falsely activates too frequently. You will most likely have to experiment with many different settings to obtain a decent model!**\n", 14 | "\n", 15 | "In the comment at the start of certain blocks, I note some specific settings to consider modifying.\n", 16 | "\n", 17 | "This runs on Google Colab, but is extremely slow compared to training on a local GPU. If you must use Colab, be sure to Change the runtime type to a GPU. Even then, it still slow!\n", 18 | "\n", 19 | "At the end of this notebook, you will be able to download a tflite file. To use this in ESPHome, you need to write a model manifest JSON file. See the [ESPHome documentation](https://esphome.io/components/micro_wake_word) for the details and the [model repo](https://github.com/esphome/micro-wake-word-models/tree/main/models/v2) for examples." 20 | ] 21 | }, 22 | { 23 | "cell_type": "code", 24 | "execution_count": null, 25 | "metadata": { 26 | "id": "BFf6511E65ff" 27 | }, 28 | "outputs": [], 29 | "source": [ 30 | "# Installs microWakeWord. Be sure to restart the session after this is finished.\n", 31 | "import platform\n", 32 | "\n", 33 | "if platform.system() == \"Darwin\":\n", 34 | " # `pymicro-features` is installed from a fork to support building on macOS\n", 35 | " !pip install 'git+https://github.com/puddly/pymicro-features@puddly/minimum-cpp-version'\n", 36 | "\n", 37 | "# `audio-metadata` is installed from a fork to unpin `attrs` from a version that breaks Jupyter\n", 38 | "!pip install 'git+https://github.com/whatsnowplaying/audio-metadata@d4ebb238e6a401bb1a5aaaac60c9e2b3cb30929f'\n", 39 | "\n", 40 | "!git clone https://github.com/kahrendt/microWakeWord\n", 41 | "!pip install -e ./microWakeWord" 42 | ] 43 | }, 44 | { 45 | "cell_type": "code", 46 | "execution_count": null, 47 | "metadata": { 48 | "id": "dEluu7nL7ywd" 49 | }, 50 | "outputs": [], 51 | "source": [ 52 | "# Generates 1 sample of the target word for manual verification.\n", 53 | "\n", 54 | "target_word = 'khum_puter' # Phonetic spellings may produce better samples\n", 55 | "\n", 56 | "import os\n", 57 | "import sys\n", 58 | "import platform\n", 59 | "\n", 60 | "from IPython.display import Audio\n", 61 | "\n", 62 | "if not os.path.exists(\"./piper-sample-generator\"):\n", 63 | " if platform.system() == \"Darwin\":\n", 64 | " !git clone -b mps-support https://github.com/kahrendt/piper-sample-generator\n", 65 | " else:\n", 66 | " !git clone https://github.com/rhasspy/piper-sample-generator\n", 67 | "\n", 68 | " !wget -O piper-sample-generator/models/en_US-libritts_r-medium.pt 'https://github.com/rhasspy/piper-sample-generator/releases/download/v2.0.0/en_US-libritts_r-medium.pt'\n", 69 | "\n", 70 | " # Install system dependencies\n", 71 | " !pip install torch torchaudio piper-phonemize-cross==1.2.1\n", 72 | "\n", 73 | " if \"piper-sample-generator/\" not in sys.path:\n", 74 | " sys.path.append(\"piper-sample-generator/\")\n", 75 | "\n", 76 | "!python3 piper-sample-generator/generate_samples.py \"{target_word}\" \\\n", 77 | "--max-samples 1 \\\n", 78 | "--batch-size 1 \\\n", 79 | "--output-dir generated_samples\n", 80 | "\n", 81 | "Audio(\"generated_samples/0.wav\", autoplay=True)" 82 | ] 83 | }, 84 | { 85 | "cell_type": "code", 86 | "execution_count": null, 87 | "metadata": { 88 | "id": "-SvGtCCM9akR" 89 | }, 90 | "outputs": [], 91 | "source": [ 92 | "# Generates a larger amount of wake word samples.\n", 93 | "# Start here when trying to improve your model.\n", 94 | "# See https://github.com/rhasspy/piper-sample-generator for the full set of\n", 95 | "# parameters. In particular, experiment with noise-scales and noise-scale-ws,\n", 96 | "# generating negative samples similar to the wake word, and generating many more\n", 97 | "# wake word samples, possibly with different phonetic pronunciations.\n", 98 | "\n", 99 | "!python3 piper-sample-generator/generate_samples.py \"{target_word}\" \\\n", 100 | "--max-samples 1000 \\\n", 101 | "--batch-size 100 \\\n", 102 | "--output-dir generated_samples" 103 | ] 104 | }, 105 | { 106 | "cell_type": "code", 107 | "execution_count": null, 108 | "metadata": { 109 | "id": "YJRG4Qvo9nXG" 110 | }, 111 | "outputs": [], 112 | "source": [ 113 | "# Downloads audio data for augmentation. This can be slow!\n", 114 | "# Borrowed from openWakeWord's automatic_model_training.ipynb, accessed March 4, 2024\n", 115 | "#\n", 116 | "# **Important note!** The data downloaded here has a mixture of difference\n", 117 | "# licenses and usage restrictions. As such, any custom models trained with this\n", 118 | "# data should be considered as appropriate for **non-commercial** personal use only.\n", 119 | "\n", 120 | "\n", 121 | "import datasets\n", 122 | "import scipy\n", 123 | "import os\n", 124 | "\n", 125 | "import numpy as np\n", 126 | "\n", 127 | "from pathlib import Path\n", 128 | "from tqdm import tqdm\n", 129 | "\n", 130 | "## Download MIR RIR data\n", 131 | "\n", 132 | "output_dir = \"./mit_rirs\"\n", 133 | "if not os.path.exists(output_dir):\n", 134 | " os.mkdir(output_dir)\n", 135 | " rir_dataset = datasets.load_dataset(\"davidscripka/MIT_environmental_impulse_responses\", split=\"train\", streaming=True)\n", 136 | " # Save clips to 16-bit PCM wav files\n", 137 | " for row in tqdm(rir_dataset):\n", 138 | " name = row['audio']['path'].split('/')[-1]\n", 139 | " scipy.io.wavfile.write(os.path.join(output_dir, name), 16000, (row['audio']['array']*32767).astype(np.int16))\n", 140 | "\n", 141 | "## Download noise and background audio\n", 142 | "\n", 143 | "# Audioset Dataset (https://research.google.com/audioset/dataset/index.html)\n", 144 | "# Download one part of the audioset .tar files, extract, and convert to 16khz\n", 145 | "# For full-scale training, it's recommended to download the entire dataset from\n", 146 | "# https://huggingface.co/datasets/agkphysics/AudioSet, and\n", 147 | "# even potentially combine it with other background noise datasets (e.g., FSD50k, Freesound, etc.)\n", 148 | "\n", 149 | "if not os.path.exists(\"audioset\"):\n", 150 | " os.mkdir(\"audioset\")\n", 151 | "\n", 152 | " fname = \"bal_train09.tar\"\n", 153 | " out_dir = f\"audioset/{fname}\"\n", 154 | " link = \"https://huggingface.co/datasets/agkphysics/AudioSet/resolve/main/data/\" + fname\n", 155 | " !wget -O {out_dir} {link}\n", 156 | " !cd audioset && tar -xf bal_train09.tar\n", 157 | "\n", 158 | " output_dir = \"./audioset_16k\"\n", 159 | " if not os.path.exists(output_dir):\n", 160 | " os.mkdir(output_dir)\n", 161 | "\n", 162 | " # Save clips to 16-bit PCM wav files\n", 163 | " audioset_dataset = datasets.Dataset.from_dict({\"audio\": [str(i) for i in Path(\"audioset/audio\").glob(\"**/*.flac\")]})\n", 164 | " audioset_dataset = audioset_dataset.cast_column(\"audio\", datasets.Audio(sampling_rate=16000))\n", 165 | " for row in tqdm(audioset_dataset):\n", 166 | " name = row['audio']['path'].split('/')[-1].replace(\".flac\", \".wav\")\n", 167 | " scipy.io.wavfile.write(os.path.join(output_dir, name), 16000, (row['audio']['array']*32767).astype(np.int16))\n", 168 | "\n", 169 | "# Free Music Archive dataset\n", 170 | "# https://github.com/mdeff/fma\n", 171 | "# (Third-party mchl914 extra small set)\n", 172 | "\n", 173 | "output_dir = \"./fma\"\n", 174 | "if not os.path.exists(output_dir):\n", 175 | " os.mkdir(output_dir)\n", 176 | " fname = \"fma_xs.zip\"\n", 177 | " link = \"https://huggingface.co/datasets/mchl914/fma_xsmall/resolve/main/\" + fname\n", 178 | " out_dir = f\"fma/{fname}\"\n", 179 | " !wget -O {out_dir} {link}\n", 180 | " !cd {output_dir} && unzip -q {fname}\n", 181 | "\n", 182 | " output_dir = \"./fma_16k\"\n", 183 | " if not os.path.exists(output_dir):\n", 184 | " os.mkdir(output_dir)\n", 185 | "\n", 186 | " # Save clips to 16-bit PCM wav files\n", 187 | " fma_dataset = datasets.Dataset.from_dict({\"audio\": [str(i) for i in Path(\"fma/fma_small\").glob(\"**/*.mp3\")]})\n", 188 | " fma_dataset = fma_dataset.cast_column(\"audio\", datasets.Audio(sampling_rate=16000))\n", 189 | " for row in tqdm(fma_dataset):\n", 190 | " name = row['audio']['path'].split('/')[-1].replace(\".mp3\", \".wav\")\n", 191 | " scipy.io.wavfile.write(os.path.join(output_dir, name), 16000, (row['audio']['array']*32767).astype(np.int16))\n" 192 | ] 193 | }, 194 | { 195 | "cell_type": "code", 196 | "execution_count": null, 197 | "metadata": { 198 | "id": "XW3bmbI5-JAz" 199 | }, 200 | "outputs": [], 201 | "source": [ 202 | "# Sets up the augmentations.\n", 203 | "# To improve your model, experiment with these settings and use more sources of\n", 204 | "# background clips.\n", 205 | "\n", 206 | "from microwakeword.audio.augmentation import Augmentation\n", 207 | "from microwakeword.audio.clips import Clips\n", 208 | "from microwakeword.audio.spectrograms import SpectrogramGeneration\n", 209 | "\n", 210 | "clips = Clips(input_directory='generated_samples',\n", 211 | " file_pattern='*.wav',\n", 212 | " max_clip_duration_s=None,\n", 213 | " remove_silence=False,\n", 214 | " random_split_seed=10,\n", 215 | " split_count=0.1,\n", 216 | " )\n", 217 | "augmenter = Augmentation(augmentation_duration_s=3.2,\n", 218 | " augmentation_probabilities = {\n", 219 | " \"SevenBandParametricEQ\": 0.1,\n", 220 | " \"TanhDistortion\": 0.1,\n", 221 | " \"PitchShift\": 0.1,\n", 222 | " \"BandStopFilter\": 0.1,\n", 223 | " \"AddColorNoise\": 0.1,\n", 224 | " \"AddBackgroundNoise\": 0.75,\n", 225 | " \"Gain\": 1.0,\n", 226 | " \"RIR\": 0.5,\n", 227 | " },\n", 228 | " impulse_paths = ['mit_rirs'],\n", 229 | " background_paths = ['fma_16k', 'audioset_16k'],\n", 230 | " background_min_snr_db = -5,\n", 231 | " background_max_snr_db = 10,\n", 232 | " min_jitter_s = 0.195,\n", 233 | " max_jitter_s = 0.205,\n", 234 | " )\n" 235 | ] 236 | }, 237 | { 238 | "cell_type": "code", 239 | "execution_count": null, 240 | "metadata": { 241 | "id": "V5UsJfKKD1k9" 242 | }, 243 | "outputs": [], 244 | "source": [ 245 | "# Augment a random clip and play it back to verify it works well\n", 246 | "\n", 247 | "from IPython.display import Audio\n", 248 | "from microwakeword.audio.audio_utils import save_clip\n", 249 | "\n", 250 | "random_clip = clips.get_random_clip()\n", 251 | "augmented_clip = augmenter.augment_clip(random_clip)\n", 252 | "save_clip(augmented_clip, 'augmented_clip.wav')\n", 253 | "\n", 254 | "Audio(\"augmented_clip.wav\", autoplay=True)" 255 | ] 256 | }, 257 | { 258 | "cell_type": "code", 259 | "execution_count": null, 260 | "metadata": { 261 | "id": "D7BHcY1mEGbK" 262 | }, 263 | "outputs": [], 264 | "source": [ 265 | "# Augment samples and save the training, validation, and testing sets.\n", 266 | "# Validating and testing samples generated the same way can make the model\n", 267 | "# benchmark better than it performs in real-word use. Use real samples or TTS\n", 268 | "# samples generated with a different TTS engine to potentially get more accurate\n", 269 | "# benchmarks.\n", 270 | "\n", 271 | "import os\n", 272 | "from mmap_ninja.ragged import RaggedMmap\n", 273 | "\n", 274 | "output_dir = 'generated_augmented_features'\n", 275 | "\n", 276 | "if not os.path.exists(output_dir):\n", 277 | " os.mkdir(output_dir)\n", 278 | "\n", 279 | "splits = [\"training\", \"validation\", \"testing\"]\n", 280 | "for split in splits:\n", 281 | " out_dir = os.path.join(output_dir, split)\n", 282 | " if not os.path.exists(out_dir):\n", 283 | " os.mkdir(out_dir)\n", 284 | "\n", 285 | "\n", 286 | " split_name = \"train\"\n", 287 | " repetition = 2\n", 288 | "\n", 289 | " spectrograms = SpectrogramGeneration(clips=clips,\n", 290 | " augmenter=augmenter,\n", 291 | " slide_frames=10, # Uses the same spectrogram repeatedly, just shifted over by one frame. This simulates the streaming inferences while training/validating in nonstreaming mode.\n", 292 | " step_ms=10,\n", 293 | " )\n", 294 | " if split == \"validation\":\n", 295 | " split_name = \"validation\"\n", 296 | " repetition = 1\n", 297 | " elif split == \"testing\":\n", 298 | " split_name = \"test\"\n", 299 | " repetition = 1\n", 300 | " spectrograms = SpectrogramGeneration(clips=clips,\n", 301 | " augmenter=augmenter,\n", 302 | " slide_frames=1, # The testing set uses the streaming version of the model, so no artificial repetition is necessary\n", 303 | " step_ms=10,\n", 304 | " )\n", 305 | "\n", 306 | " RaggedMmap.from_generator(\n", 307 | " out_dir=os.path.join(out_dir, 'wakeword_mmap'),\n", 308 | " sample_generator=spectrograms.spectrogram_generator(split=split_name, repeat=repetition),\n", 309 | " batch_size=100,\n", 310 | " verbose=True,\n", 311 | " )" 312 | ] 313 | }, 314 | { 315 | "cell_type": "code", 316 | "execution_count": null, 317 | "metadata": { 318 | "id": "1pGuJDPyp3ax" 319 | }, 320 | "outputs": [], 321 | "source": [ 322 | "# Downloads pre-generated spectrogram features (made for microWakeWord in\n", 323 | "# particular) for various negative datasets. This can be slow!\n", 324 | "\n", 325 | "output_dir = './negative_datasets'\n", 326 | "if not os.path.exists(output_dir):\n", 327 | " os.mkdir(output_dir)\n", 328 | " link_root = \"https://huggingface.co/datasets/kahrendt/microwakeword/resolve/main/\"\n", 329 | " filenames = ['dinner_party.zip', 'dinner_party_eval.zip', 'no_speech.zip', 'speech.zip']\n", 330 | " for fname in filenames:\n", 331 | " link = link_root + fname\n", 332 | "\n", 333 | " zip_path = f\"negative_datasets/{fname}\"\n", 334 | " !wget -O {zip_path} {link}\n", 335 | " !unzip -q {zip_path} -d {output_dir}" 336 | ] 337 | }, 338 | { 339 | "cell_type": "code", 340 | "execution_count": null, 341 | "metadata": { 342 | "id": "Ii1A14GsGVQT" 343 | }, 344 | "outputs": [], 345 | "source": [ 346 | "# Save a yaml config that controls the training process\n", 347 | "# These hyperparamters can make a huge different in model quality.\n", 348 | "# Experiment with sampling and penalty weights and increasing the number of\n", 349 | "# training steps.\n", 350 | "\n", 351 | "import yaml\n", 352 | "import os\n", 353 | "\n", 354 | "config = {}\n", 355 | "\n", 356 | "config[\"window_step_ms\"] = 10\n", 357 | "\n", 358 | "config[\"train_dir\"] = (\n", 359 | " \"trained_models/wakeword\"\n", 360 | ")\n", 361 | "\n", 362 | "\n", 363 | "# Each feature_dir should have at least one of the following folders with this structure:\n", 364 | "# training/\n", 365 | "# ragged_mmap_folders_ending_in_mmap\n", 366 | "# testing/\n", 367 | "# ragged_mmap_folders_ending_in_mmap\n", 368 | "# testing_ambient/\n", 369 | "# ragged_mmap_folders_ending_in_mmap\n", 370 | "# validation/\n", 371 | "# ragged_mmap_folders_ending_in_mmap\n", 372 | "# validation_ambient/\n", 373 | "# ragged_mmap_folders_ending_in_mmap\n", 374 | "#\n", 375 | "# sampling_weight: Weight for choosing a spectrogram from this set in the batch\n", 376 | "# penalty_weight: Penalizing weight for incorrect predictions from this set\n", 377 | "# truth: Boolean whether this set has positive samples or negative samples\n", 378 | "# truncation_strategy = If spectrograms in the set are longer than necessary for training, how are they truncated\n", 379 | "# - random: choose a random portion of the entire spectrogram - useful for long negative samples\n", 380 | "# - truncate_start: remove the start of the spectrogram\n", 381 | "# - truncate_end: remove the end of the spectrogram\n", 382 | "# - split: Split the longer spectrogram into separate spectrograms offset by 100 ms. Only for ambient sets\n", 383 | "\n", 384 | "config[\"features\"] = [\n", 385 | " {\n", 386 | " \"features_dir\": \"generated_augmented_features\",\n", 387 | " \"sampling_weight\": 2.0,\n", 388 | " \"penalty_weight\": 1.0,\n", 389 | " \"truth\": True,\n", 390 | " \"truncation_strategy\": \"truncate_start\",\n", 391 | " \"type\": \"mmap\",\n", 392 | " },\n", 393 | " {\n", 394 | " \"features_dir\": \"negative_datasets/speech\",\n", 395 | " \"sampling_weight\": 10.0,\n", 396 | " \"penalty_weight\": 1.0,\n", 397 | " \"truth\": False,\n", 398 | " \"truncation_strategy\": \"random\",\n", 399 | " \"type\": \"mmap\",\n", 400 | " },\n", 401 | " {\n", 402 | " \"features_dir\": \"negative_datasets/dinner_party\",\n", 403 | " \"sampling_weight\": 10.0,\n", 404 | " \"penalty_weight\": 1.0,\n", 405 | " \"truth\": False,\n", 406 | " \"truncation_strategy\": \"random\",\n", 407 | " \"type\": \"mmap\",\n", 408 | " },\n", 409 | " {\n", 410 | " \"features_dir\": \"negative_datasets/no_speech\",\n", 411 | " \"sampling_weight\": 5.0,\n", 412 | " \"penalty_weight\": 1.0,\n", 413 | " \"truth\": False,\n", 414 | " \"truncation_strategy\": \"random\",\n", 415 | " \"type\": \"mmap\",\n", 416 | " },\n", 417 | " { # Only used for validation and testing\n", 418 | " \"features_dir\": \"negative_datasets/dinner_party_eval\",\n", 419 | " \"sampling_weight\": 0.0,\n", 420 | " \"penalty_weight\": 1.0,\n", 421 | " \"truth\": False,\n", 422 | " \"truncation_strategy\": \"split\",\n", 423 | " \"type\": \"mmap\",\n", 424 | " },\n", 425 | "]\n", 426 | "\n", 427 | "# Number of training steps in each iteration - various other settings are configured as lists that corresponds to different steps\n", 428 | "config[\"training_steps\"] = [10000]\n", 429 | "\n", 430 | "# Penalizing weight for incorrect class predictions - lists that correspond to training steps\n", 431 | "config[\"positive_class_weight\"] = [1]\n", 432 | "config[\"negative_class_weight\"] = [20]\n", 433 | "\n", 434 | "config[\"learning_rates\"] = [\n", 435 | " 0.001,\n", 436 | "] # Learning rates for Adam optimizer - list that corresponds to training steps\n", 437 | "config[\"batch_size\"] = 128\n", 438 | "\n", 439 | "config[\"time_mask_max_size\"] = [\n", 440 | " 0\n", 441 | "] # SpecAugment - list that corresponds to training steps\n", 442 | "config[\"time_mask_count\"] = [0] # SpecAugment - list that corresponds to training steps\n", 443 | "config[\"freq_mask_max_size\"] = [\n", 444 | " 0\n", 445 | "] # SpecAugment - list that corresponds to training steps\n", 446 | "config[\"freq_mask_count\"] = [0] # SpecAugment - list that corresponds to training steps\n", 447 | "\n", 448 | "config[\"eval_step_interval\"] = (\n", 449 | " 500 # Test the validation sets after every this many steps\n", 450 | ")\n", 451 | "config[\"clip_duration_ms\"] = (\n", 452 | " 1500 # Maximum length of wake word that the streaming model will accept\n", 453 | ")\n", 454 | "\n", 455 | "# The best model weights are chosen first by minimizing the specified minimization metric below the specified target_minimization\n", 456 | "# Once the target has been met, it chooses the maximum of the maximization metric. Set 'minimization_metric' to None to only maximize\n", 457 | "# Available metrics:\n", 458 | "# - \"loss\" - cross entropy error on validation set\n", 459 | "# - \"accuracy\" - accuracy of validation set\n", 460 | "# - \"recall\" - recall of validation set\n", 461 | "# - \"precision\" - precision of validation set\n", 462 | "# - \"false_positive_rate\" - false positive rate of validation set\n", 463 | "# - \"false_negative_rate\" - false negative rate of validation set\n", 464 | "# - \"ambient_false_positives\" - count of false positives from the split validation_ambient set\n", 465 | "# - \"ambient_false_positives_per_hour\" - estimated number of false positives per hour on the split validation_ambient set\n", 466 | "config[\"target_minimization\"] = 0.9\n", 467 | "config[\"minimization_metric\"] = None # Set to None to disable\n", 468 | "\n", 469 | "config[\"maximization_metric\"] = \"average_viable_recall\"\n", 470 | "\n", 471 | "with open(os.path.join(\"training_parameters.yaml\"), \"w\") as file:\n", 472 | " documents = yaml.dump(config, file)" 473 | ] 474 | }, 475 | { 476 | "cell_type": "code", 477 | "execution_count": null, 478 | "metadata": { 479 | "id": "WoEXJBaiC9mf" 480 | }, 481 | "outputs": [], 482 | "source": [ 483 | "# Trains a model. When finished, it will quantize and convert the model to a\n", 484 | "# streaming version suitable for on-device detection.\n", 485 | "# It will resume if stopped, but it will start over at the configured training\n", 486 | "# steps in the yaml file.\n", 487 | "# Change --train 0 to only convert and test the best-weighted model.\n", 488 | "# On Google colab, it doesn't print the mini-batch results, so it may appear\n", 489 | "# stuck for several minutes! Additionally, it is very slow compared to training\n", 490 | "# on a local GPU.\n", 491 | "\n", 492 | "!python -m microwakeword.model_train_eval \\\n", 493 | "--training_config='training_parameters.yaml' \\\n", 494 | "--train 1 \\\n", 495 | "--restore_checkpoint 1 \\\n", 496 | "--test_tf_nonstreaming 0 \\\n", 497 | "--test_tflite_nonstreaming 0 \\\n", 498 | "--test_tflite_nonstreaming_quantized 0 \\\n", 499 | "--test_tflite_streaming 0 \\\n", 500 | "--test_tflite_streaming_quantized 1 \\\n", 501 | "--use_weights \"best_weights\" \\\n", 502 | "mixednet \\\n", 503 | "--pointwise_filters \"64,64,64,64\" \\\n", 504 | "--repeat_in_block \"1, 1, 1, 1\" \\\n", 505 | "--mixconv_kernel_sizes '[5], [7,11], [9,15], [23]' \\\n", 506 | "--residual_connection \"0,0,0,0\" \\\n", 507 | "--first_conv_filters 32 \\\n", 508 | "--first_conv_kernel_size 5 \\\n", 509 | "--stride 3" 510 | ] 511 | }, 512 | { 513 | "cell_type": "code", 514 | "execution_count": null, 515 | "metadata": { 516 | "id": "ex_UIWvwtjAN" 517 | }, 518 | "outputs": [], 519 | "source": [ 520 | "# Downloads the tflite model file. To use on the device, you need to write a\n", 521 | "# Model JSON file. See https://esphome.io/components/micro_wake_word for the\n", 522 | "# documentation and\n", 523 | "# https://github.com/esphome/micro-wake-word-models/tree/main/models/v2 for\n", 524 | "# examples. Adjust the probability threshold based on the test results obtained\n", 525 | "# after training is finished. You may also need to increase the Tensor arena\n", 526 | "# model size if the model fails to load.\n", 527 | "\n", 528 | "from google.colab import files\n", 529 | "\n", 530 | "files.download(f\"trained_models/wakeword/tflite_stream_state_internal_quant/stream_state_internal_quant.tflite\")" 531 | ] 532 | } 533 | ], 534 | "metadata": { 535 | "accelerator": "GPU", 536 | "colab": { 537 | "gpuType": "T4", 538 | "provenance": [] 539 | }, 540 | "kernelspec": { 541 | "display_name": ".venv", 542 | "language": "python", 543 | "name": "python3" 544 | }, 545 | "language_info": { 546 | "codemirror_mode": { 547 | "name": "ipython", 548 | "version": 3 549 | }, 550 | "file_extension": ".py", 551 | "mimetype": "text/x-python", 552 | "name": "python", 553 | "nbconvert_exporter": "python", 554 | "pygments_lexer": "ipython3", 555 | "version": "3.10.15" 556 | } 557 | }, 558 | "nbformat": 4, 559 | "nbformat_minor": 0 560 | } 561 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=61.0"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "microwakeword" 7 | version = "0.1.0" 8 | authors = [ 9 | { name="Kevin Ahrendt", email="kahrendt@gmail.com" }, 10 | ] 11 | description = "A TensorFlow based wake word detection training framework using synthetic sample generation suitable for certain microcontrollers." 12 | readme = "README.md" 13 | requires-python = ">=3.10" 14 | dynamic = ["dependencies"] 15 | classifiers = [ 16 | "Programming Language :: Python :: 3", 17 | "License :: OSI Approved :: Apache Software License", 18 | "Operating System :: OS Independent", 19 | ] 20 | 21 | [project.urls] 22 | Homepage = "https://github.com/kahrendt/microWakeWord" 23 | Issues = "https://github.com/kahrendt/microWakeWord/issues" 24 | 25 | [tool.black] 26 | target-version = ["py310", "py311", "py312"] 27 | exclude = 'generated' -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | 3 | with open("README.md", "r", encoding="utf-8") as fh: 4 | long_description = fh.read() 5 | 6 | setuptools.setup( 7 | name="microwakeword", 8 | version="0.1.0", 9 | install_requires=[ 10 | "audiomentations", 11 | "audio_metadata", 12 | "datasets", 13 | "mmap_ninja", 14 | "numpy", 15 | "pymicro-features", 16 | "pyyaml", 17 | "tensorflow>=2.16", 18 | "webrtcvad-wheels", 19 | ], 20 | author="Kevin Ahrendt", 21 | author_email="kahrendt@gmail.com", 22 | description="A TensorFlow based wake word detection training framework using synthetic sample generation suitable for certain microcontrollers.", 23 | long_description=long_description, 24 | long_description_content_type="text/markdown", 25 | url="https://github.com/kahrendt/microWakeWord", 26 | project_urls={ 27 | "Bug Tracker": "https://github.com/kahrendt/microWakeWord/issues", 28 | }, 29 | classifiers=[ 30 | "Programming Language :: Python :: 3", 31 | "License :: OSI Approved :: Apache 2.0 License", 32 | "Operating System :: OS Independent", 33 | ], 34 | packages=setuptools.find_packages(), 35 | include_package_data=True, 36 | python_requires=">=3.10", 37 | ) 38 | --------------------------------------------------------------------------------